diff --git a/.github/PULL_REQUEST_TEMPLATE b/.github/PULL_REQUEST_TEMPLATE new file mode 100644 index 0000000000000..989e95ccd0135 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE @@ -0,0 +1,12 @@ +## What changes were proposed in this pull request? + +(Please fill in changes proposed in this fix) + + +## How was this patch tested? + +(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) + + +(If this patch involves UI changes, please attach a screenshot; otherwise, remove this) + diff --git a/.gitignore b/.gitignore index debad77ec2ad3..05afbb5e5ed69 100644 --- a/.gitignore +++ b/.gitignore @@ -17,8 +17,6 @@ cache work/ out/ .DS_Store -third_party/libmesos.so -third_party/libmesos.dylib build/apache-maven* build/zinc* build/scala* @@ -50,6 +48,7 @@ spark-tests.log streaming-tests.log dependency-reduced-pom.xml .ensime +.ensime_cache/ .ensime_lucene checkpoint derby.log @@ -59,8 +58,6 @@ dev/create-release/*final spark-*-bin-*.tgz unit-tests.log /lib/ -ec2/lib/ -rat-results.txt scalastyle.txt scalastyle-output.xml R-unit-tests.log @@ -74,3 +71,7 @@ metastore/ warehouse/ TempStatsStore/ sql/hive-thriftserver/test_warehouses + +# For R session data +.RHistory +.RData diff --git a/.rat-excludes b/.rat-excludes deleted file mode 100644 index 08fba6d351d6a..0000000000000 --- a/.rat-excludes +++ /dev/null @@ -1,85 +0,0 @@ -target -cache -.gitignore -.gitattributes -.project -.classpath -.mima-excludes -.generated-mima-excludes -.generated-mima-class-excludes -.generated-mima-member-excludes -.rat-excludes -.*md -derby.log -TAGS -RELEASE -control -docs -slaves -spark-env.cmd -bootstrap-tooltip.js -jquery-1.11.1.min.js -d3.min.js -dagre-d3.min.js -graphlib-dot.min.js -sorttable.js -vis.min.js -vis.min.css -.*avsc -.*txt -.*json -.*data -.*log -cloudpickle.py -heapq3.py -join.py -SparkExprTyper.scala -SparkILoop.scala -SparkILoopInit.scala -SparkIMain.scala -SparkImports.scala -SparkJLineCompletion.scala -SparkJLineReader.scala -SparkMemberHandlers.scala -SparkReplReporter.scala -sbt -sbt-launch-lib.bash -plugins.sbt -work -.*\.q -.*\.qv -golden -test.out/* -.*iml -service.properties -db.lck -build/* -dist/* -.*out -.*ipr -.*iws -logs -.*scalastyle-output.xml -.*dependency-reduced-pom.xml -known_translations -json_expectation -local-1422981759269/* -local-1422981780767/* -local-1425081759269/* -local-1426533911241/* -local-1426633911242/* -local-1430917381534/* -local-1430917381535_1 -local-1430917381535_2 -DESCRIPTION -NAMESPACE -test_support/* -.*Rd -help/* -html/* -INDEX -.lintr -gen-java.* -.*avpr -org.apache.spark.sql.sources.DataSourceRegister -.*parquet diff --git a/LICENSE b/LICENSE index 0db2d14465bd3..9714b3b1e4d17 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,3 @@ - Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ @@ -237,9 +236,9 @@ The following components are provided under a BSD-style license. See project lin The text of each license is also included at licenses/LICENSE-[project].txt. (BSD 3 Clause) netlib core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) - (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.1.15 - https://github.com/jpmml/jpmml-model) - (BSD 3-clause style license) jblas (org.jblas:jblas:1.2.4 - http://jblas.org/) + (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.2.7 - https://github.com/jpmml/jpmml-model) (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) + (BSD License) ANTLR 4.5.2-1 (org.antlr:antlr4:4.5.2-1 - http://wwww.antlr.org/) (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) (BSD licence) ANTLR StringTemplate (org.antlr:stringtemplate:3.2.1 - http://www.stringtemplate.org) (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org) @@ -250,22 +249,21 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (Interpreter classes (all .scala files in repl/src/main/scala except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala), and for SerializableMapWrapper in JavaUtils.scala) - (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.10.5 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.10.5 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.10.5 - http://www.scala-lang.org/) - (BSD-like) Scala Library (org.scala-lang:scala-library:2.10.5 - http://www.scala-lang.org/) - (BSD-like) Scalap (org.scala-lang:scalap:2.10.5 - http://www.scala-lang.org/) - (BSD-style) scalacheck (org.scalacheck:scalacheck_2.10:1.10.0 - http://www.scalacheck.org) - (BSD-style) spire (org.spire-math:spire_2.10:0.7.1 - http://spire-math.org) - (BSD-style) spire-macros (org.spire-math:spire-macros_2.10:0.7.1 - http://spire-math.org) - (New BSD License) Kryo (com.esotericsoftware.kryo:kryo:2.21 - http://code.google.com/p/kryo/) - (New BSD License) MinLog (com.esotericsoftware.minlog:minlog:1.2 - http://code.google.com/p/minlog/) - (New BSD License) ReflectASM (com.esotericsoftware.reflectasm:reflectasm:1.07 - http://code.google.com/p/reflectasm/) + (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.7 - http://www.scala-lang.org/) + (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.7 - http://www.scala-lang.org/) + (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.7 - http://www.scala-lang.org/) + (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.7 - http://www.scala-lang.org/) + (BSD-like) Scalap (org.scala-lang:scalap:2.11.7 - http://www.scala-lang.org/) + (BSD-style) scalacheck (org.scalacheck:scalacheck_2.11:1.10.0 - http://www.scalacheck.org) + (BSD-style) spire (org.spire-math:spire_2.11:0.7.1 - http://spire-math.org) + (BSD-style) spire-macros (org.spire-math:spire-macros_2.11:0.7.1 - http://spire-math.org) + (New BSD License) Kryo (com.esotericsoftware:kryo:3.0.3 - https://github.com/EsotericSoftware/kryo) + (New BSD License) MinLog (com.esotericsoftware:minlog:1.3.0 - https://github.com/EsotericSoftware/minlog) (New BSD license) Protocol Buffer Java API (com.google.protobuf:protobuf-java:2.5.0 - http://code.google.com/p/protobuf) (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.9 - http://py4j.sourceforge.net/) + (The New BSD License) Py4J (net.sf.py4j:py4j:0.9.2 - http://py4j.sourceforge.net/) (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) (BSD licence) sbt and sbt-launch-lib.bash (BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE) @@ -284,7 +282,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (MIT License) SLF4J API Module (org.slf4j:slf4j-api:1.7.5 - http://www.slf4j.org) (MIT License) SLF4J LOG4J-12 Binding (org.slf4j:slf4j-log4j12:1.7.5 - http://www.slf4j.org) (MIT License) pyrolite (org.spark-project:pyrolite:2.0.1 - http://pythonhosted.org/Pyro4/) - (MIT License) scopt (com.github.scopt:scopt_2.10:3.2.0 - https://github.com/scopt/scopt) + (MIT License) scopt (com.github.scopt:scopt_2.11:3.2.0 - https://github.com/scopt/scopt) (The MIT License) Mockito (org.mockito:mockito-core:1.9.5 - http://www.mockito.org) (MIT License) jquery (https://jquery.org/license/) (MIT License) AnchorJS (https://github.com/bryanbraun/anchorjs) @@ -292,3 +290,9 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (MIT License) dagre-d3 (https://github.com/cpettitt/dagre-d3) (MIT License) sorttable (https://github.com/stuartlangridge/sorttable) (MIT License) boto (https://github.com/boto/boto/blob/develop/LICENSE) + (MIT License) datatables (http://datatables.net/license) + (MIT License) mustache (https://github.com/mustache/mustache/blob/master/LICENSE) + (MIT License) cookies (http://code.google.com/p/cookies/wiki/License) + (MIT License) blockUI (http://jquery.malsup.com/block/) + (MIT License) RowsGroup (http://datatables.net/license/mit) + (MIT License) jsonFormatter (http://www.jqueryscript.net/other/jQuery-Plugin-For-Pretty-JSON-Formatting-jsonFormatter.html) diff --git a/NOTICE b/NOTICE index 7f7769f73047f..2a6fe237dcbea 100644 --- a/NOTICE +++ b/NOTICE @@ -48,7 +48,6 @@ Eclipse Public License 1.0 The following components are provided under the Eclipse Public License 1.0. See project link for details. - (Eclipse Public License - Version 1.0) mqtt-client (org.eclipse.paho:mqtt-client:0.4.0 - http://www.eclipse.org/paho/mqtt-client) (Eclipse Public License v1.0) Eclipse JDT Core (org.eclipse.jdt:core:3.1.1 - http://www.eclipse.org/jdt/) ======================================================================== @@ -606,4 +605,63 @@ Vis.js uses and redistributes the following third-party libraries: - keycharm https://github.com/AlexDM0/keycharm - The MIT License \ No newline at end of file + The MIT License + +=============================================================================== + +The CSS style for the navigation sidebar of the documentation was originally +submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project +is distributed under the 3-Clause BSD license. +=============================================================================== + +For CSV functionality: + +/* + * Copyright 2014 Databricks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Copyright 2015 Ayasdi Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +=============================================================================== +For dev/sparktestsupport/toposort.py: + +Copyright 2014 True Blade Systems, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/R/README.md b/R/README.md index 005f56da1670c..810bfc14e977e 100644 --- a/R/README.md +++ b/R/README.md @@ -1,6 +1,16 @@ # R on Spark SparkR is an R package that provides a light-weight frontend to use Spark from R. +### Installing sparkR + +Libraries of sparkR need to be created in `$SPARK_HOME/R/lib`. This can be done by running the script `$SPARK_HOME/R/install-dev.sh`. +By default the above script uses the system wide installation of R. However, this can be changed to any user installed location of R by setting the environment variable `R_HOME` the full path of the base directory where R is installed, before running install-dev.sh script. +Example: +``` +# where /home/username/R is where R is installed and /home/username/R/bin contains the files R and RScript +export R_HOME=/home/username/R +./install-dev.sh +``` ### SparkR development @@ -30,7 +40,7 @@ To set other options like driver memory, executor memory etc. you can pass in th If you wish to use SparkR from RStudio or other R frontends you will need to set some environment variables which point SparkR to your Spark installation. For example ``` # Set this to where Spark is installed -Sys.setenv(SPARK_HOME="/Users/shivaram/spark") +Sys.setenv(SPARK_HOME="/Users/username/spark") # This line loads SparkR from the installed directory .libPaths(c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib"), .libPaths())) library(SparkR) @@ -41,7 +51,7 @@ sc <- sparkR.init(master="local") The [instructions](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) for making contributions to Spark also apply to SparkR. If you only make R file changes (i.e. no Scala changes) then you can just re-install the R package using `R/install-dev.sh` and test your changes. -Once you have made your changes, please include unit tests for them and run existing unit tests using the `run-tests.sh` script as described below. +Once you have made your changes, please include unit tests for them and run existing unit tests using the `R/run-tests.sh` script as described below. #### Generating documentation @@ -50,9 +60,9 @@ The SparkR documentation (Rd files and HTML files) are not a part of the source ### Examples, Unit tests SparkR comes with several sample programs in the `examples/src/main/r` directory. -To run one of them, use `./bin/sparkR `. For example: +To run one of them, use `./bin/spark-submit `. For example: - ./bin/sparkR examples/src/main/r/dataframe.R + ./bin/spark-submit examples/src/main/r/dataframe.R You can also run the unit-tests for SparkR by running (you need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first): @@ -60,7 +70,7 @@ You can also run the unit-tests for SparkR by running (you need to install the [ ./R/run-tests.sh ### Running on YARN -The `./bin/spark-submit` and `./bin/sparkR` can also be used to submit jobs to YARN clusters. You will need to set YARN conf dir before doing so. For example on CDH you can run +The `./bin/spark-submit` can also be used to submit jobs to YARN clusters. You will need to set YARN conf dir before doing so. For example on CDH you can run ``` export YARN_CONF_DIR=/etc/hadoop/conf ./bin/spark-submit --master yarn examples/src/main/r/dataframe.R diff --git a/R/install-dev.bat b/R/install-dev.bat index 008a5c668bc45..ed1c91ae3a0ff 100644 --- a/R/install-dev.bat +++ b/R/install-dev.bat @@ -25,3 +25,9 @@ set SPARK_HOME=%~dp0.. MKDIR %SPARK_HOME%\R\lib R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\ + +rem Zip the SparkR package so that it can be distributed to worker nodes on YARN +pushd %SPARK_HOME%\R\lib +%JAVA_HOME%\bin\jar.exe cfM "%SPARK_HOME%\R\lib\sparkr.zip" SparkR +popd + diff --git a/R/install-dev.sh b/R/install-dev.sh index 59d98c9c7a646..befd413c4cd26 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -35,11 +35,22 @@ LIB_DIR="$FWDIR/lib" mkdir -p $LIB_DIR pushd $FWDIR > /dev/null +if [ ! -z "$R_HOME" ] + then + R_SCRIPT_PATH="$R_HOME/bin" + else + R_SCRIPT_PATH="$(dirname $(which R))" +fi +echo "USING R_HOME = $R_HOME" # Generate Rd files if devtools is installed -Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' +"$R_SCRIPT_PATH/"Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' # Install SparkR to $LIB_DIR -R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ +"$R_SCRIPT_PATH/"R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ + +# Zip the SparkR package so that it can be distributed to worker nodes on YARN +cd $LIB_DIR +jar cfM "$LIB_DIR/sparkr.zip" SparkR popd > /dev/null diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 3d6edb70ec98e..7179438efc1d9 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,7 +1,7 @@ Package: SparkR Type: Package Title: R frontend for Spark -Version: 1.6.0 +Version: 2.0.0 Date: 2013-09-09 Author: The Apache Software Foundation Maintainer: Shivaram Venkataraman @@ -11,17 +11,19 @@ Depends: R (>= 3.0), methods, Suggests: - testthat + testthat, + e1071, + survival Description: R frontend for Spark License: Apache License (== 2.0) Collate: 'schema.R' 'generics.R' 'jobj.R' - 'RDD.R' - 'pairRDD.R' 'column.R' 'group.R' + 'RDD.R' + 'pairRDD.R' 'DataFrame.R' 'SQLContext.R' 'backend.R' @@ -34,4 +36,6 @@ Collate: 'serialize.R' 'sparkR.R' 'stats.R' + 'types.R' 'utils.R' +RoxygenNote: 5.0.1 diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index cd9537a2655f0..f48c61c1d59c5 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -13,7 +13,11 @@ export("print.jobj") # MLlib integration exportMethods("glm", "predict", - "summary") + "summary", + "kmeans", + "fitted", + "naiveBayes", + "survreg") # Job group lifecycle management methods export("setJobGroup", @@ -23,17 +27,26 @@ export("setJobGroup", exportClasses("DataFrame") exportMethods("arrange", + "as.data.frame", "attach", "cache", "collect", + "colnames", + "colnames<-", + "coltypes", + "coltypes<-", "columns", "count", "cov", "corr", + "covar_samp", + "covar_pop", "crosstab", "describe", "dim", "distinct", + "drop", + "dropDuplicates", "dropna", "dtypes", "except", @@ -54,6 +67,7 @@ exportMethods("arrange", "mutate", "na.omit", "names", + "names<-", "ncol", "nrow", "orderBy", @@ -83,9 +97,13 @@ exportMethods("arrange", "unique", "unpersist", "where", + "with", "withColumn", "withColumnRenamed", - "write.df") + "write.df", + "write.json", + "write.parquet", + "write.text") exportClasses("Column") @@ -95,6 +113,8 @@ exportMethods("%in%", "add_months", "alias", "approxCountDistinct", + "approxQuantile", + "array_contains", "asc", "ascii", "asin", @@ -119,15 +139,18 @@ exportMethods("%in%", "count", "countDistinct", "crc32", - "cumeDist", + "hash", + "cume_dist", "date_add", "date_format", "date_sub", "datediff", "dayofmonth", "dayofyear", - "denseRank", + "decode", + "dense_rank", "desc", + "encode", "endsWith", "exp", "explode", @@ -152,6 +175,7 @@ exportMethods("%in%", "isNaN", "isNotNull", "isNull", + "kurtosis", "lag", "last", "last_day", @@ -183,7 +207,7 @@ exportMethods("%in%", "next_day", "ntile", "otherwise", - "percentRank", + "percent_rank", "pmod", "quarter", "rand", @@ -195,7 +219,7 @@ exportMethods("%in%", "rint", "rlike", "round", - "rowNumber", + "row_number", "rpad", "rtrim", "second", @@ -204,12 +228,19 @@ exportMethods("%in%", "shiftLeft", "shiftRight", "shiftRightUnsigned", + "sd", "sign", "signum", "sin", "sinh", "size", + "skewness", + "sort_array", "soundex", + "stddev", + "stddev_pop", + "stddev_samp", + "struct", "sqrt", "startsWith", "substr", @@ -228,8 +259,13 @@ exportMethods("%in%", "unhex", "unix_timestamp", "upper", + "var", + "variance", + "var_pop", + "var_samp", "weekofyear", "when", + "window", "year") exportClasses("GroupedData") @@ -248,8 +284,12 @@ export("as.DataFrame", "loadDF", "parquetFile", "read.df", + "read.json", + "read.parquet", + "read.text", "sql", - "table", + "str", + "tableToDF", "tableNames", "tables", "uncacheTable") @@ -262,5 +302,3 @@ export("structField", "structType.jobj", "structType.structField", "print.structType") - -export("as.data.frame") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index df5bc8137187b..a64a013b654ef 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -24,14 +24,14 @@ setOldClass("jobj") #' @title S4 class that represents a DataFrame #' @description DataFrames can be created using functions like \link{createDataFrame}, -#' \link{jsonFile}, \link{table} etc. -#' @family dataframe_funcs +#' \link{read.json}, \link{table} etc. +#' @family DataFrame functions #' @rdname DataFrame #' @docType class #' #' @slot env An R environment that stores bookkeeping states of the DataFrame #' @slot sdf A Java object reference to the backing Scala DataFrame -#' @seealso \link{createDataFrame}, \link{jsonFile}, \link{table} +#' @seealso \link{createDataFrame}, \link{read.json}, \link{table} #' @seealso \url{https://spark.apache.org/docs/latest/sparkr.html#sparkr-dataframes} #' @export #' @examples @@ -68,7 +68,7 @@ dataFrame <- function(sdf, isCached = FALSE) { #' #' @param x A SparkSQL DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname printSchema #' @name printSchema #' @export @@ -77,7 +77,7 @@ dataFrame <- function(sdf, isCached = FALSE) { #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' printSchema(df) #'} setMethod("printSchema", @@ -93,7 +93,7 @@ setMethod("printSchema", #' #' @param x A SparkSQL DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname schema #' @name schema #' @export @@ -102,7 +102,7 @@ setMethod("printSchema", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' dfSchema <- schema(df) #'} setMethod("schema", @@ -117,7 +117,7 @@ setMethod("schema", #' #' @param x A SparkSQL DataFrame #' @param extended Logical. If extended is False, explain() only prints the physical plan. -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname explain #' @name explain #' @export @@ -126,7 +126,7 @@ setMethod("schema", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' explain(df, TRUE) #'} setMethod("explain", @@ -148,7 +148,7 @@ setMethod("explain", #' #' @param x A SparkSQL DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname isLocal #' @name isLocal #' @export @@ -157,7 +157,7 @@ setMethod("explain", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' isLocal(df) #'} setMethod("isLocal", @@ -173,7 +173,7 @@ setMethod("isLocal", #' @param x A SparkSQL DataFrame #' @param numRows The number of rows to print. Defaults to 20. #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname showDF #' @name showDF #' @export @@ -182,7 +182,7 @@ setMethod("isLocal", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' showDF(df) #'} setMethod("showDF", @@ -198,7 +198,7 @@ setMethod("showDF", #' #' @param x A SparkSQL DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname show #' @name show #' @export @@ -207,7 +207,7 @@ setMethod("showDF", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' df #'} setMethod("show", "DataFrame", @@ -225,7 +225,7 @@ setMethod("show", "DataFrame", #' #' @param x A SparkSQL DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname dtypes #' @name dtypes #' @export @@ -234,7 +234,7 @@ setMethod("show", "DataFrame", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' dtypes(df) #'} setMethod("dtypes", @@ -251,18 +251,19 @@ setMethod("dtypes", #' #' @param x A SparkSQL DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname columns #' @name columns -#' @aliases names + #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' columns(df) +#' colnames(df) #'} setMethod("columns", signature(x = "DataFrame"), @@ -272,7 +273,6 @@ setMethod("columns", }) }) -#' @family dataframe_funcs #' @rdname columns #' @name names setMethod("names", @@ -281,7 +281,6 @@ setMethod("names", columns(x) }) -#' @family dataframe_funcs #' @rdname columns #' @name names<- setMethod("names<-", @@ -293,6 +292,141 @@ setMethod("names<-", } }) +#' @rdname columns +#' @name colnames +setMethod("colnames", + signature(x = "DataFrame"), + function(x) { + columns(x) + }) + +#' @rdname columns +#' @name colnames<- +setMethod("colnames<-", + signature(x = "DataFrame"), + function(x, value) { + + # Check parameter integrity + if (class(value) != "character") { + stop("Invalid column names.") + } + + if (length(value) != ncol(x)) { + stop( + "Column names must have the same length as the number of columns in the dataset.") + } + + if (any(is.na(value))) { + stop("Column names cannot be NA.") + } + + # Check if the column names have . in it + if (any(regexec(".", value, fixed = TRUE)[[1]][1] != -1)) { + stop("Colum names cannot contain the '.' symbol.") + } + + sdf <- callJMethod(x@sdf, "toDF", as.list(value)) + dataFrame(sdf) + }) + +#' coltypes +#' +#' Get column types of a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' @return value A character vector with the column types of the given DataFrame +#' @rdname coltypes +#' @name coltypes +#' @family DataFrame functions +#' @export +#' @examples +#'\dontrun{ +#' irisDF <- createDataFrame(sqlContext, iris) +#' coltypes(irisDF) +#'} +setMethod("coltypes", + signature(x = "DataFrame"), + function(x) { + # Get the data types of the DataFrame by invoking dtypes() function + types <- sapply(dtypes(x), function(x) {x[[2]]}) + + # Map Spark data types into R's data types using DATA_TYPES environment + rTypes <- sapply(types, USE.NAMES = F, FUN = function(x) { + # Check for primitive types + type <- PRIMITIVE_TYPES[[x]] + + if (is.null(type)) { + # Check for complex types + for (t in names(COMPLEX_TYPES)) { + if (substring(x, 1, nchar(t)) == t) { + type <- COMPLEX_TYPES[[t]] + break + } + } + + if (is.null(type)) { + stop(paste("Unsupported data type: ", x)) + } + } + type + }) + + # Find which types don't have mapping to R + naIndices <- which(is.na(rTypes)) + + # Assign the original scala data types to the unmatched ones + rTypes[naIndices] <- types[naIndices] + + rTypes + }) + +#' coltypes +#' +#' Set the column types of a DataFrame. +#' +#' @param x A SparkSQL DataFrame +#' @param value A character vector with the target column types for the given +#' DataFrame. Column types can be one of integer, numeric/double, character, logical, or NA +#' to keep that column as-is. +#' @rdname coltypes +#' @name coltypes<- +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- read.json(sqlContext, path) +#' coltypes(df) <- c("character", "integer") +#' coltypes(df) <- c(NA, "numeric") +#'} +setMethod("coltypes<-", + signature(x = "DataFrame", value = "character"), + function(x, value) { + cols <- columns(x) + ncols <- length(cols) + if (length(value) == 0) { + stop("Cannot set types of an empty DataFrame with no Column") + } + if (length(value) != ncols) { + stop("Length of type vector should match the number of columns for DataFrame") + } + newCols <- lapply(seq_len(ncols), function(i) { + col <- getColumn(x, cols[i]) + if (!is.na(value[i])) { + stype <- rToSQLTypes[[value[i]]] + if (is.null(stype)) { + stop("Only atomic type is supported for column types") + } + cast(col, stype) + } else { + col + } + }) + nx <- select(x, newCols) + dataFrame(nx@sdf) + }) + #' Register Temporary Table #' #' Registers a DataFrame as a Temporary Table in the SQLContext @@ -300,7 +434,7 @@ setMethod("names<-", #' @param x A SparkSQL DataFrame #' @param tableName A character vector containing the name of the table #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname registerTempTable #' @name registerTempTable #' @export @@ -309,7 +443,7 @@ setMethod("names<-", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "json_df") #' new_df <- sql(sqlContext, "SELECT * FROM json_df") #'} @@ -328,7 +462,7 @@ setMethod("registerTempTable", #' @param overwrite A logical argument indicating whether or not to overwrite #' the existing rows in the table. #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname insertInto #' @name insertInto #' @export @@ -344,7 +478,10 @@ setMethod("registerTempTable", setMethod("insertInto", signature(x = "DataFrame", tableName = "character"), function(x, tableName, overwrite = FALSE) { - callJMethod(x@sdf, "insertInto", tableName, overwrite) + jmode <- convertToJSaveMode(ifelse(overwrite, "overwrite", "append")) + write <- callJMethod(x@sdf, "write") + write <- callJMethod(write, "mode", jmode) + callJMethod(write, "insertInto", tableName) }) #' Cache @@ -353,7 +490,7 @@ setMethod("insertInto", #' #' @param x A SparkSQL DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname cache #' @name cache #' @export @@ -362,7 +499,7 @@ setMethod("insertInto", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' cache(df) #'} setMethod("cache", @@ -381,7 +518,7 @@ setMethod("cache", #' #' @param x The DataFrame to persist #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname persist #' @name persist #' @export @@ -390,7 +527,7 @@ setMethod("cache", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' persist(df, "MEMORY_AND_DISK") #'} setMethod("persist", @@ -409,7 +546,7 @@ setMethod("persist", #' @param x The DataFrame to unpersist #' @param blocking Whether to block until all blocks are deleted #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname unpersist-methods #' @name unpersist #' @export @@ -418,7 +555,7 @@ setMethod("persist", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' persist(df, "MEMORY_AND_DISK") #' unpersist(df) #'} @@ -437,7 +574,7 @@ setMethod("unpersist", #' @param x A SparkSQL DataFrame #' @param numPartitions The number of partitions to use. #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname repartition #' @name repartition #' @export @@ -446,7 +583,7 @@ setMethod("unpersist", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- repartition(df, 2L) #'} setMethod("repartition", @@ -456,25 +593,24 @@ setMethod("repartition", dataFrame(sdf) }) -# toJSON -# -# Convert the rows of a DataFrame into JSON objects and return an RDD where -# each element contains a JSON string. -# -# @param x A SparkSQL DataFrame -# @return A StringRRDD of JSON objects -# -# @family dataframe_funcs -# @rdname tojson -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# sqlContext <- sparkRSQL.init(sc) -# path <- "path/to/file.json" -# df <- jsonFile(sqlContext, path) -# newRDD <- toJSON(df) -#} +#' toJSON +#' +#' Convert the rows of a DataFrame into JSON objects and return an RDD where +#' each element contains a JSON string. +#' +#' @param x A SparkSQL DataFrame +#' @return A StringRRDD of JSON objects +#' @family DataFrame functions +#' @rdname tojson +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- read.json(sqlContext, path) +#' newRDD <- toJSON(df) +#'} setMethod("toJSON", signature(x = "DataFrame"), function(x) { @@ -483,30 +619,97 @@ setMethod("toJSON", RDD(jrdd, serializedMode = "string") }) -#' saveAsParquetFile +#' write.json +#' +#' Save the contents of a DataFrame as a JSON file (one object per line). Files written out +#' with this method can be read back in as a DataFrame using read.json(). +#' +#' @param x A SparkSQL DataFrame +#' @param path The directory where the file is saved +#' +#' @family DataFrame functions +#' @rdname write.json +#' @name write.json +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- read.json(sqlContext, path) +#' write.json(df, "/tmp/sparkr-tmp/") +#'} +setMethod("write.json", + signature(x = "DataFrame", path = "character"), + function(x, path) { + write <- callJMethod(x@sdf, "write") + invisible(callJMethod(write, "json", path)) + }) + +#' write.parquet #' #' Save the contents of a DataFrame as a Parquet file, preserving the schema. Files written out -#' with this method can be read back in as a DataFrame using parquetFile(). +#' with this method can be read back in as a DataFrame using read.parquet(). #' #' @param x A SparkSQL DataFrame #' @param path The directory where the file is saved #' -#' @family dataframe_funcs -#' @rdname saveAsParquetFile -#' @name saveAsParquetFile +#' @family DataFrame functions +#' @rdname write.parquet +#' @name write.parquet #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) -#' saveAsParquetFile(df, "/tmp/sparkr-tmp/") +#' df <- read.json(sqlContext, path) +#' write.parquet(df, "/tmp/sparkr-tmp1/") +#' saveAsParquetFile(df, "/tmp/sparkr-tmp2/") #'} +setMethod("write.parquet", + signature(x = "DataFrame", path = "character"), + function(x, path) { + write <- callJMethod(x@sdf, "write") + invisible(callJMethod(write, "parquet", path)) + }) + +#' @rdname write.parquet +#' @name saveAsParquetFile +#' @export setMethod("saveAsParquetFile", signature(x = "DataFrame", path = "character"), function(x, path) { - invisible(callJMethod(x@sdf, "saveAsParquetFile", path)) + .Deprecated("write.parquet") + write.parquet(x, path) + }) + +#' write.text +#' +#' Saves the content of the DataFrame in a text file at the specified path. +#' The DataFrame must have only one column of string type with the name "value". +#' Each row becomes a new line in the output file. +#' +#' @param x A SparkSQL DataFrame +#' @param path The directory where the file is saved +#' +#' @family DataFrame functions +#' @rdname write.text +#' @name write.text +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.txt" +#' df <- read.text(sqlContext, path) +#' write.text(df, "/tmp/sparkr-tmp/") +#'} +setMethod("write.text", + signature(x = "DataFrame", path = "character"), + function(x, path) { + write <- callJMethod(x@sdf, "write") + invisible(callJMethod(write, "text", path)) }) #' Distinct @@ -515,7 +718,7 @@ setMethod("saveAsParquetFile", #' #' @param x A SparkSQL DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname distinct #' @name distinct #' @export @@ -524,7 +727,7 @@ setMethod("saveAsParquetFile", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' distinctDF <- distinct(df) #'} setMethod("distinct", @@ -534,14 +737,8 @@ setMethod("distinct", dataFrame(sdf) }) -#' @title Distinct rows in a DataFrame -# -#' @description Returns a new DataFrame containing distinct rows in this DataFrame -#' -#' @family dataframe_funcs -#' @rdname unique +#' @rdname distinct #' @name unique -#' @aliases distinct setMethod("unique", signature(x = "DataFrame"), function(x) { @@ -555,58 +752,61 @@ setMethod("unique", #' @param x A SparkSQL DataFrame #' @param withReplacement Sampling with replacement or not #' @param fraction The (rough) sample target fraction +#' @param seed Randomness seed value #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname sample -#' @aliases sample_frac +#' @name sample #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' collect(sample(df, FALSE, 0.5)) #' collect(sample(df, TRUE, 0.5)) #'} setMethod("sample", - # TODO : Figure out how to send integer as java.lang.Long to JVM so - # we can send seed as an argument through callJMethod signature(x = "DataFrame", withReplacement = "logical", fraction = "numeric"), - function(x, withReplacement, fraction) { + function(x, withReplacement, fraction, seed) { if (fraction < 0.0) stop(cat("Negative fraction value:", fraction)) - sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction) + if (!missing(seed)) { + # TODO : Figure out how to send integer as java.lang.Long to JVM so + # we can send seed as an argument through callJMethod + sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction, as.integer(seed)) + } else { + sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction) + } dataFrame(sdf) }) -#' @family dataframe_funcs #' @rdname sample #' @name sample_frac setMethod("sample_frac", signature(x = "DataFrame", withReplacement = "logical", fraction = "numeric"), - function(x, withReplacement, fraction) { - sample(x, withReplacement, fraction) + function(x, withReplacement, fraction, seed) { + sample(x, withReplacement, fraction, seed) }) -#' Count +#' nrow #' #' Returns the number of rows in a DataFrame #' #' @param x A SparkSQL DataFrame #' -#' @family dataframe_funcs -#' @rdname count +#' @family DataFrame functions +#' @rdname nrow #' @name count -#' @aliases nrow #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' count(df) #' } setMethod("count", @@ -615,14 +815,8 @@ setMethod("count", callJMethod(x@sdf, "count") }) -#' @title Number of rows for a DataFrame -#' @description Returns number of rows in a DataFrames -#' #' @name nrow -#' -#' @family dataframe_funcs #' @rdname nrow -#' @aliases count setMethod("nrow", signature(x = "DataFrame"), function(x) { @@ -633,7 +827,7 @@ setMethod("nrow", #' #' @param x a SparkSQL DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname ncol #' @name ncol #' @export @@ -642,7 +836,7 @@ setMethod("nrow", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' ncol(df) #' } setMethod("ncol", @@ -654,7 +848,7 @@ setMethod("ncol", #' Returns the dimentions (number of rows and columns) of a DataFrame #' @param x a SparkSQL DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname dim #' @name dim #' @export @@ -663,7 +857,7 @@ setMethod("ncol", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' dim(df) #' } setMethod("dim", @@ -678,7 +872,7 @@ setMethod("dim", #' @param stringsAsFactors (Optional) A logical indicating whether or not string columns #' should be converted to factors. FALSE by default. #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname collect #' @name collect #' @export @@ -687,15 +881,15 @@ setMethod("dim", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' collected <- collect(df) #' firstName <- collected[[1]]$name #' } setMethod("collect", signature(x = "DataFrame"), function(x, stringsAsFactors = FALSE) { - names <- columns(x) - ncol <- length(names) + dtypes <- dtypes(x) + ncol <- length(dtypes) if (ncol <= 0) { # empty data.frame with 0 columns and 0 rows data.frame() @@ -718,25 +912,29 @@ setMethod("collect", # data of complex type can be held. But getting a cell from a column # of list type returns a list instead of a vector. So for columns of # non-complex type, append them as vector. + # + # For columns of complex type, be careful to access them. + # Get a column of complex type returns a list. + # Get a cell from a column of complex type returns a list instead of a vector. col <- listCols[[colIndex]] if (length(col) <= 0) { - df[[names[colIndex]]] <- col + df[[colIndex]] <- col } else { - # TODO: more robust check on column of primitive types - vec <- do.call(c, col) - if (class(vec) != "list") { - df[[names[colIndex]]] <- vec + colType <- dtypes[[colIndex]][[2]] + # Note that "binary" columns behave like complex types. + if (!is.null(PRIMITIVE_TYPES[[colType]]) && colType != "binary") { + vec <- do.call(c, col) + stopifnot(class(vec) != "list") + df[[colIndex]] <- vec } else { - # For columns of complex type, be careful to access them. - # Get a column of complex type returns a list. - # Get a cell from a column of complex type returns a list instead of a vector. - df[[names[colIndex]]] <- col - } + df[[colIndex]] <- col + } + } } + names(df) <- names(x) + df } - df - } - }) + }) #' Limit #' @@ -746,7 +944,7 @@ setMethod("collect", #' @param num The number of rows to return #' @return A new DataFrame containing the number of rows specified. #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname limit #' @name limit #' @export @@ -755,7 +953,7 @@ setMethod("collect", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' limitedDF <- limit(df, 10) #' } setMethod("limit", @@ -767,7 +965,7 @@ setMethod("limit", #' Take the first NUM rows of a DataFrame and return a the results as a data.frame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname take #' @name take #' @export @@ -776,7 +974,7 @@ setMethod("limit", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' take(df, 2) #' } setMethod("take", @@ -796,7 +994,7 @@ setMethod("take", #' @param num The number of rows to return. Default is 6. #' @return A data.frame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname head #' @name head #' @export @@ -805,7 +1003,7 @@ setMethod("take", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' head(df) #' } setMethod("head", @@ -819,7 +1017,7 @@ setMethod("head", #' #' @param x A SparkSQL DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname first #' @name first #' @export @@ -828,7 +1026,7 @@ setMethod("head", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' first(df) #' } setMethod("first", @@ -837,23 +1035,21 @@ setMethod("first", take(x, 1) }) -# toRDD -# -# Converts a Spark DataFrame to an RDD while preserving column names. -# -# @param x A Spark DataFrame -# -# @family dataframe_funcs -# @rdname DataFrame -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# sqlContext <- sparkRSQL.init(sc) -# path <- "path/to/file.json" -# df <- jsonFile(sqlContext, path) -# rdd <- toRDD(df) -# } +#' toRDD +#' +#' Converts a Spark DataFrame to an RDD while preserving column names. +#' +#' @param x A Spark DataFrame +#' +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- read.json(sqlContext, path) +#' rdd <- toRDD(df) +#'} setMethod("toRDD", signature(x = "DataFrame"), function(x) { @@ -873,8 +1069,7 @@ setMethod("toRDD", #' @param x a DataFrame #' @return a GroupedData #' @seealso GroupedData -#' @aliases group_by -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname groupBy #' @name groupBy #' @export @@ -899,7 +1094,6 @@ setMethod("groupBy", groupedData(sgd) }) -#' @family dataframe_funcs #' @rdname groupBy #' @name group_by setMethod("group_by", @@ -913,10 +1107,9 @@ setMethod("group_by", #' Compute aggregates by specifying a list of columns #' #' @param x a DataFrame -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname agg #' @name agg -#' @aliases summarize #' @export setMethod("agg", signature(x = "DataFrame"), @@ -924,7 +1117,6 @@ setMethod("agg", agg(groupBy(x), ...) }) -#' @family dataframe_funcs #' @rdname agg #' @name summarize setMethod("summarize", @@ -940,8 +1132,8 @@ setMethod("summarize", # the requested map function. # ################################################################################### -# @family dataframe_funcs -# @rdname lapply +#' @rdname lapply +#' @noRd setMethod("lapply", signature(X = "DataFrame", FUN = "function"), function(X, FUN) { @@ -949,24 +1141,25 @@ setMethod("lapply", lapply(rdd, FUN) }) -# @family dataframe_funcs -# @rdname lapply +#' @rdname lapply +#' @noRd setMethod("map", signature(X = "DataFrame", FUN = "function"), function(X, FUN) { lapply(X, FUN) }) -# @family dataframe_funcs -# @rdname flatMap +#' @rdname flatMap +#' @noRd setMethod("flatMap", signature(X = "DataFrame", FUN = "function"), function(X, FUN) { rdd <- toRDD(X) flatMap(rdd, FUN) }) -# @family dataframe_funcs -# @rdname lapplyPartition + +#' @rdname lapplyPartition +#' @noRd setMethod("lapplyPartition", signature(X = "DataFrame", FUN = "function"), function(X, FUN) { @@ -974,16 +1167,16 @@ setMethod("lapplyPartition", lapplyPartition(rdd, FUN) }) -# @family dataframe_funcs -# @rdname lapplyPartition +#' @rdname lapplyPartition +#' @noRd setMethod("mapPartitions", signature(X = "DataFrame", FUN = "function"), function(X, FUN) { lapplyPartition(X, FUN) }) -# @family dataframe_funcs -# @rdname foreach +#' @rdname foreach +#' @noRd setMethod("foreach", signature(x = "DataFrame", func = "function"), function(x, func) { @@ -991,8 +1184,8 @@ setMethod("foreach", foreach(rdd, func) }) -# @family dataframe_funcs -# @rdname foreach +#' @rdname foreach +#' @noRd setMethod("foreachPartition", signature(x = "DataFrame", func = "function"), function(x, func) { @@ -1019,23 +1212,10 @@ setMethod("$", signature(x = "DataFrame"), setMethod("$<-", signature(x = "DataFrame"), function(x, name, value) { stopifnot(class(value) == "Column" || is.null(value)) - cols <- columns(x) - if (name %in% cols) { - if (is.null(value)) { - cols <- Filter(function(c) { c != name }, cols) - } - cols <- lapply(cols, function(c) { - if (c == name) { - alias(value, name) - } else { - col(c) - } - }) - nx <- select(x, cols) + + if (is.null(value)) { + nx <- drop(x, name) } else { - if (is.null(value)) { - return(x) - } nx <- withColumn(x, name, value) } x@sdf <- nx@sdf @@ -1087,14 +1267,13 @@ setMethod("[", signature(x = "DataFrame", i = "Column"), #' #' Return subsets of DataFrame according to given conditions #' @param x A DataFrame -#' @param subset A logical expression to filter on rows +#' @param subset (Optional) A logical expression to filter on rows #' @param select expression for the single Column or a list of columns to select from the DataFrame #' @return A new DataFrame containing only the rows that meet the condition with selected columns #' @export -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname subset #' @name subset -#' @aliases [ #' @family subsetting functions #' @examples #' \dontrun{ @@ -1109,10 +1288,15 @@ setMethod("[", signature(x = "DataFrame", i = "Column"), #' df[df$age %in% c(19, 30), 1:2] #' subset(df, df$age %in% c(19, 30), 1:2) #' subset(df, df$age %in% c(19), select = c(1,2)) +#' subset(df, select = c(1,2)) #' } setMethod("subset", signature(x = "DataFrame"), function(x, subset, select, ...) { - x[subset, select, ...] + if (missing(subset)) { + x[, select, ...] + } else { + x[subset, select, ...] + } }) #' Select @@ -1122,7 +1306,7 @@ setMethod("subset", signature(x = "DataFrame"), #' @param col A list of columns or single Column or name #' @return A new DataFrame with selected columns #' @export -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname select #' @name select #' @family subsetting functions @@ -1150,7 +1334,7 @@ setMethod("select", signature(x = "DataFrame", col = "character"), } }) -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname select #' @export setMethod("select", signature(x = "DataFrame", col = "Column"), @@ -1162,7 +1346,7 @@ setMethod("select", signature(x = "DataFrame", col = "Column"), dataFrame(sdf) }) -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname select #' @export setMethod("select", @@ -1187,7 +1371,7 @@ setMethod("select", #' @param expr A string containing a SQL expression #' @param ... Additional expressions #' @return A DataFrame -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname selectExpr #' @name selectExpr #' @export @@ -1196,7 +1380,7 @@ setMethod("select", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' selectExpr(df, "col1", "(col2 * 5) as newCol") #' } setMethod("selectExpr", @@ -1209,29 +1393,33 @@ setMethod("selectExpr", #' WithColumn #' -#' Return a new DataFrame with the specified column added. +#' Return a new DataFrame by adding a column or replacing the existing column +#' that has the same name. #' #' @param x A DataFrame -#' @param colName A string containing the name of the new column. +#' @param colName A column name. #' @param col A Column expression. -#' @return A DataFrame with the new column added. -#' @family dataframe_funcs +#' @return A DataFrame with the new column added or the existing column replaced. +#' @family DataFrame functions #' @rdname withColumn #' @name withColumn -#' @aliases mutate transform +#' @seealso \link{rename} \link{mutate} #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- withColumn(df, "newCol", df$col1 * 5) +#' # Replace an existing column +#' newDF2 <- withColumn(newDF, "newCol", newDF$col1) #' } setMethod("withColumn", signature(x = "DataFrame", colName = "character", col = "Column"), function(x, colName, col) { - select(x, x$"*", alias(col, colName)) + sdf <- callJMethod(x@sdf, "withColumn", colName, col@jc) + dataFrame(sdf) }) #' Mutate @@ -1241,17 +1429,17 @@ setMethod("withColumn", #' @param .data A DataFrame #' @param col a named argument of the form name = col #' @return A new DataFrame with the new columns added. -#' @family dataframe_funcs -#' @rdname withColumn +#' @family DataFrame functions +#' @rdname mutate #' @name mutate -#' @aliases withColumn transform +#' @seealso \link{rename} \link{withColumn} #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2) #' names(newDF) # Will contain newCol, newCol2 #' newDF2 <- transform(df, newCol = df$col1 / 5, newCol2 = df$col1 * 2) @@ -1275,17 +1463,15 @@ setMethod("mutate", }) #' @export -#' @family dataframe_funcs -#' @rdname withColumn +#' @rdname mutate #' @name transform -#' @aliases withColumn mutate setMethod("transform", signature(`_data` = "DataFrame"), function(`_data`, ...) { mutate(`_data`, ...) }) -#' WithColumnRenamed +#' rename #' #' Rename an existing column in a DataFrame. #' @@ -1293,16 +1479,17 @@ setMethod("transform", #' @param existingCol The name of the column you want to change. #' @param newCol The new column name. #' @return A DataFrame with the column name changed. -#' @family dataframe_funcs -#' @rdname withColumnRenamed +#' @family DataFrame functions +#' @rdname rename #' @name withColumnRenamed +#' @seealso \link{mutate} #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- withColumnRenamed(df, "col1", "newCol1") #' } setMethod("withColumnRenamed", @@ -1318,24 +1505,16 @@ setMethod("withColumnRenamed", select(x, cols) }) -#' Rename -#' -#' Rename an existing column in a DataFrame. -#' -#' @param x A DataFrame -#' @param newCol A named pair of the form new_column_name = existing_column -#' @return A DataFrame with the column name changed. -#' @family dataframe_funcs -#' @rdname withColumnRenamed +#' @param newColPair A named pair of the form new_column_name = existing_column +#' @rdname rename #' @name rename -#' @aliases withColumnRenamed #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- rename(df, col1 = df$newCol1) #' } setMethod("rename", @@ -1370,17 +1549,16 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' @param decreasing A logical argument indicating sorting order for columns when #' a character vector is specified for col #' @return A DataFrame where all elements are sorted. -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname arrange #' @name arrange -#' @aliases orderby #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' arrange(df, df$col1) #' arrange(df, asc(df$col1), desc(abs(df$col2))) #' arrange(df, "col1", decreasing = TRUE) @@ -1397,8 +1575,8 @@ setMethod("arrange", dataFrame(sdf) }) -#' @family dataframe_funcs #' @rdname arrange +#' @name arrange #' @export setMethod("arrange", signature(x = "DataFrame", col = "character"), @@ -1429,9 +1607,9 @@ setMethod("arrange", do.call("arrange", c(x, jcols)) }) -#' @family dataframe_funcs #' @rdname arrange -#' @name orderby +#' @name orderBy +#' @export setMethod("orderBy", signature(x = "DataFrame", col = "characterOrColumn"), function(x, col) { @@ -1446,7 +1624,7 @@ setMethod("orderBy", #' @param condition The condition to filter on. This may either be a Column expression #' or a string containing a SQL statement #' @return A DataFrame containing only the rows that meet the condition. -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname filter #' @name filter #' @family subsetting functions @@ -1456,7 +1634,7 @@ setMethod("orderBy", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' filter(df, "col1 > 0") #' filter(df, df$col2 != "abcdefg") #' } @@ -1470,7 +1648,7 @@ setMethod("filter", dataFrame(sdf) }) -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname filter #' @name where setMethod("where", @@ -1479,6 +1657,36 @@ setMethod("where", filter(x, condition) }) +#' dropDuplicates +#' +#' Returns a new DataFrame with duplicate rows removed, considering only +#' the subset of columns. +#' +#' @param x A DataFrame. +#' @param colnames A character vector of column names. +#' @return A DataFrame with duplicate rows removed. +#' @family DataFrame functions +#' @rdname dropduplicates +#' @name dropDuplicates +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- read.json(sqlContext, path) +#' dropDuplicates(df) +#' dropDuplicates(df, c("col1", "col2")) +#' } +setMethod("dropDuplicates", + signature(x = "DataFrame"), + function(x, colNames = columns(x)) { + stopifnot(class(colNames) == "character") + + sdf <- callJMethod(x@sdf, "dropDuplicates", as.list(colNames)) + dataFrame(sdf) + }) + #' Join #' #' Join two DataFrames based on the given join expression. @@ -1491,16 +1699,17 @@ setMethod("where", #' 'inner', 'outer', 'full', 'fullouter', leftouter', 'left_outer', 'left', #' 'right_outer', 'rightouter', 'right', and 'leftsemi'. The default joinType is "inner". #' @return A DataFrame containing the result of the join operation. -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname join #' @name join +#' @seealso \link{merge} #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlContext, path) -#' df2 <- jsonFile(sqlContext, path2) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) #' join(df1, df2) # Performs a Cartesian #' join(df1, df2, df1$col1 == df2$col2) # Performs an inner join based on expression #' join(df1, df2, df1$col1 == df2$col2, "right_outer") @@ -1530,9 +1739,7 @@ setMethod("join", dataFrame(sdf) }) -#' #' @name merge -#' @aliases join #' @title Merges two data frames #' @param x the first data frame to be joined #' @param y the second data frame to be joined @@ -1550,15 +1757,16 @@ setMethod("join", #' be returned. If all.x is set to FALSE and all.y is set to TRUE, a right #' outer join will be returned. If all.x and all.y are set to TRUE, a full #' outer join will be returned. -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname merge +#' @seealso \link{join} #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlContext, path) -#' df2 <- jsonFile(sqlContext, path2) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) #' merge(df1, df2) # Performs a Cartesian #' merge(df1, df2, by = "col1") # Performs an inner join based on expression #' merge(df1, df2, by.x = "col1", by.y = "col2", all.y = TRUE) @@ -1571,7 +1779,7 @@ setMethod("merge", signature(x = "DataFrame", y = "DataFrame"), function(x, y, by = intersect(names(x), names(y)), by.x = by, by.y = by, all = FALSE, all.x = all, all.y = all, - sort = TRUE, suffixes = c("_x","_y"), ... ) { + sort = TRUE, suffixes = c("_x", "_y"), ... ) { if (length(suffixes) != 2) { stop("suffixes must have length 2") @@ -1673,7 +1881,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { cols } -#' UnionAll +#' rbind #' #' Return a new DataFrame containing the union of rows in this DataFrame #' and another DataFrame. This is equivalent to `UNION ALL` in SQL. @@ -1682,16 +1890,16 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #' @param x A Spark DataFrame #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the union. -#' @family dataframe_funcs -#' @rdname unionAll +#' @family DataFrame functions +#' @rdname rbind #' @name unionAll #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlContext, path) -#' df2 <- jsonFile(sqlContext, path2) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) #' unioned <- unionAll(df, df2) #' } setMethod("unionAll", @@ -1702,13 +1910,11 @@ setMethod("unionAll", }) #' @title Union two or more DataFrames -#' #' @description Returns a new DataFrame containing rows of all parameters. #' -#' @family dataframe_funcs #' @rdname rbind #' @name rbind -#' @aliases unionAll +#' @export setMethod("rbind", signature(... = "DataFrame"), function(x, ..., deparse.level = 1) { @@ -1727,7 +1933,7 @@ setMethod("rbind", #' @param x A Spark DataFrame #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the intersect. -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname intersect #' @name intersect #' @export @@ -1735,8 +1941,8 @@ setMethod("rbind", #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlContext, path) -#' df2 <- jsonFile(sqlContext, path2) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) #' intersectDF <- intersect(df, df2) #' } setMethod("intersect", @@ -1754,7 +1960,7 @@ setMethod("intersect", #' @param x A Spark DataFrame #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the except operation. -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname except #' @name except #' @export @@ -1762,8 +1968,8 @@ setMethod("intersect", #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlContext, path) -#' df2 <- jsonFile(sqlContext, path2) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) #' exceptDF <- except(df, df2) #' } #' @rdname except @@ -1792,51 +1998,52 @@ setMethod("except", #' @param df A SparkSQL DataFrame #' @param path A name for the table #' @param source A name for external data source -#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname write.df #' @name write.df -#' @aliases saveDF #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' write.df(df, "myfile", "parquet", "overwrite") #' saveDF(df, parquetPath2, "parquet", mode = saveMode, mergeSchema = mergeSchema) #' } setMethod("write.df", signature(df = "DataFrame", path = "character"), - function(df, path, source = NULL, mode = "append", ...){ + function(df, path, source = NULL, mode = "error", ...){ if (is.null(source)) { - sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) + if (exists(".sparkRSQLsc", envir = .sparkREnv)) { + sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) + } else if (exists(".sparkRHivesc", envir = .sparkREnv)) { + sqlContext <- get(".sparkRHivesc", envir = .sparkREnv) + } else { + stop("sparkRHive or sparkRSQL context has to be specified") + } source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", "org.apache.spark.sql.parquet") } - allModes <- c("append", "overwrite", "error", "ignore") - # nolint start - if (!(mode %in% allModes)) { - stop('mode should be one of "append", "overwrite", "error", "ignore"') - } - # nolint end - jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) + jmode <- convertToJSaveMode(mode) options <- varargsToEnv(...) if (!is.null(path)) { options[["path"]] <- path } - callJMethod(df@sdf, "save", source, jmode, options) + write <- callJMethod(df@sdf, "write") + write <- callJMethod(write, "format", source) + write <- callJMethod(write, "mode", jmode) + write <- callJMethod(write, "save", path) }) -#' @family dataframe_funcs #' @rdname write.df #' @name saveDF #' @export setMethod("saveDF", signature(df = "DataFrame", path = "character"), - function(df, path, source = NULL, mode = "append", ...){ + function(df, path, source = NULL, mode = "error", ...){ write.df(df, path, source, mode, ...) }) @@ -1859,9 +2066,9 @@ setMethod("saveDF", #' @param df A SparkSQL DataFrame #' @param tableName A name for the table #' @param source A name for external data source -#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname saveAsTable #' @name saveAsTable #' @export @@ -1870,30 +2077,34 @@ setMethod("saveDF", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' saveAsTable(df, "myfile") #' } setMethod("saveAsTable", - signature(df = "DataFrame", tableName = "character", source = "character", - mode = "character"), - function(df, tableName, source = NULL, mode="append", ...){ + signature(df = "DataFrame", tableName = "character"), + function(df, tableName, source = NULL, mode="error", ...){ if (is.null(source)) { - sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) - source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", - "org.apache.spark.sql.parquet") - } - allModes <- c("append", "overwrite", "error", "ignore") - # nolint start - if (!(mode %in% allModes)) { - stop('mode should be one of "append", "overwrite", "error", "ignore"') + if (exists(".sparkRSQLsc", envir = .sparkREnv)) { + sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) + } else if (exists(".sparkRHivesc", envir = .sparkREnv)) { + sqlContext <- get(".sparkRHivesc", envir = .sparkREnv) + } else { + stop("sparkRHive or sparkRSQL context has to be specified") + } + source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", + "org.apache.spark.sql.parquet") } - # nolint end - jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) + jmode <- convertToJSaveMode(mode) options <- varargsToEnv(...) - callJMethod(df@sdf, "saveAsTable", tableName, source, jmode, options) + + write <- callJMethod(df@sdf, "write") + write <- callJMethod(write, "format", source) + write <- callJMethod(write, "mode", jmode) + write <- callJMethod(write, "options", options) + callJMethod(write, "saveAsTable", tableName) }) -#' describe +#' summary #' #' Computes statistics for numeric columns. #' If no columns are given, this function computes statistics for all numerical columns. @@ -1902,17 +2113,16 @@ setMethod("saveAsTable", #' @param col A string of name #' @param ... Additional expressions #' @return A DataFrame -#' @family dataframe_funcs -#' @rdname describe +#' @family DataFrame functions +#' @rdname summary #' @name describe -#' @aliases summary #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' describe(df) #' describe(df, "col1") #' describe(df, "col1", "col2") @@ -1925,8 +2135,7 @@ setMethod("describe", dataFrame(sdf) }) -#' @family dataframe_funcs -#' @rdname describe +#' @rdname summary #' @name describe setMethod("describe", signature(x = "DataFrame"), @@ -1936,17 +2145,12 @@ setMethod("describe", dataFrame(sdf) }) -#' @title Summary -#' -#' @description Computes statistics for numeric columns of the DataFrame -#' -#' @family dataframe_funcs #' @rdname summary #' @name summary setMethod("summary", - signature(x = "DataFrame"), - function(x) { - describe(x) + signature(object = "DataFrame"), + function(object, ...) { + describe(object) }) @@ -1965,17 +2169,16 @@ setMethod("summary", #' @param cols Optional list of column names to consider. #' @return A DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname nafunctions #' @name dropna -#' @aliases na.omit #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlCtx <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- read.json(sqlCtx, path) #' dropna(df) #' } setMethod("dropna", @@ -1995,7 +2198,6 @@ setMethod("dropna", dataFrame(sdf) }) -#' @family dataframe_funcs #' @rdname nafunctions #' @name na.omit #' @export @@ -2021,9 +2223,7 @@ setMethod("na.omit", #' type are ignored. For example, if value is a character, and #' subset contains a non-character column, then the non-character #' column is simply ignored. -#' @return A DataFrame #' -#' @family dataframe_funcs #' @rdname nafunctions #' @name fillna #' @export @@ -2032,7 +2232,7 @@ setMethod("na.omit", #' sc <- sparkR.init() #' sqlCtx <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- read.json(sqlCtx, path) #' fillna(df, 1) #' fillna(df, list("age" = 20, "name" = "unknown")) #' } @@ -2087,7 +2287,7 @@ setMethod("fillna", #' @title Download data from a DataFrame into a data.frame #' @param x a DataFrame #' @return a data.frame -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname as.data.frame #' @examples \dontrun{ #' @@ -2099,7 +2299,7 @@ setMethod("as.data.frame", function(x, ...) { # Check if additional parameters have been passed if (length(list(...)) > 0) { - stop(paste("Unused argument(s): ", paste(list(...), collapse=", "))) + stop(paste("Unused argument(s): ", paste(list(...), collapse = ", "))) } collect(x) }) @@ -2108,7 +2308,7 @@ setMethod("as.data.frame", #' the DataFrame is searched by R when evaluating a variable, so columns in #' the DataFrame can be accessed by simply giving their names. #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname attach #' @title Attach DataFrame to R search path #' @param what (DataFrame) The DataFrame to attach @@ -2126,11 +2326,145 @@ setMethod("as.data.frame", setMethod("attach", signature(what = "DataFrame"), function(what, pos = 2, name = deparse(substitute(what)), warn.conflicts = TRUE) { - cols <- columns(what) - stopifnot(length(cols) > 0) - newEnv <- new.env() - for (i in 1:length(cols)) { - assign(x = cols[i], value = what[, cols[i]], envir = newEnv) - } + newEnv <- assignNewEnv(what) attach(newEnv, pos = pos, name = name, warn.conflicts = warn.conflicts) }) + +#' Evaluate a R expression in an environment constructed from a DataFrame +#' with() allows access to columns of a DataFrame by simply referring to +#' their name. It appends every column of a DataFrame into a new +#' environment. Then, the given expression is evaluated in this new +#' environment. +#' +#' @rdname with +#' @title Evaluate a R expression in an environment constructed from a DataFrame +#' @param data (DataFrame) DataFrame to use for constructing an environment. +#' @param expr (expression) Expression to evaluate. +#' @param ... arguments to be passed to future methods. +#' @examples +#' \dontrun{ +#' with(irisDf, nrow(Sepal_Width)) +#' } +#' @seealso \link{attach} +setMethod("with", + signature(data = "DataFrame"), + function(data, expr, ...) { + newEnv <- assignNewEnv(data) + eval(substitute(expr), envir = newEnv, enclos = newEnv) + }) + +#' Display the structure of a DataFrame, including column names, column types, as well as a +#' a small sample of rows. +#' @name str +#' @title Compactly display the structure of a dataset +#' @rdname str +#' @family DataFrame functions +#' @param object a DataFrame +#' @examples \dontrun{ +#' # Create a DataFrame from the Iris dataset +#' irisDF <- createDataFrame(sqlContext, iris) +#' +#' # Show the structure of the DataFrame +#' str(irisDF) +#' } +setMethod("str", + signature(object = "DataFrame"), + function(object) { + + # TODO: These could be made global parameters, though in R it's not the case + MAX_CHAR_PER_ROW <- 120 + MAX_COLS <- 100 + + # Get the column names and types of the DataFrame + names <- names(object) + types <- coltypes(object) + + # Get the first elements of the dataset. Limit number of columns accordingly + localDF <- if (ncol(object) > MAX_COLS) { + head(object[, c(1:MAX_COLS)]) + } else { + head(object) + } + + # The number of observations will not be displayed as computing the + # number of rows is a very expensive operation + cat(paste0("'", class(object), "': ", length(names), " variables:\n")) + + if (nrow(localDF) > 0) { + for (i in 1 : ncol(localDF)) { + # Get the first elements for each column + + firstElements <- if (types[i] == "character") { + paste(paste0("\"", localDF[, i], "\""), collapse = " ") + } else { + paste(localDF[, i], collapse = " ") + } + + # Add the corresponding number of spaces for alignment + spaces <- paste(rep(" ", max(nchar(names) - nchar(names[i]))), collapse = "") + + # Get the short type. For 'character', it would be 'chr'; + # 'for numeric', it's 'num', etc. + dataType <- SHORT_TYPES[[types[i]]] + if (is.null(dataType)) { + dataType <- substring(types[i], 1, 3) + } + + # Concatenate the colnames, coltypes, and first + # elements of each column + line <- paste0(" $ ", names[i], spaces, ": ", + dataType, " ", firstElements) + + # Chop off extra characters if this is too long + cat(substr(line, 1, MAX_CHAR_PER_ROW)) + cat("\n") + } + + if (ncol(localDF) < ncol(object)) { + cat(paste0("\nDisplaying first ", ncol(localDF), " columns only.")) + } + } + }) + +#' drop +#' +#' Returns a new DataFrame with columns dropped. +#' This is a no-op if schema doesn't contain column name(s). +#' +#' @param x A SparkSQL DataFrame. +#' @param cols A character vector of column names or a Column. +#' @return A DataFrame +#' +#' @family DataFrame functions +#' @rdname drop +#' @name drop +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- read.json(sqlCtx, path) +#' drop(df, "col1") +#' drop(df, c("col1", "col2")) +#' drop(df, df$col1) +#' } +setMethod("drop", + signature(x = "DataFrame"), + function(x, col) { + stopifnot(class(col) == "character" || class(col) == "Column") + + if (class(col) == "Column") { + sdf <- callJMethod(x@sdf, "drop", col@jc) + } else { + sdf <- callJMethod(x@sdf, "drop", as.list(col)) + } + dataFrame(sdf) + }) + +# Expose base::drop +setMethod("drop", + signature(x = "ANY"), + function(x) { + base::drop(x) + }) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 051e441d4e063..35c4e6f1afaf4 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -19,16 +19,15 @@ setOldClass("jobj") -# @title S4 class that represents an RDD -# @description RDD can be created using functions like -# \code{parallelize}, \code{textFile} etc. -# @rdname RDD -# @seealso parallelize, textFile -# -# @slot env An R environment that stores bookkeeping states of the RDD -# @slot jrdd Java object reference to the backing JavaRDD -# to an RDD -# @export +#' @title S4 class that represents an RDD +#' @description RDD can be created using functions like +#' \code{parallelize}, \code{textFile} etc. +#' @rdname RDD +#' @seealso parallelize, textFile +#' @slot env An R environment that stores bookkeeping states of the RDD +#' @slot jrdd Java object reference to the backing JavaRDD +#' to an RDD +#' @noRd setClass("RDD", slots = list(env = "environment", jrdd = "jobj")) @@ -68,7 +67,7 @@ setMethod("initialize", "RDD", function(.Object, jrdd, serializedMode, setMethod("show", "RDD", function(object) { - cat(paste(callJMethod(getJRDD(object), "toString"), "\n", sep="")) + cat(paste(callJMethod(getJRDD(object), "toString"), "\n", sep = "")) }) setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) { @@ -111,14 +110,13 @@ setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) .Object }) -# @rdname RDD -# @export -# -# @param jrdd Java object reference to the backing JavaRDD -# @param serializedMode Use "byte" if the RDD stores data serialized in R, "string" if the RDD -# stores strings, and "row" if the RDD stores the rows of a DataFrame -# @param isCached TRUE if the RDD is cached -# @param isCheckpointed TRUE if the RDD has been checkpointed +#' @rdname RDD +#' @noRd +#' @param jrdd Java object reference to the backing JavaRDD +#' @param serializedMode Use "byte" if the RDD stores data serialized in R, "string" if the RDD +#' stores strings, and "row" if the RDD stores the rows of a DataFrame +#' @param isCached TRUE if the RDD is cached +#' @param isCheckpointed TRUE if the RDD has been checkpointed RDD <- function(jrdd, serializedMode = "byte", isCached = FALSE, isCheckpointed = FALSE) { new("RDD", jrdd, serializedMode, isCached, isCheckpointed) @@ -182,7 +180,7 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"), } # Save the serialization flag after we create a RRDD rdd@env$serializedMode <- serializedMode - rdd@env$jrdd_val <- callJMethod(rddRef, "asJavaRDD") # rddRef$asJavaRDD() + rdd@env$jrdd_val <- callJMethod(rddRef, "asJavaRDD") rdd@env$jrdd_val }) @@ -201,19 +199,20 @@ setValidity("RDD", ############ Actions and Transformations ############ -# Persist an RDD -# -# Persist this RDD with the default storage level (MEMORY_ONLY). -# -# @param x The RDD to cache -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10, 2L) -# cache(rdd) -#} -# @rdname cache-methods -# @aliases cache,RDD-method +#' Persist an RDD +#' +#' Persist this RDD with the default storage level (MEMORY_ONLY). +#' +#' @param x The RDD to cache +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' cache(rdd) +#'} +#' @rdname cache-methods +#' @aliases cache,RDD-method +#' @noRd setMethod("cache", signature(x = "RDD"), function(x) { @@ -222,22 +221,23 @@ setMethod("cache", x }) -# Persist an RDD -# -# Persist this RDD with the specified storage level. For details of the -# supported storage levels, refer to -# http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence. -# -# @param x The RDD to persist -# @param newLevel The new storage level to be assigned -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10, 2L) -# persist(rdd, "MEMORY_AND_DISK") -#} -# @rdname persist -# @aliases persist,RDD-method +#' Persist an RDD +#' +#' Persist this RDD with the specified storage level. For details of the +#' supported storage levels, refer to +#'\url{http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence}. +#' +#' @param x The RDD to persist +#' @param newLevel The new storage level to be assigned +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' persist(rdd, "MEMORY_AND_DISK") +#'} +#' @rdname persist +#' @aliases persist,RDD-method +#' @noRd setMethod("persist", signature(x = "RDD", newLevel = "character"), function(x, newLevel = "MEMORY_ONLY") { @@ -246,21 +246,22 @@ setMethod("persist", x }) -# Unpersist an RDD -# -# Mark the RDD as non-persistent, and remove all blocks for it from memory and -# disk. -# -# @param x The RDD to unpersist -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10, 2L) -# cache(rdd) # rdd@@env$isCached == TRUE -# unpersist(rdd) # rdd@@env$isCached == FALSE -#} -# @rdname unpersist-methods -# @aliases unpersist,RDD-method +#' Unpersist an RDD +#' +#' Mark the RDD as non-persistent, and remove all blocks for it from memory and +#' disk. +#' +#' @param x The RDD to unpersist +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' cache(rdd) # rdd@@env$isCached == TRUE +#' unpersist(rdd) # rdd@@env$isCached == FALSE +#'} +#' @rdname unpersist-methods +#' @aliases unpersist,RDD-method +#' @noRd setMethod("unpersist", signature(x = "RDD"), function(x) { @@ -269,24 +270,25 @@ setMethod("unpersist", x }) -# Checkpoint an RDD -# -# Mark this RDD for checkpointing. It will be saved to a file inside the -# checkpoint directory set with setCheckpointDir() and all references to its -# parent RDDs will be removed. This function must be called before any job has -# been executed on this RDD. It is strongly recommended that this RDD is -# persisted in memory, otherwise saving it on a file will require recomputation. -# -# @param x The RDD to checkpoint -# @examples -#\dontrun{ -# sc <- sparkR.init() -# setCheckpointDir(sc, "checkpoint") -# rdd <- parallelize(sc, 1:10, 2L) -# checkpoint(rdd) -#} -# @rdname checkpoint-methods -# @aliases checkpoint,RDD-method +#' Checkpoint an RDD +#' +#' Mark this RDD for checkpointing. It will be saved to a file inside the +#' checkpoint directory set with setCheckpointDir() and all references to its +#' parent RDDs will be removed. This function must be called before any job has +#' been executed on this RDD. It is strongly recommended that this RDD is +#' persisted in memory, otherwise saving it on a file will require recomputation. +#' +#' @param x The RDD to checkpoint +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' setCheckpointDir(sc, "checkpoint") +#' rdd <- parallelize(sc, 1:10, 2L) +#' checkpoint(rdd) +#'} +#' @rdname checkpoint-methods +#' @aliases checkpoint,RDD-method +#' @noRd setMethod("checkpoint", signature(x = "RDD"), function(x) { @@ -296,44 +298,57 @@ setMethod("checkpoint", x }) -# Gets the number of partitions of an RDD -# -# @param x A RDD. -# @return the number of partitions of rdd as an integer. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10, 2L) -# numPartitions(rdd) # 2L -#} -# @rdname numPartitions -# @aliases numPartitions,RDD-method +#' Gets the number of partitions of an RDD +#' +#' @param x A RDD. +#' @return the number of partitions of rdd as an integer. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' getNumPartitions(rdd) # 2L +#'} +#' @rdname getNumPartitions +#' @aliases getNumPartitions,RDD-method +#' @noRd +setMethod("getNumPartitions", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "getNumPartitions") + }) + +#' Gets the number of partitions of an RDD, the same as getNumPartitions. +#' But this function has been deprecated, please use getNumPartitions. +#' +#' @rdname getNumPartitions +#' @aliases numPartitions,RDD-method +#' @noRd setMethod("numPartitions", signature(x = "RDD"), function(x) { - jrdd <- getJRDD(x) - partitions <- callJMethod(jrdd, "partitions") - callJMethod(partitions, "size") + .Deprecated("getNumPartitions") + getNumPartitions(x) }) -# Collect elements of an RDD -# -# @description -# \code{collect} returns a list that contains all of the elements in this RDD. -# -# @param x The RDD to collect -# @param ... Other optional arguments to collect -# @param flatten FALSE if the list should not flattened -# @return a list containing elements in the RDD -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10, 2L) -# collect(rdd) # list from 1 to 10 -# collectPartition(rdd, 0L) # list from 1 to 5 -#} -# @rdname collect-methods -# @aliases collect,RDD-method +#' Collect elements of an RDD +#' +#' @description +#' \code{collect} returns a list that contains all of the elements in this RDD. +#' +#' @param x The RDD to collect +#' @param ... Other optional arguments to collect +#' @param flatten FALSE if the list should not flattened +#' @return a list containing elements in the RDD +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' collect(rdd) # list from 1 to 10 +#' collectPartition(rdd, 0L) # list from 1 to 5 +#'} +#' @rdname collect-methods +#' @aliases collect,RDD-method +#' @noRd setMethod("collect", signature(x = "RDD"), function(x, flatten = TRUE) { @@ -344,12 +359,13 @@ setMethod("collect", }) -# @description -# \code{collectPartition} returns a list that contains all of the elements -# in the specified partition of the RDD. -# @param partitionId the partition to collect (starts from 0) -# @rdname collect-methods -# @aliases collectPartition,integer,RDD-method +#' @description +#' \code{collectPartition} returns a list that contains all of the elements +#' in the specified partition of the RDD. +#' @param partitionId the partition to collect (starts from 0) +#' @rdname collect-methods +#' @aliases collectPartition,integer,RDD-method +#' @noRd setMethod("collectPartition", signature(x = "RDD", partitionId = "integer"), function(x, partitionId) { @@ -362,17 +378,20 @@ setMethod("collectPartition", serializedMode = getSerializedMode(x)) }) -# @description -# \code{collectAsMap} returns a named list as a map that contains all of the elements -# in a key-value pair RDD. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(list(1, 2), list(3, 4)), 2L) -# collectAsMap(rdd) # list(`1` = 2, `3` = 4) -#} -# @rdname collect-methods -# @aliases collectAsMap,RDD-method +#' @description +#' \code{collectAsMap} returns a named list as a map that contains all of the elements +#' in a key-value pair RDD. +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4)), 2L) +#' collectAsMap(rdd) # list(`1` = 2, `3` = 4) +#'} +# nolint end +#' @rdname collect-methods +#' @aliases collectAsMap,RDD-method +#' @noRd setMethod("collectAsMap", signature(x = "RDD"), function(x) { @@ -382,19 +401,20 @@ setMethod("collectAsMap", as.list(map) }) -# Return the number of elements in the RDD. -# -# @param x The RDD to count -# @return number of elements in the RDD. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# count(rdd) # 10 -# length(rdd) # Same as count -#} -# @rdname count -# @aliases count,RDD-method +#' Return the number of elements in the RDD. +#' +#' @param x The RDD to count +#' @return number of elements in the RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' count(rdd) # 10 +#' length(rdd) # Same as count +#'} +#' @rdname count +#' @aliases count,RDD-method +#' @noRd setMethod("count", signature(x = "RDD"), function(x) { @@ -406,55 +426,59 @@ setMethod("count", sum(as.integer(vals)) }) -# Return the number of elements in the RDD -# @export -# @rdname count +#' Return the number of elements in the RDD +#' @rdname count +#' @noRd setMethod("length", signature(x = "RDD"), function(x) { count(x) }) -# Return the count of each unique value in this RDD as a list of -# (value, count) pairs. -# -# Same as countByValue in Spark. -# -# @param x The RDD to count -# @return list of (value, count) pairs, where count is number of each unique -# value in rdd. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, c(1,2,3,2,1)) -# countByValue(rdd) # (1,2L), (2,2L), (3,1L) -#} -# @rdname countByValue -# @aliases countByValue,RDD-method +#' Return the count of each unique value in this RDD as a list of +#' (value, count) pairs. +#' +#' Same as countByValue in Spark. +#' +#' @param x The RDD to count +#' @return list of (value, count) pairs, where count is number of each unique +#' value in rdd. +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, c(1,2,3,2,1)) +#' countByValue(rdd) # (1,2L), (2,2L), (3,1L) +#'} +# nolint end +#' @rdname countByValue +#' @aliases countByValue,RDD-method +#' @noRd setMethod("countByValue", signature(x = "RDD"), function(x) { ones <- lapply(x, function(item) { list(item, 1L) }) - collect(reduceByKey(ones, `+`, numPartitions(x))) + collect(reduceByKey(ones, `+`, getNumPartitions(x))) }) -# Apply a function to all elements -# -# This function creates a new RDD by applying the given transformation to all -# elements of the given RDD -# -# @param X The RDD to apply the transformation. -# @param FUN the transformation to apply on each element -# @return a new RDD created by the transformation. -# @rdname lapply -# @aliases lapply -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# multiplyByTwo <- lapply(rdd, function(x) { x * 2 }) -# collect(multiplyByTwo) # 2,4,6... -#} +#' Apply a function to all elements +#' +#' This function creates a new RDD by applying the given transformation to all +#' elements of the given RDD +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each element +#' @return a new RDD created by the transformation. +#' @rdname lapply +#' @noRd +#' @aliases lapply +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' multiplyByTwo <- lapply(rdd, function(x) { x * 2 }) +#' collect(multiplyByTwo) # 2,4,6... +#'} setMethod("lapply", signature(X = "RDD", FUN = "function"), function(X, FUN) { @@ -464,31 +488,33 @@ setMethod("lapply", lapplyPartitionsWithIndex(X, func) }) -# @rdname lapply -# @aliases map,RDD,function-method +#' @rdname lapply +#' @aliases map,RDD,function-method +#' @noRd setMethod("map", signature(X = "RDD", FUN = "function"), function(X, FUN) { lapply(X, FUN) }) -# Flatten results after apply a function to all elements -# -# This function return a new RDD by first applying a function to all -# elements of this RDD, and then flattening the results. -# -# @param X The RDD to apply the transformation. -# @param FUN the transformation to apply on each element -# @return a new RDD created by the transformation. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# multiplyByTwo <- flatMap(rdd, function(x) { list(x*2, x*10) }) -# collect(multiplyByTwo) # 2,20,4,40,6,60... -#} -# @rdname flatMap -# @aliases flatMap,RDD,function-method +#' Flatten results after apply a function to all elements +#' +#' This function return a new RDD by first applying a function to all +#' elements of this RDD, and then flattening the results. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each element +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' multiplyByTwo <- flatMap(rdd, function(x) { list(x*2, x*10) }) +#' collect(multiplyByTwo) # 2,20,4,40,6,60... +#'} +#' @rdname flatMap +#' @aliases flatMap,RDD,function-method +#' @noRd setMethod("flatMap", signature(X = "RDD", FUN = "function"), function(X, FUN) { @@ -501,83 +527,90 @@ setMethod("flatMap", lapplyPartition(X, partitionFunc) }) -# Apply a function to each partition of an RDD -# -# Return a new RDD by applying a function to each partition of this RDD. -# -# @param X The RDD to apply the transformation. -# @param FUN the transformation to apply on each partition. -# @return a new RDD created by the transformation. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# partitionSum <- lapplyPartition(rdd, function(part) { Reduce("+", part) }) -# collect(partitionSum) # 15, 40 -#} -# @rdname lapplyPartition -# @aliases lapplyPartition,RDD,function-method +#' Apply a function to each partition of an RDD +#' +#' Return a new RDD by applying a function to each partition of this RDD. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each partition. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' partitionSum <- lapplyPartition(rdd, function(part) { Reduce("+", part) }) +#' collect(partitionSum) # 15, 40 +#'} +#' @rdname lapplyPartition +#' @aliases lapplyPartition,RDD,function-method +#' @noRd setMethod("lapplyPartition", signature(X = "RDD", FUN = "function"), function(X, FUN) { lapplyPartitionsWithIndex(X, function(s, part) { FUN(part) }) }) -# mapPartitions is the same as lapplyPartition. -# -# @rdname lapplyPartition -# @aliases mapPartitions,RDD,function-method +#' mapPartitions is the same as lapplyPartition. +#' +#' @rdname lapplyPartition +#' @aliases mapPartitions,RDD,function-method +#' @noRd setMethod("mapPartitions", signature(X = "RDD", FUN = "function"), function(X, FUN) { lapplyPartition(X, FUN) }) -# Return a new RDD by applying a function to each partition of this RDD, while -# tracking the index of the original partition. -# -# @param X The RDD to apply the transformation. -# @param FUN the transformation to apply on each partition; takes the partition -# index and a list of elements in the particular partition. -# @return a new RDD created by the transformation. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10, 5L) -# prod <- lapplyPartitionsWithIndex(rdd, function(partIndex, part) { -# partIndex * Reduce("+", part) }) -# collect(prod, flatten = FALSE) # 0, 7, 22, 45, 76 -#} -# @rdname lapplyPartitionsWithIndex -# @aliases lapplyPartitionsWithIndex,RDD,function-method +#' Return a new RDD by applying a function to each partition of this RDD, while +#' tracking the index of the original partition. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each partition; takes the partition +#' index and a list of elements in the particular partition. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 5L) +#' prod <- lapplyPartitionsWithIndex(rdd, function(partIndex, part) { +#' partIndex * Reduce("+", part) }) +#' collect(prod, flatten = FALSE) # 0, 7, 22, 45, 76 +#'} +#' @rdname lapplyPartitionsWithIndex +#' @aliases lapplyPartitionsWithIndex,RDD,function-method +#' @noRd setMethod("lapplyPartitionsWithIndex", signature(X = "RDD", FUN = "function"), function(X, FUN) { PipelinedRDD(X, FUN) }) -# @rdname lapplyPartitionsWithIndex -# @aliases mapPartitionsWithIndex,RDD,function-method +#' @rdname lapplyPartitionsWithIndex +#' @aliases mapPartitionsWithIndex,RDD,function-method +#' @noRd setMethod("mapPartitionsWithIndex", signature(X = "RDD", FUN = "function"), function(X, FUN) { lapplyPartitionsWithIndex(X, FUN) }) -# This function returns a new RDD containing only the elements that satisfy -# a predicate (i.e. returning TRUE in a given logical function). -# The same as `filter()' in Spark. -# -# @param x The RDD to be filtered. -# @param f A unary predicate function. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# unlist(collect(filterRDD(rdd, function (x) { x < 3 }))) # c(1, 2) -#} -# @rdname filterRDD -# @aliases filterRDD,RDD,function-method +#' This function returns a new RDD containing only the elements that satisfy +#' a predicate (i.e. returning TRUE in a given logical function). +#' The same as `filter()' in Spark. +#' +#' @param x The RDD to be filtered. +#' @param f A unary predicate function. +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' unlist(collect(filterRDD(rdd, function (x) { x < 3 }))) # c(1, 2) +#'} +# nolint end +#' @rdname filterRDD +#' @aliases filterRDD,RDD,function-method +#' @noRd setMethod("filterRDD", signature(x = "RDD", f = "function"), function(x, f) { @@ -587,30 +620,32 @@ setMethod("filterRDD", lapplyPartition(x, filter.func) }) -# @rdname filterRDD -# @aliases Filter +#' @rdname filterRDD +#' @aliases Filter +#' @noRd setMethod("Filter", signature(f = "function", x = "RDD"), function(f, x) { filterRDD(x, f) }) -# Reduce across elements of an RDD. -# -# This function reduces the elements of this RDD using the -# specified commutative and associative binary operator. -# -# @param x The RDD to reduce -# @param func Commutative and associative function to apply on elements -# of the RDD. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# reduce(rdd, "+") # 55 -#} -# @rdname reduce -# @aliases reduce,RDD,ANY-method +#' Reduce across elements of an RDD. +#' +#' This function reduces the elements of this RDD using the +#' specified commutative and associative binary operator. +#' +#' @param x The RDD to reduce +#' @param func Commutative and associative function to apply on elements +#' of the RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' reduce(rdd, "+") # 55 +#'} +#' @rdname reduce +#' @aliases reduce,RDD,ANY-method +#' @noRd setMethod("reduce", signature(x = "RDD", func = "ANY"), function(x, func) { @@ -624,70 +659,74 @@ setMethod("reduce", Reduce(func, partitionList) }) -# Get the maximum element of an RDD. -# -# @param x The RDD to get the maximum element from -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# maximum(rdd) # 10 -#} -# @rdname maximum -# @aliases maximum,RDD +#' Get the maximum element of an RDD. +#' +#' @param x The RDD to get the maximum element from +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' maximum(rdd) # 10 +#'} +#' @rdname maximum +#' @aliases maximum,RDD +#' @noRd setMethod("maximum", signature(x = "RDD"), function(x) { reduce(x, max) }) -# Get the minimum element of an RDD. -# -# @param x The RDD to get the minimum element from -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# minimum(rdd) # 1 -#} -# @rdname minimum -# @aliases minimum,RDD +#' Get the minimum element of an RDD. +#' +#' @param x The RDD to get the minimum element from +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' minimum(rdd) # 1 +#'} +#' @rdname minimum +#' @aliases minimum,RDD +#' @noRd setMethod("minimum", signature(x = "RDD"), function(x) { reduce(x, min) }) -# Add up the elements in an RDD. -# -# @param x The RDD to add up the elements in -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# sumRDD(rdd) # 55 -#} -# @rdname sumRDD -# @aliases sumRDD,RDD +#' Add up the elements in an RDD. +#' +#' @param x The RDD to add up the elements in +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' sumRDD(rdd) # 55 +#'} +#' @rdname sumRDD +#' @aliases sumRDD,RDD +#' @noRd setMethod("sumRDD", signature(x = "RDD"), function(x) { reduce(x, "+") }) -# Applies a function to all elements in an RDD, and force evaluation. -# -# @param x The RDD to apply the function -# @param func The function to be applied. -# @return invisible NULL. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# foreach(rdd, function(x) { save(x, file=...) }) -#} -# @rdname foreach -# @aliases foreach,RDD,function-method +#' Applies a function to all elements in an RDD, and force evaluation. +#' +#' @param x The RDD to apply the function +#' @param func The function to be applied. +#' @return invisible NULL. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' foreach(rdd, function(x) { save(x, file=...) }) +#'} +#' @rdname foreach +#' @aliases foreach,RDD,function-method +#' @noRd setMethod("foreach", signature(x = "RDD", func = "function"), function(x, func) { @@ -698,44 +737,48 @@ setMethod("foreach", invisible(collect(mapPartitions(x, partition.func))) }) -# Applies a function to each partition in an RDD, and force evaluation. -# -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# foreachPartition(rdd, function(part) { save(part, file=...); NULL }) -#} -# @rdname foreach -# @aliases foreachPartition,RDD,function-method +#' Applies a function to each partition in an RDD, and force evaluation. +#' +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' foreachPartition(rdd, function(part) { save(part, file=...); NULL }) +#'} +#' @rdname foreach +#' @aliases foreachPartition,RDD,function-method +#' @noRd setMethod("foreachPartition", signature(x = "RDD", func = "function"), function(x, func) { invisible(collect(mapPartitions(x, func))) }) -# Take elements from an RDD. -# -# This function takes the first NUM elements in the RDD and -# returns them in a list. -# -# @param x The RDD to take elements from -# @param num Number of elements to take -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# take(rdd, 2L) # list(1, 2) -#} -# @rdname take -# @aliases take,RDD,numeric-method +#' Take elements from an RDD. +#' +#' This function takes the first NUM elements in the RDD and +#' returns them in a list. +#' +#' @param x The RDD to take elements from +#' @param num Number of elements to take +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' take(rdd, 2L) # list(1, 2) +#'} +# nolint end +#' @rdname take +#' @aliases take,RDD,numeric-method +#' @noRd setMethod("take", signature(x = "RDD", num = "numeric"), function(x, num) { resList <- list() index <- -1 jrdd <- getJRDD(x) - numPartitions <- numPartitions(x) + numPartitions <- getNumPartitions(x) serializedModeRDD <- getSerializedMode(x) # TODO(shivaram): Collect more than one partition based on size @@ -763,42 +806,45 @@ setMethod("take", }) -# First -# -# Return the first element of an RDD -# -# @rdname first -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# first(rdd) -# } +#' First +#' +#' Return the first element of an RDD +#' +#' @rdname first +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' first(rdd) +#' } +#' @noRd setMethod("first", signature(x = "RDD"), function(x) { take(x, 1)[[1]] }) -# Removes the duplicates from RDD. -# -# This function returns a new RDD containing the distinct elements in the -# given RDD. The same as `distinct()' in Spark. -# -# @param x The RDD to remove duplicates from. -# @param numPartitions Number of partitions to create. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, c(1,2,2,3,3,3)) -# sort(unlist(collect(distinct(rdd)))) # c(1, 2, 3) -#} -# @rdname distinct -# @aliases distinct,RDD-method +#' Removes the duplicates from RDD. +#' +#' This function returns a new RDD containing the distinct elements in the +#' given RDD. The same as `distinct()' in Spark. +#' +#' @param x The RDD to remove duplicates from. +#' @param numPartitions Number of partitions to create. +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, c(1,2,2,3,3,3)) +#' sort(unlist(collect(distinct(rdd)))) # c(1, 2, 3) +#'} +# nolint end +#' @rdname distinct +#' @aliases distinct,RDD-method +#' @noRd setMethod("distinct", signature(x = "RDD"), - function(x, numPartitions = SparkR:::numPartitions(x)) { + function(x, numPartitions = SparkR:::getNumPartitions(x)) { identical.mapped <- lapply(x, function(x) { list(x, NULL) }) reduced <- reduceByKey(identical.mapped, function(x, y) { x }, @@ -807,24 +853,25 @@ setMethod("distinct", resRDD }) -# Return an RDD that is a sampled subset of the given RDD. -# -# The same as `sample()' in Spark. (We rename it due to signature -# inconsistencies with the `sample()' function in R's base package.) -# -# @param x The RDD to sample elements from -# @param withReplacement Sampling with replacement or not -# @param fraction The (rough) sample target fraction -# @param seed Randomness seed value -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# collect(sampleRDD(rdd, FALSE, 0.5, 1618L)) # ~5 distinct elements -# collect(sampleRDD(rdd, TRUE, 0.5, 9L)) # ~5 elements possibly with duplicates -#} -# @rdname sampleRDD -# @aliases sampleRDD,RDD +#' Return an RDD that is a sampled subset of the given RDD. +#' +#' The same as `sample()' in Spark. (We rename it due to signature +#' inconsistencies with the `sample()' function in R's base package.) +#' +#' @param x The RDD to sample elements from +#' @param withReplacement Sampling with replacement or not +#' @param fraction The (rough) sample target fraction +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' collect(sampleRDD(rdd, FALSE, 0.5, 1618L)) # ~5 distinct elements +#' collect(sampleRDD(rdd, TRUE, 0.5, 9L)) # ~5 elements possibly with duplicates +#'} +#' @rdname sampleRDD +#' @aliases sampleRDD,RDD +#' @noRd setMethod("sampleRDD", signature(x = "RDD", withReplacement = "logical", fraction = "numeric", seed = "integer"), @@ -868,23 +915,24 @@ setMethod("sampleRDD", lapplyPartitionsWithIndex(x, samplingFunc) }) -# Return a list of the elements that are a sampled subset of the given RDD. -# -# @param x The RDD to sample elements from -# @param withReplacement Sampling with replacement or not -# @param num Number of elements to return -# @param seed Randomness seed value -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:100) -# # exactly 5 elements sampled, which may not be distinct -# takeSample(rdd, TRUE, 5L, 1618L) -# # exactly 5 distinct elements sampled -# takeSample(rdd, FALSE, 5L, 16181618L) -#} -# @rdname takeSample -# @aliases takeSample,RDD +#' Return a list of the elements that are a sampled subset of the given RDD. +#' +#' @param x The RDD to sample elements from +#' @param withReplacement Sampling with replacement or not +#' @param num Number of elements to return +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:100) +#' # exactly 5 elements sampled, which may not be distinct +#' takeSample(rdd, TRUE, 5L, 1618L) +#' # exactly 5 distinct elements sampled +#' takeSample(rdd, FALSE, 5L, 16181618L) +#'} +#' @rdname takeSample +#' @aliases takeSample,RDD +#' @noRd setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", num = "integer", seed = "integer"), function(x, withReplacement, num, seed) { @@ -931,18 +979,21 @@ setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", base::sample(samples)[1:total] }) -# Creates tuples of the elements in this RDD by applying a function. -# -# @param x The RDD. -# @param func The function to be applied. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(1, 2, 3)) -# collect(keyBy(rdd, function(x) { x*x })) # list(list(1, 1), list(4, 2), list(9, 3)) -#} -# @rdname keyBy -# @aliases keyBy,RDD +#' Creates tuples of the elements in this RDD by applying a function. +#' +#' @param x The RDD. +#' @param func The function to be applied. +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3)) +#' collect(keyBy(rdd, function(x) { x*x })) # list(list(1, 1), list(4, 2), list(9, 3)) +#'} +# nolint end +#' @rdname keyBy +#' @aliases keyBy,RDD +#' @noRd setMethod("keyBy", signature(x = "RDD", func = "function"), function(x, func) { @@ -952,49 +1003,51 @@ setMethod("keyBy", lapply(x, apply.func) }) -# Return a new RDD that has exactly numPartitions partitions. -# Can increase or decrease the level of parallelism in this RDD. Internally, -# this uses a shuffle to redistribute data. -# If you are decreasing the number of partitions in this RDD, consider using -# coalesce, which can avoid performing a shuffle. -# -# @param x The RDD. -# @param numPartitions Number of partitions to create. -# @seealso coalesce -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(1, 2, 3, 4, 5, 6, 7), 4L) -# numPartitions(rdd) # 4 -# numPartitions(repartition(rdd, 2L)) # 2 -#} -# @rdname repartition -# @aliases repartition,RDD +#' Return a new RDD that has exactly numPartitions partitions. +#' Can increase or decrease the level of parallelism in this RDD. Internally, +#' this uses a shuffle to redistribute data. +#' If you are decreasing the number of partitions in this RDD, consider using +#' coalesce, which can avoid performing a shuffle. +#' +#' @param x The RDD. +#' @param numPartitions Number of partitions to create. +#' @seealso coalesce +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5, 6, 7), 4L) +#' getNumPartitions(rdd) # 4 +#' getNumPartitions(repartition(rdd, 2L)) # 2 +#'} +#' @rdname repartition +#' @aliases repartition,RDD +#' @noRd setMethod("repartition", signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions) { coalesce(x, numPartitions, TRUE) }) -# Return a new RDD that is reduced into numPartitions partitions. -# -# @param x The RDD. -# @param numPartitions Number of partitions to create. -# @seealso repartition -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(1, 2, 3, 4, 5), 3L) -# numPartitions(rdd) # 3 -# numPartitions(coalesce(rdd, 1L)) # 1 -#} -# @rdname coalesce -# @aliases coalesce,RDD +#' Return a new RDD that is reduced into numPartitions partitions. +#' +#' @param x The RDD. +#' @param numPartitions Number of partitions to create. +#' @seealso repartition +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5), 3L) +#' getNumPartitions(rdd) # 3 +#' getNumPartitions(coalesce(rdd, 1L)) # 1 +#'} +#' @rdname coalesce +#' @aliases coalesce,RDD +#' @noRd setMethod("coalesce", signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions, shuffle = FALSE) { numPartitions <- numToInt(numPartitions) - if (shuffle || numPartitions > SparkR:::numPartitions(x)) { + if (shuffle || numPartitions > SparkR:::getNumPartitions(x)) { func <- function(partIndex, part) { set.seed(partIndex) # partIndex as seed start <- as.integer(base::sample(numPartitions, 1) - 1) @@ -1013,19 +1066,20 @@ setMethod("coalesce", } }) -# Save this RDD as a SequenceFile of serialized objects. -# -# @param x The RDD to save -# @param path The directory where the file is saved -# @seealso objectFile -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:3) -# saveAsObjectFile(rdd, "/tmp/sparkR-tmp") -#} -# @rdname saveAsObjectFile -# @aliases saveAsObjectFile,RDD +#' Save this RDD as a SequenceFile of serialized objects. +#' +#' @param x The RDD to save +#' @param path The directory where the file is saved +#' @seealso objectFile +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' saveAsObjectFile(rdd, "/tmp/sparkR-tmp") +#'} +#' @rdname saveAsObjectFile +#' @aliases saveAsObjectFile,RDD +#' @noRd setMethod("saveAsObjectFile", signature(x = "RDD", path = "character"), function(x, path) { @@ -1038,18 +1092,19 @@ setMethod("saveAsObjectFile", invisible(callJMethod(getJRDD(x), "saveAsObjectFile", path)) }) -# Save this RDD as a text file, using string representations of elements. -# -# @param x The RDD to save -# @param path The directory where the partitions of the text file are saved -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:3) -# saveAsTextFile(rdd, "/tmp/sparkR-tmp") -#} -# @rdname saveAsTextFile -# @aliases saveAsTextFile,RDD +#' Save this RDD as a text file, using string representations of elements. +#' +#' @param x The RDD to save +#' @param path The directory where the partitions of the text file are saved +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' saveAsTextFile(rdd, "/tmp/sparkR-tmp") +#'} +#' @rdname saveAsTextFile +#' @aliases saveAsTextFile,RDD +#' @noRd setMethod("saveAsTextFile", signature(x = "RDD", path = "character"), function(x, path) { @@ -1062,24 +1117,27 @@ setMethod("saveAsTextFile", callJMethod(getJRDD(stringRdd, serializedMode = "string"), "saveAsTextFile", path)) }) -# Sort an RDD by the given key function. -# -# @param x An RDD to be sorted. -# @param func A function used to compute the sort key for each element. -# @param ascending A flag to indicate whether the sorting is ascending or descending. -# @param numPartitions Number of partitions to create. -# @return An RDD where all elements are sorted. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(3, 2, 1)) -# collect(sortBy(rdd, function(x) { x })) # list (1, 2, 3) -#} -# @rdname sortBy -# @aliases sortBy,RDD,RDD-method +#' Sort an RDD by the given key function. +#' +#' @param x An RDD to be sorted. +#' @param func A function used to compute the sort key for each element. +#' @param ascending A flag to indicate whether the sorting is ascending or descending. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where all elements are sorted. +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(3, 2, 1)) +#' collect(sortBy(rdd, function(x) { x })) # list (1, 2, 3) +#'} +# nolint end +#' @rdname sortBy +#' @aliases sortBy,RDD,RDD-method +#' @noRd setMethod("sortBy", signature(x = "RDD", func = "function"), - function(x, func, ascending = TRUE, numPartitions = SparkR:::numPartitions(x)) { + function(x, func, ascending = TRUE, numPartitions = SparkR:::getNumPartitions(x)) { values(sortByKey(keyBy(x, func), ascending, numPartitions)) }) @@ -1111,7 +1169,7 @@ takeOrderedElem <- function(x, num, ascending = TRUE) { resList <- list() index <- -1 jrdd <- getJRDD(newRdd) - numPartitions <- numPartitions(newRdd) + numPartitions <- getNumPartitions(newRdd) serializedModeRDD <- getSerializedMode(newRdd) while (TRUE) { @@ -1138,97 +1196,101 @@ takeOrderedElem <- function(x, num, ascending = TRUE) { resList } -# Returns the first N elements from an RDD in ascending order. -# -# @param x An RDD. -# @param num Number of elements to return. -# @return The first N elements from the RDD in ascending order. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) -# takeOrdered(rdd, 6L) # list(1, 2, 3, 4, 5, 6) -#} -# @rdname takeOrdered -# @aliases takeOrdered,RDD,RDD-method +#' Returns the first N elements from an RDD in ascending order. +#' +#' @param x An RDD. +#' @param num Number of elements to return. +#' @return The first N elements from the RDD in ascending order. +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) +#' takeOrdered(rdd, 6L) # list(1, 2, 3, 4, 5, 6) +#'} +# nolint end +#' @rdname takeOrdered +#' @aliases takeOrdered,RDD,RDD-method +#' @noRd setMethod("takeOrdered", signature(x = "RDD", num = "integer"), function(x, num) { takeOrderedElem(x, num) }) -# Returns the top N elements from an RDD. -# -# @param x An RDD. -# @param num Number of elements to return. -# @return The top N elements from the RDD. -# @rdname top -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) -# top(rdd, 6L) # list(10, 9, 7, 6, 5, 4) -#} -# @rdname top -# @aliases top,RDD,RDD-method +#' Returns the top N elements from an RDD. +#' +#' @param x An RDD. +#' @param num Number of elements to return. +#' @return The top N elements from the RDD. +#' @rdname top +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) +#' top(rdd, 6L) # list(10, 9, 7, 6, 5, 4) +#'} +# nolint end +#' @aliases top,RDD,RDD-method +#' @noRd setMethod("top", signature(x = "RDD", num = "integer"), function(x, num) { takeOrderedElem(x, num, FALSE) }) -# Fold an RDD using a given associative function and a neutral "zero value". -# -# Aggregate the elements of each partition, and then the results for all the -# partitions, using a given associative function and a neutral "zero value". -# -# @param x An RDD. -# @param zeroValue A neutral "zero value". -# @param op An associative function for the folding operation. -# @return The folding result. -# @rdname fold -# @seealso reduce -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(1, 2, 3, 4, 5)) -# fold(rdd, 0, "+") # 15 -#} -# @rdname fold -# @aliases fold,RDD,RDD-method +#' Fold an RDD using a given associative function and a neutral "zero value". +#' +#' Aggregate the elements of each partition, and then the results for all the +#' partitions, using a given associative function and a neutral "zero value". +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param op An associative function for the folding operation. +#' @return The folding result. +#' @rdname fold +#' @seealso reduce +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5)) +#' fold(rdd, 0, "+") # 15 +#'} +#' @aliases fold,RDD,RDD-method +#' @noRd setMethod("fold", signature(x = "RDD", zeroValue = "ANY", op = "ANY"), function(x, zeroValue, op) { aggregateRDD(x, zeroValue, op, op) }) -# Aggregate an RDD using the given combine functions and a neutral "zero value". -# -# Aggregate the elements of each partition, and then the results for all the -# partitions, using given combine functions and a neutral "zero value". -# -# @param x An RDD. -# @param zeroValue A neutral "zero value". -# @param seqOp A function to aggregate the RDD elements. It may return a different -# result type from the type of the RDD elements. -# @param combOp A function to aggregate results of seqOp. -# @return The aggregation result. -# @rdname aggregateRDD -# @seealso reduce -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(1, 2, 3, 4)) -# zeroValue <- list(0, 0) -# seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } -# combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } -# aggregateRDD(rdd, zeroValue, seqOp, combOp) # list(10, 4) -#} -# @rdname aggregateRDD -# @aliases aggregateRDD,RDD,RDD-method +#' Aggregate an RDD using the given combine functions and a neutral "zero value". +#' +#' Aggregate the elements of each partition, and then the results for all the +#' partitions, using given combine functions and a neutral "zero value". +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param seqOp A function to aggregate the RDD elements. It may return a different +#' result type from the type of the RDD elements. +#' @param combOp A function to aggregate results of seqOp. +#' @return The aggregation result. +#' @rdname aggregateRDD +#' @seealso reduce +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4)) +#' zeroValue <- list(0, 0) +#' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } +#' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } +#' aggregateRDD(rdd, zeroValue, seqOp, combOp) # list(10, 4) +#'} +# nolint end +#' @aliases aggregateRDD,RDD,RDD-method +#' @noRd setMethod("aggregateRDD", signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", combOp = "ANY"), function(x, zeroValue, seqOp, combOp) { @@ -1241,25 +1303,24 @@ setMethod("aggregateRDD", Reduce(combOp, partitionList, zeroValue) }) -# Pipes elements to a forked external process. -# -# The same as 'pipe()' in Spark. -# -# @param x The RDD whose elements are piped to the forked external process. -# @param command The command to fork an external process. -# @param env A named list to set environment variables of the external process. -# @return A new RDD created by piping all elements to a forked external process. -# @rdname pipeRDD -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# collect(pipeRDD(rdd, "more") -# Output: c("1", "2", ..., "10") -#} -# @rdname pipeRDD -# @aliases pipeRDD,RDD,character-method +#' Pipes elements to a forked external process. +#' +#' The same as 'pipe()' in Spark. +#' +#' @param x The RDD whose elements are piped to the forked external process. +#' @param command The command to fork an external process. +#' @param env A named list to set environment variables of the external process. +#' @return A new RDD created by piping all elements to a forked external process. +#' @rdname pipeRDD +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' collect(pipeRDD(rdd, "more") +#' Output: c("1", "2", ..., "10") +#'} +#' @aliases pipeRDD,RDD,character-method +#' @noRd setMethod("pipeRDD", signature(x = "RDD", command = "character"), function(x, command, env = list()) { @@ -1274,42 +1335,40 @@ setMethod("pipeRDD", lapplyPartition(x, func) }) -# TODO: Consider caching the name in the RDD's environment -# Return an RDD's name. -# -# @param x The RDD whose name is returned. -# @rdname name -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(1,2,3)) -# name(rdd) # NULL (if not set before) -#} -# @rdname name -# @aliases name,RDD +#' TODO: Consider caching the name in the RDD's environment +#' Return an RDD's name. +#' +#' @param x The RDD whose name is returned. +#' @rdname name +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1,2,3)) +#' name(rdd) # NULL (if not set before) +#'} +#' @aliases name,RDD +#' @noRd setMethod("name", signature(x = "RDD"), function(x) { callJMethod(getJRDD(x), "name") }) -# Set an RDD's name. -# -# @param x The RDD whose name is to be set. -# @param name The RDD name to be set. -# @return a new RDD renamed. -# @rdname setName -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(1,2,3)) -# setName(rdd, "myRDD") -# name(rdd) # "myRDD" -#} -# @rdname setName -# @aliases setName,RDD +#' Set an RDD's name. +#' +#' @param x The RDD whose name is to be set. +#' @param name The RDD name to be set. +#' @return a new RDD renamed. +#' @rdname setName +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1,2,3)) +#' setName(rdd, "myRDD") +#' name(rdd) # "myRDD" +#'} +#' @aliases setName,RDD +#' @noRd setMethod("setName", signature(x = "RDD", name = "character"), function(x, name) { @@ -1317,29 +1376,32 @@ setMethod("setName", x }) -# Zip an RDD with generated unique Long IDs. -# -# Items in the kth partition will get ids k, n+k, 2*n+k, ..., where -# n is the number of partitions. So there may exist gaps, but this -# method won't trigger a spark job, which is different from -# zipWithIndex. -# -# @param x An RDD to be zipped. -# @return An RDD with zipped items. -# @seealso zipWithIndex -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) -# collect(zipWithUniqueId(rdd)) -# # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) -#} -# @rdname zipWithUniqueId -# @aliases zipWithUniqueId,RDD +#' Zip an RDD with generated unique Long IDs. +#' +#' Items in the kth partition will get ids k, n+k, 2*n+k, ..., where +#' n is the number of partitions. So there may exist gaps, but this +#' method won't trigger a spark job, which is different from +#' zipWithIndex. +#' +#' @param x An RDD to be zipped. +#' @return An RDD with zipped items. +#' @seealso zipWithIndex +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) +#' collect(zipWithUniqueId(rdd)) +#' # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) +#'} +# nolint end +#' @rdname zipWithUniqueId +#' @aliases zipWithUniqueId,RDD +#' @noRd setMethod("zipWithUniqueId", signature(x = "RDD"), function(x) { - n <- numPartitions(x) + n <- getNumPartitions(x) partitionFunc <- function(partIndex, part) { mapply( @@ -1354,32 +1416,35 @@ setMethod("zipWithUniqueId", lapplyPartitionsWithIndex(x, partitionFunc) }) -# Zip an RDD with its element indices. -# -# The ordering is first based on the partition index and then the -# ordering of items within each partition. So the first item in -# the first partition gets index 0, and the last item in the last -# partition receives the largest index. -# -# This method needs to trigger a Spark job when this RDD contains -# more than one partition. -# -# @param x An RDD to be zipped. -# @return An RDD with zipped items. -# @seealso zipWithUniqueId -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) -# collect(zipWithIndex(rdd)) -# # list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) -#} -# @rdname zipWithIndex -# @aliases zipWithIndex,RDD +#' Zip an RDD with its element indices. +#' +#' The ordering is first based on the partition index and then the +#' ordering of items within each partition. So the first item in +#' the first partition gets index 0, and the last item in the last +#' partition receives the largest index. +#' +#' This method needs to trigger a Spark job when this RDD contains +#' more than one partition. +#' +#' @param x An RDD to be zipped. +#' @return An RDD with zipped items. +#' @seealso zipWithUniqueId +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) +#' collect(zipWithIndex(rdd)) +#' # list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) +#'} +# nolint end +#' @rdname zipWithIndex +#' @aliases zipWithIndex,RDD +#' @noRd setMethod("zipWithIndex", signature(x = "RDD"), function(x) { - n <- numPartitions(x) + n <- getNumPartitions(x) if (n > 1) { nums <- collect(lapplyPartition(x, function(part) { @@ -1407,20 +1472,23 @@ setMethod("zipWithIndex", lapplyPartitionsWithIndex(x, partitionFunc) }) -# Coalesce all elements within each partition of an RDD into a list. -# -# @param x An RDD. -# @return An RDD created by coalescing all elements within -# each partition into a list. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, as.list(1:4), 2L) -# collect(glom(rdd)) -# # list(list(1, 2), list(3, 4)) -#} -# @rdname glom -# @aliases glom,RDD +#' Coalesce all elements within each partition of an RDD into a list. +#' +#' @param x An RDD. +#' @return An RDD created by coalescing all elements within +#' each partition into a list. +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, as.list(1:4), 2L) +#' collect(glom(rdd)) +#' # list(list(1, 2), list(3, 4)) +#'} +# nolint end +#' @rdname glom +#' @aliases glom,RDD +#' @noRd setMethod("glom", signature(x = "RDD"), function(x) { @@ -1433,21 +1501,22 @@ setMethod("glom", ############ Binary Functions ############# -# Return the union RDD of two RDDs. -# The same as union() in Spark. -# -# @param x An RDD. -# @param y An RDD. -# @return a new RDD created by performing the simple union (witout removing -# duplicates) of two input RDDs. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:3) -# unionRDD(rdd, rdd) # 1, 2, 3, 1, 2, 3 -#} -# @rdname unionRDD -# @aliases unionRDD,RDD,RDD-method +#' Return the union RDD of two RDDs. +#' The same as union() in Spark. +#' +#' @param x An RDD. +#' @param y An RDD. +#' @return a new RDD created by performing the simple union (witout removing +#' duplicates) of two input RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' unionRDD(rdd, rdd) # 1, 2, 3, 1, 2, 3 +#'} +#' @rdname unionRDD +#' @aliases unionRDD,RDD,RDD-method +#' @noRd setMethod("unionRDD", signature(x = "RDD", y = "RDD"), function(x, y) { @@ -1464,32 +1533,35 @@ setMethod("unionRDD", union.rdd }) -# Zip an RDD with another RDD. -# -# Zips this RDD with another one, returning key-value pairs with the -# first element in each RDD second element in each RDD, etc. Assumes -# that the two RDDs have the same number of partitions and the same -# number of elements in each partition (e.g. one was made through -# a map on the other). -# -# @param x An RDD to be zipped. -# @param other Another RDD to be zipped. -# @return An RDD zipped from the two RDDs. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, 0:4) -# rdd2 <- parallelize(sc, 1000:1004) -# collect(zipRDD(rdd1, rdd2)) -# # list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004)) -#} -# @rdname zipRDD -# @aliases zipRDD,RDD +#' Zip an RDD with another RDD. +#' +#' Zips this RDD with another one, returning key-value pairs with the +#' first element in each RDD second element in each RDD, etc. Assumes +#' that the two RDDs have the same number of partitions and the same +#' number of elements in each partition (e.g. one was made through +#' a map on the other). +#' +#' @param x An RDD to be zipped. +#' @param other Another RDD to be zipped. +#' @return An RDD zipped from the two RDDs. +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, 0:4) +#' rdd2 <- parallelize(sc, 1000:1004) +#' collect(zipRDD(rdd1, rdd2)) +#' # list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004)) +#'} +# nolint end +#' @rdname zipRDD +#' @aliases zipRDD,RDD +#' @noRd setMethod("zipRDD", signature(x = "RDD", other = "RDD"), function(x, other) { - n1 <- numPartitions(x) - n2 <- numPartitions(other) + n1 <- getNumPartitions(x) + n2 <- getNumPartitions(other) if (n1 != n2) { stop("Can only zip RDDs which have the same number of partitions.") } @@ -1503,24 +1575,27 @@ setMethod("zipRDD", mergePartitions(rdd, TRUE) }) -# Cartesian product of this RDD and another one. -# -# Return the Cartesian product of this RDD and another one, -# that is, the RDD of all pairs of elements (a, b) where a -# is in this and b is in other. -# -# @param x An RDD. -# @param other An RDD. -# @return A new RDD which is the Cartesian product of these two RDDs. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:2) -# sortByKey(cartesian(rdd, rdd)) -# # list(list(1, 1), list(1, 2), list(2, 1), list(2, 2)) -#} -# @rdname cartesian -# @aliases cartesian,RDD,RDD-method +#' Cartesian product of this RDD and another one. +#' +#' Return the Cartesian product of this RDD and another one, +#' that is, the RDD of all pairs of elements (a, b) where a +#' is in this and b is in other. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @return A new RDD which is the Cartesian product of these two RDDs. +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:2) +#' sortByKey(cartesian(rdd, rdd)) +#' # list(list(1, 1), list(1, 2), list(2, 1), list(2, 2)) +#'} +# nolint end +#' @rdname cartesian +#' @aliases cartesian,RDD,RDD-method +#' @noRd setMethod("cartesian", signature(x = "RDD", other = "RDD"), function(x, other) { @@ -1533,58 +1608,64 @@ setMethod("cartesian", mergePartitions(rdd, FALSE) }) -# Subtract an RDD with another RDD. -# -# Return an RDD with the elements from this that are not in other. -# -# @param x An RDD. -# @param other An RDD. -# @param numPartitions Number of the partitions in the result RDD. -# @return An RDD with the elements from this that are not in other. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, list(1, 1, 2, 2, 3, 4)) -# rdd2 <- parallelize(sc, list(2, 4)) -# collect(subtract(rdd1, rdd2)) -# # list(1, 1, 3) -#} -# @rdname subtract -# @aliases subtract,RDD +#' Subtract an RDD with another RDD. +#' +#' Return an RDD with the elements from this that are not in other. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @param numPartitions Number of the partitions in the result RDD. +#' @return An RDD with the elements from this that are not in other. +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(1, 1, 2, 2, 3, 4)) +#' rdd2 <- parallelize(sc, list(2, 4)) +#' collect(subtract(rdd1, rdd2)) +#' # list(1, 1, 3) +#'} +# nolint end +#' @rdname subtract +#' @aliases subtract,RDD +#' @noRd setMethod("subtract", signature(x = "RDD", other = "RDD"), - function(x, other, numPartitions = SparkR:::numPartitions(x)) { + function(x, other, numPartitions = SparkR:::getNumPartitions(x)) { mapFunction <- function(e) { list(e, NA) } rdd1 <- map(x, mapFunction) rdd2 <- map(other, mapFunction) keys(subtractByKey(rdd1, rdd2, numPartitions)) }) -# Intersection of this RDD and another one. -# -# Return the intersection of this RDD and another one. -# The output will not contain any duplicate elements, -# even if the input RDDs did. Performs a hash partition -# across the cluster. -# Note that this method performs a shuffle internally. -# -# @param x An RDD. -# @param other An RDD. -# @param numPartitions The number of partitions in the result RDD. -# @return An RDD which is the intersection of these two RDDs. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5)) -# rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8)) -# collect(sortBy(intersection(rdd1, rdd2), function(x) { x })) -# # list(1, 2, 3) -#} -# @rdname intersection -# @aliases intersection,RDD +#' Intersection of this RDD and another one. +#' +#' Return the intersection of this RDD and another one. +#' The output will not contain any duplicate elements, +#' even if the input RDDs did. Performs a hash partition +#' across the cluster. +#' Note that this method performs a shuffle internally. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @param numPartitions The number of partitions in the result RDD. +#' @return An RDD which is the intersection of these two RDDs. +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5)) +#' rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8)) +#' collect(sortBy(intersection(rdd1, rdd2), function(x) { x })) +#' # list(1, 2, 3) +#'} +# nolint end +#' @rdname intersection +#' @aliases intersection,RDD +#' @noRd setMethod("intersection", signature(x = "RDD", other = "RDD"), - function(x, other, numPartitions = SparkR:::numPartitions(x)) { + function(x, other, numPartitions = SparkR:::getNumPartitions(x)) { rdd1 <- map(x, function(v) { list(v, NA) }) rdd2 <- map(other, function(v) { list(v, NA) }) @@ -1597,26 +1678,29 @@ setMethod("intersection", keys(filterRDD(cogroup(rdd1, rdd2, numPartitions = numPartitions), filterFunction)) }) -# Zips an RDD's partitions with one (or more) RDD(s). -# Same as zipPartitions in Spark. -# -# @param ... RDDs to be zipped. -# @param func A function to transform zipped partitions. -# @return A new RDD by applying a function to the zipped partitions. -# Assumes that all the RDDs have the *same number of partitions*, but -# does *not* require them to have the same number of elements in each partition. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 -# rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 -# rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 -# collect(zipPartitions(rdd1, rdd2, rdd3, -# func = function(x, y, z) { list(list(x, y, z))} )) -# # list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6))) -#} -# @rdname zipRDD -# @aliases zipPartitions,RDD +#' Zips an RDD's partitions with one (or more) RDD(s). +#' Same as zipPartitions in Spark. +#' +#' @param ... RDDs to be zipped. +#' @param func A function to transform zipped partitions. +#' @return A new RDD by applying a function to the zipped partitions. +#' Assumes that all the RDDs have the *same number of partitions*, but +#' does *not* require them to have the same number of elements in each partition. +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 +#' rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 +#' rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 +#' collect(zipPartitions(rdd1, rdd2, rdd3, +#' func = function(x, y, z) { list(list(x, y, z))} )) +#' # list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6))) +#'} +# nolint end +#' @rdname zipRDD +#' @aliases zipPartitions,RDD +#' @noRd setMethod("zipPartitions", "RDD", function(..., func) { @@ -1624,7 +1708,7 @@ setMethod("zipPartitions", if (length(rrdds) == 1) { return(rrdds[[1]]) } - nPart <- sapply(rrdds, numPartitions) + nPart <- sapply(rrdds, getNumPartitions) if (length(unique(nPart)) != 1) { stop("Can only zipPartitions RDDs which have the same number of partitions.") } diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 1bf025cce4376..16a2578678cd3 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -17,27 +17,33 @@ # SQLcontext.R: SQLContext-driven functions + +# Map top level R type to SQL type +getInternalType <- function(x) { + # class of POSIXlt is c("POSIXlt" "POSIXt") + switch(class(x)[[1]], + integer = "integer", + character = "string", + logical = "boolean", + double = "double", + numeric = "double", + raw = "binary", + list = "array", + struct = "struct", + environment = "map", + Date = "date", + POSIXlt = "timestamp", + POSIXct = "timestamp", + stop(paste("Unsupported type for DataFrame:", class(x)))) +} + #' infer the SQL type infer_type <- function(x) { if (is.null(x)) { stop("can not infer type from NULL") } - # class of POSIXlt is c("POSIXlt" "POSIXt") - type <- switch(class(x)[[1]], - integer = "integer", - character = "string", - logical = "boolean", - double = "double", - numeric = "double", - raw = "binary", - list = "array", - struct = "struct", - environment = "map", - Date = "date", - POSIXlt = "timestamp", - POSIXct = "timestamp", - stop(paste("Unsupported type for DataFrame:", class(x)))) + type <- getInternalType(x) if (type == "map") { stopifnot(length(x) > 0) @@ -57,7 +63,7 @@ infer_type <- function(x) { }) type <- Reduce(paste0, type) type <- paste0("struct<", substr(type, 1, nchar(type) - 1), ">") - } else if (length(x) > 1) { + } else if (length(x) > 1 && type != "binary") { paste0("array<", infer_type(x[[1]]), ">") } else { type @@ -90,19 +96,25 @@ createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0 if (is.null(schema)) { schema <- names(data) } - n <- nrow(data) - m <- ncol(data) + # get rid of factor type - dropFactor <- function(x) { + cleanCols <- function(x) { if (is.factor(x)) { as.character(x) } else { x } } - data <- lapply(1:n, function(i) { - lapply(1:m, function(j) { dropFactor(data[i,j]) }) - }) + + # drop factors and wrap lists + data <- setNames(lapply(data, cleanCols), NULL) + + # check if all columns have supported type + lapply(data, getInternalType) + + # convert to rows + args <- list(FUN = list, SIMPLIFY = FALSE, USE.NAMES = FALSE) + data <- do.call(mapply, append(args, data)) } if (is.list(data)) { sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlContext) @@ -144,7 +156,6 @@ createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0 } stopifnot(class(schema) == "structType") - # schemaString <- tojson(schema) jrdd <- getJRDD(lapply(rdd, function(x) x), "row") srdd <- callJMethod(jrdd, "rdd") @@ -160,22 +171,21 @@ as.DataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0) { createDataFrame(sqlContext, data, schema, samplingRatio) } -# toDF -# -# Converts an RDD to a DataFrame by infer the types. -# -# @param x An RDD -# -# @rdname DataFrame -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# sqlContext <- sparkRSQL.init(sc) -# rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) -# df <- toDF(rdd) -# } - +#' toDF +#' +#' Converts an RDD to a DataFrame by infer the types. +#' +#' @param x An RDD +#' +#' @rdname DataFrame +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) +#' df <- toDF(rdd) +#'} setGeneric("toDF", function(x, ...) { standardGeneric("toDF") }) setMethod("toDF", signature(x = "RDD"), @@ -198,69 +208,116 @@ setMethod("toDF", signature(x = "RDD"), #' @param sqlContext SQLContext to use #' @param path Path of file to read. A vector of multiple paths is allowed. #' @return DataFrame +#' @rdname read.json +#' @name read.json #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" +#' df <- read.json(sqlContext, path) #' df <- jsonFile(sqlContext, path) #' } - -jsonFile <- function(sqlContext, path) { +read.json <- function(sqlContext, path) { # Allow the user to have a more flexible definiton of the text file path - path <- suppressWarnings(normalizePath(path)) - # Convert a string vector of paths to a string containing comma separated paths - path <- paste(path, collapse = ",") - sdf <- callJMethod(sqlContext, "jsonFile", path) + paths <- as.list(suppressWarnings(normalizePath(path))) + read <- callJMethod(sqlContext, "read") + sdf <- callJMethod(read, "json", paths) dataFrame(sdf) } +#' @rdname read.json +#' @name jsonFile +#' @export +jsonFile <- function(sqlContext, path) { + .Deprecated("read.json") + read.json(sqlContext, path) +} -# JSON RDD -# -# Loads an RDD storing one JSON object per string as a DataFrame. -# -# @param sqlContext SQLContext to use -# @param rdd An RDD of JSON string -# @param schema A StructType object to use as schema -# @param samplingRatio The ratio of simpling used to infer the schema -# @return A DataFrame -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# sqlContext <- sparkRSQL.init(sc) -# rdd <- texFile(sc, "path/to/json") -# df <- jsonRDD(sqlContext, rdd) -# } + +#' JSON RDD +#' +#' Loads an RDD storing one JSON object per string as a DataFrame. +#' +#' @param sqlContext SQLContext to use +#' @param rdd An RDD of JSON string +#' @param schema A StructType object to use as schema +#' @param samplingRatio The ratio of simpling used to infer the schema +#' @return A DataFrame +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' rdd <- texFile(sc, "path/to/json") +#' df <- jsonRDD(sqlContext, rdd) +#'} # TODO: support schema jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { + .Deprecated("read.json") rdd <- serializeToString(rdd) if (is.null(schema)) { - sdf <- callJMethod(sqlContext, "jsonRDD", callJMethod(getJRDD(rdd), "rdd"), samplingRatio) + read <- callJMethod(sqlContext, "read") + # samplingRatio is deprecated + sdf <- callJMethod(read, "json", callJMethod(getJRDD(rdd), "rdd")) dataFrame(sdf) } else { stop("not implemented") } } - #' Create a DataFrame from a Parquet file. #' #' Loads a Parquet file, returning the result as a DataFrame. #' #' @param sqlContext SQLContext to use -#' @param ... Path(s) of parquet file(s) to read. +#' @param path Path of file to read. A vector of multiple paths is allowed. #' @return DataFrame +#' @rdname read.parquet +#' @name read.parquet #' @export +read.parquet <- function(sqlContext, path) { + # Allow the user to have a more flexible definiton of the text file path + paths <- as.list(suppressWarnings(normalizePath(path))) + read <- callJMethod(sqlContext, "read") + sdf <- callJMethod(read, "parquet", paths) + dataFrame(sdf) +} +#' @rdname read.parquet +#' @name parquetFile +#' @export # TODO: Implement saveasParquetFile and write examples for both parquetFile <- function(sqlContext, ...) { + .Deprecated("read.parquet") + read.parquet(sqlContext, unlist(list(...))) +} + +#' Create a DataFrame from a text file. +#' +#' Loads a text file and returns a DataFrame with a single string column named "value". +#' Each line in the text file is a new row in the resulting DataFrame. +#' +#' @param sqlContext SQLContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @return DataFrame +#' @rdname read.text +#' @name read.text +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.txt" +#' df <- read.text(sqlContext, path) +#' } +read.text <- function(sqlContext, path) { # Allow the user to have a more flexible definiton of the text file path - paths <- lapply(list(...), function(x) suppressWarnings(normalizePath(x))) - sdf <- callJMethod(sqlContext, "parquetFile", paths) + paths <- as.list(suppressWarnings(normalizePath(path))) + read <- callJMethod(sqlContext, "read") + sdf <- callJMethod(read, "text", paths) dataFrame(sdf) } @@ -277,7 +334,7 @@ parquetFile <- function(sqlContext, ...) { #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "table") #' new_df <- sql(sqlContext, "SELECT * FROM table") #' } @@ -295,23 +352,24 @@ sql <- function(sqlContext, sqlQuery) { #' @param sqlContext SQLContext to use #' @param tableName The SparkSQL Table to convert to a DataFrame. #' @return DataFrame +#' @rdname tableToDF +#' @name tableToDF #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "table") -#' new_df <- table(sqlContext, "table") +#' new_df <- tableToDF(sqlContext, "table") #' } -table <- function(sqlContext, tableName) { +tableToDF <- function(sqlContext, tableName) { sdf <- callJMethod(sqlContext, "table", tableName) dataFrame(sdf) } - #' Tables #' #' Returns a DataFrame containing names of tables in the given database. @@ -374,7 +432,7 @@ tableNames <- function(sqlContext, databaseName = NULL) { #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "table") #' cacheTable(sqlContext, "table") #' } @@ -396,7 +454,7 @@ cacheTable <- function(sqlContext, tableName) { #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "table") #' uncacheTable(sqlContext, "table") #' } diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R index 2403925b267c8..38f0eed95e065 100644 --- a/R/pkg/R/broadcast.R +++ b/R/pkg/R/broadcast.R @@ -51,7 +51,6 @@ Broadcast <- function(id, value, jBroadcastRef, objName) { # # @param bcast The broadcast variable to get # @rdname broadcast -# @aliases value,Broadcast-method setMethod("value", signature(bcast = "Broadcast"), function(bcast) { diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index c811d1dac3bd5..25e99390a9c89 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -44,12 +44,16 @@ determineSparkSubmitBin <- function() { } generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { + jars <- paste0(jars, collapse = ",") if (jars != "") { - jars <- paste("--jars", jars) + # construct the jars argument with a space between --jars and comma-separated values + jars <- paste0("--jars ", jars) } - if (!identical(packages, "")) { - packages <- paste("--packages", packages) + packages <- paste0(packages, collapse = ",") + if (packages != "") { + # construct the packages argument with a space between --packages and comma-separated values + packages <- paste0("--packages ", packages) } combinedArgs <- paste(jars, packages, sparkSubmitOpts, args, sep = " ") diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 20de3907b7dd9..3ffd9a9890b2e 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -56,7 +56,7 @@ operators <- list( "&" = "and", "|" = "or", #, "!" = "unary_$bang" "^" = "pow" ) -column_functions1 <- c("asc", "desc", "isNull", "isNotNull") +column_functions1 <- c("asc", "desc", "isNaN", "isNull", "isNotNull") column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains") createOperator <- function(op) { @@ -209,13 +209,13 @@ setMethod("cast", setMethod("%in%", signature(x = "Column"), function(x, table) { - jc <- callJMethod(x@jc, "in", as.list(table)) + jc <- callJMethod(x@jc, "isin", as.list(table)) return(column(jc)) }) #' otherwise #' -#' If values in the specified column are null, returns the value. +#' If values in the specified column are null, returns the value. #' Can be used in conjunction with `when` to specify a default value for expressions. #' #' @rdname otherwise @@ -225,7 +225,7 @@ setMethod("%in%", setMethod("otherwise", signature(x = "Column", value = "ANY"), function(x, value) { - value <- ifelse(class(value) == "Column", value@jc, value) + value <- if (class(value) == "Column") { value@jc } else { value } jc <- callJMethod(x@jc, "otherwise", value) column(jc) }) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 720990e1c6087..b0e67c8ad26ab 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -25,23 +25,23 @@ getMinPartitions <- function(sc, minPartitions) { as.integer(minPartitions) } -# Create an RDD from a text file. -# -# This function reads a text file from HDFS, a local file system (available on all -# nodes), or any Hadoop-supported file system URI, and creates an -# RDD of strings from it. -# -# @param sc SparkContext to use -# @param path Path of file to read. A vector of multiple paths is allowed. -# @param minPartitions Minimum number of partitions to be created. If NULL, the default -# value is chosen based on available parallelism. -# @return RDD where each item is of type \code{character} -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# lines <- textFile(sc, "myfile.txt") -#} +#' Create an RDD from a text file. +#' +#' This function reads a text file from HDFS, a local file system (available on all +#' nodes), or any Hadoop-supported file system URI, and creates an +#' RDD of strings from it. +#' +#' @param sc SparkContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @param minPartitions Minimum number of partitions to be created. If NULL, the default +#' value is chosen based on available parallelism. +#' @return RDD where each item is of type \code{character} +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' lines <- textFile(sc, "myfile.txt") +#'} textFile <- function(sc, path, minPartitions = NULL) { # Allow the user to have a more flexible definiton of the text file path path <- suppressWarnings(normalizePath(path)) @@ -53,23 +53,23 @@ textFile <- function(sc, path, minPartitions = NULL) { RDD(jrdd, "string") } -# Load an RDD saved as a SequenceFile containing serialized objects. -# -# The file to be loaded should be one that was previously generated by calling -# saveAsObjectFile() of the RDD class. -# -# @param sc SparkContext to use -# @param path Path of file to read. A vector of multiple paths is allowed. -# @param minPartitions Minimum number of partitions to be created. If NULL, the default -# value is chosen based on available parallelism. -# @return RDD containing serialized R objects. -# @seealso saveAsObjectFile -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- objectFile(sc, "myfile") -#} +#' Load an RDD saved as a SequenceFile containing serialized objects. +#' +#' The file to be loaded should be one that was previously generated by calling +#' saveAsObjectFile() of the RDD class. +#' +#' @param sc SparkContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @param minPartitions Minimum number of partitions to be created. If NULL, the default +#' value is chosen based on available parallelism. +#' @return RDD containing serialized R objects. +#' @seealso saveAsObjectFile +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- objectFile(sc, "myfile") +#'} objectFile <- function(sc, path, minPartitions = NULL) { # Allow the user to have a more flexible definiton of the text file path path <- suppressWarnings(normalizePath(path)) @@ -81,29 +81,32 @@ objectFile <- function(sc, path, minPartitions = NULL) { RDD(jrdd, "byte") } -# Create an RDD from a homogeneous list or vector. -# -# This function creates an RDD from a local homogeneous list in R. The elements -# in the list are split into \code{numSlices} slices and distributed to nodes -# in the cluster. -# -# @param sc SparkContext to use -# @param coll collection to parallelize -# @param numSlices number of partitions to create in the RDD -# @return an RDD created from this collection -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10, 2) -# # The RDD should contain 10 elements -# length(rdd) -#} +#' Create an RDD from a homogeneous list or vector. +#' +#' This function creates an RDD from a local homogeneous list in R. The elements +#' in the list are split into \code{numSlices} slices and distributed to nodes +#' in the cluster. +#' +#' @param sc SparkContext to use +#' @param coll collection to parallelize +#' @param numSlices number of partitions to create in the RDD +#' @return an RDD created from this collection +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2) +#' # The RDD should contain 10 elements +#' length(rdd) +#'} parallelize <- function(sc, coll, numSlices = 1) { # TODO: bound/safeguard numSlices # TODO: unit tests for if the split works for all primitives # TODO: support matrix, data frame, etc + # nolint start + # suppress lintr warning: Place a space before left parenthesis, except in a function call. if ((!is.list(coll) && !is.vector(coll)) || is.data.frame(coll)) { + # nolint end if (is.data.frame(coll)) { message(paste("context.R: A data frame is parallelized by columns.")) } else { @@ -133,33 +136,32 @@ parallelize <- function(sc, coll, numSlices = 1) { RDD(jrdd, "byte") } -# Include this specified package on all workers -# -# This function can be used to include a package on all workers before the -# user's code is executed. This is useful in scenarios where other R package -# functions are used in a function passed to functions like \code{lapply}. -# NOTE: The package is assumed to be installed on every node in the Spark -# cluster. -# -# @param sc SparkContext to use -# @param pkg Package name -# -# @export -# @examples -#\dontrun{ -# library(Matrix) -# -# sc <- sparkR.init() -# # Include the matrix library we will be using -# includePackage(sc, Matrix) -# -# generateSparse <- function(x) { -# sparseMatrix(i=c(1, 2, 3), j=c(1, 2, 3), x=c(1, 2, 3)) -# } -# -# rdd <- lapplyPartition(parallelize(sc, 1:2, 2L), generateSparse) -# collect(rdd) -#} +#' Include this specified package on all workers +#' +#' This function can be used to include a package on all workers before the +#' user's code is executed. This is useful in scenarios where other R package +#' functions are used in a function passed to functions like \code{lapply}. +#' NOTE: The package is assumed to be installed on every node in the Spark +#' cluster. +#' +#' @param sc SparkContext to use +#' @param pkg Package name +#' @noRd +#' @examples +#'\dontrun{ +#' library(Matrix) +#' +#' sc <- sparkR.init() +#' # Include the matrix library we will be using +#' includePackage(sc, Matrix) +#' +#' generateSparse <- function(x) { +#' sparseMatrix(i=c(1, 2, 3), j=c(1, 2, 3), x=c(1, 2, 3)) +#' } +#' +#' rdd <- lapplyPartition(parallelize(sc, 1:2, 2L), generateSparse) +#' collect(rdd) +#'} includePackage <- function(sc, pkg) { pkg <- as.character(substitute(pkg)) if (exists(".packages", .sparkREnv)) { @@ -171,30 +173,30 @@ includePackage <- function(sc, pkg) { .sparkREnv$.packages <- packages } -# @title Broadcast a variable to all workers -# -# @description -# Broadcast a read-only variable to the cluster, returning a \code{Broadcast} -# object for reading it in distributed functions. -# -# @param sc Spark Context to use -# @param object Object to be broadcast -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:2, 2L) -# -# # Large Matrix object that we want to broadcast -# randomMat <- matrix(nrow=100, ncol=10, data=rnorm(1000)) -# randomMatBr <- broadcast(sc, randomMat) -# -# # Use the broadcast variable inside the function -# useBroadcast <- function(x) { -# sum(value(randomMatBr) * x) -# } -# sumRDD <- lapply(rdd, useBroadcast) -#} +#' @title Broadcast a variable to all workers +#' +#' @description +#' Broadcast a read-only variable to the cluster, returning a \code{Broadcast} +#' object for reading it in distributed functions. +#' +#' @param sc Spark Context to use +#' @param object Object to be broadcast +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:2, 2L) +#' +#' # Large Matrix object that we want to broadcast +#' randomMat <- matrix(nrow=100, ncol=10, data=rnorm(1000)) +#' randomMatBr <- broadcast(sc, randomMat) +#' +#' # Use the broadcast variable inside the function +#' useBroadcast <- function(x) { +#' sum(value(randomMatBr) * x) +#' } +#' sumRDD <- lapply(rdd, useBroadcast) +#'} broadcast <- function(sc, object) { objName <- as.character(substitute(object)) serializedObj <- serialize(object, connection = NULL) @@ -205,21 +207,21 @@ broadcast <- function(sc, object) { Broadcast(id, object, jBroadcast, objName) } -# @title Set the checkpoint directory -# -# Set the directory under which RDDs are going to be checkpointed. The -# directory must be a HDFS path if running on a cluster. -# -# @param sc Spark Context to use -# @param dirName Directory path -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# setCheckpointDir(sc, "~/checkpoint") -# rdd <- parallelize(sc, 1:2, 2L) -# checkpoint(rdd) -#} +#' @title Set the checkpoint directory +#' +#' Set the directory under which RDDs are going to be checkpointed. The +#' directory must be a HDFS path if running on a cluster. +#' +#' @param sc Spark Context to use +#' @param dirName Directory path +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' setCheckpointDir(sc, "~/checkpoint") +#' rdd <- parallelize(sc, 1:2, 2L) +#' checkpoint(rdd) +#'} setCheckpointDir <- function(sc, dirName) { invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(dirName)))) } diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index f7e56e43016ea..eefdf178733fd 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -17,6 +17,7 @@ # Utility functions to deserialize objects from Java. +# nolint start # Type mapping from Java to R # # void -> NULL @@ -32,6 +33,8 @@ # # Array[T] -> list() # Object -> jobj +# +# nolint end readObject <- function(con) { # Read type first @@ -183,7 +186,7 @@ readMultipleObjects <- function(inputCon) { # of the objects, so the number of objects varies, we try to read # all objects in a loop until the end of the stream. data <- list() - while(TRUE) { + while (TRUE) { # If reaching the end of the stream, type returned should be "". type <- readType(inputCon) if (type == "") { diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index d7fd279279137..db877b2d63d30 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -37,7 +37,7 @@ setMethod("lit", signature("ANY"), function(x) { jc <- callJStatic("org.apache.spark.sql.functions", "lit", - ifelse(class(x) == "Column", x@jc, x)) + if (class(x) == "Column") { x@jc } else { x }) column(jc) }) @@ -259,6 +259,79 @@ setMethod("column", function(x) { col(x) }) +#' corr +#' +#' Computes the Pearson Correlation Coefficient for two Columns. +#' +#' @rdname corr +#' @name corr +#' @family math_funcs +#' @export +#' @examples \dontrun{corr(df$c, df$d)} +setMethod("corr", signature(x = "Column"), + function(x, col2) { + stopifnot(class(col2) == "Column") + jc <- callJStatic("org.apache.spark.sql.functions", "corr", x@jc, col2@jc) + column(jc) + }) + +#' cov +#' +#' Compute the sample covariance between two expressions. +#' +#' @rdname cov +#' @name cov +#' @family math_funcs +#' @export +#' @examples +#' \dontrun{ +#' cov(df$c, df$d) +#' cov("c", "d") +#' covar_samp(df$c, df$d) +#' covar_samp("c", "d") +#' } +setMethod("cov", signature(x = "characterOrColumn"), + function(x, col2) { + stopifnot(is(class(col2), "characterOrColumn")) + covar_samp(x, col2) + }) + +#' @rdname cov +#' @name covar_samp +setMethod("covar_samp", signature(col1 = "characterOrColumn", col2 = "characterOrColumn"), + function(col1, col2) { + stopifnot(class(col1) == class(col2)) + if (class(col1) == "Column") { + col1 <- col1@jc + col2 <- col2@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "covar_samp", col1, col2) + column(jc) + }) + +#' covar_pop +#' +#' Compute the population covariance between two expressions. +#' +#' @rdname covar_pop +#' @name covar_pop +#' @family math_funcs +#' @export +#' @examples +#' \dontrun{ +#' covar_pop(df$c, df$d) +#' covar_pop("c", "d") +#' } +setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOrColumn"), + function(col1, col2) { + stopifnot(class(col1) == class(col2)) + if (class(col1) == "Column") { + col1 <- col1@jc + col2 <- col2@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "covar_pop", col1, col2) + column(jc) + }) #' cos #' @@ -325,6 +398,26 @@ setMethod("crc32", column(jc) }) +#' hash +#' +#' Calculates the hash code of given columns, and returns the result as a int column. +#' +#' @rdname hash +#' @name hash +#' @family misc_funcs +#' @export +#' @examples \dontrun{hash(df$c)} +setMethod("hash", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "hash", jcols) + column(jc) + }) + #' dayofmonth #' #' Extracts the day of the month as an integer from a given date/timestamp/string. @@ -357,6 +450,40 @@ setMethod("dayofyear", column(jc) }) +#' decode +#' +#' Computes the first argument into a string from a binary using the provided character set +#' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). +#' +#' @rdname decode +#' @name decode +#' @family string_funcs +#' @export +#' @examples \dontrun{decode(df$c, "UTF-8")} +setMethod("decode", + signature(x = "Column", charset = "character"), + function(x, charset) { + jc <- callJStatic("org.apache.spark.sql.functions", "decode", x@jc, charset) + column(jc) + }) + +#' encode +#' +#' Computes the first argument into a binary from a string using the provided character set +#' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). +#' +#' @rdname encode +#' @name encode +#' @family string_funcs +#' @export +#' @examples \dontrun{encode(df$c, "UTF-8")} +setMethod("encode", + signature(x = "Column", charset = "character"), + function(x, charset) { + jc <- callJStatic("org.apache.spark.sql.functions", "encode", x@jc, charset) + column(jc) + }) + #' exp #' #' Computes the exponential of the given value. @@ -373,22 +500,6 @@ setMethod("exp", column(jc) }) -#' explode -#' -#' Creates a new row for each element in the given array or map column. -#' -#' @rdname explode -#' @name explode -#' @family collection_funcs -#' @export -#' @examples \dontrun{explode(df$c)} -setMethod("explode", - signature(x = "Column"), - function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "explode", x@jc) - column(jc) - }) - #' expm1 #' #' Computes the exponential of the given value minus one. @@ -425,15 +536,27 @@ setMethod("factorial", #' #' Aggregate function: returns the first value in a group. #' +#' The function by default returns the first values it sees. It will return the first non-missing +#' value it sees when na.rm is set to true. If all values are missing, then NA is returned. +#' #' @rdname first #' @name first #' @family agg_funcs #' @export -#' @examples \dontrun{first(df$c)} +#' @examples +#' \dontrun{ +#' first(df$c) +#' first(df$c, TRUE) +#' } setMethod("first", - signature(x = "Column"), - function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "first", x@jc) + signature(x = "characterOrColumn"), + function(x, na.rm = FALSE) { + col <- if (class(x) == "Column") { + x@jc + } else { + x + } + jc <- callJStatic("org.apache.spark.sql.functions", "first", col, na.rm) column(jc) }) @@ -504,19 +627,47 @@ setMethod("initcap", column(jc) }) -#' isNaN +#' is.nan #' -#' Return true iff the column is NaN. +#' Return true if the column is NaN, alias for \link{isnan} #' -#' @rdname isNaN -#' @name isNaN +#' @rdname is.nan +#' @name is.nan #' @family normal_funcs #' @export -#' @examples \dontrun{isNaN(df$c)} -setMethod("isNaN", +#' @examples +#' \dontrun{ +#' is.nan(df$c) +#' isnan(df$c) +#' } +setMethod("is.nan", + signature(x = "Column"), + function(x) { + isnan(x) + }) + +#' @rdname is.nan +#' @name isnan +setMethod("isnan", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "isnan", x@jc) + column(jc) + }) + +#' kurtosis +#' +#' Aggregate function: returns the kurtosis of the values in a group. +#' +#' @rdname kurtosis +#' @name kurtosis +#' @family agg_funcs +#' @export +#' @examples \dontrun{kurtosis(df$c)} +setMethod("kurtosis", signature(x = "Column"), function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "isNaN", x@jc) + jc <- callJStatic("org.apache.spark.sql.functions", "kurtosis", x@jc) column(jc) }) @@ -524,15 +675,27 @@ setMethod("isNaN", #' #' Aggregate function: returns the last value in a group. #' +#' The function by default returns the last values it sees. It will return the last non-missing +#' value it sees when na.rm is set to true. If all values are missing, then NA is returned. +#' #' @rdname last #' @name last #' @family agg_funcs #' @export -#' @examples \dontrun{last(df$c)} +#' @examples +#' \dontrun{ +#' last(df$c) +#' last(df$c, TRUE) +#' } setMethod("last", - signature(x = "Column"), - function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "last", x@jc) + signature(x = "characterOrColumn"), + function(x, na.rm = FALSE) { + col <- if (class(x) == "Column") { + x@jc + } else { + x + } + jc <- callJStatic("org.apache.spark.sql.functions", "last", col, na.rm) column(jc) }) @@ -861,6 +1024,28 @@ setMethod("rtrim", column(jc) }) +#' sd +#' +#' Aggregate function: alias for \link{stddev_samp} +#' +#' @rdname sd +#' @name sd +#' @family agg_funcs +#' @seealso \link{stddev_pop}, \link{stddev_samp} +#' @export +#' @examples +#'\dontrun{ +#'stddev(df$c) +#'select(df, stddev(df$age)) +#'agg(df, sd(df$age)) +#'} +setMethod("sd", + signature(x = "Column"), + function(x) { + # In R, sample standard deviation is calculated with the sd() function. + stddev_samp(x) + }) + #' second #' #' Extracts the seconds as an integer from a given date/timestamp/string. @@ -942,19 +1127,19 @@ setMethod("sinh", column(jc) }) -#' size +#' skewness #' -#' Returns length of array or map. +#' Aggregate function: returns the skewness of the values in a group. #' -#' @rdname size -#' @name size -#' @family collection_funcs +#' @rdname skewness +#' @name skewness +#' @family agg_funcs #' @export -#' @examples \dontrun{size(df$c)} -setMethod("size", +#' @examples \dontrun{skewness(df$c)} +setMethod("skewness", signature(x = "Column"), function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "size", x@jc) + jc <- callJStatic("org.apache.spark.sql.functions", "skewness", x@jc) column(jc) }) @@ -974,6 +1159,74 @@ setMethod("soundex", column(jc) }) +#' @rdname sd +#' @name stddev +setMethod("stddev", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "stddev", x@jc) + column(jc) + }) + +#' stddev_pop +#' +#' Aggregate function: returns the population standard deviation of the expression in a group. +#' +#' @rdname stddev_pop +#' @name stddev_pop +#' @family agg_funcs +#' @seealso \link{sd}, \link{stddev_samp} +#' @export +#' @examples \dontrun{stddev_pop(df$c)} +setMethod("stddev_pop", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "stddev_pop", x@jc) + column(jc) + }) + +#' stddev_samp +#' +#' Aggregate function: returns the unbiased sample standard deviation of the expression in a group. +#' +#' @rdname stddev_samp +#' @name stddev_samp +#' @family agg_funcs +#' @seealso \link{stddev_pop}, \link{sd} +#' @export +#' @examples \dontrun{stddev_samp(df$c)} +setMethod("stddev_samp", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "stddev_samp", x@jc) + column(jc) + }) + +#' struct +#' +#' Creates a new struct column that composes multiple input columns. +#' +#' @rdname struct +#' @name struct +#' @family normal_funcs +#' @export +#' @examples +#' \dontrun{ +#' struct(df$c, df$d) +#' struct("col1", "col2") +#' } +setMethod("struct", + signature(x = "characterOrColumn"), + function(x, ...) { + if (class(x) == "Column") { + jcols <- lapply(list(x, ...), function(x) { x@jc }) + jc <- callJStatic("org.apache.spark.sql.functions", "struct", jcols) + } else { + jc <- callJStatic("org.apache.spark.sql.functions", "struct", x, list(...)) + } + column(jc) + }) + #' sqrt #' #' Computes the square root of the specified float value. @@ -1168,6 +1421,71 @@ setMethod("upper", column(jc) }) +#' var +#' +#' Aggregate function: alias for \link{var_samp}. +#' +#' @rdname var +#' @name var +#' @family agg_funcs +#' @seealso \link{var_pop}, \link{var_samp} +#' @export +#' @examples +#'\dontrun{ +#'variance(df$c) +#'select(df, var_pop(df$age)) +#'agg(df, var(df$age)) +#'} +setMethod("var", + signature(x = "Column"), + function(x) { + # In R, sample variance is calculated with the var() function. + var_samp(x) + }) + +#' @rdname var +#' @name variance +setMethod("variance", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "variance", x@jc) + column(jc) + }) + +#' var_pop +#' +#' Aggregate function: returns the population variance of the values in a group. +#' +#' @rdname var_pop +#' @name var_pop +#' @family agg_funcs +#' @seealso \link{var}, \link{var_samp} +#' @export +#' @examples \dontrun{var_pop(df$c)} +setMethod("var_pop", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "var_pop", x@jc) + column(jc) + }) + +#' var_samp +#' +#' Aggregate function: returns the unbiased variance of the values in a group. +#' +#' @rdname var_samp +#' @name var_samp +#' @family agg_funcs +#' @seealso \link{var_pop}, \link{var} +#' @export +#' @examples \dontrun{var_samp(df$c)} +setMethod("var_samp", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "var_samp", x@jc) + column(jc) + }) + #' weekofyear #' #' Extracts the week number as an integer from a given date/timestamp/string. @@ -1337,9 +1655,10 @@ setMethod("pmod", signature(y = "Column"), #' @name approxCountDistinct #' @return the approximate number of distinct items in a group. #' @export +#' @examples \dontrun{approxCountDistinct(df$c, 0.02)} setMethod("approxCountDistinct", signature(x = "Column"), - function(x, rsd = 0.95) { + function(x, rsd = 0.05) { jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) column(jc) }) @@ -1351,14 +1670,16 @@ setMethod("approxCountDistinct", #' @name countDistinct #' @return the number of distinct items in a group. #' @export +#' @examples \dontrun{countDistinct(df$c)} setMethod("countDistinct", signature(x = "Column"), function(x, ...) { - jcol <- lapply(list(...), function (x) { + jcols <- lapply(list(...), function (x) { + stopifnot(class(x) == "Column") x@jc }) jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, - jcol) + jcols) column(jc) }) @@ -1371,10 +1692,14 @@ setMethod("countDistinct", #' @rdname concat #' @name concat #' @export +#' @examples \dontrun{concat(df$strings, df$strings2)} setMethod("concat", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function(x) { x@jc }) + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) jc <- callJStatic("org.apache.spark.sql.functions", "concat", jcols) column(jc) }) @@ -1388,11 +1713,15 @@ setMethod("concat", #' @rdname greatest #' @name greatest #' @export +#' @examples \dontrun{greatest(df$c, df$d)} setMethod("greatest", signature(x = "Column"), function(x, ...) { stopifnot(length(list(...)) > 0) - jcols <- lapply(list(x, ...), function(x) { x@jc }) + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) jc <- callJStatic("org.apache.spark.sql.functions", "greatest", jcols) column(jc) }) @@ -1400,17 +1729,21 @@ setMethod("greatest", #' least #' #' Returns the least value of the list of column names, skipping null values. -#' This function takes at least 2 parameters. It will return null iff all parameters are null. +#' This function takes at least 2 parameters. It will return null if all parameters are null. #' #' @family normal_funcs #' @rdname least #' @name least #' @export +#' @examples \dontrun{least(df$c, df$d)} setMethod("least", signature(x = "Column"), function(x, ...) { stopifnot(length(list(...)) > 0) - jcols <- lapply(list(x, ...), function(x) { x@jc }) + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) jc <- callJStatic("org.apache.spark.sql.functions", "least", jcols) column(jc) }) @@ -1419,11 +1752,10 @@ setMethod("least", #' #' Computes the ceiling of the given value. #' -#' @family math_funcs #' @rdname ceil -#' @name ceil -#' @aliases ceil +#' @name ceiling #' @export +#' @examples \dontrun{ceiling(df$c)} setMethod("ceiling", signature(x = "Column"), function(x) { @@ -1434,11 +1766,10 @@ setMethod("ceiling", #' #' Computes the signum of the given value. #' -#' @family math_funcs #' @rdname signum -#' @name signum -#' @aliases signum +#' @name sign #' @export +#' @examples \dontrun{sign(df$c)} setMethod("sign", signature(x = "Column"), function(x) { signum(x) @@ -1448,11 +1779,10 @@ setMethod("sign", signature(x = "Column"), #' #' Aggregate function: returns the number of distinct items in a group. #' -#' @family agg_funcs #' @rdname countDistinct -#' @name countDistinct -#' @aliases countDistinct +#' @name n_distinct #' @export +#' @examples \dontrun{n_distinct(df$c)} setMethod("n_distinct", signature(x = "Column"), function(x, ...) { countDistinct(x, ...) @@ -1462,11 +1792,10 @@ setMethod("n_distinct", signature(x = "Column"), #' #' Aggregate function: returns the number of items in a group. #' -#' @family agg_funcs #' @rdname count -#' @name count -#' @aliases count +#' @name n #' @export +#' @examples \dontrun{n(df$c)} setMethod("n", signature(x = "Column"), function(x) { count(x) @@ -1487,6 +1816,7 @@ setMethod("n", signature(x = "Column"), #' @rdname date_format #' @name date_format #' @export +#' @examples \dontrun{date_format(df$t, 'MM/dd/yyy')} setMethod("date_format", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "date_format", y@jc, x) @@ -1501,6 +1831,7 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' @rdname from_utc_timestamp #' @name from_utc_timestamp #' @export +#' @examples \dontrun{from_utc_timestamp(df$t, 'PST')} setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "from_utc_timestamp", y@jc, x) @@ -1519,6 +1850,7 @@ setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), #' @rdname instr #' @name instr #' @export +#' @examples \dontrun{instr(df$c, 'b')} setMethod("instr", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "instr", y@jc, x) @@ -1533,13 +1865,18 @@ setMethod("instr", signature(y = "Column", x = "character"), #' For example, \code{next_day('2015-07-27', "Sunday")} returns 2015-08-02 because that is the first #' Sunday after 2015-07-27. #' -#' Day of the week parameter is case insensitive, and accepts: +#' Day of the week parameter is case insensitive, and accepts first three or two characters: #' "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". #' #' @family datetime_funcs #' @rdname next_day #' @name next_day #' @export +#' @examples +#'\dontrun{ +#'next_day(df$d, 'Sun') +#'next_day(df$d, 'Sunday') +#'} setMethod("next_day", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "next_day", y@jc, x) @@ -1554,6 +1891,7 @@ setMethod("next_day", signature(y = "Column", x = "character"), #' @rdname to_utc_timestamp #' @name to_utc_timestamp #' @export +#' @examples \dontrun{to_utc_timestamp(df$t, 'PST')} setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "to_utc_timestamp", y@jc, x) @@ -1567,8 +1905,8 @@ setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), #' @name add_months #' @family datetime_funcs #' @rdname add_months -#' @name add_months #' @export +#' @examples \dontrun{add_months(df$d, 1)} setMethod("add_months", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "add_months", y@jc, as.integer(x)) @@ -1583,6 +1921,7 @@ setMethod("add_months", signature(y = "Column", x = "numeric"), #' @rdname date_add #' @name date_add #' @export +#' @examples \dontrun{date_add(df$d, 1)} setMethod("date_add", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "date_add", y@jc, as.integer(x)) @@ -1597,6 +1936,7 @@ setMethod("date_add", signature(y = "Column", x = "numeric"), #' @rdname date_sub #' @name date_sub #' @export +#' @examples \dontrun{date_sub(df$d, 1)} setMethod("date_sub", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "date_sub", y@jc, as.integer(x)) @@ -1605,16 +1945,19 @@ setMethod("date_sub", signature(y = "Column", x = "numeric"), #' format_number #' -#' Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places, +#' Formats numeric column y to a format like '#,###,###.##', rounded to x decimal places, #' and returns the result as a string column. #' -#' If d is 0, the result has no decimal point or fractional part. -#' If d < 0, the result will be null.' +#' If x is 0, the result has no decimal point or fractional part. +#' If x < 0, the result will be null. #' +#' @param y column to format +#' @param x number of decimal place to format to #' @family string_funcs #' @rdname format_number #' @name format_number #' @export +#' @examples \dontrun{format_number(df$n, 4)} setMethod("format_number", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1634,6 +1977,7 @@ setMethod("format_number", signature(y = "Column", x = "numeric"), #' @rdname sha2 #' @name sha2 #' @export +#' @examples \dontrun{sha2(df$c, 256)} setMethod("sha2", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "sha2", y@jc, as.integer(x)) @@ -1642,13 +1986,14 @@ setMethod("sha2", signature(y = "Column", x = "numeric"), #' shiftLeft #' -#' Shift the the given value numBits left. If the given value is a long value, this function +#' Shift the given value numBits left. If the given value is a long value, this function #' will return a long value else it will return an integer value. #' #' @family math_funcs #' @rdname shiftLeft #' @name shiftLeft #' @export +#' @examples \dontrun{shiftLeft(df$c, 1)} setMethod("shiftLeft", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1659,13 +2004,14 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), #' shiftRight #' -#' Shift the the given value numBits right. If the given value is a long value, it will return +#' Shift the given value numBits right. If the given value is a long value, it will return #' a long value else it will return an integer value. #' #' @family math_funcs #' @rdname shiftRight #' @name shiftRight #' @export +#' @examples \dontrun{shiftRight(df$c, 1)} setMethod("shiftRight", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1676,13 +2022,14 @@ setMethod("shiftRight", signature(y = "Column", x = "numeric"), #' shiftRightUnsigned #' -#' Unsigned shift the the given value numBits right. If the given value is a long value, +#' Unsigned shift the given value numBits right. If the given value is a long value, #' it will return a long value else it will return an integer value. #' #' @family math_funcs #' @rdname shiftRightUnsigned #' @name shiftRightUnsigned #' @export +#' @examples \dontrun{shiftRightUnsigned(df$c, 1)} setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1700,6 +2047,7 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), #' @rdname concat_ws #' @name concat_ws #' @export +#' @examples \dontrun{concat_ws('-', df$s, df$d)} setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { jcols <- lapply(list(x, ...), function(x) { x@jc }) @@ -1715,6 +2063,7 @@ setMethod("concat_ws", signature(sep = "character", x = "Column"), #' @rdname conv #' @name conv #' @export +#' @examples \dontrun{conv(df$n, 2, 16)} setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeric"), function(x, fromBase, toBase) { fromBase <- as.integer(fromBase) @@ -1734,6 +2083,7 @@ setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeri #' @rdname expr #' @name expr #' @export +#' @examples \dontrun{expr('length(name)')} setMethod("expr", signature(x = "character"), function(x) { jc <- callJStatic("org.apache.spark.sql.functions", "expr", x) @@ -1748,6 +2098,7 @@ setMethod("expr", signature(x = "character"), #' @rdname format_string #' @name format_string #' @export +#' @examples \dontrun{format_string('%d %s', df$a, df$b)} setMethod("format_string", signature(format = "character", x = "Column"), function(format, x, ...) { jcols <- lapply(list(x, ...), function(arg) { arg@jc }) @@ -1767,6 +2118,11 @@ setMethod("format_string", signature(format = "character", x = "Column"), #' @rdname from_unixtime #' @name from_unixtime #' @export +#' @examples +#'\dontrun{ +#'from_unixtime(df$t) +#'from_unixtime(df$t, 'yyyy/MM/dd HH') +#'} setMethod("from_unixtime", signature(x = "Column"), function(x, format = "yyyy-MM-dd HH:mm:ss") { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1775,6 +2131,69 @@ setMethod("from_unixtime", signature(x = "Column"), column(jc) }) +#' window +#' +#' Bucketize rows into one or more time windows given a timestamp specifying column. Window +#' starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window +#' [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in +#' the order of months are not supported. +#' +#' The time column must be of TimestampType. +#' +#' Durations are provided as strings, e.g. '1 second', '1 day 12 hours', '2 minutes'. Valid +#' interval strings are 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'. +#' If the `slideDuration` is not provided, the windows will be tumbling windows. +#' +#' The startTime is the offset with respect to 1970-01-01 00:00:00 UTC with which to start +#' window intervals. For example, in order to have hourly tumbling windows that start 15 minutes +#' past the hour, e.g. 12:15-13:15, 13:15-14:15... provide `startTime` as `15 minutes`. +#' +#' The output column will be a struct called 'window' by default with the nested columns 'start' +#' and 'end'. +#' +#' @family datetime_funcs +#' @rdname window +#' @name window +#' @export +#' @examples +#'\dontrun{ +#' # One minute windows every 15 seconds 10 seconds after the minute, e.g. 09:00:10-09:01:10, +#' # 09:00:25-09:01:25, 09:00:40-09:01:40, ... +#' window(df$time, "1 minute", "15 seconds", "10 seconds") +#' +#' # One minute tumbling windows 15 seconds after the minute, e.g. 09:00:15-09:01:15, +#' # 09:01:15-09:02:15... +#' window(df$time, "1 minute", startTime = "15 seconds") +#' +#' # Thirty second windows every 10 seconds, e.g. 09:00:00-09:00:30, 09:00:10-09:00:40, ... +#' window(df$time, "30 seconds", "10 seconds") +#'} +setMethod("window", signature(x = "Column"), + function(x, windowDuration, slideDuration = NULL, startTime = NULL) { + stopifnot(is.character(windowDuration)) + if (!is.null(slideDuration) && !is.null(startTime)) { + stopifnot(is.character(slideDuration) && is.character(startTime)) + jc <- callJStatic("org.apache.spark.sql.functions", + "window", + x@jc, windowDuration, slideDuration, startTime) + } else if (!is.null(slideDuration)) { + stopifnot(is.character(slideDuration)) + jc <- callJStatic("org.apache.spark.sql.functions", + "window", + x@jc, windowDuration, slideDuration) + } else if (!is.null(startTime)) { + stopifnot(is.character(startTime)) + jc <- callJStatic("org.apache.spark.sql.functions", + "window", + x@jc, windowDuration, windowDuration, startTime) + } else { + jc <- callJStatic("org.apache.spark.sql.functions", + "window", + x@jc, windowDuration) + } + column(jc) + }) + #' locate #' #' Locate the position of the first occurrence of substr. @@ -1785,6 +2204,7 @@ setMethod("from_unixtime", signature(x = "Column"), #' @rdname locate #' @name locate #' @export +#' @examples \dontrun{locate('b', df$c, 1)} setMethod("locate", signature(substr = "character", str = "Column"), function(substr, str, pos = 0) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1801,6 +2221,7 @@ setMethod("locate", signature(substr = "character", str = "Column"), #' @rdname lpad #' @name lpad #' @export +#' @examples \dontrun{lpad(df$c, 6, '#')} setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1817,12 +2238,13 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), #' @rdname rand #' @name rand #' @export +#' @examples \dontrun{rand()} setMethod("rand", signature(seed = "missing"), function(seed) { jc <- callJStatic("org.apache.spark.sql.functions", "rand") column(jc) }) -#' @family normal_funcs + #' @rdname rand #' @name rand #' @export @@ -1840,12 +2262,13 @@ setMethod("rand", signature(seed = "numeric"), #' @rdname randn #' @name randn #' @export +#' @examples \dontrun{randn()} setMethod("randn", signature(seed = "missing"), function(seed) { jc <- callJStatic("org.apache.spark.sql.functions", "randn") column(jc) }) -#' @family normal_funcs + #' @rdname randn #' @name randn #' @export @@ -1863,6 +2286,7 @@ setMethod("randn", signature(seed = "numeric"), #' @rdname regexp_extract #' @name regexp_extract #' @export +#' @examples \dontrun{regexp_extract(df$c, '(\d+)-(\d+)', 1)} setMethod("regexp_extract", signature(x = "Column", pattern = "character", idx = "numeric"), function(x, pattern, idx) { @@ -1880,6 +2304,7 @@ setMethod("regexp_extract", #' @rdname regexp_replace #' @name regexp_replace #' @export +#' @examples \dontrun{regexp_replace(df$c, '(\\d+)', '--')} setMethod("regexp_replace", signature(x = "Column", pattern = "character", replacement = "character"), function(x, pattern, replacement) { @@ -1897,6 +2322,7 @@ setMethod("regexp_replace", #' @rdname rpad #' @name rpad #' @export +#' @examples \dontrun{rpad(df$c, 6, '#')} setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1910,12 +2336,17 @@ setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), #' Returns the substring from string str before count occurrences of the delimiter delim. #' If count is positive, everything the left of the final delimiter (counting from left) is #' returned. If count is negative, every to the right of the final delimiter (counting from the -#' right) is returned. substring <- index performs a case-sensitive match when searching for delim. +#' right) is returned. substring_index performs a case-sensitive match when searching for delim. #' #' @family string_funcs #' @rdname substring_index #' @name substring_index #' @export +#' @examples +#'\dontrun{ +#'substring_index(df$c, '.', 2) +#'substring_index(df$c, '.', -1) +#'} setMethod("substring_index", signature(x = "Column", delim = "character", count = "numeric"), function(x, delim, count) { @@ -1936,6 +2367,7 @@ setMethod("substring_index", #' @rdname translate #' @name translate #' @export +#' @examples \dontrun{translate(df$c, 'rnlt', '123')} setMethod("translate", signature(x = "Column", matchingString = "character", replaceString = "character"), function(x, matchingString, replaceString) { @@ -1952,12 +2384,18 @@ setMethod("translate", #' @rdname unix_timestamp #' @name unix_timestamp #' @export +#' @examples +#'\dontrun{ +#'unix_timestamp() +#'unix_timestamp(df$t) +#'unix_timestamp(df$t, 'yyyy-MM-dd HH') +#'} setMethod("unix_timestamp", signature(x = "missing", format = "missing"), function(x, format) { jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp") column(jc) }) -#' @family datetime_funcs + #' @rdname unix_timestamp #' @name unix_timestamp #' @export @@ -1966,7 +2404,7 @@ setMethod("unix_timestamp", signature(x = "Column", format = "missing"), jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp", x@jc) column(jc) }) -#' @family datetime_funcs + #' @rdname unix_timestamp #' @name unix_timestamp #' @export @@ -1983,11 +2421,13 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), #' @family normal_funcs #' @rdname when #' @name when +#' @seealso \link{ifelse} #' @export +#' @examples \dontrun{when(df$age == 2, df$age + 1)} setMethod("when", signature(condition = "Column", value = "ANY"), function(condition, value) { condition <- condition@jc - value <- ifelse(class(value) == "Column", value@jc, value) + value <- if (class(value) == "Column") { value@jc } else { value } jc <- callJStatic("org.apache.spark.sql.functions", "when", condition, value) column(jc) }) @@ -2000,13 +2440,18 @@ setMethod("when", signature(condition = "Column", value = "ANY"), #' @family normal_funcs #' @rdname ifelse #' @name ifelse +#' @seealso \link{when} #' @export +#' @examples \dontrun{ +#' ifelse(df$a > 1 & df$b > 2, 0, 1) +#' ifelse(df$a > 1, df$a, 1) +#' } setMethod("ifelse", signature(test = "Column", yes = "ANY", no = "ANY"), function(test, yes, no) { test <- test@jc - yes <- ifelse(class(yes) == "Column", yes@jc, yes) - no <- ifelse(class(no) == "Column", no@jc, no) + yes <- if (class(yes) == "Column") { yes@jc } else { yes } + no <- if (class(no) == "Column") { no@jc } else { no } jc <- callJMethod(callJStatic("org.apache.spark.sql.functions", "when", test, yes), @@ -2016,47 +2461,47 @@ setMethod("ifelse", ###################### Window functions###################### -#' cumeDist +#' cume_dist #' #' Window function: returns the cumulative distribution of values within a window partition, #' i.e. the fraction of rows that are below the current row. -#' +#' #' N = total number of rows in the partition -#' cumeDist(x) = number of values before (and including) x / N -#' +#' cume_dist(x) = number of values before (and including) x / N +#' #' This is equivalent to the CUME_DIST function in SQL. #' -#' @rdname cumeDist -#' @name cumeDist +#' @rdname cume_dist +#' @name cume_dist #' @family window_funcs #' @export -#' @examples \dontrun{cumeDist()} -setMethod("cumeDist", +#' @examples \dontrun{cume_dist()} +setMethod("cume_dist", signature(x = "missing"), function() { - jc <- callJStatic("org.apache.spark.sql.functions", "cumeDist") + jc <- callJStatic("org.apache.spark.sql.functions", "cume_dist") column(jc) }) -#' denseRank -#' +#' dense_rank +#' #' Window function: returns the rank of rows within a window partition, without any gaps. -#' The difference between rank and denseRank is that denseRank leaves no gaps in ranking -#' sequence when there are ties. That is, if you were ranking a competition using denseRank +#' The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking +#' sequence when there are ties. That is, if you were ranking a competition using dense_rank #' and had three people tie for second place, you would say that all three were in second #' place and that the next person came in third. -#' +#' #' This is equivalent to the DENSE_RANK function in SQL. #' -#' @rdname denseRank -#' @name denseRank +#' @rdname dense_rank +#' @name dense_rank #' @family window_funcs #' @export -#' @examples \dontrun{denseRank()} -setMethod("denseRank", +#' @examples \dontrun{dense_rank()} +setMethod("dense_rank", signature(x = "missing"), function() { - jc <- callJStatic("org.apache.spark.sql.functions", "denseRank") + jc <- callJStatic("org.apache.spark.sql.functions", "dense_rank") column(jc) }) @@ -2065,7 +2510,7 @@ setMethod("denseRank", #' Window function: returns the value that is `offset` rows before the current row, and #' `defaultValue` if there is less than `offset` rows before the current row. For example, #' an `offset` of one will return the previous row at any given point in the window partition. -#' +#' #' This is equivalent to the LAG function in SQL. #' #' @rdname lag @@ -2074,7 +2519,7 @@ setMethod("denseRank", #' @export #' @examples \dontrun{lag(df$c)} setMethod("lag", - signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"), + signature(x = "characterOrColumn"), function(x, offset, defaultValue = NULL) { col <- if (class(x) == "Column") { x@jc @@ -2092,7 +2537,7 @@ setMethod("lag", #' Window function: returns the value that is `offset` rows after the current row, and #' `null` if there is less than `offset` rows after the current row. For example, #' an `offset` of one will return the next row at any given point in the window partition. -#' +#' #' This is equivalent to the LEAD function in SQL. #' #' @rdname lead @@ -2119,7 +2564,7 @@ setMethod("lead", #' Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window #' partition. Fow example, if `n` is 4, the first quarter of the rows will get value 1, the second #' quarter will get 2, the third quarter will get 3, and the last quarter will get 4. -#' +#' #' This is equivalent to the NTILE function in SQL. #' #' @rdname ntile @@ -2134,37 +2579,37 @@ setMethod("ntile", column(jc) }) -#' percentRank +#' percent_rank #' #' Window function: returns the relative rank (i.e. percentile) of rows within a window partition. -#' +#' #' This is computed by: -#' +#' #' (rank of row in its partition - 1) / (number of rows in the partition - 1) #' #' This is equivalent to the PERCENT_RANK function in SQL. #' -#' @rdname percentRank -#' @name percentRank +#' @rdname percent_rank +#' @name percent_rank #' @family window_funcs #' @export -#' @examples \dontrun{percentRank()} -setMethod("percentRank", +#' @examples \dontrun{percent_rank()} +setMethod("percent_rank", signature(x = "missing"), function() { - jc <- callJStatic("org.apache.spark.sql.functions", "percentRank") + jc <- callJStatic("org.apache.spark.sql.functions", "percent_rank") column(jc) }) #' rank #' #' Window function: returns the rank of rows within a window partition. -#' +#' #' The difference between rank and denseRank is that denseRank leaves no gaps in ranking #' sequence when there are ties. That is, if you were ranking a competition using denseRank #' and had three people tie for second place, you would say that all three were in second #' place and that the next person came in third. -#' +#' #' This is equivalent to the RANK function in SQL. #' #' @rdname rank @@ -2186,20 +2631,97 @@ setMethod("rank", base::rank(x, ...) }) -#' rowNumber +#' row_number #' #' Window function: returns a sequential number starting at 1 within a window partition. -#' +#' #' This is equivalent to the ROW_NUMBER function in SQL. #' -#' @rdname rowNumber -#' @name rowNumber +#' @rdname row_number +#' @name row_number #' @family window_funcs #' @export -#' @examples \dontrun{rowNumber()} -setMethod("rowNumber", +#' @examples \dontrun{row_number()} +setMethod("row_number", signature(x = "missing"), function() { - jc <- callJStatic("org.apache.spark.sql.functions", "rowNumber") + jc <- callJStatic("org.apache.spark.sql.functions", "row_number") + column(jc) + }) + +###################### Collection functions###################### + +#' array_contains +#' +#' Returns true if the array contain the value. +#' +#' @param x A Column +#' @param value A value to be checked if contained in the column +#' @rdname array_contains +#' @name array_contains +#' @family collection_funcs +#' @export +#' @examples \dontrun{array_contains(df$c, 1)} +setMethod("array_contains", + signature(x = "Column", value = "ANY"), + function(x, value) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_contains", x@jc, value) + column(jc) + }) + +#' explode +#' +#' Creates a new row for each element in the given array or map column. +#' +#' @rdname explode +#' @name explode +#' @family collection_funcs +#' @export +#' @examples \dontrun{explode(df$c)} +setMethod("explode", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "explode", x@jc) + column(jc) + }) + +#' size +#' +#' Returns length of array or map. +#' +#' @rdname size +#' @name size +#' @family collection_funcs +#' @export +#' @examples \dontrun{size(df$c)} +setMethod("size", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "size", x@jc) + column(jc) + }) + +#' sort_array +#' +#' Sorts the input array for the given column in ascending order, +#' according to the natural ordering of the array elements. +#' +#' @param x A Column to sort +#' @param asc A logical flag indicating the sorting order. +#' TRUE, sorting is in ascending order. +#' FALSE, sorting is in descending order. +#' @rdname sort_array +#' @name sort_array +#' @family collection_funcs +#' @export +#' @examples +#' \dontrun{ +#' sort_array(df$c) +#' sort_array(df$c, FALSE) +#' } +setMethod("sort_array", + signature(x = "Column"), + function(x, asc = TRUE) { + jc <- callJStatic("org.apache.spark.sql.functions", "sort_array", x@jc, asc) column(jc) }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 0b35340e48e42..ecdeea5ec4912 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -67,6 +67,13 @@ setGeneric("crosstab", function(x, col1, col2) { standardGeneric("crosstab") }) # @export setGeneric("freqItems", function(x, cols, support = 0.01) { standardGeneric("freqItems") }) +# @rdname statfunctions +# @export +setGeneric("approxQuantile", + function(x, col, probabilities, relativeError) { + standardGeneric("approxQuantile") + }) + # @rdname distinct # @export setGeneric("distinct", function(x, numPartitions = 1) { standardGeneric("distinct") }) @@ -77,7 +84,7 @@ setGeneric("filterRDD", function(x, f) { standardGeneric("filterRDD") }) # @rdname first # @export -setGeneric("first", function(x) { standardGeneric("first") }) +setGeneric("first", function(x, ...) { standardGeneric("first") }) # @rdname flatMap # @export @@ -88,12 +95,8 @@ setGeneric("flatMap", function(X, FUN) { standardGeneric("flatMap") }) # @export setGeneric("fold", function(x, zeroValue, op) { standardGeneric("fold") }) -# @rdname foreach -# @export setGeneric("foreach", function(x, func) { standardGeneric("foreach") }) -# @rdname foreach -# @export setGeneric("foreachPartition", function(x, func) { standardGeneric("foreachPartition") }) # The jrdd accessor function. @@ -107,27 +110,17 @@ setGeneric("glom", function(x) { standardGeneric("glom") }) # @export setGeneric("keyBy", function(x, func) { standardGeneric("keyBy") }) -# @rdname lapplyPartition -# @export setGeneric("lapplyPartition", function(X, FUN) { standardGeneric("lapplyPartition") }) -# @rdname lapplyPartitionsWithIndex -# @export setGeneric("lapplyPartitionsWithIndex", function(X, FUN) { standardGeneric("lapplyPartitionsWithIndex") }) -# @rdname lapply -# @export setGeneric("map", function(X, FUN) { standardGeneric("map") }) -# @rdname lapplyPartition -# @export setGeneric("mapPartitions", function(X, FUN) { standardGeneric("mapPartitions") }) -# @rdname lapplyPartitionsWithIndex -# @export setGeneric("mapPartitionsWithIndex", function(X, FUN) { standardGeneric("mapPartitionsWithIndex") }) @@ -147,7 +140,11 @@ setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") }) # @export setGeneric("name", function(x) { standardGeneric("name") }) -# @rdname numPartitions +# @rdname getNumPartitions +# @export +setGeneric("getNumPartitions", function(x) { standardGeneric("getNumPartitions") }) + +# @rdname getNumPartitions # @export setGeneric("numPartitions", function(x) { standardGeneric("numPartitions") }) @@ -388,7 +385,6 @@ setGeneric("subtractByKey", setGeneric("value", function(bcast) { standardGeneric("value") }) - #################### DataFrame Methods ######################## #' @rdname agg @@ -399,22 +395,65 @@ setGeneric("agg", function (x, ...) { standardGeneric("agg") }) #' @export setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) +#' @rdname as.data.frame +#' @export +setGeneric("as.data.frame") + +#' @rdname attach +#' @export +setGeneric("attach") + +#' @rdname columns +#' @export +setGeneric("colnames", function(x, do.NULL = TRUE, prefix = "col") { standardGeneric("colnames") }) + +#' @rdname columns +#' @export +setGeneric("colnames<-", function(x, value) { standardGeneric("colnames<-") }) + +#' @rdname coltypes +#' @export +setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) + +#' @rdname coltypes +#' @export +setGeneric("coltypes<-", function(x, value) { standardGeneric("coltypes<-") }) + #' @rdname schema #' @export setGeneric("columns", function(x) {standardGeneric("columns") }) #' @rdname statfunctions #' @export -setGeneric("cov", function(x, col1, col2) {standardGeneric("cov") }) +setGeneric("cov", function(x, ...) {standardGeneric("cov") }) #' @rdname statfunctions #' @export -setGeneric("corr", function(x, col1, col2, method = "pearson") {standardGeneric("corr") }) +setGeneric("corr", function(x, ...) {standardGeneric("corr") }) -#' @rdname describe +#' @rdname statfunctions +#' @export +setGeneric("covar_samp", function(col1, col2) {standardGeneric("covar_samp") }) + +#' @rdname statfunctions +#' @export +setGeneric("covar_pop", function(col1, col2) {standardGeneric("covar_pop") }) + +#' @rdname summary #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) +#' @rdname drop +#' @export +setGeneric("drop", function(x, ...) { standardGeneric("drop") }) + +#' @rdname dropduplicates +#' @export +setGeneric("dropDuplicates", + function(x, colNames = columns(x)) { + standardGeneric("dropDuplicates") + }) + #' @rdname nafunctions #' @export setGeneric("dropna", @@ -473,11 +512,11 @@ setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) #' @export setGeneric("limit", function(x, num) {standardGeneric("limit") }) -#' rdname merge +#' @rdname merge #' @export setGeneric("merge") -#' @rdname withColumn +#' @rdname mutate #' @export setGeneric("mutate", function(.data, ...) {standardGeneric("mutate") }) @@ -489,7 +528,7 @@ setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") }) #' @export setGeneric("printSchema", function(x) { standardGeneric("printSchema") }) -#' @rdname withColumnRenamed +#' @rdname rename #' @export setGeneric("rename", function(x, ...) { standardGeneric("rename") }) @@ -513,27 +552,46 @@ setGeneric("sample_frac", #' @export setGeneric("sampleBy", function(x, col, fractions, seed) { standardGeneric("sampleBy") }) -#' @rdname saveAsParquetFile -#' @export -setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) - #' @rdname saveAsTable #' @export -setGeneric("saveAsTable", function(df, tableName, source, mode, ...) { +setGeneric("saveAsTable", function(df, tableName, source = NULL, mode = "error", ...) { standardGeneric("saveAsTable") }) -#' @rdname withColumn +#' @export +setGeneric("str") + +#' @rdname mutate #' @export setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") }) #' @rdname write.df #' @export -setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) +setGeneric("write.df", function(df, path, source = NULL, mode = "error", ...) { + standardGeneric("write.df") +}) #' @rdname write.df #' @export -setGeneric("saveDF", function(df, path, ...) { standardGeneric("saveDF") }) +setGeneric("saveDF", function(df, path, source = NULL, mode = "error", ...) { + standardGeneric("saveDF") +}) + +#' @rdname write.json +#' @export +setGeneric("write.json", function(x, path) { standardGeneric("write.json") }) + +#' @rdname write.parquet +#' @export +setGeneric("write.parquet", function(x, path) { standardGeneric("write.parquet") }) + +#' @rdname write.parquet +#' @export +setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) + +#' @rdname write.text +#' @export +setGeneric("write.text", function(x, path) { standardGeneric("write.text") }) #' @rdname schema #' @export @@ -549,29 +607,25 @@ setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") #' @rdname showDF #' @export -setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) +setGeneric("showDF", function(x, ...) { standardGeneric("showDF") }) # @rdname subset # @export -setGeneric("subset", function(x, subset, select, ...) { standardGeneric("subset") }) +setGeneric("subset", function(x, ...) { standardGeneric("subset") }) #' @rdname agg #' @export -setGeneric("summarize", function(x,...) { standardGeneric("summarize") }) +setGeneric("summarize", function(x, ...) { standardGeneric("summarize") }) #' @rdname summary #' @export -setGeneric("summary", function(x, ...) { standardGeneric("summary") }) +setGeneric("summary", function(object, ...) { standardGeneric("summary") }) -# @rdname tojson -# @export setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) -#' @rdname DataFrame -#' @export setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) -#' @rdname unionAll +#' @rdname rbind #' @export setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) @@ -579,15 +633,22 @@ setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) #' @export setGeneric("where", function(x, condition) { standardGeneric("where") }) +#' @rdname with +#' @export +setGeneric("with") + #' @rdname withColumn #' @export setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn") }) -#' @rdname withColumnRenamed +#' @rdname rename #' @export setGeneric("withColumnRenamed", function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) +#' @rdname write.df +#' @export +setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) ###################### Column Methods ########################## @@ -623,6 +684,10 @@ setGeneric("getField", function(x, ...) { standardGeneric("getField") }) #' @export setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) +#' @rdname column +#' @export +setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) + #' @rdname column #' @export setGeneric("isNull", function(x) { standardGeneric("isNull") }) @@ -662,6 +727,10 @@ setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) #' @export setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) +#' @rdname array_contains +#' @export +setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) + #' @rdname ascii #' @export setGeneric("ascii", function(x) { standardGeneric("ascii") }) @@ -714,9 +783,13 @@ setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") #' @export setGeneric("crc32", function(x) { standardGeneric("crc32") }) -#' @rdname cumeDist +#' @rdname hash #' @export -setGeneric("cumeDist", function(x) { standardGeneric("cumeDist") }) +setGeneric("hash", function(x, ...) { standardGeneric("hash") }) + +#' @rdname cume_dist +#' @export +setGeneric("cume_dist", function(x) { standardGeneric("cume_dist") }) #' @rdname datediff #' @export @@ -742,9 +815,17 @@ setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) #' @export setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) -#' @rdname denseRank +#' @rdname decode +#' @export +setGeneric("decode", function(x, charset) { standardGeneric("decode") }) + +#' @rdname dense_rank #' @export -setGeneric("denseRank", function(x) { standardGeneric("denseRank") }) +setGeneric("dense_rank", function(x) { standardGeneric("dense_rank") }) + +#' @rdname encode +#' @export +setGeneric("encode", function(x, charset) { standardGeneric("encode") }) #' @rdname explode #' @export @@ -794,17 +875,21 @@ setGeneric("initcap", function(x) { standardGeneric("initcap") }) #' @export setGeneric("instr", function(y, x) { standardGeneric("instr") }) -#' @rdname isNaN +#' @rdname is.nan #' @export -setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) +setGeneric("isnan", function(x) { standardGeneric("isnan") }) + +#' @rdname kurtosis +#' @export +setGeneric("kurtosis", function(x) { standardGeneric("kurtosis") }) #' @rdname lag #' @export -setGeneric("lag", function(x, offset, defaultValue = NULL) { standardGeneric("lag") }) +setGeneric("lag", function(x, ...) { standardGeneric("lag") }) #' @rdname last #' @export -setGeneric("last", function(x) { standardGeneric("last") }) +setGeneric("last", function(x, ...) { standardGeneric("last") }) #' @rdname last_day #' @export @@ -882,9 +967,9 @@ setGeneric("ntile", function(x) { standardGeneric("ntile") }) #' @export setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) -#' @rdname percentRank +#' @rdname percent_rank #' @export -setGeneric("percentRank", function(x) { standardGeneric("percentRank") }) +setGeneric("percent_rank", function(x) { standardGeneric("percent_rank") }) #' @rdname pmod #' @export @@ -923,9 +1008,9 @@ setGeneric("reverse", function(x) { standardGeneric("reverse") }) #' @export setGeneric("rint", function(x, ...) { standardGeneric("rint") }) -#' @rdname rowNumber +#' @rdname row_number #' @export -setGeneric("rowNumber", function(x) { standardGeneric("rowNumber") }) +setGeneric("row_number", function(x) { standardGeneric("row_number") }) #' @rdname rpad #' @export @@ -935,6 +1020,10 @@ setGeneric("rpad", function(x, len, pad) { standardGeneric("rpad") }) #' @export setGeneric("rtrim", function(x) { standardGeneric("rtrim") }) +#' @rdname sd +#' @export +setGeneric("sd", function(x, na.rm = FALSE) { standardGeneric("sd") }) + #' @rdname second #' @export setGeneric("second", function(x) { standardGeneric("second") }) @@ -967,10 +1056,34 @@ setGeneric("signum", function(x) { standardGeneric("signum") }) #' @export setGeneric("size", function(x) { standardGeneric("size") }) +#' @rdname skewness +#' @export +setGeneric("skewness", function(x) { standardGeneric("skewness") }) + +#' @rdname sort_array +#' @export +setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) + #' @rdname soundex #' @export setGeneric("soundex", function(x) { standardGeneric("soundex") }) +#' @rdname sd +#' @export +setGeneric("stddev", function(x) { standardGeneric("stddev") }) + +#' @rdname stddev_pop +#' @export +setGeneric("stddev_pop", function(x) { standardGeneric("stddev_pop") }) + +#' @rdname stddev_samp +#' @export +setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") }) + +#' @rdname struct +#' @export +setGeneric("struct", function(x, ...) { standardGeneric("struct") }) + #' @rdname substring_index #' @export setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") }) @@ -1019,27 +1132,58 @@ setGeneric("unix_timestamp", function(x, format) { standardGeneric("unix_timesta #' @export setGeneric("upper", function(x) { standardGeneric("upper") }) +#' @rdname var +#' @export +setGeneric("var", function(x, y = NULL, na.rm = FALSE, use) { standardGeneric("var") }) + +#' @rdname var +#' @export +setGeneric("variance", function(x) { standardGeneric("variance") }) + +#' @rdname var_pop +#' @export +setGeneric("var_pop", function(x) { standardGeneric("var_pop") }) + +#' @rdname var_samp +#' @export +setGeneric("var_samp", function(x) { standardGeneric("var_samp") }) + #' @rdname weekofyear #' @export setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) +#' @rdname window +#' @export +setGeneric("window", function(x, ...) { standardGeneric("window") }) + #' @rdname year #' @export setGeneric("year", function(x) { standardGeneric("year") }) - #' @rdname glm #' @export setGeneric("glm") +#' @rdname predict +#' @export +setGeneric("predict", function(object, ...) { standardGeneric("predict") }) + #' @rdname rbind #' @export setGeneric("rbind", signature = "...") -#' @rdname as.data.frame +#' @rdname kmeans #' @export -setGeneric("as.data.frame") +setGeneric("kmeans") -#' @rdname attach +#' @rdname fitted #' @export -setGeneric("attach") +setGeneric("fitted") + +#' @rdname naiveBayes +#' @export +setGeneric("naiveBayes", function(formula, data, ...) { standardGeneric("naiveBayes") }) + +#' @rdname survreg +#' @export +setGeneric("survreg", function(formula, data, ...) { standardGeneric("survreg") }) diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 4cab1a69f601a..23b49aebda05f 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -68,7 +68,7 @@ setMethod("count", dataFrame(callJMethod(x@sgd, "count")) }) -#' Agg +#' summarize #' #' Aggregates on the entire DataFrame without groups. #' The resulting DataFrame will also contain the grouping columns. @@ -78,11 +78,14 @@ setMethod("count", #' #' @param x a GroupedData #' @return a DataFrame -#' @rdname agg +#' @rdname summarize +#' @name agg +#' @family agg_funcs #' @examples #' \dontrun{ #' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)' -#' df2 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum +#' df3 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum +#' df4 <- summarize(df, ageSum = max(df$age)) #' } setMethod("agg", signature(x = "GroupedData"), @@ -109,16 +112,19 @@ setMethod("agg", dataFrame(sdf) }) -#' @rdname agg -#' @aliases agg +#' @rdname summarize +#' @name summarize setMethod("summarize", signature(x = "GroupedData"), function(x, ...) { agg(x, ...) }) -# sum/mean/avg/min/max -methods <- c("sum", "mean", "avg", "min", "max") +# Aggregate Functions by name +methods <- c("avg", "max", "mean", "min", "sum") + +# These are not exposed on GroupedData: "kurtosis", "skewness", "stddev", "stddev_samp", "stddev_pop", +# "variance", "var_samp", "var_pop" createMethod <- function(name) { setMethod(name, diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 60bfadb8e7503..31bca16580451 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -17,85 +17,350 @@ # mllib.R: Provides methods for MLlib integration -#' @title S4 class that represents a PipelineModel -#' @param model A Java object reference to the backing Scala PipelineModel +#' @title S4 class that represents a generalized linear model +#' @param jobj a Java object reference to the backing Scala GeneralizedLinearRegressionWrapper #' @export -setClass("PipelineModel", representation(model = "jobj")) +setClass("GeneralizedLinearRegressionModel", representation(jobj = "jobj")) + +#' @title S4 class that represents a NaiveBayesModel +#' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper +#' @export +setClass("NaiveBayesModel", representation(jobj = "jobj")) + +#' @title S4 class that represents a AFTSurvivalRegressionModel +#' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper +#' @export +setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj")) + +#' @title S4 class that represents a KMeansModel +#' @param jobj a Java object reference to the backing Scala KMeansModel +#' @export +setClass("KMeansModel", representation(jobj = "jobj")) #' Fits a generalized linear model #' -#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package. +#' Fits a generalized linear model, similarly to R's glm(). #' #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param data DataFrame for training -#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. -#' @param lambda Regularization parameter -#' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details) -#' @return a fitted MLlib model +#' @param data DataFrame for training. +#' @param family A description of the error distribution and link function to be used in the model. +#' This can be a character string naming a family function, a family function or +#' the result of a call to a family function. Refer R family at +#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. +#' @param epsilon Positive convergence tolerance of iterations. +#' @param maxit Integer giving the maximal number of IRLS iterations. +#' @return a fitted generalized linear model #' @rdname glm #' @export #' @examples -#'\dontrun{ +#' \dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' data(iris) #' df <- createDataFrame(sqlContext, iris) #' model <- glm(Sepal_Length ~ Sepal_Width, df, family="gaussian") #' summary(model) -#'} +#' } setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"), - function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0, - standardize = TRUE, solver = "auto") { - family <- match.arg(family) - model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "fitRModelFormula", deparse(formula), data@sdf, family, lambda, - alpha, standardize, solver) - return(new("PipelineModel", model = model)) + function(formula, family = gaussian, data, epsilon = 1e-06, maxit = 25) { + if (is.character(family)) { + family <- get(family, mode = "function", envir = parent.frame()) + } + if (is.function(family)) { + family <- family() + } + if (is.null(family$family)) { + print(family) + stop("'family' not recognized") + } + + formula <- paste(deparse(formula), collapse = "") + + jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", + "fit", formula, data@sdf, family$family, family$link, + epsilon, as.integer(maxit)) + return(new("GeneralizedLinearRegressionModel", jobj = jobj)) }) -#' Make predictions from a model +#' Get the summary of a generalized linear model #' -#' Makes predictions from a model produced by glm(), similarly to R's predict(). +#' Returns the summary of a model produced by glm(), similarly to R's summary(). #' -#' @param object A fitted MLlib model +#' @param object A fitted generalized linear model +#' @return coefficients the model's coefficients, intercept +#' @rdname summary +#' @export +#' @examples +#' \dontrun{ +#' model <- glm(y ~ x, trainingData) +#' summary(model) +#' } +setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), + function(object, ...) { + jobj <- object@jobj + features <- callJMethod(jobj, "rFeatures") + coefficients <- callJMethod(jobj, "rCoefficients") + coefficients <- as.matrix(unlist(coefficients)) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + return(list(coefficients = coefficients)) + }) + +#' Make predictions from a generalized linear model +#' +#' Makes predictions from a generalized linear model produced by glm(), similarly to R's predict(). +#' +#' @param object A fitted generalized linear model #' @param newData DataFrame for testing -#' @return DataFrame containing predicted values +#' @return DataFrame containing predicted labels in a column named "prediction" #' @rdname predict #' @export #' @examples -#'\dontrun{ +#' \dontrun{ #' model <- glm(y ~ x, trainingData) #' predicted <- predict(model, testData) #' showDF(predicted) +#' } +setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"), + function(object, newData) { + return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf))) + }) + +#' Make predictions from a naive Bayes model +#' +#' Makes predictions from a model produced by naiveBayes(), similarly to R package e1071's predict. +#' +#' @param object A fitted naive Bayes model +#' @param newData DataFrame for testing +#' @return DataFrame containing predicted labels in a column named "prediction" +#' @rdname predict +#' @export +#' @examples +#' \dontrun{ +#' model <- naiveBayes(y ~ x, trainingData) +#' predicted <- predict(model, testData) +#' showDF(predicted) #'} -setMethod("predict", signature(object = "PipelineModel"), +setMethod("predict", signature(object = "NaiveBayesModel"), function(object, newData) { - return(dataFrame(callJMethod(object@model, "transform", newData@sdf))) + return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf))) }) -#' Get the summary of a model +#' Get the summary of a naive Bayes model #' -#' Returns the summary of a model produced by glm(), similarly to R's summary(). +#' Returns the summary of a naive Bayes model produced by naiveBayes(), similarly to R's summary(). #' -#' @param x A fitted MLlib model -#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See -#' summary.glm for more information. +#' @param object A fitted MLlib model +#' @return a list containing 'apriori', the label distribution, and 'tables', conditional +# probabilities given the target label #' @rdname summary #' @export #' @examples -#'\dontrun{ -#' model <- glm(y ~ x, trainingData) +#' \dontrun{ +#' model <- naiveBayes(y ~ x, trainingData) #' summary(model) #'} -setMethod("summary", signature(x = "PipelineModel"), - function(x, ...) { - features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelFeatures", x@model) - coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelCoefficients", x@model) +setMethod("summary", signature(object = "NaiveBayesModel"), + function(object, ...) { + jobj <- object@jobj + features <- callJMethod(jobj, "features") + labels <- callJMethod(jobj, "labels") + apriori <- callJMethod(jobj, "apriori") + apriori <- t(as.matrix(unlist(apriori))) + colnames(apriori) <- unlist(labels) + tables <- callJMethod(jobj, "tables") + tables <- matrix(tables, nrow = length(labels)) + rownames(tables) <- unlist(labels) + colnames(tables) <- unlist(features) + return(list(apriori = apriori, tables = tables)) + }) + +#' Fit a k-means model +#' +#' Fit a k-means model, similarly to R's kmeans(). +#' +#' @param x DataFrame for training +#' @param centers Number of centers +#' @param iter.max Maximum iteration number +#' @param algorithm Algorithm choosen to fit the model +#' @return A fitted k-means model +#' @rdname kmeans +#' @export +#' @examples +#' \dontrun{ +#' model <- kmeans(x, centers = 2, algorithm="random") +#' } +setMethod("kmeans", signature(x = "DataFrame"), + function(x, centers, iter.max = 10, algorithm = c("random", "k-means||")) { + columnNames <- as.array(colnames(x)) + algorithm <- match.arg(algorithm) + jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", x@sdf, + centers, iter.max, algorithm, columnNames) + return(new("KMeansModel", jobj = jobj)) + }) + +#' Get fitted result from a k-means model +#' +#' Get fitted result from a k-means model, similarly to R's fitted(). +#' +#' @param object A fitted k-means model +#' @return DataFrame containing fitted values +#' @rdname fitted +#' @export +#' @examples +#' \dontrun{ +#' model <- kmeans(trainingData, 2) +#' fitted.model <- fitted(model) +#' showDF(fitted.model) +#'} +setMethod("fitted", signature(object = "KMeansModel"), + function(object, method = c("centers", "classes"), ...) { + method <- match.arg(method) + return(dataFrame(callJMethod(object@jobj, "fitted", method))) + }) + +#' Get the summary of a k-means model +#' +#' Returns the summary of a k-means model produced by kmeans(), +#' similarly to R's summary(). +#' +#' @param object a fitted k-means model +#' @return the model's coefficients, size and cluster +#' @rdname summary +#' @export +#' @examples +#' \dontrun{ +#' model <- kmeans(trainingData, 2) +#' summary(model) +#' } +setMethod("summary", signature(object = "KMeansModel"), + function(object, ...) { + jobj <- object@jobj + features <- callJMethod(jobj, "features") + coefficients <- callJMethod(jobj, "coefficients") + cluster <- callJMethod(jobj, "cluster") + k <- callJMethod(jobj, "k") + size <- callJMethod(jobj, "size") + coefficients <- t(matrix(coefficients, ncol = k)) + colnames(coefficients) <- unlist(features) + rownames(coefficients) <- 1:k + return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster))) + }) + +#' Make predictions from a k-means model +#' +#' Make predictions from a model produced by kmeans(). +#' +#' @param object A fitted k-means model +#' @param newData DataFrame for testing +#' @return DataFrame containing predicted labels in a column named "prediction" +#' @rdname predict +#' @export +#' @examples +#' \dontrun{ +#' model <- kmeans(trainingData, 2) +#' predicted <- predict(model, testData) +#' showDF(predicted) +#' } +setMethod("predict", signature(object = "KMeansModel"), + function(object, newData) { + return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf))) + }) + +#' Fit a Bernoulli naive Bayes model +#' +#' Fit a Bernoulli naive Bayes model, similarly to R package e1071's naiveBayes() while only +#' categorical features are supported. The input should be a DataFrame of observations instead of a +#' contingency table. +#' +#' @param object A symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. +#' @param data DataFrame for training +#' @param laplace Smoothing parameter +#' @return a fitted naive Bayes model +#' @rdname naiveBayes +#' @seealso e1071: \url{https://cran.r-project.org/web/packages/e1071/} +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(sqlContext, infert) +#' model <- naiveBayes(education ~ ., df, laplace = 0) +#'} +setMethod("naiveBayes", signature(formula = "formula", data = "DataFrame"), + function(formula, data, laplace = 0, ...) { + formula <- paste(deparse(formula), collapse = "") + jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit", + formula, data@sdf, laplace) + return(new("NaiveBayesModel", jobj = jobj)) + }) + +#' Fit an accelerated failure time (AFT) survival regression model. +#' +#' Fit an accelerated failure time (AFT) survival regression model, similarly to R's survreg(). +#' +#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' Note that operator '.' is not supported currently. +#' @param data DataFrame for training. +#' @return a fitted AFT survival regression model +#' @rdname survreg +#' @seealso survival: \url{https://cran.r-project.org/web/packages/survival/} +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(sqlContext, ovarian) +#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, df) +#' } +setMethod("survreg", signature(formula = "formula", data = "DataFrame"), + function(formula, data, ...) { + formula <- paste(deparse(formula), collapse = "") + jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper", + "fit", formula, data@sdf) + return(new("AFTSurvivalRegressionModel", jobj = jobj)) + }) + +#' Get the summary of an AFT survival regression model +#' +#' Returns the summary of an AFT survival regression model produced by survreg(), +#' similarly to R's summary(). +#' +#' @param object a fitted AFT survival regression model +#' @return coefficients the model's coefficients, intercept and log(scale). +#' @rdname summary +#' @export +#' @examples +#' \dontrun{ +#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData) +#' summary(model) +#' } +setMethod("summary", signature(object = "AFTSurvivalRegressionModel"), + function(object, ...) { + jobj <- object@jobj + features <- callJMethod(jobj, "rFeatures") + coefficients <- callJMethod(jobj, "rCoefficients") coefficients <- as.matrix(unlist(coefficients)) - colnames(coefficients) <- c("Estimate") + colnames(coefficients) <- c("Value") rownames(coefficients) <- unlist(features) return(list(coefficients = coefficients)) }) + +#' Make predictions from an AFT survival regression model +#' +#' Make predictions from a model produced by survreg(), similarly to R package survival's predict. +#' +#' @param object A fitted AFT survival regression model +#' @param newData DataFrame for testing +#' @return DataFrame containing predicted labels in a column named "prediction" +#' @rdname predict +#' @export +#' @examples +#' \dontrun{ +#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData) +#' predicted <- predict(model, testData) +#' showDF(predicted) +#' } +setMethod("predict", signature(object = "AFTSurvivalRegressionModel"), + function(object, newData) { + return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf))) + }) diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 199c3fd6ab1b2..4075ef4377acf 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -21,23 +21,26 @@ NULL ############ Actions and Transformations ############ -# Look up elements of a key in an RDD -# -# @description -# \code{lookup} returns a list of values in this RDD for key key. -# -# @param x The RDD to collect -# @param key The key to look up for -# @return a list of values in this RDD for key key -# @examples -#\dontrun{ -# sc <- sparkR.init() -# pairs <- list(c(1, 1), c(2, 2), c(1, 3)) -# rdd <- parallelize(sc, pairs) -# lookup(rdd, 1) # list(1, 3) -#} -# @rdname lookup -# @aliases lookup,RDD-method +#' Look up elements of a key in an RDD +#' +#' @description +#' \code{lookup} returns a list of values in this RDD for key key. +#' +#' @param x The RDD to collect +#' @param key The key to look up for +#' @return a list of values in this RDD for key key +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(c(1, 1), c(2, 2), c(1, 3)) +#' rdd <- parallelize(sc, pairs) +#' lookup(rdd, 1) # list(1, 3) +#'} +# nolint end +#' @rdname lookup +#' @aliases lookup,RDD-method +#' @noRd setMethod("lookup", signature(x = "RDD", key = "ANY"), function(x, key) { @@ -49,21 +52,24 @@ setMethod("lookup", collect(valsRDD) }) -# Count the number of elements for each key, and return the result to the -# master as lists of (key, count) pairs. -# -# Same as countByKey in Spark. -# -# @param x The RDD to count keys. -# @return list of (key, count) pairs, where count is number of each key in rdd. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(c("a", 1), c("b", 1), c("a", 1))) -# countByKey(rdd) # ("a", 2L), ("b", 1L) -#} -# @rdname countByKey -# @aliases countByKey,RDD-method +#' Count the number of elements for each key, and return the result to the +#' master as lists of (key, count) pairs. +#' +#' Same as countByKey in Spark. +#' +#' @param x The RDD to count keys. +#' @return list of (key, count) pairs, where count is number of each key in rdd. +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(c("a", 1), c("b", 1), c("a", 1))) +#' countByKey(rdd) # ("a", 2L), ("b", 1L) +#'} +# nolint end +#' @rdname countByKey +#' @aliases countByKey,RDD-method +#' @noRd setMethod("countByKey", signature(x = "RDD"), function(x) { @@ -71,17 +77,20 @@ setMethod("countByKey", countByValue(keys) }) -# Return an RDD with the keys of each tuple. -# -# @param x The RDD from which the keys of each tuple is returned. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) -# collect(keys(rdd)) # list(1, 3) -#} -# @rdname keys -# @aliases keys,RDD +#' Return an RDD with the keys of each tuple. +#' +#' @param x The RDD from which the keys of each tuple is returned. +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) +#' collect(keys(rdd)) # list(1, 3) +#'} +# nolint end +#' @rdname keys +#' @aliases keys,RDD +#' @noRd setMethod("keys", signature(x = "RDD"), function(x) { @@ -91,17 +100,20 @@ setMethod("keys", lapply(x, func) }) -# Return an RDD with the values of each tuple. -# -# @param x The RDD from which the values of each tuple is returned. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) -# collect(values(rdd)) # list(2, 4) -#} -# @rdname values -# @aliases values,RDD +#' Return an RDD with the values of each tuple. +#' +#' @param x The RDD from which the values of each tuple is returned. +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) +#' collect(values(rdd)) # list(2, 4) +#'} +# nolint end +#' @rdname values +#' @aliases values,RDD +#' @noRd setMethod("values", signature(x = "RDD"), function(x) { @@ -111,23 +123,24 @@ setMethod("values", lapply(x, func) }) -# Applies a function to all values of the elements, without modifying the keys. -# -# The same as `mapValues()' in Spark. -# -# @param X The RDD to apply the transformation. -# @param FUN the transformation to apply on the value of each element. -# @return a new RDD created by the transformation. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# makePairs <- lapply(rdd, function(x) { list(x, x) }) -# collect(mapValues(makePairs, function(x) { x * 2) }) -# Output: list(list(1,2), list(2,4), list(3,6), ...) -#} -# @rdname mapValues -# @aliases mapValues,RDD,function-method +#' Applies a function to all values of the elements, without modifying the keys. +#' +#' The same as `mapValues()' in Spark. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on the value of each element. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' makePairs <- lapply(rdd, function(x) { list(x, x) }) +#' collect(mapValues(makePairs, function(x) { x * 2) }) +#' Output: list(list(1,2), list(2,4), list(3,6), ...) +#'} +#' @rdname mapValues +#' @aliases mapValues,RDD,function-method +#' @noRd setMethod("mapValues", signature(X = "RDD", FUN = "function"), function(X, FUN) { @@ -137,23 +150,24 @@ setMethod("mapValues", lapply(X, func) }) -# Pass each value in the key-value pair RDD through a flatMap function without -# changing the keys; this also retains the original RDD's partitioning. -# -# The same as 'flatMapValues()' in Spark. -# -# @param X The RDD to apply the transformation. -# @param FUN the transformation to apply on the value of each element. -# @return a new RDD created by the transformation. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) -# collect(flatMapValues(rdd, function(x) { x })) -# Output: list(list(1,1), list(1,2), list(2,3), list(2,4)) -#} -# @rdname flatMapValues -# @aliases flatMapValues,RDD,function-method +#' Pass each value in the key-value pair RDD through a flatMap function without +#' changing the keys; this also retains the original RDD's partitioning. +#' +#' The same as 'flatMapValues()' in Spark. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on the value of each element. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) +#' collect(flatMapValues(rdd, function(x) { x })) +#' Output: list(list(1,1), list(1,2), list(2,3), list(2,4)) +#'} +#' @rdname flatMapValues +#' @aliases flatMapValues,RDD,function-method +#' @noRd setMethod("flatMapValues", signature(X = "RDD", FUN = "function"), function(X, FUN) { @@ -165,38 +179,34 @@ setMethod("flatMapValues", ############ Shuffle Functions ############ -# Partition an RDD by key -# -# This function operates on RDDs where every element is of the form list(K, V) or c(K, V). -# For each element of this RDD, the partitioner is used to compute a hash -# function and the RDD is partitioned using this hash value. -# -# @param x The RDD to partition. Should be an RDD where each element is -# list(K, V) or c(K, V). -# @param numPartitions Number of partitions to create. -# @param ... Other optional arguments to partitionBy. -# -# @param partitionFunc The partition function to use. Uses a default hashCode -# function if not provided -# @return An RDD partitioned using the specified partitioner. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) -# rdd <- parallelize(sc, pairs) -# parts <- partitionBy(rdd, 2L) -# collectPartition(parts, 0L) # First partition should contain list(1, 2) and list(1, 4) -#} -# @rdname partitionBy -# @aliases partitionBy,RDD,integer-method +#' Partition an RDD by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' For each element of this RDD, the partitioner is used to compute a hash +#' function and the RDD is partitioned using this hash value. +#' +#' @param x The RDD to partition. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param numPartitions Number of partitions to create. +#' @param ... Other optional arguments to partitionBy. +#' +#' @param partitionFunc The partition function to use. Uses a default hashCode +#' function if not provided +#' @return An RDD partitioned using the specified partitioner. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- partitionBy(rdd, 2L) +#' collectPartition(parts, 0L) # First partition should contain list(1, 2) and list(1, 4) +#'} +#' @rdname partitionBy +#' @aliases partitionBy,RDD,integer-method +#' @noRd setMethod("partitionBy", signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions, partitionFunc = hashCode) { - - #if (missing(partitionFunc)) { - # partitionFunc <- hashCode - #} - partitionFunc <- cleanClosure(partitionFunc) serializedHashFuncBytes <- serialize(partitionFunc, connection = NULL) @@ -233,27 +243,28 @@ setMethod("partitionBy", RDD(r, serializedMode = "byte") }) -# Group values by key -# -# This function operates on RDDs where every element is of the form list(K, V) or c(K, V). -# and group values for each key in the RDD into a single sequence. -# -# @param x The RDD to group. Should be an RDD where each element is -# list(K, V) or c(K, V). -# @param numPartitions Number of partitions to create. -# @return An RDD where each element is list(K, list(V)) -# @seealso reduceByKey -# @examples -#\dontrun{ -# sc <- sparkR.init() -# pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) -# rdd <- parallelize(sc, pairs) -# parts <- groupByKey(rdd, 2L) -# grouped <- collect(parts) -# grouped[[1]] # Should be a list(1, list(2, 4)) -#} -# @rdname groupByKey -# @aliases groupByKey,RDD,integer-method +#' Group values by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and group values for each key in the RDD into a single sequence. +#' +#' @param x The RDD to group. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, list(V)) +#' @seealso reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- groupByKey(rdd, 2L) +#' grouped <- collect(parts) +#' grouped[[1]] # Should be a list(1, list(2, 4)) +#'} +#' @rdname groupByKey +#' @aliases groupByKey,RDD,integer-method +#' @noRd setMethod("groupByKey", signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions) { @@ -291,28 +302,29 @@ setMethod("groupByKey", lapplyPartition(shuffled, groupVals) }) -# Merge values by key -# -# This function operates on RDDs where every element is of the form list(K, V) or c(K, V). -# and merges the values for each key using an associative reduce function. -# -# @param x The RDD to reduce by key. Should be an RDD where each element is -# list(K, V) or c(K, V). -# @param combineFunc The associative reduce function to use. -# @param numPartitions Number of partitions to create. -# @return An RDD where each element is list(K, V') where V' is the merged -# value -# @examples -#\dontrun{ -# sc <- sparkR.init() -# pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) -# rdd <- parallelize(sc, pairs) -# parts <- reduceByKey(rdd, "+", 2L) -# reduced <- collect(parts) -# reduced[[1]] # Should be a list(1, 6) -#} -# @rdname reduceByKey -# @aliases reduceByKey,RDD,integer-method +#' Merge values by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and merges the values for each key using an associative and commutative reduce function. +#' +#' @param x The RDD to reduce by key. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param combineFunc The associative and commutative reduce function to use. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, V') where V' is the merged +#' value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- reduceByKey(rdd, "+", 2L) +#' reduced <- collect(parts) +#' reduced[[1]] # Should be a list(1, 6) +#'} +#' @rdname reduceByKey +#' @aliases reduceByKey,RDD,integer-method +#' @noRd setMethod("reduceByKey", signature(x = "RDD", combineFunc = "ANY", numPartitions = "numeric"), function(x, combineFunc, numPartitions) { @@ -332,27 +344,30 @@ setMethod("reduceByKey", lapplyPartition(shuffled, reduceVals) }) -# Merge values by key locally -# -# This function operates on RDDs where every element is of the form list(K, V) or c(K, V). -# and merges the values for each key using an associative reduce function, but return the -# results immediately to the driver as an R list. -# -# @param x The RDD to reduce by key. Should be an RDD where each element is -# list(K, V) or c(K, V). -# @param combineFunc The associative reduce function to use. -# @return A list of elements of type list(K, V') where V' is the merged value for each key -# @seealso reduceByKey -# @examples -#\dontrun{ -# sc <- sparkR.init() -# pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) -# rdd <- parallelize(sc, pairs) -# reduced <- reduceByKeyLocally(rdd, "+") -# reduced # list(list(1, 6), list(1.1, 3)) -#} -# @rdname reduceByKeyLocally -# @aliases reduceByKeyLocally,RDD,integer-method +#' Merge values by key locally +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and merges the values for each key using an associative and commutative reduce function, but +#' return the results immediately to the driver as an R list. +#' +#' @param x The RDD to reduce by key. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param combineFunc The associative and commutative reduce function to use. +#' @return A list of elements of type list(K, V') where V' is the merged value for each key +#' @seealso reduceByKey +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' reduced <- reduceByKeyLocally(rdd, "+") +#' reduced # list(list(1, 6), list(1.1, 3)) +#'} +# nolint end +#' @rdname reduceByKeyLocally +#' @aliases reduceByKeyLocally,RDD,integer-method +#' @noRd setMethod("reduceByKeyLocally", signature(x = "RDD", combineFunc = "ANY"), function(x, combineFunc) { @@ -384,41 +399,42 @@ setMethod("reduceByKeyLocally", convertEnvsToList(merged[[1]], merged[[2]]) }) -# Combine values by key -# -# Generic function to combine the elements for each key using a custom set of -# aggregation functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], -# for a "combined type" C. Note that V and C can be different -- for example, one -# might group an RDD of type (Int, Int) into an RDD of type (Int, Seq[Int]). - -# Users provide three functions: -# \itemize{ -# \item createCombiner, which turns a V into a C (e.g., creates a one-element list) -# \item mergeValue, to merge a V into a C (e.g., adds it to the end of a list) - -# \item mergeCombiners, to combine two C's into a single one (e.g., concatentates -# two lists). -# } -# -# @param x The RDD to combine. Should be an RDD where each element is -# list(K, V) or c(K, V). -# @param createCombiner Create a combiner (C) given a value (V) -# @param mergeValue Merge the given value (V) with an existing combiner (C) -# @param mergeCombiners Merge two combiners and return a new combiner -# @param numPartitions Number of partitions to create. -# @return An RDD where each element is list(K, C) where C is the combined type -# -# @seealso groupByKey, reduceByKey -# @examples -#\dontrun{ -# sc <- sparkR.init() -# pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) -# rdd <- parallelize(sc, pairs) -# parts <- combineByKey(rdd, function(x) { x }, "+", "+", 2L) -# combined <- collect(parts) -# combined[[1]] # Should be a list(1, 6) -#} -# @rdname combineByKey -# @aliases combineByKey,RDD,ANY,ANY,ANY,integer-method +#' Combine values by key +#' +#' Generic function to combine the elements for each key using a custom set of +#' aggregation functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], +#' for a "combined type" C. Note that V and C can be different -- for example, one +#' might group an RDD of type (Int, Int) into an RDD of type (Int, Seq[Int]). +#' Users provide three functions: +#' \itemize{ +#' \item createCombiner, which turns a V into a C (e.g., creates a one-element list) +#' \item mergeValue, to merge a V into a C (e.g., adds it to the end of a list) - +#' \item mergeCombiners, to combine two C's into a single one (e.g., concatentates +#' two lists). +#' } +#' +#' @param x The RDD to combine. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param createCombiner Create a combiner (C) given a value (V) +#' @param mergeValue Merge the given value (V) with an existing combiner (C) +#' @param mergeCombiners Merge two combiners and return a new combiner +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, C) where C is the combined type +#' @seealso groupByKey, reduceByKey +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- combineByKey(rdd, function(x) { x }, "+", "+", 2L) +#' combined <- collect(parts) +#' combined[[1]] # Should be a list(1, 6) +#'} +# nolint end +#' @rdname combineByKey +#' @aliases combineByKey,RDD,ANY,ANY,ANY,integer-method +#' @noRd setMethod("combineByKey", signature(x = "RDD", createCombiner = "ANY", mergeValue = "ANY", mergeCombiners = "ANY", numPartitions = "numeric"), @@ -450,36 +466,39 @@ setMethod("combineByKey", lapplyPartition(shuffled, mergeAfterShuffle) }) -# Aggregate a pair RDD by each key. -# -# Aggregate the values of each key in an RDD, using given combine functions -# and a neutral "zero value". This function can return a different result type, -# U, than the type of the values in this RDD, V. Thus, we need one operation -# for merging a V into a U and one operation for merging two U's, The former -# operation is used for merging values within a partition, and the latter is -# used for merging values between partitions. To avoid memory allocation, both -# of these functions are allowed to modify and return their first argument -# instead of creating a new U. -# -# @param x An RDD. -# @param zeroValue A neutral "zero value". -# @param seqOp A function to aggregate the values of each key. It may return -# a different result type from the type of the values. -# @param combOp A function to aggregate results of seqOp. -# @return An RDD containing the aggregation result. -# @seealso foldByKey, combineByKey -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) -# zeroValue <- list(0, 0) -# seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } -# combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } -# aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) -# # list(list(1, list(3, 2)), list(2, list(7, 2))) -#} -# @rdname aggregateByKey -# @aliases aggregateByKey,RDD,ANY,ANY,ANY,integer-method +#' Aggregate a pair RDD by each key. +#' +#' Aggregate the values of each key in an RDD, using given combine functions +#' and a neutral "zero value". This function can return a different result type, +#' U, than the type of the values in this RDD, V. Thus, we need one operation +#' for merging a V into a U and one operation for merging two U's, The former +#' operation is used for merging values within a partition, and the latter is +#' used for merging values between partitions. To avoid memory allocation, both +#' of these functions are allowed to modify and return their first argument +#' instead of creating a new U. +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param seqOp A function to aggregate the values of each key. It may return +#' a different result type from the type of the values. +#' @param combOp A function to aggregate results of seqOp. +#' @return An RDD containing the aggregation result. +#' @seealso foldByKey, combineByKey +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) +#' zeroValue <- list(0, 0) +#' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } +#' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } +#' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) +#' # list(list(1, list(3, 2)), list(2, list(7, 2))) +#'} +# nolint end +#' @rdname aggregateByKey +#' @aliases aggregateByKey,RDD,ANY,ANY,ANY,integer-method +#' @noRd setMethod("aggregateByKey", signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", combOp = "ANY", numPartitions = "numeric"), @@ -491,26 +510,29 @@ setMethod("aggregateByKey", combineByKey(x, createCombiner, seqOp, combOp, numPartitions) }) -# Fold a pair RDD by each key. -# -# Aggregate the values of each key in an RDD, using an associative function "func" -# and a neutral "zero value" which may be added to the result an arbitrary -# number of times, and must not change the result (e.g., 0 for addition, or -# 1 for multiplication.). -# -# @param x An RDD. -# @param zeroValue A neutral "zero value". -# @param func An associative function for folding values of each key. -# @return An RDD containing the aggregation result. -# @seealso aggregateByKey, combineByKey -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) -# foldByKey(rdd, 0, "+", 2L) # list(list(1, 3), list(2, 7)) -#} -# @rdname foldByKey -# @aliases foldByKey,RDD,ANY,ANY,integer-method +#' Fold a pair RDD by each key. +#' +#' Aggregate the values of each key in an RDD, using an associative function "func" +#' and a neutral "zero value" which may be added to the result an arbitrary +#' number of times, and must not change the result (e.g., 0 for addition, or +#' 1 for multiplication.). +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param func An associative function for folding values of each key. +#' @return An RDD containing the aggregation result. +#' @seealso aggregateByKey, combineByKey +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) +#' foldByKey(rdd, 0, "+", 2L) # list(list(1, 3), list(2, 7)) +#'} +# nolint end +#' @rdname foldByKey +#' @aliases foldByKey,RDD,ANY,ANY,integer-method +#' @noRd setMethod("foldByKey", signature(x = "RDD", zeroValue = "ANY", func = "ANY", numPartitions = "numeric"), @@ -520,28 +542,31 @@ setMethod("foldByKey", ############ Binary Functions ############# -# Join two RDDs -# -# @description -# \code{join} This function joins two RDDs where every element is of the form list(K, V). -# The key types of the two RDDs should be the same. -# -# @param x An RDD to be joined. Should be an RDD where each element is -# list(K, V). -# @param y An RDD to be joined. Should be an RDD where each element is -# list(K, V). -# @param numPartitions Number of partitions to create. -# @return a new RDD containing all pairs of elements with matching keys in -# two input RDDs. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) -# rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) -# join(rdd1, rdd2, 2L) # list(list(1, list(1, 2)), list(1, list(1, 3)) -#} -# @rdname join-methods -# @aliases join,RDD,RDD-method +#' Join two RDDs +#' +#' @description +#' \code{join} This function joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return a new RDD containing all pairs of elements with matching keys in +#' two input RDDs. +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' join(rdd1, rdd2, 2L) # list(list(1, list(1, 2)), list(1, list(1, 3)) +#'} +# nolint end +#' @rdname join-methods +#' @aliases join,RDD,RDD-method +#' @noRd setMethod("join", signature(x = "RDD", y = "RDD"), function(x, y, numPartitions) { @@ -556,30 +581,33 @@ setMethod("join", doJoin) }) -# Left outer join two RDDs -# -# @description -# \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of -# the form list(K, V). The key types of the two RDDs should be the same. -# -# @param x An RDD to be joined. Should be an RDD where each element is -# list(K, V). -# @param y An RDD to be joined. Should be an RDD where each element is -# list(K, V). -# @param numPartitions Number of partitions to create. -# @return For each element (k, v) in x, the resulting RDD will either contain -# all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL)) -# if no elements in rdd2 have key k. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) -# rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) -# leftOuterJoin(rdd1, rdd2, 2L) -# # list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) -#} -# @rdname join-methods -# @aliases leftOuterJoin,RDD,RDD-method +#' Left outer join two RDDs +#' +#' @description +#' \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of +#' the form list(K, V). The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, v) in x, the resulting RDD will either contain +#' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL)) +#' if no elements in rdd2 have key k. +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' leftOuterJoin(rdd1, rdd2, 2L) +#' # list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) +#'} +# nolint end +#' @rdname join-methods +#' @aliases leftOuterJoin,RDD,RDD-method +#' @noRd setMethod("leftOuterJoin", signature(x = "RDD", y = "RDD", numPartitions = "numeric"), function(x, y, numPartitions) { @@ -593,30 +621,33 @@ setMethod("leftOuterJoin", joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) }) -# Right outer join two RDDs -# -# @description -# \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of -# the form list(K, V). The key types of the two RDDs should be the same. -# -# @param x An RDD to be joined. Should be an RDD where each element is -# list(K, V). -# @param y An RDD to be joined. Should be an RDD where each element is -# list(K, V). -# @param numPartitions Number of partitions to create. -# @return For each element (k, w) in y, the resulting RDD will either contain -# all pairs (k, (v, w)) for (k, v) in x, or the pair (k, (NULL, w)) -# if no elements in x have key k. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) -# rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) -# rightOuterJoin(rdd1, rdd2, 2L) -# # list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) -#} -# @rdname join-methods -# @aliases rightOuterJoin,RDD,RDD-method +#' Right outer join two RDDs +#' +#' @description +#' \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of +#' the form list(K, V). The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, w) in y, the resulting RDD will either contain +#' all pairs (k, (v, w)) for (k, v) in x, or the pair (k, (NULL, w)) +#' if no elements in x have key k. +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rightOuterJoin(rdd1, rdd2, 2L) +#' # list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) +#'} +# nolint end +#' @rdname join-methods +#' @aliases rightOuterJoin,RDD,RDD-method +#' @noRd setMethod("rightOuterJoin", signature(x = "RDD", y = "RDD", numPartitions = "numeric"), function(x, y, numPartitions) { @@ -630,33 +661,36 @@ setMethod("rightOuterJoin", joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) }) -# Full outer join two RDDs -# -# @description -# \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of -# the form list(K, V). The key types of the two RDDs should be the same. -# -# @param x An RDD to be joined. Should be an RDD where each element is -# list(K, V). -# @param y An RDD to be joined. Should be an RDD where each element is -# list(K, V). -# @param numPartitions Number of partitions to create. -# @return For each element (k, v) in x and (k, w) in y, the resulting RDD -# will contain all pairs (k, (v, w)) for both (k, v) in x and -# (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements -# in x/y have key k. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) -# rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) -# fullOuterJoin(rdd1, rdd2, 2L) # list(list(1, list(2, 1)), -# # list(1, list(3, 1)), -# # list(2, list(NULL, 4))) -# # list(3, list(3, NULL)), -#} -# @rdname join-methods -# @aliases fullOuterJoin,RDD,RDD-method +#' Full outer join two RDDs +#' +#' @description +#' \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of +#' the form list(K, V). The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, v) in x and (k, w) in y, the resulting RDD +#' will contain all pairs (k, (v, w)) for both (k, v) in x and +#' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements +#' in x/y have key k. +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) +#' rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' fullOuterJoin(rdd1, rdd2, 2L) # list(list(1, list(2, 1)), +#' # list(1, list(3, 1)), +#' # list(2, list(NULL, 4))) +#' # list(3, list(3, NULL)), +#'} +# nolint end +#' @rdname join-methods +#' @aliases fullOuterJoin,RDD,RDD-method +#' @noRd setMethod("fullOuterJoin", signature(x = "RDD", y = "RDD", numPartitions = "numeric"), function(x, y, numPartitions) { @@ -670,23 +704,26 @@ setMethod("fullOuterJoin", joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) }) -# For each key k in several RDDs, return a resulting RDD that -# whose values are a list of values for the key in all RDDs. -# -# @param ... Several RDDs. -# @param numPartitions Number of partitions to create. -# @return a new RDD containing all pairs of elements with values in a list -# in all RDDs. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) -# rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) -# cogroup(rdd1, rdd2, numPartitions = 2L) -# # list(list(1, list(1, list(2, 3))), list(2, list(list(4), list())) -#} -# @rdname cogroup -# @aliases cogroup,RDD-method +#' For each key k in several RDDs, return a resulting RDD that +#' whose values are a list of values for the key in all RDDs. +#' +#' @param ... Several RDDs. +#' @param numPartitions Number of partitions to create. +#' @return a new RDD containing all pairs of elements with values in a list +#' in all RDDs. +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' cogroup(rdd1, rdd2, numPartitions = 2L) +#' # list(list(1, list(1, list(2, 3))), list(2, list(list(4), list())) +#'} +# nolint end +#' @rdname cogroup +#' @aliases cogroup,RDD-method +#' @noRd setMethod("cogroup", "RDD", function(..., numPartitions) { @@ -722,23 +759,26 @@ setMethod("cogroup", group.func) }) -# Sort a (k, v) pair RDD by k. -# -# @param x A (k, v) pair RDD to be sorted. -# @param ascending A flag to indicate whether the sorting is ascending or descending. -# @param numPartitions Number of partitions to create. -# @return An RDD where all (k, v) pair elements are sorted. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(list(3, 1), list(2, 2), list(1, 3))) -# collect(sortByKey(rdd)) # list (list(1, 3), list(2, 2), list(3, 1)) -#} -# @rdname sortByKey -# @aliases sortByKey,RDD,RDD-method +#' Sort a (k, v) pair RDD by k. +#' +#' @param x A (k, v) pair RDD to be sorted. +#' @param ascending A flag to indicate whether the sorting is ascending or descending. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where all (k, v) pair elements are sorted. +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(3, 1), list(2, 2), list(1, 3))) +#' collect(sortByKey(rdd)) # list (list(1, 3), list(2, 2), list(3, 1)) +#'} +# nolint end +#' @rdname sortByKey +#' @aliases sortByKey,RDD,RDD-method +#' @noRd setMethod("sortByKey", signature(x = "RDD"), - function(x, ascending = TRUE, numPartitions = SparkR:::numPartitions(x)) { + function(x, ascending = TRUE, numPartitions = SparkR:::getNumPartitions(x)) { rangeBounds <- list() if (numPartitions > 1) { @@ -784,28 +824,31 @@ setMethod("sortByKey", lapplyPartition(newRDD, partitionFunc) }) -# Subtract a pair RDD with another pair RDD. -# -# Return an RDD with the pairs from x whose keys are not in other. -# -# @param x An RDD. -# @param other An RDD. -# @param numPartitions Number of the partitions in the result RDD. -# @return An RDD with the pairs from x whose keys are not in other. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4), -# list("b", 5), list("a", 2))) -# rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1))) -# collect(subtractByKey(rdd1, rdd2)) -# # list(list("b", 4), list("b", 5)) -#} -# @rdname subtractByKey -# @aliases subtractByKey,RDD +#' Subtract a pair RDD with another pair RDD. +#' +#' Return an RDD with the pairs from x whose keys are not in other. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @param numPartitions Number of the partitions in the result RDD. +#' @return An RDD with the pairs from x whose keys are not in other. +#' @examples +# nolint start +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4), +#' list("b", 5), list("a", 2))) +#' rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1))) +#' collect(subtractByKey(rdd1, rdd2)) +#' # list(list("b", 4), list("b", 5)) +#'} +# nolint end +#' @rdname subtractByKey +#' @aliases subtractByKey,RDD +#' @noRd setMethod("subtractByKey", signature(x = "RDD", other = "RDD"), - function(x, other, numPartitions = SparkR:::numPartitions(x)) { + function(x, other, numPartitions = SparkR:::getNumPartitions(x)) { filterFunction <- function(elem) { iters <- elem[[2]] (length(iters[[1]]) > 0) && (length(iters[[2]]) == 0) @@ -818,41 +861,42 @@ setMethod("subtractByKey", function (v) { v[[1]] }) }) -# Return a subset of this RDD sampled by key. -# -# @description -# \code{sampleByKey} Create a sample of this RDD using variable sampling rates -# for different keys as specified by fractions, a key to sampling rate map. -# -# @param x The RDD to sample elements by key, where each element is -# list(K, V) or c(K, V). -# @param withReplacement Sampling with replacement or not -# @param fraction The (rough) sample target fraction -# @param seed Randomness seed value -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:3000) -# pairs <- lapply(rdd, function(x) { if (x %% 3 == 0) list("a", x) -# else { if (x %% 3 == 1) list("b", x) else list("c", x) }}) -# fractions <- list(a = 0.2, b = 0.1, c = 0.3) -# sample <- sampleByKey(pairs, FALSE, fractions, 1618L) -# 100 < length(lookup(sample, "a")) && 300 > length(lookup(sample, "a")) # TRUE -# 50 < length(lookup(sample, "b")) && 150 > length(lookup(sample, "b")) # TRUE -# 200 < length(lookup(sample, "c")) && 400 > length(lookup(sample, "c")) # TRUE -# lookup(sample, "a")[which.min(lookup(sample, "a"))] >= 0 # TRUE -# lookup(sample, "a")[which.max(lookup(sample, "a"))] <= 2000 # TRUE -# lookup(sample, "b")[which.min(lookup(sample, "b"))] >= 0 # TRUE -# lookup(sample, "b")[which.max(lookup(sample, "b"))] <= 2000 # TRUE -# lookup(sample, "c")[which.min(lookup(sample, "c"))] >= 0 # TRUE -# lookup(sample, "c")[which.max(lookup(sample, "c"))] <= 2000 # TRUE -# fractions <- list(a = 0.2, b = 0.1, c = 0.3, d = 0.4) -# sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # Key "d" will be ignored -# fractions <- list(a = 0.2, b = 0.1) -# sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # KeyError: "c" -#} -# @rdname sampleByKey -# @aliases sampleByKey,RDD-method +#' Return a subset of this RDD sampled by key. +#' +#' @description +#' \code{sampleByKey} Create a sample of this RDD using variable sampling rates +#' for different keys as specified by fractions, a key to sampling rate map. +#' +#' @param x The RDD to sample elements by key, where each element is +#' list(K, V) or c(K, V). +#' @param withReplacement Sampling with replacement or not +#' @param fraction The (rough) sample target fraction +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3000) +#' pairs <- lapply(rdd, function(x) { if (x %% 3 == 0) list("a", x) +#' else { if (x %% 3 == 1) list("b", x) else list("c", x) }}) +#' fractions <- list(a = 0.2, b = 0.1, c = 0.3) +#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) +#' 100 < length(lookup(sample, "a")) && 300 > length(lookup(sample, "a")) # TRUE +#' 50 < length(lookup(sample, "b")) && 150 > length(lookup(sample, "b")) # TRUE +#' 200 < length(lookup(sample, "c")) && 400 > length(lookup(sample, "c")) # TRUE +#' lookup(sample, "a")[which.min(lookup(sample, "a"))] >= 0 # TRUE +#' lookup(sample, "a")[which.max(lookup(sample, "a"))] <= 2000 # TRUE +#' lookup(sample, "b")[which.min(lookup(sample, "b"))] >= 0 # TRUE +#' lookup(sample, "b")[which.max(lookup(sample, "b"))] <= 2000 # TRUE +#' lookup(sample, "c")[which.min(lookup(sample, "c"))] >= 0 # TRUE +#' lookup(sample, "c")[which.max(lookup(sample, "c"))] <= 2000 # TRUE +#' fractions <- list(a = 0.2, b = 0.1, c = 0.3, d = 0.4) +#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # Key "d" will be ignored +#' fractions <- list(a = 0.2, b = 0.1) +#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # KeyError: "c" +#'} +#' @rdname sampleByKey +#' @aliases sampleByKey,RDD-method +#' @noRd setMethod("sampleByKey", signature(x = "RDD", withReplacement = "logical", fractions = "vector", seed = "integer"), diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 6f0e9a94e9bfa..c6ddb562270b7 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -115,20 +115,7 @@ structField.jobj <- function(x) { } checkType <- function(type) { - primtiveTypes <- c("byte", - "integer", - "float", - "double", - "numeric", - "character", - "string", - "binary", - "raw", - "logical", - "boolean", - "timestamp", - "date") - if (type %in% primtiveTypes) { + if (!is.null(PRIMITIVE_TYPES[[type]])) { return() } else { # Check complex types diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 17082b4e52fcf..3bbf60d9b668c 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -17,6 +17,7 @@ # Utility functions to serialize R objects so they can be read in Java. +# nolint start # Type mapping from R to Java # # NULL -> Void @@ -31,6 +32,7 @@ # list[T] -> Array[T], where T is one of above mentioned types # environment -> Map[String, T], where T is a native type # jobj -> Object, where jobj is an object created in the backend +# nolint end getSerdeType <- function(object) { type <- class(object)[[1]] @@ -52,7 +54,7 @@ writeObject <- function(con, object, writeType = TRUE) { # passing in vectors as arrays and instead require arrays to be passed # as lists. type <- class(object)[[1]] # class of POSIXlt is c("POSIXlt", "POSIXt") - # Checking types is needed here, since ‘is.na’ only handles atomic vectors, + # Checking types is needed here, since 'is.na' only handles atomic vectors, # lists and pairlists if (type %in% c("integer", "character", "logical", "double", "numeric")) { if (is.na(object)) { @@ -98,7 +100,7 @@ writeJobj <- function(con, value) { writeString <- function(con, value) { utfVal <- enc2utf8(value) writeInt(con, as.integer(nchar(utfVal, type = "bytes") + 1)) - writeBin(utfVal, con, endian = "big", useBytes=TRUE) + writeBin(utfVal, con, endian = "big", useBytes = TRUE) } writeInt <- function(con, value) { diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 004d08e74e1cd..c187869fdf121 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -34,7 +34,6 @@ connExists <- function(env) { sparkR.stop <- function() { env <- .sparkREnv if (exists(".sparkRCon", envir = env)) { - # cat("Stopping SparkR\n") if (exists(".sparkRjsc", envir = env)) { sc <- get(".sparkRjsc", envir = env) callJMethod(sc, "stop") @@ -49,6 +48,12 @@ sparkR.stop <- function() { } } + # Remove the R package lib path from .libPaths() + if (exists(".libPath", envir = env)) { + libPath <- get(".libPath", envir = env) + .libPaths(.libPaths()[.libPaths() != libPath]) + } + if (exists(".backendLaunched", envir = env)) { callJStatic("SparkRHandler", "stopBackend") } @@ -78,16 +83,16 @@ sparkR.stop <- function() { #' Initialize a new Spark Context. #' #' This function initializes a new SparkContext. For details on how to initialize -#' and use SparkR, refer to SparkR programming guide at +#' and use SparkR, refer to SparkR programming guide at #' \url{http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparkcontext-sqlcontext}. #' -#' @param master The Spark master URL. +#' @param master The Spark master URL #' @param appName Application name to register with cluster manager #' @param sparkHome Spark Home directory -#' @param sparkEnvir Named list of environment variables to set on worker nodes. -#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors. -#' @param sparkJars Character string vector of jar files to pass to the worker nodes. -#' @param sparkPackages Character string vector of packages from spark-packages.org +#' @param sparkEnvir Named list of environment variables to set on worker nodes +#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors +#' @param sparkJars Character vector of jar files to pass to the worker nodes +#' @param sparkPackages Character vector of packages from spark-packages.org #' @export #' @examples #'\dontrun{ @@ -97,7 +102,9 @@ sparkR.stop <- function() { #' sc <- sparkR.init("yarn-client", "SparkR", "/home/spark", #' list(spark.executor.memory="4g"), #' list(LD_LIBRARY_PATH="/directory of JVM libraries (libjvm.so) on workers/"), -#' c("jarfile1.jar","jarfile2.jar")) +#' c("one.jar", "two.jar", "three.jar"), +#' c("com.databricks:spark-avro_2.10:2.0.1", +#' "com.databricks:spark-csv_2.10:1.3.0")) #'} sparkR.init <- function( @@ -115,15 +122,8 @@ sparkR.init <- function( return(get(".sparkRjsc", envir = .sparkREnv)) } - jars <- suppressWarnings(normalizePath(as.character(sparkJars))) - - # Classpath separator is ";" on Windows - # URI needs four /// as from http://stackoverflow.com/a/18522792 - if (.Platform$OS.type == "unix") { - uriSep <- "//" - } else { - uriSep <- "////" - } + jars <- processSparkJars(sparkJars) + packages <- processSparkPackages(sparkPackages) sparkEnvirMap <- convertNamedListToEnv(sparkEnvir) @@ -140,7 +140,7 @@ sparkR.init <- function( sparkHome = sparkHome, jars = jars, sparkSubmitOpts = submitOps, - packages = sparkPackages) + packages = packages) # wait atmost 100 seconds for JVM to launch wait <- 0.1 for (i in 1:25) { @@ -153,17 +153,23 @@ sparkR.init <- function( if (!file.exists(path)) { stop("JVM is not ready after 10 seconds") } - f <- file(path, open="rb") + f <- file(path, open = "rb") backendPort <- readInt(f) monitorPort <- readInt(f) + rLibPath <- readString(f) close(f) file.remove(path) if (length(backendPort) == 0 || backendPort == 0 || - length(monitorPort) == 0 || monitorPort == 0) { + length(monitorPort) == 0 || monitorPort == 0 || + length(rLibPath) != 1) { stop("JVM failed to launch") } assign(".monitorConn", socketConnection(port = monitorPort), envir = .sparkREnv) assign(".backendLaunched", 1, envir = .sparkREnv) + if (rLibPath != "") { + assign(".libPath", rLibPath, envir = .sparkREnv) + .libPaths(c(rLibPath, .libPaths())) + } } .sparkREnv$backendPort <- backendPort @@ -179,13 +185,19 @@ sparkR.init <- function( } sparkExecutorEnvMap <- convertNamedListToEnv(sparkExecutorEnv) - if(is.null(sparkExecutorEnvMap$LD_LIBRARY_PATH)) { + if (is.null(sparkExecutorEnvMap$LD_LIBRARY_PATH)) { sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- - paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) + paste0("$LD_LIBRARY_PATH:", Sys.getenv("LD_LIBRARY_PATH")) } - nonEmptyJars <- Filter(function(x) { x != "" }, jars) - localJarPaths <- lapply(nonEmptyJars, + # Classpath separator is ";" on Windows + # URI needs four /// as from http://stackoverflow.com/a/18522792 + if (.Platform$OS.type == "unix") { + uriSep <- "//" + } else { + uriSep <- "////" + } + localJarPaths <- lapply(jars, function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) # Set the start time to identify jobjs @@ -287,7 +299,7 @@ sparkRHive.init <- function(jsc = NULL) { #' #' @param sc existing spark context #' @param groupid the ID to be assigned to job groups -#' @param description description for the the job group ID +#' @param description description for the job group ID #' @param interruptOnCancel flag to indicate if the job is interrupted on job cancellation #' @examples #'\dontrun{ @@ -355,3 +367,22 @@ getClientModeSparkSubmitOpts <- function(submitOps, sparkEnvirMap) { # --option must be before the application class "sparkr-shell" in submitOps paste0(paste0(envirToOps, collapse = ""), submitOps) } + +# Utility function that handles sparkJars argument, and normalize paths +processSparkJars <- function(jars) { + splittedJars <- splitString(jars) + if (length(splittedJars) > length(jars)) { + warning("sparkJars as a comma-separated string is deprecated, use character vector instead") + } + normalized <- suppressWarnings(normalizePath(splittedJars)) + normalized +} + +# Utility function that handles sparkPackages argument +processSparkPackages <- function(packages) { + splittedPackages <- splitString(packages) + if (length(splittedPackages) > length(packages)) { + warning("sparkPackages as a comma-separated string is deprecated, use character vector instead") + } + splittedPackages +} diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index f79329b115404..edf72937c633a 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -66,8 +66,9 @@ setMethod("crosstab", #' cov <- cov(df, "title", "gender") #' } setMethod("cov", - signature(x = "DataFrame", col1 = "character", col2 = "character"), + signature(x = "DataFrame"), function(x, col1, col2) { + stopifnot(class(col1) == "character" && class(col2) == "character") statFunctions <- callJMethod(x@sdf, "stat") callJMethod(statFunctions, "cov", col1, col2) }) @@ -77,7 +78,7 @@ setMethod("cov", #' Calculates the correlation of two columns of a DataFrame. #' Currently only supports the Pearson Correlation Coefficient. #' For Spearman Correlation, consider using RDD methods found in MLlib's Statistics. -#' +#' #' @param x A SparkSQL DataFrame #' @param col1 the name of the first column #' @param col2 the name of the second column @@ -95,8 +96,9 @@ setMethod("cov", #' corr <- corr(df, "title", "gender", method = "pearson") #' } setMethod("corr", - signature(x = "DataFrame", col1 = "character", col2 = "character"), + signature(x = "DataFrame"), function(x, col1, col2, method = "pearson") { + stopifnot(class(col1) == "character" && class(col2) == "character") statFunctions <- callJMethod(x@sdf, "stat") callJMethod(statFunctions, "corr", col1, col2, method) }) @@ -109,7 +111,7 @@ setMethod("corr", #' #' @param x A SparkSQL DataFrame. #' @param cols A vector column names to search frequent items in. -#' @param support (Optional) The minimum frequency for an item to be considered `frequent`. +#' @param support (Optional) The minimum frequency for an item to be considered `frequent`. #' Should be greater than 1e-4. Default support = 0.01. #' @return a local R data.frame with the frequent items in each column #' @@ -128,10 +130,49 @@ setMethod("freqItems", signature(x = "DataFrame", cols = "character"), collect(dataFrame(sct)) }) +#' approxQuantile +#' +#' Calculates the approximate quantiles of a numerical column of a DataFrame. +#' +#' The result of this algorithm has the following deterministic bound: +#' If the DataFrame has N elements and if we request the quantile at probability `p` up to error +#' `err`, then the algorithm will return a sample `x` from the DataFrame so that the *exact* rank +#' of `x` is close to (p * N). More precisely, +#' floor((p - err) * N) <= rank(x) <= ceil((p + err) * N). +#' This method implements a variation of the Greenwald-Khanna algorithm (with some speed +#' optimizations). The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 +#' Space-efficient Online Computation of Quantile Summaries]] by Greenwald and Khanna. +#' +#' @param x A SparkSQL DataFrame. +#' @param col The name of the numerical column. +#' @param probabilities A list of quantile probabilities. Each number must belong to [0, 1]. +#' For example 0 is the minimum, 0.5 is the median, 1 is the maximum. +#' @param relativeError The relative target precision to achieve (>= 0). If set to zero, +#' the exact quantiles are computed, which could be very expensive. +#' Note that values greater than 1 are accepted but give the same result as 1. +#' @return The approximate quantiles at the given probabilities. +#' +#' @rdname statfunctions +#' @name approxQuantile +#' @export +#' @examples +#' \dontrun{ +#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' quantiles <- approxQuantile(df, "key", c(0.5, 0.8), 0.0) +#' } +setMethod("approxQuantile", + signature(x = "DataFrame", col = "character", + probabilities = "numeric", relativeError = "numeric"), + function(x, col, probabilities, relativeError) { + statFunctions <- callJMethod(x@sdf, "stat") + callJMethod(statFunctions, "approxQuantile", col, + as.list(probabilities), relativeError) + }) + #' sampleBy #' #' Returns a stratified sample without replacement based on the fraction given on each stratum. -#' +#' #' @param x A SparkSQL DataFrame #' @param col column that defines strata #' @param fractions A named list giving sampling fraction for each stratum. If a stratum is diff --git a/R/pkg/R/types.R b/R/pkg/R/types.R new file mode 100644 index 0000000000000..ad048b1cd1795 --- /dev/null +++ b/R/pkg/R/types.R @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# types.R. This file handles the data type mapping between Spark and R + +# The primitive data types, where names(PRIMITIVE_TYPES) are Scala types whereas +# values are equivalent R types. This is stored in an environment to allow for +# more efficient look up (environments use hashmaps). +PRIMITIVE_TYPES <- as.environment(list( + "tinyint" = "integer", + "smallint" = "integer", + "int" = "integer", + "bigint" = "numeric", + "float" = "numeric", + "double" = "numeric", + "decimal" = "numeric", + "string" = "character", + "binary" = "raw", + "boolean" = "logical", + "timestamp" = "POSIXct", + "date" = "Date", + # following types are not SQL types returned by dtypes(). They are listed here for usage + # by checkType() in schema.R. + # TODO: refactor checkType() in schema.R. + "byte" = "integer", + "integer" = "integer" + )) + +# The complex data types. These do not have any direct mapping to R's types. +COMPLEX_TYPES <- list( + "map" = NA, + "array" = NA, + "struct" = NA) + +# The full list of data types. +DATA_TYPES <- as.environment(c(as.list(PRIMITIVE_TYPES), COMPLEX_TYPES)) + +SHORT_TYPES <- as.environment(list( + "character" = "chr", + "logical" = "logi", + "POSIXct" = "POSIXct", + "integer" = "int", + "numeric" = "num", + "raw" = "raw", + "Date" = "Date", + "map" = "map", + "array" = "array", + "struct" = "struct" +)) + +# An environment for mapping R to Scala, names are R types and values are Scala types. +rToSQLTypes <- as.environment(list( + "integer" = "integer", # in R, integer is 32bit + "numeric" = "double", # in R, numeric == double which is 64bit + "double" = "double", + "character" = "string", + "logical" = "boolean")) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 0b9e2957fe9a5..fb6575cb42907 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -158,7 +158,7 @@ wrapInt <- function(value) { # Multiply `val` by 31 and add `addVal` to the result. Ensures that # integer-overflows are handled at every step. mult31AndAdd <- function(val, addVal) { - vec <- c(bitwShiftL(val, c(4,3,2,1,0)), addVal) + vec <- c(bitwShiftL(val, c(4, 3, 2, 1, 0)), addVal) Reduce(function(a, b) { wrapInt(as.numeric(a) + as.numeric(b)) }, @@ -202,7 +202,7 @@ serializeToString <- function(rdd) { # This function amortizes the allocation cost by doubling # the size of the list every time it fills up. addItemToAccumulator <- function(acc, item) { - if(acc$counter == acc$size) { + if (acc$counter == acc$size) { acc$size <- acc$size * 2 length(acc$data) <- acc$size } @@ -623,3 +623,30 @@ convertNamedListToEnv <- function(namedList) { } env } + +# Assign a new environment for attach() and with() methods +assignNewEnv <- function(data) { + stopifnot(class(data) == "DataFrame") + cols <- columns(data) + stopifnot(length(cols) > 0) + + env <- new.env() + for (i in 1:length(cols)) { + assign(x = cols[i], value = data[, cols[i]], envir = env) + } + env +} + +# Utility function to split by ',' and whitespace, remove empty tokens +splitString <- function(input) { + Filter(nzchar, unlist(strsplit(input, ",|\\s"))) +} + +convertToJSaveMode <- function(mode) { + allModes <- c("append", "overwrite", "error", "ignore") + if (!(mode %in% allModes)) { + stop('mode should be one of "append", "overwrite", "error", "ignore"') # nolint + } + jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) + jmode +} diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R index 2a8a8213d0849..8c75c19ca7ac3 100644 --- a/R/pkg/inst/profile/general.R +++ b/R/pkg/inst/profile/general.R @@ -17,6 +17,7 @@ .First <- function() { packageDir <- Sys.getenv("SPARKR_PACKAGE_DIR") - .libPaths(c(packageDir, .libPaths())) - Sys.setenv(NOAWT=1) + dirs <- strsplit(packageDir, ",")[[1]] + .libPaths(c(dirs, .libPaths())) + Sys.setenv(NOAWT = 1) } diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R index 7189f1a260934..90a3761e41f82 100644 --- a/R/pkg/inst/profile/shell.R +++ b/R/pkg/inst/profile/shell.R @@ -38,7 +38,7 @@ if (nchar(sparkVer) == 0) { cat("\n") } else { - cat(" version ", sparkVer, "\n") + cat(" version ", sparkVer, "\n") } cat(" /_/", "\n") cat("\n") diff --git a/R/pkg/inst/tests/test_context.R b/R/pkg/inst/tests/test_context.R deleted file mode 100644 index 80c1b89a4c627..0000000000000 --- a/R/pkg/inst/tests/test_context.R +++ /dev/null @@ -1,94 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -context("test functions in sparkR.R") - -test_that("repeatedly starting and stopping SparkR", { - for (i in 1:4) { - sc <- sparkR.init() - rdd <- parallelize(sc, 1:20, 2L) - expect_equal(count(rdd), 20) - sparkR.stop() - } -}) - -test_that("repeatedly starting and stopping SparkR SQL", { - for (i in 1:4) { - sc <- sparkR.init() - sqlContext <- sparkRSQL.init(sc) - df <- createDataFrame(sqlContext, data.frame(a = 1:20)) - expect_equal(count(df), 20) - sparkR.stop() - } -}) - -test_that("rdd GC across sparkR.stop", { - sparkR.stop() - sc <- sparkR.init() # sc should get id 0 - rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 - rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 - sparkR.stop() - - sc <- sparkR.init() # sc should get id 0 again - - # GC rdd1 before creating rdd3 and rdd2 after - rm(rdd1) - gc() - - rdd3 <- parallelize(sc, 1:20, 2L) # rdd3 should get id 1 now - rdd4 <- parallelize(sc, 1:10, 2L) # rdd4 should get id 2 now - - rm(rdd2) - gc() - - count(rdd3) - count(rdd4) -}) - -test_that("job group functions can be called", { - sc <- sparkR.init() - setJobGroup(sc, "groupId", "job description", TRUE) - cancelJobGroup(sc, "groupId") - clearJobGroup(sc) -}) - -test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whitelist", { - e <- new.env() - e[["spark.driver.memory"]] <- "512m" - ops <- getClientModeSparkSubmitOpts("sparkrmain", e) - expect_equal("--driver-memory \"512m\" sparkrmain", ops) - - e[["spark.driver.memory"]] <- "5g" - e[["spark.driver.extraClassPath"]] <- "/opt/class_path" # nolint - e[["spark.driver.extraJavaOptions"]] <- "-XX:+UseCompressedOops -XX:+UseCompressedStrings" - e[["spark.driver.extraLibraryPath"]] <- "/usr/local/hadoop/lib" # nolint - e[["random"]] <- "skipthis" - ops2 <- getClientModeSparkSubmitOpts("sparkr-shell", e) - # nolint start - expect_equal(ops2, paste0("--driver-class-path \"/opt/class_path\" --driver-java-options \"", - "-XX:+UseCompressedOops -XX:+UseCompressedStrings\" --driver-library-path \"", - "/usr/local/hadoop/lib\" --driver-memory \"5g\" sparkr-shell")) - # nolint end - - e[["spark.driver.extraClassPath"]] <- "/" # too short - ops3 <- getClientModeSparkSubmitOpts("--driver-memory 4g sparkr-shell2", e) - # nolint start - expect_equal(ops3, paste0("--driver-java-options \"-XX:+UseCompressedOops ", - "-XX:+UseCompressedStrings\" --driver-library-path \"/usr/local/hadoop/lib\"", - " --driver-memory 4g sparkr-shell2")) - # nolint end -}) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R deleted file mode 100644 index 032cfef061fd3..0000000000000 --- a/R/pkg/inst/tests/test_mllib.R +++ /dev/null @@ -1,86 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -library(testthat) - -context("MLlib functions") - -# Tests for MLlib functions in SparkR - -sc <- sparkR.init() - -sqlContext <- sparkRSQL.init(sc) - -test_that("glm and predict", { - training <- createDataFrame(sqlContext, iris) - test <- select(training, "Sepal_Length") - model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian") - prediction <- predict(model, test) - expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") -}) - -test_that("predictions match with native glm", { - training <- createDataFrame(sqlContext, iris) - model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) -}) - -test_that("dot minus and intercept vs native glm", { - training <- createDataFrame(sqlContext, iris) - model <- glm(Sepal_Width ~ . - Species + 0, data = training) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) -}) - -test_that("feature interaction vs native glm", { - training <- createDataFrame(sqlContext, iris) - model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) -}) - -test_that("summary coefficients match with native glm", { - training <- createDataFrame(sqlContext, iris) - stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "l-bfgs")) - coefs <- as.vector(stats$coefficients) - rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) - expect_true(all(abs(rCoefs - coefs) < 1e-6)) - expect_true(all( - as.character(stats$features) == - c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) -}) - -test_that("summary coefficients match with native glm of family 'binomial'", { - df <- createDataFrame(sqlContext, iris) - training <- filter(df, df$Species != "setosa") - stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, - family = "binomial")) - coefs <- as.vector(stats$coefficients) - - rTraining <- iris[iris$Species %in% c("versicolor","virginica"),] - rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, - family = binomial(link = "logit")))) - - expect_true(all(abs(rCoefs - coefs) < 1e-4)) - expect_true(all( - as.character(stats$features) == - c("(Intercept)", "Sepal_Length", "Sepal_Width"))) -}) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R deleted file mode 100644 index b4a4d03b2643b..0000000000000 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ /dev/null @@ -1,1499 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -library(testthat) - -context("SparkSQL functions") - -# Utility function for easily checking the values of a StructField -checkStructField <- function(actual, expectedName, expectedType, expectedNullable) { - expect_equal(class(actual), "structField") - expect_equal(actual$name(), expectedName) - expect_equal(actual$dataType.toString(), expectedType) - expect_equal(actual$nullable(), expectedNullable) -} - -# Tests for SparkSQL functions in SparkR - -sc <- sparkR.init() - -sqlContext <- sparkRSQL.init(sc) - -mockLines <- c("{\"name\":\"Michael\"}", - "{\"name\":\"Andy\", \"age\":30}", - "{\"name\":\"Justin\", \"age\":19}") -jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") -parquetPath <- tempfile(pattern="sparkr-test", fileext=".parquet") -writeLines(mockLines, jsonPath) - -# For test nafunctions, like dropna(), fillna(),... -mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}", - "{\"name\":\"Alice\",\"age\":null,\"height\":164.3}", - "{\"name\":\"David\",\"age\":60,\"height\":null}", - "{\"name\":\"Amy\",\"age\":null,\"height\":null}", - "{\"name\":null,\"age\":null,\"height\":null}") -jsonPathNa <- tempfile(pattern="sparkr-test", fileext=".tmp") -writeLines(mockLinesNa, jsonPathNa) - -# For test complex types in DataFrame -mockLinesComplexType <- - c("{\"c1\":[1, 2, 3], \"c2\":[\"a\", \"b\", \"c\"], \"c3\":[1.0, 2.0, 3.0]}", - "{\"c1\":[4, 5, 6], \"c2\":[\"d\", \"e\", \"f\"], \"c3\":[4.0, 5.0, 6.0]}", - "{\"c1\":[7, 8, 9], \"c2\":[\"g\", \"h\", \"i\"], \"c3\":[7.0, 8.0, 9.0]}") -complexTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") -writeLines(mockLinesComplexType, complexTypeJsonPath) - -test_that("infer types and check types", { - expect_equal(infer_type(1L), "integer") - expect_equal(infer_type(1.0), "double") - expect_equal(infer_type("abc"), "string") - expect_equal(infer_type(TRUE), "boolean") - expect_equal(infer_type(as.Date("2015-03-11")), "date") - expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp") - expect_equal(infer_type(c(1L, 2L)), "array") - expect_equal(infer_type(list(1L, 2L)), "array") - expect_equal(infer_type(listToStruct(list(a = 1L, b = "2"))), "struct") - e <- new.env() - assign("a", 1L, envir = e) - expect_equal(infer_type(e), "map") - - expect_error(checkType("map"), "Key type in a map must be string or character") -}) - -test_that("structType and structField", { - testField <- structField("a", "string") - expect_is(testField, "structField") - expect_equal(testField$name(), "a") - expect_true(testField$nullable()) - - testSchema <- structType(testField, structField("b", "integer")) - expect_is(testSchema, "structType") - expect_is(testSchema$fields()[[2]], "structField") - expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") -}) - -test_that("create DataFrame from RDD", { - rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) - df <- createDataFrame(sqlContext, rdd, list("a", "b")) - dfAsDF <- as.DataFrame(sqlContext, rdd, list("a", "b")) - expect_is(df, "DataFrame") - expect_is(dfAsDF, "DataFrame") - expect_equal(count(df), 10) - expect_equal(count(dfAsDF), 10) - expect_equal(nrow(df), 10) - expect_equal(nrow(dfAsDF), 10) - expect_equal(ncol(df), 2) - expect_equal(ncol(dfAsDF), 2) - expect_equal(dim(df), c(10, 2)) - expect_equal(dim(dfAsDF), c(10, 2)) - expect_equal(columns(df), c("a", "b")) - expect_equal(columns(dfAsDF), c("a", "b")) - expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) - expect_equal(dtypes(dfAsDF), list(c("a", "int"), c("b", "string"))) - - df <- createDataFrame(sqlContext, rdd) - dfAsDF <- as.DataFrame(sqlContext, rdd) - expect_is(df, "DataFrame") - expect_is(dfAsDF, "DataFrame") - expect_equal(columns(df), c("_1", "_2")) - expect_equal(columns(dfAsDF), c("_1", "_2")) - - schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), - structField(x = "b", type = "string", nullable = TRUE)) - df <- createDataFrame(sqlContext, rdd, schema) - expect_is(df, "DataFrame") - expect_equal(columns(df), c("a", "b")) - expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) - - rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) - df <- createDataFrame(sqlContext, rdd) - expect_is(df, "DataFrame") - expect_equal(count(df), 10) - expect_equal(columns(df), c("a", "b")) - expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) - - df <- jsonFile(sqlContext, jsonPathNa) - hiveCtx <- tryCatch({ - newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, - error = function(err) { - skip("Hive is not build with SparkSQL, skipped") - }) - sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)") - insertInto(df, "people") - expect_equal(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"), c(16)) - expect_equal(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"), c(176.5)) - - schema <- structType(structField("name", "string"), structField("age", "integer"), - structField("height", "float")) - df2 <- createDataFrame(sqlContext, df.toRDD, schema) - df2AsDF <- as.DataFrame(sqlContext, df.toRDD, schema) - expect_equal(columns(df2), c("name", "age", "height")) - expect_equal(columns(df2AsDF), c("name", "age", "height")) - expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float"))) - expect_equal(dtypes(df2AsDF), list(c("name", "string"), c("age", "int"), c("height", "float"))) - expect_equal(collect(where(df2, df2$name == "Bob")), c("Bob", 16, 176.5)) - expect_equal(collect(where(df2AsDF, df2$name == "Bob")), c("Bob", 16, 176.5)) - - localDF <- data.frame(name=c("John", "Smith", "Sarah"), - age=c(19, 23, 18), - height=c(164.10, 181.4, 173.7)) - df <- createDataFrame(sqlContext, localDF, schema) - expect_is(df, "DataFrame") - expect_equal(count(df), 3) - expect_equal(columns(df), c("name", "age", "height")) - expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float"))) - expect_equal(collect(where(df, df$name == "John")), c("John", 19, 164.10)) -}) - -test_that("convert NAs to null type in DataFrames", { - rdd <- parallelize(sc, list(list(1L, 2L), list(NA, 4L))) - df <- createDataFrame(sqlContext, rdd, list("a", "b")) - expect_true(is.na(collect(df)[2, "a"])) - expect_equal(collect(df)[2, "b"], 4L) - - l <- data.frame(x = 1L, y = c(1L, NA_integer_, 3L)) - df <- createDataFrame(sqlContext, l) - expect_equal(collect(df)[2, "x"], 1L) - expect_true(is.na(collect(df)[2, "y"])) - - rdd <- parallelize(sc, list(list(1, 2), list(NA, 4))) - df <- createDataFrame(sqlContext, rdd, list("a", "b")) - expect_true(is.na(collect(df)[2, "a"])) - expect_equal(collect(df)[2, "b"], 4) - - l <- data.frame(x = 1, y = c(1, NA_real_, 3)) - df <- createDataFrame(sqlContext, l) - expect_equal(collect(df)[2, "x"], 1) - expect_true(is.na(collect(df)[2, "y"])) - - l <- list("a", "b", NA, "d") - df <- createDataFrame(sqlContext, l) - expect_true(is.na(collect(df)[3, "_1"])) - expect_equal(collect(df)[4, "_1"], "d") - - l <- list("a", "b", NA_character_, "d") - df <- createDataFrame(sqlContext, l) - expect_true(is.na(collect(df)[3, "_1"])) - expect_equal(collect(df)[4, "_1"], "d") - - l <- list(TRUE, FALSE, NA, TRUE) - df <- createDataFrame(sqlContext, l) - expect_true(is.na(collect(df)[3, "_1"])) - expect_equal(collect(df)[4, "_1"], TRUE) -}) - -test_that("toDF", { - rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) - df <- toDF(rdd, list("a", "b")) - expect_is(df, "DataFrame") - expect_equal(count(df), 10) - expect_equal(columns(df), c("a", "b")) - expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) - - df <- toDF(rdd) - expect_is(df, "DataFrame") - expect_equal(columns(df), c("_1", "_2")) - - schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), - structField(x = "b", type = "string", nullable = TRUE)) - df <- toDF(rdd, schema) - expect_is(df, "DataFrame") - expect_equal(columns(df), c("a", "b")) - expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) - - rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) - df <- toDF(rdd) - expect_is(df, "DataFrame") - expect_equal(count(df), 10) - expect_equal(columns(df), c("a", "b")) - expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) -}) - -test_that("create DataFrame from list or data.frame", { - l <- list(list(1, 2), list(3, 4)) - df <- createDataFrame(sqlContext, l, c("a", "b")) - expect_equal(columns(df), c("a", "b")) - - l <- list(list(a=1, b=2), list(a=3, b=4)) - df <- createDataFrame(sqlContext, l) - expect_equal(columns(df), c("a", "b")) - - a <- 1:3 - b <- c("a", "b", "c") - ldf <- data.frame(a, b) - df <- createDataFrame(sqlContext, ldf) - expect_equal(columns(df), c("a", "b")) - expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) - expect_equal(count(df), 3) - ldf2 <- collect(df) - expect_equal(ldf$a, ldf2$a) -}) - -test_that("create DataFrame with different data types", { - l <- list(a = 1L, b = 2, c = TRUE, d = "ss", e = as.Date("2012-12-13"), - f = as.POSIXct("2015-03-15 12:13:14.056")) - df <- createDataFrame(sqlContext, list(l)) - expect_equal(dtypes(df), list(c("a", "int"), c("b", "double"), c("c", "boolean"), - c("d", "string"), c("e", "date"), c("f", "timestamp"))) - expect_equal(count(df), 1) - expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) -}) - -test_that("create DataFrame with complex types", { - e <- new.env() - assign("n", 3L, envir = e) - - s <- listToStruct(list(a = "aa", b = 3L)) - - l <- list(as.list(1:10), list("a", "b"), e, s) - df <- createDataFrame(sqlContext, list(l), c("a", "b", "c", "d")) - expect_equal(dtypes(df), list(c("a", "array"), - c("b", "array"), - c("c", "map"), - c("d", "struct"))) - expect_equal(count(df), 1) - ldf <- collect(df) - expect_equal(names(ldf), c("a", "b", "c", "d")) - expect_equal(ldf[1, 1][[1]], l[[1]]) - expect_equal(ldf[1, 2][[1]], l[[2]]) - - e <- ldf$c[[1]] - expect_equal(class(e), "environment") - expect_equal(ls(e), "n") - expect_equal(e$n, 3L) - - s <- ldf$d[[1]] - expect_equal(class(s), "struct") - expect_equal(s$a, "aa") - expect_equal(s$b, 3L) -}) - -# For test map type and struct type in DataFrame -mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}", - "{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}", - "{\"name\":\"David\",\"info\":{\"age\":60,\"height\":180}}") -mapTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") -writeLines(mockLinesMapType, mapTypeJsonPath) - -test_that("Collect DataFrame with complex types", { - # ArrayType - df <- jsonFile(sqlContext, complexTypeJsonPath) - - ldf <- collect(df) - expect_equal(nrow(ldf), 3) - expect_equal(ncol(ldf), 3) - expect_equal(names(ldf), c("c1", "c2", "c3")) - expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list (7, 8, 9))) - expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list ("g", "h", "i"))) - expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list (7.0, 8.0, 9.0))) - - # MapType - schema <- structType(structField("name", "string"), - structField("info", "map")) - df <- read.df(sqlContext, mapTypeJsonPath, "json", schema) - expect_equal(dtypes(df), list(c("name", "string"), - c("info", "map"))) - ldf <- collect(df) - expect_equal(nrow(ldf), 3) - expect_equal(ncol(ldf), 2) - expect_equal(names(ldf), c("name", "info")) - expect_equal(ldf$name, c("Bob", "Alice", "David")) - bob <- ldf$info[[1]] - expect_equal(class(bob), "environment") - expect_equal(bob$age, 16) - expect_equal(bob$height, 176.5) - - # StructType - df <- jsonFile(sqlContext, mapTypeJsonPath) - expect_equal(dtypes(df), list(c("info", "struct"), - c("name", "string"))) - ldf <- collect(df) - expect_equal(nrow(ldf), 3) - expect_equal(ncol(ldf), 2) - expect_equal(names(ldf), c("info", "name")) - expect_equal(ldf$name, c("Bob", "Alice", "David")) - bob <- ldf$info[[1]] - expect_equal(class(bob), "struct") - expect_equal(bob$age, 16) - expect_equal(bob$height, 176.5) -}) - -test_that("jsonFile() on a local file returns a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) - expect_is(df, "DataFrame") - expect_equal(count(df), 3) -}) - -test_that("jsonRDD() on a RDD with json string", { - rdd <- parallelize(sc, mockLines) - expect_equal(count(rdd), 3) - df <- jsonRDD(sqlContext, rdd) - expect_is(df, "DataFrame") - expect_equal(count(df), 3) - - rdd2 <- flatMap(rdd, function(x) c(x, x)) - df <- jsonRDD(sqlContext, rdd2) - expect_is(df, "DataFrame") - expect_equal(count(df), 6) -}) - -test_that("test cache, uncache and clearCache", { - df <- jsonFile(sqlContext, jsonPath) - registerTempTable(df, "table1") - cacheTable(sqlContext, "table1") - uncacheTable(sqlContext, "table1") - clearCache(sqlContext) - dropTempTable(sqlContext, "table1") -}) - -test_that("test tableNames and tables", { - df <- jsonFile(sqlContext, jsonPath) - registerTempTable(df, "table1") - expect_equal(length(tableNames(sqlContext)), 1) - df <- tables(sqlContext) - expect_equal(count(df), 1) - dropTempTable(sqlContext, "table1") -}) - -test_that("registerTempTable() results in a queryable table and sql() results in a new DataFrame", { - df <- jsonFile(sqlContext, jsonPath) - registerTempTable(df, "table1") - newdf <- sql(sqlContext, "SELECT * FROM table1 where name = 'Michael'") - expect_is(newdf, "DataFrame") - expect_equal(count(newdf), 1) - dropTempTable(sqlContext, "table1") -}) - -test_that("insertInto() on a registered table", { - df <- read.df(sqlContext, jsonPath, "json") - write.df(df, parquetPath, "parquet", "overwrite") - dfParquet <- read.df(sqlContext, parquetPath, "parquet") - - lines <- c("{\"name\":\"Bob\", \"age\":24}", - "{\"name\":\"James\", \"age\":35}") - jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".tmp") - parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") - writeLines(lines, jsonPath2) - df2 <- read.df(sqlContext, jsonPath2, "json") - write.df(df2, parquetPath2, "parquet", "overwrite") - dfParquet2 <- read.df(sqlContext, parquetPath2, "parquet") - - registerTempTable(dfParquet, "table1") - insertInto(dfParquet2, "table1") - expect_equal(count(sql(sqlContext, "select * from table1")), 5) - expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Michael") - dropTempTable(sqlContext, "table1") - - registerTempTable(dfParquet, "table1") - insertInto(dfParquet2, "table1", overwrite = TRUE) - expect_equal(count(sql(sqlContext, "select * from table1")), 2) - expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Bob") - dropTempTable(sqlContext, "table1") -}) - -test_that("table() returns a new DataFrame", { - df <- jsonFile(sqlContext, jsonPath) - registerTempTable(df, "table1") - tabledf <- table(sqlContext, "table1") - expect_is(tabledf, "DataFrame") - expect_equal(count(tabledf), 3) - dropTempTable(sqlContext, "table1") -}) - -test_that("toRDD() returns an RRDD", { - df <- jsonFile(sqlContext, jsonPath) - testRDD <- toRDD(df) - expect_is(testRDD, "RDD") - expect_equal(count(testRDD), 3) -}) - -test_that("union on two RDDs created from DataFrames returns an RRDD", { - df <- jsonFile(sqlContext, jsonPath) - RDD1 <- toRDD(df) - RDD2 <- toRDD(df) - unioned <- unionRDD(RDD1, RDD2) - expect_is(unioned, "RDD") - expect_equal(SparkR:::getSerializedMode(unioned), "byte") - expect_equal(collect(unioned)[[2]]$name, "Andy") -}) - -test_that("union on mixed serialization types correctly returns a byte RRDD", { - # Byte RDD - nums <- 1:10 - rdd <- parallelize(sc, nums, 2L) - - # String RDD - textLines <- c("Michael", - "Andy, 30", - "Justin, 19") - textPath <- tempfile(pattern="sparkr-textLines", fileext=".tmp") - writeLines(textLines, textPath) - textRDD <- textFile(sc, textPath) - - df <- jsonFile(sqlContext, jsonPath) - dfRDD <- toRDD(df) - - unionByte <- unionRDD(rdd, dfRDD) - expect_is(unionByte, "RDD") - expect_equal(SparkR:::getSerializedMode(unionByte), "byte") - expect_equal(collect(unionByte)[[1]], 1) - expect_equal(collect(unionByte)[[12]]$name, "Andy") - - unionString <- unionRDD(textRDD, dfRDD) - expect_is(unionString, "RDD") - expect_equal(SparkR:::getSerializedMode(unionString), "byte") - expect_equal(collect(unionString)[[1]], "Michael") - expect_equal(collect(unionString)[[5]]$name, "Andy") -}) - -test_that("objectFile() works with row serialization", { - objectPath <- tempfile(pattern="spark-test", fileext=".tmp") - df <- jsonFile(sqlContext, jsonPath) - dfRDD <- toRDD(df) - saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) - objectIn <- objectFile(sc, objectPath) - - expect_is(objectIn, "RDD") - expect_equal(SparkR:::getSerializedMode(objectIn), "byte") - expect_equal(collect(objectIn)[[2]]$age, 30) -}) - -test_that("lapply() on a DataFrame returns an RDD with the correct columns", { - df <- jsonFile(sqlContext, jsonPath) - testRDD <- lapply(df, function(row) { - row$newCol <- row$age + 5 - row - }) - expect_is(testRDD, "RDD") - collected <- collect(testRDD) - expect_equal(collected[[1]]$name, "Michael") - expect_equal(collected[[2]]$newCol, 35) -}) - -test_that("collect() returns a data.frame", { - df <- jsonFile(sqlContext, jsonPath) - rdf <- collect(df) - expect_true(is.data.frame(rdf)) - expect_equal(names(rdf)[1], "age") - expect_equal(nrow(rdf), 3) - expect_equal(ncol(rdf), 2) - - # collect() returns data correctly from a DataFrame with 0 row - df0 <- limit(df, 0) - rdf <- collect(df0) - expect_true(is.data.frame(rdf)) - expect_equal(names(rdf)[1], "age") - expect_equal(nrow(rdf), 0) - expect_equal(ncol(rdf), 2) -}) - -test_that("limit() returns DataFrame with the correct number of rows", { - df <- jsonFile(sqlContext, jsonPath) - dfLimited <- limit(df, 2) - expect_is(dfLimited, "DataFrame") - expect_equal(count(dfLimited), 2) -}) - -test_that("collect() and take() on a DataFrame return the same number of rows and columns", { - df <- jsonFile(sqlContext, jsonPath) - expect_equal(nrow(collect(df)), nrow(take(df, 10))) - expect_equal(ncol(collect(df)), ncol(take(df, 10))) -}) - -test_that("collect() support Unicode characters", { - markUtf8 <- function(s) { - Encoding(s) <- "UTF-8" - s - } - - lines <- c("{\"name\":\"안녕하세요\"}", - "{\"name\":\"您好\", \"age\":30}", - "{\"name\":\"こんにちは\", \"age\":19}", - "{\"name\":\"Xin chào\"}") - - jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") - writeLines(lines, jsonPath) - - df <- read.df(sqlContext, jsonPath, "json") - rdf <- collect(df) - expect_true(is.data.frame(rdf)) - expect_equal(rdf$name[1], markUtf8("안녕하세요")) - expect_equal(rdf$name[2], markUtf8("您好")) - expect_equal(rdf$name[3], markUtf8("こんにちは")) - expect_equal(rdf$name[4], markUtf8("Xin chào")) - - df1 <- createDataFrame(sqlContext, rdf) - expect_equal(collect(where(df1, df1$name == markUtf8("您好")))$name, markUtf8("您好")) -}) - -test_that("multiple pipeline transformations result in an RDD with the correct values", { - df <- jsonFile(sqlContext, jsonPath) - first <- lapply(df, function(row) { - row$age <- row$age + 5 - row - }) - second <- lapply(first, function(row) { - row$testCol <- if (row$age == 35 && !is.na(row$age)) TRUE else FALSE - row - }) - expect_is(second, "RDD") - expect_equal(count(second), 3) - expect_equal(collect(second)[[2]]$age, 35) - expect_true(collect(second)[[2]]$testCol) - expect_false(collect(second)[[3]]$testCol) -}) - -test_that("cache(), persist(), and unpersist() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) - expect_false(df@env$isCached) - cache(df) - expect_true(df@env$isCached) - - unpersist(df) - expect_false(df@env$isCached) - - persist(df, "MEMORY_AND_DISK") - expect_true(df@env$isCached) - - unpersist(df) - expect_false(df@env$isCached) - - # make sure the data is collectable - expect_true(is.data.frame(collect(df))) -}) - -test_that("schema(), dtypes(), columns(), names() return the correct values/format", { - df <- jsonFile(sqlContext, jsonPath) - testSchema <- schema(df) - expect_equal(length(testSchema$fields()), 2) - expect_equal(testSchema$fields()[[1]]$dataType.toString(), "LongType") - expect_equal(testSchema$fields()[[2]]$dataType.simpleString(), "string") - expect_equal(testSchema$fields()[[1]]$name(), "age") - - testTypes <- dtypes(df) - expect_equal(length(testTypes[[1]]), 2) - expect_equal(testTypes[[1]][1], "age") - - testCols <- columns(df) - expect_equal(length(testCols), 2) - expect_equal(testCols[2], "name") - - testNames <- names(df) - expect_equal(length(testNames), 2) - expect_equal(testNames[2], "name") -}) - -test_that("head() and first() return the correct data", { - df <- jsonFile(sqlContext, jsonPath) - testHead <- head(df) - expect_equal(nrow(testHead), 3) - expect_equal(ncol(testHead), 2) - - testHead2 <- head(df, 2) - expect_equal(nrow(testHead2), 2) - expect_equal(ncol(testHead2), 2) - - testFirst <- first(df) - expect_equal(nrow(testFirst), 1) - - # head() and first() return the correct data on - # a DataFrame with 0 row - df0 <- limit(df, 0) - - testHead <- head(df0) - expect_equal(nrow(testHead), 0) - expect_equal(ncol(testHead), 2) - - testFirst <- first(df0) - expect_equal(nrow(testFirst), 0) - expect_equal(ncol(testFirst), 2) -}) - -test_that("distinct() and unique on DataFrames", { - lines <- c("{\"name\":\"Michael\"}", - "{\"name\":\"Andy\", \"age\":30}", - "{\"name\":\"Justin\", \"age\":19}", - "{\"name\":\"Justin\", \"age\":19}") - jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") - writeLines(lines, jsonPathWithDup) - - df <- jsonFile(sqlContext, jsonPathWithDup) - uniques <- distinct(df) - expect_is(uniques, "DataFrame") - expect_equal(count(uniques), 3) - - uniques2 <- unique(df) - expect_is(uniques2, "DataFrame") - expect_equal(count(uniques2), 3) -}) - -test_that("sample on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) - sampled <- sample(df, FALSE, 1.0) - expect_equal(nrow(collect(sampled)), count(df)) - expect_is(sampled, "DataFrame") - sampled2 <- sample(df, FALSE, 0.1) - expect_true(count(sampled2) < 3) - - # Also test sample_frac - sampled3 <- sample_frac(df, FALSE, 0.1) - expect_true(count(sampled3) < 3) -}) - -test_that("select operators", { - df <- select(jsonFile(sqlContext, jsonPath), "name", "age") - expect_is(df$name, "Column") - expect_is(df[[2]], "Column") - expect_is(df[["age"]], "Column") - - expect_is(df[,1], "DataFrame") - expect_equal(columns(df[,1]), c("name")) - expect_equal(columns(df[,"age"]), c("age")) - df2 <- df[,c("age", "name")] - expect_is(df2, "DataFrame") - expect_equal(columns(df2), c("age", "name")) - - df$age2 <- df$age - expect_equal(columns(df), c("name", "age", "age2")) - expect_equal(count(where(df, df$age2 == df$age)), 2) - df$age2 <- df$age * 2 - expect_equal(columns(df), c("name", "age", "age2")) - expect_equal(count(where(df, df$age2 == df$age * 2)), 2) - - df$age2 <- NULL - expect_equal(columns(df), c("name", "age")) - df$age3 <- NULL - expect_equal(columns(df), c("name", "age")) -}) - -test_that("select with column", { - df <- jsonFile(sqlContext, jsonPath) - df1 <- select(df, "name") - expect_equal(columns(df1), c("name")) - expect_equal(count(df1), 3) - - df2 <- select(df, df$age) - expect_equal(columns(df2), c("age")) - expect_equal(count(df2), 3) - - df3 <- select(df, lit("x")) - expect_equal(columns(df3), c("x")) - expect_equal(count(df3), 3) - expect_equal(collect(select(df3, "x"))[[1, 1]], "x") - - df4 <- select(df, c("name", "age")) - expect_equal(columns(df4), c("name", "age")) - expect_equal(count(df4), 3) - - expect_error(select(df, c("name", "age"), "name"), - "To select multiple columns, use a character vector or list for col") -}) - -test_that("subsetting", { - # jsonFile returns columns in random order - df <- select(jsonFile(sqlContext, jsonPath), "name", "age") - filtered <- df[df$age > 20,] - expect_equal(count(filtered), 1) - expect_equal(columns(filtered), c("name", "age")) - expect_equal(collect(filtered)$name, "Andy") - - df2 <- df[df$age == 19, 1] - expect_is(df2, "DataFrame") - expect_equal(count(df2), 1) - expect_equal(columns(df2), c("name")) - expect_equal(collect(df2)$name, "Justin") - - df3 <- df[df$age > 20, 2] - expect_equal(count(df3), 1) - expect_equal(columns(df3), c("age")) - - df4 <- df[df$age %in% c(19, 30), 1:2] - expect_equal(count(df4), 2) - expect_equal(columns(df4), c("name", "age")) - - df5 <- df[df$age %in% c(19), c(1,2)] - expect_equal(count(df5), 1) - expect_equal(columns(df5), c("name", "age")) - - df6 <- subset(df, df$age %in% c(30), c(1,2)) - expect_equal(count(df6), 1) - expect_equal(columns(df6), c("name", "age")) -}) - -test_that("selectExpr() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) - selected <- selectExpr(df, "age * 2") - expect_equal(names(selected), "(age * 2)") - expect_equal(collect(selected), collect(select(df, df$age * 2L))) - - selected2 <- selectExpr(df, "name as newName", "abs(age) as age") - expect_equal(names(selected2), c("newName", "age")) - expect_equal(count(selected2), 3) -}) - -test_that("expr() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) - expect_equal(collect(select(df, expr("abs(-123)")))[1, 1], 123) -}) - -test_that("column calculation", { - df <- jsonFile(sqlContext, jsonPath) - d <- collect(select(df, alias(df$age + 1, "age2"))) - expect_equal(names(d), c("age2")) - df2 <- select(df, lower(df$name), abs(df$age)) - expect_is(df2, "DataFrame") - expect_equal(count(df2), 3) -}) - -test_that("read.df() from json file", { - df <- read.df(sqlContext, jsonPath, "json") - expect_is(df, "DataFrame") - expect_equal(count(df), 3) - - # Check if we can apply a user defined schema - schema <- structType(structField("name", type = "string"), - structField("age", type = "double")) - - df1 <- read.df(sqlContext, jsonPath, "json", schema) - expect_is(df1, "DataFrame") - expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) - - # Run the same with loadDF - df2 <- loadDF(sqlContext, jsonPath, "json", schema) - expect_is(df2, "DataFrame") - expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) -}) - -test_that("write.df() as parquet file", { - df <- read.df(sqlContext, jsonPath, "json") - write.df(df, parquetPath, "parquet", mode="overwrite") - df2 <- read.df(sqlContext, parquetPath, "parquet") - expect_is(df2, "DataFrame") - expect_equal(count(df2), 3) -}) - -test_that("test HiveContext", { - hiveCtx <- tryCatch({ - newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, - error = function(err) { - skip("Hive is not build with SparkSQL, skipped") - }) - df <- createExternalTable(hiveCtx, "json", jsonPath, "json") - expect_is(df, "DataFrame") - expect_equal(count(df), 3) - df2 <- sql(hiveCtx, "select * from json") - expect_is(df2, "DataFrame") - expect_equal(count(df2), 3) - - jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") - saveAsTable(df, "json", "json", "append", path = jsonPath2) - df3 <- sql(hiveCtx, "select * from json") - expect_is(df3, "DataFrame") - expect_equal(count(df3), 6) -}) - -test_that("column operators", { - c <- column("a") - c2 <- (- c + 1 - 2) * 3 / 4.0 - c3 <- (c + c2 - c2) * c2 %% c2 - c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3) - c5 <- c2 ^ c3 ^ c4 -}) - -test_that("column functions", { - c <- column("a") - c1 <- abs(c) + acos(c) + approxCountDistinct(c) + ascii(c) + asin(c) + atan(c) - c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c) - c3 <- cosh(c) + count(c) + crc32(c) + exp(c) - c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c) - c5 <- hour(c) + initcap(c) + isNaN(c) + last(c) + last_day(c) + length(c) - c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c) - c7 <- mean(c) + min(c) + month(c) + negate(c) + quarter(c) - c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + sha1(c) - c9 <- signum(c) + sin(c) + sinh(c) + size(c) + soundex(c) + sqrt(c) + sum(c) - c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c) - c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) - c12 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1) - c13 <- cumeDist() + ntile(1) - c14 <- denseRank() + percentRank() + rank() + rowNumber() - - # Test if base::rank() is exposed - expect_equal(class(rank())[[1]], "Column") - expect_equal(rank(1:3), as.numeric(c(1:3))) - - df <- jsonFile(sqlContext, jsonPath) - df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20))) - expect_equal(collect(df2)[[2, 1]], TRUE) - expect_equal(collect(df2)[[2, 2]], FALSE) - expect_equal(collect(df2)[[3, 1]], FALSE) - expect_equal(collect(df2)[[3, 2]], TRUE) - - df3 <- select(df, between(df$name, c("Apache", "Spark"))) - expect_equal(collect(df3)[[1, 1]], TRUE) - expect_equal(collect(df3)[[2, 1]], FALSE) - expect_equal(collect(df3)[[3, 1]], TRUE) - - df4 <- createDataFrame(sqlContext, list(list(a = "010101"))) - expect_equal(collect(select(df4, conv(df4$a, 2, 16)))[1, 1], "15") -}) -# -test_that("column binary mathfunctions", { - lines <- c("{\"a\":1, \"b\":5}", - "{\"a\":2, \"b\":6}", - "{\"a\":3, \"b\":7}", - "{\"a\":4, \"b\":8}") - jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") - writeLines(lines, jsonPathWithDup) - df <- jsonFile(sqlContext, jsonPathWithDup) - expect_equal(collect(select(df, atan2(df$a, df$b)))[1, "ATAN2(a, b)"], atan2(1, 5)) - expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6)) - expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7)) - expect_equal(collect(select(df, atan2(df$a, df$b)))[4, "ATAN2(a, b)"], atan2(4, 8)) - ## nolint start - expect_equal(collect(select(df, hypot(df$a, df$b)))[1, "HYPOT(a, b)"], sqrt(1^2 + 5^2)) - expect_equal(collect(select(df, hypot(df$a, df$b)))[2, "HYPOT(a, b)"], sqrt(2^2 + 6^2)) - expect_equal(collect(select(df, hypot(df$a, df$b)))[3, "HYPOT(a, b)"], sqrt(3^2 + 7^2)) - expect_equal(collect(select(df, hypot(df$a, df$b)))[4, "HYPOT(a, b)"], sqrt(4^2 + 8^2)) - ## nolint end - expect_equal(collect(select(df, shiftLeft(df$b, 1)))[4, 1], 16) - expect_equal(collect(select(df, shiftRight(df$b, 1)))[4, 1], 4) - expect_equal(collect(select(df, shiftRightUnsigned(df$b, 1)))[4, 1], 4) - expect_equal(class(collect(select(df, rand()))[2, 1]), "numeric") - expect_equal(collect(select(df, rand(1)))[1, 1], 0.45, tolerance = 0.01) - expect_equal(class(collect(select(df, randn()))[2, 1]), "numeric") - expect_equal(collect(select(df, randn(1)))[1, 1], -0.0111, tolerance = 0.01) -}) - -test_that("string operators", { - df <- jsonFile(sqlContext, jsonPath) - expect_equal(count(where(df, like(df$name, "A%"))), 1) - expect_equal(count(where(df, startsWith(df$name, "A"))), 1) - expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") - expect_equal(collect(select(df, cast(df$age, "string")))[[2, 1]], "30") - expect_equal(collect(select(df, concat(df$name, lit(":"), df$age)))[[2, 1]], "Andy:30") - expect_equal(collect(select(df, concat_ws(":", df$name)))[[2, 1]], "Andy") - expect_equal(collect(select(df, concat_ws(":", df$name, df$age)))[[2, 1]], "Andy:30") - expect_equal(collect(select(df, instr(df$name, "i")))[, 1], c(2, 0, 5)) - expect_equal(collect(select(df, format_number(df$age, 2)))[2, 1], "30.00") - expect_equal(collect(select(df, sha1(df$name)))[2, 1], - "ab5a000e88b5d9d0fa2575f5c6263eb93452405d") - expect_equal(collect(select(df, sha2(df$name, 256)))[2, 1], - "80f2aed3c618c423ddf05a2891229fba44942d907173152442cf6591441ed6dc") - expect_equal(collect(select(df, format_string("Name:%s", df$name)))[2, 1], "Name:Andy") - expect_equal(collect(select(df, format_string("%s, %d", df$name, df$age)))[2, 1], "Andy, 30") - expect_equal(collect(select(df, regexp_extract(df$name, "(n.y)", 1)))[2, 1], "ndy") - expect_equal(collect(select(df, regexp_replace(df$name, "(n.y)", "ydn")))[2, 1], "Aydn") - - l2 <- list(list(a = "aaads")) - df2 <- createDataFrame(sqlContext, l2) - expect_equal(collect(select(df2, locate("aa", df2$a)))[1, 1], 1) - expect_equal(collect(select(df2, locate("aa", df2$a, 1)))[1, 1], 2) - expect_equal(collect(select(df2, lpad(df2$a, 8, "#")))[1, 1], "###aaads") - expect_equal(collect(select(df2, rpad(df2$a, 8, "#")))[1, 1], "aaads###") - - l3 <- list(list(a = "a.b.c.d")) - df3 <- createDataFrame(sqlContext, l3) - expect_equal(collect(select(df3, substring_index(df3$a, ".", 2)))[1, 1], "a.b") - expect_equal(collect(select(df3, substring_index(df3$a, ".", -3)))[1, 1], "b.c.d") - expect_equal(collect(select(df3, translate(df3$a, "bc", "12")))[1, 1], "a.1.2.d") -}) - -test_that("date functions on a DataFrame", { - .originalTimeZone <- Sys.getenv("TZ") - Sys.setenv(TZ = "UTC") - l <- list(list(a = 1L, b = as.Date("2012-12-13")), - list(a = 2L, b = as.Date("2013-12-14")), - list(a = 3L, b = as.Date("2014-12-15"))) - df <- createDataFrame(sqlContext, l) - expect_equal(collect(select(df, dayofmonth(df$b)))[, 1], c(13, 14, 15)) - expect_equal(collect(select(df, dayofyear(df$b)))[, 1], c(348, 348, 349)) - expect_equal(collect(select(df, weekofyear(df$b)))[, 1], c(50, 50, 51)) - expect_equal(collect(select(df, year(df$b)))[, 1], c(2012, 2013, 2014)) - expect_equal(collect(select(df, month(df$b)))[, 1], c(12, 12, 12)) - expect_equal(collect(select(df, last_day(df$b)))[, 1], - c(as.Date("2012-12-31"), as.Date("2013-12-31"), as.Date("2014-12-31"))) - expect_equal(collect(select(df, next_day(df$b, "MONDAY")))[, 1], - c(as.Date("2012-12-17"), as.Date("2013-12-16"), as.Date("2014-12-22"))) - expect_equal(collect(select(df, date_format(df$b, "y")))[, 1], c("2012", "2013", "2014")) - expect_equal(collect(select(df, add_months(df$b, 3)))[, 1], - c(as.Date("2013-03-13"), as.Date("2014-03-14"), as.Date("2015-03-15"))) - expect_equal(collect(select(df, date_add(df$b, 1)))[, 1], - c(as.Date("2012-12-14"), as.Date("2013-12-15"), as.Date("2014-12-16"))) - expect_equal(collect(select(df, date_sub(df$b, 1)))[, 1], - c(as.Date("2012-12-12"), as.Date("2013-12-13"), as.Date("2014-12-14"))) - - l2 <- list(list(a = 1L, b = as.POSIXlt("2012-12-13 12:34:00", tz = "UTC")), - list(a = 2L, b = as.POSIXlt("2014-12-15 01:24:34", tz = "UTC"))) - df2 <- createDataFrame(sqlContext, l2) - expect_equal(collect(select(df2, minute(df2$b)))[, 1], c(34, 24)) - expect_equal(collect(select(df2, second(df2$b)))[, 1], c(0, 34)) - expect_equal(collect(select(df2, from_utc_timestamp(df2$b, "JST")))[, 1], - c(as.POSIXlt("2012-12-13 21:34:00 UTC"), as.POSIXlt("2014-12-15 10:24:34 UTC"))) - expect_equal(collect(select(df2, to_utc_timestamp(df2$b, "JST")))[, 1], - c(as.POSIXlt("2012-12-13 03:34:00 UTC"), as.POSIXlt("2014-12-14 16:24:34 UTC"))) - expect_more_than(collect(select(df2, unix_timestamp()))[1, 1], 0) - expect_more_than(collect(select(df2, unix_timestamp(df2$b)))[1, 1], 0) - expect_more_than(collect(select(df2, unix_timestamp(lit("2015-01-01"), "yyyy-MM-dd")))[1, 1], 0) - - l3 <- list(list(a = 1000), list(a = -1000)) - df3 <- createDataFrame(sqlContext, l3) - result31 <- collect(select(df3, from_unixtime(df3$a))) - expect_equal(grep("\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}", result31[, 1], perl = TRUE), - c(1, 2)) - result32 <- collect(select(df3, from_unixtime(df3$a, "yyyy"))) - expect_equal(grep("\\d{4}", result32[, 1]), c(1, 2)) - Sys.setenv(TZ = .originalTimeZone) -}) - -test_that("greatest() and least() on a DataFrame", { - l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) - df <- createDataFrame(sqlContext, l) - expect_equal(collect(select(df, greatest(df$a, df$b)))[, 1], c(2, 4)) - expect_equal(collect(select(df, least(df$a, df$b)))[, 1], c(1, 3)) -}) - -test_that("when(), otherwise() and ifelse() on a DataFrame", { - l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) - df <- createDataFrame(sqlContext, l) - expect_equal(collect(select(df, when(df$a > 1 & df$b > 2, 1)))[, 1], c(NA, 1)) - expect_equal(collect(select(df, otherwise(when(df$a > 1, 1), 0)))[, 1], c(0, 1)) - expect_equal(collect(select(df, ifelse(df$a > 1 & df$b > 2, 0, 1)))[, 1], c(1, 0)) -}) - -test_that("group by", { - df <- jsonFile(sqlContext, jsonPath) - df1 <- agg(df, name = "max", age = "sum") - expect_equal(1, count(df1)) - df1 <- agg(df, age2 = max(df$age)) - expect_equal(1, count(df1)) - expect_equal(columns(df1), c("age2")) - - gd <- groupBy(df, "name") - expect_is(gd, "GroupedData") - df2 <- count(gd) - expect_is(df2, "DataFrame") - expect_equal(3, count(df2)) - - # Also test group_by, summarize, mean - gd1 <- group_by(df, "name") - expect_is(gd1, "GroupedData") - df_summarized <- summarize(gd, mean_age = mean(df$age)) - expect_is(df_summarized, "DataFrame") - expect_equal(3, count(df_summarized)) - - df3 <- agg(gd, age = "sum") - expect_is(df3, "DataFrame") - expect_equal(3, count(df3)) - - df3 <- agg(gd, age = sum(df$age)) - expect_is(df3, "DataFrame") - expect_equal(3, count(df3)) - expect_equal(columns(df3), c("name", "age")) - - df4 <- sum(gd, "age") - expect_is(df4, "DataFrame") - expect_equal(3, count(df4)) - expect_equal(3, count(mean(gd, "age"))) - expect_equal(3, count(max(gd, "age"))) -}) - -test_that("arrange() and orderBy() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) - sorted <- arrange(df, df$age) - expect_equal(collect(sorted)[1,2], "Michael") - - sorted2 <- arrange(df, "name", decreasing = FALSE) - expect_equal(collect(sorted2)[2,"age"], 19) - - sorted3 <- orderBy(df, asc(df$age)) - expect_true(is.na(first(sorted3)$age)) - expect_equal(collect(sorted3)[2, "age"], 19) - - sorted4 <- orderBy(df, desc(df$name)) - expect_equal(first(sorted4)$name, "Michael") - expect_equal(collect(sorted4)[3,"name"], "Andy") - - sorted5 <- arrange(df, "age", "name", decreasing = TRUE) - expect_equal(collect(sorted5)[1,2], "Andy") - - sorted6 <- arrange(df, "age","name", decreasing = c(T, F)) - expect_equal(collect(sorted6)[1,2], "Andy") - - sorted7 <- arrange(df, "name", decreasing = FALSE) - expect_equal(collect(sorted7)[2,"age"], 19) -}) - -test_that("filter() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) - filtered <- filter(df, "age > 20") - expect_equal(count(filtered), 1) - expect_equal(collect(filtered)$name, "Andy") - filtered2 <- where(df, df$name != "Michael") - expect_equal(count(filtered2), 2) - expect_equal(collect(filtered2)$age[2], 19) - - # test suites for %in% - filtered3 <- filter(df, "age in (19)") - expect_equal(count(filtered3), 1) - filtered4 <- filter(df, "age in (19, 30)") - expect_equal(count(filtered4), 2) - filtered5 <- where(df, df$age %in% c(19)) - expect_equal(count(filtered5), 1) - filtered6 <- where(df, df$age %in% c(19, 30)) - expect_equal(count(filtered6), 2) -}) - -test_that("join() and merge() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) - - mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", - "{\"name\":\"Andy\", \"test\": \"no\"}", - "{\"name\":\"Justin\", \"test\": \"yes\"}", - "{\"name\":\"Bob\", \"test\": \"yes\"}") - jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") - writeLines(mockLines2, jsonPath2) - df2 <- jsonFile(sqlContext, jsonPath2) - - joined <- join(df, df2) - expect_equal(names(joined), c("age", "name", "name", "test")) - expect_equal(count(joined), 12) - - joined2 <- join(df, df2, df$name == df2$name) - expect_equal(names(joined2), c("age", "name", "name", "test")) - expect_equal(count(joined2), 3) - - joined3 <- join(df, df2, df$name == df2$name, "rightouter") - expect_equal(names(joined3), c("age", "name", "name", "test")) - expect_equal(count(joined3), 4) - expect_true(is.na(collect(orderBy(joined3, joined3$age))$age[2])) - - joined4 <- select(join(df, df2, df$name == df2$name, "outer"), - alias(df$age + 5, "newAge"), df$name, df2$test) - expect_equal(names(joined4), c("newAge", "name", "test")) - expect_equal(count(joined4), 4) - expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24) - - joined5 <- join(df, df2, df$name == df2$name, "leftouter") - expect_equal(names(joined5), c("age", "name", "name", "test")) - expect_equal(count(joined5), 3) - expect_true(is.na(collect(orderBy(joined5, joined5$age))$age[1])) - - joined6 <- join(df, df2, df$name == df2$name, "inner") - expect_equal(names(joined6), c("age", "name", "name", "test")) - expect_equal(count(joined6), 3) - - joined7 <- join(df, df2, df$name == df2$name, "leftsemi") - expect_equal(names(joined7), c("age", "name")) - expect_equal(count(joined7), 3) - - joined8 <- join(df, df2, df$name == df2$name, "left_outer") - expect_equal(names(joined8), c("age", "name", "name", "test")) - expect_equal(count(joined8), 3) - expect_true(is.na(collect(orderBy(joined8, joined8$age))$age[1])) - - joined9 <- join(df, df2, df$name == df2$name, "right_outer") - expect_equal(names(joined9), c("age", "name", "name", "test")) - expect_equal(count(joined9), 4) - expect_true(is.na(collect(orderBy(joined9, joined9$age))$age[2])) - - merged <- merge(df, df2, by.x = "name", by.y = "name", all.x = TRUE, all.y = TRUE) - expect_equal(count(merged), 4) - expect_equal(names(merged), c("age", "name_x", "name_y", "test")) - expect_equal(collect(orderBy(merged, merged$name_x))$age[3], 19) - - merged <- merge(df, df2, suffixes = c("-X","-Y")) - expect_equal(count(merged), 3) - expect_equal(names(merged), c("age", "name-X", "name-Y", "test")) - expect_equal(collect(orderBy(merged, merged$"name-X"))$age[1], 30) - - merged <- merge(df, df2, by = "name", suffixes = c("-X","-Y"), sort = FALSE) - expect_equal(count(merged), 3) - expect_equal(names(merged), c("age", "name-X", "name-Y", "test")) - expect_equal(collect(orderBy(merged, merged$"name-Y"))$"name-X"[3], "Michael") - - merged <- merge(df, df2, by = "name", all = T, sort = T) - expect_equal(count(merged), 4) - expect_equal(names(merged), c("age", "name_x", "name_y", "test")) - expect_equal(collect(orderBy(merged, merged$"name_y"))$"name_x"[1], "Andy") - - merged <- merge(df, df2, by = NULL) - expect_equal(count(merged), 12) - expect_equal(names(merged), c("age", "name", "name", "test")) - - mockLines3 <- c("{\"name\":\"Michael\", \"name_y\":\"Michael\", \"test\": \"yes\"}", - "{\"name\":\"Andy\", \"name_y\":\"Andy\", \"test\": \"no\"}", - "{\"name\":\"Justin\", \"name_y\":\"Justin\", \"test\": \"yes\"}", - "{\"name\":\"Bob\", \"name_y\":\"Bob\", \"test\": \"yes\"}") - jsonPath3 <- tempfile(pattern="sparkr-test", fileext=".tmp") - writeLines(mockLines3, jsonPath3) - df3 <- jsonFile(sqlContext, jsonPath3) - expect_error(merge(df, df3), - paste("The following column name: name_y occurs more than once in the 'DataFrame'.", - "Please use different suffixes for the intersected columns.", sep = "")) -}) - -test_that("toJSON() returns an RDD of the correct values", { - df <- jsonFile(sqlContext, jsonPath) - testRDD <- toJSON(df) - expect_is(testRDD, "RDD") - expect_equal(SparkR:::getSerializedMode(testRDD), "string") - expect_equal(collect(testRDD)[[1]], mockLines[1]) -}) - -test_that("showDF()", { - df <- jsonFile(sqlContext, jsonPath) - s <- capture.output(showDF(df)) - expected <- paste("+----+-------+\n", - "| age| name|\n", - "+----+-------+\n", - "|null|Michael|\n", - "| 30| Andy|\n", - "| 19| Justin|\n", - "+----+-------+\n", sep="") - expect_output(s , expected) -}) - -test_that("isLocal()", { - df <- jsonFile(sqlContext, jsonPath) - expect_false(isLocal(df)) -}) - -test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) - - lines <- c("{\"name\":\"Bob\", \"age\":24}", - "{\"name\":\"Andy\", \"age\":30}", - "{\"name\":\"James\", \"age\":35}") - jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") - writeLines(lines, jsonPath2) - df2 <- read.df(sqlContext, jsonPath2, "json") - - unioned <- arrange(unionAll(df, df2), df$age) - expect_is(unioned, "DataFrame") - expect_equal(count(unioned), 6) - expect_equal(first(unioned)$name, "Michael") - - unioned2 <- arrange(rbind(unioned, df, df2), df$age) - expect_is(unioned2, "DataFrame") - expect_equal(count(unioned2), 12) - expect_equal(first(unioned2)$name, "Michael") - - excepted <- arrange(except(df, df2), desc(df$age)) - expect_is(unioned, "DataFrame") - expect_equal(count(excepted), 2) - expect_equal(first(excepted)$name, "Justin") - - intersected <- arrange(intersect(df, df2), df$age) - expect_is(unioned, "DataFrame") - expect_equal(count(intersected), 1) - expect_equal(first(intersected)$name, "Andy") -}) - -test_that("withColumn() and withColumnRenamed()", { - df <- jsonFile(sqlContext, jsonPath) - newDF <- withColumn(df, "newAge", df$age + 2) - expect_equal(length(columns(newDF)), 3) - expect_equal(columns(newDF)[3], "newAge") - expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) - - newDF2 <- withColumnRenamed(df, "age", "newerAge") - expect_equal(length(columns(newDF2)), 2) - expect_equal(columns(newDF2)[1], "newerAge") -}) - -test_that("mutate(), transform(), rename() and names()", { - df <- jsonFile(sqlContext, jsonPath) - newDF <- mutate(df, newAge = df$age + 2) - expect_equal(length(columns(newDF)), 3) - expect_equal(columns(newDF)[3], "newAge") - expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) - - newDF2 <- rename(df, newerAge = df$age) - expect_equal(length(columns(newDF2)), 2) - expect_equal(columns(newDF2)[1], "newerAge") - - names(newDF2) <- c("newerName", "evenNewerAge") - expect_equal(length(names(newDF2)), 2) - expect_equal(names(newDF2)[1], "newerName") - - transformedDF <- transform(df, newAge = -df$age, newAge2 = df$age / 2) - expect_equal(length(columns(transformedDF)), 4) - expect_equal(columns(transformedDF)[3], "newAge") - expect_equal(columns(transformedDF)[4], "newAge2") - expect_equal(first(filter(transformedDF, transformedDF$name == "Andy"))$newAge, -30) - - # test if transform on local data frames works - # ensure the proper signature is used - otherwise this will fail to run - attach(airquality) - result <- transform(Ozone, logOzone = log(Ozone)) - expect_equal(nrow(result), 153) - expect_equal(ncol(result), 2) - detach(airquality) -}) - -test_that("write.df() on DataFrame and works with parquetFile", { - df <- jsonFile(sqlContext, jsonPath) - write.df(df, parquetPath, "parquet", mode="overwrite") - parquetDF <- parquetFile(sqlContext, parquetPath) - expect_is(parquetDF, "DataFrame") - expect_equal(count(df), count(parquetDF)) -}) - -test_that("parquetFile works with multiple input paths", { - df <- jsonFile(sqlContext, jsonPath) - write.df(df, parquetPath, "parquet", mode="overwrite") - parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") - write.df(df, parquetPath2, "parquet", mode="overwrite") - parquetDF <- parquetFile(sqlContext, parquetPath, parquetPath2) - expect_is(parquetDF, "DataFrame") - expect_equal(count(parquetDF), count(df) * 2) - - # Test if varargs works with variables - saveMode <- "overwrite" - mergeSchema <- "true" - parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") - write.df(df, parquetPath2, "parquet", mode = saveMode, mergeSchema = mergeSchema) -}) - -test_that("describe() and summarize() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) - stats <- describe(df, "age") - expect_equal(collect(stats)[1, "summary"], "count") - expect_equal(collect(stats)[2, "age"], "24.5") - expect_equal(collect(stats)[3, "age"], "7.7781745930520225") - stats <- describe(df) - expect_equal(collect(stats)[4, "name"], "Andy") - expect_equal(collect(stats)[5, "age"], "30") - - stats2 <- summary(df) - expect_equal(collect(stats2)[4, "name"], "Andy") - expect_equal(collect(stats2)[5, "age"], "30") -}) - -test_that("dropna() and na.omit() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPathNa) - rows <- collect(df) - - # drop with columns - - expected <- rows[!is.na(rows$name),] - actual <- collect(dropna(df, cols = "name")) - expect_identical(expected, actual) - actual <- collect(na.omit(df, cols = "name")) - expect_identical(expected, actual) - - expected <- rows[!is.na(rows$age),] - actual <- collect(dropna(df, cols = "age")) - row.names(expected) <- row.names(actual) - # identical on two dataframes does not work here. Don't know why. - # use identical on all columns as a workaround. - expect_identical(expected$age, actual$age) - expect_identical(expected$height, actual$height) - expect_identical(expected$name, actual$name) - actual <- collect(na.omit(df, cols = "age")) - - expected <- rows[!is.na(rows$age) & !is.na(rows$height),] - actual <- collect(dropna(df, cols = c("age", "height"))) - expect_identical(expected, actual) - actual <- collect(na.omit(df, cols = c("age", "height"))) - expect_identical(expected, actual) - - expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] - actual <- collect(dropna(df)) - expect_identical(expected, actual) - actual <- collect(na.omit(df)) - expect_identical(expected, actual) - - # drop with how - - expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] - actual <- collect(dropna(df)) - expect_identical(expected, actual) - actual <- collect(na.omit(df)) - expect_identical(expected, actual) - - expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),] - actual <- collect(dropna(df, "all")) - expect_identical(expected, actual) - actual <- collect(na.omit(df, "all")) - expect_identical(expected, actual) - - expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] - actual <- collect(dropna(df, "any")) - expect_identical(expected, actual) - actual <- collect(na.omit(df, "any")) - expect_identical(expected, actual) - - expected <- rows[!is.na(rows$age) & !is.na(rows$height),] - actual <- collect(dropna(df, "any", cols = c("age", "height"))) - expect_identical(expected, actual) - actual <- collect(na.omit(df, "any", cols = c("age", "height"))) - expect_identical(expected, actual) - - expected <- rows[!is.na(rows$age) | !is.na(rows$height),] - actual <- collect(dropna(df, "all", cols = c("age", "height"))) - expect_identical(expected, actual) - actual <- collect(na.omit(df, "all", cols = c("age", "height"))) - expect_identical(expected, actual) - - # drop with threshold - - expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,] - actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height"))) - expect_identical(expected, actual) - actual <- collect(na.omit(df, minNonNulls = 2, cols = c("age", "height"))) - expect_identical(expected, actual) - - expected <- rows[as.integer(!is.na(rows$age)) + - as.integer(!is.na(rows$height)) + - as.integer(!is.na(rows$name)) >= 3,] - actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height"))) - expect_identical(expected, actual) - actual <- collect(na.omit(df, minNonNulls = 3, cols = c("name", "age", "height"))) - expect_identical(expected, actual) -}) - -test_that("fillna() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPathNa) - rows <- collect(df) - - # fill with value - - expected <- rows - expected$age[is.na(expected$age)] <- 50 - expected$height[is.na(expected$height)] <- 50.6 - actual <- collect(fillna(df, 50.6)) - expect_identical(expected, actual) - - expected <- rows - expected$name[is.na(expected$name)] <- "unknown" - actual <- collect(fillna(df, "unknown")) - expect_identical(expected, actual) - - expected <- rows - expected$age[is.na(expected$age)] <- 50 - actual <- collect(fillna(df, 50.6, "age")) - expect_identical(expected, actual) - - expected <- rows - expected$name[is.na(expected$name)] <- "unknown" - actual <- collect(fillna(df, "unknown", c("age", "name"))) - expect_identical(expected, actual) - - # fill with named list - - expected <- rows - expected$age[is.na(expected$age)] <- 50 - expected$height[is.na(expected$height)] <- 50.6 - expected$name[is.na(expected$name)] <- "unknown" - actual <- collect(fillna(df, list("age" = 50, "height" = 50.6, "name" = "unknown"))) - expect_identical(expected, actual) -}) - -test_that("crosstab() on a DataFrame", { - rdd <- lapply(parallelize(sc, 0:3), function(x) { - list(paste0("a", x %% 3), paste0("b", x %% 2)) - }) - df <- toDF(rdd, list("a", "b")) - ct <- crosstab(df, "a", "b") - ordered <- ct[order(ct$a_b),] - row.names(ordered) <- NULL - expected <- data.frame("a_b" = c("a0", "a1", "a2"), "b0" = c(1, 0, 1), "b1" = c(1, 1, 0), - stringsAsFactors = FALSE, row.names = NULL) - expect_identical(expected, ordered) -}) - -test_that("cov() and corr() on a DataFrame", { - l <- lapply(c(0:9), function(x) { list(x, x * 2.0) }) - df <- createDataFrame(sqlContext, l, c("singles", "doubles")) - result <- cov(df, "singles", "doubles") - expect_true(abs(result - 55.0 / 3) < 1e-12) - - result <- corr(df, "singles", "doubles") - expect_true(abs(result - 1.0) < 1e-12) - result <- corr(df, "singles", "doubles", "pearson") - expect_true(abs(result - 1.0) < 1e-12) -}) - -test_that("freqItems() on a DataFrame", { - input <- 1:1000 - rdf <- data.frame(numbers = input, letters = as.character(input), - negDoubles = input * -1.0, stringsAsFactors = F) - rdf[ input %% 3 == 0, ] <- c(1, "1", -1) - df <- createDataFrame(sqlContext, rdf) - multiColResults <- freqItems(df, c("numbers", "letters"), support=0.1) - expect_true(1 %in% multiColResults$numbers[[1]]) - expect_true("1" %in% multiColResults$letters[[1]]) - singleColResult <- freqItems(df, "negDoubles", support=0.1) - expect_true(-1 %in% head(singleColResult$negDoubles)[[1]]) - - l <- lapply(c(0:99), function(i) { - if (i %% 2 == 0) { list(1L, -1.0) } - else { list(i, i * -1.0) }}) - df <- createDataFrame(sqlContext, l, c("a", "b")) - result <- freqItems(df, c("a", "b"), 0.4) - expect_identical(result[[1]], list(list(1L, 99L))) - expect_identical(result[[2]], list(list(-1, -99))) -}) - -test_that("sampleBy() on a DataFrame", { - l <- lapply(c(0:99), function(i) { as.character(i %% 3) }) - df <- createDataFrame(sqlContext, l, "key") - fractions <- list("0" = 0.1, "1" = 0.2) - sample <- sampleBy(df, "key", fractions, 0) - result <- collect(orderBy(count(groupBy(sample, "key")), "key")) - expect_identical(as.list(result[1, ]), list(key = "0", count = 2)) - expect_identical(as.list(result[2, ]), list(key = "1", count = 10)) -}) - -test_that("SQL error message is returned from JVM", { - retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) - expect_equal(grepl("Table not found: blah", retError), TRUE) -}) - -test_that("Method as.data.frame as a synonym for collect()", { - irisDF <- createDataFrame(sqlContext, iris) - expect_equal(as.data.frame(irisDF), collect(irisDF)) - irisDF2 <- irisDF[irisDF$Species == "setosa", ] - expect_equal(as.data.frame(irisDF2), collect(irisDF2)) -}) - -test_that("attach() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) - expect_error(age) - attach(df) - expect_is(age, "DataFrame") - expected_age <- data.frame(age = c(NA, 30, 19)) - expect_equal(head(age), expected_age) - stat <- summary(age) - expect_equal(collect(stat)[5, "age"], "30") - age <- age$age + 1 - expect_is(age, "Column") - rm(age) - stat2 <- summary(age) - expect_equal(collect(stat2)[5, "age"], "30") - detach("df") - stat3 <- summary(df[, "age"]) - expect_equal(collect(stat3)[5, "age"], "30") - expect_error(age) -}) - -unlink(parquetPath) -unlink(jsonPath) -unlink(jsonPathNa) diff --git a/R/pkg/inst/tests/jarTest.R b/R/pkg/inst/tests/testthat/jarTest.R similarity index 100% rename from R/pkg/inst/tests/jarTest.R rename to R/pkg/inst/tests/testthat/jarTest.R diff --git a/R/pkg/inst/tests/packageInAJarTest.R b/R/pkg/inst/tests/testthat/packageInAJarTest.R similarity index 90% rename from R/pkg/inst/tests/packageInAJarTest.R rename to R/pkg/inst/tests/testthat/packageInAJarTest.R index 207a37a0cb47f..c26b28b78dee8 100644 --- a/R/pkg/inst/tests/packageInAJarTest.R +++ b/R/pkg/inst/tests/testthat/packageInAJarTest.R @@ -25,6 +25,6 @@ run2 <- myfunc(-4L) sparkR.stop() -if(run1 != 6) quit(save = "no", status = 1) +if (run1 != 6) quit(save = "no", status = 1) -if(run2 != -3) quit(save = "no", status = 1) +if (run2 != -3) quit(save = "no", status = 1) diff --git a/R/pkg/inst/tests/test_Serde.R b/R/pkg/inst/tests/testthat/test_Serde.R similarity index 100% rename from R/pkg/inst/tests/test_Serde.R rename to R/pkg/inst/tests/testthat/test_Serde.R diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/testthat/test_binaryFile.R similarity index 84% rename from R/pkg/inst/tests/test_binaryFile.R rename to R/pkg/inst/tests/testthat/test_binaryFile.R index f2452ed97d2ea..976a7558a816d 100644 --- a/R/pkg/inst/tests/test_binaryFile.R +++ b/R/pkg/inst/tests/testthat/test_binaryFile.R @@ -23,8 +23,8 @@ sc <- sparkR.init() mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("saveAsObjectFile()/objectFile() following textFile() works", { - fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") - fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") + fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) rdd <- textFile(sc, fileName1, 1) @@ -37,7 +37,7 @@ test_that("saveAsObjectFile()/objectFile() following textFile() works", { }) test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { - fileName <- tempfile(pattern="spark-test", fileext=".tmp") + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") l <- list(1, 2, 3) rdd <- parallelize(sc, l, 1) @@ -49,8 +49,8 @@ test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { }) test_that("saveAsObjectFile()/objectFile() following RDD transformations works", { - fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") - fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") + fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) rdd <- textFile(sc, fileName1) @@ -73,8 +73,8 @@ test_that("saveAsObjectFile()/objectFile() following RDD transformations works", }) test_that("saveAsObjectFile()/objectFile() works with multiple paths", { - fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") - fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") + fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") rdd1 <- parallelize(sc, "Spark is pretty.") saveAsObjectFile(rdd1, fileName1) diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/testthat/test_binary_function.R similarity index 94% rename from R/pkg/inst/tests/test_binary_function.R rename to R/pkg/inst/tests/testthat/test_binary_function.R index f054ac9a87d61..7bad4d2a7e106 100644 --- a/R/pkg/inst/tests/test_binary_function.R +++ b/R/pkg/inst/tests/testthat/test_binary_function.R @@ -31,7 +31,7 @@ test_that("union on two RDDs", { actual <- collect(unionRDD(rdd, rdd)) expect_equal(actual, as.list(rep(nums, 2))) - fileName <- tempfile(pattern="spark-test", fileext=".tmp") + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) text.rdd <- textFile(sc, fileName) @@ -74,10 +74,10 @@ test_that("zipPartitions() on RDDs", { actual <- collect(zipPartitions(rdd1, rdd2, rdd3, func = function(x, y, z) { list(list(x, y, z))} )) expect_equal(actual, - list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6)))) + list(list(1, c(1, 2), c(1, 2, 3)), list(2, c(3, 4), c(4, 5, 6)))) mockFile <- c("Spark is pretty.", "Spark is awesome.") - fileName <- tempfile(pattern="spark-test", fileext=".tmp") + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) rdd <- textFile(sc, fileName, 1) diff --git a/R/pkg/inst/tests/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R similarity index 92% rename from R/pkg/inst/tests/test_broadcast.R rename to R/pkg/inst/tests/testthat/test_broadcast.R index bb86a5c922bde..8be6efc3dbed3 100644 --- a/R/pkg/inst/tests/test_broadcast.R +++ b/R/pkg/inst/tests/testthat/test_broadcast.R @@ -25,7 +25,7 @@ nums <- 1:2 rrdd <- parallelize(sc, nums, 2L) test_that("using broadcast variable", { - randomMat <- matrix(nrow=10, ncol=10, data=rnorm(100)) + randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) randomMatBr <- broadcast(sc, randomMat) useBroadcast <- function(x) { @@ -37,7 +37,7 @@ test_that("using broadcast variable", { }) test_that("without using broadcast variable", { - randomMat <- matrix(nrow=10, ncol=10, data=rnorm(100)) + randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) useBroadcast <- function(x) { sum(randomMat * x) diff --git a/R/pkg/inst/tests/test_client.R b/R/pkg/inst/tests/testthat/test_client.R similarity index 76% rename from R/pkg/inst/tests/test_client.R rename to R/pkg/inst/tests/testthat/test_client.R index 8a20991f89af8..a0664f32f31c1 100644 --- a/R/pkg/inst/tests/test_client.R +++ b/R/pkg/inst/tests/testthat/test_client.R @@ -34,3 +34,12 @@ test_that("no package specified doesn't add packages flag", { test_that("multiple packages don't produce a warning", { expect_that(generateSparkSubmitArgs("", "", "", "", c("A", "B")), not(gives_warning())) }) + +test_that("sparkJars sparkPackages as character vectors", { + args <- generateSparkSubmitArgs("", "", c("one.jar", "two.jar", "three.jar"), "", + c("com.databricks:spark-avro_2.10:2.0.1", + "com.databricks:spark-csv_2.10:1.3.0")) + expect_match(args, "--jars one.jar,two.jar,three.jar") + expect_match(args, + "--packages com.databricks:spark-avro_2.10:2.0.1,com.databricks:spark-csv_2.10:1.3.0") +}) diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R new file mode 100644 index 0000000000000..6e06c974c291f --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -0,0 +1,138 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("test functions in sparkR.R") + +test_that("Check masked functions", { + # Check that we are not masking any new function from base, stats, testthat unexpectedly + masked <- conflicts(detail = TRUE)$`package:SparkR` + expect_true("describe" %in% masked) # only when with testthat.. + func <- lapply(masked, function(x) { capture.output(showMethods(x))[[1]] }) + funcSparkROrEmpty <- grepl("\\(package SparkR\\)$|^$", func) + maskedBySparkR <- masked[funcSparkROrEmpty] + namesOfMasked <- c("describe", "cov", "filter", "lag", "na.omit", "predict", "sd", "var", + "colnames", "colnames<-", "intersect", "rank", "rbind", "sample", "subset", + "summary", "transform", "drop", "window") + expect_equal(length(maskedBySparkR), length(namesOfMasked)) + expect_equal(sort(maskedBySparkR), sort(namesOfMasked)) + # above are those reported as masked when `library(SparkR)` + # note that many of these methods are still callable without base:: or stats:: prefix + # there should be a test for each of these, except followings, which are currently "broken" + funcHasAny <- unlist(lapply(masked, function(x) { + any(grepl("=\"ANY\"", capture.output(showMethods(x)[-1]))) + })) + maskedCompletely <- masked[!funcHasAny] + namesOfMaskedCompletely <- c("cov", "filter", "sample") + expect_equal(length(maskedCompletely), length(namesOfMaskedCompletely)) + expect_equal(sort(maskedCompletely), sort(namesOfMaskedCompletely)) +}) + +test_that("repeatedly starting and stopping SparkR", { + for (i in 1:4) { + sc <- sparkR.init() + rdd <- parallelize(sc, 1:20, 2L) + expect_equal(count(rdd), 20) + sparkR.stop() + } +}) + +test_that("repeatedly starting and stopping SparkR SQL", { + for (i in 1:4) { + sc <- sparkR.init() + sqlContext <- sparkRSQL.init(sc) + df <- createDataFrame(sqlContext, data.frame(a = 1:20)) + expect_equal(count(df), 20) + sparkR.stop() + } +}) + +test_that("rdd GC across sparkR.stop", { + sparkR.stop() + sc <- sparkR.init() # sc should get id 0 + rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 + rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 + sparkR.stop() + + sc <- sparkR.init() # sc should get id 0 again + + # GC rdd1 before creating rdd3 and rdd2 after + rm(rdd1) + gc() + + rdd3 <- parallelize(sc, 1:20, 2L) # rdd3 should get id 1 now + rdd4 <- parallelize(sc, 1:10, 2L) # rdd4 should get id 2 now + + rm(rdd2) + gc() + + count(rdd3) + count(rdd4) +}) + +test_that("job group functions can be called", { + sc <- sparkR.init() + setJobGroup(sc, "groupId", "job description", TRUE) + cancelJobGroup(sc, "groupId") + clearJobGroup(sc) +}) + +test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whitelist", { + e <- new.env() + e[["spark.driver.memory"]] <- "512m" + ops <- getClientModeSparkSubmitOpts("sparkrmain", e) + expect_equal("--driver-memory \"512m\" sparkrmain", ops) + + e[["spark.driver.memory"]] <- "5g" + e[["spark.driver.extraClassPath"]] <- "/opt/class_path" # nolint + e[["spark.driver.extraJavaOptions"]] <- "-XX:+UseCompressedOops -XX:+UseCompressedStrings" + e[["spark.driver.extraLibraryPath"]] <- "/usr/local/hadoop/lib" # nolint + e[["random"]] <- "skipthis" + ops2 <- getClientModeSparkSubmitOpts("sparkr-shell", e) + # nolint start + expect_equal(ops2, paste0("--driver-class-path \"/opt/class_path\" --driver-java-options \"", + "-XX:+UseCompressedOops -XX:+UseCompressedStrings\" --driver-library-path \"", + "/usr/local/hadoop/lib\" --driver-memory \"5g\" sparkr-shell")) + # nolint end + + e[["spark.driver.extraClassPath"]] <- "/" # too short + ops3 <- getClientModeSparkSubmitOpts("--driver-memory 4g sparkr-shell2", e) + # nolint start + expect_equal(ops3, paste0("--driver-java-options \"-XX:+UseCompressedOops ", + "-XX:+UseCompressedStrings\" --driver-library-path \"/usr/local/hadoop/lib\"", + " --driver-memory 4g sparkr-shell2")) + # nolint end +}) + +test_that("sparkJars sparkPackages as comma-separated strings", { + expect_warning(processSparkJars(" a, b ")) + jars <- suppressWarnings(processSparkJars(" a, b ")) + expect_equal(jars, c("a", "b")) + + jars <- suppressWarnings(processSparkJars(" abc ,, def ")) + expect_equal(jars, c("abc", "def")) + + jars <- suppressWarnings(processSparkJars(c(" abc ,, def ", "", "xyz", " ", "a,b"))) + expect_equal(jars, c("abc", "def", "xyz", "a", "b")) + + p <- processSparkPackages(c("ghi", "lmn")) + expect_equal(p, c("ghi", "lmn")) + + # check normalizePath + f <- dir()[[1]] + expect_that(processSparkJars(f), not(gives_warning())) + expect_match(processSparkJars(f), f) +}) diff --git a/R/pkg/inst/tests/test_includeJAR.R b/R/pkg/inst/tests/testthat/test_includeJAR.R similarity index 94% rename from R/pkg/inst/tests/test_includeJAR.R rename to R/pkg/inst/tests/testthat/test_includeJAR.R index cc1faeabffe30..f89aa8e507fd5 100644 --- a/R/pkg/inst/tests/test_includeJAR.R +++ b/R/pkg/inst/tests/testthat/test_includeJAR.R @@ -20,7 +20,7 @@ runScript <- function() { sparkHome <- Sys.getenv("SPARK_HOME") sparkTestJarPath <- "R/lib/SparkR/test_support/sparktestjar_2.10-1.0.jar" jarPath <- paste("--jars", shQuote(file.path(sparkHome, sparkTestJarPath))) - scriptPath <- file.path(sparkHome, "R/lib/SparkR/tests/jarTest.R") + scriptPath <- file.path(sparkHome, "R/lib/SparkR/tests/testthat/jarTest.R") submitPath <- file.path(sparkHome, "bin/spark-submit") res <- system2(command = submitPath, args = c(jarPath, scriptPath), diff --git a/R/pkg/inst/tests/test_includePackage.R b/R/pkg/inst/tests/testthat/test_includePackage.R similarity index 100% rename from R/pkg/inst/tests/test_includePackage.R rename to R/pkg/inst/tests/testthat/test_includePackage.R diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R new file mode 100644 index 0000000000000..a9dbd2bdc4cc0 --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -0,0 +1,214 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib functions") + +# Tests for MLlib functions in SparkR + +sc <- sparkR.init() + +sqlContext <- sparkRSQL.init(sc) + +test_that("formula of glm", { + training <- suppressWarnings(createDataFrame(sqlContext, iris)) + # dot minus and intercept vs native glm + model <- glm(Sepal_Width ~ . - Species + 0, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # feature interaction vs native glm + model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # glm should work with long formula + training <- suppressWarnings(createDataFrame(sqlContext, iris)) + training$LongLongLongLongLongName <- training$Sepal_Width + training$VeryLongLongLongLonLongName <- training$Sepal_Length + training$AnotherLongLongLongLongName <- training$Species + model <- glm(LongLongLongLongLongName ~ VeryLongLongLongLonLongName + AnotherLongLongLongLongName, + data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + +test_that("glm and predict", { + training <- suppressWarnings(createDataFrame(sqlContext, iris)) + # gaussian family + model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # poisson family + model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training, + family = poisson(link = identity)) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species, + data = iris, family = poisson(link = identity)), iris)) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # Test stats::predict is working + x <- rnorm(15) + y <- x + rnorm(15) + expect_equal(length(predict(lm(y ~ x))), 15) +}) + +test_that("kmeans", { + newIris <- iris + newIris$Species <- NULL + training <- suppressWarnings(createDataFrame(sqlContext, newIris)) + + # Cache the DataFrame here to work around the bug SPARK-13178. + cache(training) + take(training, 1) + + model <- kmeans(x = training, centers = 2) + sample <- take(select(predict(model, training), "prediction"), 1) + expect_equal(typeof(sample$prediction), "integer") + expect_equal(sample$prediction, 1) + + # Test stats::kmeans is working + statsModel <- kmeans(x = newIris, centers = 2) + expect_equal(sort(unique(statsModel$cluster)), c(1, 2)) + + # Test fitted works on KMeans + fitted.model <- fitted(model) + expect_equal(sort(collect(distinct(select(fitted.model, "prediction")))$prediction), c(0, 1)) + + # Test summary works on KMeans + summary.model <- summary(model) + cluster <- summary.model$cluster + expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1)) +}) + +test_that("naiveBayes", { + # R code to reproduce the result. + # We do not support instance weights yet. So we ignore the frequencies. + # + #' library(e1071) + #' t <- as.data.frame(Titanic) + #' t1 <- t[t$Freq > 0, -5] + #' m <- naiveBayes(Survived ~ ., data = t1) + #' m + #' predict(m, t1) + # + # -- output of 'm' + # + # A-priori probabilities: + # Y + # No Yes + # 0.4166667 0.5833333 + # + # Conditional probabilities: + # Class + # Y 1st 2nd 3rd Crew + # No 0.2000000 0.2000000 0.4000000 0.2000000 + # Yes 0.2857143 0.2857143 0.2857143 0.1428571 + # + # Sex + # Y Male Female + # No 0.5 0.5 + # Yes 0.5 0.5 + # + # Age + # Y Child Adult + # No 0.2000000 0.8000000 + # Yes 0.4285714 0.5714286 + # + # -- output of 'predict(m, t1)' + # + # Yes Yes Yes Yes No No Yes Yes No No Yes Yes Yes Yes Yes Yes Yes Yes No No Yes Yes No No + # + + t <- as.data.frame(Titanic) + t1 <- t[t$Freq > 0, -5] + df <- suppressWarnings(createDataFrame(sqlContext, t1)) + m <- naiveBayes(Survived ~ ., data = df) + s <- summary(m) + expect_equal(as.double(s$apriori[1, "Yes"]), 0.5833333, tolerance = 1e-6) + expect_equal(sum(s$apriori), 1) + expect_equal(as.double(s$tables["Yes", "Age_Adult"]), 0.5714286, tolerance = 1e-6) + p <- collect(select(predict(m, df), "prediction")) + expect_equal(p$prediction, c("Yes", "Yes", "Yes", "Yes", "No", "No", "Yes", "Yes", "No", "No", + "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "No", + "Yes", "Yes", "No", "No")) + + # Test e1071::naiveBayes + if (requireNamespace("e1071", quietly = TRUE)) { + expect_that(m <- e1071::naiveBayes(Survived ~ ., data = t1), not(throws_error())) + expect_equal(as.character(predict(m, t1[1, ])), "Yes") + } +}) + +test_that("survreg", { + # R code to reproduce the result. + # + #' rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0), + #' x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1)) + #' library(survival) + #' model <- survreg(Surv(time, status) ~ x + sex, rData) + #' summary(model) + #' predict(model, data) + # + # -- output of 'summary(model)' + # + # Value Std. Error z p + # (Intercept) 1.315 0.270 4.88 1.07e-06 + # x -0.190 0.173 -1.10 2.72e-01 + # sex -0.253 0.329 -0.77 4.42e-01 + # Log(scale) -1.160 0.396 -2.93 3.41e-03 + # + # -- output of 'predict(model, data)' + # + # 1 2 3 4 5 6 7 + # 3.724591 2.545368 3.079035 3.079035 2.390146 2.891269 2.891269 + # + data <- list(list(4, 1, 0, 0), list(3, 1, 2, 0), list(1, 1, 1, 0), + list(1, 0, 1, 0), list(2, 1, 1, 1), list(2, 1, 0, 1), list(3, 0, 0, 1)) + df <- createDataFrame(sqlContext, data, c("time", "status", "x", "sex")) + model <- survreg(Surv(time, status) ~ x + sex, df) + stats <- summary(model) + coefs <- as.vector(stats$coefficients[, 1]) + rCoefs <- c(1.3149571, -0.1903409, -0.2532618, -1.1599800) + expect_equal(coefs, rCoefs, tolerance = 1e-4) + expect_true(all( + rownames(stats$coefficients) == + c("(Intercept)", "x", "sex", "Log(scale)"))) + p <- collect(select(predict(model, df), "prediction")) + expect_equal(p$prediction, c(3.724591, 2.545368, 3.079035, 3.079035, + 2.390146, 2.891269, 2.891269), tolerance = 1e-4) + + # Test survival::survreg + if (requireNamespace("survival", quietly = TRUE)) { + rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0), + x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1)) + expect_that( + model <- survival::survreg(formula = survival::Surv(time, status) ~ x + sex, data = rData), + not(throws_error())) + expect_equal(predict(model, rData)[[1]], 3.724591, tolerance = 1e-4) + } +}) diff --git a/R/pkg/inst/tests/test_parallelize_collect.R b/R/pkg/inst/tests/testthat/test_parallelize_collect.R similarity index 100% rename from R/pkg/inst/tests/test_parallelize_collect.R rename to R/pkg/inst/tests/testthat/test_parallelize_collect.R diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R similarity index 89% rename from R/pkg/inst/tests/test_rdd.R rename to R/pkg/inst/tests/testthat/test_rdd.R index 71aed2bb9d6a8..b6c8e1dc6c1b7 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -28,8 +28,8 @@ intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) intRdd <- parallelize(sc, intPairs, 2L) test_that("get number of partitions in RDD", { - expect_equal(numPartitions(rdd), 2) - expect_equal(numPartitions(intRdd), 2) + expect_equal(getNumPartitions(rdd), 2) + expect_equal(getNumPartitions(intRdd), 2) }) test_that("first on RDD", { @@ -75,7 +75,7 @@ test_that("mapPartitions on RDD", { test_that("flatMap() on RDDs", { flat <- flatMap(intRdd, function(x) { list(x, x) }) actual <- collect(flat) - expect_equal(actual, rep(intPairs, each=2)) + expect_equal(actual, rep(intPairs, each = 2)) }) test_that("filterRDD on RDD", { @@ -223,14 +223,14 @@ test_that("takeSample() on RDDs", { s <- takeSample(data, TRUE, 100L, seed) expect_equal(length(s), 100L) # Chance of getting all distinct elements is astronomically low, so test we - # got < 100 + # got less than 100 expect_true(length(unique(s)) < 100L) } for (seed in 4:5) { s <- takeSample(data, TRUE, 200L, seed) expect_equal(length(s), 200L) # Chance of getting all distinct elements is still quite low, so test we - # got < 100 + # got less than 100 expect_true(length(unique(s)) < 100L) } }) @@ -245,9 +245,9 @@ test_that("mapValues() on pairwise RDDs", { }) test_that("flatMapValues() on pairwise RDDs", { - l <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) + l <- parallelize(sc, list(list(1, c(1, 2)), list(2, c(3, 4)))) actual <- collect(flatMapValues(l, function(x) { x })) - expect_equal(actual, list(list(1,1), list(1,2), list(2,3), list(2,4))) + expect_equal(actual, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) # Generate x to x+1 for every value actual <- collect(flatMapValues(intRdd, function(x) { x: (x + 1) })) @@ -304,18 +304,18 @@ test_that("repartition/coalesce on RDDs", { # repartition r1 <- repartition(rdd, 2) - expect_equal(numPartitions(r1), 2L) + expect_equal(getNumPartitions(r1), 2L) count <- length(collectPartition(r1, 0L)) expect_true(count >= 8 && count <= 12) r2 <- repartition(rdd, 6) - expect_equal(numPartitions(r2), 6L) + expect_equal(getNumPartitions(r2), 6L) count <- length(collectPartition(r2, 0L)) expect_true(count >= 0 && count <= 4) # coalesce r3 <- coalesce(rdd, 1) - expect_equal(numPartitions(r3), 1L) + expect_equal(getNumPartitions(r3), 1L) count <- length(collectPartition(r3, 0L)) expect_equal(count, 20) }) @@ -448,12 +448,12 @@ test_that("zipRDD() on RDDs", { list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004))) mockFile <- c("Spark is pretty.", "Spark is awesome.") - fileName <- tempfile(pattern="spark-test", fileext=".tmp") + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) rdd <- textFile(sc, fileName, 1) actual <- collect(zipRDD(rdd, rdd)) - expected <- lapply(mockFile, function(x) { list(x ,x) }) + expected <- lapply(mockFile, function(x) { list(x, x) }) expect_equal(actual, expected) rdd1 <- parallelize(sc, 0:1, 1) @@ -484,7 +484,7 @@ test_that("cartesian() on RDDs", { expect_equal(actual, list()) mockFile <- c("Spark is pretty.", "Spark is awesome.") - fileName <- tempfile(pattern="spark-test", fileext=".tmp") + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) rdd <- textFile(sc, fileName) @@ -523,19 +523,19 @@ test_that("subtract() on RDDs", { # subtract by an empty RDD rdd2 <- parallelize(sc, list()) actual <- collect(subtract(rdd1, rdd2)) - expect_equal(as.list(sort(as.vector(actual, mode="integer"))), + expect_equal(as.list(sort(as.vector(actual, mode = "integer"))), l) rdd2 <- parallelize(sc, list(2, 4)) actual <- collect(subtract(rdd1, rdd2)) - expect_equal(as.list(sort(as.vector(actual, mode="integer"))), + expect_equal(as.list(sort(as.vector(actual, mode = "integer"))), list(1, 1, 3)) l <- list("a", "a", "b", "b", "c", "d") rdd1 <- parallelize(sc, l) rdd2 <- parallelize(sc, list("b", "d")) actual <- collect(subtract(rdd1, rdd2)) - expect_equal(as.list(sort(as.vector(actual, mode="character"))), + expect_equal(as.list(sort(as.vector(actual, mode = "character"))), list("a", "a", "c")) }) @@ -585,53 +585,53 @@ test_that("intersection() on RDDs", { }) test_that("join() on pairwise RDDs", { - rdd1 <- parallelize(sc, list(list(1,1), list(2,4))) - rdd2 <- parallelize(sc, list(list(1,2), list(1,3))) + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) + rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) actual <- collect(join(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list(1, list(1, 2)), list(1, list(1, 3))))) - rdd1 <- parallelize(sc, list(list("a",1), list("b",4))) - rdd2 <- parallelize(sc, list(list("a",2), list("a",3))) + rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4))) + rdd2 <- parallelize(sc, list(list("a", 2), list("a", 3))) actual <- collect(join(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list("a", list(1, 2)), list("a", list(1, 3))))) - rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) - rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 2))) + rdd2 <- parallelize(sc, list(list(3, 3), list(4, 4))) actual <- collect(join(rdd1, rdd2, 2L)) expect_equal(actual, list()) - rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) - rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + rdd1 <- parallelize(sc, list(list("a", 1), list("b", 2))) + rdd2 <- parallelize(sc, list(list("c", 3), list("d", 4))) actual <- collect(join(rdd1, rdd2, 2L)) expect_equal(actual, list()) }) test_that("leftOuterJoin() on pairwise RDDs", { - rdd1 <- parallelize(sc, list(list(1,1), list(2,4))) - rdd2 <- parallelize(sc, list(list(1,2), list(1,3))) + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) + rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) expected <- list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) - rdd1 <- parallelize(sc, list(list("a",1), list("b",4))) - rdd2 <- parallelize(sc, list(list("a",2), list("a",3))) + rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4))) + rdd2 <- parallelize(sc, list(list("a", 2), list("a", 3))) actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) expected <- list(list("b", list(4, NULL)), list("a", list(1, 2)), list("a", list(1, 3))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) - rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) - rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 2))) + rdd2 <- parallelize(sc, list(list(3, 3), list(4, 4))) actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) expected <- list(list(1, list(1, NULL)), list(2, list(2, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) - rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) - rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + rdd1 <- parallelize(sc, list(list("a", 1), list("b", 2))) + rdd2 <- parallelize(sc, list(list("c", 3), list("d", 4))) actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) expected <- list(list("b", list(2, NULL)), list("a", list(1, NULL))) expect_equal(sortKeyValueList(actual), @@ -639,57 +639,57 @@ test_that("leftOuterJoin() on pairwise RDDs", { }) test_that("rightOuterJoin() on pairwise RDDs", { - rdd1 <- parallelize(sc, list(list(1,2), list(1,3))) - rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) + rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) + rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) - rdd1 <- parallelize(sc, list(list("a",2), list("a",3))) - rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) + rdd1 <- parallelize(sc, list(list("a", 2), list("a", 3))) + rdd2 <- parallelize(sc, list(list("a", 1), list("b", 4))) actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) - rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) - rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 2))) + rdd2 <- parallelize(sc, list(list(3, 3), list(4, 4))) actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list(3, list(NULL, 3)), list(4, list(NULL, 4))))) - rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) - rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + rdd1 <- parallelize(sc, list(list("a", 1), list("b", 2))) + rdd2 <- parallelize(sc, list(list("c", 3), list("d", 4))) actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list("d", list(NULL, 4)), list("c", list(NULL, 3))))) }) test_that("fullOuterJoin() on pairwise RDDs", { - rdd1 <- parallelize(sc, list(list(1,2), list(1,3), list(3,3))) - rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) + rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) + rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4)), list(3, list(3, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) - rdd1 <- parallelize(sc, list(list("a",2), list("a",3), list("c", 1))) - rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) + rdd1 <- parallelize(sc, list(list("a", 2), list("a", 3), list("c", 1))) + rdd2 <- parallelize(sc, list(list("a", 1), list("b", 4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1)), list("c", list(1, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) - rdd1 <- parallelize(sc, list(list(1,1), list(2,2))) - rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 2))) + rdd2 <- parallelize(sc, list(list(3, 3), list(4, 4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), list(3, list(NULL, 3)), list(4, list(NULL, 4))))) - rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) - rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) + rdd1 <- parallelize(sc, list(list("a", 1), list("b", 2))) + rdd2 <- parallelize(sc, list(list("c", 3), list("d", 4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), @@ -791,3 +791,11 @@ test_that("sampleByKey() on pairwise RDDs", { expect_equal(lookup(sample, 3)[which.min(lookup(sample, 3))] >= 0, TRUE) expect_equal(lookup(sample, 3)[which.max(lookup(sample, 3))] <= 2000, TRUE) }) + +test_that("Test correct concurrency of RRDD.compute()", { + rdd <- parallelize(sc, 1:1000, 100) + jrdd <- getJRDD(lapply(rdd, function(x) { x }), "row") + zrdd <- callJMethod(jrdd, "zip", jrdd) + count <- callJMethod(zrdd, "count") + expect_equal(count, 1000) +}) diff --git a/R/pkg/inst/tests/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R similarity index 98% rename from R/pkg/inst/tests/test_shuffle.R rename to R/pkg/inst/tests/testthat/test_shuffle.R index adf0b91d25fe9..d3d0f8a24d01c 100644 --- a/R/pkg/inst/tests/test_shuffle.R +++ b/R/pkg/inst/tests/testthat/test_shuffle.R @@ -176,8 +176,8 @@ test_that("partitionBy() partitions data correctly", { resultRDD <- partitionBy(numPairsRdd, 2L, partitionByMagnitude) - expected_first <- list(list(1, 100), list(2, 200)) # key < 3 - expected_second <- list(list(4, -1), list(3, 1), list(3, 0)) # key >= 3 + expected_first <- list(list(1, 100), list(2, 200)) # key less than 3 + expected_second <- list(list(4, -1), list(3, 1), list(3, 0)) # key greater than or equal 3 actual_first <- collectPartition(resultRDD, 0L) actual_second <- collectPartition(resultRDD, 1L) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R new file mode 100644 index 0000000000000..d747d4f83f24b --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -0,0 +1,1969 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("SparkSQL functions") + +# Utility function for easily checking the values of a StructField +checkStructField <- function(actual, expectedName, expectedType, expectedNullable) { + expect_equal(class(actual), "structField") + expect_equal(actual$name(), expectedName) + expect_equal(actual$dataType.toString(), expectedType) + expect_equal(actual$nullable(), expectedNullable) +} + +markUtf8 <- function(s) { + Encoding(s) <- "UTF-8" + s +} + +# Tests for SparkSQL functions in SparkR + +sc <- sparkR.init() + +sqlContext <- sparkRSQL.init(sc) + +mockLines <- c("{\"name\":\"Michael\"}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}") +jsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") +parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") +writeLines(mockLines, jsonPath) + +# For test nafunctions, like dropna(), fillna(),... +mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}", + "{\"name\":\"Alice\",\"age\":null,\"height\":164.3}", + "{\"name\":\"David\",\"age\":60,\"height\":null}", + "{\"name\":\"Amy\",\"age\":null,\"height\":null}", + "{\"name\":null,\"age\":null,\"height\":null}") +jsonPathNa <- tempfile(pattern = "sparkr-test", fileext = ".tmp") +writeLines(mockLinesNa, jsonPathNa) + +# For test complex types in DataFrame +mockLinesComplexType <- + c("{\"c1\":[1, 2, 3], \"c2\":[\"a\", \"b\", \"c\"], \"c3\":[1.0, 2.0, 3.0]}", + "{\"c1\":[4, 5, 6], \"c2\":[\"d\", \"e\", \"f\"], \"c3\":[4.0, 5.0, 6.0]}", + "{\"c1\":[7, 8, 9], \"c2\":[\"g\", \"h\", \"i\"], \"c3\":[7.0, 8.0, 9.0]}") +complexTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") +writeLines(mockLinesComplexType, complexTypeJsonPath) + +test_that("calling sparkRSQL.init returns existing SQL context", { + expect_equal(sparkRSQL.init(sc), sqlContext) +}) + +test_that("infer types and check types", { + expect_equal(infer_type(1L), "integer") + expect_equal(infer_type(1.0), "double") + expect_equal(infer_type("abc"), "string") + expect_equal(infer_type(TRUE), "boolean") + expect_equal(infer_type(as.Date("2015-03-11")), "date") + expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp") + expect_equal(infer_type(c(1L, 2L)), "array") + expect_equal(infer_type(list(1L, 2L)), "array") + expect_equal(infer_type(listToStruct(list(a = 1L, b = "2"))), "struct") + e <- new.env() + assign("a", 1L, envir = e) + expect_equal(infer_type(e), "map") + + expect_error(checkType("map"), "Key type in a map must be string or character") + + expect_equal(infer_type(as.raw(c(1, 2, 3))), "binary") +}) + +test_that("structType and structField", { + testField <- structField("a", "string") + expect_is(testField, "structField") + expect_equal(testField$name(), "a") + expect_true(testField$nullable()) + + testSchema <- structType(testField, structField("b", "integer")) + expect_is(testSchema, "structType") + expect_is(testSchema$fields()[[2]], "structField") + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") +}) + +test_that("create DataFrame from RDD", { + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) + df <- createDataFrame(sqlContext, rdd, list("a", "b")) + dfAsDF <- as.DataFrame(sqlContext, rdd, list("a", "b")) + expect_is(df, "DataFrame") + expect_is(dfAsDF, "DataFrame") + expect_equal(count(df), 10) + expect_equal(count(dfAsDF), 10) + expect_equal(nrow(df), 10) + expect_equal(nrow(dfAsDF), 10) + expect_equal(ncol(df), 2) + expect_equal(ncol(dfAsDF), 2) + expect_equal(dim(df), c(10, 2)) + expect_equal(dim(dfAsDF), c(10, 2)) + expect_equal(columns(df), c("a", "b")) + expect_equal(columns(dfAsDF), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + expect_equal(dtypes(dfAsDF), list(c("a", "int"), c("b", "string"))) + + df <- createDataFrame(sqlContext, rdd) + dfAsDF <- as.DataFrame(sqlContext, rdd) + expect_is(df, "DataFrame") + expect_is(dfAsDF, "DataFrame") + expect_equal(columns(df), c("_1", "_2")) + expect_equal(columns(dfAsDF), c("_1", "_2")) + + schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), + structField(x = "b", type = "string", nullable = TRUE)) + df <- createDataFrame(sqlContext, rdd, schema) + expect_is(df, "DataFrame") + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) + df <- createDataFrame(sqlContext, rdd) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + schema <- structType(structField("name", "string"), structField("age", "integer"), + structField("height", "float")) + df <- read.df(sqlContext, jsonPathNa, "json", schema) + df2 <- createDataFrame(sqlContext, toRDD(df), schema) + df2AsDF <- as.DataFrame(sqlContext, toRDD(df), schema) + expect_equal(columns(df2), c("name", "age", "height")) + expect_equal(columns(df2AsDF), c("name", "age", "height")) + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float"))) + expect_equal(dtypes(df2AsDF), list(c("name", "string"), c("age", "int"), c("height", "float"))) + expect_equal(as.list(collect(where(df2, df2$name == "Bob"))), + list(name = "Bob", age = 16, height = 176.5)) + expect_equal(as.list(collect(where(df2AsDF, df2AsDF$name == "Bob"))), + list(name = "Bob", age = 16, height = 176.5)) + + localDF <- data.frame(name = c("John", "Smith", "Sarah"), + age = c(19L, 23L, 18L), + height = c(176.5, 181.4, 173.7)) + df <- createDataFrame(sqlContext, localDF, schema) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + expect_equal(columns(df), c("name", "age", "height")) + expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float"))) + expect_equal(as.list(collect(where(df, df$name == "John"))), + list(name = "John", age = 19L, height = 176.5)) + + ssc <- callJMethod(sc, "sc") + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, + error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)") + df <- read.df(hiveCtx, jsonPathNa, "json", schema) + invisible(insertInto(df, "people")) + expect_equal(collect(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"))$age, + c(16)) + expect_equal(collect(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"))$height, + c(176.5)) +}) + +test_that("convert NAs to null type in DataFrames", { + rdd <- parallelize(sc, list(list(1L, 2L), list(NA, 4L))) + df <- createDataFrame(sqlContext, rdd, list("a", "b")) + expect_true(is.na(collect(df)[2, "a"])) + expect_equal(collect(df)[2, "b"], 4L) + + l <- data.frame(x = 1L, y = c(1L, NA_integer_, 3L)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(df)[2, "x"], 1L) + expect_true(is.na(collect(df)[2, "y"])) + + rdd <- parallelize(sc, list(list(1, 2), list(NA, 4))) + df <- createDataFrame(sqlContext, rdd, list("a", "b")) + expect_true(is.na(collect(df)[2, "a"])) + expect_equal(collect(df)[2, "b"], 4) + + l <- data.frame(x = 1, y = c(1, NA_real_, 3)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(df)[2, "x"], 1) + expect_true(is.na(collect(df)[2, "y"])) + + l <- list("a", "b", NA, "d") + df <- createDataFrame(sqlContext, l) + expect_true(is.na(collect(df)[3, "_1"])) + expect_equal(collect(df)[4, "_1"], "d") + + l <- list("a", "b", NA_character_, "d") + df <- createDataFrame(sqlContext, l) + expect_true(is.na(collect(df)[3, "_1"])) + expect_equal(collect(df)[4, "_1"], "d") + + l <- list(TRUE, FALSE, NA, TRUE) + df <- createDataFrame(sqlContext, l) + expect_true(is.na(collect(df)[3, "_1"])) + expect_equal(collect(df)[4, "_1"], TRUE) +}) + +test_that("toDF", { + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) + df <- toDF(rdd, list("a", "b")) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- toDF(rdd) + expect_is(df, "DataFrame") + expect_equal(columns(df), c("_1", "_2")) + + schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), + structField(x = "b", type = "string", nullable = TRUE)) + df <- toDF(rdd, schema) + expect_is(df, "DataFrame") + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) + df <- toDF(rdd) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) +}) + +test_that("create DataFrame from list or data.frame", { + l <- list(list(1, 2), list(3, 4)) + df <- createDataFrame(sqlContext, l, c("a", "b")) + expect_equal(columns(df), c("a", "b")) + + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) + df <- createDataFrame(sqlContext, l) + expect_equal(columns(df), c("a", "b")) + + a <- 1:3 + b <- c("a", "b", "c") + ldf <- data.frame(a, b) + df <- createDataFrame(sqlContext, ldf) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + expect_equal(count(df), 3) + ldf2 <- collect(df) + expect_equal(ldf$a, ldf2$a) + + irisdf <- suppressWarnings(createDataFrame(sqlContext, iris)) + iris_collected <- collect(irisdf) + expect_equivalent(iris_collected[, -5], iris[, -5]) + expect_equal(iris_collected$Species, as.character(iris$Species)) + + mtcarsdf <- createDataFrame(sqlContext, mtcars) + expect_equivalent(collect(mtcarsdf), mtcars) + + bytes <- as.raw(c(1, 2, 3)) + df <- createDataFrame(sqlContext, list(list(bytes))) + expect_equal(collect(df)[[1]][[1]], bytes) +}) + +test_that("create DataFrame with different data types", { + l <- list(a = 1L, b = 2, c = TRUE, d = "ss", e = as.Date("2012-12-13"), + f = as.POSIXct("2015-03-15 12:13:14.056")) + df <- createDataFrame(sqlContext, list(l)) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "double"), c("c", "boolean"), + c("d", "string"), c("e", "date"), c("f", "timestamp"))) + expect_equal(count(df), 1) + expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) +}) + +test_that("create DataFrame with complex types", { + e <- new.env() + assign("n", 3L, envir = e) + + s <- listToStruct(list(a = "aa", b = 3L)) + + l <- list(as.list(1:10), list("a", "b"), e, s) + df <- createDataFrame(sqlContext, list(l), c("a", "b", "c", "d")) + expect_equal(dtypes(df), list(c("a", "array"), + c("b", "array"), + c("c", "map"), + c("d", "struct"))) + expect_equal(count(df), 1) + ldf <- collect(df) + expect_equal(names(ldf), c("a", "b", "c", "d")) + expect_equal(ldf[1, 1][[1]], l[[1]]) + expect_equal(ldf[1, 2][[1]], l[[2]]) + + e <- ldf$c[[1]] + expect_equal(class(e), "environment") + expect_equal(ls(e), "n") + expect_equal(e$n, 3L) + + s <- ldf$d[[1]] + expect_equal(class(s), "struct") + expect_equal(s$a, "aa") + expect_equal(s$b, 3L) +}) + +test_that("create DataFrame from a data.frame with complex types", { + ldf <- data.frame(row.names = 1:2) + ldf$a_list <- list(list(1, 2), list(3, 4)) + ldf$an_envir <- c(as.environment(list(a = 1, b = 2)), as.environment(list(c = 3))) + + sdf <- createDataFrame(sqlContext, ldf) + collected <- collect(sdf) + + expect_identical(ldf[, 1, FALSE], collected[, 1, FALSE]) + expect_equal(ldf$an_envir, collected$an_envir) +}) + +# For test map type and struct type in DataFrame +mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}", + "{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}", + "{\"name\":\"David\",\"info\":{\"age\":60,\"height\":180}}") +mapTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") +writeLines(mockLinesMapType, mapTypeJsonPath) + +test_that("Collect DataFrame with complex types", { + # ArrayType + df <- read.json(sqlContext, complexTypeJsonPath) + ldf <- collect(df) + expect_equal(nrow(ldf), 3) + expect_equal(ncol(ldf), 3) + expect_equal(names(ldf), c("c1", "c2", "c3")) + expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list (7, 8, 9))) + expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list ("g", "h", "i"))) + expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list (7.0, 8.0, 9.0))) + + # MapType + schema <- structType(structField("name", "string"), + structField("info", "map")) + df <- read.df(sqlContext, mapTypeJsonPath, "json", schema) + expect_equal(dtypes(df), list(c("name", "string"), + c("info", "map"))) + ldf <- collect(df) + expect_equal(nrow(ldf), 3) + expect_equal(ncol(ldf), 2) + expect_equal(names(ldf), c("name", "info")) + expect_equal(ldf$name, c("Bob", "Alice", "David")) + bob <- ldf$info[[1]] + expect_equal(class(bob), "environment") + expect_equal(bob$age, 16) + expect_equal(bob$height, 176.5) + + # StructType + df <- read.json(sqlContext, mapTypeJsonPath) + expect_equal(dtypes(df), list(c("info", "struct"), + c("name", "string"))) + ldf <- collect(df) + expect_equal(nrow(ldf), 3) + expect_equal(ncol(ldf), 2) + expect_equal(names(ldf), c("info", "name")) + expect_equal(ldf$name, c("Bob", "Alice", "David")) + bob <- ldf$info[[1]] + expect_equal(class(bob), "struct") + expect_equal(bob$age, 16) + expect_equal(bob$height, 176.5) +}) + +test_that("read/write json files", { + # Test read.df + df <- read.df(sqlContext, jsonPath, "json") + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + + # Test read.df with a user defined schema + schema <- structType(structField("name", type = "string"), + structField("age", type = "double")) + + df1 <- read.df(sqlContext, jsonPath, "json", schema) + expect_is(df1, "DataFrame") + expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) + + # Test loadDF + df2 <- loadDF(sqlContext, jsonPath, "json", schema) + expect_is(df2, "DataFrame") + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) + + # Test read.json + df <- read.json(sqlContext, jsonPath) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + + # Test write.df + jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".json") + write.df(df, jsonPath2, "json", mode = "overwrite") + + # Test write.json + jsonPath3 <- tempfile(pattern = "jsonPath3", fileext = ".json") + write.json(df, jsonPath3) + + # Test read.json()/jsonFile() works with multiple input paths + jsonDF1 <- read.json(sqlContext, c(jsonPath2, jsonPath3)) + expect_is(jsonDF1, "DataFrame") + expect_equal(count(jsonDF1), 6) + # Suppress warnings because jsonFile is deprecated + jsonDF2 <- suppressWarnings(jsonFile(sqlContext, c(jsonPath2, jsonPath3))) + expect_is(jsonDF2, "DataFrame") + expect_equal(count(jsonDF2), 6) + + unlink(jsonPath2) + unlink(jsonPath3) +}) + +test_that("jsonRDD() on a RDD with json string", { + rdd <- parallelize(sc, mockLines) + expect_equal(count(rdd), 3) + df <- suppressWarnings(jsonRDD(sqlContext, rdd)) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + + rdd2 <- flatMap(rdd, function(x) c(x, x)) + df <- suppressWarnings(jsonRDD(sqlContext, rdd2)) + expect_is(df, "DataFrame") + expect_equal(count(df), 6) +}) + +test_that("test cache, uncache and clearCache", { + df <- read.json(sqlContext, jsonPath) + registerTempTable(df, "table1") + cacheTable(sqlContext, "table1") + uncacheTable(sqlContext, "table1") + clearCache(sqlContext) + dropTempTable(sqlContext, "table1") +}) + +test_that("test tableNames and tables", { + df <- read.json(sqlContext, jsonPath) + registerTempTable(df, "table1") + expect_equal(length(tableNames(sqlContext)), 1) + df <- tables(sqlContext) + expect_equal(count(df), 1) + dropTempTable(sqlContext, "table1") +}) + +test_that("registerTempTable() results in a queryable table and sql() results in a new DataFrame", { + df <- read.json(sqlContext, jsonPath) + registerTempTable(df, "table1") + newdf <- sql(sqlContext, "SELECT * FROM table1 where name = 'Michael'") + expect_is(newdf, "DataFrame") + expect_equal(count(newdf), 1) + dropTempTable(sqlContext, "table1") +}) + +test_that("insertInto() on a registered table", { + df <- read.df(sqlContext, jsonPath, "json") + write.df(df, parquetPath, "parquet", "overwrite") + dfParquet <- read.df(sqlContext, parquetPath, "parquet") + + lines <- c("{\"name\":\"Bob\", \"age\":24}", + "{\"name\":\"James\", \"age\":35}") + jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".tmp") + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + writeLines(lines, jsonPath2) + df2 <- read.df(sqlContext, jsonPath2, "json") + write.df(df2, parquetPath2, "parquet", "overwrite") + dfParquet2 <- read.df(sqlContext, parquetPath2, "parquet") + + registerTempTable(dfParquet, "table1") + insertInto(dfParquet2, "table1") + expect_equal(count(sql(sqlContext, "select * from table1")), 5) + expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Michael") + dropTempTable(sqlContext, "table1") + + registerTempTable(dfParquet, "table1") + insertInto(dfParquet2, "table1", overwrite = TRUE) + expect_equal(count(sql(sqlContext, "select * from table1")), 2) + expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Bob") + dropTempTable(sqlContext, "table1") + + unlink(jsonPath2) + unlink(parquetPath2) +}) + +test_that("tableToDF() returns a new DataFrame", { + df <- read.json(sqlContext, jsonPath) + registerTempTable(df, "table1") + tabledf <- tableToDF(sqlContext, "table1") + expect_is(tabledf, "DataFrame") + expect_equal(count(tabledf), 3) + tabledf2 <- tableToDF(sqlContext, "table1") + expect_equal(count(tabledf2), 3) + dropTempTable(sqlContext, "table1") +}) + +test_that("toRDD() returns an RRDD", { + df <- read.json(sqlContext, jsonPath) + testRDD <- toRDD(df) + expect_is(testRDD, "RDD") + expect_equal(count(testRDD), 3) +}) + +test_that("union on two RDDs created from DataFrames returns an RRDD", { + df <- read.json(sqlContext, jsonPath) + RDD1 <- toRDD(df) + RDD2 <- toRDD(df) + unioned <- unionRDD(RDD1, RDD2) + expect_is(unioned, "RDD") + expect_equal(getSerializedMode(unioned), "byte") + expect_equal(collect(unioned)[[2]]$name, "Andy") +}) + +test_that("union on mixed serialization types correctly returns a byte RRDD", { + # Byte RDD + nums <- 1:10 + rdd <- parallelize(sc, nums, 2L) + + # String RDD + textLines <- c("Michael", + "Andy, 30", + "Justin, 19") + textPath <- tempfile(pattern = "sparkr-textLines", fileext = ".tmp") + writeLines(textLines, textPath) + textRDD <- textFile(sc, textPath) + + df <- read.json(sqlContext, jsonPath) + dfRDD <- toRDD(df) + + unionByte <- unionRDD(rdd, dfRDD) + expect_is(unionByte, "RDD") + expect_equal(getSerializedMode(unionByte), "byte") + expect_equal(collect(unionByte)[[1]], 1) + expect_equal(collect(unionByte)[[12]]$name, "Andy") + + unionString <- unionRDD(textRDD, dfRDD) + expect_is(unionString, "RDD") + expect_equal(getSerializedMode(unionString), "byte") + expect_equal(collect(unionString)[[1]], "Michael") + expect_equal(collect(unionString)[[5]]$name, "Andy") +}) + +test_that("objectFile() works with row serialization", { + objectPath <- tempfile(pattern = "spark-test", fileext = ".tmp") + df <- read.json(sqlContext, jsonPath) + dfRDD <- toRDD(df) + saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) + objectIn <- objectFile(sc, objectPath) + + expect_is(objectIn, "RDD") + expect_equal(getSerializedMode(objectIn), "byte") + expect_equal(collect(objectIn)[[2]]$age, 30) +}) + +test_that("lapply() on a DataFrame returns an RDD with the correct columns", { + df <- read.json(sqlContext, jsonPath) + testRDD <- lapply(df, function(row) { + row$newCol <- row$age + 5 + row + }) + expect_is(testRDD, "RDD") + collected <- collect(testRDD) + expect_equal(collected[[1]]$name, "Michael") + expect_equal(collected[[2]]$newCol, 35) +}) + +test_that("collect() returns a data.frame", { + df <- read.json(sqlContext, jsonPath) + rdf <- collect(df) + expect_true(is.data.frame(rdf)) + expect_equal(names(rdf)[1], "age") + expect_equal(nrow(rdf), 3) + expect_equal(ncol(rdf), 2) + + # collect() returns data correctly from a DataFrame with 0 row + df0 <- limit(df, 0) + rdf <- collect(df0) + expect_true(is.data.frame(rdf)) + expect_equal(names(rdf)[1], "age") + expect_equal(nrow(rdf), 0) + expect_equal(ncol(rdf), 2) + + # collect() correctly handles multiple columns with same name + df <- createDataFrame(sqlContext, list(list(1, 2)), schema = c("name", "name")) + ldf <- collect(df) + expect_equal(names(ldf), c("name", "name")) +}) + +test_that("limit() returns DataFrame with the correct number of rows", { + df <- read.json(sqlContext, jsonPath) + dfLimited <- limit(df, 2) + expect_is(dfLimited, "DataFrame") + expect_equal(count(dfLimited), 2) +}) + +test_that("collect() and take() on a DataFrame return the same number of rows and columns", { + df <- read.json(sqlContext, jsonPath) + expect_equal(nrow(collect(df)), nrow(take(df, 10))) + expect_equal(ncol(collect(df)), ncol(take(df, 10))) +}) + +test_that("collect() support Unicode characters", { + lines <- c("{\"name\":\"안녕하세요\"}", + "{\"name\":\"您好\", \"age\":30}", + "{\"name\":\"こんにちは\", \"age\":19}", + "{\"name\":\"Xin chào\"}") + + jsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + writeLines(lines, jsonPath) + + df <- read.df(sqlContext, jsonPath, "json") + rdf <- collect(df) + expect_true(is.data.frame(rdf)) + expect_equal(rdf$name[1], markUtf8("안녕하세요")) + expect_equal(rdf$name[2], markUtf8("您好")) + expect_equal(rdf$name[3], markUtf8("こんにちは")) + expect_equal(rdf$name[4], markUtf8("Xin chào")) + + df1 <- createDataFrame(sqlContext, rdf) + expect_equal(collect(where(df1, df1$name == markUtf8("您好")))$name, markUtf8("您好")) +}) + +test_that("multiple pipeline transformations result in an RDD with the correct values", { + df <- read.json(sqlContext, jsonPath) + first <- lapply(df, function(row) { + row$age <- row$age + 5 + row + }) + second <- lapply(first, function(row) { + row$testCol <- if (row$age == 35 && !is.na(row$age)) TRUE else FALSE + row + }) + expect_is(second, "RDD") + expect_equal(count(second), 3) + expect_equal(collect(second)[[2]]$age, 35) + expect_true(collect(second)[[2]]$testCol) + expect_false(collect(second)[[3]]$testCol) +}) + +test_that("cache(), persist(), and unpersist() on a DataFrame", { + df <- read.json(sqlContext, jsonPath) + expect_false(df@env$isCached) + cache(df) + expect_true(df@env$isCached) + + unpersist(df) + expect_false(df@env$isCached) + + persist(df, "MEMORY_AND_DISK") + expect_true(df@env$isCached) + + unpersist(df) + expect_false(df@env$isCached) + + # make sure the data is collectable + expect_true(is.data.frame(collect(df))) +}) + +test_that("schema(), dtypes(), columns(), names() return the correct values/format", { + df <- read.json(sqlContext, jsonPath) + testSchema <- schema(df) + expect_equal(length(testSchema$fields()), 2) + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "LongType") + expect_equal(testSchema$fields()[[2]]$dataType.simpleString(), "string") + expect_equal(testSchema$fields()[[1]]$name(), "age") + + testTypes <- dtypes(df) + expect_equal(length(testTypes[[1]]), 2) + expect_equal(testTypes[[1]][1], "age") + + testCols <- columns(df) + expect_equal(length(testCols), 2) + expect_equal(testCols[2], "name") + + testNames <- names(df) + expect_equal(length(testNames), 2) + expect_equal(testNames[2], "name") +}) + +test_that("names() colnames() set the column names", { + df <- read.json(sqlContext, jsonPath) + names(df) <- c("col1", "col2") + expect_equal(colnames(df)[2], "col2") + + colnames(df) <- c("col3", "col4") + expect_equal(names(df)[1], "col3") + + expect_error(colnames(df) <- c("sepal.length", "sepal_width"), + "Colum names cannot contain the '.' symbol.") + expect_error(colnames(df) <- c(1, 2), "Invalid column names.") + expect_error(colnames(df) <- c("a"), + "Column names must have the same length as the number of columns in the dataset.") + expect_error(colnames(df) <- c("1", NA), "Column names cannot be NA.") + + # Note: if this test is broken, remove check for "." character on colnames<- method + irisDF <- suppressWarnings(createDataFrame(sqlContext, iris)) + expect_equal(names(irisDF)[1], "Sepal_Length") + + # Test base::colnames base::names + m2 <- cbind(1, 1:4) + expect_equal(colnames(m2, do.NULL = FALSE), c("col1", "col2")) + colnames(m2) <- c("x", "Y") + expect_equal(colnames(m2), c("x", "Y")) + + z <- list(a = 1, b = "c", c = 1:3) + expect_equal(names(z)[3], "c") + names(z)[3] <- "c2" + expect_equal(names(z)[3], "c2") +}) + +test_that("head() and first() return the correct data", { + df <- read.json(sqlContext, jsonPath) + testHead <- head(df) + expect_equal(nrow(testHead), 3) + expect_equal(ncol(testHead), 2) + + testHead2 <- head(df, 2) + expect_equal(nrow(testHead2), 2) + expect_equal(ncol(testHead2), 2) + + testFirst <- first(df) + expect_equal(nrow(testFirst), 1) + + # head() and first() return the correct data on + # a DataFrame with 0 row + df0 <- limit(df, 0) + + testHead <- head(df0) + expect_equal(nrow(testHead), 0) + expect_equal(ncol(testHead), 2) + + testFirst <- first(df0) + expect_equal(nrow(testFirst), 0) + expect_equal(ncol(testFirst), 2) +}) + +test_that("distinct(), unique() and dropDuplicates() on DataFrames", { + lines <- c("{\"name\":\"Michael\"}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}", + "{\"name\":\"Justin\", \"age\":19}") + jsonPathWithDup <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + writeLines(lines, jsonPathWithDup) + + df <- read.json(sqlContext, jsonPathWithDup) + uniques <- distinct(df) + expect_is(uniques, "DataFrame") + expect_equal(count(uniques), 3) + + uniques2 <- unique(df) + expect_is(uniques2, "DataFrame") + expect_equal(count(uniques2), 3) + + # Test dropDuplicates() + df <- createDataFrame( + sqlContext, + list( + list(2, 1, 2), list(1, 1, 1), + list(1, 2, 1), list(2, 1, 2), + list(2, 2, 2), list(2, 2, 1), + list(2, 1, 1), list(1, 1, 2), + list(1, 2, 2), list(1, 2, 1)), + schema = c("key", "value1", "value2")) + result <- collect(dropDuplicates(df)) + expected <- rbind.data.frame( + c(1, 1, 1), c(1, 1, 2), c(1, 2, 1), + c(1, 2, 2), c(2, 1, 1), c(2, 1, 2), + c(2, 2, 1), c(2, 2, 2)) + names(expected) <- c("key", "value1", "value2") + expect_equivalent( + result[order(result$key, result$value1, result$value2), ], + expected) + + result <- collect(dropDuplicates(df, c("key", "value1"))) + expected <- rbind.data.frame( + c(1, 1, 1), c(1, 2, 1), c(2, 1, 2), c(2, 2, 2)) + names(expected) <- c("key", "value1", "value2") + expect_equivalent( + result[order(result$key, result$value1, result$value2), ], + expected) + + result <- collect(dropDuplicates(df, "key")) + expected <- rbind.data.frame( + c(1, 1, 1), c(2, 1, 2)) + names(expected) <- c("key", "value1", "value2") + expect_equivalent( + result[order(result$key, result$value1, result$value2), ], + expected) +}) + +test_that("sample on a DataFrame", { + df <- read.json(sqlContext, jsonPath) + sampled <- sample(df, FALSE, 1.0) + expect_equal(nrow(collect(sampled)), count(df)) + expect_is(sampled, "DataFrame") + sampled2 <- sample(df, FALSE, 0.1, 0) # set seed for predictable result + expect_true(count(sampled2) < 3) + + count1 <- count(sample(df, FALSE, 0.1, 0)) + count2 <- count(sample(df, FALSE, 0.1, 0)) + expect_equal(count1, count2) + + # Also test sample_frac + sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result + expect_true(count(sampled3) < 3) + + # nolint start + # Test base::sample is working + #expect_equal(length(sample(1:12)), 12) + # nolint end +}) + +test_that("select operators", { + df <- select(read.json(sqlContext, jsonPath), "name", "age") + expect_is(df$name, "Column") + expect_is(df[[2]], "Column") + expect_is(df[["age"]], "Column") + + expect_is(df[, 1], "DataFrame") + expect_equal(columns(df[, 1]), c("name")) + expect_equal(columns(df[, "age"]), c("age")) + df2 <- df[, c("age", "name")] + expect_is(df2, "DataFrame") + expect_equal(columns(df2), c("age", "name")) + + df$age2 <- df$age + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == df$age)), 2) + df$age2 <- df$age * 2 + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == df$age * 2)), 2) +}) + +test_that("select with column", { + df <- read.json(sqlContext, jsonPath) + df1 <- select(df, "name") + expect_equal(columns(df1), c("name")) + expect_equal(count(df1), 3) + + df2 <- select(df, df$age) + expect_equal(columns(df2), c("age")) + expect_equal(count(df2), 3) + + df3 <- select(df, lit("x")) + expect_equal(columns(df3), c("x")) + expect_equal(count(df3), 3) + expect_equal(collect(select(df3, "x"))[[1, 1]], "x") + + df4 <- select(df, c("name", "age")) + expect_equal(columns(df4), c("name", "age")) + expect_equal(count(df4), 3) + + expect_error(select(df, c("name", "age"), "name"), + "To select multiple columns, use a character vector or list for col") +}) + +test_that("drop column", { + df <- select(read.json(sqlContext, jsonPath), "name", "age") + df1 <- drop(df, "name") + expect_equal(columns(df1), c("age")) + + df$age2 <- df$age + df1 <- drop(df, c("name", "age")) + expect_equal(columns(df1), c("age2")) + + df1 <- drop(df, df$age) + expect_equal(columns(df1), c("name", "age2")) + + df$age2 <- NULL + expect_equal(columns(df), c("name", "age")) + df$age3 <- NULL + expect_equal(columns(df), c("name", "age")) + + # Test to make sure base::drop is not masked + expect_equal(drop(1:3 %*% 2:4), 20) +}) + +test_that("subsetting", { + # read.json returns columns in random order + df <- select(read.json(sqlContext, jsonPath), "name", "age") + filtered <- df[df$age > 20, ] + expect_equal(count(filtered), 1) + expect_equal(columns(filtered), c("name", "age")) + expect_equal(collect(filtered)$name, "Andy") + + df2 <- df[df$age == 19, 1] + expect_is(df2, "DataFrame") + expect_equal(count(df2), 1) + expect_equal(columns(df2), c("name")) + expect_equal(collect(df2)$name, "Justin") + + df3 <- df[df$age > 20, 2] + expect_equal(count(df3), 1) + expect_equal(columns(df3), c("age")) + + df4 <- df[df$age %in% c(19, 30), 1:2] + expect_equal(count(df4), 2) + expect_equal(columns(df4), c("name", "age")) + + df5 <- df[df$age %in% c(19), c(1, 2)] + expect_equal(count(df5), 1) + expect_equal(columns(df5), c("name", "age")) + + df6 <- subset(df, df$age %in% c(30), c(1, 2)) + expect_equal(count(df6), 1) + expect_equal(columns(df6), c("name", "age")) + + df7 <- subset(df, select = "name") + expect_equal(count(df7), 3) + expect_equal(columns(df7), c("name")) + + # Test base::subset is working + expect_equal(nrow(subset(airquality, Temp > 80, select = c(Ozone, Temp))), 68) +}) + +test_that("selectExpr() on a DataFrame", { + df <- read.json(sqlContext, jsonPath) + selected <- selectExpr(df, "age * 2") + expect_equal(names(selected), "(age * 2)") + expect_equal(collect(selected), collect(select(df, df$age * 2L))) + + selected2 <- selectExpr(df, "name as newName", "abs(age) as age") + expect_equal(names(selected2), c("newName", "age")) + expect_equal(count(selected2), 3) +}) + +test_that("expr() on a DataFrame", { + df <- read.json(sqlContext, jsonPath) + expect_equal(collect(select(df, expr("abs(-123)")))[1, 1], 123) +}) + +test_that("column calculation", { + df <- read.json(sqlContext, jsonPath) + d <- collect(select(df, alias(df$age + 1, "age2"))) + expect_equal(names(d), c("age2")) + df2 <- select(df, lower(df$name), abs(df$age)) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) +}) + +test_that("test HiveContext", { + ssc <- callJMethod(sc, "sc") + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, + error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + df <- createExternalTable(hiveCtx, "json", jsonPath, "json") + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + df2 <- sql(hiveCtx, "select * from json") + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) + + jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + invisible(saveAsTable(df, "json2", "json", "append", path = jsonPath2)) + df3 <- sql(hiveCtx, "select * from json2") + expect_is(df3, "DataFrame") + expect_equal(count(df3), 3) + unlink(jsonPath2) + + hivetestDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + invisible(saveAsTable(df, "hivetestbl", path = hivetestDataPath)) + df4 <- sql(hiveCtx, "select * from hivetestbl") + expect_is(df4, "DataFrame") + expect_equal(count(df4), 3) + unlink(hivetestDataPath) + + parquetDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + invisible(saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath)) + df5 <- sql(hiveCtx, "select * from parquetest") + expect_is(df5, "DataFrame") + expect_equal(count(df5), 3) + unlink(parquetDataPath) +}) + +test_that("column operators", { + c <- column("a") + c2 <- (- c + 1 - 2) * 3 / 4.0 + c3 <- (c + c2 - c2) * c2 %% c2 + c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3) + c5 <- c2 ^ c3 ^ c4 +}) + +test_that("column functions", { + c <- column("a") + c1 <- abs(c) + acos(c) + approxCountDistinct(c) + ascii(c) + asin(c) + atan(c) + c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c) + c3 <- cosh(c) + count(c) + crc32(c) + hash(c) + exp(c) + c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c) + c5 <- hour(c) + initcap(c) + last(c) + last_day(c) + length(c) + c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c) + c7 <- mean(c) + min(c) + month(c) + negate(c) + quarter(c) + c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + sha1(c) + c9 <- signum(c) + sin(c) + sinh(c) + size(c) + stddev(c) + soundex(c) + sqrt(c) + sum(c) + c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c) + c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) + c12 <- variance(c) + c13 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1) + c14 <- cume_dist() + ntile(1) + corr(c, c1) + c15 <- dense_rank() + percent_rank() + rank() + row_number() + c16 <- is.nan(c) + isnan(c) + isNaN(c) + c17 <- cov(c, c1) + cov("c", "c1") + covar_samp(c, c1) + covar_samp("c", "c1") + c18 <- covar_pop(c, c1) + covar_pop("c", "c1") + + # Test if base::is.nan() is exposed + expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) + + # Test if base::rank() is exposed + expect_equal(class(rank())[[1]], "Column") + expect_equal(rank(1:3), as.numeric(c(1:3))) + + df <- read.json(sqlContext, jsonPath) + df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20))) + expect_equal(collect(df2)[[2, 1]], TRUE) + expect_equal(collect(df2)[[2, 2]], FALSE) + expect_equal(collect(df2)[[3, 1]], FALSE) + expect_equal(collect(df2)[[3, 2]], TRUE) + + df3 <- select(df, between(df$name, c("Apache", "Spark"))) + expect_equal(collect(df3)[[1, 1]], TRUE) + expect_equal(collect(df3)[[2, 1]], FALSE) + expect_equal(collect(df3)[[3, 1]], TRUE) + + df4 <- select(df, countDistinct(df$age, df$name)) + expect_equal(collect(df4)[[1, 1]], 2) + + expect_equal(collect(select(df, sum(df$age)))[1, 1], 49) + expect_true(abs(collect(select(df, stddev(df$age)))[1, 1] - 7.778175) < 1e-6) + expect_equal(collect(select(df, var_pop(df$age)))[1, 1], 30.25) + + df5 <- createDataFrame(sqlContext, list(list(a = "010101"))) + expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") + + # Test array_contains() and sort_array() + df <- createDataFrame(sqlContext, list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) + result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] + expect_equal(result, c(TRUE, FALSE)) + + result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]] + expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L))) + result <- collect(select(df, sort_array(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) + + # Test that stats::lag is working + expect_equal(length(lag(ldeaths, 12)), 72) + + # Test struct() + df <- createDataFrame(sqlContext, + list(list(1L, 2L, 3L), list(4L, 5L, 6L)), + schema = c("a", "b", "c")) + result <- collect(select(df, struct("a", "c"))) + expected <- data.frame(row.names = 1:2) + expected$"struct(a, c)" <- list(listToStruct(list(a = 1L, c = 3L)), + listToStruct(list(a = 4L, c = 6L))) + expect_equal(result, expected) + + result <- collect(select(df, struct(df$a, df$b))) + expected <- data.frame(row.names = 1:2) + expected$"struct(a, b)" <- list(listToStruct(list(a = 1L, b = 2L)), + listToStruct(list(a = 4L, b = 5L))) + expect_equal(result, expected) + + # Test encode(), decode() + bytes <- as.raw(c(0xe5, 0xa4, 0xa7, 0xe5, 0x8d, 0x83, 0xe4, 0xb8, 0x96, 0xe7, 0x95, 0x8c)) + df <- createDataFrame(sqlContext, + list(list(markUtf8("大千世界"), "utf-8", bytes)), + schema = c("a", "b", "c")) + result <- collect(select(df, encode(df$a, "utf-8"), decode(df$c, "utf-8"))) + expect_equal(result[[1]][[1]], bytes) + expect_equal(result[[2]], markUtf8("大千世界")) + + # Test first(), last() + df <- read.json(sqlContext, jsonPath) + expect_equal(collect(select(df, first(df$age)))[[1]], NA) + expect_equal(collect(select(df, first(df$age, TRUE)))[[1]], 30) + expect_equal(collect(select(df, first("age")))[[1]], NA) + expect_equal(collect(select(df, first("age", TRUE)))[[1]], 30) + expect_equal(collect(select(df, last(df$age)))[[1]], 19) + expect_equal(collect(select(df, last(df$age, TRUE)))[[1]], 19) + expect_equal(collect(select(df, last("age")))[[1]], 19) + expect_equal(collect(select(df, last("age", TRUE)))[[1]], 19) +}) + +test_that("column binary mathfunctions", { + lines <- c("{\"a\":1, \"b\":5}", + "{\"a\":2, \"b\":6}", + "{\"a\":3, \"b\":7}", + "{\"a\":4, \"b\":8}") + jsonPathWithDup <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + writeLines(lines, jsonPathWithDup) + df <- read.json(sqlContext, jsonPathWithDup) + expect_equal(collect(select(df, atan2(df$a, df$b)))[1, "ATAN2(a, b)"], atan2(1, 5)) + expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6)) + expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7)) + expect_equal(collect(select(df, atan2(df$a, df$b)))[4, "ATAN2(a, b)"], atan2(4, 8)) + ## nolint start + expect_equal(collect(select(df, hypot(df$a, df$b)))[1, "HYPOT(a, b)"], sqrt(1^2 + 5^2)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[2, "HYPOT(a, b)"], sqrt(2^2 + 6^2)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[3, "HYPOT(a, b)"], sqrt(3^2 + 7^2)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[4, "HYPOT(a, b)"], sqrt(4^2 + 8^2)) + ## nolint end + expect_equal(collect(select(df, shiftLeft(df$b, 1)))[4, 1], 16) + expect_equal(collect(select(df, shiftRight(df$b, 1)))[4, 1], 4) + expect_equal(collect(select(df, shiftRightUnsigned(df$b, 1)))[4, 1], 4) + expect_equal(class(collect(select(df, rand()))[2, 1]), "numeric") + expect_equal(collect(select(df, rand(1)))[1, 1], 0.134, tolerance = 0.01) + expect_equal(class(collect(select(df, randn()))[2, 1]), "numeric") + expect_equal(collect(select(df, randn(1)))[1, 1], -1.03, tolerance = 0.01) +}) + +test_that("string operators", { + df <- read.json(sqlContext, jsonPath) + expect_equal(count(where(df, like(df$name, "A%"))), 1) + expect_equal(count(where(df, startsWith(df$name, "A"))), 1) + expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") + expect_equal(collect(select(df, cast(df$age, "string")))[[2, 1]], "30") + expect_equal(collect(select(df, concat(df$name, lit(":"), df$age)))[[2, 1]], "Andy:30") + expect_equal(collect(select(df, concat_ws(":", df$name)))[[2, 1]], "Andy") + expect_equal(collect(select(df, concat_ws(":", df$name, df$age)))[[2, 1]], "Andy:30") + expect_equal(collect(select(df, instr(df$name, "i")))[, 1], c(2, 0, 5)) + expect_equal(collect(select(df, format_number(df$age, 2)))[2, 1], "30.00") + expect_equal(collect(select(df, sha1(df$name)))[2, 1], + "ab5a000e88b5d9d0fa2575f5c6263eb93452405d") + expect_equal(collect(select(df, sha2(df$name, 256)))[2, 1], + "80f2aed3c618c423ddf05a2891229fba44942d907173152442cf6591441ed6dc") + expect_equal(collect(select(df, format_string("Name:%s", df$name)))[2, 1], "Name:Andy") + expect_equal(collect(select(df, format_string("%s, %d", df$name, df$age)))[2, 1], "Andy, 30") + expect_equal(collect(select(df, regexp_extract(df$name, "(n.y)", 1)))[2, 1], "ndy") + expect_equal(collect(select(df, regexp_replace(df$name, "(n.y)", "ydn")))[2, 1], "Aydn") + + l2 <- list(list(a = "aaads")) + df2 <- createDataFrame(sqlContext, l2) + expect_equal(collect(select(df2, locate("aa", df2$a)))[1, 1], 1) + expect_equal(collect(select(df2, locate("aa", df2$a, 1)))[1, 1], 2) + expect_equal(collect(select(df2, lpad(df2$a, 8, "#")))[1, 1], "###aaads") # nolint + expect_equal(collect(select(df2, rpad(df2$a, 8, "#")))[1, 1], "aaads###") # nolint + + l3 <- list(list(a = "a.b.c.d")) + df3 <- createDataFrame(sqlContext, l3) + expect_equal(collect(select(df3, substring_index(df3$a, ".", 2)))[1, 1], "a.b") + expect_equal(collect(select(df3, substring_index(df3$a, ".", -3)))[1, 1], "b.c.d") + expect_equal(collect(select(df3, translate(df3$a, "bc", "12")))[1, 1], "a.1.2.d") +}) + +test_that("date functions on a DataFrame", { + .originalTimeZone <- Sys.getenv("TZ") + Sys.setenv(TZ = "UTC") + l <- list(list(a = 1L, b = as.Date("2012-12-13")), + list(a = 2L, b = as.Date("2013-12-14")), + list(a = 3L, b = as.Date("2014-12-15"))) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(select(df, dayofmonth(df$b)))[, 1], c(13, 14, 15)) + expect_equal(collect(select(df, dayofyear(df$b)))[, 1], c(348, 348, 349)) + expect_equal(collect(select(df, weekofyear(df$b)))[, 1], c(50, 50, 51)) + expect_equal(collect(select(df, year(df$b)))[, 1], c(2012, 2013, 2014)) + expect_equal(collect(select(df, month(df$b)))[, 1], c(12, 12, 12)) + expect_equal(collect(select(df, last_day(df$b)))[, 1], + c(as.Date("2012-12-31"), as.Date("2013-12-31"), as.Date("2014-12-31"))) + expect_equal(collect(select(df, next_day(df$b, "MONDAY")))[, 1], + c(as.Date("2012-12-17"), as.Date("2013-12-16"), as.Date("2014-12-22"))) + expect_equal(collect(select(df, date_format(df$b, "y")))[, 1], c("2012", "2013", "2014")) + expect_equal(collect(select(df, add_months(df$b, 3)))[, 1], + c(as.Date("2013-03-13"), as.Date("2014-03-14"), as.Date("2015-03-15"))) + expect_equal(collect(select(df, date_add(df$b, 1)))[, 1], + c(as.Date("2012-12-14"), as.Date("2013-12-15"), as.Date("2014-12-16"))) + expect_equal(collect(select(df, date_sub(df$b, 1)))[, 1], + c(as.Date("2012-12-12"), as.Date("2013-12-13"), as.Date("2014-12-14"))) + + l2 <- list(list(a = 1L, b = as.POSIXlt("2012-12-13 12:34:00", tz = "UTC")), + list(a = 2L, b = as.POSIXlt("2014-12-15 01:24:34", tz = "UTC"))) + df2 <- createDataFrame(sqlContext, l2) + expect_equal(collect(select(df2, minute(df2$b)))[, 1], c(34, 24)) + expect_equal(collect(select(df2, second(df2$b)))[, 1], c(0, 34)) + expect_equal(collect(select(df2, from_utc_timestamp(df2$b, "JST")))[, 1], + c(as.POSIXlt("2012-12-13 21:34:00 UTC"), as.POSIXlt("2014-12-15 10:24:34 UTC"))) + expect_equal(collect(select(df2, to_utc_timestamp(df2$b, "JST")))[, 1], + c(as.POSIXlt("2012-12-13 03:34:00 UTC"), as.POSIXlt("2014-12-14 16:24:34 UTC"))) + expect_more_than(collect(select(df2, unix_timestamp()))[1, 1], 0) + expect_more_than(collect(select(df2, unix_timestamp(df2$b)))[1, 1], 0) + expect_more_than(collect(select(df2, unix_timestamp(lit("2015-01-01"), "yyyy-MM-dd")))[1, 1], 0) + + l3 <- list(list(a = 1000), list(a = -1000)) + df3 <- createDataFrame(sqlContext, l3) + result31 <- collect(select(df3, from_unixtime(df3$a))) + expect_equal(grep("\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}", result31[, 1], perl = TRUE), + c(1, 2)) + result32 <- collect(select(df3, from_unixtime(df3$a, "yyyy"))) + expect_equal(grep("\\d{4}", result32[, 1]), c(1, 2)) + Sys.setenv(TZ = .originalTimeZone) +}) + +test_that("greatest() and least() on a DataFrame", { + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(select(df, greatest(df$a, df$b)))[, 1], c(2, 4)) + expect_equal(collect(select(df, least(df$a, df$b)))[, 1], c(1, 3)) +}) + +test_that("time windowing (window()) with all inputs", { + df <- createDataFrame(sqlContext, data.frame(t = c("2016-03-11 09:00:07"), v = c(1))) + df$window <- window(df$t, "5 seconds", "5 seconds", "0 seconds") + local <- collect(df)$v + # Not checking time windows because of possible time zone issues. Just checking that the function + # works + expect_equal(local, c(1)) +}) + +test_that("time windowing (window()) with slide duration", { + df <- createDataFrame(sqlContext, data.frame(t = c("2016-03-11 09:00:07"), v = c(1))) + df$window <- window(df$t, "5 seconds", "2 seconds") + local <- collect(df)$v + # Not checking time windows because of possible time zone issues. Just checking that the function + # works + expect_equal(local, c(1, 1)) +}) + +test_that("time windowing (window()) with start time", { + df <- createDataFrame(sqlContext, data.frame(t = c("2016-03-11 09:00:07"), v = c(1))) + df$window <- window(df$t, "5 seconds", startTime = "2 seconds") + local <- collect(df)$v + # Not checking time windows because of possible time zone issues. Just checking that the function + # works + expect_equal(local, c(1)) +}) + +test_that("time windowing (window()) with just window duration", { + df <- createDataFrame(sqlContext, data.frame(t = c("2016-03-11 09:00:07"), v = c(1))) + df$window <- window(df$t, "5 seconds") + local <- collect(df)$v + # Not checking time windows because of possible time zone issues. Just checking that the function + # works + expect_equal(local, c(1)) +}) + +test_that("when(), otherwise() and ifelse() on a DataFrame", { + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(select(df, when(df$a > 1 & df$b > 2, 1)))[, 1], c(NA, 1)) + expect_equal(collect(select(df, otherwise(when(df$a > 1, 1), 0)))[, 1], c(0, 1)) + expect_equal(collect(select(df, ifelse(df$a > 1 & df$b > 2, 0, 1)))[, 1], c(1, 0)) +}) + +test_that("when(), otherwise() and ifelse() with column on a DataFrame", { + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(select(df, when(df$a > 1 & df$b > 2, lit(1))))[, 1], c(NA, 1)) + expect_equal(collect(select(df, otherwise(when(df$a > 1, lit(1)), lit(0))))[, 1], c(0, 1)) + expect_equal(collect(select(df, ifelse(df$a > 1 & df$b > 2, lit(0), lit(1))))[, 1], c(1, 0)) +}) + +test_that("group by, agg functions", { + df <- read.json(sqlContext, jsonPath) + df1 <- agg(df, name = "max", age = "sum") + expect_equal(1, count(df1)) + df1 <- agg(df, age2 = max(df$age)) + expect_equal(1, count(df1)) + expect_equal(columns(df1), c("age2")) + + gd <- groupBy(df, "name") + expect_is(gd, "GroupedData") + df2 <- count(gd) + expect_is(df2, "DataFrame") + expect_equal(3, count(df2)) + + # Also test group_by, summarize, mean + gd1 <- group_by(df, "name") + expect_is(gd1, "GroupedData") + df_summarized <- summarize(gd, mean_age = mean(df$age)) + expect_is(df_summarized, "DataFrame") + expect_equal(3, count(df_summarized)) + + df3 <- agg(gd, age = "stddev") + expect_is(df3, "DataFrame") + df3_local <- collect(df3) + expect_true(is.nan(df3_local[df3_local$name == "Andy", ][1, 2])) + + df4 <- agg(gd, sumAge = sum(df$age)) + expect_is(df4, "DataFrame") + expect_equal(3, count(df4)) + expect_equal(columns(df4), c("name", "sumAge")) + + df5 <- sum(gd, "age") + expect_is(df5, "DataFrame") + expect_equal(3, count(df5)) + + expect_equal(3, count(mean(gd))) + expect_equal(3, count(max(gd))) + expect_equal(30, collect(max(gd))[2, 2]) + expect_equal(1, collect(count(gd))[1, 2]) + + mockLines2 <- c("{\"name\":\"ID1\", \"value\": \"10\"}", + "{\"name\":\"ID1\", \"value\": \"10\"}", + "{\"name\":\"ID1\", \"value\": \"22\"}", + "{\"name\":\"ID2\", \"value\": \"-3\"}") + jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + writeLines(mockLines2, jsonPath2) + gd2 <- groupBy(read.json(sqlContext, jsonPath2), "name") + df6 <- agg(gd2, value = "sum") + df6_local <- collect(df6) + expect_equal(42, df6_local[df6_local$name == "ID1", ][1, 2]) + expect_equal(-3, df6_local[df6_local$name == "ID2", ][1, 2]) + + df7 <- agg(gd2, value = "stddev") + df7_local <- collect(df7) + expect_true(abs(df7_local[df7_local$name == "ID1", ][1, 2] - 6.928203) < 1e-6) + expect_true(is.nan(df7_local[df7_local$name == "ID2", ][1, 2])) + + mockLines3 <- c("{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}", + "{\"name\":\"Justin\", \"age\":1}") + jsonPath3 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + writeLines(mockLines3, jsonPath3) + df8 <- read.json(sqlContext, jsonPath3) + gd3 <- groupBy(df8, "name") + gd3_local <- collect(sum(gd3)) + expect_equal(60, gd3_local[gd3_local$name == "Andy", ][1, 2]) + expect_equal(20, gd3_local[gd3_local$name == "Justin", ][1, 2]) + + expect_true(abs(collect(agg(df, sd(df$age)))[1, 1] - 7.778175) < 1e-6) + gd3_local <- collect(agg(gd3, var(df8$age))) + expect_equal(162, gd3_local[gd3_local$name == "Justin", ][1, 2]) + + # Test stats::sd, stats::var are working + expect_true(abs(sd(1:2) - 0.7071068) < 1e-6) + expect_true(abs(var(1:5, 1:5) - 2.5) < 1e-6) + + unlink(jsonPath2) + unlink(jsonPath3) +}) + +test_that("arrange() and orderBy() on a DataFrame", { + df <- read.json(sqlContext, jsonPath) + sorted <- arrange(df, df$age) + expect_equal(collect(sorted)[1, 2], "Michael") + + sorted2 <- arrange(df, "name", decreasing = FALSE) + expect_equal(collect(sorted2)[2, "age"], 19) + + sorted3 <- orderBy(df, asc(df$age)) + expect_true(is.na(first(sorted3)$age)) + expect_equal(collect(sorted3)[2, "age"], 19) + + sorted4 <- orderBy(df, desc(df$name)) + expect_equal(first(sorted4)$name, "Michael") + expect_equal(collect(sorted4)[3, "name"], "Andy") + + sorted5 <- arrange(df, "age", "name", decreasing = TRUE) + expect_equal(collect(sorted5)[1, 2], "Andy") + + sorted6 <- arrange(df, "age", "name", decreasing = c(T, F)) + expect_equal(collect(sorted6)[1, 2], "Andy") + + sorted7 <- arrange(df, "name", decreasing = FALSE) + expect_equal(collect(sorted7)[2, "age"], 19) +}) + +test_that("filter() on a DataFrame", { + df <- read.json(sqlContext, jsonPath) + filtered <- filter(df, "age > 20") + expect_equal(count(filtered), 1) + expect_equal(collect(filtered)$name, "Andy") + filtered2 <- where(df, df$name != "Michael") + expect_equal(count(filtered2), 2) + expect_equal(collect(filtered2)$age[2], 19) + + # test suites for %in% + filtered3 <- filter(df, "age in (19)") + expect_equal(count(filtered3), 1) + filtered4 <- filter(df, "age in (19, 30)") + expect_equal(count(filtered4), 2) + filtered5 <- where(df, df$age %in% c(19)) + expect_equal(count(filtered5), 1) + filtered6 <- where(df, df$age %in% c(19, 30)) + expect_equal(count(filtered6), 2) + + # Test stats::filter is working + #expect_true(is.ts(filter(1:100, rep(1, 3)))) # nolint +}) + +test_that("join() and merge() on a DataFrame", { + df <- read.json(sqlContext, jsonPath) + + mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", + "{\"name\":\"Andy\", \"test\": \"no\"}", + "{\"name\":\"Justin\", \"test\": \"yes\"}", + "{\"name\":\"Bob\", \"test\": \"yes\"}") + jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + writeLines(mockLines2, jsonPath2) + df2 <- read.json(sqlContext, jsonPath2) + + joined <- join(df, df2) + expect_equal(names(joined), c("age", "name", "name", "test")) + expect_equal(count(joined), 12) + expect_equal(names(collect(joined)), c("age", "name", "name", "test")) + + joined2 <- join(df, df2, df$name == df2$name) + expect_equal(names(joined2), c("age", "name", "name", "test")) + expect_equal(count(joined2), 3) + + joined3 <- join(df, df2, df$name == df2$name, "rightouter") + expect_equal(names(joined3), c("age", "name", "name", "test")) + expect_equal(count(joined3), 4) + expect_true(is.na(collect(orderBy(joined3, joined3$age))$age[2])) + + joined4 <- select(join(df, df2, df$name == df2$name, "outer"), + alias(df$age + 5, "newAge"), df$name, df2$test) + expect_equal(names(joined4), c("newAge", "name", "test")) + expect_equal(count(joined4), 4) + expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24) + + joined5 <- join(df, df2, df$name == df2$name, "leftouter") + expect_equal(names(joined5), c("age", "name", "name", "test")) + expect_equal(count(joined5), 3) + expect_true(is.na(collect(orderBy(joined5, joined5$age))$age[1])) + + joined6 <- join(df, df2, df$name == df2$name, "inner") + expect_equal(names(joined6), c("age", "name", "name", "test")) + expect_equal(count(joined6), 3) + + joined7 <- join(df, df2, df$name == df2$name, "leftsemi") + expect_equal(names(joined7), c("age", "name")) + expect_equal(count(joined7), 3) + + joined8 <- join(df, df2, df$name == df2$name, "left_outer") + expect_equal(names(joined8), c("age", "name", "name", "test")) + expect_equal(count(joined8), 3) + expect_true(is.na(collect(orderBy(joined8, joined8$age))$age[1])) + + joined9 <- join(df, df2, df$name == df2$name, "right_outer") + expect_equal(names(joined9), c("age", "name", "name", "test")) + expect_equal(count(joined9), 4) + expect_true(is.na(collect(orderBy(joined9, joined9$age))$age[2])) + + merged <- merge(df, df2, by.x = "name", by.y = "name", all.x = TRUE, all.y = TRUE) + expect_equal(count(merged), 4) + expect_equal(names(merged), c("age", "name_x", "name_y", "test")) + expect_equal(collect(orderBy(merged, merged$name_x))$age[3], 19) + + merged <- merge(df, df2, suffixes = c("-X", "-Y")) + expect_equal(count(merged), 3) + expect_equal(names(merged), c("age", "name-X", "name-Y", "test")) + expect_equal(collect(orderBy(merged, merged$"name-X"))$age[1], 30) + + merged <- merge(df, df2, by = "name", suffixes = c("-X", "-Y"), sort = FALSE) + expect_equal(count(merged), 3) + expect_equal(names(merged), c("age", "name-X", "name-Y", "test")) + expect_equal(collect(orderBy(merged, merged$"name-Y"))$"name-X"[3], "Michael") + + merged <- merge(df, df2, by = "name", all = T, sort = T) + expect_equal(count(merged), 4) + expect_equal(names(merged), c("age", "name_x", "name_y", "test")) + expect_equal(collect(orderBy(merged, merged$"name_y"))$"name_x"[1], "Andy") + + merged <- merge(df, df2, by = NULL) + expect_equal(count(merged), 12) + expect_equal(names(merged), c("age", "name", "name", "test")) + + mockLines3 <- c("{\"name\":\"Michael\", \"name_y\":\"Michael\", \"test\": \"yes\"}", + "{\"name\":\"Andy\", \"name_y\":\"Andy\", \"test\": \"no\"}", + "{\"name\":\"Justin\", \"name_y\":\"Justin\", \"test\": \"yes\"}", + "{\"name\":\"Bob\", \"name_y\":\"Bob\", \"test\": \"yes\"}") + jsonPath3 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + writeLines(mockLines3, jsonPath3) + df3 <- read.json(sqlContext, jsonPath3) + expect_error(merge(df, df3), + paste("The following column name: name_y occurs more than once in the 'DataFrame'.", + "Please use different suffixes for the intersected columns.", sep = "")) + + unlink(jsonPath2) + unlink(jsonPath3) +}) + +test_that("toJSON() returns an RDD of the correct values", { + df <- read.json(sqlContext, jsonPath) + testRDD <- toJSON(df) + expect_is(testRDD, "RDD") + expect_equal(getSerializedMode(testRDD), "string") + expect_equal(collect(testRDD)[[1]], mockLines[1]) +}) + +test_that("showDF()", { + df <- read.json(sqlContext, jsonPath) + s <- capture.output(showDF(df)) + expected <- paste("+----+-------+\n", + "| age| name|\n", + "+----+-------+\n", + "|null|Michael|\n", + "| 30| Andy|\n", + "| 19| Justin|\n", + "+----+-------+\n", sep = "") + expect_output(s, expected) +}) + +test_that("isLocal()", { + df <- read.json(sqlContext, jsonPath) + expect_false(isLocal(df)) +}) + +test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { + df <- read.json(sqlContext, jsonPath) + + lines <- c("{\"name\":\"Bob\", \"age\":24}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"James\", \"age\":35}") + jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + writeLines(lines, jsonPath2) + df2 <- read.df(sqlContext, jsonPath2, "json") + + unioned <- arrange(unionAll(df, df2), df$age) + expect_is(unioned, "DataFrame") + expect_equal(count(unioned), 6) + expect_equal(first(unioned)$name, "Michael") + + unioned2 <- arrange(rbind(unioned, df, df2), df$age) + expect_is(unioned2, "DataFrame") + expect_equal(count(unioned2), 12) + expect_equal(first(unioned2)$name, "Michael") + + excepted <- arrange(except(df, df2), desc(df$age)) + expect_is(unioned, "DataFrame") + expect_equal(count(excepted), 2) + expect_equal(first(excepted)$name, "Justin") + + intersected <- arrange(intersect(df, df2), df$age) + expect_is(unioned, "DataFrame") + expect_equal(count(intersected), 1) + expect_equal(first(intersected)$name, "Andy") + + # Test base::rbind is working + expect_equal(length(rbind(1:4, c = 2, a = 10, 10, deparse.level = 0)), 16) + + # Test base::intersect is working + expect_equal(length(intersect(1:20, 3:23)), 18) + + unlink(jsonPath2) +}) + +test_that("withColumn() and withColumnRenamed()", { + df <- read.json(sqlContext, jsonPath) + newDF <- withColumn(df, "newAge", df$age + 2) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) + + # Replace existing column + newDF <- withColumn(df, "age", df$age + 2) + expect_equal(length(columns(newDF)), 2) + expect_equal(first(filter(newDF, df$name != "Michael"))$age, 32) + + newDF2 <- withColumnRenamed(df, "age", "newerAge") + expect_equal(length(columns(newDF2)), 2) + expect_equal(columns(newDF2)[1], "newerAge") +}) + +test_that("mutate(), transform(), rename() and names()", { + df <- read.json(sqlContext, jsonPath) + newDF <- mutate(df, newAge = df$age + 2) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) + + newDF2 <- rename(df, newerAge = df$age) + expect_equal(length(columns(newDF2)), 2) + expect_equal(columns(newDF2)[1], "newerAge") + + names(newDF2) <- c("newerName", "evenNewerAge") + expect_equal(length(names(newDF2)), 2) + expect_equal(names(newDF2)[1], "newerName") + + transformedDF <- transform(df, newAge = -df$age, newAge2 = df$age / 2) + expect_equal(length(columns(transformedDF)), 4) + expect_equal(columns(transformedDF)[3], "newAge") + expect_equal(columns(transformedDF)[4], "newAge2") + expect_equal(first(filter(transformedDF, transformedDF$name == "Andy"))$newAge, -30) + + # test if base::transform on local data frames works + # ensure the proper signature is used - otherwise this will fail to run + attach(airquality) + result <- transform(Ozone, logOzone = log(Ozone)) + expect_equal(nrow(result), 153) + expect_equal(ncol(result), 2) + detach(airquality) +}) + +test_that("read/write Parquet files", { + df <- read.df(sqlContext, jsonPath, "json") + # Test write.df and read.df + write.df(df, parquetPath, "parquet", mode = "overwrite") + df2 <- read.df(sqlContext, parquetPath, "parquet") + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) + + # Test write.parquet/saveAsParquetFile and read.parquet/parquetFile + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + write.parquet(df, parquetPath2) + parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") + suppressWarnings(saveAsParquetFile(df, parquetPath3)) + parquetDF <- read.parquet(sqlContext, c(parquetPath2, parquetPath3)) + expect_is(parquetDF, "DataFrame") + expect_equal(count(parquetDF), count(df) * 2) + parquetDF2 <- suppressWarnings(parquetFile(sqlContext, parquetPath2, parquetPath3)) + expect_is(parquetDF2, "DataFrame") + expect_equal(count(parquetDF2), count(df) * 2) + + # Test if varargs works with variables + saveMode <- "overwrite" + mergeSchema <- "true" + parquetPath4 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") + write.df(df, parquetPath3, "parquet", mode = saveMode, mergeSchema = mergeSchema) + + unlink(parquetPath2) + unlink(parquetPath3) + unlink(parquetPath4) +}) + +test_that("read/write text files", { + # Test write.df and read.df + df <- read.df(sqlContext, jsonPath, "text") + expect_is(df, "DataFrame") + expect_equal(colnames(df), c("value")) + expect_equal(count(df), 3) + textPath <- tempfile(pattern = "textPath", fileext = ".txt") + write.df(df, textPath, "text", mode = "overwrite") + + # Test write.text and read.text + textPath2 <- tempfile(pattern = "textPath2", fileext = ".txt") + write.text(df, textPath2) + df2 <- read.text(sqlContext, c(textPath, textPath2)) + expect_is(df2, "DataFrame") + expect_equal(colnames(df2), c("value")) + expect_equal(count(df2), count(df) * 2) + + unlink(textPath) + unlink(textPath2) +}) + +test_that("describe() and summarize() on a DataFrame", { + df <- read.json(sqlContext, jsonPath) + stats <- describe(df, "age") + expect_equal(collect(stats)[1, "summary"], "count") + expect_equal(collect(stats)[2, "age"], "24.5") + expect_equal(collect(stats)[3, "age"], "7.7781745930520225") + stats <- describe(df) + expect_equal(collect(stats)[4, "name"], "Andy") + expect_equal(collect(stats)[5, "age"], "30") + + stats2 <- summary(df) + expect_equal(collect(stats2)[4, "name"], "Andy") + expect_equal(collect(stats2)[5, "age"], "30") + + # Test base::summary is working + expect_equal(length(summary(attenu, digits = 4)), 35) +}) + +test_that("dropna() and na.omit() on a DataFrame", { + df <- read.json(sqlContext, jsonPathNa) + rows <- collect(df) + + # drop with columns + + expected <- rows[!is.na(rows$name), ] + actual <- collect(dropna(df, cols = "name")) + expect_identical(expected, actual) + actual <- collect(na.omit(df, cols = "name")) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age), ] + actual <- collect(dropna(df, cols = "age")) + row.names(expected) <- row.names(actual) + # identical on two dataframes does not work here. Don't know why. + # use identical on all columns as a workaround. + expect_identical(expected$age, actual$age) + expect_identical(expected$height, actual$height) + expect_identical(expected$name, actual$name) + actual <- collect(na.omit(df, cols = "age")) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height), ] + actual <- collect(dropna(df, cols = c("age", "height"))) + expect_identical(expected, actual) + actual <- collect(na.omit(df, cols = c("age", "height"))) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name), ] + actual <- collect(dropna(df)) + expect_identical(expected, actual) + actual <- collect(na.omit(df)) + expect_identical(expected, actual) + + # drop with how + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name), ] + actual <- collect(dropna(df)) + expect_identical(expected, actual) + actual <- collect(na.omit(df)) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name), ] + actual <- collect(dropna(df, "all")) + expect_identical(expected, actual) + actual <- collect(na.omit(df, "all")) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name), ] + actual <- collect(dropna(df, "any")) + expect_identical(expected, actual) + actual <- collect(na.omit(df, "any")) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height), ] + actual <- collect(dropna(df, "any", cols = c("age", "height"))) + expect_identical(expected, actual) + actual <- collect(na.omit(df, "any", cols = c("age", "height"))) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) | !is.na(rows$height), ] + actual <- collect(dropna(df, "all", cols = c("age", "height"))) + expect_identical(expected, actual) + actual <- collect(na.omit(df, "all", cols = c("age", "height"))) + expect_identical(expected, actual) + + # drop with threshold + + expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2, ] + actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height"))) + expect_identical(expected, actual) + actual <- collect(na.omit(df, minNonNulls = 2, cols = c("age", "height"))) + expect_identical(expected, actual) + + expected <- rows[as.integer(!is.na(rows$age)) + + as.integer(!is.na(rows$height)) + + as.integer(!is.na(rows$name)) >= 3, ] + actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height"))) + expect_identical(expected, actual) + actual <- collect(na.omit(df, minNonNulls = 3, cols = c("name", "age", "height"))) + expect_identical(expected, actual) + + # Test stats::na.omit is working + expect_equal(nrow(na.omit(data.frame(x = c(0, 10, NA)))), 2) +}) + +test_that("fillna() on a DataFrame", { + df <- read.json(sqlContext, jsonPathNa) + rows <- collect(df) + + # fill with value + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + expected$height[is.na(expected$height)] <- 50.6 + actual <- collect(fillna(df, 50.6)) + expect_identical(expected, actual) + + expected <- rows + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, "unknown")) + expect_identical(expected, actual) + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + actual <- collect(fillna(df, 50.6, "age")) + expect_identical(expected, actual) + + expected <- rows + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, "unknown", c("age", "name"))) + expect_identical(expected, actual) + + # fill with named list + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + expected$height[is.na(expected$height)] <- 50.6 + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, list("age" = 50, "height" = 50.6, "name" = "unknown"))) + expect_identical(expected, actual) +}) + +test_that("crosstab() on a DataFrame", { + rdd <- lapply(parallelize(sc, 0:3), function(x) { + list(paste0("a", x %% 3), paste0("b", x %% 2)) + }) + df <- toDF(rdd, list("a", "b")) + ct <- crosstab(df, "a", "b") + ordered <- ct[order(ct$a_b), ] + row.names(ordered) <- NULL + expected <- data.frame("a_b" = c("a0", "a1", "a2"), "b0" = c(1, 0, 1), "b1" = c(1, 1, 0), + stringsAsFactors = FALSE, row.names = NULL) + expect_identical(expected, ordered) +}) + +test_that("cov() and corr() on a DataFrame", { + l <- lapply(c(0:9), function(x) { list(x, x * 2.0) }) + df <- createDataFrame(sqlContext, l, c("singles", "doubles")) + result <- cov(df, "singles", "doubles") + expect_true(abs(result - 55.0 / 3) < 1e-12) + + result <- corr(df, "singles", "doubles") + expect_true(abs(result - 1.0) < 1e-12) + result <- corr(df, "singles", "doubles", "pearson") + expect_true(abs(result - 1.0) < 1e-12) + + # Test stats::cov is working + #expect_true(abs(max(cov(swiss)) - 1739.295) < 1e-3) # nolint +}) + +test_that("freqItems() on a DataFrame", { + input <- 1:1000 + rdf <- data.frame(numbers = input, letters = as.character(input), + negDoubles = input * -1.0, stringsAsFactors = F) + rdf[ input %% 3 == 0, ] <- c(1, "1", -1) + df <- createDataFrame(sqlContext, rdf) + multiColResults <- freqItems(df, c("numbers", "letters"), support = 0.1) + expect_true(1 %in% multiColResults$numbers[[1]]) + expect_true("1" %in% multiColResults$letters[[1]]) + singleColResult <- freqItems(df, "negDoubles", support = 0.1) + expect_true(-1 %in% head(singleColResult$negDoubles)[[1]]) + + l <- lapply(c(0:99), function(i) { + if (i %% 2 == 0) { list(1L, -1.0) } + else { list(i, i * -1.0) }}) + df <- createDataFrame(sqlContext, l, c("a", "b")) + result <- freqItems(df, c("a", "b"), 0.4) + expect_identical(result[[1]], list(list(1L, 99L))) + expect_identical(result[[2]], list(list(-1, -99))) +}) + +test_that("sampleBy() on a DataFrame", { + l <- lapply(c(0:99), function(i) { as.character(i %% 3) }) + df <- createDataFrame(sqlContext, l, "key") + fractions <- list("0" = 0.1, "1" = 0.2) + sample <- sampleBy(df, "key", fractions, 0) + result <- collect(orderBy(count(groupBy(sample, "key")), "key")) + expect_identical(as.list(result[1, ]), list(key = "0", count = 3)) + expect_identical(as.list(result[2, ]), list(key = "1", count = 7)) +}) + +test_that("approxQuantile() on a DataFrame", { + l <- lapply(c(0:99), function(i) { i }) + df <- createDataFrame(sqlContext, l, "key") + quantiles <- approxQuantile(df, "key", c(0.5, 0.8), 0.0) + expect_equal(quantiles[[1]], 50) + expect_equal(quantiles[[2]], 80) +}) + +test_that("SQL error message is returned from JVM", { + retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) + expect_equal(grepl("Table or View not found", retError), TRUE) + expect_equal(grepl("blah", retError), TRUE) +}) + +irisDF <- suppressWarnings(createDataFrame(sqlContext, iris)) + +test_that("Method as.data.frame as a synonym for collect()", { + expect_equal(as.data.frame(irisDF), collect(irisDF)) + irisDF2 <- irisDF[irisDF$Species == "setosa", ] + expect_equal(as.data.frame(irisDF2), collect(irisDF2)) +}) + +test_that("attach() on a DataFrame", { + df <- read.json(sqlContext, jsonPath) + expect_error(age) + attach(df) + expect_is(age, "DataFrame") + expected_age <- data.frame(age = c(NA, 30, 19)) + expect_equal(head(age), expected_age) + stat <- summary(age) + expect_equal(collect(stat)[5, "age"], "30") + age <- age$age + 1 + expect_is(age, "Column") + rm(age) + stat2 <- summary(age) + expect_equal(collect(stat2)[5, "age"], "30") + detach("df") + stat3 <- summary(df[, "age"]) + expect_equal(collect(stat3)[5, "age"], "30") + expect_error(age) +}) + +test_that("with() on a DataFrame", { + df <- suppressWarnings(createDataFrame(sqlContext, iris)) + expect_error(Sepal_Length) + sum1 <- with(df, list(summary(Sepal_Length), summary(Sepal_Width))) + expect_equal(collect(sum1[[1]])[1, "Sepal_Length"], "150") + sum2 <- with(df, distinct(Sepal_Length)) + expect_equal(nrow(sum2), 35) +}) + +test_that("Method coltypes() to get and set R's data types of a DataFrame", { + expect_equal(coltypes(irisDF), c(rep("numeric", 4), "character")) + + data <- data.frame(c1 = c(1, 2, 3), + c2 = c(T, F, T), + c3 = c("2015/01/01 10:00:00", "2015/01/02 10:00:00", "2015/01/03 10:00:00")) + + schema <- structType(structField("c1", "byte"), + structField("c3", "boolean"), + structField("c4", "timestamp")) + + # Test primitive types + DF <- createDataFrame(sqlContext, data, schema) + expect_equal(coltypes(DF), c("integer", "logical", "POSIXct")) + + # Test complex types + x <- createDataFrame(sqlContext, list(list(as.environment( + list("a" = "b", "c" = "d", "e" = "f"))))) + expect_equal(coltypes(x), "map") + + df <- selectExpr(read.json(sqlContext, jsonPath), "name", "(age * 1.21) as age") + expect_equal(dtypes(df), list(c("name", "string"), c("age", "decimal(24,2)"))) + + df1 <- select(df, cast(df$age, "integer")) + coltypes(df) <- c("character", "integer") + expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"))) + value <- collect(df[, 2])[[3, 1]] + expect_equal(value, collect(df1)[[3, 1]]) + expect_equal(value, 22) + + coltypes(df) <- c(NA, "numeric") + expect_equal(dtypes(df), list(c("name", "string"), c("age", "double"))) + + expect_error(coltypes(df) <- c("character"), + "Length of type vector should match the number of columns for DataFrame") + expect_error(coltypes(df) <- c("environment", "list"), + "Only atomic type is supported for column types") +}) + +test_that("Method str()", { + # Structure of Iris + iris2 <- iris + colnames(iris2) <- c("Sepal_Length", "Sepal_Width", "Petal_Length", "Petal_Width", "Species") + iris2$col <- TRUE + irisDF2 <- createDataFrame(sqlContext, iris2) + + out <- capture.output(str(irisDF2)) + expect_equal(length(out), 7) + expect_equal(out[1], "'DataFrame': 6 variables:") + expect_equal(out[2], " $ Sepal_Length: num 5.1 4.9 4.7 4.6 5 5.4") + expect_equal(out[3], " $ Sepal_Width : num 3.5 3 3.2 3.1 3.6 3.9") + expect_equal(out[4], " $ Petal_Length: num 1.4 1.4 1.3 1.5 1.4 1.7") + expect_equal(out[5], " $ Petal_Width : num 0.2 0.2 0.2 0.2 0.2 0.4") + expect_equal(out[6], paste0(" $ Species : chr \"setosa\" \"setosa\" \"", + "setosa\" \"setosa\" \"setosa\" \"setosa\"")) + expect_equal(out[7], " $ col : logi TRUE TRUE TRUE TRUE TRUE TRUE") + + # A random dataset with many columns. This test is to check str limits + # the number of columns. Therefore, it will suffice to check for the + # number of returned rows + x <- runif(200, 1, 10) + df <- data.frame(t(as.matrix(data.frame(x, x, x, x, x, x, x, x, x)))) + DF <- createDataFrame(sqlContext, df) + out <- capture.output(str(DF)) + expect_equal(length(out), 103) + + # Test utils:::str + expect_equal(capture.output(utils:::str(iris)), capture.output(str(iris))) +}) + +unlink(parquetPath) +unlink(jsonPath) +unlink(jsonPathNa) diff --git a/R/pkg/inst/tests/test_take.R b/R/pkg/inst/tests/testthat/test_take.R similarity index 100% rename from R/pkg/inst/tests/test_take.R rename to R/pkg/inst/tests/testthat/test_take.R diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/testthat/test_textFile.R similarity index 84% rename from R/pkg/inst/tests/test_textFile.R rename to R/pkg/inst/tests/testthat/test_textFile.R index a9cf83dbdbdb1..e64ef1bb31a3a 100644 --- a/R/pkg/inst/tests/test_textFile.R +++ b/R/pkg/inst/tests/testthat/test_textFile.R @@ -23,7 +23,7 @@ sc <- sparkR.init() mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("textFile() on a local file returns an RDD", { - fileName <- tempfile(pattern="spark-test", fileext=".tmp") + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) rdd <- textFile(sc, fileName) @@ -35,7 +35,7 @@ test_that("textFile() on a local file returns an RDD", { }) test_that("textFile() followed by a collect() returns the same content", { - fileName <- tempfile(pattern="spark-test", fileext=".tmp") + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) rdd <- textFile(sc, fileName) @@ -45,7 +45,7 @@ test_that("textFile() followed by a collect() returns the same content", { }) test_that("textFile() word count works as expected", { - fileName <- tempfile(pattern="spark-test", fileext=".tmp") + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) rdd <- textFile(sc, fileName) @@ -63,7 +63,7 @@ test_that("textFile() word count works as expected", { }) test_that("several transformations on RDD created by textFile()", { - fileName <- tempfile(pattern="spark-test", fileext=".tmp") + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) rdd <- textFile(sc, fileName) # RDD @@ -77,8 +77,8 @@ test_that("several transformations on RDD created by textFile()", { }) test_that("textFile() followed by a saveAsTextFile() returns the same content", { - fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") - fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") + fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) rdd <- textFile(sc, fileName1, 1L) @@ -91,7 +91,7 @@ test_that("textFile() followed by a saveAsTextFile() returns the same content", }) test_that("saveAsTextFile() on a parallelized list works as expected", { - fileName <- tempfile(pattern="spark-test", fileext=".tmp") + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") l <- list(1, 2, 3) rdd <- parallelize(sc, l, 1L) saveAsTextFile(rdd, fileName) @@ -102,8 +102,8 @@ test_that("saveAsTextFile() on a parallelized list works as expected", { }) test_that("textFile() and saveAsTextFile() word count works as expected", { - fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") - fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") + fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) rdd <- textFile(sc, fileName1) @@ -127,8 +127,8 @@ test_that("textFile() and saveAsTextFile() word count works as expected", { }) test_that("textFile() on multiple paths", { - fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") - fileName2 <- tempfile(pattern="spark-test", fileext=".tmp") + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") + fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines("Spark is pretty.", fileName1) writeLines("Spark is awesome.", fileName2) @@ -140,7 +140,7 @@ test_that("textFile() on multiple paths", { }) test_that("Pipelined operations on RDDs created using textFile", { - fileName <- tempfile(pattern="spark-test", fileext=".tmp") + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) rdd <- textFile(sc, fileName) diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R similarity index 94% rename from R/pkg/inst/tests/test_utils.R rename to R/pkg/inst/tests/testthat/test_utils.R index 12df4cf4f65b7..4218138f641d1 100644 --- a/R/pkg/inst/tests/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -41,7 +41,7 @@ test_that("convertJListToRList() gives back (deserializes) the original JLists test_that("serializeToBytes on RDD", { # File content mockFile <- c("Spark is pretty.", "Spark is awesome.") - fileName <- tempfile(pattern="spark-test", fileext=".tmp") + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) text.rdd <- textFile(sc, fileName) @@ -86,8 +86,8 @@ test_that("cleanClosure on R functions", { f <- function(x) { defUse <- base::as.integer(x) + 1 # Test for access operators `::`. lapply(x, g) + 1 # Test for capturing function call "g"'s closure as a argument of lapply. - l$field[1,1] <- 3 # Test for access operators `$`. - res <- defUse + l$field[1,] # Test for def-use chain of "defUse", and "" symbol. + l$field[1, 1] <- 3 # Test for access operators `$`. + res <- defUse + l$field[1, ] # Test for def-use chain of "defUse", and "" symbol. f(res) # Test for recursive calls. } newF <- cleanClosure(f) @@ -95,7 +95,9 @@ test_that("cleanClosure on R functions", { # TODO(shivaram): length(ls(env)) is 4 here for some reason and `lapply` is included in `env`. # Disabling this test till we debug this. # + # nolint start # expect_equal(length(ls(env)), 3) # Only "g", "l" and "f". No "base", "field" or "defUse". + # nolint end expect_true("g" %in% ls(env)) expect_true("l" %in% ls(env)) expect_true("f" %in% ls(env)) @@ -130,7 +132,7 @@ test_that("cleanClosure on R functions", { expect_equal(actual, expected) # Test for broadcast variables. - a <- matrix(nrow=10, ncol=10, data=rnorm(100)) + a <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) aBroadcast <- broadcast(sc, a) normMultiply <- function(x) { norm(aBroadcast$value) * x } newnormMultiply <- SparkR:::cleanClosure(normMultiply) diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index 3584b418a71a9..f55beac6c8c07 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -18,10 +18,11 @@ # Worker daemon rLibDir <- Sys.getenv("SPARKR_RLIBDIR") -script <- paste(rLibDir, "SparkR/worker/worker.R", sep = "/") +dirs <- strsplit(rLibDir, ",")[[1]] +script <- file.path(dirs[[1]], "SparkR", "worker", "worker.R") # preload SparkR package, speedup worker -.libPaths(c(rLibDir, .libPaths())) +.libPaths(c(dirs, .libPaths())) suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index 0c3b0d1f4be20..b6784dbae3203 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -35,10 +35,11 @@ bootTime <- currentTimeSecs() bootElap <- elapsedSecs() rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +dirs <- strsplit(rLibDir, ",")[[1]] # Set libPaths to include SparkR package as loadNamespace needs this # TODO: Figure out if we can avoid this by not loading any objects that require # SparkR namespace -.libPaths(c(rLibDir, .libPaths())) +.libPaths(c(dirs, .libPaths())) suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) @@ -54,7 +55,7 @@ serializer <- SparkR:::readString(inputCon) # Include packages as required packageNames <- unserialize(SparkR:::readRaw(inputCon)) for (pkg in packageNames) { - suppressPackageStartupMessages(library(as.character(pkg), character.only=TRUE)) + suppressPackageStartupMessages(library(as.character(pkg), character.only = TRUE)) } # read function dependencies diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index 4f8a1ed2d83ef..1d04656ac2594 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -18,4 +18,7 @@ library(testthat) library(SparkR) +# Turn all warnings into errors +options("warn" = 2) + test_package("SparkR") diff --git a/R/run-tests.sh b/R/run-tests.sh index e82ad0ba2cd06..9dcf0ace7d97e 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -23,7 +23,7 @@ FAILED=0 LOGFILE=$FWDIR/unit-tests.out rm -f $LOGFILE -SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +SPARK_TESTING=1 $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.default.name="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) if [[ $FAILED != 0 ]]; then diff --git a/README.md b/README.md index c0d6a946035a9..d5804d1a20b43 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,8 @@ To build Spark and its example programs, run: (You do not need to do this if you downloaded a pre-built package.) More detailed documentation is available from the project site, at ["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). +For developing Spark using an IDE, see [Eclipse](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-Eclipse) +and [IntelliJ](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-IntelliJ). ## Interactive Scala Shell diff --git a/assembly/pom.xml b/assembly/pom.xml index 4b60ee00ffbe5..22cbac06cad61 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -20,22 +20,21 @@ 4.0.0 org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT + spark-parent_2.11 + 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-assembly_2.10 + spark-assembly_2.11 Spark Project Assembly http://spark.apache.org/ pom assembly - scala-${scala.binary.version} - spark-assembly-${project.version}-hadoop${hadoop.version}.jar - ${project.build.directory}/${spark.jar.dir}/${spark.jar.basename} + none + package @@ -44,11 +43,6 @@ spark-core_${scala.binary.version} ${project.version} - - org.apache.spark - spark-bagel_${scala.binary.version} - ${project.version} - org.apache.spark spark-mllib_${scala.binary.version} @@ -74,6 +68,17 @@ spark-repl_${scala.binary.version} ${project.version} + + + + com.google.guava + guava + ${hadoop.deps.scope} + @@ -92,75 +97,26 @@ true - - - org.apache.maven.plugins - maven-antrun-plugin - - - package - - run - - - - - - - - - - - - - + org.apache.maven.plugins - maven-shade-plugin - - false - ${spark.jar} - - - *:* - - - - - *:* - - org/datanucleus/** - META-INF/*.SF - META-INF/*.DSA - META-INF/*.RSA - - - - - - - package - - shade - - - - - - META-INF/services/org.apache.hadoop.fs.FileSystem - - - reference.conf - - - log4j.properties - - - - - - - + maven-antrun-plugin + + + package + + run + + + + + + + + + + + diff --git a/assembly/src/main/assembly/assembly.xml b/assembly/src/main/assembly/assembly.xml index 711156337b7c3..009d4b92f406c 100644 --- a/assembly/src/main/assembly/assembly.xml +++ b/assembly/src/main/assembly/assembly.xml @@ -32,7 +32,7 @@ ${project.parent.basedir}/core/src/main/resources/org/apache/spark/ui/static/ - /ui-resources/org/apache/spark/ui/static + ui-resources/org/apache/spark/ui/static **/* @@ -41,7 +41,7 @@ ${project.parent.basedir}/sbin/ - /sbin + sbin **/* @@ -50,7 +50,7 @@ ${project.parent.basedir}/bin/ - /bin + bin **/* @@ -59,7 +59,7 @@ ${project.parent.basedir}/assembly/target/${spark.jar.dir} - / + ${spark.jar.basename} diff --git a/bagel/pom.xml b/bagel/pom.xml deleted file mode 100644 index 672e9469aec92..0000000000000 --- a/bagel/pom.xml +++ /dev/null @@ -1,64 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT - ../pom.xml - - - org.apache.spark - spark-bagel_2.10 - - bagel - - jar - Spark Project Bagel - http://spark.apache.org/ - - - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - test-jar - test - - - org.scalacheck - scalacheck_${scala.binary.version} - test - - - org.apache.spark - spark-test-tags_${scala.binary.version} - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - diff --git a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala deleted file mode 100644 index 8399033ac61ec..0000000000000 --- a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala +++ /dev/null @@ -1,318 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.bagel - -import org.apache.spark._ -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel - -@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") -object Bagel extends Logging { - val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK - - /** - * Runs a Bagel program. - * @param sc org.apache.spark.SparkContext to use for the program. - * @param vertices vertices of the graph represented as an RDD of (Key, Vertex) pairs. Often the - * Key will be the vertex id. - * @param messages initial set of messages represented as an RDD of (Key, Message) pairs. Often - * this will be an empty array, i.e. sc.parallelize(Array[K, Message]()). - * @param combiner [[org.apache.spark.bagel.Combiner]] combines multiple individual messages to a - * given vertex into one message before sending (which often involves network - * I/O). - * @param aggregator [[org.apache.spark.bagel.Aggregator]] performs a reduce across all vertices - * after each superstep and provides the result to each vertex in the next - * superstep. - * @param partitioner org.apache.spark.Partitioner partitions values by key - * @param numPartitions number of partitions across which to split the graph. - * Default is the default parallelism of the SparkContext - * @param storageLevel org.apache.spark.storage.StorageLevel to use for caching of - * intermediate RDDs in each superstep. Defaults to caching in memory. - * @param compute function that takes a Vertex, optional set of (possibly combined) messages to - * the Vertex, optional Aggregator and the current superstep, - * and returns a set of (Vertex, outgoing Messages) pairs - * @tparam K key - * @tparam V vertex type - * @tparam M message type - * @tparam C combiner - * @tparam A aggregator - * @return an RDD of (K, V) pairs representing the graph after completion of the program - */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, - C: Manifest, A: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - aggregator: Option[Aggregator[V, A]], - partitioner: Partitioner, - numPartitions: Int, - storageLevel: StorageLevel = DEFAULT_STORAGE_LEVEL - )( - compute: (V, Option[C], Option[A], Int) => (V, Array[M]) - ): RDD[(K, V)] = { - val splits = if (numPartitions != 0) numPartitions else sc.defaultParallelism - - var superstep = 0 - var verts = vertices - var msgs = messages - var noActivity = false - var lastRDD: RDD[(K, (V, Array[M]))] = null - do { - logInfo("Starting superstep " + superstep + ".") - val startTime = System.currentTimeMillis - - val aggregated = agg(verts, aggregator) - val combinedMsgs = msgs.combineByKeyWithClassTag( - combiner.createCombiner _, combiner.mergeMsg _, combiner.mergeCombiners _, partitioner) - val grouped = combinedMsgs.groupWith(verts) - val superstep_ = superstep // Create a read-only copy of superstep for capture in closure - val (processed, numMsgs, numActiveVerts) = - comp[K, V, M, C](sc, grouped, compute(_, _, aggregated, superstep_), storageLevel) - if (lastRDD != null) { - lastRDD.unpersist(false) - } - lastRDD = processed - - val timeTaken = System.currentTimeMillis - startTime - logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000)) - - verts = processed.mapValues { case (vert, msgs) => vert } - msgs = processed.flatMap { - case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m)) - } - superstep += 1 - - noActivity = numMsgs == 0 && numActiveVerts == 0 - } while (!noActivity) - - verts - } - - /** Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]] and the default - * storage level */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - partitioner: Partitioner, - numPartitions: Int - )( - compute: (V, Option[C], Int) => (V, Array[M])): RDD[(K, V)] = run(sc, vertices, messages, - combiner, numPartitions, DEFAULT_STORAGE_LEVEL)(compute) - - /** Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]] */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - partitioner: Partitioner, - numPartitions: Int, - storageLevel: StorageLevel - )( - compute: (V, Option[C], Int) => (V, Array[M]) - ): RDD[(K, V)] = { - run[K, V, M, C, Nothing]( - sc, vertices, messages, combiner, None, partitioner, numPartitions, storageLevel)( - addAggregatorArg[K, V, M, C](compute)) - } - - /** - * Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]], default - * org.apache.spark.HashPartitioner and default storage level - */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - numPartitions: Int - )( - compute: (V, Option[C], Int) => (V, Array[M]) - ): RDD[(K, V)] = run(sc, vertices, messages, combiner, numPartitions, - DEFAULT_STORAGE_LEVEL)(compute) - - /** - * Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]] and the - * default org.apache.spark.HashPartitioner - */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - numPartitions: Int, - storageLevel: StorageLevel - )( - compute: (V, Option[C], Int) => (V, Array[M]) - ): RDD[(K, V)] = { - val part = new HashPartitioner(numPartitions) - run[K, V, M, C, Nothing]( - sc, vertices, messages, combiner, None, part, numPartitions, storageLevel)( - addAggregatorArg[K, V, M, C](compute)) - } - - /** - * Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]], - * default org.apache.spark.HashPartitioner, - * [[org.apache.spark.bagel.DefaultCombiner]] and the default storage level - */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - numPartitions: Int - )( - compute: (V, Option[Array[M]], Int) => (V, Array[M]) - ): RDD[(K, V)] = run(sc, vertices, messages, numPartitions, DEFAULT_STORAGE_LEVEL)(compute) - - /** - * Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]], - * the default org.apache.spark.HashPartitioner - * and [[org.apache.spark.bagel.DefaultCombiner]] - */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - numPartitions: Int, - storageLevel: StorageLevel - )( - compute: (V, Option[Array[M]], Int) => (V, Array[M]) - ): RDD[(K, V)] = { - val part = new HashPartitioner(numPartitions) - run[K, V, M, Array[M], Nothing]( - sc, vertices, messages, new DefaultCombiner(), None, part, numPartitions, storageLevel)( - addAggregatorArg[K, V, M, Array[M]](compute)) - } - - /** - * Aggregates the given vertices using the given aggregator, if it - * is specified. - */ - private def agg[K, V <: Vertex, A: Manifest]( - verts: RDD[(K, V)], - aggregator: Option[Aggregator[V, A]] - ): Option[A] = aggregator match { - case Some(a) => - Some(verts.map { - case (id, vert) => a.createAggregator(vert) - }.reduce(a.mergeAggregators(_, _))) - case None => None - } - - /** - * Processes the given vertex-message RDD using the compute - * function. Returns the processed RDD, the number of messages - * created, and the number of active vertices. - */ - private def comp[K: Manifest, V <: Vertex, M <: Message[K], C]( - sc: SparkContext, - grouped: RDD[(K, (Iterable[C], Iterable[V]))], - compute: (V, Option[C]) => (V, Array[M]), - storageLevel: StorageLevel - ): (RDD[(K, (V, Array[M]))], Int, Int) = { - var numMsgs = sc.accumulator(0) - var numActiveVerts = sc.accumulator(0) - val processed = grouped.mapValues(x => (x._1.iterator, x._2.iterator)) - .flatMapValues { - case (_, vs) if !vs.hasNext => None - case (c, vs) => { - val (newVert, newMsgs) = - compute(vs.next, - c.hasNext match { - case true => Some(c.next) - case false => None - } - ) - - numMsgs += newMsgs.size - if (newVert.active) { - numActiveVerts += 1 - } - - Some((newVert, newMsgs)) - } - }.persist(storageLevel) - - // Force evaluation of processed RDD for accurate performance measurements - processed.foreach(x => {}) - - (processed, numMsgs.value, numActiveVerts.value) - } - - /** - * Converts a compute function that doesn't take an aggregator to - * one that does, so it can be passed to Bagel.run. - */ - private def addAggregatorArg[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C]( - compute: (V, Option[C], Int) => (V, Array[M]) - ): (V, Option[C], Option[Nothing], Int) => (V, Array[M]) = { - (vert: V, msgs: Option[C], aggregated: Option[Nothing], superstep: Int) => - compute(vert, msgs, superstep) - } -} - -@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") -trait Combiner[M, C] { - def createCombiner(msg: M): C - def mergeMsg(combiner: C, msg: M): C - def mergeCombiners(a: C, b: C): C -} - -@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") -trait Aggregator[V, A] { - def createAggregator(vert: V): A - def mergeAggregators(a: A, b: A): A -} - -/** Default combiner that simply appends messages together (i.e. performs no aggregation) */ -@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") -class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializable { - def createCombiner(msg: M): Array[M] = - Array(msg) - def mergeMsg(combiner: Array[M], msg: M): Array[M] = - combiner :+ msg - def mergeCombiners(a: Array[M], b: Array[M]): Array[M] = - a ++ b -} - -/** - * Represents a Bagel vertex. - * - * Subclasses may store state along with each vertex and must - * inherit from java.io.Serializable or scala.Serializable. - */ -@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") -trait Vertex { - def active: Boolean -} - -/** - * Represents a Bagel message to a target vertex. - * - * Subclasses may contain a payload to deliver to the target vertex - * and must inherit from java.io.Serializable or scala.Serializable. - */ -@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") -trait Message[K] { - def targetId: K -} diff --git a/bagel/src/main/scala/org/apache/spark/bagel/package-info.java b/bagel/src/main/scala/org/apache/spark/bagel/package-info.java deleted file mode 100644 index 81f26f276549f..0000000000000 --- a/bagel/src/main/scala/org/apache/spark/bagel/package-info.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * Bagel: An implementation of Pregel in Spark. THIS IS DEPRECATED - use Spark's GraphX library. - */ -package org.apache.spark.bagel; \ No newline at end of file diff --git a/bagel/src/main/scala/org/apache/spark/bagel/package.scala b/bagel/src/main/scala/org/apache/spark/bagel/package.scala deleted file mode 100644 index 2fb1934579781..0000000000000 --- a/bagel/src/main/scala/org/apache/spark/bagel/package.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -/** - * Bagel: An implementation of Pregel in Spark. THIS IS DEPRECATED - use Spark's GraphX library. - */ -package object bagel diff --git a/bin/beeline.cmd b/bin/beeline.cmd index 8293f311029dd..02464bd088792 100644 --- a/bin/beeline.cmd +++ b/bin/beeline.cmd @@ -17,5 +17,4 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -set SPARK_HOME=%~dp0.. -cmd /V /E /C %SPARK_HOME%\bin\spark-class.cmd org.apache.hive.beeline.BeeLine %* +cmd /V /E /C "%~dp0spark-class.cmd" org.apache.hive.beeline.BeeLine %* diff --git a/bin/load-spark-env.cmd b/bin/load-spark-env.cmd index 36d932c453b6f..0977025c2036e 100644 --- a/bin/load-spark-env.cmd +++ b/bin/load-spark-env.cmd @@ -27,7 +27,7 @@ if [%SPARK_ENV_LOADED%] == [] ( if not [%SPARK_CONF_DIR%] == [] ( set user_conf_dir=%SPARK_CONF_DIR% ) else ( - set user_conf_dir=%~dp0..\..\conf + set user_conf_dir=..\conf ) call :LoadSparkEnv @@ -35,8 +35,8 @@ if [%SPARK_ENV_LOADED%] == [] ( rem Setting SPARK_SCALA_VERSION if not already set. -set ASSEMBLY_DIR2=%SPARK_HOME%/assembly/target/scala-2.11 -set ASSEMBLY_DIR1=%SPARK_HOME%/assembly/target/scala-2.10 +set ASSEMBLY_DIR2="%SPARK_HOME%\assembly\target\scala-2.11" +set ASSEMBLY_DIR1="%SPARK_HOME%\assembly\target\scala-2.10" if [%SPARK_SCALA_VERSION%] == [] ( diff --git a/bin/pyspark b/bin/pyspark index 5eaa17d3c2016..a25749964e53e 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -49,7 +49,11 @@ if [[ -n "$IPYTHON_OPTS" || "$IPYTHON" == "1" ]]; then # If IPython options are specified, assume user wants to run IPython # (for backwards-compatibility) PYSPARK_DRIVER_PYTHON_OPTS="$PYSPARK_DRIVER_PYTHON_OPTS $IPYTHON_OPTS" - PYSPARK_DRIVER_PYTHON="ipython" + if [ -x "$(command -v jupyter)" ]; then + PYSPARK_DRIVER_PYTHON="jupyter" + else + PYSPARK_DRIVER_PYTHON="ipython" + fi elif [[ -z "$PYSPARK_DRIVER_PYTHON" ]]; then PYSPARK_DRIVER_PYTHON="${PYSPARK_PYTHON:-"$DEFAULT_PYTHON"}" fi @@ -67,7 +71,7 @@ export PYSPARK_PYTHON # Add the PySpark classes to the Python path: export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH" -export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.9-src.zip:$PYTHONPATH" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.9.2-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" diff --git a/bin/pyspark.cmd b/bin/pyspark.cmd index 7c26fbbac28b8..72d046a4ba2cf 100644 --- a/bin/pyspark.cmd +++ b/bin/pyspark.cmd @@ -20,4 +20,4 @@ rem rem This is the entry point for running PySpark. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C %~dp0pyspark2.cmd %* +cmd /V /E /C "%~dp0pyspark2.cmd" %* diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index a97d884f0bf39..cb788497ffc79 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -20,7 +20,7 @@ rem rem Figure out where the Spark framework is installed set SPARK_HOME=%~dp0.. -call %SPARK_HOME%\bin\load-spark-env.cmd +call "%SPARK_HOME%\bin\load-spark-env.cmd" set _SPARK_CMD_USAGE=Usage: bin\pyspark.cmd [options] rem Figure out which Python to use. @@ -30,9 +30,9 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( ) set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH% -set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.9-src.zip;%PYTHONPATH% +set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.9.2-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py -call %SPARK_HOME%\bin\spark-submit2.cmd pyspark-shell-main --name "PySparkShell" %* +call "%SPARK_HOME%\bin\spark-submit2.cmd" pyspark-shell-main --name "PySparkShell" %* diff --git a/bin/run-example b/bin/run-example index e1b0d5789bed6..dd0e3c4120260 100755 --- a/bin/run-example +++ b/bin/run-example @@ -21,56 +21,5 @@ if [ -z "${SPARK_HOME}" ]; then export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" fi -EXAMPLES_DIR="${SPARK_HOME}"/examples - -. "${SPARK_HOME}"/bin/load-spark-env.sh - -if [ -n "$1" ]; then - EXAMPLE_CLASS="$1" - shift -else - echo "Usage: ./bin/run-example [example-args]" 1>&2 - echo " - set MASTER=XX to use a specific master" 1>&2 - echo " - can use abbreviated example class name relative to com.apache.spark.examples" 1>&2 - echo " (e.g. SparkPi, mllib.LinearRegression, streaming.KinesisWordCountASL)" 1>&2 - exit 1 -fi - -if [ -f "${SPARK_HOME}/RELEASE" ]; then - JAR_PATH="${SPARK_HOME}/lib" -else - JAR_PATH="${EXAMPLES_DIR}/target/scala-${SPARK_SCALA_VERSION}" -fi - -JAR_COUNT=0 - -for f in "${JAR_PATH}"/spark-examples-*hadoop*.jar; do - if [[ ! -e "$f" ]]; then - echo "Failed to find Spark examples assembly in ${SPARK_HOME}/lib or ${SPARK_HOME}/examples/target" 1>&2 - echo "You need to build Spark before running this program" 1>&2 - exit 1 - fi - SPARK_EXAMPLES_JAR="$f" - JAR_COUNT=$((JAR_COUNT+1)) -done - -if [ "$JAR_COUNT" -gt "1" ]; then - echo "Found multiple Spark examples assembly jars in ${JAR_PATH}" 1>&2 - ls "${JAR_PATH}"/spark-examples-*hadoop*.jar 1>&2 - echo "Please remove all but one jar." 1>&2 - exit 1 -fi - -export SPARK_EXAMPLES_JAR - -EXAMPLE_MASTER=${MASTER:-"local[*]"} - -if [[ ! $EXAMPLE_CLASS == org.apache.spark.examples* ]]; then - EXAMPLE_CLASS="org.apache.spark.examples.$EXAMPLE_CLASS" -fi - -exec "${SPARK_HOME}"/bin/spark-submit \ - --master $EXAMPLE_MASTER \ - --class $EXAMPLE_CLASS \ - "$SPARK_EXAMPLES_JAR" \ - "$@" +export _SPARK_CMD_USAGE="Usage: ./bin/run-example [options] example-class [example args]" +exec "${SPARK_HOME}"/bin/spark-submit run-example "$@" diff --git a/bin/run-example.cmd b/bin/run-example.cmd index 5b2d048d6ed50..f9b786e92b823 100644 --- a/bin/run-example.cmd +++ b/bin/run-example.cmd @@ -17,7 +17,6 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -rem This is the entry point for running a Spark example. To avoid polluting -rem the environment, it just launches a new cmd to do the real work. - -cmd /V /E /C %~dp0run-example2.cmd %* +set SPARK_HOME=%~dp0.. +set _SPARK_CMD_USAGE=Usage: ./bin/run-example [options] example-class [example args] +cmd /V /E /C "%~dp0spark-submit.cmd" run-example %* diff --git a/bin/run-example2.cmd b/bin/run-example2.cmd deleted file mode 100644 index c3e0221fb62e3..0000000000000 --- a/bin/run-example2.cmd +++ /dev/null @@ -1,88 +0,0 @@ -@echo off - -rem -rem Licensed to the Apache Software Foundation (ASF) under one or more -rem contributor license agreements. See the NOTICE file distributed with -rem this work for additional information regarding copyright ownership. -rem The ASF licenses this file to You under the Apache License, Version 2.0 -rem (the "License"); you may not use this file except in compliance with -rem the License. You may obtain a copy of the License at -rem -rem http://www.apache.org/licenses/LICENSE-2.0 -rem -rem Unless required by applicable law or agreed to in writing, software -rem distributed under the License is distributed on an "AS IS" BASIS, -rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -rem See the License for the specific language governing permissions and -rem limitations under the License. -rem - -set SCALA_VERSION=2.10 - -rem Figure out where the Spark framework is installed -set FWDIR=%~dp0..\ - -rem Export this as SPARK_HOME -set SPARK_HOME=%FWDIR% - -call %SPARK_HOME%\bin\load-spark-env.cmd - -rem Test that an argument was given -if not "x%1"=="x" goto arg_given - echo Usage: run-example ^ [example-args] - echo - set MASTER=XX to use a specific master - echo - can use abbreviated example class name relative to com.apache.spark.examples - echo (e.g. SparkPi, mllib.LinearRegression, streaming.KinesisWordCountASL) - goto exit -:arg_given - -set EXAMPLES_DIR=%FWDIR%examples - -rem Figure out the JAR file that our examples were packaged into. -set SPARK_EXAMPLES_JAR= -if exist "%FWDIR%RELEASE" ( - for %%d in ("%FWDIR%lib\spark-examples*.jar") do ( - set SPARK_EXAMPLES_JAR=%%d - ) -) else ( - for %%d in ("%EXAMPLES_DIR%\target\scala-%SCALA_VERSION%\spark-examples*.jar") do ( - set SPARK_EXAMPLES_JAR=%%d - ) -) -if "x%SPARK_EXAMPLES_JAR%"=="x" ( - echo Failed to find Spark examples assembly JAR. - echo You need to build Spark before running this program. - goto exit -) - -rem Set master from MASTER environment variable if given -if "x%MASTER%"=="x" ( - set EXAMPLE_MASTER=local[*] -) else ( - set EXAMPLE_MASTER=%MASTER% -) - -rem If the EXAMPLE_CLASS does not start with org.apache.spark.examples, add that -set EXAMPLE_CLASS=%1 -set PREFIX=%EXAMPLE_CLASS:~0,25% -if not %PREFIX%==org.apache.spark.examples ( - set EXAMPLE_CLASS=org.apache.spark.examples.%EXAMPLE_CLASS% -) - -rem Get the tail of the argument list, to skip the first one. This is surprisingly -rem complicated on Windows. -set "ARGS=" -:top -shift -if "%~1" neq "" ( - set ARGS=%ARGS% "%~1" - goto :top -) -if defined ARGS set ARGS=%ARGS:~1% - -call "%FWDIR%bin\spark-submit.cmd" ^ - --master %EXAMPLE_MASTER% ^ - --class %EXAMPLE_CLASS% ^ - "%SPARK_EXAMPLES_JAR%" %ARGS% - -:exit diff --git a/bin/spark-class b/bin/spark-class index 87d06693af4fe..b2a36b9846780 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -35,41 +35,31 @@ else fi fi -# Find assembly jar -SPARK_ASSEMBLY_JAR= +# Find Spark jars. if [ -f "${SPARK_HOME}/RELEASE" ]; then - ASSEMBLY_DIR="${SPARK_HOME}/lib" + SPARK_JARS_DIR="${SPARK_HOME}/jars" else - ASSEMBLY_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION" + SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" fi -GREP_OPTIONS= -num_jars="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" | wc -l)" -if [ "$num_jars" -eq "0" -a -z "$SPARK_ASSEMBLY_JAR" -a "$SPARK_PREPEND_CLASSES" != "1" ]; then - echo "Failed to find Spark assembly in $ASSEMBLY_DIR." 1>&2 - echo "You need to build Spark before running this program." 1>&2 +if [ ! -d "$SPARK_JARS_DIR" ] && [ -z "$SPARK_TESTING$SPARK_SQL_TESTING" ]; then + echo "Failed to find Spark jars directory ($SPARK_JARS_DIR)." 1>&2 + echo "You need to build Spark with the target \"package\" before running this program." 1>&2 exit 1 -fi -if [ -d "$ASSEMBLY_DIR" ]; then - ASSEMBLY_JARS="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" || true)" - if [ "$num_jars" -gt "1" ]; then - echo "Found multiple Spark assembly jars in $ASSEMBLY_DIR:" 1>&2 - echo "$ASSEMBLY_JARS" 1>&2 - echo "Please remove all but one jar." 1>&2 - exit 1 - fi +else + LAUNCH_CLASSPATH="$SPARK_JARS_DIR/*" fi -SPARK_ASSEMBLY_JAR="${ASSEMBLY_DIR}/${ASSEMBLY_JARS}" - -LAUNCH_CLASSPATH="$SPARK_ASSEMBLY_JAR" - # Add the launcher build dir to the classpath if requested. if [ -n "$SPARK_PREPEND_CLASSES" ]; then LAUNCH_CLASSPATH="${SPARK_HOME}/launcher/target/scala-$SPARK_SCALA_VERSION/classes:$LAUNCH_CLASSPATH" fi -export _SPARK_ASSEMBLY="$SPARK_ASSEMBLY_JAR" +# For tests +if [[ -n "$SPARK_TESTING" ]]; then + unset YARN_CONF_DIR + unset HADOOP_CONF_DIR +fi # The launcher library will print arguments separated by a NULL character, to allow arguments with # characters that would be otherwise interpreted by the shell. Read that in a while loop, populating diff --git a/bin/spark-class.cmd b/bin/spark-class.cmd index 19850db9e1e5d..3bf3d20cb57b5 100644 --- a/bin/spark-class.cmd +++ b/bin/spark-class.cmd @@ -20,4 +20,4 @@ rem rem This is the entry point for running a Spark class. To avoid polluting rem the environment, it just launches a new cmd to do the real work. -cmd /V /E /C %~dp0spark-class2.cmd %* +cmd /V /E /C "%~dp0spark-class2.cmd" %* diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index db09fa27e51a6..db680218dc964 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -20,7 +20,7 @@ rem rem Figure out where the Spark framework is installed set SPARK_HOME=%~dp0.. -call %SPARK_HOME%\bin\load-spark-env.cmd +call "%SPARK_HOME%\bin\load-spark-env.cmd" rem Test that an argument was given if "x%1"=="x" ( @@ -28,33 +28,26 @@ if "x%1"=="x" ( exit /b 1 ) -rem Find assembly jar -set SPARK_ASSEMBLY_JAR=0 - +rem Find Spark jars. if exist "%SPARK_HOME%\RELEASE" ( - set ASSEMBLY_DIR=%SPARK_HOME%\lib + set SPARK_JARS_DIR="%SPARK_HOME%\jars" ) else ( - set ASSEMBLY_DIR=%SPARK_HOME%\assembly\target\scala-%SPARK_SCALA_VERSION% + set SPARK_JARS_DIR="%SPARK_HOME%\assembly\target\scala-%SPARK_SCALA_VERSION%\jars" ) -for %%d in (%ASSEMBLY_DIR%\spark-assembly*hadoop*.jar) do ( - set SPARK_ASSEMBLY_JAR=%%d -) -if "%SPARK_ASSEMBLY_JAR%"=="0" ( - echo Failed to find Spark assembly JAR. +if not exist "%SPARK_JARS_DIR%"\ ( + echo Failed to find Spark jars directory. echo You need to build Spark before running this program. exit /b 1 ) -set LAUNCH_CLASSPATH=%SPARK_ASSEMBLY_JAR% +set LAUNCH_CLASSPATH=%SPARK_JARS_DIR%\* rem Add the launcher build dir to the classpath if requested. if not "x%SPARK_PREPEND_CLASSES%"=="x" ( - set LAUNCH_CLASSPATH=%SPARK_HOME%\launcher\target\scala-%SPARK_SCALA_VERSION%\classes;%LAUNCH_CLASSPATH% + set LAUNCH_CLASSPATH="%SPARK_HOME%\launcher\target\scala-%SPARK_SCALA_VERSION%\classes;%LAUNCH_CLASSPATH%" ) -set _SPARK_ASSEMBLY=%SPARK_ASSEMBLY_JAR% - rem Figure out where java is. set RUNNER=java if not "x%JAVA_HOME%"=="x" set RUNNER=%JAVA_HOME%\bin\java @@ -62,7 +55,7 @@ if not "x%JAVA_HOME%"=="x" set RUNNER=%JAVA_HOME%\bin\java rem The launcher library prints the command to be executed in a single line suitable for being rem executed by the batch interpreter. So read all the output of the launcher into a variable. set LAUNCHER_OUTPUT=%temp%\spark-class-launcher-output-%RANDOM%.txt -"%RUNNER%" -cp %LAUNCH_CLASSPATH% org.apache.spark.launcher.Main %* > %LAUNCHER_OUTPUT% +"%RUNNER%" -cp "%LAUNCH_CLASSPATH%" org.apache.spark.launcher.Main %* > %LAUNCHER_OUTPUT% for /f "tokens=*" %%i in (%LAUNCHER_OUTPUT%) do ( set SPARK_CMD=%%i ) diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd index 8f90ba5a0b3b8..991423da6ab99 100644 --- a/bin/spark-shell.cmd +++ b/bin/spark-shell.cmd @@ -20,4 +20,4 @@ rem rem This is the entry point for running Spark shell. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C %~dp0spark-shell2.cmd %* +cmd /V /E /C "%~dp0spark-shell2.cmd" %* diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd index b9b0f510d7f5d..7b5d396be888c 100644 --- a/bin/spark-shell2.cmd +++ b/bin/spark-shell2.cmd @@ -32,4 +32,4 @@ if "x%SPARK_SUBMIT_OPTS%"=="x" ( set SPARK_SUBMIT_OPTS="%SPARK_SUBMIT_OPTS% -Dscala.usejavacp=true" :run_shell -%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main --name "Spark shell" %* +"%SPARK_HOME%\bin\spark-submit2.cmd" --class org.apache.spark.repl.Main --name "Spark shell" %* diff --git a/bin/spark-submit.cmd b/bin/spark-submit.cmd index 8f3b84c7b971d..f301606933a95 100644 --- a/bin/spark-submit.cmd +++ b/bin/spark-submit.cmd @@ -20,4 +20,4 @@ rem rem This is the entry point for running Spark submit. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C %~dp0spark-submit2.cmd %* +cmd /V /E /C "%~dp0spark-submit2.cmd" %* diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd index 651376e526928..49e350fa5c416 100644 --- a/bin/spark-submit2.cmd +++ b/bin/spark-submit2.cmd @@ -24,4 +24,4 @@ rem disable randomized hash for string in Python 3.3+ set PYTHONHASHSEED=0 set CLASS=org.apache.spark.deploy.SparkSubmit -%~dp0spark-class2.cmd %CLASS% %* +"%~dp0spark-class2.cmd" %CLASS% %* diff --git a/bin/sparkR.cmd b/bin/sparkR.cmd index d7b60183ca8e0..1e5ea6a623219 100644 --- a/bin/sparkR.cmd +++ b/bin/sparkR.cmd @@ -20,4 +20,4 @@ rem rem This is the entry point for running SparkR. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C %~dp0sparkR2.cmd %* +cmd /V /E /C "%~dp0sparkR2.cmd" %* diff --git a/bin/sparkR2.cmd b/bin/sparkR2.cmd index e47f22c7300bb..459b780e2ae33 100644 --- a/bin/sparkR2.cmd +++ b/bin/sparkR2.cmd @@ -20,7 +20,7 @@ rem rem Figure out where the Spark framework is installed set SPARK_HOME=%~dp0.. -call %SPARK_HOME%\bin\load-spark-env.cmd +call "%SPARK_HOME%\bin\load-spark-env.cmd" -call %SPARK_HOME%\bin\spark-submit2.cmd sparkr-shell-main %* +call "%SPARK_HOME%\bin\spark-submit2.cmd" sparkr-shell-main %* diff --git a/build/mvn b/build/mvn index 7603ea03deb73..eb42552fc499e 100755 --- a/build/mvn +++ b/build/mvn @@ -69,10 +69,11 @@ install_app() { # Install maven under the build/ folder install_mvn() { - local MVN_VERSION="3.3.3" + local MVN_VERSION="3.3.9" + local APACHE_MIRROR=${APACHE_MIRROR:-'https://www.apache.org/dyn/closer.lua?action=download&filename='} install_app \ - "http://archive.apache.org/dist/maven/maven-3/${MVN_VERSION}/binaries" \ + "${APACHE_MIRROR}/maven/maven-3/${MVN_VERSION}/binaries" \ "apache-maven-${MVN_VERSION}-bin.tar.gz" \ "apache-maven-${MVN_VERSION}/bin/mvn" @@ -81,11 +82,13 @@ install_mvn() { # Install zinc under the build/ folder install_zinc() { - local zinc_path="zinc-0.3.5.3/bin/zinc" + local zinc_path="zinc-0.3.9/bin/zinc" [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 + local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.typesafe.com} + install_app \ - "http://downloads.typesafe.com/zinc/0.3.5.3" \ - "zinc-0.3.5.3.tgz" \ + "${TYPESAFE_MIRROR}/zinc/0.3.9" \ + "zinc-0.3.9.tgz" \ "${zinc_path}" ZINC_BIN="${_DIR}/${zinc_path}" } @@ -98,9 +101,10 @@ install_scala() { local scala_version=`grep "scala.version" "${_DIR}/../pom.xml" | \ head -1 | cut -f2 -d'>' | cut -f1 -d'<'` local scala_bin="${_DIR}/scala-${scala_version}/bin/scala" + local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.typesafe.com} install_app \ - "http://downloads.typesafe.com/scala/${scala_version}" \ + "${TYPESAFE_MIRROR}/scala/${scala_version}" \ "scala-${scala_version}.tgz" \ "scala-${scala_version}/bin/scala" diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml new file mode 100644 index 0000000000000..bd507c2cb6c4b --- /dev/null +++ b/common/network-common/pom.xml @@ -0,0 +1,103 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.0.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-network-common_2.11 + jar + Spark Project Networking + http://spark.apache.org/ + + network-common + + + + + + io.netty + netty-all + + + + + org.slf4j + slf4j-api + provided + + + com.google.code.findbugs + jsr305 + + + com.google.guava + guava + compile + + + + + log4j + log4j + test + + + org.apache.spark + spark-test-tags_${scala.binary.version} + + + org.mockito + mockito-core + test + + + org.slf4j + slf4j-log4j12 + test + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + + org.apache.maven.plugins + maven-jar-plugin + + + test-jar-on-test-compile + test-compile + + test-jar + + + + + + + diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java similarity index 89% rename from network/common/src/main/java/org/apache/spark/network/TransportContext.java rename to common/network-common/src/main/java/org/apache/spark/network/TransportContext.java index 43900e6f2c972..5320b28bc054c 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java @@ -43,7 +43,8 @@ /** * Contains the context to create a {@link TransportServer}, {@link TransportClientFactory}, and to - * setup Netty Channel pipelines with a {@link org.apache.spark.network.server.TransportChannelHandler}. + * setup Netty Channel pipelines with a + * {@link org.apache.spark.network.server.TransportChannelHandler}. * * There are two communication protocols that the TransportClient provides, control-plane RPCs and * data-plane "chunk fetching". The handling of the RPCs is performed outside of the scope of the @@ -59,15 +60,24 @@ public class TransportContext { private final TransportConf conf; private final RpcHandler rpcHandler; + private final boolean closeIdleConnections; private final MessageEncoder encoder; private final MessageDecoder decoder; public TransportContext(TransportConf conf, RpcHandler rpcHandler) { + this(conf, rpcHandler, false); + } + + public TransportContext( + TransportConf conf, + RpcHandler rpcHandler, + boolean closeIdleConnections) { this.conf = conf; this.rpcHandler = rpcHandler; this.encoder = new MessageEncoder(); this.decoder = new MessageDecoder(); + this.closeIdleConnections = closeIdleConnections; } /** @@ -85,7 +95,13 @@ public TransportClientFactory createClientFactory() { /** Create a server which will attempt to bind to a specific port. */ public TransportServer createServer(int port, List bootstraps) { - return new TransportServer(this, port, rpcHandler, bootstraps); + return new TransportServer(this, null, port, rpcHandler, bootstraps); + } + + /** Create a server which will attempt to bind to a specific host and port. */ + public TransportServer createServer( + String host, int port, List bootstraps) { + return new TransportServer(this, host, port, rpcHandler, bootstraps); } /** Creates a new server, binding to any available ephemeral port. */ @@ -144,7 +160,7 @@ private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, rpcHandler); return new TransportChannelHandler(client, responseHandler, requestHandler, - conf.connectionTimeoutMs()); + conf.connectionTimeoutMs(), closeIdleConnections); } public TransportConf getConf() { return conf; } diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java rename to common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java similarity index 99% rename from network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java rename to common/network-common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java index 81bc8ec40fc82..162cf6da0dffe 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java @@ -32,7 +32,7 @@ /** * A FileRegion implementation that only creates the file descriptor when the region is being * transferred. This cannot be used with Epoll because there is no native support for it. - * + * * This is mostly copied from DefaultFileRegion implementation in Netty. In the future, we * should push this into Netty so the native Epoll transport can support this feature. */ diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java similarity index 90% rename from network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java rename to common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java index a415db593a788..1861f8d7fd8f3 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java @@ -65,7 +65,11 @@ public abstract class ManagedBuffer { public abstract ManagedBuffer release(); /** - * Convert the buffer into an Netty object, used to write the data out. + * Convert the buffer into an Netty object, used to write the data out. The return value is either + * a {@link io.netty.buffer.ByteBuf} or a {@link io.netty.channel.FileRegion}. + * + * If this method returns a ByteBuf, then that buffer's reference count will be incremented and + * the caller will be responsible for releasing this new reference. */ public abstract Object convertToNetty() throws IOException; } diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java similarity index 95% rename from network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java rename to common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java index c806bfa45bef3..acc49d968c186 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java @@ -28,7 +28,7 @@ /** * A {@link ManagedBuffer} backed by a Netty {@link ByteBuf}. */ -public final class NettyManagedBuffer extends ManagedBuffer { +public class NettyManagedBuffer extends ManagedBuffer { private final ByteBuf buf; public NettyManagedBuffer(ByteBuf buf) { @@ -64,7 +64,7 @@ public ManagedBuffer release() { @Override public Object convertToNetty() throws IOException { - return buf.duplicate(); + return buf.duplicate().retain(); } @Override diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java similarity index 96% rename from network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java rename to common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java index f55b884bc45ce..631d767715256 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java @@ -28,7 +28,7 @@ /** * A {@link ManagedBuffer} backed by {@link ByteBuffer}. */ -public final class NioManagedBuffer extends ManagedBuffer { +public class NioManagedBuffer extends ManagedBuffer { private final ByteBuffer buf; public NioManagedBuffer(ByteBuffer buf) { diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java b/common/network-common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java rename to common/network-common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java rename to common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java diff --git a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java similarity index 77% rename from network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java rename to common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java index 6ec960d795420..6afc63f71bb3d 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java @@ -17,13 +17,20 @@ package org.apache.spark.network.client; +import java.nio.ByteBuffer; + /** * Callback for the result of a single RPC. This will be invoked once with either success or * failure. */ public interface RpcResponseCallback { - /** Successful serialized result from server. */ - void onSuccess(byte[] response); + /** + * Successful serialized result from server. + * + * After `onSuccess` returns, `response` will be recycled and its content will become invalid. + * Please copy the content of `response` if you want to use it after `onSuccess` returns. + */ + void onSuccess(ByteBuffer response); /** Exception either propagated from server or raised on client side. */ void onFailure(Throwable e); diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java similarity index 91% rename from network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java rename to common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java index 093fada320cc3..d322aec28793e 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java @@ -21,9 +21,9 @@ import java.nio.ByteBuffer; /** - * Callback for streaming data. Stream data will be offered to the {@link onData(ByteBuffer)} - * method as it arrives. Once all the stream data is received, {@link onComplete()} will be - * called. + * Callback for streaming data. Stream data will be offered to the + * {@link #onData(String, ByteBuffer)} method as it arrives. Once all the stream data is received, + * {@link #onComplete(String)} will be called. *

* The network library guarantees that a single thread will call these methods at a time, but * different call may be made by different threads. diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java similarity index 86% rename from network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java rename to common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java index 02230a00e69fc..b0e85bae7c309 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java @@ -30,13 +30,18 @@ */ class StreamInterceptor implements TransportFrameDecoder.Interceptor { + private final TransportResponseHandler handler; private final String streamId; private final long byteCount; private final StreamCallback callback; + private long bytesRead; - private volatile long bytesRead; - - StreamInterceptor(String streamId, long byteCount, StreamCallback callback) { + StreamInterceptor( + TransportResponseHandler handler, + String streamId, + long byteCount, + StreamCallback callback) { + this.handler = handler; this.streamId = streamId; this.byteCount = byteCount; this.callback = callback; @@ -45,11 +50,13 @@ class StreamInterceptor implements TransportFrameDecoder.Interceptor { @Override public void exceptionCaught(Throwable cause) throws Exception { + handler.deactivateStream(); callback.onFailure(streamId, cause); } @Override public void channelInactive() throws Exception { + handler.deactivateStream(); callback.onFailure(streamId, new ClosedChannelException()); } @@ -65,8 +72,10 @@ public boolean handle(ByteBuf buf) throws Exception { RuntimeException re = new IllegalStateException(String.format( "Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead)); callback.onFailure(streamId, re); + handler.deactivateStream(); throw re; } else if (bytesRead == byteCount) { + handler.deactivateStream(); callback.onComplete(streamId); } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java similarity index 85% rename from network/common/src/main/java/org/apache/spark/network/client/TransportClient.java rename to common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index a0ba223e340a2..64a83171e9e90 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -20,11 +20,13 @@ import java.io.Closeable; import java.io.IOException; import java.net.SocketAddress; +import java.nio.ByteBuffer; import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Objects; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; @@ -35,7 +37,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.OneWayMessage; import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.protocol.StreamRequest; @@ -73,10 +77,12 @@ public class TransportClient implements Closeable { private final Channel channel; private final TransportResponseHandler handler; @Nullable private String clientId; + private volatile boolean timedOut; public TransportClient(Channel channel, TransportResponseHandler handler) { this.channel = Preconditions.checkNotNull(channel); this.handler = Preconditions.checkNotNull(handler); + this.timedOut = false; } public Channel getChannel() { @@ -84,7 +90,7 @@ public Channel getChannel() { } public boolean isActive() { - return channel.isOpen() || channel.isActive(); + return !timedOut && (channel.isOpen() || channel.isActive()); } public SocketAddress getSocketAddress() { @@ -203,8 +209,12 @@ public void operationComplete(ChannelFuture future) throws Exception { /** * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked * with the server's response or upon any failure. + * + * @param message The message to send. + * @param callback Callback to handle the RPC's reply. + * @return The RPC's id. */ - public void sendRpc(byte[] message, final RpcResponseCallback callback) { + public long sendRpc(ByteBuffer message, final RpcResponseCallback callback) { final String serverAddr = NettyUtils.getRemoteAddress(channel); final long startTime = System.currentTimeMillis(); logger.trace("Sending RPC to {}", serverAddr); @@ -212,7 +222,7 @@ public void sendRpc(byte[] message, final RpcResponseCallback callback) { final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits()); handler.addRpcRequest(requestId, callback); - channel.writeAndFlush(new RpcRequest(requestId, message)).addListener( + channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message))).addListener( new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { @@ -233,19 +243,25 @@ public void operationComplete(ChannelFuture future) throws Exception { } } }); + + return requestId; } /** * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to * a specified timeout for a response. */ - public byte[] sendRpcSync(byte[] message, long timeoutMs) { - final SettableFuture result = SettableFuture.create(); + public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) { + final SettableFuture result = SettableFuture.create(); sendRpc(message, new RpcResponseCallback() { @Override - public void onSuccess(byte[] response) { - result.set(response); + public void onSuccess(ByteBuffer response) { + ByteBuffer copy = ByteBuffer.allocate(response.remaining()); + copy.put(response); + // flip "copy" to make it readable + copy.flip(); + result.set(copy); } @Override @@ -263,6 +279,35 @@ public void onFailure(Throwable e) { } } + /** + * Sends an opaque message to the RpcHandler on the server-side. No reply is expected for the + * message, and no delivery guarantees are made. + * + * @param message The message to send. + */ + public void send(ByteBuffer message) { + channel.writeAndFlush(new OneWayMessage(new NioManagedBuffer(message))); + } + + /** + * Removes any state associated with the given RPC. + * + * @param requestId The RPC id returned by {@link #sendRpc(ByteBuffer, RpcResponseCallback)}. + */ + public void removeRpcRequest(long requestId) { + handler.removeRpcRequest(requestId); + } + + /** Mark this channel as having timed out. */ + public void timeOut() { + this.timedOut = true; + } + + @VisibleForTesting + public TransportResponseHandler getHandler() { + return handler; + } + @Override public void close() { // close is a local operation and should finish with milliseconds; timeout just to be safe diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java rename to common/network-common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java similarity index 80% rename from network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java rename to common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 4952ffb44bb8b..a27aaf2b277f7 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -64,7 +64,7 @@ private static class ClientPool { TransportClient[] clients; Object[] locks; - public ClientPool(int size) { + ClientPool(int size) { clients = new TransportClient[size]; locks = new Object[size]; for (int i = 0; i < size; i++) { @@ -94,7 +94,7 @@ public TransportClientFactory( this.context = Preconditions.checkNotNull(context); this.conf = context.getConf(); this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps)); - this.connectionPool = new ConcurrentHashMap(); + this.connectionPool = new ConcurrentHashMap<>(); this.numConnectionsPerPeer = conf.numConnectionsPerPeer(); this.rand = new Random(); @@ -123,41 +123,76 @@ public TransportClientFactory( public TransportClient createClient(String remoteHost, int remotePort) throws IOException { // Get connection from the connection pool first. // If it is not found or not active, create a new one. - final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); + // Use unresolved address here to avoid DNS resolution each time we creates a client. + final InetSocketAddress unresolvedAddress = + InetSocketAddress.createUnresolved(remoteHost, remotePort); // Create the ClientPool if we don't have it yet. - ClientPool clientPool = connectionPool.get(address); + ClientPool clientPool = connectionPool.get(unresolvedAddress); if (clientPool == null) { - connectionPool.putIfAbsent(address, new ClientPool(numConnectionsPerPeer)); - clientPool = connectionPool.get(address); + connectionPool.putIfAbsent(unresolvedAddress, new ClientPool(numConnectionsPerPeer)); + clientPool = connectionPool.get(unresolvedAddress); } int clientIndex = rand.nextInt(numConnectionsPerPeer); TransportClient cachedClient = clientPool.clients[clientIndex]; if (cachedClient != null && cachedClient.isActive()) { - logger.trace("Returning cached connection to {}: {}", address, cachedClient); - return cachedClient; + // Make sure that the channel will not timeout by updating the last use time of the + // handler. Then check that the client is still alive, in case it timed out before + // this code was able to update things. + TransportChannelHandler handler = cachedClient.getChannel().pipeline() + .get(TransportChannelHandler.class); + synchronized (handler) { + handler.getResponseHandler().updateTimeOfLastRequest(); + } + + if (cachedClient.isActive()) { + logger.trace("Returning cached connection to {}: {}", + cachedClient.getSocketAddress(), cachedClient); + return cachedClient; + } } // If we reach here, we don't have an existing connection open. Let's create a new one. // Multiple threads might race here to create new connections. Keep only one of them active. + final long preResolveHost = System.nanoTime(); + final InetSocketAddress resolvedAddress = new InetSocketAddress(remoteHost, remotePort); + final long hostResolveTimeMs = (System.nanoTime() - preResolveHost) / 1000000; + if (hostResolveTimeMs > 2000) { + logger.warn("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs); + } else { + logger.trace("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs); + } + synchronized (clientPool.locks[clientIndex]) { cachedClient = clientPool.clients[clientIndex]; if (cachedClient != null) { if (cachedClient.isActive()) { - logger.trace("Returning cached connection to {}: {}", address, cachedClient); + logger.trace("Returning cached connection to {}: {}", resolvedAddress, cachedClient); return cachedClient; } else { - logger.info("Found inactive connection to {}, creating a new one.", address); + logger.info("Found inactive connection to {}, creating a new one.", resolvedAddress); } } - clientPool.clients[clientIndex] = createClient(address); + clientPool.clients[clientIndex] = createClient(resolvedAddress); return clientPool.clients[clientIndex]; } } + /** + * Create a completely new {@link TransportClient} to the given remote host / port. + * This connection is not pooled. + * + * As with {@link #createClient(String, int)}, this method is blocking. + */ + public TransportClient createUnmanagedClient(String remoteHost, int remotePort) + throws IOException { + final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); + return createClient(address); + } + /** Create a completely new {@link TransportClient} to the remote address. */ private TransportClient createClient(InetSocketAddress address) throws IOException { logger.debug("Creating new connection to " + address); @@ -171,8 +206,8 @@ private TransportClient createClient(InetSocketAddress address) throws IOExcepti .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs()) .option(ChannelOption.ALLOCATOR, pooledAllocator); - final AtomicReference clientRef = new AtomicReference(); - final AtomicReference channelRef = new AtomicReference(); + final AtomicReference clientRef = new AtomicReference<>(); + final AtomicReference channelRef = new AtomicReference<>(); bootstrap.handler(new ChannelInitializer() { @Override @@ -212,7 +247,7 @@ public void initChannel(SocketChannel ch) { } long postBootstrap = System.nanoTime(); - logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)", + logger.info("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)", address, (postBootstrap - preConnect) / 1000000, (postBootstrap - preBootstrap) / 1000000); return client; diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java similarity index 81% rename from network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java rename to common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index ed3f36af58048..8a69223c88ee4 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -24,6 +24,7 @@ import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicLong; +import com.google.common.annotations.VisibleForTesting; import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -56,20 +57,21 @@ public class TransportResponseHandler extends MessageHandler { private final Map outstandingRpcs; private final Queue streamCallbacks; + private volatile boolean streamActive; /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */ private final AtomicLong timeOfLastRequestNs; public TransportResponseHandler(Channel channel) { this.channel = channel; - this.outstandingFetches = new ConcurrentHashMap(); - this.outstandingRpcs = new ConcurrentHashMap(); - this.streamCallbacks = new ConcurrentLinkedQueue(); + this.outstandingFetches = new ConcurrentHashMap<>(); + this.outstandingRpcs = new ConcurrentHashMap<>(); + this.streamCallbacks = new ConcurrentLinkedQueue<>(); this.timeOfLastRequestNs = new AtomicLong(0); } public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { - timeOfLastRequestNs.set(System.nanoTime()); + updateTimeOfLastRequest(); outstandingFetches.put(streamChunkId, callback); } @@ -78,7 +80,7 @@ public void removeFetchRequest(StreamChunkId streamChunkId) { } public void addRpcRequest(long requestId, RpcResponseCallback callback) { - timeOfLastRequestNs.set(System.nanoTime()); + updateTimeOfLastRequest(); outstandingRpcs.put(requestId, callback); } @@ -87,9 +89,15 @@ public void removeRpcRequest(long requestId) { } public void addStreamCallback(StreamCallback callback) { + timeOfLastRequestNs.set(System.nanoTime()); streamCallbacks.offer(callback); } + @VisibleForTesting + public void deactivateStream() { + streamActive = false; + } + /** * Fire the failure callback for all outstanding requests. This is called when we have an * uncaught exception or pre-mature connection termination. @@ -108,7 +116,11 @@ private void failOutstandingRequests(Throwable cause) { } @Override - public void channelUnregistered() { + public void channelActive() { + } + + @Override + public void channelInactive() { if (numOutstandingRequests() > 0) { String remoteAddress = NettyUtils.getRemoteAddress(channel); logger.error("Still have {} requests outstanding when connection from {} is closed", @@ -128,7 +140,7 @@ public void exceptionCaught(Throwable cause) { } @Override - public void handle(ResponseMessage message) { + public void handle(ResponseMessage message) throws Exception { String remoteAddress = NettyUtils.getRemoteAddress(channel); if (message instanceof ChunkFetchSuccess) { ChunkFetchSuccess resp = (ChunkFetchSuccess) message; @@ -136,11 +148,11 @@ public void handle(ResponseMessage message) { if (listener == null) { logger.warn("Ignoring response for block {} from {} since it is not outstanding", resp.streamChunkId, remoteAddress); - resp.body.release(); + resp.body().release(); } else { outstandingFetches.remove(resp.streamChunkId); - listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body); - resp.body.release(); + listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body()); + resp.body().release(); } } else if (message instanceof ChunkFetchFailure) { ChunkFetchFailure resp = (ChunkFetchFailure) message; @@ -158,10 +170,14 @@ public void handle(ResponseMessage message) { RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); if (listener == null) { logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding", - resp.requestId, remoteAddress, resp.response.length); + resp.requestId, remoteAddress, resp.body().size()); } else { outstandingRpcs.remove(resp.requestId); - listener.onSuccess(resp.response); + try { + listener.onSuccess(resp.body().nioByteBuffer()); + } finally { + resp.body().release(); + } } } else if (message instanceof RpcFailure) { RpcFailure resp = (RpcFailure) message; @@ -177,14 +193,24 @@ public void handle(ResponseMessage message) { StreamResponse resp = (StreamResponse) message; StreamCallback callback = streamCallbacks.poll(); if (callback != null) { - StreamInterceptor interceptor = new StreamInterceptor(resp.streamId, resp.byteCount, - callback); - try { - TransportFrameDecoder frameDecoder = (TransportFrameDecoder) - channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); - frameDecoder.setInterceptor(interceptor); - } catch (Exception e) { - logger.error("Error installing stream handler.", e); + if (resp.byteCount > 0) { + StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, + callback); + try { + TransportFrameDecoder frameDecoder = (TransportFrameDecoder) + channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); + frameDecoder.setInterceptor(interceptor); + streamActive = true; + } catch (Exception e) { + logger.error("Error installing stream handler.", e); + deactivateStream(); + } + } else { + try { + callback.onComplete(resp.streamId); + } catch (Exception e) { + logger.warn("Error in stream handler onComplete().", e); + } } } else { logger.error("Could not find callback for StreamResponse."); @@ -208,7 +234,8 @@ public void handle(ResponseMessage message) { /** Returns total number of outstanding requests (fetch requests + rpcs) */ public int numOutstandingRequests() { - return outstandingFetches.size() + outstandingRpcs.size(); + return outstandingFetches.size() + outstandingRpcs.size() + streamCallbacks.size() + + (streamActive ? 1 : 0); } /** Returns the time in nanoseconds of when the last request was sent out. */ @@ -216,4 +243,9 @@ public long getTimeOfLastRequestNs() { return timeOfLastRequestNs.get(); } + /** Updates the time of the last request to the current system time. */ + public void updateTimeOfLastRequest() { + timeOfLastRequestNs.set(System.nanoTime()); + } + } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java new file mode 100644 index 0000000000000..2924218c2f08b --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * Abstract class for messages which optionally contain a body kept in a separate buffer. + */ +public abstract class AbstractMessage implements Message { + private final ManagedBuffer body; + private final boolean isBodyInFrame; + + protected AbstractMessage() { + this(null, false); + } + + protected AbstractMessage(ManagedBuffer body, boolean isBodyInFrame) { + this.body = body; + this.isBodyInFrame = isBodyInFrame; + } + + @Override + public ManagedBuffer body() { + return body; + } + + @Override + public boolean isBodyInFrame() { + return isBodyInFrame; + } + + protected boolean equals(AbstractMessage other) { + return isBodyInFrame == other.isBodyInFrame && Objects.equal(body, other.body); + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java new file mode 100644 index 0000000000000..c362c92fc4f52 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * Abstract class for response messages. + */ +public abstract class AbstractResponseMessage extends AbstractMessage implements ResponseMessage { + + protected AbstractResponseMessage(ManagedBuffer body, boolean isBodyInFrame) { + super(body, isBodyInFrame); + } + + public abstract ResponseMessage createFailureResponse(String error); +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java similarity index 96% rename from network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java index f0363830b61ac..7b28a9a969486 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java @@ -23,7 +23,7 @@ /** * Response to {@link ChunkFetchRequest} when there is an error fetching the chunk. */ -public final class ChunkFetchFailure implements ResponseMessage { +public final class ChunkFetchFailure extends AbstractMessage implements ResponseMessage { public final StreamChunkId streamChunkId; public final String errorString; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java similarity index 95% rename from network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java index 5a173af54f618..26d063feb5fe3 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java @@ -24,7 +24,7 @@ * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). */ -public final class ChunkFetchRequest implements RequestMessage { +public final class ChunkFetchRequest extends AbstractMessage implements RequestMessage { public final StreamChunkId streamChunkId; public ChunkFetchRequest(StreamChunkId streamChunkId) { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java similarity index 92% rename from network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java index e6a7e9a8b4145..94c2ac9b20e43 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java @@ -30,7 +30,7 @@ * may be written by Netty in a more efficient manner (i.e., zero-copy write). * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer. */ -public final class ChunkFetchSuccess extends ResponseWithBody { +public final class ChunkFetchSuccess extends AbstractResponseMessage { public final StreamChunkId streamChunkId; public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) { @@ -67,14 +67,14 @@ public static ChunkFetchSuccess decode(ByteBuf buf) { @Override public int hashCode() { - return Objects.hashCode(streamChunkId, body); + return Objects.hashCode(streamChunkId, body()); } @Override public boolean equals(Object other) { if (other instanceof ChunkFetchSuccess) { ChunkFetchSuccess o = (ChunkFetchSuccess) other; - return streamChunkId.equals(o.streamChunkId) && body.equals(o.body); + return streamChunkId.equals(o.streamChunkId) && super.equals(o); } return false; } @@ -83,7 +83,7 @@ public boolean equals(Object other) { public String toString() { return Objects.toStringHelper(this) .add("streamChunkId", streamChunkId) - .add("buffer", body) + .add("buffer", body()) .toString(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encodable.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/Encodable.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java similarity index 92% rename from network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java index 9162d0b977f83..be217522367c5 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java @@ -17,8 +17,8 @@ package org.apache.spark.network.protocol; +import java.nio.charset.StandardCharsets; -import com.google.common.base.Charsets; import io.netty.buffer.ByteBuf; /** Provides a canonical set of Encoders for simple types. */ @@ -27,11 +27,11 @@ public class Encoders { /** Strings are encoded with their length followed by UTF-8 bytes. */ public static class Strings { public static int encodedLength(String s) { - return 4 + s.getBytes(Charsets.UTF_8).length; + return 4 + s.getBytes(StandardCharsets.UTF_8).length; } public static void encode(ByteBuf buf, String s) { - byte[] bytes = s.getBytes(Charsets.UTF_8); + byte[] bytes = s.getBytes(StandardCharsets.UTF_8); buf.writeInt(bytes.length); buf.writeBytes(bytes); } @@ -40,7 +40,7 @@ public static String decode(ByteBuf buf) { int length = buf.readInt(); byte[] bytes = new byte[length]; buf.readBytes(bytes); - return new String(bytes, Charsets.UTF_8); + return new String(bytes, StandardCharsets.UTF_8); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java similarity index 80% rename from network/common/src/main/java/org/apache/spark/network/protocol/Message.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java index d01598c20f16f..434935a8ef2ad 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -19,20 +19,29 @@ import io.netty.buffer.ByteBuf; +import org.apache.spark.network.buffer.ManagedBuffer; + /** An on-the-wire transmittable message. */ public interface Message extends Encodable { /** Used to identify this request type. */ Type type(); + /** An optional body for the message. */ + ManagedBuffer body(); + + /** Whether to include the body of the message in the same frame as the message. */ + boolean isBodyInFrame(); + /** Preceding every serialized Message is its type, which allows us to deserialize it. */ - public static enum Type implements Encodable { + enum Type implements Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), RpcRequest(3), RpcResponse(4), RpcFailure(5), - StreamRequest(6), StreamResponse(7), StreamFailure(8); + StreamRequest(6), StreamResponse(7), StreamFailure(8), + OneWayMessage(9), User(-1); private final byte id; - private Type(int id) { + Type(int id) { assert id < 128 : "Cannot have more than 128 message types"; this.id = (byte) id; } @@ -55,6 +64,8 @@ public static Type decode(ByteBuf buf) { case 6: return StreamRequest; case 7: return StreamResponse; case 8: return StreamFailure; + case 9: return OneWayMessage; + case -1: throw new IllegalArgumentException("User type messages cannot be decoded."); default: throw new IllegalArgumentException("Unknown message type: " + id); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java similarity index 97% rename from network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index 3c04048f3821a..074780f2b95ce 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -63,6 +63,9 @@ private Message decode(Message.Type msgType, ByteBuf in) { case RpcFailure: return RpcFailure.decode(in); + case OneWayMessage: + return OneWayMessage.decode(in); + case StreamRequest: return StreamRequest.decode(in); diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java new file mode 100644 index 0000000000000..664df57feca4f --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageEncoder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Encoder used by the server side to encode server-to-client responses. + * This encoder is stateless so it is safe to be shared by multiple threads. + */ +@ChannelHandler.Sharable +public final class MessageEncoder extends MessageToMessageEncoder { + + private final Logger logger = LoggerFactory.getLogger(MessageEncoder.class); + + /*** + * Encodes a Message by invoking its encode() method. For non-data messages, we will add one + * ByteBuf to 'out' containing the total frame length, the message type, and the message itself. + * In the case of a ChunkFetchSuccess, we will also add the ManagedBuffer corresponding to the + * data to 'out', in order to enable zero-copy transfer. + */ + @Override + public void encode(ChannelHandlerContext ctx, Message in, List out) throws Exception { + Object body = null; + long bodyLength = 0; + boolean isBodyInFrame = false; + + // If the message has a body, take it out to enable zero-copy transfer for the payload. + if (in.body() != null) { + try { + bodyLength = in.body().size(); + body = in.body().convertToNetty(); + isBodyInFrame = in.isBodyInFrame(); + } catch (Exception e) { + in.body().release(); + if (in instanceof AbstractResponseMessage) { + AbstractResponseMessage resp = (AbstractResponseMessage) in; + // Re-encode this message as a failure response. + String error = e.getMessage() != null ? e.getMessage() : "null"; + logger.error(String.format("Error processing %s for client %s", + in, ctx.channel().remoteAddress()), e); + encode(ctx, resp.createFailureResponse(error), out); + } else { + throw e; + } + return; + } + } + + Message.Type msgType = in.type(); + // All messages have the frame length, message type, and message itself. The frame length + // may optionally include the length of the body data, depending on what message is being + // sent. + int headerLength = 8 + msgType.encodedLength() + in.encodedLength(); + long frameLength = headerLength + (isBodyInFrame ? bodyLength : 0); + ByteBuf header = ctx.alloc().heapBuffer(headerLength); + header.writeLong(frameLength); + msgType.encode(header); + in.encode(header); + assert header.writableBytes() == 0; + + if (body != null) { + // We transfer ownership of the reference on in.body() to MessageWithHeader. + // This reference will be freed when MessageWithHeader.deallocate() is called. + out.add(new MessageWithHeader(in.body(), header, body, bodyLength)); + } else { + out.add(header); + } + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java new file mode 100644 index 0000000000000..4f8781b42a0e4 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import javax.annotation.Nullable; + +import com.google.common.base.Preconditions; +import io.netty.buffer.ByteBuf; +import io.netty.channel.FileRegion; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.ReferenceCountUtil; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * A wrapper message that holds two separate pieces (a header and a body). + * + * The header must be a ByteBuf, while the body can be a ByteBuf or a FileRegion. + */ +class MessageWithHeader extends AbstractReferenceCounted implements FileRegion { + + @Nullable private final ManagedBuffer managedBuffer; + private final ByteBuf header; + private final int headerLength; + private final Object body; + private final long bodyLength; + private long totalBytesTransferred; + + /** + * When the write buffer size is larger than this limit, I/O will be done in chunks of this size. + * The size should not be too large as it will waste underlying memory copy. e.g. If network + * avaliable buffer is smaller than this limit, the data cannot be sent within one single write + * operation while it still will make memory copy with this size. + */ + private static final int NIO_BUFFER_LIMIT = 256 * 1024; + + /** + * Construct a new MessageWithHeader. + * + * @param managedBuffer the {@link ManagedBuffer} that the message body came from. This needs to + * be passed in so that the buffer can be freed when this message is + * deallocated. Ownership of the caller's reference to this buffer is + * transferred to this class, so if the caller wants to continue to use the + * ManagedBuffer in other messages then they will need to call retain() on + * it before passing it to this constructor. This may be null if and only if + * `body` is a {@link FileRegion}. + * @param header the message header. + * @param body the message body. Must be either a {@link ByteBuf} or a {@link FileRegion}. + * @param bodyLength the length of the message body, in bytes. + */ + MessageWithHeader( + @Nullable ManagedBuffer managedBuffer, + ByteBuf header, + Object body, + long bodyLength) { + Preconditions.checkArgument(body instanceof ByteBuf || body instanceof FileRegion, + "Body must be a ByteBuf or a FileRegion."); + this.managedBuffer = managedBuffer; + this.header = header; + this.headerLength = header.readableBytes(); + this.body = body; + this.bodyLength = bodyLength; + } + + @Override + public long count() { + return headerLength + bodyLength; + } + + @Override + public long position() { + return 0; + } + + @Override + public long transfered() { + return totalBytesTransferred; + } + + /** + * This code is more complicated than you would think because we might require multiple + * transferTo invocations in order to transfer a single MessageWithHeader to avoid busy waiting. + * + * The contract is that the caller will ensure position is properly set to the total number + * of bytes transferred so far (i.e. value returned by transfered()). + */ + @Override + public long transferTo(final WritableByteChannel target, final long position) throws IOException { + Preconditions.checkArgument(position == totalBytesTransferred, "Invalid position."); + // Bytes written for header in this call. + long writtenHeader = 0; + if (header.readableBytes() > 0) { + writtenHeader = copyByteBuf(header, target); + totalBytesTransferred += writtenHeader; + if (header.readableBytes() > 0) { + return writtenHeader; + } + } + + // Bytes written for body in this call. + long writtenBody = 0; + if (body instanceof FileRegion) { + writtenBody = ((FileRegion) body).transferTo(target, totalBytesTransferred - headerLength); + } else if (body instanceof ByteBuf) { + writtenBody = copyByteBuf((ByteBuf) body, target); + } + totalBytesTransferred += writtenBody; + + return writtenHeader + writtenBody; + } + + @Override + protected void deallocate() { + header.release(); + ReferenceCountUtil.release(body); + if (managedBuffer != null) { + managedBuffer.release(); + } + } + + private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException { + ByteBuffer buffer = buf.nioBuffer(); + int written = (buffer.remaining() <= NIO_BUFFER_LIMIT) ? + target.write(buffer) : writeNioBuffer(target, buffer); + buf.skipBytes(written); + return written; + } + + private int writeNioBuffer( + WritableByteChannel writeCh, + ByteBuffer buf) throws IOException { + int originalLimit = buf.limit(); + int ret = 0; + + try { + int ioSize = Math.min(buf.remaining(), NIO_BUFFER_LIMIT); + buf.limit(buf.position() + ioSize); + ret = writeCh.write(buf); + } finally { + buf.limit(originalLimit); + } + + return ret; + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java new file mode 100644 index 0000000000000..f7ffb1bd49bb6 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * A RPC that does not expect a reply, which is handled by a remote + * {@link org.apache.spark.network.server.RpcHandler}. + */ +public final class OneWayMessage extends AbstractMessage implements RequestMessage { + + public OneWayMessage(ManagedBuffer body) { + super(body, true); + } + + @Override + public Type type() { return Type.OneWayMessage; } + + @Override + public int encodedLength() { + // The integer (a.k.a. the body size) is not really used, since that information is already + // encoded in the frame length. But this maintains backwards compatibility with versions of + // RpcRequest that use Encoders.ByteArrays. + return 4; + } + + @Override + public void encode(ByteBuf buf) { + // See comment in encodedLength(). + buf.writeInt((int) body().size()); + } + + public static OneWayMessage decode(ByteBuf buf) { + // See comment in encodedLength(). + buf.readInt(); + return new OneWayMessage(new NettyManagedBuffer(buf.retain())); + } + + @Override + public int hashCode() { + return Objects.hashCode(body()); + } + + @Override + public boolean equals(Object other) { + if (other instanceof OneWayMessage) { + OneWayMessage o = (OneWayMessage) other; + return super.equals(o); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("body", body()) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java similarity index 94% rename from network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java index 31b15bb17a327..b85171ed6f3d1 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java @@ -17,8 +17,6 @@ package org.apache.spark.network.protocol; -import org.apache.spark.network.protocol.Message; - /** Messages from the client to the server. */ public interface RequestMessage extends Message { // token interface diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java similarity index 94% rename from network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java index 6edffd11cf1e2..194e6d9aa2bd4 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java @@ -17,8 +17,6 @@ package org.apache.spark.network.protocol; -import org.apache.spark.network.protocol.Message; - /** Messages from the server to the client. */ public interface ResponseMessage extends Message { // token interface diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java similarity index 96% rename from network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java index 2dfc7876ba328..a76624ef5dc96 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java @@ -21,7 +21,7 @@ import io.netty.buffer.ByteBuf; /** Response to {@link RpcRequest} for a failed RPC. */ -public final class RpcFailure implements ResponseMessage { +public final class RpcFailure extends AbstractMessage implements ResponseMessage { public final long requestId; public final String errorString; diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java new file mode 100644 index 0000000000000..2b30920f0598d --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}. + * This will correspond to a single + * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). + */ +public final class RpcRequest extends AbstractMessage implements RequestMessage { + /** Used to link an RPC request with its response. */ + public final long requestId; + + public RpcRequest(long requestId, ManagedBuffer message) { + super(message, true); + this.requestId = requestId; + } + + @Override + public Type type() { return Type.RpcRequest; } + + @Override + public int encodedLength() { + // The integer (a.k.a. the body size) is not really used, since that information is already + // encoded in the frame length. But this maintains backwards compatibility with versions of + // RpcRequest that use Encoders.ByteArrays. + return 8 + 4; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(requestId); + // See comment in encodedLength(). + buf.writeInt((int) body().size()); + } + + public static RpcRequest decode(ByteBuf buf) { + long requestId = buf.readLong(); + // See comment in encodedLength(). + buf.readInt(); + return new RpcRequest(requestId, new NettyManagedBuffer(buf.retain())); + } + + @Override + public int hashCode() { + return Objects.hashCode(requestId, body()); + } + + @Override + public boolean equals(Object other) { + if (other instanceof RpcRequest) { + RpcRequest o = (RpcRequest) other; + return requestId == o.requestId && super.equals(o); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("requestId", requestId) + .add("body", body()) + .toString(); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java new file mode 100644 index 0000000000000..d73014ecd8506 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** Response to {@link RpcRequest} for a successful RPC. */ +public final class RpcResponse extends AbstractResponseMessage { + public final long requestId; + + public RpcResponse(long requestId, ManagedBuffer message) { + super(message, true); + this.requestId = requestId; + } + + @Override + public Type type() { return Type.RpcResponse; } + + @Override + public int encodedLength() { + // The integer (a.k.a. the body size) is not really used, since that information is already + // encoded in the frame length. But this maintains backwards compatibility with versions of + // RpcRequest that use Encoders.ByteArrays. + return 8 + 4; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(requestId); + // See comment in encodedLength(). + buf.writeInt((int) body().size()); + } + + @Override + public ResponseMessage createFailureResponse(String error) { + return new RpcFailure(requestId, error); + } + + public static RpcResponse decode(ByteBuf buf) { + long requestId = buf.readLong(); + // See comment in encodedLength(). + buf.readInt(); + return new RpcResponse(requestId, new NettyManagedBuffer(buf.retain())); + } + + @Override + public int hashCode() { + return Objects.hashCode(requestId, body()); + } + + @Override + public boolean equals(Object other) { + if (other instanceof RpcResponse) { + RpcResponse o = (RpcResponse) other; + return requestId == o.requestId && super.equals(o); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("requestId", requestId) + .add("body", body()) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java similarity index 92% rename from network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java index e3dade2ebf905..258ef81c6783d 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java @@ -20,13 +20,10 @@ import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NettyManagedBuffer; - /** * Message indicating an error when transferring a stream. */ -public final class StreamFailure implements ResponseMessage { +public final class StreamFailure extends AbstractMessage implements ResponseMessage { public final String streamId; public final String error; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java similarity index 92% rename from network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java index 821e8f53884d7..dc183c043ed9a 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java @@ -20,16 +20,13 @@ import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NettyManagedBuffer; - /** * Request to stream data from the remote end. *

* The stream ID is an arbitrary string that needs to be negotiated between the two endpoints before * the data can be streamed. */ -public final class StreamRequest implements RequestMessage { +public final class StreamRequest extends AbstractMessage implements RequestMessage { public final String streamId; public StreamRequest(String streamId) { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java similarity index 85% rename from network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java rename to common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java index ac5ab9a323a11..87e212f3e157b 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java @@ -21,7 +21,6 @@ import io.netty.buffer.ByteBuf; import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NettyManagedBuffer; /** * Response to {@link StreamRequest} when the stream has been successfully opened. @@ -30,15 +29,15 @@ * sender. The receiver is expected to set a temporary channel handler that will consume the * number of bytes this message says the stream has. */ -public final class StreamResponse extends ResponseWithBody { - public final String streamId; - public final long byteCount; +public final class StreamResponse extends AbstractResponseMessage { + public final String streamId; + public final long byteCount; - public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) { - super(buffer, false); - this.streamId = streamId; - this.byteCount = byteCount; - } + public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) { + super(buffer, false); + this.streamId = streamId; + this.byteCount = byteCount; + } @Override public Type type() { return Type.StreamResponse; } @@ -68,7 +67,7 @@ public static StreamResponse decode(ByteBuf buf) { @Override public int hashCode() { - return Objects.hashCode(byteCount, streamId); + return Objects.hashCode(byteCount, streamId, body()); } @Override @@ -85,6 +84,7 @@ public String toString() { return Objects.toStringHelper(this) .add("streamId", streamId) .add("byteCount", byteCount) + .add("body", body()) .toString(); } diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java similarity index 88% rename from network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 69923769d44b4..68381037d6891 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -17,6 +17,8 @@ package org.apache.spark.network.sasl; +import java.io.IOException; +import java.nio.ByteBuffer; import javax.security.sasl.Sasl; import javax.security.sasl.SaslException; @@ -28,6 +30,7 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; /** @@ -70,11 +73,12 @@ public void doBootstrap(TransportClient client, Channel channel) { while (!saslClient.isComplete()) { SaslMessage msg = new SaslMessage(appId, payload); - ByteBuf buf = Unpooled.buffer(msg.encodedLength()); + ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size()); msg.encode(buf); + buf.writeBytes(msg.body().nioByteBuffer()); - byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeoutMs()); - payload = saslClient.response(response); + ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.saslRTTimeoutMs()); + payload = saslClient.response(JavaUtils.bufferToArray(response)); } client.setClientId(appId); @@ -88,6 +92,8 @@ public void doBootstrap(TransportClient client, Channel channel) { saslClient = null; logger.debug("Channel {} configured for SASL encryption.", client); } + } catch (IOException ioe) { + throw new RuntimeException(ioe); } finally { if (saslClient != null) { try { diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java similarity index 99% rename from network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java index 127335e4d35fb..3d71ebaa7ea0c 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java @@ -33,7 +33,6 @@ import io.netty.channel.FileRegion; import io.netty.handler.codec.MessageToMessageDecoder; import io.netty.util.AbstractReferenceCounted; -import io.netty.util.ReferenceCountUtil; import org.apache.spark.network.util.ByteArrayWritableChannel; import org.apache.spark.network.util.NettyUtils; diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java new file mode 100644 index 0000000000000..7331c2b481fb1 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.buffer.NettyManagedBuffer; +import org.apache.spark.network.protocol.Encoders; +import org.apache.spark.network.protocol.AbstractMessage; + +/** + * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged + * with the given appId. This appId allows a single SaslRpcHandler to multiplex different + * applications which may be using different sets of credentials. + */ +class SaslMessage extends AbstractMessage { + + /** Serialization tag used to catch incorrect payloads. */ + private static final byte TAG_BYTE = (byte) 0xEA; + + public final String appId; + + SaslMessage(String appId, byte[] message) { + this(appId, Unpooled.wrappedBuffer(message)); + } + + SaslMessage(String appId, ByteBuf message) { + super(new NettyManagedBuffer(message), true); + this.appId = appId; + } + + @Override + public Type type() { return Type.User; } + + @Override + public int encodedLength() { + // The integer (a.k.a. the body size) is not really used, since that information is already + // encoded in the frame length. But this maintains backwards compatibility with versions of + // RpcRequest that use Encoders.ByteArrays. + return 1 + Encoders.Strings.encodedLength(appId) + 4; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeByte(TAG_BYTE); + Encoders.Strings.encode(buf, appId); + // See comment in encodedLength(). + buf.writeInt((int) body().size()); + } + + public static SaslMessage decode(ByteBuf buf) { + if (buf.readByte() != TAG_BYTE) { + throw new IllegalStateException("Expected SaslMessage, received something else" + + " (maybe your client does not have SASL enabled?)"); + } + + String appId = Encoders.Strings.decode(buf); + // See comment in encodedLength(). + buf.readInt(); + return new SaslMessage(appId, buf.retain()); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java similarity index 80% rename from network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 7033adb9cae6f..c41f5b6873f6c 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -17,8 +17,11 @@ package org.apache.spark.network.sasl; +import java.io.IOException; +import java.nio.ByteBuffer; import javax.security.sasl.Sasl; +import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import org.slf4j.Logger; @@ -28,6 +31,7 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; /** @@ -70,14 +74,20 @@ class SaslRpcHandler extends RpcHandler { } @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { if (isComplete) { // Authentication complete, delegate to base handler. delegate.receive(client, message, callback); return; } - SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message)); + ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); + SaslMessage saslMessage; + try { + saslMessage = SaslMessage.decode(nettyBuf); + } finally { + nettyBuf.release(); + } if (saslServer == null) { // First message in the handshake, setup the necessary state. @@ -86,8 +96,14 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback conf.saslServerAlwaysEncrypt()); } - byte[] response = saslServer.response(saslMessage.payload); - callback.onSuccess(response); + byte[] response; + try { + response = saslServer.response(JavaUtils.bufferToArray( + saslMessage.body().nioByteBuffer())); + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + callback.onSuccess(ByteBuffer.wrap(response)); // Setup encryption after the SASL response is sent, otherwise the client can't parse the // response. It's ok to change the channel pipeline here since we are processing an incoming @@ -108,15 +124,25 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback } } + @Override + public void receive(TransportClient client, ByteBuffer message) { + delegate.receive(client, message); + } + @Override public StreamManager getStreamManager() { return delegate.getStreamManager(); } @Override - public void connectionTerminated(TransportClient client) { + public void channelActive(TransportClient client) { + delegate.channelActive(client); + } + + @Override + public void channelInactive(TransportClient client) { try { - delegate.connectionTerminated(client); + delegate.channelInactive(client); } finally { if (saslServer != null) { saslServer.dispose(); diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java similarity index 97% rename from network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java rename to common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java index 431cb67a2ae0b..b802a5af63c94 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -28,9 +28,9 @@ import javax.security.sasl.SaslException; import javax.security.sasl.SaslServer; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.Map; -import com.google.common.base.Charsets; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; @@ -187,14 +187,14 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback /* Encode a byte[] identifier as a Base64-encoded string. */ public static String encodeIdentifier(String identifier) { Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled"); - return Base64.encode(Unpooled.wrappedBuffer(identifier.getBytes(Charsets.UTF_8))) - .toString(Charsets.UTF_8); + return Base64.encode(Unpooled.wrappedBuffer(identifier.getBytes(StandardCharsets.UTF_8))) + .toString(StandardCharsets.UTF_8); } /** Encode a password as a base64-encoded char[] array. */ public static char[] encodePassword(String password) { Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled"); - return Base64.encode(Unpooled.wrappedBuffer(password.getBytes(Charsets.UTF_8))) - .toString(Charsets.UTF_8).toCharArray(); + return Base64.encode(Unpooled.wrappedBuffer(password.getBytes(StandardCharsets.UTF_8))) + .toString(StandardCharsets.UTF_8).toCharArray(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/MessageHandler.java similarity index 82% rename from network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java rename to common/network-common/src/main/java/org/apache/spark/network/server/MessageHandler.java index b80c15106ecbd..4a1f28e9ffb31 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/MessageHandler.java @@ -26,11 +26,14 @@ */ public abstract class MessageHandler { /** Handles the receipt of a single message. */ - public abstract void handle(T message); + public abstract void handle(T message) throws Exception; + + /** Invoked when the channel this MessageHandler is on is active. */ + public abstract void channelActive(); /** Invoked when an exception was caught on the Channel. */ public abstract void exceptionCaught(Throwable cause); - /** Invoked when the channel this MessageHandler is on has been unregistered. */ - public abstract void channelUnregistered(); + /** Invoked when the channel this MessageHandler is on is inactive. */ + public abstract void channelInactive(); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java similarity index 91% rename from network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java rename to common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java index 1502b7489e864..6ed61da5c7eff 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java @@ -1,5 +1,3 @@ -package org.apache.spark.network.server; - /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with @@ -17,6 +15,10 @@ * limitations under the License. */ +package org.apache.spark.network.server; + +import java.nio.ByteBuffer; + import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; @@ -29,7 +31,7 @@ public NoOpRpcHandler() { } @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { throw new UnsupportedOperationException("Cannot handle messages"); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java similarity index 96% rename from network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java rename to common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index e671854da1cae..ae7e520b2f709 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -20,7 +20,6 @@ import java.util.Iterator; import java.util.Map; import java.util.Random; -import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; @@ -33,8 +32,8 @@ import org.apache.spark.network.client.TransportClient; /** - * StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually - * fetched as chunks by the client. Each registered buffer is one chunk. + * StreamManager which allows registration of an Iterator<ManagedBuffer>, which are + * individually fetched as chunks by the client. Each registered buffer is one chunk. */ public class OneForOneStreamManager extends StreamManager { private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class); @@ -64,7 +63,7 @@ public OneForOneStreamManager() { // For debugging purposes, start with a random stream id to help identifying different streams. // This does not need to be globally unique, only unique to this class. nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000); - streams = new ConcurrentHashMap(); + streams = new ConcurrentHashMap<>(); } @Override diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java new file mode 100644 index 0000000000000..a99c3015b0e05 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.nio.ByteBuffer; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; + +/** + * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s. + */ +public abstract class RpcHandler { + + private static final RpcResponseCallback ONE_WAY_CALLBACK = new OneWayRpcCallback(); + + /** + * Receive a single RPC message. Any exception thrown while in this method will be sent back to + * the client in string form as a standard RPC failure. + * + * This method will not be called in parallel for a single TransportClient (i.e., channel). + * + * @param client A channel client which enables the handler to make requests back to the sender + * of this RPC. This will always be the exact same object for a particular channel. + * @param message The serialized bytes of the RPC. + * @param callback Callback which should be invoked exactly once upon success or failure of the + * RPC. + */ + public abstract void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback); + + /** + * Returns the StreamManager which contains the state about which streams are currently being + * fetched by a TransportClient. + */ + public abstract StreamManager getStreamManager(); + + /** + * Receives an RPC message that does not expect a reply. The default implementation will + * call "{@link #receive(TransportClient, ByteBuffer, RpcResponseCallback)}" and log a warning if + * any of the callback methods are called. + * + * @param client A channel client which enables the handler to make requests back to the sender + * of this RPC. This will always be the exact same object for a particular channel. + * @param message The serialized bytes of the RPC. + */ + public void receive(TransportClient client, ByteBuffer message) { + receive(client, message, ONE_WAY_CALLBACK); + } + + /** + * Invoked when the channel associated with the given client is active. + */ + public void channelActive(TransportClient client) { } + + /** + * Invoked when the channel associated with the given client is inactive. + * No further requests will come from this client. + */ + public void channelInactive(TransportClient client) { } + + public void exceptionCaught(Throwable cause, TransportClient client) { } + + private static class OneWayRpcCallback implements RpcResponseCallback { + + private final Logger logger = LoggerFactory.getLogger(OneWayRpcCallback.class); + + @Override + public void onSuccess(ByteBuffer response) { + logger.warn("Response provided for one-way RPC."); + } + + @Override + public void onFailure(Throwable e) { + logger.error("Error response provided for one-way RPC.", e); + } + + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java similarity index 97% rename from network/common/src/main/java/org/apache/spark/network/server/StreamManager.java rename to common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java index 3f0155957a140..07f161a29cfb8 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -54,6 +54,7 @@ public abstract class StreamManager { * {@link #getChunk(long, int)} method. * * @param streamId id of a stream that has been previously registered with the StreamManager. + * @return A managed buffer for the stream, or null if the stream was not found. */ public ManagedBuffer openStream(String streamId) { throw new UnsupportedOperationException(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java new file mode 100644 index 0000000000000..f2223379a9d24 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.handler.timeout.IdleState; +import io.netty.handler.timeout.IdleStateEvent; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportResponseHandler; +import org.apache.spark.network.protocol.Message; +import org.apache.spark.network.protocol.RequestMessage; +import org.apache.spark.network.protocol.ResponseMessage; +import org.apache.spark.network.util.NettyUtils; + +/** + * The single Transport-level Channel handler which is used for delegating requests to the + * {@link TransportRequestHandler} and responses to the {@link TransportResponseHandler}. + * + * All channels created in the transport layer are bidirectional. When the Client initiates a Netty + * Channel with a RequestMessage (which gets handled by the Server's RequestHandler), the Server + * will produce a ResponseMessage (handled by the Client's ResponseHandler). However, the Server + * also gets a handle on the same Channel, so it may then begin to send RequestMessages to the + * Client. + * This means that the Client also needs a RequestHandler and the Server needs a ResponseHandler, + * for the Client's responses to the Server's requests. + * + * This class also handles timeouts from a {@link io.netty.handler.timeout.IdleStateHandler}. + * We consider a connection timed out if there are outstanding fetch or RPC requests but no traffic + * on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not + * timeout if the client is continuously sending but getting no responses, for simplicity. + */ +public class TransportChannelHandler extends SimpleChannelInboundHandler { + private final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class); + + private final TransportClient client; + private final TransportResponseHandler responseHandler; + private final TransportRequestHandler requestHandler; + private final long requestTimeoutNs; + private final boolean closeIdleConnections; + + public TransportChannelHandler( + TransportClient client, + TransportResponseHandler responseHandler, + TransportRequestHandler requestHandler, + long requestTimeoutMs, + boolean closeIdleConnections) { + this.client = client; + this.responseHandler = responseHandler; + this.requestHandler = requestHandler; + this.requestTimeoutNs = requestTimeoutMs * 1000L * 1000; + this.closeIdleConnections = closeIdleConnections; + } + + public TransportClient getClient() { + return client; + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + logger.warn("Exception in connection from " + NettyUtils.getRemoteAddress(ctx.channel()), + cause); + requestHandler.exceptionCaught(cause); + responseHandler.exceptionCaught(cause); + ctx.close(); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + try { + requestHandler.channelActive(); + } catch (RuntimeException e) { + logger.error("Exception from request handler while registering channel", e); + } + try { + responseHandler.channelActive(); + } catch (RuntimeException e) { + logger.error("Exception from response handler while registering channel", e); + } + super.channelRegistered(ctx); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + try { + requestHandler.channelInactive(); + } catch (RuntimeException e) { + logger.error("Exception from request handler while unregistering channel", e); + } + try { + responseHandler.channelInactive(); + } catch (RuntimeException e) { + logger.error("Exception from response handler while unregistering channel", e); + } + super.channelUnregistered(ctx); + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception { + if (request instanceof RequestMessage) { + requestHandler.handle((RequestMessage) request); + } else { + responseHandler.handle((ResponseMessage) request); + } + } + + /** Triggered based on events from an {@link io.netty.handler.timeout.IdleStateHandler}. */ + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof IdleStateEvent) { + IdleStateEvent e = (IdleStateEvent) evt; + // See class comment for timeout semantics. In addition to ensuring we only timeout while + // there are outstanding requests, we also do a secondary consistency check to ensure + // there's no race between the idle timeout and incrementing the numOutstandingRequests + // (see SPARK-7003). + // + // To avoid a race between TransportClientFactory.createClient() and this code which could + // result in an inactive client being returned, this needs to run in a synchronized block. + synchronized (this) { + boolean isActuallyOverdue = + System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; + if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) { + if (responseHandler.numOutstandingRequests() > 0) { + String address = NettyUtils.getRemoteAddress(ctx.channel()); + logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + + "requests. Assuming connection is dead; please adjust spark.network.timeout if " + + "this is wrong.", address, requestTimeoutNs / 1000 / 1000); + client.timeOut(); + ctx.close(); + } else if (closeIdleConnections) { + // While CloseIdleConnections is enable, we also close idle connection + client.timeOut(); + ctx.close(); + } + } + } + } + ctx.fireUserEventTriggered(evt); + } + + public TransportResponseHandler getResponseHandler() { + return responseHandler; + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java similarity index 84% rename from network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java rename to common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 4f67bd573be21..bebe88ec5d503 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -17,6 +17,8 @@ package org.apache.spark.network.server; +import java.nio.ByteBuffer; + import com.google.common.base.Throwables; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; @@ -25,15 +27,17 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.protocol.RequestMessage; import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.OneWayMessage; +import org.apache.spark.network.protocol.RequestMessage; import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamFailure; import org.apache.spark.network.protocol.StreamRequest; @@ -78,7 +82,12 @@ public void exceptionCaught(Throwable cause) { } @Override - public void channelUnregistered() { + public void channelActive() { + rpcHandler.channelActive(reverseClient); + } + + @Override + public void channelInactive() { if (streamManager != null) { try { streamManager.connectionTerminated(channel); @@ -86,7 +95,7 @@ public void channelUnregistered() { logger.error("StreamManager connectionTerminated() callback failed.", e); } } - rpcHandler.connectionTerminated(reverseClient); + rpcHandler.channelInactive(reverseClient); } @Override @@ -95,6 +104,8 @@ public void handle(RequestMessage request) { processFetchRequest((ChunkFetchRequest) request); } else if (request instanceof RpcRequest) { processRpcRequest((RpcRequest) request); + } else if (request instanceof OneWayMessage) { + processOneWayMessage((OneWayMessage) request); } else if (request instanceof StreamRequest) { processStreamRequest((StreamRequest) request); } else { @@ -134,15 +145,20 @@ private void processStreamRequest(final StreamRequest req) { return; } - respond(new StreamResponse(req.streamId, buf.size(), buf)); + if (buf != null) { + respond(new StreamResponse(req.streamId, buf.size(), buf)); + } else { + respond(new StreamFailure(req.streamId, String.format( + "Stream '%s' was not found.", req.streamId))); + } } private void processRpcRequest(final RpcRequest req) { try { - rpcHandler.receive(reverseClient, req.message, new RpcResponseCallback() { + rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() { @Override - public void onSuccess(byte[] response) { - respond(new RpcResponse(req.requestId, response)); + public void onSuccess(ByteBuffer response) { + respond(new RpcResponse(req.requestId, new NioManagedBuffer(response))); } @Override @@ -153,6 +169,18 @@ public void onFailure(Throwable e) { } catch (Exception e) { logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + } finally { + req.body().release(); + } + } + + private void processOneWayMessage(OneWayMessage req) { + try { + rpcHandler.receive(reverseClient, req.body().nioByteBuffer()); + } catch (Exception e) { + logger.error("Error while invoking RpcHandler#receive() for one-way message.", e); + } finally { + req.body().release(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java similarity index 90% rename from network/common/src/main/java/org/apache/spark/network/server/TransportServer.java rename to common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java index f4fadb1ee3b8d..baae235e02205 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -55,9 +55,13 @@ public class TransportServer implements Closeable { private ChannelFuture channelFuture; private int port = -1; - /** Creates a TransportServer that binds to the given port, or to any available if 0. */ + /** + * Creates a TransportServer that binds to the given host and the given port, or to any available + * if 0. If you don't want to bind to any special host, set "hostToBind" to null. + * */ public TransportServer( TransportContext context, + String hostToBind, int portToBind, RpcHandler appRpcHandler, List bootstraps) { @@ -67,7 +71,7 @@ public TransportServer( this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps)); try { - init(portToBind); + init(hostToBind, portToBind); } catch (RuntimeException e) { JavaUtils.closeQuietly(this); throw e; @@ -81,7 +85,7 @@ public int getPort() { return port; } - private void init(int portToBind) { + private void init(String hostToBind, int portToBind) { IOMode ioMode = IOMode.valueOf(conf.ioMode()); EventLoopGroup bossGroup = @@ -120,7 +124,9 @@ protected void initChannel(SocketChannel ch) throws Exception { } }); - channelFuture = bootstrap.bind(new InetSocketAddress(portToBind)); + InetSocketAddress address = hostToBind == null ? + new InetSocketAddress(portToBind): new InetSocketAddress(hostToBind, portToBind); + channelFuture = bootstrap.bind(address); channelFuture.syncUninterruptibly(); port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort(); diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java rename to common/network-common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java b/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java rename to common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java b/common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java similarity index 95% rename from network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java rename to common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java index 36d655017fb0d..e097714bbc6de 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java @@ -24,7 +24,7 @@ public enum ByteUnit { TiB ((long) Math.pow(1024L, 4L)), PiB ((long) Math.pow(1024L, 5L)); - private ByteUnit(long multiplier) { + ByteUnit(long multiplier) { this.multiplier = multiplier; } @@ -33,8 +33,8 @@ private ByteUnit(long multiplier) { public long convertFrom(long d, ByteUnit u) { return u.convertTo(d, this); } - - // Convert the provided number (d) interpreted as this unit type to unit type (u). + + // Convert the provided number (d) interpreted as this unit type to unit type (u). public long convertTo(long d, ByteUnit u) { if (multiplier > u.multiplier) { long ratio = multiplier / u.multiplier; @@ -44,7 +44,7 @@ public long convertTo(long d, ByteUnit u) { } return d * ratio; } else { - // Perform operations in this order to avoid potential overflow + // Perform operations in this order to avoid potential overflow // when computing d * multiplier return d / (u.multiplier / multiplier); } @@ -54,14 +54,14 @@ public double toBytes(long d) { if (d < 0) { throw new IllegalArgumentException("Negative size value. Size must be positive: " + d); } - return d * multiplier; + return d * multiplier; } - + public long toKiB(long d) { return convertTo(d, KiB); } public long toMiB(long d) { return convertTo(d, MiB); } public long toGiB(long d) { return convertTo(d, GiB); } public long toTiB(long d) { return convertTo(d, TiB); } public long toPiB(long d) { return convertTo(d, PiB); } - + private final long multiplier; } diff --git a/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/ConfigProvider.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java rename to common/network-common/src/main/java/org/apache/spark/network/util/ConfigProvider.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/IOMode.java b/common/network-common/src/main/java/org/apache/spark/network/util/IOMode.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/IOMode.java rename to common/network-common/src/main/java/org/apache/spark/network/util/IOMode.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java similarity index 85% rename from network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java rename to common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 7d27439cfde7a..fbed2f053dc6c 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -21,11 +21,11 @@ import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; import java.util.regex.Pattern; -import com.google.common.base.Charsets; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; import io.netty.buffer.Unpooled; @@ -68,7 +68,7 @@ public static int nonNegativeHash(Object obj) { * converted back to the same string through {@link #bytesToString(ByteBuffer)}. */ public static ByteBuffer stringToBytes(String s) { - return Unpooled.wrappedBuffer(s.getBytes(Charsets.UTF_8)).nioBuffer(); + return Unpooled.wrappedBuffer(s.getBytes(StandardCharsets.UTF_8)).nioBuffer(); } /** @@ -76,7 +76,7 @@ public static ByteBuffer stringToBytes(String s) { * converted back to the same byte buffer through {@link #stringToBytes(String)}. */ public static String bytesToString(ByteBuffer b) { - return Unpooled.wrappedBuffer(b).toString(Charsets.UTF_8); + return Unpooled.wrappedBuffer(b).toString(StandardCharsets.UTF_8); } /* @@ -132,7 +132,7 @@ private static boolean isSymlink(File file) throws IOException { return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); } - private static final ImmutableMap timeSuffixes = + private static final ImmutableMap timeSuffixes = ImmutableMap.builder() .put("us", TimeUnit.MICROSECONDS) .put("ms", TimeUnit.MILLISECONDS) @@ -159,43 +159,43 @@ private static boolean isSymlink(File file) throws IOException { .build(); /** - * Convert a passed time string (e.g. 50s, 100ms, or 250us) to a time count for - * internal use. If no suffix is provided a direct conversion is attempted. + * Convert a passed time string (e.g. 50s, 100ms, or 250us) to a time count in the given unit. + * The unit is also considered the default if the given string does not specify a unit. */ - private static long parseTimeString(String str, TimeUnit unit) { + public static long timeStringAs(String str, TimeUnit unit) { String lower = str.toLowerCase().trim(); - + try { Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower); if (!m.matches()) { throw new NumberFormatException("Failed to parse time string: " + str); } - + long val = Long.parseLong(m.group(1)); String suffix = m.group(2); - + // Check for invalid suffixes if (suffix != null && !timeSuffixes.containsKey(suffix)) { throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); } - + // If suffix is valid use that, otherwise none was provided and use the default passed return unit.convert(val, suffix != null ? timeSuffixes.get(suffix) : unit); } catch (NumberFormatException e) { String timeError = "Time must be specified as seconds (s), " + "milliseconds (ms), microseconds (us), minutes (m or min), hour (h), or day (d). " + "E.g. 50s, 100ms, or 250us."; - + throw new NumberFormatException(timeError + "\n" + e.getMessage()); } } - + /** * Convert a time parameter such as (50s, 100ms, or 250us) to milliseconds for internal use. If * no suffix is provided, the passed number is assumed to be in ms. */ public static long timeStringAsMs(String str) { - return parseTimeString(str, TimeUnit.MILLISECONDS); + return timeStringAs(str, TimeUnit.MILLISECONDS); } /** @@ -203,21 +203,20 @@ public static long timeStringAsMs(String str) { * no suffix is provided, the passed number is assumed to be in seconds. */ public static long timeStringAsSec(String str) { - return parseTimeString(str, TimeUnit.SECONDS); + return timeStringAs(str, TimeUnit.SECONDS); } - + /** - * Convert a passed byte string (e.g. 50b, 100kb, or 250mb) to a ByteUnit for - * internal use. If no suffix is provided a direct conversion of the provided default is - * attempted. + * Convert a passed byte string (e.g. 50b, 100kb, or 250mb) to the given. If no suffix is + * provided, a direct conversion to the provided unit is attempted. */ - private static long parseByteString(String str, ByteUnit unit) { + public static long byteStringAs(String str, ByteUnit unit) { String lower = str.toLowerCase().trim(); try { Matcher m = Pattern.compile("([0-9]+)([a-z]+)?").matcher(lower); Matcher fractionMatcher = Pattern.compile("([0-9]+\\.[0-9]+)([a-z]+)?").matcher(lower); - + if (m.matches()) { long val = Long.parseLong(m.group(1)); String suffix = m.group(2); @@ -228,31 +227,31 @@ private static long parseByteString(String str, ByteUnit unit) { } // If suffix is valid use that, otherwise none was provided and use the default passed - return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit); + return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit); } else if (fractionMatcher.matches()) { - throw new NumberFormatException("Fractional values are not supported. Input was: " + throw new NumberFormatException("Fractional values are not supported. Input was: " + fractionMatcher.group(1)); } else { - throw new NumberFormatException("Failed to parse byte string: " + str); + throw new NumberFormatException("Failed to parse byte string: " + str); } - + } catch (NumberFormatException e) { - String timeError = "Size must be specified as bytes (b), " + + String byteError = "Size must be specified as bytes (b), " + "kibibytes (k), mebibytes (m), gibibytes (g), tebibytes (t), or pebibytes(p). " + "E.g. 50b, 100k, or 250m."; - throw new NumberFormatException(timeError + "\n" + e.getMessage()); + throw new NumberFormatException(byteError + "\n" + e.getMessage()); } } /** * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for * internal use. - * + * * If no suffix is provided, the passed number is assumed to be in bytes. */ public static long byteStringAsBytes(String str) { - return parseByteString(str, ByteUnit.BYTE); + return byteStringAs(str, ByteUnit.BYTE); } /** @@ -262,9 +261,9 @@ public static long byteStringAsBytes(String str) { * If no suffix is provided, the passed number is assumed to be in kibibytes. */ public static long byteStringAsKb(String str) { - return parseByteString(str, ByteUnit.KiB); + return byteStringAs(str, ByteUnit.KiB); } - + /** * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for * internal use. @@ -272,7 +271,7 @@ public static long byteStringAsKb(String str) { * If no suffix is provided, the passed number is assumed to be in mebibytes. */ public static long byteStringAsMb(String str) { - return parseByteString(str, ByteUnit.MiB); + return byteStringAs(str, ByteUnit.MiB); } /** @@ -282,6 +281,22 @@ public static long byteStringAsMb(String str) { * If no suffix is provided, the passed number is assumed to be in gibibytes. */ public static long byteStringAsGb(String str) { - return parseByteString(str, ByteUnit.GiB); + return byteStringAs(str, ByteUnit.GiB); } + + /** + * Returns a byte array with the buffer's contents, trying to avoid copying the data if + * possible. + */ + public static byte[] bufferToArray(ByteBuffer buffer) { + if (buffer.hasArray() && buffer.arrayOffset() == 0 && + buffer.array().length == buffer.remaining()) { + return buffer.array(); + } else { + byte[] bytes = new byte[buffer.remaining()]; + buffer.get(bytes); + return bytes; + } + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java b/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java rename to common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java similarity index 100% rename from network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java rename to common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java similarity index 97% rename from network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java rename to common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java index caa7260bc8281..10de9d3a5caf6 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -31,8 +31,6 @@ import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; -import io.netty.handler.codec.ByteToMessageDecoder; -import io.netty.handler.codec.LengthFieldBasedFrameDecoder; import io.netty.util.internal.PlatformDependent; /** diff --git a/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java similarity index 95% rename from network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java rename to common/network-common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java index 5f20b70678d1e..f15ec8d294258 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java @@ -19,8 +19,6 @@ import java.util.NoSuchElementException; -import org.apache.spark.network.util.ConfigProvider; - /** Uses System properties to obtain config values. */ public class SystemPropertyConfigProvider extends ConfigProvider { @Override diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java new file mode 100644 index 0000000000000..9f030da2b3cec --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import com.google.common.primitives.Ints; + +/** + * A central location that tracks all the settings we expose to users. + */ +public class TransportConf { + + private final String SPARK_NETWORK_IO_MODE_KEY; + private final String SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY; + private final String SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY; + private final String SPARK_NETWORK_IO_BACKLOG_KEY; + private final String SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY; + private final String SPARK_NETWORK_IO_SERVERTHREADS_KEY; + private final String SPARK_NETWORK_IO_CLIENTTHREADS_KEY; + private final String SPARK_NETWORK_IO_RECEIVEBUFFER_KEY; + private final String SPARK_NETWORK_IO_SENDBUFFER_KEY; + private final String SPARK_NETWORK_SASL_TIMEOUT_KEY; + private final String SPARK_NETWORK_IO_MAXRETRIES_KEY; + private final String SPARK_NETWORK_IO_RETRYWAIT_KEY; + private final String SPARK_NETWORK_IO_LAZYFD_KEY; + + private final ConfigProvider conf; + + private final String module; + + public TransportConf(String module, ConfigProvider conf) { + this.module = module; + this.conf = conf; + SPARK_NETWORK_IO_MODE_KEY = getConfKey("io.mode"); + SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY = getConfKey("io.preferDirectBufs"); + SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY = getConfKey("io.connectionTimeout"); + SPARK_NETWORK_IO_BACKLOG_KEY = getConfKey("io.backLog"); + SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY = getConfKey("io.numConnectionsPerPeer"); + SPARK_NETWORK_IO_SERVERTHREADS_KEY = getConfKey("io.serverThreads"); + SPARK_NETWORK_IO_CLIENTTHREADS_KEY = getConfKey("io.clientThreads"); + SPARK_NETWORK_IO_RECEIVEBUFFER_KEY = getConfKey("io.receiveBuffer"); + SPARK_NETWORK_IO_SENDBUFFER_KEY = getConfKey("io.sendBuffer"); + SPARK_NETWORK_SASL_TIMEOUT_KEY = getConfKey("sasl.timeout"); + SPARK_NETWORK_IO_MAXRETRIES_KEY = getConfKey("io.maxRetries"); + SPARK_NETWORK_IO_RETRYWAIT_KEY = getConfKey("io.retryWait"); + SPARK_NETWORK_IO_LAZYFD_KEY = getConfKey("io.lazyFD"); + } + + private String getConfKey(String suffix) { + return "spark." + module + "." + suffix; + } + + /** IO mode: nio or epoll */ + public String ioMode() { return conf.get(SPARK_NETWORK_IO_MODE_KEY, "NIO").toUpperCase(); } + + /** If true, we will prefer allocating off-heap byte buffers within Netty. */ + public boolean preferDirectBufs() { + return conf.getBoolean(SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY, true); + } + + /** Connect timeout in milliseconds. Default 120 secs. */ + public int connectionTimeoutMs() { + long defaultNetworkTimeoutS = JavaUtils.timeStringAsSec( + conf.get("spark.network.timeout", "120s")); + long defaultTimeoutMs = JavaUtils.timeStringAsSec( + conf.get(SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY, defaultNetworkTimeoutS + "s")) * 1000; + return (int) defaultTimeoutMs; + } + + /** Number of concurrent connections between two nodes for fetching data. */ + public int numConnectionsPerPeer() { + return conf.getInt(SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY, 1); + } + + /** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */ + public int backLog() { return conf.getInt(SPARK_NETWORK_IO_BACKLOG_KEY, -1); } + + /** Number of threads used in the server thread pool. Default to 0, which is 2x#cores. */ + public int serverThreads() { return conf.getInt(SPARK_NETWORK_IO_SERVERTHREADS_KEY, 0); } + + /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */ + public int clientThreads() { return conf.getInt(SPARK_NETWORK_IO_CLIENTTHREADS_KEY, 0); } + + /** + * Receive buffer size (SO_RCVBUF). + * Note: the optimal size for receive buffer and send buffer should be + * latency * network_bandwidth. + * Assuming latency = 1ms, network_bandwidth = 10Gbps + * buffer size should be ~ 1.25MB + */ + public int receiveBuf() { return conf.getInt(SPARK_NETWORK_IO_RECEIVEBUFFER_KEY, -1); } + + /** Send buffer size (SO_SNDBUF). */ + public int sendBuf() { return conf.getInt(SPARK_NETWORK_IO_SENDBUFFER_KEY, -1); } + + /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ + public int saslRTTimeoutMs() { + return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_SASL_TIMEOUT_KEY, "30s")) * 1000; + } + + /** + * Max number of times we will try IO exceptions (such as connection timeouts) per request. + * If set to 0, we will not do any retries. + */ + public int maxIORetries() { return conf.getInt(SPARK_NETWORK_IO_MAXRETRIES_KEY, 3); } + + /** + * Time (in milliseconds) that we will wait in order to perform a retry after an IOException. + * Only relevant if maxIORetries > 0. + */ + public int ioRetryWaitTimeMs() { + return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_IO_RETRYWAIT_KEY, "5s")) * 1000; + } + + /** + * Minimum size of a block that we should start using memory map rather than reading in through + * normal IO operations. This prevents Spark from memory mapping very small blocks. In general, + * memory mapping has high overhead for blocks close to or below the page size of the OS. + */ + public int memoryMapBytes() { + return Ints.checkedCast(JavaUtils.byteStringAsBytes( + conf.get("spark.storage.memoryMapThreshold", "2m"))); + } + + /** + * Whether to initialize FileDescriptor lazily or not. If true, file descriptors are + * created only when data is going to be transferred. This can reduce the number of open files. + */ + public boolean lazyFileDescriptor() { + return conf.getBoolean(SPARK_NETWORK_IO_LAZYFD_KEY, true); + } + + /** + * Maximum number of retries when binding to a port before giving up. + */ + public int portMaxRetries() { + return conf.getInt("spark.port.maxRetries", 16); + } + + /** + * Maximum number of bytes to be encrypted at a time when SASL encryption is enabled. + */ + public int maxSaslEncryptedBlockSize() { + return Ints.checkedCast(JavaUtils.byteStringAsBytes( + conf.get("spark.network.sasl.maxEncryptedBlockSize", "64k"))); + } + + /** + * Whether the server should enforce encryption on SASL-authenticated connections. + */ + public boolean saslServerAlwaysEncrypt() { + return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false); + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java new file mode 100644 index 0000000000000..fcec7dfd0c210 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.util.LinkedList; + +import com.google.common.base.Preconditions; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; + +/** + * A customized frame decoder that allows intercepting raw data. + *

+ * This behaves like Netty's frame decoder (with harcoded parameters that match this library's + * needs), except it allows an interceptor to be installed to read data directly before it's + * framed. + *

+ * Unlike Netty's frame decoder, each frame is dispatched to child handlers as soon as it's + * decoded, instead of building as many frames as the current buffer allows and dispatching + * all of them. This allows a child handler to install an interceptor if needed. + *

+ * If an interceptor is installed, framing stops, and data is instead fed directly to the + * interceptor. When the interceptor indicates that it doesn't need to read any more data, + * framing resumes. Interceptors should not hold references to the data buffers provided + * to their handle() method. + */ +public class TransportFrameDecoder extends ChannelInboundHandlerAdapter { + + public static final String HANDLER_NAME = "frameDecoder"; + private static final int LENGTH_SIZE = 8; + private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE; + private static final int UNKNOWN_FRAME_SIZE = -1; + + private final LinkedList buffers = new LinkedList<>(); + private final ByteBuf frameLenBuf = Unpooled.buffer(LENGTH_SIZE, LENGTH_SIZE); + + private long totalSize = 0; + private long nextFrameSize = UNKNOWN_FRAME_SIZE; + private volatile Interceptor interceptor; + + @Override + public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { + ByteBuf in = (ByteBuf) data; + buffers.add(in); + totalSize += in.readableBytes(); + + while (!buffers.isEmpty()) { + // First, feed the interceptor, and if it's still, active, try again. + if (interceptor != null) { + ByteBuf first = buffers.getFirst(); + int available = first.readableBytes(); + if (feedInterceptor(first)) { + assert !first.isReadable() : "Interceptor still active but buffer has data."; + } + + int read = available - first.readableBytes(); + if (read == available) { + buffers.removeFirst().release(); + } + totalSize -= read; + } else { + // Interceptor is not active, so try to decode one frame. + ByteBuf frame = decodeNext(); + if (frame == null) { + break; + } + ctx.fireChannelRead(frame); + } + } + } + + private long decodeFrameSize() { + if (nextFrameSize != UNKNOWN_FRAME_SIZE || totalSize < LENGTH_SIZE) { + return nextFrameSize; + } + + // We know there's enough data. If the first buffer contains all the data, great. Otherwise, + // hold the bytes for the frame length in a composite buffer until we have enough data to read + // the frame size. Normally, it should be rare to need more than one buffer to read the frame + // size. + ByteBuf first = buffers.getFirst(); + if (first.readableBytes() >= LENGTH_SIZE) { + nextFrameSize = first.readLong() - LENGTH_SIZE; + totalSize -= LENGTH_SIZE; + if (!first.isReadable()) { + buffers.removeFirst().release(); + } + return nextFrameSize; + } + + while (frameLenBuf.readableBytes() < LENGTH_SIZE) { + ByteBuf next = buffers.getFirst(); + int toRead = Math.min(next.readableBytes(), LENGTH_SIZE - frameLenBuf.readableBytes()); + frameLenBuf.writeBytes(next, toRead); + if (!next.isReadable()) { + buffers.removeFirst().release(); + } + } + + nextFrameSize = frameLenBuf.readLong() - LENGTH_SIZE; + totalSize -= LENGTH_SIZE; + frameLenBuf.clear(); + return nextFrameSize; + } + + private ByteBuf decodeNext() throws Exception { + long frameSize = decodeFrameSize(); + if (frameSize == UNKNOWN_FRAME_SIZE || totalSize < frameSize) { + return null; + } + + // Reset size for next frame. + nextFrameSize = UNKNOWN_FRAME_SIZE; + + Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: %s", frameSize); + Preconditions.checkArgument(frameSize > 0, "Frame length should be positive: %s", frameSize); + + // If the first buffer holds the entire frame, return it. + int remaining = (int) frameSize; + if (buffers.getFirst().readableBytes() >= remaining) { + return nextBufferForFrame(remaining); + } + + // Otherwise, create a composite buffer. + CompositeByteBuf frame = buffers.getFirst().alloc().compositeBuffer(Integer.MAX_VALUE); + while (remaining > 0) { + ByteBuf next = nextBufferForFrame(remaining); + remaining -= next.readableBytes(); + frame.addComponent(next).writerIndex(frame.writerIndex() + next.readableBytes()); + } + assert remaining == 0; + return frame; + } + + /** + * Takes the first buffer in the internal list, and either adjust it to fit in the frame + * (by taking a slice out of it) or remove it from the internal list. + */ + private ByteBuf nextBufferForFrame(int bytesToRead) { + ByteBuf buf = buffers.getFirst(); + ByteBuf frame; + + if (buf.readableBytes() > bytesToRead) { + frame = buf.retain().readSlice(bytesToRead); + totalSize -= bytesToRead; + } else { + frame = buf; + buffers.removeFirst(); + totalSize -= frame.readableBytes(); + } + + return frame; + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + for (ByteBuf b : buffers) { + b.release(); + } + if (interceptor != null) { + interceptor.channelInactive(); + } + frameLenBuf.release(); + super.channelInactive(ctx); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (interceptor != null) { + interceptor.exceptionCaught(cause); + } + super.exceptionCaught(ctx, cause); + } + + public void setInterceptor(Interceptor interceptor) { + Preconditions.checkState(this.interceptor == null, "Already have an interceptor."); + this.interceptor = interceptor; + } + + /** + * @return Whether the interceptor is still active after processing the data. + */ + private boolean feedInterceptor(ByteBuf buf) throws Exception { + if (interceptor != null && !interceptor.handle(buf)) { + interceptor = null; + } + return interceptor != null; + } + + public interface Interceptor { + + /** + * Handles data received from the remote end. + * + * @param data Buffer containing data. + * @return "true" if the interceptor expects more data, "false" to uninstall the interceptor. + */ + boolean handle(ByteBuf data) throws Exception; + + /** Called if an exception is thrown in the channel pipeline. */ + void exceptionCaught(Throwable cause) throws Exception; + + /** Called if the channel is closed and the interceptor is still installed. */ + void channelInactive() throws Exception; + + } + +} diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java similarity index 92% rename from network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index dfb7740344ed0..6d62eaf35d8cc 100644 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -31,6 +31,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Sets; +import com.google.common.io.Closeables; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -63,8 +64,6 @@ public class ChunkFetchIntegrationSuite { static ManagedBuffer bufferChunk; static ManagedBuffer fileChunk; - private TransportConf transportConf; - @BeforeClass public static void setUp() throws Exception { int bufSize = 100000; @@ -78,12 +77,17 @@ public static void setUp() throws Exception { testFile = File.createTempFile("shuffle-test-file", "txt"); testFile.deleteOnExit(); RandomAccessFile fp = new RandomAccessFile(testFile, "rw"); - byte[] fileContent = new byte[1024]; - new Random().nextBytes(fileContent); - fp.write(fileContent); - fp.close(); + boolean shouldSuppressIOException = true; + try { + byte[] fileContent = new byte[1024]; + new Random().nextBytes(fileContent); + fp.write(fileContent); + shouldSuppressIOException = false; + } finally { + Closeables.close(fp, shouldSuppressIOException); + } - final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); streamManager = new StreamManager() { @@ -101,7 +105,10 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { }; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { throw new UnsupportedOperationException(); } @@ -117,12 +124,13 @@ public StreamManager getStreamManager() { @AfterClass public static void tearDown() { + bufferChunk.release(); server.close(); clientFactory.close(); testFile.delete(); } - class FetchResult { + static class FetchResult { public Set successChunks; public Set failedChunks; public List buffers; diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java similarity index 92% rename from network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 22b451fc0e60e..6c8dd742f4b64 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -35,6 +35,7 @@ import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.MessageDecoder; import org.apache.spark.network.protocol.MessageEncoder; +import org.apache.spark.network.protocol.OneWayMessage; import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.RpcResponse; @@ -81,9 +82,10 @@ private void testClientToServer(Message msg) { @Test public void requests() { testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2))); - testClientToServer(new RpcRequest(12345, new byte[0])); - testClientToServer(new RpcRequest(12345, new byte[100])); + testClientToServer(new RpcRequest(12345, new TestManagedBuffer(0))); + testClientToServer(new RpcRequest(12345, new TestManagedBuffer(10))); testClientToServer(new StreamRequest("abcde")); + testClientToServer(new OneWayMessage(new TestManagedBuffer(10))); } @Test @@ -92,8 +94,8 @@ public void responses() { testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0))); testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error")); testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "")); - testServerToClient(new RpcResponse(12345, new byte[0])); - testServerToClient(new RpcResponse(12345, new byte[1000])); + testServerToClient(new RpcResponse(12345, new TestManagedBuffer(0))); + testServerToClient(new RpcResponse(12345, new TestManagedBuffer(100))); testServerToClient(new RpcFailure(0, "this is an error")); testServerToClient(new RpcFailure(0, "")); // Note: buffer size must be "0" since StreamResponse's buffer is written differently to the diff --git a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java new file mode 100644 index 0000000000000..959396bb8c268 --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import com.google.common.collect.Maps; +import com.google.common.util.concurrent.Uninterruptibles; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.MapConfigProvider; +import org.apache.spark.network.util.TransportConf; +import org.junit.*; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.*; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +/** + * Suite which ensures that requests that go without a response for the network timeout period are + * failed, and the connection closed. + * + * In this suite, we use 10 seconds as the connection timeout, with some slack given in the tests, + * to ensure stability in different test environments. + */ +public class RequestTimeoutIntegrationSuite { + + private TransportServer server; + private TransportClientFactory clientFactory; + + private StreamManager defaultManager; + private TransportConf conf; + + // A large timeout that "shouldn't happen", for the sake of faulty tests not hanging forever. + private static final int FOREVER = 60 * 1000; + + @Before + public void setUp() throws Exception { + Map configMap = Maps.newHashMap(); + configMap.put("spark.shuffle.io.connectionTimeout", "10s"); + conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); + + defaultManager = new StreamManager() { + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + throw new UnsupportedOperationException(); + } + }; + } + + @After + public void tearDown() { + if (server != null) { + server.close(); + } + if (clientFactory != null) { + clientFactory.close(); + } + } + + // Basic suite: First request completes quickly, and second waits for longer than network timeout. + @Test + public void timeoutInactiveRequests() throws Exception { + final Semaphore semaphore = new Semaphore(1); + final int responseSize = 16; + RpcHandler handler = new RpcHandler() { + @Override + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + try { + semaphore.acquire(); + callback.onSuccess(ByteBuffer.allocate(responseSize)); + } catch (InterruptedException e) { + // do nothing + } + } + + @Override + public StreamManager getStreamManager() { + return defaultManager; + } + }; + + TransportContext context = new TransportContext(conf, handler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + + // First completes quickly (semaphore starts at 1). + TestCallback callback0 = new TestCallback(); + client.sendRpc(ByteBuffer.allocate(0), callback0); + callback0.latch.await(); + assertEquals(responseSize, callback0.successLength); + + // Second times out after 10 seconds, with slack. Must be IOException. + TestCallback callback1 = new TestCallback(); + client.sendRpc(ByteBuffer.allocate(0), callback1); + callback1.latch.await(60, TimeUnit.SECONDS); + assertNotNull(callback1.failure); + assertTrue(callback1.failure instanceof IOException); + + semaphore.release(); + } + + // A timeout will cause the connection to be closed, invalidating the current TransportClient. + // It should be the case that requesting a client from the factory produces a new, valid one. + @Test + public void timeoutCleanlyClosesClient() throws Exception { + final Semaphore semaphore = new Semaphore(0); + final int responseSize = 16; + RpcHandler handler = new RpcHandler() { + @Override + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + try { + semaphore.acquire(); + callback.onSuccess(ByteBuffer.allocate(responseSize)); + } catch (InterruptedException e) { + // do nothing + } + } + + @Override + public StreamManager getStreamManager() { + return defaultManager; + } + }; + + TransportContext context = new TransportContext(conf, handler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + + // First request should eventually fail. + TransportClient client0 = + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + TestCallback callback0 = new TestCallback(); + client0.sendRpc(ByteBuffer.allocate(0), callback0); + callback0.latch.await(); + assertTrue(callback0.failure instanceof IOException); + assertFalse(client0.isActive()); + + // Increment the semaphore and the second request should succeed quickly. + semaphore.release(2); + TransportClient client1 = + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + TestCallback callback1 = new TestCallback(); + client1.sendRpc(ByteBuffer.allocate(0), callback1); + callback1.latch.await(); + assertEquals(responseSize, callback1.successLength); + assertNull(callback1.failure); + } + + // The timeout is relative to the LAST request sent, which is kinda weird, but still. + // This test also makes sure the timeout works for Fetch requests as well as RPCs. + @Test + public void furtherRequestsDelay() throws Exception { + final byte[] response = new byte[16]; + final StreamManager manager = new StreamManager() { + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + Uninterruptibles.sleepUninterruptibly(FOREVER, TimeUnit.MILLISECONDS); + return new NioManagedBuffer(ByteBuffer.wrap(response)); + } + }; + RpcHandler handler = new RpcHandler() { + @Override + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + throw new UnsupportedOperationException(); + } + + @Override + public StreamManager getStreamManager() { + return manager; + } + }; + + TransportContext context = new TransportContext(conf, handler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + + // Send one request, which will eventually fail. + TestCallback callback0 = new TestCallback(); + client.fetchChunk(0, 0, callback0); + Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS); + + // Send a second request before the first has failed. + TestCallback callback1 = new TestCallback(); + client.fetchChunk(0, 1, callback1); + Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS); + + // not complete yet, but should complete soon + assertEquals(-1, callback0.successLength); + assertNull(callback0.failure); + callback0.latch.await(60, TimeUnit.SECONDS); + assertTrue(callback0.failure instanceof IOException); + + // failed at same time as previous + assertTrue(callback1.failure instanceof IOException); + } + + /** + * Callback which sets 'success' or 'failure' on completion. + * Additionally notifies all waiters on this callback when invoked. + */ + static class TestCallback implements RpcResponseCallback, ChunkReceivedCallback { + + int successLength = -1; + Throwable failure; + final CountDownLatch latch = new CountDownLatch(1); + + @Override + public void onSuccess(ByteBuffer response) { + successLength = response.remaining(); + latch.countDown(); + } + + @Override + public void onFailure(Throwable e) { + failure = e; + latch.countDown(); + } + + @Override + public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + try { + successLength = buffer.nioByteBuffer().remaining(); + } catch (IOException e) { + // weird + } finally { + latch.countDown(); + } + } + + @Override + public void onFailure(int chunkIndex, Throwable e) { + failure = e; + latch.countDown(); + } + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java similarity index 77% rename from network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 64b457b4b3f01..a7a99f3bfc707 100644 --- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -17,14 +17,16 @@ package org.apache.spark.network; +import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.Iterator; +import java.util.List; import java.util.Set; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; -import com.google.common.base.Charsets; import com.google.common.collect.Sets; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -39,6 +41,7 @@ import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -46,17 +49,21 @@ public class RpcIntegrationSuite { static TransportServer server; static TransportClientFactory clientFactory; static RpcHandler rpcHandler; + static List oneWayMsgs; @BeforeClass public static void setUp() throws Exception { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); rpcHandler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - String msg = new String(message, Charsets.UTF_8); + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + String msg = JavaUtils.bytesToString(message); String[] parts = msg.split("/"); if (parts[0].equals("hello")) { - callback.onSuccess(("Hello, " + parts[1] + "!").getBytes(Charsets.UTF_8)); + callback.onSuccess(JavaUtils.stringToBytes("Hello, " + parts[1] + "!")); } else if (parts[0].equals("return error")) { callback.onFailure(new RuntimeException("Returned: " + parts[1])); } else if (parts[0].equals("throw error")) { @@ -64,12 +71,18 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback } } + @Override + public void receive(TransportClient client, ByteBuffer message) { + oneWayMsgs.add(JavaUtils.bytesToString(message)); + } + @Override public StreamManager getStreamManager() { return new OneForOneStreamManager(); } }; TransportContext context = new TransportContext(conf, rpcHandler); server = context.createServer(); clientFactory = context.createClientFactory(); + oneWayMsgs = new ArrayList<>(); } @AfterClass @@ -78,7 +91,7 @@ public static void tearDown() { clientFactory.close(); } - class RpcResult { + static class RpcResult { public Set successMessages; public Set errorMessages; } @@ -93,8 +106,9 @@ private RpcResult sendRPC(String ... commands) throws Exception { RpcResponseCallback callback = new RpcResponseCallback() { @Override - public void onSuccess(byte[] message) { - res.successMessages.add(new String(message, Charsets.UTF_8)); + public void onSuccess(ByteBuffer message) { + String response = JavaUtils.bytesToString(message); + res.successMessages.add(response); sem.release(); } @@ -106,7 +120,7 @@ public void onFailure(Throwable e) { }; for (String command : commands) { - client.sendRpc(command.getBytes(Charsets.UTF_8), callback); + client.sendRpc(JavaUtils.stringToBytes(command), callback); } if (!sem.tryAcquire(commands.length, 5, TimeUnit.SECONDS)) { @@ -158,6 +172,27 @@ public void sendSuccessAndFailure() throws Exception { assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: the", "Returned: !")); } + @Test + public void sendOneWayMessage() throws Exception { + final String message = "no reply"; + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + client.send(JavaUtils.stringToBytes(message)); + assertEquals(0, client.getHandler().numOutstandingRequests()); + + // Make sure the message arrives. + long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS); + while (System.nanoTime() < deadline && oneWayMsgs.size() == 0) { + TimeUnit.MILLISECONDS.sleep(10); + } + + assertEquals(1, oneWayMsgs.size()); + assertEquals(message, oneWayMsgs.get(0)); + } finally { + client.close(); + } + } + private void assertErrorsContain(Set errors, Set contains) { assertEquals(contains.size(), errors.size()); diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java similarity index 91% rename from network/common/src/test/java/org/apache/spark/network/StreamSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java index 6dcec831dec71..9c49556927f0b 100644 --- a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -51,13 +51,14 @@ import org.apache.spark.network.util.TransportConf; public class StreamSuite { - private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "file" }; + private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" }; private static TransportServer server; private static TransportClientFactory clientFactory; private static File testFile; private static File tempDir; + private static ByteBuffer emptyBuffer; private static ByteBuffer smallBuffer; private static ByteBuffer largeBuffer; @@ -73,6 +74,7 @@ private static ByteBuffer createBuffer(int bufSize) { @BeforeClass public static void setUp() throws Exception { tempDir = Files.createTempDir(); + emptyBuffer = createBuffer(0); smallBuffer = createBuffer(100); largeBuffer = createBuffer(100000); @@ -89,7 +91,7 @@ public static void setUp() throws Exception { fp.close(); } - final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); final StreamManager streamManager = new StreamManager() { @Override public ManagedBuffer getChunk(long streamId, int chunkIndex) { @@ -103,6 +105,8 @@ public ManagedBuffer openStream(String streamId) { return new NioManagedBuffer(largeBuffer); case "smallBuffer": return new NioManagedBuffer(smallBuffer); + case "emptyBuffer": + return new NioManagedBuffer(emptyBuffer); case "file": return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length()); default: @@ -112,7 +116,10 @@ public ManagedBuffer openStream(String streamId) { }; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { throw new UnsupportedOperationException(); } @@ -138,6 +145,18 @@ public static void tearDown() { } } + @Test + public void testZeroLengthStream() throws Throwable { + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + StreamTask task = new StreamTask(client, "emptyBuffer", TimeUnit.SECONDS.toMillis(5)); + task.run(); + task.check(); + } finally { + client.close(); + } + } + @Test public void testSingleStream() throws Throwable { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); @@ -226,6 +245,11 @@ public void run() { outFile = File.createTempFile("data", ".tmp", tempDir); out = new FileOutputStream(outFile); break; + case "emptyBuffer": + baos = new ByteArrayOutputStream(); + out = baos; + srcBuffer = emptyBuffer; + break; default: throw new IllegalArgumentException(streamId); } diff --git a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java b/common/network-common/src/test/java/org/apache/spark/network/TestManagedBuffer.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java rename to common/network-common/src/test/java/org/apache/spark/network/TestManagedBuffer.java diff --git a/network/common/src/test/java/org/apache/spark/network/TestUtils.java b/common/network-common/src/test/java/org/apache/spark/network/TestUtils.java similarity index 100% rename from network/common/src/test/java/org/apache/spark/network/TestUtils.java rename to common/network-common/src/test/java/org/apache/spark/network/TestUtils.java diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java similarity index 78% rename from network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java rename to common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 35de5e57ccb98..44d16d54225e7 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -21,11 +21,13 @@ import java.util.Collections; import java.util.HashSet; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import com.google.common.collect.Maps; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -37,6 +39,7 @@ import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.ConfigProvider; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.MapConfigProvider; @@ -50,7 +53,7 @@ public class TransportClientFactorySuite { @Before public void setUp() { - conf = new TransportConf(new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); RpcHandler rpcHandler = new NoOpRpcHandler(); context = new TransportContext(conf, rpcHandler); server1 = context.createServer(); @@ -74,7 +77,7 @@ private void testClientReuse(final int maxConnections, boolean concurrent) Map configMap = Maps.newHashMap(); configMap.put("spark.shuffle.io.numConnectionsPerPeer", Integer.toString(maxConnections)); - TransportConf conf = new TransportConf(new MapConfigProvider(configMap)); + TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); RpcHandler rpcHandler = new NoOpRpcHandler(); TransportContext context = new TransportContext(conf, rpcHandler); @@ -93,7 +96,7 @@ public void run() { try { TransportClient client = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); - assert (client.isActive()); + assertTrue(client.isActive()); clients.add(client); } catch (IOException e) { failed.incrementAndGet(); @@ -113,12 +116,14 @@ public void run() { attempts[i].join(); } - assert(failed.get() == 0); - assert(clients.size() == maxConnections); + Assert.assertEquals(0, failed.get()); + Assert.assertEquals(clients.size(), maxConnections); for (TransportClient client : clients) { client.close(); } + + factory.close(); } @Test @@ -177,4 +182,36 @@ public void closeBlockClientsWithFactory() throws IOException { assertFalse(c1.isActive()); assertFalse(c2.isActive()); } + + @Test + public void closeIdleConnectionForRequestTimeOut() throws IOException, InterruptedException { + TransportConf conf = new TransportConf("shuffle", new ConfigProvider() { + + @Override + public String get(String name) { + if ("spark.shuffle.io.connectionTimeout".equals(name)) { + // We should make sure there is enough time for us to observe the channel is active + return "1s"; + } + String value = System.getProperty(name); + if (value == null) { + throw new NoSuchElementException(name); + } + return value; + } + }); + TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true); + TransportClientFactory factory = context.createClientFactory(); + try { + TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + assertTrue(c1.isActive()); + long expiredTime = System.currentTimeMillis() + 10000; // 10 seconds + while (c1.isActive() && System.currentTimeMillis() < expiredTime) { + Thread.sleep(10); + } + assertFalse(c1.isActive()); + } finally { + factory.close(); + } + } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java new file mode 100644 index 0000000000000..128f7cba74350 --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.nio.ByteBuffer; + +import io.netty.channel.Channel; +import io.netty.channel.local.LocalChannel; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallback; +import org.apache.spark.network.client.TransportResponseHandler; +import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.StreamFailure; +import org.apache.spark.network.protocol.StreamResponse; +import org.apache.spark.network.util.TransportFrameDecoder; + +public class TransportResponseHandlerSuite { + @Test + public void handleSuccessfulFetch() throws Exception { + StreamChunkId streamChunkId = new StreamChunkId(1, 0); + + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + handler.addFetchRequest(streamChunkId, callback); + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123))); + verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); + assertEquals(0, handler.numOutstandingRequests()); + } + + @Test + public void handleFailedFetch() throws Exception { + StreamChunkId streamChunkId = new StreamChunkId(1, 0); + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + handler.addFetchRequest(streamChunkId, callback); + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new ChunkFetchFailure(streamChunkId, "some error msg")); + verify(callback, times(1)).onFailure(eq(0), (Throwable) any()); + assertEquals(0, handler.numOutstandingRequests()); + } + + @Test + public void clearAllOutstandingRequests() throws Exception { + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + handler.addFetchRequest(new StreamChunkId(1, 0), callback); + handler.addFetchRequest(new StreamChunkId(1, 1), callback); + handler.addFetchRequest(new StreamChunkId(1, 2), callback); + assertEquals(3, handler.numOutstandingRequests()); + + handler.handle(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12))); + handler.exceptionCaught(new Exception("duh duh duhhhh")); + + // should fail both b2 and b3 + verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); + verify(callback, times(1)).onFailure(eq(1), (Throwable) any()); + verify(callback, times(1)).onFailure(eq(2), (Throwable) any()); + assertEquals(0, handler.numOutstandingRequests()); + } + + @Test + public void handleSuccessfulRPC() throws Exception { + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + RpcResponseCallback callback = mock(RpcResponseCallback.class); + handler.addRpcRequest(12345, callback); + assertEquals(1, handler.numOutstandingRequests()); + + // This response should be ignored. + handler.handle(new RpcResponse(54321, new NioManagedBuffer(ByteBuffer.allocate(7)))); + assertEquals(1, handler.numOutstandingRequests()); + + ByteBuffer resp = ByteBuffer.allocate(10); + handler.handle(new RpcResponse(12345, new NioManagedBuffer(resp))); + verify(callback, times(1)).onSuccess(eq(ByteBuffer.allocate(10))); + assertEquals(0, handler.numOutstandingRequests()); + } + + @Test + public void handleFailedRPC() throws Exception { + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + RpcResponseCallback callback = mock(RpcResponseCallback.class); + handler.addRpcRequest(12345, callback); + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new RpcFailure(54321, "uh-oh!")); // should be ignored + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new RpcFailure(12345, "oh no")); + verify(callback, times(1)).onFailure((Throwable) any()); + assertEquals(0, handler.numOutstandingRequests()); + } + + @Test + public void testActiveStreams() throws Exception { + Channel c = new LocalChannel(); + c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); + TransportResponseHandler handler = new TransportResponseHandler(c); + + StreamResponse response = new StreamResponse("stream", 1234L, null); + StreamCallback cb = mock(StreamCallback.class); + handler.addStreamCallback(cb); + assertEquals(1, handler.numOutstandingRequests()); + handler.handle(response); + assertEquals(1, handler.numOutstandingRequests()); + handler.deactivateStream(); + assertEquals(0, handler.numOutstandingRequests()); + + StreamFailure failure = new StreamFailure("stream", "uh-oh"); + handler.addStreamCallback(cb); + assertEquals(1, handler.numOutstandingRequests()); + handler.handle(failure); + assertEquals(0, handler.numOutstandingRequests()); + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java new file mode 100644 index 0000000000000..b341c5681e00c --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.FileRegion; +import io.netty.util.AbstractReferenceCounted; +import org.junit.Test; +import org.mockito.Mockito; + +import static org.junit.Assert.*; + +import org.apache.spark.network.TestManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; +import org.apache.spark.network.util.ByteArrayWritableChannel; + +public class MessageWithHeaderSuite { + + @Test + public void testSingleWrite() throws Exception { + testFileRegionBody(8, 8); + } + + @Test + public void testShortWrite() throws Exception { + testFileRegionBody(8, 1); + } + + @Test + public void testByteBufBody() throws Exception { + ByteBuf header = Unpooled.copyLong(42); + ByteBuf bodyPassedToNettyManagedBuffer = Unpooled.copyLong(84); + assertEquals(1, header.refCnt()); + assertEquals(1, bodyPassedToNettyManagedBuffer.refCnt()); + ManagedBuffer managedBuf = new NettyManagedBuffer(bodyPassedToNettyManagedBuffer); + + Object body = managedBuf.convertToNetty(); + assertEquals(2, bodyPassedToNettyManagedBuffer.refCnt()); + assertEquals(1, header.refCnt()); + + MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, managedBuf.size()); + ByteBuf result = doWrite(msg, 1); + assertEquals(msg.count(), result.readableBytes()); + assertEquals(42, result.readLong()); + assertEquals(84, result.readLong()); + + assertTrue(msg.release()); + assertEquals(0, bodyPassedToNettyManagedBuffer.refCnt()); + assertEquals(0, header.refCnt()); + } + + @Test + public void testDeallocateReleasesManagedBuffer() throws Exception { + ByteBuf header = Unpooled.copyLong(42); + ManagedBuffer managedBuf = Mockito.spy(new TestManagedBuffer(84)); + ByteBuf body = (ByteBuf) managedBuf.convertToNetty(); + assertEquals(2, body.refCnt()); + MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, body.readableBytes()); + assertTrue(msg.release()); + Mockito.verify(managedBuf, Mockito.times(1)).release(); + assertEquals(0, body.refCnt()); + } + + private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception { + ByteBuf header = Unpooled.copyLong(42); + int headerLength = header.readableBytes(); + TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall); + MessageWithHeader msg = new MessageWithHeader(null, header, region, region.count()); + + ByteBuf result = doWrite(msg, totalWrites / writesPerCall); + assertEquals(headerLength + region.count(), result.readableBytes()); + assertEquals(42, result.readLong()); + for (long i = 0; i < 8; i++) { + assertEquals(i, result.readLong()); + } + assertTrue(msg.release()); + } + + private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception { + int writes = 0; + ByteArrayWritableChannel channel = new ByteArrayWritableChannel((int) msg.count()); + while (msg.transfered() < msg.count()) { + msg.transferTo(channel, msg.transfered()); + writes++; + } + assertTrue("Not enough writes!", minExpectedWrites <= writes); + return Unpooled.wrappedBuffer(channel.getData()); + } + + private static class TestFileRegion extends AbstractReferenceCounted implements FileRegion { + + private final int writeCount; + private final int writesPerCall; + private int written; + + TestFileRegion(int totalWrites, int writesPerCall) { + this.writeCount = totalWrites; + this.writesPerCall = writesPerCall; + } + + @Override + public long count() { + return 8 * writeCount; + } + + @Override + public long position() { + return 0; + } + + @Override + public long transfered() { + return 8 * written; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + for (int i = 0; i < writesPerCall; i++) { + ByteBuf buf = Unpooled.copyLong((position / 8) + i); + ByteBuffer nio = buf.nioBuffer(); + while (nio.remaining() > 0) { + target.write(nio); + } + buf.release(); + written++; + } + return 8 * writesPerCall; + } + + @Override + protected void deallocate() { + } + + } + +} diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java similarity index 86% rename from network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java rename to common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 3469e84e7f4da..45cc03df435ac 100644 --- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -21,10 +21,12 @@ import static org.mockito.Mockito.*; import java.io.File; -import java.nio.charset.StandardCharsets; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.List; import java.util.Random; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeoutException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; @@ -56,6 +58,7 @@ import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.util.ByteArrayWritableChannel; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -122,39 +125,53 @@ public void testNonMatching() { } @Test - public void testSaslAuthentication() throws Exception { + public void testSaslAuthentication() throws Throwable { testBasicSasl(false); } @Test - public void testSaslEncryption() throws Exception { + public void testSaslEncryption() throws Throwable { testBasicSasl(true); } - private void testBasicSasl(boolean encrypt) throws Exception { + private void testBasicSasl(boolean encrypt) throws Throwable { RpcHandler rpcHandler = mock(RpcHandler.class); doAnswer(new Answer() { @Override public Void answer(InvocationOnMock invocation) { - byte[] message = (byte[]) invocation.getArguments()[1]; + ByteBuffer message = (ByteBuffer) invocation.getArguments()[1]; RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2]; - assertEquals("Ping", new String(message, StandardCharsets.UTF_8)); - cb.onSuccess("Pong".getBytes(StandardCharsets.UTF_8)); + assertEquals("Ping", JavaUtils.bytesToString(message)); + cb.onSuccess(JavaUtils.stringToBytes("Pong")); return null; } }) .when(rpcHandler) - .receive(any(TransportClient.class), any(byte[].class), any(RpcResponseCallback.class)); + .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class)); SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false); try { - byte[] response = ctx.client.sendRpcSync("Ping".getBytes(StandardCharsets.UTF_8), - TimeUnit.SECONDS.toMillis(10)); - assertEquals("Pong", new String(response, StandardCharsets.UTF_8)); + ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), + TimeUnit.SECONDS.toMillis(10)); + assertEquals("Pong", JavaUtils.bytesToString(response)); } finally { ctx.close(); // There should be 2 terminated events; one for the client, one for the server. - verify(rpcHandler, times(2)).connectionTerminated(any(TransportClient.class)); + Throwable error = null; + long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS); + while (deadline > System.nanoTime()) { + try { + verify(rpcHandler, times(2)).channelInactive(any(TransportClient.class)); + error = null; + break; + } catch (Throwable t) { + error = t; + TimeUnit.MILLISECONDS.sleep(10); + } + } + if (error != null) { + throw error; + } } } @@ -207,7 +224,7 @@ public void testEncryptedMessage() throws Exception { public void testEncryptedMessageChunking() throws Exception { File file = File.createTempFile("sasltest", ".txt"); try { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); byte[] data = new byte[8 * 1024]; new Random().nextBytes(data); @@ -242,7 +259,7 @@ public void testFileRegionEncryption() throws Exception { final File file = File.createTempFile("sasltest", ".txt"); SaslTestCtx ctx = null; try { - final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); StreamManager sm = mock(StreamManager.class); when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer() { @Override @@ -260,7 +277,7 @@ public ManagedBuffer answer(InvocationOnMock invocation) { ctx = new SaslTestCtx(rpcHandler, true, false); - final Object lock = new Object(); + final CountDownLatch lock = new CountDownLatch(1); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); doAnswer(new Answer() { @@ -268,17 +285,13 @@ public ManagedBuffer answer(InvocationOnMock invocation) { public Void answer(InvocationOnMock invocation) { response.set((ManagedBuffer) invocation.getArguments()[1]); response.get().retain(); - synchronized (lock) { - lock.notifyAll(); - } + lock.countDown(); return null; } }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class)); - synchronized (lock) { - ctx.client.fetchChunk(0, 0, callback); - lock.wait(10 * 1000); - } + ctx.client.fetchChunk(0, 0, callback); + lock.await(10, TimeUnit.SECONDS); verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class)); verify(callback, never()).onFailure(anyInt(), any(Throwable.class)); @@ -324,8 +337,8 @@ public void testDataEncryptionIsActuallyEnabled() throws Exception { SaslTestCtx ctx = null; try { ctx = new SaslTestCtx(mock(RpcHandler.class), true, true); - ctx.client.sendRpcSync("Ping".getBytes(StandardCharsets.UTF_8), - TimeUnit.SECONDS.toMillis(10)); + ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), + TimeUnit.SECONDS.toMillis(10)); fail("Should have failed to send RPC to server."); } catch (Exception e) { assertFalse(e.getCause() instanceof TimeoutException); @@ -346,13 +359,21 @@ public void testRpcHandlerDelegate() throws Exception { saslHandler.getStreamManager(); verify(handler).getStreamManager(); - saslHandler.connectionTerminated(null); - verify(handler).connectionTerminated(any(TransportClient.class)); + saslHandler.channelInactive(null); + verify(handler).channelInactive(any(TransportClient.class)); saslHandler.exceptionCaught(null, null); verify(handler).exceptionCaught(any(Throwable.class), any(TransportClient.class)); } + @Test + public void testDelegates() throws Exception { + Method[] rpcHandlerMethods = RpcHandler.class.getDeclaredMethods(); + for (Method m : rpcHandlerMethods) { + SaslRpcHandler.class.getDeclaredMethod(m.getName(), m.getParameterTypes()); + } + } + private static class SaslTestCtx { final TransportClient client; @@ -368,7 +389,7 @@ private static class SaslTestCtx { boolean disableClientEncryption) throws Exception { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); SecretKeyHolder keyHolder = mock(SecretKeyHolder.class); when(keyHolder.getSaslUser(anyString())).thenReturn("user"); diff --git a/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java new file mode 100644 index 0000000000000..c647525d8f1bd --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.util.ArrayList; +import java.util.List; + +import io.netty.channel.Channel; +import org.junit.Test; +import org.mockito.Mockito; + +import org.apache.spark.network.TestManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; + +public class OneForOneStreamManagerSuite { + + @Test + public void managedBuffersAreFeedWhenConnectionIsClosed() throws Exception { + OneForOneStreamManager manager = new OneForOneStreamManager(); + List buffers = new ArrayList<>(); + TestManagedBuffer buffer1 = Mockito.spy(new TestManagedBuffer(10)); + TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20)); + buffers.add(buffer1); + buffers.add(buffer2); + long streamId = manager.registerStream("appId", buffers.iterator()); + + Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS); + manager.registerChannel(dummyChannel, streamId); + + manager.connectionTerminated(dummyChannel); + + Mockito.verify(buffer1, Mockito.times(1)).release(); + Mockito.verify(buffer2, Mockito.times(1)).release(); + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java new file mode 100644 index 0000000000000..d4de4a941d480 --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import org.junit.AfterClass; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +public class TransportFrameDecoderSuite { + + private static Random RND = new Random(); + + @AfterClass + public static void cleanup() { + RND = null; + } + + @Test + public void testFrameDecoding() throws Exception { + TransportFrameDecoder decoder = new TransportFrameDecoder(); + ChannelHandlerContext ctx = mockChannelHandlerContext(); + ByteBuf data = createAndFeedFrames(100, decoder, ctx); + verifyAndCloseDecoder(decoder, ctx, data); + } + + @Test + public void testInterception() throws Exception { + final int interceptedReads = 3; + TransportFrameDecoder decoder = new TransportFrameDecoder(); + TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads)); + ChannelHandlerContext ctx = mockChannelHandlerContext(); + + byte[] data = new byte[8]; + ByteBuf len = Unpooled.copyLong(8 + data.length); + ByteBuf dataBuf = Unpooled.wrappedBuffer(data); + + try { + decoder.setInterceptor(interceptor); + for (int i = 0; i < interceptedReads; i++) { + decoder.channelRead(ctx, dataBuf); + assertEquals(0, dataBuf.refCnt()); + dataBuf = Unpooled.wrappedBuffer(data); + } + decoder.channelRead(ctx, len); + decoder.channelRead(ctx, dataBuf); + verify(interceptor, times(interceptedReads)).handle(any(ByteBuf.class)); + verify(ctx).fireChannelRead(any(ByteBuffer.class)); + assertEquals(0, len.refCnt()); + assertEquals(0, dataBuf.refCnt()); + } finally { + release(len); + release(dataBuf); + } + } + + @Test + public void testRetainedFrames() throws Exception { + TransportFrameDecoder decoder = new TransportFrameDecoder(); + + final AtomicInteger count = new AtomicInteger(); + final List retained = new ArrayList<>(); + + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + when(ctx.fireChannelRead(any())).thenAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock in) { + // Retain a few frames but not others. + ByteBuf buf = (ByteBuf) in.getArguments()[0]; + if (count.incrementAndGet() % 2 == 0) { + retained.add(buf); + } else { + buf.release(); + } + return null; + } + }); + + ByteBuf data = createAndFeedFrames(100, decoder, ctx); + try { + // Verify all retained buffers are readable. + for (ByteBuf b : retained) { + byte[] tmp = new byte[b.readableBytes()]; + b.readBytes(tmp); + b.release(); + } + verifyAndCloseDecoder(decoder, ctx, data); + } finally { + for (ByteBuf b : retained) { + release(b); + } + } + } + + @Test + public void testSplitLengthField() throws Exception { + byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)]; + ByteBuf buf = Unpooled.buffer(frame.length + 8); + buf.writeLong(frame.length + 8); + buf.writeBytes(frame); + + TransportFrameDecoder decoder = new TransportFrameDecoder(); + ChannelHandlerContext ctx = mockChannelHandlerContext(); + try { + decoder.channelRead(ctx, buf.readSlice(RND.nextInt(7)).retain()); + verify(ctx, never()).fireChannelRead(any(ByteBuf.class)); + decoder.channelRead(ctx, buf); + verify(ctx).fireChannelRead(any(ByteBuf.class)); + assertEquals(0, buf.refCnt()); + } finally { + decoder.channelInactive(ctx); + release(buf); + } + } + + @Test(expected = IllegalArgumentException.class) + public void testNegativeFrameSize() throws Exception { + testInvalidFrame(-1); + } + + @Test(expected = IllegalArgumentException.class) + public void testEmptyFrame() throws Exception { + // 8 because frame size includes the frame length. + testInvalidFrame(8); + } + + @Test(expected = IllegalArgumentException.class) + public void testLargeFrame() throws Exception { + // Frame length includes the frame size field, so need to add a few more bytes. + testInvalidFrame(Integer.MAX_VALUE + 9); + } + + /** + * Creates a number of randomly sized frames and feed them to the given decoder, verifying + * that the frames were read. + */ + private ByteBuf createAndFeedFrames( + int frameCount, + TransportFrameDecoder decoder, + ChannelHandlerContext ctx) throws Exception { + ByteBuf data = Unpooled.buffer(); + for (int i = 0; i < frameCount; i++) { + byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)]; + data.writeLong(frame.length + 8); + data.writeBytes(frame); + } + + try { + while (data.isReadable()) { + int size = RND.nextInt(4 * 1024) + 256; + decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)).retain()); + } + + verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class)); + } catch (Exception e) { + release(data); + throw e; + } + return data; + } + + private void verifyAndCloseDecoder( + TransportFrameDecoder decoder, + ChannelHandlerContext ctx, + ByteBuf data) throws Exception { + try { + decoder.channelInactive(ctx); + assertTrue("There shouldn't be dangling references to the data.", data.release()); + } finally { + release(data); + } + } + + private void testInvalidFrame(long size) throws Exception { + TransportFrameDecoder decoder = new TransportFrameDecoder(); + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ByteBuf frame = Unpooled.copyLong(size); + try { + decoder.channelRead(ctx, frame); + } finally { + release(frame); + } + } + + private ChannelHandlerContext mockChannelHandlerContext() { + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + when(ctx.fireChannelRead(any())).thenAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock in) { + ByteBuf buf = (ByteBuf) in.getArguments()[0]; + buf.release(); + return null; + } + }); + return ctx; + } + + private void release(ByteBuf buf) { + if (buf.refCnt() > 0) { + buf.release(buf.refCnt()); + } + } + + private static class MockInterceptor implements TransportFrameDecoder.Interceptor { + + private int remainingReads; + + MockInterceptor(int readCount) { + this.remainingReads = readCount; + } + + @Override + public boolean handle(ByteBuf data) throws Exception { + data.readerIndex(data.readerIndex() + data.readableBytes()); + assertFalse(data.isReadable()); + remainingReads -= 1; + return remainingReads != 0; + } + + @Override + public void exceptionCaught(Throwable cause) throws Exception { + + } + + @Override + public void channelInactive() throws Exception { + + } + + } + +} diff --git a/network/common/src/test/resources/log4j.properties b/common/network-common/src/test/resources/log4j.properties similarity index 100% rename from network/common/src/test/resources/log4j.properties rename to common/network-common/src/test/resources/log4j.properties diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml new file mode 100644 index 0000000000000..810ec10ca05b3 --- /dev/null +++ b/common/network-shuffle/pom.xml @@ -0,0 +1,101 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.0.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-network-shuffle_2.11 + jar + Spark Project Shuffle Streaming Service + http://spark.apache.org/ + + network-shuffle + + + + + + org.apache.spark + spark-network-common_${scala.binary.version} + ${project.version} + + + + org.fusesource.leveldbjni + leveldbjni-all + 1.8 + + + + com.fasterxml.jackson.core + jackson-databind + + + + com.fasterxml.jackson.core + jackson-annotations + + + + + org.slf4j + slf4j-api + provided + + + com.google.guava + guava + + + + + org.apache.spark + spark-network-common_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-test-tags_${scala.binary.version} + + + log4j + log4j + test + + + org.mockito + mockito-core + test + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java similarity index 95% rename from network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java index 351c7930a900f..56a025c4d95d8 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java @@ -17,14 +17,12 @@ package org.apache.spark.network.sasl; -import java.lang.Override; import java.nio.ByteBuffer; import java.util.concurrent.ConcurrentHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.util.JavaUtils; /** @@ -39,7 +37,7 @@ public class ShuffleSecretManager implements SecretKeyHolder { private static final String SPARK_SASL_USER = "sparkSaslUser"; public ShuffleSecretManager() { - shuffleSecretMap = new ConcurrentHashMap(); + shuffleSecretMap = new ConcurrentHashMap<>(); } /** diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java similarity index 94% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 3ddf5c3c39189..f8d03b3b9433a 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -19,6 +19,7 @@ import java.io.File; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.List; import com.google.common.annotations.VisibleForTesting; @@ -51,7 +52,8 @@ public class ExternalShuffleBlockHandler extends RpcHandler { final ExternalShuffleBlockResolver blockManager; private final OneForOneStreamManager streamManager; - public ExternalShuffleBlockHandler(TransportConf conf, File registeredExecutorFile) throws IOException { + public ExternalShuffleBlockHandler(TransportConf conf, File registeredExecutorFile) + throws IOException { this(new OneForOneStreamManager(), new ExternalShuffleBlockResolver(conf, registeredExecutorFile)); } @@ -66,8 +68,8 @@ public ExternalShuffleBlockHandler( } @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteArray(message); + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { + BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message); handleMessage(msgObj, client, callback); } @@ -85,13 +87,13 @@ protected void handleMessage( } long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator()); logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length); - callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray()); + callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteBuffer()); } else if (msgObj instanceof RegisterExecutor) { RegisterExecutor msg = (RegisterExecutor) msgObj; checkAuth(client, msg.appId); blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo); - callback.onSuccess(new byte[0]); + callback.onSuccess(ByteBuffer.wrap(new byte[0])); } else { throw new UnsupportedOperationException("Unexpected message: " + msgObj); diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java similarity index 95% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 0d4dd6afac769..ce5c68e85375e 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -18,6 +18,7 @@ package org.apache.spark.network.shuffle; import java.io.*; +import java.nio.charset.StandardCharsets; import java.util.*; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.Executor; @@ -27,7 +28,6 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Charsets; import com.google.common.base.Objects; import com.google.common.collect.Maps; import org.fusesource.leveldbjni.JniDBFactory; @@ -152,7 +152,7 @@ public void registerExecutor( try { if (db != null) { byte[] key = dbAppExecKey(fullId); - byte[] value = mapper.writeValueAsString(executorInfo).getBytes(Charsets.UTF_8); + byte[] value = mapper.writeValueAsString(executorInfo).getBytes(StandardCharsets.UTF_8); db.put(key, value); } } catch (Exception e) { @@ -183,11 +183,10 @@ public ManagedBuffer getBlockData(String appId, String execId, String blockId) { String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); } - if ("org.apache.spark.shuffle.hash.HashShuffleManager".equals(executor.shuffleManager)) { - return getHashBasedShuffleBlockData(executor, blockId); - } else if ("org.apache.spark.shuffle.sort.SortShuffleManager".equals(executor.shuffleManager) - || "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager".equals(executor.shuffleManager)) { + if ("sort".equals(executor.shuffleManager) || "tungsten-sort".equals(executor.shuffleManager)) { return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); + } else if ("hash".equals(executor.shuffleManager)) { + return getHashBasedShuffleBlockData(executor, blockId); } else { throw new UnsupportedOperationException( "Unsupported shuffle manager: " + executor.shuffleManager); @@ -351,7 +350,7 @@ private static byte[] dbAppExecKey(AppExecId appExecId) throws IOException { // we stick a common prefix on all the keys so we can find them in the DB String appExecJson = mapper.writeValueAsString(appExecId); String key = (APP_KEY_PREFIX + ";" + appExecJson); - return key.getBytes(Charsets.UTF_8); + return key.getBytes(StandardCharsets.UTF_8); } private static AppExecId parseDbAppExecKey(String s) throws IOException { @@ -369,10 +368,10 @@ static ConcurrentMap reloadRegisteredExecutors(D ConcurrentMap registeredExecutors = Maps.newConcurrentMap(); if (db != null) { DBIterator itr = db.iterator(); - itr.seek(APP_KEY_PREFIX.getBytes(Charsets.UTF_8)); + itr.seek(APP_KEY_PREFIX.getBytes(StandardCharsets.UTF_8)); while (itr.hasNext()) { Map.Entry e = itr.next(); - String key = new String(e.getKey(), Charsets.UTF_8); + String key = new String(e.getKey(), StandardCharsets.UTF_8); if (!key.startsWith(APP_KEY_PREFIX)) { break; } @@ -419,12 +418,14 @@ private static void storeVersion(DB db) throws IOException { public static class StoreVersion { - final static byte[] KEY = "StoreVersion".getBytes(Charsets.UTF_8); + static final byte[] KEY = "StoreVersion".getBytes(StandardCharsets.UTF_8); public final int major; public final int minor; - @JsonCreator public StoreVersion(@JsonProperty("major") int major, @JsonProperty("minor") int minor) { + @JsonCreator public StoreVersion( + @JsonProperty("major") int major, + @JsonProperty("minor") int minor) { this.major = major; this.minor = minor; } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java similarity index 94% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index ea6d248d66be3..58ca87d9d3b13 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -18,6 +18,7 @@ package org.apache.spark.network.shuffle; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.List; import com.google.common.base.Preconditions; @@ -78,7 +79,7 @@ protected void checkInit() { @Override public void init(String appId) { this.appId = appId; - TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); + TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true); List bootstraps = Lists.newArrayList(); if (saslEnabled) { bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder, saslEncryptionEnabled)); @@ -137,9 +138,13 @@ public void registerWithShuffleServer( String execId, ExecutorShuffleInfo executorInfo) throws IOException { checkInit(); - TransportClient client = clientFactory.createClient(host, port); - byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray(); - client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); + TransportClient client = clientFactory.createUnmanagedClient(host, port); + try { + ByteBuffer registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteBuffer(); + client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); + } finally { + client.close(); + } } @Override diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java similarity index 96% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index e653f5cb147ee..1b2ddbf1ed917 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -17,6 +17,7 @@ package org.apache.spark.network.shuffle; +import java.nio.ByteBuffer; import java.util.Arrays; import org.slf4j.Logger; @@ -89,11 +90,11 @@ public void start() { throw new IllegalArgumentException("Zero-sized blockIds array"); } - client.sendRpc(openMessage.toByteArray(), new RpcResponseCallback() { + client.sendRpc(openMessage.toByteBuffer(), new RpcResponseCallback() { @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { try { - streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response); + streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response); logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle); // Immediately request all chunks -- we expect that the total size of the request is diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java similarity index 99% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java index 4bb0498e5d5aa..d81cf869ddb9e 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java @@ -46,7 +46,7 @@ public class RetryingBlockFetcher { * Used to initiate the first fetch for all blocks, and subsequently for retrying the fetch on any * remaining blocks. */ - public static interface BlockFetchStarter { + public interface BlockFetchStarter { /** * Creates a new BlockFetcher to fetch the given block ids which may do some synchronous * bootstrapping followed by fully asynchronous block fetching. diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java new file mode 100644 index 0000000000000..2add9c83a73d2 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.mesos; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.apache.spark.network.shuffle.protocol.mesos.ShuffleServiceHeartbeat; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.shuffle.ExternalShuffleClient; +import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver; +import org.apache.spark.network.util.TransportConf; + +/** + * A client for talking to the external shuffle service in Mesos coarse-grained mode. + * + * This is used by the Spark driver to register with each external shuffle service on the cluster. + * The reason why the driver has to talk to the service is for cleaning up shuffle files reliably + * after the application exits. Mesos does not provide a great alternative to do this, so Spark + * has to detect this itself. + */ +public class MesosExternalShuffleClient extends ExternalShuffleClient { + private final Logger logger = LoggerFactory.getLogger(MesosExternalShuffleClient.class); + + private final ScheduledExecutorService heartbeaterThread = + Executors.newSingleThreadScheduledExecutor( + new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("mesos-external-shuffle-client-heartbeater") + .build()); + + /** + * Creates an Mesos external shuffle client that wraps the {@link ExternalShuffleClient}. + * Please refer to docs on {@link ExternalShuffleClient} for more information. + */ + public MesosExternalShuffleClient( + TransportConf conf, + SecretKeyHolder secretKeyHolder, + boolean saslEnabled, + boolean saslEncryptionEnabled) { + super(conf, secretKeyHolder, saslEnabled, saslEncryptionEnabled); + } + + public void registerDriverWithShuffleService( + String host, + int port, + long heartbeatTimeoutMs, + long heartbeatIntervalMs) throws IOException { + + checkInit(); + ByteBuffer registerDriver = new RegisterDriver(appId, heartbeatTimeoutMs).toByteBuffer(); + TransportClient client = clientFactory.createClient(host, port); + client.sendRpc(registerDriver, new RegisterDriverCallback(client, heartbeatIntervalMs)); + } + + private class RegisterDriverCallback implements RpcResponseCallback { + private final TransportClient client; + private final long heartbeatIntervalMs; + + private RegisterDriverCallback(TransportClient client, long heartbeatIntervalMs) { + this.client = client; + this.heartbeatIntervalMs = heartbeatIntervalMs; + } + + @Override + public void onSuccess(ByteBuffer response) { + heartbeaterThread.scheduleAtFixedRate( + new Heartbeater(client), 0, heartbeatIntervalMs, TimeUnit.MILLISECONDS); + logger.info("Successfully registered app " + appId + " with external shuffle service."); + } + + @Override + public void onFailure(Throwable e) { + logger.warn("Unable to register app " + appId + " with external shuffle service. " + + "Please manually remove shuffle data after driver exit. Error: " + e); + } + } + + @Override + public void close() { + heartbeaterThread.shutdownNow(); + super.close(); + } + + private class Heartbeater implements Runnable { + + private final TransportClient client; + + private Heartbeater(TransportClient client) { + this.client = client; + } + + @Override + public void run() { + // TODO: Stop sending heartbeats if the shuffle service has lost the app due to timeout + client.send(new ShuffleServiceHeartbeat(appId).toByteBuffer()); + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java similarity index 88% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index fcb52363e632c..9af6759f5d5f3 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -17,11 +17,14 @@ package org.apache.spark.network.shuffle.protocol; +import java.nio.ByteBuffer; + import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import org.apache.spark.network.protocol.Encodable; import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver; +import org.apache.spark.network.shuffle.protocol.mesos.ShuffleServiceHeartbeat; /** * Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or @@ -37,12 +40,13 @@ public abstract class BlockTransferMessage implements Encodable { protected abstract Type type(); /** Preceding every serialized message is its type, which allows us to deserialize it. */ - public static enum Type { - OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4); + public enum Type { + OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), + HEARTBEAT(5); private final byte id; - private Type(int id) { + Type(int id) { assert id < 128 : "Cannot have more than 128 message types"; this.id = (byte) id; } @@ -53,7 +57,7 @@ private Type(int id) { // NB: Java does not support static methods in interfaces, so we must put this in a static class. public static class Decoder { /** Deserializes the 'type' byte followed by the message itself. */ - public static BlockTransferMessage fromByteArray(byte[] msg) { + public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { ByteBuf buf = Unpooled.wrappedBuffer(msg); byte type = buf.readByte(); switch (type) { @@ -62,18 +66,19 @@ public static BlockTransferMessage fromByteArray(byte[] msg) { case 2: return RegisterExecutor.decode(buf); case 3: return StreamHandle.decode(buf); case 4: return RegisterDriver.decode(buf); + case 5: return ShuffleServiceHeartbeat.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } } /** Serializes the 'type' byte followed by the message itself. */ - public byte[] toByteArray() { + public ByteBuffer toByteBuffer() { // Allow room for encoded message, plus the type byte ByteBuf buf = Unpooled.buffer(encodedLength() + 1); buf.writeByte(type().id); encode(buf); assert buf.writableBytes() == 0 : "Writable bytes remain: " + buf.writableBytes(); - return buf.array(); + return buf.nioBuffer(); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java similarity index 100% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java new file mode 100644 index 0000000000000..d5f53ccb7f741 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol.mesos; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +/** + * A message sent from the driver to register with the MesosExternalShuffleService. + */ +public class RegisterDriver extends BlockTransferMessage { + private final String appId; + private final long heartbeatTimeoutMs; + + public RegisterDriver(String appId, long heartbeatTimeoutMs) { + this.appId = appId; + this.heartbeatTimeoutMs = heartbeatTimeoutMs; + } + + public String getAppId() { return appId; } + + public long getHeartbeatTimeoutMs() { return heartbeatTimeoutMs; } + + @Override + protected Type type() { return Type.REGISTER_DRIVER; } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + Long.SIZE / Byte.SIZE; + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + buf.writeLong(heartbeatTimeoutMs); + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, heartbeatTimeoutMs); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof RegisterDriver)) { + return false; + } + return Objects.equal(appId, ((RegisterDriver) o).appId); + } + + public static RegisterDriver decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + long heartbeatTimeout = buf.readLong(); + return new RegisterDriver(appId, heartbeatTimeout); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java new file mode 100644 index 0000000000000..b30bb9aed55b6 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol.mesos; + +import io.netty.buffer.ByteBuf; +import org.apache.spark.network.protocol.Encoders; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +/** + * A heartbeat sent from the driver to the MesosExternalShuffleService. + */ +public class ShuffleServiceHeartbeat extends BlockTransferMessage { + private final String appId; + + public ShuffleServiceHeartbeat(String appId) { + this.appId = appId; + } + + public String getAppId() { return appId; } + + @Override + protected Type type() { return Type.HEARTBEAT; } + + @Override + public int encodedLength() { return Encoders.Strings.encodedLength(appId); } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + } + + public static ShuffleServiceHeartbeat decode(ByteBuf buf) { + return new ShuffleServiceHeartbeat(Encoders.Strings.decode(buf)); + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java similarity index 85% rename from network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index c393a5e1e6810..5bf99241851e7 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -18,7 +18,9 @@ package org.apache.spark.network.sasl; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.Arrays; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReference; import com.google.common.collect.Lists; @@ -52,6 +54,7 @@ import org.apache.spark.network.shuffle.protocol.OpenBlocks; import org.apache.spark.network.shuffle.protocol.RegisterExecutor; import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -59,7 +62,7 @@ public class SaslIntegrationSuite { // Use a long timeout to account for slow / overloaded build machines. In the normal case, // tests should finish way before the timeout expires. - private final static long TIMEOUT_MS = 10_000; + private static final long TIMEOUT_MS = 10_000; static TransportServer server; static TransportConf conf; @@ -70,7 +73,7 @@ public class SaslIntegrationSuite { @BeforeClass public static void beforeAll() throws IOException { - conf = new TransportConf(new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); context = new TransportContext(conf, new TestRpcHandler()); secretKeyHolder = mock(SecretKeyHolder.class); @@ -107,8 +110,8 @@ public void testGoodClient() throws IOException { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); String msg = "Hello, World!"; - byte[] resp = client.sendRpcSync(msg.getBytes(), TIMEOUT_MS); - assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg + ByteBuffer resp = client.sendRpcSync(JavaUtils.stringToBytes(msg), TIMEOUT_MS); + assertEquals(msg, JavaUtils.bytesToString(resp)); } @Test @@ -136,7 +139,7 @@ public void testNoSaslClient() throws IOException { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); try { - client.sendRpcSync(new byte[13], TIMEOUT_MS); + client.sendRpcSync(ByteBuffer.allocate(13), TIMEOUT_MS); fail("Should have failed"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage")); @@ -144,7 +147,7 @@ public void testNoSaslClient() throws IOException { try { // Guessing the right tag byte doesn't magically get you in... - client.sendRpcSync(new byte[] { (byte) 0xEA }, TIMEOUT_MS); + client.sendRpcSync(ByteBuffer.wrap(new byte[] { (byte) 0xEA }), TIMEOUT_MS); fail("Should have failed"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException")); @@ -195,40 +198,37 @@ public void testAppIsolation() throws Exception { final AtomicReference exception = new AtomicReference<>(); + final CountDownLatch blockFetchLatch = new CountDownLatch(1); BlockFetchingListener listener = new BlockFetchingListener() { @Override - public synchronized void onBlockFetchSuccess(String blockId, ManagedBuffer data) { - notifyAll(); + public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { + blockFetchLatch.countDown(); } - @Override - public synchronized void onBlockFetchFailure(String blockId, Throwable t) { + public void onBlockFetchFailure(String blockId, Throwable t) { exception.set(t); - notifyAll(); + blockFetchLatch.countDown(); } }; - String[] blockIds = new String[] { "shuffle_2_3_4", "shuffle_6_7_8" }; - OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client1, "app-2", "0", - blockIds, listener); - synchronized (listener) { - fetcher.start(); - listener.wait(); - } + String[] blockIds = { "shuffle_2_3_4", "shuffle_6_7_8" }; + OneForOneBlockFetcher fetcher = + new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener); + fetcher.start(); + blockFetchLatch.await(); checkSecurityException(exception.get()); // Register an executor so that the next steps work. ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo( - new String[] { System.getProperty("java.io.tmpdir") }, 1, - "org.apache.spark.shuffle.sort.SortShuffleManager"); + new String[] { System.getProperty("java.io.tmpdir") }, 1, "sort"); RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo); - client1.sendRpcSync(regmsg.toByteArray(), TIMEOUT_MS); + client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS); // Make a successful request to fetch blocks, which creates a new stream. But do not actually // fetch any blocks, to keep the stream open. OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds); - byte[] response = client1.sendRpcSync(openMessage.toByteArray(), TIMEOUT_MS); - StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response); + ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), TIMEOUT_MS); + StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response); long streamId = stream.streamId; // Create a second client, authenticated with a different app ID, and try to read from @@ -239,24 +239,22 @@ public synchronized void onBlockFetchFailure(String blockId, Throwable t) { client2 = clientFactory2.createClient(TestUtils.getLocalHost(), blockServer.getPort()); + final CountDownLatch chunkReceivedLatch = new CountDownLatch(1); ChunkReceivedCallback callback = new ChunkReceivedCallback() { @Override - public synchronized void onSuccess(int chunkIndex, ManagedBuffer buffer) { - notifyAll(); + public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + chunkReceivedLatch.countDown(); } - @Override - public synchronized void onFailure(int chunkIndex, Throwable t) { + public void onFailure(int chunkIndex, Throwable t) { exception.set(t); - notifyAll(); + chunkReceivedLatch.countDown(); } }; exception.set(null); - synchronized (callback) { - client2.fetchChunk(streamId, 0, callback); - callback.wait(); - } + client2.fetchChunk(streamId, 0, callback); + chunkReceivedLatch.await(); checkSecurityException(exception.get()); } finally { if (client1 != null) { @@ -275,7 +273,7 @@ public synchronized void onFailure(int chunkIndex, Throwable t) { /** RPC handler which simply responds with the message it received. */ public static class TestRpcHandler extends RpcHandler { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { callback.onSuccess(message); } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java similarity index 98% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java index d65de9ca550a3..86c8609e7070b 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java @@ -36,7 +36,7 @@ public void serializeOpenShuffleBlocks() { } private void checkSerializeDeserialize(BlockTransferMessage msg) { - BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteArray(msg.toByteArray()); + BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteBuffer(msg.toByteBuffer()); assertEquals(msg, msg2); assertEquals(msg.hashCode(), msg2.hashCode()); assertEquals(msg.toString(), msg2.toString()); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java similarity index 84% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index e61390cf57061..c2e0b7447fb8b 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -60,12 +60,12 @@ public void testRegisterExecutor() { RpcResponseCallback callback = mock(RpcResponseCallback.class); ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort"); - byte[] registerMessage = new RegisterExecutor("app0", "exec1", config).toByteArray(); + ByteBuffer registerMessage = new RegisterExecutor("app0", "exec1", config).toByteBuffer(); handler.receive(client, registerMessage, callback); verify(blockResolver, times(1)).registerExecutor("app0", "exec1", config); - verify(callback, times(1)).onSuccess((byte[]) any()); - verify(callback, never()).onFailure((Throwable) any()); + verify(callback, times(1)).onSuccess(any(ByteBuffer.class)); + verify(callback, never()).onFailure(any(Throwable.class)); } @SuppressWarnings("unchecked") @@ -77,17 +77,18 @@ public void testOpenShuffleBlocks() { ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); when(blockResolver.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker); when(blockResolver.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker); - byte[] openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }).toByteArray(); + ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }) + .toByteBuffer(); handler.receive(client, openBlocks, callback); verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); - ArgumentCaptor response = ArgumentCaptor.forClass(byte[].class); + ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); verify(callback, times(1)).onSuccess(response.capture()); verify(callback, never()).onFailure((Throwable) any()); StreamHandle handle = - (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response.getValue()); + (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue()); assertEquals(2, handle.numChunks); @SuppressWarnings("unchecked") @@ -104,7 +105,7 @@ public void testOpenShuffleBlocks() { public void testBadMessages() { RpcResponseCallback callback = mock(RpcResponseCallback.class); - byte[] unserializableMsg = new byte[] { 0x12, 0x34, 0x56 }; + ByteBuffer unserializableMsg = ByteBuffer.wrap(new byte[] { 0x12, 0x34, 0x56 }); try { handler.receive(client, unserializableMsg, callback); fail("Should have thrown"); @@ -112,7 +113,8 @@ public void testBadMessages() { // pass } - byte[] unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteArray(); + ByteBuffer unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], + new byte[2]).toByteBuffer(); try { handler.receive(client, unexpectedMsg, callback); fail("Should have thrown"); @@ -120,7 +122,7 @@ public void testBadMessages() { // pass } - verify(callback, never()).onSuccess((byte[]) any()); - verify(callback, never()).onFailure((Throwable) any()); + verify(callback, never()).onSuccess(any(ByteBuffer.class)); + verify(callback, never()).onFailure(any(Throwable.class)); } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java similarity index 79% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 3c6cb367dea46..d9b5f0261aaba 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.io.CharStreams; @@ -34,15 +35,16 @@ import static org.junit.Assert.*; public class ExternalShuffleBlockResolverSuite { - static String sortBlock0 = "Hello!"; - static String sortBlock1 = "World!"; + private static final String sortBlock0 = "Hello!"; + private static final String sortBlock1 = "World!"; - static String hashBlock0 = "Elementary"; - static String hashBlock1 = "Tabular"; + private static final String hashBlock0 = "Elementary"; + private static final String hashBlock1 = "Tabular"; - static TestShuffleDataContext dataContext; + private static TestShuffleDataContext dataContext; - static TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + private static final TransportConf conf = + new TransportConf("shuffle", new SystemPropertyConfigProvider()); @BeforeClass public static void beforeAll() throws IOException { @@ -50,10 +52,12 @@ public static void beforeAll() throws IOException { dataContext.create(); // Write some sort and hash data. - dataContext.insertSortShuffleData(0, 0, - new byte[][] { sortBlock0.getBytes(), sortBlock1.getBytes() } ); - dataContext.insertHashShuffleData(1, 0, - new byte[][] { hashBlock0.getBytes(), hashBlock1.getBytes() } ); + dataContext.insertSortShuffleData(0, 0, new byte[][] { + sortBlock0.getBytes(StandardCharsets.UTF_8), + sortBlock1.getBytes(StandardCharsets.UTF_8)}); + dataContext.insertHashShuffleData(1, 0, new byte[][] { + hashBlock0.getBytes(StandardCharsets.UTF_8), + hashBlock1.getBytes(StandardCharsets.UTF_8)}); } @AfterClass @@ -83,7 +87,7 @@ public void testBadRequests() throws IOException { // Nonexistent shuffle block resolver.registerExecutor("app0", "exec3", - dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); + dataContext.createExecutorInfo("sort")); try { resolver.getBlockData("app0", "exec3", "shuffle_1_1_0"); fail("Should have failed"); @@ -96,17 +100,19 @@ public void testBadRequests() throws IOException { public void testSortShuffleBlocks() throws IOException { ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); resolver.registerExecutor("app0", "exec0", - dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); + dataContext.createExecutorInfo("sort")); InputStream block0Stream = resolver.getBlockData("app0", "exec0", "shuffle_0_0_0").createInputStream(); - String block0 = CharStreams.toString(new InputStreamReader(block0Stream)); + String block0 = CharStreams.toString( + new InputStreamReader(block0Stream, StandardCharsets.UTF_8)); block0Stream.close(); assertEquals(sortBlock0, block0); InputStream block1Stream = resolver.getBlockData("app0", "exec0", "shuffle_0_0_1").createInputStream(); - String block1 = CharStreams.toString(new InputStreamReader(block1Stream)); + String block1 = CharStreams.toString( + new InputStreamReader(block1Stream, StandardCharsets.UTF_8)); block1Stream.close(); assertEquals(sortBlock1, block1); } @@ -115,17 +121,19 @@ public void testSortShuffleBlocks() throws IOException { public void testHashShuffleBlocks() throws IOException { ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); resolver.registerExecutor("app0", "exec0", - dataContext.createExecutorInfo("org.apache.spark.shuffle.hash.HashShuffleManager")); + dataContext.createExecutorInfo("hash")); InputStream block0Stream = resolver.getBlockData("app0", "exec0", "shuffle_1_0_0").createInputStream(); - String block0 = CharStreams.toString(new InputStreamReader(block0Stream)); + String block0 = CharStreams.toString( + new InputStreamReader(block0Stream, StandardCharsets.UTF_8)); block0Stream.close(); assertEquals(hashBlock0, block0); InputStream block1Stream = resolver.getBlockData("app0", "exec0", "shuffle_1_0_1").createInputStream(); - String block1 = CharStreams.toString(new InputStreamReader(block1Stream)); + String block1 = CharStreams.toString( + new InputStreamReader(block1Stream, StandardCharsets.UTF_8)); block1Stream.close(); assertEquals(hashBlock1, block1); } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java similarity index 87% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java index 2f4f1d0df478b..43d0201405872 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -19,6 +19,7 @@ import java.io.File; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.Random; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; @@ -34,8 +35,8 @@ public class ExternalShuffleCleanupSuite { // Same-thread Executor used to ensure cleanup happens synchronously in test thread. - Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + private Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); + private TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); @Test public void noCleanupAndCleanup() throws IOException { @@ -123,27 +124,29 @@ public void cleanupOnlyRemovedApp() throws IOException { assertCleanedUp(dataContext1); } - private void assertStillThere(TestShuffleDataContext dataContext) { + private static void assertStillThere(TestShuffleDataContext dataContext) { for (String localDir : dataContext.localDirs) { assertTrue(localDir + " was cleaned up prematurely", new File(localDir).exists()); } } - private void assertCleanedUp(TestShuffleDataContext dataContext) { + private static void assertCleanedUp(TestShuffleDataContext dataContext) { for (String localDir : dataContext.localDirs) { assertFalse(localDir + " wasn't cleaned up", new File(localDir).exists()); } } - private TestShuffleDataContext createSomeData() throws IOException { + private static TestShuffleDataContext createSomeData() throws IOException { Random rand = new Random(123); TestShuffleDataContext dataContext = new TestShuffleDataContext(10, 5); dataContext.create(); - dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000), - new byte[][] { "ABC".getBytes(), "DEF".getBytes() } ); - dataContext.insertHashShuffleData(rand.nextInt(1000), rand.nextInt(1000) + 1000, - new byte[][] { "GHI".getBytes(), "JKLMNOPQRSTUVWXYZ".getBytes() } ); + dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000), new byte[][] { + "ABC".getBytes(StandardCharsets.UTF_8), + "DEF".getBytes(StandardCharsets.UTF_8)}); + dataContext.insertHashShuffleData(rand.nextInt(1000), rand.nextInt(1000) + 1000, new byte[][] { + "GHI".getBytes(StandardCharsets.UTF_8), + "JKLMNOPQRSTUVWXYZ".getBytes(StandardCharsets.UTF_8)}); return dataContext; } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java similarity index 97% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index a3f9a38b1aeb9..ecbbe7bfa3b11 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -49,8 +49,8 @@ public class ExternalShuffleIntegrationSuite { static String APP_ID = "app-id"; - static String SORT_MANAGER = "org.apache.spark.shuffle.sort.SortShuffleManager"; - static String HASH_MANAGER = "org.apache.spark.shuffle.hash.HashShuffleManager"; + static String SORT_MANAGER = "sort"; + static String HASH_MANAGER = "hash"; // Executor 0 is sort-based static TestShuffleDataContext dataContext0; @@ -91,7 +91,7 @@ public static void beforeAll() throws IOException { dataContext1.create(); dataContext1.insertHashShuffleData(1, 0, exec1Blocks); - conf = new TransportConf(new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); handler = new ExternalShuffleBlockHandler(conf, null); TransportContext transportContext = new TransportContext(conf, handler); server = transportContext.createServer(); @@ -109,7 +109,7 @@ public void afterEach() { handler.applicationRemoved(APP_ID, false /* cleanupLocalDirs */); } - class FetchResult { + static class FetchResult { public Set successBlocks; public Set failedBlocks; public List buffers; diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java similarity index 96% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index aa99efda94948..acc1168f83354 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -30,7 +30,6 @@ import org.apache.spark.network.TransportContext; import org.apache.spark.network.sasl.SaslServerBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; -import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; @@ -39,7 +38,7 @@ public class ExternalShuffleSecuritySuite { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); TransportServer server; @Before diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java similarity index 97% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index b35a6d685dd02..2590b9ce4c1f1 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -134,14 +134,14 @@ private BlockFetchingListener fetchBlocks(final LinkedHashMap() { @Override public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteArray( - (byte[]) invocationOnMock.getArguments()[0]); + BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteBuffer( + (ByteBuffer) invocationOnMock.getArguments()[0]); RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1]; - callback.onSuccess(new StreamHandle(123, blocks.size()).toByteArray()); + callback.onSuccess(new StreamHandle(123, blocks.size()).toByteBuffer()); assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message); return null; } - }).when(client).sendRpc((byte[]) any(), (RpcResponseCallback) any()); + }).when(client).sendRpc(any(ByteBuffer.class), any(RpcResponseCallback.class)); // Respond to each chunk request with a single buffer from our blocks array. final AtomicInteger expectedChunkIndex = new AtomicInteger(0); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java similarity index 99% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java index 06e46f9241094..91882e3b3bcd5 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java @@ -254,7 +254,7 @@ private static void performInteractions(List> inte BlockFetchingListener listener) throws IOException { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); BlockFetchStarter fetchStarter = mock(BlockFetchStarter.class); Stubber stub = null; @@ -305,7 +305,7 @@ public Void answer(InvocationOnMock invocationOnMock) throws Throwable { } } - assert stub != null; + assertNotNull(stub); stub.when(fetchStarter).createAndStart((String[]) any(), (BlockFetchingListener) anyObject()); String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]); new RetryingBlockFetcher(conf, fetchStarter, blockIdArray, listener).start(); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java similarity index 79% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java rename to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index 3fdde054ab6c7..7ac1ca128aed0 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.io.OutputStream; +import com.google.common.io.Closeables; import com.google.common.io.Files; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; @@ -60,21 +61,28 @@ public void cleanup() { public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException { String blockId = "shuffle_" + shuffleId + "_" + mapId + "_0"; - OutputStream dataStream = new FileOutputStream( - ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".data")); - DataOutputStream indexStream = new DataOutputStream(new FileOutputStream( - ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".index"))); + OutputStream dataStream = null; + DataOutputStream indexStream = null; + boolean suppressExceptionsDuringClose = true; - long offset = 0; - indexStream.writeLong(offset); - for (byte[] block : blocks) { - offset += block.length; - dataStream.write(block); + try { + dataStream = new FileOutputStream( + ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".data")); + indexStream = new DataOutputStream(new FileOutputStream( + ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".index"))); + + long offset = 0; indexStream.writeLong(offset); + for (byte[] block : blocks) { + offset += block.length; + dataStream.write(block); + indexStream.writeLong(offset); + } + suppressExceptionsDuringClose = false; + } finally { + Closeables.close(dataStream, suppressExceptionsDuringClose); + Closeables.close(indexStream, suppressExceptionsDuringClose); } - - dataStream.close(); - indexStream.close(); } /** Creates reducer blocks in a hash-based data format within our local dirs. */ diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml new file mode 100644 index 0000000000000..bc83ef24c30ec --- /dev/null +++ b/common/network-yarn/pom.xml @@ -0,0 +1,148 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.0.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-network-yarn_2.11 + jar + Spark Project YARN Shuffle Service + http://spark.apache.org/ + + network-yarn + + provided + ${project.build.directory}/scala-${scala.binary.version}/spark-${project.version}-yarn-shuffle.jar + org/spark_project/ + + + + + + org.apache.spark + spark-network-shuffle_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-test-tags_${scala.binary.version} + + + + + org.apache.hadoop + hadoop-client + + + org.slf4j + slf4j-api + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + ${shuffle.jar} + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + com.fasterxml.jackson + ${spark.shade.packageName}.com.fasterxml.jackson + + com.fasterxml.jackson.** + + + + + + + package + + shade + + + + + + + + org.apache.maven.plugins + maven-antrun-plugin + + + verify + + run + + + + + + + + + + + + + + + + + + Verifying dependency shading + + + + + + + + + + + diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java similarity index 95% rename from network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java rename to common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index 11ea7f3fd3cfe..4bc3c1a3c8a64 100644 --- a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -24,6 +24,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Lists; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; import org.apache.hadoop.yarn.api.records.ContainerId; import org.apache.hadoop.yarn.server.api.*; import org.slf4j.Logger; @@ -118,9 +119,9 @@ protected void serviceInit(Configuration conf) { // an application was stopped while the NM was down, we expect yarn to call stopApplication() // when it comes back registeredExecutorFile = - findRegisteredExecutorFile(conf.getStrings("yarn.nodemanager.local-dirs")); + findRegisteredExecutorFile(conf.getTrimmedStrings("yarn.nodemanager.local-dirs")); - TransportConf transportConf = new TransportConf(new HadoopConfigProvider(conf)); + TransportConf transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf)); // If authentication is enabled, set up the shuffle server to use a // special RPC handler that filters out unauthenticated fetch requests boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); @@ -191,12 +192,12 @@ public void stopContainer(ContainerTerminationContext context) { private File findRegisteredExecutorFile(String[] localDirs) { for (String dir: localDirs) { - File f = new File(dir, "registeredExecutors.ldb"); + File f = new File(new Path(dir).toUri().getPath(), "registeredExecutors.ldb"); if (f.exists()) { return f; } } - return new File(localDirs[0], "registeredExecutors.ldb"); + return new File(new Path(localDirs[0]).toUri().getPath(), "registeredExecutors.ldb"); } /** diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java similarity index 100% rename from network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java rename to common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml new file mode 100644 index 0000000000000..8bc1f52798941 --- /dev/null +++ b/common/sketch/pom.xml @@ -0,0 +1,73 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.0.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-sketch_2.11 + jar + Spark Project Sketch + http://spark.apache.org/ + + sketch + + + + + org.apache.spark + spark-test-tags_${scala.binary.version} + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + + net.alchim31.maven + scala-maven-plugin + + + + -XDignore.symbol.file + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + + -XDignore.symbol.file + + + + + + + diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java new file mode 100644 index 0000000000000..480a0a79db32d --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.Arrays; + +final class BitArray { + private final long[] data; + private long bitCount; + + static int numWords(long numBits) { + if (numBits <= 0) { + throw new IllegalArgumentException("numBits must be positive, but got " + numBits); + } + long numWords = (long) Math.ceil(numBits / 64.0); + if (numWords > Integer.MAX_VALUE) { + throw new IllegalArgumentException("Can't allocate enough space for " + numBits + " bits"); + } + return (int) numWords; + } + + BitArray(long numBits) { + this(new long[numWords(numBits)]); + } + + private BitArray(long[] data) { + this.data = data; + long bitCount = 0; + for (long word : data) { + bitCount += Long.bitCount(word); + } + this.bitCount = bitCount; + } + + /** Returns true if the bit changed value. */ + boolean set(long index) { + if (!get(index)) { + data[(int) (index >>> 6)] |= (1L << index); + bitCount++; + return true; + } + return false; + } + + boolean get(long index) { + return (data[(int) (index >>> 6)] & (1L << index)) != 0; + } + + /** Number of bits */ + long bitSize() { + return (long) data.length * Long.SIZE; + } + + /** Number of set bits (1s) */ + long cardinality() { + return bitCount; + } + + /** Combines the two BitArrays using bitwise OR. */ + void putAll(BitArray array) { + assert data.length == array.data.length : "BitArrays must be of equal length when merging"; + long bitCount = 0; + for (int i = 0; i < data.length; i++) { + data[i] |= array.data[i]; + bitCount += Long.bitCount(data[i]); + } + this.bitCount = bitCount; + } + + void writeTo(DataOutputStream out) throws IOException { + out.writeInt(data.length); + for (long datum : data) { + out.writeLong(datum); + } + } + + static BitArray readFrom(DataInputStream in) throws IOException { + int numWords = in.readInt(); + long[] data = new long[numWords]; + for (int i = 0; i < numWords; i++) { + data[i] = in.readLong(); + } + return new BitArray(data); + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || !(other instanceof BitArray)) return false; + BitArray that = (BitArray) other; + return Arrays.equals(data, that.data); + } + + @Override + public int hashCode() { + return Arrays.hashCode(data); + } +} diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java new file mode 100644 index 0000000000000..c0b425e729595 --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** + * A Bloom filter is a space-efficient probabilistic data structure that offers an approximate + * containment test with one-sided error: if it claims that an item is contained in it, this + * might be in error, but if it claims that an item is not contained in it, then this is + * definitely true. Currently supported data types include: + *

    + *
  • {@link Byte}
  • + *
  • {@link Short}
  • + *
  • {@link Integer}
  • + *
  • {@link Long}
  • + *
  • {@link String}
  • + *
+ * The false positive probability ({@code FPP}) of a Bloom filter is defined as the probability that + * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that hasu + * not actually been put in the {@code BloomFilter}. + * + * The implementation is largely based on the {@code BloomFilter} class from Guava. + */ +public abstract class BloomFilter { + + public enum Version { + /** + * {@code BloomFilter} binary format version 1. All values written in big-endian order: + *
    + *
  • Version number, always 1 (32 bit)
  • + *
  • Number of hash functions (32 bit)
  • + *
  • Total number of words of the underlying bit array (32 bit)
  • + *
  • The words/longs (numWords * 64 bit)
  • + *
+ */ + V1(1); + + private final int versionNumber; + + Version(int versionNumber) { + this.versionNumber = versionNumber; + } + + int getVersionNumber() { + return versionNumber; + } + } + + /** + * Returns the probability that {@linkplain #mightContain(Object)} erroneously return {@code true} + * for an object that has not actually been put in the {@code BloomFilter}. + * + * Ideally, this number should be close to the {@code fpp} parameter passed in + * {@linkplain #create(long, double)}, or smaller. If it is significantly higher, it is usually + * the case that too many items (more than expected) have been put in the {@code BloomFilter}, + * degenerating it. + */ + public abstract double expectedFpp(); + + /** + * Returns the number of bits in the underlying bit array. + */ + public abstract long bitSize(); + + /** + * Puts an item into this {@code BloomFilter}. Ensures that subsequent invocations of + * {@linkplain #mightContain(Object)} with the same item will always return {@code true}. + * + * @return true if the bloom filter's bits changed as a result of this operation. If the bits + * changed, this is definitely the first time {@code object} has been added to the + * filter. If the bits haven't changed, this might be the first time {@code object} + * has been added to the filter. Note that {@code put(t)} always returns the + * opposite result to what {@code mightContain(t)} would have returned at the time + * it is called. + */ + public abstract boolean put(Object item); + + /** + * A specialized variant of {@link #put(Object)} that only supports {@code String} items. + */ + public abstract boolean putString(String item); + + /** + * A specialized variant of {@link #put(Object)} that only supports {@code long} items. + */ + public abstract boolean putLong(long item); + + /** + * A specialized variant of {@link #put(Object)} that only supports byte array items. + */ + public abstract boolean putBinary(byte[] item); + + /** + * Determines whether a given bloom filter is compatible with this bloom filter. For two + * bloom filters to be compatible, they must have the same bit size. + * + * @param other The bloom filter to check for compatibility. + */ + public abstract boolean isCompatible(BloomFilter other); + + /** + * Combines this bloom filter with another bloom filter by performing a bitwise OR of the + * underlying data. The mutations happen to this instance. Callers must ensure the + * bloom filters are appropriately sized to avoid saturating them. + * + * @param other The bloom filter to combine this bloom filter with. It is not mutated. + * @throws IncompatibleMergeException if {@code isCompatible(other) == false} + */ + public abstract BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeException; + + /** + * Returns {@code true} if the element might have been put in this Bloom filter, + * {@code false} if this is definitely not the case. + */ + public abstract boolean mightContain(Object item); + + /** + * A specialized variant of {@link #mightContain(Object)} that only tests {@code String} items. + */ + public abstract boolean mightContainString(String item); + + /** + * A specialized variant of {@link #mightContain(Object)} that only tests {@code long} items. + */ + public abstract boolean mightContainLong(long item); + + /** + * A specialized variant of {@link #mightContain(Object)} that only tests byte array items. + */ + public abstract boolean mightContainBinary(byte[] item); + + /** + * Writes out this {@link BloomFilter} to an output stream in binary format. It is the caller's + * responsibility to close the stream. + */ + public abstract void writeTo(OutputStream out) throws IOException; + + /** + * Reads in a {@link BloomFilter} from an input stream. It is the caller's responsibility to close + * the stream. + */ + public static BloomFilter readFrom(InputStream in) throws IOException { + return BloomFilterImpl.readFrom(in); + } + + /** + * Computes the optimal k (number of hashes per item inserted in Bloom filter), given the + * expected insertions and total number of bits in the Bloom filter. + * + * See http://en.wikipedia.org/wiki/File:Bloom_filter_fp_probability.svg for the formula. + * + * @param n expected insertions (must be positive) + * @param m total number of bits in Bloom filter (must be positive) + */ + private static int optimalNumOfHashFunctions(long n, long m) { + // (m / n) * log(2), but avoid truncation due to division! + return Math.max(1, (int) Math.round((double) m / n * Math.log(2))); + } + + /** + * Computes m (total bits of Bloom filter) which is expected to achieve, for the specified + * expected insertions, the required false positive probability. + * + * See http://en.wikipedia.org/wiki/Bloom_filter#Probability_of_false_positives for the formula. + * + * @param n expected insertions (must be positive) + * @param p false positive rate (must be 0 < p < 1) + */ + private static long optimalNumOfBits(long n, double p) { + return (long) (-n * Math.log(p) / (Math.log(2) * Math.log(2))); + } + + static final double DEFAULT_FPP = 0.03; + + /** + * Creates a {@link BloomFilter} with the expected number of insertions and a default expected + * false positive probability of 3%. + * + * Note that overflowing a {@code BloomFilter} with significantly more elements than specified, + * will result in its saturation, and a sharp deterioration of its false positive probability. + */ + public static BloomFilter create(long expectedNumItems) { + return create(expectedNumItems, DEFAULT_FPP); + } + + /** + * Creates a {@link BloomFilter} with the expected number of insertions and expected false + * positive probability. + * + * Note that overflowing a {@code BloomFilter} with significantly more elements than specified, + * will result in its saturation, and a sharp deterioration of its false positive probability. + */ + public static BloomFilter create(long expectedNumItems, double fpp) { + if (fpp <= 0D || fpp >= 1D) { + throw new IllegalArgumentException( + "False positive probability must be within range (0.0, 1.0)" + ); + } + + return create(expectedNumItems, optimalNumOfBits(expectedNumItems, fpp)); + } + + /** + * Creates a {@link BloomFilter} with given {@code expectedNumItems} and {@code numBits}, it will + * pick an optimal {@code numHashFunctions} which can minimize {@code fpp} for the bloom filter. + */ + public static BloomFilter create(long expectedNumItems, long numBits) { + if (expectedNumItems <= 0) { + throw new IllegalArgumentException("Expected insertions must be positive"); + } + + if (numBits <= 0) { + throw new IllegalArgumentException("Number of bits must be positive"); + } + + return new BloomFilterImpl(optimalNumOfHashFunctions(expectedNumItems, numBits), numBits); + } +} diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java new file mode 100644 index 0000000000000..92c28bcb56a5a --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.io.*; + +class BloomFilterImpl extends BloomFilter implements Serializable { + + private int numHashFunctions; + + private BitArray bits; + + BloomFilterImpl(int numHashFunctions, long numBits) { + this(new BitArray(numBits), numHashFunctions); + } + + private BloomFilterImpl(BitArray bits, int numHashFunctions) { + this.bits = bits; + this.numHashFunctions = numHashFunctions; + } + + private BloomFilterImpl() {} + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } + + if (other == null || !(other instanceof BloomFilterImpl)) { + return false; + } + + BloomFilterImpl that = (BloomFilterImpl) other; + + return this.numHashFunctions == that.numHashFunctions && this.bits.equals(that.bits); + } + + @Override + public int hashCode() { + return bits.hashCode() * 31 + numHashFunctions; + } + + @Override + public double expectedFpp() { + return Math.pow((double) bits.cardinality() / bits.bitSize(), numHashFunctions); + } + + @Override + public long bitSize() { + return bits.bitSize(); + } + + @Override + public boolean put(Object item) { + if (item instanceof String) { + return putString((String) item); + } else if (item instanceof byte[]) { + return putBinary((byte[]) item); + } else { + return putLong(Utils.integralToLong(item)); + } + } + + @Override + public boolean putString(String item) { + return putBinary(Utils.getBytesFromUTF8String(item)); + } + + @Override + public boolean putBinary(byte[] item) { + int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0); + int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1); + + long bitSize = bits.bitSize(); + boolean bitsChanged = false; + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + bitsChanged |= bits.set(combinedHash % bitSize); + } + return bitsChanged; + } + + @Override + public boolean mightContainString(String item) { + return mightContainBinary(Utils.getBytesFromUTF8String(item)); + } + + @Override + public boolean mightContainBinary(byte[] item) { + int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0); + int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1); + + long bitSize = bits.bitSize(); + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + if (!bits.get(combinedHash % bitSize)) { + return false; + } + } + return true; + } + + @Override + public boolean putLong(long item) { + // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce n + // hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions. + // Note that `CountMinSketch` use a different strategy, it hash the input long element with + // every i to produce n hash values. + // TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here? + int h1 = Murmur3_x86_32.hashLong(item, 0); + int h2 = Murmur3_x86_32.hashLong(item, h1); + + long bitSize = bits.bitSize(); + boolean bitsChanged = false; + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + bitsChanged |= bits.set(combinedHash % bitSize); + } + return bitsChanged; + } + + @Override + public boolean mightContainLong(long item) { + int h1 = Murmur3_x86_32.hashLong(item, 0); + int h2 = Murmur3_x86_32.hashLong(item, h1); + + long bitSize = bits.bitSize(); + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + if (!bits.get(combinedHash % bitSize)) { + return false; + } + } + return true; + } + + @Override + public boolean mightContain(Object item) { + if (item instanceof String) { + return mightContainString((String) item); + } else if (item instanceof byte[]) { + return mightContainBinary((byte[]) item); + } else { + return mightContainLong(Utils.integralToLong(item)); + } + } + + @Override + public boolean isCompatible(BloomFilter other) { + if (other == null) { + return false; + } + + if (!(other instanceof BloomFilterImpl)) { + return false; + } + + BloomFilterImpl that = (BloomFilterImpl) other; + return this.bitSize() == that.bitSize() && this.numHashFunctions == that.numHashFunctions; + } + + @Override + public BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeException { + // Duplicates the logic of `isCompatible` here to provide better error message. + if (other == null) { + throw new IncompatibleMergeException("Cannot merge null bloom filter"); + } + + if (!(other instanceof BloomFilterImpl)) { + throw new IncompatibleMergeException( + "Cannot merge bloom filter of class " + other.getClass().getName() + ); + } + + BloomFilterImpl that = (BloomFilterImpl) other; + + if (this.bitSize() != that.bitSize()) { + throw new IncompatibleMergeException("Cannot merge bloom filters with different bit size"); + } + + if (this.numHashFunctions != that.numHashFunctions) { + throw new IncompatibleMergeException( + "Cannot merge bloom filters with different number of hash functions" + ); + } + + this.bits.putAll(that.bits); + return this; + } + + @Override + public void writeTo(OutputStream out) throws IOException { + DataOutputStream dos = new DataOutputStream(out); + + dos.writeInt(Version.V1.getVersionNumber()); + dos.writeInt(numHashFunctions); + bits.writeTo(dos); + } + + private void readFrom0(InputStream in) throws IOException { + DataInputStream dis = new DataInputStream(in); + + int version = dis.readInt(); + if (version != Version.V1.getVersionNumber()) { + throw new IOException("Unexpected Bloom filter version number (" + version + ")"); + } + + this.numHashFunctions = dis.readInt(); + this.bits = BitArray.readFrom(dis); + } + + public static BloomFilterImpl readFrom(InputStream in) throws IOException { + BloomFilterImpl filter = new BloomFilterImpl(); + filter.readFrom0(in); + return filter; + } + + private void writeObject(ObjectOutputStream out) throws IOException { + writeTo(out); + } + + private void readObject(ObjectInputStream in) throws IOException { + readFrom0(in); + } +} diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java new file mode 100644 index 0000000000000..40fa20c4a3e37 --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** + * A Count-min sketch is a probabilistic data structure used for summarizing streams of data in + * sub-linear space. Currently, supported data types include: + *
    + *
  • {@link Byte}
  • + *
  • {@link Short}
  • + *
  • {@link Integer}
  • + *
  • {@link Long}
  • + *
  • {@link String}
  • + *
+ * A {@link CountMinSketch} is initialized with a random seed, and a pair of parameters: + *
    + *
  1. relative error (or {@code eps}), and + *
  2. confidence (or {@code delta}) + *
+ * Suppose you want to estimate the number of times an element {@code x} has appeared in a data + * stream so far. With probability {@code delta}, the estimate of this frequency is within the + * range {@code true frequency <= estimate <= true frequency + eps * N}, where {@code N} is the + * total count of items have appeared the data stream so far. + * + * Under the cover, a {@link CountMinSketch} is essentially a two-dimensional {@code long} array + * with depth {@code d} and width {@code w}, where + *
    + *
  • {@code d = ceil(2 / eps)}
  • + *
  • {@code w = ceil(-log(1 - confidence) / log(2))}
  • + *
+ * + * This implementation is largely based on the {@code CountMinSketch} class from stream-lib. + */ +public abstract class CountMinSketch { + + public enum Version { + /** + * {@code CountMinSketch} binary format version 1. All values written in big-endian order: + *
    + *
  • Version number, always 1 (32 bit)
  • + *
  • Total count of added items (64 bit)
  • + *
  • Depth (32 bit)
  • + *
  • Width (32 bit)
  • + *
  • Hash functions (depth * 64 bit)
  • + *
  • + * Count table + *
      + *
    • Row 0 (width * 64 bit)
    • + *
    • Row 1 (width * 64 bit)
    • + *
    • ...
    • + *
    • Row {@code depth - 1} (width * 64 bit)
    • + *
    + *
  • + *
+ */ + V1(1); + + private final int versionNumber; + + Version(int versionNumber) { + this.versionNumber = versionNumber; + } + + int getVersionNumber() { + return versionNumber; + } + } + + /** + * Returns the relative error (or {@code eps}) of this {@link CountMinSketch}. + */ + public abstract double relativeError(); + + /** + * Returns the confidence (or {@code delta}) of this {@link CountMinSketch}. + */ + public abstract double confidence(); + + /** + * Depth of this {@link CountMinSketch}. + */ + public abstract int depth(); + + /** + * Width of this {@link CountMinSketch}. + */ + public abstract int width(); + + /** + * Total count of items added to this {@link CountMinSketch} so far. + */ + public abstract long totalCount(); + + /** + * Increments {@code item}'s count by one. + */ + public abstract void add(Object item); + + /** + * Increments {@code item}'s count by {@code count}. + */ + public abstract void add(Object item, long count); + + /** + * Increments {@code item}'s count by one. + */ + public abstract void addLong(long item); + + /** + * Increments {@code item}'s count by {@code count}. + */ + public abstract void addLong(long item, long count); + + /** + * Increments {@code item}'s count by one. + */ + public abstract void addString(String item); + + /** + * Increments {@code item}'s count by {@code count}. + */ + public abstract void addString(String item, long count); + + /** + * Increments {@code item}'s count by one. + */ + public abstract void addBinary(byte[] item); + + /** + * Increments {@code item}'s count by {@code count}. + */ + public abstract void addBinary(byte[] item, long count); + + /** + * Returns the estimated frequency of {@code item}. + */ + public abstract long estimateCount(Object item); + + /** + * Merges another {@link CountMinSketch} with this one in place. + * + * Note that only Count-Min sketches with the same {@code depth}, {@code width}, and random seed + * can be merged. + * + * @exception IncompatibleMergeException if the {@code other} {@link CountMinSketch} has + * incompatible depth, width, relative-error, confidence, or random seed. + */ + public abstract CountMinSketch mergeInPlace(CountMinSketch other) + throws IncompatibleMergeException; + + /** + * Writes out this {@link CountMinSketch} to an output stream in binary format. It is the caller's + * responsibility to close the stream. + */ + public abstract void writeTo(OutputStream out) throws IOException; + + /** + * Reads in a {@link CountMinSketch} from an input stream. It is the caller's responsibility to + * close the stream. + */ + public static CountMinSketch readFrom(InputStream in) throws IOException { + return CountMinSketchImpl.readFrom(in); + } + + /** + * Creates a {@link CountMinSketch} with given {@code depth}, {@code width}, and random + * {@code seed}. + * + * @param depth depth of the Count-min Sketch, must be positive + * @param width width of the Count-min Sketch, must be positive + * @param seed random seed + */ + public static CountMinSketch create(int depth, int width, int seed) { + return new CountMinSketchImpl(depth, width, seed); + } + + /** + * Creates a {@link CountMinSketch} with given relative error ({@code eps}), {@code confidence}, + * and random {@code seed}. + * + * @param eps relative error, must be positive + * @param confidence confidence, must be positive and less than 1.0 + * @param seed random seed + */ + public static CountMinSketch create(double eps, double confidence, int seed) { + return new CountMinSketchImpl(eps, confidence, seed); + } +} diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java new file mode 100644 index 0000000000000..2acbb247b13cd --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -0,0 +1,357 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.Arrays; +import java.util.Random; + +class CountMinSketchImpl extends CountMinSketch implements Serializable { + private static final long PRIME_MODULUS = (1L << 31) - 1; + + private int depth; + private int width; + private long[][] table; + private long[] hashA; + private long totalCount; + private double eps; + private double confidence; + + private CountMinSketchImpl() {} + + CountMinSketchImpl(int depth, int width, int seed) { + if (depth <= 0 || width <= 0) { + throw new IllegalArgumentException("Depth and width must be both positive"); + } + + this.depth = depth; + this.width = width; + this.eps = 2.0 / width; + this.confidence = 1 - 1 / Math.pow(2, depth); + initTablesWith(depth, width, seed); + } + + CountMinSketchImpl(double eps, double confidence, int seed) { + if (eps <= 0D) { + throw new IllegalArgumentException("Relative error must be positive"); + } + + if (confidence <= 0D || confidence >= 1D) { + throw new IllegalArgumentException("Confidence must be within range (0.0, 1.0)"); + } + + // 2/w = eps ; w = 2/eps + // 1/2^depth <= 1-confidence ; depth >= -log2 (1-confidence) + this.eps = eps; + this.confidence = confidence; + this.width = (int) Math.ceil(2 / eps); + this.depth = (int) Math.ceil(-Math.log(1 - confidence) / Math.log(2)); + initTablesWith(depth, width, seed); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } + + if (other == null || !(other instanceof CountMinSketchImpl)) { + return false; + } + + CountMinSketchImpl that = (CountMinSketchImpl) other; + + return + this.depth == that.depth && + this.width == that.width && + this.totalCount == that.totalCount && + Arrays.equals(this.hashA, that.hashA) && + Arrays.deepEquals(this.table, that.table); + } + + @Override + public int hashCode() { + int hash = depth; + + hash = hash * 31 + width; + hash = hash * 31 + (int) (totalCount ^ (totalCount >>> 32)); + hash = hash * 31 + Arrays.hashCode(hashA); + hash = hash * 31 + Arrays.deepHashCode(table); + + return hash; + } + + private void initTablesWith(int depth, int width, int seed) { + this.table = new long[depth][width]; + this.hashA = new long[depth]; + Random r = new Random(seed); + // We're using a linear hash functions + // of the form (a*x+b) mod p. + // a,b are chosen independently for each hash function. + // However we can set b = 0 as all it does is shift the results + // without compromising their uniformity or independence with + // the other hashes. + for (int i = 0; i < depth; ++i) { + hashA[i] = r.nextInt(Integer.MAX_VALUE); + } + } + + @Override + public double relativeError() { + return eps; + } + + @Override + public double confidence() { + return confidence; + } + + @Override + public int depth() { + return depth; + } + + @Override + public int width() { + return width; + } + + @Override + public long totalCount() { + return totalCount; + } + + @Override + public void add(Object item) { + add(item, 1); + } + + @Override + public void add(Object item, long count) { + if (item instanceof String) { + addString((String) item, count); + } else { + addLong(Utils.integralToLong(item), count); + } + } + + @Override + public void addString(String item) { + addString(item, 1); + } + + @Override + public void addString(String item, long count) { + addBinary(Utils.getBytesFromUTF8String(item), count); + } + + @Override + public void addLong(long item) { + addLong(item, 1); + } + + @Override + public void addLong(long item, long count) { + if (count < 0) { + throw new IllegalArgumentException("Negative increments not implemented"); + } + + for (int i = 0; i < depth; ++i) { + table[i][hash(item, i)] += count; + } + + totalCount += count; + } + + @Override + public void addBinary(byte[] item) { + addBinary(item, 1); + } + + @Override + public void addBinary(byte[] item, long count) { + if (count < 0) { + throw new IllegalArgumentException("Negative increments not implemented"); + } + + int[] buckets = getHashBuckets(item, depth, width); + + for (int i = 0; i < depth; ++i) { + table[i][buckets[i]] += count; + } + + totalCount += count; + } + + private int hash(long item, int count) { + long hash = hashA[count] * item; + // A super fast way of computing x mod 2^p-1 + // See http://www.cs.princeton.edu/courses/archive/fall09/cos521/Handouts/universalclasses.pdf + // page 149, right after Proposition 7. + hash += hash >> 32; + hash &= PRIME_MODULUS; + // Doing "%" after (int) conversion is ~2x faster than %'ing longs. + return ((int) hash) % width; + } + + private static int[] getHashBuckets(String key, int hashCount, int max) { + return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max); + } + + private static int[] getHashBuckets(byte[] b, int hashCount, int max) { + int[] result = new int[hashCount]; + int hash1 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, 0); + int hash2 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, hash1); + for (int i = 0; i < hashCount; i++) { + result[i] = Math.abs((hash1 + i * hash2) % max); + } + return result; + } + + @Override + public long estimateCount(Object item) { + if (item instanceof String) { + return estimateCountForStringItem((String) item); + } else { + return estimateCountForLongItem(Utils.integralToLong(item)); + } + } + + private long estimateCountForLongItem(long item) { + long res = Long.MAX_VALUE; + for (int i = 0; i < depth; ++i) { + res = Math.min(res, table[i][hash(item, i)]); + } + return res; + } + + private long estimateCountForStringItem(String item) { + long res = Long.MAX_VALUE; + int[] buckets = getHashBuckets(item, depth, width); + for (int i = 0; i < depth; ++i) { + res = Math.min(res, table[i][buckets[i]]); + } + return res; + } + + @Override + public CountMinSketch mergeInPlace(CountMinSketch other) throws IncompatibleMergeException { + if (other == null) { + throw new IncompatibleMergeException("Cannot merge null estimator"); + } + + if (!(other instanceof CountMinSketchImpl)) { + throw new IncompatibleMergeException( + "Cannot merge estimator of class " + other.getClass().getName() + ); + } + + CountMinSketchImpl that = (CountMinSketchImpl) other; + + if (this.depth != that.depth) { + throw new IncompatibleMergeException("Cannot merge estimators of different depth"); + } + + if (this.width != that.width) { + throw new IncompatibleMergeException("Cannot merge estimators of different width"); + } + + if (!Arrays.equals(this.hashA, that.hashA)) { + throw new IncompatibleMergeException("Cannot merge estimators of different seed"); + } + + for (int i = 0; i < this.table.length; ++i) { + for (int j = 0; j < this.table[i].length; ++j) { + this.table[i][j] = this.table[i][j] + that.table[i][j]; + } + } + + this.totalCount += that.totalCount; + + return this; + } + + @Override + public void writeTo(OutputStream out) throws IOException { + DataOutputStream dos = new DataOutputStream(out); + + dos.writeInt(Version.V1.getVersionNumber()); + + dos.writeLong(this.totalCount); + dos.writeInt(this.depth); + dos.writeInt(this.width); + + for (int i = 0; i < this.depth; ++i) { + dos.writeLong(this.hashA[i]); + } + + for (int i = 0; i < this.depth; ++i) { + for (int j = 0; j < this.width; ++j) { + dos.writeLong(table[i][j]); + } + } + } + + public static CountMinSketchImpl readFrom(InputStream in) throws IOException { + CountMinSketchImpl sketch = new CountMinSketchImpl(); + sketch.readFrom0(in); + return sketch; + } + + private void readFrom0(InputStream in) throws IOException { + DataInputStream dis = new DataInputStream(in); + + int version = dis.readInt(); + if (version != Version.V1.getVersionNumber()) { + throw new IOException("Unexpected Count-Min Sketch version number (" + version + ")"); + } + + this.totalCount = dis.readLong(); + this.depth = dis.readInt(); + this.width = dis.readInt(); + this.eps = 2.0 / width; + this.confidence = 1 - 1 / Math.pow(2, depth); + + this.hashA = new long[depth]; + for (int i = 0; i < depth; ++i) { + this.hashA[i] = dis.readLong(); + } + + this.table = new long[depth][width]; + for (int i = 0; i < depth; ++i) { + for (int j = 0; j < width; ++j) { + this.table[i][j] = dis.readLong(); + } + } + } + + private void writeObject(ObjectOutputStream out) throws IOException { + this.writeTo(out); + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + this.readFrom0(in); + } +} diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/IncompatibleMergeException.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/IncompatibleMergeException.java new file mode 100644 index 0000000000000..64b567caa57c1 --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/IncompatibleMergeException.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +public class IncompatibleMergeException extends Exception { + public IncompatibleMergeException(String message) { + super(message); + } +} diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java new file mode 100644 index 0000000000000..a61ce4fb7241d --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +/** + * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. + */ +// This class is duplicated from `org.apache.spark.unsafe.hash.Murmur3_x86_32` to make sure +// spark-sketch has no external dependencies. +final class Murmur3_x86_32 { + private static final int C1 = 0xcc9e2d51; + private static final int C2 = 0x1b873593; + + private final int seed; + + Murmur3_x86_32(int seed) { + this.seed = seed; + } + + @Override + public String toString() { + return "Murmur3_32(seed=" + seed + ")"; + } + + public int hashInt(int input) { + return hashInt(input, seed); + } + + public static int hashInt(int input, int seed) { + int k1 = mixK1(input); + int h1 = mixH1(seed, k1); + + return fmix(h1, 4); + } + + public int hashUnsafeWords(Object base, long offset, int lengthInBytes) { + return hashUnsafeWords(base, offset, lengthInBytes, seed); + } + + public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { + // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. + assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; + int h1 = hashBytesByInt(base, offset, lengthInBytes, seed); + return fmix(h1, lengthInBytes); + } + + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; + int lengthAligned = lengthInBytes - lengthInBytes % 4; + int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + for (int i = lengthAligned; i < lengthInBytes; i++) { + int halfWord = Platform.getByte(base, offset + i); + int k1 = mixK1(halfWord); + h1 = mixH1(h1, k1); + } + return fmix(h1, lengthInBytes); + } + + private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) { + assert (lengthInBytes % 4 == 0); + int h1 = seed; + for (int i = 0; i < lengthInBytes; i += 4) { + int halfWord = Platform.getInt(base, offset + i); + int k1 = mixK1(halfWord); + h1 = mixH1(h1, k1); + } + return h1; + } + + public int hashLong(long input) { + return hashLong(input, seed); + } + + public static int hashLong(long input, int seed) { + int low = (int) input; + int high = (int) (input >>> 32); + + int k1 = mixK1(low); + int h1 = mixH1(seed, k1); + + k1 = mixK1(high); + h1 = mixH1(h1, k1); + + return fmix(h1, 8); + } + + private static int mixK1(int k1) { + k1 *= C1; + k1 = Integer.rotateLeft(k1, 15); + k1 *= C2; + return k1; + } + + private static int mixH1(int h1, int k1) { + h1 ^= k1; + h1 = Integer.rotateLeft(h1, 13); + h1 = h1 * 5 + 0xe6546b64; + return h1; + } + + // Finalization mix - force all bits of a hash block to avalanche + private static int fmix(int h1, int length) { + h1 ^= length; + h1 ^= h1 >>> 16; + h1 *= 0x85ebca6b; + h1 ^= h1 >>> 13; + h1 *= 0xc2b2ae35; + h1 ^= h1 >>> 16; + return h1; + } +} diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Platform.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Platform.java new file mode 100644 index 0000000000000..75d6a6beec408 --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Platform.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.lang.reflect.Field; + +import sun.misc.Unsafe; + +// This class is duplicated from `org.apache.spark.unsafe.Platform` to make sure spark-sketch has no +// external dependencies. +final class Platform { + + private static final Unsafe _UNSAFE; + + public static final int BYTE_ARRAY_OFFSET; + + public static final int INT_ARRAY_OFFSET; + + public static final int LONG_ARRAY_OFFSET; + + public static final int DOUBLE_ARRAY_OFFSET; + + public static int getInt(Object object, long offset) { + return _UNSAFE.getInt(object, offset); + } + + public static void putInt(Object object, long offset, int value) { + _UNSAFE.putInt(object, offset, value); + } + + public static boolean getBoolean(Object object, long offset) { + return _UNSAFE.getBoolean(object, offset); + } + + public static void putBoolean(Object object, long offset, boolean value) { + _UNSAFE.putBoolean(object, offset, value); + } + + public static byte getByte(Object object, long offset) { + return _UNSAFE.getByte(object, offset); + } + + public static void putByte(Object object, long offset, byte value) { + _UNSAFE.putByte(object, offset, value); + } + + public static short getShort(Object object, long offset) { + return _UNSAFE.getShort(object, offset); + } + + public static void putShort(Object object, long offset, short value) { + _UNSAFE.putShort(object, offset, value); + } + + public static long getLong(Object object, long offset) { + return _UNSAFE.getLong(object, offset); + } + + public static void putLong(Object object, long offset, long value) { + _UNSAFE.putLong(object, offset, value); + } + + public static float getFloat(Object object, long offset) { + return _UNSAFE.getFloat(object, offset); + } + + public static void putFloat(Object object, long offset, float value) { + _UNSAFE.putFloat(object, offset, value); + } + + public static double getDouble(Object object, long offset) { + return _UNSAFE.getDouble(object, offset); + } + + public static void putDouble(Object object, long offset, double value) { + _UNSAFE.putDouble(object, offset, value); + } + + public static Object getObjectVolatile(Object object, long offset) { + return _UNSAFE.getObjectVolatile(object, offset); + } + + public static void putObjectVolatile(Object object, long offset, Object value) { + _UNSAFE.putObjectVolatile(object, offset, value); + } + + public static long allocateMemory(long size) { + return _UNSAFE.allocateMemory(size); + } + + public static void freeMemory(long address) { + _UNSAFE.freeMemory(address); + } + + public static void copyMemory( + Object src, long srcOffset, Object dst, long dstOffset, long length) { + // Check if dstOffset is before or after srcOffset to determine if we should copy + // forward or backwards. This is necessary in case src and dst overlap. + if (dstOffset < srcOffset) { + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + srcOffset += size; + dstOffset += size; + } + } else { + srcOffset += length; + dstOffset += length; + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + srcOffset -= size; + dstOffset -= size; + _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + } + + } + } + + /** + * Raises an exception bypassing compiler checks for checked exceptions. + */ + public static void throwException(Throwable t) { + _UNSAFE.throwException(t); + } + + /** + * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to + * allow safepoint polling during a large copy. + */ + private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L; + + static { + sun.misc.Unsafe unsafe; + try { + Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe"); + unsafeField.setAccessible(true); + unsafe = (sun.misc.Unsafe) unsafeField.get(null); + } catch (Throwable cause) { + unsafe = null; + } + _UNSAFE = unsafe; + + if (_UNSAFE != null) { + BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class); + INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class); + LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class); + DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class); + } else { + BYTE_ARRAY_OFFSET = 0; + INT_ARRAY_OFFSET = 0; + LONG_ARRAY_OFFSET = 0; + DOUBLE_ARRAY_OFFSET = 0; + } + } +} diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java new file mode 100644 index 0000000000000..81461f03000a6 --- /dev/null +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Utils.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.nio.charset.StandardCharsets; + +class Utils { + public static byte[] getBytesFromUTF8String(String str) { + return str.getBytes(StandardCharsets.UTF_8); + } + + public static long integralToLong(Object i) { + long longValue; + + if (i instanceof Long) { + longValue = (Long) i; + } else if (i instanceof Integer) { + longValue = ((Integer) i).longValue(); + } else if (i instanceof Short) { + longValue = ((Short) i).longValue(); + } else if (i instanceof Byte) { + longValue = ((Byte) i).longValue(); + } else { + throw new IllegalArgumentException("Unsupported data type " + i.getClass().getName()); + } + + return longValue; + } +} diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BitArraySuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BitArraySuite.scala new file mode 100644 index 0000000000000..ff728f0ebcb85 --- /dev/null +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BitArraySuite.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch + +import scala.util.Random + +import org.scalatest.FunSuite // scalastyle:ignore funsuite + +class BitArraySuite extends FunSuite { // scalastyle:ignore funsuite + + test("error case when create BitArray") { + intercept[IllegalArgumentException](new BitArray(0)) + intercept[IllegalArgumentException](new BitArray(64L * Integer.MAX_VALUE + 1)) + } + + test("bitSize") { + assert(new BitArray(64).bitSize() == 64) + // BitArray is word-aligned, so 65~128 bits need 2 long to store, which is 128 bits. + assert(new BitArray(65).bitSize() == 128) + assert(new BitArray(127).bitSize() == 128) + assert(new BitArray(128).bitSize() == 128) + } + + test("set") { + val bitArray = new BitArray(64) + assert(bitArray.set(1)) + // Only returns true if the bit changed. + assert(!bitArray.set(1)) + assert(bitArray.set(2)) + } + + test("normal operation") { + // use a fixed seed to make the test predictable. + val r = new Random(37) + + val bitArray = new BitArray(320) + val indexes = (1 to 100).map(_ => r.nextInt(320).toLong).distinct + + indexes.foreach(bitArray.set) + indexes.foreach(i => assert(bitArray.get(i))) + assert(bitArray.cardinality() == indexes.length) + } + + test("merge") { + // use a fixed seed to make the test predictable. + val r = new Random(37) + + val bitArray1 = new BitArray(64 * 6) + val bitArray2 = new BitArray(64 * 6) + + val indexes1 = (1 to 100).map(_ => r.nextInt(64 * 6).toLong).distinct + val indexes2 = (1 to 100).map(_ => r.nextInt(64 * 6).toLong).distinct + + indexes1.foreach(bitArray1.set) + indexes2.foreach(bitArray2.set) + + bitArray1.putAll(bitArray2) + indexes1.foreach(i => assert(bitArray1.get(i))) + indexes2.foreach(i => assert(bitArray1.get(i))) + assert(bitArray1.cardinality() == (indexes1 ++ indexes2).distinct.length) + } +} diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala new file mode 100644 index 0000000000000..a0408d2da4dff --- /dev/null +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + +import scala.reflect.ClassTag +import scala.util.Random + +import org.scalatest.FunSuite // scalastyle:ignore funsuite + +class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite + private final val EPSILON = 0.01 + + // Serializes and deserializes a given `BloomFilter`, then checks whether the deserialized + // version is equivalent to the original one. + private def checkSerDe(filter: BloomFilter): Unit = { + val out = new ByteArrayOutputStream() + filter.writeTo(out) + out.close() + + val in = new ByteArrayInputStream(out.toByteArray) + val deserialized = BloomFilter.readFrom(in) + in.close() + + assert(filter == deserialized) + } + + def testAccuracy[T: ClassTag](typeName: String, numItems: Int)(itemGen: Random => T): Unit = { + test(s"accuracy - $typeName") { + // use a fixed seed to make the test predictable. + val r = new Random(37) + val fpp = 0.05 + val numInsertion = numItems / 10 + + val allItems = Array.fill(numItems)(itemGen(r)) + + val filter = BloomFilter.create(numInsertion, fpp) + + // insert first `numInsertion` items. + allItems.take(numInsertion).foreach(filter.put) + + // false negative is not allowed. + assert(allItems.take(numInsertion).forall(filter.mightContain)) + + // The number of inserted items doesn't exceed `expectedNumItems`, so the `expectedFpp` + // should not be significantly higher than the one we passed in to create this bloom filter. + assert(filter.expectedFpp() - fpp < EPSILON) + + val errorCount = allItems.drop(numInsertion).count(filter.mightContain) + + // Also check the actual fpp is not significantly higher than we expected. + val actualFpp = errorCount.toDouble / (numItems - numInsertion) + assert(actualFpp - fpp < EPSILON) + + checkSerDe(filter) + } + } + + def testMergeInPlace[T: ClassTag](typeName: String, numItems: Int)(itemGen: Random => T): Unit = { + test(s"mergeInPlace - $typeName") { + // use a fixed seed to make the test predictable. + val r = new Random(37) + + val items1 = Array.fill(numItems / 2)(itemGen(r)) + val items2 = Array.fill(numItems / 2)(itemGen(r)) + + val filter1 = BloomFilter.create(numItems) + items1.foreach(filter1.put) + + val filter2 = BloomFilter.create(numItems) + items2.foreach(filter2.put) + + filter1.mergeInPlace(filter2) + + // After merge, `filter1` has `numItems` items which doesn't exceed `expectedNumItems`, so the + // `expectedFpp` should not be significantly higher than the default one. + assert(filter1.expectedFpp() - BloomFilter.DEFAULT_FPP < EPSILON) + + items1.foreach(i => assert(filter1.mightContain(i))) + items2.foreach(i => assert(filter1.mightContain(i))) + + checkSerDe(filter1) + } + } + + def testItemType[T: ClassTag](typeName: String, numItems: Int)(itemGen: Random => T): Unit = { + testAccuracy[T](typeName, numItems)(itemGen) + testMergeInPlace[T](typeName, numItems)(itemGen) + } + + testItemType[Byte]("Byte", 160) { _.nextInt().toByte } + + testItemType[Short]("Short", 1000) { _.nextInt().toShort } + + testItemType[Int]("Int", 100000) { _.nextInt() } + + testItemType[Long]("Long", 100000) { _.nextLong() } + + testItemType[String]("String", 100000) { r => r.nextString(r.nextInt(512)) } + + test("incompatible merge") { + intercept[IncompatibleMergeException] { + BloomFilter.create(1000).mergeInPlace(null) + } + + intercept[IncompatibleMergeException] { + val filter1 = BloomFilter.create(1000, 6400) + val filter2 = BloomFilter.create(1000, 3200) + filter1.mergeInPlace(filter2) + } + + intercept[IncompatibleMergeException] { + val filter1 = BloomFilter.create(1000, 6400) + val filter2 = BloomFilter.create(2000, 6400) + filter1.mergeInPlace(filter2) + } + } +} diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala new file mode 100644 index 0000000000000..b9c7f5c23a8fe --- /dev/null +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + +import scala.reflect.ClassTag +import scala.util.Random + +import org.scalatest.FunSuite // scalastyle:ignore funsuite + +class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite + private val epsOfTotalCount = 0.0001 + + private val confidence = 0.99 + + private val seed = 42 + + // Serializes and deserializes a given `CountMinSketch`, then checks whether the deserialized + // version is equivalent to the original one. + private def checkSerDe(sketch: CountMinSketch): Unit = { + val out = new ByteArrayOutputStream() + sketch.writeTo(out) + + val in = new ByteArrayInputStream(out.toByteArray) + val deserialized = CountMinSketch.readFrom(in) + + assert(sketch === deserialized) + } + + def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = { + test(s"accuracy - $typeName") { + // Uses fixed seed to ensure reproducible test execution + val r = new Random(31) + + val numAllItems = 1000000 + val allItems = Array.fill(numAllItems)(itemGenerator(r)) + + val numSamples = numAllItems / 10 + val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems)) + + val exactFreq = { + val sampledItems = sampledItemIndices.map(allItems) + sampledItems.groupBy(identity).mapValues(_.length.toLong) + } + + val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) + checkSerDe(sketch) + + sampledItemIndices.foreach(i => sketch.add(allItems(i))) + checkSerDe(sketch) + + val probCorrect = { + val numErrors = allItems.map { item => + val count = exactFreq.getOrElse(item, 0L) + val ratio = (sketch.estimateCount(item) - count).toDouble / numAllItems + if (ratio > epsOfTotalCount) 1 else 0 + }.sum + + 1D - numErrors.toDouble / numAllItems + } + + assert( + probCorrect > confidence, + s"Confidence not reached: required $confidence, reached $probCorrect" + ) + } + } + + def testMergeInPlace[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = { + test(s"mergeInPlace - $typeName") { + // Uses fixed seed to ensure reproducible test execution + val r = new Random(31) + + val numToMerge = 5 + val numItemsPerSketch = 100000 + val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) { + itemGenerator(r) + } + + val sketches = perSketchItems.map { items => + val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) + checkSerDe(sketch) + + items.foreach(sketch.add) + checkSerDe(sketch) + + sketch + } + + val mergedSketch = sketches.reduce(_ mergeInPlace _) + checkSerDe(mergedSketch) + + val expectedSketch = { + val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) + perSketchItems.foreach(_.foreach(sketch.add)) + sketch + } + + perSketchItems.foreach { + _.foreach { item => + assert(mergedSketch.estimateCount(item) === expectedSketch.estimateCount(item)) + } + } + } + } + + def testItemType[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = { + testAccuracy[T](typeName)(itemGenerator) + testMergeInPlace[T](typeName)(itemGenerator) + } + + testItemType[Byte]("Byte") { _.nextInt().toByte } + + testItemType[Short]("Short") { _.nextInt().toShort } + + testItemType[Int]("Int") { _.nextInt() } + + testItemType[Long]("Long") { _.nextLong() } + + testItemType[String]("String") { r => r.nextString(r.nextInt(20)) } + + test("incompatible merge") { + intercept[IncompatibleMergeException] { + CountMinSketch.create(10, 10, 1).mergeInPlace(null) + } + + intercept[IncompatibleMergeException] { + val sketch1 = CountMinSketch.create(10, 20, 1) + val sketch2 = CountMinSketch.create(10, 20, 2) + sketch1.mergeInPlace(sketch2) + } + + intercept[IncompatibleMergeException] { + val sketch1 = CountMinSketch.create(10, 10, 1) + val sketch2 = CountMinSketch.create(10, 20, 2) + sketch1.mergeInPlace(sketch2) + } + } +} diff --git a/common/tags/README.md b/common/tags/README.md new file mode 100644 index 0000000000000..01e5126945eb7 --- /dev/null +++ b/common/tags/README.md @@ -0,0 +1 @@ +This module includes annotations in Java that are used to annotate test suites. diff --git a/common/tags/pom.xml b/common/tags/pom.xml new file mode 100644 index 0000000000000..8e702b4fefe8c --- /dev/null +++ b/common/tags/pom.xml @@ -0,0 +1,50 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.0.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-test-tags_2.11 + jar + Spark Project Test Tags + http://spark.apache.org/ + + test-tags + + + + + org.scalatest + scalatest_${scala.binary.version} + compile + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/common/tags/src/main/java/org/apache/spark/tags/DockerTest.java b/common/tags/src/main/java/org/apache/spark/tags/DockerTest.java new file mode 100644 index 0000000000000..0fecf3b8f979a --- /dev/null +++ b/common/tags/src/main/java/org/apache/spark/tags/DockerTest.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.tags; + +import java.lang.annotation.*; +import org.scalatest.TagAnnotation; + +@TagAnnotation +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.METHOD, ElementType.TYPE}) +public @interface DockerTest { } diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java b/common/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java similarity index 99% rename from tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java rename to common/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java index 1b0c416b0fe4e..83279e5e93c0e 100644 --- a/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java +++ b/common/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java @@ -18,6 +18,7 @@ package org.apache.spark.tags; import java.lang.annotation.*; + import org.scalatest.TagAnnotation; @TagAnnotation diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java b/common/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java similarity index 99% rename from tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java rename to common/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java index 2a631bfc88cf0..108300168e173 100644 --- a/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java +++ b/common/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java @@ -18,6 +18,7 @@ package org.apache.spark.tags; import java.lang.annotation.*; + import org.scalatest.TagAnnotation; @TagAnnotation diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml new file mode 100644 index 0000000000000..93b9580f26b86 --- /dev/null +++ b/common/unsafe/pom.xml @@ -0,0 +1,110 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.0.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-unsafe_2.11 + jar + Spark Project Unsafe + http://spark.apache.org/ + + unsafe + + + + + com.twitter + chill_${scala.binary.version} + + + + + com.google.code.findbugs + jsr305 + + + com.google.guava + guava + + + + + org.slf4j + slf4j-api + provided + + + + + org.apache.spark + spark-test-tags_${scala.binary.version} + + + org.mockito + mockito-core + test + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.apache.commons + commons-lang3 + test + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + + net.alchim31.maven + scala-maven-plugin + + + + -XDignore.symbol.file + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + + -XDignore.symbol.file + + + + + + + diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java new file mode 100644 index 0000000000000..bdf52f32c6fe1 --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; + +import sun.misc.Cleaner; +import sun.misc.Unsafe; + +public final class Platform { + + private static final Unsafe _UNSAFE; + + public static final int BYTE_ARRAY_OFFSET; + + public static final int SHORT_ARRAY_OFFSET; + + public static final int INT_ARRAY_OFFSET; + + public static final int LONG_ARRAY_OFFSET; + + public static final int FLOAT_ARRAY_OFFSET; + + public static final int DOUBLE_ARRAY_OFFSET; + + private static final boolean unaligned; + static { + boolean _unaligned; + // use reflection to access unaligned field + try { + Class bitsClass = + Class.forName("java.nio.Bits", false, ClassLoader.getSystemClassLoader()); + Method unalignedMethod = bitsClass.getDeclaredMethod("unaligned"); + unalignedMethod.setAccessible(true); + _unaligned = Boolean.TRUE.equals(unalignedMethod.invoke(null)); + } catch (Throwable t) { + // We at least know x86 and x64 support unaligned access. + String arch = System.getProperty("os.arch", ""); + //noinspection DynamicRegexReplaceableByCompiledPattern + _unaligned = arch.matches("^(i[3-6]86|x86(_64)?|x64|amd64)$"); + } + unaligned = _unaligned; + } + + /** + * @return true when running JVM is having sun's Unsafe package available in it and underlying + * system having unaligned-access capability. + */ + public static boolean unaligned() { + return unaligned; + } + + public static int getInt(Object object, long offset) { + return _UNSAFE.getInt(object, offset); + } + + public static void putInt(Object object, long offset, int value) { + _UNSAFE.putInt(object, offset, value); + } + + public static boolean getBoolean(Object object, long offset) { + return _UNSAFE.getBoolean(object, offset); + } + + public static void putBoolean(Object object, long offset, boolean value) { + _UNSAFE.putBoolean(object, offset, value); + } + + public static byte getByte(Object object, long offset) { + return _UNSAFE.getByte(object, offset); + } + + public static void putByte(Object object, long offset, byte value) { + _UNSAFE.putByte(object, offset, value); + } + + public static short getShort(Object object, long offset) { + return _UNSAFE.getShort(object, offset); + } + + public static void putShort(Object object, long offset, short value) { + _UNSAFE.putShort(object, offset, value); + } + + public static long getLong(Object object, long offset) { + return _UNSAFE.getLong(object, offset); + } + + public static void putLong(Object object, long offset, long value) { + _UNSAFE.putLong(object, offset, value); + } + + public static float getFloat(Object object, long offset) { + return _UNSAFE.getFloat(object, offset); + } + + public static void putFloat(Object object, long offset, float value) { + _UNSAFE.putFloat(object, offset, value); + } + + public static double getDouble(Object object, long offset) { + return _UNSAFE.getDouble(object, offset); + } + + public static void putDouble(Object object, long offset, double value) { + _UNSAFE.putDouble(object, offset, value); + } + + public static Object getObjectVolatile(Object object, long offset) { + return _UNSAFE.getObjectVolatile(object, offset); + } + + public static void putObjectVolatile(Object object, long offset, Object value) { + _UNSAFE.putObjectVolatile(object, offset, value); + } + + public static long allocateMemory(long size) { + return _UNSAFE.allocateMemory(size); + } + + public static void freeMemory(long address) { + _UNSAFE.freeMemory(address); + } + + public static long reallocateMemory(long address, long oldSize, long newSize) { + long newMemory = _UNSAFE.allocateMemory(newSize); + copyMemory(null, address, null, newMemory, oldSize); + freeMemory(address); + return newMemory; + } + + /** + * Uses internal JDK APIs to allocate a DirectByteBuffer while ignoring the JVM's + * MaxDirectMemorySize limit (the default limit is too low and we do not want to require users + * to increase it). + */ + @SuppressWarnings("unchecked") + public static ByteBuffer allocateDirectBuffer(int size) { + try { + Class cls = Class.forName("java.nio.DirectByteBuffer"); + Constructor constructor = cls.getDeclaredConstructor(Long.TYPE, Integer.TYPE); + constructor.setAccessible(true); + Field cleanerField = cls.getDeclaredField("cleaner"); + cleanerField.setAccessible(true); + final long memory = allocateMemory(size); + ByteBuffer buffer = (ByteBuffer) constructor.newInstance(memory, size); + Cleaner cleaner = Cleaner.create(buffer, new Runnable() { + @Override + public void run() { + freeMemory(memory); + } + }); + cleanerField.set(buffer, cleaner); + return buffer; + } catch (Exception e) { + throwException(e); + } + throw new IllegalStateException("unreachable"); + } + + public static void setMemory(long address, byte value, long size) { + _UNSAFE.setMemory(address, size, value); + } + + public static void copyMemory( + Object src, long srcOffset, Object dst, long dstOffset, long length) { + // Check if dstOffset is before or after srcOffset to determine if we should copy + // forward or backwards. This is necessary in case src and dst overlap. + if (dstOffset < srcOffset) { + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + srcOffset += size; + dstOffset += size; + } + } else { + srcOffset += length; + dstOffset += length; + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + srcOffset -= size; + dstOffset -= size; + _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + } + + } + } + + /** + * Raises an exception bypassing compiler checks for checked exceptions. + */ + public static void throwException(Throwable t) { + _UNSAFE.throwException(t); + } + + /** + * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to + * allow safepoint polling during a large copy. + */ + private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L; + + static { + sun.misc.Unsafe unsafe; + try { + Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe"); + unsafeField.setAccessible(true); + unsafe = (sun.misc.Unsafe) unsafeField.get(null); + } catch (Throwable cause) { + unsafe = null; + } + _UNSAFE = unsafe; + + if (_UNSAFE != null) { + BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class); + SHORT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(short[].class); + INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class); + LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class); + FLOAT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(float[].class); + DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class); + } else { + BYTE_ARRAY_OFFSET = 0; + SHORT_ARRAY_OFFSET = 0; + INT_ARRAY_OFFSET = 0; + LONG_ARRAY_OFFSET = 0; + FLOAT_ARRAY_OFFSET = 0; + DOUBLE_ARRAY_OFFSET = 0; + } + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java similarity index 88% rename from unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java index 74105050e4191..1a3cdff638264 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -39,7 +39,6 @@ public final class LongArray { private final long length; public LongArray(MemoryBlock memory) { - assert memory.size() % WIDTH == 0 : "Memory not aligned (" + memory.size() + ")"; assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size > 4 billion elements"; this.memory = memory; this.baseObj = memory.getBaseObject(); @@ -51,6 +50,14 @@ public MemoryBlock memoryBlock() { return memory; } + public Object getBaseObject() { + return baseObj; + } + + public long getBaseOffset() { + return baseOffset; + } + /** * Returns the number of elements this array can hold. */ @@ -58,6 +65,15 @@ public long size() { return length; } + /** + * Fill this all with 0L. + */ + public void zeroOut() { + for (long off = baseOffset; off < baseOffset + length * WIDTH; off += WIDTH) { + Platform.putLong(baseObj, off, 0); + } + } + /** * Sets the value at position {@code index}. */ diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java similarity index 97% rename from unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java index 7857bf66a72ad..c8c57381f332f 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java @@ -87,7 +87,8 @@ public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidt * To iterate over the true bits in a BitSet, use the following loop: *
    * 
-   *  for (long i = bs.nextSetBit(0, sizeInWords); i >= 0; i = bs.nextSetBit(i + 1, sizeInWords)) {
+   *  for (long i = bs.nextSetBit(0, sizeInWords); i >= 0;
+   *    i = bs.nextSetBit(i + 1, sizeInWords)) {
    *    // operate on index i here
    *  }
    * 
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
new file mode 100644
index 0000000000000..5e7ee480cafd1
--- /dev/null
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.hash;
+
+import org.apache.spark.unsafe.Platform;
+
+/**
+ * 32-bit Murmur3 hasher.  This is based on Guava's Murmur3_32HashFunction.
+ */
+public final class Murmur3_x86_32 {
+  private static final int C1 = 0xcc9e2d51;
+  private static final int C2 = 0x1b873593;
+
+  private final int seed;
+
+  public Murmur3_x86_32(int seed) {
+    this.seed = seed;
+  }
+
+  @Override
+  public String toString() {
+    return "Murmur3_32(seed=" + seed + ")";
+  }
+
+  public int hashInt(int input) {
+    return hashInt(input, seed);
+  }
+
+  public static int hashInt(int input, int seed) {
+    int k1 = mixK1(input);
+    int h1 = mixH1(seed, k1);
+
+    return fmix(h1, 4);
+  }
+
+  public int hashUnsafeWords(Object base, long offset, int lengthInBytes) {
+    return hashUnsafeWords(base, offset, lengthInBytes, seed);
+  }
+
+  public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) {
+    // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method.
+    assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)";
+    int h1 = hashBytesByInt(base, offset, lengthInBytes, seed);
+    return fmix(h1, lengthInBytes);
+  }
+
+  public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) {
+    assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
+    int lengthAligned = lengthInBytes - lengthInBytes % 4;
+    int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
+    for (int i = lengthAligned; i < lengthInBytes; i++) {
+      int halfWord = Platform.getByte(base, offset + i);
+      int k1 = mixK1(halfWord);
+      h1 = mixH1(h1, k1);
+    }
+    return fmix(h1, lengthInBytes);
+  }
+
+  private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) {
+    assert (lengthInBytes % 4 == 0);
+    int h1 = seed;
+    for (int i = 0; i < lengthInBytes; i += 4) {
+      int halfWord = Platform.getInt(base, offset + i);
+      int k1 = mixK1(halfWord);
+      h1 = mixH1(h1, k1);
+    }
+    return h1;
+  }
+
+  public int hashLong(long input) {
+    return hashLong(input, seed);
+  }
+
+  public static int hashLong(long input, int seed) {
+    int low = (int) input;
+    int high = (int) (input >>> 32);
+
+    int k1 = mixK1(low);
+    int h1 = mixH1(seed, k1);
+
+    k1 = mixK1(high);
+    h1 = mixH1(h1, k1);
+
+    return fmix(h1, 8);
+  }
+
+  private static int mixK1(int k1) {
+    k1 *= C1;
+    k1 = Integer.rotateLeft(k1, 15);
+    k1 *= C2;
+    return k1;
+  }
+
+  private static int mixH1(int h1, int k1) {
+    h1 ^= k1;
+    h1 = Integer.rotateLeft(h1, 13);
+    h1 = h1 * 5 + 0xe6546b64;
+    return h1;
+  }
+
+  // Finalization mix - force all bits of a hash block to avalanche
+  private static int fmix(int h1, int length) {
+    h1 ^= length;
+    h1 ^= h1 >>> 16;
+    h1 *= 0x85ebca6b;
+    h1 ^= h1 >>> 13;
+    h1 *= 0xc2b2ae35;
+    h1 ^= h1 >>> 16;
+    return h1;
+  }
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
similarity index 100%
rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java
similarity index 100%
rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java
rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
similarity index 100%
rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java
similarity index 100%
rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java
rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
similarity index 100%
rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java
similarity index 100%
rename from unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java
rename to common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
similarity index 94%
rename from unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
rename to common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
index 30e1758076361..62edf6c64bbc7 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
@@ -188,6 +188,11 @@ public static CalendarInterval fromSingleUnitString(String unit, String s)
             Integer.MIN_VALUE, Integer.MAX_VALUE);
           result = new CalendarInterval(month, 0L);
 
+        } else if (unit.equals("week")) {
+          long week = toLongWithRange("week", m.group(1),
+                  Long.MIN_VALUE / MICROS_PER_WEEK, Long.MAX_VALUE / MICROS_PER_WEEK);
+          result = new CalendarInterval(0, week * MICROS_PER_WEEK);
+
         } else if (unit.equals("day")) {
           long day = toLongWithRange("day", m.group(1),
             Long.MIN_VALUE / MICROS_PER_DAY, Long.MAX_VALUE / MICROS_PER_DAY);
@@ -206,6 +211,15 @@ public static CalendarInterval fromSingleUnitString(String unit, String s)
         } else if (unit.equals("second")) {
           long micros = parseSecondNano(m.group(1));
           result = new CalendarInterval(0, micros);
+
+        } else if (unit.equals("millisecond")) {
+          long millisecond = toLongWithRange("millisecond", m.group(1),
+                  Long.MIN_VALUE / MICROS_PER_MILLI, Long.MAX_VALUE / MICROS_PER_MILLI);
+          result = new CalendarInterval(0, millisecond * MICROS_PER_MILLI);
+
+        } else if (unit.equals("microsecond")) {
+          long micros = Long.valueOf(m.group(1));
+          result = new CalendarInterval(0, micros);
         }
       } catch (Exception e) {
         throw new IllegalArgumentException("Error parsing interval string: " + e.getMessage(), e);
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
similarity index 96%
rename from unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
rename to common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index b7aecb5102ba6..54a54569240c0 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -21,11 +21,18 @@
 import java.io.*;
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
+import java.nio.charset.StandardCharsets;
 import java.util.Arrays;
 import java.util.Map;
 
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.KryoSerializable;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.hash.Murmur3_x86_32;
 
 import static org.apache.spark.unsafe.Platform.*;
 
@@ -38,9 +45,10 @@
  * 

* Note: This is not designed for general use cases, should not be used outside SQL. */ -public final class UTF8String implements Comparable, Externalizable { +public final class UTF8String implements Comparable, Externalizable, KryoSerializable, + Cloneable { - // These are only updated by readExternal() + // These are only updated by readExternal() or read() @Nonnull private Object base; private long offset; @@ -98,15 +106,7 @@ public static UTF8String fromAddress(Object base, long offset, int numBytes) { * Creates an UTF8String from String. */ public static UTF8String fromString(String str) { - if (str == null) return null; - try { - return fromBytes(str.getBytes("utf-8")); - } catch (UnsupportedEncodingException e) { - // Turn the exception into unchecked so we can find out about it at runtime, but - // don't need to add lots of boilerplate code everywhere. - throwException(e); - return null; - } + return str == null ? null : fromBytes(str.getBytes(StandardCharsets.UTF_8)); } /** @@ -818,14 +818,7 @@ public UTF8String translate(Map dict) { @Override public String toString() { - try { - return new String(getBytes(), "utf-8"); - } catch (UnsupportedEncodingException e) { - // Turn the exception into unchecked so we can find out about it at runtime, but - // don't need to add lots of boilerplate code everywhere. - throwException(e); - return "unknown"; // we will never reach here. - } + return new String(getBytes(), StandardCharsets.UTF_8); } @Override @@ -895,9 +888,9 @@ public int levenshteinDistance(UTF8String other) { m = swap; } - int p[] = new int[n + 1]; - int d[] = new int[n + 1]; - int swap[]; + int[] p = new int[n + 1]; + int[] d = new int[n + 1]; + int[] swap; int i, i_bytes, j, j_bytes, num_bytes_j, cost; @@ -930,11 +923,7 @@ public int levenshteinDistance(UTF8String other) { @Override public int hashCode() { - int result = 1; - for (int i = 0; i < numBytes; i ++) { - result = 31 * result + getByte(i); - } - return result; + return Murmur3_x86_32.hashUnsafeBytes(base, offset, numBytes, 42); } /** @@ -960,7 +949,7 @@ public UTF8String soundex() { // first character must be a letter return this; } - byte sx[] = {'0', '0', '0', '0'}; + byte[] sx = {'0', '0', '0', '0'}; sx[0] = b; int sxi = 1; int idx = b - 'A'; @@ -1003,4 +992,19 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept in.readFully((byte[]) base); } + @Override + public void write(Kryo kryo, Output out) { + byte[] bytes = getBytes(); + out.writeInt(bytes.length); + out.write(bytes); + } + + @Override + public void read(Kryo kryo, Input in) { + this.offset = BYTE_ARRAY_OFFSET; + this.numBytes = in.readInt(); + this.base = new byte[numBytes]; + in.read((byte[]) base); + } + } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java new file mode 100644 index 0000000000000..693ec6ec58dbd --- /dev/null +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe; + +import org.junit.Assert; +import org.junit.Test; + +public class PlatformUtilSuite { + + @Test + public void overlappingCopyMemory() { + byte[] data = new byte[3 * 1024 * 1024]; + int size = 2 * 1024 * 1024; + for (int i = 0; i < data.length; ++i) { + data[i] = (byte)i; + } + + Platform.copyMemory(data, Platform.BYTE_ARRAY_OFFSET, data, Platform.BYTE_ARRAY_OFFSET, size); + for (int i = 0; i < data.length; ++i) { + Assert.assertEquals((byte)i, data[i]); + } + + Platform.copyMemory( + data, + Platform.BYTE_ARRAY_OFFSET + 1, + data, + Platform.BYTE_ARRAY_OFFSET, + size); + for (int i = 0; i < size; ++i) { + Assert.assertEquals((byte)(i + 1), data[i]); + } + + for (int i = 0; i < data.length; ++i) { + data[i] = (byte)i; + } + Platform.copyMemory( + data, + Platform.BYTE_ARRAY_OFFSET, + data, + Platform.BYTE_ARRAY_OFFSET + 1, + size); + for (int i = 0; i < size; ++i) { + Assert.assertEquals((byte)i, data[i + 1]); + } + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java similarity index 92% rename from unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java rename to common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java index 5974cf91ff993..fb8e53b3348f3 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java @@ -34,5 +34,9 @@ public void basicTest() { Assert.assertEquals(2, arr.size()); Assert.assertEquals(1L, arr.get(0)); Assert.assertEquals(3L, arr.get(1)); + + arr.zeroOut(); + Assert.assertEquals(0L, arr.get(0)); + Assert.assertEquals(0L, arr.get(1)); } } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java rename to common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java rename to common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java similarity index 98% rename from unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java rename to common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index e21ffdcff9abf..d4160ad029eb3 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.types; -import java.io.UnsupportedEncodingException; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.HashMap; @@ -30,9 +30,9 @@ public class UTF8StringSuite { - private static void checkBasic(String str, int len) throws UnsupportedEncodingException { + private static void checkBasic(String str, int len) { UTF8String s1 = fromString(str); - UTF8String s2 = fromBytes(str.getBytes("utf8")); + UTF8String s2 = fromBytes(str.getBytes(StandardCharsets.UTF_8)); assertEquals(s1.numChars(), len); assertEquals(s2.numChars(), len); @@ -51,7 +51,7 @@ private static void checkBasic(String str, int len) throws UnsupportedEncodingEx } @Test - public void basicTest() throws UnsupportedEncodingException { + public void basicTest() { checkBasic("", 0); checkBasic("hello", 5); checkBasic("大 千 世 界", 7); @@ -378,7 +378,7 @@ public void split() { assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), 2), new UTF8String[]{fromString("ab"), fromString("def,ghi")})); } - + @Test public void levenshteinDistance() { assertEquals(0, EMPTY_UTF8.levenshteinDistance(EMPTY_UTF8)); diff --git a/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala similarity index 99% rename from unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala rename to common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala index 12a002befa0ac..8a6b9e3e4536d 100644 --- a/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.unsafe.types import org.apache.commons.lang3.StringUtils - import org.scalacheck.{Arbitrary, Gen} import org.scalatest.prop.GeneratorDrivenPropertyChecks // scalastyle:off @@ -194,7 +193,7 @@ class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenProperty test("concat") { def concat(orgin: Seq[String]): String = - if (orgin.exists(_ == null)) null else orgin.mkString + if (orgin.contains(null)) null else orgin.mkString forAll { (inputs: Seq[String]) => assert(UTF8String.concat(inputs.map(toUTF8): _*) === toUTF8(inputs.mkString)) diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index f3046be54d7c6..ec1aa187dfb32 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -22,9 +22,14 @@ log4j.appender.console.target=System.err log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n +# Set the default spark-shell log level to WARN. When running the spark-shell, the +# log level for this class is used to overwrite the root logger's log level, so that +# the user can have different defaults for the shell and regular Spark apps. +log4j.logger.org.apache.spark.repl.Main=WARN + # Settings to quiet third party logs that are too verbose -log4j.logger.org.spark-project.jetty=WARN -log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark_project.jetty=WARN +log4j.logger.org.spark_project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO log4j.logger.org.apache.parquet=ERROR diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index d6962e0da2f30..8a4f4e48335bd 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -57,39 +57,41 @@ # added to Java properties using -Dspark.metrics.conf=xxx if you want to # customize metrics system. You can also put the file in ${SPARK_HOME}/conf # and it will be loaded automatically. -# 5. MetricsServlet is added by default as a sink in master, worker and client -# driver, you can send http request "/metrics/json" to get a snapshot of all the -# registered metrics in json format. For master, requests "/metrics/master/json" and -# "/metrics/applications/json" can be sent seperately to get metrics snapshot of -# instance master and applications. MetricsServlet may not be configured by self. -# +# 5. The MetricsServlet sink is added by default as a sink in the master, +# worker and driver, and you can send HTTP requests to the "/metrics/json" +# endpoint to get a snapshot of all the registered metrics in JSON format. +# For master, requests to the "/metrics/master/json" and +# "/metrics/applications/json" endpoints can be sent separately to get +# metrics snapshots of the master instance and applications. This +# MetricsServlet does not have to be configured. ## List of available common sources and their properties. # org.apache.spark.metrics.source.JvmSource -# Note: Currently, JvmSource is the only available common source -# to add additionaly to an instance, to enable this, -# set the "class" option to its fully qulified class name (see examples below) +# Note: Currently, JvmSource is the only available common source. +# It can be added to an instance by setting the "class" option to its +# fully qualified class name (see examples below). ## List of available sinks and their properties. # org.apache.spark.metrics.sink.ConsoleSink # Name: Default: Description: # period 10 Poll period -# unit seconds Units of poll period +# unit seconds Unit of the poll period # org.apache.spark.metrics.sink.CSVSink # Name: Default: Description: # period 10 Poll period -# unit seconds Units of poll period +# unit seconds Unit of the poll period # directory /tmp Where to store CSV files # org.apache.spark.metrics.sink.GangliaSink # Name: Default: Description: -# host NONE Hostname or multicast group of Ganglia server -# port NONE Port of Ganglia server(s) +# host NONE Hostname or multicast group of the Ganglia server, +# must be set +# port NONE Port of the Ganglia server(s), must be set # period 10 Poll period -# unit seconds Units of poll period +# unit seconds Unit of the poll period # ttl 1 TTL of messages sent by Ganglia # mode multicast Ganglia network mode ('unicast' or 'multicast') @@ -98,19 +100,21 @@ # org.apache.spark.metrics.sink.MetricsServlet # Name: Default: Description: # path VARIES* Path prefix from the web server root -# sample false Whether to show entire set of samples for histograms ('false' or 'true') +# sample false Whether to show entire set of samples for histograms +# ('false' or 'true') # -# * Default path is /metrics/json for all instances except the master. The master has two paths: +# * Default path is /metrics/json for all instances except the master. The +# master has two paths: # /metrics/applications/json # App information # /metrics/master/json # Master information # org.apache.spark.metrics.sink.GraphiteSink # Name: Default: Description: -# host NONE Hostname of Graphite server -# port NONE Port of Graphite server +# host NONE Hostname of the Graphite server, must be set +# port NONE Port of the Graphite server, must be set # period 10 Poll period -# unit seconds Units of poll period -# prefix EMPTY STRING Prefix to prepend to metric name +# unit seconds Unit of the poll period +# prefix EMPTY STRING Prefix to prepend to every metric's name # protocol tcp Protocol ("tcp" or "udp") to use ## Examples @@ -120,42 +124,42 @@ # Enable ConsoleSink for all instances by class name #*.sink.console.class=org.apache.spark.metrics.sink.ConsoleSink -# Polling period for ConsoleSink +# Polling period for the ConsoleSink #*.sink.console.period=10 - +# Unit of the polling period for the ConsoleSink #*.sink.console.unit=seconds -# Master instance overlap polling period +# Polling period for the ConsoleSink specific for the master instance #master.sink.console.period=15 - +# Unit of the polling period for the ConsoleSink specific for the master +# instance #master.sink.console.unit=seconds -# Enable CsvSink for all instances +# Enable CsvSink for all instances by class name #*.sink.csv.class=org.apache.spark.metrics.sink.CsvSink -# Polling period for CsvSink +# Polling period for the CsvSink #*.sink.csv.period=1 - +# Unit of the polling period for the CsvSink #*.sink.csv.unit=minutes # Polling directory for CsvSink #*.sink.csv.directory=/tmp/ -# Worker instance overlap polling period +# Polling period for the CsvSink specific for the worker instance #worker.sink.csv.period=10 - +# Unit of the polling period for the CsvSink specific for the worker instance #worker.sink.csv.unit=minutes # Enable Slf4jSink for all instances by class name #*.sink.slf4j.class=org.apache.spark.metrics.sink.Slf4jSink -# Polling period for Slf4JSink +# Polling period for the Slf4JSink #*.sink.slf4j.period=1 - +# Unit of the polling period for the Slf4jSink #*.sink.slf4j.unit=minutes - -# Enable jvm source for instance master, worker, driver and executor +# Enable JvmSource for instance master, worker, driver and executor #master.source.jvm.class=org.apache.spark.metrics.source.JvmSource #worker.source.jvm.class=org.apache.spark.metrics.source.JvmSource @@ -163,4 +167,3 @@ #driver.source.jvm.class=org.apache.spark.metrics.source.JvmSource #executor.source.jvm.class=org.apache.spark.metrics.source.JvmSource - diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 771251f90ee36..a031cd6a722f9 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -41,7 +41,7 @@ # - SPARK_EXECUTOR_MEMORY, Memory per Executor (e.g. 1000M, 2G) (Default: 1G) # - SPARK_DRIVER_MEMORY, Memory for Driver (e.g. 1000M, 2G) (Default: 1G) # - SPARK_YARN_APP_NAME, The name of your application (Default: Spark) -# - SPARK_YARN_QUEUE, The hadoop queue to use for allocation requests (Default: ‘default’) +# - SPARK_YARN_QUEUE, The hadoop queue to use for allocation requests (Default: 'default') # - SPARK_YARN_DIST_FILES, Comma separated list of files to be distributed with the job. # - SPARK_YARN_DIST_ARCHIVES, Comma separated list of archives to be distributed with the job. diff --git a/core/pom.xml b/core/pom.xml index 570a25cf325a2..7349ad35b9595 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT + spark-parent_2.11 + 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-core_2.10 + spark-core_2.11 core @@ -51,6 +51,10 @@ com.twitter chill-java + + org.apache.xbean + xbean-asm5-shaded + org.apache.hadoop hadoop-client @@ -174,21 +178,12 @@ lz4 - commons-net - commons-net - - - ${akka.group} - akka-remote_${scala.binary.version} - - - ${akka.group} - akka-slf4j_${scala.binary.version} + org.roaringbitmap + RoaringBitmap - ${akka.group} - akka-testkit_${scala.binary.version} - test + commons-net + commons-net org.scala-lang @@ -197,7 +192,6 @@ org.json4s json4s-jackson_${scala.binary.version} - 3.2.10 com.sun.jersey @@ -216,6 +210,10 @@ io.netty netty-all + + io.netty + netty + com.clearspring.analytics stream @@ -259,33 +257,6 @@ oro ${oro.version} - - org.tachyonproject - tachyon-client - 0.8.1 - - - org.apache.hadoop - hadoop-client - - - org.apache.curator - curator-client - - - org.apache.curator - curator-framework - - - org.apache.curator - curator-recipes - - - org.tachyonproject - tachyon-underfs-glusterfs - - - org.seleniumhq.selenium selenium-java @@ -342,7 +313,7 @@ net.sf.py4j py4j - 0.9 + 0.9.2 org.apache.spark diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java deleted file mode 100644 index fa9acf0a15b88..0000000000000 --- a/core/src/main/java/org/apache/spark/JavaSparkListener.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark; - -import org.apache.spark.scheduler.*; - -/** - * Java clients should extend this class instead of implementing - * SparkListener directly. This is to prevent java clients - * from breaking when new events are added to the SparkListener - * trait. - * - * This is a concrete class instead of abstract to enforce - * new events get added to both the SparkListener and this adapter - * in lockstep. - */ -public class JavaSparkListener implements SparkListener { - - @Override - public void onStageCompleted(SparkListenerStageCompleted stageCompleted) { } - - @Override - public void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) { } - - @Override - public void onTaskStart(SparkListenerTaskStart taskStart) { } - - @Override - public void onTaskGettingResult(SparkListenerTaskGettingResult taskGettingResult) { } - - @Override - public void onTaskEnd(SparkListenerTaskEnd taskEnd) { } - - @Override - public void onJobStart(SparkListenerJobStart jobStart) { } - - @Override - public void onJobEnd(SparkListenerJobEnd jobEnd) { } - - @Override - public void onEnvironmentUpdate(SparkListenerEnvironmentUpdate environmentUpdate) { } - - @Override - public void onBlockManagerAdded(SparkListenerBlockManagerAdded blockManagerAdded) { } - - @Override - public void onBlockManagerRemoved(SparkListenerBlockManagerRemoved blockManagerRemoved) { } - - @Override - public void onUnpersistRDD(SparkListenerUnpersistRDD unpersistRDD) { } - - @Override - public void onApplicationStart(SparkListenerApplicationStart applicationStart) { } - - @Override - public void onApplicationEnd(SparkListenerApplicationEnd applicationEnd) { } - - @Override - public void onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate executorMetricsUpdate) { } - - @Override - public void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { } - - @Override - public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { } - - @Override - public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { } - -} diff --git a/core/src/main/java/org/apache/spark/SparkExecutorInfo.java b/core/src/main/java/org/apache/spark/SparkExecutorInfo.java new file mode 100644 index 0000000000000..dc3e826475987 --- /dev/null +++ b/core/src/main/java/org/apache/spark/SparkExecutorInfo.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import java.io.Serializable; + +/** + * Exposes information about Spark Executors. + * + * This interface is not designed to be implemented outside of Spark. We may add additional methods + * which may break binary compatibility with outside implementations. + */ +public interface SparkExecutorInfo extends Serializable { + String host(); + int port(); + long cacheSize(); + int numRunningTasks(); +} diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index 1214d05ba6063..97eed611e8f9a 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -28,7 +28,7 @@ * this was a concrete Scala class, default implementations of new event handlers would be inherited * from the SparkListener trait). */ -public class SparkFirehoseListener implements SparkListener { +public class SparkFirehoseListener implements SparkListenerInterface { public void onEvent(SparkListenerEvent event) { } @@ -118,4 +118,8 @@ public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { onEvent(blockUpdated); } + @Override + public void onOtherEvent(SparkListenerEvent event) { + onEvent(event); + } } diff --git a/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java b/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java index d4c42b38ac224..0dd8fafbf2c82 100644 --- a/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java +++ b/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java @@ -62,5 +62,6 @@ public final JavaPairRDD union(JavaPairRDD... rdds) { // These methods take separate "first" and "rest" elements to avoid having the same type erasure public abstract JavaRDD union(JavaRDD first, List> rest); public abstract JavaDoubleRDD union(JavaDoubleRDD first, List rest); - public abstract JavaPairRDD union(JavaPairRDD first, List> rest); + public abstract JavaPairRDD union(JavaPairRDD first, List> + rest); } diff --git a/core/src/main/java/org/apache/spark/api/java/Optional.java b/core/src/main/java/org/apache/spark/api/java/Optional.java new file mode 100644 index 0000000000000..ca7babc3f01c7 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/Optional.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java; + +import java.io.Serializable; + +import com.google.common.base.Preconditions; + +/** + *

Like {@code java.util.Optional} in Java 8, {@code scala.Option} in Scala, and + * {@code com.google.common.base.Optional} in Google Guava, this class represents a + * value of a given type that may or may not exist. It is used in methods that wish + * to optionally return a value, in preference to returning {@code null}.

+ * + *

In fact, the class here is a reimplementation of the essential API of both + * {@code java.util.Optional} and {@code com.google.common.base.Optional}. From + * {@code java.util.Optional}, it implements:

+ * + *
    + *
  • {@link #empty()}
  • + *
  • {@link #of(Object)}
  • + *
  • {@link #ofNullable(Object)}
  • + *
  • {@link #get()}
  • + *
  • {@link #orElse(Object)}
  • + *
  • {@link #isPresent()}
  • + *
+ * + *

From {@code com.google.common.base.Optional} it implements:

+ * + *
    + *
  • {@link #absent()}
  • + *
  • {@link #of(Object)}
  • + *
  • {@link #fromNullable(Object)}
  • + *
  • {@link #get()}
  • + *
  • {@link #or(Object)}
  • + *
  • {@link #orNull()}
  • + *
  • {@link #isPresent()}
  • + *
+ * + *

{@code java.util.Optional} itself is not used at this time because the + * project does not require Java 8. Using {@code com.google.common.base.Optional} + * has in the past caused serious library version conflicts with Guava that can't + * be resolved by shading. Hence this work-alike clone.

+ * + * @param type of value held inside + */ +public final class Optional implements Serializable { + + private static final Optional EMPTY = new Optional<>(); + + private final T value; + + private Optional() { + this.value = null; + } + + private Optional(T value) { + Preconditions.checkNotNull(value); + this.value = value; + } + + // java.util.Optional API (subset) + + /** + * @return an empty {@code Optional} + */ + public static Optional empty() { + @SuppressWarnings("unchecked") + Optional t = (Optional) EMPTY; + return t; + } + + /** + * @param value non-null value to wrap + * @return {@code Optional} wrapping this value + * @throws NullPointerException if value is null + */ + public static Optional of(T value) { + return new Optional<>(value); + } + + /** + * @param value value to wrap, which may be null + * @return {@code Optional} wrapping this value, which may be empty + */ + public static Optional ofNullable(T value) { + if (value == null) { + return empty(); + } else { + return of(value); + } + } + + /** + * @return the value wrapped by this {@code Optional} + * @throws NullPointerException if this is empty (contains no value) + */ + public T get() { + Preconditions.checkNotNull(value); + return value; + } + + /** + * @param other value to return if this is empty + * @return this {@code Optional}'s value if present, or else the given value + */ + public T orElse(T other) { + return value != null ? value : other; + } + + /** + * @return true iff this {@code Optional} contains a value (non-empty) + */ + public boolean isPresent() { + return value != null; + } + + // Guava API (subset) + // of(), get() and isPresent() are identically present in the Guava API + + /** + * @return an empty {@code Optional} + */ + public static Optional absent() { + return empty(); + } + + /** + * @param value value to wrap, which may be null + * @return {@code Optional} wrapping this value, which may be empty + */ + public static Optional fromNullable(T value) { + return ofNullable(value); + } + + /** + * @param other value to return if this is empty + * @return this {@code Optional}'s value if present, or else the given value + */ + public T or(T other) { + return value != null ? value : other; + } + + /** + * @return this {@code Optional}'s value if present, or else null + */ + public T orNull() { + return value; + } + + // Common methods + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof Optional)) { + return false; + } + Optional other = (Optional) obj; + return value == null ? other.value == null : value.equals(other.value); + } + + @Override + public int hashCode() { + return value == null ? 0 : value.hashCode(); + } + + @Override + public String toString() { + return value == null ? "Optional.empty" : String.format("Optional[%s]", value); + } + +} diff --git a/core/src/main/java/org/apache/spark/api/java/StorageLevels.java b/core/src/main/java/org/apache/spark/api/java/StorageLevels.java index 840a1bd93bfbb..3fcb52f615834 100644 --- a/core/src/main/java/org/apache/spark/api/java/StorageLevels.java +++ b/core/src/main/java/org/apache/spark/api/java/StorageLevels.java @@ -34,26 +34,13 @@ public class StorageLevels { public static final StorageLevel MEMORY_AND_DISK_2 = create(true, true, false, true, 2); public static final StorageLevel MEMORY_AND_DISK_SER = create(true, true, false, false, 1); public static final StorageLevel MEMORY_AND_DISK_SER_2 = create(true, true, false, false, 2); - public static final StorageLevel OFF_HEAP = create(false, false, true, false, 1); + public static final StorageLevel OFF_HEAP = create(true, true, true, false, 1); /** * Create a new StorageLevel object. * @param useDisk saved to disk, if true - * @param useMemory saved to memory, if true - * @param deserialized saved as deserialized objects, if true - * @param replication replication factor - */ - @Deprecated - public static StorageLevel create(boolean useDisk, boolean useMemory, boolean deserialized, - int replication) { - return StorageLevel.apply(useDisk, useMemory, false, deserialized, replication); - } - - /** - * Create a new StorageLevel object. - * @param useDisk saved to disk, if true - * @param useMemory saved to memory, if true - * @param useOffHeap saved to Tachyon, if true + * @param useMemory saved to on-heap memory, if true + * @param useOffHeap saved to off-heap memory, if true * @param deserialized saved as deserialized objects, if true * @param replication replication factor */ diff --git a/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java new file mode 100644 index 0000000000000..07aebb75e8f4e --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * A function that returns zero or more output records from each grouping key and its values from 2 + * Datasets. + */ +public interface CoGroupFunction extends Serializable { + Iterator call(K key, Iterator left, Iterator right) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java index 57fd0a7a80494..576087b6f428e 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java @@ -18,10 +18,11 @@ package org.apache.spark.api.java.function; import java.io.Serializable; +import java.util.Iterator; /** * A function that returns zero or more records of type Double from each input record. */ public interface DoubleFlatMapFunction extends Serializable { - public Iterable call(T t) throws Exception; + Iterator call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java b/core/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java index 150144e0e418c..bf16f791f906a 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java @@ -23,5 +23,5 @@ * A function that returns Doubles, and can be used to construct DoubleRDDs. */ public interface DoubleFunction extends Serializable { - public double call(T t) throws Exception; + double call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java new file mode 100644 index 0000000000000..e8d999dd00135 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * Base interface for a function used in Dataset's filter function. + * + * If the function returns true, the element is discarded in the returned Dataset. + */ +public interface FilterFunction extends Serializable { + boolean call(T value) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java index 23f5fdd43631b..2d8ea6d1a5a7e 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java @@ -18,10 +18,11 @@ package org.apache.spark.api.java.function; import java.io.Serializable; +import java.util.Iterator; /** * A function that returns zero or more output records from each input record. */ public interface FlatMapFunction extends Serializable { - public Iterable call(T t) throws Exception; + Iterator call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java index c48e92f535ff5..fc97b63f825d0 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java @@ -18,10 +18,11 @@ package org.apache.spark.api.java.function; import java.io.Serializable; +import java.util.Iterator; /** * A function that takes two inputs and returns zero or more output records. */ public interface FlatMapFunction2 extends Serializable { - public Iterable call(T1 t1, T2 t2) throws Exception; + Iterator call(T1 t1, T2 t2) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java new file mode 100644 index 0000000000000..bae574ab5755d --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * A function that returns zero or more output records from each grouping key and its values. + */ +public interface FlatMapGroupsFunction extends Serializable { + Iterator call(K key, Iterator values) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java b/core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java new file mode 100644 index 0000000000000..07e54b28fa12c --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * Base interface for a function used in Dataset's foreach function. + * + * Spark will invoke the call function on each element in the input Dataset. + */ +public interface ForeachFunction extends Serializable { + void call(T t) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java b/core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java new file mode 100644 index 0000000000000..4938a51bcd712 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * Base interface for a function used in Dataset's foreachPartition function. + */ +public interface ForeachPartitionFunction extends Serializable { + void call(Iterator t) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function.java b/core/src/main/java/org/apache/spark/api/java/function/Function.java index d00551bb0add6..b9d9777a75651 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/Function.java +++ b/core/src/main/java/org/apache/spark/api/java/function/Function.java @@ -25,5 +25,5 @@ * when mapping RDDs of other types. */ public interface Function extends Serializable { - public R call(T1 v1) throws Exception; + R call(T1 v1) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function0.java b/core/src/main/java/org/apache/spark/api/java/function/Function0.java index 38e410c5debe6..c86928dd05408 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/Function0.java +++ b/core/src/main/java/org/apache/spark/api/java/function/Function0.java @@ -23,5 +23,5 @@ * A zero-argument function that returns an R. */ public interface Function0 extends Serializable { - public R call() throws Exception; + R call() throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function2.java b/core/src/main/java/org/apache/spark/api/java/function/Function2.java index 793caaa61ac5a..a975ce3c68192 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/Function2.java +++ b/core/src/main/java/org/apache/spark/api/java/function/Function2.java @@ -23,5 +23,5 @@ * A two-argument function that takes arguments of type T1 and T2 and returns an R. */ public interface Function2 extends Serializable { - public R call(T1 v1, T2 v2) throws Exception; + R call(T1 v1, T2 v2) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function3.java b/core/src/main/java/org/apache/spark/api/java/function/Function3.java index b4151c3417df4..6eecfb645a663 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/Function3.java +++ b/core/src/main/java/org/apache/spark/api/java/function/Function3.java @@ -23,5 +23,5 @@ * A three-argument function that takes arguments of type T1, T2 and T3 and returns an R. */ public interface Function3 extends Serializable { - public R call(T1 v1, T2 v2, T3 v3) throws Exception; + R call(T1 v1, T2 v2, T3 v3) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function4.java b/core/src/main/java/org/apache/spark/api/java/function/Function4.java new file mode 100644 index 0000000000000..9c35a22ca9d0f --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/Function4.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * A four-argument function that takes arguments of type T1, T2, T3 and T4 and returns an R. + */ +public interface Function4 extends Serializable { + R call(T1 v1, T2 v2, T3 v3, T4 v4) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java new file mode 100644 index 0000000000000..3ae6ef44898e1 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * Base interface for a map function used in Dataset's map function. + */ +public interface MapFunction extends Serializable { + U call(T value) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java new file mode 100644 index 0000000000000..faa59eabc8b4f --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * Base interface for a map function used in GroupedDataset's mapGroup function. + */ +public interface MapGroupsFunction extends Serializable { + R call(K key, Iterator values) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java new file mode 100644 index 0000000000000..cf9945a215aff --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * Base interface for function used in Dataset's mapPartitions. + */ +public interface MapPartitionsFunction extends Serializable { + Iterator call(Iterator input) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java index 691ef2eceb1f6..51eed2e67b9fa 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java @@ -18,6 +18,7 @@ package org.apache.spark.api.java.function; import java.io.Serializable; +import java.util.Iterator; import scala.Tuple2; @@ -26,5 +27,5 @@ * key-value pairs are represented as scala.Tuple2 objects. */ public interface PairFlatMapFunction extends Serializable { - public Iterable> call(T t) throws Exception; + Iterator> call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java b/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java index 99bf240a17225..2fdfa7184a3bd 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java @@ -26,5 +26,5 @@ * construct PairRDDs. */ public interface PairFunction extends Serializable { - public Tuple2 call(T t) throws Exception; + Tuple2 call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java b/core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java new file mode 100644 index 0000000000000..ee092d0058f44 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * Base interface for function used in Dataset's reduce. + */ +public interface ReduceFunction extends Serializable { + T call(T v1, T v2) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java index 2a10435b7523a..f30d42ee57966 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java @@ -23,5 +23,5 @@ * A function with no return value. */ public interface VoidFunction extends Serializable { - public void call(T t) throws Exception; + void call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java new file mode 100644 index 0000000000000..da9ae1c9c5cdc --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * A two-argument function that takes arguments of type T1 and T2 with no return value. + */ +public interface VoidFunction2 extends Serializable { + void call(T1 v1, T2 v2) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/io/LZ4BlockInputStream.java b/core/src/main/java/org/apache/spark/io/LZ4BlockInputStream.java new file mode 100644 index 0000000000000..8783b5f56ebae --- /dev/null +++ b/core/src/main/java/org/apache/spark/io/LZ4BlockInputStream.java @@ -0,0 +1,261 @@ +package org.apache.spark.io; + +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.EOFException; +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.zip.Checksum; + +import net.jpountz.lz4.LZ4Exception; +import net.jpountz.lz4.LZ4Factory; +import net.jpountz.lz4.LZ4FastDecompressor; +import net.jpountz.util.SafeUtils; +import net.jpountz.xxhash.XXHashFactory; + +/** + * {@link InputStream} implementation to decode data written with + * {@link net.jpountz.lz4.LZ4BlockOutputStream}. This class is not thread-safe and does not + * support {@link #mark(int)}/{@link #reset()}. + * @see net.jpountz.lz4.LZ4BlockOutputStream + * + * This is based on net.jpountz.lz4.LZ4BlockInputStream + * + * changes: https://github.com/davies/lz4-java/commit/cc1fa940ac57cc66a0b937300f805d37e2bf8411 + * + * TODO: merge this into upstream + */ +public final class LZ4BlockInputStream extends FilterInputStream { + + // Copied from net.jpountz.lz4.LZ4BlockOutputStream + static final byte[] MAGIC = new byte[] { 'L', 'Z', '4', 'B', 'l', 'o', 'c', 'k' }; + static final int MAGIC_LENGTH = MAGIC.length; + + static final int HEADER_LENGTH = + MAGIC_LENGTH // magic bytes + + 1 // token + + 4 // compressed length + + 4 // decompressed length + + 4; // checksum + + static final int COMPRESSION_LEVEL_BASE = 10; + + static final int COMPRESSION_METHOD_RAW = 0x10; + static final int COMPRESSION_METHOD_LZ4 = 0x20; + + static final int DEFAULT_SEED = 0x9747b28c; + + private final LZ4FastDecompressor decompressor; + private final Checksum checksum; + private byte[] buffer; + private byte[] compressedBuffer; + private int originalLen; + private int o; + private boolean finished; + + /** + * Create a new {@link InputStream}. + * + * @param in the {@link InputStream} to poll + * @param decompressor the {@link LZ4FastDecompressor decompressor} instance to + * use + * @param checksum the {@link Checksum} instance to use, must be + * equivalent to the instance which has been used to + * write the stream + */ + public LZ4BlockInputStream(InputStream in, LZ4FastDecompressor decompressor, Checksum checksum) { + super(in); + this.decompressor = decompressor; + this.checksum = checksum; + this.buffer = new byte[0]; + this.compressedBuffer = new byte[HEADER_LENGTH]; + o = originalLen = 0; + finished = false; + } + + /** + * Create a new instance using {@link net.jpountz.xxhash.XXHash32} for checksuming. + * @see #LZ4BlockInputStream(InputStream, LZ4FastDecompressor, Checksum) + * @see net.jpountz.xxhash.StreamingXXHash32#asChecksum() + */ + public LZ4BlockInputStream(InputStream in, LZ4FastDecompressor decompressor) { + this(in, decompressor, + XXHashFactory.fastestInstance().newStreamingHash32(DEFAULT_SEED).asChecksum()); + } + + /** + * Create a new instance which uses the fastest {@link LZ4FastDecompressor} available. + * @see LZ4Factory#fastestInstance() + * @see #LZ4BlockInputStream(InputStream, LZ4FastDecompressor) + */ + public LZ4BlockInputStream(InputStream in) { + this(in, LZ4Factory.fastestInstance().fastDecompressor()); + } + + @Override + public int available() throws IOException { + refill(); + return originalLen - o; + } + + @Override + public int read() throws IOException { + refill(); + if (finished) { + return -1; + } + return buffer[o++] & 0xFF; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + SafeUtils.checkRange(b, off, len); + refill(); + if (finished) { + return -1; + } + len = Math.min(len, originalLen - o); + System.arraycopy(buffer, o, b, off, len); + o += len; + return len; + } + + @Override + public int read(byte[] b) throws IOException { + return read(b, 0, b.length); + } + + @Override + public long skip(long n) throws IOException { + refill(); + if (finished) { + return -1; + } + final int skipped = (int) Math.min(n, originalLen - o); + o += skipped; + return skipped; + } + + private void refill() throws IOException { + if (finished || o < originalLen) { + return; + } + try { + readFully(compressedBuffer, HEADER_LENGTH); + } catch (EOFException e) { + finished = true; + return; + } + for (int i = 0; i < MAGIC_LENGTH; ++i) { + if (compressedBuffer[i] != MAGIC[i]) { + throw new IOException("Stream is corrupted"); + } + } + final int token = compressedBuffer[MAGIC_LENGTH] & 0xFF; + final int compressionMethod = token & 0xF0; + final int compressionLevel = COMPRESSION_LEVEL_BASE + (token & 0x0F); + if (compressionMethod != COMPRESSION_METHOD_RAW && compressionMethod != COMPRESSION_METHOD_LZ4) + { + throw new IOException("Stream is corrupted"); + } + final int compressedLen = SafeUtils.readIntLE(compressedBuffer, MAGIC_LENGTH + 1); + originalLen = SafeUtils.readIntLE(compressedBuffer, MAGIC_LENGTH + 5); + final int check = SafeUtils.readIntLE(compressedBuffer, MAGIC_LENGTH + 9); + assert HEADER_LENGTH == MAGIC_LENGTH + 13; + if (originalLen > 1 << compressionLevel + || originalLen < 0 + || compressedLen < 0 + || (originalLen == 0 && compressedLen != 0) + || (originalLen != 0 && compressedLen == 0) + || (compressionMethod == COMPRESSION_METHOD_RAW && originalLen != compressedLen)) { + throw new IOException("Stream is corrupted"); + } + if (originalLen == 0 && compressedLen == 0) { + if (check != 0) { + throw new IOException("Stream is corrupted"); + } + refill(); + return; + } + if (buffer.length < originalLen) { + buffer = new byte[Math.max(originalLen, buffer.length * 3 / 2)]; + } + switch (compressionMethod) { + case COMPRESSION_METHOD_RAW: + readFully(buffer, originalLen); + break; + case COMPRESSION_METHOD_LZ4: + if (compressedBuffer.length < originalLen) { + compressedBuffer = new byte[Math.max(compressedLen, compressedBuffer.length * 3 / 2)]; + } + readFully(compressedBuffer, compressedLen); + try { + final int compressedLen2 = + decompressor.decompress(compressedBuffer, 0, buffer, 0, originalLen); + if (compressedLen != compressedLen2) { + throw new IOException("Stream is corrupted"); + } + } catch (LZ4Exception e) { + throw new IOException("Stream is corrupted", e); + } + break; + default: + throw new AssertionError(); + } + checksum.reset(); + checksum.update(buffer, 0, originalLen); + if ((int) checksum.getValue() != check) { + throw new IOException("Stream is corrupted"); + } + o = 0; + } + + private void readFully(byte[] b, int len) throws IOException { + int read = 0; + while (read < len) { + final int r = in.read(b, read, len - read); + if (r < 0) { + throw new EOFException("Stream ended prematurely"); + } + read += r; + } + assert len == read; + } + + @Override + public boolean markSupported() { + return false; + } + + @SuppressWarnings("sync-override") + @Override + public void mark(int readlimit) { + // unsupported + } + + @SuppressWarnings("sync-override") + @Override + public void reset() throws IOException { + throw new IOException("mark/reset not supported"); + } + + @Override + public String toString() { + return getClass().getSimpleName() + "(in=" + in + + ", decompressor=" + decompressor + ", checksum=" + checksum + ")"; + } + +} diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index 008799cc77395..36138cc9a297c 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -17,25 +17,25 @@ package org.apache.spark.memory; - import java.io.IOException; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; - /** * An memory consumer of TaskMemoryManager, which support spilling. + * + * Note: this only supports allocation / spilling of Tungsten memory. */ public abstract class MemoryConsumer { - private final TaskMemoryManager taskMemoryManager; + protected final TaskMemoryManager taskMemoryManager; private final long pageSize; - private long used; + protected long used; protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize) { this.taskMemoryManager = taskMemoryManager; this.pageSize = pageSize; - this.used = 0; } protected MemoryConsumer(TaskMemoryManager taskMemoryManager) { @@ -66,6 +66,8 @@ public void spill() throws IOException { * * Note: In order to avoid possible deadlock, should not call acquireMemory() from spill(). * + * Note: today, this only frees Tungsten-managed pages. + * * @param size the amount of memory should be released * @param trigger the MemoryConsumer that trigger this spilling * @return the amount of released memory in bytes @@ -74,26 +76,29 @@ public void spill() throws IOException { public abstract long spill(long size, MemoryConsumer trigger) throws IOException; /** - * Acquire `size` bytes memory. - * - * If there is not enough memory, throws OutOfMemoryError. + * Allocates a LongArray of `size`. */ - protected void acquireMemory(long size) { - long got = taskMemoryManager.acquireExecutionMemory(size, this); - if (got < size) { - taskMemoryManager.releaseExecutionMemory(got, this); + public LongArray allocateArray(long size) { + long required = size * 8L; + MemoryBlock page = taskMemoryManager.allocatePage(required, this); + if (page == null || page.size() < required) { + long got = 0; + if (page != null) { + got = page.size(); + taskMemoryManager.freePage(page, this); + } taskMemoryManager.showMemoryUsage(); - throw new OutOfMemoryError("Could not acquire " + size + " bytes of memory, got " + got); + throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); } - used += got; + used += required; + return new LongArray(page); } /** - * Release `size` bytes memory. + * Frees a LongArray. */ - protected void releaseMemory(long size) { - used -= size; - taskMemoryManager.releaseExecutionMemory(size, this); + public void freeArray(LongArray array) { + freePage(array.memoryBlock()); } /** @@ -109,7 +114,7 @@ protected MemoryBlock allocatePage(long required) { long got = 0; if (page != null) { got = page.size(); - freePage(page); + taskMemoryManager.freePage(page, this); } taskMemoryManager.showMemoryUsage(); throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); diff --git a/core/src/main/java/org/apache/spark/memory/MemoryMode.java b/core/src/main/java/org/apache/spark/memory/MemoryMode.java new file mode 100644 index 0000000000000..3a5e72d8aaec0 --- /dev/null +++ b/core/src/main/java/org/apache/spark/memory/MemoryMode.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory; + +import org.apache.spark.annotation.Private; + +@Private +public enum MemoryMode { + ON_HEAP, + OFF_HEAP +} diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 4230575446d31..9044bb4f4a44b 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -67,9 +67,9 @@ public class TaskMemoryManager { /** * Maximum supported data page size (in bytes). In principle, the maximum addressable page size is - * (1L << OFFSET_BITS) bytes, which is 2+ petabytes. However, the on-heap allocator's maximum page - * size is limited by the maximum amount of data that can be stored in a long[] array, which is - * (2^32 - 1) * 8 bytes (or 16 gigabytes). Therefore, we cap this at 16 gigabytes. + * (1L << OFFSET_BITS) bytes, which is 2+ petabytes. However, the on-heap allocator's + * maximum page size is limited by the maximum amount of data that can be stored in a long[] + * array, which is (2^32 - 1) * 8 bytes (or 16 gigabytes). Therefore, we cap this at 16 gigabytes. */ public static final long MAXIMUM_PAGE_SIZE_BYTES = ((1L << 31) - 1) * 8L; @@ -103,19 +103,24 @@ public class TaskMemoryManager { * without doing any masking or lookups. Since this branching should be well-predicted by the JIT, * this extra layer of indirection / abstraction hopefully shouldn't be too expensive. */ - private final boolean inHeap; + final MemoryMode tungstenMemoryMode; /** - * The size of memory granted to each consumer. + * Tracks spillable memory consumers. */ @GuardedBy("this") private final HashSet consumers; + /** + * The amount of memory that is acquired but not used. + */ + private long acquiredButNotUsed = 0L; + /** * Construct a new TaskMemoryManager. */ public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) { - this.inHeap = memoryManager.tungstenMemoryIsAllocatedInHeap(); + this.tungstenMemoryMode = memoryManager.tungstenMemoryMode(); this.memoryManager = memoryManager; this.taskAttemptId = taskAttemptId; this.consumers = new HashSet<>(); @@ -127,23 +132,30 @@ public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) { * * @return number of bytes successfully granted (<= N). */ - public long acquireExecutionMemory(long required, MemoryConsumer consumer) { + public long acquireExecutionMemory( + long required, + MemoryMode mode, + MemoryConsumer consumer) { assert(required >= 0); + // If we are allocating Tungsten pages off-heap and receive a request to allocate on-heap + // memory here, then it may not make sense to spill since that would only end up freeing + // off-heap memory. This is subject to change, though, so it may be risky to make this + // optimization now in case we forget to undo it late when making changes. synchronized (this) { - long got = memoryManager.acquireExecutionMemory(required, taskAttemptId); + long got = memoryManager.acquireExecutionMemory(required, taskAttemptId, mode); - // try to release memory from other consumers first, then we can reduce the frequency of + // Try to release memory from other consumers first, then we can reduce the frequency of // spilling, avoid to have too many spilled files. if (got < required) { // Call spill() on other consumers to release memory for (MemoryConsumer c: consumers) { - if (c != null && c != consumer && c.getUsed() > 0) { + if (c != consumer && c.getUsed() > 0) { try { long released = c.spill(required - got, consumer); - if (released > 0) { - logger.info("Task {} released {} from {} for {}", taskAttemptId, + if (released > 0 && mode == tungstenMemoryMode) { + logger.debug("Task {} released {} from {} for {}", taskAttemptId, Utils.bytesToString(released), c, consumer); - got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId); + got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode); if (got >= required) { break; } @@ -161,10 +173,10 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { if (got < required && consumer != null) { try { long released = consumer.spill(required - got, consumer); - if (released > 0) { - logger.info("Task {} released {} from itself ({})", taskAttemptId, + if (released > 0 && mode == tungstenMemoryMode) { + logger.debug("Task {} released {} from itself ({})", taskAttemptId, Utils.bytesToString(released), consumer); - got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId); + got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode); } } catch (IOException e) { logger.error("error while calling spill() on " + consumer, e); @@ -173,7 +185,9 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { } } - consumers.add(consumer); + if (consumer != null) { + consumers.add(consumer); + } logger.debug("Task {} acquire {} for {}", taskAttemptId, Utils.bytesToString(got), consumer); return got; } @@ -182,9 +196,9 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { /** * Release N bytes of execution memory for a MemoryConsumer. */ - public void releaseExecutionMemory(long size, MemoryConsumer consumer) { + public void releaseExecutionMemory(long size, MemoryMode mode, MemoryConsumer consumer) { logger.debug("Task {} release {} from {}", taskAttemptId, Utils.bytesToString(size), consumer); - memoryManager.releaseExecutionMemory(size, taskAttemptId); + memoryManager.releaseExecutionMemory(size, taskAttemptId, mode); } /** @@ -193,11 +207,22 @@ public void releaseExecutionMemory(long size, MemoryConsumer consumer) { public void showMemoryUsage() { logger.info("Memory used in task " + taskAttemptId); synchronized (this) { + long memoryAccountedForByConsumers = 0; for (MemoryConsumer c: consumers) { - if (c.getUsed() > 0) { - logger.info("Acquired by " + c + ": " + Utils.bytesToString(c.getUsed())); + long totalMemUsage = c.getUsed(); + memoryAccountedForByConsumers += totalMemUsage; + if (totalMemUsage > 0) { + logger.info("Acquired by " + c + ": " + Utils.bytesToString(totalMemUsage)); } } + long memoryNotAccountedFor = + memoryManager.getExecutionMemoryUsageForTask(taskAttemptId) - memoryAccountedForByConsumers; + logger.info( + "{} bytes of memory were used by task {} but are not associated with specific consumers", + memoryNotAccountedFor, taskAttemptId); + logger.info( + "{} bytes of memory are used for execution and {} bytes of memory are used for storage", + memoryManager.executionMemoryUsed(), memoryManager.storageMemoryUsed()); } } @@ -212,7 +237,8 @@ public long pageSizeBytes() { * Allocate a block of memory that will be tracked in the MemoryManager's page table; this is * intended for allocating large blocks of Tungsten memory that will be shared between operators. * - * Returns `null` if there was not enough memory to allocate the page. + * Returns `null` if there was not enough memory to allocate the page. May return a page that + * contains fewer bytes than requested, so callers should verify the size of returned pages. */ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { if (size > MAXIMUM_PAGE_SIZE_BYTES) { @@ -220,7 +246,7 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes"); } - long acquired = acquireExecutionMemory(size, consumer); + long acquired = acquireExecutionMemory(size, tungstenMemoryMode, consumer); if (acquired <= 0) { return null; } @@ -229,13 +255,26 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { synchronized (this) { pageNumber = allocatedPages.nextClearBit(0); if (pageNumber >= PAGE_TABLE_SIZE) { - releaseExecutionMemory(acquired, consumer); + releaseExecutionMemory(acquired, tungstenMemoryMode, consumer); throw new IllegalStateException( "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages"); } allocatedPages.set(pageNumber); } - final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(acquired); + MemoryBlock page = null; + try { + page = memoryManager.tungstenMemoryAllocator().allocate(acquired); + } catch (OutOfMemoryError e) { + logger.warn("Failed to allocate a page ({} bytes), try again.", acquired); + // there is no enough memory actually, it means the actual free memory is smaller than + // MemoryManager thought, we should keep the acquired memory. + synchronized (this) { + acquiredButNotUsed += acquired; + allocatedPages.clear(pageNumber); + } + // this could trigger spilling to free some pages. + return allocatePage(size, consumer); + } page.pageNumber = pageNumber; pageTable[pageNumber] = page; if (logger.isTraceEnabled()) { @@ -260,7 +299,7 @@ public void freePage(MemoryBlock page, MemoryConsumer consumer) { } long pageSize = page.size(); memoryManager.tungstenMemoryAllocator().free(page); - releaseExecutionMemory(pageSize, consumer); + releaseExecutionMemory(pageSize, tungstenMemoryMode, consumer); } /** @@ -274,7 +313,7 @@ public void freePage(MemoryBlock page, MemoryConsumer consumer) { * @return an encoded page address. */ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { - if (!inHeap) { + if (tungstenMemoryMode == MemoryMode.OFF_HEAP) { // In off-heap mode, an offset is an absolute address that may require a full 64 bits to // encode. Due to our page size limitation, though, we can convert this into an offset that's // relative to the page's base offset; this relative offset will fit in 51 bits. @@ -291,7 +330,7 @@ public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) @VisibleForTesting public static int decodePageNumber(long pagePlusOffsetAddress) { - return (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> OFFSET_BITS); + return (int) (pagePlusOffsetAddress >>> OFFSET_BITS); } private static long decodeOffset(long pagePlusOffsetAddress) { @@ -303,7 +342,7 @@ private static long decodeOffset(long pagePlusOffsetAddress) { * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} */ public Object getPage(long pagePlusOffsetAddress) { - if (inHeap) { + if (tungstenMemoryMode == MemoryMode.ON_HEAP) { final int pageNumber = decodePageNumber(pagePlusOffsetAddress); assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); final MemoryBlock page = pageTable[pageNumber]; @@ -321,7 +360,7 @@ public Object getPage(long pagePlusOffsetAddress) { */ public long getOffsetInPage(long pagePlusOffsetAddress) { final long offsetInPage = decodeOffset(pagePlusOffsetAddress); - if (inHeap) { + if (tungstenMemoryMode == MemoryMode.ON_HEAP) { return offsetInPage; } else { // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we @@ -349,11 +388,22 @@ public long cleanUpAllAllocatedMemory() { } consumers.clear(); } + + for (MemoryBlock page : pageTable) { + if (page != null) { + memoryManager.tungstenMemoryAllocator().free(page); + } + } + Arrays.fill(pageTable, null); + + // release the memory that is not used by any consumer. + memoryManager.releaseExecutionMemory(acquiredButNotUsed, taskAttemptId, tungstenMemoryMode); + return memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId); } /** - * Returns the memory consumption, in bytes, for the current task + * Returns the memory consumption, in bytes, for the current task. */ public long getMemoryConsumptionForThisTask() { return memoryManager.getExecutionMemoryUsageForTask(taskAttemptId); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index ee82d679935c0..7a60c3eb35740 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -98,7 +98,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { */ private boolean stopping = false; - public BypassMergeSortShuffleWriter( + BypassMergeSortShuffleWriter( BlockManager blockManager, IndexShuffleBlockResolver shuffleBlockResolver, BypassMergeSortShuffleHandle handle, @@ -114,9 +114,8 @@ public BypassMergeSortShuffleWriter( this.shuffleId = dep.shuffleId(); this.partitioner = dep.partitioner(); this.numPartitions = partitioner.numPartitions(); - this.writeMetrics = new ShuffleWriteMetrics(); - taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics)); - this.serializer = Serializer.getSerializer(dep.serializer()); + this.writeMetrics = taskContext.taskMetrics().registerShuffleWriteMetrics(); + this.serializer = dep.serializer(); this.shuffleBlockResolver = shuffleBlockResolver; } @@ -125,7 +124,7 @@ public void write(Iterator> records) throws IOException { assert (partitionWriters == null); if (!records.hasNext()) { partitionLengths = new long[numPartitions]; - shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths); + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); return; } @@ -138,12 +137,12 @@ public void write(Iterator> records) throws IOException { final File file = tempShuffleBlockIdPlusFile._2(); final BlockId blockId = tempShuffleBlockIdPlusFile._1(); partitionWriters[i] = - blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics).open(); + blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics); } // Creating the file to write to and creating a disk writer both involve interacting with // the disk, and can take a long time in aggregate when we open many files, so should be // included in the shuffle write time. - writeMetrics.incShuffleWriteTime(System.nanoTime() - openStartTime); + writeMetrics.incWriteTime(System.nanoTime() - openStartTime); while (records.hasNext()) { final Product2 record = records.next(); @@ -155,9 +154,10 @@ public void write(Iterator> records) throws IOException { writer.commitAndClose(); } - partitionLengths = - writePartitionedFile(shuffleBlockResolver.getDataFile(shuffleId, mapId)); - shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths); + File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); + File tmp = Utils.tempFileWith(output); + partitionLengths = writePartitionedFile(tmp); + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } @@ -184,22 +184,25 @@ private long[] writePartitionedFile(File outputFile) throws IOException { boolean threwException = true; try { for (int i = 0; i < numPartitions; i++) { - final FileInputStream in = new FileInputStream(partitionWriters[i].fileSegment().file()); - boolean copyThrewException = true; - try { - lengths[i] = Utils.copyStream(in, out, false, transferToEnabled); - copyThrewException = false; - } finally { - Closeables.close(in, copyThrewException); - } - if (!partitionWriters[i].fileSegment().file().delete()) { - logger.error("Unable to delete file for partition {}", i); + final File file = partitionWriters[i].fileSegment().file(); + if (file.exists()) { + final FileInputStream in = new FileInputStream(file); + boolean copyThrewException = true; + try { + lengths[i] = Utils.copyStream(in, out, false, transferToEnabled); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + } + if (!file.delete()) { + logger.error("Unable to delete file for partition {}", i); + } } } threwException = false; } finally { Closeables.close(out, threwException); - writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime); + writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); } partitionWriters = null; return lengths; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java index f8f2b220e181d..f7a6c68be9156 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java @@ -17,8 +17,6 @@ package org.apache.spark.shuffle.sort; -import org.apache.spark.memory.TaskMemoryManager; - /** * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer. *

@@ -28,7 +26,7 @@ *

* This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that * our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the - * 13-bit page numbers assigned by {@link TaskMemoryManager}), this + * 13-bit page numbers assigned by {@link org.apache.spark.memory.TaskMemoryManager}), this * implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task. *

* Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 400d8520019b9..3c2980e442ab7 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -39,6 +39,7 @@ import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.TempShuffleBlockId; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.Utils; @@ -83,9 +84,9 @@ final class ShuffleExternalSorter extends MemoryConsumer { * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager * itself). */ - private final LinkedList allocatedPages = new LinkedList(); + private final LinkedList allocatedPages = new LinkedList<>(); - private final LinkedList spills = new LinkedList(); + private final LinkedList spills = new LinkedList<>(); /** Peak memory used by this sorter so far, in bytes. **/ private long peakMemoryUsedBytes; @@ -95,7 +96,7 @@ final class ShuffleExternalSorter extends MemoryConsumer { @Nullable private MemoryBlock currentPage = null; private long pageCursor = -1; - public ShuffleExternalSorter( + ShuffleExternalSorter( TaskMemoryManager memoryManager, BlockManager blockManager, TaskContext taskContext, @@ -114,8 +115,7 @@ public ShuffleExternalSorter( this.numElementsForSpillThreshold = conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE); this.writeMetrics = writeMetrics; - acquireMemory(initialSize * 8L); - this.inMemSorter = new ShuffleInMemorySorter(initialSize); + this.inMemSorter = new ShuffleInMemorySorter(this, initialSize); this.peakMemoryUsedBytes = getMemoryUsage(); } @@ -215,8 +215,6 @@ private void writeSortedFile(boolean isLastFile) throws IOException { } } - inMemSorter.reset(); - if (!isLastFile) { // i.e. this is a spill file // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter @@ -233,8 +231,8 @@ private void writeSortedFile(boolean isLastFile) throws IOException { // Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`. // Consistent with ExternalSorter, we do not count this IO towards shuffle write time. // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this. - writeMetrics.incShuffleRecordsWritten(writeMetricsToUse.shuffleRecordsWritten()); - taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.shuffleBytesWritten()); + writeMetrics.incRecordsWritten(writeMetricsToUse.recordsWritten()); + taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.bytesWritten()); } } @@ -255,6 +253,10 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { writeSortedFile(false); final long spillSize = freeMemory(); + inMemSorter.reset(); + // Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the + // records. Otherwise, if the task is over allocated memory, then without freeing the memory pages, + // we might not be able to get memory for the pointer array. taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); return spillSize; } @@ -301,9 +303,8 @@ private long freeMemory() { public void cleanupResources() { freeMemory(); if (inMemSorter != null) { - long sorterMemoryUsage = inMemSorter.getMemoryUsage(); + inMemSorter.free(); inMemSorter = null; - releaseMemory(sorterMemoryUsage); } for (SpillInfo spill : spills) { if (spill.file.exists() && !spill.file.delete()) { @@ -321,26 +322,23 @@ private void growPointerArrayIfNecessary() throws IOException { assert(inMemSorter != null); if (!inMemSorter.hasSpaceForAnotherRecord()) { long used = inMemSorter.getMemoryUsage(); - long needed = used + inMemSorter.getMemoryToExpand(); + LongArray array; try { - acquireMemory(needed); // could trigger spilling + // could trigger spilling + array = allocateArray(used / 8 * 2); } catch (OutOfMemoryError e) { // should have trigger spilling - assert(inMemSorter.hasSpaceForAnotherRecord()); + if (!inMemSorter.hasSpaceForAnotherRecord()) { + logger.error("Unable to grow the pointer array"); + throw e; + } return; } // check if spilling is triggered or not if (inMemSorter.hasSpaceForAnotherRecord()) { - releaseMemory(needed); + freeArray(array); } else { - try { - inMemSorter.expandPointerArray(); - releaseMemory(used); - } catch (OutOfMemoryError oom) { - // Just in case that JVM had run out of memory - releaseMemory(needed); - spill(); - } + inMemSorter.expandPointerArray(array); } } } @@ -404,9 +402,8 @@ public SpillInfo[] closeAndGetSpills() throws IOException { // Do not count the final file towards the spill count. writeSortedFile(true); freeMemory(); - long sorterMemoryUsage = inMemSorter.getMemoryUsage(); + inMemSorter.free(); inMemSorter = null; - releaseMemory(sorterMemoryUsage); } return spills.toArray(new SpillInfo[spills.size()]); } catch (IOException e) { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index e630575d1ae19..76b0e6a304ac2 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -19,69 +19,86 @@ import java.util.Comparator; +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.util.collection.Sorter; final class ShuffleInMemorySorter { - private final Sorter sorter; + private final Sorter sorter; private static final class SortComparator implements Comparator { @Override public int compare(PackedRecordPointer left, PackedRecordPointer right) { - return left.getPartitionId() - right.getPartitionId(); + int leftId = left.getPartitionId(); + int rightId = right.getPartitionId(); + return leftId < rightId ? -1 : (leftId > rightId ? 1 : 0); } } private static final SortComparator SORT_COMPARATOR = new SortComparator(); + private final MemoryConsumer consumer; + /** * An array of record pointers and partition ids that have been encoded by * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating * records. */ - private long[] array; + private LongArray array; /** * The position in the pointer array where new records can be inserted. */ private int pos = 0; - public ShuffleInMemorySorter(int initialSize) { + private int initialSize; + + ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize) { + this.consumer = consumer; assert (initialSize > 0); - this.array = new long[initialSize]; + this.initialSize = initialSize; + this.array = consumer.allocateArray(initialSize); this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE); } + public void free() { + if (array != null) { + consumer.freeArray(array); + array = null; + } + } + public int numRecords() { return pos; } public void reset() { + if (consumer != null) { + consumer.freeArray(array); + this.array = consumer.allocateArray(initialSize); + } pos = 0; } - private int newLength() { - // Guard against overflow: - return array.length <= Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE; - } - - /** - * Returns the memory needed to expand - */ - public long getMemoryToExpand() { - return ((long) (newLength() - array.length)) * 8; - } - - public void expandPointerArray() { - final long[] oldArray = array; - array = new long[newLength()]; - System.arraycopy(oldArray, 0, array, 0, oldArray.length); + public void expandPointerArray(LongArray newArray) { + assert(newArray.size() > array.size()); + Platform.copyMemory( + array.getBaseObject(), + array.getBaseOffset(), + newArray.getBaseObject(), + newArray.getBaseOffset(), + array.size() * 8L + ); + consumer.freeArray(array); + array = newArray; } public boolean hasSpaceForAnotherRecord() { - return pos < array.length; + return pos < array.size(); } public long getMemoryUsage() { - return array.length * 8L; + return array.size() * 8L; } /** @@ -96,14 +113,9 @@ public long getMemoryUsage() { */ public void insertRecord(long recordPointer, int partitionId) { if (!hasSpaceForAnotherRecord()) { - if (array.length == Integer.MAX_VALUE) { - throw new IllegalStateException("Sort pointer array has reached maximum size"); - } else { - expandPointerArray(); - } + throw new IllegalStateException("There is no space for new record"); } - array[pos] = - PackedRecordPointer.packPointer(recordPointer, partitionId); + array.set(pos, PackedRecordPointer.packPointer(recordPointer, partitionId)); pos++; } @@ -112,12 +124,12 @@ public void insertRecord(long recordPointer, int partitionId) { */ public static final class ShuffleSorterIterator { - private final long[] pointerArray; + private final LongArray pointerArray; private final int numRecords; final PackedRecordPointer packedRecordPointer = new PackedRecordPointer(); private int position = 0; - public ShuffleSorterIterator(int numRecords, long[] pointerArray) { + ShuffleSorterIterator(int numRecords, LongArray pointerArray) { this.numRecords = numRecords; this.pointerArray = pointerArray; } @@ -127,7 +139,7 @@ public boolean hasNext() { } public void loadNext() { - packedRecordPointer.set(pointerArray[position]); + packedRecordPointer.set(pointerArray.get(position)); position++; } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java index 8a1e5aec6ff0e..8f4e3229976dc 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java @@ -17,16 +17,19 @@ package org.apache.spark.shuffle.sort; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.SortDataFormat; -final class ShuffleSortDataFormat extends SortDataFormat { +final class ShuffleSortDataFormat extends SortDataFormat { public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat(); private ShuffleSortDataFormat() { } @Override - public PackedRecordPointer getKey(long[] data, int pos) { + public PackedRecordPointer getKey(LongArray data, int pos) { // Since we re-use keys, this method shouldn't be called. throw new UnsupportedOperationException(); } @@ -37,31 +40,38 @@ public PackedRecordPointer newKey() { } @Override - public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) { - reuse.set(data[pos]); + public PackedRecordPointer getKey(LongArray data, int pos, PackedRecordPointer reuse) { + reuse.set(data.get(pos)); return reuse; } @Override - public void swap(long[] data, int pos0, int pos1) { - final long temp = data[pos0]; - data[pos0] = data[pos1]; - data[pos1] = temp; + public void swap(LongArray data, int pos0, int pos1) { + final long temp = data.get(pos0); + data.set(pos0, data.get(pos1)); + data.set(pos1, temp); } @Override - public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { - dst[dstPos] = src[srcPos]; + public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) { + dst.set(dstPos, src.get(srcPos)); } @Override - public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { - System.arraycopy(src, srcPos, dst, dstPos, length); + public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) { + Platform.copyMemory( + src.getBaseObject(), + src.getBaseOffset() + srcPos * 8, + dst.getBaseObject(), + dst.getBaseOffset() + dstPos * 8, + length * 8 + ); } @Override - public long[] allocate(int length) { - return new long[length]; + public LongArray allocate(int length) { + // This buffer is used temporary (usually small), so it's fine to allocated from JVM heap. + return new LongArray(MemoryBlock.fromLongArray(new long[length])); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java index df9f7b7abe028..865def6b83c53 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java @@ -29,7 +29,7 @@ final class SpillInfo { final File file; final TempShuffleBlockId blockId; - public SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) { + SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) { this.partitionLengths = new long[numPartitions]; this.file = file; this.blockId = blockId; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 6a0a89e81c321..0c5fb883a8326 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -25,7 +25,6 @@ import scala.Option; import scala.Product2; import scala.collection.JavaConverters; -import scala.collection.immutable.Map; import scala.reflect.ClassTag; import scala.reflect.ClassTag$; @@ -41,19 +40,18 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; -import org.apache.spark.io.LZFCompressionCodec; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.SerializationStream; -import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; -import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; @Private public class UnsafeShuffleWriter extends ShuffleWriter { @@ -83,7 +81,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream { - public MyByteArrayOutputStream(int size) { super(size); } + MyByteArrayOutputStream(int size) { super(size); } public byte[] getBuf() { return buf; } } @@ -109,7 +107,8 @@ public UnsafeShuffleWriter( if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { throw new IllegalArgumentException( "UnsafeShuffleWriter can only be used for shuffles with at most " + - SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() + " reduce partitions"); + SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() + + " reduce partitions"); } this.blockManager = blockManager; this.shuffleBlockResolver = shuffleBlockResolver; @@ -117,10 +116,9 @@ public UnsafeShuffleWriter( this.mapId = mapId; final ShuffleDependency dep = handle.dependency(); this.shuffleId = dep.shuffleId(); - this.serializer = Serializer.getSerializer(dep.serializer()).newInstance(); + this.serializer = dep.serializer().newInstance(); this.partitioner = dep.partitioner(); - this.writeMetrics = new ShuffleWriteMetrics(); - taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics)); + this.writeMetrics = taskContext.taskMetrics().registerShuffleWriteMetrics(); this.taskContext = taskContext; this.sparkConf = sparkConf; this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); @@ -206,8 +204,10 @@ void closeAndWriteOutput() throws IOException { final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; final long[] partitionLengths; + final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); + final File tmp = Utils.tempFileWith(output); try { - partitionLengths = mergeSpills(spills); + partitionLengths = mergeSpills(spills, tmp); } finally { for (SpillInfo spill : spills) { if (spill.file.exists() && ! spill.file.delete()) { @@ -215,7 +215,7 @@ void closeAndWriteOutput() throws IOException { } } } - shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths); + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } @@ -248,8 +248,7 @@ void forceSorterToSpill() throws IOException { * * @return the partition lengths in the merged file. */ - private long[] mergeSpills(SpillInfo[] spills) throws IOException { - final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId); + private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException { final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true); final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); final boolean fastMergeEnabled = @@ -297,8 +296,8 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException { // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs // to be counted as shuffle write, but this will lead to double-counting of the final // SpillInfo's bytes. - writeMetrics.decShuffleBytesWritten(spills[spills.length - 1].file.length()); - writeMetrics.incShuffleBytesWritten(outputFile.length()); + writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); + writeMetrics.incBytesWritten(outputFile.length()); return partitionLengths; } } catch (IOException e) { @@ -410,7 +409,7 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th spillInputChannelPositions[i] += actualBytesTransferred; bytesToTransfer -= actualBytesTransferred; } - writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime); + writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); bytesWrittenToMergedFile += partitionLengthInSpill; partitionLengths[partition] += partitionLengthInSpill; } @@ -444,13 +443,7 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th @Override public Option stop(boolean success) { try { - // Update task metrics from accumulators (null in UnsafeShuffleWriterSuite) - Map> internalAccumulators = - taskContext.internalMetricsToAccumulators(); - if (internalAccumulators != null) { - internalAccumulators.apply(InternalAccumulator.PEAK_EXECUTION_MEMORY()) - .add(getPeakMemoryUsedBytes()); - } + taskContext.taskMetrics().incPeakExecutionMemory(getPeakMemoryUsedBytes()); if (stopping) { return Option.apply(null); diff --git a/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java b/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java index f19ed01d5aebf..9307eb93a5b20 100644 --- a/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java +++ b/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java @@ -28,8 +28,8 @@ public enum TaskSorting { DECREASING_RUNTIME("-runtime"); private final Set alternateNames; - private TaskSorting(String... names) { - alternateNames = new HashSet(); + TaskSorting(String... names) { + alternateNames = new HashSet<>(); for (String n: names) { alternateNames.add(n); } diff --git a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java index dc2aa30466cc6..5d0555a8c28e1 100644 --- a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java +++ b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java @@ -42,34 +42,34 @@ public TimeTrackingOutputStream(ShuffleWriteMetrics writeMetrics, OutputStream o public void write(int b) throws IOException { final long startTime = System.nanoTime(); outputStream.write(b); - writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + writeMetrics.incWriteTime(System.nanoTime() - startTime); } @Override public void write(byte[] b) throws IOException { final long startTime = System.nanoTime(); outputStream.write(b); - writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + writeMetrics.incWriteTime(System.nanoTime() - startTime); } @Override public void write(byte[] b, int off, int len) throws IOException { final long startTime = System.nanoTime(); outputStream.write(b, off, len); - writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + writeMetrics.incWriteTime(System.nanoTime() - startTime); } @Override public void flush() throws IOException { final long startTime = System.nanoTime(); outputStream.flush(); - writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + writeMetrics.incWriteTime(System.nanoTime() - startTime); } @Override public void close() throws IOException { final long startTime = System.nanoTime(); outputStream.close(); - writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + writeMetrics.incWriteTime(System.nanoTime() - startTime); } } diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 6656fd1d0bc59..6807710f9fef1 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -20,11 +20,11 @@ import javax.annotation.Nullable; import java.io.File; import java.io.IOException; -import java.util.Arrays; import java.util.Iterator; import java.util.LinkedList; import com.google.common.annotations.VisibleForTesting; +import com.google.common.io.Closeables; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -32,13 +32,13 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.memory.MemoryLocation; import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader; import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter; @@ -56,17 +56,16 @@ * Bytes 4 to 8: len(k) * Bytes 8 to 8 + len(k): key data * Bytes 8 + len(k) to 8 + len(k) + len(v): value data + * Bytes 8 + len(k) + len(v) to 8 + len(k) + len(v) + 8: pointer to next pair * * This means that the first four bytes store the entire record (key + value) length. This format - * is consistent with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter}, + * is compatible with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter}, * so we can pass records from this map directly into the sorter to sort records in place. */ public final class BytesToBytesMap extends MemoryConsumer { private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class); - private static final Murmur3_x86_32 HASHER = new Murmur3_x86_32(0); - private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING; private final TaskMemoryManager taskMemoryManager; @@ -134,7 +133,12 @@ public final class BytesToBytesMap extends MemoryConsumer { /** * Number of keys defined in the map. */ - private int numElements; + private int numKeys; + + /** + * Number of values defined in the map. A key could have multiple values. + */ + private int numValues; /** * The map will be expanded once the number of keys exceeds this threshold. @@ -166,12 +170,14 @@ public final class BytesToBytesMap extends MemoryConsumer { private long peakMemoryUsedBytes = 0L; private final BlockManager blockManager; + private final SerializerManager serializerManager; private volatile MapIterator destructiveIterator = null; private LinkedList spillWriters = new LinkedList<>(); public BytesToBytesMap( TaskMemoryManager taskMemoryManager, BlockManager blockManager, + SerializerManager serializerManager, int initialCapacity, double loadFactor, long pageSizeBytes, @@ -179,6 +185,7 @@ public BytesToBytesMap( super(taskMemoryManager, pageSizeBytes); this.taskMemoryManager = taskMemoryManager; this.blockManager = blockManager; + this.serializerManager = serializerManager; this.loadFactor = loadFactor; this.loc = new Location(); this.pageSizeBytes = pageSizeBytes; @@ -212,6 +219,7 @@ public BytesToBytesMap( this( taskMemoryManager, SparkEnv.get() != null ? SparkEnv.get().blockManager() : null, + SparkEnv.get() != null ? SparkEnv.get().serializerManager() : null, initialCapacity, 0.70, pageSizeBytes, @@ -221,7 +229,12 @@ public BytesToBytesMap( /** * Returns the number of keys defined in the map. */ - public int numElements() { return numElements; } + public int numKeys() { return numKeys; } + + /** + * Returns the number of values defined in the map. A key could have multiple values. + */ + public int numValues() { return numValues; } public final class MapIterator implements Iterator { @@ -273,7 +286,8 @@ private void advanceToNextPage() { } } try { - reader = spillWriters.getFirst().getReader(blockManager); + Closeables.close(reader, /* swallowIOException = */ false); + reader = spillWriters.getFirst().getReader(serializerManager); recordsInPage = -1; } catch (IOException e) { // Scala iterator does not handle exception @@ -308,7 +322,8 @@ public Location next() { if (currentPage != null) { int totalLength = Platform.getInt(pageBaseObject, offsetInPage); loc.with(currentPage, offsetInPage); - offsetInPage += 4 + totalLength; + // [total size] [key size] [key] [value] [pointer to next] + offsetInPage += 4 + totalLength + 8; recordsInPage --; return loc; } else { @@ -319,6 +334,11 @@ public Location next() { try { reader.loadNext(); } catch (IOException e) { + try { + reader.close(); + } catch(IOException e2) { + logger.error("Error while closing spill reader", e2); + } // Scala iterator does not handle exception Platform.throwException(e); } @@ -353,7 +373,7 @@ public long spill(long numBytes) throws IOException { while (numRecords > 0) { int length = Platform.getInt(base, offset); writer.write(base, offset + 4, length, 0); - offset += 4 + length; + offset += 4 + length + 8; numRecords--; } writer.close(); @@ -387,7 +407,7 @@ public void remove() { * `lookup()`, the behavior of the returned iterator is undefined. */ public MapIterator iterator() { - return new MapIterator(numElements, loc, false); + return new MapIterator(numValues, loc, false); } /** @@ -401,7 +421,7 @@ public MapIterator iterator() { * `lookup()`, the behavior of the returned iterator is undefined. */ public MapIterator destructiveIterator() { - return new MapIterator(numElements, loc, true); + return new MapIterator(numValues, loc, true); } /** @@ -411,7 +431,19 @@ public MapIterator destructiveIterator() { * This function always return the same {@link Location} instance to avoid object allocation. */ public Location lookup(Object keyBase, long keyOffset, int keyLength) { - safeLookup(keyBase, keyOffset, keyLength, loc); + safeLookup(keyBase, keyOffset, keyLength, loc, + Murmur3_x86_32.hashUnsafeWords(keyBase, keyOffset, keyLength, 42)); + return loc; + } + + /** + * Looks up a key, and return a {@link Location} handle that can be used to test existence + * and read/write values. + * + * This function always return the same {@link Location} instance to avoid object allocation. + */ + public Location lookup(Object keyBase, long keyOffset, int keyLength, int hash) { + safeLookup(keyBase, keyOffset, keyLength, loc, hash); return loc; } @@ -420,14 +452,13 @@ public Location lookup(Object keyBase, long keyOffset, int keyLength) { * * This is a thread-safe version of `lookup`, could be used by multiple threads. */ - public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc) { + public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc, int hash) { assert(longArray != null); if (enablePerfMetrics) { numKeyLookups++; } - final int hashcode = HASHER.hashUnsafeWords(keyBase, keyOffset, keyLength); - int pos = hashcode & mask; + int pos = hash & mask; int step = 1; while (true) { if (enablePerfMetrics) { @@ -435,22 +466,19 @@ public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location l } if (longArray.get(pos * 2) == 0) { // This is a new key. - loc.with(pos, hashcode, false); + loc.with(pos, hash, false); return; } else { long stored = longArray.get(pos * 2 + 1); - if ((int) (stored) == hashcode) { + if ((int) (stored) == hash) { // Full hash code matches. Let's compare the keys for equality. - loc.with(pos, hashcode, true); + loc.with(pos, hash, true); if (loc.getKeyLength() == keyLength) { - final MemoryLocation keyAddress = loc.getKeyAddress(); - final Object storedkeyBase = keyAddress.getBaseObject(); - final long storedkeyOffset = keyAddress.getBaseOffset(); final boolean areEqual = ByteArrayMethods.arrayEquals( keyBase, keyOffset, - storedkeyBase, - storedkeyOffset, + loc.getKeyBase(), + loc.getKeyOffset(), keyLength ); if (areEqual) { @@ -478,13 +506,14 @@ public final class Location { private boolean isDefined; /** * The hashcode of the most recent key passed to - * {@link BytesToBytesMap#lookup(Object, long, int)}. Caching this hashcode here allows us to - * avoid re-hashing the key when storing a value for that key. + * {@link BytesToBytesMap#lookup(Object, long, int, int)}. Caching this hashcode here allows us + * to avoid re-hashing the key when storing a value for that key. */ private int keyHashcode; - private final MemoryLocation keyMemoryLocation = new MemoryLocation(); - private final MemoryLocation valueMemoryLocation = new MemoryLocation(); + private Object baseObject; // the base object for key and value + private long keyOffset; private int keyLength; + private long valueOffset; private int valueLength; /** @@ -498,18 +527,15 @@ private void updateAddressesAndSizes(long fullKeyAddress) { taskMemoryManager.getOffsetInPage(fullKeyAddress)); } - private void updateAddressesAndSizes(final Object base, final long offset) { - long position = offset; - final int totalLength = Platform.getInt(base, position); - position += 4; - keyLength = Platform.getInt(base, position); - position += 4; + private void updateAddressesAndSizes(final Object base, long offset) { + baseObject = base; + final int totalLength = Platform.getInt(base, offset); + offset += 4; + keyLength = Platform.getInt(base, offset); + offset += 4; + keyOffset = offset; + valueOffset = offset + keyLength; valueLength = totalLength - keyLength - 4; - - keyMemoryLocation.setObjAndOffset(base, position); - - position += keyLength; - valueMemoryLocation.setObjAndOffset(base, position); } private Location with(int pos, int keyHashcode, boolean isDefined) { @@ -537,13 +563,28 @@ private Location with(MemoryBlock page, long offsetInPage) { private Location with(Object base, long offset, int length) { this.isDefined = true; this.memoryPage = null; + baseObject = base; + keyOffset = offset + 4; keyLength = Platform.getInt(base, offset); + valueOffset = offset + 4 + keyLength; valueLength = length - 4 - keyLength; - keyMemoryLocation.setObjAndOffset(base, offset + 4); - valueMemoryLocation.setObjAndOffset(base, offset + 4 + keyLength); return this; } + /** + * Find the next pair that has the same key as current one. + */ + public boolean nextValue() { + assert isDefined; + long nextAddr = Platform.getLong(baseObject, valueOffset + valueLength); + if (nextAddr == 0) { + return false; + } else { + updateAddressesAndSizes(nextAddr); + return true; + } + } + /** * Returns the memory page that contains the current record. * This is only valid if this is returned by {@link BytesToBytesMap#iterator()}. @@ -560,34 +601,44 @@ public boolean isDefined() { } /** - * Returns the address of the key defined at this position. - * This points to the first byte of the key data. - * Unspecified behavior if the key is not defined. - * For efficiency reasons, calls to this method always returns the same MemoryLocation object. + * Returns the base object for key. */ - public MemoryLocation getKeyAddress() { + public Object getKeyBase() { assert (isDefined); - return keyMemoryLocation; + return baseObject; } /** - * Returns the length of the key defined at this position. - * Unspecified behavior if the key is not defined. + * Returns the offset for key. */ - public int getKeyLength() { + public long getKeyOffset() { assert (isDefined); - return keyLength; + return keyOffset; + } + + /** + * Returns the base object for value. + */ + public Object getValueBase() { + assert (isDefined); + return baseObject; + } + + /** + * Returns the offset for value. + */ + public long getValueOffset() { + assert (isDefined); + return valueOffset; } /** - * Returns the address of the value defined at this position. - * This points to the first byte of the value data. + * Returns the length of the key defined at this position. * Unspecified behavior if the key is not defined. - * For efficiency reasons, calls to this method always returns the same MemoryLocation object. */ - public MemoryLocation getValueAddress() { + public int getKeyLength() { assert (isDefined); - return valueMemoryLocation; + return keyLength; } /** @@ -600,10 +651,9 @@ public int getValueLength() { } /** - * Store a new key and value. This method may only be called once for a given key; if you want - * to update the value associated with a key, then you can directly manipulate the bytes stored - * at the value address. The return value indicates whether the put succeeded or whether it - * failed because additional memory could not be acquired. + * Append a new value for the key. This method could be called multiple times for a given key. + * The return value indicates whether the put succeeded or whether it failed because additional + * memory could not be acquired. *

* It is only valid to call this method immediately after calling `lookup()` using the same key. *

@@ -612,7 +662,7 @@ public int getValueLength() { *

*

* After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length` - * will return information on the data stored by this `putNewKey` call. + * will return information on the data stored by this `append` call. *

*

* As an example usage, here's the proper way to store a new key: @@ -620,7 +670,7 @@ public int getValueLength() { *

      *   Location loc = map.lookup(keyBase, keyOffset, keyLength);
      *   if (!loc.isDefined()) {
-     *     if (!loc.putNewKey(keyBase, keyOffset, keyLength, ...)) {
+     *     if (!loc.append(keyBase, keyOffset, keyLength, ...)) {
      *       // handle failure to grow map (by spilling, for example)
      *     }
      *   }
@@ -632,26 +682,23 @@ public int getValueLength() {
      * @return true if the put() was successful and false if the put() failed because memory could
      *         not be acquired.
      */
-    public boolean putNewKey(Object keyBase, long keyOffset, int keyLength,
-        Object valueBase, long valueOffset, int valueLength) {
-      assert (!isDefined) : "Can only set value once for a key";
-      assert (keyLength % 8 == 0);
-      assert (valueLength % 8 == 0);
-      assert(longArray != null);
+    public boolean append(Object kbase, long koff, int klen, Object vbase, long voff, int vlen) {
+      assert (klen % 8 == 0);
+      assert (vlen % 8 == 0);
+      assert (longArray != null);
 
-
-      if (numElements == MAX_CAPACITY
+      if (numKeys == MAX_CAPACITY
         // The map could be reused from last spill (because of no enough memory to grow),
         // then we don't try to grow again if hit the `growthThreshold`.
-        || !canGrowArray && numElements > growthThreshold) {
+        || !canGrowArray && numKeys > growthThreshold) {
         return false;
       }
 
       // Here, we'll copy the data into our data pages. Because we only store a relative offset from
       // the key address instead of storing the absolute address of the value, the key and value
       // must be stored in the same memory page.
-      // (8 byte key length) (key) (value)
-      final long recordLength = 8 + keyLength + valueLength;
+      // (8 byte key length) (key) (value) (8 byte pointer to next value)
+      final long recordLength = 8 + klen + vlen + 8;
       if (currentPage == null || currentPage.size() - pageCursor < recordLength) {
         if (!acquireNewPage(recordLength + 4L)) {
           return false;
@@ -662,30 +709,36 @@ public boolean putNewKey(Object keyBase, long keyOffset, int keyLength,
       final Object base = currentPage.getBaseObject();
       long offset = currentPage.getBaseOffset() + pageCursor;
       final long recordOffset = offset;
-      Platform.putInt(base, offset, keyLength + valueLength + 4);
-      Platform.putInt(base, offset + 4, keyLength);
+      Platform.putInt(base, offset, klen + vlen + 4);
+      Platform.putInt(base, offset + 4, klen);
       offset += 8;
-      Platform.copyMemory(keyBase, keyOffset, base, offset, keyLength);
-      offset += keyLength;
-      Platform.copyMemory(valueBase, valueOffset, base, offset, valueLength);
-
-      // --- Update bookkeeping data structures -----------------------------------------------------
+      Platform.copyMemory(kbase, koff, base, offset, klen);
+      offset += klen;
+      Platform.copyMemory(vbase, voff, base, offset, vlen);
+      offset += vlen;
+      // put this value at the beginning of the list
+      Platform.putLong(base, offset, isDefined ? longArray.get(pos * 2) : 0);
+
+      // --- Update bookkeeping data structures ----------------------------------------------------
       offset = currentPage.getBaseOffset();
       Platform.putInt(base, offset, Platform.getInt(base, offset) + 1);
       pageCursor += recordLength;
-      numElements++;
       final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset(
         currentPage, recordOffset);
       longArray.set(pos * 2, storedKeyAddress);
-      longArray.set(pos * 2 + 1, keyHashcode);
       updateAddressesAndSizes(storedKeyAddress);
-      isDefined = true;
+      numValues++;
+      if (!isDefined) {
+        numKeys++;
+        longArray.set(pos * 2 + 1, keyHashcode);
+        isDefined = true;
 
-      if (numElements > growthThreshold && longArray.size() < MAX_CAPACITY) {
-        try {
-          growAndRehash();
-        } catch (OutOfMemoryError oom) {
-          canGrowArray = false;
+        if (numKeys > growthThreshold && longArray.size() < MAX_CAPACITY) {
+          try {
+            growAndRehash();
+          } catch (OutOfMemoryError oom) {
+            canGrowArray = false;
+          }
         }
       }
       return true;
@@ -724,11 +777,10 @@ public long spill(long size, MemoryConsumer trigger) throws IOException {
    */
   private void allocate(int capacity) {
     assert (capacity >= 0);
-    // The capacity needs to be divisible by 64 so that our bit set can be sized properly
     capacity = Math.max((int) Math.min(MAX_CAPACITY, ByteArrayMethods.nextPowerOf2(capacity)), 64);
     assert (capacity <= MAX_CAPACITY);
-    acquireMemory(capacity * 16);
-    longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2]));
+    longArray = allocateArray(capacity * 2);
+    longArray.zeroOut();
 
     this.growthThreshold = (int) (capacity * loadFactor);
     this.mask = capacity - 1;
@@ -743,9 +795,8 @@ private void allocate(int capacity) {
   public void free() {
     updatePeakMemoryUsed();
     if (longArray != null) {
-      long used = longArray.memoryBlock().size();
+      freeArray(longArray);
       longArray = null;
-      releaseMemory(used);
     }
     Iterator dataPagesIterator = dataPages.iterator();
     while (dataPagesIterator.hasNext()) {
@@ -834,17 +885,19 @@ public int getNumDataPages() {
   /**
    * Returns the underline long[] of longArray.
    */
-  public long[] getArray() {
+  public LongArray getArray() {
     assert(longArray != null);
-    return (long[]) longArray.memoryBlock().getBaseObject();
+    return longArray;
   }
 
   /**
    * Reset this map to initialized state.
    */
   public void reset() {
-    numElements = 0;
-    Arrays.fill(getArray(), 0);
+    numKeys = 0;
+    numValues = 0;
+    longArray.zeroOut();
+
     while (dataPages.size() > 0) {
       MemoryBlock dataPage = dataPages.removeLast();
       freePage(dataPage);
@@ -887,7 +940,7 @@ void growAndRehash() {
       longArray.set(newPos * 2, keyPointer);
       longArray.set(newPos * 2 + 1, hashcode);
     }
-    releaseMemory(oldLongArray.memoryBlock().size());
+    freeArray(oldLongArray);
 
     if (enablePerfMetrics) {
       timeSpentResizingNs += System.nanoTime() - resizeStartTime;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
index d2bf297c6c178..c2a8f429beca4 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
@@ -20,7 +20,6 @@
 import com.google.common.primitives.UnsignedLongs;
 
 import org.apache.spark.annotation.Private;
-import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.types.ByteArray;
 import org.apache.spark.unsafe.types.UTF8String;
 import org.apache.spark.util.Utils;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java
index dbf6770e07391..de92b8db47131 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java
@@ -17,11 +17,9 @@
 
 package org.apache.spark.util.collection.unsafe.sort;
 
-import org.apache.spark.memory.TaskMemoryManager;
-
 final class RecordPointerAndKeyPrefix {
   /**
-   * A pointer to a record; see {@link TaskMemoryManager} for a
+   * A pointer to a record; see {@link org.apache.spark.memory.TaskMemoryManager} for a
    * description of how these addresses are encoded.
    */
   public long recordPointer;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index cba043bc48cc8..ef79b49083479 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -21,6 +21,7 @@
 import java.io.File;
 import java.io.IOException;
 import java.util.LinkedList;
+import java.util.Queue;
 
 import com.google.common.annotations.VisibleForTesting;
 import org.slf4j.Logger;
@@ -30,8 +31,10 @@
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.memory.MemoryConsumer;
 import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.serializer.SerializerManager;
 import org.apache.spark.storage.BlockManager;
 import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
 import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.util.TaskCompletionListener;
 import org.apache.spark.util.Utils;
@@ -43,10 +46,13 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
 
   private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class);
 
+  @Nullable
   private final PrefixComparator prefixComparator;
+  @Nullable
   private final RecordComparator recordComparator;
   private final TaskMemoryManager taskMemoryManager;
   private final BlockManager blockManager;
+  private final SerializerManager serializerManager;
   private final TaskContext taskContext;
   private ShuffleWriteMetrics writeMetrics;
 
@@ -74,6 +80,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
   public static UnsafeExternalSorter createWithExistingInMemorySorter(
       TaskMemoryManager taskMemoryManager,
       BlockManager blockManager,
+      SerializerManager serializerManager,
       TaskContext taskContext,
       RecordComparator recordComparator,
       PrefixComparator prefixComparator,
@@ -81,7 +88,8 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter(
       long pageSizeBytes,
       UnsafeInMemorySorter inMemorySorter) throws IOException {
     UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager,
-      taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter);
+      serializerManager, taskContext, recordComparator, prefixComparator, initialSize,
+        pageSizeBytes, inMemorySorter);
     sorter.spill(Long.MAX_VALUE, sorter);
     // The external sorter will be used to insert records, in-memory sorter is not needed.
     sorter.inMemSorter = null;
@@ -91,18 +99,20 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter(
   public static UnsafeExternalSorter create(
       TaskMemoryManager taskMemoryManager,
       BlockManager blockManager,
+      SerializerManager serializerManager,
       TaskContext taskContext,
       RecordComparator recordComparator,
       PrefixComparator prefixComparator,
       int initialSize,
       long pageSizeBytes) {
-    return new UnsafeExternalSorter(taskMemoryManager, blockManager,
+    return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager,
       taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null);
   }
 
   private UnsafeExternalSorter(
       TaskMemoryManager taskMemoryManager,
       BlockManager blockManager,
+      SerializerManager serializerManager,
       TaskContext taskContext,
       RecordComparator recordComparator,
       PrefixComparator prefixComparator,
@@ -112,20 +122,18 @@ private UnsafeExternalSorter(
     super(taskMemoryManager, pageSizeBytes);
     this.taskMemoryManager = taskMemoryManager;
     this.blockManager = blockManager;
+    this.serializerManager = serializerManager;
     this.taskContext = taskContext;
     this.recordComparator = recordComparator;
     this.prefixComparator = prefixComparator;
     // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units
     // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
     this.fileBufferSizeBytes = 32 * 1024;
-    // TODO: metrics tracking + integration with shuffle write metrics
-    // need to connect the write metrics to task metrics so we count the spill IO somewhere.
-    this.writeMetrics = new ShuffleWriteMetrics();
+    this.writeMetrics = taskContext.taskMetrics().registerShuffleWriteMetrics();
 
     if (existingInMemorySorter == null) {
-      this.inMemSorter =
-        new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize);
-      acquireMemory(inMemSorter.getMemoryUsage());
+      this.inMemSorter = new UnsafeInMemorySorter(
+        this, taskMemoryManager, recordComparator, prefixComparator, initialSize);
     } else {
       this.inMemSorter = existingInMemorySorter;
     }
@@ -192,14 +200,17 @@ public long spill(long size, MemoryConsumer trigger) throws IOException {
         spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix());
       }
       spillWriter.close();
-
-      inMemSorter.reset();
     }
 
     final long spillSize = freeMemory();
     // Note that this is more-or-less going to be a multiple of the page size, so wasted space in
     // pages will currently be counted as memory spilled even though that space isn't actually
     // written to disk. This also counts the space needed to store the sorter's pointer array.
+    inMemSorter.reset();
+    // Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the
+    // records. Otherwise, if the task is over allocated memory, then without freeing the memory pages,
+    // we might not be able to get memory for the pointer array.
+
     taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
 
     return spillSize;
@@ -277,9 +288,8 @@ public void cleanupResources() {
       deleteSpillFiles();
       freeMemory();
       if (inMemSorter != null) {
-        long used = inMemSorter.getMemoryUsage();
+        inMemSorter.free();
         inMemSorter = null;
-        releaseMemory(used);
       }
     }
   }
@@ -293,26 +303,23 @@ private void growPointerArrayIfNecessary() throws IOException {
     assert(inMemSorter != null);
     if (!inMemSorter.hasSpaceForAnotherRecord()) {
       long used = inMemSorter.getMemoryUsage();
-      long needed = used + inMemSorter.getMemoryToExpand();
+      LongArray array;
       try {
-        acquireMemory(needed);  // could trigger spilling
+        // could trigger spilling
+        array = allocateArray(used / 8 * 2);
       } catch (OutOfMemoryError e) {
         // should have trigger spilling
-        assert(inMemSorter.hasSpaceForAnotherRecord());
+        if (!inMemSorter.hasSpaceForAnotherRecord()) {
+          logger.error("Unable to grow the pointer array");
+          throw e;
+        }
         return;
       }
       // check if spilling is triggered or not
       if (inMemSorter.hasSpaceForAnotherRecord()) {
-        releaseMemory(needed);
+        freeArray(array);
       } else {
-        try {
-          inMemSorter.expandPointerArray();
-          releaseMemory(used);
-        } catch (OutOfMemoryError oom) {
-          // Just in case that JVM had run out of memory
-          releaseMemory(needed);
-          spill();
-        }
+        inMemSorter.expandPointerArray(array);
       }
     }
   }
@@ -406,6 +413,7 @@ public void merge(UnsafeExternalSorter other) throws IOException {
    * after consuming this iterator.
    */
   public UnsafeSorterIterator getSortedIterator() throws IOException {
+    assert(recordComparator != null);
     if (spillWriters.isEmpty()) {
       assert(inMemSorter != null);
       readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
@@ -414,7 +422,7 @@ public UnsafeSorterIterator getSortedIterator() throws IOException {
       final UnsafeSorterSpillMerger spillMerger =
         new UnsafeSorterSpillMerger(recordComparator, prefixComparator, spillWriters.size());
       for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
-        spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager));
+        spillMerger.addSpillIfNotEmpty(spillWriter.getReader(serializerManager));
       }
       if (inMemSorter != null) {
         readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
@@ -434,9 +442,13 @@ class SpillableIterator extends UnsafeSorterIterator {
     private boolean loaded = false;
     private int numRecords = 0;
 
-    public SpillableIterator(UnsafeInMemorySorter.SortedIterator inMemIterator) {
+    SpillableIterator(UnsafeInMemorySorter.SortedIterator inMemIterator) {
       this.upstream = inMemIterator;
-      this.numRecords = inMemIterator.numRecordsLeft();
+      this.numRecords = inMemIterator.getNumRecords();
+    }
+
+    public int getNumRecords() {
+      return numRecords;
     }
 
     public long spill() throws IOException {
@@ -449,6 +461,7 @@ public long spill() throws IOException {
         UnsafeInMemorySorter.SortedIterator inMemIterator =
           ((UnsafeInMemorySorter.SortedIterator) upstream).clone();
 
+        // Iterate over the records that have not been returned and spill them.
         final UnsafeSorterSpillWriter spillWriter =
           new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords);
         while (inMemIterator.hasNext()) {
@@ -460,13 +473,15 @@ public long spill() throws IOException {
         }
         spillWriter.close();
         spillWriters.add(spillWriter);
-        nextUpstream = spillWriter.getReader(blockManager);
+        nextUpstream = spillWriter.getReader(serializerManager);
 
         long released = 0L;
         synchronized (UnsafeExternalSorter.this) {
-          // release the pages except the one that is used
+          // release the pages except the one that is used. There can still be a caller that
+          // is accessing the current record. We free this page in that caller's next loadNext()
+          // call.
           for (MemoryBlock page : allocatedPages) {
-            if (!loaded || page.getBaseObject() != inMemIterator.getBaseObject()) {
+            if (!loaded || page.getBaseObject() != upstream.getBaseObject()) {
               released += page.size();
               freePage(page);
             } else {
@@ -475,6 +490,12 @@ public long spill() throws IOException {
           }
           allocatedPages.clear();
         }
+
+        // in-memory sorter will not be used after spilling
+        assert(inMemSorter != null);
+        released += inMemSorter.getMemoryUsage();
+        inMemSorter.free();
+        inMemSorter = null;
         return released;
       }
     }
@@ -496,11 +517,6 @@ public void loadNext() throws IOException {
           }
           upstream = nextUpstream;
           nextUpstream = null;
-
-          assert(inMemSorter != null);
-          long used = inMemSorter.getMemoryUsage();
-          inMemSorter = null;
-          releaseMemory(used);
         }
         numRecords--;
         upstream.loadNext();
@@ -527,4 +543,81 @@ public long getKeyPrefix() {
       return upstream.getKeyPrefix();
     }
   }
+
+  /**
+   * Returns a iterator, which will return the rows in the order as inserted.
+   *
+   * It is the caller's responsibility to call `cleanupResources()`
+   * after consuming this iterator.
+   *
+   * TODO: support forced spilling
+   */
+  public UnsafeSorterIterator getIterator() throws IOException {
+    if (spillWriters.isEmpty()) {
+      assert(inMemSorter != null);
+      return inMemSorter.getSortedIterator();
+    } else {
+      LinkedList queue = new LinkedList<>();
+      for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
+        queue.add(spillWriter.getReader(serializerManager));
+      }
+      if (inMemSorter != null) {
+        queue.add(inMemSorter.getSortedIterator());
+      }
+      return new ChainedIterator(queue);
+    }
+  }
+
+  /**
+   * Chain multiple UnsafeSorterIterator together as single one.
+   */
+  static class ChainedIterator extends UnsafeSorterIterator {
+
+    private final Queue iterators;
+    private UnsafeSorterIterator current;
+    private int numRecords;
+
+    ChainedIterator(Queue iterators) {
+      assert iterators.size() > 0;
+      this.numRecords = 0;
+      for (UnsafeSorterIterator iter: iterators) {
+        this.numRecords += iter.getNumRecords();
+      }
+      this.iterators = iterators;
+      this.current = iterators.remove();
+    }
+
+    @Override
+    public int getNumRecords() {
+      return numRecords;
+    }
+
+    @Override
+    public boolean hasNext() {
+      while (!current.hasNext() && !iterators.isEmpty()) {
+        current = iterators.remove();
+      }
+      return current.hasNext();
+    }
+
+    @Override
+    public void loadNext() throws IOException {
+      while (!current.hasNext() && !iterators.isEmpty()) {
+        current = iterators.remove();
+      }
+      current.loadNext();
+    }
+
+    @Override
+    public Object getBaseObject() { return current.getBaseObject(); }
+
+    @Override
+    public long getBaseOffset() { return current.getBaseOffset(); }
+
+    @Override
+    public int getRecordLength() { return current.getRecordLength(); }
+
+    @Override
+    public long getKeyPrefix() { return current.getKeyPrefix(); }
+  }
 }
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index d57213b9b8bfc..01eae0e8dc14c 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -19,8 +19,12 @@
 
 import java.util.Comparator;
 
+import org.apache.avro.reflect.Nullable;
+
+import org.apache.spark.memory.MemoryConsumer;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
 import org.apache.spark.util.collection.Sorter;
 
 /**
@@ -62,41 +66,70 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) {
     }
   }
 
+  private final MemoryConsumer consumer;
   private final TaskMemoryManager memoryManager;
-  private final Sorter sorter;
+  @Nullable
+  private final Sorter sorter;
+  @Nullable
   private final Comparator sortComparator;
 
   /**
    * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at
    * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
    */
-  private long[] array;
+  private LongArray array;
 
   /**
    * The position in the sort buffer where new records can be inserted.
    */
   private int pos = 0;
 
+  private long initialSize;
+
   public UnsafeInMemorySorter(
+    final MemoryConsumer consumer,
     final TaskMemoryManager memoryManager,
     final RecordComparator recordComparator,
     final PrefixComparator prefixComparator,
     int initialSize) {
-    this(memoryManager, recordComparator, prefixComparator, new long[initialSize * 2]);
+    this(consumer, memoryManager, recordComparator, prefixComparator,
+      consumer.allocateArray(initialSize * 2));
   }
 
   public UnsafeInMemorySorter(
+    final MemoryConsumer consumer,
       final TaskMemoryManager memoryManager,
       final RecordComparator recordComparator,
       final PrefixComparator prefixComparator,
-      long[] array) {
-    this.array = array;
+      LongArray array) {
+    this.consumer = consumer;
     this.memoryManager = memoryManager;
-    this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
-    this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
+    this.initialSize = array.size();
+    if (recordComparator != null) {
+      this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
+      this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
+    } else {
+      this.sorter = null;
+      this.sortComparator = null;
+    }
+    this.array = array;
+  }
+
+  /**
+   * Free the memory used by pointer array.
+   */
+  public void free() {
+    if (consumer != null) {
+      consumer.freeArray(array);
+      array = null;
+    }
   }
 
   public void reset() {
+    if (consumer != null) {
+      consumer.freeArray(array);
+      this.array = consumer.allocateArray(initialSize);
+    }
     pos = 0;
   }
 
@@ -107,26 +140,26 @@ public int numRecords() {
     return pos / 2;
   }
 
-  private int newLength() {
-    return array.length < Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE;
-  }
-
-  public long getMemoryToExpand() {
-    return (long) (newLength() - array.length) * 8L;
-  }
-
   public long getMemoryUsage() {
-    return array.length * 8L;
+    return array.size() * 8L;
   }
 
   public boolean hasSpaceForAnotherRecord() {
-    return pos + 2 <= array.length;
+    return pos + 2 <= array.size();
   }
 
-  public void expandPointerArray() {
-    final long[] oldArray = array;
-    array = new long[newLength()];
-    System.arraycopy(oldArray, 0, array, 0, oldArray.length);
+  public void expandPointerArray(LongArray newArray) {
+    if (newArray.size() < array.size()) {
+      throw new OutOfMemoryError("Not enough memory to grow pointer array");
+    }
+    Platform.copyMemory(
+      array.getBaseObject(),
+      array.getBaseOffset(),
+      newArray.getBaseObject(),
+      newArray.getBaseOffset(),
+      array.size() * 8L);
+    consumer.freeArray(array);
+    array = newArray;
   }
 
   /**
@@ -138,36 +171,30 @@ public void expandPointerArray() {
    */
   public void insertRecord(long recordPointer, long keyPrefix) {
     if (!hasSpaceForAnotherRecord()) {
-      expandPointerArray();
+      throw new IllegalStateException("There is no space for new record");
     }
-    array[pos] = recordPointer;
+    array.set(pos, recordPointer);
     pos++;
-    array[pos] = keyPrefix;
+    array.set(pos, keyPrefix);
     pos++;
   }
 
-  public static final class SortedIterator extends UnsafeSorterIterator {
+  public final class SortedIterator extends UnsafeSorterIterator implements Cloneable {
 
-    private final TaskMemoryManager memoryManager;
-    private final int sortBufferInsertPosition;
-    private final long[] sortBuffer;
-    private int position = 0;
+    private final int numRecords;
+    private int position;
     private Object baseObject;
     private long baseOffset;
     private long keyPrefix;
     private int recordLength;
 
-    private SortedIterator(
-        TaskMemoryManager memoryManager,
-        int sortBufferInsertPosition,
-        long[] sortBuffer) {
-      this.memoryManager = memoryManager;
-      this.sortBufferInsertPosition = sortBufferInsertPosition;
-      this.sortBuffer = sortBuffer;
+    private SortedIterator(int numRecords) {
+      this.numRecords = numRecords;
+      this.position = 0;
     }
 
-    public SortedIterator clone () {
-      SortedIterator iter = new SortedIterator(memoryManager, sortBufferInsertPosition, sortBuffer);
+    public SortedIterator clone() {
+      SortedIterator iter = new SortedIterator(numRecords);
       iter.position = position;
       iter.baseObject = baseObject;
       iter.baseOffset = baseOffset;
@@ -177,22 +204,23 @@ public SortedIterator clone () {
     }
 
     @Override
-    public boolean hasNext() {
-      return position < sortBufferInsertPosition;
+    public int getNumRecords() {
+      return numRecords;
     }
 
-    public int numRecordsLeft() {
-      return (sortBufferInsertPosition - position) / 2;
+    @Override
+    public boolean hasNext() {
+      return position / 2 < numRecords;
     }
 
     @Override
     public void loadNext() {
       // This pointer points to a 4-byte record length, followed by the record's bytes
-      final long recordPointer = sortBuffer[position];
+      final long recordPointer = array.get(position);
       baseObject = memoryManager.getPage(recordPointer);
       baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4;  // Skip over record length
       recordLength = Platform.getInt(baseObject, baseOffset - 4);
-      keyPrefix = sortBuffer[position + 1];
+      keyPrefix = array.get(position + 1);
       position += 2;
     }
 
@@ -214,7 +242,9 @@ public void loadNext() {
    * {@code next()} will return the same mutable object.
    */
   public SortedIterator getSortedIterator() {
-    sorter.sort(array, 0, pos / 2, sortComparator);
-    return new SortedIterator(memoryManager, pos, array);
+    if (sorter != null) {
+      sorter.sort(array, 0, pos / 2, sortComparator);
+    }
+    return new SortedIterator(pos / 2);
   }
 }
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
index d09c728a7a638..12fb62fb77f0f 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
@@ -17,6 +17,9 @@
 
 package org.apache.spark.util.collection.unsafe.sort;
 
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.util.collection.SortDataFormat;
 
 /**
@@ -26,14 +29,14 @@
  * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at
  * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
  */
-final class UnsafeSortDataFormat extends SortDataFormat {
+final class UnsafeSortDataFormat extends SortDataFormat {
 
   public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat();
 
   private UnsafeSortDataFormat() { }
 
   @Override
-  public RecordPointerAndKeyPrefix getKey(long[] data, int pos) {
+  public RecordPointerAndKeyPrefix getKey(LongArray data, int pos) {
     // Since we re-use keys, this method shouldn't be called.
     throw new UnsupportedOperationException();
   }
@@ -44,37 +47,44 @@ public RecordPointerAndKeyPrefix newKey() {
   }
 
   @Override
-  public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix reuse) {
-    reuse.recordPointer = data[pos * 2];
-    reuse.keyPrefix = data[pos * 2 + 1];
+  public RecordPointerAndKeyPrefix getKey(LongArray data, int pos,
+                                          RecordPointerAndKeyPrefix reuse) {
+    reuse.recordPointer = data.get(pos * 2);
+    reuse.keyPrefix = data.get(pos * 2 + 1);
     return reuse;
   }
 
   @Override
-  public void swap(long[] data, int pos0, int pos1) {
-    long tempPointer = data[pos0 * 2];
-    long tempKeyPrefix = data[pos0 * 2 + 1];
-    data[pos0 * 2] = data[pos1 * 2];
-    data[pos0 * 2 + 1] = data[pos1 * 2 + 1];
-    data[pos1 * 2] = tempPointer;
-    data[pos1 * 2 + 1] = tempKeyPrefix;
+  public void swap(LongArray data, int pos0, int pos1) {
+    long tempPointer = data.get(pos0 * 2);
+    long tempKeyPrefix = data.get(pos0 * 2 + 1);
+    data.set(pos0 * 2, data.get(pos1 * 2));
+    data.set(pos0 * 2 + 1, data.get(pos1 * 2 + 1));
+    data.set(pos1 * 2, tempPointer);
+    data.set(pos1 * 2 + 1, tempKeyPrefix);
   }
 
   @Override
-  public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
-    dst[dstPos * 2] = src[srcPos * 2];
-    dst[dstPos * 2 + 1] = src[srcPos * 2 + 1];
+  public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) {
+    dst.set(dstPos * 2, src.get(srcPos * 2));
+    dst.set(dstPos * 2 + 1, src.get(srcPos * 2 + 1));
   }
 
   @Override
-  public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
-    System.arraycopy(src, srcPos * 2, dst, dstPos * 2, length * 2);
+  public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) {
+    Platform.copyMemory(
+      src.getBaseObject(),
+      src.getBaseOffset() + srcPos * 16,
+      dst.getBaseObject(),
+      dst.getBaseOffset() + dstPos * 16,
+      length * 16);
   }
 
   @Override
-  public long[] allocate(int length) {
+  public LongArray allocate(int length) {
     assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large";
-    return new long[length * 2];
+    // This is used as temporary buffer, it's fine to allocate from JVM heap.
+    return new LongArray(MemoryBlock.fromLongArray(new long[length * 2]));
   }
 
 }
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
index 16ac2e8d821ba..1b3167fcc250c 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
@@ -32,4 +32,6 @@ public abstract class UnsafeSorterIterator {
   public abstract int getRecordLength();
 
   public abstract long getKeyPrefix();
+
+  public abstract int getNumRecords();
 }
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
index 3874a9f9cbdb6..01aed95878cf6 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
@@ -23,9 +23,10 @@
 
 final class UnsafeSorterSpillMerger {
 
+  private int numRecords = 0;
   private final PriorityQueue priorityQueue;
 
-  public UnsafeSorterSpillMerger(
+  UnsafeSorterSpillMerger(
       final RecordComparator recordComparator,
       final PrefixComparator prefixComparator,
       final int numSpills) {
@@ -44,7 +45,7 @@ public int compare(UnsafeSorterIterator left, UnsafeSorterIterator right) {
         }
       }
     };
-    priorityQueue = new PriorityQueue(numSpills, comparator);
+    priorityQueue = new PriorityQueue<>(numSpills, comparator);
   }
 
   /**
@@ -56,9 +57,10 @@ public void addSpillIfNotEmpty(UnsafeSorterIterator spillReader) throws IOExcept
       // make sure the hasNext method of UnsafeSorterIterator returned by getSortedIterator
       // does not return wrong result because hasNext will returns true
       // at least priorityQueue.size() times. If we allow n spillReaders in the
-      // priorityQueue, we will have n extra empty records in the result of the UnsafeSorterIterator.
+      // priorityQueue, we will have n extra empty records in the result of UnsafeSorterIterator.
       spillReader.loadNext();
       priorityQueue.add(spillReader);
+      numRecords += spillReader.getNumRecords();
     }
   }
 
@@ -67,6 +69,11 @@ public UnsafeSorterIterator getSortedIterator() throws IOException {
 
       private UnsafeSorterIterator spillReader;
 
+      @Override
+      public int getNumRecords() {
+        return numRecords;
+      }
+
       @Override
       public boolean hasNext() {
         return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext());
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
index 039e940a357ea..1d588c37c5db0 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -20,27 +20,25 @@
 import java.io.*;
 
 import com.google.common.io.ByteStreams;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+import com.google.common.io.Closeables;
 
+import org.apache.spark.serializer.SerializerManager;
 import org.apache.spark.storage.BlockId;
-import org.apache.spark.storage.BlockManager;
 import org.apache.spark.unsafe.Platform;
 
 /**
  * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description
  * of the file format).
  */
-public final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
-  private static final Logger logger = LoggerFactory.getLogger(UnsafeSorterSpillReader.class);
+public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implements Closeable {
 
-  private final File file;
   private InputStream in;
   private DataInputStream din;
 
   // Variables that change with every record read:
   private int recordLength;
   private long keyPrefix;
+  private int numRecords;
   private int numRecordsRemaining;
 
   private byte[] arr = new byte[1024 * 1024];
@@ -48,15 +46,24 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
   private final long baseOffset = Platform.BYTE_ARRAY_OFFSET;
 
   public UnsafeSorterSpillReader(
-      BlockManager blockManager,
+      SerializerManager serializerManager,
       File file,
       BlockId blockId) throws IOException {
     assert (file.length() > 0);
-    this.file = file;
     final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file));
-    this.in = blockManager.wrapForCompression(blockId, bs);
-    this.din = new DataInputStream(this.in);
-    numRecordsRemaining = din.readInt();
+    try {
+      this.in = serializerManager.wrapForCompression(blockId, bs);
+      this.din = new DataInputStream(this.in);
+      numRecords = numRecordsRemaining = din.readInt();
+    } catch (IOException e) {
+      Closeables.close(bs, /* swallowIOException = */ true);
+      throw e;
+    }
+  }
+
+  @Override
+  public int getNumRecords() {
+    return numRecords;
   }
 
   @Override
@@ -75,12 +82,7 @@ public void loadNext() throws IOException {
     ByteStreams.readFully(in, arr, 0, recordLength);
     numRecordsRemaining--;
     if (numRecordsRemaining == 0) {
-      in.close();
-      if (!file.delete() && file.exists()) {
-        logger.warn("Unable to delete spill file {}", file.getPath());
-      }
-      in = null;
-      din = null;
+      close();
     }
   }
 
@@ -103,4 +105,16 @@ public int getRecordLength() {
   public long getKeyPrefix() {
     return keyPrefix;
   }
+
+  @Override
+  public void close() throws IOException {
+   if (in != null) {
+     try {
+       in.close();
+     } finally {
+       in = null;
+       din = null;
+     }
+   }
+  }
 }
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
index 234e21140a1dd..9ba760e8422f4 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
@@ -20,6 +20,7 @@
 import java.io.File;
 import java.io.IOException;
 
+import org.apache.spark.serializer.SerializerManager;
 import scala.Tuple2;
 
 import org.apache.spark.executor.ShuffleWriteMetrics;
@@ -144,7 +145,7 @@ public File getFile() {
     return file;
   }
 
-  public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException {
-    return new UnsafeSorterSpillReader(blockManager, file, blockId);
+  public UnsafeSorterSpillReader getReader(SerializerManager serializerManager) throws IOException {
+    return new UnsafeSorterSpillReader(serializerManager, file, blockId);
   }
 }
diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties
deleted file mode 100644
index c85abc35b93bf..0000000000000
--- a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties
+++ /dev/null
@@ -1,33 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements.  See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License.  You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-# Set everything to be logged to the console
-log4j.rootCategory=WARN, console
-log4j.appender.console=org.apache.log4j.ConsoleAppender
-log4j.appender.console.target=System.err
-log4j.appender.console.layout=org.apache.log4j.PatternLayout
-log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
-
-# Settings to quiet third party logs that are too verbose
-log4j.logger.org.spark-project.jetty=WARN
-log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR
-log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
-log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
-
-# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support
-log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL
-log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR
diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties
index d44cc85dcbd82..89a7963a86d98 100644
--- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties
+++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties
@@ -22,9 +22,14 @@ log4j.appender.console.target=System.err
 log4j.appender.console.layout=org.apache.log4j.PatternLayout
 log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
 
+# Set the default spark-shell log level to WARN. When running the spark-shell, the
+# log level for this class is used to overwrite the root logger's log level, so that
+# the user can have different defaults for the shell and regular Spark apps.
+log4j.logger.org.apache.spark.repl.Main=WARN
+
 # Settings to quiet third party logs that are too verbose
-log4j.logger.org.spark-project.jetty=WARN
-log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR
+log4j.logger.org.spark_project.jetty=WARN
+log4j.logger.org.spark_project.jetty.util.component.AbstractLifeCycle=ERROR
 log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
 log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
 
diff --git a/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js b/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js
index 2d9262b972a59..6fe8136c87ae0 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js
@@ -1,4 +1,5 @@
-/* This is a custom version of dagre-d3 on top of v0.4.3. The full list of commits can be found at http://github.com/andrewor14/dagre-d3/ */!function(e){if("object"==typeof exports&&"undefined"!=typeof module)module.exports=e();else if("function"==typeof define&&define.amd)define([],e);else{var f;"undefined"!=typeof window?f=window:"undefined"!=typeof global?f=global:"undefined"!=typeof self&&(f=self),f.dagreD3=e()}}(function(){var define,module,exports;return function e(t,n,r){function s(o,u){if(!n[o]){if(!t[o]){var a=typeof require=="function"&&require;if(!u&&a)return a(o,!0);if(i)return i(o,!0);var f=new Error("Cannot find module '"+o+"'");throw f.code="MODULE_NOT_FOUND",f}var l=n[o]={exports:{}};t[o][0].call(l.exports,function(e){var n=t[o][1][e];return s(n?n:e)},l,l.exports,e,t,n,r)}return n[o].exports}var i=typeof require=="function"&&require;for(var o=0;o0}},{}],14:[function(require,module,exports){module.exports=intersectNode;function intersectNode(node,point){return node.intersect(point)}},{}],15:[function(require,module,exports){var intersectLine=require("./intersect-line");module.exports=intersectPolygon;function intersectPolygon(node,polyPoints,point){var x1=node.x;var y1=node.y;var intersections=[];var minX=Number.POSITIVE_INFINITY,minY=Number.POSITIVE_INFINITY;polyPoints.forEach(function(entry){minX=Math.min(minX,entry.x);minY=Math.min(minY,entry.y)});var left=x1-node.width/2-minX;var top=y1-node.height/2-minY;for(var i=0;i1){intersections.sort(function(p,q){var pdx=p.x-point.x,pdy=p.y-point.y,distp=Math.sqrt(pdx*pdx+pdy*pdy),qdx=q.x-point.x,qdy=q.y-point.y,distq=Math.sqrt(qdx*qdx+qdy*qdy);return distpMath.abs(dx)*h){if(dy<0){h=-h}sx=dy===0?0:h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=dx===0?0:w*dy/dx}return{x:x+sx,y:y+sy}}},{}],17:[function(require,module,exports){var util=require("../util");module.exports=addHtmlLabel;function addHtmlLabel(root,node){var fo=root.append("foreignObject").attr("width","100000");var div=fo.append("xhtml:div");var label=node.label;switch(typeof label){case"function":div.insert(label);break;case"object":div.insert(function(){return label});break;default:div.html(label)}util.applyStyle(div,node.labelStyle);div.style("display","inline-block");div.style("white-space","nowrap");var w,h;div.each(function(){w=this.clientWidth;h=this.clientHeight});fo.attr("width",w).attr("height",h);return fo}},{"../util":25}],18:[function(require,module,exports){var addTextLabel=require("./add-text-label"),addHtmlLabel=require("./add-html-label");module.exports=addLabel;function addLabel(root,node){var label=node.label;var labelSvg=root.append("g");if(typeof label!=="string"||node.labelType==="html"){addHtmlLabel(labelSvg,node)}else{addTextLabel(labelSvg,node)}var labelBBox=labelSvg.node().getBBox();labelSvg.attr("transform","translate("+-labelBBox.width/2+","+-labelBBox.height/2+")");return labelSvg}},{"./add-html-label":17,"./add-text-label":19}],19:[function(require,module,exports){var util=require("../util");module.exports=addTextLabel;function addTextLabel(root,node){var domNode=root.append("text");var lines=processEscapeSequences(node.label).split("\n");for(var i=0;imaxPadding){maxPadding=child.paddingTop}}return maxPadding}function getRank(g,v){var maxRank=0;var children=g.children(v);for(var i=0;imaxRank){maxRank=thisRank}}return maxRank}function orderByRank(g,nodes){return nodes.sort(function(x,y){return getRank(g,x)-getRank(g,y)})}function edgeToId(e){return escapeId(e.v)+":"+escapeId(e.w)+":"+escapeId(e.name)}var ID_DELIM=/:/g;function escapeId(str){return str?String(str).replace(ID_DELIM,"\\:"):""}function applyStyle(dom,styleFn){if(styleFn){dom.attr("style",styleFn)}}function applyClass(dom,classFn,otherClasses){if(classFn){dom.attr("class",classFn).attr("class",otherClasses+" "+dom.attr("class"))}}function applyTransition(selection,g){var graph=g.graph();if(_.isPlainObject(graph)){var transition=graph.transition;if(_.isFunction(transition)){return transition(selection)}}return selection}},{"./lodash":20}],26:[function(require,module,exports){module.exports="0.4.4-pre"},{}],27:[function(require,module,exports){module.exports={graphlib:require("./lib/graphlib"),layout:require("./lib/layout"),debug:require("./lib/debug"),util:{time:require("./lib/util").time,notime:require("./lib/util").notime},version:require("./lib/version")}},{"./lib/debug":32,"./lib/graphlib":33,"./lib/layout":35,"./lib/util":55,"./lib/version":56}],28:[function(require,module,exports){"use strict";var _=require("./lodash"),greedyFAS=require("./greedy-fas");module.exports={run:run,undo:undo};function run(g){var fas=g.graph().acyclicer==="greedy"?greedyFAS(g,weightFn(g)):dfsFAS(g);_.each(fas,function(e){var label=g.edge(e);g.removeEdge(e);label.forwardName=e.name;label.reversed=true;g.setEdge(e.w,e.v,label,_.uniqueId("rev"))});function weightFn(g){return function(e){return g.edge(e).weight}}}function dfsFAS(g){var fas=[],stack={},visited={};function dfs(v){if(_.has(visited,v)){return}visited[v]=true;stack[v]=true;_.each(g.outEdges(v),function(e){if(_.has(stack,e.w)){fas.push(e)}else{dfs(e.w)}});delete stack[v]}_.each(g.nodes(),dfs);return fas}function undo(g){_.each(g.edges(),function(e){var label=g.edge(e);if(label.reversed){g.removeEdge(e);var forwardName=label.forwardName;delete label.reversed;delete label.forwardName;g.setEdge(e.w,e.v,label,forwardName)}})}},{"./greedy-fas":34,"./lodash":36}],29:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports=addBorderSegments;function addBorderSegments(g){function dfs(v){var children=g.children(v),node=g.node(v);if(children.length){_.each(children,dfs)}if(_.has(node,"minRank")){node.borderLeft=[];node.borderRight=[];for(var rank=node.minRank,maxRank=node.maxRank+1;rank0;--i){entry=buckets[i].dequeue();if(entry){results=results.concat(removeNode(g,buckets,zeroIdx,entry,true));break}}}}return results}function removeNode(g,buckets,zeroIdx,entry,collectPredecessors){var results=collectPredecessors?[]:undefined;_.each(g.inEdges(entry.v),function(edge){var weight=g.edge(edge),uEntry=g.node(edge.v);if(collectPredecessors){results.push({v:edge.v,w:edge.w})}uEntry.out-=weight;assignBucket(buckets,zeroIdx,uEntry)});_.each(g.outEdges(entry.v),function(edge){var weight=g.edge(edge),w=edge.w,wEntry=g.node(w);wEntry["in"]-=weight;assignBucket(buckets,zeroIdx,wEntry)});g.removeNode(entry.v);return results}function buildState(g,weightFn){var fasGraph=new Graph,maxIn=0,maxOut=0;_.each(g.nodes(),function(v){fasGraph.setNode(v,{v:v,"in":0,out:0})});_.each(g.edges(),function(e){var prevWeight=fasGraph.edge(e.v,e.w)||0,weight=weightFn(e),edgeWeight=prevWeight+weight;fasGraph.setEdge(e.v,e.w,edgeWeight);maxOut=Math.max(maxOut,fasGraph.node(e.v).out+=weight);maxIn=Math.max(maxIn,fasGraph.node(e.w)["in"]+=weight)});var buckets=_.range(maxOut+maxIn+3).map(function(){return new List});var zeroIdx=maxIn+1;_.each(fasGraph.nodes(),function(v){assignBucket(buckets,zeroIdx,fasGraph.node(v))});return{graph:fasGraph,buckets:buckets,zeroIdx:zeroIdx}}function assignBucket(buckets,zeroIdx,entry){if(!entry.out){buckets[0].enqueue(entry)}else if(!entry["in"]){buckets[buckets.length-1].enqueue(entry)}else{buckets[entry.out-entry["in"]+zeroIdx].enqueue(entry)}}},{"./data/list":31,"./graphlib":33,"./lodash":36}],35:[function(require,module,exports){"use strict";var _=require("./lodash"),acyclic=require("./acyclic"),normalize=require("./normalize"),rank=require("./rank"),normalizeRanks=require("./util").normalizeRanks,parentDummyChains=require("./parent-dummy-chains"),removeEmptyRanks=require("./util").removeEmptyRanks,nestingGraph=require("./nesting-graph"),addBorderSegments=require("./add-border-segments"),coordinateSystem=require("./coordinate-system"),order=require("./order"),position=require("./position"),util=require("./util"),Graph=require("./graphlib").Graph;module.exports=layout;function layout(g,opts){var time=opts&&opts.debugTiming?util.time:util.notime;time("layout",function(){var layoutGraph=time("  buildLayoutGraph",function(){return buildLayoutGraph(g)});time("  runLayout",function(){runLayout(layoutGraph,time)});time("  updateInputGraph",function(){updateInputGraph(g,layoutGraph)})})}function runLayout(g,time){time("    makeSpaceForEdgeLabels",function(){makeSpaceForEdgeLabels(g)});time("    removeSelfEdges",function(){removeSelfEdges(g)});time("    acyclic",function(){acyclic.run(g)});time("    nestingGraph.run",function(){nestingGraph.run(g)});time("    rank",function(){rank(util.asNonCompoundGraph(g))});time("    injectEdgeLabelProxies",function(){injectEdgeLabelProxies(g)});time("    removeEmptyRanks",function(){removeEmptyRanks(g)});time("    nestingGraph.cleanup",function(){nestingGraph.cleanup(g)});time("    normalizeRanks",function(){normalizeRanks(g)});time("    assignRankMinMax",function(){assignRankMinMax(g)});time("    removeEdgeLabelProxies",function(){removeEdgeLabelProxies(g)});time("    normalize.run",function(){normalize.run(g)});time("    parentDummyChains",function(){
-parentDummyChains(g)});time("    addBorderSegments",function(){addBorderSegments(g)});time("    order",function(){order(g)});time("    insertSelfEdges",function(){insertSelfEdges(g)});time("    adjustCoordinateSystem",function(){coordinateSystem.adjust(g)});time("    position",function(){position(g)});time("    positionSelfEdges",function(){positionSelfEdges(g)});time("    removeBorderNodes",function(){removeBorderNodes(g)});time("    normalize.undo",function(){normalize.undo(g)});time("    fixupEdgeLabelCoords",function(){fixupEdgeLabelCoords(g)});time("    undoCoordinateSystem",function(){coordinateSystem.undo(g)});time("    translateGraph",function(){translateGraph(g)});time("    assignNodeIntersects",function(){assignNodeIntersects(g)});time("    reversePoints",function(){reversePointsForReversedEdges(g)});time("    acyclic.undo",function(){acyclic.undo(g)})}function updateInputGraph(inputGraph,layoutGraph){_.each(inputGraph.nodes(),function(v){var inputLabel=inputGraph.node(v),layoutLabel=layoutGraph.node(v);if(inputLabel){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y;if(layoutGraph.children(v).length){inputLabel.width=layoutLabel.width;inputLabel.height=layoutLabel.height}}});_.each(inputGraph.edges(),function(e){var inputLabel=inputGraph.edge(e),layoutLabel=layoutGraph.edge(e);inputLabel.points=layoutLabel.points;if(_.has(layoutLabel,"x")){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y}});inputGraph.graph().width=layoutGraph.graph().width;inputGraph.graph().height=layoutGraph.graph().height}var graphNumAttrs=["nodesep","edgesep","ranksep","marginx","marginy"],graphDefaults={ranksep:50,edgesep:20,nodesep:50,rankdir:"tb"},graphAttrs=["acyclicer","ranker","rankdir","align"],nodeNumAttrs=["width","height"],nodeDefaults={width:0,height:0},edgeNumAttrs=["minlen","weight","width","height","labeloffset"],edgeDefaults={minlen:1,weight:1,width:0,height:0,labeloffset:10,labelpos:"r"},edgeAttrs=["labelpos"];function buildLayoutGraph(inputGraph){var g=new Graph({multigraph:true,compound:true}),graph=canonicalize(inputGraph.graph());g.setGraph(_.merge({},graphDefaults,selectNumberAttrs(graph,graphNumAttrs),_.pick(graph,graphAttrs)));_.each(inputGraph.nodes(),function(v){var node=canonicalize(inputGraph.node(v));g.setNode(v,_.defaults(selectNumberAttrs(node,nodeNumAttrs),nodeDefaults));g.setParent(v,inputGraph.parent(v))});_.each(inputGraph.edges(),function(e){var edge=canonicalize(inputGraph.edge(e));g.setEdge(e,_.merge({},edgeDefaults,selectNumberAttrs(edge,edgeNumAttrs),_.pick(edge,edgeAttrs)))});return g}function makeSpaceForEdgeLabels(g){var graph=g.graph();graph.ranksep/=2;_.each(g.edges(),function(e){var edge=g.edge(e);edge.minlen*=2;if(edge.labelpos.toLowerCase()!=="c"){if(graph.rankdir==="TB"||graph.rankdir==="BT"){edge.width+=edge.labeloffset}else{edge.height+=edge.labeloffset}}})}function injectEdgeLabelProxies(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.width&&edge.height){var v=g.node(e.v),w=g.node(e.w),label={rank:(w.rank-v.rank)/2+v.rank,e:e};util.addDummyNode(g,"edge-proxy",label,"_ep")}})}function assignRankMinMax(g){var maxRank=0;_.each(g.nodes(),function(v){var node=g.node(v);if(node.borderTop){node.minRank=g.node(node.borderTop).rank;node.maxRank=g.node(node.borderBottom).rank;maxRank=_.max(maxRank,node.maxRank)}});g.graph().maxRank=maxRank}function removeEdgeLabelProxies(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="edge-proxy"){g.edge(node.e).labelRank=node.rank;g.removeNode(v)}})}function translateGraph(g){var minX=Number.POSITIVE_INFINITY,maxX=0,minY=Number.POSITIVE_INFINITY,maxY=0,graphLabel=g.graph(),marginX=graphLabel.marginx||0,marginY=graphLabel.marginy||0;function getExtremes(attrs){var x=attrs.x,y=attrs.y,w=attrs.width,h=attrs.height;minX=Math.min(minX,x-w/2);maxX=Math.max(maxX,x+w/2);minY=Math.min(minY,y-h/2);maxY=Math.max(maxY,y+h/2)}_.each(g.nodes(),function(v){getExtremes(g.node(v))});_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){getExtremes(edge)}});minX-=marginX;minY-=marginY;_.each(g.nodes(),function(v){var node=g.node(v);node.x-=minX;node.y-=minY});_.each(g.edges(),function(e){var edge=g.edge(e);_.each(edge.points,function(p){p.x-=minX;p.y-=minY});if(_.has(edge,"x")){edge.x-=minX}if(_.has(edge,"y")){edge.y-=minY}});graphLabel.width=maxX-minX+marginX;graphLabel.height=maxY-minY+marginY}function assignNodeIntersects(g){_.each(g.edges(),function(e){var edge=g.edge(e),nodeV=g.node(e.v),nodeW=g.node(e.w),p1,p2;if(!edge.points){edge.points=[];p1=nodeW;p2=nodeV}else{p1=edge.points[0];p2=edge.points[edge.points.length-1]}edge.points.unshift(util.intersectRect(nodeV,p1));edge.points.push(util.intersectRect(nodeW,p2))})}function fixupEdgeLabelCoords(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){if(edge.labelpos==="l"||edge.labelpos==="r"){edge.width-=edge.labeloffset}switch(edge.labelpos){case"l":edge.x-=edge.width/2+edge.labeloffset;break;case"r":edge.x+=edge.width/2+edge.labeloffset;break}}})}function reversePointsForReversedEdges(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.reversed){edge.points.reverse()}})}function removeBorderNodes(g){_.each(g.nodes(),function(v){if(g.children(v).length){var node=g.node(v),t=g.node(node.borderTop),b=g.node(node.borderBottom),l=g.node(_.last(node.borderLeft)),r=g.node(_.last(node.borderRight));node.width=Math.abs(r.x-l.x);node.height=Math.abs(b.y-t.y);node.x=l.x+node.width/2;node.y=t.y+node.height/2}});_.each(g.nodes(),function(v){if(g.node(v).dummy==="border"){g.removeNode(v)}})}function removeSelfEdges(g){_.each(g.edges(),function(e){if(e.v===e.w){var node=g.node(e.v);if(!node.selfEdges){node.selfEdges=[]}node.selfEdges.push({e:e,label:g.edge(e)});g.removeEdge(e)}})}function insertSelfEdges(g){var layers=util.buildLayerMatrix(g);_.each(layers,function(layer){var orderShift=0;_.each(layer,function(v,i){var node=g.node(v);node.order=i+orderShift;_.each(node.selfEdges,function(selfEdge){util.addDummyNode(g,"selfedge",{width:selfEdge.label.width,height:selfEdge.label.height,rank:node.rank,order:i+ ++orderShift,e:selfEdge.e,label:selfEdge.label},"_se")});delete node.selfEdges})})}function positionSelfEdges(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="selfedge"){var selfNode=g.node(node.e.v),x=selfNode.x+selfNode.width/2,y=selfNode.y,dx=node.x-x,dy=selfNode.height/2;g.setEdge(node.e,node.label);g.removeNode(v);node.label.points=[{x:x+2*dx/3,y:y-dy},{x:x+5*dx/6,y:y-dy},{x:x+dx,y:y},{x:x+5*dx/6,y:y+dy},{x:x+2*dx/3,y:y+dy}];node.label.x=node.x;node.label.y=node.y}})}function selectNumberAttrs(obj,attrs){return _.mapValues(_.pick(obj,attrs),Number)}function canonicalize(attrs){var newAttrs={};_.each(attrs,function(v,k){newAttrs[k.toLowerCase()]=v});return newAttrs}},{"./acyclic":28,"./add-border-segments":29,"./coordinate-system":30,"./graphlib":33,"./lodash":36,"./nesting-graph":37,"./normalize":38,"./order":43,"./parent-dummy-chains":48,"./position":50,"./rank":52,"./util":55}],36:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],37:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports={run:run,cleanup:cleanup};function run(g){var root=util.addDummyNode(g,"root",{},"_root"),depths=treeDepths(g),height=_.max(depths)-1,nodeSep=2*height+1;g.graph().nestingRoot=root;_.each(g.edges(),function(e){g.edge(e).minlen*=nodeSep});var weight=sumWeights(g)+1;_.each(g.children(),function(child){dfs(g,root,nodeSep,weight,height,depths,child)});g.graph().nodeRankFactor=nodeSep}function dfs(g,root,nodeSep,weight,height,depths,v){var children=g.children(v);if(!children.length){if(v!==root){g.setEdge(root,v,{weight:0,minlen:nodeSep})}return}var top=util.addBorderNode(g,"_bt"),bottom=util.addBorderNode(g,"_bb"),label=g.node(v);g.setParent(top,v);label.borderTop=top;g.setParent(bottom,v);label.borderBottom=bottom;_.each(children,function(child){dfs(g,root,nodeSep,weight,height,depths,child);var childNode=g.node(child),childTop=childNode.borderTop?childNode.borderTop:child,childBottom=childNode.borderBottom?childNode.borderBottom:child,thisWeight=childNode.borderTop?weight:2*weight,minlen=childTop!==childBottom?1:height-depths[v]+1;g.setEdge(top,childTop,{weight:thisWeight,minlen:minlen,nestingEdge:true});g.setEdge(childBottom,bottom,{weight:thisWeight,minlen:minlen,nestingEdge:true})});if(!g.parent(v)){g.setEdge(root,top,{weight:0,minlen:height+depths[v]})}}function treeDepths(g){var depths={};function dfs(v,depth){var children=g.children(v);if(children&&children.length){_.each(children,function(child){dfs(child,depth+1)})}depths[v]=depth}_.each(g.children(),function(v){dfs(v,1)});return depths}function sumWeights(g){return _.reduce(g.edges(),function(acc,e){return acc+g.edge(e).weight},0)}function cleanup(g){var graphLabel=g.graph();g.removeNode(graphLabel.nestingRoot);delete graphLabel.nestingRoot;_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.nestingEdge){g.removeEdge(e)}})}},{"./lodash":36,"./util":55}],38:[function(require,module,exports){"use strict";var _=require("./lodash"),util=require("./util");module.exports={run:run,undo:undo};function run(g){g.graph().dummyChains=[];_.each(g.edges(),function(edge){normalizeEdge(g,edge)})}function normalizeEdge(g,e){var v=e.v,vRank=g.node(v).rank,w=e.w,wRank=g.node(w).rank,name=e.name,edgeLabel=g.edge(e),labelRank=edgeLabel.labelRank;if(wRank===vRank+1)return;g.removeEdge(e);var dummy,attrs,i;for(i=0,++vRank;vRank0){if(index%2){weightSum+=tree[index+1]}index=index-1>>1;tree[index]+=entry.weight}cc+=entry.weight*weightSum}));return cc}},{"../lodash":36}],43:[function(require,module,exports){"use strict";var _=require("../lodash"),initOrder=require("./init-order"),crossCount=require("./cross-count"),sortSubgraph=require("./sort-subgraph"),buildLayerGraph=require("./build-layer-graph"),addSubgraphConstraints=require("./add-subgraph-constraints"),Graph=require("../graphlib").Graph,util=require("../util");module.exports=order;function order(g){var maxRank=util.maxRank(g),downLayerGraphs=buildLayerGraphs(g,_.range(1,maxRank+1),"inEdges"),upLayerGraphs=buildLayerGraphs(g,_.range(maxRank-1,-1,-1),"outEdges");var layering=initOrder(g);assignOrder(g,layering);var bestCC=Number.POSITIVE_INFINITY,best;for(var i=0,lastBest=0;lastBest<4;++i,++lastBest){sweepLayerGraphs(i%2?downLayerGraphs:upLayerGraphs,i%4>=2);layering=util.buildLayerMatrix(g);var cc=crossCount(g,layering);if(cc=vEntry.barycenter){mergeEntries(vEntry,uEntry)}}}function handleOut(vEntry){return function(wEntry){wEntry["in"].push(vEntry);if(--wEntry.indegree===0){sourceSet.push(wEntry)}}}while(sourceSet.length){var entry=sourceSet.pop();entries.push(entry);_.each(entry["in"].reverse(),handleIn(entry));_.each(entry.out,handleOut(entry))}return _.chain(entries).filter(function(entry){return!entry.merged}).map(function(entry){return _.pick(entry,["vs","i","barycenter","weight"])}).value()}function mergeEntries(target,source){var sum=0,weight=0;if(target.weight){sum+=target.barycenter*target.weight;weight+=target.weight}if(source.weight){sum+=source.barycenter*source.weight;weight+=source.weight}target.vs=source.vs.concat(target.vs);target.barycenter=sum/weight;target.weight=weight;target.i=Math.min(source.i,target.i);source.merged=true}},{"../lodash":36}],46:[function(require,module,exports){var _=require("../lodash"),barycenter=require("./barycenter"),resolveConflicts=require("./resolve-conflicts"),sort=require("./sort");module.exports=sortSubgraph;function sortSubgraph(g,v,cg,biasRight){var movable=g.children(v),node=g.node(v),bl=node?node.borderLeft:undefined,br=node?node.borderRight:undefined,subgraphs={};if(bl){movable=_.filter(movable,function(w){return w!==bl&&w!==br})}var barycenters=barycenter(g,movable);_.each(barycenters,function(entry){if(g.children(entry.v).length){var subgraphResult=sortSubgraph(g,entry.v,cg,biasRight);subgraphs[entry.v]=subgraphResult;if(_.has(subgraphResult,"barycenter")){mergeBarycenters(entry,subgraphResult)}}});var entries=resolveConflicts(barycenters,cg);expandSubgraphs(entries,subgraphs);var result=sort(entries,biasRight);if(bl){result.vs=_.flatten([bl,result.vs,br],true);if(g.predecessors(bl).length){var blPred=g.node(g.predecessors(bl)[0]),brPred=g.node(g.predecessors(br)[0]);if(!_.has(result,"barycenter")){result.barycenter=0;result.weight=0}result.barycenter=(result.barycenter*result.weight+blPred.order+brPred.order)/(result.weight+2);result.weight+=2}}return result}function expandSubgraphs(entries,subgraphs){_.each(entries,function(entry){entry.vs=_.flatten(entry.vs.map(function(v){if(subgraphs[v]){return subgraphs[v].vs}return v}),true)})}function mergeBarycenters(target,other){if(!_.isUndefined(target.barycenter)){target.barycenter=(target.barycenter*target.weight+other.barycenter*other.weight)/(target.weight+other.weight);target.weight+=other.weight}else{target.barycenter=other.barycenter;target.weight=other.weight}}},{"../lodash":36,"./barycenter":40,"./resolve-conflicts":45,"./sort":47}],47:[function(require,module,exports){var _=require("../lodash"),util=require("../util");module.exports=sort;function sort(entries,biasRight){var parts=util.partition(entries,function(entry){return _.has(entry,"barycenter")});var sortable=parts.lhs,unsortable=_.sortBy(parts.rhs,function(entry){return-entry.i}),vs=[],sum=0,weight=0,vsIndex=0;sortable.sort(compareWithBias(!!biasRight));vsIndex=consumeUnsortable(vs,unsortable,vsIndex);_.each(sortable,function(entry){vsIndex+=entry.vs.length;vs.push(entry.vs);sum+=entry.barycenter*entry.weight;weight+=entry.weight;vsIndex=consumeUnsortable(vs,unsortable,vsIndex)});var result={vs:_.flatten(vs,true)};if(weight){result.barycenter=sum/weight;result.weight=weight}return result}function consumeUnsortable(vs,unsortable,index){var last;while(unsortable.length&&(last=_.last(unsortable)).i<=index){unsortable.pop();vs.push(last.vs);index++}return index}function compareWithBias(bias){return function(entryV,entryW){if(entryV.barycenterentryW.barycenter){return 1}return!bias?entryV.i-entryW.i:entryW.i-entryV.i}}},{"../lodash":36,"../util":55}],48:[function(require,module,exports){var _=require("./lodash");module.exports=parentDummyChains;function parentDummyChains(g){var postorderNums=postorder(g);_.each(g.graph().dummyChains,function(v){var node=g.node(v),edgeObj=node.edgeObj,pathData=findPath(g,postorderNums,edgeObj.v,edgeObj.w),path=pathData.path,lca=pathData.lca,pathIdx=0,pathV=path[pathIdx],ascending=true;while(v!==edgeObj.w){node=g.node(v);if(ascending){while((pathV=path[pathIdx])!==lca&&g.node(pathV).maxRanklow||lim>postorderNums[parent].lim));lca=parent;parent=w;while((parent=g.parent(parent))!==lca){wPath.push(parent)}return{path:vPath.concat(wPath.reverse()),lca:lca}}function postorder(g){var result={},lim=0;function dfs(v){var low=lim;_.each(g.children(v),dfs);result[v]={low:low,lim:lim++}}_.each(g.children(),dfs);return result}},{"./lodash":36}],49:[function(require,module,exports){"use strict";var _=require("../lodash"),Graph=require("../graphlib").Graph,util=require("../util");module.exports={positionX:positionX,findType1Conflicts:findType1Conflicts,findType2Conflicts:findType2Conflicts,addConflict:addConflict,hasConflict:hasConflict,verticalAlignment:verticalAlignment,horizontalCompaction:horizontalCompaction,alignCoordinates:alignCoordinates,findSmallestWidthAlignment:findSmallestWidthAlignment,balance:balance};function findType1Conflicts(g,layering){var conflicts={};function visitLayer(prevLayer,layer){var k0=0,scanPos=0,prevLayerLength=prevLayer.length,lastNode=_.last(layer);_.each(layer,function(v,i){var w=findOtherInnerSegmentNode(g,v),k1=w?g.node(w).order:prevLayerLength;if(w||v===lastNode){_.each(layer.slice(scanPos,i+1),function(scanNode){_.each(g.predecessors(scanNode),function(u){var uLabel=g.node(u),uPos=uLabel.order;if((uPosnextNorthBorder)){addConflict(conflicts,u,v)}})}})}function visitLayer(north,south){var prevNorthPos=-1,nextNorthPos,southPos=0;_.each(south,function(v,southLookahead){if(g.node(v).dummy==="border"){var predecessors=g.predecessors(v);if(predecessors.length){nextNorthPos=g.node(predecessors[0]).order;scan(south,southPos,southLookahead,prevNorthPos,nextNorthPos);southPos=southLookahead;prevNorthPos=nextNorthPos}}scan(south,southPos,south.length,nextNorthPos,north.length)});return south}_.reduce(layering,visitLayer);return conflicts}function findOtherInnerSegmentNode(g,v){if(g.node(v).dummy){return _.find(g.predecessors(v),function(u){return g.node(u).dummy})}}function addConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}var conflictsV=conflicts[v];if(!conflictsV){conflicts[v]=conflictsV={}}conflictsV[w]=true}function hasConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}return _.has(conflicts[v],w)}function verticalAlignment(g,layering,conflicts,neighborFn){var root={},align={},pos={};_.each(layering,function(layer){_.each(layer,function(v,order){root[v]=v;align[v]=v;pos[v]=order})});_.each(layering,function(layer){var prevIdx=-1;_.each(layer,function(v){var ws=neighborFn(v);if(ws.length){ws=_.sortBy(ws,function(w){return pos[w]});var mp=(ws.length-1)/2;for(var i=Math.floor(mp),il=Math.ceil(mp);i<=il;++i){var w=ws[i];if(align[v]===v&&prevIdx0}},{}],14:[function(require,module,exports){module.exports=intersectNode;function intersectNode(node,point){return node.intersect(point)}},{}],15:[function(require,module,exports){var intersectLine=require("./intersect-line");module.exports=intersectPolygon;function intersectPolygon(node,polyPoints,point){var x1=node.x;var y1=node.y;var intersections=[];var minX=Number.POSITIVE_INFINITY,minY=Number.POSITIVE_INFINITY;polyPoints.forEach(function(entry){minX=Math.min(minX,entry.x);minY=Math.min(minY,entry.y)});var left=x1-node.width/2-minX;var top=y1-node.height/2-minY;for(var i=0;i1){intersections.sort(function(p,q){var pdx=p.x-point.x,pdy=p.y-point.y,distp=Math.sqrt(pdx*pdx+pdy*pdy),qdx=q.x-point.x,qdy=q.y-point.y,distq=Math.sqrt(qdx*qdx+qdy*qdy);return distpMath.abs(dx)*h){if(dy<0){h=-h}sx=dy===0?0:h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=dx===0?0:w*dy/dx}return{x:x+sx,y:y+sy}}},{}],17:[function(require,module,exports){var util=require("../util");module.exports=addHtmlLabel;function addHtmlLabel(root,node){var fo=root.append("foreignObject").attr("width","100000");var div=fo.append("xhtml:div");var label=node.label;switch(typeof label){case"function":div.insert(label);break;case"object":div.insert(function(){return label});break;default:div.html(label)}util.applyStyle(div,node.labelStyle);div.style("display","inline-block");div.style("white-space","nowrap");var w,h;div.each(function(){w=this.clientWidth;h=this.clientHeight});fo.attr("width",w).attr("height",h);return fo}},{"../util":25}],18:[function(require,module,exports){var addTextLabel=require("./add-text-label"),addHtmlLabel=require("./add-html-label");module.exports=addLabel;function addLabel(root,node){var label=node.label;var labelSvg=root.append("g");if(typeof label!=="string"||node.labelType==="html"){addHtmlLabel(labelSvg,node)}else{addTextLabel(labelSvg,node)}var labelBBox=labelSvg.node().getBBox();labelSvg.attr("transform","translate("+-labelBBox.width/2+","+-labelBBox.height/2+")");return labelSvg}},{"./add-html-label":17,"./add-text-label":19}],19:[function(require,module,exports){var util=require("../util");module.exports=addTextLabel;function addTextLabel(root,node){var domNode=root.append("text");var lines=processEscapeSequences(node.label).split("\n");for(var i=0;imaxPadding){maxPadding=child.paddingTop}}return maxPadding}function getRank(g,v){var maxRank=0;var children=g.children(v);for(var i=0;imaxRank){maxRank=thisRank}}return maxRank}function orderByRank(g,nodes){return nodes.sort(function(x,y){return getRank(g,x)-getRank(g,y)})}function edgeToId(e){return escapeId(e.v)+":"+escapeId(e.w)+":"+escapeId(e.name)}var ID_DELIM=/:/g;function escapeId(str){return str?String(str).replace(ID_DELIM,"\\:"):""}function applyStyle(dom,styleFn){if(styleFn){dom.attr("style",styleFn)}}function applyClass(dom,classFn,otherClasses){if(classFn){dom.attr("class",classFn).attr("class",otherClasses+" "+dom.attr("class"))}}function applyTransition(selection,g){var graph=g.graph();if(_.isPlainObject(graph)){var transition=graph.transition;if(_.isFunction(transition)){return transition(selection)}}return selection}},{"./lodash":20}],26:[function(require,module,exports){module.exports="0.4.4-pre"},{}],27:[function(require,module,exports){module.exports={graphlib:require("./lib/graphlib"),layout:require("./lib/layout"),debug:require("./lib/debug"),util:{time:require("./lib/util").time,notime:require("./lib/util").notime},version:require("./lib/version")}},{"./lib/debug":32,"./lib/graphlib":33,"./lib/layout":35,"./lib/util":55,"./lib/version":56}],28:[function(require,module,exports){"use strict";var _=require("./lodash"),greedyFAS=require("./greedy-fas");module.exports={run:run,undo:undo};function run(g){var fas=g.graph().acyclicer==="greedy"?greedyFAS(g,weightFn(g)):dfsFAS(g);_.each(fas,function(e){var label=g.edge(e);g.removeEdge(e);label.forwardName=e.name;label.reversed=true;g.setEdge(e.w,e.v,label,_.uniqueId("rev"))});function weightFn(g){return function(e){return g.edge(e).weight}}}function dfsFAS(g){var fas=[],stack={},visited={};function dfs(v){if(_.has(visited,v)){return}visited[v]=true;stack[v]=true;_.each(g.outEdges(v),function(e){if(_.has(stack,e.w)){fas.push(e)}else{dfs(e.w)}});delete stack[v]}_.each(g.nodes(),dfs);return fas}function undo(g){_.each(g.edges(),function(e){var label=g.edge(e);if(label.reversed){g.removeEdge(e);var forwardName=label.forwardName;delete label.reversed;delete label.forwardName;g.setEdge(e.w,e.v,label,forwardName)}})}},{"./greedy-fas":34,"./lodash":36}],29:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports=addBorderSegments;function addBorderSegments(g){function dfs(v){var children=g.children(v),node=g.node(v);if(children.length){_.each(children,dfs)}if(_.has(node,"minRank")){node.borderLeft=[];node.borderRight=[];for(var rank=node.minRank,maxRank=node.maxRank+1;rank0;--i){entry=buckets[i].dequeue();if(entry){results=results.concat(removeNode(g,buckets,zeroIdx,entry,true));break}}}}return results}function removeNode(g,buckets,zeroIdx,entry,collectPredecessors){var results=collectPredecessors?[]:undefined;_.each(g.inEdges(entry.v),function(edge){var weight=g.edge(edge),uEntry=g.node(edge.v);if(collectPredecessors){results.push({v:edge.v,w:edge.w})}uEntry.out-=weight;assignBucket(buckets,zeroIdx,uEntry)});_.each(g.outEdges(entry.v),function(edge){var weight=g.edge(edge),w=edge.w,wEntry=g.node(w);wEntry["in"]-=weight;assignBucket(buckets,zeroIdx,wEntry)});g.removeNode(entry.v);return results}function buildState(g,weightFn){var fasGraph=new Graph,maxIn=0,maxOut=0;_.each(g.nodes(),function(v){fasGraph.setNode(v,{v:v,"in":0,out:0})});_.each(g.edges(),function(e){var prevWeight=fasGraph.edge(e.v,e.w)||0,weight=weightFn(e),edgeWeight=prevWeight+weight;fasGraph.setEdge(e.v,e.w,edgeWeight);maxOut=Math.max(maxOut,fasGraph.node(e.v).out+=weight);maxIn=Math.max(maxIn,fasGraph.node(e.w)["in"]+=weight)});var buckets=_.range(maxOut+maxIn+3).map(function(){return new List});var zeroIdx=maxIn+1;_.each(fasGraph.nodes(),function(v){assignBucket(buckets,zeroIdx,fasGraph.node(v))});return{graph:fasGraph,buckets:buckets,zeroIdx:zeroIdx}}function assignBucket(buckets,zeroIdx,entry){if(!entry.out){buckets[0].enqueue(entry)}else if(!entry["in"]){buckets[buckets.length-1].enqueue(entry)}else{buckets[entry.out-entry["in"]+zeroIdx].enqueue(entry)}}},{"./data/list":31,"./graphlib":33,"./lodash":36}],35:[function(require,module,exports){"use strict";var _=require("./lodash"),acyclic=require("./acyclic"),normalize=require("./normalize"),rank=require("./rank"),normalizeRanks=require("./util").normalizeRanks,parentDummyChains=require("./parent-dummy-chains"),removeEmptyRanks=require("./util").removeEmptyRanks,nestingGraph=require("./nesting-graph"),addBorderSegments=require("./add-border-segments"),coordinateSystem=require("./coordinate-system"),order=require("./order"),position=require("./position"),util=require("./util"),Graph=require("./graphlib").Graph;module.exports=layout;function layout(g,opts){var time=opts&&opts.debugTiming?util.time:util.notime;time("layout",function(){var layoutGraph=time("  buildLayoutGraph",function(){return buildLayoutGraph(g)});time("  runLayout",function(){runLayout(layoutGraph,time)});time("  updateInputGraph",function(){updateInputGraph(g,layoutGraph)})})}function runLayout(g,time){time("    makeSpaceForEdgeLabels",function(){makeSpaceForEdgeLabels(g)});time("    removeSelfEdges",function(){removeSelfEdges(g)});time("    acyclic",function(){acyclic.run(g)});time("    nestingGraph.run",function(){nestingGraph.run(g)});time("    rank",function(){rank(util.asNonCompoundGraph(g))});time("    injectEdgeLabelProxies",function(){injectEdgeLabelProxies(g)});time("    removeEmptyRanks",function(){removeEmptyRanks(g)});time("    nestingGraph.cleanup",function(){nestingGraph.cleanup(g)});time("    normalizeRanks",function(){normalizeRanks(g)});time("    assignRankMinMax",function(){assignRankMinMax(g)});time("    removeEdgeLabelProxies",function(){removeEdgeLabelProxies(g)});time("    normalize.run",function(){
+normalize.run(g)});time("    parentDummyChains",function(){parentDummyChains(g)});time("    addBorderSegments",function(){addBorderSegments(g)});time("    order",function(){order(g)});time("    insertSelfEdges",function(){insertSelfEdges(g)});time("    adjustCoordinateSystem",function(){coordinateSystem.adjust(g)});time("    position",function(){position(g)});time("    positionSelfEdges",function(){positionSelfEdges(g)});time("    removeBorderNodes",function(){removeBorderNodes(g)});time("    normalize.undo",function(){normalize.undo(g)});time("    fixupEdgeLabelCoords",function(){fixupEdgeLabelCoords(g)});time("    undoCoordinateSystem",function(){coordinateSystem.undo(g)});time("    translateGraph",function(){translateGraph(g)});time("    assignNodeIntersects",function(){assignNodeIntersects(g)});time("    reversePoints",function(){reversePointsForReversedEdges(g)});time("    acyclic.undo",function(){acyclic.undo(g)})}function updateInputGraph(inputGraph,layoutGraph){_.each(inputGraph.nodes(),function(v){var inputLabel=inputGraph.node(v),layoutLabel=layoutGraph.node(v);if(inputLabel){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y;if(layoutGraph.children(v).length){inputLabel.width=layoutLabel.width;inputLabel.height=layoutLabel.height}}});_.each(inputGraph.edges(),function(e){var inputLabel=inputGraph.edge(e),layoutLabel=layoutGraph.edge(e);inputLabel.points=layoutLabel.points;if(_.has(layoutLabel,"x")){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y}});inputGraph.graph().width=layoutGraph.graph().width;inputGraph.graph().height=layoutGraph.graph().height}var graphNumAttrs=["nodesep","edgesep","ranksep","marginx","marginy"],graphDefaults={ranksep:50,edgesep:20,nodesep:50,rankdir:"tb"},graphAttrs=["acyclicer","ranker","rankdir","align"],nodeNumAttrs=["width","height"],nodeDefaults={width:0,height:0},edgeNumAttrs=["minlen","weight","width","height","labeloffset"],edgeDefaults={minlen:1,weight:1,width:0,height:0,labeloffset:10,labelpos:"r"},edgeAttrs=["labelpos"];function buildLayoutGraph(inputGraph){var g=new Graph({multigraph:true,compound:true}),graph=canonicalize(inputGraph.graph());g.setGraph(_.merge({},graphDefaults,selectNumberAttrs(graph,graphNumAttrs),_.pick(graph,graphAttrs)));_.each(inputGraph.nodes(),function(v){var node=canonicalize(inputGraph.node(v));g.setNode(v,_.defaults(selectNumberAttrs(node,nodeNumAttrs),nodeDefaults));g.setParent(v,inputGraph.parent(v))});_.each(inputGraph.edges(),function(e){var edge=canonicalize(inputGraph.edge(e));g.setEdge(e,_.merge({},edgeDefaults,selectNumberAttrs(edge,edgeNumAttrs),_.pick(edge,edgeAttrs)))});return g}function makeSpaceForEdgeLabels(g){var graph=g.graph();graph.ranksep/=2;_.each(g.edges(),function(e){var edge=g.edge(e);edge.minlen*=2;if(edge.labelpos.toLowerCase()!=="c"){if(graph.rankdir==="TB"||graph.rankdir==="BT"){edge.width+=edge.labeloffset}else{edge.height+=edge.labeloffset}}})}function injectEdgeLabelProxies(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.width&&edge.height){var v=g.node(e.v),w=g.node(e.w),label={rank:(w.rank-v.rank)/2+v.rank,e:e};util.addDummyNode(g,"edge-proxy",label,"_ep")}})}function assignRankMinMax(g){var maxRank=0;_.each(g.nodes(),function(v){var node=g.node(v);if(node.borderTop){node.minRank=g.node(node.borderTop).rank;node.maxRank=g.node(node.borderBottom).rank;maxRank=_.max(maxRank,node.maxRank)}});g.graph().maxRank=maxRank}function removeEdgeLabelProxies(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="edge-proxy"){g.edge(node.e).labelRank=node.rank;g.removeNode(v)}})}function translateGraph(g){var minX=Number.POSITIVE_INFINITY,maxX=0,minY=Number.POSITIVE_INFINITY,maxY=0,graphLabel=g.graph(),marginX=graphLabel.marginx||0,marginY=graphLabel.marginy||0;function getExtremes(attrs){var x=attrs.x,y=attrs.y,w=attrs.width,h=attrs.height;minX=Math.min(minX,x-w/2);maxX=Math.max(maxX,x+w/2);minY=Math.min(minY,y-h/2);maxY=Math.max(maxY,y+h/2)}_.each(g.nodes(),function(v){getExtremes(g.node(v))});_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){getExtremes(edge)}});minX-=marginX;minY-=marginY;_.each(g.nodes(),function(v){var node=g.node(v);node.x-=minX;node.y-=minY});_.each(g.edges(),function(e){var edge=g.edge(e);_.each(edge.points,function(p){p.x-=minX;p.y-=minY});if(_.has(edge,"x")){edge.x-=minX}if(_.has(edge,"y")){edge.y-=minY}});graphLabel.width=maxX-minX+marginX;graphLabel.height=maxY-minY+marginY}function assignNodeIntersects(g){_.each(g.edges(),function(e){var edge=g.edge(e),nodeV=g.node(e.v),nodeW=g.node(e.w),p1,p2;if(!edge.points){edge.points=[];p1=nodeW;p2=nodeV}else{p1=edge.points[0];p2=edge.points[edge.points.length-1]}edge.points.unshift(util.intersectRect(nodeV,p1));edge.points.push(util.intersectRect(nodeW,p2))})}function fixupEdgeLabelCoords(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){if(edge.labelpos==="l"||edge.labelpos==="r"){edge.width-=edge.labeloffset}switch(edge.labelpos){case"l":edge.x-=edge.width/2+edge.labeloffset;break;case"r":edge.x+=edge.width/2+edge.labeloffset;break}}})}function reversePointsForReversedEdges(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.reversed){edge.points.reverse()}})}function removeBorderNodes(g){_.each(g.nodes(),function(v){if(g.children(v).length){var node=g.node(v),t=g.node(node.borderTop),b=g.node(node.borderBottom),l=g.node(_.last(node.borderLeft)),r=g.node(_.last(node.borderRight));node.width=Math.abs(r.x-l.x);node.height=Math.abs(b.y-t.y);node.x=l.x+node.width/2;node.y=t.y+node.height/2}});_.each(g.nodes(),function(v){if(g.node(v).dummy==="border"){g.removeNode(v)}})}function removeSelfEdges(g){_.each(g.edges(),function(e){if(e.v===e.w){var node=g.node(e.v);if(!node.selfEdges){node.selfEdges=[]}node.selfEdges.push({e:e,label:g.edge(e)});g.removeEdge(e)}})}function insertSelfEdges(g){var layers=util.buildLayerMatrix(g);_.each(layers,function(layer){var orderShift=0;_.each(layer,function(v,i){var node=g.node(v);node.order=i+orderShift;_.each(node.selfEdges,function(selfEdge){util.addDummyNode(g,"selfedge",{width:selfEdge.label.width,height:selfEdge.label.height,rank:node.rank,order:i+ ++orderShift,e:selfEdge.e,label:selfEdge.label},"_se")});delete node.selfEdges})})}function positionSelfEdges(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="selfedge"){var selfNode=g.node(node.e.v),x=selfNode.x+selfNode.width/2,y=selfNode.y,dx=node.x-x,dy=selfNode.height/2;g.setEdge(node.e,node.label);g.removeNode(v);node.label.points=[{x:x+2*dx/3,y:y-dy},{x:x+5*dx/6,y:y-dy},{x:x+dx,y:y},{x:x+5*dx/6,y:y+dy},{x:x+2*dx/3,y:y+dy}];node.label.x=node.x;node.label.y=node.y}})}function selectNumberAttrs(obj,attrs){return _.mapValues(_.pick(obj,attrs),Number)}function canonicalize(attrs){var newAttrs={};_.each(attrs,function(v,k){newAttrs[k.toLowerCase()]=v});return newAttrs}},{"./acyclic":28,"./add-border-segments":29,"./coordinate-system":30,"./graphlib":33,"./lodash":36,"./nesting-graph":37,"./normalize":38,"./order":43,"./parent-dummy-chains":48,"./position":50,"./rank":52,"./util":55}],36:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],37:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports={run:run,cleanup:cleanup};function run(g){var root=util.addDummyNode(g,"root",{},"_root"),depths=treeDepths(g),height=_.max(depths)-1,nodeSep=2*height+1;g.graph().nestingRoot=root;_.each(g.edges(),function(e){g.edge(e).minlen*=nodeSep});var weight=sumWeights(g)+1;_.each(g.children(),function(child){dfs(g,root,nodeSep,weight,height,depths,child)});g.graph().nodeRankFactor=nodeSep}function dfs(g,root,nodeSep,weight,height,depths,v){var children=g.children(v);if(!children.length){if(v!==root){g.setEdge(root,v,{weight:0,minlen:nodeSep})}return}var top=util.addBorderNode(g,"_bt"),bottom=util.addBorderNode(g,"_bb"),label=g.node(v);g.setParent(top,v);label.borderTop=top;g.setParent(bottom,v);label.borderBottom=bottom;_.each(children,function(child){dfs(g,root,nodeSep,weight,height,depths,child);var childNode=g.node(child),childTop=childNode.borderTop?childNode.borderTop:child,childBottom=childNode.borderBottom?childNode.borderBottom:child,thisWeight=childNode.borderTop?weight:2*weight,minlen=childTop!==childBottom?1:height-depths[v]+1;g.setEdge(top,childTop,{weight:thisWeight,minlen:minlen,nestingEdge:true});g.setEdge(childBottom,bottom,{weight:thisWeight,minlen:minlen,nestingEdge:true})});if(!g.parent(v)){g.setEdge(root,top,{weight:0,minlen:height+depths[v]})}}function treeDepths(g){var depths={};function dfs(v,depth){var children=g.children(v);if(children&&children.length){_.each(children,function(child){dfs(child,depth+1)})}depths[v]=depth}_.each(g.children(),function(v){dfs(v,1)});return depths}function sumWeights(g){return _.reduce(g.edges(),function(acc,e){return acc+g.edge(e).weight},0)}function cleanup(g){var graphLabel=g.graph();g.removeNode(graphLabel.nestingRoot);delete graphLabel.nestingRoot;_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.nestingEdge){g.removeEdge(e)}})}},{"./lodash":36,"./util":55}],38:[function(require,module,exports){"use strict";var _=require("./lodash"),util=require("./util");module.exports={run:run,undo:undo};function run(g){g.graph().dummyChains=[];_.each(g.edges(),function(edge){normalizeEdge(g,edge)})}function normalizeEdge(g,e){var v=e.v,vRank=g.node(v).rank,w=e.w,wRank=g.node(w).rank,name=e.name,edgeLabel=g.edge(e),labelRank=edgeLabel.labelRank;if(wRank===vRank+1)return;g.removeEdge(e);var dummy,attrs,i;for(i=0,++vRank;vRank0){if(index%2){weightSum+=tree[index+1]}index=index-1>>1;tree[index]+=entry.weight}cc+=entry.weight*weightSum}));return cc}},{"../lodash":36}],43:[function(require,module,exports){"use strict";var _=require("../lodash"),initOrder=require("./init-order"),crossCount=require("./cross-count"),sortSubgraph=require("./sort-subgraph"),buildLayerGraph=require("./build-layer-graph"),addSubgraphConstraints=require("./add-subgraph-constraints"),Graph=require("../graphlib").Graph,util=require("../util");module.exports=order;function order(g){var maxRank=util.maxRank(g),downLayerGraphs=buildLayerGraphs(g,_.range(1,maxRank+1),"inEdges"),upLayerGraphs=buildLayerGraphs(g,_.range(maxRank-1,-1,-1),"outEdges");var layering=initOrder(g);assignOrder(g,layering);var bestCC=Number.POSITIVE_INFINITY,best;for(var i=0,lastBest=0;lastBest<4;++i,++lastBest){sweepLayerGraphs(i%2?downLayerGraphs:upLayerGraphs,i%4>=2);layering=util.buildLayerMatrix(g);var cc=crossCount(g,layering);if(cc=vEntry.barycenter){mergeEntries(vEntry,uEntry)}}}function handleOut(vEntry){return function(wEntry){wEntry["in"].push(vEntry);if(--wEntry.indegree===0){sourceSet.push(wEntry)}}}while(sourceSet.length){var entry=sourceSet.pop();entries.push(entry);_.each(entry["in"].reverse(),handleIn(entry));_.each(entry.out,handleOut(entry))}return _.chain(entries).filter(function(entry){return!entry.merged}).map(function(entry){return _.pick(entry,["vs","i","barycenter","weight"])}).value()}function mergeEntries(target,source){var sum=0,weight=0;if(target.weight){sum+=target.barycenter*target.weight;weight+=target.weight}if(source.weight){sum+=source.barycenter*source.weight;weight+=source.weight}target.vs=source.vs.concat(target.vs);target.barycenter=sum/weight;target.weight=weight;target.i=Math.min(source.i,target.i);source.merged=true}},{"../lodash":36}],46:[function(require,module,exports){var _=require("../lodash"),barycenter=require("./barycenter"),resolveConflicts=require("./resolve-conflicts"),sort=require("./sort");module.exports=sortSubgraph;function sortSubgraph(g,v,cg,biasRight){var movable=g.children(v),node=g.node(v),bl=node?node.borderLeft:undefined,br=node?node.borderRight:undefined,subgraphs={};if(bl){movable=_.filter(movable,function(w){return w!==bl&&w!==br})}var barycenters=barycenter(g,movable);_.each(barycenters,function(entry){if(g.children(entry.v).length){var subgraphResult=sortSubgraph(g,entry.v,cg,biasRight);subgraphs[entry.v]=subgraphResult;if(_.has(subgraphResult,"barycenter")){mergeBarycenters(entry,subgraphResult)}}});var entries=resolveConflicts(barycenters,cg);expandSubgraphs(entries,subgraphs);var result=sort(entries,biasRight);if(bl){result.vs=_.flatten([bl,result.vs,br],true);if(g.predecessors(bl).length){var blPred=g.node(g.predecessors(bl)[0]),brPred=g.node(g.predecessors(br)[0]);if(!_.has(result,"barycenter")){result.barycenter=0;result.weight=0}result.barycenter=(result.barycenter*result.weight+blPred.order+brPred.order)/(result.weight+2);result.weight+=2}}return result}function expandSubgraphs(entries,subgraphs){_.each(entries,function(entry){entry.vs=_.flatten(entry.vs.map(function(v){if(subgraphs[v]){return subgraphs[v].vs}return v}),true)})}function mergeBarycenters(target,other){if(!_.isUndefined(target.barycenter)){target.barycenter=(target.barycenter*target.weight+other.barycenter*other.weight)/(target.weight+other.weight);target.weight+=other.weight}else{target.barycenter=other.barycenter;target.weight=other.weight}}},{"../lodash":36,"./barycenter":40,"./resolve-conflicts":45,"./sort":47}],47:[function(require,module,exports){var _=require("../lodash"),util=require("../util");module.exports=sort;function sort(entries,biasRight){var parts=util.partition(entries,function(entry){return _.has(entry,"barycenter")});var sortable=parts.lhs,unsortable=_.sortBy(parts.rhs,function(entry){return-entry.i}),vs=[],sum=0,weight=0,vsIndex=0;sortable.sort(compareWithBias(!!biasRight));vsIndex=consumeUnsortable(vs,unsortable,vsIndex);_.each(sortable,function(entry){vsIndex+=entry.vs.length;vs.push(entry.vs);sum+=entry.barycenter*entry.weight;weight+=entry.weight;vsIndex=consumeUnsortable(vs,unsortable,vsIndex)});var result={vs:_.flatten(vs,true)};if(weight){result.barycenter=sum/weight;result.weight=weight}return result}function consumeUnsortable(vs,unsortable,index){var last;while(unsortable.length&&(last=_.last(unsortable)).i<=index){unsortable.pop();vs.push(last.vs);index++}return index}function compareWithBias(bias){return function(entryV,entryW){if(entryV.barycenterentryW.barycenter){return 1}return!bias?entryV.i-entryW.i:entryW.i-entryV.i}}},{"../lodash":36,"../util":55}],48:[function(require,module,exports){var _=require("./lodash");module.exports=parentDummyChains;function parentDummyChains(g){var postorderNums=postorder(g);_.each(g.graph().dummyChains,function(v){var node=g.node(v),edgeObj=node.edgeObj,pathData=findPath(g,postorderNums,edgeObj.v,edgeObj.w),path=pathData.path,lca=pathData.lca,pathIdx=0,pathV=path[pathIdx],ascending=true;while(v!==edgeObj.w){node=g.node(v);if(ascending){while((pathV=path[pathIdx])!==lca&&g.node(pathV).maxRanklow||lim>postorderNums[parent].lim));lca=parent;parent=w;while((parent=g.parent(parent))!==lca){wPath.push(parent)}return{path:vPath.concat(wPath.reverse()),lca:lca}}function postorder(g){var result={},lim=0;function dfs(v){var low=lim;_.each(g.children(v),dfs);result[v]={low:low,lim:lim++}}_.each(g.children(),dfs);return result}},{"./lodash":36}],49:[function(require,module,exports){"use strict";var _=require("../lodash"),Graph=require("../graphlib").Graph,util=require("../util");module.exports={positionX:positionX,findType1Conflicts:findType1Conflicts,findType2Conflicts:findType2Conflicts,addConflict:addConflict,hasConflict:hasConflict,verticalAlignment:verticalAlignment,horizontalCompaction:horizontalCompaction,alignCoordinates:alignCoordinates,findSmallestWidthAlignment:findSmallestWidthAlignment,balance:balance};function findType1Conflicts(g,layering){var conflicts={};function visitLayer(prevLayer,layer){var k0=0,scanPos=0,prevLayerLength=prevLayer.length,lastNode=_.last(layer);_.each(layer,function(v,i){var w=findOtherInnerSegmentNode(g,v),k1=w?g.node(w).order:prevLayerLength;if(w||v===lastNode){_.each(layer.slice(scanPos,i+1),function(scanNode){_.each(g.predecessors(scanNode),function(u){var uLabel=g.node(u),uPos=uLabel.order;if((uPosnextNorthBorder)){addConflict(conflicts,u,v)}})}})}function visitLayer(north,south){var prevNorthPos=-1,nextNorthPos,southPos=0;_.each(south,function(v,southLookahead){if(g.node(v).dummy==="border"){var predecessors=g.predecessors(v);if(predecessors.length){nextNorthPos=g.node(predecessors[0]).order;scan(south,southPos,southLookahead,prevNorthPos,nextNorthPos);southPos=southLookahead;prevNorthPos=nextNorthPos}}scan(south,southPos,south.length,nextNorthPos,north.length)});return south}_.reduce(layering,visitLayer);return conflicts}function findOtherInnerSegmentNode(g,v){if(g.node(v).dummy){return _.find(g.predecessors(v),function(u){return g.node(u).dummy})}}function addConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}var conflictsV=conflicts[v];if(!conflictsV){conflicts[v]=conflictsV={}}conflictsV[w]=true}function hasConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}return _.has(conflicts[v],w)}function verticalAlignment(g,layering,conflicts,neighborFn){var root={},align={},pos={};_.each(layering,function(layer){_.each(layer,function(v,order){root[v]=v;align[v]=v;pos[v]=order})});_.each(layering,function(layer){var prevIdx=-1;_.each(layer,function(v){var ws=neighborFn(v);if(ws.length){ws=_.sortBy(ws,function(w){return pos[w]});var mp=(ws.length-1)/2;for(var i=Math.floor(mp),il=Math.ceil(mp);i<=il;++i){var w=ws[i];if(align[v]===v&&prevIdxwLabel.lim){tailLabel=wLabel;flip=true}var candidates=_.filter(g.edges(),function(edge){return flip===isDescendant(t,t.node(edge.v),tailLabel)&&flip!==isDescendant(t,t.node(edge.w),tailLabel)});return _.min(candidates,function(edge){return slack(g,edge)})}function exchangeEdges(t,g,e,f){var v=e.v,w=e.w;t.removeEdge(v,w);t.setEdge(f.v,f.w,{});initLowLimValues(t);initCutValues(t,g);updateRanks(t,g)}function updateRanks(t,g){var root=_.find(t.nodes(),function(v){return!g.node(v).parent}),vs=preorder(t,root);vs=vs.slice(1);_.each(vs,function(v){var parent=t.node(v).parent,edge=g.edge(v,parent),flipped=false;if(!edge){edge=g.edge(parent,v);flipped=true}g.node(v).rank=g.node(parent).rank+(flipped?edge.minlen:-edge.minlen)})}function isTreeEdge(tree,u,v){return tree.hasEdge(u,v)}function isDescendant(tree,vLabel,rootLabel){return rootLabel.low<=vLabel.lim&&vLabel.lim<=rootLabel.lim}},{"../graphlib":33,"../lodash":36,"../util":55,"./feasible-tree":51,"./util":54}],54:[function(require,module,exports){"use strict";var _=require("../lodash");module.exports={longestPath:longestPath,slack:slack};function longestPath(g){var visited={};function dfs(v){var label=g.node(v);if(_.has(visited,v)){return label.rank}visited[v]=true;var rank=_.min(_.map(g.outEdges(v),function(e){return dfs(e.w)-g.edge(e).minlen}));if(rank===Number.POSITIVE_INFINITY){rank=0}return label.rank=rank}_.each(g.sources(),dfs)}function slack(g,e){return g.node(e.w).rank-g.node(e.v).rank-g.edge(e).minlen}},{"../lodash":36}],55:[function(require,module,exports){"use strict";var _=require("./lodash"),Graph=require("./graphlib").Graph;module.exports={addDummyNode:addDummyNode,simplify:simplify,asNonCompoundGraph:asNonCompoundGraph,successorWeights:successorWeights,predecessorWeights:predecessorWeights,intersectRect:intersectRect,buildLayerMatrix:buildLayerMatrix,normalizeRanks:normalizeRanks,removeEmptyRanks:removeEmptyRanks,addBorderNode:addBorderNode,maxRank:maxRank,partition:partition,time:time,notime:notime};function addDummyNode(g,type,attrs,name){var v;do{v=_.uniqueId(name)}while(g.hasNode(v));attrs.dummy=type;g.setNode(v,attrs);return v}function simplify(g){var simplified=(new Graph).setGraph(g.graph());_.each(g.nodes(),function(v){simplified.setNode(v,g.node(v))});_.each(g.edges(),function(e){var simpleLabel=simplified.edge(e.v,e.w)||{weight:0,minlen:1},label=g.edge(e);simplified.setEdge(e.v,e.w,{weight:simpleLabel.weight+label.weight,minlen:Math.max(simpleLabel.minlen,label.minlen)})});return simplified}function asNonCompoundGraph(g){var simplified=new Graph({multigraph:g.isMultigraph()}).setGraph(g.graph());_.each(g.nodes(),function(v){if(!g.children(v).length){simplified.setNode(v,g.node(v))}});_.each(g.edges(),function(e){simplified.setEdge(e,g.edge(e))});return simplified}function successorWeights(g){var weightMap=_.map(g.nodes(),function(v){var sucs={};_.each(g.outEdges(v),function(e){sucs[e.w]=(sucs[e.w]||0)+g.edge(e).weight});return sucs});return _.zipObject(g.nodes(),weightMap)}function predecessorWeights(g){var weightMap=_.map(g.nodes(),function(v){var preds={};_.each(g.inEdges(v),function(e){preds[e.v]=(preds[e.v]||0)+g.edge(e).weight});return preds});return _.zipObject(g.nodes(),weightMap)}function intersectRect(rect,point){var x=rect.x;var y=rect.y;var dx=point.x-x;var dy=point.y-y;var w=rect.width/2;var h=rect.height/2;if(!dx&&!dy){throw new Error("Not possible to find intersection inside of the rectangle")}var sx,sy;if(Math.abs(dy)*w>Math.abs(dx)*h){if(dy<0){h=-h}sx=h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=w*dy/dx}return{x:x+sx,y:y+sy}}function buildLayerMatrix(g){var layering=_.map(_.range(maxRank(g)+1),function(){return[]});_.each(g.nodes(),function(v){var node=g.node(v),rank=node.rank;if(!_.isUndefined(rank)){layering[rank][node.order]=v}});return layering}function normalizeRanks(g){var min=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));_.each(g.nodes(),function(v){var node=g.node(v);if(_.has(node,"rank")){node.rank-=min}})}function removeEmptyRanks(g){var offset=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));var layers=[];_.each(g.nodes(),function(v){var rank=g.node(v).rank-offset;if(!_.has(layers,rank)){layers[rank]=[]}layers[rank].push(v)});var delta=0,nodeRankFactor=g.graph().nodeRankFactor;_.each(layers,function(vs,i){if(_.isUndefined(vs)&&i%nodeRankFactor!==0){--delta}else if(delta){_.each(vs,function(v){g.node(v).rank+=delta})}})}function addBorderNode(g,prefix,rank,order){var node={width:0,height:0};if(arguments.length>=4){node.rank=rank;node.order=order}return addDummyNode(g,"border",node,prefix)}function maxRank(g){return _.max(_.map(g.nodes(),function(v){var rank=g.node(v).rank;if(!_.isUndefined(rank)){return rank}}))}function partition(collection,fn){var result={lhs:[],rhs:[]};_.each(collection,function(value){if(fn(value)){result.lhs.push(value)}else{result.rhs.push(value)}});return result}function time(name,fn){var start=_.now();try{return fn()}finally{console.log(name+" time: "+(_.now()-start)+"ms")}}function notime(name,fn){return fn()}},{"./graphlib":33,"./lodash":36}],56:[function(require,module,exports){module.exports="0.7.1"},{}],57:[function(require,module,exports){var lib=require("./lib");module.exports={Graph:lib.Graph,json:require("./lib/json"),alg:require("./lib/alg"),version:lib.version}},{"./lib":73,"./lib/alg":64,"./lib/json":74}],58:[function(require,module,exports){var _=require("../lodash");module.exports=components;function components(g){var visited={},cmpts=[],cmpt;function dfs(v){if(_.has(visited,v))return;visited[v]=true;cmpt.push(v);_.each(g.successors(v),dfs);_.each(g.predecessors(v),dfs)}_.each(g.nodes(),function(v){cmpt=[];dfs(v);if(cmpt.length){cmpts.push(cmpt)}});return cmpts}},{"../lodash":75}],59:[function(require,module,exports){var _=require("../lodash");module.exports=dfs;function dfs(g,vs,order){if(!_.isArray(vs)){vs=[vs]}var acc=[],visited={};_.each(vs,function(v){if(!g.hasNode(v)){throw new Error("Graph does not have node: "+v)}doDfs(g,v,order==="post",visited,acc)});return acc}function doDfs(g,v,postorder,visited,acc){if(!_.has(visited,v)){visited[v]=true;if(!postorder){acc.push(v)}_.each(g.neighbors(v),function(w){doDfs(g,w,postorder,visited,acc)});if(postorder){acc.push(v)}}}},{"../lodash":75}],60:[function(require,module,exports){var dijkstra=require("./dijkstra"),_=require("../lodash");module.exports=dijkstraAll;function dijkstraAll(g,weightFunc,edgeFunc){return _.transform(g.nodes(),function(acc,v){acc[v]=dijkstra(g,v,weightFunc,edgeFunc)},{})}},{"../lodash":75,"./dijkstra":61}],61:[function(require,module,exports){var _=require("../lodash"),PriorityQueue=require("../data/priority-queue");module.exports=dijkstra;var DEFAULT_WEIGHT_FUNC=_.constant(1);function dijkstra(g,source,weightFn,edgeFn){return runDijkstra(g,String(source),weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runDijkstra(g,source,weightFn,edgeFn){var results={},pq=new PriorityQueue,v,vEntry;var updateNeighbors=function(edge){var w=edge.v!==v?edge.v:edge.w,wEntry=results[w],weight=weightFn(edge),distance=vEntry.distance+weight;if(weight<0){throw new Error("dijkstra does not allow negative edge weights. "+"Bad edge: "+edge+" Weight: "+weight)}if(distance0){v=pq.removeMin();vEntry=results[v];if(vEntry.distance===Number.POSITIVE_INFINITY){break}edgeFn(v).forEach(updateNeighbors)}return results}},{"../data/priority-queue":71,"../lodash":75}],62:[function(require,module,exports){var _=require("../lodash"),tarjan=require("./tarjan");module.exports=findCycles;function findCycles(g){return _.filter(tarjan(g),function(cmpt){return cmpt.length>1})}},{"../lodash":75,"./tarjan":69}],63:[function(require,module,exports){var _=require("../lodash");module.exports=floydWarshall;var DEFAULT_WEIGHT_FUNC=_.constant(1);function floydWarshall(g,weightFn,edgeFn){return runFloydWarshall(g,weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runFloydWarshall(g,weightFn,edgeFn){var results={},nodes=g.nodes();nodes.forEach(function(v){results[v]={};results[v][v]={distance:0};nodes.forEach(function(w){if(v!==w){results[v][w]={distance:Number.POSITIVE_INFINITY}}});edgeFn(v).forEach(function(edge){var w=edge.v===v?edge.w:edge.v,d=weightFn(edge);results[v][w]={distance:d,predecessor:v}})});nodes.forEach(function(k){var rowK=results[k];nodes.forEach(function(i){var rowI=results[i];nodes.forEach(function(j){var ik=rowI[k];var kj=rowK[j];var ij=rowI[j];var altDistance=ik.distance+kj.distance;if(altDistance0){v=pq.removeMin();if(_.has(parents,v)){result.setEdge(v,parents[v])}else if(init){throw new Error("Input graph is not connected: "+g)}else{init=true}g.nodeEdges(v).forEach(updateNeighbors)}return result}},{"../data/priority-queue":71,"../graph":72,"../lodash":75}],69:[function(require,module,exports){var _=require("../lodash");module.exports=tarjan;function tarjan(g){var index=0,stack=[],visited={},results=[];function dfs(v){var entry=visited[v]={onStack:true,lowlink:index,index:index++};stack.push(v);g.successors(v).forEach(function(w){if(!_.has(visited,w)){dfs(w);entry.lowlink=Math.min(entry.lowlink,visited[w].lowlink)}else if(visited[w].onStack){entry.lowlink=Math.min(entry.lowlink,visited[w].index)}});if(entry.lowlink===entry.index){var cmpt=[],w;do{w=stack.pop();visited[w].onStack=false;cmpt.push(w)}while(v!==w);results.push(cmpt)}}g.nodes().forEach(function(v){if(!_.has(visited,v)){dfs(v)}});return results}},{"../lodash":75}],70:[function(require,module,exports){var _=require("../lodash");module.exports=topsort;topsort.CycleException=CycleException;function topsort(g){var visited={},stack={},results=[];function visit(node){if(_.has(stack,node)){throw new CycleException}if(!_.has(visited,node)){stack[node]=true;visited[node]=true;_.each(g.predecessors(node),visit);delete stack[node];results.push(node)}}_.each(g.sinks(),visit);if(_.size(visited)!==g.nodeCount()){throw new CycleException}return results}function CycleException(){}},{"../lodash":75}],71:[function(require,module,exports){var _=require("../lodash");module.exports=PriorityQueue;function PriorityQueue(){this._arr=[];this._keyIndices={}}PriorityQueue.prototype.size=function(){return this._arr.length};PriorityQueue.prototype.keys=function(){return this._arr.map(function(x){return x.key})};PriorityQueue.prototype.has=function(key){return _.has(this._keyIndices,key)};PriorityQueue.prototype.priority=function(key){var index=this._keyIndices[key];if(index!==undefined){return this._arr[index].priority}};PriorityQueue.prototype.min=function(){if(this.size()===0){throw new Error("Queue underflow")}return this._arr[0].key};PriorityQueue.prototype.add=function(key,priority){var keyIndices=this._keyIndices;key=String(key);if(!_.has(keyIndices,key)){var arr=this._arr;var index=arr.length;keyIndices[key]=index;arr.push({key:key,priority:priority});this._decrease(index);return true}return false};PriorityQueue.prototype.removeMin=function(){this._swap(0,this._arr.length-1);var min=this._arr.pop();delete this._keyIndices[min.key];this._heapify(0);return min.key};PriorityQueue.prototype.decrease=function(key,priority){var index=this._keyIndices[key];if(priority>this._arr[index].priority){throw new Error("New priority is greater than current priority. "+"Key: "+key+" Old: "+this._arr[index].priority+" New: "+priority)}this._arr[index].priority=priority;this._decrease(index)};PriorityQueue.prototype._heapify=function(i){var arr=this._arr;var l=2*i,r=l+1,largest=i;if(l>1;if(arr[parent].priority1){this.setNode(v,value)}else{this.setNode(v)}},this);return this};Graph.prototype.setNode=function(v,value){if(_.has(this._nodes,v)){if(arguments.length>1){this._nodes[v]=value}return this}this._nodes[v]=arguments.length>1?value:this._defaultNodeLabelFn(v);if(this._isCompound){this._parent[v]=GRAPH_NODE;this._children[v]={};this._children[GRAPH_NODE][v]=true}this._in[v]={};this._preds[v]={};this._out[v]={};this._sucs[v]={};++this._nodeCount;return this};Graph.prototype.node=function(v){return this._nodes[v]};Graph.prototype.hasNode=function(v){return _.has(this._nodes,v)};Graph.prototype.removeNode=function(v){var self=this;if(_.has(this._nodes,v)){var removeEdge=function(e){self.removeEdge(self._edgeObjs[e])};delete this._nodes[v];if(this._isCompound){this._removeFromParentsChildList(v);delete this._parent[v];_.each(this.children(v),function(child){this.setParent(child)},this);delete this._children[v]}_.each(_.keys(this._in[v]),removeEdge);delete this._in[v];delete this._preds[v];_.each(_.keys(this._out[v]),removeEdge);delete this._out[v];delete this._sucs[v];--this._nodeCount}return this};Graph.prototype.setParent=function(v,parent){if(!this._isCompound){throw new Error("Cannot set parent in a non-compound graph")}if(_.isUndefined(parent)){parent=GRAPH_NODE}else{for(var ancestor=parent;!_.isUndefined(ancestor);ancestor=this.parent(ancestor)){if(ancestor===v){throw new Error("Setting "+parent+" as parent of "+v+" would create create a cycle")}}this.setNode(parent)}this.setNode(v);this._removeFromParentsChildList(v);this._parent[v]=parent;this._children[parent][v]=true;return this};Graph.prototype._removeFromParentsChildList=function(v){delete this._children[this._parent[v]][v]};Graph.prototype.parent=function(v){if(this._isCompound){var parent=this._parent[v];if(parent!==GRAPH_NODE){return parent}}};Graph.prototype.children=function(v){if(_.isUndefined(v)){v=GRAPH_NODE}if(this._isCompound){var children=this._children[v];if(children){return _.keys(children)}}else if(v===GRAPH_NODE){return this.nodes()}else if(this.hasNode(v)){return[]}};Graph.prototype.predecessors=function(v){var predsV=this._preds[v];if(predsV){return _.keys(predsV)}};Graph.prototype.successors=function(v){var sucsV=this._sucs[v];if(sucsV){return _.keys(sucsV)}};Graph.prototype.neighbors=function(v){var preds=this.predecessors(v);if(preds){return _.union(preds,this.successors(v))}};Graph.prototype.setDefaultEdgeLabel=function(newDefault){if(!_.isFunction(newDefault)){newDefault=_.constant(newDefault)}this._defaultEdgeLabelFn=newDefault;return this};Graph.prototype.edgeCount=function(){return this._edgeCount};Graph.prototype.edges=function(){return _.values(this._edgeObjs)};Graph.prototype.setPath=function(vs,value){var self=this,args=arguments;_.reduce(vs,function(v,w){if(args.length>1){self.setEdge(v,w,value)}else{self.setEdge(v,w)}return w});return this};Graph.prototype.setEdge=function(){var v,w,name,value,valueSpecified=false;if(_.isPlainObject(arguments[0])){v=arguments[0].v;w=arguments[0].w;name=arguments[0].name;if(arguments.length===2){value=arguments[1];valueSpecified=true}}else{v=arguments[0];w=arguments[1];name=arguments[3];if(arguments.length>2){value=arguments[2];valueSpecified=true}}v=""+v;w=""+w;if(!_.isUndefined(name)){name=""+name}var e=edgeArgsToId(this._isDirected,v,w,name);if(_.has(this._edgeLabels,e)){if(valueSpecified){this._edgeLabels[e]=value}return this}if(!_.isUndefined(name)&&!this._isMultigraph){throw new Error("Cannot set a named edge when isMultigraph = false")}this.setNode(v);this.setNode(w);this._edgeLabels[e]=valueSpecified?value:this._defaultEdgeLabelFn(v,w,name);var edgeObj=edgeArgsToObj(this._isDirected,v,w,name);v=edgeObj.v;w=edgeObj.w;Object.freeze(edgeObj);this._edgeObjs[e]=edgeObj;incrementOrInitEntry(this._preds[w],v);incrementOrInitEntry(this._sucs[v],w);this._in[w][e]=edgeObj;this._out[v][e]=edgeObj;this._edgeCount++;return this};Graph.prototype.edge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return this._edgeLabels[e]};Graph.prototype.hasEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return _.has(this._edgeLabels,e)};Graph.prototype.removeEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name),edge=this._edgeObjs[e];if(edge){v=edge.v;w=edge.w;delete this._edgeLabels[e];delete this._edgeObjs[e];decrementOrRemoveEntry(this._preds[w],v);decrementOrRemoveEntry(this._sucs[v],w);delete this._in[w][e];delete this._out[v][e];this._edgeCount--}return this};Graph.prototype.inEdges=function(v,u){var inV=this._in[v];if(inV){var edges=_.values(inV);if(!u){return edges}return _.filter(edges,function(edge){return edge.v===u})}};Graph.prototype.outEdges=function(v,w){var outV=this._out[v];if(outV){var edges=_.values(outV);if(!w){return edges}return _.filter(edges,function(edge){return edge.w===w})}};Graph.prototype.nodeEdges=function(v,w){var inEdges=this.inEdges(v,w);if(inEdges){return inEdges.concat(this.outEdges(v,w))}};function incrementOrInitEntry(map,k){if(_.has(map,k)){map[k]++}else{map[k]=1}}function decrementOrRemoveEntry(map,k){if(!--map[k]){delete map[k]}}function edgeArgsToId(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}return v+EDGE_KEY_DELIM+w+EDGE_KEY_DELIM+(_.isUndefined(name)?DEFAULT_EDGE_NAME:name)}function edgeArgsToObj(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}var edgeObj={v:v,w:w};if(name){edgeObj.name=name}return edgeObj}function edgeObjToId(isDirected,edgeObj){return edgeArgsToId(isDirected,edgeObj.v,edgeObj.w,edgeObj.name)}},{"./lodash":75}],73:[function(require,module,exports){module.exports={Graph:require("./graph"),version:require("./version")}},{"./graph":72,"./version":76}],74:[function(require,module,exports){var _=require("./lodash"),Graph=require("./graph");module.exports={write:write,read:read};function write(g){var json={options:{directed:g.isDirected(),multigraph:g.isMultigraph(),compound:g.isCompound()},nodes:writeNodes(g),edges:writeEdges(g)};if(!_.isUndefined(g.graph())){json.value=_.clone(g.graph())}return json}function writeNodes(g){return _.map(g.nodes(),function(v){var nodeValue=g.node(v),parent=g.parent(v),node={v:v};if(!_.isUndefined(nodeValue)){node.value=nodeValue}if(!_.isUndefined(parent)){node.parent=parent}return node})}function writeEdges(g){return _.map(g.edges(),function(e){var edgeValue=g.edge(e),edge={v:e.v,w:e.w};if(!_.isUndefined(e.name)){edge.name=e.name}if(!_.isUndefined(edgeValue)){edge.value=edgeValue}return edge})}function read(json){var g=new Graph(json.options).setGraph(json.value);_.each(json.nodes,function(entry){g.setNode(entry.v,entry.value);if(entry.parent){g.setParent(entry.v,entry.parent)}});_.each(json.edges,function(entry){g.setEdge({v:entry.v,w:entry.w,name:entry.name},entry.value)});return g}},{"./graph":72,"./lodash":75}],75:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],76:[function(require,module,exports){module.exports="1.0.1"},{}],77:[function(require,module,exports){(function(global){(function(){var undefined;var arrayPool=[],objectPool=[];var idCounter=0;var keyPrefix=+new Date+"";var largeArraySize=75;var maxPoolSize=40;var whitespace=" 	\f \ufeff"+"\n\r\u2028\u2029"+" ᠎              ";var reEmptyStringLeading=/\b__p \+= '';/g,reEmptyStringMiddle=/\b(__p \+=) '' \+/g,reEmptyStringTrailing=/(__e\(.*?\)|\b__t\)) \+\n'';/g;var reEsTemplate=/\$\{([^\\}]*(?:\\.[^\\}]*)*)\}/g;var reFlags=/\w*$/;var reFuncName=/^\s*function[ \n\r\t]+\w/;var reInterpolate=/<%=([\s\S]+?)%>/g;var reLeadingSpacesAndZeros=RegExp("^["+whitespace+"]*0+(?=.$)");var reNoMatch=/($^)/;var reThis=/\bthis\b/;var reUnescapedString=/['\n\r\t\u2028\u2029\\]/g;var contextProps=["Array","Boolean","Date","Function","Math","Number","Object","RegExp","String","_","attachEvent","clearTimeout","isFinite","isNaN","parseInt","setTimeout"];var templateCounter=0;var argsClass="[object Arguments]",arrayClass="[object Array]",boolClass="[object Boolean]",dateClass="[object Date]",funcClass="[object Function]",numberClass="[object Number]",objectClass="[object Object]",regexpClass="[object RegExp]",stringClass="[object String]";var cloneableClasses={};cloneableClasses[funcClass]=false;cloneableClasses[argsClass]=cloneableClasses[arrayClass]=cloneableClasses[boolClass]=cloneableClasses[dateClass]=cloneableClasses[numberClass]=cloneableClasses[objectClass]=cloneableClasses[regexpClass]=cloneableClasses[stringClass]=true;var debounceOptions={leading:false,maxWait:0,trailing:false};var descriptor={configurable:false,enumerable:false,value:null,writable:false};var objectTypes={"boolean":false,"function":true,object:true,number:false,string:false,undefined:false};var stringEscapes={"\\":"\\","'":"'","\n":"n","\r":"r","	":"t","\u2028":"u2028","\u2029":"u2029"};var root=objectTypes[typeof window]&&window||this;var freeExports=objectTypes[typeof exports]&&exports&&!exports.nodeType&&exports;var freeModule=objectTypes[typeof module]&&module&&!module.nodeType&&module;var moduleExports=freeModule&&freeModule.exports===freeExports&&freeExports;var freeGlobal=objectTypes[typeof global]&&global;if(freeGlobal&&(freeGlobal.global===freeGlobal||freeGlobal.window===freeGlobal)){root=freeGlobal}function baseIndexOf(array,value,fromIndex){var index=(fromIndex||0)-1,length=array?array.length:0;while(++index-1?0:-1:cache?0:-1}function cachePush(value){var cache=this.cache,type=typeof value;if(type=="boolean"||value==null){cache[value]=true}else{if(type!="number"&&type!="string"){type="object"}var key=type=="number"?value:keyPrefix+value,typeCache=cache[type]||(cache[type]={});if(type=="object"){(typeCache[key]||(typeCache[key]=[])).push(value)}else{typeCache[key]=true}}}function charAtCallback(value){return value.charCodeAt(0)}function compareAscending(a,b){var ac=a.criteria,bc=b.criteria,index=-1,length=ac.length;while(++indexother||typeof value=="undefined"){return 1}if(value/g,evaluate:/<%([\s\S]+?)%>/g,interpolate:reInterpolate,variable:"",imports:{_:lodash}};function baseBind(bindData){var func=bindData[0],partialArgs=bindData[2],thisArg=bindData[4];function bound(){if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(this instanceof bound){var thisBinding=baseCreate(func.prototype),result=func.apply(thisBinding,args||arguments);return isObject(result)?result:thisBinding}return func.apply(thisArg,args||arguments)}setBindData(bound,bindData);return bound}function baseClone(value,isDeep,callback,stackA,stackB){if(callback){var result=callback(value);if(typeof result!="undefined"){return result}}var isObj=isObject(value);if(isObj){var className=toString.call(value);if(!cloneableClasses[className]){return value}var ctor=ctorByClass[className];switch(className){case boolClass:case dateClass:return new ctor(+value);case numberClass:case stringClass:return new ctor(value);case regexpClass:result=ctor(value.source,reFlags.exec(value));result.lastIndex=value.lastIndex;return result}}else{return value}var isArr=isArray(value);if(isDeep){var initedStack=!stackA;stackA||(stackA=getArray());stackB||(stackB=getArray());var length=stackA.length;while(length--){if(stackA[length]==value){return stackB[length]}}result=isArr?ctor(value.length):{}}else{result=isArr?slice(value):assign({},value)}if(isArr){if(hasOwnProperty.call(value,"index")){result.index=value.index}if(hasOwnProperty.call(value,"input")){result.input=value.input}}if(!isDeep){return result}stackA.push(value);stackB.push(result);(isArr?forEach:forOwn)(value,function(objValue,key){result[key]=baseClone(objValue,isDeep,callback,stackA,stackB)});if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseCreate(prototype,properties){return isObject(prototype)?nativeCreate(prototype):{};
+})}function enterEdge(t,g,edge){var v=edge.v,w=edge.w;if(!g.hasEdge(v,w)){v=edge.w;w=edge.v}var vLabel=t.node(v),wLabel=t.node(w),tailLabel=vLabel,flip=false;if(vLabel.lim>wLabel.lim){tailLabel=wLabel;flip=true}var candidates=_.filter(g.edges(),function(edge){return flip===isDescendant(t,t.node(edge.v),tailLabel)&&flip!==isDescendant(t,t.node(edge.w),tailLabel)});return _.min(candidates,function(edge){return slack(g,edge)})}function exchangeEdges(t,g,e,f){var v=e.v,w=e.w;t.removeEdge(v,w);t.setEdge(f.v,f.w,{});initLowLimValues(t);initCutValues(t,g);updateRanks(t,g)}function updateRanks(t,g){var root=_.find(t.nodes(),function(v){return!g.node(v).parent}),vs=preorder(t,root);vs=vs.slice(1);_.each(vs,function(v){var parent=t.node(v).parent,edge=g.edge(v,parent),flipped=false;if(!edge){edge=g.edge(parent,v);flipped=true}g.node(v).rank=g.node(parent).rank+(flipped?edge.minlen:-edge.minlen)})}function isTreeEdge(tree,u,v){return tree.hasEdge(u,v)}function isDescendant(tree,vLabel,rootLabel){return rootLabel.low<=vLabel.lim&&vLabel.lim<=rootLabel.lim}},{"../graphlib":33,"../lodash":36,"../util":55,"./feasible-tree":51,"./util":54}],54:[function(require,module,exports){"use strict";var _=require("../lodash");module.exports={longestPath:longestPath,slack:slack};function longestPath(g){var visited={};function dfs(v){var label=g.node(v);if(_.has(visited,v)){return label.rank}visited[v]=true;var rank=_.min(_.map(g.outEdges(v),function(e){return dfs(e.w)-g.edge(e).minlen}));if(rank===Number.POSITIVE_INFINITY){rank=0}return label.rank=rank}_.each(g.sources(),dfs)}function slack(g,e){return g.node(e.w).rank-g.node(e.v).rank-g.edge(e).minlen}},{"../lodash":36}],55:[function(require,module,exports){"use strict";var _=require("./lodash"),Graph=require("./graphlib").Graph;module.exports={addDummyNode:addDummyNode,simplify:simplify,asNonCompoundGraph:asNonCompoundGraph,successorWeights:successorWeights,predecessorWeights:predecessorWeights,intersectRect:intersectRect,buildLayerMatrix:buildLayerMatrix,normalizeRanks:normalizeRanks,removeEmptyRanks:removeEmptyRanks,addBorderNode:addBorderNode,maxRank:maxRank,partition:partition,time:time,notime:notime};function addDummyNode(g,type,attrs,name){var v;do{v=_.uniqueId(name)}while(g.hasNode(v));attrs.dummy=type;g.setNode(v,attrs);return v}function simplify(g){var simplified=(new Graph).setGraph(g.graph());_.each(g.nodes(),function(v){simplified.setNode(v,g.node(v))});_.each(g.edges(),function(e){var simpleLabel=simplified.edge(e.v,e.w)||{weight:0,minlen:1},label=g.edge(e);simplified.setEdge(e.v,e.w,{weight:simpleLabel.weight+label.weight,minlen:Math.max(simpleLabel.minlen,label.minlen)})});return simplified}function asNonCompoundGraph(g){var simplified=new Graph({multigraph:g.isMultigraph()}).setGraph(g.graph());_.each(g.nodes(),function(v){if(!g.children(v).length){simplified.setNode(v,g.node(v))}});_.each(g.edges(),function(e){simplified.setEdge(e,g.edge(e))});return simplified}function successorWeights(g){var weightMap=_.map(g.nodes(),function(v){var sucs={};_.each(g.outEdges(v),function(e){sucs[e.w]=(sucs[e.w]||0)+g.edge(e).weight});return sucs});return _.zipObject(g.nodes(),weightMap)}function predecessorWeights(g){var weightMap=_.map(g.nodes(),function(v){var preds={};_.each(g.inEdges(v),function(e){preds[e.v]=(preds[e.v]||0)+g.edge(e).weight});return preds});return _.zipObject(g.nodes(),weightMap)}function intersectRect(rect,point){var x=rect.x;var y=rect.y;var dx=point.x-x;var dy=point.y-y;var w=rect.width/2;var h=rect.height/2;if(!dx&&!dy){throw new Error("Not possible to find intersection inside of the rectangle")}var sx,sy;if(Math.abs(dy)*w>Math.abs(dx)*h){if(dy<0){h=-h}sx=h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=w*dy/dx}return{x:x+sx,y:y+sy}}function buildLayerMatrix(g){var layering=_.map(_.range(maxRank(g)+1),function(){return[]});_.each(g.nodes(),function(v){var node=g.node(v),rank=node.rank;if(!_.isUndefined(rank)){layering[rank][node.order]=v}});return layering}function normalizeRanks(g){var min=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));_.each(g.nodes(),function(v){var node=g.node(v);if(_.has(node,"rank")){node.rank-=min}})}function removeEmptyRanks(g){var offset=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));var layers=[];_.each(g.nodes(),function(v){var rank=g.node(v).rank-offset;if(!_.has(layers,rank)){layers[rank]=[]}layers[rank].push(v)});var delta=0,nodeRankFactor=g.graph().nodeRankFactor;_.each(layers,function(vs,i){if(_.isUndefined(vs)&&i%nodeRankFactor!==0){--delta}else if(delta){_.each(vs,function(v){g.node(v).rank+=delta})}})}function addBorderNode(g,prefix,rank,order){var node={width:0,height:0};if(arguments.length>=4){node.rank=rank;node.order=order}return addDummyNode(g,"border",node,prefix)}function maxRank(g){return _.max(_.map(g.nodes(),function(v){var rank=g.node(v).rank;if(!_.isUndefined(rank)){return rank}}))}function partition(collection,fn){var result={lhs:[],rhs:[]};_.each(collection,function(value){if(fn(value)){result.lhs.push(value)}else{result.rhs.push(value)}});return result}function time(name,fn){var start=_.now();try{return fn()}finally{console.log(name+" time: "+(_.now()-start)+"ms")}}function notime(name,fn){return fn()}},{"./graphlib":33,"./lodash":36}],56:[function(require,module,exports){module.exports="0.7.1"},{}],57:[function(require,module,exports){var lib=require("./lib");module.exports={Graph:lib.Graph,json:require("./lib/json"),alg:require("./lib/alg"),version:lib.version}},{"./lib":73,"./lib/alg":64,"./lib/json":74}],58:[function(require,module,exports){var _=require("../lodash");module.exports=components;function components(g){var visited={},cmpts=[],cmpt;function dfs(v){if(_.has(visited,v))return;visited[v]=true;cmpt.push(v);_.each(g.successors(v),dfs);_.each(g.predecessors(v),dfs)}_.each(g.nodes(),function(v){cmpt=[];dfs(v);if(cmpt.length){cmpts.push(cmpt)}});return cmpts}},{"../lodash":75}],59:[function(require,module,exports){var _=require("../lodash");module.exports=dfs;function dfs(g,vs,order){if(!_.isArray(vs)){vs=[vs]}var acc=[],visited={};_.each(vs,function(v){if(!g.hasNode(v)){throw new Error("Graph does not have node: "+v)}doDfs(g,v,order==="post",visited,acc)});return acc}function doDfs(g,v,postorder,visited,acc){if(!_.has(visited,v)){visited[v]=true;if(!postorder){acc.push(v)}_.each(g.neighbors(v),function(w){doDfs(g,w,postorder,visited,acc)});if(postorder){acc.push(v)}}}},{"../lodash":75}],60:[function(require,module,exports){var dijkstra=require("./dijkstra"),_=require("../lodash");module.exports=dijkstraAll;function dijkstraAll(g,weightFunc,edgeFunc){return _.transform(g.nodes(),function(acc,v){acc[v]=dijkstra(g,v,weightFunc,edgeFunc)},{})}},{"../lodash":75,"./dijkstra":61}],61:[function(require,module,exports){var _=require("../lodash"),PriorityQueue=require("../data/priority-queue");module.exports=dijkstra;var DEFAULT_WEIGHT_FUNC=_.constant(1);function dijkstra(g,source,weightFn,edgeFn){return runDijkstra(g,String(source),weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runDijkstra(g,source,weightFn,edgeFn){var results={},pq=new PriorityQueue,v,vEntry;var updateNeighbors=function(edge){var w=edge.v!==v?edge.v:edge.w,wEntry=results[w],weight=weightFn(edge),distance=vEntry.distance+weight;if(weight<0){throw new Error("dijkstra does not allow negative edge weights. "+"Bad edge: "+edge+" Weight: "+weight)}if(distance0){v=pq.removeMin();vEntry=results[v];if(vEntry.distance===Number.POSITIVE_INFINITY){break}edgeFn(v).forEach(updateNeighbors)}return results}},{"../data/priority-queue":71,"../lodash":75}],62:[function(require,module,exports){var _=require("../lodash"),tarjan=require("./tarjan");module.exports=findCycles;function findCycles(g){return _.filter(tarjan(g),function(cmpt){return cmpt.length>1})}},{"../lodash":75,"./tarjan":69}],63:[function(require,module,exports){var _=require("../lodash");module.exports=floydWarshall;var DEFAULT_WEIGHT_FUNC=_.constant(1);function floydWarshall(g,weightFn,edgeFn){return runFloydWarshall(g,weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runFloydWarshall(g,weightFn,edgeFn){var results={},nodes=g.nodes();nodes.forEach(function(v){results[v]={};results[v][v]={distance:0};nodes.forEach(function(w){if(v!==w){results[v][w]={distance:Number.POSITIVE_INFINITY}}});edgeFn(v).forEach(function(edge){var w=edge.v===v?edge.w:edge.v,d=weightFn(edge);results[v][w]={distance:d,predecessor:v}})});nodes.forEach(function(k){var rowK=results[k];nodes.forEach(function(i){var rowI=results[i];nodes.forEach(function(j){var ik=rowI[k];var kj=rowK[j];var ij=rowI[j];var altDistance=ik.distance+kj.distance;if(altDistance0){v=pq.removeMin();if(_.has(parents,v)){result.setEdge(v,parents[v])}else if(init){throw new Error("Input graph is not connected: "+g)}else{init=true}g.nodeEdges(v).forEach(updateNeighbors)}return result}},{"../data/priority-queue":71,"../graph":72,"../lodash":75}],69:[function(require,module,exports){var _=require("../lodash");module.exports=tarjan;function tarjan(g){var index=0,stack=[],visited={},results=[];function dfs(v){var entry=visited[v]={onStack:true,lowlink:index,index:index++};stack.push(v);g.successors(v).forEach(function(w){if(!_.has(visited,w)){dfs(w);entry.lowlink=Math.min(entry.lowlink,visited[w].lowlink)}else if(visited[w].onStack){entry.lowlink=Math.min(entry.lowlink,visited[w].index)}});if(entry.lowlink===entry.index){var cmpt=[],w;do{w=stack.pop();visited[w].onStack=false;cmpt.push(w)}while(v!==w);results.push(cmpt)}}g.nodes().forEach(function(v){if(!_.has(visited,v)){dfs(v)}});return results}},{"../lodash":75}],70:[function(require,module,exports){var _=require("../lodash");module.exports=topsort;topsort.CycleException=CycleException;function topsort(g){var visited={},stack={},results=[];function visit(node){if(_.has(stack,node)){throw new CycleException}if(!_.has(visited,node)){stack[node]=true;visited[node]=true;_.each(g.predecessors(node),visit);delete stack[node];results.push(node)}}_.each(g.sinks(),visit);if(_.size(visited)!==g.nodeCount()){throw new CycleException}return results}function CycleException(){}},{"../lodash":75}],71:[function(require,module,exports){var _=require("../lodash");module.exports=PriorityQueue;function PriorityQueue(){this._arr=[];this._keyIndices={}}PriorityQueue.prototype.size=function(){return this._arr.length};PriorityQueue.prototype.keys=function(){return this._arr.map(function(x){return x.key})};PriorityQueue.prototype.has=function(key){return _.has(this._keyIndices,key)};PriorityQueue.prototype.priority=function(key){var index=this._keyIndices[key];if(index!==undefined){return this._arr[index].priority}};PriorityQueue.prototype.min=function(){if(this.size()===0){throw new Error("Queue underflow")}return this._arr[0].key};PriorityQueue.prototype.add=function(key,priority){var keyIndices=this._keyIndices;key=String(key);if(!_.has(keyIndices,key)){var arr=this._arr;var index=arr.length;keyIndices[key]=index;arr.push({key:key,priority:priority});this._decrease(index);return true}return false};PriorityQueue.prototype.removeMin=function(){this._swap(0,this._arr.length-1);var min=this._arr.pop();delete this._keyIndices[min.key];this._heapify(0);return min.key};PriorityQueue.prototype.decrease=function(key,priority){var index=this._keyIndices[key];if(priority>this._arr[index].priority){throw new Error("New priority is greater than current priority. "+"Key: "+key+" Old: "+this._arr[index].priority+" New: "+priority)}this._arr[index].priority=priority;this._decrease(index)};PriorityQueue.prototype._heapify=function(i){var arr=this._arr;var l=2*i,r=l+1,largest=i;if(l>1;if(arr[parent].priority1){this.setNode(v,value)}else{this.setNode(v)}},this);return this};Graph.prototype.setNode=function(v,value){if(_.has(this._nodes,v)){if(arguments.length>1){this._nodes[v]=value}return this}this._nodes[v]=arguments.length>1?value:this._defaultNodeLabelFn(v);if(this._isCompound){this._parent[v]=GRAPH_NODE;this._children[v]={};this._children[GRAPH_NODE][v]=true}this._in[v]={};this._preds[v]={};this._out[v]={};this._sucs[v]={};++this._nodeCount;return this};Graph.prototype.node=function(v){return this._nodes[v]};Graph.prototype.hasNode=function(v){return _.has(this._nodes,v)};Graph.prototype.removeNode=function(v){var self=this;if(_.has(this._nodes,v)){var removeEdge=function(e){self.removeEdge(self._edgeObjs[e])};delete this._nodes[v];if(this._isCompound){this._removeFromParentsChildList(v);delete this._parent[v];_.each(this.children(v),function(child){this.setParent(child)},this);delete this._children[v]}_.each(_.keys(this._in[v]),removeEdge);delete this._in[v];delete this._preds[v];_.each(_.keys(this._out[v]),removeEdge);delete this._out[v];delete this._sucs[v];--this._nodeCount}return this};Graph.prototype.setParent=function(v,parent){if(!this._isCompound){throw new Error("Cannot set parent in a non-compound graph")}if(_.isUndefined(parent)){parent=GRAPH_NODE}else{for(var ancestor=parent;!_.isUndefined(ancestor);ancestor=this.parent(ancestor)){if(ancestor===v){throw new Error("Setting "+parent+" as parent of "+v+" would create create a cycle")}}this.setNode(parent)}this.setNode(v);this._removeFromParentsChildList(v);this._parent[v]=parent;this._children[parent][v]=true;return this};Graph.prototype._removeFromParentsChildList=function(v){delete this._children[this._parent[v]][v]};Graph.prototype.parent=function(v){if(this._isCompound){var parent=this._parent[v];if(parent!==GRAPH_NODE){return parent}}};Graph.prototype.children=function(v){if(_.isUndefined(v)){v=GRAPH_NODE}if(this._isCompound){var children=this._children[v];if(children){return _.keys(children)}}else if(v===GRAPH_NODE){return this.nodes()}else if(this.hasNode(v)){return[]}};Graph.prototype.predecessors=function(v){var predsV=this._preds[v];if(predsV){return _.keys(predsV)}};Graph.prototype.successors=function(v){var sucsV=this._sucs[v];if(sucsV){return _.keys(sucsV)}};Graph.prototype.neighbors=function(v){var preds=this.predecessors(v);if(preds){return _.union(preds,this.successors(v))}};Graph.prototype.setDefaultEdgeLabel=function(newDefault){if(!_.isFunction(newDefault)){newDefault=_.constant(newDefault)}this._defaultEdgeLabelFn=newDefault;return this};Graph.prototype.edgeCount=function(){return this._edgeCount};Graph.prototype.edges=function(){return _.values(this._edgeObjs)};Graph.prototype.setPath=function(vs,value){var self=this,args=arguments;_.reduce(vs,function(v,w){if(args.length>1){self.setEdge(v,w,value)}else{self.setEdge(v,w)}return w});return this};Graph.prototype.setEdge=function(){var v,w,name,value,valueSpecified=false;if(_.isPlainObject(arguments[0])){v=arguments[0].v;w=arguments[0].w;name=arguments[0].name;if(arguments.length===2){value=arguments[1];valueSpecified=true}}else{v=arguments[0];w=arguments[1];name=arguments[3];if(arguments.length>2){value=arguments[2];valueSpecified=true}}v=""+v;w=""+w;if(!_.isUndefined(name)){name=""+name}var e=edgeArgsToId(this._isDirected,v,w,name);if(_.has(this._edgeLabels,e)){if(valueSpecified){this._edgeLabels[e]=value}return this}if(!_.isUndefined(name)&&!this._isMultigraph){throw new Error("Cannot set a named edge when isMultigraph = false")}this.setNode(v);this.setNode(w);this._edgeLabels[e]=valueSpecified?value:this._defaultEdgeLabelFn(v,w,name);var edgeObj=edgeArgsToObj(this._isDirected,v,w,name);v=edgeObj.v;w=edgeObj.w;Object.freeze(edgeObj);this._edgeObjs[e]=edgeObj;incrementOrInitEntry(this._preds[w],v);incrementOrInitEntry(this._sucs[v],w);this._in[w][e]=edgeObj;this._out[v][e]=edgeObj;this._edgeCount++;return this};Graph.prototype.edge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return this._edgeLabels[e]};Graph.prototype.hasEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return _.has(this._edgeLabels,e)};Graph.prototype.removeEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name),edge=this._edgeObjs[e];if(edge){v=edge.v;w=edge.w;delete this._edgeLabels[e];delete this._edgeObjs[e];decrementOrRemoveEntry(this._preds[w],v);decrementOrRemoveEntry(this._sucs[v],w);delete this._in[w][e];delete this._out[v][e];this._edgeCount--}return this};Graph.prototype.inEdges=function(v,u){var inV=this._in[v];if(inV){var edges=_.values(inV);if(!u){return edges}return _.filter(edges,function(edge){return edge.v===u})}};Graph.prototype.outEdges=function(v,w){var outV=this._out[v];if(outV){var edges=_.values(outV);if(!w){return edges}return _.filter(edges,function(edge){return edge.w===w})}};Graph.prototype.nodeEdges=function(v,w){var inEdges=this.inEdges(v,w);if(inEdges){return inEdges.concat(this.outEdges(v,w))}};function incrementOrInitEntry(map,k){if(_.has(map,k)){map[k]++}else{map[k]=1}}function decrementOrRemoveEntry(map,k){if(!--map[k]){delete map[k]}}function edgeArgsToId(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}return v+EDGE_KEY_DELIM+w+EDGE_KEY_DELIM+(_.isUndefined(name)?DEFAULT_EDGE_NAME:name)}function edgeArgsToObj(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}var edgeObj={v:v,w:w};if(name){edgeObj.name=name}return edgeObj}function edgeObjToId(isDirected,edgeObj){return edgeArgsToId(isDirected,edgeObj.v,edgeObj.w,edgeObj.name)}},{"./lodash":75}],73:[function(require,module,exports){module.exports={Graph:require("./graph"),version:require("./version")}},{"./graph":72,"./version":76}],74:[function(require,module,exports){var _=require("./lodash"),Graph=require("./graph");module.exports={write:write,read:read};function write(g){var json={options:{directed:g.isDirected(),multigraph:g.isMultigraph(),compound:g.isCompound()},nodes:writeNodes(g),edges:writeEdges(g)};if(!_.isUndefined(g.graph())){json.value=_.clone(g.graph())}return json}function writeNodes(g){return _.map(g.nodes(),function(v){var nodeValue=g.node(v),parent=g.parent(v),node={v:v};if(!_.isUndefined(nodeValue)){node.value=nodeValue}if(!_.isUndefined(parent)){node.parent=parent}return node})}function writeEdges(g){return _.map(g.edges(),function(e){var edgeValue=g.edge(e),edge={v:e.v,w:e.w};if(!_.isUndefined(e.name)){edge.name=e.name}if(!_.isUndefined(edgeValue)){edge.value=edgeValue}return edge})}function read(json){var g=new Graph(json.options).setGraph(json.value);_.each(json.nodes,function(entry){g.setNode(entry.v,entry.value);if(entry.parent){g.setParent(entry.v,entry.parent)}});_.each(json.edges,function(entry){g.setEdge({v:entry.v,w:entry.w,name:entry.name},entry.value)});return g}},{"./graph":72,"./lodash":75}],75:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],76:[function(require,module,exports){module.exports="1.0.1"},{}],77:[function(require,module,exports){(function(global){(function(){var undefined;var arrayPool=[],objectPool=[];var idCounter=0;var keyPrefix=+new Date+"";var largeArraySize=75;var maxPoolSize=40;var whitespace=" 	\f \ufeff"+"\n\r\u2028\u2029"+" ᠎              ";var reEmptyStringLeading=/\b__p \+= '';/g,reEmptyStringMiddle=/\b(__p \+=) '' \+/g,reEmptyStringTrailing=/(__e\(.*?\)|\b__t\)) \+\n'';/g;var reEsTemplate=/\$\{([^\\}]*(?:\\.[^\\}]*)*)\}/g;var reFlags=/\w*$/;var reFuncName=/^\s*function[ \n\r\t]+\w/;var reInterpolate=/<%=([\s\S]+?)%>/g;var reLeadingSpacesAndZeros=RegExp("^["+whitespace+"]*0+(?=.$)");var reNoMatch=/($^)/;var reThis=/\bthis\b/;var reUnescapedString=/['\n\r\t\u2028\u2029\\]/g;var contextProps=["Array","Boolean","Date","Function","Math","Number","Object","RegExp","String","_","attachEvent","clearTimeout","isFinite","isNaN","parseInt","setTimeout"];var templateCounter=0;var argsClass="[object Arguments]",arrayClass="[object Array]",boolClass="[object Boolean]",dateClass="[object Date]",funcClass="[object Function]",numberClass="[object Number]",objectClass="[object Object]",regexpClass="[object RegExp]",stringClass="[object String]";var cloneableClasses={};cloneableClasses[funcClass]=false;cloneableClasses[argsClass]=cloneableClasses[arrayClass]=cloneableClasses[boolClass]=cloneableClasses[dateClass]=cloneableClasses[numberClass]=cloneableClasses[objectClass]=cloneableClasses[regexpClass]=cloneableClasses[stringClass]=true;var debounceOptions={leading:false,maxWait:0,trailing:false};var descriptor={configurable:false,enumerable:false,value:null,writable:false};var objectTypes={"boolean":false,"function":true,object:true,number:false,string:false,undefined:false};var stringEscapes={"\\":"\\","'":"'","\n":"n","\r":"r","	":"t","\u2028":"u2028","\u2029":"u2029"};var root=objectTypes[typeof window]&&window||this;var freeExports=objectTypes[typeof exports]&&exports&&!exports.nodeType&&exports;var freeModule=objectTypes[typeof module]&&module&&!module.nodeType&&module;var moduleExports=freeModule&&freeModule.exports===freeExports&&freeExports;var freeGlobal=objectTypes[typeof global]&&global;if(freeGlobal&&(freeGlobal.global===freeGlobal||freeGlobal.window===freeGlobal)){root=freeGlobal}function baseIndexOf(array,value,fromIndex){var index=(fromIndex||0)-1,length=array?array.length:0;while(++index-1?0:-1:cache?0:-1}function cachePush(value){var cache=this.cache,type=typeof value;if(type=="boolean"||value==null){cache[value]=true}else{if(type!="number"&&type!="string"){type="object"}var key=type=="number"?value:keyPrefix+value,typeCache=cache[type]||(cache[type]={});if(type=="object"){(typeCache[key]||(typeCache[key]=[])).push(value)}else{typeCache[key]=true}}}function charAtCallback(value){return value.charCodeAt(0)}function compareAscending(a,b){var ac=a.criteria,bc=b.criteria,index=-1,length=ac.length;while(++indexother||typeof value=="undefined"){return 1}if(value/g,evaluate:/<%([\s\S]+?)%>/g,interpolate:reInterpolate,variable:"",imports:{_:lodash}};function baseBind(bindData){var func=bindData[0],partialArgs=bindData[2],thisArg=bindData[4];function bound(){if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(this instanceof bound){var thisBinding=baseCreate(func.prototype),result=func.apply(thisBinding,args||arguments);return isObject(result)?result:thisBinding}return func.apply(thisArg,args||arguments)}setBindData(bound,bindData);return bound}function baseClone(value,isDeep,callback,stackA,stackB){if(callback){var result=callback(value);if(typeof result!="undefined"){return result}}var isObj=isObject(value);if(isObj){var className=toString.call(value);if(!cloneableClasses[className]){return value}var ctor=ctorByClass[className];switch(className){case boolClass:case dateClass:return new ctor(+value);case numberClass:case stringClass:return new ctor(value);case regexpClass:result=ctor(value.source,reFlags.exec(value));result.lastIndex=value.lastIndex;return result}}else{return value}var isArr=isArray(value);if(isDeep){var initedStack=!stackA;stackA||(stackA=getArray());stackB||(stackB=getArray());var length=stackA.length;while(length--){if(stackA[length]==value){return stackB[length]}}result=isArr?ctor(value.length):{}}else{result=isArr?slice(value):assign({},value)}if(isArr){if(hasOwnProperty.call(value,"index")){result.index=value.index}if(hasOwnProperty.call(value,"input")){result.input=value.input}}if(!isDeep){return result}stackA.push(value);stackB.push(result);(isArr?forEach:forOwn)(value,function(objValue,key){result[key]=baseClone(objValue,isDeep,callback,stackA,stackB)});if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseCreate(prototype,properties){
+return isObject(prototype)?nativeCreate(prototype):{}}if(!nativeCreate){baseCreate=function(){function Object(){}return function(prototype){if(isObject(prototype)){Object.prototype=prototype;var result=new Object;Object.prototype=null}return result||context.Object()}}()}function baseCreateCallback(func,thisArg,argCount){if(typeof func!="function"){return identity}if(typeof thisArg=="undefined"||!("prototype"in func)){return func}var bindData=func.__bindData__;if(typeof bindData=="undefined"){if(support.funcNames){bindData=!func.name}bindData=bindData||!support.funcDecomp;if(!bindData){var source=fnToString.call(func);if(!support.funcNames){bindData=!reFuncName.test(source)}if(!bindData){bindData=reThis.test(source);setBindData(func,bindData)}}}if(bindData===false||bindData!==true&&bindData[1]&1){return func}switch(argCount){case 1:return function(value){return func.call(thisArg,value)};case 2:return function(a,b){return func.call(thisArg,a,b)};case 3:return function(value,index,collection){return func.call(thisArg,value,index,collection)};case 4:return function(accumulator,value,index,collection){return func.call(thisArg,accumulator,value,index,collection)}}return bind(func,thisArg)}function baseCreateWrapper(bindData){var func=bindData[0],bitmask=bindData[1],partialArgs=bindData[2],partialRightArgs=bindData[3],thisArg=bindData[4],arity=bindData[5];var isBind=bitmask&1,isBindKey=bitmask&2,isCurry=bitmask&4,isCurryBound=bitmask&8,key=func;function bound(){var thisBinding=isBind?thisArg:this;if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(partialRightArgs||isCurry){args||(args=slice(arguments));if(partialRightArgs){push.apply(args,partialRightArgs)}if(isCurry&&args.length=largeArraySize&&indexOf===baseIndexOf,result=[];if(isLarge){var cache=createCache(values);if(cache){indexOf=cacheIndexOf;values=cache}else{isLarge=false}}while(++index-1}})}}stackA.pop();stackB.pop();if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseMerge(object,source,callback,stackA,stackB){(isArray(source)?forEach:forOwn)(source,function(source,key){var found,isArr,result=source,value=object[key];if(source&&((isArr=isArray(source))||isPlainObject(source))){var stackLength=stackA.length;while(stackLength--){if(found=stackA[stackLength]==source){value=stackB[stackLength];break}}if(!found){var isShallow;if(callback){result=callback(value,source);if(isShallow=typeof result!="undefined"){value=result}}if(!isShallow){value=isArr?isArray(value)?value:[]:isPlainObject(value)?value:{}}stackA.push(source);stackB.push(value);if(!isShallow){baseMerge(value,source,callback,stackA,stackB)}}}else{if(callback){result=callback(value,source);if(typeof result=="undefined"){result=source}}if(typeof result!="undefined"){value=result}}object[key]=value})}function baseRandom(min,max){return min+floor(nativeRandom()*(max-min+1))}function baseUniq(array,isSorted,callback){var index=-1,indexOf=getIndexOf(),length=array?array.length:0,result=[];var isLarge=!isSorted&&length>=largeArraySize&&indexOf===baseIndexOf,seen=callback||isLarge?getArray():result;if(isLarge){var cache=createCache(seen);indexOf=cacheIndexOf;seen=cache}while(++index":">",'"':""","'":"'"};var htmlUnescapes=invert(htmlEscapes);var reEscapedHtml=RegExp("("+keys(htmlUnescapes).join("|")+")","g"),reUnescapedHtml=RegExp("["+keys(htmlEscapes).join("")+"]","g");var assign=function(object,source,guard){var index,iterable=object,result=iterable;if(!iterable)return result;var args=arguments,argsIndex=0,argsLength=typeof guard=="number"?2:args.length;if(argsLength>3&&typeof args[argsLength-2]=="function"){var callback=baseCreateCallback(args[--argsLength-1],args[argsLength--],2)}else if(argsLength>2&&typeof args[argsLength-1]=="function"){callback=args[--argsLength]}while(++argsIndex3&&typeof args[length-2]=="function"){var callback=baseCreateCallback(args[--length-1],args[length--],2)}else if(length>2&&typeof args[length-1]=="function"){callback=args[--length]}var sources=slice(arguments,1,length),index=-1,stackA=getArray(),stackB=getArray();while(++index-1}else if(typeof length=="number"){result=(isString(collection)?collection.indexOf(target,fromIndex):indexOf(collection,target,fromIndex))>-1}else{forOwn(collection,function(value){if(++index>=fromIndex){return!(result=value===target)}})}return result}var countBy=createAggregator(function(result,value,key){hasOwnProperty.call(result,key)?result[key]++:result[key]=1});function every(collection,callback,thisArg){var result=true;callback=lodash.createCallback(callback,thisArg,3);var index=-1,length=collection?collection.length:0;if(typeof length=="number"){while(++indexresult){result=value}}}else{callback=callback==null&&isString(collection)?charAtCallback:lodash.createCallback(callback,thisArg,3);forEach(collection,function(value,index,collection){var current=callback(value,index,collection);if(current>computed){computed=current;result=value}})}return result}function min(collection,callback,thisArg){var computed=Infinity,result=computed;if(typeof callback!="function"&&thisArg&&thisArg[callback]===collection){callback=null}if(callback==null&&isArray(collection)){var index=-1,length=collection.length;while(++index=largeArraySize&&createCache(argsIndex?args[argsIndex]:seen))}}var array=args[0],index=-1,length=array?array.length:0,result=[];outer:while(++index>>1;callback(array[mid])=largeArraySize&&indexOf===baseIndexOf,result=[];if(isLarge){var cache=createCache(values);if(cache){indexOf=cacheIndexOf;values=cache}else{isLarge=false}}while(++index-1}})}}stackA.pop();stackB.pop();if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseMerge(object,source,callback,stackA,stackB){(isArray(source)?forEach:forOwn)(source,function(source,key){var found,isArr,result=source,value=object[key];if(source&&((isArr=isArray(source))||isPlainObject(source))){var stackLength=stackA.length;while(stackLength--){if(found=stackA[stackLength]==source){value=stackB[stackLength];break}}if(!found){var isShallow;if(callback){result=callback(value,source);if(isShallow=typeof result!="undefined"){value=result}}if(!isShallow){value=isArr?isArray(value)?value:[]:isPlainObject(value)?value:{}}stackA.push(source);stackB.push(value);if(!isShallow){baseMerge(value,source,callback,stackA,stackB)}}}else{if(callback){result=callback(value,source);if(typeof result=="undefined"){result=source}}if(typeof result!="undefined"){value=result}}object[key]=value})}function baseRandom(min,max){return min+floor(nativeRandom()*(max-min+1))}function baseUniq(array,isSorted,callback){var index=-1,indexOf=getIndexOf(),length=array?array.length:0,result=[];var isLarge=!isSorted&&length>=largeArraySize&&indexOf===baseIndexOf,seen=callback||isLarge?getArray():result;if(isLarge){var cache=createCache(seen);indexOf=cacheIndexOf;seen=cache}while(++index":">",'"':""","'":"'"};var htmlUnescapes=invert(htmlEscapes);var reEscapedHtml=RegExp("("+keys(htmlUnescapes).join("|")+")","g"),reUnescapedHtml=RegExp("["+keys(htmlEscapes).join("")+"]","g");var assign=function(object,source,guard){var index,iterable=object,result=iterable;if(!iterable)return result;var args=arguments,argsIndex=0,argsLength=typeof guard=="number"?2:args.length;if(argsLength>3&&typeof args[argsLength-2]=="function"){var callback=baseCreateCallback(args[--argsLength-1],args[argsLength--],2)}else if(argsLength>2&&typeof args[argsLength-1]=="function"){callback=args[--argsLength]}while(++argsIndex3&&typeof args[length-2]=="function"){var callback=baseCreateCallback(args[--length-1],args[length--],2)}else if(length>2&&typeof args[length-1]=="function"){callback=args[--length]}var sources=slice(arguments,1,length),index=-1,stackA=getArray(),stackB=getArray();while(++index-1}else if(typeof length=="number"){result=(isString(collection)?collection.indexOf(target,fromIndex):indexOf(collection,target,fromIndex))>-1}else{forOwn(collection,function(value){if(++index>=fromIndex){return!(result=value===target)}})}return result}var countBy=createAggregator(function(result,value,key){hasOwnProperty.call(result,key)?result[key]++:result[key]=1});function every(collection,callback,thisArg){var result=true;callback=lodash.createCallback(callback,thisArg,3);var index=-1,length=collection?collection.length:0;if(typeof length=="number"){while(++indexresult){result=value}}}else{callback=callback==null&&isString(collection)?charAtCallback:lodash.createCallback(callback,thisArg,3);forEach(collection,function(value,index,collection){var current=callback(value,index,collection);if(current>computed){computed=current;result=value}})}return result}function min(collection,callback,thisArg){var computed=Infinity,result=computed;if(typeof callback!="function"&&thisArg&&thisArg[callback]===collection){callback=null}if(callback==null&&isArray(collection)){var index=-1,length=collection.length;while(++index=largeArraySize&&createCache(argsIndex?args[argsIndex]:seen))}}var array=args[0],index=-1,length=array?array.length:0,result=[];outer:while(++index>>1;callback(array[mid])1?arguments:arguments[0],index=-1,length=array?max(pluck(array,"length")):0,result=Array(length<0?0:length);while(++index2?createWrapper(func,17,slice(arguments,2),null,thisArg):createWrapper(func,1,null,null,thisArg)}function bindAll(object){var funcs=arguments.length>1?baseFlatten(arguments,true,false,1):functions(object),index=-1,length=funcs.length;while(++index2?createWrapper(key,19,slice(arguments,2),null,object):createWrapper(key,3,null,null,object)}function compose(){var funcs=arguments,length=funcs.length;while(length--){if(!isFunction(funcs[length])){throw new TypeError}}return function(){var args=arguments,length=funcs.length;while(length--){args=[funcs[length].apply(this,args)]}return args[0]}}function curry(func,arity){arity=typeof arity=="number"?arity:+arity||func.length;return createWrapper(func,4,null,null,null,arity)}function debounce(func,wait,options){var args,maxTimeoutId,result,stamp,thisArg,timeoutId,trailingCall,lastCalled=0,maxWait=false,trailing=true;if(!isFunction(func)){throw new TypeError}wait=nativeMax(0,wait)||0;if(options===true){var leading=true;trailing=false}else if(isObject(options)){leading=options.leading;maxWait="maxWait"in options&&(nativeMax(wait,options.maxWait)||0);trailing="trailing"in options?options.trailing:trailing}var delayed=function(){var remaining=wait-(now()-stamp);if(remaining<=0){if(maxTimeoutId){clearTimeout(maxTimeoutId)}var isCalled=trailingCall;maxTimeoutId=timeoutId=trailingCall=undefined;if(isCalled){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}}else{timeoutId=setTimeout(delayed,remaining)}};var maxDelayed=function(){if(timeoutId){clearTimeout(timeoutId)}maxTimeoutId=timeoutId=trailingCall=undefined;if(trailing||maxWait!==wait){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}};return function(){args=arguments;stamp=now();thisArg=this;trailingCall=trailing&&(timeoutId||!leading);if(maxWait===false){var leadingCall=leading&&!timeoutId}else{if(!maxTimeoutId&&!leading){lastCalled=stamp}var remaining=maxWait-(stamp-lastCalled),isCalled=remaining<=0;if(isCalled){if(maxTimeoutId){maxTimeoutId=clearTimeout(maxTimeoutId)}lastCalled=stamp;result=func.apply(thisArg,args)}else if(!maxTimeoutId){maxTimeoutId=setTimeout(maxDelayed,remaining)}}if(isCalled&&timeoutId){timeoutId=clearTimeout(timeoutId)}else if(!timeoutId&&wait!==maxWait){timeoutId=setTimeout(delayed,wait)}if(leadingCall){isCalled=true;result=func.apply(thisArg,args)}if(isCalled&&!timeoutId&&!maxTimeoutId){args=thisArg=null}return result}}function defer(func){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,1);return setTimeout(function(){func.apply(undefined,args)},1)}function delay(func,wait){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,2);return setTimeout(function(){func.apply(undefined,args)},wait)}function memoize(func,resolver){if(!isFunction(func)){throw new TypeError}var memoized=function(){var cache=memoized.cache,key=resolver?resolver.apply(this,arguments):keyPrefix+arguments[0];return hasOwnProperty.call(cache,key)?cache[key]:cache[key]=func.apply(this,arguments)};memoized.cache={};return memoized}function once(func){var ran,result;if(!isFunction(func)){throw new TypeError}return function(){if(ran){return result}ran=true;result=func.apply(this,arguments);func=null;return result}}function partial(func){return createWrapper(func,16,slice(arguments,1))}function partialRight(func){return createWrapper(func,32,null,slice(arguments,1))}function throttle(func,wait,options){var leading=true,trailing=true;if(!isFunction(func)){throw new TypeError}if(options===false){leading=false}else if(isObject(options)){leading="leading"in options?options.leading:leading;trailing="trailing"in options?options.trailing:trailing}debounceOptions.leading=leading;debounceOptions.maxWait=wait;debounceOptions.trailing=trailing;return debounce(func,wait,debounceOptions)}function wrap(value,wrapper){return createWrapper(wrapper,16,[value])}function constant(value){return function(){return value}}function createCallback(func,thisArg,argCount){var type=typeof func;if(func==null||type=="function"){return baseCreateCallback(func,thisArg,argCount)}if(type!="object"){return property(func)}var props=keys(func),key=props[0],a=func[key];if(props.length==1&&a===a&&!isObject(a)){return function(object){var b=object[key];return a===b&&(a!==0||1/a==1/b)}}return function(object){var length=props.length,result=false;while(length--){if(!(result=baseIsEqual(object[props[length]],func[props[length]],null,true))){break}}return result}}function escape(string){return string==null?"":String(string).replace(reUnescapedHtml,escapeHtmlChar)}function identity(value){return value}function mixin(object,source,options){var chain=true,methodNames=source&&functions(source);if(!source||!options&&!methodNames.length){if(options==null){options=source}ctor=lodashWrapper;source=object;object=lodash;methodNames=functions(source)}if(options===false){chain=false}else if(isObject(options)&&"chain"in options){chain=options.chain}var ctor=object,isFunc=isFunction(ctor);forEach(methodNames,function(methodName){var func=object[methodName]=source[methodName];if(isFunc){ctor.prototype[methodName]=function(){var chainAll=this.__chain__,value=this.__wrapped__,args=[value];push.apply(args,arguments);var result=func.apply(object,args);if(chain||chainAll){if(value===result&&isObject(result)){return this}result=new ctor(result);result.__chain__=chainAll}return result}}})}function noConflict(){context._=oldDash;return this}function noop(){}var now=isNative(now=Date.now)&&now||function(){return(new Date).getTime()};var parseInt=nativeParseInt(whitespace+"08")==8?nativeParseInt:function(value,radix){return nativeParseInt(isString(value)?value.replace(reLeadingSpacesAndZeros,""):value,radix||0)};function property(key){return function(object){return object[key]}}function random(min,max,floating){var noMin=min==null,noMax=max==null;if(floating==null){if(typeof min=="boolean"&&noMax){floating=min;min=1}else if(!noMax&&typeof max=="boolean"){floating=max;noMax=true}}if(noMin&&noMax){max=1}min=+min||0;if(noMax){max=min;min=0}else{max=+max||0}if(floating||min%1||max%1){var rand=nativeRandom();return nativeMin(min+rand*(max-min+parseFloat("1e-"+((rand+"").length-1))),max)}return baseRandom(min,max)}function result(object,key){if(object){var value=object[key];return isFunction(value)?object[key]():value}}function template(text,data,options){var settings=lodash.templateSettings;text=String(text||"");options=defaults({},options,settings);var imports=defaults({},options.imports,settings.imports),importsKeys=keys(imports),importsValues=values(imports);var isEvaluating,index=0,interpolate=options.interpolate||reNoMatch,source="__p += '";var reDelimiters=RegExp((options.escape||reNoMatch).source+"|"+interpolate.source+"|"+(interpolate===reInterpolate?reEsTemplate:reNoMatch).source+"|"+(options.evaluate||reNoMatch).source+"|$","g");text.replace(reDelimiters,function(match,escapeValue,interpolateValue,esTemplateValue,evaluateValue,offset){interpolateValue||(interpolateValue=esTemplateValue);source+=text.slice(index,offset).replace(reUnescapedString,escapeStringChar);if(escapeValue){source+="' +\n__e("+escapeValue+") +\n'"}if(evaluateValue){isEvaluating=true;source+="';\n"+evaluateValue+";\n__p += '"}if(interpolateValue){source+="' +\n((__t = ("+interpolateValue+")) == null ? '' : __t) +\n'"}index=offset+match.length;return match});source+="';\n";var variable=options.variable,hasVariable=variable;if(!hasVariable){variable="obj";source="with ("+variable+") {\n"+source+"\n}\n"}source=(isEvaluating?source.replace(reEmptyStringLeading,""):source).replace(reEmptyStringMiddle,"$1").replace(reEmptyStringTrailing,"$1;");source="function("+variable+") {\n"+(hasVariable?"":variable+" || ("+variable+" = {});\n")+"var __t, __p = '', __e = _.escape"+(isEvaluating?", __j = Array.prototype.join;\n"+"function print() { __p += __j.call(arguments, '') }\n":";\n")+source+"return __p\n}";var sourceURL="\n/*\n//# sourceURL="+(options.sourceURL||"/lodash/template/source["+templateCounter++ +"]")+"\n*/";try{var result=Function(importsKeys,"return "+source+sourceURL).apply(undefined,importsValues)}catch(e){e.source=source;throw e}if(data){return result(data)}result.source=source;return result}function times(n,callback,thisArg){n=(n=+n)>-1?n:0;var index=-1,result=Array(n);callback=baseCreateCallback(callback,thisArg,1);while(++index1?arguments:arguments[0],index=-1,length=array?max(pluck(array,"length")):0,result=Array(length<0?0:length);while(++index2?createWrapper(func,17,slice(arguments,2),null,thisArg):createWrapper(func,1,null,null,thisArg)}function bindAll(object){var funcs=arguments.length>1?baseFlatten(arguments,true,false,1):functions(object),index=-1,length=funcs.length;while(++index2?createWrapper(key,19,slice(arguments,2),null,object):createWrapper(key,3,null,null,object)}function compose(){var funcs=arguments,length=funcs.length;while(length--){if(!isFunction(funcs[length])){throw new TypeError}}return function(){var args=arguments,length=funcs.length;while(length--){args=[funcs[length].apply(this,args)]}return args[0]}}function curry(func,arity){arity=typeof arity=="number"?arity:+arity||func.length;return createWrapper(func,4,null,null,null,arity)}function debounce(func,wait,options){var args,maxTimeoutId,result,stamp,thisArg,timeoutId,trailingCall,lastCalled=0,maxWait=false,trailing=true;if(!isFunction(func)){throw new TypeError}wait=nativeMax(0,wait)||0;if(options===true){var leading=true;trailing=false}else if(isObject(options)){leading=options.leading;maxWait="maxWait"in options&&(nativeMax(wait,options.maxWait)||0);trailing="trailing"in options?options.trailing:trailing}var delayed=function(){var remaining=wait-(now()-stamp);if(remaining<=0){if(maxTimeoutId){clearTimeout(maxTimeoutId)}var isCalled=trailingCall;maxTimeoutId=timeoutId=trailingCall=undefined;if(isCalled){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}}else{timeoutId=setTimeout(delayed,remaining)}};var maxDelayed=function(){if(timeoutId){clearTimeout(timeoutId)}maxTimeoutId=timeoutId=trailingCall=undefined;if(trailing||maxWait!==wait){lastCalled=now();result=func.apply(thisArg,args);if(!timeoutId&&!maxTimeoutId){args=thisArg=null}}};return function(){args=arguments;stamp=now();thisArg=this;trailingCall=trailing&&(timeoutId||!leading);if(maxWait===false){var leadingCall=leading&&!timeoutId}else{if(!maxTimeoutId&&!leading){lastCalled=stamp}var remaining=maxWait-(stamp-lastCalled),isCalled=remaining<=0;if(isCalled){if(maxTimeoutId){maxTimeoutId=clearTimeout(maxTimeoutId)}lastCalled=stamp;result=func.apply(thisArg,args)}else if(!maxTimeoutId){maxTimeoutId=setTimeout(maxDelayed,remaining)}}if(isCalled&&timeoutId){timeoutId=clearTimeout(timeoutId)}else if(!timeoutId&&wait!==maxWait){timeoutId=setTimeout(delayed,wait)}if(leadingCall){isCalled=true;result=func.apply(thisArg,args)}if(isCalled&&!timeoutId&&!maxTimeoutId){args=thisArg=null}return result}}function defer(func){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,1);return setTimeout(function(){func.apply(undefined,args)},1)}function delay(func,wait){if(!isFunction(func)){throw new TypeError}var args=slice(arguments,2);return setTimeout(function(){func.apply(undefined,args)},wait)}function memoize(func,resolver){if(!isFunction(func)){throw new TypeError}var memoized=function(){var cache=memoized.cache,key=resolver?resolver.apply(this,arguments):keyPrefix+arguments[0];return hasOwnProperty.call(cache,key)?cache[key]:cache[key]=func.apply(this,arguments)};memoized.cache={};return memoized}function once(func){var ran,result;if(!isFunction(func)){throw new TypeError}return function(){if(ran){return result}ran=true;result=func.apply(this,arguments);func=null;return result}}function partial(func){return createWrapper(func,16,slice(arguments,1))}function partialRight(func){return createWrapper(func,32,null,slice(arguments,1))}function throttle(func,wait,options){var leading=true,trailing=true;if(!isFunction(func)){throw new TypeError}if(options===false){leading=false}else if(isObject(options)){leading="leading"in options?options.leading:leading;trailing="trailing"in options?options.trailing:trailing}debounceOptions.leading=leading;debounceOptions.maxWait=wait;debounceOptions.trailing=trailing;return debounce(func,wait,debounceOptions)}function wrap(value,wrapper){return createWrapper(wrapper,16,[value])}function constant(value){return function(){return value}}function createCallback(func,thisArg,argCount){var type=typeof func;if(func==null||type=="function"){return baseCreateCallback(func,thisArg,argCount)}if(type!="object"){return property(func)}var props=keys(func),key=props[0],a=func[key];if(props.length==1&&a===a&&!isObject(a)){return function(object){var b=object[key];return a===b&&(a!==0||1/a==1/b)}}return function(object){var length=props.length,result=false;while(length--){if(!(result=baseIsEqual(object[props[length]],func[props[length]],null,true))){break}}return result}}function escape(string){return string==null?"":String(string).replace(reUnescapedHtml,escapeHtmlChar)}function identity(value){return value}function mixin(object,source,options){var chain=true,methodNames=source&&functions(source);if(!source||!options&&!methodNames.length){if(options==null){options=source}ctor=lodashWrapper;source=object;object=lodash;methodNames=functions(source)}if(options===false){chain=false}else if(isObject(options)&&"chain"in options){chain=options.chain}var ctor=object,isFunc=isFunction(ctor);forEach(methodNames,function(methodName){var func=object[methodName]=source[methodName];if(isFunc){ctor.prototype[methodName]=function(){var chainAll=this.__chain__,value=this.__wrapped__,args=[value];push.apply(args,arguments);var result=func.apply(object,args);if(chain||chainAll){if(value===result&&isObject(result)){return this}result=new ctor(result);result.__chain__=chainAll}return result}}})}function noConflict(){context._=oldDash;return this}function noop(){}var now=isNative(now=Date.now)&&now||function(){return(new Date).getTime()};var parseInt=nativeParseInt(whitespace+"08")==8?nativeParseInt:function(value,radix){return nativeParseInt(isString(value)?value.replace(reLeadingSpacesAndZeros,""):value,radix||0)};function property(key){return function(object){return object[key]}}function random(min,max,floating){var noMin=min==null,noMax=max==null;if(floating==null){if(typeof min=="boolean"&&noMax){floating=min;min=1}else if(!noMax&&typeof max=="boolean"){floating=max;noMax=true}}if(noMin&&noMax){max=1}min=+min||0;if(noMax){max=min;min=0}else{max=+max||0}if(floating||min%1||max%1){var rand=nativeRandom();return nativeMin(min+rand*(max-min+parseFloat("1e-"+((rand+"").length-1))),max)}return baseRandom(min,max)}function result(object,key){if(object){var value=object[key];return isFunction(value)?object[key]():value}}function template(text,data,options){var settings=lodash.templateSettings;text=String(text||"");options=defaults({},options,settings);var imports=defaults({},options.imports,settings.imports),importsKeys=keys(imports),importsValues=values(imports);var isEvaluating,index=0,interpolate=options.interpolate||reNoMatch,source="__p += '";var reDelimiters=RegExp((options.escape||reNoMatch).source+"|"+interpolate.source+"|"+(interpolate===reInterpolate?reEsTemplate:reNoMatch).source+"|"+(options.evaluate||reNoMatch).source+"|$","g");text.replace(reDelimiters,function(match,escapeValue,interpolateValue,esTemplateValue,evaluateValue,offset){interpolateValue||(interpolateValue=esTemplateValue);source+=text.slice(index,offset).replace(reUnescapedString,escapeStringChar);if(escapeValue){source+="' +\n__e("+escapeValue+") +\n'"}if(evaluateValue){isEvaluating=true;source+="';\n"+evaluateValue+";\n__p += '"}if(interpolateValue){source+="' +\n((__t = ("+interpolateValue+")) == null ? '' : __t) +\n'"}index=offset+match.length;return match});source+="';\n";var variable=options.variable,hasVariable=variable;if(!hasVariable){variable="obj";source="with ("+variable+") {\n"+source+"\n}\n"}source=(isEvaluating?source.replace(reEmptyStringLeading,""):source).replace(reEmptyStringMiddle,"$1").replace(reEmptyStringTrailing,"$1;");source="function("+variable+") {\n"+(hasVariable?"":variable+" || ("+variable+" = {});\n")+"var __t, __p = '', __e = _.escape"+(isEvaluating?", __j = Array.prototype.join;\n"+"function print() { __p += __j.call(arguments, '') }\n":";\n")+source+"return __p\n}";var sourceURL="\n/*\n//# sourceURL="+(options.sourceURL||"/lodash/template/source["+templateCounter++ +"]")+"\n*/";try{var result=Function(importsKeys,"return "+source+sourceURL).apply(undefined,importsValues)}catch(e){e.source=source;throw e}if(data){return result(data)}result.source=source;return result}function times(n,callback,thisArg){n=(n=+n)>-1?n:0;var index=-1,result=Array(n);callback=baseCreateCallback(callback,thisArg,1);while(++index tr > th {
+	padding-left: 18px;
+	padding-right: 18px;
+}
+
+table.dataTable th:active {
+	outline: none;
+}
+
+/* Scrolling */
+div.dataTables_scrollHead table {
+	margin-bottom: 0 !important;
+	border-bottom-left-radius: 0;
+	border-bottom-right-radius: 0;
+}
+
+div.dataTables_scrollHead table thead tr:last-child th:first-child,
+div.dataTables_scrollHead table thead tr:last-child td:first-child {
+	border-bottom-left-radius: 0 !important;
+	border-bottom-right-radius: 0 !important;
+}
+
+div.dataTables_scrollBody table {
+	border-top: none;
+	margin-top: 0 !important;
+	margin-bottom: 0 !important;
+}
+
+div.dataTables_scrollBody tbody tr:first-child th,
+div.dataTables_scrollBody tbody tr:first-child td {
+	border-top: none;
+}
+
+div.dataTables_scrollFoot table {
+	margin-top: 0 !important;
+	border-top: none;
+}
+
+/* Frustratingly the border-collapse:collapse used by Bootstrap makes the column
+   width calculations when using scrolling impossible to align columns. We have
+   to use separate
+ */
+table.table-bordered.dataTable {
+	border-collapse: separate !important;
+}
+table.table-bordered thead th,
+table.table-bordered thead td {
+	border-left-width: 0;
+	border-top-width: 0;
+}
+table.table-bordered tbody th,
+table.table-bordered tbody td {
+	border-left-width: 0;
+	border-bottom-width: 0;
+}
+table.table-bordered th:last-child,
+table.table-bordered td:last-child {
+	border-right-width: 0;
+}
+div.dataTables_scrollHead table.table-bordered {
+	border-bottom-width: 0;
+}
+
+
+
+
+/*
+ * TableTools styles
+ */
+.table.dataTable tbody tr.active td,
+.table.dataTable tbody tr.active th {
+	background-color: #08C;
+	color: white;
+}
+
+.table.dataTable tbody tr.active:hover td,
+.table.dataTable tbody tr.active:hover th {
+	background-color: #0075b0 !important;
+}
+
+.table.dataTable tbody tr.active th > a,
+.table.dataTable tbody tr.active td > a {
+	color: white;
+}
+
+.table-striped.dataTable tbody tr.active:nth-child(odd) td,
+.table-striped.dataTable tbody tr.active:nth-child(odd) th {
+	background-color: #017ebc;
+}
+
+table.DTTT_selectable tbody tr {
+	cursor: pointer;
+}
+
+div.DTTT .btn {
+	color: #333 !important;
+	font-size: 12px;
+}
+
+div.DTTT .btn:hover {
+	text-decoration: none !important;
+}
+
+ul.DTTT_dropdown.dropdown-menu {
+  z-index: 2003;
+}
+
+ul.DTTT_dropdown.dropdown-menu a {
+	color: #333 !important; /* needed only when demo_page.css is included */
+}
+
+ul.DTTT_dropdown.dropdown-menu li {
+	position: relative;
+}
+
+ul.DTTT_dropdown.dropdown-menu li:hover a {
+	background-color: #0088cc;
+	color: white !important;
+}
+
+div.DTTT_collection_background {
+	z-index: 2002;	
+}
+
+/* TableTools information display */
+div.DTTT_print_info {
+	position: fixed;
+	top: 50%;
+	left: 50%;
+	width: 400px;
+	height: 150px;
+	margin-left: -200px;
+	margin-top: -75px;
+	text-align: center;
+	color: #333;
+	padding: 10px 30px;
+	opacity: 0.95;
+
+	background-color: white;
+	border: 1px solid rgba(0, 0, 0, 0.2);
+	border-radius: 6px;
+	
+	-webkit-box-shadow: 0 3px 7px rgba(0, 0, 0, 0.5);
+	        box-shadow: 0 3px 7px rgba(0, 0, 0, 0.5);
+}
+
+div.DTTT_print_info h6 {
+	font-weight: normal;
+	font-size: 28px;
+	line-height: 28px;
+	margin: 1em;
+}
+
+div.DTTT_print_info p {
+	font-size: 14px;
+	line-height: 20px;
+}
+
+div.dataTables_processing {
+    position: absolute;
+    top: 50%;
+    left: 50%;
+    width: 100%;
+    height: 60px;
+    margin-left: -50%;
+    margin-top: -25px;
+    padding-top: 20px;
+    padding-bottom: 20px;
+    text-align: center;
+    font-size: 1.2em;
+    background-color: white;
+    background: -webkit-gradient(linear, left top, right top, color-stop(0%, rgba(255,255,255,0)), color-stop(25%, rgba(255,255,255,0.9)), color-stop(75%, rgba(255,255,255,0.9)), color-stop(100%, rgba(255,255,255,0)));
+    background: -webkit-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);
+    background: -moz-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);
+    background: -ms-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);
+    background: -o-linear-gradient(left, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);
+    background: linear-gradient(to right, rgba(255,255,255,0) 0%, rgba(255,255,255,0.9) 25%, rgba(255,255,255,0.9) 75%, rgba(255,255,255,0) 100%);
+}
+
+
+
+/*
+ * FixedColumns styles
+ */
+div.DTFC_LeftHeadWrapper table,
+div.DTFC_LeftFootWrapper table,
+div.DTFC_RightHeadWrapper table,
+div.DTFC_RightFootWrapper table,
+table.DTFC_Cloned tr.even {
+    background-color: white;
+    margin-bottom: 0;
+}
+ 
+div.DTFC_RightHeadWrapper table ,
+div.DTFC_LeftHeadWrapper table {
+	border-bottom: none !important;
+    margin-bottom: 0 !important;
+    border-top-right-radius: 0 !important;
+    border-bottom-left-radius: 0 !important;
+    border-bottom-right-radius: 0 !important;
+}
+ 
+div.DTFC_RightHeadWrapper table thead tr:last-child th:first-child,
+div.DTFC_RightHeadWrapper table thead tr:last-child td:first-child,
+div.DTFC_LeftHeadWrapper table thead tr:last-child th:first-child,
+div.DTFC_LeftHeadWrapper table thead tr:last-child td:first-child {
+    border-bottom-left-radius: 0 !important;
+    border-bottom-right-radius: 0 !important;
+}
+ 
+div.DTFC_RightBodyWrapper table,
+div.DTFC_LeftBodyWrapper table {
+    border-top: none;
+    margin: 0 !important;
+}
+ 
+div.DTFC_RightBodyWrapper tbody tr:first-child th,
+div.DTFC_RightBodyWrapper tbody tr:first-child td,
+div.DTFC_LeftBodyWrapper tbody tr:first-child th,
+div.DTFC_LeftBodyWrapper tbody tr:first-child td {
+    border-top: none;
+}
+ 
+div.DTFC_RightFootWrapper table,
+div.DTFC_LeftFootWrapper table {
+    border-top: none;
+    margin-top: 0 !important;
+}
+
+
+/*
+ * FixedHeader styles
+ */
+div.FixedHeader_Cloned table {
+	margin: 0 !important
+}
+
diff --git a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.min.js b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.min.js
new file mode 100644
index 0000000000000..f0d09b9d52668
--- /dev/null
+++ b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap.min.js
@@ -0,0 +1,8 @@
+/*!
+ DataTables Bootstrap 3 integration
+ ©2011-2014 SpryMedia Ltd - datatables.net/license
+*/
+(function(){var f=function(c,b){c.extend(!0,b.defaults,{dom:"<'row'<'col-sm-6'l><'col-sm-6'f>><'row'<'col-sm-12'tr>><'row'<'col-sm-6'i><'col-sm-6'p>>",renderer:"bootstrap"});c.extend(b.ext.classes,{sWrapper:"dataTables_wrapper form-inline dt-bootstrap",sFilterInput:"form-control input-sm",sLengthSelect:"form-control input-sm"});b.ext.renderer.pageButton.bootstrap=function(g,f,p,k,h,l){var q=new b.Api(g),r=g.oClasses,i=g.oLanguage.oPaginate,d,e,o=function(b,f){var j,m,n,a,k=function(a){a.preventDefault();
+c(a.currentTarget).hasClass("disabled")||q.page(a.data.action).draw(!1)};j=0;for(m=f.length;j",{"class":r.sPageButton+" "+
+e,"aria-controls":g.sTableId,tabindex:g.iTabIndex,id:0===p&&"string"===typeof a?g.sTableId+"_"+a:null}).append(c("",{href:"#"}).html(d)).appendTo(b),g.oApi._fnBindAction(n,{action:a},k))}};o(c(f).empty().html('
           
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
index 6174fc11f83d8..ae16ce90c84b7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
@@ -17,9 +17,9 @@
 
 package org.apache.spark.deploy.master.ui
 
-import org.apache.spark.Logging
 import org.apache.spark.deploy.master.Master
-import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationsListResource, ApplicationInfo,
+import org.apache.spark.internal.Logging
+import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, ApplicationsListResource,
   UIRoot}
 import org.apache.spark.ui.{SparkUI, WebUI}
 import org.apache.spark.ui.JettyUtils._
@@ -28,14 +28,17 @@ import org.apache.spark.ui.JettyUtils._
  * Web UI server for the standalone master.
  */
 private[master]
-class MasterWebUI(val master: Master, requestedPort: Int)
-  extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging
-  with UIRoot {
+class MasterWebUI(
+    val master: Master,
+    requestedPort: Int,
+    customMasterPage: Option[MasterPage] = None)
+  extends WebUI(master.securityMgr, master.securityMgr.getSSLOptions("standalone"),
+    requestedPort, master.conf, name = "MasterUI") with Logging with UIRoot {
 
   val masterEndpointRef = master.self
   val killEnabled = master.conf.getBoolean("spark.ui.killEnabled", true)
 
-  val masterPage = new MasterPage(this)
+  val masterPage = customMasterPage.getOrElse(new MasterPage(this))
 
   initialize()
 
diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala
index 5d4e5b899dfdc..a057977eb0dd2 100644
--- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala
@@ -19,11 +19,12 @@ package org.apache.spark.deploy.mesos
 
 import java.util.concurrent.CountDownLatch
 
+import org.apache.spark.{SecurityManager, SparkConf}
 import org.apache.spark.deploy.mesos.ui.MesosClusterUI
 import org.apache.spark.deploy.rest.mesos.MesosRestServer
+import org.apache.spark.internal.Logging
 import org.apache.spark.scheduler.cluster.mesos._
-import org.apache.spark.util.SignalLogger
-import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.util.{ShutdownHookManager, Utils}
 
 /*
  * A dispatcher that is responsible for managing and launching drivers, and is intended to be
@@ -50,7 +51,7 @@ private[mesos] class MesosClusterDispatcher(
   extends Logging {
 
   private val publicAddress = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(args.host)
-  private val recoveryMode = conf.get("spark.mesos.deploy.recoveryMode", "NONE").toUpperCase()
+  private val recoveryMode = conf.get("spark.deploy.recoveryMode", "NONE").toUpperCase()
   logInfo("Recovery mode in Mesos dispatcher set to: " + recoveryMode)
 
   private val engineFactory = recoveryMode match {
@@ -73,7 +74,7 @@ private[mesos] class MesosClusterDispatcher(
 
   def start(): Unit = {
     webUi.bind()
-    scheduler.frameworkUrl = webUi.activeWebUiUrl
+    scheduler.frameworkUrl = conf.get("spark.mesos.dispatcher.webui.url", webUi.activeWebUiUrl)
     scheduler.start()
     server.start()
   }
@@ -92,25 +93,22 @@ private[mesos] class MesosClusterDispatcher(
 
 private[mesos] object MesosClusterDispatcher extends Logging {
   def main(args: Array[String]) {
-    SignalLogger.register(log)
+    Utils.initDaemon(log)
     val conf = new SparkConf
     val dispatcherArgs = new MesosClusterDispatcherArguments(args, conf)
     conf.setMaster(dispatcherArgs.masterUrl)
     conf.setAppName(dispatcherArgs.name)
     dispatcherArgs.zookeeperUrl.foreach { z =>
-      conf.set("spark.mesos.deploy.recoveryMode", "ZOOKEEPER")
-      conf.set("spark.mesos.deploy.zookeeper.url", z)
+      conf.set("spark.deploy.recoveryMode", "ZOOKEEPER")
+      conf.set("spark.deploy.zookeeper.url", z)
     }
     val dispatcher = new MesosClusterDispatcher(dispatcherArgs, conf)
     dispatcher.start()
-    val shutdownHook = new Thread() {
-      override def run() {
-        logInfo("Shutdown hook is shutting down dispatcher")
-        dispatcher.stop()
-        dispatcher.awaitShutdown()
-      }
+    ShutdownHookManager.addShutdownHook { () =>
+      logInfo("Shutdown hook is shutting down dispatcher")
+      dispatcher.stop()
+      dispatcher.awaitShutdown()
     }
-    Runtime.getRuntime.addShutdownHook(shutdownHook)
     dispatcher.awaitShutdown()
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala
index 5accaf78d0a51..11e13441eeba6 100644
--- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.deploy.mesos
 
+import scala.annotation.tailrec
+
 import org.apache.spark.SparkConf
 import org.apache.spark.util.{IntParam, Utils}
 
@@ -34,6 +36,7 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf:
 
   propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile)
 
+  @tailrec
   private def parse(args: List[String]): Unit = args match {
     case ("--host" | "-h") :: value :: tail =>
       Utils.checkHost(value, "Please use hostname " + value)
@@ -44,7 +47,7 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf:
       port = value
       parse(tail)
 
-    case ("--webui-port" | "-p") :: IntParam(value) :: tail =>
+    case ("--webui-port") :: IntParam(value) :: tail =>
       webUiPort = value
       parse(tail)
 
@@ -73,14 +76,13 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf:
     case ("--help") :: tail =>
       printUsageAndExit(0)
 
-    case Nil => {
+    case Nil =>
       if (masterUrl == null) {
         // scalastyle:off println
         System.err.println("--master is required")
         // scalastyle:on println
         printUsageAndExit(1)
       }
-    }
 
     case _ =>
       printUsageAndExit(1)
diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala
index 12337a940a414..6b297c4600a68 100644
--- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala
@@ -17,68 +17,90 @@
 
 package org.apache.spark.deploy.mesos
 
-import java.net.SocketAddress
+import java.nio.ByteBuffer
+import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
 
-import scala.collection.mutable
+import scala.collection.JavaConverters._
 
-import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.{SecurityManager, SparkConf}
 import org.apache.spark.deploy.ExternalShuffleService
+import org.apache.spark.internal.Logging
 import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
 import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler
 import org.apache.spark.network.shuffle.protocol.BlockTransferMessage
-import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver
+import org.apache.spark.network.shuffle.protocol.mesos.{RegisterDriver, ShuffleServiceHeartbeat}
 import org.apache.spark.network.util.TransportConf
+import org.apache.spark.util.ThreadUtils
 
 /**
  * An RPC endpoint that receives registration requests from Spark drivers running on Mesos.
  * It detects driver termination and calls the cleanup callback to [[ExternalShuffleService]].
  */
-private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportConf)
+private[mesos] class MesosExternalShuffleBlockHandler(
+    transportConf: TransportConf,
+    cleanerIntervalS: Long)
   extends ExternalShuffleBlockHandler(transportConf, null) with Logging {
 
-  // Stores a map of driver socket addresses to app ids
-  private val connectedApps = new mutable.HashMap[SocketAddress, String]
+  ThreadUtils.newDaemonSingleThreadScheduledExecutor("shuffle-cleaner-watcher")
+    .scheduleAtFixedRate(new CleanerThread(), 0, cleanerIntervalS, TimeUnit.SECONDS)
+
+  // Stores a map of app id to app state (timeout value and last heartbeat)
+  private val connectedApps = new ConcurrentHashMap[String, AppState]()
 
   protected override def handleMessage(
       message: BlockTransferMessage,
       client: TransportClient,
       callback: RpcResponseCallback): Unit = {
     message match {
-      case RegisterDriverParam(appId) =>
+      case RegisterDriverParam(appId, appState) =>
+        val address = client.getSocketAddress
+        val timeout = appState.heartbeatTimeout
+        logInfo(s"Received registration request from app $appId (remote address $address, " +
+          s"heartbeat timeout $timeout ms).")
+        if (connectedApps.containsKey(appId)) {
+          logWarning(s"Received a registration request from app $appId, but it was already " +
+            s"registered")
+        }
+        connectedApps.put(appId, appState)
+        callback.onSuccess(ByteBuffer.allocate(0))
+      case Heartbeat(appId) =>
         val address = client.getSocketAddress
-        logDebug(s"Received registration request from app $appId (remote address $address).")
-        if (connectedApps.contains(address)) {
-          val existingAppId = connectedApps(address)
-          if (!existingAppId.equals(appId)) {
-            logError(s"A new app '$appId' has connected to existing address $address, " +
-              s"removing previously registered app '$existingAppId'.")
-            applicationRemoved(existingAppId, true)
-          }
+        Option(connectedApps.get(appId)) match {
+          case Some(existingAppState) =>
+            logTrace(s"Received ShuffleServiceHeartbeat from app '$appId' (remote " +
+              s"address $address).")
+            existingAppState.lastHeartbeat = System.nanoTime()
+          case None =>
+            logWarning(s"Received ShuffleServiceHeartbeat from an unknown app (remote " +
+              s"address $address, appId '$appId').")
         }
-        connectedApps(address) = appId
-        callback.onSuccess(new Array[Byte](0))
       case _ => super.handleMessage(message, client, callback)
     }
   }
 
-  /**
-   * On connection termination, clean up shuffle files written by the associated application.
-   */
-  override def connectionTerminated(client: TransportClient): Unit = {
-    val address = client.getSocketAddress
-    if (connectedApps.contains(address)) {
-      val appId = connectedApps(address)
-      logInfo(s"Application $appId disconnected (address was $address).")
-      applicationRemoved(appId, true /* cleanupLocalDirs */)
-      connectedApps.remove(address)
-    } else {
-      logWarning(s"Unknown $address disconnected.")
-    }
-  }
-
   /** An extractor object for matching [[RegisterDriver]] message. */
   private object RegisterDriverParam {
-    def unapply(r: RegisterDriver): Option[String] = Some(r.getAppId)
+    def unapply(r: RegisterDriver): Option[(String, AppState)] =
+      Some((r.getAppId, new AppState(r.getHeartbeatTimeoutMs, System.nanoTime())))
+  }
+
+  private object Heartbeat {
+    def unapply(h: ShuffleServiceHeartbeat): Option[String] = Some(h.getAppId)
+  }
+
+  private class AppState(val heartbeatTimeout: Long, @volatile var lastHeartbeat: Long)
+
+  private class CleanerThread extends Runnable {
+    override def run(): Unit = {
+      val now = System.nanoTime()
+      connectedApps.asScala.foreach { case (appId, appState) =>
+        if (now - appState.lastHeartbeat > appState.heartbeatTimeout * 1000 * 1000) {
+          logInfo(s"Application $appId timed out. Removing shuffle files.")
+          connectedApps.remove(appId)
+          applicationRemoved(appId, true)
+        }
+      }
+    }
   }
 }
 
@@ -92,7 +114,8 @@ private[mesos] class MesosExternalShuffleService(conf: SparkConf, securityManage
 
   protected override def newShuffleBlockHandler(
       conf: TransportConf): ExternalShuffleBlockHandler = {
-    new MesosExternalShuffleBlockHandler(conf)
+    val cleanerIntervalS = this.conf.getTimeAsSeconds("spark.shuffle.cleaner.interval", "30s")
+    new MesosExternalShuffleBlockHandler(conf, cleanerIntervalS)
   }
 }
 
diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala
index e8ef60bd5428a..807835105ec3e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala
@@ -23,10 +23,9 @@ import scala.xml.Node
 
 import org.apache.spark.deploy.Command
 import org.apache.spark.deploy.mesos.MesosDriverDescription
-import org.apache.spark.scheduler.cluster.mesos.{MesosClusterSubmissionState, MesosClusterRetryState}
+import org.apache.spark.scheduler.cluster.mesos.{MesosClusterRetryState, MesosClusterSubmissionState}
 import org.apache.spark.ui.{UIUtils, WebUIPage}
 
-
 private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") {
 
   override def render(request: HttpServletRequest): Seq[Node] = {
@@ -46,7 +45,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver")
     val schedulerHeaders = Seq("Scheduler property", "Value")
     val commandEnvHeaders = Seq("Command environment variable", "Value")
     val launchedHeaders = Seq("Launched property", "Value")
-    val commandHeaders = Seq("Comamnd property", "Value")
+    val commandHeaders = Seq("Command property", "Value")
     val retryHeaders = Seq("Last failed status", "Next retry time", "Retry count")
     val driverDescription = Iterable.apply(driverState.description)
     val submissionState = Iterable.apply(driverState.submissionState)
diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala
index 7419fa9699648..166f666fbcfdc 100644
--- a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala
@@ -22,6 +22,7 @@ import javax.servlet.http.HttpServletRequest
 import scala.xml.Node
 
 import org.apache.mesos.Protos.TaskStatus
+
 import org.apache.spark.deploy.mesos.MesosDriverDescription
 import org.apache.spark.scheduler.cluster.mesos.MesosClusterSubmissionState
 import org.apache.spark.ui.{UIUtils, WebUIPage}
diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala
index 3f693545a0349..baad098a0cd1f 100644
--- a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala
@@ -17,10 +17,10 @@
 
 package org.apache.spark.deploy.mesos.ui
 
-import org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler
 import org.apache.spark.{SecurityManager, SparkConf}
-import org.apache.spark.ui.JettyUtils._
+import org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler
 import org.apache.spark.ui.{SparkUI, WebUI}
+import org.apache.spark.ui.JettyUtils._
 
 /**
  * UI that displays driver results from the [[org.apache.spark.deploy.mesos.MesosClusterDispatcher]]
@@ -31,7 +31,7 @@ private[spark] class MesosClusterUI(
     conf: SparkConf,
     dispatcherPublicAddress: String,
     val scheduler: MesosClusterScheduler)
-  extends WebUI(securityManager, port, conf) {
+  extends WebUI(securityManager, securityManager.getSSLOptions("mesos"), port, conf) {
 
   initialize()
 
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala
index 957a928bc402b..c5a5876a896cc 100644
--- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala
@@ -19,15 +19,19 @@ package org.apache.spark.deploy.rest
 
 import java.io.{DataOutputStream, FileNotFoundException}
 import java.net.{ConnectException, HttpURLConnection, SocketException, URL}
+import java.nio.charset.StandardCharsets
+import java.util.concurrent.TimeoutException
 import javax.servlet.http.HttpServletResponse
 
 import scala.collection.mutable
+import scala.concurrent.{Await, Future}
+import scala.concurrent.duration._
 import scala.io.Source
 
 import com.fasterxml.jackson.core.JsonProcessingException
-import com.google.common.base.Charsets
 
-import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion}
+import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf}
+import org.apache.spark.internal.Logging
 import org.apache.spark.util.Utils
 
 /**
@@ -208,7 +212,7 @@ private[spark] class RestSubmissionClient(master: String) extends Logging {
     try {
       val out = new DataOutputStream(conn.getOutputStream)
       Utils.tryWithSafeFinally {
-        out.write(json.getBytes(Charsets.UTF_8))
+        out.write(json.getBytes(StandardCharsets.UTF_8))
       } {
         out.close()
       }
@@ -225,7 +229,8 @@ private[spark] class RestSubmissionClient(master: String) extends Logging {
    * Exposed for testing.
    */
   private[rest] def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = {
-    try {
+    import scala.concurrent.ExecutionContext.Implicits.global
+    val responseFuture = Future {
       val dataStream =
         if (connection.getResponseCode == HttpServletResponse.SC_OK) {
           connection.getInputStream
@@ -251,11 +256,15 @@ private[spark] class RestSubmissionClient(master: String) extends Logging {
           throw new SubmitRestProtocolException(
             s"Message received from server was not a response:\n${unexpected.toJson}")
       }
-    } catch {
+    }
+
+    try { Await.result(responseFuture, 10.seconds) } catch {
       case unreachable @ (_: FileNotFoundException | _: SocketException) =>
         throw new SubmitRestConnectionException("Unable to connect to server", unreachable)
       case malformed @ (_: JsonProcessingException | _: SubmitRestProtocolException) =>
         throw new SubmitRestProtocolException("Malformed response received from server", malformed)
+      case timeout: TimeoutException =>
+        throw new SubmitRestConnectionException("No response from server", timeout)
     }
   }
 
@@ -374,7 +383,7 @@ private[spark] class RestSubmissionClient(master: String) extends Logging {
       logWarning(s"Unable to connect to server ${masterUrl}.")
       lostMasters += masterUrl
     }
-    lostMasters.size >= masters.size
+    lostMasters.size >= masters.length
   }
 }
 
@@ -404,13 +413,13 @@ private[spark] object RestSubmissionClient {
   }
 
   def main(args: Array[String]): Unit = {
-    if (args.size < 2) {
+    if (args.length < 2) {
       sys.error("Usage: RestSubmissionClient [app resource] [main class] [app args*]")
       sys.exit(1)
     }
     val appResource = args(0)
     val mainClass = args(1)
-    val appArgs = args.slice(2, args.size)
+    val appArgs = args.slice(2, args.length)
     val conf = new SparkConf
     val env = filterSystemEnvironment(sys.env)
     run(appResource, mainClass, appArgs, conf, env)
@@ -420,8 +429,10 @@ private[spark] object RestSubmissionClient {
    * Filter non-spark environment variables from any environment.
    */
   private[rest] def filterSystemEnvironment(env: Map[String, String]): Map[String, String] = {
-    env.filter { case (k, _) =>
-      (k.startsWith("SPARK_") && k != "SPARK_ENV_LOADED") || k.startsWith("MESOS_")
+    env.filterKeys { k =>
+      // SPARK_HOME is filtered out because it is usually wrong on the remote machine (SPARK-12345)
+      (k.startsWith("SPARK_") && k != "SPARK_ENV_LOADED" && k != "SPARK_HOME") ||
+        k.startsWith("MESOS_")
     }
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala
index 2e78d03e5c0cc..14244ea5714c6 100644
--- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala
@@ -21,14 +21,16 @@ import java.net.InetSocketAddress
 import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}
 
 import scala.io.Source
+
 import com.fasterxml.jackson.core.JsonProcessingException
 import org.eclipse.jetty.server.Server
-import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler}
+import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder}
 import org.eclipse.jetty.util.thread.QueuedThreadPool
 import org.json4s._
 import org.json4s.jackson.JsonMethods._
 
-import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion}
+import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf}
+import org.apache.spark.internal.Logging
 import org.apache.spark.util.Utils
 
 /**
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
index d5b9bcab1423f..c19296c7b3e00 100644
--- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
@@ -20,11 +20,11 @@ package org.apache.spark.deploy.rest
 import java.io.File
 import javax.servlet.http.HttpServletResponse
 
-import org.apache.spark.deploy.ClientArguments._
+import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf}
 import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription}
+import org.apache.spark.deploy.ClientArguments._
 import org.apache.spark.rpc.RpcEndpointRef
 import org.apache.spark.util.Utils
-import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf}
 
 /**
  * A server that responds to requests submitted by the [[RestSubmissionClient]].
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala
index 868cc35d06ef3..3b96488a129a9 100644
--- a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala
@@ -23,13 +23,12 @@ import java.util.Date
 import java.util.concurrent.atomic.AtomicLong
 import javax.servlet.http.HttpServletResponse
 
+import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf}
 import org.apache.spark.deploy.Command
 import org.apache.spark.deploy.mesos.MesosDriverDescription
 import org.apache.spark.deploy.rest._
 import org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler
 import org.apache.spark.util.Utils
-import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf}
-
 
 /**
  * A server that responds to requests submitted by the [[RestSubmissionClient]].
@@ -95,7 +94,7 @@ private[mesos] class MesosSubmitRequestServlet(
     val driverCores = sparkProperties.get("spark.driver.cores")
     val appArgs = request.appArgs
     val environmentVariables = request.environmentVariables
-    val name = request.sparkProperties.get("spark.app.name").getOrElse(mainClass)
+    val name = request.sparkProperties.getOrElse("spark.app.name", mainClass)
 
     // Construct driver description
     val conf = new SparkConf(false).setAll(sparkProperties)
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
index ce02ee203a4bd..cba4aaffe2caa 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
@@ -22,14 +22,14 @@ import java.io.{File, FileOutputStream, InputStream, IOException}
 import scala.collection.JavaConverters._
 import scala.collection.Map
 
-import org.apache.spark.Logging
 import org.apache.spark.SecurityManager
 import org.apache.spark.deploy.Command
+import org.apache.spark.internal.Logging
 import org.apache.spark.launcher.WorkerCommandBuilder
 import org.apache.spark.util.Utils
 
 /**
- ** Utilities for running commands with the spark classpath.
+ * Utilities for running commands with the spark classpath.
  */
 private[deploy]
 object CommandUtils extends Logging {
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
index 89159ff5e2b3c..aad2e91b25554 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
@@ -18,20 +18,21 @@
 package org.apache.spark.deploy.worker
 
 import java.io._
+import java.nio.charset.StandardCharsets
 
 import scala.collection.JavaConverters._
 
-import com.google.common.base.Charsets.UTF_8
 import com.google.common.io.Files
 import org.apache.hadoop.fs.Path
 
-import org.apache.spark.{Logging, SparkConf, SecurityManager}
+import org.apache.spark.{SecurityManager, SparkConf}
 import org.apache.spark.deploy.{DriverDescription, SparkHadoopUtil}
 import org.apache.spark.deploy.DeployMessages.DriverStateChanged
 import org.apache.spark.deploy.master.DriverState
 import org.apache.spark.deploy.master.DriverState.DriverState
+import org.apache.spark.internal.Logging
 import org.apache.spark.rpc.RpcEndpointRef
-import org.apache.spark.util.{Utils, Clock, SystemClock}
+import org.apache.spark.util.{Clock, SystemClock, Utils}
 
 /**
  * Manages the execution of one driver, including automatically restarting the driver on failure.
@@ -174,7 +175,7 @@ private[deploy] class DriverRunner(
       val stderr = new File(baseDir, "stderr")
       val formattedCommand = builder.command.asScala.mkString("\"", "\" \"", "\"")
       val header = "Launch Command: %s\n%s\n\n".format(formattedCommand, "=" * 40)
-      Files.append(header, stderr, UTF_8)
+      Files.append(header, stderr, StandardCharsets.UTF_8)
       CommandUtils.redirectStream(process.getErrorStream, stderr)
     }
     runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise)
@@ -217,7 +218,7 @@ private[deploy] class DriverRunner(
 }
 
 private[deploy] trait Sleeper {
-  def sleep(seconds: Int)
+  def sleep(seconds: Int): Unit
 }
 
 // Needed because ProcessBuilder is a final class and cannot be mocked
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index 3aef0515cbf6e..06066248ea5d0 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -18,16 +18,17 @@
 package org.apache.spark.deploy.worker
 
 import java.io._
+import java.nio.charset.StandardCharsets
 
 import scala.collection.JavaConverters._
 
-import com.google.common.base.Charsets.UTF_8
 import com.google.common.io.Files
 
-import org.apache.spark.rpc.RpcEndpointRef
-import org.apache.spark.{SecurityManager, SparkConf, Logging}
+import org.apache.spark.{SecurityManager, SparkConf}
 import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
 import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc.RpcEndpointRef
 import org.apache.spark.util.{ShutdownHookManager, Utils}
 import org.apache.spark.util.logging.FileAppender
 
@@ -60,6 +61,9 @@ private[deploy] class ExecutorRunner(
   private var stdoutAppender: FileAppender = null
   private var stderrAppender: FileAppender = null
 
+  // Timeout to wait for when trying to terminate an executor.
+  private val EXECUTOR_TERMINATE_TIMEOUT_MS = 10 * 1000
+
   // NOTE: This is now redundant with the automated shut-down enforced by the Executor. It might
   // make sense to remove this in the future.
   private var shutdownHook: AnyRef = null
@@ -71,6 +75,11 @@ private[deploy] class ExecutorRunner(
     workerThread.start()
     // Shutdown hook that kills actors on shutdown.
     shutdownHook = ShutdownHookManager.addShutdownHook { () =>
+      // It's possible that we arrive here before calling `fetchAndRunExecutor`, then `state` will
+      // be `ExecutorState.RUNNING`. In this case, we should set `state` to `FAILED`.
+      if (state == ExecutorState.RUNNING) {
+        state = ExecutorState.FAILED
+      }
       killProcess(Some("Worker shutting down")) }
   }
 
@@ -89,10 +98,17 @@ private[deploy] class ExecutorRunner(
       if (stderrAppender != null) {
         stderrAppender.stop()
       }
-      process.destroy()
-      exitCode = Some(process.waitFor())
+      exitCode = Utils.terminateProcess(process, EXECUTOR_TERMINATE_TIMEOUT_MS)
+      if (exitCode.isEmpty) {
+        logWarning("Failed to terminate process: " + process +
+          ". This process will likely be orphaned.")
+      }
+    }
+    try {
+      worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode))
+    } catch {
+      case e: IllegalStateException => logWarning(e.getMessage(), e)
     }
-    worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode))
   }
 
   /** Stop this executor runner, including killing the process it launched */
@@ -153,7 +169,7 @@ private[deploy] class ExecutorRunner(
       stdoutAppender = FileAppender(process.getInputStream, stdout, conf)
 
       val stderr = new File(executorDir, "stderr")
-      Files.write(header, stderr, UTF_8)
+      Files.write(header, stderr, StandardCharsets.UTF_8)
       stderrAppender = FileAppender(process.getErrorStream, stderr, conf)
 
       // Wait for it to exit; executor may exit with code 0 (when driver instructs it to shutdown)
@@ -163,16 +179,14 @@ private[deploy] class ExecutorRunner(
       val message = "Command exited with code " + exitCode
       worker.send(ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode)))
     } catch {
-      case interrupted: InterruptedException => {
+      case interrupted: InterruptedException =>
         logInfo("Runner thread for executor " + fullId + " interrupted")
         state = ExecutorState.KILLED
         killProcess(None)
-      }
-      case e: Exception => {
+      case e: Exception =>
         logError("Error running executor", e)
         state = ExecutorState.FAILED
         killProcess(Some(e.toString))
-      }
     }
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index a45867e7680ec..449beb0811177 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -20,7 +20,7 @@ package org.apache.spark.deploy.worker
 import java.io.File
 import java.io.IOException
 import java.text.SimpleDateFormat
-import java.util.{UUID, Date}
+import java.util.{Date, UUID}
 import java.util.concurrent._
 import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture}
 
@@ -29,15 +29,16 @@ import scala.concurrent.ExecutionContext
 import scala.util.{Failure, Random, Success}
 import scala.util.control.NonFatal
 
-import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.{SecurityManager, SparkConf}
 import org.apache.spark.deploy.{Command, ExecutorDescription, ExecutorState}
 import org.apache.spark.deploy.DeployMessages._
 import org.apache.spark.deploy.ExternalShuffleService
 import org.apache.spark.deploy.master.{DriverState, Master}
 import org.apache.spark.deploy.worker.ui.WorkerWebUI
+import org.apache.spark.internal.Logging
 import org.apache.spark.metrics.MetricsSystem
 import org.apache.spark.rpc._
-import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils}
+import org.apache.spark.util.{ThreadUtils, Utils}
 
 private[deploy] class Worker(
     override val rpcEnv: RpcEnv,
@@ -45,7 +46,6 @@ private[deploy] class Worker(
     cores: Int,
     memory: Int,
     masterRpcAddresses: Array[RpcAddress],
-    systemName: String,
     endpointName: String,
     workDirPath: String = null,
     val conf: SparkConf,
@@ -101,7 +101,8 @@ private[deploy] class Worker(
   private var master: Option[RpcEndpointRef] = None
   private var activeMasterUrl: String = ""
   private[worker] var activeMasterWebUiUrl : String = ""
-  private val workerUri = rpcEnv.uriOf(systemName, rpcEnv.address, endpointName)
+  private var workerWebUiUrl: String = ""
+  private val workerUri = RpcEndpointAddress(rpcEnv.address, endpointName).toString
   private var registered = false
   private var connected = false
   private val workerId = generateWorkerId()
@@ -146,12 +147,10 @@ private[deploy] class Worker(
   // A thread pool for registering with masters. Because registering with a master is a blocking
   // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same
   // time so that we can register with all masters.
-  private val registerMasterThreadPool = new ThreadPoolExecutor(
-    0,
-    masterRpcAddresses.size, // Make sure we can register with all masters at the same time
-    60L, TimeUnit.SECONDS,
-    new SynchronousQueue[Runnable](),
-    ThreadUtils.namedThreadFactory("worker-register-master-threadpool"))
+  private val registerMasterThreadPool = ThreadUtils.newDaemonCachedThreadPool(
+    "worker-register-master-threadpool",
+    masterRpcAddresses.length // Make sure we can register with all masters at the same time
+  )
 
   var coresUsed = 0
   var memoryUsed = 0
@@ -187,6 +186,9 @@ private[deploy] class Worker(
     shuffleService.startIfEnabled()
     webUi = new WorkerWebUI(this, workDir, webUiPort)
     webUi.bind()
+
+    val scheme = if (webUi.sslOptions.enabled) "https" else "http"
+    workerWebUiUrl = s"$scheme://$publicAddress:${webUi.boundPort}"
     registerWithMaster()
 
     metricsSystem.registerSource(workerSource)
@@ -211,8 +213,7 @@ private[deploy] class Worker(
         override def run(): Unit = {
           try {
             logInfo("Connecting to master " + masterAddress + "...")
-            val masterEndpoint =
-              rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME)
+            val masterEndpoint = rpcEnv.setupEndpointRef(masterAddress, Master.ENDPOINT_NAME)
             registerWithMaster(masterEndpoint)
           } catch {
             case ie: InterruptedException => // Cancelled
@@ -268,8 +269,7 @@ private[deploy] class Worker(
               override def run(): Unit = {
                 try {
                   logInfo("Connecting to master " + masterAddress + "...")
-                  val masterEndpoint =
-                    rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME)
+                  val masterEndpoint = rpcEnv.setupEndpointRef(masterAddress, Master.ENDPOINT_NAME)
                   registerWithMaster(masterEndpoint)
                 } catch {
                   case ie: InterruptedException => // Cancelled
@@ -341,7 +341,7 @@ private[deploy] class Worker(
 
   private def registerWithMaster(masterEndpoint: RpcEndpointRef): Unit = {
     masterEndpoint.ask[RegisterWorkerResponse](RegisterWorker(
-      workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress))
+      workerId, host, port, self, cores, memory, workerWebUiUrl))
       .onComplete {
         // This is a very fast action so we can use "ThreadUtils.sameThread"
         case Success(msg) =>
@@ -375,6 +375,11 @@ private[deploy] class Worker(
           }, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS)
         }
 
+        val execs = executors.values.map { e =>
+          new ExecutorDescription(e.appId, e.execId, e.cores, e.state)
+        }
+        masterRef.send(WorkerLatestState(workerId, execs.toList, drivers.keys.toSeq))
+
       case RegisterWorkerFailed(message) =>
         if (!registered) {
           logError("Worker registration failed: " + message)
@@ -395,7 +400,7 @@ private[deploy] class Worker(
       // rpcEndpoint.
       // Copy ids so that it can be used in the cleanup thread.
       val appIds = executors.values.map(_.appId).toSet
-      val cleanupFuture = concurrent.future {
+      val cleanupFuture = concurrent.Future {
         val appDirs = workDir.listFiles()
         if (appDirs == null) {
           throw new IOException("ERROR: Failed to list files in " + appDirs)
@@ -446,13 +451,12 @@ private[deploy] class Worker(
           // Create local dirs for the executor. These are passed to the executor via the
           // SPARK_EXECUTOR_DIRS environment variable, and deleted by the Worker when the
           // application finishes.
-          val appLocalDirs = appDirectories.get(appId).getOrElse {
+          val appLocalDirs = appDirectories.getOrElse(appId,
             Utils.getOrCreateLocalRootDirs(conf).map { dir =>
               val appDir = Utils.createDirectory(dir, namePrefix = "executor")
               Utils.chmod700(appDir)
               appDir.getAbsolutePath()
-            }.toSeq
-          }
+            }.toSeq)
           appDirectories(appId) = appLocalDirs
           val manager = new ExecutorRunner(
             appId,
@@ -469,14 +473,14 @@ private[deploy] class Worker(
             executorDir,
             workerUri,
             conf,
-            appLocalDirs, ExecutorState.LOADING)
+            appLocalDirs, ExecutorState.RUNNING)
           executors(appId + "/" + execId) = manager
           manager.start()
           coresUsed += cores_
           memoryUsed += memory_
           sendToMaster(ExecutorStateChanged(appId, execId, manager.state, None, None))
         } catch {
-          case e: Exception => {
+          case e: Exception =>
             logError(s"Failed to launch executor $appId/$execId for ${appDesc.name}.", e)
             if (executors.contains(appId + "/" + execId)) {
               executors(appId + "/" + execId).kill()
@@ -484,7 +488,6 @@ private[deploy] class Worker(
             }
             sendToMaster(ExecutorStateChanged(appId, execId, ExecutorState.FAILED,
               Some(e.toString), None))
-          }
         }
       }
 
@@ -505,7 +508,7 @@ private[deploy] class Worker(
         }
       }
 
-    case LaunchDriver(driverId, driverDesc) => {
+    case LaunchDriver(driverId, driverDesc) =>
       logInfo(s"Asked to launch driver $driverId")
       val driver = new DriverRunner(
         conf,
@@ -521,9 +524,8 @@ private[deploy] class Worker(
 
       coresUsed += driverDesc.cores
       memoryUsed += driverDesc.mem
-    }
 
-    case KillDriver(driverId) => {
+    case KillDriver(driverId) =>
       logInfo(s"Asked to kill driver $driverId")
       drivers.get(driverId) match {
         case Some(runner) =>
@@ -531,11 +533,9 @@ private[deploy] class Worker(
         case None =>
           logError(s"Asked to kill unknown driver $driverId")
       }
-    }
 
-    case driverStateChanged @ DriverStateChanged(driverId, state, exception) => {
+    case driverStateChanged @ DriverStateChanged(driverId, state, exception) =>
       handleDriverStateChanged(driverStateChanged)
-    }
 
     case ReregisterWithMaster =>
       reregisterWithMaster()
@@ -688,11 +688,11 @@ private[deploy] object Worker extends Logging {
   val ENDPOINT_NAME = "Worker"
 
   def main(argStrings: Array[String]) {
-    SignalLogger.register(log)
+    Utils.initDaemon(log)
     val conf = new SparkConf
     val args = new WorkerArguments(argStrings, conf)
     val rpcEnv = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, args.cores,
-      args.memory, args.masters, args.workDir)
+      args.memory, args.masters, args.workDir, conf = conf)
     rpcEnv.awaitTermination()
   }
 
@@ -713,7 +713,7 @@ private[deploy] object Worker extends Logging {
     val rpcEnv = RpcEnv.create(systemName, host, port, conf, securityMgr)
     val masterAddresses = masterUrls.map(RpcAddress.fromSparkURL(_))
     rpcEnv.setupEndpoint(ENDPOINT_NAME, new Worker(rpcEnv, webUiPort, cores, memory,
-      masterAddresses, systemName, ENDPOINT_NAME, workDir, conf, securityMgr))
+      masterAddresses, ENDPOINT_NAME, workDir, conf, securityMgr))
     rpcEnv
   }
 
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
index 5181142c5f80e..777020d4d5c84 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
@@ -19,6 +19,8 @@ package org.apache.spark.deploy.worker
 
 import java.lang.management.ManagementFactory
 
+import scala.annotation.tailrec
+
 import org.apache.spark.util.{IntParam, MemoryParam, Utils}
 import org.apache.spark.SparkConf
 
@@ -63,6 +65,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) {
 
   checkWorkerMemory()
 
+  @tailrec
   private def parse(args: List[String]): Unit = args match {
     case ("--ip" | "-i") :: value :: tail =>
       Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
@@ -162,12 +165,11 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) {
       }
       // scalastyle:on classforname
     } catch {
-      case e: Exception => {
+      case e: Exception =>
         totalMb = 2*1024
         // scalastyle:off println
         System.out.println("Failed to get total physical memory. Using " + totalMb + " MB")
         // scalastyle:on println
-      }
     }
     // Leave out 1 GB for the operating system, but don't return a negative memory size
     math.max(totalMb - 1024, Utils.DEFAULT_DRIVER_MEM_MB)
@@ -175,7 +177,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) {
 
   def checkWorkerMemory(): Unit = {
     if (memory <= 0) {
-      val message = "Memory can't be 0, missing a M or G on the end of the memory specification?"
+      val message = "Memory is below 1MB, or missing a M/G at the end of the memory specification?"
       throw new IllegalStateException(message)
     }
   }
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
index ab56fde938bae..af29de3b0896e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.deploy.worker
 
-import org.apache.spark.Logging
+import org.apache.spark.internal.Logging
 import org.apache.spark.rpc._
 
 /**
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
index 5a1d06eb87db9..e75c0cec4acc7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
@@ -18,19 +18,18 @@
 package org.apache.spark.deploy.worker.ui
 
 import java.io.File
-import java.net.URI
 import javax.servlet.http.HttpServletRequest
 
 import scala.xml.Node
 
-import org.apache.spark.ui.{WebUIPage, UIUtils}
+import org.apache.spark.internal.Logging
+import org.apache.spark.ui.{UIUtils, WebUIPage}
 import org.apache.spark.util.Utils
-import org.apache.spark.Logging
 import org.apache.spark.util.logging.RollingFileAppender
 
 private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with Logging {
   private val worker = parent.worker
-  private val workDir = parent.workDir
+  private val workDir = new File(parent.workDir.toURI.normalize().getPath)
   private val supportedLogTypes = Set("stderr", "stdout")
 
   def renderLog(request: HttpServletRequest): String = {
@@ -108,20 +107,18 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with
       }
 
     val content =
-      
-        
-          {linkToMaster}
-          
-
{backButton}
-
{range}
-
{nextButton}
-
-
-
-
{logText}
-
- - +
+ {linkToMaster} +
+
{backButton}
+
{range}
+
{nextButton}
+
+
+
+
{logText}
+
+
UIUtils.basicSparkPage(content, logType + " log page for " + pageName) } @@ -138,7 +135,7 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with } // Verify that the normalized path of the log directory is in the working directory - val normalizedUri = new URI(logDirectory).normalize() + val normalizedUri = new File(logDirectory).toURI.normalize() val normalizedLogDir = new File(normalizedUri.getPath) if (!Utils.isInDirectory(workDir, normalizedLogDir)) { return ("Error: invalid log directory " + logDirectory, 0, 0, 0) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index fd905feb97e92..8ebcbcb6a1738 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -17,16 +17,17 @@ package org.apache.spark.deploy.worker.ui +import javax.servlet.http.HttpServletRequest + import scala.xml.Node -import javax.servlet.http.HttpServletRequest import org.json4s.JValue -import org.apache.spark.deploy.JsonProtocol import org.apache.spark.deploy.DeployMessages.{RequestWorkerState, WorkerStateResponse} +import org.apache.spark.deploy.JsonProtocol import org.apache.spark.deploy.master.DriverState import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} -import org.apache.spark.ui.{WebUIPage, UIUtils} +import org.apache.spark.ui.{UIUtils, WebUIPage} import org.apache.spark.util.Utils private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 1a0598e50dcf1..db696b04384bd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -20,8 +20,8 @@ package org.apache.spark.deploy.worker.ui import java.io.File import javax.servlet.http.HttpServletRequest -import org.apache.spark.Logging import org.apache.spark.deploy.worker.Worker +import org.apache.spark.internal.Logging import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.RpcUtils @@ -34,7 +34,8 @@ class WorkerWebUI( val worker: Worker, val workDir: File, requestedPort: Int) - extends WebUI(worker.securityMgr, requestedPort, worker.conf, name = "WorkerUI") + extends WebUI(worker.securityMgr, worker.securityMgr.getSSLOptions("standalone"), + requestedPort, worker.conf, name = "WorkerUI") with Logging { private[ui] val timeout = RpcUtils.askRpcTimeout(worker.conf) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index c2ebf30596215..71b4ad160d679 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -19,32 +19,32 @@ package org.apache.spark.executor import java.net.URL import java.nio.ByteBuffer - -import org.apache.hadoop.conf.Configuration +import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable import scala.util.{Failure, Success} -import org.apache.spark.rpc._ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher +import org.apache.spark.internal.Logging +import org.apache.spark.rpc._ import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} +import org.apache.spark.util.{ThreadUtils, Utils} private[spark] class CoarseGrainedExecutorBackend( override val rpcEnv: RpcEnv, driverUrl: String, executorId: String, - hostPort: String, cores: Int, userClassPath: Seq[URL], env: SparkEnv) extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging { + private[this] val stopping = new AtomicBoolean(false) var executor: Executor = null @volatile var driver: Option[RpcEndpointRef] = None @@ -57,17 +57,14 @@ private[spark] class CoarseGrainedExecutorBackend( rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref => // This is a very fast action so we can use "ThreadUtils.sameThread" driver = Some(ref) - ref.ask[RegisterExecutorResponse]( - RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls)) + ref.ask[Boolean](RegisterExecutor(executorId, self, cores, extractLogUrls)) }(ThreadUtils.sameThread).onComplete { // This is a very fast action so we can use "ThreadUtils.sameThread" - case Success(msg) => Utils.tryLogNonFatalError { - Option(self).foreach(_.send(msg)) // msg must be RegisterExecutorResponse - } - case Failure(e) => { + case Success(msg) => + // Always receive `true`. Just ignore it + case Failure(e) => logError(s"Cannot register with driver: $driverUrl", e) System.exit(1) - } }(ThreadUtils.sameThread) } @@ -106,19 +103,29 @@ private[spark] class CoarseGrainedExecutorBackend( } case StopExecutor => + stopping.set(true) logInfo("Driver commanded a shutdown") // Cannot shutdown here because an ack may need to be sent back to the caller. So send // a message to self to actually do the shutdown. self.send(Shutdown) case Shutdown => - executor.stop() - stop() - rpcEnv.shutdown() + stopping.set(true) + new Thread("CoarseGrainedExecutorBackend-stop-executor") { + override def run(): Unit = { + // executor.stop() will call `SparkEnv.stop()` which waits until RpcEnv stops totally. + // However, if `executor.stop()` runs in some thread of RpcEnv, RpcEnv won't be able to + // stop until `executor.stop()` returns, which becomes a dead-lock (See SPARK-14180). + // Therefore, we put this line in a new thread. + executor.stop() + } + }.start() } override def onDisconnected(remoteAddress: RpcAddress): Unit = { - if (driver.exists(_.address == remoteAddress)) { + if (stopping.get()) { + logInfo(s"Driver from $remoteAddress disconnected during shutdown") + } else if (driver.exists(_.address == remoteAddress)) { logError(s"Driver $remoteAddress disassociated! Shutting down.") System.exit(1) } else { @@ -146,7 +153,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { workerUrl: Option[String], userClassPath: Seq[URL]) { - SignalLogger.register(log) + Utils.initDaemon(log) SparkHadoopUtil.get.runAsSparkUser { () => // Debug code @@ -186,14 +193,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val env = SparkEnv.createExecutorEnv( driverConf, executorId, hostname, port, cores, isLocal = false) - // SparkEnv will set spark.executor.port if the rpc env is listening for incoming - // connections (e.g., if it's using akka). Otherwise, the executor is running in - // client mode only, and does not accept incoming connections. - val sparkHostPort = env.conf.getOption("spark.executor.port").map { port => - hostname + ":" + port - }.orNull env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( - env.rpcEnv, driverUrl, executorId, sparkHostPort, cores, userClassPath, env)) + env.rpcEnv, driverUrl, executorId, cores, userClassPath, env)) workerUrl.foreach { url => env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } @@ -251,13 +252,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { } run(driverUrl, executorId, hostname, cores, appId, workerUrl, userClassPath) + System.exit(0) } private def printUsageAndExit() = { // scalastyle:off println System.err.println( """ - |"Usage: CoarseGrainedExecutorBackend [options] + |Usage: CoarseGrainedExecutorBackend [options] | | Options are: | --driver-url diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 9e88d488c0379..9f94fdef24ebe 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -21,6 +21,7 @@ import java.io.{File, NotSerializableException} import java.lang.management.ManagementFactory import java.net.URL import java.nio.ByteBuffer +import java.util.Properties import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import scala.collection.JavaConverters._ @@ -29,17 +30,20 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging import org.apache.spark.memory.TaskMemoryManager -import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task} +import org.apache.spark.rpc.RpcTimeout +import org.apache.spark.scheduler.{AccumulableInfo, DirectTaskResult, IndirectTaskResult, Task} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} import org.apache.spark.util._ +import org.apache.spark.util.io.ChunkedByteBuffer /** * Spark executor, backed by a threadpool to run tasks. * * This can be used with Mesos, YARN, and the standalone scheduler. - * An internal RPC interface (at the moment Akka) is used for communication with the driver, + * An internal RPC interface is used for communication with the driver, * except in the case of Mesos fine-grained mode. */ private[spark] class Executor( @@ -85,10 +89,6 @@ private[spark] class Executor( env.blockManager.initialize(conf.getAppId) } - // Create an RpcEndpoint for receiving RPCs from the driver - private val executorEndpoint = env.rpcEnv.setupEndpoint( - ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME, new ExecutorEndpoint(env.rpcEnv, executorId)) - // Whether to load classes in user jars before those in Spark jars private val userClassPathFirst = conf.getBoolean("spark.executor.userClassPathFirst", false) @@ -100,9 +100,11 @@ private[spark] class Executor( // Set the classloader for serializer env.serializer.setDefaultClassLoader(replClassLoader) - // Akka's message frame size. If task result is bigger than this, we use the block manager + // Max size of direct result. If task result is bigger than this, we use the block manager // to send the result back. - private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) + private val maxDirectResultSize = Math.min( + conf.getSizeAsBytes("spark.task.maxDirectResultSize", 1L << 20), + RpcUtils.maxMessageSizeBytes(conf)) // Limit of bytes for total size of results (default is 1GB) private val maxResultSize = Utils.getMaxResultSize(conf) @@ -113,6 +115,23 @@ private[spark] class Executor( // Executor for the heartbeat task. private val heartbeater = ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-heartbeater") + // must be initialized before running startDriverHeartbeat() + private val heartbeatReceiverRef = + RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) + + /** + * When an executor is unable to send heartbeats to the driver more than `HEARTBEAT_MAX_FAILURES` + * times, it should kill itself. The default value is 60. It means we will retry to send + * heartbeats about 10 minutes because the heartbeat interval is 10s. + */ + private val HEARTBEAT_MAX_FAILURES = conf.getInt("spark.executor.heartbeat.maxFailures", 60) + + /** + * Count the failure times of heartbeat. It should only be accessed in the heartbeat thread. Each + * successful heartbeat will reset it to 0. + */ + private var heartbeatFailures = 0 + startDriverHeartbeater() def launchTask( @@ -136,7 +155,6 @@ private[spark] class Executor( def stop(): Unit = { env.metricsSystem.report() - env.rpcEnv.stop(executorEndpoint) heartbeater.shutdown() heartbeater.awaitTermination(10, TimeUnit.SECONDS) threadPool.shutdown() @@ -189,9 +207,16 @@ private[spark] class Executor( startGCTime = computeTotalGcTime() try { - val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) + val (taskFiles, taskJars, taskProps, taskBytes) = + Task.deserializeWithDependencies(serializedTask) + + // Must be set before updateDependencies() is called, in case fetching dependencies + // requires access to properties contained within (e.g. for access control). + Executor.taskDeserializationProps.set(taskProps) + updateDependencies(taskFiles, taskJars) task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) + task.localProperties = taskProps task.setTaskMemoryManager(taskMemoryManager) // If this task has been killed before we deserialized it, let's quit now. Otherwise, @@ -210,7 +235,7 @@ private[spark] class Executor( // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() var threwException = true - val (value, accumUpdates) = try { + val value = try { val res = task.run( taskAttemptId = taskId, attemptNumber = attemptNumber, @@ -218,7 +243,9 @@ private[spark] class Executor( threwException = false res } finally { + val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId) val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() + if (freedMemory > 0) { val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId" if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false) && !threwException) { @@ -227,6 +254,17 @@ private[spark] class Executor( logError(errMsg) } } + + if (releasedLocks.nonEmpty) { + val errMsg = + s"${releasedLocks.size} block locks were not released by TID = $taskId:\n" + + releasedLocks.mkString("[", ", ", "]") + if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false) && !threwException) { + throw new SparkException(errMsg) + } else { + logWarning(errMsg) + } + } } val taskFinish = System.currentTimeMillis() @@ -249,10 +287,12 @@ private[spark] class Executor( m.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime) m.setJvmGCTime(computeTotalGcTime() - startGCTime) m.setResultSerializationTime(afterSerialization - beforeSerialization) - m.updateAccumulators() } - val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull) + // Note: accumulator updates must be collected after TaskMetrics is updated + val accumUpdates = task.collectAccumulatorUpdates() + // TODO: do not serialize value twice + val directResult = new DirectTaskResult(valueBytes, accumUpdates) val serializedDirectResult = ser.serialize(directResult) val resultSize = serializedDirectResult.limit @@ -263,10 +303,12 @@ private[spark] class Executor( s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " + s"dropping it.") ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize)) - } else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { + } else if (resultSize > maxDirectResultSize) { val blockId = TaskResultBlockId(taskId) env.blockManager.putBytes( - blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER) + blockId, + new ChunkedByteBuffer(serializedDirectResult.duplicate()), + StorageLevel.MEMORY_AND_DISK_SER) logInfo( s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)") ser.serialize(new IndirectTaskResult[Any](blockId, resultSize)) @@ -287,7 +329,7 @@ private[spark] class Executor( logInfo(s"Executor killed $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) - case cDE: CommitDeniedException => + case CausedBy(cDE: CommitDeniedException) => val reason = cDE.toTaskEndReason execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) @@ -297,21 +339,25 @@ private[spark] class Executor( // the default uncaught exception handler, which will terminate the Executor. logError(s"Exception in $taskName (TID $taskId)", t) - val metrics: Option[TaskMetrics] = Option(task).flatMap { task => - task.metrics.map { m => - m.setExecutorRunTime(System.currentTimeMillis() - taskStart) - m.setJvmGCTime(computeTotalGcTime() - startGCTime) - m.updateAccumulators() - m + // Collect latest accumulator values to report back to the driver + val accumulatorUpdates: Seq[AccumulableInfo] = + if (task != null) { + task.metrics.foreach { m => + m.setExecutorRunTime(System.currentTimeMillis() - taskStart) + m.setJvmGCTime(computeTotalGcTime() - startGCTime) + } + task.collectAccumulatorUpdates(taskFailed = true) + } else { + Seq.empty[AccumulableInfo] } - } + val serializedTaskEndReason = { try { - ser.serialize(new ExceptionFailure(t, metrics)) + ser.serialize(new ExceptionFailure(t, accumulatorUpdates)) } catch { case _: NotSerializableException => // t is not serializable so just send the stacktrace - ser.serialize(new ExceptionFailure(t, metrics, false)) + ser.serialize(new ExceptionFailure(t, accumulatorUpdates, preserveCause = false)) } } execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason) @@ -365,9 +411,9 @@ private[spark] class Executor( val _userClassPathFirst: java.lang.Boolean = userClassPathFirst val klass = Utils.classForName("org.apache.spark.repl.ExecutorClassLoader") .asInstanceOf[Class[_ <: ClassLoader]] - val constructor = klass.getConstructor(classOf[SparkConf], classOf[String], - classOf[ClassLoader], classOf[Boolean]) - constructor.newInstance(conf, classUri, parent, _userClassPathFirst) + val constructor = klass.getConstructor(classOf[SparkConf], classOf[SparkEnv], + classOf[String], classOf[ClassLoader], classOf[Boolean]) + constructor.newInstance(conf, env, classUri, parent, _userClassPathFirst) } catch { case _: ClassNotFoundException => logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!") @@ -416,46 +462,40 @@ private[spark] class Executor( } } - private val heartbeatReceiverRef = - RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) - /** Reports heartbeat and metrics for active tasks to the driver. */ private def reportHeartBeat(): Unit = { - // list of (task id, metrics) to send back to the driver - val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]() + // list of (task id, accumUpdates) to send back to the driver + val accumUpdates = new ArrayBuffer[(Long, Seq[AccumulableInfo])]() val curGCTime = computeTotalGcTime() for (taskRunner <- runningTasks.values().asScala) { if (taskRunner.task != null) { taskRunner.task.metrics.foreach { metrics => - metrics.updateShuffleReadMetrics() - metrics.updateInputMetrics() + metrics.mergeShuffleReadMetrics() metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime) - metrics.updateAccumulators() - - if (isLocal) { - // JobProgressListener will hold an reference of it during - // onExecutorMetricsUpdate(), then JobProgressListener can not see - // the changes of metrics any more, so make a deep copy of it - val copiedMetrics = Utils.deserialize[TaskMetrics](Utils.serialize(metrics)) - tasksMetrics += ((taskRunner.taskId, copiedMetrics)) - } else { - // It will be copied by serialization - tasksMetrics += ((taskRunner.taskId, metrics)) - } + accumUpdates += ((taskRunner.taskId, metrics.accumulatorUpdates())) } } } - val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) + val message = Heartbeat(executorId, accumUpdates.toArray, env.blockManager.blockManagerId) try { - val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](message) + val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse]( + message, RpcTimeout(conf, "spark.executor.heartbeatInterval", "10s")) if (response.reregisterBlockManager) { logInfo("Told to re-register on heartbeat") env.blockManager.reregister() } + heartbeatFailures = 0 } catch { - case NonFatal(e) => logWarning("Issue communicating with driver in heartbeater", e) + case NonFatal(e) => + logWarning("Issue communicating with driver in heartbeater", e) + heartbeatFailures += 1 + if (heartbeatFailures >= HEARTBEAT_MAX_FAILURES) { + logError(s"Exit as unable to send heartbeats to driver " + + s"more than $HEARTBEAT_MAX_FAILURES times") + System.exit(ExecutorExitCode.HEARTBEAT_FAILURE) + } } } @@ -474,3 +514,10 @@ private[spark] class Executor( heartbeater.scheduleAtFixedRate(heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS) } } + +private[spark] object Executor { + // This is reserved for internal use by components that need to read task properties before a + // task is fully deserialized. When possible, the TaskContext.getLocalProperty call should be + // used instead. + val taskDeserializationProps: ThreadLocal[Properties] = new ThreadLocal[Properties] +} diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala index e07cb31cbe4ba..7153323d01a0b 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala @@ -25,6 +25,6 @@ import org.apache.spark.TaskState.TaskState * A pluggable interface used by the Executor to send updates to the cluster scheduler. */ private[spark] trait ExecutorBackend { - def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) + def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer): Unit } diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala deleted file mode 100644 index cf362f8464735..0000000000000 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.executor - -import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} -import org.apache.spark.util.Utils - -/** - * Driver -> Executor message to trigger a thread dump. - */ -private[spark] case object TriggerThreadDump - -/** - * [[RpcEndpoint]] that runs inside of executors to enable driver -> executor RPC. - */ -private[spark] -class ExecutorEndpoint(override val rpcEnv: RpcEnv, executorId: String) extends RpcEndpoint { - - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case TriggerThreadDump => - context.reply(Utils.getThreadDump()) - } - -} - -object ExecutorEndpoint { - val EXECUTOR_ENDPOINT_NAME = "ExecutorEndpoint" -} diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala index ea36fb60bd540..99858f785600d 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala @@ -39,6 +39,12 @@ object ExecutorExitCode { /** ExternalBlockStore failed to create a local temporary directory after many attempts. */ val EXTERNAL_BLOCK_STORE_FAILED_TO_CREATE_DIR = 55 + /** + * Executor is unable to send heartbeats to the driver more than + * "spark.executor.heartbeat.maxFailures" times. + */ + val HEARTBEAT_FAILURE = 56 + def explainExitCode(exitCode: Int): String = { exitCode match { case UNCAUGHT_EXCEPTION => "Uncaught exception" @@ -51,6 +57,8 @@ object ExecutorExitCode { // TODO: replace external block store with concrete implementation name case EXTERNAL_BLOCK_STORE_FAILED_TO_CREATE_DIR => "ExternalBlockStore failed to create a local temporary directory." + case HEARTBEAT_FAILURE => + "Unable to send heartbeats to driver." case _ => "Unknown executor exit code (" + exitCode + ")" + ( if (exitCode > 128) { diff --git a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala new file mode 100644 index 0000000000000..83e11c5e236d4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import org.apache.spark.{Accumulator, InternalAccumulator} +import org.apache.spark.annotation.DeveloperApi + + +/** + * :: DeveloperApi :: + * Method by which input data was read. Network means that the data was read over the network + * from a remote block manager (which may have stored the data on-disk or in-memory). + * Operations are not thread-safe. + */ +@DeveloperApi +object DataReadMethod extends Enumeration with Serializable { + type DataReadMethod = Value + val Memory, Disk, Hadoop, Network = Value +} + + +/** + * :: DeveloperApi :: + * A collection of accumulators that represents metrics about reading data from external systems. + */ +@DeveloperApi +class InputMetrics private ( + _bytesRead: Accumulator[Long], + _recordsRead: Accumulator[Long], + _readMethod: Accumulator[String]) + extends Serializable { + + private[executor] def this(accumMap: Map[String, Accumulator[_]]) { + this( + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.input.BYTES_READ), + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.input.RECORDS_READ), + TaskMetrics.getAccum[String](accumMap, InternalAccumulator.input.READ_METHOD)) + } + + /** + * Create a new [[InputMetrics]] that is not associated with any particular task. + * + * This mainly exists because of SPARK-5225, where we are forced to use a dummy [[InputMetrics]] + * because we want to ignore metrics from a second read method. In the future, we should revisit + * whether this is needed. + * + * A better alternative is [[TaskMetrics.registerInputMetrics]]. + */ + private[executor] def this() { + this(InternalAccumulator.createInputAccums() + .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]]) + } + + /** + * Total number of bytes read. + */ + def bytesRead: Long = _bytesRead.localValue + + /** + * Total number of records read. + */ + def recordsRead: Long = _recordsRead.localValue + + /** + * The source from which this task reads its input. + */ + def readMethod: DataReadMethod.Value = DataReadMethod.withName(_readMethod.localValue) + + private[spark] def incBytesRead(v: Long): Unit = _bytesRead.add(v) + private[spark] def incRecordsRead(v: Long): Unit = _recordsRead.add(v) + private[spark] def setBytesRead(v: Long): Unit = _bytesRead.setValue(v) + private[spark] def setReadMethod(v: DataReadMethod.Value): Unit = _readMethod.setValue(v.toString) + +} diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index c9f18ebc7f0ea..680cfb733e9e6 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -21,15 +21,16 @@ import java.nio.ByteBuffer import scala.collection.JavaConverters._ -import org.apache.mesos.protobuf.ByteString import org.apache.mesos.{Executor => MesosExecutor, ExecutorDriver, MesosExecutorDriver} import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _} +import org.apache.mesos.protobuf.ByteString -import org.apache.spark.{Logging, TaskState, SparkConf, SparkEnv} +import org.apache.spark.{SparkConf, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging import org.apache.spark.scheduler.cluster.mesos.MesosTaskLaunchData -import org.apache.spark.util.{SignalLogger, Utils} +import org.apache.spark.util.Utils private[spark] class MesosExecutorBackend extends MesosExecutor @@ -121,7 +122,7 @@ private[spark] class MesosExecutorBackend */ private[spark] object MesosExecutorBackend extends Logging { def main(args: Array[String]) { - SignalLogger.register(log) + Utils.initDaemon(log) // Create a new Executor and start it running val runner = new MesosExecutorBackend() new MesosExecutorDriver(runner).run() diff --git a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala new file mode 100644 index 0000000000000..93f953846fe26 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import org.apache.spark.{Accumulator, InternalAccumulator} +import org.apache.spark.annotation.DeveloperApi + + +/** + * :: DeveloperApi :: + * Method by which output data was written. + * Operations are not thread-safe. + */ +@DeveloperApi +object DataWriteMethod extends Enumeration with Serializable { + type DataWriteMethod = Value + val Hadoop = Value +} + + +/** + * :: DeveloperApi :: + * A collection of accumulators that represents metrics about writing data to external systems. + */ +@DeveloperApi +class OutputMetrics private ( + _bytesWritten: Accumulator[Long], + _recordsWritten: Accumulator[Long], + _writeMethod: Accumulator[String]) + extends Serializable { + + private[executor] def this(accumMap: Map[String, Accumulator[_]]) { + this( + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.output.BYTES_WRITTEN), + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.output.RECORDS_WRITTEN), + TaskMetrics.getAccum[String](accumMap, InternalAccumulator.output.WRITE_METHOD)) + } + + /** + * Total number of bytes written. + */ + def bytesWritten: Long = _bytesWritten.localValue + + /** + * Total number of records written. + */ + def recordsWritten: Long = _recordsWritten.localValue + + /** + * The source to which this task writes its output. + */ + def writeMethod: DataWriteMethod.Value = DataWriteMethod.withName(_writeMethod.localValue) + + private[spark] def setBytesWritten(v: Long): Unit = _bytesWritten.setValue(v) + private[spark] def setRecordsWritten(v: Long): Unit = _recordsWritten.setValue(v) + private[spark] def setWriteMethod(v: DataWriteMethod.Value): Unit = + _writeMethod.setValue(v.toString) + +} diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala new file mode 100644 index 0000000000000..71a24770b50ae --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import org.apache.spark.{Accumulator, InternalAccumulator} +import org.apache.spark.annotation.DeveloperApi + + +/** + * :: DeveloperApi :: + * A collection of accumulators that represent metrics about reading shuffle data. + * Operations are not thread-safe. + */ +@DeveloperApi +class ShuffleReadMetrics private ( + _remoteBlocksFetched: Accumulator[Int], + _localBlocksFetched: Accumulator[Int], + _remoteBytesRead: Accumulator[Long], + _localBytesRead: Accumulator[Long], + _fetchWaitTime: Accumulator[Long], + _recordsRead: Accumulator[Long]) + extends Serializable { + + private[executor] def this(accumMap: Map[String, Accumulator[_]]) { + this( + TaskMetrics.getAccum[Int](accumMap, InternalAccumulator.shuffleRead.REMOTE_BLOCKS_FETCHED), + TaskMetrics.getAccum[Int](accumMap, InternalAccumulator.shuffleRead.LOCAL_BLOCKS_FETCHED), + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleRead.REMOTE_BYTES_READ), + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleRead.LOCAL_BYTES_READ), + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleRead.FETCH_WAIT_TIME), + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleRead.RECORDS_READ)) + } + + /** + * Create a new [[ShuffleReadMetrics]] that is not associated with any particular task. + * + * This mainly exists for legacy reasons, because we use dummy [[ShuffleReadMetrics]] in + * many places only to merge their values together later. In the future, we should revisit + * whether this is needed. + * + * A better alternative is [[TaskMetrics.registerTempShuffleReadMetrics]] followed by + * [[TaskMetrics.mergeShuffleReadMetrics]]. + */ + private[spark] def this() { + this(InternalAccumulator.createShuffleReadAccums().map { a => (a.name.get, a) }.toMap) + } + + /** + * Number of remote blocks fetched in this shuffle by this task. + */ + def remoteBlocksFetched: Int = _remoteBlocksFetched.localValue + + /** + * Number of local blocks fetched in this shuffle by this task. + */ + def localBlocksFetched: Int = _localBlocksFetched.localValue + + /** + * Total number of remote bytes read from the shuffle by this task. + */ + def remoteBytesRead: Long = _remoteBytesRead.localValue + + /** + * Shuffle data that was read from the local disk (as opposed to from a remote executor). + */ + def localBytesRead: Long = _localBytesRead.localValue + + /** + * Time the task spent waiting for remote shuffle blocks. This only includes the time + * blocking on shuffle input data. For instance if block B is being fetched while the task is + * still not finished processing block A, it is not considered to be blocking on block B. + */ + def fetchWaitTime: Long = _fetchWaitTime.localValue + + /** + * Total number of records read from the shuffle by this task. + */ + def recordsRead: Long = _recordsRead.localValue + + /** + * Total bytes fetched in the shuffle by this task (both remote and local). + */ + def totalBytesRead: Long = remoteBytesRead + localBytesRead + + /** + * Number of blocks fetched in this shuffle by this task (remote or local). + */ + def totalBlocksFetched: Int = remoteBlocksFetched + localBlocksFetched + + private[spark] def incRemoteBlocksFetched(v: Int): Unit = _remoteBlocksFetched.add(v) + private[spark] def incLocalBlocksFetched(v: Int): Unit = _localBlocksFetched.add(v) + private[spark] def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead.add(v) + private[spark] def incLocalBytesRead(v: Long): Unit = _localBytesRead.add(v) + private[spark] def incFetchWaitTime(v: Long): Unit = _fetchWaitTime.add(v) + private[spark] def incRecordsRead(v: Long): Unit = _recordsRead.add(v) + + private[spark] def setRemoteBlocksFetched(v: Int): Unit = _remoteBlocksFetched.setValue(v) + private[spark] def setLocalBlocksFetched(v: Int): Unit = _localBlocksFetched.setValue(v) + private[spark] def setRemoteBytesRead(v: Long): Unit = _remoteBytesRead.setValue(v) + private[spark] def setLocalBytesRead(v: Long): Unit = _localBytesRead.setValue(v) + private[spark] def setFetchWaitTime(v: Long): Unit = _fetchWaitTime.setValue(v) + private[spark] def setRecordsRead(v: Long): Unit = _recordsRead.setValue(v) + + /** + * Resets the value of the current metrics (`this`) and and merges all the independent + * [[ShuffleReadMetrics]] into `this`. + */ + private[spark] def setMergeValues(metrics: Seq[ShuffleReadMetrics]): Unit = { + _remoteBlocksFetched.setValue(_remoteBlocksFetched.zero) + _localBlocksFetched.setValue(_localBlocksFetched.zero) + _remoteBytesRead.setValue(_remoteBytesRead.zero) + _localBytesRead.setValue(_localBytesRead.zero) + _fetchWaitTime.setValue(_fetchWaitTime.zero) + _recordsRead.setValue(_recordsRead.zero) + metrics.foreach { metric => + _remoteBlocksFetched.add(metric.remoteBlocksFetched) + _localBlocksFetched.add(metric.localBlocksFetched) + _remoteBytesRead.add(metric.remoteBytesRead) + _localBytesRead.add(metric.localBytesRead) + _fetchWaitTime.add(metric.fetchWaitTime) + _recordsRead.add(metric.recordsRead) + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala new file mode 100644 index 0000000000000..c7aaabb561bba --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import org.apache.spark.{Accumulator, InternalAccumulator} +import org.apache.spark.annotation.DeveloperApi + + +/** + * :: DeveloperApi :: + * A collection of accumulators that represent metrics about writing shuffle data. + * Operations are not thread-safe. + */ +@DeveloperApi +class ShuffleWriteMetrics private ( + _bytesWritten: Accumulator[Long], + _recordsWritten: Accumulator[Long], + _writeTime: Accumulator[Long]) + extends Serializable { + + private[executor] def this(accumMap: Map[String, Accumulator[_]]) { + this( + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleWrite.BYTES_WRITTEN), + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleWrite.RECORDS_WRITTEN), + TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleWrite.WRITE_TIME)) + } + + /** + * Create a new [[ShuffleWriteMetrics]] that is not associated with any particular task. + * + * This mainly exists for legacy reasons, because we use dummy [[ShuffleWriteMetrics]] in + * many places only to merge their values together later. In the future, we should revisit + * whether this is needed. + * + * A better alternative is [[TaskMetrics.registerShuffleWriteMetrics]]. + */ + private[spark] def this() { + this(InternalAccumulator.createShuffleWriteAccums().map { a => (a.name.get, a) }.toMap) + } + + /** + * Number of bytes written for the shuffle by this task. + */ + def bytesWritten: Long = _bytesWritten.localValue + + /** + * Total number of records written to the shuffle by this task. + */ + def recordsWritten: Long = _recordsWritten.localValue + + /** + * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds. + */ + def writeTime: Long = _writeTime.localValue + + private[spark] def incBytesWritten(v: Long): Unit = _bytesWritten.add(v) + private[spark] def incRecordsWritten(v: Long): Unit = _recordsWritten.add(v) + private[spark] def incWriteTime(v: Long): Unit = _writeTime.add(v) + private[spark] def decBytesWritten(v: Long): Unit = { + _bytesWritten.setValue(bytesWritten - v) + } + private[spark] def decRecordsWritten(v: Long): Unit = { + _recordsWritten.setValue(recordsWritten - v) + } + + // Legacy methods for backward compatibility. + // TODO: remove these once we make this class private. + @deprecated("use bytesWritten instead", "2.0.0") + def shuffleBytesWritten: Long = bytesWritten + @deprecated("use writeTime instead", "2.0.0") + def shuffleWriteTime: Long = writeTime + @deprecated("use recordsWritten instead", "2.0.0") + def shuffleRecordsWritten: Long = recordsWritten + +} diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 42207a9553592..bda2a91d9d2ca 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,421 +17,392 @@ package org.apache.spark.executor -import java.io.{IOException, ObjectInputStream} -import java.util.concurrent.ConcurrentHashMap - +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.executor.DataReadMethod.DataReadMethod +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.storage.{BlockId, BlockStatus} -import org.apache.spark.util.Utils + /** * :: DeveloperApi :: * Metrics tracked during the execution of a task. * - * This class is used to house metrics both for in-progress and completed tasks. In executors, - * both the task thread and the heartbeat thread write to the TaskMetrics. The heartbeat thread - * reads it to send in-progress metrics, and the task thread reads it to send metrics along with - * the completed task. + * This class is wrapper around a collection of internal accumulators that represent metrics + * associated with a task. The local values of these accumulators are sent from the executor + * to the driver when the task completes. These values are then merged into the corresponding + * accumulator previously registered on the driver. * - * So, when adding new fields, take into consideration that the whole object can be serialized for - * shipping off at any time to consumers of the SparkListener interface. + * The accumulator updates are also sent to the driver periodically (on executor heartbeat) + * and when the task failed with an exception. The [[TaskMetrics]] object itself should never + * be sent to the driver. + * + * @param initialAccums the initial set of accumulators that this [[TaskMetrics]] depends on. + * Each accumulator in this initial set must be uniquely named and marked + * as internal. Additional accumulators registered later need not satisfy + * these requirements. */ @DeveloperApi -class TaskMetrics extends Serializable { - /** - * Host's name the task runs on - */ - private var _hostname: String = _ - def hostname: String = _hostname - private[spark] def setHostname(value: String) = _hostname = value - - /** - * Time taken on the executor to deserialize this task - */ - private var _executorDeserializeTime: Long = _ - def executorDeserializeTime: Long = _executorDeserializeTime - private[spark] def setExecutorDeserializeTime(value: Long) = _executorDeserializeTime = value +class TaskMetrics private[spark] (initialAccums: Seq[Accumulator[_]]) extends Serializable { + import InternalAccumulator._ + // Needed for Java tests + def this() { + this(InternalAccumulator.createAll()) + } /** - * Time the executor spends actually running the task (including fetching shuffle data) + * All accumulators registered with this task. */ - private var _executorRunTime: Long = _ - def executorRunTime: Long = _executorRunTime - private[spark] def setExecutorRunTime(value: Long) = _executorRunTime = value + private val accums = new ArrayBuffer[Accumulable[_, _]] + accums ++= initialAccums /** - * The number of bytes this task transmitted back to the driver as the TaskResult + * A map for quickly accessing the initial set of accumulators by name. */ - private var _resultSize: Long = _ - def resultSize: Long = _resultSize - private[spark] def setResultSize(value: Long) = _resultSize = value + private val initialAccumsMap: Map[String, Accumulator[_]] = { + val map = new mutable.HashMap[String, Accumulator[_]] + initialAccums.foreach { a => + val name = a.name.getOrElse { + throw new IllegalArgumentException( + "initial accumulators passed to TaskMetrics must be named") + } + require(a.isInternal, + s"initial accumulator '$name' passed to TaskMetrics must be marked as internal") + require(!map.contains(name), + s"detected duplicate accumulator name '$name' when constructing TaskMetrics") + map(name) = a + } + map.toMap + } + // Each metric is internally represented as an accumulator + private val _executorDeserializeTime = getAccum(EXECUTOR_DESERIALIZE_TIME) + private val _executorRunTime = getAccum(EXECUTOR_RUN_TIME) + private val _resultSize = getAccum(RESULT_SIZE) + private val _jvmGCTime = getAccum(JVM_GC_TIME) + private val _resultSerializationTime = getAccum(RESULT_SERIALIZATION_TIME) + private val _memoryBytesSpilled = getAccum(MEMORY_BYTES_SPILLED) + private val _diskBytesSpilled = getAccum(DISK_BYTES_SPILLED) + private val _peakExecutionMemory = getAccum(PEAK_EXECUTION_MEMORY) + private val _updatedBlockStatuses = + TaskMetrics.getAccum[Seq[(BlockId, BlockStatus)]](initialAccumsMap, UPDATED_BLOCK_STATUSES) /** - * Amount of time the JVM spent in garbage collection while executing this task + * Time taken on the executor to deserialize this task. */ - private var _jvmGCTime: Long = _ - def jvmGCTime: Long = _jvmGCTime - private[spark] def setJvmGCTime(value: Long) = _jvmGCTime = value + def executorDeserializeTime: Long = _executorDeserializeTime.localValue /** - * Amount of time spent serializing the task result + * Time the executor spends actually running the task (including fetching shuffle data). */ - private var _resultSerializationTime: Long = _ - def resultSerializationTime: Long = _resultSerializationTime - private[spark] def setResultSerializationTime(value: Long) = _resultSerializationTime = value + def executorRunTime: Long = _executorRunTime.localValue /** - * The number of in-memory bytes spilled by this task + * The number of bytes this task transmitted back to the driver as the TaskResult. */ - private var _memoryBytesSpilled: Long = _ - def memoryBytesSpilled: Long = _memoryBytesSpilled - private[spark] def incMemoryBytesSpilled(value: Long): Unit = _memoryBytesSpilled += value - private[spark] def decMemoryBytesSpilled(value: Long): Unit = _memoryBytesSpilled -= value + def resultSize: Long = _resultSize.localValue /** - * The number of on-disk bytes spilled by this task + * Amount of time the JVM spent in garbage collection while executing this task. */ - private var _diskBytesSpilled: Long = _ - def diskBytesSpilled: Long = _diskBytesSpilled - private[spark] def incDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled += value - private[spark] def decDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled -= value + def jvmGCTime: Long = _jvmGCTime.localValue /** - * If this task reads from a HadoopRDD or from persisted data, metrics on how much data was read - * are stored here. + * Amount of time spent serializing the task result. */ - private var _inputMetrics: Option[InputMetrics] = None - - def inputMetrics: Option[InputMetrics] = _inputMetrics + def resultSerializationTime: Long = _resultSerializationTime.localValue /** - * This should only be used when recreating TaskMetrics, not when updating input metrics in - * executors + * The number of in-memory bytes spilled by this task. */ - private[spark] def setInputMetrics(inputMetrics: Option[InputMetrics]) { - _inputMetrics = inputMetrics - } + def memoryBytesSpilled: Long = _memoryBytesSpilled.localValue /** - * If this task writes data externally (e.g. to a distributed filesystem), metrics on how much - * data was written are stored here. + * The number of on-disk bytes spilled by this task. */ - var outputMetrics: Option[OutputMetrics] = None + def diskBytesSpilled: Long = _diskBytesSpilled.localValue /** - * If this task reads from shuffle output, metrics on getting shuffle data will be collected here. - * This includes read metrics aggregated over all the task's shuffle dependencies. + * Peak memory used by internal data structures created during shuffles, aggregations and + * joins. The value of this accumulator should be approximately the sum of the peak sizes + * across all such data structures created in this task. For SQL jobs, this only tracks all + * unsafe operators and ExternalSort. */ - private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None - - def shuffleReadMetrics: Option[ShuffleReadMetrics] = _shuffleReadMetrics + def peakExecutionMemory: Long = _peakExecutionMemory.localValue /** - * This should only be used when recreating TaskMetrics, not when updating read metrics in - * executors. + * Storage statuses of any blocks that have been updated as a result of this task. */ - private[spark] def setShuffleReadMetrics(shuffleReadMetrics: Option[ShuffleReadMetrics]) { - _shuffleReadMetrics = shuffleReadMetrics + def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = _updatedBlockStatuses.localValue + + // Setters and increment-ers + private[spark] def setExecutorDeserializeTime(v: Long): Unit = + _executorDeserializeTime.setValue(v) + private[spark] def setExecutorRunTime(v: Long): Unit = _executorRunTime.setValue(v) + private[spark] def setResultSize(v: Long): Unit = _resultSize.setValue(v) + private[spark] def setJvmGCTime(v: Long): Unit = _jvmGCTime.setValue(v) + private[spark] def setResultSerializationTime(v: Long): Unit = + _resultSerializationTime.setValue(v) + private[spark] def incMemoryBytesSpilled(v: Long): Unit = _memoryBytesSpilled.add(v) + private[spark] def incDiskBytesSpilled(v: Long): Unit = _diskBytesSpilled.add(v) + private[spark] def incPeakExecutionMemory(v: Long): Unit = _peakExecutionMemory.add(v) + private[spark] def incUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = + _updatedBlockStatuses.add(v) + private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = + _updatedBlockStatuses.setValue(v) + + /** + * Get a Long accumulator from the given map by name, assuming it exists. + * Note: this only searches the initial set of accumulators passed into the constructor. + */ + private[spark] def getAccum(name: String): Accumulator[Long] = { + TaskMetrics.getAccum[Long](initialAccumsMap, name) } - /** - * ShuffleReadMetrics per dependency for collecting independently while task is in progress. - */ - @transient private lazy val depsShuffleReadMetrics: ArrayBuffer[ShuffleReadMetrics] = - new ArrayBuffer[ShuffleReadMetrics]() - /** - * If this task writes to shuffle output, metrics on the written shuffle data will be collected - * here - */ - var shuffleWriteMetrics: Option[ShuffleWriteMetrics] = None + /* ========================== * + | INPUT METRICS | + * ========================== */ - /** - * Storage statuses of any blocks that have been updated as a result of this task. - */ - var updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = None + private var _inputMetrics: Option[InputMetrics] = None /** - * A task may have multiple shuffle readers for multiple dependencies. To avoid synchronization - * issues from readers in different threads, in-progress tasks use a ShuffleReadMetrics for each - * dependency, and merge these metrics before reporting them to the driver. This method returns - * a ShuffleReadMetrics for a dependency and registers it for merging later. + * Metrics related to reading data from a [[org.apache.spark.rdd.HadoopRDD]] or from persisted + * data, defined only in tasks with input. */ - private [spark] def createShuffleReadMetricsForDependency(): ShuffleReadMetrics = synchronized { - val readMetrics = new ShuffleReadMetrics() - depsShuffleReadMetrics += readMetrics - readMetrics - } + def inputMetrics: Option[InputMetrics] = _inputMetrics /** - * Returns the input metrics object that the task should use. Currently, if - * there exists an input metric with the same readMethod, we return that one - * so the caller can accumulate bytes read. If the readMethod is different - * than previously seen by this task, we return a new InputMetric but don't - * record it. - * - * Once https://issues.apache.org/jira/browse/SPARK-5225 is addressed, - * we can store all the different inputMetrics (one per readMethod). + * Get or create a new [[InputMetrics]] associated with this task. */ - private[spark] def getInputMetricsForReadMethod(readMethod: DataReadMethod): InputMetrics = { + private[spark] def registerInputMetrics(readMethod: DataReadMethod.Value): InputMetrics = { synchronized { - _inputMetrics match { - case None => - val metrics = new InputMetrics(readMethod) - _inputMetrics = Some(metrics) - metrics - case Some(metrics @ InputMetrics(method)) if method == readMethod => - metrics - case Some(InputMetrics(method)) => - new InputMetrics(readMethod) + val metrics = _inputMetrics.getOrElse { + val metrics = new InputMetrics(initialAccumsMap) + metrics.setReadMethod(readMethod) + _inputMetrics = Some(metrics) + metrics } - } - } - - /** - * Aggregates shuffle read metrics for all registered dependencies into shuffleReadMetrics. - */ - private[spark] def updateShuffleReadMetrics(): Unit = synchronized { - if (!depsShuffleReadMetrics.isEmpty) { - val merged = new ShuffleReadMetrics() - for (depMetrics <- depsShuffleReadMetrics) { - merged.incFetchWaitTime(depMetrics.fetchWaitTime) - merged.incLocalBlocksFetched(depMetrics.localBlocksFetched) - merged.incRemoteBlocksFetched(depMetrics.remoteBlocksFetched) - merged.incRemoteBytesRead(depMetrics.remoteBytesRead) - merged.incLocalBytesRead(depMetrics.localBytesRead) - merged.incRecordsRead(depMetrics.recordsRead) + // If there already exists an InputMetric with the same read method, we can just return + // that one. Otherwise, if the read method is different from the one previously seen by + // this task, we return a new dummy one to avoid clobbering the values of the old metrics. + // In the future we should try to store input metrics from all different read methods at + // the same time (SPARK-5225). + if (metrics.readMethod == readMethod) { + metrics + } else { + val m = new InputMetrics + m.setReadMethod(readMethod) + m } - _shuffleReadMetrics = Some(merged) } } - private[spark] def updateInputMetrics(): Unit = synchronized { - inputMetrics.foreach(_.updateBytesRead()) - } - - @throws(classOf[IOException]) - private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { - in.defaultReadObject() - // Get the hostname from cached data, since hostname is the order of number of nodes in - // cluster, so using cached hostname will decrease the object number and alleviate the GC - // overhead. - _hostname = TaskMetrics.getCachedHostName(_hostname) - } - private var _accumulatorUpdates: Map[Long, Any] = Map.empty - @transient private var _accumulatorsUpdater: () => Map[Long, Any] = null + /* ============================ * + | OUTPUT METRICS | + * ============================ */ - private[spark] def updateAccumulators(): Unit = synchronized { - _accumulatorUpdates = _accumulatorsUpdater() - } + private var _outputMetrics: Option[OutputMetrics] = None /** - * Return the latest updates of accumulators in this task. + * Metrics related to writing data externally (e.g. to a distributed filesystem), + * defined only in tasks with output. */ - def accumulatorUpdates(): Map[Long, Any] = _accumulatorUpdates - - private[spark] def setAccumulatorsUpdater(accumulatorsUpdater: () => Map[Long, Any]): Unit = { - _accumulatorsUpdater = accumulatorsUpdater - } -} - -private[spark] object TaskMetrics { - private val hostNameCache = new ConcurrentHashMap[String, String]() - - def empty: TaskMetrics = new TaskMetrics + def outputMetrics: Option[OutputMetrics] = _outputMetrics - def getCachedHostName(host: String): String = { - val canonicalHost = hostNameCache.putIfAbsent(host, host) - if (canonicalHost != null) canonicalHost else host + /** + * Get or create a new [[OutputMetrics]] associated with this task. + */ + private[spark] def registerOutputMetrics( + writeMethod: DataWriteMethod.Value): OutputMetrics = synchronized { + _outputMetrics.getOrElse { + val metrics = new OutputMetrics(initialAccumsMap) + metrics.setWriteMethod(writeMethod) + _outputMetrics = Some(metrics) + metrics + } } -} -/** - * :: DeveloperApi :: - * Method by which input data was read. Network means that the data was read over the network - * from a remote block manager (which may have stored the data on-disk or in-memory). - */ -@DeveloperApi -object DataReadMethod extends Enumeration with Serializable { - type DataReadMethod = Value - val Memory, Disk, Hadoop, Network = Value -} -/** - * :: DeveloperApi :: - * Method by which output data was written. - */ -@DeveloperApi -object DataWriteMethod extends Enumeration with Serializable { - type DataWriteMethod = Value - val Hadoop = Value -} + /* ================================== * + | SHUFFLE READ METRICS | + * ================================== */ -/** - * :: DeveloperApi :: - * Metrics about reading input data. - */ -@DeveloperApi -case class InputMetrics(readMethod: DataReadMethod.Value) { + private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None /** - * This is volatile so that it is visible to the updater thread. + * Metrics related to shuffle read aggregated across all shuffle dependencies. + * This is defined only if there are shuffle dependencies in this task. */ - @volatile @transient var bytesReadCallback: Option[() => Long] = None + def shuffleReadMetrics: Option[ShuffleReadMetrics] = _shuffleReadMetrics /** - * Total bytes read. + * Temporary list of [[ShuffleReadMetrics]], one per shuffle dependency. + * + * A task may have multiple shuffle readers for multiple dependencies. To avoid synchronization + * issues from readers in different threads, in-progress tasks use a [[ShuffleReadMetrics]] for + * each dependency and merge these metrics before reporting them to the driver. */ - private var _bytesRead: Long = _ - def bytesRead: Long = _bytesRead - def incBytesRead(bytes: Long): Unit = _bytesRead += bytes + @transient private lazy val tempShuffleReadMetrics = new ArrayBuffer[ShuffleReadMetrics] /** - * Total records read. + * Create a temporary [[ShuffleReadMetrics]] for a particular shuffle dependency. + * + * All usages are expected to be followed by a call to [[mergeShuffleReadMetrics]], which + * merges the temporary values synchronously. Otherwise, all temporary data collected will + * be lost. */ - private var _recordsRead: Long = _ - def recordsRead: Long = _recordsRead - def incRecordsRead(records: Long): Unit = _recordsRead += records + private[spark] def registerTempShuffleReadMetrics(): ShuffleReadMetrics = synchronized { + val readMetrics = new ShuffleReadMetrics + tempShuffleReadMetrics += readMetrics + readMetrics + } /** - * Invoke the bytesReadCallback and mutate bytesRead. + * Merge values across all temporary [[ShuffleReadMetrics]] into `_shuffleReadMetrics`. + * This is expected to be called on executor heartbeat and at the end of a task. */ - def updateBytesRead() { - bytesReadCallback.foreach { c => - _bytesRead = c() + private[spark] def mergeShuffleReadMetrics(): Unit = synchronized { + if (tempShuffleReadMetrics.nonEmpty) { + val metrics = new ShuffleReadMetrics(initialAccumsMap) + metrics.setMergeValues(tempShuffleReadMetrics) + _shuffleReadMetrics = Some(metrics) } } - /** - * Register a function that can be called to get up-to-date information on how many bytes the task - * has read from an input source. - */ - def setBytesReadCallback(f: Option[() => Long]) { - bytesReadCallback = f - } -} - -/** - * :: DeveloperApi :: - * Metrics about writing output data. - */ -@DeveloperApi -case class OutputMetrics(writeMethod: DataWriteMethod.Value) { - /** - * Total bytes written - */ - private var _bytesWritten: Long = _ - def bytesWritten: Long = _bytesWritten - private[spark] def setBytesWritten(value : Long): Unit = _bytesWritten = value + /* =================================== * + | SHUFFLE WRITE METRICS | + * =================================== */ - /** - * Total records written - */ - private var _recordsWritten: Long = 0L - def recordsWritten: Long = _recordsWritten - private[spark] def setRecordsWritten(value: Long): Unit = _recordsWritten = value -} + private var _shuffleWriteMetrics: Option[ShuffleWriteMetrics] = None -/** - * :: DeveloperApi :: - * Metrics pertaining to shuffle data read in a given task. - */ -@DeveloperApi -class ShuffleReadMetrics extends Serializable { /** - * Number of remote blocks fetched in this shuffle by this task + * Metrics related to shuffle write, defined only in shuffle map stages. */ - private var _remoteBlocksFetched: Int = _ - def remoteBlocksFetched: Int = _remoteBlocksFetched - private[spark] def incRemoteBlocksFetched(value: Int) = _remoteBlocksFetched += value - private[spark] def decRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value + def shuffleWriteMetrics: Option[ShuffleWriteMetrics] = _shuffleWriteMetrics /** - * Number of local blocks fetched in this shuffle by this task + * Get or create a new [[ShuffleWriteMetrics]] associated with this task. */ - private var _localBlocksFetched: Int = _ - def localBlocksFetched: Int = _localBlocksFetched - private[spark] def incLocalBlocksFetched(value: Int) = _localBlocksFetched += value - private[spark] def decLocalBlocksFetched(value: Int) = _localBlocksFetched -= value + private[spark] def registerShuffleWriteMetrics(): ShuffleWriteMetrics = synchronized { + _shuffleWriteMetrics.getOrElse { + val metrics = new ShuffleWriteMetrics(initialAccumsMap) + _shuffleWriteMetrics = Some(metrics) + metrics + } + } - /** - * Time the task spent waiting for remote shuffle blocks. This only includes the time - * blocking on shuffle input data. For instance if block B is being fetched while the task is - * still not finished processing block A, it is not considered to be blocking on block B. - */ - private var _fetchWaitTime: Long = _ - def fetchWaitTime: Long = _fetchWaitTime - private[spark] def incFetchWaitTime(value: Long) = _fetchWaitTime += value - private[spark] def decFetchWaitTime(value: Long) = _fetchWaitTime -= value - /** - * Total number of remote bytes read from the shuffle by this task - */ - private var _remoteBytesRead: Long = _ - def remoteBytesRead: Long = _remoteBytesRead - private[spark] def incRemoteBytesRead(value: Long) = _remoteBytesRead += value - private[spark] def decRemoteBytesRead(value: Long) = _remoteBytesRead -= value + /* ========================== * + | OTHER THINGS | + * ========================== */ - /** - * Shuffle data that was read from the local disk (as opposed to from a remote executor). - */ - private var _localBytesRead: Long = _ - def localBytesRead: Long = _localBytesRead - private[spark] def incLocalBytesRead(value: Long) = _localBytesRead += value + private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit = { + accums += a + } /** - * Total bytes fetched in the shuffle by this task (both remote and local). + * Return the latest updates of accumulators in this task. + * + * The [[AccumulableInfo.update]] field is always defined and the [[AccumulableInfo.value]] + * field is always empty, since this represents the partial updates recorded in this task, + * not the aggregated value across multiple tasks. */ - def totalBytesRead: Long = _remoteBytesRead + _localBytesRead + def accumulatorUpdates(): Seq[AccumulableInfo] = { + accums.map { a => a.toInfo(Some(a.localValue), None) } + } - /** - * Number of blocks fetched in this shuffle by this task (remote or local) - */ - def totalBlocksFetched: Int = _remoteBlocksFetched + _localBlocksFetched + // If we are reconstructing this TaskMetrics on the driver, some metrics may already be set. + // If so, initialize all relevant metrics classes so listeners can access them downstream. + { + var (hasShuffleRead, hasShuffleWrite, hasInput, hasOutput) = (false, false, false, false) + initialAccums + .filter { a => a.localValue != a.zero } + .foreach { a => + a.name.get match { + case sr if sr.startsWith(SHUFFLE_READ_METRICS_PREFIX) => hasShuffleRead = true + case sw if sw.startsWith(SHUFFLE_WRITE_METRICS_PREFIX) => hasShuffleWrite = true + case in if in.startsWith(INPUT_METRICS_PREFIX) => hasInput = true + case out if out.startsWith(OUTPUT_METRICS_PREFIX) => hasOutput = true + case _ => + } + } + if (hasShuffleRead) { _shuffleReadMetrics = Some(new ShuffleReadMetrics(initialAccumsMap)) } + if (hasShuffleWrite) { _shuffleWriteMetrics = Some(new ShuffleWriteMetrics(initialAccumsMap)) } + if (hasInput) { _inputMetrics = Some(new InputMetrics(initialAccumsMap)) } + if (hasOutput) { _outputMetrics = Some(new OutputMetrics(initialAccumsMap)) } + } - /** - * Total number of records read from the shuffle by this task - */ - private var _recordsRead: Long = _ - def recordsRead: Long = _recordsRead - private[spark] def incRecordsRead(value: Long) = _recordsRead += value - private[spark] def decRecordsRead(value: Long) = _recordsRead -= value } /** - * :: DeveloperApi :: - * Metrics pertaining to shuffle data written in a given task. + * Internal subclass of [[TaskMetrics]] which is used only for posting events to listeners. + * Its purpose is to obviate the need for the driver to reconstruct the original accumulators, + * which might have been garbage-collected. See SPARK-13407 for more details. + * + * Instances of this class should be considered read-only and users should not call `inc*()` or + * `set*()` methods. While we could override the setter methods to throw + * UnsupportedOperationException, we choose not to do so because the overrides would quickly become + * out-of-date when new metrics are added. */ -@DeveloperApi -class ShuffleWriteMetrics extends Serializable { - /** - * Number of bytes written for the shuffle by this task - */ - @volatile private var _shuffleBytesWritten: Long = _ - def shuffleBytesWritten: Long = _shuffleBytesWritten - private[spark] def incShuffleBytesWritten(value: Long) = _shuffleBytesWritten += value - private[spark] def decShuffleBytesWritten(value: Long) = _shuffleBytesWritten -= value +private[spark] class ListenerTaskMetrics( + initialAccums: Seq[Accumulator[_]], + accumUpdates: Seq[AccumulableInfo]) extends TaskMetrics(initialAccums) { + + override def accumulatorUpdates(): Seq[AccumulableInfo] = accumUpdates + + override private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit = { + throw new UnsupportedOperationException("This TaskMetrics is read-only") + } +} + +private[spark] object TaskMetrics extends Logging { + + def empty: TaskMetrics = new TaskMetrics /** - * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds + * Get an accumulator from the given map by name, assuming it exists. */ - @volatile private var _shuffleWriteTime: Long = _ - def shuffleWriteTime: Long = _shuffleWriteTime - private[spark] def incShuffleWriteTime(value: Long) = _shuffleWriteTime += value - private[spark] def decShuffleWriteTime(value: Long) = _shuffleWriteTime -= value + def getAccum[T](accumMap: Map[String, Accumulator[_]], name: String): Accumulator[T] = { + require(accumMap.contains(name), s"metric '$name' is missing") + val accum = accumMap(name) + try { + // Note: we can't do pattern matching here because types are erased by compile time + accum.asInstanceOf[Accumulator[T]] + } catch { + case e: ClassCastException => + throw new SparkException(s"accumulator $name was of unexpected type", e) + } + } /** - * Total number of records written to the shuffle by this task - */ - @volatile private var _shuffleRecordsWritten: Long = _ - def shuffleRecordsWritten: Long = _shuffleRecordsWritten - private[spark] def incShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten += value - private[spark] def decShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten -= value - private[spark] def setShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten = value + * Construct a [[TaskMetrics]] object from a list of accumulator updates, called on driver only. + * + * Executors only send accumulator updates back to the driver, not [[TaskMetrics]]. However, we + * need the latter to post task end events to listeners, so we need to reconstruct the metrics + * on the driver. + * + * This assumes the provided updates contain the initial set of accumulators representing + * internal task level metrics. + */ + def fromAccumulatorUpdates(accumUpdates: Seq[AccumulableInfo]): TaskMetrics = { + // Initial accumulators are passed into the TaskMetrics constructor first because these + // are required to be uniquely named. The rest of the accumulators from this task are + // registered later because they need not satisfy this requirement. + val definedAccumUpdates = accumUpdates.filter { info => info.update.isDefined } + val initialAccums = definedAccumUpdates + .filter { info => info.name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX)) } + .map { info => + val accum = InternalAccumulator.create(info.name.get) + accum.setValueAny(info.update.get) + accum + } + new ListenerTaskMetrics(initialAccums, definedAccumUpdates) + } + } diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala index 532850dd57716..978afaffab30b 100644 --- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala @@ -19,11 +19,10 @@ package org.apache.spark.input import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{BytesWritable, LongWritable} -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat -import org.apache.spark.Logging -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging /** * Custom Input Format for reading and splitting flat binary files that contain records, @@ -36,7 +35,7 @@ private[spark] object FixedLengthBinaryInputFormat { /** Retrieves the record length property from a Hadoop configuration */ def getRecordLength(context: JobContext): Int = { - SparkHadoopUtil.get.getConfigurationFromJobContext(context).get(RECORD_LENGTH_PROPERTY).toInt + context.getConfiguration.get(RECORD_LENGTH_PROPERTY).toInt } } diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala index 67a96925da019..549395314ba61 100644 --- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala @@ -20,11 +20,10 @@ package org.apache.spark.input import java.io.IOException import org.apache.hadoop.fs.FSDataInputStream -import org.apache.hadoop.io.compress.CompressionCodecFactory import org.apache.hadoop.io.{BytesWritable, LongWritable} +import org.apache.hadoop.io.compress.CompressionCodecFactory import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.FileSplit -import org.apache.spark.deploy.SparkHadoopUtil /** * FixedLengthBinaryRecordReader is returned by FixedLengthBinaryInputFormat. @@ -83,16 +82,16 @@ private[spark] class FixedLengthBinaryRecordReader // the actual file we will be reading from val file = fileSplit.getPath // job configuration - val job = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val conf = context.getConfiguration // check compression - val codec = new CompressionCodecFactory(job).getCodec(file) + val codec = new CompressionCodecFactory(conf).getCodec(file) if (codec != null) { throw new IOException("FixedLengthRecordReader does not support reading compressed files") } // get the record length recordLength = FixedLengthBinaryInputFormat.getRecordLength(context) // get the filesystem - val fs = file.getFileSystem(job) + val fs = file.getFileSystem(conf) // open the File fileInputStream = fs.open(file) // seek to the splitStart position diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index 280e7a5fe893c..18cb7631b3d4c 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -21,14 +21,12 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da import scala.collection.JavaConverters._ -import com.google.common.io.{Closeables, ByteStreams} +import com.google.common.io.{ByteStreams, Closeables} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit} -import org.apache.spark.deploy.SparkHadoopUtil - /** * A general format for reading whole files in as streams, byte arrays, * or other functions to be added @@ -43,9 +41,8 @@ private[spark] abstract class StreamFileInputFormat[T] * which is set through setMaxSplitSize */ def setMinPartitions(context: JobContext, minPartitions: Int) { - val files = listStatus(context).asScala - val totalLen = files.map(file => if (file.isDir) 0L else file.getLen).sum - val maxSplitSize = Math.ceil(totalLen * 1.0 / files.size).toLong + val totalLen = listStatus(context).asScala.filterNot(_.isDirectory).map(_.getLen).sum + val maxSplitSize = math.ceil(totalLen / math.max(minPartitions, 1.0)).toLong super.setMaxSplitSize(maxSplitSize) } @@ -135,8 +132,7 @@ class PortableDataStream( private val confBytes = { val baos = new ByteArrayOutputStream() - SparkHadoopUtil.get.getConfigurationFromJobContext(context). - write(new DataOutputStream(baos)) + context.getConfiguration.write(new DataOutputStream(baos)) baos.toByteArray } diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala index 1ba34a11414a2..fa34f1e886c72 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -20,6 +20,7 @@ package org.apache.spark.input import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.InputSplit import org.apache.hadoop.mapreduce.JobContext import org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat @@ -33,14 +34,13 @@ import org.apache.hadoop.mapreduce.TaskAttemptContext */ private[spark] class WholeTextFileInputFormat - extends CombineFileInputFormat[String, String] with Configurable { + extends CombineFileInputFormat[Text, Text] with Configurable { override protected def isSplitable(context: JobContext, file: Path): Boolean = false override def createRecordReader( split: InputSplit, - context: TaskAttemptContext): RecordReader[String, String] = { - + context: TaskAttemptContext): RecordReader[Text, Text] = { val reader = new ConfigurableCombineFileRecordReader(split, context, classOf[WholeTextFileRecordReader]) reader.setConf(getConf) @@ -53,7 +53,7 @@ private[spark] class WholeTextFileInputFormat */ def setMinPartitions(context: JobContext, minPartitions: Int) { val files = listStatus(context).asScala - val totalLen = files.map(file => if (file.isDir) 0L else file.getLen).sum + val totalLen = files.map(file => if (file.isDirectory) 0L else file.getLen).sum val maxSplitSize = Math.ceil(totalLen * 1.0 / (if (minPartitions == 0) 1 else minPartitions)).toLong super.setMaxSplitSize(maxSplitSize) diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala index 31bde8a78f3c6..6b7f086678e93 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala @@ -17,17 +17,14 @@ package org.apache.spark.input -import org.apache.hadoop.conf.{Configuration, Configurable => HConfigurable} import com.google.common.io.{ByteStreams, Closeables} - +import org.apache.hadoop.conf.{Configurable => HConfigurable, Configuration} import org.apache.hadoop.io.Text import org.apache.hadoop.io.compress.CompressionCodecFactory import org.apache.hadoop.mapreduce.InputSplit -import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, CombineFileRecordReader} import org.apache.hadoop.mapreduce.RecordReader import org.apache.hadoop.mapreduce.TaskAttemptContext -import org.apache.spark.deploy.SparkHadoopUtil - +import org.apache.hadoop.mapreduce.lib.input.{CombineFileRecordReader, CombineFileSplit} /** * A trait to implement [[org.apache.hadoop.conf.Configurable Configurable]] interface. @@ -49,17 +46,16 @@ private[spark] class WholeTextFileRecordReader( split: CombineFileSplit, context: TaskAttemptContext, index: Integer) - extends RecordReader[String, String] with Configurable { + extends RecordReader[Text, Text] with Configurable { private[this] val path = split.getPath(index) - private[this] val fs = path.getFileSystem( - SparkHadoopUtil.get.getConfigurationFromJobContext(context)) + private[this] val fs = path.getFileSystem(context.getConfiguration) // True means the current file has been processed, then skip it. private[this] var processed = false - private[this] val key = path.toString - private[this] var value: String = null + private[this] val key: Text = new Text(path.toString) + private[this] var value: Text = null override def initialize(split: InputSplit, context: TaskAttemptContext): Unit = {} @@ -67,9 +63,9 @@ private[spark] class WholeTextFileRecordReader( override def getProgress: Float = if (processed) 1.0f else 0.0f - override def getCurrentKey: String = key + override def getCurrentKey: Text = key - override def getCurrentValue: String = value + override def getCurrentValue: Text = value override def nextKeyValue(): Boolean = { if (!processed) { @@ -83,7 +79,7 @@ private[spark] class WholeTextFileRecordReader( ByteStreams.toByteArray(fileIn) } - value = new Text(innerBuffer).toString + value = new Text(innerBuffer) Closeables.close(fileIn, false) processed = true true diff --git a/core/src/main/scala/org/apache/spark/internal/Logging.scala b/core/src/main/scala/org/apache/spark/internal/Logging.scala new file mode 100644 index 0000000000000..66a0cfec6296d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/Logging.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal + +import org.apache.log4j.{Level, LogManager, PropertyConfigurator} +import org.slf4j.{Logger, LoggerFactory} +import org.slf4j.impl.StaticLoggerBinder + +import org.apache.spark.util.Utils + +/** + * Utility trait for classes that want to log data. Creates a SLF4J logger for the class and allows + * logging messages at different levels using methods that only evaluate parameters lazily if the + * log level is enabled. + */ +private[spark] trait Logging { + + // Make the log field transient so that objects with Logging can + // be serialized and used on another machine + @transient private var log_ : Logger = null + + // Method to get the logger name for this object + protected def logName = { + // Ignore trailing $'s in the class names for Scala objects + this.getClass.getName.stripSuffix("$") + } + + // Method to get or create the logger for this object + protected def log: Logger = { + if (log_ == null) { + initializeLogIfNecessary(false) + log_ = LoggerFactory.getLogger(logName) + } + log_ + } + + // Log methods that take only a String + protected def logInfo(msg: => String) { + if (log.isInfoEnabled) log.info(msg) + } + + protected def logDebug(msg: => String) { + if (log.isDebugEnabled) log.debug(msg) + } + + protected def logTrace(msg: => String) { + if (log.isTraceEnabled) log.trace(msg) + } + + protected def logWarning(msg: => String) { + if (log.isWarnEnabled) log.warn(msg) + } + + protected def logError(msg: => String) { + if (log.isErrorEnabled) log.error(msg) + } + + // Log methods that take Throwables (Exceptions/Errors) too + protected def logInfo(msg: => String, throwable: Throwable) { + if (log.isInfoEnabled) log.info(msg, throwable) + } + + protected def logDebug(msg: => String, throwable: Throwable) { + if (log.isDebugEnabled) log.debug(msg, throwable) + } + + protected def logTrace(msg: => String, throwable: Throwable) { + if (log.isTraceEnabled) log.trace(msg, throwable) + } + + protected def logWarning(msg: => String, throwable: Throwable) { + if (log.isWarnEnabled) log.warn(msg, throwable) + } + + protected def logError(msg: => String, throwable: Throwable) { + if (log.isErrorEnabled) log.error(msg, throwable) + } + + protected def isTraceEnabled(): Boolean = { + log.isTraceEnabled + } + + protected def initializeLogIfNecessary(isInterpreter: Boolean): Unit = { + if (!Logging.initialized) { + Logging.initLock.synchronized { + if (!Logging.initialized) { + initializeLogging(isInterpreter) + } + } + } + } + + private def initializeLogging(isInterpreter: Boolean): Unit = { + // Don't use a logger in here, as this is itself occurring during initialization of a logger + // If Log4j 1.2 is being used, but is not initialized, load a default properties file + val binderClass = StaticLoggerBinder.getSingleton.getLoggerFactoryClassStr + // This distinguishes the log4j 1.2 binding, currently + // org.slf4j.impl.Log4jLoggerFactory, from the log4j 2.0 binding, currently + // org.apache.logging.slf4j.Log4jLoggerFactory + val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass) + if (usingLog4j12) { + val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements + // scalastyle:off println + if (!log4j12Initialized) { + val defaultLogProps = "org/apache/spark/log4j-defaults.properties" + Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { + case Some(url) => + PropertyConfigurator.configure(url) + System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") + case None => + System.err.println(s"Spark was unable to load $defaultLogProps") + } + } + + if (isInterpreter) { + // Use the repl's main class to define the default log level when running the shell, + // overriding the root logger's config if they're different. + val rootLogger = LogManager.getRootLogger() + val replLogger = LogManager.getLogger(logName) + val replLevel = Option(replLogger.getLevel()).getOrElse(Level.WARN) + if (replLevel != rootLogger.getEffectiveLevel()) { + System.err.printf("Setting default log level to \"%s\".\n", replLevel) + System.err.println("To adjust logging level use sc.setLogLevel(newLevel).") + rootLogger.setLevel(replLevel) + } + } + // scalastyle:on println + } + Logging.initialized = true + + // Force a call into slf4j to initialize it. Avoids this happening from multiple threads + // and triggering this: http://mailman.qos.ch/pipermail/slf4j-dev/2010-April/002956.html + log + } +} + +private object Logging { + @volatile private var initialized = false + val initLock = new Object() + try { + // We use reflection here to handle the case where users remove the + // slf4j-to-jul bridge order to route their logs to JUL. + val bridgeClass = Utils.classForName("org.slf4j.bridge.SLF4JBridgeHandler") + bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null) + val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean] + if (!installed) { + bridgeClass.getMethod("install").invoke(null) + } + } catch { + case e: ClassNotFoundException => // can't log anything yet so just fail silently + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala new file mode 100644 index 0000000000000..5d50e3851a9f0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.config + +import java.util.concurrent.TimeUnit + +import org.apache.spark.network.util.{ByteUnit, JavaUtils} + +private object ConfigHelpers { + + def toNumber[T](s: String, converter: String => T, key: String, configType: String): T = { + try { + converter(s) + } catch { + case _: NumberFormatException => + throw new IllegalArgumentException(s"$key should be $configType, but was $s") + } + } + + def toBoolean(s: String, key: String): Boolean = { + try { + s.toBoolean + } catch { + case _: IllegalArgumentException => + throw new IllegalArgumentException(s"$key should be boolean, but was $s") + } + } + + def stringToSeq[T](str: String, converter: String => T): Seq[T] = { + str.split(",").map(_.trim()).filter(_.nonEmpty).map(converter) + } + + def seqToString[T](v: Seq[T], stringConverter: T => String): String = { + v.map(stringConverter).mkString(",") + } + + def timeFromString(str: String, unit: TimeUnit): Long = JavaUtils.timeStringAs(str, unit) + + def timeToString(v: Long, unit: TimeUnit): String = TimeUnit.MILLISECONDS.convert(v, unit) + "ms" + + def byteFromString(str: String, unit: ByteUnit): Long = { + val (input, multiplier) = + if (str.length() > 0 && str.charAt(0) == '-') { + (str.substring(1), -1) + } else { + (str, 1) + } + multiplier * JavaUtils.byteStringAs(input, unit) + } + + def byteToString(v: Long, unit: ByteUnit): String = unit.convertTo(v, ByteUnit.BYTE) + "b" + +} + +/** + * A type-safe config builder. Provides methods for transforming the input data (which can be + * used, e.g., for validation) and creating the final config entry. + * + * One of the methods that return a [[ConfigEntry]] must be called to create a config entry that + * can be used with [[SparkConf]]. + */ +private[spark] class TypedConfigBuilder[T]( + val parent: ConfigBuilder, + val converter: String => T, + val stringConverter: T => String) { + + import ConfigHelpers._ + + def this(parent: ConfigBuilder, converter: String => T) = { + this(parent, converter, Option(_).map(_.toString).orNull) + } + + /** Apply a transformation to the user-provided values of the config entry. */ + def transform(fn: T => T): TypedConfigBuilder[T] = { + new TypedConfigBuilder(parent, s => fn(converter(s)), stringConverter) + } + + /** Check that user-provided values for the config match a pre-defined set. */ + def checkValues(validValues: Set[T]): TypedConfigBuilder[T] = { + transform { v => + if (!validValues.contains(v)) { + throw new IllegalArgumentException( + s"The value of ${parent.key} should be one of ${validValues.mkString(", ")}, but was $v") + } + v + } + } + + /** Turns the config entry into a sequence of values of the underlying type. */ + def toSequence: TypedConfigBuilder[Seq[T]] = { + new TypedConfigBuilder(parent, stringToSeq(_, converter), seqToString(_, stringConverter)) + } + + /** Creates a [[ConfigEntry]] that does not have a default value. */ + def createOptional: OptionalConfigEntry[T] = { + val entry = new OptionalConfigEntry[T](parent.key, converter, stringConverter, parent._doc, + parent._public) + parent._onCreate.foreach(_(entry)) + entry + } + + /** Creates a [[ConfigEntry]] that has a default value. */ + def createWithDefault(default: T): ConfigEntry[T] = { + val transformedDefault = converter(stringConverter(default)) + val entry = new ConfigEntryWithDefault[T](parent.key, transformedDefault, converter, + stringConverter, parent._doc, parent._public) + parent._onCreate.foreach(_(entry)) + entry + } + + /** + * Creates a [[ConfigEntry]] that has a default value. The default value is provided as a + * [[String]] and must be a valid value for the entry. + */ + def createWithDefaultString(default: String): ConfigEntry[T] = { + val typedDefault = converter(default) + val entry = new ConfigEntryWithDefault[T](parent.key, typedDefault, converter, stringConverter, + parent._doc, parent._public) + parent._onCreate.foreach(_(entry)) + entry + } + +} + +/** + * Basic builder for Spark configurations. Provides methods for creating type-specific builders. + * + * @see TypedConfigBuilder + */ +private[spark] case class ConfigBuilder(key: String) { + + import ConfigHelpers._ + + private[config] var _public = true + private[config] var _doc = "" + private[config] var _onCreate: Option[ConfigEntry[_] => Unit] = None + + def internal(): ConfigBuilder = { + _public = false + this + } + + def doc(s: String): ConfigBuilder = { + _doc = s + this + } + + /** + * Registers a callback for when the config entry is finally instantiated. Currently used by + * SQLConf to keep track of SQL configuration entries. + */ + def onCreate(callback: ConfigEntry[_] => Unit): ConfigBuilder = { + _onCreate = Option(callback) + this + } + + def intConf: TypedConfigBuilder[Int] = { + new TypedConfigBuilder(this, toNumber(_, _.toInt, key, "int")) + } + + def longConf: TypedConfigBuilder[Long] = { + new TypedConfigBuilder(this, toNumber(_, _.toLong, key, "long")) + } + + def doubleConf: TypedConfigBuilder[Double] = { + new TypedConfigBuilder(this, toNumber(_, _.toDouble, key, "double")) + } + + def booleanConf: TypedConfigBuilder[Boolean] = { + new TypedConfigBuilder(this, toBoolean(_, key)) + } + + def stringConf: TypedConfigBuilder[String] = { + new TypedConfigBuilder(this, v => v) + } + + def timeConf(unit: TimeUnit): TypedConfigBuilder[Long] = { + new TypedConfigBuilder(this, timeFromString(_, unit), timeToString(_, unit)) + } + + def bytesConf(unit: ByteUnit): TypedConfigBuilder[Long] = { + new TypedConfigBuilder(this, byteFromString(_, unit), byteToString(_, unit)) + } + + def fallbackConf[T](fallback: ConfigEntry[T]): ConfigEntry[T] = { + new FallbackConfigEntry(key, _doc, _public, fallback) + } + +} diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala new file mode 100644 index 0000000000000..f7296b487c0e9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.config + +import org.apache.spark.SparkConf + +/** + * An entry contains all meta information for a configuration. + * + * @param key the key for the configuration + * @param defaultValue the default value for the configuration + * @param valueConverter how to convert a string to the value. It should throw an exception if the + * string does not have the required format. + * @param stringConverter how to convert a value to a string that the user can use it as a valid + * string value. It's usually `toString`. But sometimes, a custom converter + * is necessary. E.g., if T is List[String], `a, b, c` is better than + * `List(a, b, c)`. + * @param doc the documentation for the configuration + * @param isPublic if this configuration is public to the user. If it's `false`, this + * configuration is only used internally and we should not expose it to users. + * @tparam T the value type + */ +private[spark] abstract class ConfigEntry[T] ( + val key: String, + val valueConverter: String => T, + val stringConverter: T => String, + val doc: String, + val isPublic: Boolean) { + + def defaultValueString: String + + def readFrom(conf: SparkConf): T + + // This is used by SQLConf, since it doesn't use SparkConf to store settings and thus cannot + // use readFrom(). + def defaultValue: Option[T] = None + + override def toString: String = { + s"ConfigEntry(key=$key, defaultValue=$defaultValueString, doc=$doc, public=$isPublic)" + } +} + +private class ConfigEntryWithDefault[T] ( + key: String, + _defaultValue: T, + valueConverter: String => T, + stringConverter: T => String, + doc: String, + isPublic: Boolean) + extends ConfigEntry(key, valueConverter, stringConverter, doc, isPublic) { + + override def defaultValue: Option[T] = Some(_defaultValue) + + override def defaultValueString: String = stringConverter(_defaultValue) + + override def readFrom(conf: SparkConf): T = { + conf.getOption(key).map(valueConverter).getOrElse(_defaultValue) + } + +} + +/** + * A config entry that does not have a default value. + */ +private[spark] class OptionalConfigEntry[T]( + key: String, + val rawValueConverter: String => T, + val rawStringConverter: T => String, + doc: String, + isPublic: Boolean) + extends ConfigEntry[Option[T]](key, s => Some(rawValueConverter(s)), + v => v.map(rawStringConverter).orNull, doc, isPublic) { + + override def defaultValueString: String = "" + + override def readFrom(conf: SparkConf): Option[T] = conf.getOption(key).map(rawValueConverter) + +} + +/** + * A config entry whose default value is defined by another config entry. + */ +private class FallbackConfigEntry[T] ( + key: String, + doc: String, + isPublic: Boolean, + private val fallback: ConfigEntry[T]) + extends ConfigEntry[T](key, fallback.valueConverter, fallback.stringConverter, doc, isPublic) { + + override def defaultValueString: String = s"" + + override def readFrom(conf: SparkConf): T = { + conf.getOption(key).map(valueConverter).getOrElse(fallback.readFrom(conf)) + } + +} diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala new file mode 100644 index 0000000000000..94b50ee06520c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal + +import org.apache.spark.launcher.SparkLauncher +import org.apache.spark.network.util.ByteUnit + +package object config { + + private[spark] val DRIVER_CLASS_PATH = + ConfigBuilder(SparkLauncher.DRIVER_EXTRA_CLASSPATH).stringConf.createOptional + + private[spark] val DRIVER_JAVA_OPTIONS = + ConfigBuilder(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS).stringConf.createOptional + + private[spark] val DRIVER_LIBRARY_PATH = + ConfigBuilder(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH).stringConf.createOptional + + private[spark] val DRIVER_USER_CLASS_PATH_FIRST = + ConfigBuilder("spark.driver.userClassPathFirst").booleanConf.createWithDefault(false) + + private[spark] val DRIVER_MEMORY = ConfigBuilder("spark.driver.memory") + .bytesConf(ByteUnit.MiB) + .createWithDefaultString("1g") + + private[spark] val EXECUTOR_CLASS_PATH = + ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_CLASSPATH).stringConf.createOptional + + private[spark] val EXECUTOR_JAVA_OPTIONS = + ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_JAVA_OPTIONS).stringConf.createOptional + + private[spark] val EXECUTOR_LIBRARY_PATH = + ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_LIBRARY_PATH).stringConf.createOptional + + private[spark] val EXECUTOR_USER_CLASS_PATH_FIRST = + ConfigBuilder("spark.executor.userClassPathFirst").booleanConf.createWithDefault(false) + + private[spark] val EXECUTOR_MEMORY = ConfigBuilder("spark.executor.memory") + .bytesConf(ByteUnit.MiB) + .createWithDefaultString("1g") + + private[spark] val IS_PYTHON_APP = ConfigBuilder("spark.yarn.isPython").internal() + .booleanConf.createWithDefault(false) + + private[spark] val CPUS_PER_TASK = ConfigBuilder("spark.task.cpus").intConf.createWithDefault(1) + + private[spark] val DYN_ALLOCATION_MIN_EXECUTORS = + ConfigBuilder("spark.dynamicAllocation.minExecutors").intConf.createWithDefault(0) + + private[spark] val DYN_ALLOCATION_INITIAL_EXECUTORS = + ConfigBuilder("spark.dynamicAllocation.initialExecutors") + .fallbackConf(DYN_ALLOCATION_MIN_EXECUTORS) + + private[spark] val DYN_ALLOCATION_MAX_EXECUTORS = + ConfigBuilder("spark.dynamicAllocation.maxExecutors").intConf.createWithDefault(Int.MaxValue) + + private[spark] val SHUFFLE_SERVICE_ENABLED = + ConfigBuilder("spark.shuffle.service.enabled").booleanConf.createWithDefault(false) + + private[spark] val KEYTAB = ConfigBuilder("spark.yarn.keytab") + .doc("Location of user's keytab.") + .stringConf.createOptional + + private[spark] val PRINCIPAL = ConfigBuilder("spark.yarn.principal") + .doc("Name of the Kerberos principal.") + .stringConf.createOptional + + private[spark] val EXECUTOR_INSTANCES = ConfigBuilder("spark.executor.instances") + .intConf + .createOptional + + private[spark] val PY_FILES = ConfigBuilder("spark.submit.pyFiles") + .internal() + .stringConf + .toSequence + .createWithDefault(Nil) +} diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index ca74eedf89be5..ae014becef755 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -17,10 +17,10 @@ package org.apache.spark.io -import java.io.{IOException, InputStream, OutputStream} +import java.io._ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} -import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream} +import net.jpountz.lz4.LZ4BlockOutputStream import org.xerial.snappy.{Snappy, SnappyInputStream, SnappyOutputStream} import org.apache.spark.SparkConf @@ -49,7 +49,8 @@ private[spark] object CompressionCodec { private val configKey = "spark.io.compression.codec" private[spark] def supportsConcatenationOfSerializedStreams(codec: CompressionCodec): Boolean = { - codec.isInstanceOf[SnappyCompressionCodec] || codec.isInstanceOf[LZFCompressionCodec] + (codec.isInstanceOf[SnappyCompressionCodec] || codec.isInstanceOf[LZFCompressionCodec] + || codec.isInstanceOf[LZ4CompressionCodec]) } private val shortCompressionCodecNames = Map( @@ -92,12 +93,11 @@ private[spark] object CompressionCodec { } } - val FALLBACK_COMPRESSION_CODEC = "lzf" - val DEFAULT_COMPRESSION_CODEC = "snappy" + val FALLBACK_COMPRESSION_CODEC = "snappy" + val DEFAULT_COMPRESSION_CODEC = "lz4" val ALL_COMPRESSION_CODECS = shortCompressionCodecNames.values.toSeq } - /** * :: DeveloperApi :: * LZ4 implementation of [[org.apache.spark.io.CompressionCodec]]. @@ -149,12 +149,7 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { */ @DeveloperApi class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { - - try { - Snappy.getNativeLibraryVersion - } catch { - case e: Error => throw new IllegalArgumentException(e) - } + val version = SnappyCompressionCodec.version override def compressedOutputStream(s: OutputStream): OutputStream = { val blockSize = conf.getSizeAsBytes("spark.io.compression.snappy.blockSize", "32k").toInt @@ -164,6 +159,19 @@ class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { override def compressedInputStream(s: InputStream): InputStream = new SnappyInputStream(s) } +/** + * Object guards against memory leak bug in snappy-java library: + * (https://github.com/xerial/snappy-java/issues/131). + * Before a new version of the library, we only call the method once and cache the result. + */ +private final object SnappyCompressionCodec { + private lazy val version: String = try { + Snappy.getNativeLibraryVersion + } catch { + case e: Error => throw new IllegalArgumentException(e) + } +} + /** * Wrapper over [[SnappyOutputStream]] which guards against write-after-close and double-close * issues. See SPARK-7660 for more details. This wrapping can be removed if we upgrade to a version diff --git a/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala index 3ea984c501e02..a5d41a1eeb479 100644 --- a/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala +++ b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala @@ -21,7 +21,7 @@ import java.net.{InetAddress, Socket} import org.apache.spark.SPARK_VERSION import org.apache.spark.launcher.LauncherProtocol._ -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{ThreadUtils, Utils} /** * A class that can be used to talk to a launcher server. Users should extend this class to @@ -88,12 +88,20 @@ private[spark] abstract class LauncherBackend { */ protected def onDisconnected() : Unit = { } + private def fireStopRequest(): Unit = { + val thread = LauncherBackend.threadFactory.newThread(new Runnable() { + override def run(): Unit = Utils.tryLogNonFatalError { + onStopRequest() + } + }) + thread.start() + } private class BackendConnection(s: Socket) extends LauncherConnection(s) { override protected def handle(m: Message): Unit = m match { case _: Stop => - onStopRequest() + fireStopRequest() case _ => throw new IllegalArgumentException(s"Unexpected message type: ${m.getClass().getName()}") diff --git a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala index a2add61617281..31b9c5edf003f 100644 --- a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala +++ b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala @@ -37,7 +37,6 @@ private[spark] class WorkerCommandBuilder(sparkHome: String, memoryMb: Int, comm override def buildCommand(env: JMap[String, String]): JList[String] = { val cmd = buildJavaCommand(command.classPathEntries.mkString(File.pathSeparator)) - cmd.add(s"-Xms${memoryMb}M") cmd.add(s"-Xmx${memoryMb}M") command.javaOpts.foreach(cmd.add) CommandBuilderUtils.addPermGenSizeOpt(cmd) diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index f7298e8d5c62c..607283a306b8f 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -18,61 +18,13 @@ package org.apache.spark.mapred import java.io.IOException -import java.lang.reflect.Modifier -import org.apache.hadoop.mapred._ import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext} import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter} -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.executor.CommitDeniedException -import org.apache.spark.{Logging, SparkEnv, TaskContext} -import org.apache.spark.util.{Utils => SparkUtils} - -private[spark] -trait SparkHadoopMapRedUtil { - def newJobContext(conf: JobConf, jobId: JobID): JobContext = { - val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl", - "org.apache.hadoop.mapred.JobContext") - val ctor = klass.getDeclaredConstructor(classOf[JobConf], - classOf[org.apache.hadoop.mapreduce.JobID]) - // In Hadoop 1.0.x, JobContext is an interface, and JobContextImpl is package private. - // Make it accessible if it's not in order to access it. - if (!Modifier.isPublic(ctor.getModifiers)) { - ctor.setAccessible(true) - } - ctor.newInstance(conf, jobId).asInstanceOf[JobContext] - } - - def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = { - val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl", - "org.apache.hadoop.mapred.TaskAttemptContext") - val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID]) - // See above - if (!Modifier.isPublic(ctor.getModifiers)) { - ctor.setAccessible(true) - } - ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext] - } - - def newTaskAttemptID( - jtIdentifier: String, - jobId: Int, - isMap: Boolean, - taskId: Int, - attemptId: Int): TaskAttemptID = { - new TaskAttemptID(jtIdentifier, jobId, isMap, taskId, attemptId) - } - - private def firstAvailableClass(first: String, second: String): Class[_] = { - try { - SparkUtils.classForName(first) - } catch { - case e: ClassNotFoundException => - SparkUtils.classForName(second) - } - } -} +import org.apache.spark.internal.Logging object SparkHadoopMapRedUtil extends Logging { /** @@ -81,11 +33,8 @@ object SparkHadoopMapRedUtil extends Logging { * the driver in order to determine whether this attempt can commit (please see SPARK-4879 for * details). * - * Output commit coordinator is only contacted when the following two configurations are both set - * to `true`: - * - * - `spark.speculation` - * - `spark.hadoop.outputCommitCoordination.enabled` + * Output commit coordinator is only used when `spark.hadoop.outputCommitCoordination.enabled` + * is set to true (which is the default). */ def commitTask( committer: MapReduceOutputCommitter, @@ -93,7 +42,7 @@ object SparkHadoopMapRedUtil extends Logging { jobId: Int, splitId: Int): Unit = { - val mrTaskAttemptID = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(mrTaskContext) + val mrTaskAttemptID = mrTaskContext.getTaskAttemptID // Called after we have decided to commit def performCommit(): Unit = { @@ -112,11 +61,10 @@ object SparkHadoopMapRedUtil extends Logging { if (committer.needsTaskCommit(mrTaskContext)) { val shouldCoordinateWithDriver: Boolean = { val sparkConf = SparkEnv.get.conf - // We only need to coordinate with the driver if there are multiple concurrent task - // attempts, which should only occur if speculation is enabled - val speculationEnabled = sparkConf.getBoolean("spark.speculation", defaultValue = false) - // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs - sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", speculationEnabled) + // We only need to coordinate with the driver if there are concurrent task attempts. + // Note that this could happen even when speculation is not enabled (e.g. see SPARK-8029). + // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs. + sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", defaultValue = true) } if (shouldCoordinateWithDriver) { diff --git a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala deleted file mode 100644 index 943ebcb7bd0a1..0000000000000 --- a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mapreduce - -import java.lang.{Boolean => JBoolean, Integer => JInteger} - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.mapreduce.{JobContext, JobID, TaskAttemptContext, TaskAttemptID} -import org.apache.spark.util.Utils - -private[spark] -trait SparkHadoopMapReduceUtil { - def newJobContext(conf: Configuration, jobId: JobID): JobContext = { - val klass = firstAvailableClass( - "org.apache.hadoop.mapreduce.task.JobContextImpl", // hadoop2, hadoop2-yarn - "org.apache.hadoop.mapreduce.JobContext") // hadoop1 - val ctor = klass.getDeclaredConstructor(classOf[Configuration], classOf[JobID]) - ctor.newInstance(conf, jobId).asInstanceOf[JobContext] - } - - def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = { - val klass = firstAvailableClass( - "org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl", // hadoop2, hadoop2-yarn - "org.apache.hadoop.mapreduce.TaskAttemptContext") // hadoop1 - val ctor = klass.getDeclaredConstructor(classOf[Configuration], classOf[TaskAttemptID]) - ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext] - } - - def newTaskAttemptID( - jtIdentifier: String, - jobId: Int, - isMap: Boolean, - taskId: Int, - attemptId: Int): TaskAttemptID = { - val klass = Utils.classForName("org.apache.hadoop.mapreduce.TaskAttemptID") - try { - // First, attempt to use the old-style constructor that takes a boolean isMap - // (not available in YARN) - val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], classOf[Boolean], - classOf[Int], classOf[Int]) - ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId), - new JInteger(attemptId)).asInstanceOf[TaskAttemptID] - } catch { - case exc: NoSuchMethodException => { - // If that failed, look for the new constructor that takes a TaskType (not available in 1.x) - val taskTypeClass = Utils.classForName("org.apache.hadoop.mapreduce.TaskType") - .asInstanceOf[Class[Enum[_]]] - val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke( - taskTypeClass, if (isMap) "MAP" else "REDUCE") - val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass, - classOf[Int], classOf[Int]) - ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), - new JInteger(attemptId)).asInstanceOf[TaskAttemptID] - } - } - } - - private def firstAvailableClass(first: String, second: String): Class[_] = { - try { - Utils.classForName(first) - } catch { - case e: ClassNotFoundException => - Utils.classForName(second) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala new file mode 100644 index 0000000000000..f8167074c6dfa --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable + +import org.apache.spark.internal.Logging + +/** + * Implements policies and bookkeeping for sharing a adjustable-sized pool of memory between tasks. + * + * Tries to ensure that each task gets a reasonable share of memory, instead of some task ramping up + * to a large amount first and then causing others to spill to disk repeatedly. + * + * If there are N tasks, it ensures that each task can acquire at least 1 / 2N of the memory + * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the + * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever this + * set changes. This is all done by synchronizing access to mutable state and using wait() and + * notifyAll() to signal changes to callers. Prior to Spark 1.6, this arbitration of memory across + * tasks was performed by the ShuffleMemoryManager. + * + * @param lock a [[MemoryManager]] instance to synchronize on + * @param memoryMode the type of memory tracked by this pool (on- or off-heap) + */ +private[memory] class ExecutionMemoryPool( + lock: Object, + memoryMode: MemoryMode + ) extends MemoryPool(lock) with Logging { + + private[this] val poolName: String = memoryMode match { + case MemoryMode.ON_HEAP => "on-heap execution" + case MemoryMode.OFF_HEAP => "off-heap execution" + } + + /** + * Map from taskAttemptId -> memory consumption in bytes + */ + @GuardedBy("lock") + private val memoryForTask = new mutable.HashMap[Long, Long]() + + override def memoryUsed: Long = lock.synchronized { + memoryForTask.values.sum + } + + /** + * Returns the memory consumption, in bytes, for the given task. + */ + def getMemoryUsageForTask(taskAttemptId: Long): Long = lock.synchronized { + memoryForTask.getOrElse(taskAttemptId, 0L) + } + + /** + * Try to acquire up to `numBytes` of memory for the given task and return the number of bytes + * obtained, or 0 if none can be allocated. + * + * This call may block until there is enough free memory in some situations, to make sure each + * task has a chance to ramp up to at least 1 / 2N of the total memory pool (where N is the # of + * active tasks) before it is forced to spill. This can happen if the number of tasks increase + * but an older task had a lot of memory already. + * + * @param numBytes number of bytes to acquire + * @param taskAttemptId the task attempt acquiring memory + * @param maybeGrowPool a callback that potentially grows the size of this pool. It takes in + * one parameter (Long) that represents the desired amount of memory by + * which this pool should be expanded. + * @param computeMaxPoolSize a callback that returns the maximum allowable size of this pool + * at this given moment. This is not a field because the max pool + * size is variable in certain cases. For instance, in unified + * memory management, the execution pool can be expanded by evicting + * cached blocks, thereby shrinking the storage pool. + * + * @return the number of bytes granted to the task. + */ + private[memory] def acquireMemory( + numBytes: Long, + taskAttemptId: Long, + maybeGrowPool: Long => Unit = (additionalSpaceNeeded: Long) => Unit, + computeMaxPoolSize: () => Long = () => poolSize): Long = lock.synchronized { + assert(numBytes > 0, s"invalid number of bytes requested: $numBytes") + + // TODO: clean up this clunky method signature + + // Add this task to the taskMemory map just so we can keep an accurate count of the number + // of active tasks, to let other tasks ramp down their memory in calls to `acquireMemory` + if (!memoryForTask.contains(taskAttemptId)) { + memoryForTask(taskAttemptId) = 0L + // This will later cause waiting tasks to wake up and check numTasks again + lock.notifyAll() + } + + // Keep looping until we're either sure that we don't want to grant this request (because this + // task would have more than 1 / numActiveTasks of the memory) or we have enough free + // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)). + // TODO: simplify this to limit each task to its own slot + while (true) { + val numActiveTasks = memoryForTask.keys.size + val curMem = memoryForTask(taskAttemptId) + + // In every iteration of this loop, we should first try to reclaim any borrowed execution + // space from storage. This is necessary because of the potential race condition where new + // storage blocks may steal the free execution memory that this task was waiting for. + maybeGrowPool(numBytes - memoryFree) + + // Maximum size the pool would have after potentially growing the pool. + // This is used to compute the upper bound of how much memory each task can occupy. This + // must take into account potential free memory as well as the amount this pool currently + // occupies. Otherwise, we may run into SPARK-12155 where, in unified memory management, + // we did not take into account space that could have been freed by evicting cached blocks. + val maxPoolSize = computeMaxPoolSize() + val maxMemoryPerTask = maxPoolSize / numActiveTasks + val minMemoryPerTask = poolSize / (2 * numActiveTasks) + + // How much we can grant this task; keep its share within 0 <= X <= 1 / numActiveTasks + val maxToGrant = math.min(numBytes, math.max(0, maxMemoryPerTask - curMem)) + // Only give it as much memory as is free, which might be none if it reached 1 / numTasks + val toGrant = math.min(maxToGrant, memoryFree) + + // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; + // if we can't give it this much now, wait for other tasks to free up memory + // (this happens if older tasks allocated lots of memory before N grew) + if (toGrant < numBytes && curMem + toGrant < minMemoryPerTask) { + logInfo(s"TID $taskAttemptId waiting for at least 1/2N of $poolName pool to be free") + lock.wait() + } else { + memoryForTask(taskAttemptId) += toGrant + return toGrant + } + } + 0L // Never reached + } + + /** + * Release `numBytes` of memory acquired by the given task. + */ + def releaseMemory(numBytes: Long, taskAttemptId: Long): Unit = lock.synchronized { + val curMem = memoryForTask.getOrElse(taskAttemptId, 0L) + var memoryToFree = if (curMem < numBytes) { + logWarning( + s"Internal error: release called on $numBytes bytes but task only has $curMem bytes " + + s"of memory from the $poolName pool") + curMem + } else { + numBytes + } + if (memoryForTask.contains(taskAttemptId)) { + memoryForTask(taskAttemptId) -= memoryToFree + if (memoryForTask(taskAttemptId) <= 0) { + memoryForTask.remove(taskAttemptId) + } + } + lock.notifyAll() // Notify waiters in acquireMemory() that memory has been freed + } + + /** + * Release all memory for the given task and mark it as inactive (e.g. when a task ends). + * @return the number of bytes freed. + */ + def releaseAllMemoryForTask(taskAttemptId: Long): Long = lock.synchronized { + val numBytesToFree = getMemoryUsageForTask(taskAttemptId) + releaseMemory(numBytesToFree, taskAttemptId) + numBytesToFree + } + +} diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index b0cf2696a397f..0210217e41bfe 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -19,14 +19,11 @@ package org.apache.spark.memory import javax.annotation.concurrent.GuardedBy -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - -import com.google.common.annotations.VisibleForTesting - -import org.apache.spark.util.Utils -import org.apache.spark.{SparkException, TaskContext, SparkConf, Logging} -import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore} +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.storage.BlockId +import org.apache.spark.storage.memory.MemoryStore +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.memory.MemoryAllocator @@ -36,65 +33,56 @@ import org.apache.spark.unsafe.memory.MemoryAllocator * In this context, execution memory refers to that used for computation in shuffles, joins, * sorts and aggregations, while storage memory refers to that used for caching and propagating * internal data across the cluster. There exists one MemoryManager per JVM. - * - * The MemoryManager abstract base class itself implements policies for sharing execution memory - * between tasks; it tries to ensure that each task gets a reasonable share of memory, instead of - * some task ramping up to a large amount first and then causing others to spill to disk repeatedly. - * If there are N tasks, it ensures that each task can acquire at least 1 / 2N of the memory - * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the - * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever - * this set changes. This is all done by synchronizing access to mutable state and using wait() and - * notifyAll() to signal changes to callers. Prior to Spark 1.6, this arbitration of memory across - * tasks was performed by the ShuffleMemoryManager. */ -private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) extends Logging { +private[spark] abstract class MemoryManager( + conf: SparkConf, + numCores: Int, + onHeapStorageMemory: Long, + onHeapExecutionMemory: Long) extends Logging { // -- Methods related to memory allocation policies and bookkeeping ------------------------------ - // The memory store used to evict cached blocks - private var _memoryStore: MemoryStore = _ - protected def memoryStore: MemoryStore = { - if (_memoryStore == null) { - throw new IllegalArgumentException("memory store not initialized yet") - } - _memoryStore - } + @GuardedBy("this") + protected val onHeapStorageMemoryPool = new StorageMemoryPool(this, MemoryMode.ON_HEAP) + @GuardedBy("this") + protected val offHeapStorageMemoryPool = new StorageMemoryPool(this, MemoryMode.OFF_HEAP) + @GuardedBy("this") + protected val onHeapExecutionMemoryPool = new ExecutionMemoryPool(this, MemoryMode.ON_HEAP) + @GuardedBy("this") + protected val offHeapExecutionMemoryPool = new ExecutionMemoryPool(this, MemoryMode.OFF_HEAP) - // Amount of execution/storage memory in use, accesses must be synchronized on `this` - @GuardedBy("this") protected var _executionMemoryUsed: Long = 0 - @GuardedBy("this") protected var _storageMemoryUsed: Long = 0 - // Map from taskAttemptId -> memory consumption in bytes - @GuardedBy("this") private val executionMemoryForTask = new mutable.HashMap[Long, Long]() + onHeapStorageMemoryPool.incrementPoolSize(onHeapStorageMemory) + onHeapExecutionMemoryPool.incrementPoolSize(onHeapExecutionMemory) - /** - * Set the [[MemoryStore]] used by this manager to evict cached blocks. - * This must be set after construction due to initialization ordering constraints. - */ - final def setMemoryStore(store: MemoryStore): Unit = { - _memoryStore = store - } + protected[this] val maxOffHeapMemory = conf.getSizeAsBytes("spark.memory.offHeap.size", 0) + protected[this] val offHeapStorageMemory = + (maxOffHeapMemory * conf.getDouble("spark.memory.storageFraction", 0.5)).toLong + + offHeapExecutionMemoryPool.incrementPoolSize(maxOffHeapMemory - offHeapStorageMemory) + offHeapStorageMemoryPool.incrementPoolSize(offHeapStorageMemory) /** - * Total available memory for execution, in bytes. + * Total available memory for storage, in bytes. This amount can vary over time, depending on + * the MemoryManager implementation. + * In this model, this is equivalent to the amount of memory not occupied by execution. */ - def maxExecutionMemory: Long + def maxOnHeapStorageMemory: Long /** - * Total available memory for storage, in bytes. + * Set the [[MemoryStore]] used by this manager to evict cached blocks. + * This must be set after construction due to initialization ordering constraints. */ - def maxStorageMemory: Long - - // TODO: avoid passing evicted blocks around to simplify method signatures (SPARK-10985) + final def setMemoryStore(store: MemoryStore): Unit = synchronized { + onHeapStorageMemoryPool.setMemoryStore(store) + offHeapStorageMemoryPool.setMemoryStore(store) + } /** * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary. - * Blocks evicted in the process, if any, are added to `evictedBlocks`. + * * @return whether all N bytes were successfully granted. */ - def acquireStorageMemory( - blockId: BlockId, - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean + def acquireStorageMemory(blockId: BlockId, numBytes: Long, memoryMode: MemoryMode): Boolean /** * Acquire N bytes of memory to unroll the given block, evicting existing ones if necessary. @@ -102,197 +90,115 @@ private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) exte * This extra method allows subclasses to differentiate behavior between acquiring storage * memory and acquiring unroll memory. For instance, the memory management model in Spark * 1.5 and before places a limit on the amount of space that can be freed from unrolling. - * Blocks evicted in the process, if any, are added to `evictedBlocks`. * * @return whether all N bytes were successfully granted. */ - def acquireUnrollMemory( - blockId: BlockId, - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { - acquireStorageMemory(blockId, numBytes, evictedBlocks) - } - - /** - * Acquire N bytes of memory for execution, evicting cached blocks if necessary. - * Blocks evicted in the process, if any, are added to `evictedBlocks`. - * @return number of bytes successfully granted (<= N). - */ - @VisibleForTesting - private[memory] def doAcquireExecutionMemory( - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long + def acquireUnrollMemory(blockId: BlockId, numBytes: Long, memoryMode: MemoryMode): Boolean /** - * Try to acquire up to `numBytes` of execution memory for the current task and return the number - * of bytes obtained, or 0 if none can be allocated. + * Try to acquire up to `numBytes` of execution memory for the current task and return the + * number of bytes obtained, or 0 if none can be allocated. * * This call may block until there is enough free memory in some situations, to make sure each * task has a chance to ramp up to at least 1 / 2N of the total memory pool (where N is the # of * active tasks) before it is forced to spill. This can happen if the number of tasks increase * but an older task had a lot of memory already. - * - * Subclasses should override `doAcquireExecutionMemory` in order to customize the policies - * that control global sharing of memory between execution and storage. */ private[memory] - final def acquireExecutionMemory(numBytes: Long, taskAttemptId: Long): Long = synchronized { - assert(numBytes > 0, "invalid number of bytes requested: " + numBytes) - - // Add this task to the taskMemory map just so we can keep an accurate count of the number - // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire - if (!executionMemoryForTask.contains(taskAttemptId)) { - executionMemoryForTask(taskAttemptId) = 0L - // This will later cause waiting tasks to wake up and check numTasks again - notifyAll() - } - - // Once the cross-task memory allocation policy has decided to grant more memory to a task, - // this method is called in order to actually obtain that execution memory, potentially - // triggering eviction of storage memory: - def acquire(toGrant: Long): Long = synchronized { - val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - val acquired = doAcquireExecutionMemory(toGrant, evictedBlocks) - // Register evicted blocks, if any, with the active task metrics - Option(TaskContext.get()).foreach { tc => - val metrics = tc.taskMetrics() - val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) - metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq) - } - executionMemoryForTask(taskAttemptId) += acquired - acquired - } - - // Keep looping until we're either sure that we don't want to grant this request (because this - // task would have more than 1 / numActiveTasks of the memory) or we have enough free - // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)). - // TODO: simplify this to limit each task to its own slot - while (true) { - val numActiveTasks = executionMemoryForTask.keys.size - val curMem = executionMemoryForTask(taskAttemptId) - val freeMemory = maxExecutionMemory - executionMemoryForTask.values.sum - - // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks; - // don't let it be negative - val maxToGrant = - math.min(numBytes, math.max(0, (maxExecutionMemory / numActiveTasks) - curMem)) - // Only give it as much memory as is free, which might be none if it reached 1 / numTasks - val toGrant = math.min(maxToGrant, freeMemory) - - if (curMem < maxExecutionMemory / (2 * numActiveTasks)) { - // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; - // if we can't give it this much now, wait for other tasks to free up memory - // (this happens if older tasks allocated lots of memory before N grew) - if ( - freeMemory >= math.min(maxToGrant, maxExecutionMemory / (2 * numActiveTasks) - curMem)) { - return acquire(toGrant) - } else { - logInfo( - s"TID $taskAttemptId waiting for at least 1/2N of execution memory pool to be free") - wait() - } - } else { - return acquire(toGrant) - } - } - 0L // Never reached - } - - @VisibleForTesting - private[memory] def releaseExecutionMemory(numBytes: Long): Unit = synchronized { - if (numBytes > _executionMemoryUsed) { - logWarning(s"Attempted to release $numBytes bytes of execution " + - s"memory when we only have ${_executionMemoryUsed} bytes") - _executionMemoryUsed = 0 - } else { - _executionMemoryUsed -= numBytes - } - } + def acquireExecutionMemory( + numBytes: Long, + taskAttemptId: Long, + memoryMode: MemoryMode): Long /** * Release numBytes of execution memory belonging to the given task. */ private[memory] - final def releaseExecutionMemory(numBytes: Long, taskAttemptId: Long): Unit = synchronized { - val curMem = executionMemoryForTask.getOrElse(taskAttemptId, 0L) - if (curMem < numBytes) { - if (Utils.isTesting) { - throw new SparkException( - s"Internal error: release called on $numBytes bytes but task only has $curMem") - } else { - logWarning(s"Internal error: release called on $numBytes bytes but task only has $curMem") - } - } - if (executionMemoryForTask.contains(taskAttemptId)) { - executionMemoryForTask(taskAttemptId) -= numBytes - if (executionMemoryForTask(taskAttemptId) <= 0) { - executionMemoryForTask.remove(taskAttemptId) - } - releaseExecutionMemory(numBytes) + def releaseExecutionMemory( + numBytes: Long, + taskAttemptId: Long, + memoryMode: MemoryMode): Unit = synchronized { + memoryMode match { + case MemoryMode.ON_HEAP => onHeapExecutionMemoryPool.releaseMemory(numBytes, taskAttemptId) + case MemoryMode.OFF_HEAP => offHeapExecutionMemoryPool.releaseMemory(numBytes, taskAttemptId) } - notifyAll() // Notify waiters in acquireExecutionMemory() that memory has been freed } /** * Release all memory for the given task and mark it as inactive (e.g. when a task ends). + * * @return the number of bytes freed. */ private[memory] def releaseAllExecutionMemoryForTask(taskAttemptId: Long): Long = synchronized { - val numBytesToFree = getExecutionMemoryUsageForTask(taskAttemptId) - releaseExecutionMemory(numBytesToFree, taskAttemptId) - numBytesToFree + onHeapExecutionMemoryPool.releaseAllMemoryForTask(taskAttemptId) + + offHeapExecutionMemoryPool.releaseAllMemoryForTask(taskAttemptId) } /** * Release N bytes of storage memory. */ - def releaseStorageMemory(numBytes: Long): Unit = synchronized { - if (numBytes > _storageMemoryUsed) { - logWarning(s"Attempted to release $numBytes bytes of storage " + - s"memory when we only have ${_storageMemoryUsed} bytes") - _storageMemoryUsed = 0 - } else { - _storageMemoryUsed -= numBytes + def releaseStorageMemory(numBytes: Long, memoryMode: MemoryMode): Unit = synchronized { + memoryMode match { + case MemoryMode.ON_HEAP => onHeapStorageMemoryPool.releaseMemory(numBytes) + case MemoryMode.OFF_HEAP => offHeapStorageMemoryPool.releaseMemory(numBytes) } } /** * Release all storage memory acquired. */ - def releaseAllStorageMemory(): Unit = synchronized { - _storageMemoryUsed = 0 + final def releaseAllStorageMemory(): Unit = synchronized { + onHeapStorageMemoryPool.releaseAllMemory() + offHeapStorageMemoryPool.releaseAllMemory() } /** * Release N bytes of unroll memory. */ - def releaseUnrollMemory(numBytes: Long): Unit = synchronized { - releaseStorageMemory(numBytes) + final def releaseUnrollMemory(numBytes: Long, memoryMode: MemoryMode): Unit = synchronized { + releaseStorageMemory(numBytes, memoryMode) } /** * Execution memory currently in use, in bytes. */ final def executionMemoryUsed: Long = synchronized { - _executionMemoryUsed + onHeapExecutionMemoryPool.memoryUsed + offHeapExecutionMemoryPool.memoryUsed } /** * Storage memory currently in use, in bytes. */ final def storageMemoryUsed: Long = synchronized { - _storageMemoryUsed + onHeapStorageMemoryPool.memoryUsed + offHeapStorageMemoryPool.memoryUsed } /** * Returns the execution memory consumption, in bytes, for the given task. */ private[memory] def getExecutionMemoryUsageForTask(taskAttemptId: Long): Long = synchronized { - executionMemoryForTask.getOrElse(taskAttemptId, 0L) + onHeapExecutionMemoryPool.getMemoryUsageForTask(taskAttemptId) + + offHeapExecutionMemoryPool.getMemoryUsageForTask(taskAttemptId) } // -- Fields related to Tungsten managed memory ------------------------------------------------- + /** + * Tracks whether Tungsten memory will be allocated on the JVM heap or off-heap using + * sun.misc.Unsafe. + */ + final val tungstenMemoryMode: MemoryMode = { + if (conf.getBoolean("spark.memory.offHeap.enabled", false)) { + require(conf.getSizeAsBytes("spark.memory.offHeap.size", 0) > 0, + "spark.memory.offHeap.size must be > 0 when spark.memory.offHeap.enabled == true") + require(Platform.unaligned(), + "No support for unaligned Unsafe. Set spark.memory.offHeap.enabled to false.") + MemoryMode.OFF_HEAP + } else { + MemoryMode.ON_HEAP + } + } + /** * The default page size, in bytes. * @@ -306,21 +212,22 @@ private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) exte val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors() // Because of rounding to next power of 2, we may have safetyFactor as 8 in worst case val safetyFactor = 16 - val size = ByteArrayMethods.nextPowerOf2(maxExecutionMemory / cores / safetyFactor) + val maxTungstenMemory: Long = tungstenMemoryMode match { + case MemoryMode.ON_HEAP => onHeapExecutionMemoryPool.poolSize + case MemoryMode.OFF_HEAP => offHeapExecutionMemoryPool.poolSize + } + val size = ByteArrayMethods.nextPowerOf2(maxTungstenMemory / cores / safetyFactor) val default = math.min(maxPageSize, math.max(minPageSize, size)) conf.getSizeAsBytes("spark.buffer.pageSize", default) } - /** - * Tracks whether Tungsten memory will be allocated on the JVM heap or off-heap using - * sun.misc.Unsafe. - */ - final val tungstenMemoryIsAllocatedInHeap: Boolean = - !conf.getBoolean("spark.unsafe.offHeap", false) - /** * Allocates memory for use by Unsafe/Tungsten code. */ - private[memory] final val tungstenMemoryAllocator: MemoryAllocator = - if (tungstenMemoryIsAllocatedInHeap) MemoryAllocator.HEAP else MemoryAllocator.UNSAFE + private[memory] final val tungstenMemoryAllocator: MemoryAllocator = { + tungstenMemoryMode match { + case MemoryMode.ON_HEAP => MemoryAllocator.HEAP + case MemoryMode.OFF_HEAP => MemoryAllocator.UNSAFE + } + } } diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala new file mode 100644 index 0000000000000..1b9edf9c43bda --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import javax.annotation.concurrent.GuardedBy + +/** + * Manages bookkeeping for an adjustable-sized region of memory. This class is internal to + * the [[MemoryManager]]. See subclasses for more details. + * + * @param lock a [[MemoryManager]] instance, used for synchronization. We purposely erase the type + * to `Object` to avoid programming errors, since this object should only be used for + * synchronization purposes. + */ +private[memory] abstract class MemoryPool(lock: Object) { + + @GuardedBy("lock") + private[this] var _poolSize: Long = 0 + + /** + * Returns the current size of the pool, in bytes. + */ + final def poolSize: Long = lock.synchronized { + _poolSize + } + + /** + * Returns the amount of free memory in the pool, in bytes. + */ + final def memoryFree: Long = lock.synchronized { + _poolSize - memoryUsed + } + + /** + * Expands the pool by `delta` bytes. + */ + final def incrementPoolSize(delta: Long): Unit = lock.synchronized { + require(delta >= 0) + _poolSize += delta + } + + /** + * Shrinks the pool by `delta` bytes. + */ + final def decrementPoolSize(delta: Long): Unit = lock.synchronized { + require(delta >= 0) + require(delta <= _poolSize) + require(_poolSize - delta >= memoryUsed) + _poolSize -= delta + } + + /** + * Returns the amount of used memory in this pool (in bytes). + */ + def memoryUsed: Long +} diff --git a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala index 9c2c2e90a2282..cbd0fa9ec2098 100644 --- a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala @@ -17,11 +17,8 @@ package org.apache.spark.memory -import scala.collection.mutable - import org.apache.spark.SparkConf -import org.apache.spark.storage.{BlockId, BlockStatus} - +import org.apache.spark.storage.BlockId /** * A [[MemoryManager]] that statically partitions the heap space into disjoint regions. @@ -32,10 +29,14 @@ import org.apache.spark.storage.{BlockId, BlockStatus} */ private[spark] class StaticMemoryManager( conf: SparkConf, - override val maxExecutionMemory: Long, - override val maxStorageMemory: Long, + maxOnHeapExecutionMemory: Long, + override val maxOnHeapStorageMemory: Long, numCores: Int) - extends MemoryManager(conf, numCores) { + extends MemoryManager( + conf, + numCores, + maxOnHeapStorageMemory, + maxOnHeapExecutionMemory) { def this(conf: SparkConf, numCores: Int) { this( @@ -45,81 +46,59 @@ private[spark] class StaticMemoryManager( numCores) } - // Max number of bytes worth of blocks to evict when unrolling - private val maxMemoryToEvictForUnroll: Long = { - (maxStorageMemory * conf.getDouble("spark.storage.unrollFraction", 0.2)).toLong - } + // The StaticMemoryManager does not support off-heap storage memory: + offHeapExecutionMemoryPool.incrementPoolSize(offHeapStorageMemoryPool.poolSize) + offHeapStorageMemoryPool.decrementPoolSize(offHeapStorageMemoryPool.poolSize) - /** - * Acquire N bytes of memory for execution. - * @return number of bytes successfully granted (<= N). - */ - override def doAcquireExecutionMemory( - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized { - assert(numBytes >= 0) - assert(_executionMemoryUsed <= maxExecutionMemory) - val bytesToGrant = math.min(numBytes, maxExecutionMemory - _executionMemoryUsed) - _executionMemoryUsed += bytesToGrant - bytesToGrant + // Max number of bytes worth of blocks to evict when unrolling + private val maxUnrollMemory: Long = { + (maxOnHeapStorageMemory * conf.getDouble("spark.storage.unrollFraction", 0.2)).toLong } - /** - * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary. - * Blocks evicted in the process, if any, are added to `evictedBlocks`. - * @return whether all N bytes were successfully granted. - */ override def acquireStorageMemory( blockId: BlockId, numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { - acquireStorageMemory(blockId, numBytes, numBytes, evictedBlocks) + memoryMode: MemoryMode): Boolean = synchronized { + require(memoryMode != MemoryMode.OFF_HEAP, + "StaticMemoryManager does not support off-heap storage memory") + if (numBytes > maxOnHeapStorageMemory) { + // Fail fast if the block simply won't fit + logInfo(s"Will not store $blockId as the required space ($numBytes bytes) exceeds our " + + s"memory limit ($maxOnHeapStorageMemory bytes)") + false + } else { + onHeapStorageMemoryPool.acquireMemory(blockId, numBytes) + } } - /** - * Acquire N bytes of memory to unroll the given block, evicting existing ones if necessary. - * - * This evicts at most M bytes worth of existing blocks, where M is a fraction of the storage - * space specified by `spark.storage.unrollFraction`. Blocks evicted in the process, if any, - * are added to `evictedBlocks`. - * - * @return whether all N bytes were successfully granted. - */ override def acquireUnrollMemory( blockId: BlockId, numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { - val currentUnrollMemory = memoryStore.currentUnrollMemory - val maxNumBytesToFree = math.max(0, maxMemoryToEvictForUnroll - currentUnrollMemory) - val numBytesToFree = math.min(numBytes, maxNumBytesToFree) - acquireStorageMemory(blockId, numBytes, numBytesToFree, evictedBlocks) + memoryMode: MemoryMode): Boolean = synchronized { + require(memoryMode != MemoryMode.OFF_HEAP, + "StaticMemoryManager does not support off-heap unroll memory") + val currentUnrollMemory = onHeapStorageMemoryPool.memoryStore.currentUnrollMemory + val freeMemory = onHeapStorageMemoryPool.memoryFree + // When unrolling, we will use all of the existing free memory, and, if necessary, + // some extra space freed from evicting cached blocks. We must place a cap on the + // amount of memory to be evicted by unrolling, however, otherwise unrolling one + // big block can blow away the entire cache. + val maxNumBytesToFree = math.max(0, maxUnrollMemory - currentUnrollMemory - freeMemory) + // Keep it within the range 0 <= X <= maxNumBytesToFree + val numBytesToFree = math.max(0, math.min(maxNumBytesToFree, numBytes - freeMemory)) + onHeapStorageMemoryPool.acquireMemory(blockId, numBytes, numBytesToFree) } - /** - * Acquire N bytes of storage memory for the given block, evicting existing ones if necessary. - * - * @param blockId the ID of the block we are acquiring storage memory for - * @param numBytesToAcquire the size of this block - * @param numBytesToFree the size of space to be freed through evicting blocks - * @param evictedBlocks a holder for blocks evicted in the process - * @return whether all N bytes were successfully granted. - */ - private def acquireStorageMemory( - blockId: BlockId, - numBytesToAcquire: Long, - numBytesToFree: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { - assert(numBytesToAcquire >= 0) - assert(numBytesToFree >= 0) - memoryStore.ensureFreeSpace(blockId, numBytesToFree, evictedBlocks) - assert(_storageMemoryUsed <= maxStorageMemory) - val enoughMemory = _storageMemoryUsed + numBytesToAcquire <= maxStorageMemory - if (enoughMemory) { - _storageMemoryUsed += numBytesToAcquire + private[memory] + override def acquireExecutionMemory( + numBytes: Long, + taskAttemptId: Long, + memoryMode: MemoryMode): Long = synchronized { + memoryMode match { + case MemoryMode.ON_HEAP => onHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) + case MemoryMode.OFF_HEAP => offHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) } - enoughMemory } - } @@ -135,7 +114,6 @@ private[spark] object StaticMemoryManager { (systemMaxMemory * memoryFraction * safetyFraction).toLong } - /** * Return the total amount of memory available for the execution region, in bytes. */ diff --git a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala new file mode 100644 index 0000000000000..0b552cabfc941 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import javax.annotation.concurrent.GuardedBy + +import org.apache.spark.internal.Logging +import org.apache.spark.storage.BlockId +import org.apache.spark.storage.memory.MemoryStore + +/** + * Performs bookkeeping for managing an adjustable-size pool of memory that is used for storage + * (caching). + * + * @param lock a [[MemoryManager]] instance to synchronize on + * @param memoryMode the type of memory tracked by this pool (on- or off-heap) + */ +private[memory] class StorageMemoryPool( + lock: Object, + memoryMode: MemoryMode + ) extends MemoryPool(lock) with Logging { + + private[this] val poolName: String = memoryMode match { + case MemoryMode.ON_HEAP => "on-heap storage" + case MemoryMode.OFF_HEAP => "off-heap storage" + } + + @GuardedBy("lock") + private[this] var _memoryUsed: Long = 0L + + override def memoryUsed: Long = lock.synchronized { + _memoryUsed + } + + private var _memoryStore: MemoryStore = _ + def memoryStore: MemoryStore = { + if (_memoryStore == null) { + throw new IllegalStateException("memory store not initialized yet") + } + _memoryStore + } + + /** + * Set the [[MemoryStore]] used by this manager to evict cached blocks. + * This must be set after construction due to initialization ordering constraints. + */ + final def setMemoryStore(store: MemoryStore): Unit = { + _memoryStore = store + } + + /** + * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary. + * + * @return whether all N bytes were successfully granted. + */ + def acquireMemory(blockId: BlockId, numBytes: Long): Boolean = lock.synchronized { + val numBytesToFree = math.max(0, numBytes - memoryFree) + acquireMemory(blockId, numBytes, numBytesToFree) + } + + /** + * Acquire N bytes of storage memory for the given block, evicting existing ones if necessary. + * + * @param blockId the ID of the block we are acquiring storage memory for + * @param numBytesToAcquire the size of this block + * @param numBytesToFree the amount of space to be freed through evicting blocks + * @return whether all N bytes were successfully granted. + */ + def acquireMemory( + blockId: BlockId, + numBytesToAcquire: Long, + numBytesToFree: Long): Boolean = lock.synchronized { + assert(numBytesToAcquire >= 0) + assert(numBytesToFree >= 0) + assert(memoryUsed <= poolSize) + if (numBytesToFree > 0) { + memoryStore.evictBlocksToFreeSpace(Some(blockId), numBytesToFree, memoryMode) + } + // NOTE: If the memory store evicts blocks, then those evictions will synchronously call + // back into this StorageMemoryPool in order to free memory. Therefore, these variables + // should have been updated. + val enoughMemory = numBytesToAcquire <= memoryFree + if (enoughMemory) { + _memoryUsed += numBytesToAcquire + } + enoughMemory + } + + def releaseMemory(size: Long): Unit = lock.synchronized { + if (size > _memoryUsed) { + logWarning(s"Attempted to release $size bytes of storage " + + s"memory when we only have ${_memoryUsed} bytes") + _memoryUsed = 0 + } else { + _memoryUsed -= size + } + } + + def releaseAllMemory(): Unit = lock.synchronized { + _memoryUsed = 0 + } + + /** + * Try to shrink the size of this storage memory pool by `spaceToFree` bytes. Return the number + * of bytes removed from the pool's capacity. + */ + def shrinkPoolToFreeSpace(spaceToFree: Long): Long = lock.synchronized { + // First, shrink the pool by reclaiming free memory: + val spaceFreedByReleasingUnusedMemory = math.min(spaceToFree, memoryFree) + decrementPoolSize(spaceFreedByReleasingUnusedMemory) + val remainingSpaceToFree = spaceToFree - spaceFreedByReleasingUnusedMemory + if (remainingSpaceToFree > 0) { + // If reclaiming free memory did not adequately shrink the pool, begin evicting blocks: + val spaceFreedByEviction = + memoryStore.evictBlocksToFreeSpace(None, remainingSpaceToFree, memoryMode) + // When a block is released, BlockManager.dropFromMemory() calls releaseMemory(), so we do + // not need to decrement _memoryUsed here. However, we do need to decrement the pool size. + decrementPoolSize(spaceFreedByEviction) + spaceFreedByReleasingUnusedMemory + spaceFreedByEviction + } else { + spaceFreedByReleasingUnusedMemory + } + } +} diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index a3093030a0f93..fa9c021f70376 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -17,17 +17,14 @@ package org.apache.spark.memory -import scala.collection.mutable - import org.apache.spark.SparkConf -import org.apache.spark.storage.{BlockStatus, BlockId} - +import org.apache.spark.storage.BlockId /** * A [[MemoryManager]] that enforces a soft boundary between execution and storage such that * either side can borrow memory from the other. * - * The region shared between execution and storage is a fraction of the total heap space + * The region shared between execution and storage is a fraction of (the total heap space - 300MB) * configurable through `spark.memory.fraction` (default 0.75). The position of the boundary * within this space is further determined by `spark.memory.storageFraction` (default 0.5). * This means the size of the storage region is 0.75 * 0.5 = 0.375 of the heap space by default. @@ -41,105 +38,191 @@ import org.apache.spark.storage.{BlockStatus, BlockId} * The implication is that attempts to cache blocks may fail if execution has already eaten * up most of the storage space, in which case the new blocks will be evicted immediately * according to their respective storage levels. + * + * @param onHeapStorageRegionSize Size of the storage region, in bytes. + * This region is not statically reserved; execution can borrow from + * it if necessary. Cached blocks can be evicted only if actual + * storage memory usage exceeds this region. */ -private[spark] class UnifiedMemoryManager( +private[spark] class UnifiedMemoryManager private[memory] ( conf: SparkConf, - maxMemory: Long, + val maxHeapMemory: Long, + onHeapStorageRegionSize: Long, numCores: Int) - extends MemoryManager(conf, numCores) { + extends MemoryManager( + conf, + numCores, + onHeapStorageRegionSize, + maxHeapMemory - onHeapStorageRegionSize) { - def this(conf: SparkConf, numCores: Int) { - this(conf, UnifiedMemoryManager.getMaxMemory(conf), numCores) + private def assertInvariants(): Unit = { + assert(onHeapExecutionMemoryPool.poolSize + onHeapStorageMemoryPool.poolSize == maxHeapMemory) + assert( + offHeapExecutionMemoryPool.poolSize + offHeapStorageMemoryPool.poolSize == maxOffHeapMemory) } - /** - * Size of the storage region, in bytes. - * - * This region is not statically reserved; execution can borrow from it if necessary. - * Cached blocks can be evicted only if actual storage memory usage exceeds this region. - */ - private val storageRegionSize: Long = { - (maxMemory * conf.getDouble("spark.memory.storageFraction", 0.5)).toLong - } + assertInvariants() - /** - * Total amount of memory, in bytes, not currently occupied by either execution or storage. - */ - private def totalFreeMemory: Long = synchronized { - assert(_executionMemoryUsed <= maxMemory) - assert(_storageMemoryUsed <= maxMemory) - assert(_executionMemoryUsed + _storageMemoryUsed <= maxMemory) - maxMemory - _executionMemoryUsed - _storageMemoryUsed - } - - /** - * Total available memory for execution, in bytes. - * In this model, this is equivalent to the amount of memory not occupied by storage. - */ - override def maxExecutionMemory: Long = synchronized { - maxMemory - _storageMemoryUsed - } - - /** - * Total available memory for storage, in bytes. - * In this model, this is equivalent to the amount of memory not occupied by execution. - */ - override def maxStorageMemory: Long = synchronized { - maxMemory - _executionMemoryUsed + override def maxOnHeapStorageMemory: Long = synchronized { + maxHeapMemory - onHeapExecutionMemoryPool.memoryUsed } /** - * Acquire N bytes of memory for execution, evicting cached blocks if necessary. + * Try to acquire up to `numBytes` of execution memory for the current task and return the + * number of bytes obtained, or 0 if none can be allocated. * - * This method evicts blocks only up to the amount of memory borrowed by storage. - * Blocks evicted in the process, if any, are added to `evictedBlocks`. - * @return number of bytes successfully granted (<= N). + * This call may block until there is enough free memory in some situations, to make sure each + * task has a chance to ramp up to at least 1 / 2N of the total memory pool (where N is the # of + * active tasks) before it is forced to spill. This can happen if the number of tasks increase + * but an older task had a lot of memory already. */ - private[memory] override def doAcquireExecutionMemory( + override private[memory] def acquireExecutionMemory( numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized { + taskAttemptId: Long, + memoryMode: MemoryMode): Long = synchronized { + assertInvariants() assert(numBytes >= 0) - val memoryBorrowedByStorage = math.max(0, _storageMemoryUsed - storageRegionSize) - // If there is not enough free memory AND storage has borrowed some execution memory, - // then evict as much memory borrowed by storage as needed to grant this request - val shouldEvictStorage = totalFreeMemory < numBytes && memoryBorrowedByStorage > 0 - if (shouldEvictStorage) { - val spaceToEnsure = math.min(numBytes, memoryBorrowedByStorage) - memoryStore.ensureFreeSpace(spaceToEnsure, evictedBlocks) + val (executionPool, storagePool, storageRegionSize, maxMemory) = memoryMode match { + case MemoryMode.ON_HEAP => ( + onHeapExecutionMemoryPool, + onHeapStorageMemoryPool, + onHeapStorageRegionSize, + maxHeapMemory) + case MemoryMode.OFF_HEAP => ( + offHeapExecutionMemoryPool, + offHeapStorageMemoryPool, + offHeapStorageMemory, + maxOffHeapMemory) } - val bytesToGrant = math.min(numBytes, totalFreeMemory) - _executionMemoryUsed += bytesToGrant - bytesToGrant + + /** + * Grow the execution pool by evicting cached blocks, thereby shrinking the storage pool. + * + * When acquiring memory for a task, the execution pool may need to make multiple + * attempts. Each attempt must be able to evict storage in case another task jumps in + * and caches a large block between the attempts. This is called once per attempt. + */ + def maybeGrowExecutionPool(extraMemoryNeeded: Long): Unit = { + if (extraMemoryNeeded > 0) { + // There is not enough free memory in the execution pool, so try to reclaim memory from + // storage. We can reclaim any free memory from the storage pool. If the storage pool + // has grown to become larger than `storageRegionSize`, we can evict blocks and reclaim + // the memory that storage has borrowed from execution. + val memoryReclaimableFromStorage = math.max( + storagePool.memoryFree, + storagePool.poolSize - storageRegionSize) + if (memoryReclaimableFromStorage > 0) { + // Only reclaim as much space as is necessary and available: + val spaceReclaimed = storagePool.shrinkPoolToFreeSpace( + math.min(extraMemoryNeeded, memoryReclaimableFromStorage)) + executionPool.incrementPoolSize(spaceReclaimed) + } + } + } + + /** + * The size the execution pool would have after evicting storage memory. + * + * The execution memory pool divides this quantity among the active tasks evenly to cap + * the execution memory allocation for each task. It is important to keep this greater + * than the execution pool size, which doesn't take into account potential memory that + * could be freed by evicting storage. Otherwise we may hit SPARK-12155. + * + * Additionally, this quantity should be kept below `maxMemory` to arbitrate fairness + * in execution memory allocation across tasks, Otherwise, a task may occupy more than + * its fair share of execution memory, mistakenly thinking that other tasks can acquire + * the portion of storage memory that cannot be evicted. + */ + def computeMaxExecutionPoolSize(): Long = { + maxMemory - math.min(storagePool.memoryUsed, storageRegionSize) + } + + executionPool.acquireMemory( + numBytes, taskAttemptId, maybeGrowExecutionPool, computeMaxExecutionPoolSize) } - /** - * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary. - * Blocks evicted in the process, if any, are added to `evictedBlocks`. - * @return whether all N bytes were successfully granted. - */ override def acquireStorageMemory( blockId: BlockId, numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + memoryMode: MemoryMode): Boolean = synchronized { + assertInvariants() assert(numBytes >= 0) - memoryStore.ensureFreeSpace(blockId, numBytes, evictedBlocks) - val enoughMemory = totalFreeMemory >= numBytes - if (enoughMemory) { - _storageMemoryUsed += numBytes + val (executionPool, storagePool, maxMemory) = memoryMode match { + case MemoryMode.ON_HEAP => ( + onHeapExecutionMemoryPool, + onHeapStorageMemoryPool, + maxOnHeapStorageMemory) + case MemoryMode.OFF_HEAP => ( + offHeapExecutionMemoryPool, + offHeapStorageMemoryPool, + maxOffHeapMemory) + } + if (numBytes > maxMemory) { + // Fail fast if the block simply won't fit + logInfo(s"Will not store $blockId as the required space ($numBytes bytes) exceeds our " + + s"memory limit ($maxMemory bytes)") + return false + } + if (numBytes > storagePool.memoryFree) { + // There is not enough free memory in the storage pool, so try to borrow free memory from + // the execution pool. + val memoryBorrowedFromExecution = Math.min(executionPool.memoryFree, numBytes) + executionPool.decrementPoolSize(memoryBorrowedFromExecution) + storagePool.incrementPoolSize(memoryBorrowedFromExecution) } - enoughMemory + storagePool.acquireMemory(blockId, numBytes) } + override def acquireUnrollMemory( + blockId: BlockId, + numBytes: Long, + memoryMode: MemoryMode): Boolean = synchronized { + acquireStorageMemory(blockId, numBytes, memoryMode) + } } -private object UnifiedMemoryManager { +object UnifiedMemoryManager { + + // Set aside a fixed amount of memory for non-storage, non-execution purposes. + // This serves a function similar to `spark.memory.fraction`, but guarantees that we reserve + // sufficient memory for the system even for small heaps. E.g. if we have a 1GB JVM, then + // the memory used for execution and storage will be (1024 - 300) * 0.75 = 543MB by default. + private val RESERVED_SYSTEM_MEMORY_BYTES = 300 * 1024 * 1024 + + def apply(conf: SparkConf, numCores: Int): UnifiedMemoryManager = { + val maxMemory = getMaxMemory(conf) + new UnifiedMemoryManager( + conf, + maxHeapMemory = maxMemory, + onHeapStorageRegionSize = + (maxMemory * conf.getDouble("spark.memory.storageFraction", 0.5)).toLong, + numCores = numCores) + } /** * Return the total amount of memory shared between execution and storage, in bytes. */ private def getMaxMemory(conf: SparkConf): Long = { - val systemMaxMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory) + val systemMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory) + val reservedMemory = conf.getLong("spark.testing.reservedMemory", + if (conf.contains("spark.testing")) 0 else RESERVED_SYSTEM_MEMORY_BYTES) + val minSystemMemory = reservedMemory * 1.5 + if (systemMemory < minSystemMemory) { + throw new IllegalArgumentException(s"System memory $systemMemory must " + + s"be at least $minSystemMemory. Please increase heap size using the --driver-memory " + + s"option or spark.driver.memory in Spark configuration.") + } + // SPARK-12759 Check executor memory to fail fast if memory is insufficient + if (conf.contains("spark.executor.memory")) { + val executorMemory = conf.getSizeAsBytes("spark.executor.memory") + if (executorMemory < minSystemMemory) { + throw new IllegalArgumentException(s"Executor memory $executorMemory must be at least " + + s"$minSystemMemory. Please increase executor memory using the " + + s"--executor-memory option or spark.executor.memory in Spark configuration.") + } + } + val usableMemory = systemMemory - reservedMemory val memoryFraction = conf.getDouble("spark.memory.fraction", 0.75) - (systemMaxMemory * memoryFraction).toLong + (usableMemory * memoryFraction).toLong } } diff --git a/core/src/main/scala/org/apache/spark/memory/package.scala b/core/src/main/scala/org/apache/spark/memory/package.scala new file mode 100644 index 0000000000000..3d00cd9cb6377 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/package.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * This package implements Spark's memory management system. This system consists of two main + * components, a JVM-wide memory manager and a per-task manager: + * + * - [[org.apache.spark.memory.MemoryManager]] manages Spark's overall memory usage within a JVM. + * This component implements the policies for dividing the available memory across tasks and for + * allocating memory between storage (memory used caching and data transfer) and execution + * (memory used by computations, such as shuffles, joins, sorts, and aggregations). + * - [[org.apache.spark.memory.TaskMemoryManager]] manages the memory allocated by individual + * tasks. Tasks interact with TaskMemoryManager and never directly interact with the JVM-wide + * MemoryManager. + * + * Internally, each of these components have additional abstractions for memory bookkeeping: + * + * - [[org.apache.spark.memory.MemoryConsumer]]s are clients of the TaskMemoryManager and + * correspond to individual operators and data structures within a task. The TaskMemoryManager + * receives memory allocation requests from MemoryConsumers and issues callbacks to consumers + * in order to trigger spilling when running low on memory. + * - [[org.apache.spark.memory.MemoryPool]]s are a bookkeeping abstraction used by the + * MemoryManager to track the division of memory between storage and execution. + * + * Diagrammatically: + * + * {{{ + * +-------------+ + * | MemConsumer |----+ +------------------------+ + * +-------------+ | +-------------------+ | MemoryManager | + * +--->| TaskMemoryManager |----+ | | + * +-------------+ | +-------------------+ | | +------------------+ | + * | MemConsumer |----+ | | | StorageMemPool | | + * +-------------+ +-------------------+ | | +------------------+ | + * | TaskMemoryManager |----+ | | + * +-------------------+ | | +------------------+ | + * +---->| |OnHeapExecMemPool | | + * * | | +------------------+ | + * * | | | + * +-------------+ * | | +------------------+ | + * | MemConsumer |----+ | | |OffHeapExecMemPool| | + * +-------------+ | +-------------------+ | | +------------------+ | + * +--->| TaskMemoryManager |----+ | | + * +-------------------+ +------------------------+ + * }}} + * + * + * There are two implementations of [[org.apache.spark.memory.MemoryManager]] which vary in how + * they handle the sizing of their memory pools: + * + * - [[org.apache.spark.memory.UnifiedMemoryManager]], the default in Spark 1.6+, enforces soft + * boundaries between storage and execution memory, allowing requests for memory in one region + * to be fulfilled by borrowing memory from the other. + * - [[org.apache.spark.memory.StaticMemoryManager]] enforces hard boundaries between storage + * and execution memory by statically partitioning Spark's memory and preventing storage and + * execution from borrowing memory from each other. This mode is retained only for legacy + * compatibility purposes. + */ +package object memory diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala index dd2d325d87034..979782ea40fd6 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala @@ -24,8 +24,9 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.matching.Regex +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkConf} private[spark] class MetricsConfig(conf: SparkConf) extends Logging { diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index fdf76d312db3b..0fed991049dd3 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -20,16 +20,16 @@ package org.apache.spark.metrics import java.util.Properties import java.util.concurrent.TimeUnit -import org.apache.spark.util.Utils - import scala.collection.mutable import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry} import org.eclipse.jetty.servlet.ServletContextHandler -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.internal.Logging import org.apache.spark.metrics.sink.{MetricsServlet, Sink} import org.apache.spark.metrics.source.Source +import org.apache.spark.util.Utils /** * Spark Metrics System, created by specific "instance", combined by source, @@ -196,10 +196,9 @@ private[spark] class MetricsSystem private ( sinks += sink.asInstanceOf[Sink] } } catch { - case e: Exception => { + case e: Exception => logError("Sink class " + classPath + " cannot be instantiated") throw e - } } } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index 2d25ebd66159f..22454e50b14b4 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -22,7 +22,7 @@ import java.util.Properties import java.util.concurrent.TimeUnit import com.codahale.metrics.MetricRegistry -import com.codahale.metrics.graphite.{GraphiteUDP, Graphite, GraphiteReporter} +import com.codahale.metrics.graphite.{Graphite, GraphiteReporter, GraphiteUDP} import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala index 2588fe2c9edb8..1992b42ac7f6b 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala @@ -20,6 +20,7 @@ package org.apache.spark.metrics.sink import java.util.Properties import com.codahale.metrics.{JmxReporter, MetricRegistry} + import org.apache.spark.SecurityManager private[spark] class JmxSink(val property: Properties, val registry: MetricRegistry, diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 4193e1d21d3c1..68b58b8490641 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -19,7 +19,6 @@ package org.apache.spark.metrics.sink import java.util.Properties import java.util.concurrent.TimeUnit - import javax.servlet.http.HttpServletRequest import com.codahale.metrics.MetricRegistry @@ -27,7 +26,7 @@ import com.codahale.metrics.json.MetricsModule import com.fasterxml.jackson.databind.ObjectMapper import org.eclipse.jetty.servlet.ServletContextHandler -import org.apache.spark.{SparkConf, SecurityManager} +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.ui.JettyUtils._ private[spark] class MetricsServlet( diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala index 11dfcfe2f04e1..773e074336cb0 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala @@ -20,7 +20,7 @@ package org.apache.spark.metrics.sink import java.util.Properties import java.util.concurrent.TimeUnit -import com.codahale.metrics.{Slf4jReporter, MetricRegistry} +import com.codahale.metrics.{MetricRegistry, Slf4jReporter} import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 1745d52c81923..8f83668d79029 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -17,6 +17,8 @@ package org.apache.spark.network +import scala.reflect.ClassTag + import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.storage.{BlockId, StorageLevel} @@ -31,6 +33,18 @@ trait BlockDataManager { /** * Put the block locally, using the given storage level. + * + * Returns true if the block was stored and false if the put operation failed or the block + * already existed. + */ + def putBlockData( + blockId: BlockId, + data: ManagedBuffer, + level: StorageLevel, + classTag: ClassTag[_]): Boolean + + /** + * Release locks acquired by [[putBlockData()]] and [[getBlockData()]]. */ - def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit + def releaseLock(blockId: BlockId): Unit } diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index dcbda5a8515dd..09ce012e4e692 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -20,13 +20,14 @@ package org.apache.spark.network import java.io.Closeable import java.nio.ByteBuffer -import scala.concurrent.{Promise, Await, Future} +import scala.concurrent.{Await, Future, Promise} import scala.concurrent.duration.Duration +import scala.reflect.ClassTag -import org.apache.spark.Logging -import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.{ShuffleClient, BlockFetchingListener} -import org.apache.spark.storage.{BlockManagerId, BlockId, StorageLevel} +import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.storage.{BlockId, StorageLevel} private[spark] abstract class BlockTransferService extends ShuffleClient with Closeable with Logging { @@ -35,7 +36,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch * local blocks or put local blocks. */ - def init(blockDataManager: BlockDataManager) + def init(blockDataManager: BlockDataManager): Unit /** * Tear down the transfer service. @@ -76,7 +77,8 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo execId: String, blockId: BlockId, blockData: ManagedBuffer, - level: StorageLevel): Future[Unit] + level: StorageLevel, + classTag: ClassTag[_]): Future[Unit] /** * A special case of [[fetchBlocks]], as it fetches only one block and is blocking. @@ -114,7 +116,9 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo execId: String, blockId: BlockId, blockData: ManagedBuffer, - level: StorageLevel): Unit = { - Await.result(uploadBlock(hostname, port, execId, blockId, blockData, level), Duration.Inf) + level: StorageLevel, + classTag: ClassTag[_]): Unit = { + val future = uploadBlock(hostname, port, execId, blockId, blockData, level, classTag) + Await.result(future, Duration.Inf) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 76968249fb625..2ed8a00df7023 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -20,8 +20,10 @@ package org.apache.spark.network.netty import java.nio.ByteBuffer import scala.collection.JavaConverters._ +import scala.language.existentials +import scala.reflect.ClassTag -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.network.BlockDataManager import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} @@ -47,9 +49,9 @@ class NettyBlockRpcServer( override def receive( client: TransportClient, - messageBytes: Array[Byte], + rpcMessage: ByteBuffer, responseContext: RpcResponseCallback): Unit = { - val message = BlockTransferMessage.Decoder.fromByteArray(messageBytes) + val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage) logTrace(s"Received request: $message") message match { @@ -58,15 +60,20 @@ class NettyBlockRpcServer( openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) val streamId = streamManager.registerStream(appId, blocks.iterator.asJava) logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") - responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray) + responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer) case uploadBlock: UploadBlock => - // StorageLevel is serialized as bytes using our JavaSerializer. - val level: StorageLevel = - serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata)) + // StorageLevel and ClassTag are serialized as bytes using our JavaSerializer. + val (level: StorageLevel, classTag: ClassTag[_]) = { + serializer + .newInstance() + .deserialize(ByteBuffer.wrap(uploadBlock.metadata)) + .asInstanceOf[(StorageLevel, ClassTag[_])] + } val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData)) - blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level) - responseContext.onSuccess(new Array[Byte](0)) + val blockId = BlockId(uploadBlock.blockId) + blockManager.putBlockData(blockId, data, level, classTag) + responseContext.onSuccess(ByteBuffer.allocate(0)) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 70a42f9045e6b..33a3219607749 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,17 +17,21 @@ package org.apache.spark.network.netty +import java.nio.ByteBuffer + import scala.collection.JavaConverters._ import scala.concurrent.{Future, Promise} +import scala.reflect.ClassTag import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory} +import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} import org.apache.spark.network.server._ -import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher} +import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher} import org.apache.spark.network.shuffle.protocol.UploadBlock +import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils @@ -35,13 +39,17 @@ import org.apache.spark.util.Utils /** * A BlockTransferService that uses Netty to fetch a set of blocks at at time. */ -class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManager, numCores: Int) +private[spark] class NettyBlockTransferService( + conf: SparkConf, + securityManager: SecurityManager, + override val hostName: String, + numCores: Int) extends BlockTransferService { // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. private val serializer = new JavaSerializer(conf) private val authEnabled = securityManager.isAuthenticationEnabled() - private val transportConf = SparkTransportConf.fromSparkConf(conf, numCores) + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores) private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ @@ -61,13 +69,13 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage clientFactory = transportContext.createClientFactory(clientBootstrap.toSeq.asJava) server = createServer(serverBootstrap.toList) appId = conf.getAppId - logInfo("Server created on " + server.getPort) + logInfo(s"Server created on ${hostName}:${server.getPort}") } /** Creates and binds the TransportServer, possibly trying multiple ports. */ private def createServer(bootstraps: List[TransportServerBootstrap]): TransportServer = { def startService(port: Int): (TransportServer, Int) = { - val server = transportContext.createServer(port, bootstraps.asJava) + val server = transportContext.createServer(hostName, port, bootstraps.asJava) (server, server.getPort) } @@ -105,8 +113,6 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage } } - override def hostName: String = Utils.localHostName() - override def port: Int = server.getPort override def uploadBlock( @@ -115,27 +121,21 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage execId: String, blockId: BlockId, blockData: ManagedBuffer, - level: StorageLevel): Future[Unit] = { + level: StorageLevel, + classTag: ClassTag[_]): Future[Unit] = { val result = Promise[Unit]() val client = clientFactory.createClient(hostname, port) - // StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded - // using our binary protocol. - val levelBytes = serializer.newInstance().serialize(level).array() + // StorageLevel and ClassTag are serialized as bytes using our JavaSerializer. + // Everything else is encoded using our binary protocol. + val metadata = JavaUtils.bufferToArray(serializer.newInstance().serialize((level, classTag))) // Convert or copy nio buffer into array in order to serialize it. - val nioBuffer = blockData.nioByteBuffer() - val array = if (nioBuffer.hasArray) { - nioBuffer.array() - } else { - val data = new Array[Byte](nioBuffer.remaining()) - nioBuffer.get(data) - data - } + val array = JavaUtils.bufferToArray(blockData.nioByteBuffer()) - client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteArray, + client.sendRpc(new UploadBlock(appId, execId, blockId.toString, metadata, array).toByteBuffer, new RpcResponseCallback { - override def onSuccess(response: Array[Byte]): Unit = { + override def onSuccess(response: ByteBuffer): Unit = { logTrace(s"Successfully uploaded block $blockId") result.success((): Unit) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala index cef203006d685..86874e2067dd4 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala @@ -18,7 +18,7 @@ package org.apache.spark.network.netty import org.apache.spark.SparkConf -import org.apache.spark.network.util.{TransportConf, ConfigProvider} +import org.apache.spark.network.util.{ConfigProvider, TransportConf} /** * Provides a utility for transforming from a SparkConf inside a Spark JVM (e.g., Executor, @@ -40,23 +40,23 @@ object SparkTransportConf { /** * Utility for creating a [[TransportConf]] from a [[SparkConf]]. + * @param _conf the [[SparkConf]] + * @param module the module name * @param numUsableCores if nonzero, this will restrict the server and client threads to only * use the given number of cores, rather than all of the machine's cores. * This restriction will only occur if these properties are not already set. */ - def fromSparkConf(_conf: SparkConf, numUsableCores: Int = 0): TransportConf = { + def fromSparkConf(_conf: SparkConf, module: String, numUsableCores: Int = 0): TransportConf = { val conf = _conf.clone // Specify thread configuration based on our JVM's allocation of cores (rather than necessarily // assuming we have all the machine's cores). // NB: Only set if serverThreads/clientThreads not already set. val numThreads = defaultNumThreads(numUsableCores) - conf.set("spark.shuffle.io.serverThreads", - conf.get("spark.shuffle.io.serverThreads", numThreads.toString)) - conf.set("spark.shuffle.io.clientThreads", - conf.get("spark.shuffle.io.clientThreads", numThreads.toString)) + conf.setIfMissing(s"spark.$module.io.serverThreads", numThreads.toString) + conf.setIfMissing(s"spark.$module.io.clientThreads", numThreads.toString) - new TransportConf(new ConfigProvider { + new TransportConf(module, new ConfigProvider { override def get(name: String): String = conf.get(name) }) } diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 7515aad09db73..cc5e7ef3ae008 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -43,5 +43,5 @@ package org.apache package object spark { // For package docs only - val SPARK_VERSION = "1.6.0-SNAPSHOT" + val SPARK_VERSION = "2.0.0-SNAPSHOT" } diff --git a/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala index d25452daf7606..b089bbd7e972e 100644 --- a/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala +++ b/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala @@ -38,7 +38,7 @@ private[spark] class ApproximateActionListener[T, U, R]( extends JobListener { val startTime = System.currentTimeMillis() - val totalTasks = rdd.partitions.size + val totalTasks = rdd.partitions.length var finishedTasks = 0 var failure: Option[Exception] = None // Set if the job has failed (permanently) var resultObject: Option[PartialResult[R]] = None // Set if we've already returned a PartialResult diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala index 48b9434153172..ab6aba6fc7d6a 100644 --- a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala +++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala @@ -21,5 +21,22 @@ package org.apache.spark.partial * A Double value with error bars and associated confidence. */ class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) { + override def toString(): String = "[%.3f, %.3f]".format(low, high) + + override def hashCode: Int = + this.mean.hashCode ^ this.confidence.hashCode ^ this.low.hashCode ^ this.high.hashCode + + /** + * Note that consistent with Double, any NaN value will make equality false + */ + override def equals(that: Any): Boolean = + that match { + case that: BoundedDouble => + this.mean == that.mean && + this.confidence == that.confidence && + this.low == that.low && + this.high == that.high + case _ => false + } } diff --git a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala b/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala index 828bf96c2c0bd..55acb9ca64d3f 100644 --- a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala +++ b/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala @@ -17,7 +17,7 @@ package org.apache.spark.partial -import org.apache.commons.math3.distribution.{TDistribution, NormalDistribution} +import org.apache.commons.math3.distribution.{NormalDistribution, TDistribution} /** * A utility class for caching Student's T distribution values for a given confidence level diff --git a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala index 1753c2561b678..5fe33583166c3 100644 --- a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala @@ -17,7 +17,7 @@ package org.apache.spark.partial -import org.apache.commons.math3.distribution.{TDistribution, NormalDistribution} +import org.apache.commons.math3.distribution.{NormalDistribution, TDistribution} import org.apache.spark.util.StatCounter @@ -29,8 +29,9 @@ import org.apache.spark.util.StatCounter private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double) extends ApproximateEvaluator[StatCounter, BoundedDouble] { + // modified in merge var outputsMerged = 0 - var counter = new StatCounter + val counter = new StatCounter override def merge(outputId: Int, taskResult: StatCounter) { outputsMerged += 1 @@ -40,30 +41,39 @@ private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double) override def currentResult(): BoundedDouble = { if (outputsMerged == totalOutputs) { new BoundedDouble(counter.sum, 1.0, counter.sum, counter.sum) - } else if (outputsMerged == 0) { + } else if (outputsMerged == 0 || counter.count == 0) { new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) } else { val p = outputsMerged.toDouble / totalOutputs val meanEstimate = counter.mean - val meanVar = counter.sampleVariance / counter.count val countEstimate = (counter.count + 1 - p) / p - val countVar = (counter.count + 1) * (1 - p) / (p * p) val sumEstimate = meanEstimate * countEstimate - val sumVar = (meanEstimate * meanEstimate * countVar) + - (countEstimate * countEstimate * meanVar) + - (meanVar * countVar) - val sumStdev = math.sqrt(sumVar) - val confFactor = { - if (counter.count > 100) { + + val meanVar = counter.sampleVariance / counter.count + + // branch at this point because counter.count == 1 implies counter.sampleVariance == Nan + // and we don't want to ever return a bound of NaN + if (meanVar.isNaN || counter.count == 1) { + new BoundedDouble(sumEstimate, confidence, Double.NegativeInfinity, Double.PositiveInfinity) + } else { + val countVar = (counter.count + 1) * (1 - p) / (p * p) + val sumVar = (meanEstimate * meanEstimate * countVar) + + (countEstimate * countEstimate * meanVar) + + (meanVar * countVar) + val sumStdev = math.sqrt(sumVar) + val confFactor = if (counter.count > 100) { new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2) } else { + // note that if this goes to 0, TDistribution will throw an exception. + // Hence special casing 1 above. val degreesOfFreedom = (counter.count - 1).toInt new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2) } + + val low = sumEstimate - confFactor * sumStdev + val high = sumEstimate + confFactor * sumStdev + new BoundedDouble(sumEstimate, confidence, low, high) } - val low = sumEstimate - confFactor * sumStdev - val high = sumEstimate + confFactor * sumStdev - new BoundedDouble(sumEstimate, confidence, low, high) } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index ca1eb1f4e4a9a..c9ed12f4e1bd4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -19,13 +19,13 @@ package org.apache.spark.rdd import java.util.concurrent.atomic.AtomicLong -import org.apache.spark.util.ThreadUtils - import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext +import scala.concurrent.{ExecutionContext, Future} import scala.reflect.ClassTag -import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} +import org.apache.spark.{ComplexFutureAction, FutureAction, JobSubmitter} +import org.apache.spark.internal.Logging +import org.apache.spark.util.ThreadUtils /** * A set of asynchronous RDD actions available through an implicit conversion. @@ -65,18 +65,26 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi * Returns a future for retrieving the first num elements of the RDD. */ def takeAsync(num: Int): FutureAction[Seq[T]] = self.withScope { - val f = new ComplexFutureAction[Seq[T]] - - f.run { - // This is a blocking action so we should use "AsyncRDDActions.futureExecutionContext" which - // is a cached thread pool. - val results = new ArrayBuffer[T](num) - val totalParts = self.partitions.length - var partsScanned = 0 - while (results.size < num && partsScanned < totalParts) { + val callSite = self.context.getCallSite + val localProperties = self.context.getLocalProperties + // Cached thread pool to handle aggregation of subtasks. + implicit val executionContext = AsyncRDDActions.futureExecutionContext + val results = new ArrayBuffer[T] + val totalParts = self.partitions.length + + /* + Recursively triggers jobs to scan partitions until either the requested + number of elements are retrieved, or the partitions to scan are exhausted. + This implementation is non-blocking, asynchronously handling the + results of each job and triggering the next job using callbacks on futures. + */ + def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter): Future[Seq[T]] = + if (results.size >= num || partsScanned >= totalParts) { + Future.successful(results.toSeq) + } else { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1 + var numPartsToTry = 1L if (partsScanned > 0) { // If we didn't find any rows after the previous iteration, quadruple and retry. // Otherwise, interpolate the number of partitions we need to try, but overestimate it @@ -92,22 +100,23 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi } val left = num - results.size - val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val buf = new Array[Array[T]](p.size) - f.runJob(self, + self.context.setCallSite(callSite) + self.context.setLocalProperties(localProperties) + val job = jobSubmitter.submitJob(self, (it: Iterator[T]) => it.take(left).toArray, p, (index: Int, data: Array[T]) => buf(index) = data, Unit) - - buf.foreach(results ++= _.take(num - results.size)) - partsScanned += numPartsToTry + job.flatMap { _ => + buf.foreach(results ++= _.take(num - results.size)) + continue(partsScanned + p.size) + } } - results.toSeq - }(AsyncRDDActions.futureExecutionContext) - f + new ComplexFutureAction[Seq[T]](continue(0)(_)) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index aedced7408cde..be0cb175f5340 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -20,8 +20,10 @@ package org.apache.spark.rdd import org.apache.hadoop.conf.{ Configurable, Configuration } import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.task.JobContextImpl + +import org.apache.spark.{Partition, SparkContext} import org.apache.spark.input.StreamFileInputFormat -import org.apache.spark.{ Partition, SparkContext } private[spark] class BinaryFileRDD[T]( sc: SparkContext, @@ -40,7 +42,7 @@ private[spark] class BinaryFileRDD[T]( configurable.setConf(conf) case _ => } - val jobContext = newJobContext(conf, jobId) + val jobContext = new JobContextImpl(conf, jobId) inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index fc1710fbad0a3..63d1d1767a8cb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -21,7 +21,6 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.storage.{BlockId, BlockManager} -import scala.Some private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends Partition { val index = idx @@ -36,9 +35,9 @@ class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[Blo override def getPartitions: Array[Partition] = { assertValid() - (0 until blockIds.length).map(i => { + (0 until blockIds.length).map { i => new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition] - }).toArray + }.toArray } override def compute(split: Partition, context: TaskContext): Iterator[T] = { diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index 18e8cddbc40db..57108dcedcf0c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -50,7 +50,7 @@ class CartesianRDD[T: ClassTag, U: ClassTag]( sc: SparkContext, var rdd1 : RDD[T], var rdd2 : RDD[U]) - extends RDD[Pair[T, U]](sc, Nil) + extends RDD[(T, U)](sc, Nil) with Serializable { val numPartitionsInRdd2 = rdd2.partitions.length diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 935c3babd8ea1..7bc1eb043610a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -17,23 +17,24 @@ package org.apache.spark.rdd -import scala.language.existentials - import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer +import scala.language.existentials import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.serializer.Serializer import org.apache.spark.util.collection.{CompactBuffer, ExternalAppendOnlyMap} import org.apache.spark.util.Utils -import org.apache.spark.serializer.Serializer -/** The references to rdd and splitIndex are transient because redundant information is stored - * in the CoGroupedRDD object. Because CoGroupedRDD is serialized separately from - * CoGroupPartition, if rdd and splitIndex aren't transient, they'll be included twice in the - * task closure. */ +/** + * The references to rdd and splitIndex are transient because redundant information is stored + * in the CoGroupedRDD object. Because CoGroupedRDD is serialized separately from + * CoGroupPartition, if rdd and splitIndex aren't transient, they'll be included twice in the + * task closure. + */ private[spark] case class NarrowCoGroupSplitDep( @transient rdd: RDD[_], @transient splitIndex: Int, @@ -70,7 +71,7 @@ private[spark] class CoGroupPartition( * * Note: This is an internal API. We recommend users use RDD.cogroup(...) instead of * instantiating this directly. - + * * @param rdds parent RDDs. * @param part partitioner used to partition the shuffle output */ @@ -87,11 +88,11 @@ class CoGroupedRDD[K: ClassTag]( private type CoGroupValue = (Any, Int) // Int is dependency number private type CoGroupCombiner = Array[CoGroup] - private var serializer: Option[Serializer] = None + private var serializer: Serializer = SparkEnv.get.serializer /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ def setSerializer(serializer: Serializer): CoGroupedRDD[K] = { - this.serializer = Option(serializer) + this.serializer = serializer this } @@ -154,8 +155,7 @@ class CoGroupedRDD[K: ClassTag]( } context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) - context.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) + context.taskMetrics().incPeakExecutionMemory(map.peakMemoryUsedBytes) new InterruptibleIterator(context, map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 7fbaadcea3a3b..368916a39e649 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -17,8 +17,8 @@ package org.apache.spark.rdd -import org.apache.spark.annotation.Experimental -import org.apache.spark.{TaskContext, Logging} +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.MeanEvaluator import org.apache.spark.partial.PartialResult @@ -103,7 +103,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * If the RDD contains infinity, NaN throws an exception * If the elements in RDD do not vary (max == min) always returns a single bucket. */ - def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = self.withScope { + def histogram(bucketCount: Int): (Array[Double], Array[Long]) = self.withScope { // Scala's built-in range has issues. See #SI-8782 def customRange(min: Double, max: Double, steps: Int): IndexedSeq[Double] = { val span = max - min @@ -112,7 +112,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { // Compute the minimum and the maximum val (max: Double, min: Double) = self.mapPartitions { items => Iterator(items.foldRight(Double.NegativeInfinity, - Double.PositiveInfinity)((e: Double, x: Pair[Double, Double]) => + Double.PositiveInfinity)((e: Double, x: (Double, Double)) => (x._1.max(e), x._2.min(e)))) }.reduce { (maxmin1, maxmin2) => (maxmin1._1.max(maxmin2._1), maxmin1._2.min(maxmin2._2)) @@ -142,7 +142,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * And on the input of 1 and 50 we would have a histogram of 1, 0, 1 * * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched - * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets + * from an O(log n) insertion to O(1) per element. (where n = # buckets) if you set evenBuckets * to true. * buckets must be sorted and not contain any duplicates. * buckets array must be at least two elements @@ -166,8 +166,8 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { val counters = new Array[Long](buckets.length - 1) while (iter.hasNext) { bucketFunction(iter.next()) match { - case Some(x: Int) => {counters(x) += 1} - case _ => {} + case Some(x: Int) => counters(x) += 1 + case _ => // No-Op } } Iterator(counters) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index d841f05ec52cf..35d190b464ff4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -17,25 +17,26 @@ package org.apache.spark.rdd +import java.io.EOFException import java.text.SimpleDateFormat import java.util.Date -import java.io.EOFException import scala.collection.immutable.Map -import scala.reflect.ClassTag import scala.collection.mutable.ListBuffer +import scala.reflect.ClassTag import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.mapred.FileSplit import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.InputSplit import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapred.JobID import org.apache.hadoop.mapred.RecordReader import org.apache.hadoop.mapred.Reporter -import org.apache.hadoop.mapred.JobID import org.apache.hadoop.mapred.TaskAttemptID import org.apache.hadoop.mapred.TaskID import org.apache.hadoop.mapred.lib.CombineFileSplit +import org.apache.hadoop.mapreduce.TaskType import org.apache.hadoop.util.ReflectionUtils import org.apache.spark._ @@ -43,10 +44,11 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod +import org.apache.spark.internal.Logging import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, NextIterator, Utils} -import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation} +import org.apache.spark.scheduler.{HDFSCacheTaskLocation, HostTaskLocation} import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.{NextIterator, SerializableConfiguration, ShutdownHookManager, Utils} /** * A Spark split class that wraps around a Hadoop InputSplit. @@ -88,8 +90,8 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, s: InputSplit) * * @param sc The SparkContext to associate the RDD with. * @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed - * variabe references an instance of JobConf, then that JobConf will be used for the Hadoop job. - * Otherwise, a new JobConf will be created on each slave using the enclosed Configuration. + * variable references an instance of JobConf, then that JobConf will be used for the Hadoop job. + * Otherwise, a new JobConf will be created on each slave using the enclosed Configuration. * @param initLocalJobConfFuncOpt Optional closure used to initialize any JobConf that HadoopRDD * creates. * @param inputFormatClass Storage format of the data to be read. @@ -123,7 +125,7 @@ class HadoopRDD[K, V]( sc, sc.broadcast(new SerializableConfiguration(conf)) .asInstanceOf[Broadcast[SerializableConfiguration]], - None /* initLocalJobConfFuncOpt */, + initLocalJobConfFuncOpt = None, inputFormatClass, keyClass, valueClass, @@ -184,8 +186,9 @@ class HadoopRDD[K, V]( protected def getInputFormat(conf: JobConf): InputFormat[K, V] = { val newInputFormat = ReflectionUtils.newInstance(inputFormatClass.asInstanceOf[Class[_]], conf) .asInstanceOf[InputFormat[K, V]] - if (newInputFormat.isInstanceOf[Configurable]) { - newInputFormat.asInstanceOf[Configurable].setConf(conf) + newInputFormat match { + case c: Configurable => c.setConf(conf) + case _ => } newInputFormat } @@ -195,9 +198,6 @@ class HadoopRDD[K, V]( // add the credentials here as this can be called before SparkContext initialized SparkHadoopUtil.get.addCredentials(jobConf) val inputFormat = getInputFormat(jobConf) - if (inputFormat.isInstanceOf[Configurable]) { - inputFormat.asInstanceOf[Configurable].setConf(jobConf) - } val inputSplits = inputFormat.getSplits(jobConf, minPartitions) val array = new Array[Partition](inputSplits.size) for (i <- 0 until inputSplits.size) { @@ -213,18 +213,32 @@ class HadoopRDD[K, V]( logInfo("Input split: " + split.inputSplit) val jobConf = getJobConf() - val inputMetrics = context.taskMetrics.getInputMetricsForReadMethod(DataReadMethod.Hadoop) + val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) + val existingBytesRead = inputMetrics.bytesRead + + // Sets the thread local variable for the file's name + split.inputSplit.value match { + case fs: FileSplit => InputFileNameHolder.setInputFileName(fs.getPath.toString) + case _ => InputFileNameHolder.unsetInputFileName() + } // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { - split.inputSplit.value match { - case _: FileSplit | _: CombineFileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() - case _ => None + val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match { + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + case _ => None + } + + // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics. + // If we do a coalesce, however, we are likely to compute multiple partitions in the same + // task and in the same thread, in which case we need to avoid override values written by + // previous partitions (SPARK-13071). + def updateBytesRead(): Unit = { + getBytesReadCallback.foreach { getBytesRead => + inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) } } - inputMetrics.setBytesReadCallback(bytesReadCallback) var reader: RecordReader[K, V] = null val inputFormat = getInputFormat(jobConf) @@ -247,11 +261,15 @@ class HadoopRDD[K, V]( if (!finished) { inputMetrics.incRecordsRead(1) } + if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { + updateBytesRead() + } (key, value) } override def close() { if (reader != null) { + InputFileNameHolder.unsetInputFileName() // Close the reader and release it. Note: it's very important that we don't close the // reader more than once, since that exposes us to MAPREDUCE-5918 when running against // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic @@ -266,8 +284,8 @@ class HadoopRDD[K, V]( } finally { reader = null } - if (bytesReadCallback.isDefined) { - inputMetrics.updateBytesRead() + if (getBytesReadCallback.isDefined) { + updateBytesRead() } else if (split.inputSplit.value.isInstanceOf[FileSplit] || split.inputSplit.value.isInstanceOf[CombineFileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, @@ -352,7 +370,7 @@ private[spark] object HadoopRDD extends Logging { def addLocalConfiguration(jobTrackerId: String, jobId: Int, splitId: Int, attemptId: Int, conf: JobConf) { val jobID = new JobID(jobTrackerId, jobId) - val taId = new TaskAttemptID(new TaskID(jobID, true, splitId), attemptId) + val taId = new TaskAttemptID(new TaskID(jobID, TaskType.MAP, splitId), attemptId) conf.set("mapred.tip.id", taId.getTaskID.toString) conf.set("mapred.task.id", taId.toString) @@ -404,7 +422,7 @@ private[spark] object HadoopRDD extends Logging { private[spark] def convertSplitLocationInfo(infos: Array[AnyRef]): Seq[String] = { val out = ListBuffer[String]() - infos.foreach { loc => { + infos.foreach { loc => val locationStr = HadoopRDD.SPLIT_INFO_REFLECTIONS.get. getLocation.invoke(loc).asInstanceOf[String] if (locationStr != "localhost") { @@ -416,7 +434,7 @@ private[spark] object HadoopRDD extends Logging { out += new HostTaskLocation(locationStr).toString } } - }} + } out.seq } } diff --git a/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala b/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala new file mode 100644 index 0000000000000..108e9d2558190 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.unsafe.types.UTF8String + +/** + * This holds file names of the current Spark task. This is used in HadoopRDD, + * FileScanRDD and InputFileName function in Spark SQL. + */ +private[spark] object InputFileNameHolder { + /** + * The thread variable for the name of the current file being read. This is used by + * the InputFileName function in Spark SQL. + */ + private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { + override protected def initialValue(): UTF8String = UTF8String.fromString("") + } + + def getInputFileName(): UTF8String = inputFileName.get() + + private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) + + private[spark] def unsetInputFileName(): Unit = inputFileName.remove() + +} diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index 0c28f045e46e9..526138093d3ea 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -17,15 +17,16 @@ package org.apache.spark.rdd -import java.sql.{PreparedStatement, Connection, ResultSet} +import java.sql.{Connection, ResultSet} import scala.reflect.ClassTag +import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.function.{Function => JFunction} -import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.internal.Logging import org.apache.spark.util.NextIterator -import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition { override def index: Int = idx diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala index bfe19195fcd37..503aa0dffc9f3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import scala.reflect.ClassTag -import org.apache.spark.{Partition, SparkContext, SparkEnv, SparkException, TaskContext} +import org.apache.spark.{Partition, SparkContext, SparkException, TaskContext} import org.apache.spark.storage.RDDBlockId /** @@ -41,7 +41,7 @@ private[spark] class LocalCheckpointRDD[T: ClassTag]( extends CheckpointRDD[T](sc) { def this(rdd: RDD[T]) { - this(rdd.context, rdd.id, rdd.partitions.size) + this(rdd.context, rdd.id, rdd.partitions.length) } protected override def getPartitions: Array[Partition] = { diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala index c115e0ff74d3c..56f53714cbe3a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala @@ -19,7 +19,8 @@ package org.apache.spark.rdd import scala.reflect.ClassTag -import org.apache.spark.{Logging, SparkEnv, SparkException, TaskContext} +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.internal.Logging import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.Utils @@ -72,12 +73,6 @@ private[spark] object LocalRDDCheckpointData { * This method is idempotent. */ def transformStorageLevel(level: StorageLevel): StorageLevel = { - // If this RDD is to be cached off-heap, fail fast since we cannot provide any - // correctness guarantees about subsequent computations after the first one - if (level.useOffHeap) { - throw new SparkException("Local checkpointing is not compatible with off-heap caching.") - } - StorageLevel(useDisk = true, level.useMemory, level.deserialized, level.replication) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index 4312d3a417759..e4587c96eae1c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -25,7 +25,7 @@ import org.apache.spark.{Partition, TaskContext} * An RDD that applies the provided function to every partition of the parent RDD. */ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( - prev: RDD[T], + var prev: RDD[T], f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator) preservesPartitioning: Boolean = false) extends RDD[U](prev) { @@ -36,4 +36,9 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( override def compute(split: Partition, context: TaskContext): Iterator[U] = f(context, split.index, firstParent[T].iterator(split, context)) + + override def clearDependencies() { + super.clearDependencies() + prev = null + } } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 9c4b70844bdbe..3ccd616cbfd57 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -24,18 +24,19 @@ import scala.reflect.ClassTag import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} +import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl} -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.input.WholeTextFileInputFormat import org.apache.spark._ +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.internal.Logging import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils} -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} private[spark] class NewHadoopPartition( rddId: Int, @@ -59,7 +60,6 @@ private[spark] class NewHadoopPartition( * @param inputFormatClass Storage format of the data to be read. * @param keyClass Class of the key associated with the inputFormatClass. * @param valueClass Class of the value associated with the inputFormatClass. - * @param conf The Hadoop configuration. */ @DeveloperApi class NewHadoopRDD[K, V]( @@ -68,9 +68,7 @@ class NewHadoopRDD[K, V]( keyClass: Class[K], valueClass: Class[V], @transient private val _conf: Configuration) - extends RDD[(K, V)](sc, Nil) - with SparkHadoopMapReduceUtil - with Logging { + extends RDD[(K, V)](sc, Nil) with Logging { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it private val confBroadcast = sc.broadcast(new SerializableConfiguration(_conf)) @@ -97,7 +95,13 @@ class NewHadoopRDD[K, V]( // issues, this cloning is disabled by default. NewHadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { logDebug("Cloning Hadoop Configuration") - new Configuration(conf) + // The Configuration passed in is actually a JobConf and possibly contains credentials. + // To keep those credentials properly we have to create a new JobConf not a Configuration. + if (conf.isInstanceOf[JobConf]) { + new JobConf(conf) + } else { + new Configuration(conf) + } } } else { conf @@ -111,7 +115,7 @@ class NewHadoopRDD[K, V]( configurable.setConf(_conf) case _ => } - val jobContext = newJobContext(_conf, jobId) + val jobContext = new JobContextImpl(_conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) for (i <- 0 until rawSplits.size) { @@ -126,28 +130,35 @@ class NewHadoopRDD[K, V]( logInfo("Input split: " + split.serializableHadoopSplit) val conf = getConf - val inputMetrics = context.taskMetrics - .getInputMetricsForReadMethod(DataReadMethod.Hadoop) + val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) + val existingBytesRead = inputMetrics.bytesRead // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { - split.serializableHadoopSplit.value match { - case _: FileSplit | _: CombineFileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() - case _ => None + val getBytesReadCallback: Option[() => Long] = split.serializableHadoopSplit.value match { + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + case _ => None + } + + // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics. + // If we do a coalesce, however, we are likely to compute multiple partitions in the same + // task and in the same thread, in which case we need to avoid override values written by + // previous partitions (SPARK-13071). + def updateBytesRead(): Unit = { + getBytesReadCallback.foreach { getBytesRead => + inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) } } - inputMetrics.setBytesReadCallback(bytesReadCallback) - val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) val format = inputFormatClass.newInstance format match { case configurable: Configurable => configurable.setConf(conf) case _ => } + val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0) + val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) private var reader = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) @@ -180,6 +191,9 @@ class NewHadoopRDD[K, V]( if (!finished) { inputMetrics.incRecordsRead(1) } + if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { + updateBytesRead() + } (reader.getCurrentKey, reader.getCurrentValue) } @@ -199,8 +213,8 @@ class NewHadoopRDD[K, V]( } finally { reader = null } - if (bytesReadCallback.isDefined) { - inputMetrics.updateBytesRead() + if (getBytesReadCallback.isDefined) { + updateBytesRead() } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, @@ -282,32 +296,3 @@ private[spark] object NewHadoopRDD { } } } - -private[spark] class WholeTextFileRDD( - sc : SparkContext, - inputFormatClass: Class[_ <: WholeTextFileInputFormat], - keyClass: Class[String], - valueClass: Class[String], - conf: Configuration, - minPartitions: Int) - extends NewHadoopRDD[String, String](sc, inputFormatClass, keyClass, valueClass, conf) { - - override def getPartitions: Array[Partition] = { - val inputFormat = inputFormatClass.newInstance - val conf = getConf - inputFormat match { - case configurable: Configurable => - configurable.setConf(conf) - case _ => - } - val jobContext = newJobContext(conf, jobId) - inputFormat.setMinPartitions(jobContext, minPartitions) - val rawSplits = inputFormat.getSplits(jobContext).toArray - val result = new Array[Partition](rawSplits.size) - for (i <- 0 until rawSplits.size) { - result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) - } - result - } -} - diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index d71bb63000904..a5992022d0832 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -19,8 +19,9 @@ package org.apache.spark.rdd import scala.reflect.ClassTag -import org.apache.spark.{Logging, Partitioner, RangePartitioner} +import org.apache.spark.{Partitioner, RangePartitioner} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.Logging /** * Extra functions available on RDDs of (key, value) pairs where the key is sortable through @@ -45,8 +46,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag, P <: Product2[K, V] : ClassTag] @DeveloperApi() ( self: RDD[P]) - extends Logging with Serializable -{ + extends Logging with Serializable { private val ordering = implicitly[Ordering[K]] /** @@ -76,7 +76,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, } /** - * Returns an RDD containing only the elements in the the inclusive range `lower` to `upper`. + * Returns an RDD containing only the elements in the inclusive range `lower` to `upper`. * If the RDD has been partitioned using a `RangePartitioner`, then this operation can be * performed efficiently by only scanning the partitions that might contain matching elements. * Otherwise, a standard `filter` is applied to all partitions. @@ -86,12 +86,11 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, def inRange(k: K): Boolean = ordering.gteq(k, lower) && ordering.lteq(k, upper) val rddToFilter: RDD[P] = self.partitioner match { - case Some(rp: RangePartitioner[K, V]) => { + case Some(rp: RangePartitioner[K, V]) => val partitionIndicies = (rp.getPartition(lower), rp.getPartition(upper)) match { case (l, u) => Math.min(l, u) to Math.max(l, u) } PartitionPruningRDD.create(self, partitionIndicies.contains) - } case _ => self } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index c6181902ace6d..085829af6eee7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import java.text.SimpleDateFormat import java.util.{Date, HashMap => JHashMap} -import scala.collection.{Map, mutable} +import scala.collection.{mutable, Map} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag @@ -33,15 +33,15 @@ import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} -import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, - RecordWriter => NewRecordWriter} +import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, TaskAttemptID, TaskType} +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.annotation.Experimental import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.{DataWriteMethod, OutputMetrics} -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.internal.Logging import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -53,10 +53,7 @@ import org.apache.spark.util.random.StratifiedSamplingUtils */ class PairRDDFunctions[K, V](self: RDD[(K, V)]) (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) - extends Logging - with SparkHadoopMapReduceUtil - with Serializable -{ + extends Logging with Serializable { /** * :: Experimental :: @@ -65,9 +62,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Note that V and C can be different -- for example, one might group an RDD of type * (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions: * - * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) - * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) - * - `mergeCombiners`, to combine two C's into a single one. + * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) + * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) + * - `mergeCombiners`, to combine two C's into a single one. * * In addition, users can control the partitioning of the output RDD, and whether to perform * map-side aggregation (if a mapper can produce multiple items with the same key). @@ -304,27 +301,27 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. + * Merge the values for each key using an associative and commutative reduce function. This will + * also perform the merging locally on each mapper before sending results to a reducer, similarly + * to a "combiner" in MapReduce. */ def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = self.withScope { combineByKeyWithClassTag[V]((v: V) => v, func, func, partitioner) } /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. + * Merge the values for each key using an associative and commutative reduce function. This will + * also perform the merging locally on each mapper before sending results to a reducer, similarly + * to a "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. */ def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)] = self.withScope { reduceByKey(new HashPartitioner(numPartitions), func) } /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ + * Merge the values for each key using an associative and commutative reduce function. This will + * also perform the merging locally on each mapper before sending results to a reducer, similarly + * to a "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ * parallelism level. */ def reduceByKey(func: (V, V) => V): RDD[(K, V)] = self.withScope { @@ -332,9 +329,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } /** - * Merge the values for each key using an associative reduce function, but return the results - * immediately to the master as a Map. This will also perform the merging locally on each mapper - * before sending results to a reducer, similarly to a "combiner" in MapReduce. + * Merge the values for each key using an associative and commutative reduce function, but return + * the results immediately to the master as a Map. This will also perform the merging locally on + * each mapper before sending results to a reducer, similarly to a "combiner" in MapReduce. */ def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = self.withScope { val cleanedF = self.sparkContext.clean(func) @@ -363,12 +360,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) self.mapPartitions(reducePartition).reduce(mergeMaps).asScala } - /** Alias for reduceByKeyLocally */ - @deprecated("Use reduceByKeyLocally", "1.0.0") - def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = self.withScope { - reduceByKeyLocally(func) - } - /** * Count the number of elements for each key, collecting the results to a local Map. * @@ -736,6 +727,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * * Warning: this doesn't return a multimap (so if you have multiple values to the same key, only * one value per key is preserved in the map returned) + * + * @note this method should only be used if the resulting data is expected to be small, as + * all the data is loaded into the driver's memory. */ def collectAsMap(): Map[K, V] = self.withScope { val data = self.collect() @@ -985,11 +979,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) conf: Configuration = self.context.hadoopConfiguration): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf - val job = new NewAPIHadoopJob(hadoopConf) + val job = NewAPIHadoopJob.getInstance(hadoopConf) job.setOutputKeyClass(keyClass) job.setOutputValueClass(valueClass) job.setOutputFormatClass(outputFormatClass) - val jobConfiguration = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val jobConfiguration = job.getConfiguration jobConfiguration.set("mapred.output.dir", path) saveAsNewAPIHadoopDataset(jobConfiguration) } @@ -1074,11 +1068,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) def saveAsNewAPIHadoopDataset(conf: Configuration): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf - val job = new NewAPIHadoopJob(hadoopConf) + val job = NewAPIHadoopJob.getInstance(hadoopConf) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = self.id - val jobConfiguration = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val jobConfiguration = job.getConfiguration val wrappedConf = new SerializableConfiguration(jobConfiguration) val outfmt = job.getOutputFormatClass val jobFormat = outfmt.newInstance @@ -1091,9 +1085,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writeShard = (context: TaskContext, iter: Iterator[(K, V)]) => { val config = wrappedConf.value /* "reduce task" */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, + val attemptId = new TaskAttemptID(jobtrackerID, stageId, TaskType.REDUCE, context.partitionId, context.attemptNumber) - val hadoopContext = newTaskAttemptContext(config, attemptId) + val hadoopContext = new TaskAttemptContextImpl(config, attemptId) val format = outfmt.newInstance format match { case c: Configurable => c.setConf(config) @@ -1102,31 +1096,32 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val committer = format.getOutputCommitter(hadoopContext) committer.setupTask(hadoopContext) - val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context) + val outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)] = + initHadoopOutputMetrics(context) val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K, V]] require(writer != null, "Unable to obtain RecordWriter") var recordsWritten = 0L - Utils.tryWithSafeFinally { + Utils.tryWithSafeFinallyAndFailureCallbacks { while (iter.hasNext) { val pair = iter.next() writer.write(pair._1, pair._2) // Update bytes written metric every few records - maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten) + maybeUpdateOutputMetrics(outputMetricsAndBytesWrittenCallback, recordsWritten) recordsWritten += 1 } - } { - writer.close(hadoopContext) - } + }(finallyBlock = writer.close(hadoopContext)) committer.commitTask(hadoopContext) - bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) } - outputMetrics.setRecordsWritten(recordsWritten) + outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) => + om.setBytesWritten(callback()) + om.setRecordsWritten(recordsWritten) + } 1 } : Int - val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0) - val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) + val jobAttemptId = new TaskAttemptID(jobtrackerID, stageId, TaskType.MAP, 0, 0) + val jobTaskContext = new TaskAttemptContextImpl(wrappedConf.value, jobAttemptId) val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) // When speculation is on and output committer class name contains "Direct", we should warn @@ -1187,47 +1182,54 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) // around by taking a mod. We expect that no task will be attempted 2 billion times. val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt - val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context) + val outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)] = + initHadoopOutputMetrics(context) writer.setup(context.stageId, context.partitionId, taskAttemptId) writer.open() var recordsWritten = 0L - Utils.tryWithSafeFinally { + Utils.tryWithSafeFinallyAndFailureCallbacks { while (iter.hasNext) { val record = iter.next() writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) // Update bytes written metric every few records - maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten) + maybeUpdateOutputMetrics(outputMetricsAndBytesWrittenCallback, recordsWritten) recordsWritten += 1 } - } { - writer.close() - } + }(finallyBlock = writer.close()) writer.commit() - bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) } - outputMetrics.setRecordsWritten(recordsWritten) + outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) => + om.setBytesWritten(callback()) + om.setRecordsWritten(recordsWritten) + } } self.context.runJob(self, writeToFile) writer.commitJob() } - private def initHadoopOutputMetrics(context: TaskContext): (OutputMetrics, Option[() => Long]) = { + // TODO: these don't seem like the right abstractions. + // We should abstract the duplicate code in a less awkward way. + + // return type: (output metrics, bytes written callback), defined only if the latter is defined + private def initHadoopOutputMetrics( + context: TaskContext): Option[(OutputMetrics, () => Long)] = { val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback() - val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop) - if (bytesWrittenCallback.isDefined) { - context.taskMetrics.outputMetrics = Some(outputMetrics) + bytesWrittenCallback.map { b => + (context.taskMetrics().registerOutputMetrics(DataWriteMethod.Hadoop), b) } - (outputMetrics, bytesWrittenCallback) } - private def maybeUpdateOutputMetrics(bytesWrittenCallback: Option[() => Long], - outputMetrics: OutputMetrics, recordsWritten: Long): Unit = { + private def maybeUpdateOutputMetrics( + outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)], + recordsWritten: Long): Unit = { if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0) { - bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) } - outputMetrics.setRecordsWritten(recordsWritten) + outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) => + om.setBytesWritten(callback()) + om.setRecordsWritten(recordsWritten) + } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 582fa93afe34e..bb84e4af15b15 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -121,14 +121,14 @@ private object ParallelCollectionRDD { // Sequences need to be sliced at the same set of index positions for operations // like RDD.zip() to behave as expected def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = { - (0 until numSlices).iterator.map(i => { + (0 until numSlices).iterator.map { i => val start = ((i * length) / numSlices).toInt val end = (((i + 1) * length) / numSlices).toInt (start, end) - }) + } } seq match { - case r: Range => { + case r: Range => positions(r.length, numSlices).zipWithIndex.map({ case ((start, end), index) => // If the range is inclusive, use inclusive range for the last slice if (r.isInclusive && index == numSlices - 1) { @@ -138,8 +138,7 @@ private object ParallelCollectionRDD { new Range(r.start + start * r.step, r.start + end * r.step, r.step) } }).toSeq.asInstanceOf[Seq[Seq[T]]] - } - case nr: NumericRange[_] => { + case nr: NumericRange[_] => // For ranges of Long, Double, BigInteger, etc val slices = new ArrayBuffer[Seq[T]](numSlices) var r = nr @@ -149,14 +148,12 @@ private object ParallelCollectionRDD { r = r.drop(sliceSize) } slices - } - case _ => { + case _ => val array = seq.toArray // To prevent O(n^2) operations for List etc positions(array.length, numSlices).map({ case (start, end) => array.slice(start, end).toSeq }).toSeq - } } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index d6a37e8cc5dac..0c6ddda52cee9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -65,7 +65,7 @@ class PartitionPruningRDD[T: ClassTag]( } override protected def getPartitions: Array[Partition] = - getDependencies.head.asInstanceOf[PruneDependency[T]].partitions + dependencies.head.asInstanceOf[PruneDependency[T]].partitions } diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index 9e3880714a79f..0abba15bec9f7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -68,9 +68,9 @@ class PartitionerAwareUnionRDD[T: ClassTag]( override def getPartitions: Array[Partition] = { val numPartitions = partitioner.get.numPartitions - (0 until numPartitions).map(index => { + (0 until numPartitions).map { index => new PartitionerAwareUnionRDDPartition(rdds, index) - }).toArray + }.toArray } // Get the location where most of the partitions of parent RDDs are located @@ -78,11 +78,10 @@ class PartitionerAwareUnionRDD[T: ClassTag]( logDebug("Finding preferred location for " + this + ", partition " + s.index) val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents val locations = rdds.zip(parentPartitions).flatMap { - case (rdd, part) => { + case (rdd, part) => val parentLocations = currPrefLocs(rdd, part) logDebug("Location of " + rdd + " partition " + part.index + " = " + parentLocations) parentLocations - } } val location = if (locations.isEmpty) { None diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index afbe566b76566..dd8e46ba0f122 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -22,12 +22,14 @@ import java.io.FilenameFilter import java.io.IOException import java.io.PrintWriter import java.util.StringTokenizer +import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.ArrayBuffer import scala.io.Source import scala.reflect.ClassTag +import scala.util.control.NonFatal import org.apache.spark.{Partition, SparkEnv, TaskContext} import org.apache.spark.util.Utils @@ -118,63 +120,98 @@ private[spark] class PipedRDD[T: ClassTag]( val proc = pb.start() val env = SparkEnv.get + val childThreadException = new AtomicReference[Throwable](null) // Start a thread to print the process's stderr to ours - new Thread("stderr reader for " + command) { - override def run() { - for (line <- Source.fromInputStream(proc.getErrorStream).getLines) { - // scalastyle:off println - System.err.println(line) - // scalastyle:on println + new Thread(s"stderr reader for $command") { + override def run(): Unit = { + val err = proc.getErrorStream + try { + for (line <- Source.fromInputStream(err).getLines) { + // scalastyle:off println + System.err.println(line) + // scalastyle:on println + } + } catch { + case t: Throwable => childThreadException.set(t) + } finally { + err.close() } } }.start() // Start a thread to feed the process input from our parent's iterator - new Thread("stdin writer for " + command) { - override def run() { + new Thread(s"stdin writer for $command") { + override def run(): Unit = { TaskContext.setTaskContext(context) val out = new PrintWriter(proc.getOutputStream) - - // scalastyle:off println - // input the pipe context firstly - if (printPipeContext != null) { - printPipeContext(out.println(_)) - } - for (elem <- firstParent[T].iterator(split, context)) { - if (printRDDElement != null) { - printRDDElement(elem, out.println(_)) - } else { - out.println(elem) + try { + // scalastyle:off println + // input the pipe context firstly + if (printPipeContext != null) { + printPipeContext(out.println) + } + for (elem <- firstParent[T].iterator(split, context)) { + if (printRDDElement != null) { + printRDDElement(elem, out.println) + } else { + out.println(elem) + } } + // scalastyle:on println + } catch { + case t: Throwable => childThreadException.set(t) + } finally { + out.close() } - // scalastyle:on println - out.close() } }.start() // Return an iterator that read lines from the process's stdout val lines = Source.fromInputStream(proc.getInputStream).getLines() new Iterator[String] { - def next(): String = lines.next() - def hasNext: Boolean = { - if (lines.hasNext) { + def next(): String = { + if (!hasNext()) { + throw new NoSuchElementException() + } + lines.next() + } + + def hasNext(): Boolean = { + val result = if (lines.hasNext) { true } else { val exitStatus = proc.waitFor() + cleanup() if (exitStatus != 0) { - throw new Exception("Subprocess exited with status " + exitStatus) + throw new IllegalStateException(s"Subprocess exited with status $exitStatus. " + + s"Command ran: " + command.mkString(" ")) } + false + } + propagateChildException() + result + } - // cleanup task working directory if used - if (workInTaskDirectory) { - scala.util.control.Exception.ignoring(classOf[IOException]) { - Utils.deleteRecursively(new File(taskDirectory)) - } - logDebug("Removed task working directory " + taskDirectory) + private def cleanup(): Unit = { + // cleanup task working directory if used + if (workInTaskDirectory) { + scala.util.control.Exception.ignoring(classOf[IOException]) { + Utils.deleteRecursively(new File(taskDirectory)) } + logDebug(s"Removed task working directory $taskDirectory") + } + } - false + private def propagateChildException(): Unit = { + val t = childThreadException.get() + if (t != null) { + val commandRan = command.mkString(" ") + logError(s"Caught exception while running pipe() operator. Command ran: $commandRan. " + + s"Exception: ${t.getMessage}") + proc.destroy() + cleanup() + throw t } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 800ef53cbef07..36ff3bcaaec62 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -31,16 +31,17 @@ import org.apache.hadoop.mapred.TextOutputFormat import org.apache.spark._ import org.apache.spark.Partitioner._ -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaRDD +import org.apache.spark.internal.Logging import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult -import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.{BoundedPriorityQueue, Utils} import org.apache.spark.util.collection.OpenHashMap -import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, BernoulliCellSampler, +import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler, SamplingUtils} /** @@ -85,17 +86,21 @@ abstract class RDD[T: ClassTag]( private def sc: SparkContext = { if (_sc == null) { throw new SparkException( - "RDD transformations and actions can only be invoked by the driver, not inside of other " + - "transformations; for example, rdd1.map(x => rdd2.values.count() * x) is invalid because " + - "the values transformation and count action cannot be performed inside of the rdd1.map " + - "transformation. For more information, see SPARK-5063.") + "This RDD lacks a SparkContext. It could happen in the following cases: \n(1) RDD " + + "transformations and actions are NOT invoked by the driver, but inside of other " + + "transformations; for example, rdd1.map(x => rdd2.values.count() * x) is invalid " + + "because the values transformation and count action cannot be performed inside of the " + + "rdd1.map transformation. For more information, see SPARK-5063.\n(2) When a Spark " + + "Streaming job recovers from checkpoint, this exception will be hit if a reference to " + + "an RDD not defined by the streaming job is used in DStream operations. For more " + + "information, See SPARK-13758.") } _sc } /** Construct an RDD with just a one-to-one dependency on one parent */ def this(@transient oneParent: RDD[_]) = - this(oneParent.context , List(new OneToOneDependency(oneParent))) + this(oneParent.context, List(new OneToOneDependency(oneParent))) private[spark] def conf = sc.conf // ======================================================================= @@ -112,6 +117,9 @@ abstract class RDD[T: ClassTag]( /** * Implemented by subclasses to return the set of partitions in this RDD. This method will only * be called once, so it is safe to implement a time-consuming computation in it. + * + * The partitions in this array must satisfy the following property: + * `rdd.partitions.zipWithIndex.forall { case (partition, index) => partition.index == index }` */ protected def getPartitions: Array[Partition] @@ -237,11 +245,21 @@ abstract class RDD[T: ClassTag]( checkpointRDD.map(_.partitions).getOrElse { if (partitions_ == null) { partitions_ = getPartitions + partitions_.zipWithIndex.foreach { case (partition, index) => + require(partition.index == index, + s"partitions($index).partition == ${partition.index}, but it should equal $index") + } } partitions_ } } + /** + * Returns the number of partitions of this RDD. + */ + @Since("1.6.0") + final def getNumPartitions: Int = partitions.length + /** * Get the preferred locations of a partition, taking into account whether the * RDD is checkpointed. @@ -259,7 +277,7 @@ abstract class RDD[T: ClassTag]( */ final def iterator(split: Partition, context: TaskContext): Iterator[T] = { if (storageLevel != StorageLevel.NONE) { - SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel) + getOrCompute(split, context) } else { computeOrReadCheckpoint(split, context) } @@ -301,6 +319,35 @@ abstract class RDD[T: ClassTag]( } } + /** + * Gets or computes an RDD partition. Used by RDD.iterator() when an RDD is cached. + */ + private[spark] def getOrCompute(partition: Partition, context: TaskContext): Iterator[T] = { + val blockId = RDDBlockId(id, partition.index) + var readCachedBlock = true + // This method is called on executors, so we need call SparkEnv.get instead of sc.env. + SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, elementClassTag, () => { + readCachedBlock = false + computeOrReadCheckpoint(partition, context) + }) match { + case Left(blockResult) => + if (readCachedBlock) { + val existingMetrics = context.taskMetrics().registerInputMetrics(blockResult.readMethod) + existingMetrics.incBytesRead(blockResult.bytes) + new InterruptibleIterator[T](context, blockResult.data.asInstanceOf[Iterator[T]]) { + override def next(): T = { + existingMetrics.incRecordsRead(1) + delegate.next() + } + } + } else { + new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]]) + } + case Right(iter) => + new InterruptibleIterator(context, iter.asInstanceOf[Iterator[T]]) + } + } + /** * Execute a block of code in a scope such that all new RDDs created in this body will * be part of the same scope. For more detail, see {{org.apache.spark.rdd.RDDOperationScope}}. @@ -468,6 +515,9 @@ abstract class RDD[T: ClassTag]( /** * Return a fixed-size sampled subset of this RDD in an array * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. + * * @param withReplacement whether sampling is done with replacement * @param num size of the returned sample * @param seed seed for the random number generator @@ -518,11 +568,7 @@ abstract class RDD[T: ClassTag]( * times (use `.distinct()` to eliminate them). */ def union(other: RDD[T]): RDD[T] = withScope { - if (partitioner.isDefined && other.partitioner == partitioner) { - new PartitionerAwareUnionRDD(sc, Array(this, other)) - } else { - new UnionRDD(sc, Array(this, other)) - } + sc.union(this, other) } /** @@ -673,7 +719,7 @@ abstract class RDD[T: ClassTag]( * An example of pipe the RDD data of groupBy() in a streaming way, * instead of constructing a huge String to concat all the elements: * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = - * for (e <- record._2){f(e)} + * for (e <- record._2) {f(e)} * @param separateWorkingDir Use separate working directories for each task. * @return the result RDD */ @@ -706,113 +752,38 @@ abstract class RDD[T: ClassTag]( } /** - * Return a new RDD by applying a function to each partition of this RDD, while tracking the index - * of the original partition. + * [performance] Spark's internal mapPartitions method which skips closure cleaning. It is a + * performance API to be used carefully only if we are sure that the RDD elements are + * serializable and don't require closure cleaning. * - * `preservesPartitioning` indicates whether the input function preserves the partitioner, which - * should be `false` unless this is a pair RDD and the input function doesn't modify the keys. + * @param preservesPartitioning indicates whether the input function preserves the partitioner, + * which should be `false` unless this is a pair RDD and the input function doesn't modify + * the keys. */ - def mapPartitionsWithIndex[U: ClassTag]( - f: (Int, Iterator[T]) => Iterator[U], + private[spark] def mapPartitionsInternal[U: ClassTag]( + f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = withScope { - val cleanedF = sc.clean(f) new MapPartitionsRDD( this, - (context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(index, iter), + (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter), preservesPartitioning) } /** - * :: DeveloperApi :: - * Return a new RDD by applying a function to each partition of this RDD. This is a variant of - * mapPartitions that also passes the TaskContext into the closure. + * Return a new RDD by applying a function to each partition of this RDD, while tracking the index + * of the original partition. * * `preservesPartitioning` indicates whether the input function preserves the partitioner, which * should be `false` unless this is a pair RDD and the input function doesn't modify the keys. */ - @DeveloperApi - @deprecated("use TaskContext.get", "1.2.0") - def mapPartitionsWithContext[U: ClassTag]( - f: (TaskContext, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = withScope { - val cleanF = sc.clean(f) - val func = (context: TaskContext, index: Int, iter: Iterator[T]) => cleanF(context, iter) - new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) - } - - /** - * Return a new RDD by applying a function to each partition of this RDD, while tracking the index - * of the original partition. - */ - @deprecated("use mapPartitionsWithIndex", "0.7.0") - def mapPartitionsWithSplit[U: ClassTag]( + def mapPartitionsWithIndex[U: ClassTag]( f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = withScope { - mapPartitionsWithIndex(f, preservesPartitioning) - } - - /** - * Maps f over this RDD, where f takes an additional parameter of type A. This - * additional parameter is produced by constructA, which is called in each - * partition with the index of that partition. - */ - @deprecated("use mapPartitionsWithIndex", "1.0.0") - def mapWith[A, U: ClassTag] - (constructA: Int => A, preservesPartitioning: Boolean = false) - (f: (T, A) => U): RDD[U] = withScope { - val cleanF = sc.clean(f) - val cleanA = sc.clean(constructA) - mapPartitionsWithIndex((index, iter) => { - val a = cleanA(index) - iter.map(t => cleanF(t, a)) - }, preservesPartitioning) - } - - /** - * FlatMaps f over this RDD, where f takes an additional parameter of type A. This - * additional parameter is produced by constructA, which is called in each - * partition with the index of that partition. - */ - @deprecated("use mapPartitionsWithIndex and flatMap", "1.0.0") - def flatMapWith[A, U: ClassTag] - (constructA: Int => A, preservesPartitioning: Boolean = false) - (f: (T, A) => Seq[U]): RDD[U] = withScope { - val cleanF = sc.clean(f) - val cleanA = sc.clean(constructA) - mapPartitionsWithIndex((index, iter) => { - val a = cleanA(index) - iter.flatMap(t => cleanF(t, a)) - }, preservesPartitioning) - } - - /** - * Applies f to each element of this RDD, where f takes an additional parameter of type A. - * This additional parameter is produced by constructA, which is called in each - * partition with the index of that partition. - */ - @deprecated("use mapPartitionsWithIndex and foreach", "1.0.0") - def foreachWith[A](constructA: Int => A)(f: (T, A) => Unit): Unit = withScope { - val cleanF = sc.clean(f) - val cleanA = sc.clean(constructA) - mapPartitionsWithIndex { (index, iter) => - val a = cleanA(index) - iter.map(t => {cleanF(t, a); t}) - } - } - - /** - * Filters this RDD with p, where p takes an additional parameter of type A. This - * additional parameter is produced by constructA, which is called in each - * partition with the index of that partition. - */ - @deprecated("use mapPartitionsWithIndex and filter", "1.0.0") - def filterWith[A](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = withScope { - val cleanP = sc.clean(p) - val cleanA = sc.clean(constructA) - mapPartitionsWithIndex((index, iter) => { - val a = cleanA(index) - iter.filter(t => cleanP(t, a)) - }, preservesPartitioning = true) + val cleanedF = sc.clean(f) + new MapPartitionsRDD( + this, + (context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(index, iter), + preservesPartitioning) } /** @@ -898,6 +869,9 @@ abstract class RDD[T: ClassTag]( /** * Return an array that contains all of the elements in this RDD. + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. */ def collect(): Array[T] = withScope { val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray) @@ -920,14 +894,6 @@ abstract class RDD[T: ClassTag]( (0 until partitions.length).iterator.flatMap(i => collectPartition(i)) } - /** - * Return an array that contains all of the elements in this RDD. - */ - @deprecated("use collect", "1.0.0") - def toArray(): Array[T] = withScope { - collect() - } - /** * Return an RDD that contains all matching values by applying `f`. */ @@ -1037,7 +1003,7 @@ abstract class RDD[T: ClassTag]( /** * Aggregate the elements of each partition, and then the results for all the partitions, using a - * given associative and commutative function and a neutral "zero value". The function + * given associative function and a neutral "zero value". The function * op(t1, t2) is allowed to modify t1 and return it as its result value to avoid object * allocation; however, it should not modify t2. * @@ -1047,6 +1013,13 @@ abstract class RDD[T: ClassTag]( * apply the fold to each element sequentially in some defined ordering. For functions * that are not commutative, the result may differ from that of a fold applied to a * non-distributed collection. + * + * @param zeroValue the initial value for the accumulated result of each partition for the `op` + * operator, and also the initial value for the combine results from different + * partitions for the `op` operator - this will typically be the neutral + * element (e.g. `Nil` for list concatenation or `0` for summation) + * @param op an operator used to both accumulate results within a partition and combine results + * from different partitions */ def fold(zeroValue: T)(op: (T, T) => T): T = withScope { // Clone the zero value since we will also be serializing it as part of tasks @@ -1065,6 +1038,13 @@ abstract class RDD[T: ClassTag]( * and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are * allowed to modify and return their first argument instead of creating a new U to avoid memory * allocation. + * + * @param zeroValue the initial value for the accumulated result of each partition for the + * `seqOp` operator, and also the initial value for the combine results from + * different partitions for the `combOp` operator - this will typically be the + * neutral element (e.g. `Nil` for list concatenation or `0` for summation) + * @param seqOp an operator used to accumulate results within a partition + * @param combOp an associative operator used to combine results from different partitions */ def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = withScope { // Clone the zero value since we will also be serializing it as part of tasks @@ -1258,6 +1238,9 @@ abstract class RDD[T: ClassTag]( * results from that partition to estimate the number of additional partitions needed to satisfy * the limit. * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. + * * @note due to complications in the internal implementation, this method will raise * an exception if called on an RDD of `Nothing` or `Null`. */ @@ -1271,7 +1254,7 @@ abstract class RDD[T: ClassTag]( while (buf.size < num && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1 + var numPartsToTry = 1L if (partsScanned > 0) { // If we didn't find any rows after the previous iteration, quadruple and retry. // Otherwise, interpolate the number of partitions we need to try, but overestimate @@ -1286,11 +1269,11 @@ abstract class RDD[T: ClassTag]( } val left = num - buf.size - val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p) res.foreach(buf ++= _.take(num - buf.size)) - partsScanned += numPartsToTry + partsScanned += p.size } buf.toArray @@ -1309,7 +1292,8 @@ abstract class RDD[T: ClassTag]( /** * Returns the top k (largest) elements from this RDD as defined by the specified - * implicit Ordering[T]. This does the opposite of [[takeOrdered]]. For example: + * implicit Ordering[T] and maintains the ordering. This does the opposite of + * [[takeOrdered]]. For example: * {{{ * sc.parallelize(Seq(10, 4, 2, 12, 3)).top(1) * // returns Array(12) @@ -1318,6 +1302,9 @@ abstract class RDD[T: ClassTag]( * // returns Array(6, 5) * }}} * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. + * * @param num k, the number of top elements to return * @param ord the implicit ordering for T * @return an array of top elements @@ -1338,6 +1325,9 @@ abstract class RDD[T: ClassTag]( * // returns Array(2, 3) * }}} * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. + * * @param num k, the number of elements to return * @param ord the implicit ordering for T * @return an array of top elements @@ -1597,6 +1587,15 @@ abstract class RDD[T: ClassTag]( private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None + // Whether to checkpoint all ancestor RDDs that are marked for checkpointing. By default, + // we stop as soon as we find the first such RDD, an optimization that allows us to write + // less data but is not safe for all workloads. E.g. in streaming we may checkpoint both + // an RDD and its parent in every batch, in which case the parent may never be checkpointed + // and its lineage never truncated, leading to OOMs in the long run (SPARK-6847). + private val checkpointAllMarkedAncestors = + Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)) + .map(_.toBoolean).getOrElse(false) + /** Returns the first parent RDD */ protected[spark] def firstParent[U: ClassTag]: RDD[U] = { dependencies.head.rdd.asInstanceOf[RDD[U]] @@ -1640,6 +1639,13 @@ abstract class RDD[T: ClassTag]( if (!doCheckpointCalled) { doCheckpointCalled = true if (checkpointData.isDefined) { + if (checkpointAllMarkedAncestors) { + // TODO We can collect all the RDDs that needs to be checkpointed, and then checkpoint + // them in parallel. + // Checkpoint parents first because our lineage will be truncated after we + // checkpoint ourselves + dependencies.foreach(_.rdd.doCheckpoint()) + } checkpointData.get.checkpoint() } else { dependencies.foreach(_.rdd.doCheckpoint()) @@ -1759,6 +1765,9 @@ abstract class RDD[T: ClassTag]( */ object RDD { + private[spark] val CHECKPOINT_ALL_MARKED_ANCESTORS = + "spark.checkpoint.checkpointAllMarkedAncestors" + // The following implicit functions were in SparkContext before 1.3 and users had to // `import SparkContext._` to enable them. Now we move them here to make the compiler find // them automatically. However, we still keep the old functions in SparkContext for backward diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala index 540cbd688b63b..53d69ba26811f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala @@ -25,7 +25,8 @@ import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.scala.DefaultScalaModule import com.google.common.base.Objects -import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.SparkContext +import org.apache.spark.internal.Logging /** * A general, named code block representing an operation that instantiates RDDs. diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index a69be6a068bbf..fddb9353018a8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -20,12 +20,13 @@ package org.apache.spark.rdd import java.io.IOException import scala.reflect.ClassTag +import scala.util.control.NonFatal import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging import org.apache.spark.util.{SerializableConfiguration, Utils} /** @@ -33,8 +34,9 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} */ private[spark] class ReliableCheckpointRDD[T: ClassTag]( sc: SparkContext, - val checkpointPath: String) - extends CheckpointRDD[T](sc) { + val checkpointPath: String, + _partitioner: Option[Partitioner] = None + ) extends CheckpointRDD[T](sc) { @transient private val hadoopConf = sc.hadoopConfiguration @transient private val cpath = new Path(checkpointPath) @@ -47,7 +49,13 @@ private[spark] class ReliableCheckpointRDD[T: ClassTag]( /** * Return the path of the checkpoint directory this RDD reads data from. */ - override def getCheckpointFile: Option[String] = Some(checkpointPath) + override val getCheckpointFile: Option[String] = Some(checkpointPath) + + override val partitioner: Option[Partitioner] = { + _partitioner.orElse { + ReliableCheckpointRDD.readCheckpointedPartitionerFile(context, checkpointPath) + } + } /** * Return partitions described by the files in the checkpoint directory. @@ -100,10 +108,52 @@ private[spark] object ReliableCheckpointRDD extends Logging { "part-%05d".format(partitionIndex) } + private def checkpointPartitionerFileName(): String = { + "_partitioner" + } + + /** + * Write RDD to checkpoint files and return a ReliableCheckpointRDD representing the RDD. + */ + def writeRDDToCheckpointDirectory[T: ClassTag]( + originalRDD: RDD[T], + checkpointDir: String, + blockSize: Int = -1): ReliableCheckpointRDD[T] = { + + val sc = originalRDD.sparkContext + + // Create the output path for the checkpoint + val checkpointDirPath = new Path(checkpointDir) + val fs = checkpointDirPath.getFileSystem(sc.hadoopConfiguration) + if (!fs.mkdirs(checkpointDirPath)) { + throw new SparkException(s"Failed to create checkpoint path $checkpointDirPath") + } + + // Save to file, and reload it as an RDD + val broadcastedConf = sc.broadcast( + new SerializableConfiguration(sc.hadoopConfiguration)) + // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582) + sc.runJob(originalRDD, + writePartitionToCheckpointFile[T](checkpointDirPath.toString, broadcastedConf) _) + + if (originalRDD.partitioner.nonEmpty) { + writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath) + } + + val newRDD = new ReliableCheckpointRDD[T]( + sc, checkpointDirPath.toString, originalRDD.partitioner) + if (newRDD.partitions.length != originalRDD.partitions.length) { + throw new SparkException( + s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " + + s"number of partitions from original RDD $originalRDD(${originalRDD.partitions.length})") + } + newRDD + } + /** - * Write this partition's values to a checkpoint file. + * Write a RDD partition's data to a checkpoint file. */ - def writeCheckpointFile[T: ClassTag]( + def writePartitionToCheckpointFile[T: ClassTag]( path: String, broadcastedConf: Broadcast[SerializableConfiguration], blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) { @@ -125,7 +175,8 @@ private[spark] object ReliableCheckpointRDD extends Logging { fs.create(tempOutputPath, false, bufferSize) } else { // This is mainly for testing purpose - fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize) + fs.create(tempOutputPath, false, bufferSize, + fs.getDefaultReplication(fs.getWorkingDirectory), blockSize) } val serializer = env.serializer.newInstance() val serializeStream = serializer.serializeStream(fileOutputStream) @@ -151,6 +202,67 @@ private[spark] object ReliableCheckpointRDD extends Logging { } } + /** + * Write a partitioner to the given RDD checkpoint directory. This is done on a best-effort + * basis; any exception while writing the partitioner is caught, logged and ignored. + */ + private def writePartitionerToCheckpointDir( + sc: SparkContext, partitioner: Partitioner, checkpointDirPath: Path): Unit = { + try { + val partitionerFilePath = new Path(checkpointDirPath, checkpointPartitionerFileName) + val bufferSize = sc.conf.getInt("spark.buffer.size", 65536) + val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration) + val fileOutputStream = fs.create(partitionerFilePath, false, bufferSize) + val serializer = SparkEnv.get.serializer.newInstance() + val serializeStream = serializer.serializeStream(fileOutputStream) + Utils.tryWithSafeFinally { + serializeStream.writeObject(partitioner) + } { + serializeStream.close() + } + logDebug(s"Written partitioner to $partitionerFilePath") + } catch { + case NonFatal(e) => + logWarning(s"Error writing partitioner $partitioner to $checkpointDirPath") + } + } + + + /** + * Read a partitioner from the given RDD checkpoint directory, if it exists. + * This is done on a best-effort basis; any exception while reading the partitioner is + * caught, logged and ignored. + */ + private def readCheckpointedPartitionerFile( + sc: SparkContext, + checkpointDirPath: String): Option[Partitioner] = { + try { + val bufferSize = sc.conf.getInt("spark.buffer.size", 65536) + val partitionerFilePath = new Path(checkpointDirPath, checkpointPartitionerFileName) + val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration) + if (fs.exists(partitionerFilePath)) { + val fileInputStream = fs.open(partitionerFilePath, bufferSize) + val serializer = SparkEnv.get.serializer.newInstance() + val deserializeStream = serializer.deserializeStream(fileInputStream) + val partitioner = Utils.tryWithSafeFinally[Partitioner] { + deserializeStream.readObject[Partitioner] + } { + deserializeStream.close() + } + logDebug(s"Read partitioner from $partitionerFilePath") + Some(partitioner) + } else { + logDebug("No partitioner file") + None + } + } catch { + case NonFatal(e) => + logWarning(s"Error reading partitioner from $checkpointDirPath, " + + s"partitioner will not be recovered which may lead to performance loss", e) + None + } + } + /** * Read the content of the specified checkpoint file. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala index 91cad6662e4d2..74f187642af21 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag import org.apache.hadoop.fs.Path import org.apache.spark._ -import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.internal.Logging /** * An implementation of checkpointing that writes the RDD data to reliable storage. @@ -55,25 +55,7 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v * This is called immediately after the first action invoked on this RDD has completed. */ protected override def doCheckpoint(): CheckpointRDD[T] = { - - // Create the output path for the checkpoint - val path = new Path(cpDir) - val fs = path.getFileSystem(rdd.context.hadoopConfiguration) - if (!fs.mkdirs(path)) { - throw new SparkException(s"Failed to create checkpoint path $cpDir") - } - - // Save to file, and reload it as an RDD - val broadcastedConf = rdd.context.broadcast( - new SerializableConfiguration(rdd.context.hadoopConfiguration)) - // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582) - rdd.context.runJob(rdd, ReliableCheckpointRDD.writeCheckpointFile[T](cpDir, broadcastedConf) _) - val newRDD = new ReliableCheckpointRDD[T](rdd.context, cpDir) - if (newRDD.partitions.length != rdd.partitions.length) { - throw new SparkException( - s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " + - s"number of partitions from original RDD $rdd(${rdd.partitions.length})") - } + val newRDD = ReliableCheckpointRDD.writeRDDToCheckpointDirectory(rdd, cpDir) // Optionally clean our checkpoint files if the reference is out of scope if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) { @@ -83,7 +65,6 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v } logInfo(s"Done checkpointing RDD ${rdd.id} to $cpDir, new parent is RDD ${newRDD.id}") - newRDD } diff --git a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala deleted file mode 100644 index 9e8cee5331cf8..0000000000000 --- a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rdd - -import java.util.Random - -import scala.reflect.ClassTag - -import org.apache.commons.math3.distribution.PoissonDistribution - -import org.apache.spark.{Partition, TaskContext} - -@deprecated("Replaced by PartitionwiseSampledRDDPartition", "1.0.0") -private[spark] -class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable { - override val index: Int = prev.index -} - -@deprecated("Replaced by PartitionwiseSampledRDD", "1.0.0") -private[spark] class SampledRDD[T: ClassTag]( - prev: RDD[T], - withReplacement: Boolean, - frac: Double, - seed: Int) - extends RDD[T](prev) { - - override def getPartitions: Array[Partition] = { - val rg = new Random(seed) - firstParent[T].partitions.map(x => new SampledRDDPartition(x, rg.nextInt)) - } - - override def getPreferredLocations(split: Partition): Seq[String] = - firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDPartition].prev) - - override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = { - val split = splitIn.asInstanceOf[SampledRDDPartition] - if (withReplacement) { - // For large datasets, the expected number of occurrences of each element in a sample with - // replacement is Poisson(frac). We use that to get a count for each element. - val poisson = new PoissonDistribution(frac) - poisson.reseedRandomGenerator(split.seed) - - firstParent[T].iterator(split.prev, context).flatMap { element => - val count = poisson.sample() - if (count == 0) { - Iterator.empty // Avoid object allocation when we return 0 items, which is quite often - } else { - Iterator.fill(count)(element) - } - } - } else { // Sampling without replacement - val rand = new Random(split.seed) - firstParent[T].iterator(split.prev, context).filter(x => (rand.nextDouble <= frac)) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index 4b5f15dd06b85..1311b481c7c71 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -16,14 +16,14 @@ */ package org.apache.spark.rdd -import scala.reflect.{ClassTag, classTag} +import scala.reflect.{classTag, ClassTag} import org.apache.hadoop.io.Writable import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapred.SequenceFileOutputFormat -import org.apache.spark.Logging +import org.apache.spark.internal.Logging /** * Extra functions available on RDDs of (key, value) pairs to create a Hadoop SequenceFile, @@ -38,11 +38,6 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag extends Logging with Serializable { - @deprecated("It's used to provide backward compatibility for pre 1.3.0.", "1.3.0") - def this(self: RDD[(K, V)]) { - this(self, null, null) - } - private val keyWritableClass = if (_keyWritableClass == null) { // pre 1.3.0, we need to use Reflection to get the Writable class diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index a013c3f66a3a8..800b42505de10 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -44,7 +44,7 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag]( part: Partitioner) extends RDD[(K, C)](prev.context, Nil) { - private var serializer: Option[Serializer] = None + private var userSpecifiedSerializer: Option[Serializer] = None private var keyOrdering: Option[Ordering[K]] = None @@ -54,7 +54,7 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag]( /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C] = { - this.serializer = Option(serializer) + this.userSpecifiedSerializer = Option(serializer) this } @@ -77,6 +77,14 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag]( } override def getDependencies: Seq[Dependency[_]] = { + val serializer = userSpecifiedSerializer.getOrElse { + val serializerManager = SparkEnv.get.serializerManager + if (mapSideCombine) { + serializerManager.getSerializer(implicitly[ClassTag[K]], implicitly[ClassTag[C]]) + } else { + serializerManager.getSerializer(implicitly[ClassTag[K]], implicitly[ClassTag[V]]) + } + } List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine)) } @@ -86,7 +94,7 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag]( Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i)) } - override def getPreferredLocations(partition: Partition): Seq[String] = { + override protected def getPreferredLocations(partition: Partition): Seq[String] = { val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] tracker.getPreferredLocationsForShuffle(dep, partition.index) diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala deleted file mode 100644 index 264dae7f39085..0000000000000 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ /dev/null @@ -1,290 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rdd - -import java.text.SimpleDateFormat -import java.util.Date - -import scala.reflect.ClassTag - -import org.apache.hadoop.conf.{Configurable, Configuration} -import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.executor.DataReadMethod -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.{Partition => SparkPartition, _} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils} - - -private[spark] class SqlNewHadoopPartition( - rddId: Int, - val index: Int, - rawSplit: InputSplit with Writable) - extends SparkPartition { - - val serializableHadoopSplit = new SerializableWritable(rawSplit) - - override def hashCode(): Int = 41 * (41 + rddId) + index -} - -/** - * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, - * sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`). - * It is based on [[org.apache.spark.rdd.NewHadoopRDD]]. It has three additions. - * 1. A shared broadcast Hadoop Configuration. - * 2. An optional closure `initDriverSideJobFuncOpt` that set configurations at the driver side - * to the shared Hadoop Configuration. - * 3. An optional closure `initLocalJobFuncOpt` that set configurations at both the driver side - * and the executor side to the shared Hadoop Configuration. - * - * Note: This is RDD is basically a cloned version of [[org.apache.spark.rdd.NewHadoopRDD]] with - * changes based on [[org.apache.spark.rdd.HadoopRDD]]. - */ -private[spark] class SqlNewHadoopRDD[V: ClassTag]( - sc : SparkContext, - broadcastedConf: Broadcast[SerializableConfiguration], - @transient private val initDriverSideJobFuncOpt: Option[Job => Unit], - initLocalJobFuncOpt: Option[Job => Unit], - inputFormatClass: Class[_ <: InputFormat[Void, V]], - valueClass: Class[V]) - extends RDD[V](sc, Nil) - with SparkHadoopMapReduceUtil - with Logging { - - protected def getJob(): Job = { - val conf: Configuration = broadcastedConf.value.value - // "new Job" will make a copy of the conf. Then, it is - // safe to mutate conf properties with initLocalJobFuncOpt - // and initDriverSideJobFuncOpt. - val newJob = new Job(conf) - initLocalJobFuncOpt.map(f => f(newJob)) - newJob - } - - def getConf(isDriverSide: Boolean): Configuration = { - val job = getJob() - if (isDriverSide) { - initDriverSideJobFuncOpt.map(f => f(job)) - } - SparkHadoopUtil.get.getConfigurationFromJobContext(job) - } - - private val jobTrackerId: String = { - val formatter = new SimpleDateFormat("yyyyMMddHHmm") - formatter.format(new Date()) - } - - @transient protected val jobId = new JobID(jobTrackerId, id) - - override def getPartitions: Array[SparkPartition] = { - val conf = getConf(isDriverSide = true) - val inputFormat = inputFormatClass.newInstance - inputFormat match { - case configurable: Configurable => - configurable.setConf(conf) - case _ => - } - val jobContext = newJobContext(conf, jobId) - val rawSplits = inputFormat.getSplits(jobContext).toArray - val result = new Array[SparkPartition](rawSplits.size) - for (i <- 0 until rawSplits.size) { - result(i) = - new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) - } - result - } - - override def compute( - theSplit: SparkPartition, - context: TaskContext): Iterator[V] = { - val iter = new Iterator[V] { - val split = theSplit.asInstanceOf[SqlNewHadoopPartition] - logInfo("Input split: " + split.serializableHadoopSplit) - val conf = getConf(isDriverSide = false) - - val inputMetrics = context.taskMetrics - .getInputMetricsForReadMethod(DataReadMethod.Hadoop) - - // Sets the thread local variable for the file's name - split.serializableHadoopSplit.value match { - case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString) - case _ => SqlNewHadoopRDD.unsetInputFileName() - } - - // Find a function that will return the FileSystem bytes read by this thread. Do this before - // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { - split.serializableHadoopSplit.value match { - case _: FileSplit | _: CombineFileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() - case _ => None - } - } - inputMetrics.setBytesReadCallback(bytesReadCallback) - - val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) - val format = inputFormatClass.newInstance - format match { - case configurable: Configurable => - configurable.setConf(conf) - case _ => - } - private[this] var reader = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) - - // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener(context => close()) - - private[this] var havePair = false - private[this] var finished = false - - override def hasNext: Boolean = { - if (context.isInterrupted) { - throw new TaskKilledException - } - if (!finished && !havePair) { - finished = !reader.nextKeyValue - if (finished) { - // Close and release the reader here; close() will also be called when the task - // completes, but for tasks that read from many files, it helps to release the - // resources early. - close() - } - havePair = !finished - } - !finished - } - - override def next(): V = { - if (!hasNext) { - throw new java.util.NoSuchElementException("End of stream") - } - havePair = false - if (!finished) { - inputMetrics.incRecordsRead(1) - } - reader.getCurrentValue - } - - private def close() { - if (reader != null) { - SqlNewHadoopRDD.unsetInputFileName() - // Close the reader and release it. Note: it's very important that we don't close the - // reader more than once, since that exposes us to MAPREDUCE-5918 when running against - // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic - // corruption issues when reading compressed input. - try { - reader.close() - } catch { - case e: Exception => - if (!ShutdownHookManager.inShutdown()) { - logWarning("Exception in RecordReader.close()", e) - } - } finally { - reader = null - } - if (bytesReadCallback.isDefined) { - inputMetrics.updateBytesRead() - } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || - split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { - // If we can't get the bytes read from the FS stats, fall back to the split size, - // which may be inaccurate. - try { - inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) - } catch { - case e: java.io.IOException => - logWarning("Unable to get input size to set InputMetrics for task", e) - } - } - } - } - } - iter - } - - override def getPreferredLocations(hsplit: SparkPartition): Seq[String] = { - val split = hsplit.asInstanceOf[SqlNewHadoopPartition].serializableHadoopSplit.value - val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match { - case Some(c) => - try { - val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] - Some(HadoopRDD.convertSplitLocationInfo(infos)) - } catch { - case e : Exception => - logDebug("Failed to use InputSplit#getLocationInfo.", e) - None - } - case None => None - } - locs.getOrElse(split.getLocations.filter(_ != "localhost")) - } - - override def persist(storageLevel: StorageLevel): this.type = { - if (storageLevel.deserialized) { - logWarning("Caching NewHadoopRDDs as deserialized objects usually leads to undesired" + - " behavior because Hadoop's RecordReader reuses the same Writable object for all records." + - " Use a map transformation to make copies of the records.") - } - super.persist(storageLevel) - } -} - -private[spark] object SqlNewHadoopRDD { - - /** - * The thread variable for the name of the current file being read. This is used by - * the InputFileName function in Spark SQL. - */ - private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { - override protected def initialValue(): UTF8String = UTF8String.fromString("") - } - - def getInputFileName(): UTF8String = inputFileName.get() - - private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) - - private[spark] def unsetInputFileName(): Unit = inputFileName.remove() - - /** - * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to - * the given function rather than the index of the partition. - */ - private[spark] class NewHadoopMapPartitionsWithSplitRDD[U: ClassTag, T: ClassTag]( - prev: RDD[T], - f: (InputSplit, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false) - extends RDD[U](prev) { - - override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None - - override def getPartitions: Array[SparkPartition] = firstParent[T].partitions - - override def compute(split: SparkPartition, context: TaskContext): Iterator[U] = { - val partition = split.asInstanceOf[SqlNewHadoopPartition] - val inputSplit = partition.serializableHadoopSplit.value - f(inputSplit, firstParent[T].iterator(split, context)) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index 25ec685eff5ab..a733eaa5d7e53 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -30,7 +30,6 @@ import org.apache.spark.Partitioner import org.apache.spark.ShuffleDependency import org.apache.spark.SparkEnv import org.apache.spark.TaskContext -import org.apache.spark.serializer.Serializer /** * An optimized version of cogroup for set difference/subtraction. @@ -54,13 +53,6 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) { - private var serializer: Option[Serializer] = None - - /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ - def setSerializer(serializer: Serializer): SubtractedRDD[K, V, W] = { - this.serializer = Option(serializer) - this - } override def getDependencies: Seq[Dependency[_]] = { def rddDependency[T1: ClassTag, T2: ClassTag](rdd: RDD[_ <: Product2[T1, T2]]) @@ -70,7 +62,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( new OneToOneDependency(rdd) } else { logDebug("Adding shuffle dependency with " + rdd) - new ShuffleDependency[T1, T2, Any](rdd, part, serializer) + new ShuffleDependency[T1, T2, Any](rdd, part) } } Seq(rddDependency[K, V](rdd1), rddDependency[K, W](rdd2)) diff --git a/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala new file mode 100644 index 0000000000000..8e1baae796fc5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.hadoop.conf.{Configurable, Configuration} +import org.apache.hadoop.io.{Text, Writable} +import org.apache.hadoop.mapreduce.InputSplit +import org.apache.hadoop.mapreduce.task.JobContextImpl + +import org.apache.spark.{Partition, SparkContext} +import org.apache.spark.input.WholeTextFileInputFormat + +/** + * An RDD that reads a bunch of text files in, and each text file becomes one record. + */ +private[spark] class WholeTextFileRDD( + sc : SparkContext, + inputFormatClass: Class[_ <: WholeTextFileInputFormat], + keyClass: Class[Text], + valueClass: Class[Text], + conf: Configuration, + minPartitions: Int) + extends NewHadoopRDD[Text, Text](sc, inputFormatClass, keyClass, valueClass, conf) { + + override def getPartitions: Array[Partition] = { + val inputFormat = inputFormatClass.newInstance + val conf = getConf + inputFormat match { + case configurable: Configurable => + configurable.setConf(conf) + case _ => + } + val jobContext = new JobContextImpl(conf, jobId) + inputFormat.setMinPartitions(jobContext, minPartitions) + val rawSplits = inputFormat.getSplits(jobContext).toArray + val result = new Array[Partition](rawSplits.size) + for (i <- 0 until rawSplits.size) { + result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } + result + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index 4333a679c8aae..3cb1231bd3477 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -54,7 +54,8 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag]( override def getPartitions: Array[Partition] = { val numParts = rdds.head.partitions.length if (!rdds.forall(rdd => rdd.partitions.length == numParts)) { - throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") + throw new IllegalArgumentException( + s"Can't zip RDDs with unequal numbers of partitions: ${rdds.map(_.partitions.length)}") } Array.tabulate[Partition](numParts) { i => val prefs = rdds.map(rdd => rdd.preferredLocations(rdd.partitions(i))) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala similarity index 87% rename from core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala rename to core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala index d2e94f943aba5..b9db60a7797d8 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala @@ -15,10 +15,9 @@ * limitations under the License. */ -package org.apache.spark.rpc.netty +package org.apache.spark.rpc import org.apache.spark.SparkException -import org.apache.spark.rpc.RpcAddress /** * An address identifier for an RPC endpoint. @@ -26,10 +25,10 @@ import org.apache.spark.rpc.RpcAddress * The `rpcAddress` may be null, in which case the endpoint is registered via a client-only * connection and can only be reached via the client that sent the endpoint reference. * - * @param rpcAddress The socket address of the endpint. + * @param rpcAddress The socket address of the endpoint. * @param name Name of the endpoint. */ -private[netty] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) { +private[spark] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) { require(name != null, "RpcEndpoint name must be provided.") @@ -44,7 +43,11 @@ private[netty] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val nam } } -private[netty] object RpcEndpointAddress { +private[spark] object RpcEndpointAddress { + + def apply(host: String, port: Int, name: String): RpcEndpointAddress = { + new RpcEndpointAddress(host, port, name) + } def apply(sparkUrl: String): RpcEndpointAddress = { try { diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala index 623da3e9c11b8..994e18676ec49 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -20,8 +20,9 @@ package org.apache.spark.rpc import scala.concurrent.Future import scala.reflect.ClassTag +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.internal.Logging import org.apache.spark.util.RpcUtils -import org.apache.spark.{SparkException, Logging, SparkConf} /** * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index a560fd10cdf76..56683771335a6 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -17,10 +17,14 @@ package org.apache.spark.rpc +import java.io.File +import java.nio.channels.ReadableByteChannel + import scala.concurrent.Future import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.util.{RpcUtils, Utils} +import org.apache.spark.rpc.netty.NettyRpcEnvFactory +import org.apache.spark.util.RpcUtils /** @@ -29,15 +33,6 @@ import org.apache.spark.util.{RpcUtils, Utils} */ private[spark] object RpcEnv { - private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = { - val rpcEnvNames = Map( - "akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory", - "netty" -> "org.apache.spark.rpc.netty.NettyRpcEnvFactory") - val rpcEnvName = conf.get("spark.rpc", "netty") - val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) - Utils.classForName(rpcEnvFactoryClassName).newInstance().asInstanceOf[RpcEnvFactory] - } - def create( name: String, host: String, @@ -45,9 +40,8 @@ private[spark] object RpcEnv { conf: SparkConf, securityManager: SecurityManager, clientMode: Boolean = false): RpcEnv = { - // Using Reflection to create the RpcEnv to avoid to depend on Akka directly val config = RpcEnvConfig(conf, name, host, port, securityManager, clientMode) - getRpcEnvFactory(conf).create(config) + new NettyRpcEnvFactory().create(config) } } @@ -95,12 +89,11 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { } /** - * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`. + * Retrieve the [[RpcEndpointRef]] represented by `address` and `endpointName`. * This is a blocking action. */ - def setupEndpointRef( - systemName: String, address: RpcAddress, endpointName: String): RpcEndpointRef = { - setupEndpointRefByURI(uriOf(systemName, address, endpointName)) + def setupEndpointRef(address: RpcAddress, endpointName: String): RpcEndpointRef = { + setupEndpointRefByURI(RpcEndpointAddress(address, endpointName).toString) } /** @@ -121,19 +114,74 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { */ def awaitTermination(): Unit - /** - * Create a URI used to create a [[RpcEndpointRef]]. Use this one to create the URI instead of - * creating it manually because different [[RpcEnv]] may have different formats. - */ - def uriOf(systemName: String, address: RpcAddress, endpointName: String): String - /** * [[RpcEndpointRef]] cannot be deserialized without [[RpcEnv]]. So when deserializing any object * that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method. */ def deserialize[T](deserializationAction: () => T): T + + /** + * Return the instance of the file server used to serve files. This may be `null` if the + * RpcEnv is not operating in server mode. + */ + def fileServer: RpcEnvFileServer + + /** + * Open a channel to download a file from the given URI. If the URIs returned by the + * RpcEnvFileServer use the "spark" scheme, this method will be called by the Utils class to + * retrieve the files. + * + * @param uri URI with location of the file. + */ + def openChannel(uri: String): ReadableByteChannel + } +/** + * A server used by the RpcEnv to server files to other processes owned by the application. + * + * The file server can return URIs handled by common libraries (such as "http" or "hdfs"), or + * it can return "spark" URIs which will be handled by `RpcEnv#fetchFile`. + */ +private[spark] trait RpcEnvFileServer { + + /** + * Adds a file to be served by this RpcEnv. This is used to serve files from the driver + * to executors when they're stored on the driver's local file system. + * + * @param file Local file to serve. + * @return A URI for the location of the file. + */ + def addFile(file: File): String + + /** + * Adds a jar to be served by this RpcEnv. Similar to `addFile` but for jars added using + * `SparkContext.addJar`. + * + * @param file Local file to serve. + * @return A URI for the location of the file. + */ + def addJar(file: File): String + + /** + * Adds a local directory to be served via this file server. + * + * @param baseUri Leading URI path (files can be retrieved by appending their relative + * path to this base URI). This cannot be "files" nor "jars". + * @param path Path to the local directory. + * @return URI for the root of the directory in the file server. + */ + def addDirectory(baseUri: String, path: File): String + + /** Validates and normalizes the base URI for directories. */ + protected def validateDirectoryUri(baseUri: String): String = { + val fixedBaseUri = "/" + baseUri.stripPrefix("/").stripSuffix("/") + require(fixedBaseUri != "/files" && fixedBaseUri != "/jars", + "Directory URI cannot be /files nor /jars.") + fixedBaseUri + } + +} private[spark] case class RpcEnvConfig( conf: SparkConf, diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnvStoppedException.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnvStoppedException.scala new file mode 100644 index 0000000000000..c296cc23f12b7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnvStoppedException.scala @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rpc + +private[rpc] class RpcEnvStoppedException() + extends IllegalStateException("RpcEnv already stopped.") diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala index 285786ebf9f1b..2950df62bf285 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala @@ -19,13 +19,12 @@ package org.apache.spark.rpc import java.util.concurrent.TimeoutException -import scala.concurrent.{Awaitable, Await} +import scala.concurrent.{Await, Awaitable} import scala.concurrent.duration._ import org.apache.spark.SparkConf import org.apache.spark.util.Utils - /** * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. */ @@ -120,7 +119,7 @@ private[spark] object RpcTimeout { // Find the first set property or use the default value with the first property val itr = timeoutPropList.iterator var foundProp: Option[(String, String)] = None - while (itr.hasNext && foundProp.isEmpty){ + while (itr.hasNext && foundProp.isEmpty) { val propKey = itr.next() conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) } } diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala deleted file mode 100644 index 3fad595a0d0b0..0000000000000 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ /dev/null @@ -1,345 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rpc.akka - -import java.util.concurrent.ConcurrentHashMap - -import scala.concurrent.Future -import scala.language.postfixOps -import scala.reflect.ClassTag -import scala.util.control.NonFatal - -import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Address} -import akka.event.Logging.Error -import akka.pattern.{ask => akkaAsk} -import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} -import akka.serialization.JavaSerializer - -import org.apache.spark.{SparkException, Logging, SparkConf} -import org.apache.spark.rpc._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} - -/** - * A RpcEnv implementation based on Akka. - * - * TODO Once we remove all usages of Akka in other place, we can move this file to a new project and - * remove Akka from the dependencies. - */ -private[spark] class AkkaRpcEnv private[akka] ( - val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int) - extends RpcEnv(conf) with Logging { - - private val defaultAddress: RpcAddress = { - val address = actorSystem.asInstanceOf[ExtendedActorSystem].provider.getDefaultAddress - // In some test case, ActorSystem doesn't bind to any address. - // So just use some default value since they are only some unit tests - RpcAddress(address.host.getOrElse("localhost"), address.port.getOrElse(boundPort)) - } - - override val address: RpcAddress = defaultAddress - - /** - * A lookup table to search a [[RpcEndpointRef]] for a [[RpcEndpoint]]. We need it to make - * [[RpcEndpoint.self]] work. - */ - private val endpointToRef = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]() - - /** - * Need this map to remove `RpcEndpoint` from `endpointToRef` via a `RpcEndpointRef` - */ - private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]() - - private def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = { - endpointToRef.put(endpoint, endpointRef) - refToEndpoint.put(endpointRef, endpoint) - } - - private def unregisterEndpoint(endpointRef: RpcEndpointRef): Unit = { - val endpoint = refToEndpoint.remove(endpointRef) - if (endpoint != null) { - endpointToRef.remove(endpoint) - } - } - - /** - * Retrieve the [[RpcEndpointRef]] of `endpoint`. - */ - override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointToRef.get(endpoint) - - override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { - @volatile var endpointRef: AkkaRpcEndpointRef = null - // Use defered function because the Actor needs to use `endpointRef`. - // So `actorRef` should be created after assigning `endpointRef`. - val actorRef = () => actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { - - assert(endpointRef != null) - - override def preStart(): Unit = { - // Listen for remote client network events - context.system.eventStream.subscribe(self, classOf[AssociationEvent]) - safelyCall(endpoint) { - endpoint.onStart() - } - } - - override def receiveWithLogging: Receive = { - case AssociatedEvent(_, remoteAddress, _) => - safelyCall(endpoint) { - endpoint.onConnected(akkaAddressToRpcAddress(remoteAddress)) - } - - case DisassociatedEvent(_, remoteAddress, _) => - safelyCall(endpoint) { - endpoint.onDisconnected(akkaAddressToRpcAddress(remoteAddress)) - } - - case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) => - safelyCall(endpoint) { - endpoint.onNetworkError(cause, akkaAddressToRpcAddress(remoteAddress)) - } - - case e: AssociationEvent => - // TODO ignore? - - case m: AkkaMessage => - logDebug(s"Received RPC message: $m") - safelyCall(endpoint) { - processMessage(endpoint, m, sender) - } - - case AkkaFailure(e) => - safelyCall(endpoint) { - throw e - } - - case message: Any => { - logWarning(s"Unknown message: $message") - } - - } - - override def postStop(): Unit = { - unregisterEndpoint(endpoint.self) - safelyCall(endpoint) { - endpoint.onStop() - } - } - - }), name = name) - endpointRef = new AkkaRpcEndpointRef(defaultAddress, actorRef, conf, initInConstructor = false) - registerEndpoint(endpoint, endpointRef) - // Now actorRef can be created safely - endpointRef.init() - endpointRef - } - - private def processMessage(endpoint: RpcEndpoint, m: AkkaMessage, _sender: ActorRef): Unit = { - val message = m.message - val needReply = m.needReply - val pf: PartialFunction[Any, Unit] = - if (needReply) { - endpoint.receiveAndReply(new RpcCallContext { - override def sendFailure(e: Throwable): Unit = { - _sender ! AkkaFailure(e) - } - - override def reply(response: Any): Unit = { - _sender ! AkkaMessage(response, false) - } - - // Use "lazy" because most of RpcEndpoints don't need "senderAddress" - override lazy val senderAddress: RpcAddress = - new AkkaRpcEndpointRef(defaultAddress, _sender, conf).address - }) - } else { - endpoint.receive - } - try { - pf.applyOrElse[Any, Unit](message, { message => - throw new SparkException(s"Unmatched message $message from ${_sender}") - }) - } catch { - case NonFatal(e) => - _sender ! AkkaFailure(e) - if (!needReply) { - // If the sender does not require a reply, it may not handle the exception. So we rethrow - // "e" to make sure it will be processed. - throw e - } - } - } - - /** - * Run `action` safely to avoid to crash the thread. If any non-fatal exception happens, it will - * call `endpoint.onError`. If `endpoint.onError` throws any non-fatal exception, just log it. - */ - private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { - try { - action - } catch { - case NonFatal(e) => { - try { - endpoint.onError(e) - } catch { - case NonFatal(e) => logError(s"Ignore error: ${e.getMessage}", e) - } - } - } - } - - private def akkaAddressToRpcAddress(address: Address): RpcAddress = { - RpcAddress(address.host.getOrElse(defaultAddress.host), - address.port.getOrElse(defaultAddress.port)) - } - - override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { - import actorSystem.dispatcher - actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout.duration). - map(new AkkaRpcEndpointRef(defaultAddress, _, conf)). - // this is just in case there is a timeout from creating the future in resolveOne, we want the - // exception to indicate the conf that determines the timeout - recover(defaultLookupTimeout.addMessageIfTimeout) - } - - override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = { - AkkaUtils.address( - AkkaUtils.protocol(actorSystem), systemName, address.host, address.port, endpointName) - } - - override def shutdown(): Unit = { - actorSystem.shutdown() - } - - override def stop(endpoint: RpcEndpointRef): Unit = { - require(endpoint.isInstanceOf[AkkaRpcEndpointRef]) - actorSystem.stop(endpoint.asInstanceOf[AkkaRpcEndpointRef].actorRef) - } - - override def awaitTermination(): Unit = { - actorSystem.awaitTermination() - } - - override def toString: String = s"${getClass.getSimpleName}($actorSystem)" - - override def deserialize[T](deserializationAction: () => T): T = { - JavaSerializer.currentSystem.withValue(actorSystem.asInstanceOf[ExtendedActorSystem]) { - deserializationAction() - } - } -} - -private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { - - def create(config: RpcEnvConfig): RpcEnv = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - config.name, config.host, config.port, config.conf, config.securityManager) - actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor") - new AkkaRpcEnv(actorSystem, config.conf, boundPort) - } -} - -/** - * Monitor errors reported by Akka and log them. - */ -private[akka] class ErrorMonitor extends Actor with ActorLogReceive with Logging { - - override def preStart(): Unit = { - context.system.eventStream.subscribe(self, classOf[Error]) - } - - override def receiveWithLogging: Actor.Receive = { - case Error(cause: Throwable, _, _, message: String) => logError(message, cause) - } -} - -private[akka] class AkkaRpcEndpointRef( - @transient private val defaultAddress: RpcAddress, - @transient private val _actorRef: () => ActorRef, - conf: SparkConf, - initInConstructor: Boolean) - extends RpcEndpointRef(conf) with Logging { - - def this( - defaultAddress: RpcAddress, - _actorRef: ActorRef, - conf: SparkConf) = { - this(defaultAddress, () => _actorRef, conf, true) - } - - lazy val actorRef = _actorRef() - - override lazy val address: RpcAddress = { - val akkaAddress = actorRef.path.address - RpcAddress(akkaAddress.host.getOrElse(defaultAddress.host), - akkaAddress.port.getOrElse(defaultAddress.port)) - } - - override lazy val name: String = actorRef.path.name - - private[akka] def init(): Unit = { - // Initialize the lazy vals - actorRef - address - name - } - - if (initInConstructor) { - init() - } - - override def send(message: Any): Unit = { - actorRef ! AkkaMessage(message, false) - } - - override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { - actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap { - // The function will run in the calling thread, so it should be short and never block. - case msg @ AkkaMessage(message, reply) => - if (reply) { - logError(s"Receive $msg but the sender cannot reply") - Future.failed(new SparkException(s"Receive $msg but the sender cannot reply")) - } else { - Future.successful(message) - } - case AkkaFailure(e) => - Future.failed(e) - }(ThreadUtils.sameThread).mapTo[T]. - recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) - } - - override def toString: String = s"${getClass.getSimpleName}($actorRef)" - - final override def equals(that: Any): Boolean = that match { - case other: AkkaRpcEndpointRef => actorRef == other.actorRef - case _ => false - } - - final override def hashCode(): Int = if (actorRef == null) 0 else actorRef.hashCode() -} - -/** - * A wrapper to `message` so that the receiver knows if the sender expects a reply. - * @param message - * @param needReply if the sender expects a reply message - */ -private[akka] case class AkkaMessage(message: Any, needReply: Boolean) - -/** - * A reply with the failure error from the receiver to the sender - */ -private[akka] case class AkkaFailure(e: Throwable) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index eb25d6c7b721b..4f8fe018b432d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -17,14 +17,15 @@ package org.apache.spark.rpc.netty -import java.util.concurrent.{ThreadPoolExecutor, ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit} import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.concurrent.Promise import scala.util.control.NonFatal -import org.apache.spark.{SparkException, Logging} +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging import org.apache.spark.network.client.RpcResponseCallback import org.apache.spark.rpc._ import org.apache.spark.util.ThreadUtils @@ -106,70 +107,61 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { val iter = endpoints.keySet().iterator() while (iter.hasNext) { val name = iter.next - postMessage( - name, - _ => message, - () => { logWarning(s"Drop $message because $name has been stopped") }) + postMessage(name, message, (e) => logWarning(s"Message $message dropped. ${e.getMessage}")) } } /** Posts a message sent by a remote endpoint. */ def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = { - def createMessage(sender: NettyRpcEndpointRef): InboxMessage = { - val rpcCallContext = - new RemoteNettyRpcCallContext( - nettyEnv, sender, callback, message.senderAddress, message.needReply) - ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext) - } - - def onEndpointStopped(): Unit = { - callback.onFailure( - new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) - } - - postMessage(message.receiver.name, createMessage, onEndpointStopped) + val rpcCallContext = + new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress) + val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext) + postMessage(message.receiver.name, rpcMessage, (e) => callback.onFailure(e)) } /** Posts a message sent by a local endpoint. */ def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = { - def createMessage(sender: NettyRpcEndpointRef): InboxMessage = { - val rpcCallContext = - new LocalNettyRpcCallContext(sender, message.senderAddress, message.needReply, p) - ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext) - } - - def onEndpointStopped(): Unit = { - p.tryFailure( - new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) - } + val rpcCallContext = + new LocalNettyRpcCallContext(message.senderAddress, p) + val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext) + postMessage(message.receiver.name, rpcMessage, (e) => p.tryFailure(e)) + } - postMessage(message.receiver.name, createMessage, onEndpointStopped) + /** Posts a one-way message. */ + def postOneWayMessage(message: RequestMessage): Unit = { + postMessage(message.receiver.name, OneWayMessage(message.senderAddress, message.content), + (e) => throw e) } /** * Posts a message to a specific endpoint. * * @param endpointName name of the endpoint. - * @param createMessageFn function to create the message. + * @param message the message to post * @param callbackIfStopped callback function if the endpoint is stopped. */ private def postMessage( endpointName: String, - createMessageFn: NettyRpcEndpointRef => InboxMessage, - callbackIfStopped: () => Unit): Unit = { + message: InboxMessage, + callbackIfStopped: (Exception) => Unit): Unit = { val shouldCallOnStop = synchronized { val data = endpoints.get(endpointName) if (stopped || data == null) { true } else { - data.inbox.post(createMessageFn(data.ref)) + data.inbox.post(message) receivers.offer(data) false } } if (shouldCallOnStop) { // We don't need to call `onStop` in the `synchronized` block - callbackIfStopped() + val error = if (stopped) { + new RpcEnvStoppedException() + } else { + new SparkException(s"Could not find $endpointName or it has been stopped.") + } + callbackIfStopped(error) } } @@ -201,7 +193,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { /** Thread pool used for dispatching messages. */ private val threadpool: ThreadPoolExecutor = { val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads", - Runtime.getRuntime.availableProcessors()) + math.max(2, Runtime.getRuntime.availableProcessors())) val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop") for (i <- 0 until numThreads) { pool.execute(new MessageLoop) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index c72b588db57fe..fffbd5cd44a23 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -21,18 +21,20 @@ import javax.annotation.concurrent.GuardedBy import scala.util.control.NonFatal -import com.google.common.annotations.VisibleForTesting - -import org.apache.spark.{Logging, SparkException} +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint} private[netty] sealed trait InboxMessage -private[netty] case class ContentMessage( +private[netty] case class OneWayMessage( + senderAddress: RpcAddress, + content: Any) extends InboxMessage + +private[netty] case class RpcMessage( senderAddress: RpcAddress, content: Any, - needReply: Boolean, context: NettyRpcCallContext) extends InboxMessage private[netty] case object OnStart extends InboxMessage @@ -98,29 +100,24 @@ private[netty] class Inbox( while (true) { safelyCall(endpoint) { message match { - case ContentMessage(_sender, content, needReply, context) => - // The partial function to call - val pf = if (needReply) endpoint.receiveAndReply(context) else endpoint.receive + case RpcMessage(_sender, content, context) => try { - pf.applyOrElse[Any, Unit](content, { msg => + endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg => throw new SparkException(s"Unsupported message $message from ${_sender}") }) - if (!needReply) { - context.finish() - } } catch { case NonFatal(e) => - if (needReply) { - // If the sender asks a reply, we should send the error back to the sender - context.sendFailure(e) - } else { - context.finish() - } + context.sendFailure(e) // Throw the exception -- this exception will be caught by the safelyCall function. // The endpoint's onError function will be called. throw e } + case OneWayMessage(_sender, content) => + endpoint.receive.applyOrElse[Any, Unit](content, { msg => + throw new SparkException(s"Unsupported message $message from ${_sender}") + }) + case OnStart => endpoint.onStart() if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { @@ -193,8 +190,10 @@ private[netty] class Inbox( def isEmpty: Boolean = inbox.synchronized { messages.isEmpty } - /** Called when we are dropping a message. Test cases override this to test message dropping. */ - @VisibleForTesting + /** + * Called when we are dropping a message. Test cases override this to test message dropping. + * Exposed for testing. + */ protected def onDrop(message: InboxMessage): Unit = { logWarning(s"Drop $message because $endpointRef is stopped") } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala index 21d5bb4923d1b..7dd7e610a28eb 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala @@ -19,53 +19,32 @@ package org.apache.spark.rpc.netty import scala.concurrent.Promise -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.network.client.RpcResponseCallback import org.apache.spark.rpc.{RpcAddress, RpcCallContext} -private[netty] abstract class NettyRpcCallContext( - endpointRef: NettyRpcEndpointRef, - override val senderAddress: RpcAddress, - needReply: Boolean) +private[netty] abstract class NettyRpcCallContext(override val senderAddress: RpcAddress) extends RpcCallContext with Logging { protected def send(message: Any): Unit override def reply(response: Any): Unit = { - if (needReply) { - send(AskResponse(endpointRef, response)) - } else { - throw new IllegalStateException( - s"Cannot send $response to the sender because the sender does not expect a reply") - } + send(response) } override def sendFailure(e: Throwable): Unit = { - if (needReply) { - send(AskResponse(endpointRef, RpcFailure(e))) - } else { - logError(e.getMessage, e) - throw new IllegalStateException( - "Cannot send reply to the sender because the sender won't handle it") - } + send(RpcFailure(e)) } - def finish(): Unit = { - if (!needReply) { - send(Ack(endpointRef)) - } - } } /** * If the sender and the receiver are in the same process, the reply can be sent back via `Promise`. */ private[netty] class LocalNettyRpcCallContext( - endpointRef: NettyRpcEndpointRef, senderAddress: RpcAddress, - needReply: Boolean, p: Promise[Any]) - extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { + extends NettyRpcCallContext(senderAddress) { override protected def send(message: Any): Unit = { p.success(message) @@ -77,11 +56,9 @@ private[netty] class LocalNettyRpcCallContext( */ private[netty] class RemoteNettyRpcCallContext( nettyEnv: NettyRpcEnv, - endpointRef: NettyRpcEndpointRef, callback: RpcResponseCallback, - senderAddress: RpcAddress, - needReply: Boolean) - extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { + senderAddress: RpcAddress) + extends NettyRpcCallContext(senderAddress) { override protected def send(message: Any): Unit = { val reply = nettyEnv.serialize(message) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 09093819bb22c..7f2192e1f5a70 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -17,22 +17,20 @@ package org.apache.spark.rpc.netty import java.io._ -import java.lang.{Boolean => JBoolean} import java.net.{InetSocketAddress, URI} import java.nio.ByteBuffer +import java.nio.channels.{Pipe, ReadableByteChannel, WritableByteChannel} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy +import javax.annotation.Nullable -import scala.collection.mutable import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag -import scala.util.{DynamicVariable, Failure, Success} +import scala.util.{DynamicVariable, Failure, Success, Try} import scala.util.control.NonFatal -import com.google.common.base.Preconditions -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.internal.Logging import org.apache.spark.network.TransportContext import org.apache.spark.network.client._ import org.apache.spark.network.netty.SparkTransportConf @@ -48,26 +46,39 @@ private[netty] class NettyRpcEnv( host: String, securityManager: SecurityManager) extends RpcEnv(conf) with Logging { - private val transportConf = SparkTransportConf.fromSparkConf( - conf.clone.set("spark.shuffle.io.numConnectionsPerPeer", "1"), + private[netty] val transportConf = SparkTransportConf.fromSparkConf( + conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"), + "rpc", conf.getInt("spark.rpc.io.threads", 0)) private val dispatcher: Dispatcher = new Dispatcher(this) + private val streamManager = new NettyStreamManager(this) + private val transportContext = new TransportContext(transportConf, - new NettyRpcHandler(dispatcher, this)) + new NettyRpcHandler(dispatcher, this, streamManager)) - private val clientFactory = { - val bootstraps: java.util.List[TransportClientBootstrap] = - if (securityManager.isAuthenticationEnabled()) { - java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, - securityManager.isSaslEncryptionEnabled())) - } else { - java.util.Collections.emptyList[TransportClientBootstrap] - } - transportContext.createClientFactory(bootstraps) + private def createClientBootstraps(): java.util.List[TransportClientBootstrap] = { + if (securityManager.isAuthenticationEnabled()) { + java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, + securityManager.isSaslEncryptionEnabled())) + } else { + java.util.Collections.emptyList[TransportClientBootstrap] + } } + private val clientFactory = transportContext.createClientFactory(createClientBootstraps()) + + /** + * A separate client factory for file downloads. This avoids using the same RPC handler as + * the main RPC context, so that events caused by these clients are kept isolated from the + * main RPC traffic. + * + * It also allows for different configuration of certain properties, such as the number of + * connections per peer. + */ + @volatile private var fileDownloadFactory: TransportClientFactory = _ + val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout") // Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool @@ -104,7 +115,7 @@ private[netty] class NettyRpcEnv( } else { java.util.Collections.emptyList() } - server = transportContext.createServer(port, bootstraps) + server = transportContext.createServer(host, port, bootstraps) dispatcher.registerRpcEndpoint( RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher)) } @@ -139,7 +150,7 @@ private[netty] class NettyRpcEnv( private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = { if (receiver.client != null) { - receiver.client.sendRpc(message.content, message.createCallback(receiver.client)); + message.sendWith(receiver.client) } else { require(receiver.address != null, "Cannot send message to client endpoint with no listen address.") @@ -171,25 +182,14 @@ private[netty] class NettyRpcEnv( val remoteAddr = message.receiver.address if (remoteAddr == address) { // Message to a local RPC endpoint. - val promise = Promise[Any]() - dispatcher.postLocalMessage(message, promise) - promise.future.onComplete { - case Success(response) => - val ack = response.asInstanceOf[Ack] - logTrace(s"Received ack from ${ack.sender}") - case Failure(e) => - logWarning(s"Exception when sending $message", e) - }(ThreadUtils.sameThread) + try { + dispatcher.postOneWayMessage(message) + } catch { + case e: RpcEnvStoppedException => logWarning(e.getMessage) + } } else { // Message to a remote RPC endpoint. - postToOutbox(message.receiver, OutboxMessage(serialize(message), - (e) => { - logWarning(s"Exception when sending $message", e) - }, - (client, response) => { - val ack = deserialize[Ack](client, response) - logDebug(s"Receive ack from ${ack.sender}") - })) + postToOutbox(message.receiver, OneWayOutboxMessage(serialize(message))) } } @@ -197,58 +197,66 @@ private[netty] class NettyRpcEnv( clientFactory.createClient(address.host, address.port) } - private[netty] def ask(message: RequestMessage): Future[Any] = { + private[netty] def ask[T: ClassTag](message: RequestMessage, timeout: RpcTimeout): Future[T] = { val promise = Promise[Any]() val remoteAddr = message.receiver.address - if (remoteAddr == address) { - val p = Promise[Any]() - dispatcher.postLocalMessage(message, p) - p.future.onComplete { - case Success(response) => - val reply = response.asInstanceOf[AskResponse] - if (reply.reply.isInstanceOf[RpcFailure]) { - if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { - logWarning(s"Ignore failure: ${reply.reply}") - } - } else if (!promise.trySuccess(reply.reply)) { - logWarning(s"Ignore message: ${reply}") - } - case Failure(e) => - if (!promise.tryFailure(e)) { - logWarning("Ignore Exception", e) - } + + def onFailure(e: Throwable): Unit = { + if (!promise.tryFailure(e)) { + logWarning(s"Ignored failure: $e") + } + } + + def onSuccess(reply: Any): Unit = reply match { + case RpcFailure(e) => onFailure(e) + case rpcReply => + if (!promise.trySuccess(rpcReply)) { + logWarning(s"Ignored message: $reply") + } + } + + try { + if (remoteAddr == address) { + val p = Promise[Any]() + p.future.onComplete { + case Success(response) => onSuccess(response) + case Failure(e) => onFailure(e) + }(ThreadUtils.sameThread) + dispatcher.postLocalMessage(message, p) + } else { + val rpcMessage = RpcOutboxMessage(serialize(message), + onFailure, + (client, response) => onSuccess(deserialize[Any](client, response))) + postToOutbox(message.receiver, rpcMessage) + promise.future.onFailure { + case _: TimeoutException => rpcMessage.onTimeout() + case _ => + }(ThreadUtils.sameThread) + } + + val timeoutCancelable = timeoutScheduler.schedule(new Runnable { + override def run(): Unit = { + onFailure(new TimeoutException(s"Cannot receive any reply in ${timeout.duration}")) + } + }, timeout.duration.toNanos, TimeUnit.NANOSECONDS) + promise.future.onComplete { v => + timeoutCancelable.cancel(true) }(ThreadUtils.sameThread) - } else { - postToOutbox(message.receiver, OutboxMessage(serialize(message), - (e) => { - if (!promise.tryFailure(e)) { - logWarning("Ignore Exception", e) - } - }, - (client, response) => { - val reply = deserialize[AskResponse](client, response) - if (reply.reply.isInstanceOf[RpcFailure]) { - if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { - logWarning(s"Ignore failure: ${reply.reply}") - } - } else if (!promise.trySuccess(reply.reply)) { - logWarning(s"Ignore message: ${reply}") - } - })) + } catch { + case NonFatal(e) => + onFailure(e) } - promise.future + promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) } - private[netty] def serialize(content: Any): Array[Byte] = { - val buffer = javaSerializerInstance.serialize(content) - java.util.Arrays.copyOfRange( - buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit) + private[netty] def serialize(content: Any): ByteBuffer = { + javaSerializerInstance.serialize(content) } - private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: Array[Byte]): T = { + private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: ByteBuffer): T = { NettyRpcEnv.currentClient.withValue(client) { deserialize { () => - javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes)) + javaSerializerInstance.deserialize[T](bytes) } } } @@ -257,9 +265,6 @@ private[netty] class NettyRpcEnv( dispatcher.getRpcEndpointRef(endpoint) } - override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = - new RpcEndpointAddress(address, endpointName).toString - override def shutdown(): Unit = { cleanup() } @@ -294,6 +299,9 @@ private[netty] class NettyRpcEnv( if (clientConnectionExecutor != null) { clientConnectionExecutor.shutdownNow() } + if (fileDownloadFactory != null) { + fileDownloadFactory.close() + } } override def deserialize[T](deserializationAction: () => T): T = { @@ -302,6 +310,104 @@ private[netty] class NettyRpcEnv( } } + override def fileServer: RpcEnvFileServer = streamManager + + override def openChannel(uri: String): ReadableByteChannel = { + val parsedUri = new URI(uri) + require(parsedUri.getHost() != null, "Host name must be defined.") + require(parsedUri.getPort() > 0, "Port must be defined.") + require(parsedUri.getPath() != null && parsedUri.getPath().nonEmpty, "Path must be defined.") + + val pipe = Pipe.open() + val source = new FileDownloadChannel(pipe.source()) + try { + val client = downloadClient(parsedUri.getHost(), parsedUri.getPort()) + val callback = new FileDownloadCallback(pipe.sink(), source, client) + client.stream(parsedUri.getPath(), callback) + } catch { + case e: Exception => + pipe.sink().close() + source.close() + throw e + } + + source + } + + private def downloadClient(host: String, port: Int): TransportClient = { + if (fileDownloadFactory == null) synchronized { + if (fileDownloadFactory == null) { + val module = "files" + val prefix = "spark.rpc.io." + val clone = conf.clone() + + // Copy any RPC configuration that is not overridden in the spark.files namespace. + conf.getAll.foreach { case (key, value) => + if (key.startsWith(prefix)) { + val opt = key.substring(prefix.length()) + clone.setIfMissing(s"spark.$module.io.$opt", value) + } + } + + val ioThreads = clone.getInt("spark.files.io.threads", 1) + val downloadConf = SparkTransportConf.fromSparkConf(clone, module, ioThreads) + val downloadContext = new TransportContext(downloadConf, new NoOpRpcHandler(), true) + fileDownloadFactory = downloadContext.createClientFactory(createClientBootstraps()) + } + } + fileDownloadFactory.createClient(host, port) + } + + private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel { + + @volatile private var error: Throwable = _ + + def setError(e: Throwable): Unit = { + error = e + source.close() + } + + override def read(dst: ByteBuffer): Int = { + Try(source.read(dst)) match { + case Success(bytesRead) => bytesRead + case Failure(readErr) => + if (error != null) { + throw error + } else { + throw readErr + } + } + } + + override def close(): Unit = source.close() + + override def isOpen(): Boolean = source.isOpen() + + } + + private class FileDownloadCallback( + sink: WritableByteChannel, + source: FileDownloadChannel, + client: TransportClient) extends StreamCallback { + + override def onData(streamId: String, buf: ByteBuffer): Unit = { + while (buf.remaining() > 0) { + sink.write(buf) + } + } + + override def onComplete(streamId: String): Unit = { + sink.close() + } + + override def onFailure(streamId: String, cause: Throwable): Unit = { + logDebug(s"Error downloading stream $streamId.", cause) + source.setError(cause) + sink.close() + } + + } + } private[netty] object NettyRpcEnv extends Logging { @@ -326,7 +432,7 @@ private[netty] object NettyRpcEnv extends Logging { } -private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { +private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { def create(config: RpcEnvConfig): RpcEnv = { val sparkConf = config.conf @@ -339,10 +445,10 @@ private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { if (!config.clientMode) { val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort => nettyEnv.startServer(actualPort) - (nettyEnv, actualPort) + (nettyEnv, nettyEnv.address.port) } try { - Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1 + Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1 } catch { case NonFatal(e) => nettyEnv.shutdown() @@ -372,7 +478,6 @@ private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { * @param conf Spark configuration. * @param endpointAddress The address where the endpoint is listening. * @param nettyEnv The RpcEnv associated with this ref. - * @param local Whether the referenced endpoint lives in the same process. */ private[netty] class NettyRpcEndpointRef( @transient private val conf: SparkConf, @@ -400,30 +505,17 @@ private[netty] class NettyRpcEndpointRef( override def name: String = _name override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { - val promise = Promise[Any]() - val timeoutCancelable = nettyEnv.timeoutScheduler.schedule(new Runnable { - override def run(): Unit = { - promise.tryFailure(new TimeoutException("Cannot receive any reply in " + timeout.duration)) - } - }, timeout.duration.toNanos, TimeUnit.NANOSECONDS) - val f = nettyEnv.ask(RequestMessage(nettyEnv.address, this, message, true)) - f.onComplete { v => - timeoutCancelable.cancel(true) - if (!promise.tryComplete(v)) { - logWarning(s"Ignore message $v") - } - }(ThreadUtils.sameThread) - promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) + nettyEnv.ask(RequestMessage(nettyEnv.address, this, message), timeout) } override def send(message: Any): Unit = { require(message != null, "Message is null") - nettyEnv.send(RequestMessage(nettyEnv.address, this, message, false)) + nettyEnv.send(RequestMessage(nettyEnv.address, this, message)) } override def toString: String = s"NettyRpcEndpointRef(${_address})" - def toURI: URI = new URI(s"spark://${_address}") + def toURI: URI = new URI(_address.toString) final override def equals(that: Any): Boolean = that match { case other: NettyRpcEndpointRef => _address == other._address @@ -437,24 +529,7 @@ private[netty] class NettyRpcEndpointRef( * The message that is sent from the sender to the receiver. */ private[netty] case class RequestMessage( - senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any, needReply: Boolean) - -/** - * The base trait for all messages that are sent back from the receiver to the sender. - */ -private[netty] trait ResponseMessage - -/** - * The reply for `ask` from the receiver side. - */ -private[netty] case class AskResponse(sender: NettyRpcEndpointRef, reply: Any) - extends ResponseMessage - -/** - * A message to send back to the receiver side. It's necessary because [[TransportClient]] only - * clean the resources when it receives a reply. - */ -private[netty] case class Ack(sender: NettyRpcEndpointRef) extends ResponseMessage + senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any) /** * A response that indicates some failure happens in the receiver side. @@ -474,40 +549,60 @@ private[netty] case class RpcFailure(e: Throwable) * with different `RpcAddress` information). */ private[netty] class NettyRpcHandler( - dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with Logging { + dispatcher: Dispatcher, + nettyEnv: NettyRpcEnv, + streamManager: StreamManager) extends RpcHandler with Logging { - // TODO: Can we add connection callback (channel registered) to the underlying framework? - // A variable to track whether we should dispatch the RemoteProcessConnected message. - private val clients = new ConcurrentHashMap[TransportClient, JBoolean]() + // A variable to track the remote RpcEnv addresses of all clients + private val remoteAddresses = new ConcurrentHashMap[RpcAddress, RpcAddress]() override def receive( client: TransportClient, - message: Array[Byte], + message: ByteBuffer, callback: RpcResponseCallback): Unit = { + val messageToDispatch = internalReceive(client, message) + dispatcher.postRemoteMessage(messageToDispatch, callback) + } + + override def receive( + client: TransportClient, + message: ByteBuffer): Unit = { + val messageToDispatch = internalReceive(client, message) + dispatcher.postOneWayMessage(messageToDispatch) + } + + private def internalReceive(client: TransportClient, message: ByteBuffer): RequestMessage = { val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - if (clients.putIfAbsent(client, JBoolean.TRUE) == null) { - dispatcher.postToAll(RemoteProcessConnected(clientAddr)) - } val requestMessage = nettyEnv.deserialize[RequestMessage](client, message) - val messageToDispatch = if (requestMessage.senderAddress == null) { - // Create a new message with the socket address of the client as the sender. - RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content, - requestMessage.needReply) - } else { - requestMessage + if (requestMessage.senderAddress == null) { + // Create a new message with the socket address of the client as the sender. + RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content) + } else { + // The remote RpcEnv listens to some port, we should also fire a RemoteProcessConnected for + // the listening address + val remoteEnvAddress = requestMessage.senderAddress + if (remoteAddresses.putIfAbsent(clientAddr, remoteEnvAddress) == null) { + dispatcher.postToAll(RemoteProcessConnected(remoteEnvAddress)) } - dispatcher.postRemoteMessage(messageToDispatch, callback) + requestMessage + } } - override def getStreamManager: StreamManager = new OneForOneStreamManager + override def getStreamManager: StreamManager = streamManager override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = { val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { val clientAddr = RpcAddress(addr.getHostName, addr.getPort) dispatcher.postToAll(RemoteProcessConnectionError(cause, clientAddr)) + // If the remove RpcEnv listens to some address, we should also fire a + // RemoteProcessConnectionError for the remote RpcEnv listening address + val remoteEnvAddress = remoteAddresses.get(clientAddr) + if (remoteEnvAddress != null) { + dispatcher.postToAll(RemoteProcessConnectionError(cause, remoteEnvAddress)) + } } else { // If the channel is closed before connecting, its remoteAddress will be null. // See java.net.Socket.getRemoteSocketAddress @@ -516,13 +611,25 @@ private[netty] class NettyRpcHandler( } } - override def connectionTerminated(client: TransportClient): Unit = { + override def channelActive(client: TransportClient): Unit = { + val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] + assert(addr != null) + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + dispatcher.postToAll(RemoteProcessConnected(clientAddr)) + } + + override def channelInactive(client: TransportClient): Unit = { val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - clients.remove(client) nettyEnv.removeOutbox(clientAddr) dispatcher.postToAll(RemoteProcessDisconnected(clientAddr)) + val remoteEnvAddress = remoteAddresses.remove(clientAddr) + // If the remove RpcEnv listens to some address, we should also fire a + // RemoteProcessDisconnected for the remote RpcEnv listening address + if (remoteEnvAddress != null) { + dispatcher.postToAll(RemoteProcessDisconnected(remoteEnvAddress)) + } } else { // If the channel is closed before connecting, its remoteAddress will be null. In this case, // we can ignore it since we don't fire "Associated". diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala new file mode 100644 index 0000000000000..afcb023a99daa --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rpc.netty + +import java.io.File +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.server.StreamManager +import org.apache.spark.rpc.RpcEnvFileServer +import org.apache.spark.util.Utils + +/** + * StreamManager implementation for serving files from a NettyRpcEnv. + * + * Three kinds of resources can be registered in this manager, all backed by actual files: + * + * - "/files": a flat list of files; used as the backend for [[SparkContext.addFile]]. + * - "/jars": a flat list of files; used as the backend for [[SparkContext.addJar]]. + * - arbitrary directories; all files under the directory become available through the manager, + * respecting the directory's hierarchy. + * + * Only streaming (openStream) is supported. + */ +private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) + extends StreamManager with RpcEnvFileServer { + + private val files = new ConcurrentHashMap[String, File]() + private val jars = new ConcurrentHashMap[String, File]() + private val dirs = new ConcurrentHashMap[String, File]() + + override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = { + throw new UnsupportedOperationException() + } + + override def openStream(streamId: String): ManagedBuffer = { + val Array(ftype, fname) = streamId.stripPrefix("/").split("/", 2) + val file = ftype match { + case "files" => files.get(fname) + case "jars" => jars.get(fname) + case other => + val dir = dirs.get(ftype) + require(dir != null, s"Invalid stream URI: $ftype not found.") + new File(dir, fname) + } + + if (file != null && file.isFile()) { + new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length()) + } else { + null + } + } + + override def addFile(file: File): String = { + require(files.putIfAbsent(file.getName(), file) == null, + s"File ${file.getName()} already registered.") + s"${rpcEnv.address.toSparkURL}/files/${Utils.encodeFileNameToURIRawPath(file.getName())}" + } + + override def addJar(file: File): String = { + require(jars.putIfAbsent(file.getName(), file) == null, + s"JAR ${file.getName()} already registered.") + s"${rpcEnv.address.toSparkURL}/jars/${Utils.encodeFileNameToURIRawPath(file.getName())}" + } + + override def addDirectory(baseUri: String, path: File): String = { + val fixedBaseUri = validateDirectoryUri(baseUri) + require(dirs.putIfAbsent(fixedBaseUri.stripPrefix("/"), path) == null, + s"URI '$fixedBaseUri' already registered.") + s"${rpcEnv.address.toSparkURL}$fixedBaseUri" + } + +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala index 2f6817f2eb935..56499c639f292 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala @@ -17,31 +17,70 @@ package org.apache.spark.rpc.netty +import java.nio.ByteBuffer import java.util.concurrent.Callable import javax.annotation.concurrent.GuardedBy import scala.util.control.NonFatal import org.apache.spark.SparkException +import org.apache.spark.internal.Logging import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} -import org.apache.spark.rpc.RpcAddress +import org.apache.spark.rpc.{RpcAddress, RpcEnvStoppedException} -private[netty] case class OutboxMessage(content: Array[Byte], - _onFailure: (Throwable) => Unit, - _onSuccess: (TransportClient, Array[Byte]) => Unit) { +private[netty] sealed trait OutboxMessage { - def createCallback(client: TransportClient): RpcResponseCallback = new RpcResponseCallback() { - override def onFailure(e: Throwable): Unit = { - _onFailure(e) - } + def sendWith(client: TransportClient): Unit + + def onFailure(e: Throwable): Unit + +} + +private[netty] case class OneWayOutboxMessage(content: ByteBuffer) extends OutboxMessage + with Logging { - override def onSuccess(response: Array[Byte]): Unit = { - _onSuccess(client, response) + override def sendWith(client: TransportClient): Unit = { + client.send(content) + } + + override def onFailure(e: Throwable): Unit = { + e match { + case e1: RpcEnvStoppedException => logWarning(e1.getMessage) + case e1: Throwable => logWarning(s"Failed to send one-way RPC.", e1) } } } +private[netty] case class RpcOutboxMessage( + content: ByteBuffer, + _onFailure: (Throwable) => Unit, + _onSuccess: (TransportClient, ByteBuffer) => Unit) + extends OutboxMessage with RpcResponseCallback { + + private var client: TransportClient = _ + private var requestId: Long = _ + + override def sendWith(client: TransportClient): Unit = { + this.client = client + this.requestId = client.sendRpc(content, this) + } + + def onTimeout(): Unit = { + require(client != null, "TransportClient has not yet been set.") + client.removeRpcRequest(requestId) + } + + override def onFailure(e: Throwable): Unit = { + _onFailure(e) + } + + override def onSuccess(response: ByteBuffer): Unit = { + _onSuccess(client, response) + } + +} + private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { outbox => // Give this an alias so we can use it more clearly in closures. @@ -82,7 +121,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { } } if (dropped) { - message._onFailure(new SparkException("Message is dropped because Outbox is stopped")) + message.onFailure(new SparkException("Message is dropped because Outbox is stopped")) } else { drainOutbox() } @@ -122,7 +161,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { try { val _client = synchronized { client } if (_client != null) { - _client.sendRpc(message.content, message.createCallback(_client)) + message.sendWith(_client) } else { assert(stopped == true) } @@ -195,7 +234,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { // update messages and it's safe to just drain the queue. var message = messages.poll() while (message != null) { - message._onFailure(e) + message.onFailure(e) message = messages.poll() } assert(messages.isEmpty) @@ -229,7 +268,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { // update messages and it's safe to just drain the queue. var message = messages.poll() while (message != null) { - message._onFailure(new SparkException("Message is dropped because Outbox is stopped")) + message.onFailure(new SparkException("Message is dropped because Outbox is stopped")) message = messages.poll() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index 146cfb9ba8037..cedacad44afec 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -19,47 +19,61 @@ package org.apache.spark.scheduler import org.apache.spark.annotation.DeveloperApi + /** * :: DeveloperApi :: * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage. + * + * Note: once this is JSON serialized the types of `update` and `value` will be lost and be + * cast to strings. This is because the user can define an accumulator of any type and it will + * be difficult to preserve the type in consumers of the event log. This does not apply to + * internal accumulators that represent task level metrics. + * + * @param id accumulator ID + * @param name accumulator name + * @param update partial value from a task, may be None if used on driver to describe a stage + * @param value total accumulated value so far, maybe None if used on executors to describe a task + * @param internal whether this accumulator was internal + * @param countFailedValues whether to count this accumulator's partial value if the task failed + * @param metadata internal metadata associated with this accumulator, if any */ @DeveloperApi -class AccumulableInfo private[spark] ( - val id: Long, - val name: String, - val update: Option[String], // represents a partial update within a task - val value: String, - val internal: Boolean) { - - override def equals(other: Any): Boolean = other match { - case acc: AccumulableInfo => - this.id == acc.id && this.name == acc.name && - this.update == acc.update && this.value == acc.value && - this.internal == acc.internal - case _ => false - } +case class AccumulableInfo private[spark] ( + id: Long, + name: Option[String], + update: Option[Any], // represents a partial update within a task + value: Option[Any], + private[spark] val internal: Boolean, + private[spark] val countFailedValues: Boolean, + // TODO: use this to identify internal task metrics instead of encoding it in the name + private[spark] val metadata: Option[String] = None) - override def hashCode(): Int = { - val state = Seq(id, name, update, value, internal) - state.map(_.hashCode).reduceLeft(31 * _ + _) - } -} +/** + * A collection of deprecated constructors. This will be removed soon. + */ object AccumulableInfo { + + @deprecated("do not create AccumulableInfo", "2.0.0") def apply( id: Long, name: String, update: Option[String], value: String, internal: Boolean): AccumulableInfo = { - new AccumulableInfo(id, name, update, value, internal) + new AccumulableInfo( + id, Option(name), update, Option(value), internal, countFailedValues = false) } + @deprecated("do not create AccumulableInfo", "2.0.0") def apply(id: Long, name: String, update: Option[String], value: String): AccumulableInfo = { - new AccumulableInfo(id, name, update, value, internal = false) + new AccumulableInfo( + id, Option(name), update, Option(value), internal = false, countFailedValues = false) } + @deprecated("do not create AccumulableInfo", "2.0.0") def apply(id: Long, name: String, value: String): AccumulableInfo = { - new AccumulableInfo(id, name, None, value, internal = false) + new AccumulableInfo( + id, Option(name), None, Option(value), internal = false, countFailedValues = false) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala index a3d2db31301b3..949e88f606275 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala @@ -19,7 +19,6 @@ package org.apache.spark.scheduler import java.util.Properties -import org.apache.spark.TaskContext import org.apache.spark.util.CallSite /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index a1f0fd05f661a..c27aad268d32a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -22,6 +22,7 @@ import java.util.Properties import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger +import scala.annotation.tailrec import scala.collection.Map import scala.collection.mutable.{HashMap, HashSet, Stack} import scala.concurrent.duration._ @@ -34,12 +35,14 @@ import org.apache.commons.lang3.SerializationUtils import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ -import org.apache.spark.util._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat +import org.apache.spark.util._ /** * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of @@ -130,7 +133,7 @@ class DAGScheduler( def this(sc: SparkContext) = this(sc, sc.taskScheduler) - private[scheduler] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this) + private[spark] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this) private[scheduler] val nextJobId = new AtomicInteger(0) private[scheduler] def numTotalJobs: Int = nextJobId.get() @@ -206,11 +209,10 @@ class DAGScheduler( task: Task[_], reason: TaskEndReason, result: Any, - accumUpdates: Map[Long, Any], - taskInfo: TaskInfo, - taskMetrics: TaskMetrics): Unit = { + accumUpdates: Seq[AccumulableInfo], + taskInfo: TaskInfo): Unit = { eventProcessLoop.post( - CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)) + CompletionEvent(task, reason, result, accumUpdates, taskInfo)) } /** @@ -220,9 +222,10 @@ class DAGScheduler( */ def executorHeartbeatReceived( execId: String, - taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics) + // (taskId, stageId, stageAttemptId, accumUpdates) + accumUpdates: Array[(Long, Int, Int, Seq[AccumulableInfo])], blockManagerId: BlockManagerId): Boolean = { - listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) + listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, accumUpdates)) blockManagerMaster.driverEndpoint.askWithRetry[Boolean]( BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat")) } @@ -467,6 +470,7 @@ class DAGScheduler( * all of that stage's ancestors. */ private def updateJobIdStageIdMaps(jobId: Int, stage: Stage): Unit = { + @tailrec def updateJobIdStageIdMapsList(stages: List[Stage]) { if (stages.nonEmpty) { val s = stages.head @@ -541,8 +545,7 @@ class DAGScheduler( } /** - * Submit an action job to the scheduler and get a JobWaiter object back. The JobWaiter object - * can be used to block until the the job finishes executing or can be used to cancel the job. + * Submit an action job to the scheduler. * * @param rdd target RDD to run tasks on * @param func a function to run on each partition of the RDD @@ -551,6 +554,11 @@ class DAGScheduler( * @param callSite where in the user program this job was called * @param resultHandler callback to pass each result to * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + * + * @return a JobWaiter object that can be used to block until the job finishes executing + * or can be used to cancel the job. + * + * @throws IllegalArgumentException when partitions ids are illegal */ def submitJob[T, U]( rdd: RDD[T], @@ -584,7 +592,7 @@ class DAGScheduler( /** * Run an action job on the given RDD and pass all the results to the resultHandler function as - * they arrive. Throws an exception if the job fials, or returns normally if successful. + * they arrive. * * @param rdd target RDD to run tasks on * @param func a function to run on each partition of the RDD @@ -593,6 +601,8 @@ class DAGScheduler( * @param callSite where in the user program this job was called * @param resultHandler callback to pass each result to * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + * + * @throws Exception when the job fails */ def runJob[T, U]( rdd: RDD[T], @@ -603,11 +613,17 @@ class DAGScheduler( properties: Properties): Unit = { val start = System.nanoTime val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties) - waiter.awaitResult() match { - case JobSucceeded => + // Note: Do not call Await.ready(future) because that calls `scala.concurrent.blocking`, + // which causes concurrent SQL executions to fail if a fork-join pool is used. Note that + // due to idiosyncrasies in Scala, `awaitPermission` is not actually used anywhere so it's + // safe to pass in null here. For more detail, see SPARK-13747. + val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait] + waiter.completionFuture.ready(Duration.Inf)(awaitPermission) + waiter.completionFuture.value.get match { + case scala.util.Success(_) => logInfo("Job %d finished: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) - case JobFailed(exception: Exception) => + case scala.util.Failure(exception) => logInfo("Job %d failed: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler. @@ -646,7 +662,7 @@ class DAGScheduler( /** * Submit a shuffle map stage to run independently and get a JobWaiter object back. The waiter - * can be used to block until the the job finishes executing or can be used to cancel the job. + * can be used to block until the job finishes executing or can be used to cancel the job. * This method is used for adaptive query planning, to run map stages and look at statistics * about their outputs before submitting downstream stages. * @@ -738,7 +754,7 @@ class DAGScheduler( } /** - * Check for waiting or failed stages which are now eligible for resubmission. + * Check for waiting stages which are now eligible for resubmission. * Ordinarily run on every iteration of the event loop. */ private def submitWaitingStages() { @@ -796,7 +812,8 @@ class DAGScheduler( private[scheduler] def cleanUpAfterSchedulerStop() { for (job <- activeJobs) { - val error = new SparkException("Job cancelled because SparkContext was shut down") + val error = + new SparkException(s"Job ${job.jobId} cancelled because SparkContext was shut down") job.listener.jobFailed(error) // Tell the listeners that all of the running stages have ended. Don't bother // cancelling the stages because if the DAG scheduler is stopped, the entire application @@ -933,14 +950,9 @@ class DAGScheduler( // First figure out the indexes of partition ids to compute. val partitionsToCompute: Seq[Int] = stage.findMissingPartitions() - // Create internal accumulators if the stage has no accumulators initialized. - // Reset internal accumulators only if this stage is not partially submitted - // Otherwise, we may override existing accumulator values from some tasks - if (stage.internalAccumulators.isEmpty || stage.numPartitions == partitionsToCompute.size) { - stage.resetInternalAccumulators() - } - - val properties = jobIdToActiveJob.get(stage.firstJobId).map(_.properties).orNull + // Use the scheduling pool, job group, description, etc. from an ActiveJob associated + // with this Stage + val properties = jobIdToActiveJob(jobId).properties runningStages += stage // SparkListenerStageSubmitted should be posted before testing whether tasks are @@ -969,7 +981,7 @@ class DAGScheduler( case NonFatal(e) => stage.makeNewStageAttempt(partitionsToCompute.size) listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) - abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e)) + abortStage(stage, s"Task creation failed: $e\n${Utils.exceptionString(e)}", Some(e)) runningStages -= stage return } @@ -989,9 +1001,10 @@ class DAGScheduler( // For ResultTask, serialize and broadcast (rdd, func). val taskBinaryBytes: Array[Byte] = stage match { case stage: ShuffleMapStage => - closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).array() + JavaUtils.bufferToArray( + closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef)) case stage: ResultStage => - closureSerializer.serialize((stage.rdd, stage.func): AnyRef).array() + JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef)) } taskBinary = sc.broadcast(taskBinaryBytes) @@ -1004,7 +1017,7 @@ class DAGScheduler( // Abort execution return case NonFatal(e) => - abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}", Some(e)) + abortStage(stage, s"Task serialization failed: $e\n${Utils.exceptionString(e)}", Some(e)) runningStages -= stage return } @@ -1016,7 +1029,7 @@ class DAGScheduler( val locs = taskIdToLocations(id) val part = stage.rdd.partitions(id) new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, - taskBinary, part, locs, stage.internalAccumulators) + taskBinary, part, locs, stage.latestInfo.internalAccumulators, properties) } case stage: ResultStage => @@ -1026,12 +1039,12 @@ class DAGScheduler( val part = stage.rdd.partitions(p) val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptId, - taskBinary, part, locs, id, stage.internalAccumulators) + taskBinary, part, locs, id, properties, stage.latestInfo.internalAccumulators) } } } catch { case NonFatal(e) => - abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e)) + abortStage(stage, s"Task creation failed: $e\n${Utils.exceptionString(e)}", Some(e)) runningStages -= stage return } @@ -1041,7 +1054,7 @@ class DAGScheduler( stage.pendingPartitions ++= tasks.map(_.partitionId) logDebug("New pending partitions: " + stage.pendingPartitions) taskScheduler.submitTasks(new TaskSet( - tasks.toArray, stage.id, stage.latestInfo.attemptId, stage.firstJobId, properties)) + tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark @@ -1061,39 +1074,40 @@ class DAGScheduler( } } - /** Merge updates from a task to our local accumulator values */ + /** + * Merge local values from a task into the corresponding accumulators previously registered + * here on the driver. + * + * Although accumulators themselves are not thread-safe, this method is called only from one + * thread, the one that runs the scheduling loop. This means we only handle one task + * completion event at a time so we don't need to worry about locking the accumulators. + * This still doesn't stop the caller from updating the accumulator outside the scheduler, + * but that's not our problem since there's nothing we can do about that. + */ private def updateAccumulators(event: CompletionEvent): Unit = { val task = event.task val stage = stageIdToStage(task.stageId) - if (event.accumUpdates != null) { - try { - Accumulators.add(event.accumUpdates) - - event.accumUpdates.foreach { case (id, partialValue) => - // In this instance, although the reference in Accumulators.originals is a WeakRef, - // it's guaranteed to exist since the event.accumUpdates Map exists - - val acc = Accumulators.originals(id).get match { - case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] - case None => throw new NullPointerException("Non-existent reference to Accumulator") - } - - // To avoid UI cruft, ignore cases where value wasn't updated - if (acc.name.isDefined && partialValue != acc.zero) { - val name = acc.name.get - val value = s"${acc.value}" - stage.latestInfo.accumulables(id) = - new AccumulableInfo(id, name, None, value, acc.isInternal) - event.taskInfo.accumulables += - new AccumulableInfo(id, name, Some(s"$partialValue"), value, acc.isInternal) - } + try { + event.accumUpdates.foreach { ainfo => + assert(ainfo.update.isDefined, "accumulator from task should have a partial value") + val id = ainfo.id + val partialValue = ainfo.update.get + // Find the corresponding accumulator on the driver and update it + val acc: Accumulable[Any, Any] = Accumulators.get(id) match { + case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] + case None => + throw new SparkException(s"attempted to access non-existent accumulator $id") + } + acc ++= partialValue + // To avoid UI cruft, ignore cases where value wasn't updated + if (acc.name.isDefined && partialValue != acc.zero) { + stage.latestInfo.accumulables(id) = acc.toInfo(None, Some(acc.value)) + event.taskInfo.accumulables += acc.toInfo(Some(partialValue), Some(acc.value)) } - } catch { - // If we see an exception during accumulator update, just log the - // error and move on. - case e: Exception => - logError(s"Failed to update accumulators for $task", e) } + } catch { + case NonFatal(e) => + logError(s"Failed to update accumulators for task ${task.partitionId}", e) } } @@ -1103,6 +1117,7 @@ class DAGScheduler( */ private[scheduler] def handleTaskCompletion(event: CompletionEvent) { val task = event.task + val taskId = event.taskInfo.id val stageId = task.stageId val taskType = Utils.getFormattedClassName(task) @@ -1112,13 +1127,27 @@ class DAGScheduler( event.taskInfo.attemptNumber, // this is a task attempt number event.reason) - // The success case is dealt with separately below, since we need to compute accumulator - // updates before posting. - if (event.reason != Success) { - val attemptId = task.stageAttemptId - listenerBus.post(SparkListenerTaskEnd(stageId, attemptId, taskType, event.reason, - event.taskInfo, event.taskMetrics)) - } + // Reconstruct task metrics. Note: this may be null if the task has failed. + val taskMetrics: TaskMetrics = + if (event.accumUpdates.nonEmpty) { + try { + TaskMetrics.fromAccumulatorUpdates(event.accumUpdates) + } catch { + case NonFatal(e) => + logError(s"Error when attempting to reconstruct metrics for task $taskId", e) + null + } + } else { + null + } + + // The stage may have already finished when we get this event -- eg. maybe it was a + // speculative task. It is important that we send the TaskEnd event in any case, so listeners + // are properly notified and can chose to handle it. For instance, some listeners are + // doing their own accounting and if they don't get the task end event they think + // tasks are still running when they really aren't. + listenerBus.post(SparkListenerTaskEnd( + stageId, task.stageAttemptId, taskType, event.reason, event.taskInfo, taskMetrics)) if (!stageIdToStage.contains(task.stageId)) { // Skip all the actions if the stage has been cancelled. @@ -1128,8 +1157,6 @@ class DAGScheduler( val stage = stageIdToStage(task.stageId) event.reason match { case Success => - listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, - event.reason, event.taskInfo, event.taskMetrics)) stage.pendingPartitions -= task.partitionId task match { case rt: ResultTask[_, _] => @@ -1278,12 +1305,13 @@ class DAGScheduler( // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits case exceptionFailure: ExceptionFailure => - // Do nothing here, left up to the TaskScheduler to decide how to handle user failures + // Tasks failed with exceptions might still have accumulator updates. + updateAccumulators(event) case TaskResultLost => // Do nothing here; the TaskScheduler handles these failures and resubmits the task. - case other => + case _: ExecutorLostFailure | TaskKilled | UnknownReason => // Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler // will abort the job. } @@ -1568,14 +1596,11 @@ class DAGScheduler( } def stop() { - logInfo("Stopping DAGScheduler") messageScheduler.shutdownNow() eventProcessLoop.stop() taskScheduler.stop() } - // Start the event thread and register the metrics source at the end of the constructor - env.metricsSystem.registerSource(metricsSource) eventProcessLoop.start() } @@ -1627,7 +1652,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case GettingResultEvent(taskInfo) => dagScheduler.handleGetTaskResult(taskInfo) - case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) => + case completion: CompletionEvent => dagScheduler.handleTaskCompletion(completion) case TaskSetFailed(taskSet, reason, exception) => diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index dda3b6cc7f960..a3845c6acd774 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -19,11 +19,9 @@ package org.apache.spark.scheduler import java.util.Properties -import scala.collection.Map import scala.language.existentials import org.apache.spark._ -import org.apache.spark.executor.TaskMetrics import org.apache.spark.rdd.RDD import org.apache.spark.util.CallSite @@ -73,9 +71,8 @@ private[scheduler] case class CompletionEvent( task: Task[_], reason: TaskEndReason, result: Any, - accumUpdates: Map[Long, Any], - taskInfo: TaskInfo, - taskMetrics: TaskMetrics) + accumUpdates: Seq[AccumulableInfo], + taskInfo: TaskInfo) extends DAGSchedulerEvent private[scheduler] case class ExecutorAdded(execId: String, host: String) extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 000a021a528cf..a7d06391176d2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -19,19 +19,20 @@ package org.apache.spark.scheduler import java.io._ import java.net.URI +import java.nio.charset.StandardCharsets import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import com.google.common.base.Charsets import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, FSDataOutputStream, Path} import org.apache.hadoop.fs.permission.FsPermission import org.json4s.JsonAST.JValue import org.json4s.jackson.JsonMethods._ -import org.apache.spark.{Logging, SparkConf, SPARK_VERSION} +import org.apache.spark.{SPARK_VERSION, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec import org.apache.spark.util.{JsonProtocol, Utils} @@ -77,14 +78,6 @@ private[spark] class EventLoggingListener( // Only defined if the file system scheme is not local private var hadoopDataStream: Option[FSDataOutputStream] = None - // The Hadoop APIs have changed over time, so we use reflection to figure out - // the correct method to use to flush a hadoop data stream. See SPARK-1518 - // for details. - private val hadoopFlushMethod = { - val cls = classOf[FSDataOutputStream] - scala.util.Try(cls.getMethod("hflush")).getOrElse(cls.getMethod("sync")) - } - private var writer: Option[PrintWriter] = None // For testing. Keep track of all JSON serialized events that have been logged. @@ -97,7 +90,7 @@ private[spark] class EventLoggingListener( * Creates the log file in the configured log directory. */ def start() { - if (!fileSystem.getFileStatus(new Path(logBaseDir)).isDir) { + if (!fileSystem.getFileStatus(new Path(logBaseDir)).isDirectory) { throw new IllegalArgumentException(s"Log directory $logBaseDir does not exist.") } @@ -147,7 +140,7 @@ private[spark] class EventLoggingListener( // scalastyle:on println if (flushLogger) { writer.foreach(_.flush()) - hadoopDataStream.foreach(hadoopFlushMethod.invoke(_)) + hadoopDataStream.foreach(_.hflush()) } if (testing) { loggedEvents += eventJson @@ -207,6 +200,12 @@ private[spark] class EventLoggingListener( // No-op because logging every update would be overkill override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } + override def onOtherEvent(event: SparkListenerEvent): Unit = { + if (event.logEvent) { + logEvent(event, flushLogger = true) + } + } + /** * Stop logging events. The event log file will be renamed so that it loses the * ".inprogress" suffix. @@ -226,6 +225,13 @@ private[spark] class EventLoggingListener( } } fileSystem.rename(new Path(logPath + IN_PROGRESS), target) + // touch file to ensure modtime is current across those filesystems where rename() + // does not set it, -and which support setTimes(); it's a no-op on most object stores + try { + fileSystem.setTimes(target, System.currentTimeMillis(), -1) + } catch { + case e: Exception => logDebug(s"failed to set time of $target", e) + } } } @@ -234,8 +240,6 @@ private[spark] object EventLoggingListener extends Logging { // Suffix applied to the names of files still being written by applications. val IN_PROGRESS = ".inprogress" val DEFAULT_LOG_DIR = "/tmp/spark-events" - val SPARK_VERSION_KEY = "SPARK_VERSION" - val COMPRESSION_CODEC_KEY = "COMPRESSION_CODEC" private val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort) @@ -251,7 +255,7 @@ private[spark] object EventLoggingListener extends Logging { def initEventLog(logStream: OutputStream): Unit = { val metadata = SparkListenerLogStart(SPARK_VERSION) val metadataJson = compact(JsonProtocol.logStartToJson(metadata)) + "\n" - logStream.write(metadataJson.getBytes(Charsets.UTF_8)) + logStream.write(metadataJson.getBytes(StandardCharsets.UTF_8)) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala index 47a5cbff4930b..7e1197d742802 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala @@ -40,6 +40,8 @@ private[spark] object ExecutorExited { } } +private[spark] object ExecutorKilled extends ExecutorLossReason("Executor killed by driver.") + /** * A loss reason that means we don't yet know why the executor exited. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala index 0e438ab4366d9..a6b032cc0084c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala @@ -26,9 +26,9 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.util.ReflectionUtils -import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging /** * :: DeveloperApi :: @@ -57,11 +57,10 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl // Since we are not doing canonicalization of path, this can be wrong : like relative vs // absolute path .. which is fine, this is best case effort to remove duplicates - right ? override def equals(other: Any): Boolean = other match { - case that: InputFormatInfo => { + case that: InputFormatInfo => // not checking config - that should be fine, right ? this.inputFormatClazz == that.inputFormatClazz && this.path == that.path - } case _ => false } @@ -86,10 +85,9 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl } } catch { - case e: ClassNotFoundException => { + case e: ClassNotFoundException => throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz + " cannot be found ?", e) - } } } @@ -103,7 +101,7 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] = ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], conf).asInstanceOf[ org.apache.hadoop.mapreduce.InputFormat[_, _]] - val job = new Job(conf) + val job = Job.getInstance(conf) val retval = new ArrayBuffer[SplitInfo]() val list = instance.getSplits(job) @@ -157,7 +155,7 @@ object InputFormatInfo { b) Decrement the currently allocated containers on that host. c) Compute rack info for each host and update rack -> count map based on (b). d) Allocate nodes based on (c) - e) On the allocation result, ensure that we dont allocate "too many" jobs on a single node + e) On the allocation result, ensure that we don't allocate "too many" jobs on a single node (even if data locality on that is very high) : this is to prevent fragility of job if a single (or small set of) hosts go down. @@ -173,7 +171,7 @@ object InputFormatInfo { for (inputSplit <- formats) { val splits = inputSplit.findPreferredLocations() - for (split <- splits){ + for (split <- splits) { val location = split.hostLocation val set = nodeToSplit.getOrElseUpdate(location, new HashSet[SplitInfo]) set += split diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala b/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala index 50c2b9acd609c..e0f7c8f02132d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala @@ -23,6 +23,6 @@ package org.apache.spark.scheduler * job fails (and no further taskSucceeded events will happen). */ private[spark] trait JobListener { - def taskSucceeded(index: Int, result: Any) - def jobFailed(exception: Exception) + def taskSucceeded(index: Int, result: Any): Unit + def jobFailed(exception: Exception): Unit } diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala deleted file mode 100644 index f96eb8ca0ae00..0000000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ /dev/null @@ -1,277 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler - -import java.io.{File, FileNotFoundException, IOException, PrintWriter} -import java.text.SimpleDateFormat -import java.util.{Date, Properties} - -import scala.collection.mutable.HashMap - -import org.apache.spark._ -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.executor.TaskMetrics - -/** - * :: DeveloperApi :: - * A logger class to record runtime information for jobs in Spark. This class outputs one log file - * for each Spark job, containing tasks start/stop and shuffle information. JobLogger is a subclass - * of SparkListener, use addSparkListener to add JobLogger to a SparkContext after the SparkContext - * is created. Note that each JobLogger only works for one SparkContext - * - * NOTE: The functionality of this class is heavily stripped down to accommodate for a general - * refactor of the SparkListener interface. In its place, the EventLoggingListener is introduced - * to log application information as SparkListenerEvents. To enable this functionality, set - * spark.eventLog.enabled to true. - */ -@DeveloperApi -@deprecated("Log application information by setting spark.eventLog.enabled.", "1.0.0") -class JobLogger(val user: String, val logDirName: String) extends SparkListener with Logging { - - def this() = this(System.getProperty("user.name", ""), - String.valueOf(System.currentTimeMillis())) - - private val logDir = - if (System.getenv("SPARK_LOG_DIR") != null) { - System.getenv("SPARK_LOG_DIR") - } else { - "/tmp/spark-%s".format(user) - } - - private val jobIdToPrintWriter = new HashMap[Int, PrintWriter] - private val stageIdToJobId = new HashMap[Int, Int] - private val jobIdToStageIds = new HashMap[Int, Seq[Int]] - private val dateFormat = new ThreadLocal[SimpleDateFormat]() { - override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") - } - - createLogDir() - - /** Create a folder for log files, the folder's name is the creation time of jobLogger */ - protected def createLogDir() { - val dir = new File(logDir + "/" + logDirName + "/") - if (dir.exists()) { - return - } - if (!dir.mkdirs()) { - // JobLogger should throw a exception rather than continue to construct this object. - throw new IOException("create log directory error:" + logDir + "/" + logDirName + "/") - } - } - - /** - * Create a log file for one job - * @param jobId ID of the job - * @throws FileNotFoundException Fail to create log file - */ - protected def createLogWriter(jobId: Int) { - try { - val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobId) - jobIdToPrintWriter += (jobId -> fileWriter) - } catch { - case e: FileNotFoundException => e.printStackTrace() - } - } - - /** - * Close log file, and clean the stage relationship in stageIdToJobId - * @param jobId ID of the job - */ - protected def closeLogWriter(jobId: Int) { - jobIdToPrintWriter.get(jobId).foreach { fileWriter => - fileWriter.close() - jobIdToStageIds.get(jobId).foreach(_.foreach { stageId => - stageIdToJobId -= stageId - }) - jobIdToPrintWriter -= jobId - jobIdToStageIds -= jobId - } - } - - /** - * Build up the maps that represent stage-job relationships - * @param jobId ID of the job - * @param stageIds IDs of the associated stages - */ - protected def buildJobStageDependencies(jobId: Int, stageIds: Seq[Int]) = { - jobIdToStageIds(jobId) = stageIds - stageIds.foreach { stageId => stageIdToJobId(stageId) = jobId } - } - - /** - * Write info into log file - * @param jobId ID of the job - * @param info Info to be recorded - * @param withTime Controls whether to record time stamp before the info, default is true - */ - protected def jobLogInfo(jobId: Int, info: String, withTime: Boolean = true) { - var writeInfo = info - if (withTime) { - val date = new Date(System.currentTimeMillis()) - writeInfo = dateFormat.get.format(date) + ": " + info - } - // scalastyle:off println - jobIdToPrintWriter.get(jobId).foreach(_.println(writeInfo)) - // scalastyle:on println - } - - /** - * Write info into log file - * @param stageId ID of the stage - * @param info Info to be recorded - * @param withTime Controls whether to record time stamp before the info, default is true - */ - protected def stageLogInfo(stageId: Int, info: String, withTime: Boolean = true) { - stageIdToJobId.get(stageId).foreach(jobId => jobLogInfo(jobId, info, withTime)) - } - - /** - * Record task metrics into job log files, including execution info and shuffle metrics - * @param stageId Stage ID of the task - * @param status Status info of the task - * @param taskInfo Task description info - * @param taskMetrics Task running metrics - */ - protected def recordTaskMetrics(stageId: Int, status: String, - taskInfo: TaskInfo, taskMetrics: TaskMetrics) { - val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageId + - " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime + - " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname - val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime - val gcTime = " GC_TIME=" + taskMetrics.jvmGCTime - val inputMetrics = taskMetrics.inputMetrics match { - case Some(metrics) => - " READ_METHOD=" + metrics.readMethod.toString + - " INPUT_BYTES=" + metrics.bytesRead - case None => "" - } - val outputMetrics = taskMetrics.outputMetrics match { - case Some(metrics) => - " OUTPUT_BYTES=" + metrics.bytesWritten - case None => "" - } - val shuffleReadMetrics = taskMetrics.shuffleReadMetrics match { - case Some(metrics) => - " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched + - " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched + - " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched + - " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime + - " REMOTE_BYTES_READ=" + metrics.remoteBytesRead + - " LOCAL_BYTES_READ=" + metrics.localBytesRead - case None => "" - } - val writeMetrics = taskMetrics.shuffleWriteMetrics match { - case Some(metrics) => - " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten + - " SHUFFLE_WRITE_TIME=" + metrics.shuffleWriteTime - case None => "" - } - stageLogInfo(stageId, status + info + executorRunTime + gcTime + inputMetrics + outputMetrics + - shuffleReadMetrics + writeMetrics) - } - - /** - * When stage is submitted, record stage submit info - * @param stageSubmitted Stage submitted event - */ - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { - val stageInfo = stageSubmitted.stageInfo - stageLogInfo(stageInfo.stageId, "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format( - stageInfo.stageId, stageInfo.numTasks)) - } - - /** - * When stage is completed, record stage completion status - * @param stageCompleted Stage completed event - */ - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { - val stageId = stageCompleted.stageInfo.stageId - if (stageCompleted.stageInfo.failureReason.isEmpty) { - stageLogInfo(stageId, s"STAGE_ID=$stageId STATUS=COMPLETED") - } else { - stageLogInfo(stageId, s"STAGE_ID=$stageId STATUS=FAILED") - } - } - - /** - * When task ends, record task completion status and metrics - * @param taskEnd Task end event - */ - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { - val taskInfo = taskEnd.taskInfo - var taskStatus = "TASK_TYPE=%s".format(taskEnd.taskType) - val taskMetrics = if (taskEnd.taskMetrics != null) taskEnd.taskMetrics else TaskMetrics.empty - taskEnd.reason match { - case Success => taskStatus += " STATUS=SUCCESS" - recordTaskMetrics(taskEnd.stageId, taskStatus, taskInfo, taskMetrics) - case Resubmitted => - taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId + - " STAGE_ID=" + taskEnd.stageId - stageLogInfo(taskEnd.stageId, taskStatus) - case FetchFailed(bmAddress, shuffleId, mapId, reduceId, message) => - taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" + - taskEnd.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" + - mapId + " REDUCE_ID=" + reduceId - stageLogInfo(taskEnd.stageId, taskStatus) - case _ => - } - } - - /** - * When job ends, recording job completion status and close log file - * @param jobEnd Job end event - */ - override def onJobEnd(jobEnd: SparkListenerJobEnd) { - val jobId = jobEnd.jobId - var info = "JOB_ID=" + jobId - jobEnd.jobResult match { - case JobSucceeded => info += " STATUS=SUCCESS" - case JobFailed(exception) => - info += " STATUS=FAILED REASON=" - exception.getMessage.split("\\s+").foreach(info += _ + "_") - case _ => - } - jobLogInfo(jobId, info.substring(0, info.length - 1).toUpperCase) - closeLogWriter(jobId) - } - - /** - * Record job properties into job log file - * @param jobId ID of the job - * @param properties Properties of the job - */ - protected def recordJobProperties(jobId: Int, properties: Properties) { - if (properties != null) { - val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "") - jobLogInfo(jobId, description, withTime = false) - } - } - - /** - * When job starts, record job property and stage graph - * @param jobStart Job start event - */ - override def onJobStart(jobStart: SparkListenerJobStart) { - val jobId = jobStart.jobId - val properties = jobStart.properties - createLogWriter(jobId) - recordJobProperties(jobId, properties) - buildJobStageDependencies(jobId, jobStart.stageIds) - jobLogInfo(jobId, "JOB_ID=" + jobId + " STATUS=STARTED") - } -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala b/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala index 4cd6cbe189aab..4a304a078d658 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala @@ -29,5 +29,4 @@ sealed trait JobResult @DeveloperApi case object JobSucceeded extends JobResult -@DeveloperApi private[spark] case class JobFailed(exception: Exception) extends JobResult diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 382b09422a4a0..9012289f047c5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -17,6 +17,12 @@ package org.apache.spark.scheduler +import java.util.concurrent.atomic.AtomicInteger + +import scala.concurrent.{Future, Promise} + +import org.apache.spark.internal.Logging + /** * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their * results to the given handler function. @@ -26,19 +32,17 @@ private[spark] class JobWaiter[T]( val jobId: Int, totalTasks: Int, resultHandler: (Int, T) => Unit) - extends JobListener { - - private var finishedTasks = 0 - - // Is the job as a whole finished (succeeded or failed)? - @volatile - private var _jobFinished = totalTasks == 0 - - def jobFinished: Boolean = _jobFinished + extends JobListener with Logging { + private val finishedTasks = new AtomicInteger(0) // If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero // partition RDDs), we set the jobResult directly to JobSucceeded. - private var jobResult: JobResult = if (jobFinished) JobSucceeded else null + private val jobPromise: Promise[Unit] = + if (totalTasks == 0) Promise.successful(()) else Promise() + + def jobFinished: Boolean = jobPromise.isCompleted + + def completionFuture: Future[Unit] = jobPromise.future /** * Sends a signal to the DAGScheduler to cancel the job. The cancellation itself is handled @@ -49,29 +53,20 @@ private[spark] class JobWaiter[T]( dagScheduler.cancelJob(jobId) } - override def taskSucceeded(index: Int, result: Any): Unit = synchronized { - if (_jobFinished) { - throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter") + override def taskSucceeded(index: Int, result: Any): Unit = { + // resultHandler call must be synchronized in case resultHandler itself is not thread safe. + synchronized { + resultHandler(index, result.asInstanceOf[T]) } - resultHandler(index, result.asInstanceOf[T]) - finishedTasks += 1 - if (finishedTasks == totalTasks) { - _jobFinished = true - jobResult = JobSucceeded - this.notifyAll() + if (finishedTasks.incrementAndGet() == totalTasks) { + jobPromise.success(()) } } - override def jobFailed(exception: Exception): Unit = synchronized { - _jobFinished = true - jobResult = JobFailed(exception) - this.notifyAll() - } - - def awaitResult(): JobResult = synchronized { - while (!_jobFinished) { - this.wait() + override def jobFailed(exception: Exception): Unit = { + if (!jobPromise.tryFailure(exception)) { + logWarning("Ignore failure", exception) } - return jobResult } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index be23056e7d423..1c21313d1cb17 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -17,24 +17,169 @@ package org.apache.spark.scheduler +import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean -import org.apache.spark.util.AsynchronousListenerBus +import scala.util.DynamicVariable + +import org.apache.spark.SparkContext +import org.apache.spark.util.Utils /** * Asynchronously passes SparkListenerEvents to registered SparkListeners. * - * Until start() is called, all posted events are only buffered. Only after this listener bus + * Until `start()` is called, all posted events are only buffered. Only after this listener bus * has started will events be actually propagated to all attached listeners. This listener bus - * is stopped when it receives a SparkListenerShutdown event, which is posted using stop(). + * is stopped when `stop()` is called, and it will drop further events after stopping. */ -private[spark] class LiveListenerBus - extends AsynchronousListenerBus[SparkListener, SparkListenerEvent]("SparkListenerBus") - with SparkListenerBus { +private[spark] class LiveListenerBus extends SparkListenerBus { + + self => + + import LiveListenerBus._ + + private var sparkContext: SparkContext = null + + // Cap the capacity of the event queue so we get an explicit error (rather than + // an OOM exception) if it's perpetually being added to more quickly than it's being drained. + private val EVENT_QUEUE_CAPACITY = 10000 + private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent](EVENT_QUEUE_CAPACITY) + + // Indicate if `start()` is called + private val started = new AtomicBoolean(false) + // Indicate if `stop()` is called + private val stopped = new AtomicBoolean(false) + + // Indicate if we are processing some event + // Guarded by `self` + private var processingEvent = false private val logDroppedEvent = new AtomicBoolean(false) - override def onDropEvent(event: SparkListenerEvent): Unit = { + // A counter that represents the number of events produced and consumed in the queue + private val eventLock = new Semaphore(0) + + private val listenerThread = new Thread(name) { + setDaemon(true) + override def run(): Unit = Utils.tryOrStopSparkContext(sparkContext) { + LiveListenerBus.withinListenerThread.withValue(true) { + while (true) { + eventLock.acquire() + self.synchronized { + processingEvent = true + } + try { + val event = eventQueue.poll + if (event == null) { + // Get out of the while loop and shutdown the daemon thread + if (!stopped.get) { + throw new IllegalStateException("Polling `null` from eventQueue means" + + " the listener bus has been stopped. So `stopped` must be true") + } + return + } + postToAll(event) + } finally { + self.synchronized { + processingEvent = false + } + } + } + } + } + } + + /** + * Start sending events to attached listeners. + * + * This first sends out all buffered events posted before this listener bus has started, then + * listens for any additional events asynchronously while the listener bus is still running. + * This should only be called once. + * + * @param sc Used to stop the SparkContext in case the listener thread dies. + */ + def start(sc: SparkContext): Unit = { + if (started.compareAndSet(false, true)) { + sparkContext = sc + listenerThread.start() + } else { + throw new IllegalStateException(s"$name already started!") + } + } + + def post(event: SparkListenerEvent): Unit = { + if (stopped.get) { + // Drop further events to make `listenerThread` exit ASAP + logError(s"$name has already stopped! Dropping event $event") + return + } + val eventAdded = eventQueue.offer(event) + if (eventAdded) { + eventLock.release() + } else { + onDropEvent(event) + } + } + + /** + * For testing only. Wait until there are no more events in the queue, or until the specified + * time has elapsed. Throw `TimeoutException` if the specified time elapsed before the queue + * emptied. + * Exposed for testing. + */ + @throws(classOf[TimeoutException]) + def waitUntilEmpty(timeoutMillis: Long): Unit = { + val finishTime = System.currentTimeMillis + timeoutMillis + while (!queueIsEmpty) { + if (System.currentTimeMillis > finishTime) { + throw new TimeoutException( + s"The event queue is not empty after $timeoutMillis milliseconds") + } + /* Sleep rather than using wait/notify, because this is used only for testing and + * wait/notify add overhead in the general case. */ + Thread.sleep(10) + } + } + + /** + * For testing only. Return whether the listener daemon thread is still alive. + * Exposed for testing. + */ + def listenerThreadIsAlive: Boolean = listenerThread.isAlive + + /** + * Return whether the event queue is empty. + * + * The use of synchronized here guarantees that all events that once belonged to this queue + * have already been processed by all attached listeners, if this returns true. + */ + private def queueIsEmpty: Boolean = synchronized { eventQueue.isEmpty && !processingEvent } + + /** + * Stop the listener bus. It will wait until the queued events have been processed, but drop the + * new events after stopping. + */ + def stop(): Unit = { + if (!started.get()) { + throw new IllegalStateException(s"Attempted to stop $name that has not yet started!") + } + if (stopped.compareAndSet(false, true)) { + // Call eventLock.release() so that listenerThread will poll `null` from `eventQueue` and know + // `stop` is called. + eventLock.release() + listenerThread.join() + } else { + // Keep quiet + } + } + + /** + * If the event queue exceeds its capacity, the new events will be dropped. The subclasses will be + * notified with the dropped events. + * + * Note: `onDropEvent` can be called in any thread. + */ + def onDropEvent(event: SparkListenerEvent): Unit = { if (logDroppedEvent.compareAndSet(false, true)) { // Only log the following message once to avoid duplicated annoying logs. logError("Dropping SparkListenerEvent because no remaining room in event queue. " + @@ -42,5 +187,13 @@ private[spark] class LiveListenerBus "the rate at which tasks are being started by the scheduler.") } } +} + +private[spark] object LiveListenerBus { + // Allows for Context to check whether stop() call is made within listener thread + val withinListenerThread: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false) + /** The thread name of Spark listener bus */ + val name = "SparkListenerBus" } + diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 180c8d1827e13..b2e9a97129f08 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -19,8 +19,9 @@ package org.apache.spark.scheduler import java.io.{Externalizable, ObjectInput, ObjectOutput} +import org.roaringbitmap.RoaringBitmap + import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.collection.BitSet import org.apache.spark.util.Utils /** @@ -121,8 +122,7 @@ private[spark] class CompressedMapStatus( /** * A [[MapStatus]] implementation that only stores the average size of non-empty blocks, - * plus a bitmap for tracking which blocks are empty. During serialization, this bitmap - * is compressed. + * plus a bitmap for tracking which blocks are empty. * * @param loc location where the task is being executed * @param numNonEmptyBlocks the number of non-empty blocks @@ -132,7 +132,7 @@ private[spark] class CompressedMapStatus( private[spark] class HighlyCompressedMapStatus private ( private[this] var loc: BlockManagerId, private[this] var numNonEmptyBlocks: Int, - private[this] var emptyBlocks: BitSet, + private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long) extends MapStatus with Externalizable { @@ -145,7 +145,7 @@ private[spark] class HighlyCompressedMapStatus private ( override def location: BlockManagerId = loc override def getSizeForBlock(reduceId: Int): Long = { - if (emptyBlocks.get(reduceId)) { + if (emptyBlocks.contains(reduceId)) { 0 } else { avgSize @@ -160,7 +160,7 @@ private[spark] class HighlyCompressedMapStatus private ( override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { loc = BlockManagerId(in) - emptyBlocks = new BitSet + emptyBlocks = new RoaringBitmap() emptyBlocks.readExternal(in) avgSize = in.readLong() } @@ -176,15 +176,15 @@ private[spark] object HighlyCompressedMapStatus { // From a compression standpoint, it shouldn't matter whether we track empty or non-empty // blocks. From a performance standpoint, we benefit from tracking empty blocks because // we expect that there will be far fewer of them, so we will perform fewer bitmap insertions. + val emptyBlocks = new RoaringBitmap() val totalNumBlocks = uncompressedSizes.length - val emptyBlocks = new BitSet(totalNumBlocks) while (i < totalNumBlocks) { var size = uncompressedSizes(i) if (size > 0) { numNonEmptyBlocks += 1 totalSize += size } else { - emptyBlocks.set(i) + emptyBlocks.add(i) } i += 1 } @@ -193,6 +193,8 @@ private[spark] object HighlyCompressedMapStatus { } else { 0 } + emptyBlocks.trim() + emptyBlocks.runOptimize() new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 4d146678174f6..2dd453cd63973 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -20,7 +20,8 @@ package org.apache.spark.scheduler import scala.collection.mutable import org.apache.spark._ -import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, RpcEndpoint} +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} private sealed trait OutputCommitCoordinationMessage extends Serializable diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 551e39a81b695..4cd13e2feaeb1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -22,7 +22,7 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.scheduler.SchedulingMode.SchedulingMode /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index c6d957b65f3fb..d32f5eb7bfe92 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -24,7 +24,7 @@ import scala.io.Source import com.fasterxml.jackson.core.JsonParseException import org.json4s.jackson.JsonMethods._ -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.util.JsonProtocol /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index fb693721a9cb6..db6276f75d781 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -17,9 +17,9 @@ package org.apache.spark.scheduler -import java.nio.ByteBuffer - import java.io._ +import java.nio.ByteBuffer +import java.util.Properties import org.apache.spark._ import org.apache.spark.broadcast.Broadcast @@ -31,6 +31,7 @@ import org.apache.spark.rdd.RDD * See [[Task]] for more information. * * @param stageId id of the stage this task belongs to + * @param stageAttemptId attempt id of the stage this task belongs to * @param taskBinary broadcasted version of the serialized RDD and the function to apply on each * partition of the given RDD. Once deserialized, the type should be * (RDD[T], (TaskContext, Iterator[T]) => U). @@ -38,6 +39,10 @@ import org.apache.spark.rdd.RDD * @param locs preferred task execution locations for locality scheduling * @param outputId index of the task in this job (a job can launch tasks on only a subset of the * input RDD's partitions). + * @param localProperties copy of thread-local properties set by the user on the driver side. + * @param _initialAccums initial set of accumulators to be used in this task for tracking + * internal metrics. Other accumulators will be registered later when + * they are deserialized on the executors. */ private[spark] class ResultTask[T, U]( stageId: Int, @@ -46,8 +51,9 @@ private[spark] class ResultTask[T, U]( partition: Partition, locs: Seq[TaskLocation], val outputId: Int, - internalAccumulators: Seq[Accumulator[Long]]) - extends Task[U](stageId, stageAttemptId, partition.index, internalAccumulators) + localProperties: Properties, + _initialAccums: Seq[Accumulator[_]] = InternalAccumulator.createAll()) + extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums, localProperties) with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala index 6c5827f75e636..100ed76ecb6d6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala @@ -22,7 +22,8 @@ import java.util.{NoSuchElementException, Properties} import scala.xml.XML -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging import org.apache.spark.util.Utils /** @@ -33,9 +34,9 @@ import org.apache.spark.util.Utils private[spark] trait SchedulableBuilder { def rootPool: Pool - def buildPools() + def buildPools(): Unit - def addTaskSetManager(manager: Schedulable, properties: Properties) + def addTaskSetManager(manager: Schedulable, properties: Properties): Unit } private[spark] class FIFOSchedulableBuilder(val rootPool: Pool) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index f478f9982afef..b7cab7013ef6f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -18,25 +18,32 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer +import java.util.Properties import scala.language.existentials import org.apache.spark._ import org.apache.spark.broadcast.Broadcast +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.shuffle.ShuffleWriter /** -* A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner -* specified in the ShuffleDependency). -* -* See [[org.apache.spark.scheduler.Task]] for more information. -* + * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner + * specified in the ShuffleDependency). + * + * See [[org.apache.spark.scheduler.Task]] for more information. + * * @param stageId id of the stage this task belongs to + * @param stageAttemptId attempt id of the stage this task belongs to * @param taskBinary broadcast version of the RDD and the ShuffleDependency. Once deserialized, * the type should be (RDD[_], ShuffleDependency[_, _, _]). * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling + * @param _initialAccums initial set of accumulators to be used in this task for tracking + * internal metrics. Other accumulators will be registered later when + * they are deserialized on the executors. + * @param localProperties copy of thread-local properties set by the user on the driver side. */ private[spark] class ShuffleMapTask( stageId: Int, @@ -44,13 +51,14 @@ private[spark] class ShuffleMapTask( taskBinary: Broadcast[Array[Byte]], partition: Partition, @transient private var locs: Seq[TaskLocation], - internalAccumulators: Seq[Accumulator[Long]]) - extends Task[MapStatus](stageId, stageAttemptId, partition.index, internalAccumulators) + _initialAccums: Seq[Accumulator[_]], + localProperties: Properties) + extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums, localProperties) with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ def this(partitionId: Int) { - this(0, 0, null, new Partition { override def index: Int = 0 }, null, null) + this(0, 0, null, new Partition { override def index: Int = 0 }, null, null, new Properties) } @transient private val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 896f1743332f1..080ea6c33a7dd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -18,19 +18,28 @@ package org.apache.spark.scheduler import java.util.Properties +import javax.annotation.Nullable import scala.collection.Map import scala.collection.mutable -import org.apache.spark.{Logging, TaskEndReason} +import com.fasterxml.jackson.annotation.JsonTypeInfo + +import org.apache.spark.{SparkConf, TaskEndReason} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.Logging import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo} +import org.apache.spark.ui.SparkUI import org.apache.spark.util.{Distribution, Utils} @DeveloperApi -sealed trait SparkListenerEvent +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event") +trait SparkListenerEvent { + /* Whether output this event to the event log */ + protected[spark] def logEvent: Boolean = true +} @DeveloperApi case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Properties = null) @@ -53,7 +62,8 @@ case class SparkListenerTaskEnd( taskType: String, reason: TaskEndReason, taskInfo: TaskInfo, - taskMetrics: TaskMetrics) + // may be null if the task has failed + @Nullable taskMetrics: TaskMetrics) extends SparkListenerEvent @DeveloperApi @@ -104,12 +114,12 @@ case class SparkListenerBlockUpdated(blockUpdatedInfo: BlockUpdatedInfo) extends /** * Periodic updates from executors. * @param execId executor id - * @param taskMetrics sequence of (task id, stage id, stage attempt, metrics) + * @param accumUpdates sequence of (taskId, stageId, stageAttemptId, accumUpdates) */ @DeveloperApi case class SparkListenerExecutorMetricsUpdate( execId: String, - taskMetrics: Seq[(Long, Int, Int, TaskMetrics)]) + accumUpdates: Seq[(Long, Int, Int, Seq[AccumulableInfo])]) extends SparkListenerEvent @DeveloperApi @@ -131,258 +141,162 @@ case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent private[spark] case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent /** - * :: DeveloperApi :: - * Interface for listening to events from the Spark scheduler. Note that this is an internal - * interface which might change in different Spark releases. Java clients should extend - * {@link JavaSparkListener} + * Interface for creating history listeners defined in other modules like SQL, which are used to + * rebuild the history UI. */ -@DeveloperApi -trait SparkListener { +private[spark] trait SparkHistoryListenerFactory { + /** + * Create listeners used to rebuild the history UI. + */ + def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] +} + + +/** + * Interface for listening to events from the Spark scheduler. Most applications should probably + * extend SparkListener or SparkFirehoseListener directly, rather than implementing this class. + * + * Note that this is an internal interface which might change in different Spark releases. + */ +private[spark] trait SparkListenerInterface { + /** * Called when a stage completes successfully or fails, with information on the completed stage. */ - def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { } + def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit /** * Called when a stage is submitted */ - def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { } + def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit /** * Called when a task starts */ - def onTaskStart(taskStart: SparkListenerTaskStart) { } + def onTaskStart(taskStart: SparkListenerTaskStart): Unit /** * Called when a task begins remotely fetching its result (will not be called for tasks that do * not need to fetch the result remotely). */ - def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { } + def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult): Unit /** * Called when a task ends */ - def onTaskEnd(taskEnd: SparkListenerTaskEnd) { } + def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit /** * Called when a job starts */ - def onJobStart(jobStart: SparkListenerJobStart) { } + def onJobStart(jobStart: SparkListenerJobStart): Unit /** * Called when a job ends */ - def onJobEnd(jobEnd: SparkListenerJobEnd) { } + def onJobEnd(jobEnd: SparkListenerJobEnd): Unit /** * Called when environment properties have been updated */ - def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { } + def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate): Unit /** * Called when a new block manager has joined */ - def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded) { } + def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit /** * Called when an existing block manager has been removed */ - def onBlockManagerRemoved(blockManagerRemoved: SparkListenerBlockManagerRemoved) { } + def onBlockManagerRemoved(blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit /** * Called when an RDD is manually unpersisted by the application */ - def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD) { } + def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit /** * Called when the application starts */ - def onApplicationStart(applicationStart: SparkListenerApplicationStart) { } + def onApplicationStart(applicationStart: SparkListenerApplicationStart): Unit /** * Called when the application ends */ - def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd) { } + def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit /** * Called when the driver receives task metrics from an executor in a heartbeat. */ - def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { } + def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit /** * Called when the driver registers a new executor. */ - def onExecutorAdded(executorAdded: SparkListenerExecutorAdded) { } + def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit /** * Called when the driver removes an executor. */ - def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved) { } + def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit /** * Called when the driver receives a block update info. */ - def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated) { } + def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit + + /** + * Called when other events like SQL-specific events are posted. + */ + def onOtherEvent(event: SparkListenerEvent): Unit } + /** * :: DeveloperApi :: - * Simple SparkListener that logs a few summary statistics when each stage completes + * A default implementation for [[SparkListenerInterface]] that has no-op implementations for + * all callbacks. + * + * Note that this is an internal interface which might change in different Spark releases. */ @DeveloperApi -class StatsReportListener extends SparkListener with Logging { - - import org.apache.spark.scheduler.StatsReportListener._ - - private val taskInfoMetrics = mutable.Buffer[(TaskInfo, TaskMetrics)]() - - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { - val info = taskEnd.taskInfo - val metrics = taskEnd.taskMetrics - if (info != null && metrics != null) { - taskInfoMetrics += ((info, metrics)) - } - } - - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { - implicit val sc = stageCompleted - this.logInfo("Finished stage: " + stageCompleted.stageInfo) - showMillisDistribution("task runtime:", (info, _) => Some(info.duration), taskInfoMetrics) - - // Shuffle write - showBytesDistribution("shuffle bytes written:", - (_, metric) => metric.shuffleWriteMetrics.map(_.shuffleBytesWritten), taskInfoMetrics) - - // Fetch & I/O - showMillisDistribution("fetch wait time:", - (_, metric) => metric.shuffleReadMetrics.map(_.fetchWaitTime), taskInfoMetrics) - showBytesDistribution("remote bytes read:", - (_, metric) => metric.shuffleReadMetrics.map(_.remoteBytesRead), taskInfoMetrics) - showBytesDistribution("task result size:", - (_, metric) => Some(metric.resultSize), taskInfoMetrics) - - // Runtime breakdown - val runtimePcts = taskInfoMetrics.map { case (info, metrics) => - RuntimePercentage(info.duration, metrics) - } - showDistribution("executor (non-fetch) time pct: ", - Distribution(runtimePcts.map(_.executorPct * 100)), "%2.0f %%") - showDistribution("fetch wait time pct: ", - Distribution(runtimePcts.flatMap(_.fetchPct.map(_ * 100))), "%2.0f %%") - showDistribution("other time pct: ", Distribution(runtimePcts.map(_.other * 100)), "%2.0f %%") - taskInfoMetrics.clear() - } +abstract class SparkListener extends SparkListenerInterface { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { } -} + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { } -private[spark] object StatsReportListener extends Logging { - - // For profiling, the extremes are more interesting - val percentiles = Array[Int](0, 5, 10, 25, 50, 75, 90, 95, 100) - val probabilities = percentiles.map(_ / 100.0) - val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%" - - def extractDoubleDistribution( - taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)], - getMetric: (TaskInfo, TaskMetrics) => Option[Double]): Option[Distribution] = { - Distribution(taskInfoMetrics.flatMap { case (info, metric) => getMetric(info, metric) }) - } - - // Is there some way to setup the types that I can get rid of this completely? - def extractLongDistribution( - taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)], - getMetric: (TaskInfo, TaskMetrics) => Option[Long]): Option[Distribution] = { - extractDoubleDistribution( - taskInfoMetrics, - (info, metric) => { getMetric(info, metric).map(_.toDouble) }) - } - - def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) { - val stats = d.statCounter - val quantiles = d.getQuantiles(probabilities).map(formatNumber) - logInfo(heading + stats) - logInfo(percentilesHeader) - logInfo("\t" + quantiles.mkString("\t")) - } - - def showDistribution( - heading: String, - dOpt: Option[Distribution], - formatNumber: Double => String) { - dOpt.foreach { d => showDistribution(heading, d, formatNumber)} - } - - def showDistribution(heading: String, dOpt: Option[Distribution], format: String) { - def f(d: Double): String = format.format(d) - showDistribution(heading, dOpt, f _) - } - - def showDistribution( - heading: String, - format: String, - getMetric: (TaskInfo, TaskMetrics) => Option[Double], - taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) { - showDistribution(heading, extractDoubleDistribution(taskInfoMetrics, getMetric), format) - } - - def showBytesDistribution( - heading: String, - getMetric: (TaskInfo, TaskMetrics) => Option[Long], - taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) { - showBytesDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric)) - } - - def showBytesDistribution(heading: String, dOpt: Option[Distribution]) { - dOpt.foreach { dist => showBytesDistribution(heading, dist) } - } - - def showBytesDistribution(heading: String, dist: Distribution) { - showDistribution(heading, dist, (d => Utils.bytesToString(d.toLong)): Double => String) - } - - def showMillisDistribution(heading: String, dOpt: Option[Distribution]) { - showDistribution(heading, dOpt, - (d => StatsReportListener.millisToString(d.toLong)): Double => String) - } - - def showMillisDistribution( - heading: String, - getMetric: (TaskInfo, TaskMetrics) => Option[Long], - taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) { - showMillisDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric)) - } - - val seconds = 1000L - val minutes = seconds * 60 - val hours = minutes * 60 + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { } - /** - * Reformat a time interval in milliseconds to a prettier format for output - */ - def millisToString(ms: Long): String = { - val (size, units) = - if (ms > hours) { - (ms.toDouble / hours, "hours") - } else if (ms > minutes) { - (ms.toDouble / minutes, "min") - } else if (ms > seconds) { - (ms.toDouble / seconds, "s") - } else { - (ms.toDouble, "ms") - } - "%.1f %s".format(size, units) - } -} + override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult): Unit = { } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { } + + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { } + + override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate): Unit = { } + + override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = { } + + override def onBlockManagerRemoved( + blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = { } + + override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = { } + + override def onApplicationStart(applicationStart: SparkListenerApplicationStart): Unit = { } + + override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { } + + override def onExecutorMetricsUpdate( + executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = { } + + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = { } + + override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { } + + override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { } -private case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double) - -private object RuntimePercentage { - def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = { - val denom = totalTime.toDouble - val fetchTime = metrics.shuffleReadMetrics.map(_.fetchWaitTime) - val fetch = fetchTime.map(_ / denom) - val exec = (metrics.executorRunTime - fetchTime.getOrElse(0L)) / denom - val other = 1.0 - (exec + fetch.getOrElse(0d)) - RuntimePercentage(exec, fetch, other) - } + override def onOtherEvent(event: SparkListenerEvent): Unit = { } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 04afde33f5aad..471586ac0852a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -22,9 +22,12 @@ import org.apache.spark.util.ListenerBus /** * A [[SparkListenerEvent]] bus that relays [[SparkListenerEvent]]s to its listeners */ -private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkListenerEvent] { +private[spark] trait SparkListenerBus + extends ListenerBus[SparkListenerInterface, SparkListenerEvent] { - override def onPostEvent(listener: SparkListener, event: SparkListenerEvent): Unit = { + protected override def doPostEvent( + listener: SparkListenerInterface, + event: SparkListenerEvent): Unit = { event match { case stageSubmitted: SparkListenerStageSubmitted => listener.onStageSubmitted(stageSubmitted) @@ -61,6 +64,7 @@ private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkLi case blockUpdated: SparkListenerBlockUpdated => listener.onBlockUpdated(blockUpdated) case logStart: SparkListenerLogStart => // ignore event log metadata + case _ => listener.onOtherEvent(event) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala index 1ce83485f024b..bc1431835e258 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala @@ -45,18 +45,17 @@ class SplitInfo( hashCode } - // This is practically useless since most of the Split impl's dont seem to implement equals :-( + // This is practically useless since most of the Split impl's don't seem to implement equals :-( // So unless there is identity equality between underlyingSplits, it will always fail even if it // is pointing to same block. override def equals(other: Any): Boolean = other match { - case that: SplitInfo => { + case that: SplitInfo => this.hostLocation == that.hostLocation && this.inputFormatClazz == that.inputFormatClazz && this.path == that.path && this.length == that.length && // other split specific checks (like start for FileSplit) this.underlyingSplit == that.underlyingSplit - } case _ => false } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 7ea24a217bd39..b6d4e39fe532a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import scala.collection.mutable.HashSet import org.apache.spark._ +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.util.CallSite @@ -74,22 +75,6 @@ private[scheduler] abstract class Stage( val name: String = callSite.shortForm val details: String = callSite.longForm - private var _internalAccumulators: Seq[Accumulator[Long]] = Seq.empty - - /** Internal accumulators shared across all tasks in this stage. */ - def internalAccumulators: Seq[Accumulator[Long]] = _internalAccumulators - - /** - * Re-initialize the internal accumulators associated with this stage. - * - * This is called every time the stage is submitted, *except* when a subset of tasks - * belonging to this stage has already finished. Otherwise, reinitializing the internal - * accumulators here again will override partial values from the finished tasks. - */ - def resetInternalAccumulators(): Unit = { - _internalAccumulators = InternalAccumulator.create(rdd.sparkContext) - } - /** * Pointer to the [StageInfo] object for the most recent attempt. This needs to be initialized * here, before any attempts have actually been created, because the DAGScheduler uses this @@ -126,7 +111,8 @@ private[scheduler] abstract class Stage( numPartitionsToCompute: Int, taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty): Unit = { _latestInfo = StageInfo.fromStage( - this, nextAttemptId, Some(numPartitionsToCompute), taskLocalityPreferences) + this, nextAttemptId, Some(numPartitionsToCompute), + InternalAccumulator.createAll(rdd.sparkContext), taskLocalityPreferences) nextAttemptId += 1 } diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 24796c14300b1..0fd58c41cdceb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import scala.collection.mutable.HashMap +import org.apache.spark.Accumulator import org.apache.spark.annotation.DeveloperApi import org.apache.spark.storage.RDDInfo @@ -35,6 +36,7 @@ class StageInfo( val rddInfos: Seq[RDDInfo], val parentIds: Seq[Int], val details: String, + val internalAccumulators: Seq[Accumulator[_]] = Seq.empty, private[spark] val taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty) { /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */ var submissionTime: Option[Long] = None @@ -42,7 +44,11 @@ class StageInfo( var completionTime: Option[Long] = None /** If the stage failed, the reason why. */ var failureReason: Option[String] = None - /** Terminal values of accumulables updated during this stage. */ + + /** + * Terminal values of accumulables updated during this stage, including all the user-defined + * accumulators. + */ val accumulables = HashMap[Long, AccumulableInfo]() def stageFailed(reason: String) { @@ -75,6 +81,7 @@ private[spark] object StageInfo { stage: Stage, attemptId: Int, numTasks: Option[Int] = None, + internalAccumulators: Seq[Accumulator[_]] = Seq.empty, taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty ): StageInfo = { val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd) @@ -87,6 +94,7 @@ private[spark] object StageInfo { rddInfos, stage.parents.map(_.id), stage.details, + internalAccumulators, taskLocalityPreferences) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala new file mode 100644 index 0000000000000..309f4b806bf70 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.collection.mutable + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.Logging +import org.apache.spark.util.{Distribution, Utils} + + +/** + * :: DeveloperApi :: + * Simple SparkListener that logs a few summary statistics when each stage completes. + */ +@DeveloperApi +class StatsReportListener extends SparkListener with Logging { + + import org.apache.spark.scheduler.StatsReportListener._ + + private val taskInfoMetrics = mutable.Buffer[(TaskInfo, TaskMetrics)]() + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + val info = taskEnd.taskInfo + val metrics = taskEnd.taskMetrics + if (info != null && metrics != null) { + taskInfoMetrics += ((info, metrics)) + } + } + + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { + implicit val sc = stageCompleted + this.logInfo(s"Finished stage: ${getStatusDetail(stageCompleted.stageInfo)}") + showMillisDistribution("task runtime:", (info, _) => Some(info.duration), taskInfoMetrics) + + // Shuffle write + showBytesDistribution("shuffle bytes written:", + (_, metric) => metric.shuffleWriteMetrics.map(_.bytesWritten), taskInfoMetrics) + + // Fetch & I/O + showMillisDistribution("fetch wait time:", + (_, metric) => metric.shuffleReadMetrics.map(_.fetchWaitTime), taskInfoMetrics) + showBytesDistribution("remote bytes read:", + (_, metric) => metric.shuffleReadMetrics.map(_.remoteBytesRead), taskInfoMetrics) + showBytesDistribution("task result size:", + (_, metric) => Some(metric.resultSize), taskInfoMetrics) + + // Runtime breakdown + val runtimePcts = taskInfoMetrics.map { case (info, metrics) => + RuntimePercentage(info.duration, metrics) + } + showDistribution("executor (non-fetch) time pct: ", + Distribution(runtimePcts.map(_.executorPct * 100)), "%2.0f %%") + showDistribution("fetch wait time pct: ", + Distribution(runtimePcts.flatMap(_.fetchPct.map(_ * 100))), "%2.0f %%") + showDistribution("other time pct: ", Distribution(runtimePcts.map(_.other * 100)), "%2.0f %%") + taskInfoMetrics.clear() + } + + private def getStatusDetail(info: StageInfo): String = { + val failureReason = info.failureReason.map("(" + _ + ")").getOrElse("") + val timeTaken = info.submissionTime.map( + x => info.completionTime.getOrElse(System.currentTimeMillis()) - x + ).getOrElse("-") + + s"Stage(${info.stageId}, ${info.attemptId}); Name: '${info.name}'; " + + s"Status: ${info.getStatusString}$failureReason; numTasks: ${info.numTasks}; " + + s"Took: $timeTaken msec" + } + +} + +private[spark] object StatsReportListener extends Logging { + + // For profiling, the extremes are more interesting + val percentiles = Array[Int](0, 5, 10, 25, 50, 75, 90, 95, 100) + val probabilities = percentiles.map(_ / 100.0) + val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%" + + def extractDoubleDistribution( + taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)], + getMetric: (TaskInfo, TaskMetrics) => Option[Double]): Option[Distribution] = { + Distribution(taskInfoMetrics.flatMap { case (info, metric) => getMetric(info, metric) }) + } + + // Is there some way to setup the types that I can get rid of this completely? + def extractLongDistribution( + taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)], + getMetric: (TaskInfo, TaskMetrics) => Option[Long]): Option[Distribution] = { + extractDoubleDistribution( + taskInfoMetrics, + (info, metric) => { getMetric(info, metric).map(_.toDouble) }) + } + + def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) { + val stats = d.statCounter + val quantiles = d.getQuantiles(probabilities).map(formatNumber) + logInfo(heading + stats) + logInfo(percentilesHeader) + logInfo("\t" + quantiles.mkString("\t")) + } + + def showDistribution( + heading: String, + dOpt: Option[Distribution], + formatNumber: Double => String) { + dOpt.foreach { d => showDistribution(heading, d, formatNumber)} + } + + def showDistribution(heading: String, dOpt: Option[Distribution], format: String) { + def f(d: Double): String = format.format(d) + showDistribution(heading, dOpt, f _) + } + + def showDistribution( + heading: String, + format: String, + getMetric: (TaskInfo, TaskMetrics) => Option[Double], + taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) { + showDistribution(heading, extractDoubleDistribution(taskInfoMetrics, getMetric), format) + } + + def showBytesDistribution( + heading: String, + getMetric: (TaskInfo, TaskMetrics) => Option[Long], + taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) { + showBytesDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric)) + } + + def showBytesDistribution(heading: String, dOpt: Option[Distribution]) { + dOpt.foreach { dist => showBytesDistribution(heading, dist) } + } + + def showBytesDistribution(heading: String, dist: Distribution) { + showDistribution(heading, dist, (d => Utils.bytesToString(d.toLong)): Double => String) + } + + def showMillisDistribution(heading: String, dOpt: Option[Distribution]) { + showDistribution(heading, dOpt, + (d => StatsReportListener.millisToString(d.toLong)): Double => String) + } + + def showMillisDistribution( + heading: String, + getMetric: (TaskInfo, TaskMetrics) => Option[Long], + taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) { + showMillisDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric)) + } + + val seconds = 1000L + val minutes = seconds * 60 + val hours = minutes * 60 + + /** + * Reformat a time interval in milliseconds to a prettier format for output + */ + def millisToString(ms: Long): String = { + val (size, units) = + if (ms > hours) { + (ms.toDouble / hours, "hours") + } else if (ms > minutes) { + (ms.toDouble / minutes, "min") + } else if (ms > seconds) { + (ms.toDouble / seconds, "s") + } else { + (ms.toDouble, "ms") + } + "%.1f %s".format(size, units) + } +} + +private case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double) + +private object RuntimePercentage { + def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = { + val denom = totalTime.toDouble + val fetchTime = metrics.shuffleReadMetrics.map(_.fetchWaitTime) + val fetch = fetchTime.map(_ / denom) + val exec = (metrics.executorRunTime - fetchTime.getOrElse(0L)) / denom + val other = 1.0 - (exec + fetch.getOrElse(0d)) + RuntimePercentage(exec, fetch, other) + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 4fb32ba8cb188..1ff9d7795f42e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -17,24 +17,24 @@ package org.apache.spark.scheduler -import java.io.{ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer +import java.util.Properties import scala.collection.mutable.HashMap -import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.{Accumulator, SparkEnv, TaskContextImpl, TaskContext} +import org.apache.spark.{Accumulator, SparkEnv, TaskContext, TaskContextImpl} import org.apache.spark.executor.TaskMetrics -import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} +import org.apache.spark.metrics.MetricsSystem import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.util.ByteBufferInputStream -import org.apache.spark.util.Utils - +import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} /** * A unit of execution. We have two kinds of Task's in Spark: - * - [[org.apache.spark.scheduler.ShuffleMapTask]] - * - [[org.apache.spark.scheduler.ResultTask]] + * + * - [[org.apache.spark.scheduler.ShuffleMapTask]] + * - [[org.apache.spark.scheduler.ResultTask]] * * A Spark job consists of one or more stages. The very last stage in a job consists of multiple * ResultTasks, while earlier stages consist of ShuffleMapTasks. A ResultTask executes the task @@ -42,56 +42,72 @@ import org.apache.spark.util.Utils * and divides the task output to multiple buckets (based on the task's partitioner). * * @param stageId id of the stage this task belongs to + * @param stageAttemptId attempt id of the stage this task belongs to * @param partitionId index of the number in the RDD + * @param initialAccumulators initial set of accumulators to be used in this task for tracking + * internal metrics. Other accumulators will be registered later when + * they are deserialized on the executors. + * @param localProperties copy of thread-local properties set by the user on the driver side. */ private[spark] abstract class Task[T]( val stageId: Int, val stageAttemptId: Int, val partitionId: Int, - internalAccumulators: Seq[Accumulator[Long]]) extends Serializable { + val initialAccumulators: Seq[Accumulator[_]], + @transient var localProperties: Properties) extends Serializable { /** - * The key of the Map is the accumulator id and the value of the Map is the latest accumulator - * local value. - */ - type AccumulatorUpdates = Map[Long, Any] - - /** - * Called by [[Executor]] to run this task. + * Called by [[org.apache.spark.executor.Executor]] to run this task. * * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext. * @param attemptNumber how many times this task has been attempted (0 for the first attempt) * @return the result of the task along with updates of Accumulators. */ final def run( - taskAttemptId: Long, - attemptNumber: Int, - metricsSystem: MetricsSystem) - : (T, AccumulatorUpdates) = { + taskAttemptId: Long, + attemptNumber: Int, + metricsSystem: MetricsSystem): T = { + SparkEnv.get.blockManager.registerTask(taskAttemptId) context = new TaskContextImpl( stageId, partitionId, taskAttemptId, attemptNumber, taskMemoryManager, + localProperties, metricsSystem, - internalAccumulators, - runningLocally = false) + initialAccumulators) TaskContext.setTaskContext(context) - context.taskMetrics.setHostname(Utils.localHostName()) - context.taskMetrics.setAccumulatorsUpdater(context.collectInternalAccumulators) taskThread = Thread.currentThread() if (_killed) { kill(interruptThread = false) } try { - (runTask(context), context.collectAccumulators()) + runTask(context) + } catch { + case e: Throwable => + // Catch all errors; run task failure callbacks, and rethrow the exception. + try { + context.markTaskFailed(e) + } catch { + case t: Throwable => + e.addSuppressed(t) + } + throw e } finally { + // Call the task completion callbacks. context.markTaskCompleted() try { Utils.tryLogNonFatalError { // Release memory used by this thread for unrolling blocks - SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask() + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP) + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.OFF_HEAP) + // Notify any tasks waiting for execution memory to be freed to wake up and try to + // acquire memory again. This makes impossible the scenario where a task sleeps forever + // because there are no other tasks left to notify it. Since this is safe to do but may + // not be strictly necessary, we should revisit whether we can remove this in the future. + val memoryManager = SparkEnv.get.memoryManager + memoryManager.synchronized { memoryManager.notifyAll() } } } finally { TaskContext.unset() @@ -136,6 +152,18 @@ private[spark] abstract class Task[T]( */ def executorDeserializeTime: Long = _executorDeserializeTime + /** + * Collect the latest values of accumulators used in this task. If the task failed, + * filter out the accumulators whose values should not be included on failures. + */ + def collectAccumulatorUpdates(taskFailed: Boolean = false): Seq[AccumulableInfo] = { + if (context != null) { + context.taskMetrics.accumulatorUpdates().filter { a => !taskFailed || a.countFailedValues } + } else { + Seq.empty[AccumulableInfo] + } + } + /** * Kills a task by setting the interrupted flag to true. This relies on the upper level Spark * code and user code to properly handle the flag. This function should be idempotent so it can @@ -171,7 +199,7 @@ private[spark] object Task { serializer: SerializerInstance) : ByteBuffer = { - val out = new ByteArrayOutputStream(4096) + val out = new ByteBufferOutputStream(4096) val dataOut = new DataOutputStream(out) // Write currentFiles @@ -188,11 +216,16 @@ private[spark] object Task { dataOut.writeLong(timestamp) } + // Write the task properties separately so it is available before full task deserialization. + val propBytes = Utils.serialize(task.localProperties) + dataOut.writeInt(propBytes.length) + dataOut.write(propBytes) + // Write the task itself and finish dataOut.flush() - val taskBytes = serializer.serialize(task).array() - out.write(taskBytes) - ByteBuffer.wrap(out.toByteArray) + val taskBytes = serializer.serialize(task) + Utils.writeByteBuffer(taskBytes, out) + out.toByteBuffer } /** @@ -203,7 +236,7 @@ private[spark] object Task { * @return (taskFiles, taskJars, taskBytes) */ def deserializeWithDependencies(serializedTask: ByteBuffer) - : (HashMap[String, Long], HashMap[String, Long], ByteBuffer) = { + : (HashMap[String, Long], HashMap[String, Long], Properties, ByteBuffer) = { val in = new ByteBufferInputStream(serializedTask) val dataIn = new DataInputStream(in) @@ -222,8 +255,13 @@ private[spark] object Task { taskJars(dataIn.readUTF()) = dataIn.readLong() } + val propLength = dataIn.readInt() + val propBytes = new Array[Byte](propLength) + dataIn.readFully(propBytes, 0, propLength) + val taskProps = Utils.deserialize[Properties](propBytes) + // Create a sub-buffer for the rest of the data, which is the serialized Task object val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task - (taskFiles, taskJars, subBuffer) + (taskFiles, taskJars, taskProps, subBuffer) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index f113c2b1b8433..a42990addb9c4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -95,9 +95,6 @@ class TaskInfo( } } - @deprecated("Use attemptNumber", "1.6.0") - def attempt: Int = attemptNumber - def id: String = s"$index.$attemptNumber" def duration: Long = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index b82c7f3fa54f8..03135e63d7551 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -20,11 +20,9 @@ package org.apache.spark.scheduler import java.io._ import java.nio.ByteBuffer -import scala.collection.Map -import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkEnv -import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockId import org.apache.spark.util.Utils @@ -36,31 +34,24 @@ private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int) extends TaskResult[T] with Serializable /** A TaskResult that contains the task's return value and accumulator updates. */ -private[spark] -class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long, Any], - var metrics: TaskMetrics) +private[spark] class DirectTaskResult[T]( + var valueBytes: ByteBuffer, + var accumUpdates: Seq[AccumulableInfo]) extends TaskResult[T] with Externalizable { private var valueObjectDeserialized = false private var valueObject: T = _ - def this() = this(null.asInstanceOf[ByteBuffer], null, null) + def this() = this(null.asInstanceOf[ByteBuffer], null) override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - - out.writeInt(valueBytes.remaining); + out.writeInt(valueBytes.remaining) Utils.writeByteBuffer(valueBytes, out) - out.writeInt(accumUpdates.size) - for ((key, value) <- accumUpdates) { - out.writeLong(key) - out.writeObject(value) - } - out.writeObject(metrics) + accumUpdates.foreach(out.writeObject) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { - val blen = in.readInt() val byteVal = new Array[Byte](blen) in.readFully(byteVal) @@ -70,13 +61,12 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long if (numUpdates == 0) { accumUpdates = null } else { - val _accumUpdates = mutable.Map[Long, Any]() + val _accumUpdates = new ArrayBuffer[AccumulableInfo] for (i <- 0 until numUpdates) { - _accumUpdates(in.readLong()) = in.readObject() + _accumUpdates += in.readObject.asInstanceOf[AccumulableInfo] } accumUpdates = _accumUpdates } - metrics = in.readObject().asInstanceOf[TaskMetrics] valueObjectDeserialized = false } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 46a6f6537e2ee..ae7ef46abbf31 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -18,13 +18,14 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer -import java.util.concurrent.RejectedExecutionException +import java.util.concurrent.{ExecutorService, RejectedExecutionException} import scala.language.existentials import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.TaskState.TaskState +import org.apache.spark.internal.Logging import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.{ThreadUtils, Utils} @@ -35,9 +36,12 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul extends Logging { private val THREADS = sparkEnv.conf.getInt("spark.resultGetter.threads", 4) - private val getTaskResultExecutor = ThreadUtils.newDaemonFixedThreadPool( - THREADS, "task-result-getter") + // Exposed for testing. + protected val getTaskResultExecutor: ExecutorService = + ThreadUtils.newDaemonFixedThreadPool(THREADS, "task-result-getter") + + // Exposed for testing. protected val serializer = new ThreadLocal[SerializerInstance] { override def initialValue(): SerializerInstance = { sparkEnv.closureSerializer.newInstance() @@ -45,7 +49,9 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul } def enqueueSuccessfulTask( - taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) { + taskSetManager: TaskSetManager, + tid: Long, + serializedData: ByteBuffer): Unit = { getTaskResultExecutor.execute(new Runnable { override def run(): Unit = Utils.logUncaughtExceptions { try { @@ -77,12 +83,24 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul return } val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]]( - serializedTaskResult.get) + serializedTaskResult.get.toByteBuffer) sparkEnv.blockManager.master.removeBlock(blockId) (deserializedResult, size) } - result.metrics.setResultSize(size) + // Set the task result size in the accumulator updates received from the executors. + // We need to do this here on the driver because if we did this on the executors then + // we would have to serialize the result again after updating the size. + result.accumUpdates = result.accumUpdates.map { a => + if (a.name == Some(InternalAccumulator.RESULT_SIZE)) { + assert(a.update == Some(0L), + "task result size should not have been set on the executors") + a.copy(update = Some(size.toLong)) + } else { + a + } + } + scheduler.handleSuccessfulTask(taskSetManager, tid, result) } catch { case cnf: ClassNotFoundException => @@ -103,19 +121,19 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul try { getTaskResultExecutor.execute(new Runnable { override def run(): Unit = Utils.logUncaughtExceptions { + val loader = Utils.getContextOrSparkClassLoader try { if (serializedData != null && serializedData.limit() > 0) { reason = serializer.get().deserialize[TaskEndReason]( - serializedData, Utils.getSparkClassLoader) + serializedData, loader) } } catch { case cnd: ClassNotFoundException => // Log an error but keep going here -- the task failed, so not catastrophic // if we can't deserialize the reason. - val loader = Utils.getContextOrSparkClassLoader logError( "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) - case ex: Exception => {} + case ex: Exception => // No-op } scheduler.handleFailedTask(taskSetManager, tid, taskState, reason) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index cb9a3008107d7..647d44a0f0680 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -18,7 +18,6 @@ package org.apache.spark.scheduler import org.apache.spark.scheduler.SchedulingMode.SchedulingMode -import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId /** @@ -52,7 +51,7 @@ private[spark] trait TaskScheduler { def submitTasks(taskSet: TaskSet): Unit // Cancel a stage. - def cancelTasks(stageId: Int, interruptThread: Boolean) + def cancelTasks(stageId: Int, interruptThread: Boolean): Unit // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. def setDAGScheduler(dagScheduler: DAGScheduler): Unit @@ -65,8 +64,10 @@ private[spark] trait TaskScheduler { * alive. Return true if the driver knows about the given block manager. Otherwise, return false, * indicating that the block manager should re-register. */ - def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)], - blockManagerId: BlockManagerId): Boolean + def executorHeartbeatReceived( + execId: String, + accumUpdates: Array[(Long, Seq[AccumulableInfo])], + blockManagerId: BlockManagerId): Boolean /** * Get an application ID associated with the job. diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 43d7d80b7aae1..c3159188d9f03 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer -import java.util.{TimerTask, Timer} +import java.util.{Timer, TimerTask} import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicLong @@ -30,11 +30,11 @@ import scala.util.Random import org.apache.spark._ import org.apache.spark.TaskState.TaskState +import org.apache.spark.internal.Logging import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.TaskLocality.TaskLocality -import org.apache.spark.util.{ThreadUtils, Utils} -import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.{ThreadUtils, Utils} /** * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend. @@ -87,8 +87,10 @@ private[spark] class TaskSchedulerImpl( // Incrementing task IDs val nextTaskId = new AtomicLong(0) - // Which executor IDs we have executors on - val activeExecutorIds = new HashSet[String] + // Number of tasks running on each executor + private val executorIdToTaskCount = new HashMap[String, Int] + + def runningTasksByExecutors(): Map[String, Int] = executorIdToTaskCount.toMap // The set of executors we have on each host; this is used to compute hostsAlive, which // in turn is used to decide when we can attain data locality on a given host @@ -254,6 +256,7 @@ private[spark] class TaskSchedulerImpl( val tid = task.taskId taskIdToTaskSetManager(tid) = taskSet taskIdToExecutorId(tid) = execId + executorIdToTaskCount(execId) += 1 executorsByHost(host) += execId availableCpus(i) -= CPUS_PER_TASK assert(availableCpus(i) >= 0) @@ -282,7 +285,7 @@ private[spark] class TaskSchedulerImpl( var newExecAvail = false for (o <- offers) { executorIdToHost(o.executorId) = o.host - activeExecutorIds += o.executorId + executorIdToTaskCount.getOrElseUpdate(o.executorId, 0) if (!executorsByHost.contains(o.host)) { executorsByHost(o.host) = new HashSet[String]() executorAdded(o.executorId, o.host) @@ -331,7 +334,8 @@ private[spark] class TaskSchedulerImpl( if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { // We lost this entire executor, so remember that it's gone val execId = taskIdToExecutorId(tid) - if (activeExecutorIds.contains(execId)) { + + if (executorIdToTaskCount.contains(execId)) { removeExecutor(execId, SlaveLost(s"Task $tid was lost, so marking the executor as lost as well.")) failedExecutor = Some(execId) @@ -341,7 +345,11 @@ private[spark] class TaskSchedulerImpl( case Some(taskSet) => if (TaskState.isFinished(state)) { taskIdToTaskSetManager.remove(tid) - taskIdToExecutorId.remove(tid) + taskIdToExecutorId.remove(tid).foreach { execId => + if (executorIdToTaskCount.contains(execId)) { + executorIdToTaskCount(execId) -= 1 + } + } } if (state == TaskState.FINISHED) { taskSet.removeRunningTask(tid) @@ -374,17 +382,17 @@ private[spark] class TaskSchedulerImpl( */ override def executorHeartbeatReceived( execId: String, - taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics + accumUpdates: Array[(Long, Seq[AccumulableInfo])], blockManagerId: BlockManagerId): Boolean = { - - val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized { - taskMetrics.flatMap { case (id, metrics) => + // (taskId, stageId, stageAttemptId, accumUpdates) + val accumUpdatesWithTaskIds: Array[(Long, Int, Int, Seq[AccumulableInfo])] = synchronized { + accumUpdates.flatMap { case (id, updates) => taskIdToTaskSetManager.get(id).map { taskSetMgr => - (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics) + (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, updates) } } } - dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId) + dagScheduler.executorHeartbeatReceived(execId, accumUpdatesWithTaskIds, blockManagerId) } def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long): Unit = synchronized { @@ -462,26 +470,27 @@ private[spark] class TaskSchedulerImpl( var failedExecutor: Option[String] = None synchronized { - if (activeExecutorIds.contains(executorId)) { + if (executorIdToTaskCount.contains(executorId)) { val hostPort = executorIdToHost(executorId) - logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) + logExecutorLoss(executorId, hostPort, reason) removeExecutor(executorId, reason) failedExecutor = Some(executorId) } else { - executorIdToHost.get(executorId) match { - case Some(_) => - // If the host mapping still exists, it means we don't know the loss reason for the - // executor. So call removeExecutor() to update tasks running on that executor when - // the real loss reason is finally known. - removeExecutor(executorId, reason) - - case None => - // We may get multiple executorLost() calls with different loss reasons. For example, - // one may be triggered by a dropped connection from the slave while another may be a - // report of executor termination from Mesos. We produce log messages for both so we - // eventually report the termination reason. - logError("Lost an executor " + executorId + " (already removed): " + reason) - } + executorIdToHost.get(executorId) match { + case Some(hostPort) => + // If the host mapping still exists, it means we don't know the loss reason for the + // executor. So call removeExecutor() to update tasks running on that executor when + // the real loss reason is finally known. + logExecutorLoss(executorId, hostPort, reason) + removeExecutor(executorId, reason) + + case None => + // We may get multiple executorLost() calls with different loss reasons. For example, + // one may be triggered by a dropped connection from the slave while another may be a + // report of executor termination from Mesos. We produce log messages for both so we + // eventually report the termination reason. + logError(s"Lost an executor $executorId (already removed): $reason") + } } } // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock @@ -491,13 +500,26 @@ private[spark] class TaskSchedulerImpl( } } + private def logExecutorLoss( + executorId: String, + hostPort: String, + reason: ExecutorLossReason): Unit = reason match { + case LossReasonPending => + logDebug(s"Executor $executorId on $hostPort lost, but reason not yet known.") + case ExecutorKilled => + logInfo(s"Executor $executorId on $hostPort killed by driver.") + case _ => + logError(s"Lost executor $executorId on $hostPort: $reason") + } + /** * Remove an executor from all our data structures and mark it as lost. If the executor's loss * reason is not yet known, do not yet remove its association with its host nor update the status * of any running tasks, since the loss reason defines whether we'll fail those tasks. */ private def removeExecutor(executorId: String, reason: ExecutorLossReason) { - activeExecutorIds -= executorId + executorIdToTaskCount -= executorId + val host = executorIdToHost(executorId) val execs = executorsByHost.getOrElse(host, new HashSet) execs -= executorId @@ -534,7 +556,11 @@ private[spark] class TaskSchedulerImpl( } def isExecutorAlive(execId: String): Boolean = synchronized { - activeExecutorIds.contains(execId) + executorIdToTaskCount.contains(execId) + } + + def isExecutorBusy(execId: String): Boolean = synchronized { + executorIdToTaskCount.getOrElse(execId, -1) > 0 } // By default, rack is unknown @@ -545,6 +571,11 @@ private[spark] class TaskSchedulerImpl( return } while (!backend.isReady) { + // Might take a while for backend to be ready if it is waiting on resources. + if (sc.stopped.get) { + // For example: the master removes the application for some reason + throw new IllegalStateException("Spark context stopped while waiting for backend") + } synchronized { this.wait(100) } @@ -597,10 +628,10 @@ private[spark] object TaskSchedulerImpl { while (found) { found = false for (key <- keyList) { - val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null) + val containerList: ArrayBuffer[T] = map.getOrElse(key, null) assert(containerList != null) // Get the index'th entry for this host - if present - if (index < containerList.size){ + if (index < containerList.size) { retval += containerList.apply(index) found = true } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala index be8526ba9b94f..517c8991aed78 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala @@ -29,7 +29,7 @@ private[spark] class TaskSet( val stageAttemptId: Int, val priority: Int, val properties: Properties) { - val id: String = stageId + "." + stageAttemptId + val id: String = stageId + "." + stageAttemptId override def toString: String = "TaskSet " + id } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 114468c48c44c..6e08cdd87a8d1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -25,11 +25,11 @@ import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet -import scala.math.{min, max} +import scala.math.{max, min} import scala.util.control.NonFatal import org.apache.spark._ -import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.Logging import org.apache.spark.scheduler.SchedulingMode._ import org.apache.spark.TaskState.TaskState import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -114,9 +114,14 @@ private[spark] class TaskSetManager( // treated as stacks, in which new tasks are added to the end of the // ArrayBuffer and removed from the end. This makes it faster to detect // tasks that repeatedly fail because whenever a task failed, it is put - // back at the head of the stack. They are also only cleaned up lazily; - // when a task is launched, it remains in all the pending lists except - // the one that it was launched from, but gets removed from them later. + // back at the head of the stack. These collections may contain duplicates + // for two reasons: + // (1): Tasks are only removed lazily; when a task is launched, it remains + // in all the pending lists except the one that it was launched from. + // (2): Tasks may be re-added to these lists multiple times as a result + // of failures. + // Duplicates are handled in dequeueTaskFromList, which ensures that a + // task hasn't already started running before launching it. private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]] // Set of pending tasks for each host. Similar to pendingTasksForExecutor, @@ -179,41 +184,32 @@ private[spark] class TaskSetManager( /** Add a task to all the pending-task lists that it should be on. */ private def addPendingTask(index: Int) { - // Utility method that adds `index` to a list only if it's not already there - def addTo(list: ArrayBuffer[Int]) { - if (!list.contains(index)) { - list += index - } - } - for (loc <- tasks(index).preferredLocations) { loc match { case e: ExecutorCacheTaskLocation => - addTo(pendingTasksForExecutor.getOrElseUpdate(e.executorId, new ArrayBuffer)) - case e: HDFSCacheTaskLocation => { + pendingTasksForExecutor.getOrElseUpdate(e.executorId, new ArrayBuffer) += index + case e: HDFSCacheTaskLocation => val exe = sched.getExecutorsAliveOnHost(loc.host) exe match { - case Some(set) => { + case Some(set) => for (e <- set) { - addTo(pendingTasksForExecutor.getOrElseUpdate(e, new ArrayBuffer)) + pendingTasksForExecutor.getOrElseUpdate(e, new ArrayBuffer) += index } logInfo(s"Pending task $index has a cached location at ${e.host} " + ", where there are executors " + set.mkString(",")) - } case None => logDebug(s"Pending task $index has a cached location at ${e.host} " + ", but there are no executors alive there.") } - } - case _ => Unit + case _ => } - addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer)) + pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer) += index for (rack <- sched.getRackForHost(loc.host)) { - addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer)) + pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer) += index } } if (tasks(index).preferredLocations == Nil) { - addTo(pendingTasksWithNoPrefs) + pendingTasksWithNoPrefs += index } allPendingTasks += index // No point scanning this whole list to find the old task there @@ -340,7 +336,7 @@ private[spark] class TaskSetManager( if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { for (rack <- sched.getRackForHost(host)) { for (index <- speculatableTasks if canRunOnHost(index)) { - val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost) + val racks = tasks(index).preferredLocations.map(_.host).flatMap(sched.getRackForHost) if (racks.contains(rack)) { speculatableTasks -= index return Some((index, TaskLocality.RACK_LOCAL)) @@ -439,7 +435,7 @@ private[spark] class TaskSetManager( } dequeueTask(execId, host, allowedLocality) match { - case Some((index, taskLocality, speculative)) => { + case Some((index, taskLocality, speculative)) => // Found a task; do some bookkeeping and return a task description val task = tasks(index) val taskId = sched.newTaskId() @@ -488,7 +484,6 @@ private[spark] class TaskSetManager( sched.dagScheduler.taskStarted(task, info) return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId, taskName, index, serializedTask)) - } case _ => } } @@ -557,9 +552,9 @@ private[spark] class TaskSetManager( // Jump to the next locality level, and reset lastLaunchTime so that the next locality // wait timer doesn't immediately expire lastLaunchTime += localityWaits(currentLocalityIndex) - currentLocalityIndex += 1 - logDebug(s"Moving to ${myLocalityLevels(currentLocalityIndex)} after waiting for " + + logDebug(s"Moving to ${myLocalityLevels(currentLocalityIndex + 1)} after waiting for " + s"${localityWaits(currentLocalityIndex)}ms") + currentLocalityIndex += 1 } else { return myLocalityLevels(currentLocalityIndex) } @@ -608,7 +603,7 @@ private[spark] class TaskSetManager( } /** - * Marks the task as successful and notifies the DAGScheduler that a task has ended. + * Marks a task as successful and notifies the DAGScheduler that the task has ended. */ def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]): Unit = { val info = taskInfos(tid) @@ -621,8 +616,7 @@ private[spark] class TaskSetManager( // "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here. // Note: "result.value()" only deserializes the value when it's called at the first time, so // here "result.value()" just returns the value and won't block other threads. - sched.dagScheduler.taskEnded( - tasks(index), Success, result.value(), result.accumUpdates, info, result.metrics) + sched.dagScheduler.taskEnded(tasks(index), Success, result.value(), result.accumUpdates, info) if (!successful(index)) { tasksSuccessful += 1 logInfo("Finished task %s in stage %s (TID %d) in %d ms on %s (%d/%d)".format( @@ -653,8 +647,7 @@ private[spark] class TaskSetManager( info.markFailed() val index = info.index copiesRunning(index) -= 1 - var taskMetrics : TaskMetrics = null - + var accumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo] val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " + reason.asInstanceOf[TaskFailedReason].toErrorString val failureException: Option[Throwable] = reason match { @@ -669,7 +662,8 @@ private[spark] class TaskSetManager( None case ef: ExceptionFailure => - taskMetrics = ef.metrics.orNull + // ExceptionFailure's might have accumulator updates + accumUpdates = ef.accumUpdates if (ef.className == classOf[NotSerializableException].getName) { // If the task result wasn't serializable, there's no point in trying to re-execute it. logError("Task %s in stage %s (TID %d) had a not serializable result: %s; not retrying" @@ -705,7 +699,7 @@ private[spark] class TaskSetManager( ef.exception case e: ExecutorLostFailure if !e.exitCausedByApp => - logInfo(s"Task $tid failed because while it was being computed, its executor" + + logInfo(s"Task $tid failed because while it was being computed, its executor " + "exited for a reason unrelated to the task. Not counting this failure towards the " + "maximum number of failures for the task.") None @@ -721,7 +715,7 @@ private[spark] class TaskSetManager( // always add to failed executors failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()). put(info.executorId, clock.getTimeMillis()) - sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics) + sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, info) addPendingTask(index) if (!isZombie && state != TaskState.KILLED && reason.isInstanceOf[TaskFailedReason] @@ -793,13 +787,15 @@ private[spark] class TaskSetManager( addPendingTask(index) // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our // stage finishes when a total of tasks.size tasks finish. - sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null) + sched.dagScheduler.taskEnded( + tasks(index), Resubmitted, null, Seq.empty[AccumulableInfo], info) } } } for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { val exitCausedByApp: Boolean = reason match { case exited: ExecutorExited => exited.exitCausedByApp + case ExecutorKilled => false case _ => true } handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(info.executorId, exitCausedByApp, @@ -829,7 +825,7 @@ private[spark] class TaskSetManager( val time = clock.getTimeMillis() val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray Arrays.sort(durations) - val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1)) + val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.length - 1)) val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100) // TODO: Threshold should also look at standard deviation of task durations and have a lower // bound based on that. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index f3d0d85476772..46a829114ec86 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.ExecutorLossReason -import org.apache.spark.util.{SerializableBuffer, Utils} +import org.apache.spark.util.SerializableBuffer private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable @@ -30,6 +30,8 @@ private[spark] object CoarseGrainedClusterMessages { case object RetrieveSparkProps extends CoarseGrainedClusterMessage + case object RetrieveLastAllocatedExecutorId extends CoarseGrainedClusterMessage + // Driver to executors case class LaunchTask(data: SerializableBuffer) extends CoarseGrainedClusterMessage @@ -48,7 +50,6 @@ private[spark] object CoarseGrainedClusterMessages { case class RegisterExecutor( executorId: String, executorRef: RpcEndpointRef, - hostPort: String, cores: Int, logUrls: Map[String, String]) extends CoarseGrainedClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index f71d98feac050..8896391f9775f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -19,18 +19,20 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import org.apache.spark.{ExecutorAllocationClient, SparkEnv, SparkException, TaskState} +import org.apache.spark.internal.Logging import org.apache.spark.rpc._ -import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.ENDPOINT_NAME -import org.apache.spark.util.{ThreadUtils, SerializableBuffer, AkkaUtils, Utils} +import org.apache.spark.util.{RpcUtils, SerializableBuffer, ThreadUtils, Utils} /** - * A scheduler backend that waits for coarse grained executors to connect to it through Akka. + * A scheduler backend that waits for coarse-grained executors to connect. * This backend holds onto each executor for the duration of the Spark job rather than relinquishing * executors whenever a task is done and asking the scheduler to launch a new executor for * each new task. Executors may be launched in a variety of ways, such as Mesos tasks for the @@ -42,43 +44,57 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp extends ExecutorAllocationClient with SchedulerBackend with Logging { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed - var totalCoreCount = new AtomicInteger(0) + protected val totalCoreCount = new AtomicInteger(0) // Total number of executors that are currently registered - var totalRegisteredExecutors = new AtomicInteger(0) - val conf = scheduler.sc.conf - private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) + protected val totalRegisteredExecutors = new AtomicInteger(0) + protected val conf = scheduler.sc.conf + private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) // Submit tasks only after (registered resources / total expected resources) // is equal to at least this value, that is double between 0 and 1. - var minRegisteredRatio = + private val _minRegisteredRatio = math.min(1, conf.getDouble("spark.scheduler.minRegisteredResourcesRatio", 0)) // Submit tasks after maxRegisteredWaitingTime milliseconds // if minRegisteredRatio has not yet been reached - val maxRegisteredWaitingTimeMs = + private val maxRegisteredWaitingTimeMs = conf.getTimeAsMs("spark.scheduler.maxRegisteredResourcesWaitingTime", "30s") - val createTime = System.currentTimeMillis() + private val createTime = System.currentTimeMillis() + // Accessing `executorDataMap` in `DriverEndpoint.receive/receiveAndReply` doesn't need any + // protection. But accessing `executorDataMap` out of `DriverEndpoint.receive/receiveAndReply` + // must be protected by `CoarseGrainedSchedulerBackend.this`. Besides, `executorDataMap` should + // only be modified in `DriverEndpoint.receive/receiveAndReply` with protection by + // `CoarseGrainedSchedulerBackend.this`. private val executorDataMap = new HashMap[String, ExecutorData] // Number of executors requested from the cluster manager that have not registered yet + @GuardedBy("CoarseGrainedSchedulerBackend.this") private var numPendingExecutors = 0 private val listenerBus = scheduler.sc.listenerBus - // Executors we have requested the cluster manager to kill that have not died yet - private val executorsPendingToRemove = new HashSet[String] + // Executors we have requested the cluster manager to kill that have not died yet; maps + // the executor ID to whether it was explicitly killed by the driver (and thus shouldn't + // be considered an app-related failure). + @GuardedBy("CoarseGrainedSchedulerBackend.this") + private val executorsPendingToRemove = new HashMap[String, Boolean] // A map to store hostname with its possible task number running on it + @GuardedBy("CoarseGrainedSchedulerBackend.this") protected var hostToLocalTaskCount: Map[String, Int] = Map.empty // The number of pending tasks which is locality required + @GuardedBy("CoarseGrainedSchedulerBackend.this") protected var localityAwareTasks = 0 - // Executors that have been lost, but for which we don't yet know the real exit reason. - protected val executorsPendingLossReason = new HashSet[String] + // The num of current max ExecutorId used to re-register appMaster + @volatile protected var currentExecutorIdCounter = 0 class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends ThreadSafeRpcEndpoint with Logging { + // Executors that have been lost, but for which we don't yet know the real exit reason. + protected val executorsPendingLossReason = new HashSet[String] + // If this DriverEndpoint is changed to support multiple threads, // then this may need to be changed so that we don't share the serializer // instance across threads @@ -132,9 +148,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RegisterExecutor(executorId, executorRef, hostPort, cores, logUrls) => + case RegisterExecutor(executorId, executorRef, cores, logUrls) => if (executorDataMap.contains(executorId)) { - context.reply(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) + executorRef.send(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) + context.reply(true) } else { // If the executor's rpc env is not listening for incoming connections, `hostPort` // will be null, and the client connection should be used to contact the executor. @@ -153,13 +170,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // in this block are read when requesting executors CoarseGrainedSchedulerBackend.this.synchronized { executorDataMap.put(executorId, data) + if (currentExecutorIdCounter < executorId.toInt) { + currentExecutorIdCounter = executorId.toInt + } if (numPendingExecutors > 0) { numPendingExecutors -= 1 logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") } } + executorRef.send(RegisteredExecutor(executorAddress.host)) // Note: some tests expect the reply to come after we put the executor in the map - context.reply(RegisteredExecutor(executorAddress.host)) + context.reply(true) listenerBus.post( SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data)) makeOffers() @@ -177,6 +198,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp context.reply(true) case RemoveExecutor(executorId, reason) => + // We will remove the executor's state and cannot restore it. However, the connection + // between the driver and the executor may be still alive so that the executor won't exit + // automatically, so try to tell the executor to stop itself. See SPARK-13519. + executorDataMap.get(executorId).foreach(_.executorEndpoint.send(StopExecutor)) removeExecutor(executorId, reason) context.reply(true) @@ -222,14 +247,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp private def launchTasks(tasks: Seq[Seq[TaskDescription]]) { for (task <- tasks.flatten) { val serializedTask = ser.serialize(task) - if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { + if (serializedTask.limit >= maxRpcMessageSize) { scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr => try { var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " + - "spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " + - "spark.akka.frameSize or using broadcast variables for large values." - msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize, - AkkaUtils.reservedSizeBytes) + "spark.rpc.message.maxSize (%d bytes). Consider increasing " + + "spark.rpc.message.maxSize or using broadcast variables for large values." + msg = msg.format(task.taskId, task.index, serializedTask.limit, maxRpcMessageSize) taskSetMgr.abort(msg) } catch { case e: Exception => logError("Exception in error callback", e) @@ -239,26 +263,30 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp else { val executorData = executorDataMap(task.executorId) executorData.freeCores -= scheduler.CPUS_PER_TASK + + logInfo(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " + + s"${executorData.executorHost}.") + executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask))) } } } // Remove a disconnected slave from the cluster - def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = { + private def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = { executorDataMap.get(executorId) match { case Some(executorInfo) => // This must be synchronized because variables mutated // in this block are read when requesting executors - CoarseGrainedSchedulerBackend.this.synchronized { + val killed = CoarseGrainedSchedulerBackend.this.synchronized { addressToExecutorId -= executorInfo.executorAddress executorDataMap -= executorId - executorsPendingToRemove -= executorId executorsPendingLossReason -= executorId + executorsPendingToRemove.remove(executorId).getOrElse(false) } totalCoreCount.addAndGet(-executorInfo.totalCores) totalRegisteredExecutors.addAndGet(-1) - scheduler.executorLost(executorId, reason) + scheduler.executorLost(executorId, if (killed) ExecutorKilled else reason) listenerBus.post( SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason.toString)) case None => logInfo(s"Asked to remove non-existent executor $executorId") @@ -269,7 +297,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * Stop making resource offers for the given executor. The executor is marked as lost with * the loss reason still pending. * - * @return Whether executor was alive. + * @return Whether executor should be disabled */ protected def disableExecutor(executorId: String): Boolean = { val shouldDisable = CoarseGrainedSchedulerBackend.this.synchronized { @@ -277,7 +305,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp executorsPendingLossReason += executorId true } else { - false + // Returns true for explicitly killed executors, we also need to get pending loss reasons; + // For others return false. + executorsPendingToRemove.contains(executorId) } } @@ -295,7 +325,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } var driverEndpoint: RpcEndpointRef = null - val taskIdsOnSlave = new HashMap[String, HashSet[String]] + + protected def minRegisteredRatio: Double = _minRegisteredRatio override def start() { val properties = new ArrayBuffer[(String, String)] @@ -306,7 +337,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // TODO (prashant) send conf instead of properties - driverEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, createDriverEndpoint(properties)) + driverEndpoint = createDriverEndpointRef(properties) + } + + protected def createDriverEndpointRef( + properties: ArrayBuffer[(String, String)]): RpcEndpointRef = { + rpcEnv.setupEndpoint(ENDPOINT_NAME, createDriverEndpoint(properties)) } protected def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { @@ -337,6 +373,22 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } } + /** + * Reset the state of CoarseGrainedSchedulerBackend to the initial state. Currently it will only + * be called in the yarn-client mode when AM re-registers after a failure. + * */ + protected def reset(): Unit = synchronized { + numPendingExecutors = 0 + executorsPendingToRemove.clear() + + // Remove all the lingering executors that should be removed but not yet. The reason might be + // because (1) disconnected event is not yet received; (2) executors die silently. + executorDataMap.toMap.foreach { case (eid, _) => + driverEndpoint.askWithRetry[Boolean]( + RemoveExecutor(eid, SlaveLost("Stale executor after cluster manager re-registered."))) + } + } + override def reviveOffers() { driverEndpoint.send(ReviveOffers) } @@ -378,7 +430,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp /** * Return the number of executors currently registered with this backend. */ - def numExistingExecutors: Int = executorDataMap.size + private def numExistingExecutors: Int = executorDataMap.size + + override def getExecutorIds(): Seq[String] = { + executorDataMap.keySet.toSeq + } /** * Request an additional number of executors from the cluster manager. @@ -448,20 +504,30 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp /** * Request that the cluster manager kill the specified executors. - * @return whether the kill request is acknowledged. + * @return whether the kill request is acknowledged. If list to kill is empty, it will return + * false. */ final override def killExecutors(executorIds: Seq[String]): Boolean = synchronized { - killExecutors(executorIds, replace = false) + killExecutors(executorIds, replace = false, force = false) } /** * Request that the cluster manager kill the specified executors. * + * When asking the executor to be replaced, the executor loss is considered a failure, and + * killed tasks that are running on the executor will count towards the failure limits. If no + * replacement is being requested, then the tasks will not count towards the limit. + * * @param executorIds identifiers of executors to kill * @param replace whether to replace the killed executors with new ones - * @return whether the kill request is acknowledged. + * @param force whether to force kill busy executors + * @return whether the kill request is acknowledged. If list to kill is empty, it will return + * false. */ - final def killExecutors(executorIds: Seq[String], replace: Boolean): Boolean = synchronized { + final def killExecutors( + executorIds: Seq[String], + replace: Boolean, + force: Boolean): Boolean = synchronized { logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}") val (knownExecutors, unknownExecutors) = executorIds.partition(executorDataMap.contains) unknownExecutors.foreach { id => @@ -469,8 +535,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // If an executor is already pending to be removed, do not kill it again (SPARK-9795) - val executorsToKill = knownExecutors.filter { id => !executorsPendingToRemove.contains(id) } - executorsPendingToRemove ++= executorsToKill + // If this executor is busy, do not kill it unless we are told to force kill it (SPARK-9552) + val executorsToKill = knownExecutors + .filter { id => !executorsPendingToRemove.contains(id) } + .filter { id => force || !scheduler.isExecutorBusy(id) } + executorsToKill.foreach { id => executorsPendingToRemove(id) = !replace } // If we do not wish to replace the executors we kill, sync the target number of executors // with the cluster manager to avoid allocating new ones. When computing the new target, @@ -482,7 +551,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp numPendingExecutors += knownExecutors.size } - doKillExecutors(executorsToKill) + !executorsToKill.isEmpty && doKillExecutors(executorsToKill) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index 626a2b7d69abe..b25a4bfb501fb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster -import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} /** * Grouping of data for an executor used by CoarseGrainedSchedulerBackend. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala deleted file mode 100644 index 641638a77d5f5..0000000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster - -import org.apache.hadoop.fs.{Path, FileSystem} - -import org.apache.spark.rpc.RpcAddress -import org.apache.spark.{Logging, SparkContext, SparkEnv} -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.scheduler.TaskSchedulerImpl - -private[spark] class SimrSchedulerBackend( - scheduler: TaskSchedulerImpl, - sc: SparkContext, - driverFilePath: String) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) - with Logging { - - val tmpPath = new Path(driverFilePath + "_tmp") - val filePath = new Path(driverFilePath) - - val maxCores = conf.getInt("spark.simr.executor.cores", 1) - - override def start() { - super.start() - - val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, - RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) - - val conf = SparkHadoopUtil.get.newConfiguration(sc.conf) - val fs = FileSystem.get(conf) - val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") - - logInfo("Writing to HDFS file: " + driverFilePath) - logInfo("Writing Akka address: " + driverUrl) - logInfo("Writing Spark UI Address: " + appUIAddress) - - // Create temporary file to prevent race condition where executors get empty driverUrl file - val temp = fs.create(tmpPath, true) - temp.writeUTF(driverUrl) - temp.writeInt(maxCores) - temp.writeUTF(appUIAddress) - temp.close() - - // "Atomic" rename - fs.rename(tmpPath, filePath) - } - - override def stop() { - val conf = SparkHadoopUtil.get.newConfiguration(sc.conf) - val fs = FileSystem.get(conf) - if (!fs.delete(new Path(driverFilePath), false)) { - logWarning(s"error deleting ${driverFilePath}") - } - super.stop() - } - -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 05d9bc92f228b..85d002011d64c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -19,11 +19,12 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.Semaphore -import org.apache.spark.rpc.RpcAddress -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} +import org.apache.spark.internal.Logging import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} +import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler._ import org.apache.spark.util.Utils @@ -54,9 +55,10 @@ private[spark] class SparkDeploySchedulerBackend( launcherBackend.connect() // The endpoint for executors to talk to us - val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, - RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + val driverUrl = RpcEndpointAddress( + sc.conf.get("spark.driver.host"), + sc.conf.get("spark.driver.port").toInt, + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString val args = Seq( "--driver-url", driverUrl, "--executor-id", "{{EXECUTOR_ID}}", @@ -88,8 +90,16 @@ private[spark] class SparkDeploySchedulerBackend( args, sc.executorEnvs, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts) val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt) - val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, - command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor) + // If we're using dynamic allocation, set our initial executor limit to 0 for now. + // ExecutorAllocationManager will send the real initial limit to the Master later. + val initialExecutorLimit = + if (Utils.isDynamicAllocationEnabled(conf)) { + Some(0) + } else { + None + } + val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, + appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor, initialExecutorLimit) client = new AppClient(sc.env.rpcEnv, masters, appDesc, this, conf) client.start() launcherBackend.setState(SparkAppHandle.State.SUBMITTED) @@ -191,17 +201,19 @@ private[spark] class SparkDeploySchedulerBackend( } private def stop(finalState: SparkAppHandle.State): Unit = synchronized { - stopping = true + try { + stopping = true - launcherBackend.setState(finalState) - launcherBackend.close() + super.stop() + client.stop() - super.stop() - client.stop() - - val callback = shutdownCallback - if (callback != null) { - callback(this) + val callback = shutdownCallback + if (callback != null) { + callback(this) + } + } finally { + launcherBackend.setState(finalState) + launcherBackend.close() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala deleted file mode 100644 index 80da37b09b590..0000000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ /dev/null @@ -1,226 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.cluster - -import scala.collection.mutable.ArrayBuffer -import scala.concurrent.{Future, ExecutionContext} - -import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.rpc._ -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.scheduler._ -import org.apache.spark.ui.JettyUtils -import org.apache.spark.util.{ThreadUtils, RpcUtils} - -import scala.util.control.NonFatal - -/** - * Abstract Yarn scheduler backend that contains common logic - * between the client and cluster Yarn scheduler backends. - */ -private[spark] abstract class YarnSchedulerBackend( - scheduler: TaskSchedulerImpl, - sc: SparkContext) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) { - - if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { - minRegisteredRatio = 0.8 - } - - protected var totalExpectedExecutors = 0 - - private val yarnSchedulerEndpoint = new YarnSchedulerEndpoint(rpcEnv) - - private val yarnSchedulerEndpointRef = rpcEnv.setupEndpoint( - YarnSchedulerBackend.ENDPOINT_NAME, yarnSchedulerEndpoint) - - private implicit val askTimeout = RpcUtils.askRpcTimeout(sc.conf) - - /** - * Request executors from the ApplicationMaster by specifying the total number desired. - * This includes executors already pending or running. - */ - override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - yarnSchedulerEndpointRef.askWithRetry[Boolean]( - RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount)) - } - - /** - * Request that the ApplicationMaster kill the specified executors. - */ - override def doKillExecutors(executorIds: Seq[String]): Boolean = { - yarnSchedulerEndpointRef.askWithRetry[Boolean](KillExecutors(executorIds)) - } - - override def sufficientResourcesRegistered(): Boolean = { - totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio - } - - /** - * Add filters to the SparkUI. - */ - private def addWebUIFilter( - filterName: String, - filterParams: Map[String, String], - proxyBase: String): Unit = { - if (proxyBase != null && proxyBase.nonEmpty) { - System.setProperty("spark.ui.proxyBase", proxyBase) - } - - val hasFilter = - filterName != null && filterName.nonEmpty && - filterParams != null && filterParams.nonEmpty - if (hasFilter) { - logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase") - conf.set("spark.ui.filters", filterName) - filterParams.foreach { case (k, v) => conf.set(s"spark.$filterName.param.$k", v) } - scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) } - } - } - - override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { - new YarnDriverEndpoint(rpcEnv, properties) - } - - /** - * Override the DriverEndpoint to add extra logic for the case when an executor is disconnected. - * This endpoint communicates with the executors and queries the AM for an executor's exit - * status when the executor is disconnected. - */ - private class YarnDriverEndpoint(rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) - extends DriverEndpoint(rpcEnv, sparkProperties) { - - /** - * When onDisconnected is received at the driver endpoint, the superclass DriverEndpoint - * handles it by assuming the Executor was lost for a bad reason and removes the executor - * immediately. - * - * In YARN's case however it is crucial to talk to the application master and ask why the - * executor had exited. If the executor exited for some reason unrelated to the running tasks - * (e.g., preemption), according to the application master, then we pass that information down - * to the TaskSetManager to inform the TaskSetManager that tasks on that lost executor should - * not count towards a job failure. - */ - override def onDisconnected(rpcAddress: RpcAddress): Unit = { - addressToExecutorId.get(rpcAddress).foreach { executorId => - if (disableExecutor(executorId)) { - yarnSchedulerEndpoint.handleExecutorDisconnectedFromDriver(executorId, rpcAddress) - } - } - } - } - - /** - * An [[RpcEndpoint]] that communicates with the ApplicationMaster. - */ - private class YarnSchedulerEndpoint(override val rpcEnv: RpcEnv) - extends ThreadSafeRpcEndpoint with Logging { - private var amEndpoint: Option[RpcEndpointRef] = None - - private val askAmThreadPool = - ThreadUtils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool") - implicit val askAmExecutor = ExecutionContext.fromExecutor(askAmThreadPool) - - private[YarnSchedulerBackend] def handleExecutorDisconnectedFromDriver( - executorId: String, - executorRpcAddress: RpcAddress): Unit = { - amEndpoint match { - case Some(am) => - val lossReasonRequest = GetExecutorLossReason(executorId) - val future = am.ask[ExecutorLossReason](lossReasonRequest, askTimeout) - future onSuccess { - case reason: ExecutorLossReason => { - driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason)) - } - } - future onFailure { - case NonFatal(e) => { - logWarning(s"Attempted to get executor loss reason" + - s" for executor id ${executorId} at RPC address ${executorRpcAddress}," + - s" but got no response. Marking as slave lost.", e) - driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, SlaveLost())) - } - case t => throw t - } - case None => - logWarning("Attempted to check for an executor loss reason" + - " before the AM has registered!") - } - } - - override def receive: PartialFunction[Any, Unit] = { - case RegisterClusterManager(am) => - logInfo(s"ApplicationMaster registered as $am") - amEndpoint = Option(am) - - case AddWebUIFilter(filterName, filterParams, proxyBase) => - addWebUIFilter(filterName, filterParams, proxyBase) - - case RemoveExecutor(executorId, reason) => - logWarning(reason.toString) - removeExecutor(executorId, reason) - } - - - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case r: RequestExecutors => - amEndpoint match { - case Some(am) => - Future { - context.reply(am.askWithRetry[Boolean](r)) - } onFailure { - case NonFatal(e) => - logError(s"Sending $r to AM was unsuccessful", e) - context.sendFailure(e) - } - case None => - logWarning("Attempted to request executors before the AM has registered!") - context.reply(false) - } - - case k: KillExecutors => - amEndpoint match { - case Some(am) => - Future { - context.reply(am.askWithRetry[Boolean](k)) - } onFailure { - case NonFatal(e) => - logError(s"Sending $k to AM was unsuccessful", e) - context.sendFailure(e) - } - case None => - logWarning("Attempted to kill executors before the AM has registered!") - context.reply(false) - } - } - - override def onDisconnected(remoteAddress: RpcAddress): Unit = { - if (amEndpoint.exists(_.address == remoteAddress)) { - logWarning(s"ApplicationMaster has disassociated: $remoteAddress") - } - } - - override def onStop(): Unit = { - askAmThreadPool.shutdownNow() - } - } -} - -private[spark] object YarnSchedulerBackend { - val ENDPOINT_NAME = "YarnScheduler" -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index d10a77f8e5c78..50b452c72f8aa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -18,20 +18,20 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.concurrent.locks.ReentrantLock import java.util.{Collections, List => JList} +import java.util.concurrent.locks.ReentrantLock import scala.collection.JavaConverters._ -import scala.collection.mutable.{HashMap, HashSet} +import scala.collection.mutable +import scala.collection.mutable.{Buffer, HashMap, HashSet} -import com.google.common.collect.HashBiMap -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.mesos.{Scheduler => MScheduler, SchedulerDriver} +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} -import org.apache.spark.{SecurityManager, SparkContext, SparkEnv, SparkException, TaskState} +import org.apache.spark.{SecurityManager, SparkContext, SparkException, TaskState} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient -import org.apache.spark.rpc.RpcAddress +import org.apache.spark.rpc.{RpcEndpointAddress} import org.apache.spark.scheduler.{SlaveLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils @@ -60,28 +60,38 @@ private[spark] class CoarseMesosSchedulerBackend( // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt + private[this] val shutdownTimeoutMS = + conf.getTimeAsMs("spark.mesos.coarse.shutdownTimeout", "10s") + .ensuring(_ >= 0, "spark.mesos.coarse.shutdownTimeout must be >= 0") + + // Synchronization protected by stateLock + private[this] var stopCalled: Boolean = false + // If shuffle service is enabled, the Spark driver will register with the shuffle service. // This is for cleaning up shuffle files reliably. private val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) // Cores we have acquired with each Mesos task ID - val coresByTaskId = new HashMap[Int, Int] + val coresByTaskId = new HashMap[String, Int] var totalCoresAcquired = 0 - val slaveIdsWithExecutors = new HashSet[String] - - // Maping from slave Id to hostname - private val slaveIdToHost = new HashMap[String, String] - - val taskIdToSlaveId: HashBiMap[Int, String] = HashBiMap.create[Int, String] - // How many times tasks on each slave failed - val failuresBySlaveId: HashMap[String, Int] = new HashMap[String, Int] + // SlaveID -> Slave + // This map accumulates entries for the duration of the job. Slaves are never deleted, because + // we need to maintain e.g. failure state and connection state. + private val slaves = new HashMap[String, Slave] /** - * The total number of executors we aim to have. Undefined when not using dynamic allocation - * and before the ExecutorAllocatorManager calls [[doRequestTotalExecutors]]. + * The total number of executors we aim to have. Undefined when not using dynamic allocation. + * Initially set to 0 when using dynamic allocation, the executor allocation manager will send + * the real initial limit later. */ - private var executorLimitOption: Option[Int] = None + private var executorLimitOption: Option[Int] = { + if (Utils.isDynamicAllocationEnabled(conf)) { + Some(0) + } else { + None + } + } /** * Return the current executor limit, which may be [[Int.MaxValue]] @@ -89,39 +99,46 @@ private[spark] class CoarseMesosSchedulerBackend( */ private[mesos] def executorLimit: Int = executorLimitOption.getOrElse(Int.MaxValue) - private val pendingRemovedSlaveIds = new HashSet[String] - // private lock object protecting mutable state above. Using the intrinsic lock // may lead to deadlocks since the superclass might also try to lock private val stateLock = new ReentrantLock - val extraCoresPerSlave = conf.getInt("spark.mesos.extra.cores", 0) + val extraCoresPerExecutor = conf.getInt("spark.mesos.extra.cores", 0) // Offer constraints private val slaveOfferConstraints = parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) - // A client for talking to the external shuffle service, if it is a + // reject offers with mismatched constraints in seconds + private val rejectOfferDurationForUnmetConstraints = + getRejectOfferDurationForUnmetConstraints(sc) + + // A client for talking to the external shuffle service private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = { if (shuffleServiceEnabled) { - Some(new MesosExternalShuffleClient( - SparkTransportConf.fromSparkConf(conf), - securityManager, - securityManager.isAuthenticationEnabled(), - securityManager.isSaslEncryptionEnabled())) + Some(getShuffleClient()) } else { None } } + // This method is factored out for testability + protected def getShuffleClient(): MesosExternalShuffleClient = { + new MesosExternalShuffleClient( + SparkTransportConf.fromSparkConf(conf, "shuffle"), + securityManager, + securityManager.isAuthenticationEnabled(), + securityManager.isSaslEncryptionEnabled()) + } + var nextMesosTaskId = 0 @volatile var appId: String = _ - def newMesosTaskId(): Int = { + def newMesosTaskId(): String = { val id = nextMesosTaskId nextMesosTaskId += 1 - id + id.toString } override def start() { @@ -132,11 +149,12 @@ private[spark] class CoarseMesosSchedulerBackend( sc.sparkUser, sc.appName, sc.conf, - sc.ui.map(_.appUIAddress)) + sc.conf.getOption("spark.mesos.driver.webui.url").orElse(sc.ui.map(_.appUIAddress)) + ) startScheduler(driver) } - def createCommand(offer: Offer, numCores: Int, taskId: Int): CommandInfo = { + def createCommand(offer: Offer, numCores: Int, taskId: String): CommandInfo = { val executorSparkHome = conf.getOption("spark.mesos.executor.home") .orElse(sc.getSparkHome()) .getOrElse { @@ -175,12 +193,12 @@ private[spark] class CoarseMesosSchedulerBackend( .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) if (uri.isEmpty) { - val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath + val runScript = new File(executorSparkHome, "./bin/spark-class").getPath command.setValue( "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend" .format(prefixEnv, runScript) + s" --driver-url $driverURL" + - s" --executor-id ${offer.getSlaveId.getValue}" + + s" --executor-id $taskId" + s" --hostname ${offer.getHostname}" + s" --cores $numCores" + s" --app-id $appId") @@ -188,12 +206,11 @@ private[spark] class CoarseMesosSchedulerBackend( // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.get.split('/').last.split('.').head - val executorId = sparkExecutorId(offer.getSlaveId.getValue, taskId.toString) command.setValue( s"cd $basename*; $prefixEnv " + - "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + s" --driver-url $driverURL" + - s" --executor-id $executorId" + + s" --executor-id $taskId" + s" --hostname ${offer.getHostname}" + s" --cores $numCores" + s" --app-id $appId") @@ -211,10 +228,10 @@ private[spark] class CoarseMesosSchedulerBackend( if (conf.contains("spark.testing")) { "driverURL" } else { - sc.env.rpcEnv.uriOf( - SparkEnv.driverActorSystemName, - RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + RpcEndpointAddress( + conf.get("spark.driver.host"), + conf.get("spark.driver.port").toInt, + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString } } @@ -241,105 +258,221 @@ private[spark] class CoarseMesosSchedulerBackend( */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { stateLock.synchronized { - val filters = Filters.newBuilder().setRefuseSeconds(5).build() - for (offer <- offers.asScala) { + if (stopCalled) { + logDebug("Ignoring offers during shutdown") + // Driver should simply return a stopped status on race + // condition between this.stop() and completing here + offers.asScala.map(_.getId).foreach(d.declineOffer) + return + } + + logDebug(s"Received ${offers.size} resource offers.") + + val (matchedOffers, unmatchedOffers) = offers.asScala.partition { offer => val offerAttributes = toAttributeMap(offer.getAttributesList) - val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + } + + declineUnmatchedOffers(d, unmatchedOffers) + handleMatchedOffers(d, matchedOffers) + } + } + + private def declineUnmatchedOffers(d: SchedulerDriver, offers: Buffer[Offer]): Unit = { + for (offer <- offers) { + val id = offer.getId.getValue + val offerAttributes = toAttributeMap(offer.getAttributesList) + val mem = getResource(offer.getResourcesList, "mem") + val cpus = getResource(offer.getResourcesList, "cpus") + val filters = Filters.newBuilder() + .setRefuseSeconds(rejectOfferDurationForUnmetConstraints).build() + + logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus" + + s" for $rejectOfferDurationForUnmetConstraints seconds") + + d.declineOffer(offer.getId, filters) + } + } + + /** + * Launches executors on accepted offers, and declines unused offers. Executors are launched + * round-robin on offers. + * + * @param d SchedulerDriver + * @param offers Mesos offers that match attribute constraints + */ + private def handleMatchedOffers(d: SchedulerDriver, offers: Buffer[Offer]): Unit = { + val tasks = buildMesosTasks(offers) + for (offer <- offers) { + val offerAttributes = toAttributeMap(offer.getAttributesList) + val offerMem = getResource(offer.getResourcesList, "mem") + val offerCpus = getResource(offer.getResourcesList, "cpus") + val id = offer.getId.getValue + + if (tasks.contains(offer.getId)) { // accept + val offerTasks = tasks(offer.getId) + + logDebug(s"Accepting offer: $id with attributes: $offerAttributes " + + s"mem: $offerMem cpu: $offerCpus. Launching ${offerTasks.size} Mesos tasks.") + + for (task <- offerTasks) { + val taskId = task.getTaskId + val mem = getResource(task.getResourcesList, "mem") + val cpus = getResource(task.getResourcesList, "cpus") + + logDebug(s"Launching Mesos task: ${taskId.getValue} with mem: $mem cpu: $cpus.") + } + + d.launchTasks( + Collections.singleton(offer.getId), + offerTasks.asJava) + } else { // decline + logDebug(s"Declining offer: $id with attributes: $offerAttributes " + + s"mem: $offerMem cpu: $offerCpus") + + d.declineOffer(offer.getId) + } + } + } + + /** + * Returns a map from OfferIDs to the tasks to launch on those offers. In order to maximize + * per-task memory and IO, tasks are round-robin assigned to offers. + * + * @param offers Mesos offers that match attribute constraints + * @return A map from OfferID to a list of Mesos tasks to launch on that offer + */ + private def buildMesosTasks(offers: Buffer[Offer]): Map[OfferID, List[MesosTaskInfo]] = { + // offerID -> tasks + val tasks = new HashMap[OfferID, List[MesosTaskInfo]].withDefaultValue(Nil) + + // offerID -> resources + val remainingResources = mutable.Map(offers.map(offer => + (offer.getId.getValue, offer.getResourcesList)): _*) + + var launchTasks = true + + // TODO(mgummelt): combine offers for a single slave + // + // round-robin create executors on the available offers + while (launchTasks) { + launchTasks = false + + for (offer <- offers) { val slaveId = offer.getSlaveId.getValue - val mem = getResource(offer.getResourcesList, "mem") - val cpus = getResource(offer.getResourcesList, "cpus").toInt - val id = offer.getId.getValue - if (taskIdToSlaveId.size < executorLimit && - totalCoresAcquired < maxCores && - meetsConstraints && - mem >= calculateTotalMemory(sc) && - cpus >= 1 && - failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && - !slaveIdsWithExecutors.contains(slaveId)) { - // Launch an executor on the slave - val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) - totalCoresAcquired += cpusToUse + val offerId = offer.getId.getValue + val resources = remainingResources(offerId) + + if (canLaunchTask(slaveId, resources)) { + // Create a task + launchTasks = true val taskId = newMesosTaskId() - taskIdToSlaveId.put(taskId, slaveId) - slaveIdsWithExecutors += slaveId - coresByTaskId(taskId) = cpusToUse - // Gather cpu resources from the available resources and use them in the task. - val (remainingResources, cpuResourcesToUse) = - partitionResources(offer.getResourcesList, "cpus", cpusToUse) - val (_, memResourcesToUse) = - partitionResources(remainingResources.asJava, "mem", calculateTotalMemory(sc)) + val offerCPUs = getResource(resources, "cpus").toInt + + val taskCPUs = executorCores(offerCPUs) + val taskMemory = executorMemory(sc) + + slaves.getOrElseUpdate(slaveId, new Slave(offer.getHostname)).taskIDs.add(taskId) + + val (afterCPUResources, cpuResourcesToUse) = + partitionResources(resources, "cpus", taskCPUs) + val (resourcesLeft, memResourcesToUse) = + partitionResources(afterCPUResources.asJava, "mem", taskMemory) + val taskBuilder = MesosTaskInfo.newBuilder() .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) .setSlaveId(offer.getSlaveId) - .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId)) + .setCommand(createCommand(offer, taskCPUs + extraCoresPerExecutor, taskId)) .setName("Task " + taskId) .addAllResources(cpuResourcesToUse.asJava) .addAllResources(memResourcesToUse.asJava) sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => MesosSchedulerBackendUtil - .setupContainerBuilderDockerInfo(image, sc.conf, taskBuilder.getContainerBuilder()) + .setupContainerBuilderDockerInfo(image, sc.conf, taskBuilder.getContainerBuilder) } - // accept the offer and launch the task - logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") - slaveIdToHost(offer.getSlaveId.getValue) = offer.getHostname - d.launchTasks( - Collections.singleton(offer.getId), - Collections.singleton(taskBuilder.build()), filters) - } else { - // Decline the offer - logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") - d.declineOffer(offer.getId) + tasks(offer.getId) ::= taskBuilder.build() + remainingResources(offerId) = resourcesLeft.asJava + totalCoresAcquired += taskCPUs + coresByTaskId(taskId) = taskCPUs } } } + tasks.toMap + } + + private def canLaunchTask(slaveId: String, resources: JList[Resource]): Boolean = { + val offerMem = getResource(resources, "mem") + val offerCPUs = getResource(resources, "cpus").toInt + val cpus = executorCores(offerCPUs) + val mem = executorMemory(sc) + + cpus > 0 && + cpus <= offerCPUs && + cpus + totalCoresAcquired <= maxCores && + mem <= offerMem && + numExecutors() < executorLimit && + slaves.get(slaveId).map(_.taskFailures).getOrElse(0) < MAX_SLAVE_FAILURES } + private def executorCores(offerCPUs: Int): Int = { + sc.conf.getInt("spark.executor.cores", + math.min(offerCPUs, maxCores - totalCoresAcquired)) + } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { - val taskId = status.getTaskId.getValue.toInt - val state = status.getState - logInfo(s"Mesos task $taskId is now $state") - val slaveId: String = status.getSlaveId.getValue + val taskId = status.getTaskId.getValue + val slaveId = status.getSlaveId.getValue + val state = TaskState.fromMesos(status.getState) + + logInfo(s"Mesos task $taskId is now ${status.getState}") + stateLock.synchronized { + val slave = slaves(slaveId) + // If the shuffle service is enabled, have the driver register with each one of the // shuffle services. This allows the shuffle services to clean up state associated with // this application when the driver exits. There is currently not a great way to detect // this through Mesos, since the shuffle services are set up independently. - if (TaskState.fromMesos(state).equals(TaskState.RUNNING) && - slaveIdToHost.contains(slaveId) && - shuffleServiceEnabled) { + if (state.equals(TaskState.RUNNING) && + shuffleServiceEnabled && + !slave.shuffleRegistered) { assume(mesosExternalShuffleClient.isDefined, "External shuffle client was not instantiated even though shuffle service is enabled.") // TODO: Remove this and allow the MesosExternalShuffleService to detect // framework termination when new Mesos Framework HTTP API is available. val externalShufflePort = conf.getInt("spark.shuffle.service.port", 7337) - val hostname = slaveIdToHost.remove(slaveId).get + logDebug(s"Connecting to shuffle service on slave $slaveId, " + - s"host $hostname, port $externalShufflePort for app ${conf.getAppId}") + s"host ${slave.hostname}, port $externalShufflePort for app ${conf.getAppId}") + mesosExternalShuffleClient.get - .registerDriverWithShuffleService(hostname, externalShufflePort) + .registerDriverWithShuffleService( + slave.hostname, + externalShufflePort, + sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", + s"${sc.conf.getTimeAsMs("spark.network.timeout", "120s")}ms"), + sc.conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s")) + slave.shuffleRegistered = true } - if (TaskState.isFinished(TaskState.fromMesos(state))) { - val slaveId = taskIdToSlaveId.get(taskId) - slaveIdsWithExecutors -= slaveId - taskIdToSlaveId.remove(taskId) + if (TaskState.isFinished(state)) { // Remove the cores we have remembered for this task, if it's in the hashmap for (cores <- coresByTaskId.get(taskId)) { totalCoresAcquired -= cores coresByTaskId -= taskId } // If it was a failure, mark the slave as failed for blacklisting purposes - if (TaskState.isFailed(TaskState.fromMesos(state))) { - failuresBySlaveId(slaveId) = failuresBySlaveId.getOrElse(slaveId, 0) + 1 - if (failuresBySlaveId(slaveId) >= MAX_SLAVE_FAILURES) { + if (TaskState.isFailed(state)) { + slave.taskFailures += 1 + + if (slave.taskFailures >= MAX_SLAVE_FAILURES) { logInfo(s"Blacklisting Mesos slave $slaveId due to too many failures; " + "is Spark installed on it?") } } - executorTerminated(d, slaveId, s"Executor finished with state $state") + executorTerminated(d, slaveId, taskId, s"Executor finished with state $state") // In case we'd rejected everything before but have now lost a node d.reviveOffers() } @@ -352,7 +485,35 @@ private[spark] class CoarseMesosSchedulerBackend( } override def stop() { - super.stop() + // Make sure we're not launching tasks during shutdown + stateLock.synchronized { + if (stopCalled) { + logWarning("Stop called multiple times, ignoring") + return + } + stopCalled = true + super.stop() + } + + // Wait for executors to report done, or else mesosDriver.stop() will forcefully kill them. + // See SPARK-12330 + val startTime = System.nanoTime() + + // slaveIdsWithExecutors has no memory barrier, so this is eventually consistent + while (numExecutors() > 0 && + System.nanoTime() - startTime < shutdownTimeoutMS * 1000L * 1000L) { + Thread.sleep(100) + } + + if (numExecutors() > 0) { + logWarning(s"Timed out waiting for ${numExecutors()} remaining executors " + + s"to terminate within $shutdownTimeoutMS ms. This may leave temporary files " + + "on the mesos nodes.") + } + + // Close the mesos external shuffle client if used + mesosExternalShuffleClient.foreach(_.close()) + if (mesosDriver != null) { mesosDriver.stop() } @@ -361,40 +522,26 @@ private[spark] class CoarseMesosSchedulerBackend( override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} /** - * Called when a slave is lost or a Mesos task finished. Update local view on - * what tasks are running and remove the terminated slave from the list of pending - * slave IDs that we might have asked to be killed. It also notifies the driver - * that an executor was removed. + * Called when a slave is lost or a Mesos task finished. Updates local view on + * what tasks are running. It also notifies the driver that an executor was removed. */ - private def executorTerminated(d: SchedulerDriver, slaveId: String, reason: String): Unit = { + private def executorTerminated( + d: SchedulerDriver, + slaveId: String, + taskId: String, + reason: String): Unit = { stateLock.synchronized { - if (slaveIdsWithExecutors.contains(slaveId)) { - val slaveIdToTaskId = taskIdToSlaveId.inverse() - if (slaveIdToTaskId.containsKey(slaveId)) { - val taskId: Int = slaveIdToTaskId.get(slaveId) - taskIdToSlaveId.remove(taskId) - removeExecutor(sparkExecutorId(slaveId, taskId.toString), SlaveLost(reason)) - } - // TODO: This assumes one Spark executor per Mesos slave, - // which may no longer be true after SPARK-5095 - pendingRemovedSlaveIds -= slaveId - slaveIdsWithExecutors -= slaveId - } + removeExecutor(taskId, SlaveLost(reason)) + slaves(slaveId).taskIDs.remove(taskId) } } - private def sparkExecutorId(slaveId: String, taskId: String): String = { - s"$slaveId/$taskId" - } - override def slaveLost(d: SchedulerDriver, slaveId: SlaveID): Unit = { logInfo(s"Mesos slave lost: ${slaveId.getValue}") - executorTerminated(d, slaveId.getValue, "Mesos slave lost: " + slaveId.getValue) } override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int): Unit = { - logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue)) - slaveLost(d, s) + logInfo("Mesos executor lost: %s".format(e.getValue)) } override def applicationId(): String = @@ -414,23 +561,26 @@ private[spark] class CoarseMesosSchedulerBackend( override def doKillExecutors(executorIds: Seq[String]): Boolean = { if (mesosDriver == null) { logWarning("Asked to kill executors before the Mesos driver was started.") - return false - } - - val slaveIdToTaskId = taskIdToSlaveId.inverse() - for (executorId <- executorIds) { - val slaveId = executorId.split("/")(0) - if (slaveIdToTaskId.containsKey(slaveId)) { - mesosDriver.killTask( - TaskID.newBuilder().setValue(slaveIdToTaskId.get(slaveId).toString).build()) - pendingRemovedSlaveIds += slaveId - } else { - logWarning("Unable to find executor Id '" + executorId + "' in Mesos scheduler") + false + } else { + for (executorId <- executorIds) { + val taskId = TaskID.newBuilder().setValue(executorId).build() + mesosDriver.killTask(taskId) } + // no need to adjust `executorLimitOption` since the AllocationManager already communicated + // the desired limit through a call to `doRequestTotalExecutors`. + // See [[o.a.s.scheduler.cluster.CoarseGrainedSchedulerBackend.killExecutors]] + true } - // no need to adjust `executorLimitOption` since the AllocationManager already communicated - // the desired limit through a call to `doRequestTotalExecutors`. - // See [[o.a.s.scheduler.cluster.CoarseGrainedSchedulerBackend.killExecutors]] - true } + + private def numExecutors(): Int = { + slaves.values.map(_.taskIDs.size).sum + } +} + +private class Slave(val hostname: String) { + val taskIDs = new HashSet[String]() + var taskFailures = 0 + var shuffleRegistered = false } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala index e0c547dce6d07..61ab3e87c5711 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala @@ -23,8 +23,9 @@ import org.apache.curator.framework.CuratorFramework import org.apache.zookeeper.CreateMode import org.apache.zookeeper.KeeperException.NoNodeException -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkCuratorUtil +import org.apache.spark.internal.Logging import org.apache.spark.util.Utils /** @@ -53,9 +54,9 @@ private[spark] trait MesosClusterPersistenceEngine { * all of them reuses the same connection pool. */ private[spark] class ZookeeperMesosClusterPersistenceEngineFactory(conf: SparkConf) - extends MesosClusterPersistenceEngineFactory(conf) { + extends MesosClusterPersistenceEngineFactory(conf) with Logging { - lazy val zk = SparkCuratorUtil.newClient(conf, "spark.mesos.deploy.zookeeper.url") + lazy val zk = SparkCuratorUtil.newClient(conf) def createEngine(path: String): MesosClusterPersistenceEngine = { new ZookeeperMesosClusterPersistenceEngine(path, zk, conf) @@ -120,11 +121,10 @@ private[spark] class ZookeeperMesosClusterPersistenceEngine( Some(Utils.deserialize[T](fileData)) } catch { case e: NoNodeException => None - case e: Exception => { + case e: Exception => logWarning("Exception while reading persisted file, deleting", e) zk.delete().forPath(zkPath) None - } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index a6d9374eb9e8c..73bd4c58e16fc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -18,23 +18,22 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.concurrent.locks.ReentrantLock import java.util.{Collections, Date, List => JList} import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.mesos.{Scheduler, SchedulerDriver} +import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason -import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} -import org.apache.mesos.{Scheduler, SchedulerDriver} + +import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.Utils -import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} - /** * Tracks the current state of a Mesos Task that runs a Spark driver. @@ -126,7 +125,7 @@ private[spark] class MesosClusterScheduler( private val retainedDrivers = conf.getInt("spark.mesos.retainedDrivers", 200) private val maxRetryWaitTime = conf.getInt("spark.mesos.cluster.retry.wait.max", 60) // 1 minute private val schedulerState = engineFactory.createEngine("scheduler") - private val stateLock = new ReentrantLock() + private val stateLock = new Object() private val finishedDrivers = new mutable.ArrayBuffer[MesosClusterSubmissionState](retainedDrivers) private var frameworkId: String = null @@ -358,9 +357,10 @@ private[spark] class MesosClusterScheduler( val appJar = CommandInfo.URI.newBuilder() .setValue(desc.jarUrl.stripPrefix("file:").stripPrefix("local:")).build() val builder = CommandInfo.newBuilder().addUris(appJar) - val entries = - (conf.getOption("spark.executor.extraLibraryPath").toList ++ - desc.command.libraryPathEntries) + val entries = conf.getOption("spark.executor.extraLibraryPath") + .map(path => Seq(path) ++ desc.command.libraryPathEntries) + .getOrElse(desc.command.libraryPathEntries) + val prefixEnv = if (!entries.isEmpty) { Utils.libraryPathEnvPrefix(entries) } else { @@ -395,7 +395,7 @@ private[spark] class MesosClusterScheduler( .getOrElse { throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") } - val cmdExecutable = new File(executorSparkHome, "./bin/spark-submit").getCanonicalPath + val cmdExecutable = new File(executorSparkHome, "./bin/spark-submit").getPath // Sandbox points to the current directory by default with Mesos. (cmdExecutable, ".") } @@ -423,6 +423,12 @@ private[spark] class MesosClusterScheduler( "--driver-cores", desc.cores.toString, "--driver-memory", s"${desc.mem}M") + val replicatedOptionsBlacklist = Set( + "spark.jars", // Avoids duplicate classes in classpath + "spark.submit.deployMode", // this would be set to `cluster`, but we need client + "spark.master" // this contains the address of the dispatcher, not master + ) + // Assume empty main class means we're running python if (!desc.command.mainClass.equals("")) { options ++= Seq("--class", desc.command.mainClass) @@ -440,12 +446,35 @@ private[spark] class MesosClusterScheduler( .mkString(",") options ++= Seq("--py-files", formattedFiles) } + desc.schedulerProperties + .filter { case (key, _) => !replicatedOptionsBlacklist.contains(key) } + .foreach { case (key, value) => options ++= Seq("--conf", s"$key=${shellEscape(value)}") } options } - private class ResourceOffer(val offer: Offer, var cpu: Double, var mem: Double) { + /** + * Escape args for Unix-like shells, unless already quoted by the user. + * Based on: http://www.gnu.org/software/bash/manual/html_node/Double-Quotes.html + * and http://www.grymoire.com/Unix/Quote.html + * @param value argument + * @return escaped argument + */ + private[scheduler] def shellEscape(value: String): String = { + val WrappedInQuotes = """^(".+"|'.+')$""".r + val ShellSpecialChars = (""".*([ '<>&|\?\*;!#\\(\)"$`]).*""").r + value match { + case WrappedInQuotes(c) => value // The user quoted his args, don't touch it! + case ShellSpecialChars(c) => "\"" + value.replaceAll("""(["`\$\\])""", """\\$1""") + "\"" + case _: String => value // Don't touch harmless strings + } + } + + private class ResourceOffer( + val offerId: OfferID, + val slaveId: SlaveID, + var resources: JList[Resource]) { override def toString(): String = { - s"Offer id: ${offer.getId.getValue}, cpu: $cpu, mem: $mem" + s"Offer id: ${offerId}, resources: ${resources}" } } @@ -464,27 +493,29 @@ private[spark] class MesosClusterScheduler( val driverMem = submission.mem logTrace(s"Finding offer to launch driver with cpu: $driverCpu, mem: $driverMem") val offerOption = currentOffers.find { o => - o.cpu >= driverCpu && o.mem >= driverMem + getResource(o.resources, "cpus") >= driverCpu && + getResource(o.resources, "mem") >= driverMem } if (offerOption.isEmpty) { logDebug(s"Unable to find offer to launch driver id: ${submission.submissionId}, " + s"cpu: $driverCpu, mem: $driverMem") } else { val offer = offerOption.get - offer.cpu -= driverCpu - offer.mem -= driverMem val taskId = TaskID.newBuilder().setValue(submission.submissionId).build() - val cpuResource = createResource("cpus", driverCpu) - val memResource = createResource("mem", driverMem) + val (remainingResources, cpuResourcesToUse) = + partitionResources(offer.resources, "cpus", driverCpu) + val (finalResources, memResourcesToUse) = + partitionResources(remainingResources.asJava, "mem", driverMem) val commandInfo = buildDriverCommand(submission) val appName = submission.schedulerProperties("spark.app.name") val taskInfo = TaskInfo.newBuilder() .setTaskId(taskId) .setName(s"Driver for $appName") - .setSlaveId(offer.offer.getSlaveId) + .setSlaveId(offer.slaveId) .setCommand(commandInfo) - .addResources(cpuResource) - .addResources(memResource) + .addAllResources(cpuResourcesToUse.asJava) + .addAllResources(memResourcesToUse.asJava) + offer.resources = finalResources.asJava submission.schedulerProperties.get("spark.mesos.executor.docker.image").foreach { image => val container = taskInfo.getContainerBuilder() val volumes = submission.schedulerProperties @@ -497,11 +528,11 @@ private[spark] class MesosClusterScheduler( container, image, volumes = volumes, portmaps = portmaps) taskInfo.setContainer(container.build()) } - val queuedTasks = tasks.getOrElseUpdate(offer.offer.getId, new ArrayBuffer[TaskInfo]) + val queuedTasks = tasks.getOrElseUpdate(offer.offerId, new ArrayBuffer[TaskInfo]) queuedTasks += taskInfo.build() - logTrace(s"Using offer ${offer.offer.getId.getValue} to launch driver " + + logTrace(s"Using offer ${offer.offerId.getValue} to launch driver " + submission.submissionId) - val newState = new MesosClusterSubmissionState(submission, taskId, offer.offer.getSlaveId, + val newState = new MesosClusterSubmissionState(submission, taskId, offer.slaveId, None, new Date(), None) launchedDrivers(submission.submissionId) = newState launchedDriversState.persist(submission.submissionId, newState) @@ -511,14 +542,14 @@ private[spark] class MesosClusterScheduler( } override def resourceOffers(driver: SchedulerDriver, offers: JList[Offer]): Unit = { - val currentOffers = offers.asScala.map(o => - new ResourceOffer( - o, getResource(o.getResourcesList, "cpus"), getResource(o.getResourcesList, "mem")) - ).toList - logTrace(s"Received offers from Mesos: \n${currentOffers.mkString("\n")}") + logTrace(s"Received offers from Mesos: \n${offers.asScala.mkString("\n")}") val tasks = new mutable.HashMap[OfferID, ArrayBuffer[TaskInfo]]() val currentTime = new Date() + val currentOffers = offers.asScala.map { + o => new ResourceOffer(o.getId, o.getSlaveId, o.getResourcesList) + }.toList + stateLock.synchronized { // We first schedule all the supervised drivers that are ready to retry. // This list will be empty if none of the drivers are marked as supervise. @@ -542,9 +573,10 @@ private[spark] class MesosClusterScheduler( tasks.foreach { case (offerId, taskInfos) => driver.launchTasks(Collections.singleton(offerId), taskInfos.asJava) } - offers.asScala - .filter(o => !tasks.keySet.contains(o.getId)) - .foreach(o => driver.declineOffer(o.getId)) + + for (o <- currentOffers if !tasks.contains(o.offerId)) { + driver.declineOffer(o.offerId) + } } private def copyBuffer( @@ -574,6 +606,7 @@ private[spark] class MesosClusterScheduler( override def slaveLost(driver: SchedulerDriver, slaveId: SlaveID): Unit = {} override def error(driver: SchedulerDriver, error: String): Unit = { logError("Error received: " + error) + markErr() } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index aaffac604a885..1a94aee2ca30c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -26,6 +26,7 @@ import scala.collection.mutable.{HashMap, HashSet} import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, _} import org.apache.mesos.protobuf.ByteString + import org.apache.spark.{SparkContext, SparkException, TaskState} import org.apache.spark.executor.MesosExecutorBackend import org.apache.spark.scheduler._ @@ -63,6 +64,10 @@ private[spark] class MesosSchedulerBackend( private[this] val slaveOfferConstraints = parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + // reject offers with mismatched constraints in seconds + private val rejectOfferDurationForUnmetConstraints = + getRejectOfferDurationForUnmetConstraints(sc) + @volatile var appId: String = _ override def start() { @@ -73,7 +78,8 @@ private[spark] class MesosSchedulerBackend( sc.sparkUser, sc.appName, sc.conf, - sc.ui.map(_.appUIAddress)) + sc.conf.getOption("spark.mesos.driver.webui.url").orElse(sc.ui.map(_.appUIAddress)) + ) startScheduler(driver) } @@ -120,7 +126,7 @@ private[spark] class MesosSchedulerBackend( val executorBackendName = classOf[MesosExecutorBackend].getName if (uri.isEmpty) { - val executorPath = new File(executorSparkHome, "/bin/spark-class").getCanonicalPath + val executorPath = new File(executorSparkHome, "/bin/spark-class").getPath command.setValue(s"$prefixEnv $executorPath $executorBackendName") } else { // Grab everything to the first '.'. We'll use that and '*' to @@ -133,7 +139,7 @@ private[spark] class MesosSchedulerBackend( val (resourcesAfterCpu, usedCpuResources) = partitionResources(availableResources, "cpus", mesosExecutorCores) val (resourcesAfterMem, usedMemResources) = - partitionResources(resourcesAfterCpu.asJava, "mem", calculateTotalMemory(sc)) + partitionResources(resourcesAfterCpu.asJava, "mem", executorMemory(sc)) builder.addAllResources(usedCpuResources.asJava) builder.addAllResources(usedMemResources.asJava) @@ -212,29 +218,47 @@ private[spark] class MesosSchedulerBackend( */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { inClassLoader() { - // Fail-fast on offers we know will be rejected - val (usableOffers, unUsableOffers) = offers.asScala.partition { o => + // Fail first on offers with unmet constraints + val (offersMatchingConstraints, offersNotMatchingConstraints) = + offers.asScala.partition { o => + val offerAttributes = toAttributeMap(o.getAttributesList) + val meetsConstraints = + matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + + // add some debug messaging + if (!meetsConstraints) { + val id = o.getId.getValue + logDebug(s"Declining offer: $id with attributes: $offerAttributes") + } + + meetsConstraints + } + + // These offers do not meet constraints. We don't need to see them again. + // Decline the offer for a long period of time. + offersNotMatchingConstraints.foreach { o => + d.declineOffer(o.getId, Filters.newBuilder() + .setRefuseSeconds(rejectOfferDurationForUnmetConstraints).build()) + } + + // Of the matching constraints, see which ones give us enough memory and cores + val (usableOffers, unUsableOffers) = offersMatchingConstraints.partition { o => val mem = getResource(o.getResourcesList, "mem") val cpus = getResource(o.getResourcesList, "cpus") val slaveId = o.getSlaveId.getValue val offerAttributes = toAttributeMap(o.getAttributesList) - // check if all constraints are satisfield - // 1. Attribute constraints - // 2. Memory requirements - // 3. CPU requirements - need at least 1 for executor, 1 for task - val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) - val meetsMemoryRequirements = mem >= calculateTotalMemory(sc) + // check offers for + // 1. Memory requirements + // 2. CPU requirements - need at least 1 for executor, 1 for task + val meetsMemoryRequirements = mem >= executorMemory(sc) val meetsCPURequirements = cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK) - val meetsRequirements = - (meetsConstraints && meetsMemoryRequirements && meetsCPURequirements) || + (meetsMemoryRequirements && meetsCPURequirements) || (slaveIdToExecutorInfo.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK) - - // add some debug messaging val debugstr = if (meetsRequirements) "Accepting" else "Declining" - val id = o.getId.getValue - logDebug(s"$debugstr offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + logDebug(s"$debugstr offer: ${o.getId.getValue} with attributes: " + + s"$offerAttributes mem: $mem cpu: $cpus") meetsRequirements } @@ -352,6 +376,7 @@ private[spark] class MesosSchedulerBackend( override def error(d: SchedulerDriver, message: String) { inClassLoader() { logError("Mesos error: " + message) + markErr() scheduler.error(message) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala index e79c543a9de27..1b7ac172defb9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -20,7 +20,8 @@ package org.apache.spark.scheduler.cluster.mesos import org.apache.mesos.Protos.{ContainerInfo, Volume} import org.apache.mesos.Protos.ContainerInfo.DockerInfo -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging /** * A collection of utility functions which can be used by both the @@ -54,11 +55,10 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { Some(vol.setContainerPath(container_path) .setHostPath(host_path) .setMode(Volume.Mode.RO)) - case spec => { + case spec => logWarning(s"Unable to parse volume specs: $volumes. " + "Expected form: \"[host-dir:]container-dir[:rw|:ro](, ...)\"") None - } } } .map { _.build() } @@ -89,11 +89,10 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { Some(portmap.setHostPort(host_port.toInt) .setContainerPort(container_port.toInt) .setProtocol(protocol)) - case spec => { + case spec => logWarning(s"Unable to parse port mapping specs: $portmaps. " + "Expected form: \"host_port:container_port[:udp|:tcp](, ...)\"") None - } } } .map { _.build() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 860c8e097b3b9..1e322ac679419 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -25,12 +25,13 @@ import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import com.google.common.base.Splitter -import org.apache.mesos.{MesosSchedulerDriver, SchedulerDriver, Scheduler, Protos} +import org.apache.mesos.{MesosSchedulerDriver, Protos, Scheduler, SchedulerDriver} import org.apache.mesos.Protos._ import org.apache.mesos.protobuf.{ByteString, GeneratedMessage} -import org.apache.spark.{SparkException, SparkConf, Logging, SparkContext} -import org.apache.spark.util.Utils +import org.apache.spark.{SparkConf, SparkContext, SparkException} +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils /** * Shared trait for implementing a Mesos Scheduler. This holds common state and helper @@ -106,44 +107,56 @@ private[mesos] trait MesosSchedulerUtils extends Logging { registerLatch.await() return } + @volatile + var error: Option[Exception] = None + // We create a new thread that will block inside `mesosDriver.run` + // until the scheduler exists new Thread(Utils.getFormattedClassName(this) + "-mesos-driver") { setDaemon(true) - override def run() { - mesosDriver = newDriver try { + mesosDriver = newDriver val ret = mesosDriver.run() logInfo("driver.run() returned with code " + ret) if (ret != null && ret.equals(Status.DRIVER_ABORTED)) { - System.exit(1) + error = Some(new SparkException("Error starting driver, DRIVER_ABORTED")) + markErr() } } catch { - case e: Exception => { + case e: Exception => logError("driver.run() failed", e) - System.exit(1) - } + error = Some(e) + markErr() } } }.start() registerLatch.await() + + // propagate any error to the calling thread. This ensures that SparkContext creation fails + // without leaving a broken context that won't be able to schedule any tasks + error.foreach(throw _) } } - /** - * Signal that the scheduler has registered with Mesos. - */ - protected def getResource(res: JList[Resource], name: String): Double = { + def getResource(res: JList[Resource], name: String): Double = { // A resource can have multiple values in the offer since it can either be from // a specific role or wildcard. res.asScala.filter(_.getName == name).map(_.getScalar.getValue).sum } + /** + * Signal that the scheduler has registered with Mesos. + */ protected def markRegistered(): Unit = { registerLatch.countDown() } + protected def markErr(): Unit = { + registerLatch.countDown() + } + def createResource(name: String, amount: Double, role: Option[String] = None): Resource = { val builder = Resource.newBuilder() .setName(name) @@ -170,7 +183,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { var remain = amountToUse var requestedResources = new ArrayBuffer[Resource] val remainingResources = resources.asScala.map { - case r => { + case r => if (remain > 0 && r.getType == Value.Type.SCALAR && r.getScalar.getValue > 0.0 && @@ -182,7 +195,6 @@ private[mesos] trait MesosSchedulerUtils extends Logging { } else { r } - } } // Filter any resource that has depleted. @@ -214,7 +226,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { * @return */ protected def toAttributeMap(offerAttributes: JList[Attribute]): Map[String, GeneratedMessage] = { - offerAttributes.asScala.map(attr => { + offerAttributes.asScala.map { attr => val attrValue = attr.getType match { case Value.Type.SCALAR => attr.getScalar case Value.Type.RANGES => attr.getRanges @@ -222,7 +234,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { case Value.Type.TEXT => attr.getText } (attr.getName, attrValue) - }).toMap + }.toMap } @@ -269,11 +281,11 @@ private[mesos] trait MesosSchedulerUtils extends Logging { * are separated by ':'. The ':' implies equality (for singular values) and "is one of" for * multiple values (comma separated). For example: * {{{ - * parseConstraintString("tachyon:true;zone:us-east-1a,us-east-1b") + * parseConstraintString("os:centos7;zone:us-east-1a,us-east-1b") * // would result in * * Map( - * "tachyon" -> Set("true"), + * "os" -> Set("centos7"), * "zone": -> Set("us-east-1a", "us-east-1b") * ) * }}} @@ -324,7 +336,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { * @return memory requirement as (0.1 * ) or MEMORY_OVERHEAD_MINIMUM * (whichever is larger) */ - def calculateTotalMemory(sc: SparkContext): Int = { + def executorMemory(sc: SparkContext): Int = { sc.conf.getInt("spark.mesos.executor.memoryOverhead", math.max(MEMORY_OVERHEAD_FRACTION * sc.executorMemory, MEMORY_OVERHEAD_MINIMUM).toInt) + sc.executorMemory @@ -336,4 +348,8 @@ private[mesos] trait MesosSchedulerUtils extends Logging { } } + protected def getRejectOfferDurationForUnmetConstraints(sc: SparkContext): Long = { + sc.conf.getTimeAsSeconds("spark.mesos.rejectOfferDurationForUnmetConstraints", "120s") + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala index 5e7e6567a3e06..8370b61145e45 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import org.apache.mesos.protobuf.ByteString -import org.apache.spark.Logging +import org.apache.spark.internal.Logging /** * Wrapper for serializing the data sent when launching Mesos tasks. diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index c633d860ae6e5..3473ef21b39a4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -21,9 +21,10 @@ import java.io.File import java.net.URL import java.nio.ByteBuffer -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState} +import org.apache.spark.{SparkConf, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} +import org.apache.spark.internal.Logging import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler._ diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala index 62f8aae7f2126..d17a7894fd8a8 100644 --- a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala @@ -19,6 +19,7 @@ package org.apache.spark.serializer import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets import scala.collection.mutable @@ -29,7 +30,7 @@ import org.apache.avro.generic.{GenericData, GenericRecord} import org.apache.avro.io._ import org.apache.commons.io.IOUtils -import org.apache.spark.{SparkException, SparkEnv} +import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.io.CompressionCodec /** @@ -71,7 +72,7 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String]) def compress(schema: Schema): Array[Byte] = compressCache.getOrElseUpdate(schema, { val bos = new ByteArrayOutputStream() val out = codec.compressedOutputStream(bos) - out.write(schema.toString.getBytes("UTF-8")) + out.write(schema.toString.getBytes(StandardCharsets.UTF_8)) out.close() bos.toByteArray }) @@ -81,9 +82,12 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String]) * seen values so to limit the number of times that decompression has to be done. */ def decompress(schemaBytes: ByteBuffer): Schema = decompressCache.getOrElseUpdate(schemaBytes, { - val bis = new ByteArrayInputStream(schemaBytes.array()) + val bis = new ByteArrayInputStream( + schemaBytes.array(), + schemaBytes.arrayOffset() + schemaBytes.position(), + schemaBytes.remaining()) val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis)) - new Schema.Parser().parse(new String(bytes, "UTF-8")) + new Schema.Parser().parse(new String(bytes, StandardCharsets.UTF_8)) }) /** diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index b463a71d5bd7d..8b72da2ee01b7 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -24,8 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.ByteBufferInputStream -import org.apache.spark.util.Utils +import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} private[spark] class JavaSerializationStream( out: OutputStream, counterReset: Int, extraDebugInfo: Boolean) @@ -69,7 +68,7 @@ private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoa // scalastyle:on classforname } catch { case e: ClassNotFoundException => - JavaDeserializationStream.primitiveMappings.get(desc.getName).getOrElse(throw e) + JavaDeserializationStream.primitiveMappings.getOrElse(desc.getName, throw e) } } @@ -96,11 +95,11 @@ private[spark] class JavaSerializerInstance( extends SerializerInstance { override def serialize[T: ClassTag](t: T): ByteBuffer = { - val bos = new ByteArrayOutputStream() + val bos = new ByteBufferOutputStream() val out = serializeStream(bos) out.writeObject(t) out.close() - ByteBuffer.wrap(bos.toByteArray) + bos.toByteBuffer } override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index bc51d4f2820c8..918ae376f6286 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -17,7 +17,7 @@ package org.apache.spark.serializer -import java.io.{EOFException, IOException, InputStream, OutputStream} +import java.io._ import java.nio.ByteBuffer import javax.annotation.Nullable @@ -25,20 +25,21 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag -import com.esotericsoftware.kryo.{Kryo, KryoException} +import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSerializer} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import org.apache.avro.generic.{GenericData, GenericRecord} +import org.roaringbitmap.RoaringBitmap import org.apache.spark._ import org.apache.spark.api.python.PythonBroadcast -import org.apache.spark.broadcast.HttpBroadcast +import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ -import org.apache.spark.util.{Utils, BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf} -import org.apache.spark.util.collection.{BitSet, CompactBuffer} +import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf, Utils} +import org.apache.spark.util.collection.CompactBuffer /** * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. @@ -69,7 +70,9 @@ class KryoSerializer(conf: SparkConf) private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false) - private val userRegistrator = conf.getOption("spark.kryo.registrator") + private val userRegistrators = conf.get("spark.kryo.registrator", "") + .split(',') + .filter(!_.isEmpty) private val classesToRegister = conf.get("spark.kryo.classesToRegister", "") .split(',') .filter(!_.isEmpty) @@ -93,6 +96,9 @@ class KryoSerializer(conf: SparkConf) for (cls <- KryoSerializer.toRegister) { kryo.register(cls) } + for ((cls, ser) <- KryoSerializer.toRegisterSerializer) { + kryo.register(cls, ser) + } // For results returned by asJavaIterable. See JavaIterableWrapperSerializer. kryo.register(JavaIterableWrapperSerializer.wrapperClass, new JavaIterableWrapperSerializer) @@ -101,7 +107,6 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) kryo.register(classOf[SerializableConfiguration], new KryoJavaSerializer()) kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer()) - kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas)) @@ -115,7 +120,7 @@ class KryoSerializer(conf: SparkConf) classesToRegister .foreach { className => kryo.register(Class.forName(className, true, classLoader)) } // Allow the user to register their own classes by setting spark.kryo.registrator. - userRegistrator + userRegistrators .map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]) .foreach { reg => reg.registerClasses(kryo) } // scalastyle:on classforname @@ -303,7 +308,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { val kryo = borrowKryo() try { - input.setBuffer(bytes.array) + input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) kryo.readClassAndObject(input).asInstanceOf[T] } finally { releaseKryo(kryo) @@ -315,7 +320,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ val oldClassLoader = kryo.getClassLoader try { kryo.setClassLoader(loader) - input.setBuffer(bytes.array) + input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) kryo.readClassAndObject(input).asInstanceOf[T] } finally { kryo.setClassLoader(oldClassLoader) @@ -352,7 +357,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ * serialization. */ trait KryoRegistrator { - def registerClasses(kryo: Kryo) + def registerClasses(kryo: Kryo): Unit } private[serializer] object KryoSerializer { @@ -362,7 +367,6 @@ private[serializer] object KryoSerializer { classOf[StorageLevel], classOf[CompressedMapStatus], classOf[HighlyCompressedMapStatus], - classOf[BitSet], classOf[CompactBuffer[_]], classOf[BlockManagerId], classOf[Array[Byte]], @@ -371,6 +375,72 @@ private[serializer] object KryoSerializer { classOf[BoundedPriorityQueue[_]], classOf[SparkConf] ) + + private val toRegisterSerializer = Map[Class[_], KryoClassSerializer[_]]( + classOf[RoaringBitmap] -> new KryoClassSerializer[RoaringBitmap]() { + override def write(kryo: Kryo, output: KryoOutput, bitmap: RoaringBitmap): Unit = { + bitmap.serialize(new KryoOutputObjectOutputBridge(kryo, output)) + } + override def read(kryo: Kryo, input: KryoInput, cls: Class[RoaringBitmap]): RoaringBitmap = { + val ret = new RoaringBitmap + ret.deserialize(new KryoInputObjectInputBridge(kryo, input)) + ret + } + } + ) +} + +/** + * This is a bridge class to wrap KryoInput as an InputStream and ObjectInput. It forwards all + * methods of InputStream and ObjectInput to KryoInput. It's usually helpful when an API expects + * an InputStream or ObjectInput but you want to use Kryo. + */ +private[spark] class KryoInputObjectInputBridge( + kryo: Kryo, input: KryoInput) extends FilterInputStream(input) with ObjectInput { + override def readLong(): Long = input.readLong() + override def readChar(): Char = input.readChar() + override def readFloat(): Float = input.readFloat() + override def readByte(): Byte = input.readByte() + override def readShort(): Short = input.readShort() + override def readUTF(): String = input.readString() // readString in kryo does utf8 + override def readInt(): Int = input.readInt() + override def readUnsignedShort(): Int = input.readShortUnsigned() + override def skipBytes(n: Int): Int = { + input.skip(n) + n + } + override def readFully(b: Array[Byte]): Unit = input.read(b) + override def readFully(b: Array[Byte], off: Int, len: Int): Unit = input.read(b, off, len) + override def readLine(): String = throw new UnsupportedOperationException("readLine") + override def readBoolean(): Boolean = input.readBoolean() + override def readUnsignedByte(): Int = input.readByteUnsigned() + override def readDouble(): Double = input.readDouble() + override def readObject(): AnyRef = kryo.readClassAndObject(input) +} + +/** + * This is a bridge class to wrap KryoOutput as an OutputStream and ObjectOutput. It forwards all + * methods of OutputStream and ObjectOutput to KryoOutput. It's usually helpful when an API expects + * an OutputStream or ObjectOutput but you want to use Kryo. + */ +private[spark] class KryoOutputObjectOutputBridge( + kryo: Kryo, output: KryoOutput) extends FilterOutputStream(output) with ObjectOutput { + override def writeFloat(v: Float): Unit = output.writeFloat(v) + // There is no "readChars" counterpart, except maybe "readLine", which is not supported + override def writeChars(s: String): Unit = throw new UnsupportedOperationException("writeChars") + override def writeDouble(v: Double): Unit = output.writeDouble(v) + override def writeUTF(s: String): Unit = output.writeString(s) // writeString in kryo does UTF8 + override def writeShort(v: Int): Unit = output.writeShort(v) + override def writeInt(v: Int): Unit = output.writeInt(v) + override def writeBoolean(v: Boolean): Unit = output.writeBoolean(v) + override def write(b: Int): Unit = output.write(b) + override def write(b: Array[Byte]): Unit = output.write(b) + override def write(b: Array[Byte], off: Int, len: Int): Unit = output.write(b, off, len) + override def writeBytes(s: String): Unit = output.writeString(s) + override def writeChar(v: Int): Unit = output.writeChar(v.toChar) + override def writeLong(v: Long): Unit = output.writeLong(v) + override def writeByte(v: Int): Unit = output.writeByte(v) + override def writeObject(obj: AnyRef): Unit = kryo.writeClassAndObject(output, obj) } /** diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala index a1b1e1631eafb..8daca6c390635 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -25,7 +25,7 @@ import scala.annotation.tailrec import scala.collection.mutable import scala.util.control.NonFatal -import org.apache.spark.Logging +import org.apache.spark.internal.Logging private[spark] object SerializationDebugger extends Logging { @@ -53,12 +53,13 @@ private[spark] object SerializationDebugger extends Logging { /** * Find the path leading to a not serializable object. This method is modeled after OpenJDK's * serialization mechanism, and handles the following cases: - * - primitives - * - arrays of primitives - * - arrays of non-primitive objects - * - Serializable objects - * - Externalizable objects - * - writeReplace + * + * - primitives + * - arrays of primitives + * - arrays of non-primitive objects + * - Serializable objects + * - Externalizable objects + * - writeReplace * * It does not yet handle writeObject override, but that shouldn't be too hard to do either. */ diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index bd2704dc81871..cb95246d5b0ca 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -23,9 +23,9 @@ import javax.annotation.concurrent.NotThreadSafe import scala.reflect.ClassTag -import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.SparkEnv import org.apache.spark.annotation.{DeveloperApi, Private} -import org.apache.spark.util.{Utils, ByteBufferInputStream, NextIterator} +import org.apache.spark.util.NextIterator /** * :: DeveloperApi :: @@ -100,18 +100,6 @@ abstract class Serializer { } -@DeveloperApi -object Serializer { - def getSerializer(serializer: Serializer): Serializer = { - if (serializer == null) SparkEnv.get.serializer else serializer - } - - def getSerializer(serializer: Option[Serializer]): Serializer = { - serializer.getOrElse(SparkEnv.get.serializer) - } -} - - /** * :: DeveloperApi :: * An instance of a serializer, for use by one thread at a time. @@ -200,10 +188,9 @@ abstract class DeserializationStream { try { (readKey[Any](), readValue[Any]()) } catch { - case eof: EOFException => { + case eof: EOFException => finished = true null - } } } diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala new file mode 100644 index 0000000000000..745ef126913f5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{BufferedInputStream, BufferedOutputStream, InputStream, OutputStream} +import java.nio.ByteBuffer + +import scala.reflect.ClassTag + +import org.apache.spark.SparkConf +import org.apache.spark.io.CompressionCodec +import org.apache.spark.storage._ +import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} + +/** + * Component which configures serialization and compression for various Spark components, including + * automatic selection of which [[Serializer]] to use for shuffles. + */ +private[spark] class SerializerManager(defaultSerializer: Serializer, conf: SparkConf) { + + private[this] val kryoSerializer = new KryoSerializer(conf) + + private[this] val stringClassTag: ClassTag[String] = implicitly[ClassTag[String]] + private[this] val primitiveAndPrimitiveArrayClassTags: Set[ClassTag[_]] = { + val primitiveClassTags = Set[ClassTag[_]]( + ClassTag.Boolean, + ClassTag.Byte, + ClassTag.Char, + ClassTag.Double, + ClassTag.Float, + ClassTag.Int, + ClassTag.Long, + ClassTag.Null, + ClassTag.Short + ) + val arrayClassTags = primitiveClassTags.map(_.wrap) + primitiveClassTags ++ arrayClassTags + } + + // Whether to compress broadcast variables that are stored + private[this] val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true) + // Whether to compress shuffle output that are stored + private[this] val compressShuffle = conf.getBoolean("spark.shuffle.compress", true) + // Whether to compress RDD partitions that are stored serialized + private[this] val compressRdds = conf.getBoolean("spark.rdd.compress", false) + // Whether to compress shuffle output temporarily spilled to disk + private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) + + /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay + * the initialization of the compression codec until it is first used. The reason is that a Spark + * program could be using a user-defined codec in a third party jar, which is loaded in + * Executor.updateDependencies. When the BlockManager is initialized, user level jars hasn't been + * loaded yet. */ + private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) + + private def canUseKryo(ct: ClassTag[_]): Boolean = { + primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag + } + + def getSerializer(ct: ClassTag[_]): Serializer = { + if (canUseKryo(ct)) { + kryoSerializer + } else { + defaultSerializer + } + } + + /** + * Pick the best serializer for shuffling an RDD of key-value pairs. + */ + def getSerializer(keyClassTag: ClassTag[_], valueClassTag: ClassTag[_]): Serializer = { + if (canUseKryo(keyClassTag) && canUseKryo(valueClassTag)) { + kryoSerializer + } else { + defaultSerializer + } + } + + private def shouldCompress(blockId: BlockId): Boolean = { + blockId match { + case _: ShuffleBlockId => compressShuffle + case _: BroadcastBlockId => compressBroadcast + case _: RDDBlockId => compressRdds + case _: TempLocalBlockId => compressShuffleSpill + case _: TempShuffleBlockId => compressShuffle + case _ => false + } + } + + /** + * Wrap an output stream for compression if block compression is enabled for its block type + */ + def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { + if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s + } + + /** + * Wrap an input stream for compression if block compression is enabled for its block type + */ + def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { + if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s + } + + /** Serializes into a stream. */ + def dataSerializeStream[T: ClassTag]( + blockId: BlockId, + outputStream: OutputStream, + values: Iterator[T]): Unit = { + val byteStream = new BufferedOutputStream(outputStream) + val ser = getSerializer(implicitly[ClassTag[T]]).newInstance() + ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() + } + + /** Serializes into a chunked byte buffer. */ + def dataSerialize[T: ClassTag](blockId: BlockId, values: Iterator[T]): ChunkedByteBuffer = { + val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate) + dataSerializeStream(blockId, bbos, values) + bbos.toChunkedByteBuffer + } + + /** + * Deserializes a InputStream into an iterator of values and disposes of it when the end of + * the iterator is reached. + */ + def dataDeserializeStream[T: ClassTag]( + blockId: BlockId, + inputStream: InputStream): Iterator[T] = { + val stream = new BufferedInputStream(inputStream) + getSerializer(implicitly[ClassTag[T]]) + .newInstance() + .deserializeStream(wrapForCompression(blockId, stream)) + .asIterator.asInstanceOf[Iterator[T]] + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala index b36c457d6d514..04e4cf88d7063 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala @@ -17,8 +17,7 @@ package org.apache.spark.shuffle -import org.apache.spark.{ShuffleDependency, Aggregator, Partitioner} -import org.apache.spark.serializer.Serializer +import org.apache.spark.ShuffleDependency /** * A basic ShuffleHandle implementation that just captures registerShuffle's parameters. diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index b0abda4a81b8d..876cdfaa87601 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -18,7 +18,8 @@ package org.apache.spark.shuffle import org.apache.spark._ -import org.apache.spark.serializer.Serializer +import org.apache.spark.internal.Logging +import org.apache.spark.serializer.SerializerManager import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -32,6 +33,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( startPartition: Int, endPartition: Int, context: TaskContext, + serializerManager: SerializerManager = SparkEnv.get.serializerManager, blockManager: BlockManager = SparkEnv.get.blockManager, mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) extends ShuffleReader[K, C] with Logging { @@ -46,15 +48,15 @@ private[spark] class BlockStoreShuffleReader[K, C]( blockManager, mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, + SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue)) // Wrap the streams for compression based on configuration val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) => - blockManager.wrapForCompression(blockId, inputStream) + serializerManager.wrapForCompression(blockId, inputStream) } - val ser = Serializer.getSerializer(dep.serializer) - val serializerInstance = ser.newInstance() + val serializerInstance = dep.serializer.newInstance() // Create a key/value iterator for each stream val recordIter = wrappedStreams.flatMap { wrappedStream => @@ -65,13 +67,13 @@ private[spark] class BlockStoreShuffleReader[K, C]( } // Update the context task metrics for each record read. - val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + val readMetrics = context.taskMetrics.registerTempShuffleReadMetrics() val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( - recordIter.map(record => { + recordIter.map { record => readMetrics.incRecordsRead(1) record - }), - context.taskMetrics().updateShuffleReadMetrics()) + }, + context.taskMetrics().mergeShuffleReadMetrics()) // An interruptible iterator must be used here in order to support task cancellation val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) @@ -99,12 +101,11 @@ private[spark] class BlockStoreShuffleReader[K, C]( // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled, // the ExternalSorter won't spill to disk. val sorter = - new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser)) + new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer) sorter.insertAll(aggregatedIter) context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) - context.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) + context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) case None => aggregatedIter diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index be184464e0ae9..b2d050b218f53 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -17,8 +17,8 @@ package org.apache.spark.shuffle -import org.apache.spark.storage.BlockManagerId import org.apache.spark.{FetchFailed, TaskEndReason} +import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils /** diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index cd253a78c2b19..be1e84a2ba938 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -17,24 +17,25 @@ package org.apache.spark.shuffle -import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} import scala.collection.JavaConverters._ -import org.apache.spark.{Logging, SparkConf, SparkEnv} +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.serializer.Serializer import org.apache.spark.storage._ -import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} +import org.apache.spark.util.Utils /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { val writers: Array[DiskBlockObjectWriter] /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */ - def releaseWriters(success: Boolean) + def releaseWriters(success: Boolean): Unit } /** @@ -46,7 +47,7 @@ private[spark] trait ShuffleWriterGroup { private[spark] class FileShuffleBlockResolver(conf: SparkConf) extends ShuffleBlockResolver with Logging { - private val transportConf = SparkTransportConf.fromSparkConf(conf) + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") private lazy val blockManager = SparkEnv.get.blockManager @@ -63,10 +64,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) val completedMapTasks = new ConcurrentLinkedQueue[Int]() } - private val shuffleStates = new TimeStampedHashMap[ShuffleId, ShuffleState] - - private val metadataCleaner = - new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup, conf) + private val shuffleStates = new ConcurrentHashMap[ShuffleId, ShuffleState] /** * Get a ShuffleWriterGroup for the given map task, which will register it as complete @@ -75,31 +73,25 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) def forMapTask(shuffleId: Int, mapId: Int, numReducers: Int, serializer: Serializer, writeMetrics: ShuffleWriteMetrics): ShuffleWriterGroup = { new ShuffleWriterGroup { - shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numReducers)) - private val shuffleState = shuffleStates(shuffleId) - + private val shuffleState: ShuffleState = { + // Note: we do _not_ want to just wrap this java ConcurrentHashMap into a Scala map and use + // .getOrElseUpdate() because that's actually NOT atomic. + shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numReducers)) + shuffleStates.get(shuffleId) + } val openStartTime = System.nanoTime val serializerInstance = serializer.newInstance() val writers: Array[DiskBlockObjectWriter] = { Array.tabulate[DiskBlockObjectWriter](numReducers) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) val blockFile = blockManager.diskBlockManager.getFile(blockId) - // Because of previous failures, the shuffle file may already exist on this machine. - // If so, remove it. - if (blockFile.exists) { - if (blockFile.delete()) { - logInfo(s"Removed existing shuffle file $blockFile") - } else { - logWarning(s"Failed to remove existing shuffle file $blockFile") - } - } - blockManager.getDiskWriter(blockId, blockFile, serializerInstance, bufferSize, - writeMetrics) + val tmp = Utils.tempFileWith(blockFile) + blockManager.getDiskWriter(blockId, tmp, serializerInstance, bufferSize, writeMetrics) } } // Creating the file to write to and creating a disk writer both involve interacting with // the disk, so should be included in the shuffle write time. - writeMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) + writeMetrics.incWriteTime(System.nanoTime - openStartTime) override def releaseWriters(success: Boolean) { shuffleState.completedMapTasks.add(mapId) @@ -123,7 +115,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) /** Remove all the blocks / files related to a particular shuffle. */ private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = { - shuffleStates.get(shuffleId) match { + Option(shuffleStates.get(shuffleId)) match { case Some(state) => for (mapId <- state.completedMapTasks.asScala; reduceId <- 0 until state.numReducers) { val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) @@ -140,11 +132,5 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) } } - private def cleanup(cleanupTime: Long) { - shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId)) - } - - override def stop() { - metadataCleaner.cancel() - } + override def stop(): Unit = {} } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 5e4c2b5d0a5c4..94d8c0d0fd3e4 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -21,14 +21,14 @@ import java.io._ import com.google.common.io.ByteStreams -import org.apache.spark.{SparkConf, SparkEnv, Logging} +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID import org.apache.spark.storage._ import org.apache.spark.util.Utils -import IndexShuffleBlockResolver.NOOP_REDUCE_ID - /** * Create and maintain the shuffle blocks' mapping between logic block and physical file location. * Data of shuffle blocks from the same map task are stored in a single consolidated data file. @@ -40,12 +40,15 @@ import IndexShuffleBlockResolver.NOOP_REDUCE_ID */ // Note: Changes to the format in this file should be kept in sync with // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getSortBasedShuffleBlockData(). -private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleBlockResolver +private[spark] class IndexShuffleBlockResolver( + conf: SparkConf, + _blockManager: BlockManager = null) + extends ShuffleBlockResolver with Logging { - private lazy val blockManager = SparkEnv.get.blockManager + private lazy val blockManager = Option(_blockManager).getOrElse(SparkEnv.get.blockManager) - private val transportConf = SparkTransportConf.fromSparkConf(conf) + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") def getDataFile(shuffleId: Int, mapId: Int): File = { blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) @@ -74,14 +77,69 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB } } + /** + * Check whether the given index and data files match each other. + * If so, return the partition lengths in the data file. Otherwise return null. + */ + private def checkIndexAndDataFile(index: File, data: File, blocks: Int): Array[Long] = { + // the index file should have `block + 1` longs as offset. + if (index.length() != (blocks + 1) * 8) { + return null + } + val lengths = new Array[Long](blocks) + // Read the lengths of blocks + val in = try { + new DataInputStream(new BufferedInputStream(new FileInputStream(index))) + } catch { + case e: IOException => + return null + } + try { + // Convert the offsets into lengths of each block + var offset = in.readLong() + if (offset != 0L) { + return null + } + var i = 0 + while (i < blocks) { + val off = in.readLong() + lengths(i) = off - offset + offset = off + i += 1 + } + } catch { + case e: IOException => + return null + } finally { + in.close() + } + + // the size of data file should match with index file + if (data.length() == lengths.sum) { + lengths + } else { + null + } + } + /** * Write an index file with the offsets of each block, plus a final offset at the end for the * end of the output file. This will be used by getBlockData to figure out where each block * begins and ends. + * + * It will commit the data and index file as an atomic operation, use the existing ones, or + * replace them with new ones. + * + * Note: the `lengths` will be updated to match the existing index file if use the existing ones. * */ - def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = { + def writeIndexFileAndCommit( + shuffleId: Int, + mapId: Int, + lengths: Array[Long], + dataTmp: File): Unit = { val indexFile = getIndexFile(shuffleId, mapId) - val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) + val indexTmp = Utils.tempFileWith(indexFile) + val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp))) Utils.tryWithSafeFinally { // We take in lengths of each block, need to convert it to offsets. var offset = 0L @@ -93,6 +151,37 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB } { out.close() } + + val dataFile = getDataFile(shuffleId, mapId) + // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure + // the following check and rename are atomic. + synchronized { + val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length) + if (existingLengths != null) { + // Another attempt for the same task has already written our map outputs successfully, + // so just use the existing partition lengths and delete our temporary map outputs. + System.arraycopy(existingLengths, 0, lengths, 0, lengths.length) + if (dataTmp != null && dataTmp.exists()) { + dataTmp.delete() + } + indexTmp.delete() + } else { + // This is the first successful attempt in writing the map outputs for this task, + // so override any existing index and data files with the ones we wrote. + if (indexFile.exists()) { + indexFile.delete() + } + if (dataFile.exists()) { + dataFile.delete() + } + if (!indexTmp.renameTo(indexFile)) { + throw new IOException("fail to rename file " + indexTmp + " to " + indexFile) + } + if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) { + throw new IOException("fail to rename file " + dataTmp + " to " + dataFile) + } + } + } } override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala index 4342b0d598b16..d1ecbc1bf0178 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala @@ -17,7 +17,6 @@ package org.apache.spark.shuffle -import java.nio.ByteBuffer import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.storage.ShuffleBlockId diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index 978366d1a1d1b..364fad664e3a0 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle -import org.apache.spark.{TaskContext, ShuffleDependency} +import org.apache.spark.{ShuffleDependency, TaskContext} /** * Pluggable interface for shuffle systems. A ShuffleManager is created in SparkEnv on the driver @@ -28,6 +28,10 @@ import org.apache.spark.{TaskContext, ShuffleDependency} * boolean isDriver as parameters. */ private[spark] trait ShuffleManager { + + /** Return short name for the ShuffleManager */ + val shortName: String + /** * Register a shuffle with the manager and obtain a handle for it to pass to tasks. */ @@ -50,9 +54,9 @@ private[spark] trait ShuffleManager { context: TaskContext): ShuffleReader[K, C] /** - * Remove a shuffle's metadata from the ShuffleManager. - * @return true if the metadata removed successfully, otherwise false. - */ + * Remove a shuffle's metadata from the ShuffleManager. + * @return true if the metadata removed successfully, otherwise false. + */ def unregisterShuffle(shuffleId: Int): Boolean /** diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index d2e2fc4c110a7..6bb4ff94b546d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -18,6 +18,7 @@ package org.apache.spark.shuffle.hash import org.apache.spark._ +import org.apache.spark.internal.Logging import org.apache.spark.shuffle._ /** @@ -34,6 +35,8 @@ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager private val fileShuffleBlockResolver = new FileShuffleBlockResolver(conf) + override val shortName: String = "hash" + /* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */ override def registerShuffle[K, V, C]( shuffleId: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 41df70c602c30..9276d95012f2f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -17,10 +17,11 @@ package org.apache.spark.shuffle.hash +import java.io.IOException + import org.apache.spark._ -import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.internal.Logging import org.apache.spark.scheduler.MapStatus -import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle._ import org.apache.spark.storage.DiskBlockObjectWriter @@ -40,13 +41,11 @@ private[spark] class HashShuffleWriter[K, V]( // we don't try deleting files, etc twice. private var stopping = false - private val writeMetrics = new ShuffleWriteMetrics() - metrics.shuffleWriteMetrics = Some(writeMetrics) + private val writeMetrics = metrics.registerShuffleWriteMetrics() private val blockManager = SparkEnv.get.blockManager - private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null)) - private val shuffle = shuffleBlockResolver.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser, - writeMetrics) + private val shuffle = shuffleBlockResolver.forMapTask(dep.shuffleId, mapId, numOutputSplits, + dep.serializer, writeMetrics) /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { @@ -106,6 +105,29 @@ private[spark] class HashShuffleWriter[K, V]( writer.commitAndClose() writer.fileSegment().length } + // rename all shuffle files to final paths + // Note: there is only one ShuffleBlockResolver in executor + shuffleBlockResolver.synchronized { + shuffle.writers.zipWithIndex.foreach { case (writer, i) => + val output = blockManager.diskBlockManager.getFile(writer.blockId) + if (sizes(i) > 0) { + if (output.exists()) { + // Use length of existing file and delete our own temporary one + sizes(i) = output.length() + writer.file.delete() + } else { + // Commit by renaming our temporary file to something the fetcher expects + if (!writer.file.renameTo(output)) { + throw new IOException(s"fail to rename ${writer.file} to $output") + } + } + } else { + if (output.exists()) { + output.delete() + } + } + } + } MapStatus(blockManager.shuffleServerId, sizes) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 66b6bbc61fe8e..9bfd966e33581 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -20,7 +20,7 @@ package org.apache.spark.shuffle.sort import java.util.concurrent.ConcurrentHashMap import org.apache.spark._ -import org.apache.spark.serializer.Serializer +import org.apache.spark.internal.Logging import org.apache.spark.shuffle._ /** @@ -79,6 +79,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager */ private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]() + override val shortName: String = "sort" + override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) /** @@ -184,10 +186,9 @@ private[spark] object SortShuffleManager extends Logging { def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _]): Boolean = { val shufId = dependency.shuffleId val numPartitions = dependency.partitioner.numPartitions - val serializer = Serializer.getSerializer(dependency.serializer) - if (!serializer.supportsRelocationOfSerializedObjects) { + if (!dependency.serializer.supportsRelocationOfSerializedObjects) { log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " + - s"${serializer.getClass.getName}, does not support object relocation") + s"${dependency.serializer.getClass.getName}, does not support object relocation") false } else if (dependency.aggregator.isDefined) { log.debug( diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 808317b017a0f..8ab1cee2e842d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -18,10 +18,11 @@ package org.apache.spark.shuffle.sort import org.apache.spark._ -import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.internal.Logging import org.apache.spark.scheduler.MapStatus -import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle} +import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter} import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.util.Utils import org.apache.spark.util.collection.ExternalSorter private[spark] class SortShuffleWriter[K, V, C]( @@ -44,8 +45,7 @@ private[spark] class SortShuffleWriter[K, V, C]( private var mapStatus: MapStatus = null - private val writeMetrics = new ShuffleWriteMetrics() - context.taskMetrics.shuffleWriteMetrics = Some(writeMetrics) + private val writeMetrics = context.taskMetrics().registerShuffleWriteMetrics() /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { @@ -65,11 +65,11 @@ private[spark] class SortShuffleWriter[K, V, C]( // Don't bother including the time to open the merged output file in the shuffle write time, // because it just opens a single file, so is typically too fast to measure accurately // (see SPARK-3570). - val outputFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId) + val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId) + val tmp = Utils.tempFileWith(output) val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) - val partitionLengths = sorter.writePartitionedFile(blockId, outputFile) - shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths) - + val partitionLengths = sorter.writePartitionedFile(blockId, tmp) + shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) } @@ -92,8 +92,7 @@ private[spark] class SortShuffleWriter[K, V, C]( if (sorter != null) { val startTime = System.nanoTime() sorter.stop() - context.taskMetrics.shuffleWriteMetrics.foreach( - _.incShuffleWriteTime(System.nanoTime - startTime)) + writeMetrics.incWriteTime(System.nanoTime - startTime) sorter = null } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala index 645ede26a0879..5c03609e5e5e5 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala @@ -28,7 +28,7 @@ private[v1] class AllRDDResource(ui: SparkUI) { @GET def rddList(): Seq[RDDStorageInfo] = { - val storageStatusList = ui.storageListener.storageStatusList + val storageStatusList = ui.storageListener.activeStorageStatusList val rddInfos = ui.storageListener.rddInfoList rddInfos.map{rddInfo => AllRDDResource.getRDDStorageInfo(rddInfo.id, rddInfo, storageStatusList, @@ -44,7 +44,7 @@ private[spark] object AllRDDResource { rddId: Int, listener: StorageListener, includeDetails: Boolean): Option[RDDStorageInfo] = { - val storageStatusList = listener.storageStatusList + val storageStatusList = listener.activeStorageStatusList listener.rddInfoList.find { _.id == rddId }.map { rddInfo => getRDDStorageInfo(rddId, rddInfo, storageStatusList, includeDetails) } @@ -61,7 +61,7 @@ private[spark] object AllRDDResource { .flatMap { _.rddBlocksById(rddId) } .sortWith { _._1.name < _._1.name } .map { case (blockId, status) => - (blockId, status, blockLocations.get(blockId).getOrElse(Seq[String]("Unknown"))) + (blockId, status, blockLocations.getOrElse(blockId, Seq[String]("Unknown"))) } val dataDistribution = if (includeDetails) { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 24a0b5220695c..f8d6e9fbbb90d 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -17,7 +17,7 @@ package org.apache.spark.status.api.v1 import java.util.{Arrays, Date, List => JList} -import javax.ws.rs.{GET, PathParam, Produces, QueryParam} +import javax.ws.rs.{GET, Produces, QueryParam} import javax.ws.rs.core.MediaType import org.apache.spark.executor.{InputMetrics => InternalInputMetrics, OutputMetrics => InternalOutputMetrics, ShuffleReadMetrics => InternalShuffleReadMetrics, ShuffleWriteMetrics => InternalShuffleWriteMetrics, TaskMetrics => InternalTaskMetrics} @@ -59,6 +59,15 @@ private[v1] object AllStagesResource { stageUiData: StageUIData, includeDetails: Boolean): StageData = { + val taskLaunchTimes = stageUiData.taskData.values.map(_.taskInfo.launchTime).filter(_ > 0) + + val firstTaskLaunchedTime: Option[Date] = + if (taskLaunchTimes.nonEmpty) { + Some(new Date(taskLaunchTimes.min)) + } else { + None + } + val taskData = if (includeDetails) { Some(stageUiData.taskData.map { case (k, v) => k -> convertTaskData(v) } ) } else { @@ -92,6 +101,9 @@ private[v1] object AllStagesResource { numCompleteTasks = stageUiData.numCompleteTasks, numFailedTasks = stageUiData.numFailedTasks, executorRunTime = stageUiData.executorRunTime, + submissionTime = stageInfo.submissionTime.map(new Date(_)), + firstTaskLaunchedTime, + completionTime = stageInfo.completionTime.map(new Date(_)), inputBytes = stageUiData.inputBytes, inputRecords = stageUiData.inputRecords, outputBytes = stageUiData.outputBytes, @@ -135,7 +147,7 @@ private[v1] object AllStagesResource { speculative = uiData.taskInfo.speculative, accumulatorUpdates = uiData.taskInfo.accumulables.map { convertAccumulableInfo }, errorMessage = uiData.errorMessage, - taskMetrics = uiData.taskMetrics.map { convertUiTaskMetrics } + taskMetrics = uiData.metrics.map { convertUiTaskMetrics } ) } @@ -143,7 +155,7 @@ private[v1] object AllStagesResource { allTaskData: Iterable[TaskUIData], quantiles: Array[Double]): TaskMetricDistributions = { - val rawMetrics = allTaskData.flatMap{_.taskMetrics}.toSeq + val rawMetrics = allTaskData.flatMap{_.metrics}.toSeq def metricQuantiles(f: InternalTaskMetrics => Double): IndexedSeq[Double] = Distribution(rawMetrics.map { d => f(d) }).get.getQuantiles(quantiles) @@ -202,9 +214,9 @@ private[v1] object AllStagesResource { raw.shuffleWriteMetrics } def build: ShuffleWriteMetricDistributions = new ShuffleWriteMetricDistributions( - writeBytes = submetricQuantiles(_.shuffleBytesWritten), - writeRecords = submetricQuantiles(_.shuffleRecordsWritten), - writeTime = submetricQuantiles(_.shuffleWriteTime) + writeBytes = submetricQuantiles(_.bytesWritten), + writeRecords = submetricQuantiles(_.recordsWritten), + writeTime = submetricQuantiles(_.writeTime) ) }.metricOption @@ -225,7 +237,8 @@ private[v1] object AllStagesResource { } def convertAccumulableInfo(acc: InternalAccumulableInfo): AccumulableInfo = { - new AccumulableInfo(acc.id, acc.name, acc.update, acc.value) + new AccumulableInfo( + acc.id, acc.name.orNull, acc.update.map(_.toString), acc.value.map(_.toString).orNull) } def convertUiTaskMetrics(internal: InternalTaskMetrics): TaskMetrics = { @@ -271,9 +284,9 @@ private[v1] object AllStagesResource { def convertShuffleWriteMetrics(internal: InternalShuffleWriteMetrics): ShuffleWriteMetrics = { new ShuffleWriteMetrics( - bytesWritten = internal.shuffleBytesWritten, - writeTime = internal.shuffleWriteTime, - recordsWritten = internal.shuffleRecordsWritten + bytesWritten = internal.bytesWritten, + writeTime = internal.writeTime, + recordsWritten = internal.recordsWritten ) } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 50b6ba67e9931..ba9cd711f18e2 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -177,6 +177,12 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { @PathParam("attemptId") attemptId: String): EventLogDownloadResource = { new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) } + + @Path("version") + def getVersion(): VersionResource = { + new VersionResource(uiRoot) + } + } private[spark] object ApiRootResource { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala index 17b521f3e1d41..0f30183682469 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala @@ -62,11 +62,22 @@ private[spark] object ApplicationsListResource { new ApplicationInfo( id = app.id, name = app.name, + coresGranted = None, + maxCores = None, + coresPerExecutor = None, + memoryPerExecutorMB = None, attempts = app.attempts.map { internalAttemptInfo => new ApplicationAttemptInfo( attemptId = internalAttemptInfo.attemptId, startTime = new Date(internalAttemptInfo.startTime), endTime = new Date(internalAttemptInfo.endTime), + duration = + if (internalAttemptInfo.endTime > 0) { + internalAttemptInfo.endTime - internalAttemptInfo.startTime + } else { + 0 + }, + lastUpdated = new Date(internalAttemptInfo.lastUpdated), sparkUser = internalAttemptInfo.sparkUser, completed = internalAttemptInfo.completed ) @@ -81,10 +92,21 @@ private[spark] object ApplicationsListResource { new ApplicationInfo( id = internal.id, name = internal.desc.name, + coresGranted = Some(internal.coresGranted), + maxCores = internal.desc.maxCores, + coresPerExecutor = internal.desc.coresPerExecutor, + memoryPerExecutorMB = Some(internal.desc.memoryPerExecutorMB), attempts = Seq(new ApplicationAttemptInfo( attemptId = None, startTime = new Date(internal.startTime), endTime = new Date(internal.endTime), + duration = + if (internal.endTime > 0) { + internal.endTime - internal.startTime + } else { + 0 + }, + lastUpdated = new Date(internal.endTime), sparkUser = internal.desc.user, completed = completed )) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala index 22e21f0c62a29..c84022ddfeef0 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala @@ -23,8 +23,9 @@ import javax.ws.rs.core.{MediaType, Response, StreamingOutput} import scala.util.control.NonFatal -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging @Produces(Array(MediaType.APPLICATION_OCTET_STREAM)) private[v1] class EventLogDownloadResource( diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala index 8ad4656b4dada..6ca59c2f3caeb 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala @@ -28,9 +28,13 @@ private[v1] class ExecutorListResource(ui: SparkUI) { @GET def executorList(): Seq[ExecutorSummary] = { val listener = ui.executorsListener - val storageStatusList = listener.storageStatusList - (0 until storageStatusList.size).map { statusId => - ExecutorsPage.getExecInfo(listener, statusId) + listener.synchronized { + // The follow codes should be protected by `listener` to make sure no executors will be + // removed before we query their status. See SPARK-12784. + val storageStatusList = listener.activeStorageStatusList + (0 until storageStatusList.size).map { statusId => + ExecutorsPage.getExecInfo(listener, statusId, isActive = true) + } } } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala index 202a5191ad57d..f6a9f9c5573db 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala @@ -19,6 +19,7 @@ package org.apache.spark.status.api.v1 import java.io.OutputStream import java.lang.annotation.Annotation import java.lang.reflect.Type +import java.nio.charset.StandardCharsets import java.text.SimpleDateFormat import java.util.{Calendar, SimpleTimeZone} import javax.ws.rs.Produces @@ -68,7 +69,7 @@ private[v1] class JacksonMessageWriter extends MessageBodyWriter[Object]{ multivaluedMap: MultivaluedMap[String, AnyRef], outputStream: OutputStream): Unit = { t match { - case ErrorWrapper(err) => outputStream.write(err.getBytes("utf-8")) + case ErrorWrapper(err) => outputStream.write(err.getBytes(StandardCharsets.UTF_8)) case _ => mapper.writeValue(outputStream, t) } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala index b5ef72649e295..d7e6a8b589953 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala @@ -16,8 +16,8 @@ */ package org.apache.spark.status.api.v1 +import javax.ws.rs.{GET, PathParam, Produces} import javax.ws.rs.core.MediaType -import javax.ws.rs.{Produces, PathParam, GET} @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class OneApplicationResource(uiRoot: UIRoot) { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala index 6d8a60d480aed..653150385c732 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.status.api.v1 -import javax.ws.rs.{PathParam, GET, Produces} +import javax.ws.rs.{GET, PathParam, Produces} import javax.ws.rs.core.MediaType import org.apache.spark.JobExecutionStatus @@ -30,7 +30,7 @@ private[v1] class OneJobResource(ui: SparkUI) { def oneJob(@PathParam("jobId") jobId: Int): JobData = { val statusToJobs: Seq[(JobExecutionStatus, Seq[JobUIData])] = AllJobsResource.getStatusToJobs(ui) - val jobOpt = statusToJobs.map {_._2} .flatten.find { jobInfo => jobInfo.jobId == jobId} + val jobOpt = statusToJobs.flatMap(_._2).find { jobInfo => jobInfo.jobId == jobId} jobOpt.map { job => AllJobsResource.convertJobData(job, ui.jobProgressListener, false) }.getOrElse { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala index dfdc09c6caf3b..237aeac185877 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.status.api.v1 -import javax.ws.rs.{PathParam, GET, Produces} +import javax.ws.rs.{GET, PathParam, Produces} import javax.ws.rs.core.MediaType import org.apache.spark.ui.SparkUI diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala index f9812f06cf527..3e6d2942d0fbb 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala @@ -33,7 +33,7 @@ private[v1] class OneStageResource(ui: SparkUI) { @GET @Path("") def stageData(@PathParam("stageId") stageId: Int): Seq[StageData] = { - withStage(stageId){ stageAttempts => + withStage(stageId) { stageAttempts => stageAttempts.map { stage => AllStagesResource.stageUiToStageData(stage.status, stage.info, stage.ui, includeDetails = true) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/VersionResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/VersionResource.scala new file mode 100644 index 0000000000000..673da1ce36b57 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/api/v1/VersionResource.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status.api.v1 + +import javax.ws.rs._ +import javax.ws.rs.core.MediaType + +@Produces(Array(MediaType.APPLICATION_JSON)) +private[v1] class VersionResource(ui: UIRoot) { + + @GET + def getVersionInfo(): VersionInfo = new VersionInfo( + org.apache.spark.SPARK_VERSION + ) + +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 2bec64f2ef02b..ebbbf4814880f 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -25,14 +25,24 @@ import org.apache.spark.JobExecutionStatus class ApplicationInfo private[spark]( val id: String, val name: String, + val coresGranted: Option[Int], + val maxCores: Option[Int], + val coresPerExecutor: Option[Int], + val memoryPerExecutorMB: Option[Int], val attempts: Seq[ApplicationAttemptInfo]) class ApplicationAttemptInfo private[spark]( val attemptId: Option[String], val startTime: Date, val endTime: Date, + val lastUpdated: Date, + val duration: Long, val sparkUser: String, - val completed: Boolean = false) + val completed: Boolean = false) { + def getStartTimeEpoch: Long = startTime.getTime + def getEndTimeEpoch: Long = endTime.getTime + def getLastUpdatedEpoch: Long = lastUpdated.getTime +} class ExecutorStageSummary private[spark]( val taskTime : Long, @@ -48,14 +58,18 @@ class ExecutorStageSummary private[spark]( class ExecutorSummary private[spark]( val id: String, val hostPort: String, + val isActive: Boolean, val rddBlocks: Int, val memoryUsed: Long, val diskUsed: Long, + val totalCores: Int, + val maxTasks: Int, val activeTasks: Int, val failedTasks: Int, val completedTasks: Int, val totalTasks: Int, val totalDuration: Long, + val totalGCTime: Long, val totalInputBytes: Long, val totalShuffleRead: Long, val totalShuffleWrite: Long, @@ -81,8 +95,6 @@ class JobData private[spark]( val numSkippedStages: Int, val numFailedStages: Int) -// Q: should Tachyon size go in here as well? currently the UI only shows it on the overall storage -// page ... does anybody pay attention to it? class RDDStorageInfo private[spark]( val id: Int, val name: String, @@ -111,11 +123,14 @@ class StageData private[spark]( val status: StageStatus, val stageId: Int, val attemptId: Int, - val numActiveTasks: Int , + val numActiveTasks: Int, val numCompleteTasks: Int, val numFailedTasks: Int, val executorRunTime: Long, + val submissionTime: Option[Date], + val firstTaskLaunchedTime: Option[Date], + val completionTime: Option[Date], val inputBytes: Long, val inputRecords: Long, @@ -226,3 +241,6 @@ class AccumulableInfo private[spark]( val name: String, val update: Option[String], val value: String) + +class VersionInfo private[spark]( + val spark: String) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala deleted file mode 100644 index 22fdf73e9d1f4..0000000000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.util.concurrent.ConcurrentHashMap - -private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) { - // To save space, 'pending' and 'failed' are encoded as special sizes: - @volatile var size: Long = BlockInfo.BLOCK_PENDING - private def pending: Boolean = size == BlockInfo.BLOCK_PENDING - private def failed: Boolean = size == BlockInfo.BLOCK_FAILED - private def initThread: Thread = BlockInfo.blockInfoInitThreads.get(this) - - setInitThread() - - private def setInitThread() { - /* Set current thread as init thread - waitForReady will not block this thread - * (in case there is non trivial initialization which ends up calling waitForReady - * as part of initialization itself) */ - BlockInfo.blockInfoInitThreads.put(this, Thread.currentThread()) - } - - /** - * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing). - * Return true if the block is available, false otherwise. - */ - def waitForReady(): Boolean = { - if (pending && initThread != Thread.currentThread()) { - synchronized { - while (pending) { - this.wait() - } - } - } - !failed - } - - /** Mark this BlockInfo as ready (i.e. block is finished writing) */ - def markReady(sizeInBytes: Long) { - require(sizeInBytes >= 0, s"sizeInBytes was negative: $sizeInBytes") - assert(pending) - size = sizeInBytes - BlockInfo.blockInfoInitThreads.remove(this) - synchronized { - this.notifyAll() - } - } - - /** Mark this BlockInfo as ready but failed */ - def markFailure() { - assert(pending) - size = BlockInfo.BLOCK_FAILED - BlockInfo.blockInfoInitThreads.remove(this) - synchronized { - this.notifyAll() - } - } -} - -private object BlockInfo { - /* initThread is logically a BlockInfo field, but we store it here because - * it's only needed while this block is in the 'pending' state and we want - * to minimize BlockInfo's memory footprint. */ - private val blockInfoInitThreads = new ConcurrentHashMap[BlockInfo, Thread] - - private val BLOCK_PENDING: Long = -1L - private val BLOCK_FAILED: Long = -2L -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala new file mode 100644 index 0000000000000..ca53534b61c4a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -0,0 +1,447 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import javax.annotation.concurrent.GuardedBy + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.reflect.ClassTag + +import com.google.common.collect.ConcurrentHashMultiset + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.internal.Logging + + +/** + * Tracks metadata for an individual block. + * + * Instances of this class are _not_ thread-safe and are protected by locks in the + * [[BlockInfoManager]]. + * + * @param level the block's storage level. This is the requested persistence level, not the + * effective storage level of the block (i.e. if this is MEMORY_AND_DISK, then this + * does not imply that the block is actually resident in memory). + * @param classTag the block's [[ClassTag]], used to select the serializer + * @param tellMaster whether state changes for this block should be reported to the master. This + * is true for most blocks, but is false for broadcast blocks. + */ +private[storage] class BlockInfo( + val level: StorageLevel, + val classTag: ClassTag[_], + val tellMaster: Boolean) { + + /** + * The size of the block (in bytes) + */ + def size: Long = _size + def size_=(s: Long): Unit = { + _size = s + checkInvariants() + } + private[this] var _size: Long = 0 + + /** + * The number of times that this block has been locked for reading. + */ + def readerCount: Int = _readerCount + def readerCount_=(c: Int): Unit = { + _readerCount = c + checkInvariants() + } + private[this] var _readerCount: Int = 0 + + /** + * The task attempt id of the task which currently holds the write lock for this block, or + * [[BlockInfo.NON_TASK_WRITER]] if the write lock is held by non-task code, or + * [[BlockInfo.NO_WRITER]] if this block is not locked for writing. + */ + def writerTask: Long = _writerTask + def writerTask_=(t: Long): Unit = { + _writerTask = t + checkInvariants() + } + private[this] var _writerTask: Long = BlockInfo.NO_WRITER + + private def checkInvariants(): Unit = { + // A block's reader count must be non-negative: + assert(_readerCount >= 0) + // A block is either locked for reading or for writing, but not for both at the same time: + assert(_readerCount == 0 || _writerTask == BlockInfo.NO_WRITER) + } + + checkInvariants() +} + +private[storage] object BlockInfo { + + /** + * Special task attempt id constant used to mark a block's write lock as being unlocked. + */ + val NO_WRITER: Long = -1 + + /** + * Special task attempt id constant used to mark a block's write lock as being held by + * a non-task thread (e.g. by a driver thread or by unit test code). + */ + val NON_TASK_WRITER: Long = -1024 +} + +/** + * Component of the [[BlockManager]] which tracks metadata for blocks and manages block locking. + * + * The locking interface exposed by this class is readers-writer lock. Every lock acquisition is + * automatically associated with a running task and locks are automatically released upon task + * completion or failure. + * + * This class is thread-safe. + */ +private[storage] class BlockInfoManager extends Logging { + + private type TaskAttemptId = Long + + /** + * Used to look up metadata for individual blocks. Entries are added to this map via an atomic + * set-if-not-exists operation ([[lockNewBlockForWriting()]]) and are removed + * by [[removeBlock()]]. + */ + @GuardedBy("this") + private[this] val infos = new mutable.HashMap[BlockId, BlockInfo] + + /** + * Tracks the set of blocks that each task has locked for writing. + */ + @GuardedBy("this") + private[this] val writeLocksByTask = + new mutable.HashMap[TaskAttemptId, mutable.Set[BlockId]] + with mutable.MultiMap[TaskAttemptId, BlockId] + + /** + * Tracks the set of blocks that each task has locked for reading, along with the number of times + * that a block has been locked (since our read locks are re-entrant). + */ + @GuardedBy("this") + private[this] val readLocksByTask = + new mutable.HashMap[TaskAttemptId, ConcurrentHashMultiset[BlockId]] + + // ---------------------------------------------------------------------------------------------- + + // Initialization for special task attempt ids: + registerTask(BlockInfo.NON_TASK_WRITER) + + // ---------------------------------------------------------------------------------------------- + + /** + * Called at the start of a task in order to register that task with this [[BlockInfoManager]]. + * This must be called prior to calling any other BlockInfoManager methods from that task. + */ + def registerTask(taskAttemptId: TaskAttemptId): Unit = synchronized { + require(!readLocksByTask.contains(taskAttemptId), + s"Task attempt $taskAttemptId is already registered") + readLocksByTask(taskAttemptId) = ConcurrentHashMultiset.create() + } + + /** + * Returns the current task's task attempt id (which uniquely identifies the task), or + * [[BlockInfo.NON_TASK_WRITER]] if called by a non-task thread. + */ + private def currentTaskAttemptId: TaskAttemptId = { + Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(BlockInfo.NON_TASK_WRITER) + } + + /** + * Lock a block for reading and return its metadata. + * + * If another task has already locked this block for reading, then the read lock will be + * immediately granted to the calling task and its lock count will be incremented. + * + * If another task has locked this block for writing, then this call will block until the write + * lock is released or will return immediately if `blocking = false`. + * + * A single task can lock a block multiple times for reading, in which case each lock will need + * to be released separately. + * + * @param blockId the block to lock. + * @param blocking if true (default), this call will block until the lock is acquired. If false, + * this call will return immediately if the lock acquisition fails. + * @return None if the block did not exist or was removed (in which case no lock is held), or + * Some(BlockInfo) (in which case the block is locked for reading). + */ + def lockForReading( + blockId: BlockId, + blocking: Boolean = true): Option[BlockInfo] = synchronized { + logTrace(s"Task $currentTaskAttemptId trying to acquire read lock for $blockId") + do { + infos.get(blockId) match { + case None => return None + case Some(info) => + if (info.writerTask == BlockInfo.NO_WRITER) { + info.readerCount += 1 + readLocksByTask(currentTaskAttemptId).add(blockId) + logTrace(s"Task $currentTaskAttemptId acquired read lock for $blockId") + return Some(info) + } + } + if (blocking) { + wait() + } + } while (blocking) + None + } + + /** + * Lock a block for writing and return its metadata. + * + * If another task has already locked this block for either reading or writing, then this call + * will block until the other locks are released or will return immediately if `blocking = false`. + * + * If this is called by a task which already holds the block's exclusive write lock, then this + * method will throw an exception. + * + * @param blockId the block to lock. + * @param blocking if true (default), this call will block until the lock is acquired. If false, + * this call will return immediately if the lock acquisition fails. + * @return None if the block did not exist or was removed (in which case no lock is held), or + * Some(BlockInfo) (in which case the block is locked for writing). + */ + def lockForWriting( + blockId: BlockId, + blocking: Boolean = true): Option[BlockInfo] = synchronized { + logTrace(s"Task $currentTaskAttemptId trying to acquire write lock for $blockId") + do { + infos.get(blockId) match { + case None => return None + case Some(info) => + if (info.writerTask == currentTaskAttemptId) { + throw new IllegalStateException( + s"Task $currentTaskAttemptId has already locked $blockId for writing") + } else if (info.writerTask == BlockInfo.NO_WRITER && info.readerCount == 0) { + info.writerTask = currentTaskAttemptId + writeLocksByTask.addBinding(currentTaskAttemptId, blockId) + logTrace(s"Task $currentTaskAttemptId acquired write lock for $blockId") + return Some(info) + } + } + if (blocking) { + wait() + } + } while (blocking) + None + } + + /** + * Throws an exception if the current task does not hold a write lock on the given block. + * Otherwise, returns the block's BlockInfo. + */ + def assertBlockIsLockedForWriting(blockId: BlockId): BlockInfo = synchronized { + infos.get(blockId) match { + case Some(info) => + if (info.writerTask != currentTaskAttemptId) { + throw new SparkException( + s"Task $currentTaskAttemptId has not locked block $blockId for writing") + } else { + info + } + case None => + throw new SparkException(s"Block $blockId does not exist") + } + } + + /** + * Get a block's metadata without acquiring any locks. This method is only exposed for use by + * [[BlockManager.getStatus()]] and should not be called by other code outside of this class. + */ + private[storage] def get(blockId: BlockId): Option[BlockInfo] = synchronized { + infos.get(blockId) + } + + /** + * Downgrades an exclusive write lock to a shared read lock. + */ + def downgradeLock(blockId: BlockId): Unit = synchronized { + logTrace(s"Task $currentTaskAttemptId downgrading write lock for $blockId") + val info = get(blockId).get + require(info.writerTask == currentTaskAttemptId, + s"Task $currentTaskAttemptId tried to downgrade a write lock that it does not hold on" + + s" block $blockId") + unlock(blockId) + val lockOutcome = lockForReading(blockId, blocking = false) + assert(lockOutcome.isDefined) + } + + /** + * Release a lock on the given block. + */ + def unlock(blockId: BlockId): Unit = synchronized { + logTrace(s"Task $currentTaskAttemptId releasing lock for $blockId") + val info = get(blockId).getOrElse { + throw new IllegalStateException(s"Block $blockId not found") + } + if (info.writerTask != BlockInfo.NO_WRITER) { + info.writerTask = BlockInfo.NO_WRITER + writeLocksByTask.removeBinding(currentTaskAttemptId, blockId) + } else { + assert(info.readerCount > 0, s"Block $blockId is not locked for reading") + info.readerCount -= 1 + val countsForTask = readLocksByTask(currentTaskAttemptId) + val newPinCountForTask: Int = countsForTask.remove(blockId, 1) - 1 + assert(newPinCountForTask >= 0, + s"Task $currentTaskAttemptId release lock on block $blockId more times than it acquired it") + } + notifyAll() + } + + /** + * Attempt to acquire the appropriate lock for writing a new block. + * + * This enforces the first-writer-wins semantics. If we are the first to write the block, + * then just go ahead and acquire the write lock. Otherwise, if another thread is already + * writing the block, then we wait for the write to finish before acquiring the read lock. + * + * @return true if the block did not already exist, false otherwise. If this returns false, then + * a read lock on the existing block will be held. If this returns true, a write lock on + * the new block will be held. + */ + def lockNewBlockForWriting( + blockId: BlockId, + newBlockInfo: BlockInfo): Boolean = synchronized { + logTrace(s"Task $currentTaskAttemptId trying to put $blockId") + lockForReading(blockId) match { + case Some(info) => + // Block already exists. This could happen if another thread races with us to compute + // the same block. In this case, just keep the read lock and return. + false + case None => + // Block does not yet exist or is removed, so we are free to acquire the write lock + infos(blockId) = newBlockInfo + lockForWriting(blockId) + true + } + } + + /** + * Release all lock held by the given task, clearing that task's pin bookkeeping + * structures and updating the global pin counts. This method should be called at the + * end of a task (either by a task completion handler or in `TaskRunner.run()`). + * + * @return the ids of blocks whose pins were released + */ + def releaseAllLocksForTask(taskAttemptId: TaskAttemptId): Seq[BlockId] = { + val blocksWithReleasedLocks = mutable.ArrayBuffer[BlockId]() + + val readLocks = synchronized { + readLocksByTask.remove(taskAttemptId).get + } + val writeLocks = synchronized { + writeLocksByTask.remove(taskAttemptId).getOrElse(Seq.empty) + } + + for (blockId <- writeLocks) { + infos.get(blockId).foreach { info => + assert(info.writerTask == taskAttemptId) + info.writerTask = BlockInfo.NO_WRITER + } + blocksWithReleasedLocks += blockId + } + readLocks.entrySet().iterator().asScala.foreach { entry => + val blockId = entry.getElement + val lockCount = entry.getCount + blocksWithReleasedLocks += blockId + synchronized { + get(blockId).foreach { info => + info.readerCount -= lockCount + assert(info.readerCount >= 0) + } + } + } + + synchronized { + notifyAll() + } + blocksWithReleasedLocks + } + + /** + * Returns the number of blocks tracked. + */ + def size: Int = synchronized { + infos.size + } + + /** + * Return the number of map entries in this pin counter's internal data structures. + * This is used in unit tests in order to detect memory leaks. + */ + private[storage] def getNumberOfMapEntries: Long = synchronized { + size + + readLocksByTask.size + + readLocksByTask.map(_._2.size()).sum + + writeLocksByTask.size + + writeLocksByTask.map(_._2.size).sum + } + + /** + * Returns an iterator over a snapshot of all blocks' metadata. Note that the individual entries + * in this iterator are mutable and thus may reflect blocks that are deleted while the iterator + * is being traversed. + */ + def entries: Iterator[(BlockId, BlockInfo)] = synchronized { + infos.toArray.toIterator + } + + /** + * Removes the given block and releases the write lock on it. + * + * This can only be called while holding a write lock on the given block. + */ + def removeBlock(blockId: BlockId): Unit = synchronized { + logTrace(s"Task $currentTaskAttemptId trying to remove block $blockId") + infos.get(blockId) match { + case Some(blockInfo) => + if (blockInfo.writerTask != currentTaskAttemptId) { + throw new IllegalStateException( + s"Task $currentTaskAttemptId called remove() on block $blockId without a write lock") + } else { + infos.remove(blockId) + blockInfo.readerCount = 0 + blockInfo.writerTask = BlockInfo.NO_WRITER + writeLocksByTask.removeBinding(currentTaskAttemptId, blockId) + } + case None => + throw new IllegalArgumentException( + s"Task $currentTaskAttemptId called remove() on non-existent block $blockId") + } + notifyAll() + } + + /** + * Delete all state. Called during shutdown. + */ + def clear(): Unit = synchronized { + infos.valuesIterator.foreach { blockInfo => + blockInfo.readerCount = 0 + blockInfo.writerTask = BlockInfo.NO_WRITER + } + infos.clear() + readLocksByTask.clear() + writeLocksByTask.clear() + notifyAll() + } + +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index c374b93766225..35a6c63ad193e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -18,35 +18,31 @@ package org.apache.spark.storage import java.io._ -import java.nio.{ByteBuffer, MappedByteBuffer} +import java.nio.ByteBuffer import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.concurrent.{ExecutionContext, Await, Future} +import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration._ -import scala.util.control.NonFatal +import scala.reflect.ClassTag import scala.util.Random - -import sun.nio.ch.DirectBuffer +import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} -import org.apache.spark.io.CompressionCodec -import org.apache.spark.memory.MemoryManager +import org.apache.spark.internal.Logging +import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.network._ -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.buffer.{ManagedBuffer, NettyManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.rpc.RpcEnv -import org.apache.spark.serializer.{SerializerInstance, Serializer} +import org.apache.spark.serializer.{SerializerInstance, SerializerManager} import org.apache.spark.shuffle.ShuffleManager -import org.apache.spark.shuffle.hash.HashShuffleManager +import org.apache.spark.storage.memory._ +import org.apache.spark.unsafe.Platform import org.apache.spark.util._ - -private[spark] sealed trait BlockValues -private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues -private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues -private[spark] case class ArrayValues(buffer: Array[Any]) extends BlockValues +import org.apache.spark.util.io.ChunkedByteBuffer /* Class for returning a fetched block and associated metrics. */ private[spark] class BlockResult( @@ -58,13 +54,13 @@ private[spark] class BlockResult( * Manager running on every node (driver and executors) which provides interfaces for putting and * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap). * - * Note that #initialize() must be called before the BlockManager is usable. + * Note that [[initialize()]] must be called before the BlockManager is usable. */ private[spark] class BlockManager( executorId: String, rpcEnv: RpcEnv, val master: BlockManagerMaster, - defaultSerializer: Serializer, + serializerManager: SerializerManager, val conf: SparkConf, memoryManager: MemoryManager, mapOutputTracker: MapOutputTracker, @@ -72,33 +68,35 @@ private[spark] class BlockManager( blockTransferService: BlockTransferService, securityManager: SecurityManager, numUsableCores: Int) - extends BlockDataManager with Logging { + extends BlockDataManager with BlockEvictionHandler with Logging { + + private[spark] val externalShuffleServiceEnabled = + conf.getBoolean("spark.shuffle.service.enabled", false) - val diskBlockManager = new DiskBlockManager(this, conf) + val diskBlockManager = { + // Only perform cleanup if an external service is not serving our shuffle files. + val deleteFilesOnStop = + !externalShuffleServiceEnabled || executorId == SparkContext.DRIVER_IDENTIFIER + new DiskBlockManager(conf, deleteFilesOnStop) + } - private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] + // Visible for testing + private[storage] val blockInfoManager = new BlockInfoManager private val futureExecutionContext = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonCachedThreadPool("block-manager-future", 128)) // Actual storage of where blocks are kept - private var externalBlockStoreInitialized = false - private[spark] val memoryStore = new MemoryStore(this, memoryManager) - private[spark] val diskStore = new DiskStore(this, diskBlockManager) - private[spark] lazy val externalBlockStore: ExternalBlockStore = { - externalBlockStoreInitialized = true - new ExternalBlockStore(this, executorId) - } + private[spark] val memoryStore = + new MemoryStore(conf, blockInfoManager, serializerManager, memoryManager, this) + private[spark] val diskStore = new DiskStore(conf, diskBlockManager) memoryManager.setMemoryStore(memoryStore) // Note: depending on the memory manager, `maxStorageMemory` may actually vary over time. // However, since we use this only for reporting and logging, what we actually want here is // the absolute maximum value that `maxStorageMemory` can ever possibly reach. We may need // to revisit whether reporting this value as the "max" is intuitive to the user. - private val maxMemory = memoryManager.maxStorageMemory - - private[spark] - val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) + private val maxMemory = memoryManager.maxOnHeapStorageMemory // Port used by the external shuffle service. In Yarn mode, this may be already be // set through the Hadoop configuration as the server is launched in the Yarn NM. @@ -123,21 +121,16 @@ private[spark] class BlockManager( // Client to read other executors' shuffle files. This is either an external service, or just the // standard BlockTransferService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { - val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores) + val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(), securityManager.isSaslEncryptionEnabled()) } else { blockTransferService } - // Whether to compress broadcast variables that are stored - private val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true) - // Whether to compress shuffle output that are stored - private val compressShuffle = conf.getBoolean("spark.shuffle.compress", true) - // Whether to compress RDD partitions that are stored serialized - private val compressRdds = conf.getBoolean("spark.rdd.compress", false) - // Whether to compress shuffle output temporarily spilled to disk - private val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) + // Max number of failures before this block manager refreshes the block locations from the driver + private val maxFailuresBeforeLocationRefresh = + conf.getInt("spark.block.failures.beforeLocationRefresh", 5) private val slaveEndpoint = rpcEnv.setupEndpoint( "BlockManagerEndpoint" + BlockManager.ID_GENERATOR.next, @@ -148,23 +141,11 @@ private[spark] class BlockManager( private var asyncReregisterTask: Future[Unit] = null private val asyncReregisterLock = new Object - private val metadataCleaner = new MetadataCleaner( - MetadataCleanerType.BLOCK_MANAGER, this.dropOldNonBroadcastBlocks, conf) - private val broadcastCleaner = new MetadataCleaner( - MetadataCleanerType.BROADCAST_VARS, this.dropOldBroadcastBlocks, conf) - // Field related to peer block managers that are necessary for block replication @volatile private var cachedPeers: Seq[BlockManagerId] = _ private val peerFetchLock = new Object private var lastPeerFetchTime = 0L - /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay - * the initialization of the compression codec until it is first used. The reason is that a Spark - * program could be using a user-defined codec in a third party jar, which is loaded in - * Executor.updateDependencies. When the BlockManager is initialized, user level jars hasn't been - * loaded yet. */ - private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) - /** * Initializes the BlockManager with the given appId. This is not performed in the constructor as * the appId may not be known at BlockManager instantiation time (in particular for the driver, @@ -201,7 +182,7 @@ private[spark] class BlockManager( val shuffleConfig = new ExecutorShuffleInfo( diskBlockManager.localDirs.map(_.toString), diskBlockManager.subDirsPerLocalDir, - shuffleManager.getClass.getName) + shuffleManager.shortName) val MAX_ATTEMPTS = 3 val SLEEP_TIME_SECS = 5 @@ -232,8 +213,8 @@ private[spark] class BlockManager( * will be made then. */ private def reportAllBlocks(): Unit = { - logInfo(s"Reporting ${blockInfo.size} blocks to the master.") - for ((blockId, info) <- blockInfo) { + logInfo(s"Reporting ${blockInfoManager.size} blocks to the master.") + for ((blockId, info) <- blockInfoManager.entries) { val status = getCurrentBlockStatus(blockId, info) if (!tryToReportBlockStatus(blockId, info, status)) { logError(s"Failed to report $blockId to master; giving up.") @@ -291,13 +272,9 @@ private[spark] class BlockManager( if (blockId.isShuffle) { shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) } else { - val blockBytesOpt = doGetLocal(blockId, asBlockResult = false) - .asInstanceOf[Option[ByteBuffer]] - if (blockBytesOpt.isDefined) { - val buffer = blockBytesOpt.get - new NioManagedBuffer(buffer) - } else { - throw new BlockNotFoundException(blockId.toString) + getLocalBytes(blockId) match { + case Some(buffer) => new BlockManagerManagedBuffer(blockInfoManager, blockId, buffer) + case None => throw new BlockNotFoundException(blockId.toString) } } } @@ -305,8 +282,12 @@ private[spark] class BlockManager( /** * Put the block locally, using the given storage level. */ - override def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit = { - putBytes(blockId, data.nioByteBuffer(), level) + override def putBlockData( + blockId: BlockId, + data: ManagedBuffer, + level: StorageLevel, + classTag: ClassTag[_]): Boolean = { + putBytes(blockId, new ChunkedByteBuffer(data.nioByteBuffer()), level)(classTag) } /** @@ -314,11 +295,10 @@ private[spark] class BlockManager( * NOTE: This is mainly for testing, and it doesn't fetch information from external block store. */ def getStatus(blockId: BlockId): Option[BlockStatus] = { - blockInfo.get(blockId).map { info => + blockInfoManager.get(blockId).map { info => val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L val diskSize = if (diskStore.contains(blockId)) diskStore.getSize(blockId) else 0L - // Assume that block is not in external block store - BlockStatus(info.level, memSize, diskSize, 0L) + BlockStatus(info.level, memSize = memSize, diskSize = diskSize) } } @@ -328,7 +308,12 @@ private[spark] class BlockManager( * may not know of). */ def getMatchingBlockIds(filter: BlockId => Boolean): Seq[BlockId] = { - (blockInfo.keys ++ diskBlockManager.getAllBlocks()).filter(filter).toSeq + // The `toArray` is necessary here in order to force the list to be materialized so that we + // don't try to serialize a lazy iterator when responding to client requests. + (blockInfoManager.entries.map(_._1) ++ diskBlockManager.getAllBlocks()) + .filter(filter) + .toArray + .toSeq } /** @@ -367,10 +352,8 @@ private[spark] class BlockManager( if (info.tellMaster) { val storageLevel = status.storageLevel val inMemSize = Math.max(status.memSize, droppedMemorySize) - val inExternalBlockStoreSize = status.externalBlockStoreSize val onDiskSize = status.diskSize - master.updateBlockInfo( - blockManagerId, blockId, storageLevel, inMemSize, onDiskSize, inExternalBlockStoreSize) + master.updateBlockInfo(blockManagerId, blockId, storageLevel, inMemSize, onDiskSize) } else { true } @@ -385,20 +368,21 @@ private[spark] class BlockManager( info.synchronized { info.level match { case null => - BlockStatus(StorageLevel.NONE, 0L, 0L, 0L) + BlockStatus(StorageLevel.NONE, memSize = 0L, diskSize = 0L) case level => val inMem = level.useMemory && memoryStore.contains(blockId) - val inExternalBlockStore = level.useOffHeap && externalBlockStore.contains(blockId) val onDisk = level.useDisk && diskStore.contains(blockId) val deserialized = if (inMem) level.deserialized else false - val replication = if (inMem || inExternalBlockStore || onDisk) level.replication else 1 - val storageLevel = - StorageLevel(onDisk, inMem, inExternalBlockStore, deserialized, replication) + val replication = if (inMem || onDisk) level.replication else 1 + val storageLevel = StorageLevel( + useDisk = onDisk, + useMemory = inMem, + useOffHeap = level.useOffHeap, + deserialized = deserialized, + replication = replication) val memSize = if (inMem) memoryStore.getSize(blockId) else 0L - val externalBlockStoreSize = - if (inExternalBlockStore) externalBlockStore.getSize(blockId) else 0L val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L - BlockStatus(storageLevel, memSize, diskSize, externalBlockStoreSize) + BlockStatus(storageLevel, memSize, diskSize) } } } @@ -414,17 +398,54 @@ private[spark] class BlockManager( } /** - * Get block from local block manager. + * Get block from local block manager as an iterator of Java objects. */ - def getLocal(blockId: BlockId): Option[BlockResult] = { + def getLocalValues(blockId: BlockId): Option[BlockResult] = { logDebug(s"Getting local block $blockId") - doGetLocal(blockId, asBlockResult = true).asInstanceOf[Option[BlockResult]] + blockInfoManager.lockForReading(blockId) match { + case None => + logDebug(s"Block $blockId was not found") + None + case Some(info) => + val level = info.level + logDebug(s"Level for block $blockId is $level") + if (level.useMemory && memoryStore.contains(blockId)) { + val iter: Iterator[Any] = if (level.deserialized) { + memoryStore.getValues(blockId).get + } else { + serializerManager.dataDeserializeStream( + blockId, memoryStore.getBytes(blockId).get.toInputStream())(info.classTag) + } + val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId)) + Some(new BlockResult(ci, DataReadMethod.Memory, info.size)) + } else if (level.useDisk && diskStore.contains(blockId)) { + val iterToReturn: Iterator[Any] = { + val diskBytes = diskStore.getBytes(blockId) + if (level.deserialized) { + val diskValues = serializerManager.dataDeserializeStream( + blockId, + diskBytes.toInputStream(dispose = true))(info.classTag) + maybeCacheDiskValuesInMemory(info, blockId, level, diskValues) + } else { + val stream = maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes) + .map {_.toInputStream(dispose = false)} + .getOrElse { diskBytes.toInputStream(dispose = true) } + serializerManager.dataDeserializeStream(blockId, stream)(info.classTag) + } + } + val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, releaseLock(blockId)) + Some(new BlockResult(ci, DataReadMethod.Disk, info.size)) + } else { + releaseLock(blockId) + throw new SparkException(s"Block $blockId was not found even though it's read-locked") + } + } } /** * Get block from the local block manager as serialized bytes. */ - def getLocalBytes(blockId: BlockId): Option[ByteBuffer] = { + def getLocalBytes(blockId: BlockId): Option[ChunkedByteBuffer] = { logDebug(s"Getting local block $blockId as bytes") // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work @@ -433,185 +454,125 @@ private[spark] class BlockManager( // TODO: This should gracefully handle case where local block is not available. Currently // downstream code will throw an exception. Option( - shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer()) + new ChunkedByteBuffer( + shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer())) } else { - doGetLocal(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] + blockInfoManager.lockForReading(blockId).map { info => doGetLocalBytes(blockId, info) } } } - private def doGetLocal(blockId: BlockId, asBlockResult: Boolean): Option[Any] = { - val info = blockInfo.get(blockId).orNull - if (info != null) { - info.synchronized { - // Double check to make sure the block is still there. There is a small chance that the - // block has been removed by removeBlock (which also synchronizes on the blockInfo object). - // Note that this only checks metadata tracking. If user intentionally deleted the block - // on disk or from off heap storage without using removeBlock, this conditional check will - // still pass but eventually we will get an exception because we can't find the block. - if (blockInfo.get(blockId).isEmpty) { - logWarning(s"Block $blockId had been removed") - return None - } - - // If another thread is writing the block, wait for it to become ready. - if (!info.waitForReady()) { - // If we get here, the block write failed. - logWarning(s"Block $blockId was marked as failure.") - return None - } - - val level = info.level - logDebug(s"Level for block $blockId is $level") - - // Look for the block in memory - if (level.useMemory) { - logDebug(s"Getting block $blockId from memory") - val result = if (asBlockResult) { - memoryStore.getValues(blockId).map(new BlockResult(_, DataReadMethod.Memory, info.size)) - } else { - memoryStore.getBytes(blockId) - } - result match { - case Some(values) => - return result - case None => - logDebug(s"Block $blockId not found in memory") - } - } - - // Look for the block in external block store - if (level.useOffHeap) { - logDebug(s"Getting block $blockId from ExternalBlockStore") - if (externalBlockStore.contains(blockId)) { - val result = if (asBlockResult) { - externalBlockStore.getValues(blockId) - .map(new BlockResult(_, DataReadMethod.Memory, info.size)) - } else { - externalBlockStore.getBytes(blockId) - } - result match { - case Some(values) => - return result - case None => - logDebug(s"Block $blockId not found in ExternalBlockStore") - } - } - } - - // Look for block on disk, potentially storing it back in memory if required - if (level.useDisk) { - logDebug(s"Getting block $blockId from disk") - val bytes: ByteBuffer = diskStore.getBytes(blockId) match { - case Some(b) => b - case None => - throw new BlockException( - blockId, s"Block $blockId not found on disk, though it should be") - } - assert(0 == bytes.position()) - - if (!level.useMemory) { - // If the block shouldn't be stored in memory, we can just return it - if (asBlockResult) { - return Some(new BlockResult(dataDeserialize(blockId, bytes), DataReadMethod.Disk, - info.size)) - } else { - return Some(bytes) - } - } else { - // Otherwise, we also have to store something in the memory store - if (!level.deserialized || !asBlockResult) { - /* We'll store the bytes in memory if the block's storage level includes - * "memory serialized", or if it should be cached as objects in memory - * but we only requested its serialized bytes. */ - memoryStore.putBytes(blockId, bytes.limit, () => { - // https://issues.apache.org/jira/browse/SPARK-6076 - // If the file size is bigger than the free memory, OOM will happen. So if we cannot - // put it into MemoryStore, copyForMemory should not be created. That's why this - // action is put into a `() => ByteBuffer` and created lazily. - val copyForMemory = ByteBuffer.allocate(bytes.limit) - copyForMemory.put(bytes) - }) - bytes.rewind() - } - if (!asBlockResult) { - return Some(bytes) - } else { - val values = dataDeserialize(blockId, bytes) - if (level.deserialized) { - // Cache the values before returning them - val putResult = memoryStore.putIterator( - blockId, values, level, returnValues = true, allowPersistToDisk = false) - // The put may or may not have succeeded, depending on whether there was enough - // space to unroll the block. Either way, the put here should return an iterator. - putResult.data match { - case Left(it) => - return Some(new BlockResult(it, DataReadMethod.Disk, info.size)) - case _ => - // This only happens if we dropped the values back to disk (which is never) - throw new SparkException("Memory store did not return an iterator!") - } - } else { - return Some(new BlockResult(values, DataReadMethod.Disk, info.size)) - } - } - } - } + /** + * Get block from the local block manager as serialized bytes. + * + * Must be called while holding a read lock on the block. + * Releases the read lock upon exception; keeps the read lock upon successful return. + */ + private def doGetLocalBytes(blockId: BlockId, info: BlockInfo): ChunkedByteBuffer = { + val level = info.level + logDebug(s"Level for block $blockId is $level") + // In order, try to read the serialized bytes from memory, then from disk, then fall back to + // serializing in-memory objects, and, finally, throw an exception if the block does not exist. + if (level.deserialized) { + // Try to avoid expensive serialization by reading a pre-serialized copy from disk: + if (level.useDisk && diskStore.contains(blockId)) { + // Note: we purposely do not try to put the block back into memory here. Since this branch + // handles deserialized blocks, this block may only be cached in memory as objects, not + // serialized bytes. Because the caller only requested bytes, it doesn't make sense to + // cache the block's deserialized objects since that caching may not have a payoff. + diskStore.getBytes(blockId) + } else if (level.useMemory && memoryStore.contains(blockId)) { + // The block was not found on disk, so serialize an in-memory copy: + serializerManager.dataSerialize(blockId, memoryStore.getValues(blockId).get) + } else { + releaseLock(blockId) + throw new SparkException(s"Block $blockId was not found even though it's read-locked") + } + } else { // storage level is serialized + if (level.useMemory && memoryStore.contains(blockId)) { + memoryStore.getBytes(blockId).get + } else if (level.useDisk && diskStore.contains(blockId)) { + val diskBytes = diskStore.getBytes(blockId) + maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes).getOrElse(diskBytes) + } else { + releaseLock(blockId) + throw new SparkException(s"Block $blockId was not found even though it's read-locked") } - } else { - logDebug(s"Block $blockId not registered locally") } - None } /** * Get block from remote block managers. + * + * This does not acquire a lock on this block in this JVM. */ - def getRemote(blockId: BlockId): Option[BlockResult] = { - logDebug(s"Getting remote block $blockId") - doGetRemote(blockId, asBlockResult = true).asInstanceOf[Option[BlockResult]] + private def getRemoteValues(blockId: BlockId): Option[BlockResult] = { + getRemoteBytes(blockId).map { data => + val values = + serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true)) + new BlockResult(values, DataReadMethod.Network, data.size) + } } /** - * Get block from remote block managers as serialized bytes. + * Return a list of locations for the given block, prioritizing the local machine since + * multiple block managers can share the same host. */ - def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = { - logDebug(s"Getting remote block $blockId as bytes") - doGetRemote(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] + private def getLocations(blockId: BlockId): Seq[BlockManagerId] = { + val locs = Random.shuffle(master.getLocations(blockId)) + val (preferredLocs, otherLocs) = locs.partition { loc => blockManagerId.host == loc.host } + preferredLocs ++ otherLocs } - private def doGetRemote(blockId: BlockId, asBlockResult: Boolean): Option[Any] = { + /** + * Get block from remote block managers as serialized bytes. + */ + def getRemoteBytes(blockId: BlockId): Option[ChunkedByteBuffer] = { + logDebug(s"Getting remote block $blockId") require(blockId != null, "BlockId is null") - val locations = Random.shuffle(master.getLocations(blockId)) - var numFetchFailures = 0 - for (loc <- locations) { + var runningFailureCount = 0 + var totalFailureCount = 0 + val locations = getLocations(blockId) + val maxFetchFailures = locations.size + var locationIterator = locations.iterator + while (locationIterator.hasNext) { + val loc = locationIterator.next() logDebug(s"Getting remote block $blockId from $loc") val data = try { blockTransferService.fetchBlockSync( loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer() } catch { case NonFatal(e) => - numFetchFailures += 1 - if (numFetchFailures == locations.size) { - // An exception is thrown while fetching this block from all locations - throw new BlockFetchException(s"Failed to fetch block from" + - s" ${locations.size} locations. Most recent failure cause:", e) - } else { - // This location failed, so we retry fetch from a different one by returning null here - logWarning(s"Failed to fetch remote block $blockId " + - s"from $loc (failed attempt $numFetchFailures)", e) - null + runningFailureCount += 1 + totalFailureCount += 1 + + if (totalFailureCount >= maxFetchFailures) { + // Give up trying anymore locations. Either we've tried all of the original locations, + // or we've refreshed the list of locations from the master, and have still + // hit failures after trying locations from the refreshed list. + throw new BlockFetchException(s"Failed to fetch block after" + + s" ${totalFailureCount} fetch failures. Most recent failure cause:", e) } + + logWarning(s"Failed to fetch remote block $blockId " + + s"from $loc (failed attempt $runningFailureCount)", e) + + // If there is a large number of executors then locations list can contain a + // large number of stale entries causing a large number of retries that may + // take a significant amount of time. To get rid of these stale entries + // we refresh the block locations after a certain number of fetch failures + if (runningFailureCount >= maxFailuresBeforeLocationRefresh) { + locationIterator = getLocations(blockId).iterator + logDebug(s"Refreshed locations from the driver " + + s"after ${runningFailureCount} fetch failures.") + runningFailureCount = 0 + } + + // This location failed, so we retry fetch from a different one by returning null here + null } if (data != null) { - if (asBlockResult) { - return Some(new BlockResult( - dataDeserialize(blockId, data), - DataReadMethod.Network, - data.limit())) - } else { - return Some(data) - } + return Some(new ChunkedByteBuffer(data)) } logDebug(s"The value of block $blockId is null") } @@ -621,14 +582,18 @@ private[spark] class BlockManager( /** * Get a block from the block manager (either local or remote). + * + * This acquires a read lock on the block if the block was stored locally and does not acquire + * any locks if the block was fetched from a remote block manager. The read lock will + * automatically be freed once the result's `data` iterator is fully consumed. */ def get(blockId: BlockId): Option[BlockResult] = { - val local = getLocal(blockId) + val local = getLocalValues(blockId) if (local.isDefined) { logInfo(s"Found block $blockId locally") return local } - val remote = getRemote(blockId) + val remote = getRemoteValues(blockId) if (remote.isDefined) { logInfo(s"Found block $blockId remotely") return remote @@ -636,14 +601,98 @@ private[spark] class BlockManager( None } - def putIterator( + /** + * Downgrades an exclusive write lock to a shared read lock. + */ + def downgradeLock(blockId: BlockId): Unit = { + blockInfoManager.downgradeLock(blockId) + } + + /** + * Release a lock on the given block. + */ + def releaseLock(blockId: BlockId): Unit = { + blockInfoManager.unlock(blockId) + } + + /** + * Registers a task with the BlockManager in order to initialize per-task bookkeeping structures. + */ + def registerTask(taskAttemptId: Long): Unit = { + blockInfoManager.registerTask(taskAttemptId) + } + + /** + * Release all locks for the given task. + * + * @return the blocks whose locks were released. + */ + def releaseAllLocksForTask(taskAttemptId: Long): Seq[BlockId] = { + blockInfoManager.releaseAllLocksForTask(taskAttemptId) + } + + /** + * Retrieve the given block if it exists, otherwise call the provided `makeIterator` method + * to compute the block, persist it, and return its values. + * + * @return either a BlockResult if the block was successfully cached, or an iterator if the block + * could not be cached. + */ + def getOrElseUpdate[T]( blockId: BlockId, - values: Iterator[Any], level: StorageLevel, - tellMaster: Boolean = true, - effectiveStorageLevel: Option[StorageLevel] = None): Seq[(BlockId, BlockStatus)] = { + classTag: ClassTag[T], + makeIterator: () => Iterator[T]): Either[BlockResult, Iterator[T]] = { + // Attempt to read the block from local or remote storage. If it's present, then we don't need + // to go through the local-get-or-put path. + get(blockId) match { + case Some(block) => + return Left(block) + case _ => + // Need to compute the block. + } + // Initially we hold no locks on this block. + doPutIterator(blockId, makeIterator, level, classTag, keepReadLock = true) match { + case None => + // doPut() didn't hand work back to us, so the block already existed or was successfully + // stored. Therefore, we now hold a read lock on the block. + val blockResult = getLocalValues(blockId).getOrElse { + // Since we held a read lock between the doPut() and get() calls, the block should not + // have been evicted, so get() not returning the block indicates some internal error. + releaseLock(blockId) + throw new SparkException(s"get() failed for block $blockId even though we held a lock") + } + // We already hold a read lock on the block from the doPut() call and getLocalValues() + // acquires the lock again, so we need to call releaseLock() here so that the net number + // of lock acquisitions is 1 (since the caller will only call release() once). + releaseLock(blockId) + Left(blockResult) + case Some(iter) => + // The put failed, likely because the data was too large to fit in memory and could not be + // dropped to disk. Therefore, we need to pass the input iterator back to the caller so + // that they can decide what to do with the values (e.g. process them without caching). + Right(iter) + } + } + + /** + * @return true if the block was stored or false if an error occurred. + */ + def putIterator[T: ClassTag]( + blockId: BlockId, + values: Iterator[T], + level: StorageLevel, + tellMaster: Boolean = true): Boolean = { require(values != null, "Values is null") - doPut(blockId, IteratorValues(values), level, tellMaster, effectiveStorageLevel) + doPutIterator(blockId, () => values, level, implicitly[ClassTag[T]], tellMaster) match { + case None => + true + case Some(iter) => + // Caller doesn't care about the iterator values, so we can close the iterator here + // to free resources earlier + iter.close() + false + } } /** @@ -657,224 +706,350 @@ private[spark] class BlockManager( serializerInstance: SerializerInstance, bufferSize: Int, writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { - val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) + val compressStream: OutputStream => OutputStream = + serializerManager.wrapForCompression(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) new DiskBlockObjectWriter(file, serializerInstance, bufferSize, compressStream, - syncWrites, writeMetrics) + syncWrites, writeMetrics, blockId) } /** - * Put a new block of values to the block manager. - * Return a list of blocks updated as a result of this put. + * Put a new block of serialized bytes to the block manager. + * + * @return true if the block was stored or false if an error occurred. */ - def putArray( + def putBytes[T: ClassTag]( blockId: BlockId, - values: Array[Any], + bytes: ChunkedByteBuffer, level: StorageLevel, - tellMaster: Boolean = true, - effectiveStorageLevel: Option[StorageLevel] = None): Seq[(BlockId, BlockStatus)] = { - require(values != null, "Values is null") - doPut(blockId, ArrayValues(values), level, tellMaster, effectiveStorageLevel) + tellMaster: Boolean = true): Boolean = { + require(bytes != null, "Bytes is null") + doPutBytes(blockId, bytes, level, implicitly[ClassTag[T]], tellMaster) } /** - * Put a new block of serialized bytes to the block manager. - * Return a list of blocks updated as a result of this put. + * Put the given bytes according to the given level in one of the block stores, replicating + * the values if necessary. + * + * If the block already exists, this method will not overwrite it. + * + * @param keepReadLock if true, this method will hold the read lock when it returns (even if the + * block already exists). If false, this method will hold no locks when it + * returns. + * @return true if the block was already present or if the put succeeded, false otherwise. */ - def putBytes( + private def doPutBytes[T]( blockId: BlockId, - bytes: ByteBuffer, + bytes: ChunkedByteBuffer, level: StorageLevel, + classTag: ClassTag[T], tellMaster: Boolean = true, - effectiveStorageLevel: Option[StorageLevel] = None): Seq[(BlockId, BlockStatus)] = { - require(bytes != null, "Bytes is null") - doPut(blockId, ByteBufferValues(bytes), level, tellMaster, effectiveStorageLevel) + keepReadLock: Boolean = false): Boolean = { + doPut(blockId, level, classTag, tellMaster = tellMaster, keepReadLock = keepReadLock) { info => + val startTimeMs = System.currentTimeMillis + // Since we're storing bytes, initiate the replication before storing them locally. + // This is faster as data is already serialized and ready to send. + val replicationFuture = if (level.replication > 1) { + Future { + // This is a blocking action and should run in futureExecutionContext which is a cached + // thread pool + replicate(blockId, bytes, level, classTag) + }(futureExecutionContext) + } else { + null + } + + val size = bytes.size + + if (level.useMemory) { + // Put it in memory first, even if it also has useDisk set to true; + // We will drop it to disk later if the memory store can't hold it. + val putSucceeded = if (level.deserialized) { + val values = + serializerManager.dataDeserializeStream(blockId, bytes.toInputStream())(classTag) + memoryStore.putIteratorAsValues(blockId, values, classTag) match { + case Right(_) => true + case Left(iter) => + // If putting deserialized values in memory failed, we will put the bytes directly to + // disk, so we don't need this iterator and can close it to free resources earlier. + iter.close() + false + } + } else { + memoryStore.putBytes(blockId, size, level.memoryMode, () => bytes) + } + if (!putSucceeded && level.useDisk) { + logWarning(s"Persisting block $blockId to disk instead.") + diskStore.putBytes(blockId, bytes) + } + } else if (level.useDisk) { + diskStore.putBytes(blockId, bytes) + } + + val putBlockStatus = getCurrentBlockStatus(blockId, info) + val blockWasSuccessfullyStored = putBlockStatus.storageLevel.isValid + if (blockWasSuccessfullyStored) { + // Now that the block is in either the memory, externalBlockStore, or disk store, + // tell the master about it. + info.size = size + if (tellMaster) { + reportBlockStatus(blockId, info, putBlockStatus) + } + Option(TaskContext.get()).foreach { c => + c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, putBlockStatus))) + } + } + logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs))) + if (level.replication > 1) { + // Wait for asynchronous replication to finish + Await.ready(replicationFuture, Duration.Inf) + } + if (blockWasSuccessfullyStored) { + None + } else { + Some(bytes) + } + }.isEmpty } /** - * Put the given block according to the given level in one of the block stores, replicating - * the values if necessary. + * Helper method used to abstract common code from [[doPutBytes()]] and [[doPutIterator()]]. * - * The effective storage level refers to the level according to which the block will actually be - * handled. This allows the caller to specify an alternate behavior of doPut while preserving - * the original level specified by the user. + * @param putBody a function which attempts the actual put() and returns None on success + * or Some on failure. */ - private def doPut( + private def doPut[T]( blockId: BlockId, - data: BlockValues, level: StorageLevel, - tellMaster: Boolean = true, - effectiveStorageLevel: Option[StorageLevel] = None) - : Seq[(BlockId, BlockStatus)] = { + classTag: ClassTag[_], + tellMaster: Boolean, + keepReadLock: Boolean)(putBody: BlockInfo => Option[T]): Option[T] = { require(blockId != null, "BlockId is null") require(level != null && level.isValid, "StorageLevel is null or invalid") - effectiveStorageLevel.foreach { level => - require(level != null && level.isValid, "Effective StorageLevel is null or invalid") - } - - // Return value - val updatedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - /* Remember the block's storage level so that we can correctly drop it to disk if it needs - * to be dropped right after it got put into memory. Note, however, that other threads will - * not be able to get() this block until we call markReady on its BlockInfo. */ val putBlockInfo = { - val tinfo = new BlockInfo(level, tellMaster) - // Do atomically ! - val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo) - if (oldBlockOpt.isDefined) { - if (oldBlockOpt.get.waitForReady()) { - logWarning(s"Block $blockId already exists on this machine; not re-adding it") - return updatedBlocks - } - // TODO: So the block info exists - but previous attempt to load it (?) failed. - // What do we do now ? Retry on it ? - oldBlockOpt.get + val newInfo = new BlockInfo(level, classTag, tellMaster) + if (blockInfoManager.lockNewBlockForWriting(blockId, newInfo)) { + newInfo } else { - tinfo + logWarning(s"Block $blockId already exists on this machine; not re-adding it") + if (!keepReadLock) { + // lockNewBlockForWriting returned a read lock on the existing block, so we must free it: + releaseLock(blockId) + } + return None } } val startTimeMs = System.currentTimeMillis - - /* If we're storing values and we need to replicate the data, we'll want access to the values, - * but because our put will read the whole iterator, there will be no values left. For the - * case where the put serializes data, we'll remember the bytes, above; but for the case where - * it doesn't, such as deserialized storage, let's rely on the put returning an Iterator. */ - var valuesAfterPut: Iterator[Any] = null - - // Ditto for the bytes after the put - var bytesAfterPut: ByteBuffer = null - - // Size of the block in bytes - var size = 0L - - // The level we actually use to put the block - val putLevel = effectiveStorageLevel.getOrElse(level) - - // If we're storing bytes, then initiate the replication before storing them locally. - // This is faster as data is already serialized and ready to send. - val replicationFuture = data match { - case b: ByteBufferValues if putLevel.replication > 1 => - // Duplicate doesn't copy the bytes, but just creates a wrapper - val bufferView = b.buffer.duplicate() - Future { - // This is a blocking action and should run in futureExecutionContext which is a cached - // thread pool - replicate(blockId, bufferView, putLevel) - }(futureExecutionContext) - case _ => null + var blockWasSuccessfullyStored: Boolean = false + val result: Option[T] = try { + val res = putBody(putBlockInfo) + blockWasSuccessfullyStored = res.isEmpty + res + } finally { + if (blockWasSuccessfullyStored) { + if (keepReadLock) { + blockInfoManager.downgradeLock(blockId) + } else { + blockInfoManager.unlock(blockId) + } + } else { + blockInfoManager.removeBlock(blockId) + logWarning(s"Putting block $blockId failed") + } } - - putBlockInfo.synchronized { - logTrace("Put for block %s took %s to get into synchronized block" + if (level.replication > 1) { + logDebug("Putting block %s with replication took %s" .format(blockId, Utils.getUsedTimeMs(startTimeMs))) + } else { + logDebug("Putting block %s without replication took %s" + .format(blockId, Utils.getUsedTimeMs(startTimeMs))) + } + result + } - var marked = false - try { - // returnValues - Whether to return the values put - // blockStore - The type of storage to put these values into - val (returnValues, blockStore: BlockStore) = { - if (putLevel.useMemory) { - // Put it in memory first, even if it also has useDisk set to true; - // We will drop it to disk later if the memory store can't hold it. - (true, memoryStore) - } else if (putLevel.useOffHeap) { - // Use external block store - (false, externalBlockStore) - } else if (putLevel.useDisk) { - // Don't get back the bytes from put unless we replicate them - (putLevel.replication > 1, diskStore) - } else { - assert(putLevel == StorageLevel.NONE) - throw new BlockException( - blockId, s"Attempted to put block $blockId without specifying storage level!") + /** + * Put the given block according to the given level in one of the block stores, replicating + * the values if necessary. + * + * If the block already exists, this method will not overwrite it. + * + * @param keepReadLock if true, this method will hold the read lock when it returns (even if the + * block already exists). If false, this method will hold no locks when it + * returns. + * @return None if the block was already present or if the put succeeded, or Some(iterator) + * if the put failed. + */ + private def doPutIterator[T]( + blockId: BlockId, + iterator: () => Iterator[T], + level: StorageLevel, + classTag: ClassTag[T], + tellMaster: Boolean = true, + keepReadLock: Boolean = false): Option[PartiallyUnrolledIterator[T]] = { + doPut(blockId, level, classTag, tellMaster = tellMaster, keepReadLock = keepReadLock) { info => + val startTimeMs = System.currentTimeMillis + var iteratorFromFailedMemoryStorePut: Option[PartiallyUnrolledIterator[T]] = None + // Size of the block in bytes + var size = 0L + if (level.useMemory) { + // Put it in memory first, even if it also has useDisk set to true; + // We will drop it to disk later if the memory store can't hold it. + if (level.deserialized) { + memoryStore.putIteratorAsValues(blockId, iterator(), classTag) match { + case Right(s) => + size = s + case Left(iter) => + // Not enough space to unroll this block; drop to disk if applicable + if (level.useDisk) { + logWarning(s"Persisting block $blockId to disk instead.") + diskStore.put(blockId) { fileOutputStream => + serializerManager.dataSerializeStream(blockId, fileOutputStream, iter)(classTag) + } + size = diskStore.getSize(blockId) + } else { + iteratorFromFailedMemoryStorePut = Some(iter) + } + } + } else { // !level.deserialized + memoryStore.putIteratorAsBytes(blockId, iterator(), classTag, level.memoryMode) match { + case Right(s) => + size = s + case Left(partiallySerializedValues) => + // Not enough space to unroll this block; drop to disk if applicable + if (level.useDisk) { + logWarning(s"Persisting block $blockId to disk instead.") + diskStore.put(blockId) { fileOutputStream => + partiallySerializedValues.finishWritingToStream(fileOutputStream) + } + size = diskStore.getSize(blockId) + } else { + iteratorFromFailedMemoryStorePut = Some(partiallySerializedValues.valuesIterator) + } } } - // Actually put the values - val result = data match { - case IteratorValues(iterator) => - blockStore.putIterator(blockId, iterator, putLevel, returnValues) - case ArrayValues(array) => - blockStore.putArray(blockId, array, putLevel, returnValues) - case ByteBufferValues(bytes) => - bytes.rewind() - blockStore.putBytes(blockId, bytes, putLevel) - } - size = result.size - result.data match { - case Left (newIterator) if putLevel.useMemory => valuesAfterPut = newIterator - case Right (newBytes) => bytesAfterPut = newBytes - case _ => - } - - // Keep track of which blocks are dropped from memory - if (putLevel.useMemory) { - result.droppedBlocks.foreach { updatedBlocks += _ } + } else if (level.useDisk) { + diskStore.put(blockId) { fileOutputStream => + serializerManager.dataSerializeStream(blockId, fileOutputStream, iterator())(classTag) } + size = diskStore.getSize(blockId) + } - val putBlockStatus = getCurrentBlockStatus(blockId, putBlockInfo) - if (putBlockStatus.storageLevel != StorageLevel.NONE) { - // Now that the block is in either the memory, externalBlockStore, or disk store, - // let other threads read it, and tell the master about it. - marked = true - putBlockInfo.markReady(size) - if (tellMaster) { - reportBlockStatus(blockId, putBlockInfo, putBlockStatus) - } - updatedBlocks += ((blockId, putBlockStatus)) + val putBlockStatus = getCurrentBlockStatus(blockId, info) + val blockWasSuccessfullyStored = putBlockStatus.storageLevel.isValid + if (blockWasSuccessfullyStored) { + // Now that the block is in either the memory, externalBlockStore, or disk store, + // tell the master about it. + info.size = size + if (tellMaster) { + reportBlockStatus(blockId, info, putBlockStatus) } - } finally { - // If we failed in putting the block to memory/disk, notify other possible readers - // that it has failed, and then remove it from the block info map. - if (!marked) { - // Note that the remove must happen before markFailure otherwise another thread - // could've inserted a new BlockInfo before we remove it. - blockInfo.remove(blockId) - putBlockInfo.markFailure() - logWarning(s"Putting block $blockId failed") + Option(TaskContext.get()).foreach { c => + c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, putBlockStatus))) } - } - } - logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs))) - - // Either we're storing bytes and we asynchronously started replication, or we're storing - // values and need to serialize and replicate them now: - if (putLevel.replication > 1) { - data match { - case ByteBufferValues(bytes) => - if (replicationFuture != null) { - Await.ready(replicationFuture, Duration.Inf) - } - case _ => + logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs))) + if (level.replication > 1) { val remoteStartTime = System.currentTimeMillis - // Serialize the block if not already done - if (bytesAfterPut == null) { - if (valuesAfterPut == null) { - throw new SparkException( - "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.") - } - bytesAfterPut = dataSerialize(blockId, valuesAfterPut) + val bytesToReplicate = doGetLocalBytes(blockId, info) + try { + replicate(blockId, bytesToReplicate, level, classTag) + } finally { + bytesToReplicate.dispose() } - replicate(blockId, bytesAfterPut, putLevel) logDebug("Put block %s remotely took %s" .format(blockId, Utils.getUsedTimeMs(remoteStartTime))) + } } + assert(blockWasSuccessfullyStored == iteratorFromFailedMemoryStorePut.isEmpty) + iteratorFromFailedMemoryStorePut } + } - BlockManager.dispose(bytesAfterPut) - - if (putLevel.replication > 1) { - logDebug("Putting block %s with replication took %s" - .format(blockId, Utils.getUsedTimeMs(startTimeMs))) + /** + * Attempts to cache spilled bytes read from disk into the MemoryStore in order to speed up + * subsequent reads. This method requires the caller to hold a read lock on the block. + * + * @return a copy of the bytes from the memory store if the put succeeded, otherwise None. + * If this returns bytes from the memory store then the original disk store bytes will + * automatically be disposed and the caller should not continue to use them. Otherwise, + * if this returns None then the original disk store bytes will be unaffected. + */ + private def maybeCacheDiskBytesInMemory( + blockInfo: BlockInfo, + blockId: BlockId, + level: StorageLevel, + diskBytes: ChunkedByteBuffer): Option[ChunkedByteBuffer] = { + require(!level.deserialized) + if (level.useMemory) { + // Synchronize on blockInfo to guard against a race condition where two readers both try to + // put values read from disk into the MemoryStore. + blockInfo.synchronized { + if (memoryStore.contains(blockId)) { + diskBytes.dispose() + Some(memoryStore.getBytes(blockId).get) + } else { + val allocator = level.memoryMode match { + case MemoryMode.ON_HEAP => ByteBuffer.allocate _ + case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _ + } + val putSucceeded = memoryStore.putBytes(blockId, diskBytes.size, level.memoryMode, () => { + // https://issues.apache.org/jira/browse/SPARK-6076 + // If the file size is bigger than the free memory, OOM will happen. So if we + // cannot put it into MemoryStore, copyForMemory should not be created. That's why + // this action is put into a `() => ChunkedByteBuffer` and created lazily. + diskBytes.copy(allocator) + }) + if (putSucceeded) { + diskBytes.dispose() + Some(memoryStore.getBytes(blockId).get) + } else { + None + } + } + } } else { - logDebug("Putting block %s without replication took %s" - .format(blockId, Utils.getUsedTimeMs(startTimeMs))) + None } + } - updatedBlocks + /** + * Attempts to cache spilled values read from disk into the MemoryStore in order to speed up + * subsequent reads. This method requires the caller to hold a read lock on the block. + * + * @return a copy of the iterator. The original iterator passed this method should no longer + * be used after this method returns. + */ + private def maybeCacheDiskValuesInMemory[T]( + blockInfo: BlockInfo, + blockId: BlockId, + level: StorageLevel, + diskIterator: Iterator[T]): Iterator[T] = { + require(level.deserialized) + val classTag = blockInfo.classTag.asInstanceOf[ClassTag[T]] + if (level.useMemory) { + // Synchronize on blockInfo to guard against a race condition where two readers both try to + // put values read from disk into the MemoryStore. + blockInfo.synchronized { + if (memoryStore.contains(blockId)) { + // Note: if we had a means to discard the disk iterator, we would do that here. + memoryStore.getValues(blockId).get + } else { + memoryStore.putIteratorAsValues(blockId, diskIterator, classTag) match { + case Left(iter) => + // The memory store put() failed, so it returned the iterator back to us: + iter + case Right(_) => + // The put() succeeded, so we can read the values back: + memoryStore.getValues(blockId).get + } + } + }.asInstanceOf[Iterator[T]] + } else { + diskIterator + } } /** @@ -897,14 +1072,22 @@ private[spark] class BlockManager( * Replicate block to another node. Not that this is a blocking call that returns after * the block has been replicated. */ - private def replicate(blockId: BlockId, data: ByteBuffer, level: StorageLevel): Unit = { + private def replicate( + blockId: BlockId, + data: ChunkedByteBuffer, + level: StorageLevel, + classTag: ClassTag[_]): Unit = { val maxReplicationFailures = conf.getInt("spark.storage.maxReplicationFailures", 1) val numPeersToReplicateTo = level.replication - 1 val peersForReplication = new ArrayBuffer[BlockManagerId] val peersReplicatedTo = new ArrayBuffer[BlockManagerId] val peersFailedToReplicateTo = new ArrayBuffer[BlockManagerId] val tLevel = StorageLevel( - level.useDisk, level.useMemory, level.useOffHeap, level.deserialized, 1) + useDisk = level.useDisk, + useMemory = level.useMemory, + useOffHeap = level.useOffHeap, + deserialized = level.deserialized, + replication = 1) val startTime = System.currentTimeMillis val random = new Random(blockId.hashCode) @@ -951,11 +1134,16 @@ private[spark] class BlockManager( case Some(peer) => try { val onePeerStartTime = System.currentTimeMillis - data.rewind() - logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer") + logTrace(s"Trying to replicate $blockId of ${data.size} bytes to $peer") blockTransferService.uploadBlockSync( - peer.host, peer.port, peer.executorId, blockId, new NioManagedBuffer(data), tLevel) - logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %s ms" + peer.host, + peer.port, + peer.executorId, + blockId, + new NettyManagedBuffer(data.toNetty), + tLevel, + classTag) + logTrace(s"Replicated $blockId of ${data.size} bytes to $peer in %s ms" .format(System.currentTimeMillis - onePeerStartTime)) peersReplicatedTo += peer peersForReplication -= peer @@ -969,7 +1157,7 @@ private[spark] class BlockManager( failures += 1 replicationFailed = true peersFailedToReplicateTo += peer - if (failures > maxReplicationFailures) { // too many failures in replcating to peers + if (failures > maxReplicationFailures) { // too many failures in replicating to peers done = true } } @@ -978,7 +1166,7 @@ private[spark] class BlockManager( } } val timeTakeMs = (System.currentTimeMillis - startTime) - logDebug(s"Replicating $blockId of ${data.limit()} bytes to " + + logDebug(s"Replicating $blockId of ${data.size} bytes to " + s"${peersReplicatedTo.size} peer(s) took $timeTakeMs ms") if (peersReplicatedTo.size < numPeersToReplicateTo) { logWarning(s"Block $blockId replicated to only " + @@ -995,98 +1183,85 @@ private[spark] class BlockManager( /** * Write a block consisting of a single object. + * + * @return true if the block was stored or false if the block was already stored or an + * error occurred. */ - def putSingle( + def putSingle[T: ClassTag]( blockId: BlockId, - value: Any, + value: T, level: StorageLevel, - tellMaster: Boolean = true): Seq[(BlockId, BlockStatus)] = { + tellMaster: Boolean = true): Boolean = { putIterator(blockId, Iterator(value), level, tellMaster) } - def dropFromMemory( - blockId: BlockId, - data: Either[Array[Any], ByteBuffer]): Option[BlockStatus] = { - dropFromMemory(blockId, () => data) - } - /** * Drop a block from memory, possibly putting it on disk if applicable. Called when the memory * store reaches its limit and needs to free up space. * * If `data` is not put on disk, it won't be created. * - * Return the block status if the given block has been updated, else None. + * The caller of this method must hold a write lock on the block before calling this method. + * This method does not release the write lock. + * + * @return the block's new effective StorageLevel. */ - def dropFromMemory( + private[storage] override def dropFromMemory[T: ClassTag]( blockId: BlockId, - data: () => Either[Array[Any], ByteBuffer]): Option[BlockStatus] = { - + data: () => Either[Array[T], ChunkedByteBuffer]): StorageLevel = { logInfo(s"Dropping block $blockId from memory") - val info = blockInfo.get(blockId).orNull - - // If the block has not already been dropped - if (info != null) { - info.synchronized { - // required ? As of now, this will be invoked only for blocks which are ready - // But in case this changes in future, adding for consistency sake. - if (!info.waitForReady()) { - // If we get here, the block write failed. - logWarning(s"Block $blockId was marked as failure. Nothing to drop") - return None - } else if (blockInfo.get(blockId).isEmpty) { - logWarning(s"Block $blockId was already dropped.") - return None - } - var blockIsUpdated = false - val level = info.level - - // Drop to disk, if storage level requires - if (level.useDisk && !diskStore.contains(blockId)) { - logInfo(s"Writing block $blockId to disk") - data() match { - case Left(elements) => - diskStore.putArray(blockId, elements, level, returnValues = false) - case Right(bytes) => - diskStore.putBytes(blockId, bytes, level) + val info = blockInfoManager.assertBlockIsLockedForWriting(blockId) + var blockIsUpdated = false + val level = info.level + + // Drop to disk, if storage level requires + if (level.useDisk && !diskStore.contains(blockId)) { + logInfo(s"Writing block $blockId to disk") + data() match { + case Left(elements) => + diskStore.put(blockId) { fileOutputStream => + serializerManager.dataSerializeStream( + blockId, + fileOutputStream, + elements.toIterator)(info.classTag.asInstanceOf[ClassTag[T]]) } - blockIsUpdated = true - } + case Right(bytes) => + diskStore.putBytes(blockId, bytes) + } + blockIsUpdated = true + } - // Actually drop from memory store - val droppedMemorySize = - if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L - val blockIsRemoved = memoryStore.remove(blockId) - if (blockIsRemoved) { - blockIsUpdated = true - } else { - logWarning(s"Block $blockId could not be dropped from memory as it does not exist") - } + // Actually drop from memory store + val droppedMemorySize = + if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L + val blockIsRemoved = memoryStore.remove(blockId) + if (blockIsRemoved) { + blockIsUpdated = true + } else { + logWarning(s"Block $blockId could not be dropped from memory as it does not exist") + } - val status = getCurrentBlockStatus(blockId, info) - if (info.tellMaster) { - reportBlockStatus(blockId, info, status, droppedMemorySize) - } - if (!level.useDisk) { - // The block is completely gone from this node; forget it so we can put() it again later. - blockInfo.remove(blockId) - } - if (blockIsUpdated) { - return Some(status) - } + val status = getCurrentBlockStatus(blockId, info) + if (info.tellMaster) { + reportBlockStatus(blockId, info, status, droppedMemorySize) + } + if (blockIsUpdated) { + Option(TaskContext.get()).foreach { c => + c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, status))) } } - None + status.storageLevel } /** * Remove all blocks belonging to the given RDD. + * * @return The number of blocks removed. */ def removeRdd(rddId: Int): Int = { // TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks. logInfo(s"Removing RDD $rddId") - val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) + val blocksToRemove = blockInfoManager.entries.flatMap(_._1.asRDDId).filter(_.rddId == rddId) blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) } blocksToRemove.size } @@ -1096,7 +1271,7 @@ private[spark] class BlockManager( */ def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = { logDebug(s"Removing broadcast $broadcastId") - val blocksToRemove = blockInfo.keys.collect { + val blocksToRemove = blockInfoManager.entries.map(_._1).collect { case bid @ BroadcastBlockId(`broadcastId`, _) => bid } blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster) } @@ -1108,130 +1283,29 @@ private[spark] class BlockManager( */ def removeBlock(blockId: BlockId, tellMaster: Boolean = true): Unit = { logDebug(s"Removing block $blockId") - val info = blockInfo.get(blockId).orNull - if (info != null) { - info.synchronized { + blockInfoManager.lockForWriting(blockId) match { + case None => + // The block has already been removed; do nothing. + logWarning(s"Asked to remove block $blockId, which does not exist") + case Some(info) => // Removals are idempotent in disk store and memory store. At worst, we get a warning. val removedFromMemory = memoryStore.remove(blockId) val removedFromDisk = diskStore.remove(blockId) - val removedFromExternalBlockStore = - if (externalBlockStoreInitialized) externalBlockStore.remove(blockId) else false - if (!removedFromMemory && !removedFromDisk && !removedFromExternalBlockStore) { + if (!removedFromMemory && !removedFromDisk) { logWarning(s"Block $blockId could not be removed as it was not found in either " + "the disk, memory, or external block store") } - blockInfo.remove(blockId) + blockInfoManager.removeBlock(blockId) + val removeBlockStatus = getCurrentBlockStatus(blockId, info) if (tellMaster && info.tellMaster) { - val status = getCurrentBlockStatus(blockId, info) - reportBlockStatus(blockId, info, status) + reportBlockStatus(blockId, info, removeBlockStatus) } - } - } else { - // The block has already been removed; do nothing. - logWarning(s"Asked to remove block $blockId, which does not exist") - } - } - - private def dropOldNonBroadcastBlocks(cleanupTime: Long): Unit = { - logInfo(s"Dropping non broadcast blocks older than $cleanupTime") - dropOldBlocks(cleanupTime, !_.isBroadcast) - } - - private def dropOldBroadcastBlocks(cleanupTime: Long): Unit = { - logInfo(s"Dropping broadcast blocks older than $cleanupTime") - dropOldBlocks(cleanupTime, _.isBroadcast) - } - - private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)): Unit = { - val iterator = blockInfo.getEntrySet.iterator - while (iterator.hasNext) { - val entry = iterator.next() - val (id, info, time) = (entry.getKey, entry.getValue.value, entry.getValue.timestamp) - if (time < cleanupTime && shouldDrop(id)) { - info.synchronized { - val level = info.level - if (level.useMemory) { memoryStore.remove(id) } - if (level.useDisk) { diskStore.remove(id) } - if (level.useOffHeap) { externalBlockStore.remove(id) } - iterator.remove() - logInfo(s"Dropped block $id") + Option(TaskContext.get()).foreach { c => + c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, removeBlockStatus))) } - val status = getCurrentBlockStatus(id, info) - reportBlockStatus(id, info, status) - } - } - } - - private def shouldCompress(blockId: BlockId): Boolean = { - blockId match { - case _: ShuffleBlockId => compressShuffle - case _: BroadcastBlockId => compressBroadcast - case _: RDDBlockId => compressRdds - case _: TempLocalBlockId => compressShuffleSpill - case _: TempShuffleBlockId => compressShuffle - case _ => false } } - /** - * Wrap an output stream for compression if block compression is enabled for its block type - */ - def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { - if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s - } - - /** - * Wrap an input stream for compression if block compression is enabled for its block type - */ - def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { - if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s - } - - /** Serializes into a stream. */ - def dataSerializeStream( - blockId: BlockId, - outputStream: OutputStream, - values: Iterator[Any], - serializer: Serializer = defaultSerializer): Unit = { - val byteStream = new BufferedOutputStream(outputStream) - val ser = serializer.newInstance() - ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() - } - - /** Serializes into a byte buffer. */ - def dataSerialize( - blockId: BlockId, - values: Iterator[Any], - serializer: Serializer = defaultSerializer): ByteBuffer = { - val byteStream = new ByteArrayOutputStream(4096) - dataSerializeStream(blockId, byteStream, values, serializer) - ByteBuffer.wrap(byteStream.toByteArray) - } - - /** - * Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of - * the iterator is reached. - */ - def dataDeserialize( - blockId: BlockId, - bytes: ByteBuffer, - serializer: Serializer = defaultSerializer): Iterator[Any] = { - bytes.rewind() - dataDeserializeStream(blockId, new ByteBufferInputStream(bytes, true), serializer) - } - - /** - * Deserializes a InputStream into an iterator of values and disposes of it when the end of - * the iterator is reached. - */ - def dataDeserializeStream( - blockId: BlockId, - inputStream: InputStream, - serializer: Serializer = defaultSerializer): Iterator[Any] = { - val stream = new BufferedInputStream(inputStream) - serializer.newInstance().deserializeStream(wrapForCompression(blockId, stream)).asIterator - } - def stop(): Unit = { blockTransferService.close() if (shuffleClient ne blockTransferService) { @@ -1240,38 +1314,17 @@ private[spark] class BlockManager( } diskBlockManager.stop() rpcEnv.stop(slaveEndpoint) - blockInfo.clear() + blockInfoManager.clear() memoryStore.clear() - diskStore.clear() - if (externalBlockStoreInitialized) { - externalBlockStore.clear() - } - metadataCleaner.cancel() - broadcastCleaner.cancel() futureExecutionContext.shutdownNow() logInfo("BlockManager stopped") } } -private[spark] object BlockManager extends Logging { +private[spark] object BlockManager { private val ID_GENERATOR = new IdGenerator - /** - * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that - * might cause errors if one attempts to read from the unmapped buffer, but it's better than - * waiting for the GC to find it because that could lead to huge numbers of open files. There's - * unfortunately no standard API to do this. - */ - def dispose(buffer: ByteBuffer): Unit = { - if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { - logTrace(s"Unmapping $buffer") - if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) { - buffer.asInstanceOf[DirectBuffer].cleaner().clean() - } - } - } - def blockIdsToHosts( blockIds: Array[BlockId], env: SparkEnv, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index 69ac37511e730..cae7c9ed952f1 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -44,7 +44,7 @@ class BlockManagerId private ( def executorId: String = executorId_ - if (null != host_){ + if (null != host_) { Utils.checkHost(host_, "Expected hostname") assert (port_ > 0) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala new file mode 100644 index 0000000000000..f66f942798550 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.apache.spark.network.buffer.{ManagedBuffer, NettyManagedBuffer} +import org.apache.spark.util.io.ChunkedByteBuffer + +/** + * This [[ManagedBuffer]] wraps a [[ChunkedByteBuffer]] retrieved from the [[BlockManager]] + * so that the corresponding block's read lock can be released once this buffer's references + * are released. + * + * This is effectively a wrapper / bridge to connect the BlockManager's notion of read locks + * to the network layer's notion of retain / release counts. + */ +private[storage] class BlockManagerManagedBuffer( + blockInfoManager: BlockInfoManager, + blockId: BlockId, + chunkedBuffer: ChunkedByteBuffer) extends NettyManagedBuffer(chunkedBuffer.toNetty) { + + override def retain(): ManagedBuffer = { + super.retain() + val locked = blockInfoManager.lockForReading(blockId, blocking = false) + assert(locked.isDefined) + this + } + + override def release(): ManagedBuffer = { + blockInfoManager.unlock(blockId) + super.release() + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index f45bff34d4dbc..c22d2e0fb61fa 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -19,12 +19,13 @@ package org.apache.spark.storage import scala.collection.Iterable import scala.collection.generic.CanBuildFrom -import scala.concurrent.{Await, Future} +import scala.concurrent.Future +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.{ThreadUtils, RpcUtils} +import org.apache.spark.util.{RpcUtils, ThreadUtils} private[spark] class BlockManagerMaster( @@ -54,11 +55,9 @@ class BlockManagerMaster( blockId: BlockId, storageLevel: StorageLevel, memSize: Long, - diskSize: Long, - externalBlockStoreSize: Long): Boolean = { + diskSize: Long): Boolean = { val res = driverEndpoint.askWithRetry[Boolean]( - UpdateBlockInfo(blockManagerId, blockId, storageLevel, - memSize, diskSize, externalBlockStoreSize)) + UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize)) logDebug(s"Updated info of block $blockId") res } @@ -87,8 +86,8 @@ class BlockManagerMaster( driverEndpoint.askWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId)) } - def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { - driverEndpoint.askWithRetry[Option[(String, Int)]](GetRpcHostPortForExecutor(executorId)) + def getExecutorEndpointRef(executorId: String): Option[RpcEndpointRef] = { + driverEndpoint.askWithRetry[Option[RpcEndpointRef]](GetExecutorEndpointRef(executorId)) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 7db6035553ae6..8fa12150114db 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -19,14 +19,14 @@ package org.apache.spark.storage import java.util.{HashMap => JHashMap} -import scala.collection.immutable.HashSet import scala.collection.mutable import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} -import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, ThreadSafeRpcEndpoint} -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.{ThreadUtils, Utils} @@ -60,10 +60,9 @@ class BlockManagerMasterEndpoint( register(blockManagerId, maxMemSize, slaveEndpoint) context.reply(true) - case _updateBlockInfo @ UpdateBlockInfo( - blockManagerId, blockId, storageLevel, deserializedSize, size, externalBlockStoreSize) => - context.reply(updateBlockInfo( - blockManagerId, blockId, storageLevel, deserializedSize, size, externalBlockStoreSize)) + case _updateBlockInfo @ + UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => + context.reply(updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size)) listenerBus.post(SparkListenerBlockUpdated(BlockUpdatedInfo(_updateBlockInfo))) case GetLocations(blockId) => @@ -75,8 +74,8 @@ class BlockManagerMasterEndpoint( case GetPeers(blockManagerId) => context.reply(getPeers(blockManagerId)) - case GetRpcHostPortForExecutor(executorId) => - context.reply(getRpcHostPortForExecutor(executorId)) + case GetExecutorEndpointRef(executorId) => + context.reply(getExecutorEndpointRef(executorId)) case GetMemoryStatus => context.reply(memoryStatus) @@ -326,8 +325,7 @@ class BlockManagerMasterEndpoint( blockId: BlockId, storageLevel: StorageLevel, memSize: Long, - diskSize: Long, - externalBlockStoreSize: Long): Boolean = { + diskSize: Long): Boolean = { if (!blockManagerInfo.contains(blockManagerId)) { if (blockManagerId.isDriver && !isLocal) { @@ -344,8 +342,7 @@ class BlockManagerMasterEndpoint( return true } - blockManagerInfo(blockManagerId).updateBlockInfo( - blockId, storageLevel, memSize, diskSize, externalBlockStoreSize) + blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize) var locations: mutable.HashSet[BlockManagerId] = null if (blockLocations.containsKey(blockId)) { @@ -388,15 +385,14 @@ class BlockManagerMasterEndpoint( } /** - * Returns the hostname and port of an executor, based on the [[RpcEnv]] address of its - * [[BlockManagerSlaveEndpoint]]. + * Returns an [[RpcEndpointRef]] of the [[BlockManagerSlaveEndpoint]] for sending RPC messages. */ - private def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { + private def getExecutorEndpointRef(executorId: String): Option[RpcEndpointRef] = { for ( blockManagerId <- blockManagerIdByExecutor.get(executorId); info <- blockManagerInfo.get(blockManagerId) ) yield { - (info.slaveEndpoint.address.host, info.slaveEndpoint.address.port) + info.slaveEndpoint } } @@ -406,17 +402,13 @@ class BlockManagerMasterEndpoint( } @DeveloperApi -case class BlockStatus( - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long, - externalBlockStoreSize: Long) { - def isCached: Boolean = memSize + diskSize + externalBlockStoreSize > 0 +case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) { + def isCached: Boolean = memSize + diskSize > 0 } @DeveloperApi object BlockStatus { - def empty: BlockStatus = BlockStatus(StorageLevel.NONE, 0L, 0L, 0L) + def empty: BlockStatus = BlockStatus(StorageLevel.NONE, memSize = 0L, diskSize = 0L) } private[spark] class BlockManagerInfo( @@ -445,8 +437,7 @@ private[spark] class BlockManagerInfo( blockId: BlockId, storageLevel: StorageLevel, memSize: Long, - diskSize: Long, - externalBlockStoreSize: Long) { + diskSize: Long) { updateLastSeenMs() @@ -462,7 +453,7 @@ private[spark] class BlockManagerInfo( } if (storageLevel.isValid) { - /* isValid means it is either stored in-memory, on-disk or on-externalBlockStore. + /* isValid means it is either stored in-memory or on-disk. * The memSize here indicates the data size in or dropped from memory, * externalBlockStoreSize here indicates the data size in or dropped from externalBlockStore, * and the diskSize here indicates the data size in or dropped to disk. @@ -470,7 +461,7 @@ private[spark] class BlockManagerInfo( * Therefore, a safe way to set BlockStatus is to set its info in accurate modes. */ var blockStatus: BlockStatus = null if (storageLevel.useMemory) { - blockStatus = BlockStatus(storageLevel, memSize, 0, 0) + blockStatus = BlockStatus(storageLevel, memSize = memSize, diskSize = 0) _blocks.put(blockId, blockStatus) _remainingMem -= memSize logInfo("Added %s in memory on %s (size: %s, free: %s)".format( @@ -478,17 +469,11 @@ private[spark] class BlockManagerInfo( Utils.bytesToString(_remainingMem))) } if (storageLevel.useDisk) { - blockStatus = BlockStatus(storageLevel, 0, diskSize, 0) + blockStatus = BlockStatus(storageLevel, memSize = 0, diskSize = diskSize) _blocks.put(blockId, blockStatus) logInfo("Added %s on disk on %s (size: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize))) } - if (storageLevel.useOffHeap) { - blockStatus = BlockStatus(storageLevel, 0, 0, externalBlockStoreSize) - _blocks.put(blockId, blockStatus) - logInfo("Added %s on ExternalBlockStore on %s (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(externalBlockStoreSize))) - } if (!blockId.isBroadcast && blockStatus.isCached) { _cachedBlocks += blockId } @@ -506,11 +491,6 @@ private[spark] class BlockManagerInfo( logInfo("Removed %s on %s on disk (size: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize))) } - if (blockStatus.storageLevel.useOffHeap) { - logInfo("Removed %s on %s on externalBlockStore (size: %s)".format( - blockId, blockManagerId.hostPort, - Utils.bytesToString(blockStatus.externalBlockStoreSize))) - } } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 376e9eb48843d..6bded92700504 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -42,6 +42,11 @@ private[spark] object BlockManagerMessages { case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) extends ToBlockManagerSlave + /** + * Driver -> Executor message to trigger a thread dump. + */ + case object TriggerThreadDump extends ToBlockManagerSlave + ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. ////////////////////////////////////////////////////////////////////////////////// @@ -58,12 +63,11 @@ private[spark] object BlockManagerMessages { var blockId: BlockId, var storageLevel: StorageLevel, var memSize: Long, - var diskSize: Long, - var externalBlockStoreSize: Long) + var diskSize: Long) extends ToBlockManagerMaster with Externalizable { - def this() = this(null, null, null, 0, 0, 0) // For deserialization only + def this() = this(null, null, null, 0, 0) // For deserialization only override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { blockManagerId.writeExternal(out) @@ -71,7 +75,6 @@ private[spark] object BlockManagerMessages { storageLevel.writeExternal(out) out.writeLong(memSize) out.writeLong(diskSize) - out.writeLong(externalBlockStoreSize) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -80,7 +83,6 @@ private[spark] object BlockManagerMessages { storageLevel = StorageLevel(in) memSize = in.readLong() diskSize = in.readLong() - externalBlockStoreSize = in.readLong() } } @@ -90,7 +92,7 @@ private[spark] object BlockManagerMessages { case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - case class GetRpcHostPortForExecutor(executorId: String) extends ToBlockManagerMaster + case class GetExecutorEndpointRef(executorId: String) extends ToBlockManagerMaster case class RemoveExecutor(execId: String) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index e749631bf6f19..d17ddbc162579 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -19,10 +19,11 @@ package org.apache.spark.storage import scala.concurrent.{ExecutionContext, Future} -import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext, RpcEndpoint} -import org.apache.spark.util.ThreadUtils -import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} +import org.apache.spark.{MapOutputTracker, SparkEnv} +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.storage.BlockManagerMessages._ +import org.apache.spark.util.{ThreadUtils, Utils} /** * An RpcEndpoint to take commands from the master to execute options. For example, @@ -70,6 +71,9 @@ class BlockManagerSlaveEndpoint( case GetMatchingBlockIds(filter, _) => context.reply(blockManager.getMatchingBlockIds(filter)) + + case TriggerThreadDump => + context.reply(Utils.getThreadDump()) } private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala index 2789e25b8d3ab..0a14fcadf53e0 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockStatusListener.scala @@ -26,8 +26,7 @@ private[spark] case class BlockUIData( location: String, storageLevel: StorageLevel, memSize: Long, - diskSize: Long, - externalBlockStoreSize: Long) + diskSize: Long) /** * The aggregated status of stream blocks in an executor @@ -41,8 +40,6 @@ private[spark] case class ExecutorStreamBlockStatus( def totalDiskSize: Long = blocks.map(_.diskSize).sum - def totalExternalBlockStoreSize: Long = blocks.map(_.externalBlockStoreSize).sum - def numStreamBlocks: Int = blocks.size } @@ -62,7 +59,6 @@ private[spark] class BlockStatusListener extends SparkListener { val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel val memSize = blockUpdated.blockUpdatedInfo.memSize val diskSize = blockUpdated.blockUpdatedInfo.diskSize - val externalBlockStoreSize = blockUpdated.blockUpdatedInfo.externalBlockStoreSize synchronized { // Drop the update info if the block manager is not registered @@ -74,8 +70,7 @@ private[spark] class BlockStatusListener extends SparkListener { blockManagerId.hostPort, storageLevel, memSize, - diskSize, - externalBlockStoreSize) + diskSize) ) } else { // If isValid is not true, it means we should drop the block. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala deleted file mode 100644 index 69985c9759e2d..0000000000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.Logging - -/** - * Abstract class to store blocks. - */ -private[spark] abstract class BlockStore(val blockManager: BlockManager) extends Logging { - - def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel): PutResult - - /** - * Put in a block and, possibly, also return its content as either bytes or another Iterator. - * This is used to efficiently write the values to multiple locations (e.g. for replication). - * - * @return a PutResult that contains the size of the data, as well as the values put if - * returnValues is true (if not, the result's data field can be null) - */ - def putIterator( - blockId: BlockId, - values: Iterator[Any], - level: StorageLevel, - returnValues: Boolean): PutResult - - def putArray( - blockId: BlockId, - values: Array[Any], - level: StorageLevel, - returnValues: Boolean): PutResult - - /** - * Return the size of a block in bytes. - */ - def getSize(blockId: BlockId): Long - - def getBytes(blockId: BlockId): Option[ByteBuffer] - - def getValues(blockId: BlockId): Option[Iterator[Any]] - - /** - * Remove a block, if it exists. - * @param blockId the block to remove. - * @return True if the block was found and removed, False otherwise. - */ - def remove(blockId: BlockId): Boolean - - def contains(blockId: BlockId): Boolean - - def clear() { } -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala index a5790e4454a89..e070bf658acb8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockUpdatedInfo.scala @@ -30,8 +30,7 @@ case class BlockUpdatedInfo( blockId: BlockId, storageLevel: StorageLevel, memSize: Long, - diskSize: Long, - externalBlockStoreSize: Long) + diskSize: Long) private[spark] object BlockUpdatedInfo { @@ -41,7 +40,6 @@ private[spark] object BlockUpdatedInfo { updateBlockInfo.blockId, updateBlockInfo.storageLevel, updateBlockInfo.memSize, - updateBlockInfo.diskSize, - updateBlockInfo.externalBlockStoreSize) + updateBlockInfo.diskSize) } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index f7e84a2c2e14c..0666be2dcb019 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -17,27 +17,24 @@ package org.apache.spark.storage +import java.io.{File, IOException} import java.util.UUID -import java.io.{IOException, File} -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.SparkConf import org.apache.spark.executor.ExecutorExitCode +import org.apache.spark.internal.Logging import org.apache.spark.util.{ShutdownHookManager, Utils} /** * Creates and maintains the logical mapping between logical blocks and physical on-disk - * locations. By default, one block is mapped to one file with a name given by its BlockId. - * However, it is also possible to have a block map to only a segment of a file, by calling - * mapBlockToFileSegment(). + * locations. One block is mapped to one file with a name given by its BlockId. * * Block files are hashed among the directories listed in spark.local.dir (or in * SPARK_LOCAL_DIRS, if it's set). */ -private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkConf) - extends Logging { +private[spark] class DiskBlockManager(conf: SparkConf, deleteFilesOnStop: Boolean) extends Logging { - private[spark] - val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64) + private[spark] val subDirsPerLocalDir = conf.getInt("spark.diskStore.subDirectories", 64) /* Create one local directory for each path mentioned in spark.local.dir; then, inside this * directory, create multiple subdirectories that we will hash files into, in order to avoid @@ -163,10 +160,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon } private def doStop(): Unit = { - // Only perform cleanup if an external service is not serving our shuffle files. - // Also blockManagerId could be null if block manager is not initialized properly. - if (!blockManager.externalShuffleServiceEnabled || - (blockManager.blockManagerId != null && blockManager.blockManagerId.isDriver)) { + if (deleteFilesOnStop) { localDirs.foreach { localDir => if (localDir.isDirectory() && localDir.exists()) { try { diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index 80d426fadc65e..ab97d2e4b8b78 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -17,12 +17,12 @@ package org.apache.spark.storage -import java.io.{BufferedOutputStream, FileOutputStream, File, OutputStream} +import java.io.{BufferedOutputStream, File, FileOutputStream, OutputStream} import java.nio.channels.FileChannel -import org.apache.spark.Logging -import org.apache.spark.serializer.{SerializerInstance, SerializationStream} import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.internal.Logging +import org.apache.spark.serializer.{SerializationStream, SerializerInstance} import org.apache.spark.util.Utils /** @@ -34,14 +34,15 @@ import org.apache.spark.util.Utils * reopened again. */ private[spark] class DiskBlockObjectWriter( - file: File, + val file: File, serializerInstance: SerializerInstance, bufferSize: Int, compressStream: OutputStream => OutputStream, syncWrites: Boolean, // These write metrics concurrently shared with other active DiskBlockObjectWriters who // are themselves performing writes. All updates must be relative. - writeMetrics: ShuffleWriteMetrics) + writeMetrics: ShuffleWriteMetrics, + val blockId: BlockId = null) extends OutputStream with Logging { @@ -101,7 +102,7 @@ private[spark] class DiskBlockObjectWriter( objOut.flush() val start = System.nanoTime() fos.getFD.sync() - writeMetrics.incShuffleWriteTime(System.nanoTime() - start) + writeMetrics.incWriteTime(System.nanoTime() - start) } } { objOut.close() @@ -131,7 +132,7 @@ private[spark] class DiskBlockObjectWriter( close() finalPosition = file.length() // In certain compression codecs, more bytes are written after close() is called - writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition) + writeMetrics.incBytesWritten(finalPosition - reportedPosition) } else { finalPosition = file.length() } @@ -151,8 +152,8 @@ private[spark] class DiskBlockObjectWriter( // truncating the file to its initial position. try { if (initialized) { - writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) - writeMetrics.decShuffleRecordsWritten(numRecordsWritten) + writeMetrics.decBytesWritten(reportedPosition - initialPosition) + writeMetrics.decRecordsWritten(numRecordsWritten) objOut.flush() bs.flush() close() @@ -200,8 +201,9 @@ private[spark] class DiskBlockObjectWriter( */ def recordWritten(): Unit = { numRecordsWritten += 1 - writeMetrics.incShuffleRecordsWritten(1) + writeMetrics.incRecordsWritten(1) + // TODO: call updateBytesWritten() less frequently. if (numRecordsWritten % 32 == 0) { updateBytesWritten() } @@ -225,7 +227,7 @@ private[spark] class DiskBlockObjectWriter( */ private def updateBytesWritten() { val pos = channel.position() - writeMetrics.incShuffleBytesWritten(pos - reportedPosition) + writeMetrics.incBytesWritten(pos - reportedPosition) reportedPosition = pos } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index c008b9dc16327..ca23e2391ed02 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -17,144 +17,97 @@ package org.apache.spark.storage -import java.io.{IOException, File, FileOutputStream, RandomAccessFile} +import java.io.{FileOutputStream, IOException, RandomAccessFile} import java.nio.ByteBuffer import java.nio.channels.FileChannel.MapMode -import org.apache.spark.Logging -import org.apache.spark.serializer.Serializer +import com.google.common.io.Closeables + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging import org.apache.spark.util.Utils +import org.apache.spark.util.io.ChunkedByteBuffer /** * Stores BlockManager blocks on disk. */ -private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager) - extends BlockStore(blockManager) with Logging { +private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) extends Logging { - val minMemoryMapBytes = blockManager.conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") + private val minMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") - override def getSize(blockId: BlockId): Long = { + def getSize(blockId: BlockId): Long = { diskManager.getFile(blockId.name).length } - override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel): PutResult = { - // So that we do not modify the input offsets ! - // duplicate does not copy buffer, so inexpensive - val bytes = _bytes.duplicate() + /** + * Invokes the provided callback function to write the specific block. + * + * @throws IllegalStateException if the block already exists in the disk store. + */ + def put(blockId: BlockId)(writeFunc: FileOutputStream => Unit): Unit = { + if (contains(blockId)) { + throw new IllegalStateException(s"Block $blockId is already present in the disk store") + } logDebug(s"Attempting to put block $blockId") val startTime = System.currentTimeMillis val file = diskManager.getFile(blockId) - val channel = new FileOutputStream(file).getChannel - Utils.tryWithSafeFinally { - while (bytes.remaining > 0) { - channel.write(bytes) + val fileOutputStream = new FileOutputStream(file) + var threwException: Boolean = true + try { + writeFunc(fileOutputStream) + threwException = false + } finally { + try { + Closeables.close(fileOutputStream, threwException) + } finally { + if (threwException) { + remove(blockId) + } } - } { - channel.close() } val finishTime = System.currentTimeMillis logDebug("Block %s stored as %s file on disk in %d ms".format( - file.getName, Utils.bytesToString(bytes.limit), finishTime - startTime)) - PutResult(bytes.limit(), Right(bytes.duplicate())) + file.getName, + Utils.bytesToString(file.length()), + finishTime - startTime)) } - override def putArray( - blockId: BlockId, - values: Array[Any], - level: StorageLevel, - returnValues: Boolean): PutResult = { - putIterator(blockId, values.toIterator, level, returnValues) - } - - override def putIterator( - blockId: BlockId, - values: Iterator[Any], - level: StorageLevel, - returnValues: Boolean): PutResult = { - - logDebug(s"Attempting to write values for block $blockId") - val startTime = System.currentTimeMillis - val file = diskManager.getFile(blockId) - val outputStream = new FileOutputStream(file) - try { + def putBytes(blockId: BlockId, bytes: ChunkedByteBuffer): Unit = { + put(blockId) { fileOutputStream => + val channel = fileOutputStream.getChannel Utils.tryWithSafeFinally { - blockManager.dataSerializeStream(blockId, outputStream, values) + bytes.writeFully(channel) } { - // Close outputStream here because it should be closed before file is deleted. - outputStream.close() + channel.close() } - } catch { - case e: Throwable => - if (file.exists()) { - if (!file.delete()) { - logWarning(s"Error deleting ${file}") - } - } - throw e - } - - val length = file.length - - val timeTaken = System.currentTimeMillis - startTime - logDebug("Block %s stored as %s file on disk in %d ms".format( - file.getName, Utils.bytesToString(length), timeTaken)) - - if (returnValues) { - // Return a byte buffer for the contents of the file - val buffer = getBytes(blockId).get - PutResult(length, Right(buffer)) - } else { - PutResult(length, null) } } - private def getBytes(file: File, offset: Long, length: Long): Option[ByteBuffer] = { + def getBytes(blockId: BlockId): ChunkedByteBuffer = { + val file = diskManager.getFile(blockId.name) val channel = new RandomAccessFile(file, "r").getChannel Utils.tryWithSafeFinally { // For small files, directly read rather than memory map - if (length < minMemoryMapBytes) { - val buf = ByteBuffer.allocate(length.toInt) - channel.position(offset) + if (file.length < minMemoryMapBytes) { + val buf = ByteBuffer.allocate(file.length.toInt) + channel.position(0) while (buf.remaining() != 0) { if (channel.read(buf) == -1) { throw new IOException("Reached EOF before filling buffer\n" + - s"offset=$offset\nfile=${file.getAbsolutePath}\nbuf.remaining=${buf.remaining}") + s"offset=0\nfile=${file.getAbsolutePath}\nbuf.remaining=${buf.remaining}") } } buf.flip() - Some(buf) + new ChunkedByteBuffer(buf) } else { - Some(channel.map(MapMode.READ_ONLY, offset, length)) + new ChunkedByteBuffer(channel.map(MapMode.READ_ONLY, 0, file.length)) } } { channel.close() } } - override def getBytes(blockId: BlockId): Option[ByteBuffer] = { - val file = diskManager.getFile(blockId.name) - getBytes(file, 0, file.length) - } - - def getBytes(segment: FileSegment): Option[ByteBuffer] = { - getBytes(segment.file, segment.offset, segment.length) - } - - override def getValues(blockId: BlockId): Option[Iterator[Any]] = { - getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer)) - } - - /** - * A version of getValues that allows a custom serializer. This is used as part of the - * shuffle short-circuit code. - */ - def getValues(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { - // TODO: Should bypass getBytes and use a stream based implementation, so that - // we won't use a lot of memory during e.g. external sort merge. - getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer)) - } - - override def remove(blockId: BlockId): Boolean = { + def remove(blockId: BlockId): Boolean = { val file = diskManager.getFile(blockId.name) if (file.exists()) { val ret = file.delete() @@ -167,7 +120,7 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc } } - override def contains(blockId: BlockId): Boolean = { + def contains(blockId: BlockId): Boolean = { val file = diskManager.getFile(blockId.name) file.exists() } diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala deleted file mode 100644 index f39325a12d244..0000000000000 --- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.nio.ByteBuffer - -/** - * An abstract class that the concrete external block manager has to inherit. - * The class has to have a no-argument constructor, and will be initialized by init, - * which is invoked by ExternalBlockStore. The main input parameter is blockId for all - * the methods, which is the unique identifier for Block in one Spark application. - * - * The underlying external block manager should avoid any name space conflicts among multiple - * Spark applications. For example, creating different directory for different applications - * by randomUUID - * - */ -private[spark] abstract class ExternalBlockManager { - - protected var blockManager: BlockManager = _ - - override def toString: String = {"External Block Store"} - - /** - * Initialize a concrete block manager implementation. Subclass should initialize its internal - * data structure, e.g, file system, in this function, which is invoked by ExternalBlockStore - * right after the class is constructed. The function should throw IOException on failure - * - * @throws java.io.IOException if there is any file system failure during the initialization. - */ - def init(blockManager: BlockManager, executorId: String): Unit = { - this.blockManager = blockManager - } - - /** - * Drop the block from underlying external block store, if it exists.. - * @return true on successfully removing the block - * false if the block could not be removed as it was not found - * - * @throws java.io.IOException if there is any file system failure in removing the block. - */ - def removeBlock(blockId: BlockId): Boolean - - /** - * Used by BlockManager to check the existence of the block in the underlying external - * block store. - * @return true if the block exists. - * false if the block does not exists. - * - * @throws java.io.IOException if there is any file system failure in checking - * the block existence. - */ - def blockExists(blockId: BlockId): Boolean - - /** - * Put the given block to the underlying external block store. Note that in normal case, - * putting a block should never fail unless something wrong happens to the underlying - * external block store, e.g., file system failure, etc. In this case, IOException - * should be thrown. - * - * @throws java.io.IOException if there is any file system failure in putting the block. - */ - def putBytes(blockId: BlockId, bytes: ByteBuffer): Unit - - def putValues(blockId: BlockId, values: Iterator[_]): Unit = { - val bytes = blockManager.dataSerialize(blockId, values) - putBytes(blockId, bytes) - } - - /** - * Retrieve the block bytes. - * @return Some(ByteBuffer) if the block bytes is successfully retrieved - * None if the block does not exist in the external block store. - * - * @throws java.io.IOException if there is any file system failure in getting the block. - */ - def getBytes(blockId: BlockId): Option[ByteBuffer] - - /** - * Retrieve the block data. - * @return Some(Iterator[Any]) if the block data is successfully retrieved - * None if the block does not exist in the external block store. - * - * @throws java.io.IOException if there is any file system failure in getting the block. - */ - def getValues(blockId: BlockId): Option[Iterator[_]] = { - getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer)) - } - - /** - * Get the size of the block saved in the underlying external block store, - * which is saved before by putBytes. - * @return size of the block - * 0 if the block does not exist - * - * @throws java.io.IOException if there is any file system failure in getting the block size. - */ - def getSize(blockId: BlockId): Long - - /** - * Clean up any information persisted in the underlying external block store, - * e.g., the directory, files, etc,which is invoked by the shutdown hook of ExternalBlockStore - * during system shutdown. - * - */ - def shutdown() -} diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala deleted file mode 100644 index db965d54bafd6..0000000000000 --- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala +++ /dev/null @@ -1,217 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.nio.ByteBuffer - -import scala.util.control.NonFatal - -import org.apache.spark.Logging -import org.apache.spark.util.Utils - - -/** - * Stores BlockManager blocks on ExternalBlockStore. - * We capture any potential exception from underlying implementation - * and return with the expected failure value - */ -private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: String) - extends BlockStore(blockManager: BlockManager) with Logging { - - lazy val externalBlockManager: Option[ExternalBlockManager] = createBlkManager() - - logInfo("ExternalBlockStore started") - - override def getSize(blockId: BlockId): Long = { - try { - externalBlockManager.map(_.getSize(blockId)).getOrElse(0) - } catch { - case NonFatal(t) => - logError(s"Error in getSize($blockId)", t) - 0L - } - } - - override def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel): PutResult = { - putIntoExternalBlockStore(blockId, bytes, returnValues = true) - } - - override def putArray( - blockId: BlockId, - values: Array[Any], - level: StorageLevel, - returnValues: Boolean): PutResult = { - putIntoExternalBlockStore(blockId, values.toIterator, returnValues) - } - - override def putIterator( - blockId: BlockId, - values: Iterator[Any], - level: StorageLevel, - returnValues: Boolean): PutResult = { - putIntoExternalBlockStore(blockId, values, returnValues) - } - - private def putIntoExternalBlockStore( - blockId: BlockId, - values: Iterator[_], - returnValues: Boolean): PutResult = { - logTrace(s"Attempting to put block $blockId into ExternalBlockStore") - // we should never hit here if externalBlockManager is None. Handle it anyway for safety. - try { - val startTime = System.currentTimeMillis - if (externalBlockManager.isDefined) { - externalBlockManager.get.putValues(blockId, values) - val size = getSize(blockId) - val data = if (returnValues) { - Left(getValues(blockId).get) - } else { - null - } - val finishTime = System.currentTimeMillis - logDebug("Block %s stored as %s file in ExternalBlockStore in %d ms".format( - blockId, Utils.bytesToString(size), finishTime - startTime)) - PutResult(size, data) - } else { - logError(s"Error in putValues($blockId): no ExternalBlockManager has been configured") - PutResult(-1, null, Seq((blockId, BlockStatus.empty))) - } - } catch { - case NonFatal(t) => - logError(s"Error in putValues($blockId)", t) - PutResult(-1, null, Seq((blockId, BlockStatus.empty))) - } - } - - private def putIntoExternalBlockStore( - blockId: BlockId, - bytes: ByteBuffer, - returnValues: Boolean): PutResult = { - logTrace(s"Attempting to put block $blockId into ExternalBlockStore") - // we should never hit here if externalBlockManager is None. Handle it anyway for safety. - try { - val startTime = System.currentTimeMillis - if (externalBlockManager.isDefined) { - val byteBuffer = bytes.duplicate() - byteBuffer.rewind() - externalBlockManager.get.putBytes(blockId, byteBuffer) - val size = bytes.limit() - val data = if (returnValues) { - Right(bytes) - } else { - null - } - val finishTime = System.currentTimeMillis - logDebug("Block %s stored as %s file in ExternalBlockStore in %d ms".format( - blockId, Utils.bytesToString(size), finishTime - startTime)) - PutResult(size, data) - } else { - logError(s"Error in putBytes($blockId): no ExternalBlockManager has been configured") - PutResult(-1, null, Seq((blockId, BlockStatus.empty))) - } - } catch { - case NonFatal(t) => - logError(s"Error in putBytes($blockId)", t) - PutResult(-1, null, Seq((blockId, BlockStatus.empty))) - } - } - - // We assume the block is removed even if exception thrown - override def remove(blockId: BlockId): Boolean = { - try { - externalBlockManager.map(_.removeBlock(blockId)).getOrElse(true) - } catch { - case NonFatal(t) => - logError(s"Error in removeBlock($blockId)", t) - true - } - } - - override def getValues(blockId: BlockId): Option[Iterator[Any]] = { - try { - externalBlockManager.flatMap(_.getValues(blockId)) - } catch { - case NonFatal(t) => - logError(s"Error in getValues($blockId)", t) - None - } - } - - override def getBytes(blockId: BlockId): Option[ByteBuffer] = { - try { - externalBlockManager.flatMap(_.getBytes(blockId)) - } catch { - case NonFatal(t) => - logError(s"Error in getBytes($blockId)", t) - None - } - } - - override def contains(blockId: BlockId): Boolean = { - try { - val ret = externalBlockManager.map(_.blockExists(blockId)).getOrElse(false) - if (!ret) { - logInfo(s"Remove block $blockId") - blockManager.removeBlock(blockId, true) - } - ret - } catch { - case NonFatal(t) => - logError(s"Error in getBytes($blockId)", t) - false - } - } - - private def addShutdownHook() { - Runtime.getRuntime.addShutdownHook(new Thread("ExternalBlockStore shutdown hook") { - override def run(): Unit = Utils.logUncaughtExceptions { - logDebug("Shutdown hook called") - externalBlockManager.map(_.shutdown()) - } - }) - } - - // Create concrete block manager and fall back to Tachyon by default for backward compatibility. - private def createBlkManager(): Option[ExternalBlockManager] = { - val clsName = blockManager.conf.getOption(ExternalBlockStore.BLOCK_MANAGER_NAME) - .getOrElse(ExternalBlockStore.DEFAULT_BLOCK_MANAGER_NAME) - - try { - val instance = Utils.classForName(clsName) - .newInstance() - .asInstanceOf[ExternalBlockManager] - instance.init(blockManager, executorId) - addShutdownHook(); - Some(instance) - } catch { - case NonFatal(t) => - logError("Cannot initialize external block store", t) - None - } - } -} - -private[spark] object ExternalBlockStore extends Logging { - val MAX_DIR_CREATION_ATTEMPTS = 10 - val SUB_DIRS_PER_DIR = "64" - val BASE_DIR = "spark.externalBlockStore.baseDir" - val FOLD_NAME = "spark.externalBlockStore.folderName" - val MASTER_URL = "spark.externalBlockStore.url" - val BLOCK_MANAGER_NAME = "spark.externalBlockStore.blockManager" - val DEFAULT_BLOCK_MANAGER_NAME = "org.apache.spark.storage.TachyonBlockManager" -} diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala deleted file mode 100644 index 4dbac388e098b..0000000000000 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ /dev/null @@ -1,625 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.nio.ByteBuffer -import java.util.LinkedHashMap - -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.TaskContext -import org.apache.spark.memory.MemoryManager -import org.apache.spark.util.{SizeEstimator, Utils} -import org.apache.spark.util.collection.SizeTrackingVector - -private case class MemoryEntry(value: Any, size: Long, deserialized: Boolean) - -/** - * Stores blocks in memory, either as Arrays of deserialized Java objects or as - * serialized ByteBuffers. - */ -private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: MemoryManager) - extends BlockStore(blockManager) { - - // Note: all changes to memory allocations, notably putting blocks, evicting blocks, and - // acquiring or releasing unroll memory, must be synchronized on `memoryManager`! - - private val conf = blockManager.conf - private val entries = new LinkedHashMap[BlockId, MemoryEntry](32, 0.75f, true) - - // A mapping from taskAttemptId to amount of memory used for unrolling a block (in bytes) - // All accesses of this map are assumed to have manually synchronized on `memoryManager` - private val unrollMemoryMap = mutable.HashMap[Long, Long]() - // Same as `unrollMemoryMap`, but for pending unroll memory as defined below. - // Pending unroll memory refers to the intermediate memory occupied by a task - // after the unroll but before the actual putting of the block in the cache. - // This chunk of memory is expected to be released *as soon as* we finish - // caching the corresponding block as opposed to until after the task finishes. - // This is only used if a block is successfully unrolled in its entirety in - // memory (SPARK-4777). - private val pendingUnrollMemoryMap = mutable.HashMap[Long, Long]() - - // Initial memory to request before unrolling any block - private val unrollMemoryThreshold: Long = - conf.getLong("spark.storage.unrollMemoryThreshold", 1024 * 1024) - - /** Total amount of memory available for storage, in bytes. */ - private def maxMemory: Long = memoryManager.maxStorageMemory - - if (maxMemory < unrollMemoryThreshold) { - logWarning(s"Max memory ${Utils.bytesToString(maxMemory)} is less than the initial memory " + - s"threshold ${Utils.bytesToString(unrollMemoryThreshold)} needed to store a block in " + - s"memory. Please configure Spark with more memory.") - } - - logInfo("MemoryStore started with capacity %s".format(Utils.bytesToString(maxMemory))) - - /** Total storage memory used including unroll memory, in bytes. */ - private def memoryUsed: Long = memoryManager.storageMemoryUsed - - /** - * Amount of storage memory, in bytes, used for caching blocks. - * This does not include memory used for unrolling. - */ - private def blocksMemoryUsed: Long = memoryManager.synchronized { - memoryUsed - currentUnrollMemory - } - - override def getSize(blockId: BlockId): Long = { - entries.synchronized { - entries.get(blockId).size - } - } - - override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel): PutResult = { - // Work on a duplicate - since the original input might be used elsewhere. - val bytes = _bytes.duplicate() - bytes.rewind() - if (level.deserialized) { - val values = blockManager.dataDeserialize(blockId, bytes) - putIterator(blockId, values, level, returnValues = true) - } else { - val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - tryToPut(blockId, bytes, bytes.limit, deserialized = false, droppedBlocks) - PutResult(bytes.limit(), Right(bytes.duplicate()), droppedBlocks) - } - } - - /** - * Use `size` to test if there is enough space in MemoryStore. If so, create the ByteBuffer and - * put it into MemoryStore. Otherwise, the ByteBuffer won't be created. - * - * The caller should guarantee that `size` is correct. - */ - def putBytes(blockId: BlockId, size: Long, _bytes: () => ByteBuffer): PutResult = { - // Work on a duplicate - since the original input might be used elsewhere. - lazy val bytes = _bytes().duplicate().rewind().asInstanceOf[ByteBuffer] - val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - val putSuccess = tryToPut(blockId, () => bytes, size, deserialized = false, droppedBlocks) - val data = - if (putSuccess) { - assert(bytes.limit == size) - Right(bytes.duplicate()) - } else { - null - } - PutResult(size, data, droppedBlocks) - } - - override def putArray( - blockId: BlockId, - values: Array[Any], - level: StorageLevel, - returnValues: Boolean): PutResult = { - val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - if (level.deserialized) { - val sizeEstimate = SizeEstimator.estimate(values.asInstanceOf[AnyRef]) - tryToPut(blockId, values, sizeEstimate, deserialized = true, droppedBlocks) - PutResult(sizeEstimate, Left(values.iterator), droppedBlocks) - } else { - val bytes = blockManager.dataSerialize(blockId, values.iterator) - tryToPut(blockId, bytes, bytes.limit, deserialized = false, droppedBlocks) - PutResult(bytes.limit(), Right(bytes.duplicate()), droppedBlocks) - } - } - - override def putIterator( - blockId: BlockId, - values: Iterator[Any], - level: StorageLevel, - returnValues: Boolean): PutResult = { - putIterator(blockId, values, level, returnValues, allowPersistToDisk = true) - } - - /** - * Attempt to put the given block in memory store. - * - * There may not be enough space to fully unroll the iterator in memory, in which case we - * optionally drop the values to disk if - * (1) the block's storage level specifies useDisk, and - * (2) `allowPersistToDisk` is true. - * - * One scenario in which `allowPersistToDisk` is false is when the BlockManager reads a block - * back from disk and attempts to cache it in memory. In this case, we should not persist the - * block back on disk again, as it is already in disk store. - */ - private[storage] def putIterator( - blockId: BlockId, - values: Iterator[Any], - level: StorageLevel, - returnValues: Boolean, - allowPersistToDisk: Boolean): PutResult = { - val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - val unrolledValues = unrollSafely(blockId, values, droppedBlocks) - unrolledValues match { - case Left(arrayValues) => - // Values are fully unrolled in memory, so store them as an array - val res = putArray(blockId, arrayValues, level, returnValues) - droppedBlocks ++= res.droppedBlocks - PutResult(res.size, res.data, droppedBlocks) - case Right(iteratorValues) => - // Not enough space to unroll this block; drop to disk if applicable - if (level.useDisk && allowPersistToDisk) { - logWarning(s"Persisting block $blockId to disk instead.") - val res = blockManager.diskStore.putIterator(blockId, iteratorValues, level, returnValues) - PutResult(res.size, res.data, droppedBlocks) - } else { - PutResult(0, Left(iteratorValues), droppedBlocks) - } - } - } - - override def getBytes(blockId: BlockId): Option[ByteBuffer] = { - val entry = entries.synchronized { - entries.get(blockId) - } - if (entry == null) { - None - } else if (entry.deserialized) { - Some(blockManager.dataSerialize(blockId, entry.value.asInstanceOf[Array[Any]].iterator)) - } else { - Some(entry.value.asInstanceOf[ByteBuffer].duplicate()) // Doesn't actually copy the data - } - } - - override def getValues(blockId: BlockId): Option[Iterator[Any]] = { - val entry = entries.synchronized { - entries.get(blockId) - } - if (entry == null) { - None - } else if (entry.deserialized) { - Some(entry.value.asInstanceOf[Array[Any]].iterator) - } else { - val buffer = entry.value.asInstanceOf[ByteBuffer].duplicate() // Doesn't actually copy data - Some(blockManager.dataDeserialize(blockId, buffer)) - } - } - - override def remove(blockId: BlockId): Boolean = memoryManager.synchronized { - val entry = entries.synchronized { entries.remove(blockId) } - if (entry != null) { - memoryManager.releaseStorageMemory(entry.size) - logDebug(s"Block $blockId of size ${entry.size} dropped " + - s"from memory (free ${maxMemory - blocksMemoryUsed})") - true - } else { - false - } - } - - override def clear(): Unit = memoryManager.synchronized { - entries.synchronized { - entries.clear() - } - unrollMemoryMap.clear() - pendingUnrollMemoryMap.clear() - memoryManager.releaseAllStorageMemory() - logInfo("MemoryStore cleared") - } - - /** - * Unroll the given block in memory safely. - * - * The safety of this operation refers to avoiding potential OOM exceptions caused by - * unrolling the entirety of the block in memory at once. This is achieved by periodically - * checking whether the memory restrictions for unrolling blocks are still satisfied, - * stopping immediately if not. This check is a safeguard against the scenario in which - * there is not enough free memory to accommodate the entirety of a single block. - * - * This method returns either an array with the contents of the entire block or an iterator - * containing the values of the block (if the array would have exceeded available memory). - */ - def unrollSafely( - blockId: BlockId, - values: Iterator[Any], - droppedBlocks: ArrayBuffer[(BlockId, BlockStatus)]) - : Either[Array[Any], Iterator[Any]] = { - - // Number of elements unrolled so far - var elementsUnrolled = 0 - // Whether there is still enough memory for us to continue unrolling this block - var keepUnrolling = true - // Initial per-task memory to request for unrolling blocks (bytes). Exposed for testing. - val initialMemoryThreshold = unrollMemoryThreshold - // How often to check whether we need to request more memory - val memoryCheckPeriod = 16 - // Memory currently reserved by this task for this particular unrolling operation - var memoryThreshold = initialMemoryThreshold - // Memory to request as a multiple of current vector size - val memoryGrowthFactor = 1.5 - // Previous unroll memory held by this task, for releasing later (only at the very end) - val previousMemoryReserved = currentUnrollMemoryForThisTask - // Underlying vector for unrolling the block - var vector = new SizeTrackingVector[Any] - - // Request enough memory to begin unrolling - keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, droppedBlocks) - - if (!keepUnrolling) { - logWarning(s"Failed to reserve initial memory threshold of " + - s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.") - } - - // Unroll this block safely, checking whether we have exceeded our threshold periodically - try { - while (values.hasNext && keepUnrolling) { - vector += values.next() - if (elementsUnrolled % memoryCheckPeriod == 0) { - // If our vector's size has exceeded the threshold, request more memory - val currentSize = vector.estimateSize() - if (currentSize >= memoryThreshold) { - val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong - keepUnrolling = reserveUnrollMemoryForThisTask( - blockId, amountToRequest, droppedBlocks) - // New threshold is currentSize * memoryGrowthFactor - memoryThreshold += amountToRequest - } - } - elementsUnrolled += 1 - } - - if (keepUnrolling) { - // We successfully unrolled the entirety of this block - Left(vector.toArray) - } else { - // We ran out of space while unrolling the values for this block - logUnrollFailureMessage(blockId, vector.estimateSize()) - Right(vector.iterator ++ values) - } - - } finally { - // If we return an array, the values returned here will be cached in `tryToPut` later. - // In this case, we should release the memory only after we cache the block there. - if (keepUnrolling) { - val taskAttemptId = currentTaskAttemptId() - memoryManager.synchronized { - // Since we continue to hold onto the array until we actually cache it, we cannot - // release the unroll memory yet. Instead, we transfer it to pending unroll memory - // so `tryToPut` can further transfer it to normal storage memory later. - // TODO: we can probably express this without pending unroll memory (SPARK-10907) - val amountToTransferToPending = currentUnrollMemoryForThisTask - previousMemoryReserved - unrollMemoryMap(taskAttemptId) -= amountToTransferToPending - pendingUnrollMemoryMap(taskAttemptId) = - pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + amountToTransferToPending - } - } else { - // Otherwise, if we return an iterator, we can only release the unroll memory when - // the task finishes since we don't know when the iterator will be consumed. - } - } - } - - /** - * Return the RDD ID that a given block ID is from, or None if it is not an RDD block. - */ - private def getRddId(blockId: BlockId): Option[Int] = { - blockId.asRDDId.map(_.rddId) - } - - private def tryToPut( - blockId: BlockId, - value: Any, - size: Long, - deserialized: Boolean, - droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { - tryToPut(blockId, () => value, size, deserialized, droppedBlocks) - } - - /** - * Try to put in a set of values, if we can free up enough space. The value should either be - * an Array if deserialized is true or a ByteBuffer otherwise. Its (possibly estimated) size - * must also be passed by the caller. - * - * `value` will be lazily created. If it cannot be put into MemoryStore or disk, `value` won't be - * created to avoid OOM since it may be a big ByteBuffer. - * - * Synchronize on `memoryManager` to ensure that all the put requests and its associated block - * dropping is done by only on thread at a time. Otherwise while one thread is dropping - * blocks to free memory for one block, another thread may use up the freed space for - * another block. - * - * All blocks evicted in the process, if any, will be added to `droppedBlocks`. - * - * @return whether put was successful. - */ - private def tryToPut( - blockId: BlockId, - value: () => Any, - size: Long, - deserialized: Boolean, - droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { - - /* TODO: Its possible to optimize the locking by locking entries only when selecting blocks - * to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has - * been released, it must be ensured that those to-be-dropped blocks are not double counted - * for freeing up more space for another block that needs to be put. Only then the actually - * dropping of blocks (and writing to disk if necessary) can proceed in parallel. */ - - memoryManager.synchronized { - // Note: if we have previously unrolled this block successfully, then pending unroll - // memory should be non-zero. This is the amount that we already reserved during the - // unrolling process. In this case, we can just reuse this space to cache our block. - // The synchronization on `memoryManager` here guarantees that the release and acquire - // happen atomically. This relies on the assumption that all memory acquisitions are - // synchronized on the same lock. - releasePendingUnrollMemoryForThisTask() - val enoughMemory = memoryManager.acquireStorageMemory(blockId, size, droppedBlocks) - if (enoughMemory) { - // We acquired enough memory for the block, so go ahead and put it - val entry = new MemoryEntry(value(), size, deserialized) - entries.synchronized { - entries.put(blockId, entry) - } - val valuesOrBytes = if (deserialized) "values" else "bytes" - logInfo("Block %s stored as %s in memory (estimated size %s, free %s)".format( - blockId, valuesOrBytes, Utils.bytesToString(size), Utils.bytesToString(blocksMemoryUsed))) - } else { - // Tell the block manager that we couldn't put it in memory so that it can drop it to - // disk if the block allows disk storage. - lazy val data = if (deserialized) { - Left(value().asInstanceOf[Array[Any]]) - } else { - Right(value().asInstanceOf[ByteBuffer].duplicate()) - } - val droppedBlockStatus = blockManager.dropFromMemory(blockId, () => data) - droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } - } - enoughMemory - } - } - - /** - * Try to free up a given amount of space by evicting existing blocks. - * - * @param space the amount of memory to free, in bytes - * @param droppedBlocks a holder for blocks evicted in the process - * @return whether the requested free space is freed. - */ - private[spark] def ensureFreeSpace( - space: Long, - droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { - ensureFreeSpace(None, space, droppedBlocks) - } - - /** - * Try to free up a given amount of space to store a block by evicting existing ones. - * - * @param space the amount of memory to free, in bytes - * @param droppedBlocks a holder for blocks evicted in the process - * @return whether the requested free space is freed. - */ - private[spark] def ensureFreeSpace( - blockId: BlockId, - space: Long, - droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { - ensureFreeSpace(Some(blockId), space, droppedBlocks) - } - - /** - * Try to free up a given amount of space to store a particular block, but can fail if - * either the block is bigger than our memory or it would require replacing another block - * from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that - * don't fit into memory that we want to avoid). - * - * @param blockId the ID of the block we are freeing space for, if any - * @param space the size of this block - * @param droppedBlocks a holder for blocks evicted in the process - * @return whether the requested free space is freed. - */ - private def ensureFreeSpace( - blockId: Option[BlockId], - space: Long, - droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { - memoryManager.synchronized { - val freeMemory = maxMemory - memoryUsed - val rddToAdd = blockId.flatMap(getRddId) - val selectedBlocks = new ArrayBuffer[BlockId] - var selectedMemory = 0L - - logInfo(s"Ensuring $space bytes of free space " + - blockId.map { id => s"for block $id" }.getOrElse("") + - s"(free: $freeMemory, max: $maxMemory)") - - // Fail fast if the block simply won't fit - if (space > maxMemory) { - logInfo("Will not " + blockId.map { id => s"store $id" }.getOrElse("free memory") + - s" as the required space ($space bytes) exceeds our memory limit ($maxMemory bytes)") - return false - } - - // No need to evict anything if there is already enough free space - if (freeMemory >= space) { - return true - } - - // This is synchronized to ensure that the set of entries is not changed - // (because of getValue or getBytes) while traversing the iterator, as that - // can lead to exceptions. - entries.synchronized { - val iterator = entries.entrySet().iterator() - while (freeMemory + selectedMemory < space && iterator.hasNext) { - val pair = iterator.next() - val blockId = pair.getKey - if (rddToAdd.isEmpty || rddToAdd != getRddId(blockId)) { - selectedBlocks += blockId - selectedMemory += pair.getValue.size - } - } - } - - if (freeMemory + selectedMemory >= space) { - logInfo(s"${selectedBlocks.size} blocks selected for dropping") - for (blockId <- selectedBlocks) { - val entry = entries.synchronized { entries.get(blockId) } - // This should never be null as only one task should be dropping - // blocks and removing entries. However the check is still here for - // future safety. - if (entry != null) { - val data = if (entry.deserialized) { - Left(entry.value.asInstanceOf[Array[Any]]) - } else { - Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) - } - val droppedBlockStatus = blockManager.dropFromMemory(blockId, data) - droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } - } - } - true - } else { - blockId.foreach { id => - logInfo(s"Will not store $id as it would require dropping another block " + - "from the same RDD") - } - false - } - } - } - - override def contains(blockId: BlockId): Boolean = { - entries.synchronized { entries.containsKey(blockId) } - } - - private def currentTaskAttemptId(): Long = { - // In case this is called on the driver, return an invalid task attempt id. - Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L) - } - - /** - * Reserve memory for unrolling the given block for this task. - * @return whether the request is granted. - */ - def reserveUnrollMemoryForThisTask( - blockId: BlockId, - memory: Long, - droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { - memoryManager.synchronized { - val success = memoryManager.acquireUnrollMemory(blockId, memory, droppedBlocks) - if (success) { - val taskAttemptId = currentTaskAttemptId() - unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory - } - success - } - } - - /** - * Release memory used by this task for unrolling blocks. - * If the amount is not specified, remove the current task's allocation altogether. - */ - def releaseUnrollMemoryForThisTask(memory: Long = Long.MaxValue): Unit = { - val taskAttemptId = currentTaskAttemptId() - memoryManager.synchronized { - if (unrollMemoryMap.contains(taskAttemptId)) { - val memoryToRelease = math.min(memory, unrollMemoryMap(taskAttemptId)) - if (memoryToRelease > 0) { - unrollMemoryMap(taskAttemptId) -= memoryToRelease - if (unrollMemoryMap(taskAttemptId) == 0) { - unrollMemoryMap.remove(taskAttemptId) - } - memoryManager.releaseUnrollMemory(memoryToRelease) - } - } - } - } - - /** - * Release pending unroll memory of current unroll successful block used by this task - */ - def releasePendingUnrollMemoryForThisTask(memory: Long = Long.MaxValue): Unit = { - val taskAttemptId = currentTaskAttemptId() - memoryManager.synchronized { - if (pendingUnrollMemoryMap.contains(taskAttemptId)) { - val memoryToRelease = math.min(memory, pendingUnrollMemoryMap(taskAttemptId)) - if (memoryToRelease > 0) { - pendingUnrollMemoryMap(taskAttemptId) -= memoryToRelease - if (pendingUnrollMemoryMap(taskAttemptId) == 0) { - pendingUnrollMemoryMap.remove(taskAttemptId) - } - memoryManager.releaseUnrollMemory(memoryToRelease) - } - } - } - } - - /** - * Return the amount of memory currently occupied for unrolling blocks across all tasks. - */ - def currentUnrollMemory: Long = memoryManager.synchronized { - unrollMemoryMap.values.sum + pendingUnrollMemoryMap.values.sum - } - - /** - * Return the amount of memory currently occupied for unrolling blocks by this task. - */ - def currentUnrollMemoryForThisTask: Long = memoryManager.synchronized { - unrollMemoryMap.getOrElse(currentTaskAttemptId(), 0L) - } - - /** - * Return the number of tasks currently unrolling blocks. - */ - private def numTasksUnrolling: Int = memoryManager.synchronized { unrollMemoryMap.keys.size } - - /** - * Log information about current memory usage. - */ - private def logMemoryUsage(): Unit = { - logInfo( - s"Memory use = ${Utils.bytesToString(blocksMemoryUsed)} (blocks) + " + - s"${Utils.bytesToString(currentUnrollMemory)} (scratch space shared across " + - s"$numTasksUnrolling tasks(s)) = ${Utils.bytesToString(memoryUsed)}. " + - s"Storage limit = ${Utils.bytesToString(maxMemory)}." - ) - } - - /** - * Log a warning for failing to unroll a block. - * - * @param blockId ID of the block we are trying to unroll. - * @param finalVectorSize Final size of the vector before unrolling failed. - */ - private def logUnrollFailureMessage(blockId: BlockId, finalVectorSize: Long): Unit = { - logWarning( - s"Not enough space to cache $blockId in memory! " + - s"(computed ${Utils.bytesToString(finalVectorSize)} so far)" - ) - logMemoryUsage() - } -} diff --git a/core/src/main/scala/org/apache/spark/storage/PutResult.scala b/core/src/main/scala/org/apache/spark/storage/PutResult.scala deleted file mode 100644 index f0eac7594ecf6..0000000000000 --- a/core/src/main/scala/org/apache/spark/storage/PutResult.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.nio.ByteBuffer - -/** - * Result of adding a block into a BlockStore. This case class contains a few things: - * (1) The estimated size of the put, - * (2) The values put if the caller asked for them to be returned (e.g. for chaining - * replication), and - * (3) A list of blocks dropped as a result of this put. This is always empty for DiskStore. - */ -private[spark] case class PutResult( - size: Long, - data: Either[Iterator[_], ByteBuffer], - droppedBlocks: Seq[(BlockId, BlockStatus)] = Seq.empty) diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 96062626b5045..083d78b59ebee 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -18,7 +18,7 @@ package org.apache.spark.storage import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.{RDDOperationScope, RDD} +import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.util.Utils @DeveloperApi @@ -28,6 +28,7 @@ class RDDInfo( val numPartitions: Int, var storageLevel: StorageLevel, val parentIds: Seq[Int], + val callSite: String = "", val scope: Option[RDDOperationScope] = None) extends Ordered[RDDInfo] { @@ -36,15 +37,14 @@ class RDDInfo( var diskSize = 0L var externalBlockStoreSize = 0L - def isCached: Boolean = - (memSize + diskSize + externalBlockStoreSize > 0) && numCachedPartitions > 0 + def isCached: Boolean = (memSize + diskSize > 0) && numCachedPartitions > 0 override def toString: String = { import Utils.bytesToString ("RDD \"%s\" (%d) StorageLevel: %s; CachedPartitions: %d; TotalPartitions: %d; " + - "MemorySize: %s; ExternalBlockStoreSize: %s; DiskSize: %s").format( + "MemorySize: %s; DiskSize: %s").format( name, id, storageLevel.toString, numCachedPartitions, numPartitions, - bytesToString(memSize), bytesToString(externalBlockStoreSize), bytesToString(diskSize)) + bytesToString(memSize), bytesToString(diskSize)) } override def compare(that: RDDInfo): Int = { @@ -56,6 +56,7 @@ private[spark] object RDDInfo { def fromRdd(rdd: RDD[_]): RDDInfo = { val rddName = Option(rdd.name).getOrElse(Utils.getFormattedClassName(rdd)) val parentIds = rdd.dependencies.map(_.rdd.id) - new RDDInfo(rdd.id, rddName, rdd.partitions.length, rdd.getStorageLevel, parentIds, rdd.scope) + new RDDInfo(rdd.id, rddName, rdd.partitions.length, + rdd.getStorageLevel, parentIds, rdd.creationSite.shortForm, rdd.scope) } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 0d0448feb5b06..4ec5b4bbb07cb 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -19,11 +19,13 @@ package org.apache.spark.storage import java.io.InputStream import java.util.concurrent.LinkedBlockingQueue +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} import scala.util.control.NonFatal -import org.apache.spark.{Logging, SparkException, TaskContext} +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.shuffle.FetchFailedException @@ -36,7 +38,7 @@ import org.apache.spark.util.Utils * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks * in a pipelined fashion as they are received. * - * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid + * The implementation throttles the remote fetches so they don't exceed maxBytesInFlight to avoid * using too much memory. * * @param context [[TaskContext]], used for metrics update @@ -46,6 +48,7 @@ import org.apache.spark.util.Utils * For each block we also require the size (in bytes as a long field) in * order to throttle the memory usage. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. + * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. */ private[spark] final class ShuffleBlockFetcherIterator( @@ -53,7 +56,8 @@ final class ShuffleBlockFetcherIterator( shuffleClient: ShuffleClient, blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - maxBytesInFlight: Long) + maxBytesInFlight: Long, + maxReqsInFlight: Int) extends Iterator[(BlockId, InputStream)] with Logging { import ShuffleBlockFetcherIterator._ @@ -67,7 +71,7 @@ final class ShuffleBlockFetcherIterator( private[this] var numBlocksToFetch = 0 /** - * The number of blocks proccessed by the caller. The iterator is exhausted when + * The number of blocks processed by the caller. The iterator is exhausted when * [[numBlocksProcessed]] == [[numBlocksToFetch]]. */ private[this] var numBlocksProcessed = 0 @@ -101,13 +105,17 @@ final class ShuffleBlockFetcherIterator( /** Current bytes in flight from our requests */ private[this] var bytesInFlight = 0L - private[this] val shuffleMetrics = context.taskMetrics().createShuffleReadMetricsForDependency() + /** Current number of requests in flight */ + private[this] var reqsInFlight = 0 + + private[this] val shuffleMetrics = context.taskMetrics().registerTempShuffleReadMetrics() /** * Whether the iterator is still active. If isZombie is true, the callback interface will no * longer place fetched blocks into [[results]]. */ - @volatile private[this] var isZombie = false + @GuardedBy("this") + private[this] var isZombie = false initialize() @@ -116,7 +124,7 @@ final class ShuffleBlockFetcherIterator( private[storage] def releaseCurrentResultBuffer(): Unit = { // Release the current buffer if necessary currentResult match { - case SuccessFetchResult(_, _, _, buf) => buf.release() + case SuccessFetchResult(_, _, _, buf, _) => buf.release() case _ => } currentResult = null @@ -126,14 +134,21 @@ final class ShuffleBlockFetcherIterator( * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. */ private[this] def cleanup() { - isZombie = true + synchronized { + isZombie = true + } releaseCurrentResultBuffer() // Release buffers in the results queue val iter = results.iterator() while (iter.hasNext) { val result = iter.next() result match { - case SuccessFetchResult(_, _, _, buf) => buf.release() + case SuccessFetchResult(_, address, _, buf, _) => + if (address != blockManager.blockManagerId) { + shuffleMetrics.incRemoteBytesRead(buf.size) + shuffleMetrics.incRemoteBlocksFetched(1) + } + buf.release() case _ => } } @@ -143,9 +158,11 @@ final class ShuffleBlockFetcherIterator( logDebug("Sending request for %d blocks (%s) from %s".format( req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) bytesInFlight += req.size + reqsInFlight += 1 // so we can look up the size of each blockID val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap + val remainingBlocks = new HashSet[String]() ++= sizeMap.keys val blockIds = req.blocks.map(_._1.toString) val address = req.address @@ -154,13 +171,16 @@ final class ShuffleBlockFetcherIterator( override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { // Only add the buffer to results queue if the iterator is not zombie, // i.e. cleanup() has not been called yet. - if (!isZombie) { - // Increment the ref count because we need to pass this to a different thread. - // This needs to be released after use. - buf.retain() - results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf)) - shuffleMetrics.incRemoteBytesRead(buf.size) - shuffleMetrics.incRemoteBlocksFetched(1) + ShuffleBlockFetcherIterator.this.synchronized { + if (!isZombie) { + // Increment the ref count because we need to pass this to a different thread. + // This needs to be released after use. + buf.retain() + remainingBlocks -= blockId + results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf, + remainingBlocks.isEmpty)) + logDebug("remainingBlocks: " + remainingBlocks) + } } logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } @@ -239,7 +259,7 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() - results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf)) + results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false)) } catch { case e: Exception => // If we see an exception, stop immediately. @@ -258,6 +278,9 @@ final class ShuffleBlockFetcherIterator( val remoteRequests = splitLocalRemoteBlocks() // Add the remote requests into our queue in a random order fetchRequests ++= Utils.randomize(remoteRequests) + assert ((0 == reqsInFlight) == (0 == bytesInFlight), + "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight + + ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight) // Send out initial requests for blocks, up to our maxBytesInFlight fetchUpToMaxBytes() @@ -289,7 +312,16 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait) result match { - case SuccessFetchResult(_, _, size, _) => bytesInFlight -= size + case SuccessFetchResult(_, address, size, buf, isNetworkReqDone) => + if (address != blockManager.blockManagerId) { + shuffleMetrics.incRemoteBytesRead(buf.size) + shuffleMetrics.incRemoteBlocksFetched(1) + } + bytesInFlight -= size + if (isNetworkReqDone) { + reqsInFlight -= 1 + logDebug("Number of requests in flight " + reqsInFlight) + } case _ => } // Send fetch requests up to maxBytesInFlight @@ -299,7 +331,7 @@ final class ShuffleBlockFetcherIterator( case FailureFetchResult(blockId, address, e) => throwFetchFailedException(blockId, address, e) - case SuccessFetchResult(blockId, address, _, buf) => + case SuccessFetchResult(blockId, address, _, buf, _) => try { (result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this)) } catch { @@ -312,7 +344,9 @@ final class ShuffleBlockFetcherIterator( private def fetchUpToMaxBytes(): Unit = { // Send fetch requests up to maxBytesInFlight while (fetchRequests.nonEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { + (bytesInFlight == 0 || + (reqsInFlight + 1 <= maxReqsInFlight && + bytesInFlight + fetchRequests.front.size <= maxBytesInFlight))) { sendRequest(fetchRequests.dequeue()) } } @@ -329,7 +363,7 @@ final class ShuffleBlockFetcherIterator( } /** - * Helper class that ensures a ManagedBuffer is release upon InputStream.close() + * Helper class that ensures a ManagedBuffer is released upon InputStream.close() */ private class BufferReleasingInputStream( private val delegate: InputStream, @@ -390,13 +424,14 @@ object ShuffleBlockFetcherIterator { * @param size estimated size of the block, used to calculate bytesInFlight. * Note that this is NOT the exact bytes. * @param buf [[ManagedBuffer]] for the content. + * @param isNetworkReqDone Is this the last network request for this host in this fetch request. */ private[storage] case class SuccessFetchResult( blockId: BlockId, address: BlockManagerId, size: Long, - buf: ManagedBuffer) - extends FetchResult { + buf: ManagedBuffer, + isNetworkReqDone: Boolean) extends FetchResult { require(buf != null) require(size >= 0) } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index 703bce3e6b85b..216ec0793492f 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -21,6 +21,7 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.util.concurrent.ConcurrentHashMap import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.memory.MemoryMode import org.apache.spark.util.Utils /** @@ -59,10 +60,12 @@ class StorageLevel private( assert(replication < 40, "Replication restricted to be less than 40 for calculating hash codes") if (useOffHeap) { - require(!useDisk, "Off-heap storage level does not support using disk") - require(!useMemory, "Off-heap storage level does not support using heap memory") require(!deserialized, "Off-heap storage level does not support deserialized storage") - require(replication == 1, "Off-heap storage level does not support multiple replication") + } + + private[spark] def memoryMode: MemoryMode = { + if (useOffHeap) MemoryMode.OFF_HEAP + else MemoryMode.ON_HEAP } override def clone(): StorageLevel = { @@ -80,7 +83,7 @@ class StorageLevel private( false } - def isValid: Boolean = (useMemory || useDisk || useOffHeap) && (replication > 0) + def isValid: Boolean = (useMemory || useDisk) && (replication > 0) def toInt: Int = { var ret = 0 @@ -117,7 +120,8 @@ class StorageLevel private( private def readResolve(): Object = StorageLevel.getCachedStorageLevel(this) override def toString: String = { - s"StorageLevel($useDisk, $useMemory, $useOffHeap, $deserialized, $replication)" + s"StorageLevel(disk=$useDisk, memory=$useMemory, offheap=$useOffHeap, " + + s"deserialized=$deserialized, replication=$replication)" } override def hashCode(): Int = toInt * 41 + replication @@ -125,8 +129,9 @@ class StorageLevel private( def description: String = { var result = "" result += (if (useDisk) "Disk " else "") - result += (if (useMemory) "Memory " else "") - result += (if (useOffHeap) "ExternalBlockStore " else "") + if (useMemory) { + result += (if (useOffHeap) "Memory (off heap) " else "Memory ") + } result += (if (deserialized) "Deserialized " else "Serialized ") result += s"${replication}x Replicated" result @@ -150,7 +155,7 @@ object StorageLevel { val MEMORY_AND_DISK_2 = new StorageLevel(true, true, false, true, 2) val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, false) val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, false, 2) - val OFF_HEAP = new StorageLevel(false, false, true, false) + val OFF_HEAP = new StorageLevel(true, true, true, false, 1) /** * :: DeveloperApi :: @@ -175,7 +180,7 @@ object StorageLevel { /** * :: DeveloperApi :: - * Create a new StorageLevel object without setting useOffHeap. + * Create a new StorageLevel object. */ @DeveloperApi def apply( @@ -190,7 +195,7 @@ object StorageLevel { /** * :: DeveloperApi :: - * Create a new StorageLevel object. + * Create a new StorageLevel object without setting useOffHeap. */ @DeveloperApi def apply( diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala index ec711480ebf30..3008520f61c3f 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import scala.collection.mutable +import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ @@ -29,14 +30,20 @@ import org.apache.spark.scheduler._ * This class is thread-safe (unlike JobProgressListener) */ @DeveloperApi -class StorageStatusListener extends SparkListener { +class StorageStatusListener(conf: SparkConf) extends SparkListener { // This maintains only blocks that are cached (i.e. storage level is not StorageLevel.NONE) private[storage] val executorIdToStorageStatus = mutable.Map[String, StorageStatus]() + private[storage] val deadExecutorStorageStatus = new mutable.ListBuffer[StorageStatus]() + private[this] val retainedDeadExecutors = conf.getInt("spark.ui.retainedDeadExecutors", 100) def storageStatusList: Seq[StorageStatus] = synchronized { executorIdToStorageStatus.values.toSeq } + def deadStorageStatusList: Seq[StorageStatus] = synchronized { + deadExecutorStorageStatus.toSeq + } + /** Update storage status list to reflect updated block statuses */ private def updateStorageStatus(execId: String, updatedBlocks: Seq[(BlockId, BlockStatus)]) { executorIdToStorageStatus.get(execId).foreach { storageStatus => @@ -59,17 +66,6 @@ class StorageStatusListener extends SparkListener { } } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { - val info = taskEnd.taskInfo - val metrics = taskEnd.taskMetrics - if (info != null && metrics != null) { - val updatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) - if (updatedBlocks.length > 0) { - updateStorageStatus(info.executorId, updatedBlocks) - } - } - } - override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = synchronized { updateStorageStatus(unpersistRDD.rddId) } @@ -87,8 +83,22 @@ class StorageStatusListener extends SparkListener { override def onBlockManagerRemoved(blockManagerRemoved: SparkListenerBlockManagerRemoved) { synchronized { val executorId = blockManagerRemoved.blockManagerId.executorId - executorIdToStorageStatus.remove(executorId) + executorIdToStorageStatus.remove(executorId).foreach { status => + deadExecutorStorageStatus += status + } + if (deadExecutorStorageStatus.size > retainedDeadExecutors) { + deadExecutorStorageStatus.trimStart(1) + } } } + override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { + val executorId = blockUpdated.blockUpdatedInfo.blockManagerId.executorId + val blockId = blockUpdated.blockUpdatedInfo.blockId + val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel + val memSize = blockUpdated.blockUpdatedInfo.memSize + val diskSize = blockUpdated.blockUpdatedInfo.diskSize + val blockStatus = BlockStatus(storageLevel, memSize, diskSize) + updateStorageStatus(executorId, Seq((blockId, blockStatus))) + } } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index c4ac30092f807..fb9941bbd9e0f 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -17,10 +17,15 @@ package org.apache.spark.storage +import java.nio.{ByteBuffer, MappedByteBuffer} + import scala.collection.Map import scala.collection.mutable +import sun.nio.ch.DirectBuffer + import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.Logging /** * :: DeveloperApi :: @@ -48,14 +53,14 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { * non-RDD blocks for the same reason. In particular, RDD storage information is stored * in a map indexed by the RDD ID to the following 4-tuple: * - * (memory size, disk size, off-heap size, storage level) + * (memory size, disk size, storage level) * * We assume that all the blocks that belong to the same RDD have the same storage level. * This field is not relevant to non-RDD blocks, however, so the storage information for * non-RDD blocks contains only the first 3 fields (in the same order). */ - private val _rddStorageInfo = new mutable.HashMap[Int, (Long, Long, Long, StorageLevel)] - private var _nonRddStorageInfo: (Long, Long, Long) = (0L, 0L, 0L) + private val _rddStorageInfo = new mutable.HashMap[Int, (Long, Long, StorageLevel)] + private var _nonRddStorageInfo: (Long, Long) = (0L, 0L) /** Create a storage status with an initial set of blocks, leaving the source unmodified. */ def this(bmid: BlockManagerId, maxMem: Long, initialBlocks: Map[BlockId, BlockStatus]) { @@ -82,9 +87,7 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { def rddBlocks: Map[BlockId, BlockStatus] = _rddBlocks.flatMap { case (_, blocks) => blocks } /** Return the blocks that belong to the given RDD stored in this block manager. */ - def rddBlocksById(rddId: Int): Map[BlockId, BlockStatus] = { - _rddBlocks.get(rddId).getOrElse(Map.empty) - } + def rddBlocksById(rddId: Int): Map[BlockId, BlockStatus] = _rddBlocks.getOrElse(rddId, Map.empty) /** Add the given block to this storage status. If it already exists, overwrite it. */ private[spark] def addBlock(blockId: BlockId, blockStatus: BlockStatus): Unit = { @@ -143,7 +146,7 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { def getBlock(blockId: BlockId): Option[BlockStatus] = { blockId match { case RDDBlockId(rddId, _) => - _rddBlocks.get(rddId).map(_.get(blockId)).flatten + _rddBlocks.get(rddId).flatMap(_.get(blockId)) case _ => _nonRddBlocks.get(blockId) } @@ -172,25 +175,22 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { def memRemaining: Long = maxMem - memUsed /** Return the memory used by this block manager. */ - def memUsed: Long = _nonRddStorageInfo._1 + _rddBlocks.keys.toSeq.map(memUsedByRdd).sum + def memUsed: Long = _nonRddStorageInfo._1 + cacheSize + + /** Return the memory used by caching RDDs */ + def cacheSize: Long = _rddBlocks.keys.toSeq.map(memUsedByRdd).sum /** Return the disk space used by this block manager. */ def diskUsed: Long = _nonRddStorageInfo._2 + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum - /** Return the off-heap space used by this block manager. */ - def offHeapUsed: Long = _nonRddStorageInfo._3 + _rddBlocks.keys.toSeq.map(offHeapUsedByRdd).sum - /** Return the memory used by the given RDD in this block manager in O(1) time. */ def memUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._1).getOrElse(0L) /** Return the disk space used by the given RDD in this block manager in O(1) time. */ def diskUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._2).getOrElse(0L) - /** Return the off-heap space used by the given RDD in this block manager in O(1) time. */ - def offHeapUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._3).getOrElse(0L) - /** Return the storage level, if any, used by the given RDD in this block manager. */ - def rddStorageLevel(rddId: Int): Option[StorageLevel] = _rddStorageInfo.get(rddId).map(_._4) + def rddStorageLevel(rddId: Int): Option[StorageLevel] = _rddStorageInfo.get(rddId).map(_._3) /** * Update the relevant storage info, taking into account any existing status for this block. @@ -199,41 +199,53 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { val oldBlockStatus = getBlock(blockId).getOrElse(BlockStatus.empty) val changeInMem = newBlockStatus.memSize - oldBlockStatus.memSize val changeInDisk = newBlockStatus.diskSize - oldBlockStatus.diskSize - val changeInExternalBlockStore = - newBlockStatus.externalBlockStoreSize - oldBlockStatus.externalBlockStoreSize val level = newBlockStatus.storageLevel // Compute new info from old info - val (oldMem, oldDisk, oldExternalBlockStore) = blockId match { + val (oldMem, oldDisk) = blockId match { case RDDBlockId(rddId, _) => _rddStorageInfo.get(rddId) - .map { case (mem, disk, externalBlockStore, _) => (mem, disk, externalBlockStore) } - .getOrElse((0L, 0L, 0L)) + .map { case (mem, disk, _) => (mem, disk) } + .getOrElse((0L, 0L)) case _ => _nonRddStorageInfo } val newMem = math.max(oldMem + changeInMem, 0L) val newDisk = math.max(oldDisk + changeInDisk, 0L) - val newExternalBlockStore = math.max(oldExternalBlockStore + changeInExternalBlockStore, 0L) // Set the correct info blockId match { case RDDBlockId(rddId, _) => // If this RDD is no longer persisted, remove it - if (newMem + newDisk + newExternalBlockStore == 0) { + if (newMem + newDisk == 0) { _rddStorageInfo.remove(rddId) } else { - _rddStorageInfo(rddId) = (newMem, newDisk, newExternalBlockStore, level) + _rddStorageInfo(rddId) = (newMem, newDisk, level) } case _ => - _nonRddStorageInfo = (newMem, newDisk, newExternalBlockStore) + _nonRddStorageInfo = (newMem, newDisk) } } } /** Helper methods for storage-related objects. */ -private[spark] object StorageUtils { +private[spark] object StorageUtils extends Logging { + + /** + * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that + * might cause errors if one attempts to read from the unmapped buffer, but it's better than + * waiting for the GC to find it because that could lead to huge numbers of open files. There's + * unfortunately no standard API to do this. + */ + def dispose(buffer: ByteBuffer): Unit = { + if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { + logTrace(s"Unmapping $buffer") + if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) { + buffer.asInstanceOf[DirectBuffer].cleaner().clean() + } + } + } /** * Update the given list of RDDInfo with the given list of storage statuses. @@ -248,13 +260,11 @@ private[spark] object StorageUtils { val numCachedPartitions = statuses.map(_.numRddBlocksById(rddId)).sum val memSize = statuses.map(_.memUsedByRdd(rddId)).sum val diskSize = statuses.map(_.diskUsedByRdd(rddId)).sum - val externalBlockStoreSize = statuses.map(_.offHeapUsedByRdd(rddId)).sum rddInfo.storageLevel = storageLevel rddInfo.numCachedPartitions = numCachedPartitions rddInfo.memSize = memSize rddInfo.diskSize = diskSize - rddInfo.externalBlockStoreSize = externalBlockStoreSize } } diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala deleted file mode 100644 index 22878783fca67..0000000000000 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ /dev/null @@ -1,253 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.io.IOException -import java.nio.ByteBuffer -import java.text.SimpleDateFormat -import java.util.{Date, Random} - -import scala.util.control.NonFatal - -import com.google.common.io.ByteStreams - -import tachyon.client.{ReadType, WriteType, TachyonFS, TachyonFile} -import tachyon.conf.TachyonConf -import tachyon.TachyonURI - -import org.apache.spark.Logging -import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.util.{ShutdownHookManager, Utils} - - -/** - * Creates and maintains the logical mapping between logical blocks and tachyon fs locations. By - * default, one block is mapped to one file with a name given by its BlockId. - * - */ -private[spark] class TachyonBlockManager() extends ExternalBlockManager with Logging { - - var rootDirs: String = _ - var master: String = _ - var client: tachyon.client.TachyonFS = _ - private var subDirsPerTachyonDir: Int = _ - - // Create one Tachyon directory for each path mentioned in spark.tachyonStore.folderName; - // then, inside this directory, create multiple subdirectories that we will hash files into, - // in order to avoid having really large inodes at the top level in Tachyon. - private var tachyonDirs: Array[TachyonFile] = _ - private var subDirs: Array[Array[tachyon.client.TachyonFile]] = _ - - - override def init(blockManager: BlockManager, executorId: String): Unit = { - super.init(blockManager, executorId) - val storeDir = blockManager.conf.get(ExternalBlockStore.BASE_DIR, "/tmp_spark_tachyon") - val appFolderName = blockManager.conf.get(ExternalBlockStore.FOLD_NAME) - - rootDirs = s"$storeDir/$appFolderName/$executorId" - master = blockManager.conf.get(ExternalBlockStore.MASTER_URL, "tachyon://localhost:19998") - client = if (master != null && master != "") { - TachyonFS.get(new TachyonURI(master), new TachyonConf()) - } else { - null - } - // original implementation call System.exit, we change it to run without extblkstore support - if (client == null) { - logError("Failed to connect to the Tachyon as the master address is not configured") - throw new IOException("Failed to connect to the Tachyon as the master " + - "address is not configured") - } - subDirsPerTachyonDir = blockManager.conf.get("spark.externalBlockStore.subDirectories", - ExternalBlockStore.SUB_DIRS_PER_DIR).toInt - - // Create one Tachyon directory for each path mentioned in spark.tachyonStore.folderName; - // then, inside this directory, create multiple subdirectories that we will hash files into, - // in order to avoid having really large inodes at the top level in Tachyon. - tachyonDirs = createTachyonDirs() - subDirs = Array.fill(tachyonDirs.length)(new Array[TachyonFile](subDirsPerTachyonDir)) - tachyonDirs.foreach(tachyonDir => ShutdownHookManager.registerShutdownDeleteDir(tachyonDir)) - } - - override def toString: String = {"ExternalBlockStore-Tachyon"} - - override def removeBlock(blockId: BlockId): Boolean = { - val file = getFile(blockId) - if (fileExists(file)) { - removeFile(file) - } else { - false - } - } - - override def blockExists(blockId: BlockId): Boolean = { - val file = getFile(blockId) - fileExists(file) - } - - override def putBytes(blockId: BlockId, bytes: ByteBuffer): Unit = { - val file = getFile(blockId) - val os = file.getOutStream(WriteType.TRY_CACHE) - try { - os.write(bytes.array()) - } catch { - case NonFatal(e) => - logWarning(s"Failed to put bytes of block $blockId into Tachyon", e) - os.cancel() - } finally { - os.close() - } - } - - override def putValues(blockId: BlockId, values: Iterator[_]): Unit = { - val file = getFile(blockId) - val os = file.getOutStream(WriteType.TRY_CACHE) - try { - blockManager.dataSerializeStream(blockId, os, values) - } catch { - case NonFatal(e) => - logWarning(s"Failed to put values of block $blockId into Tachyon", e) - os.cancel() - } finally { - os.close() - } - } - - override def getBytes(blockId: BlockId): Option[ByteBuffer] = { - val file = getFile(blockId) - if (file == null || file.getLocationHosts.size == 0) { - return None - } - val is = file.getInStream(ReadType.CACHE) - try { - val size = file.length - val bs = new Array[Byte](size.asInstanceOf[Int]) - ByteStreams.readFully(is, bs) - Some(ByteBuffer.wrap(bs)) - } catch { - case NonFatal(e) => - logWarning(s"Failed to get bytes of block $blockId from Tachyon", e) - None - } finally { - is.close() - } - } - - override def getValues(blockId: BlockId): Option[Iterator[_]] = { - val file = getFile(blockId) - if (file == null || file.getLocationHosts().size() == 0) { - return None - } - val is = file.getInStream(ReadType.CACHE) - Option(is).map { is => - blockManager.dataDeserializeStream(blockId, is) - } - } - - override def getSize(blockId: BlockId): Long = { - getFile(blockId.name).length - } - - def removeFile(file: TachyonFile): Boolean = { - client.delete(new TachyonURI(file.getPath()), false) - } - - def fileExists(file: TachyonFile): Boolean = { - client.exist(new TachyonURI(file.getPath())) - } - - def getFile(filename: String): TachyonFile = { - // Figure out which tachyon directory it hashes to, and which subdirectory in that - val hash = Utils.nonNegativeHash(filename) - val dirId = hash % tachyonDirs.length - val subDirId = (hash / tachyonDirs.length) % subDirsPerTachyonDir - - // Create the subdirectory if it doesn't already exist - var subDir = subDirs(dirId)(subDirId) - if (subDir == null) { - subDir = subDirs(dirId).synchronized { - val old = subDirs(dirId)(subDirId) - if (old != null) { - old - } else { - val path = new TachyonURI(s"${tachyonDirs(dirId)}/${"%02x".format(subDirId)}") - client.mkdir(path) - val newDir = client.getFile(path) - subDirs(dirId)(subDirId) = newDir - newDir - } - } - } - val filePath = new TachyonURI(s"$subDir/$filename") - if(!client.exist(filePath)) { - client.createFile(filePath) - } - val file = client.getFile(filePath) - file - } - - def getFile(blockId: BlockId): TachyonFile = getFile(blockId.name) - - // TODO: Some of the logic here could be consolidated/de-duplicated with that in the DiskStore. - private def createTachyonDirs(): Array[TachyonFile] = { - logDebug("Creating tachyon directories at root dirs '" + rootDirs + "'") - val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") - rootDirs.split(",").map { rootDir => - var foundLocalDir = false - var tachyonDir: TachyonFile = null - var tachyonDirId: String = null - var tries = 0 - val rand = new Random() - while (!foundLocalDir && tries < ExternalBlockStore.MAX_DIR_CREATION_ATTEMPTS) { - tries += 1 - try { - tachyonDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) - val path = new TachyonURI(s"$rootDir/spark-tachyon-$tachyonDirId") - if (!client.exist(path)) { - foundLocalDir = client.mkdir(path) - tachyonDir = client.getFile(path) - } - } catch { - case NonFatal(e) => - logWarning("Attempt " + tries + " to create tachyon dir " + tachyonDir + " failed", e) - } - } - if (!foundLocalDir) { - logError("Failed " + ExternalBlockStore.MAX_DIR_CREATION_ATTEMPTS - + " attempts to create tachyon dir in " + rootDir) - System.exit(ExecutorExitCode.EXTERNAL_BLOCK_STORE_FAILED_TO_CREATE_DIR) - } - logInfo("Created tachyon directory at " + tachyonDir) - tachyonDir - } - } - - override def shutdown() { - logDebug("Shutdown hook called") - tachyonDirs.foreach { tachyonDir => - try { - if (!ShutdownHookManager.hasRootAsShutdownDeleteDir(tachyonDir)) { - Utils.deleteRecursively(tachyonDir, client) - } - } catch { - case NonFatal(e) => - logError("Exception while deleting tachyon spark dir: " + tachyonDir, e) - } - } - client.close() - } -} diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala new file mode 100644 index 0000000000000..99be4de0658cc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -0,0 +1,791 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage.memory + +import java.io.OutputStream +import java.nio.ByteBuffer +import java.util.LinkedHashMap + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import com.google.common.io.ByteStreams + +import org.apache.spark.{SparkConf, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.memory.{MemoryManager, MemoryMode} +import org.apache.spark.serializer.{SerializationStream, SerializerManager} +import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel} +import org.apache.spark.unsafe.Platform +import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils} +import org.apache.spark.util.collection.SizeTrackingVector +import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} + +private sealed trait MemoryEntry[T] { + def size: Long + def memoryMode: MemoryMode + def classTag: ClassTag[T] +} +private case class DeserializedMemoryEntry[T]( + value: Array[T], + size: Long, + classTag: ClassTag[T]) extends MemoryEntry[T] { + val memoryMode: MemoryMode = MemoryMode.ON_HEAP +} +private case class SerializedMemoryEntry[T]( + buffer: ChunkedByteBuffer, + memoryMode: MemoryMode, + classTag: ClassTag[T]) extends MemoryEntry[T] { + def size: Long = buffer.size +} + +private[storage] trait BlockEvictionHandler { + /** + * Drop a block from memory, possibly putting it on disk if applicable. Called when the memory + * store reaches its limit and needs to free up space. + * + * If `data` is not put on disk, it won't be created. + * + * The caller of this method must hold a write lock on the block before calling this method. + * This method does not release the write lock. + * + * @return the block's new effective StorageLevel. + */ + private[storage] def dropFromMemory[T: ClassTag]( + blockId: BlockId, + data: () => Either[Array[T], ChunkedByteBuffer]): StorageLevel +} + +/** + * Stores blocks in memory, either as Arrays of deserialized Java objects or as + * serialized ByteBuffers. + */ +private[spark] class MemoryStore( + conf: SparkConf, + blockInfoManager: BlockInfoManager, + serializerManager: SerializerManager, + memoryManager: MemoryManager, + blockEvictionHandler: BlockEvictionHandler) + extends Logging { + + // Note: all changes to memory allocations, notably putting blocks, evicting blocks, and + // acquiring or releasing unroll memory, must be synchronized on `memoryManager`! + + private val entries = new LinkedHashMap[BlockId, MemoryEntry[_]](32, 0.75f, true) + + // A mapping from taskAttemptId to amount of memory used for unrolling a block (in bytes) + // All accesses of this map are assumed to have manually synchronized on `memoryManager` + private val onHeapUnrollMemoryMap = mutable.HashMap[Long, Long]() + // Note: off-heap unroll memory is only used in putIteratorAsBytes() because off-heap caching + // always stores serialized values. + private val offHeapUnrollMemoryMap = mutable.HashMap[Long, Long]() + + // Initial memory to request before unrolling any block + private val unrollMemoryThreshold: Long = + conf.getLong("spark.storage.unrollMemoryThreshold", 1024 * 1024) + + /** Total amount of memory available for storage, in bytes. */ + private def maxMemory: Long = memoryManager.maxOnHeapStorageMemory + + if (maxMemory < unrollMemoryThreshold) { + logWarning(s"Max memory ${Utils.bytesToString(maxMemory)} is less than the initial memory " + + s"threshold ${Utils.bytesToString(unrollMemoryThreshold)} needed to store a block in " + + s"memory. Please configure Spark with more memory.") + } + + logInfo("MemoryStore started with capacity %s".format(Utils.bytesToString(maxMemory))) + + /** Total storage memory used including unroll memory, in bytes. */ + private def memoryUsed: Long = memoryManager.storageMemoryUsed + + /** + * Amount of storage memory, in bytes, used for caching blocks. + * This does not include memory used for unrolling. + */ + private def blocksMemoryUsed: Long = memoryManager.synchronized { + memoryUsed - currentUnrollMemory + } + + def getSize(blockId: BlockId): Long = { + entries.synchronized { + entries.get(blockId).size + } + } + + /** + * Use `size` to test if there is enough space in MemoryStore. If so, create the ByteBuffer and + * put it into MemoryStore. Otherwise, the ByteBuffer won't be created. + * + * The caller should guarantee that `size` is correct. + * + * @return true if the put() succeeded, false otherwise. + */ + def putBytes[T: ClassTag]( + blockId: BlockId, + size: Long, + memoryMode: MemoryMode, + _bytes: () => ChunkedByteBuffer): Boolean = { + require(!contains(blockId), s"Block $blockId is already present in the MemoryStore") + if (memoryManager.acquireStorageMemory(blockId, size, memoryMode)) { + // We acquired enough memory for the block, so go ahead and put it + val bytes = _bytes() + assert(bytes.size == size) + val entry = new SerializedMemoryEntry[T](bytes, memoryMode, implicitly[ClassTag[T]]) + entries.synchronized { + entries.put(blockId, entry) + } + logInfo("Block %s stored as bytes in memory (estimated size %s, free %s)".format( + blockId, Utils.bytesToString(size), Utils.bytesToString(maxMemory - blocksMemoryUsed))) + true + } else { + false + } + } + + /** + * Attempt to put the given block in memory store as values. + * + * It's possible that the iterator is too large to materialize and store in memory. To avoid + * OOM exceptions, this method will gradually unroll the iterator while periodically checking + * whether there is enough free memory. If the block is successfully materialized, then the + * temporary unroll memory used during the materialization is "transferred" to storage memory, + * so we won't acquire more memory than is actually needed to store the block. + * + * @return in case of success, the estimated the estimated size of the stored data. In case of + * failure, return an iterator containing the values of the block. The returned iterator + * will be backed by the combination of the partially-unrolled block and the remaining + * elements of the original input iterator. The caller must either fully consume this + * iterator or call `close()` on it in order to free the storage memory consumed by the + * partially-unrolled block. + */ + private[storage] def putIteratorAsValues[T]( + blockId: BlockId, + values: Iterator[T], + classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = { + + require(!contains(blockId), s"Block $blockId is already present in the MemoryStore") + + // Number of elements unrolled so far + var elementsUnrolled = 0 + // Whether there is still enough memory for us to continue unrolling this block + var keepUnrolling = true + // Initial per-task memory to request for unrolling blocks (bytes). + val initialMemoryThreshold = unrollMemoryThreshold + // How often to check whether we need to request more memory + val memoryCheckPeriod = 16 + // Memory currently reserved by this task for this particular unrolling operation + var memoryThreshold = initialMemoryThreshold + // Memory to request as a multiple of current vector size + val memoryGrowthFactor = 1.5 + // Keep track of unroll memory used by this particular block / putIterator() operation + var unrollMemoryUsedByThisBlock = 0L + // Underlying vector for unrolling the block + var vector = new SizeTrackingVector[T]()(classTag) + + // Request enough memory to begin unrolling + keepUnrolling = + reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, MemoryMode.ON_HEAP) + + if (!keepUnrolling) { + logWarning(s"Failed to reserve initial memory threshold of " + + s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.") + } else { + unrollMemoryUsedByThisBlock += initialMemoryThreshold + } + + // Unroll this block safely, checking whether we have exceeded our threshold periodically + while (values.hasNext && keepUnrolling) { + vector += values.next() + if (elementsUnrolled % memoryCheckPeriod == 0) { + // If our vector's size has exceeded the threshold, request more memory + val currentSize = vector.estimateSize() + if (currentSize >= memoryThreshold) { + val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong + keepUnrolling = + reserveUnrollMemoryForThisTask(blockId, amountToRequest, MemoryMode.ON_HEAP) + if (keepUnrolling) { + unrollMemoryUsedByThisBlock += amountToRequest + } + // New threshold is currentSize * memoryGrowthFactor + memoryThreshold += amountToRequest + } + } + elementsUnrolled += 1 + } + + if (keepUnrolling) { + // We successfully unrolled the entirety of this block + val arrayValues = vector.toArray + vector = null + val entry = + new DeserializedMemoryEntry[T](arrayValues, SizeEstimator.estimate(arrayValues), classTag) + val size = entry.size + def transferUnrollToStorage(amount: Long): Unit = { + // Synchronize so that transfer is atomic + memoryManager.synchronized { + releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, amount) + val success = memoryManager.acquireStorageMemory(blockId, amount, MemoryMode.ON_HEAP) + assert(success, "transferring unroll memory to storage memory failed") + } + } + // Acquire storage memory if necessary to store this block in memory. + val enoughStorageMemory = { + if (unrollMemoryUsedByThisBlock <= size) { + val acquiredExtra = + memoryManager.acquireStorageMemory( + blockId, size - unrollMemoryUsedByThisBlock, MemoryMode.ON_HEAP) + if (acquiredExtra) { + transferUnrollToStorage(unrollMemoryUsedByThisBlock) + } + acquiredExtra + } else { // unrollMemoryUsedByThisBlock > size + // If this task attempt already owns more unroll memory than is necessary to store the + // block, then release the extra memory that will not be used. + val excessUnrollMemory = unrollMemoryUsedByThisBlock - size + releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, excessUnrollMemory) + transferUnrollToStorage(size) + true + } + } + if (enoughStorageMemory) { + entries.synchronized { + entries.put(blockId, entry) + } + logInfo("Block %s stored as values in memory (estimated size %s, free %s)".format( + blockId, Utils.bytesToString(size), Utils.bytesToString(maxMemory - blocksMemoryUsed))) + Right(size) + } else { + assert(currentUnrollMemoryForThisTask >= currentUnrollMemoryForThisTask, + "released too much unroll memory") + Left(new PartiallyUnrolledIterator( + this, + unrollMemoryUsedByThisBlock, + unrolled = arrayValues.toIterator, + rest = Iterator.empty)) + } + } else { + // We ran out of space while unrolling the values for this block + logUnrollFailureMessage(blockId, vector.estimateSize()) + Left(new PartiallyUnrolledIterator( + this, unrollMemoryUsedByThisBlock, unrolled = vector.iterator, rest = values)) + } + } + + /** + * Attempt to put the given block in memory store as bytes. + * + * It's possible that the iterator is too large to materialize and store in memory. To avoid + * OOM exceptions, this method will gradually unroll the iterator while periodically checking + * whether there is enough free memory. If the block is successfully materialized, then the + * temporary unroll memory used during the materialization is "transferred" to storage memory, + * so we won't acquire more memory than is actually needed to store the block. + * + * @return in case of success, the estimated the estimated size of the stored data. In case of + * failure, return a handle which allows the caller to either finish the serialization + * by spilling to disk or to deserialize the partially-serialized block and reconstruct + * the original input iterator. The caller must either fully consume this result + * iterator or call `discard()` on it in order to free the storage memory consumed by the + * partially-unrolled block. + */ + private[storage] def putIteratorAsBytes[T]( + blockId: BlockId, + values: Iterator[T], + classTag: ClassTag[T], + memoryMode: MemoryMode): Either[PartiallySerializedBlock[T], Long] = { + + require(!contains(blockId), s"Block $blockId is already present in the MemoryStore") + + val allocator = memoryMode match { + case MemoryMode.ON_HEAP => ByteBuffer.allocate _ + case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _ + } + + // Whether there is still enough memory for us to continue unrolling this block + var keepUnrolling = true + // Initial per-task memory to request for unrolling blocks (bytes). + val initialMemoryThreshold = unrollMemoryThreshold + // Keep track of unroll memory used by this particular block / putIterator() operation + var unrollMemoryUsedByThisBlock = 0L + // Underlying buffer for unrolling the block + val redirectableStream = new RedirectableOutputStream + val bbos = new ChunkedByteBufferOutputStream(initialMemoryThreshold.toInt, allocator) + redirectableStream.setOutputStream(bbos) + val serializationStream: SerializationStream = { + val ser = serializerManager.getSerializer(classTag).newInstance() + ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream)) + } + + // Request enough memory to begin unrolling + keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, memoryMode) + + if (!keepUnrolling) { + logWarning(s"Failed to reserve initial memory threshold of " + + s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.") + } else { + unrollMemoryUsedByThisBlock += initialMemoryThreshold + } + + def reserveAdditionalMemoryIfNecessary(): Unit = { + if (bbos.size > unrollMemoryUsedByThisBlock) { + val amountToRequest = bbos.size - unrollMemoryUsedByThisBlock + keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) + if (keepUnrolling) { + unrollMemoryUsedByThisBlock += amountToRequest + } + } + } + + // Unroll this block safely, checking whether we have exceeded our threshold + while (values.hasNext && keepUnrolling) { + serializationStream.writeObject(values.next())(classTag) + reserveAdditionalMemoryIfNecessary() + } + + // Make sure that we have enough memory to store the block. By this point, it is possible that + // the block's actual memory usage has exceeded the unroll memory by a small amount, so we + // perform one final call to attempt to allocate additional memory if necessary. + if (keepUnrolling) { + serializationStream.close() + reserveAdditionalMemoryIfNecessary() + } + + if (keepUnrolling) { + val entry = SerializedMemoryEntry[T](bbos.toChunkedByteBuffer, memoryMode, classTag) + // Synchronize so that transfer is atomic + memoryManager.synchronized { + releaseUnrollMemoryForThisTask(memoryMode, unrollMemoryUsedByThisBlock) + val success = memoryManager.acquireStorageMemory(blockId, entry.size, memoryMode) + assert(success, "transferring unroll memory to storage memory failed") + } + entries.synchronized { + entries.put(blockId, entry) + } + logInfo("Block %s stored as bytes in memory (estimated size %s, free %s)".format( + blockId, Utils.bytesToString(entry.size), Utils.bytesToString(blocksMemoryUsed))) + Right(entry.size) + } else { + // We ran out of space while unrolling the values for this block + logUnrollFailureMessage(blockId, bbos.size) + Left( + new PartiallySerializedBlock( + this, + serializerManager, + blockId, + serializationStream, + redirectableStream, + unrollMemoryUsedByThisBlock, + memoryMode, + bbos.toChunkedByteBuffer, + values, + classTag)) + } + } + + def getBytes(blockId: BlockId): Option[ChunkedByteBuffer] = { + val entry = entries.synchronized { entries.get(blockId) } + entry match { + case null => None + case e: DeserializedMemoryEntry[_] => + throw new IllegalArgumentException("should only call getBytes on serialized blocks") + case SerializedMemoryEntry(bytes, _, _) => Some(bytes) + } + } + + def getValues(blockId: BlockId): Option[Iterator[_]] = { + val entry = entries.synchronized { entries.get(blockId) } + entry match { + case null => None + case e: SerializedMemoryEntry[_] => + throw new IllegalArgumentException("should only call getValues on deserialized blocks") + case DeserializedMemoryEntry(values, _, _) => + val x = Some(values) + x.map(_.iterator) + } + } + + def remove(blockId: BlockId): Boolean = memoryManager.synchronized { + val entry = entries.synchronized { + entries.remove(blockId) + } + if (entry != null) { + entry match { + case SerializedMemoryEntry(buffer, _, _) => buffer.dispose() + case _ => + } + memoryManager.releaseStorageMemory(entry.size, entry.memoryMode) + logDebug(s"Block $blockId of size ${entry.size} dropped " + + s"from memory (free ${maxMemory - blocksMemoryUsed})") + true + } else { + false + } + } + + def clear(): Unit = memoryManager.synchronized { + entries.synchronized { + entries.clear() + } + onHeapUnrollMemoryMap.clear() + offHeapUnrollMemoryMap.clear() + memoryManager.releaseAllStorageMemory() + logInfo("MemoryStore cleared") + } + + /** + * Return the RDD ID that a given block ID is from, or None if it is not an RDD block. + */ + private def getRddId(blockId: BlockId): Option[Int] = { + blockId.asRDDId.map(_.rddId) + } + + /** + * Try to evict blocks to free up a given amount of space to store a particular block. + * Can fail if either the block is bigger than our memory or it would require replacing + * another block from the same RDD (which leads to a wasteful cyclic replacement pattern for + * RDDs that don't fit into memory that we want to avoid). + * + * @param blockId the ID of the block we are freeing space for, if any + * @param space the size of this block + * @param memoryMode the type of memory to free (on- or off-heap) + * @return the amount of memory (in bytes) freed by eviction + */ + private[spark] def evictBlocksToFreeSpace( + blockId: Option[BlockId], + space: Long, + memoryMode: MemoryMode): Long = { + assert(space > 0) + memoryManager.synchronized { + var freedMemory = 0L + val rddToAdd = blockId.flatMap(getRddId) + val selectedBlocks = new ArrayBuffer[BlockId] + def blockIsEvictable(blockId: BlockId, entry: MemoryEntry[_]): Boolean = { + entry.memoryMode == memoryMode && (rddToAdd.isEmpty || rddToAdd != getRddId(blockId)) + } + // This is synchronized to ensure that the set of entries is not changed + // (because of getValue or getBytes) while traversing the iterator, as that + // can lead to exceptions. + entries.synchronized { + val iterator = entries.entrySet().iterator() + while (freedMemory < space && iterator.hasNext) { + val pair = iterator.next() + val blockId = pair.getKey + val entry = pair.getValue + if (blockIsEvictable(blockId, entry)) { + // We don't want to evict blocks which are currently being read, so we need to obtain + // an exclusive write lock on blocks which are candidates for eviction. We perform a + // non-blocking "tryLock" here in order to ignore blocks which are locked for reading: + if (blockInfoManager.lockForWriting(blockId, blocking = false).isDefined) { + selectedBlocks += blockId + freedMemory += pair.getValue.size + } + } + } + } + + def dropBlock[T](blockId: BlockId, entry: MemoryEntry[T]): Unit = { + val data = entry match { + case DeserializedMemoryEntry(values, _, _) => Left(values) + case SerializedMemoryEntry(buffer, _, _) => Right(buffer) + } + val newEffectiveStorageLevel = + blockEvictionHandler.dropFromMemory(blockId, () => data)(entry.classTag) + if (newEffectiveStorageLevel.isValid) { + // The block is still present in at least one store, so release the lock + // but don't delete the block info + blockInfoManager.unlock(blockId) + } else { + // The block isn't present in any store, so delete the block info so that the + // block can be stored again + blockInfoManager.removeBlock(blockId) + } + } + + if (freedMemory >= space) { + logInfo(s"${selectedBlocks.size} blocks selected for dropping " + + s"(${Utils.bytesToString(freedMemory)} bytes)") + for (blockId <- selectedBlocks) { + val entry = entries.synchronized { entries.get(blockId) } + // This should never be null as only one task should be dropping + // blocks and removing entries. However the check is still here for + // future safety. + if (entry != null) { + dropBlock(blockId, entry) + } + } + logInfo(s"After dropping ${selectedBlocks.size} blocks, " + + s"free memory is ${Utils.bytesToString(maxMemory - blocksMemoryUsed)}") + freedMemory + } else { + blockId.foreach { id => + logInfo(s"Will not store $id") + } + selectedBlocks.foreach { id => + blockInfoManager.unlock(id) + } + 0L + } + } + } + + def contains(blockId: BlockId): Boolean = { + entries.synchronized { entries.containsKey(blockId) } + } + + private def currentTaskAttemptId(): Long = { + // In case this is called on the driver, return an invalid task attempt id. + Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L) + } + + /** + * Reserve memory for unrolling the given block for this task. + * + * @return whether the request is granted. + */ + def reserveUnrollMemoryForThisTask( + blockId: BlockId, + memory: Long, + memoryMode: MemoryMode): Boolean = { + memoryManager.synchronized { + val success = memoryManager.acquireUnrollMemory(blockId, memory, memoryMode) + if (success) { + val taskAttemptId = currentTaskAttemptId() + val unrollMemoryMap = memoryMode match { + case MemoryMode.ON_HEAP => onHeapUnrollMemoryMap + case MemoryMode.OFF_HEAP => offHeapUnrollMemoryMap + } + unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory + } + success + } + } + + /** + * Release memory used by this task for unrolling blocks. + * If the amount is not specified, remove the current task's allocation altogether. + */ + def releaseUnrollMemoryForThisTask(memoryMode: MemoryMode, memory: Long = Long.MaxValue): Unit = { + val taskAttemptId = currentTaskAttemptId() + memoryManager.synchronized { + val unrollMemoryMap = memoryMode match { + case MemoryMode.ON_HEAP => onHeapUnrollMemoryMap + case MemoryMode.OFF_HEAP => offHeapUnrollMemoryMap + } + if (unrollMemoryMap.contains(taskAttemptId)) { + val memoryToRelease = math.min(memory, unrollMemoryMap(taskAttemptId)) + if (memoryToRelease > 0) { + unrollMemoryMap(taskAttemptId) -= memoryToRelease + if (unrollMemoryMap(taskAttemptId) == 0) { + unrollMemoryMap.remove(taskAttemptId) + } + memoryManager.releaseUnrollMemory(memoryToRelease, memoryMode) + } + } + } + } + + /** + * Return the amount of memory currently occupied for unrolling blocks across all tasks. + */ + def currentUnrollMemory: Long = memoryManager.synchronized { + onHeapUnrollMemoryMap.values.sum + offHeapUnrollMemoryMap.values.sum + } + + /** + * Return the amount of memory currently occupied for unrolling blocks by this task. + */ + def currentUnrollMemoryForThisTask: Long = memoryManager.synchronized { + onHeapUnrollMemoryMap.getOrElse(currentTaskAttemptId(), 0L) + + offHeapUnrollMemoryMap.getOrElse(currentTaskAttemptId(), 0L) + } + + /** + * Return the number of tasks currently unrolling blocks. + */ + private def numTasksUnrolling: Int = memoryManager.synchronized { + (onHeapUnrollMemoryMap.keys ++ offHeapUnrollMemoryMap.keys).toSet.size + } + + /** + * Log information about current memory usage. + */ + private def logMemoryUsage(): Unit = { + logInfo( + s"Memory use = ${Utils.bytesToString(blocksMemoryUsed)} (blocks) + " + + s"${Utils.bytesToString(currentUnrollMemory)} (scratch space shared across " + + s"$numTasksUnrolling tasks(s)) = ${Utils.bytesToString(memoryUsed)}. " + + s"Storage limit = ${Utils.bytesToString(maxMemory)}." + ) + } + + /** + * Log a warning for failing to unroll a block. + * + * @param blockId ID of the block we are trying to unroll. + * @param finalVectorSize Final size of the vector before unrolling failed. + */ + private def logUnrollFailureMessage(blockId: BlockId, finalVectorSize: Long): Unit = { + logWarning( + s"Not enough space to cache $blockId in memory! " + + s"(computed ${Utils.bytesToString(finalVectorSize)} so far)" + ) + logMemoryUsage() + } +} + +/** + * The result of a failed [[MemoryStore.putIteratorAsValues()]] call. + * + * @param memoryStore the memoryStore, used for freeing memory. + * @param unrollMemory the amount of unroll memory used by the values in `unrolled`. + * @param unrolled an iterator for the partially-unrolled values. + * @param rest the rest of the original iterator passed to + * [[MemoryStore.putIteratorAsValues()]]. + */ +private[storage] class PartiallyUnrolledIterator[T]( + memoryStore: MemoryStore, + unrollMemory: Long, + unrolled: Iterator[T], + rest: Iterator[T]) + extends Iterator[T] { + + private[this] var unrolledIteratorIsConsumed: Boolean = false + private[this] var iter: Iterator[T] = { + val completionIterator = CompletionIterator[T, Iterator[T]](unrolled, { + unrolledIteratorIsConsumed = true + memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory) + }) + completionIterator ++ rest + } + + override def hasNext: Boolean = iter.hasNext + override def next(): T = iter.next() + + /** + * Called to dispose of this iterator and free its memory. + */ + def close(): Unit = { + if (!unrolledIteratorIsConsumed) { + memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory) + unrolledIteratorIsConsumed = true + } + iter = null + } +} + +/** + * A wrapper which allows an open [[OutputStream]] to be redirected to a different sink. + */ +private class RedirectableOutputStream extends OutputStream { + private[this] var os: OutputStream = _ + def setOutputStream(s: OutputStream): Unit = { os = s } + override def write(b: Int): Unit = os.write(b) + override def write(b: Array[Byte]): Unit = os.write(b) + override def write(b: Array[Byte], off: Int, len: Int): Unit = os.write(b, off, len) + override def flush(): Unit = os.flush() + override def close(): Unit = os.close() +} + +/** + * The result of a failed [[MemoryStore.putIteratorAsBytes()]] call. + * + * @param memoryStore the MemoryStore, used for freeing memory. + * @param serializerManager the SerializerManager, used for deserializing values. + * @param blockId the block id. + * @param serializationStream a serialization stream which writes to [[redirectableOutputStream]]. + * @param redirectableOutputStream an OutputStream which can be redirected to a different sink. + * @param unrollMemory the amount of unroll memory used by the values in `unrolled`. + * @param memoryMode whether the unroll memory is on- or off-heap + * @param unrolled a byte buffer containing the partially-serialized values. + * @param rest the rest of the original iterator passed to + * [[MemoryStore.putIteratorAsValues()]]. + * @param classTag the [[ClassTag]] for the block. + */ +private[storage] class PartiallySerializedBlock[T]( + memoryStore: MemoryStore, + serializerManager: SerializerManager, + blockId: BlockId, + serializationStream: SerializationStream, + redirectableOutputStream: RedirectableOutputStream, + unrollMemory: Long, + memoryMode: MemoryMode, + unrolled: ChunkedByteBuffer, + rest: Iterator[T], + classTag: ClassTag[T]) { + + // If the task does not fully consume `valuesIterator` or otherwise fails to consume or dispose of + // this PartiallySerializedBlock then we risk leaking of direct buffers, so we use a task + // completion listener here in order to ensure that `unrolled.dispose()` is called at least once. + // The dispose() method is idempotent, so it's safe to call it unconditionally. + Option(TaskContext.get()).foreach { taskContext => + taskContext.addTaskCompletionListener { _ => + // When a task completes, its unroll memory will automatically be freed. Thus we do not call + // releaseUnrollMemoryForThisTask() here because we want to avoid double-freeing. + unrolled.dispose() + } + } + + /** + * Called to dispose of this block and free its memory. + */ + def discard(): Unit = { + try { + // We want to close the output stream in order to free any resources associated with the + // serializer itself (such as Kryo's internal buffers). close() might cause data to be + // written, so redirect the output stream to discard that data. + redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream()) + serializationStream.close() + } finally { + unrolled.dispose() + memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory) + } + } + + /** + * Finish writing this block to the given output stream by first writing the serialized values + * and then serializing the values from the original input iterator. + */ + def finishWritingToStream(os: OutputStream): Unit = { + // `unrolled`'s underlying buffers will be freed once this input stream is fully read: + ByteStreams.copy(unrolled.toInputStream(dispose = true), os) + memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory) + redirectableOutputStream.setOutputStream(os) + while (rest.hasNext) { + serializationStream.writeObject(rest.next())(classTag) + } + serializationStream.close() + } + + /** + * Returns an iterator over the values in this block by first deserializing the serialized + * values and then consuming the rest of the original input iterator. + * + * If the caller does not plan to fully consume the resulting iterator then they must call + * `close()` on it to free its resources. + */ + def valuesIterator: PartiallyUnrolledIterator[T] = { + // `unrolled`'s underlying buffers will be freed once this input stream is fully read: + val unrolledIter = serializerManager.dataDeserializeStream( + blockId, unrolled.toInputStream(dispose = true))(classTag) + new PartiallyUnrolledIterator( + memoryStore, + unrollMemory, + unrolled = CompletionIterator[T, Iterator[T]](unrolledIter, discard()), + rest = rest) + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala index 77c0bc8b5360a..2719e1ee98ba4 100644 --- a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala +++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala @@ -20,6 +20,7 @@ package org.apache.spark.ui import java.util.{Timer, TimerTask} import org.apache.spark._ +import org.apache.spark.internal.Logging /** * ConsoleProgressBar shows the progress of stages in the next line of the console. It poll the @@ -28,7 +29,7 @@ import org.apache.spark._ * of them will be combined together, showed in one line. */ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { - // Carrige return + // Carriage return val CR = '\r' // Update period of progress bar, in milliseconds val UPDATE_PERIOD = 200L @@ -63,7 +64,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { return } val stageIds = sc.statusTracker.getActiveStageIds() - val stages = stageIds.map(sc.statusTracker.getStageInfo).flatten.filter(_.numTasks() > 1) + val stages = stageIds.flatMap(sc.statusTracker.getStageInfo).filter(_.numTasks() > 1) .filter(now - _.submissionTime() > FIRST_DELAY).sortBy(_.stageId()) if (stages.length > 0) { show(now, stages.take(3)) // display at most 3 stages in same time diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index b796a44fe01ac..119165f724f59 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -17,21 +17,26 @@ package org.apache.spark.ui -import java.net.{InetSocketAddress, URL} +import java.net.{URI, URL} import javax.servlet.DispatcherType import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} +import scala.collection.mutable.ArrayBuffer import scala.language.implicitConversions import scala.xml.Node -import org.eclipse.jetty.server.Server +import org.eclipse.jetty.server.{Connector, Request, Server} import org.eclipse.jetty.server.handler._ +import org.eclipse.jetty.server.nio.SelectChannelConnector +import org.eclipse.jetty.server.ssl.SslSelectChannelConnector import org.eclipse.jetty.servlet._ +import org.eclipse.jetty.util.component.LifeCycle import org.eclipse.jetty.util.thread.QueuedThreadPool import org.json4s.JValue import org.json4s.jackson.JsonMethods.{pretty, render} -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SSLOptions} +import org.apache.spark.internal.Logging import org.apache.spark.util.Utils /** @@ -224,24 +229,67 @@ private[spark] object JettyUtils extends Logging { def startJettyServer( hostName: String, port: Int, + sslOptions: SSLOptions, handlers: Seq[ServletContextHandler], conf: SparkConf, serverName: String = ""): ServerInfo = { + val collection = new ContextHandlerCollection addFilters(handlers, conf) - val collection = new ContextHandlerCollection val gzipHandlers = handlers.map { h => val gzipHandler = new GzipHandler gzipHandler.setHandler(h) gzipHandler } - collection.setHandlers(gzipHandlers.toArray) // Bind to the given port, or throw a java.net.BindException if the port is occupied def connect(currentPort: Int): (Server, Int) = { - val server = new Server(new InetSocketAddress(hostName, currentPort)) + val server = new Server + val connectors = new ArrayBuffer[Connector] + // Create a connector on port currentPort to listen for HTTP requests + val httpConnector = new SelectChannelConnector() + httpConnector.setPort(currentPort) + connectors += httpConnector + + sslOptions.createJettySslContextFactory().foreach { factory => + // If the new port wraps around, do not try a privileged port. + val securePort = + if (currentPort != 0) { + (currentPort + 400 - 1024) % (65536 - 1024) + 1024 + } else { + 0 + } + val scheme = "https" + // Create a connector on port securePort to listen for HTTPS requests + val connector = new SslSelectChannelConnector(factory) + connector.setPort(securePort) + connectors += connector + + // redirect the HTTP requests to HTTPS port + collection.addHandler(createRedirectHttpsHandler(securePort, scheme)) + } + + gzipHandlers.foreach(collection.addHandler) + connectors.foreach(_.setHost(hostName)) + // As each acceptor and each selector will use one thread, the number of threads should at + // least be the number of acceptors and selectors plus 1. (See SPARK-13776) + var minThreads = 1 + connectors.foreach { c => + // Currently we only use "SelectChannelConnector" + val connector = c.asInstanceOf[SelectChannelConnector] + // Limit the max acceptor number to 8 so that we don't waste a lot of threads + connector.setAcceptors(math.min(connector.getAcceptors, 8)) + // The number of selectors always equals to the number of acceptors + minThreads += connector.getAcceptors * 2 + } + server.setConnectors(connectors.toArray) + val pool = new QueuedThreadPool + if (serverName.nonEmpty) { + pool.setName(serverName) + } + pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads)) pool.setDaemon(true) server.setThreadPool(pool) val errorHandler = new ErrorHandler() @@ -262,9 +310,56 @@ private[spark] object JettyUtils extends Logging { val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName) ServerInfo(server, boundPort, collection) } + + private def createRedirectHttpsHandler(securePort: Int, scheme: String): ContextHandler = { + val redirectHandler: ContextHandler = new ContextHandler + redirectHandler.setContextPath("/") + redirectHandler.setHandler(new AbstractHandler { + override def handle( + target: String, + baseRequest: Request, + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + if (baseRequest.isSecure) { + return + } + val httpsURI = createRedirectURI(scheme, baseRequest.getServerName, securePort, + baseRequest.getRequestURI, baseRequest.getQueryString) + response.setContentLength(0) + response.encodeRedirectURL(httpsURI) + response.sendRedirect(httpsURI) + baseRequest.setHandled(true) + } + }) + redirectHandler + } + + // Create a new URI from the arguments, handling IPv6 host encoding and default ports. + private def createRedirectURI( + scheme: String, server: String, port: Int, path: String, query: String) = { + val redirectServer = if (server.contains(":") && !server.startsWith("[")) { + s"[${server}]" + } else { + server + } + val authority = s"$redirectServer:$port" + new URI(scheme, authority, path, query, null).toString + } + } private[spark] case class ServerInfo( server: Server, boundPort: Int, - rootHandler: ContextHandlerCollection) + rootHandler: ContextHandlerCollection) { + + def stop(): Unit = { + server.stop() + // Stop the ThreadPool if it supports stop() method (through LifeCycle). + // It is needed because stopping the Server won't stop the ThreadPool it uses. + val threadPool = server.getThreadPool + if (threadPool != null && threadPool.isInstanceOf[LifeCycle]) { + threadPool.asInstanceOf[LifeCycle].stop + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala index 6e2375477a688..9b6ed8cbbef10 100644 --- a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala @@ -17,8 +17,15 @@ package org.apache.spark.ui +import java.net.URLDecoder + +import scala.collection.JavaConverters._ import scala.xml.{Node, Unparsed} +import com.google.common.base.Splitter + +import org.apache.spark.util.Utils + /** * A data source that provides data for a page. * @@ -71,6 +78,12 @@ private[ui] trait PagedTable[T] { def tableCssClass: String + def pageSizeFormField: String + + def prevPageSizeFormField: String + + def pageNumberFormField: String + def dataSource: PagedDataSource[T] def headers: Seq[Node] @@ -95,7 +108,12 @@ private[ui] trait PagedTable[T] { val PageData(totalPages, _) = _dataSource.pageData(1)
{pageNavigation(1, _dataSource.pageSize, totalPages)} -
{e.getMessage}
+
+

Error while rendering table:

+
+              {Utils.exceptionString(e)}
+            
+
} } @@ -151,36 +169,56 @@ private[ui] trait PagedTable[T] { // The current page should be disabled so that it cannot be clicked.
  • {p}
  • } else { -
  • {p}
  • +
  • {p}
  • + } + } + + val hiddenFormFields = { + if (goButtonFormPath.contains('?')) { + val querystring = goButtonFormPath.split("\\?", 2)(1) + Splitter + .on('&') + .trimResults() + .withKeyValueSeparator("=") + .split(querystring) + .asScala + .filterKeys(_ != pageSizeFormField) + .filterKeys(_ != prevPageSizeFormField) + .filterKeys(_ != pageNumberFormField) + .mapValues(URLDecoder.decode(_, "UTF-8")) + .map { case (k, v) => + + } + } else { + Seq.empty } } - val (goButtonJsFuncName, goButtonJsFunc) = goButtonJavascriptFunction - // When clicking the "Go" button, it will call this javascript method and then call - // "goButtonJsFuncName" - val formJs = - s"""$$(function(){ - | $$( "#form-$tableId-page" ).submit(function(event) { - | var page = $$("#form-$tableId-page-no").val() - | var pageSize = $$("#form-$tableId-page-size").val() - | pageSize = pageSize ? pageSize: 100; - | if (page != "") { - | ${goButtonJsFuncName}(page, pageSize); - | } - | event.preventDefault(); - | }); - |}); - """.stripMargin
    + method="get" + action={Unparsed(goButtonFormPath)} + class="form-inline pull-right" + style="margin-bottom: 0px;"> + + {hiddenFormFields} - + + + id={s"form-$tableId-page-size"} + name={pageSizeFormField} + value={pageSize.toString} + class="span1" /> +
    @@ -189,7 +227,7 @@ private[ui] trait PagedTable[T] {
    - } } @@ -239,10 +272,7 @@ private[ui] trait PagedTable[T] { def pageLink(page: Int): String /** - * Only the implementation knows how to create the url with a page number and the page size, so we - * leave this one to the implementation. The implementation should create a JavaScript method that - * accepts a page number along with the page size and jumps to the page. The return value is this - * method name and its JavaScript codes. + * Returns the submission path for the "go to page #" form. */ - def goButtonJavascriptFunction: (String, String) + def goButtonFormPath: String } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 99085ada9f0af..39155ff2649ec 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -17,19 +17,23 @@ package org.apache.spark.ui -import java.util.Date +import java.util.{Date, ServiceLoader} +import scala.collection.JavaConverters._ + +import org.apache.spark.{SecurityManager, SparkConf, SparkContext} +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler._ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo, UIRoot} -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} -import org.apache.spark.scheduler._ import org.apache.spark.storage.StorageStatusListener import org.apache.spark.ui.JettyUtils._ import org.apache.spark.ui.env.{EnvironmentListener, EnvironmentTab} import org.apache.spark.ui.exec.{ExecutorsListener, ExecutorsTab} -import org.apache.spark.ui.jobs.{JobsTab, JobProgressListener, StagesTab} -import org.apache.spark.ui.storage.{StorageListener, StorageTab} +import org.apache.spark.ui.jobs.{JobProgressListener, JobsTab, StagesTab} import org.apache.spark.ui.scope.RDDOperationGraphListener +import org.apache.spark.ui.storage.{StorageListener, StorageTab} +import org.apache.spark.util.Utils /** * Top level user interface for a Spark application. @@ -47,7 +51,8 @@ private[spark] class SparkUI private ( var appName: String, val basePath: String, val startTime: Long) - extends WebUI(securityManager, SparkUI.getUIPort(conf), conf, basePath, "SparkUI") + extends WebUI(securityManager, securityManager.getSSLOptions("ui"), SparkUI.getUIPort(conf), + conf, basePath, "SparkUI") with Logging with UIRoot { @@ -75,6 +80,10 @@ private[spark] class SparkUI private ( } initialize() + def getSparkUser: String = { + environmentListener.systemProperties.toMap.get("user.name").getOrElse("") + } + def getAppName: String = appName def setAppId(id: String): Unit = { @@ -102,10 +111,16 @@ private[spark] class SparkUI private ( Iterator(new ApplicationInfo( id = appId, name = appName, + coresGranted = None, + maxCores = None, + coresPerExecutor = None, + memoryPerExecutorMB = None, attempts = Seq(new ApplicationAttemptInfo( attemptId = None, startTime = new Date(startTime), endTime = new Date(-1), + duration = 0, + lastUpdated = new Date(startTime), sparkUser = "", completed = false )) @@ -150,7 +165,16 @@ private[spark] object SparkUI { appName: String, basePath: String, startTime: Long): SparkUI = { - create(None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime) + val sparkUI = create( + None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime) + + val listenerFactories = ServiceLoader.load(classOf[SparkHistoryListenerFactory], + Utils.getContextOrSparkClassLoader).asScala + listenerFactories.foreach { listenerFactory => + val listeners = listenerFactory.createListeners(conf, sparkUI) + listeners.foreach(listenerBus.addListener) + } + sparkUI } /** @@ -177,8 +201,8 @@ private[spark] object SparkUI { } val environmentListener = new EnvironmentListener - val storageStatusListener = new StorageStatusListener - val executorsListener = new ExecutorsListener(storageStatusListener) + val storageStatusListener = new StorageStatusListener(conf) + val executorsListener = new ExecutorsListener(storageStatusListener, conf) val storageListener = new StorageListener(storageStatusListener) val operationGraphListener = new RDDOperationGraphListener(conf) diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index cb122eaed83d1..2d2d80be4aabe 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -87,4 +87,7 @@ private[spark] object ToolTips { multiple operations (e.g. two map() functions) if they can be pipelined. Some operations also create multiple RDDs internally. Cached RDDs are shown in green. """ + + val TASK_TIME = + "Shaded red when garbage collection (GC) time is over 10% of task time" } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 25dcb604d9e5f..28d277df4ae12 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.ui +import java.net.URLDecoder import java.text.SimpleDateFormat import java.util.{Date, Locale} @@ -24,7 +25,7 @@ import scala.util.control.NonFatal import scala.xml._ import scala.xml.transform.{RewriteRule, RuleTransformer} -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.ui.scope.RDDOperationGraph /** Utility functions for generating XML pages with spark content. */ @@ -177,6 +178,20 @@ private[spark] object UIUtils extends Logging { } + def dataTablesHeaderNodes: Seq[Node] = { + + + + + + + + + + } + /** Returns a spark page with correctly formatted headers */ def headerSparkPage( title: String, @@ -210,10 +225,10 @@ private[spark] object UIUtils extends Logging { {org.apache.spark.SPARK_VERSION} - +
    @@ -232,10 +247,14 @@ private[spark] object UIUtils extends Logging { } /** Returns a page with the spark css/js and a simple format. Used for scheduler UI. */ - def basicSparkPage(content: => Seq[Node], title: String): Seq[Node] = { + def basicSparkPage( + content: => Seq[Node], + title: String, + useDataTables: Boolean = false): Seq[Node] = { {commonHeaderNodes} + {if (useDataTables) dataTablesHeaderNodes else Seq.empty} {title} @@ -319,7 +338,9 @@ private[spark] object UIUtils extends Logging { skipped: Int, total: Int): Seq[Node] = { val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) - val startWidth = "width: %s%%".format((started.toDouble/total)*100) + // started + completed can be > total when there are speculative tasks + val boundedStarted = math.min(started, total - completed) + val startWidth = "width: %s%%".format((boundedStarted.toDouble/total)*100)
    @@ -387,13 +408,6 @@ private[spark] object UIUtils extends Logging { } - /** Return a script element that automatically expands the DAG visualization on page load. */ - def expandDagVizOnLoad(forJob: Boolean): Seq[Node] = { - - } - /** * Returns HTML rendering of a job or stage description. It will try to parse the string as HTML * and make sure that it only contains anchors with root-relative links. Otherwise, @@ -402,8 +416,16 @@ private[spark] object UIUtils extends Logging { * Note: In terms of security, only anchor tags with root relative links are supported. So any * attempts to embed links outside Spark UI, or other tags like } private def createExecutorTable() : Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 2cad0a796913e..bd4797ae8e0c5 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -18,11 +18,10 @@ package org.apache.spark.ui.jobs import java.util.Date +import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{Buffer, HashMap, ListBuffer} -import scala.xml.{NodeSeq, Node, Unparsed, Utility} - -import javax.servlet.http.HttpServletRequest +import scala.xml.{Node, NodeSeq, Unparsed, Utility} import org.apache.spark.JobExecutionStatus import org.apache.spark.scheduler.StageInfo @@ -123,7 +122,7 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { | 'Removed at ${UIUtils.formatDate(new Date(event.finishTime.get))}' + | '${ if (event.finishReason.isDefined) { - s"""
    Reason: ${event.finishReason.get}""" + s"""
    Reason: ${event.finishReason.get.replace("\n", " ")}""" } else { "" } @@ -204,7 +203,7 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { // This could be empty if the JobProgressListener hasn't received information about the // stage or if the stage information has been garbage collected listener.stageIdToInfo.getOrElse(stageId, - new StageInfo(stageId, 0, "Unknown", 0, Seq.empty, Seq.empty, "Unknown")) + new StageInfo(stageId, 0, "Unknown", 0, Seq.empty, Seq.empty, "Unknown", Seq.empty)) } val activeStages = Buffer[StageInfo]() diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 77d034fa5ba2c..13f5f84d06feb 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -21,11 +21,10 @@ import java.util.concurrent.TimeoutException import scala.collection.mutable.{HashMap, HashSet, ListBuffer} -import com.google.common.annotations.VisibleForTesting - import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.BlockManagerId @@ -327,12 +326,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { val taskInfo = taskStart.taskInfo if (taskInfo != null) { + val metrics = new TaskMetrics val stageData = stageIdToData.getOrElseUpdate((taskStart.stageId, taskStart.stageAttemptId), { logWarning("Task start for unknown stage " + taskStart.stageId) new StageUIData }) stageData.numActiveTasks += 1 - stageData.taskData.put(taskInfo.taskId, new TaskUIData(taskInfo)) + stageData.taskData.put(taskInfo.taskId, new TaskUIData(taskInfo, Some(metrics))) } for ( activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskStart.stageId); @@ -375,28 +375,34 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { execSummary.taskTime += info.duration stageData.numActiveTasks -= 1 - val (errorMessage, metrics): (Option[String], Option[TaskMetrics]) = + val (errorMessage, accums): (Option[String], Seq[AccumulableInfo]) = taskEnd.reason match { case org.apache.spark.Success => stageData.completedIndices.add(info.index) stageData.numCompleteTasks += 1 - (None, Option(taskEnd.taskMetrics)) - case e: ExceptionFailure => // Handle ExceptionFailure because we might have metrics + (None, taskEnd.taskMetrics.accumulatorUpdates()) + case e: ExceptionFailure => // Handle ExceptionFailure because we might have accumUpdates stageData.numFailedTasks += 1 - (Some(e.toErrorString), e.metrics) - case e: TaskFailedReason => // All other failure cases + (Some(e.toErrorString), e.accumUpdates) + case e: TaskFailedReason => // All other failure cases stageData.numFailedTasks += 1 - (Some(e.toErrorString), None) + (Some(e.toErrorString), Seq.empty[AccumulableInfo]) } - if (!metrics.isEmpty) { - val oldMetrics = stageData.taskData.get(info.taskId).flatMap(_.taskMetrics) - updateAggregateMetrics(stageData, info.executorId, metrics.get, oldMetrics) + val taskMetrics = + if (accums.nonEmpty) { + Some(TaskMetrics.fromAccumulatorUpdates(accums)) + } else { + None + } + taskMetrics.foreach { m => + val oldMetrics = stageData.taskData.get(info.taskId).flatMap(_.metrics) + updateAggregateMetrics(stageData, info.executorId, m, oldMetrics) } val taskData = stageData.taskData.getOrElseUpdate(info.taskId, new TaskUIData(info)) taskData.taskInfo = info - taskData.taskMetrics = metrics + taskData.metrics = taskMetrics taskData.errorMessage = errorMessage for ( @@ -428,14 +434,14 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val execSummary = stageData.executorSummary.getOrElseUpdate(execId, new ExecutorSummary) val shuffleWriteDelta = - (taskMetrics.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L) - - oldMetrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleBytesWritten).getOrElse(0L)) + (taskMetrics.shuffleWriteMetrics.map(_.bytesWritten).getOrElse(0L) + - oldMetrics.flatMap(_.shuffleWriteMetrics).map(_.bytesWritten).getOrElse(0L)) stageData.shuffleWriteBytes += shuffleWriteDelta execSummary.shuffleWrite += shuffleWriteDelta val shuffleWriteRecordsDelta = - (taskMetrics.shuffleWriteMetrics.map(_.shuffleRecordsWritten).getOrElse(0L) - - oldMetrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleRecordsWritten).getOrElse(0L)) + (taskMetrics.shuffleWriteMetrics.map(_.recordsWritten).getOrElse(0L) + - oldMetrics.flatMap(_.shuffleWriteMetrics).map(_.recordsWritten).getOrElse(0L)) stageData.shuffleWriteRecords += shuffleWriteRecordsDelta execSummary.shuffleWriteRecords += shuffleWriteRecordsDelta @@ -491,19 +497,18 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } override def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { - for ((taskId, sid, sAttempt, taskMetrics) <- executorMetricsUpdate.taskMetrics) { + for ((taskId, sid, sAttempt, accumUpdates) <- executorMetricsUpdate.accumUpdates) { val stageData = stageIdToData.getOrElseUpdate((sid, sAttempt), { logWarning("Metrics update for task in unknown stage " + sid) new StageUIData }) val taskData = stageData.taskData.get(taskId) - taskData.map { t => + val metrics = TaskMetrics.fromAccumulatorUpdates(accumUpdates) + taskData.foreach { t => if (!t.taskInfo.finished) { - updateAggregateMetrics(stageData, executorMetricsUpdate.execId, taskMetrics, - t.taskMetrics) - + updateAggregateMetrics(stageData, executorMetricsUpdate.execId, metrics, t.metrics) // Overwrite task metrics - t.taskMetrics = Some(taskMetrics) + t.metrics = Some(metrics) } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala index 77ca60b000a9b..7b00b558d591a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -29,7 +29,9 @@ private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") { val operationGraphListener = parent.operationGraphListener def isFairScheduler: Boolean = - jobProgresslistener.schedulingMode.exists(_ == SchedulingMode.FAIR) + jobProgresslistener.schedulingMode == Some(SchedulingMode.FAIR) + + def getSparkUser: String = parent.getSparkUser attachPage(new AllJobsPage(this)) attachPage(new JobPage(this)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index f3e0b38523f32..6cd25919ca5fd 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -22,7 +22,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node import org.apache.spark.scheduler.StageInfo -import org.apache.spark.ui.{WebUIPage, UIUtils} +import org.apache.spark.ui.{UIUtils, WebUIPage} /** Page showing specific pool details */ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { @@ -31,8 +31,11 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { def render(request: HttpServletRequest): Seq[Node] = { listener.synchronized { - val poolName = request.getParameter("poolname") - require(poolName != null && poolName.nonEmpty, "Missing poolname parameter") + val poolName = Option(request.getParameter("poolname")).map { poolname => + UIUtils.decodeURLParameter(poolname) + }.getOrElse { + throw new IllegalArgumentException(s"Missing poolname parameter") + } val poolToActiveStages = listener.poolToActiveStages val activeStages = poolToActiveStages.get(poolName) match { @@ -44,7 +47,9 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { killEnabled = parent.killEnabled) // For now, pool information is only accessible in live UIs - val pools = sc.map(_.getPoolForName(poolName).get).toSeq + val pools = sc.map(_.getPoolForName(poolName).getOrElse { + throw new IllegalArgumentException(s"Unknown poolname: $poolName") + }).toSeq val poolTable = new PoolTable(pools, parent) val content = diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala index 9ba2af54dacf4..ea02968733cac 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala @@ -17,6 +17,8 @@ package org.apache.spark.ui.jobs +import java.net.URLEncoder + import scala.collection.mutable.HashMap import scala.xml.Node @@ -59,7 +61,7 @@ private[ui] class PoolTable(pools: Seq[Schedulable], parent: StagesTab) { case None => 0 } val href = "%s/stages/pool?poolname=%s" - .format(UIUtils.prependBaseUri(parent.basePath), p.name) + .format(UIUtils.prependBaseUri(parent.basePath), URLEncoder.encode(p.name, "UTF-8")) {p.name} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 51425e599e748..8a44bbd9fcd57 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -28,10 +28,10 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.{InternalAccumulator, SparkConf} import org.apache.spark.executor.TaskMetrics -import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} +import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo, TaskLocality} import org.apache.spark.ui._ import org.apache.spark.ui.jobs.UIData._ -import org.apache.spark.util.{Utils, Distribution} +import org.apache.spark.util.{Distribution, Utils} /** Page showing statistics and task list for a given stage */ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { @@ -70,6 +70,21 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { private val displayPeakExecutionMemory = parent.conf.getBoolean("spark.sql.unsafe.enabled", true) + private def getLocalitySummaryString(stageData: StageUIData): String = { + val localities = stageData.taskData.values.map(_.taskInfo.taskLocality) + val localityCounts = localities.groupBy(identity).mapValues(_.size) + val localityNamesAndCounts = localityCounts.toSeq.map { case (locality, count) => + val localityName = locality match { + case TaskLocality.PROCESS_LOCAL => "Process local" + case TaskLocality.NODE_LOCAL => "Node local" + case TaskLocality.RACK_LOCAL => "Rack local" + case TaskLocality.ANY => "Any" + } + s"$localityName: $count" + } + localityNamesAndCounts.sorted.mkString("; ") + } + def render(request: HttpServletRequest): Seq[Node] = { progressListener.synchronized { val parameterId = request.getParameter("id") @@ -82,15 +97,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val parameterTaskSortColumn = request.getParameter("task.sort") val parameterTaskSortDesc = request.getParameter("task.desc") val parameterTaskPageSize = request.getParameter("task.pageSize") + val parameterTaskPrevPageSize = request.getParameter("task.prevPageSize") val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1) - val taskSortColumn = Option(parameterTaskSortColumn).getOrElse("Index") + val taskSortColumn = Option(parameterTaskSortColumn).map { sortColumn => + UIUtils.decodeURLParameter(sortColumn) + }.getOrElse("Index") val taskSortDesc = Option(parameterTaskSortDesc).map(_.toBoolean).getOrElse(false) val taskPageSize = Option(parameterTaskPageSize).map(_.toInt).getOrElse(100) - - // If this is set, expand the dag visualization by default - val expandDagVizParam = request.getParameter("expandDagViz") - val expandDagViz = expandDagVizParam != null && expandDagVizParam.toBoolean + val taskPrevPageSize = Option(parameterTaskPrevPageSize).map(_.toInt).getOrElse(taskPageSize) val stageId = parameterId.toInt val stageAttemptId = parameterAttempt.toInt @@ -129,6 +144,10 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { Total Time Across All Tasks: {UIUtils.formatDuration(stageData.executorRunTime)} +
  • + Locality Level Summary: + {getLocalitySummaryString(stageData)} +
  • {if (stageData.hasInput) {
  • Input Size / Records: @@ -240,21 +259,27 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val dagViz = UIUtils.showDagVizForStage( stageId, operationGraphListener.getOperationGraphForStage(stageId)) - val maybeExpandDagViz: Seq[Node] = - if (expandDagViz) { - UIUtils.expandDagVizOnLoad(forJob = false) - } else { - Seq.empty - } - val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value") - def accumulableRow(acc: AccumulableInfo): Elem = - {acc.name}{acc.value} + def accumulableRow(acc: AccumulableInfo): Seq[Node] = { + (acc.name, acc.value) match { + case (Some(name), Some(value)) => {name}{value} + case _ => Seq.empty[Node] + } + } val accumulableTable = UIUtils.listingTable( accumulableHeaders, accumulableRow, externalAccumulables.toSeq) + val page: Int = { + // If the user has changed to a larger page size, then go to page 1 in order to avoid + // IndexOutOfBoundsException. + if (taskPageSize <= taskPrevPageSize) { + taskPage + } else { + 1 + } + } val currentTime = System.currentTimeMillis() val (taskTable, taskTableHTML) = try { val _taskTable = new TaskPagedTable( @@ -273,10 +298,17 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { sortColumn = taskSortColumn, desc = taskSortDesc ) - (_taskTable, _taskTable.table(taskPage)) + (_taskTable, _taskTable.table(page)) } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => - (null,
    {e.getMessage}
    ) + val errorMessage = +
    +

    Error while rendering stage table:

    +
    +                {Utils.exceptionString(e)}
    +              
    +
    + (null, errorMessage) } val jsForScrollingDownToTaskTable = @@ -298,7 +330,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { else taskTable.dataSource.slicedTaskIds // Excludes tasks which failed and have incomplete metrics - val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.taskMetrics.isDefined) + val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.metrics.isDefined) val summaryTable: Option[Seq[Node]] = if (validTasks.size == 0) { @@ -316,8 +348,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { getDistributionQuantiles(data).map(d => {Utils.bytesToString(d.toLong)}) } - val deserializationTimes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.executorDeserializeTime.toDouble + val deserializationTimes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.executorDeserializeTime.toDouble } val deserializationQuantiles = @@ -327,13 +359,13 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +: getFormattedTimeQuantiles(deserializationTimes) - val serviceTimes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.executorRunTime.toDouble + val serviceTimes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.executorRunTime.toDouble } val serviceQuantiles = Duration +: getFormattedTimeQuantiles(serviceTimes) - val gcTimes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.jvmGCTime.toDouble + val gcTimes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.jvmGCTime.toDouble } val gcQuantiles = @@ -342,8 +374,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +: getFormattedTimeQuantiles(gcTimes) - val serializationTimes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.resultSerializationTime.toDouble + val serializationTimes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.resultSerializationTime.toDouble } val serializationQuantiles = @@ -353,8 +385,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +: getFormattedTimeQuantiles(serializationTimes) - val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) => - getGettingResultTime(info, currentTime).toDouble + val gettingResultTimes = validTasks.map { taskUIData: TaskUIData => + getGettingResultTime(taskUIData.taskInfo, currentTime).toDouble } val gettingResultQuantiles = @@ -365,12 +397,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +: getFormattedTimeQuantiles(gettingResultTimes) - val peakExecutionMemory = validTasks.map { case TaskUIData(info, _, _) => - info.accumulables - .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY } - .map { acc => acc.update.getOrElse("0").toLong } - .getOrElse(0L) - .toDouble + val peakExecutionMemory = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.peakExecutionMemory.toDouble } val peakExecutionMemoryQuantiles = { @@ -384,8 +412,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { // The scheduler delay includes the network delay to send the task to the worker // machine and to send back the result (but not the time to fetch the task result, // if it needed to be fetched from the block manager on the worker). - val schedulerDelays = validTasks.map { case TaskUIData(info, metrics, _) => - getSchedulerDelay(info, metrics.get, currentTime).toDouble + val schedulerDelays = validTasks.map { taskUIData: TaskUIData => + getSchedulerDelay(taskUIData.taskInfo, taskUIData.metrics.get, currentTime).toDouble } val schedulerDelayTitle = Scheduler Delay @@ -399,30 +427,30 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { ) } - val inputSizes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.inputMetrics.map(_.bytesRead).getOrElse(0L).toDouble + val inputSizes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.inputMetrics.map(_.bytesRead).getOrElse(0L).toDouble } - val inputRecords = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.inputMetrics.map(_.recordsRead).getOrElse(0L).toDouble + val inputRecords = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.inputMetrics.map(_.recordsRead).getOrElse(0L).toDouble } val inputQuantiles = Input Size / Records +: getFormattedSizeQuantilesWithRecords(inputSizes, inputRecords) - val outputSizes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.outputMetrics.map(_.bytesWritten).getOrElse(0L).toDouble + val outputSizes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.outputMetrics.map(_.bytesWritten).getOrElse(0L).toDouble } - val outputRecords = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.outputMetrics.map(_.recordsWritten).getOrElse(0L).toDouble + val outputRecords = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.outputMetrics.map(_.recordsWritten).getOrElse(0L).toDouble } val outputQuantiles = Output Size / Records +: getFormattedSizeQuantilesWithRecords(outputSizes, outputRecords) - val shuffleReadBlockedTimes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.shuffleReadMetrics.map(_.fetchWaitTime).getOrElse(0L).toDouble + val shuffleReadBlockedTimes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.shuffleReadMetrics.map(_.fetchWaitTime).getOrElse(0L).toDouble } val shuffleReadBlockedQuantiles = @@ -433,11 +461,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +: getFormattedTimeQuantiles(shuffleReadBlockedTimes) - val shuffleReadTotalSizes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.shuffleReadMetrics.map(_.totalBytesRead).getOrElse(0L).toDouble + val shuffleReadTotalSizes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.shuffleReadMetrics.map(_.totalBytesRead).getOrElse(0L).toDouble } - val shuffleReadTotalRecords = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.shuffleReadMetrics.map(_.recordsRead).getOrElse(0L).toDouble + val shuffleReadTotalRecords = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.shuffleReadMetrics.map(_.recordsRead).getOrElse(0L).toDouble } val shuffleReadTotalQuantiles = @@ -448,8 +476,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +: getFormattedSizeQuantilesWithRecords(shuffleReadTotalSizes, shuffleReadTotalRecords) - val shuffleReadRemoteSizes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble + val shuffleReadRemoteSizes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble } val shuffleReadRemoteQuantiles = @@ -460,25 +488,25 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +: getFormattedSizeQuantiles(shuffleReadRemoteSizes) - val shuffleWriteSizes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L).toDouble + val shuffleWriteSizes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.shuffleWriteMetrics.map(_.bytesWritten).getOrElse(0L).toDouble } - val shuffleWriteRecords = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.shuffleWriteMetrics.map(_.shuffleRecordsWritten).getOrElse(0L).toDouble + val shuffleWriteRecords = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.shuffleWriteMetrics.map(_.recordsWritten).getOrElse(0L).toDouble } val shuffleWriteQuantiles = Shuffle Write Size / Records +: getFormattedSizeQuantilesWithRecords(shuffleWriteSizes, shuffleWriteRecords) - val memoryBytesSpilledSizes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.memoryBytesSpilled.toDouble + val memoryBytesSpilledSizes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.memoryBytesSpilled.toDouble } val memoryBytesSpilledQuantiles = Shuffle spill (memory) +: getFormattedSizeQuantiles(memoryBytesSpilledSizes) - val diskBytesSpilledSizes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.diskBytesSpilled.toDouble + val diskBytesSpilledSizes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.diskBytesSpilled.toDouble } val diskBytesSpilledQuantiles = Shuffle spill (disk) +: getFormattedSizeQuantiles(diskBytesSpilledSizes) @@ -539,7 +567,6 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val content = summary ++ dagViz ++ - maybeExpandDagViz ++ showAdditionalMetrics ++ makeTimeline( // Only show the tasks in the table @@ -574,13 +601,13 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { def toProportion(time: Long) = time.toDouble / totalExecutionTime * 100 - val metricsOpt = taskUIData.taskMetrics + val metricsOpt = taskUIData.metrics val shuffleReadTime = metricsOpt.flatMap(_.shuffleReadMetrics.map(_.fetchWaitTime)).getOrElse(0L) val shuffleReadTimeProportion = toProportion(shuffleReadTime) val shuffleWriteTime = (metricsOpt.flatMap(_.shuffleWriteMetrics - .map(_.shuffleWriteTime)).getOrElse(0L) / 1e6).toLong + .map(_.writeTime)).getOrElse(0L) / 1e6).toLong val shuffleWriteTimeProportion = toProportion(shuffleWriteTime) val serializationTime = metricsOpt.map(_.resultSerializationTime).getOrElse(0L) @@ -841,7 +868,8 @@ private[ui] class TaskDataSource( def slicedTaskIds: Set[Long] = _slicedTaskIds private def taskRow(taskData: TaskUIData): TaskTableRowData = { - val TaskUIData(info, metrics, errorMessage) = taskData + val info = taskData.taskInfo + val metrics = taskData.metrics val duration = if (info.status == "RUNNING") info.timeRunning(currentTime) else metrics.map(_.executorRunTime).getOrElse(1L) val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration) @@ -852,15 +880,15 @@ private[ui] class TaskDataSource( val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) val gettingResultTime = getGettingResultTime(info, currentTime) - val (taskInternalAccumulables, taskExternalAccumulables) = - info.accumulables.partition(_.internal) - val externalAccumulableReadable = taskExternalAccumulables.map { acc => - StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}") - } - val peakExecutionMemoryUsed = taskInternalAccumulables - .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY } - .map { acc => acc.update.getOrElse("0").toLong } - .getOrElse(0L) + val externalAccumulableReadable = info.accumulables + .filterNot(_.internal) + .flatMap { a => + (a.name, a.update) match { + case (Some(name), Some(update)) => Some(StringEscapeUtils.escapeHtml4(s"$name: $update")) + case _ => None + } + } + val peakExecutionMemoryUsed = metrics.map(_.peakExecutionMemory).getOrElse(0L) val maybeInput = metrics.flatMap(_.inputMetrics) val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L) @@ -891,13 +919,13 @@ private[ui] class TaskDataSource( val shuffleReadRemoteReadable = remoteShuffleBytes.map(Utils.bytesToString).getOrElse("") val maybeShuffleWrite = metrics.flatMap(_.shuffleWriteMetrics) - val shuffleWriteSortable = maybeShuffleWrite.map(_.shuffleBytesWritten).getOrElse(0L) + val shuffleWriteSortable = maybeShuffleWrite.map(_.bytesWritten).getOrElse(0L) val shuffleWriteReadable = maybeShuffleWrite - .map(m => s"${Utils.bytesToString(m.shuffleBytesWritten)}").getOrElse("") + .map(m => s"${Utils.bytesToString(m.bytesWritten)}").getOrElse("") val shuffleWriteRecords = maybeShuffleWrite - .map(_.shuffleRecordsWritten.toString).getOrElse("") + .map(_.recordsWritten.toString).getOrElse("") - val maybeWriteTime = metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleWriteTime) + val maybeWriteTime = metrics.flatMap(_.shuffleWriteMetrics).map(_.writeTime) val writeTimeSortable = maybeWriteTime.getOrElse(0L) val writeTimeReadable = maybeWriteTime.map(t => t / (1000 * 1000)).map { ms => if (ms == 0) "" else UIUtils.formatDuration(ms) @@ -987,7 +1015,7 @@ private[ui] class TaskDataSource( shuffleRead, shuffleWrite, bytesSpilled, - errorMessage.getOrElse("")) + taskData.errorMessage.getOrElse("")) } /** @@ -1196,7 +1224,14 @@ private[ui] class TaskPagedTable( override def tableId: String = "task-table" - override def tableCssClass: String = "table table-bordered table-condensed table-striped" + override def tableCssClass: String = + "table table-bordered table-condensed table-striped table-head-clickable" + + override def pageSizeFormField: String = "task.pageSize" + + override def prevPageSizeFormField: String = "task.prevPageSize" + + override def pageNumberFormField: String = "task.page" override val dataSource: TaskDataSource = new TaskDataSource( data, @@ -1213,24 +1248,16 @@ private[ui] class TaskPagedTable( override def pageLink(page: Int): String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") - s"${basePath}&task.page=$page&task.sort=${encodedSortColumn}&task.desc=${desc}" + - s"&task.pageSize=${pageSize}" + basePath + + s"&$pageNumberFormField=$page" + + s"&task.sort=$encodedSortColumn" + + s"&task.desc=$desc" + + s"&$pageSizeFormField=$pageSize" } - override def goButtonJavascriptFunction: (String, String) = { - val jsFuncName = "goToTaskPage" + override def goButtonFormPath: String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") - val jsFunc = s""" - |currentTaskPageSize = ${pageSize} - |function goToTaskPage(page, pageSize) { - | // Set page to 1 if the page size changes - | page = pageSize == currentTaskPageSize ? page : 1; - | var url = "${basePath}&task.sort=${encodedSortColumn}&task.desc=${desc}" + - | "&task.page=" + page + "&task.pageSize=" + pageSize; - | window.location.href = url; - |} - """.stripMargin - (jsFuncName, jsFunc) + s"$basePath&task.sort=$encodedSortColumn&task.desc=$desc" } def headers: Seq[Node] = { @@ -1279,21 +1306,27 @@ private[ui] class TaskPagedTable( val headerRow: Seq[Node] = { taskHeadersAndCssClasses.map { case (header, cssClass) => if (header == sortColumn) { - val headerLink = - s"$basePath&task.sort=${URLEncoder.encode(header, "UTF-8")}&task.desc=${!desc}" + - s"&task.pageSize=${pageSize}" - val js = Unparsed(s"window.location.href='${headerLink}'") + val headerLink = Unparsed( + basePath + + s"&task.sort=${URLEncoder.encode(header, "UTF-8")}" + + s"&task.desc=${!desc}" + + s"&task.pageSize=$pageSize") val arrow = if (desc) "▾" else "▴" // UP or DOWN - - {header} -  {Unparsed(arrow)} + + + {header} +  {Unparsed(arrow)} + } else { - val headerLink = - s"$basePath&task.sort=${URLEncoder.encode(header, "UTF-8")}&task.pageSize=${pageSize}" - val js = Unparsed(s"window.location.href='${headerLink}'") - - {header} + val headerLink = Unparsed( + basePath + + s"&task.sort=${URLEncoder.encode(header, "UTF-8")}" + + s"&task.pageSize=$pageSize") + + + {header} + } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index ea806d09b6009..2a1c3c1a50ec9 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -145,9 +145,22 @@ private[ui] class StageTableBase( case None => "Unknown" } val finishTime = s.completionTime.getOrElse(System.currentTimeMillis) - val duration = s.submissionTime.map { t => - if (finishTime > t) finishTime - t else System.currentTimeMillis - t - } + + // The submission time for a stage is misleading because it counts the time + // the stage waits to be launched. (SPARK-10930) + val taskLaunchTimes = + stageData.taskData.values.map(_.taskInfo.launchTime).filter(_ > 0) + val duration: Option[Long] = + if (taskLaunchTimes.nonEmpty) { + val startTime = taskLaunchTimes.min + if (finishTime > startTime) { + Some(finishTime - startTime) + } else { + Some(System.currentTimeMillis() - startTime) + } + } else { + None + } val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") val inputRead = stageData.inputBytes diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index 5989f0035b270..bd5f16d25b477 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -34,7 +34,7 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages" attachPage(new StagePage(this)) attachPage(new PoolPage(this)) - def isFairScheduler: Boolean = progressListener.schedulingMode.exists(_ == SchedulingMode.FAIR) + def isFairScheduler: Boolean = progressListener.schedulingMode == Some(SchedulingMode.FAIR) def handleKillRequest(request: HttpServletRequest): Unit = { if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index f008d40180611..b454ef1b204b2 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -17,14 +17,14 @@ package org.apache.spark.ui.jobs +import scala.collection.mutable +import scala.collection.mutable.HashMap + import org.apache.spark.JobExecutionStatus import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} import org.apache.spark.util.collection.OpenHashSet -import scala.collection.mutable -import scala.collection.mutable.HashMap - private[spark] object UIData { class ExecutorSummary { @@ -105,12 +105,12 @@ private[spark] object UIData { /** * These are kept mutable and reused throughout a task's lifetime to avoid excessive reallocation. */ - case class TaskUIData( + class TaskUIData( var taskInfo: TaskInfo, - var taskMetrics: Option[TaskMetrics] = None, + var metrics: Option[TaskMetrics] = None, var errorMessage: Option[String] = None) - case class ExecutorUIData( + class ExecutorUIData( val startTime: Long, var finishTime: Option[Long] = None, var finishReason: Option[String] = None) diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index 81f168a447ead..bb6b663f1ead3 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -18,9 +18,11 @@ package org.apache.spark.ui.scope import scala.collection.mutable -import scala.collection.mutable.{StringBuilder, ListBuffer} +import scala.collection.mutable.{ListBuffer, StringBuilder} -import org.apache.spark.Logging +import org.apache.commons.lang3.StringEscapeUtils + +import org.apache.spark.internal.Logging import org.apache.spark.scheduler.StageInfo import org.apache.spark.storage.StorageLevel @@ -38,7 +40,7 @@ private[ui] case class RDDOperationGraph( rootCluster: RDDOperationCluster) /** A node in an RDDOperationGraph. This represents an RDD. */ -private[ui] case class RDDOperationNode(id: Int, name: String, cached: Boolean) +private[ui] case class RDDOperationNode(id: Int, name: String, cached: Boolean, callsite: String) /** * A directed edge connecting two nodes in an RDDOperationGraph. @@ -104,8 +106,8 @@ private[ui] object RDDOperationGraph extends Logging { edges ++= rdd.parentIds.map { parentId => RDDOperationEdge(parentId, rdd.id) } // TODO: differentiate between the intention to cache an RDD and whether it's actually cached - val node = nodes.getOrElseUpdate( - rdd.id, RDDOperationNode(rdd.id, rdd.name, rdd.storageLevel != StorageLevel.NONE)) + val node = nodes.getOrElseUpdate(rdd.id, RDDOperationNode( + rdd.id, rdd.name, rdd.storageLevel != StorageLevel.NONE, rdd.callSite)) if (rdd.scope.isEmpty) { // This RDD has no encompassing scope, so we put it directly in the root cluster @@ -129,7 +131,11 @@ private[ui] object RDDOperationGraph extends Logging { } } // Attach the outermost cluster to the root cluster, and the RDD to the innermost cluster - rddClusters.headOption.foreach { cluster => rootCluster.attachChildCluster(cluster) } + rddClusters.headOption.foreach { cluster => + if (!rootCluster.childClusters.contains(cluster)) { + rootCluster.attachChildCluster(cluster) + } + } rddClusters.lastOption.foreach { cluster => cluster.attachChildNode(node) } } } @@ -177,7 +183,8 @@ private[ui] object RDDOperationGraph extends Logging { /** Return the dot representation of a node in an RDDOperationGraph. */ private def makeDotNode(node: RDDOperationNode): String = { - s"""${node.id} [label="${node.name} [${node.id}]"]""" + val label = s"${node.name} [${node.id}]\n${node.callsite}" + s"""${node.id} [label="${StringEscapeUtils.escapeJava(label)}"]""" } /** Update the dot representation of the RDDOperationGraph in cluster to subgraph. */ @@ -186,7 +193,7 @@ private[ui] object RDDOperationGraph extends Logging { cluster: RDDOperationCluster, indent: String): Unit = { subgraph.append(indent).append(s"subgraph cluster${cluster.id} {\n") - subgraph.append(indent).append(s""" label="${cluster.name}";\n""") + .append(indent).append(s""" label="${StringEscapeUtils.escapeJava(cluster.name)}";\n""") cluster.childNodes.foreach { node => subgraph.append(indent).append(s" ${makeDotNode(node)};\n") } diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala index 89119cd3579ef..bcae56e2f114c 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala @@ -52,9 +52,8 @@ private[ui] class RDDOperationGraphListener(conf: SparkConf) extends SparkListen * An empty list is returned if one or more of its stages has been cleaned up. */ def getOperationGraphForJob(jobId: Int): Seq[RDDOperationGraph] = synchronized { - val skippedStageIds = jobIdToSkippedStageIds.get(jobId).getOrElse(Seq.empty) - val graphs = jobIdToStageIds.get(jobId) - .getOrElse(Seq.empty) + val skippedStageIds = jobIdToSkippedStageIds.getOrElse(jobId, Seq.empty) + val graphs = jobIdToStageIds.getOrElse(jobId, Seq.empty) .flatMap { sid => stageIdToGraph.get(sid) } // Mark any skipped stages as such graphs.foreach { g => diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index fd6cc3ed759b3..606d15d599e81 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -38,11 +38,13 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { val parameterBlockSortColumn = request.getParameter("block.sort") val parameterBlockSortDesc = request.getParameter("block.desc") val parameterBlockPageSize = request.getParameter("block.pageSize") + val parameterBlockPrevPageSize = request.getParameter("block.prevPageSize") val blockPage = Option(parameterBlockPage).map(_.toInt).getOrElse(1) val blockSortColumn = Option(parameterBlockSortColumn).getOrElse("Block Name") val blockSortDesc = Option(parameterBlockSortDesc).map(_.toBoolean).getOrElse(false) val blockPageSize = Option(parameterBlockPageSize).map(_.toInt).getOrElse(100) + val blockPrevPageSize = Option(parameterBlockPrevPageSize).map(_.toInt).getOrElse(blockPageSize) val rddId = parameterId.toInt val rddStorageInfo = AllRDDResource.getRDDStorageInfo(rddId, listener, includeDetails = true) @@ -56,17 +58,26 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { rddStorageInfo.dataDistribution.get, id = Some("rdd-storage-by-worker-table")) // Block table - val (blockTable, blockTableHTML) = try { + val page: Int = { + // If the user has changed to a larger page size, then go to page 1 in order to avoid + // IndexOutOfBoundsException. + if (blockPageSize <= blockPrevPageSize) { + blockPage + } else { + 1 + } + } + val blockTableHTML = try { val _blockTable = new BlockPagedTable( UIUtils.prependBaseUri(parent.basePath) + s"/storage/rdd/?id=${rddId}", rddStorageInfo.partitions.get, blockPageSize, blockSortColumn, blockSortDesc) - (_blockTable, _blockTable.table(blockPage)) + _blockTable.table(page) } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => - (null,
    {e.getMessage}
    ) +
    {e.getMessage}
    } val jsForScrollingDownToBlockTable = @@ -226,7 +237,14 @@ private[ui] class BlockPagedTable( override def tableId: String = "rdd-storage-by-block-table" - override def tableCssClass: String = "table table-bordered table-condensed table-striped" + override def tableCssClass: String = + "table table-bordered table-condensed table-striped table-head-clickable" + + override def pageSizeFormField: String = "block.pageSize" + + override def prevPageSizeFormField: String = "block.prevPageSize" + + override def pageNumberFormField: String = "block.page" override val dataSource: BlockDataSource = new BlockDataSource( rddPartitions, @@ -236,24 +254,16 @@ private[ui] class BlockPagedTable( override def pageLink(page: Int): String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") - s"${basePath}&block.page=$page&block.sort=${encodedSortColumn}&block.desc=${desc}" + - s"&block.pageSize=${pageSize}" + basePath + + s"&$pageNumberFormField=$page" + + s"&block.sort=$encodedSortColumn" + + s"&block.desc=$desc" + + s"&$pageSizeFormField=$pageSize" } - override def goButtonJavascriptFunction: (String, String) = { - val jsFuncName = "goToBlockPage" + override def goButtonFormPath: String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") - val jsFunc = s""" - |currentBlockPageSize = ${pageSize} - |function goToBlockPage(page, pageSize) { - | // Set page to 1 if the page size changes - | page = pageSize == currentBlockPageSize ? page : 1; - | var url = "${basePath}&block.sort=${encodedSortColumn}&block.desc=${desc}" + - | "&block.page=" + page + "&block.pageSize=" + pageSize; - | window.location.href = url; - |} - """.stripMargin - (jsFuncName, jsFunc) + s"$basePath&block.sort=$encodedSortColumn&block.desc=$desc" } override def headers: Seq[Node] = { @@ -271,22 +281,27 @@ private[ui] class BlockPagedTable( val headerRow: Seq[Node] = { blockHeaders.map { header => if (header == sortColumn) { - val headerLink = - s"$basePath&block.sort=${URLEncoder.encode(header, "UTF-8")}&block.desc=${!desc}" + - s"&block.pageSize=${pageSize}" - val js = Unparsed(s"window.location.href='${headerLink}'") + val headerLink = Unparsed( + basePath + + s"&block.sort=${URLEncoder.encode(header, "UTF-8")}" + + s"&block.desc=${!desc}" + + s"&block.pageSize=$pageSize") val arrow = if (desc) "▾" else "▴" // UP or DOWN - - {header} -  {Unparsed(arrow)} + + + {header} +  {Unparsed(arrow)} + } else { - val headerLink = - s"$basePath&block.sort=${URLEncoder.encode(header, "UTF-8")}" + - s"&block.pageSize=${pageSize}" - val js = Unparsed(s"window.location.href='${headerLink}'") - - {header} + val headerLink = Unparsed( + basePath + + s"&block.sort=${URLEncoder.encode(header, "UTF-8")}" + + s"&block.pageSize=$pageSize") + + + {header} + } } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 04f584621e71e..76d7c6d414bcf 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -54,7 +54,6 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { "Cached Partitions", "Fraction Cached", "Size in Memory", - "Size in ExternalBlockStore", "Size on Disk") /** Render an HTML row representing an RDD */ @@ -71,7 +70,6 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { {rdd.numCachedPartitions.toString} {"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)} {Utils.bytesToString(rdd.memSize)} - {Utils.bytesToString(rdd.externalBlockStoreSize)} {Utils.bytesToString(rdd.diskSize)} // scalastyle:on @@ -104,7 +102,6 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { "Executor ID", "Address", "Total Size in Memory", - "Total Size in ExternalBlockStore", "Total Size on Disk", "Stream Blocks") @@ -119,9 +116,6 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { {Utils.bytesToString(status.totalMemSize)} - - {Utils.bytesToString(status.totalExternalBlockStoreSize)} - {Utils.bytesToString(status.totalDiskSize)} @@ -162,7 +156,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { streamBlockTableSubrow(block._1, replications.head, replications.size, true) } else { streamBlockTableSubrow(block._1, replications.head, replications.size, true) ++ - replications.tail.map(streamBlockTableSubrow(block._1, _, replications.size, false)).flatten + replications.tail.flatMap(streamBlockTableSubrow(block._1, _, replications.size, false)) } } @@ -195,8 +189,6 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { ("Memory", block.memSize) } else if (block.storageLevel.useMemory && !block.storageLevel.deserialized) { ("Memory Serialized", block.memSize) - } else if (block.storageLevel.useOffHeap) { - ("External", block.externalBlockStoreSize) } else { throw new IllegalStateException(s"Invalid Storage Level: ${block.storageLevel}") } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index 22e2993b3b5bd..50095831b4a53 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -20,9 +20,9 @@ package org.apache.spark.ui.storage import scala.collection.mutable import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.ui._ import org.apache.spark.scheduler._ import org.apache.spark.storage._ +import org.apache.spark.ui._ /** Web UI showing storage status of all RDD's in the given SparkContext. */ private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storage") { @@ -43,7 +43,7 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Bloc private[ui] val _rddInfoMap = mutable.Map[Int, RDDInfo]() // exposed for testing - def storageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList + def activeStorageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList /** Filter RDD info to include only those with cached partitions */ def rddInfoList: Seq[RDDInfo] = synchronized { @@ -54,18 +54,7 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Bloc private def updateRDDInfo(updatedBlocks: Seq[(BlockId, BlockStatus)]): Unit = { val rddIdsToUpdate = updatedBlocks.flatMap { case (bid, _) => bid.asRDDId.map(_.rddId) }.toSet val rddInfosToUpdate = _rddInfoMap.values.toSeq.filter { s => rddIdsToUpdate.contains(s.id) } - StorageUtils.updateRddInfo(rddInfosToUpdate, storageStatusList) - } - - /** - * Assumes the storage status list is fully up-to-date. This implies the corresponding - * StorageStatusSparkListener must process the SparkListenerTaskEnd event before this listener. - */ - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { - val metrics = taskEnd.taskMetrics - if (metrics != null && metrics.updatedBlocks.isDefined) { - updateRDDInfo(metrics.updatedBlocks.get) - } + StorageUtils.updateRddInfo(rddInfosToUpdate, activeStorageStatusList) } override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { @@ -84,4 +73,14 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Bloc override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = synchronized { _rddInfoMap.remove(unpersistRDD.rddId) } + + override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { + super.onBlockUpdated(blockUpdated) + val blockId = blockUpdated.blockUpdatedInfo.blockId + val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel + val memSize = blockUpdated.blockUpdatedInfo.memSize + val diskSize = blockUpdated.blockUpdatedInfo.diskSize + val blockStatus = BlockStatus(storageLevel, memSize, diskSize) + updateRDDInfo(Seq((blockId, blockStatus))) + } } diff --git a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala deleted file mode 100644 index 81a7cbde01ce5..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import akka.actor.Actor -import org.slf4j.Logger - -/** - * A trait to enable logging all Akka actor messages. Here's an example of using this: - * - * {{{ - * class BlockManagerMasterActor extends Actor with ActorLogReceive with Logging { - * ... - * override def receiveWithLogging = { - * case GetLocations(blockId) => - * sender ! getLocations(blockId) - * ... - * } - * ... - * } - * }}} - * - */ -private[spark] trait ActorLogReceive { - self: Actor => - - override def receive: Actor.Receive = new Actor.Receive { - - private val _receiveWithLogging = receiveWithLogging - - override def isDefinedAt(o: Any): Boolean = { - val handled = _receiveWithLogging.isDefinedAt(o) - if (!handled) { - log.debug(s"Received unexpected actor system event: $o") - } - handled - } - - override def apply(o: Any): Unit = { - if (log.isDebugEnabled) { - log.debug(s"[actor] received message $o from ${self.sender}") - } - val start = System.nanoTime - _receiveWithLogging.apply(o) - val timeTaken = (System.nanoTime - start).toDouble / 1000000 - if (log.isDebugEnabled) { - log.debug(s"[actor] handled message ($timeTaken ms) $o from ${self.sender}") - } - } - } - - def receiveWithLogging: Actor.Receive - - protected def log: Logger -} diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala deleted file mode 100644 index 1738258a0c794..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ /dev/null @@ -1,242 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import scala.collection.JavaConverters._ - -import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem} -import akka.pattern.ask - -import com.typesafe.config.ConfigFactory -import org.apache.log4j.{Level, Logger} - -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException} -import org.apache.spark.rpc.RpcTimeout - -/** - * Various utility classes for working with Akka. - */ -private[spark] object AkkaUtils extends Logging { - - /** - * Creates an ActorSystem ready for remoting, with various Spark features. Returns both the - * ActorSystem itself and its port (which is hard to get from Akka). - * - * Note: the `name` parameter is important, as even if a client sends a message to right - * host + port, if the system name is incorrect, Akka will drop the message. - * - * If indestructible is set to true, the Actor System will continue running in the event - * of a fatal exception. This is used by [[org.apache.spark.executor.Executor]]. - */ - def createActorSystem( - name: String, - host: String, - port: Int, - conf: SparkConf, - securityManager: SecurityManager): (ActorSystem, Int) = { - val startService: Int => (ActorSystem, Int) = { actualPort => - doCreateActorSystem(name, host, actualPort, conf, securityManager) - } - Utils.startServiceOnPort(port, startService, conf, name) - } - - private def doCreateActorSystem( - name: String, - host: String, - port: Int, - conf: SparkConf, - securityManager: SecurityManager): (ActorSystem, Int) = { - - val akkaThreads = conf.getInt("spark.akka.threads", 4) - val akkaBatchSize = conf.getInt("spark.akka.batchSize", 15) - val akkaTimeoutS = conf.getTimeAsSeconds("spark.akka.timeout", - conf.get("spark.network.timeout", "120s")) - val akkaFrameSize = maxFrameSizeBytes(conf) - val akkaLogLifecycleEvents = conf.getBoolean("spark.akka.logLifecycleEvents", false) - val lifecycleEvents = if (akkaLogLifecycleEvents) "on" else "off" - if (!akkaLogLifecycleEvents) { - // As a workaround for Akka issue #3787, we coerce the "EndpointWriter" log to be silent. - // See: https://www.assembla.com/spaces/akka/tickets/3787#/ - Option(Logger.getLogger("akka.remote.EndpointWriter")).map(l => l.setLevel(Level.FATAL)) - } - - val logAkkaConfig = if (conf.getBoolean("spark.akka.logAkkaConfig", false)) "on" else "off" - - val akkaHeartBeatPausesS = conf.getTimeAsSeconds("spark.akka.heartbeat.pauses", "6000s") - val akkaHeartBeatIntervalS = conf.getTimeAsSeconds("spark.akka.heartbeat.interval", "1000s") - - val secretKey = securityManager.getSecretKey() - val isAuthOn = securityManager.isAuthenticationEnabled() - if (isAuthOn && secretKey == null) { - throw new Exception("Secret key is null with authentication on") - } - val requireCookie = if (isAuthOn) "on" else "off" - val secureCookie = if (isAuthOn) secretKey else "" - logDebug(s"In createActorSystem, requireCookie is: $requireCookie") - - val akkaSslConfig = securityManager.akkaSSLOptions.createAkkaConfig - .getOrElse(ConfigFactory.empty()) - - val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap.asJava) - .withFallback(akkaSslConfig).withFallback(ConfigFactory.parseString( - s""" - |akka.daemonic = on - |akka.loggers = [""akka.event.slf4j.Slf4jLogger""] - |akka.stdout-loglevel = "ERROR" - |akka.jvm-exit-on-fatal-error = off - |akka.remote.require-cookie = "$requireCookie" - |akka.remote.secure-cookie = "$secureCookie" - |akka.remote.transport-failure-detector.heartbeat-interval = $akkaHeartBeatIntervalS s - |akka.remote.transport-failure-detector.acceptable-heartbeat-pause = $akkaHeartBeatPausesS s - |akka.actor.provider = "akka.remote.RemoteActorRefProvider" - |akka.remote.netty.tcp.transport-class = "akka.remote.transport.netty.NettyTransport" - |akka.remote.netty.tcp.hostname = "$host" - |akka.remote.netty.tcp.port = $port - |akka.remote.netty.tcp.tcp-nodelay = on - |akka.remote.netty.tcp.connection-timeout = $akkaTimeoutS s - |akka.remote.netty.tcp.maximum-frame-size = ${akkaFrameSize}B - |akka.remote.netty.tcp.execution-pool-size = $akkaThreads - |akka.actor.default-dispatcher.throughput = $akkaBatchSize - |akka.log-config-on-start = $logAkkaConfig - |akka.remote.log-remote-lifecycle-events = $lifecycleEvents - |akka.log-dead-letters = $lifecycleEvents - |akka.log-dead-letters-during-shutdown = $lifecycleEvents - """.stripMargin)) - - val actorSystem = ActorSystem(name, akkaConf) - val provider = actorSystem.asInstanceOf[ExtendedActorSystem].provider - val boundPort = provider.getDefaultAddress.port.get - (actorSystem, boundPort) - } - - private val AKKA_MAX_FRAME_SIZE_IN_MB = Int.MaxValue / 1024 / 1024 - - /** Returns the configured max frame size for Akka messages in bytes. */ - def maxFrameSizeBytes(conf: SparkConf): Int = { - val frameSizeInMB = conf.getInt("spark.akka.frameSize", 128) - if (frameSizeInMB > AKKA_MAX_FRAME_SIZE_IN_MB) { - throw new IllegalArgumentException( - s"spark.akka.frameSize should not be greater than $AKKA_MAX_FRAME_SIZE_IN_MB MB") - } - frameSizeInMB * 1024 * 1024 - } - - /** Space reserved for extra data in an Akka message besides serialized task or task result. */ - val reservedSizeBytes = 200 * 1024 - - /** - * Send a message to the given actor and get its result within a default timeout, or - * throw a SparkException if this fails. - */ - def askWithReply[T]( - message: Any, - actor: ActorRef, - timeout: RpcTimeout): T = { - askWithReply[T](message, actor, maxAttempts = 1, retryInterval = Int.MaxValue, timeout) - } - - /** - * Send a message to the given actor and get its result within a default timeout, or - * throw a SparkException if this fails even after the specified number of retries. - */ - def askWithReply[T]( - message: Any, - actor: ActorRef, - maxAttempts: Int, - retryInterval: Long, - timeout: RpcTimeout): T = { - // TODO: Consider removing multiple attempts - if (actor == null) { - throw new SparkException(s"Error sending message [message = $message]" + - " as actor is null ") - } - var attempts = 0 - var lastException: Exception = null - while (attempts < maxAttempts) { - attempts += 1 - try { - val future = actor.ask(message)(timeout.duration) - val result = timeout.awaitResult(future) - if (result == null) { - throw new SparkException("Actor returned null") - } - return result.asInstanceOf[T] - } catch { - case ie: InterruptedException => throw ie - case e: Exception => - lastException = e - logWarning(s"Error sending message [message = $message] in $attempts attempts", e) - } - if (attempts < maxAttempts) { - Thread.sleep(retryInterval) - } - } - - throw new SparkException( - s"Error sending message [message = $message]", lastException) - } - - def makeDriverRef(name: String, conf: SparkConf, actorSystem: ActorSystem): ActorRef = { - val driverActorSystemName = SparkEnv.driverActorSystemName - val driverHost: String = conf.get("spark.driver.host", "localhost") - val driverPort: Int = conf.getInt("spark.driver.port", 7077) - Utils.checkHost(driverHost, "Expected hostname") - val url = address(protocol(actorSystem), driverActorSystemName, driverHost, driverPort, name) - val timeout = RpcUtils.lookupRpcTimeout(conf) - logInfo(s"Connecting to $name: $url") - timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration)) - } - - def makeExecutorRef( - name: String, - conf: SparkConf, - host: String, - port: Int, - actorSystem: ActorSystem): ActorRef = { - val executorActorSystemName = SparkEnv.executorActorSystemName - Utils.checkHost(host, "Expected hostname") - val url = address(protocol(actorSystem), executorActorSystemName, host, port, name) - val timeout = RpcUtils.lookupRpcTimeout(conf) - logInfo(s"Connecting to $name: $url") - timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration)) - } - - def protocol(actorSystem: ActorSystem): String = { - val akkaConf = actorSystem.settings.config - val sslProp = "akka.remote.netty.tcp.enable-ssl" - protocol(akkaConf.hasPath(sslProp) && akkaConf.getBoolean(sslProp)) - } - - def protocol(ssl: Boolean = false): String = { - if (ssl) { - "akka.ssl.tcp" - } else { - "akka.tcp" - } - } - - def address( - protocol: String, - systemName: String, - host: String, - port: Int, - actorName: String): String = { - s"$protocol://$systemName@$host:$port/user/$actorName" - } - -} diff --git a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala deleted file mode 100644 index 61b5a4cecddce..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala +++ /dev/null @@ -1,180 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import java.util.concurrent._ -import java.util.concurrent.atomic.AtomicBoolean - -import com.google.common.annotations.VisibleForTesting -import org.apache.spark.SparkContext - -/** - * Asynchronously passes events to registered listeners. - * - * Until `start()` is called, all posted events are only buffered. Only after this listener bus - * has started will events be actually propagated to all attached listeners. This listener bus - * is stopped when `stop()` is called, and it will drop further events after stopping. - * - * @param name name of the listener bus, will be the name of the listener thread. - * @tparam L type of listener - * @tparam E type of event - */ -private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: String) - extends ListenerBus[L, E] { - - self => - - private var sparkContext: SparkContext = null - - /* Cap the capacity of the event queue so we get an explicit error (rather than - * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */ - private val EVENT_QUEUE_CAPACITY = 10000 - private val eventQueue = new LinkedBlockingQueue[E](EVENT_QUEUE_CAPACITY) - - // Indicate if `start()` is called - private val started = new AtomicBoolean(false) - // Indicate if `stop()` is called - private val stopped = new AtomicBoolean(false) - - // Indicate if we are processing some event - // Guarded by `self` - private var processingEvent = false - - // A counter that represents the number of events produced and consumed in the queue - private val eventLock = new Semaphore(0) - - private val listenerThread = new Thread(name) { - setDaemon(true) - override def run(): Unit = Utils.tryOrStopSparkContext(sparkContext) { - while (true) { - eventLock.acquire() - self.synchronized { - processingEvent = true - } - try { - val event = eventQueue.poll - if (event == null) { - // Get out of the while loop and shutdown the daemon thread - if (!stopped.get) { - throw new IllegalStateException("Polling `null` from eventQueue means" + - " the listener bus has been stopped. So `stopped` must be true") - } - return - } - postToAll(event) - } finally { - self.synchronized { - processingEvent = false - } - } - } - } - } - - /** - * Start sending events to attached listeners. - * - * This first sends out all buffered events posted before this listener bus has started, then - * listens for any additional events asynchronously while the listener bus is still running. - * This should only be called once. - * - * @param sc Used to stop the SparkContext in case the listener thread dies. - */ - def start(sc: SparkContext) { - if (started.compareAndSet(false, true)) { - sparkContext = sc - listenerThread.start() - } else { - throw new IllegalStateException(s"$name already started!") - } - } - - def post(event: E) { - if (stopped.get) { - // Drop further events to make `listenerThread` exit ASAP - logError(s"$name has already stopped! Dropping event $event") - return - } - val eventAdded = eventQueue.offer(event) - if (eventAdded) { - eventLock.release() - } else { - onDropEvent(event) - } - } - - /** - * For testing only. Wait until there are no more events in the queue, or until the specified - * time has elapsed. Throw `TimeoutException` if the specified time elapsed before the queue - * emptied. - */ - @VisibleForTesting - @throws(classOf[TimeoutException]) - def waitUntilEmpty(timeoutMillis: Long): Unit = { - val finishTime = System.currentTimeMillis + timeoutMillis - while (!queueIsEmpty) { - if (System.currentTimeMillis > finishTime) { - throw new TimeoutException( - s"The event queue is not empty after $timeoutMillis milliseconds") - } - /* Sleep rather than using wait/notify, because this is used only for testing and - * wait/notify add overhead in the general case. */ - Thread.sleep(10) - } - } - - /** - * For testing only. Return whether the listener daemon thread is still alive. - */ - @VisibleForTesting - def listenerThreadIsAlive: Boolean = listenerThread.isAlive - - /** - * Return whether the event queue is empty. - * - * The use of synchronized here guarantees that all events that once belonged to this queue - * have already been processed by all attached listeners, if this returns true. - */ - private def queueIsEmpty: Boolean = synchronized { eventQueue.isEmpty && !processingEvent } - - /** - * Stop the listener bus. It will wait until the queued events have been processed, but drop the - * new events after stopping. - */ - def stop() { - if (!started.get()) { - throw new IllegalStateException(s"Attempted to stop $name that has not yet started!") - } - if (stopped.compareAndSet(false, true)) { - // Call eventLock.release() so that listenerThread will poll `null` from `eventQueue` and know - // `stop` is called. - eventLock.release() - listenerThread.join() - } else { - // Keep quiet - } - } - - /** - * If the event queue exceeds its capacity, the new events will be dropped. The subclasses will be - * notified with the dropped events. - * - * Note: `onDropEvent` can be called in any thread. - */ - def onDropEvent(event: E): Unit -} diff --git a/core/src/main/scala/org/apache/spark/util/Benchmark.scala b/core/src/main/scala/org/apache/spark/util/Benchmark.scala new file mode 100644 index 0000000000000..9e40bafd521d7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/Benchmark.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.Try + +import org.apache.commons.lang3.SystemUtils + +/** + * Utility class to benchmark components. An example of how to use this is: + * val benchmark = new Benchmark("My Benchmark", valuesPerIteration) + * benchmark.addCase("V1")() + * benchmark.addCase("V2")() + * benchmark.run + * This will output the average time to run each function and the rate of each function. + * + * The benchmark function takes one argument that is the iteration that's being run. + * + * If outputPerIteration is true, the timing for each run will be printed to stdout. + */ +private[spark] class Benchmark( + name: String, + valuesPerIteration: Long, + iters: Int = 5, + outputPerIteration: Boolean = false) { + val benchmarks = mutable.ArrayBuffer.empty[Benchmark.Case] + + def addCase(name: String)(f: Int => Unit): Unit = { + benchmarks += Benchmark.Case(name, f) + } + + /** + * Runs the benchmark and outputs the results to stdout. This should be copied and added as + * a comment with the benchmark. Although the results vary from machine to machine, it should + * provide some baseline. + */ + def run(): Unit = { + require(benchmarks.nonEmpty) + // scalastyle:off + println("Running benchmark: " + name) + + val results = benchmarks.map { c => + println(" Running case: " + c.name) + Benchmark.measure(valuesPerIteration, iters, outputPerIteration)(c.fn) + } + println + + val firstBest = results.head.bestMs + // The results are going to be processor specific so it is useful to include that. + println(Benchmark.getJVMOSInfo()) + println(Benchmark.getProcessorName()) + printf("%-35s %16s %12s %13s %10s\n", name + ":", "Best/Avg Time(ms)", "Rate(M/s)", + "Per Row(ns)", "Relative") + println("-----------------------------------------------------------------------------------" + + "--------") + results.zip(benchmarks).foreach { case (result, benchmark) => + printf("%-35s %16s %12s %13s %10s\n", + benchmark.name, + "%5.0f / %4.0f" format (result.bestMs, result.avgMs), + "%10.1f" format result.bestRate, + "%6.1f" format (1000 / result.bestRate), + "%3.1fX" format (firstBest / result.bestMs)) + } + println + // scalastyle:on + } +} + +private[spark] object Benchmark { + case class Case(name: String, fn: Int => Unit) + case class Result(avgMs: Double, bestRate: Double, bestMs: Double) + + /** + * This should return a user helpful processor information. Getting at this depends on the OS. + * This should return something like "Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz" + */ + def getProcessorName(): String = { + val cpu = if (SystemUtils.IS_OS_MAC_OSX) { + Utils.executeAndGetOutput(Seq("/usr/sbin/sysctl", "-n", "machdep.cpu.brand_string")) + } else if (SystemUtils.IS_OS_LINUX) { + Try { + val grepPath = Utils.executeAndGetOutput(Seq("which", "grep")) + Utils.executeAndGetOutput(Seq(grepPath, "-m", "1", "model name", "/proc/cpuinfo")) + .replaceFirst("model name[\\s*]:[\\s*]", "") + }.getOrElse("Unknown processor") + } else { + System.getenv("PROCESSOR_IDENTIFIER") + } + cpu + } + + /** + * This should return a user helpful JVM & OS information. + * This should return something like + * "OpenJDK 64-Bit Server VM 1.8.0_65-b17 on Linux 4.1.13-100.fc21.x86_64" + */ + def getJVMOSInfo(): String = { + val vmName = System.getProperty("java.vm.name") + val runtimeVersion = System.getProperty("java.runtime.version") + val osName = System.getProperty("os.name") + val osVersion = System.getProperty("os.version") + s"${vmName} ${runtimeVersion} on ${osName} ${osVersion}" + } + + /** + * Runs a single function `f` for iters, returning the average time the function took and + * the rate of the function. + */ + def measure(num: Long, iters: Int, outputPerIteration: Boolean)(f: Int => Unit): Result = { + val runTimes = ArrayBuffer[Long]() + for (i <- 0 until iters + 1) { + val start = System.nanoTime() + + f(i) + + val end = System.nanoTime() + val runTime = end - start + if (i > 0) { + runTimes += runTime + } + + if (outputPerIteration) { + // scalastyle:off + println(s"Iteration $i took ${runTime / 1000} microseconds") + // scalastyle:on + } + } + val best = runTimes.min + val avg = runTimes.sum / iters + Result(avg / 1000000.0, num / (best / 1000.0), best / 1000000.0) + } +} + diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala index 54de4d4ee8ca7..dce2ac63a664c 100644 --- a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala @@ -20,10 +20,10 @@ package org.apache.spark.util import java.io.InputStream import java.nio.ByteBuffer -import org.apache.spark.storage.BlockManager +import org.apache.spark.storage.StorageUtils /** - * Reads data from a ByteBuffer, and optionally cleans it up using BlockManager.dispose() + * Reads data from a ByteBuffer, and optionally cleans it up using StorageUtils.dispose() * at the end of the stream (e.g. to close a memory-mapped file). */ private[spark] @@ -68,12 +68,12 @@ class ByteBufferInputStream(private var buffer: ByteBuffer, dispose: Boolean = f } /** - * Clean up the buffer, and potentially dispose of it using BlockManager.dispose(). + * Clean up the buffer, and potentially dispose of it using StorageUtils.dispose(). */ private def cleanUp() { if (buffer != null) { if (dispose) { - BlockManager.dispose(buffer) + StorageUtils.dispose(buffer) } buffer = null } diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala new file mode 100644 index 0000000000000..09e7579ae9606 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.ByteArrayOutputStream +import java.nio.ByteBuffer + +/** + * Provide a zero-copy way to convert data in ByteArrayOutputStream to ByteBuffer + */ +private[spark] class ByteBufferOutputStream(capacity: Int) extends ByteArrayOutputStream(capacity) { + + def this() = this(32) + + def getCount(): Int = count + + def toByteBuffer: ByteBuffer = { + return ByteBuffer.wrap(buf, 0, count) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/CausedBy.scala b/core/src/main/scala/org/apache/spark/util/CausedBy.scala new file mode 100644 index 0000000000000..73df446d981cb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/CausedBy.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * Extractor Object for pulling out the root cause of an error. + * If the error contains no cause, it will return the error itself. + * + * Usage: + * try { + * ... + * } catch { + * case CausedBy(ex: CommitDeniedException) => ... + * } + */ +private[spark] object CausedBy { + + def unapply(e: Throwable): Option[Throwable] = { + Option(e.getCause).flatMap(cause => unapply(cause)).orElse(Some(e)) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 1b49dca9dc78b..489688cb0880f 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -19,12 +19,14 @@ package org.apache.spark.util import java.io.{ByteArrayInputStream, ByteArrayOutputStream} -import scala.collection.mutable.{Map, Set} +import scala.collection.mutable.{Map, Set, Stack} +import scala.language.existentials -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ +import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor, Type} +import org.apache.xbean.asm5.Opcodes._ -import org.apache.spark.{Logging, SparkEnv, SparkException} +import org.apache.spark.{SparkEnv, SparkException} +import org.apache.spark.internal.Logging /** * A cleaner that renders closures serializable if they can be done so safely. @@ -76,35 +78,19 @@ private[spark] object ClosureCleaner extends Logging { */ private def getInnerClosureClasses(obj: AnyRef): List[Class[_]] = { val seen = Set[Class[_]](obj.getClass) - var stack = List[Class[_]](obj.getClass) + val stack = Stack[Class[_]](obj.getClass) while (!stack.isEmpty) { - val cr = getClassReader(stack.head) - stack = stack.tail + val cr = getClassReader(stack.pop()) val set = Set[Class[_]]() cr.accept(new InnerClosureFinder(set), 0) for (cls <- set -- seen) { seen += cls - stack = cls :: stack + stack.push(cls) } } (seen - obj.getClass).toList } - private def createNullValue(cls: Class[_]): AnyRef = { - if (cls.isPrimitive) { - cls match { - case java.lang.Boolean.TYPE => new java.lang.Boolean(false) - case java.lang.Character.TYPE => new java.lang.Character('\u0000') - case java.lang.Void.TYPE => - // This should not happen because `Foo(void x) {}` does not compile. - throw new IllegalStateException("Unexpected void parameter in constructor") - case _ => new java.lang.Byte(0: Byte) - } - } else { - null - } - } - /** * Clean the given closure in place. * @@ -232,16 +218,24 @@ private[spark] object ClosureCleaner extends Logging { // Note that all outer objects but the outermost one (first one in this list) must be closures var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse var parent: AnyRef = null - if (outerPairs.size > 0 && !isClosure(outerPairs.head._1)) { - // The closure is ultimately nested inside a class; keep the object of that - // class without cloning it since we don't want to clone the user's objects. - // Note that we still need to keep around the outermost object itself because - // we need it to clone its child closure later (see below). - logDebug(s" + outermost object is not a closure, so do not clone it: ${outerPairs.head}") - parent = outerPairs.head._2 // e.g. SparkContext - outerPairs = outerPairs.tail - } else if (outerPairs.size > 0) { - logDebug(s" + outermost object is a closure, so we just keep it: ${outerPairs.head}") + if (outerPairs.size > 0) { + val (outermostClass, outermostObject) = outerPairs.head + if (isClosure(outermostClass)) { + logDebug(s" + outermost object is a closure, so we clone it: ${outerPairs.head}") + } else if (outermostClass.getName.startsWith("$line")) { + // SPARK-14558: if the outermost object is a REPL line object, we should clone and clean it + // as it may carray a lot of unnecessary information, e.g. hadoop conf, spark conf, etc. + logDebug(s" + outermost object is a REPL line object, so we clone it: ${outerPairs.head}") + } else { + // The closure is ultimately nested inside a class; keep the object of that + // class without cloning it since we don't want to clone the user's objects. + // Note that we still need to keep around the outermost object itself because + // we need it to clone its child closure later (see below). + logDebug(" + outermost object is not a closure or REPL line object, so do not clone it: " + + outerPairs.head) + parent = outermostObject // e.g. SparkContext + outerPairs = outerPairs.tail + } } else { logDebug(" + there are no enclosing objects!") } @@ -325,11 +319,11 @@ private[spark] object ClosureCleaner extends Logging { private[spark] class ReturnStatementInClosureException extends SparkException("Return statements aren't allowed in Spark closures") -private class ReturnStatementFinder extends ClassVisitor(ASM4) { +private class ReturnStatementFinder extends ClassVisitor(ASM5) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { if (name.contains("apply")) { - new MethodVisitor(ASM4) { + new MethodVisitor(ASM5) { override def visitTypeInsn(op: Int, tp: String) { if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) { throw new ReturnStatementInClosureException @@ -337,7 +331,7 @@ private class ReturnStatementFinder extends ClassVisitor(ASM4) { } } } else { - new MethodVisitor(ASM4) {} + new MethodVisitor(ASM5) {} } } } @@ -361,7 +355,7 @@ private[util] class FieldAccessFinder( findTransitively: Boolean, specificMethod: Option[MethodIdentifier[_]] = None, visitedMethods: Set[MethodIdentifier[_]] = Set.empty) - extends ClassVisitor(ASM4) { + extends ClassVisitor(ASM5) { override def visitMethod( access: Int, @@ -376,7 +370,7 @@ private[util] class FieldAccessFinder( return null } - new MethodVisitor(ASM4) { + new MethodVisitor(ASM5) { override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) { if (op == GETFIELD) { for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) { @@ -385,7 +379,8 @@ private[util] class FieldAccessFinder( } } - override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { + override def visitMethodInsn( + op: Int, owner: String, name: String, desc: String, itf: Boolean) { for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) { // Check for calls a getter method for a variable in an interpreter wrapper object. // This means that the corresponding field will be accessed, so we should save it. @@ -408,7 +403,7 @@ private[util] class FieldAccessFinder( } } -private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) { +private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM5) { var myName: String = null // TODO: Recursively find inner closures that we indirectly reference, e.g. @@ -423,9 +418,9 @@ private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - new MethodVisitor(ASM4) { - override def visitMethodInsn(op: Int, owner: String, name: String, - desc: String) { + new MethodVisitor(ASM5) { + override def visitMethodInsn( + op: Int, owner: String, name: String, desc: String, itf: Boolean) { val argTypes = Type.getArgumentTypes(desc) if (op == INVOKESPECIAL && name == "" && argTypes.length > 0 && argTypes(0).toString.startsWith("L") // is it an object? diff --git a/core/src/main/scala/org/apache/spark/util/EventLoop.scala b/core/src/main/scala/org/apache/spark/util/EventLoop.scala index e9b2b8d24b476..3ea9139e11027 100644 --- a/core/src/main/scala/org/apache/spark/util/EventLoop.scala +++ b/core/src/main/scala/org/apache/spark/util/EventLoop.scala @@ -17,12 +17,12 @@ package org.apache.spark.util -import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque} +import java.util.concurrent.atomic.AtomicBoolean import scala.util.control.NonFatal -import org.apache.spark.Logging +import org.apache.spark.internal.Logging /** * An event loop to receive events from the caller and process all events in the event thread. It @@ -47,13 +47,12 @@ private[spark] abstract class EventLoop[E](name: String) extends Logging { try { onReceive(event) } catch { - case NonFatal(e) => { + case NonFatal(e) => try { onError(e) } catch { case NonFatal(e) => logError("Unexpected error in " + name, e) } - } } } } catch { diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index ee2eb58cf5e2a..558767e36f7da 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -19,19 +19,21 @@ package org.apache.spark.util import java.util.{Properties, UUID} -import org.apache.spark.scheduler.cluster.ExecutorInfo - import scala.collection.JavaConverters._ import scala.collection.Map +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.json4s.DefaultFormats -import org.json4s.JsonDSL._ import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage._ /** @@ -54,6 +56,8 @@ private[spark] object JsonProtocol { private implicit val format = DefaultFormats + private val mapper = new ObjectMapper().registerModule(DefaultScalaModule) + /** ------------------------------------------------- * * JSON serialization methods for SparkListenerEvents | * -------------------------------------------------- */ @@ -96,6 +100,7 @@ private[spark] object JsonProtocol { executorMetricsUpdateToJson(metricsUpdate) case blockUpdated: SparkListenerBlockUpdated => throw new MatchError(blockUpdated) // TODO(ekl) implement this + case _ => parse(mapper.writeValueAsString(event)) } } @@ -228,14 +233,14 @@ private[spark] object JsonProtocol { def executorMetricsUpdateToJson(metricsUpdate: SparkListenerExecutorMetricsUpdate): JValue = { val execId = metricsUpdate.execId - val taskMetrics = metricsUpdate.taskMetrics + val accumUpdates = metricsUpdate.accumUpdates ("Event" -> Utils.getFormattedClassName(metricsUpdate)) ~ ("Executor ID" -> execId) ~ - ("Metrics Updated" -> taskMetrics.map { case (taskId, stageId, stageAttemptId, metrics) => + ("Metrics Updated" -> accumUpdates.map { case (taskId, stageId, stageAttemptId, updates) => ("Task ID" -> taskId) ~ ("Stage ID" -> stageId) ~ ("Stage Attempt ID" -> stageAttemptId) ~ - ("Task Metrics" -> taskMetricsToJson(metrics)) + ("Accumulator Updates" -> JArray(updates.map(accumulableInfoToJson).toList)) }) } @@ -260,7 +265,7 @@ private[spark] object JsonProtocol { ("Completion Time" -> completionTime) ~ ("Failure Reason" -> failureReason) ~ ("Accumulables" -> JArray( - stageInfo.accumulables.values.map(accumulableInfoToJson).toList)) + stageInfo.accumulables.values.map(accumulableInfoToJson).toList)) } def taskInfoToJson(taskInfo: TaskInfo): JValue = { @@ -279,30 +284,80 @@ private[spark] object JsonProtocol { } def accumulableInfoToJson(accumulableInfo: AccumulableInfo): JValue = { + val name = accumulableInfo.name ("ID" -> accumulableInfo.id) ~ - ("Name" -> accumulableInfo.name) ~ - ("Update" -> accumulableInfo.update.map(new JString(_)).getOrElse(JNothing)) ~ - ("Value" -> accumulableInfo.value) ~ - ("Internal" -> accumulableInfo.internal) + ("Name" -> name) ~ + ("Update" -> accumulableInfo.update.map { v => accumValueToJson(name, v) }) ~ + ("Value" -> accumulableInfo.value.map { v => accumValueToJson(name, v) }) ~ + ("Internal" -> accumulableInfo.internal) ~ + ("Count Failed Values" -> accumulableInfo.countFailedValues) ~ + ("Metadata" -> accumulableInfo.metadata) + } + + /** + * Serialize the value of an accumulator to JSON. + * + * For accumulators representing internal task metrics, this looks up the relevant + * [[AccumulatorParam]] to serialize the value accordingly. For all other accumulators, + * this will simply serialize the value as a string. + * + * The behavior here must match that of [[accumValueFromJson]]. Exposed for testing. + */ + private[util] def accumValueToJson(name: Option[String], value: Any): JValue = { + import AccumulatorParam._ + if (name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX))) { + (value, InternalAccumulator.getParam(name.get)) match { + case (v: Int, IntAccumulatorParam) => JInt(v) + case (v: Long, LongAccumulatorParam) => JInt(v) + case (v: String, StringAccumulatorParam) => JString(v) + case (v, UpdatedBlockStatusesAccumulatorParam) => + JArray(v.asInstanceOf[Seq[(BlockId, BlockStatus)]].toList.map { case (id, status) => + ("Block ID" -> id.toString) ~ + ("Status" -> blockStatusToJson(status)) + }) + case (v, p) => + throw new IllegalArgumentException(s"unexpected combination of accumulator value " + + s"type (${v.getClass.getName}) and param (${p.getClass.getName}) in '${name.get}'") + } + } else { + // For all external accumulators, just use strings + JString(value.toString) + } } def taskMetricsToJson(taskMetrics: TaskMetrics): JValue = { - val shuffleReadMetrics = - taskMetrics.shuffleReadMetrics.map(shuffleReadMetricsToJson).getOrElse(JNothing) - val shuffleWriteMetrics = - taskMetrics.shuffleWriteMetrics.map(shuffleWriteMetricsToJson).getOrElse(JNothing) - val inputMetrics = - taskMetrics.inputMetrics.map(inputMetricsToJson).getOrElse(JNothing) - val outputMetrics = - taskMetrics.outputMetrics.map(outputMetricsToJson).getOrElse(JNothing) - val updatedBlocks = - taskMetrics.updatedBlocks.map { blocks => - JArray(blocks.toList.map { case (id, status) => - ("Block ID" -> id.toString) ~ - ("Status" -> blockStatusToJson(status)) - }) + val shuffleReadMetrics: JValue = + taskMetrics.shuffleReadMetrics.map { rm => + ("Remote Blocks Fetched" -> rm.remoteBlocksFetched) ~ + ("Local Blocks Fetched" -> rm.localBlocksFetched) ~ + ("Fetch Wait Time" -> rm.fetchWaitTime) ~ + ("Remote Bytes Read" -> rm.remoteBytesRead) ~ + ("Local Bytes Read" -> rm.localBytesRead) ~ + ("Total Records Read" -> rm.recordsRead) + }.getOrElse(JNothing) + val shuffleWriteMetrics: JValue = + taskMetrics.shuffleWriteMetrics.map { wm => + ("Shuffle Bytes Written" -> wm.bytesWritten) ~ + ("Shuffle Write Time" -> wm.writeTime) ~ + ("Shuffle Records Written" -> wm.recordsWritten) }.getOrElse(JNothing) - ("Host Name" -> taskMetrics.hostname) ~ + val inputMetrics: JValue = + taskMetrics.inputMetrics.map { im => + ("Data Read Method" -> im.readMethod.toString) ~ + ("Bytes Read" -> im.bytesRead) ~ + ("Records Read" -> im.recordsRead) + }.getOrElse(JNothing) + val outputMetrics: JValue = + taskMetrics.outputMetrics.map { om => + ("Data Write Method" -> om.writeMethod.toString) ~ + ("Bytes Written" -> om.bytesWritten) ~ + ("Records Written" -> om.recordsWritten) + }.getOrElse(JNothing) + val updatedBlocks = + JArray(taskMetrics.updatedBlockStatuses.toList.map { case (id, status) => + ("Block ID" -> id.toString) ~ + ("Status" -> blockStatusToJson(status)) + }) ("Executor Deserialize Time" -> taskMetrics.executorDeserializeTime) ~ ("Executor Run Time" -> taskMetrics.executorRunTime) ~ ("Result Size" -> taskMetrics.resultSize) ~ @@ -317,33 +372,6 @@ private[spark] object JsonProtocol { ("Updated Blocks" -> updatedBlocks) } - def shuffleReadMetricsToJson(shuffleReadMetrics: ShuffleReadMetrics): JValue = { - ("Remote Blocks Fetched" -> shuffleReadMetrics.remoteBlocksFetched) ~ - ("Local Blocks Fetched" -> shuffleReadMetrics.localBlocksFetched) ~ - ("Fetch Wait Time" -> shuffleReadMetrics.fetchWaitTime) ~ - ("Remote Bytes Read" -> shuffleReadMetrics.remoteBytesRead) ~ - ("Local Bytes Read" -> shuffleReadMetrics.localBytesRead) ~ - ("Total Records Read" -> shuffleReadMetrics.recordsRead) - } - - def shuffleWriteMetricsToJson(shuffleWriteMetrics: ShuffleWriteMetrics): JValue = { - ("Shuffle Bytes Written" -> shuffleWriteMetrics.shuffleBytesWritten) ~ - ("Shuffle Write Time" -> shuffleWriteMetrics.shuffleWriteTime) ~ - ("Shuffle Records Written" -> shuffleWriteMetrics.shuffleRecordsWritten) - } - - def inputMetricsToJson(inputMetrics: InputMetrics): JValue = { - ("Data Read Method" -> inputMetrics.readMethod.toString) ~ - ("Bytes Read" -> inputMetrics.bytesRead) ~ - ("Records Read" -> inputMetrics.recordsRead) - } - - def outputMetricsToJson(outputMetrics: OutputMetrics): JValue = { - ("Data Write Method" -> outputMetrics.writeMethod.toString) ~ - ("Bytes Written" -> outputMetrics.bytesWritten) ~ - ("Records Written" -> outputMetrics.recordsWritten) - } - def taskEndReasonToJson(taskEndReason: TaskEndReason): JValue = { val reason = Utils.getFormattedClassName(taskEndReason) val json: JObject = taskEndReason match { @@ -357,12 +385,12 @@ private[spark] object JsonProtocol { ("Message" -> fetchFailed.message) case exceptionFailure: ExceptionFailure => val stackTrace = stackTraceToJson(exceptionFailure.stackTrace) - val metrics = exceptionFailure.metrics.map(taskMetricsToJson).getOrElse(JNothing) + val accumUpdates = JArray(exceptionFailure.accumUpdates.map(accumulableInfoToJson).toList) ("Class Name" -> exceptionFailure.className) ~ ("Description" -> exceptionFailure.description) ~ ("Stack Trace" -> stackTrace) ~ ("Full Stack Trace" -> exceptionFailure.fullStackTrace) ~ - ("Metrics" -> metrics) + ("Accumulator Updates" -> accumUpdates) case taskCommitDenied: TaskCommitDenied => ("Job ID" -> taskCommitDenied.jobID) ~ ("Partition ID" -> taskCommitDenied.partitionID) ~ @@ -398,19 +426,18 @@ private[spark] object JsonProtocol { ("RDD ID" -> rddInfo.id) ~ ("Name" -> rddInfo.name) ~ ("Scope" -> rddInfo.scope.map(_.toJson)) ~ + ("Callsite" -> rddInfo.callSite) ~ ("Parent IDs" -> parentIds) ~ ("Storage Level" -> storageLevel) ~ ("Number of Partitions" -> rddInfo.numPartitions) ~ ("Number of Cached Partitions" -> rddInfo.numCachedPartitions) ~ ("Memory Size" -> rddInfo.memSize) ~ - ("ExternalBlockStore Size" -> rddInfo.externalBlockStoreSize) ~ ("Disk Size" -> rddInfo.diskSize) } def storageLevelToJson(storageLevel: StorageLevel): JValue = { ("Use Disk" -> storageLevel.useDisk) ~ ("Use Memory" -> storageLevel.useMemory) ~ - ("Use ExternalBlockStore" -> storageLevel.useOffHeap) ~ ("Deserialized" -> storageLevel.deserialized) ~ ("Replication" -> storageLevel.replication) } @@ -419,7 +446,6 @@ private[spark] object JsonProtocol { val storageLevel = storageLevelToJson(blockStatus.storageLevel) ("Storage Level" -> storageLevel) ~ ("Memory Size" -> blockStatus.memSize) ~ - ("ExternalBlockStore Size" -> blockStatus.externalBlockStoreSize) ~ ("Disk Size" -> blockStatus.diskSize) } @@ -505,6 +531,8 @@ private[spark] object JsonProtocol { case `executorRemoved` => executorRemovedFromJson(json) case `logStart` => logStartFromJson(json) case `metricsUpdate` => executorMetricsUpdateFromJson(json) + case other => mapper.readValue(compact(render(json)), Utils.classForName(other)) + .asInstanceOf[SparkListenerEvent] } } @@ -550,7 +578,9 @@ private[spark] object JsonProtocol { // The "Stage Infos" field was added in Spark 1.2.0 val stageInfos = Utils.jsonOption(json \ "Stage Infos") .map(_.extract[Seq[JValue]].map(stageInfoFromJson)).getOrElse { - stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown")) + stageIds.map { id => + new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown", Seq.empty) + } } SparkListenerJobStart(jobId, submissionTime, stageInfos, properties) } @@ -624,14 +654,15 @@ private[spark] object JsonProtocol { def executorMetricsUpdateFromJson(json: JValue): SparkListenerExecutorMetricsUpdate = { val execInfo = (json \ "Executor ID").extract[String] - val taskMetrics = (json \ "Metrics Updated").extract[List[JValue]].map { json => + val accumUpdates = (json \ "Metrics Updated").extract[List[JValue]].map { json => val taskId = (json \ "Task ID").extract[Long] val stageId = (json \ "Stage ID").extract[Int] val stageAttemptId = (json \ "Stage Attempt ID").extract[Int] - val metrics = taskMetricsFromJson(json \ "Task Metrics") - (taskId, stageId, stageAttemptId, metrics) + val updates = + (json \ "Accumulator Updates").extract[List[JValue]].map(accumulableInfoFromJson) + (taskId, stageId, stageAttemptId, updates) } - SparkListenerExecutorMetricsUpdate(execInfo, taskMetrics) + SparkListenerExecutorMetricsUpdate(execInfo, accumUpdates) } /** --------------------------------------------------------------------- * @@ -652,12 +683,12 @@ private[spark] object JsonProtocol { val completionTime = Utils.jsonOption(json \ "Completion Time").map(_.extract[Long]) val failureReason = Utils.jsonOption(json \ "Failure Reason").map(_.extract[String]) val accumulatedValues = (json \ "Accumulables").extractOpt[List[JValue]] match { - case Some(values) => values.map(accumulableInfoFromJson(_)) + case Some(values) => values.map(accumulableInfoFromJson) case None => Seq[AccumulableInfo]() } val stageInfo = new StageInfo( - stageId, attemptId, stageName, numTasks, rddInfos, parentIds, details) + stageId, attemptId, stageName, numTasks, rddInfos, parentIds, details, Seq.empty) stageInfo.submissionTime = submissionTime stageInfo.completionTime = completionTime stageInfo.failureReason = failureReason @@ -680,7 +711,7 @@ private[spark] object JsonProtocol { val finishTime = (json \ "Finish Time").extract[Long] val failed = (json \ "Failed").extract[Boolean] val accumulables = (json \ "Accumulables").extractOpt[Seq[JValue]] match { - case Some(values) => values.map(accumulableInfoFromJson(_)) + case Some(values) => values.map(accumulableInfoFromJson) case None => Seq[AccumulableInfo]() } @@ -695,11 +726,44 @@ private[spark] object JsonProtocol { def accumulableInfoFromJson(json: JValue): AccumulableInfo = { val id = (json \ "ID").extract[Long] - val name = (json \ "Name").extract[String] - val update = Utils.jsonOption(json \ "Update").map(_.extract[String]) - val value = (json \ "Value").extract[String] + val name = (json \ "Name").extractOpt[String] + val update = Utils.jsonOption(json \ "Update").map { v => accumValueFromJson(name, v) } + val value = Utils.jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) } val internal = (json \ "Internal").extractOpt[Boolean].getOrElse(false) - AccumulableInfo(id, name, update, value, internal) + val countFailedValues = (json \ "Count Failed Values").extractOpt[Boolean].getOrElse(false) + val metadata = (json \ "Metadata").extractOpt[String] + new AccumulableInfo(id, name, update, value, internal, countFailedValues, metadata) + } + + /** + * Deserialize the value of an accumulator from JSON. + * + * For accumulators representing internal task metrics, this looks up the relevant + * [[AccumulatorParam]] to deserialize the value accordingly. For all other + * accumulators, this will simply deserialize the value as a string. + * + * The behavior here must match that of [[accumValueToJson]]. Exposed for testing. + */ + private[util] def accumValueFromJson(name: Option[String], value: JValue): Any = { + import AccumulatorParam._ + if (name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX))) { + (value, InternalAccumulator.getParam(name.get)) match { + case (JInt(v), IntAccumulatorParam) => v.toInt + case (JInt(v), LongAccumulatorParam) => v.toLong + case (JString(v), StringAccumulatorParam) => v + case (JArray(v), UpdatedBlockStatusesAccumulatorParam) => + v.map { blockJson => + val id = BlockId((blockJson \ "Block ID").extract[String]) + val status = blockStatusFromJson(blockJson \ "Status") + (id, status) + } + case (v, p) => + throw new IllegalArgumentException(s"unexpected combination of accumulator " + + s"value in JSON ($v) and accumulator param (${p.getClass.getName}) in '${name.get}'") + } + } else { + value.extract[String] + } } def taskMetricsFromJson(json: JValue): TaskMetrics = { @@ -707,7 +771,6 @@ private[spark] object JsonProtocol { return TaskMetrics.empty } val metrics = new TaskMetrics - metrics.setHostname((json \ "Host Name").extract[String]) metrics.setExecutorDeserializeTime((json \ "Executor Deserialize Time").extract[Long]) metrics.setExecutorRunTime((json \ "Executor Run Time").extract[Long]) metrics.setResultSize((json \ "Result Size").extract[Long]) @@ -715,58 +778,54 @@ private[spark] object JsonProtocol { metrics.setResultSerializationTime((json \ "Result Serialization Time").extract[Long]) metrics.incMemoryBytesSpilled((json \ "Memory Bytes Spilled").extract[Long]) metrics.incDiskBytesSpilled((json \ "Disk Bytes Spilled").extract[Long]) - metrics.setShuffleReadMetrics( - Utils.jsonOption(json \ "Shuffle Read Metrics").map(shuffleReadMetricsFromJson)) - metrics.shuffleWriteMetrics = - Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson) - metrics.setInputMetrics( - Utils.jsonOption(json \ "Input Metrics").map(inputMetricsFromJson)) - metrics.outputMetrics = - Utils.jsonOption(json \ "Output Metrics").map(outputMetricsFromJson) - metrics.updatedBlocks = - Utils.jsonOption(json \ "Updated Blocks").map { value => - value.extract[List[JValue]].map { block => - val id = BlockId((block \ "Block ID").extract[String]) - val status = blockStatusFromJson(block \ "Status") - (id, status) - } - } - metrics - } - def shuffleReadMetricsFromJson(json: JValue): ShuffleReadMetrics = { - val metrics = new ShuffleReadMetrics - metrics.incRemoteBlocksFetched((json \ "Remote Blocks Fetched").extract[Int]) - metrics.incLocalBlocksFetched((json \ "Local Blocks Fetched").extract[Int]) - metrics.incFetchWaitTime((json \ "Fetch Wait Time").extract[Long]) - metrics.incRemoteBytesRead((json \ "Remote Bytes Read").extract[Long]) - metrics.incLocalBytesRead((json \ "Local Bytes Read").extractOpt[Long].getOrElse(0)) - metrics.incRecordsRead((json \ "Total Records Read").extractOpt[Long].getOrElse(0)) - metrics - } + // Shuffle read metrics + Utils.jsonOption(json \ "Shuffle Read Metrics").foreach { readJson => + val readMetrics = metrics.registerTempShuffleReadMetrics() + readMetrics.incRemoteBlocksFetched((readJson \ "Remote Blocks Fetched").extract[Int]) + readMetrics.incLocalBlocksFetched((readJson \ "Local Blocks Fetched").extract[Int]) + readMetrics.incRemoteBytesRead((readJson \ "Remote Bytes Read").extract[Long]) + readMetrics.incLocalBytesRead((readJson \ "Local Bytes Read").extractOpt[Long].getOrElse(0L)) + readMetrics.incFetchWaitTime((readJson \ "Fetch Wait Time").extract[Long]) + readMetrics.incRecordsRead((readJson \ "Total Records Read").extractOpt[Long].getOrElse(0L)) + metrics.mergeShuffleReadMetrics() + } - def shuffleWriteMetricsFromJson(json: JValue): ShuffleWriteMetrics = { - val metrics = new ShuffleWriteMetrics - metrics.incShuffleBytesWritten((json \ "Shuffle Bytes Written").extract[Long]) - metrics.incShuffleWriteTime((json \ "Shuffle Write Time").extract[Long]) - metrics.setShuffleRecordsWritten((json \ "Shuffle Records Written") - .extractOpt[Long].getOrElse(0)) - metrics - } + // Shuffle write metrics + // TODO: Drop the redundant "Shuffle" since it's inconsistent with related classes. + Utils.jsonOption(json \ "Shuffle Write Metrics").foreach { writeJson => + val writeMetrics = metrics.registerShuffleWriteMetrics() + writeMetrics.incBytesWritten((writeJson \ "Shuffle Bytes Written").extract[Long]) + writeMetrics.incRecordsWritten((writeJson \ "Shuffle Records Written") + .extractOpt[Long].getOrElse(0L)) + writeMetrics.incWriteTime((writeJson \ "Shuffle Write Time").extract[Long]) + } - def inputMetricsFromJson(json: JValue): InputMetrics = { - val metrics = new InputMetrics( - DataReadMethod.withName((json \ "Data Read Method").extract[String])) - metrics.incBytesRead((json \ "Bytes Read").extract[Long]) - metrics.incRecordsRead((json \ "Records Read").extractOpt[Long].getOrElse(0)) - metrics - } + // Output metrics + Utils.jsonOption(json \ "Output Metrics").foreach { outJson => + val writeMethod = DataWriteMethod.withName((outJson \ "Data Write Method").extract[String]) + val outputMetrics = metrics.registerOutputMetrics(writeMethod) + outputMetrics.setBytesWritten((outJson \ "Bytes Written").extract[Long]) + outputMetrics.setRecordsWritten((outJson \ "Records Written").extractOpt[Long].getOrElse(0L)) + } + + // Input metrics + Utils.jsonOption(json \ "Input Metrics").foreach { inJson => + val readMethod = DataReadMethod.withName((inJson \ "Data Read Method").extract[String]) + val inputMetrics = metrics.registerInputMetrics(readMethod) + inputMetrics.incBytesRead((inJson \ "Bytes Read").extract[Long]) + inputMetrics.incRecordsRead((inJson \ "Records Read").extractOpt[Long].getOrElse(0L)) + } + + // Updated blocks + Utils.jsonOption(json \ "Updated Blocks").foreach { blocksJson => + metrics.setUpdatedBlockStatuses(blocksJson.extract[List[JValue]].map { blockJson => + val id = BlockId((blockJson \ "Block ID").extract[String]) + val status = blockStatusFromJson(blockJson \ "Status") + (id, status) + }) + } - def outputMetricsFromJson(json: JValue): OutputMetrics = { - val metrics = new OutputMetrics( - DataWriteMethod.withName((json \ "Data Write Method").extract[String])) - metrics.setBytesWritten((json \ "Bytes Written").extract[Long]) - metrics.setRecordsWritten((json \ "Records Written").extractOpt[Long].getOrElse(0)) metrics } @@ -796,10 +855,12 @@ private[spark] object JsonProtocol { val className = (json \ "Class Name").extract[String] val description = (json \ "Description").extract[String] val stackTrace = stackTraceFromJson(json \ "Stack Trace") - val fullStackTrace = Utils.jsonOption(json \ "Full Stack Trace"). - map(_.extract[String]).orNull - val metrics = Utils.jsonOption(json \ "Metrics").map(taskMetricsFromJson) - ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics, None) + val fullStackTrace = (json \ "Full Stack Trace").extractOpt[String].orNull + // Fallback on getting accumulator updates from TaskMetrics, which was logged in Spark 1.x + val accumUpdates = Utils.jsonOption(json \ "Accumulator Updates") + .map(_.extract[List[JValue]].map(accumulableInfoFromJson)) + .getOrElse(taskMetricsFromJson(json \ "Metrics").accumulatorUpdates()) + ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates) case `taskResultLost` => TaskResultLost case `taskKilled` => TaskKilled case `taskCommitDenied` => @@ -851,6 +912,7 @@ private[spark] object JsonProtocol { val scope = Utils.jsonOption(json \ "Scope") .map(_.extract[String]) .map(RDDOperationScope.fromJson) + val callsite = Utils.jsonOption(json \ "Callsite").map(_.extract[String]).getOrElse("") val parentIds = Utils.jsonOption(json \ "Parent IDs") .map { l => l.extract[List[JValue]].map(_.extract[Int]) } .getOrElse(Seq.empty) @@ -858,15 +920,11 @@ private[spark] object JsonProtocol { val numPartitions = (json \ "Number of Partitions").extract[Int] val numCachedPartitions = (json \ "Number of Cached Partitions").extract[Int] val memSize = (json \ "Memory Size").extract[Long] - // fallback to tachyon for backward compatibility - val externalBlockStoreSize = (json \ "ExternalBlockStore Size").toSome - .getOrElse(json \ "Tachyon Size").extract[Long] val diskSize = (json \ "Disk Size").extract[Long] - val rddInfo = new RDDInfo(rddId, name, numPartitions, storageLevel, parentIds, scope) + val rddInfo = new RDDInfo(rddId, name, numPartitions, storageLevel, parentIds, callsite, scope) rddInfo.numCachedPartitions = numCachedPartitions rddInfo.memSize = memSize - rddInfo.externalBlockStoreSize = externalBlockStoreSize rddInfo.diskSize = diskSize rddInfo } @@ -874,22 +932,16 @@ private[spark] object JsonProtocol { def storageLevelFromJson(json: JValue): StorageLevel = { val useDisk = (json \ "Use Disk").extract[Boolean] val useMemory = (json \ "Use Memory").extract[Boolean] - // fallback to tachyon for backward compatability - val useExternalBlockStore = (json \ "Use ExternalBlockStore").toSome - .getOrElse(json \ "Use Tachyon").extract[Boolean] val deserialized = (json \ "Deserialized").extract[Boolean] val replication = (json \ "Replication").extract[Int] - StorageLevel(useDisk, useMemory, useExternalBlockStore, deserialized, replication) + StorageLevel(useDisk, useMemory, deserialized, replication) } def blockStatusFromJson(json: JValue): BlockStatus = { val storageLevel = storageLevelFromJson(json \ "Storage Level") val memorySize = (json \ "Memory Size").extract[Long] val diskSize = (json \ "Disk Size").extract[Long] - // fallback to tachyon for backward compatability - val externalBlockStoreSize = (json \ "ExternalBlockStore Size").toSome - .getOrElse(json \ "Tachyon Size").extract[Long] - BlockStatus(storageLevel, memorySize, diskSize, externalBlockStoreSize) + BlockStatus(storageLevel, memorySize, diskSize) } def executorInfoFromJson(json: JValue): ExecutorInfo = { diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index 13cb516b583e9..436c1951dee2f 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.util.control.NonFatal -import org.apache.spark.Logging +import org.apache.spark.internal.Logging /** * An event bus which posts events to its listeners. @@ -36,10 +36,18 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { /** * Add a listener to listen events. This method is thread-safe and can be called in any thread. */ - final def addListener(listener: L) { + final def addListener(listener: L): Unit = { listeners.add(listener) } + /** + * Remove a listener and it won't receive any events. This method is thread-safe and can be called + * in any thread. + */ + final def removeListener(listener: L): Unit = { + listeners.remove(listener) + } + /** * Post the event to all registered listeners. The `postToAll` caller should guarantee calling * `postToAll` in the same thread for all events. @@ -52,7 +60,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { while (iter.hasNext) { val listener = iter.next() try { - onPostEvent(listener, event) + doPostEvent(listener, event) } catch { case NonFatal(e) => logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) @@ -64,7 +72,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { * Post an event to the specified listener. `onPostEvent` is guaranteed to be called in the same * thread. */ - def onPostEvent(listener: L, event: E): Unit + protected def doPostEvent(listener: L, event: E): Unit private[spark] def findListenersByClass[T <: L : ClassTag](): Seq[T] = { val c = implicitly[ClassTag[T]].runtimeClass diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala deleted file mode 100644 index a8bbad086849e..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import java.util.{Timer, TimerTask} - -import org.apache.spark.{Logging, SparkConf} - -/** - * Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries) - */ -private[spark] class MetadataCleaner( - cleanerType: MetadataCleanerType.MetadataCleanerType, - cleanupFunc: (Long) => Unit, - conf: SparkConf) - extends Logging -{ - val name = cleanerType.toString - - private val delaySeconds = MetadataCleaner.getDelaySeconds(conf, cleanerType) - private val periodSeconds = math.max(10, delaySeconds / 10) - private val timer = new Timer(name + " cleanup timer", true) - - - private val task = new TimerTask { - override def run() { - try { - cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) - logInfo("Ran metadata cleaner for " + name) - } catch { - case e: Exception => logError("Error running cleanup task for " + name, e) - } - } - } - - if (delaySeconds > 0) { - logDebug( - "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds " + - "and period of " + periodSeconds + " secs") - timer.schedule(task, delaySeconds * 1000, periodSeconds * 1000) - } - - def cancel() { - timer.cancel() - } -} - -private[spark] object MetadataCleanerType extends Enumeration { - - val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, BLOCK_MANAGER, - SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value - - type MetadataCleanerType = Value - - def systemProperty(which: MetadataCleanerType.MetadataCleanerType): String = { - "spark.cleaner.ttl." + which.toString - } -} - -// TODO: This mutates a Conf to set properties right now, which is kind of ugly when used in the -// initialization of StreamingContext. It's okay for users trying to configure stuff themselves. -private[spark] object MetadataCleaner { - def getDelaySeconds(conf: SparkConf): Int = { - conf.getTimeAsSeconds("spark.cleaner.ttl", "-1").toInt - } - - def getDelaySeconds( - conf: SparkConf, - cleanerType: MetadataCleanerType.MetadataCleanerType): Int = { - conf.get(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString).toInt - } - - def setDelaySeconds( - conf: SparkConf, - cleanerType: MetadataCleanerType.MetadataCleanerType, - delay: Int) { - conf.set(MetadataCleanerType.systemProperty(cleanerType), delay.toString) - } - - /** - * Set the default delay time (in seconds). - * @param conf SparkConf instance - * @param delay default delay time to set - * @param resetAll whether to reset all to default - */ - def setDelaySeconds(conf: SparkConf, delay: Int, resetAll: Boolean = true) { - conf.set("spark.cleaner.ttl", delay.toString) - if (resetAll) { - for (cleanerType <- MetadataCleanerType.values) { - System.clearProperty(MetadataCleanerType.systemProperty(cleanerType)) - } - } - } -} - diff --git a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala index 945217203be72..0a3180da87987 100644 --- a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala @@ -17,7 +17,7 @@ package org.apache.spark.util -import java.net.{URLClassLoader, URL} +import java.net.{URL, URLClassLoader} import java.util.Enumeration import java.util.concurrent.ConcurrentHashMap diff --git a/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala index 73d126ff6254e..c9b7493fcdc1b 100644 --- a/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala @@ -18,7 +18,7 @@ package org.apache.spark.util /** - * A class loader which makes some protected methods in ClassLoader accesible. + * A class loader which makes some protected methods in ClassLoader accessible. */ private[spark] class ParentClassLoader(parent: ClassLoader) extends ClassLoader(parent) { diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala index 7578a3b1d85f2..2bb8de568e803 100644 --- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -17,23 +17,21 @@ package org.apache.spark.util -import scala.concurrent.duration.FiniteDuration import scala.language.postfixOps -import org.apache.spark.{SparkEnv, SparkConf} +import org.apache.spark.SparkConf import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, RpcTimeout} -object RpcUtils { +private[spark] object RpcUtils { /** * Retrieve a [[RpcEndpointRef]] which is located in the driver via its name. */ def makeDriverRef(name: String, conf: SparkConf, rpcEnv: RpcEnv): RpcEndpointRef = { - val driverActorSystemName = SparkEnv.driverActorSystemName val driverHost: String = conf.get("spark.driver.host", "localhost") val driverPort: Int = conf.getInt("spark.driver.port", 7077) Utils.checkHost(driverHost, "Expected hostname") - rpcEnv.setupEndpointRef(driverActorSystemName, RpcAddress(driverHost, driverPort), name) + rpcEnv.setupEndpointRef(RpcAddress(driverHost, driverPort), name) } /** Returns the configured number of times to retry connecting */ @@ -47,22 +45,24 @@ object RpcUtils { } /** Returns the default Spark timeout to use for RPC ask operations. */ - private[spark] def askRpcTimeout(conf: SparkConf): RpcTimeout = { + def askRpcTimeout(conf: SparkConf): RpcTimeout = { RpcTimeout(conf, Seq("spark.rpc.askTimeout", "spark.network.timeout"), "120s") } - @deprecated("use askRpcTimeout instead, this method was not intended to be public", "1.5.0") - def askTimeout(conf: SparkConf): FiniteDuration = { - askRpcTimeout(conf).duration - } - /** Returns the default Spark timeout to use for RPC remote endpoint lookup. */ - private[spark] def lookupRpcTimeout(conf: SparkConf): RpcTimeout = { + def lookupRpcTimeout(conf: SparkConf): RpcTimeout = { RpcTimeout(conf, Seq("spark.rpc.lookupTimeout", "spark.network.timeout"), "120s") } - @deprecated("use lookupRpcTimeout instead, this method was not intended to be public", "1.5.0") - def lookupTimeout(conf: SparkConf): FiniteDuration = { - lookupRpcTimeout(conf).duration + private val MAX_MESSAGE_SIZE_IN_MB = Int.MaxValue / 1024 / 1024 + + /** Returns the configured max message size for messages in bytes. */ + def maxMessageSizeBytes(conf: SparkConf): Int = { + val maxSizeInMB = conf.getInt("spark.rpc.message.maxSize", 128) + if (maxSizeInMB > MAX_MESSAGE_SIZE_IN_MB) { + throw new IllegalArgumentException( + s"spark.rpc.message.maxSize should not be greater than $MAX_MESSAGE_SIZE_IN_MB MB") + } + maxSizeInMB * 1024 * 1024 } } diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala index db4a8b304ec3e..bd26bfd848ff1 100644 --- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -20,11 +20,11 @@ package org.apache.spark.util import java.io.File import java.util.PriorityQueue -import scala.util.{Failure, Success, Try} -import tachyon.client.TachyonFile +import scala.util.Try import org.apache.hadoop.fs.FileSystem -import org.apache.spark.Logging + +import org.apache.spark.internal.Logging /** * Various utility methods used by Spark. @@ -52,12 +52,13 @@ private[spark] object ShutdownHookManager extends Logging { } private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() - private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() // Add a shutdown hook to delete the temp dirs when the JVM exits addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () => logInfo("Shutdown hook called") - shutdownDeletePaths.foreach { dirPath => + // we need to materialize the paths to delete because deleteRecursively removes items from + // shutdownDeletePaths as we are traversing through it. + shutdownDeletePaths.toArray.foreach { dirPath => try { logInfo("Deleting directory " + dirPath) Utils.deleteRecursively(new File(dirPath)) @@ -75,14 +76,6 @@ private[spark] object ShutdownHookManager extends Logging { } } - // Register the tachyon path to be deleted via shutdown hook - def registerShutdownDeleteDir(tachyonfile: TachyonFile) { - val absolutePath = tachyonfile.getPath() - shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths += absolutePath - } - } - // Remove the path to be deleted via shutdown hook def removeShutdownDeleteDir(file: File) { val absolutePath = file.getAbsolutePath() @@ -91,14 +84,6 @@ private[spark] object ShutdownHookManager extends Logging { } } - // Remove the tachyon path to be deleted via shutdown hook - def removeShutdownDeleteDir(tachyonfile: TachyonFile) { - val absolutePath = tachyonfile.getPath() - shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths.remove(absolutePath) - } - } - // Is the path already registered to be deleted via a shutdown hook ? def hasShutdownDeleteDir(file: File): Boolean = { val absolutePath = file.getAbsolutePath() @@ -107,14 +92,6 @@ private[spark] object ShutdownHookManager extends Logging { } } - // Is the path already registered to be deleted via a shutdown hook ? - def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = { - val absolutePath = file.getPath() - shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths.contains(absolutePath) - } - } - // Note: if file is child of some registered path, while not equal to it, then return true; // else false. This is to ensure that two shutdown hooks do not try to delete each others // paths - resulting in IOException and incomplete cleanup. @@ -131,22 +108,6 @@ private[spark] object ShutdownHookManager extends Logging { retval } - // Note: if file is child of some registered path, while not equal to it, then return true; - // else false. This is to ensure that two shutdown hooks do not try to delete each others - // paths - resulting in Exception and incomplete cleanup. - def hasRootAsShutdownDeleteDir(file: TachyonFile): Boolean = { - val absolutePath = file.getPath() - val retval = shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths.exists { path => - !absolutePath.equals(path) && absolutePath.startsWith(path) - } - } - if (retval) { - logInfo("path = " + file + ", already present as root for deletion.") - } - retval - } - /** * Detect whether this thread might be executing a shutdown hook. Will always return true if * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g. @@ -160,7 +121,9 @@ private[spark] object ShutdownHookManager extends Logging { val hook = new Thread { override def run() {} } + // scalastyle:off runtimeaddshutdownhook Runtime.getRuntime.addShutdownHook(hook) + // scalastyle:on runtimeaddshutdownhook Runtime.getRuntime.removeShutdownHook(hook) } catch { case ise: IllegalStateException => return true @@ -204,7 +167,7 @@ private[spark] object ShutdownHookManager extends Logging { private [util] class SparkShutdownHookManager { private val hooks = new PriorityQueue[SparkShutdownHook]() - private var shuttingDown = false + @volatile private var shuttingDown = false /** * Install a hook to run at shutdown and run all registered hooks in order. Hadoop 1.x does not @@ -215,43 +178,31 @@ private [util] class SparkShutdownHookManager { val hookTask = new Runnable() { override def run(): Unit = runAll() } - Try(Utils.classForName("org.apache.hadoop.util.ShutdownHookManager")) match { - case Success(shmClass) => - val fsPriority = classOf[FileSystem] - .getField("SHUTDOWN_HOOK_PRIORITY") - .get(null) // static field, the value is not used - .asInstanceOf[Int] - val shm = shmClass.getMethod("get").invoke(null) - shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int]) - .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30)) - - case Failure(_) => - Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook")); - } + org.apache.hadoop.util.ShutdownHookManager.get().addShutdownHook( + hookTask, FileSystem.SHUTDOWN_HOOK_PRIORITY + 30) } - def runAll(): Unit = synchronized { + def runAll(): Unit = { shuttingDown = true - while (!hooks.isEmpty()) { - Try(Utils.logUncaughtExceptions(hooks.poll().run())) + var nextHook: SparkShutdownHook = null + while ({ nextHook = hooks.synchronized { hooks.poll() }; nextHook != null }) { + Try(Utils.logUncaughtExceptions(nextHook.run())) } } - def add(priority: Int, hook: () => Unit): AnyRef = synchronized { - checkState() - val hookRef = new SparkShutdownHook(priority, hook) - hooks.add(hookRef) - hookRef - } - - def remove(ref: AnyRef): Boolean = synchronized { - hooks.remove(ref) + def add(priority: Int, hook: () => Unit): AnyRef = { + hooks.synchronized { + if (shuttingDown) { + throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.") + } + val hookRef = new SparkShutdownHook(priority, hook) + hooks.add(hookRef) + hookRef + } } - private def checkState(): Unit = { - if (shuttingDown) { - throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.") - } + def remove(ref: AnyRef): Boolean = { + hooks.synchronized { hooks.remove(ref) } } } diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 23ee4eff0881b..6861a75612dd1 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -17,20 +17,34 @@ package org.apache.spark.util -import com.google.common.collect.MapMaker - import java.lang.management.ManagementFactory import java.lang.reflect.{Field, Modifier} import java.util.{IdentityHashMap, Random} -import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable.ArrayBuffer import scala.runtime.ScalaRunTime -import org.apache.spark.Logging +import com.google.common.collect.MapMaker + import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.Logging import org.apache.spark.util.collection.OpenHashSet +/** + * A trait that allows a class to give [[SizeEstimator]] more accurate size estimation. + * When a class extends it, [[SizeEstimator]] will query the `estimatedSize` first. + * If `estimatedSize` does not return [[None]], [[SizeEstimator]] will use the returned size + * as the size of the object. Otherwise, [[SizeEstimator]] will do the estimation work. + * The difference between a [[KnownSizeEstimation]] and + * [[org.apache.spark.util.collection.SizeTracker]] is that, a + * [[org.apache.spark.util.collection.SizeTracker]] still uses [[SizeEstimator]] to + * estimate the size. However, a [[KnownSizeEstimation]] can provide a better estimation without + * using [[SizeEstimator]]. + */ +private[spark] trait KnownSizeEstimation { + def estimatedSize: Long +} + /** * :: DeveloperApi :: * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in @@ -137,13 +151,12 @@ object SizeEstimator extends Logging { // TODO: We could use reflection on the VMOption returned ? getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true") } catch { - case e: Exception => { + case e: Exception => // Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB val guess = Runtime.getRuntime.maxMemory < (32L*1024*1024*1024) val guessInWords = if (guess) "yes" else "not" logWarning("Failed to check whether UseCompressedOops is set; assuming " + guessInWords) return guess - } } } @@ -199,10 +212,15 @@ object SizeEstimator extends Logging { // the size estimator since it references the whole REPL. Do nothing in this case. In // general all ClassLoaders and Classes will be shared between objects anyway. } else { - val classInfo = getClassInfo(cls) - state.size += alignSize(classInfo.shellSize) - for (field <- classInfo.pointerFields) { - state.enqueue(field.get(obj)) + obj match { + case s: KnownSizeEstimation => + state.size += s.estimatedSize + case _ => + val classInfo = getClassInfo(cls) + state.size += alignSize(classInfo.shellSize) + for (field <- classInfo.pointerFields) { + state.enqueue(field.get(obj)) + } } } } @@ -234,7 +252,7 @@ object SizeEstimator extends Logging { } else { // Estimate the size of a large array by sampling elements without replacement. // To exclude the shared objects that the array elements may link, sample twice - // and use the min one to caculate array size. + // and use the min one to calculate array size. val rand = new Random(42) val drawn = new OpenHashSet[Int](2 * ARRAY_SAMPLE_SIZE) val s1 = sampleArray(array, state, rand, drawn, length) diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala index 7248187247330..95bf3f58bc77f 100644 --- a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala +++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala @@ -17,7 +17,7 @@ package org.apache.spark.util -import org.apache.spark.Logging +import org.apache.spark.internal.Logging /** * The default uncaught exception handler for Executors terminates the whole process, to avoid @@ -29,7 +29,11 @@ private[spark] object SparkUncaughtExceptionHandler override def uncaughtException(thread: Thread, exception: Throwable) { try { - logError("Uncaught exception in thread " + thread, exception) + // Make it explicit that uncaught exceptions are thrown when container is shutting down. + // It will help users when they analyze the executor logs + val inShutdownMsg = if (ShutdownHookManager.inShutdown()) "[Container in shutdown] " else "" + val errMsg = "Uncaught exception in thread " + logError(inShutdownMsg + errMsg + thread, exception) // We may have been called from a shutdown hook. If so, we must not call System.exit(). // (If we do, we will deadlock.) diff --git a/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala b/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala deleted file mode 100644 index c1b8bf052c0ca..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import java.util.EventListener - -import org.apache.spark.TaskContext -import org.apache.spark.annotation.DeveloperApi - -/** - * :: DeveloperApi :: - * - * Listener providing a callback function to invoke when a task's execution completes. - */ -@DeveloperApi -trait TaskCompletionListener extends EventListener { - def onTaskCompletion(context: TaskContext) -} diff --git a/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala b/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala deleted file mode 100644 index f64e069cd1724..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -/** - * Exception thrown when there is an exception in - * executing the callback in TaskCompletionListener. - */ -private[spark] -class TaskCompletionListenerException(errorMessages: Seq[String]) extends Exception { - - override def getMessage: String = { - if (errorMessages.size == 1) { - errorMessages.head - } else { - errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") - } - } -} diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 53283448c87b1..9abbf4a7a3971 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.util import java.util.concurrent._ import scala.concurrent.{ExecutionContext, ExecutionContextExecutor} +import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread} import scala.util.control.NonFatal import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} @@ -56,10 +57,18 @@ private[spark] object ThreadUtils { * Create a cached thread pool whose max number of threads is `maxThreadNumber`. Thread names * are formatted as prefix-ID, where ID is a unique, sequentially assigned integer. */ - def newDaemonCachedThreadPool(prefix: String, maxThreadNumber: Int): ThreadPoolExecutor = { + def newDaemonCachedThreadPool( + prefix: String, maxThreadNumber: Int, keepAliveSeconds: Int = 60): ThreadPoolExecutor = { val threadFactory = namedThreadFactory(prefix) - new ThreadPoolExecutor( - 0, maxThreadNumber, 60L, TimeUnit.SECONDS, new SynchronousQueue[Runnable], threadFactory) + val threadPool = new ThreadPoolExecutor( + maxThreadNumber, // corePoolSize: the max number of threads to create before queuing the tasks + maxThreadNumber, // maximumPoolSize: because we use LinkedBlockingDeque, this one is not used + keepAliveSeconds, + TimeUnit.SECONDS, + new LinkedBlockingQueue[Runnable], + threadFactory) + threadPool.allowCoreThreadTimeOut(true) + threadPool } /** @@ -148,4 +157,21 @@ private[spark] object ThreadUtils { result } } + + /** + * Construct a new Scala ForkJoinPool with a specified max parallelism and name prefix. + */ + def newForkJoinPool(prefix: String, maxThreadNumber: Int): SForkJoinPool = { + // Custom factory to set thread names + val factory = new SForkJoinPool.ForkJoinWorkerThreadFactory { + override def newThread(pool: SForkJoinPool) = + new SForkJoinWorkerThread(pool) { + setName(prefix + "-" + super.getName) + } + } + new SForkJoinPool(maxThreadNumber, factory, + null, // handler + false // asyncMode + ) + } } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index d7e5143c30953..32af0127bbf38 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -17,14 +17,14 @@ package org.apache.spark.util -import java.util.Set import java.util.Map.Entry +import java.util.Set import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import scala.collection.mutable -import org.apache.spark.Logging +import org.apache.spark.internal.Logging private[spark] case class TimeStampedValue[V](value: V, timestamp: Long) diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala deleted file mode 100644 index 65efeb1f4c19c..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import java.util.concurrent.ConcurrentHashMap - -import scala.collection.JavaConverters._ -import scala.collection.mutable.Set - -private[spark] class TimeStampedHashSet[A] extends Set[A] { - val internalMap = new ConcurrentHashMap[A, Long]() - - def contains(key: A): Boolean = { - internalMap.contains(key) - } - - def iterator: Iterator[A] = { - val jIterator = internalMap.entrySet().iterator() - jIterator.asScala.map(_.getKey) - } - - override def + (elem: A): Set[A] = { - val newSet = new TimeStampedHashSet[A] - newSet ++= this - newSet += elem - newSet - } - - override def - (elem: A): Set[A] = { - val newSet = new TimeStampedHashSet[A] - newSet ++= this - newSet -= elem - newSet - } - - override def += (key: A): this.type = { - internalMap.put(key, currentTime) - this - } - - override def -= (key: A): this.type = { - internalMap.remove(key) - this - } - - override def empty: Set[A] = new TimeStampedHashSet[A]() - - override def size(): Int = internalMap.size() - - override def foreach[U](f: (A) => U): Unit = { - val iterator = internalMap.entrySet().iterator() - while(iterator.hasNext) { - f(iterator.next.getKey) - } - } - - /** - * Removes old values that have timestamp earlier than `threshTime` - */ - def clearOldValues(threshTime: Long) { - val iterator = internalMap.entrySet().iterator() - while(iterator.hasNext) { - val entry = iterator.next() - if (entry.getValue < threshTime) { - iterator.remove() - } - } - } - - private def currentTime: Long = System.currentTimeMillis() -} diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala deleted file mode 100644 index 310c0c109416c..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import java.lang.ref.WeakReference -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.mutable -import scala.language.implicitConversions - -import org.apache.spark.Logging - -/** - * A wrapper of TimeStampedHashMap that ensures the values are weakly referenced and timestamped. - * - * If the value is garbage collected and the weak reference is null, get() will return a - * non-existent value. These entries are removed from the map periodically (every N inserts), as - * their values are no longer strongly reachable. Further, key-value pairs whose timestamps are - * older than a particular threshold can be removed using the clearOldValues method. - * - * TimeStampedWeakValueHashMap exposes a scala.collection.mutable.Map interface, which allows it - * to be a drop-in replacement for Scala HashMaps. Internally, it uses a Java ConcurrentHashMap, - * so all operations on this HashMap are thread-safe. - * - * @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed. - */ -private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boolean = false) - extends mutable.Map[A, B]() with Logging { - - import TimeStampedWeakValueHashMap._ - - private val internalMap = new TimeStampedHashMap[A, WeakReference[B]](updateTimeStampOnGet) - private val insertCount = new AtomicInteger(0) - - /** Return a map consisting only of entries whose values are still strongly reachable. */ - private def nonNullReferenceMap = internalMap.filter { case (_, ref) => ref.get != null } - - def get(key: A): Option[B] = internalMap.get(key) - - def iterator: Iterator[(A, B)] = nonNullReferenceMap.iterator - - override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = { - val newMap = new TimeStampedWeakValueHashMap[A, B1] - val oldMap = nonNullReferenceMap.asInstanceOf[mutable.Map[A, WeakReference[B1]]] - newMap.internalMap.putAll(oldMap.toMap) - newMap.internalMap += kv - newMap - } - - override def - (key: A): mutable.Map[A, B] = { - val newMap = new TimeStampedWeakValueHashMap[A, B] - newMap.internalMap.putAll(nonNullReferenceMap.toMap) - newMap.internalMap -= key - newMap - } - - override def += (kv: (A, B)): this.type = { - internalMap += kv - if (insertCount.incrementAndGet() % CLEAR_NULL_VALUES_INTERVAL == 0) { - clearNullValues() - } - this - } - - override def -= (key: A): this.type = { - internalMap -= key - this - } - - override def update(key: A, value: B): Unit = this += ((key, value)) - - override def apply(key: A): B = internalMap.apply(key) - - override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = nonNullReferenceMap.filter(p) - - override def empty: mutable.Map[A, B] = new TimeStampedWeakValueHashMap[A, B]() - - override def size: Int = internalMap.size - - override def foreach[U](f: ((A, B)) => U): Unit = nonNullReferenceMap.foreach(f) - - def putIfAbsent(key: A, value: B): Option[B] = internalMap.putIfAbsent(key, value) - - def toMap: Map[A, B] = iterator.toMap - - /** Remove old key-value pairs with timestamps earlier than `threshTime`. */ - def clearOldValues(threshTime: Long): Unit = internalMap.clearOldValues(threshTime) - - /** Remove entries with values that are no longer strongly reachable. */ - def clearNullValues() { - val it = internalMap.getEntrySet.iterator - while (it.hasNext) { - val entry = it.next() - if (entry.getValue.value.get == null) { - logDebug("Removing key " + entry.getKey + " because it is no longer strongly reachable.") - it.remove() - } - } - } - - // For testing - - def getTimestamp(key: A): Option[Long] = { - internalMap.getTimeStampedValue(key).map(_.timestamp) - } - - def getReference(key: A): Option[WeakReference[B]] = { - internalMap.getTimeStampedValue(key).map(_.value) - } -} - -/** - * Helper methods for converting to and from WeakReferences. - */ -private object TimeStampedWeakValueHashMap { - - // Number of inserts after which entries with null references are removed - val CLEAR_NULL_VALUES_INTERVAL = 100 - - /* Implicit conversion methods to WeakReferences. */ - - implicit def toWeakReference[V](v: V): WeakReference[V] = new WeakReference[V](v) - - implicit def toWeakReferenceTuple[K, V](kv: (K, V)): (K, WeakReference[V]) = { - kv match { case (k, v) => (k, toWeakReference(v)) } - } - - implicit def toWeakReferenceFunction[K, V, R](p: ((K, V)) => R): ((K, WeakReference[V])) => R = { - (kv: (K, WeakReference[V])) => p(kv) - } - - /* Implicit conversion methods from WeakReferences. */ - - implicit def fromWeakReference[V](ref: WeakReference[V]): V = ref.get - - implicit def fromWeakReferenceOption[V](v: Option[WeakReference[V]]): Option[V] = { - v match { - case Some(ref) => Option(fromWeakReference(ref)) - case None => None - } - } - - implicit def fromWeakReferenceTuple[K, V](kv: (K, WeakReference[V])): (K, V) = { - kv match { case (k, v) => (k, fromWeakReference(v)) } - } - - implicit def fromWeakReferenceIterator[K, V]( - it: Iterator[(K, WeakReference[V])]): Iterator[(K, V)] = { - it.map(fromWeakReferenceTuple) - } - - implicit def fromWeakReferenceMap[K, V]( - map: mutable.Map[K, WeakReference[V]]) : mutable.Map[K, V] = { - mutable.Map(map.mapValues(fromWeakReference).toSeq: _*) - } -} diff --git a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala new file mode 100644 index 0000000000000..4dcf95177aa78 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import javax.annotation.concurrent.GuardedBy + +/** + * A special Thread that provides "runUninterruptibly" to allow running codes without being + * interrupted by `Thread.interrupt()`. If `Thread.interrupt()` is called during runUninterruptibly + * is running, it won't set the interrupted status. Instead, setting the interrupted status will be + * deferred until it's returning from "runUninterruptibly". + * + * Note: "runUninterruptibly" should be called only in `this` thread. + */ +private[spark] class UninterruptibleThread(name: String) extends Thread(name) { + + /** A monitor to protect "uninterruptible" and "interrupted" */ + private val uninterruptibleLock = new Object + + /** + * Indicates if `this` thread are in the uninterruptible status. If so, interrupting + * "this" will be deferred until `this` enters into the interruptible status. + */ + @GuardedBy("uninterruptibleLock") + private var uninterruptible = false + + /** + * Indicates if we should interrupt `this` when we are leaving the uninterruptible zone. + */ + @GuardedBy("uninterruptibleLock") + private var shouldInterruptThread = false + + /** + * Run `f` uninterruptibly in `this` thread. The thread won't be interrupted before returning + * from `f`. + * + * If this method finds that `interrupt` is called before calling `f` and it's not inside another + * `runUninterruptibly`, it will throw `InterruptedException`. + * + * Note: this method should be called only in `this` thread. + */ + def runUninterruptibly[T](f: => T): T = { + if (Thread.currentThread() != this) { + throw new IllegalStateException(s"Call runUninterruptibly in a wrong thread. " + + s"Expected: $this but was ${Thread.currentThread()}") + } + + if (uninterruptibleLock.synchronized { uninterruptible }) { + // We are already in the uninterruptible status. So just run "f" and return + return f + } + + uninterruptibleLock.synchronized { + // Clear the interrupted status if it's set. + if (Thread.interrupted() || shouldInterruptThread) { + shouldInterruptThread = false + // Since it's interrupted, we don't need to run `f` which may be a long computation. + // Throw InterruptedException as we don't have a T to return. + throw new InterruptedException() + } + uninterruptible = true + } + try { + f + } finally { + uninterruptibleLock.synchronized { + uninterruptible = false + if (shouldInterruptThread) { + // Recover the interrupted status + super.interrupt() + shouldInterruptThread = false + } + } + } + } + + /** + * Tests whether `interrupt()` has been called. + */ + override def isInterrupted: Boolean = { + super.isInterrupted || uninterruptibleLock.synchronized { shouldInterruptThread } + } + + /** + * Interrupt `this` thread if possible. If `this` is in the uninterruptible status, it won't be + * interrupted until it enters into the interruptible status. + */ + override def interrupt(): Unit = { + uninterruptibleLock.synchronized { + if (uninterruptible) { + shouldInterruptThread = true + } else { + super.interrupt() + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 5a976ee839b1e..78e164cff7738 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -21,19 +21,23 @@ import java.io._ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer -import java.util.{Properties, Locale, Random, UUID} +import java.nio.channels.Channels +import java.nio.charset.StandardCharsets +import java.nio.file.Files +import java.util.{Locale, Properties, Random, UUID} import java.util.concurrent._ import javax.net.ssl.HttpsURLConnection +import scala.annotation.tailrec import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.ArrayBuffer import scala.io.Source import scala.reflect.ClassTag -import scala.util.{Failure, Success, Try} +import scala.util.Try import scala.util.control.{ControlThrowable, NonFatal} -import com.google.common.io.{ByteStreams, Files} +import com.google.common.io.{ByteStreams, Files => GFiles} import com.google.common.net.InetAddresses import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration @@ -42,12 +46,11 @@ import org.apache.hadoop.security.UserGroupInformation import org.apache.log4j.PropertyConfigurator import org.eclipse.jetty.util.MultiException import org.json4s._ - -import tachyon.TachyonURI -import tachyon.client.{TachyonFS, TachyonFile} +import org.slf4j.Logger import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} @@ -57,6 +60,7 @@ private[spark] case class CallSite(shortForm: String, longForm: String) private[spark] object CallSite { val SHORT_FORM = "callSite.short" val LONG_FORM = "callSite.long" + val empty = CallSite("", "") } /** @@ -177,7 +181,20 @@ private[spark] object Utils extends Logging { /** * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]] */ - def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput): Unit = { + def writeByteBuffer(bb: ByteBuffer, out: DataOutput): Unit = { + if (bb.hasArray) { + out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + } else { + val bbval = new Array[Byte](bb.remaining()) + bb.get(bbval) + out.write(bbval) + } + } + + /** + * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.OutputStream]] + */ + def writeByteBuffer(bb: ByteBuffer, out: OutputStream): Unit = { if (bb.hasArray) { out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) } else { @@ -239,10 +256,11 @@ private[spark] object Utils extends Logging { dir } - /** Copy all data from an InputStream to an OutputStream. NIO way of file stream to file stream - * copying is disabled by default unless explicitly set transferToEnabled as true, - * the parameter transferToEnabled should be configured by spark.file.transferTo = [true|false]. - */ + /** + * Copy all data from an InputStream to an OutputStream. NIO way of file stream to file stream + * copying is disabled by default unless explicitly set transferToEnabled as true, + * the parameter transferToEnabled should be configured by spark.file.transferTo = [true|false]. + */ def copyStream(in: InputStream, out: OutputStream, closeStreams: Boolean = false, @@ -317,6 +335,30 @@ private[spark] object Utils extends Logging { } /** + * A file name may contain some invalid URI characters, such as " ". This method will convert the + * file name to a raw path accepted by `java.net.URI(String)`. + * + * Note: the file name must not contain "/" or "\" + */ + def encodeFileNameToURIRawPath(fileName: String): String = { + require(!fileName.contains("/") && !fileName.contains("\\")) + // `file` and `localhost` are not used. Just to prevent URI from parsing `fileName` as + // scheme or host. The prefix "/" is required because URI doesn't accept a relative path. + // We should remove it after we get the raw path. + new URI("file", null, "localhost", -1, "/" + fileName, null, null).getRawPath.substring(1) + } + + /** + * Get the file name from uri's raw path and decode it. If the raw path of uri ends with "/", + * return the name before the last "/". + */ + def decodeFileNameInURI(uri: URI): String = { + val rawPath = uri.getRawPath + val rawFileName = rawPath.split("/").last + new URI("file:///" + rawFileName).getPath.substring(1) + } + + /** * Download a file or directory to target directory. Supports fetching the file in a variety of * ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based * on the URL parameter. Fetching directories is only supported from Hadoop-compatible @@ -337,7 +379,7 @@ private[spark] object Utils extends Logging { hadoopConf: Configuration, timestamp: Long, useCache: Boolean) { - val fileName = url.split("/").last + val fileName = decodeFileNameInURI(new URI(url)) val targetFile = new File(targetDir, fileName) val fetchCacheEnabled = conf.getBoolean("spark.files.useFetchCache", defaultValue = true) if (useCache && fetchCacheEnabled) { @@ -479,7 +521,7 @@ private[spark] object Utils extends Logging { // The file does not exist in the target directory. Copy or move it there. if (removeSourceFile) { - Files.move(sourceFile, destFile) + Files.move(sourceFile.toPath, destFile.toPath) } else { logInfo(s"Copying ${sourceFile.getAbsolutePath} to ${destFile.getAbsolutePath}") copyRecursive(sourceFile, destFile) @@ -497,7 +539,7 @@ private[spark] object Utils extends Logging { case (f1, f2) => filesEqualRecursive(f1, f2) } } else if (file1.isFile && file2.isFile) { - Files.equal(file1, file2) + GFiles.equal(file1, file2) } else { false } @@ -511,7 +553,7 @@ private[spark] object Utils extends Logging { val subfiles = source.listFiles() subfiles.foreach(f => copyRecursive(f, new File(dest, f.getName))) } else { - Files.copy(source, dest) + Files.copy(source.toPath, dest.toPath) } } @@ -535,6 +577,14 @@ private[spark] object Utils extends Logging { val uri = new URI(url) val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false) Option(uri.getScheme).getOrElse("file") match { + case "spark" => + if (SparkEnv.get == null) { + throw new IllegalStateException( + "Cannot retrieve files with 'spark' scheme without an active SparkEnv.") + } + val source = SparkEnv.get.rpcEnv.openChannel(url) + val is = Channels.newInputStream(source) + downloadFile(url, is, targetFile, fileOverwrite) case "http" | "https" | "ftp" => var uc: URLConnection = null if (securityMgr.isAuthenticationEnabled()) { @@ -617,9 +667,7 @@ private[spark] object Utils extends Logging { private[spark] def isRunningInYarnContainer(conf: SparkConf): Boolean = { // These environment variables are set by YARN. - // For Hadoop 0.23.X, we check for YARN_LOCAL_DIRS (we use this below in getYarnLocalDirs()) - // For Hadoop 2.X, we check for CONTAINER_ID. - conf.getenv("CONTAINER_ID") != null || conf.getenv("YARN_LOCAL_DIRS") != null + conf.getenv("CONTAINER_ID") != null } /** @@ -695,17 +743,12 @@ private[spark] object Utils extends Logging { logError(s"Failed to create local root dir in $root. Ignoring this directory.") None } - }.toArray + } } /** Get the Yarn approved local directories. */ private def getYarnLocalDirs(conf: SparkConf): String = { - // Hadoop 0.23 and 2.x have different Environment variable names for the - // local dirs, so lets check both. We assume one of the 2 is set. - // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X - val localDirs = Option(conf.getenv("YARN_LOCAL_DIRS")) - .getOrElse(Option(conf.getenv("LOCAL_DIRS")) - .getOrElse("")) + val localDirs = Option(conf.getenv("LOCAL_DIRS")).getOrElse("") if (localDirs.isEmpty) { throw new Exception("Yarn Local dirs can't be empty") @@ -899,15 +942,6 @@ private[spark] object Utils extends Logging { } } - /** - * Delete a file or directory and its contents recursively. - */ - def deleteRecursively(dir: TachyonFile, client: TachyonFS) { - if (!client.delete(new TachyonURI(dir.getPath()), true)) { - throw new IOException("Failed to delete the tachyon dir: " + dir) - } - } - /** * Check to see if file is a symbolic link. */ @@ -1087,9 +1121,9 @@ private[spark] object Utils extends Logging { extraEnvironment: Map[String, String] = Map.empty, redirectStderr: Boolean = true): String = { val process = executeCommand(command, workingDir, extraEnvironment, redirectStderr) - val output = new StringBuffer + val output = new StringBuilder val threadName = "read stdout for " + command(0) - def appendToOutput(s: String): Unit = output.append(s) + def appendToOutput(s: String): Unit = output.append(s).append("\n") val stdoutThread = processStreamByLine(threadName, process.getInputStream, appendToOutput) val exitCode = process.waitFor() stdoutThread.join() // Wait for it to finish reading output @@ -1160,21 +1194,6 @@ private[spark] object Utils extends Logging { } } - /** - * Execute a block of code that evaluates to Unit, re-throwing any non-fatal uncaught - * exceptions as IOException. This is used when implementing Externalizable and Serializable's - * read and write methods, since Java's serializer will not report non-IOExceptions properly; - * see SPARK-4080 for more context. - */ - def tryOrIOException(block: => Unit) { - try { - block - } catch { - case e: IOException => throw e - case NonFatal(t) => throw new IOException(t) - } - } - /** * Execute a block of code that returns a value, re-throwing any non-fatal uncaught * exceptions as IOException. This is used when implementing Externalizable and Serializable's @@ -1185,8 +1204,12 @@ private[spark] object Utils extends Logging { try { block } catch { - case e: IOException => throw e - case NonFatal(t) => throw new IOException(t) + case e: IOException => + logError("Exception encountered", e) + throw e + case NonFatal(e) => + logError("Exception encountered", e) + throw new IOException(e) } } @@ -1211,7 +1234,6 @@ private[spark] object Utils extends Logging { * exception from the original `out.write` call. */ def tryWithSafeFinally[T](block: => T)(finallyBlock: => Unit): T = { - // It would be nice to find a method on Try that did this var originalThrowable: Throwable = null try { block @@ -1237,6 +1259,53 @@ private[spark] object Utils extends Logging { } } + /** + * Execute a block of code and call the failure callbacks in the catch block. If exceptions occur + * in either the catch or the finally block, they are appended to the list of suppressed + * exceptions in original exception which is then rethrown. + * + * This is primarily an issue with `catch { abort() }` or `finally { out.close() }` blocks, + * where the abort/close needs to be called to clean up `out`, but if an exception happened + * in `out.write`, it's likely `out` may be corrupted and `abort` or `out.close` will + * fail as well. This would then suppress the original/likely more meaningful + * exception from the original `out.write` call. + */ + def tryWithSafeFinallyAndFailureCallbacks[T](block: => T) + (catchBlock: => Unit = (), finallyBlock: => Unit = ()): T = { + var originalThrowable: Throwable = null + try { + block + } catch { + case cause: Throwable => + // Purposefully not using NonFatal, because even fatal exceptions + // we don't want to have our finallyBlock suppress + originalThrowable = cause + try { + logError("Aborting task", originalThrowable) + TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(originalThrowable) + catchBlock + } catch { + case t: Throwable => + originalThrowable.addSuppressed(t) + logWarning(s"Suppressing exception in catch: " + t.getMessage, t) + } + throw originalThrowable + } finally { + try { + finallyBlock + } catch { + case t: Throwable => + if (originalThrowable != null) { + originalThrowable.addSuppressed(t) + logWarning(s"Suppressing exception in finally: " + t.getMessage, t) + throw originalThrowable + } else { + throw t + } + } + } + } + /** Default filtering function for finding call sites using `getCallSite`. */ private def sparkInternalExclusionFunction(className: String): Boolean = { // A regular expression to match classes of the internal Spark API's @@ -1461,7 +1530,7 @@ private[spark] object Utils extends Logging { rawMod + (if (rawMod < 0) mod else 0) } - // Handles idiosyncracies with hash (add more as required) + // Handles idiosyncrasies with hash (add more as required) // This method should be kept in sync with // org.apache.spark.network.util.JavaUtils#nonNegativeHash(). def nonNegativeHash(obj: AnyRef): Int = { @@ -1505,9 +1574,11 @@ private[spark] object Utils extends Logging { else -1 } - /** Returns the system properties map that is thread-safe to iterator over. It gets the - * properties which have been set explicitly, as well as those for which only a default value - * has been defined. */ + /** + * Returns the system properties map that is thread-safe to iterator over. It gets the + * properties which have been set explicitly, as well as those for which only a default value + * has been defined. + */ def getSystemProperties: Map[String, String] = { System.getProperties.stringPropertyNames().asScala .map(key => (key, System.getProperty(key))).toMap @@ -1531,7 +1602,7 @@ private[spark] object Utils extends Logging { * @param f function to be executed. If prepare is not None, the running time of each call to f * must be an order of magnitude longer than one millisecond for accurate timing. * @param prepare function to be executed before each call to f. Its running time doesn't count. - * @return the total time across all iterations (not couting preparation time) + * @return the total time across all iterations (not counting preparation time) */ def timeIt(numIters: Int)(f: => Unit, prepare: Option[() => Unit] = None): Long = { if (prepare.isEmpty) { @@ -1567,30 +1638,18 @@ private[spark] object Utils extends Logging { } /** - * Creates a symlink. Note jdk1.7 has Files.createSymbolicLink but not used here - * for jdk1.6 support. Supports windows by doing copy, everything else uses "ln -sf". + * Creates a symlink. * @param src absolute path to the source * @param dst relative path for the destination */ - def symlink(src: File, dst: File) { + def symlink(src: File, dst: File): Unit = { if (!src.isAbsolute()) { throw new IOException("Source must be absolute") } if (dst.isAbsolute()) { throw new IOException("Destination must be relative") } - var cmdSuffix = "" - val linkCmd = if (isWindows) { - // refer to http://technet.microsoft.com/en-us/library/cc771254.aspx - cmdSuffix = " /s /e /k /h /y /i" - "cmd /c xcopy " - } else { - cmdSuffix = "" - "ln -sf " - } - import scala.sys.process._ - (linkCmd + src.getAbsolutePath() + " " + dst.getPath() + cmdSuffix) lines_! - ProcessLogger(line => logInfo(line)) + Files.createSymbolicLink(dst.toPath, src.toPath) } @@ -1662,6 +1721,30 @@ private[spark] object Utils extends Logging { new File(path).getName } + /** + * Terminates a process waiting for at most the specified duration. Returns whether + * the process terminated. + */ + def terminateProcess(process: Process, timeoutMs: Long): Option[Int] = { + try { + // Java8 added a new API which will more forcibly kill the process. Use that if available. + val destroyMethod = process.getClass().getMethod("destroyForcibly"); + destroyMethod.setAccessible(true) + destroyMethod.invoke(process) + } catch { + case NonFatal(e) => + if (!e.isInstanceOf[NoSuchMethodException]) { + logWarning("Exception when attempting to kill process", e) + } + process.destroy() + } + if (waitForProcess(process, timeoutMs)) { + Option(process.exitValue()) + } else { + None + } + } + /** * Wait for a process to terminate for at most the specified duration. * Return whether the process actually terminated after the given timeout. @@ -1739,15 +1822,6 @@ private[spark] object Utils extends Logging { } } - lazy val isInInterpreter: Boolean = { - try { - val interpClass = classForName("org.apache.spark.repl.Main") - interpClass.getMethod("interp").invoke(null) != null - } catch { - case _: ClassNotFoundException => false - } - } - /** * Return a well-formed URI for the file described by a user input string. * @@ -1824,7 +1898,7 @@ private[spark] object Utils extends Logging { require(file.exists(), s"Properties file $file does not exist") require(file.isFile(), s"Properties file $file is not a normal file") - val inReader = new InputStreamReader(new FileInputStream(file), "UTF-8") + val inReader = new InputStreamReader(new FileInputStream(file), StandardCharsets.UTF_8) try { val properties = new Properties() properties.load(inReader) @@ -1934,8 +2008,10 @@ private[spark] object Utils extends Logging { } catch { case e: Exception if isBindCollision(e) => if (offset >= maxRetries) { - val exceptionMessage = - s"${e.getMessage}: Service$serviceString failed after $maxRetries retries!" + val exceptionMessage = s"${e.getMessage}: Service$serviceString failed after " + + s"$maxRetries retries! Consider explicitly setting the appropriate port for the " + + s"service$serviceString (for example spark.ui.port for SparkUI) to an available " + + "port or increasing spark.port.maxRetries." val exception = new BindException(exceptionMessage) // restore original stack trace exception.setStackTrace(e.getStackTrace) @@ -2140,6 +2216,7 @@ private[spark] object Utils extends Logging { /** * Return whether the specified file is a parent directory of the child file. */ + @tailrec def isInDirectory(parent: File, child: File): Boolean = { if (child == null || parent == null) { return false @@ -2153,6 +2230,16 @@ private[spark] object Utils extends Logging { isInDirectory(parent, child.getParentFile) } + + /** + * + * @return whether it is local mode + */ + def isLocalMaster(conf: SparkConf): Boolean = { + val master = conf.get("spark.master", "") + master == "local" || master.startsWith("local[") + } + /** * Return whether dynamic allocation is enabled in the given conf * Dynamic allocation and explicitly setting the number of executors are inherently @@ -2160,14 +2247,43 @@ private[spark] object Utils extends Logging { * the latter should override the former (SPARK-9092). */ def isDynamicAllocationEnabled(conf: SparkConf): Boolean = { - conf.getBoolean("spark.dynamicAllocation.enabled", false) && - conf.getInt("spark.executor.instances", 0) == 0 + val numExecutor = conf.getInt("spark.executor.instances", 0) + val dynamicAllocationEnabled = conf.getBoolean("spark.dynamicAllocation.enabled", false) + if (numExecutor != 0 && dynamicAllocationEnabled) { + logWarning("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") + } + numExecutor == 0 && dynamicAllocationEnabled && + (!isLocalMaster(conf) || conf.getBoolean("spark.dynamicAllocation.testing", false)) } def tryWithResource[R <: Closeable, T](createResource: => R)(f: R => T): T = { val resource = createResource try f.apply(resource) finally resource.close() } + + /** + * Returns a path of temporary file which is in the same directory with `path`. + */ + def tempFileWith(path: File): File = { + new File(path.getAbsolutePath + "." + UUID.randomUUID()) + } + + /** + * Returns the name of this JVM process. This is OS dependent but typically (OSX, Linux, Windows), + * this is formatted as PID@hostname. + */ + def getProcessName(): String = { + ManagementFactory.getRuntimeMXBean().getName() + } + + /** + * Utility function that should be called early in `main()` for daemons to set up some common + * diagnostic state. + */ + def initDaemon(log: Logger): Unit = { + log.info(s"Started daemon with process name: ${Utils.getProcessName()}") + SignalLogger.register(log) + } } /** @@ -2222,7 +2338,7 @@ private[spark] class CircularBuffer(sizeInBytes: Int = 10240) extends java.io.Ou def read(): Int = if (iterator.hasNext) iterator.next() else -1 } - val reader = new BufferedReader(new InputStreamReader(input)) + val reader = new BufferedReader(new InputStreamReader(input, StandardCharsets.UTF_8)) val stringBuilder = new StringBuilder var line = reader.readLine() while (line != null) { diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala deleted file mode 100644 index 2ed827eab46df..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/Vector.scala +++ /dev/null @@ -1,158 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import scala.language.implicitConversions -import scala.util.Random - -import org.apache.spark.util.random.XORShiftRandom - -@deprecated("Use Vectors.dense from Spark's mllib.linalg package instead.", "1.0.0") -class Vector(val elements: Array[Double]) extends Serializable { - def length: Int = elements.length - - def apply(index: Int): Double = elements(index) - - def + (other: Vector): Vector = { - if (length != other.length) { - throw new IllegalArgumentException("Vectors of different length") - } - Vector(length, i => this(i) + other(i)) - } - - def add(other: Vector): Vector = this + other - - def - (other: Vector): Vector = { - if (length != other.length) { - throw new IllegalArgumentException("Vectors of different length") - } - Vector(length, i => this(i) - other(i)) - } - - def subtract(other: Vector): Vector = this - other - - def dot(other: Vector): Double = { - if (length != other.length) { - throw new IllegalArgumentException("Vectors of different length") - } - var ans = 0.0 - var i = 0 - while (i < length) { - ans += this(i) * other(i) - i += 1 - } - ans - } - - /** - * return (this + plus) dot other, but without creating any intermediate storage - * @param plus - * @param other - * @return - */ - def plusDot(plus: Vector, other: Vector): Double = { - if (length != other.length) { - throw new IllegalArgumentException("Vectors of different length") - } - if (length != plus.length) { - throw new IllegalArgumentException("Vectors of different length") - } - var ans = 0.0 - var i = 0 - while (i < length) { - ans += (this(i) + plus(i)) * other(i) - i += 1 - } - ans - } - - def += (other: Vector): Vector = { - if (length != other.length) { - throw new IllegalArgumentException("Vectors of different length") - } - var i = 0 - while (i < length) { - elements(i) += other(i) - i += 1 - } - this - } - - def addInPlace(other: Vector): Vector = this +=other - - def * (scale: Double): Vector = Vector(length, i => this(i) * scale) - - def multiply (d: Double): Vector = this * d - - def / (d: Double): Vector = this * (1 / d) - - def divide (d: Double): Vector = this / d - - def unary_- : Vector = this * -1 - - def sum: Double = elements.reduceLeft(_ + _) - - def squaredDist(other: Vector): Double = { - var ans = 0.0 - var i = 0 - while (i < length) { - ans += (this(i) - other(i)) * (this(i) - other(i)) - i += 1 - } - ans - } - - def dist(other: Vector): Double = math.sqrt(squaredDist(other)) - - override def toString: String = elements.mkString("(", ", ", ")") -} - -object Vector { - def apply(elements: Array[Double]): Vector = new Vector(elements) - - def apply(elements: Double*): Vector = new Vector(elements.toArray) - - def apply(length: Int, initializer: Int => Double): Vector = { - val elements: Array[Double] = Array.tabulate(length)(initializer) - new Vector(elements) - } - - def zeros(length: Int): Vector = new Vector(new Array[Double](length)) - - def ones(length: Int): Vector = Vector(length, _ => 1) - - /** - * Creates this [[org.apache.spark.util.Vector]] of given length containing random numbers - * between 0.0 and 1.0. Optional scala.util.Random number generator can be provided. - */ - def random(length: Int, random: Random = new XORShiftRandom()): Vector = - Vector(length, _ => random.nextDouble()) - - class Multiplier(num: Double) { - def * (vec: Vector): Vector = vec * num - } - - implicit def doubleToMultiplier(num: Double): Multiplier = new Multiplier(num) - - implicit object VectorAccumParam extends org.apache.spark.AccumulatorParam[Vector] { - def addInPlace(t1: Vector, t2: Vector): Vector = t1 + t2 - - def zero(initialValue: Vector): Vector = Vector.zeros(initialValue.length) - } - -} diff --git a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala index 4c1e16155462e..6b74a29aceda9 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala @@ -17,7 +17,7 @@ package org.apache.spark.util.collection -import java.util.{Arrays, Comparator} +import java.util.Comparator import com.google.common.hash.Hashing diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index 85c5bdbfcebc0..7ab67fc3a2de9 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -17,21 +17,14 @@ package org.apache.spark.util.collection -import java.io.{Externalizable, ObjectInput, ObjectOutput} - -import org.apache.spark.util.{Utils => UUtils} - - /** * A simple, fixed-size bit set implementation. This implementation is fast because it avoids * safety/bound checking. */ -class BitSet(private[this] var numBits: Int) extends Externalizable { +class BitSet(numBits: Int) extends Serializable { - private var words = new Array[Long](bit2words(numBits)) - private def numWords = words.length - - def this() = this(0) + private val words = new Array[Long](bit2words(numBits)) + private val numWords = words.length /** * Compute the capacity (number of bits) that can be represented @@ -237,19 +230,4 @@ class BitSet(private[this] var numBits: Int) extends Externalizable { /** Return the number of longs it would take to hold numBits. */ private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1 - - override def writeExternal(out: ObjectOutput): Unit = UUtils.tryOrIOException { - out.writeInt(numBits) - words.foreach(out.writeLong(_)) - } - - override def readExternal(in: ObjectInput): Unit = UUtils.tryOrIOException { - numBits = in.readInt() - words = new Array[Long](bit2words(numBits)) - var index = 0 - while (index < words.length) { - words(index) = in.readLong() - index += 1 - } - } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index f6d81ee5bf05e..95351e98261d7 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -26,14 +26,15 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.io.ByteStreams -import org.apache.spark.{Logging, SparkEnv, TaskContext} +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.internal.Logging import org.apache.spark.memory.TaskMemoryManager -import org.apache.spark.serializer.{DeserializationStream, Serializer} +import org.apache.spark.serializer.{DeserializationStream, Serializer, SerializerManager} import org.apache.spark.storage.{BlockId, BlockManager} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator -import org.apache.spark.executor.ShuffleWriteMetrics /** * :: DeveloperApi :: @@ -58,7 +59,8 @@ class ExternalAppendOnlyMap[K, V, C]( mergeCombiners: (C, C) => C, serializer: Serializer = SparkEnv.get.serializer, blockManager: BlockManager = SparkEnv.get.blockManager, - context: TaskContext = TaskContext.get()) + context: TaskContext = TaskContext.get(), + serializerManager: SerializerManager = SparkEnv.get.serializerManager) extends Iterable[(K, C)] with Serializable with Logging @@ -193,8 +195,8 @@ class ExternalAppendOnlyMap[K, V, C]( val w = writer writer = null w.commitAndClose() - _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten - batchSizes.append(curWriteMetrics.shuffleBytesWritten) + _diskBytesSpilled += curWriteMetrics.bytesWritten + batchSizes.append(curWriteMetrics.bytesWritten) objectsWritten = 0 } @@ -457,7 +459,7 @@ class ExternalAppendOnlyMap[K, V, C]( ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) - val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream) + val compressedStream = serializerManager.wrapForCompression(blockId, bufferedStream) ser.deserializeStream(compressedStream) } else { // No more batches left diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index a44e72b7c16d3..561ba22df557f 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -20,16 +20,16 @@ package org.apache.spark.util.collection import java.io._ import java.util.Comparator -import scala.collection.mutable.ArrayBuffer import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer -import com.google.common.annotations.VisibleForTesting import com.google.common.io.ByteStreams import org.apache.spark._ +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.internal.Logging import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer._ -import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} /** @@ -68,31 +68,31 @@ import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} * * At a high level, this class works internally as follows: * - * - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if - * we want to combine by key, or a PartitionedPairBuffer if we don't. - * Inside these buffers, we sort elements by partition ID and then possibly also by key. - * To avoid calling the partitioner multiple times with each key, we store the partition ID - * alongside each record. + * - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if + * we want to combine by key, or a PartitionedPairBuffer if we don't. + * Inside these buffers, we sort elements by partition ID and then possibly also by key. + * To avoid calling the partitioner multiple times with each key, we store the partition ID + * alongside each record. * - * - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first - * by partition ID and possibly second by key or by hash code of the key, if we want to do - * aggregation. For each file, we track how many objects were in each partition in memory, so we - * don't have to write out the partition ID for every element. + * - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first + * by partition ID and possibly second by key or by hash code of the key, if we want to do + * aggregation. For each file, we track how many objects were in each partition in memory, so we + * don't have to write out the partition ID for every element. * - * - When the user requests an iterator or file output, the spilled files are merged, along with - * any remaining in-memory data, using the same sort order defined above (unless both sorting - * and aggregation are disabled). If we need to aggregate by key, we either use a total ordering - * from the ordering parameter, or read the keys with the same hash code and compare them with - * each other for equality to merge values. + * - When the user requests an iterator or file output, the spilled files are merged, along with + * any remaining in-memory data, using the same sort order defined above (unless both sorting + * and aggregation are disabled). If we need to aggregate by key, we either use a total ordering + * from the ordering parameter, or read the keys with the same hash code and compare them with + * each other for equality to merge values. * - * - Users are expected to call stop() at the end to delete all the intermediate files. + * - Users are expected to call stop() at the end to delete all the intermediate files. */ private[spark] class ExternalSorter[K, V, C]( context: TaskContext, aggregator: Option[Aggregator[K, V, C]] = None, partitioner: Option[Partitioner] = None, ordering: Option[Ordering[K]] = None, - serializer: Option[Serializer] = None) + serializer: Serializer = SparkEnv.get.serializer) extends Logging with Spillable[WritablePartitionedPairCollection[K, C]] { @@ -108,8 +108,8 @@ private[spark] class ExternalSorter[K, V, C]( private val blockManager = SparkEnv.get.blockManager private val diskBlockManager = blockManager.diskBlockManager - private val ser = Serializer.getSerializer(serializer) - private val serInstance = ser.newInstance() + private val serializerManager = SparkEnv.get.serializerManager + private val serInstance = serializer.newInstance() // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 @@ -263,8 +263,8 @@ private[spark] class ExternalSorter[K, V, C]( val w = writer writer = null w.commitAndClose() - _diskBytesSpilled += spillMetrics.shuffleBytesWritten - batchSizes.append(spillMetrics.shuffleBytesWritten) + _diskBytesSpilled += spillMetrics.bytesWritten + batchSizes.append(spillMetrics.bytesWritten) spillMetrics = null objectsWritten = 0 } @@ -504,7 +504,7 @@ private[spark] class ExternalSorter[K, V, C]( ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) - val compressedStream = blockManager.wrapForCompression(spill.blockId, bufferedStream) + val compressedStream = serializerManager.wrapForCompression(spill.blockId, bufferedStream) serInstance.deserializeStream(compressedStream) } else { // No more batches left @@ -608,8 +608,8 @@ private[spark] class ExternalSorter[K, V, C]( * * For now, we just merge all the spilled files in once pass, but this can be modified to * support hierarchical merging. + * Exposed for testing. */ - @VisibleForTesting def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { val usingMap = aggregator.isDefined val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer @@ -639,13 +639,14 @@ private[spark] class ExternalSorter[K, V, C]( * called by the SortShuffleWriter. * * @param blockId block ID to write to. The index file will be blockId.name + ".index". - * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) */ def writePartitionedFile( blockId: BlockId, outputFile: File): Array[Long] = { + val writeMetrics = context.taskMetrics().registerShuffleWriteMetrics() + // Track location of each range in the output file val lengths = new Array[Long](numPartitions) @@ -654,8 +655,8 @@ private[spark] class ExternalSorter[K, V, C]( val collection = if (aggregator.isDefined) map else buffer val it = collection.destructiveSortedWritablePartitionedIterator(comparator) while (it.hasNext) { - val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize, - context.taskMetrics.shuffleWriteMetrics.get) + val writer = blockManager.getDiskWriter( + blockId, outputFile, serInstance, fileBufferSize, writeMetrics) val partitionId = it.nextPartition() while (it.hasNext && it.nextPartition() == partitionId) { it.writeNext(writer) @@ -668,8 +669,8 @@ private[spark] class ExternalSorter[K, V, C]( // We must perform merge-sort; get an iterator by partition and write everything directly. for ((id, elements) <- this.partitionedIterator) { if (elements.hasNext) { - val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize, - context.taskMetrics.shuffleWriteMetrics.get) + val writer = blockManager.getDiskWriter( + blockId, outputFile, serInstance, fileBufferSize, writeMetrics) for (elem <- elements) { writer.write(elem._1, elem._2) } @@ -682,8 +683,7 @@ private[spark] class ExternalSorter[K, V, C]( context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) - context.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemoryUsedBytes) + context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes) lengths } diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala index efc2482c74ddf..22d7a4988bb56 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala @@ -19,17 +19,13 @@ package org.apache.spark.util.collection import scala.reflect.ClassTag -import org.apache.spark.annotation.DeveloperApi - /** - * :: DeveloperApi :: * A fast hash map implementation for nullable keys. This hash map supports insertions and updates, * but not deletions. This map is about 5X faster than java.util.HashMap, while using much less * space overhead. * * Under the hood, it uses our OpenHashSet implementation. */ -@DeveloperApi private[spark] class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( initialCapacity: Int) diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 60bf4dd7469f1..0f6a425e3db9a 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -18,6 +18,7 @@ package org.apache.spark.util.collection import scala.reflect._ + import com.google.common.hash.Hashing import org.apache.spark.annotation.Private diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 9e002621a6909..25ca2037bbac6 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -17,8 +17,9 @@ package org.apache.spark.util.collection -import org.apache.spark.memory.TaskMemoryManager -import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} /** * Spills contents of an in-memory collection to disk when the memory threshold @@ -78,7 +79,8 @@ private[spark] trait Spillable[C] extends Logging { if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) { // Claim up to double our current memory from the shuffle memory pool val amountToRequest = 2 * currentMemory - myMemoryThreshold - val granted = taskMemoryManager.acquireExecutionMemory(amountToRequest, null) + val granted = + taskMemoryManager.acquireExecutionMemory(amountToRequest, MemoryMode.ON_HEAP, null) myMemoryThreshold += granted // If we were granted too little memory to grow further (either tryToAcquire returned 0, // or we already had more memory than myMemoryThreshold), spill the current collection @@ -107,7 +109,8 @@ private[spark] trait Spillable[C] extends Logging { */ def releaseMemory(): Unit = { // The amount we requested does not include the initial memory tracking threshold - taskMemoryManager.releaseExecutionMemory(myMemoryThreshold - initialMemoryThreshold, null) + taskMemoryManager.releaseExecutionMemory( + myMemoryThreshold - initialMemoryThreshold, MemoryMode.ON_HEAP, null) myMemoryThreshold = initialMemoryThreshold } diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala index 38848e9018c6c..5232c2bd8d6f6 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala @@ -23,9 +23,10 @@ import org.apache.spark.storage.DiskBlockObjectWriter /** * A common interface for size-tracking collections of key-value pairs that - * - Have an associated partition for each key-value pair. - * - Support a memory-efficient sorted iterator - * - Support a WritablePartitionedIterator for writing the contents directly as bytes. + * + * - Have an associated partition for each key-value pair. + * - Support a memory-efficient sorted iterator + * - Support a WritablePartitionedIterator for writing the contents directly as bytes. */ private[spark] trait WritablePartitionedPairCollection[K, V] { /** diff --git a/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala deleted file mode 100644 index daac6f971eb20..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.io - -import java.io.OutputStream - -import scala.collection.mutable.ArrayBuffer - - -/** - * An OutputStream that writes to fixed-size chunks of byte arrays. - * - * @param chunkSize size of each chunk, in bytes. - */ -private[spark] -class ByteArrayChunkOutputStream(chunkSize: Int) extends OutputStream { - - private val chunks = new ArrayBuffer[Array[Byte]] - - /** Index of the last chunk. Starting with -1 when the chunks array is empty. */ - private var lastChunkIndex = -1 - - /** - * Next position to write in the last chunk. - * - * If this equals chunkSize, it means for next write we need to allocate a new chunk. - * This can also never be 0. - */ - private var position = chunkSize - - override def write(b: Int): Unit = { - allocateNewChunkIfNeeded() - chunks(lastChunkIndex)(position) = b.toByte - position += 1 - } - - override def write(bytes: Array[Byte], off: Int, len: Int): Unit = { - var written = 0 - while (written < len) { - allocateNewChunkIfNeeded() - val thisBatch = math.min(chunkSize - position, len - written) - System.arraycopy(bytes, written + off, chunks(lastChunkIndex), position, thisBatch) - written += thisBatch - position += thisBatch - } - } - - @inline - private def allocateNewChunkIfNeeded(): Unit = { - if (position == chunkSize) { - chunks += new Array[Byte](chunkSize) - lastChunkIndex += 1 - position = 0 - } - } - - def toArrays: Array[Array[Byte]] = { - if (lastChunkIndex == -1) { - new Array[Array[Byte]](0) - } else { - // Copy the first n-1 chunks to the output, and then create an array that fits the last chunk. - // An alternative would have been returning an array of ByteBuffers, with the last buffer - // bounded to only the last chunk's position. However, given our use case in Spark (to put - // the chunks in block manager), only limiting the view bound of the buffer would still - // require the block manager to store the whole chunk. - val ret = new Array[Array[Byte]](chunks.size) - for (i <- 0 until chunks.size - 1) { - ret(i) = chunks(i) - } - if (position == chunkSize) { - ret(lastChunkIndex) = chunks(lastChunkIndex) - } else { - ret(lastChunkIndex) = new Array[Byte](position) - System.arraycopy(chunks(lastChunkIndex), 0, ret(lastChunkIndex), 0, position) - } - ret - } - } -} diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala new file mode 100644 index 0000000000000..fb4706e78d38f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.io + +import java.io.InputStream +import java.nio.ByteBuffer +import java.nio.channels.WritableByteChannel + +import com.google.common.primitives.UnsignedBytes +import io.netty.buffer.{ByteBuf, Unpooled} + +import org.apache.spark.network.util.ByteArrayWritableChannel +import org.apache.spark.storage.StorageUtils + +/** + * Read-only byte buffer which is physically stored as multiple chunks rather than a single + * contiguous array. + * + * @param chunks an array of [[ByteBuffer]]s. Each buffer in this array must be non-empty and have + * position == 0. Ownership of these buffers is transferred to the ChunkedByteBuffer, + * so if these buffers may also be used elsewhere then the caller is responsible for + * copying them as needed. + */ +private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { + require(chunks != null, "chunks must not be null") + require(chunks.forall(_.limit() > 0), "chunks must be non-empty") + require(chunks.forall(_.position() == 0), "chunks' positions must be 0") + + private[this] var disposed: Boolean = false + + /** + * This size of this buffer, in bytes. + */ + val size: Long = chunks.map(_.limit().asInstanceOf[Long]).sum + + def this(byteBuffer: ByteBuffer) = { + this(Array(byteBuffer)) + } + + /** + * Write this buffer to a channel. + */ + def writeFully(channel: WritableByteChannel): Unit = { + for (bytes <- getChunks()) { + while (bytes.remaining > 0) { + channel.write(bytes) + } + } + } + + /** + * Wrap this buffer to view it as a Netty ByteBuf. + */ + def toNetty: ByteBuf = { + Unpooled.wrappedBuffer(getChunks(): _*) + } + + /** + * Copy this buffer into a new byte array. + * + * @throws UnsupportedOperationException if this buffer's size exceeds the maximum array size. + */ + def toArray: Array[Byte] = { + if (size >= Integer.MAX_VALUE) { + throw new UnsupportedOperationException( + s"cannot call toArray because buffer size ($size bytes) exceeds maximum array size") + } + val byteChannel = new ByteArrayWritableChannel(size.toInt) + writeFully(byteChannel) + byteChannel.close() + byteChannel.getData + } + + /** + * Copy this buffer into a new ByteBuffer. + * + * @throws UnsupportedOperationException if this buffer's size exceeds the max ByteBuffer size. + */ + def toByteBuffer: ByteBuffer = { + if (chunks.length == 1) { + chunks.head.duplicate() + } else { + ByteBuffer.wrap(toArray) + } + } + + /** + * Creates an input stream to read data from this ChunkedByteBuffer. + * + * @param dispose if true, [[dispose()]] will be called at the end of the stream + * in order to close any memory-mapped files which back this buffer. + */ + def toInputStream(dispose: Boolean = false): InputStream = { + new ChunkedByteBufferInputStream(this, dispose) + } + + /** + * Get duplicates of the ByteBuffers backing this ChunkedByteBuffer. + */ + def getChunks(): Array[ByteBuffer] = { + chunks.map(_.duplicate()) + } + + /** + * Make a copy of this ChunkedByteBuffer, copying all of the backing data into new buffers. + * The new buffer will share no resources with the original buffer. + * + * @param allocator a method for allocating byte buffers + */ + def copy(allocator: Int => ByteBuffer): ChunkedByteBuffer = { + val copiedChunks = getChunks().map { chunk => + val newChunk = allocator(chunk.limit()) + newChunk.put(chunk) + newChunk.flip() + newChunk + } + new ChunkedByteBuffer(copiedChunks) + } + + /** + * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that + * might cause errors if one attempts to read from the unmapped buffer, but it's better than + * waiting for the GC to find it because that could lead to huge numbers of open files. There's + * unfortunately no standard API to do this. + */ + def dispose(): Unit = { + if (!disposed) { + chunks.foreach(StorageUtils.dispose) + disposed = true + } + } +} + +/** + * Reads data from a ChunkedByteBuffer. + * + * @param dispose if true, [[ChunkedByteBuffer.dispose()]] will be called at the end of the stream + * in order to close any memory-mapped files which back the buffer. + */ +private class ChunkedByteBufferInputStream( + var chunkedByteBuffer: ChunkedByteBuffer, + dispose: Boolean) + extends InputStream { + + private[this] var chunks = chunkedByteBuffer.getChunks().iterator + private[this] var currentChunk: ByteBuffer = { + if (chunks.hasNext) { + chunks.next() + } else { + null + } + } + + override def read(): Int = { + if (currentChunk != null && !currentChunk.hasRemaining && chunks.hasNext) { + currentChunk = chunks.next() + } + if (currentChunk != null && currentChunk.hasRemaining) { + UnsignedBytes.toInt(currentChunk.get()) + } else { + close() + -1 + } + } + + override def read(dest: Array[Byte], offset: Int, length: Int): Int = { + if (currentChunk != null && !currentChunk.hasRemaining && chunks.hasNext) { + currentChunk = chunks.next() + } + if (currentChunk != null && currentChunk.hasRemaining) { + val amountToGet = math.min(currentChunk.remaining(), length) + currentChunk.get(dest, offset, amountToGet) + amountToGet + } else { + close() + -1 + } + } + + override def skip(bytes: Long): Long = { + if (currentChunk != null) { + val amountToSkip = math.min(bytes, currentChunk.remaining).toInt + currentChunk.position(currentChunk.position + amountToSkip) + if (currentChunk.remaining() == 0) { + if (chunks.hasNext) { + currentChunk = chunks.next() + } else { + close() + } + } + amountToSkip + } else { + 0L + } + } + + override def close(): Unit = { + if (chunkedByteBuffer != null && dispose) { + chunkedByteBuffer.dispose() + } + chunkedByteBuffer = null + chunks = null + currentChunk = null + } +} diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala new file mode 100644 index 0000000000000..67b50d1e70437 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.io + +import java.io.OutputStream +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.storage.StorageUtils + +/** + * An OutputStream that writes to fixed-size chunks of byte arrays. + * + * @param chunkSize size of each chunk, in bytes. + */ +private[spark] class ChunkedByteBufferOutputStream( + chunkSize: Int, + allocator: Int => ByteBuffer) + extends OutputStream { + + private[this] var toChunkedByteBufferWasCalled = false + + private val chunks = new ArrayBuffer[ByteBuffer] + + /** Index of the last chunk. Starting with -1 when the chunks array is empty. */ + private[this] var lastChunkIndex = -1 + + /** + * Next position to write in the last chunk. + * + * If this equals chunkSize, it means for next write we need to allocate a new chunk. + * This can also never be 0. + */ + private[this] var position = chunkSize + private[this] var _size = 0 + + def size: Long = _size + + override def write(b: Int): Unit = { + allocateNewChunkIfNeeded() + chunks(lastChunkIndex).put(b.toByte) + position += 1 + _size += 1 + } + + override def write(bytes: Array[Byte], off: Int, len: Int): Unit = { + var written = 0 + while (written < len) { + allocateNewChunkIfNeeded() + val thisBatch = math.min(chunkSize - position, len - written) + chunks(lastChunkIndex).put(bytes, written + off, thisBatch) + written += thisBatch + position += thisBatch + } + _size += len + } + + @inline + private def allocateNewChunkIfNeeded(): Unit = { + require(!toChunkedByteBufferWasCalled, "cannot write after toChunkedByteBuffer() is called") + if (position == chunkSize) { + chunks += allocator(chunkSize) + lastChunkIndex += 1 + position = 0 + } + } + + def toChunkedByteBuffer: ChunkedByteBuffer = { + require(!toChunkedByteBufferWasCalled, "toChunkedByteBuffer() can only be called once") + toChunkedByteBufferWasCalled = true + if (lastChunkIndex == -1) { + new ChunkedByteBuffer(Array.empty[ByteBuffer]) + } else { + // Copy the first n-1 chunks to the output, and then create an array that fits the last chunk. + // An alternative would have been returning an array of ByteBuffers, with the last buffer + // bounded to only the last chunk's position. However, given our use case in Spark (to put + // the chunks in block manager), only limiting the view bound of the buffer would still + // require the block manager to store the whole chunk. + val ret = new Array[ByteBuffer](chunks.size) + for (i <- 0 until chunks.size - 1) { + ret(i) = chunks(i) + ret(i).flip() + } + if (position == chunkSize) { + ret(lastChunkIndex) = chunks(lastChunkIndex) + ret(lastChunkIndex).flip() + } else { + ret(lastChunkIndex) = allocator(position) + chunks(lastChunkIndex).flip() + ret(lastChunkIndex).put(chunks(lastChunkIndex)) + ret(lastChunkIndex).flip() + StorageUtils.dispose(chunks(lastChunkIndex)) + } + new ChunkedByteBuffer(ret) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala index 14b6ba4af489a..fdb1495899bc3 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala @@ -17,9 +17,10 @@ package org.apache.spark.util.logging -import java.io.{File, FileOutputStream, InputStream} +import java.io.{File, FileOutputStream, InputStream, IOException} -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging import org.apache.spark.util.{IntParam, Utils} /** @@ -29,7 +30,6 @@ private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSi extends Logging { @volatile private var outputStream: FileOutputStream = null @volatile private var markedForStop = false // has the appender been asked to stopped - @volatile private var stopped = false // has the appender stopped // Thread that reads the input stream and writes to file private val writingThread = new Thread("File appending thread for " + file) { @@ -47,11 +47,7 @@ private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSi * or because of any error in appending */ def awaitTermination() { - synchronized { - if (!stopped) { - wait() - } - } + writingThread.join() } /** Stop the appender */ @@ -63,24 +59,28 @@ private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSi protected def appendStreamToFile() { try { logDebug("Started appending thread") - openFile() - val buf = new Array[Byte](bufferSize) - var n = 0 - while (!markedForStop && n != -1) { - n = inputStream.read(buf) - if (n != -1) { - appendToFile(buf, n) + Utils.tryWithSafeFinally { + openFile() + val buf = new Array[Byte](bufferSize) + var n = 0 + while (!markedForStop && n != -1) { + try { + n = inputStream.read(buf) + } catch { + // An InputStream can throw IOException during read if the stream is closed + // asynchronously, so once appender has been flagged to stop these will be ignored + case _: IOException if markedForStop => // do nothing and proceed to stop appending + } + if (n > 0) { + appendToFile(buf, n) + } } + } { + closeFile() } } catch { case e: Exception => logError(s"Error writing stream to file $file", e) - } finally { - closeFile() - synchronized { - stopped = true - notifyAll() - } } } diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala index 1e8476c4a047e..a0eb05c7c0e82 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala @@ -20,8 +20,8 @@ package org.apache.spark.util.logging import java.io.{File, FileFilter, InputStream} import com.google.common.io.Files + import org.apache.spark.SparkConf -import RollingFileAppender._ /** * Continuously appends data from input stream into the given file, and rolls @@ -39,9 +39,11 @@ private[spark] class RollingFileAppender( activeFile: File, val rollingPolicy: RollingPolicy, conf: SparkConf, - bufferSize: Int = DEFAULT_BUFFER_SIZE + bufferSize: Int = RollingFileAppender.DEFAULT_BUFFER_SIZE ) extends FileAppender(inputStream, activeFile, bufferSize) { + import RollingFileAppender._ + private val maxRetainedFiles = conf.getInt(RETAINED_FILES_PROPERTY, -1) /** Stop the appender */ @@ -115,7 +117,7 @@ private[spark] class RollingFileAppender( } }).sorted val filesToBeDeleted = rolledoverFiles.take( - math.max(0, rolledoverFiles.size - maxRetainedFiles)) + math.max(0, rolledoverFiles.length - maxRetainedFiles)) filesToBeDeleted.foreach { file => logInfo(s"Deleting file executor log file ${file.getAbsolutePath}") file.delete() diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala index d7b7219e179d0..6e80db2f51f9c 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala @@ -20,7 +20,7 @@ package org.apache.spark.util.logging import java.text.SimpleDateFormat import java.util.Calendar -import org.apache.spark.Logging +import org.apache.spark.internal.Logging /** * Defines the policy based on which [[org.apache.spark.util.logging.RollingFileAppender]] will @@ -32,10 +32,10 @@ private[spark] trait RollingPolicy { def shouldRollover(bytesToBeWritten: Long): Boolean /** Notify that rollover has occurred */ - def rolledOver() + def rolledOver(): Unit /** Notify that bytes have been written */ - def bytesWritten(bytes: Long) + def bytesWritten(bytes: Long): Unit /** Get the desired name of the rollover file */ def generateRolledOverFileSuffix(): String diff --git a/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala b/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala index 70f3dd62b9b19..41f28f6e511e3 100644 --- a/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala @@ -26,5 +26,5 @@ import org.apache.spark.annotation.DeveloperApi @DeveloperApi trait Pseudorandom { /** Set random seed. */ - def setSeed(seed: Long) + def setSeed(seed: Long): Unit } diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index c156b03cdb7c4..8c67364ef1a05 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -19,8 +19,8 @@ package org.apache.spark.util.random import java.util.Random -import scala.reflect.ClassTag import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag import org.apache.commons.math3.distribution.PoissonDistribution @@ -39,7 +39,14 @@ import org.apache.spark.annotation.DeveloperApi trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable { /** take a random sample */ - def sample(items: Iterator[T]): Iterator[U] + def sample(items: Iterator[T]): Iterator[U] = + items.filter(_ => sample > 0).asInstanceOf[Iterator[U]] + + /** + * Whether to sample the next item or not. + * Return how many times the next item will be sampled. Return 0 if it is not sampled. + */ + def sample(): Int /** return a copy of the RandomSampler object */ override def clone: RandomSampler[T, U] = @@ -54,7 +61,7 @@ object RandomSampler { /** * Default maximum gap-sampling fraction. * For sampling fractions <= this value, the gap sampling optimization will be applied. - * Above this value, it is assumed that "tradtional" Bernoulli sampling is faster. The + * Above this value, it is assumed that "traditional" Bernoulli sampling is faster. The * optimal value for this will depend on the RNG. More expensive RNGs will tend to make * the optimal value higher. The most reliable way to determine this value for a new RNG * is to experiment. When tuning for a new RNG, I would expect a value of 0.5 to be close @@ -107,21 +114,13 @@ class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = fals override def setSeed(seed: Long): Unit = rng.setSeed(seed) - override def sample(items: Iterator[T]): Iterator[T] = { + override def sample(): Int = { if (ub - lb <= 0.0) { - if (complement) items else Iterator.empty + if (complement) 1 else 0 } else { - if (complement) { - items.filter { item => { - val x = rng.nextDouble() - (x < lb) || (x >= ub) - }} - } else { - items.filter { item => { - val x = rng.nextDouble() - (x >= lb) && (x < ub) - }} - } + val x = rng.nextDouble() + val n = if ((x >= lb) && (x < ub)) 1 else 0 + if (complement) 1 - n else n } } @@ -155,15 +154,22 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T override def setSeed(seed: Long): Unit = rng.setSeed(seed) - override def sample(items: Iterator[T]): Iterator[T] = { + private lazy val gapSampling: GapSampling = + new GapSampling(fraction, rng, RandomSampler.rngEpsilon) + + override def sample(): Int = { if (fraction <= 0.0) { - Iterator.empty + 0 } else if (fraction >= 1.0) { - items + 1 } else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) { - new GapSamplingIterator(items, fraction, rng, RandomSampler.rngEpsilon) + gapSampling.sample() } else { - items.filter { _ => rng.nextDouble() <= fraction } + if (rng.nextDouble() <= fraction) { + 1 + } else { + 0 + } } } @@ -180,7 +186,7 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T * @tparam T item type */ @DeveloperApi -class PoissonSampler[T: ClassTag]( +class PoissonSampler[T]( fraction: Double, useGapSamplingIfPossible: Boolean) extends RandomSampler[T, T] { @@ -201,15 +207,29 @@ class PoissonSampler[T: ClassTag]( rngGap.setSeed(seed) } - override def sample(items: Iterator[T]): Iterator[T] = { + private lazy val gapSamplingReplacement = + new GapSamplingReplacement(fraction, rngGap, RandomSampler.rngEpsilon) + + override def sample(): Int = { if (fraction <= 0.0) { - Iterator.empty + 0 } else if (useGapSamplingIfPossible && fraction <= RandomSampler.defaultMaxGapSamplingFraction) { - new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon) + gapSamplingReplacement.sample() } else { + rng.sample() + } + } + + override def sample(items: Iterator[T]): Iterator[T] = { + if (fraction <= 0.0) { + Iterator.empty + } else { + val useGapSampling = useGapSamplingIfPossible && + fraction <= RandomSampler.defaultMaxGapSamplingFraction + items.flatMap { item => - val count = rng.sample() + val count = if (useGapSampling) gapSamplingReplacement.sample() else rng.sample() if (count == 0) Iterator.empty else Iterator.fill(count)(item) } } @@ -220,50 +240,36 @@ class PoissonSampler[T: ClassTag]( private[spark] -class GapSamplingIterator[T: ClassTag]( - var data: Iterator[T], +class GapSampling( f: Double, rng: Random = RandomSampler.newDefaultRNG, - epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] { + epsilon: Double = RandomSampler.rngEpsilon) extends Serializable { require(f > 0.0 && f < 1.0, s"Sampling fraction ($f) must reside on open interval (0, 1)") require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0") - /** implement efficient linear-sequence drop until Scala includes fix for jira SI-8835. */ - private val iterDrop: Int => Unit = { - val arrayClass = Array.empty[T].iterator.getClass - val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass - data.getClass match { - case `arrayClass` => - (n: Int) => { data = data.drop(n) } - case `arrayBufferClass` => - (n: Int) => { data = data.drop(n) } - case _ => - (n: Int) => { - var j = 0 - while (j < n && data.hasNext) { - data.next() - j += 1 - } - } - } - } - - override def hasNext: Boolean = data.hasNext + private val lnq = math.log1p(-f) - override def next(): T = { - val r = data.next() - advance() - r + /** Return 1 if the next item should be sampled. Otherwise, return 0. */ + def sample(): Int = { + if (countForDropping > 0) { + countForDropping -= 1 + 0 + } else { + advance() + 1 + } } - private val lnq = math.log1p(-f) + private var countForDropping: Int = 0 - /** skip elements that won't be sampled, according to geometric dist P(k) = (f)(1-f)^k. */ + /** + * Decide the number of elements that won't be sampled, + * according to geometric dist P(k) = (f)(1-f)^k. + */ private def advance(): Unit = { val u = math.max(rng.nextDouble(), epsilon) - val k = (math.log(u) / lnq).toInt - iterDrop(k) + countForDropping = (math.log(u) / lnq).toInt } /** advance to first sample as part of object construction. */ @@ -273,73 +279,24 @@ class GapSamplingIterator[T: ClassTag]( // work reliably. } + private[spark] -class GapSamplingReplacementIterator[T: ClassTag]( - var data: Iterator[T], - f: Double, - rng: Random = RandomSampler.newDefaultRNG, - epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] { +class GapSamplingReplacement( + val f: Double, + val rng: Random = RandomSampler.newDefaultRNG, + epsilon: Double = RandomSampler.rngEpsilon) extends Serializable { require(f > 0.0, s"Sampling fraction ($f) must be > 0") require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0") - /** implement efficient linear-sequence drop until scala includes fix for jira SI-8835. */ - private val iterDrop: Int => Unit = { - val arrayClass = Array.empty[T].iterator.getClass - val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass - data.getClass match { - case `arrayClass` => - (n: Int) => { data = data.drop(n) } - case `arrayBufferClass` => - (n: Int) => { data = data.drop(n) } - case _ => - (n: Int) => { - var j = 0 - while (j < n && data.hasNext) { - data.next() - j += 1 - } - } - } - } - - /** current sampling value, and its replication factor, as we are sampling with replacement. */ - private var v: T = _ - private var rep: Int = 0 - - override def hasNext: Boolean = data.hasNext || rep > 0 - - override def next(): T = { - val r = v - rep -= 1 - if (rep <= 0) advance() - r - } - - /** - * Skip elements with replication factor zero (i.e. elements that won't be sampled). - * Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is - * q is the probabililty of Poisson(0; f) - */ - private def advance(): Unit = { - val u = math.max(rng.nextDouble(), epsilon) - val k = (math.log(u) / (-f)).toInt - iterDrop(k) - // set the value and replication factor for the next value - if (data.hasNext) { - v = data.next() - rep = poissonGE1 - } - } - - private val q = math.exp(-f) + protected val q = math.exp(-f) /** * Sample from Poisson distribution, conditioned such that the sampled value is >= 1. * This is an adaptation from the algorithm for Generating Poisson distributed random variables: * http://en.wikipedia.org/wiki/Poisson_distribution */ - private def poissonGE1: Int = { + protected def poissonGE1: Int = { // simulate that the standard poisson sampling // gave us at least one iteration, for a sample of >= 1 var pp = q + ((1.0 - q) * rng.nextDouble()) @@ -353,6 +310,28 @@ class GapSamplingReplacementIterator[T: ClassTag]( } r } + private var countForDropping: Int = 0 + + def sample(): Int = { + if (countForDropping > 0) { + countForDropping -= 1 + 0 + } else { + val r = poissonGE1 + advance() + r + } + } + + /** + * Skip elements with replication factor zero (i.e. elements that won't be sampled). + * Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is + * q is the probability of Poisson(0; f) + */ + private def advance(): Unit = { + val u = math.max(rng.nextDouble(), epsilon) + countForDropping = (math.log(u) / (-f)).toInt + } /** advance to first sample as part of object construction. */ advance() diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index c9a864ae62778..f98932a470165 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -34,7 +34,7 @@ private[spark] object SamplingUtils { input: Iterator[T], k: Int, seed: Long = Random.nextLong()) - : (Array[T], Int) = { + : (Array[T], Long) = { val reservoir = new Array[T](k) // Put the first k elements in the reservoir. var i = 0 @@ -52,16 +52,17 @@ private[spark] object SamplingUtils { (trimReservoir, i) } else { // If input size > k, continue the sampling process. + var l = i.toLong val rand = new XORShiftRandom(seed) while (input.hasNext) { val item = input.next() - val replacementIndex = rand.nextInt(i) + val replacementIndex = (rand.nextDouble() * l).toLong if (replacementIndex < k) { - reservoir(replacementIndex) = item + reservoir(replacementIndex.toInt) = item } - i += 1 + l += 1 } - (reservoir, i) + (reservoir, l) } } diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala index effe6fa2adcfa..67822749112c6 100644 --- a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala @@ -24,7 +24,7 @@ import scala.reflect.ClassTag import org.apache.commons.math3.distribution.PoissonDistribution -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD /** diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala index 85fb923cd9bc7..e8cdb6e98bf36 100644 --- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala @@ -60,9 +60,11 @@ private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) { private[spark] object XORShiftRandom { /** Hash seeds to have 0/1 bits throughout. */ - private def hashSeed(seed: Long): Long = { + private[random] def hashSeed(seed: Long): Long = { val bytes = ByteBuffer.allocate(java.lang.Long.SIZE).putLong(seed).array() - MurmurHash3.bytesHash(bytes) + val lowBits = MurmurHash3.bytesHash(bytes) + val highBits = MurmurHash3.bytesHash(bytes, lowBits) + (highBits.toLong << 32) | (lowBits.toLong & 0xFFFFFFFFL) } /** diff --git a/core/src/main/scala/org/apache/spark/util/taskListeners.scala b/core/src/main/scala/org/apache/spark/util/taskListeners.scala new file mode 100644 index 0000000000000..1be31e88ab68e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/taskListeners.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.EventListener + +import org.apache.spark.TaskContext +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * + * Listener providing a callback function to invoke when a task's execution completes. + */ +@DeveloperApi +trait TaskCompletionListener extends EventListener { + def onTaskCompletion(context: TaskContext): Unit +} + + +/** + * :: DeveloperApi :: + * + * Listener providing a callback function to invoke when a task's execution encounters an error. + * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times. + */ +@DeveloperApi +trait TaskFailureListener extends EventListener { + def onTaskFailure(context: TaskContext, error: Throwable): Unit +} + + +/** + * Exception thrown when there is an exception in executing the callback in TaskCompletionListener. + */ +private[spark] +class TaskCompletionListenerException( + errorMessages: Seq[String], + val previousError: Option[Throwable] = None) + extends RuntimeException { + + override def getMessage: String = { + if (errorMessages.size == 1) { + errorMessages.head + } else { + errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") + } + + previousError.map { e => + "\n\nPrevious exception in task: " + e.getMessage + "\n" + + e.getStackTrace.mkString("\t", "\n\t", "") + }.getOrElse("") + } +} diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index fd8f7f39b7cc8..0f65554516153 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -21,7 +21,18 @@ import java.nio.channels.FileChannel; import java.nio.ByteBuffer; import java.net.URI; -import java.util.*; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.concurrent.*; import scala.Tuple2; @@ -35,8 +46,6 @@ import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.base.Throwables; -import com.google.common.base.Optional; -import com.google.common.base.Charsets; import com.google.common.io.Files; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; @@ -45,11 +54,16 @@ import org.apache.hadoop.mapred.SequenceFileOutputFormat; import org.apache.hadoop.mapreduce.Job; import org.junit.After; -import org.junit.Assert; +import static org.junit.Assert.*; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.*; +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaFutureAction; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.Optional; import org.apache.spark.api.java.function.*; import org.apache.spark.input.PortableDataStream; import org.apache.spark.partial.BoundedDouble; @@ -88,19 +102,19 @@ public void sparkContextUnion() { JavaRDD s2 = sc.parallelize(strings); // Varargs JavaRDD sUnion = sc.union(s1, s2); - Assert.assertEquals(4, sUnion.count()); + assertEquals(4, sUnion.count()); // List List> list = new ArrayList<>(); list.add(s2); sUnion = sc.union(s1, list); - Assert.assertEquals(4, sUnion.count()); + assertEquals(4, sUnion.count()); // Union of JavaDoubleRDDs List doubles = Arrays.asList(1.0, 2.0); JavaDoubleRDD d1 = sc.parallelizeDoubles(doubles); JavaDoubleRDD d2 = sc.parallelizeDoubles(doubles); JavaDoubleRDD dUnion = sc.union(d1, d2); - Assert.assertEquals(4, dUnion.count()); + assertEquals(4, dUnion.count()); // Union of JavaPairRDDs List> pairs = new ArrayList<>(); @@ -109,7 +123,7 @@ public void sparkContextUnion() { JavaPairRDD p1 = sc.parallelizePairs(pairs); JavaPairRDD p2 = sc.parallelizePairs(pairs); JavaPairRDD pUnion = sc.union(p1, p2); - Assert.assertEquals(4, pUnion.count()); + assertEquals(4, pUnion.count()); } @SuppressWarnings("unchecked") @@ -121,17 +135,17 @@ public void intersection() { JavaRDD s2 = sc.parallelize(ints2); JavaRDD intersections = s1.intersection(s2); - Assert.assertEquals(3, intersections.count()); + assertEquals(3, intersections.count()); JavaRDD empty = sc.emptyRDD(); JavaRDD emptyIntersection = empty.intersection(s2); - Assert.assertEquals(0, emptyIntersection.count()); + assertEquals(0, emptyIntersection.count()); List doubles = Arrays.asList(1.0, 2.0); JavaDoubleRDD d1 = sc.parallelizeDoubles(doubles); JavaDoubleRDD d2 = sc.parallelizeDoubles(doubles); JavaDoubleRDD dIntersection = d1.intersection(d2); - Assert.assertEquals(2, dIntersection.count()); + assertEquals(2, dIntersection.count()); List> pairs = new ArrayList<>(); pairs.add(new Tuple2<>(1, 2)); @@ -139,28 +153,36 @@ public void intersection() { JavaPairRDD p1 = sc.parallelizePairs(pairs); JavaPairRDD p2 = sc.parallelizePairs(pairs); JavaPairRDD pIntersection = p1.intersection(p2); - Assert.assertEquals(2, pIntersection.count()); + assertEquals(2, pIntersection.count()); } @Test public void sample() { List ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); JavaRDD rdd = sc.parallelize(ints); - JavaRDD sample20 = rdd.sample(true, 0.2, 3); - Assert.assertEquals(2, sample20.count()); - JavaRDD sample20WithoutReplacement = rdd.sample(false, 0.2, 5); - Assert.assertEquals(2, sample20WithoutReplacement.count()); + // the seeds here are "magic" to make this work out nicely + JavaRDD sample20 = rdd.sample(true, 0.2, 8); + assertEquals(2, sample20.count()); + JavaRDD sample20WithoutReplacement = rdd.sample(false, 0.2, 2); + assertEquals(2, sample20WithoutReplacement.count()); } @Test public void randomSplit() { - List ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + List ints = new ArrayList<>(1000); + for (int i = 0; i < 1000; i++) { + ints.add(i); + } JavaRDD rdd = sc.parallelize(ints); JavaRDD[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 31); - Assert.assertEquals(3, splits.length); - Assert.assertEquals(1, splits[0].count()); - Assert.assertEquals(2, splits[1].count()); - Assert.assertEquals(7, splits[2].count()); + // the splits aren't perfect -- not enough data for them to be -- just check they're about right + assertEquals(3, splits.length); + long s0 = splits[0].count(); + long s1 = splits[1].count(); + long s2 = splits[2].count(); + assertTrue(s0 + " not within expected range", s0 > 150 && s0 < 250); + assertTrue(s1 + " not within expected range", s1 > 250 && s0 < 350); + assertTrue(s2 + " not within expected range", s2 > 430 && s2 < 570); } @Test @@ -174,17 +196,17 @@ public void sortByKey() { // Default comparator JavaPairRDD sortedRDD = rdd.sortByKey(); - Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); + assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); List> sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); + assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); + assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); // Custom comparator sortedRDD = rdd.sortByKey(Collections.reverseOrder(), false); - Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); + assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); + assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); + assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); } @SuppressWarnings("unchecked") @@ -213,19 +235,19 @@ public int getPartition(Object key) { JavaPairRDD repartitioned = rdd.repartitionAndSortWithinPartitions(partitioner); - Assert.assertTrue(repartitioned.partitioner().isPresent()); - Assert.assertEquals(repartitioned.partitioner().get(), partitioner); + assertTrue(repartitioned.partitioner().isPresent()); + assertEquals(repartitioned.partitioner().get(), partitioner); List>> partitions = repartitioned.glom().collect(); - Assert.assertEquals(partitions.get(0), + assertEquals(partitions.get(0), Arrays.asList(new Tuple2<>(0, 5), new Tuple2<>(0, 8), new Tuple2<>(2, 6))); - Assert.assertEquals(partitions.get(1), + assertEquals(partitions.get(1), Arrays.asList(new Tuple2<>(1, 3), new Tuple2<>(3, 8), new Tuple2<>(3, 8))); } @Test public void emptyRDD() { JavaRDD rdd = sc.emptyRDD(); - Assert.assertEquals("Empty RDD shouldn't have any values", 0, rdd.count()); + assertEquals("Empty RDD shouldn't have any values", 0, rdd.count()); } @Test @@ -238,17 +260,18 @@ public void sortBy() { JavaRDD> rdd = sc.parallelize(pairs); // compare on first value - JavaRDD> sortedRDD = rdd.sortBy(new Function, Integer>() { + JavaRDD> sortedRDD = + rdd.sortBy(new Function, Integer>() { @Override public Integer call(Tuple2 t) { return t._1(); } }, true, 2); - Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); + assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); List> sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); + assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); + assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); // compare on second value sortedRDD = rdd.sortBy(new Function, Integer>() { @@ -257,10 +280,10 @@ public Integer call(Tuple2 t) { return t._2(); } }, true, 2); - Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); + assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(2)); + assertEquals(new Tuple2<>(3, 2), sortedPairs.get(1)); + assertEquals(new Tuple2<>(0, 4), sortedPairs.get(2)); } @Test @@ -273,7 +296,7 @@ public void call(String s) { accum.add(1); } }); - Assert.assertEquals(2, accum.value().intValue()); + assertEquals(2, accum.value().intValue()); } @Test @@ -289,7 +312,7 @@ public void call(Iterator iter) { } } }); - Assert.assertEquals(2, accum.value().intValue()); + assertEquals(2, accum.value().intValue()); } @Test @@ -297,7 +320,7 @@ public void toLocalIterator() { List correct = Arrays.asList(1, 2, 3, 4); JavaRDD rdd = sc.parallelize(correct); List result = Lists.newArrayList(rdd.toLocalIterator()); - Assert.assertEquals(correct, result); + assertEquals(correct, result); } @Test @@ -305,7 +328,7 @@ public void zipWithUniqueId() { List dataArray = Arrays.asList(1, 2, 3, 4); JavaPairRDD zip = sc.parallelize(dataArray).zipWithUniqueId(); JavaRDD indexes = zip.values(); - Assert.assertEquals(4, new HashSet<>(indexes.collect()).size()); + assertEquals(4, new HashSet<>(indexes.collect()).size()); } @Test @@ -314,7 +337,7 @@ public void zipWithIndex() { JavaPairRDD zip = sc.parallelize(dataArray).zipWithIndex(); JavaRDD indexes = zip.values(); List correctIndexes = Arrays.asList(0L, 1L, 2L, 3L); - Assert.assertEquals(correctIndexes, indexes.collect()); + assertEquals(correctIndexes, indexes.collect()); } @SuppressWarnings("unchecked") @@ -325,8 +348,8 @@ public void lookup() { new Tuple2<>("Oranges", "Fruit"), new Tuple2<>("Oranges", "Citrus") )); - Assert.assertEquals(2, categories.lookup("Oranges").size()); - Assert.assertEquals(2, Iterables.size(categories.groupByKey().lookup("Oranges").get(0))); + assertEquals(2, categories.lookup("Oranges").size()); + assertEquals(2, Iterables.size(categories.groupByKey().lookup("Oranges").get(0))); } @Test @@ -339,14 +362,14 @@ public Boolean call(Integer x) { } }; JavaPairRDD> oddsAndEvens = rdd.groupBy(isOdd); - Assert.assertEquals(2, oddsAndEvens.count()); - Assert.assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0))); // Evens - Assert.assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds + assertEquals(2, oddsAndEvens.count()); + assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0))); // Evens + assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds oddsAndEvens = rdd.groupBy(isOdd, 1); - Assert.assertEquals(2, oddsAndEvens.count()); - Assert.assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0))); // Evens - Assert.assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds + assertEquals(2, oddsAndEvens.count()); + assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0))); // Evens + assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds } @Test @@ -362,14 +385,14 @@ public Boolean call(Tuple2 x) { }; JavaPairRDD pairRDD = rdd.zip(rdd); JavaPairRDD>> oddsAndEvens = pairRDD.groupBy(areOdd); - Assert.assertEquals(2, oddsAndEvens.count()); - Assert.assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0))); // Evens - Assert.assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds + assertEquals(2, oddsAndEvens.count()); + assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0))); // Evens + assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds oddsAndEvens = pairRDD.groupBy(areOdd, 1); - Assert.assertEquals(2, oddsAndEvens.count()); - Assert.assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0))); // Evens - Assert.assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds + assertEquals(2, oddsAndEvens.count()); + assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0))); // Evens + assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds } @SuppressWarnings("unchecked") @@ -386,8 +409,8 @@ public String call(Tuple2 x) { }; JavaPairRDD pairRDD = rdd.zip(rdd); JavaPairRDD> keyed = pairRDD.keyBy(sumToString); - Assert.assertEquals(7, keyed.count()); - Assert.assertEquals(1, (long) keyed.lookup("2").get(0)._1()); + assertEquals(7, keyed.count()); + assertEquals(1, (long) keyed.lookup("2").get(0)._1()); } @SuppressWarnings("unchecked") @@ -404,8 +427,8 @@ public void cogroup() { )); JavaPairRDD, Iterable>> cogrouped = categories.cogroup(prices); - Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); - Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); + assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); cogrouped.collect(); } @@ -429,9 +452,9 @@ public void cogroup3() { JavaPairRDD, Iterable, Iterable>> cogrouped = categories.cogroup(prices, quantities); - Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); - Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); - Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); + assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); + assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); cogrouped.collect(); @@ -458,12 +481,12 @@ public void cogroup4() { new Tuple2<>("Apples", "US") )); - JavaPairRDD, Iterable, Iterable, Iterable>> cogrouped = - categories.cogroup(prices, quantities, countries); - Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); - Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); - Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); - Assert.assertEquals("[BR]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._4())); + JavaPairRDD, Iterable, Iterable, + Iterable>> cogrouped = categories.cogroup(prices, quantities, countries); + assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); + assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); + assertEquals("[BR]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._4())); cogrouped.collect(); } @@ -485,7 +508,7 @@ public void leftOuterJoin() { )); List>>> joined = rdd1.leftOuterJoin(rdd2).collect(); - Assert.assertEquals(5, joined.size()); + assertEquals(5, joined.size()); Tuple2>> firstUnmatched = rdd1.leftOuterJoin(rdd2).filter( new Function>>, Boolean>() { @@ -494,7 +517,7 @@ public Boolean call(Tuple2>> tup) { return !tup._2()._2().isPresent(); } }).first(); - Assert.assertEquals(3, firstUnmatched._1().intValue()); + assertEquals(3, firstUnmatched._1().intValue()); } @Test @@ -508,10 +531,10 @@ public Integer call(Integer a, Integer b) { }; int sum = rdd.fold(0, add); - Assert.assertEquals(33, sum); + assertEquals(33, sum); sum = rdd.reduce(add); - Assert.assertEquals(33, sum); + assertEquals(33, sum); } @Test @@ -525,7 +548,7 @@ public Integer call(Integer a, Integer b) { }; for (int depth = 1; depth <= 10; depth++) { int sum = rdd.treeReduce(add, depth); - Assert.assertEquals(-5, sum); + assertEquals(-5, sum); } } @@ -540,7 +563,7 @@ public Integer call(Integer a, Integer b) { }; for (int depth = 1; depth <= 10; depth++) { int sum = rdd.treeAggregate(0, add, add, depth); - Assert.assertEquals(-5, sum); + assertEquals(-5, sum); } } @@ -570,10 +593,10 @@ public Set call(Set a, Set b) { return a; } }).collectAsMap(); - Assert.assertEquals(3, sets.size()); - Assert.assertEquals(new HashSet<>(Arrays.asList(1)), sets.get(1)); - Assert.assertEquals(new HashSet<>(Arrays.asList(2)), sets.get(3)); - Assert.assertEquals(new HashSet<>(Arrays.asList(1, 3)), sets.get(5)); + assertEquals(3, sets.size()); + assertEquals(new HashSet<>(Arrays.asList(1)), sets.get(1)); + assertEquals(new HashSet<>(Arrays.asList(2)), sets.get(3)); + assertEquals(new HashSet<>(Arrays.asList(1, 3)), sets.get(5)); } @SuppressWarnings("unchecked") @@ -594,9 +617,9 @@ public Integer call(Integer a, Integer b) { return a + b; } }); - Assert.assertEquals(1, sums.lookup(1).get(0).intValue()); - Assert.assertEquals(2, sums.lookup(2).get(0).intValue()); - Assert.assertEquals(3, sums.lookup(3).get(0).intValue()); + assertEquals(1, sums.lookup(1).get(0).intValue()); + assertEquals(2, sums.lookup(2).get(0).intValue()); + assertEquals(3, sums.lookup(3).get(0).intValue()); } @SuppressWarnings("unchecked") @@ -617,14 +640,14 @@ public Integer call(Integer a, Integer b) { return a + b; } }); - Assert.assertEquals(1, counts.lookup(1).get(0).intValue()); - Assert.assertEquals(2, counts.lookup(2).get(0).intValue()); - Assert.assertEquals(3, counts.lookup(3).get(0).intValue()); + assertEquals(1, counts.lookup(1).get(0).intValue()); + assertEquals(2, counts.lookup(2).get(0).intValue()); + assertEquals(3, counts.lookup(3).get(0).intValue()); Map localCounts = counts.collectAsMap(); - Assert.assertEquals(1, localCounts.get(1).intValue()); - Assert.assertEquals(2, localCounts.get(2).intValue()); - Assert.assertEquals(3, localCounts.get(3).intValue()); + assertEquals(1, localCounts.get(1).intValue()); + assertEquals(2, localCounts.get(2).intValue()); + assertEquals(3, localCounts.get(3).intValue()); localCounts = rdd.reduceByKeyLocally(new Function2() { @Override @@ -632,45 +655,45 @@ public Integer call(Integer a, Integer b) { return a + b; } }); - Assert.assertEquals(1, localCounts.get(1).intValue()); - Assert.assertEquals(2, localCounts.get(2).intValue()); - Assert.assertEquals(3, localCounts.get(3).intValue()); + assertEquals(1, localCounts.get(1).intValue()); + assertEquals(2, localCounts.get(2).intValue()); + assertEquals(3, localCounts.get(3).intValue()); } @Test public void approximateResults() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); Map countsByValue = rdd.countByValue(); - Assert.assertEquals(2, countsByValue.get(1).longValue()); - Assert.assertEquals(1, countsByValue.get(13).longValue()); + assertEquals(2, countsByValue.get(1).longValue()); + assertEquals(1, countsByValue.get(13).longValue()); PartialResult> approx = rdd.countByValueApprox(1); Map finalValue = approx.getFinalValue(); - Assert.assertEquals(2.0, finalValue.get(1).mean(), 0.01); - Assert.assertEquals(1.0, finalValue.get(13).mean(), 0.01); + assertEquals(2.0, finalValue.get(1).mean(), 0.01); + assertEquals(1.0, finalValue.get(13).mean(), 0.01); } @Test public void take() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); - Assert.assertEquals(1, rdd.first().intValue()); + assertEquals(1, rdd.first().intValue()); rdd.take(2); rdd.takeSample(false, 2, 42); } @Test public void isEmpty() { - Assert.assertTrue(sc.emptyRDD().isEmpty()); - Assert.assertTrue(sc.parallelize(new ArrayList()).isEmpty()); - Assert.assertFalse(sc.parallelize(Arrays.asList(1)).isEmpty()); - Assert.assertTrue(sc.parallelize(Arrays.asList(1, 2, 3), 3).filter( + assertTrue(sc.emptyRDD().isEmpty()); + assertTrue(sc.parallelize(new ArrayList()).isEmpty()); + assertFalse(sc.parallelize(Arrays.asList(1)).isEmpty()); + assertTrue(sc.parallelize(Arrays.asList(1, 2, 3), 3).filter( new Function() { @Override public Boolean call(Integer i) { return i < 0; } }).isEmpty()); - Assert.assertFalse(sc.parallelize(Arrays.asList(1, 2, 3)).filter( + assertFalse(sc.parallelize(Arrays.asList(1, 2, 3)).filter( new Function() { @Override public Boolean call(Integer i) { @@ -679,47 +702,40 @@ public Boolean call(Integer i) { }).isEmpty()); } - @Test - public void toArray() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3)); - List list = rdd.toArray(); - Assert.assertEquals(Arrays.asList(1, 2, 3), list); - } - @Test public void cartesian() { JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); JavaRDD stringRDD = sc.parallelize(Arrays.asList("Hello", "World")); JavaPairRDD cartesian = stringRDD.cartesian(doubleRDD); - Assert.assertEquals(new Tuple2<>("Hello", 1.0), cartesian.first()); + assertEquals(new Tuple2<>("Hello", 1.0), cartesian.first()); } @Test public void javaDoubleRDD() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); JavaDoubleRDD distinct = rdd.distinct(); - Assert.assertEquals(5, distinct.count()); + assertEquals(5, distinct.count()); JavaDoubleRDD filter = rdd.filter(new Function() { @Override public Boolean call(Double x) { return x > 2.0; } }); - Assert.assertEquals(3, filter.count()); + assertEquals(3, filter.count()); JavaDoubleRDD union = rdd.union(rdd); - Assert.assertEquals(12, union.count()); + assertEquals(12, union.count()); union = union.cache(); - Assert.assertEquals(12, union.count()); + assertEquals(12, union.count()); - Assert.assertEquals(20, rdd.sum(), 0.01); + assertEquals(20, rdd.sum(), 0.01); StatCounter stats = rdd.stats(); - Assert.assertEquals(20, stats.sum(), 0.01); - Assert.assertEquals(20/6.0, rdd.mean(), 0.01); - Assert.assertEquals(20/6.0, rdd.mean(), 0.01); - Assert.assertEquals(6.22222, rdd.variance(), 0.01); - Assert.assertEquals(7.46667, rdd.sampleVariance(), 0.01); - Assert.assertEquals(2.49444, rdd.stdev(), 0.01); - Assert.assertEquals(2.73252, rdd.sampleStdev(), 0.01); + assertEquals(20, stats.sum(), 0.01); + assertEquals(20/6.0, rdd.mean(), 0.01); + assertEquals(20/6.0, rdd.mean(), 0.01); + assertEquals(6.22222, rdd.variance(), 0.01); + assertEquals(7.46667, rdd.sampleVariance(), 0.01); + assertEquals(2.49444, rdd.stdev(), 0.01); + assertEquals(2.73252, rdd.sampleStdev(), 0.01); rdd.first(); rdd.take(5); @@ -732,13 +748,13 @@ public void javaDoubleRDDHistoGram() { Tuple2 results = rdd.histogram(2); double[] expected_buckets = {1.0, 2.5, 4.0}; long[] expected_counts = {2, 2}; - Assert.assertArrayEquals(expected_buckets, results._1(), 0.1); - Assert.assertArrayEquals(expected_counts, results._2()); + assertArrayEquals(expected_buckets, results._1(), 0.1); + assertArrayEquals(expected_counts, results._2()); // Test with provided buckets long[] histogram = rdd.histogram(expected_buckets); - Assert.assertArrayEquals(expected_counts, histogram); + assertArrayEquals(expected_counts, histogram); // SPARK-5744 - Assert.assertArrayEquals( + assertArrayEquals( new long[] {0}, sc.parallelizeDoubles(new ArrayList(0), 1).histogram(new double[]{0.0, 1.0})); } @@ -754,42 +770,42 @@ public int compare(Double o1, Double o2) { public void max() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); double max = rdd.max(new DoubleComparator()); - Assert.assertEquals(4.0, max, 0.001); + assertEquals(4.0, max, 0.001); } @Test public void min() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); double max = rdd.min(new DoubleComparator()); - Assert.assertEquals(1.0, max, 0.001); + assertEquals(1.0, max, 0.001); } @Test public void naturalMax() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); double max = rdd.max(); - Assert.assertEquals(4.0, max, 0.0); + assertEquals(4.0, max, 0.0); } @Test public void naturalMin() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); double max = rdd.min(); - Assert.assertEquals(1.0, max, 0.0); + assertEquals(1.0, max, 0.0); } @Test public void takeOrdered() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); - Assert.assertEquals(Arrays.asList(1.0, 2.0), rdd.takeOrdered(2, new DoubleComparator())); - Assert.assertEquals(Arrays.asList(1.0, 2.0), rdd.takeOrdered(2)); + assertEquals(Arrays.asList(1.0, 2.0), rdd.takeOrdered(2, new DoubleComparator())); + assertEquals(Arrays.asList(1.0, 2.0), rdd.takeOrdered(2)); } @Test public void top() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); List top2 = rdd.top(2); - Assert.assertEquals(Arrays.asList(4, 3), top2); + assertEquals(Arrays.asList(4, 3), top2); } private static class AddInts implements Function2 { @@ -803,7 +819,7 @@ public Integer call(Integer a, Integer b) { public void reduce() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); int sum = rdd.reduce(new AddInts()); - Assert.assertEquals(10, sum); + assertEquals(10, sum); } @Test @@ -815,21 +831,21 @@ public Double call(Double v1, Double v2) { return v1 + v2; } }); - Assert.assertEquals(10.0, sum, 0.001); + assertEquals(10.0, sum, 0.001); } @Test public void fold() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); int sum = rdd.fold(0, new AddInts()); - Assert.assertEquals(10, sum); + assertEquals(10, sum); } @Test public void aggregate() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); int sum = rdd.aggregate(0, new AddInts(), new AddInts()); - Assert.assertEquals(10, sum); + assertEquals(10, sum); } @Test @@ -865,40 +881,40 @@ public void flatMap() { "The quick brown fox jumps over the lazy dog.")); JavaRDD words = rdd.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Arrays.asList(x.split(" ")); + public Iterator call(String x) { + return Arrays.asList(x.split(" ")).iterator(); } }); - Assert.assertEquals("Hello", words.first()); - Assert.assertEquals(11, words.count()); + assertEquals("Hello", words.first()); + assertEquals(11, words.count()); JavaPairRDD pairsRDD = rdd.flatMapToPair( new PairFlatMapFunction() { @Override - public Iterable> call(String s) { + public Iterator> call(String s) { List> pairs = new LinkedList<>(); for (String word : s.split(" ")) { pairs.add(new Tuple2<>(word, word)); } - return pairs; + return pairs.iterator(); } } ); - Assert.assertEquals(new Tuple2<>("Hello", "Hello"), pairsRDD.first()); - Assert.assertEquals(11, pairsRDD.count()); + assertEquals(new Tuple2<>("Hello", "Hello"), pairsRDD.first()); + assertEquals(11, pairsRDD.count()); JavaDoubleRDD doubles = rdd.flatMapToDouble(new DoubleFlatMapFunction() { @Override - public Iterable call(String s) { + public Iterator call(String s) { List lengths = new LinkedList<>(); for (String word : s.split(" ")) { lengths.add((double) word.length()); } - return lengths; + return lengths.iterator(); } }); - Assert.assertEquals(5.0, doubles.first(), 0.01); - Assert.assertEquals(11, pairsRDD.count()); + assertEquals(5.0, doubles.first(), 0.01); + assertEquals(11, pairsRDD.count()); } @SuppressWarnings("unchecked") @@ -915,8 +931,8 @@ public void mapsFromPairsToPairs() { JavaPairRDD swapped = pairRDD.flatMapToPair( new PairFlatMapFunction, String, Integer>() { @Override - public Iterable> call(Tuple2 item) { - return Collections.singletonList(item.swap()); + public Iterator> call(Tuple2 item) { + return Collections.singletonList(item.swap()).iterator(); } }); swapped.collect(); @@ -936,15 +952,15 @@ public void mapPartitions() { JavaRDD partitionSums = rdd.mapPartitions( new FlatMapFunction, Integer>() { @Override - public Iterable call(Iterator iter) { + public Iterator call(Iterator iter) { int sum = 0; while (iter.hasNext()) { sum += iter.next(); } - return Collections.singletonList(sum); + return Collections.singletonList(sum).iterator(); } }); - Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); + assertEquals("[3, 7]", partitionSums.collect().toString()); } @@ -962,9 +978,22 @@ public Iterator call(Integer index, Iterator iter) { return Collections.singletonList(sum).iterator(); } }, false); - Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); + assertEquals("[3, 7]", partitionSums.collect().toString()); } + @Test + public void getNumPartitions(){ + JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3); + JavaDoubleRDD rdd2 = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0), 2); + JavaPairRDD rdd3 = sc.parallelizePairs(Arrays.asList( + new Tuple2<>("a", 1), + new Tuple2<>("aa", 2), + new Tuple2<>("aaa", 3) + ), 2); + assertEquals(3, rdd1.getNumPartitions()); + assertEquals(2, rdd2.getNumPartitions()); + assertEquals(2, rdd3.getNumPartitions()); + } @Test public void repartition() { @@ -972,18 +1001,18 @@ public void repartition() { JavaRDD in1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 2); JavaRDD repartitioned1 = in1.repartition(4); List> result1 = repartitioned1.glom().collect(); - Assert.assertEquals(4, result1.size()); + assertEquals(4, result1.size()); for (List l : result1) { - Assert.assertFalse(l.isEmpty()); + assertFalse(l.isEmpty()); } // Growing number of partitions JavaRDD in2 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 4); JavaRDD repartitioned2 = in2.repartition(2); List> result2 = repartitioned2.glom().collect(); - Assert.assertEquals(2, result2.size()); + assertEquals(2, result2.size()); for (List l: result2) { - Assert.assertFalse(l.isEmpty()); + assertFalse(l.isEmpty()); } } @@ -992,7 +1021,7 @@ public void repartition() { public void persist() { JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); doubleRDD = doubleRDD.persist(StorageLevel.DISK_ONLY()); - Assert.assertEquals(20, doubleRDD.sum(), 0.1); + assertEquals(20, doubleRDD.sum(), 0.1); List> pairs = Arrays.asList( new Tuple2<>(1, "a"), @@ -1001,24 +1030,24 @@ public void persist() { ); JavaPairRDD pairRDD = sc.parallelizePairs(pairs); pairRDD = pairRDD.persist(StorageLevel.DISK_ONLY()); - Assert.assertEquals("a", pairRDD.first()._2()); + assertEquals("a", pairRDD.first()._2()); JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); rdd = rdd.persist(StorageLevel.DISK_ONLY()); - Assert.assertEquals(1, rdd.first().intValue()); + assertEquals(1, rdd.first().intValue()); } @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); TaskContext context = TaskContext$.MODULE$.empty(); - Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); + assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } @Test public void glom() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); - Assert.assertEquals("[1, 2]", rdd.glom().first().toString()); + assertEquals("[1, 2]", rdd.glom().first().toString()); } // File input / output tests are largely adapted from FileSuite: @@ -1030,18 +1059,18 @@ public void textFiles() throws IOException { rdd.saveAsTextFile(outputDir); // Read the plain text file and check it's OK File outputFile = new File(outputDir, "part-00000"); - String content = Files.toString(outputFile, Charsets.UTF_8); - Assert.assertEquals("1\n2\n3\n4\n", content); + String content = Files.toString(outputFile, StandardCharsets.UTF_8); + assertEquals("1\n2\n3\n4\n", content); // Also try reading it in as a text file RDD List expected = Arrays.asList("1", "2", "3", "4"); JavaRDD readRDD = sc.textFile(outputDir); - Assert.assertEquals(expected, readRDD.collect()); + assertEquals(expected, readRDD.collect()); } @Test public void wholeTextFiles() throws Exception { - byte[] content1 = "spark is easy to use.\n".getBytes("utf-8"); - byte[] content2 = "spark is also easy to use.\n".getBytes("utf-8"); + byte[] content1 = "spark is easy to use.\n".getBytes(StandardCharsets.UTF_8); + byte[] content2 = "spark is also easy to use.\n".getBytes(StandardCharsets.UTF_8); String tempDirName = tempDir.getAbsolutePath(); Files.write(content1, new File(tempDirName + "/part-00000")); @@ -1055,7 +1084,7 @@ public void wholeTextFiles() throws Exception { List> result = readRDD.collect(); for (Tuple2 res : result) { - Assert.assertEquals(res._2(), container.get(new URI(res._1()).getPath())); + assertEquals(res._2(), container.get(new URI(res._1()).getPath())); } } @@ -1068,7 +1097,7 @@ public void textFilesCompressed() throws IOException { // Try reading it in as a text file RDD List expected = Arrays.asList("1", "2", "3", "4"); JavaRDD readRDD = sc.textFile(outputDir); - Assert.assertEquals(expected, readRDD.collect()); + assertEquals(expected, readRDD.collect()); } @SuppressWarnings("unchecked") @@ -1097,13 +1126,13 @@ public Tuple2 call(Tuple2 pair) { return new Tuple2<>(pair._1().get(), pair._2().toString()); } }); - Assert.assertEquals(pairs, readRDD.collect()); + assertEquals(pairs, readRDD.collect()); } @Test public void binaryFiles() throws Exception { // Reusing the wholeText files example - byte[] content1 = "spark is easy to use.\n".getBytes("utf-8"); + byte[] content1 = "spark is easy to use.\n".getBytes(StandardCharsets.UTF_8); String tempDirName = tempDir.getAbsolutePath(); File file1 = new File(tempDirName + "/part-00000"); @@ -1117,14 +1146,14 @@ public void binaryFiles() throws Exception { JavaPairRDD readRDD = sc.binaryFiles(tempDirName, 3); List> result = readRDD.collect(); for (Tuple2 res : result) { - Assert.assertArrayEquals(content1, res._2().toArray()); + assertArrayEquals(content1, res._2().toArray()); } } @Test public void binaryFilesCaching() throws Exception { // Reusing the wholeText files example - byte[] content1 = "spark is easy to use.\n".getBytes("utf-8"); + byte[] content1 = "spark is easy to use.\n".getBytes(StandardCharsets.UTF_8); String tempDirName = tempDir.getAbsolutePath(); File file1 = new File(tempDirName + "/part-00000"); @@ -1146,14 +1175,14 @@ public void call(Tuple2 pair) { List> result = readRDD.collect(); for (Tuple2 res : result) { - Assert.assertArrayEquals(content1, res._2().toArray()); + assertArrayEquals(content1, res._2().toArray()); } } @Test public void binaryRecords() throws Exception { // Reusing the wholeText files example - byte[] content1 = "spark isn't always easy to use.\n".getBytes("utf-8"); + byte[] content1 = "spark isn't always easy to use.\n".getBytes(StandardCharsets.UTF_8); int numOfCopies = 10; String tempDirName = tempDir.getAbsolutePath(); File file1 = new File(tempDirName + "/part-00000"); @@ -1169,10 +1198,10 @@ public void binaryRecords() throws Exception { channel1.close(); JavaRDD readRDD = sc.binaryRecords(tempDirName, content1.length); - Assert.assertEquals(numOfCopies,readRDD.count()); + assertEquals(numOfCopies,readRDD.count()); List result = readRDD.collect(); for (byte[] res : result) { - Assert.assertArrayEquals(content1, res); + assertArrayEquals(content1, res); } } @@ -1196,8 +1225,9 @@ public Tuple2 call(Tuple2 pair) { outputDir, IntWritable.class, Text.class, org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); - JavaPairRDD output = sc.sequenceFile(outputDir, IntWritable.class, Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { + JavaPairRDD output = + sc.sequenceFile(outputDir, IntWritable.class, Text.class); + assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1225,8 +1255,8 @@ public Tuple2 call(Tuple2 pair) { JavaPairRDD output = sc.newAPIHadoopFile(outputDir, org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, - IntWritable.class, Text.class, new Job().getConfiguration()); - Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { + IntWritable.class, Text.class, Job.getInstance().getConfiguration()); + assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1242,7 +1272,7 @@ public void objectFilesOfInts() { // Try reading the output back as an object file List expected = Arrays.asList(1, 2, 3, 4); JavaRDD readRDD = sc.objectFile(outputDir); - Assert.assertEquals(expected, readRDD.collect()); + assertEquals(expected, readRDD.collect()); } @SuppressWarnings("unchecked") @@ -1258,7 +1288,7 @@ public void objectFilesOfComplexTypes() { rdd.saveAsObjectFile(outputDir); // Try reading the output back as an object file JavaRDD> readRDD = sc.objectFile(outputDir); - Assert.assertEquals(pairs, readRDD.collect()); + assertEquals(pairs, readRDD.collect()); } @SuppressWarnings("unchecked") @@ -1281,7 +1311,7 @@ public Tuple2 call(Tuple2 pair) { JavaPairRDD output = sc.hadoopFile(outputDir, SequenceFileInputFormat.class, IntWritable.class, Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { + assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1311,7 +1341,7 @@ public Tuple2 call(Tuple2 pair) { JavaPairRDD output = sc.hadoopFile(outputDir, SequenceFileInputFormat.class, IntWritable.class, Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { + assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1339,13 +1369,13 @@ public void zipPartitions() { FlatMapFunction2, Iterator, Integer> sizesFn = new FlatMapFunction2, Iterator, Integer>() { @Override - public Iterable call(Iterator i, Iterator s) { - return Arrays.asList(Iterators.size(i), Iterators.size(s)); + public Iterator call(Iterator i, Iterator s) { + return Arrays.asList(Iterators.size(i), Iterators.size(s)).iterator(); } }; JavaRDD sizes = rdd1.zipPartitions(rdd2, sizesFn); - Assert.assertEquals("[3, 2, 3, 2]", sizes.collect().toString()); + assertEquals("[3, 2, 3, 2]", sizes.collect().toString()); } @Test @@ -1359,7 +1389,7 @@ public void call(Integer x) { intAccum.add(x); } }); - Assert.assertEquals((Integer) 25, intAccum.value()); + assertEquals((Integer) 25, intAccum.value()); final Accumulator doubleAccum = sc.doubleAccumulator(10.0); rdd.foreach(new VoidFunction() { @@ -1368,7 +1398,7 @@ public void call(Integer x) { doubleAccum.add((double) x); } }); - Assert.assertEquals((Double) 25.0, doubleAccum.value()); + assertEquals((Double) 25.0, doubleAccum.value()); // Try a custom accumulator type AccumulatorParam floatAccumulatorParam = new AccumulatorParam() { @@ -1395,11 +1425,11 @@ public void call(Integer x) { floatAccum.add((float) x); } }); - Assert.assertEquals((Float) 25.0f, floatAccum.value()); + assertEquals((Float) 25.0f, floatAccum.value()); // Test the setValue method floatAccum.setValue(5.0f); - Assert.assertEquals((Float) 5.0f, floatAccum.value()); + assertEquals((Float) 5.0f, floatAccum.value()); } @Test @@ -1411,33 +1441,33 @@ public String call(Integer t) { return t.toString(); } }).collect(); - Assert.assertEquals(new Tuple2<>("1", 1), s.get(0)); - Assert.assertEquals(new Tuple2<>("2", 2), s.get(1)); + assertEquals(new Tuple2<>("1", 1), s.get(0)); + assertEquals(new Tuple2<>("2", 2), s.get(1)); } @Test public void checkpointAndComputation() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); sc.setCheckpointDir(tempDir.getAbsolutePath()); - Assert.assertFalse(rdd.isCheckpointed()); + assertFalse(rdd.isCheckpointed()); rdd.checkpoint(); rdd.count(); // Forces the DAG to cause a checkpoint - Assert.assertTrue(rdd.isCheckpointed()); - Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), rdd.collect()); + assertTrue(rdd.isCheckpointed()); + assertEquals(Arrays.asList(1, 2, 3, 4, 5), rdd.collect()); } @Test public void checkpointAndRestore() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); sc.setCheckpointDir(tempDir.getAbsolutePath()); - Assert.assertFalse(rdd.isCheckpointed()); + assertFalse(rdd.isCheckpointed()); rdd.checkpoint(); rdd.count(); // Forces the DAG to cause a checkpoint - Assert.assertTrue(rdd.isCheckpointed()); + assertTrue(rdd.isCheckpointed()); - Assert.assertTrue(rdd.getCheckpointFile().isPresent()); + assertTrue(rdd.getCheckpointFile().isPresent()); JavaRDD recovered = sc.checkpointFile(rdd.getCheckpointFile().get()); - Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect()); + assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect()); } @Test @@ -1456,7 +1486,8 @@ public Integer call(Integer v1) { } }; - Function2 mergeValueFunction = new Function2() { + Function2 mergeValueFunction = + new Function2() { @Override public Integer call(Integer v1, Integer v2) { return v1 + v2; @@ -1467,7 +1498,7 @@ public Integer call(Integer v1, Integer v2) { .combineByKey(createCombinerFunction, mergeValueFunction, mergeValueFunction); Map results = combinedRDD.collectAsMap(); ImmutableMap expected = ImmutableMap.of(0, 9, 1, 5, 2, 7); - Assert.assertEquals(expected, results); + assertEquals(expected, results); Partitioner defaultPartitioner = Partitioner.defaultPartitioner( combinedRDD.rdd(), @@ -1482,7 +1513,7 @@ public Integer call(Integer v1, Integer v2) { false, new KryoSerializer(new SparkConf())); results = combinedRDD.collectAsMap(); - Assert.assertEquals(expected, results); + assertEquals(expected, results); } @SuppressWarnings("unchecked") @@ -1503,7 +1534,7 @@ public Tuple2 call(Tuple2 in) { return new Tuple2<>(in._2(), in._1()); } }); - Assert.assertEquals(Arrays.asList( + assertEquals(Arrays.asList( new Tuple2<>(1, 1), new Tuple2<>(0, 2), new Tuple2<>(1, 3), @@ -1525,21 +1556,19 @@ public Tuple2 call(Integer i) { }); List[] parts = rdd1.collectPartitions(new int[] {0}); - Assert.assertEquals(Arrays.asList(1, 2), parts[0]); + assertEquals(Arrays.asList(1, 2), parts[0]); parts = rdd1.collectPartitions(new int[] {1, 2}); - Assert.assertEquals(Arrays.asList(3, 4), parts[0]); - Assert.assertEquals(Arrays.asList(5, 6, 7), parts[1]); + assertEquals(Arrays.asList(3, 4), parts[0]); + assertEquals(Arrays.asList(5, 6, 7), parts[1]); - Assert.assertEquals(Arrays.asList(new Tuple2<>(1, 1), + assertEquals(Arrays.asList(new Tuple2<>(1, 1), new Tuple2<>(2, 0)), rdd2.collectPartitions(new int[] {0})[0]); List>[] parts2 = rdd2.collectPartitions(new int[] {1, 2}); - Assert.assertEquals(Arrays.asList(new Tuple2<>(3, 1), - new Tuple2<>(4, 0)), - parts2[0]); - Assert.assertEquals(Arrays.asList(new Tuple2<>(5, 1), + assertEquals(Arrays.asList(new Tuple2<>(3, 1), new Tuple2<>(4, 0)), parts2[0]); + assertEquals(Arrays.asList(new Tuple2<>(5, 1), new Tuple2<>(6, 0), new Tuple2<>(7, 1)), parts2[1]); @@ -1553,7 +1582,7 @@ public void countApproxDistinct() { arrayData.add(i % size); } JavaRDD simpleRdd = sc.parallelize(arrayData, 10); - Assert.assertTrue(Math.abs((simpleRdd.countApproxDistinct(0.05) - size) / (size * 1.0)) <= 0.1); + assertTrue(Math.abs((simpleRdd.countApproxDistinct(0.05) - size) / (size * 1.0)) <= 0.1); } @Test @@ -1566,12 +1595,12 @@ public void countApproxDistinctByKey() { } double relativeSD = 0.001; JavaPairRDD pairRdd = sc.parallelizePairs(arrayData); - List> res = pairRdd.countApproxDistinctByKey(relativeSD, 8).collect(); - for (Tuple2 resItem : res) { - double count = (double)resItem._1(); - Long resCount = (Long)resItem._2(); - Double error = Math.abs((resCount - count) / count); - Assert.assertTrue(error < 0.1); + List> res = pairRdd.countApproxDistinctByKey(relativeSD, 8).collect(); + for (Tuple2 resItem : res) { + double count = resItem._1(); + long resCount = resItem._2(); + double error = Math.abs((resCount - count) / count); + assertTrue(error < 0.1); } } @@ -1601,7 +1630,7 @@ public void collectAsMapAndSerialize() throws Exception { new ObjectOutputStream(bytes).writeObject(map); Map deserializedMap = (Map) new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray())).readObject(); - Assert.assertEquals(1, deserializedMap.get("foo").intValue()); + assertEquals(1, deserializedMap.get("foo").intValue()); } @Test @@ -1619,15 +1648,15 @@ public Tuple2 call(Integer i) { fractions.put(0, 0.5); fractions.put(1, 1.0); JavaPairRDD wr = rdd2.sampleByKey(true, fractions, 1L); - Map wrCounts = (Map) (Object) wr.countByKey(); - Assert.assertEquals(2, wrCounts.size()); - Assert.assertTrue(wrCounts.get(0) > 0); - Assert.assertTrue(wrCounts.get(1) > 0); + Map wrCounts = wr.countByKey(); + assertEquals(2, wrCounts.size()); + assertTrue(wrCounts.get(0) > 0); + assertTrue(wrCounts.get(1) > 0); JavaPairRDD wor = rdd2.sampleByKey(false, fractions, 1L); - Map worCounts = (Map) (Object) wor.countByKey(); - Assert.assertEquals(2, worCounts.size()); - Assert.assertTrue(worCounts.get(0) > 0); - Assert.assertTrue(worCounts.get(1) > 0); + Map worCounts = wor.countByKey(); + assertEquals(2, worCounts.size()); + assertTrue(worCounts.get(0) > 0); + assertTrue(worCounts.get(1) > 0); } @Test @@ -1645,15 +1674,15 @@ public Tuple2 call(Integer i) { fractions.put(0, 0.5); fractions.put(1, 1.0); JavaPairRDD wrExact = rdd2.sampleByKeyExact(true, fractions, 1L); - Map wrExactCounts = (Map) (Object) wrExact.countByKey(); - Assert.assertEquals(2, wrExactCounts.size()); - Assert.assertTrue(wrExactCounts.get(0) == 2); - Assert.assertTrue(wrExactCounts.get(1) == 4); + Map wrExactCounts = wrExact.countByKey(); + assertEquals(2, wrExactCounts.size()); + assertTrue(wrExactCounts.get(0) == 2); + assertTrue(wrExactCounts.get(1) == 4); JavaPairRDD worExact = rdd2.sampleByKeyExact(false, fractions, 1L); - Map worExactCounts = (Map) (Object) worExact.countByKey(); - Assert.assertEquals(2, worExactCounts.size()); - Assert.assertTrue(worExactCounts.get(0) == 2); - Assert.assertTrue(worExactCounts.get(1) == 4); + Map worExactCounts = worExact.countByKey(); + assertEquals(2, worExactCounts.size()); + assertTrue(worExactCounts.get(0) == 2); + assertTrue(worExactCounts.get(1) == 4); } private static class SomeCustomClass implements Serializable { @@ -1669,8 +1698,9 @@ public void collectUnderlyingScalaRDD() { data.add(new SomeCustomClass()); } JavaRDD rdd = sc.parallelize(data); - SomeCustomClass[] collected = (SomeCustomClass[]) rdd.rdd().retag(SomeCustomClass.class).collect(); - Assert.assertEquals(data.size(), collected.length); + SomeCustomClass[] collected = + (SomeCustomClass[]) rdd.rdd().retag(SomeCustomClass.class).collect(); + assertEquals(data.size(), collected.length); } private static final class BuggyMapFunction implements Function { @@ -1687,10 +1717,10 @@ public void collectAsync() throws Exception { JavaRDD rdd = sc.parallelize(data, 1); JavaFutureAction> future = rdd.collectAsync(); List result = future.get(); - Assert.assertEquals(data, result); - Assert.assertFalse(future.isCancelled()); - Assert.assertTrue(future.isDone()); - Assert.assertEquals(1, future.jobIds().size()); + assertEquals(data, result); + assertFalse(future.isCancelled()); + assertTrue(future.isDone()); + assertEquals(1, future.jobIds().size()); } @Test @@ -1699,11 +1729,11 @@ public void takeAsync() throws Exception { JavaRDD rdd = sc.parallelize(data, 1); JavaFutureAction> future = rdd.takeAsync(1); List result = future.get(); - Assert.assertEquals(1, result.size()); - Assert.assertEquals((Integer) 1, result.get(0)); - Assert.assertFalse(future.isCancelled()); - Assert.assertTrue(future.isDone()); - Assert.assertEquals(1, future.jobIds().size()); + assertEquals(1, result.size()); + assertEquals((Integer) 1, result.get(0)); + assertFalse(future.isCancelled()); + assertTrue(future.isDone()); + assertEquals(1, future.jobIds().size()); } @Test @@ -1719,9 +1749,9 @@ public void call(Integer integer) { } ); future.get(); - Assert.assertFalse(future.isCancelled()); - Assert.assertTrue(future.isDone()); - Assert.assertEquals(1, future.jobIds().size()); + assertFalse(future.isCancelled()); + assertTrue(future.isDone()); + assertEquals(1, future.jobIds().size()); } @Test @@ -1730,10 +1760,10 @@ public void countAsync() throws Exception { JavaRDD rdd = sc.parallelize(data, 1); JavaFutureAction future = rdd.countAsync(); long count = future.get(); - Assert.assertEquals(data.size(), count); - Assert.assertFalse(future.isCancelled()); - Assert.assertTrue(future.isDone()); - Assert.assertEquals(1, future.jobIds().size()); + assertEquals(data.size(), count); + assertFalse(future.isCancelled()); + assertTrue(future.isDone()); + assertEquals(1, future.jobIds().size()); } @Test @@ -1747,11 +1777,11 @@ public void call(Integer integer) throws InterruptedException { } }); future.cancel(true); - Assert.assertTrue(future.isCancelled()); - Assert.assertTrue(future.isDone()); + assertTrue(future.isCancelled()); + assertTrue(future.isDone()); try { future.get(2000, TimeUnit.MILLISECONDS); - Assert.fail("Expected future.get() for cancelled job to throw CancellationException"); + fail("Expected future.get() for cancelled job to throw CancellationException"); } catch (CancellationException ignored) { // pass } @@ -1764,37 +1794,11 @@ public void testAsyncActionErrorWrapping() throws Exception { JavaFutureAction future = rdd.map(new BuggyMapFunction()).countAsync(); try { future.get(2, TimeUnit.SECONDS); - Assert.fail("Expected future.get() for failed job to throw ExcecutionException"); + fail("Expected future.get() for failed job to throw ExcecutionException"); } catch (ExecutionException ee) { - Assert.assertTrue(Throwables.getStackTraceAsString(ee).contains("Custom exception!")); - } - Assert.assertTrue(future.isDone()); - } - - - /** - * Test for SPARK-3647. This test needs to use the maven-built assembly to trigger the issue, - * since that's the only artifact where Guava classes have been relocated. - */ - @Test - public void testGuavaOptional() { - // Stop the context created in setUp() and start a local-cluster one, to force usage of the - // assembly. - sc.stop(); - JavaSparkContext localCluster = new JavaSparkContext("local-cluster[1,1,1024]", "JavaAPISuite"); - try { - JavaRDD rdd1 = localCluster.parallelize(Arrays.asList(1, 2, null), 3); - JavaRDD> rdd2 = rdd1.map( - new Function>() { - @Override - public Optional call(Integer i) { - return Optional.fromNullable(i); - } - }); - rdd2.collect(); - } finally { - localCluster.stop(); + assertTrue(Throwables.getStackTraceAsString(ee).contains("Custom exception!")); } + assertTrue(future.isDone()); } static class Class1 {} @@ -1804,9 +1808,21 @@ static class Class2 {} public void testRegisterKryoClasses() { SparkConf conf = new SparkConf(); conf.registerKryoClasses(new Class[]{ Class1.class, Class2.class }); - Assert.assertEquals( + assertEquals( Class1.class.getName() + "," + Class2.class.getName(), conf.get("spark.kryo.classesToRegister")); } + @Test + public void testGetPersistentRDDs() { + java.util.Map> cachedRddsMap = sc.getPersistentRDDs(); + assertTrue(cachedRddsMap.isEmpty()); + JavaRDD rdd1 = sc.parallelize(Arrays.asList("a", "b")).setName("RDD1").cache(); + JavaRDD rdd2 = sc.parallelize(Arrays.asList("c", "d")).setName("RDD2").cache(); + cachedRddsMap = sc.getPersistentRDDs(); + assertEquals(2, cachedRddsMap.size()); + assertEquals("RDD1", cachedRddsMap.get(0).name()); + assertEquals("RDD2", cachedRddsMap.get(1).name()); + } + } diff --git a/core/src/test/java/org/apache/spark/api/java/OptionalSuite.java b/core/src/test/java/org/apache/spark/api/java/OptionalSuite.java new file mode 100644 index 0000000000000..4b97c18198c1a --- /dev/null +++ b/core/src/test/java/org/apache/spark/api/java/OptionalSuite.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java; + +import org.junit.Assert; +import org.junit.Test; + +/** + * Tests {@link Optional}. + */ +public class OptionalSuite { + + @Test + public void testEmpty() { + Assert.assertFalse(Optional.empty().isPresent()); + Assert.assertNull(Optional.empty().orNull()); + Assert.assertEquals("foo", Optional.empty().or("foo")); + Assert.assertEquals("foo", Optional.empty().orElse("foo")); + } + + @Test(expected = NullPointerException.class) + public void testEmptyGet() { + Optional.empty().get(); + } + + @Test + public void testAbsent() { + Assert.assertFalse(Optional.absent().isPresent()); + Assert.assertNull(Optional.absent().orNull()); + Assert.assertEquals("foo", Optional.absent().or("foo")); + Assert.assertEquals("foo", Optional.absent().orElse("foo")); + } + + @Test(expected = NullPointerException.class) + public void testAbsentGet() { + Optional.absent().get(); + } + + @Test + public void testOf() { + Assert.assertTrue(Optional.of(1).isPresent()); + Assert.assertNotNull(Optional.of(1).orNull()); + Assert.assertEquals(Integer.valueOf(1), Optional.of(1).get()); + Assert.assertEquals(Integer.valueOf(1), Optional.of(1).or(2)); + Assert.assertEquals(Integer.valueOf(1), Optional.of(1).orElse(2)); + } + + @Test(expected = NullPointerException.class) + public void testOfWithNull() { + Optional.of(null); + } + + @Test + public void testOfNullable() { + Assert.assertTrue(Optional.ofNullable(1).isPresent()); + Assert.assertNotNull(Optional.ofNullable(1).orNull()); + Assert.assertEquals(Integer.valueOf(1), Optional.ofNullable(1).get()); + Assert.assertEquals(Integer.valueOf(1), Optional.ofNullable(1).or(2)); + Assert.assertEquals(Integer.valueOf(1), Optional.ofNullable(1).orElse(2)); + Assert.assertFalse(Optional.ofNullable(null).isPresent()); + Assert.assertNull(Optional.ofNullable(null).orNull()); + Assert.assertEquals(Integer.valueOf(2), Optional.ofNullable(null).or(2)); + Assert.assertEquals(Integer.valueOf(2), Optional.ofNullable(null).orElse(2)); + } + + @Test + public void testFromNullable() { + Assert.assertTrue(Optional.fromNullable(1).isPresent()); + Assert.assertNotNull(Optional.fromNullable(1).orNull()); + Assert.assertEquals(Integer.valueOf(1), Optional.fromNullable(1).get()); + Assert.assertEquals(Integer.valueOf(1), Optional.fromNullable(1).or(2)); + Assert.assertEquals(Integer.valueOf(1), Optional.fromNullable(1).orElse(2)); + Assert.assertFalse(Optional.fromNullable(null).isPresent()); + Assert.assertNull(Optional.fromNullable(null).orNull()); + Assert.assertEquals(Integer.valueOf(2), Optional.fromNullable(null).or(2)); + Assert.assertEquals(Integer.valueOf(2), Optional.fromNullable(null).orElse(2)); + } + +} diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index aa15e792e2b27..3e47bfc274cb1 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -17,9 +17,6 @@ package org.apache.spark.launcher; -import java.io.BufferedReader; -import java.io.InputStream; -import java.io.InputStreamReader; import java.util.Arrays; import java.util.HashMap; import java.util.Map; @@ -91,7 +88,7 @@ public void testSparkArgumentHandling() throws Exception { @Test public void testChildProcLauncher() throws Exception { SparkSubmitOptionParser opts = new SparkSubmitOptionParser(); - Map env = new HashMap(); + Map env = new HashMap<>(); env.put("SPARK_PRINT_LAUNCH_COMMAND", "1"); SparkLauncher launcher = new SparkLauncher(env) diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index dab7b0592cb4e..127789b632b44 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -17,8 +17,6 @@ package org.apache.spark.memory; -import java.io.IOException; - import org.junit.Assert; import org.junit.Test; @@ -27,39 +25,26 @@ public class TaskMemoryManagerSuite { - class TestMemoryConsumer extends MemoryConsumer { - TestMemoryConsumer(TaskMemoryManager memoryManager) { - super(memoryManager); - } - - @Override - public long spill(long size, MemoryConsumer trigger) throws IOException { - long used = getUsed(); - releaseMemory(used); - return used; - } - - void use(long size) { - acquireMemory(size); - } - - void free(long size) { - releaseMemory(size); - } - } - @Test public void leakedPageMemoryIsDetected() { final TaskMemoryManager manager = new TaskMemoryManager( - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MAX_VALUE, + Long.MAX_VALUE, + 1), + 0); manager.allocatePage(4096, null); // leak memory + Assert.assertEquals(4096, manager.getMemoryConsumptionForThisTask()); Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory()); } @Test public void encodePageNumberAndOffsetOffHeap() { - final TaskMemoryManager manager = new TaskMemoryManager( - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "true")), 0); + final SparkConf conf = new SparkConf() + .set("spark.memory.offHeap.enabled", "true") + .set("spark.memory.offHeap.size", "1000"); + final TaskMemoryManager manager = new TaskMemoryManager(new TestMemoryManager(conf), 0); final MemoryBlock dataPage = manager.allocatePage(256, null); // In off-heap mode, an offset is an absolute address that may require more than 51 bits to // encode. This test exercises that corner-case: @@ -72,7 +57,7 @@ public void encodePageNumberAndOffsetOffHeap() { @Test public void encodePageNumberAndOffsetOnHeap() { final TaskMemoryManager manager = new TaskMemoryManager( - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); final MemoryBlock dataPage = manager.allocatePage(256, null); final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64); Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress)); @@ -88,37 +73,48 @@ public void cooperativeSpilling() { TestMemoryConsumer c1 = new TestMemoryConsumer(manager); TestMemoryConsumer c2 = new TestMemoryConsumer(manager); c1.use(100); - assert(c1.getUsed() == 100); + Assert.assertEquals(100, c1.getUsed()); c2.use(100); - assert(c2.getUsed() == 100); - assert(c1.getUsed() == 0); // spilled + Assert.assertEquals(100, c2.getUsed()); + Assert.assertEquals(0, c1.getUsed()); // spilled c1.use(100); - assert(c1.getUsed() == 100); - assert(c2.getUsed() == 0); // spilled + Assert.assertEquals(100, c1.getUsed()); + Assert.assertEquals(0, c2.getUsed()); // spilled c1.use(50); - assert(c1.getUsed() == 50); // spilled - assert(c2.getUsed() == 0); + Assert.assertEquals(50, c1.getUsed()); // spilled + Assert.assertEquals(0, c2.getUsed()); c2.use(50); - assert(c1.getUsed() == 50); - assert(c2.getUsed() == 50); + Assert.assertEquals(50, c1.getUsed()); + Assert.assertEquals(50, c2.getUsed()); c1.use(100); - assert(c1.getUsed() == 100); - assert(c2.getUsed() == 0); // spilled + Assert.assertEquals(100, c1.getUsed()); + Assert.assertEquals(0, c2.getUsed()); // spilled c1.free(20); - assert(c1.getUsed() == 80); + Assert.assertEquals(80, c1.getUsed()); c2.use(10); - assert(c1.getUsed() == 80); - assert(c2.getUsed() == 10); + Assert.assertEquals(80, c1.getUsed()); + Assert.assertEquals(10, c2.getUsed()); c2.use(100); - assert(c2.getUsed() == 100); - assert(c1.getUsed() == 0); // spilled + Assert.assertEquals(100, c2.getUsed()); + Assert.assertEquals(0, c1.getUsed()); // spilled c1.free(0); c2.free(100); - assert(manager.cleanUpAllAllocatedMemory() == 0); + Assert.assertEquals(0, manager.cleanUpAllAllocatedMemory()); + } + + @Test + public void offHeapConfigurationBackwardsCompatibility() { + // Tests backwards-compatibility with the old `spark.unsafe.offHeap` configuration, which + // was deprecated in Spark 1.6 and replaced by `spark.memory.offHeap.enabled` (see SPARK-12251). + final SparkConf conf = new SparkConf() + .set("spark.unsafe.offHeap", "true") + .set("spark.memory.offHeap.size", "1000"); + final TaskMemoryManager manager = new TaskMemoryManager(new TestMemoryManager(conf), 0); + Assert.assertSame(MemoryMode.OFF_HEAP, manager.tungstenMemoryMode); } } diff --git a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java new file mode 100644 index 0000000000000..e6e16fff80401 --- /dev/null +++ b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory; + +import java.io.IOException; + +public class TestMemoryConsumer extends MemoryConsumer { + public TestMemoryConsumer(TaskMemoryManager memoryManager) { + super(memoryManager); + } + + @Override + public long spill(long size, MemoryConsumer trigger) throws IOException { + long used = getUsed(); + free(used); + return used; + } + + void use(long size) { + long got = taskMemoryManager.acquireExecutionMemory( + size, + taskMemoryManager.tungstenMemoryMode, + this); + used += got; + } + + void free(long size) { + used -= size; + taskMemoryManager.releaseExecutionMemory( + size, + taskMemoryManager.tungstenMemoryMode, + this); + } +} + + diff --git a/core/src/test/java/org/apache/spark/serializer/TestJavaSerializerImpl.java b/core/src/test/java/org/apache/spark/serializer/TestJavaSerializerImpl.java index 3d50ab4fabe42..8aa0636700991 100644 --- a/core/src/test/java/org/apache/spark/serializer/TestJavaSerializerImpl.java +++ b/core/src/test/java/org/apache/spark/serializer/TestJavaSerializerImpl.java @@ -21,7 +21,6 @@ import java.io.OutputStream; import java.nio.ByteBuffer; -import scala.Option; import scala.reflect.ClassTag; diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java index 9a43f1f3a9235..fe5abc5c23049 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java @@ -35,7 +35,7 @@ public class PackedRecordPointerSuite { @Test public void heap() throws IOException { - final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false"); + final SparkConf conf = new SparkConf().set("spark.memory.offHeap.enabled", "false"); final TaskMemoryManager memoryManager = new TaskMemoryManager(new TestMemoryManager(conf), 0); final MemoryBlock page0 = memoryManager.allocatePage(128, null); @@ -54,7 +54,9 @@ public void heap() throws IOException { @Test public void offHeap() throws IOException { - final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "true"); + final SparkConf conf = new SparkConf() + .set("spark.memory.offHeap.enabled", "true") + .set("spark.memory.offHeap.size", "10000"); final TaskMemoryManager memoryManager = new TaskMemoryManager(new TestMemoryManager(conf), 0); final MemoryBlock page0 = memoryManager.allocatePage(128, null); diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java index 2293b1bbc113e..4cd3600df1c29 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.shuffle.sort; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Random; @@ -25,24 +26,30 @@ import org.apache.spark.HashPartitioner; import org.apache.spark.SparkConf; -import org.apache.spark.unsafe.Platform; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TestMemoryConsumer; import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.memory.TaskMemoryManager; public class ShuffleInMemorySorterSuite { + final TestMemoryManager memoryManager = + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")); + final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); + final TestMemoryConsumer consumer = new TestMemoryConsumer(taskMemoryManager); + private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) { final byte[] strBytes = new byte[strLength]; Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, strLength); - return new String(strBytes); + return new String(strBytes, StandardCharsets.UTF_8); } @Test public void testSortingEmptyInput() { - final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(100); + final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 100); final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator(); - assert(!iter.hasNext()); + Assert.assertFalse(iter.hasNext()); } @Test @@ -58,19 +65,22 @@ public void testBasicSorting() throws Exception { "Lychee", "Mango" }; - final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false"); + final SparkConf conf = new SparkConf().set("spark.memory.offHeap.enabled", "false"); final TaskMemoryManager memoryManager = new TaskMemoryManager(new TestMemoryManager(conf), 0); final MemoryBlock dataPage = memoryManager.allocatePage(2048, null); final Object baseObject = dataPage.getBaseObject(); - final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4); + final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4); final HashPartitioner hashPartitioner = new HashPartitioner(4); // Write the records into the data page and store pointers into the sorter long position = dataPage.getBaseOffset(); for (String str : dataToSort) { + if (!sorter.hasSpaceForAnotherRecord()) { + sorter.expandPointerArray(consumer.allocateArray(sorter.numRecords() * 2)); + } final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position); - final byte[] strBytes = str.getBytes("utf-8"); + final byte[] strBytes = str.getBytes(StandardCharsets.UTF_8); Platform.putInt(baseObject, position, strBytes.length); position += 4; Platform.copyMemory( @@ -104,10 +114,13 @@ public void testBasicSorting() throws Exception { @Test public void testSortingManyNumbers() throws Exception { - ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4); + ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4); int[] numbersToSort = new int[128000]; Random random = new Random(16); for (int i = 0; i < numbersToSort.length; i++) { + if (!sorter.hasSpaceForAnotherRecord()) { + sorter.expandPointerArray(consumer.allocateArray(sorter.numRecords() * 2)); + } numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1); sorter.insertRecord(0, numbersToSort[i]); } diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 4763395d7d401..30750b1bf1980 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -67,9 +67,10 @@ public class UnsafeShuffleWriterSuite { File mergedOutputFile; File tempDir; long[] partitionSizesInMergedFile; - final LinkedList spillFilesCreated = new LinkedList(); + final LinkedList spillFilesCreated = new LinkedList<>(); SparkConf conf; final Serializer serializer = new KryoSerializer(new SparkConf()); + final SerializerManager serializerManager = new SerializerManager(serializer, new SparkConf()); TaskMetrics taskMetrics; @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @@ -108,10 +109,10 @@ public void setUp() throws IOException { spillFilesCreated.clear(); conf = new SparkConf() .set("spark.buffer.pageSize", "1m") - .set("spark.unsafe.offHeap", "false"); + .set("spark.memory.offHeap.enabled", "false"); taskMetrics = new TaskMetrics(); memoryManager = new TestMemoryManager(conf); - taskMemoryManager = new TaskMemoryManager(memoryManager, 0); + taskMemoryManager = new TaskMemoryManager(memoryManager, 0); when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); when(blockManager.getDiskWriter( @@ -130,48 +131,24 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th (Integer) args[3], new CompressStream(), false, - (ShuffleWriteMetrics) args[4] + (ShuffleWriteMetrics) args[4], + (BlockId) args[0] ); } }); - when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))).thenAnswer( - new Answer() { - @Override - public InputStream answer(InvocationOnMock invocation) throws Throwable { - assert (invocation.getArguments()[0] instanceof TempShuffleBlockId); - InputStream is = (InputStream) invocation.getArguments()[1]; - if (conf.getBoolean("spark.shuffle.compress", true)) { - return CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(is); - } else { - return is; - } - } - } - ); - - when(blockManager.wrapForCompression(any(BlockId.class), any(OutputStream.class))).thenAnswer( - new Answer() { - @Override - public OutputStream answer(InvocationOnMock invocation) throws Throwable { - assert (invocation.getArguments()[0] instanceof TempShuffleBlockId); - OutputStream os = (OutputStream) invocation.getArguments()[1]; - if (conf.getBoolean("spark.shuffle.compress", true)) { - return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(os); - } else { - return os; - } - } - } - ); when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); doAnswer(new Answer() { @Override public Void answer(InvocationOnMock invocationOnMock) throws Throwable { partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; + File tmp = (File) invocationOnMock.getArguments()[3]; + mergedOutputFile.delete(); + tmp.renameTo(mergedOutputFile); return null; } - }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class)); + }).when(shuffleBlockResolver) + .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class)); when(diskBlockManager.createTempShuffleBlock()).thenAnswer( new Answer>() { @@ -186,20 +163,18 @@ public Tuple2 answer( }); when(taskContext.taskMetrics()).thenReturn(taskMetrics); - when(taskContext.internalMetricsToAccumulators()).thenReturn(null); - - when(shuffleDep.serializer()).thenReturn(Option.apply(serializer)); + when(shuffleDep.serializer()).thenReturn(serializer); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); } private UnsafeShuffleWriter createWriter( boolean transferToEnabled) throws IOException { conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); - return new UnsafeShuffleWriter( + return new UnsafeShuffleWriter<>( blockManager, shuffleBlockResolver, taskMemoryManager, - new SerializedShuffleHandle(0, 1, shuffleDep), + new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, conf @@ -214,7 +189,7 @@ private void assertSpillFilesWereCleanedUp() { } private List> readRecordsFromFile() throws IOException { - final ArrayList> recordsList = new ArrayList>(); + final ArrayList> recordsList = new ArrayList<>(); long startOffset = 0; for (int i = 0; i < NUM_PARTITITONS; i++) { final long partitionSize = partitionSizesInMergedFile[i]; @@ -249,7 +224,7 @@ public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException { createWriter(false).stop(false); } - class PandaException extends RuntimeException { + static class PandaException extends RuntimeException { } @Test(expected=PandaException.class) @@ -274,8 +249,8 @@ public void writeEmptyIterator() throws Exception { assertTrue(mapStatus.isDefined()); assertTrue(mergedOutputFile.exists()); assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile); - assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleRecordsWritten()); - assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleBytesWritten()); + assertEquals(0, taskMetrics.shuffleWriteMetrics().get().recordsWritten()); + assertEquals(0, taskMetrics.shuffleWriteMetrics().get().bytesWritten()); assertEquals(0, taskMetrics.diskBytesSpilled()); assertEquals(0, taskMetrics.memoryBytesSpilled()); } @@ -283,8 +258,7 @@ public void writeEmptyIterator() throws Exception { @Test public void writeWithoutSpilling() throws Exception { // In this example, each partition should have exactly one record: - final ArrayList> dataToWrite = - new ArrayList>(); + final ArrayList> dataToWrite = new ArrayList<>(); for (int i = 0; i < NUM_PARTITITONS; i++) { dataToWrite.add(new Tuple2(i, i)); } @@ -306,10 +280,10 @@ public void writeWithoutSpilling() throws Exception { HashMultiset.create(readRecordsFromFile())); assertSpillFilesWereCleanedUp(); ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); - assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); assertEquals(0, taskMetrics.diskBytesSpilled()); assertEquals(0, taskMetrics.memoryBytesSpilled()); - assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); + assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten()); } private void testMergingSpills( @@ -322,8 +296,7 @@ private void testMergingSpills( conf.set("spark.shuffle.compress", "false"); } final UnsafeShuffleWriter writer = createWriter(transferToEnabled); - final ArrayList> dataToWrite = - new ArrayList>(); + final ArrayList> dataToWrite = new ArrayList<>(); for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) { dataToWrite.add(new Tuple2(i, i)); } @@ -349,11 +322,11 @@ private void testMergingSpills( assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile())); assertSpillFilesWereCleanedUp(); ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); - assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); - assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); + assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten()); } @Test @@ -400,7 +373,7 @@ public void mergeSpillsWithFileStreamAndNoCompression() throws Exception { public void writeEnoughDataToTriggerSpill() throws Exception { memoryManager.limit(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES); final UnsafeShuffleWriter writer = createWriter(false); - final ArrayList> dataToWrite = new ArrayList>(); + final ArrayList> dataToWrite = new ArrayList<>(); final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 10]; for (int i = 0; i < 10 + 1; i++) { dataToWrite.add(new Tuple2(i, bigByteArray)); @@ -411,11 +384,11 @@ public void writeEnoughDataToTriggerSpill() throws Exception { readRecordsFromFile(); assertSpillFilesWereCleanedUp(); ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); - assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); - assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); + assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten()); } @Test @@ -423,7 +396,7 @@ public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exce memoryManager.limit(UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE * 16); final UnsafeShuffleWriter writer = createWriter(false); final ArrayList> dataToWrite = new ArrayList<>(); - for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) { + for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE + 1; i++) { dataToWrite.add(new Tuple2(i, i)); } writer.write(dataToWrite.iterator()); @@ -432,18 +405,17 @@ public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exce readRecordsFromFile(); assertSpillFilesWereCleanedUp(); ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); - assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); - assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); + assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten()); } @Test public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception { final UnsafeShuffleWriter writer = createWriter(false); - final ArrayList> dataToWrite = - new ArrayList>(); + final ArrayList> dataToWrite = new ArrayList<>(); final byte[] bytes = new byte[(int) (ShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)]; new Random(42).nextBytes(bytes); dataToWrite.add(new Tuple2(1, ByteBuffer.wrap(bytes))); @@ -458,7 +430,7 @@ public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception @Test public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception { final UnsafeShuffleWriter writer = createWriter(false); - final ArrayList> dataToWrite = new ArrayList>(); + final ArrayList> dataToWrite = new ArrayList<>(); dataToWrite.add(new Tuple2(1, ByteBuffer.wrap(new byte[1]))); // We should be able to write a record that's right _at_ the max record size final byte[] atMaxRecordSize = new byte[(int) taskMemoryManager.pageSizeBytes() - 4]; @@ -495,7 +467,7 @@ public void testPeakMemoryUsed() throws Exception { taskMemoryManager = spy(taskMemoryManager); when(taskMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes); final UnsafeShuffleWriter writer = - new UnsafeShuffleWriter( + new UnsafeShuffleWriter<>( blockManager, shuffleBlockResolver, taskMemoryManager, diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 92bd45e5fa241..84b82f5a4742c 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -19,7 +19,6 @@ import java.io.File; import java.io.IOException; -import java.io.InputStream; import java.io.OutputStream; import java.nio.ByteBuffer; import java.util.*; @@ -39,19 +38,20 @@ import org.apache.spark.SparkConf; import org.apache.spark.executor.ShuffleWriteMetrics; -import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.serializer.JavaSerializer; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; -import org.apache.spark.unsafe.memory.MemoryLocation; import org.apache.spark.util.Utils; import static org.hamcrest.Matchers.greaterThan; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -import static org.mockito.AdditionalAnswers.returnsSecondArg; import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; @@ -64,9 +64,12 @@ public abstract class AbstractBytesToBytesMapSuite { private TestMemoryManager memoryManager; private TaskMemoryManager taskMemoryManager; - private final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes + private SerializerManager serializerManager = new SerializerManager( + new JavaSerializer(new SparkConf()), + new SparkConf().set("spark.shuffle.spill.compress", "false")); + private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes - final LinkedList spillFilesCreated = new LinkedList(); + final LinkedList spillFilesCreated = new LinkedList<>(); File tempDir; @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @@ -83,16 +86,22 @@ public OutputStream apply(OutputStream stream) { public void setup() { memoryManager = new TestMemoryManager( - new SparkConf().set("spark.unsafe.offHeap", "" + useOffHeapMemoryAllocator())); + new SparkConf() + .set("spark.memory.offHeap.enabled", "" + useOffHeapMemoryAllocator()) + .set("spark.memory.offHeap.size", "256mb") + .set("spark.shuffle.spill.compress", "false") + .set("spark.shuffle.compress", "false")); taskMemoryManager = new TaskMemoryManager(memoryManager, 0); tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test"); spillFilesCreated.clear(); MockitoAnnotations.initMocks(this); when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); - when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() { + when(diskBlockManager.createTempLocalBlock()).thenAnswer( + new Answer>() { @Override - public Tuple2 answer(InvocationOnMock invocationOnMock) throws Throwable { + public Tuple2 answer(InvocationOnMock invocationOnMock) + throws Throwable { TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); File file = File.createTempFile("spillFile", ".spill", tempDir); spillFilesCreated.add(file); @@ -115,12 +124,11 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th (Integer) args[3], new CompressStream(), false, - (ShuffleWriteMetrics) args[4] + (ShuffleWriteMetrics) args[4], + (BlockId) args[0] ); } }); - when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) - .then(returnsSecondArg()); } @After @@ -128,8 +136,8 @@ public void tearDown() { Utils.deleteRecursively(tempDir); tempDir = null; - Assert.assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory()); if (taskMemoryManager != null) { + Assert.assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory()); long leakedMemory = taskMemoryManager.getMemoryConsumptionForThisTask(); taskMemoryManager = null; Assert.assertEquals(0L, leakedMemory); @@ -138,10 +146,9 @@ public void tearDown() { protected abstract boolean useOffHeapMemoryAllocator(); - private static byte[] getByteArray(MemoryLocation loc, int size) { + private static byte[] getByteArray(Object base, long offset, int size) { final byte[] arr = new byte[size]; - Platform.copyMemory( - loc.getBaseObject(), loc.getBaseOffset(), arr, Platform.BYTE_ARRAY_OFFSET, size); + Platform.copyMemory(base, offset, arr, Platform.BYTE_ARRAY_OFFSET, size); return arr; } @@ -159,13 +166,14 @@ private byte[] getRandomByteArray(int numWords) { */ private static boolean arrayEquals( byte[] expected, - MemoryLocation actualAddr, + Object base, + long offset, long actualLengthBytes) { return (actualLengthBytes == expected.length) && ByteArrayMethods.arrayEquals( expected, Platform.BYTE_ARRAY_OFFSET, - actualAddr.getBaseObject(), - actualAddr.getBaseOffset(), + base, + offset, expected.length ); } @@ -174,7 +182,7 @@ private static boolean arrayEquals( public void emptyMap() { BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, PAGE_SIZE_BYTES); try { - Assert.assertEquals(0, map.numElements()); + Assert.assertEquals(0, map.numKeys()); final int keyLengthInWords = 10; final int keyLengthInBytes = keyLengthInWords * 8; final byte[] key = getRandomByteArray(keyLengthInWords); @@ -196,7 +204,7 @@ public void setAndRetrieveAKey() { final BytesToBytesMap.Location loc = map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes); Assert.assertFalse(loc.isDefined()); - Assert.assertTrue(loc.putNewKey( + Assert.assertTrue(loc.append( keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes, @@ -208,19 +216,23 @@ public void setAndRetrieveAKey() { // reflect the result of this store without us having to call lookup() again on the same key. Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); Assert.assertEquals(recordLengthBytes, loc.getValueLength()); - Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); - Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); + Assert.assertArrayEquals(keyData, + getByteArray(loc.getKeyBase(), loc.getKeyOffset(), recordLengthBytes)); + Assert.assertArrayEquals(valueData, + getByteArray(loc.getValueBase(), loc.getValueOffset(), recordLengthBytes)); // After calling lookup() the location should still point to the correct data. Assert.assertTrue( map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined()); Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); Assert.assertEquals(recordLengthBytes, loc.getValueLength()); - Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); - Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); + Assert.assertArrayEquals(keyData, + getByteArray(loc.getKeyBase(), loc.getKeyOffset(), recordLengthBytes)); + Assert.assertArrayEquals(valueData, + getByteArray(loc.getValueBase(), loc.getValueOffset(), recordLengthBytes)); try { - Assert.assertTrue(loc.putNewKey( + Assert.assertTrue(loc.append( keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes, @@ -248,7 +260,7 @@ private void iteratorTestBase(boolean destructive) throws Exception { Assert.assertFalse(loc.isDefined()); // Ensure that we store some zero-length keys if (i % 5 == 0) { - Assert.assertTrue(loc.putNewKey( + Assert.assertTrue(loc.append( null, Platform.LONG_ARRAY_OFFSET, 0, @@ -257,7 +269,7 @@ private void iteratorTestBase(boolean destructive) throws Exception { 8 )); } else { - Assert.assertTrue(loc.putNewKey( + Assert.assertTrue(loc.append( value, Platform.LONG_ARRAY_OFFSET, 8, @@ -279,15 +291,12 @@ private void iteratorTestBase(boolean destructive) throws Exception { while (iter.hasNext()) { final BytesToBytesMap.Location loc = iter.next(); Assert.assertTrue(loc.isDefined()); - final MemoryLocation keyAddress = loc.getKeyAddress(); - final MemoryLocation valueAddress = loc.getValueAddress(); - final long value = Platform.getLong( - valueAddress.getBaseObject(), valueAddress.getBaseOffset()); + final long value = Platform.getLong(loc.getValueBase(), loc.getValueOffset()); final long keyLength = loc.getKeyLength(); if (keyLength == 0) { Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0); } else { - final long key = Platform.getLong(keyAddress.getBaseObject(), keyAddress.getBaseOffset()); + final long key = Platform.getLong(loc.getKeyBase(), loc.getKeyOffset()); Assert.assertEquals(value, key); } valuesSeen.set((int) value); @@ -340,7 +349,7 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { KEY_LENGTH ); Assert.assertFalse(loc.isDefined()); - Assert.assertTrue(loc.putNewKey( + Assert.assertTrue(loc.append( key, Platform.LONG_ARRAY_OFFSET, KEY_LENGTH, @@ -353,23 +362,23 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { final java.util.BitSet valuesSeen = new java.util.BitSet(NUM_ENTRIES); final Iterator iter = map.iterator(); - final long key[] = new long[KEY_LENGTH / 8]; - final long value[] = new long[VALUE_LENGTH / 8]; + final long[] key = new long[KEY_LENGTH / 8]; + final long[] value = new long[VALUE_LENGTH / 8]; while (iter.hasNext()) { final BytesToBytesMap.Location loc = iter.next(); Assert.assertTrue(loc.isDefined()); Assert.assertEquals(KEY_LENGTH, loc.getKeyLength()); Assert.assertEquals(VALUE_LENGTH, loc.getValueLength()); Platform.copyMemory( - loc.getKeyAddress().getBaseObject(), - loc.getKeyAddress().getBaseOffset(), + loc.getKeyBase(), + loc.getKeyOffset(), key, Platform.LONG_ARRAY_OFFSET, KEY_LENGTH ); Platform.copyMemory( - loc.getValueAddress().getBaseObject(), - loc.getValueAddress().getBaseOffset(), + loc.getValueBase(), + loc.getValueOffset(), value, Platform.LONG_ARRAY_OFFSET, VALUE_LENGTH @@ -393,7 +402,7 @@ public void randomizedStressTest() { final int size = 65536; // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays // into ByteBuffers in order to use them as keys here. - final Map expected = new HashMap(); + final Map expected = new HashMap<>(); final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, size, PAGE_SIZE_BYTES); try { // Fill the map to 90% full so that we can trigger probing @@ -408,7 +417,7 @@ public void randomizedStressTest() { key.length ); Assert.assertFalse(loc.isDefined()); - Assert.assertTrue(loc.putNewKey( + Assert.assertTrue(loc.append( key, Platform.BYTE_ARRAY_OFFSET, key.length, @@ -421,19 +430,22 @@ public void randomizedStressTest() { Assert.assertTrue(loc.isDefined()); Assert.assertEquals(key.length, loc.getKeyLength()); Assert.assertEquals(value.length, loc.getValueLength()); - Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length)); - Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length)); + Assert.assertTrue(arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), key.length)); + Assert.assertTrue( + arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), value.length)); } } for (Map.Entry entry : expected.entrySet()) { - final byte[] key = entry.getKey().array(); + final byte[] key = JavaUtils.bufferToArray(entry.getKey()); final byte[] value = entry.getValue(); final BytesToBytesMap.Location loc = map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); Assert.assertTrue(loc.isDefined()); - Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); - Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); + Assert.assertTrue( + arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength())); + Assert.assertTrue( + arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), loc.getValueLength())); } } finally { map.free(); @@ -446,7 +458,7 @@ public void randomizedTestWithRecordsLargerThanPageSize() { final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, pageSizeBytes); // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays // into ByteBuffers in order to use them as keys here. - final Map expected = new HashMap(); + final Map expected = new HashMap<>(); try { for (int i = 0; i < 1000; i++) { final byte[] key = getRandomByteArray(rand.nextInt(128)); @@ -459,7 +471,7 @@ public void randomizedTestWithRecordsLargerThanPageSize() { key.length ); Assert.assertFalse(loc.isDefined()); - Assert.assertTrue(loc.putNewKey( + Assert.assertTrue(loc.append( key, Platform.BYTE_ARRAY_OFFSET, key.length, @@ -472,18 +484,21 @@ public void randomizedTestWithRecordsLargerThanPageSize() { Assert.assertTrue(loc.isDefined()); Assert.assertEquals(key.length, loc.getKeyLength()); Assert.assertEquals(value.length, loc.getValueLength()); - Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length)); - Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length)); + Assert.assertTrue(arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), key.length)); + Assert.assertTrue( + arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), value.length)); } } for (Map.Entry entry : expected.entrySet()) { - final byte[] key = entry.getKey().array(); + final byte[] key = JavaUtils.bufferToArray(entry.getKey()); final byte[] value = entry.getValue(); final BytesToBytesMap.Location loc = map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); Assert.assertTrue(loc.isDefined()); - Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); - Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); + Assert.assertTrue( + arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength())); + Assert.assertTrue( + arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), loc.getValueLength())); } } finally { map.free(); @@ -499,7 +514,7 @@ public void failureToAllocateFirstPage() { final BytesToBytesMap.Location loc = map.lookup(emptyArray, Platform.LONG_ARRAY_OFFSET, 0); Assert.assertFalse(loc.isDefined()); - Assert.assertFalse(loc.putNewKey( + Assert.assertFalse(loc.append( emptyArray, Platform.LONG_ARRAY_OFFSET, 0, emptyArray, Platform.LONG_ARRAY_OFFSET, 0)); } finally { map.free(); @@ -520,7 +535,7 @@ public void failureToGrow() { final long[] arr = new long[]{i}; final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8); success = - loc.putNewKey(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8); + loc.append(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8); if (!success) { break; } @@ -534,13 +549,14 @@ public void failureToGrow() { @Test public void spillInIterator() throws IOException { - BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, blockManager, 1, 0.75, 1024, false); + BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, blockManager, serializerManager, 1, 0.75, 1024, false); try { int i; for (i = 0; i < 1024; i++) { final long[] arr = new long[]{i}; final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8); - loc.putNewKey(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8); + loc.append(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8); } BytesToBytesMap.MapIterator iter = map.iterator(); for (i = 0; i < 100; i++) { @@ -570,6 +586,44 @@ public void spillInIterator() throws IOException { } } + @Test + public void multipleValuesForSameKey() { + BytesToBytesMap map = + new BytesToBytesMap(taskMemoryManager, blockManager, serializerManager, 1, 0.75, 1024, false); + try { + int i; + for (i = 0; i < 1024; i++) { + final long[] arr = new long[]{i}; + map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8) + .append(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8); + } + assert map.numKeys() == 1024; + assert map.numValues() == 1024; + for (i = 0; i < 1024; i++) { + final long[] arr = new long[]{i}; + map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8) + .append(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8); + } + assert map.numKeys() == 1024; + assert map.numValues() == 2048; + for (i = 0; i < 1024; i++) { + final long[] arr = new long[]{i}; + final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8); + assert loc.isDefined(); + assert loc.nextValue(); + assert !loc.nextValue(); + } + BytesToBytesMap.MapIterator iter = map.iterator(); + for (i = 0; i < 2048; i++) { + assert iter.hasNext(); + final BytesToBytesMap.Location loc = iter.next(); + assert loc.isDefined(); + } + } finally { + map.free(); + } + } + @Test public void initialCapacityBoundsChecking() { try { @@ -592,7 +646,7 @@ public void initialCapacityBoundsChecking() { @Test public void testPeakMemoryUsed() { - final long recordLengthBytes = 24; + final long recordLengthBytes = 32; final long pageSizeBytes = 256 + 8; // 8 bytes for end-of-page marker final long numRecordsPerPage = (pageSizeBytes - 8) / recordLengthBytes; final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1024, pageSizeBytes); @@ -606,7 +660,7 @@ public void testPeakMemoryUsed() { try { for (long i = 0; i < numRecordsPerPage * 10; i++) { final long[] value = new long[]{i}; - map.lookup(value, Platform.LONG_ARRAY_OFFSET, 8).putNewKey( + map.lookup(value, Platform.LONG_ARRAY_OFFSET, 8).append( value, Platform.LONG_ARRAY_OFFSET, 8, diff --git a/core/src/test/java/org/apache/spark/util/collection/TestTimSort.java b/core/src/test/java/org/apache/spark/util/collection/TestTimSort.java index 45772b6d3c20d..e884b1bc123b8 100644 --- a/core/src/test/java/org/apache/spark/util/collection/TestTimSort.java +++ b/core/src/test/java/org/apache/spark/util/collection/TestTimSort.java @@ -76,7 +76,7 @@ private static int[] createArray(List runs, int length) { * @param length The sum of all run lengths that will be added to runs. */ private static List runsJDKWorstCase(int minRun, int length) { - List runs = new ArrayList(); + List runs = new ArrayList<>(); long runningTotal = 0, Y = minRun + 4, X = minRun; diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index cfead0e5924b8..a2253d8559640 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -19,7 +19,6 @@ import java.io.File; import java.io.IOException; -import java.io.InputStream; import java.io.OutputStream; import java.util.Arrays; import java.util.LinkedList; @@ -43,23 +42,27 @@ import org.apache.spark.executor.TaskMetrics; import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.serializer.JavaSerializer; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.util.Utils; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.junit.Assert.*; -import static org.mockito.AdditionalAnswers.returnsSecondArg; import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Mockito.*; public class UnsafeExternalSorterSuite { - final LinkedList spillFilesCreated = new LinkedList(); + final LinkedList spillFilesCreated = new LinkedList<>(); final TestMemoryManager memoryManager = - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")); + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")); final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); + final SerializerManager serializerManager = new SerializerManager( + new JavaSerializer(new SparkConf()), + new SparkConf().set("spark.shuffle.spill.compress", "false")); // Use integer comparison for comparing prefixes (which are partition ids, in this case) final PrefixComparator prefixComparator = new PrefixComparator() { @Override @@ -80,7 +83,6 @@ public int compare( } }; - SparkConf sparkConf; File tempDir; @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; @@ -99,15 +101,16 @@ public OutputStream apply(OutputStream stream) { @Before public void setUp() { MockitoAnnotations.initMocks(this); - sparkConf = new SparkConf(); tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test"); spillFilesCreated.clear(); taskContext = mock(TaskContext.class); when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); - when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() { + when(diskBlockManager.createTempLocalBlock()).thenAnswer( + new Answer>() { @Override - public Tuple2 answer(InvocationOnMock invocationOnMock) throws Throwable { + public Tuple2 answer(InvocationOnMock invocationOnMock) + throws Throwable { TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); File file = File.createTempFile("spillFile", ".spill", tempDir); spillFilesCreated.add(file); @@ -130,12 +133,11 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th (Integer) args[3], new CompressStream(), false, - (ShuffleWriteMetrics) args[4] + (ShuffleWriteMetrics) args[4], + (BlockId) args[0] ); } }); - when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) - .then(returnsSecondArg()); } @After @@ -171,6 +173,7 @@ private UnsafeExternalSorter newSorter() throws IOException { return UnsafeExternalSorter.create( taskMemoryManager, blockManager, + serializerManager, taskContext, recordComparator, prefixComparator, @@ -322,23 +325,23 @@ public void forcedSpillingWithReadIterator() throws Exception { record[0] = (long) i; sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0); } - assert(sorter.getNumberOfAllocatedPages() >= 2); + assertTrue(sorter.getNumberOfAllocatedPages() >= 2); UnsafeExternalSorter.SpillableIterator iter = (UnsafeExternalSorter.SpillableIterator) sorter.getSortedIterator(); int lastv = 0; for (int i = 0; i < n / 3; i++) { iter.hasNext(); iter.loadNext(); - assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == i); + assertTrue(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == i); lastv = i; } - assert(iter.spill() > 0); - assert(iter.spill() == 0); - assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == lastv); + assertTrue(iter.spill() > 0); + assertEquals(0, iter.spill()); + assertTrue(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == lastv); for (int i = n / 3; i < n; i++) { iter.hasNext(); iter.loadNext(); - assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == i); + assertEquals(i, Platform.getLong(iter.getBaseObject(), iter.getBaseOffset())); } sorter.cleanupResources(); assertSpillFilesWereCleanedUp(); @@ -354,15 +357,47 @@ public void forcedSpillingWithNotReadIterator() throws Exception { record[0] = (long) i; sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0); } - assert(sorter.getNumberOfAllocatedPages() >= 2); + assertTrue(sorter.getNumberOfAllocatedPages() >= 2); UnsafeExternalSorter.SpillableIterator iter = (UnsafeExternalSorter.SpillableIterator) sorter.getSortedIterator(); - assert(iter.spill() > 0); - assert(iter.spill() == 0); + assertTrue(iter.spill() > 0); + assertEquals(0, iter.spill()); for (int i = 0; i < n; i++) { iter.hasNext(); iter.loadNext(); - assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == i); + assertEquals(i, Platform.getLong(iter.getBaseObject(), iter.getBaseOffset())); + } + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); + } + + @Test + public void forcedSpillingWithoutComparator() throws Exception { + final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( + taskMemoryManager, + blockManager, + serializerManager, + taskContext, + null, + null, + /* initialSize */ 1024, + pageSizeBytes); + long[] record = new long[100]; + int recordSize = record.length * 8; + int n = (int) pageSizeBytes / recordSize * 3; + int batch = n / 4; + for (int i = 0; i < n; i++) { + record[0] = (long) i; + sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0); + if (i % batch == batch - 1) { + sorter.spill(); + } + } + UnsafeSorterIterator iter = sorter.getIterator(); + for (int i = 0; i < n; i++) { + iter.hasNext(); + iter.loadNext(); + assertEquals(i, Platform.getLong(iter.getBaseObject(), iter.getBaseOffset())); } sorter.cleanupResources(); assertSpillFilesWereCleanedUp(); @@ -376,6 +411,7 @@ public void testPeakMemoryUsed() throws Exception { final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( taskMemoryManager, blockManager, + serializerManager, taskContext, recordComparator, prefixComparator, @@ -390,7 +426,6 @@ public void testPeakMemoryUsed() throws Exception { for (int i = 0; i < numRecordsPerPage * 10; i++) { insertNumber(sorter, i); newPeakMemory = sorter.getPeakMemoryUsedBytes(); - // The first page is pre-allocated on instantiation if (i % numRecordsPerPage == 0) { // We allocated a new page for this record, so peak memory should change assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 642f6585f8a15..f90214fffd396 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -17,12 +17,15 @@ package org.apache.spark.util.collection.unsafe.sort; +import java.nio.charset.StandardCharsets; import java.util.Arrays; +import org.junit.Assert; import org.junit.Test; import org.apache.spark.HashPartitioner; import org.apache.spark.SparkConf; +import org.apache.spark.memory.TestMemoryConsumer; import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; @@ -39,19 +42,21 @@ public class UnsafeInMemorySorterSuite { private static String getStringFromDataPage(Object baseObject, long baseOffset, int length) { final byte[] strBytes = new byte[length]; Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, length); - return new String(strBytes); + return new String(strBytes, StandardCharsets.UTF_8); } @Test public void testSortingEmptyInput() { - final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter( - new TaskMemoryManager( - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0), + final TaskMemoryManager memoryManager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); + final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, + memoryManager, mock(RecordComparator.class), mock(PrefixComparator.class), 100); final UnsafeSorterIterator iter = sorter.getSortedIterator(); - assert(!iter.hasNext()); + Assert.assertFalse(iter.hasNext()); } @Test @@ -68,13 +73,14 @@ public void testSortingOnlyByIntegerPrefix() throws Exception { "Mango" }; final TaskMemoryManager memoryManager = new TaskMemoryManager( - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); final MemoryBlock dataPage = memoryManager.allocatePage(2048, null); final Object baseObject = dataPage.getBaseObject(); // Write the records into the data page: long position = dataPage.getBaseOffset(); for (String str : dataToSort) { - final byte[] strBytes = str.getBytes("utf-8"); + final byte[] strBytes = str.getBytes(StandardCharsets.UTF_8); Platform.putInt(baseObject, position, strBytes.length); position += 4; Platform.copyMemory( @@ -102,11 +108,14 @@ public int compare(long prefix1, long prefix2) { return (int) prefix1 - (int) prefix2; } }; - UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(memoryManager, recordComparator, - prefixComparator, dataToSort.length); + UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, memoryManager, + recordComparator, prefixComparator, dataToSort.length); // Given a page of records, insert those records into the sorter one-by-one: position = dataPage.getBaseOffset(); for (int i = 0; i < dataToSort.length; i++) { + if (!sorter.hasSpaceForAnotherRecord()) { + sorter.expandPointerArray(consumer.allocateArray(sorter.numRecords() * 2 * 2)); + } // position now points to the start of a record (which holds its length). final int recordLength = Platform.getInt(baseObject, position); final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java b/core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java deleted file mode 100644 index e38bc38949d7c..0000000000000 --- a/core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package test.org.apache.spark; - -import org.apache.spark.TaskContext; -import org.apache.spark.util.TaskCompletionListener; - - -/** - * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and - * TaskContext is Java friendly. - */ -public class JavaTaskCompletionListenerImpl implements TaskCompletionListener { - - @Override - public void onTaskCompletion(TaskContext context) { - context.isCompleted(); - context.isInterrupted(); - context.stageId(); - context.partitionId(); - context.isRunningLocally(); - context.addTaskCompletionListener(this); - } -} diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java index 4a918f725dc91..94f5805853e1e 100644 --- a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java +++ b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java @@ -18,6 +18,8 @@ package test.org.apache.spark; import org.apache.spark.TaskContext; +import org.apache.spark.util.TaskCompletionListener; +import org.apache.spark.util.TaskFailureListener; /** * Something to make sure that TaskContext can be used in Java. @@ -29,13 +31,39 @@ public static void test() { tc.isCompleted(); tc.isInterrupted(); - tc.isRunningLocally(); tc.addTaskCompletionListener(new JavaTaskCompletionListenerImpl()); + tc.addTaskFailureListener(new JavaTaskFailureListenerImpl()); tc.attemptNumber(); tc.partitionId(); tc.stageId(); tc.taskAttemptId(); } + + /** + * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and + * TaskContext is Java friendly. + */ + static class JavaTaskCompletionListenerImpl implements TaskCompletionListener { + @Override + public void onTaskCompletion(TaskContext context) { + context.isCompleted(); + context.isInterrupted(); + context.stageId(); + context.partitionId(); + context.addTaskCompletionListener(this); + } + } + + /** + * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and + * TaskContext is Java friendly. + */ + static class JavaTaskFailureListenerImpl implements TaskFailureListener { + @Override + public void onTaskFailure(TaskContext context, Throwable error) { + } + } + } diff --git a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json index d575bf2f284b9..1a13233133b1e 100644 --- a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json @@ -2,8 +2,13 @@ "id" : "local-1430917381534", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917391398, + "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:11.398GMT", + "lastUpdated" : "", + "duration" : 10505, "sparkUser" : "irashid", "completed" : true } ] @@ -12,14 +17,24 @@ "name" : "Spark shell", "attempts" : [ { "attemptId" : "2", + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917380950, + "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:00.950GMT", + "lastUpdated" : "", + "duration" : 57, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", + "startTimeEpoch" : 1430917380880, + "endTimeEpoch" : 1430917380890, + "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.880GMT", "endTime" : "2015-05-06T13:03:00.890GMT", + "lastUpdated" : "", + "duration" : 10, "sparkUser" : "irashid", "completed" : true } ] @@ -28,14 +43,24 @@ "name" : "Spark shell", "attempts" : [ { "attemptId" : "2", + "startTimeEpoch" : 1426633910242, + "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0, "startTime" : "2015-03-17T23:11:50.242GMT", "endTime" : "2015-03-17T23:12:25.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", + "startTimeEpoch" : 1426533910242, + "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0, "startTime" : "2015-03-16T19:25:10.242GMT", "endTime" : "2015-03-16T19:25:45.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true } ] @@ -43,8 +68,13 @@ "id" : "local-1425081759269", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1425081758277, + "endTimeEpoch" : 1425081766912, + "lastUpdatedEpoch" : 0, "startTime" : "2015-02-28T00:02:38.277GMT", "endTime" : "2015-02-28T00:02:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser" : "irashid", "completed" : true } ] @@ -52,8 +82,13 @@ "id" : "local-1422981780767", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1422981779720, + "endTimeEpoch" : 1422981788731, + "lastUpdatedEpoch" : 0, "startTime" : "2015-02-03T16:42:59.720GMT", "endTime" : "2015-02-03T16:43:08.731GMT", + "lastUpdated" : "", + "duration" : 9011, "sparkUser" : "irashid", "completed" : true } ] @@ -61,9 +96,14 @@ "id" : "local-1422981759269", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1422981758277, + "endTimeEpoch" : 1422981766912, + "lastUpdatedEpoch" : 0, "startTime" : "2015-02-03T16:42:38.277GMT", "endTime" : "2015-02-03T16:42:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser" : "irashid", "completed" : true } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json index 31ac9beea8788..8f8067f86d57f 100644 --- a/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 162, + "submissionTime" : "2015-02-03T16:43:07.191GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:07.191GMT", + "completionTime" : "2015-02-03T16:43:07.226GMT", "inputBytes" : 160, "inputRecords" : 0, "outputBytes" : 0, @@ -28,6 +31,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "submissionTime" : "2015-02-03T16:43:05.829GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", + "completionTime" : "2015-02-03T16:43:06.286GMT", "inputBytes" : 28000128, "inputRecords" : 0, "outputBytes" : 0, @@ -50,6 +56,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 4338, + "submissionTime" : "2015-02-03T16:43:04.228GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:04.234GMT", + "completionTime" : "2015-02-03T16:43:04.819GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -64,4 +73,4 @@ "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line9.$read$$iwC$$iwC$$iwC$$iwC.(:15)\n$line9.$read$$iwC$$iwC$$iwC.(:20)\n$line9.$read$$iwC$$iwC.(:22)\n$line9.$read$$iwC.(:24)\n$line9.$read.(:26)\n$line9.$read$.(:30)\n$line9.$read$.()\n$line9.$eval$.(:7)\n$line9.$eval$.()\n$line9.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", "accumulatorUpdates" : [ ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json index d575bf2f284b9..1a13233133b1e 100644 --- a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json @@ -2,8 +2,13 @@ "id" : "local-1430917381534", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917391398, + "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:11.398GMT", + "lastUpdated" : "", + "duration" : 10505, "sparkUser" : "irashid", "completed" : true } ] @@ -12,14 +17,24 @@ "name" : "Spark shell", "attempts" : [ { "attemptId" : "2", + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917380950, + "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:00.950GMT", + "lastUpdated" : "", + "duration" : 57, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", + "startTimeEpoch" : 1430917380880, + "endTimeEpoch" : 1430917380890, + "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.880GMT", "endTime" : "2015-05-06T13:03:00.890GMT", + "lastUpdated" : "", + "duration" : 10, "sparkUser" : "irashid", "completed" : true } ] @@ -28,14 +43,24 @@ "name" : "Spark shell", "attempts" : [ { "attemptId" : "2", + "startTimeEpoch" : 1426633910242, + "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0, "startTime" : "2015-03-17T23:11:50.242GMT", "endTime" : "2015-03-17T23:12:25.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", + "startTimeEpoch" : 1426533910242, + "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0, "startTime" : "2015-03-16T19:25:10.242GMT", "endTime" : "2015-03-16T19:25:45.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true } ] @@ -43,8 +68,13 @@ "id" : "local-1425081759269", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1425081758277, + "endTimeEpoch" : 1425081766912, + "lastUpdatedEpoch" : 0, "startTime" : "2015-02-28T00:02:38.277GMT", "endTime" : "2015-02-28T00:02:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser" : "irashid", "completed" : true } ] @@ -52,8 +82,13 @@ "id" : "local-1422981780767", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1422981779720, + "endTimeEpoch" : 1422981788731, + "lastUpdatedEpoch" : 0, "startTime" : "2015-02-03T16:42:59.720GMT", "endTime" : "2015-02-03T16:43:08.731GMT", + "lastUpdated" : "", + "duration" : 9011, "sparkUser" : "irashid", "completed" : true } ] @@ -61,9 +96,14 @@ "id" : "local-1422981759269", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1422981758277, + "endTimeEpoch" : 1422981766912, + "lastUpdatedEpoch" : 0, "startTime" : "2015-02-03T16:42:38.277GMT", "endTime" : "2015-02-03T16:42:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser" : "irashid", "completed" : true } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json index cb622e147249e..efc865919b0d7 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json @@ -1,14 +1,18 @@ [ { "id" : "", "hostPort" : "localhost:57971", - "rddBlocks" : 8, - "memoryUsed" : 28000128, + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, "diskUsed" : 0, + "totalCores" : 0, + "maxTasks" : 0, "activeTasks" : 0, "failedTasks" : 1, "completedTasks" : 31, "totalTasks" : 32, "totalDuration" : 8820, + "totalGCTime" : 352, "totalInputBytes" : 28000288, "totalShuffleRead" : 0, "totalShuffleWrite" : 13180, diff --git a/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json index bff6a4f69d077..08b692eda8028 100644 --- a/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 7, "numFailedTasks" : 1, "executorRunTime" : 278, + "submissionTime" : "2015-02-03T16:43:06.296GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:06.296GMT", + "completionTime" : "2015-02-03T16:43:06.347GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -20,4 +23,4 @@ "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line11.$read$$iwC$$iwC$$iwC$$iwC.(:20)\n$line11.$read$$iwC$$iwC$$iwC.(:25)\n$line11.$read$$iwC$$iwC.(:27)\n$line11.$read$$iwC.(:29)\n$line11.$read.(:31)\n$line11.$read$.(:35)\n$line11.$read$.()\n$line11.$eval$.(:7)\n$line11.$eval$.()\n$line11.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", "accumulatorUpdates" : [ ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json index 483632a3956ed..eacf04b9016ac 100644 --- a/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json @@ -2,9 +2,14 @@ "id" : "local-1422981759269", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1422981758277, + "endTimeEpoch" : 1422981766912, + "lastUpdatedEpoch" : 0, "startTime" : "2015-02-03T16:42:38.277GMT", "endTime" : "2015-02-03T16:42:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser" : "irashid", "completed" : true } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json index 4b85690fd9199..adad25bf17fd5 100644 --- a/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json @@ -2,8 +2,13 @@ "id" : "local-1422981780767", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1422981779720, + "endTimeEpoch" : 1422981788731, + "lastUpdatedEpoch" : 0, "startTime" : "2015-02-03T16:42:59.720GMT", "endTime" : "2015-02-03T16:43:08.731GMT", + "lastUpdated" : "", + "duration" : 9011, "sparkUser" : "irashid", "completed" : true } ] @@ -11,9 +16,14 @@ "id" : "local-1422981759269", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1422981758277, + "endTimeEpoch" : 1422981766912, + "lastUpdatedEpoch" : 0, "startTime" : "2015-02-03T16:42:38.277GMT", "endTime" : "2015-02-03T16:42:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser" : "irashid", "completed" : true } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json index 15c2de8ef99ea..a658909088a4a 100644 --- a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json @@ -2,8 +2,13 @@ "id" : "local-1430917381534", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917391398, + "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:11.398GMT", + "lastUpdated" : "", + "duration" : 10505, "sparkUser" : "irashid", "completed" : true } ] @@ -12,14 +17,24 @@ "name" : "Spark shell", "attempts" : [ { "attemptId" : "2", + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917380950, + "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:00.950GMT", + "lastUpdated" : "", + "duration" : 57, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", + "startTimeEpoch" : 1430917380880, + "endTimeEpoch" : 1430917380890, + "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.880GMT", "endTime" : "2015-05-06T13:03:00.890GMT", + "lastUpdated" : "", + "duration" : 10, "sparkUser" : "irashid", "completed" : true } ] @@ -28,14 +43,24 @@ "name" : "Spark shell", "attempts" : [ { "attemptId" : "2", + "startTimeEpoch" : 1426633910242, + "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0, "startTime" : "2015-03-17T23:11:50.242GMT", "endTime" : "2015-03-17T23:12:25.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", + "startTimeEpoch" : 1426533910242, + "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0, "startTime" : "2015-03-16T19:25:10.242GMT", "endTime" : "2015-03-16T19:25:45.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true } ] @@ -44,10 +69,15 @@ "name": "Spark shell", "attempts": [ { + "startTimeEpoch" : 1425081758277, + "endTimeEpoch" : 1425081766912, + "lastUpdatedEpoch" : 0, "startTime": "2015-02-28T00:02:38.277GMT", "endTime": "2015-02-28T00:02:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, "sparkUser": "irashid", "completed": true } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json index 07489ad96414a..0217facad9ded 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json @@ -2,9 +2,14 @@ "id" : "local-1422981780767", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1422981779720, + "endTimeEpoch" : 1422981788731, + "lastUpdatedEpoch" : 0, "startTime" : "2015-02-03T16:42:59.720GMT", "endTime" : "2015-02-03T16:43:08.731GMT", + "lastUpdated" : "", + "duration" : 9011, "sparkUser" : "irashid", "completed" : true } ] -} \ No newline at end of file +} diff --git a/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json index 8f3d7160c723f..b20a26648e430 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json @@ -3,15 +3,25 @@ "name" : "Spark shell", "attempts" : [ { "attemptId" : "2", + "startTimeEpoch" : 1426633910242, + "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0, "startTime" : "2015-03-17T23:11:50.242GMT", "endTime" : "2015-03-17T23:12:25.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true }, { "attemptId" : "1", + "startTimeEpoch" : 1426533910242, + "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0, "startTime" : "2015-03-16T19:25:10.242GMT", "endTime" : "2015-03-16T19:25:45.177GMT", + "lastUpdated" : "", + "duration" : 34935, "sparkUser" : "irashid", "completed" : true } ] -} \ No newline at end of file +} diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json index 111cb8163eb3d..b07011d4f113f 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "submissionTime" : "2015-02-03T16:43:05.829GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", + "completionTime" : "2015-02-03T16:43:06.286GMT", "inputBytes" : 28000128, "inputRecords" : 0, "outputBytes" : 0, @@ -267,4 +270,4 @@ "diskBytesSpilled" : 0 } } -} \ No newline at end of file +} diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json index ef339f89afa45..2f71520549e1f 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "submissionTime" : "2015-02-03T16:43:05.829GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", + "completionTime" : "2015-02-03T16:43:06.286GMT", "inputBytes" : 28000128, "inputRecords" : 0, "outputBytes" : 0, @@ -267,4 +270,4 @@ "diskBytesSpilled" : 0 } } -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/rdd_list_storage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/rdd_list_storage_json_expectation.json index f79a31022d214..8878e547a7984 100644 --- a/core/src/test/resources/HistoryServerExpectations/rdd_list_storage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/rdd_list_storage_json_expectation.json @@ -1,9 +1 @@ -[ { - "id" : 0, - "name" : "0", - "numPartitions" : 8, - "numCachedPartitions" : 8, - "storageLevel" : "Memory Deserialized 1x Replicated", - "memoryUsed" : 28000128, - "diskUsed" : 0 -} ] \ No newline at end of file +[ ] \ No newline at end of file diff --git a/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json index 056fac7088594..5b957ed549556 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 162, + "submissionTime" : "2015-02-03T16:43:07.191GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:07.191GMT", + "completionTime" : "2015-02-03T16:43:07.226GMT", "inputBytes" : 160, "inputRecords" : 0, "outputBytes" : 0, @@ -28,6 +31,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "submissionTime" : "2015-02-03T16:43:05.829GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", + "completionTime" : "2015-02-03T16:43:06.286GMT", "inputBytes" : 28000128, "inputRecords" : 0, "outputBytes" : 0, @@ -50,6 +56,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 4338, + "submissionTime" : "2015-02-03T16:43:04.228GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:04.234GMT", + "completionTime" : "2015-02-03T16:43:04.819GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -72,6 +81,9 @@ "numCompleteTasks" : 7, "numFailedTasks" : 1, "executorRunTime" : 278, + "submissionTime" : "2015-02-03T16:43:06.296GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:06.296GMT", + "completionTime" : "2015-02-03T16:43:06.347GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -86,4 +98,4 @@ "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line11.$read$$iwC$$iwC$$iwC$$iwC.(:20)\n$line11.$read$$iwC$$iwC$$iwC.(:25)\n$line11.$read$$iwC$$iwC.(:27)\n$line11.$read$$iwC.(:29)\n$line11.$read.(:31)\n$line11.$read$.(:35)\n$line11.$read$.()\n$line11.$eval$.(:7)\n$line11.$eval$.()\n$line11.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", "accumulatorUpdates" : [ ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json index 79ccacd309693..afa425f8c27bb 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 120, + "submissionTime" : "2015-03-16T19:25:36.103GMT", + "firstTaskLaunchedTime" : "2015-03-16T19:25:36.515GMT", + "completionTime" : "2015-03-16T19:25:36.579GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -24,4 +27,4 @@ "name" : "my counter", "value" : "5050" } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json index 32d5731676ad5..12665a152c9ec 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 120, + "submissionTime" : "2015-03-16T19:25:36.103GMT", + "firstTaskLaunchedTime" : "2015-03-16T19:25:36.515GMT", + "completionTime" : "2015-03-16T19:25:36.579GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -239,4 +242,4 @@ "diskBytesSpilled" : 0 } } -} \ No newline at end of file +} diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties index a54d27de91ed2..fb9d9851cb4de 100644 --- a/core/src/test/resources/log4j.properties +++ b/core/src/test/resources/log4j.properties @@ -33,5 +33,4 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%t: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN -org.spark-project.jetty.LEVEL=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/core/src/test/resources/spark-events/local-1422981759269/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1422981759269 similarity index 100% rename from core/src/test/resources/spark-events/local-1422981759269/EVENT_LOG_1 rename to core/src/test/resources/spark-events/local-1422981759269 diff --git a/core/src/test/resources/spark-events/local-1422981759269/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1422981759269/SPARK_VERSION_1.2.0 deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1422981780767/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1422981780767 similarity index 100% rename from core/src/test/resources/spark-events/local-1422981780767/EVENT_LOG_1 rename to core/src/test/resources/spark-events/local-1422981780767 diff --git a/core/src/test/resources/spark-events/local-1422981780767/APPLICATION_COMPLETE b/core/src/test/resources/spark-events/local-1422981780767/APPLICATION_COMPLETE deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1422981780767/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1422981780767/SPARK_VERSION_1.2.0 deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1425081759269/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1425081759269 similarity index 100% rename from core/src/test/resources/spark-events/local-1425081759269/EVENT_LOG_1 rename to core/src/test/resources/spark-events/local-1425081759269 diff --git a/core/src/test/resources/spark-events/local-1425081759269/APPLICATION_COMPLETE b/core/src/test/resources/spark-events/local-1425081759269/APPLICATION_COMPLETE deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1425081759269/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1425081759269/SPARK_VERSION_1.2.0 deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1426533911241/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1426533911241 similarity index 100% rename from core/src/test/resources/spark-events/local-1426533911241/EVENT_LOG_1 rename to core/src/test/resources/spark-events/local-1426533911241 diff --git a/core/src/test/resources/spark-events/local-1426533911241/APPLICATION_COMPLETE b/core/src/test/resources/spark-events/local-1426533911241/APPLICATION_COMPLETE deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1426533911241/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1426533911241/SPARK_VERSION_1.2.0 deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1426633911242/EVENT_LOG_1 b/core/src/test/resources/spark-events/local-1426633911242 similarity index 100% rename from core/src/test/resources/spark-events/local-1426633911242/EVENT_LOG_1 rename to core/src/test/resources/spark-events/local-1426633911242 diff --git a/core/src/test/resources/spark-events/local-1426633911242/APPLICATION_COMPLETE b/core/src/test/resources/spark-events/local-1426633911242/APPLICATION_COMPLETE deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark-events/local-1426633911242/SPARK_VERSION_1.2.0 b/core/src/test/resources/spark-events/local-1426633911242/SPARK_VERSION_1.2.0 deleted file mode 100755 index e69de29bb2d1d..0000000000000 diff --git a/core/src/test/resources/spark.keystore b/core/src/test/resources/spark.keystore new file mode 100644 index 0000000000000..f30716b57b302 Binary files /dev/null and b/core/src/test/resources/spark.keystore differ diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 5b84acf40be4e..37879d11caec4 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -17,18 +17,32 @@ package org.apache.spark +import java.util.Properties +import java.util.concurrent.Semaphore +import javax.annotation.concurrent.GuardedBy + import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.ref.WeakReference +import scala.util.control.NonFatal import org.scalatest.Matchers import org.scalatest.exceptions.TestFailedException import org.apache.spark.scheduler._ +import org.apache.spark.serializer.JavaSerializer class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext { - import InternalAccumulator._ + import AccumulatorParam._ + + override def afterEach(): Unit = { + try { + Accumulators.clear() + } finally { + super.afterEach() + } + } implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] = new AccumulableParam[mutable.Set[A], A] { @@ -45,7 +59,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex } } - test ("basic accumulation"){ + test ("basic accumulation") { sc = new SparkContext("local", "test") val acc : Accumulator[Int] = sc.accumulator(0) @@ -59,7 +73,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex longAcc.value should be (210L + maxInt * 20) } - test ("value not assignable from tasks") { + test("value not assignable from tasks") { sc = new SparkContext("local", "test") val acc : Accumulator[Int] = sc.accumulator(0) @@ -84,7 +98,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex } } - test ("value not readable in tasks") { + test("value not readable in tasks") { val maxI = 1000 for (nThreads <- List(1, 10)) { // test single & multi-threaded sc = new SparkContext("local[" + nThreads + "]", "test") @@ -159,193 +173,159 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex assert(!Accumulators.originals.get(accId).isDefined) } - test("internal accumulators in TaskContext") { + test("get accum") { sc = new SparkContext("local", "test") - val accums = InternalAccumulator.create(sc) - val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null, accums) - val internalMetricsToAccums = taskContext.internalMetricsToAccumulators - val collectedInternalAccums = taskContext.collectInternalAccumulators() - val collectedAccums = taskContext.collectAccumulators() - assert(internalMetricsToAccums.size > 0) - assert(internalMetricsToAccums.values.forall(_.isInternal)) - assert(internalMetricsToAccums.contains(TEST_ACCUMULATOR)) - val testAccum = internalMetricsToAccums(TEST_ACCUMULATOR) - assert(collectedInternalAccums.size === internalMetricsToAccums.size) - assert(collectedInternalAccums.size === collectedAccums.size) - assert(collectedInternalAccums.contains(testAccum.id)) - assert(collectedAccums.contains(testAccum.id)) - } + // Don't register with SparkContext for cleanup + var acc = new Accumulable[Int, Int](0, IntAccumulatorParam, None, true, true) + val accId = acc.id + val ref = WeakReference(acc) + assert(ref.get.isDefined) + Accumulators.register(ref.get.get) - test("internal accumulators in a stage") { - val listener = new SaveInfoListener - val numPartitions = 10 - sc = new SparkContext("local", "test") - sc.addSparkListener(listener) - // Have each task add 1 to the internal accumulator - val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitions { iter => - TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 - iter - } - // Register asserts in job completion callback to avoid flakiness - listener.registerJobCompletionCallback { _ => - val stageInfos = listener.getCompletedStageInfos - val taskInfos = listener.getCompletedTaskInfos - assert(stageInfos.size === 1) - assert(taskInfos.size === numPartitions) - // The accumulator values should be merged in the stage - val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) - assert(stageAccum.value.toLong === numPartitions) - // The accumulator should be updated locally on each task - val taskAccumValues = taskInfos.map { taskInfo => - val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) - assert(taskAccum.update.isDefined) - assert(taskAccum.update.get.toLong === 1) - taskAccum.value.toLong - } - // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions - assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + // Remove the explicit reference to it and allow weak reference to get garbage collected + acc = null + System.gc() + assert(ref.get.isEmpty) + + // Getting a garbage collected accum should throw error + intercept[IllegalAccessError] { + Accumulators.get(accId) } - rdd.count() + + // Getting a normal accumulator. Note: this has to be separate because referencing an + // accumulator above in an `assert` would keep it from being garbage collected. + val acc2 = new Accumulable[Long, Long](0L, LongAccumulatorParam, None, true, true) + Accumulators.register(acc2) + assert(Accumulators.get(acc2.id) === Some(acc2)) + + // Getting an accumulator that does not exist should return None + assert(Accumulators.get(100000).isEmpty) } - test("internal accumulators in multiple stages") { - val listener = new SaveInfoListener - val numPartitions = 10 - sc = new SparkContext("local", "test") - sc.addSparkListener(listener) - // Each stage creates its own set of internal accumulators so the - // values for the same metric should not be mixed up across stages - val rdd = sc.parallelize(1 to 100, numPartitions) - .map { i => (i, i) } - .mapPartitions { iter => - TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 - iter - } - .reduceByKey { case (x, y) => x + y } - .mapPartitions { iter => - TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 10 - iter - } - .repartition(numPartitions * 2) - .mapPartitions { iter => - TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 100 - iter - } - // Register asserts in job completion callback to avoid flakiness - listener.registerJobCompletionCallback { _ => - // We ran 3 stages, and the accumulator values should be distinct - val stageInfos = listener.getCompletedStageInfos - assert(stageInfos.size === 3) - val (firstStageAccum, secondStageAccum, thirdStageAccum) = - (findAccumulableInfo(stageInfos(0).accumulables.values, TEST_ACCUMULATOR), - findAccumulableInfo(stageInfos(1).accumulables.values, TEST_ACCUMULATOR), - findAccumulableInfo(stageInfos(2).accumulables.values, TEST_ACCUMULATOR)) - assert(firstStageAccum.value.toLong === numPartitions) - assert(secondStageAccum.value.toLong === numPartitions * 10) - assert(thirdStageAccum.value.toLong === numPartitions * 2 * 100) - } - rdd.count() + test("only external accums are automatically registered") { + val accEx = new Accumulator(0, IntAccumulatorParam, Some("external"), internal = false) + val accIn = new Accumulator(0, IntAccumulatorParam, Some("internal"), internal = true) + assert(!accEx.isInternal) + assert(accIn.isInternal) + assert(Accumulators.get(accEx.id).isDefined) + assert(Accumulators.get(accIn.id).isEmpty) } - test("internal accumulators in fully resubmitted stages") { - testInternalAccumulatorsWithFailedTasks((i: Int) => true) // fail all tasks + test("copy") { + val acc1 = new Accumulable[Long, Long](456L, LongAccumulatorParam, Some("x"), true, false) + val acc2 = acc1.copy() + assert(acc1.id === acc2.id) + assert(acc1.value === acc2.value) + assert(acc1.name === acc2.name) + assert(acc1.isInternal === acc2.isInternal) + assert(acc1.countFailedValues === acc2.countFailedValues) + assert(acc1 !== acc2) + // Modifying one does not affect the other + acc1.add(44L) + assert(acc1.value === 500L) + assert(acc2.value === 456L) + acc2.add(144L) + assert(acc1.value === 500L) + assert(acc2.value === 600L) } - test("internal accumulators in partially resubmitted stages") { - testInternalAccumulatorsWithFailedTasks((i: Int) => i % 2 == 0) // fail a subset + test("register multiple accums with same ID") { + // Make sure these are internal accums so we don't automatically register them already + val acc1 = new Accumulable[Int, Int](0, IntAccumulatorParam, None, true, true) + val acc2 = acc1.copy() + assert(acc1 !== acc2) + assert(acc1.id === acc2.id) + assert(Accumulators.originals.isEmpty) + assert(Accumulators.get(acc1.id).isEmpty) + Accumulators.register(acc1) + Accumulators.register(acc2) + // The second one does not override the first one + assert(Accumulators.originals.size === 1) + assert(Accumulators.get(acc1.id) === Some(acc1)) } - /** - * Return the accumulable info that matches the specified name. - */ - private def findAccumulableInfo( - accums: Iterable[AccumulableInfo], - name: String): AccumulableInfo = { - accums.find { a => a.name == name }.getOrElse { - throw new TestFailedException(s"internal accumulator '$name' not found", 0) - } + test("string accumulator param") { + val acc = new Accumulator("", StringAccumulatorParam, Some("darkness")) + assert(acc.value === "") + acc.setValue("feeds") + assert(acc.value === "feeds") + acc.add("your") + assert(acc.value === "your") // value is overwritten, not concatenated + acc += "soul" + assert(acc.value === "soul") + acc ++= "with" + assert(acc.value === "with") + acc.merge("kindness") + assert(acc.value === "kindness") } - /** - * Test whether internal accumulators are merged properly if some tasks fail. - */ - private def testInternalAccumulatorsWithFailedTasks(failCondition: (Int => Boolean)): Unit = { - val listener = new SaveInfoListener - val numPartitions = 10 - val numFailedPartitions = (0 until numPartitions).count(failCondition) - // This says use 1 core and retry tasks up to 2 times - sc = new SparkContext("local[1, 2]", "test") - sc.addSparkListener(listener) - val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) => - val taskContext = TaskContext.get() - taskContext.internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 - // Fail the first attempts of a subset of the tasks - if (failCondition(i) && taskContext.attemptNumber() == 0) { - throw new Exception("Failing a task intentionally.") - } - iter - } - // Register asserts in job completion callback to avoid flakiness - listener.registerJobCompletionCallback { _ => - val stageInfos = listener.getCompletedStageInfos - val taskInfos = listener.getCompletedTaskInfos - assert(stageInfos.size === 1) - assert(taskInfos.size === numPartitions + numFailedPartitions) - val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) - // We should not double count values in the merged accumulator - assert(stageAccum.value.toLong === numPartitions) - val taskAccumValues = taskInfos.flatMap { taskInfo => - if (!taskInfo.failed) { - // If a task succeeded, its update value should always be 1 - val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) - assert(taskAccum.update.isDefined) - assert(taskAccum.update.get.toLong === 1) - Some(taskAccum.value.toLong) - } else { - // If a task failed, we should not get its accumulator values - assert(taskInfo.accumulables.isEmpty) - None - } - } - assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) - } - rdd.count() + test("list accumulator param") { + val acc = new Accumulator(Seq.empty[Int], new ListAccumulatorParam[Int], Some("numbers")) + assert(acc.value === Seq.empty[Int]) + acc.add(Seq(1, 2)) + assert(acc.value === Seq(1, 2)) + acc += Seq(3, 4) + assert(acc.value === Seq(1, 2, 3, 4)) + acc ++= Seq(5, 6) + assert(acc.value === Seq(1, 2, 3, 4, 5, 6)) + acc.merge(Seq(7, 8)) + assert(acc.value === Seq(1, 2, 3, 4, 5, 6, 7, 8)) + acc.setValue(Seq(9, 10)) + assert(acc.value === Seq(9, 10)) + } + + test("value is reset on the executors") { + val acc1 = new Accumulator(0, IntAccumulatorParam, Some("thing"), internal = false) + val acc2 = new Accumulator(0L, LongAccumulatorParam, Some("thing2"), internal = false) + val externalAccums = Seq(acc1, acc2) + val internalAccums = InternalAccumulator.createAll() + // Set some values; these should not be observed later on the "executors" + acc1.setValue(10) + acc2.setValue(20L) + internalAccums + .find(_.name == Some(InternalAccumulator.TEST_ACCUM)) + .get.asInstanceOf[Accumulator[Long]] + .setValue(30L) + // Simulate the task being serialized and sent to the executors. + val dummyTask = new DummyTask(internalAccums, externalAccums) + val serInstance = new JavaSerializer(new SparkConf).newInstance() + val taskSer = Task.serializeWithDependencies( + dummyTask, mutable.HashMap(), mutable.HashMap(), serInstance) + // Now we're on the executors. + // Deserialize the task and assert that its accumulators are zero'ed out. + val (_, _, _, taskBytes) = Task.deserializeWithDependencies(taskSer) + val taskDeser = serInstance.deserialize[DummyTask]( + taskBytes, Thread.currentThread.getContextClassLoader) + // Assert that executors see only zeros + taskDeser.externalAccums.foreach { a => assert(a.localValue == a.zero) } + taskDeser.internalAccums.foreach { a => assert(a.localValue == a.zero) } } } private[spark] object AccumulatorSuite { + import InternalAccumulator._ + /** - * Run one or more Spark jobs and verify that the peak execution memory accumulator - * is updated afterwards. + * Run one or more Spark jobs and verify that in at least one job the peak execution memory + * accumulator is updated afterwards. */ def verifyPeakExecutionMemorySet( sc: SparkContext, testName: String)(testBody: => Unit): Unit = { val listener = new SaveInfoListener sc.addSparkListener(listener) - // Register asserts in job completion callback to avoid flakiness - listener.registerJobCompletionCallback { jobId => - if (jobId == 0) { - // The first job is a dummy one to verify that the accumulator does not already exist - val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values) - assert(!accums.exists(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)) - } else { - // In the subsequent jobs, verify that peak execution memory is updated - val accum = listener.getCompletedStageInfos - .flatMap(_.accumulables.values) - .find(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY) - .getOrElse { - throw new TestFailedException( - s"peak execution memory accumulator not set in '$testName'", 0) - } - assert(accum.value.toLong > 0) - } - } - // Run the jobs - sc.parallelize(1 to 10).count() testBody + // wait until all events have been processed before proceeding to assert things + sc.listenerBus.waitUntilEmpty(10 * 1000) + val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values) + val isSet = accums.exists { a => + a.name == Some(PEAK_EXECUTION_MEMORY) && a.value.exists(_.asInstanceOf[Long] > 0L) + } + if (!isSet) { + throw new TestFailedException(s"peak execution memory accumulator not set in '$testName'", 0) + } } } @@ -353,21 +333,57 @@ private[spark] object AccumulatorSuite { * A simple listener that keeps track of the TaskInfos and StageInfos of all completed jobs. */ private class SaveInfoListener extends SparkListener { - private val completedStageInfos: ArrayBuffer[StageInfo] = new ArrayBuffer[StageInfo] - private val completedTaskInfos: ArrayBuffer[TaskInfo] = new ArrayBuffer[TaskInfo] - private var jobCompletionCallback: (Int => Unit) = null // parameter is job ID + type StageId = Int + type StageAttemptId = Int + + private val completedStageInfos = new ArrayBuffer[StageInfo] + private val completedTaskInfos = + new mutable.HashMap[(StageId, StageAttemptId), ArrayBuffer[TaskInfo]] + + // Callback to call when a job completes. Parameter is job ID. + @GuardedBy("this") + private var jobCompletionCallback: () => Unit = null + private val jobCompletionSem = new Semaphore(0) + private var exception: Throwable = null def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq - def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.toArray.toSeq + def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.values.flatten.toSeq + def getCompletedTaskInfos(stageId: StageId, stageAttemptId: StageAttemptId): Seq[TaskInfo] = + completedTaskInfos.getOrElse((stageId, stageAttemptId), Seq.empty[TaskInfo]) + + /** + * If `jobCompletionCallback` is set, block until the next call has finished. + * If the callback failed with an exception, throw it. + */ + def awaitNextJobCompletion(): Unit = { + if (jobCompletionCallback != null) { + jobCompletionSem.acquire() + if (exception != null) { + exception = null + throw exception + } + } + } - /** Register a callback to be called on job end. */ - def registerJobCompletionCallback(callback: (Int => Unit)): Unit = { + /** + * Register a callback to be called on job end. + * A call to this should be followed by [[awaitNextJobCompletion]]. + */ + def registerJobCompletionCallback(callback: () => Unit): Unit = { jobCompletionCallback = callback } override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { if (jobCompletionCallback != null) { - jobCompletionCallback(jobEnd.jobId) + try { + jobCompletionCallback() + } catch { + // Store any exception thrown here so we can throw them later in the main thread. + // Otherwise, if `jobCompletionCallback` threw something it wouldn't fail the test. + case NonFatal(e) => exception = e + } finally { + jobCompletionSem.release() + } } } @@ -376,6 +392,18 @@ private class SaveInfoListener extends SparkListener { } override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { - completedTaskInfos += taskEnd.taskInfo + completedTaskInfos.getOrElseUpdate( + (taskEnd.stageId, taskEnd.stageAttemptId), new ArrayBuffer[TaskInfo]) += taskEnd.taskInfo } } + + +/** + * A dummy [[Task]] that contains internal and external [[Accumulator]]s. + */ +private[spark] class DummyTask( + val internalAccums: Seq[Accumulator[_]], + val externalAccums: Seq[Accumulator[_]]) + extends Task[Int](0, 0, 0, internalAccums, new Properties) { + override def runTask(c: TaskContext): Int = 1 +} diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala deleted file mode 100644 index cb8bd04e496a7..0000000000000 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import org.mockito.Mockito._ -import org.scalatest.BeforeAndAfter -import org.scalatest.mock.MockitoSugar - -import org.apache.spark.executor.{DataReadMethod, TaskMetrics} -import org.apache.spark.rdd.RDD -import org.apache.spark.storage._ - -// TODO: Test the CacheManager's thread-safety aspects -class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfter - with MockitoSugar { - - var blockManager: BlockManager = _ - var cacheManager: CacheManager = _ - var split: Partition = _ - /** An RDD which returns the values [1, 2, 3, 4]. */ - var rdd: RDD[Int] = _ - var rdd2: RDD[Int] = _ - var rdd3: RDD[Int] = _ - - before { - sc = new SparkContext("local", "test") - blockManager = mock[BlockManager] - cacheManager = new CacheManager(blockManager) - split = new Partition { override def index: Int = 0 } - rdd = new RDD[Int](sc, Nil) { - override def getPartitions: Array[Partition] = Array(split) - override val getDependencies = List[Dependency[_]]() - override def compute(split: Partition, context: TaskContext): Iterator[Int] = - Array(1, 2, 3, 4).iterator - } - rdd2 = new RDD[Int](sc, List(new OneToOneDependency(rdd))) { - override def getPartitions: Array[Partition] = firstParent[Int].partitions - override def compute(split: Partition, context: TaskContext): Iterator[Int] = - firstParent[Int].iterator(split, context) - }.cache() - rdd3 = new RDD[Int](sc, List(new OneToOneDependency(rdd2))) { - override def getPartitions: Array[Partition] = firstParent[Int].partitions - override def compute(split: Partition, context: TaskContext): Iterator[Int] = - firstParent[Int].iterator(split, context) - }.cache() - } - - test("get uncached rdd") { - // Do not mock this test, because attempting to match Array[Any], which is not covariant, - // in blockManager.put is a losing battle. You have been warned. - blockManager = sc.env.blockManager - cacheManager = sc.env.cacheManager - val context = TaskContext.empty() - val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) - val getValue = blockManager.get(RDDBlockId(rdd.id, split.index)) - assert(computeValue.toList === List(1, 2, 3, 4)) - assert(getValue.isDefined, "Block cached from getOrCompute is not found!") - assert(getValue.get.data.toList === List(1, 2, 3, 4)) - } - - test("get cached rdd") { - val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12) - when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result)) - - val context = TaskContext.empty() - val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) - assert(value.toList === List(5, 6, 7)) - } - - test("get uncached local rdd") { - // Local computation should not persist the resulting value, so don't expect a put(). - when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None) - - val context = new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty, runningLocally = true) - val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) - assert(value.toList === List(1, 2, 3, 4)) - } - - test("verify task metrics updated correctly") { - cacheManager = sc.env.cacheManager - val context = TaskContext.empty() - cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) - assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) - } -} diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 119e5fc28e412..9f94e36324536 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -21,17 +21,231 @@ import java.io.File import scala.reflect.ClassTag +import org.apache.hadoop.fs.Path + import org.apache.spark.rdd._ import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.util.Utils +trait RDDCheckpointTester { self: SparkFunSuite => + + protected val partitioner = new HashPartitioner(2) + + private def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect() + + /** Implementations of this trait must implement this method */ + protected def sparkContext: SparkContext + + /** + * Test checkpointing of the RDD generated by the given operation. It tests whether the + * serialized size of the RDD is reduce after checkpointing or not. This function should be called + * on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.). + * + * @param op an operation to run on the RDD + * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints + * @param collectFunc a function for collecting the values in the RDD, in case there are + * non-comparable types like arrays that we want to convert to something + * that supports == + */ + protected def testRDD[U: ClassTag]( + op: (RDD[Int]) => RDD[U], + reliableCheckpoint: Boolean, + collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { + // Generate the final RDD using given RDD operation + val baseRDD = generateFatRDD() + val operatedRDD = op(baseRDD) + val parentDependency = operatedRDD.dependencies.headOption.orNull + val rddType = operatedRDD.getClass.getSimpleName + val numPartitions = operatedRDD.partitions.length + + // Force initialization of all the data structures in RDDs + // Without this, serializing the RDD will give a wrong estimate of the size of the RDD + initializeRdd(operatedRDD) + + val partitionsBeforeCheckpoint = operatedRDD.partitions + + // Find serialized sizes before and after the checkpoint + logInfo("RDD before checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) + checkpoint(operatedRDD, reliableCheckpoint) + val result = collectFunc(operatedRDD) + operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables + val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + + // Test whether the checkpoint file has been created + if (reliableCheckpoint) { + assert(operatedRDD.getCheckpointFile.nonEmpty) + val recoveredRDD = sparkContext.checkpointFile[U](operatedRDD.getCheckpointFile.get) + assert(collectFunc(recoveredRDD) === result) + assert(recoveredRDD.partitioner === operatedRDD.partitioner) + } + + // Test whether dependencies have been changed from its earlier parent RDD + assert(operatedRDD.dependencies.head != parentDependency) + + // Test whether the partitions have been changed from its earlier partitions + assert(operatedRDD.partitions.toList != partitionsBeforeCheckpoint.toList) + + // Test whether the partitions have been changed to the new Hadoop partitions + assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList) + + // Test whether the number of partitions is same as before + assert(operatedRDD.partitions.length === numPartitions) + + // Test whether the data in the checkpointed RDD is same as original + assert(collectFunc(operatedRDD) === result) + + // Test whether serialized size of the RDD has reduced. + logInfo("Size of " + rddType + + " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") + assert( + rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, + "Size of " + rddType + " did not reduce after checkpointing " + + " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" + ) + } + + /** + * Test whether checkpointing of the parent of the generated RDD also + * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent + * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed, + * the generated RDD will remember the partitions and therefore potentially the whole lineage. + * This function should be called only those RDD whose partitions refer to parent RDD's + * partitions (i.e., do not call it on simple RDD like MappedRDD). + * + * @param op an operation to run on the RDD + * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints + * @param collectFunc a function for collecting the values in the RDD, in case there are + * non-comparable types like arrays that we want to convert to something + * that supports == + */ + protected def testRDDPartitions[U: ClassTag]( + op: (RDD[Int]) => RDD[U], + reliableCheckpoint: Boolean, + collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { + // Generate the final RDD using given RDD operation + val baseRDD = generateFatRDD() + val operatedRDD = op(baseRDD) + val parentRDDs = operatedRDD.dependencies.map(_.rdd) + val rddType = operatedRDD.getClass.getSimpleName + + // Force initialization of all the data structures in RDDs + // Without this, serializing the RDD will give a wrong estimate of the size of the RDD + initializeRdd(operatedRDD) + + // Find serialized sizes before and after the checkpoint + logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) + // checkpoint the parent RDD, not the generated one + parentRDDs.foreach { rdd => + checkpoint(rdd, reliableCheckpoint) + } + val result = collectFunc(operatedRDD) // force checkpointing + operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables + val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + + // Test whether the data in the checkpointed RDD is same as original + assert(collectFunc(operatedRDD) === result) + + // Test whether serialized size of the partitions has reduced + logInfo("Size of partitions of " + rddType + + " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]") + assert( + partitionSizeAfterCheckpoint < partitionSizeBeforeCheckpoint, + "Size of " + rddType + " partitions did not reduce after checkpointing parent RDDs" + + " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]" + ) + } + + /** + * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks + * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint. + */ + private def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { + val rddSize = Utils.serialize(rdd).size + val rddCpDataSize = Utils.serialize(rdd.checkpointData).size + val rddPartitionSize = Utils.serialize(rdd.partitions).size + val rddDependenciesSize = Utils.serialize(rdd.dependencies).size + + // Print detailed size, helps in debugging + logInfo("Serialized sizes of " + rdd + + ": RDD = " + rddSize + + ", RDD checkpoint data = " + rddCpDataSize + + ", RDD partitions = " + rddPartitionSize + + ", RDD dependencies = " + rddDependenciesSize + ) + // this makes sure that serializing the RDD's checkpoint data does not + // serialize the whole RDD as well + assert( + rddSize > rddCpDataSize, + "RDD's checkpoint data (" + rddCpDataSize + ") is equal or larger than the " + + "whole RDD with checkpoint data (" + rddSize + ")" + ) + (rddSize - rddCpDataSize, rddPartitionSize) + } + + /** + * Serialize and deserialize an object. This is useful to verify the objects + * contents after deserialization (e.g., the contents of an RDD split after + * it is sent to a slave along with a task) + */ + protected def serializeDeserialize[T](obj: T): T = { + val bytes = Utils.serialize(obj) + Utils.deserialize[T](bytes) + } + + /** + * Recursively force the initialization of the all members of an RDD and it parents. + */ + private def initializeRdd(rdd: RDD[_]): Unit = { + rdd.partitions // forces the initialization of the partitions + rdd.dependencies.map(_.rdd).foreach(initializeRdd) + } + + /** Checkpoint the RDD either locally or reliably. */ + protected def checkpoint(rdd: RDD[_], reliableCheckpoint: Boolean): Unit = { + if (reliableCheckpoint) { + rdd.checkpoint() + } else { + rdd.localCheckpoint() + } + } + + /** Run a test twice, once for local checkpointing and once for reliable checkpointing. */ + protected def runTest( + name: String, + skipLocalCheckpoint: Boolean = false + )(body: Boolean => Unit): Unit = { + test(name + " [reliable checkpoint]")(body(true)) + if (!skipLocalCheckpoint) { + test(name + " [local checkpoint]")(body(false)) + } + } + + /** + * Generate an RDD such that both the RDD and its partitions have large size. + */ + protected def generateFatRDD(): RDD[Int] = { + new FatRDD(sparkContext.makeRDD(1 to 100, 4)).map(x => x) + } + + /** + * Generate an pair RDD (with partitioner) such that both the RDD and its partitions + * have large size. + */ + protected def generateFatPairRDD(): RDD[(Int, Int)] = { + new FatPairRDD(sparkContext.makeRDD(1 to 100, 4), partitioner).mapValues(x => x) + } +} + /** * Test suite for end-to-end checkpointing functionality. * This tests both reliable checkpoints and local checkpoints. */ -class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging { +class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalSparkContext { private var checkpointDir: File = _ - private val partitioner = new HashPartitioner(2) override def beforeEach(): Unit = { super.beforeEach() @@ -42,10 +256,15 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging } override def afterEach(): Unit = { - super.afterEach() - Utils.deleteRecursively(checkpointDir) + try { + Utils.deleteRecursively(checkpointDir) + } finally { + super.afterEach() + } } + override def sparkContext: SparkContext = sc + runTest("basic checkpointing") { reliableCheckpoint: Boolean => val parCollection = sc.makeRDD(1 to 4) val flatMappedRDD = parCollection.flatMap(x => 1 to x) @@ -56,6 +275,49 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging assert(flatMappedRDD.collect() === result) } + runTest("checkpointing partitioners", skipLocalCheckpoint = true) { _: Boolean => + + def testPartitionerCheckpointing( + partitioner: Partitioner, + corruptPartitionerFile: Boolean = false + ): Unit = { + val rddWithPartitioner = sc.makeRDD(1 to 4).map { _ -> 1 }.partitionBy(partitioner) + rddWithPartitioner.checkpoint() + rddWithPartitioner.count() + assert(rddWithPartitioner.getCheckpointFile.get.nonEmpty, + "checkpointing was not successful") + + if (corruptPartitionerFile) { + // Overwrite the partitioner file with garbage data + val checkpointDir = new Path(rddWithPartitioner.getCheckpointFile.get) + val fs = checkpointDir.getFileSystem(sc.hadoopConfiguration) + val partitionerFile = fs.listStatus(checkpointDir) + .find(_.getPath.getName.contains("partitioner")) + .map(_.getPath) + require(partitionerFile.nonEmpty, "could not find the partitioner file for testing") + val output = fs.create(partitionerFile.get, true) + output.write(100) + output.close() + } + + val newRDD = sc.checkpointFile[(Int, Int)](rddWithPartitioner.getCheckpointFile.get) + assert(newRDD.collect().toSet === rddWithPartitioner.collect().toSet, "RDD not recovered") + + if (!corruptPartitionerFile) { + assert(newRDD.partitioner != None, "partitioner not recovered") + assert(newRDD.partitioner === rddWithPartitioner.partitioner, + "recovered partitioner does not match") + } else { + assert(newRDD.partitioner == None, "partitioner unexpectedly recovered") + } + } + + testPartitionerCheckpointing(partitioner) + + // Test that corrupted partitioner file does not prevent recovery of RDD + testPartitionerCheckpointing(partitioner, corruptPartitionerFile = true) + } + runTest("RDDs with one-to-one dependencies") { reliableCheckpoint: Boolean => testRDD(_.map(x => x.toString), reliableCheckpoint) testRDD(_.flatMap(x => 1 to x), reliableCheckpoint) @@ -251,203 +513,26 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging assert(rdd.partitions.size === 0) } - // Utility test methods - - /** Checkpoint the RDD either locally or reliably. */ - private def checkpoint(rdd: RDD[_], reliableCheckpoint: Boolean): Unit = { - if (reliableCheckpoint) { - rdd.checkpoint() - } else { - rdd.localCheckpoint() - } - } - - /** Run a test twice, once for local checkpointing and once for reliable checkpointing. */ - private def runTest(name: String)(body: Boolean => Unit): Unit = { - test(name + " [reliable checkpoint]")(body(true)) - test(name + " [local checkpoint]")(body(false)) + runTest("checkpointAllMarkedAncestors") { reliableCheckpoint: Boolean => + testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = true) + testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = false) } - private def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect() - - /** - * Test checkpointing of the RDD generated by the given operation. It tests whether the - * serialized size of the RDD is reduce after checkpointing or not. This function should be called - * on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.). - * - * @param op an operation to run on the RDD - * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints - * @param collectFunc a function for collecting the values in the RDD, in case there are - * non-comparable types like arrays that we want to convert to something that supports == - */ - private def testRDD[U: ClassTag]( - op: (RDD[Int]) => RDD[U], - reliableCheckpoint: Boolean, - collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { - // Generate the final RDD using given RDD operation - val baseRDD = generateFatRDD() - val operatedRDD = op(baseRDD) - val parentRDD = operatedRDD.dependencies.headOption.orNull - val rddType = operatedRDD.getClass.getSimpleName - val numPartitions = operatedRDD.partitions.length - - // Force initialization of all the data structures in RDDs - // Without this, serializing the RDD will give a wrong estimate of the size of the RDD - initializeRdd(operatedRDD) - - val partitionsBeforeCheckpoint = operatedRDD.partitions - - // Find serialized sizes before and after the checkpoint - logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) - val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) - checkpoint(operatedRDD, reliableCheckpoint) - val result = collectFunc(operatedRDD) - operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables - val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) - logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) - - // Test whether the checkpoint file has been created - if (reliableCheckpoint) { - assert(collectFunc(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result) - } - - // Test whether dependencies have been changed from its earlier parent RDD - assert(operatedRDD.dependencies.head.rdd != parentRDD) - - // Test whether the partitions have been changed from its earlier partitions - assert(operatedRDD.partitions.toList != partitionsBeforeCheckpoint.toList) - - // Test whether the partitions have been changed to the new Hadoop partitions - assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList) - - // Test whether the number of partitions is same as before - assert(operatedRDD.partitions.length === numPartitions) - - // Test whether the data in the checkpointed RDD is same as original - assert(collectFunc(operatedRDD) === result) - - // Test whether serialized size of the RDD has reduced. - logInfo("Size of " + rddType + - " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") - assert( - rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, - "Size of " + rddType + " did not reduce after checkpointing " + - " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" - ) - } - - /** - * Test whether checkpointing of the parent of the generated RDD also - * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent - * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed, - * the generated RDD will remember the partitions and therefore potentially the whole lineage. - * This function should be called only those RDD whose partitions refer to parent RDD's - * partitions (i.e., do not call it on simple RDD like MappedRDD). - * - * @param op an operation to run on the RDD - * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints - * @param collectFunc a function for collecting the values in the RDD, in case there are - * non-comparable types like arrays that we want to convert to something that supports == - */ - private def testRDDPartitions[U: ClassTag]( - op: (RDD[Int]) => RDD[U], - reliableCheckpoint: Boolean, - collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { - // Generate the final RDD using given RDD operation - val baseRDD = generateFatRDD() - val operatedRDD = op(baseRDD) - val parentRDDs = operatedRDD.dependencies.map(_.rdd) - val rddType = operatedRDD.getClass.getSimpleName - - // Force initialization of all the data structures in RDDs - // Without this, serializing the RDD will give a wrong estimate of the size of the RDD - initializeRdd(operatedRDD) - - // Find serialized sizes before and after the checkpoint - logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) - val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) - // checkpoint the parent RDD, not the generated one - parentRDDs.foreach { rdd => - checkpoint(rdd, reliableCheckpoint) + private def testCheckpointAllMarkedAncestors( + reliableCheckpoint: Boolean, checkpointAllMarkedAncestors: Boolean): Unit = { + sc.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, checkpointAllMarkedAncestors.toString) + try { + val rdd1 = sc.parallelize(1 to 10) + checkpoint(rdd1, reliableCheckpoint) + val rdd2 = rdd1.map(_ + 1) + checkpoint(rdd2, reliableCheckpoint) + rdd2.count() + assert(rdd1.isCheckpointed === checkpointAllMarkedAncestors) + assert(rdd2.isCheckpointed === true) + } finally { + sc.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, null) } - val result = collectFunc(operatedRDD) // force checkpointing - operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables - val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) - logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) - - // Test whether the data in the checkpointed RDD is same as original - assert(collectFunc(operatedRDD) === result) - - // Test whether serialized size of the partitions has reduced - logInfo("Size of partitions of " + rddType + - " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]") - assert( - partitionSizeAfterCheckpoint < partitionSizeBeforeCheckpoint, - "Size of " + rddType + " partitions did not reduce after checkpointing parent RDDs" + - " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]" - ) - } - - /** - * Generate an RDD such that both the RDD and its partitions have large size. - */ - private def generateFatRDD(): RDD[Int] = { - new FatRDD(sc.makeRDD(1 to 100, 4)).map(x => x) } - - /** - * Generate an pair RDD (with partitioner) such that both the RDD and its partitions - * have large size. - */ - private def generateFatPairRDD(): RDD[(Int, Int)] = { - new FatPairRDD(sc.makeRDD(1 to 100, 4), partitioner).mapValues(x => x) - } - - /** - * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks - * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint. - */ - private def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { - val rddSize = Utils.serialize(rdd).size - val rddCpDataSize = Utils.serialize(rdd.checkpointData).size - val rddPartitionSize = Utils.serialize(rdd.partitions).size - val rddDependenciesSize = Utils.serialize(rdd.dependencies).size - - // Print detailed size, helps in debugging - logInfo("Serialized sizes of " + rdd + - ": RDD = " + rddSize + - ", RDD checkpoint data = " + rddCpDataSize + - ", RDD partitions = " + rddPartitionSize + - ", RDD dependencies = " + rddDependenciesSize - ) - // this makes sure that serializing the RDD's checkpoint data does not - // serialize the whole RDD as well - assert( - rddSize > rddCpDataSize, - "RDD's checkpoint data (" + rddCpDataSize + ") is equal or larger than the " + - "whole RDD with checkpoint data (" + rddSize + ")" - ) - (rddSize - rddCpDataSize, rddPartitionSize) - } - - /** - * Serialize and deserialize an object. This is useful to verify the objects - * contents after deserialization (e.g., the contents of an RDD split after - * it is sent to a slave along with a task) - */ - private def serializeDeserialize[T](obj: T): T = { - val bytes = Utils.serialize(obj) - Utils.deserialize[T](bytes) - } - - /** - * Recursively force the initialization of the all members of an RDD and it parents. - */ - private def initializeRdd(rdd: RDD[_]): Unit = { - rdd.partitions // forces the - rdd.dependencies.map(_.rdd).foreach(initializeRdd) - } - } /** RDD partition that has large serialized size. */ @@ -494,5 +579,4 @@ object CheckpointSuite { part ).asInstanceOf[RDD[(K, Array[Iterable[V]])]] } - } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 0c14bef7befd8..f98150536d8a8 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -19,23 +19,20 @@ package org.apache.spark import java.lang.ref.WeakReference -import scala.collection.mutable.{HashSet, SynchronizedSet} +import scala.collection.mutable.HashSet import scala.language.existentials import scala.util.Random import org.scalatest.BeforeAndAfter -import org.scalatest.concurrent.PatienceConfiguration import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.PatienceConfiguration import org.scalatest.time.SpanSugar._ -import org.apache.spark.rdd.{ReliableRDDCheckpointData, RDD} -import org.apache.spark.storage._ +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData} import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.storage.BroadcastBlockId -import org.apache.spark.storage.RDDBlockId -import org.apache.spark.storage.ShuffleBlockId -import org.apache.spark.storage.ShuffleIndexBlockId +import org.apache.spark.storage._ /** * An abstract base class for context cleaner tests, which sets up a context with a config @@ -446,25 +443,25 @@ class CleanerTester( checkpointIds: Seq[Long] = Seq.empty) extends Logging { - val toBeCleanedRDDIds = new HashSet[Int] with SynchronizedSet[Int] ++= rddIds - val toBeCleanedShuffleIds = new HashSet[Int] with SynchronizedSet[Int] ++= shuffleIds - val toBeCleanedBroadcstIds = new HashSet[Long] with SynchronizedSet[Long] ++= broadcastIds - val toBeCheckpointIds = new HashSet[Long] with SynchronizedSet[Long] ++= checkpointIds + val toBeCleanedRDDIds = new HashSet[Int] ++= rddIds + val toBeCleanedShuffleIds = new HashSet[Int] ++= shuffleIds + val toBeCleanedBroadcstIds = new HashSet[Long] ++= broadcastIds + val toBeCheckpointIds = new HashSet[Long] ++= checkpointIds val isDistributed = !sc.isLocal val cleanerListener = new CleanerListener { def rddCleaned(rddId: Int): Unit = { - toBeCleanedRDDIds -= rddId + toBeCleanedRDDIds.synchronized { toBeCleanedRDDIds -= rddId } logInfo("RDD " + rddId + " cleaned") } def shuffleCleaned(shuffleId: Int): Unit = { - toBeCleanedShuffleIds -= shuffleId + toBeCleanedShuffleIds.synchronized { toBeCleanedShuffleIds -= shuffleId } logInfo("Shuffle " + shuffleId + " cleaned") } def broadcastCleaned(broadcastId: Long): Unit = { - toBeCleanedBroadcstIds -= broadcastId + toBeCleanedBroadcstIds.synchronized { toBeCleanedBroadcstIds -= broadcastId } logInfo("Broadcast " + broadcastId + " cleaned") } @@ -473,7 +470,7 @@ class CleanerTester( } def checkpointCleaned(rddId: Long): Unit = { - toBeCheckpointIds -= rddId + toBeCheckpointIds.synchronized { toBeCheckpointIds -= rddId } logInfo("checkpoint " + rddId + " cleaned") } } @@ -489,7 +486,8 @@ class CleanerTester( def assertCleanup()(implicit waitTimeout: PatienceConfiguration.Timeout) { try { eventually(waitTimeout, interval(100 millis)) { - assert(isAllCleanedUp) + assert(isAllCleanedUp, + "The following resources were not cleaned up:\n" + uncleanedResourcesToString) } postCleanupValidate() } finally { @@ -581,18 +579,27 @@ class CleanerTester( } private def uncleanedResourcesToString = { + val s1 = toBeCleanedRDDIds.synchronized { + toBeCleanedRDDIds.toSeq.sorted.mkString("[", ", ", "]") + } + val s2 = toBeCleanedShuffleIds.synchronized { + toBeCleanedShuffleIds.toSeq.sorted.mkString("[", ", ", "]") + } + val s3 = toBeCleanedBroadcstIds.synchronized { + toBeCleanedBroadcstIds.toSeq.sorted.mkString("[", ", ", "]") + } s""" - |\tRDDs = ${toBeCleanedRDDIds.toSeq.sorted.mkString("[", ", ", "]")} - |\tShuffles = ${toBeCleanedShuffleIds.toSeq.sorted.mkString("[", ", ", "]")} - |\tBroadcasts = ${toBeCleanedBroadcstIds.toSeq.sorted.mkString("[", ", ", "]")} + |\tRDDs = $s1 + |\tShuffles = $s2 + |\tBroadcasts = $s3 """.stripMargin } private def isAllCleanedUp = - toBeCleanedRDDIds.isEmpty && - toBeCleanedShuffleIds.isEmpty && - toBeCleanedBroadcstIds.isEmpty && - toBeCheckpointIds.isEmpty + toBeCleanedRDDIds.synchronized { toBeCleanedRDDIds.isEmpty } && + toBeCleanedShuffleIds.synchronized { toBeCleanedShuffleIds.isEmpty } && + toBeCleanedBroadcstIds.synchronized { toBeCleanedBroadcstIds.isEmpty } && + toBeCheckpointIds.synchronized { toBeCheckpointIds.isEmpty } private def getRDDBlocks(rddId: Int): Seq[BlockId] = { blockManager.master.getMatchingBlockIds( _ match { diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 1c3f2bc315ddc..2110d3d770d5d 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.Matchers import org.scalatest.time.{Millis, Span} import org.apache.spark.storage.{RDDBlockId, StorageLevel} +import org.apache.spark.util.io.ChunkedByteBuffer class NotSerializableClass class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} @@ -193,11 +194,12 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val blockId = blockIds(0) val blockManager = SparkEnv.get.blockManager val blockTransfer = SparkEnv.get.blockTransferService + val serializerManager = SparkEnv.get.serializerManager blockManager.master.getLocations(blockId).foreach { cmId => val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId, blockId.toString) - val deserialized = blockManager.dataDeserialize(blockId, bytes.nioByteBuffer()) - .asInstanceOf[Iterator[Int]].toList + val deserialized = serializerManager.dataDeserializeStream[Int](blockId, + new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream()).toList assert(deserialized === (1 to 100).toList) } } @@ -222,7 +224,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val numPartitions = 10 val conf = new SparkConf() .set("spark.storage.unrollMemoryThreshold", "1024") - .set("spark.testing.memory", (size * numPartitions).toString) + .set("spark.testing.memory", size.toString) sc = new SparkContext(clusterUrl, "test", conf) val data = sc.parallelize(1 to size, numPartitions).persist(StorageLevel.MEMORY_ONLY) assert(data.count() === size) @@ -318,7 +320,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex Thread.sleep(200) } } catch { - case _: Throwable => { Thread.sleep(10) } + case _: Throwable => Thread.sleep(10) // Do nothing. We might see exceptions because block manager // is racing this thread to remove entries from the driver. } diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 116f027a0f987..ee6b991461902 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark import scala.collection.mutable import org.scalatest.{BeforeAndAfter, PrivateMethodTester} + import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -800,11 +801,95 @@ class ExecutorAllocationManagerSuite assert(maxNumExecutorsNeeded(manager) === 1) // If the task is failed, we expect it to be resubmitted later. - val taskEndReason = ExceptionFailure(null, null, null, null, null, None) + val taskEndReason = ExceptionFailure(null, null, null, null, None) sc.listenerBus.postToAll(SparkListenerTaskEnd(0, 0, null, taskEndReason, taskInfo, null)) assert(maxNumExecutorsNeeded(manager) === 1) } + test("reset the state of allocation manager") { + sc = createSparkContext() + val manager = sc.executorAllocationManager.get + assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsToAdd(manager) === 1) + + // Allocation manager is reset when adding executor requests are sent without reporting back + // executor added. + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 10))) + + assert(addExecutors(manager) === 1) + assert(numExecutorsTarget(manager) === 2) + assert(addExecutors(manager) === 2) + assert(numExecutorsTarget(manager) === 4) + assert(addExecutors(manager) === 1) + assert(numExecutorsTarget(manager) === 5) + + manager.reset() + assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsToAdd(manager) === 1) + assert(executorIds(manager) === Set.empty) + + // Allocation manager is reset when executors are added. + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 10))) + + addExecutors(manager) + addExecutors(manager) + addExecutors(manager) + assert(numExecutorsTarget(manager) === 5) + + onExecutorAdded(manager, "first") + onExecutorAdded(manager, "second") + onExecutorAdded(manager, "third") + onExecutorAdded(manager, "fourth") + onExecutorAdded(manager, "fifth") + assert(executorIds(manager) === Set("first", "second", "third", "fourth", "fifth")) + + // Cluster manager lost will make all the live executors lost, so here simulate this behavior + onExecutorRemoved(manager, "first") + onExecutorRemoved(manager, "second") + onExecutorRemoved(manager, "third") + onExecutorRemoved(manager, "fourth") + onExecutorRemoved(manager, "fifth") + + manager.reset() + assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsToAdd(manager) === 1) + assert(executorIds(manager) === Set.empty) + assert(removeTimes(manager) === Map.empty) + + // Allocation manager is reset when executors are pending to remove + addExecutors(manager) + addExecutors(manager) + addExecutors(manager) + assert(numExecutorsTarget(manager) === 5) + + onExecutorAdded(manager, "first") + onExecutorAdded(manager, "second") + onExecutorAdded(manager, "third") + onExecutorAdded(manager, "fourth") + onExecutorAdded(manager, "fifth") + assert(executorIds(manager) === Set("first", "second", "third", "fourth", "fifth")) + + removeExecutor(manager, "first") + removeExecutor(manager, "second") + assert(executorsPendingToRemove(manager) === Set("first", "second")) + assert(executorIds(manager) === Set("first", "second", "third", "fourth", "fifth")) + + + // Cluster manager lost will make all the live executors lost, so here simulate this behavior + onExecutorRemoved(manager, "first") + onExecutorRemoved(manager, "second") + onExecutorRemoved(manager, "third") + onExecutorRemoved(manager, "fourth") + onExecutorRemoved(manager, "fifth") + + manager.reset() + + assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsToAdd(manager) === 1) + assert(executorsPendingToRemove(manager) === Set.empty) + assert(removeTimes(manager) === Map.empty) + } + private def createSparkContext( minExecutors: Int = 1, maxExecutors: Int = 5, @@ -843,8 +928,8 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { numTasks: Int, taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty ): StageInfo = { - new StageInfo( - stageId, 0, "name", numTasks, Seq.empty, Seq.empty, "no details", taskLocalityPreferences) + new StageInfo(stageId, 0, "name", numTasks, Seq.empty, Seq.empty, "no details", + Seq.empty, taskLocalityPreferences) } private def createTaskInfo(taskId: Int, taskIndex: Int, executorId: String): TaskInfo = { diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 231f4631e0a47..eb3fb99747d12 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -35,7 +35,8 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { var rpcHandler: ExternalShuffleBlockHandler = _ override def beforeAll() { - val transportConf = SparkTransportConf.fromSparkConf(conf, numUsableCores = 2) + super.beforeAll() + val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 2) rpcHandler = new ExternalShuffleBlockHandler(transportConf, null) val transportContext = new TransportContext(transportConf, rpcHandler) server = transportContext.createServer() @@ -46,7 +47,11 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { } override def afterAll() { - server.close() + try { + server.close() + } finally { + super.afterAll() + } } // This test ensures that the external shuffle service is actually in use for the other tests. diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index 203dab934ca1f..3def8b0b1850e 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark -import org.apache.spark.util.NonSerializable - import java.io.{IOException, NotSerializableException, ObjectInputStream} +import org.apache.spark.util.NonSerializable + // Common state shared by FailureSuite-launched tasks. We use a global object // for this because any local variables used in the task closures will rightfully // be copied for each task, so there's no other way for them to share state. diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala deleted file mode 100644 index 1255e71af6c0b..0000000000000 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ /dev/null @@ -1,261 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import java.io._ -import java.net.URI -import java.util.jar.{JarEntry, JarOutputStream} -import javax.net.ssl.SSLException - -import com.google.common.io.{ByteStreams, Files} -import org.apache.commons.lang3.RandomUtils - -import org.apache.spark.util.Utils - -import SSLSampleConfigs._ - -class FileServerSuite extends SparkFunSuite with LocalSparkContext { - - @transient var tmpDir: File = _ - @transient var tmpFile: File = _ - @transient var tmpJarUrl: String = _ - - def newConf: SparkConf = new SparkConf(loadDefaults = false).set("spark.authenticate", "false") - - override def beforeEach() { - super.beforeEach() - resetSparkContext() - } - - override def beforeAll() { - super.beforeAll() - - tmpDir = Utils.createTempDir() - val testTempDir = new File(tmpDir, "test") - testTempDir.mkdir() - - val textFile = new File(testTempDir, "FileServerSuite.txt") - val pw = new PrintWriter(textFile) - // scalastyle:off println - pw.println("100") - // scalastyle:on println - pw.close() - - val jarFile = new File(testTempDir, "test.jar") - val jarStream = new FileOutputStream(jarFile) - val jar = new JarOutputStream(jarStream, new java.util.jar.Manifest()) - - val jarEntry = new JarEntry(textFile.getName) - jar.putNextEntry(jarEntry) - - val in = new FileInputStream(textFile) - ByteStreams.copy(in, jar) - - in.close() - jar.close() - jarStream.close() - - tmpFile = textFile - tmpJarUrl = jarFile.toURI.toURL.toString - } - - override def afterAll() { - super.afterAll() - Utils.deleteRecursively(tmpDir) - } - - test("Distributing files locally") { - sc = new SparkContext("local[4]", "test", newConf) - sc.addFile(tmpFile.toString) - val testData = Array((1, 1), (1, 1), (2, 1), (3, 5), (2, 2), (3, 0)) - val result = sc.parallelize(testData).reduceByKey { - val path = SparkFiles.get("FileServerSuite.txt") - val in = new BufferedReader(new FileReader(path)) - val fileVal = in.readLine().toInt - in.close() - _ * fileVal + _ * fileVal - }.collect() - assert(result.toSet === Set((1, 200), (2, 300), (3, 500))) - } - - test("Distributing files locally security On") { - val sparkConf = new SparkConf(false) - sparkConf.set("spark.authenticate", "true") - sparkConf.set("spark.authenticate.secret", "good") - sc = new SparkContext("local[4]", "test", sparkConf) - - sc.addFile(tmpFile.toString) - assert(sc.env.securityManager.isAuthenticationEnabled() === true) - val testData = Array((1, 1), (1, 1), (2, 1), (3, 5), (2, 2), (3, 0)) - val result = sc.parallelize(testData).reduceByKey { - val path = SparkFiles.get("FileServerSuite.txt") - val in = new BufferedReader(new FileReader(path)) - val fileVal = in.readLine().toInt - in.close() - _ * fileVal + _ * fileVal - }.collect() - assert(result.toSet === Set((1, 200), (2, 300), (3, 500))) - } - - test("Distributing files locally using URL as input") { - // addFile("file:///....") - sc = new SparkContext("local[4]", "test", newConf) - sc.addFile(new File(tmpFile.toString).toURI.toString) - val testData = Array((1, 1), (1, 1), (2, 1), (3, 5), (2, 2), (3, 0)) - val result = sc.parallelize(testData).reduceByKey { - val path = SparkFiles.get("FileServerSuite.txt") - val in = new BufferedReader(new FileReader(path)) - val fileVal = in.readLine().toInt - in.close() - _ * fileVal + _ * fileVal - }.collect() - assert(result.toSet === Set((1, 200), (2, 300), (3, 500))) - } - - test ("Dynamically adding JARS locally") { - sc = new SparkContext("local[4]", "test", newConf) - sc.addJar(tmpJarUrl) - val testData = Array((1, 1)) - sc.parallelize(testData).foreach { x => - if (Thread.currentThread.getContextClassLoader.getResource("FileServerSuite.txt") == null) { - throw new SparkException("jar not added") - } - } - } - - test("Distributing files on a standalone cluster") { - sc = new SparkContext("local-cluster[1,1,1024]", "test", newConf) - sc.addFile(tmpFile.toString) - val testData = Array((1, 1), (1, 1), (2, 1), (3, 5), (2, 2), (3, 0)) - val result = sc.parallelize(testData).reduceByKey { - val path = SparkFiles.get("FileServerSuite.txt") - val in = new BufferedReader(new FileReader(path)) - val fileVal = in.readLine().toInt - in.close() - _ * fileVal + _ * fileVal - }.collect() - assert(result.toSet === Set((1, 200), (2, 300), (3, 500))) - } - - test ("Dynamically adding JARS on a standalone cluster") { - sc = new SparkContext("local-cluster[1,1,1024]", "test", newConf) - sc.addJar(tmpJarUrl) - val testData = Array((1, 1)) - sc.parallelize(testData).foreach { x => - if (Thread.currentThread.getContextClassLoader.getResource("FileServerSuite.txt") == null) { - throw new SparkException("jar not added") - } - } - } - - test ("Dynamically adding JARS on a standalone cluster using local: URL") { - sc = new SparkContext("local-cluster[1,1,1024]", "test", newConf) - sc.addJar(tmpJarUrl.replace("file", "local")) - val testData = Array((1, 1)) - sc.parallelize(testData).foreach { x => - if (Thread.currentThread.getContextClassLoader.getResource("FileServerSuite.txt") == null) { - throw new SparkException("jar not added") - } - } - } - - test ("HttpFileServer should work with SSL") { - val sparkConf = sparkSSLConfig() - val sm = new SecurityManager(sparkConf) - val server = new HttpFileServer(sparkConf, sm, 0) - try { - server.initialize() - - fileTransferTest(server, sm) - } finally { - server.stop() - } - } - - test ("HttpFileServer should work with SSL and good credentials") { - val sparkConf = sparkSSLConfig() - sparkConf.set("spark.authenticate", "true") - sparkConf.set("spark.authenticate.secret", "good") - - val sm = new SecurityManager(sparkConf) - val server = new HttpFileServer(sparkConf, sm, 0) - try { - server.initialize() - - fileTransferTest(server, sm) - } finally { - server.stop() - } - } - - test ("HttpFileServer should not work with valid SSL and bad credentials") { - val sparkConf = sparkSSLConfig() - sparkConf.set("spark.authenticate", "true") - sparkConf.set("spark.authenticate.secret", "bad") - - val sm = new SecurityManager(sparkConf) - val server = new HttpFileServer(sparkConf, sm, 0) - try { - server.initialize() - - intercept[IOException] { - fileTransferTest(server) - } - } finally { - server.stop() - } - } - - test ("HttpFileServer should not work with SSL when the server is untrusted") { - val sparkConf = sparkSSLConfigUntrusted() - val sm = new SecurityManager(sparkConf) - val server = new HttpFileServer(sparkConf, sm, 0) - try { - server.initialize() - - intercept[SSLException] { - fileTransferTest(server) - } - } finally { - server.stop() - } - } - - def fileTransferTest(server: HttpFileServer, sm: SecurityManager = null): Unit = { - val randomContent = RandomUtils.nextBytes(100) - val file = File.createTempFile("FileServerSuite", "sslTests", tmpDir) - Files.write(randomContent, file) - server.addFile(file) - - val uri = new URI(server.serverUri + "/files/" + file.getName) - - val connection = if (sm != null && sm.isAuthenticationEnabled()) { - Utils.constructURIForAuthentication(uri, sm).toURL.openConnection() - } else { - uri.toURL.openConnection() - } - - if (sm != null) { - Utils.setupSecureURLConnection(connection, sm) - } - - val buf = ByteStreams.toByteArray(connection.getInputStream) - assert(buf === randomContent) - } - -} diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index fdb00aafc4a48..993834f8d7d42 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -19,20 +19,18 @@ package org.apache.spark import java.io.{File, FileWriter} -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.input.PortableDataStream -import org.apache.spark.storage.StorageLevel - import scala.io.Source import org.apache.hadoop.io._ import org.apache.hadoop.io.compress.DefaultCodec -import org.apache.hadoop.mapred.{JobConf, FileAlreadyExistsException, FileSplit, TextInputFormat, TextOutputFormat} +import org.apache.hadoop.mapred.{FileAlreadyExistsException, FileSplit, JobConf, TextInputFormat, TextOutputFormat} import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} -import org.apache.spark.rdd.{NewHadoopRDD, HadoopRDD} +import org.apache.spark.input.PortableDataStream +import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD} +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils class FileSuite extends SparkFunSuite with LocalSparkContext { @@ -44,8 +42,11 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } override def afterEach() { - super.afterEach() - Utils.deleteRecursively(tempDir) + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterEach() + } } test("text files") { @@ -503,11 +504,11 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { sc = new SparkContext("local", "test") val randomRDD = sc.parallelize( Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) - val job = new Job(sc.hadoopConfiguration) + val job = Job.getInstance(sc.hadoopConfiguration) job.setOutputKeyClass(classOf[String]) job.setOutputValueClass(classOf[String]) job.setOutputFormatClass(classOf[NewTextOutputFormat[String, String]]) - val jobConfig = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val jobConfig = job.getConfiguration jobConfig.set("mapred.output.dir", tempDir.getPath + "/outputDataset_new") randomRDD.saveAsNewAPIHadoopDataset(jobConfig) assert(new File(tempDir.getPath + "/outputDataset_new/part-r-00000").exists() === true) diff --git a/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala b/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala index 19180e88ebe0a..10794235ed392 100644 --- a/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala @@ -24,6 +24,7 @@ class HashShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with hash-based shuffle. override def beforeAll() { + super.beforeAll() conf.set("spark.shuffle.manager", "hash") } } diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 3cd80c0f7d171..713d5e58b4ffc 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -25,13 +25,13 @@ import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps -import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} -import org.mockito.Mockito.{mock, spy, verify, when} import org.mockito.Matchers import org.mockito.Matchers._ +import org.mockito.Mockito.{mock, spy, verify, when} +import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} import org.apache.spark.executor.TaskMetrics -import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv, RpcEndpointRef} +import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend @@ -66,6 +66,7 @@ class HeartbeatReceiverSuite * that uses a manual clock. */ override def beforeEach(): Unit = { + super.beforeEach() val conf = new SparkConf() .setMaster("local[2]") .setAppName("test") @@ -173,10 +174,10 @@ class HeartbeatReceiverSuite val dummyExecutorEndpoint2 = new FakeExecutorEndpoint(rpcEnv) val dummyExecutorEndpointRef1 = rpcEnv.setupEndpoint("fake-executor-1", dummyExecutorEndpoint1) val dummyExecutorEndpointRef2 = rpcEnv.setupEndpoint("fake-executor-2", dummyExecutorEndpoint2) - fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisterExecutorResponse]( - RegisterExecutor(executorId1, dummyExecutorEndpointRef1, "dummy:4040", 0, Map.empty)) - fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisterExecutorResponse]( - RegisterExecutor(executorId2, dummyExecutorEndpointRef2, "dummy:4040", 0, Map.empty)) + fakeSchedulerBackend.driverEndpoint.askWithRetry[Boolean]( + RegisterExecutor(executorId1, dummyExecutorEndpointRef1, 0, Map.empty)) + fakeSchedulerBackend.driverEndpoint.askWithRetry[Boolean]( + RegisterExecutor(executorId2, dummyExecutorEndpointRef2, 0, Map.empty)) heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) addExecutorAndVerify(executorId1) addExecutorAndVerify(executorId2) @@ -214,14 +215,16 @@ class HeartbeatReceiverSuite val metrics = new TaskMetrics val blockManagerId = BlockManagerId(executorId, "localhost", 12345) val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse]( - Heartbeat(executorId, Array(1L -> metrics), blockManagerId)) + Heartbeat(executorId, Array(1L -> metrics.accumulatorUpdates()), blockManagerId)) if (executorShouldReregister) { assert(response.reregisterBlockManager) } else { assert(!response.reregisterBlockManager) // Additionally verify that the scheduler callback is called with the correct parameters verify(scheduler).executorHeartbeatReceived( - Matchers.eq(executorId), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + Matchers.eq(executorId), + Matchers.eq(Array(1L -> metrics.accumulatorUpdates())), + Matchers.eq(blockManagerId)) } } @@ -252,7 +255,12 @@ class HeartbeatReceiverSuite /** * Dummy RPC endpoint to simulate executors. */ -private class FakeExecutorEndpoint(override val rpcEnv: RpcEnv) extends RpcEndpoint +private class FakeExecutorEndpoint(override val rpcEnv: RpcEnv) extends RpcEndpoint { + + override def receive: PartialFunction[Any, Unit] = { + case _ => + } +} /** * Dummy scheduler backend to simulate executor allocation requests to the cluster manager. diff --git a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala index 4399f25626472..939f12f94f5c3 100644 --- a/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD class ImplicitOrderingSuite extends SparkFunSuite with LocalSparkContext { // Tests that PairRDDFunctions grabs an implicit Ordering in various cases where it should. - test("basic inference of Orderings"){ + test("basic inference of Orderings") { sc = new SparkContext("local", "test") val rdd = sc.parallelize(1 to 10) diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala new file mode 100644 index 0000000000000..474550608ba2f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.scheduler.AccumulableInfo +import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.storage.{BlockId, BlockStatus} + + +class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { + import InternalAccumulator._ + import AccumulatorParam._ + + override def afterEach(): Unit = { + try { + Accumulators.clear() + } finally { + super.afterEach() + } + } + + test("get param") { + assert(getParam(EXECUTOR_DESERIALIZE_TIME) === LongAccumulatorParam) + assert(getParam(EXECUTOR_RUN_TIME) === LongAccumulatorParam) + assert(getParam(RESULT_SIZE) === LongAccumulatorParam) + assert(getParam(JVM_GC_TIME) === LongAccumulatorParam) + assert(getParam(RESULT_SERIALIZATION_TIME) === LongAccumulatorParam) + assert(getParam(MEMORY_BYTES_SPILLED) === LongAccumulatorParam) + assert(getParam(DISK_BYTES_SPILLED) === LongAccumulatorParam) + assert(getParam(PEAK_EXECUTION_MEMORY) === LongAccumulatorParam) + assert(getParam(UPDATED_BLOCK_STATUSES) === UpdatedBlockStatusesAccumulatorParam) + assert(getParam(TEST_ACCUM) === LongAccumulatorParam) + // shuffle read + assert(getParam(shuffleRead.REMOTE_BLOCKS_FETCHED) === IntAccumulatorParam) + assert(getParam(shuffleRead.LOCAL_BLOCKS_FETCHED) === IntAccumulatorParam) + assert(getParam(shuffleRead.REMOTE_BYTES_READ) === LongAccumulatorParam) + assert(getParam(shuffleRead.LOCAL_BYTES_READ) === LongAccumulatorParam) + assert(getParam(shuffleRead.FETCH_WAIT_TIME) === LongAccumulatorParam) + assert(getParam(shuffleRead.RECORDS_READ) === LongAccumulatorParam) + // shuffle write + assert(getParam(shuffleWrite.BYTES_WRITTEN) === LongAccumulatorParam) + assert(getParam(shuffleWrite.RECORDS_WRITTEN) === LongAccumulatorParam) + assert(getParam(shuffleWrite.WRITE_TIME) === LongAccumulatorParam) + // input + assert(getParam(input.READ_METHOD) === StringAccumulatorParam) + assert(getParam(input.RECORDS_READ) === LongAccumulatorParam) + assert(getParam(input.BYTES_READ) === LongAccumulatorParam) + // output + assert(getParam(output.WRITE_METHOD) === StringAccumulatorParam) + assert(getParam(output.RECORDS_WRITTEN) === LongAccumulatorParam) + assert(getParam(output.BYTES_WRITTEN) === LongAccumulatorParam) + // default to Long + assert(getParam(METRICS_PREFIX + "anything") === LongAccumulatorParam) + intercept[IllegalArgumentException] { + getParam("something that does not start with the right prefix") + } + } + + test("create by name") { + val executorRunTime = create(EXECUTOR_RUN_TIME) + val updatedBlockStatuses = create(UPDATED_BLOCK_STATUSES) + val shuffleRemoteBlocksRead = create(shuffleRead.REMOTE_BLOCKS_FETCHED) + val inputReadMethod = create(input.READ_METHOD) + assert(executorRunTime.name === Some(EXECUTOR_RUN_TIME)) + assert(updatedBlockStatuses.name === Some(UPDATED_BLOCK_STATUSES)) + assert(shuffleRemoteBlocksRead.name === Some(shuffleRead.REMOTE_BLOCKS_FETCHED)) + assert(inputReadMethod.name === Some(input.READ_METHOD)) + assert(executorRunTime.value.isInstanceOf[Long]) + assert(updatedBlockStatuses.value.isInstanceOf[Seq[_]]) + // We cannot assert the type of the value directly since the type parameter is erased. + // Instead, try casting a `Seq` of expected type and see if it fails in run time. + updatedBlockStatuses.setValueAny(Seq.empty[(BlockId, BlockStatus)]) + assert(shuffleRemoteBlocksRead.value.isInstanceOf[Int]) + assert(inputReadMethod.value.isInstanceOf[String]) + // default to Long + val anything = create(METRICS_PREFIX + "anything") + assert(anything.value.isInstanceOf[Long]) + } + + test("create") { + val accums = createAll() + val shuffleReadAccums = createShuffleReadAccums() + val shuffleWriteAccums = createShuffleWriteAccums() + val inputAccums = createInputAccums() + val outputAccums = createOutputAccums() + // assert they're all internal + assert(accums.forall(_.isInternal)) + assert(shuffleReadAccums.forall(_.isInternal)) + assert(shuffleWriteAccums.forall(_.isInternal)) + assert(inputAccums.forall(_.isInternal)) + assert(outputAccums.forall(_.isInternal)) + // assert they all count on failures + assert(accums.forall(_.countFailedValues)) + assert(shuffleReadAccums.forall(_.countFailedValues)) + assert(shuffleWriteAccums.forall(_.countFailedValues)) + assert(inputAccums.forall(_.countFailedValues)) + assert(outputAccums.forall(_.countFailedValues)) + // assert they all have names + assert(accums.forall(_.name.isDefined)) + assert(shuffleReadAccums.forall(_.name.isDefined)) + assert(shuffleWriteAccums.forall(_.name.isDefined)) + assert(inputAccums.forall(_.name.isDefined)) + assert(outputAccums.forall(_.name.isDefined)) + // assert `accums` is a strict superset of the others + val accumNames = accums.map(_.name.get).toSet + val shuffleReadAccumNames = shuffleReadAccums.map(_.name.get).toSet + val shuffleWriteAccumNames = shuffleWriteAccums.map(_.name.get).toSet + val inputAccumNames = inputAccums.map(_.name.get).toSet + val outputAccumNames = outputAccums.map(_.name.get).toSet + assert(shuffleReadAccumNames.subsetOf(accumNames)) + assert(shuffleWriteAccumNames.subsetOf(accumNames)) + assert(inputAccumNames.subsetOf(accumNames)) + assert(outputAccumNames.subsetOf(accumNames)) + } + + test("naming") { + val accums = createAll() + val shuffleReadAccums = createShuffleReadAccums() + val shuffleWriteAccums = createShuffleWriteAccums() + val inputAccums = createInputAccums() + val outputAccums = createOutputAccums() + // assert that prefixes are properly namespaced + assert(SHUFFLE_READ_METRICS_PREFIX.startsWith(METRICS_PREFIX)) + assert(SHUFFLE_WRITE_METRICS_PREFIX.startsWith(METRICS_PREFIX)) + assert(INPUT_METRICS_PREFIX.startsWith(METRICS_PREFIX)) + assert(OUTPUT_METRICS_PREFIX.startsWith(METRICS_PREFIX)) + assert(accums.forall(_.name.get.startsWith(METRICS_PREFIX))) + // assert they all start with the expected prefixes + assert(shuffleReadAccums.forall(_.name.get.startsWith(SHUFFLE_READ_METRICS_PREFIX))) + assert(shuffleWriteAccums.forall(_.name.get.startsWith(SHUFFLE_WRITE_METRICS_PREFIX))) + assert(inputAccums.forall(_.name.get.startsWith(INPUT_METRICS_PREFIX))) + assert(outputAccums.forall(_.name.get.startsWith(OUTPUT_METRICS_PREFIX))) + } + + test("internal accumulators in TaskContext") { + val taskContext = TaskContext.empty() + val accumUpdates = taskContext.taskMetrics.accumulatorUpdates() + assert(accumUpdates.size > 0) + assert(accumUpdates.forall(_.internal)) + val testAccum = taskContext.taskMetrics.getAccum(TEST_ACCUM) + assert(accumUpdates.exists(_.id == testAccum.id)) + } + + test("internal accumulators in a stage") { + val listener = new SaveInfoListener + val numPartitions = 10 + sc = new SparkContext("local", "test") + sc.addSparkListener(listener) + // Have each task add 1 to the internal accumulator + val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitions { iter => + TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 1 + iter + } + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { () => + val stageInfos = listener.getCompletedStageInfos + val taskInfos = listener.getCompletedTaskInfos + assert(stageInfos.size === 1) + assert(taskInfos.size === numPartitions) + // The accumulator values should be merged in the stage + val stageAccum = findTestAccum(stageInfos.head.accumulables.values) + assert(stageAccum.value.get.asInstanceOf[Long] === numPartitions) + // The accumulator should be updated locally on each task + val taskAccumValues = taskInfos.map { taskInfo => + val taskAccum = findTestAccum(taskInfo.accumulables) + assert(taskAccum.update.isDefined) + assert(taskAccum.update.get.asInstanceOf[Long] === 1L) + taskAccum.value.get.asInstanceOf[Long] + } + // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions + assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + } + rdd.count() + listener.awaitNextJobCompletion() + } + + test("internal accumulators in multiple stages") { + val listener = new SaveInfoListener + val numPartitions = 10 + sc = new SparkContext("local", "test") + sc.addSparkListener(listener) + // Each stage creates its own set of internal accumulators so the + // values for the same metric should not be mixed up across stages + val rdd = sc.parallelize(1 to 100, numPartitions) + .map { i => (i, i) } + .mapPartitions { iter => + TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 1 + iter + } + .reduceByKey { case (x, y) => x + y } + .mapPartitions { iter => + TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 10 + iter + } + .repartition(numPartitions * 2) + .mapPartitions { iter => + TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 100 + iter + } + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { () => + // We ran 3 stages, and the accumulator values should be distinct + val stageInfos = listener.getCompletedStageInfos + assert(stageInfos.size === 3) + val (firstStageAccum, secondStageAccum, thirdStageAccum) = + (findTestAccum(stageInfos(0).accumulables.values), + findTestAccum(stageInfos(1).accumulables.values), + findTestAccum(stageInfos(2).accumulables.values)) + assert(firstStageAccum.value.get.asInstanceOf[Long] === numPartitions) + assert(secondStageAccum.value.get.asInstanceOf[Long] === numPartitions * 10) + assert(thirdStageAccum.value.get.asInstanceOf[Long] === numPartitions * 2 * 100) + } + rdd.count() + } + + test("internal accumulators in resubmitted stages") { + val listener = new SaveInfoListener + val numPartitions = 10 + sc = new SparkContext("local", "test") + sc.addSparkListener(listener) + + // Simulate fetch failures in order to trigger a stage retry. Here we run 1 job with + // 2 stages. On the second stage, we trigger a fetch failure on the first stage attempt. + // This should retry both stages in the scheduler. Note that we only want to fail the + // first stage attempt because we want the stage to eventually succeed. + val x = sc.parallelize(1 to 100, numPartitions) + .mapPartitions { iter => TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 1; iter } + .groupBy(identity) + val sid = x.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle.shuffleId + val rdd = x.mapPartitionsWithIndex { case (i, iter) => + // Fail the first stage attempt. Here we use the task attempt ID to determine this. + // This job runs 2 stages, and we're in the second stage. Therefore, any task attempt + // ID that's < 2 * numPartitions belongs to the first attempt of this stage. + val taskContext = TaskContext.get() + val isFirstStageAttempt = taskContext.taskAttemptId() < numPartitions * 2 + if (isFirstStageAttempt) { + throw new FetchFailedException( + SparkEnv.get.blockManager.blockManagerId, + sid, + taskContext.partitionId(), + taskContext.partitionId(), + "simulated fetch failure") + } else { + iter + } + } + + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { () => + val stageInfos = listener.getCompletedStageInfos + assert(stageInfos.size === 4) // 1 shuffle map stage + 1 result stage, both are retried + val mapStageId = stageInfos.head.stageId + val mapStageInfo1stAttempt = stageInfos.head + val mapStageInfo2ndAttempt = { + stageInfos.tail.find(_.stageId == mapStageId).getOrElse { + fail("expected two attempts of the same shuffle map stage.") + } + } + val stageAccum1stAttempt = findTestAccum(mapStageInfo1stAttempt.accumulables.values) + val stageAccum2ndAttempt = findTestAccum(mapStageInfo2ndAttempt.accumulables.values) + // Both map stages should have succeeded, since the fetch failure happened in the + // result stage, not the map stage. This means we should get the accumulator updates + // from all partitions. + assert(stageAccum1stAttempt.value.get.asInstanceOf[Long] === numPartitions) + assert(stageAccum2ndAttempt.value.get.asInstanceOf[Long] === numPartitions) + // Because this test resubmitted the map stage with all missing partitions, we should have + // created a fresh set of internal accumulators in the 2nd stage attempt. Assert this is + // the case by comparing the accumulator IDs between the two attempts. + // Note: it would be good to also test the case where the map stage is resubmitted where + // only a subset of the original partitions are missing. However, this scenario is very + // difficult to construct without potentially introducing flakiness. + assert(stageAccum1stAttempt.id != stageAccum2ndAttempt.id) + } + rdd.count() + listener.awaitNextJobCompletion() + } + + test("internal accumulators are registered for cleanups") { + sc = new SparkContext("local", "test") { + private val myCleaner = new SaveAccumContextCleaner(this) + override def cleaner: Option[ContextCleaner] = Some(myCleaner) + } + assert(Accumulators.originals.isEmpty) + sc.parallelize(1 to 100).map { i => (i, i) }.reduceByKey { _ + _ }.count() + val internalAccums = InternalAccumulator.createAll() + // We ran 2 stages, so we should have 2 sets of internal accumulators, 1 for each stage + assert(Accumulators.originals.size === internalAccums.size * 2) + val accumsRegistered = sc.cleaner match { + case Some(cleaner: SaveAccumContextCleaner) => cleaner.accumsRegisteredForCleanup + case _ => Seq.empty[Long] + } + // Make sure the same set of accumulators is registered for cleanup + assert(accumsRegistered.size === internalAccums.size * 2) + assert(accumsRegistered.toSet === Accumulators.originals.keys.toSet) + } + + /** + * Return the accumulable info that matches the specified name. + */ + private def findTestAccum(accums: Iterable[AccumulableInfo]): AccumulableInfo = { + accums.find { a => a.name == Some(TEST_ACCUM) }.getOrElse { + fail(s"unable to find internal accumulator called $TEST_ACCUM") + } + } + + /** + * A special [[ContextCleaner]] that saves the IDs of the accumulators registered for cleanup. + */ + private class SaveAccumContextCleaner(sc: SparkContext) extends ContextCleaner(sc) { + private val accumsRegistered = new ArrayBuffer[Long] + + override def registerAccumulatorForCleanup(a: Accumulable[_, _]): Unit = { + accumsRegistered += a.id + super.registerAccumulatorForCleanup(a) + } + + def accumsRegisteredForCleanup: Seq[Long] = accumsRegistered.toArray + } + +} diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 1168eb0b802f2..c347ab8dc8020 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -22,7 +22,7 @@ import java.util.concurrent.Semaphore import scala.concurrent.Await import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ -import scala.concurrent.future +import scala.concurrent.Future import org.scalatest.BeforeAndAfter import org.scalatest.Matchers @@ -38,8 +38,11 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft with LocalSparkContext { override def afterEach() { - super.afterEach() - resetSparkContext() + try { + resetSparkContext() + } finally { + super.afterEach() + } } test("local mode, FIFO scheduler") { @@ -100,7 +103,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft val rdd1 = rdd.map(x => x) - future { + Future { taskStartedSemaphore.acquire() sc.cancelAllJobs() taskCancelledSemaphore.release(100000) @@ -123,7 +126,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft }) // jobA is the one to be cancelled. - val jobA = future { + val jobA = Future { sc.setJobGroup("jobA", "this is a job to be cancelled") sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count() } @@ -188,7 +191,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft }) // jobA is the one to be cancelled. - val jobA = future { + val jobA = Future { sc.setJobGroup("jobA", "this is a job to be cancelled", interruptOnCancel = true) sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(100000); i }.count() } @@ -228,7 +231,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft val f2 = rdd.countAsync() // Kill one of the action. - future { + Future { sem1.acquire() f1.cancel() JobCancellationSuite.twoJobsSharingStageSemaphore.release(10) @@ -244,7 +247,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft // Cancel before launching any tasks { val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync() - future { f.cancel() } + Future { f.cancel() } val e = intercept[SparkException] { f.get() } assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) } @@ -260,7 +263,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft }) val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync() - future { + Future { // Wait until some tasks were launched before we cancel the job. sem.acquire() f.cancel() @@ -274,7 +277,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft // Cancel before launching any tasks { val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000) - future { f.cancel() } + Future { f.cancel() } val e = intercept[SparkException] { f.get() } assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) } @@ -289,7 +292,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft } }) val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000) - future { + Future { sem.acquire() f.cancel() } diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 8bf2e55defd02..24ec99c7e5e60 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -17,7 +17,7 @@ package org.apache.spark -import _root_.io.netty.util.internal.logging.{Slf4JLoggerFactory, InternalLoggerFactory} +import _root_.io.netty.util.internal.logging.{InternalLoggerFactory, Slf4JLoggerFactory} import org.scalatest.BeforeAndAfterAll import org.scalatest.BeforeAndAfterEach import org.scalatest.Suite @@ -28,13 +28,16 @@ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self @transient var sc: SparkContext = _ override def beforeAll() { - InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()) super.beforeAll() + InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()) } override def afterEach() { - resetSparkContext() - super.afterEach() + try { + resetSparkContext() + } finally { + super.afterEach() + } } def resetSparkContext(): Unit = { @@ -49,7 +52,7 @@ object LocalSparkContext { if (sc != null) { sc.stop() } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + // To avoid RPC rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 7e70308bb360c..ddf48765ec30a 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark import scala.collection.mutable.ArrayBuffer -import org.mockito.Mockito._ import org.mockito.Matchers.{any, isA} +import org.mockito.Mockito._ -import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcCallContext, RpcEnv} +import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} @@ -125,7 +125,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf)) val slaveTracker = new MapOutputTrackerWorker(conf) slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() @@ -154,9 +154,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { slaveRpcEnv.shutdown() } - test("remote fetch below akka frame size") { + test("remote fetch below max RPC message size") { val newConf = new SparkConf - newConf.set("spark.akka.frameSize", "1") + newConf.set("spark.rpc.message.maxSize", "1") newConf.set("spark.rpc.askTimeout", "1") // Fail fast val masterTracker = new MapOutputTrackerMaster(conf) @@ -164,7 +164,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) - // Frame size should be ~123B, and no exception should be thrown + // Message size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) @@ -179,9 +179,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { rpcEnv.shutdown() } - test("remote fetch exceeds akka frame size") { + test("remote fetch exceeds max RPC message size") { val newConf = new SparkConf - newConf.set("spark.akka.frameSize", "1") + newConf.set("spark.rpc.message.maxSize", "1") newConf.set("spark.rpc.askTimeout", "1") // Fail fast val masterTracker = new MapOutputTrackerMaster(conf) @@ -189,7 +189,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) - // Frame size should be ~1.1MB, and MapOutputTrackerMasterEndpoint should throw exception. + // Message size should be ~1.1MB, and MapOutputTrackerMasterEndpoint should throw exception. // Note that the size is hand-selected here because map output statuses are compressed before // being sent. masterTracker.registerShuffle(20, 100) diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index aa8028792cb41..3d31c7864e760 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -163,8 +163,8 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva val hashP2 = new HashPartitioner(2) assert(rangeP2 === rangeP2) assert(hashP2 === hashP2) - assert(hashP2 != rangeP2) - assert(rangeP2 != hashP2) + assert(hashP2 !== rangeP2) + assert(rangeP2 !== hashP2) } test("partitioner preservation") { diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 25b79bce6ab98..159b448e05b02 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -20,8 +20,6 @@ package org.apache.spark import java.io.File import javax.net.ssl.SSLContext -import com.google.common.io.Files -import org.apache.spark.util.Utils import org.scalatest.BeforeAndAfterAll class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala index 2d14249855c9d..33270bec6247c 100644 --- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala +++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala @@ -41,7 +41,6 @@ object SSLSampleConfigs { def sparkSSLConfig(): SparkConf = { val conf = new SparkConf(loadDefaults = false) - conf.set("spark.rpc", "akka") conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") @@ -55,7 +54,6 @@ object SSLSampleConfigs { def sparkSSLConfigUntrusted(): SparkConf = { val conf = new SparkConf(loadDefaults = false) - conf.set("spark.rpc", "akka") conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", untrustedKeyStorePath) conf.set("spark.ssl.keyStorePassword", "password") diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index 26b95c06789f7..8bdb237c28f66 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark import java.io.File -import org.apache.spark.util.{SparkConfWithEnv, Utils} +import org.apache.spark.util.{ResetSystemProperties, SparkConfWithEnv, Utils} -class SecurityManagerSuite extends SparkFunSuite { +class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { test("set security with conf") { val conf = new SparkConf @@ -183,7 +183,6 @@ class SecurityManagerSuite extends SparkFunSuite { val securityManager = new SecurityManager(conf) assert(securityManager.fileServerSSLOptions.enabled === true) - assert(securityManager.akkaSSLOptions.enabled === true) assert(securityManager.sslSocketFactory.isDefined === true) assert(securityManager.hostnameVerifier.isDefined === true) @@ -197,16 +196,6 @@ class SecurityManagerSuite extends SparkFunSuite { assert(securityManager.fileServerSSLOptions.keyPassword === Some("password")) assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1.2")) assert(securityManager.fileServerSSLOptions.enabledAlgorithms === expectedAlgorithms) - - assert(securityManager.akkaSSLOptions.trustStore.isDefined === true) - assert(securityManager.akkaSSLOptions.trustStore.get.getName === "truststore") - assert(securityManager.akkaSSLOptions.keyStore.isDefined === true) - assert(securityManager.akkaSSLOptions.keyStore.get.getName === "keystore") - assert(securityManager.akkaSSLOptions.trustStorePassword === Some("password")) - assert(securityManager.akkaSSLOptions.keyStorePassword === Some("password")) - assert(securityManager.akkaSSLOptions.keyPassword === Some("password")) - assert(securityManager.akkaSSLOptions.protocol === Some("TLSv1.2")) - assert(securityManager.akkaSSLOptions.enabledAlgorithms === expectedAlgorithms) } test("ssl off setup") { @@ -218,7 +207,6 @@ class SecurityManagerSuite extends SparkFunSuite { val securityManager = new SecurityManager(conf) assert(securityManager.fileServerSSLOptions.enabled === false) - assert(securityManager.akkaSSLOptions.enabled === false) assert(securityManager.sslSocketFactory.isDefined === false) assert(securityManager.hostnameVerifier.isDefined === false) } diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala index 3d2700b7e6be4..858bc742e07cf 100644 --- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -30,13 +30,16 @@ trait SharedSparkContext extends BeforeAndAfterAll { self: Suite => var conf = new SparkConf(false) override def beforeAll() { - _sc = new SparkContext("local[4]", "test", conf) super.beforeAll() + _sc = new SparkContext("local[4]", "test", conf) } override def afterAll() { - LocalSparkContext.stop(_sc) - _sc = null - super.afterAll() + try { + LocalSparkContext.stop(_sc) + _sc = null + } finally { + super.afterAll() + } } } diff --git a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala index d78c99c2e1e06..73638d9b131ea 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala @@ -24,6 +24,7 @@ class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with Netty shuffle mode. override def beforeAll() { + super.beforeAll() conf.set("spark.shuffle.blockTransferService", "netty") } } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 4a0877d86f2c6..cd7d2e15700d3 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -17,13 +17,18 @@ package org.apache.spark +import java.util.Properties +import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService} + import org.scalatest.Matchers import org.apache.spark.ShuffleSuite.NonJavaSerializableClass +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD} -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} +import org.apache.spark.scheduler.{MapStatus, MyRDD, SparkListener, SparkListenerTaskEnd} import org.apache.spark.serializer.KryoSerializer -import org.apache.spark.storage.{ShuffleDataBlockId, ShuffleBlockId} +import org.apache.spark.shuffle.ShuffleWriter +import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId} import org.apache.spark.util.MutablePair abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkContext { @@ -317,6 +322,107 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(metrics.bytesWritten === metrics.byresRead) assert(metrics.bytesWritten > 0) } + + test("multiple simultaneous attempts for one task (SPARK-8029)") { + sc = new SparkContext("local", "test", conf) + val mapTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + val manager = sc.env.shuffleManager + + val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0L) + val metricsSystem = sc.env.metricsSystem + val shuffleMapRdd = new MyRDD(sc, 1, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) + val shuffleHandle = manager.registerShuffle(0, 1, shuffleDep) + + // first attempt -- its successful + val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, + new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem, + InternalAccumulator.createAll(sc))) + val data1 = (1 to 10).map { x => x -> x} + + // second attempt -- also successful. We'll write out different data, + // just to simulate the fact that the records may get written differently + // depending on what gets spilled, what gets combined, etc. + val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0, + new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem, + InternalAccumulator.createAll(sc))) + val data2 = (11 to 20).map { x => x -> x} + + // interleave writes of both attempts -- we want to test that both attempts can occur + // simultaneously, and everything is still OK + + def writeAndClose( + writer: ShuffleWriter[Int, Int])( + iter: Iterator[(Int, Int)]): Option[MapStatus] = { + val files = writer.write(iter) + writer.stop(true) + } + val interleaver = new InterleaveIterators( + data1, writeAndClose(writer1), data2, writeAndClose(writer2)) + val (mapOutput1, mapOutput2) = interleaver.run() + + // check that we can read the map output and it has the right data + assert(mapOutput1.isDefined) + assert(mapOutput2.isDefined) + assert(mapOutput1.get.location === mapOutput2.get.location) + assert(mapOutput1.get.getSizeForBlock(0) === mapOutput1.get.getSizeForBlock(0)) + + // register one of the map outputs -- doesn't matter which one + mapOutput1.foreach { case mapStatus => + mapTrackerMaster.registerMapOutputs(0, Array(mapStatus)) + } + + val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, + new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem, + InternalAccumulator.createAll(sc))) + val readData = reader.read().toIndexedSeq + assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) + + manager.unregisterShuffle(0) + } +} + +/** + * Utility to help tests make sure that we can process two different iterators simultaneously + * in different threads. This makes sure that in your test, you don't completely process data1 with + * f1 before processing data2 with f2 (or vice versa). It adds a barrier so that the functions only + * process one element, before pausing to wait for the other function to "catch up". + */ +class InterleaveIterators[T, R]( + data1: Seq[T], + f1: Iterator[T] => R, + data2: Seq[T], + f2: Iterator[T] => R) { + + require(data1.size == data2.size) + + val barrier = new CyclicBarrier(2) + class BarrierIterator[E](id: Int, sub: Iterator[E]) extends Iterator[E] { + def hasNext: Boolean = sub.hasNext + + def next: E = { + barrier.await() + sub.next() + } + } + + val c1 = new Callable[R] { + override def call(): R = f1(new BarrierIterator(1, data1.iterator)) + } + val c2 = new Callable[R] { + override def call(): R = f2(new BarrierIterator(2, data2.iterator)) + } + + val e: ExecutorService = Executors.newFixedThreadPool(2) + + def run(): (R, R) = { + val future1 = e.submit(c1) + val future2 = e.submit(c2) + val r1 = future1.get() + val r2 = future2.get() + e.shutdown() + (r1, r2) + } } object ShuffleSuite { @@ -345,8 +451,8 @@ object ShuffleSuite { val listener = new SparkListener { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { taskEnd.taskMetrics.shuffleWriteMetrics.foreach { m => - recordsWritten += m.shuffleRecordsWritten - bytesWritten += m.shuffleBytesWritten + recordsWritten += m.recordsWritten + bytesWritten += m.bytesWritten } taskEnd.taskMetrics.shuffleReadMetrics.foreach { m => recordsRead += m.recordsRead diff --git a/core/src/test/scala/org/apache/spark/Smuggle.scala b/core/src/test/scala/org/apache/spark/Smuggle.scala new file mode 100644 index 0000000000000..9d9217ea1b485 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/Smuggle.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.UUID +import java.util.concurrent.locks.ReentrantReadWriteLock + +import scala.collection.mutable +import scala.language.implicitConversions + +/** + * Utility wrapper to "smuggle" objects into tasks while bypassing serialization. + * This is intended for testing purposes, primarily to make locks, semaphores, and + * other constructs that would not survive serialization available from within tasks. + * A Smuggle reference is itself serializable, but after being serialized and + * deserialized, it still refers to the same underlying "smuggled" object, as long + * as it was deserialized within the same JVM. This can be useful for tests that + * depend on the timing of task completion to be deterministic, since one can "smuggle" + * a lock or semaphore into the task, and then the task can block until the test gives + * the go-ahead to proceed via the lock. + */ +class Smuggle[T] private(val key: Symbol) extends Serializable { + def smuggledObject: T = Smuggle.get(key) +} + + +object Smuggle { + /** + * Wraps the specified object to be smuggled into a serialized task without + * being serialized itself. + * + * @param smuggledObject + * @tparam T + * @return Smuggle wrapper around smuggledObject. + */ + def apply[T](smuggledObject: T): Smuggle[T] = { + val key = Symbol(UUID.randomUUID().toString) + lock.writeLock().lock() + try { + smuggledObjects += key -> smuggledObject + } finally { + lock.writeLock().unlock() + } + new Smuggle(key) + } + + private val lock = new ReentrantReadWriteLock + private val smuggledObjects = mutable.WeakHashMap.empty[Symbol, Any] + + private def get[T](key: Symbol) : T = { + lock.readLock().lock() + try { + smuggledObjects(key).asInstanceOf[T] + } finally { + lock.readLock().unlock() + } + } + + /** + * Implicit conversion of a Smuggle wrapper to the object being smuggled. + * + * @param smuggle the wrapper to unpack. + * @tparam T + * @return the smuggled object represented by the wrapper. + */ + implicit def unpackSmuggledObject[T](smuggle : Smuggle[T]): T = smuggle.smuggledObject + +} diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala index b8ab227517cc4..7a897c2b4698f 100644 --- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala @@ -26,8 +26,8 @@ import org.apache.commons.io.filefilter.TrueFileFilter import org.scalatest.BeforeAndAfterAll import org.apache.spark.rdd.ShuffledRDD -import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.util.Utils class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { @@ -37,10 +37,12 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { private var tempDir: File = _ override def beforeAll() { + super.beforeAll() conf.set("spark.shuffle.manager", "sort") } override def beforeEach(): Unit = { + super.beforeEach() tempDir = Utils.createTempDir() conf.set("spark.local.dir", tempDir.getAbsolutePath) } diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index ff9a92cc0a421..a883d1b57e526 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -17,17 +17,18 @@ package org.apache.spark -import java.util.concurrent.{TimeUnit, Executors} +import java.util.concurrent.{Executors, TimeUnit} import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps -import scala.util.{Try, Random} +import scala.util.{Random, Try} + +import com.esotericsoftware.kryo.Kryo import org.apache.spark.network.util.ByteUnit import org.apache.spark.serializer.{KryoRegistrator, KryoSerializer} -import org.apache.spark.util.{RpcUtils, ResetSystemProperties} -import com.esotericsoftware.kryo.Kryo +import org.apache.spark.util.{ResetSystemProperties, RpcUtils} class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSystemProperties { test("Test byteString conversion") { @@ -236,7 +237,7 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst conf.set(newName, "4") assert(conf.get(newName) === "4") - val count = conf.getAll.filter { case (k, v) => k.startsWith("spark.history.") }.size + val count = conf.getAll.count { case (k, v) => k.startsWith("spark.history.") } assert(count === 4) conf.set("spark.yarn.applicationMaster.waitTries", "42") @@ -266,6 +267,20 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst conf.set("spark.akka.lookupTimeout", "4") assert(RpcUtils.lookupRpcTimeout(conf).duration === (4 seconds)) } + + test("SPARK-13727") { + val conf = new SparkConf() + // set the conf in the deprecated way + conf.set("spark.io.compression.lz4.block.size", "12345") + // get the conf in the recommended way + assert(conf.get("spark.io.compression.lz4.blockSize") === "12345") + // we can still get the conf in the deprecated way + assert(conf.get("spark.io.compression.lz4.block.size") === "12345") + // the contains() also works as expected + assert(conf.contains("spark.io.compression.lz4.block.size")) + assert(conf.contains("spark.io.compression.lz4.blockSize")) + assert(conf.contains("spark.io.unknown") === false) + } } class Class1 {} diff --git a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala index 2bdbd70c638a5..8feb3dee050d2 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import org.scalatest.Assertions + import org.apache.spark.storage.StorageLevel class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext { @@ -81,20 +82,18 @@ package object testPackage extends Assertions { val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd" val rddCreationLine = rddCreationSite match { - case CALL_SITE_REGEX(func, file, line) => { + case CALL_SITE_REGEX(func, file, line) => assert(func === "makeRDD") assert(file === "SparkContextInfoSuite.scala") line.toInt - } case _ => fail("Did not match expected call site format") } curCallSite match { - case CALL_SITE_REGEX(func, file, line) => { + case CALL_SITE_REGEX(func, file, line) => assert(func === "getCallSite") // this is correct because we called it from outside of Spark assert(file === "SparkContextInfoSuite.scala") assert(line.toInt === rddCreationLine.toInt + 2) - } case _ => fail("Did not match expected call site format") } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index e5a14a69ef05f..49c2bf6bcad18 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -19,25 +19,32 @@ package org.apache.spark import org.scalatest.PrivateMethodTester -import org.apache.spark.util.Utils +import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SchedulerBackend, TaskScheduler, TaskSchedulerImpl} -import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend} +import org.apache.spark.scheduler.cluster.SparkDeploySchedulerBackend import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.local.LocalBackend +import org.apache.spark.util.Utils class SparkContextSchedulerCreationSuite extends SparkFunSuite with LocalSparkContext with PrivateMethodTester with Logging { def createTaskScheduler(master: String): TaskSchedulerImpl = - createTaskScheduler(master, new SparkConf()) + createTaskScheduler(master, "client") + + def createTaskScheduler(master: String, deployMode: String): TaskSchedulerImpl = + createTaskScheduler(master, deployMode, new SparkConf()) - def createTaskScheduler(master: String, conf: SparkConf): TaskSchedulerImpl = { + def createTaskScheduler( + master: String, + deployMode: String, + conf: SparkConf): TaskSchedulerImpl = { // Create local SparkContext to setup a SparkEnv. We don't actually want to start() the // real schedulers, so we don't want to create a full SparkContext with the desired scheduler. sc = new SparkContext("local", "test", conf) val createTaskSchedulerMethod = PrivateMethod[Tuple2[SchedulerBackend, TaskScheduler]]('createTaskScheduler) - val (_, sched) = SparkContext invokePrivate createTaskSchedulerMethod(sc, master) + val (_, sched) = SparkContext invokePrivate createTaskSchedulerMethod(sc, master, deployMode) sched.asInstanceOf[TaskSchedulerImpl] } @@ -107,7 +114,7 @@ class SparkContextSchedulerCreationSuite test("local-default-parallelism") { val conf = new SparkConf().set("spark.default.parallelism", "16") - val sched = createTaskScheduler("local", conf) + val sched = createTaskScheduler("local", "client", conf) sched.backend match { case s: LocalBackend => assert(s.defaultParallelism() === 16) @@ -115,13 +122,6 @@ class SparkContextSchedulerCreationSuite } } - test("simr") { - createTaskScheduler("simr://uri").backend match { - case s: SimrSchedulerBackend => // OK - case _ => fail() - } - } - test("local-cluster") { createTaskScheduler("local-cluster[3, 14, 1024]").backend match { case s: SparkDeploySchedulerBackend => // OK @@ -129,9 +129,9 @@ class SparkContextSchedulerCreationSuite } } - def testYarn(master: String, expectedClassName: String) { + def testYarn(master: String, deployMode: String, expectedClassName: String) { try { - val sched = createTaskScheduler(master) + val sched = createTaskScheduler(master, deployMode) assert(sched.getClass === Utils.classForName(expectedClassName)) } catch { case e: SparkException => @@ -142,21 +142,17 @@ class SparkContextSchedulerCreationSuite } test("yarn-cluster") { - testYarn("yarn-cluster", "org.apache.spark.scheduler.cluster.YarnClusterScheduler") - } - - test("yarn-standalone") { - testYarn("yarn-standalone", "org.apache.spark.scheduler.cluster.YarnClusterScheduler") + testYarn("yarn", "cluster", "org.apache.spark.scheduler.cluster.YarnClusterScheduler") } test("yarn-client") { - testYarn("yarn-client", "org.apache.spark.scheduler.cluster.YarnScheduler") + testYarn("yarn", "client", "org.apache.spark.scheduler.cluster.YarnScheduler") } def testMesos(master: String, expectedClass: Class[_], coarse: Boolean) { val conf = new SparkConf().set("spark.mesos.coarse", coarse.toString) try { - val sched = createTaskScheduler(master, conf) + val sched = createTaskScheduler(master, "client", conf) assert(sched.backend.getClass === expectedClass) } catch { case e: UnsatisfiedLinkError => @@ -175,6 +171,11 @@ class SparkContextSchedulerCreationSuite } test("mesos with zookeeper") { + testMesos("mesos://zk://localhost:1234,localhost:2345", + classOf[MesosSchedulerBackend], coarse = false) + } + + test("mesos with zookeeper and Master URL starting with zk://") { testMesos("zk://localhost:1234,localhost:2345", classOf[MesosSchedulerBackend], coarse = false) } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index d4f2ea87650a9..841fd02ae8bb6 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -18,20 +18,20 @@ package org.apache.spark import java.io.File +import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit -import com.google.common.base.Charsets._ -import com.google.common.io.Files +import scala.concurrent.Await +import scala.concurrent.duration.Duration +import com.google.common.io.Files import org.apache.hadoop.io.{BytesWritable, LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} -import org.apache.spark.util.Utils - -import scala.concurrent.Await -import scala.concurrent.duration.Duration import org.scalatest.Matchers._ +import org.apache.spark.util.Utils + class SparkContextSuite extends SparkFunSuite with LocalSparkContext { test("Only one SparkContext may be active at a time") { @@ -115,8 +115,8 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { val absolutePath2 = file2.getAbsolutePath try { - Files.write("somewords1", file1, UTF_8) - Files.write("somewords2", file2, UTF_8) + Files.write("somewords1", file1, StandardCharsets.UTF_8) + Files.write("somewords2", file2, StandardCharsets.UTF_8) val length1 = file1.length() val length2 = file2.length() @@ -243,11 +243,12 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { try { // Create 5 text files. - Files.write("someline1 in file1\nsomeline2 in file1\nsomeline3 in file1", file1, UTF_8) - Files.write("someline1 in file2\nsomeline2 in file2", file2, UTF_8) - Files.write("someline1 in file3", file3, UTF_8) - Files.write("someline1 in file4\nsomeline2 in file4", file4, UTF_8) - Files.write("someline1 in file2\nsomeline2 in file5", file5, UTF_8) + Files.write("someline1 in file1\nsomeline2 in file1\nsomeline3 in file1", file1, + StandardCharsets.UTF_8) + Files.write("someline1 in file2\nsomeline2 in file2", file2, StandardCharsets.UTF_8) + Files.write("someline1 in file3", file3, StandardCharsets.UTF_8) + Files.write("someline1 in file4\nsomeline2 in file4", file4, StandardCharsets.UTF_8) + Files.write("someline1 in file2\nsomeline2 in file5", file5, StandardCharsets.UTF_8) sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) @@ -274,6 +275,31 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { } } + test("Default path for file based RDDs is properly set (SPARK-12517)") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + + // Test filetextFile, wholeTextFiles, binaryFiles, hadoopFile and + // newAPIHadoopFile for setting the default path as the RDD name + val mockPath = "default/path/for/" + + var targetPath = mockPath + "textFile" + assert(sc.textFile(targetPath).name === targetPath) + + targetPath = mockPath + "wholeTextFiles" + assert(sc.wholeTextFiles(targetPath).name === targetPath) + + targetPath = mockPath + "binaryFiles" + assert(sc.binaryFiles(targetPath).name === targetPath) + + targetPath = mockPath + "hadoopFile" + assert(sc.hadoopFile(targetPath).name === targetPath) + + targetPath = mockPath + "newAPIHadoopFile" + assert(sc.newAPIHadoopFile(targetPath).name === targetPath) + + sc.stop() + } + test("calling multiple sc.stop() must not throw any exception") { noException should be thrownBy { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 9be9db01c7de9..3228752b96389 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -18,14 +18,28 @@ package org.apache.spark // scalastyle:off -import org.scalatest.{FunSuite, Outcome} +import org.scalatest.{BeforeAndAfterAll, FunSuite, Outcome} + +import org.apache.spark.internal.Logging /** * Base abstract class for all unit tests in Spark for handling common functionality. */ -private[spark] abstract class SparkFunSuite extends FunSuite with Logging { +private[spark] abstract class SparkFunSuite + extends FunSuite + with BeforeAndAfterAll + with Logging { // scalastyle:on + protected override def afterAll(): Unit = { + try { + // Avoid leaking map entries in tests that use accumulators without SparkContext + Accumulators.clear() + } finally { + super.afterAll() + } + } + /** * Log the suite name and the test name before and after each test. * diff --git a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala index 46516e8d25298..5483f2b8434aa 100644 --- a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala @@ -86,4 +86,30 @@ class StatusTrackerSuite extends SparkFunSuite with Matchers with LocalSparkCont Set(firstJobId, secondJobId)) } } + + test("getJobIdsForGroup() with takeAsync()") { + sc = new SparkContext("local", "test", new SparkConf(false)) + sc.setJobGroup("my-job-group2", "description") + sc.statusTracker.getJobIdsForGroup("my-job-group2") shouldBe empty + val firstJobFuture = sc.parallelize(1 to 1000, 1).takeAsync(1) + val firstJobId = eventually(timeout(10 seconds)) { + firstJobFuture.jobIds.head + } + eventually(timeout(10 seconds)) { + sc.statusTracker.getJobIdsForGroup("my-job-group2") should be (Seq(firstJobId)) + } + } + + test("getJobIdsForGroup() with takeAsync() across multiple partitions") { + sc = new SparkContext("local", "test", new SparkConf(false)) + sc.setJobGroup("my-job-group2", "description") + sc.statusTracker.getJobIdsForGroup("my-job-group2") shouldBe empty + val firstJobFuture = sc.parallelize(1 to 1000, 2).takeAsync(999) + val firstJobId = eventually(timeout(10 seconds)) { + firstJobFuture.jobIds.head + } + eventually(timeout(10 seconds)) { + sc.statusTracker.getJobIdsForGroup("my-job-group2") should have size 2 + } + } } diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index 54c131cdae367..36273d722f50a 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark -import java.util.concurrent.{TimeUnit, Semaphore} -import java.util.concurrent.atomic.AtomicBoolean -import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.Semaphore +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} -import org.apache.spark.scheduler._ +import org.apache.spark.internal.Logging /** * Holds state shared across task threads in some ThreadingSuite tests. diff --git a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala index f7a13ab3996d8..09e21646ee744 100644 --- a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala +++ b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala @@ -35,7 +35,7 @@ class UnpersistSuite extends SparkFunSuite with LocalSparkContext { Thread.sleep(200) } } catch { - case _: Throwable => { Thread.sleep(10) } + case _: Throwable => Thread.sleep(10) // Do nothing. We might see exceptions because block manager // is racing this thread to remove entries from the driver. } diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala index 135c56bf5bc9d..b38a3667abee1 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.api.python -import scala.io.Source +import java.io.{File, PrintWriter} -import java.io.{PrintWriter, File} +import scala.io.Source import org.scalatest.Matchers diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala index 41f2a5c972b6b..05b4e67412f2e 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.api.python import java.io.{ByteArrayOutputStream, DataOutputStream} +import java.nio.charset.StandardCharsets import org.apache.spark.SparkFunSuite @@ -35,10 +36,12 @@ class PythonRDDSuite extends SparkFunSuite { // The correctness will be tested in Python PythonRDD.writeIteratorToStream(Iterator("a", null), buffer) PythonRDD.writeIteratorToStream(Iterator(null, "a"), buffer) - PythonRDD.writeIteratorToStream(Iterator("a".getBytes, null), buffer) - PythonRDD.writeIteratorToStream(Iterator(null, "a".getBytes), buffer) + PythonRDD.writeIteratorToStream(Iterator("a".getBytes(StandardCharsets.UTF_8), null), buffer) + PythonRDD.writeIteratorToStream(Iterator(null, "a".getBytes(StandardCharsets.UTF_8)), buffer) PythonRDD.writeIteratorToStream(Iterator((null, null), ("a", null), (null, "b")), buffer) - PythonRDD.writeIteratorToStream( - Iterator((null, null), ("a".getBytes, null), (null, "b".getBytes)), buffer) + PythonRDD.writeIteratorToStream(Iterator( + (null, null), + ("a".getBytes(StandardCharsets.UTF_8), null), + (null, "b".getBytes(StandardCharsets.UTF_8))), buffer) } } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index ba21075ce6be5..6657104823e71 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -37,7 +37,7 @@ class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { rdd.map { x => val bm = SparkEnv.get.blockManager // Check if broadcast block was fetched - val isFound = bm.getLocal(BroadcastBlockId(bid)).isDefined + val isFound = bm.getLocalValues(BroadcastBlockId(bid)).isDefined (x, isFound) }.collect().toSet } @@ -45,39 +45,8 @@ class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { class BroadcastSuite extends SparkFunSuite with LocalSparkContext { - private val httpConf = broadcastConf("HttpBroadcastFactory") - private val torrentConf = broadcastConf("TorrentBroadcastFactory") - - test("Using HttpBroadcast locally") { - sc = new SparkContext("local", "test", httpConf) - val list = List[Int](1, 2, 3, 4) - val broadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum)) - assert(results.collect().toSet === Set((1, 10), (2, 10))) - } - - test("Accessing HttpBroadcast variables from multiple threads") { - sc = new SparkContext("local[10]", "test", httpConf) - val list = List[Int](1, 2, 3, 4) - val broadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum)) - assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet) - } - - test("Accessing HttpBroadcast variables in a local cluster") { - val numSlaves = 4 - val conf = httpConf.clone - conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.broadcast.compress", "true") - sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) - val list = List[Int](1, 2, 3, 4) - val broadcast = sc.broadcast(list) - val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) - assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) - } - test("Using TorrentBroadcast locally") { - sc = new SparkContext("local", "test", torrentConf) + sc = new SparkContext("local", "test") val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum)) @@ -85,7 +54,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { } test("Accessing TorrentBroadcast variables from multiple threads") { - sc = new SparkContext("local[10]", "test", torrentConf) + sc = new SparkContext("local[10]", "test") val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum)) @@ -94,7 +63,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { test("Accessing TorrentBroadcast variables in a local cluster") { val numSlaves = 4 - val conf = torrentConf.clone + val conf = new SparkConf conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.broadcast.compress", "true") sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) @@ -124,31 +93,13 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { test("Test Lazy Broadcast variables with TorrentBroadcast") { val numSlaves = 2 - val conf = torrentConf.clone - sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) + sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test") val rdd = sc.parallelize(1 to numSlaves) - val results = new DummyBroadcastClass(rdd).doSomething() assert(results.toSet === (1 to numSlaves).map(x => (x, false)).toSet) } - test("Unpersisting HttpBroadcast on executors only in local mode") { - testUnpersistHttpBroadcast(distributed = false, removeFromDriver = false) - } - - test("Unpersisting HttpBroadcast on executors and driver in local mode") { - testUnpersistHttpBroadcast(distributed = false, removeFromDriver = true) - } - - test("Unpersisting HttpBroadcast on executors only in distributed mode") { - testUnpersistHttpBroadcast(distributed = true, removeFromDriver = false) - } - - test("Unpersisting HttpBroadcast on executors and driver in distributed mode") { - testUnpersistHttpBroadcast(distributed = true, removeFromDriver = true) - } - test("Unpersisting TorrentBroadcast on executors only in local mode") { testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = false) } @@ -179,64 +130,11 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { assert(thrown.getMessage.toLowerCase.contains("stopped")) } - /** - * Verify the persistence of state associated with an HttpBroadcast in either local mode or - * local-cluster mode (when distributed = true). - * - * This test creates a broadcast variable, uses it on all executors, and then unpersists it. - * In between each step, this test verifies that the broadcast blocks and the broadcast file - * are present only on the expected nodes. - */ - private def testUnpersistHttpBroadcast(distributed: Boolean, removeFromDriver: Boolean) { - val numSlaves = if (distributed) 2 else 0 - - // Verify that the broadcast file is created, and blocks are persisted only on the driver - def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) { - val blockId = BroadcastBlockId(broadcastId) - val statuses = bmm.getBlockStatus(blockId, askSlaves = true) - assert(statuses.size === 1) - statuses.head match { case (bm, status) => - assert(bm.isDriver, "Block should only be on the driver") - assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) - assert(status.memSize > 0, "Block should be in memory store on the driver") - assert(status.diskSize === 0, "Block should not be in disk store on the driver") - } - if (distributed) { - // this file is only generated in distributed mode - assert(HttpBroadcast.getFile(blockId.broadcastId).exists, "Broadcast file not found!") - } - } - - // Verify that blocks are persisted in both the executors and the driver - def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) { - val blockId = BroadcastBlockId(broadcastId) - val statuses = bmm.getBlockStatus(blockId, askSlaves = true) - assert(statuses.size === numSlaves + 1) - statuses.foreach { case (_, status) => - assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) - assert(status.memSize > 0, "Block should be in memory store") - assert(status.diskSize === 0, "Block should not be in disk store") - } - } - - // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver - // is true. In the latter case, also verify that the broadcast file is deleted on the driver. - def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) { - val blockId = BroadcastBlockId(broadcastId) - val statuses = bmm.getBlockStatus(blockId, askSlaves = true) - val expectedNumBlocks = if (removeFromDriver) 0 else 1 - val possiblyNot = if (removeFromDriver) "" else " not" - assert(statuses.size === expectedNumBlocks, - "Block should%s be unpersisted on the driver".format(possiblyNot)) - if (distributed && removeFromDriver) { - // this file is only generated in distributed mode - assert(!HttpBroadcast.getFile(blockId.broadcastId).exists, - "Broadcast file should%s be deleted".format(possiblyNot)) - } - } - - testUnpersistBroadcast(distributed, numSlaves, httpConf, afterCreation, - afterUsingBroadcast, afterUnpersist, removeFromDriver) + test("Forbid broadcasting RDD directly") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + val rdd = sc.parallelize(1 to 4) + intercept[IllegalArgumentException] { sc.broadcast(rdd) } + sc.stop() } /** @@ -284,7 +182,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { assert(statuses.size === expectedNumBlocks) } - testUnpersistBroadcast(distributed, numSlaves, torrentConf, afterCreation, + testUnpersistBroadcast(distributed, numSlaves, afterCreation, afterUsingBroadcast, afterUnpersist, removeFromDriver) } @@ -300,7 +198,6 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { private def testUnpersistBroadcast( distributed: Boolean, numSlaves: Int, // used only when distributed = true - broadcastConf: SparkConf, afterCreation: (Long, BlockManagerMaster) => Unit, afterUsingBroadcast: (Long, BlockManagerMaster) => Unit, afterUnpersist: (Long, BlockManagerMaster) => Unit, @@ -308,7 +205,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { sc = if (distributed) { val _sc = - new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", broadcastConf) + new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test") // Wait until all salves are up try { _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 60000) @@ -319,7 +216,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { throw e } } else { - new SparkContext("local", "test", broadcastConf) + new SparkContext("local", "test") } val blockManagerMaster = sc.env.blockManager.master val list = List[Int](1, 2, 3, 4) @@ -356,13 +253,6 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet) } } - - /** Helper method to create a SparkConf that uses the given broadcast factory. */ - private def broadcastConf(factoryName: String): SparkConf = { - val conf = new SparkConf - conf.set("spark.broadcast.factory", "org.apache.spark.broadcast.%s".format(factoryName)) - conf - } } package object testPackage extends Assertions { diff --git a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala index 3164760b08a71..9c13c15281a42 100644 --- a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala @@ -20,9 +20,9 @@ package org.apache.spark.deploy import java.io.File import java.util.Date +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} -import org.apache.spark.{SecurityManager, SparkConf} private[deploy] object DeployTestUtils { def createAppDesc(): ApplicationDescription = { @@ -50,7 +50,7 @@ private[deploy] object DeployTestUtils { createDriverDesc(), new Date()) def createWorkerInfo(): WorkerInfo = { - val workerInfo = new WorkerInfo("id", "host", 8080, 4, 1234, null, 80, "publicAddress") + val workerInfo = new WorkerInfo("id", "host", 8080, 4, 1234, null, "http://publicAddress:80") workerInfo.lastHeartbeat = JsonConstants.currTimeInMillis workerInfo } @@ -69,7 +69,7 @@ private[deploy] object DeployTestUtils { "publicAddress", new File("sparkHome"), new File("workDir"), - "akka://worker", + "spark://worker", new SparkConf, Seq("localDir"), ExecutorState.RUNNING) @@ -84,7 +84,7 @@ private[deploy] object DeployTestUtils { new File("sparkHome"), createDriverDesc(), null, - "akka://worker", + "spark://worker", new SecurityManager(conf)) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala index d93febcfd23fd..9ecf49b59898b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala @@ -24,10 +24,8 @@ import java.util.jar.Manifest import scala.collection.mutable.ArrayBuffer -import com.google.common.io.{Files, ByteStreams} - +import com.google.common.io.{ByteStreams, Files} import org.apache.commons.io.FileUtils - import org.apache.ivy.core.settings.IvySettings import org.apache.spark.TestUtils.{createCompiledClass, JavaSourceFromString} diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index 0a9f128a3a6b6..2d48e75cfbd96 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -23,10 +23,10 @@ import com.fasterxml.jackson.core.JsonParseException import org.json4s._ import org.json4s.jackson.JsonMethods +import org.apache.spark.{JsonTestUtils, SparkFunSuite} import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} import org.apache.spark.deploy.master.{ApplicationInfo, RecoveryState} import org.apache.spark.deploy.worker.ExecutorRunner -import org.apache.spark.{JsonTestUtils, SparkFunSuite} class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala index 8dd31b4b6fdda..cbdf1755b0c5b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -22,9 +22,9 @@ import java.net.URL import scala.collection.mutable import scala.io.Source +import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorAdded} import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.scheduler.{SparkListenerExecutorAdded, SparkListener} -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.util.SparkConfWithEnv class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala index 1ed4bae3ca21e..13cba94578a6a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.deploy -import java.io.{PrintStream, OutputStream, File} +import java.io.{File, OutputStream, PrintStream} import java.net.URI -import java.util.jar.Attributes.Name import java.util.jar.{JarFile, Manifest} +import java.util.jar.Attributes.Name import java.util.zip.ZipFile import scala.collection.JavaConverters._ @@ -33,8 +33,12 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate +import org.apache.spark.util.ResetSystemProperties -class RPackageUtilsSuite extends SparkFunSuite with BeforeAndAfterEach { +class RPackageUtilsSuite + extends SparkFunSuite + with BeforeAndAfterEach + with ResetSystemProperties { private val main = MavenCoordinate("a", "b", "c") private val dep1 = MavenCoordinate("a", "dep1", "c") @@ -60,11 +64,9 @@ class RPackageUtilsSuite extends SparkFunSuite with BeforeAndAfterEach { } } - def beforeAll() { - System.setProperty("spark.testing", "true") - } - override def beforeEach(): Unit = { + super.beforeEach() + System.setProperty("spark.testing", "true") lineBuffer.clear() } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 1fd470cd3b01d..271897699201b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -18,29 +18,33 @@ package org.apache.spark.deploy import java.io._ +import java.nio.charset.StandardCharsets import scala.collection.mutable.ArrayBuffer -import com.google.common.base.Charsets.UTF_8 import com.google.common.io.ByteStreams -import org.scalatest.Matchers +import org.scalatest.{BeforeAndAfterEach, Matchers} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate +import org.apache.spark.internal.Logging import org.apache.spark.util.{ResetSystemProperties, Utils} // Note: this suite mixes in ResetSystemProperties because SparkSubmit.main() sets a bunch -// of properties that neeed to be cleared after tests. +// of properties that needed to be cleared after tests. class SparkSubmitSuite extends SparkFunSuite with Matchers + with BeforeAndAfterEach with ResetSystemProperties with Timeouts { - def beforeAll() { + override def beforeEach() { + super.beforeEach() System.setProperty("spark.testing", "true") } @@ -133,6 +137,47 @@ class SparkSubmitSuite appArgs.childArgs should be (Seq("--master", "local", "some", "--weird", "args")) } + test("specify deploy mode through configuration") { + val clArgs = Seq( + "--master", "yarn", + "--conf", "spark.submit.deployMode=client", + "--class", "org.SomeClass", + "thejar.jar" + ) + val appArgs = new SparkSubmitArguments(clArgs) + val (_, _, sysProps, _) = prepareSubmitEnvironment(appArgs) + + appArgs.deployMode should be ("client") + sysProps("spark.submit.deployMode") should be ("client") + + // Both cmd line and configuration are specified, cmdline option takes the priority + val clArgs1 = Seq( + "--master", "yarn", + "--deploy-mode", "cluster", + "--conf", "spark.submit.deployMode=client", + "-class", "org.SomeClass", + "thejar.jar" + ) + val appArgs1 = new SparkSubmitArguments(clArgs1) + val (_, _, sysProps1, _) = prepareSubmitEnvironment(appArgs1) + + appArgs1.deployMode should be ("cluster") + sysProps1("spark.submit.deployMode") should be ("cluster") + + // Neither cmdline nor configuration are specified, client mode is the default choice + val clArgs2 = Seq( + "--master", "yarn", + "--class", "org.SomeClass", + "thejar.jar" + ) + val appArgs2 = new SparkSubmitArguments(clArgs2) + appArgs2.deployMode should be (null) + + val (_, _, sysProps2, _) = prepareSubmitEnvironment(appArgs2) + appArgs2.deployMode should be ("client") + sysProps2("spark.submit.deployMode") should be ("client") + } + test("handles YARN cluster mode") { val clArgs = Seq( "--deploy-mode", "cluster", @@ -154,21 +199,21 @@ class SparkSubmitSuite val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") childArgsStr should include ("--class org.SomeClass") - childArgsStr should include ("--executor-memory 5g") - childArgsStr should include ("--driver-memory 4g") - childArgsStr should include ("--executor-cores 5") childArgsStr should include ("--arg arg1 --arg arg2") - childArgsStr should include ("--queue thequeue") childArgsStr should include regex ("--jar .*thejar.jar") - childArgsStr should include regex ("--addJars .*one.jar,.*two.jar,.*three.jar") - childArgsStr should include regex ("--files .*file1.txt,.*file2.txt") - childArgsStr should include regex ("--archives .*archive1.txt,.*archive2.txt") mainClass should be ("org.apache.spark.deploy.yarn.Client") classpath should have length (0) + + sysProps("spark.executor.memory") should be ("5g") + sysProps("spark.driver.memory") should be ("4g") + sysProps("spark.executor.cores") should be ("5") + sysProps("spark.yarn.queue") should be ("thequeue") + sysProps("spark.yarn.dist.jars") should include regex (".*one.jar,.*two.jar,.*three.jar") + sysProps("spark.yarn.dist.files") should include regex (".*file1.txt,.*file2.txt") + sysProps("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") sysProps("spark.app.name") should be ("beauty") sysProps("spark.ui.enabled") should be ("false") sysProps("SPARK_SUBMIT") should be ("true") - sysProps.keys should not contain ("spark.jars") } test("handles YARN client mode") { @@ -204,7 +249,8 @@ class SparkSubmitSuite sysProps("spark.executor.instances") should be ("6") sysProps("spark.yarn.dist.files") should include regex (".*file1.txt,.*file2.txt") sysProps("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") - sysProps("spark.jars") should include regex (".*one.jar,.*two.jar,.*three.jar,.*thejar.jar") + sysProps("spark.yarn.dist.jars") should include + regex (".*one.jar,.*two.jar,.*three.jar,.*thejar.jar") sysProps("SPARK_SUBMIT") should be ("true") sysProps("spark.ui.enabled") should be ("false") } @@ -314,7 +360,8 @@ class SparkSubmitSuite val appArgs = new SparkSubmitArguments(clArgs) val (_, _, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) sysProps("spark.executor.memory") should be ("5g") - sysProps("spark.master") should be ("yarn-cluster") + sysProps("spark.master") should be ("yarn") + sysProps("spark.submit.deployMode") should be ("cluster") mainClass should be ("org.apache.spark.deploy.yarn.Client") } @@ -366,10 +413,9 @@ class SparkSubmitSuite } } - test("correctly builds R packages included in a jar with --packages") { - // TODO(SPARK-9603): Building a package to $SPARK_HOME/R/lib is unavailable on Jenkins. - // It's hard to write the test in SparkR (because we can't create the repository dynamically) - /* + // TODO(SPARK-9603): Building a package is flaky on Jenkins Maven builds. + // See https://gist.github.com/shivaram/3a2fecce60768a603dac for a error log + ignore("correctly builds R packages included in a jar with --packages") { assume(RUtils.isRInstalled, "R isn't installed on this machine.") val main = MavenCoordinate("my.great.lib", "mylib", "0.1") val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) @@ -387,7 +433,6 @@ class SparkSubmitSuite rScriptDir) runSparkSubmit(args) } - */ } test("resolves command line argument paths correctly") { @@ -412,7 +457,7 @@ class SparkSubmitSuite // Test files and archives (Yarn) val clArgs2 = Seq( - "--master", "yarn-client", + "--master", "yarn", "--class", "org.SomeClass", "--files", files, "--archives", archives, @@ -470,7 +515,7 @@ class SparkSubmitSuite writer2.println("spark.yarn.dist.archives " + archives) writer2.close() val clArgs2 = Seq( - "--master", "yarn-client", + "--master", "yarn", "--class", "org.SomeClass", "--properties-file", f2.getPath, "thejar.jar" @@ -550,7 +595,7 @@ class SparkSubmitSuite val tmpDir = Utils.createTempDir() val defaultsConf = new File(tmpDir.getAbsolutePath, "spark-defaults.conf") - val writer = new OutputStreamWriter(new FileOutputStream(defaultsConf)) + val writer = new OutputStreamWriter(new FileOutputStream(defaultsConf), StandardCharsets.UTF_8) for ((key, value) <- defaults) writer.write(s"$key $value\n") writer.close() @@ -575,7 +620,7 @@ object JarCreationTest extends Logging { Utils.classForName(args(1)) } catch { case t: Throwable => - exception = t + "\n" + t.getStackTraceString + exception = t + "\n" + Utils.exceptionString(t) exception = exception.replaceAll("\n", "\n\t") } Option(exception).toSeq.iterator @@ -618,7 +663,7 @@ object UserClasspathFirstTest { val ccl = Thread.currentThread().getContextClassLoader() val resource = ccl.getResourceAsStream("test.resource") val bytes = ByteStreams.toByteArray(resource) - val contents = new String(bytes, 0, bytes.length, UTF_8) + val contents = new String(bytes, 0, bytes.length, StandardCharsets.UTF_8) if (contents != "USER") { throw new SparkException("Should have read user resource, but instead read: " + contents) } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 63c346c1b8908..4877710c1237d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.deploy -import java.io.{File, PrintStream, OutputStream} +import java.io.{File, OutputStream, PrintStream} import scala.collection.mutable.ArrayBuffer -import org.scalatest.BeforeAndAfterAll import org.apache.ivy.core.module.descriptor.MDArtifact import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.resolver.{AbstractResolver, FileSystemResolver, IBiblioResolver} +import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate @@ -171,7 +171,7 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { } test("neglects Spark and Spark's dependencies") { - val components = Seq("bagel_", "catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_", + val components = Seq("catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_", "sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_") val coordinates = diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index d145e78834b1b..3d39bd4a748cf 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.deploy +import scala.collection.mutable import scala.concurrent.duration._ import org.mockito.Mockito.{mock, when} -import org.scalatest.BeforeAndAfterAll +import org.scalatest.{BeforeAndAfterAll, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ @@ -29,6 +30,7 @@ import org.apache.spark.deploy.master.ApplicationInfo import org.apache.spark.deploy.master.Master import org.apache.spark.deploy.worker.Worker import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv} +import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RegisterExecutor @@ -38,7 +40,8 @@ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RegisterE class StandaloneDynamicAllocationSuite extends SparkFunSuite with LocalSparkContext - with BeforeAndAfterAll { + with BeforeAndAfterAll + with PrivateMethodTester { private val numWorkers = 2 private val conf = new SparkConf() @@ -68,15 +71,18 @@ class StandaloneDynamicAllocationSuite } override def afterAll(): Unit = { - masterRpcEnv.shutdown() - workerRpcEnvs.foreach(_.shutdown()) - master.stop() - workers.foreach(_.stop()) - masterRpcEnv = null - workerRpcEnvs = null - master = null - workers = null - super.afterAll() + try { + masterRpcEnv.shutdown() + workerRpcEnvs.foreach(_.shutdown()) + master.stop() + workers.foreach(_.stop()) + masterRpcEnv = null + workerRpcEnvs = null + master = null + workers = null + } finally { + super.afterAll() + } } test("dynamic allocation default behavior") { @@ -362,7 +368,7 @@ class StandaloneDynamicAllocationSuite val executors = getExecutorIds(sc) assert(executors.size === 2) assert(sc.killExecutor(executors.head)) - assert(sc.killExecutor(executors.head)) + assert(!sc.killExecutor(executors.head)) val apps = getApplications() assert(apps.head.executors.size === 1) // The limit should not be lowered twice @@ -383,25 +389,81 @@ class StandaloneDynamicAllocationSuite // the driver refuses to kill executors it does not know about syncExecutors(sc) val executors = getExecutorIds(sc) + val executorIdsBefore = executors.toSet assert(executors.size === 2) - // kill executor 1, and replace it + // kill and replace an executor assert(sc.killAndReplaceExecutor(executors.head)) eventually(timeout(10.seconds), interval(10.millis)) { val apps = getApplications() assert(apps.head.executors.size === 2) + val executorIdsAfter = getExecutorIds(sc).toSet + // make sure the executor was killed and replaced + assert(executorIdsBefore != executorIdsAfter) } + // kill old executor (which is killedAndReplaced) should fail + assert(!sc.killExecutor(executors.head)) + + // refresh executors list + val newExecutors = getExecutorIds(sc) + syncExecutors(sc) + + // kill newly created executor and do not replace it + assert(sc.killExecutor(newExecutors(1))) + val apps = getApplications() + assert(apps.head.executors.size === 1) + assert(apps.head.getExecutorLimit === 1) + } + + test("disable force kill for busy executors (SPARK-9552)") { + sc = new SparkContext(appConf) + val appId = sc.applicationId + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(apps.size === 1) + assert(apps.head.id === appId) + assert(apps.head.executors.size === 2) + assert(apps.head.getExecutorLimit === Int.MaxValue) + } var apps = getApplications() - // kill executor 1 - assert(sc.killExecutor(executors.head)) + // sync executors between the Master and the driver, needed because + // the driver refuses to kill executors it does not know about + syncExecutors(sc) + val executors = getExecutorIds(sc) + assert(executors.size === 2) + + // simulate running a task on the executor + val getMap = PrivateMethod[mutable.HashMap[String, Int]]('executorIdToTaskCount) + val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] + val executorIdToTaskCount = taskScheduler invokePrivate getMap() + executorIdToTaskCount(executors.head) = 1 + // kill the busy executor without force; this should fail + assert(!killExecutor(sc, executors.head, force = false)) apps = getApplications() assert(apps.head.executors.size === 2) - assert(apps.head.getExecutorLimit === 2) - // kill executor 2 - assert(sc.killExecutor(executors(1))) + + // force kill busy executor + assert(killExecutor(sc, executors.head, force = true)) apps = getApplications() + // kill executor successfully assert(apps.head.executors.size === 1) - assert(apps.head.getExecutorLimit === 1) + } + + test("initial executor limit") { + val initialExecutorLimit = 1 + val myConf = appConf + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.shuffle.service.enabled", "true") + .set("spark.dynamicAllocation.initialExecutors", initialExecutorLimit.toString) + sc = new SparkContext(myConf) + val appId = sc.applicationId + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(apps.size === 1) + assert(apps.head.id === appId) + assert(apps.head.executors.size === initialExecutorLimit) + assert(apps.head.getExecutorLimit === initialExecutorLimit) + } } // =============================== @@ -428,7 +490,7 @@ class StandaloneDynamicAllocationSuite (0 until numWorkers).map { i => val rpcEnv = workerRpcEnvs(i) val worker = new Worker(rpcEnv, 0, cores, memory, Array(masterRpcEnv.address), - Worker.SYSTEM_NAME + i, Worker.ENDPOINT_NAME, null, conf, securityManager) + Worker.ENDPOINT_NAME, null, conf, securityManager) rpcEnv.setupEndpoint(Worker.ENDPOINT_NAME, worker) worker } @@ -439,7 +501,7 @@ class StandaloneDynamicAllocationSuite master.self.askWithRetry[MasterStateResponse](RequestMasterState) } - /** Get the applictions that are active from Master */ + /** Get the applications that are active from Master */ private def getApplications(): Seq[ApplicationInfo] = { getMasterState.activeApps } @@ -455,6 +517,16 @@ class StandaloneDynamicAllocationSuite sc.killExecutors(getExecutorIds(sc).take(n)) } + /** Kill the given executor, specifying whether to force kill it. */ + private def killExecutor(sc: SparkContext, executorId: String, force: Boolean): Boolean = { + syncExecutors(sc) + sc.schedulerBackend match { + case b: CoarseGrainedSchedulerBackend => + b.killExecutors(Seq(executorId), replace = false, force) + case _ => fail("expected coarse grained scheduler") + } + } + /** * Return a list of executor IDs belonging to this application. * @@ -484,13 +556,12 @@ class StandaloneDynamicAllocationSuite val missingExecutors = masterExecutors.toSet.diff(driverExecutors.toSet).toSeq.sorted missingExecutors.foreach { id => // Fake an executor registration so the driver knows about us - val port = System.currentTimeMillis % 65536 val endpointRef = mock(classOf[RpcEndpointRef]) val mockAddress = mock(classOf[RpcAddress]) when(endpointRef.address).thenReturn(mockAddress) - val message = RegisterExecutor(id, endpointRef, s"localhost:$port", 10, Map.empty) + val message = RegisterExecutor(id, endpointRef, 10, Map.empty) val backend = sc.schedulerBackend.asInstanceOf[CoarseGrainedSchedulerBackend] - backend.driverEndpoint.askWithRetry[CoarseGrainedClusterMessage](message) + backend.driverEndpoint.askWithRetry[Boolean](message) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala new file mode 100644 index 0000000000000..7b46f9101d89b --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.client + +import java.util.concurrent.ConcurrentLinkedQueue + +import scala.concurrent.duration._ + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark._ +import org.apache.spark.deploy.{ApplicationDescription, Command} +import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} +import org.apache.spark.deploy.master.{ApplicationInfo, Master} +import org.apache.spark.deploy.worker.Worker +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.util.Utils + +/** + * End-to-end tests for application client in standalone mode. + */ +class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfterAll { + private val numWorkers = 2 + private val conf = new SparkConf() + private val securityManager = new SecurityManager(conf) + + private var masterRpcEnv: RpcEnv = null + private var workerRpcEnvs: Seq[RpcEnv] = null + private var master: Master = null + private var workers: Seq[Worker] = null + + /** + * Start the local cluster. + * Note: local-cluster mode is insufficient because we want a reference to the Master. + */ + override def beforeAll(): Unit = { + super.beforeAll() + masterRpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityManager) + workerRpcEnvs = (0 until numWorkers).map { i => + RpcEnv.create(Worker.SYSTEM_NAME + i, "localhost", 0, conf, securityManager) + } + master = makeMaster() + workers = makeWorkers(10, 2048) + // Wait until all workers register with master successfully + eventually(timeout(60.seconds), interval(10.millis)) { + assert(getMasterState.workers.size === numWorkers) + } + } + + override def afterAll(): Unit = { + try { + workerRpcEnvs.foreach(_.shutdown()) + masterRpcEnv.shutdown() + workers.foreach(_.stop()) + master.stop() + workerRpcEnvs = null + masterRpcEnv = null + workers = null + master = null + } finally { + super.afterAll() + } + } + + test("interface methods of AppClient using local Master") { + val ci = new AppClientInst(masterRpcEnv.address.toSparkURL) + + ci.client.start() + + // Client should connect with one Master which registers the application + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(ci.listener.connectedIdList.size === 1, "client listener should have one connection") + assert(apps.size === 1, "master should have 1 registered app") + } + + // Send message to Master to request Executors, verify request by change in executor limit + val numExecutorsRequested = 1 + assert(ci.client.requestTotalExecutors(numExecutorsRequested)) + + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(apps.head.getExecutorLimit === numExecutorsRequested, s"executor request failed") + } + + // Send request to kill executor, verify request was made + assert { + val apps = getApplications() + val executorId: String = apps.head.executors.head._2.fullId + ci.client.killExecutors(Seq(executorId)) + } + + // Issue stop command for Client to disconnect from Master + ci.client.stop() + + // Verify Client is marked dead and unregistered from Master + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(ci.listener.deadReasonList.size === 1, "client should have been marked dead") + assert(apps.isEmpty, "master should have 0 registered apps") + } + } + + test("request from AppClient before initialized with master") { + val ci = new AppClientInst(masterRpcEnv.address.toSparkURL) + + // requests to master should fail immediately + assert(ci.client.requestTotalExecutors(3) === false) + } + + // =============================== + // | Utility methods for testing | + // =============================== + + /** Return a SparkConf for applications that want to talk to our Master. */ + private def appConf: SparkConf = { + new SparkConf() + .setMaster(masterRpcEnv.address.toSparkURL) + .setAppName("test") + .set("spark.executor.memory", "256m") + } + + /** Make a master to which our application will send executor requests. */ + private def makeMaster(): Master = { + val master = new Master(masterRpcEnv, masterRpcEnv.address, 0, securityManager, conf) + masterRpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master) + master + } + + /** Make a few workers that talk to our master. */ + private def makeWorkers(cores: Int, memory: Int): Seq[Worker] = { + (0 until numWorkers).map { i => + val rpcEnv = workerRpcEnvs(i) + val worker = new Worker(rpcEnv, 0, cores, memory, Array(masterRpcEnv.address), + Worker.ENDPOINT_NAME, null, conf, securityManager) + rpcEnv.setupEndpoint(Worker.ENDPOINT_NAME, worker) + worker + } + } + + /** Get the Master state */ + private def getMasterState: MasterStateResponse = { + master.self.askWithRetry[MasterStateResponse](RequestMasterState) + } + + /** Get the applications that are active from Master */ + private def getApplications(): Seq[ApplicationInfo] = { + getMasterState.activeApps + } + + /** Application Listener to collect events */ + private class AppClientCollector extends AppClientListener with Logging { + val connectedIdList = new ConcurrentLinkedQueue[String]() + @volatile var disconnectedCount: Int = 0 + val deadReasonList = new ConcurrentLinkedQueue[String]() + val execAddedList = new ConcurrentLinkedQueue[String]() + val execRemovedList = new ConcurrentLinkedQueue[String]() + + def connected(id: String): Unit = { + connectedIdList.add(id) + } + + def disconnected(): Unit = { + synchronized { + disconnectedCount += 1 + } + } + + def dead(reason: String): Unit = { + deadReasonList.add(reason) + } + + def executorAdded( + id: String, + workerId: String, + hostPort: String, + cores: Int, + memory: Int): Unit = { + execAddedList.add(id) + } + + def executorRemoved(id: String, message: String, exitStatus: Option[Int]): Unit = { + execRemovedList.add(id) + } + } + + /** Create AppClient and supporting objects */ + private class AppClientInst(masterUrl: String) { + val rpcEnv = RpcEnv.create("spark", Utils.localHostName(), 0, conf, securityManager) + private val cmd = new Command(TestExecutor.getClass.getCanonicalName.stripSuffix("$"), + List(), Map(), Seq(), Seq(), Seq()) + private val desc = new ApplicationDescription("AppClientSuite", Some(1), 512, cmd, "ignored") + val listener = new AppClientCollector + val client = new AppClient(rpcEnv, Array(masterUrl), desc, listener, new SparkConf) + } + +} diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala b/core/src/test/scala/org/apache/spark/deploy/client/TestExecutor.scala similarity index 100% rename from core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala rename to core/src/test/scala/org/apache/spark/deploy/client/TestExecutor.scala diff --git a/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala new file mode 100644 index 0000000000000..4ab000b53ad10 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala @@ -0,0 +1,489 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.history + +import java.util.{Date, NoSuchElementException} +import javax.servlet.Filter +import javax.servlet.http.{HttpServletRequest, HttpServletResponse} + +import scala.collection.mutable +import scala.collection.mutable.ListBuffer +import scala.language.postfixOps + +import com.codahale.metrics.Counter +import com.google.common.cache.LoadingCache +import com.google.common.util.concurrent.UncheckedExecutionException +import org.eclipse.jetty.servlet.ServletContextHandler +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.Matchers +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.Logging +import org.apache.spark.status.api.v1.{ApplicationAttemptInfo => AttemptInfo, ApplicationInfo} +import org.apache.spark.ui.SparkUI +import org.apache.spark.util.{Clock, ManualClock, Utils} + +class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar with Matchers { + + /** + * subclass with access to the cache internals + * @param retainedApplications number of retained applications + */ + class TestApplicationCache( + operations: ApplicationCacheOperations = new StubCacheOperations(), + retainedApplications: Int, + clock: Clock = new ManualClock(0)) + extends ApplicationCache(operations, retainedApplications, clock) { + + def cache(): LoadingCache[CacheKey, CacheEntry] = appCache + } + + /** + * Stub cache operations. + * The state is kept in a map of [[CacheKey]] to [[CacheEntry]], + * the `probeTime` field in the cache entry setting the timestamp of the entry + */ + class StubCacheOperations extends ApplicationCacheOperations with Logging { + + /** map to UI instances, including timestamps, which are used in update probes */ + val instances = mutable.HashMap.empty[CacheKey, CacheEntry] + + /** Map of attached spark UIs */ + val attached = mutable.HashMap.empty[CacheKey, SparkUI] + + var getAppUICount = 0L + var attachCount = 0L + var detachCount = 0L + var updateProbeCount = 0L + + override def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] = { + logDebug(s"getAppUI($appId, $attemptId)") + getAppUICount += 1 + instances.get(CacheKey(appId, attemptId)).map( e => + LoadedAppUI(e.ui, updateProbe(appId, attemptId, e.probeTime))) + } + + override def attachSparkUI( + appId: String, + attemptId: Option[String], + ui: SparkUI, + completed: Boolean): Unit = { + logDebug(s"attachSparkUI($appId, $attemptId, $ui)") + attachCount += 1 + attached += (CacheKey(appId, attemptId) -> ui) + } + + def putAndAttach( + appId: String, + attemptId: Option[String], + completed: Boolean, + started: Long, + ended: Long, + timestamp: Long): SparkUI = { + val ui = putAppUI(appId, attemptId, completed, started, ended, timestamp) + attachSparkUI(appId, attemptId, ui, completed) + ui + } + + def putAppUI( + appId: String, + attemptId: Option[String], + completed: Boolean, + started: Long, + ended: Long, + timestamp: Long): SparkUI = { + val ui = newUI(appId, attemptId, completed, started, ended) + putInstance(appId, attemptId, ui, completed, timestamp) + ui + } + + def putInstance( + appId: String, + attemptId: Option[String], + ui: SparkUI, + completed: Boolean, + timestamp: Long): Unit = { + instances += (CacheKey(appId, attemptId) -> + new CacheEntry(ui, completed, updateProbe(appId, attemptId, timestamp), timestamp)) + } + + /** + * Detach a reconstructed UI + * + * @param ui Spark UI + */ + override def detachSparkUI(appId: String, attemptId: Option[String], ui: SparkUI): Unit = { + logDebug(s"detachSparkUI($appId, $attemptId, $ui)") + detachCount += 1 + var name = ui.getAppName + val key = CacheKey(appId, attemptId) + attached.getOrElse(key, { throw new java.util.NoSuchElementException() }) + attached -= key + } + + /** + * Lookup from the internal cache of attached UIs + */ + def getAttached(appId: String, attemptId: Option[String]): Option[SparkUI] = { + attached.get(CacheKey(appId, attemptId)) + } + + /** + * The update probe. + * @param appId application to probe + * @param attemptId attempt to probe + * @param updateTime timestamp of this UI load + */ + private[history] def updateProbe( + appId: String, + attemptId: Option[String], + updateTime: Long)(): Boolean = { + updateProbeCount += 1 + logDebug(s"isUpdated($appId, $attemptId, ${updateTime})") + val entry = instances.get(CacheKey(appId, attemptId)).get + val updated = entry.probeTime > updateTime + logDebug(s"entry = $entry; updated = $updated") + updated + } + } + + /** + * Create a new UI. The info/attempt info classes here are from the package + * `org.apache.spark.status.api.v1`, not the near-equivalents from the history package + */ + def newUI( + name: String, + attemptId: Option[String], + completed: Boolean, + started: Long, + ended: Long): SparkUI = { + val info = new ApplicationInfo(name, name, Some(1), Some(1), Some(1), Some(64), + Seq(new AttemptInfo(attemptId, new Date(started), new Date(ended), + new Date(ended), ended - started, "user", completed))) + val ui = mock[SparkUI] + when(ui.getApplicationInfoList).thenReturn(List(info).iterator) + when(ui.getAppName).thenReturn(name) + when(ui.appName).thenReturn(name) + val handler = new ServletContextHandler() + when(ui.getHandlers).thenReturn(Seq(handler)) + ui + } + + /** + * Test operations on completed UIs: they are loaded on demand, entries + * are removed on overload. + * + * This effectively tests the original behavior of the history server's cache. + */ + test("Completed UI get") { + val operations = new StubCacheOperations() + val clock = new ManualClock(1) + implicit val cache = new ApplicationCache(operations, 2, clock) + val metrics = cache.metrics + // cache misses + val app1 = "app-1" + assertNotFound(app1, None) + assertMetric("lookupCount", metrics.lookupCount, 1) + assertMetric("lookupFailureCount", metrics.lookupFailureCount, 1) + assert(1 === operations.getAppUICount, "getAppUICount") + assertNotFound(app1, None) + assert(2 === operations.getAppUICount, "getAppUICount") + assert(0 === operations.attachCount, "attachCount") + + val now = clock.getTimeMillis() + // add the entry + operations.putAppUI(app1, None, true, now, now, now) + + // make sure its local + operations.getAppUI(app1, None).get + operations.getAppUICount = 0 + // now expect it to be found + val cacheEntry = cache.lookupCacheEntry(app1, None) + assert(1 === cacheEntry.probeTime) + assert(cacheEntry.completed) + // assert about queries made of the operations + assert(1 === operations.getAppUICount, "getAppUICount") + assert(1 === operations.attachCount, "attachCount") + + // and in the map of attached + assert(operations.getAttached(app1, None).isDefined, s"attached entry '1' from $cache") + + // go forward in time + clock.setTime(10) + val time2 = clock.getTimeMillis() + val cacheEntry2 = cache.get(app1) + // no more refresh as this is a completed app + assert(1 === operations.getAppUICount, "getAppUICount") + assert(0 === operations.updateProbeCount, "updateProbeCount") + assert(0 === operations.detachCount, "attachCount") + + // evict the entry + operations.putAndAttach("2", None, true, time2, time2, time2) + operations.putAndAttach("3", None, true, time2, time2, time2) + cache.get("2") + cache.get("3") + + // there should have been a detachment here + assert(1 === operations.detachCount, s"detach count from $cache") + // and entry app1 no longer attached + assert(operations.getAttached(app1, None).isEmpty, s"get($app1) in $cache") + val appId = "app1" + val attemptId = Some("_01") + val time3 = clock.getTimeMillis() + operations.putAppUI(appId, attemptId, false, time3, 0, time3) + // expect an error here + assertNotFound(appId, None) + } + + test("Test that if an attempt ID is is set, it must be used in lookups") { + val operations = new StubCacheOperations() + val clock = new ManualClock(1) + implicit val cache = new ApplicationCache(operations, retainedApplications = 10, clock = clock) + val appId = "app1" + val attemptId = Some("_01") + operations.putAppUI(appId, attemptId, false, clock.getTimeMillis(), 0, 0) + assertNotFound(appId, None) + } + + /** + * Test that incomplete apps are not probed for updates during the time window, + * but that they are checked if that window has expired and they are not completed. + * Then, if they have changed, the old entry is replaced by a new one. + */ + test("Incomplete apps refreshed") { + val operations = new StubCacheOperations() + val clock = new ManualClock(50) + val window = 500 + implicit val cache = new ApplicationCache(operations, retainedApplications = 5, clock = clock) + val metrics = cache.metrics + // add the incomplete app + // add the entry + val started = clock.getTimeMillis() + val appId = "app1" + val attemptId = Some("001") + operations.putAppUI(appId, attemptId, false, started, 0, started) + val firstEntry = cache.lookupCacheEntry(appId, attemptId) + assert(started === firstEntry.probeTime, s"timestamp in $firstEntry") + assert(!firstEntry.completed, s"entry is complete: $firstEntry") + assertMetric("lookupCount", metrics.lookupCount, 1) + + assert(0 === operations.updateProbeCount, "expected no update probe on that first get") + + val checkTime = window * 2 + clock.setTime(checkTime) + val entry3 = cache.lookupCacheEntry(appId, attemptId) + assert(firstEntry !== entry3, s"updated entry test from $cache") + assertMetric("lookupCount", metrics.lookupCount, 2) + assertMetric("updateProbeCount", metrics.updateProbeCount, 1) + assertMetric("updateTriggeredCount", metrics.updateTriggeredCount, 0) + assert(1 === operations.updateProbeCount, s"refresh count in $cache") + assert(0 === operations.detachCount, s"detach count") + assert(entry3.probeTime === checkTime) + + val updateTime = window * 3 + // update the cached value + val updatedApp = operations.putAppUI(appId, attemptId, true, started, updateTime, updateTime) + val endTime = window * 10 + clock.setTime(endTime) + logDebug(s"Before operation = $cache") + val entry5 = cache.lookupCacheEntry(appId, attemptId) + assertMetric("lookupCount", metrics.lookupCount, 3) + assertMetric("updateProbeCount", metrics.updateProbeCount, 2) + // the update was triggered + assertMetric("updateTriggeredCount", metrics.updateTriggeredCount, 1) + assert(updatedApp === entry5.ui, s"UI {$updatedApp} did not match entry {$entry5} in $cache") + + // at which point, the refreshes stop + clock.setTime(window * 20) + assertCacheEntryEquals(appId, attemptId, entry5) + assertMetric("updateProbeCount", metrics.updateProbeCount, 2) + } + + /** + * Assert that a metric counter has a specific value; failure raises an exception + * including the cache's toString value + * @param name counter name (for exceptions) + * @param counter counter + * @param expected expected value. + * @param cache cache + */ + def assertMetric( + name: String, + counter: Counter, + expected: Long) + (implicit cache: ApplicationCache): Unit = { + val actual = counter.getCount + if (actual != expected) { + // this is here because Scalatest loses stack depth + throw new Exception(s"Wrong $name value - expected $expected but got $actual in $cache") + } + } + + /** + * Look up the cache entry and assert that it matches in the expected value. + * This assertion works if the two CacheEntries are different -it looks at the fields. + * UI are compared on object equality; the timestamp and completed flags directly. + * @param appId application ID + * @param attemptId attempt ID + * @param expected expected value + * @param cache app cache + */ + def assertCacheEntryEquals( + appId: String, + attemptId: Option[String], + expected: CacheEntry) + (implicit cache: ApplicationCache): Unit = { + val actual = cache.lookupCacheEntry(appId, attemptId) + val errorText = s"Expected get($appId, $attemptId) -> $expected, but got $actual from $cache" + assert(expected.ui === actual.ui, errorText + " SparkUI reference") + assert(expected.completed === actual.completed, errorText + " -completed flag") + assert(expected.probeTime === actual.probeTime, errorText + " -timestamp") + } + + /** + * Assert that a key wasn't found in cache or loaded. + * + * Looks for the specific nested exception raised by [[ApplicationCache]] + * @param appId application ID + * @param attemptId attempt ID + * @param cache app cache + */ + def assertNotFound( + appId: String, + attemptId: Option[String]) + (implicit cache: ApplicationCache): Unit = { + val ex = intercept[UncheckedExecutionException] { + cache.get(appId, attemptId) + } + var cause = ex.getCause + assert(cause !== null) + if (!cause.isInstanceOf[NoSuchElementException]) { + throw cause + } + } + + test("Large Scale Application Eviction") { + val operations = new StubCacheOperations() + val clock = new ManualClock(0) + val size = 5 + // only two entries are retained, so we expect evictions to occur on lookups + implicit val cache: ApplicationCache = new TestApplicationCache(operations, + retainedApplications = size, clock = clock) + + val attempt1 = Some("01") + + val ids = new ListBuffer[String]() + // build a list of applications + val count = 100 + for (i <- 1 to count ) { + val appId = f"app-$i%04d" + ids += appId + clock.advance(10) + val t = clock.getTimeMillis() + operations.putAppUI(appId, attempt1, true, t, t, t) + } + // now go through them in sequence reading them, expect evictions + ids.foreach { id => + cache.get(id, attempt1) + } + logInfo(cache.toString) + val metrics = cache.metrics + + assertMetric("loadCount", metrics.loadCount, count) + assertMetric("evictionCount", metrics.evictionCount, count - size) +} + + test("Attempts are Evicted") { + val operations = new StubCacheOperations() + implicit val cache: ApplicationCache = new TestApplicationCache(operations, + retainedApplications = 4) + val metrics = cache.metrics + val appId = "app1" + val attempt1 = Some("01") + val attempt2 = Some("02") + val attempt3 = Some("03") + operations.putAppUI(appId, attempt1, true, 100, 110, 110) + operations.putAppUI(appId, attempt2, true, 200, 210, 210) + operations.putAppUI(appId, attempt3, true, 300, 310, 310) + val attempt4 = Some("04") + operations.putAppUI(appId, attempt4, true, 400, 410, 410) + val attempt5 = Some("05") + operations.putAppUI(appId, attempt5, true, 500, 510, 510) + + def expectLoadAndEvictionCounts(expectedLoad: Int, expectedEvictionCount: Int): Unit = { + assertMetric("loadCount", metrics.loadCount, expectedLoad) + assertMetric("evictionCount", metrics.evictionCount, expectedEvictionCount) + } + + // first entry + cache.get(appId, attempt1) + expectLoadAndEvictionCounts(1, 0) + + // second + cache.get(appId, attempt2) + expectLoadAndEvictionCounts(2, 0) + + // no change + cache.get(appId, attempt2) + expectLoadAndEvictionCounts(2, 0) + + // eviction time + cache.get(appId, attempt3) + cache.size() should be(3) + cache.get(appId, attempt4) + expectLoadAndEvictionCounts(4, 0) + cache.get(appId, attempt5) + expectLoadAndEvictionCounts(5, 1) + cache.get(appId, attempt5) + expectLoadAndEvictionCounts(5, 1) + + } + + test("Instantiate Filter") { + // this is a regression test on the filter being constructable + val clazz = Utils.classForName(ApplicationCacheCheckFilterRelay.FILTER_NAME) + val instance = clazz.newInstance() + instance shouldBe a [Filter] + } + + test("redirect includes query params") { + val clazz = Utils.classForName(ApplicationCacheCheckFilterRelay.FILTER_NAME) + val filter = clazz.newInstance().asInstanceOf[ApplicationCacheCheckFilter] + filter.appId = "local-123" + val cache = mock[ApplicationCache] + when(cache.checkForUpdates(any(), any())).thenReturn(true) + ApplicationCacheCheckFilterRelay.setApplicationCache(cache) + val request = mock[HttpServletRequest] + when(request.getMethod()).thenReturn("GET") + when(request.getRequestURI()).thenReturn("http://localhost:18080/history/local-123/jobs/job/") + when(request.getQueryString()).thenReturn("id=2") + val resp = mock[HttpServletResponse] + when(resp.encodeRedirectURL(any())).thenAnswer(new Answer[String]() { + override def answer(invocationOnMock: InvocationOnMock): String = { + invocationOnMock.getArguments()(0).asInstanceOf[String] + } + }) + filter.doFilter(request, resp, null) + verify(resp).sendRedirect("http://localhost:18080/history/local-123/jobs/job/?id=2") + } + +} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 5cab17f8a38f5..39c5857b13451 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -20,33 +20,30 @@ package org.apache.spark.deploy.history import java.io.{BufferedOutputStream, ByteArrayInputStream, ByteArrayOutputStream, File, FileOutputStream, OutputStreamWriter} import java.net.URI +import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit import java.util.zip.{ZipInputStream, ZipOutputStream} -import scala.io.Source import scala.concurrent.duration._ import scala.language.postfixOps -import com.google.common.base.Charsets import com.google.common.io.{ByteStreams, Files} -import org.apache.hadoop.fs.Path import org.apache.hadoop.hdfs.DistributedFileSystem import org.json4s.jackson.JsonMethods._ import org.mockito.Matchers.any -import org.mockito.Mockito.{doReturn, mock, spy, verify, when} +import org.mockito.Mockito.{mock, spy, verify} import org.scalatest.BeforeAndAfter import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{Logging, SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.Logging import org.apache.spark.io._ import org.apache.spark.scheduler._ import org.apache.spark.util.{Clock, JsonProtocol, ManualClock, Utils} class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { - import FsHistoryProvider._ - private var testDir: File = null before { @@ -69,7 +66,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc new File(logPath) } - test("Parse new and old application logs") { + test("Parse application logs") { val provider = new FsHistoryProvider(createTestConf()) // Write a new-style application log. @@ -95,26 +92,11 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc None) ) - // Write an old-style application log. - val oldAppComplete = writeOldLog("old1", "1.0", None, true, - SparkListenerApplicationStart("old1", Some("old-app-complete"), 2L, "test", None), - SparkListenerApplicationEnd(3L) - ) - - // Check for logs so that we force the older unfinished app to be loaded, to make - // sure unfinished apps are also sorted correctly. - provider.checkForLogs() - - // Write an unfinished app, old-style. - val oldAppIncomplete = writeOldLog("old2", "1.0", None, false, - SparkListenerApplicationStart("old2", None, 2L, "test", None) - ) - - // Force a reload of data from the log directory, and check that both logs are loaded. + // Force a reload of data from the log directory, and check that logs are loaded. // Take the opportunity to check that the offset checks work as expected. updateAndCheck(provider) { list => - list.size should be (5) - list.count(_.attempts.head.completed) should be (3) + list.size should be (3) + list.count(_.attempts.head.completed) should be (2) def makeAppInfo( id: String, @@ -132,11 +114,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc newAppComplete.lastModified(), "test", true)) list(1) should be (makeAppInfo("new-complete-lzf", newAppCompressedComplete.getName(), 1L, 4L, newAppCompressedComplete.lastModified(), "test", true)) - list(2) should be (makeAppInfo("old-app-complete", oldAppComplete.getName(), 2L, 3L, - oldAppComplete.lastModified(), "test", true)) - list(3) should be (makeAppInfo(oldAppIncomplete.getName(), oldAppIncomplete.getName(), 2L, - -1L, oldAppIncomplete.lastModified(), "test", false)) - list(4) should be (makeAppInfo("new-incomplete", newAppIncomplete.getName(), 1L, -1L, + list(2) should be (makeAppInfo("new-incomplete", newAppIncomplete.getName(), 1L, -1L, newAppIncomplete.lastModified(), "test", false)) // Make sure the UI can be rendered. @@ -148,38 +126,6 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } - test("Parse legacy logs with compression codec set") { - val provider = new FsHistoryProvider(createTestConf()) - val testCodecs = List((classOf[LZFCompressionCodec].getName(), true), - (classOf[SnappyCompressionCodec].getName(), true), - ("invalid.codec", false)) - - testCodecs.foreach { case (codecName, valid) => - val codec = if (valid) CompressionCodec.createCodec(new SparkConf(), codecName) else null - val logDir = new File(testDir, codecName) - logDir.mkdir() - createEmptyFile(new File(logDir, SPARK_VERSION_PREFIX + "1.0")) - writeFile(new File(logDir, LOG_PREFIX + "1"), false, Option(codec), - SparkListenerApplicationStart("app2", None, 2L, "test", None), - SparkListenerApplicationEnd(3L) - ) - createEmptyFile(new File(logDir, COMPRESSION_CODEC_PREFIX + codecName)) - - val logPath = new Path(logDir.getAbsolutePath()) - try { - val logInput = provider.openLegacyEventLog(logPath) - try { - Source.fromInputStream(logInput).getLines().toSeq.size should be (2) - } finally { - logInput.close() - } - } catch { - case e: IllegalArgumentException => - valid should be (false) - } - } - } - test("SPARK-3697: ignore directories that cannot be read.") { val logFile1 = newLogFile("new1", None, inProgress = false) writeFile(logFile1, true, None, @@ -375,8 +321,9 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc var entry = inputStream.getNextEntry entry should not be null while (entry != null) { - val actual = new String(ByteStreams.toByteArray(inputStream), Charsets.UTF_8) - val expected = Files.toString(logs.find(_.getName == entry.getName).get, Charsets.UTF_8) + val actual = new String(ByteStreams.toByteArray(inputStream), StandardCharsets.UTF_8) + val expected = + Files.toString(logs.find(_.getName == entry.getName).get, StandardCharsets.UTF_8) actual should be (expected) totalEntries += 1 entry = inputStream.getNextEntry @@ -395,21 +342,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc SparkListenerLogStart("1.4") ) - // Write a 1.2 log file with no start event (= no app id), it should be ignored. - writeOldLog("v12Log", "1.2", None, false) - - // Write 1.0 and 1.1 logs, which don't have app ids. - writeOldLog("v11Log", "1.1", None, true, - SparkListenerApplicationStart("v11Log", None, 2L, "test", None), - SparkListenerApplicationEnd(3L)) - writeOldLog("v10Log", "1.0", None, true, - SparkListenerApplicationStart("v10Log", None, 2L, "test", None), - SparkListenerApplicationEnd(4L)) - updateAndCheck(provider) { list => - list.size should be (2) - list(0).id should be ("v10Log") - list(1).id should be ("v11Log") + list.size should be (0) } } @@ -483,7 +417,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc if (isNewFormat) { EventLoggingListener.initEventLog(new FileOutputStream(file)) } - val writer = new OutputStreamWriter(bstream, "UTF-8") + val writer = new OutputStreamWriter(bstream, StandardCharsets.UTF_8) Utils.tryWithSafeFinally { events.foreach(e => writer.write(compact(render(JsonProtocol.sparkEventToJson(e))) + "\n")) } { @@ -499,25 +433,6 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc new SparkConf().set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) } - private def writeOldLog( - fname: String, - sparkVersion: String, - codec: Option[CompressionCodec], - completed: Boolean, - events: SparkListenerEvent*): File = { - val log = new File(testDir, fname) - log.mkdir() - - val oldEventLog = new File(log, LOG_PREFIX + "1") - createEmptyFile(new File(log, SPARK_VERSION_PREFIX + sparkVersion)) - writeFile(new File(log, LOG_PREFIX + "1"), false, codec, events: _*) - if (completed) { - createEmptyFile(new File(log, APPLICATION_COMPLETE)) - } - - log - } - private class SafeModeTestProvider(conf: SparkConf, clock: Clock) extends FsHistoryProvider(conf, clock) { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 4b7fd4f13b692..2a013aca7b895 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -18,23 +18,36 @@ package org.apache.spark.deploy.history import java.io.{File, FileInputStream, FileWriter, InputStream, IOException} import java.net.{HttpURLConnection, URL} +import java.nio.charset.StandardCharsets import java.util.zip.ZipInputStream import javax.servlet.http.{HttpServletRequest, HttpServletResponse} -import com.google.common.base.Charsets +import scala.concurrent.duration._ +import scala.language.postfixOps + +import com.codahale.metrics.Counter import com.google.common.io.{ByteStreams, Files} import org.apache.commons.io.{FileUtils, IOUtils} -import org.mockito.Mockito.when +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.json4s.JsonAST._ +import org.json4s.jackson.JsonMethods +import org.json4s.jackson.JsonMethods._ +import org.openqa.selenium.WebDriver +import org.openqa.selenium.htmlunit.HtmlUnitDriver import org.scalatest.{BeforeAndAfter, Matchers} +import org.scalatest.concurrent.Eventually import org.scalatest.mock.MockitoSugar +import org.scalatest.selenium.WebBrowser -import org.apache.spark.{JsonTestUtils, SecurityManager, SparkConf, SparkFunSuite} -import org.apache.spark.ui.{SparkUI, UIUtils} +import org.apache.spark._ +import org.apache.spark.ui.SparkUI +import org.apache.spark.ui.jobs.UIData.JobUIData +import org.apache.spark.util.{ResetSystemProperties, Utils} /** * A collection of tests against the historyserver, including comparing responses from the json * metrics api to a set of known "golden files". If new endpoints / parameters are added, - * cases should be added to this test suite. The expected outcomes can be genered by running + * cases should be added to this test suite. The expected outcomes can be generated by running * the HistoryServerSuite.main. Note that this will blindly generate new expectation files matching * the current behavior -- the developer must verify that behavior is correct. * @@ -43,7 +56,8 @@ import org.apache.spark.ui.{SparkUI, UIUtils} * are considered part of Spark's public api. */ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers with MockitoSugar - with JsonTestUtils { + with JsonTestUtils with Eventually with WebBrowser with LocalSparkContext + with ResetSystemProperties { private val logDir = new File("src/test/resources/spark-events") private val expRoot = new File("src/test/resources/HistoryServerExpectations/") @@ -55,7 +69,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers def init(): Unit = { val conf = new SparkConf() .set("spark.history.fs.logDirectory", logDir.getAbsolutePath) - .set("spark.history.fs.updateInterval", "0") + .set("spark.history.fs.update.interval", "0") .set("spark.testing", "true") provider = new FsHistoryProvider(conf) provider.checkForLogs() @@ -126,8 +140,9 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers "stage task list from multi-attempt app json(2)" -> "applications/local-1426533911241/2/stages/0/0/taskList", - "rdd list storage json" -> "applications/local-1422981780767/storage/rdd", - "one rdd storage json" -> "applications/local-1422981780767/storage/rdd/0" + "rdd list storage json" -> "applications/local-1422981780767/storage/rdd" + // Todo: enable this test when logging the even of onBlockUpdated. See: SPARK-13845 + // "one rdd storage json" -> "applications/local-1422981780767/storage/rdd/0" ) // run a bunch of characterization tests -- just verify the behavior is the same as what is saved @@ -138,7 +153,26 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers code should be (HttpServletResponse.SC_OK) jsonOpt should be ('defined) errOpt should be (None) - val json = jsonOpt.get + val jsonOrg = jsonOpt.get + + // SPARK-10873 added the lastUpdated field for each application's attempt, + // the REST API returns the last modified time of EVENT LOG file for this field. + // It is not applicable to hard-code this dynamic field in a static expected file, + // so here we skip checking the lastUpdated field's value (setting it as ""). + val json = if (jsonOrg.indexOf("lastUpdated") >= 0) { + val subStrings = jsonOrg.split(",") + for (i <- subStrings.indices) { + if (subStrings(i).indexOf("lastUpdatedEpoch") >= 0) { + subStrings(i) = subStrings(i).replaceAll("(\\d+)", "0") + } else if (subStrings(i).indexOf("lastUpdated") >= 0) { + subStrings(i) = "\"lastUpdated\":\"\"" + } + } + subStrings.mkString(",") + } else { + jsonOrg + } + val exp = IOUtils.toString(new FileInputStream( new File(expRoot, HistoryServerSuite.sanitizePath(name) + "_expectation.json"))) // compare the ASTs so formatting differences don't cause failures @@ -158,18 +192,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers (1 to 2).foreach { attemptId => doDownloadTest("local-1430917381535", Some(attemptId)) } } - test("download legacy logs - all attempts") { - doDownloadTest("local-1426533911241", None, legacy = true) - } - - test("download legacy logs - single attempts") { - (1 to 2). foreach { - attemptId => doDownloadTest("local-1426533911241", Some(attemptId), legacy = true) - } - } - // Test that the files are downloaded correctly, and validate them. - def doDownloadTest(appId: String, attemptId: Option[Int], legacy: Boolean = false): Unit = { + def doDownloadTest(appId: String, attemptId: Option[Int]): Unit = { val url = attemptId match { case Some(id) => @@ -187,25 +211,16 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers var entry = zipStream.getNextEntry entry should not be null val totalFiles = { - if (legacy) { - attemptId.map { x => 3 }.getOrElse(6) - } else { - attemptId.map { x => 1 }.getOrElse(2) - } + attemptId.map { x => 1 }.getOrElse(2) } var filesCompared = 0 while (entry != null) { if (!entry.isDirectory) { val expectedFile = { - if (legacy) { - val splits = entry.getName.split("/") - new File(new File(logDir, splits(0)), splits(1)) - } else { - new File(logDir, entry.getName) - } + new File(logDir, entry.getName) } - val expected = Files.toString(expectedFile, Charsets.UTF_8) - val actual = new String(ByteStreams.toByteArray(zipStream), Charsets.UTF_8) + val expected = Files.toString(expectedFile, StandardCharsets.UTF_8) + val actual = new String(ByteStreams.toByteArray(zipStream), StandardCharsets.UTF_8) actual should be (expected) filesCompared += 1 } @@ -240,30 +255,6 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers getContentAndCode("foobar")._1 should be (HttpServletResponse.SC_NOT_FOUND) } - test("generate history page with relative links") { - val historyServer = mock[HistoryServer] - val request = mock[HttpServletRequest] - val ui = mock[SparkUI] - val link = "/history/app1" - val info = new ApplicationHistoryInfo("app1", "app1", - List(ApplicationAttemptInfo(None, 0, 2, 1, "xxx", true))) - when(historyServer.getApplicationList()).thenReturn(Seq(info)) - when(ui.basePath).thenReturn(link) - when(historyServer.getProviderConfig()).thenReturn(Map[String, String]()) - val page = new HistoryPage(historyServer) - - // when - val response = page.render(request) - - // then - val links = response \\ "a" - val justHrefs = for { - l <- links - attrs <- l.attribute("href") - } yield (attrs.toString) - justHrefs should contain (UIUtils.prependBaseUri(resource = link)) - } - test("relative links are prefixed with uiRoot (spark.ui.proxyBase)") { val proxyBaseBeforeTest = System.getProperty("spark.ui.proxyBase") val uiRoot = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).getOrElse("/testwebproxybase") @@ -281,6 +272,204 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers all (siteRelativeLinks) should startWith (uiRoot) } + test("incomplete apps get refreshed") { + + implicit val webDriver: WebDriver = new HtmlUnitDriver + implicit val formats = org.json4s.DefaultFormats + + // this test dir is explicitly deleted on successful runs; retained for diagnostics when + // not + val logDir = Utils.createDirectory(System.getProperty("java.io.tmpdir", "logs")) + + // a new conf is used with the background thread set and running at its fastest + // allowed refresh rate (1Hz) + val myConf = new SparkConf() + .set("spark.history.fs.logDirectory", logDir.getAbsolutePath) + .set("spark.eventLog.dir", logDir.getAbsolutePath) + .set("spark.history.fs.update.interval", "1s") + .set("spark.eventLog.enabled", "true") + .set("spark.history.cache.window", "250ms") + .remove("spark.testing") + val provider = new FsHistoryProvider(myConf) + val securityManager = new SecurityManager(myConf) + + sc = new SparkContext("local", "test", myConf) + val logDirUri = logDir.toURI + val logDirPath = new Path(logDirUri) + val fs = FileSystem.get(logDirUri, sc.hadoopConfiguration) + + def listDir(dir: Path): Seq[FileStatus] = { + val statuses = fs.listStatus(dir) + statuses.flatMap( + stat => if (stat.isDirectory) listDir(stat.getPath) else Seq(stat)) + } + + def dumpLogDir(msg: String = ""): Unit = { + if (log.isDebugEnabled) { + logDebug(msg) + listDir(logDirPath).foreach { status => + val s = status.toString + logDebug(s) + } + } + } + + // stop the server with the old config, and start the new one + server.stop() + server = new HistoryServer(myConf, provider, securityManager, 18080) + server.initialize() + server.bind() + val port = server.boundPort + val metrics = server.cacheMetrics + + // assert that a metric has a value; if not dump the whole metrics instance + def assertMetric(name: String, counter: Counter, expected: Long): Unit = { + val actual = counter.getCount + if (actual != expected) { + // this is here because Scalatest loses stack depth + fail(s"Wrong $name value - expected $expected but got $actual" + + s" in metrics\n$metrics") + } + } + + // build a URL for an app or app/attempt plus a page underneath + def buildURL(appId: String, suffix: String): URL = { + new URL(s"http://localhost:$port/history/$appId$suffix") + } + + // build a rest URL for the application and suffix. + def applications(appId: String, suffix: String): URL = { + new URL(s"http://localhost:$port/api/v1/applications/$appId$suffix") + } + + val historyServerRoot = new URL(s"http://localhost:$port/") + + // start initial job + val d = sc.parallelize(1 to 10) + d.count() + val stdInterval = interval(100 milliseconds) + val appId = eventually(timeout(20 seconds), stdInterval) { + val json = getContentAndCode("applications", port)._2.get + val apps = parse(json).asInstanceOf[JArray].arr + apps should have size 1 + (apps.head \ "id").extract[String] + } + + val appIdRoot = buildURL(appId, "") + val rootAppPage = HistoryServerSuite.getUrl(appIdRoot) + logDebug(s"$appIdRoot ->[${rootAppPage.length}] \n$rootAppPage") + // sanity check to make sure filter is chaining calls + rootAppPage should not be empty + + def getAppUI: SparkUI = { + provider.getAppUI(appId, None).get.ui + } + + // selenium isn't that useful on failures...add our own reporting + def getNumJobs(suffix: String): Int = { + val target = buildURL(appId, suffix) + val targetBody = HistoryServerSuite.getUrl(target) + try { + go to target.toExternalForm + findAll(cssSelector("tbody tr")).toIndexedSeq.size + } catch { + case ex: Exception => + throw new Exception(s"Against $target\n$targetBody", ex) + } + } + // use REST API to get #of jobs + def getNumJobsRestful(): Int = { + val json = HistoryServerSuite.getUrl(applications(appId, "/jobs")) + val jsonAst = parse(json) + val jobList = jsonAst.asInstanceOf[JArray] + jobList.values.size + } + + // get a list of app Ids of all apps in a given state. REST API + def listApplications(completed: Boolean): Seq[String] = { + val json = parse(HistoryServerSuite.getUrl(applications("", ""))) + logDebug(s"${JsonMethods.pretty(json)}") + json match { + case JNothing => Seq() + case apps: JArray => + apps.filter(app => { + (app \ "attempts") match { + case attempts: JArray => + val state = (attempts.children.head \ "completed").asInstanceOf[JBool] + state.value == completed + case _ => false + } + }).map(app => (app \ "id").asInstanceOf[JString].values) + case _ => Seq() + } + } + + def completedJobs(): Seq[JobUIData] = { + getAppUI.jobProgressListener.completedJobs + } + + def activeJobs(): Seq[JobUIData] = { + getAppUI.jobProgressListener.activeJobs.values.toSeq + } + + activeJobs() should have size 0 + completedJobs() should have size 1 + getNumJobs("") should be (1) + getNumJobs("/jobs") should be (1) + getNumJobsRestful() should be (1) + assert(metrics.lookupCount.getCount > 1, s"lookup count too low in $metrics") + + // dump state before the next bit of test, which is where update + // checking really gets stressed + dumpLogDir("filesystem before executing second job") + logDebug(s"History Server: $server") + + val d2 = sc.parallelize(1 to 10) + d2.count() + dumpLogDir("After second job") + + val stdTimeout = timeout(10 seconds) + logDebug("waiting for UI to update") + eventually(stdTimeout, stdInterval) { + assert(2 === getNumJobs(""), + s"jobs not updated, server=$server\n dir = ${listDir(logDirPath)}") + assert(2 === getNumJobs("/jobs"), + s"job count under /jobs not updated, server=$server\n dir = ${listDir(logDirPath)}") + getNumJobsRestful() should be(2) + } + + d.count() + d.count() + eventually(stdTimeout, stdInterval) { + assert(4 === getNumJobsRestful(), s"two jobs back-to-back not updated, server=$server\n") + } + val jobcount = getNumJobs("/jobs") + assert(!provider.getListing().head.completed) + + listApplications(false) should contain(appId) + + // stop the spark context + resetSparkContext() + // check the app is now found as completed + eventually(stdTimeout, stdInterval) { + assert(provider.getListing().head.completed, + s"application never completed, server=$server\n") + } + + // app becomes observably complete + eventually(stdTimeout, stdInterval) { + listApplications(true) should contain (appId) + } + // app is no longer incomplete + listApplications(false) should not contain(appId) + + assert(jobcount === getNumJobs("/jobs")) + + // no need to retain the test dir now the tests complete + logDir.deleteOnExit(); + + } + def getContentAndCode(path: String, port: Int = port): (Int, Option[String], Option[String]) = { HistoryServerSuite.getContentAndCode(new URL(s"http://localhost:$port/api/v1/$path")) } @@ -300,6 +489,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers out.write(json) out.close() } + } object HistoryServerSuite { diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 242bf4b5566eb..7cbe4e342eaa5 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -18,22 +18,36 @@ package org.apache.spark.deploy.master import java.util.Date +import java.util.concurrent.ConcurrentLinkedQueue +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.io.Source import scala.language.postfixOps import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.scalatest.{Matchers, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, Matchers, PrivateMethodTester} import org.scalatest.concurrent.Eventually import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy._ -import org.apache.spark.rpc.RpcEnv +import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.rpc.{RpcEndpoint, RpcEnv} -class MasterSuite extends SparkFunSuite with Matchers with Eventually with PrivateMethodTester { +class MasterSuite extends SparkFunSuite + with Matchers with Eventually with PrivateMethodTester with BeforeAndAfter { + + private var _master: Master = _ + + after { + if (_master != null) { + _master.rpcEnv.shutdown() + _master.rpcEnv.awaitTermination() + _master = null + } + } test("can use a custom recovery mode factory") { val conf = new SparkConf(loadDefaults = false) @@ -90,15 +104,14 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva cores = 0, memory = 0, endpoint = null, - webUiPort = 0, - publicAddress = "" + webUiAddress = "http://localhost:80" ) val (rpcEnv, _, _) = Master.startRpcEnvAndEndpoint("127.0.0.1", 0, 0, conf) try { - rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, rpcEnv.address, Master.ENDPOINT_NAME) + rpcEnv.setupEndpointRef(rpcEnv.address, Master.ENDPOINT_NAME) CustomPersistenceEngine.lastInstance.isDefined shouldBe true val persistenceEngine = CustomPersistenceEngine.lastInstance.get @@ -358,10 +371,11 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva private val workerInfos = Array(workerInfo, workerInfo, workerInfo) private def makeMaster(conf: SparkConf = new SparkConf): Master = { + assert(_master === null, "Some Master's RpcEnv is leaked in tests") val securityMgr = new SecurityManager(conf) val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityMgr) - val master = new Master(rpcEnv, rpcEnv.address, 0, securityMgr, conf) - master + _master = new Master(rpcEnv, rpcEnv.address, 0, securityMgr, conf) + _master } private def makeAppInfo( @@ -376,7 +390,7 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva private def makeWorkerInfo(memoryMb: Int, cores: Int): WorkerInfo = { val workerId = System.currentTimeMillis.toString - new WorkerInfo(workerId, "host", 100, cores, memoryMb, null, 101, "address") + new WorkerInfo(workerId, "host", 100, cores, memoryMb, null, "http://localhost:80") } private def scheduleExecutorsOnWorkers( @@ -387,4 +401,35 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva master.invokePrivate(_scheduleExecutorsOnWorkers(appInfo, workerInfos, spreadOut)) } + test("SPARK-13604: Master should ask Worker kill unknown executors and drivers") { + val master = makeMaster() + master.rpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master) + eventually(timeout(10.seconds)) { + val masterState = master.self.askWithRetry[MasterStateResponse](RequestMasterState) + assert(masterState.status === RecoveryState.ALIVE, "Master is not alive") + } + + val killedExecutors = new ConcurrentLinkedQueue[(String, Int)]() + val killedDrivers = new ConcurrentLinkedQueue[String]() + val fakeWorker = master.rpcEnv.setupEndpoint("worker", new RpcEndpoint { + override val rpcEnv: RpcEnv = master.rpcEnv + + override def receive: PartialFunction[Any, Unit] = { + case KillExecutor(_, appId, execId) => killedExecutors.add(appId, execId) + case KillDriver(driverId) => killedDrivers.add(driverId) + } + }) + + master.self.ask( + RegisterWorker("1", "localhost", 9999, fakeWorker, 10, 1024, "http://localhost:8080")) + val executors = (0 until 3).map { i => + new ExecutorDescription(appId = i.toString, execId = i, 2, ExecutorState.RUNNING) + } + master.self.send(WorkerLatestState("1", executors, driverIds = Seq("0", "1", "2"))) + + eventually(timeout(10.seconds)) { + assert(killedExecutors.asScala.toList.sorted === List("0" -> 0, "1" -> 1, "2" -> 2)) + assert(killedDrivers.asScala.toList.sorted === List("0", "1", "2")) + } + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala index 34775577de8a3..62fe0eaedfd27 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala @@ -25,7 +25,7 @@ import org.apache.curator.test.TestingServer import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.rpc.{RpcEndpoint, RpcEnv} -import org.apache.spark.serializer.{Serializer, JavaSerializer} +import org.apache.spark.serializer.{JavaSerializer, Serializer} import org.apache.spark.util.Utils class PersistenceEngineSuite extends SparkFunSuite { @@ -63,56 +63,57 @@ class PersistenceEngineSuite extends SparkFunSuite { conf: SparkConf, persistenceEngineCreator: Serializer => PersistenceEngine): Unit = { val serializer = new JavaSerializer(conf) val persistenceEngine = persistenceEngineCreator(serializer) - persistenceEngine.persist("test_1", "test_1_value") - assert(Seq("test_1_value") === persistenceEngine.read[String]("test_")) - persistenceEngine.persist("test_2", "test_2_value") - assert(Set("test_1_value", "test_2_value") === persistenceEngine.read[String]("test_").toSet) - persistenceEngine.unpersist("test_1") - assert(Seq("test_2_value") === persistenceEngine.read[String]("test_")) - persistenceEngine.unpersist("test_2") - assert(persistenceEngine.read[String]("test_").isEmpty) - - // Test deserializing objects that contain RpcEndpointRef - val testRpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) try { - // Create a real endpoint so that we can test RpcEndpointRef deserialization - val workerEndpoint = testRpcEnv.setupEndpoint("worker", new RpcEndpoint { - override val rpcEnv: RpcEnv = testRpcEnv - }) - - val workerToPersist = new WorkerInfo( - id = "test_worker", - host = "127.0.0.1", - port = 10000, - cores = 0, - memory = 0, - endpoint = workerEndpoint, - webUiPort = 0, - publicAddress = "" - ) - - persistenceEngine.addWorker(workerToPersist) - - val (storedApps, storedDrivers, storedWorkers) = - persistenceEngine.readPersistedData(testRpcEnv) - - assert(storedApps.isEmpty) - assert(storedDrivers.isEmpty) - - // Check deserializing WorkerInfo - assert(storedWorkers.size == 1) - val recoveryWorkerInfo = storedWorkers.head - assert(workerToPersist.id === recoveryWorkerInfo.id) - assert(workerToPersist.host === recoveryWorkerInfo.host) - assert(workerToPersist.port === recoveryWorkerInfo.port) - assert(workerToPersist.cores === recoveryWorkerInfo.cores) - assert(workerToPersist.memory === recoveryWorkerInfo.memory) - assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint) - assert(workerToPersist.webUiPort === recoveryWorkerInfo.webUiPort) - assert(workerToPersist.publicAddress === recoveryWorkerInfo.publicAddress) + persistenceEngine.persist("test_1", "test_1_value") + assert(Seq("test_1_value") === persistenceEngine.read[String]("test_")) + persistenceEngine.persist("test_2", "test_2_value") + assert(Set("test_1_value", "test_2_value") === persistenceEngine.read[String]("test_").toSet) + persistenceEngine.unpersist("test_1") + assert(Seq("test_2_value") === persistenceEngine.read[String]("test_")) + persistenceEngine.unpersist("test_2") + assert(persistenceEngine.read[String]("test_").isEmpty) + + // Test deserializing objects that contain RpcEndpointRef + val testRpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + try { + // Create a real endpoint so that we can test RpcEndpointRef deserialization + val workerEndpoint = testRpcEnv.setupEndpoint("worker", new RpcEndpoint { + override val rpcEnv: RpcEnv = testRpcEnv + }) + + val workerToPersist = new WorkerInfo( + id = "test_worker", + host = "127.0.0.1", + port = 10000, + cores = 0, + memory = 0, + endpoint = workerEndpoint, + webUiAddress = "http://localhost:80") + + persistenceEngine.addWorker(workerToPersist) + + val (storedApps, storedDrivers, storedWorkers) = + persistenceEngine.readPersistedData(testRpcEnv) + + assert(storedApps.isEmpty) + assert(storedDrivers.isEmpty) + + // Check deserializing WorkerInfo + assert(storedWorkers.size == 1) + val recoveryWorkerInfo = storedWorkers.head + assert(workerToPersist.id === recoveryWorkerInfo.id) + assert(workerToPersist.host === recoveryWorkerInfo.host) + assert(workerToPersist.port === recoveryWorkerInfo.port) + assert(workerToPersist.cores === recoveryWorkerInfo.cores) + assert(workerToPersist.memory === recoveryWorkerInfo.memory) + assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint) + assert(workerToPersist.webUiAddress === recoveryWorkerInfo.webUiAddress) + } finally { + testRpcEnv.shutdown() + testRpcEnv.awaitTermination() + } } finally { - testRpcEnv.shutdown() - testRpcEnv.awaitTermination() + persistenceEngine.close() } } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala new file mode 100644 index 0000000000000..0c9382a92bcaf --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master.ui + +import java.util.Date + +import scala.io.Source +import scala.language.postfixOps + +import org.json4s.jackson.JsonMethods._ +import org.json4s.JsonAST.{JInt, JNothing, JString} +import org.mockito.Mockito.{mock, when} +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.deploy.DeployMessages.MasterStateResponse +import org.apache.spark.deploy.DeployTestUtils._ +import org.apache.spark.deploy.master._ +import org.apache.spark.rpc.RpcEnv + + +class MasterWebUISuite extends SparkFunSuite with BeforeAndAfter { + + val masterPage = mock(classOf[MasterPage]) + val master = { + val conf = new SparkConf + val securityMgr = new SecurityManager(conf) + val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityMgr) + val master = new Master(rpcEnv, rpcEnv.address, 0, securityMgr, conf) + master + } + val masterWebUI = new MasterWebUI(master, 0, customMasterPage = Some(masterPage)) + + before { + masterWebUI.bind() + } + + after { + masterWebUI.stop() + } + + test("list applications") { + val worker = createWorkerInfo() + val appDesc = createAppDesc() + // use new start date so it isn't filtered by UI + val activeApp = new ApplicationInfo( + new Date().getTime, "id", appDesc, new Date(), null, Int.MaxValue) + activeApp.addExecutor(worker, 2) + + val workers = Array[WorkerInfo](worker) + val activeApps = Array(activeApp) + val completedApps = Array[ApplicationInfo]() + val activeDrivers = Array[DriverInfo]() + val completedDrivers = Array[DriverInfo]() + val stateResponse = new MasterStateResponse( + "host", 8080, None, workers, activeApps, completedApps, + activeDrivers, completedDrivers, RecoveryState.ALIVE) + + when(masterPage.getMasterState).thenReturn(stateResponse) + + val resultJson = Source.fromURL( + s"http://localhost:${masterWebUI.boundPort}/api/v1/applications") + .mkString + val parsedJson = parse(resultJson) + val firstApp = parsedJson(0) + + assert(firstApp \ "id" === JString(activeApp.id)) + assert(firstApp \ "name" === JString(activeApp.desc.name)) + assert(firstApp \ "coresGranted" === JInt(2)) + assert(firstApp \ "maxCores" === JInt(4)) + assert(firstApp \ "memoryPerExecutorMB" === JInt(1234)) + assert(firstApp \ "coresPerExecutor" === JNothing) + } + +} diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 9693e32bf6af6..a7bb9aa4686eb 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -19,21 +19,21 @@ package org.apache.spark.deploy.rest import java.io.DataOutputStream import java.net.{HttpURLConnection, URL} +import java.nio.charset.StandardCharsets import javax.servlet.http.HttpServletResponse import scala.collection.mutable -import com.google.common.base.Charsets -import org.scalatest.BeforeAndAfterEach import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods._ +import org.scalatest.BeforeAndAfterEach import org.apache.spark._ -import org.apache.spark.rpc._ -import org.apache.spark.util.Utils -import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.{SparkSubmit, SparkSubmitArguments} +import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.DriverState._ +import org.apache.spark.rpc._ +import org.apache.spark.util.Utils /** * Tests for the REST application submission protocol used in standalone cluster mode. @@ -43,8 +43,12 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { private var server: Option[RestSubmissionServer] = None override def afterEach() { - rpcEnv.foreach(_.shutdown()) - server.foreach(_.stop()) + try { + rpcEnv.foreach(_.shutdown()) + server.foreach(_.stop()) + } finally { + super.afterEach() + } } test("construct submit request") { @@ -494,7 +498,7 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { if (body.nonEmpty) { conn.setDoOutput(true) val out = new DataOutputStream(conn.getOutputStream) - out.write(body.getBytes(Charsets.UTF_8)) + out.write(body.getBytes(StandardCharsets.UTF_8)) out.close() } conn diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala index 7101cb9978df3..607c0a4fac46b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.deploy.worker +import org.scalatest.{Matchers, PrivateMethodTester} + import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.Command import org.apache.spark.util.Utils -import org.scalatest.{Matchers, PrivateMethodTester} class CommandUtilsSuite extends SparkFunSuite with Matchers with PrivateMethodTester { diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala index 6258c18d177fd..2a1696be3660a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala @@ -19,8 +19,8 @@ package org.apache.spark.deploy.worker import java.io.File -import org.mockito.Mockito._ import org.mockito.Matchers._ +import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer @@ -34,7 +34,7 @@ class DriverRunnerTest extends SparkFunSuite { val driverDescription = new DriverDescription("jarUrl", 512, 1, true, command) val conf = new SparkConf() new DriverRunner(conf, "driverId", new File("workDir"), new File("sparkHome"), - driverDescription, null, "akka://1.2.3.4/worker/", new SecurityManager(conf)) + driverDescription, null, "spark://1.2.3.4/worker/", new SecurityManager(conf)) } private def createProcessBuilderAndProcess(): (ProcessBuilderLike, Process) = { diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index 98664dc1101e6..0240bf8aed4cd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -19,8 +19,8 @@ package org.apache.spark.deploy.worker import java.io.File -import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} class ExecutorRunnerTest extends SparkFunSuite { test("command includes appId") { diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index faed4bdc68447..101a44edd8ee2 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.deploy.worker import org.scalatest.Matchers +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.deploy.{Command, ExecutorState} import org.apache.spark.deploy.DeployMessages.{DriverStateChanged, ExecutorStateChanged} import org.apache.spark.deploy.master.DriverState -import org.apache.spark.deploy.{Command, ExecutorState} import org.apache.spark.rpc.{RpcAddress, RpcEnv} -import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} class WorkerSuite extends SparkFunSuite with Matchers { @@ -67,7 +67,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { conf.set("spark.worker.ui.retainedExecutors", 2.toString) val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + "Worker", "/tmp", conf, new SecurityManager(conf)) // initialize workers for (i <- 0 until 5) { worker.executors += s"app1/$i" -> createExecutorRunner(i) @@ -93,7 +93,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { conf.set("spark.worker.ui.retainedExecutors", 30.toString) val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + "Worker", "/tmp", conf, new SecurityManager(conf)) // initialize workers for (i <- 0 until 50) { worker.executors += s"app1/$i" -> createExecutorRunner(i) @@ -128,7 +128,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { conf.set("spark.worker.ui.retainedDrivers", 2.toString) val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + "Worker", "/tmp", conf, new SecurityManager(conf)) // initialize workers for (i <- 0 until 5) { val driverId = s"driverId-$i" @@ -154,7 +154,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { conf.set("spark.worker.ui.retainedDrivers", 30.toString) val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + "Worker", "/tmp", conf, new SecurityManager(conf)) // initialize workers for (i <- 0 until 50) { val driverId = s"driverId-$i" diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index 40c24bdecc6ce..31bea3293ae77 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.deploy.worker -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.SecurityManager -import org.apache.spark.rpc.{RpcAddress, RpcEnv} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress, RpcEnv} class WorkerWatcherSuite extends SparkFunSuite { test("WorkerWatcher shuts down on valid disassociation") { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") + val targetWorkerUrl = RpcEndpointAddress(RpcAddress("1.2.3.4", 1234), "Worker").toString val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl, isTesting = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) workerWatcher.onDisconnected(RpcAddress("1.2.3.4", 1234)) @@ -36,7 +35,7 @@ class WorkerWatcherSuite extends SparkFunSuite { test("WorkerWatcher stays alive on invalid disassociation") { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") + val targetWorkerUrl = RpcEndpointAddress(RpcAddress("1.2.3.4", 1234), "Worker").toString val otherRpcAddress = RpcAddress("4.3.2.1", 1234) val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl, isTesting = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala index 8275fd87764cd..d91f50f18f431 100644 --- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala @@ -17,12 +17,538 @@ package org.apache.spark.executor -import org.apache.spark.SparkFunSuite +import org.scalatest.Assertions + +import org.apache.spark._ +import org.apache.spark.scheduler.AccumulableInfo +import org.apache.spark.storage.{BlockId, BlockStatus, StorageLevel, TestBlockId} + class TaskMetricsSuite extends SparkFunSuite { - test("[SPARK-5701] updateShuffleReadMetrics: ShuffleReadMetrics not added when no shuffle deps") { - val taskMetrics = new TaskMetrics() - taskMetrics.updateShuffleReadMetrics() - assert(taskMetrics.shuffleReadMetrics.isEmpty) + import AccumulatorParam._ + import InternalAccumulator._ + import StorageLevel._ + import TaskMetricsSuite._ + + test("create") { + val internalAccums = InternalAccumulator.createAll() + val tm1 = new TaskMetrics + val tm2 = new TaskMetrics(internalAccums) + assert(tm1.accumulatorUpdates().size === internalAccums.size) + assert(tm1.shuffleReadMetrics.isEmpty) + assert(tm1.shuffleWriteMetrics.isEmpty) + assert(tm1.inputMetrics.isEmpty) + assert(tm1.outputMetrics.isEmpty) + assert(tm2.accumulatorUpdates().size === internalAccums.size) + assert(tm2.shuffleReadMetrics.isEmpty) + assert(tm2.shuffleWriteMetrics.isEmpty) + assert(tm2.inputMetrics.isEmpty) + assert(tm2.outputMetrics.isEmpty) + // TaskMetrics constructor expects minimal set of initial accumulators + intercept[IllegalArgumentException] { new TaskMetrics(Seq.empty[Accumulator[_]]) } + } + + test("create with unnamed accum") { + intercept[IllegalArgumentException] { + new TaskMetrics( + InternalAccumulator.createAll() ++ Seq( + new Accumulator(0, IntAccumulatorParam, None, internal = true))) + } + } + + test("create with duplicate name accum") { + intercept[IllegalArgumentException] { + new TaskMetrics( + InternalAccumulator.createAll() ++ Seq( + new Accumulator(0, IntAccumulatorParam, Some(RESULT_SIZE), internal = true))) + } + } + + test("create with external accum") { + intercept[IllegalArgumentException] { + new TaskMetrics( + InternalAccumulator.createAll() ++ Seq( + new Accumulator(0, IntAccumulatorParam, Some("x")))) + } + } + + test("create shuffle read metrics") { + import shuffleRead._ + val accums = InternalAccumulator.createShuffleReadAccums() + .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]] + accums(REMOTE_BLOCKS_FETCHED).setValueAny(1) + accums(LOCAL_BLOCKS_FETCHED).setValueAny(2) + accums(REMOTE_BYTES_READ).setValueAny(3L) + accums(LOCAL_BYTES_READ).setValueAny(4L) + accums(FETCH_WAIT_TIME).setValueAny(5L) + accums(RECORDS_READ).setValueAny(6L) + val sr = new ShuffleReadMetrics(accums) + assert(sr.remoteBlocksFetched === 1) + assert(sr.localBlocksFetched === 2) + assert(sr.remoteBytesRead === 3L) + assert(sr.localBytesRead === 4L) + assert(sr.fetchWaitTime === 5L) + assert(sr.recordsRead === 6L) + } + + test("create shuffle write metrics") { + import shuffleWrite._ + val accums = InternalAccumulator.createShuffleWriteAccums() + .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]] + accums(BYTES_WRITTEN).setValueAny(1L) + accums(RECORDS_WRITTEN).setValueAny(2L) + accums(WRITE_TIME).setValueAny(3L) + val sw = new ShuffleWriteMetrics(accums) + assert(sw.bytesWritten === 1L) + assert(sw.recordsWritten === 2L) + assert(sw.writeTime === 3L) + } + + test("create input metrics") { + import input._ + val accums = InternalAccumulator.createInputAccums() + .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]] + accums(BYTES_READ).setValueAny(1L) + accums(RECORDS_READ).setValueAny(2L) + accums(READ_METHOD).setValueAny(DataReadMethod.Hadoop.toString) + val im = new InputMetrics(accums) + assert(im.bytesRead === 1L) + assert(im.recordsRead === 2L) + assert(im.readMethod === DataReadMethod.Hadoop) } + + test("create output metrics") { + import output._ + val accums = InternalAccumulator.createOutputAccums() + .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]] + accums(BYTES_WRITTEN).setValueAny(1L) + accums(RECORDS_WRITTEN).setValueAny(2L) + accums(WRITE_METHOD).setValueAny(DataWriteMethod.Hadoop.toString) + val om = new OutputMetrics(accums) + assert(om.bytesWritten === 1L) + assert(om.recordsWritten === 2L) + assert(om.writeMethod === DataWriteMethod.Hadoop) + } + + test("mutating values") { + val accums = InternalAccumulator.createAll() + val tm = new TaskMetrics(accums) + // initial values + assertValueEquals(tm, _.executorDeserializeTime, accums, EXECUTOR_DESERIALIZE_TIME, 0L) + assertValueEquals(tm, _.executorRunTime, accums, EXECUTOR_RUN_TIME, 0L) + assertValueEquals(tm, _.resultSize, accums, RESULT_SIZE, 0L) + assertValueEquals(tm, _.jvmGCTime, accums, JVM_GC_TIME, 0L) + assertValueEquals(tm, _.resultSerializationTime, accums, RESULT_SERIALIZATION_TIME, 0L) + assertValueEquals(tm, _.memoryBytesSpilled, accums, MEMORY_BYTES_SPILLED, 0L) + assertValueEquals(tm, _.diskBytesSpilled, accums, DISK_BYTES_SPILLED, 0L) + assertValueEquals(tm, _.peakExecutionMemory, accums, PEAK_EXECUTION_MEMORY, 0L) + assertValueEquals(tm, _.updatedBlockStatuses, accums, UPDATED_BLOCK_STATUSES, + Seq.empty[(BlockId, BlockStatus)]) + // set or increment values + tm.setExecutorDeserializeTime(100L) + tm.setExecutorDeserializeTime(1L) // overwrite + tm.setExecutorRunTime(200L) + tm.setExecutorRunTime(2L) + tm.setResultSize(300L) + tm.setResultSize(3L) + tm.setJvmGCTime(400L) + tm.setJvmGCTime(4L) + tm.setResultSerializationTime(500L) + tm.setResultSerializationTime(5L) + tm.incMemoryBytesSpilled(600L) + tm.incMemoryBytesSpilled(6L) // add + tm.incDiskBytesSpilled(700L) + tm.incDiskBytesSpilled(7L) + tm.incPeakExecutionMemory(800L) + tm.incPeakExecutionMemory(8L) + val block1 = (TestBlockId("a"), BlockStatus(MEMORY_ONLY, 1L, 2L)) + val block2 = (TestBlockId("b"), BlockStatus(MEMORY_ONLY, 3L, 4L)) + tm.incUpdatedBlockStatuses(Seq(block1)) + tm.incUpdatedBlockStatuses(Seq(block2)) + // assert new values exist + assertValueEquals(tm, _.executorDeserializeTime, accums, EXECUTOR_DESERIALIZE_TIME, 1L) + assertValueEquals(tm, _.executorRunTime, accums, EXECUTOR_RUN_TIME, 2L) + assertValueEquals(tm, _.resultSize, accums, RESULT_SIZE, 3L) + assertValueEquals(tm, _.jvmGCTime, accums, JVM_GC_TIME, 4L) + assertValueEquals(tm, _.resultSerializationTime, accums, RESULT_SERIALIZATION_TIME, 5L) + assertValueEquals(tm, _.memoryBytesSpilled, accums, MEMORY_BYTES_SPILLED, 606L) + assertValueEquals(tm, _.diskBytesSpilled, accums, DISK_BYTES_SPILLED, 707L) + assertValueEquals(tm, _.peakExecutionMemory, accums, PEAK_EXECUTION_MEMORY, 808L) + assertValueEquals(tm, _.updatedBlockStatuses, accums, UPDATED_BLOCK_STATUSES, + Seq(block1, block2)) + } + + test("mutating shuffle read metrics values") { + import shuffleRead._ + val accums = InternalAccumulator.createAll() + val tm = new TaskMetrics(accums) + def assertValEquals[T](tmValue: ShuffleReadMetrics => T, name: String, value: T): Unit = { + assertValueEquals(tm, tm => tmValue(tm.shuffleReadMetrics.get), accums, name, value) + } + // create shuffle read metrics + assert(tm.shuffleReadMetrics.isEmpty) + tm.registerTempShuffleReadMetrics() + tm.mergeShuffleReadMetrics() + assert(tm.shuffleReadMetrics.isDefined) + val sr = tm.shuffleReadMetrics.get + // initial values + assertValEquals(_.remoteBlocksFetched, REMOTE_BLOCKS_FETCHED, 0) + assertValEquals(_.localBlocksFetched, LOCAL_BLOCKS_FETCHED, 0) + assertValEquals(_.remoteBytesRead, REMOTE_BYTES_READ, 0L) + assertValEquals(_.localBytesRead, LOCAL_BYTES_READ, 0L) + assertValEquals(_.fetchWaitTime, FETCH_WAIT_TIME, 0L) + assertValEquals(_.recordsRead, RECORDS_READ, 0L) + // set and increment values + sr.setRemoteBlocksFetched(100) + sr.setRemoteBlocksFetched(10) + sr.incRemoteBlocksFetched(1) // 10 + 1 + sr.incRemoteBlocksFetched(1) // 10 + 1 + 1 + sr.setLocalBlocksFetched(200) + sr.setLocalBlocksFetched(20) + sr.incLocalBlocksFetched(2) + sr.incLocalBlocksFetched(2) + sr.setRemoteBytesRead(300L) + sr.setRemoteBytesRead(30L) + sr.incRemoteBytesRead(3L) + sr.incRemoteBytesRead(3L) + sr.setLocalBytesRead(400L) + sr.setLocalBytesRead(40L) + sr.incLocalBytesRead(4L) + sr.incLocalBytesRead(4L) + sr.setFetchWaitTime(500L) + sr.setFetchWaitTime(50L) + sr.incFetchWaitTime(5L) + sr.incFetchWaitTime(5L) + sr.setRecordsRead(600L) + sr.setRecordsRead(60L) + sr.incRecordsRead(6L) + sr.incRecordsRead(6L) + // assert new values exist + assertValEquals(_.remoteBlocksFetched, REMOTE_BLOCKS_FETCHED, 12) + assertValEquals(_.localBlocksFetched, LOCAL_BLOCKS_FETCHED, 24) + assertValEquals(_.remoteBytesRead, REMOTE_BYTES_READ, 36L) + assertValEquals(_.localBytesRead, LOCAL_BYTES_READ, 48L) + assertValEquals(_.fetchWaitTime, FETCH_WAIT_TIME, 60L) + assertValEquals(_.recordsRead, RECORDS_READ, 72L) + } + + test("mutating shuffle write metrics values") { + import shuffleWrite._ + val accums = InternalAccumulator.createAll() + val tm = new TaskMetrics(accums) + def assertValEquals[T](tmValue: ShuffleWriteMetrics => T, name: String, value: T): Unit = { + assertValueEquals(tm, tm => tmValue(tm.shuffleWriteMetrics.get), accums, name, value) + } + // create shuffle write metrics + assert(tm.shuffleWriteMetrics.isEmpty) + tm.registerShuffleWriteMetrics() + assert(tm.shuffleWriteMetrics.isDefined) + val sw = tm.shuffleWriteMetrics.get + // initial values + assertValEquals(_.bytesWritten, BYTES_WRITTEN, 0L) + assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 0L) + assertValEquals(_.writeTime, WRITE_TIME, 0L) + // increment and decrement values + sw.incBytesWritten(100L) + sw.incBytesWritten(10L) // 100 + 10 + sw.decBytesWritten(1L) // 100 + 10 - 1 + sw.decBytesWritten(1L) // 100 + 10 - 1 - 1 + sw.incRecordsWritten(200L) + sw.incRecordsWritten(20L) + sw.decRecordsWritten(2L) + sw.decRecordsWritten(2L) + sw.incWriteTime(300L) + sw.incWriteTime(30L) + // assert new values exist + assertValEquals(_.bytesWritten, BYTES_WRITTEN, 108L) + assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 216L) + assertValEquals(_.writeTime, WRITE_TIME, 330L) + } + + test("mutating input metrics values") { + import input._ + val accums = InternalAccumulator.createAll() + val tm = new TaskMetrics(accums) + def assertValEquals(tmValue: InputMetrics => Any, name: String, value: Any): Unit = { + assertValueEquals(tm, tm => tmValue(tm.inputMetrics.get), accums, name, value, + (x: Any, y: Any) => assert(x.toString === y.toString)) + } + // create input metrics + assert(tm.inputMetrics.isEmpty) + tm.registerInputMetrics(DataReadMethod.Memory) + assert(tm.inputMetrics.isDefined) + val in = tm.inputMetrics.get + // initial values + assertValEquals(_.bytesRead, BYTES_READ, 0L) + assertValEquals(_.recordsRead, RECORDS_READ, 0L) + assertValEquals(_.readMethod, READ_METHOD, DataReadMethod.Memory) + // set and increment values + in.setBytesRead(1L) + in.setBytesRead(2L) + in.incRecordsRead(1L) + in.incRecordsRead(2L) + in.setReadMethod(DataReadMethod.Disk) + // assert new values exist + assertValEquals(_.bytesRead, BYTES_READ, 2L) + assertValEquals(_.recordsRead, RECORDS_READ, 3L) + assertValEquals(_.readMethod, READ_METHOD, DataReadMethod.Disk) + } + + test("mutating output metrics values") { + import output._ + val accums = InternalAccumulator.createAll() + val tm = new TaskMetrics(accums) + def assertValEquals(tmValue: OutputMetrics => Any, name: String, value: Any): Unit = { + assertValueEquals(tm, tm => tmValue(tm.outputMetrics.get), accums, name, value, + (x: Any, y: Any) => assert(x.toString === y.toString)) + } + // create input metrics + assert(tm.outputMetrics.isEmpty) + tm.registerOutputMetrics(DataWriteMethod.Hadoop) + assert(tm.outputMetrics.isDefined) + val out = tm.outputMetrics.get + // initial values + assertValEquals(_.bytesWritten, BYTES_WRITTEN, 0L) + assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 0L) + assertValEquals(_.writeMethod, WRITE_METHOD, DataWriteMethod.Hadoop) + // set values + out.setBytesWritten(1L) + out.setBytesWritten(2L) + out.setRecordsWritten(3L) + out.setRecordsWritten(4L) + out.setWriteMethod(DataWriteMethod.Hadoop) + // assert new values exist + assertValEquals(_.bytesWritten, BYTES_WRITTEN, 2L) + assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 4L) + // Note: this doesn't actually test anything, but there's only one DataWriteMethod + // so we can't set it to anything else + assertValEquals(_.writeMethod, WRITE_METHOD, DataWriteMethod.Hadoop) + } + + test("merging multiple shuffle read metrics") { + val tm = new TaskMetrics + assert(tm.shuffleReadMetrics.isEmpty) + val sr1 = tm.registerTempShuffleReadMetrics() + val sr2 = tm.registerTempShuffleReadMetrics() + val sr3 = tm.registerTempShuffleReadMetrics() + assert(tm.shuffleReadMetrics.isEmpty) + sr1.setRecordsRead(10L) + sr2.setRecordsRead(10L) + sr1.setFetchWaitTime(1L) + sr2.setFetchWaitTime(2L) + sr3.setFetchWaitTime(3L) + tm.mergeShuffleReadMetrics() + assert(tm.shuffleReadMetrics.isDefined) + val sr = tm.shuffleReadMetrics.get + assert(sr.remoteBlocksFetched === 0L) + assert(sr.recordsRead === 20L) + assert(sr.fetchWaitTime === 6L) + + // SPARK-5701: calling merge without any shuffle deps does nothing + val tm2 = new TaskMetrics + tm2.mergeShuffleReadMetrics() + assert(tm2.shuffleReadMetrics.isEmpty) + } + + test("register multiple shuffle write metrics") { + val tm = new TaskMetrics + val sw1 = tm.registerShuffleWriteMetrics() + val sw2 = tm.registerShuffleWriteMetrics() + assert(sw1 === sw2) + assert(tm.shuffleWriteMetrics === Some(sw1)) + } + + test("register multiple input metrics") { + val tm = new TaskMetrics + val im1 = tm.registerInputMetrics(DataReadMethod.Memory) + val im2 = tm.registerInputMetrics(DataReadMethod.Memory) + // input metrics with a different read method than the one already registered are ignored + val im3 = tm.registerInputMetrics(DataReadMethod.Hadoop) + assert(im1 === im2) + assert(im1 !== im3) + assert(tm.inputMetrics === Some(im1)) + im2.setBytesRead(50L) + im3.setBytesRead(100L) + assert(tm.inputMetrics.get.bytesRead === 50L) + } + + test("register multiple output metrics") { + val tm = new TaskMetrics + val om1 = tm.registerOutputMetrics(DataWriteMethod.Hadoop) + val om2 = tm.registerOutputMetrics(DataWriteMethod.Hadoop) + assert(om1 === om2) + assert(tm.outputMetrics === Some(om1)) + } + + test("additional accumulables") { + val internalAccums = InternalAccumulator.createAll() + val tm = new TaskMetrics(internalAccums) + assert(tm.accumulatorUpdates().size === internalAccums.size) + val acc1 = new Accumulator(0, IntAccumulatorParam, Some("a")) + val acc2 = new Accumulator(0, IntAccumulatorParam, Some("b")) + val acc3 = new Accumulator(0, IntAccumulatorParam, Some("c")) + val acc4 = new Accumulator(0, IntAccumulatorParam, Some("d"), + internal = true, countFailedValues = true) + tm.registerAccumulator(acc1) + tm.registerAccumulator(acc2) + tm.registerAccumulator(acc3) + tm.registerAccumulator(acc4) + acc1 += 1 + acc2 += 2 + val newUpdates = tm.accumulatorUpdates().map { a => (a.id, a) }.toMap + assert(newUpdates.contains(acc1.id)) + assert(newUpdates.contains(acc2.id)) + assert(newUpdates.contains(acc3.id)) + assert(newUpdates.contains(acc4.id)) + assert(newUpdates(acc1.id).name === Some("a")) + assert(newUpdates(acc2.id).name === Some("b")) + assert(newUpdates(acc3.id).name === Some("c")) + assert(newUpdates(acc4.id).name === Some("d")) + assert(newUpdates(acc1.id).update === Some(1)) + assert(newUpdates(acc2.id).update === Some(2)) + assert(newUpdates(acc3.id).update === Some(0)) + assert(newUpdates(acc4.id).update === Some(0)) + assert(!newUpdates(acc3.id).internal) + assert(!newUpdates(acc3.id).countFailedValues) + assert(newUpdates(acc4.id).internal) + assert(newUpdates(acc4.id).countFailedValues) + assert(newUpdates.values.map(_.update).forall(_.isDefined)) + assert(newUpdates.values.map(_.value).forall(_.isEmpty)) + assert(newUpdates.size === internalAccums.size + 4) + } + + test("existing values in shuffle read accums") { + // set shuffle read accum before passing it into TaskMetrics + val accums = InternalAccumulator.createAll() + val srAccum = accums.find(_.name === Some(shuffleRead.FETCH_WAIT_TIME)) + assert(srAccum.isDefined) + srAccum.get.asInstanceOf[Accumulator[Long]] += 10L + val tm = new TaskMetrics(accums) + assert(tm.shuffleReadMetrics.isDefined) + assert(tm.shuffleWriteMetrics.isEmpty) + assert(tm.inputMetrics.isEmpty) + assert(tm.outputMetrics.isEmpty) + } + + test("existing values in shuffle write accums") { + // set shuffle write accum before passing it into TaskMetrics + val accums = InternalAccumulator.createAll() + val swAccum = accums.find(_.name === Some(shuffleWrite.RECORDS_WRITTEN)) + assert(swAccum.isDefined) + swAccum.get.asInstanceOf[Accumulator[Long]] += 10L + val tm = new TaskMetrics(accums) + assert(tm.shuffleReadMetrics.isEmpty) + assert(tm.shuffleWriteMetrics.isDefined) + assert(tm.inputMetrics.isEmpty) + assert(tm.outputMetrics.isEmpty) + } + + test("existing values in input accums") { + // set input accum before passing it into TaskMetrics + val accums = InternalAccumulator.createAll() + val inAccum = accums.find(_.name === Some(input.RECORDS_READ)) + assert(inAccum.isDefined) + inAccum.get.asInstanceOf[Accumulator[Long]] += 10L + val tm = new TaskMetrics(accums) + assert(tm.shuffleReadMetrics.isEmpty) + assert(tm.shuffleWriteMetrics.isEmpty) + assert(tm.inputMetrics.isDefined) + assert(tm.outputMetrics.isEmpty) + } + + test("existing values in output accums") { + // set output accum before passing it into TaskMetrics + val accums = InternalAccumulator.createAll() + val outAccum = accums.find(_.name === Some(output.RECORDS_WRITTEN)) + assert(outAccum.isDefined) + outAccum.get.asInstanceOf[Accumulator[Long]] += 10L + val tm4 = new TaskMetrics(accums) + assert(tm4.shuffleReadMetrics.isEmpty) + assert(tm4.shuffleWriteMetrics.isEmpty) + assert(tm4.inputMetrics.isEmpty) + assert(tm4.outputMetrics.isDefined) + } + + test("from accumulator updates") { + val accumUpdates1 = InternalAccumulator.createAll().map { a => + AccumulableInfo(a.id, a.name, Some(3L), None, a.isInternal, a.countFailedValues) + } + val metrics1 = TaskMetrics.fromAccumulatorUpdates(accumUpdates1) + assertUpdatesEquals(metrics1.accumulatorUpdates(), accumUpdates1) + // Test this with additional accumulators to ensure that we do not crash when handling + // updates from unregistered accumulators. In practice, all accumulators created + // on the driver, internal or not, should be registered with `Accumulators` at some point. + val param = IntAccumulatorParam + val registeredAccums = Seq( + new Accumulator(0, param, Some("a"), internal = true, countFailedValues = true), + new Accumulator(0, param, Some("b"), internal = true, countFailedValues = false), + new Accumulator(0, param, Some("c"), internal = false, countFailedValues = true), + new Accumulator(0, param, Some("d"), internal = false, countFailedValues = false)) + val unregisteredAccums = Seq( + new Accumulator(0, param, Some("e"), internal = true, countFailedValues = true), + new Accumulator(0, param, Some("f"), internal = true, countFailedValues = false)) + registeredAccums.foreach(Accumulators.register) + registeredAccums.foreach { a => assert(Accumulators.originals.contains(a.id)) } + unregisteredAccums.foreach { a => assert(!Accumulators.originals.contains(a.id)) } + // set some values in these accums + registeredAccums.zipWithIndex.foreach { case (a, i) => a.setValue(i) } + unregisteredAccums.zipWithIndex.foreach { case (a, i) => a.setValue(i) } + val registeredAccumInfos = registeredAccums.map(makeInfo) + val unregisteredAccumInfos = unregisteredAccums.map(makeInfo) + val accumUpdates2 = accumUpdates1 ++ registeredAccumInfos ++ unregisteredAccumInfos + // Simply checking that this does not crash: + TaskMetrics.fromAccumulatorUpdates(accumUpdates2) + } +} + + +private[spark] object TaskMetricsSuite extends Assertions { + + /** + * Assert that the following three things are equal to `value`: + * (1) TaskMetrics value + * (2) TaskMetrics accumulator update value + * (3) Original accumulator value + */ + def assertValueEquals( + tm: TaskMetrics, + tmValue: TaskMetrics => Any, + accums: Seq[Accumulator[_]], + metricName: String, + value: Any, + assertEquals: (Any, Any) => Unit = (x: Any, y: Any) => assert(x === y)): Unit = { + assertEquals(tmValue(tm), value) + val accum = accums.find(_.name == Some(metricName)) + assert(accum.isDefined) + assertEquals(accum.get.value, value) + val accumUpdate = tm.accumulatorUpdates().find(_.name == Some(metricName)) + assert(accumUpdate.isDefined) + assert(accumUpdate.get.value === None) + assertEquals(accumUpdate.get.update, Some(value)) + } + + /** + * Assert that two lists of accumulator updates are equal. + * Note: this does NOT check accumulator ID equality. + */ + def assertUpdatesEquals( + updates1: Seq[AccumulableInfo], + updates2: Seq[AccumulableInfo]): Unit = { + assert(updates1.size === updates2.size) + updates1.zip(updates2).foreach { case (info1, info2) => + // do not assert ID equals here + assert(info1.name === info2.name) + assert(info1.update === info2.update) + assert(info1.value === info2.value) + assert(info1.internal === info2.internal) + assert(info1.countFailedValues === info2.countFailedValues) + } + } + + /** + * Make an [[AccumulableInfo]] out of an [[Accumulable]] with the intent to use the + * info as an accumulator update. + */ + def makeInfo(a: Accumulable[_, _]): AccumulableInfo = a.toInfo(Some(a.value), None) + } diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index 8a199459c1ddf..ddf73d6370631 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -23,13 +23,13 @@ import java.io.FileOutputStream import scala.collection.immutable.IndexedSeq -import org.scalatest.BeforeAndAfterAll - import org.apache.hadoop.io.Text +import org.apache.hadoop.io.compress.{CompressionCodecFactory, GzipCodec} +import org.scalatest.BeforeAndAfterAll -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.internal.Logging import org.apache.spark.util.Utils -import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, GzipCodec} /** * Tests the correctness of @@ -47,6 +47,7 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAl // hard-to-reproduce test failures, since any suites that were run after this one would inherit // the new value of "fs.local.block.size" (see SPARK-5227 and SPARK-5679). To work around this, // we disable FileSystem caching in this suite. + super.beforeAll() val conf = new SparkConf().set("spark.hadoop.fs.file.impl.disable.cache", "true") sc = new SparkContext("local", "test", conf) @@ -59,7 +60,11 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAl } override def afterAll() { - sc.stop() + try { + sc.stop() + } finally { + super.afterAll() + } } private def createNativeFile(inputDir: File, fileName: String, contents: Array[Byte], diff --git a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala new file mode 100644 index 0000000000000..337fd7e85e81c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.config + +import java.util.concurrent.TimeUnit + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.network.util.ByteUnit + +class ConfigEntrySuite extends SparkFunSuite { + + test("conf entry: int") { + val conf = new SparkConf() + val iConf = ConfigBuilder("spark.int").intConf.createWithDefault(1) + assert(conf.get(iConf) === 1) + conf.set(iConf, 2) + assert(conf.get(iConf) === 2) + } + + test("conf entry: long") { + val conf = new SparkConf() + val lConf = ConfigBuilder("spark.long").longConf.createWithDefault(0L) + conf.set(lConf, 1234L) + assert(conf.get(lConf) === 1234L) + } + + test("conf entry: double") { + val conf = new SparkConf() + val dConf = ConfigBuilder("spark.double").doubleConf.createWithDefault(0.0) + conf.set(dConf, 20.0) + assert(conf.get(dConf) === 20.0) + } + + test("conf entry: boolean") { + val conf = new SparkConf() + val bConf = ConfigBuilder("spark.boolean").booleanConf.createWithDefault(false) + assert(!conf.get(bConf)) + conf.set(bConf, true) + assert(conf.get(bConf)) + } + + test("conf entry: optional") { + val conf = new SparkConf() + val optionalConf = ConfigBuilder("spark.optional").intConf.createOptional + assert(conf.get(optionalConf) === None) + conf.set(optionalConf, 1) + assert(conf.get(optionalConf) === Some(1)) + } + + test("conf entry: fallback") { + val conf = new SparkConf() + val parentConf = ConfigBuilder("spark.int").intConf.createWithDefault(1) + val confWithFallback = ConfigBuilder("spark.fallback").fallbackConf(parentConf) + assert(conf.get(confWithFallback) === 1) + conf.set(confWithFallback, 2) + assert(conf.get(parentConf) === 1) + assert(conf.get(confWithFallback) === 2) + } + + test("conf entry: time") { + val conf = new SparkConf() + val time = ConfigBuilder("spark.time").timeConf(TimeUnit.SECONDS).createWithDefaultString("1h") + assert(conf.get(time) === 3600L) + conf.set(time.key, "1m") + assert(conf.get(time) === 60L) + } + + test("conf entry: bytes") { + val conf = new SparkConf() + val bytes = ConfigBuilder("spark.bytes").bytesConf(ByteUnit.KiB).createWithDefaultString("1m") + assert(conf.get(bytes) === 1024L) + conf.set(bytes.key, "1k") + assert(conf.get(bytes) === 1L) + } + + test("conf entry: string seq") { + val conf = new SparkConf() + val seq = ConfigBuilder("spark.seq").stringConf.toSequence.createWithDefault(Seq()) + conf.set(seq.key, "1,,2, 3 , , 4") + assert(conf.get(seq) === Seq("1", "2", "3", "4")) + conf.set(seq, Seq("1", "2")) + assert(conf.get(seq) === Seq("1", "2")) + } + + test("conf entry: int seq") { + val conf = new SparkConf() + val seq = ConfigBuilder("spark.seq").intConf.toSequence.createWithDefault(Seq()) + conf.set(seq.key, "1,,2, 3 , , 4") + assert(conf.get(seq) === Seq(1, 2, 3, 4)) + conf.set(seq, Seq(1, 2)) + assert(conf.get(seq) === Seq(1, 2)) + } + + test("conf entry: transformation") { + val conf = new SparkConf() + val transformationConf = ConfigBuilder("spark.transformation") + .stringConf + .transform(_.toLowerCase()) + .createWithDefault("FOO") + + assert(conf.get(transformationConf) === "foo") + conf.set(transformationConf, "BAR") + assert(conf.get(transformationConf) === "bar") + } + + test("conf entry: valid values check") { + val conf = new SparkConf() + val enum = ConfigBuilder("spark.enum") + .stringConf + .checkValues(Set("a", "b", "c")) + .createWithDefault("a") + assert(conf.get(enum) === "a") + + conf.set(enum, "b") + assert(conf.get(enum) === "b") + + conf.set(enum, "d") + val enumError = intercept[IllegalArgumentException] { + conf.get(enum) + } + assert(enumError.getMessage === s"The value of ${enum.key} should be one of a, b, c, but was d") + } + + test("conf entry: conversion error") { + val conf = new SparkConf() + val conversionTest = ConfigBuilder("spark.conversionTest").doubleConf.createOptional + conf.set(conversionTest.key, "abc") + val conversionError = intercept[IllegalArgumentException] { + conf.get(conversionTest) + } + assert(conversionError.getMessage === s"${conversionTest.key} should be double, but was abc") + } + + test("default value handling is null-safe") { + val conf = new SparkConf() + val stringConf = ConfigBuilder("spark.string").stringConf.createWithDefault(null) + assert(conf.get(stringConf) === null) + } + +} diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala new file mode 100644 index 0000000000000..f205d4f0d60b5 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.io + +import java.nio.ByteBuffer + +import com.google.common.io.ByteStreams + +import org.apache.spark.SparkFunSuite +import org.apache.spark.network.util.ByteArrayWritableChannel +import org.apache.spark.util.io.ChunkedByteBuffer + +class ChunkedByteBufferSuite extends SparkFunSuite { + + test("no chunks") { + val emptyChunkedByteBuffer = new ChunkedByteBuffer(Array.empty[ByteBuffer]) + assert(emptyChunkedByteBuffer.size === 0) + assert(emptyChunkedByteBuffer.getChunks().isEmpty) + assert(emptyChunkedByteBuffer.toArray === Array.empty) + assert(emptyChunkedByteBuffer.toByteBuffer.capacity() === 0) + assert(emptyChunkedByteBuffer.toNetty.capacity() === 0) + emptyChunkedByteBuffer.toInputStream(dispose = false).close() + emptyChunkedByteBuffer.toInputStream(dispose = true).close() + } + + test("chunks must be non-empty") { + intercept[IllegalArgumentException] { + new ChunkedByteBuffer(Array(ByteBuffer.allocate(0))) + } + } + + test("getChunks() duplicates chunks") { + val chunkedByteBuffer = new ChunkedByteBuffer(Array(ByteBuffer.allocate(8))) + chunkedByteBuffer.getChunks().head.position(4) + assert(chunkedByteBuffer.getChunks().head.position() === 0) + } + + test("copy() does not affect original buffer's position") { + val chunkedByteBuffer = new ChunkedByteBuffer(Array(ByteBuffer.allocate(8))) + chunkedByteBuffer.copy(ByteBuffer.allocate) + assert(chunkedByteBuffer.getChunks().head.position() === 0) + } + + test("writeFully() does not affect original buffer's position") { + val chunkedByteBuffer = new ChunkedByteBuffer(Array(ByteBuffer.allocate(8))) + chunkedByteBuffer.writeFully(new ByteArrayWritableChannel(chunkedByteBuffer.size.toInt)) + assert(chunkedByteBuffer.getChunks().head.position() === 0) + } + + test("toArray()") { + val bytes = ByteBuffer.wrap(Array.tabulate(8)(_.toByte)) + val chunkedByteBuffer = new ChunkedByteBuffer(Array(bytes, bytes)) + assert(chunkedByteBuffer.toArray === bytes.array() ++ bytes.array()) + } + + test("toArray() throws UnsupportedOperationException if size exceeds 2GB") { + val fourMegabyteBuffer = ByteBuffer.allocate(1024 * 1024 * 4) + fourMegabyteBuffer.limit(fourMegabyteBuffer.capacity()) + val chunkedByteBuffer = new ChunkedByteBuffer(Array.fill(1024)(fourMegabyteBuffer)) + assert(chunkedByteBuffer.size === (1024L * 1024L * 1024L * 4L)) + intercept[UnsupportedOperationException] { + chunkedByteBuffer.toArray + } + } + + test("toInputStream()") { + val bytes1 = ByteBuffer.wrap(Array.tabulate(256)(_.toByte)) + val bytes2 = ByteBuffer.wrap(Array.tabulate(128)(_.toByte)) + val chunkedByteBuffer = new ChunkedByteBuffer(Array(bytes1, bytes2)) + assert(chunkedByteBuffer.size === bytes1.limit() + bytes2.limit()) + + val inputStream = chunkedByteBuffer.toInputStream(dispose = false) + val bytesFromStream = new Array[Byte](chunkedByteBuffer.size.toInt) + ByteStreams.readFully(inputStream, bytesFromStream) + assert(bytesFromStream === bytes1.array() ++ bytes2.array()) + assert(chunkedByteBuffer.getChunks().head.position() === 0) + } +} diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index 1553ab60bddaa..9e9c2b0165e13 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -46,7 +46,7 @@ class CompressionCodecSuite extends SparkFunSuite { test("default compression codec") { val codec = CompressionCodec.createCodec(conf) - assert(codec.getClass === classOf[SnappyCompressionCodec]) + assert(codec.getClass === classOf[LZ4CompressionCodec]) testCodec(codec) } @@ -62,12 +62,10 @@ class CompressionCodecSuite extends SparkFunSuite { testCodec(codec) } - test("lz4 does not support concatenation of serialized streams") { + test("lz4 supports concatenation of serialized streams") { val codec = CompressionCodec.createCodec(conf, classOf[LZ4CompressionCodec].getName) assert(codec.getClass === classOf[LZ4CompressionCodec]) - intercept[Exception] { - testConcatenationOfSerializedStreams(codec) - } + testConcatenationOfSerializedStreams(codec) } test("lzf compression codec") { diff --git a/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala b/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala index 639d1daa36c73..713560d3ddfa1 100644 --- a/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala @@ -26,7 +26,6 @@ import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ import org.apache.spark._ -import org.apache.spark.launcher._ class LauncherBackendSuite extends SparkFunSuite with Matchers { diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala index 4a9479cf490fb..99d5b496bcd2e 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -19,72 +19,66 @@ package org.apache.spark.memory import java.util.concurrent.atomic.AtomicLong -import scala.concurrent.duration.Duration +import scala.collection.mutable import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration.Duration import org.mockito.Matchers.{any, anyLong} -import org.mockito.Mockito.{mock, when} +import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfterEach import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite -import org.apache.spark.storage.MemoryStore +import org.apache.spark.storage.{BlockId, BlockStatus, StorageLevel} +import org.apache.spark.storage.memory.MemoryStore /** * Helper trait for sharing code among [[MemoryManager]] tests. */ -private[memory] trait MemoryManagerSuite extends SparkFunSuite { +private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAfterEach { + + protected val evictedBlocks = new mutable.ArrayBuffer[(BlockId, BlockStatus)] - import MemoryManagerSuite.DEFAULT_ENSURE_FREE_SPACE_CALLED + import MemoryManagerSuite.DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED // Note: Mockito's verify mechanism does not provide a way to reset method call counts // without also resetting stubbed methods. Since our test code relies on the latter, - // we need to use our own variable to track invocations of `ensureFreeSpace`. + // we need to use our own variable to track invocations of `evictBlocksToFreeSpace`. /** - * The amount of free space requested in the last call to [[MemoryStore.ensureFreeSpace]] + * The amount of space requested in the last call to [[MemoryStore.evictBlocksToFreeSpace]]. * - * This set whenever [[MemoryStore.ensureFreeSpace]] is called, and cleared when the test - * code makes explicit assertions on this variable through [[assertEnsureFreeSpaceCalled]]. + * This set whenever [[MemoryStore.evictBlocksToFreeSpace]] is called, and cleared when the test + * code makes explicit assertions on this variable through + * [[assertEvictBlocksToFreeSpaceCalled]]. */ - private val ensureFreeSpaceCalled = new AtomicLong(DEFAULT_ENSURE_FREE_SPACE_CALLED) + private val evictBlocksToFreeSpaceCalled = new AtomicLong(0) + + override def beforeEach(): Unit = { + super.beforeEach() + evictedBlocks.clear() + evictBlocksToFreeSpaceCalled.set(DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED) + } /** - * Make a mocked [[MemoryStore]] whose [[MemoryStore.ensureFreeSpace]] method is stubbed. + * Make a mocked [[MemoryStore]] whose [[MemoryStore.evictBlocksToFreeSpace]] method is stubbed. * - * This allows our test code to release storage memory when [[MemoryStore.ensureFreeSpace]] - * is called without relying on [[org.apache.spark.storage.BlockManager]] and all of its - * dependencies. + * This allows our test code to release storage memory when these methods are called + * without relying on [[org.apache.spark.storage.BlockManager]] and all of its dependencies. */ protected def makeMemoryStore(mm: MemoryManager): MemoryStore = { - val ms = mock(classOf[MemoryStore]) - when(ms.ensureFreeSpace(anyLong(), any())).thenAnswer(ensureFreeSpaceAnswer(mm, 0)) - when(ms.ensureFreeSpace(any(), anyLong(), any())).thenAnswer(ensureFreeSpaceAnswer(mm, 1)) + val ms = mock(classOf[MemoryStore], RETURNS_SMART_NULLS) + when(ms.evictBlocksToFreeSpace(any(), anyLong(), any())) + .thenAnswer(evictBlocksToFreeSpaceAnswer(mm)) mm.setMemoryStore(ms) ms } /** - * Make an [[Answer]] that stubs [[MemoryStore.ensureFreeSpace]] with the right arguments. - */ - private def ensureFreeSpaceAnswer(mm: MemoryManager, numBytesPos: Int): Answer[Boolean] = { - new Answer[Boolean] { - override def answer(invocation: InvocationOnMock): Boolean = { - val args = invocation.getArguments - require(args.size > numBytesPos, s"bad test: expected >$numBytesPos arguments " + - s"in ensureFreeSpace, found ${args.size}") - require(args(numBytesPos).isInstanceOf[Long], s"bad test: expected ensureFreeSpace " + - s"argument at index $numBytesPos to be a Long: ${args.mkString(", ")}") - val numBytes = args(numBytesPos).asInstanceOf[Long] - mockEnsureFreeSpace(mm, numBytes) - } - } - } - - /** - * Simulate the part of [[MemoryStore.ensureFreeSpace]] that releases storage memory. + * Simulate the part of [[MemoryStore.evictBlocksToFreeSpace]] that releases storage memory. * * This is a significant simplification of the real method, which actually drops existing * blocks based on the size of each block. Instead, here we simply release as many bytes @@ -92,133 +86,141 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite { * test without relying on the [[org.apache.spark.storage.BlockManager]], which brings in * many other dependencies. * - * Every call to this method will set a global variable, [[ensureFreeSpaceCalled]], that + * Every call to this method will set a global variable, [[evictBlocksToFreeSpaceCalled]], that * records the number of bytes this is called with. This variable is expected to be cleared - * by the test code later through [[assertEnsureFreeSpaceCalled]]. + * by the test code later through [[assertEvictBlocksToFreeSpaceCalled]]. */ - private def mockEnsureFreeSpace(mm: MemoryManager, numBytes: Long): Boolean = mm.synchronized { - require(ensureFreeSpaceCalled.get() === DEFAULT_ENSURE_FREE_SPACE_CALLED, - "bad test: ensure free space variable was not reset") - // Record the number of bytes we freed this call - ensureFreeSpaceCalled.set(numBytes) - if (numBytes <= mm.maxStorageMemory) { - def freeMemory = mm.maxStorageMemory - mm.storageMemoryUsed - val spaceToRelease = numBytes - freeMemory - if (spaceToRelease > 0) { - mm.releaseStorageMemory(spaceToRelease) + private def evictBlocksToFreeSpaceAnswer(mm: MemoryManager): Answer[Long] = { + new Answer[Long] { + override def answer(invocation: InvocationOnMock): Long = { + val args = invocation.getArguments + val numBytesToFree = args(1).asInstanceOf[Long] + assert(numBytesToFree > 0) + require(evictBlocksToFreeSpaceCalled.get() === DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED, + "bad test: evictBlocksToFreeSpace() variable was not reset") + evictBlocksToFreeSpaceCalled.set(numBytesToFree) + if (numBytesToFree <= mm.storageMemoryUsed) { + // We can evict enough blocks to fulfill the request for space + mm.releaseStorageMemory(numBytesToFree, MemoryMode.ON_HEAP) + evictedBlocks.append( + (null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytesToFree, 0L))) + numBytesToFree + } else { + // No blocks were evicted because eviction would not free enough space. + 0L + } } - freeMemory >= numBytes - } else { - // We attempted to free more bytes than our max allowable memory - false } } /** - * Assert that [[MemoryStore.ensureFreeSpace]] is called with the given parameters. + * Assert that [[MemoryStore.evictBlocksToFreeSpace]] is called with the given parameters. */ - protected def assertEnsureFreeSpaceCalled(ms: MemoryStore, numBytes: Long): Unit = { - assert(ensureFreeSpaceCalled.get() === numBytes, - s"expected ensure free space to be called with $numBytes") - ensureFreeSpaceCalled.set(DEFAULT_ENSURE_FREE_SPACE_CALLED) + protected def assertEvictBlocksToFreeSpaceCalled(ms: MemoryStore, numBytes: Long): Unit = { + assert(evictBlocksToFreeSpaceCalled.get() === numBytes, + s"expected evictBlocksToFreeSpace() to be called with $numBytes") + evictBlocksToFreeSpaceCalled.set(DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED) } /** - * Assert that [[MemoryStore.ensureFreeSpace]] is NOT called. + * Assert that [[MemoryStore.evictBlocksToFreeSpace]] is NOT called. */ - protected def assertEnsureFreeSpaceNotCalled[T](ms: MemoryStore): Unit = { - assert(ensureFreeSpaceCalled.get() === DEFAULT_ENSURE_FREE_SPACE_CALLED, - "ensure free space should not have been called!") + protected def assertEvictBlocksToFreeSpaceNotCalled[T](ms: MemoryStore): Unit = { + assert(evictBlocksToFreeSpaceCalled.get() === DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED, + "evictBlocksToFreeSpace() should not have been called!") + assert(evictedBlocks.isEmpty) } /** - * Create a MemoryManager with the specified execution memory limit and no storage memory. + * Create a MemoryManager with the specified execution memory limits and no storage memory. */ - protected def createMemoryManager(maxExecutionMemory: Long): MemoryManager + protected def createMemoryManager( + maxOnHeapExecutionMemory: Long, + maxOffHeapExecutionMemory: Long = 0L): MemoryManager // -- Tests of sharing of execution memory between tasks ---------------------------------------- // Prior to Spark 1.6, these tests were part of ShuffleMemoryManagerSuite. implicit val ec = ExecutionContext.global - test("single task requesting execution memory") { + test("single task requesting on-heap execution memory") { val manager = createMemoryManager(1000L) val taskMemoryManager = new TaskMemoryManager(manager, 0) - assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 100L) - assert(taskMemoryManager.acquireExecutionMemory(400L, null) === 400L) - assert(taskMemoryManager.acquireExecutionMemory(400L, null) === 400L) - assert(taskMemoryManager.acquireExecutionMemory(200L, null) === 100L) - assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 0L) - assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 0L) + assert(taskMemoryManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) === 100L) + assert(taskMemoryManager.acquireExecutionMemory(400L, MemoryMode.ON_HEAP, null) === 400L) + assert(taskMemoryManager.acquireExecutionMemory(400L, MemoryMode.ON_HEAP, null) === 400L) + assert(taskMemoryManager.acquireExecutionMemory(200L, MemoryMode.ON_HEAP, null) === 100L) + assert(taskMemoryManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) === 0L) + assert(taskMemoryManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) === 0L) - taskMemoryManager.releaseExecutionMemory(500L, null) - assert(taskMemoryManager.acquireExecutionMemory(300L, null) === 300L) - assert(taskMemoryManager.acquireExecutionMemory(300L, null) === 200L) + taskMemoryManager.releaseExecutionMemory(500L, MemoryMode.ON_HEAP, null) + assert(taskMemoryManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) === 300L) + assert(taskMemoryManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) === 200L) taskMemoryManager.cleanUpAllAllocatedMemory() - assert(taskMemoryManager.acquireExecutionMemory(1000L, null) === 1000L) - assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 0L) + assert(taskMemoryManager.acquireExecutionMemory(1000L, MemoryMode.ON_HEAP, null) === 1000L) + assert(taskMemoryManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) === 0L) } - test("two tasks requesting full execution memory") { + test("two tasks requesting full on-heap execution memory") { val memoryManager = createMemoryManager(1000L) val t1MemManager = new TaskMemoryManager(memoryManager, 1) val t2MemManager = new TaskMemoryManager(memoryManager, 2) val futureTimeout: Duration = 20.seconds // Have both tasks request 500 bytes, then wait until both requests have been granted: - val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L, null) } - val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, null) } + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } assert(Await.result(t1Result1, futureTimeout) === 500L) assert(Await.result(t2Result1, futureTimeout) === 500L) // Have both tasks each request 500 bytes more; both should immediately return 0 as they are // both now at 1 / N - val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, null) } - val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, null) } + val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } assert(Await.result(t1Result2, 200.millis) === 0L) assert(Await.result(t2Result2, 200.millis) === 0L) } - test("two tasks cannot grow past 1 / N of execution memory") { + test("two tasks cannot grow past 1 / N of on-heap execution memory") { val memoryManager = createMemoryManager(1000L) val t1MemManager = new TaskMemoryManager(memoryManager, 1) val t2MemManager = new TaskMemoryManager(memoryManager, 2) val futureTimeout: Duration = 20.seconds // Have both tasks request 250 bytes, then wait until both requests have been granted: - val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L, null) } - val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, null) } + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L, MemoryMode.ON_HEAP, null) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, MemoryMode.ON_HEAP, null) } assert(Await.result(t1Result1, futureTimeout) === 250L) assert(Await.result(t2Result1, futureTimeout) === 250L) // Have both tasks each request 500 bytes more. // We should only grant 250 bytes to each of them on this second request - val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, null) } - val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, null) } + val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } assert(Await.result(t1Result2, futureTimeout) === 250L) assert(Await.result(t2Result2, futureTimeout) === 250L) } - test("tasks can block to get at least 1 / 2N of execution memory") { + test("tasks can block to get at least 1 / 2N of on-heap execution memory") { val memoryManager = createMemoryManager(1000L) val t1MemManager = new TaskMemoryManager(memoryManager, 1) val t2MemManager = new TaskMemoryManager(memoryManager, 2) val futureTimeout: Duration = 20.seconds // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. - val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, null) } + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, MemoryMode.ON_HEAP, null) } assert(Await.result(t1Result1, futureTimeout) === 1000L) - val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, null) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, MemoryMode.ON_HEAP, null) } // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult // to make sure the other thread blocks for some time otherwise. Thread.sleep(300) - t1MemManager.releaseExecutionMemory(250L, null) + t1MemManager.releaseExecutionMemory(250L, MemoryMode.ON_HEAP, null) // The memory freed from t1 should now be granted to t2. assert(Await.result(t2Result1, futureTimeout) === 250L) // Further requests by t2 should be denied immediately because it now has 1 / 2N of the memory. - val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L, null) } + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) } assert(Await.result(t2Result2, 200.millis) === 0L) } @@ -229,18 +231,18 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite { val futureTimeout: Duration = 20.seconds // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. - val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, null) } + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, MemoryMode.ON_HEAP, null) } assert(Await.result(t1Result1, futureTimeout) === 1000L) - val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, null) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult // to make sure the other thread blocks for some time otherwise. Thread.sleep(300) // t1 releases all of its memory, so t2 should be able to grab all of the memory t1MemManager.cleanUpAllAllocatedMemory() assert(Await.result(t2Result1, futureTimeout) === 500L) - val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, null) } + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } assert(Await.result(t2Result2, futureTimeout) === 500L) - val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L, null) } + val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } assert(Await.result(t2Result3, 200.millis) === 0L) } @@ -251,17 +253,37 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite { val t2MemManager = new TaskMemoryManager(memoryManager, 2) val futureTimeout: Duration = 20.seconds - val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L, null) } + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L, MemoryMode.ON_HEAP, null) } assert(Await.result(t1Result1, futureTimeout) === 700L) - val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L, null) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) } assert(Await.result(t2Result1, futureTimeout) === 300L) - val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L, null) } + val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) } assert(Await.result(t1Result2, 200.millis) === 0L) } + + test("off-heap execution allocations cannot exceed limit") { + val memoryManager = createMemoryManager( + maxOnHeapExecutionMemory = 0L, + maxOffHeapExecutionMemory = 1000L) + + val tMemManager = new TaskMemoryManager(memoryManager, 1) + val result1 = Future { tMemManager.acquireExecutionMemory(1000L, MemoryMode.OFF_HEAP, null) } + assert(Await.result(result1, 200.millis) === 1000L) + assert(tMemManager.getMemoryConsumptionForThisTask === 1000L) + + val result2 = Future { tMemManager.acquireExecutionMemory(300L, MemoryMode.OFF_HEAP, null) } + assert(Await.result(result2, 200.millis) === 0L) + + assert(tMemManager.getMemoryConsumptionForThisTask === 1000L) + tMemManager.releaseExecutionMemory(500L, MemoryMode.OFF_HEAP, null) + assert(tMemManager.getMemoryConsumptionForThisTask === 500L) + tMemManager.releaseExecutionMemory(500L, MemoryMode.OFF_HEAP, null) + assert(tMemManager.getMemoryConsumptionForThisTask === 0L) + } } private object MemoryManagerSuite { - private val DEFAULT_ENSURE_FREE_SPACE_CALLED = -1L + private val DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED = -1L } diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala index 4b4c3b0311328..362cd861cc248 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala @@ -17,7 +17,9 @@ package org.apache.spark.memory -import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext} +import java.util.Properties + +import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl} /** * Helper methods for mocking out memory-management-related classes in tests. @@ -31,7 +33,7 @@ object MemoryTestingUtils { taskAttemptId = 0, attemptNumber = 0, taskMemoryManager = taskMemoryManager, - metricsSystem = env.metricsSystem, - internalAccumulators = Seq.empty) + localProperties = new Properties, + metricsSystem = env.metricsSystem) } } diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala index 885c450d6d4f5..4e31fb5589a9c 100644 --- a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala @@ -17,17 +17,14 @@ package org.apache.spark.memory -import scala.collection.mutable.ArrayBuffer - import org.mockito.Mockito.when import org.apache.spark.SparkConf -import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore, TestBlockId} - +import org.apache.spark.storage.TestBlockId +import org.apache.spark.storage.memory.MemoryStore class StaticMemoryManagerSuite extends MemoryManagerSuite { private val conf = new SparkConf().set("spark.storage.unrollFraction", "0.4") - private val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] /** * Make a [[StaticMemoryManager]] and a [[MemoryStore]] with limited class dependencies. @@ -36,38 +33,48 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { maxExecutionMem: Long, maxStorageMem: Long): (StaticMemoryManager, MemoryStore) = { val mm = new StaticMemoryManager( - conf, maxExecutionMemory = maxExecutionMem, maxStorageMemory = maxStorageMem, numCores = 1) + conf, + maxOnHeapExecutionMemory = maxExecutionMem, + maxOnHeapStorageMemory = maxStorageMem, + numCores = 1) val ms = makeMemoryStore(mm) (mm, ms) } - override protected def createMemoryManager(maxMemory: Long): MemoryManager = { + override protected def createMemoryManager( + maxOnHeapExecutionMemory: Long, + maxOffHeapExecutionMemory: Long): StaticMemoryManager = { new StaticMemoryManager( - conf, - maxExecutionMemory = maxMemory, - maxStorageMemory = 0, + conf.clone + .set("spark.memory.fraction", "1") + .set("spark.testing.memory", maxOnHeapExecutionMemory.toString) + .set("spark.memory.offHeap.size", maxOffHeapExecutionMemory.toString), + maxOnHeapExecutionMemory = maxOnHeapExecutionMemory, + maxOnHeapStorageMemory = 0, numCores = 1) } test("basic execution memory") { val maxExecutionMem = 1000L + val taskAttemptId = 0L val (mm, _) = makeThings(maxExecutionMem, Long.MaxValue) + val memoryMode = MemoryMode.ON_HEAP assert(mm.executionMemoryUsed === 0L) - assert(mm.doAcquireExecutionMemory(10L, evictedBlocks) === 10L) + assert(mm.acquireExecutionMemory(10L, taskAttemptId, memoryMode) === 10L) assert(mm.executionMemoryUsed === 10L) - assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L) + assert(mm.acquireExecutionMemory(100L, taskAttemptId, memoryMode) === 100L) // Acquire up to the max - assert(mm.doAcquireExecutionMemory(1000L, evictedBlocks) === 890L) + assert(mm.acquireExecutionMemory(1000L, taskAttemptId, memoryMode) === 890L) assert(mm.executionMemoryUsed === maxExecutionMem) - assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 0L) + assert(mm.acquireExecutionMemory(1L, taskAttemptId, memoryMode) === 0L) assert(mm.executionMemoryUsed === maxExecutionMem) - mm.releaseExecutionMemory(800L) + mm.releaseExecutionMemory(800L, taskAttemptId, memoryMode) assert(mm.executionMemoryUsed === 200L) // Acquire after release - assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 1L) + assert(mm.acquireExecutionMemory(1L, taskAttemptId, memoryMode) === 1L) assert(mm.executionMemoryUsed === 201L) // Release beyond what was acquired - mm.releaseExecutionMemory(maxExecutionMem) + mm.releaseExecutionMemory(maxExecutionMem, taskAttemptId, memoryMode) assert(mm.executionMemoryUsed === 0L) } @@ -75,60 +82,68 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { val maxStorageMem = 1000L val dummyBlock = TestBlockId("you can see the world you brought to live") val (mm, ms) = makeThings(Long.MaxValue, maxStorageMem) + val memoryMode = MemoryMode.ON_HEAP assert(mm.storageMemoryUsed === 0L) - assert(mm.acquireStorageMemory(dummyBlock, 10L, evictedBlocks)) - // `ensureFreeSpace` should be called with the number of bytes requested - assertEnsureFreeSpaceCalled(ms, 10L) + assert(mm.acquireStorageMemory(dummyBlock, 10L, memoryMode)) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 10L) - assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 100L) + + assert(mm.acquireStorageMemory(dummyBlock, 100L, memoryMode)) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 110L) // Acquire more than the max, not granted - assert(!mm.acquireStorageMemory(dummyBlock, maxStorageMem + 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, maxStorageMem + 1L) + assert(!mm.acquireStorageMemory(dummyBlock, maxStorageMem + 1L, memoryMode)) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 110L) // Acquire up to the max, requests after this are still granted due to LRU eviction - assert(mm.acquireStorageMemory(dummyBlock, maxStorageMem, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1000L) + assert(mm.acquireStorageMemory(dummyBlock, maxStorageMem, memoryMode)) + assertEvictBlocksToFreeSpaceCalled(ms, 110L) assert(mm.storageMemoryUsed === 1000L) - assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1L) + assert(mm.acquireStorageMemory(dummyBlock, 1L, memoryMode)) + assertEvictBlocksToFreeSpaceCalled(ms, 1L) + assert(evictedBlocks.nonEmpty) + evictedBlocks.clear() + // Note: We evicted 1 byte to put another 1-byte block in, so the storage memory used remains at + // 1000 bytes. This is different from real behavior, where the 1-byte block would have evicted + // the 1000-byte block entirely. This is set up differently so we can write finer-grained tests. assert(mm.storageMemoryUsed === 1000L) - mm.releaseStorageMemory(800L) + mm.releaseStorageMemory(800L, memoryMode) assert(mm.storageMemoryUsed === 200L) // Acquire after release - assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1L) + assert(mm.acquireStorageMemory(dummyBlock, 1L, memoryMode)) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 201L) mm.releaseAllStorageMemory() assert(mm.storageMemoryUsed === 0L) - assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1L) + assert(mm.acquireStorageMemory(dummyBlock, 1L, memoryMode)) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 1L) // Release beyond what was acquired - mm.releaseStorageMemory(100L) + mm.releaseStorageMemory(100L, memoryMode) assert(mm.storageMemoryUsed === 0L) } test("execution and storage isolation") { val maxExecutionMem = 200L val maxStorageMem = 1000L + val taskAttemptId = 0L val dummyBlock = TestBlockId("ain't nobody love like you do") val (mm, ms) = makeThings(maxExecutionMem, maxStorageMem) + val memoryMode = MemoryMode.ON_HEAP // Only execution memory should increase - assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L) + assert(mm.acquireExecutionMemory(100L, taskAttemptId, memoryMode) === 100L) assert(mm.storageMemoryUsed === 0L) assert(mm.executionMemoryUsed === 100L) - assert(mm.doAcquireExecutionMemory(1000L, evictedBlocks) === 100L) + assert(mm.acquireExecutionMemory(1000L, taskAttemptId, memoryMode) === 100L) assert(mm.storageMemoryUsed === 0L) assert(mm.executionMemoryUsed === 200L) // Only storage memory should increase - assert(mm.acquireStorageMemory(dummyBlock, 50L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 50L) + assert(mm.acquireStorageMemory(dummyBlock, 50L, memoryMode)) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 50L) assert(mm.executionMemoryUsed === 200L) // Only execution memory should be released - mm.releaseExecutionMemory(133L) + mm.releaseExecutionMemory(133L, taskAttemptId, memoryMode) assert(mm.storageMemoryUsed === 50L) assert(mm.executionMemoryUsed === 67L) // Only storage memory should be released @@ -141,24 +156,34 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { val maxStorageMem = 1000L val dummyBlock = TestBlockId("lonely water") val (mm, ms) = makeThings(Long.MaxValue, maxStorageMem) - assert(mm.acquireUnrollMemory(dummyBlock, 100L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 100L) + val memoryMode = MemoryMode.ON_HEAP + assert(mm.acquireUnrollMemory(dummyBlock, 100L, memoryMode)) + when(ms.currentUnrollMemory).thenReturn(100L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 100L) - mm.releaseUnrollMemory(40L) + mm.releaseUnrollMemory(40L, memoryMode) assert(mm.storageMemoryUsed === 60L) when(ms.currentUnrollMemory).thenReturn(60L) - assert(mm.acquireUnrollMemory(dummyBlock, 500L, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, 800L, memoryMode)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 860L) // `spark.storage.unrollFraction` is 0.4, so the max unroll space is 400 bytes. - // Since we already occupy 60 bytes, we will try to ensure only 400 - 60 = 340 bytes. - assertEnsureFreeSpaceCalled(ms, 340L) - assert(mm.storageMemoryUsed === 560L) - when(ms.currentUnrollMemory).thenReturn(560L) - assert(!mm.acquireUnrollMemory(dummyBlock, 800L, evictedBlocks)) - assert(mm.storageMemoryUsed === 560L) - // We already have 560 bytes > the max unroll space of 400 bytes, so no bytes are freed - assertEnsureFreeSpaceCalled(ms, 0L) + // As of this point, cache memory is 800 bytes and current unroll memory is 60 bytes. + // Requesting 240 more bytes of unroll memory will leave our total unroll memory at + // 300 bytes, still under the 400-byte limit. Therefore, all 240 bytes are granted. + assert(mm.acquireUnrollMemory(dummyBlock, 240L, memoryMode)) + assertEvictBlocksToFreeSpaceCalled(ms, 100L) // 860 + 240 - 1000 + when(ms.currentUnrollMemory).thenReturn(300L) // 60 + 240 + assert(mm.storageMemoryUsed === 1000L) + evictedBlocks.clear() + // We already have 300 bytes of unroll memory, so requesting 150 more will leave us + // above the 400-byte limit. Since there is not enough free memory, this request will + // fail even after evicting as much as we can (400 - 300 = 100 bytes). + assert(!mm.acquireUnrollMemory(dummyBlock, 150L, memoryMode)) + assertEvictBlocksToFreeSpaceCalled(ms, 100L) + assert(mm.storageMemoryUsed === 900L) // Release beyond what was acquired - mm.releaseUnrollMemory(maxStorageMem) + mm.releaseUnrollMemory(maxStorageMem, memoryMode) assert(mm.storageMemoryUsed === 0L) } diff --git a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala index 77e43554ee27c..6a4f409e8e08f 100644 --- a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala +++ b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala @@ -17,24 +17,23 @@ package org.apache.spark.memory -import scala.collection.mutable - import org.apache.spark.SparkConf -import org.apache.spark.storage.{BlockStatus, BlockId} +import org.apache.spark.storage.BlockId + +class TestMemoryManager(conf: SparkConf) + extends MemoryManager(conf, numCores = 1, Long.MaxValue, Long.MaxValue) { -class TestMemoryManager(conf: SparkConf) extends MemoryManager(conf, numCores = 1) { - private[memory] override def doAcquireExecutionMemory( + override private[memory] def acquireExecutionMemory( numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized { + taskAttemptId: Long, + memoryMode: MemoryMode): Long = { if (oomOnce) { oomOnce = false 0 } else if (available >= numBytes) { - _executionMemoryUsed += numBytes // To suppress warnings when freeing unallocated memory available -= numBytes numBytes } else { - _executionMemoryUsed += available val grant = available available = 0 grant @@ -43,18 +42,19 @@ class TestMemoryManager(conf: SparkConf) extends MemoryManager(conf, numCores = override def acquireStorageMemory( blockId: BlockId, numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true + memoryMode: MemoryMode): Boolean = true override def acquireUnrollMemory( blockId: BlockId, numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true - override def releaseExecutionMemory(numBytes: Long): Unit = { + memoryMode: MemoryMode): Boolean = true + override def releaseStorageMemory(numBytes: Long, memoryMode: MemoryMode): Unit = {} + override private[memory] def releaseExecutionMemory( + numBytes: Long, + taskAttemptId: Long, + memoryMode: MemoryMode): Unit = { available += numBytes - _executionMemoryUsed -= numBytes } - override def releaseStorageMemory(numBytes: Long): Unit = {} - override def maxExecutionMemory: Long = Long.MaxValue - override def maxStorageMemory: Long = Long.MaxValue + override def maxOnHeapStorageMemory: Long = Long.MaxValue private var oomOnce = false private var available = Long.MaxValue diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index 0c97f2bd89651..14255818c7b5e 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -17,196 +17,267 @@ package org.apache.spark.memory -import scala.collection.mutable.ArrayBuffer - import org.scalatest.PrivateMethodTester import org.apache.spark.SparkConf -import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore, TestBlockId} - +import org.apache.spark.storage.TestBlockId +import org.apache.spark.storage.memory.MemoryStore class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTester { - private val conf = new SparkConf().set("spark.memory.storageFraction", "0.5") private val dummyBlock = TestBlockId("--") - private val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + + private val storageFraction: Double = 0.5 /** * Make a [[UnifiedMemoryManager]] and a [[MemoryStore]] with limited class dependencies. */ private def makeThings(maxMemory: Long): (UnifiedMemoryManager, MemoryStore) = { - val mm = new UnifiedMemoryManager(conf, maxMemory, numCores = 1) + val mm = createMemoryManager(maxMemory) val ms = makeMemoryStore(mm) (mm, ms) } - override protected def createMemoryManager(maxMemory: Long): MemoryManager = { - new UnifiedMemoryManager(conf, maxMemory, numCores = 1) - } - - private def getStorageRegionSize(mm: UnifiedMemoryManager): Long = { - mm invokePrivate PrivateMethod[Long]('storageRegionSize)() - } - - test("storage region size") { - val maxMemory = 1000L - val (mm, _) = makeThings(maxMemory) - val storageFraction = conf.get("spark.memory.storageFraction").toDouble - val expectedStorageRegionSize = maxMemory * storageFraction - val actualStorageRegionSize = getStorageRegionSize(mm) - assert(expectedStorageRegionSize === actualStorageRegionSize) + override protected def createMemoryManager( + maxOnHeapExecutionMemory: Long, + maxOffHeapExecutionMemory: Long): UnifiedMemoryManager = { + val conf = new SparkConf() + .set("spark.memory.fraction", "1") + .set("spark.testing.memory", maxOnHeapExecutionMemory.toString) + .set("spark.memory.offHeap.size", maxOffHeapExecutionMemory.toString) + .set("spark.memory.storageFraction", storageFraction.toString) + UnifiedMemoryManager(conf, numCores = 1) } test("basic execution memory") { val maxMemory = 1000L + val taskAttemptId = 0L val (mm, _) = makeThings(maxMemory) + val memoryMode = MemoryMode.ON_HEAP assert(mm.executionMemoryUsed === 0L) - assert(mm.doAcquireExecutionMemory(10L, evictedBlocks) === 10L) + assert(mm.acquireExecutionMemory(10L, taskAttemptId, memoryMode) === 10L) assert(mm.executionMemoryUsed === 10L) - assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L) + assert(mm.acquireExecutionMemory(100L, taskAttemptId, memoryMode) === 100L) // Acquire up to the max - assert(mm.doAcquireExecutionMemory(1000L, evictedBlocks) === 890L) + assert(mm.acquireExecutionMemory(1000L, taskAttemptId, memoryMode) === 890L) assert(mm.executionMemoryUsed === maxMemory) - assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 0L) + assert(mm.acquireExecutionMemory(1L, taskAttemptId, memoryMode) === 0L) assert(mm.executionMemoryUsed === maxMemory) - mm.releaseExecutionMemory(800L) + mm.releaseExecutionMemory(800L, taskAttemptId, memoryMode) assert(mm.executionMemoryUsed === 200L) // Acquire after release - assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 1L) + assert(mm.acquireExecutionMemory(1L, taskAttemptId, memoryMode) === 1L) assert(mm.executionMemoryUsed === 201L) // Release beyond what was acquired - mm.releaseExecutionMemory(maxMemory) + mm.releaseExecutionMemory(maxMemory, taskAttemptId, memoryMode) assert(mm.executionMemoryUsed === 0L) } test("basic storage memory") { val maxMemory = 1000L val (mm, ms) = makeThings(maxMemory) + val memoryMode = MemoryMode.ON_HEAP assert(mm.storageMemoryUsed === 0L) - assert(mm.acquireStorageMemory(dummyBlock, 10L, evictedBlocks)) - // `ensureFreeSpace` should be called with the number of bytes requested - assertEnsureFreeSpaceCalled(ms, 10L) + assert(mm.acquireStorageMemory(dummyBlock, 10L, memoryMode)) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 10L) - assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 100L) + + assert(mm.acquireStorageMemory(dummyBlock, 100L, memoryMode)) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 110L) // Acquire more than the max, not granted - assert(!mm.acquireStorageMemory(dummyBlock, maxMemory + 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, maxMemory + 1L) + assert(!mm.acquireStorageMemory(dummyBlock, maxMemory + 1L, memoryMode)) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 110L) // Acquire up to the max, requests after this are still granted due to LRU eviction - assert(mm.acquireStorageMemory(dummyBlock, maxMemory, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1000L) + assert(mm.acquireStorageMemory(dummyBlock, maxMemory, memoryMode)) + assertEvictBlocksToFreeSpaceCalled(ms, 110L) assert(mm.storageMemoryUsed === 1000L) - assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1L) + assert(evictedBlocks.nonEmpty) + evictedBlocks.clear() + assert(mm.acquireStorageMemory(dummyBlock, 1L, memoryMode)) + assertEvictBlocksToFreeSpaceCalled(ms, 1L) + assert(evictedBlocks.nonEmpty) + evictedBlocks.clear() + // Note: We evicted 1 byte to put another 1-byte block in, so the storage memory used remains at + // 1000 bytes. This is different from real behavior, where the 1-byte block would have evicted + // the 1000-byte block entirely. This is set up differently so we can write finer-grained tests. assert(mm.storageMemoryUsed === 1000L) - mm.releaseStorageMemory(800L) + mm.releaseStorageMemory(800L, memoryMode) assert(mm.storageMemoryUsed === 200L) // Acquire after release - assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1L) + assert(mm.acquireStorageMemory(dummyBlock, 1L, memoryMode)) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 201L) mm.releaseAllStorageMemory() assert(mm.storageMemoryUsed === 0L) - assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1L) + assert(mm.acquireStorageMemory(dummyBlock, 1L, memoryMode)) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 1L) // Release beyond what was acquired - mm.releaseStorageMemory(100L) + mm.releaseStorageMemory(100L, memoryMode) assert(mm.storageMemoryUsed === 0L) } test("execution evicts storage") { val maxMemory = 1000L + val taskAttemptId = 0L val (mm, ms) = makeThings(maxMemory) - // First, ensure the test classes are set up as expected - val expectedStorageRegionSize = 500L - val expectedExecutionRegionSize = 500L - val storageRegionSize = getStorageRegionSize(mm) - val executionRegionSize = maxMemory - expectedStorageRegionSize - require(storageRegionSize === expectedStorageRegionSize, - "bad test: storage region size is unexpected") - require(executionRegionSize === expectedExecutionRegionSize, - "bad test: storage region size is unexpected") + val memoryMode = MemoryMode.ON_HEAP // Acquire enough storage memory to exceed the storage region - assert(mm.acquireStorageMemory(dummyBlock, 750L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 750L) + assert(mm.acquireStorageMemory(dummyBlock, 750L, memoryMode)) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.executionMemoryUsed === 0L) assert(mm.storageMemoryUsed === 750L) - require(mm.storageMemoryUsed > storageRegionSize, - s"bad test: storage memory used should exceed the storage region") // Execution needs to request 250 bytes to evict storage memory - assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L) + assert(mm.acquireExecutionMemory(100L, taskAttemptId, memoryMode) === 100L) assert(mm.executionMemoryUsed === 100L) assert(mm.storageMemoryUsed === 750L) - assertEnsureFreeSpaceNotCalled(ms) + assertEvictBlocksToFreeSpaceNotCalled(ms) // Execution wants 200 bytes but only 150 are free, so storage is evicted - assert(mm.doAcquireExecutionMemory(200L, evictedBlocks) === 200L) - assertEnsureFreeSpaceCalled(ms, 200L) + assert(mm.acquireExecutionMemory(200L, taskAttemptId, memoryMode) === 200L) assert(mm.executionMemoryUsed === 300L) + assert(mm.storageMemoryUsed === 700L) + assertEvictBlocksToFreeSpaceCalled(ms, 50L) + assert(evictedBlocks.nonEmpty) + evictedBlocks.clear() mm.releaseAllStorageMemory() - require(mm.executionMemoryUsed < executionRegionSize, - s"bad test: execution memory used should be within the execution region") + require(mm.executionMemoryUsed === 300L) require(mm.storageMemoryUsed === 0, "bad test: all storage memory should have been released") // Acquire some storage memory again, but this time keep it within the storage region - assert(mm.acquireStorageMemory(dummyBlock, 400L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 400L) - require(mm.storageMemoryUsed < storageRegionSize, - s"bad test: storage memory used should be within the storage region") + assert(mm.acquireStorageMemory(dummyBlock, 400L, memoryMode)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 400L) + assert(mm.executionMemoryUsed === 300L) // Execution cannot evict storage because the latter is within the storage fraction, // so grant only what's remaining without evicting anything, i.e. 1000 - 300 - 400 = 300 - assert(mm.doAcquireExecutionMemory(400L, evictedBlocks) === 300L) + assert(mm.acquireExecutionMemory(400L, taskAttemptId, memoryMode) === 300L) assert(mm.executionMemoryUsed === 600L) assert(mm.storageMemoryUsed === 400L) - assertEnsureFreeSpaceNotCalled(ms) + assertEvictBlocksToFreeSpaceNotCalled(ms) + } + + test("execution memory requests smaller than free memory should evict storage (SPARK-12165)") { + val maxMemory = 1000L + val taskAttemptId = 0L + val (mm, ms) = makeThings(maxMemory) + val memoryMode = MemoryMode.ON_HEAP + // Acquire enough storage memory to exceed the storage region size + assert(mm.acquireStorageMemory(dummyBlock, 700L, memoryMode)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.executionMemoryUsed === 0L) + assert(mm.storageMemoryUsed === 700L) + // SPARK-12165: previously, MemoryStore would not evict anything because it would + // mistakenly think that the 300 bytes of free space was still available even after + // using it to expand the execution pool. Consequently, no storage memory was released + // and the following call granted only 300 bytes to execution. + assert(mm.acquireExecutionMemory(500L, taskAttemptId, memoryMode) === 500L) + assertEvictBlocksToFreeSpaceCalled(ms, 200L) + assert(mm.storageMemoryUsed === 500L) + assert(mm.executionMemoryUsed === 500L) + assert(evictedBlocks.nonEmpty) } test("storage does not evict execution") { val maxMemory = 1000L + val taskAttemptId = 0L val (mm, ms) = makeThings(maxMemory) - // First, ensure the test classes are set up as expected - val expectedStorageRegionSize = 500L - val expectedExecutionRegionSize = 500L - val storageRegionSize = getStorageRegionSize(mm) - val executionRegionSize = maxMemory - expectedStorageRegionSize - require(storageRegionSize === expectedStorageRegionSize, - "bad test: storage region size is unexpected") - require(executionRegionSize === expectedExecutionRegionSize, - "bad test: storage region size is unexpected") + val memoryMode = MemoryMode.ON_HEAP // Acquire enough execution memory to exceed the execution region - assert(mm.doAcquireExecutionMemory(800L, evictedBlocks) === 800L) + assert(mm.acquireExecutionMemory(800L, taskAttemptId, memoryMode) === 800L) assert(mm.executionMemoryUsed === 800L) assert(mm.storageMemoryUsed === 0L) - assertEnsureFreeSpaceNotCalled(ms) - require(mm.executionMemoryUsed > executionRegionSize, - s"bad test: execution memory used should exceed the execution region") + assertEvictBlocksToFreeSpaceNotCalled(ms) // Storage should not be able to evict execution - assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, 100L, memoryMode)) assert(mm.executionMemoryUsed === 800L) assert(mm.storageMemoryUsed === 100L) - assertEnsureFreeSpaceCalled(ms, 100L) - assert(!mm.acquireStorageMemory(dummyBlock, 250L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(!mm.acquireStorageMemory(dummyBlock, 250L, memoryMode)) assert(mm.executionMemoryUsed === 800L) assert(mm.storageMemoryUsed === 100L) - assertEnsureFreeSpaceCalled(ms, 250L) - mm.releaseExecutionMemory(maxMemory) - mm.releaseStorageMemory(maxMemory) + // Do not attempt to evict blocks, since evicting will not free enough memory: + assertEvictBlocksToFreeSpaceNotCalled(ms) + mm.releaseExecutionMemory(maxMemory, taskAttemptId, memoryMode) + mm.releaseStorageMemory(maxMemory, memoryMode) // Acquire some execution memory again, but this time keep it within the execution region - assert(mm.doAcquireExecutionMemory(200L, evictedBlocks) === 200L) + assert(mm.acquireExecutionMemory(200L, taskAttemptId, memoryMode) === 200L) assert(mm.executionMemoryUsed === 200L) assert(mm.storageMemoryUsed === 0L) - assertEnsureFreeSpaceNotCalled(ms) - require(mm.executionMemoryUsed < executionRegionSize, - s"bad test: execution memory used should be within the execution region") + assertEvictBlocksToFreeSpaceNotCalled(ms) // Storage should still not be able to evict execution - assert(mm.acquireStorageMemory(dummyBlock, 750L, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, 750L, memoryMode)) assert(mm.executionMemoryUsed === 200L) assert(mm.storageMemoryUsed === 750L) - assertEnsureFreeSpaceCalled(ms, 750L) - assert(!mm.acquireStorageMemory(dummyBlock, 850L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) // since there were 800 bytes free + assert(!mm.acquireStorageMemory(dummyBlock, 850L, memoryMode)) assert(mm.executionMemoryUsed === 200L) assert(mm.storageMemoryUsed === 750L) - assertEnsureFreeSpaceCalled(ms, 850L) + // Do not attempt to evict blocks, since evicting will not free enough memory: + assertEvictBlocksToFreeSpaceNotCalled(ms) + } + + test("small heap") { + val systemMemory = 1024 * 1024 + val reservedMemory = 300 * 1024 + val memoryFraction = 0.8 + val conf = new SparkConf() + .set("spark.memory.fraction", memoryFraction.toString) + .set("spark.testing.memory", systemMemory.toString) + .set("spark.testing.reservedMemory", reservedMemory.toString) + val mm = UnifiedMemoryManager(conf, numCores = 1) + val expectedMaxMemory = ((systemMemory - reservedMemory) * memoryFraction).toLong + assert(mm.maxHeapMemory === expectedMaxMemory) + + // Try using a system memory that's too small + val conf2 = conf.clone().set("spark.testing.memory", (reservedMemory / 2).toString) + val exception = intercept[IllegalArgumentException] { + UnifiedMemoryManager(conf2, numCores = 1) + } + assert(exception.getMessage.contains("increase heap size")) + } + + test("insufficient executor memory") { + val systemMemory = 1024 * 1024 + val reservedMemory = 300 * 1024 + val memoryFraction = 0.8 + val conf = new SparkConf() + .set("spark.memory.fraction", memoryFraction.toString) + .set("spark.testing.memory", systemMemory.toString) + .set("spark.testing.reservedMemory", reservedMemory.toString) + val mm = UnifiedMemoryManager(conf, numCores = 1) + + // Try using an executor memory that's too small + val conf2 = conf.clone().set("spark.executor.memory", (reservedMemory / 2).toString) + val exception = intercept[IllegalArgumentException] { + UnifiedMemoryManager(conf2, numCores = 1) + } + assert(exception.getMessage.contains("increase executor memory")) + } + + test("execution can evict cached blocks when there are multiple active tasks (SPARK-12155)") { + val conf = new SparkConf() + .set("spark.memory.fraction", "1") + .set("spark.memory.storageFraction", "0") + .set("spark.testing.memory", "1000") + val mm = UnifiedMemoryManager(conf, numCores = 2) + val ms = makeMemoryStore(mm) + val memoryMode = MemoryMode.ON_HEAP + assert(mm.maxHeapMemory === 1000) + // Have two tasks each acquire some execution memory so that the memory pool registers that + // there are two active tasks: + assert(mm.acquireExecutionMemory(100L, 0, memoryMode) === 100L) + assert(mm.acquireExecutionMemory(100L, 1, memoryMode) === 100L) + // Fill up all of the remaining memory with storage. + assert(mm.acquireStorageMemory(dummyBlock, 800L, memoryMode)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 800) + assert(mm.executionMemoryUsed === 200) + // A task should still be able to allocate 100 bytes execution memory by evicting blocks + assert(mm.acquireExecutionMemory(100L, 0, memoryMode) === 100L) + assertEvictBlocksToFreeSpaceCalled(ms, 100L) + assert(mm.executionMemoryUsed === 300) + assert(mm.storageMemoryUsed === 700) + assert(evictedBlocks.nonEmpty) } } diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 44eb5a0469122..056e5463a0abf 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -25,17 +25,17 @@ import org.apache.commons.lang3.RandomUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.mapred.{FileSplit => OldFileSplit, InputSplit => OldInputSplit, + JobConf, LineRecordReader => OldLineRecordReader, RecordReader => OldRecordReader, + Reporter, TextInputFormat => OldTextInputFormat} import org.apache.hadoop.mapred.lib.{CombineFileInputFormat => OldCombineFileInputFormat, CombineFileRecordReader => OldCombineFileRecordReader, CombineFileSplit => OldCombineFileSplit} -import org.apache.hadoop.mapred.{JobConf, Reporter, FileSplit => OldFileSplit, - InputSplit => OldInputSplit, LineRecordReader => OldLineRecordReader, - RecordReader => OldRecordReader, TextInputFormat => OldTextInputFormat} +import org.apache.hadoop.mapreduce.{InputSplit => NewInputSplit, RecordReader => NewRecordReader, + TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat => NewCombineFileInputFormat, CombineFileRecordReader => NewCombineFileRecordReader, CombineFileSplit => NewCombineFileSplit, FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} -import org.apache.hadoop.mapreduce.{TaskAttemptContext, InputSplit => NewInputSplit, - RecordReader => NewRecordReader} import org.scalatest.BeforeAndAfter import org.apache.spark.{SharedSparkContext, SparkFunSuite} @@ -98,14 +98,14 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext rdd.coalesce(4).count() } - // for count and coelesce, the same bytes should be read. + // for count and coalesce, the same bytes should be read. assert(bytesRead != 0) assert(bytesRead2 == bytesRead) } /** * This checks the situation where we have interleaved reads from - * different sources. Currently, we only accumulate fron the first + * different sources. Currently, we only accumulate from the first * read method we find in the task. This test uses cartesian to create * the interleaved reads. * @@ -183,7 +183,7 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext assert(records == numRecords) } - test("input metrics on recordsd read with cache") { + test("input metrics on records read with cache") { // prime the cache manager val rdd = sc.textFile(tmpFilePath, 4).cache() rdd.collect() @@ -212,7 +212,7 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext metrics.inputMetrics.foreach(inputRead += _.recordsRead) metrics.outputMetrics.foreach(outputWritten += _.recordsWritten) metrics.shuffleReadMetrics.foreach(shuffleRead += _.recordsRead) - metrics.shuffleWriteMetrics.foreach(shuffleWritten += _.shuffleRecordsWritten) + metrics.shuffleWriteMetrics.foreach(shuffleWritten += _.recordsWritten) } }) diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala index 41f2ff725a17b..b24f5d732f292 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.metrics -import org.apache.spark.SparkConf - import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkConf import org.apache.spark.SparkFunSuite class MetricsConfigSuite extends SparkFunSuite with BeforeAndAfter { diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala index 9c389c76bf3bd..5d8554229dbe1 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.metrics +import scala.collection.mutable.ArrayBuffer + +import com.codahale.metrics.MetricRegistry import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.master.MasterSource import org.apache.spark.metrics.source.Source -import com.codahale.metrics.MetricRegistry - -import scala.collection.mutable.ArrayBuffer - class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester{ var filePath: String = _ var conf: SparkConf = null diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 3940527fb874e..ed15e77ff1421 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -19,23 +19,24 @@ package org.apache.spark.network.netty import java.io.InputStreamReader import java.nio._ -import java.nio.charset.Charset +import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit -import scala.concurrent.duration._ import scala.concurrent.{Await, Promise} +import scala.concurrent.duration._ import scala.util.{Failure, Success, Try} import com.google.common.io.CharStreams -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.shuffle.BlockFetchingListener -import org.apache.spark.network.{BlockDataManager, BlockTransferService} -import org.apache.spark.storage.{BlockId, ShuffleBlockId} -import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.mockito.Mockito._ import org.scalatest.mock.MockitoSugar import org.scalatest.ShouldMatchers +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.network.{BlockDataManager, BlockTransferService} +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.storage.{BlockId, ShuffleBlockId} + class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar with ShouldMatchers { test("security default off") { val conf = new SparkConf() @@ -102,24 +103,25 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi val blockManager = mock[BlockDataManager] val blockId = ShuffleBlockId(0, 1, 2) val blockString = "Hello, world!" - val blockBuffer = new NioManagedBuffer(ByteBuffer.wrap(blockString.getBytes)) + val blockBuffer = new NioManagedBuffer(ByteBuffer.wrap( + blockString.getBytes(StandardCharsets.UTF_8))) when(blockManager.getBlockData(blockId)).thenReturn(blockBuffer) val securityManager0 = new SecurityManager(conf0) - val exec0 = new NettyBlockTransferService(conf0, securityManager0, numCores = 1) + val exec0 = new NettyBlockTransferService(conf0, securityManager0, "localhost", numCores = 1) exec0.init(blockManager) val securityManager1 = new SecurityManager(conf1) - val exec1 = new NettyBlockTransferService(conf1, securityManager1, numCores = 1) + val exec1 = new NettyBlockTransferService(conf1, securityManager1, "localhost", numCores = 1) exec1.init(blockManager) val result = fetchBlock(exec0, exec1, "1", blockId) match { case Success(buf) => val actualString = CharStreams.toString( - new InputStreamReader(buf.createInputStream(), Charset.forName("UTF-8"))) + new InputStreamReader(buf.createInputStream(), StandardCharsets.UTF_8)) actualString should equal(blockString) buf.release() - Success() + Success(()) case Failure(t) => Failure(t) } @@ -148,7 +150,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi } }) - Await.ready(promise.future, FiniteDuration(1000, TimeUnit.MILLISECONDS)) + Await.ready(promise.future, FiniteDuration(10, TimeUnit.SECONDS)) promise.future.value.get } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala index 6f8e8a7ac6033..f3c156e4f709d 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala @@ -17,11 +17,12 @@ package org.apache.spark.network.netty -import org.apache.spark.network.BlockDataManager -import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.mockito.Mockito.mock import org.scalatest._ +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.network.BlockDataManager + class NettyBlockTransferServiceSuite extends SparkFunSuite with BeforeAndAfterEach @@ -31,14 +32,18 @@ class NettyBlockTransferServiceSuite private var service1: NettyBlockTransferService = _ override def afterEach() { - if (service0 != null) { - service0.close() - service0 = null - } + try { + if (service0 != null) { + service0.close() + service0 = null + } - if (service1 != null) { - service1.close() - service1 = null + if (service1 != null) { + service1.close() + service1 = null + } + } finally { + super.afterEach() } } @@ -75,7 +80,7 @@ class NettyBlockTransferServiceSuite .set("spark.blockManager.port", port.toString) val securityManager = new SecurityManager(conf) val blockDataManager = mock(classOf[BlockDataManager]) - val service = new NettyBlockTransferService(conf, securityManager, numCores = 1) + val service = new NettyBlockTransferService(conf, securityManager, "localhost", numCores = 1) service.init(blockDataManager) service } diff --git a/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala b/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala new file mode 100644 index 0000000000000..a79f5b4d74467 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.partial + +import org.apache.spark._ +import org.apache.spark.util.StatCounter + +class SumEvaluatorSuite extends SparkFunSuite with SharedSparkContext { + + test("correct handling of count 1") { + + // setup + val counter = new StatCounter(List(2.0)) + // count of 10 because it's larger than 1, + // and 0.95 because that's the default + val evaluator = new SumEvaluator(10, 0.95) + // arbitrarily assign id 1 + evaluator.merge(1, counter) + + // execute + val res = evaluator.currentResult() + // 38.0 - 7.1E-15 because that's how the maths shakes out + val targetMean = 38.0 - 7.1E-15 + + // Sanity check that equality works on BoundedDouble + assert(new BoundedDouble(2.0, 0.95, 1.1, 1.2) == new BoundedDouble(2.0, 0.95, 1.1, 1.2)) + + // actual test + assert(res == + new BoundedDouble(targetMean, 0.950, Double.NegativeInfinity, Double.PositiveInfinity)) + } + + test("correct handling of count 0") { + + // setup + val counter = new StatCounter(List()) + // count of 10 because it's larger than 0, + // and 0.95 because that's the default + val evaluator = new SumEvaluator(10, 0.95) + // arbitrarily assign id 1 + evaluator.merge(1, counter) + + // execute + val res = evaluator.currentResult() + // assert + assert(res == new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)) + } + + test("correct handling of NaN") { + + // setup + val counter = new StatCounter(List(1, Double.NaN, 2)) + // count of 10 because it's larger than 0, + // and 0.95 because that's the default + val evaluator = new SumEvaluator(10, 0.95) + // arbitrarily assign id 1 + evaluator.merge(1, counter) + + // execute + val res = evaluator.currentResult() + // assert - note semantics of == in face of NaN + assert(res.mean.isNaN) + assert(res.confidence == 0.95) + assert(res.low == Double.NegativeInfinity) + assert(res.high == Double.PositiveInfinity) + } + + test("correct handling of > 1 values") { + + // setup + val counter = new StatCounter(List(1, 3, 2)) + // count of 10 because it's larger than 0, + // and 0.95 because that's the default + val evaluator = new SumEvaluator(10, 0.95) + // arbitrarily assign id 1 + evaluator.merge(1, counter) + + // execute + val res = evaluator.currentResult() + + // These vals because that's how the maths shakes out + val targetMean = 78.0 + val targetLow = -117.617 + 2.732357258139473E-5 + val targetHigh = 273.617 - 2.7323572624027292E-5 + val target = new BoundedDouble(targetMean, 0.95, targetLow, targetHigh) + + + // check that values are within expected tolerance of expectation + assert(res == target) + } + +} diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index ec99f2a1bad66..d18bde790b40a 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.util.concurrent.Semaphore -import scala.concurrent.{Await, TimeoutException} +import scala.concurrent._ import scala.concurrent.duration.Duration import scala.concurrent.ExecutionContext.Implicits.global @@ -27,19 +27,24 @@ import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ -import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark._ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Timeouts { @transient private var sc: SparkContext = _ override def beforeAll() { + super.beforeAll() sc = new SparkContext("local[2]", "test") } override def afterAll() { - LocalSparkContext.stop(sc) - sc = null + try { + LocalSparkContext.stop(sc) + sc = null + } finally { + super.afterAll() + } } lazy val zeroPartRdd = new EmptyRDD[Int](sc) @@ -197,4 +202,33 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim Await.result(f, Duration(20, "milliseconds")) } } + + private def testAsyncAction[R](action: RDD[Int] => FutureAction[R]): Unit = { + val executionContextInvoked = Promise[Unit] + val fakeExecutionContext = new ExecutionContext { + override def execute(runnable: Runnable): Unit = { + executionContextInvoked.success(()) + } + override def reportFailure(t: Throwable): Unit = () + } + val starter = Smuggle(new Semaphore(0)) + starter.drainPermits() + val rdd = sc.parallelize(1 to 100, 4).mapPartitions {itr => starter.acquire(1); itr} + val f = action(rdd) + f.onComplete(_ => ())(fakeExecutionContext) + // Here we verify that registering the callback didn't cause a thread to be consumed. + assert(!executionContextInvoked.isCompleted) + // Now allow the executors to proceed with task processing. + starter.release(rdd.partitions.length) + // Waiting for the result verifies that the tasks were successfully processed. + Await.result(executionContextInvoked.future, atMost = 15.seconds) + } + + test("SimpleFutureAction callback must not consume a thread while waiting") { + testAsyncAction(_.countAsync()) + } + + test("ComplexFutureAction callback must not consume a thread while waiting") { + testAsyncAction((_.takeAsync(100))) + } } diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala index 4e72b89bfcc40..864adddad3426 100644 --- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -164,8 +164,8 @@ class DoubleRDDSuite extends SparkFunSuite with SharedSparkContext { val expectedHistogramResults = Array(4, 2, 1, 2, 3) assert(histogramResults === expectedHistogramResults) } - // Make sure this works with a NaN end bucket and an inifity - test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRangeAndInfity") { + // Make sure this works with a NaN end bucket and an infinity + test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRangeAndInfinity") { // Make sure that it works with two unequally spaced buckets and elements in each val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1, 1.0/0.0, -1.0/0.0, Double.NaN)) @@ -178,7 +178,7 @@ class DoubleRDDSuite extends SparkFunSuite with SharedSparkContext { test("WorksWithOutOfRangeWithInfiniteBuckets") { // Verify that out of range works with two buckets val rdd = sc.parallelize(Seq(10.01, -0.01, Double.NaN)) - val buckets = Array(-1.0/0.0 , 0.0, 1.0/0.0) + val buckets = Array(-1.0/0.0, 0.0, 1.0/0.0) val histogramResults = rdd.histogram(buckets) val expectedHistogramResults = Array(1, 1) assert(histogramResults === expectedHistogramResults) diff --git a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala index 5103eb74b2457..2802cd975292c 100644 --- a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala @@ -17,9 +17,7 @@ package org.apache.spark.rdd -import org.apache.spark.{SparkException, SparkContext, LocalSparkContext, SparkFunSuite} - -import org.mockito.Mockito.spy +import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.storage.{RDDBlockId, StorageLevel} /** @@ -29,6 +27,7 @@ import org.apache.spark.storage.{RDDBlockId, StorageLevel} class LocalCheckpointSuite extends SparkFunSuite with LocalSparkContext { override def beforeEach(): Unit = { + super.beforeEach() sc = new SparkContext("local[2]", "test") } @@ -45,10 +44,6 @@ class LocalCheckpointSuite extends SparkFunSuite with LocalSparkContext { assert(transform(StorageLevel.MEMORY_AND_DISK_SER) === StorageLevel.MEMORY_AND_DISK_SER) assert(transform(StorageLevel.MEMORY_AND_DISK_2) === StorageLevel.MEMORY_AND_DISK_2) assert(transform(StorageLevel.MEMORY_AND_DISK_SER_2) === StorageLevel.MEMORY_AND_DISK_SER_2) - // Off-heap is not supported and Spark should fail fast - intercept[SparkException] { - transform(StorageLevel.OFF_HEAP) - } } test("basic lineage truncation") { diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 1321ec84735b5..b0d69de6e2ef4 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -17,18 +17,22 @@ package org.apache.spark.rdd -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.mapred._ -import org.apache.hadoop.util.Progressable +import java.io.IOException import scala.collection.mutable.{ArrayBuffer, HashSet} import scala.util.Random +import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution} import org.apache.hadoop.conf.{Configurable, Configuration} -import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, OutputCommitter => NewOutputCommitter, -OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, -TaskAttemptContext => NewTaskAttempContext} -import org.apache.spark.{Partitioner, SharedSparkContext, SparkFunSuite} +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.mapred._ +import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, + OutputCommitter => NewOutputCommitter, OutputFormat => NewOutputFormat, + RecordWriter => NewRecordWriter, TaskAttemptContext => NewTaskAttempContext} +import org.apache.hadoop.util.Progressable + +import org.apache.spark._ +import org.apache.spark.Partitioner import org.apache.spark.util.Utils class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { @@ -178,7 +182,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { assert(sums(2) === 1) } - test("reduceByKey with many output partitons") { + test("reduceByKey with many output partitions") { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) val sums = pairs.reduceByKey(_ + _, 10).collect() assert(sums.toSet === Set((1, 7), (2, 1))) @@ -532,6 +536,38 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { assert(FakeOutputCommitter.ran, "OutputCommitter was never called") } + test("failure callbacks should be called before calling writer.close() in saveNewAPIHadoopFile") { + val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + + FakeWriterWithCallback.calledBy = "" + FakeWriterWithCallback.exception = null + val e = intercept[SparkException] { + pairs.saveAsNewAPIHadoopFile[NewFakeFormatWithCallback]("ignored") + } + assert(e.getMessage contains "failed to write") + + assert(FakeWriterWithCallback.calledBy === "write,callback,close") + assert(FakeWriterWithCallback.exception != null, "exception should be captured") + assert(FakeWriterWithCallback.exception.getMessage contains "failed to write") + } + + test("failure callbacks should be called before calling writer.close() in saveAsHadoopFile") { + val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + val conf = new JobConf() + + FakeWriterWithCallback.calledBy = "" + FakeWriterWithCallback.exception = null + val e = intercept[SparkException] { + pairs.saveAsHadoopFile( + "ignored", pairs.keyClass, pairs.valueClass, classOf[FakeFormatWithCallback], conf) + } + assert(e.getMessage contains "failed to write") + + assert(FakeWriterWithCallback.calledBy === "write,callback,close") + assert(FakeWriterWithCallback.exception != null, "exception should be captured") + assert(FakeWriterWithCallback.exception.getMessage contains "failed to write") + } + test("lookup") { val pairs = sc.parallelize(Array((1, 2), (3, 4), (5, 6), (5, 7))) @@ -578,17 +614,36 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0" } - def checkSize(exact: Boolean, - withReplacement: Boolean, - expected: Long, - actual: Long, - p: Double): Boolean = { + def assertBinomialSample( + exact: Boolean, + actual: Int, + trials: Int, + p: Double): Unit = { + if (exact) { + assert(actual == math.ceil(p * trials).toInt) + } else { + val dist = new BinomialDistribution(trials, p) + val q = dist.cumulativeProbability(actual) + withClue(s"p = $p: trials = $trials") { + assert(q >= 0.001 && q <= 0.999) + } + } + } + + def assertPoissonSample( + exact: Boolean, + actual: Int, + trials: Int, + p: Double): Unit = { if (exact) { - return expected == actual + assert(actual == math.ceil(p * trials).toInt) + } else { + val dist = new PoissonDistribution(p * trials) + val q = dist.cumulativeProbability(actual) + withClue(s"p = $p: trials = $trials") { + assert(q >= 0.001 && q <= 0.999) + } } - val stdev = if (withReplacement) math.sqrt(expected) else math.sqrt(expected * p * (1 - p)) - // Very forgiving margin since we're dealing with very small sample sizes most of the time - math.abs(actual - expected) <= 6 * stdev } def testSampleExact(stratifiedData: RDD[(String, Int)], @@ -613,8 +668,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { samplingRate: Double, seed: Long, n: Long): Unit = { - val expectedSampleSize = stratifiedData.countByKey() - .mapValues(count => math.ceil(count * samplingRate).toInt) + val trials = stratifiedData.countByKey() val fractions = Map("1" -> samplingRate, "0" -> samplingRate) val sample = if (exact) { stratifiedData.sampleByKeyExact(false, fractions, seed) @@ -623,8 +677,10 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } val sampleCounts = sample.countByKey() val takeSample = sample.collect() - sampleCounts.foreach { case(k, v) => - assert(checkSize(exact, false, expectedSampleSize(k), v, samplingRate)) } + sampleCounts.foreach { case (k, v) => + assertBinomialSample(exact = exact, actual = v.toInt, trials = trials(k).toInt, + p = samplingRate) + } assert(takeSample.size === takeSample.toSet.size) takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") } } @@ -635,6 +691,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { samplingRate: Double, seed: Long, n: Long): Unit = { + val trials = stratifiedData.countByKey() val expectedSampleSize = stratifiedData.countByKey().mapValues(count => math.ceil(count * samplingRate).toInt) val fractions = Map("1" -> samplingRate, "0" -> samplingRate) @@ -646,7 +703,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { val sampleCounts = sample.countByKey() val takeSample = sample.collect() sampleCounts.foreach { case (k, v) => - assert(checkSize(exact, true, expectedSampleSize(k), v, samplingRate)) + assertPoissonSample(exact, actual = v.toInt, trials = trials(k).toInt, p = samplingRate) } val groupedByKey = takeSample.groupBy(_._1) for ((key, v) <- groupedByKey) { @@ -657,7 +714,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { if (exact) { assert(v.toSet.size <= expectedSampleSize(key)) } else { - assert(checkSize(false, true, expectedSampleSize(key), v.toSet.size, samplingRate)) + assertPoissonSample(false, actual = v.toSet.size, trials(key).toInt, p = samplingRate) } } } @@ -754,6 +811,60 @@ class NewFakeFormat() extends NewOutputFormat[Integer, Integer]() { } } +object FakeWriterWithCallback { + var calledBy: String = "" + var exception: Throwable = _ + + def onFailure(ctx: TaskContext, e: Throwable): Unit = { + calledBy += "callback," + exception = e + } +} + +class FakeWriterWithCallback extends FakeWriter { + + override def close(p1: Reporter): Unit = { + FakeWriterWithCallback.calledBy += "close" + } + + override def write(p1: Integer, p2: Integer): Unit = { + FakeWriterWithCallback.calledBy += "write," + TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) => + FakeWriterWithCallback.onFailure(t, e) + } + throw new IOException("failed to write") + } +} + +class FakeFormatWithCallback() extends FakeOutputFormat { + override def getRecordWriter( + ignored: FileSystem, + job: JobConf, name: String, + progress: Progressable): RecordWriter[Integer, Integer] = { + new FakeWriterWithCallback() + } +} + +class NewFakeWriterWithCallback extends NewFakeWriter { + override def close(p1: NewTaskAttempContext): Unit = { + FakeWriterWithCallback.calledBy += "close" + } + + override def write(p1: Integer, p2: Integer): Unit = { + FakeWriterWithCallback.calledBy += "write," + TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) => + FakeWriterWithCallback.onFailure(t, e) + } + throw new IOException("failed to write") + } +} + +class NewFakeFormatWithCallback() extends NewFakeFormat { + override def getRecordWriter(p1: NewTaskAttempContext): NewRecordWriter[Integer, Integer] = { + new NewFakeWriterWithCallback() + } +} + class ConfigTestFormat() extends NewFakeFormat() with Configurable { var setConfCalled = false diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala index e7cc1617cdf1c..31ce9483cf20a 100644 --- a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala @@ -101,7 +101,7 @@ class ParallelCollectionSplitSuite extends SparkFunSuite with Checkers { val data = 1 until 100 val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_ + _) === 99) + assert(slices.map(_.size).sum === 99) assert(slices.forall(_.isInstanceOf[Range])) } @@ -109,7 +109,7 @@ class ParallelCollectionSplitSuite extends SparkFunSuite with Checkers { val data = 1 to 100 val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_ + _) === 100) + assert(slices.map(_.size).sum === 100) assert(slices.forall(_.isInstanceOf[Range])) } @@ -202,7 +202,7 @@ class ParallelCollectionSplitSuite extends SparkFunSuite with Checkers { val data = 1L until 100L val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_ + _) === 99) + assert(slices.map(_.size).sum === 99) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } @@ -210,7 +210,7 @@ class ParallelCollectionSplitSuite extends SparkFunSuite with Checkers { val data = 1L to 100L val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_ + _) === 100) + assert(slices.map(_.size).sum === 100) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } @@ -218,7 +218,7 @@ class ParallelCollectionSplitSuite extends SparkFunSuite with Checkers { val data = 1.0 until 100.0 by 1.0 val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_ + _) === 99) + assert(slices.map(_.size).sum === 99) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } @@ -226,7 +226,7 @@ class ParallelCollectionSplitSuite extends SparkFunSuite with Checkers { val data = 1.0 to 100.0 by 1.0 val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_ + _) === 100) + assert(slices.map(_.size).sum === 100) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala index 132a5fa9a80fb..cb0de1c6beb6b 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala @@ -29,6 +29,8 @@ class MockSampler extends RandomSampler[Long, Long] { s = seed } + override def sample(): Int = 1 + override def sample(items: Iterator[Long]): Iterator[Long] = { Iterator(s) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 5f73ec8675966..e9cc8195240f0 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -19,15 +19,15 @@ package org.apache.spark.rdd import java.io.File -import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.{LongWritable, Text} -import org.apache.hadoop.mapred.{FileSplit, JobConf, TextInputFormat} - import scala.collection.Map import scala.language.postfixOps import scala.sys.process._ import scala.util.Try +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.mapred.{FileSplit, JobConf, TextInputFormat} + import org.apache.spark._ import org.apache.spark.util.Utils @@ -50,6 +50,27 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { } } + test("failure in iterating over pipe input") { + if (testCommandAvailable("cat")) { + val nums = + sc.makeRDD(Array(1, 2, 3, 4), 2) + .mapPartitionsWithIndex((index, iterator) => { + new Iterator[Int] { + def hasNext = true + def next() = { + throw new SparkException("Exception to simulate bad scenario") + } + } + }) + + val piped = nums.pipe(Seq("cat")) + + intercept[SparkException] { + piped.collect() + } + } + } + test("advanced pipe") { if (testCommandAvailable("cat")) { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) @@ -113,15 +134,27 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { } } - test("pipe with non-zero exit status") { + test("pipe with process which cannot be launched due to bad command") { + if (!testCommandAvailable("some_nonexistent_command")) { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val command = Seq("some_nonexistent_command") + val piped = nums.pipe(command) + val exception = intercept[SparkException] { + piped.collect() + } + assert(exception.getMessage.contains(command.mkString(" "))) + } + } + + test("pipe with process which is launched but fails with non-zero exit status") { if (testCommandAvailable("cat")) { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val piped = nums.pipe(Seq("cat nonexistent_file", "2>", "/dev/null")) - intercept[SparkException] { + val command = Seq("cat", "nonexistent_file") + val piped = nums.pipe(command) + val exception = intercept[SparkException] { piped.collect() } - } else { - assert(true) + assert(exception.getMessage.contains(command.mkString(" "))) } } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 5f718ea9f7be1..24daedab2090f 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.rdd -import java.io.{ObjectInputStream, ObjectOutputStream, IOException} +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} -import com.esotericsoftware.kryo.KryoException - -import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.collection.JavaConverters._ +import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.reflect.ClassTag +import com.esotericsoftware.kryo.KryoException + import org.apache.spark._ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDDSuiteUtils._ @@ -34,6 +34,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { test("basic operations") { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + assert(nums.getNumPartitions === 2) assert(nums.collect().toList === List(1, 2, 3, 4)) assert(nums.toLocalIterator.toList === List(1, 2, 3, 4)) val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) @@ -53,16 +54,16 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(!nums.isEmpty()) assert(nums.max() === 4) assert(nums.min() === 1) - val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) + val partitionSums = nums.mapPartitions(iter => Iterator(iter.sum)) assert(partitionSums.collect().toList === List(3, 7)) val partitionSumsWithSplit = nums.mapPartitionsWithIndex { - case(split, iter) => Iterator((split, iter.reduceLeft(_ + _))) + case(split, iter) => Iterator((split, iter.sum)) } assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7))) val partitionSumsWithIndex = nums.mapPartitionsWithIndex { - case(split, iter) => Iterator((split, iter.reduceLeft(_ + _))) + case(split, iter) => Iterator((split, iter.sum)) } assert(partitionSumsWithIndex.collect().toList === List((0, 3), (1, 7))) @@ -100,21 +101,21 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { } test("SparkContext.union creates UnionRDD if at least one RDD has no partitioner") { - val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) - val rddWithNoPartitioner = sc.parallelize(Seq(2->true)) + val rddWithPartitioner = sc.parallelize(Seq(1 -> true)).partitionBy(new HashPartitioner(1)) + val rddWithNoPartitioner = sc.parallelize(Seq(2 -> true)) val unionRdd = sc.union(rddWithNoPartitioner, rddWithPartitioner) assert(unionRdd.isInstanceOf[UnionRDD[_]]) } test("SparkContext.union creates PartitionAwareUnionRDD if all RDDs have partitioners") { - val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) + val rddWithPartitioner = sc.parallelize(Seq(1 -> true)).partitionBy(new HashPartitioner(1)) val unionRdd = sc.union(rddWithPartitioner, rddWithPartitioner) assert(unionRdd.isInstanceOf[PartitionerAwareUnionRDD[_]]) } test("PartitionAwareUnionRDD raises exception if at least one RDD has no partitioner") { - val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) - val rddWithNoPartitioner = sc.parallelize(Seq(2->true)) + val rddWithPartitioner = sc.parallelize(Seq(1 -> true)).partitionBy(new HashPartitioner(1)) + val rddWithNoPartitioner = sc.parallelize(Seq(2 -> true)) intercept[IllegalArgumentException] { new PartitionerAwareUnionRDD(sc, Seq(rddWithNoPartitioner, rddWithPartitioner)) } @@ -440,66 +441,6 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(prunedData(0) === 10) } - test("mapWith") { - import java.util.Random - val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) - @deprecated("suppress compile time deprecation warning", "1.0.0") - val randoms = ones.mapWith( - (index: Int) => new Random(index + 42)) - {(t: Int, prng: Random) => prng.nextDouble * t}.collect() - val prn42_3 = { - val prng42 = new Random(42) - prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble() - } - val prn43_3 = { - val prng43 = new Random(43) - prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble() - } - assert(randoms(2) === prn42_3) - assert(randoms(5) === prn43_3) - } - - test("flatMapWith") { - import java.util.Random - val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) - @deprecated("suppress compile time deprecation warning", "1.0.0") - val randoms = ones.flatMapWith( - (index: Int) => new Random(index + 42)) - {(t: Int, prng: Random) => - val random = prng.nextDouble() - Seq(random * t, random * t * 10)}. - collect() - val prn42_3 = { - val prng42 = new Random(42) - prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble() - } - val prn43_3 = { - val prng43 = new Random(43) - prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble() - } - assert(randoms(5) === prn42_3 * 10) - assert(randoms(11) === prn43_3 * 10) - } - - test("filterWith") { - import java.util.Random - val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2) - @deprecated("suppress compile time deprecation warning", "1.0.0") - val sample = ints.filterWith( - (index: Int) => new Random(index + 42)) - {(t: Int, prng: Random) => prng.nextInt(3) == 0}. - collect() - val checkSample = { - val prng42 = new Random(42) - val prng43 = new Random(43) - Array(1, 2, 3, 4, 5, 6).filter{i => - if (i < 4) 0 == prng42.nextInt(3) else 0 == prng43.nextInt(3) - } - } - assert(sample.size === checkSample.size) - for (i <- 0 until sample.size) assert(sample(i) === checkSample(i)) - } - test("collect large number of empty partitions") { // Regression test for SPARK-4019 assert(sc.makeRDD(0 until 10, 1000).repartition(2001).collect().toSet === (0 until 10).toSet) @@ -541,6 +482,10 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(nums.take(501) === (1 to 501).toArray) assert(nums.take(999) === (1 to 999).toArray) assert(nums.take(1000) === (1 to 999).toArray) + + nums = sc.parallelize(1 to 2, 2) + assert(nums.take(2147483638).size === 2) + assert(nums.takeAsync(2147483638).get.size === 2) } test("top with predefined ordering") { @@ -969,6 +914,24 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { } } + test("RDD.partitions() fails fast when partitions indicies are incorrect (SPARK-13021)") { + class BadRDD[T: ClassTag](prev: RDD[T]) extends RDD[T](prev) { + + override def compute(part: Partition, context: TaskContext): Iterator[T] = { + prev.compute(part, context) + } + + override protected def getPartitions: Array[Partition] = { + prev.partitions.reverse // breaks contract, which is that `rdd.partitions(i).index == i` + } + } + val rdd = new BadRDD(sc.parallelize(1 to 100, 100)) + val e = intercept[IllegalArgumentException] { + rdd.partitions + } + assert(e.getMessage.contains("partitions")) + } + test("nested RDDs are not supported (SPARK-5063)") { val rdd: RDD[Int] = sc.parallelize(1 to 100) val rdd2: RDD[Int] = sc.parallelize(1 to 100) diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala index a7de9cabe7cc9..f9a7f151823a2 100644 --- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala @@ -19,7 +19,8 @@ package org.apache.spark.rdd import org.scalatest.Matchers -import org.apache.spark.{Logging, SharedSparkContext, SparkFunSuite} +import org.apache.spark.{SharedSparkContext, SparkFunSuite} +import org.apache.spark.internal.Logging class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers with Logging { diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 834e4743df866..cebac2097f380 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -17,18 +17,25 @@ package org.apache.spark.rpc -import java.io.NotSerializableException -import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException} +import java.io.{File, NotSerializableException} +import java.nio.charset.StandardCharsets.UTF_8 +import java.util.UUID +import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch, TimeUnit} import scala.collection.mutable +import scala.collection.JavaConverters._ import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps +import com.google.common.io.Files +import org.mockito.Mockito.{mock, when} import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException, SparkFunSuite} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.util.Utils /** * Common tests for an RpcEnv implementation. @@ -38,13 +45,23 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { var env: RpcEnv = _ override def beforeAll(): Unit = { + super.beforeAll() val conf = new SparkConf() - env = createRpcEnv(conf, "local", 12345) + env = createRpcEnv(conf, "local", 0) + + val sparkEnv = mock(classOf[SparkEnv]) + when(sparkEnv.rpcEnv).thenReturn(env) + SparkEnv.set(sparkEnv) } override def afterAll(): Unit = { - if (env != null) { - env.shutdown() + try { + if (env != null) { + env.shutdown() + } + SparkEnv.set(null) + } finally { + super.afterAll() } } @@ -76,9 +93,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "send-remotely") + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "send-remotely") try { rpcEndpointRef.send("hello") eventually(timeout(5 seconds), interval(10 millis)) { @@ -110,9 +127,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override val rpcEnv = env override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case msg: String => { + case msg: String => context.reply(msg) - } } }) val reply = rpcEndpointRef.askWithRetry[String]("hello") @@ -124,15 +140,14 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override val rpcEnv = env override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case msg: String => { + case msg: String => context.reply(msg) - } } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-remotely") + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "ask-remotely") try { val reply = rpcEndpointRef.askWithRetry[String]("hello") assert("hello" === reply) @@ -147,10 +162,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override val rpcEnv = env override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case msg: String => { + case msg: String => Thread.sleep(100) context.reply(msg) - } } }) @@ -158,9 +172,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val shortProp = "spark.rpc.short.timeout" conf.set("spark.rpc.retry.wait", "0") conf.set("spark.rpc.numRetries", "1") - val anotherEnv = createRpcEnv(conf, "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(conf, "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "ask-timeout") try { // Any exception thrown in askWithRetry is wrapped with a SparkException and set as the cause val e = intercept[SparkException] { @@ -300,10 +314,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override val rpcEnv = env override def receive: PartialFunction[Any, Unit] = { - case m => { + case m => self callSelfSuccessfully = true - } } }) @@ -417,9 +430,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "sendWithReply-remotely") + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "sendWithReply-remotely") try { val f = rpcEndpointRef.ask[String]("hello") val ack = Await.result(f, 5 seconds) @@ -457,10 +470,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef( - "local", env.address, "sendWithReply-remotely-error") + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "sendWithReply-remotely-error") try { val f = rpcEndpointRef.ask[String]("hello") val e = intercept[SparkException] { @@ -473,64 +485,123 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } - test("network events") { - val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] - env.setupEndpoint("network-events", new ThreadSafeRpcEndpoint { - override val rpcEnv = env + /** + * Setup an [[RpcEndpoint]] to collect all network events. + * @return the [[RpcEndpointRef]] and an `ConcurrentLinkedQueue` that contains network events. + */ + private def setupNetworkEndpoint( + _env: RpcEnv, + name: String): (RpcEndpointRef, ConcurrentLinkedQueue[(Any, Any)]) = { + val events = new ConcurrentLinkedQueue[(Any, Any)] + val ref = _env.setupEndpoint("network-events-non-client", new ThreadSafeRpcEndpoint { + override val rpcEnv = _env override def receive: PartialFunction[Any, Unit] = { case "hello" => - case m => events += "receive" -> m + case m => events.add("receive" -> m) } override def onConnected(remoteAddress: RpcAddress): Unit = { - events += "onConnected" -> remoteAddress + events.add("onConnected" -> remoteAddress) } override def onDisconnected(remoteAddress: RpcAddress): Unit = { - events += "onDisconnected" -> remoteAddress + events.add("onDisconnected" -> remoteAddress) } override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { - events += "onNetworkError" -> remoteAddress + events.add("onNetworkError" -> remoteAddress) } }) + (ref, events) + } - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) - // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef( - "local", env.address, "network-events") - val remoteAddress = anotherEnv.address - rpcEndpointRef.send("hello") - eventually(timeout(5 seconds), interval(5 millis)) { - // anotherEnv is connected in client mode, so the remote address may be unknown depending on - // the implementation. Account for that when doing checks. - if (remoteAddress != null) { - assert(events === List(("onConnected", remoteAddress))) - } else { - assert(events.size === 1) - assert(events(0)._1 === "onConnected") + test("network events in sever RpcEnv when another RpcEnv is in server mode") { + val serverEnv1 = createRpcEnv(new SparkConf(), "server1", 0, clientMode = false) + val serverEnv2 = createRpcEnv(new SparkConf(), "server2", 0, clientMode = false) + val (_, events) = setupNetworkEndpoint(serverEnv1, "network-events") + val (serverRef2, _) = setupNetworkEndpoint(serverEnv2, "network-events") + try { + val serverRefInServer2 = serverEnv1.setupEndpointRef(serverRef2.address, serverRef2.name) + // Send a message to set up the connection + serverRefInServer2.send("hello") + + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", serverEnv2.address))) + } + + serverEnv2.shutdown() + serverEnv2.awaitTermination() + + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", serverEnv2.address))) + assert(events.contains(("onDisconnected", serverEnv2.address))) } + } finally { + serverEnv1.shutdown() + serverEnv2.shutdown() + serverEnv1.awaitTermination() + serverEnv2.awaitTermination() } + } - anotherEnv.shutdown() - anotherEnv.awaitTermination() - eventually(timeout(5 seconds), interval(5 millis)) { - // Account for anotherEnv not having an address due to running in client mode. - if (remoteAddress != null) { - assert(events === List( - ("onConnected", remoteAddress), - ("onNetworkError", remoteAddress), - ("onDisconnected", remoteAddress)) || - events === List( - ("onConnected", remoteAddress), - ("onDisconnected", remoteAddress))) - } else { - val eventNames = events.map(_._1) - assert(eventNames === List("onConnected", "onNetworkError", "onDisconnected") || - eventNames === List("onConnected", "onDisconnected")) + test("network events in sever RpcEnv when another RpcEnv is in client mode") { + val serverEnv = createRpcEnv(new SparkConf(), "server", 0, clientMode = false) + val (serverRef, events) = setupNetworkEndpoint(serverEnv, "network-events") + val clientEnv = createRpcEnv(new SparkConf(), "client", 0, clientMode = true) + try { + val serverRefInClient = clientEnv.setupEndpointRef(serverRef.address, serverRef.name) + // Send a message to set up the connection + serverRefInClient.send("hello") + + eventually(timeout(5 seconds), interval(5 millis)) { + // We don't know the exact client address but at least we can verify the message type + assert(events.asScala.map(_._1).exists(_ == "onConnected")) + } + + clientEnv.shutdown() + clientEnv.awaitTermination() + + eventually(timeout(5 seconds), interval(5 millis)) { + // We don't know the exact client address but at least we can verify the message type + assert(events.asScala.map(_._1).exists(_ == "onConnected")) + assert(events.asScala.map(_._1).exists(_ == "onDisconnected")) } + } finally { + clientEnv.shutdown() + serverEnv.shutdown() + clientEnv.awaitTermination() + serverEnv.awaitTermination() + } + } + + test("network events in client RpcEnv when another RpcEnv is in server mode") { + val clientEnv = createRpcEnv(new SparkConf(), "client", 0, clientMode = true) + val serverEnv = createRpcEnv(new SparkConf(), "server", 0, clientMode = false) + val (_, events) = setupNetworkEndpoint(clientEnv, "network-events") + val (serverRef, _) = setupNetworkEndpoint(serverEnv, "network-events") + try { + val serverRefInClient = clientEnv.setupEndpointRef(serverRef.address, serverRef.name) + // Send a message to set up the connection + serverRefInClient.send("hello") + + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", serverEnv.address))) + } + + serverEnv.shutdown() + serverEnv.awaitTermination() + + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", serverEnv.address))) + assert(events.contains(("onDisconnected", serverEnv.address))) + } + } finally { + clientEnv.shutdown() + serverEnv.shutdown() + clientEnv.awaitTermination() + serverEnv.awaitTermination() } } @@ -543,18 +614,16 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef( - "local", env.address, "sendWithReply-unserializable-error") + val rpcEndpointRef = + anotherEnv.setupEndpointRef(env.address, "sendWithReply-unserializable-error") try { val f = rpcEndpointRef.ask[String]("hello") val e = intercept[Exception] { Await.result(f, 1 seconds) } - assert(e.isInstanceOf[TimeoutException] || // For Akka - e.isInstanceOf[NotSerializableException] // For Netty - ) + assert(e.isInstanceOf[NotSerializableException]) } finally { anotherEnv.shutdown() anotherEnv.awaitTermination() @@ -571,8 +640,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") - val localEnv = createRpcEnv(conf, "authentication-local", 13345) - val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345, clientMode = true) + val localEnv = createRpcEnv(conf, "authentication-local", 0) + val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode = true) try { @volatile var message: String = null @@ -583,8 +652,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { case msg: String => message = msg } }) - val rpcEndpointRef = - remoteEnv.setupEndpointRef("authentication-local", localEnv.address, "send-authentication") + val rpcEndpointRef = remoteEnv.setupEndpointRef(localEnv.address, "send-authentication") rpcEndpointRef.send("hello") eventually(timeout(5 seconds), interval(10 millis)) { assert("hello" === message) @@ -602,21 +670,19 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") - val localEnv = createRpcEnv(conf, "authentication-local", 13345) - val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345, clientMode = true) + val localEnv = createRpcEnv(conf, "authentication-local", 0) + val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode = true) try { localEnv.setupEndpoint("ask-authentication", new RpcEndpoint { override val rpcEnv = localEnv override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case msg: String => { + case msg: String => context.reply(msg) - } } }) - val rpcEndpointRef = - remoteEnv.setupEndpointRef("authentication-local", localEnv.address, "ask-authentication") + val rpcEndpointRef = remoteEnv.setupEndpointRef(localEnv.address, "ask-authentication") val reply = rpcEndpointRef.askWithRetry[String]("hello") assert("hello" === reply) } finally { @@ -713,6 +779,68 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { assert(shortTimeout.timeoutProp.r.findAllIn(reply4).length === 1) } + test("file server") { + val conf = new SparkConf() + val tempDir = Utils.createTempDir() + val file = new File(tempDir, "file") + Files.write(UUID.randomUUID().toString(), file, UTF_8) + val fileWithSpecialChars = new File(tempDir, "file name") + Files.write(UUID.randomUUID().toString(), fileWithSpecialChars, UTF_8) + val empty = new File(tempDir, "empty") + Files.write("", empty, UTF_8); + val jar = new File(tempDir, "jar") + Files.write(UUID.randomUUID().toString(), jar, UTF_8) + + val dir1 = new File(tempDir, "dir1") + assert(dir1.mkdir()) + val subFile1 = new File(dir1, "file1") + Files.write(UUID.randomUUID().toString(), subFile1, UTF_8) + + val dir2 = new File(tempDir, "dir2") + assert(dir2.mkdir()) + val subFile2 = new File(dir2, "file2") + Files.write(UUID.randomUUID().toString(), subFile2, UTF_8) + + val fileUri = env.fileServer.addFile(file) + val fileWithSpecialCharsUri = env.fileServer.addFile(fileWithSpecialChars) + val emptyUri = env.fileServer.addFile(empty) + val jarUri = env.fileServer.addJar(jar) + val dir1Uri = env.fileServer.addDirectory("/dir1", dir1) + val dir2Uri = env.fileServer.addDirectory("/dir2", dir2) + + // Try registering directories with invalid names. + Seq("/files", "/jars").foreach { uri => + intercept[IllegalArgumentException] { + env.fileServer.addDirectory(uri, dir1) + } + } + + val destDir = Utils.createTempDir() + val sm = new SecurityManager(conf) + val hc = SparkHadoopUtil.get.conf + + val files = Seq( + (file, fileUri), + (fileWithSpecialChars, fileWithSpecialCharsUri), + (empty, emptyUri), + (jar, jarUri), + (subFile1, dir1Uri + "/file1"), + (subFile2, dir2Uri + "/file2")) + files.foreach { case (f, uri) => + val destFile = new File(destDir, f.getName()) + Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) + assert(Files.equal(f, destFile)) + } + + // Try to download files that do not exist. + Seq("files", "jars", "dir1").foreach { root => + intercept[Exception] { + val uri = env.address.toSparkURL + s"/$root/doesNotExist" + Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) + } + } + } + } class UnserializableClass diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala deleted file mode 100644 index 6478ab51c4da2..0000000000000 --- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rpc.akka - -import org.apache.spark.rpc._ -import org.apache.spark.{SSLSampleConfigs, SecurityManager, SparkConf} - -class AkkaRpcEnvSuite extends RpcEnvSuite { - - override def createRpcEnv(conf: SparkConf, - name: String, - port: Int, - clientMode: Boolean = false): RpcEnv = { - new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf), clientMode)) - } - - test("setupEndpointRef: systemName, address, endpointName") { - val ref = env.setupEndpoint("test_endpoint", new RpcEndpoint { - override val rpcEnv = env - - override def receive = { - case _ => - } - }) - val conf = new SparkConf() - val newRpcEnv = new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf), false)) - try { - val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_endpoint") - assert(s"akka.tcp://local@${env.address}/user/test_endpoint" === - newRef.asInstanceOf[AkkaRpcEndpointRef].actorRef.path.toString) - } finally { - newRpcEnv.shutdown() - } - } - - test("uriOf") { - val uri = env.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint") - assert("akka.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri) - } - - test("uriOf: ssl") { - val conf = SSLSampleConfigs.sparkSSLConfig() - val securityManager = new SecurityManager(conf) - val rpcEnv = new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, "test", "localhost", 12346, securityManager, false)) - try { - val uri = rpcEnv.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint") - assert("akka.ssl.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri) - } finally { - rpcEnv.shutdown() - } - } - -} diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala index 276c077b3d13e..e5539566e4b6f 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala @@ -23,7 +23,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.mockito.Mockito._ import org.apache.spark.SparkFunSuite -import org.apache.spark.rpc.{RpcEnv, RpcEndpoint, RpcAddress, TestRpcEndpoint} +import org.apache.spark.rpc.{RpcAddress, TestRpcEndpoint} class InboxSuite extends SparkFunSuite { @@ -35,7 +35,7 @@ class InboxSuite extends SparkFunSuite { val dispatcher = mock(classOf[Dispatcher]) val inbox = new Inbox(endpointRef, endpoint) - val message = ContentMessage(null, "hi", false, null) + val message = OneWayMessage(null, "hi") inbox.post(message) inbox.process(dispatcher) assert(inbox.isEmpty) @@ -55,7 +55,7 @@ class InboxSuite extends SparkFunSuite { val dispatcher = mock(classOf[Dispatcher]) val inbox = new Inbox(endpointRef, endpoint) - val message = ContentMessage(null, "hi", true, null) + val message = RpcMessage(null, "hi", null) inbox.post(message) inbox.process(dispatcher) assert(inbox.isEmpty) @@ -83,7 +83,7 @@ class InboxSuite extends SparkFunSuite { new Thread { override def run(): Unit = { for (_ <- 0 until 100) { - val message = ContentMessage(null, "hi", false, null) + val message = OneWayMessage(null, "hi") inbox.post(message) } exitLatch.countDown() diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala index 56743ba650b41..4fcdb619f9300 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.rpc.netty import org.apache.spark.SparkFunSuite +import org.apache.spark.rpc.RpcEndpointAddress class NettyRpcAddressSuite extends SparkFunSuite { diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala index ce83087ec04d6..994a58836bd0d 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala @@ -33,9 +33,9 @@ class NettyRpcEnvSuite extends RpcEnvSuite { } test("non-existent endpoint") { - val uri = env.uriOf("test", env.address, "nonexist-endpoint") + val uri = RpcEndpointAddress(env.address, "nonexist-endpoint").toString val e = intercept[RpcEndpointNotFoundException] { - env.setupEndpointRef("test", env.address, "nonexist-endpoint") + env.setupEndpointRef(env.address, "nonexist-endpoint") } assert(e.getMessage.contains(uri)) } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index f9d8e80c98b66..0c156fef0ae0f 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -18,44 +18,47 @@ package org.apache.spark.rpc.netty import java.net.InetSocketAddress +import java.nio.ByteBuffer import io.netty.channel.Channel -import org.mockito.Mockito._ import org.mockito.Matchers._ +import org.mockito.Mockito._ import org.apache.spark.SparkFunSuite -import org.apache.spark.network.client.{TransportResponseHandler, TransportClient} +import org.apache.spark.network.client.{TransportClient, TransportResponseHandler} +import org.apache.spark.network.server.StreamManager import org.apache.spark.rpc._ class NettyRpcHandlerSuite extends SparkFunSuite { val env = mock(classOf[NettyRpcEnv]) - when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())). - thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false)) + val sm = mock(classOf[StreamManager]) + when(env.deserialize(any(classOf[TransportClient]), any(classOf[ByteBuffer]))(any())) + .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null)) test("receive") { val dispatcher = mock(classOf[Dispatcher]) - val nettyRpcHandler = new NettyRpcHandler(dispatcher, env) + val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm) val channel = mock(classOf[Channel]) val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) - nettyRpcHandler.receive(client, null, null) + nettyRpcHandler.channelActive(client) verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000))) } test("connectionTerminated") { val dispatcher = mock(classOf[Dispatcher]) - val nettyRpcHandler = new NettyRpcHandler(dispatcher, env) + val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm) val channel = mock(classOf[Channel]) val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) - nettyRpcHandler.receive(client, null, null) + nettyRpcHandler.channelActive(client) when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) - nettyRpcHandler.connectionTerminated(client) + nettyRpcHandler.channelInactive(client) verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000))) verify(dispatcher, times(1)).postToAll( diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index eef6aafa624ee..04cccc67e328e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -18,16 +18,16 @@ package org.apache.spark.scheduler import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} -import org.apache.spark.util.{SerializableBuffer, AkkaUtils} +import org.apache.spark.util.{RpcUtils, SerializableBuffer} class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext { - test("serialized task larger than akka frame size") { + test("serialized task larger than max RPC message size") { val conf = new SparkConf - conf.set("spark.akka.frameSize", "1") + conf.set("spark.rpc.message.maxSize", "1") conf.set("spark.default.parallelism", "1") sc = new SparkContext("local-cluster[2, 1, 1024]", "test", conf) - val frameSize = AkkaUtils.maxFrameSizeBytes(sc.conf) + val frameSize = RpcUtils.maxMessageSizeBytes(sc.conf) val buffer = new SerializableBuffer(java.nio.ByteBuffer.allocate(2 * frameSize)) val larger = sc.parallelize(Seq(buffer)) val thrown = intercept[SparkException] { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 3816b8c4a09aa..fd96fb04f8b29 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -17,11 +17,13 @@ package org.apache.spark.scheduler -import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap, Map} +import java.util.Properties + +import scala.annotation.meta.param +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.language.reflectiveCalls import scala.util.control.NonFatal -import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -30,7 +32,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} -import org.apache.spark.util.CallSite +import org.apache.spark.util.{CallSite, Utils} class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler) extends DAGSchedulerEventProcessLoop(dagScheduler) { @@ -43,6 +45,13 @@ class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler) case NonFatal(e) => onError(e) } } + + override def onError(e: Throwable): Unit = { + logError("Error in DAGSchedulerEventLoop: ", e) + dagScheduler.stop() + throw e + } + } /** @@ -59,7 +68,7 @@ class MyRDD( numPartitions: Int, dependencies: List[Dependency[_]], locations: Seq[Seq[String]] = Nil, - @transient tracker: MapOutputTrackerMaster = null) + @(transient @param) tracker: MapOutputTrackerMaster = null) extends RDD[(Int, Int)](sc, dependencies) with Serializable { override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = @@ -87,8 +96,7 @@ class MyRDD( class DAGSchedulerSuiteDummyException extends Exception -class DAGSchedulerSuite - extends SparkFunSuite with BeforeAndAfter with LocalSparkContext with Timeouts { +class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeouts { val conf = new SparkConf /** Set of TaskSets the DAGScheduler has requested executed. */ @@ -102,8 +110,10 @@ class DAGSchedulerSuite override def schedulingMode: SchedulingMode = SchedulingMode.NONE override def start() = {} override def stop() = {} - override def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)], - blockManagerId: BlockManagerId): Boolean = true + override def executorHeartbeatReceived( + execId: String, + accumUpdates: Array[(Long, Seq[AccumulableInfo])], + blockManagerId: BlockManagerId): Boolean = true override def submitTasks(taskSet: TaskSet) = { // normally done by TaskSetManager taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) @@ -125,6 +135,7 @@ class DAGSchedulerSuite val successfulStages = new HashSet[Int] val failedStages = new ArrayBuffer[Int] val stageByOrderOfExecution = new ArrayBuffer[Int] + val endedTasks = new HashSet[Long] override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { submittedStageInfos += stageSubmitted.stageInfo @@ -139,6 +150,10 @@ class DAGSchedulerSuite failedStages += stageInfo.stageId } } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + endedTasks += taskEnd.taskInfo.taskId + } } var mapOutputTracker: MapOutputTrackerMaster = null @@ -180,11 +195,13 @@ class DAGSchedulerSuite override def jobFailed(exception: Exception): Unit = { failure = exception } } - before { + override def beforeEach(): Unit = { + super.beforeEach() sc = new SparkContext("local", "DAGSchedulerSuite") sparkListener.submittedStageInfos.clear() sparkListener.successfulStages.clear() sparkListener.failedStages.clear() + sparkListener.endedTasks.clear() failure = null sc.addSparkListener(sparkListener) taskSets.clear() @@ -193,17 +210,21 @@ class DAGSchedulerSuite results.clear() mapOutputTracker = new MapOutputTrackerMaster(conf) scheduler = new DAGScheduler( - sc, - taskScheduler, - sc.listenerBus, - mapOutputTracker, - blockManagerMaster, - sc.env) + sc, + taskScheduler, + sc.listenerBus, + mapOutputTracker, + blockManagerMaster, + sc.env) dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler) } - after { - scheduler.stop() + override def afterEach(): Unit = { + try { + scheduler.stop() + } finally { + super.afterEach() + } } override def afterAll() { @@ -233,26 +254,31 @@ class DAGSchedulerSuite * directly through CompletionEvents. */ private val jobComputeFunc = (context: TaskContext, it: Iterator[(_)]) => - it.next.asInstanceOf[Tuple2[_, _]]._1 + it.next.asInstanceOf[Tuple2[_, _]]._1 /** Send the given CompletionEvent messages for the tasks in the TaskSet. */ private def complete(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) { assert(taskSet.tasks.size >= results.size) for ((result, i) <- results.zipWithIndex) { if (i < taskSet.tasks.size) { - runEvent(CompletionEvent( - taskSet.tasks(i), result._1, result._2, null, createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent(taskSet.tasks(i), result._1, result._2)) } } } - private def completeWithAccumulator(accumId: Long, taskSet: TaskSet, - results: Seq[(TaskEndReason, Any)]) { + private def completeWithAccumulator( + accumId: Long, + taskSet: TaskSet, + results: Seq[(TaskEndReason, Any)]) { assert(taskSet.tasks.size >= results.size) for ((result, i) <- results.zipWithIndex) { if (i < taskSet.tasks.size) { - runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, - Map[Long, Any]((accumId, 1)), createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent( + taskSet.tasks(i), + result._1, + result._2, + Seq(new AccumulableInfo( + accumId, Some(""), Some(1), None, internal = false, countFailedValues = false)))) } } } @@ -262,9 +288,10 @@ class DAGSchedulerSuite rdd: RDD[_], partitions: Array[Int], func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, - listener: JobListener = jobListener): Int = { + listener: JobListener = jobListener, + properties: Properties = null): Int = { val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, func, partitions, CallSite("", ""), listener)) + runEvent(JobSubmitted(jobId, rdd, func, partitions, CallSite("", ""), listener, properties)) jobId } @@ -297,13 +324,18 @@ class DAGSchedulerSuite test("zero split job") { var numResults = 0 + var failureReason: Option[Exception] = None val fakeListener = new JobListener() { - override def taskSucceeded(partition: Int, value: Any) = numResults += 1 - override def jobFailed(exception: Exception) = throw exception + override def taskSucceeded(partition: Int, value: Any): Unit = numResults += 1 + override def jobFailed(exception: Exception): Unit = { + failureReason = Some(exception) + } } val jobId = submit(new MyRDD(sc, 0, Nil), Array(), listener = fakeListener) assert(numResults === 0) cancel(jobId) + assert(failureReason.isDefined) + assert(failureReason.get.getMessage() === "Job 0 cancelled ") } test("run trivial job") { @@ -323,9 +355,12 @@ class DAGSchedulerSuite } test("equals and hashCode AccumulableInfo") { - val accInfo1 = new AccumulableInfo(1, " Accumulable " + 1, Some("delta" + 1), "val" + 1, true) - val accInfo2 = new AccumulableInfo(1, " Accumulable " + 1, Some("delta" + 1), "val" + 1, false) - val accInfo3 = new AccumulableInfo(1, " Accumulable " + 1, Some("delta" + 1), "val" + 1, false) + val accInfo1 = new AccumulableInfo( + 1, Some("a1"), Some("delta1"), Some("val1"), internal = true, countFailedValues = false) + val accInfo2 = new AccumulableInfo( + 1, Some("a1"), Some("delta1"), Some("val1"), internal = false, countFailedValues = false) + val accInfo3 = new AccumulableInfo( + 1, Some("a1"), Some("delta1"), Some("val1"), internal = false, countFailedValues = false) assert(accInfo1 !== accInfo2) assert(accInfo2 === accInfo3) assert(accInfo2.hashCode() === accInfo3.hashCode()) @@ -449,7 +484,7 @@ class DAGSchedulerSuite override def defaultParallelism(): Int = 2 override def executorHeartbeatReceived( execId: String, - taskMetrics: Array[(Long, TaskMetrics)], + accumUpdates: Array[(Long, Seq[AccumulableInfo])], blockManagerId: BlockManagerId): Boolean = true override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} override def applicationAttemptId(): Option[String] = None @@ -484,8 +519,8 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)))) + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)))) assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) complete(taskSets(1), Seq((Success, 42))) @@ -500,12 +535,12 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0, 1)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), - (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) // the 2nd ResultTask failed complete(taskSets(1), Seq( - (Success, 42), - (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null))) + (Success, 42), + (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null))) // this will get called // blockManagerMaster.removeExecutor("exec-hostA") // ask the scheduler to try it again @@ -594,11 +629,17 @@ class DAGSchedulerSuite * @param stageId - The current stageId * @param attemptIdx - The current attempt count */ - private def completeNextResultStageWithSuccess(stageId: Int, attemptIdx: Int): Unit = { + private def completeNextResultStageWithSuccess( + stageId: Int, + attemptIdx: Int, + partitionToResult: Int => Int = _ => 42): Unit = { val stageAttempt = taskSets.last checkStageId(stageId, attemptIdx, stageAttempt) assert(scheduler.stageIdToStage(stageId).isInstanceOf[ResultStage]) - complete(stageAttempt, stageAttempt.tasks.zipWithIndex.map(_ => (Success, 42)).toSeq) + val taskResults = stageAttempt.tasks.zipWithIndex.map { case (task, idx) => + (Success, partitionToResult(idx)) + } + complete(stageAttempt, taskResults.toSeq) } /** @@ -629,7 +670,7 @@ class DAGSchedulerSuite completeShuffleMapStageSuccessfully(0, 1, numShufflePartitions = parts) completeNextResultStageWithSuccess(1, 1) - // Confirm job finished succesfully + // Confirm job finished successfully sc.listenerBus.waitUntilEmpty(1000) assert(ended === true) assert(results === (0 until parts).map { idx => idx -> 42 }.toMap) @@ -808,23 +849,17 @@ class DAGSchedulerSuite HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(1).tasks(0), FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), - null, - Map[Long, Any](), - createFakeTaskInfo(), null)) sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.failedStages.contains(1)) // The second ResultTask fails, with a fetch failure for the output from the second mapper. - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(1).tasks(0), FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1, "ignored"), - null, - Map[Long, Any](), - createFakeTaskInfo(), null)) // The SparkListener should not receive redundant failure events. sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) @@ -861,12 +896,9 @@ class DAGSchedulerSuite HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(1).tasks(0), FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), - null, - Map[Long, Any](), - createFakeTaskInfo(), null)) sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.failedStages.contains(1)) @@ -879,12 +911,9 @@ class DAGSchedulerSuite assert(countSubmittedMapStageAttempts() === 2) // The second ResultTask fails, with a fetch failure for the output from the second mapper. - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(1).tasks(1), FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"), - null, - Map[Long, Any](), - createFakeTaskInfo(), null)) // Another ResubmitFailedStages event should not result in another attempt for the map @@ -899,11 +928,11 @@ class DAGSchedulerSuite } /** - * This tests the case where a late FetchFailed comes in after the map stage has finished getting - * retried and a new reduce stage starts running. - */ + * This tests the case where a late FetchFailed comes in after the map stage has finished getting + * retried and a new reduce stage starts running. + */ test("extremely late fetch failures don't cause multiple concurrent attempts for " + - "the same stage") { + "the same stage") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId @@ -931,12 +960,9 @@ class DAGSchedulerSuite assert(countSubmittedReduceStageAttempts() === 1) // The first result task fails, with a fetch failure for the output from the first mapper. - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(1).tasks(0), FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), - null, - Map[Long, Any](), - createFakeTaskInfo(), null)) // Trigger resubmission of the failed map stage and finish the re-started map task. @@ -950,12 +976,9 @@ class DAGSchedulerSuite assert(countSubmittedReduceStageAttempts() === 2) // A late FetchFailed arrives from the second task in the original reduce stage. - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(1).tasks(1), FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"), - null, - Map[Long, Any](), - createFakeTaskInfo(), null)) // Running ResubmitFailedStages shouldn't result in any more attempts for the map stage, because @@ -966,6 +989,52 @@ class DAGSchedulerSuite assert(countSubmittedMapStageAttempts() === 2) } + test("task events always posted in speculation / when stage is killed") { + val baseRdd = new MyRDD(sc, 4, Nil) + val finalRdd = new MyRDD(sc, 4, List(new OneToOneDependency(baseRdd))) + submit(finalRdd, Array(0, 1, 2, 3)) + + // complete two tasks + runEvent(makeCompletionEvent( + taskSets(0).tasks(0), Success, 42, + Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(0))) + runEvent(makeCompletionEvent( + taskSets(0).tasks(1), Success, 42, + Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(1))) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + // verify stage exists + assert(scheduler.stageIdToStage.contains(0)) + assert(sparkListener.endedTasks.size == 2) + + // finish other 2 tasks + runEvent(makeCompletionEvent( + taskSets(0).tasks(2), Success, 42, + Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(2))) + runEvent(makeCompletionEvent( + taskSets(0).tasks(3), Success, 42, + Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(3))) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sparkListener.endedTasks.size == 4) + + // verify the stage is done + assert(!scheduler.stageIdToStage.contains(0)) + + // Stage should be complete. Finish one other Successful task to simulate what can happen + // with a speculative task and make sure the event is sent out + runEvent(makeCompletionEvent( + taskSets(0).tasks(3), Success, 42, + Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(5))) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sparkListener.endedTasks.size == 5) + + // make sure non successful tasks also send out event + runEvent(makeCompletionEvent( + taskSets(0).tasks(3), UnknownReason, 42, + Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(6))) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sparkListener.endedTasks.size == 6) + } + test("ignore late map task completions") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) @@ -986,48 +1055,36 @@ class DAGSchedulerSuite assert(shuffleStage.numAvailableOutputs === 0) // should be ignored for being too old - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSet.tasks(0), Success, - makeMapStatus("hostA", reduceRdd.partitions.size), - null, - createFakeTaskInfo(), - null)) + makeMapStatus("hostA", reduceRdd.partitions.size))) assert(shuffleStage.numAvailableOutputs === 0) // should work because it's a non-failed host (so the available map outputs will increase) - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSet.tasks(0), Success, - makeMapStatus("hostB", reduceRdd.partitions.size), - null, - createFakeTaskInfo(), - null)) + makeMapStatus("hostB", reduceRdd.partitions.size))) assert(shuffleStage.numAvailableOutputs === 1) // should be ignored for being too old - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSet.tasks(0), Success, - makeMapStatus("hostA", reduceRdd.partitions.size), - null, - createFakeTaskInfo(), - null)) + makeMapStatus("hostA", reduceRdd.partitions.size))) assert(shuffleStage.numAvailableOutputs === 1) // should work because it's a new epoch, which will increase the number of available map // outputs, and also finish the stage taskSet.tasks(1).epoch = newEpoch - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSet.tasks(1), Success, - makeMapStatus("hostA", reduceRdd.partitions.size), - null, - createFakeTaskInfo(), - null)) + makeMapStatus("hostA", reduceRdd.partitions.size))) assert(shuffleStage.numAvailableOutputs === 2) assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) + HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) // finish the next stage normally, which completes the job complete(taskSets(1), Seq((Success, 42), (Success, 43))) @@ -1054,6 +1111,47 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + /** + * Run two jobs, with a shared dependency. We simulate a fetch failure in the second job, which + * requires regenerating some outputs of the shared dependency. One key aspect of this test is + * that the second job actually uses a different stage for the shared dependency (a "skipped" + * stage). + */ + test("shuffle fetch failure in a reused shuffle dependency") { + // Run the first job successfully, which creates one shuffle dependency + + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + submit(reduceRdd, Array(0, 1)) + + completeShuffleMapStageSuccessfully(0, 0, 2) + completeNextResultStageWithSuccess(1, 0) + assert(results === Map(0 -> 42, 1 -> 42)) + assertDataStructuresEmpty() + + // submit another job w/ the shared dependency, and have a fetch failure + val reduce2 = new MyRDD(sc, 2, List(shuffleDep)) + submit(reduce2, Array(0, 1)) + // Note that the stage numbering here is only b/c the shared dependency produces a new, skipped + // stage. If instead it reused the existing stage, then this would be stage 2 + completeNextStageWithFetchFailure(3, 0, shuffleDep) + scheduler.resubmitFailedStages() + + // the scheduler now creates a new task set to regenerate the missing map output, but this time + // using a different stage, the "skipped" one + + // SPARK-9809 -- this stage is submitted without a task for each partition (because some of + // the shuffle map output is still available from stage 0); make sure we've still got internal + // accumulators setup + assert(scheduler.stageIdToStage(2).latestInfo.internalAccumulators.nonEmpty) + completeShuffleMapStageSuccessfully(2, 0, 2) + completeNextResultStageWithSuccess(3, 1, idx => idx + 1234) + assert(results === Map(0 -> 1234, 1 -> 1235)) + + assertDataStructuresEmpty() + } + /** * This test runs a three stage job, with a fetch failure in stage 1. but during the retry, we * have completions from both the first & second attempt of stage 1. So all the map output is @@ -1078,12 +1176,9 @@ class DAGSchedulerSuite // then one executor dies, and a task fails in stage 1 runEvent(ExecutorLost("exec-hostA")) - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(1).tasks(0), FetchFailed(null, firstShuffleId, 2, 0, "Fetch failed"), - null, - null, - createFakeTaskInfo(), null)) // so we resubmit stage 0, which completes happily @@ -1093,13 +1188,10 @@ class DAGSchedulerSuite assert(stage0Resubmit.stageAttemptId === 1) val task = stage0Resubmit.tasks(0) assert(task.partitionId === 2) - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( task, Success, - makeMapStatus("hostC", shuffleMapRdd.partitions.length), - null, - createFakeTaskInfo(), - null)) + makeMapStatus("hostC", shuffleMapRdd.partitions.length))) // now here is where things get tricky : we will now have a task set representing // the second attempt for stage 1, but we *also* have some tasks for the first attempt for @@ -1112,28 +1204,19 @@ class DAGSchedulerSuite // we'll have some tasks finish from the first attempt, and some finish from the second attempt, // so that we actually have all stage outputs, though no attempt has completed all its // tasks - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(3).tasks(0), Success, - makeMapStatus("hostC", reduceRdd.partitions.length), - null, - createFakeTaskInfo(), - null)) - runEvent(CompletionEvent( + makeMapStatus("hostC", reduceRdd.partitions.length))) + runEvent(makeCompletionEvent( taskSets(3).tasks(1), Success, - makeMapStatus("hostC", reduceRdd.partitions.length), - null, - createFakeTaskInfo(), - null)) + makeMapStatus("hostC", reduceRdd.partitions.length))) // late task finish from the first attempt - runEvent(CompletionEvent( + runEvent(makeCompletionEvent( taskSets(1).tasks(2), Success, - makeMapStatus("hostB", reduceRdd.partitions.length), - null, - createFakeTaskInfo(), - null)) + makeMapStatus("hostB", reduceRdd.partitions.length))) // What should happen now is that we submit stage 2. However, we might not see an error // b/c of DAGScheduler's error handling (it tends to swallow errors and just log them). But @@ -1180,21 +1263,21 @@ class DAGSchedulerSuite submit(reduceRdd, Array(0)) // complete some of the tasks from the first stage, on one host - runEvent(CompletionEvent( - taskSets(0).tasks(0), Success, - makeMapStatus("hostA", reduceRdd.partitions.length), null, createFakeTaskInfo(), null)) - runEvent(CompletionEvent( - taskSets(0).tasks(1), Success, - makeMapStatus("hostA", reduceRdd.partitions.length), null, createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent( + taskSets(0).tasks(0), + Success, + makeMapStatus("hostA", reduceRdd.partitions.length))) + runEvent(makeCompletionEvent( + taskSets(0).tasks(1), + Success, + makeMapStatus("hostA", reduceRdd.partitions.length))) // now that host goes down runEvent(ExecutorLost("exec-hostA")) // so we resubmit those tasks - runEvent(CompletionEvent( - taskSets(0).tasks(0), Resubmitted, null, null, createFakeTaskInfo(), null)) - runEvent(CompletionEvent( - taskSets(0).tasks(1), Resubmitted, null, null, createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent(taskSets(0).tasks(0), Resubmitted, null)) + runEvent(makeCompletionEvent(taskSets(0).tasks(1), Resubmitted, null)) // now complete everything on a different host complete(taskSets(0), Seq( @@ -1275,6 +1358,106 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + def checkJobPropertiesAndPriority(taskSet: TaskSet, expected: String, priority: Int): Unit = { + assert(taskSet.properties != null) + assert(taskSet.properties.getProperty("testProperty") === expected) + assert(taskSet.priority === priority) + } + + def launchJobsThatShareStageAndCancelFirst(): ShuffleDependency[Int, Int, Nothing] = { + val baseRdd = new MyRDD(sc, 1, Nil) + val shuffleDep1 = new ShuffleDependency(baseRdd, new HashPartitioner(1)) + val intermediateRdd = new MyRDD(sc, 1, List(shuffleDep1)) + val shuffleDep2 = new ShuffleDependency(intermediateRdd, new HashPartitioner(1)) + val finalRdd1 = new MyRDD(sc, 1, List(shuffleDep2)) + val finalRdd2 = new MyRDD(sc, 1, List(shuffleDep2)) + val job1Properties = new Properties() + val job2Properties = new Properties() + job1Properties.setProperty("testProperty", "job1") + job2Properties.setProperty("testProperty", "job2") + + // Run jobs 1 & 2, both referencing the same stage, then cancel job1. + // Note that we have to submit job2 before we cancel job1 to have them actually share + // *Stages*, and not just shuffle dependencies, due to skipped stages (at least until + // we address SPARK-10193.) + val jobId1 = submit(finalRdd1, Array(0), properties = job1Properties) + val jobId2 = submit(finalRdd2, Array(0), properties = job2Properties) + assert(scheduler.activeJobs.nonEmpty) + val testProperty1 = scheduler.jobIdToActiveJob(jobId1).properties.getProperty("testProperty") + + // remove job1 as an ActiveJob + cancel(jobId1) + + // job2 should still be running + assert(scheduler.activeJobs.nonEmpty) + val testProperty2 = scheduler.jobIdToActiveJob(jobId2).properties.getProperty("testProperty") + assert(testProperty1 != testProperty2) + // NB: This next assert isn't necessarily the "desired" behavior; it's just to document + // the current behavior. We've already submitted the TaskSet for stage 0 based on job1, but + // even though we have cancelled that job and are now running it because of job2, we haven't + // updated the TaskSet's properties. Changing the properties to "job2" is likely the more + // correct behavior. + val job1Id = 0 // TaskSet priority for Stages run with "job1" as the ActiveJob + checkJobPropertiesAndPriority(taskSets(0), "job1", job1Id) + complete(taskSets(0), Seq((Success, makeMapStatus("hostA", 1)))) + + shuffleDep1 + } + + /** + * Makes sure that tasks for a stage used by multiple jobs are submitted with the properties of a + * later, active job if they were previously run under a job that is no longer active + */ + test("stage used by two jobs, the first no longer active (SPARK-6880)") { + launchJobsThatShareStageAndCancelFirst() + + // The next check is the key for SPARK-6880. For the stage which was shared by both job1 and + // job2 but never had any tasks submitted for job1, the properties of job2 are now used to run + // the stage. + checkJobPropertiesAndPriority(taskSets(1), "job2", 1) + + complete(taskSets(1), Seq((Success, makeMapStatus("hostA", 1)))) + assert(taskSets(2).properties != null) + complete(taskSets(2), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assert(scheduler.activeJobs.isEmpty) + + assertDataStructuresEmpty() + } + + /** + * Makes sure that tasks for a stage used by multiple jobs are submitted with the properties of a + * later, active job if they were previously run under a job that is no longer active, even when + * there are fetch failures + */ + test("stage used by two jobs, some fetch failures, and the first job no longer active " + + "(SPARK-6880)") { + val shuffleDep1 = launchJobsThatShareStageAndCancelFirst() + val job2Id = 1 // TaskSet priority for Stages run with "job2" as the ActiveJob + + // lets say there is a fetch failure in this task set, which makes us go back and + // run stage 0, attempt 1 + complete(taskSets(1), Seq( + (FetchFailed(makeBlockManagerId("hostA"), shuffleDep1.shuffleId, 0, 0, "ignored"), null))) + scheduler.resubmitFailedStages() + + // stage 0, attempt 1 should have the properties of job2 + assert(taskSets(2).stageId === 0) + assert(taskSets(2).stageAttemptId === 1) + checkJobPropertiesAndPriority(taskSets(2), "job2", job2Id) + + // run the rest of the stages normally, checking that they have the correct properties + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) + checkJobPropertiesAndPriority(taskSets(3), "job2", job2Id) + complete(taskSets(3), Seq((Success, makeMapStatus("hostA", 1)))) + checkJobPropertiesAndPriority(taskSets(4), "job2", job2Id) + complete(taskSets(4), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assert(scheduler.activeJobs.isEmpty) + + assertDataStructuresEmpty() + } + test("run trivial shuffle with out-of-band failure and retry") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) @@ -1287,12 +1470,12 @@ class DAGSchedulerSuite // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks // rather than marking it is as failed and waiting. complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)))) + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)))) // have hostC complete the resubmitted task complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) complete(taskSets(2), Seq((Success, 42))) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -1307,15 +1490,15 @@ class DAGSchedulerSuite submit(finalRdd, Array(0)) // have the first stage complete normally complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 2)), - (Success, makeMapStatus("hostB", 2)))) + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) // have the second stage complete normally complete(taskSets(1), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostC", 1)))) + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostC", 1)))) // fail the third stage because hostA went down complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) // TODO assert this: // blockManagerMaster.removeExecutor("exec-hostA") // have DAGScheduler try again @@ -1338,15 +1521,15 @@ class DAGSchedulerSuite cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) // complete stage 0 complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 2)), - (Success, makeMapStatus("hostB", 2)))) + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) // complete stage 1 complete(taskSets(1), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)))) + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)))) // pretend stage 2 failed because hostA went down complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) // TODO assert this: // blockManagerMaster.removeExecutor("exec-hostA") // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun. @@ -1444,6 +1627,25 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + test("accumulators are updated on exception failures") { + val acc1 = sc.accumulator(0L, "ingenieur") + val acc2 = sc.accumulator(0L, "boulanger") + val acc3 = sc.accumulator(0L, "agriculteur") + assert(Accumulators.get(acc1.id).isDefined) + assert(Accumulators.get(acc2.id).isDefined) + assert(Accumulators.get(acc3.id).isDefined) + val accInfo1 = acc1.toInfo(Some(15L), None) + val accInfo2 = acc2.toInfo(Some(13L), None) + val accInfo3 = acc3.toInfo(Some(18L), None) + val accumUpdates = Seq(accInfo1, accInfo2, accInfo3) + val exceptionFailure = new ExceptionFailure(new SparkException("fondue?"), accumUpdates) + submit(new MyRDD(sc, 1, Nil), Array(0)) + runEvent(makeCompletionEvent(taskSets.head.tasks.head, exceptionFailure, "result")) + assert(Accumulators.get(acc1.id).get.value === 15L) + assert(Accumulators.get(acc2.id).get.value === 13L) + assert(Accumulators.get(acc3.id).get.value === 18L) + } + test("reduce tasks should be placed locally with map output") { // Create an shuffleMapRdd with 1 partition val shuffleMapRdd = new MyRDD(sc, 1, Nil) @@ -1452,9 +1654,9 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 1)))) + (Success, makeMapStatus("hostA", 1)))) assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostA"))) + HashSet(makeBlockManagerId("hostA"))) // Reducer should run on the same host that map task ran val reduceTaskSet = taskSets(1) @@ -1516,7 +1718,7 @@ class DAGSchedulerSuite } // Does not include message, ONLY stack trace. - val stackTraceString = e.getStackTraceString + val stackTraceString = Utils.exceptionString(e) // should actually include the RDD operation that invoked the method: assert(stackTraceString.contains("org.apache.spark.rdd.RDD.count")) @@ -1525,6 +1727,18 @@ class DAGSchedulerSuite assert(stackTraceString.contains("org.scalatest.FunSuite")) } + test("catch errors in event loop") { + // this is a test of our testing framework -- make sure errors in event loop don't get ignored + + // just run some bad event that will throw an exception -- we'll give a null TaskEndReason + val rdd1 = new MyRDD(sc, 1, Nil) + submit(rdd1, Array(0)) + intercept[Exception] { + complete(taskSets(0), Seq( + (null, makeMapStatus("hostA", 1)))) + } + } + test("simple map stage submission") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) @@ -1710,8 +1924,7 @@ class DAGSchedulerSuite submitMapStage(shuffleDep) val oldTaskSet = taskSets(0) - runEvent(CompletionEvent(oldTaskSet.tasks(0), Success, makeMapStatus("hostA", 2), - null, createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent(oldTaskSet.tasks(0), Success, makeMapStatus("hostA", 2))) assert(results.size === 0) // Map stage job should not be complete yet // Pretend host A was lost @@ -1721,23 +1934,19 @@ class DAGSchedulerSuite assert(newEpoch > oldEpoch) // Suppose we also get a completed event from task 1 on the same host; this should be ignored - runEvent(CompletionEvent(oldTaskSet.tasks(1), Success, makeMapStatus("hostA", 2), - null, createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent(oldTaskSet.tasks(1), Success, makeMapStatus("hostA", 2))) assert(results.size === 0) // Map stage job should not be complete yet // A completion from another task should work because it's a non-failed host - runEvent(CompletionEvent(oldTaskSet.tasks(2), Success, makeMapStatus("hostB", 2), - null, createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent(oldTaskSet.tasks(2), Success, makeMapStatus("hostB", 2))) assert(results.size === 0) // Map stage job should not be complete yet // Now complete tasks in the second task set val newTaskSet = taskSets(1) assert(newTaskSet.tasks.size === 2) // Both tasks 0 and 1 were on on hostA - runEvent(CompletionEvent(newTaskSet.tasks(0), Success, makeMapStatus("hostB", 2), - null, createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent(newTaskSet.tasks(0), Success, makeMapStatus("hostB", 2))) assert(results.size === 0) // Map stage job should not be complete yet - runEvent(CompletionEvent(newTaskSet.tasks(1), Success, makeMapStatus("hostB", 2), - null, createFakeTaskInfo(), null)) + runEvent(makeCompletionEvent(newTaskSet.tasks(1), Success, makeMapStatus("hostB", 2))) assert(results.size === 1) // Map stage job should now finally be complete assertDataStructuresEmpty() @@ -1788,5 +1997,24 @@ class DAGSchedulerSuite info } -} + private def createFakeTaskInfoWithId(taskId: Long): TaskInfo = { + val info = new TaskInfo(taskId, 0, 0, 0L, "", "", TaskLocality.ANY, false) + info.finishTime = 1 // to prevent spurious errors in JobProgressListener + info + } + private def makeCompletionEvent( + task: Task[_], + reason: TaskEndReason, + result: Any, + extraAccumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo], + taskInfo: TaskInfo = createFakeTaskInfo()): CompletionEvent = { + val accumUpdates = reason match { + case Success => task.initialAccumulators.map { a => a.toInfo(Some(a.zero), None) } + case ef: ExceptionFailure => ef.accumUpdates + case _ => Seq.empty[AccumulableInfo] + } + CompletionEvent(task, reason, result, accumUpdates ++ extraAccumUpdates, taskInfo) + } + +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 5cb2d4225d281..176d8930aad19 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -29,6 +29,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging import org.apache.spark.io._ import org.apache.spark.util.{JsonProtocol, Utils} @@ -67,11 +68,11 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit val logPath = new Path(eventLogger.logPath + EventLoggingListener.IN_PROGRESS) assert(fileSystem.exists(logPath)) val logStatus = fileSystem.getFileStatus(logPath) - assert(!logStatus.isDir) + assert(!logStatus.isDirectory) // Verify log is renamed after stop() eventLogger.stop() - assert(!fileSystem.getFileStatus(new Path(eventLogger.logPath)).isDir) + assert(!fileSystem.getFileStatus(new Path(eventLogger.logPath)).isDirectory) } test("Basic event logging") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index f7e16af9d3a92..e3e6df6831def 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -17,12 +17,14 @@ package org.apache.spark.scheduler +import java.util.Properties + import org.apache.spark.TaskContext class FakeTask( stageId: Int, prefLocs: Seq[TaskLocation] = Nil) - extends Task[Int](stageId, 0, 0, Seq.empty) { + extends Task[Int](stageId, 0, 0, Seq.empty, new Properties) { override def runTask(context: TaskContext): Int = 0 override def preferredLocations: Seq[TaskLocation] = prefLocs } diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobWaiterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobWaiterSuite.scala new file mode 100644 index 0000000000000..bc8e513fe5bc8 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/JobWaiterSuite.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.util.Failure + +import org.apache.spark.SparkFunSuite + +class JobWaiterSuite extends SparkFunSuite { + + test("call jobFailed multiple times") { + val waiter = new JobWaiter[Int](null, 0, totalTasks = 2, null) + + // Should not throw exception if calling jobFailed multiple times + waiter.jobFailed(new RuntimeException("Oops 1")) + waiter.jobFailed(new RuntimeException("Oops 2")) + waiter.jobFailed(new RuntimeException("Oops 3")) + + waiter.completionFuture.value match { + case Some(Failure(e)) => + // We should receive the first exception + assert("Oops 1" === e.getMessage) + case other => fail("Should receiver the first exception but it was " + other) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index b8e466fab4506..759d52fca5ce1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.scheduler -import org.apache.spark.storage.BlockManagerId +import scala.util.Random + +import org.roaringbitmap.RoaringBitmap import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer - -import scala.util.Random +import org.apache.spark.storage.BlockManagerId class MapStatusSuite extends SparkFunSuite { @@ -78,7 +79,7 @@ class MapStatusSuite extends SparkFunSuite { test("HighlyCompressedMapStatus: estimated size should be the average non-empty block size") { val sizes = Array.tabulate[Long](3000) { i => i.toLong } - val avg = sizes.sum / sizes.filter(_ != 0).length + val avg = sizes.sum / sizes.count(_ != 0) val loc = BlockManagerId("a", "b", 10) val status = MapStatus(loc, sizes) val status1 = compressAndDecompressMapStatus(status) @@ -97,4 +98,34 @@ class MapStatusSuite extends SparkFunSuite { val buf = ser.newInstance().serialize(status) ser.newInstance().deserialize[MapStatus](buf) } + + test("RoaringBitmap: runOptimize succeeded") { + val r = new RoaringBitmap + (1 to 200000).foreach(i => + if (i % 200 != 0) { + r.add(i) + } + ) + val size1 = r.getSizeInBytes + val success = r.runOptimize() + r.trim() + val size2 = r.getSizeInBytes + assert(size1 > size2) + assert(success) + } + + test("RoaringBitmap: runOptimize failed") { + val r = new RoaringBitmap + (1 to 200000).foreach(i => + if (i % 200 == 0) { + r.add(i) + } + ) + val size1 = r.getSizeInBytes + val success = r.runOptimize() + r.trim() + val size2 = r.getSizeInBytes + assert(size1 === size2) + assert(!success) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala index f33324792495b..76a7087645961 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala @@ -17,7 +17,8 @@ package org.apache.spark.scheduler -import java.io.{ObjectInputStream, ObjectOutputStream, IOException} +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} +import java.util.Properties import org.apache.spark.TaskContext @@ -25,7 +26,7 @@ import org.apache.spark.TaskContext * A Task implementation that fails to serialize. */ private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) - extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) { + extends Task[Array[Byte]](stageId, 0, 0, Seq.empty, new Properties) { override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte] override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]() diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala index 1ae5b030f0832..601f1c378c41f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.scheduler import org.apache.hadoop.mapred.{FileOutputCommitter, TaskAttemptContext} import org.scalatest.concurrent.Timeouts -import org.scalatest.time.{Span, Seconds} +import org.scalatest.time.{Seconds, Span} -import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext, SparkFunSuite, TaskContext} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite, TaskContext} import org.apache.spark.util.Utils /** @@ -38,7 +38,7 @@ class OutputCommitCoordinatorIntegrationSuite super.beforeAll() val conf = new SparkConf() .set("master", "local[2,4]") - .set("spark.speculation", "true") + .set("spark.hadoop.outputCommitCoordination.enabled", "true") .set("spark.hadoop.mapred.output.committer.class", classOf[ThrowExceptionOnFirstAttemptOutputCommitter].getCanonicalName) sc = new SparkContext("local[2, 4]", "test", conf) diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 7345508bfe995..8e509de7677c3 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -20,22 +20,21 @@ package org.apache.spark.scheduler import java.io.File import java.util.concurrent.TimeoutException +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.apache.hadoop.mapred.{JobConf, OutputCommitter, TaskAttemptContext, TaskAttemptID} import org.mockito.Matchers import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfter -import org.apache.hadoop.mapred.{TaskAttemptID, JobConf, TaskAttemptContext, OutputCommitter} - import org.apache.spark._ -import org.apache.spark.rdd.{RDD, FakeOutputCommitter} +import org.apache.spark.rdd.{FakeOutputCommitter, RDD} import org.apache.spark.util.Utils -import scala.concurrent.Await -import scala.concurrent.duration._ -import scala.language.postfixOps - /** * Unit tests for the output commit coordination functionality. * @@ -78,7 +77,7 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { val conf = new SparkConf() .setMaster("local[4]") .setAppName(classOf[OutputCommitCoordinatorSuite].getSimpleName) - .set("spark.speculation", "true") + .set("spark.hadoop.outputCommitCoordination.enabled", "true") sc = new SparkContext(conf) { override private[spark] def createSparkEnv( conf: SparkConf, diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 103fc19369c97..35215c15ea805 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -23,11 +23,10 @@ import java.net.URI import org.json4s.jackson.JsonMethods._ import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkConf, SparkContext, SPARK_VERSION} import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec -import org.apache.spark.util.{JsonProtocol, Utils} +import org.apache.spark.util.{JsonProtocol, JsonProtocolSuite, Utils} /** * Test whether ReplayListenerBus replays events from logs correctly. @@ -115,7 +114,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter { val applications = fileSystem.listStatus(logDirPath) assert(applications != null && applications.size > 0) val eventLog = applications.sortBy(_.getModificationTime).last - assert(!eventLog.isDir) + assert(!eventLog.isDirectory) // Replay events val logData = EventLoggingListener.openEventLog(eventLog.getPath(), fileSystem) @@ -132,7 +131,11 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(sc.eventLogger.isDefined) val originalEvents = sc.eventLogger.get.loggedEvents val replayedEvents = eventMonster.loggedEvents - originalEvents.zip(replayedEvents).foreach { case (e1, e2) => assert(e1 === e2) } + originalEvents.zip(replayedEvents).foreach { case (e1, e2) => + // Don't compare the JSON here because accumulators in StageInfo may be out of order + JsonProtocolSuite.assertEquals( + JsonProtocol.sparkEventFromJson(e1), JsonProtocol.sparkEventFromJson(e2)) + } } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 53102b9f1c936..b854d742b5bdd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -24,9 +24,9 @@ import scala.collection.JavaConverters._ import org.scalatest.Matchers +import org.apache.spark._ import org.apache.spark.executor.TaskMetrics -import org.apache.spark.util.ResetSystemProperties -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.util.{ResetSystemProperties, RpcUtils} class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Matchers with ResetSystemProperties { @@ -36,6 +36,21 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val jobCompletionTime = 1421191296660L + test("don't call sc.stop in listener") { + sc = new SparkContext("local", "SparkListenerSuite") + val listener = new SparkContextStoppingListener(sc) + val bus = new LiveListenerBus + bus.addListener(listener) + + // Starting listener bus should flush all buffered events + bus.start(sc) + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + + bus.stop() + assert(listener.sparkExSeen) + } + test("basic creation and shutdown of LiveListenerBus") { val counter = new BasicJobCounter val bus = new LiveListenerBus @@ -254,7 +269,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match taskMetrics.inputMetrics should not be ('defined) taskMetrics.outputMetrics should not be ('defined) taskMetrics.shuffleWriteMetrics should be ('defined) - taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0L) + taskMetrics.shuffleWriteMetrics.get.bytesWritten should be > (0L) } if (stageInfo.rddInfos.exists(_.name == d4.name)) { taskMetrics.shuffleReadMetrics should be ('defined) @@ -269,18 +284,18 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match } test("onTaskGettingResult() called when result fetched remotely") { - sc = new SparkContext("local", "SparkListenerSuite") + val conf = new SparkConf().set("spark.rpc.message.maxSize", "1") + sc = new SparkContext("local", "SparkListenerSuite", conf) val listener = new SaveTaskEvents sc.addSparkListener(listener) - // Make a task whose result is larger than the akka frame size - System.setProperty("spark.akka.frameSize", "1") - val akkaFrameSize = - sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt + // Make a task whose result is larger than the RPC message size + val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) + assert(maxRpcMessageSize === 1024 * 1024) val result = sc.parallelize(Seq(1), 1) - .map { x => 1.to(akkaFrameSize).toArray } + .map { x => 1.to(maxRpcMessageSize).toArray } .reduce { case (x, y) => x } - assert(result === 1.to(akkaFrameSize).toArray) + assert(result === 1.to(maxRpcMessageSize).toArray) sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) val TASK_INDEX = 0 @@ -294,7 +309,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val listener = new SaveTaskEvents sc.addSparkListener(listener) - // Make a task whose result is larger than the akka frame size + // Make a task whose result is larger than the RPC message size val result = sc.parallelize(Seq(1), 1).map(2 * _).reduce { case (x, y) => x } assert(result === 2) @@ -362,13 +377,18 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match } test("registering listeners via spark.extraListeners") { + val listeners = Seq( + classOf[ListenerThatAcceptsSparkConf], + classOf[FirehoseListenerThatAcceptsSparkConf], + classOf[BasicJobCounter]) val conf = new SparkConf().setMaster("local").setAppName("test") - .set("spark.extraListeners", classOf[ListenerThatAcceptsSparkConf].getName + "," + - classOf[BasicJobCounter].getName) + .set("spark.extraListeners", listeners.map(_.getName).mkString(",")) sc = new SparkContext(conf) sc.listenerBus.listeners.asScala.count(_.isInstanceOf[BasicJobCounter]) should be (1) sc.listenerBus.listeners.asScala .count(_.isInstanceOf[ListenerThatAcceptsSparkConf]) should be (1) + sc.listenerBus.listeners.asScala + .count(_.isInstanceOf[FirehoseListenerThatAcceptsSparkConf]) should be (1) } /** @@ -442,7 +462,30 @@ private class BasicJobCounter extends SparkListener { override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1 } +/** + * A simple listener that tries to stop SparkContext. + */ +private class SparkContextStoppingListener(val sc: SparkContext) extends SparkListener { + @volatile var sparkExSeen = false + override def onJobEnd(job: SparkListenerJobEnd): Unit = { + try { + sc.stop() + } catch { + case se: SparkException => + sparkExSeen = true + } + } +} + private class ListenerThatAcceptsSparkConf(conf: SparkConf) extends SparkListener { var count = 0 override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1 } + +private class FirehoseListenerThatAcceptsSparkConf(conf: SparkConf) extends SparkFirehoseListener { + var count = 0 + override def onEvent(event: SparkListenerEvent): Unit = event match { + case job: SparkListenerJobEnd => count += 1 + case _ => + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 450ab7b9fe92b..86911d2211a3a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -17,16 +17,19 @@ package org.apache.spark.scheduler -import org.mockito.Mockito._ -import org.mockito.Matchers.any +import java.util.Properties +import org.mockito.Matchers.any +import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter import org.apache.spark._ -import org.apache.spark.rdd.RDD -import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} +import org.apache.spark.executor.{Executor, TaskMetricsSuite} +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.JvmSource - +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.util._ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { @@ -57,15 +60,36 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark } val closureSerializer = SparkEnv.get.closureSerializer.newInstance() val func = (c: TaskContext, i: Iterator[String]) => i.next() - val taskBinary = sc.broadcast(closureSerializer.serialize((rdd, func)).array) + val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) val task = new ResultTask[String, String]( - 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty) + 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties) intercept[RuntimeException] { task.run(0, 0, null) } assert(TaskContextSuite.completed === true) } + test("calls TaskFailureListeners after failure") { + TaskContextSuite.lastError = null + sc = new SparkContext("local", "test") + val rdd = new RDD[String](sc, List()) { + override def getPartitions = Array[Partition](StubPartition(0)) + override def compute(split: Partition, context: TaskContext) = { + context.addTaskFailureListener((context, error) => TaskContextSuite.lastError = error) + sys.error("damn error") + } + } + val closureSerializer = SparkEnv.get.closureSerializer.newInstance() + val func = (c: TaskContext, i: Iterator[String]) => i.next() + val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) + val task = new ResultTask[String, String]( + 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties) + intercept[RuntimeException] { + task.run(0, 0, null) + } + assert(TaskContextSuite.lastError.getMessage == "damn error") + } + test("all TaskCompletionListeners should be called even if some fail") { val context = TaskContext.empty() val listener = mock(classOf[TaskCompletionListener]) @@ -80,6 +104,26 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark verify(listener, times(1)).onTaskCompletion(any()) } + test("all TaskFailureListeners should be called even if some fail") { + val context = TaskContext.empty() + val listener = mock(classOf[TaskFailureListener]) + context.addTaskFailureListener((_, _) => throw new Exception("exception in listener1")) + context.addTaskFailureListener(listener) + context.addTaskFailureListener((_, _) => throw new Exception("exception in listener3")) + + val e = intercept[TaskCompletionListenerException] { + context.markTaskFailed(new Exception("exception in task")) + } + + // Make sure listener 2 was called. + verify(listener, times(1)).onTaskFailure(any(), any()) + + // also need to check failure in TaskFailureListener does not mask earlier exception + assert(e.getMessage.contains("exception in listener1")) + assert(e.getMessage.contains("exception in listener3")) + assert(e.getMessage.contains("exception in task")) + } + test("TaskContext.attemptNumber should return attempt number, not task id (SPARK-4014)") { sc = new SparkContext("local[1,2]", "test") // use maxRetries = 2 because we test failed tasks // Check that attemptIds are 0 for all tasks' initial attempts @@ -99,17 +143,74 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(attemptIdsWithFailedTask.toSet === Set(0, 1)) } - test("TaskContext.attemptId returns taskAttemptId for backwards-compatibility (SPARK-4014)") { + test("accumulators are updated on exception failures") { + // This means use 1 core and 4 max task failures + sc = new SparkContext("local[1,4]", "test") + val param = AccumulatorParam.LongAccumulatorParam + // Create 2 accumulators, one that counts failed values and another that doesn't + val acc1 = new Accumulator(0L, param, Some("x"), internal = false, countFailedValues = true) + val acc2 = new Accumulator(0L, param, Some("y"), internal = false, countFailedValues = false) + // Fail first 3 attempts of every task. This means each task should be run 4 times. + sc.parallelize(1 to 10, 10).map { i => + acc1 += 1 + acc2 += 1 + if (TaskContext.get.attemptNumber() <= 2) { + throw new Exception("you did something wrong") + } else { + 0 + } + }.count() + // The one that counts failed values should be 4x the one that didn't, + // since we ran each task 4 times + assert(Accumulators.get(acc1.id).get.value === 40L) + assert(Accumulators.get(acc2.id).get.value === 10L) + } + + test("failed tasks collect only accumulators whose values count during failures") { sc = new SparkContext("local", "test") - val attemptIds = sc.parallelize(Seq(1, 2, 3, 4), 4).mapPartitions { iter => - Seq(TaskContext.get().attemptId).iterator - }.collect() - assert(attemptIds.toSet === Set(0, 1, 2, 3)) + val param = AccumulatorParam.LongAccumulatorParam + val acc1 = new Accumulator(0L, param, Some("x"), internal = false, countFailedValues = true) + val acc2 = new Accumulator(0L, param, Some("y"), internal = false, countFailedValues = false) + val initialAccums = InternalAccumulator.createAll() + // Create a dummy task. We won't end up running this; we just want to collect + // accumulator updates from it. + val task = new Task[Int](0, 0, 0, Seq.empty[Accumulator[_]], new Properties) { + context = new TaskContextImpl(0, 0, 0L, 0, + new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), + new Properties, + SparkEnv.get.metricsSystem, + initialAccums) + context.taskMetrics.registerAccumulator(acc1) + context.taskMetrics.registerAccumulator(acc2) + override def runTask(tc: TaskContext): Int = 0 + } + // First, simulate task success. This should give us all the accumulators. + val accumUpdates1 = task.collectAccumulatorUpdates(taskFailed = false) + val accumUpdates2 = (initialAccums ++ Seq(acc1, acc2)).map(TaskMetricsSuite.makeInfo) + TaskMetricsSuite.assertUpdatesEquals(accumUpdates1, accumUpdates2) + // Now, simulate task failures. This should give us only the accums that count failed values. + val accumUpdates3 = task.collectAccumulatorUpdates(taskFailed = true) + val accumUpdates4 = (initialAccums ++ Seq(acc1)).map(TaskMetricsSuite.makeInfo) + TaskMetricsSuite.assertUpdatesEquals(accumUpdates3, accumUpdates4) } + + test("localProperties are propagated to executors correctly") { + sc = new SparkContext("local", "test") + sc.setLocalProperty("testPropKey", "testPropValue") + val res = sc.parallelize(Array(1), 1).map(i => i).map(i => { + val inTask = TaskContext.get().getLocalProperty("testPropKey") + val inDeser = Executor.taskDeserializationProps.get().getProperty("testPropKey") + s"$inTask,$inDeser" + }).collect() + assert(res === Array("testPropValue,testPropValue")) + } + } private object TaskContextSuite { @volatile var completed = false + + @volatile var lastError: Throwable = _ } private case class StubPartition(index: Int) extends Partition diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 815caa79ff529..b5385c11a926e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -17,17 +17,27 @@ package org.apache.spark.scheduler +import java.io.File +import java.net.URL import java.nio.ByteBuffer +import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.control.NonFatal +import com.google.common.util.concurrent.MoreExecutors +import org.mockito.ArgumentCaptor +import org.mockito.Matchers.{any, anyLong} +import org.mockito.Mockito.{spy, times, verify} import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.storage.TaskResultBlockId +import org.apache.spark.TestUtils.JavaSourceFromString +import org.apache.spark.util.{MutableURLClassLoader, RpcUtils, Utils} + /** * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter. @@ -35,7 +45,7 @@ import org.apache.spark.storage.TaskResultBlockId * Used to test the case where a BlockManager evicts the task result (or dies) before the * TaskResult is retrieved. */ -class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl) +private class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl) extends TaskResultGetter(sparkEnv, scheduler) { var removedResult = false @@ -68,27 +78,52 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedule } } + +/** + * A [[TaskResultGetter]] that stores the [[DirectTaskResult]]s it receives from executors + * _before_ modifying the results in any way. + */ +private class MyTaskResultGetter(env: SparkEnv, scheduler: TaskSchedulerImpl) + extends TaskResultGetter(env, scheduler) { + + // Use the current thread so we can access its results synchronously + protected override val getTaskResultExecutor = MoreExecutors.sameThreadExecutor() + + // DirectTaskResults that we receive from the executors + private val _taskResults = new ArrayBuffer[DirectTaskResult[_]] + + def taskResults: Seq[DirectTaskResult[_]] = _taskResults + + override def enqueueSuccessfulTask(tsm: TaskSetManager, tid: Long, data: ByteBuffer): Unit = { + // work on a copy since the super class still needs to use the buffer + val newBuffer = data.duplicate() + _taskResults += env.closureSerializer.newInstance().deserialize[DirectTaskResult[_]](newBuffer) + super.enqueueSuccessfulTask(tsm, tid, data) + } +} + + /** * Tests related to handling task results (both direct and indirect). */ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { - // Set the Akka frame size to be as small as possible (it must be an integer, so 1 is as small + // Set the RPC message size to be as small as possible (it must be an integer, so 1 is as small // as we can make it) so the tests don't take too long. - def conf: SparkConf = new SparkConf().set("spark.akka.frameSize", "1") + def conf: SparkConf = new SparkConf().set("spark.rpc.message.maxSize", "1") - test("handling results smaller than Akka frame size") { + test("handling results smaller than max RPC message size") { sc = new SparkContext("local", "test", conf) val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x) assert(result === 2) } - test("handling results larger than Akka frame size") { + test("handling results larger than max RPC message size") { sc = new SparkContext("local", "test", conf) - val akkaFrameSize = - sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt - val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) - assert(result === 1.to(akkaFrameSize).toArray) + val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) + val result = + sc.parallelize(Seq(1), 1).map(x => 1.to(maxRpcMessageSize).toArray).reduce((x, y) => x) + assert(result === 1.to(maxRpcMessageSize).toArray) val RESULT_BLOCK_ID = TaskResultBlockId(0) assert(sc.env.blockManager.master.getLocations(RESULT_BLOCK_ID).size === 0, @@ -110,14 +145,107 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local } val resultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler) scheduler.taskResultGetter = resultGetter - val akkaFrameSize = - sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt - val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) + val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) + val result = + sc.parallelize(Seq(1), 1).map(x => 1.to(maxRpcMessageSize).toArray).reduce((x, y) => x) assert(resultGetter.removeBlockSuccessfully) - assert(result === 1.to(akkaFrameSize).toArray) + assert(result === 1.to(maxRpcMessageSize).toArray) // Make sure two tasks were run (one failed one, and a second retried one). assert(scheduler.nextTaskId.get() === 2) } + + /** + * Make sure we are using the context classloader when deserializing failed TaskResults instead + * of the Spark classloader. + + * This test compiles a jar containing an exception and tests that when it is thrown on the + * executor, enqueueFailedTask can correctly deserialize the failure and identify the thrown + * exception as the cause. + + * Before this fix, enqueueFailedTask would throw a ClassNotFoundException when deserializing + * the exception, resulting in an UnknownReason for the TaskEndResult. + */ + test("failed task deserialized with the correct classloader (SPARK-11195)") { + // compile a small jar containing an exception that will be thrown on an executor. + val tempDir = Utils.createTempDir() + val srcDir = new File(tempDir, "repro/") + srcDir.mkdirs() + val excSource = new JavaSourceFromString(new File(srcDir, "MyException").getAbsolutePath, + """package repro; + | + |public class MyException extends Exception { + |} + """.stripMargin) + val excFile = TestUtils.createCompiledClass("MyException", srcDir, excSource, Seq.empty) + val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis())) + TestUtils.createJar(Seq(excFile), jarFile, directoryPrefix = Some("repro")) + + // ensure we reset the classloader after the test completes + val originalClassLoader = Thread.currentThread.getContextClassLoader + try { + // load the exception from the jar + val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader) + loader.addURL(jarFile.toURI.toURL) + Thread.currentThread().setContextClassLoader(loader) + val excClass: Class[_] = Utils.classForName("repro.MyException") + + // NOTE: we must run the cluster with "local" so that the executor can load the compiled + // jar. + sc = new SparkContext("local", "test", conf) + val rdd = sc.parallelize(Seq(1), 1).map { _ => + val exc = excClass.newInstance().asInstanceOf[Exception] + throw exc + } + + // the driver should not have any problems resolving the exception class and determining + // why the task failed. + val exceptionMessage = intercept[SparkException] { + rdd.collect() + }.getMessage + + val expectedFailure = """(?s).*Lost task.*: repro.MyException.*""".r + val unknownFailure = """(?s).*Lost task.*: UnknownReason.*""".r + + assert(expectedFailure.findFirstMatchIn(exceptionMessage).isDefined) + assert(unknownFailure.findFirstMatchIn(exceptionMessage).isEmpty) + } finally { + Thread.currentThread.setContextClassLoader(originalClassLoader) + } + } + + test("task result size is set on the driver, not the executors") { + import InternalAccumulator._ + + // Set up custom TaskResultGetter and TaskSchedulerImpl spy + sc = new SparkContext("local", "test", conf) + val scheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] + val spyScheduler = spy(scheduler) + val resultGetter = new MyTaskResultGetter(sc.env, spyScheduler) + val newDAGScheduler = new DAGScheduler(sc, spyScheduler) + scheduler.taskResultGetter = resultGetter + sc.dagScheduler = newDAGScheduler + sc.taskScheduler = spyScheduler + sc.taskScheduler.setDAGScheduler(newDAGScheduler) + + // Just run 1 task and capture the corresponding DirectTaskResult + sc.parallelize(1 to 1, 1).count() + val captor = ArgumentCaptor.forClass(classOf[DirectTaskResult[_]]) + verify(spyScheduler, times(1)).handleSuccessfulTask(any(), anyLong(), captor.capture()) + + // When a task finishes, the executor sends a serialized DirectTaskResult to the driver + // without setting the result size so as to avoid serializing the result again. Instead, + // the result size is set later in TaskResultGetter on the driver before passing the + // DirectTaskResult on to TaskSchedulerImpl. In this test, we capture the DirectTaskResult + // before and after the result size is set. + assert(resultGetter.taskResults.size === 1) + val resBefore = resultGetter.taskResults.head + val resAfter = captor.getValue + val resSizeBefore = resBefore.accumUpdates.find(_.name == Some(RESULT_SIZE)).flatMap(_.update) + val resSizeAfter = resAfter.accumUpdates.find(_.name == Some(RESULT_SIZE)).flatMap(_.update) + assert(resSizeBefore.exists(_ == 0L)) + assert(resSizeAfter.exists(_.toString.toLong > 0L)) + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 2afb595e6f10d..a09a602d1368d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import org.apache.spark._ +import org.apache.spark.internal.Logging class FakeSchedulerBackend extends SchedulerBackend { def start() {} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index ecc18fc6e15b4..ade8e84d848f0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.scheduler -import java.util.Random +import java.util.{Properties, Random} import scala.collection.Map import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark._ -import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.Logging import org.apache.spark.util.ManualClock class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) @@ -38,9 +38,8 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) task: Task[_], reason: TaskEndReason, result: Any, - accumUpdates: Map[Long, Any], - taskInfo: TaskInfo, - taskMetrics: TaskMetrics) { + accumUpdates: Seq[AccumulableInfo], + taskInfo: TaskInfo) { taskScheduler.endedTasks(taskInfo.index) = reason } @@ -139,7 +138,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex /** * A Task implementation that results in a large serialized task. */ -class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) { +class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, Seq.empty, new Properties) { val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024) val random = new Random(0) random.nextBytes(randomBuffer) @@ -167,14 +166,15 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val taskSet = FakeTask.createTaskSet(1) val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val accumUpdates = taskSet.tasks.head.initialAccumulators.map { a => a.toInfo(Some(0L), None) } // Offer a host with NO_PREF as the constraint, // we should get a nopref task immediately since that's what we only have - var taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) + val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) assert(taskOption.isDefined) // Tell it the task has finished - manager.handleSuccessfulTask(0, createTaskResult(0)) + manager.handleSuccessfulTask(0, createTaskResult(0, accumUpdates)) assert(sched.endedTasks(0) === Success) assert(sched.finishedManagers.contains(manager)) } @@ -184,10 +184,13 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(3) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) + val accumUpdatesByTask: Array[Seq[AccumulableInfo]] = taskSet.tasks.map { task => + task.initialAccumulators.map { a => a.toInfo(Some(0L), None) } + } // First three offers should all find tasks for (i <- 0 until 3) { - var taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) + val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) assert(taskOption.isDefined) val task = taskOption.get assert(task.executorId === "exec1") @@ -198,14 +201,14 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(manager.resourceOffer("exec1", "host1", NO_PREF) === None) // Finish the first two tasks - manager.handleSuccessfulTask(0, createTaskResult(0)) - manager.handleSuccessfulTask(1, createTaskResult(1)) + manager.handleSuccessfulTask(0, createTaskResult(0, accumUpdatesByTask(0))) + manager.handleSuccessfulTask(1, createTaskResult(1, accumUpdatesByTask(1))) assert(sched.endedTasks(0) === Success) assert(sched.endedTasks(1) === Success) assert(!sched.finishedManagers.contains(manager)) // Finish the last task - manager.handleSuccessfulTask(2, createTaskResult(2)) + manager.handleSuccessfulTask(2, createTaskResult(2, accumUpdatesByTask(2))) assert(sched.endedTasks(2) === Success) assert(sched.finishedManagers.contains(manager)) } @@ -394,7 +397,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val rescheduleDelay = 300L val conf = new SparkConf(). set("spark.scheduler.executorTaskBlacklistTime", rescheduleDelay.toString). - // dont wait to jump locality levels in this test + // don't wait to jump locality levels in this test set("spark.locality.wait", "0") sc = new SparkContext("local", "test", conf) @@ -620,7 +623,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // multiple 1k result val r = sc.makeRDD(0 until 10, 10).map(genBytes(1024)).collect() - assert(10 === r.size ) + assert(10 === r.size) // single 10M result val thrown = intercept[SparkException] {sc.makeRDD(genBytes(10 << 20)(0), 1).collect()} @@ -761,7 +764,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // Regression test for SPARK-2931 sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, - ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) + ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) val taskSet = FakeTask.createTaskSet(3, Seq(TaskLocation("host1")), Seq(TaskLocation("host2")), @@ -786,8 +789,10 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(TaskLocation("executor_host1_3") === ExecutorCacheTaskLocation("host1", "3")) } - def createTaskResult(id: Int): DirectTaskResult[Int] = { + private def createTaskResult( + id: Int, + accumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo]): DirectTaskResult[Int] = { val valueSer = SparkEnv.get.serializer.newInstance() - new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics) + new DirectTaskResult[Int](valueSer.serialize(id), accumUpdates) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala index 525ee0d3bdc5a..b18f0eb162b1d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -17,49 +17,206 @@ package org.apache.spark.scheduler.cluster.mesos -import java.util import java.util.Collections -import org.apache.mesos.Protos.Value.Scalar -import org.apache.mesos.Protos._ +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + import org.apache.mesos.{Protos, Scheduler, SchedulerDriver} +import org.apache.mesos.Protos._ +import org.apache.mesos.Protos.Value.Scalar +import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Matchers._ import org.mockito.Mockito._ -import org.mockito.Matchers import org.scalatest.mock.MockitoSugar import org.scalatest.BeforeAndAfter +import org.apache.spark.{LocalSparkContext, SecurityManager, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SecurityManager, SparkFunSuite} class CoarseMesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with BeforeAndAfter { - private def createOffer(offerId: String, slaveId: String, mem: Int, cpu: Int): Offer = { - val builder = Offer.newBuilder() - builder.addResourcesBuilder() - .setName("mem") - .setType(Value.Type.SCALAR) - .setScalar(Scalar.newBuilder().setValue(mem)) - builder.addResourcesBuilder() - .setName("cpus") - .setType(Value.Type.SCALAR) - .setScalar(Scalar.newBuilder().setValue(cpu)) - builder.setId(OfferID.newBuilder() - .setValue(offerId).build()) - .setFrameworkId(FrameworkID.newBuilder() - .setValue("f1")) - .setSlaveId(SlaveID.newBuilder().setValue(slaveId)) - .setHostname(s"host${slaveId}") - .build() + private var sparkConf: SparkConf = _ + private var driver: SchedulerDriver = _ + private var taskScheduler: TaskSchedulerImpl = _ + private var backend: CoarseMesosSchedulerBackend = _ + private var externalShuffleClient: MesosExternalShuffleClient = _ + private var driverEndpoint: RpcEndpointRef = _ + + test("mesos supports killing and limiting executors") { + setBackend() + sparkConf.set("spark.driver.host", "driverHost") + sparkConf.set("spark.driver.port", "1234") + + val minMem = backend.executorMemory(sc) + val minCpu = 4 + val offers = List((minMem, minCpu)) + + // launches a task on a valid offer + offerResources(offers) + verifyTaskLaunched("o1") + + // kills executors + backend.doRequestTotalExecutors(0) + assert(backend.doKillExecutors(Seq("0"))) + val taskID0 = createTaskId("0") + verify(driver, times(1)).killTask(taskID0) + + // doesn't launch a new task when requested executors == 0 + offerResources(offers, 2) + verifyDeclinedOffer(driver, createOfferId("o2")) + + // Launches a new task when requested executors is positive + backend.doRequestTotalExecutors(2) + offerResources(offers, 2) + verifyTaskLaunched("o2") } - private def createSchedulerBackend( - taskScheduler: TaskSchedulerImpl, - driver: SchedulerDriver): CoarseMesosSchedulerBackend = { + test("mesos supports killing and relaunching tasks with executors") { + setBackend() + + // launches a task on a valid offer + val minMem = backend.executorMemory(sc) + 1024 + val minCpu = 4 + val offer1 = (minMem, minCpu) + val offer2 = (minMem, 1) + offerResources(List(offer1, offer2)) + verifyTaskLaunched("o1") + + // accounts for a killed task + val status = createTaskStatus("0", "s1", TaskState.TASK_KILLED) + backend.statusUpdate(driver, status) + verify(driver, times(1)).reviveOffers() + + // Launches a new task on a valid offer from the same slave + offerResources(List(offer2)) + verifyTaskLaunched("o2") + } + + test("mesos supports spark.executor.cores") { + val executorCores = 4 + setBackend(Map("spark.executor.cores" -> executorCores.toString)) + + val executorMemory = backend.executorMemory(sc) + val offers = List((executorMemory * 2, executorCores + 1)) + offerResources(offers) + + val taskInfos = verifyTaskLaunched("o1") + assert(taskInfos.size() == 1) + + val cpus = backend.getResource(taskInfos.iterator().next().getResourcesList, "cpus") + assert(cpus == executorCores) + } + + test("mesos supports unset spark.executor.cores") { + setBackend() + + val executorMemory = backend.executorMemory(sc) + val offerCores = 10 + offerResources(List((executorMemory * 2, offerCores))) + + val taskInfos = verifyTaskLaunched("o1") + assert(taskInfos.size() == 1) + + val cpus = backend.getResource(taskInfos.iterator().next().getResourcesList, "cpus") + assert(cpus == offerCores) + } + + test("mesos does not acquire more than spark.cores.max") { + val maxCores = 10 + setBackend(Map("spark.cores.max" -> maxCores.toString)) + + val executorMemory = backend.executorMemory(sc) + offerResources(List((executorMemory, maxCores + 1))) + + val taskInfos = verifyTaskLaunched("o1") + assert(taskInfos.size() == 1) + + val cpus = backend.getResource(taskInfos.iterator().next().getResourcesList, "cpus") + assert(cpus == maxCores) + } + + test("mesos declines offers that violate attribute constraints") { + setBackend(Map("spark.mesos.constraints" -> "x:true")) + offerResources(List((backend.executorMemory(sc), 4))) + verifyDeclinedOffer(driver, createOfferId("o1"), true) + } + + test("mesos assigns tasks round-robin on offers") { + val executorCores = 4 + val maxCores = executorCores * 2 + setBackend(Map("spark.executor.cores" -> executorCores.toString, + "spark.cores.max" -> maxCores.toString)) + + val executorMemory = backend.executorMemory(sc) + offerResources(List( + (executorMemory * 2, executorCores * 2), + (executorMemory * 2, executorCores * 2))) + + verifyTaskLaunched("o1") + verifyTaskLaunched("o2") + } + + test("mesos creates multiple executors on a single slave") { + val executorCores = 4 + setBackend(Map("spark.executor.cores" -> executorCores.toString)) + + // offer with room for two executors + val executorMemory = backend.executorMemory(sc) + offerResources(List((executorMemory * 2, executorCores * 2))) + + // verify two executors were started on a single offer + val taskInfos = verifyTaskLaunched("o1") + assert(taskInfos.size() == 2) + } + + test("mesos doesn't register twice with the same shuffle service") { + setBackend(Map("spark.shuffle.service.enabled" -> "true")) + val (mem, cpu) = (backend.executorMemory(sc), 4) + + val offer1 = createOffer("o1", "s1", mem, cpu) + backend.resourceOffers(driver, List(offer1).asJava) + verifyTaskLaunched("o1") + + val offer2 = createOffer("o2", "s1", mem, cpu) + backend.resourceOffers(driver, List(offer2).asJava) + verifyTaskLaunched("o2") + + val status1 = createTaskStatus("0", "s1", TaskState.TASK_RUNNING) + backend.statusUpdate(driver, status1) + + val status2 = createTaskStatus("1", "s1", TaskState.TASK_RUNNING) + backend.statusUpdate(driver, status2) + verify(externalShuffleClient, times(1)) + .registerDriverWithShuffleService(anyString, anyInt, anyLong, anyLong) + } + + test("mesos kills an executor when told") { + setBackend() + + val (mem, cpu) = (backend.executorMemory(sc), 4) + + val offer1 = createOffer("o1", "s1", mem, cpu) + backend.resourceOffers(driver, List(offer1).asJava) + verifyTaskLaunched("o1") + + backend.doKillExecutors(List("0")) + verify(driver, times(1)).killTask(createTaskId("0")) + } + + test("weburi is set in created scheduler driver") { + setBackend() + val taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.sc).thenReturn(sc) + val driver = mock[SchedulerDriver] + when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) val securityManager = mock[SecurityManager] + val backend = new CoarseMesosSchedulerBackend(taskScheduler, sc, "master", securityManager) { override protected def createSchedulerDriver( masterUrl: String, @@ -70,118 +227,142 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite webuiUrl: Option[String] = None, checkpoint: Option[Boolean] = None, failoverTimeout: Option[Double] = None, - frameworkId: Option[String] = None): SchedulerDriver = driver - markRegistered() + frameworkId: Option[String] = None): SchedulerDriver = { + markRegistered() + assert(webuiUrl.isDefined) + assert(webuiUrl.get.equals("http://webui")) + driver + } } + backend.start() - backend } - var sparkConf: SparkConf = _ - - before { - sparkConf = (new SparkConf) - .setMaster("local[*]") - .setAppName("test-mesos-dynamic-alloc") - .setSparkHome("/path") - - sc = new SparkContext(sparkConf) + private def verifyDeclinedOffer(driver: SchedulerDriver, + offerId: OfferID, + filter: Boolean = false): Unit = { + if (filter) { + verify(driver, times(1)).declineOffer(Matchers.eq(offerId), anyObject[Filters]) + } else { + verify(driver, times(1)).declineOffer(Matchers.eq(offerId)) + } } - test("mesos supports killing and limiting executors") { - val driver = mock[SchedulerDriver] - when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) - val taskScheduler = mock[TaskSchedulerImpl] - when(taskScheduler.sc).thenReturn(sc) + private def offerResources(offers: List[(Int, Int)], startId: Int = 1): Unit = { + val mesosOffers = offers.zipWithIndex.map {case (offer, i) => + createOffer(s"o${i + startId}", s"s${i + startId}", offer._1, offer._2)} - sparkConf.set("spark.driver.host", "driverHost") - sparkConf.set("spark.driver.port", "1234") - - val backend = createSchedulerBackend(taskScheduler, driver) - val minMem = backend.calculateTotalMemory(sc) - val minCpu = 4 - - val mesosOffers = new java.util.ArrayList[Offer] - mesosOffers.add(createOffer("o1", "s1", minMem, minCpu)) - - val taskID0 = TaskID.newBuilder().setValue("0").build() + backend.resourceOffers(driver, mesosOffers.asJava) + } - backend.resourceOffers(driver, mesosOffers) + private def verifyTaskLaunched(offerId: String): java.util.Collection[TaskInfo] = { + val captor = ArgumentCaptor.forClass(classOf[java.util.Collection[TaskInfo]]) verify(driver, times(1)).launchTasks( - Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), - any[util.Collection[TaskInfo]], - any[Filters]) - - // simulate the allocation manager down-scaling executors - backend.doRequestTotalExecutors(0) - assert(backend.doKillExecutors(Seq("s1/0"))) - verify(driver, times(1)).killTask(taskID0) + Matchers.eq(Collections.singleton(createOfferId(offerId))), + captor.capture()) + captor.getValue + } - val mesosOffers2 = new java.util.ArrayList[Offer] - mesosOffers2.add(createOffer("o2", "s2", minMem, minCpu)) - backend.resourceOffers(driver, mesosOffers2) + private def createTaskStatus(taskId: String, slaveId: String, state: TaskState): TaskStatus = { + TaskStatus.newBuilder() + .setTaskId(TaskID.newBuilder().setValue(taskId).build()) + .setSlaveId(SlaveID.newBuilder().setValue(slaveId).build()) + .setState(state) + .build + } - verify(driver, times(1)) - .declineOffer(OfferID.newBuilder().setValue("o2").build()) - // Verify we didn't launch any new executor - assert(backend.slaveIdsWithExecutors.size === 1) + private def createOfferId(offerId: String): OfferID = { + OfferID.newBuilder().setValue(offerId).build() + } - backend.doRequestTotalExecutors(2) - backend.resourceOffers(driver, mesosOffers2) - verify(driver, times(1)).launchTasks( - Matchers.eq(Collections.singleton(mesosOffers2.get(0).getId)), - any[util.Collection[TaskInfo]], - any[Filters]) + private def createSlaveId(slaveId: String): SlaveID = { + SlaveID.newBuilder().setValue(slaveId).build() + } - assert(backend.slaveIdsWithExecutors.size === 2) - backend.slaveLost(driver, SlaveID.newBuilder().setValue("s1").build()) - assert(backend.slaveIdsWithExecutors.size === 1) + private def createExecutorId(executorId: String): ExecutorID = { + ExecutorID.newBuilder().setValue(executorId).build() } - test("mesos supports killing and relaunching tasks with executors") { - val driver = mock[SchedulerDriver] - when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) - val taskScheduler = mock[TaskSchedulerImpl] - when(taskScheduler.sc).thenReturn(sc) + private def createTaskId(taskId: String): TaskID = { + TaskID.newBuilder().setValue(taskId).build() + } - val backend = createSchedulerBackend(taskScheduler, driver) - val minMem = backend.calculateTotalMemory(sc) + 1024 - val minCpu = 4 + private def createOffer(offerId: String, slaveId: String, mem: Int, cpu: Int): Offer = { + val builder = Offer.newBuilder() + builder.addResourcesBuilder() + .setName("mem") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(mem)) + builder.addResourcesBuilder() + .setName("cpus") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(cpu)) + builder.setId(createOfferId(offerId)) + .setFrameworkId(FrameworkID.newBuilder() + .setValue("f1")) + .setSlaveId(SlaveID.newBuilder().setValue(slaveId)) + .setHostname(s"host${slaveId}") + .build() + } - val mesosOffers = new java.util.ArrayList[Offer] - val offer1 = createOffer("o1", "s1", minMem, minCpu) - mesosOffers.add(offer1) + private def createSchedulerBackend( + taskScheduler: TaskSchedulerImpl, + driver: SchedulerDriver, + shuffleClient: MesosExternalShuffleClient, + endpoint: RpcEndpointRef): CoarseMesosSchedulerBackend = { + val securityManager = mock[SecurityManager] - val offer2 = createOffer("o2", "s1", minMem, 1); + val backend = new CoarseMesosSchedulerBackend(taskScheduler, sc, "master", securityManager) { + override protected def createSchedulerDriver( + masterUrl: String, + scheduler: Scheduler, + sparkUser: String, + appName: String, + conf: SparkConf, + webuiUrl: Option[String] = None, + checkpoint: Option[Boolean] = None, + failoverTimeout: Option[Double] = None, + frameworkId: Option[String] = None): SchedulerDriver = driver + + override protected def getShuffleClient(): MesosExternalShuffleClient = shuffleClient + + override protected def createDriverEndpointRef( + properties: ArrayBuffer[(String, String)]): RpcEndpointRef = endpoint + + // override to avoid race condition with the driver thread on `mesosDriver` + override def startScheduler(newDriver: SchedulerDriver): Unit = { + mesosDriver = newDriver + } - backend.resourceOffers(driver, mesosOffers) + markRegistered() + } + backend.start() + backend + } - verify(driver, times(1)).launchTasks( - Matchers.eq(Collections.singleton(offer1.getId)), - anyObject(), - anyObject[Filters]) - - // Simulate task killed, executor no longer running - val status = TaskStatus.newBuilder() - .setTaskId(TaskID.newBuilder().setValue("0").build()) - .setSlaveId(SlaveID.newBuilder().setValue("s1").build()) - .setState(TaskState.TASK_KILLED) - .build + private def setBackend(sparkConfVars: Map[String, String] = null) { + sparkConf = (new SparkConf) + .setMaster("local[*]") + .setAppName("test-mesos-dynamic-alloc") + .setSparkHome("/path") + .set("spark.mesos.driver.webui.url", "http://webui") - backend.statusUpdate(driver, status) - assert(!backend.slaveIdsWithExecutors.contains("s1")) + if (sparkConfVars != null) { + for (attr <- sparkConfVars) { + sparkConf.set(attr._1, attr._2) + } + } - mesosOffers.clear() - mesosOffers.add(offer2) - backend.resourceOffers(driver, mesosOffers) - assert(backend.slaveIdsWithExecutors.contains("s1")) + sc = new SparkContext(sparkConf) - verify(driver, times(1)).launchTasks( - Matchers.eq(Collections.singleton(offer2.getId)), - anyObject(), - anyObject[Filters]) + driver = mock[SchedulerDriver] + when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) + taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.sc).thenReturn(sc) + externalShuffleClient = mock[MesosExternalShuffleClient] + driverEndpoint = mock[RpcEndpointRef] - verify(driver, times(1)).reviveOffers() + backend = createSchedulerBackend(taskScheduler, driver, externalShuffleClient, driverEndpoint) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala new file mode 100644 index 0000000000000..a32423dc4fdeb --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import java.util.{Collection, Collections, Date} + +import scala.collection.JavaConverters._ + +import org.apache.mesos.Protos._ +import org.apache.mesos.Protos.Value.{Scalar, Type} +import org.apache.mesos.SchedulerDriver +import org.mockito.{ArgumentCaptor, Matchers} +import org.mockito.Mockito._ +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite} +import org.apache.spark.deploy.Command +import org.apache.spark.deploy.mesos.MesosDriverDescription + + +class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { + + private val command = new Command("mainClass", Seq("arg"), Map(), Seq(), Seq(), Seq()) + private var scheduler: MesosClusterScheduler = _ + + override def beforeEach(): Unit = { + val conf = new SparkConf() + conf.setMaster("mesos://localhost:5050") + conf.setAppName("spark mesos") + scheduler = new MesosClusterScheduler( + new BlackHoleMesosClusterPersistenceEngineFactory, conf) { + override def start(): Unit = { ready = true } + } + scheduler.start() + } + + test("can queue drivers") { + val response = scheduler.submitDriver( + new MesosDriverDescription("d1", "jar", 1000, 1, true, + command, Map[String, String](), "s1", new Date())) + assert(response.success) + val response2 = + scheduler.submitDriver(new MesosDriverDescription( + "d1", "jar", 1000, 1, true, command, Map[String, String](), "s2", new Date())) + assert(response2.success) + val state = scheduler.getSchedulerState() + val queuedDrivers = state.queuedDrivers.toList + assert(queuedDrivers(0).submissionId == response.submissionId) + assert(queuedDrivers(1).submissionId == response2.submissionId) + } + + test("can kill queued drivers") { + val response = scheduler.submitDriver( + new MesosDriverDescription("d1", "jar", 1000, 1, true, + command, Map[String, String](), "s1", new Date())) + assert(response.success) + val killResponse = scheduler.killDriver(response.submissionId) + assert(killResponse.success) + val state = scheduler.getSchedulerState() + assert(state.queuedDrivers.isEmpty) + } + + test("can handle multiple roles") { + val driver = mock[SchedulerDriver] + val response = scheduler.submitDriver( + new MesosDriverDescription("d1", "jar", 1200, 1.5, true, + command, + Map(("spark.mesos.executor.home", "test"), ("spark.app.name", "test")), + "s1", + new Date())) + assert(response.success) + val offer = Offer.newBuilder() + .addResources( + Resource.newBuilder().setRole("*") + .setScalar(Scalar.newBuilder().setValue(1).build()).setName("cpus").setType(Type.SCALAR)) + .addResources( + Resource.newBuilder().setRole("*") + .setScalar(Scalar.newBuilder().setValue(1000).build()) + .setName("mem") + .setType(Type.SCALAR)) + .addResources( + Resource.newBuilder().setRole("role2") + .setScalar(Scalar.newBuilder().setValue(1).build()).setName("cpus").setType(Type.SCALAR)) + .addResources( + Resource.newBuilder().setRole("role2") + .setScalar(Scalar.newBuilder().setValue(500).build()).setName("mem").setType(Type.SCALAR)) + .setId(OfferID.newBuilder().setValue("o1").build()) + .setFrameworkId(FrameworkID.newBuilder().setValue("f1").build()) + .setSlaveId(SlaveID.newBuilder().setValue("s1").build()) + .setHostname("host1") + .build() + + val capture = ArgumentCaptor.forClass(classOf[Collection[TaskInfo]]) + + when( + driver.launchTasks( + Matchers.eq(Collections.singleton(offer.getId)), + capture.capture()) + ).thenReturn(Status.valueOf(1)) + + scheduler.resourceOffers(driver, Collections.singletonList(offer)) + + val taskInfos = capture.getValue + assert(taskInfos.size() == 1) + val taskInfo = taskInfos.iterator().next() + val resources = taskInfo.getResourcesList + assert(scheduler.getResource(resources, "cpus") == 1.5) + assert(scheduler.getResource(resources, "mem") == 1200) + val resourcesSeq: Seq[Resource] = resources.asScala + val cpus = resourcesSeq.filter(_.getName.equals("cpus")).toList + assert(cpus.size == 2) + assert(cpus.exists(_.getRole().equals("role2"))) + assert(cpus.exists(_.getRole().equals("*"))) + val mem = resourcesSeq.filter(_.getName.equals("mem")).toList + assert(mem.size == 2) + assert(mem.exists(_.getRole().equals("role2"))) + assert(mem.exists(_.getRole().equals("*"))) + + verify(driver, times(1)).launchTasks( + Matchers.eq(Collections.singleton(offer.getId)), + capture.capture() + ) + } + + test("escapes commandline args for the shell") { + val conf = new SparkConf() + conf.setMaster("mesos://localhost:5050") + conf.setAppName("spark mesos") + val scheduler = new MesosClusterScheduler( + new BlackHoleMesosClusterPersistenceEngineFactory, conf) { + override def start(): Unit = { ready = true } + } + val escape = scheduler.shellEscape _ + def wrapped(str: String): String = "\"" + str + "\"" + + // Wrapped in quotes + assert(escape("'should be left untouched'") === "'should be left untouched'") + assert(escape("\"should be left untouched\"") === "\"should be left untouched\"") + + // Harmless + assert(escape("") === "") + assert(escape("harmless") === "harmless") + assert(escape("har-m.l3ss") === "har-m.l3ss") + + // Special Chars escape + assert(escape("should escape this \" quote") === wrapped("should escape this \\\" quote")) + assert(escape("shouldescape\"quote") === wrapped("shouldescape\\\"quote")) + assert(escape("should escape this $ dollar") === wrapped("should escape this \\$ dollar")) + assert(escape("should escape this ` backtick") === wrapped("should escape this \\` backtick")) + assert(escape("""should escape this \ backslash""") + === wrapped("""should escape this \\ backslash""")) + assert(escape("""\"?""") === wrapped("""\\\"?""")) + + + // Special Chars no escape only wrap + List(" ", "'", "<", ">", "&", "|", "?", "*", ";", "!", "#", "(", ")").foreach(char => { + assert(escape(s"onlywrap${char}this") === wrapped(s"onlywrap${char}this")) + }) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index c4dc560031207..7d6b7bde68253 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -26,22 +26,57 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.mesos.Protos.Value.Scalar +import org.apache.mesos.{Protos, Scheduler, SchedulerDriver} import org.apache.mesos.Protos._ -import org.apache.mesos.SchedulerDriver +import org.apache.mesos.Protos.Value.Scalar +import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Matchers._ import org.mockito.Mockito._ -import org.mockito.{ArgumentCaptor, Matchers} import org.scalatest.mock.MockitoSugar +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.executor.MesosExecutorBackend -import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerExecutorAdded, TaskDescription, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.cluster.ExecutorInfo class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { + test("weburi is set in created scheduler driver") { + val conf = new SparkConf + conf.set("spark.mesos.driver.webui.url", "http://webui") + conf.set("spark.app.name", "name1") + + val sc = mock[SparkContext] + when(sc.conf).thenReturn(conf) + when(sc.sparkUser).thenReturn("sparkUser1") + when(sc.appName).thenReturn("appName1") + + val taskScheduler = mock[TaskSchedulerImpl] + val driver = mock[SchedulerDriver] + when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) + + val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") { + override protected def createSchedulerDriver( + masterUrl: String, + scheduler: Scheduler, + sparkUser: String, + appName: String, + conf: SparkConf, + webuiUrl: Option[String] = None, + checkpoint: Option[Boolean] = None, + failoverTimeout: Option[Double] = None, + frameworkId: Option[String] = None): SchedulerDriver = { + markRegistered() + assert(webuiUrl.isDefined) + assert(webuiUrl.get.equals("http://webui")) + driver + } + } + + backend.start() + } + test("Use configured mesosExecutor.cores for ExecutorInfo") { val mesosExecutorCores = 3 val conf = new SparkConf @@ -76,7 +111,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi test("check spark-class location correctly") { val conf = new SparkConf - conf.set("spark.mesos.executor.home" , "/mesos-home") + conf.set("spark.mesos.executor.home", "/mesos-home") val listenerBus = mock[LiveListenerBus] listenerBus.post( @@ -189,7 +224,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") - val minMem = backend.calculateTotalMemory(sc) + val minMem = backend.executorMemory(sc) val minCpu = 4 val mesosOffers = new java.util.ArrayList[Offer] diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala index 2eb43b7313381..ceb3a52983cd8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala @@ -41,28 +41,28 @@ class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoS test("use at-least minimum overhead") { val f = fixture when(f.sc.executorMemory).thenReturn(512) - utils.calculateTotalMemory(f.sc) shouldBe 896 + utils.executorMemory(f.sc) shouldBe 896 } test("use overhead if it is greater than minimum value") { val f = fixture when(f.sc.executorMemory).thenReturn(4096) - utils.calculateTotalMemory(f.sc) shouldBe 4505 + utils.executorMemory(f.sc) shouldBe 4505 } test("use spark.mesos.executor.memoryOverhead (if set)") { val f = fixture when(f.sc.executorMemory).thenReturn(1024) f.sparkConf.set("spark.mesos.executor.memoryOverhead", "512") - utils.calculateTotalMemory(f.sc) shouldBe 1536 + utils.executorMemory(f.sc) shouldBe 1536 } test("parse a non-empty constraint string correctly") { val expectedMap = Map( - "tachyon" -> Set("true"), + "os" -> Set("centos7"), "zone" -> Set("us-east-1a", "us-east-1b") ) - utils.parseConstraintString("tachyon:true;zone:us-east-1a,us-east-1b") should be (expectedMap) + utils.parseConstraintString("os:centos7;zone:us-east-1a,us-east-1b") should be (expectedMap) } test("parse an empty constraint string correctly") { @@ -71,35 +71,35 @@ class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoS test("throw an exception when the input is malformed") { an[IllegalArgumentException] should be thrownBy - utils.parseConstraintString("tachyon;zone:us-east") + utils.parseConstraintString("os;zone:us-east") } test("empty values for attributes' constraints matches all values") { - val constraintsStr = "tachyon:" + val constraintsStr = "os:" val parsedConstraints = utils.parseConstraintString(constraintsStr) - parsedConstraints shouldBe Map("tachyon" -> Set()) + parsedConstraints shouldBe Map("os" -> Set()) val zoneSet = Value.Set.newBuilder().addItem("us-east-1a").addItem("us-east-1b").build() - val noTachyonOffer = Map("zone" -> zoneSet) - val tachyonTrueOffer = Map("tachyon" -> Value.Text.newBuilder().setValue("true").build()) - val tachyonFalseOffer = Map("tachyon" -> Value.Text.newBuilder().setValue("false").build()) + val noOsOffer = Map("zone" -> zoneSet) + val centosOffer = Map("os" -> Value.Text.newBuilder().setValue("centos").build()) + val ubuntuOffer = Map("os" -> Value.Text.newBuilder().setValue("ubuntu").build()) - utils.matchesAttributeRequirements(parsedConstraints, noTachyonOffer) shouldBe false - utils.matchesAttributeRequirements(parsedConstraints, tachyonTrueOffer) shouldBe true - utils.matchesAttributeRequirements(parsedConstraints, tachyonFalseOffer) shouldBe true + utils.matchesAttributeRequirements(parsedConstraints, noOsOffer) shouldBe false + utils.matchesAttributeRequirements(parsedConstraints, centosOffer) shouldBe true + utils.matchesAttributeRequirements(parsedConstraints, ubuntuOffer) shouldBe true } test("subset match is performed for set attributes") { val supersetConstraint = Map( - "tachyon" -> Value.Text.newBuilder().setValue("true").build(), + "os" -> Value.Text.newBuilder().setValue("ubuntu").build(), "zone" -> Value.Set.newBuilder() .addItem("us-east-1a") .addItem("us-east-1b") .addItem("us-east-1c") .build()) - val zoneConstraintStr = "tachyon:;zone:us-east-1a,us-east-1c" + val zoneConstraintStr = "os:;zone:us-east-1a,us-east-1c" val parsedConstraints = utils.parseConstraintString(zoneConstraintStr) utils.matchesAttributeRequirements(parsedConstraints, supersetConstraint) shouldBe true @@ -131,10 +131,10 @@ class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoS } test("equality match is performed for text attributes") { - val offerAttribs = Map("tachyon" -> Value.Text.newBuilder().setValue("true").build()) + val offerAttribs = Map("os" -> Value.Text.newBuilder().setValue("centos7").build()) - val trueConstraint = utils.parseConstraintString("tachyon:true") - val falseConstraint = utils.parseConstraintString("tachyon:false") + val trueConstraint = utils.parseConstraintString("os:centos7") + val falseConstraint = utils.parseConstraintString("os:ubuntu") utils.matchesAttributeRequirements(trueConstraint, offerAttribs) shouldBe true utils.matchesAttributeRequirements(falseConstraint, offerAttribs) shouldBe false diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala deleted file mode 100644 index f5cef1caaf1ac..0000000000000 --- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler.mesos - -import java.util.Date - -import org.scalatest.mock.MockitoSugar - -import org.apache.spark.deploy.Command -import org.apache.spark.deploy.mesos.MesosDriverDescription -import org.apache.spark.scheduler.cluster.mesos._ -import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite} - - -class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { - - private val command = new Command("mainClass", Seq("arg"), null, null, null, null) - - test("can queue drivers") { - val conf = new SparkConf() - conf.setMaster("mesos://localhost:5050") - conf.setAppName("spark mesos") - val scheduler = new MesosClusterScheduler( - new BlackHoleMesosClusterPersistenceEngineFactory, conf) { - override def start(): Unit = { ready = true } - } - scheduler.start() - val response = scheduler.submitDriver( - new MesosDriverDescription("d1", "jar", 1000, 1, true, - command, Map[String, String](), "s1", new Date())) - assert(response.success) - val response2 = - scheduler.submitDriver(new MesosDriverDescription( - "d1", "jar", 1000, 1, true, command, Map[String, String](), "s2", new Date())) - assert(response2.success) - val state = scheduler.getSchedulerState() - val queuedDrivers = state.queuedDrivers.toList - assert(queuedDrivers(0).submissionId == response.submissionId) - assert(queuedDrivers(1).submissionId == response2.submissionId) - } - - test("can kill queued drivers") { - val conf = new SparkConf() - conf.setMaster("mesos://localhost:5050") - conf.setAppName("spark mesos") - val scheduler = new MesosClusterScheduler( - new BlackHoleMesosClusterPersistenceEngineFactory, conf) { - override def start(): Unit = { ready = true } - } - scheduler.start() - val response = scheduler.submitDriver( - new MesosDriverDescription("d1", "jar", 1000, 1, true, - command, Map[String, String](), "s1", new Date())) - assert(response.success) - val killResponse = scheduler.killDriver(response.submissionId) - assert(killResponse.success) - val state = scheduler.getSchedulerState() - assert(state.queuedDrivers.isEmpty) - } -} diff --git a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala index 87f25e7245e1f..3734f1cb408fe 100644 --- a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.serializer import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import java.nio.ByteBuffer -import com.esotericsoftware.kryo.io.{Output, Input} -import org.apache.avro.{SchemaBuilder, Schema} +import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.avro.{Schema, SchemaBuilder} import org.apache.avro.generic.GenericData.Record -import org.apache.spark.{SparkFunSuite, SharedSparkContext} +import org.apache.spark.{SharedSparkContext, SparkFunSuite} class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") diff --git a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala index 20f45670bc2ba..6a6ea42797fb6 100644 --- a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala @@ -23,13 +23,18 @@ class JavaSerializerSuite extends SparkFunSuite { test("JavaSerializer instances are serializable") { val serializer = new JavaSerializer(new SparkConf()) val instance = serializer.newInstance() - instance.deserialize[JavaSerializer](instance.serialize(serializer)) + val obj = instance.deserialize[JavaSerializer](instance.serialize(serializer)) + // enforce class cast + obj.getClass } test("Deserialize object containing a primitive Class as attribute") { val serializer = new JavaSerializer(new SparkConf()) val instance = serializer.newInstance() - instance.deserialize[JavaSerializer](instance.serialize(new ContainsPrimitiveClass())) + val obj = instance.deserialize[ContainsPrimitiveClass](instance.serialize( + new ContainsPrimitiveClass())) + // enforce class cast + obj.getClass } } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala index 935a091f14f9b..c1484b0afa85f 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.serializer -import org.apache.spark.util.Utils - import com.esotericsoftware.kryo.Kryo import org.apache.spark._ import org.apache.spark.serializer.KryoDistributedTest._ +import org.apache.spark.util.Utils -class KryoSerializerDistributedSuite extends SparkFunSuite { +class KryoSerializerDistributedSuite extends SparkFunSuite with LocalSparkContext { test("kryo objects are serialised consistently in different processes") { val conf = new SparkConf(false) @@ -35,7 +34,7 @@ class KryoSerializerDistributedSuite extends SparkFunSuite { val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName)) conf.setJars(List(jar.getPath)) - val sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) val original = Thread.currentThread.getContextClassLoader val loader = new java.net.URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader) SparkEnv.get.serializer.setDefaultClassLoader(loader) @@ -48,8 +47,6 @@ class KryoSerializerDistributedSuite extends SparkFunSuite { // Join the two RDDs, and force evaluation assert(shuffledRDD.join(cachedRDD).collect().size == 1) - - LocalSparkContext.stop(sc) } } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala index a9b209ccfc76e..21251f0b93760 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala @@ -18,11 +18,10 @@ package org.apache.spark.serializer import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.SparkContext import org.apache.spark.LocalSparkContext +import org.apache.spark.SparkContext import org.apache.spark.SparkException - class KryoSerializerResizableOutputSuite extends SparkFunSuite { // trial and error showed this will not serialize with 1mb buffer diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index afe2e80358ca0..27d063630be9d 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -17,18 +17,21 @@ package org.apache.spark.serializer -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} +import org.roaringbitmap.RoaringBitmap import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.scheduler.HighlyCompressedMapStatus import org.apache.spark.serializer.KryoTest._ import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.Utils class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") @@ -144,10 +147,10 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { check(mutable.Map("one" -> 1, "two" -> 2)) check(mutable.HashMap(1 -> "one", 2 -> "two")) check(mutable.HashMap("one" -> 1, "two" -> 2)) - check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) + check(List(Some(mutable.HashMap(1 -> 1, 2 -> 2)), None, Some(mutable.HashMap(3 -> 4)))) check(List( mutable.HashMap("one" -> 1, "two" -> 2), - mutable.HashMap(1->"one", 2->"two", 3->"three"))) + mutable.HashMap(1 -> "one", 2 -> "two", 3 -> "three"))) } test("Bug: SPARK-10251") { @@ -174,10 +177,10 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { check(mutable.Map("one" -> 1, "two" -> 2)) check(mutable.HashMap(1 -> "one", 2 -> "two")) check(mutable.HashMap("one" -> 1, "two" -> 2)) - check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) + check(List(Some(mutable.HashMap(1 -> 1, 2 -> 2)), None, Some(mutable.HashMap(3 -> 4)))) check(List( mutable.HashMap("one" -> 1, "two" -> 2), - mutable.HashMap(1->"one", 2->"two", 3->"three"))) + mutable.HashMap(1 -> "one", 2 -> "two", 3 -> "three"))) } test("ranges") { @@ -279,8 +282,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { test("kryo with fold") { val control = 1 :: 2 :: Nil // zeroValue must not be a ClassWithoutNoArgConstructor instance because it will be - // serialized by spark.closure.serializer but spark.closure.serializer only supports - // the default Java serializer. + // serialized by the Java serializer. val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)) .fold(null)((t1, t2) => { val t1x = if (t1 == null) 0 else t1.x @@ -322,6 +324,12 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val conf = new SparkConf(false) conf.set("spark.kryo.registrationRequired", "true") + // these cases require knowing the internals of RoaringBitmap a little. Blocks span 2^16 + // values, and they use a bitmap (dense) if they have more than 4096 values, and an + // array (sparse) if they use less. So we just create two cases, one sparse and one dense. + // and we use a roaring bitmap for the empty blocks, so we trigger the dense case w/ mostly + // empty blocks + val ser = new KryoSerializer(conf).newInstance() val denseBlockSizes = new Array[Long](5000) val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L) @@ -344,6 +352,44 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { assert(thrown.getMessage.contains(kryoBufferMaxProperty)) } + test("SPARK-12222: deserialize RoaringBitmap throw Buffer underflow exception") { + val dir = Utils.createTempDir() + val tmpfile = dir.toString + "/RoaringBitmap" + val outStream = new FileOutputStream(tmpfile) + val output = new KryoOutput(outStream) + val bitmap = new RoaringBitmap + bitmap.add(1) + bitmap.add(3) + bitmap.add(5) + // Ignore Kryo because it doesn't use writeObject + bitmap.serialize(new KryoOutputObjectOutputBridge(null, output)) + output.flush() + output.close() + + val inStream = new FileInputStream(tmpfile) + val input = new KryoInput(inStream) + val ret = new RoaringBitmap + // Ignore Kryo because it doesn't use readObject + ret.deserialize(new KryoInputObjectInputBridge(null, input)) + input.close() + assert(ret == bitmap) + Utils.deleteRecursively(dir) + } + + test("KryoOutputObjectOutputBridge.writeObject and KryoInputObjectInputBridge.readObject") { + val kryo = new KryoSerializer(conf).newKryo() + + val bytesOutput = new ByteArrayOutputStream() + val objectOutput = new KryoOutputObjectOutputBridge(kryo, new KryoOutput(bytesOutput)) + objectOutput.writeObject("test") + objectOutput.close() + + val bytesInput = new ByteArrayInputStream(bytesOutput.toByteArray) + val objectInput = new KryoInputObjectInputBridge(kryo, new KryoInput(bytesInput)) + assert(objectInput.readObject() === "test") + objectInput.close() + } + test("getAutoReset") { val ser = new KryoSerializer(new SparkConf).newInstance().asInstanceOf[KryoSerializerInstance] assert(ser.getAutoReset) diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala index 2d5e9d66b2e15..f019b1e25900b 100644 --- a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.serializer import java.io._ +import scala.annotation.meta.param + import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite @@ -29,6 +31,7 @@ class SerializationDebuggerSuite extends SparkFunSuite with BeforeAndAfterEach { import SerializationDebugger.find override def beforeEach(): Unit = { + super.beforeEach() SerializationDebugger.enableDebugging = true } @@ -190,7 +193,7 @@ class SerializationDebuggerSuite extends SparkFunSuite with BeforeAndAfterEach { } val originalException = new NotSerializableException("someClass") - // verify thaht original exception is returned on failure + // verify that original exception is returned on failure assert(SerializationDebugger.improveException(o, originalException).eq(originalException)) } } @@ -218,7 +221,7 @@ class SerializableClassWithWriteObject(val objectField: Object) extends Serializ } -class SerializableClassWithWriteReplace(@transient replacementFieldObject: Object) +class SerializableClassWithWriteReplace(@(transient @param) replacementFieldObject: Object) extends Serializable { private def writeReplace(): Object = { replacementFieldObject diff --git a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala index c1e0a29a34bb1..17037870f7a15 100644 --- a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala +++ b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala @@ -17,12 +17,11 @@ package org.apache.spark.serializer -import java.io.{EOFException, OutputStream, InputStream} +import java.io.{EOFException, InputStream, OutputStream} import java.nio.ByteBuffer import scala.reflect.ClassTag - /** * A serializer implementation that always returns two elements in a deserialization stream. */ diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 26a372d6a905d..dba1172d5fdbd 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -20,14 +20,11 @@ package org.apache.spark.shuffle import java.io.{ByteArrayOutputStream, InputStream} import java.nio.ByteBuffer -import org.mockito.Matchers.{eq => meq, _} import org.mockito.Mockito.{mock, when} -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer import org.apache.spark._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} /** @@ -77,13 +74,6 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // can ensure retain() and release() are properly called. val blockManager = mock(classOf[BlockManager]) - // Create a return function to use for the mocked wrapForCompression method that just returns - // the original input stream. - val dummyCompressionFunction = new Answer[InputStream] { - override def answer(invocation: InvocationOnMock): InputStream = - invocation.getArguments()(1).asInstanceOf[InputStream] - } - // Create a buffer with some randomly generated key-value pairs to use as the shuffle data // from each mappers (all mappers return the same shuffle data). val byteOutputStream = new ByteArrayOutputStream() @@ -105,9 +95,6 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // fetch shuffle data. val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer) - when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream]))) - .thenAnswer(dummyCompressionFunction) - managedBuffer } @@ -127,17 +114,24 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Create a mocked shuffle handle to pass into HashShuffleReader. val shuffleHandle = { val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]]) - when(dependency.serializer).thenReturn(Some(serializer)) + when(dependency.serializer).thenReturn(serializer) when(dependency.aggregator).thenReturn(None) when(dependency.keyOrdering).thenReturn(None) new BaseShuffleHandle(shuffleId, numMaps, dependency) } + val serializerManager = new SerializerManager( + serializer, + new SparkConf() + .set("spark.shuffle.compress", "false") + .set("spark.shuffle.spill.compress", "false")) + val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, reduceId + 1, TaskContext.empty(), + serializerManager, blockManager, mapOutputTracker) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index b92a302806f76..16418f855bbe1 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -23,8 +23,8 @@ import java.util.UUID import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.mockito.Answers.RETURNS_SMART_NULLS import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Answers.RETURNS_SMART_NULLS import org.mockito.Matchers._ import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock @@ -32,9 +32,9 @@ import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfterEach import org.apache.spark._ -import org.apache.spark.executor.{TaskMetrics, ShuffleWriteMetrics} -import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} import org.apache.spark.serializer.{JavaSerializer, SerializerInstance} +import org.apache.spark.shuffle.IndexShuffleBlockResolver import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -55,6 +55,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte private var shuffleHandle: BypassMergeSortShuffleHandle[Int, Int] = _ override def beforeEach(): Unit = { + super.beforeEach() tempDir = Utils.createTempDir() outputFile = File.createTempFile("shuffle", null, tempDir) taskMetrics = new TaskMetrics @@ -65,9 +66,20 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte dependency = dependency ) when(dependency.partitioner).thenReturn(new HashPartitioner(7)) - when(dependency.serializer).thenReturn(Some(new JavaSerializer(conf))) + when(dependency.serializer).thenReturn(new JavaSerializer(conf)) when(taskContext.taskMetrics()).thenReturn(taskMetrics) when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) + doAnswer(new Answer[Void] { + def answer(invocationOnMock: InvocationOnMock): Void = { + val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File] + if (tmp != null) { + outputFile.delete + tmp.renameTo(outputFile) + } + null + } + }).when(blockResolver) + .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File])) when(blockManager.diskBlockManager).thenReturn(diskBlockManager) when(blockManager.getDiskWriter( any[BlockId], @@ -84,7 +96,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte args(3).asInstanceOf[Int], compressStream = identity, syncWrites = false, - args(4).asInstanceOf[ShuffleWriteMetrics] + args(4).asInstanceOf[ShuffleWriteMetrics], + blockId = args(0).asInstanceOf[BlockId] ) } }) @@ -92,7 +105,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte new Answer[(TempShuffleBlockId, File)] { override def answer(invocation: InvocationOnMock): (TempShuffleBlockId, File) = { val blockId = new TempShuffleBlockId(UUID.randomUUID) - val file = File.createTempFile(blockId.toString, null, tempDir) + val file = new File(tempDir, blockId.name) blockIdToFileMap.put(blockId, file) temporaryFilesCreated.append(file) (blockId, file) @@ -107,9 +120,13 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte } override def afterEach(): Unit = { - Utils.deleteRecursively(tempDir) - blockIdToFileMap.clear() - temporaryFilesCreated.clear() + try { + Utils.deleteRecursively(tempDir) + blockIdToFileMap.clear() + temporaryFilesCreated.clear() + } finally { + super.afterEach() + } } test("write empty iterator") { @@ -128,8 +145,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte assert(outputFile.length() === 0) assert(temporaryFilesCreated.isEmpty) val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get - assert(shuffleWriteMetrics.shuffleBytesWritten === 0) - assert(shuffleWriteMetrics.shuffleRecordsWritten === 0) + assert(shuffleWriteMetrics.bytesWritten === 0) + assert(shuffleWriteMetrics.recordsWritten === 0) assert(taskMetrics.diskBytesSpilled === 0) assert(taskMetrics.memoryBytesSpilled === 0) } @@ -149,14 +166,50 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte writer.stop( /* success = */ true) assert(temporaryFilesCreated.nonEmpty) assert(writer.getPartitionLengths.sum === outputFile.length()) + assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get - assert(shuffleWriteMetrics.shuffleBytesWritten === outputFile.length()) - assert(shuffleWriteMetrics.shuffleRecordsWritten === records.length) + assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) + assert(shuffleWriteMetrics.recordsWritten === records.length) assert(taskMetrics.diskBytesSpilled === 0) assert(taskMetrics.memoryBytesSpilled === 0) } + test("only generate temp shuffle file for non-empty partition") { + // Using exception to test whether only non-empty partition creates temp shuffle file, + // because temp shuffle file will only be cleaned after calling stop(false) in the failure + // case, so we could use it to validate the temp shuffle files. + def records: Iterator[(Int, Int)] = + Iterator((1, 1), (5, 5)) ++ + (0 until 100000).iterator.map { i => + if (i == 99990) { + throw new SparkException("intentional failure") + } else { + (2, 2) + } + } + + val writer = new BypassMergeSortShuffleWriter[Int, Int]( + blockManager, + blockResolver, + shuffleHandle, + 0, // MapId + taskContext, + conf + ) + + intercept[SparkException] { + writer.write(records) + } + + assert(temporaryFilesCreated.nonEmpty) + // Only 3 temp shuffle files will be created + assert(temporaryFilesCreated.count(_.exists()) === 3) + + writer.stop( /* success = */ false) + assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted + } + test("cleanup of intermediate files after errors") { val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala new file mode 100644 index 0000000000000..d21ce73f4021e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import java.io.{File, FileInputStream, FileOutputStream} + +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.storage._ +import org.apache.spark.util.Utils + + +class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEach { + + @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _ + @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _ + + private var tempDir: File = _ + private val conf: SparkConf = new SparkConf(loadDefaults = false) + + override def beforeEach(): Unit = { + super.beforeEach() + tempDir = Utils.createTempDir() + MockitoAnnotations.initMocks(this) + + when(blockManager.diskBlockManager).thenReturn(diskBlockManager) + when(diskBlockManager.getFile(any[BlockId])).thenAnswer( + new Answer[File] { + override def answer(invocation: InvocationOnMock): File = { + new File(tempDir, invocation.getArguments.head.toString) + } + }) + } + + override def afterEach(): Unit = { + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterEach() + } + } + + test("commit shuffle files multiple times") { + val resolver = new IndexShuffleBlockResolver(conf, blockManager) + val lengths = Array[Long](10, 0, 20) + val dataTmp = File.createTempFile("shuffle", null, tempDir) + val out = new FileOutputStream(dataTmp) + Utils.tryWithSafeFinally { + out.write(new Array[Byte](30)) + } { + out.close() + } + resolver.writeIndexFileAndCommit(1, 2, lengths, dataTmp) + + val dataFile = resolver.getDataFile(1, 2) + assert(dataFile.exists()) + assert(dataFile.length() === 30) + assert(!dataTmp.exists()) + + val lengths2 = new Array[Long](3) + val dataTmp2 = File.createTempFile("shuffle", null, tempDir) + val out2 = new FileOutputStream(dataTmp2) + Utils.tryWithSafeFinally { + out2.write(Array[Byte](1)) + out2.write(new Array[Byte](29)) + } { + out2.close() + } + resolver.writeIndexFileAndCommit(1, 2, lengths2, dataTmp2) + assert(lengths2.toSeq === lengths.toSeq) + assert(dataFile.exists()) + assert(dataFile.length() === 30) + assert(!dataTmp2.exists()) + + // The dataFile should be the previous one + val firstByte = new Array[Byte](1) + val in = new FileInputStream(dataFile) + Utils.tryWithSafeFinally { + in.read(firstByte) + } { + in.close() + } + assert(firstByte(0) === 0) + + // remove data file + dataFile.delete() + + val lengths3 = Array[Long](10, 10, 15) + val dataTmp3 = File.createTempFile("shuffle", null, tempDir) + val out3 = new FileOutputStream(dataTmp3) + Utils.tryWithSafeFinally { + out3.write(Array[Byte](2)) + out3.write(new Array[Byte](34)) + } { + out3.close() + } + resolver.writeIndexFileAndCommit(1, 2, lengths3, dataTmp3) + assert(lengths3.toSeq != lengths.toSeq) + assert(dataFile.exists()) + assert(dataFile.length() === 35) + assert(!dataTmp2.exists()) + + // The dataFile should be the previous one + val firstByte2 = new Array[Byte](1) + val in2 = new FileInputStream(dataFile) + Utils.tryWithSafeFinally { + in2.read(firstByte2) + } { + in2.close() + } + assert(firstByte2(0) === 2) + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala index 8744a072cb3f6..55cebe7c8b6a8 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala @@ -41,7 +41,7 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers { private def shuffleDep( partitioner: Partitioner, - serializer: Option[Serializer], + serializer: Serializer, keyOrdering: Option[Ordering[Any]], aggregator: Option[Aggregator[Any, Any, Any]], mapSideCombine: Boolean): ShuffleDependency[Any, Any, Any] = { @@ -56,7 +56,7 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers { } test("supported shuffle dependencies for serialized shuffle") { - val kryo = Some(new KryoSerializer(new SparkConf())) + val kryo = new KryoSerializer(new SparkConf()) assert(canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), @@ -88,8 +88,8 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers { } test("unsupported shuffle dependencies for serialized shuffle") { - val kryo = Some(new KryoSerializer(new SparkConf())) - val java = Some(new JavaSerializer(new SparkConf())) + val kryo = new KryoSerializer(new SparkConf()) + val java = new JavaSerializer(new SparkConf()) // We only support serializers that support object relocation assert(!canUseSerializedShuffle(shuffleDep( diff --git a/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala new file mode 100644 index 0000000000000..88817dccf3497 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status.api.v1 + +import java.util.Date + +import scala.collection.mutable.HashMap + +import org.apache.spark.SparkFunSuite +import org.apache.spark.scheduler.{StageInfo, TaskInfo, TaskLocality} +import org.apache.spark.ui.jobs.UIData.{StageUIData, TaskUIData} + +class AllStagesResourceSuite extends SparkFunSuite { + + def getFirstTaskLaunchTime(taskLaunchTimes: Seq[Long]): Option[Date] = { + val tasks = new HashMap[Long, TaskUIData] + taskLaunchTimes.zipWithIndex.foreach { case (time, idx) => + tasks(idx.toLong) = new TaskUIData( + new TaskInfo(idx, idx, 1, time, "", "", TaskLocality.ANY, false), None, None) + } + + val stageUiData = new StageUIData() + stageUiData.taskData = tasks + val status = StageStatus.ACTIVE + val stageInfo = new StageInfo( + 1, 1, "stage 1", 10, Seq.empty, Seq.empty, "details abc", Seq.empty) + val stageData = AllStagesResource.stageUiToStageData(status, stageInfo, stageUiData, false) + + stageData.firstTaskLaunchedTime + } + + test("firstTaskLaunchedTime when there are no tasks") { + val result = getFirstTaskLaunchTime(Seq()) + assert(result == None) + } + + test("firstTaskLaunchedTime when there are tasks but none launched") { + val result = getFirstTaskLaunchTime(Seq(-100L, -200L, -300L)) + assert(result == None) + } + + test("firstTaskLaunchedTime when there are tasks and some launched") { + val result = getFirstTaskLaunchTime(Seq(-100L, 1449255596000L, 1449255597000L)) + assert(result == Some(new Date(1449255596000L))) + } + +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala new file mode 100644 index 0000000000000..9d1bd7ec89bc7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala @@ -0,0 +1,360 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.util.Properties + +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.language.implicitConversions +import scala.reflect.ClassTag + +import org.scalatest.BeforeAndAfterEach +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.{SparkException, SparkFunSuite, TaskContext, TaskContextImpl} + + +class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { + + private implicit val ec = ExecutionContext.global + private var blockInfoManager: BlockInfoManager = _ + + override protected def beforeEach(): Unit = { + super.beforeEach() + blockInfoManager = new BlockInfoManager() + for (t <- 0 to 4) { + blockInfoManager.registerTask(t) + } + } + + override protected def afterEach(): Unit = { + try { + blockInfoManager = null + } finally { + super.afterEach() + } + } + + private implicit def stringToBlockId(str: String): BlockId = { + TestBlockId(str) + } + + private def newBlockInfo(): BlockInfo = { + new BlockInfo(StorageLevel.MEMORY_ONLY, ClassTag.Any, tellMaster = false) + } + + private def withTaskId[T](taskAttemptId: Long)(block: => T): T = { + try { + TaskContext.setTaskContext( + new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, null)) + block + } finally { + TaskContext.unset() + } + } + + test("initial memory usage") { + assert(blockInfoManager.size === 0) + } + + test("get non-existent block") { + assert(blockInfoManager.get("non-existent-block").isEmpty) + assert(blockInfoManager.lockForReading("non-existent-block").isEmpty) + assert(blockInfoManager.lockForWriting("non-existent-block").isEmpty) + } + + test("basic lockNewBlockForWriting") { + val initialNumMapEntries = blockInfoManager.getNumberOfMapEntries + val blockInfo = newBlockInfo() + withTaskId(1) { + assert(blockInfoManager.lockNewBlockForWriting("block", blockInfo)) + assert(blockInfoManager.get("block").get eq blockInfo) + assert(blockInfo.readerCount === 0) + assert(blockInfo.writerTask === 1) + // Downgrade lock so that second call doesn't block: + blockInfoManager.downgradeLock("block") + assert(blockInfo.readerCount === 1) + assert(blockInfo.writerTask === BlockInfo.NO_WRITER) + assert(!blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + assert(blockInfo.readerCount === 2) + assert(blockInfoManager.get("block").get eq blockInfo) + assert(blockInfo.readerCount === 2) + assert(blockInfo.writerTask === BlockInfo.NO_WRITER) + blockInfoManager.unlock("block") + blockInfoManager.unlock("block") + assert(blockInfo.readerCount === 0) + assert(blockInfo.writerTask === BlockInfo.NO_WRITER) + } + assert(blockInfoManager.size === 1) + assert(blockInfoManager.getNumberOfMapEntries === initialNumMapEntries + 1) + } + + test("lockNewBlockForWriting blocks while write lock is held, then returns false after release") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + } + val lock1Future = Future { + withTaskId(1) { + blockInfoManager.lockNewBlockForWriting("block", newBlockInfo()) + } + } + val lock2Future = Future { + withTaskId(2) { + blockInfoManager.lockNewBlockForWriting("block", newBlockInfo()) + } + } + Thread.sleep(300) // Hack to try to ensure that both future tasks are waiting + withTaskId(0) { + blockInfoManager.downgradeLock("block") + } + // After downgrading to a read lock, both threads should wake up and acquire the shared + // read lock. + assert(!Await.result(lock1Future, 1.seconds)) + assert(!Await.result(lock2Future, 1.seconds)) + assert(blockInfoManager.get("block").get.readerCount === 3) + } + + test("lockNewBlockForWriting blocks while write lock is held, then returns true after removal") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + } + val lock1Future = Future { + withTaskId(1) { + blockInfoManager.lockNewBlockForWriting("block", newBlockInfo()) + } + } + val lock2Future = Future { + withTaskId(2) { + blockInfoManager.lockNewBlockForWriting("block", newBlockInfo()) + } + } + Thread.sleep(300) // Hack to try to ensure that both future tasks are waiting + withTaskId(0) { + blockInfoManager.removeBlock("block") + } + // After removing the block, the write lock is released. Both threads should wake up but only + // one should acquire the write lock. The second thread should block until the winner of the + // write race releases its lock. + val winningFuture: Future[Boolean] = + Await.ready(Future.firstCompletedOf(Seq(lock1Future, lock2Future)), 1.seconds) + assert(winningFuture.value.get.get) + val winningTID = blockInfoManager.get("block").get.writerTask + assert(winningTID === 1 || winningTID === 2) + val losingFuture: Future[Boolean] = if (winningTID == 1) lock2Future else lock1Future + assert(!losingFuture.isCompleted) + // Once the writer releases its lock, the blocked future should wake up again and complete. + withTaskId(winningTID) { + blockInfoManager.unlock("block") + } + assert(!Await.result(losingFuture, 1.seconds)) + assert(blockInfoManager.get("block").get.readerCount === 1) + } + + test("read locks are reentrant") { + withTaskId(1) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + blockInfoManager.unlock("block") + assert(blockInfoManager.lockForReading("block").isDefined) + assert(blockInfoManager.lockForReading("block").isDefined) + assert(blockInfoManager.get("block").get.readerCount === 2) + assert(blockInfoManager.get("block").get.writerTask === BlockInfo.NO_WRITER) + blockInfoManager.unlock("block") + assert(blockInfoManager.get("block").get.readerCount === 1) + blockInfoManager.unlock("block") + assert(blockInfoManager.get("block").get.readerCount === 0) + } + } + + test("multiple tasks can hold read locks") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + blockInfoManager.unlock("block") + } + withTaskId(1) { assert(blockInfoManager.lockForReading("block").isDefined) } + withTaskId(2) { assert(blockInfoManager.lockForReading("block").isDefined) } + withTaskId(3) { assert(blockInfoManager.lockForReading("block").isDefined) } + withTaskId(4) { assert(blockInfoManager.lockForReading("block").isDefined) } + assert(blockInfoManager.get("block").get.readerCount === 4) + } + + test("single task can hold write lock") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + blockInfoManager.unlock("block") + } + withTaskId(1) { + assert(blockInfoManager.lockForWriting("block").isDefined) + assert(blockInfoManager.get("block").get.writerTask === 1) + } + withTaskId(2) { + assert(blockInfoManager.lockForWriting("block", blocking = false).isEmpty) + assert(blockInfoManager.get("block").get.writerTask === 1) + } + } + + test("cannot call lockForWriting while already holding a write lock") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + blockInfoManager.unlock("block") + } + withTaskId(1) { + assert(blockInfoManager.lockForWriting("block").isDefined) + intercept[IllegalStateException] { + blockInfoManager.lockForWriting("block") + } + blockInfoManager.assertBlockIsLockedForWriting("block") + } + } + + test("assertBlockIsLockedForWriting throws exception if block is not locked") { + intercept[SparkException] { + blockInfoManager.assertBlockIsLockedForWriting("block") + } + withTaskId(BlockInfo.NON_TASK_WRITER) { + intercept[SparkException] { + blockInfoManager.assertBlockIsLockedForWriting("block") + } + } + } + + test("downgrade lock") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + blockInfoManager.downgradeLock("block") + } + withTaskId(1) { + assert(blockInfoManager.lockForReading("block").isDefined) + } + assert(blockInfoManager.get("block").get.readerCount === 2) + assert(blockInfoManager.get("block").get.writerTask === BlockInfo.NO_WRITER) + } + + test("write lock will block readers") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + } + val get1Future = Future { + withTaskId(1) { + blockInfoManager.lockForReading("block") + } + } + val get2Future = Future { + withTaskId(2) { + blockInfoManager.lockForReading("block") + } + } + Thread.sleep(300) // Hack to try to ensure that both future tasks are waiting + withTaskId(0) { + blockInfoManager.unlock("block") + } + assert(Await.result(get1Future, 1.seconds).isDefined) + assert(Await.result(get2Future, 1.seconds).isDefined) + assert(blockInfoManager.get("block").get.readerCount === 2) + } + + test("read locks will block writer") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + blockInfoManager.unlock("block") + blockInfoManager.lockForReading("block") + } + val write1Future = Future { + withTaskId(1) { + blockInfoManager.lockForWriting("block") + } + } + val write2Future = Future { + withTaskId(2) { + blockInfoManager.lockForWriting("block") + } + } + Thread.sleep(300) // Hack to try to ensure that both future tasks are waiting + withTaskId(0) { + blockInfoManager.unlock("block") + } + assert( + Await.result(Future.firstCompletedOf(Seq(write1Future, write2Future)), 1.seconds).isDefined) + val firstWriteWinner = if (write1Future.isCompleted) 1 else 2 + withTaskId(firstWriteWinner) { + blockInfoManager.unlock("block") + } + assert(Await.result(write1Future, 1.seconds).isDefined) + assert(Await.result(write2Future, 1.seconds).isDefined) + } + + test("removing a non-existent block throws IllegalArgumentException") { + withTaskId(0) { + intercept[IllegalArgumentException] { + blockInfoManager.removeBlock("non-existent-block") + } + } + } + + test("removing a block without holding any locks throws IllegalStateException") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + blockInfoManager.unlock("block") + intercept[IllegalStateException] { + blockInfoManager.removeBlock("block") + } + } + } + + test("removing a block while holding only a read lock throws IllegalStateException") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + blockInfoManager.unlock("block") + assert(blockInfoManager.lockForReading("block").isDefined) + intercept[IllegalStateException] { + blockInfoManager.removeBlock("block") + } + } + } + + test("removing a block causes blocked callers to receive None") { + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + } + val getFuture = Future { + withTaskId(1) { + blockInfoManager.lockForReading("block") + } + } + val writeFuture = Future { + withTaskId(2) { + blockInfoManager.lockForWriting("block") + } + } + Thread.sleep(300) // Hack to try to ensure that both future tasks are waiting + withTaskId(0) { + blockInfoManager.removeBlock("block") + } + assert(Await.result(getFuture, 1.seconds).isEmpty) + assert(Await.result(writeFuture, 1.seconds).isEmpty) + } + + test("releaseAllLocksForTask releases write locks") { + val initialNumMapEntries = blockInfoManager.getNumberOfMapEntries + withTaskId(0) { + assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) + } + assert(blockInfoManager.getNumberOfMapEntries === initialNumMapEntries + 3) + blockInfoManager.releaseAllLocksForTask(0) + assert(blockInfoManager.getNumberOfMapEntries === initialNumMapEntries) + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 6e3f500e15dc0..d26df7e760cea 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -26,13 +26,13 @@ import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.network.netty.NettyBlockTransferService -import org.apache.spark.rpc.RpcEnv import org.apache.spark._ -import org.apache.spark.memory.StaticMemoryManager +import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.BlockTransferService +import org.apache.spark.network.netty.NettyBlockTransferService +import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus -import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.serializer.{KryoSerializer, SerializerManager} import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage.StorageLevel._ @@ -60,9 +60,12 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo private def makeBlockManager( maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { - val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) - val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) - val store = new BlockManager(name, rpcEnv, master, serializer, conf, + conf.set("spark.testing.memory", maxMem.toString) + conf.set("spark.memory.offHeap.size", maxMem.toString) + val transfer = new NettyBlockTransferService(conf, securityMgr, "localhost", numCores = 1) + val memManager = UnifiedMemoryManager(conf, numCores = 1) + val serializerManager = new SerializerManager(serializer, conf) + val store = new BlockManager(name, rpcEnv, master, serializerManager, conf, memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) memManager.setMemoryStore(store.memoryStore) store.initialize("app-id") @@ -75,6 +78,9 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo conf.set("spark.authenticate", "false") conf.set("spark.driver.port", rpcEnv.address.port.toString) + conf.set("spark.testing", "true") + conf.set("spark.memory.fraction", "1") + conf.set("spark.memory.storageFraction", "1") conf.set("spark.storage.unrollFraction", "0.4") conf.set("spark.storage.unrollMemoryThreshold", "512") @@ -171,6 +177,10 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo testReplication(5, storageLevels) } + test("block replication - off-heap") { + testReplication(2, Seq(OFF_HEAP, StorageLevel(true, true, true, false, 2))) + } + test("block replication - 2x replication without peers") { intercept[org.scalatest.exceptions.TestFailedException] { testReplication(1, @@ -261,8 +271,10 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work when(failableTransfer.hostName).thenReturn("some-hostname") when(failableTransfer.port).thenReturn(1000) - val memManager = new StaticMemoryManager(conf, Long.MaxValue, 10000, numCores = 1) - val failableStore = new BlockManager("failable-store", rpcEnv, master, serializer, conf, + conf.set("spark.testing.memory", "10000") + val memManager = UnifiedMemoryManager(conf, numCores = 1) + val serializerManager = new SerializerManager(serializer, conf) + val failableStore = new BlockManager("failable-store", rpcEnv, master, serializerManager, conf, memManager, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0) memManager.setMemoryStore(failableStore.memoryStore) failableStore.initialize("app-id") @@ -366,7 +378,9 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo testStore => blockLocations.contains(testStore.blockManagerId.executorId) }.foreach { testStore => val testStoreName = testStore.blockManagerId.executorId - assert(testStore.getLocal(blockId).isDefined, s"$blockId was not found in $testStoreName") + assert( + testStore.getLocalValues(blockId).isDefined, s"$blockId was not found in $testStoreName") + testStore.releaseLock(blockId) assert(master.getLocations(blockId).map(_.executorId).toSet.contains(testStoreName), s"master does not have status for ${blockId.name} in $testStoreName") @@ -388,10 +402,14 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo // If the block is supposed to be in memory, then drop the copy of the block in // this store test whether master is updated with zero memory usage this store if (storageLevel.useMemory) { + val sl = if (storageLevel.useOffHeap) { + StorageLevel(false, true, true, false, 1) + } else { + MEMORY_ONLY_SER + } // Force the block to be dropped by adding a number of dummy blocks (1 to 10).foreach { - i => - testStore.putSingle(s"dummy-block-$i", new Array[Byte](1000), MEMORY_ONLY_SER) + i => testStore.putSingle(s"dummy-block-$i", new Array[Byte](1000), sl) } (1 to 10).foreach { i => testStore.removeBlock(s"dummy-block-$i") diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index d49015afcd594..a1c2933584acc 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -17,48 +17,53 @@ package org.apache.spark.storage -import java.nio.{ByteBuffer, MappedByteBuffer} -import java.util.Arrays +import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ +import scala.concurrent.Future import scala.language.implicitConversions import scala.language.postfixOps +import scala.reflect.ClassTag -import org.mockito.Mockito.{mock, when} +import org.mockito.{Matchers => mc} +import org.mockito.Mockito.{mock, times, verify, when} import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ -import org.apache.spark.network.netty.NettyBlockTransferService -import org.apache.spark.rpc.RpcEnv import org.apache.spark._ import org.apache.spark.executor.DataReadMethod -import org.apache.spark.memory.StaticMemoryManager +import org.apache.spark.memory.UnifiedMemoryManager +import org.apache.spark.network.{BlockDataManager, BlockTransferService} +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.netty.NettyBlockTransferService +import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus -import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager} import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat import org.apache.spark.util._ - +import org.apache.spark.util.io.ChunkedByteBuffer class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach with PrivateMethodTester with ResetSystemProperties { - private val conf = new SparkConf(false).set("spark.app.id", "test") + import BlockManagerSuite._ + + var conf: SparkConf = null var store: BlockManager = null var store2: BlockManager = null var store3: BlockManager = null var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null - conf.set("spark.authenticate", "false") - val securityMgr = new SecurityManager(conf) - val mapOutputTracker = new MapOutputTrackerMaster(conf) - val shuffleManager = new HashShuffleManager(conf) + val securityMgr = new SecurityManager(new SparkConf(false)) + val mapOutputTracker = new MapOutputTrackerMaster(new SparkConf(false)) + val shuffleManager = new HashShuffleManager(new SparkConf(false)) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test - conf.set("spark.kryoserializer.buffer", "1m") - val serializer = new KryoSerializer(conf) + val serializer = new KryoSerializer(new SparkConf(false).set("spark.kryoserializer.buffer", "1m")) // Implicitly convert strings to BlockIds for test clarity. implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) @@ -66,10 +71,17 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE private def makeBlockManager( maxMem: Long, - name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { - val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) - val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) - val blockManager = new BlockManager(name, rpcEnv, master, serializer, conf, + name: String = SparkContext.DRIVER_IDENTIFIER, + master: BlockManagerMaster = this.master, + transferService: Option[BlockTransferService] = Option.empty): BlockManager = { + conf.set("spark.testing.memory", maxMem.toString) + conf.set("spark.memory.offHeap.size", maxMem.toString) + val serializer = new KryoSerializer(conf) + val transfer = transferService + .getOrElse(new NettyBlockTransferService(conf, securityMgr, "localhost", numCores = 1)) + val memManager = UnifiedMemoryManager(conf, numCores = 1) + val serializerManager = new SerializerManager(serializer, conf) + val blockManager = new BlockManager(name, rpcEnv, master, serializerManager, conf, memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) memManager.setMemoryStore(blockManager.memoryStore) blockManager.initialize("app-id") @@ -77,15 +89,21 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } override def beforeEach(): Unit = { - rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) - + super.beforeEach() // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case System.setProperty("os.arch", "amd64") - conf.set("os.arch", "amd64") - conf.set("spark.test.useCompressedOops", "true") + conf = new SparkConf(false) + .set("spark.app.id", "test") + .set("spark.testing", "true") + .set("spark.memory.fraction", "1") + .set("spark.memory.storageFraction", "1") + .set("spark.kryoserializer.buffer", "1m") + .set("spark.test.useCompressedOops", "true") + .set("spark.storage.unrollFraction", "0.4") + .set("spark.storage.unrollMemoryThreshold", "512") + + rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) conf.set("spark.driver.port", rpcEnv.address.port.toString) - conf.set("spark.storage.unrollFraction", "0.4") - conf.set("spark.storage.unrollMemoryThreshold", "512") master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) @@ -95,30 +113,35 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } override def afterEach(): Unit = { - if (store != null) { - store.stop() - store = null - } - if (store2 != null) { - store2.stop() - store2 = null - } - if (store3 != null) { - store3.stop() - store3 = null + try { + conf = null + if (store != null) { + store.stop() + store = null + } + if (store2 != null) { + store2.stop() + store2 = null + } + if (store3 != null) { + store3.stop() + store3 = null + } + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null + master = null + } finally { + super.afterEach() } - rpcEnv.shutdown() - rpcEnv.awaitTermination() - rpcEnv = null - master = null } test("StorageLevel object caching") { - val level1 = StorageLevel(false, false, false, false, 3) + val level1 = StorageLevel(false, false, false, 3) // this should return the same object as level1 - val level2 = StorageLevel(false, false, false, false, 3) + val level2 = StorageLevel(false, false, false, 3) // this should return a different object - val level3 = StorageLevel(false, false, false, false, 2) + val level3 = StorageLevel(false, false, false, 2) assert(level2 === level1, "level2 is not same as level1") assert(level2.eq(level1), "level2 is not the same object as level1") assert(level3 != level1, "level3 is same as level1") @@ -167,9 +190,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY, tellMaster = false) // Checking whether blocks are in memory - assert(store.getSingle("a1").isDefined, "a1 was not in store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3").isDefined, "a3 was not in store") + assert(store.getSingleAndReleaseLock("a1").isDefined, "a1 was not in store") + assert(store.getSingleAndReleaseLock("a2").isDefined, "a2 was not in store") + assert(store.getSingleAndReleaseLock("a3").isDefined, "a3 was not in store") // Checking whether master knows about the blocks or not assert(master.getLocations("a1").size > 0, "master was not told about a1") @@ -177,10 +200,10 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(master.getLocations("a3").size === 0, "master was told about a3") // Drop a1 and a2 from memory; this should be reported back to the master - store.dropFromMemory("a1", null: Either[Array[Any], ByteBuffer]) - store.dropFromMemory("a2", null: Either[Array[Any], ByteBuffer]) - assert(store.getSingle("a1") === None, "a1 not removed from store") - assert(store.getSingle("a2") === None, "a2 not removed from store") + store.dropFromMemoryIfExists("a1", () => null: Either[Array[Any], ChunkedByteBuffer]) + store.dropFromMemoryIfExists("a2", () => null: Either[Array[Any], ChunkedByteBuffer]) + assert(store.getSingleAndReleaseLock("a1") === None, "a1 not removed from store") + assert(store.getSingleAndReleaseLock("a2") === None, "a2 not removed from store") assert(master.getLocations("a1").size === 0, "master did not remove a1") assert(master.getLocations("a2").size === 0, "master did not remove a2") } @@ -216,9 +239,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val memStatus = master.getMemoryStatus.head._2 assert(memStatus._1 == 20000L, "total memory " + memStatus._1 + " should equal 20000") assert(memStatus._2 <= 12000L, "remaining memory " + memStatus._2 + " should <= 12000") - assert(store.getSingle("a1-to-remove").isDefined, "a1 was not in store") - assert(store.getSingle("a2-to-remove").isDefined, "a2 was not in store") - assert(store.getSingle("a3-to-remove").isDefined, "a3 was not in store") + assert(store.getSingleAndReleaseLock("a1-to-remove").isDefined, "a1 was not in store") + assert(store.getSingleAndReleaseLock("a2-to-remove").isDefined, "a2 was not in store") + assert(store.getSingleAndReleaseLock("a3-to-remove").isDefined, "a3 was not in store") // Checking whether master knows about the blocks or not assert(master.getLocations("a1-to-remove").size > 0, "master was not told about a1") @@ -231,15 +254,15 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE master.removeBlock("a3-to-remove") eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("a1-to-remove") should be (None) + assert(!store.hasLocalBlock("a1-to-remove")) master.getLocations("a1-to-remove") should have size 0 } eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("a2-to-remove") should be (None) + assert(!store.hasLocalBlock("a2-to-remove")) master.getLocations("a2-to-remove") should have size 0 } eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("a3-to-remove") should not be (None) + assert(store.hasLocalBlock("a3-to-remove")) master.getLocations("a3-to-remove") should have size 0 } eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { @@ -261,24 +284,24 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE master.removeRdd(0, blocking = false) eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle(rdd(0, 0)) should be (None) + store.getSingleAndReleaseLock(rdd(0, 0)) should be (None) master.getLocations(rdd(0, 0)) should have size 0 } eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle(rdd(0, 1)) should be (None) + store.getSingleAndReleaseLock(rdd(0, 1)) should be (None) master.getLocations(rdd(0, 1)) should have size 0 } eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("nonrddblock") should not be (None) + store.getSingleAndReleaseLock("nonrddblock") should not be (None) master.getLocations("nonrddblock") should have size (1) } store.putSingle(rdd(0, 0), a1, StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 1), a2, StorageLevel.MEMORY_ONLY) master.removeRdd(0, blocking = true) - store.getSingle(rdd(0, 0)) should be (None) + store.getSingleAndReleaseLock(rdd(0, 0)) should be (None) master.getLocations(rdd(0, 0)) should have size 0 - store.getSingle(rdd(0, 1)) should be (None) + store.getSingleAndReleaseLock(rdd(0, 1)) should be (None) master.getLocations(rdd(0, 1)) should have size 0 } @@ -306,46 +329,46 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // verify whether the blocks exist in both the stores Seq(driverStore, executorStore).foreach { case s => - s.getLocal(broadcast0BlockId) should not be (None) - s.getLocal(broadcast1BlockId) should not be (None) - s.getLocal(broadcast2BlockId) should not be (None) - s.getLocal(broadcast2BlockId2) should not be (None) + assert(s.hasLocalBlock(broadcast0BlockId)) + assert(s.hasLocalBlock(broadcast1BlockId)) + assert(s.hasLocalBlock(broadcast2BlockId)) + assert(s.hasLocalBlock(broadcast2BlockId2)) } // remove broadcast 0 block only from executors master.removeBroadcast(0, removeFromMaster = false, blocking = true) // only broadcast 0 block should be removed from the executor store - executorStore.getLocal(broadcast0BlockId) should be (None) - executorStore.getLocal(broadcast1BlockId) should not be (None) - executorStore.getLocal(broadcast2BlockId) should not be (None) + assert(!executorStore.hasLocalBlock(broadcast0BlockId)) + assert(executorStore.hasLocalBlock(broadcast1BlockId)) + assert(executorStore.hasLocalBlock(broadcast2BlockId)) // nothing should be removed from the driver store - driverStore.getLocal(broadcast0BlockId) should not be (None) - driverStore.getLocal(broadcast1BlockId) should not be (None) - driverStore.getLocal(broadcast2BlockId) should not be (None) + assert(driverStore.hasLocalBlock(broadcast0BlockId)) + assert(driverStore.hasLocalBlock(broadcast1BlockId)) + assert(driverStore.hasLocalBlock(broadcast2BlockId)) // remove broadcast 0 block from the driver as well master.removeBroadcast(0, removeFromMaster = true, blocking = true) - driverStore.getLocal(broadcast0BlockId) should be (None) - driverStore.getLocal(broadcast1BlockId) should not be (None) + assert(!driverStore.hasLocalBlock(broadcast0BlockId)) + assert(driverStore.hasLocalBlock(broadcast1BlockId)) // remove broadcast 1 block from both the stores asynchronously // and verify all broadcast 1 blocks have been removed master.removeBroadcast(1, removeFromMaster = true, blocking = false) eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - driverStore.getLocal(broadcast1BlockId) should be (None) - executorStore.getLocal(broadcast1BlockId) should be (None) + assert(!driverStore.hasLocalBlock(broadcast1BlockId)) + assert(!executorStore.hasLocalBlock(broadcast1BlockId)) } // remove broadcast 2 from both the stores asynchronously // and verify all broadcast 2 blocks have been removed master.removeBroadcast(2, removeFromMaster = true, blocking = false) eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - driverStore.getLocal(broadcast2BlockId) should be (None) - driverStore.getLocal(broadcast2BlockId2) should be (None) - executorStore.getLocal(broadcast2BlockId) should be (None) - executorStore.getLocal(broadcast2BlockId2) should be (None) + assert(!driverStore.hasLocalBlock(broadcast2BlockId)) + assert(!driverStore.hasLocalBlock(broadcast2BlockId2)) + assert(!executorStore.hasLocalBlock(broadcast2BlockId)) + assert(!executorStore.hasLocalBlock(broadcast2BlockId2)) } executorStore.stop() driverStore.stop() @@ -358,7 +381,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - assert(store.getSingle("a1").isDefined, "a1 was not in store") + assert(store.getSingleAndReleaseLock("a1").isDefined, "a1 was not in store") assert(master.getLocations("a1").size > 0, "master was not told about a1") master.removeExecutor(store.blockManagerId.executorId) @@ -397,7 +420,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE master.removeExecutor(store.blockManagerId.executorId) val t1 = new Thread { override def run() { - store.putIterator("a2", a2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIterator( + "a2", a2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) } } val t2 = new Thread { @@ -418,8 +442,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE t2.join() t3.join() - store.dropFromMemory("a1", null: Either[Array[Any], ByteBuffer]) - store.dropFromMemory("a2", null: Either[Array[Any], ByteBuffer]) + store.dropFromMemoryIfExists("a1", () => null: Either[Array[Any], ChunkedByteBuffer]) + store.dropFromMemoryIfExists("a2", () => null: Either[Array[Any], ChunkedByteBuffer]) store.waitForAsyncReregister() } } @@ -430,9 +454,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val list2 = List(new Array[Byte](500), new Array[Byte](1000), new Array[Byte](1500)) val list1SizeEstimate = SizeEstimator.estimate(list1.iterator.toArray) val list2SizeEstimate = SizeEstimator.estimate(list2.iterator.toArray) - store.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store.putIterator("list2memory", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store.putIterator("list2disk", list2.iterator, StorageLevel.DISK_ONLY, tellMaster = true) + store.putIterator( + "list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIterator( + "list2memory", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIterator( + "list2disk", list2.iterator, StorageLevel.DISK_ONLY, tellMaster = true) val list1Get = store.get("list1") assert(list1Get.isDefined, "list1 expected to be in store") assert(list1Get.get.data.size === 2) @@ -451,74 +478,89 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(list2DiskGet.get.readMethod === DataReadMethod.Disk) } + test("optimize a location order of blocks") { + val localHost = Utils.localHostName() + val otherHost = "otherHost" + val bmMaster = mock(classOf[BlockManagerMaster]) + val bmId1 = BlockManagerId("id1", localHost, 1) + val bmId2 = BlockManagerId("id2", localHost, 2) + val bmId3 = BlockManagerId("id3", otherHost, 3) + when(bmMaster.getLocations(mc.any[BlockId])).thenReturn(Seq(bmId1, bmId2, bmId3)) + + val blockManager = makeBlockManager(128, "exec", bmMaster) + val getLocations = PrivateMethod[Seq[BlockManagerId]]('getLocations) + val locations = blockManager invokePrivate getLocations(BroadcastBlockId(0)) + assert(locations.map(_.host).toSet === Set(localHost, localHost, otherHost)) + } + test("SPARK-9591: getRemoteBytes from another location when Exception throw") { - val origTimeoutOpt = conf.getOption("spark.network.timeout") - try { - conf.set("spark.network.timeout", "2s") - store = makeBlockManager(8000, "executor1") - store2 = makeBlockManager(8000, "executor2") - store3 = makeBlockManager(8000, "executor3") - val list1 = List(new Array[Byte](4000)) - store2.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store3.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - var list1Get = store.getRemoteBytes("list1") - assert(list1Get.isDefined, "list1Get expected to be fetched") - // block manager exit - store2.stop() - store2 = null - list1Get = store.getRemoteBytes("list1") - // get `list1` block - assert(list1Get.isDefined, "list1Get expected to be fetched") - store3.stop() - store3 = null - // exception throw because there is no locations - intercept[BlockFetchException] { - list1Get = store.getRemoteBytes("list1") - } - } finally { - origTimeoutOpt match { - case Some(t) => conf.set("spark.network.timeout", t) - case None => conf.remove("spark.network.timeout") - } + conf.set("spark.shuffle.io.maxRetries", "0") + store = makeBlockManager(8000, "executor1") + store2 = makeBlockManager(8000, "executor2") + store3 = makeBlockManager(8000, "executor3") + val list1 = List(new Array[Byte](4000)) + store2.putIterator( + "list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store3.putIterator( + "list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + assert(store.getRemoteBytes("list1").isDefined, "list1Get expected to be fetched") + store2.stop() + store2 = null + assert(store.getRemoteBytes("list1").isDefined, "list1Get expected to be fetched") + store3.stop() + store3 = null + // exception throw because there is no locations + intercept[BlockFetchException] { + store.getRemoteBytes("list1") } } + test("SPARK-14252: getOrElseUpdate should still read from remote storage") { + store = makeBlockManager(8000, "executor1") + store2 = makeBlockManager(8000, "executor2") + val list1 = List(new Array[Byte](4000)) + store2.putIterator( + "list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + assert(store.getOrElseUpdate( + "list1", + StorageLevel.MEMORY_ONLY, + ClassTag.Any, + () => throw new AssertionError("attempted to compute locally")).isLeft) + } + test("in-memory LRU storage") { - store = makeBlockManager(12000) - val a1 = new Array[Byte](4000) - val a2 = new Array[Byte](4000) - val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY) - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3").isDefined, "a3 was not in store") - assert(store.getSingle("a1") === None, "a1 was in store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") - // At this point a2 was gotten last, so LRU will getSingle rid of a3 - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - assert(store.getSingle("a1").isDefined, "a1 was not in store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3") === None, "a3 was in store") + testInMemoryLRUStorage(StorageLevel.MEMORY_ONLY) } test("in-memory LRU storage with serialization") { + testInMemoryLRUStorage(StorageLevel.MEMORY_ONLY_SER) + } + + test("in-memory LRU storage with off-heap") { + testInMemoryLRUStorage(StorageLevel( + useDisk = false, + useMemory = true, + useOffHeap = true, + deserialized = false, replication = 1)) + } + + private def testInMemoryLRUStorage(storageLevel: StorageLevel): Unit = { store = makeBlockManager(12000) val a1 = new Array[Byte](4000) val a2 = new Array[Byte](4000) val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_SER) - store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY_SER) - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3").isDefined, "a3 was not in store") - assert(store.getSingle("a1") === None, "a1 was in store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") + store.putSingle("a1", a1, storageLevel) + store.putSingle("a2", a2, storageLevel) + store.putSingle("a3", a3, storageLevel) + assert(store.getSingleAndReleaseLock("a2").isDefined, "a2 was not in store") + assert(store.getSingleAndReleaseLock("a3").isDefined, "a3 was not in store") + assert(store.getSingleAndReleaseLock("a1") === None, "a1 was in store") + assert(store.getSingleAndReleaseLock("a2").isDefined, "a2 was not in store") // At this point a2 was gotten last, so LRU will getSingle rid of a3 - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER) - assert(store.getSingle("a1").isDefined, "a1 was not in store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3") === None, "a3 was in store") + store.putSingle("a1", a1, storageLevel) + assert(store.getSingleAndReleaseLock("a1").isDefined, "a1 was not in store") + assert(store.getSingleAndReleaseLock("a2").isDefined, "a2 was not in store") + assert(store.getSingleAndReleaseLock("a3") === None, "a3 was in store") } test("in-memory LRU for partitions of same RDD") { @@ -531,13 +573,13 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store.putSingle(rdd(0, 3), a3, StorageLevel.MEMORY_ONLY) // Even though we accessed rdd_0_3 last, it should not have replaced partitions 1 and 2 // from the same RDD - assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store") - assert(store.getSingle(rdd(0, 2)).isDefined, "rdd_0_2 was not in store") - assert(store.getSingle(rdd(0, 1)).isDefined, "rdd_0_1 was not in store") + assert(store.getSingleAndReleaseLock(rdd(0, 3)) === None, "rdd_0_3 was in store") + assert(store.getSingleAndReleaseLock(rdd(0, 2)).isDefined, "rdd_0_2 was not in store") + assert(store.getSingleAndReleaseLock(rdd(0, 1)).isDefined, "rdd_0_1 was not in store") // Check that rdd_0_3 doesn't replace them even after further accesses - assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store") - assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store") - assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store") + assert(store.getSingleAndReleaseLock(rdd(0, 3)) === None, "rdd_0_3 was in store") + assert(store.getSingleAndReleaseLock(rdd(0, 3)) === None, "rdd_0_3 was in store") + assert(store.getSingleAndReleaseLock(rdd(0, 3)) === None, "rdd_0_3 was in store") } test("in-memory LRU for partitions of multiple RDDs") { @@ -550,7 +592,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was in store") assert(store.memoryStore.contains(rdd(0, 2)), "rdd_0_2 was not in store") // Do a get() on rdd_0_2 so that it is the most recently used item - assert(store.getSingle(rdd(0, 2)).isDefined, "rdd_0_2 was not in store") + assert(store.getSingleAndReleaseLock(rdd(0, 2)).isDefined, "rdd_0_2 was not in store") // Put in more partitions from RDD 0; they should replace rdd_1_1 store.putSingle(rdd(0, 3), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 4), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) @@ -563,26 +605,6 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.memoryStore.contains(rdd(0, 3)), "rdd_0_3 was not in store") } - test("tachyon storage") { - // TODO Make the spark.test.tachyon.enable true after using tachyon 0.5.0 testing jar. - val tachyonUnitTestEnabled = conf.getBoolean("spark.test.tachyon.enable", false) - conf.set(ExternalBlockStore.BLOCK_MANAGER_NAME, ExternalBlockStore.DEFAULT_BLOCK_MANAGER_NAME) - if (tachyonUnitTestEnabled) { - store = makeBlockManager(1200) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.OFF_HEAP) - store.putSingle("a2", a2, StorageLevel.OFF_HEAP) - store.putSingle("a3", a3, StorageLevel.OFF_HEAP) - assert(store.getSingle("a3").isDefined, "a3 was in store") - assert(store.getSingle("a2").isDefined, "a2 was in store") - assert(store.getSingle("a1").isDefined, "a1 was in store") - } else { - info("tachyon storage test disabled.") - } - } - test("on-disk storage") { store = makeBlockManager(1200) val a1 = new Array[Byte](400) @@ -591,69 +613,64 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store.putSingle("a1", a1, StorageLevel.DISK_ONLY) store.putSingle("a2", a2, StorageLevel.DISK_ONLY) store.putSingle("a3", a3, StorageLevel.DISK_ONLY) - assert(store.getSingle("a2").isDefined, "a2 was in store") - assert(store.getSingle("a3").isDefined, "a3 was in store") - assert(store.getSingle("a1").isDefined, "a1 was in store") + assert(store.getSingleAndReleaseLock("a2").isDefined, "a2 was in store") + assert(store.getSingleAndReleaseLock("a3").isDefined, "a3 was in store") + assert(store.getSingleAndReleaseLock("a1").isDefined, "a1 was in store") } test("disk and memory storage") { - store = makeBlockManager(12000) - val a1 = new Array[Byte](4000) - val a2 = new Array[Byte](4000) - val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK) - store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK) - store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK) - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3").isDefined, "a3 was not in store") - assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") - assert(store.getSingle("a1").isDefined, "a1 was not in store") - assert(store.memoryStore.getValues("a1").isDefined, "a1 was not in memory store") + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = false) } test("disk and memory storage with getLocalBytes") { - store = makeBlockManager(12000) - val a1 = new Array[Byte](4000) - val a2 = new Array[Byte](4000) - val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK) - store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK) - store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK) - assert(store.getLocalBytes("a2").isDefined, "a2 was not in store") - assert(store.getLocalBytes("a3").isDefined, "a3 was not in store") - assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") - assert(store.getLocalBytes("a1").isDefined, "a1 was not in store") - assert(store.memoryStore.getValues("a1").isDefined, "a1 was not in memory store") + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = true) } test("disk and memory storage with serialization") { - store = makeBlockManager(12000) - val a1 = new Array[Byte](4000) - val a2 = new Array[Byte](4000) - val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK_SER) - store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK_SER) - store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK_SER) - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3").isDefined, "a3 was not in store") - assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") - assert(store.getSingle("a1").isDefined, "a1 was not in store") - assert(store.memoryStore.getValues("a1").isDefined, "a1 was not in memory store") + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = false) } test("disk and memory storage with serialization and getLocalBytes") { + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = true) + } + + test("disk and off-heap memory storage") { + testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = false) + } + + test("disk and off-heap memory storage with getLocalBytes") { + testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = true) + } + + def testDiskAndMemoryStorage( + storageLevel: StorageLevel, + getAsBytes: Boolean): Unit = { store = makeBlockManager(12000) + val accessMethod = + if (getAsBytes) store.getLocalBytesAndReleaseLock else store.getSingleAndReleaseLock val a1 = new Array[Byte](4000) val a2 = new Array[Byte](4000) val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK_SER) - store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK_SER) - store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK_SER) - assert(store.getLocalBytes("a2").isDefined, "a2 was not in store") - assert(store.getLocalBytes("a3").isDefined, "a3 was not in store") - assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") - assert(store.getLocalBytes("a1").isDefined, "a1 was not in store") - assert(store.memoryStore.getValues("a1").isDefined, "a1 was not in memory store") + store.putSingle("a1", a1, storageLevel) + store.putSingle("a2", a2, storageLevel) + store.putSingle("a3", a3, storageLevel) + assert(accessMethod("a2").isDefined, "a2 was not in store") + assert(accessMethod("a3").isDefined, "a3 was not in store") + assert(accessMethod("a1").isDefined, "a1 was not in store") + val dataShouldHaveBeenCachedBackIntoMemory = { + if (storageLevel.deserialized) { + !getAsBytes + } else { + // If the block's storage level is serialized, then always cache the bytes in memory, even + // if the caller requested values. + true + } + } + if (dataShouldHaveBeenCachedBackIntoMemory) { + assert(store.memoryStore.contains("a1"), "a1 was not in memory store") + } else { + assert(!store.memoryStore.contains("a1"), "a1 was in memory store") + } } test("LRU with mixed storage levels") { @@ -667,15 +684,15 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_SER) store.putSingle("a3", a3, StorageLevel.DISK_ONLY) // At this point LRU should not kick in because a3 is only on disk - assert(store.getSingle("a1").isDefined, "a1 was not in store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3").isDefined, "a3 was not in store") + assert(store.getSingleAndReleaseLock("a1").isDefined, "a1 was not in store") + assert(store.getSingleAndReleaseLock("a2").isDefined, "a2 was not in store") + assert(store.getSingleAndReleaseLock("a3").isDefined, "a3 was not in store") // Now let's add in a4, which uses both disk and memory; a1 should drop out store.putSingle("a4", a4, StorageLevel.MEMORY_AND_DISK_SER) - assert(store.getSingle("a1") == None, "a1 was in store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3").isDefined, "a3 was not in store") - assert(store.getSingle("a4").isDefined, "a4 was not in store") + assert(store.getSingleAndReleaseLock("a1") == None, "a1 was in store") + assert(store.getSingleAndReleaseLock("a2").isDefined, "a2 was not in store") + assert(store.getSingleAndReleaseLock("a3").isDefined, "a3 was not in store") + assert(store.getSingleAndReleaseLock("a4").isDefined, "a4 was not in store") } test("in-memory LRU with streams") { @@ -683,23 +700,27 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val list1 = List(new Array[Byte](2000), new Array[Byte](2000)) val list2 = List(new Array[Byte](2000), new Array[Byte](2000)) val list3 = List(new Array[Byte](2000), new Array[Byte](2000)) - store.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store.putIterator("list2", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store.putIterator("list3", list3.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - assert(store.get("list2").isDefined, "list2 was not in store") + store.putIterator( + "list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIterator( + "list2", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIterator( + "list3", list3.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + assert(store.getAndReleaseLock("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) - assert(store.get("list3").isDefined, "list3 was not in store") + assert(store.getAndReleaseLock("list3").isDefined, "list3 was not in store") assert(store.get("list3").get.data.size === 2) - assert(store.get("list1") === None, "list1 was in store") - assert(store.get("list2").isDefined, "list2 was not in store") + assert(store.getAndReleaseLock("list1") === None, "list1 was in store") + assert(store.getAndReleaseLock("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) // At this point list2 was gotten last, so LRU will getSingle rid of list3 - store.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - assert(store.get("list1").isDefined, "list1 was not in store") + store.putIterator( + "list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + assert(store.getAndReleaseLock("list1").isDefined, "list1 was not in store") assert(store.get("list1").get.data.size === 2) - assert(store.get("list2").isDefined, "list2 was not in store") + assert(store.getAndReleaseLock("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) - assert(store.get("list3") === None, "list1 was in store") + assert(store.getAndReleaseLock("list3") === None, "list1 was in store") } test("LRU with mixed storage levels and streams") { @@ -709,33 +730,37 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val list3 = List(new Array[Byte](2000), new Array[Byte](2000)) val list4 = List(new Array[Byte](2000), new Array[Byte](2000)) // First store list1 and list2, both in memory, and list3, on disk only - store.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) - store.putIterator("list2", list2.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) - store.putIterator("list3", list3.iterator, StorageLevel.DISK_ONLY, tellMaster = true) + store.putIterator( + "list1", list1.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) + store.putIterator( + "list2", list2.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) + store.putIterator( + "list3", list3.iterator, StorageLevel.DISK_ONLY, tellMaster = true) val listForSizeEstimate = new ArrayBuffer[Any] listForSizeEstimate ++= list1.iterator val listSize = SizeEstimator.estimate(listForSizeEstimate) // At this point LRU should not kick in because list3 is only on disk - assert(store.get("list1").isDefined, "list1 was not in store") + assert(store.getAndReleaseLock("list1").isDefined, "list1 was not in store") assert(store.get("list1").get.data.size === 2) - assert(store.get("list2").isDefined, "list2 was not in store") + assert(store.getAndReleaseLock("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) - assert(store.get("list3").isDefined, "list3 was not in store") + assert(store.getAndReleaseLock("list3").isDefined, "list3 was not in store") assert(store.get("list3").get.data.size === 2) - assert(store.get("list1").isDefined, "list1 was not in store") + assert(store.getAndReleaseLock("list1").isDefined, "list1 was not in store") assert(store.get("list1").get.data.size === 2) - assert(store.get("list2").isDefined, "list2 was not in store") + assert(store.getAndReleaseLock("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) - assert(store.get("list3").isDefined, "list3 was not in store") + assert(store.getAndReleaseLock("list3").isDefined, "list3 was not in store") assert(store.get("list3").get.data.size === 2) // Now let's add in list4, which uses both disk and memory; list1 should drop out - store.putIterator("list4", list4.iterator, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true) - assert(store.get("list1") === None, "list1 was in store") - assert(store.get("list2").isDefined, "list2 was not in store") + store.putIterator( + "list4", list4.iterator, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true) + assert(store.getAndReleaseLock("list1") === None, "list1 was in store") + assert(store.getAndReleaseLock("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) - assert(store.get("list3").isDefined, "list3 was not in store") + assert(store.getAndReleaseLock("list3").isDefined, "list3 was not in store") assert(store.get("list3").get.data.size === 2) - assert(store.get("list4").isDefined, "list4 was not in store") + assert(store.getAndReleaseLock("list4").isDefined, "list4 was not in store") assert(store.get("list4").get.data.size === 2) } @@ -754,17 +779,18 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("overly large block") { store = makeBlockManager(5000) store.putSingle("a1", new Array[Byte](10000), StorageLevel.MEMORY_ONLY) - assert(store.getSingle("a1") === None, "a1 was in store") + assert(store.getSingleAndReleaseLock("a1") === None, "a1 was in store") store.putSingle("a2", new Array[Byte](10000), StorageLevel.MEMORY_AND_DISK) - assert(store.memoryStore.getValues("a2") === None, "a2 was in memory store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") + assert(!store.memoryStore.contains("a2"), "a2 was in memory store") + assert(store.getSingleAndReleaseLock("a2").isDefined, "a2 was not in store") } test("block compression") { try { conf.set("spark.shuffle.compress", "true") store = makeBlockManager(20000, "exec1") - store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + store.putSingle( + ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) <= 100, "shuffle_0_0_0 was not compressed") store.stop() @@ -772,7 +798,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE conf.set("spark.shuffle.compress", "false") store = makeBlockManager(20000, "exec2") - store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) + store.putSingle( + ShuffleBlockId(0, 0, 0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 10000, "shuffle_0_0_0 was compressed") store.stop() @@ -780,7 +807,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE conf.set("spark.broadcast.compress", "true") store = makeBlockManager(20000, "exec3") - store.putSingle(BroadcastBlockId(0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) + store.putSingle( + BroadcastBlockId(0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 1000, "broadcast_0 was not compressed") store.stop() @@ -788,7 +816,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE conf.set("spark.broadcast.compress", "false") store = makeBlockManager(20000, "exec4") - store.putSingle(BroadcastBlockId(0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) + store.putSingle( + BroadcastBlockId(0), new Array[Byte](10000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 10000, "broadcast_0 was compressed") store.stop() store = null @@ -822,14 +851,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("block store put failure") { // Use Java serializer so we can create an unserializable error. - val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) - val memoryManager = new StaticMemoryManager( - conf, - maxExecutionMemory = Long.MaxValue, - maxStorageMemory = 1200, - numCores = 1) + conf.set("spark.testing.memory", "1200") + val transfer = new NettyBlockTransferService(conf, securityMgr, "localhost", numCores = 1) + val memoryManager = UnifiedMemoryManager(conf, numCores = 1) + val serializerManager = new SerializerManager(new JavaSerializer(conf), conf) store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, - new JavaSerializer(conf), conf, memoryManager, mapOutputTracker, + serializerManager, conf, memoryManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) memoryManager.setMemoryStore(store.memoryStore) @@ -842,76 +869,50 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // Make sure get a1 doesn't hang and returns None. failAfter(1 second) { - assert(store.getSingle("a1").isEmpty, "a1 should not be in store") - } - } - - test("reads of memory-mapped and non memory-mapped files are equivalent") { - val confKey = "spark.storage.memoryMapThreshold" - - // Create a non-trivial (not all zeros) byte array - var counter = 0.toByte - def incr: Byte = {counter = (counter + 1).toByte; counter;} - val bytes = Array.fill[Byte](1000)(incr) - val byteBuffer = ByteBuffer.wrap(bytes) - - val blockId = BlockId("rdd_1_2") - - // This sequence of mocks makes these tests fairly brittle. It would - // be nice to refactor classes involved in disk storage in a way that - // allows for easier testing. - val blockManager = mock(classOf[BlockManager]) - when(blockManager.conf).thenReturn(conf.clone.set(confKey, "0")) - val diskBlockManager = new DiskBlockManager(blockManager, conf) - - val diskStoreMapped = new DiskStore(blockManager, diskBlockManager) - diskStoreMapped.putBytes(blockId, byteBuffer, StorageLevel.DISK_ONLY) - val mapped = diskStoreMapped.getBytes(blockId).get - - when(blockManager.conf).thenReturn(conf.clone.set(confKey, "1m")) - val diskStoreNotMapped = new DiskStore(blockManager, diskBlockManager) - diskStoreNotMapped.putBytes(blockId, byteBuffer, StorageLevel.DISK_ONLY) - val notMapped = diskStoreNotMapped.getBytes(blockId).get - - // Not possible to do isInstanceOf due to visibility of HeapByteBuffer - assert(notMapped.getClass.getName.endsWith("HeapByteBuffer"), - "Expected HeapByteBuffer for un-mapped read") - assert(mapped.isInstanceOf[MappedByteBuffer], "Expected MappedByteBuffer for mapped read") - - def arrayFromByteBuffer(in: ByteBuffer): Array[Byte] = { - val array = new Array[Byte](in.remaining()) - in.get(array) - array + assert(store.getSingleAndReleaseLock("a1").isEmpty, "a1 should not be in store") } - - val mappedAsArray = arrayFromByteBuffer(mapped) - val notMappedAsArray = arrayFromByteBuffer(notMapped) - assert(Arrays.equals(mappedAsArray, bytes)) - assert(Arrays.equals(notMappedAsArray, bytes)) } test("updated block statuses") { store = makeBlockManager(12000) + store.registerTask(0) val list = List.fill(2)(new Array[Byte](2000)) val bigList = List.fill(8)(new Array[Byte](2000)) + def getUpdatedBlocks(task: => Unit): Seq[(BlockId, BlockStatus)] = { + val context = TaskContext.empty() + try { + TaskContext.setTaskContext(context) + task + } finally { + TaskContext.unset() + } + context.taskMetrics.updatedBlockStatuses + } + // 1 updated block (i.e. list1) - val updatedBlocks1 = - store.putIterator("list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + val updatedBlocks1 = getUpdatedBlocks { + store.putIterator( + "list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + } assert(updatedBlocks1.size === 1) assert(updatedBlocks1.head._1 === TestBlockId("list1")) assert(updatedBlocks1.head._2.storageLevel === StorageLevel.MEMORY_ONLY) // 1 updated block (i.e. list2) - val updatedBlocks2 = - store.putIterator("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + val updatedBlocks2 = getUpdatedBlocks { + store.putIterator( + "list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + } assert(updatedBlocks2.size === 1) assert(updatedBlocks2.head._1 === TestBlockId("list2")) assert(updatedBlocks2.head._2.storageLevel === StorageLevel.MEMORY_ONLY) // 2 updated blocks - list1 is kicked out of memory while list3 is added - val updatedBlocks3 = - store.putIterator("list3", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + val updatedBlocks3 = getUpdatedBlocks { + store.putIterator( + "list3", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + } assert(updatedBlocks3.size === 2) updatedBlocks3.foreach { case (id, status) => id match { @@ -923,8 +924,10 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.memoryStore.contains("list3"), "list3 was not in memory store") // 2 updated blocks - list2 is kicked out of memory (but put on disk) while list4 is added - val updatedBlocks4 = - store.putIterator("list4", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + val updatedBlocks4 = getUpdatedBlocks { + store.putIterator( + "list4", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + } assert(updatedBlocks4.size === 2) updatedBlocks4.foreach { case (id, status) => id match { @@ -937,8 +940,10 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.memoryStore.contains("list4"), "list4 was not in memory store") // No updated blocks - list5 is too big to fit in store and nothing is kicked out - val updatedBlocks5 = - store.putIterator("list5", bigList.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + val updatedBlocks5 = getUpdatedBlocks { + store.putIterator( + "list5", bigList.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + } assert(updatedBlocks5.size === 0) // memory store contains only list3 and list4 @@ -954,6 +959,16 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!store.diskStore.contains("list3"), "list3 was in disk store") assert(!store.diskStore.contains("list4"), "list4 was in disk store") assert(!store.diskStore.contains("list5"), "list5 was in disk store") + + // remove block - list2 should be removed from disk + val updatedBlocks6 = getUpdatedBlocks { + store.removeBlock( + "list2", tellMaster = true) + } + assert(updatedBlocks6.size === 1) + assert(updatedBlocks6.head._1 === TestBlockId("list2")) + assert(updatedBlocks6.head._2.storageLevel == StorageLevel.NONE) + assert(!store.diskStore.contains("list2"), "list2 was in disk store") } test("query block statuses") { @@ -961,9 +976,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val list = List.fill(2)(new Array[Byte](2000)) // Tell master. By LRU, only list2 and list3 remains. - store.putIterator("list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store.putIterator("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) - store.putIterator("list3", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIterator( + "list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIterator( + "list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.putIterator( + "list3", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) // getLocations and getBlockStatus should yield the same locations assert(store.master.getLocations("list1").size === 0) @@ -977,9 +995,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.master.getBlockStatus("list3", askSlaves = true).size === 1) // This time don't tell master and see what happens. By LRU, only list5 and list6 remains. - store.putIterator("list4", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) - store.putIterator("list5", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) - store.putIterator("list6", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) + store.putIterator( + "list4", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) + store.putIterator( + "list5", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + store.putIterator( + "list6", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) // getLocations should return nothing because the master is not informed // getBlockStatus without asking slaves should have the same result @@ -1000,9 +1021,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val list = List.fill(2)(new Array[Byte](100)) // insert some blocks - store.putIterator("list1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) - store.putIterator("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) - store.putIterator("list3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.putIterator( + "list1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.putIterator( + "list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.putIterator( + "list3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) // getLocations and getBlockStatus should yield the same locations assert(store.master.getMatchingBlockIds(_.toString.contains("list"), askSlaves = false).size @@ -1011,9 +1035,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE === 1) // insert some more blocks - store.putIterator("newlist1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) - store.putIterator("newlist2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) - store.putIterator("newlist3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + store.putIterator( + "newlist1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.putIterator( + "newlist2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + store.putIterator( + "newlist3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) // getLocations and getBlockStatus should yield the same locations assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = false).size @@ -1023,7 +1050,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val blockIds = Seq(RDDBlockId(1, 0), RDDBlockId(1, 1), RDDBlockId(2, 0)) blockIds.foreach { blockId => - store.putIterator(blockId, list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.putIterator( + blockId, list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) } val matchedBlockIds = store.master.getMatchingBlockIds(_ match { case RDDBlockId(1, _) => true @@ -1037,7 +1065,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store.putSingle(rdd(0, 0), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(1, 0), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) // Access rdd_1_0 to ensure it's not least recently used. - assert(store.getSingle(rdd(1, 0)).isDefined, "rdd_1_0 was not in store") + assert(store.getSingleAndReleaseLock(rdd(1, 0)).isDefined, "rdd_1_0 was not in store") // According to the same-RDD rule, rdd_1_0 should be replaced here. store.putSingle(rdd(0, 1), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) // rdd_1_0 should have been replaced, even it's not least recently used. @@ -1046,158 +1074,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!store.memoryStore.contains(rdd(1, 0)), "rdd_1_0 was in store") } - test("reserve/release unroll memory") { - store = makeBlockManager(12000) - val memoryStore = store.memoryStore - assert(memoryStore.currentUnrollMemory === 0) - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - - def reserveUnrollMemoryForThisTask(memory: Long): Boolean = { - memoryStore.reserveUnrollMemoryForThisTask( - TestBlockId(""), memory, new ArrayBuffer[(BlockId, BlockStatus)]) - } - - // Reserve - assert(reserveUnrollMemoryForThisTask(100)) - assert(memoryStore.currentUnrollMemoryForThisTask === 100) - assert(reserveUnrollMemoryForThisTask(200)) - assert(memoryStore.currentUnrollMemoryForThisTask === 300) - assert(reserveUnrollMemoryForThisTask(500)) - assert(memoryStore.currentUnrollMemoryForThisTask === 800) - assert(!reserveUnrollMemoryForThisTask(1000000)) - assert(memoryStore.currentUnrollMemoryForThisTask === 800) // not granted - // Release - memoryStore.releaseUnrollMemoryForThisTask(100) - assert(memoryStore.currentUnrollMemoryForThisTask === 700) - memoryStore.releaseUnrollMemoryForThisTask(100) - assert(memoryStore.currentUnrollMemoryForThisTask === 600) - // Reserve again - assert(reserveUnrollMemoryForThisTask(4400)) - assert(memoryStore.currentUnrollMemoryForThisTask === 5000) - assert(!reserveUnrollMemoryForThisTask(20000)) - assert(memoryStore.currentUnrollMemoryForThisTask === 5000) // not granted - // Release again - memoryStore.releaseUnrollMemoryForThisTask(1000) - assert(memoryStore.currentUnrollMemoryForThisTask === 4000) - memoryStore.releaseUnrollMemoryForThisTask() // release all - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - } - - /** - * Verify the result of MemoryStore#unrollSafely is as expected. - */ - private def verifyUnroll( - expected: Iterator[Any], - result: Either[Array[Any], Iterator[Any]], - shouldBeArray: Boolean): Unit = { - val actual: Iterator[Any] = result match { - case Left(arr: Array[Any]) => - assert(shouldBeArray, "expected iterator from unroll!") - arr.iterator - case Right(it: Iterator[Any]) => - assert(!shouldBeArray, "expected array from unroll!") - it - case _ => - fail("unroll returned neither an iterator nor an array...") - } - expected.zip(actual).foreach { case (e, a) => - assert(e === a, "unroll did not return original values!") - } - } - - test("safely unroll blocks") { - store = makeBlockManager(12000) - val smallList = List.fill(40)(new Array[Byte](100)) - val bigList = List.fill(40)(new Array[Byte](1000)) - val memoryStore = store.memoryStore - val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - - // Unroll with all the space in the world. This should succeed and return an array. - var unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks) - verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true) - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - memoryStore.releasePendingUnrollMemoryForThisTask() - - // Unroll with not enough space. This should succeed after kicking out someBlock1. - store.putIterator("someBlock1", smallList.iterator, StorageLevel.MEMORY_ONLY) - store.putIterator("someBlock2", smallList.iterator, StorageLevel.MEMORY_ONLY) - unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks) - verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true) - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - assert(droppedBlocks.size === 1) - assert(droppedBlocks.head._1 === TestBlockId("someBlock1")) - droppedBlocks.clear() - memoryStore.releasePendingUnrollMemoryForThisTask() - - // Unroll huge block with not enough space. Even after ensuring free space of 12000 * 0.4 = - // 4800 bytes, there is still not enough room to unroll this block. This returns an iterator. - // In the mean time, however, we kicked out someBlock2 before giving up. - store.putIterator("someBlock3", smallList.iterator, StorageLevel.MEMORY_ONLY) - unrollResult = memoryStore.unrollSafely("unroll", bigList.iterator, droppedBlocks) - verifyUnroll(bigList.iterator, unrollResult, shouldBeArray = false) - assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator - assert(droppedBlocks.size === 1) - assert(droppedBlocks.head._1 === TestBlockId("someBlock2")) - droppedBlocks.clear() - } - - test("safely unroll blocks through putIterator") { - store = makeBlockManager(12000) - val memOnly = StorageLevel.MEMORY_ONLY - val memoryStore = store.memoryStore - val smallList = List.fill(40)(new Array[Byte](100)) - val bigList = List.fill(40)(new Array[Byte](1000)) - def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] - def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - - // Unroll with plenty of space. This should succeed and cache both blocks. - val result1 = memoryStore.putIterator("b1", smallIterator, memOnly, returnValues = true) - val result2 = memoryStore.putIterator("b2", smallIterator, memOnly, returnValues = true) - assert(memoryStore.contains("b1")) - assert(memoryStore.contains("b2")) - assert(result1.size > 0) // unroll was successful - assert(result2.size > 0) - assert(result1.data.isLeft) // unroll did not drop this block to disk - assert(result2.data.isLeft) - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - - // Re-put these two blocks so block manager knows about them too. Otherwise, block manager - // would not know how to drop them from memory later. - memoryStore.remove("b1") - memoryStore.remove("b2") - store.putIterator("b1", smallIterator, memOnly) - store.putIterator("b2", smallIterator, memOnly) - - // Unroll with not enough space. This should succeed but kick out b1 in the process. - val result3 = memoryStore.putIterator("b3", smallIterator, memOnly, returnValues = true) - assert(result3.size > 0) - assert(result3.data.isLeft) - assert(!memoryStore.contains("b1")) - assert(memoryStore.contains("b2")) - assert(memoryStore.contains("b3")) - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - memoryStore.remove("b3") - store.putIterator("b3", smallIterator, memOnly) - - // Unroll huge block with not enough space. This should fail and kick out b2 in the process. - val result4 = memoryStore.putIterator("b4", bigIterator, memOnly, returnValues = true) - assert(result4.size === 0) // unroll was unsuccessful - assert(result4.data.isLeft) - assert(!memoryStore.contains("b1")) - assert(!memoryStore.contains("b2")) - assert(memoryStore.contains("b3")) - assert(!memoryStore.contains("b4")) - assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator - } - - /** - * This test is essentially identical to the preceding one, except that it uses MEMORY_AND_DISK. - */ test("safely unroll blocks through putIterator (disk)") { store = makeBlockManager(12000) - val memAndDisk = StorageLevel.MEMORY_AND_DISK val memoryStore = store.memoryStore val diskStore = store.diskStore val smallList = List.fill(40)(new Array[Byte](100)) @@ -1206,13 +1084,13 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] assert(memoryStore.currentUnrollMemoryForThisTask === 0) - store.putIterator("b1", smallIterator, memAndDisk) - store.putIterator("b2", smallIterator, memAndDisk) + store.putIterator("b1", smallIterator, StorageLevel.MEMORY_AND_DISK) + store.putIterator("b2", smallIterator, StorageLevel.MEMORY_AND_DISK) // Unroll with not enough space. This should succeed but kick out b1 in the process. // Memory store should contain b2 and b3, while disk store should contain only b1 - val result3 = memoryStore.putIterator("b3", smallIterator, memAndDisk, returnValues = true) - assert(result3.size > 0) + val result3 = memoryStore.putIteratorAsValues("b3", smallIterator, ClassTag.Any) + assert(result3.isRight) assert(!memoryStore.contains("b1")) assert(memoryStore.contains("b2")) assert(memoryStore.contains("b3")) @@ -1223,83 +1101,155 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store.putIterator("b3", smallIterator, StorageLevel.MEMORY_ONLY) assert(memoryStore.currentUnrollMemoryForThisTask === 0) - // Unroll huge block with not enough space. This should fail and drop the new block to disk - // directly in addition to kicking out b2 in the process. Memory store should contain only - // b3, while disk store should contain b1, b2 and b4. - val result4 = memoryStore.putIterator("b4", bigIterator, memAndDisk, returnValues = true) - assert(result4.size > 0) - assert(result4.data.isRight) // unroll returned bytes from disk + // Unroll huge block with not enough space. This should fail and return an iterator so that + // the block may be stored to disk. During the unrolling process, block "b2" should be kicked + // out, so the memory store should contain only b3, while the disk store should contain + // b1, b2 and b4. + val result4 = memoryStore.putIteratorAsValues("b4", bigIterator, ClassTag.Any) + assert(result4.isLeft) assert(!memoryStore.contains("b1")) assert(!memoryStore.contains("b2")) assert(memoryStore.contains("b3")) assert(!memoryStore.contains("b4")) - assert(diskStore.contains("b1")) - assert(diskStore.contains("b2")) - assert(!diskStore.contains("b3")) - assert(diskStore.contains("b4")) - assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator } - test("multiple unrolls by the same thread") { + test("read-locked blocks cannot be evicted from memory") { store = makeBlockManager(12000) - val memOnly = StorageLevel.MEMORY_ONLY - val memoryStore = store.memoryStore - val smallList = List.fill(40)(new Array[Byte](100)) - def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] - assert(memoryStore.currentUnrollMemoryForThisTask === 0) + val arr = new Array[Byte](4000) + // First store a1 and a2, both in memory, and a3, on disk only + store.putSingle("a1", arr, StorageLevel.MEMORY_ONLY_SER) + store.putSingle("a2", arr, StorageLevel.MEMORY_ONLY_SER) + assert(store.getSingle("a1").isDefined, "a1 was not in store") + assert(store.getSingle("a2").isDefined, "a2 was not in store") + // This put should fail because both a1 and a2 should be read-locked: + store.putSingle("a3", arr, StorageLevel.MEMORY_ONLY_SER) + assert(store.getSingle("a3").isEmpty, "a3 was in store") + assert(store.getSingle("a1").isDefined, "a1 was not in store") + assert(store.getSingle("a2").isDefined, "a2 was not in store") + // Release both pins of block a2: + store.releaseLock("a2") + store.releaseLock("a2") + // Block a1 is the least-recently accessed, so an LRU eviction policy would evict it before + // block a2. However, a1 is still pinned so this put of a3 should evict a2 instead: + store.putSingle("a3", arr, StorageLevel.MEMORY_ONLY_SER) + assert(store.getSingle("a2").isEmpty, "a2 was in store") + assert(store.getSingle("a1").isDefined, "a1 was not in store") + assert(store.getSingle("a3").isDefined, "a3 was not in store") + } - // All unroll memory used is released because unrollSafely returned an array - memoryStore.putIterator("b1", smallIterator, memOnly, returnValues = true) - assert(memoryStore.currentUnrollMemoryForThisTask === 0) - memoryStore.putIterator("b2", smallIterator, memOnly, returnValues = true) - assert(memoryStore.currentUnrollMemoryForThisTask === 0) + test("SPARK-13328: refresh block locations (fetch should fail after hitting a threshold)") { + val mockBlockTransferService = + new MockBlockTransferService(conf.getInt("spark.block.failures.beforeLocationRefresh", 5)) + store = makeBlockManager(8000, "executor1", transferService = Option(mockBlockTransferService)) + store.putSingle("item", 999L, StorageLevel.MEMORY_ONLY, tellMaster = true) + intercept[BlockFetchException] { + store.getRemoteBytes("item") + } + } - // Unroll memory is not released because unrollSafely returned an iterator - // that still depends on the underlying vector used in the process - memoryStore.putIterator("b3", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisTask - assert(unrollMemoryAfterB3 > 0) - - // The unroll memory owned by this thread builds on top of its value after the previous unrolls - memoryStore.putIterator("b4", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisTask - assert(unrollMemoryAfterB4 > unrollMemoryAfterB3) - - // ... but only to a certain extent (until we run out of free space to grant new unroll memory) - memoryStore.putIterator("b5", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisTask - memoryStore.putIterator("b6", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisTask - memoryStore.putIterator("b7", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisTask - assert(unrollMemoryAfterB5 === unrollMemoryAfterB4) - assert(unrollMemoryAfterB6 === unrollMemoryAfterB4) - assert(unrollMemoryAfterB7 === unrollMemoryAfterB4) + test("SPARK-13328: refresh block locations (fetch should succeed after location refresh)") { + val maxFailuresBeforeLocationRefresh = + conf.getInt("spark.block.failures.beforeLocationRefresh", 5) + val mockBlockManagerMaster = mock(classOf[BlockManagerMaster]) + val mockBlockTransferService = + new MockBlockTransferService(maxFailuresBeforeLocationRefresh) + // make sure we have more than maxFailuresBeforeLocationRefresh locations + // so that we have a chance to do location refresh + val blockManagerIds = (0 to maxFailuresBeforeLocationRefresh) + .map { i => BlockManagerId(s"id-$i", s"host-$i", i + 1) } + when(mockBlockManagerMaster.getLocations(mc.any[BlockId])).thenReturn(blockManagerIds) + store = makeBlockManager(8000, "executor1", mockBlockManagerMaster, + transferService = Option(mockBlockTransferService)) + val block = store.getRemoteBytes("item") + .asInstanceOf[Option[ByteBuffer]] + assert(block.isDefined) + verify(mockBlockManagerMaster, times(2)).getLocations("item") } - test("lazily create a big ByteBuffer to avoid OOM if it cannot be put into MemoryStore") { - store = makeBlockManager(12000) - val memoryStore = store.memoryStore - val blockId = BlockId("rdd_3_10") - val result = memoryStore.putBytes(blockId, 13000, () => { - fail("A big ByteBuffer that cannot be put into MemoryStore should not be created") - }) - assert(result.size === 13000) - assert(result.data === null) - assert(result.droppedBlocks === Nil) + class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { + var numCalls = 0 + + override def init(blockDataManager: BlockDataManager): Unit = {} + + override def fetchBlocks( + host: String, + port: Int, + execId: String, + blockIds: Array[String], + listener: BlockFetchingListener): Unit = { + listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) + } + + override def close(): Unit = {} + + override def hostName: String = { "MockBlockTransferServiceHost" } + + override def port: Int = { 63332 } + + override def uploadBlock( + hostname: String, + port: Int, execId: String, + blockId: BlockId, + blockData: ManagedBuffer, + level: StorageLevel, + classTag: ClassTag[_]): Future[Unit] = { + import scala.concurrent.ExecutionContext.Implicits.global + Future {} + } + + override def fetchBlockSync( + host: String, + port: Int, + execId: String, + blockId: String): ManagedBuffer = { + numCalls += 1 + if (numCalls <= maxFailures) { + throw new RuntimeException("Failing block fetch in the mock block transfer service") + } + super.fetchBlockSync(host, port, execId, blockId) + } } +} - test("put a small ByteBuffer to MemoryStore") { - store = makeBlockManager(12000) - val memoryStore = store.memoryStore - val blockId = BlockId("rdd_3_10") - var bytes: ByteBuffer = null - val result = memoryStore.putBytes(blockId, 10000, () => { - bytes = ByteBuffer.allocate(10000) - bytes - }) - assert(result.size === 10000) - assert(result.data === Right(bytes)) - assert(result.droppedBlocks === Nil) +private object BlockManagerSuite { + + private implicit class BlockManagerTestUtils(store: BlockManager) { + + def dropFromMemoryIfExists( + blockId: BlockId, + data: () => Either[Array[Any], ChunkedByteBuffer]): Unit = { + store.blockInfoManager.lockForWriting(blockId).foreach { info => + val newEffectiveStorageLevel = store.dropFromMemory(blockId, data) + if (newEffectiveStorageLevel.isValid) { + // The block is still present in at least one store, so release the lock + // but don't delete the block info + store.releaseLock(blockId) + } else { + // The block isn't present in any store, so delete the block info so that the + // block can be stored again + store.blockInfoManager.removeBlock(blockId) + } + } + } + + private def wrapGet[T](f: BlockId => Option[T]): BlockId => Option[T] = (blockId: BlockId) => { + val result = f(blockId) + if (result.isDefined) { + store.releaseLock(blockId) + } + result + } + + def hasLocalBlock(blockId: BlockId): Boolean = { + getLocalAndReleaseLock(blockId).isDefined + } + + val getLocalAndReleaseLock: (BlockId) => Option[BlockResult] = wrapGet(store.getLocalValues) + val getAndReleaseLock: (BlockId) => Option[BlockResult] = wrapGet(store.get) + val getSingleAndReleaseLock: (BlockId) => Option[Any] = wrapGet(store.getSingle) + val getLocalBytesAndReleaseLock: (BlockId) => Option[ChunkedByteBuffer] = { + wrapGet(store.getLocalBytes) + } } + } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala index d7ffde1e7864e..06acca3943c20 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockStatusListenerSuite.scala @@ -34,16 +34,14 @@ class BlockStatusListenerSuite extends SparkFunSuite { StreamBlockId(0, 100), StorageLevel.MEMORY_AND_DISK, memSize = 100, - diskSize = 100, - externalBlockStoreSize = 0))) + diskSize = 100))) // The new block status should be added to the listener val expectedBlock = BlockUIData( StreamBlockId(0, 100), "localhost:10000", StorageLevel.MEMORY_AND_DISK, memSize = 100, - diskSize = 100, - externalBlockStoreSize = 0 + diskSize = 100 ) val expectedExecutorStreamBlockStatus = Seq( ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)) @@ -60,15 +58,13 @@ class BlockStatusListenerSuite extends SparkFunSuite { StreamBlockId(0, 100), StorageLevel.MEMORY_AND_DISK, memSize = 100, - diskSize = 100, - externalBlockStoreSize = 0))) + diskSize = 100))) val expectedBlock2 = BlockUIData( StreamBlockId(0, 100), "localhost:10001", StorageLevel.MEMORY_AND_DISK, memSize = 100, - diskSize = 100, - externalBlockStoreSize = 0 + diskSize = 100 ) // Each block manager should contain one block val expectedExecutorStreamBlockStatus2 = Set( @@ -84,8 +80,7 @@ class BlockStatusListenerSuite extends SparkFunSuite { StreamBlockId(0, 100), StorageLevel.NONE, // StorageLevel.NONE means removing it memSize = 0, - diskSize = 0, - externalBlockStoreSize = 0))) + diskSize = 0))) // Only the first block manager contains a block val expectedExecutorStreamBlockStatus3 = Set( ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)), @@ -102,8 +97,7 @@ class BlockStatusListenerSuite extends SparkFunSuite { StreamBlockId(0, 100), StorageLevel.MEMORY_AND_DISK, memSize = 100, - diskSize = 100, - externalBlockStoreSize = 0))) + diskSize = 100))) // The second block manager is removed so we should not see the new block val expectedExecutorStreamBlockStatus4 = Seq( ExecutorStreamBlockStatus("0", "localhost:10000", Seq(expectedBlock)) diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 688f56f4665f3..bbfd6df3b6990 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -21,7 +21,6 @@ import java.io.{File, FileWriter} import scala.language.reflectiveCalls -import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.apache.spark.{SparkConf, SparkFunSuite} @@ -33,8 +32,6 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B private var rootDir1: File = _ private var rootDirs: String = _ - val blockManager = mock(classOf[BlockManager]) - when(blockManager.conf).thenReturn(testConf) var diskBlockManager: DiskBlockManager = _ override def beforeAll() { @@ -45,19 +42,27 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B } override def afterAll() { - super.afterAll() - Utils.deleteRecursively(rootDir0) - Utils.deleteRecursively(rootDir1) + try { + Utils.deleteRecursively(rootDir0) + Utils.deleteRecursively(rootDir1) + } finally { + super.afterAll() + } } override def beforeEach() { + super.beforeEach() val conf = testConf.clone conf.set("spark.local.dir", rootDirs) - diskBlockManager = new DiskBlockManager(blockManager, conf) + diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) } override def afterEach() { - diskBlockManager.stop() + try { + diskBlockManager.stop() + } finally { + super.afterEach() + } } test("basic block creation") { diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala index 7c19531c18802..8eff3c297035d 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala @@ -30,11 +30,16 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { var tempDir: File = _ override def beforeEach(): Unit = { + super.beforeEach() tempDir = Utils.createTempDir() } override def afterEach(): Unit = { - Utils.deleteRecursively(tempDir) + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterEach() + } } test("verify write metrics") { @@ -45,18 +50,18 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { writer.write(Long.box(20), Long.box(30)) // Record metrics update on every write - assert(writeMetrics.shuffleRecordsWritten === 1) + assert(writeMetrics.recordsWritten === 1) // Metrics don't update on every write - assert(writeMetrics.shuffleBytesWritten == 0) + assert(writeMetrics.bytesWritten == 0) // After 32 writes, metrics should update for (i <- 0 until 32) { writer.flush() writer.write(Long.box(i), Long.box(i)) } - assert(writeMetrics.shuffleBytesWritten > 0) - assert(writeMetrics.shuffleRecordsWritten === 33) + assert(writeMetrics.bytesWritten > 0) + assert(writeMetrics.recordsWritten === 33) writer.commitAndClose() - assert(file.length() == writeMetrics.shuffleBytesWritten) + assert(file.length() == writeMetrics.bytesWritten) } test("verify write metrics on revert") { @@ -67,19 +72,19 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { writer.write(Long.box(20), Long.box(30)) // Record metrics update on every write - assert(writeMetrics.shuffleRecordsWritten === 1) + assert(writeMetrics.recordsWritten === 1) // Metrics don't update on every write - assert(writeMetrics.shuffleBytesWritten == 0) + assert(writeMetrics.bytesWritten == 0) // After 32 writes, metrics should update for (i <- 0 until 32) { writer.flush() writer.write(Long.box(i), Long.box(i)) } - assert(writeMetrics.shuffleBytesWritten > 0) - assert(writeMetrics.shuffleRecordsWritten === 33) + assert(writeMetrics.bytesWritten > 0) + assert(writeMetrics.recordsWritten === 33) writer.revertPartialWritesAndClose() - assert(writeMetrics.shuffleBytesWritten == 0) - assert(writeMetrics.shuffleRecordsWritten == 0) + assert(writeMetrics.bytesWritten == 0) + assert(writeMetrics.recordsWritten == 0) } test("Reopening a closed block writer") { @@ -104,11 +109,11 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { writer.write(i, i) } writer.commitAndClose() - val bytesWritten = writeMetrics.shuffleBytesWritten - assert(writeMetrics.shuffleRecordsWritten === 1000) + val bytesWritten = writeMetrics.bytesWritten + assert(writeMetrics.recordsWritten === 1000) writer.revertPartialWritesAndClose() - assert(writeMetrics.shuffleRecordsWritten === 1000) - assert(writeMetrics.shuffleBytesWritten === bytesWritten) + assert(writeMetrics.recordsWritten === 1000) + assert(writeMetrics.bytesWritten === bytesWritten) } test("commitAndClose() should be idempotent") { @@ -120,13 +125,13 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { writer.write(i, i) } writer.commitAndClose() - val bytesWritten = writeMetrics.shuffleBytesWritten - val writeTime = writeMetrics.shuffleWriteTime - assert(writeMetrics.shuffleRecordsWritten === 1000) + val bytesWritten = writeMetrics.bytesWritten + val writeTime = writeMetrics.writeTime + assert(writeMetrics.recordsWritten === 1000) writer.commitAndClose() - assert(writeMetrics.shuffleRecordsWritten === 1000) - assert(writeMetrics.shuffleBytesWritten === bytesWritten) - assert(writeMetrics.shuffleWriteTime === writeTime) + assert(writeMetrics.recordsWritten === 1000) + assert(writeMetrics.bytesWritten === bytesWritten) + assert(writeMetrics.writeTime === writeTime) } test("revertPartialWritesAndClose() should be idempotent") { @@ -138,13 +143,13 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { writer.write(i, i) } writer.revertPartialWritesAndClose() - val bytesWritten = writeMetrics.shuffleBytesWritten - val writeTime = writeMetrics.shuffleWriteTime - assert(writeMetrics.shuffleRecordsWritten === 0) + val bytesWritten = writeMetrics.bytesWritten + val writeTime = writeMetrics.writeTime + assert(writeMetrics.recordsWritten === 0) writer.revertPartialWritesAndClose() - assert(writeMetrics.shuffleRecordsWritten === 0) - assert(writeMetrics.shuffleBytesWritten === bytesWritten) - assert(writeMetrics.shuffleWriteTime === writeTime) + assert(writeMetrics.recordsWritten === 0) + assert(writeMetrics.bytesWritten === bytesWritten) + assert(writeMetrics.writeTime === writeTime) } test("fileSegment() can only be called after commitAndClose() has been called") { diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala new file mode 100644 index 0000000000000..9ed5016510d56 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.nio.{ByteBuffer, MappedByteBuffer} +import java.util.Arrays + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.util.io.ChunkedByteBuffer + +class DiskStoreSuite extends SparkFunSuite { + + test("reads of memory-mapped and non memory-mapped files are equivalent") { + val confKey = "spark.storage.memoryMapThreshold" + + // Create a non-trivial (not all zeros) byte array + val bytes = Array.tabulate[Byte](1000)(_.toByte) + val byteBuffer = new ChunkedByteBuffer(ByteBuffer.wrap(bytes)) + + val blockId = BlockId("rdd_1_2") + val diskBlockManager = new DiskBlockManager(new SparkConf(), deleteFilesOnStop = true) + + val diskStoreMapped = new DiskStore(new SparkConf().set(confKey, "0"), diskBlockManager) + diskStoreMapped.putBytes(blockId, byteBuffer) + val mapped = diskStoreMapped.getBytes(blockId) + assert(diskStoreMapped.remove(blockId)) + + val diskStoreNotMapped = new DiskStore(new SparkConf().set(confKey, "1m"), diskBlockManager) + diskStoreNotMapped.putBytes(blockId, byteBuffer) + val notMapped = diskStoreNotMapped.getBytes(blockId) + + // Not possible to do isInstanceOf due to visibility of HeapByteBuffer + assert(notMapped.getChunks().forall(_.getClass.getName.endsWith("HeapByteBuffer")), + "Expected HeapByteBuffer for un-mapped read") + assert(mapped.getChunks().forall(_.isInstanceOf[MappedByteBuffer]), + "Expected MappedByteBuffer for mapped read") + + def arrayFromByteBuffer(in: ByteBuffer): Array[Byte] = { + val array = new Array[Byte](in.remaining()) + in.get(array) + array + } + + assert(Arrays.equals(mapped.toArray, bytes)) + assert(Arrays.equals(notMapped.toArray, bytes)) + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala index cc50289c7b3ea..c7074078d8fd2 100644 --- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -19,11 +19,10 @@ package org.apache.spark.storage import java.io.File -import org.apache.spark.util.Utils import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.util.SparkConfWithEnv +import org.apache.spark.util.{SparkConfWithEnv, Utils} /** * Tests for the spark.local.dir and SPARK_LOCAL_DIRS configuration options. diff --git a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala new file mode 100644 index 0000000000000..145d432afe85e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala @@ -0,0 +1,413 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.nio.ByteBuffer + +import scala.language.implicitConversions +import scala.language.postfixOps +import scala.language.reflectiveCalls +import scala.reflect.ClassTag + +import org.scalatest._ + +import org.apache.spark._ +import org.apache.spark.memory.{MemoryMode, StaticMemoryManager} +import org.apache.spark.serializer.{KryoSerializer, SerializerManager} +import org.apache.spark.storage.memory.{BlockEvictionHandler, MemoryStore, PartiallySerializedBlock, PartiallyUnrolledIterator} +import org.apache.spark.util._ +import org.apache.spark.util.io.ChunkedByteBuffer + +class MemoryStoreSuite + extends SparkFunSuite + with PrivateMethodTester + with BeforeAndAfterEach + with ResetSystemProperties { + + var conf: SparkConf = new SparkConf(false) + .set("spark.test.useCompressedOops", "true") + .set("spark.storage.unrollFraction", "0.4") + .set("spark.storage.unrollMemoryThreshold", "512") + + // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test + val serializer = new KryoSerializer(new SparkConf(false).set("spark.kryoserializer.buffer", "1m")) + + val serializerManager = new SerializerManager(serializer, conf) + + // Implicitly convert strings to BlockIds for test clarity. + implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) + def rdd(rddId: Int, splitId: Int): RDDBlockId = RDDBlockId(rddId, splitId) + + override def beforeEach(): Unit = { + super.beforeEach() + // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case + System.setProperty("os.arch", "amd64") + val initialize = PrivateMethod[Unit]('initialize) + SizeEstimator invokePrivate initialize() + } + + def makeMemoryStore(maxMem: Long): (MemoryStore, BlockInfoManager) = { + val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) + val blockInfoManager = new BlockInfoManager + val blockEvictionHandler = new BlockEvictionHandler { + var memoryStore: MemoryStore = _ + override private[storage] def dropFromMemory[T: ClassTag]( + blockId: BlockId, + data: () => Either[Array[T], ChunkedByteBuffer]): StorageLevel = { + memoryStore.remove(blockId) + StorageLevel.NONE + } + } + val memoryStore = + new MemoryStore(conf, blockInfoManager, serializerManager, memManager, blockEvictionHandler) + memManager.setMemoryStore(memoryStore) + blockEvictionHandler.memoryStore = memoryStore + (memoryStore, blockInfoManager) + } + + test("reserve/release unroll memory") { + val (memoryStore, _) = makeMemoryStore(12000) + assert(memoryStore.currentUnrollMemory === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + + def reserveUnrollMemoryForThisTask(memory: Long): Boolean = { + memoryStore.reserveUnrollMemoryForThisTask(TestBlockId(""), memory, MemoryMode.ON_HEAP) + } + + // Reserve + assert(reserveUnrollMemoryForThisTask(100)) + assert(memoryStore.currentUnrollMemoryForThisTask === 100) + assert(reserveUnrollMemoryForThisTask(200)) + assert(memoryStore.currentUnrollMemoryForThisTask === 300) + assert(reserveUnrollMemoryForThisTask(500)) + assert(memoryStore.currentUnrollMemoryForThisTask === 800) + assert(!reserveUnrollMemoryForThisTask(1000000)) + assert(memoryStore.currentUnrollMemoryForThisTask === 800) // not granted + // Release + memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, 100) + assert(memoryStore.currentUnrollMemoryForThisTask === 700) + memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, 100) + assert(memoryStore.currentUnrollMemoryForThisTask === 600) + // Reserve again + assert(reserveUnrollMemoryForThisTask(4400)) + assert(memoryStore.currentUnrollMemoryForThisTask === 5000) + assert(!reserveUnrollMemoryForThisTask(20000)) + assert(memoryStore.currentUnrollMemoryForThisTask === 5000) // not granted + // Release again + memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, 1000) + assert(memoryStore.currentUnrollMemoryForThisTask === 4000) + memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP) // release all + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + } + + test("safely unroll blocks") { + val smallList = List.fill(40)(new Array[Byte](100)) + val bigList = List.fill(40)(new Array[Byte](1000)) + val ct = implicitly[ClassTag[Array[Byte]]] + val (memoryStore, blockInfoManager) = makeMemoryStore(12000) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + + def putIteratorAsValues[T]( + blockId: BlockId, + iter: Iterator[T], + classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = { + assert(blockInfoManager.lockNewBlockForWriting( + blockId, + new BlockInfo(StorageLevel.MEMORY_ONLY, classTag, tellMaster = false))) + val res = memoryStore.putIteratorAsValues(blockId, iter, classTag) + blockInfoManager.unlock(blockId) + res + } + + // Unroll with all the space in the world. This should succeed. + var putResult = putIteratorAsValues("unroll", smallList.iterator, ClassTag.Any) + assert(putResult.isRight) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) => + assert(e === a, "getValues() did not return original values!") + } + blockInfoManager.lockForWriting("unroll") + assert(memoryStore.remove("unroll")) + blockInfoManager.removeBlock("unroll") + + // Unroll with not enough space. This should succeed after kicking out someBlock1. + assert(putIteratorAsValues("someBlock1", smallList.iterator, ct).isRight) + assert(putIteratorAsValues("someBlock2", smallList.iterator, ct).isRight) + putResult = putIteratorAsValues("unroll", smallList.iterator, ClassTag.Any) + assert(putResult.isRight) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + assert(memoryStore.contains("someBlock2")) + assert(!memoryStore.contains("someBlock1")) + smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) => + assert(e === a, "getValues() did not return original values!") + } + blockInfoManager.lockForWriting("unroll") + assert(memoryStore.remove("unroll")) + blockInfoManager.removeBlock("unroll") + + // Unroll huge block with not enough space. Even after ensuring free space of 12000 * 0.4 = + // 4800 bytes, there is still not enough room to unroll this block. This returns an iterator. + // In the meantime, however, we kicked out someBlock2 before giving up. + assert(putIteratorAsValues("someBlock3", smallList.iterator, ct).isRight) + putResult = putIteratorAsValues("unroll", bigList.iterator, ClassTag.Any) + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator + assert(!memoryStore.contains("someBlock2")) + assert(putResult.isLeft) + bigList.iterator.zip(putResult.left.get).foreach { case (e, a) => + assert(e === a, "putIterator() did not return original values!") + } + // The unroll memory was freed once the iterator returned by putIterator() was fully traversed. + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + } + + test("safely unroll blocks through putIteratorAsValues") { + val (memoryStore, blockInfoManager) = makeMemoryStore(12000) + val smallList = List.fill(40)(new Array[Byte](100)) + val bigList = List.fill(40)(new Array[Byte](1000)) + def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] + def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + + def putIteratorAsValues[T]( + blockId: BlockId, + iter: Iterator[T], + classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = { + assert(blockInfoManager.lockNewBlockForWriting( + blockId, + new BlockInfo(StorageLevel.MEMORY_ONLY, classTag, tellMaster = false))) + val res = memoryStore.putIteratorAsValues(blockId, iter, classTag) + blockInfoManager.unlock(blockId) + res + } + + // Unroll with plenty of space. This should succeed and cache both blocks. + val result1 = putIteratorAsValues("b1", smallIterator, ClassTag.Any) + val result2 = putIteratorAsValues("b2", smallIterator, ClassTag.Any) + assert(memoryStore.contains("b1")) + assert(memoryStore.contains("b2")) + assert(result1.isRight) // unroll was successful + assert(result2.isRight) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + + // Re-put these two blocks so block manager knows about them too. Otherwise, block manager + // would not know how to drop them from memory later. + blockInfoManager.lockForWriting("b1") + memoryStore.remove("b1") + blockInfoManager.removeBlock("b1") + blockInfoManager.lockForWriting("b2") + memoryStore.remove("b2") + blockInfoManager.removeBlock("b2") + putIteratorAsValues("b1", smallIterator, ClassTag.Any) + putIteratorAsValues("b2", smallIterator, ClassTag.Any) + + // Unroll with not enough space. This should succeed but kick out b1 in the process. + val result3 = putIteratorAsValues("b3", smallIterator, ClassTag.Any) + assert(result3.isRight) + assert(!memoryStore.contains("b1")) + assert(memoryStore.contains("b2")) + assert(memoryStore.contains("b3")) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + blockInfoManager.lockForWriting("b3") + assert(memoryStore.remove("b3")) + blockInfoManager.removeBlock("b3") + putIteratorAsValues("b3", smallIterator, ClassTag.Any) + + // Unroll huge block with not enough space. This should fail and kick out b2 in the process. + val result4 = putIteratorAsValues("b4", bigIterator, ClassTag.Any) + assert(result4.isLeft) // unroll was unsuccessful + assert(!memoryStore.contains("b1")) + assert(!memoryStore.contains("b2")) + assert(memoryStore.contains("b3")) + assert(!memoryStore.contains("b4")) + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator + result4.left.get.close() + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // close released the unroll memory + } + + test("safely unroll blocks through putIteratorAsBytes") { + val (memoryStore, blockInfoManager) = makeMemoryStore(12000) + val smallList = List.fill(40)(new Array[Byte](100)) + val bigList = List.fill(40)(new Array[Byte](1000)) + def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] + def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + + def putIteratorAsBytes[T]( + blockId: BlockId, + iter: Iterator[T], + classTag: ClassTag[T]): Either[PartiallySerializedBlock[T], Long] = { + assert(blockInfoManager.lockNewBlockForWriting( + blockId, + new BlockInfo(StorageLevel.MEMORY_ONLY_SER, classTag, tellMaster = false))) + val res = memoryStore.putIteratorAsBytes(blockId, iter, classTag, MemoryMode.ON_HEAP) + blockInfoManager.unlock(blockId) + res + } + + // Unroll with plenty of space. This should succeed and cache both blocks. + val result1 = putIteratorAsBytes("b1", smallIterator, ClassTag.Any) + val result2 = putIteratorAsBytes("b2", smallIterator, ClassTag.Any) + assert(memoryStore.contains("b1")) + assert(memoryStore.contains("b2")) + assert(result1.isRight) // unroll was successful + assert(result2.isRight) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + + // Re-put these two blocks so block manager knows about them too. Otherwise, block manager + // would not know how to drop them from memory later. + blockInfoManager.lockForWriting("b1") + memoryStore.remove("b1") + blockInfoManager.removeBlock("b1") + blockInfoManager.lockForWriting("b2") + memoryStore.remove("b2") + blockInfoManager.removeBlock("b2") + putIteratorAsBytes("b1", smallIterator, ClassTag.Any) + putIteratorAsBytes("b2", smallIterator, ClassTag.Any) + + // Unroll with not enough space. This should succeed but kick out b1 in the process. + val result3 = putIteratorAsBytes("b3", smallIterator, ClassTag.Any) + assert(result3.isRight) + assert(!memoryStore.contains("b1")) + assert(memoryStore.contains("b2")) + assert(memoryStore.contains("b3")) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + blockInfoManager.lockForWriting("b3") + assert(memoryStore.remove("b3")) + blockInfoManager.removeBlock("b3") + putIteratorAsBytes("b3", smallIterator, ClassTag.Any) + + // Unroll huge block with not enough space. This should fail and kick out b2 in the process. + val result4 = putIteratorAsBytes("b4", bigIterator, ClassTag.Any) + assert(result4.isLeft) // unroll was unsuccessful + assert(!memoryStore.contains("b1")) + assert(!memoryStore.contains("b2")) + assert(memoryStore.contains("b3")) + assert(!memoryStore.contains("b4")) + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator + result4.left.get.discard() + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // discard released the unroll memory + } + + test("PartiallySerializedBlock.valuesIterator") { + val (memoryStore, blockInfoManager) = makeMemoryStore(12000) + val bigList = List.fill(40)(new Array[Byte](1000)) + def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] + + // Unroll huge block with not enough space. This should fail. + assert(blockInfoManager.lockNewBlockForWriting( + "b1", + new BlockInfo(StorageLevel.MEMORY_ONLY_SER, ClassTag.Any, tellMaster = false))) + val res = memoryStore.putIteratorAsBytes("b1", bigIterator, ClassTag.Any, MemoryMode.ON_HEAP) + blockInfoManager.unlock("b1") + assert(res.isLeft) + assert(memoryStore.currentUnrollMemoryForThisTask > 0) + val valuesReturnedFromFailedPut = res.left.get.valuesIterator.toSeq // force materialization + valuesReturnedFromFailedPut.zip(bigList).foreach { case (e, a) => + assert(e === a, "PartiallySerializedBlock.valuesIterator() did not return original values!") + } + // The unroll memory was freed once the iterator was fully traversed. + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + } + + test("PartiallySerializedBlock.finishWritingToStream") { + val (memoryStore, blockInfoManager) = makeMemoryStore(12000) + val bigList = List.fill(40)(new Array[Byte](1000)) + def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] + + // Unroll huge block with not enough space. This should fail. + assert(blockInfoManager.lockNewBlockForWriting( + "b1", + new BlockInfo(StorageLevel.MEMORY_ONLY_SER, ClassTag.Any, tellMaster = false))) + val res = memoryStore.putIteratorAsBytes("b1", bigIterator, ClassTag.Any, MemoryMode.ON_HEAP) + blockInfoManager.unlock("b1") + assert(res.isLeft) + assert(memoryStore.currentUnrollMemoryForThisTask > 0) + val bos = new ByteBufferOutputStream() + res.left.get.finishWritingToStream(bos) + // The unroll memory was freed once the block was fully written. + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + val deserializationStream = serializerManager.dataDeserializeStream[Any]( + "b1", new ByteBufferInputStream(bos.toByteBuffer))(ClassTag.Any) + deserializationStream.zip(bigList.iterator).foreach { case (e, a) => + assert(e === a, + "PartiallySerializedBlock.finishWritingtoStream() did not write original values!") + } + } + + test("multiple unrolls by the same thread") { + val (memoryStore, _) = makeMemoryStore(12000) + val smallList = List.fill(40)(new Array[Byte](100)) + def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + + def putIteratorAsValues( + blockId: BlockId, + iter: Iterator[Any]): Either[PartiallyUnrolledIterator[Any], Long] = { + memoryStore.putIteratorAsValues(blockId, iter, ClassTag.Any) + } + + // All unroll memory used is released because putIterator did not return an iterator + assert(putIteratorAsValues("b1", smallIterator).isRight) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + assert(putIteratorAsValues("b2", smallIterator).isRight) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + + // Unroll memory is not released because putIterator returned an iterator + // that still depends on the underlying vector used in the process + assert(putIteratorAsValues("b3", smallIterator).isLeft) + val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisTask + assert(unrollMemoryAfterB3 > 0) + + // The unroll memory owned by this thread builds on top of its value after the previous unrolls + assert(putIteratorAsValues("b4", smallIterator).isLeft) + val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisTask + assert(unrollMemoryAfterB4 > unrollMemoryAfterB3) + + // ... but only to a certain extent (until we run out of free space to grant new unroll memory) + assert(putIteratorAsValues("b5", smallIterator).isLeft) + val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisTask + assert(putIteratorAsValues("b6", smallIterator).isLeft) + val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisTask + assert(putIteratorAsValues("b7", smallIterator).isLeft) + val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisTask + assert(unrollMemoryAfterB5 === unrollMemoryAfterB4) + assert(unrollMemoryAfterB6 === unrollMemoryAfterB4) + assert(unrollMemoryAfterB7 === unrollMemoryAfterB4) + } + + test("lazily create a big ByteBuffer to avoid OOM if it cannot be put into MemoryStore") { + val (memoryStore, blockInfoManager) = makeMemoryStore(12000) + val blockId = BlockId("rdd_3_10") + blockInfoManager.lockNewBlockForWriting( + blockId, new BlockInfo(StorageLevel.MEMORY_ONLY, ClassTag.Any, tellMaster = false)) + memoryStore.putBytes(blockId, 13000, MemoryMode.ON_HEAP, () => { + fail("A big ByteBuffer that cannot be put into MemoryStore should not be created") + }) + } + + test("put a small ByteBuffer to MemoryStore") { + val (memoryStore, _) = makeMemoryStore(12000) + val blockId = BlockId("rdd_3_10") + var bytes: ChunkedByteBuffer = null + memoryStore.putBytes(blockId, 10000, MemoryMode.ON_HEAP, () => { + bytes = new ChunkedByteBuffer(ByteBuffer.allocate(10000)) + bytes + }) + assert(memoryStore.getSize(blockId) === 10000) + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 828153bdbfc44..e3ec99685f73c 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -21,7 +21,7 @@ import java.io.InputStream import java.util.concurrent.Semaphore import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent.future +import scala.concurrent.Future import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito._ @@ -99,7 +99,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, - 48 * 1024 * 1024) + 48 * 1024 * 1024, + Int.MaxValue) // 3 local blocks fetched in initialization verify(blockManager, times(3)).getBlockData(any()) @@ -149,7 +150,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - future { + Future { // Return the first two blocks, and wait till task completion before returning the 3rd one listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) @@ -171,7 +172,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, - 48 * 1024 * 1024) + 48 * 1024 * 1024, + Int.MaxValue) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() iterator.next()._2.close() // close() first block's input stream @@ -211,7 +213,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - future { + Future { // Return the first block, and then fail. listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) @@ -233,7 +235,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, - 48 * 1024 * 1024) + 48 * 1024 * 1024, + Int.MaxValue) // Continue only after the mock calls onBlockFetchFailure sem.acquire() diff --git a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala index 1a199beb3558f..9835f11a2f7ed 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import org.apache.spark.{SparkFunSuite, Success} +import org.apache.spark.{SparkConf, SparkFunSuite, Success} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ @@ -29,9 +29,11 @@ class StorageStatusListenerSuite extends SparkFunSuite { private val bm2 = BlockManagerId("fat", "duck", 2) private val taskInfo1 = new TaskInfo(0, 0, 0, 0, "big", "dog", TaskLocality.ANY, false) private val taskInfo2 = new TaskInfo(0, 0, 0, 0, "fat", "duck", TaskLocality.ANY, false) + private val conf = new SparkConf() test("block manager added/removed") { - val listener = new StorageStatusListener + conf.set("spark.ui.retainedDeadExecutors", "1") + val listener = new StorageStatusListener(conf) // Block manager add assert(listener.executorIdToStorageStatus.size === 0) @@ -53,14 +55,18 @@ class StorageStatusListenerSuite extends SparkFunSuite { assert(listener.executorIdToStorageStatus.size === 1) assert(!listener.executorIdToStorageStatus.get("big").isDefined) assert(listener.executorIdToStorageStatus.get("fat").isDefined) + assert(listener.deadExecutorStorageStatus.size === 1) + assert(listener.deadExecutorStorageStatus(0).blockManagerId.executorId.equals("big")) listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(1L, bm2)) assert(listener.executorIdToStorageStatus.size === 0) assert(!listener.executorIdToStorageStatus.get("big").isDefined) assert(!listener.executorIdToStorageStatus.get("fat").isDefined) + assert(listener.deadExecutorStorageStatus.size === 1) + assert(listener.deadExecutorStorageStatus(0).blockManagerId.executorId.equals("fat")) } test("task end without updated blocks") { - val listener = new StorageStatusListener + val listener = new StorageStatusListener(conf) listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm2, 2000L)) val taskMetrics = new TaskMetrics @@ -76,48 +82,51 @@ class StorageStatusListenerSuite extends SparkFunSuite { assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) } - test("task end with updated blocks") { - val listener = new StorageStatusListener + test("updated blocks") { + val listener = new StorageStatusListener(conf) listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm2, 2000L)) - val taskMetrics1 = new TaskMetrics - val taskMetrics2 = new TaskMetrics - val block1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.DISK_ONLY, 0L, 100L, 0L)) - val block2 = (RDDBlockId(1, 2), BlockStatus(StorageLevel.DISK_ONLY, 0L, 200L, 0L)) - val block3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.DISK_ONLY, 0L, 300L, 0L)) - taskMetrics1.updatedBlocks = Some(Seq(block1, block2)) - taskMetrics2.updatedBlocks = Some(Seq(block3)) - - // Task end with new blocks + + val blockUpdateInfos1 = Seq( + BlockUpdatedInfo(bm1, RDDBlockId(1, 1), StorageLevel.DISK_ONLY, 0L, 100L), + BlockUpdatedInfo(bm1, RDDBlockId(1, 2), StorageLevel.DISK_ONLY, 0L, 200L) + ) + val blockUpdateInfos2 = + Seq(BlockUpdatedInfo(bm2, RDDBlockId(4, 0), StorageLevel.DISK_ONLY, 0L, 300L)) + + // Add some new blocks assert(listener.executorIdToStorageStatus("big").numBlocks === 0) assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics1)) + postUpdateBlock(listener, blockUpdateInfos1) assert(listener.executorIdToStorageStatus("big").numBlocks === 2) assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2))) assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) - listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo2, taskMetrics2)) + postUpdateBlock(listener, blockUpdateInfos2) assert(listener.executorIdToStorageStatus("big").numBlocks === 2) assert(listener.executorIdToStorageStatus("fat").numBlocks === 1) assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2))) assert(listener.executorIdToStorageStatus("fat").containsBlock(RDDBlockId(4, 0))) - // Task end with dropped blocks - val droppedBlock1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.NONE, 0L, 0L, 0L)) - val droppedBlock2 = (RDDBlockId(1, 2), BlockStatus(StorageLevel.NONE, 0L, 0L, 0L)) - val droppedBlock3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.NONE, 0L, 0L, 0L)) - taskMetrics1.updatedBlocks = Some(Seq(droppedBlock1, droppedBlock3)) - taskMetrics2.updatedBlocks = Some(Seq(droppedBlock2, droppedBlock3)) + // Dropped the blocks + val droppedBlockInfo1 = Seq( + BlockUpdatedInfo(bm1, RDDBlockId(1, 1), StorageLevel.NONE, 0L, 0L), + BlockUpdatedInfo(bm1, RDDBlockId(4, 0), StorageLevel.NONE, 0L, 0L) + ) + val droppedBlockInfo2 = Seq( + BlockUpdatedInfo(bm2, RDDBlockId(1, 2), StorageLevel.NONE, 0L, 0L), + BlockUpdatedInfo(bm2, RDDBlockId(4, 0), StorageLevel.NONE, 0L, 0L) + ) - listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics1)) + postUpdateBlock(listener, droppedBlockInfo1) assert(listener.executorIdToStorageStatus("big").numBlocks === 1) assert(listener.executorIdToStorageStatus("fat").numBlocks === 1) assert(!listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2))) assert(listener.executorIdToStorageStatus("fat").containsBlock(RDDBlockId(4, 0))) - listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo2, taskMetrics2)) + postUpdateBlock(listener, droppedBlockInfo2) assert(listener.executorIdToStorageStatus("big").numBlocks === 1) assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) assert(!listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1))) @@ -126,17 +135,16 @@ class StorageStatusListenerSuite extends SparkFunSuite { } test("unpersist RDD") { - val listener = new StorageStatusListener + val listener = new StorageStatusListener(conf) listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) - val taskMetrics1 = new TaskMetrics - val taskMetrics2 = new TaskMetrics - val block1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.DISK_ONLY, 0L, 100L, 0L)) - val block2 = (RDDBlockId(1, 2), BlockStatus(StorageLevel.DISK_ONLY, 0L, 200L, 0L)) - val block3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.DISK_ONLY, 0L, 300L, 0L)) - taskMetrics1.updatedBlocks = Some(Seq(block1, block2)) - taskMetrics2.updatedBlocks = Some(Seq(block3)) - listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics1)) - listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics2)) + val blockUpdateInfos1 = Seq( + BlockUpdatedInfo(bm1, RDDBlockId(1, 1), StorageLevel.DISK_ONLY, 0L, 100L), + BlockUpdatedInfo(bm1, RDDBlockId(1, 2), StorageLevel.DISK_ONLY, 0L, 200L) + ) + val blockUpdateInfos2 = + Seq(BlockUpdatedInfo(bm1, RDDBlockId(4, 0), StorageLevel.DISK_ONLY, 0L, 300L)) + postUpdateBlock(listener, blockUpdateInfos1) + postUpdateBlock(listener, blockUpdateInfos2) assert(listener.executorIdToStorageStatus("big").numBlocks === 3) // Unpersist RDD @@ -149,4 +157,11 @@ class StorageStatusListenerSuite extends SparkFunSuite { listener.onUnpersistRDD(SparkListenerUnpersistRDD(1)) assert(listener.executorIdToStorageStatus("big").numBlocks === 0) } + + private def postUpdateBlock( + listener: StorageStatusListener, updateBlockInfos: Seq[BlockUpdatedInfo]): Unit = { + updateBlockInfos.foreach { updateBlockInfo => + listener.onBlockUpdated(SparkListenerBlockUpdated(updateBlockInfo)) + } + } } diff --git a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala index 1d5a813a4d336..e5733aebf607c 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala @@ -33,10 +33,9 @@ class StorageSuite extends SparkFunSuite { assert(status.memUsed === 0L) assert(status.memRemaining === 1000L) assert(status.diskUsed === 0L) - assert(status.offHeapUsed === 0L) - status.addBlock(TestBlockId("foo"), BlockStatus(memAndDisk, 10L, 20L, 1L)) - status.addBlock(TestBlockId("fee"), BlockStatus(memAndDisk, 10L, 20L, 1L)) - status.addBlock(TestBlockId("faa"), BlockStatus(memAndDisk, 10L, 20L, 1L)) + status.addBlock(TestBlockId("foo"), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(TestBlockId("fee"), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(TestBlockId("faa"), BlockStatus(memAndDisk, 10L, 20L)) status } @@ -50,18 +49,16 @@ class StorageSuite extends SparkFunSuite { assert(status.memUsed === 30L) assert(status.memRemaining === 970L) assert(status.diskUsed === 60L) - assert(status.offHeapUsed === 3L) } test("storage status update non-RDD blocks") { val status = storageStatus1 - status.updateBlock(TestBlockId("foo"), BlockStatus(memAndDisk, 50L, 100L, 1L)) - status.updateBlock(TestBlockId("fee"), BlockStatus(memAndDisk, 100L, 20L, 0L)) + status.updateBlock(TestBlockId("foo"), BlockStatus(memAndDisk, 50L, 100L)) + status.updateBlock(TestBlockId("fee"), BlockStatus(memAndDisk, 100L, 20L)) assert(status.blocks.size === 3) assert(status.memUsed === 160L) assert(status.memRemaining === 840L) assert(status.diskUsed === 140L) - assert(status.offHeapUsed === 2L) } test("storage status remove non-RDD blocks") { @@ -73,20 +70,19 @@ class StorageSuite extends SparkFunSuite { assert(status.memUsed === 10L) assert(status.memRemaining === 990L) assert(status.diskUsed === 20L) - assert(status.offHeapUsed === 1L) } // For testing add, update, remove, get, and contains etc. for both RDD and non-RDD blocks private def storageStatus2: StorageStatus = { val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L) assert(status.rddBlocks.isEmpty) - status.addBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 10L, 20L, 0L)) - status.addBlock(TestBlockId("man"), BlockStatus(memAndDisk, 10L, 20L, 0L)) - status.addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 10L, 20L, 1L)) - status.addBlock(RDDBlockId(1, 1), BlockStatus(memAndDisk, 100L, 200L, 1L)) - status.addBlock(RDDBlockId(2, 2), BlockStatus(memAndDisk, 10L, 20L, 1L)) - status.addBlock(RDDBlockId(2, 3), BlockStatus(memAndDisk, 10L, 20L, 0L)) - status.addBlock(RDDBlockId(2, 4), BlockStatus(memAndDisk, 10L, 40L, 0L)) + status.addBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(TestBlockId("man"), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(RDDBlockId(1, 1), BlockStatus(memAndDisk, 100L, 200L)) + status.addBlock(RDDBlockId(2, 2), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(RDDBlockId(2, 3), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(RDDBlockId(2, 4), BlockStatus(memAndDisk, 10L, 40L)) status } @@ -113,9 +109,6 @@ class StorageSuite extends SparkFunSuite { assert(status.diskUsedByRdd(0) === 20L) assert(status.diskUsedByRdd(1) === 200L) assert(status.diskUsedByRdd(2) === 80L) - assert(status.offHeapUsedByRdd(0) === 1L) - assert(status.offHeapUsedByRdd(1) === 1L) - assert(status.offHeapUsedByRdd(2) === 1L) assert(status.rddStorageLevel(0) === Some(memAndDisk)) assert(status.rddStorageLevel(1) === Some(memAndDisk)) assert(status.rddStorageLevel(2) === Some(memAndDisk)) @@ -124,15 +117,14 @@ class StorageSuite extends SparkFunSuite { assert(status.rddBlocksById(10).isEmpty) assert(status.memUsedByRdd(10) === 0L) assert(status.diskUsedByRdd(10) === 0L) - assert(status.offHeapUsedByRdd(10) === 0L) assert(status.rddStorageLevel(10) === None) } test("storage status update RDD blocks") { val status = storageStatus2 - status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 5000L, 0L, 0L)) - status.updateBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 0L, 0L, 0L)) - status.updateBlock(RDDBlockId(2, 2), BlockStatus(memAndDisk, 0L, 1000L, 0L)) + status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 5000L, 0L)) + status.updateBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 0L, 0L)) + status.updateBlock(RDDBlockId(2, 2), BlockStatus(memAndDisk, 0L, 1000L)) assert(status.blocks.size === 7) assert(status.rddBlocks.size === 5) assert(status.rddBlocksById(0).size === 1) @@ -144,9 +136,6 @@ class StorageSuite extends SparkFunSuite { assert(status.diskUsedByRdd(0) === 0L) assert(status.diskUsedByRdd(1) === 200L) assert(status.diskUsedByRdd(2) === 1060L) - assert(status.offHeapUsedByRdd(0) === 0L) - assert(status.offHeapUsedByRdd(1) === 1L) - assert(status.offHeapUsedByRdd(2) === 0L) } test("storage status remove RDD blocks") { @@ -170,9 +159,6 @@ class StorageSuite extends SparkFunSuite { assert(status.diskUsedByRdd(0) === 20L) assert(status.diskUsedByRdd(1) === 0L) assert(status.diskUsedByRdd(2) === 20L) - assert(status.offHeapUsedByRdd(0) === 1L) - assert(status.offHeapUsedByRdd(1) === 0L) - assert(status.offHeapUsedByRdd(2) === 0L) } test("storage status containsBlock") { @@ -209,17 +195,17 @@ class StorageSuite extends SparkFunSuite { val status = storageStatus2 assert(status.blocks.size === status.numBlocks) assert(status.rddBlocks.size === status.numRddBlocks) - status.addBlock(TestBlockId("Foo"), BlockStatus(memAndDisk, 0L, 0L, 100L)) - status.addBlock(RDDBlockId(4, 4), BlockStatus(memAndDisk, 0L, 0L, 100L)) - status.addBlock(RDDBlockId(4, 8), BlockStatus(memAndDisk, 0L, 0L, 100L)) + status.addBlock(TestBlockId("Foo"), BlockStatus(memAndDisk, 0L, 0L)) + status.addBlock(RDDBlockId(4, 4), BlockStatus(memAndDisk, 0L, 0L)) + status.addBlock(RDDBlockId(4, 8), BlockStatus(memAndDisk, 0L, 0L)) assert(status.blocks.size === status.numBlocks) assert(status.rddBlocks.size === status.numRddBlocks) assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) - status.updateBlock(TestBlockId("Foo"), BlockStatus(memAndDisk, 0L, 10L, 400L)) - status.updateBlock(RDDBlockId(4, 0), BlockStatus(memAndDisk, 0L, 0L, 100L)) - status.updateBlock(RDDBlockId(4, 8), BlockStatus(memAndDisk, 0L, 0L, 100L)) - status.updateBlock(RDDBlockId(10, 10), BlockStatus(memAndDisk, 0L, 0L, 100L)) + status.updateBlock(TestBlockId("Foo"), BlockStatus(memAndDisk, 0L, 10L)) + status.updateBlock(RDDBlockId(4, 0), BlockStatus(memAndDisk, 0L, 0L)) + status.updateBlock(RDDBlockId(4, 8), BlockStatus(memAndDisk, 0L, 0L)) + status.updateBlock(RDDBlockId(10, 10), BlockStatus(memAndDisk, 0L, 0L)) assert(status.blocks.size === status.numBlocks) assert(status.rddBlocks.size === status.numRddBlocks) assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) @@ -244,29 +230,24 @@ class StorageSuite extends SparkFunSuite { val status = storageStatus2 def actualMemUsed: Long = status.blocks.values.map(_.memSize).sum def actualDiskUsed: Long = status.blocks.values.map(_.diskSize).sum - def actualOffHeapUsed: Long = status.blocks.values.map(_.externalBlockStoreSize).sum assert(status.memUsed === actualMemUsed) assert(status.diskUsed === actualDiskUsed) - assert(status.offHeapUsed === actualOffHeapUsed) - status.addBlock(TestBlockId("fire"), BlockStatus(memAndDisk, 4000L, 5000L, 6000L)) - status.addBlock(TestBlockId("wire"), BlockStatus(memAndDisk, 400L, 500L, 600L)) - status.addBlock(RDDBlockId(25, 25), BlockStatus(memAndDisk, 40L, 50L, 60L)) + status.addBlock(TestBlockId("fire"), BlockStatus(memAndDisk, 4000L, 5000L)) + status.addBlock(TestBlockId("wire"), BlockStatus(memAndDisk, 400L, 500L)) + status.addBlock(RDDBlockId(25, 25), BlockStatus(memAndDisk, 40L, 50L)) assert(status.memUsed === actualMemUsed) assert(status.diskUsed === actualDiskUsed) - assert(status.offHeapUsed === actualOffHeapUsed) - status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 4L, 5L, 6L)) - status.updateBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 4L, 5L, 6L)) - status.updateBlock(RDDBlockId(1, 1), BlockStatus(memAndDisk, 4L, 5L, 6L)) + status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 4L, 5L)) + status.updateBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 4L, 5L)) + status.updateBlock(RDDBlockId(1, 1), BlockStatus(memAndDisk, 4L, 5L)) assert(status.memUsed === actualMemUsed) assert(status.diskUsed === actualDiskUsed) - assert(status.offHeapUsed === actualOffHeapUsed) status.removeBlock(TestBlockId("fire")) status.removeBlock(TestBlockId("man")) status.removeBlock(RDDBlockId(2, 2)) status.removeBlock(RDDBlockId(2, 3)) assert(status.memUsed === actualMemUsed) assert(status.diskUsed === actualDiskUsed) - assert(status.offHeapUsed === actualOffHeapUsed) } // For testing StorageUtils.updateRddInfo and StorageUtils.getRddBlockLocations @@ -274,14 +255,14 @@ class StorageSuite extends SparkFunSuite { val status1 = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L) val status2 = new StorageStatus(BlockManagerId("fat", "duck", 2), 2000L) val status3 = new StorageStatus(BlockManagerId("fat", "cat", 3), 3000L) - status1.addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 1L, 2L, 0L)) - status1.addBlock(RDDBlockId(0, 1), BlockStatus(memAndDisk, 1L, 2L, 0L)) - status2.addBlock(RDDBlockId(0, 2), BlockStatus(memAndDisk, 1L, 2L, 0L)) - status2.addBlock(RDDBlockId(0, 3), BlockStatus(memAndDisk, 1L, 2L, 0L)) - status2.addBlock(RDDBlockId(1, 0), BlockStatus(memAndDisk, 1L, 2L, 0L)) - status2.addBlock(RDDBlockId(1, 1), BlockStatus(memAndDisk, 1L, 2L, 0L)) - status3.addBlock(RDDBlockId(0, 4), BlockStatus(memAndDisk, 1L, 2L, 0L)) - status3.addBlock(RDDBlockId(1, 2), BlockStatus(memAndDisk, 1L, 2L, 0L)) + status1.addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 1L, 2L)) + status1.addBlock(RDDBlockId(0, 1), BlockStatus(memAndDisk, 1L, 2L)) + status2.addBlock(RDDBlockId(0, 2), BlockStatus(memAndDisk, 1L, 2L)) + status2.addBlock(RDDBlockId(0, 3), BlockStatus(memAndDisk, 1L, 2L)) + status2.addBlock(RDDBlockId(1, 0), BlockStatus(memAndDisk, 1L, 2L)) + status2.addBlock(RDDBlockId(1, 1), BlockStatus(memAndDisk, 1L, 2L)) + status3.addBlock(RDDBlockId(0, 4), BlockStatus(memAndDisk, 1L, 2L)) + status3.addBlock(RDDBlockId(1, 2), BlockStatus(memAndDisk, 1L, 2L)) Seq(status1, status2, status3) } @@ -334,9 +315,9 @@ class StorageSuite extends SparkFunSuite { test("StorageUtils.getRddBlockLocations with multiple locations") { val storageStatuses = stockStorageStatuses - storageStatuses(0).addBlock(RDDBlockId(1, 0), BlockStatus(memAndDisk, 1L, 2L, 0L)) - storageStatuses(0).addBlock(RDDBlockId(0, 4), BlockStatus(memAndDisk, 1L, 2L, 0L)) - storageStatuses(2).addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 1L, 2L, 0L)) + storageStatuses(0).addBlock(RDDBlockId(1, 0), BlockStatus(memAndDisk, 1L, 2L)) + storageStatuses(0).addBlock(RDDBlockId(0, 4), BlockStatus(memAndDisk, 1L, 2L)) + storageStatuses(2).addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 1L, 2L)) val blockLocations0 = StorageUtils.getRddBlockLocations(0, storageStatuses) val blockLocations1 = StorageUtils.getRddBlockLocations(1, storageStatuses) assert(blockLocations0.size === 5) diff --git a/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala b/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala index cc76c141c53cc..74eeca282882a 100644 --- a/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala @@ -64,7 +64,13 @@ class PagedTableSuite extends SparkFunSuite { override def row(t: Int): Seq[Node] = Nil - override def goButtonJavascriptFunction: (String, String) = ("", "") + override def pageSizeFormField: String = "pageSize" + + override def prevPageSizeFormField: String = "prevPageSize" + + override def pageNumberFormField: String = "page" + + override def goButtonFormPath: String = "" } assert(pagedTable.pageNavigation(1, 10, 1) === Nil) diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 86699e7f56953..b83ffa3282e4d 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -31,6 +31,8 @@ import org.apache.spark.ui.scope.RDDOperationGraphListener class StagePageSuite extends SparkFunSuite with LocalSparkContext { + private val peakExecutionMemory = 10 + test("peak execution memory only displayed if unsafe is enabled") { val unsafeConf = "spark.sql.unsafe.enabled" val conf = new SparkConf(false).set(unsafeConf, "true") @@ -52,7 +54,7 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { val conf = new SparkConf(false).set(unsafeConf, "true") val html = renderStagePage(conf).toString().toLowerCase // verify min/25/50/75/max show task value not cumulative values - assert(html.contains("10.0 b" * 5)) + assert(html.contains(s"$peakExecutionMemory.0 b" * 5)) } /** @@ -79,14 +81,13 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { (1 to 2).foreach { taskId => val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false) - val peakExecutionMemory = 10 - taskInfo.accumulables += new AccumulableInfo(0, InternalAccumulator.PEAK_EXECUTION_MEMORY, - Some(peakExecutionMemory.toString), (peakExecutionMemory * taskId).toString, true) jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) taskInfo.markSuccessful() + val taskMetrics = TaskMetrics.empty + taskMetrics.incPeakExecutionMemory(peakExecutionMemory) jobListener.onTaskEnd( - SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty)) + SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, taskMetrics)) } jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo)) page.render(request) diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 18eec7da9763e..b0a35fe8c3319 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.ui import java.net.{HttpURLConnection, URL} -import javax.servlet.http.{HttpServletResponse, HttpServletRequest} +import javax.servlet.http.{HttpServletRequest, HttpServletResponse} import scala.io.Source import scala.xml.Node @@ -26,16 +26,16 @@ import scala.xml.Node import com.gargoylesoftware.htmlunit.DefaultCssErrorHandler import org.json4s._ import org.json4s.jackson.JsonMethods -import org.openqa.selenium.htmlunit.HtmlUnitDriver import org.openqa.selenium.{By, WebDriver} +import org.openqa.selenium.htmlunit.HtmlUnitDriver import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.selenium.WebBrowser import org.scalatest.time.SpanSugar._ import org.w3c.css.sac.CSSParseException -import org.apache.spark.LocalSparkContext._ import org.apache.spark._ +import org.apache.spark.LocalSparkContext._ import org.apache.spark.api.java.StorageLevels import org.apache.spark.deploy.history.HistoryServerSuite import org.apache.spark.shuffle.FetchFailedException @@ -76,14 +76,19 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B override def beforeAll(): Unit = { + super.beforeAll() webDriver = new HtmlUnitDriver { getWebClient.setCssErrorHandler(new SparkUICssErrorHandler) } } override def afterAll(): Unit = { - if (webDriver != null) { - webDriver.quit() + try { + if (webDriver != null) { + webDriver.quit() + } + } finally { + super.afterAll() } } @@ -284,7 +289,11 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B JInt(stageId) <- stage \ "stageId" JInt(attemptId) <- stage \ "attemptId" } { - val exp = if (attemptId == 0 && stageId == 1) StageStatus.FAILED else StageStatus.COMPLETE + val exp = if (attemptId.toInt == 0 && stageId.toInt == 1) { + StageStatus.FAILED + } else { + StageStatus.COMPLETE + } status should be (exp.name()) } @@ -615,29 +624,29 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B assert(stage0.contains("digraph G {\n subgraph clusterstage_0 {\n " + "label="Stage 0";\n subgraph ")) assert(stage0.contains("{\n label="parallelize";\n " + - "0 [label="ParallelCollectionRDD [0]"];\n }")) + "0 [label="ParallelCollectionRDD [0]")) assert(stage0.contains("{\n label="map";\n " + - "1 [label="MapPartitionsRDD [1]"];\n }")) + "1 [label="MapPartitionsRDD [1]")) assert(stage0.contains("{\n label="groupBy";\n " + - "2 [label="MapPartitionsRDD [2]"];\n }")) + "2 [label="MapPartitionsRDD [2]")) val stage1 = Source.fromURL(sc.ui.get.appUIAddress + "/stages/stage/?id=1&attempt=0&expandDagViz=true").mkString assert(stage1.contains("digraph G {\n subgraph clusterstage_1 {\n " + "label="Stage 1";\n subgraph ")) assert(stage1.contains("{\n label="groupBy";\n " + - "3 [label="ShuffledRDD [3]"];\n }")) + "3 [label="ShuffledRDD [3]")) assert(stage1.contains("{\n label="map";\n " + - "4 [label="MapPartitionsRDD [4]"];\n }")) + "4 [label="MapPartitionsRDD [4]")) assert(stage1.contains("{\n label="groupBy";\n " + - "5 [label="MapPartitionsRDD [5]"];\n }")) + "5 [label="MapPartitionsRDD [5]")) val stage2 = Source.fromURL(sc.ui.get.appUIAddress + "/stages/stage/?id=2&attempt=0&expandDagViz=true").mkString assert(stage2.contains("digraph G {\n subgraph clusterstage_2 {\n " + "label="Stage 2";\n subgraph ")) assert(stage2.contains("{\n label="groupBy";\n " + - "6 [label="ShuffledRDD [6]"];\n }")) + "6 [label="ShuffledRDD [6]")) } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 8f9502b5673d1..2b59b48d8bc98 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -17,17 +17,16 @@ package org.apache.spark.ui -import java.net.ServerSocket +import java.net.{BindException, ServerSocket} import scala.io.Source -import scala.util.{Failure, Success, Try} import org.eclipse.jetty.servlet.ServletContextHandler import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ +import org.apache.spark._ import org.apache.spark.LocalSparkContext._ -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} class UISuite extends SparkFunSuite { @@ -45,6 +44,20 @@ class UISuite extends SparkFunSuite { sc } + private def sslDisabledConf(): (SparkConf, SSLOptions) = { + val conf = new SparkConf + (conf, new SecurityManager(conf).getSSLOptions("ui")) + } + + private def sslEnabledConf(): (SparkConf, SSLOptions) = { + val conf = new SparkConf() + .set("spark.ssl.ui.enabled", "true") + .set("spark.ssl.ui.keyStore", "./src/test/resources/spark.keystore") + .set("spark.ssl.ui.keyStorePassword", "123456") + .set("spark.ssl.ui.keyPassword", "123456") + (conf, new SecurityManager(conf).getSSLOptions("ui")) + } + ignore("basic ui visibility") { withSpark(newSparkContext()) { sc => // test if the ui is visible, and all the expected tabs are visible @@ -70,33 +83,92 @@ class UISuite extends SparkFunSuite { } test("jetty selects different port under contention") { - val server = new ServerSocket(0) - val startPort = server.getLocalPort - val serverInfo1 = JettyUtils.startJettyServer( - "0.0.0.0", startPort, Seq[ServletContextHandler](), new SparkConf) - val serverInfo2 = JettyUtils.startJettyServer( - "0.0.0.0", startPort, Seq[ServletContextHandler](), new SparkConf) - // Allow some wiggle room in case ports on the machine are under contention - val boundPort1 = serverInfo1.boundPort - val boundPort2 = serverInfo2.boundPort - assert(boundPort1 != startPort) - assert(boundPort2 != startPort) - assert(boundPort1 != boundPort2) - serverInfo1.server.stop() - serverInfo2.server.stop() - server.close() + var server: ServerSocket = null + var serverInfo1: ServerInfo = null + var serverInfo2: ServerInfo = null + val (conf, sslOptions) = sslDisabledConf() + try { + server = new ServerSocket(0) + val startPort = server.getLocalPort + serverInfo1 = JettyUtils.startJettyServer( + "0.0.0.0", startPort, sslOptions, Seq[ServletContextHandler](), conf) + serverInfo2 = JettyUtils.startJettyServer( + "0.0.0.0", startPort, sslOptions, Seq[ServletContextHandler](), conf) + // Allow some wiggle room in case ports on the machine are under contention + val boundPort1 = serverInfo1.boundPort + val boundPort2 = serverInfo2.boundPort + assert(boundPort1 != startPort) + assert(boundPort2 != startPort) + assert(boundPort1 != boundPort2) + } finally { + stopServer(serverInfo1) + stopServer(serverInfo2) + closeSocket(server) + } + } + + test("jetty with https selects different port under contention") { + var server: ServerSocket = null + var serverInfo1: ServerInfo = null + var serverInfo2: ServerInfo = null + try { + server = new ServerSocket(0) + val startPort = server.getLocalPort + val (conf, sslOptions) = sslEnabledConf() + serverInfo1 = JettyUtils.startJettyServer( + "0.0.0.0", startPort, sslOptions, Seq[ServletContextHandler](), conf, "server1") + serverInfo2 = JettyUtils.startJettyServer( + "0.0.0.0", startPort, sslOptions, Seq[ServletContextHandler](), conf, "server2") + // Allow some wiggle room in case ports on the machine are under contention + val boundPort1 = serverInfo1.boundPort + val boundPort2 = serverInfo2.boundPort + assert(boundPort1 != startPort) + assert(boundPort2 != startPort) + assert(boundPort1 != boundPort2) + } finally { + stopServer(serverInfo1) + stopServer(serverInfo2) + closeSocket(server) + } } test("jetty binds to port 0 correctly") { - val serverInfo = JettyUtils.startJettyServer( - "0.0.0.0", 0, Seq[ServletContextHandler](), new SparkConf) - val server = serverInfo.server - val boundPort = serverInfo.boundPort - assert(server.getState === "STARTED") - assert(boundPort != 0) - Try { new ServerSocket(boundPort) } match { - case Success(s) => fail("Port %s doesn't seem used by jetty server".format(boundPort)) - case Failure(e) => + var socket: ServerSocket = null + var serverInfo: ServerInfo = null + val (conf, sslOptions) = sslDisabledConf() + try { + serverInfo = JettyUtils.startJettyServer( + "0.0.0.0", 0, sslOptions, Seq[ServletContextHandler](), conf) + val server = serverInfo.server + val boundPort = serverInfo.boundPort + assert(server.getState === "STARTED") + assert(boundPort != 0) + intercept[BindException] { + socket = new ServerSocket(boundPort) + } + } finally { + stopServer(serverInfo) + closeSocket(socket) + } + } + + test("jetty with https binds to port 0 correctly") { + var socket: ServerSocket = null + var serverInfo: ServerInfo = null + try { + val (conf, sslOptions) = sslEnabledConf() + serverInfo = JettyUtils.startJettyServer( + "0.0.0.0", 0, sslOptions, Seq[ServletContextHandler](), conf) + val server = serverInfo.server + val boundPort = serverInfo.boundPort + assert(server.getState === "STARTED") + assert(boundPort != 0) + intercept[BindException] { + socket = new ServerSocket(boundPort) + } + } finally { + stopServer(serverInfo) + closeSocket(socket) } } @@ -117,4 +189,12 @@ class UISuite extends SparkFunSuite { assert(splitUIAddress(2).toInt == boundPort) } } + + def stopServer(info: ServerInfo): Unit = { + if (info != null && info.server != null) info.server.stop + } + + def closeSocket(socket: ServerSocket): Unit = { + if (socket != null) socket.close + } } diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala index 2b693c165180f..58beaf103cfb4 100644 --- a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala @@ -17,49 +17,129 @@ package org.apache.spark.ui -import scala.xml.Elem +import scala.xml.{Node, Text} import org.apache.spark.SparkFunSuite class UIUtilsSuite extends SparkFunSuite { import UIUtils._ - test("makeDescription") { + test("makeDescription(plainText = false)") { verify( """test text """, test text , - "Correctly formatted text with only anchors and relative links should generate HTML" + "Correctly formatted text with only anchors and relative links should generate HTML", + plainText = false ) verify( """test """, {"""test """}, - "Badly formatted text should make the description be treated as a streaming instead of HTML" + "Badly formatted text should make the description be treated as a string instead of HTML", + plainText = false ) verify( """test text """, {"""test text """}, - "Non-relative links should make the description be treated as a string instead of HTML" + "Non-relative links should make the description be treated as a string instead of HTML", + plainText = false ) verify( """test""", {"""test"""}, - "Non-anchor elements should make the description be treated as a string instead of HTML" + "Non-anchor elements should make the description be treated as a string instead of HTML", + plainText = false ) verify( """test text """, test text , baseUrl = "base", - errorMsg = "Base URL should be prepended to html links" + errorMsg = "Base URL should be prepended to html links", + plainText = false ) } + test("makeDescription(plainText = true)") { + verify( + """test text """, + Text("test text "), + "Correctly formatted text with only anchors and relative links should generate a string " + + "without any html tags", + plainText = true + ) + + verify( + """test text1 text2 """, + Text("test text1 text2 "), + "Correctly formatted text with multiple anchors and relative links should generate a " + + "string without any html tags", + plainText = true + ) + + verify( + """test text """, + Text("test text "), + "Correctly formatted text with nested anchors and relative links and/or spans should " + + "generate a string without any html tags", + plainText = true + ) + + verify( + """test """, + Text("""test """), + "Badly formatted text should make the description be as the same as the original text", + plainText = true + ) + + verify( + """test text """, + Text("""test text """), + "Non-relative links should make the description be as the same as the original text", + plainText = true + ) + + verify( + """test""", + Text("""test"""), + "Non-anchor elements should make the description be as the same as the original text", + plainText = true + ) + } + + test("SPARK-11906: Progress bar should not overflow because of speculative tasks") { + val generated = makeProgressBar(2, 3, 0, 0, 4).head.child.filter(_.label == "div") + val expected = Seq( +
    , +
    + ) + assert(generated.sameElements(expected), + s"\nRunning progress bar should round down\n\nExpected:\n$expected\nGenerated:\n$generated") + } + + test("decodeURLParameter (SPARK-12708: Sorting task error in Stages Page when yarn mode.)") { + val encoded1 = "%252F" + val decoded1 = "/" + val encoded2 = "%253Cdriver%253E" + val decoded2 = "" + + assert(decoded1 === decodeURLParameter(encoded1)) + assert(decoded2 === decodeURLParameter(encoded2)) + + // verify that no affect to decoded URL. + assert(decoded1 === decodeURLParameter(decoded1)) + assert(decoded2 === decodeURLParameter(decoded2)) + } + private def verify( - desc: String, expected: Elem, errorMsg: String = "", baseUrl: String = ""): Unit = { - val generated = makeDescription(desc, baseUrl) + desc: String, + expected: Node, + errorMsg: String = "", + baseUrl: String = "", + plainText: Boolean): Unit = { + val generated = makeDescription(desc, baseUrl, plainText) assert(generated.sameElements(expected), s"\n$errorMsg\n\nExpected:\n$expected\nGenerated:\n$generated") } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index e02f5a1b20fe3..7d4c0863bc963 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -184,12 +184,12 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with val conf = new SparkConf() val listener = new JobProgressListener(conf) val taskMetrics = new TaskMetrics() - val shuffleReadMetrics = new ShuffleReadMetrics() + val shuffleReadMetrics = taskMetrics.registerTempShuffleReadMetrics() assert(listener.stageIdToData.size === 0) // finish this task, should get updated shuffleRead shuffleReadMetrics.incRemoteBytesRead(1000) - taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics)) + taskMetrics.mergeShuffleReadMetrics() var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 var task = new ShuffleMapTask(0) @@ -240,7 +240,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with val taskFailedReasons = Seq( Resubmitted, new FetchFailed(null, 0, 0, 0, "ignored"), - ExceptionFailure("Exception", "description", null, null, None, None), + ExceptionFailure("Exception", "description", null, null, None), TaskResultLost, TaskKilled, ExecutorLostFailure("0", true, Some("Induced failure")), @@ -269,23 +269,22 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with val execId = "exe-1" def makeTaskMetrics(base: Int): TaskMetrics = { - val taskMetrics = new TaskMetrics() - val shuffleReadMetrics = new ShuffleReadMetrics() - val shuffleWriteMetrics = new ShuffleWriteMetrics() - taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics)) - taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics) + val accums = InternalAccumulator.createAll() + accums.foreach(Accumulators.register) + val taskMetrics = new TaskMetrics(accums) + val shuffleReadMetrics = taskMetrics.registerTempShuffleReadMetrics() + val shuffleWriteMetrics = taskMetrics.registerShuffleWriteMetrics() + val inputMetrics = taskMetrics.registerInputMetrics(DataReadMethod.Hadoop) + val outputMetrics = taskMetrics.registerOutputMetrics(DataWriteMethod.Hadoop) shuffleReadMetrics.incRemoteBytesRead(base + 1) shuffleReadMetrics.incLocalBytesRead(base + 9) shuffleReadMetrics.incRemoteBlocksFetched(base + 2) - shuffleWriteMetrics.incShuffleBytesWritten(base + 3) + taskMetrics.mergeShuffleReadMetrics() + shuffleWriteMetrics.incBytesWritten(base + 3) taskMetrics.setExecutorRunTime(base + 4) taskMetrics.incDiskBytesSpilled(base + 5) taskMetrics.incMemoryBytesSpilled(base + 6) - val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) - taskMetrics.setInputMetrics(Some(inputMetrics)) - inputMetrics.incBytesRead(base + 7) - val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop) - taskMetrics.outputMetrics = Some(outputMetrics) + inputMetrics.setBytesRead(base + 7) outputMetrics.setBytesWritten(base + 8) taskMetrics } @@ -303,9 +302,9 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with listener.onTaskStart(SparkListenerTaskStart(1, 0, makeTaskInfo(1237L))) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate(execId, Array( - (1234L, 0, 0, makeTaskMetrics(0)), - (1235L, 0, 0, makeTaskMetrics(100)), - (1236L, 1, 0, makeTaskMetrics(200))))) + (1234L, 0, 0, makeTaskMetrics(0).accumulatorUpdates()), + (1235L, 0, 0, makeTaskMetrics(100).accumulatorUpdates()), + (1236L, 1, 0, makeTaskMetrics(200).accumulatorUpdates())))) var stage0Data = listener.stageIdToData.get((0, 0)).get var stage1Data = listener.stageIdToData.get((1, 0)).get @@ -323,11 +322,11 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with assert(stage1Data.inputBytes == 207) assert(stage0Data.outputBytes == 116) assert(stage1Data.outputBytes == 208) - assert(stage0Data.taskData.get(1234L).get.taskMetrics.get.shuffleReadMetrics.get + assert(stage0Data.taskData.get(1234L).get.metrics.get.shuffleReadMetrics.get .totalBlocksFetched == 2) - assert(stage0Data.taskData.get(1235L).get.taskMetrics.get.shuffleReadMetrics.get + assert(stage0Data.taskData.get(1235L).get.metrics.get.shuffleReadMetrics.get .totalBlocksFetched == 102) - assert(stage1Data.taskData.get(1236L).get.taskMetrics.get.shuffleReadMetrics.get + assert(stage1Data.taskData.get(1236L).get.metrics.get.shuffleReadMetrics.get .totalBlocksFetched == 202) // task that was included in a heartbeat @@ -356,9 +355,9 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with assert(stage1Data.inputBytes == 614) assert(stage0Data.outputBytes == 416) assert(stage1Data.outputBytes == 616) - assert(stage0Data.taskData.get(1234L).get.taskMetrics.get.shuffleReadMetrics.get + assert(stage0Data.taskData.get(1234L).get.metrics.get.shuffleReadMetrics.get .totalBlocksFetched == 302) - assert(stage1Data.taskData.get(1237L).get.taskMetrics.get.shuffleReadMetrics.get + assert(stage1Data.taskData.get(1237L).get.metrics.get.shuffleReadMetrics.get .totalBlocksFetched == 402) } } diff --git a/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala index 86b078851851f..3fb78da0c7476 100644 --- a/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala @@ -19,9 +19,6 @@ package org.apache.spark.ui.scope import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.SparkListenerStageSubmitted -import org.apache.spark.scheduler.SparkListenerStageCompleted -import org.apache.spark.scheduler.SparkListenerJobStart /** * Tests that this listener populates and cleans up its data structures properly. diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala index 3dab15a9d4691..350c174e24742 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.ui.storage -import scala.xml.Utility - import org.mockito.Mockito._ import org.apache.spark.SparkFunSuite @@ -64,26 +62,24 @@ class StoragePageSuite extends SparkFunSuite { "Cached Partitions", "Fraction Cached", "Size in Memory", - "Size in ExternalBlockStore", "Size on Disk") assert((xmlNodes \\ "th").map(_.text) === headers) assert((xmlNodes \\ "tr").size === 3) assert(((xmlNodes \\ "tr")(0) \\ "td").map(_.text.trim) === - Seq("rdd1", "Memory Deserialized 1x Replicated", "10", "100%", "100.0 B", "0.0 B", "0.0 B")) + Seq("rdd1", "Memory Deserialized 1x Replicated", "10", "100%", "100.0 B", "0.0 B")) // Check the url assert(((xmlNodes \\ "tr")(0) \\ "td" \ "a")(0).attribute("href").map(_.text) === Some("http://localhost:4040/storage/rdd?id=1")) assert(((xmlNodes \\ "tr")(1) \\ "td").map(_.text.trim) === - Seq("rdd2", "Disk Serialized 1x Replicated", "5", "50%", "0.0 B", "0.0 B", "200.0 B")) + Seq("rdd2", "Disk Serialized 1x Replicated", "5", "50%", "0.0 B", "200.0 B")) // Check the url assert(((xmlNodes \\ "tr")(1) \\ "td" \ "a")(0).attribute("href").map(_.text) === Some("http://localhost:4040/storage/rdd?id=2")) assert(((xmlNodes \\ "tr")(2) \\ "td").map(_.text.trim) === - Seq("rdd3", "Disk Memory Serialized 1x Replicated", "10", "100%", "400.0 B", "0.0 B", - "500.0 B")) + Seq("rdd3", "Disk Memory Serialized 1x Replicated", "10", "100%", "400.0 B", "500.0 B")) // Check the url assert(((xmlNodes \\ "tr")(2) \\ "td" \ "a")(0).attribute("href").map(_.text) === Some("http://localhost:4040/storage/rdd?id=3")) @@ -98,16 +94,14 @@ class StoragePageSuite extends SparkFunSuite { "localhost:1111", StorageLevel.MEMORY_ONLY, memSize = 100, - diskSize = 0, - externalBlockStoreSize = 0) + diskSize = 0) assert(("Memory", 100) === storagePage.streamBlockStorageLevelDescriptionAndSize(memoryBlock)) val memorySerializedBlock = BlockUIData(StreamBlockId(0, 0), "localhost:1111", StorageLevel.MEMORY_ONLY_SER, memSize = 100, - diskSize = 0, - externalBlockStoreSize = 0) + diskSize = 0) assert(("Memory Serialized", 100) === storagePage.streamBlockStorageLevelDescriptionAndSize(memorySerializedBlock)) @@ -115,18 +109,8 @@ class StoragePageSuite extends SparkFunSuite { "localhost:1111", StorageLevel.DISK_ONLY, memSize = 0, - diskSize = 100, - externalBlockStoreSize = 0) + diskSize = 100) assert(("Disk", 100) === storagePage.streamBlockStorageLevelDescriptionAndSize(diskBlock)) - - val externalBlock = BlockUIData(StreamBlockId(0, 0), - "localhost:1111", - StorageLevel.OFF_HEAP, - memSize = 0, - diskSize = 0, - externalBlockStoreSize = 100) - assert(("External", 100) === - storagePage.streamBlockStorageLevelDescriptionAndSize(externalBlock)) } test("receiverBlockTables") { @@ -135,14 +119,12 @@ class StoragePageSuite extends SparkFunSuite { "localhost:10000", StorageLevel.MEMORY_ONLY, memSize = 100, - diskSize = 0, - externalBlockStoreSize = 0), + diskSize = 0), BlockUIData(StreamBlockId(1, 1), "localhost:10000", StorageLevel.DISK_ONLY, memSize = 0, - diskSize = 100, - externalBlockStoreSize = 0) + diskSize = 100) ) val executor0 = ExecutorStreamBlockStatus("0", "localhost:10000", blocksForExecutor0) @@ -151,20 +133,12 @@ class StoragePageSuite extends SparkFunSuite { "localhost:10001", StorageLevel.MEMORY_ONLY, memSize = 100, - diskSize = 0, - externalBlockStoreSize = 0), - BlockUIData(StreamBlockId(2, 2), - "localhost:10001", - StorageLevel.OFF_HEAP, - memSize = 0, - diskSize = 0, - externalBlockStoreSize = 200), + diskSize = 0), BlockUIData(StreamBlockId(1, 1), "localhost:10001", StorageLevel.MEMORY_ONLY_SER, memSize = 100, - diskSize = 0, - externalBlockStoreSize = 0) + diskSize = 0) ) val executor1 = ExecutorStreamBlockStatus("1", "localhost:10001", blocksForExecutor1) val xmlNodes = storagePage.receiverBlockTables(Seq(executor0, executor1)) @@ -174,16 +148,15 @@ class StoragePageSuite extends SparkFunSuite { "Executor ID", "Address", "Total Size in Memory", - "Total Size in ExternalBlockStore", "Total Size on Disk", "Stream Blocks") assert((executorTable \\ "th").map(_.text) === executorHeaders) assert((executorTable \\ "tr").size === 2) assert(((executorTable \\ "tr")(0) \\ "td").map(_.text.trim) === - Seq("0", "localhost:10000", "100.0 B", "0.0 B", "100.0 B", "2")) + Seq("0", "localhost:10000", "100.0 B", "100.0 B", "2")) assert(((executorTable \\ "tr")(1) \\ "td").map(_.text.trim) === - Seq("1", "localhost:10001", "200.0 B", "200.0 B", "0.0 B", "3")) + Seq("1", "localhost:10001", "200.0 B", "0.0 B", "2")) val blockTable = (xmlNodes \\ "table")(1) val blockHeaders = Seq( @@ -194,7 +167,7 @@ class StoragePageSuite extends SparkFunSuite { "Size") assert((blockTable \\ "th").map(_.text) === blockHeaders) - assert((blockTable \\ "tr").size === 5) + assert((blockTable \\ "tr").size === 4) assert(((blockTable \\ "tr")(0) \\ "td").map(_.text.trim) === Seq("input-0-0", "2", "localhost:10000", "Memory", "100.0 B")) // Check "rowspan=2" for the first 2 columns @@ -212,17 +185,10 @@ class StoragePageSuite extends SparkFunSuite { assert(((blockTable \\ "tr")(3) \\ "td").map(_.text.trim) === Seq("localhost:10001", "Memory Serialized", "100.0 B")) - - assert(((blockTable \\ "tr")(4) \\ "td").map(_.text.trim) === - Seq("input-2-2", "1", "localhost:10001", "External", "200.0 B")) - // Check "rowspan=1" for the first 2 columns - assert(((blockTable \\ "tr")(4) \\ "td")(0).attribute("rowspan").map(_.text) === Some("1")) - assert(((blockTable \\ "tr")(4) \\ "td")(1).attribute("rowspan").map(_.text) === Some("1")) } test("empty receiverBlockTables") { assert(storagePage.receiverBlockTables(Seq.empty).isEmpty) - val executor0 = ExecutorStreamBlockStatus("0", "localhost:10000", Seq.empty) val executor1 = ExecutorStreamBlockStatus("1", "localhost:10001", Seq.empty) assert(storagePage.receiverBlockTables(Seq(executor0, executor1)).isEmpty) diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index 37e2670de9685..7d77deeb60618 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.ui.storage import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkFunSuite, Success} + +import org.apache.spark.{SparkConf, SparkFunSuite, Success} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.storage._ @@ -43,7 +44,7 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { before { bus = new LiveListenerBus - storageStatusListener = new StorageStatusListener + storageStatusListener = new StorageStatusListener(new SparkConf()) storageListener = new StorageListener(storageStatusListener) bus.addListener(storageStatusListener) bus.addListener(storageListener) @@ -105,7 +106,7 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { assert(storageListener.rddInfoList.size === 0) } - test("task end") { + test("block update") { val myRddInfo0 = rddInfo0 val myRddInfo1 = rddInfo1 val myRddInfo2 = rddInfo2 @@ -119,46 +120,35 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { assert(!storageListener._rddInfoMap(1).isCached) assert(!storageListener._rddInfoMap(2).isCached) - // Task end with no updated blocks. This should not change anything. - bus.postToAll(SparkListenerTaskEnd(0, 0, "obliteration", Success, taskInfo, new TaskMetrics)) - assert(storageListener._rddInfoMap.size === 3) - assert(storageListener.rddInfoList.size === 0) - - // Task end with a few new persisted blocks, some from the same RDD - val metrics1 = new TaskMetrics - metrics1.updatedBlocks = Some(Seq( - (RDDBlockId(0, 100), BlockStatus(memAndDisk, 400L, 0L, 0L)), - (RDDBlockId(0, 101), BlockStatus(memAndDisk, 0L, 400L, 0L)), - (RDDBlockId(0, 102), BlockStatus(memAndDisk, 400L, 0L, 200L)), - (RDDBlockId(1, 20), BlockStatus(memAndDisk, 0L, 240L, 0L)) - )) - bus.postToAll(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo, metrics1)) - assert(storageListener._rddInfoMap(0).memSize === 800L) + // Some blocks updated + val blockUpdateInfos = Seq( + BlockUpdatedInfo(bm1, RDDBlockId(0, 100), memAndDisk, 400L, 0L), + BlockUpdatedInfo(bm1, RDDBlockId(0, 101), memAndDisk, 0L, 400L), + BlockUpdatedInfo(bm1, RDDBlockId(1, 20), memAndDisk, 0L, 240L) + ) + postUpdateBlocks(bus, blockUpdateInfos) + assert(storageListener._rddInfoMap(0).memSize === 400L) assert(storageListener._rddInfoMap(0).diskSize === 400L) - assert(storageListener._rddInfoMap(0).externalBlockStoreSize === 200L) - assert(storageListener._rddInfoMap(0).numCachedPartitions === 3) + assert(storageListener._rddInfoMap(0).numCachedPartitions === 2) assert(storageListener._rddInfoMap(0).isCached) assert(storageListener._rddInfoMap(1).memSize === 0L) assert(storageListener._rddInfoMap(1).diskSize === 240L) - assert(storageListener._rddInfoMap(1).externalBlockStoreSize === 0L) assert(storageListener._rddInfoMap(1).numCachedPartitions === 1) assert(storageListener._rddInfoMap(1).isCached) assert(!storageListener._rddInfoMap(2).isCached) assert(storageListener._rddInfoMap(2).numCachedPartitions === 0) - // Task end with a few dropped blocks - val metrics2 = new TaskMetrics - metrics2.updatedBlocks = Some(Seq( - (RDDBlockId(0, 100), BlockStatus(none, 0L, 0L, 0L)), - (RDDBlockId(1, 20), BlockStatus(none, 0L, 0L, 0L)), - (RDDBlockId(2, 40), BlockStatus(none, 0L, 0L, 0L)), // doesn't actually exist - (RDDBlockId(4, 80), BlockStatus(none, 0L, 0L, 0L)) // doesn't actually exist - )) - bus.postToAll(SparkListenerTaskEnd(2, 0, "obliteration", Success, taskInfo, metrics2)) - assert(storageListener._rddInfoMap(0).memSize === 400L) + // Drop some blocks + val blockUpdateInfos2 = Seq( + BlockUpdatedInfo(bm1, RDDBlockId(0, 100), none, 0L, 0L), + BlockUpdatedInfo(bm1, RDDBlockId(1, 20), none, 0L, 0L), + BlockUpdatedInfo(bm1, RDDBlockId(2, 40), none, 0L, 0L), // doesn't actually exist + BlockUpdatedInfo(bm1, RDDBlockId(4, 80), none, 0L, 0L) // doesn't actually exist + ) + postUpdateBlocks(bus, blockUpdateInfos2) + assert(storageListener._rddInfoMap(0).memSize === 0L) assert(storageListener._rddInfoMap(0).diskSize === 400L) - assert(storageListener._rddInfoMap(0).externalBlockStoreSize === 200L) - assert(storageListener._rddInfoMap(0).numCachedPartitions === 2) + assert(storageListener._rddInfoMap(0).numCachedPartitions === 1) assert(storageListener._rddInfoMap(0).isCached) assert(!storageListener._rddInfoMap(1).isCached) assert(storageListener._rddInfoMap(2).numCachedPartitions === 0) @@ -172,24 +162,27 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { val rddInfo1 = new RDDInfo(1, "rdd1", 1, memOnly, Seq(4)) val stageInfo0 = new StageInfo(0, 0, "stage0", 1, Seq(rddInfo0), Seq.empty, "details") val stageInfo1 = new StageInfo(1, 0, "stage1", 1, Seq(rddInfo1), Seq.empty, "details") - val taskMetrics0 = new TaskMetrics - val taskMetrics1 = new TaskMetrics - val block0 = (RDDBlockId(0, 1), BlockStatus(memOnly, 100L, 0L, 0L)) - val block1 = (RDDBlockId(1, 1), BlockStatus(memOnly, 200L, 0L, 0L)) - taskMetrics0.updatedBlocks = Some(Seq(block0)) - taskMetrics1.updatedBlocks = Some(Seq(block1)) + val blockUpdateInfos1 = Seq(BlockUpdatedInfo(bm1, RDDBlockId(0, 1), memOnly, 100L, 0L)) + val blockUpdateInfos2 = Seq(BlockUpdatedInfo(bm1, RDDBlockId(1, 1), memOnly, 200L, 0L)) bus.postToAll(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) assert(storageListener.rddInfoList.size === 0) - bus.postToAll(SparkListenerTaskEnd(0, 0, "big", Success, taskInfo, taskMetrics0)) + postUpdateBlocks(bus, blockUpdateInfos1) assert(storageListener.rddInfoList.size === 1) bus.postToAll(SparkListenerStageSubmitted(stageInfo1)) assert(storageListener.rddInfoList.size === 1) bus.postToAll(SparkListenerStageCompleted(stageInfo0)) assert(storageListener.rddInfoList.size === 1) - bus.postToAll(SparkListenerTaskEnd(1, 0, "small", Success, taskInfo1, taskMetrics1)) + postUpdateBlocks(bus, blockUpdateInfos2) assert(storageListener.rddInfoList.size === 2) bus.postToAll(SparkListenerStageCompleted(stageInfo1)) assert(storageListener.rddInfoList.size === 2) } + + private def postUpdateBlocks( + bus: SparkListenerBus, blockUpdateInfos: Seq[BlockUpdatedInfo]): Unit = { + blockUpdateInfos.foreach { blockUpdateInfo => + bus.postToAll(SparkListenerBlockUpdated(blockUpdateInfo)) + } + } } diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala deleted file mode 100644 index 61601016e005e..0000000000000 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ /dev/null @@ -1,359 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import scala.collection.mutable.ArrayBuffer - -import java.util.concurrent.TimeoutException - -import akka.actor.ActorNotFound - -import org.apache.spark._ -import org.apache.spark.rpc.RpcEnv -import org.apache.spark.scheduler.MapStatus -import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} -import org.apache.spark.SSLSampleConfigs._ - - -/** - * Test the AkkaUtils with various security settings. - */ -class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSystemProperties { - - test("remote fetch security bad password") { - val conf = new SparkConf - conf.set("spark.rpc", "akka") - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - - val securityManager = new SecurityManager(conf) - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - assert(securityManager.isAuthenticationEnabled() === true) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val badconf = new SparkConf - badconf.set("spark.rpc", "akka") - badconf.set("spark.authenticate", "true") - badconf.set("spark.authenticate.secret", "bad") - val securityManagerBad = new SecurityManager(badconf) - - assert(securityManagerBad.isAuthenticationEnabled() === true) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, conf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - } - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - test("remote fetch security off") { - val conf = new SparkConf - conf.set("spark.authenticate", "false") - conf.set("spark.authenticate.secret", "bad") - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === false) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val badconf = new SparkConf - badconf.set("spark.authenticate", "false") - badconf.set("spark.authenticate.secret", "good") - val securityManagerBad = new SecurityManager(badconf) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, badconf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - - assert(securityManagerBad.isAuthenticationEnabled() === false) - - masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) - masterTracker.registerMapOutput(10, 0, - MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L))) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - // this should succeed since security off - assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), - ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - test("remote fetch security pass") { - val conf = new SparkConf - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === true) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val goodconf = new SparkConf - goodconf.set("spark.authenticate", "true") - goodconf.set("spark.authenticate.secret", "good") - val securityManagerGood = new SecurityManager(goodconf) - - assert(securityManagerGood.isAuthenticationEnabled() === true) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, goodconf, securityManagerGood) - val slaveTracker = new MapOutputTrackerWorker(conf) - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - - masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) - masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("a", "hostA", 1000), Array(1000L))) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - // this should succeed since security on and passwords match - assert(slaveTracker.getMapSizesByExecutorId(10, 0) === - Seq((BlockManagerId("a", "hostA", 1000), - ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - test("remote fetch security off client") { - val conf = new SparkConf - conf.set("spark.rpc", "akka") - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === true) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val badconf = new SparkConf - badconf.set("spark.rpc", "akka") - badconf.set("spark.authenticate", "false") - badconf.set("spark.authenticate.secret", "bad") - val securityManagerBad = new SecurityManager(badconf) - - assert(securityManagerBad.isAuthenticationEnabled() === false) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, badconf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - } - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - test("remote fetch ssl on") { - val conf = sparkSSLConfig() - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === false) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val slaveConf = sparkSSLConfig() - val securityManagerBad = new SecurityManager(slaveConf) - - val slaveRpcEnv = RpcEnv.create("spark-slaves", hostname, 0, slaveConf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - - assert(securityManagerBad.isAuthenticationEnabled() === false) - - masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) - masterTracker.registerMapOutput(10, 0, - MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L))) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - // this should succeed since security off - assert(slaveTracker.getMapSizesByExecutorId(10, 0) === - Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - - test("remote fetch ssl on and security enabled") { - val conf = sparkSSLConfig() - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === true) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val slaveConf = sparkSSLConfig() - slaveConf.set("spark.authenticate", "true") - slaveConf.set("spark.authenticate.secret", "good") - val securityManagerBad = new SecurityManager(slaveConf) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - - assert(securityManagerBad.isAuthenticationEnabled() === true) - - masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) - masterTracker.registerMapOutput(10, 0, - MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L))) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - assert(slaveTracker.getMapSizesByExecutorId(10, 0) === - Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - - test("remote fetch ssl on and security enabled - bad credentials") { - val conf = sparkSSLConfig() - conf.set("spark.rpc", "akka") - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === true) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val slaveConf = sparkSSLConfig() - slaveConf.set("spark.rpc", "akka") - slaveConf.set("spark.authenticate", "true") - slaveConf.set("spark.authenticate.secret", "bad") - val securityManagerBad = new SecurityManager(slaveConf) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - } - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - - test("remote fetch ssl on - untrusted server") { - val conf = sparkSSLConfigUntrusted() - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === false) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val slaveConf = sparkSSLConfig() - val securityManagerBad = new SecurityManager(slaveConf) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - try { - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - fail("should receive either ActorNotFound or TimeoutException") - } catch { - case e: ActorNotFound => - case e: TimeoutException => - } - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - -} diff --git a/core/src/test/scala/org/apache/spark/util/CausedBySuite.scala b/core/src/test/scala/org/apache/spark/util/CausedBySuite.scala new file mode 100644 index 0000000000000..4a80e3f1f452d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/CausedBySuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.apache.spark.SparkFunSuite + +class CausedBySuite extends SparkFunSuite { + + test("For an error without a cause, should return the error") { + val error = new Exception + + val causedBy = error match { + case CausedBy(e) => e + } + + assert(causedBy === error) + } + + test("For an error with a cause, should return the cause of the error") { + val cause = new Exception + val error = new Exception(cause) + + val causedBy = error match { + case CausedBy(e) => e + } + + assert(causedBy === cause) + } + + test("For an error with a cause that itself has a cause, return the root cause") { + val causeOfCause = new Exception + val cause = new Exception(causeOfCause) + val error = new Exception(cause) + + val causedBy = error match { + case CausedBy(e) => e + } + + assert(causedBy === causeOfCause) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 480722a5ac182..932704c1a3659 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -18,10 +18,9 @@ package org.apache.spark.util import java.io.NotSerializableException -import java.util.Random -import org.apache.spark.LocalSparkContext._ import org.apache.spark.{SparkContext, SparkException, SparkFunSuite, TaskContext} +import org.apache.spark.LocalSparkContext._ import org.apache.spark.partial.CountEvaluator import org.apache.spark.rdd.RDD @@ -91,11 +90,6 @@ class ClosureCleanerSuite extends SparkFunSuite { expectCorrectException { TestUserClosuresActuallyCleaned.testKeyBy(rdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitions(rdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitionsWithIndex(rdd) } - expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitionsWithContext(rdd) } - expectCorrectException { TestUserClosuresActuallyCleaned.testFlatMapWith(rdd) } - expectCorrectException { TestUserClosuresActuallyCleaned.testFilterWith(rdd) } - expectCorrectException { TestUserClosuresActuallyCleaned.testForEachWith(rdd) } - expectCorrectException { TestUserClosuresActuallyCleaned.testMapWith(rdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions2(rdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions3(rdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions4(rdd) } @@ -269,21 +263,6 @@ private object TestUserClosuresActuallyCleaned { def testMapPartitionsWithIndex(rdd: RDD[Int]): Unit = { rdd.mapPartitionsWithIndex { (_, it) => return; it }.count() } - def testFlatMapWith(rdd: RDD[Int]): Unit = { - rdd.flatMapWith ((index: Int) => new Random(index + 42)){ (_, it) => return; Seq() }.count() - } - def testMapWith(rdd: RDD[Int]): Unit = { - rdd.mapWith ((index: Int) => new Random(index + 42)){ (_, it) => return; 0 }.count() - } - def testFilterWith(rdd: RDD[Int]): Unit = { - rdd.filterWith ((index: Int) => new Random(index + 42)){ (_, it) => return; true }.count() - } - def testForEachWith(rdd: RDD[Int]): Unit = { - rdd.foreachWith ((index: Int) => new Random(index + 42)){ (_, it) => return } - } - def testMapPartitionsWithContext(rdd: RDD[Int]): Unit = { - rdd.mapPartitionsWithContext { (_, it) => return; it }.count() - } def testZipPartitions2(rdd: RDD[Int]): Unit = { rdd.zipPartitions(rdd) { case (it1, it2) => return; it1 }.count() } diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala index a829b099025e9..934385fbcad1b 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala @@ -38,14 +38,19 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri private var closureSerializer: SerializerInstance = null override def beforeAll(): Unit = { + super.beforeAll() sc = new SparkContext("local", "test") closureSerializer = sc.env.closureSerializer.newInstance() } override def afterAll(): Unit = { - sc.stop() - sc = null - closureSerializer = null + try { + sc.stop() + sc = null + closureSerializer = null + } finally { + super.afterAll() + } } // Some fields and methods to reference in inner closures later diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala index b207d497f33c2..6f7dddd4f760a 100644 --- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.util -import java.util.concurrent.CountDownLatch +import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch} -import scala.collection.mutable +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps @@ -31,11 +31,11 @@ import org.apache.spark.SparkFunSuite class EventLoopSuite extends SparkFunSuite with Timeouts { test("EventLoop") { - val buffer = new mutable.ArrayBuffer[Int] with mutable.SynchronizedBuffer[Int] + val buffer = new ConcurrentLinkedQueue[Int] val eventLoop = new EventLoop[Int]("test") { override def onReceive(event: Int): Unit = { - buffer += event + buffer.add(event) } override def onError(e: Throwable): Unit = {} @@ -43,7 +43,7 @@ class EventLoopSuite extends SparkFunSuite with Timeouts { eventLoop.start() (1 to 100).foreach(eventLoop.post) eventually(timeout(5 seconds), interval(5 millis)) { - assert((1 to 100) === buffer.toSeq) + assert((1 to 100) === buffer.asScala.toSeq) } eventLoop.stop() } diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 2b76ae1f8a24b..4fa9f9a8f590f 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -18,17 +18,22 @@ package org.apache.spark.util import java.io._ +import java.nio.charset.StandardCharsets +import java.util.concurrent.CountDownLatch import scala.collection.mutable.HashSet import scala.reflect._ -import org.scalatest.BeforeAndAfter - -import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files +import org.apache.log4j.{Appender, Level, Logger} +import org.apache.log4j.spi.LoggingEvent +import org.mockito.ArgumentCaptor +import org.mockito.Mockito.{atLeast, mock, verify} +import org.scalatest.BeforeAndAfter -import org.apache.spark.{Logging, SparkConf, SparkFunSuite} -import org.apache.spark.util.logging.{RollingFileAppender, SizeBasedRollingPolicy, TimeBasedRollingPolicy, FileAppender} +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.Logging +import org.apache.spark.util.logging.{FileAppender, RollingFileAppender, SizeBasedRollingPolicy, TimeBasedRollingPolicy} class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { @@ -44,11 +49,11 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { test("basic file appender") { val testString = (1 to 1000).mkString(", ") - val inputStream = new ByteArrayInputStream(testString.getBytes(UTF_8)) + val inputStream = new ByteArrayInputStream(testString.getBytes(StandardCharsets.UTF_8)) val appender = new FileAppender(inputStream, testFile) inputStream.close() appender.awaitTermination() - assert(Files.toString(testFile, UTF_8) === testString) + assert(Files.toString(testFile, StandardCharsets.UTF_8) === testString) } test("rolling file appender - time-based rolling") { @@ -96,7 +101,7 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { val allGeneratedFiles = new HashSet[String]() val items = (1 to 10).map { _.toString * 10000 } for (i <- 0 until items.size) { - testOutputStream.write(items(i).getBytes(UTF_8)) + testOutputStream.write(items(i).getBytes(StandardCharsets.UTF_8)) testOutputStream.flush() allGeneratedFiles ++= RollingFileAppender.getSortedRolledOverFiles( testFile.getParentFile.toString, testFile.getName).map(_.toString) @@ -189,6 +194,77 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { testAppenderSelection[FileAppender, Any](rollingStrategy("xyz")) } + test("file appender async close stream abruptly") { + // Test FileAppender reaction to closing InputStream using a mock logging appender + val mockAppender = mock(classOf[Appender]) + val loggingEventCaptor = ArgumentCaptor.forClass(classOf[LoggingEvent]) + + // Make sure only logging errors + val logger = Logger.getRootLogger + val oldLogLevel = logger.getLevel + logger.setLevel(Level.ERROR) + try { + logger.addAppender(mockAppender) + + val testOutputStream = new PipedOutputStream() + val testInputStream = new PipedInputStream(testOutputStream) + + // Close the stream before appender tries to read will cause an IOException + testInputStream.close() + testOutputStream.close() + val appender = FileAppender(testInputStream, testFile, new SparkConf) + + appender.awaitTermination() + + // If InputStream was closed without first stopping the appender, an exception will be logged + verify(mockAppender, atLeast(1)).doAppend(loggingEventCaptor.capture) + val loggingEvent = loggingEventCaptor.getValue + assert(loggingEvent.getThrowableInformation !== null) + assert(loggingEvent.getThrowableInformation.getThrowable.isInstanceOf[IOException]) + } finally { + logger.setLevel(oldLogLevel) + } + } + + test("file appender async close stream gracefully") { + // Test FileAppender reaction to closing InputStream using a mock logging appender + val mockAppender = mock(classOf[Appender]) + val loggingEventCaptor = ArgumentCaptor.forClass(classOf[LoggingEvent]) + + // Make sure only logging errors + val logger = Logger.getRootLogger + val oldLogLevel = logger.getLevel + logger.setLevel(Level.ERROR) + try { + logger.addAppender(mockAppender) + + val testOutputStream = new PipedOutputStream() + val testInputStream = new PipedInputStream(testOutputStream) with LatchedInputStream + + // Close the stream before appender tries to read will cause an IOException + testInputStream.close() + testOutputStream.close() + val appender = FileAppender(testInputStream, testFile, new SparkConf) + + // Stop the appender before an IOException is called during read + testInputStream.latchReadStarted.await() + appender.stop() + testInputStream.latchReadProceed.countDown() + + appender.awaitTermination() + + // Make sure no IOException errors have been logged as a result of appender closing gracefully + verify(mockAppender, atLeast(0)).doAppend(loggingEventCaptor.capture) + import scala.collection.JavaConverters._ + loggingEventCaptor.getAllValues.asScala.foreach { loggingEvent => + assert(loggingEvent.getThrowableInformation === null + || !loggingEvent.getThrowableInformation.getThrowable.isInstanceOf[IOException]) + } + } finally { + logger.setLevel(oldLogLevel) + } + } + /** * Run the rolling file appender with data and see whether all the data was written correctly * across rolled over files. @@ -202,7 +278,7 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { // send data to appender through the input stream, and wait for the data to be written val expectedText = textToAppend.mkString("") for (i <- 0 until textToAppend.size) { - outputStream.write(textToAppend(i).getBytes(UTF_8)) + outputStream.write(textToAppend(i).getBytes(StandardCharsets.UTF_8)) outputStream.flush() Thread.sleep(sleepTimeBetweenTexts) } @@ -217,7 +293,7 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { logInfo("Filtered files: \n" + generatedFiles.mkString("\n")) assert(generatedFiles.size > 1) val allText = generatedFiles.map { file => - Files.toString(file, UTF_8) + Files.toString(file, StandardCharsets.UTF_8) }.mkString("") assert(allText === expectedText) generatedFiles @@ -229,4 +305,15 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { file.getName.startsWith(testFile.getName) }.foreach { _.delete() } } + + /** Used to synchronize when read is called on a stream */ + private trait LatchedInputStream extends PipedInputStream { + val latchReadStarted = new CountDownLatch(1) + val latchReadProceed = new CountDownLatch(1) + abstract override def read(): Int = { + latchReadStarted.countDown() + latchReadProceed.await() + super.read() + } + } } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 953456c2caa89..de6f408fa82be 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -19,26 +19,24 @@ package org.apache.spark.util import java.util.Properties -import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.shuffle.MetadataFetchFailedException - import scala.collection.Map import org.json4s.jackson.JsonMethods._ +import org.json4s.JsonAST.{JArray, JInt, JString, JValue} +import org.json4s.JsonDSL._ +import org.scalatest.Assertions +import org.scalatest.exceptions.TestFailedException import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage._ class JsonProtocolSuite extends SparkFunSuite { - - val jobSubmissionTime = 1421191042750L - val jobCompletionTime = 1421191296660L - - val executorAddedTime = 1421458410000L - val executorRemovedTime = 1421458922000L + import JsonProtocolSuite._ test("SparkListenerEvent") { val stageSubmitted = @@ -83,9 +81,13 @@ class JsonProtocolSuite extends SparkFunSuite { val executorAdded = SparkListenerExecutorAdded(executorAddedTime, "exec1", new ExecutorInfo("Hostee.awesome.com", 11, logUrlMap)) val executorRemoved = SparkListenerExecutorRemoved(executorRemovedTime, "exec2", "test reason") - val executorMetricsUpdate = SparkListenerExecutorMetricsUpdate("exec3", Seq( - (1L, 2, 3, makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, - hasHadoopInput = true, hasOutput = true)))) + val executorMetricsUpdate = { + // Use custom accum ID for determinism + val accumUpdates = + makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, hasHadoopInput = true, hasOutput = true) + .accumulatorUpdates().zipWithIndex.map { case (a, i) => a.copy(id = i) } + SparkListenerExecutorMetricsUpdate("exec3", Seq((1L, 2, 3, accumUpdates))) + } testEvent(stageSubmitted, stageSubmittedJsonString) testEvent(stageCompleted, stageCompletedJsonString) @@ -143,7 +145,7 @@ class JsonProtocolSuite extends SparkFunSuite { "Some exception") val fetchMetadataFailed = new MetadataFetchFailedException(17, 19, "metadata Fetch failed exception").toTaskEndReason - val exceptionFailure = new ExceptionFailure(exception, None) + val exceptionFailure = new ExceptionFailure(exception, Seq.empty[AccumulableInfo]) testTaskEndReason(Success) testTaskEndReason(Resubmitted) testTaskEndReason(fetchFailed) @@ -163,9 +165,12 @@ class JsonProtocolSuite extends SparkFunSuite { testBlockId(StreamBlockId(1, 2L)) } - test("ExceptionFailure backward compatibility") { - val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, - None, None) + /* ============================== * + | Backward compatibility tests | + * ============================== */ + + test("ExceptionFailure backward compatibility: full stack trace") { + val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, None) val oldEvent = JsonProtocol.taskEndReasonToJson(exceptionFailure) .removeField({ _._1 == "Full Stack Trace" }) assertEquals(exceptionFailure, JsonProtocol.taskEndReasonFromJson(oldEvent)) @@ -224,7 +229,7 @@ class JsonProtocolSuite extends SparkFunSuite { .removeField { case (field, _) => field == "Shuffle Records Written" } val newMetrics = JsonProtocol.taskMetricsFromJson(oldJson) assert(newMetrics.shuffleReadMetrics.get.recordsRead == 0) - assert(newMetrics.shuffleWriteMetrics.get.shuffleRecordsWritten == 0) + assert(newMetrics.shuffleWriteMetrics.get.recordsWritten == 0) } test("OutputMetrics backward compatibility") { @@ -270,14 +275,13 @@ class JsonProtocolSuite extends SparkFunSuite { assert(expectedFetchFailed === JsonProtocol.taskEndReasonFromJson(oldEvent)) } - test("ShuffleReadMetrics: Local bytes read and time taken backwards compatibility") { - // Metrics about local shuffle bytes read and local read time were added in 1.3.1. + test("ShuffleReadMetrics: Local bytes read backwards compatibility") { + // Metrics about local shuffle bytes read were added in 1.3.1. val metrics = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, hasHadoopInput = false, hasOutput = false, hasRecords = false) assert(metrics.shuffleReadMetrics.nonEmpty) val newJson = JsonProtocol.taskMetricsToJson(metrics) val oldJson = newJson.removeField { case (field, _) => field == "Local Bytes Read" } - .removeField { case (field, _) => field == "Local Read Time" } val newMetrics = JsonProtocol.taskMetricsFromJson(oldJson) assert(newMetrics.shuffleReadMetrics.get.localBytesRead == 0) } @@ -334,14 +338,17 @@ class JsonProtocolSuite extends SparkFunSuite { assertEquals(expectedJobEnd, JsonProtocol.jobEndFromJson(oldEndEvent)) } - test("RDDInfo backward compatibility (scope, parent IDs)") { - // Prior to Spark 1.4.0, RDDInfo did not have the "Scope" and "Parent IDs" properties - val rddInfo = new RDDInfo( - 1, "one", 100, StorageLevel.NONE, Seq(1, 6, 8), Some(new RDDOperationScope("fable"))) + test("RDDInfo backward compatibility (scope, parent IDs, callsite)") { + // "Scope" and "Parent IDs" were introduced in Spark 1.4.0 + // "Callsite" was introduced in Spark 1.6.0 + val rddInfo = new RDDInfo(1, "one", 100, StorageLevel.NONE, Seq(1, 6, 8), + "callsite", Some(new RDDOperationScope("fable"))) val oldRddInfoJson = JsonProtocol.rddInfoToJson(rddInfo) .removeField({ _._1 == "Parent IDs"}) .removeField({ _._1 == "Scope"}) - val expectedRddInfo = new RDDInfo(1, "one", 100, StorageLevel.NONE, Seq.empty, scope = None) + .removeField({ _._1 == "Callsite"}) + val expectedRddInfo = new RDDInfo( + 1, "one", 100, StorageLevel.NONE, Seq.empty, "", scope = None) assertEquals(expectedRddInfo, JsonProtocol.rddInfoFromJson(oldRddInfoJson)) } @@ -365,22 +372,79 @@ class JsonProtocolSuite extends SparkFunSuite { } test("AccumulableInfo backward compatibility") { - // "Internal" property of AccumulableInfo were added after 1.5.1. - val accumulableInfo = makeAccumulableInfo(1) - val oldJson = JsonProtocol.accumulableInfoToJson(accumulableInfo) - .removeField({ _._1 == "Internal" }) + // "Internal" property of AccumulableInfo was added in 1.5.1 + val accumulableInfo = makeAccumulableInfo(1, internal = true, countFailedValues = true) + val accumulableInfoJson = JsonProtocol.accumulableInfoToJson(accumulableInfo) + val oldJson = accumulableInfoJson.removeField({ _._1 == "Internal" }) val oldInfo = JsonProtocol.accumulableInfoFromJson(oldJson) - assert(false === oldInfo.internal) + assert(!oldInfo.internal) + // "Count Failed Values" property of AccumulableInfo was added in 2.0.0 + val oldJson2 = accumulableInfoJson.removeField({ _._1 == "Count Failed Values" }) + val oldInfo2 = JsonProtocol.accumulableInfoFromJson(oldJson2) + assert(!oldInfo2.countFailedValues) + // "Metadata" property of AccumulableInfo was added in 2.0.0 + val oldJson3 = accumulableInfoJson.removeField({ _._1 == "Metadata" }) + val oldInfo3 = JsonProtocol.accumulableInfoFromJson(oldJson3) + assert(oldInfo3.metadata.isEmpty) + } + + test("ExceptionFailure backward compatibility: accumulator updates") { + // "Task Metrics" was replaced with "Accumulator Updates" in 2.0.0. For older event logs, + // we should still be able to fallback to constructing the accumulator updates from the + // "Task Metrics" field, if it exists. + val tm = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, hasHadoopInput = true, hasOutput = true) + val tmJson = JsonProtocol.taskMetricsToJson(tm) + val accumUpdates = tm.accumulatorUpdates() + val exception = new SparkException("sentimental") + val exceptionFailure = new ExceptionFailure(exception, accumUpdates) + val exceptionFailureJson = JsonProtocol.taskEndReasonToJson(exceptionFailure) + val tmFieldJson: JValue = "Task Metrics" -> tmJson + val oldExceptionFailureJson: JValue = + exceptionFailureJson.removeField { _._1 == "Accumulator Updates" }.merge(tmFieldJson) + val oldExceptionFailure = + JsonProtocol.taskEndReasonFromJson(oldExceptionFailureJson).asInstanceOf[ExceptionFailure] + assert(exceptionFailure.className === oldExceptionFailure.className) + assert(exceptionFailure.description === oldExceptionFailure.description) + assertSeqEquals[StackTraceElement]( + exceptionFailure.stackTrace, oldExceptionFailure.stackTrace, assertStackTraceElementEquals) + assert(exceptionFailure.fullStackTrace === oldExceptionFailure.fullStackTrace) + assertSeqEquals[AccumulableInfo]( + exceptionFailure.accumUpdates, oldExceptionFailure.accumUpdates, (x, y) => x == y) + } + + test("AccumulableInfo value de/serialization") { + import InternalAccumulator._ + val blocks = Seq[(BlockId, BlockStatus)]( + (TestBlockId("meebo"), BlockStatus(StorageLevel.MEMORY_ONLY, 1L, 2L)), + (TestBlockId("feebo"), BlockStatus(StorageLevel.DISK_ONLY, 3L, 4L))) + val blocksJson = JArray(blocks.toList.map { case (id, status) => + ("Block ID" -> id.toString) ~ + ("Status" -> JsonProtocol.blockStatusToJson(status)) + }) + testAccumValue(Some(RESULT_SIZE), 3L, JInt(3)) + testAccumValue(Some(shuffleRead.REMOTE_BLOCKS_FETCHED), 2, JInt(2)) + testAccumValue(Some(input.READ_METHOD), "aka", JString("aka")) + testAccumValue(Some(UPDATED_BLOCK_STATUSES), blocks, blocksJson) + // For anything else, we just cast the value to a string + testAccumValue(Some("anything"), blocks, JString(blocks.toString)) + testAccumValue(Some("anything"), 123, JString("123")) } - /** -------------------------- * - | Helper test running methods | - * --------------------------- */ +} + + +private[spark] object JsonProtocolSuite extends Assertions { + import InternalAccumulator._ + + private val jobSubmissionTime = 1421191042750L + private val jobCompletionTime = 1421191296660L + private val executorAddedTime = 1421458410000L + private val executorRemovedTime = 1421458922000L private def testEvent(event: SparkListenerEvent, jsonString: String) { val actualJsonString = compact(render(JsonProtocol.sparkEventToJson(event))) val newEvent = JsonProtocol.sparkEventFromJson(parse(actualJsonString)) - assertJsonStringEquals(jsonString, actualJsonString) + assertJsonStringEquals(jsonString, actualJsonString, event.getClass.getSimpleName) assertEquals(event, newEvent) } @@ -434,11 +498,19 @@ class JsonProtocolSuite extends SparkFunSuite { assertEquals(info, newInfo) } + private def testAccumValue(name: Option[String], value: Any, expectedJson: JValue): Unit = { + val json = JsonProtocol.accumValueToJson(name, value) + assert(json === expectedJson) + val newValue = JsonProtocol.accumValueFromJson(name, json) + val expectedValue = if (name.exists(_.startsWith(METRICS_PREFIX))) value else value.toString + assert(newValue === expectedValue) + } + /** -------------------------------- * | Util methods for comparing events | * --------------------------------- */ - private def assertEquals(event1: SparkListenerEvent, event2: SparkListenerEvent) { + private[spark] def assertEquals(event1: SparkListenerEvent, event2: SparkListenerEvent) { (event1, event2) match { case (e1: SparkListenerStageSubmitted, e2: SparkListenerStageSubmitted) => assert(e1.properties === e2.properties) @@ -472,14 +544,17 @@ class JsonProtocolSuite extends SparkFunSuite { assert(e1.executorId === e1.executorId) case (e1: SparkListenerExecutorMetricsUpdate, e2: SparkListenerExecutorMetricsUpdate) => assert(e1.execId === e2.execId) - assertSeqEquals[(Long, Int, Int, TaskMetrics)](e1.taskMetrics, e2.taskMetrics, (a, b) => { - val (taskId1, stageId1, stageAttemptId1, metrics1) = a - val (taskId2, stageId2, stageAttemptId2, metrics2) = b - assert(taskId1 === taskId2) - assert(stageId1 === stageId2) - assert(stageAttemptId1 === stageAttemptId2) - assertEquals(metrics1, metrics2) - }) + assertSeqEquals[(Long, Int, Int, Seq[AccumulableInfo])]( + e1.accumUpdates, + e2.accumUpdates, + (a, b) => { + val (taskId1, stageId1, stageAttemptId1, updates1) = a + val (taskId2, stageId2, stageAttemptId2, updates2) = b + assert(taskId1 === taskId2) + assert(stageId1 === stageId2) + assert(stageAttemptId1 === stageAttemptId2) + assertSeqEquals[AccumulableInfo](updates1, updates2, (a, b) => a.equals(b)) + }) case (e1, e2) => assert(e1 === e2) case _ => fail("Events don't match in types!") @@ -538,7 +613,6 @@ class JsonProtocolSuite extends SparkFunSuite { } private def assertEquals(metrics1: TaskMetrics, metrics2: TaskMetrics) { - assert(metrics1.hostname === metrics2.hostname) assert(metrics1.executorDeserializeTime === metrics2.executorDeserializeTime) assert(metrics1.resultSize === metrics2.resultSize) assert(metrics1.jvmGCTime === metrics2.jvmGCTime) @@ -551,7 +625,7 @@ class JsonProtocolSuite extends SparkFunSuite { metrics1.shuffleWriteMetrics, metrics2.shuffleWriteMetrics, assertShuffleWriteEquals) assertOptionEquals( metrics1.inputMetrics, metrics2.inputMetrics, assertInputMetricsEquals) - assertOptionEquals(metrics1.updatedBlocks, metrics2.updatedBlocks, assertBlocksEquals) + assertBlocksEquals(metrics1.updatedBlockStatuses, metrics2.updatedBlockStatuses) } private def assertEquals(metrics1: ShuffleReadMetrics, metrics2: ShuffleReadMetrics) { @@ -562,8 +636,8 @@ class JsonProtocolSuite extends SparkFunSuite { } private def assertEquals(metrics1: ShuffleWriteMetrics, metrics2: ShuffleWriteMetrics) { - assert(metrics1.shuffleBytesWritten === metrics2.shuffleBytesWritten) - assert(metrics1.shuffleWriteTime === metrics2.shuffleWriteTime) + assert(metrics1.bytesWritten === metrics2.bytesWritten) + assert(metrics1.writeTime === metrics2.writeTime) } private def assertEquals(metrics1: InputMetrics, metrics2: InputMetrics) { @@ -595,7 +669,7 @@ class JsonProtocolSuite extends SparkFunSuite { assert(r1.description === r2.description) assertSeqEquals(r1.stackTrace, r2.stackTrace, assertStackTraceElementEquals) assert(r1.fullStackTrace === r2.fullStackTrace) - assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals) + assertSeqEquals[AccumulableInfo](r1.accumUpdates, r2.accumUpdates, (a, b) => a.equals(b)) case (TaskResultLost, TaskResultLost) => case (TaskKilled, TaskKilled) => case (TaskCommitDenied(jobId1, partitionId1, attemptNumber1), @@ -631,10 +705,16 @@ class JsonProtocolSuite extends SparkFunSuite { assertStackTraceElementEquals) } - private def assertJsonStringEquals(json1: String, json2: String) { + private def assertJsonStringEquals(expected: String, actual: String, metadata: String) { val formatJsonString = (json: String) => json.replaceAll("[\\s|]", "") - assert(formatJsonString(json1) === formatJsonString(json2), - s"input ${formatJsonString(json1)} got ${formatJsonString(json2)}") + if (formatJsonString(expected) != formatJsonString(actual)) { + // scalastyle:off + // This prints something useful if the JSON strings don't match + println("=== EXPECTED ===\n" + pretty(parse(expected)) + "\n") + println("=== ACTUAL ===\n" + pretty(parse(actual)) + "\n") + // scalastyle:on + throw new TestFailedException(s"$metadata JSON did not equal", 1) + } } private def assertSeqEquals[T](seq1: Seq[T], seq2: Seq[T], assertEquals: (T, T) => Unit) { @@ -713,7 +793,7 @@ class JsonProtocolSuite extends SparkFunSuite { } private def makeRddInfo(a: Int, b: Int, c: Int, d: Long, e: Long) = { - val r = new RDDInfo(a, "mayor", b, StorageLevel.MEMORY_AND_DISK, Seq(1, 4, 7)) + val r = new RDDInfo(a, "mayor", b, StorageLevel.MEMORY_AND_DISK, Seq(1, 4, 7), a.toString) r.numCachedPartitions = c r.memSize = d r.diskSize = e @@ -740,8 +820,13 @@ class JsonProtocolSuite extends SparkFunSuite { taskInfo } - private def makeAccumulableInfo(id: Int, internal: Boolean = false): AccumulableInfo = - AccumulableInfo(id, " Accumulable " + id, Some("delta" + id), "val" + id, internal) + private def makeAccumulableInfo( + id: Int, + internal: Boolean = false, + countFailedValues: Boolean = false, + metadata: Option[String] = None): AccumulableInfo = + new AccumulableInfo(id, Some(s"Accumulable$id"), Some(s"delta$id"), Some(s"val$id"), + internal, countFailedValues, metadata) /** * Creates a TaskMetrics object describing a task that read data from Hadoop (if hasHadoopInput is @@ -758,7 +843,6 @@ class JsonProtocolSuite extends SparkFunSuite { hasOutput: Boolean, hasRecords: Boolean = true) = { val t = new TaskMetrics - t.setHostname("localhost") t.setExecutorDeserializeTime(a) t.setExecutorRunTime(b) t.setResultSize(c) @@ -767,35 +851,32 @@ class JsonProtocolSuite extends SparkFunSuite { t.incMemoryBytesSpilled(a + c) if (hasHadoopInput) { - val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) - inputMetrics.incBytesRead(d + e + f) + val inputMetrics = t.registerInputMetrics(DataReadMethod.Hadoop) + inputMetrics.setBytesRead(d + e + f) inputMetrics.incRecordsRead(if (hasRecords) (d + e + f) / 100 else -1) - t.setInputMetrics(Some(inputMetrics)) } else { - val sr = new ShuffleReadMetrics + val sr = t.registerTempShuffleReadMetrics() sr.incRemoteBytesRead(b + d) sr.incLocalBlocksFetched(e) sr.incFetchWaitTime(a + d) sr.incRemoteBlocksFetched(f) sr.incRecordsRead(if (hasRecords) (b + d) / 100 else -1) sr.incLocalBytesRead(a + f) - t.setShuffleReadMetrics(Some(sr)) + t.mergeShuffleReadMetrics() } if (hasOutput) { - val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop) + val outputMetrics = t.registerOutputMetrics(DataWriteMethod.Hadoop) outputMetrics.setBytesWritten(a + b + c) outputMetrics.setRecordsWritten(if (hasRecords) (a + b + c)/100 else -1) - t.outputMetrics = Some(outputMetrics) } else { - val sw = new ShuffleWriteMetrics - sw.incShuffleBytesWritten(a + b + c) - sw.incShuffleWriteTime(b + c + d) - sw.setShuffleRecordsWritten(if (hasRecords) (a + b + c) / 100 else -1) - t.shuffleWriteMetrics = Some(sw) + val sw = t.registerShuffleWriteMetrics() + sw.incBytesWritten(a + b + c) + sw.incWriteTime(b + c + d) + sw.incRecordsWritten(if (hasRecords) (a + b + c) / 100 else -1) } // Make at most 6 blocks - t.updatedBlocks = Some((1 to (e % 5 + 1)).map { i => - (RDDBlockId(e % i, f % i), BlockStatus(StorageLevel.MEMORY_AND_DISK_SER_2, a % i, b % i, c%i)) + t.setUpdatedBlockStatuses((1 to (e % 5 + 1)).map { i => + (RDDBlockId(e % i, f % i), BlockStatus(StorageLevel.MEMORY_AND_DISK_SER_2, a % i, b % i)) }.toSeq) t } @@ -823,14 +904,16 @@ class JsonProtocolSuite extends SparkFunSuite { | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 1, | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | } | ] | }, @@ -856,18 +939,17 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 101, | "Name": "mayor", + | "Callsite": "101", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": true, | "Replication": 1 | }, | "Number of Partitions": 201, | "Number of Cached Partitions": 301, | "Memory Size": 401, - | "ExternalBlockStore Size": 0, | "Disk Size": 501 | } | ], @@ -879,14 +961,16 @@ class JsonProtocolSuite extends SparkFunSuite { | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 1, | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | } | ] | } @@ -917,21 +1001,24 @@ class JsonProtocolSuite extends SparkFunSuite { | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", | "Value": "val3", - | "Internal": true + | "Internal": true, + | "Count Failed Values": false | } | ] | } @@ -960,21 +1047,24 @@ class JsonProtocolSuite extends SparkFunSuite { | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", | "Value": "val3", - | "Internal": true + | "Internal": true, + | "Count Failed Values": false | } | ] | } @@ -1009,26 +1099,28 @@ class JsonProtocolSuite extends SparkFunSuite { | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", | "Value": "val3", - | "Internal": true + | "Internal": true, + | "Count Failed Values": false | } | ] | }, | "Task Metrics": { - | "Host Name": "localhost", | "Executor Deserialize Time": 300, | "Executor Run Time": 400, | "Result Size": 500, @@ -1042,7 +1134,7 @@ class JsonProtocolSuite extends SparkFunSuite { | "Fetch Wait Time": 900, | "Remote Bytes Read": 1000, | "Local Bytes Read": 1100, - | "Total Records Read" : 10 + | "Total Records Read": 10 | }, | "Shuffle Write Metrics": { | "Shuffle Bytes Written": 1200, @@ -1056,12 +1148,10 @@ class JsonProtocolSuite extends SparkFunSuite { | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": false, | "Replication": 2 | }, | "Memory Size": 0, - | "ExternalBlockStore Size": 0, | "Disk Size": 0 | } | } @@ -1098,26 +1188,28 @@ class JsonProtocolSuite extends SparkFunSuite { | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", | "Value": "val3", - | "Internal": true + | "Internal": true, + | "Count Failed Values": false | } | ] | }, | "Task Metrics": { - | "Host Name": "localhost", | "Executor Deserialize Time": 300, | "Executor Run Time": 400, | "Result Size": 500, @@ -1142,12 +1234,10 @@ class JsonProtocolSuite extends SparkFunSuite { | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": false, | "Replication": 2 | }, | "Memory Size": 0, - | "ExternalBlockStore Size": 0, | "Disk Size": 0 | } | } @@ -1184,26 +1274,28 @@ class JsonProtocolSuite extends SparkFunSuite { | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 2, | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 3, | "Name": "Accumulable3", | "Update": "delta3", | "Value": "val3", - | "Internal": true + | "Internal": true, + | "Count Failed Values": false | } | ] | }, | "Task Metrics": { - | "Host Name": "localhost", | "Executor Deserialize Time": 300, | "Executor Run Time": 400, | "Result Size": 500, @@ -1228,12 +1320,10 @@ class JsonProtocolSuite extends SparkFunSuite { | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": false, | "Replication": 2 | }, | "Memory Size": 0, - | "ExternalBlockStore Size": 0, | "Disk Size": 0 | } | } @@ -1258,18 +1348,17 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 1, | "Name": "mayor", + | "Callsite": "1", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": true, | "Replication": 1 | }, | "Number of Partitions": 200, | "Number of Cached Partitions": 300, | "Memory Size": 400, - | "ExternalBlockStore Size": 0, | "Disk Size": 500 | } | ], @@ -1278,17 +1367,19 @@ class JsonProtocolSuite extends SparkFunSuite { | "Accumulables": [ | { | "ID": 2, - | "Name": " Accumulable 2", + | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 1, - | "Name": " Accumulable 1", + | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | } | ] | }, @@ -1301,35 +1392,33 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 2, | "Name": "mayor", + | "Callsite": "2", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": true, | "Replication": 1 | }, | "Number of Partitions": 400, | "Number of Cached Partitions": 600, | "Memory Size": 800, - | "ExternalBlockStore Size": 0, | "Disk Size": 1000 | }, | { | "RDD ID": 3, | "Name": "mayor", + | "Callsite": "3", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": true, | "Replication": 1 | }, | "Number of Partitions": 401, | "Number of Cached Partitions": 601, | "Memory Size": 801, - | "ExternalBlockStore Size": 0, | "Disk Size": 1001 | } | ], @@ -1338,17 +1427,19 @@ class JsonProtocolSuite extends SparkFunSuite { | "Accumulables": [ | { | "ID": 2, - | "Name": " Accumulable 2", + | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 1, - | "Name": " Accumulable 1", + | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | } | ] | }, @@ -1361,52 +1452,49 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 3, | "Name": "mayor", + | "Callsite": "3", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": true, | "Replication": 1 | }, | "Number of Partitions": 600, | "Number of Cached Partitions": 900, | "Memory Size": 1200, - | "ExternalBlockStore Size": 0, | "Disk Size": 1500 | }, | { | "RDD ID": 4, | "Name": "mayor", + | "Callsite": "4", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": true, | "Replication": 1 | }, | "Number of Partitions": 601, | "Number of Cached Partitions": 901, | "Memory Size": 1201, - | "ExternalBlockStore Size": 0, | "Disk Size": 1501 | }, | { | "RDD ID": 5, | "Name": "mayor", + | "Callsite": "5", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": true, | "Replication": 1 | }, | "Number of Partitions": 602, | "Number of Cached Partitions": 902, | "Memory Size": 1202, - | "ExternalBlockStore Size": 0, | "Disk Size": 1502 | } | ], @@ -1415,17 +1503,19 @@ class JsonProtocolSuite extends SparkFunSuite { | "Accumulables": [ | { | "ID": 2, - | "Name": " Accumulable 2", + | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 1, - | "Name": " Accumulable 1", + | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | } | ] | }, @@ -1438,69 +1528,65 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 4, | "Name": "mayor", + | "Callsite": "4", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": true, | "Replication": 1 | }, | "Number of Partitions": 800, | "Number of Cached Partitions": 1200, | "Memory Size": 1600, - | "ExternalBlockStore Size": 0, | "Disk Size": 2000 | }, | { | "RDD ID": 5, | "Name": "mayor", + | "Callsite": "5", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": true, | "Replication": 1 | }, | "Number of Partitions": 801, | "Number of Cached Partitions": 1201, | "Memory Size": 1601, - | "ExternalBlockStore Size": 0, | "Disk Size": 2001 | }, | { | "RDD ID": 6, | "Name": "mayor", + | "Callsite": "6", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": true, | "Replication": 1 | }, | "Number of Partitions": 802, | "Number of Cached Partitions": 1202, | "Memory Size": 1602, - | "ExternalBlockStore Size": 0, | "Disk Size": 2002 | }, | { | "RDD ID": 7, | "Name": "mayor", + | "Callsite": "7", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, | "Use Memory": true, - | "Use ExternalBlockStore": false, | "Deserialized": true, | "Replication": 1 | }, | "Number of Partitions": 803, | "Number of Cached Partitions": 1203, | "Memory Size": 1603, - | "ExternalBlockStore Size": 0, | "Disk Size": 2003 | } | ], @@ -1509,17 +1595,19 @@ class JsonProtocolSuite extends SparkFunSuite { | "Accumulables": [ | { | "ID": 2, - | "Name": " Accumulable 2", + | "Name": "Accumulable2", | "Update": "delta2", | "Value": "val2", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | }, | { | "ID": 1, - | "Name": " Accumulable 1", + | "Name": "Accumulable1", | "Update": "delta1", | "Value": "val1", - | "Internal": false + | "Internal": false, + | "Count Failed Values": false | } | ] | } @@ -1671,53 +1759,208 @@ class JsonProtocolSuite extends SparkFunSuite { """ private val executorMetricsUpdateJsonString = - s""" - |{ - | "Event": "SparkListenerExecutorMetricsUpdate", - | "Executor ID": "exec3", - | "Metrics Updated": [ - | { - | "Task ID": 1, - | "Stage ID": 2, - | "Stage Attempt ID": 3, - | "Task Metrics": { - | "Host Name": "localhost", - | "Executor Deserialize Time": 300, - | "Executor Run Time": 400, - | "Result Size": 500, - | "JVM GC Time": 600, - | "Result Serialization Time": 700, - | "Memory Bytes Spilled": 800, - | "Disk Bytes Spilled": 0, - | "Input Metrics": { - | "Data Read Method": "Hadoop", - | "Bytes Read": 2100, - | "Records Read": 21 - | }, - | "Output Metrics": { - | "Data Write Method": "Hadoop", - | "Bytes Written": 1200, - | "Records Written": 12 - | }, - | "Updated Blocks": [ - | { - | "Block ID": "rdd_0_0", - | "Status": { - | "Storage Level": { - | "Use Disk": true, - | "Use Memory": true, - | "Use ExternalBlockStore": false, - | "Deserialized": false, - | "Replication": 2 - | }, - | "Memory Size": 0, - | "ExternalBlockStore Size": 0, - | "Disk Size": 0 - | } - | } - | ] - | } - | }] - |} - """.stripMargin + s""" + |{ + | "Event": "SparkListenerExecutorMetricsUpdate", + | "Executor ID": "exec3", + | "Metrics Updated": [ + | { + | "Task ID": 1, + | "Stage ID": 2, + | "Stage Attempt ID": 3, + | "Accumulator Updates": [ + | { + | "ID": 0, + | "Name": "$EXECUTOR_DESERIALIZE_TIME", + | "Update": 300, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 1, + | "Name": "$EXECUTOR_RUN_TIME", + | "Update": 400, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 2, + | "Name": "$RESULT_SIZE", + | "Update": 500, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 3, + | "Name": "$JVM_GC_TIME", + | "Update": 600, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 4, + | "Name": "$RESULT_SERIALIZATION_TIME", + | "Update": 700, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 5, + | "Name": "$MEMORY_BYTES_SPILLED", + | "Update": 800, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 6, + | "Name": "$DISK_BYTES_SPILLED", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 7, + | "Name": "$PEAK_EXECUTION_MEMORY", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 8, + | "Name": "$UPDATED_BLOCK_STATUSES", + | "Update": [ + | { + | "BlockID": "rdd_0_0", + | "Status": { + | "StorageLevel": { + | "UseDisk": true, + | "UseMemory": true, + | "Deserialized": false, + | "Replication": 2 + | }, + | "MemorySize": 0, + | "DiskSize": 0 + | } + | } + | ], + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 9, + | "Name": "${shuffleRead.REMOTE_BLOCKS_FETCHED}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 10, + | "Name": "${shuffleRead.LOCAL_BLOCKS_FETCHED}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 11, + | "Name": "${shuffleRead.REMOTE_BYTES_READ}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 12, + | "Name": "${shuffleRead.LOCAL_BYTES_READ}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 13, + | "Name": "${shuffleRead.FETCH_WAIT_TIME}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 14, + | "Name": "${shuffleRead.RECORDS_READ}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 15, + | "Name": "${shuffleWrite.BYTES_WRITTEN}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 16, + | "Name": "${shuffleWrite.RECORDS_WRITTEN}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 17, + | "Name": "${shuffleWrite.WRITE_TIME}", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 18, + | "Name": "${input.READ_METHOD}", + | "Update": "Hadoop", + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 19, + | "Name": "${input.BYTES_READ}", + | "Update": 2100, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 20, + | "Name": "${input.RECORDS_READ}", + | "Update": 21, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 21, + | "Name": "${output.WRITE_METHOD}", + | "Update": "Hadoop", + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 22, + | "Name": "${output.BYTES_WRITTEN}", + | "Update": 1200, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 23, + | "Name": "${output.RECORDS_WRITTEN}", + | "Update": 12, + | "Internal": true, + | "Count Failed Values": true + | }, + | { + | "ID": 24, + | "Name": "$TEST_ACCUM", + | "Update": 0, + | "Internal": true, + | "Count Failed Values": true + | } + | ] + | } + | ] + |} + """.stripMargin } diff --git a/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala index c58db5e606f7c..75e4504850679 100644 --- a/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala +++ b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala @@ -22,8 +22,6 @@ import java.util.Properties import org.apache.commons.lang3.SerializationUtils import org.scalatest.{BeforeAndAfterEach, Suite} -import org.apache.spark.SparkFunSuite - /** * Mixin for automatically resetting system properties that are modified in ScalaTest tests. * This resets the properties after each individual test. @@ -45,7 +43,7 @@ private[spark] trait ResetSystemProperties extends BeforeAndAfterEach { this: Su var oldProperties: Properties = null override def beforeEach(): Unit = { - // we need SerializationUtils.clone instead of `new Properties(System.getProperties()` because + // we need SerializationUtils.clone instead of `new Properties(System.getProperties())` because // the later way of creating a copy does not copy the properties but it initializes a new // Properties object with the given properties as defaults. They are not recognized at all // by standard Scala wrapper over Java Properties then. diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 20550178fb1bd..c342b68f46656 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.util import scala.collection.mutable.ArrayBuffer -import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, PrivateMethodTester} +import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} import org.apache.spark.SparkFunSuite @@ -60,6 +60,12 @@ class DummyString(val arr: Array[Char]) { @transient val hash32: Int = 0 } +class DummyClass8 extends KnownSizeEstimation { + val x: Int = 0 + + override def estimatedSize: Long = 2015 +} + class SizeEstimatorSuite extends SparkFunSuite with BeforeAndAfterEach @@ -73,6 +79,10 @@ class SizeEstimatorSuite System.setProperty("spark.test.useCompressedOops", "true") } + override def afterEach(): Unit = { + super.afterEach() + } + test("simple classes") { assertResult(16)(SizeEstimator.estimate(new DummyClass1)) assertResult(16)(SizeEstimator.estimate(new DummyClass2)) @@ -214,4 +224,10 @@ class SizeEstimatorSuite // Class should be 32 bytes on s390x if recognised as 64 bit platform assertResult(32)(SizeEstimator.estimate(new DummyClass7)) } + + test("SizeEstimation can provide the estimated size") { + // DummyClass8 provides its size estimation. + assertResult(2015)(SizeEstimator.estimate(new DummyClass8)) + assertResult(20206)(SizeEstimator.estimate(Array.fill(10)(new DummyClass8))) + } } diff --git a/core/src/test/scala/org/apache/spark/util/SparkConfWithEnv.scala b/core/src/test/scala/org/apache/spark/util/SparkConfWithEnv.scala index ddd5edf4f7396..0c8b8cfdd53a1 100644 --- a/core/src/test/scala/org/apache/spark/util/SparkConfWithEnv.scala +++ b/core/src/test/scala/org/apache/spark/util/SparkConfWithEnv.scala @@ -23,9 +23,7 @@ import org.apache.spark.SparkConf * Customized SparkConf that allows env variables to be overridden. */ class SparkConfWithEnv(env: Map[String, String]) extends SparkConf(false) { - override def getenv(name: String): String = { - env.get(name).getOrElse(super.getenv(name)) - } + override def getenv(name: String): String = env.getOrElse(name, super.getenv(name)) override def clone: SparkConf = { new SparkConfWithEnv(env).setAll(getAll) diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala index 620e4debf4e08..6652a41b6990b 100644 --- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -20,10 +20,12 @@ package org.apache.spark.util import java.util.concurrent.{CountDownLatch, TimeUnit} -import scala.concurrent.duration._ import scala.concurrent.{Await, Future} +import scala.concurrent.duration._ import scala.util.Random +import org.scalatest.concurrent.Eventually._ + import org.apache.spark.SparkFunSuite class ThreadUtilsSuite extends SparkFunSuite { @@ -59,6 +61,49 @@ class ThreadUtilsSuite extends SparkFunSuite { } } + test("newDaemonCachedThreadPool") { + val maxThreadNumber = 10 + val startThreadsLatch = new CountDownLatch(maxThreadNumber) + val latch = new CountDownLatch(1) + val cachedThreadPool = ThreadUtils.newDaemonCachedThreadPool( + "ThreadUtilsSuite-newDaemonCachedThreadPool", + maxThreadNumber, + keepAliveSeconds = 2) + try { + for (_ <- 1 to maxThreadNumber) { + cachedThreadPool.execute(new Runnable { + override def run(): Unit = { + startThreadsLatch.countDown() + latch.await(10, TimeUnit.SECONDS) + } + }) + } + startThreadsLatch.await(10, TimeUnit.SECONDS) + assert(cachedThreadPool.getActiveCount === maxThreadNumber) + assert(cachedThreadPool.getQueue.size === 0) + + // Submit a new task and it should be put into the queue since the thread number reaches the + // limitation + cachedThreadPool.execute(new Runnable { + override def run(): Unit = { + latch.await(10, TimeUnit.SECONDS) + } + }) + + assert(cachedThreadPool.getActiveCount === maxThreadNumber) + assert(cachedThreadPool.getQueue.size === 1) + + latch.countDown() + eventually(timeout(10.seconds)) { + // All threads should be stopped after keepAliveSeconds + assert(cachedThreadPool.getActiveCount === 0) + assert(cachedThreadPool.getPoolSize === 0) + } + } finally { + cachedThreadPool.shutdownNow() + } + } + test("sameThread") { val callerThreadName = Thread.currentThread().getName() val f = Future { diff --git a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala index 9b3169026cda3..25fc15dd54d04 100644 --- a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.util -import java.lang.ref.WeakReference - import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.Random @@ -34,10 +32,6 @@ class TimeStampedHashMapSuite extends SparkFunSuite { testMap(new TimeStampedHashMap[String, String]()) testMapThreadSafety(new TimeStampedHashMap[String, String]()) - // Test TimeStampedWeakValueHashMap basic functionality - testMap(new TimeStampedWeakValueHashMap[String, String]()) - testMapThreadSafety(new TimeStampedWeakValueHashMap[String, String]()) - test("TimeStampedHashMap - clearing by timestamp") { // clearing by insertion time val map = new TimeStampedHashMap[String, String](updateTimeStampOnGet = false) @@ -68,86 +62,6 @@ class TimeStampedHashMapSuite extends SparkFunSuite { assert(map1.get("k2").isDefined) } - test("TimeStampedWeakValueHashMap - clearing by timestamp") { - // clearing by insertion time - val map = new TimeStampedWeakValueHashMap[String, String](updateTimeStampOnGet = false) - map("k1") = "v1" - assert(map("k1") === "v1") - Thread.sleep(10) - val threshTime = System.currentTimeMillis - assert(map.getTimestamp("k1").isDefined) - assert(map.getTimestamp("k1").get < threshTime) - map.clearOldValues(threshTime) - assert(map.get("k1") === None) - - // clearing by modification time - val map1 = new TimeStampedWeakValueHashMap[String, String](updateTimeStampOnGet = true) - map1("k1") = "v1" - map1("k2") = "v2" - assert(map1("k1") === "v1") - Thread.sleep(10) - val threshTime1 = System.currentTimeMillis - Thread.sleep(10) - assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime - assert(map1.getTimestamp("k1").isDefined) - assert(map1.getTimestamp("k1").get < threshTime1) - assert(map1.getTimestamp("k2").isDefined) - assert(map1.getTimestamp("k2").get >= threshTime1) - map1.clearOldValues(threshTime1) // should only clear k1 - assert(map1.get("k1") === None) - assert(map1.get("k2").isDefined) - } - - test("TimeStampedWeakValueHashMap - clearing weak references") { - var strongRef = new Object - val weakRef = new WeakReference(strongRef) - val map = new TimeStampedWeakValueHashMap[String, Object] - map("k1") = strongRef - map("k2") = "v2" - map("k3") = "v3" - val isEquals = map("k1") == strongRef - assert(isEquals) - - // clear strong reference to "k1" - strongRef = null - val startTime = System.currentTimeMillis - System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. - System.runFinalization() // Make a best effort to call finalizer on all cleaned objects. - while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { - System.gc() - System.runFinalization() - Thread.sleep(100) - } - assert(map.getReference("k1").isDefined) - val ref = map.getReference("k1").get - assert(ref.get === null) - assert(map.get("k1") === None) - - // operations should only display non-null entries - assert(map.iterator.forall { case (k, v) => k != "k1" }) - assert(map.filter { case (k, v) => k != "k2" }.size === 1) - assert(map.filter { case (k, v) => k != "k2" }.head._1 === "k3") - assert(map.toMap.size === 2) - assert(map.toMap.forall { case (k, v) => k != "k1" }) - val buffer = new ArrayBuffer[String] - map.foreach { case (k, v) => buffer += v.toString } - assert(buffer.size === 2) - assert(buffer.forall(_ != "k1")) - val plusMap = map + (("k4", "v4")) - assert(plusMap.size === 3) - assert(plusMap.forall { case (k, v) => k != "k1" }) - val minusMap = map - "k2" - assert(minusMap.size === 1) - assert(minusMap.head._1 == "k3") - - // clear null values - should only clear k1 - map.clearNullValues() - assert(map.getReference("k1") === None) - assert(map.get("k1") === None) - assert(map.get("k2").isDefined) - assert(map.get("k2").get === "v2") - } - /** Test basic operations of a Scala mutable Map. */ def testMap(hashMapConstructor: => mutable.Map[String, String]) { def newMap() = hashMapConstructor diff --git a/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala b/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala new file mode 100644 index 0000000000000..39b31f8ddeaba --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.concurrent.{CountDownLatch, TimeUnit} + +import scala.util.Random + +import com.google.common.util.concurrent.Uninterruptibles + +import org.apache.spark.SparkFunSuite + +class UninterruptibleThreadSuite extends SparkFunSuite { + + /** Sleep millis and return true if it's interrupted */ + private def sleep(millis: Long): Boolean = { + try { + Thread.sleep(millis) + false + } catch { + case _: InterruptedException => + true + } + } + + test("interrupt when runUninterruptibly is running") { + val enterRunUninterruptibly = new CountDownLatch(1) + @volatile var hasInterruptedException = false + @volatile var interruptStatusBeforeExit = false + val t = new UninterruptibleThread("test") { + override def run(): Unit = { + runUninterruptibly { + enterRunUninterruptibly.countDown() + hasInterruptedException = sleep(1000) + } + interruptStatusBeforeExit = Thread.interrupted() + } + } + t.start() + assert(enterRunUninterruptibly.await(10, TimeUnit.SECONDS), "await timeout") + t.interrupt() + t.join() + assert(hasInterruptedException === false) + assert(interruptStatusBeforeExit === true) + } + + test("interrupt before runUninterruptibly runs") { + val interruptLatch = new CountDownLatch(1) + @volatile var hasInterruptedException = false + @volatile var interruptStatusBeforeExit = false + val t = new UninterruptibleThread("test") { + override def run(): Unit = { + Uninterruptibles.awaitUninterruptibly(interruptLatch, 10, TimeUnit.SECONDS) + try { + runUninterruptibly { + assert(false, "Should not reach here") + } + } catch { + case _: InterruptedException => hasInterruptedException = true + } + interruptStatusBeforeExit = Thread.interrupted() + } + } + t.start() + t.interrupt() + interruptLatch.countDown() + t.join() + assert(hasInterruptedException === true) + assert(interruptStatusBeforeExit === false) + } + + test("nested runUninterruptibly") { + val enterRunUninterruptibly = new CountDownLatch(1) + val interruptLatch = new CountDownLatch(1) + @volatile var hasInterruptedException = false + @volatile var interruptStatusBeforeExit = false + val t = new UninterruptibleThread("test") { + override def run(): Unit = { + runUninterruptibly { + enterRunUninterruptibly.countDown() + Uninterruptibles.awaitUninterruptibly(interruptLatch, 10, TimeUnit.SECONDS) + hasInterruptedException = sleep(1) + runUninterruptibly { + if (sleep(1)) { + hasInterruptedException = true + } + } + if (sleep(1)) { + hasInterruptedException = true + } + } + interruptStatusBeforeExit = Thread.interrupted() + } + } + t.start() + assert(enterRunUninterruptibly.await(10, TimeUnit.SECONDS), "await timeout") + t.interrupt() + interruptLatch.countDown() + t.join() + assert(hasInterruptedException === false) + assert(interruptStatusBeforeExit === true) + } + + test("stress test") { + @volatile var hasInterruptedException = false + val t = new UninterruptibleThread("test") { + override def run(): Unit = { + for (i <- 0 until 100) { + try { + runUninterruptibly { + if (sleep(Random.nextInt(10))) { + hasInterruptedException = true + } + runUninterruptibly { + if (sleep(Random.nextInt(10))) { + hasInterruptedException = true + } + } + if (sleep(Random.nextInt(10))) { + hasInterruptedException = true + } + } + Uninterruptibles.sleepUninterruptibly(Random.nextInt(10), TimeUnit.MILLISECONDS) + // 50% chance to clear the interrupted status + if (Random.nextBoolean()) { + Thread.interrupted() + } + } catch { + case _: InterruptedException => + // The first runUninterruptibly may throw InterruptedException if the interrupt status + // is set before running `f`. + } + } + } + } + t.start() + for (i <- 0 until 400) { + Thread.sleep(Random.nextInt(10)) + t.interrupt() + } + t.join() + assert(hasInterruptedException === false) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 68b0da76bc134..4aa4854c36f3a 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -17,26 +17,26 @@ package org.apache.spark.util -import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileOutputStream} import java.lang.{Double => JDouble, Float => JFloat} import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} +import java.nio.charset.StandardCharsets import java.text.DecimalFormatSymbols -import java.util.concurrent.TimeUnit import java.util.Locale +import java.util.concurrent.TimeUnit import scala.collection.mutable.ListBuffer import scala.util.Random -import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files - +import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit -import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.SparkConf class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { @@ -269,7 +269,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { val tmpDir2 = Utils.createTempDir() val f1Path = tmpDir2 + "/f1" val f1 = new FileOutputStream(f1Path) - f1.write("1\n2\n3\n4\n5\n6\n7\n8\n9\n".getBytes(UTF_8)) + f1.write("1\n2\n3\n4\n5\n6\n7\n8\n9\n".getBytes(StandardCharsets.UTF_8)) f1.close() // Read first few bytes @@ -296,9 +296,9 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { test("reading offset bytes across multiple files") { val tmpDir = Utils.createTempDir() val files = (1 to 3).map(i => new File(tmpDir, i.toString)) - Files.write("0123456789", files(0), UTF_8) - Files.write("abcdefghij", files(1), UTF_8) - Files.write("ABCDEFGHIJ", files(2), UTF_8) + Files.write("0123456789", files(0), StandardCharsets.UTF_8) + Files.write("abcdefghij", files(1), StandardCharsets.UTF_8) + Files.write("ABCDEFGHIJ", files(2), StandardCharsets.UTF_8) // Read first few bytes in the 1st file assert(Utils.offsetBytes(files, 0, 5) === "01234") @@ -530,7 +530,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { try { System.setProperty("spark.test.fileNameLoadB", "2") Files.write("spark.test.fileNameLoadA true\n" + - "spark.test.fileNameLoadB 1\n", outFile, UTF_8) + "spark.test.fileNameLoadB 1\n", outFile, StandardCharsets.UTF_8) val properties = Utils.getPropertiesFromFile(outFile.getAbsolutePath) properties .filter { case (k, v) => k.startsWith("spark.")} @@ -560,7 +560,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { val innerSourceDir = Utils.createTempDir(root = sourceDir.getPath) val sourceFile = File.createTempFile("someprefix", "somesuffix", innerSourceDir) val targetDir = new File(tempDir, "target-dir") - Files.write("some text", sourceFile, UTF_8) + Files.write("some text", sourceFile, StandardCharsets.UTF_8) val path = if (Utils.isWindows) { @@ -723,6 +723,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { test("isDynamicAllocationEnabled") { val conf = new SparkConf() + conf.set("spark.master", "yarn-client") assert(Utils.isDynamicAllocationEnabled(conf) === false) assert(Utils.isDynamicAllocationEnabled( conf.set("spark.dynamicAllocation.enabled", "false")) === false) @@ -732,6 +733,94 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { conf.set("spark.executor.instances", "1")) === false) assert(Utils.isDynamicAllocationEnabled( conf.set("spark.executor.instances", "0")) === true) + assert(Utils.isDynamicAllocationEnabled(conf.set("spark.master", "local")) === false) + assert(Utils.isDynamicAllocationEnabled(conf.set("spark.dynamicAllocation.testing", "true"))) + } + + test("encodeFileNameToURIRawPath") { + assert(Utils.encodeFileNameToURIRawPath("abc") === "abc") + assert(Utils.encodeFileNameToURIRawPath("abc xyz") === "abc%20xyz") + assert(Utils.encodeFileNameToURIRawPath("abc:xyz") === "abc:xyz") } + test("decodeFileNameInURI") { + assert(Utils.decodeFileNameInURI(new URI("files:///abc/xyz")) === "xyz") + assert(Utils.decodeFileNameInURI(new URI("files:///abc")) === "abc") + assert(Utils.decodeFileNameInURI(new URI("files:///abc%20xyz")) === "abc xyz") + } + + test("Kill process") { + // Verify that we can terminate a process even if it is in a bad state. This is only run + // on UNIX since it does some OS specific things to verify the correct behavior. + if (SystemUtils.IS_OS_UNIX) { + def getPid(p: Process): Int = { + val f = p.getClass().getDeclaredField("pid") + f.setAccessible(true) + f.get(p).asInstanceOf[Int] + } + + def pidExists(pid: Int): Boolean = { + val p = Runtime.getRuntime.exec(s"kill -0 $pid") + p.waitFor() + p.exitValue() == 0 + } + + def signal(pid: Int, s: String): Unit = { + val p = Runtime.getRuntime.exec(s"kill -$s $pid") + p.waitFor() + } + + // Start up a process that runs 'sleep 10'. Terminate the process and assert it takes + // less time and the process is no longer there. + val startTimeMs = System.currentTimeMillis() + val process = new ProcessBuilder("sleep", "10").start() + val pid = getPid(process) + try { + assert(pidExists(pid)) + val terminated = Utils.terminateProcess(process, 5000) + assert(terminated.isDefined) + Utils.waitForProcess(process, 5000) + val durationMs = System.currentTimeMillis() - startTimeMs + assert(durationMs < 5000) + assert(!pidExists(pid)) + } finally { + // Forcibly kill the test process just in case. + signal(pid, "SIGKILL") + } + + val versionParts = System.getProperty("java.version").split("[+.\\-]+", 3) + var majorVersion = versionParts(0).toInt + if (majorVersion == 1) majorVersion = versionParts(1).toInt + if (majorVersion >= 8) { + // Java8 added a way to forcibly terminate a process. We'll make sure that works by + // creating a very misbehaving process. It ignores SIGTERM and has been SIGSTOPed. On + // older versions of java, this will *not* terminate. + val file = File.createTempFile("temp-file-name", ".tmp") + val cmd = + s""" + |#!/bin/bash + |trap "" SIGTERM + |sleep 10 + """.stripMargin + Files.write(cmd.getBytes(StandardCharsets.UTF_8), file) + file.getAbsoluteFile.setExecutable(true) + + val process = new ProcessBuilder(file.getAbsolutePath).start() + val pid = getPid(process) + assert(pidExists(pid)) + try { + signal(pid, "SIGSTOP") + val start = System.currentTimeMillis() + val terminated = Utils.terminateProcess(process, 5000) + assert(terminated.isDefined) + Utils.waitForProcess(process, 5000) + val duration = System.currentTimeMillis() - start + assert(duration < 5000) + assert(!pidExists(pid)) + } finally { + signal(pid, "SIGKILL") + } + } + } + } } diff --git a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala deleted file mode 100644 index 11194cd22a419..0000000000000 --- a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import scala.util.Random - -import org.apache.spark.SparkFunSuite - -/** - * Tests org.apache.spark.util.Vector functionality - */ -@deprecated("suppress compile time deprecation warning", "1.0.0") -class VectorSuite extends SparkFunSuite { - - def verifyVector(vector: Vector, expectedLength: Int): Unit = { - assert(vector.length == expectedLength) - assert(vector.elements.min > 0.0) - assert(vector.elements.max < 1.0) - } - - test("random with default random number generator") { - val vector100 = Vector.random(100) - verifyVector(vector100, 100) - } - - test("random with given random number generator") { - val vector100 = Vector.random(100, new Random(100)) - verifyVector(vector100, 100) - } -} diff --git a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala index b0db0988eeaab..69dbfa9cd7141 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala @@ -17,10 +17,7 @@ package org.apache.spark.util.collection -import java.io.{File, FileInputStream, FileOutputStream, ObjectInputStream, ObjectOutputStream} - import org.apache.spark.SparkFunSuite -import org.apache.spark.util.{Utils => UUtils} class BitSetSuite extends SparkFunSuite { @@ -155,50 +152,4 @@ class BitSetSuite extends SparkFunSuite { assert(bitsetDiff.nextSetBit(85) === 85) assert(bitsetDiff.nextSetBit(86) === -1) } - - test("read and write externally") { - val tempDir = UUtils.createTempDir() - val outputFile = File.createTempFile("bits", null, tempDir) - - val fos = new FileOutputStream(outputFile) - val oos = new ObjectOutputStream(fos) - - // Create BitSet - val setBits = Seq(0, 9, 1, 10, 90, 96) - val bitset = new BitSet(100) - - for (i <- 0 until 100) { - assert(!bitset.get(i)) - } - - setBits.foreach(i => bitset.set(i)) - - for (i <- 0 until 100) { - if (setBits.contains(i)) { - assert(bitset.get(i)) - } else { - assert(!bitset.get(i)) - } - } - assert(bitset.cardinality() === setBits.size) - - bitset.writeExternal(oos) - oos.close() - - val fis = new FileInputStream(outputFile) - val ois = new ObjectInputStream(fis) - - // Read BitSet from the file - val bitset2 = new BitSet(0) - bitset2.readExternal(ois) - - for (i <- 0 until 100) { - if (setBits.contains(i)) { - assert(bitset2.get(i)) - } else { - assert(!bitset2.get(i)) - } - } - assert(bitset2.cardinality() === setBits.size) - } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index d7b2d07a40052..a1a7ac97d924b 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.util.collection -import org.apache.spark.memory.MemoryTestingUtils - import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.apache.spark._ +import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} - class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { import TestUtils.{assertNotSpilled, assertSpilled} @@ -112,7 +110,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { createCombiner _, mergeValue _, mergeCombiners _) val sorter = new ExternalSorter[String, String, ArrayBuffer[String]]( - context, Some(agg), None, None, None) + context, Some(agg), None, None) val collisionPairs = Seq( ("Aa", "BB"), // 2112 @@ -163,7 +161,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val context = MemoryTestingUtils.fakeTaskContext(sc.env) val agg = new Aggregator[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _) - val sorter = new ExternalSorter[FixedHashObject, Int, Int](context, Some(agg), None, None, None) + val sorter = new ExternalSorter[FixedHashObject, Int, Int](context, Some(agg), None, None) // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes // problems if the map fails to group together the objects with the same code (SPARK-2043). val toInsert = for (i <- 1 to 10; j <- 1 to size) yield (FixedHashObject(j, j % 2), 1) @@ -194,7 +192,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners) val sorter = - new ExternalSorter[Int, Int, ArrayBuffer[Int]](context, Some(agg), None, None, None) + new ExternalSorter[Int, Int, ArrayBuffer[Int]](context, Some(agg), None, None) sorter.insertAll( (1 to size).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue))) assert(sorter.numSpills > 0, "sorter did not spill") @@ -221,7 +219,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { createCombiner, mergeValue, mergeCombiners) val sorter = new ExternalSorter[String, String, ArrayBuffer[String]]( - context, Some(agg), None, None, None) + context, Some(agg), None, None) sorter.insertAll((1 to size).iterator.map(i => (i.toString, i.toString)) ++ Iterator( (null.asInstanceOf[String], "1"), @@ -285,25 +283,25 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { // Both aggregator and ordering val sorter = new ExternalSorter[Int, Int, Int]( - context, Some(agg), Some(new HashPartitioner(3)), Some(ord), None) + context, Some(agg), Some(new HashPartitioner(3)), Some(ord)) assert(sorter.iterator.toSeq === Seq()) sorter.stop() // Only aggregator val sorter2 = new ExternalSorter[Int, Int, Int]( - context, Some(agg), Some(new HashPartitioner(3)), None, None) + context, Some(agg), Some(new HashPartitioner(3)), None) assert(sorter2.iterator.toSeq === Seq()) sorter2.stop() // Only ordering val sorter3 = new ExternalSorter[Int, Int, Int]( - context, None, Some(new HashPartitioner(3)), Some(ord), None) + context, None, Some(new HashPartitioner(3)), Some(ord)) assert(sorter3.iterator.toSeq === Seq()) sorter3.stop() // Neither aggregator nor ordering val sorter4 = new ExternalSorter[Int, Int, Int]( - context, None, Some(new HashPartitioner(3)), None, None) + context, None, Some(new HashPartitioner(3)), None) assert(sorter4.iterator.toSeq === Seq()) sorter4.stop() } @@ -322,28 +320,28 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { // Both aggregator and ordering val sorter = new ExternalSorter[Int, Int, Int]( - context, Some(agg), Some(new HashPartitioner(7)), Some(ord), None) + context, Some(agg), Some(new HashPartitioner(7)), Some(ord)) sorter.insertAll(elements.iterator) assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) sorter.stop() // Only aggregator val sorter2 = new ExternalSorter[Int, Int, Int]( - context, Some(agg), Some(new HashPartitioner(7)), None, None) + context, Some(agg), Some(new HashPartitioner(7)), None) sorter2.insertAll(elements.iterator) assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) sorter2.stop() // Only ordering val sorter3 = new ExternalSorter[Int, Int, Int]( - context, None, Some(new HashPartitioner(7)), Some(ord), None) + context, None, Some(new HashPartitioner(7)), Some(ord)) sorter3.insertAll(elements.iterator) assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) sorter3.stop() // Neither aggregator nor ordering val sorter4 = new ExternalSorter[Int, Int, Int]( - context, None, Some(new HashPartitioner(7)), None, None) + context, None, Some(new HashPartitioner(7)), None) sorter4.insertAll(elements.iterator) assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) sorter4.stop() @@ -360,7 +358,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { val elements = Iterator((1, 1), (5, 5)) ++ (0 until size).iterator.map(x => (2, 2)) val sorter = new ExternalSorter[Int, Int, Int]( - context, None, Some(new HashPartitioner(7)), Some(ord), None) + context, None, Some(new HashPartitioner(7)), Some(ord)) sorter.insertAll(elements) assert(sorter.numSpills > 0, "sorter did not spill") val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList)) @@ -444,7 +442,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { val expectedSize = if (withFailures) size - 1 else size val context = MemoryTestingUtils.fakeTaskContext(sc.env) val sorter = new ExternalSorter[Int, Int, Int]( - context, None, Some(new HashPartitioner(3)), Some(ord), None) + context, None, Some(new HashPartitioner(3)), Some(ord)) if (withFailures) { intercept[SparkException] { sorter.insertAll((0 until size).iterator.map { i => @@ -514,7 +512,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { val ord = if (withOrdering) Some(implicitly[Ordering[Int]]) else None val context = MemoryTestingUtils.fakeTaskContext(sc.env) val sorter = - new ExternalSorter[Int, Int, Int](context, agg, Some(new HashPartitioner(3)), ord, None) + new ExternalSorter[Int, Int, Int](context, agg, Some(new HashPartitioner(3)), ord) sorter.insertAll((0 until size).iterator.map { i => (i / 4, i) }) if (withSpilling) { assert(sorter.numSpills > 0, "sorter did not spill") @@ -553,7 +551,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { val context = MemoryTestingUtils.fakeTaskContext(sc.env) val sorter1 = new ExternalSorter[String, String, String]( - context, None, None, Some(wrongOrdering), None) + context, None, None, Some(wrongOrdering)) val thrown = intercept[IllegalArgumentException] { sorter1.insertAll(testData.iterator.map(i => (i, i))) assert(sorter1.numSpills > 0, "sorter did not spill") @@ -575,7 +573,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { createCombiner, mergeValue, mergeCombiners) val sorter2 = new ExternalSorter[String, String, ArrayBuffer[String]]( - context, Some(agg), None, None, None) + context, Some(agg), None, None) sorter2.insertAll(testData.iterator.map(i => (i, i))) assert(sorter2.numSpills > 0, "sorter did not spill") diff --git a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala index fefa5165db197..65bf857e22c02 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.util.collection import java.lang.{Float => JFloat, Integer => JInteger} import java.util.{Arrays, Comparator} -import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.Logging import org.apache.spark.util.random.XORShiftRandom class SorterSuite extends SparkFunSuite with Logging { diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index 0326ed70b5edb..dda8bee222eca 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -17,8 +17,11 @@ package org.apache.spark.util.collection.unsafe.sort +import java.nio.charset.StandardCharsets + import com.google.common.primitives.UnsignedBytes import org.scalatest.prop.PropertyChecks + import org.apache.spark.SparkFunSuite import org.apache.spark.unsafe.types.UTF8String @@ -86,10 +89,12 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { // scalastyle:on forAll (regressionTests) { (s1: String, s2: String) => - testPrefixComparison(s1.getBytes("UTF-8"), s2.getBytes("UTF-8")) + testPrefixComparison( + s1.getBytes(StandardCharsets.UTF_8), s2.getBytes(StandardCharsets.UTF_8)) } forAll { (s1: String, s2: String) => - testPrefixComparison(s1.getBytes("UTF-8"), s2.getBytes("UTF-8")) + testPrefixComparison( + s1.getBytes(StandardCharsets.UTF_8), s2.getBytes(StandardCharsets.UTF_8)) } } diff --git a/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala deleted file mode 100644 index 361ec95654f47..0000000000000 --- a/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.io - -import scala.util.Random - -import org.apache.spark.SparkFunSuite - - -class ByteArrayChunkOutputStreamSuite extends SparkFunSuite { - - test("empty output") { - val o = new ByteArrayChunkOutputStream(1024) - assert(o.toArrays.length === 0) - } - - test("write a single byte") { - val o = new ByteArrayChunkOutputStream(1024) - o.write(10) - assert(o.toArrays.length === 1) - assert(o.toArrays.head.toSeq === Seq(10.toByte)) - } - - test("write a single near boundary") { - val o = new ByteArrayChunkOutputStream(10) - o.write(new Array[Byte](9)) - o.write(99) - assert(o.toArrays.length === 1) - assert(o.toArrays.head(9) === 99.toByte) - } - - test("write a single at boundary") { - val o = new ByteArrayChunkOutputStream(10) - o.write(new Array[Byte](10)) - o.write(99) - assert(o.toArrays.length === 2) - assert(o.toArrays(1).length === 1) - assert(o.toArrays(1)(0) === 99.toByte) - } - - test("single chunk output") { - val ref = new Array[Byte](8) - Random.nextBytes(ref) - val o = new ByteArrayChunkOutputStream(10) - o.write(ref) - val arrays = o.toArrays - assert(arrays.length === 1) - assert(arrays.head.length === ref.length) - assert(arrays.head.toSeq === ref.toSeq) - } - - test("single chunk output at boundary size") { - val ref = new Array[Byte](10) - Random.nextBytes(ref) - val o = new ByteArrayChunkOutputStream(10) - o.write(ref) - val arrays = o.toArrays - assert(arrays.length === 1) - assert(arrays.head.length === ref.length) - assert(arrays.head.toSeq === ref.toSeq) - } - - test("multiple chunk output") { - val ref = new Array[Byte](26) - Random.nextBytes(ref) - val o = new ByteArrayChunkOutputStream(10) - o.write(ref) - val arrays = o.toArrays - assert(arrays.length === 3) - assert(arrays(0).length === 10) - assert(arrays(1).length === 10) - assert(arrays(2).length === 6) - - assert(arrays(0).toSeq === ref.slice(0, 10)) - assert(arrays(1).toSeq === ref.slice(10, 20)) - assert(arrays(2).toSeq === ref.slice(20, 26)) - } - - test("multiple chunk output at boundary size") { - val ref = new Array[Byte](30) - Random.nextBytes(ref) - val o = new ByteArrayChunkOutputStream(10) - o.write(ref) - val arrays = o.toArrays - assert(arrays.length === 3) - assert(arrays(0).length === 10) - assert(arrays(1).length === 10) - assert(arrays(2).length === 10) - - assert(arrays(0).toSeq === ref.slice(0, 10)) - assert(arrays(1).toSeq === ref.slice(10, 20)) - assert(arrays(2).toSeq === ref.slice(20, 30)) - } -} diff --git a/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala new file mode 100644 index 0000000000000..226622075a6cc --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.io + +import java.nio.ByteBuffer + +import scala.util.Random + +import org.apache.spark.SparkFunSuite + + +class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { + + test("empty output") { + val o = new ChunkedByteBufferOutputStream(1024, ByteBuffer.allocate) + assert(o.toChunkedByteBuffer.size === 0) + } + + test("write a single byte") { + val o = new ChunkedByteBufferOutputStream(1024, ByteBuffer.allocate) + o.write(10) + val chunkedByteBuffer = o.toChunkedByteBuffer + assert(chunkedByteBuffer.getChunks().length === 1) + assert(chunkedByteBuffer.getChunks().head.array().toSeq === Seq(10.toByte)) + } + + test("write a single near boundary") { + val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) + o.write(new Array[Byte](9)) + o.write(99) + val chunkedByteBuffer = o.toChunkedByteBuffer + assert(chunkedByteBuffer.getChunks().length === 1) + assert(chunkedByteBuffer.getChunks().head.array()(9) === 99.toByte) + } + + test("write a single at boundary") { + val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) + o.write(new Array[Byte](10)) + o.write(99) + val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) + assert(arrays.length === 2) + assert(arrays(1).length === 1) + assert(arrays(1)(0) === 99.toByte) + } + + test("single chunk output") { + val ref = new Array[Byte](8) + Random.nextBytes(ref) + val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) + o.write(ref) + val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) + assert(arrays.length === 1) + assert(arrays.head.length === ref.length) + assert(arrays.head.toSeq === ref.toSeq) + } + + test("single chunk output at boundary size") { + val ref = new Array[Byte](10) + Random.nextBytes(ref) + val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) + o.write(ref) + val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) + assert(arrays.length === 1) + assert(arrays.head.length === ref.length) + assert(arrays.head.toSeq === ref.toSeq) + } + + test("multiple chunk output") { + val ref = new Array[Byte](26) + Random.nextBytes(ref) + val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) + o.write(ref) + val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) + assert(arrays.length === 3) + assert(arrays(0).length === 10) + assert(arrays(1).length === 10) + assert(arrays(2).length === 6) + + assert(arrays(0).toSeq === ref.slice(0, 10)) + assert(arrays(1).toSeq === ref.slice(10, 20)) + assert(arrays(2).toSeq === ref.slice(20, 26)) + } + + test("multiple chunk output at boundary size") { + val ref = new Array[Byte](30) + Random.nextBytes(ref) + val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) + o.write(ref) + val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) + assert(arrays.length === 3) + assert(arrays(0).length === 10) + assert(arrays(1).length === 10) + assert(arrays(2).length === 10) + + assert(arrays(0).toSeq === ref.slice(0, 10)) + assert(arrays(1).toSeq === ref.slice(10, 20)) + assert(arrays(2).toSeq === ref.slice(20, 30)) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala index d6af0aebde733..7eb2f56c20585 100644 --- a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala @@ -18,9 +18,10 @@ package org.apache.spark.util.random import java.util.Random + import scala.collection.mutable.ArrayBuffer -import org.apache.commons.math3.distribution.PoissonDistribution +import org.apache.commons.math3.distribution.PoissonDistribution import org.scalatest.Matchers import org.apache.spark.SparkFunSuite @@ -128,6 +129,13 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { t(m / 2) } + def replacementSampling(data: Iterator[Int], sampler: PoissonSampler[Int]): Iterator[Int] = { + data.flatMap { item => + val count = sampler.sample() + if (count == 0) Iterator.empty else Iterator.fill(count)(item) + } + } + test("utilities") { val s1 = Array(0, 1, 1, 0, 2) val s2 = Array(1, 0, 3, 2, 1) @@ -188,6 +196,36 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { d should be > D } + test("bernoulli sampling without iterator") { + // Tests expect maximum gap sampling fraction to be this value + RandomSampler.defaultMaxGapSamplingFraction should be (0.4) + + var d: Double = 0.0 + + val data = Iterator.from(0) + + var sampler: RandomSampler[Int, Int] = new BernoulliSampler[Int](0.5) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.5))) + d should be < D + + sampler = new BernoulliSampler[Int](0.7) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.7))) + d should be < D + + sampler = new BernoulliSampler[Int](0.9) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.9))) + d should be < D + + // sampling at different frequencies should show up as statistically different: + sampler = new BernoulliSampler[Int](0.5) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.6))) + d should be > D + } + test("bernoulli sampling with gap sampling optimization") { // Tests expect maximum gap sampling fraction to be this value RandomSampler.defaultMaxGapSamplingFraction should be (0.4) @@ -216,6 +254,37 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { d should be > D } + test("bernoulli sampling (without iterator) with gap sampling optimization") { + // Tests expect maximum gap sampling fraction to be this value + RandomSampler.defaultMaxGapSamplingFraction should be (0.4) + + var d: Double = 0.0 + + val data = Iterator.from(0) + + var sampler: RandomSampler[Int, Int] = new BernoulliSampler[Int](0.01) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), + gaps(sample(Iterator.from(0), 0.01))) + d should be < D + + sampler = new BernoulliSampler[Int](0.1) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.1))) + d should be < D + + sampler = new BernoulliSampler[Int](0.3) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.3))) + d should be < D + + // sampling at different frequencies should show up as statistically different: + sampler = new BernoulliSampler[Int](0.3) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.4))) + d should be > D + } + test("bernoulli boundary cases") { val data = (1 to 100).toArray @@ -232,6 +301,22 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { sampler.sample(data.iterator).toArray should be (data) } + test("bernoulli (without iterator) boundary cases") { + val data = (1 to 100).toArray + + var sampler = new BernoulliSampler[Int](0.0) + data.filter(_ => sampler.sample() > 0) should be (Array.empty[Int]) + + sampler = new BernoulliSampler[Int](1.0) + data.filter(_ => sampler.sample() > 0) should be (data) + + sampler = new BernoulliSampler[Int](0.0 - (RandomSampler.roundingEpsilon / 2.0)) + data.filter(_ => sampler.sample() > 0) should be (Array.empty[Int]) + + sampler = new BernoulliSampler[Int](1.0 + (RandomSampler.roundingEpsilon / 2.0)) + data.filter(_ => sampler.sample() > 0) should be (data) + } + test("bernoulli data types") { // Tests expect maximum gap sampling fraction to be this value RandomSampler.defaultMaxGapSamplingFraction should be (0.4) @@ -340,6 +425,36 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { d should be > D } + test("replacement sampling without iterator") { + // Tests expect maximum gap sampling fraction to be this value + RandomSampler.defaultMaxGapSamplingFraction should be (0.4) + + var d: Double = 0.0 + + val data = Iterator.from(0) + + var sampler = new PoissonSampler[Int](0.5) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.5))) + d should be < D + + sampler = new PoissonSampler[Int](0.7) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.7))) + d should be < D + + sampler = new PoissonSampler[Int](0.9) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.9))) + d should be < D + + // sampling at different frequencies should show up as statistically different: + sampler = new PoissonSampler[Int](0.5) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.6))) + d should be > D + } + test("replacement sampling with gap sampling") { // Tests expect maximum gap sampling fraction to be this value RandomSampler.defaultMaxGapSamplingFraction should be (0.4) @@ -368,6 +483,36 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { d should be > D } + test("replacement sampling (without iterator) with gap sampling") { + // Tests expect maximum gap sampling fraction to be this value + RandomSampler.defaultMaxGapSamplingFraction should be (0.4) + + var d: Double = 0.0 + + val data = Iterator.from(0) + + var sampler = new PoissonSampler[Int](0.01) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.01))) + d should be < D + + sampler = new PoissonSampler[Int](0.1) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.1))) + d should be < D + + sampler = new PoissonSampler[Int](0.3) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.3))) + d should be < D + + // sampling at different frequencies should show up as statistically different: + sampler = new PoissonSampler[Int](0.3) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.4))) + d should be > D + } + test("replacement boundary cases") { val data = (1 to 100).toArray @@ -382,6 +527,20 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { sampler.sample(data.iterator).length should be > (data.length) } + test("replacement (without) boundary cases") { + val data = (1 to 100).toArray + + var sampler = new PoissonSampler[Int](0.0) + replacementSampling(data.iterator, sampler).toArray should be (Array.empty[Int]) + + sampler = new PoissonSampler[Int](0.0 - (RandomSampler.roundingEpsilon / 2.0)) + replacementSampling(data.iterator, sampler).toArray should be (Array.empty[Int]) + + // sampling with replacement has no upper bound on sampling fraction + sampler = new PoissonSampler[Int](2.0) + replacementSampling(data.iterator, sampler).length should be > (data.length) + } + test("replacement data types") { // Tests expect maximum gap sampling fraction to be this value RandomSampler.defaultMaxGapSamplingFraction should be (0.4) @@ -476,6 +635,22 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { d should be < D } + test("bernoulli partitioning sampling without iterator") { + var d: Double = 0.0 + + val data = Iterator.from(0) + + var sampler = new BernoulliCellSampler[Int](0.1, 0.2) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.1))) + d should be < D + + sampler = new BernoulliCellSampler[Int](0.1, 0.2, true) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.9))) + d should be < D + } + test("bernoulli partitioning boundary cases") { val data = (1 to 100).toArray val d = RandomSampler.roundingEpsilon / 2.0 @@ -499,6 +674,29 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { sampler.sample(data.iterator).toArray should be (Array.empty[Int]) } + test("bernoulli partitioning (without iterator) boundary cases") { + val data = (1 to 100).toArray + val d = RandomSampler.roundingEpsilon / 2.0 + + var sampler = new BernoulliCellSampler[Int](0.0, 0.0) + data.filter(_ => sampler.sample() > 0).toArray should be (Array.empty[Int]) + + sampler = new BernoulliCellSampler[Int](0.5, 0.5) + data.filter(_ => sampler.sample() > 0).toArray should be (Array.empty[Int]) + + sampler = new BernoulliCellSampler[Int](1.0, 1.0) + data.filter(_ => sampler.sample() > 0).toArray should be (Array.empty[Int]) + + sampler = new BernoulliCellSampler[Int](0.0, 1.0) + data.filter(_ => sampler.sample() > 0).toArray should be (data) + + sampler = new BernoulliCellSampler[Int](0.0 - d, 1.0 + d) + data.filter(_ => sampler.sample() > 0).toArray should be (data) + + sampler = new BernoulliCellSampler[Int](0.5, 0.5 - d) + data.filter(_ => sampler.sample() > 0).toArray should be (Array.empty[Int]) + } + test("bernoulli partitioning data") { val seed = rngSeed.nextLong val data = (1 to 100).toArray diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala index d26667bf720cf..83eba3690e289 100644 --- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.util.random -import org.scalatest.Matchers +import scala.language.reflectiveCalls import org.apache.commons.math3.stat.inference.ChiSquareTest +import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.util.Utils.times -import scala.language.reflectiveCalls - class XORShiftRandomSuite extends SparkFunSuite with Matchers { private def fixture = new { @@ -54,7 +53,7 @@ class XORShiftRandomSuite extends SparkFunSuite with Matchers { * Perform the chi square test on the 5 rows of randomly generated numbers evenly divided into * 10 bins. chiSquareTest returns true iff the null hypothesis (that the classifications * represented by the counts in the columns of the input 2-way table are independent of the - * rows) can be rejected with 100 * (1 - alpha) percent confidence, where alpha is prespeficied + * rows) can be rejected with 100 * (1 - alpha) percent confidence, where alpha is prespecified * as 0.05 */ val chiTest = new ChiSquareTest @@ -65,4 +64,19 @@ class XORShiftRandomSuite extends SparkFunSuite with Matchers { val random = new XORShiftRandom(0L) assert(random.nextInt() != 0) } + + test ("hashSeed has random bits throughout") { + val totalBitCount = (0 until 10).map { seed => + val hashed = XORShiftRandom.hashSeed(seed) + val bitCount = java.lang.Long.bitCount(hashed) + // make sure we have roughly equal numbers of 0s and 1s. Mostly just check that we + // don't have all 0s or 1s in the high bits + bitCount should be > 20 + bitCount should be < 44 + bitCount + }.sum + // and over all the seeds, very close to equal numbers of 0s & 1s + totalBitCount should be > (32 * 10 - 30) + totalBitCount should be < (32 * 10 + 30) + } } diff --git a/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala b/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala index daa795a043495..2fb09ead4b2d8 100644 --- a/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala +++ b/core/src/test/scala/org/apache/sparktest/ImplicitSuite.scala @@ -26,11 +26,11 @@ package org.apache.sparktest */ class ImplicitSuite { - // We only want to test if `implict` works well with the compiler, so we don't need a real + // We only want to test if `implicit` works well with the compiler, so we don't need a real // SparkContext. def mockSparkContext[T]: org.apache.spark.SparkContext = null - // We only want to test if `implict` works well with the compiler, so we don't need a real RDD. + // We only want to test if `implicit` works well with the compiler, so we don't need a real RDD. def mockRDD[T]: org.apache.spark.rdd.RDD[T] = null def testRddToPairRDDFunctions(): Unit = { diff --git a/data/mllib/als/sample_movielens_movies.txt b/data/mllib/als/sample_movielens_movies.txt deleted file mode 100644 index 934a0253849e1..0000000000000 --- a/data/mllib/als/sample_movielens_movies.txt +++ /dev/null @@ -1,100 +0,0 @@ -0::Movie 0::Romance|Comedy -1::Movie 1::Action|Anime -2::Movie 2::Romance|Thriller -3::Movie 3::Action|Romance -4::Movie 4::Anime|Comedy -5::Movie 5::Action|Action -6::Movie 6::Action|Comedy -7::Movie 7::Anime|Comedy -8::Movie 8::Comedy|Action -9::Movie 9::Anime|Thriller -10::Movie 10::Action|Anime -11::Movie 11::Action|Anime -12::Movie 12::Anime|Comedy -13::Movie 13::Thriller|Action -14::Movie 14::Anime|Comedy -15::Movie 15::Comedy|Thriller -16::Movie 16::Anime|Romance -17::Movie 17::Thriller|Action -18::Movie 18::Action|Comedy -19::Movie 19::Anime|Romance -20::Movie 20::Action|Anime -21::Movie 21::Romance|Thriller -22::Movie 22::Romance|Romance -23::Movie 23::Comedy|Comedy -24::Movie 24::Anime|Action -25::Movie 25::Comedy|Comedy -26::Movie 26::Anime|Romance -27::Movie 27::Anime|Anime -28::Movie 28::Thriller|Anime -29::Movie 29::Anime|Romance -30::Movie 30::Thriller|Romance -31::Movie 31::Thriller|Romance -32::Movie 32::Comedy|Anime -33::Movie 33::Comedy|Comedy -34::Movie 34::Anime|Anime -35::Movie 35::Action|Thriller -36::Movie 36::Anime|Romance -37::Movie 37::Romance|Anime -38::Movie 38::Thriller|Romance -39::Movie 39::Romance|Comedy -40::Movie 40::Action|Anime -41::Movie 41::Comedy|Thriller -42::Movie 42::Comedy|Action -43::Movie 43::Thriller|Anime -44::Movie 44::Anime|Action -45::Movie 45::Comedy|Romance -46::Movie 46::Comedy|Action -47::Movie 47::Romance|Comedy -48::Movie 48::Action|Comedy -49::Movie 49::Romance|Romance -50::Movie 50::Comedy|Romance -51::Movie 51::Action|Action -52::Movie 52::Thriller|Action -53::Movie 53::Action|Action -54::Movie 54::Romance|Thriller -55::Movie 55::Anime|Romance -56::Movie 56::Comedy|Action -57::Movie 57::Action|Anime -58::Movie 58::Thriller|Romance -59::Movie 59::Thriller|Comedy -60::Movie 60::Anime|Comedy -61::Movie 61::Comedy|Action -62::Movie 62::Comedy|Romance -63::Movie 63::Romance|Thriller -64::Movie 64::Romance|Action -65::Movie 65::Anime|Romance -66::Movie 66::Comedy|Action -67::Movie 67::Thriller|Anime -68::Movie 68::Thriller|Romance -69::Movie 69::Action|Comedy -70::Movie 70::Thriller|Thriller -71::Movie 71::Action|Comedy -72::Movie 72::Thriller|Romance -73::Movie 73::Comedy|Action -74::Movie 74::Action|Action -75::Movie 75::Action|Action -76::Movie 76::Comedy|Comedy -77::Movie 77::Comedy|Comedy -78::Movie 78::Comedy|Comedy -79::Movie 79::Thriller|Thriller -80::Movie 80::Comedy|Anime -81::Movie 81::Comedy|Anime -82::Movie 82::Romance|Anime -83::Movie 83::Comedy|Thriller -84::Movie 84::Anime|Action -85::Movie 85::Thriller|Anime -86::Movie 86::Romance|Anime -87::Movie 87::Thriller|Thriller -88::Movie 88::Romance|Thriller -89::Movie 89::Action|Anime -90::Movie 90::Anime|Romance -91::Movie 91::Anime|Thriller -92::Movie 92::Action|Comedy -93::Movie 93::Romance|Thriller -94::Movie 94::Thriller|Comedy -95::Movie 95::Action|Action -96::Movie 96::Thriller|Romance -97::Movie 97::Thriller|Thriller -98::Movie 98::Thriller|Comedy -99::Movie 99::Thriller|Romance diff --git a/data/mllib/streaming_kmeans_data_test.txt b/data/mllib/streaming_kmeans_data_test.txt new file mode 100644 index 0000000000000..649a0d6cf4e22 --- /dev/null +++ b/data/mllib/streaming_kmeans_data_test.txt @@ -0,0 +1,2 @@ +(1.0), [1.7, 0.4, 0.9] +(2.0), [2.2, 1.8, 0.0] diff --git a/data/streaming/AFINN-111.txt b/data/streaming/AFINN-111.txt new file mode 100644 index 0000000000000..0f6fb8ebaa0bf --- /dev/null +++ b/data/streaming/AFINN-111.txt @@ -0,0 +1,2477 @@ +abandon -2 +abandoned -2 +abandons -2 +abducted -2 +abduction -2 +abductions -2 +abhor -3 +abhorred -3 +abhorrent -3 +abhors -3 +abilities 2 +ability 2 +aboard 1 +absentee -1 +absentees -1 +absolve 2 +absolved 2 +absolves 2 +absolving 2 +absorbed 1 +abuse -3 +abused -3 +abuses -3 +abusive -3 +accept 1 +accepted 1 +accepting 1 +accepts 1 +accident -2 +accidental -2 +accidentally -2 +accidents -2 +accomplish 2 +accomplished 2 +accomplishes 2 +accusation -2 +accusations -2 +accuse -2 +accused -2 +accuses -2 +accusing -2 +ache -2 +achievable 1 +aching -2 +acquit 2 +acquits 2 +acquitted 2 +acquitting 2 +acrimonious -3 +active 1 +adequate 1 +admire 3 +admired 3 +admires 3 +admiring 3 +admit -1 +admits -1 +admitted -1 +admonish -2 +admonished -2 +adopt 1 +adopts 1 +adorable 3 +adore 3 +adored 3 +adores 3 +advanced 1 +advantage 2 +advantages 2 +adventure 2 +adventures 2 +adventurous 2 +affected -1 +affection 3 +affectionate 3 +afflicted -1 +affronted -1 +afraid -2 +aggravate -2 +aggravated -2 +aggravates -2 +aggravating -2 +aggression -2 +aggressions -2 +aggressive -2 +aghast -2 +agog 2 +agonise -3 +agonised -3 +agonises -3 +agonising -3 +agonize -3 +agonized -3 +agonizes -3 +agonizing -3 +agree 1 +agreeable 2 +agreed 1 +agreement 1 +agrees 1 +alarm -2 +alarmed -2 +alarmist -2 +alarmists -2 +alas -1 +alert -1 +alienation -2 +alive 1 +allergic -2 +allow 1 +alone -2 +amaze 2 +amazed 2 +amazes 2 +amazing 4 +ambitious 2 +ambivalent -1 +amuse 3 +amused 3 +amusement 3 +amusements 3 +anger -3 +angers -3 +angry -3 +anguish -3 +anguished -3 +animosity -2 +annoy -2 +annoyance -2 +annoyed -2 +annoying -2 +annoys -2 +antagonistic -2 +anti -1 +anticipation 1 +anxiety -2 +anxious -2 +apathetic -3 +apathy -3 +apeshit -3 +apocalyptic -2 +apologise -1 +apologised -1 +apologises -1 +apologising -1 +apologize -1 +apologized -1 +apologizes -1 +apologizing -1 +apology -1 +appalled -2 +appalling -2 +appease 2 +appeased 2 +appeases 2 +appeasing 2 +applaud 2 +applauded 2 +applauding 2 +applauds 2 +applause 2 +appreciate 2 +appreciated 2 +appreciates 2 +appreciating 2 +appreciation 2 +apprehensive -2 +approval 2 +approved 2 +approves 2 +ardent 1 +arrest -2 +arrested -3 +arrests -2 +arrogant -2 +ashame -2 +ashamed -2 +ass -4 +assassination -3 +assassinations -3 +asset 2 +assets 2 +assfucking -4 +asshole -4 +astonished 2 +astound 3 +astounded 3 +astounding 3 +astoundingly 3 +astounds 3 +attack -1 +attacked -1 +attacking -1 +attacks -1 +attract 1 +attracted 1 +attracting 2 +attraction 2 +attractions 2 +attracts 1 +audacious 3 +authority 1 +avert -1 +averted -1 +averts -1 +avid 2 +avoid -1 +avoided -1 +avoids -1 +await -1 +awaited -1 +awaits -1 +award 3 +awarded 3 +awards 3 +awesome 4 +awful -3 +awkward -2 +axe -1 +axed -1 +backed 1 +backing 2 +backs 1 +bad -3 +badass -3 +badly -3 +bailout -2 +bamboozle -2 +bamboozled -2 +bamboozles -2 +ban -2 +banish -1 +bankrupt -3 +bankster -3 +banned -2 +bargain 2 +barrier -2 +bastard -5 +bastards -5 +battle -1 +battles -1 +beaten -2 +beatific 3 +beating -1 +beauties 3 +beautiful 3 +beautifully 3 +beautify 3 +belittle -2 +belittled -2 +beloved 3 +benefit 2 +benefits 2 +benefitted 2 +benefitting 2 +bereave -2 +bereaved -2 +bereaves -2 +bereaving -2 +best 3 +betray -3 +betrayal -3 +betrayed -3 +betraying -3 +betrays -3 +better 2 +bias -1 +biased -2 +big 1 +bitch -5 +bitches -5 +bitter -2 +bitterly -2 +bizarre -2 +blah -2 +blame -2 +blamed -2 +blames -2 +blaming -2 +bless 2 +blesses 2 +blessing 3 +blind -1 +bliss 3 +blissful 3 +blithe 2 +block -1 +blockbuster 3 +blocked -1 +blocking -1 +blocks -1 +bloody -3 +blurry -2 +boastful -2 +bold 2 +boldly 2 +bomb -1 +boost 1 +boosted 1 +boosting 1 +boosts 1 +bore -2 +bored -2 +boring -3 +bother -2 +bothered -2 +bothers -2 +bothersome -2 +boycott -2 +boycotted -2 +boycotting -2 +boycotts -2 +brainwashing -3 +brave 2 +breakthrough 3 +breathtaking 5 +bribe -3 +bright 1 +brightest 2 +brightness 1 +brilliant 4 +brisk 2 +broke -1 +broken -1 +brooding -2 +bullied -2 +bullshit -4 +bully -2 +bullying -2 +bummer -2 +buoyant 2 +burden -2 +burdened -2 +burdening -2 +burdens -2 +calm 2 +calmed 2 +calming 2 +calms 2 +can't stand -3 +cancel -1 +cancelled -1 +cancelling -1 +cancels -1 +cancer -1 +capable 1 +captivated 3 +care 2 +carefree 1 +careful 2 +carefully 2 +careless -2 +cares 2 +cashing in -2 +casualty -2 +catastrophe -3 +catastrophic -4 +cautious -1 +celebrate 3 +celebrated 3 +celebrates 3 +celebrating 3 +censor -2 +censored -2 +censors -2 +certain 1 +chagrin -2 +chagrined -2 +challenge -1 +chance 2 +chances 2 +chaos -2 +chaotic -2 +charged -3 +charges -2 +charm 3 +charming 3 +charmless -3 +chastise -3 +chastised -3 +chastises -3 +chastising -3 +cheat -3 +cheated -3 +cheater -3 +cheaters -3 +cheats -3 +cheer 2 +cheered 2 +cheerful 2 +cheering 2 +cheerless -2 +cheers 2 +cheery 3 +cherish 2 +cherished 2 +cherishes 2 +cherishing 2 +chic 2 +childish -2 +chilling -1 +choke -2 +choked -2 +chokes -2 +choking -2 +clarifies 2 +clarity 2 +clash -2 +classy 3 +clean 2 +cleaner 2 +clear 1 +cleared 1 +clearly 1 +clears 1 +clever 2 +clouded -1 +clueless -2 +cock -5 +cocksucker -5 +cocksuckers -5 +cocky -2 +coerced -2 +collapse -2 +collapsed -2 +collapses -2 +collapsing -2 +collide -1 +collides -1 +colliding -1 +collision -2 +collisions -2 +colluding -3 +combat -1 +combats -1 +comedy 1 +comfort 2 +comfortable 2 +comforting 2 +comforts 2 +commend 2 +commended 2 +commit 1 +commitment 2 +commits 1 +committed 1 +committing 1 +compassionate 2 +compelled 1 +competent 2 +competitive 2 +complacent -2 +complain -2 +complained -2 +complains -2 +comprehensive 2 +conciliate 2 +conciliated 2 +conciliates 2 +conciliating 2 +condemn -2 +condemnation -2 +condemned -2 +condemns -2 +confidence 2 +confident 2 +conflict -2 +conflicting -2 +conflictive -2 +conflicts -2 +confuse -2 +confused -2 +confusing -2 +congrats 2 +congratulate 2 +congratulation 2 +congratulations 2 +consent 2 +consents 2 +consolable 2 +conspiracy -3 +constrained -2 +contagion -2 +contagions -2 +contagious -1 +contempt -2 +contemptuous -2 +contemptuously -2 +contend -1 +contender -1 +contending -1 +contentious -2 +contestable -2 +controversial -2 +controversially -2 +convince 1 +convinced 1 +convinces 1 +convivial 2 +cool 1 +cool stuff 3 +cornered -2 +corpse -1 +costly -2 +courage 2 +courageous 2 +courteous 2 +courtesy 2 +cover-up -3 +coward -2 +cowardly -2 +coziness 2 +cramp -1 +crap -3 +crash -2 +crazier -2 +craziest -2 +crazy -2 +creative 2 +crestfallen -2 +cried -2 +cries -2 +crime -3 +criminal -3 +criminals -3 +crisis -3 +critic -2 +criticism -2 +criticize -2 +criticized -2 +criticizes -2 +criticizing -2 +critics -2 +cruel -3 +cruelty -3 +crush -1 +crushed -2 +crushes -1 +crushing -1 +cry -1 +crying -2 +cunt -5 +curious 1 +curse -1 +cut -1 +cute 2 +cuts -1 +cutting -1 +cynic -2 +cynical -2 +cynicism -2 +damage -3 +damages -3 +damn -4 +damned -4 +damnit -4 +danger -2 +daredevil 2 +daring 2 +darkest -2 +darkness -1 +dauntless 2 +dead -3 +deadlock -2 +deafening -1 +dear 2 +dearly 3 +death -2 +debonair 2 +debt -2 +deceit -3 +deceitful -3 +deceive -3 +deceived -3 +deceives -3 +deceiving -3 +deception -3 +decisive 1 +dedicated 2 +defeated -2 +defect -3 +defects -3 +defender 2 +defenders 2 +defenseless -2 +defer -1 +deferring -1 +defiant -1 +deficit -2 +degrade -2 +degraded -2 +degrades -2 +dehumanize -2 +dehumanized -2 +dehumanizes -2 +dehumanizing -2 +deject -2 +dejected -2 +dejecting -2 +dejects -2 +delay -1 +delayed -1 +delight 3 +delighted 3 +delighting 3 +delights 3 +demand -1 +demanded -1 +demanding -1 +demands -1 +demonstration -1 +demoralized -2 +denied -2 +denier -2 +deniers -2 +denies -2 +denounce -2 +denounces -2 +deny -2 +denying -2 +depressed -2 +depressing -2 +derail -2 +derailed -2 +derails -2 +deride -2 +derided -2 +derides -2 +deriding -2 +derision -2 +desirable 2 +desire 1 +desired 2 +desirous 2 +despair -3 +despairing -3 +despairs -3 +desperate -3 +desperately -3 +despondent -3 +destroy -3 +destroyed -3 +destroying -3 +destroys -3 +destruction -3 +destructive -3 +detached -1 +detain -2 +detained -2 +detention -2 +determined 2 +devastate -2 +devastated -2 +devastating -2 +devoted 3 +diamond 1 +dick -4 +dickhead -4 +die -3 +died -3 +difficult -1 +diffident -2 +dilemma -1 +dipshit -3 +dire -3 +direful -3 +dirt -2 +dirtier -2 +dirtiest -2 +dirty -2 +disabling -1 +disadvantage -2 +disadvantaged -2 +disappear -1 +disappeared -1 +disappears -1 +disappoint -2 +disappointed -2 +disappointing -2 +disappointment -2 +disappointments -2 +disappoints -2 +disaster -2 +disasters -2 +disastrous -3 +disbelieve -2 +discard -1 +discarded -1 +discarding -1 +discards -1 +disconsolate -2 +disconsolation -2 +discontented -2 +discord -2 +discounted -1 +discouraged -2 +discredited -2 +disdain -2 +disgrace -2 +disgraced -2 +disguise -1 +disguised -1 +disguises -1 +disguising -1 +disgust -3 +disgusted -3 +disgusting -3 +disheartened -2 +dishonest -2 +disillusioned -2 +disinclined -2 +disjointed -2 +dislike -2 +dismal -2 +dismayed -2 +disorder -2 +disorganized -2 +disoriented -2 +disparage -2 +disparaged -2 +disparages -2 +disparaging -2 +displeased -2 +dispute -2 +disputed -2 +disputes -2 +disputing -2 +disqualified -2 +disquiet -2 +disregard -2 +disregarded -2 +disregarding -2 +disregards -2 +disrespect -2 +disrespected -2 +disruption -2 +disruptions -2 +disruptive -2 +dissatisfied -2 +distort -2 +distorted -2 +distorting -2 +distorts -2 +distract -2 +distracted -2 +distraction -2 +distracts -2 +distress -2 +distressed -2 +distresses -2 +distressing -2 +distrust -3 +distrustful -3 +disturb -2 +disturbed -2 +disturbing -2 +disturbs -2 +dithering -2 +dizzy -1 +dodging -2 +dodgy -2 +does not work -3 +dolorous -2 +dont like -2 +doom -2 +doomed -2 +doubt -1 +doubted -1 +doubtful -1 +doubting -1 +doubts -1 +douche -3 +douchebag -3 +downcast -2 +downhearted -2 +downside -2 +drag -1 +dragged -1 +drags -1 +drained -2 +dread -2 +dreaded -2 +dreadful -3 +dreading -2 +dream 1 +dreams 1 +dreary -2 +droopy -2 +drop -1 +drown -2 +drowned -2 +drowns -2 +drunk -2 +dubious -2 +dud -2 +dull -2 +dumb -3 +dumbass -3 +dump -1 +dumped -2 +dumps -1 +dupe -2 +duped -2 +dysfunction -2 +eager 2 +earnest 2 +ease 2 +easy 1 +ecstatic 4 +eerie -2 +eery -2 +effective 2 +effectively 2 +elated 3 +elation 3 +elegant 2 +elegantly 2 +embarrass -2 +embarrassed -2 +embarrasses -2 +embarrassing -2 +embarrassment -2 +embittered -2 +embrace 1 +emergency -2 +empathetic 2 +emptiness -1 +empty -1 +enchanted 2 +encourage 2 +encouraged 2 +encouragement 2 +encourages 2 +endorse 2 +endorsed 2 +endorsement 2 +endorses 2 +enemies -2 +enemy -2 +energetic 2 +engage 1 +engages 1 +engrossed 1 +enjoy 2 +enjoying 2 +enjoys 2 +enlighten 2 +enlightened 2 +enlightening 2 +enlightens 2 +ennui -2 +enrage -2 +enraged -2 +enrages -2 +enraging -2 +enrapture 3 +enslave -2 +enslaved -2 +enslaves -2 +ensure 1 +ensuring 1 +enterprising 1 +entertaining 2 +enthral 3 +enthusiastic 3 +entitled 1 +entrusted 2 +envies -1 +envious -2 +envy -1 +envying -1 +erroneous -2 +error -2 +errors -2 +escape -1 +escapes -1 +escaping -1 +esteemed 2 +ethical 2 +euphoria 3 +euphoric 4 +eviction -1 +evil -3 +exaggerate -2 +exaggerated -2 +exaggerates -2 +exaggerating -2 +exasperated 2 +excellence 3 +excellent 3 +excite 3 +excited 3 +excitement 3 +exciting 3 +exclude -1 +excluded -2 +exclusion -1 +exclusive 2 +excuse -1 +exempt -1 +exhausted -2 +exhilarated 3 +exhilarates 3 +exhilarating 3 +exonerate 2 +exonerated 2 +exonerates 2 +exonerating 2 +expand 1 +expands 1 +expel -2 +expelled -2 +expelling -2 +expels -2 +exploit -2 +exploited -2 +exploiting -2 +exploits -2 +exploration 1 +explorations 1 +expose -1 +exposed -1 +exposes -1 +exposing -1 +extend 1 +extends 1 +exuberant 4 +exultant 3 +exultantly 3 +fabulous 4 +fad -2 +fag -3 +faggot -3 +faggots -3 +fail -2 +failed -2 +failing -2 +fails -2 +failure -2 +failures -2 +fainthearted -2 +fair 2 +faith 1 +faithful 3 +fake -3 +fakes -3 +faking -3 +fallen -2 +falling -1 +falsified -3 +falsify -3 +fame 1 +fan 3 +fantastic 4 +farce -1 +fascinate 3 +fascinated 3 +fascinates 3 +fascinating 3 +fascist -2 +fascists -2 +fatalities -3 +fatality -3 +fatigue -2 +fatigued -2 +fatigues -2 +fatiguing -2 +favor 2 +favored 2 +favorite 2 +favorited 2 +favorites 2 +favors 2 +fear -2 +fearful -2 +fearing -2 +fearless 2 +fearsome -2 +fed up -3 +feeble -2 +feeling 1 +felonies -3 +felony -3 +fervent 2 +fervid 2 +festive 2 +fiasco -3 +fidgety -2 +fight -1 +fine 2 +fire -2 +fired -2 +firing -2 +fit 1 +fitness 1 +flagship 2 +flees -1 +flop -2 +flops -2 +flu -2 +flustered -2 +focused 2 +fond 2 +fondness 2 +fool -2 +foolish -2 +fools -2 +forced -1 +foreclosure -2 +foreclosures -2 +forget -1 +forgetful -2 +forgive 1 +forgiving 1 +forgotten -1 +fortunate 2 +frantic -1 +fraud -4 +frauds -4 +fraudster -4 +fraudsters -4 +fraudulence -4 +fraudulent -4 +free 1 +freedom 2 +frenzy -3 +fresh 1 +friendly 2 +fright -2 +frightened -2 +frightening -3 +frikin -2 +frisky 2 +frowning -1 +frustrate -2 +frustrated -2 +frustrates -2 +frustrating -2 +frustration -2 +ftw 3 +fuck -4 +fucked -4 +fucker -4 +fuckers -4 +fuckface -4 +fuckhead -4 +fucking -4 +fucktard -4 +fud -3 +fuked -4 +fuking -4 +fulfill 2 +fulfilled 2 +fulfills 2 +fuming -2 +fun 4 +funeral -1 +funerals -1 +funky 2 +funnier 4 +funny 4 +furious -3 +futile 2 +gag -2 +gagged -2 +gain 2 +gained 2 +gaining 2 +gains 2 +gallant 3 +gallantly 3 +gallantry 3 +generous 2 +genial 3 +ghost -1 +giddy -2 +gift 2 +glad 3 +glamorous 3 +glamourous 3 +glee 3 +gleeful 3 +gloom -1 +gloomy -2 +glorious 2 +glory 2 +glum -2 +god 1 +goddamn -3 +godsend 4 +good 3 +goodness 3 +grace 1 +gracious 3 +grand 3 +grant 1 +granted 1 +granting 1 +grants 1 +grateful 3 +gratification 2 +grave -2 +gray -1 +great 3 +greater 3 +greatest 3 +greed -3 +greedy -2 +green wash -3 +green washing -3 +greenwash -3 +greenwasher -3 +greenwashers -3 +greenwashing -3 +greet 1 +greeted 1 +greeting 1 +greetings 2 +greets 1 +grey -1 +grief -2 +grieved -2 +gross -2 +growing 1 +growth 2 +guarantee 1 +guilt -3 +guilty -3 +gullibility -2 +gullible -2 +gun -1 +ha 2 +hacked -1 +haha 3 +hahaha 3 +hahahah 3 +hail 2 +hailed 2 +hapless -2 +haplessness -2 +happiness 3 +happy 3 +hard -1 +hardier 2 +hardship -2 +hardy 2 +harm -2 +harmed -2 +harmful -2 +harming -2 +harms -2 +harried -2 +harsh -2 +harsher -2 +harshest -2 +hate -3 +hated -3 +haters -3 +hates -3 +hating -3 +haunt -1 +haunted -2 +haunting 1 +haunts -1 +havoc -2 +healthy 2 +heartbreaking -3 +heartbroken -3 +heartfelt 3 +heaven 2 +heavenly 4 +heavyhearted -2 +hell -4 +help 2 +helpful 2 +helping 2 +helpless -2 +helps 2 +hero 2 +heroes 2 +heroic 3 +hesitant -2 +hesitate -2 +hid -1 +hide -1 +hides -1 +hiding -1 +highlight 2 +hilarious 2 +hindrance -2 +hoax -2 +homesick -2 +honest 2 +honor 2 +honored 2 +honoring 2 +honour 2 +honoured 2 +honouring 2 +hooligan -2 +hooliganism -2 +hooligans -2 +hope 2 +hopeful 2 +hopefully 2 +hopeless -2 +hopelessness -2 +hopes 2 +hoping 2 +horrendous -3 +horrible -3 +horrific -3 +horrified -3 +hostile -2 +huckster -2 +hug 2 +huge 1 +hugs 2 +humerous 3 +humiliated -3 +humiliation -3 +humor 2 +humorous 2 +humour 2 +humourous 2 +hunger -2 +hurrah 5 +hurt -2 +hurting -2 +hurts -2 +hypocritical -2 +hysteria -3 +hysterical -3 +hysterics -3 +idiot -3 +idiotic -3 +ignorance -2 +ignorant -2 +ignore -1 +ignored -2 +ignores -1 +ill -2 +illegal -3 +illiteracy -2 +illness -2 +illnesses -2 +imbecile -3 +immobilized -1 +immortal 2 +immune 1 +impatient -2 +imperfect -2 +importance 2 +important 2 +impose -1 +imposed -1 +imposes -1 +imposing -1 +impotent -2 +impress 3 +impressed 3 +impresses 3 +impressive 3 +imprisoned -2 +improve 2 +improved 2 +improvement 2 +improves 2 +improving 2 +inability -2 +inaction -2 +inadequate -2 +incapable -2 +incapacitated -2 +incensed -2 +incompetence -2 +incompetent -2 +inconsiderate -2 +inconvenience -2 +inconvenient -2 +increase 1 +increased 1 +indecisive -2 +indestructible 2 +indifference -2 +indifferent -2 +indignant -2 +indignation -2 +indoctrinate -2 +indoctrinated -2 +indoctrinates -2 +indoctrinating -2 +ineffective -2 +ineffectively -2 +infatuated 2 +infatuation 2 +infected -2 +inferior -2 +inflamed -2 +influential 2 +infringement -2 +infuriate -2 +infuriated -2 +infuriates -2 +infuriating -2 +inhibit -1 +injured -2 +injury -2 +injustice -2 +innovate 1 +innovates 1 +innovation 1 +innovative 2 +inquisition -2 +inquisitive 2 +insane -2 +insanity -2 +insecure -2 +insensitive -2 +insensitivity -2 +insignificant -2 +insipid -2 +inspiration 2 +inspirational 2 +inspire 2 +inspired 2 +inspires 2 +inspiring 3 +insult -2 +insulted -2 +insulting -2 +insults -2 +intact 2 +integrity 2 +intelligent 2 +intense 1 +interest 1 +interested 2 +interesting 2 +interests 1 +interrogated -2 +interrupt -2 +interrupted -2 +interrupting -2 +interruption -2 +interrupts -2 +intimidate -2 +intimidated -2 +intimidates -2 +intimidating -2 +intimidation -2 +intricate 2 +intrigues 1 +invincible 2 +invite 1 +inviting 1 +invulnerable 2 +irate -3 +ironic -1 +irony -1 +irrational -1 +irresistible 2 +irresolute -2 +irresponsible 2 +irreversible -1 +irritate -3 +irritated -3 +irritating -3 +isolated -1 +itchy -2 +jackass -4 +jackasses -4 +jailed -2 +jaunty 2 +jealous -2 +jeopardy -2 +jerk -3 +jesus 1 +jewel 1 +jewels 1 +jocular 2 +join 1 +joke 2 +jokes 2 +jolly 2 +jovial 2 +joy 3 +joyful 3 +joyfully 3 +joyless -2 +joyous 3 +jubilant 3 +jumpy -1 +justice 2 +justifiably 2 +justified 2 +keen 1 +kill -3 +killed -3 +killing -3 +kills -3 +kind 2 +kinder 2 +kiss 2 +kudos 3 +lack -2 +lackadaisical -2 +lag -1 +lagged -2 +lagging -2 +lags -2 +lame -2 +landmark 2 +laugh 1 +laughed 1 +laughing 1 +laughs 1 +laughting 1 +launched 1 +lawl 3 +lawsuit -2 +lawsuits -2 +lazy -1 +leak -1 +leaked -1 +leave -1 +legal 1 +legally 1 +lenient 1 +lethargic -2 +lethargy -2 +liar -3 +liars -3 +libelous -2 +lied -2 +lifesaver 4 +lighthearted 1 +like 2 +liked 2 +likes 2 +limitation -1 +limited -1 +limits -1 +litigation -1 +litigious -2 +lively 2 +livid -2 +lmao 4 +lmfao 4 +loathe -3 +loathed -3 +loathes -3 +loathing -3 +lobby -2 +lobbying -2 +lol 3 +lonely -2 +lonesome -2 +longing -1 +loom -1 +loomed -1 +looming -1 +looms -1 +loose -3 +looses -3 +loser -3 +losing -3 +loss -3 +lost -3 +lovable 3 +love 3 +loved 3 +lovelies 3 +lovely 3 +loving 2 +lowest -1 +loyal 3 +loyalty 3 +luck 3 +luckily 3 +lucky 3 +lugubrious -2 +lunatic -3 +lunatics -3 +lurk -1 +lurking -1 +lurks -1 +mad -3 +maddening -3 +made-up -1 +madly -3 +madness -3 +mandatory -1 +manipulated -1 +manipulating -1 +manipulation -1 +marvel 3 +marvelous 3 +marvels 3 +masterpiece 4 +masterpieces 4 +matter 1 +matters 1 +mature 2 +meaningful 2 +meaningless -2 +medal 3 +mediocrity -3 +meditative 1 +melancholy -2 +menace -2 +menaced -2 +mercy 2 +merry 3 +mess -2 +messed -2 +messing up -2 +methodical 2 +mindless -2 +miracle 4 +mirth 3 +mirthful 3 +mirthfully 3 +misbehave -2 +misbehaved -2 +misbehaves -2 +misbehaving -2 +mischief -1 +mischiefs -1 +miserable -3 +misery -2 +misgiving -2 +misinformation -2 +misinformed -2 +misinterpreted -2 +misleading -3 +misread -1 +misreporting -2 +misrepresentation -2 +miss -2 +missed -2 +missing -2 +mistake -2 +mistaken -2 +mistakes -2 +mistaking -2 +misunderstand -2 +misunderstanding -2 +misunderstands -2 +misunderstood -2 +moan -2 +moaned -2 +moaning -2 +moans -2 +mock -2 +mocked -2 +mocking -2 +mocks -2 +mongering -2 +monopolize -2 +monopolized -2 +monopolizes -2 +monopolizing -2 +moody -1 +mope -1 +moping -1 +moron -3 +motherfucker -5 +motherfucking -5 +motivate 1 +motivated 2 +motivating 2 +motivation 1 +mourn -2 +mourned -2 +mournful -2 +mourning -2 +mourns -2 +mumpish -2 +murder -2 +murderer -2 +murdering -3 +murderous -3 +murders -2 +myth -1 +n00b -2 +naive -2 +nasty -3 +natural 1 +naïve -2 +needy -2 +negative -2 +negativity -2 +neglect -2 +neglected -2 +neglecting -2 +neglects -2 +nerves -1 +nervous -2 +nervously -2 +nice 3 +nifty 2 +niggas -5 +nigger -5 +no -1 +no fun -3 +noble 2 +noisy -1 +nonsense -2 +noob -2 +nosey -2 +not good -2 +not working -3 +notorious -2 +novel 2 +numb -1 +nuts -3 +obliterate -2 +obliterated -2 +obnoxious -3 +obscene -2 +obsessed 2 +obsolete -2 +obstacle -2 +obstacles -2 +obstinate -2 +odd -2 +offend -2 +offended -2 +offender -2 +offending -2 +offends -2 +offline -1 +oks 2 +ominous 3 +once-in-a-lifetime 3 +opportunities 2 +opportunity 2 +oppressed -2 +oppressive -2 +optimism 2 +optimistic 2 +optionless -2 +outcry -2 +outmaneuvered -2 +outrage -3 +outraged -3 +outreach 2 +outstanding 5 +overjoyed 4 +overload -1 +overlooked -1 +overreact -2 +overreacted -2 +overreaction -2 +overreacts -2 +oversell -2 +overselling -2 +oversells -2 +oversimplification -2 +oversimplified -2 +oversimplifies -2 +oversimplify -2 +overstatement -2 +overstatements -2 +overweight -1 +oxymoron -1 +pain -2 +pained -2 +panic -3 +panicked -3 +panics -3 +paradise 3 +paradox -1 +pardon 2 +pardoned 2 +pardoning 2 +pardons 2 +parley -1 +passionate 2 +passive -1 +passively -1 +pathetic -2 +pay -1 +peace 2 +peaceful 2 +peacefully 2 +penalty -2 +pensive -1 +perfect 3 +perfected 2 +perfectly 3 +perfects 2 +peril -2 +perjury -3 +perpetrator -2 +perpetrators -2 +perplexed -2 +persecute -2 +persecuted -2 +persecutes -2 +persecuting -2 +perturbed -2 +pesky -2 +pessimism -2 +pessimistic -2 +petrified -2 +phobic -2 +picturesque 2 +pileup -1 +pique -2 +piqued -2 +piss -4 +pissed -4 +pissing -3 +piteous -2 +pitied -1 +pity -2 +playful 2 +pleasant 3 +please 1 +pleased 3 +pleasure 3 +poised -2 +poison -2 +poisoned -2 +poisons -2 +pollute -2 +polluted -2 +polluter -2 +polluters -2 +pollutes -2 +poor -2 +poorer -2 +poorest -2 +popular 3 +positive 2 +positively 2 +possessive -2 +postpone -1 +postponed -1 +postpones -1 +postponing -1 +poverty -1 +powerful 2 +powerless -2 +praise 3 +praised 3 +praises 3 +praising 3 +pray 1 +praying 1 +prays 1 +prblm -2 +prblms -2 +prepared 1 +pressure -1 +pressured -2 +pretend -1 +pretending -1 +pretends -1 +pretty 1 +prevent -1 +prevented -1 +preventing -1 +prevents -1 +prick -5 +prison -2 +prisoner -2 +prisoners -2 +privileged 2 +proactive 2 +problem -2 +problems -2 +profiteer -2 +progress 2 +prominent 2 +promise 1 +promised 1 +promises 1 +promote 1 +promoted 1 +promotes 1 +promoting 1 +propaganda -2 +prosecute -1 +prosecuted -2 +prosecutes -1 +prosecution -1 +prospect 1 +prospects 1 +prosperous 3 +protect 1 +protected 1 +protects 1 +protest -2 +protesters -2 +protesting -2 +protests -2 +proud 2 +proudly 2 +provoke -1 +provoked -1 +provokes -1 +provoking -1 +pseudoscience -3 +punish -2 +punished -2 +punishes -2 +punitive -2 +pushy -1 +puzzled -2 +quaking -2 +questionable -2 +questioned -1 +questioning -1 +racism -3 +racist -3 +racists -3 +rage -2 +rageful -2 +rainy -1 +rant -3 +ranter -3 +ranters -3 +rants -3 +rape -4 +rapist -4 +rapture 2 +raptured 2 +raptures 2 +rapturous 4 +rash -2 +ratified 2 +reach 1 +reached 1 +reaches 1 +reaching 1 +reassure 1 +reassured 1 +reassures 1 +reassuring 2 +rebellion -2 +recession -2 +reckless -2 +recommend 2 +recommended 2 +recommends 2 +redeemed 2 +refuse -2 +refused -2 +refusing -2 +regret -2 +regretful -2 +regrets -2 +regretted -2 +regretting -2 +reject -1 +rejected -1 +rejecting -1 +rejects -1 +rejoice 4 +rejoiced 4 +rejoices 4 +rejoicing 4 +relaxed 2 +relentless -1 +reliant 2 +relieve 1 +relieved 2 +relieves 1 +relieving 2 +relishing 2 +remarkable 2 +remorse -2 +repulse -1 +repulsed -2 +rescue 2 +rescued 2 +rescues 2 +resentful -2 +resign -1 +resigned -1 +resigning -1 +resigns -1 +resolute 2 +resolve 2 +resolved 2 +resolves 2 +resolving 2 +respected 2 +responsible 2 +responsive 2 +restful 2 +restless -2 +restore 1 +restored 1 +restores 1 +restoring 1 +restrict -2 +restricted -2 +restricting -2 +restriction -2 +restricts -2 +retained -1 +retard -2 +retarded -2 +retreat -1 +revenge -2 +revengeful -2 +revered 2 +revive 2 +revives 2 +reward 2 +rewarded 2 +rewarding 2 +rewards 2 +rich 2 +ridiculous -3 +rig -1 +rigged -1 +right direction 3 +rigorous 3 +rigorously 3 +riot -2 +riots -2 +risk -2 +risks -2 +rob -2 +robber -2 +robed -2 +robing -2 +robs -2 +robust 2 +rofl 4 +roflcopter 4 +roflmao 4 +romance 2 +rotfl 4 +rotflmfao 4 +rotflol 4 +ruin -2 +ruined -2 +ruining -2 +ruins -2 +sabotage -2 +sad -2 +sadden -2 +saddened -2 +sadly -2 +safe 1 +safely 1 +safety 1 +salient 1 +sappy -1 +sarcastic -2 +satisfied 2 +save 2 +saved 2 +scam -2 +scams -2 +scandal -3 +scandalous -3 +scandals -3 +scapegoat -2 +scapegoats -2 +scare -2 +scared -2 +scary -2 +sceptical -2 +scold -2 +scoop 3 +scorn -2 +scornful -2 +scream -2 +screamed -2 +screaming -2 +screams -2 +screwed -2 +screwed up -3 +scumbag -4 +secure 2 +secured 2 +secures 2 +sedition -2 +seditious -2 +seduced -1 +self-confident 2 +self-deluded -2 +selfish -3 +selfishness -3 +sentence -2 +sentenced -2 +sentences -2 +sentencing -2 +serene 2 +severe -2 +sexy 3 +shaky -2 +shame -2 +shamed -2 +shameful -2 +share 1 +shared 1 +shares 1 +shattered -2 +shit -4 +shithead -4 +shitty -3 +shock -2 +shocked -2 +shocking -2 +shocks -2 +shoot -1 +short-sighted -2 +short-sightedness -2 +shortage -2 +shortages -2 +shrew -4 +shy -1 +sick -2 +sigh -2 +significance 1 +significant 1 +silencing -1 +silly -1 +sincere 2 +sincerely 2 +sincerest 2 +sincerity 2 +sinful -3 +singleminded -2 +skeptic -2 +skeptical -2 +skepticism -2 +skeptics -2 +slam -2 +slash -2 +slashed -2 +slashes -2 +slashing -2 +slavery -3 +sleeplessness -2 +slick 2 +slicker 2 +slickest 2 +sluggish -2 +slut -5 +smart 1 +smarter 2 +smartest 2 +smear -2 +smile 2 +smiled 2 +smiles 2 +smiling 2 +smog -2 +sneaky -1 +snub -2 +snubbed -2 +snubbing -2 +snubs -2 +sobering 1 +solemn -1 +solid 2 +solidarity 2 +solution 1 +solutions 1 +solve 1 +solved 1 +solves 1 +solving 1 +somber -2 +some kind 0 +son-of-a-bitch -5 +soothe 3 +soothed 3 +soothing 3 +sophisticated 2 +sore -1 +sorrow -2 +sorrowful -2 +sorry -1 +spam -2 +spammer -3 +spammers -3 +spamming -2 +spark 1 +sparkle 3 +sparkles 3 +sparkling 3 +speculative -2 +spirit 1 +spirited 2 +spiritless -2 +spiteful -2 +splendid 3 +sprightly 2 +squelched -1 +stab -2 +stabbed -2 +stable 2 +stabs -2 +stall -2 +stalled -2 +stalling -2 +stamina 2 +stampede -2 +startled -2 +starve -2 +starved -2 +starves -2 +starving -2 +steadfast 2 +steal -2 +steals -2 +stereotype -2 +stereotyped -2 +stifled -1 +stimulate 1 +stimulated 1 +stimulates 1 +stimulating 2 +stingy -2 +stolen -2 +stop -1 +stopped -1 +stopping -1 +stops -1 +stout 2 +straight 1 +strange -1 +strangely -1 +strangled -2 +strength 2 +strengthen 2 +strengthened 2 +strengthening 2 +strengthens 2 +stressed -2 +stressor -2 +stressors -2 +stricken -2 +strike -1 +strikers -2 +strikes -1 +strong 2 +stronger 2 +strongest 2 +struck -1 +struggle -2 +struggled -2 +struggles -2 +struggling -2 +stubborn -2 +stuck -2 +stunned -2 +stunning 4 +stupid -2 +stupidly -2 +suave 2 +substantial 1 +substantially 1 +subversive -2 +success 2 +successful 3 +suck -3 +sucks -3 +suffer -2 +suffering -2 +suffers -2 +suicidal -2 +suicide -2 +suing -2 +sulking -2 +sulky -2 +sullen -2 +sunshine 2 +super 3 +superb 5 +superior 2 +support 2 +supported 2 +supporter 1 +supporters 1 +supporting 1 +supportive 2 +supports 2 +survived 2 +surviving 2 +survivor 2 +suspect -1 +suspected -1 +suspecting -1 +suspects -1 +suspend -1 +suspended -1 +suspicious -2 +swear -2 +swearing -2 +swears -2 +sweet 2 +swift 2 +swiftly 2 +swindle -3 +swindles -3 +swindling -3 +sympathetic 2 +sympathy 2 +tard -2 +tears -2 +tender 2 +tense -2 +tension -1 +terrible -3 +terribly -3 +terrific 4 +terrified -3 +terror -3 +terrorize -3 +terrorized -3 +terrorizes -3 +thank 2 +thankful 2 +thanks 2 +thorny -2 +thoughtful 2 +thoughtless -2 +threat -2 +threaten -2 +threatened -2 +threatening -2 +threatens -2 +threats -2 +thrilled 5 +thwart -2 +thwarted -2 +thwarting -2 +thwarts -2 +timid -2 +timorous -2 +tired -2 +tits -2 +tolerant 2 +toothless -2 +top 2 +tops 2 +torn -2 +torture -4 +tortured -4 +tortures -4 +torturing -4 +totalitarian -2 +totalitarianism -2 +tout -2 +touted -2 +touting -2 +touts -2 +tragedy -2 +tragic -2 +tranquil 2 +trap -1 +trapped -2 +trauma -3 +traumatic -3 +travesty -2 +treason -3 +treasonous -3 +treasure 2 +treasures 2 +trembling -2 +tremulous -2 +tricked -2 +trickery -2 +triumph 4 +triumphant 4 +trouble -2 +troubled -2 +troubles -2 +true 2 +trust 1 +trusted 2 +tumor -2 +twat -5 +ugly -3 +unacceptable -2 +unappreciated -2 +unapproved -2 +unaware -2 +unbelievable -1 +unbelieving -1 +unbiased 2 +uncertain -1 +unclear -1 +uncomfortable -2 +unconcerned -2 +unconfirmed -1 +unconvinced -1 +uncredited -1 +undecided -1 +underestimate -1 +underestimated -1 +underestimates -1 +underestimating -1 +undermine -2 +undermined -2 +undermines -2 +undermining -2 +undeserving -2 +undesirable -2 +uneasy -2 +unemployment -2 +unequal -1 +unequaled 2 +unethical -2 +unfair -2 +unfocused -2 +unfulfilled -2 +unhappy -2 +unhealthy -2 +unified 1 +unimpressed -2 +unintelligent -2 +united 1 +unjust -2 +unlovable -2 +unloved -2 +unmatched 1 +unmotivated -2 +unprofessional -2 +unresearched -2 +unsatisfied -2 +unsecured -2 +unsettled -1 +unsophisticated -2 +unstable -2 +unstoppable 2 +unsupported -2 +unsure -1 +untarnished 2 +unwanted -2 +unworthy -2 +upset -2 +upsets -2 +upsetting -2 +uptight -2 +urgent -1 +useful 2 +usefulness 2 +useless -2 +uselessness -2 +vague -2 +validate 1 +validated 1 +validates 1 +validating 1 +verdict -1 +verdicts -1 +vested 1 +vexation -2 +vexing -2 +vibrant 3 +vicious -2 +victim -3 +victimize -3 +victimized -3 +victimizes -3 +victimizing -3 +victims -3 +vigilant 3 +vile -3 +vindicate 2 +vindicated 2 +vindicates 2 +vindicating 2 +violate -2 +violated -2 +violates -2 +violating -2 +violence -3 +violent -3 +virtuous 2 +virulent -2 +vision 1 +visionary 3 +visioning 1 +visions 1 +vitality 3 +vitamin 1 +vitriolic -3 +vivacious 3 +vociferous -1 +vulnerability -2 +vulnerable -2 +walkout -2 +walkouts -2 +wanker -3 +want 1 +war -2 +warfare -2 +warm 1 +warmth 2 +warn -2 +warned -2 +warning -3 +warnings -3 +warns -2 +waste -1 +wasted -2 +wasting -2 +wavering -1 +weak -2 +weakness -2 +wealth 3 +wealthy 2 +weary -2 +weep -2 +weeping -2 +weird -2 +welcome 2 +welcomed 2 +welcomes 2 +whimsical 1 +whitewash -3 +whore -4 +wicked -2 +widowed -1 +willingness 2 +win 4 +winner 4 +winning 4 +wins 4 +winwin 3 +wish 1 +wishes 1 +wishing 1 +withdrawal -3 +woebegone -2 +woeful -3 +won 3 +wonderful 4 +woo 3 +woohoo 3 +wooo 4 +woow 4 +worn -1 +worried -3 +worry -3 +worrying -3 +worse -3 +worsen -3 +worsened -3 +worsening -3 +worsens -3 +worshiped 3 +worst -3 +worth 2 +worthless -2 +worthy 2 +wow 4 +wowow 4 +wowww 4 +wrathful -3 +wreck -2 +wrong -2 +wronged -2 +wtf -4 +yeah 1 +yearning 1 +yeees 2 +yes 1 +youthful 2 +yucky -2 +yummy 3 +zealot -2 +zealots -2 +zealous 2 \ No newline at end of file diff --git a/dev/.rat-excludes b/dev/.rat-excludes new file mode 100644 index 0000000000000..8b5061415ff4c --- /dev/null +++ b/dev/.rat-excludes @@ -0,0 +1,100 @@ +target +cache +.gitignore +.gitattributes +.project +.classpath +.mima-excludes +.generated-mima-excludes +.generated-mima-class-excludes +.generated-mima-member-excludes +.rat-excludes +.*md +derby.log +TAGS +RELEASE +control +docs +slaves +spark-env.cmd +bootstrap-tooltip.js +jquery-1.11.1.min.js +d3.min.js +dagre-d3.min.js +graphlib-dot.min.js +sorttable.js +vis.min.js +vis.min.css +dataTables.bootstrap.css +dataTables.bootstrap.min.js +dataTables.rowsGroup.js +jquery.blockUI.min.js +jquery.cookies.2.2.0.min.js +jquery.dataTables.1.10.4.min.css +jquery.dataTables.1.10.4.min.js +jquery.mustache.js +jsonFormatter.min.css +jsonFormatter.min.js +.*avsc +.*txt +.*json +.*data +.*log +cloudpickle.py +heapq3.py +join.py +SparkExprTyper.scala +SparkILoop.scala +SparkILoopInit.scala +SparkIMain.scala +SparkImports.scala +SparkJLineCompletion.scala +SparkJLineReader.scala +SparkMemberHandlers.scala +SparkReplReporter.scala +sbt +sbt-launch-lib.bash +plugins.sbt +work +.*\.q +.*\.qv +golden +test.out/* +.*iml +service.properties +db.lck +build/* +dist/* +.*out +.*ipr +.*iws +logs +.*scalastyle-output.xml +.*dependency-reduced-pom.xml +known_translations +json_expectation +local-1422981759269 +local-1422981780767 +local-1425081759269 +local-1426533911241 +local-1426633911242 +local-1430917381534 +local-1430917381535_1 +local-1430917381535_2 +DESCRIPTION +NAMESPACE +test_support/* +.*Rd +help/* +html/* +INDEX +.lintr +gen-java.* +.*avpr +org.apache.spark.sql.sources.DataSourceRegister +org.apache.spark.scheduler.SparkHistoryListenerFactory +.*parquet +LZ4BlockInputStream.java +spark-deps-.* +.*csv +.*tsv diff --git a/dev/audit-release/README.md b/dev/audit-release/README.md index f72f8c653a265..37b2a0afb7aee 100644 --- a/dev/audit-release/README.md +++ b/dev/audit-release/README.md @@ -1,10 +1,11 @@ -# Test Application Builds -This directory includes test applications which are built when auditing releases. You can -run them locally by setting appropriate environment variables. +Test Application Builds +======================= + +This directory includes test applications which are built when auditing releases. You can run them locally by setting appropriate environment variables. ``` $ cd sbt_app_core -$ SCALA_VERSION=2.10.5 \ +$ SCALA_VERSION=2.11.7 \ SPARK_VERSION=1.0.0-SNAPSHOT \ SPARK_RELEASE_REPOSITORY=file:///home/patrick/.ivy2/local \ sbt run diff --git a/dev/audit-release/audit_release.py b/dev/audit-release/audit_release.py index 27d1dd784ce2e..ee72da4df0652 100755 --- a/dev/audit-release/audit_release.py +++ b/dev/audit-release/audit_release.py @@ -35,8 +35,8 @@ RELEASE_KEY = "XXXXXXXX" # Your 8-digit hex RELEASE_REPOSITORY = "https://repository.apache.org/content/repositories/orgapachespark-1033" RELEASE_VERSION = "1.1.1" -SCALA_VERSION = "2.10.5" -SCALA_BINARY_VERSION = "2.10" +SCALA_VERSION = "2.11.7" +SCALA_BINARY_VERSION = "2.11" # Do not set these LOG_FILE_NAME = "spark_audit_%s" % time.strftime("%h_%m_%Y_%I_%M_%S") @@ -115,9 +115,8 @@ def ensure_path_not_present(path): # maven that links against them. This will catch issues with messed up # dependencies within those projects. modules = [ - "spark-core", "spark-bagel", "spark-mllib", "spark-streaming", "spark-repl", + "spark-core", "spark-mllib", "spark-streaming", "spark-repl", "spark-graphx", "spark-streaming-flume", "spark-streaming-kafka", - "spark-streaming-mqtt", "spark-streaming-twitter", "spark-streaming-zeromq", "spark-catalyst", "spark-sql", "spark-hive", "spark-streaming-kinesis-asl" ] modules = map(lambda m: "%s_%s" % (m, SCALA_BINARY_VERSION), modules) diff --git a/dev/check-license b/dev/check-license index 10740cfdc5242..678e73fd60f1f 100755 --- a/dev/check-license +++ b/dev/check-license @@ -58,7 +58,7 @@ else declare java_cmd=java fi -export RAT_VERSION=0.10 +export RAT_VERSION=0.11 export rat_jar="$FWDIR"/lib/apache-rat-${RAT_VERSION}.jar mkdir -p "$FWDIR"/lib @@ -67,14 +67,15 @@ mkdir -p "$FWDIR"/lib exit 1 } -$java_cmd -jar "$rat_jar" -E "$FWDIR"/.rat-excludes -d "$FWDIR" > rat-results.txt +mkdir target +$java_cmd -jar "$rat_jar" -E "$FWDIR"/dev/.rat-excludes -d "$FWDIR" > target/rat-results.txt if [ $? -ne 0 ]; then echo "RAT exited abnormally" exit 1 fi -ERRORS="$(cat rat-results.txt | grep -e "??")" +ERRORS="$(cat target/rat-results.txt | grep -e "??")" if test ! -z "$ERRORS"; then echo "Could not find Apache license headers in the following files:" diff --git a/dev/checkstyle-suppressions.xml b/dev/checkstyle-suppressions.xml new file mode 100644 index 0000000000000..a1a88ac8cdac5 --- /dev/null +++ b/dev/checkstyle-suppressions.xml @@ -0,0 +1,39 @@ + + + + + + + + + + + + diff --git a/dev/checkstyle.xml b/dev/checkstyle.xml new file mode 100644 index 0000000000000..b66dca9041f2f --- /dev/null +++ b/dev/checkstyle.xml @@ -0,0 +1,170 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index cb79e9eba06e2..65e80fc76056a 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -23,8 +23,8 @@ usage: release-build.sh Creates build deliverables from a Spark commit. Top level targets are - package: Create binary packages and copy them to people.apache - docs: Build docs and copy them to people.apache + package: Create binary packages and copy them to home.apache + docs: Build docs and copy them to home.apache publish-snapshot: Publish snapshot release to Apache snapshots publish-release: Publish a release to Apache release repo @@ -64,13 +64,16 @@ for env in ASF_USERNAME ASF_RSA_KEY GPG_PASSPHRASE GPG_KEY; do fi done +# Explicitly set locale in order to make `sort` output consistent across machines. +# See https://stackoverflow.com/questions/28881 for more details. +export LC_ALL=C + # Commit ref to checkout when building GIT_REF=${GIT_REF:-master} # Destination directory parent on remote server REMOTE_PARENT_DIR=${REMOTE_PARENT_DIR:-/home/$ASF_USERNAME/public_html} -SSH="ssh -o ConnectTimeout=300 -o StrictHostKeyChecking=no -i $ASF_RSA_KEY" GPG="gpg --no-tty --batch" NEXUS_ROOT=https://repository.apache.org/service/local/staging NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads @@ -97,7 +100,20 @@ if [ -z "$SPARK_PACKAGE_VERSION" ]; then fi DEST_DIR_NAME="spark-$SPARK_PACKAGE_VERSION" -USER_HOST="$ASF_USERNAME@people.apache.org" + +function LFTP { + SSH="ssh -o ConnectTimeout=300 -o StrictHostKeyChecking=no -i $ASF_RSA_KEY" + COMMANDS=$(cat <&1 | grep 'Maven home' | awk '{print $NF}'` - ./make-distribution.sh --name $NAME --mvn $MVN_HOME/bin/mvn --tgz $FLAGS \ + ./dev/make-distribution.sh --name $NAME --mvn $MVN_HOME/bin/mvn --tgz $FLAGS \ -DzincPort=$ZINC_PORT 2>&1 > ../binary-release-$NAME.log cd .. cp spark-$SPARK_VERSION-bin-$NAME/spark-$SPARK_VERSION-bin-$NAME.tgz . @@ -166,12 +186,10 @@ if [[ "$1" == "package" ]]; then # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds # share the same Zinc server. - make_binary_release "hadoop1" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver" "3030" & - make_binary_release "hadoop1-scala2.11" "-Psparkr -Phadoop-1 -Phive -Dscala-2.11" "3031" & - make_binary_release "cdh4" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & make_binary_release "hadoop2.3" "-Psparkr -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & make_binary_release "hadoop2.4" "-Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & - make_binary_release "hadoop2.6" "-Psparkr -Phadoop-2.6 -Phive -Phive-thriftserver -Pyarn" "3034" & + make_binary_release "hadoop2.6" "-Psparkr -Phadoop-2.6 -Phive -Phive-thriftserver -Pyarn" "3035" & + make_binary_release "hadoop2.7" "-Psparkr -Phadoop-2.7 -Phive -Phive-thriftserver -Pyarn" "3036" & make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn" "3037" & make_binary_release "without-hadoop" "-Psparkr -Phadoop-provided -Pyarn" "3038" & wait @@ -180,11 +198,15 @@ if [[ "$1" == "package" ]]; then # Copy data dest_dir="$REMOTE_PARENT_DIR/${DEST_DIR_NAME}-bin" echo "Copying release tarballs to $dest_dir" - $SSH $USER_HOST mkdir $dest_dir - rsync -e "$SSH" spark-* $USER_HOST:$dest_dir - echo "Linking /latest to $dest_dir" - $SSH $USER_HOST rm -f "$REMOTE_PARENT_DIR/latest" - $SSH $USER_HOST ln -s $dest_dir "$REMOTE_PARENT_DIR/latest" + # Put to new directory: + LFTP mkdir -p $dest_dir + LFTP mput -O $dest_dir 'spark-*' + # Delete /latest directory and rename new upload to /latest + LFTP "rm -r -f $REMOTE_PARENT_DIR/latest || exit 0" + LFTP mv $dest_dir "$REMOTE_PARENT_DIR/latest" + # Re-upload a second time and leave the files in the timestamped upload directory: + LFTP mkdir -p $dest_dir + LFTP mput -O $dest_dir 'spark-*' exit 0 fi @@ -198,11 +220,15 @@ if [[ "$1" == "docs" ]]; then # TODO: Make configurable to add this: PRODUCTION=1 PRODUCTION=1 RELEASE_VERSION="$SPARK_VERSION" jekyll build echo "Copying release documentation to $dest_dir" - $SSH $USER_HOST mkdir $dest_dir - echo "Linking /latest to $dest_dir" - $SSH $USER_HOST rm -f "$REMOTE_PARENT_DIR/latest" - $SSH $USER_HOST ln -s $dest_dir "$REMOTE_PARENT_DIR/latest" - rsync -e "$SSH" -r _site/* $USER_HOST:$dest_dir + # Put to new directory: + LFTP mkdir -p $dest_dir + LFTP mirror -R _site $dest_dir + # Delete /latest directory and rename new upload to /latest + LFTP "rm -r -f $REMOTE_PARENT_DIR/latest || exit 0" + LFTP mv $dest_dir "$REMOTE_PARENT_DIR/latest" + # Re-upload a second time and leave the files in the timestamped upload directory: + LFTP mkdir -p $dest_dir + LFTP mirror -R _site $dest_dir cd .. exit 0 fi @@ -230,8 +256,8 @@ if [[ "$1" == "publish-snapshot" ]]; then $MVN -DzincPort=$ZINC_PORT --settings $tmp_settings -DskipTests $PUBLISH_PROFILES \ -Phive-thriftserver deploy - ./dev/change-scala-version.sh 2.11 - $MVN -DzincPort=$ZINC_PORT -Dscala-2.11 --settings $tmp_settings \ + ./dev/change-scala-version.sh 2.10 + $MVN -DzincPort=$ZINC_PORT -Dscala-2.10 --settings $tmp_settings \ -DskipTests $PUBLISH_PROFILES clean deploy # Clean-up Zinc nailgun process @@ -268,9 +294,9 @@ if [[ "$1" == "publish-release" ]]; then $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests $PUBLISH_PROFILES \ -Phive-thriftserver clean install - ./dev/change-scala-version.sh 2.11 + ./dev/change-scala-version.sh 2.10 - $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -Dscala-2.11 \ + $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -Dscala-2.10 \ -DskipTests $PUBLISH_PROFILES clean install # Clean-up Zinc nailgun process diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh index b0a3374becc6a..d404939d1caee 100755 --- a/dev/create-release/release-tag.sh +++ b/dev/create-release/release-tag.sh @@ -64,9 +64,6 @@ git commit -a -m "Preparing Spark release $RELEASE_TAG" echo "Creating tag $RELEASE_TAG at the head of $GIT_BRANCH" git tag $RELEASE_TAG -# TODO: It would be nice to do some verifications here -# i.e. check whether ec2 scripts have the new version - # Create next version $MVN versions:set -DnewVersion=$NEXT_VERSION | grep -v "no value" # silence logs git commit -a -m "Preparing development version $NEXT_VERSION" diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index 7f152b7f53559..5d0ac16b3b0a1 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -159,7 +159,6 @@ def get_commits(tag): "build": CORE_COMPONENT, "deploy": CORE_COMPONENT, "documentation": CORE_COMPONENT, - "ec2": "EC2", "examples": CORE_COMPONENT, "graphx": "GraphX", "input/output": CORE_COMPONENT, diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 new file mode 100644 index 0000000000000..023fba536915d --- /dev/null +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -0,0 +1,181 @@ +JavaEWAH-0.3.2.jar +RoaringBitmap-0.5.11.jar +ST4-4.0.4.jar +activation-1.1.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar +antlr4-runtime-4.5.2-1.jar +aopalliance-1.0.jar +apache-log4j-extras-1.2.17.jar +arpack_combined_all-0.1.jar +asm-3.1.jar +asm-commons-3.1.jar +asm-tree-3.1.jar +avro-1.7.7.jar +avro-ipc-1.7.7.jar +avro-mapred-1.7.7-hadoop2.jar +bonecp-0.8.0.RELEASE.jar +breeze-macros_2.11-0.11.2.jar +breeze_2.11-0.11.2.jar +calcite-avatica-1.2.0-incubating.jar +calcite-core-1.2.0-incubating.jar +calcite-linq4j-1.2.0-incubating.jar +chill-java-0.8.0.jar +chill_2.11-0.8.0.jar +commons-beanutils-1.7.0.jar +commons-beanutils-core-1.8.0.jar +commons-cli-1.2.jar +commons-codec-1.10.jar +commons-collections-3.2.2.jar +commons-compiler-2.7.6.jar +commons-compress-1.4.1.jar +commons-configuration-1.6.jar +commons-dbcp-1.4.jar +commons-digester-1.8.jar +commons-httpclient-3.1.jar +commons-io-2.1.jar +commons-lang-2.6.jar +commons-lang3-3.3.2.jar +commons-logging-1.1.3.jar +commons-math-2.1.jar +commons-math3-3.4.1.jar +commons-net-2.2.jar +commons-pool-1.5.4.jar +compress-lzf-1.0.3.jar +core-1.1.2.jar +curator-client-2.4.0.jar +curator-framework-2.4.0.jar +curator-recipes-2.4.0.jar +datanucleus-api-jdo-3.2.6.jar +datanucleus-core-3.2.10.jar +datanucleus-rdbms-3.2.9.jar +derby-10.10.1.1.jar +eigenbase-properties-1.1.5.jar +geronimo-annotation_1.0_spec-1.1.1.jar +geronimo-jaspic_1.0_spec-1.0.jar +geronimo-jta_1.1_spec-1.1.1.jar +gmbal-api-only-3.0.0-b023.jar +grizzly-framework-2.1.2.jar +grizzly-http-2.1.2.jar +grizzly-http-server-2.1.2.jar +grizzly-http-servlet-2.1.2.jar +grizzly-rcm-2.1.2.jar +guava-14.0.1.jar +guice-3.0.jar +guice-servlet-3.0.jar +hadoop-annotations-2.2.0.jar +hadoop-auth-2.2.0.jar +hadoop-client-2.2.0.jar +hadoop-common-2.2.0.jar +hadoop-hdfs-2.2.0.jar +hadoop-mapreduce-client-app-2.2.0.jar +hadoop-mapreduce-client-common-2.2.0.jar +hadoop-mapreduce-client-core-2.2.0.jar +hadoop-mapreduce-client-jobclient-2.2.0.jar +hadoop-mapreduce-client-shuffle-2.2.0.jar +hadoop-yarn-api-2.2.0.jar +hadoop-yarn-client-2.2.0.jar +hadoop-yarn-common-2.2.0.jar +hadoop-yarn-server-common-2.2.0.jar +hadoop-yarn-server-web-proxy-2.2.0.jar +httpclient-4.3.2.jar +httpcore-4.3.2.jar +ivy-2.4.0.jar +jackson-annotations-2.5.3.jar +jackson-core-2.5.3.jar +jackson-core-asl-1.9.13.jar +jackson-databind-2.5.3.jar +jackson-jaxrs-1.9.13.jar +jackson-mapper-asl-1.9.13.jar +jackson-module-scala_2.11-2.5.3.jar +jackson-xc-1.9.13.jar +janino-2.7.8.jar +javax.inject-1.jar +javax.servlet-3.0.0.v201112011016.jar +javax.servlet-3.1.jar +javax.servlet-api-3.0.1.jar +javolution-5.5.1.jar +jaxb-api-2.2.2.jar +jaxb-impl-2.2.3-1.jar +jcl-over-slf4j-1.7.16.jar +jdo-api-3.0.1.jar +jersey-client-1.9.jar +jersey-core-1.9.jar +jersey-grizzly2-1.9.jar +jersey-guice-1.9.jar +jersey-json-1.9.jar +jersey-server-1.9.jar +jersey-test-framework-core-1.9.jar +jersey-test-framework-grizzly2-1.9.jar +jets3t-0.7.1.jar +jettison-1.1.jar +jetty-all-7.6.0.v20120127.jar +jetty-util-6.1.26.jar +jline-2.12.jar +joda-time-2.9.jar +jodd-core-3.5.2.jar +jpam-1.1.jar +json-20090211.jar +json4s-ast_2.11-3.2.10.jar +json4s-core_2.11-3.2.10.jar +json4s-jackson_2.11-3.2.10.jar +jsr305-1.3.9.jar +jta-1.1.jar +jtransforms-2.4.0.jar +jul-to-slf4j-1.7.16.jar +kryo-shaded-3.0.3.jar +leveldbjni-all-1.8.jar +libfb303-0.9.2.jar +libthrift-0.9.2.jar +log4j-1.2.17.jar +lz4-1.3.0.jar +mail-1.4.1.jar +management-api-3.0.0-b012.jar +mesos-0.21.1-shaded-protobuf.jar +metrics-core-3.1.2.jar +metrics-graphite-3.1.2.jar +metrics-json-3.1.2.jar +metrics-jvm-3.1.2.jar +minlog-1.3.0.jar +netty-3.8.0.Final.jar +netty-all-4.0.29.Final.jar +objenesis-2.1.jar +opencsv-2.3.jar +oro-2.0.8.jar +paranamer-2.6.jar +parquet-column-1.7.0.jar +parquet-common-1.7.0.jar +parquet-encoding-1.7.0.jar +parquet-format-2.3.0-incubating.jar +parquet-generator-1.7.0.jar +parquet-hadoop-1.7.0.jar +parquet-hadoop-bundle-1.6.0.jar +parquet-jackson-1.7.0.jar +pmml-agent-1.2.7.jar +pmml-model-1.2.7.jar +pmml-schema-1.2.7.jar +protobuf-java-2.5.0.jar +py4j-0.9.2.jar +pyrolite-4.9.jar +scala-compiler-2.11.8.jar +scala-library-2.11.8.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.8.jar +scala-xml_2.11-1.0.2.jar +scalap-2.11.8.jar +slf4j-api-1.7.16.jar +slf4j-log4j12-1.7.16.jar +snappy-0.2.jar +snappy-java-1.1.2.4.jar +spire-macros_2.11-0.7.4.jar +spire_2.11-0.7.4.jar +stax-api-1.0-2.jar +stax-api-1.0.1.jar +stream-2.7.0.jar +stringtemplate-3.2.1.jar +super-csv-2.2.0.jar +univocity-parsers-2.0.2.jar +xbean-asm5-shaded-4.4.jar +xmlenc-0.52.jar +xz-1.0.jar +zookeeper-3.4.5.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 new file mode 100644 index 0000000000000..003c540d72a08 --- /dev/null +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -0,0 +1,172 @@ +JavaEWAH-0.3.2.jar +RoaringBitmap-0.5.11.jar +ST4-4.0.4.jar +activation-1.1.1.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar +antlr4-runtime-4.5.2-1.jar +aopalliance-1.0.jar +apache-log4j-extras-1.2.17.jar +arpack_combined_all-0.1.jar +asm-3.1.jar +asm-commons-3.1.jar +asm-tree-3.1.jar +avro-1.7.7.jar +avro-ipc-1.7.7.jar +avro-mapred-1.7.7-hadoop2.jar +base64-2.3.8.jar +bcprov-jdk15on-1.51.jar +bonecp-0.8.0.RELEASE.jar +breeze-macros_2.11-0.11.2.jar +breeze_2.11-0.11.2.jar +calcite-avatica-1.2.0-incubating.jar +calcite-core-1.2.0-incubating.jar +calcite-linq4j-1.2.0-incubating.jar +chill-java-0.8.0.jar +chill_2.11-0.8.0.jar +commons-beanutils-1.7.0.jar +commons-beanutils-core-1.8.0.jar +commons-cli-1.2.jar +commons-codec-1.10.jar +commons-collections-3.2.2.jar +commons-compiler-2.7.6.jar +commons-compress-1.4.1.jar +commons-configuration-1.6.jar +commons-dbcp-1.4.jar +commons-digester-1.8.jar +commons-httpclient-3.1.jar +commons-io-2.4.jar +commons-lang-2.6.jar +commons-lang3-3.3.2.jar +commons-logging-1.1.3.jar +commons-math3-3.4.1.jar +commons-net-2.2.jar +commons-pool-1.5.4.jar +compress-lzf-1.0.3.jar +core-1.1.2.jar +curator-client-2.4.0.jar +curator-framework-2.4.0.jar +curator-recipes-2.4.0.jar +datanucleus-api-jdo-3.2.6.jar +datanucleus-core-3.2.10.jar +datanucleus-rdbms-3.2.9.jar +derby-10.10.1.1.jar +eigenbase-properties-1.1.5.jar +geronimo-annotation_1.0_spec-1.1.1.jar +geronimo-jaspic_1.0_spec-1.0.jar +geronimo-jta_1.1_spec-1.1.1.jar +guava-14.0.1.jar +guice-3.0.jar +guice-servlet-3.0.jar +hadoop-annotations-2.3.0.jar +hadoop-auth-2.3.0.jar +hadoop-client-2.3.0.jar +hadoop-common-2.3.0.jar +hadoop-hdfs-2.3.0.jar +hadoop-mapreduce-client-app-2.3.0.jar +hadoop-mapreduce-client-common-2.3.0.jar +hadoop-mapreduce-client-core-2.3.0.jar +hadoop-mapreduce-client-jobclient-2.3.0.jar +hadoop-mapreduce-client-shuffle-2.3.0.jar +hadoop-yarn-api-2.3.0.jar +hadoop-yarn-client-2.3.0.jar +hadoop-yarn-common-2.3.0.jar +hadoop-yarn-server-common-2.3.0.jar +hadoop-yarn-server-web-proxy-2.3.0.jar +httpclient-4.3.2.jar +httpcore-4.3.2.jar +ivy-2.4.0.jar +jackson-annotations-2.5.3.jar +jackson-core-2.5.3.jar +jackson-core-asl-1.9.13.jar +jackson-databind-2.5.3.jar +jackson-jaxrs-1.9.13.jar +jackson-mapper-asl-1.9.13.jar +jackson-module-scala_2.11-2.5.3.jar +jackson-xc-1.9.13.jar +janino-2.7.8.jar +java-xmlbuilder-1.0.jar +javax.inject-1.jar +javax.servlet-3.0.0.v201112011016.jar +javolution-5.5.1.jar +jaxb-api-2.2.2.jar +jaxb-impl-2.2.3-1.jar +jcl-over-slf4j-1.7.16.jar +jdo-api-3.0.1.jar +jersey-core-1.9.jar +jersey-guice-1.9.jar +jersey-json-1.9.jar +jersey-server-1.9.jar +jets3t-0.9.3.jar +jettison-1.1.jar +jetty-6.1.26.jar +jetty-all-7.6.0.v20120127.jar +jetty-util-6.1.26.jar +jline-2.12.jar +joda-time-2.9.jar +jodd-core-3.5.2.jar +jpam-1.1.jar +json-20090211.jar +json4s-ast_2.11-3.2.10.jar +json4s-core_2.11-3.2.10.jar +json4s-jackson_2.11-3.2.10.jar +jsr305-1.3.9.jar +jta-1.1.jar +jtransforms-2.4.0.jar +jul-to-slf4j-1.7.16.jar +kryo-shaded-3.0.3.jar +leveldbjni-all-1.8.jar +libfb303-0.9.2.jar +libthrift-0.9.2.jar +log4j-1.2.17.jar +lz4-1.3.0.jar +mail-1.4.7.jar +mesos-0.21.1-shaded-protobuf.jar +metrics-core-3.1.2.jar +metrics-graphite-3.1.2.jar +metrics-json-3.1.2.jar +metrics-jvm-3.1.2.jar +minlog-1.3.0.jar +mx4j-3.0.2.jar +netty-3.8.0.Final.jar +netty-all-4.0.29.Final.jar +objenesis-2.1.jar +opencsv-2.3.jar +oro-2.0.8.jar +paranamer-2.6.jar +parquet-column-1.7.0.jar +parquet-common-1.7.0.jar +parquet-encoding-1.7.0.jar +parquet-format-2.3.0-incubating.jar +parquet-generator-1.7.0.jar +parquet-hadoop-1.7.0.jar +parquet-hadoop-bundle-1.6.0.jar +parquet-jackson-1.7.0.jar +pmml-agent-1.2.7.jar +pmml-model-1.2.7.jar +pmml-schema-1.2.7.jar +protobuf-java-2.5.0.jar +py4j-0.9.2.jar +pyrolite-4.9.jar +scala-compiler-2.11.8.jar +scala-library-2.11.8.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.8.jar +scala-xml_2.11-1.0.2.jar +scalap-2.11.8.jar +slf4j-api-1.7.16.jar +slf4j-log4j12-1.7.16.jar +snappy-0.2.jar +snappy-java-1.1.2.4.jar +spire-macros_2.11-0.7.4.jar +spire_2.11-0.7.4.jar +stax-api-1.0-2.jar +stax-api-1.0.1.jar +stream-2.7.0.jar +stringtemplate-3.2.1.jar +super-csv-2.2.0.jar +univocity-parsers-2.0.2.jar +xbean-asm5-shaded-4.4.jar +xmlenc-0.52.jar +xz-1.0.jar +zookeeper-3.4.5.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 new file mode 100644 index 0000000000000..80fbaea222388 --- /dev/null +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -0,0 +1,173 @@ +JavaEWAH-0.3.2.jar +RoaringBitmap-0.5.11.jar +ST4-4.0.4.jar +activation-1.1.1.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar +antlr4-runtime-4.5.2-1.jar +aopalliance-1.0.jar +apache-log4j-extras-1.2.17.jar +arpack_combined_all-0.1.jar +asm-3.1.jar +asm-commons-3.1.jar +asm-tree-3.1.jar +avro-1.7.7.jar +avro-ipc-1.7.7.jar +avro-mapred-1.7.7-hadoop2.jar +base64-2.3.8.jar +bcprov-jdk15on-1.51.jar +bonecp-0.8.0.RELEASE.jar +breeze-macros_2.11-0.11.2.jar +breeze_2.11-0.11.2.jar +calcite-avatica-1.2.0-incubating.jar +calcite-core-1.2.0-incubating.jar +calcite-linq4j-1.2.0-incubating.jar +chill-java-0.8.0.jar +chill_2.11-0.8.0.jar +commons-beanutils-1.7.0.jar +commons-beanutils-core-1.8.0.jar +commons-cli-1.2.jar +commons-codec-1.10.jar +commons-collections-3.2.2.jar +commons-compiler-2.7.6.jar +commons-compress-1.4.1.jar +commons-configuration-1.6.jar +commons-dbcp-1.4.jar +commons-digester-1.8.jar +commons-httpclient-3.1.jar +commons-io-2.4.jar +commons-lang-2.6.jar +commons-lang3-3.3.2.jar +commons-logging-1.1.3.jar +commons-math3-3.4.1.jar +commons-net-2.2.jar +commons-pool-1.5.4.jar +compress-lzf-1.0.3.jar +core-1.1.2.jar +curator-client-2.4.0.jar +curator-framework-2.4.0.jar +curator-recipes-2.4.0.jar +datanucleus-api-jdo-3.2.6.jar +datanucleus-core-3.2.10.jar +datanucleus-rdbms-3.2.9.jar +derby-10.10.1.1.jar +eigenbase-properties-1.1.5.jar +geronimo-annotation_1.0_spec-1.1.1.jar +geronimo-jaspic_1.0_spec-1.0.jar +geronimo-jta_1.1_spec-1.1.1.jar +guava-14.0.1.jar +guice-3.0.jar +guice-servlet-3.0.jar +hadoop-annotations-2.4.0.jar +hadoop-auth-2.4.0.jar +hadoop-client-2.4.0.jar +hadoop-common-2.4.0.jar +hadoop-hdfs-2.4.0.jar +hadoop-mapreduce-client-app-2.4.0.jar +hadoop-mapreduce-client-common-2.4.0.jar +hadoop-mapreduce-client-core-2.4.0.jar +hadoop-mapreduce-client-jobclient-2.4.0.jar +hadoop-mapreduce-client-shuffle-2.4.0.jar +hadoop-yarn-api-2.4.0.jar +hadoop-yarn-client-2.4.0.jar +hadoop-yarn-common-2.4.0.jar +hadoop-yarn-server-common-2.4.0.jar +hadoop-yarn-server-web-proxy-2.4.0.jar +httpclient-4.3.2.jar +httpcore-4.3.2.jar +ivy-2.4.0.jar +jackson-annotations-2.5.3.jar +jackson-core-2.5.3.jar +jackson-core-asl-1.9.13.jar +jackson-databind-2.5.3.jar +jackson-jaxrs-1.9.13.jar +jackson-mapper-asl-1.9.13.jar +jackson-module-scala_2.11-2.5.3.jar +jackson-xc-1.9.13.jar +janino-2.7.8.jar +java-xmlbuilder-1.0.jar +javax.inject-1.jar +javax.servlet-3.0.0.v201112011016.jar +javolution-5.5.1.jar +jaxb-api-2.2.2.jar +jaxb-impl-2.2.3-1.jar +jcl-over-slf4j-1.7.16.jar +jdo-api-3.0.1.jar +jersey-client-1.9.jar +jersey-core-1.9.jar +jersey-guice-1.9.jar +jersey-json-1.9.jar +jersey-server-1.9.jar +jets3t-0.9.3.jar +jettison-1.1.jar +jetty-6.1.26.jar +jetty-all-7.6.0.v20120127.jar +jetty-util-6.1.26.jar +jline-2.12.jar +joda-time-2.9.jar +jodd-core-3.5.2.jar +jpam-1.1.jar +json-20090211.jar +json4s-ast_2.11-3.2.10.jar +json4s-core_2.11-3.2.10.jar +json4s-jackson_2.11-3.2.10.jar +jsr305-1.3.9.jar +jta-1.1.jar +jtransforms-2.4.0.jar +jul-to-slf4j-1.7.16.jar +kryo-shaded-3.0.3.jar +leveldbjni-all-1.8.jar +libfb303-0.9.2.jar +libthrift-0.9.2.jar +log4j-1.2.17.jar +lz4-1.3.0.jar +mail-1.4.7.jar +mesos-0.21.1-shaded-protobuf.jar +metrics-core-3.1.2.jar +metrics-graphite-3.1.2.jar +metrics-json-3.1.2.jar +metrics-jvm-3.1.2.jar +minlog-1.3.0.jar +mx4j-3.0.2.jar +netty-3.8.0.Final.jar +netty-all-4.0.29.Final.jar +objenesis-2.1.jar +opencsv-2.3.jar +oro-2.0.8.jar +paranamer-2.6.jar +parquet-column-1.7.0.jar +parquet-common-1.7.0.jar +parquet-encoding-1.7.0.jar +parquet-format-2.3.0-incubating.jar +parquet-generator-1.7.0.jar +parquet-hadoop-1.7.0.jar +parquet-hadoop-bundle-1.6.0.jar +parquet-jackson-1.7.0.jar +pmml-agent-1.2.7.jar +pmml-model-1.2.7.jar +pmml-schema-1.2.7.jar +protobuf-java-2.5.0.jar +py4j-0.9.2.jar +pyrolite-4.9.jar +scala-compiler-2.11.8.jar +scala-library-2.11.8.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.8.jar +scala-xml_2.11-1.0.2.jar +scalap-2.11.8.jar +slf4j-api-1.7.16.jar +slf4j-log4j12-1.7.16.jar +snappy-0.2.jar +snappy-java-1.1.2.4.jar +spire-macros_2.11-0.7.4.jar +spire_2.11-0.7.4.jar +stax-api-1.0-2.jar +stax-api-1.0.1.jar +stream-2.7.0.jar +stringtemplate-3.2.1.jar +super-csv-2.2.0.jar +univocity-parsers-2.0.2.jar +xbean-asm5-shaded-4.4.jar +xmlenc-0.52.jar +xz-1.0.jar +zookeeper-3.4.5.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 new file mode 100644 index 0000000000000..b2c2a4caec86f --- /dev/null +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -0,0 +1,180 @@ +JavaEWAH-0.3.2.jar +RoaringBitmap-0.5.11.jar +ST4-4.0.4.jar +activation-1.1.1.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar +antlr4-runtime-4.5.2-1.jar +aopalliance-1.0.jar +apache-log4j-extras-1.2.17.jar +apacheds-i18n-2.0.0-M15.jar +apacheds-kerberos-codec-2.0.0-M15.jar +api-asn1-api-1.0.0-M20.jar +api-util-1.0.0-M20.jar +arpack_combined_all-0.1.jar +asm-3.1.jar +asm-commons-3.1.jar +asm-tree-3.1.jar +avro-1.7.7.jar +avro-ipc-1.7.7.jar +avro-mapred-1.7.7-hadoop2.jar +base64-2.3.8.jar +bcprov-jdk15on-1.51.jar +bonecp-0.8.0.RELEASE.jar +breeze-macros_2.11-0.11.2.jar +breeze_2.11-0.11.2.jar +calcite-avatica-1.2.0-incubating.jar +calcite-core-1.2.0-incubating.jar +calcite-linq4j-1.2.0-incubating.jar +chill-java-0.8.0.jar +chill_2.11-0.8.0.jar +commons-beanutils-1.7.0.jar +commons-beanutils-core-1.8.0.jar +commons-cli-1.2.jar +commons-codec-1.10.jar +commons-collections-3.2.2.jar +commons-compiler-2.7.6.jar +commons-compress-1.4.1.jar +commons-configuration-1.6.jar +commons-dbcp-1.4.jar +commons-digester-1.8.jar +commons-httpclient-3.1.jar +commons-io-2.4.jar +commons-lang-2.6.jar +commons-lang3-3.3.2.jar +commons-logging-1.1.3.jar +commons-math3-3.4.1.jar +commons-net-2.2.jar +commons-pool-1.5.4.jar +compress-lzf-1.0.3.jar +core-1.1.2.jar +curator-client-2.6.0.jar +curator-framework-2.6.0.jar +curator-recipes-2.6.0.jar +datanucleus-api-jdo-3.2.6.jar +datanucleus-core-3.2.10.jar +datanucleus-rdbms-3.2.9.jar +derby-10.10.1.1.jar +eigenbase-properties-1.1.5.jar +geronimo-annotation_1.0_spec-1.1.1.jar +geronimo-jaspic_1.0_spec-1.0.jar +geronimo-jta_1.1_spec-1.1.1.jar +gson-2.2.4.jar +guava-14.0.1.jar +guice-3.0.jar +guice-servlet-3.0.jar +hadoop-annotations-2.6.0.jar +hadoop-auth-2.6.0.jar +hadoop-client-2.6.0.jar +hadoop-common-2.6.0.jar +hadoop-hdfs-2.6.0.jar +hadoop-mapreduce-client-app-2.6.0.jar +hadoop-mapreduce-client-common-2.6.0.jar +hadoop-mapreduce-client-core-2.6.0.jar +hadoop-mapreduce-client-jobclient-2.6.0.jar +hadoop-mapreduce-client-shuffle-2.6.0.jar +hadoop-yarn-api-2.6.0.jar +hadoop-yarn-client-2.6.0.jar +hadoop-yarn-common-2.6.0.jar +hadoop-yarn-server-common-2.6.0.jar +hadoop-yarn-server-web-proxy-2.6.0.jar +htrace-core-3.0.4.jar +httpclient-4.3.2.jar +httpcore-4.3.2.jar +ivy-2.4.0.jar +jackson-annotations-2.5.3.jar +jackson-core-2.5.3.jar +jackson-core-asl-1.9.13.jar +jackson-databind-2.5.3.jar +jackson-jaxrs-1.9.13.jar +jackson-mapper-asl-1.9.13.jar +jackson-module-scala_2.11-2.5.3.jar +jackson-xc-1.9.13.jar +janino-2.7.8.jar +java-xmlbuilder-1.0.jar +javax.inject-1.jar +javax.servlet-3.0.0.v201112011016.jar +javolution-5.5.1.jar +jaxb-api-2.2.2.jar +jaxb-impl-2.2.3-1.jar +jcl-over-slf4j-1.7.16.jar +jdo-api-3.0.1.jar +jersey-client-1.9.jar +jersey-core-1.9.jar +jersey-guice-1.9.jar +jersey-json-1.9.jar +jersey-server-1.9.jar +jets3t-0.9.3.jar +jettison-1.1.jar +jetty-6.1.26.jar +jetty-all-7.6.0.v20120127.jar +jetty-util-6.1.26.jar +jline-2.12.jar +joda-time-2.9.jar +jodd-core-3.5.2.jar +jpam-1.1.jar +json-20090211.jar +json4s-ast_2.11-3.2.10.jar +json4s-core_2.11-3.2.10.jar +json4s-jackson_2.11-3.2.10.jar +jsr305-1.3.9.jar +jta-1.1.jar +jtransforms-2.4.0.jar +jul-to-slf4j-1.7.16.jar +kryo-shaded-3.0.3.jar +leveldbjni-all-1.8.jar +libfb303-0.9.2.jar +libthrift-0.9.2.jar +log4j-1.2.17.jar +lz4-1.3.0.jar +mail-1.4.7.jar +mesos-0.21.1-shaded-protobuf.jar +metrics-core-3.1.2.jar +metrics-graphite-3.1.2.jar +metrics-json-3.1.2.jar +metrics-jvm-3.1.2.jar +minlog-1.3.0.jar +mx4j-3.0.2.jar +netty-3.8.0.Final.jar +netty-all-4.0.29.Final.jar +objenesis-2.1.jar +opencsv-2.3.jar +oro-2.0.8.jar +paranamer-2.6.jar +parquet-column-1.7.0.jar +parquet-common-1.7.0.jar +parquet-encoding-1.7.0.jar +parquet-format-2.3.0-incubating.jar +parquet-generator-1.7.0.jar +parquet-hadoop-1.7.0.jar +parquet-hadoop-bundle-1.6.0.jar +parquet-jackson-1.7.0.jar +pmml-agent-1.2.7.jar +pmml-model-1.2.7.jar +pmml-schema-1.2.7.jar +protobuf-java-2.5.0.jar +py4j-0.9.2.jar +pyrolite-4.9.jar +scala-compiler-2.11.8.jar +scala-library-2.11.8.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.8.jar +scala-xml_2.11-1.0.2.jar +scalap-2.11.8.jar +slf4j-api-1.7.16.jar +slf4j-log4j12-1.7.16.jar +snappy-0.2.jar +snappy-java-1.1.2.4.jar +spire-macros_2.11-0.7.4.jar +spire_2.11-0.7.4.jar +stax-api-1.0-2.jar +stax-api-1.0.1.jar +stream-2.7.0.jar +stringtemplate-3.2.1.jar +super-csv-2.2.0.jar +univocity-parsers-2.0.2.jar +xbean-asm5-shaded-4.4.jar +xercesImpl-2.9.1.jar +xmlenc-0.52.jar +xz-1.0.jar +zookeeper-3.4.6.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 new file mode 100644 index 0000000000000..71e51883d5abe --- /dev/null +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -0,0 +1,181 @@ +JavaEWAH-0.3.2.jar +RoaringBitmap-0.5.11.jar +ST4-4.0.4.jar +activation-1.1.1.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar +antlr4-runtime-4.5.2-1.jar +aopalliance-1.0.jar +apache-log4j-extras-1.2.17.jar +apacheds-i18n-2.0.0-M15.jar +apacheds-kerberos-codec-2.0.0-M15.jar +api-asn1-api-1.0.0-M20.jar +api-util-1.0.0-M20.jar +arpack_combined_all-0.1.jar +asm-3.1.jar +asm-commons-3.1.jar +asm-tree-3.1.jar +avro-1.7.7.jar +avro-ipc-1.7.7.jar +avro-mapred-1.7.7-hadoop2.jar +base64-2.3.8.jar +bcprov-jdk15on-1.51.jar +bonecp-0.8.0.RELEASE.jar +breeze-macros_2.11-0.11.2.jar +breeze_2.11-0.11.2.jar +calcite-avatica-1.2.0-incubating.jar +calcite-core-1.2.0-incubating.jar +calcite-linq4j-1.2.0-incubating.jar +chill-java-0.8.0.jar +chill_2.11-0.8.0.jar +commons-beanutils-1.7.0.jar +commons-beanutils-core-1.8.0.jar +commons-cli-1.2.jar +commons-codec-1.10.jar +commons-collections-3.2.2.jar +commons-compiler-2.7.6.jar +commons-compress-1.4.1.jar +commons-configuration-1.6.jar +commons-dbcp-1.4.jar +commons-digester-1.8.jar +commons-httpclient-3.1.jar +commons-io-2.4.jar +commons-lang-2.6.jar +commons-lang3-3.3.2.jar +commons-logging-1.1.3.jar +commons-math3-3.4.1.jar +commons-net-2.2.jar +commons-pool-1.5.4.jar +compress-lzf-1.0.3.jar +core-1.1.2.jar +curator-client-2.6.0.jar +curator-framework-2.6.0.jar +curator-recipes-2.6.0.jar +datanucleus-api-jdo-3.2.6.jar +datanucleus-core-3.2.10.jar +datanucleus-rdbms-3.2.9.jar +derby-10.10.1.1.jar +eigenbase-properties-1.1.5.jar +geronimo-annotation_1.0_spec-1.1.1.jar +geronimo-jaspic_1.0_spec-1.0.jar +geronimo-jta_1.1_spec-1.1.1.jar +gson-2.2.4.jar +guava-14.0.1.jar +guice-3.0.jar +guice-servlet-3.0.jar +hadoop-annotations-2.7.0.jar +hadoop-auth-2.7.0.jar +hadoop-client-2.7.0.jar +hadoop-common-2.7.0.jar +hadoop-hdfs-2.7.0.jar +hadoop-mapreduce-client-app-2.7.0.jar +hadoop-mapreduce-client-common-2.7.0.jar +hadoop-mapreduce-client-core-2.7.0.jar +hadoop-mapreduce-client-jobclient-2.7.0.jar +hadoop-mapreduce-client-shuffle-2.7.0.jar +hadoop-yarn-api-2.7.0.jar +hadoop-yarn-client-2.7.0.jar +hadoop-yarn-common-2.7.0.jar +hadoop-yarn-server-common-2.7.0.jar +hadoop-yarn-server-web-proxy-2.7.0.jar +htrace-core-3.1.0-incubating.jar +httpclient-4.3.2.jar +httpcore-4.3.2.jar +ivy-2.4.0.jar +jackson-annotations-2.5.3.jar +jackson-core-2.5.3.jar +jackson-core-asl-1.9.13.jar +jackson-databind-2.5.3.jar +jackson-jaxrs-1.9.13.jar +jackson-mapper-asl-1.9.13.jar +jackson-module-scala_2.11-2.5.3.jar +jackson-xc-1.9.13.jar +janino-2.7.8.jar +java-xmlbuilder-1.0.jar +javax.inject-1.jar +javax.servlet-3.0.0.v201112011016.jar +javolution-5.5.1.jar +jaxb-api-2.2.2.jar +jaxb-impl-2.2.3-1.jar +jcl-over-slf4j-1.7.16.jar +jdo-api-3.0.1.jar +jersey-client-1.9.jar +jersey-core-1.9.jar +jersey-guice-1.9.jar +jersey-json-1.9.jar +jersey-server-1.9.jar +jets3t-0.9.3.jar +jettison-1.1.jar +jetty-6.1.26.jar +jetty-all-7.6.0.v20120127.jar +jetty-util-6.1.26.jar +jline-2.12.jar +joda-time-2.9.jar +jodd-core-3.5.2.jar +jpam-1.1.jar +json-20090211.jar +json4s-ast_2.11-3.2.10.jar +json4s-core_2.11-3.2.10.jar +json4s-jackson_2.11-3.2.10.jar +jsp-api-2.1.jar +jsr305-1.3.9.jar +jta-1.1.jar +jtransforms-2.4.0.jar +jul-to-slf4j-1.7.16.jar +kryo-shaded-3.0.3.jar +leveldbjni-all-1.8.jar +libfb303-0.9.2.jar +libthrift-0.9.2.jar +log4j-1.2.17.jar +lz4-1.3.0.jar +mail-1.4.7.jar +mesos-0.21.1-shaded-protobuf.jar +metrics-core-3.1.2.jar +metrics-graphite-3.1.2.jar +metrics-json-3.1.2.jar +metrics-jvm-3.1.2.jar +minlog-1.3.0.jar +mx4j-3.0.2.jar +netty-3.8.0.Final.jar +netty-all-4.0.29.Final.jar +objenesis-2.1.jar +opencsv-2.3.jar +oro-2.0.8.jar +paranamer-2.6.jar +parquet-column-1.7.0.jar +parquet-common-1.7.0.jar +parquet-encoding-1.7.0.jar +parquet-format-2.3.0-incubating.jar +parquet-generator-1.7.0.jar +parquet-hadoop-1.7.0.jar +parquet-hadoop-bundle-1.6.0.jar +parquet-jackson-1.7.0.jar +pmml-agent-1.2.7.jar +pmml-model-1.2.7.jar +pmml-schema-1.2.7.jar +protobuf-java-2.5.0.jar +py4j-0.9.2.jar +pyrolite-4.9.jar +scala-compiler-2.11.8.jar +scala-library-2.11.8.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.8.jar +scala-xml_2.11-1.0.2.jar +scalap-2.11.8.jar +slf4j-api-1.7.16.jar +slf4j-log4j12-1.7.16.jar +snappy-0.2.jar +snappy-java-1.1.2.4.jar +spire-macros_2.11-0.7.4.jar +spire_2.11-0.7.4.jar +stax-api-1.0-2.jar +stax-api-1.0.1.jar +stream-2.7.0.jar +stringtemplate-3.2.1.jar +super-csv-2.2.0.jar +univocity-parsers-2.0.2.jar +xbean-asm5-shaded-4.4.jar +xercesImpl-2.9.1.jar +xmlenc-0.52.jar +xz-1.0.jar +zookeeper-3.4.6.jar diff --git a/dev/lint-java b/dev/lint-java new file mode 100755 index 0000000000000..fe8ab83d562d1 --- /dev/null +++ b/dev/lint-java @@ -0,0 +1,30 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" +SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" + +ERRORS=$($SCRIPT_DIR/../build/mvn -Pkinesis-asl -Pyarn -Phive -Phive-thriftserver checkstyle:check | grep ERROR) + +if test ! -z "$ERRORS"; then + echo -e "Checkstyle checks failed at following occurrences:\n$ERRORS" + exit 1 +else + echo -e "Checkstyle checks passed." +fi diff --git a/dev/lint-python b/dev/lint-python index 0b97213ae3dff..63487043a50b6 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -19,11 +19,13 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" -PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/ ./dev/sparktestsupport" +PATHS_TO_CHECK="./python/pyspark/ ./examples/src/main/python/ ./dev/sparktestsupport" PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/run-tests.py ./dev/run-tests-jenkins.py" PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" PYLINT_INSTALL_INFO="$SPARK_ROOT_DIR/dev/pylint-info.txt" +SPHINXBUILD=${SPHINXBUILD:=sphinx-build} +SPHINX_REPORT_PATH="$SPARK_ROOT_DIR/dev/sphinx-report.txt" cd "$SPARK_ROOT_DIR" @@ -35,7 +37,7 @@ compile_status="${PIPESTATUS[0]}" #+ See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 #+ TODOs: #+ - Download pep8 from PyPI. It's more "official". -PEP8_VERSION="1.6.2" +PEP8_VERSION="1.7.0" PEP8_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pep8-$PEP8_VERSION.py" PEP8_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/jcrocholl/pep8/$PEP8_VERSION/pep8.py" @@ -58,27 +60,11 @@ export "PYTHONPATH=$SPARK_ROOT_DIR/dev/pylint" export "PYLINT_HOME=$PYTHONPATH" export "PATH=$PYTHONPATH:$PATH" -# if [ ! -d "$PYLINT_HOME" ]; then -# mkdir "$PYLINT_HOME" -# # Redirect the annoying pylint installation output. -# easy_install -d "$PYLINT_HOME" pylint==1.4.4 &>> "$PYLINT_INSTALL_INFO" -# easy_install_status="$?" -# -# if [ "$easy_install_status" -ne 0 ]; then -# echo "Unable to install pylint locally in \"$PYTHONPATH\"." -# cat "$PYLINT_INSTALL_INFO" -# exit "$easy_install_status" -# fi -# -# rm "$PYLINT_INSTALL_INFO" -# -# fi - # There is no need to write this output to a file #+ first, but we do so so that the check status can #+ be output before the report, like with the #+ scalastyle and RAT checks. -python "$PEP8_SCRIPT_PATH" --ignore=E402,E731,E241,W503,E226 $PATHS_TO_CHECK >> "$PEP8_REPORT_PATH" +python "$PEP8_SCRIPT_PATH" --ignore=E402,E731,E241,W503,E226 --config=dev/tox.ini $PATHS_TO_CHECK >> "$PEP8_REPORT_PATH" pep8_status="${PIPESTATUS[0]}" if [ "$compile_status" -eq 0 -a "$pep8_status" -eq 0 ]; then @@ -90,25 +76,32 @@ fi if [ "$lint_status" -ne 0 ]; then echo "PEP8 checks failed." cat "$PEP8_REPORT_PATH" + rm "$PEP8_REPORT_PATH" + exit "$lint_status" else echo "PEP8 checks passed." + rm "$PEP8_REPORT_PATH" fi -rm "$PEP8_REPORT_PATH" - -# for to_be_checked in "$PATHS_TO_CHECK" -# do -# pylint --rcfile="$SPARK_ROOT_DIR/pylintrc" $to_be_checked >> "$PYLINT_REPORT_PATH" -# done - -# if [ "${PIPESTATUS[0]}" -ne 0 ]; then -# lint_status=1 -# echo "Pylint checks failed." -# cat "$PYLINT_REPORT_PATH" -# else -# echo "Pylint checks passed." -# fi - -# rm "$PYLINT_REPORT_PATH" - -exit "$lint_status" +# Check that the documentation builds acceptably, skip check if sphinx is not installed. +if hash "$SPHINXBUILD" 2> /dev/null; then + cd python/docs + make clean + # Treat warnings as errors so we stop correctly + SPHINXOPTS="-a -W" make html &> "$SPHINX_REPORT_PATH" || lint_status=1 + if [ "$lint_status" -ne 0 ]; then + echo "pydoc checks failed." + cat "$SPHINX_REPORT_PATH" + echo "re-running make html to print full warning list" + make clean + SPHINXOPTS="-a" make html + rm "$SPHINX_REPORT_PATH" + exit "$lint_status" + else + echo "pydoc checks passed." + rm "$SPHINX_REPORT_PATH" + fi + cd ../.. +else + echo >&2 "The $SPHINXBUILD command was not found. Skipping pydoc checks for now" +fi diff --git a/dev/lint-r.R b/dev/lint-r.R index 999eef571b824..87ee36d5c9b68 100644 --- a/dev/lint-r.R +++ b/dev/lint-r.R @@ -27,7 +27,7 @@ if (! library(SparkR, lib.loc = LOCAL_LIB_LOC, logical.return = TRUE)) { # Installs lintr from Github in a local directory. # NOTE: The CRAN's version is too old to adapt to our rules. if ("lintr" %in% row.names(installed.packages()) == FALSE) { - devtools::install_github("jimhester/lintr") + devtools::install_github("jimhester/lintr@a769c0b") } library(lintr) diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh new file mode 100755 index 0000000000000..4f7544f6ea78b --- /dev/null +++ b/dev/make-distribution.sh @@ -0,0 +1,225 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# Script to create a binary distribution for easy deploys of Spark. +# The distribution directory defaults to dist/ but can be overridden below. +# The distribution contains fat (assembly) jars that include the Scala library, +# so it is completely self contained. +# It does not contain source or *.class files. + +set -o pipefail +set -e +set -x + +# Figure out where the Spark framework is installed +SPARK_HOME="$(cd "`dirname "$0"`/.."; pwd)" +DISTDIR="$SPARK_HOME/dist" + +MAKE_TGZ=false +NAME=none +MVN="$SPARK_HOME/build/mvn" + +function exit_with_usage { + echo "make-distribution.sh - tool for making binary distributions of Spark" + echo "" + echo "usage:" + cl_options="[--name] [--tgz] [--mvn ]" + echo "make-distribution.sh $cl_options " + echo "See Spark's \"Building Spark\" doc for correct Maven options." + echo "" + exit 1 +} + +# Parse arguments +while (( "$#" )); do + case $1 in + --hadoop) + echo "Error: '--hadoop' is no longer supported:" + echo "Error: use Maven profiles and options -Dhadoop.version and -Dyarn.version instead." + echo "Error: Related profiles include hadoop-2.2, hadoop-2.3 and hadoop-2.4." + exit_with_usage + ;; + --with-yarn) + echo "Error: '--with-yarn' is no longer supported, use Maven option -Pyarn" + exit_with_usage + ;; + --with-hive) + echo "Error: '--with-hive' is no longer supported, use Maven options -Phive and -Phive-thriftserver" + exit_with_usage + ;; + --tgz) + MAKE_TGZ=true + ;; + --mvn) + MVN="$2" + shift + ;; + --name) + NAME="$2" + shift + ;; + --help) + exit_with_usage + ;; + *) + break + ;; + esac + shift +done + +if [ -z "$JAVA_HOME" ]; then + # Fall back on JAVA_HOME from rpm, if found + if [ $(command -v rpm) ]; then + RPM_JAVA_HOME="$(rpm -E %java_home 2>/dev/null)" + if [ "$RPM_JAVA_HOME" != "%java_home" ]; then + JAVA_HOME="$RPM_JAVA_HOME" + echo "No JAVA_HOME set, proceeding with '$JAVA_HOME' learned from rpm" + fi + fi +fi + +if [ -z "$JAVA_HOME" ]; then + echo "Error: JAVA_HOME is not set, cannot proceed." + exit -1 +fi + +if [ $(command -v git) ]; then + GITREV=$(git rev-parse --short HEAD 2>/dev/null || :) + if [ ! -z "$GITREV" ]; then + GITREVSTRING=" (git revision $GITREV)" + fi + unset GITREV +fi + + +if [ ! "$(command -v "$MVN")" ] ; then + echo -e "Could not locate Maven command: '$MVN'." + echo -e "Specify the Maven command with the --mvn flag" + exit -1; +fi + +VERSION=$("$MVN" help:evaluate -Dexpression=project.version $@ 2>/dev/null | grep -v "INFO" | tail -n 1) +SCALA_VERSION=$("$MVN" help:evaluate -Dexpression=scala.binary.version $@ 2>/dev/null\ + | grep -v "INFO"\ + | tail -n 1) +SPARK_HADOOP_VERSION=$("$MVN" help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ + | grep -v "INFO"\ + | tail -n 1) +SPARK_HIVE=$("$MVN" help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\ + | grep -v "INFO"\ + | fgrep --count "hive";\ + # Reset exit status to 0, otherwise the script stops here if the last grep finds nothing\ + # because we use "set -o pipefail" + echo -n) + +if [ "$NAME" == "none" ]; then + NAME=$SPARK_HADOOP_VERSION +fi + +echo "Spark version is $VERSION" + +if [ "$MAKE_TGZ" == "true" ]; then + echo "Making spark-$VERSION-bin-$NAME.tgz" +else + echo "Making distribution for Spark $VERSION in $DISTDIR..." +fi + +# Build uber fat JAR +cd "$SPARK_HOME" + +export MAVEN_OPTS="${MAVEN_OPTS:--Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m}" + +# Store the command as an array because $MVN variable might have spaces in it. +# Normal quoting tricks don't work. +# See: http://mywiki.wooledge.org/BashFAQ/050 +BUILD_COMMAND=("$MVN" clean package -DskipTests $@) + +# Actually build the jar +echo -e "\nBuilding with..." +echo -e "\$ ${BUILD_COMMAND[@]}\n" + +"${BUILD_COMMAND[@]}" + +# Make directories +rm -rf "$DISTDIR" +mkdir -p "$DISTDIR/jars" +echo "Spark $VERSION$GITREVSTRING built for Hadoop $SPARK_HADOOP_VERSION" > "$DISTDIR/RELEASE" +echo "Build flags: $@" >> "$DISTDIR/RELEASE" + +# Copy jars +cp "$SPARK_HOME"/assembly/target/scala*/jars/* "$DISTDIR/jars/" + +# Only create the yarn directory if the yarn artifacts were build. +if [ -f "$SPARK_HOME"/common/network-yarn/target/scala*/spark-*-yarn-shuffle.jar ]; then + mkdir "$DISTDIR"/yarn + cp "$SPARK_HOME"/common/network-yarn/target/scala*/spark-*-yarn-shuffle.jar "$DISTDIR/yarn" +fi + +# Copy examples and dependencies +mkdir -p "$DISTDIR/examples/jars" +cp "$SPARK_HOME"/examples/target/scala*/jars/* "$DISTDIR/examples/jars" + +# Deduplicate jars that have already been packaged as part of the main Spark dependencies. +for f in "$DISTDIR/examples/jars/"*; do + name=$(basename "$f") + if [ -f "$DISTDIR/jars/$name" ]; then + rm "$DISTDIR/examples/jars/$name" + fi +done + +# Copy example sources (needed for python and SQL) +mkdir -p "$DISTDIR/examples/src/main" +cp -r "$SPARK_HOME"/examples/src/main "$DISTDIR/examples/src/" + +# Copy license and ASF files +cp "$SPARK_HOME/LICENSE" "$DISTDIR" +cp -r "$SPARK_HOME/licenses" "$DISTDIR" +cp "$SPARK_HOME/NOTICE" "$DISTDIR" + +if [ -e "$SPARK_HOME"/CHANGES.txt ]; then + cp "$SPARK_HOME/CHANGES.txt" "$DISTDIR" +fi + +# Copy data files +cp -r "$SPARK_HOME/data" "$DISTDIR" + +# Copy other things +mkdir "$DISTDIR"/conf +cp "$SPARK_HOME"/conf/*.template "$DISTDIR"/conf +cp "$SPARK_HOME/README.md" "$DISTDIR" +cp -r "$SPARK_HOME/bin" "$DISTDIR" +cp -r "$SPARK_HOME/python" "$DISTDIR" +cp -r "$SPARK_HOME/sbin" "$DISTDIR" +# Copy SparkR if it exists +if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then + mkdir -p "$DISTDIR"/R/lib + cp -r "$SPARK_HOME/R/lib/SparkR" "$DISTDIR"/R/lib + cp "$SPARK_HOME/R/lib/sparkr.zip" "$DISTDIR"/R/lib +fi + +if [ "$MAKE_TGZ" == "true" ]; then + TARDIR_NAME=spark-$VERSION-bin-$NAME + TARDIR="$SPARK_HOME/$TARDIR_NAME" + rm -rf "$TARDIR" + cp -r "$DISTDIR" "$TARDIR" + tar czf "spark-$VERSION-bin-$NAME.tgz" -C "$SPARK_HOME" "$TARDIR_NAME" + rm -rf "$TARDIR" +fi diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index bf1a000f46791..5ab285eae99b7 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -355,11 +355,21 @@ def standardize_jira_ref(text): return clean_text + +def get_current_ref(): + ref = run_cmd("git rev-parse --abbrev-ref HEAD").strip() + if ref == 'HEAD': + # The current ref is a detached HEAD, so grab its SHA. + return run_cmd("git rev-parse HEAD").strip() + else: + return ref + + def main(): global original_head os.chdir(SPARK_HOME) - original_head = run_cmd("git rev-parse HEAD")[:8] + original_head = get_current_ref() branches = get_json("%s/branches" % GITHUB_API_BASE) branch_names = filter(lambda x: x.startswith("branch-"), [x['name'] for x in branches]) @@ -449,5 +459,8 @@ def main(): (failure_count, test_count) = doctest.testmod() if failure_count: exit(-1) - - main() + try: + main() + except: + clean_up() + raise diff --git a/dev/mima b/dev/mima index 2952fa65d42ff..c3553490451c8 100755 --- a/dev/mima +++ b/dev/mima @@ -24,26 +24,19 @@ set -e FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" -echo -e "q\n" | build/sbt oldDeps/update -rm -f .generated-mima* - -generate_mima_ignore() { - SPARK_JAVA_OPTS="-XX:MaxPermSize=1g -Xmx2g" \ - ./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore -} +SPARK_PROFILES="-Pyarn -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" +TOOLS_CLASSPATH="$(build/sbt -DcopyDependencies=false "export tools/fullClasspath" | tail -n1)" +OLD_DEPS_CLASSPATH="$(build/sbt -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | tail -n1)" -# Generate Mima Ignore is called twice, first with latest built jars -# on the classpath and then again with previous version jars on the classpath. -# Because of a bug in GenerateMIMAIgnore that when old jars are ahead on classpath -# it did not process the new classes (which are in assembly jar). -generate_mima_ignore - -export SPARK_CLASSPATH="`find lib_managed \( -name '*spark*jar' -a -type f \) | tr "\\n" ":"`" -echo "SPARK_CLASSPATH=$SPARK_CLASSPATH" +rm -f .generated-mima* -generate_mima_ignore +java \ + -XX:MaxPermSize=1g \ + -Xmx2g \ + -cp "$TOOLS_CLASSPATH:$OLD_DEPS_CLASSPATH" \ + org.apache.spark.tools.GenerateMIMAIgnore -echo -e "q\n" | build/sbt mima-report-binary-issues | grep -v -e "info.*Resolving" +echo -e "q\n" | build/sbt -DcopyDependencies=false "$@" mimaReportBinaryIssues | grep -v -e "info.*Resolving" ret_val=$? if [ $ret_val != 0 ]; then diff --git a/dev/requirements.txt b/dev/requirements.txt new file mode 100644 index 0000000000000..bf042d22a8b47 --- /dev/null +++ b/dev/requirements.txt @@ -0,0 +1,3 @@ +jira==1.0.3 +PyGithub==1.26.0 +Unidecode==0.04.19 diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 623004310e189..a48d918f9dc1f 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -119,10 +119,12 @@ def run_tests(tests_timeout): ERROR_CODES["BLOCK_GENERAL"]: 'some tests', ERROR_CODES["BLOCK_RAT"]: 'RAT tests', ERROR_CODES["BLOCK_SCALA_STYLE"]: 'Scala style tests', + ERROR_CODES["BLOCK_JAVA_STYLE"]: 'Java style tests', ERROR_CODES["BLOCK_PYTHON_STYLE"]: 'Python style tests', ERROR_CODES["BLOCK_R_STYLE"]: 'R style tests', ERROR_CODES["BLOCK_DOCUMENTATION"]: 'to generate documentation', ERROR_CODES["BLOCK_BUILD"]: 'to build', + ERROR_CODES["BLOCK_BUILD_TESTS"]: 'build dependency tests', ERROR_CODES["BLOCK_MIMA"]: 'MiMa tests', ERROR_CODES["BLOCK_SPARK_UNIT_TESTS"]: 'Spark unit tests', ERROR_CODES["BLOCK_PYSPARK_UNIT_TESTS"]: 'PySpark unit tests', @@ -162,14 +164,16 @@ def main(): if "test-maven" in ghprb_pull_title: os.environ["AMPLAB_JENKINS_BUILD_TOOL"] = "maven" # Switch the Hadoop profile based on the PR title: - if "test-hadoop1.0" in ghprb_pull_title: - os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop1.0" - if "test-hadoop2.2" in ghprb_pull_title: - os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.0" if "test-hadoop2.2" in ghprb_pull_title: os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.2" if "test-hadoop2.3" in ghprb_pull_title: os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.3" + if "test-hadoop2.4" in ghprb_pull_title: + os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.4" + if "test-hadoop2.6" in ghprb_pull_title: + os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.6" + if "test-hadoop2.7" in ghprb_pull_title: + os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.7" build_display_name = os.environ["BUILD_DISPLAY_NAME"] build_url = os.environ["BUILD_URL"] @@ -196,7 +200,6 @@ def main(): pr_tests = [ "pr_merge_ability", "pr_public_classes" - # DISABLED (pwendell) "pr_new_dependencies" ] # `bind_message_base` returns a function to generate messages for Github posting diff --git a/dev/run-tests.py b/dev/run-tests.py index 9e1abb0697192..cbe347274e62c 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -29,6 +29,7 @@ from sparktestsupport import SPARK_HOME, USER_HOME, ERROR_CODES from sparktestsupport.shellutils import exit_from_command_with_retcode, run_cmd, rm_r, which +from sparktestsupport.toposort import toposort_flatten, toposort import sparktestsupport.modules as modules @@ -43,7 +44,7 @@ def determine_modules_for_files(filenames): If a file is not associated with a more specific submodule, then this method will consider that file to belong to the 'root' module. - >>> sorted(x.name for x in determine_modules_for_files(["python/pyspark/a.py", "sql/test/foo"])) + >>> sorted(x.name for x in determine_modules_for_files(["python/pyspark/a.py", "sql/core/foo"])) ['pyspark-core', 'sql'] >>> [x.name for x in determine_modules_for_files(["file_not_matched_by_any_subproject"])] ['root'] @@ -99,24 +100,28 @@ def determine_modules_to_test(changed_modules): Given a set of modules that have changed, compute the transitive closure of those modules' dependent modules in order to determine the set of modules that should be tested. - >>> sorted(x.name for x in determine_modules_to_test([modules.root])) + Returns a topologically-sorted list of modules (ties are broken by sorting on module names). + + >>> [x.name for x in determine_modules_to_test([modules.root])] + ['root'] + >>> [x.name for x in determine_modules_to_test([modules.build])] ['root'] - >>> sorted(x.name for x in determine_modules_to_test([modules.graphx])) - ['examples', 'graphx'] - >>> x = sorted(x.name for x in determine_modules_to_test([modules.sql])) + >>> [x.name for x in determine_modules_to_test([modules.graphx])] + ['graphx', 'examples'] + >>> x = [x.name for x in determine_modules_to_test([modules.sql])] >>> x # doctest: +NORMALIZE_WHITESPACE - ['examples', 'hive-thriftserver', 'mllib', 'pyspark-ml', \ - 'pyspark-mllib', 'pyspark-sql', 'sparkr', 'sql'] + ['sql', 'hive', 'mllib', 'examples', 'hive-thriftserver', 'pyspark-sql', 'sparkr', + 'pyspark-mllib', 'pyspark-ml'] """ - # If we're going to have to run all of the tests, then we can just short-circuit - # and return 'root'. No module depends on root, so if it appears then it will be - # in changed_modules. - if modules.root in changed_modules: - return [modules.root] modules_to_test = set() for module in changed_modules: modules_to_test = modules_to_test.union(determine_modules_to_test(module.dependent_modules)) - return modules_to_test.union(set(changed_modules)) + modules_to_test = modules_to_test.union(set(changed_modules)) + # If we need to run all of the tests, then we should short-circuit and return 'root' + if modules.root in modules_to_test: + return [modules.root] + return toposort_flatten( + {m: set(m.dependencies).intersection(modules_to_test) for m in modules_to_test}, sort=True) def determine_tags_to_exclude(changed_modules): @@ -148,7 +153,7 @@ def determine_java_executable(): return java_exe if java_exe else which("java") -JavaVersion = namedtuple('JavaVersion', ['major', 'minor', 'patch', 'update']) +JavaVersion = namedtuple('JavaVersion', ['major', 'minor', 'patch']) def determine_java_version(java_exe): @@ -164,14 +169,13 @@ def determine_java_version(java_exe): # find raw version string, eg 'java version "1.8.0_25"' raw_version_str = next(x for x in raw_output_lines if " version " in x) - match = re.search('(\d+)\.(\d+)\.(\d+)_(\d+)', raw_version_str) + match = re.search('(\d+)\.(\d+)\.(\d+)', raw_version_str) major = int(match.group(1)) minor = int(match.group(2)) patch = int(match.group(3)) - update = int(match.group(4)) - return JavaVersion(major, minor, patch, update) + return JavaVersion(major, minor, patch) # ------------------------------------------------------------------------------------------------- # Functions for running the other build and test scripts @@ -198,6 +202,11 @@ def run_scala_style_checks(): run_cmd([os.path.join(SPARK_HOME, "dev", "lint-scala")]) +def run_java_style_checks(): + set_title_and_block("Running Java style checks", "BLOCK_JAVA_STYLE") + run_cmd([os.path.join(SPARK_HOME, "dev", "lint-java")]) + + def run_python_style_checks(): set_title_and_block("Running Python style checks", "BLOCK_PYTHON_STYLE") run_cmd([os.path.join(SPARK_HOME, "dev", "lint-python")]) @@ -291,16 +300,16 @@ def exec_sbt(sbt_args=()): def get_hadoop_profiles(hadoop_version): """ - For the given Hadoop version tag, return a list of SBT profile flags for + For the given Hadoop version tag, return a list of Maven/SBT profile flags for building and testing against that Hadoop version. """ sbt_maven_hadoop_profiles = { - "hadoop1.0": ["-Phadoop-1", "-Dhadoop.version=1.2.1"], - "hadoop2.0": ["-Phadoop-1", "-Dhadoop.version=2.0.0-mr1-cdh4.1.1"], "hadoop2.2": ["-Pyarn", "-Phadoop-2.2"], - "hadoop2.3": ["-Pyarn", "-Phadoop-2.3", "-Dhadoop.version=2.3.0"], + "hadoop2.3": ["-Pyarn", "-Phadoop-2.3"], + "hadoop2.4": ["-Pyarn", "-Phadoop-2.4"], "hadoop2.6": ["-Pyarn", "-Phadoop-2.6"], + "hadoop2.7": ["-Pyarn", "-Phadoop-2.7"], } if hadoop_version in sbt_maven_hadoop_profiles: @@ -327,11 +336,8 @@ def build_spark_sbt(hadoop_version): # Enable all of the profiles for the build: build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags sbt_goals = ["package", - "assembly/assembly", "streaming-kafka-assembly/assembly", "streaming-flume-assembly/assembly", - "streaming-mqtt-assembly/assembly", - "streaming-mqtt/test:assembly", "streaming-kinesis-asl-assembly/assembly"] profiles_and_goals = build_profiles + sbt_goals @@ -341,6 +347,16 @@ def build_spark_sbt(hadoop_version): exec_sbt(profiles_and_goals) +def build_spark_assembly_sbt(hadoop_version): + # Enable all of the profiles for the build: + build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags + sbt_goals = ["assembly/package"] + profiles_and_goals = build_profiles + sbt_goals + print("[info] Building Spark assembly (w/Hive 1.2.1) using SBT with these arguments: ", + " ".join(profiles_and_goals)) + exec_sbt(profiles_and_goals) + + def build_apache_spark(build_tool, hadoop_version): """Will build Spark against Hive v1.2.1 given the passed in build tool (either `sbt` or `maven`). Defaults to using `sbt`.""" @@ -355,9 +371,10 @@ def build_apache_spark(build_tool, hadoop_version): build_spark_sbt(hadoop_version) -def detect_binary_inop_with_mima(): +def detect_binary_inop_with_mima(hadoop_version): + build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags set_title_and_block("Detecting binary incompatibilities with MiMa", "BLOCK_MIMA") - run_cmd([os.path.join(SPARK_HOME, "dev", "mima")]) + run_cmd([os.path.join(SPARK_HOME, "dev", "mima")] + build_profiles) def run_scala_tests_maven(test_profiles): @@ -373,12 +390,12 @@ def run_scala_tests_maven(test_profiles): def run_scala_tests_sbt(test_modules, test_profiles): - sbt_test_goals = set(itertools.chain.from_iterable(m.sbt_test_goals for m in test_modules)) + sbt_test_goals = list(itertools.chain.from_iterable(m.sbt_test_goals for m in test_modules)) if not sbt_test_goals: return - profiles_and_goals = test_profiles + list(sbt_test_goals) + profiles_and_goals = test_profiles + sbt_test_goals print("[info] Running Spark tests using SBT with these arguments: ", " ".join(profiles_and_goals)) @@ -415,6 +432,12 @@ def run_python_tests(test_modules, parallelism): run_cmd(command) +def run_build_tests(): + set_title_and_block("Running build tests", "BLOCK_BUILD_TESTS") + run_cmd([os.path.join(SPARK_HOME, "dev", "test-dependencies.sh")]) + pass + + def run_sparkr_tests(): set_title_and_block("Running SparkR tests", "BLOCK_SPARKR_UNIT_TESTS") @@ -473,7 +496,7 @@ def main(): if which("R"): run_cmd([os.path.join(SPARK_HOME, "R", "install-dev.sh")]) else: - print("Can't install SparkR as R is was not found in PATH") + print("Cannot install SparkR as R was not found in PATH") if os.environ.get("AMPLAB_JENKINS"): # if we're on the Amplab Jenkins build servers setup variables @@ -520,8 +543,16 @@ def main(): run_apache_rat_checks() # style checks - if not changed_files or any(f.endswith(".scala") for f in changed_files): + if not changed_files or any(f.endswith(".scala") + or f.endswith("scalastyle-config.xml") + for f in changed_files): run_scala_style_checks() + if not changed_files or any(f.endswith(".java") + or f.endswith("checkstyle.xml") + or f.endswith("checkstyle-suppressions.xml") + for f in changed_files): + # run_java_style_checks() + pass if not changed_files or any(f.endswith(".py") for f in changed_files): run_python_style_checks() if not changed_files or any(f.endswith(".R") for f in changed_files): @@ -532,13 +563,19 @@ def main(): # if "DOCS" in changed_modules and test_env == "amplab_jenkins": # build_spark_documentation() + if any(m.should_run_build_tests for m in test_modules): + run_build_tests() + # spark build build_apache_spark(build_tool, hadoop_version) # backwards compatibility checks if build_tool == "sbt": - # Note: compatiblity tests only supported in sbt for now - detect_binary_inop_with_mima() + # Note: compatibility tests only supported in sbt for now + detect_binary_inop_with_mima(hadoop_version) + # Since we did not build assembly/package before running dev/mima, we need to + # do it here because the tests still rely on it; see SPARK-13294 for details. + build_spark_assembly_sbt(hadoop_version) # run the test suites run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags) diff --git a/dev/scalastyle b/dev/scalastyle index ad93f7e85b27c..8fd3604b9f451 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -17,14 +17,17 @@ # limitations under the License. # -echo -e "q\n" | build/sbt -Pkinesis-asl -Phive -Phive-thriftserver scalastyle > scalastyle.txt -echo -e "q\n" | build/sbt -Pkinesis-asl -Phive -Phive-thriftserver test:scalastyle >> scalastyle.txt -# Check style with YARN built too -echo -e "q\n" | build/sbt -Pkinesis-asl -Pyarn -Phadoop-2.2 scalastyle >> scalastyle.txt -echo -e "q\n" | build/sbt -Pkinesis-asl -Pyarn -Phadoop-2.2 test:scalastyle >> scalastyle.txt - -ERRORS=$(cat scalastyle.txt | awk '{if($1~/error/)print}') -rm scalastyle.txt +# NOTE: echo "q" is needed because SBT prompts the user for input on encountering a build file +# with failure (either resolution or compilation); the "q" makes SBT quit. +ERRORS=$(echo -e "q\n" \ + | build/sbt \ + -Pkinesis-asl \ + -Pyarn \ + -Phive \ + -Phive-thriftserver \ + scalastyle test:scalastyle \ + | awk '{if($1~/error/)print}' \ +) if test ! -z "$ERRORS"; then echo -e "Scalastyle checks failed at following occurrences:\n$ERRORS" diff --git a/dev/sparktestsupport/__init__.py b/dev/sparktestsupport/__init__.py index 8ab6d9e37ca2f..89015f8c4fb9c 100644 --- a/dev/sparktestsupport/__init__.py +++ b/dev/sparktestsupport/__init__.py @@ -31,5 +31,7 @@ "BLOCK_SPARK_UNIT_TESTS": 18, "BLOCK_PYSPARK_UNIT_TESTS": 19, "BLOCK_SPARKR_UNIT_TESTS": 20, + "BLOCK_JAVA_STYLE": 21, + "BLOCK_BUILD_TESTS": 22, "BLOCK_TIMEOUT": 124 } diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index d65547e04db4b..c844bcff7e4f0 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -15,12 +15,14 @@ # limitations under the License. # +from functools import total_ordering import itertools import re all_modules = [] +@total_ordering class Module(object): """ A module is the basic abstraction in our test runner script. Each module consists of a set of @@ -31,7 +33,7 @@ class Module(object): def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={}, sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(), - test_tags=(), should_run_r_tests=False): + test_tags=(), should_run_r_tests=False, should_run_build_tests=False): """ Define a new module. @@ -53,6 +55,7 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= :param test_tags A set of tags that will be excluded when running unit tests if the module is not explicitly changed. :param should_run_r_tests: If true, changes in this module will trigger all R tests. + :param should_run_build_tests: If true, changes in this module will trigger build tests. """ self.name = name self.dependencies = dependencies @@ -64,6 +67,7 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= self.blacklisted_python_implementations = blacklisted_python_implementations self.test_tags = test_tags self.should_run_r_tests = should_run_r_tests + self.should_run_build_tests = should_run_build_tests self.dependent_modules = set() for dep in dependencies: @@ -73,20 +77,56 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= def contains_file(self, filename): return any(re.match(p, filename) for p in self.source_file_prefixes) + def __repr__(self): + return "Module<%s>" % self.name + + def __lt__(self, other): + return self.name < other.name + + def __eq__(self, other): + return self.name == other.name + + def __ne__(self, other): + return not (self.name == other.name) + + def __hash__(self): + return hash(self.name) + + +catalyst = Module( + name="catalyst", + dependencies=[], + source_file_regexes=[ + "sql/catalyst/", + ], + sbt_test_goals=[ + "catalyst/test", + ], +) + sql = Module( name="sql", - dependencies=[], + dependencies=[catalyst], + source_file_regexes=[ + "sql/core/", + ], + sbt_test_goals=[ + "sql/test", + ], +) + +hive = Module( + name="hive", + dependencies=[sql], source_file_regexes=[ - "sql/(?!hive-thriftserver)", + "sql/hive/", "bin/spark-sql", ], build_profile_flags=[ "-Phive", ], sbt_test_goals=[ - "catalyst/test", - "sql/test", "hive/test", ], test_tags=[ @@ -97,7 +137,7 @@ def contains_file(self, filename): hive_thriftserver = Module( name="hive-thriftserver", - dependencies=[sql], + dependencies=[hive], source_file_regexes=[ "sql/hive-thriftserver", "sbin/start-thriftserver.sh", @@ -111,6 +151,18 @@ def contains_file(self, filename): ) +sketch = Module( + name="sketch", + dependencies=[], + source_file_regexes=[ + "common/sketch/", + ], + sbt_test_goals=[ + "sketch/test" + ] +) + + graphx = Module( name="graphx", dependencies=[], @@ -143,8 +195,8 @@ def contains_file(self, filename): name="streaming-kinesis-asl", dependencies=[], source_file_regexes=[ - "extras/kinesis-asl/", - "extras/kinesis-asl-assembly/", + "external/kinesis-asl/", + "external/kinesis-asl-assembly/", ], build_profile_flags=[ "-Pkinesis-asl", @@ -158,43 +210,6 @@ def contains_file(self, filename): ) -streaming_zeromq = Module( - name="streaming-zeromq", - dependencies=[streaming], - source_file_regexes=[ - "external/zeromq", - ], - sbt_test_goals=[ - "streaming-zeromq/test", - ] -) - - -streaming_twitter = Module( - name="streaming-twitter", - dependencies=[streaming], - source_file_regexes=[ - "external/twitter", - ], - sbt_test_goals=[ - "streaming-twitter/test", - ] -) - - -streaming_mqtt = Module( - name="streaming-mqtt", - dependencies=[streaming], - source_file_regexes=[ - "external/mqtt", - "external/mqtt-assembly", - ], - sbt_test_goals=[ - "streaming-mqtt/test", - ] -) - - streaming_kafka = Module( name="streaming-kafka", dependencies=[streaming], @@ -241,9 +256,21 @@ def contains_file(self, filename): ) +mllib_local = Module( + name="mllib-local", + dependencies=[], + source_file_regexes=[ + "mllib-local", + ], + sbt_test_goals=[ + "mllib-local/test", + ] +) + + mllib = Module( name="mllib", - dependencies=[streaming, sql], + dependencies=[mllib_local, streaming, sql], source_file_regexes=[ "data/mllib/", "mllib/", @@ -256,7 +283,7 @@ def contains_file(self, filename): examples = Module( name="examples", - dependencies=[graphx, mllib, streaming, sql], + dependencies=[graphx, mllib, streaming, hive], source_file_regexes=[ "examples/", ], @@ -288,7 +315,7 @@ def contains_file(self, filename): pyspark_sql = Module( name="pyspark-sql", - dependencies=[pyspark_core, sql], + dependencies=[pyspark_core, hive], source_file_regexes=[ "python/pyspark/sql" ], @@ -313,7 +340,6 @@ def contains_file(self, filename): streaming, streaming_kafka, streaming_flume_assembly, - streaming_mqtt, streaming_kinesis_asl ], source_file_regexes=[ @@ -378,7 +404,7 @@ def contains_file(self, filename): sparkr = Module( name="sparkr", - dependencies=[sql, mllib], + dependencies=[hive, mllib], source_file_regexes=[ "R/", ], @@ -394,22 +420,22 @@ def contains_file(self, filename): ] ) - -ec2 = Module( - name="ec2", +build = Module( + name="build", dependencies=[], source_file_regexes=[ - "ec2/", - ] + ".*pom.xml", + "dev/test-dependencies.sh", + ], + should_run_build_tests=True ) - yarn = Module( name="yarn", dependencies=[], source_file_regexes=[ "yarn/", - "network/yarn/", + "common/network-yarn/", ], sbt_test_goals=[ "yarn/test", @@ -424,7 +450,7 @@ def contains_file(self, filename): # No other modules should directly depend on this module. root = Module( name="root", - dependencies=[], + dependencies=[build], # Changes to build should trigger all tests. source_file_regexes=[], # In order to run all of the tests, enable every test profile: build_profile_flags=list(set( @@ -433,5 +459,6 @@ def contains_file(self, filename): "test", ], python_test_goals=list(itertools.chain.from_iterable(m.python_test_goals for m in all_modules)), - should_run_r_tests=True + should_run_r_tests=True, + should_run_build_tests=True ) diff --git a/dev/sparktestsupport/toposort.py b/dev/sparktestsupport/toposort.py new file mode 100644 index 0000000000000..6c67b4504bc3b --- /dev/null +++ b/dev/sparktestsupport/toposort.py @@ -0,0 +1,85 @@ +####################################################################### +# Implements a topological sort algorithm. +# +# Copyright 2014 True Blade Systems, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Notes: +# Based on http://code.activestate.com/recipes/578272-topological-sort +# with these major changes: +# Added unittests. +# Deleted doctests (maybe not the best idea in the world, but it cleans +# up the docstring). +# Moved functools import to the top of the file. +# Changed assert to a ValueError. +# Changed iter[items|keys] to [items|keys], for python 3 +# compatibility. I don't think it matters for python 2 these are +# now lists instead of iterables. +# Copy the input so as to leave it unmodified. +# Renamed function from toposort2 to toposort. +# Handle empty input. +# Switch tests to use set literals. +# +######################################################################## + +from functools import reduce as _reduce + + +__all__ = ['toposort', 'toposort_flatten'] + + +def toposort(data): + """Dependencies are expressed as a dictionary whose keys are items +and whose values are a set of dependent items. Output is a list of +sets in topological order. The first set consists of items with no +dependences, each subsequent set consists of items that depend upon +items in the preceeding sets. +""" + + # Special case empty input. + if len(data) == 0: + return + + # Copy the input so as to leave it unmodified. + data = data.copy() + + # Ignore self dependencies. + for k, v in data.items(): + v.discard(k) + # Find all items that don't depend on anything. + extra_items_in_deps = _reduce(set.union, data.values()) - set(data.keys()) + # Add empty dependences where needed. + data.update({item: set() for item in extra_items_in_deps}) + while True: + ordered = set(item for item, dep in data.items() if len(dep) == 0) + if not ordered: + break + yield ordered + data = {item: (dep - ordered) + for item, dep in data.items() + if item not in ordered} + if len(data) != 0: + raise ValueError('Cyclic dependencies exist among these items: {}'.format( + ', '.join(repr(x) for x in data.items()))) + + +def toposort_flatten(data, sort=True): + """Returns a single list of dependencies. For any set returned by +toposort(), those items are sorted and appended to the result (just to +make the results deterministic).""" + + result = [] + for d in toposort(data): + result.extend((sorted if sort else list)(d)) + return result diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh new file mode 100755 index 0000000000000..924b55287c2dc --- /dev/null +++ b/dev/test-dependencies.sh @@ -0,0 +1,112 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +set -e + +FWDIR="$(cd "`dirname $0`"/..; pwd)" +cd "$FWDIR" + +# Explicitly set locale in order to make `sort` output consistent across machines. +# See https://stackoverflow.com/questions/28881 for more details. +export LC_ALL=C + +# TODO: This would be much nicer to do in SBT, once SBT supports Maven-style resolution. + +# NOTE: These should match those in the release publishing script +HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pyarn -Phive" +MVN="build/mvn --force" +HADOOP_PROFILES=( + hadoop-2.2 + hadoop-2.3 + hadoop-2.4 + hadoop-2.6 + hadoop-2.7 +) + +# We'll switch the version to a temp. one, publish POMs using that new version, then switch back to +# the old version. We need to do this because the `dependency:build-classpath` task needs to +# resolve Spark's internal submodule dependencies. + +# From http://stackoverflow.com/a/26514030 +set +e +OLD_VERSION=$($MVN -q \ + -Dexec.executable="echo" \ + -Dexec.args='${project.version}' \ + --non-recursive \ + org.codehaus.mojo:exec-maven-plugin:1.3.1:exec) +if [ $? != 0 ]; then + echo -e "Error while getting version string from Maven:\n$OLD_VERSION" + exit 1 +fi +set -e +TEMP_VERSION="spark-$(python -S -c "import random; print(random.randrange(100000, 999999))")" + +function reset_version { + # Delete the temporary POMs that we wrote to the local Maven repo: + find "$HOME/.m2/" | grep "$TEMP_VERSION" | xargs rm -rf + + # Restore the original version number: + $MVN -q versions:set -DnewVersion=$OLD_VERSION -DgenerateBackupPoms=false > /dev/null +} +trap reset_version EXIT + +$MVN -q versions:set -DnewVersion=$TEMP_VERSION -DgenerateBackupPoms=false > /dev/null + +# Generate manifests for each Hadoop profile: +for HADOOP_PROFILE in "${HADOOP_PROFILES[@]}"; do + echo "Performing Maven install for $HADOOP_PROFILE" + $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE jar:jar jar:test-jar install:install clean -q + + echo "Performing Maven validate for $HADOOP_PROFILE" + $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE validate -q + + echo "Generating dependency manifest for $HADOOP_PROFILE" + mkdir -p dev/pr-deps + $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE dependency:build-classpath -pl assembly \ + | grep "Building Spark Project Assembly" -A 5 \ + | tail -n 1 | tr ":" "\n" | rev | cut -d "/" -f 1 | rev | sort \ + | grep -v spark > dev/pr-deps/spark-deps-$HADOOP_PROFILE +done + +if [[ $@ == **replace-manifest** ]]; then + echo "Replacing manifests and creating new files at dev/deps" + rm -rf dev/deps + mv dev/pr-deps dev/deps + exit 0 +fi + +for HADOOP_PROFILE in "${HADOOP_PROFILES[@]}"; do + set +e + dep_diff="$( + git diff \ + --no-index \ + dev/deps/spark-deps-$HADOOP_PROFILE \ + dev/pr-deps/spark-deps-$HADOOP_PROFILE \ + )" + set -e + if [ "$dep_diff" != "" ]; then + echo "Spark's published dependencies DO NOT MATCH the manifest file (dev/spark-deps)." + echo "To update the manifest file, run './dev/test-dependencies.sh --replace-manifest'." + echo "$dep_diff" + rm -rf dev/pr-deps + exit 1 + fi +done + +exit 0 diff --git a/dev/tests/pr_new_dependencies.sh b/dev/tests/pr_new_dependencies.sh deleted file mode 100755 index fdfb3c62aff58..0000000000000 --- a/dev/tests/pr_new_dependencies.sh +++ /dev/null @@ -1,117 +0,0 @@ -#!/usr/bin/env bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# -# This script follows the base format for testing pull requests against -# another branch and returning results to be published. More details can be -# found at dev/run-tests-jenkins. -# -# Arg1: The Github Pull Request Actual Commit -#+ known as `ghprbActualCommit` in `run-tests-jenkins` -# Arg2: The SHA1 hash -#+ known as `sha1` in `run-tests-jenkins` -# Arg3: Current PR Commit Hash -#+ the PR hash for the current commit -# - -ghprbActualCommit="$1" -sha1="$2" -current_pr_head="$3" - -MVN_BIN="build/mvn" -CURR_CP_FILE="my-classpath.txt" -MASTER_CP_FILE="master-classpath.txt" - -# First switch over to the master branch -git checkout -f master -# Find and copy all pom.xml files into a *.gate file that we can check -# against through various `git` changes -find -name "pom.xml" -exec cp {} {}.gate \; -# Switch back to the current PR -git checkout -f "${current_pr_head}" - -# Check if any *.pom files from the current branch are different from the master -difference_q="" -for p in $(find -name "pom.xml"); do - [[ -f "${p}" && -f "${p}.gate" ]] && \ - difference_q="${difference_q}$(diff $p.gate $p)" -done - -# If no pom files were changed we can easily say no new dependencies were added -if [ -z "${difference_q}" ]; then - echo " * This patch does not change any dependencies." -else - # Else we need to manually build spark to determine what, if any, dependencies - # were added into the Spark assembly jar - ${MVN_BIN} clean package dependency:build-classpath -DskipTests 2>/dev/null | \ - sed -n -e '/Building Spark Project Assembly/,$p' | \ - grep --context=1 -m 2 "Dependencies classpath:" | \ - head -n 3 | \ - tail -n 1 | \ - tr ":" "\n" | \ - rev | \ - cut -d "/" -f 1 | \ - rev | \ - sort > ${CURR_CP_FILE} - - # Checkout the master branch to compare against - git checkout -f master - - ${MVN_BIN} clean package dependency:build-classpath -DskipTests 2>/dev/null | \ - sed -n -e '/Building Spark Project Assembly/,$p' | \ - grep --context=1 -m 2 "Dependencies classpath:" | \ - head -n 3 | \ - tail -n 1 | \ - tr ":" "\n" | \ - rev | \ - cut -d "/" -f 1 | \ - rev | \ - sort > ${MASTER_CP_FILE} - - DIFF_RESULTS="`diff ${CURR_CP_FILE} ${MASTER_CP_FILE}`" - - if [ -z "${DIFF_RESULTS}" ]; then - echo " * This patch does not change any dependencies." - else - # Pretty print the new dependencies - added_deps=$(echo "${DIFF_RESULTS}" | grep "<" | cut -d' ' -f2 | awk '{printf " * \`"$1"\`\\n"}') - removed_deps=$(echo "${DIFF_RESULTS}" | grep ">" | cut -d' ' -f2 | awk '{printf " * \`"$1"\`\\n"}') - added_deps_text=" * This patch **adds the following new dependencies:**\n${added_deps}" - removed_deps_text=" * This patch **removes the following dependencies:**\n${removed_deps}" - - # Construct the final returned message with proper - return_mssg="" - [ -n "${added_deps}" ] && return_mssg="${added_deps_text}" - if [ -n "${removed_deps}" ]; then - if [ -n "${return_mssg}" ]; then - return_mssg="${return_mssg}\n${removed_deps_text}" - else - return_mssg="${removed_deps_text}" - fi - fi - echo "${return_mssg}" - fi - - # Remove the files we've left over - [ -f "${CURR_CP_FILE}" ] && rm -f "${CURR_CP_FILE}" - [ -f "${MASTER_CP_FILE}" ] && rm -f "${MASTER_CP_FILE}" - - # Clean up our mess from the Maven builds just in case - ${MVN_BIN} clean &>/dev/null -fi diff --git a/dev/tests/pr_public_classes.sh b/dev/tests/pr_public_classes.sh index 927295b88c963..41c5d3ee8cb3c 100755 --- a/dev/tests/pr_public_classes.sh +++ b/dev/tests/pr_public_classes.sh @@ -24,36 +24,44 @@ # # Arg1: The Github Pull Request Actual Commit #+ known as `ghprbActualCommit` in `run-tests-jenkins` -# Arg2: The SHA1 hash -#+ known as `sha1` in `run-tests-jenkins` -# - -# We diff master...$ghprbActualCommit because that gets us changes introduced in the PR -#+ and not anything else added to master since the PR was branched. ghprbActualCommit="$1" -sha1="$2" + +# $ghprbActualCommit is an automatic merge commit generated by GitHub; its parents are some Spark +# master commit and the tip of the pull request branch. + +# By diffing$ghprbActualCommit^...$ghprbActualCommit and filtering to examine the diffs of only +# non-test files, we can gets us changes introduced in the PR and not anything else added to master +# since the PR was branched. + +# Handle differences between GNU and BSD sed +if [[ $(uname) == "Darwin" ]]; then + SED='sed -E' +else + SED='sed -r' +fi source_files=$( - git diff master...$ghprbActualCommit --name-only `# diff patch against master from branch point` \ + git diff $ghprbActualCommit^...$ghprbActualCommit --name-only `# diff patch against master from branch point` \ | grep -v -e "\/test" `# ignore files in test directories` \ | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \ | tr "\n" " " ) + new_public_classes=$( - git diff master...$ghprbActualCommit ${source_files} `# diff patch against master from branch point` \ + git diff $ghprbActualCommit^...$ghprbActualCommit ${source_files} `# diff patch against master from branch point` \ | grep "^\+" `# filter in only added lines` \ - | sed -r -e "s/^\+//g" `# remove the leading +` \ + | $SED -e "s/^\+//g" `# remove the leading +` \ | grep -e "trait " -e "class " `# filter in lines with these key words` \ | grep -e "{" -e "(" `# filter in lines with these key words, too` \ | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \ | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \ - | sed -r -e "s/\{.*//g" `# remove from the { onwards` \ - | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \ - | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \ - | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \ - | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \ - | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \ + | $SED -e "s/\{.*//g" `# remove from the { onwards` \ + | $SED -e "s/\}//g" `# just in case, remove }; they mess the JSON` \ + | $SED -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \ + | $SED -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \ + | $SED -e "s/^/ \* /g" `# prepend ' *' to start of line` \ + | $SED -e "s/$/\\\n/g" `# append newline to end of line` \ | tr -d "\n" `# remove actual LF characters` ) @@ -61,5 +69,5 @@ if [ -z "$new_public_classes" ]; then echo " * This patch adds no public classes." else public_classes_note=" * This patch adds the following public classes _(experimental)_:" - echo "${public_classes_note}\n${new_public_classes}" + echo -e "${public_classes_note}\n${new_public_classes}" fi diff --git a/tox.ini b/dev/tox.ini similarity index 100% rename from tox.ini rename to dev/tox.ini diff --git a/docs/README.md b/docs/README.md index 1f4fd3e56ed5f..bcea93e1f3b6d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -10,15 +10,18 @@ whichever version of Spark you currently have checked out of revision control. ## Prerequisites The Spark documentation build uses a number of tools to build HTML docs and API docs in Scala, -Python and R. To get started you can run the following commands +Python and R. - $ sudo gem install jekyll - $ sudo gem install jekyll-redirect-from +You need to have [Ruby](https://www.ruby-lang.org/en/documentation/installation/) and +[Python](https://docs.python.org/2/using/unix.html#getting-and-installing-the-latest-version-of-python) +installed. Also install the following libraries: +```sh + $ sudo gem install jekyll jekyll-redirect-from pygments.rb $ sudo pip install Pygments + # Following is needed only for generating API docs $ sudo pip install sphinx $ Rscript -e 'install.packages(c("knitr", "devtools"), repos="http://cran.stat.ucla.edu/")' - - +``` ## Generating the Documentation HTML We include the Spark documentation as part of the source (as opposed to using a hosted wiki, such as @@ -38,14 +41,16 @@ compiled files. $ jekyll build You can modify the default Jekyll build as follows: - +```sh # Skip generating API docs (which takes a while) $ SKIP_API=1 jekyll build + # Serve content locally on port 4000 $ jekyll serve --watch + # Build the site with extra features used on the live page $ PRODUCTION=1 jekyll build - +``` ## API Docs (Scaladoc, Sphinx, roxygen2) @@ -59,7 +64,7 @@ When you run `jekyll` in the `docs` directory, it will also copy over the scalad Spark subprojects into the `docs` directory (and then also into the `_site` directory). We use a jekyll plugin to run `build/sbt unidoc` before building the site so if you haven't run it (recently) it may take some time as it generates all of the scaladoc. The jekyll plugin also generates the -PySpark docs [Sphinx](http://sphinx-doc.org/). +PySpark docs using [Sphinx](http://sphinx-doc.org/). NOTE: To skip the step of building and copying over the Scala, Python, R API docs, run `SKIP_API=1 jekyll`. diff --git a/docs/_config.yml b/docs/_config.yml index 2c70b76be8b7a..8bdc68aeeac7f 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,10 +14,10 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 1.6.0-SNAPSHOT -SPARK_VERSION_SHORT: 1.6.0 -SCALA_BINARY_VERSION: "2.10" -SCALA_VERSION: "2.10.5" +SPARK_VERSION: 2.0.0-SNAPSHOT +SPARK_VERSION_SHORT: 2.0.0 +SCALA_BINARY_VERSION: "2.11" +SCALA_VERSION: "2.11.7" MESOS_VERSION: 0.21.0 SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK SPARK_GITHUB_URL: https://github.com/apache/spark diff --git a/docs/_data/menu-ml.yaml b/docs/_data/menu-ml.yaml new file mode 100644 index 0000000000000..3fd3ee2823f75 --- /dev/null +++ b/docs/_data/menu-ml.yaml @@ -0,0 +1,12 @@ +- text: "Overview: estimators, transformers and pipelines" + url: ml-guide.html +- text: Extracting, transforming and selecting features + url: ml-features.html +- text: Classification and Regression + url: ml-classification-regression.html +- text: Clustering + url: ml-clustering.html +- text: Collaborative filtering + url: ml-collaborative-filtering.html +- text: Advanced topics + url: ml-advanced.html diff --git a/docs/_data/menu-mllib.yaml b/docs/_data/menu-mllib.yaml new file mode 100644 index 0000000000000..12d22abd52826 --- /dev/null +++ b/docs/_data/menu-mllib.yaml @@ -0,0 +1,75 @@ +- text: Data types + url: mllib-data-types.html +- text: Basic statistics + url: mllib-statistics.html + subitems: + - text: Summary statistics + url: mllib-statistics.html#summary-statistics + - text: Correlations + url: mllib-statistics.html#correlations + - text: Stratified sampling + url: mllib-statistics.html#stratified-sampling + - text: Hypothesis testing + url: mllib-statistics.html#hypothesis-testing + - text: Random data generation + url: mllib-statistics.html#random-data-generation +- text: Classification and regression + url: mllib-classification-regression.html + subitems: + - text: Linear models (SVMs, logistic regression, linear regression) + url: mllib-linear-methods.html + - text: Naive Bayes + url: mllib-naive-bayes.html + - text: decision trees + url: mllib-decision-tree.html + - text: ensembles of trees (Random Forests and Gradient-Boosted Trees) + url: mllib-ensembles.html + - text: isotonic regression + url: mllib-isotonic-regression.html +- text: Collaborative filtering + url: mllib-collaborative-filtering.html + subitems: + - text: alternating least squares (ALS) + url: mllib-collaborative-filtering.html#collaborative-filtering +- text: Clustering + url: mllib-clustering.html + subitems: + - text: k-means + url: mllib-clustering.html#k-means + - text: Gaussian mixture + url: mllib-clustering.html#gaussian-mixture + - text: power iteration clustering (PIC) + url: mllib-clustering.html#power-iteration-clustering-pic + - text: latent Dirichlet allocation (LDA) + url: mllib-clustering.html#latent-dirichlet-allocation-lda + - text: streaming k-means + url: mllib-clustering.html#streaming-k-means +- text: Dimensionality reduction + url: mllib-dimensionality-reduction.html + subitems: + - text: singular value decomposition (SVD) + url: mllib-dimensionality-reduction.html#singular-value-decomposition-svd + - text: principal component analysis (PCA) + url: mllib-dimensionality-reduction.html#principal-component-analysis-pca +- text: Feature extraction and transformation + url: mllib-feature-extraction.html +- text: Frequent pattern mining + url: mllib-frequent-pattern-mining.html + subitems: + - text: FP-growth + url: mllib-frequent-pattern-mining.html#fp-growth + - text: association rules + url: mllib-frequent-pattern-mining.html#association-rules + - text: PrefixSpan + url: mllib-frequent-pattern-mining.html#prefix-span +- text: Evaluation metrics + url: mllib-evaluation-metrics.html +- text: PMML model export + url: mllib-pmml-model-export.html +- text: Optimization (developer) + url: mllib-optimization.html + subitems: + - text: stochastic gradient descent + url: mllib-optimization.html#stochastic-gradient-descent-sgd + - text: limited-memory BFGS (L-BFGS) + url: mllib-optimization.html#limited-memory-bfgs-l-bfgs diff --git a/docs/_includes/nav-left-wrapper-ml.html b/docs/_includes/nav-left-wrapper-ml.html new file mode 100644 index 0000000000000..e2d7eda027c6e --- /dev/null +++ b/docs/_includes/nav-left-wrapper-ml.html @@ -0,0 +1,8 @@ +
    +
    +

    spark.ml package

    + {% include nav-left.html nav=include.nav-ml %} +

    spark.mllib package

    + {% include nav-left.html nav=include.nav-mllib %} +
    +
    \ No newline at end of file diff --git a/docs/_includes/nav-left.html b/docs/_includes/nav-left.html new file mode 100644 index 0000000000000..73176f4132554 --- /dev/null +++ b/docs/_includes/nav-left.html @@ -0,0 +1,17 @@ +{% assign navurl = page.url | remove: 'index.html' %} + diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 467ff7a03fb70..d493f62f0e578 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -1,3 +1,4 @@ + @@ -71,10 +72,9 @@
  • Spark Programming Guide
  • Spark Streaming
  • -
  • DataFrames and SQL
  • +
  • DataFrames, Datasets and SQL
  • MLlib (Machine Learning)
  • GraphX (Graph Processing)
  • -
  • Bagel (Pregel on Spark)
  • SparkR (R on Spark)
  • @@ -98,8 +98,6 @@
  • Spark Standalone
  • Mesos
  • YARN
  • -
  • -
  • Amazon EC2
  • @@ -124,16 +122,36 @@
    -
    - {% if page.displayTitle %} -

    {{ page.displayTitle }}

    - {% else %} -

    {{ page.title }}

    - {% endif %} +
    + + {% if page.url contains "/ml" %} + {% include nav-left-wrapper-ml.html nav-mllib=site.data.menu-mllib nav-ml=site.data.menu-ml %} + + +
    + {% if page.displayTitle %} +

    {{ page.displayTitle }}

    + {% else %} +

    {{ page.title }}

    + {% endif %} + + {{ content }} - {{ content }} +
    + {% else %} +
    + {% if page.displayTitle %} +

    {{ page.displayTitle }}

    + {% else %} +

    {{ page.title }}

    + {% endif %} + + {{ content }} -
    +
    + {% endif %} + +
    diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 01718d98dffe0..f926d67e6beaf 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -27,7 +27,7 @@ cd("..") puts "Running 'build/sbt -Pkinesis-asl clean compile unidoc' from " + pwd + "; this may take a few minutes..." - puts `build/sbt -Pkinesis-asl clean compile unidoc` + system("build/sbt -Pkinesis-asl clean compile unidoc") || raise("Unidoc generation failed") puts "Moving back into docs dir." cd("docs") @@ -37,7 +37,7 @@ # Copy over the unified ScalaDoc for all projects to api/scala. # This directory will be copied over to _site when `jekyll` command is run. - source = "../target/scala-2.10/unidoc" + source = "../target/scala-2.11/unidoc" dest = "api/scala" puts "Making directory " + dest @@ -117,7 +117,7 @@ puts "Moving to python/docs directory and building sphinx." cd("../python/docs") - puts `make html` + system("make html") || raise("Python doc generation failed") puts "Moving back into home dir." cd("../../") @@ -131,7 +131,7 @@ # Build SparkR API docs puts "Moving to R directory and building roxygen docs." cd("R") - puts `./create-docs.sh` + system("./create-docs.sh") || raise("R doc generation failed") puts "Moving back into home dir." cd("../") diff --git a/docs/_plugins/include_example.rb b/docs/_plugins/include_example.rb index 6ee63a5ac69df..f7485826a762d 100644 --- a/docs/_plugins/include_example.rb +++ b/docs/_plugins/include_example.rb @@ -20,15 +20,15 @@ module Jekyll class IncludeExampleTag < Liquid::Tag - + def initialize(tag_name, markup, tokens) @markup = markup super end - + def render(context) site = context.registers[:site] - config_dir = (site.config['code_dir'] || '../examples/src/main').sub(/^\//,'') + config_dir = '../examples/src/main' @code_dir = File.join(site.source, config_dir) clean_markup = @markup.strip @@ -37,10 +37,15 @@ def render(context) code = File.open(@file).read.encode("UTF-8") code = select_lines(code) - - Pygments.highlight(code, :lexer => @lang) + + rendered_code = Pygments.highlight(code, :lexer => @lang) + + hint = "
    Find full example code at " \ + "\"examples/src/main/#{clean_markup}\" in the Spark repo.
    " + + rendered_code + hint end - + # Trim the code block so as to have the same indention, regardless of their positions in the # code file. def trim_codeblock(lines) @@ -70,10 +75,10 @@ def select_lines(code) .select { |l, i| l.include? "$example off$" } .map { |l, i| i } - raise "Start indices amount is not equal to end indices amount, please check the code." \ + raise "Start indices amount is not equal to end indices amount, see #{@file}." \ unless startIndices.size == endIndices.size - raise "No code is selected by include_example, please check the code." \ + raise "No code is selected by include_example, see #{@file}." \ if startIndices.size == 0 # Select and join code blocks together, with a space line between each of two continuous @@ -81,8 +86,10 @@ def select_lines(code) lastIndex = -1 result = "" startIndices.zip(endIndices).each do |start, endline| - raise "Overlapping between two example code blocks are not allowed." if start <= lastIndex - raise "$example on$ should not be in the same line with $example off$." if start == endline + raise "Overlapping between two example code blocks are not allowed, see #{@file}." \ + if start <= lastIndex + raise "$example on$ should not be in the same line with $example off$, see #{@file}." \ + if start == endline lastIndex = endline range = Range.new(start + 1, endline - 1) result += trim_codeblock(lines[range]).join diff --git a/docs/bagel-programming-guide.md b/docs/bagel-programming-guide.md deleted file mode 100644 index 347ca4a7af989..0000000000000 --- a/docs/bagel-programming-guide.md +++ /dev/null @@ -1,159 +0,0 @@ ---- -layout: global -displayTitle: Bagel Programming Guide -title: Bagel ---- - -**Bagel is deprecated, and superseded by [GraphX](graphx-programming-guide.html).** - -Bagel is a Spark implementation of Google's [Pregel](http://portal.acm.org/citation.cfm?id=1807184) graph processing framework. Bagel currently supports basic graph computation, combiners, and aggregators. - -In the Pregel programming model, jobs run as a sequence of iterations called _supersteps_. In each superstep, each vertex in the graph runs a user-specified function that can update state associated with the vertex and send messages to other vertices for use in the *next* iteration. - -This guide shows the programming model and features of Bagel by walking through an example implementation of PageRank on Bagel. - -# Linking with Bagel - -To use Bagel in your program, add the following SBT or Maven dependency: - - groupId = org.apache.spark - artifactId = spark-bagel_{{site.SCALA_BINARY_VERSION}} - version = {{site.SPARK_VERSION}} - -# Programming Model - -Bagel operates on a graph represented as a [distributed dataset](programming-guide.html) of (K, V) pairs, where keys are vertex IDs and values are vertices plus their associated state. In each superstep, Bagel runs a user-specified compute function on each vertex that takes as input the current vertex state and a list of messages sent to that vertex during the previous superstep, and returns the new vertex state and a list of outgoing messages. - -For example, we can use Bagel to implement PageRank. Here, vertices represent pages, edges represent links between pages, and messages represent shares of PageRank sent to the pages that a particular page links to. - -We first extend the default `Vertex` class to store a `Double` -representing the current PageRank of the vertex, and similarly extend -the `Message` and `Edge` classes. Note that these need to be marked `@serializable` to allow Spark to transfer them across machines. We also import the Bagel types and implicit conversions. - -{% highlight scala %} -import org.apache.spark.bagel._ -import org.apache.spark.bagel.Bagel._ - -@serializable class PREdge(val targetId: String) extends Edge - -@serializable class PRVertex( - val id: String, val rank: Double, val outEdges: Seq[Edge], - val active: Boolean) extends Vertex - -@serializable class PRMessage( - val targetId: String, val rankShare: Double) extends Message -{% endhighlight %} - -Next, we load a sample graph from a text file as a distributed dataset and package it into `PRVertex` objects. We also cache the distributed dataset because Bagel will use it multiple times and we'd like to avoid recomputing it. - -{% highlight scala %} -val input = sc.textFile("data/mllib/pagerank_data.txt") - -val numVerts = input.count() - -val verts = input.map(line => { - val fields = line.split('\t') - val (id, linksStr) = (fields(0), fields(1)) - val links = linksStr.split(',').map(new PREdge(_)) - (id, new PRVertex(id, 1.0 / numVerts, links, true)) -}).cache -{% endhighlight %} - -We run the Bagel job, passing in `verts`, an empty distributed dataset of messages, and a custom compute function that runs PageRank for 10 iterations. - -{% highlight scala %} -val emptyMsgs = sc.parallelize(List[(String, PRMessage)]()) - -def compute(self: PRVertex, msgs: Option[Seq[PRMessage]], superstep: Int) -: (PRVertex, Iterable[PRMessage]) = { - val msgSum = msgs.getOrElse(List()).map(_.rankShare).sum - val newRank = - if (msgSum != 0) - 0.15 / numVerts + 0.85 * msgSum - else - self.rank - val halt = superstep >= 10 - val msgsOut = - if (!halt) - self.outEdges.map(edge => - new PRMessage(edge.targetId, newRank / self.outEdges.size)) - else - List() - (new PRVertex(self.id, newRank, self.outEdges, !halt), msgsOut) -} -{% endhighlight %} - -val result = Bagel.run(sc, verts, emptyMsgs)()(compute) - -Finally, we print the results. - -{% highlight scala %} -println(result.map(v => "%s\t%s\n".format(v.id, v.rank)).collect.mkString) -{% endhighlight %} - -## Combiners - -Sending a message to another vertex generally involves expensive communication over the network. For certain algorithms, it's possible to reduce the amount of communication using _combiners_. For example, if the compute function receives integer messages and only uses their sum, it's possible for Bagel to combine multiple messages to the same vertex by summing them. - -For combiner support, Bagel can optionally take a set of combiner functions that convert messages to their combined form. - -_Example: PageRank with combiners_ - -## Aggregators - -Aggregators perform a reduce across all vertices after each superstep, and provide the result to each vertex in the next superstep. - -For aggregator support, Bagel can optionally take an aggregator function that reduces across each vertex. - -_Example_ - -## Operations - -Here are the actions and types in the Bagel API. See [Bagel.scala](https://github.com/apache/spark/blob/master/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala) for details. - -### Actions - -{% highlight scala %} -/*** Full form ***/ - -Bagel.run(sc, vertices, messages, combiner, aggregator, partitioner, numSplits)(compute) -// where compute takes (vertex: V, combinedMessages: Option[C], aggregated: Option[A], superstep: Int) -// and returns (newVertex: V, outMessages: Array[M]) - -/*** Abbreviated forms ***/ - -Bagel.run(sc, vertices, messages, combiner, partitioner, numSplits)(compute) -// where compute takes (vertex: V, combinedMessages: Option[C], superstep: Int) -// and returns (newVertex: V, outMessages: Array[M]) - -Bagel.run(sc, vertices, messages, combiner, numSplits)(compute) -// where compute takes (vertex: V, combinedMessages: Option[C], superstep: Int) -// and returns (newVertex: V, outMessages: Array[M]) - -Bagel.run(sc, vertices, messages, numSplits)(compute) -// where compute takes (vertex: V, messages: Option[Array[M]], superstep: Int) -// and returns (newVertex: V, outMessages: Array[M]) -{% endhighlight %} - -### Types - -{% highlight scala %} -trait Combiner[M, C] { - def createCombiner(msg: M): C - def mergeMsg(combiner: C, msg: M): C - def mergeCombiners(a: C, b: C): C -} - -trait Aggregator[V, A] { - def createAggregator(vert: V): A - def mergeAggregators(a: A, b: A): A -} - -trait Vertex { - def active: Boolean -} - -trait Message[K] { - def targetId: K -} -{% endhighlight %} diff --git a/docs/building-spark.md b/docs/building-spark.md index 4f73adb85446c..fec442af95e1b 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -7,7 +7,7 @@ redirect_from: "building-with-maven.html" * This will become a table of contents (this text will be scraped). {:toc} -Building Spark using Maven requires Maven 3.3.3 or newer and Java 7+. +Building Spark using Maven requires Maven 3.3.9 or newer and Java 7+. The Spark build can supply a suitable Maven binary; see below. # Building with `build/mvn` @@ -33,14 +33,14 @@ to the `sharedSettings` val. See also [this PR](https://github.com/apache/spark/ # Building a Runnable Distribution -To create a Spark distribution like those distributed by the -[Spark Downloads](http://spark.apache.org/downloads.html) page, and that is laid out so as -to be runnable, use `make-distribution.sh` in the project root directory. It can be configured +To create a Spark distribution like those distributed by the +[Spark Downloads](http://spark.apache.org/downloads.html) page, and that is laid out so as +to be runnable, use `./dev/make-distribution.sh` in the project root directory. It can be configured with Maven profile settings and so on like the direct Maven build. Example: - ./make-distribution.sh --name custom-spark --tgz -Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn - -For more information on usage, run `./make-distribution.sh --help` + ./dev/make-distribution.sh --name custom-spark --tgz -Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn + +For more information on usage, run `./dev/make-distribution.sh --help` # Setting up Maven's Memory Usage @@ -74,23 +74,14 @@ Because HDFS is not protocol-compatible across versions, if you want to read fro Hadoop versionProfile required - 1.x to 2.1.xhadoop-1 2.2.xhadoop-2.2 2.3.xhadoop-2.3 2.4.xhadoop-2.4 - 2.6.x and later 2.xhadoop-2.6 + 2.6.xhadoop-2.6 + 2.7.x and later 2.xhadoop-2.7 -For Apache Hadoop versions 1.x, Cloudera CDH "mr1" distributions, and other Hadoop versions without YARN, use: - -{% highlight bash %} -# Apache Hadoop 1.2.1 -mvn -Dhadoop.version=1.2.1 -Phadoop-1 -DskipTests clean package - -# Cloudera CDH 4.2.0 with MapReduce v1 -mvn -Dhadoop.version=2.0.0-mr1-cdh4.2.0 -Phadoop-1 -DskipTests clean package -{% endhighlight %} You can enable the `yarn` profile and optionally set the `yarn.version` property if it is different from `hadoop.version`. Spark only supports YARN versions 2.2.0 and later. @@ -107,8 +98,11 @@ mvn -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -DskipTests clean package # Apache Hadoop 2.4.X or 2.5.X mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=VERSION -DskipTests clean package -Versions of Hadoop after 2.5.X may or may not work with the -Phadoop-2.4 profile (they were -released after this version of Spark). +# Apache Hadoop 2.6.X +mvn -Pyarn -Phadoop-2.6 -Dhadoop.version=2.6.0 -DskipTests clean package + +# Apache Hadoop 2.7.X and later +mvn -Pyarn -Phadoop-2.7 -Dhadoop.version=VERSION -DskipTests clean package # Different versions of HDFS and YARN. mvn -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=2.2.0 -DskipTests clean package @@ -117,19 +111,17 @@ mvn -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=2.2.0 -DskipTests # Building With Hive and JDBC Support To enable Hive integration for Spark SQL along with its JDBC server and CLI, add the `-Phive` and `Phive-thriftserver` profiles to your existing build options. -By default Spark will build with Hive 0.13.1 bindings. +By default Spark will build with Hive 1.2.1 bindings. {% highlight bash %} -# Apache Hadoop 2.4.X with Hive 13 support +# Apache Hadoop 2.4.X with Hive 1.2.1 support mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -DskipTests clean package {% endhighlight %} -# Building for Scala 2.11 -To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` property: - - ./dev/change-scala-version.sh 2.11 - mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package +# Building for Scala 2.10 +To produce a Spark package compiled with Scala 2.10, use the `-Dscala-2.10` property: -Spark does not yet support its JDBC component for Scala 2.11. + ./dev/change-scala-version.sh 2.10 + mvn -Pyarn -Phadoop-2.4 -Dscala-2.10 -DskipTests clean package # Spark Tests in Maven @@ -151,10 +143,10 @@ It's possible to build Spark sub-modules using the `mvn -pl` option. For instance, you can build the Spark Streaming module using: {% highlight bash %} -mvn -pl :spark-streaming_2.10 clean install +mvn -pl :spark-streaming_2.11 clean install {% endhighlight %} -where `spark-streaming_2.10` is the `artifactId` as defined in `streaming/pom.xml` file. +where `spark-streaming_2.11` is the `artifactId` as defined in `streaming/pom.xml` file. # Continuous Compilation @@ -188,22 +180,19 @@ For help in setting up IntelliJ IDEA or Eclipse for Spark development, and troub Running only Java 8 tests and nothing else. - mvn install -DskipTests -Pjava8-tests + mvn install -DskipTests + mvn -pl :java8-tests_2.11 test -Java 8 tests are run when `-Pjava8-tests` profile is enabled, they will run in spite of `-DskipTests`. -For these tests to run your system must have a JDK 8 installation. -If you have JDK 8 installed but it is not the system default, you can set JAVA_HOME to point to JDK 8 before running the tests. +or -# Building for PySpark on YARN + sbt java8-tests/test -PySpark on YARN is only supported if the jar is built with Maven. Further, there is a known problem -with building this assembly jar on Red Hat based operating systems (see [SPARK-1753](https://issues.apache.org/jira/browse/SPARK-1753)). If you wish to -run PySpark on a YARN cluster with Red Hat installed, we recommend that you build the jar elsewhere, -then ship it over to the cluster. We are investigating the exact cause for this. +Java 8 tests are automatically enabled when a Java 8 JDK is detected. +If you have JDK 8 installed but it is not the system default, you can set JAVA_HOME to point to JDK 8 before running the tests. # Packaging without Hadoop Dependencies for YARN -The assembly jar produced by `mvn package` will, by default, include all of Spark's dependencies, including Hadoop and some of its ecosystem projects. On YARN deployments, this causes multiple versions of these to appear on executor classpaths: the version packaged in the Spark assembly and the version on each node, included with `yarn.application.classpath`. The `hadoop-provided` profile builds the assembly without including Hadoop-ecosystem projects, like ZooKeeper and Hadoop itself. +The assembly directory produced by `mvn package` will, by default, include all of Spark's dependencies, including Hadoop and some of its ecosystem projects. On YARN deployments, this causes multiple versions of these to appear on executor classpaths: the version packaged in the Spark assembly and the version on each node, included with `yarn.application.classpath`. The `hadoop-provided` profile builds the assembly without including Hadoop-ecosystem projects, like ZooKeeper and Hadoop itself. # Building with SBT @@ -214,7 +203,7 @@ compilation. More advanced developers may wish to use SBT. The SBT build is derived from the Maven POM files, and so the same Maven profiles and variables can be set to control the SBT build. For example: - build/sbt -Pyarn -Phadoop-2.3 assembly + build/sbt -Pyarn -Phadoop-2.3 package To avoid the overhead of launching sbt each time you need to re-compile, you can launch sbt in interactive mode by running `build/sbt`, and then run all build commands at the command @@ -223,9 +212,9 @@ prompt. For more recommendations on reducing build time, refer to the # Testing with SBT -Some of the tests require Spark to be packaged first, so always run `build/sbt assembly` the first time. The following is an example of a correct (build, test) sequence: +Some of the tests require Spark to be packaged first, so always run `build/sbt package` the first time. The following is an example of a correct (build, test) sequence: - build/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver assembly + build/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver package build/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver test To run only a specific test suite as follows: diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index faaf154d243f5..814e4406cf435 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -35,7 +35,7 @@ There are several useful things to note about this architecture: processes, and these communicate with each other, it is relatively easy to run it even on a cluster manager that also supports other applications (e.g. Mesos/YARN). 3. The driver program must listen for and accept incoming connections from its executors throughout - its lifetime (e.g., see [spark.driver.port and spark.fileserver.port in the network config + its lifetime (e.g., see [spark.driver.port in the network config section](configuration.html#networking)). As such, the driver program must be network addressable from the worker nodes. 4. Because the driver schedules tasks on the cluster, it should be run close to the worker @@ -53,8 +53,6 @@ The system currently supports three cluster managers: and service applications. * [Hadoop YARN](running-on-yarn.html) -- the resource manager in Hadoop 2. -In addition, Spark's [EC2 launch scripts](ec2-scripts.html) make it easy to launch a standalone -cluster on Amazon EC2. # Submitting Applications diff --git a/docs/configuration.md b/docs/configuration.md index c276e8e90decf..16d5be62f9e89 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -35,7 +35,7 @@ val sc = new SparkContext(conf) {% endhighlight %} Note that we can have more than 1 thread in local mode, and in cases like Spark Streaming, we may -actually require one to prevent any sort of starvation issues. +actually require more than 1 thread to prevent any sort of starvation issues. Properties that specify some time duration should be configured with a unit of time. The following format is accepted: @@ -48,7 +48,7 @@ The following format is accepted: 1y (years) -Properties that specify a byte size should be configured with a unit of size. +Properties that specify a byte size should be configured with a unit of size. The following format is accepted: 1b (bytes) @@ -173,7 +173,7 @@ of the most common options to set are: stored on disk. This should be on a fast, local disk in your system. It can also be a comma-separated list of multiple directories on different disks. - NOTE: In Spark 1.0 and later this will be overriden by SPARK_LOCAL_DIRS (Standalone, Mesos) or + NOTE: In Spark 1.0 and later this will be overridden by SPARK_LOCAL_DIRS (Standalone, Mesos) or LOCAL_DIRS (YARN) environment variables set by the cluster manager. @@ -192,6 +192,15 @@ of the most common options to set are: allowed master URL's. + + spark.submit.deployMode + (none) + + The deploy mode of Spark driver program, either "client" or "cluster", + Which means to launch driver program locally ("client") + or remotely ("cluster") on one of the nodes inside the cluster. + + Apart from these, the following properties are also available, and may be useful in some situations: @@ -216,11 +225,14 @@ Apart from these, the following properties are also available, and may be useful (none) A string of extra JVM options to pass to the driver. For instance, GC settings or other logging. + Note that it is illegal to set maximum heap size (-Xmx) settings with this option. Maximum heap + size settings can be set with spark.driver.memory in the cluster mode and through + the --driver-memory command line option in the client mode.
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. Instead, please set this through the --driver-java-options command line option or in - your default properties file. + your default properties file. @@ -240,7 +252,7 @@ Apart from these, the following properties are also available, and may be useful false (Experimental) Whether to give user-added jars precedence over Spark's own jars when loading - classes in the the driver. This feature can be used to mitigate conflicts between Spark's + classes in the driver. This feature can be used to mitigate conflicts between Spark's dependencies and user dependencies. It is currently an experimental feature. This is used in cluster mode only. @@ -260,9 +272,9 @@ Apart from these, the following properties are also available, and may be useful (none) A string of extra JVM options to pass to executors. For instance, GC settings or other logging. - Note that it is illegal to set Spark properties or heap size settings with this option. Spark - properties should be set using a SparkConf object or the spark-defaults.conf file used with the - spark-submit script. Heap size settings can be set with spark.executor.memory. + Note that it is illegal to set Spark properties or maximum heap size (-Xmx) settings with this + option. Spark properties should be set using a SparkConf object or the spark-defaults.conf file + used with the spark-submit script. Maximum heap size settings can be set with spark.executor.memory. @@ -305,7 +317,7 @@ Apart from these, the following properties are also available, and may be useful daily Set the time interval by which the executor logs will be rolled over. - Rolling is disabled by default. Valid values are daily, hourly, minutely or + Rolling is disabled by default. Valid values are daily, hourly, minutely or any interval in seconds. See spark.executor.logs.rolling.maxRetainedFiles for automatic cleaning of old logs. @@ -330,13 +342,13 @@ Apart from these, the following properties are also available, and may be useful spark.python.profile false - Enable profiling in Python worker, the profile result will show up by sc.show_profiles(), + Enable profiling in Python worker, the profile result will show up by sc.show_profiles(), or it will be displayed before the driver exiting. It also can be dumped into disk by - sc.dump_profiles(path). If some of the profile results had been displayed manually, + sc.dump_profiles(path). If some of the profile results had been displayed manually, they will not be displayed automatically before driver exiting. - By default the pyspark.profiler.BasicProfiler will be used, but this can be overridden by - passing a profiler class in as a parameter to the SparkContext constructor. + By default the pyspark.profiler.BasicProfiler will be used, but this can be overridden by + passing a profiler class in as a parameter to the SparkContext constructor. @@ -364,7 +376,7 @@ Apart from these, the following properties are also available, and may be useful Reuse Python worker or not. If yes, it will use a fixed number of Python workers, does not need to fork() a Python process for every tasks. It will be very useful - if there is large broadcast, then the broadcast will not be needed to transfered + if there is large broadcast, then the broadcast will not be needed to transferred from JVM to Python worker for every task. @@ -382,6 +394,16 @@ Apart from these, the following properties are also available, and may be useful overhead per reduce task, so keep it small unless you have a large amount of memory. + + spark.reducer.maxReqsInFlight + Int.MaxValue + + This configuration limits the number of remote requests to fetch blocks at any given point. + When the number of hosts in the cluster increase, it might lead to very large number + of in-bound connections to one or more nodes, causing the workers to fail under load. + By allowing it to limit the number of fetch requests, this scenario can be mitigated. + + spark.shuffle.compress true @@ -565,6 +587,13 @@ Apart from these, the following properties are also available, and may be useful How many finished batches the Spark UI and status APIs remember before garbage collecting. + + spark.ui.retainedDeadExecutors + 100 + + How many dead executors the Spark UI and status APIs remember before garbage collecting. + + #### Compression and Serialization @@ -577,16 +606,9 @@ Apart from these, the following properties are also available, and may be useful Whether to compress broadcast variables before sending them. Generally a good idea. - - spark.closure.serializer - org.apache.spark.serializer.
    JavaSerializer - - Serializer class to use for closures. Currently only the Java serializer is supported. - - spark.io.compression.codec - snappy + lz4 The codec used to compress internal data such as RDD partitions, broadcast variables and shuffle outputs. By default, Spark provides three codecs: lz4, lzf, @@ -647,10 +669,10 @@ Apart from these, the following properties are also available, and may be useful spark.kryo.registrator (none) - If you use Kryo serialization, set this class to register your custom classes with Kryo. This + If you use Kryo serialization, give a comma-separated list of classes that register your custom classes with Kryo. This property is useful if you need to register your classes in a custom way, e.g. to specify a custom field serializer. Otherwise spark.kryo.classesToRegister is simpler. It should be - set to a class that extends + set to classes that extend KryoRegistrator. See the tuning guide for more details. @@ -679,8 +701,9 @@ Apart from these, the following properties are also available, and may be useful false Whether to compress serialized RDD partitions (e.g. for - StorageLevel.MEMORY_ONLY_SER). Can save substantial space at the cost of some - extra CPU time. + StorageLevel.MEMORY_ONLY_SER in Java + and Scala or StorageLevel.MEMORY_ONLY in Python). + Can save substantial space at the cost of some extra CPU time. @@ -719,20 +742,39 @@ Apart from these, the following properties are also available, and may be useful spark.memory.fraction 0.75 - Fraction of the heap space used for execution and storage. The lower this is, the more - frequently spills and cached data eviction occur. The purpose of this config is to set + Fraction of (heap space - 300MB) used for execution and storage. The lower this is, the + more frequently spills and cached data eviction occur. The purpose of this config is to set aside memory for internal metadata, user data structures, and imprecise size estimation - in the case of sparse, unusually large records. + in the case of sparse, unusually large records. Leaving this at the default value is + recommended. For more detail, see + this description. spark.memory.storageFraction 0.5 - T​he size of the storage region within the space set aside by - s​park.memory.fraction. This region is not statically reserved, but dynamically - allocated as cache requests come in. ​Cached data may be evicted only if total storage exceeds - this region. + Amount of storage memory immune to eviction, expressed as a fraction of the size of the + region set aside by s​park.memory.fraction. The higher this is, the less + working memory may be available to execution and tasks may spill to disk more often. + Leaving this at the default value is recommended. For more detail, see + this description. + + + + spark.memory.offHeap.enabled + false + + If true, Spark will attempt to use off-heap memory for certain operations. If off-heap memory use is enabled, then spark.memory.offHeap.size must be positive. + + + + spark.memory.offHeap.size + 0 + + The absolute amount of memory in bytes which can be used for off-heap allocation. + This setting has no impact on heap memory usage, so if your executors' total memory consumption must fit within some hard limit then be sure to shrink your JVM heap size accordingly. + This must be set to a positive value when spark.memory.offHeap.enabled=true. @@ -795,32 +837,19 @@ Apart from these, the following properties are also available, and may be useful - spark.broadcast.factory - org.apache.spark.broadcast.
    TorrentBroadcastFactory - - Which broadcast implementation to use. - - - - spark.cleaner.ttl - (infinite) + spark.executor.cores - Duration (seconds) of how long Spark will remember any metadata (stages generated, tasks - generated, etc.). Periodic cleanups will ensure that metadata older than this duration will be - forgotten. This is useful for running Spark for many hours / days (for example, running 24/7 in - case of Spark Streaming applications). Note that any RDD that persists in memory for more than - this duration will be cleared as well. + 1 in YARN mode, all the available cores on the worker in + standalone and Mesos coarse-grained modes. - - - spark.executor.cores - 1 in YARN mode, all the available cores on the worker in standalone mode. - The number of cores to use on each executor. For YARN and standalone mode only. + The number of cores to use on each executor. - In standalone mode, setting this parameter allows an application to run multiple executors on - the same worker, provided that there are enough cores on that worker. Otherwise, only one - executor per application will run on each worker. + In standalone and Mesos coarse-grained modes, setting this + parameter allows an application to run multiple executors on the + same worker, provided that there are enough cores on that + worker. Otherwise, only one executor per application will run on + each worker. @@ -903,82 +932,18 @@ Apart from these, the following properties are also available, and may be useful mapping has high overhead for blocks close to or below the page size of the operating system. - - spark.externalBlockStore.blockManager - org.apache.spark.storage.TachyonBlockManager - - Implementation of external block manager (file system) that store RDDs. The file system's URL is set by - spark.externalBlockStore.url. - - - - spark.externalBlockStore.baseDir - System.getProperty("java.io.tmpdir") - - Directories of the external block store that store RDDs. The file system's URL is set by - spark.externalBlockStore.url It can also be a comma-separated list of multiple - directories on Tachyon file system. - - - - spark.externalBlockStore.url - tachyon://localhost:19998 for Tachyon - - The URL of the underlying external blocker file system in the external block store. - - #### Networking - + - - - - - - - - - - - - - - - - - - - - @@ -988,14 +953,6 @@ Apart from these, the following properties are also available, and may be useful Port for all block managers to listen on. These exist on both the driver and the executors. - - - - - @@ -1012,26 +969,12 @@ Apart from these, the following properties are also available, and may be useful This is used for communicating with the executors and the standalone Master. - - - - - - - - - - - - - - - @@ -1082,7 +1017,7 @@ Apart from these, the following properties are also available, and may be useful
    Property NameDefaultMeaning
    spark.akka.frameSizespark.rpc.message.maxSize 128 Maximum message size (in MB) to allow in "control plane" communication; generally only applies to map output size information sent between executors and the driver. Increase this if you are running - jobs with many thousands of map and reduce tasks and see messages about the frame size. -
    spark.akka.heartbeat.interval1000s - This is set to a larger value to disable the transport failure detector that comes built in to - Akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger - interval value reduces network overhead and a smaller value ( ~ 1 s) might be more - informative for Akka's failure detector. Tune this in combination of spark.akka.heartbeat.pauses - if you need to. A likely positive use case for using failure detector would be: a sensistive - failure detector can help evict rogue executors quickly. However this is usually not the case - as GC pauses and network lags are expected in a real Spark cluster. Apart from that enabling - this leads to a lot of exchanges of heart beats between nodes leading to flooding the network - with those. -
    spark.akka.heartbeat.pauses6000s - This is set to a larger value to disable the transport failure detector that comes built in to Akka. - It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart - beat pause for Akka. This can be used to control sensitivity to GC pauses. Tune - this along with spark.akka.heartbeat.interval if you need to. -
    spark.akka.threads4 - Number of actor threads to use for communication. Can be useful to increase on large clusters - when the driver has a lot of CPU cores. -
    spark.akka.timeout100s - Communication timeout between Spark nodes. + jobs with many thousands of map and reduce tasks and see messages about the RPC message size.
    spark.broadcast.port(random) - Port for the driver's HTTP broadcast server to listen on. - This is not relevant for torrent broadcast. -
    spark.driver.host (local hostname)
    spark.executor.port(random) - Port for the executor to listen on. This is used for communicating with the driver. -
    spark.fileserver.port(random) - Port for the driver's HTTP file server to listen on. -
    spark.network.timeout 120s Default timeout for all network interactions. This config will be used in place of - spark.core.connection.ack.wait.timeout, spark.akka.timeout, + spark.core.connection.ack.wait.timeout, spark.storage.blockManagerSlaveTimeoutMs, spark.shuffle.io.connectionTimeout, spark.rpc.askTimeout or spark.rpc.lookupTimeout if they are not configured. @@ -1048,14 +991,6 @@ Apart from these, the following properties are also available, and may be useful to port + maxRetries.
    spark.replClassServer.port(random) - Port for the driver's HTTP class server to listen on. - This is only relevant for the Spark shell. -
    spark.rpc.numRetries 3spark.rpc.lookupTimeout 120s - Duration for an RPC remote endpoint lookup operation to wait before timing out. + Duration for an RPC remote endpoint lookup operation to wait before timing out.
    @@ -1228,8 +1163,8 @@ Apart from these, the following properties are also available, and may be useful false Whether to use dynamic resource allocation, which scales the number of executors registered - with this application up and down based on the workload. Note that this is currently only - available on YARN mode. For more detail, see the description + with this application up and down based on the workload. + For more detail, see the description here.

    This requires spark.shuffle.service.enabled to be set. @@ -1317,7 +1252,7 @@ Apart from these, the following properties are also available, and may be useful Comma separated list of users/administrators that have view and modify access to all Spark jobs. This can be used if you run on a shared cluster and have a set of administrators or devs who - help debug when things work. Putting a "*" in the list means any user can have the priviledge + help debug when things work. Putting a "*" in the list means any user can have the privilege of admin. @@ -1421,8 +1356,7 @@ Apart from these, the following properties are also available, and may be useful

    Use spark.ssl.YYY.XXX settings to overwrite the global configuration for particular protocol denoted by YYY. Currently YYY can be - either akka for Akka based connections or fs for broadcast and - file server.

    + only fs for file server.

    @@ -1433,6 +1367,7 @@ Apart from these, the following properties are also available, and may be useful The reference list of protocols one can find on this page. + Note: If not set, it will use the default cipher suites of JVM. @@ -1457,6 +1392,13 @@ Apart from these, the following properties are also available, and may be useful A password to the key-store. + + spark.ssl.keyStoreType + JKS + + The type of the key-store. + + spark.ssl.protocol None @@ -1466,6 +1408,13 @@ Apart from these, the following properties are also available, and may be useful page. + + spark.ssl.needClientAuth + false + + Set true if SSL needs client authentication. + + spark.ssl.trustStore None @@ -1481,6 +1430,13 @@ Apart from these, the following properties are also available, and may be useful A password to the trust-store. + + spark.ssl.trustStoreType + JKS + + The type of the trust-store. + + @@ -1500,6 +1456,14 @@ Apart from these, the following properties are also available, and may be useful if they are set (see below). + + spark.streaming.backpressure.initialRate + not set + + This is the initial maximum receiving rate at which each receiver will receive data for the + first batch when the backpressure mechanism is enabled. + + spark.streaming.blockInterval 200ms @@ -1546,7 +1510,7 @@ Apart from these, the following properties are also available, and may be useful spark.streaming.stopGracefullyOnShutdown false - If true, Spark shuts down the StreamingContext gracefully on JVM + If true, Spark shuts down the StreamingContext gracefully on JVM shutdown rather than immediately. @@ -1577,6 +1541,24 @@ Apart from these, the following properties are also available, and may be useful How many batches the Spark Streaming UI and status APIs remember before garbage collecting. + + spark.streaming.driver.writeAheadLog.closeFileAfterWrite + false + + Whether to close the file after writing a write ahead log record on the driver. Set this to 'true' + when you want to use S3 (or any file system that does not support flushing) for the metadata WAL + on the driver. + + + + spark.streaming.receiver.writeAheadLog.closeFileAfterWrite + false + + Whether to close the file after writing a write ahead log record on the receivers. Set this to 'true' + when you want to use S3 (or any file system that does not support flushing) for the data WAL + on the receivers. + + #### SparkR @@ -1605,6 +1587,29 @@ Apart from these, the following properties are also available, and may be useful +#### Deploy + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.deploy.recoveryModeNONEThe recovery mode setting to recover submitted Spark jobs with cluster mode when it failed and relaunches. + This is only applicable for cluster mode when running with Standalone or Mesos.
    spark.deploy.zookeeper.urlNoneWhen `spark.deploy.recoveryMode` is set to ZOOKEEPER, this configuration is used to set the zookeeper URL to connect to.
    spark.deploy.zookeeper.dirNoneWhen `spark.deploy.recoveryMode` is set to ZOOKEEPER, this configuration is used to set the zookeeper directory to store recovery state.
    + + #### Cluster Managers Each cluster manager in Spark has additional configuration options. Configurations can be found on the pages for each mode: @@ -1636,7 +1641,7 @@ The following variables can be set in `spark-env.sh`: PYSPARK_PYTHON - Python binary executable to use for PySpark in both driver and workers (default is python). + Python binary executable to use for PySpark in both driver and workers (default is python2.7 if available, otherwise python). PYSPARK_DRIVER_PYTHON @@ -1663,6 +1668,8 @@ to use on each machine and maximum memory. Since `spark-env.sh` is a shell script, some of these can be set programmatically -- for example, you might compute `SPARK_LOCAL_IP` by looking up the IP of a specific network interface. +Note: When running Spark on YARN in `cluster` mode, environment variables need to be set using the `spark.yarn.appMasterEnv.[EnvironmentVariableName]` property in your `conf/spark-defaults.conf` file. Environment variables that are set in `spark-env.sh` will not be reflected in the YARN Application Master process in `cluster` mode. See the [YARN-related Spark Properties](running-on-yarn.html#spark-properties) for more information. + # Configuring Logging Spark uses [log4j](http://logging.apache.org/log4j/) for logging. You can configure it by adding a @@ -1672,7 +1679,7 @@ Spark uses [log4j](http://logging.apache.org/log4j/) for logging. You can config # Overriding configuration directory To specify a different configuration directory other than the default "SPARK_HOME/conf", -you can set SPARK_CONF_DIR. Spark will use the the configuration files (spark-defaults.conf, spark-env.sh, log4j.properties, etc) +you can set SPARK_CONF_DIR. Spark will use the configuration files (spark-defaults.conf, spark-env.sh, log4j.properties, etc) from this directory. # Inheriting Hadoop Cluster Configuration diff --git a/docs/css/main.css b/docs/css/main.css index d770173be1014..175e8004fca0e 100755 --- a/docs/css/main.css +++ b/docs/css/main.css @@ -39,8 +39,15 @@ margin-left: 10px; } -body #content { - line-height: 1.6; /* Inspired by Github's wiki style */ +body .container-wrapper { + background-color: #FFF; + color: #1D1F22; + max-width: 1024px; + margin-top: 10px; + margin-left: auto; + margin-right: auto; + border-radius: 15px; + position: relative; } .title { @@ -91,6 +98,24 @@ a:hover code { max-width: 914px; } +.content { + z-index: 1; + position: relative; + background-color: #FFF; + max-width: 914px; + line-height: 1.6; /* Inspired by Github's wiki style */ + padding-left: 15px; +} + +.content-with-sidebar { + z-index: 1; + position: relative; + background-color: #FFF; + max-width: 914px; + line-height: 1.6; /* Inspired by Github's wiki style */ + padding-left: 30px; +} + .dropdown-menu { /* Remove the default 2px top margin which causes a small gap between the hover trigger area and the popup menu */ @@ -155,3 +180,110 @@ ul.nav li.dropdown ul.dropdown-menu li.dropdown-submenu ul.dropdown-menu { * AnchorJS (anchor links when hovering over headers) */ a.anchorjs-link:hover { text-decoration: none; } + + +/** + * The left navigation bar. + */ +.left-menu-wrapper { + margin-left: 0px; + margin-right: 0px; + background-color: #F0F8FC; + border-top-width: 0px; + border-left-width: 0px; + border-bottom-width: 0px; + margin-top: 0px; + width: 210px; + float: left; + position: absolute; +} + +.left-menu { + padding: 0px; + width: 199px; +} + +.left-menu h3 { + margin-left: 10px; + line-height: 30px; +} + +/** + * The collapsing button for the navigation bar. + */ +.nav-trigger { + position: fixed; + clip: rect(0, 0, 0, 0); +} + +.nav-trigger + label:after { + content: '»'; +} + +label { + z-index: 10; +} + +label[for="nav-trigger"] { + position: fixed; + margin-left: 0px; + padding-top: 100px; + padding-left: 5px; + width: 10px; + height: 80%; + cursor: pointer; + background-size: contain; + background-color: #D4F0FF; +} + +label[for="nav-trigger"]:hover { + background-color: #BEE9FF; +} + +.nav-trigger:checked + label { + margin-left: 200px; +} + +.nav-trigger:checked + label:after { + content: '«'; +} + +.nav-trigger:checked ~ div.content-with-sidebar { + margin-left: 200px; +} + +.nav-trigger + label, div.content-with-sidebar { + transition: left 0.4s; +} + +/** + * Rules to collapse the menu automatically when the screen becomes too thin. + */ + +@media all and (max-width: 780px) { + + div.content-with-sidebar { + margin-left: 200px; + } + .nav-trigger + label:after { + content: '«'; + } + label[for="nav-trigger"] { + margin-left: 200px; + } + + .nav-trigger:checked + label { + margin-left: 0px; + } + .nav-trigger:checked + label:after { + content: '»'; + } + .nav-trigger:checked ~ div.content-with-sidebar { + margin-left: 0px; + } + + div.container-index { + margin-left: -215px; + } + +} \ No newline at end of file diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md deleted file mode 100644 index 7f60f82b966fe..0000000000000 --- a/docs/ec2-scripts.md +++ /dev/null @@ -1,192 +0,0 @@ ---- -layout: global -title: Running Spark on EC2 ---- - -The `spark-ec2` script, located in Spark's `ec2` directory, allows you -to launch, manage and shut down Spark clusters on Amazon EC2. It automatically -sets up Spark and HDFS on the cluster for you. This guide describes -how to use `spark-ec2` to launch clusters, how to run jobs on them, and how -to shut them down. It assumes you've already signed up for an EC2 account -on the [Amazon Web Services site](http://aws.amazon.com/). - -`spark-ec2` is designed to manage multiple named clusters. You can -launch a new cluster (telling the script its size and giving it a name), -shutdown an existing cluster, or log into a cluster. Each cluster is -identified by placing its machines into EC2 security groups whose names -are derived from the name of the cluster. For example, a cluster named -`test` will contain a master node in a security group called -`test-master`, and a number of slave nodes in a security group called -`test-slaves`. The `spark-ec2` script will create these security groups -for you based on the cluster name you request. You can also use them to -identify machines belonging to each cluster in the Amazon EC2 Console. - - -# Before You Start - -- Create an Amazon EC2 key pair for yourself. This can be done by - logging into your Amazon Web Services account through the [AWS - console](http://aws.amazon.com/console/), clicking Key Pairs on the - left sidebar, and creating and downloading a key. Make sure that you - set the permissions for the private key file to `600` (i.e. only you - can read and write it) so that `ssh` will work. -- Whenever you want to use the `spark-ec2` script, set the environment - variables `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` to your - Amazon EC2 access key ID and secret access key. These can be - obtained from the [AWS homepage](http://aws.amazon.com/) by clicking - Account \> Security Credentials \> Access Credentials. - -# Launching a Cluster - -- Go into the `ec2` directory in the release of Spark you downloaded. -- Run - `./spark-ec2 -k -i -s launch `, - where `` is the name of your EC2 key pair (that you gave it - when you created it), `` is the private key file for your - key pair, `` is the number of slave nodes to launch (try - 1 at first), and `` is the name to give to your - cluster. - - For example: - - ```bash - export AWS_SECRET_ACCESS_KEY=AaBbCcDdEeFGgHhIiJjKkLlMmNnOoPpQqRrSsTtU -export AWS_ACCESS_KEY_ID=ABCDEFG1234567890123 -./spark-ec2 --key-pair=awskey --identity-file=awskey.pem --region=us-west-1 --zone=us-west-1a launch my-spark-cluster - ``` - -- After everything launches, check that the cluster scheduler is up and sees - all the slaves by going to its web UI, which will be printed at the end of - the script (typically `http://:8080`). - -You can also run `./spark-ec2 --help` to see more usage options. The -following options are worth pointing out: - -- `--instance-type=` can be used to specify an EC2 -instance type to use. For now, the script only supports 64-bit instance -types, and the default type is `m1.large` (which has 2 cores and 7.5 GB -RAM). Refer to the Amazon pages about [EC2 instance -types](http://aws.amazon.com/ec2/instance-types) and [EC2 -pricing](http://aws.amazon.com/ec2/#pricing) for information about other -instance types. -- `--region=` specifies an EC2 region in which to launch -instances. The default region is `us-east-1`. -- `--zone=` can be used to specify an EC2 availability zone -to launch instances in. Sometimes, you will get an error because there -is not enough capacity in one zone, and you should try to launch in -another. -- `--ebs-vol-size=` will attach an EBS volume with a given amount - of space to each node so that you can have a persistent HDFS cluster - on your nodes across cluster restarts (see below). -- `--spot-price=` will launch the worker nodes as - [Spot Instances](http://aws.amazon.com/ec2/spot-instances/), - bidding for the given maximum price (in dollars). -- `--spark-version=` will pre-load the cluster with the - specified version of Spark. The `` can be a version number - (e.g. "0.7.3") or a specific git hash. By default, a recent - version will be used. -- `--spark-git-repo=` will let you run a custom version of - Spark that is built from the given git repository. By default, the - [Apache Github mirror](https://github.com/apache/spark) will be used. - When using a custom Spark version, `--spark-version` must be set to git - commit hash, such as 317e114, instead of a version number. -- If one of your launches fails due to e.g. not having the right -permissions on your private key file, you can run `launch` with the -`--resume` option to restart the setup process on an existing cluster. - -# Launching a Cluster in a VPC - -- Run - `./spark-ec2 -k -i -s --vpc-id= --subnet-id= launch `, - where `` is the name of your EC2 key pair (that you gave it - when you created it), `` is the private key file for your - key pair, `` is the number of slave nodes to launch (try - 1 at first), `` is the name of your VPC, `` is the - name of your subnet, and `` is the name to give to your - cluster. - - For example: - - ```bash - export AWS_SECRET_ACCESS_KEY=AaBbCcDdEeFGgHhIiJjKkLlMmNnOoPpQqRrSsTtU -export AWS_ACCESS_KEY_ID=ABCDEFG1234567890123 -./spark-ec2 --key-pair=awskey --identity-file=awskey.pem --region=us-west-1 --zone=us-west-1a --vpc-id=vpc-a28d24c7 --subnet-id=subnet-4eb27b39 --spark-version=1.1.0 launch my-spark-cluster - ``` - -# Running Applications - -- Go into the `ec2` directory in the release of Spark you downloaded. -- Run `./spark-ec2 -k -i login ` to - SSH into the cluster, where `` and `` are as - above. (This is just for convenience; you could also use - the EC2 console.) -- To deploy code or data within your cluster, you can log in and use the - provided script `~/spark-ec2/copy-dir`, which, - given a directory path, RSYNCs it to the same location on all the slaves. -- If your application needs to access large datasets, the fastest way to do - that is to load them from Amazon S3 or an Amazon EBS device into an - instance of the Hadoop Distributed File System (HDFS) on your nodes. - The `spark-ec2` script already sets up a HDFS instance for you. It's - installed in `/root/ephemeral-hdfs`, and can be accessed using the - `bin/hadoop` script in that directory. Note that the data in this - HDFS goes away when you stop and restart a machine. -- There is also a *persistent HDFS* instance in - `/root/persistent-hdfs` that will keep data across cluster restarts. - Typically each node has relatively little space of persistent data - (about 3 GB), but you can use the `--ebs-vol-size` option to - `spark-ec2` to attach a persistent EBS volume to each node for - storing the persistent HDFS. -- Finally, if you get errors while running your application, look at the slave's logs - for that application inside of the scheduler work directory (/root/spark/work). You can - also view the status of the cluster using the web UI: `http://:8080`. - -# Configuration - -You can edit `/root/spark/conf/spark-env.sh` on each machine to set Spark configuration options, such -as JVM options. This file needs to be copied to **every machine** to reflect the change. The easiest way to -do this is to use a script we provide called `copy-dir`. First edit your `spark-env.sh` file on the master, -then run `~/spark-ec2/copy-dir /root/spark/conf` to RSYNC it to all the workers. - -The [configuration guide](configuration.html) describes the available configuration options. - -# Terminating a Cluster - -***Note that there is no way to recover data on EC2 nodes after shutting -them down! Make sure you have copied everything important off the nodes -before stopping them.*** - -- Go into the `ec2` directory in the release of Spark you downloaded. -- Run `./spark-ec2 destroy `. - -# Pausing and Restarting Clusters - -The `spark-ec2` script also supports pausing a cluster. In this case, -the VMs are stopped but not terminated, so they -***lose all data on ephemeral disks*** but keep the data in their -root partitions and their `persistent-hdfs`. Stopped machines will not -cost you any EC2 cycles, but ***will*** continue to cost money for EBS -storage. - -- To stop one of your clusters, go into the `ec2` directory and run -`./spark-ec2 --region= stop `. -- To restart it later, run -`./spark-ec2 -i --region= start `. -- To ultimately destroy the cluster and stop consuming EBS space, run -`./spark-ec2 --region= destroy ` as described in the previous -section. - -# Limitations - -- Support for "cluster compute" nodes is limited -- there's no way to specify a - locality group. However, you can launch slave nodes in your - `-slaves` group manually and then use `spark-ec2 launch - --resume` to start a cluster with them. - -If you have a patch or suggestion for one of these limitations, feel free to -[contribute](contributing-to-spark.html) it! - -# Accessing Data in S3 - -Spark's file interface allows it to process data in Amazon S3 using the same URI formats that are supported for Hadoop. You can specify a path in S3 as input through a URI of the form `s3n:///path`. To provide AWS credentials for S3 access, launch the Spark cluster with the option `--copy-aws-credentials`. Full instructions on S3 access using the Hadoop input libraries can be found on the [Hadoop S3 page](http://wiki.apache.org/hadoop/AmazonS3). - -In addition to using a single input file, you can also use a directory of files as input by simply giving the path to the directory. diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 6a512ab234bb2..9dea9b5904d2d 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -70,7 +70,7 @@ operators (e.g., [subgraph](#structural_operators), [joinVertices](#join_operato ## Migrating from Spark 1.1 -GraphX in Spark {{site.SPARK_VERSION}} contains a few user facing API changes: +GraphX in Spark 1.2 contains a few user facing API changes: 1. To improve performance we have introduced a new version of [`mapReduceTriplets`][Graph.mapReduceTriplets] called diff --git a/docs/hardware-provisioning.md b/docs/hardware-provisioning.md index 790220500a1b3..60ecb4f483afa 100644 --- a/docs/hardware-provisioning.md +++ b/docs/hardware-provisioning.md @@ -63,7 +63,7 @@ from the application's monitoring UI (`http://:4040`). # CPU Cores -Spark scales well to tens of CPU cores per machine because it performes minimal sharing between +Spark scales well to tens of CPU cores per machine because it performs minimal sharing between threads. You should likely provision at least **8-16 cores** per machine. Depending on the CPU cost of your workload, you may also need more: once data is in memory, most applications are either CPU- or network-bound. diff --git a/docs/index.md b/docs/index.md index f1d9e012c6cf0..20eab567a50df 100644 --- a/docs/index.md +++ b/docs/index.md @@ -64,7 +64,7 @@ To run Spark interactively in a R interpreter, use `bin/sparkR`: ./bin/sparkR --master local[2] Example applications are also provided in R. For example, - + ./bin/spark-submit examples/src/main/r/dataframe.R # Launching on a Cluster @@ -73,7 +73,6 @@ The Spark [cluster mode overview](cluster-overview.html) explains the key concep Spark can run both by itself, or over several existing cluster managers. It currently provides several options for deployment: -* [Amazon EC2](ec2-scripts.html): our EC2 scripts let you launch a cluster in about 5 minutes * [Standalone Deploy Mode](spark-standalone.html): simplest way to deploy Spark on a private cluster * [Apache Mesos](running-on-mesos.html) * [Hadoop YARN](running-on-yarn.html) @@ -87,7 +86,7 @@ options for deployment: in all supported languages (Scala, Java, Python, R) * Modules built on Spark: * [Spark Streaming](streaming-programming-guide.html): processing real-time data streams - * [Spark SQL and DataFrames](sql-programming-guide.html): support for structured data and relational queries + * [Spark SQL, Datasets, and DataFrames](sql-programming-guide.html): support for structured data and relational queries * [MLlib](mllib-guide.html): built-in machine learning library * [GraphX](graphx-programming-guide.html): Spark's new API for graph processing @@ -103,7 +102,7 @@ options for deployment: * [Cluster Overview](cluster-overview.html): overview of concepts and components when running on a cluster * [Submitting Applications](submitting-applications.html): packaging and deploying applications * Deployment modes: - * [Amazon EC2](ec2-scripts.html): scripts that let you launch a cluster on EC2 in about 5 minutes + * [Amazon EC2](https://github.com/amplab/spark-ec2): scripts that let you launch a cluster on EC2 in about 5 minutes * [Standalone Deploy Mode](spark-standalone.html): launch a standalone cluster quickly without a third-party cluster manager * [Mesos](running-on-mesos.html): deploy a private cluster using [Apache Mesos](http://mesos.apache.org) @@ -131,8 +130,8 @@ options for deployment: * [StackOverflow tag `apache-spark`](http://stackoverflow.com/questions/tagged/apache-spark) * [Mailing Lists](http://spark.apache.org/mailing-lists.html): ask questions about Spark here * [AMP Camps](http://ampcamp.berkeley.edu/): a series of training camps at UC Berkeley that featured talks and - exercises about Spark, Spark Streaming, Mesos, and more. [Videos](http://ampcamp.berkeley.edu/3/), - [slides](http://ampcamp.berkeley.edu/3/) and [exercises](http://ampcamp.berkeley.edu/3/exercises/) are + exercises about Spark, Spark Streaming, Mesos, and more. [Videos](http://ampcamp.berkeley.edu/6/), + [slides](http://ampcamp.berkeley.edu/6/) and [exercises](http://ampcamp.berkeley.edu/6/exercises/) are available online for free. * [Code Examples](http://spark.apache.org/examples.html): more are also available in the `examples` subfolder of Spark ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples), diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index 8d9c2ba2041b2..083c020caa5db 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -39,7 +39,10 @@ Resource allocation can be configured as follows, based on the cluster type: and optionally set `spark.cores.max` to limit each application's resource share as in the standalone mode. You should also set `spark.executor.memory` to control the executor memory. * **YARN:** The `--num-executors` option to the Spark YARN client controls how many executors it will allocate - on the cluster, while `--executor-memory` and `--executor-cores` control the resources per executor. + on the cluster (`spark.executor.instances` as configuration property), while `--executor-memory` + (`spark.executor.memory` configuration property) and `--executor-cores` (`spark.executor.cores` configuration + property) control the resources per executor. For more information, see the + [YARN Spark Properties](running-on-yarn.html). A second option available on Mesos is _dynamic sharing_ of CPU cores. In this mode, each Spark application still has a fixed and independent memory allocation (set by `spark.executor.memory`), but when the @@ -47,58 +50,56 @@ application is not running tasks on a machine, other applications may run tasks is useful when you expect large numbers of not overly active applications, such as shell sessions from separate users. However, it comes with a risk of less predictable latency, because it may take a while for an application to gain back cores on one node when it has work to do. To use this mode, simply use a -`mesos://` URL without setting `spark.mesos.coarse` to true. +`mesos://` URL and set `spark.mesos.coarse` to false. Note that none of the modes currently provide memory sharing across applications. If you would like to share data this way, we recommend running a single server application that can serve multiple requests by querying -the same RDDs. In future releases, in-memory storage systems such as [Tachyon](http://tachyon-project.org) will -provide another approach to share RDDs. +the same RDDs. ## Dynamic Resource Allocation -Spark 1.2 introduces the ability to dynamically scale the set of cluster resources allocated to -your application up and down based on the workload. This means that your application may give -resources back to the cluster if they are no longer used and request them again later when there -is demand. This feature is particularly useful if multiple applications share resources in your -Spark cluster. If a subset of the resources allocated to an application becomes idle, it can be -returned to the cluster's pool of resources and acquired by other applications. In Spark, dynamic -resource allocation is performed on the granularity of the executor and can be enabled through -`spark.dynamicAllocation.enabled`. - -This feature is currently disabled by default and available only on [YARN](running-on-yarn.html). -A future release will extend this to [standalone mode](spark-standalone.html) and -[Mesos coarse-grained mode](running-on-mesos.html#mesos-run-modes). Note that although Spark on -Mesos already has a similar notion of dynamic resource sharing in fine-grained mode, enabling -dynamic allocation allows your Mesos application to take advantage of coarse-grained low-latency -scheduling while sharing cluster resources efficiently. +Spark provides a mechanism to dynamically adjust the resources your application occupies based +on the workload. This means that your application may give resources back to the cluster if they +are no longer used and request them again later when there is demand. This feature is particularly +useful if multiple applications share resources in your Spark cluster. + +This feature is disabled by default and available on all coarse-grained cluster managers, i.e. +[standalone mode](spark-standalone.html), [YARN mode](running-on-yarn.html), and +[Mesos coarse-grained mode](running-on-mesos.html#mesos-run-modes). ### Configuration and Setup -All configurations used by this feature live under the `spark.dynamicAllocation.*` namespace. -To enable this feature, your application must set `spark.dynamicAllocation.enabled` to `true`. -Other relevant configurations are described on the -[configurations page](configuration.html#dynamic-allocation) and in the subsequent sections in -detail. +There are two requirements for using this feature. First, your application must set +`spark.dynamicAllocation.enabled` to `true`. Second, you must set up an *external shuffle service* +on each worker node in the same cluster and set `spark.shuffle.service.enabled` to true in your +application. The purpose of the external shuffle service is to allow executors to be removed +without deleting shuffle files written by them (more detail described +[below](job-scheduling.html#graceful-decommission-of-executors)). The way to set up this service +varies across cluster managers: + +In standalone mode, simply start your workers with `spark.shuffle.service.enabled` set to `true`. -Additionally, your application must use an external shuffle service. The purpose of the service is -to preserve the shuffle files written by executors so the executors can be safely removed (more -detail described [below](job-scheduling.html#graceful-decommission-of-executors)). To enable -this service, set `spark.shuffle.service.enabled` to `true`. In YARN, this external shuffle service -is implemented in `org.apache.spark.yarn.network.YarnShuffleService` that runs in each `NodeManager` -in your cluster. To start this service, follow these steps: +In Mesos coarse-grained mode, run `$SPARK_HOME/sbin/start-mesos-shuffle-service.sh` on all +slave nodes with `spark.shuffle.service.enabled` set to `true`. For instance, you may do so +through Marathon. + +In YARN mode, start the shuffle service on each `NodeManager` as follows: 1. Build Spark with the [YARN profile](building-spark.html). Skip this step if you are using a pre-packaged distribution. 2. Locate the `spark--yarn-shuffle.jar`. This should be under -`$SPARK_HOME/network/yarn/target/scala-` if you are building Spark yourself, and under +`$SPARK_HOME/common/network-yarn/target/scala-` if you are building Spark yourself, and under `lib` if you are using a distribution. 2. Add this jar to the classpath of all `NodeManager`s in your cluster. 3. In the `yarn-site.xml` on each node, add `spark_shuffle` to `yarn.nodemanager.aux-services`, then set `yarn.nodemanager.aux-services.spark_shuffle.class` to -`org.apache.spark.network.yarn.YarnShuffleService`. Additionally, set all relevant -`spark.shuffle.service.*` [configurations](configuration.html). +`org.apache.spark.network.yarn.YarnShuffleService`. 4. Restart all `NodeManager`s in your cluster. +All other relevant configurations are optional and under the `spark.dynamicAllocation.*` and +`spark.shuffle.service.*` namespaces. For more detail, see the +[configurations page](configuration.html#dynamic-allocation). + ### Resource Allocation Policy At a high level, Spark should relinquish executors when they are no longer used and acquire diff --git a/docs/js/main.js b/docs/js/main.js index f5d66b16f7b21..2329eb8327dd5 100755 --- a/docs/js/main.js +++ b/docs/js/main.js @@ -83,7 +83,7 @@ $(function() { // Display anchor links when hovering over headers. For documentation of the // configuration options, see the AnchorJS documentation. anchors.options = { - placement: 'left' + placement: 'right' }; anchors.add(); diff --git a/docs/ml-advanced.md b/docs/ml-advanced.md new file mode 100644 index 0000000000000..91731d78a2d43 --- /dev/null +++ b/docs/ml-advanced.md @@ -0,0 +1,13 @@ +--- +layout: global +title: Advanced topics - spark.ml +displayTitle: Advanced topics - spark.ml +--- + +# Optimization of linear methods + +The optimization algorithm underlying the implementation is called +[Orthant-Wise Limited-memory +QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) +(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 +regularization and elastic net. diff --git a/docs/ml-ann.md b/docs/ml-ann.md index d5ddd92af1e96..c2d9bd200f62f 100644 --- a/docs/ml-ann.md +++ b/docs/ml-ann.md @@ -1,123 +1,8 @@ --- layout: global -title: Multilayer perceptron classifier - ML -displayTitle: ML - Multilayer perceptron classifier +title: Multilayer perceptron classifier - spark.ml +displayTitle: Multilayer perceptron classifier - spark.ml --- - -`\[ -\newcommand{\R}{\mathbb{R}} -\newcommand{\E}{\mathbb{E}} -\newcommand{\x}{\mathbf{x}} -\newcommand{\y}{\mathbf{y}} -\newcommand{\wv}{\mathbf{w}} -\newcommand{\av}{\mathbf{\alpha}} -\newcommand{\bv}{\mathbf{b}} -\newcommand{\N}{\mathbb{N}} -\newcommand{\id}{\mathbf{I}} -\newcommand{\ind}{\mathbf{1}} -\newcommand{\0}{\mathbf{0}} -\newcommand{\unit}{\mathbf{e}} -\newcommand{\one}{\mathbf{1}} -\newcommand{\zero}{\mathbf{0}} -\]` - - -Multilayer perceptron classifier (MLPC) is a classifier based on the [feedforward artificial neural network](https://en.wikipedia.org/wiki/Feedforward_neural_network). -MLPC consists of multiple layers of nodes. -Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes maps inputs to the outputs -by performing linear combination of the inputs with the node's weights `$\wv$` and bias `$\bv$` and applying an activation function. -It can be written in matrix form for MLPC with `$K+1$` layers as follows: -`\[ -\mathrm{y}(\x) = \mathrm{f_K}(...\mathrm{f_2}(\wv_2^T\mathrm{f_1}(\wv_1^T \x+b_1)+b_2)...+b_K) -\]` -Nodes in intermediate layers use sigmoid (logistic) function: -`\[ -\mathrm{f}(z_i) = \frac{1}{1 + e^{-z_i}} -\]` -Nodes in the output layer use softmax function: -`\[ -\mathrm{f}(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}} -\]` -The number of nodes `$N$` in the output layer corresponds to the number of classes. - -MLPC employes backpropagation for learning the model. We use logistic loss function for optimization and L-BFGS as optimization routine. - -**Examples** - -
    - -
    - -{% highlight scala %} -import org.apache.spark.ml.classification.MultilayerPerceptronClassifier -import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.Row - -// Load training data -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt").toDF() -// Split the data into train and test -val splits = data.randomSplit(Array(0.6, 0.4), seed = 1234L) -val train = splits(0) -val test = splits(1) -// specify layers for the neural network: -// input layer of size 4 (features), two intermediate of size 5 and 4 and output of size 3 (classes) -val layers = Array[Int](4, 5, 4, 3) -// create the trainer and set its parameters -val trainer = new MultilayerPerceptronClassifier() - .setLayers(layers) - .setBlockSize(128) - .setSeed(1234L) - .setMaxIter(100) -// train the model -val model = trainer.fit(train) -// compute precision on the test set -val result = model.transform(test) -val predictionAndLabels = result.select("prediction", "label") -val evaluator = new MulticlassClassificationEvaluator() - .setMetricName("precision") -println("Precision:" + evaluator.evaluate(predictionAndLabels)) -{% endhighlight %} - -
    - -
    - -{% highlight java %} -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel; -import org.apache.spark.ml.classification.MultilayerPerceptronClassifier; -import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; - -// Load training data -String path = "data/mllib/sample_multiclass_classification_data.txt"; -JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); -DataFrame dataFrame = sqlContext.createDataFrame(data, LabeledPoint.class); -// Split the data into train and test -DataFrame[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L); -DataFrame train = splits[0]; -DataFrame test = splits[1]; -// specify layers for the neural network: -// input layer of size 4 (features), two intermediate of size 5 and 4 and output of size 3 (classes) -int[] layers = new int[] {4, 5, 4, 3}; -// create the trainer and set its parameters -MultilayerPerceptronClassifier trainer = new MultilayerPerceptronClassifier() - .setLayers(layers) - .setBlockSize(128) - .setSeed(1234L) - .setMaxIter(100); -// train the model -MultilayerPerceptronClassificationModel model = trainer.fit(train); -// compute precision on the test set -DataFrame result = model.transform(test); -DataFrame predictionAndLabels = result.select("prediction", "label"); -MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() - .setMetricName("precision"); -System.out.println("Precision = " + evaluator.evaluate(predictionAndLabels)); -{% endhighlight %} -
    - -
    + > This section has been moved into the + [classification and regression section](ml-classification-regression.html#multilayer-perceptron-classifier). diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md new file mode 100644 index 0000000000000..eaf4f6d843368 --- /dev/null +++ b/docs/ml-classification-regression.md @@ -0,0 +1,819 @@ +--- +layout: global +title: Classification and regression - spark.ml +displayTitle: Classification and regression - spark.ml +--- + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + +**Table of Contents** + +* This will become a table of contents (this text will be scraped). +{:toc} + +In `spark.ml`, we implement popular linear methods such as logistic +regression and linear least squares with $L_1$ or $L_2$ regularization. +Refer to [the linear methods in mllib](mllib-linear-methods.html) for +details about implementation and tuning. We also include a DataFrame API for [Elastic +net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid +of $L_1$ and $L_2$ regularization proposed in [Zou et al, Regularization +and variable selection via the elastic +net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). +Mathematically, it is defined as a convex combination of the $L_1$ and +the $L_2$ regularization terms: +`\[ +\alpha \left( \lambda \|\wv\|_1 \right) + (1-\alpha) \left( \frac{\lambda}{2}\|\wv\|_2^2 \right) , \alpha \in [0, 1], \lambda \geq 0 +\]` +By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$ +regularization as special cases. For example, if a [linear +regression](https://en.wikipedia.org/wiki/Linear_regression) model is +trained with the elastic net parameter $\alpha$ set to $1$, it is +equivalent to a +[Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. +On the other hand, if $\alpha$ is set to $0$, the trained model reduces +to a [ridge +regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. +We implement Pipelines API for both linear regression and logistic +regression with elastic net regularization. + + +# Classification + +## Logistic regression + +Logistic regression is a popular method to predict a binary response. It is a special case of [Generalized Linear models](https://en.wikipedia.org/wiki/Generalized_linear_model) that predicts the probability of the outcome. +For more background and more details about the implementation, refer to the documentation of the [logistic regression in `spark.mllib`](mllib-linear-methods.html#logistic-regression). + + > The current implementation of logistic regression in `spark.ml` only supports binary classes. Support for multiclass regression will be added in the future. + +**Example** + +The following example shows how to train a logistic regression model +with elastic net regularization. `elasticNetParam` corresponds to +$\alpha$ and `regParam` corresponds to $\lambda$. + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java %} +
    + +
    +{% include_example python/ml/logistic_regression_with_elastic_net.py %} +
    + +
    + +The `spark.ml` implementation of logistic regression also supports +extracting a summary of the model over the training set. Note that the +predictions and metrics which are stored as `DataFrame` in +`BinaryLogisticRegressionSummary` are annotated `@transient` and hence +only available on the driver. + +
    + +
    + +[`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary) +provides a summary for a +[`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel). +Currently, only binary classification is supported and the +summary must be explicitly cast to +[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary). +This will likely change when multiclass classification is supported. + +Continuing the earlier example: + +{% include_example scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala %} +
    + +
    +[`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html) +provides a summary for a +[`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html). +Currently, only binary classification is supported and the +summary must be explicitly cast to +[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). +This will likely change when multiclass classification is supported. + +Continuing the earlier example: + +{% include_example java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java %} +
    + + +
    +Logistic regression model summary is not yet supported in Python. +
    + +
    + + +## Decision tree classifier + +Decision trees are a popular family of classification and regression methods. +More information about the `spark.ml` implementation can be found further in the [section on decision trees](#decision-trees). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the Decision Tree algorithm can recognize. + +
    +
    + +More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.classification.DecisionTreeClassifier). + +{% include_example scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala %} + +
    + +
    + +More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/classification/DecisionTreeClassifier.html). + +{% include_example java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java %} + +
    + +
    + +More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.classification.DecisionTreeClassifier). + +{% include_example python/ml/decision_tree_classification_example.py %} + +
    + +
    + +## Random forest classifier + +Random forests are a popular family of classification and regression methods. +More information about the `spark.ml` implementation can be found further in the [section on random forests](#random-forests). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.RandomForestClassifier) for more details. + +{% include_example scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/RandomForestClassifier.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.RandomForestClassifier) for more details. + +{% include_example python/ml/random_forest_classifier_example.py %} +
    +
    + +## Gradient-boosted tree classifier + +Gradient-boosted trees (GBTs) are a popular classification and regression method using ensembles of decision trees. +More information about the `spark.ml` implementation can be found further in the [section on GBTs](#gradient-boosted-trees-gbts). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.GBTClassifier) for more details. + +{% include_example scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/GBTClassifier.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.GBTClassifier) for more details. + +{% include_example python/ml/gradient_boosted_tree_classifier_example.py %} +
    +
    + +## Multilayer perceptron classifier + +Multilayer perceptron classifier (MLPC) is a classifier based on the [feedforward artificial neural network](https://en.wikipedia.org/wiki/Feedforward_neural_network). +MLPC consists of multiple layers of nodes. +Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes maps inputs to the outputs +by performing linear combination of the inputs with the node's weights `$\wv$` and bias `$\bv$` and applying an activation function. +It can be written in matrix form for MLPC with `$K+1$` layers as follows: +`\[ +\mathrm{y}(\x) = \mathrm{f_K}(...\mathrm{f_2}(\wv_2^T\mathrm{f_1}(\wv_1^T \x+b_1)+b_2)...+b_K) +\]` +Nodes in intermediate layers use sigmoid (logistic) function: +`\[ +\mathrm{f}(z_i) = \frac{1}{1 + e^{-z_i}} +\]` +Nodes in the output layer use softmax function: +`\[ +\mathrm{f}(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}} +\]` +The number of nodes `$N$` in the output layer corresponds to the number of classes. + +MLPC employs backpropagation for learning the model. We use logistic loss function for optimization and L-BFGS as optimization routine. + +**Example** + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java %} +
    + +
    +{% include_example python/ml/multilayer_perceptron_classification.py %} +
    + +
    + + +## One-vs-Rest classifier (a.k.a. One-vs-All) + +[OneVsRest](http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest) is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. It is also known as "One-vs-All." + +`OneVsRest` is implemented as an `Estimator`. For the base classifier it takes instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. + +Predictions are done by evaluating each binary classifier and the index of the most confident classifier is output as label. + +**Example** + +The example below demonstrates how to load the +[Iris dataset](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/iris.scale), parse it as a DataFrame and perform multiclass classification using `OneVsRest`. The test error is calculated to measure the algorithm accuracy. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.OneVsRest) for more details. + +{% include_example scala/org/apache/spark/examples/ml/OneVsRestExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/OneVsRest.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaOneVsRestExample.java %} +
    +
    + +## Naive Bayes + +[Naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) are a family of simple +probabilistic classifiers based on applying Bayes' theorem with strong (naive) independence +assumptions between the features. The spark.ml implementation currently supports both [multinomial +naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html) +and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). +More information can be found in the section on [Naive Bayes in MLlib](mllib-naive-bayes.html#naive-bayes-sparkmllib). + +**Example** + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.NaiveBayes) for more details. + +{% include_example scala/org/apache/spark/examples/ml/NaiveBayesExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/NaiveBayes.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.NaiveBayes) for more details. + +{% include_example python/ml/naive_bayes_example.py %} +
    +
    + + +# Regression + +## Linear regression + +The interface for working with linear regression models and model +summaries is similar to the logistic regression case. + +**Example** + +The following +example demonstrates training an elastic net regularized linear +regression model and extracting model summary statistics. + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java %} +
    + +
    + +{% include_example python/ml/linear_regression_with_elastic_net.py %} +
    + +
    + + +## Decision tree regression + +Decision trees are a popular family of classification and regression methods. +More information about the `spark.ml` implementation can be found further in the [section on decision trees](#decision-trees). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use a feature transformer to index categorical features, adding metadata to the `DataFrame` which the Decision Tree algorithm can recognize. + +
    +
    + +More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.regression.DecisionTreeRegressor). + +{% include_example scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala %} +
    + +
    + +More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/regression/DecisionTreeRegressor.html). + +{% include_example java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java %} +
    + +
    + +More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.regression.DecisionTreeRegressor). + +{% include_example python/ml/decision_tree_regression_example.py %} +
    + +
    + + +## Random forest regression + +Random forests are a popular family of classification and regression methods. +More information about the `spark.ml` implementation can be found further in the [section on random forests](#random-forests). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use a feature transformer to index categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.RandomForestRegressor) for more details. + +{% include_example scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/RandomForestRegressor.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.RandomForestRegressor) for more details. + +{% include_example python/ml/random_forest_regressor_example.py %} +
    +
    + +## Gradient-boosted tree regression + +Gradient-boosted trees (GBTs) are a popular regression method using ensembles of decision trees. +More information about the `spark.ml` implementation can be found further in the [section on GBTs](#gradient-boosted-trees-gbts). + +**Example** + +Note: For this example dataset, `GBTRegressor` actually only needs 1 iteration, but that will not +be true in general. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.GBTRegressor) for more details. + +{% include_example scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/GBTRegressor.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.GBTRegressor) for more details. + +{% include_example python/ml/gradient_boosted_tree_regressor_example.py %} +
    +
    + + +## Survival regression + + +In `spark.ml`, we implement the [Accelerated failure time (AFT)](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) +model which is a parametric survival regression model for censored data. +It describes a model for the log of survival time, so it's often called +log-linear model for survival analysis. Different from +[Proportional hazards](https://en.wikipedia.org/wiki/Proportional_hazards_model) model +designed for the same purpose, the AFT model is more easily to parallelize +because each instance contribute to the objective function independently. + +Given the values of the covariates $x^{'}$, for random lifetime $t_{i}$ of +subjects i = 1, ..., n, with possible right-censoring, +the likelihood function under the AFT model is given as: +`\[ +L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}} +\]` +Where $\delta_{i}$ is the indicator of the event has occurred i.e. uncensored or not. +Using $\epsilon_{i}=\frac{\log{t_{i}}-x^{'}\beta}{\sigma}$, the log-likelihood function +assumes the form: +`\[ +\iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+\delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}] +\]` +Where $S_{0}(\epsilon_{i})$ is the baseline survivor function, +and $f_{0}(\epsilon_{i})$ is corresponding density function. + +The most commonly used AFT model is based on the Weibull distribution of the survival time. +The Weibull distribution for lifetime corresponding to extreme value distribution for +log of the lifetime, and the $S_{0}(\epsilon)$ function is: +`\[ +S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}}) +\]` +the $f_{0}(\epsilon_{i})$ function is: +`\[ +f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}}) +\]` +The log-likelihood function for AFT model with Weibull distribution of lifetime is: +`\[ +\iota(\beta,\sigma)= -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}] +\]` +Due to minimizing the negative log-likelihood equivalent to maximum a posteriori probability, +the loss function we use to optimize is $-\iota(\beta,\sigma)$. +The gradient functions for $\beta$ and $\log\sigma$ respectively are: +`\[ +\frac{\partial (-\iota)}{\partial \beta}=\sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma} +\]` +`\[ +\frac{\partial (-\iota)}{\partial (\log\sigma)}=\sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}] +\]` + +The AFT model can be formulated as a convex optimization problem, +i.e. the task of finding a minimizer of a convex function $-\iota(\beta,\sigma)$ +that depends coefficients vector $\beta$ and the log of scale parameter $\log\sigma$. +The optimization algorithm underlying the implementation is L-BFGS. +The implementation matches the result from R's survival function +[survreg](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html) + +**Example** + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java %} +
    + +
    +{% include_example python/ml/aft_survival_regression.py %} +
    + +
    + + + +# Decision trees + +[Decision trees](http://en.wikipedia.org/wiki/Decision_tree_learning) +and their ensembles are popular methods for the machine learning tasks of +classification and regression. Decision trees are widely used since they are easy to interpret, +handle categorical features, extend to the multiclass classification setting, do not require +feature scaling, and are able to capture non-linearities and feature interactions. Tree ensemble +algorithms such as random forests and boosting are among the top performers for classification and +regression tasks. + +The `spark.ml` implementation supports decision trees for binary and multiclass classification and for regression, +using both continuous and categorical features. The implementation partitions data by rows, +allowing distributed training with millions or even billions of instances. + +Users can find more information about the decision tree algorithm in the [MLlib Decision Tree guide](mllib-decision-tree.html). +The main differences between this API and the [original MLlib Decision Tree API](mllib-decision-tree.html) are: + +* support for ML Pipelines +* separation of Decision Trees for classification vs. regression +* use of DataFrame metadata to distinguish continuous and categorical features + + +The Pipelines API for Decision Trees offers a bit more functionality than the original API. +In particular, for classification, users can get the predicted probability of each class (a.k.a. class conditional probabilities); +for regression, users can get the biased sample variance of prediction. + +Ensembles of trees (Random Forests and Gradient-Boosted Trees) are described below in the [Tree ensembles section](#tree-ensembles). + +## Inputs and Outputs + +We list the input and output (prediction) column types here. +All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. + +### Input Columns + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    labelColDouble"label"Label to predict
    featuresColVector"features"Feature vector
    + +### Output Columns + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescriptionNotes
    predictionColDouble"prediction"Predicted label
    rawPredictionColVector"rawPrediction"Vector of length # classes, with the counts of training instance labels at the tree node which makes the predictionClassification only
    probabilityColVector"probability"Vector of length # classes equal to rawPrediction normalized to a multinomial distributionClassification only
    varianceColDoubleThe biased sample variance of predictionRegression only
    + + +# Tree Ensembles + +The DataFrame API supports two major tree ensemble algorithms: [Random Forests](http://en.wikipedia.org/wiki/Random_forest) and [Gradient-Boosted Trees (GBTs)](http://en.wikipedia.org/wiki/Gradient_boosting). +Both use [`spark.ml` decision trees](ml-classification-regression.html#decision-trees) as their base models. + +Users can find more information about ensemble algorithms in the [MLlib Ensemble guide](mllib-ensembles.html). +In this section, we demonstrate the DataFrame API for ensembles. + +The main differences between this API and the [original MLlib ensembles API](mllib-ensembles.html) are: + +* support for DataFrames and ML Pipelines +* separation of classification vs. regression +* use of DataFrame metadata to distinguish continuous and categorical features +* more functionality for random forests: estimates of feature importance, as well as the predicted probability of each class (a.k.a. class conditional probabilities) for classification. + +## Random Forests + +[Random forests](http://en.wikipedia.org/wiki/Random_forest) +are ensembles of [decision trees](ml-decision-tree.html). +Random forests combine many decision trees in order to reduce the risk of overfitting. +The `spark.ml` implementation supports random forests for binary and multiclass classification and for regression, +using both continuous and categorical features. + +For more information on the algorithm itself, please see the [`spark.mllib` documentation on random forests](mllib-ensembles.html). + +### Inputs and Outputs + +We list the input and output (prediction) column types here. +All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. + +#### Input Columns + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    labelColDouble"label"Label to predict
    featuresColVector"features"Feature vector
    + +#### Output Columns (Predictions) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescriptionNotes
    predictionColDouble"prediction"Predicted label
    rawPredictionColVector"rawPrediction"Vector of length # classes, with the counts of training instance labels at the tree node which makes the predictionClassification only
    probabilityColVector"probability"Vector of length # classes equal to rawPrediction normalized to a multinomial distributionClassification only
    + + + +## Gradient-Boosted Trees (GBTs) + +[Gradient-Boosted Trees (GBTs)](http://en.wikipedia.org/wiki/Gradient_boosting) +are ensembles of [decision trees](ml-decision-tree.html). +GBTs iteratively train decision trees in order to minimize a loss function. +The `spark.ml` implementation supports GBTs for binary classification and for regression, +using both continuous and categorical features. + +For more information on the algorithm itself, please see the [`spark.mllib` documentation on GBTs](mllib-ensembles.html). + +### Inputs and Outputs + +We list the input and output (prediction) column types here. +All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. + +#### Input Columns + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    labelColDouble"label"Label to predict
    featuresColVector"features"Feature vector
    + +Note that `GBTClassifier` currently only supports binary labels. + +#### Output Columns (Predictions) + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescriptionNotes
    predictionColDouble"prediction"Predicted label
    + +In the future, `GBTClassifier` will also output columns for `rawPrediction` and `probability`, just as `RandomForestClassifier` does. + diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md new file mode 100644 index 0000000000000..440c455cd077c --- /dev/null +++ b/docs/ml-clustering.md @@ -0,0 +1,107 @@ +--- +layout: global +title: Clustering - spark.ml +displayTitle: Clustering - spark.ml +--- + +In this section, we introduce the pipeline API for [clustering in mllib](mllib-clustering.html). + +**Table of Contents** + +* This will become a table of contents (this text will be scraped). +{:toc} + +## K-means + +[k-means](http://en.wikipedia.org/wiki/K-means_clustering) is one of the +most commonly used clustering algorithms that clusters the data points into a +predefined number of clusters. The MLlib implementation includes a parallelized +variant of the [k-means++](http://en.wikipedia.org/wiki/K-means%2B%2B) method +called [kmeans||](http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf). + +`KMeans` is implemented as an `Estimator` and generates a `KMeansModel` as the base model. + +### Input Columns + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    featuresColVector"features"Feature vector
    + +### Output Columns + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    predictionColInt"prediction"Predicted cluster center
    + +### Example + +
    + +
    +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.KMeans) for more details. + +{% include_example scala/org/apache/spark/examples/ml/KMeansExample.scala %} +
    + +
    +Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/KMeans.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaKMeansExample.java %} +
    + +
    + + +## Latent Dirichlet allocation (LDA) + +`LDA` is implemented as an `Estimator` that supports both `EMLDAOptimizer` and `OnlineLDAOptimizer`, +and generates a `LDAModel` as the base models. Expert users may cast a `LDAModel` generated by +`EMLDAOptimizer` to a `DistributedLDAModel` if needed. + +
    + +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.LDA) for more details. + +{% include_example scala/org/apache/spark/examples/ml/LDAExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/LDA.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaLDAExample.java %} +
    + +
    \ No newline at end of file diff --git a/docs/ml-collaborative-filtering.md b/docs/ml-collaborative-filtering.md new file mode 100644 index 0000000000000..4514a358e12f2 --- /dev/null +++ b/docs/ml-collaborative-filtering.md @@ -0,0 +1,148 @@ +--- +layout: global +title: Collaborative Filtering - spark.ml +displayTitle: Collaborative Filtering - spark.ml +--- + +* Table of contents +{:toc} + +## Collaborative filtering + +[Collaborative filtering](http://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) +is commonly used for recommender systems. These techniques aim to fill in the +missing entries of a user-item association matrix. `spark.ml` currently supports +model-based collaborative filtering, in which users and products are described +by a small set of latent factors that can be used to predict missing entries. +`spark.ml` uses the [alternating least squares +(ALS)](http://dl.acm.org/citation.cfm?id=1608614) +algorithm to learn these latent factors. The implementation in `spark.ml` has the +following parameters: + +* *numBlocks* is the number of blocks the users and items will be partitioned into in order to parallelize computation (defaults to 10). +* *rank* is the number of latent factors in the model (defaults to 10). +* *maxIter* is the maximum number of iterations to run (defaults to 10). +* *regParam* specifies the regularization parameter in ALS (defaults to 1.0). +* *implicitPrefs* specifies whether to use the *explicit feedback* ALS variant or one adapted for + *implicit feedback* data (defaults to `false` which means using *explicit feedback*). +* *alpha* is a parameter applicable to the implicit feedback variant of ALS that governs the + *baseline* confidence in preference observations (defaults to 1.0). +* *nonnegative* specifies whether or not to use nonnegative constraints for least squares (defaults to `false`). + +### Explicit vs. implicit feedback + +The standard approach to matrix factorization based collaborative filtering treats +the entries in the user-item matrix as *explicit* preferences given by the user to the item, +for example, users giving ratings to movies. + +It is common in many real-world use cases to only have access to *implicit feedback* (e.g. views, +clicks, purchases, likes, shares etc.). The approach used in `spark.mllib` to deal with such data is taken +from [Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22). +Essentially, instead of trying to model the matrix of ratings directly, this approach treats the data +as numbers representing the *strength* in observations of user actions (such as the number of clicks, +or the cumulative duration someone spent viewing a movie). Those numbers are then related to the level of +confidence in observed user preferences, rather than explicit ratings given to items. The model +then tries to find latent factors that can be used to predict the expected preference of a user for +an item. + +### Scaling of the regularization parameter + +We scale the regularization parameter `regParam` in solving each least squares problem by +the number of ratings the user generated in updating user factors, +or the number of ratings the product received in updating product factors. +This approach is named "ALS-WR" and discussed in the paper +"[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](http://dx.doi.org/10.1007/978-3-540-68880-8_32)". +It makes `regParam` less dependent on the scale of the dataset, so we can apply the +best parameter learned from a sampled subset to the full dataset and expect similar performance. + +## Examples + +
    +
    + +In the following example, we load rating data from the +[MovieLens dataset](http://grouplens.org/datasets/movielens/), each row +consisting of a user, a movie, a rating and a timestamp. +We then train an ALS model which assumes, by default, that the ratings are +explicit (`implicitPrefs` is `false`). +We evaluate the recommendation model by measuring the root-mean-square error of +rating prediction. + +Refer to the [`ALS` Scala docs](api/scala/index.html#org.apache.spark.ml.recommendation.ALS) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/ALSExample.scala %} + +If the rating matrix is derived from another source of information (i.e. it is +inferred from other signals), you can set `implicitPrefs` to `true` to get +better results: + +{% highlight scala %} +val als = new ALS() + .setMaxIter(5) + .setRegParam(0.01) + .setImplicitPrefs(true) + .setUserCol("userId") + .setItemCol("movieId") + .setRatingCol("rating") +{% endhighlight %} + +
    + +
    + +In the following example, we load rating data from the +[MovieLens dataset](http://grouplens.org/datasets/movielens/), each row +consisting of a user, a movie, a rating and a timestamp. +We then train an ALS model which assumes, by default, that the ratings are +explicit (`implicitPrefs` is `false`). +We evaluate the recommendation model by measuring the root-mean-square error of +rating prediction. + +Refer to the [`ALS` Java docs](api/java/org/apache/spark/ml/recommendation/ALS.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaALSExample.java %} + +If the rating matrix is derived from another source of information (i.e. it is +inferred from other signals), you can set `implicitPrefs` to `true` to get +better results: + +{% highlight java %} +ALS als = new ALS() + .setMaxIter(5) + .setRegParam(0.01) + .setImplicitPrefs(true) + .setUserCol("userId") + .setItemCol("movieId") + .setRatingCol("rating"); +{% endhighlight %} + +
    + +
    + +In the following example, we load rating data from the +[MovieLens dataset](http://grouplens.org/datasets/movielens/), each row +consisting of a user, a movie, a rating and a timestamp. +We then train an ALS model which assumes, by default, that the ratings are +explicit (`implicitPrefs` is `False`). +We evaluate the recommendation model by measuring the root-mean-square error of +rating prediction. + +Refer to the [`ALS` Python docs](api/python/pyspark.ml.html#pyspark.ml.recommendation.ALS) +for more details on the API. + +{% include_example python/ml/als_example.py %} + +If the rating matrix is derived from another source of information (i.e. it is +inferred from other signals), you can set `implicitPrefs` to `True` to get +better results: + +{% highlight python %} +als = ALS(maxIter=5, regParam=0.01, implicitPrefs=True, + userCol="userId", itemCol="movieId", ratingCol="rating") +{% endhighlight %} + +
    +
    diff --git a/docs/ml-decision-tree.md b/docs/ml-decision-tree.md index 542819e93e6dc..a721d55bc675b 100644 --- a/docs/ml-decision-tree.md +++ b/docs/ml-decision-tree.md @@ -1,493 +1,8 @@ --- layout: global -title: Decision Trees - SparkML -displayTitle: ML - Decision Trees +title: Decision trees - spark.ml +displayTitle: Decision trees - spark.ml --- -**Table of Contents** - -* This will become a table of contents (this text will be scraped). -{:toc} - - -# Overview - -[Decision trees](http://en.wikipedia.org/wiki/Decision_tree_learning) -and their ensembles are popular methods for the machine learning tasks of -classification and regression. Decision trees are widely used since they are easy to interpret, -handle categorical features, extend to the multiclass classification setting, do not require -feature scaling, and are able to capture non-linearities and feature interactions. Tree ensemble -algorithms such as random forests and boosting are among the top performers for classification and -regression tasks. - -MLlib supports decision trees for binary and multiclass classification and for regression, -using both continuous and categorical features. The implementation partitions data by rows, -allowing distributed training with millions or even billions of instances. - -Users can find more information about the decision tree algorithm in the [MLlib Decision Tree guide](mllib-decision-tree.html). In this section, we demonstrate the Pipelines API for Decision Trees. - -The Pipelines API for Decision Trees offers a bit more functionality than the original API. In particular, for classification, users can get the predicted probability of each class (a.k.a. class conditional probabilities). - -Ensembles of trees (Random Forests and Gradient-Boosted Trees) are described in the [Ensembles guide](ml-ensembles.html). - -# Inputs and Outputs - -We list the input and output (prediction) column types here. -All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. - -## Input Columns - - - - - - - - - - - - - - - - - - - - - - - - -
    Param nameType(s)DefaultDescription
    labelColDouble"label"Label to predict
    featuresColVector"features"Feature vector
    - -## Output Columns - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    Param nameType(s)DefaultDescriptionNotes
    predictionColDouble"prediction"Predicted label
    rawPredictionColVector"rawPrediction"Vector of length # classes, with the counts of training instance labels at the tree node which makes the predictionClassification only
    probabilityColVector"probability"Vector of length # classes equal to rawPrediction normalized to a multinomial distributionClassification only
    - -# Examples - -The below examples demonstrate the Pipelines API for Decision Trees. The main differences between this API and the [original MLlib Decision Tree API](mllib-decision-tree.html) are: - -* support for ML Pipelines -* separation of Decision Trees for classification vs. regression -* use of DataFrame metadata to distinguish continuous and categorical features - - -## Classification - -The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. -We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the Decision Tree algorithm can recognize. - -
    -
    - -More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.classification.DecisionTreeClassifier). - -{% highlight scala %} -import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.classification.DecisionTreeClassifier -import org.apache.spark.ml.classification.DecisionTreeClassificationModel -import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} -import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file, converting it to a DataFrame. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() - -// Index labels, adding metadata to the label column. -// Fit on whole dataset to include all labels in index. -val labelIndexer = new StringIndexer() - .setInputCol("label") - .setOutputCol("indexedLabel") - .fit(data) -// Automatically identify categorical features, and index them. -val featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) // features with > 4 distinct values are treated as continuous - .fit(data) - -// Split the data into training and test sets (30% held out for testing) -val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) - -// Train a DecisionTree model. -val dt = new DecisionTreeClassifier() - .setLabelCol("indexedLabel") - .setFeaturesCol("indexedFeatures") - -// Convert indexed labels back to original labels. -val labelConverter = new IndexToString() - .setInputCol("prediction") - .setOutputCol("predictedLabel") - .setLabels(labelIndexer.labels) - -// Chain indexers and tree in a Pipeline -val pipeline = new Pipeline() - .setStages(Array(labelIndexer, featureIndexer, dt, labelConverter)) - -// Train model. This also runs the indexers. -val model = pipeline.fit(trainingData) - -// Make predictions. -val predictions = model.transform(testData) - -// Select example rows to display. -predictions.select("predictedLabel", "label", "features").show(5) - -// Select (prediction, true label) and compute test error -val evaluator = new MulticlassClassificationEvaluator() - .setLabelCol("indexedLabel") - .setPredictionCol("prediction") - .setMetricName("precision") -val accuracy = evaluator.evaluate(predictions) -println("Test Error = " + (1.0 - accuracy)) - -val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] -println("Learned classification tree model:\n" + treeModel.toDebugString) -{% endhighlight %} -
    - -
    - -More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/classification/DecisionTreeClassifier.html). - -{% highlight java %} -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineModel; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.classification.DecisionTreeClassifier; -import org.apache.spark.ml.classification.DecisionTreeClassificationModel; -import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; -import org.apache.spark.ml.feature.*; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; -import org.apache.spark.sql.DataFrame; - -// Load and parse the data file, converting it to a DataFrame. -RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); -DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); - -// Index labels, adding metadata to the label column. -// Fit on whole dataset to include all labels in index. -StringIndexerModel labelIndexer = new StringIndexer() - .setInputCol("label") - .setOutputCol("indexedLabel") - .fit(data); -// Automatically identify categorical features, and index them. -VectorIndexerModel featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) // features with > 4 distinct values are treated as continuous - .fit(data); - -// Split the data into training and test sets (30% held out for testing) -DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); -DataFrame trainingData = splits[0]; -DataFrame testData = splits[1]; - -// Train a DecisionTree model. -DecisionTreeClassifier dt = new DecisionTreeClassifier() - .setLabelCol("indexedLabel") - .setFeaturesCol("indexedFeatures"); - -// Convert indexed labels back to original labels. -IndexToString labelConverter = new IndexToString() - .setInputCol("prediction") - .setOutputCol("predictedLabel") - .setLabels(labelIndexer.labels()); - -// Chain indexers and tree in a Pipeline -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {labelIndexer, featureIndexer, dt, labelConverter}); - -// Train model. This also runs the indexers. -PipelineModel model = pipeline.fit(trainingData); - -// Make predictions. -DataFrame predictions = model.transform(testData); - -// Select example rows to display. -predictions.select("predictedLabel", "label", "features").show(5); - -// Select (prediction, true label) and compute test error -MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() - .setLabelCol("indexedLabel") - .setPredictionCol("prediction") - .setMetricName("precision"); -double accuracy = evaluator.evaluate(predictions); -System.out.println("Test Error = " + (1.0 - accuracy)); - -DecisionTreeClassificationModel treeModel = - (DecisionTreeClassificationModel)(model.stages()[2]); -System.out.println("Learned classification tree model:\n" + treeModel.toDebugString()); -{% endhighlight %} -
    - -
    - -More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.classification.DecisionTreeClassifier). - -{% highlight python %} -from pyspark.ml import Pipeline -from pyspark.ml.classification import DecisionTreeClassifier -from pyspark.ml.feature import StringIndexer, VectorIndexer -from pyspark.ml.evaluation import MulticlassClassificationEvaluator -from pyspark.mllib.util import MLUtils - -# Load and parse the data file, converting it to a DataFrame. -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() - -# Index labels, adding metadata to the label column. -# Fit on whole dataset to include all labels in index. -labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) -# Automatically identify categorical features, and index them. -# We specify maxCategories so features with > 4 distinct values are treated as continuous. -featureIndexer =\ - VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) - -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a DecisionTree model. -dt = DecisionTreeClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures") - -# Chain indexers and tree in a Pipeline -pipeline = Pipeline(stages=[labelIndexer, featureIndexer, dt]) - -# Train model. This also runs the indexers. -model = pipeline.fit(trainingData) - -# Make predictions. -predictions = model.transform(testData) - -# Select example rows to display. -predictions.select("prediction", "indexedLabel", "features").show(5) - -# Select (prediction, true label) and compute test error -evaluator = MulticlassClassificationEvaluator( - labelCol="indexedLabel", predictionCol="prediction", metricName="precision") -accuracy = evaluator.evaluate(predictions) -print "Test Error = %g" % (1.0 - accuracy) - -treeModel = model.stages[2] -print treeModel # summary only -{% endhighlight %} -
    - -
    - - -## Regression - -The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. -We use a feature transformer to index categorical features, adding metadata to the `DataFrame` which the Decision Tree algorithm can recognize. - -
    -
    - -More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.regression.DecisionTreeRegressor). - -{% highlight scala %} -import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.regression.DecisionTreeRegressor -import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.feature.VectorIndexer -import org.apache.spark.ml.evaluation.RegressionEvaluator -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file, converting it to a DataFrame. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() - -// Automatically identify categorical features, and index them. -// Here, we treat features with > 4 distinct values as continuous. -val featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) - .fit(data) - -// Split the data into training and test sets (30% held out for testing) -val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) - -// Train a DecisionTree model. -val dt = new DecisionTreeRegressor() - .setLabelCol("label") - .setFeaturesCol("indexedFeatures") - -// Chain indexer and tree in a Pipeline -val pipeline = new Pipeline() - .setStages(Array(featureIndexer, dt)) - -// Train model. This also runs the indexer. -val model = pipeline.fit(trainingData) - -// Make predictions. -val predictions = model.transform(testData) - -// Select example rows to display. -predictions.select("prediction", "label", "features").show(5) - -// Select (prediction, true label) and compute test error -val evaluator = new RegressionEvaluator() - .setLabelCol("label") - .setPredictionCol("prediction") - .setMetricName("rmse") -val rmse = evaluator.evaluate(predictions) -println("Root Mean Squared Error (RMSE) on test data = " + rmse) - -val treeModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel] -println("Learned regression tree model:\n" + treeModel.toDebugString) -{% endhighlight %} -
    - -
    - -More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/regression/DecisionTreeRegressor.html). - -{% highlight java %} -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineModel; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.evaluation.RegressionEvaluator; -import org.apache.spark.ml.feature.VectorIndexer; -import org.apache.spark.ml.feature.VectorIndexerModel; -import org.apache.spark.ml.regression.DecisionTreeRegressionModel; -import org.apache.spark.ml.regression.DecisionTreeRegressor; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; -import org.apache.spark.sql.DataFrame; - -// Load and parse the data file, converting it to a DataFrame. -RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); -DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); - -// Automatically identify categorical features, and index them. -// Set maxCategories so features with > 4 distinct values are treated as continuous. -VectorIndexerModel featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) - .fit(data); - -// Split the data into training and test sets (30% held out for testing) -DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); -DataFrame trainingData = splits[0]; -DataFrame testData = splits[1]; - -// Train a DecisionTree model. -DecisionTreeRegressor dt = new DecisionTreeRegressor() - .setFeaturesCol("indexedFeatures"); - -// Chain indexer and tree in a Pipeline -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {featureIndexer, dt}); - -// Train model. This also runs the indexer. -PipelineModel model = pipeline.fit(trainingData); - -// Make predictions. -DataFrame predictions = model.transform(testData); - -// Select example rows to display. -predictions.select("label", "features").show(5); - -// Select (prediction, true label) and compute test error -RegressionEvaluator evaluator = new RegressionEvaluator() - .setLabelCol("label") - .setPredictionCol("prediction") - .setMetricName("rmse"); -double rmse = evaluator.evaluate(predictions); -System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); - -DecisionTreeRegressionModel treeModel = - (DecisionTreeRegressionModel)(model.stages()[1]); -System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); -{% endhighlight %} -
    - -
    - -More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.regression.DecisionTreeRegressor). - -{% highlight python %} -from pyspark.ml import Pipeline -from pyspark.ml.regression import DecisionTreeRegressor -from pyspark.ml.feature import VectorIndexer -from pyspark.ml.evaluation import RegressionEvaluator -from pyspark.mllib.util import MLUtils - -# Load and parse the data file, converting it to a DataFrame. -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() - -# Automatically identify categorical features, and index them. -# We specify maxCategories so features with > 4 distinct values are treated as continuous. -featureIndexer =\ - VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) - -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a DecisionTree model. -dt = DecisionTreeRegressor(featuresCol="indexedFeatures") - -# Chain indexer and tree in a Pipeline -pipeline = Pipeline(stages=[featureIndexer, dt]) - -# Train model. This also runs the indexer. -model = pipeline.fit(trainingData) - -# Make predictions. -predictions = model.transform(testData) - -# Select example rows to display. -predictions.select("prediction", "label", "features").show(5) - -# Select (prediction, true label) and compute test error -evaluator = RegressionEvaluator( - labelCol="label", predictionCol="prediction", metricName="rmse") -rmse = evaluator.evaluate(predictions) -print "Root Mean Squared Error (RMSE) on test data = %g" % rmse - -treeModel = model.stages[1] -print treeModel # summary only -{% endhighlight %} -
    - -
    + > This section has been moved into the + [classification and regression section](ml-classification-regression.html#decision-trees). diff --git a/docs/ml-ensembles.md b/docs/ml-ensembles.md index 58f566c9b4b55..303773e8038fc 100644 --- a/docs/ml-ensembles.md +++ b/docs/ml-ensembles.md @@ -1,1044 +1,8 @@ --- layout: global -title: Ensembles -displayTitle: ML - Ensembles +title: Tree ensemble methods - spark.ml +displayTitle: Tree ensemble methods - spark.ml --- -**Table of Contents** - -* This will become a table of contents (this text will be scraped). -{:toc} - -An [ensemble method](http://en.wikipedia.org/wiki/Ensemble_learning) -is a learning algorithm which creates a model composed of a set of other base models. - -## Tree Ensembles - -The Pipelines API supports two major tree ensemble algorithms: [Random Forests](http://en.wikipedia.org/wiki/Random_forest) and [Gradient-Boosted Trees (GBTs)](http://en.wikipedia.org/wiki/Gradient_boosting). -Both use [MLlib decision trees](ml-decision-tree.html) as their base models. - -Users can find more information about ensemble algorithms in the [MLlib Ensemble guide](mllib-ensembles.html). In this section, we demonstrate the Pipelines API for ensembles. - -The main differences between this API and the [original MLlib ensembles API](mllib-ensembles.html) are: -* support for ML Pipelines -* separation of classification vs. regression -* use of DataFrame metadata to distinguish continuous and categorical features -* a bit more functionality for random forests: estimates of feature importance, as well as the predicted probability of each class (a.k.a. class conditional probabilities) for classification. - -### Random Forests - -[Random forests](http://en.wikipedia.org/wiki/Random_forest) -are ensembles of [decision trees](ml-decision-tree.html). -Random forests combine many decision trees in order to reduce the risk of overfitting. -MLlib supports random forests for binary and multiclass classification and for regression, -using both continuous and categorical features. - -This section gives examples of using random forests with the Pipelines API. -For more information on the algorithm, please see the [main MLlib docs on random forests](mllib-ensembles.html). - -#### Inputs and Outputs - -We list the input and output (prediction) column types here. -All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. - -##### Input Columns - - - - - - - - - - - - - - - - - - - - - - - - -
    Param nameType(s)DefaultDescription
    labelColDouble"label"Label to predict
    featuresColVector"features"Feature vector
    - -##### Output Columns (Predictions) - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    Param nameType(s)DefaultDescriptionNotes
    predictionColDouble"prediction"Predicted label
    rawPredictionColVector"rawPrediction"Vector of length # classes, with the counts of training instance labels at the tree node which makes the predictionClassification only
    probabilityColVector"probability"Vector of length # classes equal to rawPrediction normalized to a multinomial distributionClassification only
    - -#### Example: Classification - -The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. -We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. - -
    -
    - -Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.RandomForestClassifier) for more details. - -{% highlight scala %} -import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.classification.RandomForestClassifier -import org.apache.spark.ml.classification.RandomForestClassificationModel -import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} -import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator - -// Load and parse the data file, converting it to a DataFrame. -val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - -// Index labels, adding metadata to the label column. -// Fit on whole dataset to include all labels in index. -val labelIndexer = new StringIndexer() - .setInputCol("label") - .setOutputCol("indexedLabel") - .fit(data) -// Automatically identify categorical features, and index them. -// Set maxCategories so features with > 4 distinct values are treated as continuous. -val featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) - .fit(data) - -// Split the data into training and test sets (30% held out for testing) -val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) - -// Train a RandomForest model. -val rf = new RandomForestClassifier() - .setLabelCol("indexedLabel") - .setFeaturesCol("indexedFeatures") - .setNumTrees(10) - -// Convert indexed labels back to original labels. -val labelConverter = new IndexToString() - .setInputCol("prediction") - .setOutputCol("predictedLabel") - .setLabels(labelIndexer.labels) - -// Chain indexers and forest in a Pipeline -val pipeline = new Pipeline() - .setStages(Array(labelIndexer, featureIndexer, rf, labelConverter)) - -// Train model. This also runs the indexers. -val model = pipeline.fit(trainingData) - -// Make predictions. -val predictions = model.transform(testData) - -// Select example rows to display. -predictions.select("predictedLabel", "label", "features").show(5) - -// Select (prediction, true label) and compute test error -val evaluator = new MulticlassClassificationEvaluator() - .setLabelCol("indexedLabel") - .setPredictionCol("prediction") - .setMetricName("precision") -val accuracy = evaluator.evaluate(predictions) -println("Test Error = " + (1.0 - accuracy)) - -val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] -println("Learned classification forest model:\n" + rfModel.toDebugString) -{% endhighlight %} -
    - -
    - -Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/RandomForestClassifier.html) for more details. - -{% highlight java %} -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineModel; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.classification.RandomForestClassifier; -import org.apache.spark.ml.classification.RandomForestClassificationModel; -import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; -import org.apache.spark.ml.feature.*; -import org.apache.spark.sql.DataFrame; - -// Load and parse the data file, converting it to a DataFrame. -DataFrame data = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); - -// Index labels, adding metadata to the label column. -// Fit on whole dataset to include all labels in index. -StringIndexerModel labelIndexer = new StringIndexer() - .setInputCol("label") - .setOutputCol("indexedLabel") - .fit(data); -// Automatically identify categorical features, and index them. -// Set maxCategories so features with > 4 distinct values are treated as continuous. -VectorIndexerModel featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) - .fit(data); - -// Split the data into training and test sets (30% held out for testing) -DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); -DataFrame trainingData = splits[0]; -DataFrame testData = splits[1]; - -// Train a RandomForest model. -RandomForestClassifier rf = new RandomForestClassifier() - .setLabelCol("indexedLabel") - .setFeaturesCol("indexedFeatures"); - -// Convert indexed labels back to original labels. -IndexToString labelConverter = new IndexToString() - .setInputCol("prediction") - .setOutputCol("predictedLabel") - .setLabels(labelIndexer.labels()); - -// Chain indexers and forest in a Pipeline -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {labelIndexer, featureIndexer, rf, labelConverter}); - -// Train model. This also runs the indexers. -PipelineModel model = pipeline.fit(trainingData); - -// Make predictions. -DataFrame predictions = model.transform(testData); - -// Select example rows to display. -predictions.select("predictedLabel", "label", "features").show(5); - -// Select (prediction, true label) and compute test error -MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() - .setLabelCol("indexedLabel") - .setPredictionCol("prediction") - .setMetricName("precision"); -double accuracy = evaluator.evaluate(predictions); -System.out.println("Test Error = " + (1.0 - accuracy)); - -RandomForestClassificationModel rfModel = - (RandomForestClassificationModel)(model.stages()[2]); -System.out.println("Learned classification forest model:\n" + rfModel.toDebugString()); -{% endhighlight %} -
    - -
    - -Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.RandomForestClassifier) for more details. - -{% highlight python %} -from pyspark.ml import Pipeline -from pyspark.ml.classification import RandomForestClassifier -from pyspark.ml.feature import StringIndexer, VectorIndexer -from pyspark.ml.evaluation import MulticlassClassificationEvaluator - -# Load and parse the data file, converting it to a DataFrame. -data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - -# Index labels, adding metadata to the label column. -# Fit on whole dataset to include all labels in index. -labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) -# Automatically identify categorical features, and index them. -# Set maxCategories so features with > 4 distinct values are treated as continuous. -featureIndexer =\ - VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) - -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a RandomForest model. -rf = RandomForestClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures") - -# Chain indexers and forest in a Pipeline -pipeline = Pipeline(stages=[labelIndexer, featureIndexer, rf]) - -# Train model. This also runs the indexers. -model = pipeline.fit(trainingData) - -# Make predictions. -predictions = model.transform(testData) - -# Select example rows to display. -predictions.select("prediction", "indexedLabel", "features").show(5) - -# Select (prediction, true label) and compute test error -evaluator = MulticlassClassificationEvaluator( - labelCol="indexedLabel", predictionCol="prediction", metricName="precision") -accuracy = evaluator.evaluate(predictions) -print "Test Error = %g" % (1.0 - accuracy) - -rfModel = model.stages[2] -print rfModel # summary only -{% endhighlight %} -
    -
    - -#### Example: Regression - -The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. -We use a feature transformer to index categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. - -
    -
    - -Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.RandomForestRegressor) for more details. - -{% highlight scala %} -import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.regression.RandomForestRegressor -import org.apache.spark.ml.regression.RandomForestRegressionModel -import org.apache.spark.ml.feature.VectorIndexer -import org.apache.spark.ml.evaluation.RegressionEvaluator - -// Load and parse the data file, converting it to a DataFrame. -val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - -// Automatically identify categorical features, and index them. -// Set maxCategories so features with > 4 distinct values are treated as continuous. -val featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) - .fit(data) - -// Split the data into training and test sets (30% held out for testing) -val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) - -// Train a RandomForest model. -val rf = new RandomForestRegressor() - .setLabelCol("label") - .setFeaturesCol("indexedFeatures") - -// Chain indexer and forest in a Pipeline -val pipeline = new Pipeline() - .setStages(Array(featureIndexer, rf)) - -// Train model. This also runs the indexer. -val model = pipeline.fit(trainingData) - -// Make predictions. -val predictions = model.transform(testData) - -// Select example rows to display. -predictions.select("prediction", "label", "features").show(5) - -// Select (prediction, true label) and compute test error -val evaluator = new RegressionEvaluator() - .setLabelCol("label") - .setPredictionCol("prediction") - .setMetricName("rmse") -val rmse = evaluator.evaluate(predictions) -println("Root Mean Squared Error (RMSE) on test data = " + rmse) - -val rfModel = model.stages(1).asInstanceOf[RandomForestRegressionModel] -println("Learned regression forest model:\n" + rfModel.toDebugString) -{% endhighlight %} -
    - -
    - -Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/RandomForestRegressor.html) for more details. - -{% highlight java %} -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineModel; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.evaluation.RegressionEvaluator; -import org.apache.spark.ml.feature.VectorIndexer; -import org.apache.spark.ml.feature.VectorIndexerModel; -import org.apache.spark.ml.regression.RandomForestRegressionModel; -import org.apache.spark.ml.regression.RandomForestRegressor; -import org.apache.spark.sql.DataFrame; - -// Load and parse the data file, converting it to a DataFrame. -DataFrame data = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); - -// Automatically identify categorical features, and index them. -// Set maxCategories so features with > 4 distinct values are treated as continuous. -VectorIndexerModel featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) - .fit(data); - -// Split the data into training and test sets (30% held out for testing) -DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); -DataFrame trainingData = splits[0]; -DataFrame testData = splits[1]; - -// Train a RandomForest model. -RandomForestRegressor rf = new RandomForestRegressor() - .setLabelCol("label") - .setFeaturesCol("indexedFeatures"); - -// Chain indexer and forest in a Pipeline -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {featureIndexer, rf}); - -// Train model. This also runs the indexer. -PipelineModel model = pipeline.fit(trainingData); - -// Make predictions. -DataFrame predictions = model.transform(testData); - -// Select example rows to display. -predictions.select("prediction", "label", "features").show(5); - -// Select (prediction, true label) and compute test error -RegressionEvaluator evaluator = new RegressionEvaluator() - .setLabelCol("label") - .setPredictionCol("prediction") - .setMetricName("rmse"); -double rmse = evaluator.evaluate(predictions); -System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); - -RandomForestRegressionModel rfModel = - (RandomForestRegressionModel)(model.stages()[1]); -System.out.println("Learned regression forest model:\n" + rfModel.toDebugString()); -{% endhighlight %} -
    - -
    - -Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.RandomForestRegressor) for more details. - -{% highlight python %} -from pyspark.ml import Pipeline -from pyspark.ml.regression import RandomForestRegressor -from pyspark.ml.feature import VectorIndexer -from pyspark.ml.evaluation import RegressionEvaluator - -# Load and parse the data file, converting it to a DataFrame. -data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - -# Automatically identify categorical features, and index them. -# Set maxCategories so features with > 4 distinct values are treated as continuous. -featureIndexer =\ - VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) - -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a RandomForest model. -rf = RandomForestRegressor(featuresCol="indexedFeatures") - -# Chain indexer and forest in a Pipeline -pipeline = Pipeline(stages=[featureIndexer, rf]) - -# Train model. This also runs the indexer. -model = pipeline.fit(trainingData) - -# Make predictions. -predictions = model.transform(testData) - -# Select example rows to display. -predictions.select("prediction", "label", "features").show(5) - -# Select (prediction, true label) and compute test error -evaluator = RegressionEvaluator( - labelCol="label", predictionCol="prediction", metricName="rmse") -rmse = evaluator.evaluate(predictions) -print "Root Mean Squared Error (RMSE) on test data = %g" % rmse - -rfModel = model.stages[1] -print rfModel # summary only -{% endhighlight %} -
    -
    - -### Gradient-Boosted Trees (GBTs) - -[Gradient-Boosted Trees (GBTs)](http://en.wikipedia.org/wiki/Gradient_boosting) -are ensembles of [decision trees](ml-decision-tree.html). -GBTs iteratively train decision trees in order to minimize a loss function. -MLlib supports GBTs for binary classification and for regression, -using both continuous and categorical features. - -This section gives examples of using GBTs with the Pipelines API. -For more information on the algorithm, please see the [main MLlib docs on GBTs](mllib-ensembles.html). - -#### Inputs and Outputs - -We list the input and output (prediction) column types here. -All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. - -##### Input Columns - - - - - - - - - - - - - - - - - - - - - - - - -
    Param nameType(s)DefaultDescription
    labelColDouble"label"Label to predict
    featuresColVector"features"Feature vector
    - -Note that `GBTClassifier` currently only supports binary labels. - -##### Output Columns (Predictions) - - - - - - - - - - - - - - - - - - - - -
    Param nameType(s)DefaultDescriptionNotes
    predictionColDouble"prediction"Predicted label
    - -In the future, `GBTClassifier` will also output columns for `rawPrediction` and `probability`, just as `RandomForestClassifier` does. - -#### Example: Classification - -The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. -We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. - -
    -
    - -Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.GBTClassifier) for more details. - -{% highlight scala %} -import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.classification.GBTClassifier -import org.apache.spark.ml.classification.GBTClassificationModel -import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} -import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator - -// Load and parse the data file, converting it to a DataFrame. -val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - -// Index labels, adding metadata to the label column. -// Fit on whole dataset to include all labels in index. -val labelIndexer = new StringIndexer() - .setInputCol("label") - .setOutputCol("indexedLabel") - .fit(data) -// Automatically identify categorical features, and index them. -// Set maxCategories so features with > 4 distinct values are treated as continuous. -val featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) - .fit(data) - -// Split the data into training and test sets (30% held out for testing) -val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) - -// Train a GBT model. -val gbt = new GBTClassifier() - .setLabelCol("indexedLabel") - .setFeaturesCol("indexedFeatures") - .setMaxIter(10) - -// Convert indexed labels back to original labels. -val labelConverter = new IndexToString() - .setInputCol("prediction") - .setOutputCol("predictedLabel") - .setLabels(labelIndexer.labels) - -// Chain indexers and GBT in a Pipeline -val pipeline = new Pipeline() - .setStages(Array(labelIndexer, featureIndexer, gbt, labelConverter)) - -// Train model. This also runs the indexers. -val model = pipeline.fit(trainingData) - -// Make predictions. -val predictions = model.transform(testData) - -// Select example rows to display. -predictions.select("predictedLabel", "label", "features").show(5) - -// Select (prediction, true label) and compute test error -val evaluator = new MulticlassClassificationEvaluator() - .setLabelCol("indexedLabel") - .setPredictionCol("prediction") - .setMetricName("precision") -val accuracy = evaluator.evaluate(predictions) -println("Test Error = " + (1.0 - accuracy)) - -val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel] -println("Learned classification GBT model:\n" + gbtModel.toDebugString) -{% endhighlight %} -
    - -
    - -Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/GBTClassifier.html) for more details. - -{% highlight java %} -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineModel; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.classification.GBTClassifier; -import org.apache.spark.ml.classification.GBTClassificationModel; -import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; -import org.apache.spark.ml.feature.*; -import org.apache.spark.sql.DataFrame; - -// Load and parse the data file, converting it to a DataFrame. -DataFrame data sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt"); - -// Index labels, adding metadata to the label column. -// Fit on whole dataset to include all labels in index. -StringIndexerModel labelIndexer = new StringIndexer() - .setInputCol("label") - .setOutputCol("indexedLabel") - .fit(data); -// Automatically identify categorical features, and index them. -// Set maxCategories so features with > 4 distinct values are treated as continuous. -VectorIndexerModel featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) - .fit(data); - -// Split the data into training and test sets (30% held out for testing) -DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); -DataFrame trainingData = splits[0]; -DataFrame testData = splits[1]; - -// Train a GBT model. -GBTClassifier gbt = new GBTClassifier() - .setLabelCol("indexedLabel") - .setFeaturesCol("indexedFeatures") - .setMaxIter(10); - -// Convert indexed labels back to original labels. -IndexToString labelConverter = new IndexToString() - .setInputCol("prediction") - .setOutputCol("predictedLabel") - .setLabels(labelIndexer.labels()); - -// Chain indexers and GBT in a Pipeline -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {labelIndexer, featureIndexer, gbt, labelConverter}); - -// Train model. This also runs the indexers. -PipelineModel model = pipeline.fit(trainingData); - -// Make predictions. -DataFrame predictions = model.transform(testData); - -// Select example rows to display. -predictions.select("predictedLabel", "label", "features").show(5); - -// Select (prediction, true label) and compute test error -MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() - .setLabelCol("indexedLabel") - .setPredictionCol("prediction") - .setMetricName("precision"); -double accuracy = evaluator.evaluate(predictions); -System.out.println("Test Error = " + (1.0 - accuracy)); - -GBTClassificationModel gbtModel = - (GBTClassificationModel)(model.stages()[2]); -System.out.println("Learned classification GBT model:\n" + gbtModel.toDebugString()); -{% endhighlight %} -
    - -
    - -Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.GBTClassifier) for more details. - -{% highlight python %} -from pyspark.ml import Pipeline -from pyspark.ml.classification import GBTClassifier -from pyspark.ml.feature import StringIndexer, VectorIndexer -from pyspark.ml.evaluation import MulticlassClassificationEvaluator - -# Load and parse the data file, converting it to a DataFrame. -data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - -# Index labels, adding metadata to the label column. -# Fit on whole dataset to include all labels in index. -labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) -# Automatically identify categorical features, and index them. -# Set maxCategories so features with > 4 distinct values are treated as continuous. -featureIndexer =\ - VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) - -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a GBT model. -gbt = GBTClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures", maxIter=10) - -# Chain indexers and GBT in a Pipeline -pipeline = Pipeline(stages=[labelIndexer, featureIndexer, gbt]) - -# Train model. This also runs the indexers. -model = pipeline.fit(trainingData) - -# Make predictions. -predictions = model.transform(testData) - -# Select example rows to display. -predictions.select("prediction", "indexedLabel", "features").show(5) - -# Select (prediction, true label) and compute test error -evaluator = MulticlassClassificationEvaluator( - labelCol="indexedLabel", predictionCol="prediction", metricName="precision") -accuracy = evaluator.evaluate(predictions) -print "Test Error = %g" % (1.0 - accuracy) - -gbtModel = model.stages[2] -print gbtModel # summary only -{% endhighlight %} -
    -
    - -#### Example: Regression - -Note: For this example dataset, `GBTRegressor` actually only needs 1 iteration, but that will not -be true in general. - -
    -
    - -Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.GBTRegressor) for more details. - -{% highlight scala %} -import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.regression.GBTRegressor -import org.apache.spark.ml.regression.GBTRegressionModel -import org.apache.spark.ml.feature.VectorIndexer -import org.apache.spark.ml.evaluation.RegressionEvaluator - -// Load and parse the data file, converting it to a DataFrame. -val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - -// Automatically identify categorical features, and index them. -// Set maxCategories so features with > 4 distinct values are treated as continuous. -val featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) - .fit(data) - -// Split the data into training and test sets (30% held out for testing) -val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) - -// Train a GBT model. -val gbt = new GBTRegressor() - .setLabelCol("label") - .setFeaturesCol("indexedFeatures") - .setMaxIter(10) - -// Chain indexer and GBT in a Pipeline -val pipeline = new Pipeline() - .setStages(Array(featureIndexer, gbt)) - -// Train model. This also runs the indexer. -val model = pipeline.fit(trainingData) - -// Make predictions. -val predictions = model.transform(testData) - -// Select example rows to display. -predictions.select("prediction", "label", "features").show(5) - -// Select (prediction, true label) and compute test error -val evaluator = new RegressionEvaluator() - .setLabelCol("label") - .setPredictionCol("prediction") - .setMetricName("rmse") -val rmse = evaluator.evaluate(predictions) -println("Root Mean Squared Error (RMSE) on test data = " + rmse) - -val gbtModel = model.stages(1).asInstanceOf[GBTRegressionModel] -println("Learned regression GBT model:\n" + gbtModel.toDebugString) -{% endhighlight %} -
    - -
    - -Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/GBTRegressor.html) for more details. - -{% highlight java %} -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineModel; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.evaluation.RegressionEvaluator; -import org.apache.spark.ml.feature.VectorIndexer; -import org.apache.spark.ml.feature.VectorIndexerModel; -import org.apache.spark.ml.regression.GBTRegressionModel; -import org.apache.spark.ml.regression.GBTRegressor; -import org.apache.spark.sql.DataFrame; - -// Load and parse the data file, converting it to a DataFrame. -DataFrame data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt"); - -// Automatically identify categorical features, and index them. -// Set maxCategories so features with > 4 distinct values are treated as continuous. -VectorIndexerModel featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) - .fit(data); - -// Split the data into training and test sets (30% held out for testing) -DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); -DataFrame trainingData = splits[0]; -DataFrame testData = splits[1]; - -// Train a GBT model. -GBTRegressor gbt = new GBTRegressor() - .setLabelCol("label") - .setFeaturesCol("indexedFeatures") - .setMaxIter(10); - -// Chain indexer and GBT in a Pipeline -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {featureIndexer, gbt}); - -// Train model. This also runs the indexer. -PipelineModel model = pipeline.fit(trainingData); - -// Make predictions. -DataFrame predictions = model.transform(testData); - -// Select example rows to display. -predictions.select("prediction", "label", "features").show(5); - -// Select (prediction, true label) and compute test error -RegressionEvaluator evaluator = new RegressionEvaluator() - .setLabelCol("label") - .setPredictionCol("prediction") - .setMetricName("rmse"); -double rmse = evaluator.evaluate(predictions); -System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); - -GBTRegressionModel gbtModel = - (GBTRegressionModel)(model.stages()[1]); -System.out.println("Learned regression GBT model:\n" + gbtModel.toDebugString()); -{% endhighlight %} -
    - -
    - -Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.GBTRegressor) for more details. - -{% highlight python %} -from pyspark.ml import Pipeline -from pyspark.ml.regression import GBTRegressor -from pyspark.ml.feature import VectorIndexer -from pyspark.ml.evaluation import RegressionEvaluator - -# Load and parse the data file, converting it to a DataFrame. -data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - -# Automatically identify categorical features, and index them. -# Set maxCategories so features with > 4 distinct values are treated as continuous. -featureIndexer =\ - VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) - -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a GBT model. -gbt = GBTRegressor(featuresCol="indexedFeatures", maxIter=10) - -# Chain indexer and GBT in a Pipeline -pipeline = Pipeline(stages=[featureIndexer, gbt]) - -# Train model. This also runs the indexer. -model = pipeline.fit(trainingData) - -# Make predictions. -predictions = model.transform(testData) - -# Select example rows to display. -predictions.select("prediction", "label", "features").show(5) - -# Select (prediction, true label) and compute test error -evaluator = RegressionEvaluator( - labelCol="label", predictionCol="prediction", metricName="rmse") -rmse = evaluator.evaluate(predictions) -print "Root Mean Squared Error (RMSE) on test data = %g" % rmse - -gbtModel = model.stages[1] -print gbtModel # summary only -{% endhighlight %} -
    -
    - - -## One-vs-Rest (a.k.a. One-vs-All) - -[OneVsRest](http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest) is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. It is also known as "One-vs-All." - -`OneVsRest` is implemented as an `Estimator`. For the base classifier it takes instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. - -Predictions are done by evaluating each binary classifier and the index of the most confident classifier is output as label. - -### Example - -The example below demonstrates how to load the -[Iris dataset](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/iris.scale), parse it as a DataFrame and perform multiclass classification using `OneVsRest`. The test error is calculated to measure the algorithm accuracy. - -
    -
    - -Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classifier.OneVsRest) for more details. - -{% highlight scala %} -import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest} -import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.sql.{Row, SQLContext} - -val sqlContext = new SQLContext(sc) - -// parse data into dataframe -val data = sqlContext.read.format("libsvm") - .load("data/mllib/sample_multiclass_classification_data.txt") -val Array(train, test) = data.randomSplit(Array(0.7, 0.3)) - -// instantiate multiclass learner and train -val ovr = new OneVsRest().setClassifier(new LogisticRegression) - -val ovrModel = ovr.fit(train) - -// score model on test data -val predictions = ovrModel.transform(test).select("prediction", "label") -val predictionsAndLabels = predictions.map {case Row(p: Double, l: Double) => (p, l)} - -// compute confusion matrix -val metrics = new MulticlassMetrics(predictionsAndLabels) -println(metrics.confusionMatrix) - -// the Iris DataSet has three classes -val numClasses = 3 - -println("label\tfpr\n") -(0 until numClasses).foreach { index => - val label = index.toDouble - println(label + "\t" + metrics.falsePositiveRate(label)) -} -{% endhighlight %} -
    - -
    - -Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/OneVsRest.html) for more details. - -{% highlight java %} -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.classification.LogisticRegression; -import org.apache.spark.ml.classification.OneVsRest; -import org.apache.spark.ml.classification.OneVsRestModel; -import org.apache.spark.mllib.evaluation.MulticlassMetrics; -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; - -SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample"); -JavaSparkContext jsc = new JavaSparkContext(conf); -SQLContext jsql = new SQLContext(jsc); - -DataFrame dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_multiclass_classification_data.txt"); - -DataFrame[] splits = dataFrame.randomSplit(new double[] {0.7, 0.3}, 12345); -DataFrame train = splits[0]; -DataFrame test = splits[1]; - -// instantiate the One Vs Rest Classifier -OneVsRest ovr = new OneVsRest().setClassifier(new LogisticRegression()); - -// train the multiclass model -OneVsRestModel ovrModel = ovr.fit(train.cache()); - -// score the model on test data -DataFrame predictions = ovrModel - .transform(test) - .select("prediction", "label"); - -// obtain metrics -MulticlassMetrics metrics = new MulticlassMetrics(predictions); -Matrix confusionMatrix = metrics.confusionMatrix(); - -// output the Confusion Matrix -System.out.println("Confusion Matrix"); -System.out.println(confusionMatrix); - -// compute the false positive rate per label -System.out.println(); -System.out.println("label\tfpr\n"); - -// the Iris DataSet has three classes -int numClasses = 3; -for (int index = 0; index < numClasses; index++) { - double label = (double) index; - System.out.print(label); - System.out.print("\t"); - System.out.print(metrics.falsePositiveRate(label)); - System.out.println(); -} -{% endhighlight %} -
    -
    + > This section has been moved into the + [classification and regression section](ml-classification-regression.html#tree-ensembles). diff --git a/docs/ml-features.md b/docs/ml-features.md index 142afac2f3f95..70812eb5e2292 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1,7 +1,7 @@ --- layout: global -title: Feature Extraction, Transformation, and Selection - SparkML -displayTitle: ML - Features +title: Extracting, transforming and selecting features - spark.ml +displayTitle: Extracting, transforming and selecting features - spark.ml --- This section covers algorithms for working with features, roughly divided into these groups: @@ -63,7 +63,7 @@ the [IDF Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.IDF) for mor `Word2VecModel`. The model maps each word to a unique fixed-size vector. The `Word2VecModel` transforms each document into a vector using the average of all words in the document; this vector can then be used for as features for prediction, document similarity calculations, etc. -Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#Word2Vec) for more +Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#word2vec) for more details. In the following code segment, we start with a set of documents, each of which is represented as a sequence of words. For each document, we transform it into a feature vector. This feature vector could then be passed to a learning algorithm. @@ -149,6 +149,15 @@ for more details on the API. {% include_example java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java %} + +
    + +Refer to the [CountVectorizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.CountVectorizer) +and the [CountVectorizerModel Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.CountVectorizerModel) +for more details on the API. + +{% include_example python/ml/count_vectorizer_example.py %} +
    # Feature Transformers @@ -170,25 +179,7 @@ Refer to the [Tokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.fea and the [RegexTokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Tokenizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.{Tokenizer, RegexTokenizer} - -val sentenceDataFrame = sqlContext.createDataFrame(Seq( - (0, "Hi I heard about Spark"), - (1, "I wish Java could use case classes"), - (2, "Logistic,regression,models,are,neat") -)).toDF("label", "sentence") -val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") -val regexTokenizer = new RegexTokenizer() - .setInputCol("sentence") - .setOutputCol("words") - .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) - -val tokenized = tokenizer.transform(sentenceDataFrame) -tokenized.select("words", "label").take(3).foreach(println) -val regexTokenized = regexTokenizer.transform(sentenceDataFrame) -regexTokenized.select("words", "label").take(3).foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/TokenizerExample.scala %}
    @@ -197,67 +188,16 @@ Refer to the [Tokenizer Java docs](api/java/org/apache/spark/ml/feature/Tokenize and the [RegexTokenizer Java docs](api/java/org/apache/spark/ml/feature/RegexTokenizer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.RegexTokenizer; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(1, "I wish Java could use case classes"), - RowFactory.create(2, "Logistic,regression,models,are,neat") -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) -}); -DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); -Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); -DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); -for (Row r : wordsDataFrame.select("words", "label").take(3)) { - java.util.List words = r.getList(0); - for (String word : words) System.out.print(word + " "); - System.out.println(); -} - -RegexTokenizer regexTokenizer = new RegexTokenizer() - .setInputCol("sentence") - .setOutputCol("words") - .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaTokenizerExample.java %}
    Refer to the [Tokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Tokenizer) and -the the [RegexTokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RegexTokenizer) +the [RegexTokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RegexTokenizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Tokenizer, RegexTokenizer - -sentenceDataFrame = sqlContext.createDataFrame([ - (0, "Hi I heard about Spark"), - (1, "I wish Java could use case classes"), - (2, "Logistic,regression,models,are,neat") -], ["label", "sentence"]) -tokenizer = Tokenizer(inputCol="sentence", outputCol="words") -wordsDataFrame = tokenizer.transform(sentenceDataFrame) -for words_label in wordsDataFrame.select("words", "label").take(3): - print(words_label) -regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") -# alternatively, pattern="\\w+", gaps(False) -{% endhighlight %} +{% include_example python/ml/tokenizer_example.py %}
    @@ -306,19 +246,7 @@ filtered out. Refer to the [StopWordsRemover Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StopWordsRemover) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.StopWordsRemover - -val remover = new StopWordsRemover() - .setInputCol("raw") - .setOutputCol("filtered") -val dataSet = sqlContext.createDataFrame(Seq( - (0, Seq("I", "saw", "the", "red", "baloon")), - (1, Seq("Mary", "had", "a", "little", "lamb")) -)).toDF("id", "raw") - -remover.transform(dataSet).show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala %}
    @@ -326,34 +254,7 @@ remover.transform(dataSet).show() Refer to the [StopWordsRemover Java docs](api/java/org/apache/spark/ml/feature/StopWordsRemover.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.StopWordsRemover; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -StopWordsRemover remover = new StopWordsRemover() - .setInputCol("raw") - .setOutputCol("filtered"); - -JavaRDD rdd = jsc.parallelize(Arrays.asList( - RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), - RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) -)); -StructType schema = new StructType(new StructField[] { - new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) -}); -DataFrame dataset = jsql.createDataFrame(rdd, schema); - -remover.transform(dataset).show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java %}
    @@ -361,17 +262,7 @@ remover.transform(dataset).show(); Refer to the [StopWordsRemover Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StopWordsRemover) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import StopWordsRemover - -sentenceData = sqlContext.createDataFrame([ - (0, ["I", "saw", "the", "red", "baloon"]), - (1, ["Mary", "had", "a", "little", "lamb"]) -], ["label", "raw"]) - -remover = StopWordsRemover(inputCol="raw", outputCol="filtered") -remover.transform(sentenceData).show(truncate=False) -{% endhighlight %} +{% include_example python/ml/stopwords_remover_example.py %}
    @@ -388,19 +279,7 @@ An [n-gram](https://en.wikipedia.org/wiki/N-gram) is a sequence of $n$ tokens (t Refer to the [NGram Scala docs](api/scala/index.html#org.apache.spark.ml.feature.NGram) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.NGram - -val wordDataFrame = sqlContext.createDataFrame(Seq( - (0, Array("Hi", "I", "heard", "about", "Spark")), - (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), - (2, Array("Logistic", "regression", "models", "are", "neat")) -)).toDF("label", "words") - -val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") -val ngramDataFrame = ngram.transform(wordDataFrame) -ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/NGramExample.scala %}
    @@ -408,38 +287,7 @@ ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(pri Refer to the [NGram Java docs](api/java/org/apache/spark/ml/feature/NGram.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.NGram; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")), - RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")), - RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat")) -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) -}); -DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema); -NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); -DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame); -for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { - java.util.List ngrams = r.getList(0); - for (String ngram : ngrams) System.out.print(ngram + " --- "); - System.out.println(); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaNGramExample.java %}
    @@ -447,19 +295,7 @@ for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { Refer to the [NGram Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.NGram) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import NGram - -wordDataFrame = sqlContext.createDataFrame([ - (0, ["Hi", "I", "heard", "about", "Spark"]), - (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), - (2, ["Logistic", "regression", "models", "are", "neat"]) -], ["label", "words"]) -ngram = NGram(inputCol="words", outputCol="ngrams") -ngramDataFrame = ngram.transform(wordDataFrame) -for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): - print(ngrams_label) -{% endhighlight %} +{% include_example python/ml/n_gram_example.py %}
    @@ -476,26 +312,7 @@ Binarization is the process of thresholding numerical features to binary (0/1) f Refer to the [Binarizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Binarizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.Binarizer -import org.apache.spark.sql.DataFrame - -val data = Array( - (0, 0.1), - (1, 0.8), - (2, 0.2) -) -val dataFrame: DataFrame = sqlContext.createDataFrame(data).toDF("label", "feature") - -val binarizer: Binarizer = new Binarizer() - .setInputCol("feature") - .setOutputCol("binarized_feature") - .setThreshold(0.5) - -val binarizedDataFrame = binarizer.transform(dataFrame) -val binarizedFeatures = binarizedDataFrame.select("binarized_feature") -binarizedFeatures.collect().foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/BinarizerExample.scala %}
    @@ -503,40 +320,7 @@ binarizedFeatures.collect().foreach(println) Refer to the [Binarizer Java docs](api/java/org/apache/spark/ml/feature/Binarizer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.Binarizer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, 0.1), - RowFactory.create(1, 0.8), - RowFactory.create(2, 0.2) -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) -}); -DataFrame continuousDataFrame = jsql.createDataFrame(jrdd, schema); -Binarizer binarizer = new Binarizer() - .setInputCol("feature") - .setOutputCol("binarized_feature") - .setThreshold(0.5); -DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); -DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); -for (Row r : binarizedFeatures.collect()) { - Double binarized_value = r.getDouble(0); - System.out.println(binarized_value); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaBinarizerExample.java %}
    @@ -544,20 +328,7 @@ for (Row r : binarizedFeatures.collect()) { Refer to the [Binarizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Binarizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Binarizer - -continuousDataFrame = sqlContext.createDataFrame([ - (0, 0.1), - (1, 0.8), - (2, 0.2) -], ["label", "feature"]) -binarizer = Binarizer(threshold=0.5, inputCol="feature", outputCol="binarized_feature") -binarizedDataFrame = binarizer.transform(continuousDataFrame) -binarizedFeatures = binarizedDataFrame.select("binarized_feature") -for binarized_feature, in binarizedFeatures.collect(): - print(binarized_feature) -{% endhighlight %} +{% include_example python/ml/binarizer_example.py %}
    @@ -571,25 +342,7 @@ for binarized_feature, in binarizedFeatures.collect(): Refer to the [PCA Scala docs](api/scala/index.html#org.apache.spark.ml.feature.PCA) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.PCA -import org.apache.spark.mllib.linalg.Vectors - -val data = Array( - Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), - Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), - Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) -) -val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") -val pca = new PCA() - .setInputCol("features") - .setOutputCol("pcaFeatures") - .setK(3) - .fit(df) -val pcaDF = pca.transform(df) -val result = pcaDF.select("pcaFeatures") -result.show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/PCAExample.scala %}
    @@ -597,42 +350,7 @@ result.show() Refer to the [PCA Java docs](api/java/org/apache/spark/ml/feature/PCA.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.feature.PCA -import org.apache.spark.ml.feature.PCAModel -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaSparkContext jsc = ... -SQLContext jsql = ... -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), - RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), - RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", new VectorUDT(), false, Metadata.empty()), -}); -DataFrame df = jsql.createDataFrame(data, schema); -PCAModel pca = new PCA() - .setInputCol("features") - .setOutputCol("pcaFeatures") - .setK(3) - .fit(df); -DataFrame result = pca.transform(df).select("pcaFeatures"); -result.show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaPCAExample.java %}
    @@ -640,19 +358,7 @@ result.show(); Refer to the [PCA Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.PCA) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import PCA -from pyspark.mllib.linalg import Vectors - -data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), - (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), - (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] -df = sqlContext.createDataFrame(data,["features"]) -pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") -model = pca.fit(df) -result = model.transform(df).select("pcaFeatures") -result.show(truncate=False) -{% endhighlight %} +{% include_example python/ml/pca_example.py %}
    @@ -666,23 +372,7 @@ result.show(truncate=False) Refer to the [PolynomialExpansion Scala docs](api/scala/index.html#org.apache.spark.ml.feature.PolynomialExpansion) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.PolynomialExpansion -import org.apache.spark.mllib.linalg.Vectors - -val data = Array( - Vectors.dense(-2.0, 2.3), - Vectors.dense(0.0, 0.0), - Vectors.dense(0.6, -1.1) -) -val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") -val polynomialExpansion = new PolynomialExpansion() - .setInputCol("features") - .setOutputCol("polyFeatures") - .setDegree(3) -val polyDF = polynomialExpansion.transform(df) -polyDF.select("polyFeatures").take(3).foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala %}
    @@ -690,43 +380,7 @@ polyDF.select("polyFeatures").take(3).foreach(println) Refer to the [PolynomialExpansion Java docs](api/java/org/apache/spark/ml/feature/PolynomialExpansion.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaSparkContext jsc = ... -SQLContext jsql = ... -PolynomialExpansion polyExpansion = new PolynomialExpansion() - .setInputCol("features") - .setOutputCol("polyFeatures") - .setDegree(3); -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.dense(-2.0, 2.3)), - RowFactory.create(Vectors.dense(0.0, 0.0)), - RowFactory.create(Vectors.dense(0.6, -1.1)) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", new VectorUDT(), false, Metadata.empty()), -}); -DataFrame df = jsql.createDataFrame(data, schema); -DataFrame polyDF = polyExpansion.transform(df); -Row[] row = polyDF.select("polyFeatures").take(3); -for (Row r : row) { - System.out.println(r.get(0)); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java %}
    @@ -734,20 +388,7 @@ for (Row r : row) { Refer to the [PolynomialExpansion Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.PolynomialExpansion) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import PolynomialExpansion -from pyspark.mllib.linalg import Vectors - -df = sqlContext.createDataFrame( - [(Vectors.dense([-2.0, 2.3]), ), - (Vectors.dense([0.0, 0.0]), ), - (Vectors.dense([0.6, -1.1]), )], - ["features"]) -px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") -polyDF = px.transform(df) -for expanded in polyDF.select("polyFeatures").take(3): - print(expanded) -{% endhighlight %} +{% include_example python/ml/polynomial_expansion_example.py %}
    @@ -771,22 +412,7 @@ $0$th DCT coefficient and _not_ the $N/2$th). Refer to the [DCT Scala docs](api/scala/index.html#org.apache.spark.ml.feature.DCT) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.DCT -import org.apache.spark.mllib.linalg.Vectors - -val data = Seq( - Vectors.dense(0.0, 1.0, -2.0, 3.0), - Vectors.dense(-1.0, 2.0, 4.0, -7.0), - Vectors.dense(14.0, -2.0, -5.0, 1.0)) -val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") -val dct = new DCT() - .setInputCol("features") - .setOutputCol("featuresDCT") - .setInverse(false) -val dctDf = dct.transform(df) -dctDf.select("featuresDCT").show(3) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/DCTExample.scala %}
    @@ -794,39 +420,15 @@ dctDf.select("featuresDCT").show(3) Refer to the [DCT Java docs](api/java/org/apache/spark/ml/feature/DCT.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.feature.DCT; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), - RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), - RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", new VectorUDT(), false, Metadata.empty()), -}); -DataFrame df = jsql.createDataFrame(data, schema); -DCT dct = new DCT() - .setInputCol("features") - .setOutputCol("featuresDCT") - .setInverse(false); -DataFrame dctDf = dct.transform(df); -dctDf.select("featuresDCT").show(3); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaDCTExample.java %} +
    + +
    + +Refer to the [DCT Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.DCT) +for more details on the API. + +{% include_example python/ml/dct_example.py %}
    @@ -835,10 +437,10 @@ dctDf.select("featuresDCT").show(3); `StringIndexer` encodes a string column of labels to a column of label indices. The indices are in `[0, numLabels)`, ordered by label frequencies. So the most frequent label gets index `0`. -If the input column is numeric, we cast it to string and index the string -values. When downstream pipeline components such as `Estimator` or -`Transformer` make use of this string-indexed label, you must set the input -column of the component to this string-indexed column name. In many cases, +If the input column is numeric, we cast it to string and index the string +values. When downstream pipeline components such as `Estimator` or +`Transformer` make use of this string-indexed label, you must set the input +column of the component to this string-indexed column name. In many cases, you can set the input column with `setInputCol`. **Examples** @@ -874,6 +476,42 @@ column, we should get the following: "a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with index `2`. +Additionally, there are two strategies regarding how `StringIndexer` will handle +unseen labels when you have fit a `StringIndexer` on one dataset and then use it +to transform another: + +- throw an exception (which is the default) +- skip the row containing the unseen label entirely + +**Examples** + +Let's go back to our previous example but this time reuse our previously defined +`StringIndexer` on the following dataset: + +~~~~ + id | category +----|---------- + 0 | a + 1 | b + 2 | c + 3 | d +~~~~ + +If you've not set how `StringIndexer` handles unseen labels or set it to +"error", an exception will be thrown. +However, if you had called `setHandleInvalid("skip")`, the following dataset +will be generated: + +~~~~ + id | category | categoryIndex +----|----------|--------------- + 0 | a | 0.0 + 1 | b | 2.0 + 2 | c | 1.0 +~~~~ + +Notice that the row containing "d" does not appear. +
    @@ -881,18 +519,7 @@ index `2`. Refer to the [StringIndexer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StringIndexer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.StringIndexer - -val df = sqlContext.createDataFrame( - Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) -).toDF("id", "category") -val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") -val indexed = indexer.fit(df).transform(df) -indexed.show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/StringIndexerExample.scala %}
    @@ -900,37 +527,7 @@ indexed.show() Refer to the [StringIndexer Java docs](api/java/org/apache/spark/ml/feature/StringIndexer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -import static org.apache.spark.sql.types.DataTypes.*; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") -)); -StructType schema = new StructType(new StructField[] { - createStructField("id", DoubleType, false), - createStructField("category", StringType, false) -}); -DataFrame df = sqlContext.createDataFrame(jrdd, schema); -StringIndexer indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex"); -DataFrame indexed = indexer.fit(df).transform(df); -indexed.show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaStringIndexerExample.java %}
    @@ -938,52 +535,90 @@ indexed.show(); Refer to the [StringIndexer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StringIndexer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import StringIndexer - -df = sqlContext.createDataFrame( - [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], - ["id", "category"]) -indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") -indexed = indexer.fit(df).transform(df) -indexed.show() -{% endhighlight %} +{% include_example python/ml/string_indexer_example.py %}
    -## OneHotEncoder -[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features +## IndexToString + +Symmetrically to `StringIndexer`, `IndexToString` maps a column of label indices +back to a column containing the original labels as strings. The common use case +is to produce indices from labels with `StringIndexer`, train a model with those +indices and retrieve the original labels from the column of predicted indices +with `IndexToString`. However, you are free to supply your own labels. + +**Examples** + +Building on the `StringIndexer` example, let's assume we have the following +DataFrame with columns `id` and `categoryIndex`: + +~~~~ + id | categoryIndex +----|--------------- + 0 | 0.0 + 1 | 2.0 + 2 | 1.0 + 3 | 0.0 + 4 | 0.0 + 5 | 1.0 +~~~~ + +Applying `IndexToString` with `categoryIndex` as the input column, +`originalCategory` as the output column, we are able to retrieve our original +labels (they will be inferred from the columns' metadata): + +~~~~ + id | categoryIndex | originalCategory +----|---------------|----------------- + 0 | 0.0 | a + 1 | 2.0 | b + 2 | 1.0 | c + 3 | 0.0 | a + 4 | 0.0 | a + 5 | 1.0 | c +~~~~
    -Refer to the [OneHotEncoder Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoder) +Refer to the [IndexToString Scala docs](api/scala/index.html#org.apache.spark.ml.feature.IndexToString) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/IndexToStringExample.scala %} + +
    + +
    + +Refer to the [IndexToString Java docs](api/java/org/apache/spark/ml/feature/IndexToString.html) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} +{% include_example java/org/apache/spark/examples/ml/JavaIndexToStringExample.java %} -val df = sqlContext.createDataFrame(Seq( - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") -)).toDF("id", "category") +
    -val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df) -val indexed = indexer.transform(df) +
    + +Refer to the [IndexToString Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.IndexToString) +for more details on the API. -val encoder = new OneHotEncoder().setInputCol("categoryIndex"). - setOutputCol("categoryVec") -val encoded = encoder.transform(indexed) -encoded.select("id", "categoryVec").foreach(println) -{% endhighlight %} +{% include_example python/ml/index_to_string_example.py %} + +
    +
    + +## OneHotEncoder + +[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features + +
    +
    + +Refer to the [OneHotEncoder Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoder) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala %}
    @@ -991,45 +626,7 @@ encoded.select("id", "categoryVec").foreach(println) Refer to the [OneHotEncoder Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoder.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.OneHotEncoder; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.ml.feature.StringIndexerModel; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") -)); -StructType schema = new StructType(new StructField[]{ - new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("category", DataTypes.StringType, false, Metadata.empty()) -}); -DataFrame df = sqlContext.createDataFrame(jrdd, schema); -StringIndexerModel indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df); -DataFrame indexed = indexer.transform(df); - -OneHotEncoder encoder = new OneHotEncoder() - .setInputCol("categoryIndex") - .setOutputCol("categoryVec"); -DataFrame encoded = encoder.transform(indexed); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java %}
    @@ -1037,24 +634,7 @@ DataFrame encoded = encoder.transform(indexed); Refer to the [OneHotEncoder Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoder) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import OneHotEncoder, StringIndexer - -df = sqlContext.createDataFrame([ - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") -], ["id", "category"]) - -stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") -model = stringIndexer.fit(df) -indexed = model.transform(df) -encoder = OneHotEncoder(includeFirst=False, inputCol="categoryIndex", outputCol="categoryVec") -encoded = encoder.transform(indexed) -{% endhighlight %} +{% include_example python/ml/onehot_encoder_example.py %}
    @@ -1078,23 +658,7 @@ In the example below, we read in a dataset of labeled points and then use `Vecto Refer to the [VectorIndexer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorIndexer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.VectorIndexer - -val data = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -val indexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexed") - .setMaxCategories(10) -val indexerModel = indexer.fit(data) -val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet -println(s"Chose ${categoricalFeatures.size} categorical features: " + - categoricalFeatures.mkString(", ")) - -// Create new column "indexed" with categorical values transformed to indices -val indexedData = indexerModel.transform(data) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/VectorIndexerExample.scala %}
    @@ -1102,30 +666,7 @@ val indexedData = indexerModel.transform(data) Refer to the [VectorIndexer Java docs](api/java/org/apache/spark/ml/feature/VectorIndexer.html) for more details on the API. -{% highlight java %} -import java.util.Map; - -import org.apache.spark.ml.feature.VectorIndexer; -import org.apache.spark.ml.feature.VectorIndexerModel; -import org.apache.spark.sql.DataFrame; - -DataFrame data = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); -VectorIndexer indexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexed") - .setMaxCategories(10); -VectorIndexerModel indexerModel = indexer.fit(data); -Map> categoryMaps = indexerModel.javaCategoryMaps(); -System.out.print("Chose " + categoryMaps.size() + "categorical features:"); -for (Integer feature : categoryMaps.keySet()) { - System.out.print(" " + feature); -} -System.out.println(); - -// Create new column "indexed" with categorical values transformed to indices -DataFrame indexedData = indexerModel.transform(data); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java %}
    @@ -1133,17 +674,7 @@ DataFrame indexedData = indexerModel.transform(data); Refer to the [VectorIndexer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorIndexer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import VectorIndexer - -data = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) -indexerModel = indexer.fit(data) - -# Create new column "indexed" with categorical values transformed to indices -indexedData = indexerModel.transform(data) -{% endhighlight %} +{% include_example python/ml/vector_indexer_example.py %}
    @@ -1155,72 +686,28 @@ indexedData = indexerModel.transform(data) The following example demonstrates how to load a dataset in libsvm format and then normalize each row to have unit $L^2$ norm and unit $L^\infty$ norm.
    -
    +
    Refer to the [Normalizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Normalizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.Normalizer - -val dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") - -// Normalize each Vector using $L^1$ norm. -val normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normFeatures") - .setP(1.0) -val l1NormData = normalizer.transform(dataFrame) - -// Normalize each Vector using $L^\infty$ norm. -val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/NormalizerExample.scala %}
    -
    +
    Refer to the [Normalizer Java docs](api/java/org/apache/spark/ml/feature/Normalizer.html) for more details on the API. -{% highlight java %} -import org.apache.spark.ml.feature.Normalizer; -import org.apache.spark.sql.DataFrame; - -DataFrame dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); - -// Normalize each Vector using $L^1$ norm. -Normalizer normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normFeatures") - .setP(1.0); -DataFrame l1NormData = normalizer.transform(dataFrame); - -// Normalize each Vector using $L^\infty$ norm. -DataFrame lInfNormData = - normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaNormalizerExample.java %}
    -
    +
    Refer to the [Normalizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Normalizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Normalizer - -dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") - -# Normalize each Vector using $L^1$ norm. -normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) -l1NormData = normalizer.transform(dataFrame) - -# Normalize each Vector using $L^\infty$ norm. -lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) -{% endhighlight %} +{% include_example python/ml/normalizer_example.py %}
    @@ -1232,82 +719,35 @@ lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) * `withStd`: True by default. Scales the data to unit standard deviation. * `withMean`: False by default. Centers the data with mean before scaling. It will build a dense output, so this does not work on sparse input and will raise an exception. -`StandardScaler` is a `Model` which can be `fit` on a dataset to produce a `StandardScalerModel`; this amounts to computing summary statistics. The model can then transform a `Vector` column in a dataset to have unit standard deviation and/or zero mean features. +`StandardScaler` is an `Estimator` which can be `fit` on a dataset to produce a `StandardScalerModel`; this amounts to computing summary statistics. The model can then transform a `Vector` column in a dataset to have unit standard deviation and/or zero mean features. Note that if the standard deviation of a feature is zero, it will return default `0.0` value in the `Vector` for that feature. The following example demonstrates how to load a dataset in libsvm format and then normalize each feature to have unit standard deviation.
    -
    +
    Refer to the [StandardScaler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StandardScaler) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.StandardScaler - -val dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -val scaler = new StandardScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - .setWithStd(true) - .setWithMean(false) - -// Compute summary statistics by fitting the StandardScaler -val scalerModel = scaler.fit(dataFrame) - -// Normalize each feature to have unit standard deviation. -val scaledData = scalerModel.transform(dataFrame) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/StandardScalerExample.scala %}
    -
    +
    Refer to the [StandardScaler Java docs](api/java/org/apache/spark/ml/feature/StandardScaler.html) for more details on the API. -{% highlight java %} -import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.ml.feature.StandardScalerModel; -import org.apache.spark.sql.DataFrame; - -DataFrame dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); -StandardScaler scaler = new StandardScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - .setWithStd(true) - .setWithMean(false); - -// Compute summary statistics by fitting the StandardScaler -StandardScalerModel scalerModel = scaler.fit(dataFrame); - -// Normalize each feature to have unit standard deviation. -DataFrame scaledData = scalerModel.transform(dataFrame); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaStandardScalerExample.java %}
    -
    +
    Refer to the [StandardScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StandardScaler) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import StandardScaler - -dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", - withStd=True, withMean=False) - -# Compute summary statistics by fitting the StandardScaler -scalerModel = scaler.fit(dataFrame) - -# Normalize each feature to have unit standard deviation. -scaledData = scalerModel.transform(dataFrame) -{% endhighlight %} +{% include_example python/ml/standard_scaler_example.py %}
    @@ -1337,47 +777,64 @@ Refer to the [MinMaxScaler Scala docs](api/scala/index.html#org.apache.spark.ml. and the [MinMaxScalerModel Scala docs](api/scala/index.html#org.apache.spark.ml.feature.MinMaxScalerModel) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.MinMaxScaler +{% include_example scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala %} +
    + +
    + +Refer to the [MinMaxScaler Java docs](api/java/org/apache/spark/ml/feature/MinMaxScaler.html) +and the [MinMaxScalerModel Java docs](api/java/org/apache/spark/ml/feature/MinMaxScalerModel.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java %} +
    + +
    + +Refer to the [MinMaxScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.MinMaxScaler) +for more details on the API. + +{% include_example python/ml/min_max_scaler_example.py %} +
    +
    + + +## MaxAbsScaler + +`MaxAbsScaler` transforms a dataset of `Vector` rows, rescaling each feature to range [-1, 1] +by dividing through the maximum absolute value in each feature. It does not shift/center the +data, and thus does not destroy any sparsity. -val dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -val scaler = new MinMaxScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") +`MaxAbsScaler` computes summary statistics on a data set and produces a `MaxAbsScalerModel`. The +model can then transform each feature individually to range [-1, 1]. -// Compute summary statistics and generate MinMaxScalerModel -val scalerModel = scaler.fit(dataFrame) +The following example demonstrates how to load a dataset in libsvm format and then rescale each feature to [-1, 1]. -// rescale each feature to range [min, max]. -val scaledData = scalerModel.transform(dataFrame) -{% endhighlight %} +
    +
    + +Refer to the [MaxAbsScaler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.MaxAbsScaler) +and the [MaxAbsScalerModel Scala docs](api/scala/index.html#org.apache.spark.ml.feature.MaxAbsScalerModel) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/MaxAbsScalerExample.scala %}
    -Refer to the [MinMaxScaler Java docs](api/java/org/apache/spark/ml/feature/MinMaxScaler.html) -and the [MinMaxScalerModel Java docs](api/java/org/apache/spark/ml/feature/MinMaxScalerModel.html) +Refer to the [MaxAbsScaler Java docs](api/java/org/apache/spark/ml/feature/MaxAbsScaler.html) +and the [MaxAbsScalerModel Java docs](api/java/org/apache/spark/ml/feature/MaxAbsScalerModel.html) for more details on the API. -{% highlight java %} -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.MinMaxScaler; -import org.apache.spark.ml.feature.MinMaxScalerModel; -import org.apache.spark.sql.DataFrame; +{% include_example java/org/apache/spark/examples/ml/JavaMaxAbsScalerExample.java %} +
    -DataFrame dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); -MinMaxScaler scaler = new MinMaxScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures"); +
    -// Compute summary statistics and generate MinMaxScalerModel -MinMaxScalerModel scalerModel = scaler.fit(dataFrame); +Refer to the [MaxAbsScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.MaxAbsScaler) +for more details on the API. -// rescale each feature to range [min, max]. -DataFrame scaledData = scalerModel.transform(dataFrame); -{% endhighlight %} +{% include_example python/ml/max_abs_scaler_example.py %}
    @@ -1387,7 +844,7 @@ DataFrame scaledData = scalerModel.transform(dataFrame); * `splits`: Parameter for mapping continuous features into buckets. With n+1 splits, there are n buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last bucket, which also includes y. Splits should be strictly increasing. Values at -inf, inf must be explicitly provided to cover all Double values; Otherwise, values outside the splits specified will be treated as errors. Two examples of `splits` are `Array(Double.NegativeInfinity, 0.0, 1.0, Double.PositiveInfinity)` and `Array(0.0, 1.0, 2.0)`. -Note that if you have no idea of the upper bound and lower bound of the targeted column, you would better add the `Double.NegativeInfinity` and `Double.PositiveInfinity` as the bounds of your splits to prevent a potenial out of Bucketizer bounds exception. +Note that if you have no idea of the upper bound and lower bound of the targeted column, you would better add the `Double.NegativeInfinity` and `Double.PositiveInfinity` as the bounds of your splits to prevent a potential out of Bucketizer bounds exception. Note also that the splits that you provided have to be in strictly increasing order, i.e. `s0 < s1 < s2 < ... < sn`. @@ -1396,87 +853,28 @@ More details can be found in the API docs for [Bucketizer](api/scala/index.html# The following example demonstrates how to bucketize a column of `Double`s into another index-wised column.
    -
    +
    Refer to the [Bucketizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Bucketizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.Bucketizer -import org.apache.spark.sql.DataFrame - -val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) - -val data = Array(-0.5, -0.3, 0.0, 0.2) -val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") - -val bucketizer = new Bucketizer() - .setInputCol("features") - .setOutputCol("bucketedFeatures") - .setSplits(splits) - -// Transform original data into its bucket index. -val bucketedData = bucketizer.transform(dataFrame) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/BucketizerExample.scala %}
    -
    +
    Refer to the [Bucketizer Java docs](api/java/org/apache/spark/ml/feature/Bucketizer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; - -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(-0.5), - RowFactory.create(-0.3), - RowFactory.create(0.0), - RowFactory.create(0.2) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) -}); -DataFrame dataFrame = jsql.createDataFrame(data, schema); - -Bucketizer bucketizer = new Bucketizer() - .setInputCol("features") - .setOutputCol("bucketedFeatures") - .setSplits(splits); - -// Transform original data into its bucket index. -DataFrame bucketedData = bucketizer.transform(dataFrame); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaBucketizerExample.java %}
    -
    +
    Refer to the [Bucketizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Bucketizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Bucketizer - -splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] - -data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] -dataFrame = sqlContext.createDataFrame(data, ["features"]) - -bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") - -# Transform original data into its bucket index. -bucketedData = bucketizer.transform(dataFrame) -{% endhighlight %} +{% include_example python/ml/bucketizer_example.py %}
    @@ -1508,25 +906,7 @@ This example below demonstrates how to transform vectors using a transforming ve Refer to the [ElementwiseProduct Scala docs](api/scala/index.html#org.apache.spark.ml.feature.ElementwiseProduct) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.ElementwiseProduct -import org.apache.spark.mllib.linalg.Vectors - -// Create some vector data; also works for sparse vectors -val dataFrame = sqlContext.createDataFrame(Seq( - ("a", Vectors.dense(1.0, 2.0, 3.0)), - ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") - -val transformingVector = Vectors.dense(0.0, 1.0, 2.0) -val transformer = new ElementwiseProduct() - .setScalingVec(transformingVector) - .setInputCol("vector") - .setOutputCol("transformedVector") - -// Batch transform the vectors to create new column: -transformer.transform(dataFrame).show() - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala %}
    @@ -1534,41 +914,7 @@ transformer.transform(dataFrame).show() Refer to the [ElementwiseProduct Java docs](api/java/org/apache/spark/ml/feature/ElementwiseProduct.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.ElementwiseProduct; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -// Create some vector data; also works for sparse vectors -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), - RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) -)); -List fields = new ArrayList(2); -fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); -fields.add(DataTypes.createStructField("vector", DataTypes.StringType, false)); -StructType schema = DataTypes.createStructType(fields); -DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema); -Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); -ElementwiseProduct transformer = new ElementwiseProduct() - .setScalingVec(transformingVector) - .setInputCol("vector") - .setOutputCol("transformedVector"); -// Batch transform the vectors to create new column: -transformer.transform(dataFrame).show(); - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java %}
    @@ -1576,19 +922,67 @@ transformer.transform(dataFrame).show(); Refer to the [ElementwiseProduct Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.ElementwiseProduct) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import ElementwiseProduct -from pyspark.mllib.linalg import Vectors +{% include_example python/ml/elementwise_product_example.py %} +
    +
    + +## SQLTransformer + +`SQLTransformer` implements the transformations which are defined by SQL statement. +Currently we only support SQL syntax like `"SELECT ... FROM __THIS__ ..."` +where `"__THIS__"` represents the underlying table of the input dataset. +The select clause specifies the fields, constants, and expressions to display in +the output, it can be any select clause that Spark SQL supports. Users can also +use Spark SQL built-in function and UDFs to operate on these selected columns. +For example, `SQLTransformer` supports statements like: + +* `SELECT a, a + b AS a_b FROM __THIS__` +* `SELECT a, SQRT(b) AS b_sqrt FROM __THIS__ where a > 5` +* `SELECT a, b, SUM(c) AS c_sum FROM __THIS__ GROUP BY a, b` -data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] -df = sqlContext.createDataFrame(data, ["vector"]) -transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), - inputCol="vector", outputCol="transformedVector") -transformer.transform(df).show() +**Examples** + +Assume that we have the following DataFrame with columns `id`, `v1` and `v2`: + +~~~~ + id | v1 | v2 +----|-----|----- + 0 | 1.0 | 3.0 + 2 | 2.0 | 5.0 +~~~~ + +This is the output of the `SQLTransformer` with statement `"SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__"`: + +~~~~ + id | v1 | v2 | v3 | v4 +----|-----|-----|-----|----- + 0 | 1.0 | 3.0 | 4.0 | 3.0 + 2 | 2.0 | 5.0 | 7.0 |10.0 +~~~~ -{% endhighlight %} +
    +
    + +Refer to the [SQLTransformer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.SQLTransformer) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/SQLTransformerExample.scala %}
    +
    + +Refer to the [SQLTransformer Java docs](api/java/org/apache/spark/ml/feature/SQLTransformer.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java %} +
    + +
    + +Refer to the [SQLTransformer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.SQLTransformer) for more details on the API. + +{% include_example python/ml/sql_transformer.py %} +
    ## VectorAssembler @@ -1632,19 +1026,7 @@ output column to `features`, after transformation we should get the following Da Refer to the [VectorAssembler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorAssembler) for more details on the API. -{% highlight scala %} -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.ml.feature.VectorAssembler - -val dataset = sqlContext.createDataFrame( - Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) -).toDF("id", "hour", "mobile", "userFeatures", "clicked") -val assembler = new VectorAssembler() - .setInputCols(Array("hour", "mobile", "userFeatures")) - .setOutputCol("features") -val output = assembler.transform(dataset) -println(output.select("features", "clicked").first()) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala %}
    @@ -1652,56 +1034,80 @@ println(output.select("features", "clicked").first()) Refer to the [VectorAssembler Java docs](api/java/org/apache/spark/ml/feature/VectorAssembler.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -import static org.apache.spark.sql.types.DataTypes.*; +{% include_example java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java %} +
    -StructType schema = createStructType(new StructField[] { - createStructField("id", IntegerType, false), - createStructField("hour", IntegerType, false), - createStructField("mobile", DoubleType, false), - createStructField("userFeatures", new VectorUDT(), false), - createStructField("clicked", DoubleType, false) -}); -Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); -JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); -DataFrame dataset = sqlContext.createDataFrame(rdd, schema); +
    -VectorAssembler assembler = new VectorAssembler() - .setInputCols(new String[] {"hour", "mobile", "userFeatures"}) - .setOutputCol("features"); +Refer to the [VectorAssembler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorAssembler) +for more details on the API. -DataFrame output = assembler.transform(dataset); -System.out.println(output.select("features", "clicked").first()); -{% endhighlight %} +{% include_example python/ml/vector_assembler_example.py %} +
    -
    +## QuantileDiscretizer -Refer to the [VectorAssembler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorAssembler) +`QuantileDiscretizer` takes a column with continuous features and outputs a column with binned +categorical features. +The bin ranges are chosen by taking a sample of the data and dividing it into roughly equal parts. +The lower and upper bin bounds will be `-Infinity` and `+Infinity`, covering all real values. +This attempts to find `numBuckets` partitions based on a sample of the given input data, but it may +find fewer depending on the data sample values. + +Note that the result may be different every time you run it, since the sample strategy behind it is +non-deterministic. + +**Examples** + +Assume that we have a DataFrame with the columns `id`, `hour`: + +~~~ + id | hour +----|------ + 0 | 18.0 +----|------ + 1 | 19.0 +----|------ + 2 | 8.0 +----|------ + 3 | 5.0 +----|------ + 4 | 2.2 +~~~ + +`hour` is a continuous feature with `Double` type. We want to turn the continuous feature into +categorical one. Given `numBuckets = 3`, we should get the following DataFrame: + +~~~ + id | hour | result +----|------|------ + 0 | 18.0 | 2.0 +----|------|------ + 1 | 19.0 | 2.0 +----|------|------ + 2 | 8.0 | 1.0 +----|------|------ + 3 | 5.0 | 1.0 +----|------|------ + 4 | 2.2 | 0.0 +~~~ + +
    +
    + +Refer to the [QuantileDiscretizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.QuantileDiscretizer) for more details on the API. -{% highlight python %} -from pyspark.mllib.linalg import Vectors -from pyspark.ml.feature import VectorAssembler +{% include_example scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala %} +
    + +
    + +Refer to the [QuantileDiscretizer Java docs](api/java/org/apache/spark/ml/feature/QuantileDiscretizer.html) +for more details on the API. -dataset = sqlContext.createDataFrame( - [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], - ["id", "hour", "mobile", "userFeatures", "clicked"]) -assembler = VectorAssembler( - inputCols=["hour", "mobile", "userFeatures"], - outputCol="features") -output = assembler.transform(dataset) -print(output.select("features", "clicked").first()) -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java %}
    @@ -1713,15 +1119,15 @@ print(output.select("features", "clicked").first()) sub-array of the original features. It is useful for extracting features from a vector column. `VectorSlicer` accepts a vector column with a specified indices, then outputs a new vector column -whose values are selected via those indices. There are two types of indices, +whose values are selected via those indices. There are two types of indices, 1. Integer indices that represents the indices into the vector, `setIndices()`; - 2. String indices that represents the names of features into the vector, `setNames()`. + 2. String indices that represents the names of features into the vector, `setNames()`. *This requires the vector column to have an `AttributeGroup` since the implementation matches on the name field of an `Attribute`.* -Specification by integer and string are both acceptable. Moreover, you can use integer index and +Specification by integer and string are both acceptable. Moreover, you can use integer index and string name simultaneously. At least one feature must be selected. Duplicate features are not allowed, so there can be no overlap between selected indices and names. Note that if names of features are selected, an exception will be threw out when encountering with empty input attributes. @@ -1734,9 +1140,9 @@ followed by the selected names (in the order given). Suppose that we have a DataFrame with the column `userFeatures`: ~~~ - userFeatures + userFeatures ------------------ - [0.0, 10.0, 0.5] + [0.0, 10.0, 0.5] ~~~ `userFeatures` is a vector column that contains three user features. Assuming that the first column @@ -1750,7 +1156,7 @@ column named `features`: [0.0, 10.0, 0.5] | [10.0, 0.5] ~~~ -Suppose also that we have a potential input attributes for the `userFeatures`, i.e. +Suppose also that we have a potential input attributes for the `userFeatures`, i.e. `["f1", "f2", "f3"]`, then we can use `setNames("f2", "f3")` to select them. ~~~ @@ -1766,33 +1172,7 @@ Suppose also that we have a potential input attributes for the `userFeatures`, i Refer to the [VectorSlicer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorSlicer) for more details on the API. -{% highlight scala %} -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} -import org.apache.spark.ml.feature.VectorSlicer -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, Row, SQLContext} - -val data = Array( - Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), - Vectors.dense(-2.0, 2.3, 0.0) -) - -val defaultAttr = NumericAttribute.defaultAttr -val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) -val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) - -val dataRDD = sc.parallelize(data).map(Row.apply) -val dataset = sqlContext.createDataFrame(dataRDD, StructType(attrGroup.toStructField())) - -val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") - -slicer.setIndices(1).setNames("f3") -// or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) - -val output = slicer.transform(dataset) -println(output.select("userFeatures", "features").first()) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/VectorSlicerExample.scala %}
    @@ -1800,47 +1180,31 @@ println(output.select("userFeatures", "features").first()) Refer to the [VectorSlicer Java docs](api/java/org/apache/spark/ml/feature/VectorSlicer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -import static org.apache.spark.sql.types.DataTypes.*; - -Attribute[] attrs = new Attribute[]{ - NumericAttribute.defaultAttr().withName("f1"), - NumericAttribute.defaultAttr().withName("f2"), - NumericAttribute.defaultAttr().withName("f3") -}; -AttributeGroup group = new AttributeGroup("userFeatures", attrs); +{% include_example java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java %} +
    +
    -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), - RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) -)); +## RFormula -DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); +`RFormula` selects columns specified by an [R model formula](https://stat.ethz.ch/R-manual/R-devel/library/stats/html/formula.html). +Currently we support a limited subset of the R operators, including '~', '.', ':', '+', and '-'. +The basic operators are: -VectorSlicer vectorSlicer = new VectorSlicer() - .setInputCol("userFeatures").setOutputCol("features"); +* `~` separate target and terms +* `+` concat terms, "+ 0" means removing intercept +* `-` remove a term, "- 1" means removing intercept +* `:` interaction (multiplication for numeric values, or binarized categorical values) +* `.` all columns except target -vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); -// or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) +Suppose `a` and `b` are double columns, we use the following simple examples to illustrate the effect of `RFormula`: -DataFrame output = vectorSlicer.transform(dataset); +* `y ~ a + b` means model `y ~ w0 + w1 * a + w2 * b` where `w0` is the intercept and `w1, w2` are coefficients. +* `y ~ a + b + a:b - 1` means model `y ~ w1 * a + w2 * b + w3 * a * b` where `w1, w2, w3` are coefficients. -System.out.println(output.select("userFeatures", "features").first()); -{% endhighlight %} -
    -
    - -## RFormula - -`RFormula` selects columns specified by an [R model formula](https://stat.ethz.ch/R-manual/R-devel/library/stats/html/formula.html). It produces a vector column of features and a double column of labels. Like when formulas are used in R for linear regression, string input columns will be one-hot encoded, and numeric columns will be cast to doubles. If not already present in the DataFrame, the output label column will be created from the specified response variable in the formula. +`RFormula` produces a vector column of features and a double or string column of label. +Like when formulas are used in R for linear regression, string input columns will be one-hot encoded, and numeric columns will be cast to doubles. +If the label column is of type string, it will be first transformed to double with `StringIndexer`. +If the label column does not exist in the DataFrame, the output label column will be created from the specified response variable in the formula. **Examples** @@ -1871,21 +1235,7 @@ id | country | hour | clicked | features | label Refer to the [RFormula Scala docs](api/scala/index.html#org.apache.spark.ml.feature.RFormula) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.RFormula - -val dataset = sqlContext.createDataFrame(Seq( - (7, "US", 18, 1.0), - (8, "CA", 12, 0.0), - (9, "NZ", 15, 0.0) -)).toDF("id", "country", "hour", "clicked") -val formula = new RFormula() - .setFormula("clicked ~ country + hour") - .setFeaturesCol("features") - .setLabelCol("label") -val output = formula.fit(dataset).transform(dataset) -output.select("features", "label").show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/RFormulaExample.scala %}
    @@ -1893,38 +1243,7 @@ output.select("features", "label").show() Refer to the [RFormula Java docs](api/java/org/apache/spark/ml/feature/RFormula.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.RFormula; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -import static org.apache.spark.sql.types.DataTypes.*; - -StructType schema = createStructType(new StructField[] { - createStructField("id", IntegerType, false), - createStructField("country", StringType, false), - createStructField("hour", IntegerType, false), - createStructField("clicked", DoubleType, false) -}); -JavaRDD rdd = jsc.parallelize(Arrays.asList( - RowFactory.create(7, "US", 18, 1.0), - RowFactory.create(8, "CA", 12, 0.0), - RowFactory.create(9, "NZ", 15, 0.0) -)); -DataFrame dataset = sqlContext.createDataFrame(rdd, schema); - -RFormula formula = new RFormula() - .setFormula("clicked ~ country + hour") - .setFeaturesCol("features") - .setLabelCol("label"); - -DataFrame output = formula.fit(dataset).transform(dataset); -output.select("features", "label").show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaRFormulaExample.java %}
    @@ -1932,20 +1251,56 @@ output.select("features", "label").show(); Refer to the [RFormula Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RFormula) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import RFormula - -dataset = sqlContext.createDataFrame( - [(7, "US", 18, 1.0), - (8, "CA", 12, 0.0), - (9, "NZ", 15, 0.0)], - ["id", "country", "hour", "clicked"]) -formula = RFormula( - formula="clicked ~ country + hour", - featuresCol="features", - labelCol="label") -output = formula.fit(dataset).transform(dataset) -output.select("features", "label").show() -{% endhighlight %} +{% include_example python/ml/rformula_example.py %} +
    + + +## ChiSqSelector + +`ChiSqSelector` stands for Chi-Squared feature selection. It operates on labeled data with +categorical features. ChiSqSelector orders features based on a +[Chi-Squared test of independence](https://en.wikipedia.org/wiki/Chi-squared_test) +from the class, and then filters (selects) the top features which the class label depends on the +most. This is akin to yielding the features with the most predictive power. + +**Examples** + +Assume that we have a DataFrame with the columns `id`, `features`, and `clicked`, which is used as +our target to be predicted: + +~~~ +id | features | clicked +---|-----------------------|--------- + 7 | [0.0, 0.0, 18.0, 1.0] | 1.0 + 8 | [0.0, 1.0, 12.0, 0.0] | 0.0 + 9 | [1.0, 0.0, 15.0, 0.1] | 0.0 +~~~ + +If we use `ChiSqSelector` with a `numTopFeatures = 1`, then according to our label `clicked` the +last column in our `features` chosen as the most useful feature: + +~~~ +id | features | clicked | selectedFeatures +---|-----------------------|---------|------------------ + 7 | [0.0, 0.0, 18.0, 1.0] | 1.0 | [1.0] + 8 | [0.0, 1.0, 12.0, 0.0] | 0.0 | [0.0] + 9 | [1.0, 0.0, 15.0, 0.1] | 0.0 | [0.1] +~~~ + +
    +
    + +Refer to the [ChiSqSelector Scala docs](api/scala/index.html#org.apache.spark.ml.feature.ChiSqSelector) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala %} +
    + +
    + +Refer to the [ChiSqSelector Java docs](api/java/org/apache/spark/ml/feature/ChiSqSelector.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java %}
    diff --git a/docs/ml-guide.md b/docs/ml-guide.md index fd3a6167bc65e..99167873cd02d 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -1,8 +1,10 @@ --- layout: global -title: Spark ML Programming Guide +title: "Overview: estimators, transformers and pipelines - spark.ml" +displayTitle: "Overview: estimators, transformers and pipelines - spark.ml" --- + `\[ \newcommand{\R}{\mathbb{R}} \newcommand{\E}{\mathbb{E}} @@ -32,19 +34,6 @@ See the [algorithm guides](#algorithm-guides) section below for guides on sub-pa * This will become a table of contents (this text will be scraped). {:toc} -# Algorithm guides - -We provide several algorithm guides specific to the Pipelines API. -Several of these algorithms, such as certain feature transformers, are not in the `spark.mllib` API. -Also, some algorithms have additional capabilities in the `spark.ml` API; e.g., random forests -provide class probabilities, and linear models provide model summaries. - -* [Feature extraction, transformation, and selection](ml-features.html) -* [Decision Trees for classification and regression](ml-decision-tree.html) -* [Ensembles](ml-ensembles.html) -* [Linear methods with elastic net regularization](ml-linear-methods.html) -* [Multilayer perceptron classifier](ml-ann.html) - # Main concepts in Pipelines @@ -203,6 +192,10 @@ Parameters belong to specific instances of `Estimator`s and `Transformer`s. For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. +## Saving and Loading Pipelines + +Often times it is worth it to save a model or a pipeline to disk for later use. In Spark 1.6, a model import/export functionality was added to the Pipeline API. Most basic transformers are supported as well as some of the more basic ML models. Please refer to the algorithm's API documentation to see if saving and loading is supported. + # Code examples This section gives code examples illustrating the functionality discussed above. @@ -220,209 +213,15 @@ This example covers the concepts of `Estimator`, `Transformer`, and `Param`.
    -{% highlight scala %} -import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.param.ParamMap -import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.sql.Row - -// Prepare training data from a list of (label, features) tuples. -val training = sqlContext.createDataFrame(Seq( - (1.0, Vectors.dense(0.0, 1.1, 0.1)), - (0.0, Vectors.dense(2.0, 1.0, -1.0)), - (0.0, Vectors.dense(2.0, 1.3, 1.0)), - (1.0, Vectors.dense(0.0, 1.2, -0.5)) -)).toDF("label", "features") - -// Create a LogisticRegression instance. This instance is an Estimator. -val lr = new LogisticRegression() -// Print out the parameters, documentation, and any default values. -println("LogisticRegression parameters:\n" + lr.explainParams() + "\n") - -// We may set parameters using setter methods. -lr.setMaxIter(10) - .setRegParam(0.01) - -// Learn a LogisticRegression model. This uses the parameters stored in lr. -val model1 = lr.fit(training) -// Since model1 is a Model (i.e., a Transformer produced by an Estimator), -// we can view the parameters it used during fit(). -// This prints the parameter (name: value) pairs, where names are unique IDs for this -// LogisticRegression instance. -println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) - -// We may alternatively specify parameters using a ParamMap, -// which supports several methods for specifying parameters. -val paramMap = ParamMap(lr.maxIter -> 20) - .put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. - .put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. - -// One can also combine ParamMaps. -val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name -val paramMapCombined = paramMap ++ paramMap2 - -// Now learn a new model using the paramMapCombined parameters. -// paramMapCombined overrides all parameters set earlier via lr.set* methods. -val model2 = lr.fit(training, paramMapCombined) -println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) - -// Prepare test data. -val test = sqlContext.createDataFrame(Seq( - (1.0, Vectors.dense(-1.0, 1.5, 1.3)), - (0.0, Vectors.dense(3.0, 2.0, -0.1)), - (1.0, Vectors.dense(0.0, 2.2, -1.5)) -)).toDF("label", "features") - -// Make predictions on test data using the Transformer.transform() method. -// LogisticRegression.transform will only use the 'features' column. -// Note that model2.transform() outputs a 'myProbability' column instead of the usual -// 'probability' column since we renamed the lr.probabilityCol parameter previously. -model2.transform(test) - .select("features", "label", "myProbability", "prediction") - .collect() - .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) => - println(s"($features, $label) -> prob=$prob, prediction=$prediction") - } - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala %}
    -{% highlight java %} -import java.util.Arrays; -import java.util.List; - -import org.apache.spark.ml.classification.LogisticRegressionModel; -import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.ml.classification.LogisticRegression; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; - -// Prepare training data. -// We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans -// into DataFrames, where it uses the bean metadata to infer the schema. -DataFrame training = sqlContext.createDataFrame(Arrays.asList( - new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), - new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), - new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), - new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)) -), LabeledPoint.class); - -// Create a LogisticRegression instance. This instance is an Estimator. -LogisticRegression lr = new LogisticRegression(); -// Print out the parameters, documentation, and any default values. -System.out.println("LogisticRegression parameters:\n" + lr.explainParams() + "\n"); - -// We may set parameters using setter methods. -lr.setMaxIter(10) - .setRegParam(0.01); - -// Learn a LogisticRegression model. This uses the parameters stored in lr. -LogisticRegressionModel model1 = lr.fit(training); -// Since model1 is a Model (i.e., a Transformer produced by an Estimator), -// we can view the parameters it used during fit(). -// This prints the parameter (name: value) pairs, where names are unique IDs for this -// LogisticRegression instance. -System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap()); - -// We may alternatively specify parameters using a ParamMap. -ParamMap paramMap = new ParamMap() - .put(lr.maxIter().w(20)) // Specify 1 Param. - .put(lr.maxIter(), 30) // This overwrites the original maxIter. - .put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. - -// One can also combine ParamMaps. -ParamMap paramMap2 = new ParamMap() - .put(lr.probabilityCol().w("myProbability")); // Change output column name -ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); - -// Now learn a new model using the paramMapCombined parameters. -// paramMapCombined overrides all parameters set earlier via lr.set* methods. -LogisticRegressionModel model2 = lr.fit(training, paramMapCombined); -System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap()); - -// Prepare test documents. -DataFrame test = sqlContext.createDataFrame(Arrays.asList( - new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), - new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), - new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)) -), LabeledPoint.class); - -// Make predictions on test documents using the Transformer.transform() method. -// LogisticRegression.transform will only use the 'features' column. -// Note that model2.transform() outputs a 'myProbability' column instead of the usual -// 'probability' column since we renamed the lr.probabilityCol parameter previously. -DataFrame results = model2.transform(test); -for (Row r: results.select("features", "label", "myProbability", "prediction").collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) - + ", prediction=" + r.get(3)); -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java %}
    -{% highlight python %} -from pyspark.mllib.linalg import Vectors -from pyspark.ml.classification import LogisticRegression -from pyspark.ml.param import Param, Params - -# Prepare training data from a list of (label, features) tuples. -training = sqlContext.createDataFrame([ - (1.0, Vectors.dense([0.0, 1.1, 0.1])), - (0.0, Vectors.dense([2.0, 1.0, -1.0])), - (0.0, Vectors.dense([2.0, 1.3, 1.0])), - (1.0, Vectors.dense([0.0, 1.2, -0.5]))], ["label", "features"]) - -# Create a LogisticRegression instance. This instance is an Estimator. -lr = LogisticRegression(maxIter=10, regParam=0.01) -# Print out the parameters, documentation, and any default values. -print "LogisticRegression parameters:\n" + lr.explainParams() + "\n" - -# Learn a LogisticRegression model. This uses the parameters stored in lr. -model1 = lr.fit(training) - -# Since model1 is a Model (i.e., a transformer produced by an Estimator), -# we can view the parameters it used during fit(). -# This prints the parameter (name: value) pairs, where names are unique IDs for this -# LogisticRegression instance. -print "Model 1 was fit using parameters: " -print model1.extractParamMap() - -# We may alternatively specify parameters using a Python dictionary as a paramMap -paramMap = {lr.maxIter: 20} -paramMap[lr.maxIter] = 30 # Specify 1 Param, overwriting the original maxIter. -paramMap.update({lr.regParam: 0.1, lr.threshold: 0.55}) # Specify multiple Params. - -# You can combine paramMaps, which are python dictionaries. -paramMap2 = {lr.probabilityCol: "myProbability"} # Change output column name -paramMapCombined = paramMap.copy() -paramMapCombined.update(paramMap2) - -# Now learn a new model using the paramMapCombined parameters. -# paramMapCombined overrides all parameters set earlier via lr.set* methods. -model2 = lr.fit(training, paramMapCombined) -print "Model 2 was fit using parameters: " -print model2.extractParamMap() - -# Prepare test data -test = sqlContext.createDataFrame([ - (1.0, Vectors.dense([-1.0, 1.5, 1.3])), - (0.0, Vectors.dense([3.0, 2.0, -0.1])), - (1.0, Vectors.dense([0.0, 2.2, -1.5]))], ["label", "features"]) - -# Make predictions on test data using the Transformer.transform() method. -# LogisticRegression.transform will only use the 'features' column. -# Note that model2.transform() outputs a "myProbability" column instead of the usual -# 'probability' column since we renamed the lr.probabilityCol parameter previously. -prediction = model2.transform(test) -selected = prediction.select("features", "label", "myProbability", "prediction") -for row in selected.collect(): - print row - -{% endhighlight %} +{% include_example python/ml/estimator_transformer_param_example.py %}
    @@ -434,182 +233,15 @@ This example follows the simple text document `Pipeline` illustrated in the figu
    -{% highlight scala %} -import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.feature.{HashingTF, Tokenizer} -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.Row - -// Prepare training documents from a list of (id, text, label) tuples. -val training = sqlContext.createDataFrame(Seq( - (0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0) -)).toDF("id", "text", "label") - -// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. -val tokenizer = new Tokenizer() - .setInputCol("text") - .setOutputCol("words") -val hashingTF = new HashingTF() - .setNumFeatures(1000) - .setInputCol(tokenizer.getOutputCol) - .setOutputCol("features") -val lr = new LogisticRegression() - .setMaxIter(10) - .setRegParam(0.01) -val pipeline = new Pipeline() - .setStages(Array(tokenizer, hashingTF, lr)) - -// Fit the pipeline to training documents. -val model = pipeline.fit(training) - -// Prepare test documents, which are unlabeled (id, text) tuples. -val test = sqlContext.createDataFrame(Seq( - (4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop") -)).toDF("id", "text") - -// Make predictions on test documents. -model.transform(test) - .select("id", "text", "probability", "prediction") - .collect() - .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => - println(s"($id, $text) --> prob=$prob, prediction=$prediction") - } - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/PipelineExample.scala %}
    -{% highlight java %} -import java.util.Arrays; -import java.util.List; - -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineModel; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.classification.LogisticRegression; -import org.apache.spark.ml.feature.HashingTF; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; - -// Labeled and unlabeled instance types. -// Spark SQL can infer schema from Java Beans. -public class Document implements Serializable { - private long id; - private String text; - - public Document(long id, String text) { - this.id = id; - this.text = text; - } - - public long getId() { return this.id; } - public void setId(long id) { this.id = id; } - - public String getText() { return this.text; } - public void setText(String text) { this.text = text; } -} - -public class LabeledDocument extends Document implements Serializable { - private double label; - - public LabeledDocument(long id, String text, double label) { - super(id, text); - this.label = label; - } - - public double getLabel() { return this.label; } - public void setLabel(double label) { this.label = label; } -} - -// Prepare training documents, which are labeled. -DataFrame training = sqlContext.createDataFrame(Arrays.asList( - new LabeledDocument(0L, "a b c d e spark", 1.0), - new LabeledDocument(1L, "b d", 0.0), - new LabeledDocument(2L, "spark f g h", 1.0), - new LabeledDocument(3L, "hadoop mapreduce", 0.0) -), LabeledDocument.class); - -// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. -Tokenizer tokenizer = new Tokenizer() - .setInputCol("text") - .setOutputCol("words"); -HashingTF hashingTF = new HashingTF() - .setNumFeatures(1000) - .setInputCol(tokenizer.getOutputCol()) - .setOutputCol("features"); -LogisticRegression lr = new LogisticRegression() - .setMaxIter(10) - .setRegParam(0.01); -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); - -// Fit the pipeline to training documents. -PipelineModel model = pipeline.fit(training); - -// Prepare test documents, which are unlabeled. -DataFrame test = sqlContext.createDataFrame(Arrays.asList( - new Document(4L, "spark i j k"), - new Document(5L, "l m n"), - new Document(6L, "mapreduce spark"), - new Document(7L, "apache hadoop") -), Document.class); - -// Make predictions on test documents. -DataFrame predictions = model.transform(test); -for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) - + ", prediction=" + r.get(3)); -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaPipelineExample.java %}
    -{% highlight python %} -from pyspark.ml import Pipeline -from pyspark.ml.classification import LogisticRegression -from pyspark.ml.feature import HashingTF, Tokenizer -from pyspark.sql import Row - -# Prepare training documents from a list of (id, text, label) tuples. -LabeledDocument = Row("id", "text", "label") -training = sqlContext.createDataFrame([ - (0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0)], ["id", "text", "label"]) - -# Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. -tokenizer = Tokenizer(inputCol="text", outputCol="words") -hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") -lr = LogisticRegression(maxIter=10, regParam=0.01) -pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) - -# Fit the pipeline to training documents. -model = pipeline.fit(training) - -# Prepare test documents, which are unlabeled (id, text) tuples. -test = sqlContext.createDataFrame([ - (4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop")], ["id", "text"]) - -# Make predictions on test documents and print columns of interest. -prediction = model.transform(test) -selected = prediction.select("id", "text", "prediction") -for row in selected.collect(): - print(row) - -{% endhighlight %} +{% include_example python/ml/pipeline_example.py %}
    @@ -625,8 +257,8 @@ Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/ The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.RegressionEvaluator) for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator) -for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MultiClassClassificationEvaluator) -for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetric` +for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator) +for multiclass problems. The default metric used to choose the best `ParamMap` can be overridden by the `setMetricName` method in each of these evaluators. The `ParamMap` which produces the best evaluation metric (averaged over the `$k$` folds) is selected as the best model. @@ -644,201 +276,16 @@ However, it is also a well-established method for choosing parameters which is m
    -{% highlight scala %} -import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator -import org.apache.spark.ml.feature.{HashingTF, Tokenizer} -import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.Row - -// Prepare training data from a list of (id, text, label) tuples. -val training = sqlContext.createDataFrame(Seq( - (0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0), - (4L, "b spark who", 1.0), - (5L, "g d a y", 0.0), - (6L, "spark fly", 1.0), - (7L, "was mapreduce", 0.0), - (8L, "e spark program", 1.0), - (9L, "a e c l", 0.0), - (10L, "spark compile", 1.0), - (11L, "hadoop software", 0.0) -)).toDF("id", "text", "label") - -// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. -val tokenizer = new Tokenizer() - .setInputCol("text") - .setOutputCol("words") -val hashingTF = new HashingTF() - .setInputCol(tokenizer.getOutputCol) - .setOutputCol("features") -val lr = new LogisticRegression() - .setMaxIter(10) -val pipeline = new Pipeline() - .setStages(Array(tokenizer, hashingTF, lr)) - -// We use a ParamGridBuilder to construct a grid of parameters to search over. -// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, -// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. -val paramGrid = new ParamGridBuilder() - .addGrid(hashingTF.numFeatures, Array(10, 100, 1000)) - .addGrid(lr.regParam, Array(0.1, 0.01)) - .build() - -// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. -// This will allow us to jointly choose parameters for all Pipeline stages. -// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric -// is areaUnderROC. -val cv = new CrossValidator() - .setEstimator(pipeline) - .setEvaluator(new BinaryClassificationEvaluator) - .setEstimatorParamMaps(paramGrid) - .setNumFolds(2) // Use 3+ in practice - -// Run cross-validation, and choose the best set of parameters. -val cvModel = cv.fit(training) - -// Prepare test documents, which are unlabeled (id, text) tuples. -val test = sqlContext.createDataFrame(Seq( - (4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop") -)).toDF("id", "text") - -// Make predictions on test documents. cvModel uses the best model found (lrModel). -cvModel.transform(test) - .select("id", "text", "probability", "prediction") - .collect() - .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => - println(s"($id, $text) --> prob=$prob, prediction=$prediction") - } - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala %}
    -{% highlight java %} -import java.util.Arrays; -import java.util.List; - -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.classification.LogisticRegression; -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; -import org.apache.spark.ml.feature.HashingTF; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.ml.tuning.CrossValidator; -import org.apache.spark.ml.tuning.CrossValidatorModel; -import org.apache.spark.ml.tuning.ParamGridBuilder; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; - -// Labeled and unlabeled instance types. -// Spark SQL can infer schema from Java Beans. -public class Document implements Serializable { - private long id; - private String text; - - public Document(long id, String text) { - this.id = id; - this.text = text; - } - - public long getId() { return this.id; } - public void setId(long id) { this.id = id; } - - public String getText() { return this.text; } - public void setText(String text) { this.text = text; } -} - -public class LabeledDocument extends Document implements Serializable { - private double label; - - public LabeledDocument(long id, String text, double label) { - super(id, text); - this.label = label; - } - - public double getLabel() { return this.label; } - public void setLabel(double label) { this.label = label; } -} - - -// Prepare training documents, which are labeled. -DataFrame training = sqlContext.createDataFrame(Arrays.asList( - new LabeledDocument(0L, "a b c d e spark", 1.0), - new LabeledDocument(1L, "b d", 0.0), - new LabeledDocument(2L, "spark f g h", 1.0), - new LabeledDocument(3L, "hadoop mapreduce", 0.0), - new LabeledDocument(4L, "b spark who", 1.0), - new LabeledDocument(5L, "g d a y", 0.0), - new LabeledDocument(6L, "spark fly", 1.0), - new LabeledDocument(7L, "was mapreduce", 0.0), - new LabeledDocument(8L, "e spark program", 1.0), - new LabeledDocument(9L, "a e c l", 0.0), - new LabeledDocument(10L, "spark compile", 1.0), - new LabeledDocument(11L, "hadoop software", 0.0) -), LabeledDocument.class); - -// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. -Tokenizer tokenizer = new Tokenizer() - .setInputCol("text") - .setOutputCol("words"); -HashingTF hashingTF = new HashingTF() - .setNumFeatures(1000) - .setInputCol(tokenizer.getOutputCol()) - .setOutputCol("features"); -LogisticRegression lr = new LogisticRegression() - .setMaxIter(10) - .setRegParam(0.01); -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); - -// We use a ParamGridBuilder to construct a grid of parameters to search over. -// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, -// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. -ParamMap[] paramGrid = new ParamGridBuilder() - .addGrid(hashingTF.numFeatures(), new int[]{10, 100, 1000}) - .addGrid(lr.regParam(), new double[]{0.1, 0.01}) - .build(); - -// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. -// This will allow us to jointly choose parameters for all Pipeline stages. -// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric -// is areaUnderROC. -CrossValidator cv = new CrossValidator() - .setEstimator(pipeline) - .setEvaluator(new BinaryClassificationEvaluator()) - .setEstimatorParamMaps(paramGrid) - .setNumFolds(2); // Use 3+ in practice - -// Run cross-validation, and choose the best set of parameters. -CrossValidatorModel cvModel = cv.fit(training); - -// Prepare test documents, which are unlabeled. -DataFrame test = sqlContext.createDataFrame(Arrays.asList( - new Document(4L, "spark i j k"), - new Document(5L, "l m n"), - new Document(6L, "mapreduce spark"), - new Document(7L, "apache hadoop") -), Document.class); - -// Make predictions on test documents. cvModel uses the best model found (lrModel). -DataFrame predictions = cvModel.transform(test); -for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) - + ", prediction=" + r.get(3)); -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java %} +
    + +
    + +{% include_example python/ml/cross_validator.py %}
    @@ -862,97 +309,15 @@ The `ParamMap` which produces the best evaluation metric is selected as the best
    -{% highlight scala %} -import org.apache.spark.ml.evaluation.RegressionEvaluator -import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} -import org.apache.spark.mllib.util.MLUtils - -// Prepare training and test data. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() -val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) - -val lr = new LinearRegression() - -// We use a ParamGridBuilder to construct a grid of parameters to search over. -// TrainValidationSplit will try all combinations of values and determine best model using -// the evaluator. -val paramGrid = new ParamGridBuilder() - .addGrid(lr.regParam, Array(0.1, 0.01)) - .addGrid(lr.fitIntercept) - .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)) - .build() - -// In this case the estimator is simply the linear regression. -// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -val trainValidationSplit = new TrainValidationSplit() - .setEstimator(lr) - .setEvaluator(new RegressionEvaluator) - .setEstimatorParamMaps(paramGrid) - // 80% of the data will be used for training and the remaining 20% for validation. - .setTrainRatio(0.8) - -// Run train validation split, and choose the best set of parameters. -val model = trainValidationSplit.fit(training) - -// Make predictions on test data. model is the model with combination of parameters -// that performed best. -model.transform(test) - .select("features", "label", "prediction") - .show() - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala %}
    -{% highlight java %} -import org.apache.spark.ml.evaluation.RegressionEvaluator; -import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.ml.regression.LinearRegression; -import org.apache.spark.ml.tuning.*; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; -import org.apache.spark.sql.DataFrame; - -DataFrame data = sqlContext.createDataFrame( - MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"), - LabeledPoint.class); - -// Prepare training and test data. -DataFrame[] splits = data.randomSplit(new double[] {0.9, 0.1}, 12345); -DataFrame training = splits[0]; -DataFrame test = splits[1]; - -LinearRegression lr = new LinearRegression(); - -// We use a ParamGridBuilder to construct a grid of parameters to search over. -// TrainValidationSplit will try all combinations of values and determine best model using -// the evaluator. -ParamMap[] paramGrid = new ParamGridBuilder() - .addGrid(lr.regParam(), new double[] {0.1, 0.01}) - .addGrid(lr.fitIntercept()) - .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0}) - .build(); - -// In this case the estimator is simply the linear regression. -// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -TrainValidationSplit trainValidationSplit = new TrainValidationSplit() - .setEstimator(lr) - .setEvaluator(new RegressionEvaluator()) - .setEstimatorParamMaps(paramGrid) - .setTrainRatio(0.8); // 80% for training and the remaining 20% for validation - -// Run train validation split, and choose the best set of parameters. -TrainValidationSplitModel model = trainValidationSplit.fit(training); - -// Make predictions on test data. model is the model with combination of parameters -// that performed best. -model.transform(test) - .select("features", "label", "prediction") - .show(); - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java %} +
    + +
    +{% include_example python/ml/train_validation_split.py %}
    diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md index 16e2ee71293ae..a8754835cab95 100644 --- a/docs/ml-linear-methods.md +++ b/docs/ml-linear-methods.md @@ -1,350 +1,8 @@ --- layout: global -title: Linear Methods - ML -displayTitle: ML - Linear Methods +title: Linear methods - spark.ml +displayTitle: Linear methods - spark.ml --- - -`\[ -\newcommand{\R}{\mathbb{R}} -\newcommand{\E}{\mathbb{E}} -\newcommand{\x}{\mathbf{x}} -\newcommand{\y}{\mathbf{y}} -\newcommand{\wv}{\mathbf{w}} -\newcommand{\av}{\mathbf{\alpha}} -\newcommand{\bv}{\mathbf{b}} -\newcommand{\N}{\mathbb{N}} -\newcommand{\id}{\mathbf{I}} -\newcommand{\ind}{\mathbf{1}} -\newcommand{\0}{\mathbf{0}} -\newcommand{\unit}{\mathbf{e}} -\newcommand{\one}{\mathbf{1}} -\newcommand{\zero}{\mathbf{0}} -\]` - - -In MLlib, we implement popular linear methods such as logistic -regression and linear least squares with $L_1$ or $L_2$ regularization. -Refer to [the linear methods in mllib](mllib-linear-methods.html) for -details. In `spark.ml`, we also include Pipelines API for [Elastic -net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid -of $L_1$ and $L_2$ regularization proposed in [Zou et al, Regularization -and variable selection via the elastic -net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). -Mathematically, it is defined as a convex combination of the $L_1$ and -the $L_2$ regularization terms: -`\[ -\alpha \left( \lambda \|\wv\|_1 \right) + (1-\alpha) \left( \frac{\lambda}{2}\|\wv\|_2^2 \right) , \alpha \in [0, 1], \lambda \geq 0 -\]` -By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$ -regularization as special cases. For example, if a [linear -regression](https://en.wikipedia.org/wiki/Linear_regression) model is -trained with the elastic net parameter $\alpha$ set to $1$, it is -equivalent to a -[Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. -On the other hand, if $\alpha$ is set to $0$, the trained model reduces -to a [ridge -regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. -We implement Pipelines API for both linear regression and logistic -regression with elastic net regularization. - -## Example: Logistic Regression - -The following example shows how to train a logistic regression model -with elastic net regularization. `elasticNetParam` corresponds to -$\alpha$ and `regParam` corresponds to $\lambda$. - -
    - -
    -{% highlight scala %} -import org.apache.spark.ml.classification.LogisticRegression - -// Load training data -val training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - -val lr = new LogisticRegression() - .setMaxIter(10) - .setRegParam(0.3) - .setElasticNetParam(0.8) - -// Fit the model -val lrModel = lr.fit(training) - -// Print the coefficients and intercept for logistic regression -println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}") -{% endhighlight %} -
    - -
    -{% highlight java %} -import org.apache.spark.ml.classification.LogisticRegression; -import org.apache.spark.ml.classification.LogisticRegressionModel; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; - -public class LogisticRegressionWithElasticNetExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf() - .setAppName("Logistic Regression with Elastic Net Example"); - - SparkContext sc = new SparkContext(conf); - SQLContext sql = new SQLContext(sc); - String path = "data/mllib/sample_libsvm_data.txt"; - - // Load training data - DataFrame training = sqlContext.read.format("libsvm").load(path); - - LogisticRegression lr = new LogisticRegression() - .setMaxIter(10) - .setRegParam(0.3) - .setElasticNetParam(0.8); - - // Fit the model - LogisticRegressionModel lrModel = lr.fit(training); - - // Print the coefficients and intercept for logistic regression - System.out.println("Coefficients: " + lrModel.coefficients() + " Intercept: " + lrModel.intercept()); - } -} -{% endhighlight %} -
    - -
    -{% highlight python %} -from pyspark.ml.classification import LogisticRegression - -# Load training data -training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - -lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) - -# Fit the model -lrModel = lr.fit(training) - -# Print the coefficients and intercept for logistic regression -print("Coefficients: " + str(lrModel.coefficients)) -print("Intercept: " + str(lrModel.intercept)) -{% endhighlight %} -
    - -
    - -The `spark.ml` implementation of logistic regression also supports -extracting a summary of the model over the training set. Note that the -predictions and metrics which are stored as `Dataframe` in -`BinaryLogisticRegressionSummary` are annotated `@transient` and hence -only available on the driver. - -
    - -
    - -[`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary) -provides a summary for a -[`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel). -Currently, only binary classification is supported and the -summary must be explicitly cast to -[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary). -This will likely change when multiclass classification is supported. - -Continuing the earlier example: - -{% highlight scala %} -import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary - -// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example -val trainingSummary = lrModel.summary - -// Obtain the objective per iteration. -val objectiveHistory = trainingSummary.objectiveHistory -objectiveHistory.foreach(loss => println(loss)) - -// Obtain the metrics useful to judge performance on test data. -// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a -// binary classification problem. -val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary] - -// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. -val roc = binarySummary.roc -roc.show() -println(binarySummary.areaUnderROC) - -// Set the model threshold to maximize F-Measure -val fMeasure = binarySummary.fMeasureByThreshold -val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0) -val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure). - select("threshold").head().getDouble(0) -lrModel.setThreshold(bestThreshold) -{% endhighlight %} -
    - -
    -[`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html) -provides a summary for a -[`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html). -Currently, only binary classification is supported and the -summary must be explicitly cast to -[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). -This will likely change when multiclass classification is supported. - -Continuing the earlier example: - -{% highlight java %} -import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary; -import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary; -import org.apache.spark.sql.functions; - -// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example -LogisticRegressionTrainingSummary trainingSummary = lrModel.summary(); - -// Obtain the loss per iteration. -double[] objectiveHistory = trainingSummary.objectiveHistory(); -for (double lossPerIteration : objectiveHistory) { - System.out.println(lossPerIteration); -} - -// Obtain the metrics useful to judge performance on test data. -// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a -// binary classification problem. -BinaryLogisticRegressionSummary binarySummary = (BinaryLogisticRegressionSummary) trainingSummary; - -// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. -DataFrame roc = binarySummary.roc(); -roc.show(); -roc.select("FPR").show(); -System.out.println(binarySummary.areaUnderROC()); - -// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with -// this selected threshold. -DataFrame fMeasure = binarySummary.fMeasureByThreshold(); -double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0); -double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)). - select("threshold").head().getDouble(0); -lrModel.setThreshold(bestThreshold); -{% endhighlight %} -
    - - -
    -Logistic regression model summary is not yet supported in Python. -
    - -
    - -## Example: Linear Regression - -The interface for working with linear regression models and model -summaries is similar to the logistic regression case. The following -example demonstrates training an elastic net regularized linear -regression model and extracting model summary statistics. - -
    - -
    -{% highlight scala %} -import org.apache.spark.ml.regression.LinearRegression - -// Load training data -val training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - -val lr = new LinearRegression() - .setMaxIter(10) - .setRegParam(0.3) - .setElasticNetParam(0.8) - -// Fit the model -val lrModel = lr.fit(training) - -// Print the coefficients and intercept for linear regression -println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}") - -// Summarize the model over the training set and print out some metrics -val trainingSummary = lrModel.summary -println(s"numIterations: ${trainingSummary.totalIterations}") -println(s"objectiveHistory: ${trainingSummary.objectiveHistory.toList}") -trainingSummary.residuals.show() -println(s"RMSE: ${trainingSummary.rootMeanSquaredError}") -println(s"r2: ${trainingSummary.r2}") -{% endhighlight %} -
    - -
    -{% highlight java %} -import org.apache.spark.ml.regression.LinearRegression; -import org.apache.spark.ml.regression.LinearRegressionModel; -import org.apache.spark.ml.regression.LinearRegressionTrainingSummary; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; - -public class LinearRegressionWithElasticNetExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf() - .setAppName("Linear Regression with Elastic Net Example"); - - SparkContext sc = new SparkContext(conf); - SQLContext sql = new SQLContext(sc); - String path = "data/mllib/sample_libsvm_data.txt"; - - // Load training data - DataFrame training = sqlContext.read.format("libsvm").load(path); - - LinearRegression lr = new LinearRegression() - .setMaxIter(10) - .setRegParam(0.3) - .setElasticNetParam(0.8); - - // Fit the model - LinearRegressionModel lrModel = lr.fit(training); - - // Print the coefficients and intercept for linear regression - System.out.println("Coefficients: " + lrModel.coefficients() + " Intercept: " + lrModel.intercept()); - - // Summarize the model over the training set and print out some metrics - LinearRegressionTrainingSummary trainingSummary = lrModel.summary(); - System.out.println("numIterations: " + trainingSummary.totalIterations()); - System.out.println("objectiveHistory: " + Vectors.dense(trainingSummary.objectiveHistory())); - trainingSummary.residuals().show(); - System.out.println("RMSE: " + trainingSummary.rootMeanSquaredError()); - System.out.println("r2: " + trainingSummary.r2()); - } -} -{% endhighlight %} -
    - -
    - -{% highlight python %} -from pyspark.ml.regression import LinearRegression - -# Load training data -training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - -lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) - -# Fit the model -lrModel = lr.fit(training) - -# Print the coefficients and intercept for linear regression -print("Coefficients: " + str(lrModel.coefficients)) -print("Intercept: " + str(lrModel.intercept)) - -# Linear regression model summary is not yet supported in Python. -{% endhighlight %} -
    - -
    - -# Optimization - -The optimization algorithm underlying the implementation is called -[Orthant-Wise Limited-memory -QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) -(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 -regularization and elastic net. - + > This section has been moved into the + [classification and regression section](ml-classification-regression.html). diff --git a/docs/ml-survival-regression.md b/docs/ml-survival-regression.md new file mode 100644 index 0000000000000..856ceb2f4e7f6 --- /dev/null +++ b/docs/ml-survival-regression.md @@ -0,0 +1,8 @@ +--- +layout: global +title: Survival Regression - spark.ml +displayTitle: Survival Regression - spark.ml +--- + + > This section has been moved into the + [classification and regression section](ml-classification-regression.html#survival-regression). diff --git a/docs/mllib-classification-regression.md b/docs/mllib-classification-regression.md index 0210950b89906..aaf8bd465c9ab 100644 --- a/docs/mllib-classification-regression.md +++ b/docs/mllib-classification-regression.md @@ -1,10 +1,10 @@ --- layout: global -title: Classification and Regression - MLlib -displayTitle: MLlib - Classification and Regression +title: Classification and Regression - spark.mllib +displayTitle: Classification and Regression - spark.mllib --- -MLlib supports various methods for +The `spark.mllib` package supports various methods for [binary classification](http://en.wikipedia.org/wiki/Binary_classification), [multiclass classification](http://en.wikipedia.org/wiki/Multiclass_classification), and diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 8fbced6c87d9f..6897ba4a5d57d 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -1,7 +1,7 @@ --- layout: global -title: Clustering - MLlib -displayTitle: MLlib - Clustering +title: Clustering - spark.mllib +displayTitle: Clustering - spark.mllib --- [Clustering](https://en.wikipedia.org/wiki/Cluster_analysis) is an unsupervised learning problem whereby we aim to group subsets @@ -10,19 +10,19 @@ often used for exploratory analysis and/or as a component of a hierarchical [supervised learning](https://en.wikipedia.org/wiki/Supervised_learning) pipeline (in which distinct classifiers or regression models are trained for each cluster). -MLlib supports the following models: +The `spark.mllib` package supports the following models: * Table of contents {:toc} ## K-means -[k-means](http://en.wikipedia.org/wiki/K-means_clustering) is one of the +[K-means](http://en.wikipedia.org/wiki/K-means_clustering) is one of the most commonly used clustering algorithms that clusters the data points into a -predefined number of clusters. The MLlib implementation includes a parallelized +predefined number of clusters. The `spark.mllib` implementation includes a parallelized variant of the [k-means++](http://en.wikipedia.org/wiki/K-means%2B%2B) method called [kmeans||](http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf). -The implementation in MLlib has the following parameters: +The implementation in `spark.mllib` has the following parameters: * *k* is the number of desired clusters. * *maxIterations* is the maximum number of iterations to run. @@ -49,27 +49,7 @@ optimal *k* is usually one where there is an "elbow" in the WSSSE graph. Refer to the [`KMeans` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.KMeans) and [`KMeansModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.KMeansModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.clustering.{KMeans, KMeansModel} -import org.apache.spark.mllib.linalg.Vectors - -// Load and parse the data -val data = sc.textFile("data/mllib/kmeans_data.txt") -val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))).cache() - -// Cluster the data into two classes using KMeans -val numClusters = 2 -val numIterations = 20 -val clusters = KMeans.train(parsedData, numClusters, numIterations) - -// Evaluate clustering by computing Within Set Sum of Squared Errors -val WSSSE = clusters.computeCost(parsedData) -println("Within Set Sum of Squared Errors = " + WSSSE) - -// Save and load model -clusters.save(sc, "myModelPath") -val sameModel = KMeansModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/KMeansExample.scala %}
    @@ -81,51 +61,7 @@ that is equivalent to the provided example in Scala is given below: Refer to the [`KMeans` Java docs](api/java/org/apache/spark/mllib/clustering/KMeans.html) and [`KMeansModel` Java docs](api/java/org/apache/spark/mllib/clustering/KMeansModel.html) for details on the API. -{% highlight java %} -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.clustering.KMeans; -import org.apache.spark.mllib.clustering.KMeansModel; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.SparkConf; - -public class KMeansExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("K-means Example"); - JavaSparkContext sc = new JavaSparkContext(conf); - - // Load and parse data - String path = "data/mllib/kmeans_data.txt"; - JavaRDD data = sc.textFile(path); - JavaRDD parsedData = data.map( - new Function() { - public Vector call(String s) { - String[] sarray = s.split(" "); - double[] values = new double[sarray.length]; - for (int i = 0; i < sarray.length; i++) - values[i] = Double.parseDouble(sarray[i]); - return Vectors.dense(values); - } - } - ); - parsedData.cache(); - - // Cluster the data into two classes using KMeans - int numClusters = 2; - int numIterations = 20; - KMeansModel clusters = KMeans.train(parsedData.rdd(), numClusters, numIterations); - - // Evaluate clustering by computing Within Set Sum of Squared Errors - double WSSSE = clusters.computeCost(parsedData.rdd()); - System.out.println("Within Set Sum of Squared Errors = " + WSSSE); - - // Save and load model - clusters.save(sc.sc(), "myModelPath"); - KMeansModel sameModel = KMeansModel.load(sc.sc(), "myModelPath"); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaKMeansExample.java %}
    @@ -138,31 +74,7 @@ fact the optimal *k* is usually one where there is an "elbow" in the WSSSE graph Refer to the [`KMeans` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.clustering.KMeans) and [`KMeansModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.clustering.KMeansModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.clustering import KMeans, KMeansModel -from numpy import array -from math import sqrt - -# Load and parse the data -data = sc.textFile("data/mllib/kmeans_data.txt") -parsedData = data.map(lambda line: array([float(x) for x in line.split(' ')])) - -# Build the model (cluster the data) -clusters = KMeans.train(parsedData, 2, maxIterations=10, - runs=10, initializationMode="random") - -# Evaluate clustering by computing Within Set Sum of Squared Errors -def error(point): - center = clusters.centers[clusters.predict(point)] - return sqrt(sum([x**2 for x in (point - center)])) - -WSSSE = parsedData.map(lambda point: error(point)).reduce(lambda x, y: x + y) -print("Within Set Sum of Squared Error = " + str(WSSSE)) - -# Save and load model -clusters.save(sc, "myModelPath") -sameModel = KMeansModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/k_means_example.py %}
    @@ -171,7 +83,7 @@ sameModel = KMeansModel.load(sc, "myModelPath") A [Gaussian Mixture Model](http://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model) represents a composite distribution whereby points are drawn from one of *k* Gaussian sub-distributions, -each with its own probability. The MLlib implementation uses the +each with its own probability. The `spark.mllib` implementation uses the [expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) algorithm to induce the maximum-likelihood model given a set of samples. The implementation has the following parameters: @@ -192,29 +104,7 @@ to the algorithm. We then output the parameters of the mixture model. Refer to the [`GaussianMixture` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.GaussianMixture) and [`GaussianMixtureModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.GaussianMixtureModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.clustering.GaussianMixture -import org.apache.spark.mllib.clustering.GaussianMixtureModel -import org.apache.spark.mllib.linalg.Vectors - -// Load and parse the data -val data = sc.textFile("data/mllib/gmm_data.txt") -val parsedData = data.map(s => Vectors.dense(s.trim.split(' ').map(_.toDouble))).cache() - -// Cluster the data into two classes using GaussianMixture -val gmm = new GaussianMixture().setK(2).run(parsedData) - -// Save and load model -gmm.save(sc, "myGMMModel") -val sameModel = GaussianMixtureModel.load(sc, "myGMMModel") - -// output parameters of max-likelihood model -for (i <- 0 until gmm.k) { - println("weight=%f\nmu=%s\nsigma=\n%s\n" format - (gmm.weights(i), gmm.gaussians(i).mu, gmm.gaussians(i).sigma)) -} - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/GaussianMixtureExample.scala %}
    @@ -226,50 +116,7 @@ that is equivalent to the provided example in Scala is given below: Refer to the [`GaussianMixture` Java docs](api/java/org/apache/spark/mllib/clustering/GaussianMixture.html) and [`GaussianMixtureModel` Java docs](api/java/org/apache/spark/mllib/clustering/GaussianMixtureModel.html) for details on the API. -{% highlight java %} -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.clustering.GaussianMixture; -import org.apache.spark.mllib.clustering.GaussianMixtureModel; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.SparkConf; - -public class GaussianMixtureExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("GaussianMixture Example"); - JavaSparkContext sc = new JavaSparkContext(conf); - - // Load and parse data - String path = "data/mllib/gmm_data.txt"; - JavaRDD data = sc.textFile(path); - JavaRDD parsedData = data.map( - new Function() { - public Vector call(String s) { - String[] sarray = s.trim().split(" "); - double[] values = new double[sarray.length]; - for (int i = 0; i < sarray.length; i++) - values[i] = Double.parseDouble(sarray[i]); - return Vectors.dense(values); - } - } - ); - parsedData.cache(); - - // Cluster the data into two classes using GaussianMixture - GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd()); - - // Save and load GaussianMixtureModel - gmm.save(sc.sc(), "myGMMModel"); - GaussianMixtureModel sameModel = GaussianMixtureModel.load(sc.sc(), "myGMMModel"); - // Output the parameters of the mixture model - for(int j=0; j
    @@ -280,23 +127,7 @@ to the algorithm. We then output the parameters of the mixture model. Refer to the [`GaussianMixture` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.clustering.GaussianMixture) and [`GaussianMixtureModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.clustering.GaussianMixtureModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.clustering import GaussianMixture -from numpy import array - -# Load and parse the data -data = sc.textFile("data/mllib/gmm_data.txt") -parsedData = data.map(lambda line: array([float(x) for x in line.strip().split(' ')])) - -# Build the model (cluster the data) -gmm = GaussianMixture.train(parsedData, 2) - -# output parameters of model -for i in range(2): - print ("weight = ", gmm.weights[i], "mu = ", gmm.gaussians[i].mu, - "sigma = ", gmm.gaussians[i].sigma.toArray()) - -{% endhighlight %} +{% include_example python/mllib/gaussian_mixture_example.py %}
    @@ -304,17 +135,17 @@ for i in range(2): ## Power iteration clustering (PIC) Power iteration clustering (PIC) is a scalable and efficient algorithm for clustering vertices of a -graph given pairwise similarties as edge properties, +graph given pairwise similarities as edge properties, described in [Lin and Cohen, Power Iteration Clustering](http://www.icml2010.org/papers/387.pdf). It computes a pseudo-eigenvector of the normalized affinity matrix of the graph via [power iteration](http://en.wikipedia.org/wiki/Power_iteration) and uses it to cluster vertices. -MLlib includes an implementation of PIC using GraphX as its backend. +`spark.mllib` includes an implementation of PIC using GraphX as its backend. It takes an `RDD` of `(srcId, dstId, similarity)` tuples and outputs a model with the clustering assignments. The similarities must be nonnegative. PIC assumes that the similarity measure is symmetric. A pair `(srcId, dstId)` regardless of the ordering should appear at most once in the input data. If a pair is missing from input, their similarity is treated as zero. -MLlib's PIC implementation takes the following (hyper-)parameters: +`spark.mllib`'s PIC implementation takes the following (hyper-)parameters: * `k`: number of clusters * `maxIterations`: maximum number of power iterations @@ -323,7 +154,7 @@ MLlib's PIC implementation takes the following (hyper-)parameters: **Examples** -In the following, we show code snippets to demonstrate how to use PIC in MLlib. +In the following, we show code snippets to demonstrate how to use PIC in `spark.mllib`.
    @@ -338,31 +169,7 @@ which contains the computed clustering assignments. Refer to the [`PowerIterationClustering` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.PowerIterationClustering) and [`PowerIterationClusteringModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.PowerIterationClusteringModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.clustering.{PowerIterationClustering, PowerIterationClusteringModel} -import org.apache.spark.mllib.linalg.Vectors - -// Load and parse the data -val data = sc.textFile("data/mllib/pic_data.txt") -val similarities = data.map { line => - val parts = line.split(' ') - (parts(0).toLong, parts(1).toLong, parts(2).toDouble) -} - -// Cluster the data into two classes using PowerIterationClustering -val pic = new PowerIterationClustering() - .setK(2) - .setMaxIterations(10) -val model = pic.run(similarities) - -model.assignments.foreach { a => - println(s"${a.id} -> ${a.cluster}") -} - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = PowerIterationClusteringModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala %} A full example that produces the experiment described in the PIC paper can be found under [`examples/`](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala). @@ -381,40 +188,7 @@ which contains the computed clustering assignments. Refer to the [`PowerIterationClustering` Java docs](api/java/org/apache/spark/mllib/clustering/PowerIterationClustering.html) and [`PowerIterationClusteringModel` Java docs](api/java/org/apache/spark/mllib/clustering/PowerIterationClusteringModel.html) for details on the API. -{% highlight java %} -import scala.Tuple2; -import scala.Tuple3; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.clustering.PowerIterationClustering; -import org.apache.spark.mllib.clustering.PowerIterationClusteringModel; - -// Load and parse the data -JavaRDD data = sc.textFile("data/mllib/pic_data.txt"); -JavaRDD> similarities = data.map( - new Function>() { - public Tuple3 call(String line) { - String[] parts = line.split(" "); - return new Tuple3<>(new Long(parts[0]), new Long(parts[1]), new Double(parts[2])); - } - } -); - -// Cluster the data into two classes using PowerIterationClustering -PowerIterationClustering pic = new PowerIterationClustering() - .setK(2) - .setMaxIterations(10); -PowerIterationClusteringModel model = pic.run(similarities); - -for (PowerIterationClustering.Assignment a: model.assignments().toJavaRDD().collect()) { - System.out.println(a.id() + " -> " + a.cluster()); -} - -// Save and load model -model.save(sc.sc(), "myModelPath"); -PowerIterationClusteringModel sameModel = PowerIterationClusteringModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java %}
    @@ -429,23 +203,7 @@ which contains the computed clustering assignments. Refer to the [`PowerIterationClustering` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.clustering.PowerIterationClustering) and [`PowerIterationClusteringModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.clustering.PowerIterationClusteringModel) for more details on the API. -{% highlight python %} -from __future__ import print_function -from pyspark.mllib.clustering import PowerIterationClustering, PowerIterationClusteringModel - -# Load and parse the data -data = sc.textFile("data/mllib/pic_data.txt") -similarities = data.map(lambda line: tuple([float(x) for x in line.split(' ')])) - -# Cluster the data into two classes using PowerIterationClustering -model = PowerIterationClustering.train(similarities, 2, 10) - -model.assignments().foreach(lambda x: print(str(x.id) + " -> " + str(x.cluster))) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = PowerIterationClusteringModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/power_iteration_clustering_example.py %}
    @@ -493,7 +251,7 @@ checkpointing can help reduce shuffle file sizes on disk and help with failure recovery. -All of MLlib's LDA models support: +All of `spark.mllib`'s LDA models support: * `describeTopics`: Returns topics as arrays of most important terms and term weights @@ -591,137 +349,68 @@ to the algorithm. We then output the topics, represented as probability distribu
    Refer to the [`LDA` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.LDA) and [`DistributedLDAModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.DistributedLDAModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.clustering.{LDA, DistributedLDAModel} -import org.apache.spark.mllib.linalg.Vectors - -// Load and parse the data -val data = sc.textFile("data/mllib/sample_lda_data.txt") -val parsedData = data.map(s => Vectors.dense(s.trim.split(' ').map(_.toDouble))) -// Index documents with unique IDs -val corpus = parsedData.zipWithIndex.map(_.swap).cache() - -// Cluster the documents into three topics using LDA -val ldaModel = new LDA().setK(3).run(corpus) - -// Output topics. Each is a distribution over words (matching word count vectors) -println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize + " words):") -val topics = ldaModel.topicsMatrix -for (topic <- Range(0, 3)) { - print("Topic " + topic + ":") - for (word <- Range(0, ldaModel.vocabSize)) { print(" " + topics(word, topic)); } - println() -} - -// Save and load model. -ldaModel.save(sc, "myLDAModel") -val sameModel = DistributedLDAModel.load(sc, "myLDAModel") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala %}
    Refer to the [`LDA` Java docs](api/java/org/apache/spark/mllib/clustering/LDA.html) and [`DistributedLDAModel` Java docs](api/java/org/apache/spark/mllib/clustering/DistributedLDAModel.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.clustering.DistributedLDAModel; -import org.apache.spark.mllib.clustering.LDA; -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.SparkConf; - -public class JavaLDAExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("LDA Example"); - JavaSparkContext sc = new JavaSparkContext(conf); - - // Load and parse the data - String path = "data/mllib/sample_lda_data.txt"; - JavaRDD data = sc.textFile(path); - JavaRDD parsedData = data.map( - new Function() { - public Vector call(String s) { - String[] sarray = s.trim().split(" "); - double[] values = new double[sarray.length]; - for (int i = 0; i < sarray.length; i++) - values[i] = Double.parseDouble(sarray[i]); - return Vectors.dense(values); - } - } - ); - // Index documents with unique IDs - JavaPairRDD corpus = JavaPairRDD.fromJavaRDD(parsedData.zipWithIndex().map( - new Function, Tuple2>() { - public Tuple2 call(Tuple2 doc_id) { - return doc_id.swap(); - } - } - )); - corpus.cache(); - - // Cluster the documents into three topics using LDA - DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus); - - // Output topics. Each is a distribution over words (matching word count vectors) - System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize() - + " words):"); - Matrix topics = ldaModel.topicsMatrix(); - for (int topic = 0; topic < 3; topic++) { - System.out.print("Topic " + topic + ":"); - for (int word = 0; word < ldaModel.vocabSize(); word++) { - System.out.print(" " + topics.apply(word, topic)); - } - System.out.println(); - } - - ldaModel.save(sc.sc(), "myLDAModel"); - DistributedLDAModel sameModel = DistributedLDAModel.load(sc.sc(), "myLDAModel"); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaLatentDirichletAllocationExample.java %}
    Refer to the [`LDA` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.clustering.LDA) and [`LDAModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.clustering.LDAModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.clustering import LDA, LDAModel -from pyspark.mllib.linalg import Vectors - -# Load and parse the data -data = sc.textFile("data/mllib/sample_lda_data.txt") -parsedData = data.map(lambda line: Vectors.dense([float(x) for x in line.strip().split(' ')])) -# Index documents with unique IDs -corpus = parsedData.zipWithIndex().map(lambda x: [x[1], x[0]]).cache() - -# Cluster the documents into three topics using LDA -ldaModel = LDA.train(corpus, k=3) - -# Output topics. Each is a distribution over words (matching word count vectors) -print("Learned topics (as distributions over vocab of " + str(ldaModel.vocabSize()) + " words):") -topics = ldaModel.topicsMatrix() -for topic in range(3): - print("Topic " + str(topic) + ":") - for word in range(0, ldaModel.vocabSize()): - print(" " + str(topics[word][topic])) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = LDAModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/latent_dirichlet_allocation_example.py %}
    +## Bisecting k-means + +Bisecting K-means can often be much faster than regular K-means, but it will generally produce a different clustering. + +Bisecting k-means is a kind of [hierarchical clustering](https://en.wikipedia.org/wiki/Hierarchical_clustering). +Hierarchical clustering is one of the most commonly used method of cluster analysis which seeks to build a hierarchy of clusters. +Strategies for hierarchical clustering generally fall into two types: + +- Agglomerative: This is a "bottom up" approach: each observation starts in its own cluster, and pairs of clusters are merged as one moves up the hierarchy. +- Divisive: This is a "top down" approach: all observations start in one cluster, and splits are performed recursively as one moves down the hierarchy. + +Bisecting k-means algorithm is a kind of divisive algorithms. +The implementation in MLlib has the following parameters: + +* *k*: the desired number of leaf clusters (default: 4). The actual number could be smaller if there are no divisible leaf clusters. +* *maxIterations*: the max number of k-means iterations to split clusters (default: 20) +* *minDivisibleClusterSize*: the minimum number of points (if >= 1.0) or the minimum proportion of points (if < 1.0) of a divisible cluster (default: 1) +* *seed*: a random seed (default: hash value of the class name) + +**Examples** + +
    +
    +Refer to the [`BisectingKMeans` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.BisectingKMeans) and [`BisectingKMeansModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.BisectingKMeansModel) for details on the API. + +{% include_example scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala %} +
    + +
    +Refer to the [`BisectingKMeans` Java docs](api/java/org/apache/spark/mllib/clustering/BisectingKMeans.html) and [`BisectingKMeansModel` Java docs](api/java/org/apache/spark/mllib/clustering/BisectingKMeansModel.html) for details on the API. + +{% include_example java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java %} +
    + +
    +Refer to the [`BisectingKMeans` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.clustering.BisectingKMeans) and [`BisectingKMeansModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.clustering.BisectingKMeansModel) for more details on the API. + +{% include_example python/mllib/bisecting_k_means_example.py %} +
    +
    + ## Streaming k-means When data arrive in a stream, we may want to estimate clusters dynamically, -updating them as new data arrive. MLlib provides support for streaming k-means clustering, +updating them as new data arrive. `spark.mllib` provides support for streaming k-means clustering, with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm uses a generalization of the mini-batch k-means update rule. For each batch of data, we assign all points to their nearest cluster, compute new cluster centers, then update each cluster using: @@ -754,96 +443,16 @@ This example shows how to estimate clusters on streaming data.
    Refer to the [`StreamingKMeans` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.StreamingKMeans) for details on the API. +And Refer to [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for details on StreamingContext. -First we import the neccessary classes. - -{% highlight scala %} - -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.clustering.StreamingKMeans - -{% endhighlight %} - -Then we make an input stream of vectors for training, as well as a stream of labeled data -points for testing. We assume a StreamingContext `ssc` has been created, see -[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. - -{% highlight scala %} - -val trainingData = ssc.textFileStream("/training/data/dir").map(Vectors.parse) -val testData = ssc.textFileStream("/testing/data/dir").map(LabeledPoint.parse) - -{% endhighlight %} - -We create a model with random clusters and specify the number of clusters to find - -{% highlight scala %} - -val numDimensions = 3 -val numClusters = 2 -val model = new StreamingKMeans() - .setK(numClusters) - .setDecayFactor(1.0) - .setRandomCenters(numDimensions, 0.0) - -{% endhighlight %} - -Now register the streams for training and testing and start the job, printing -the predicted cluster assignments on new data points as they arrive. - -{% highlight scala %} - -model.trainOn(trainingData) -model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() - -ssc.start() -ssc.awaitTermination() - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala %}
    Refer to the [`StreamingKMeans` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.clustering.StreamingKMeans) for more details on the API. +And Refer to [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for details on StreamingContext. -First we import the neccessary classes. - -{% highlight python %} -from pyspark.mllib.linalg import Vectors -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.clustering import StreamingKMeans -{% endhighlight %} - -Then we make an input stream of vectors for training, as well as a stream of labeled data -points for testing. We assume a StreamingContext `ssc` has been created, see -[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. - -{% highlight python %} -def parse(lp): - label = float(lp[lp.find('(') + 1: lp.find(',')]) - vec = Vectors.dense(lp[lp.find('[') + 1: lp.find(']')].split(',')) - return LabeledPoint(label, vec) - -trainingData = ssc.textFileStream("/training/data/dir").map(Vectors.parse) -testData = ssc.textFileStream("/testing/data/dir").map(parse) -{% endhighlight %} - -We create a model with random clusters and specify the number of clusters to find - -{% highlight python %} -model = StreamingKMeans(k=2, decayFactor=1.0).setRandomCenters(3, 1.0, 0) -{% endhighlight %} - -Now register the streams for training and testing and start the job, printing -the predicted cluster assignments on new data points as they arrive. - -{% highlight python %} -model.trainOn(trainingData) -print(model.predictOnValues(testData.map(lambda lp: (lp.label, lp.features)))) - -ssc.start() -ssc.awaitTermination() -{% endhighlight %} +{% include_example python/mllib/streaming_k_means_example.py %}
    diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 1ad52123c74aa..5c33292aaf086 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -1,7 +1,7 @@ --- layout: global -title: Collaborative Filtering - MLlib -displayTitle: MLlib - Collaborative Filtering +title: Collaborative Filtering - spark.mllib +displayTitle: Collaborative Filtering - spark.mllib --- * Table of contents @@ -11,17 +11,18 @@ displayTitle: MLlib - Collaborative Filtering [Collaborative filtering](http://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) is commonly used for recommender systems. These techniques aim to fill in the -missing entries of a user-item association matrix. MLlib currently supports +missing entries of a user-item association matrix. `spark.mllib` currently supports model-based collaborative filtering, in which users and products are described by a small set of latent factors that can be used to predict missing entries. -MLlib uses the [alternating least squares +`spark.mllib` uses the [alternating least squares (ALS)](http://dl.acm.org/citation.cfm?id=1608614) -algorithm to learn these latent factors. The implementation in MLlib has the +algorithm to learn these latent factors. The implementation in `spark.mllib` has the following parameters: * *numBlocks* is the number of blocks used to parallelize computation (set to -1 to auto-configure). * *rank* is the number of latent factors in the model. -* *iterations* is the number of iterations to run. +* *iterations* is the number of iterations of ALS to run. ALS typically converges to a reasonable + solution in 20 iterations or less. * *lambda* specifies the regularization parameter in ALS. * *implicitPrefs* specifies whether to use the *explicit feedback* ALS variant or one adapted for *implicit feedback* data. @@ -31,17 +32,18 @@ following parameters: ### Explicit vs. implicit feedback The standard approach to matrix factorization based collaborative filtering treats -the entries in the user-item matrix as *explicit* preferences given by the user to the item. +the entries in the user-item matrix as *explicit* preferences given by the user to the item, +for example, users giving ratings to movies. It is common in many real-world use cases to only have access to *implicit feedback* (e.g. views, -clicks, purchases, likes, shares etc.). The approach used in MLlib to deal with such data is taken -from -[Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22). -Essentially instead of trying to model the matrix of ratings directly, this approach treats the data -as a combination of binary preferences and *confidence values*. The ratings are then related to the -level of confidence in observed user preferences, rather than explicit ratings given to items. The -model then tries to find latent factors that can be used to predict the expected preference of a -user for an item. +clicks, purchases, likes, shares etc.). The approach used in `spark.mllib` to deal with such data is taken +from [Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22). +Essentially, instead of trying to model the matrix of ratings directly, this approach treats the data +as numbers representing the *strength* in observations of user actions (such as the number of clicks, +or the cumulative duration someone spent viewing a movie). Those numbers are then related to the level of +confidence in observed user preferences, rather than explicit ratings given to items. The model +then tries to find latent factors that can be used to predict the expected preference of a user for +an item. ### Scaling of the regularization parameter @@ -50,9 +52,8 @@ the number of ratings the user generated in updating user factors, or the number of ratings the product received in updating product factors. This approach is named "ALS-WR" and discussed in the paper "[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](http://dx.doi.org/10.1007/978-3-540-68880-8_32)". -It makes `lambda` less dependent on the scale of the dataset. -So we can apply the best parameter learned from a sampled subset to the full dataset -and expect similar performance. +It makes `lambda` less dependent on the scale of the dataset, so we can apply the +best parameter learned from a sampled subset to the full dataset and expect similar performance. ## Examples @@ -64,47 +65,11 @@ We use the default [ALS.train()](api/scala/index.html#org.apache.spark.mllib.rec method which assumes ratings are explicit. We evaluate the recommendation model by measuring the Mean Squared Error of rating prediction. -Refer to the [`ALS` Scala docs](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS) for details on the API. +Refer to the [`ALS` Scala docs](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS) for more details on the API. -{% highlight scala %} -import org.apache.spark.mllib.recommendation.ALS -import org.apache.spark.mllib.recommendation.MatrixFactorizationModel -import org.apache.spark.mllib.recommendation.Rating - -// Load and parse the data -val data = sc.textFile("data/mllib/als/test.data") -val ratings = data.map(_.split(',') match { case Array(user, item, rate) => - Rating(user.toInt, item.toInt, rate.toDouble) - }) - -// Build the recommendation model using ALS -val rank = 10 -val numIterations = 10 -val model = ALS.train(ratings, rank, numIterations, 0.01) - -// Evaluate the model on rating data -val usersProducts = ratings.map { case Rating(user, product, rate) => - (user, product) -} -val predictions = - model.predict(usersProducts).map { case Rating(user, product, rate) => - ((user, product), rate) - } -val ratesAndPreds = ratings.map { case Rating(user, product, rate) => - ((user, product), rate) -}.join(predictions) -val MSE = ratesAndPreds.map { case ((user, product), (r1, r2)) => - val err = (r1 - r2) - err * err -}.mean() -println("Mean Squared Error = " + MSE) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = MatrixFactorizationModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/RecommendationExample.scala %} -If the rating matrix is derived from another source of information (e.g., it is inferred from +If the rating matrix is derived from another source of information (i.e. it is inferred from other signals), you can use the `trainImplicit` method to get better results. {% highlight scala %} @@ -121,83 +86,9 @@ Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a calling `.rdd()` on your `JavaRDD` object. A self-contained application example that is equivalent to the provided example in Scala is given below: -Refer to the [`ALS` Java docs](api/java/org/apache/spark/mllib/recommendation/ALS.html) for details on the API. - -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.recommendation.ALS; -import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; -import org.apache.spark.mllib.recommendation.Rating; -import org.apache.spark.SparkConf; - -public class CollaborativeFiltering { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Collaborative Filtering Example"); - JavaSparkContext sc = new JavaSparkContext(conf); - - // Load and parse the data - String path = "data/mllib/als/test.data"; - JavaRDD data = sc.textFile(path); - JavaRDD ratings = data.map( - new Function() { - public Rating call(String s) { - String[] sarray = s.split(","); - return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), - Double.parseDouble(sarray[2])); - } - } - ); - - // Build the recommendation model using ALS - int rank = 10; - int numIterations = 10; - MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01); - - // Evaluate the model on rating data - JavaRDD> userProducts = ratings.map( - new Function>() { - public Tuple2 call(Rating r) { - return new Tuple2(r.user(), r.product()); - } - } - ); - JavaPairRDD, Double> predictions = JavaPairRDD.fromJavaRDD( - model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( - new Function, Double>>() { - public Tuple2, Double> call(Rating r){ - return new Tuple2, Double>( - new Tuple2(r.user(), r.product()), r.rating()); - } - } - )); - JavaRDD> ratesAndPreds = - JavaPairRDD.fromJavaRDD(ratings.map( - new Function, Double>>() { - public Tuple2, Double> call(Rating r){ - return new Tuple2, Double>( - new Tuple2(r.user(), r.product()), r.rating()); - } - } - )).join(predictions).values(); - double MSE = JavaDoubleRDD.fromRDD(ratesAndPreds.map( - new Function, Object>() { - public Object call(Tuple2 pair) { - Double err = pair._1() - pair._2(); - return err * err; - } - } - ).rdd()).mean(); - System.out.println("Mean Squared Error = " + MSE); - - // Save and load model - model.save(sc.sc(), "myModelPath"); - MatrixFactorizationModel sameModel = MatrixFactorizationModel.load(sc.sc(), "myModelPath"); - } -} -{% endhighlight %} +Refer to the [`ALS` Java docs](api/java/org/apache/spark/mllib/recommendation/ALS.html) for more details on the API. + +{% include_example java/org/apache/spark/examples/mllib/JavaRecommendationExample.java %}
    @@ -207,31 +98,9 @@ recommendation by measuring the Mean Squared Error of rating prediction. Refer to the [`ALS` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.recommendation.ALS) for more details on the API. -{% highlight python %} -from pyspark.mllib.recommendation import ALS, MatrixFactorizationModel, Rating - -# Load and parse the data -data = sc.textFile("data/mllib/als/test.data") -ratings = data.map(lambda l: l.split(',')).map(lambda l: Rating(int(l[0]), int(l[1]), float(l[2]))) - -# Build the recommendation model using Alternating Least Squares -rank = 10 -numIterations = 10 -model = ALS.train(ratings, rank, numIterations) - -# Evaluate the model on training data -testdata = ratings.map(lambda p: (p[0], p[1])) -predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2])) -ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions) -MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).mean() -print("Mean Squared Error = " + str(MSE)) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = MatrixFactorizationModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/recommendation_example.py %} -If the rating matrix is derived from other source of information (i.e., it is inferred from other +If the rating matrix is derived from other source of information (i.e. it is inferred from other signals), you can use the trainImplicit method to get better results. {% highlight python %} @@ -251,4 +120,4 @@ a dependency. ## Tutorial The [training exercises](https://databricks-training.s3.amazonaws.com/index.html) from the Spark Summit 2014 include a hands-on tutorial for -[personalized movie recommendation with MLlib](https://databricks-training.s3.amazonaws.com/movie-recommendation-with-mllib.html). +[personalized movie recommendation with `spark.mllib`](https://databricks-training.s3.amazonaws.com/movie-recommendation-with-mllib.html). diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 3c0c0479674df..5e3ee472a72c3 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -1,7 +1,7 @@ --- layout: global title: Data Types - MLlib -displayTitle: MLlib - Data Types +displayTitle: Data Types - MLlib --- * Table of contents @@ -11,7 +11,7 @@ MLlib supports local vectors and matrices stored on a single machine, as well as distributed matrices backed by one or more RDDs. Local vectors and local matrices are simple data models that serve as public interfaces. The underlying linear algebra operations are provided by -[Breeze](http://www.scalanlp.org/) and [jblas](http://jblas.org/). +[Breeze](http://www.scalanlp.org/). A training example used in supervised learning is called a "labeled point" in MLlib. ## Local vector diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index f31c4f88936bd..9af48357b3dfc 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -1,7 +1,7 @@ --- layout: global -title: Decision Trees - MLlib -displayTitle: MLlib - Decision Trees +title: Decision Trees - spark.mllib +displayTitle: Decision Trees - spark.mllib --- * Table of contents @@ -15,7 +15,7 @@ feature scaling, and are able to capture non-linearities and feature interaction algorithms such as random forests and boosting are among the top performers for classification and regression tasks. -MLlib supports decision trees for binary and multiclass classification and for regression, +`spark.mllib` supports decision trees for binary and multiclass classification and for regression, using both continuous and categorical features. The implementation partitions data by rows, allowing distributed training with millions of instances. @@ -121,12 +121,12 @@ The parameters are listed below roughly in order of descending importance. New These parameters describe the problem you want to solve and your dataset. They should be specified and do not require tuning. -* **`algo`**: `Classification` or `Regression` +* **`algo`**: Type of decision tree, either `Classification` or `Regression`. -* **`numClasses`**: Number of classes (for `Classification` only) +* **`numClasses`**: Number of classes (for `Classification` only). * **`categoricalFeaturesInfo`**: Specifies which features are categorical and how many categorical values each of those features can take. This is given as a map from feature indices to feature arity (number of categories). Any features not in this map are treated as continuous. - * E.g., `Map(0 -> 2, 4 -> 10)` specifies that feature `0` is binary (taking values `0` or `1`) and that feature `4` has 10 categories (values `{0, 1, ..., 9}`). Note that feature indices are 0-based: features `0` and `4` are the 1st and 5th elements of an instance's feature vector. + * For example, `Map(0 -> 2, 4 -> 10)` specifies that feature `0` is binary (taking values `0` or `1`) and that feature `4` has 10 categories (values `{0, 1, ..., 9}`). Note that feature indices are 0-based: features `0` and `4` are the 1st and 5th elements of an instance's feature vector. * Note that you do not have to specify `categoricalFeaturesInfo`. The algorithm will still run and may get reasonable results. However, performance should be better if categorical features are properly designated. ### Stopping criteria @@ -194,137 +194,19 @@ maximum tree depth of 5. The test error is calculated to measure the algorithm a
    Refer to the [`DecisionTree` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) and [`DecisionTreeModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.DecisionTreeModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.tree.DecisionTree -import org.apache.spark.mllib.tree.model.DecisionTreeModel -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -// Split the data into training and test sets (30% held out for testing) -val splits = data.randomSplit(Array(0.7, 0.3)) -val (trainingData, testData) = (splits(0), splits(1)) - -// Train a DecisionTree model. -// Empty categoricalFeaturesInfo indicates all features are continuous. -val numClasses = 2 -val categoricalFeaturesInfo = Map[Int, Int]() -val impurity = "gini" -val maxDepth = 5 -val maxBins = 32 - -val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, - impurity, maxDepth, maxBins) - -// Evaluate model on test instances and compute test error -val labelAndPreds = testData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() -println("Test Error = " + testErr) -println("Learned classification tree model:\n" + model.toDebugString) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = DecisionTreeModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala %}
    Refer to the [`DecisionTree` Java docs](api/java/org/apache/spark/mllib/tree/DecisionTree.html) and [`DecisionTreeModel` Java docs](api/java/org/apache/spark/mllib/tree/model/DecisionTreeModel.html) for details on the API. -{% highlight java %} -import java.util.HashMap; -import scala.Tuple2; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.DecisionTree; -import org.apache.spark.mllib.tree.model.DecisionTreeModel; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; - -SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); -JavaSparkContext sc = new JavaSparkContext(sparkConf); - -// Load and parse the data file. -String datapath = "data/mllib/sample_libsvm_data.txt"; -JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); -// Split the data into training and test sets (30% held out for testing) -JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); -JavaRDD trainingData = splits[0]; -JavaRDD testData = splits[1]; - -// Set parameters. -// Empty categoricalFeaturesInfo indicates all features are continuous. -Integer numClasses = 2; -Map categoricalFeaturesInfo = new HashMap(); -String impurity = "gini"; -Integer maxDepth = 5; -Integer maxBins = 32; - -// Train a DecisionTree model for classification. -final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, - categoricalFeaturesInfo, impurity, maxDepth, maxBins); - -// Evaluate model on test instances and compute test error -JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); -Double testErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override - public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / testData.count(); -System.out.println("Test Error: " + testErr); -System.out.println("Learned classification tree model:\n" + model.toDebugString()); - -// Save and load model -model.save(sc.sc(), "myModelPath"); -DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java %}
    Refer to the [`DecisionTree` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.DecisionTree) and [`DecisionTreeModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.DecisionTreeModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.tree import DecisionTree, DecisionTreeModel -from pyspark.mllib.util import MLUtils - -# Load and parse the data file into an RDD of LabeledPoint. -data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a DecisionTree model. -# Empty categoricalFeaturesInfo indicates all features are continuous. -model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={}, - impurity='gini', maxDepth=5, maxBins=32) - -# Evaluate model on test instances and compute test error -predictions = model.predict(testData.map(lambda x: x.features)) -labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) -testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) -print('Test Error = ' + str(testErr)) -print('Learned classification tree model:') -print(model.toDebugString()) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = DecisionTreeModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/decision_tree_classification_example.py %}
    @@ -343,142 +225,19 @@ depth of 5. The Mean Squared Error (MSE) is computed at the end to evaluate
    Refer to the [`DecisionTree` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) and [`DecisionTreeModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.DecisionTreeModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.tree.DecisionTree -import org.apache.spark.mllib.tree.model.DecisionTreeModel -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -// Split the data into training and test sets (30% held out for testing) -val splits = data.randomSplit(Array(0.7, 0.3)) -val (trainingData, testData) = (splits(0), splits(1)) - -// Train a DecisionTree model. -// Empty categoricalFeaturesInfo indicates all features are continuous. -val categoricalFeaturesInfo = Map[Int, Int]() -val impurity = "variance" -val maxDepth = 5 -val maxBins = 32 - -val model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity, - maxDepth, maxBins) - -// Evaluate model on test instances and compute test error -val labelsAndPredictions = testData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() -println("Test Mean Squared Error = " + testMSE) -println("Learned regression tree model:\n" + model.toDebugString) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = DecisionTreeModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala %}
    Refer to the [`DecisionTree` Java docs](api/java/org/apache/spark/mllib/tree/DecisionTree.html) and [`DecisionTreeModel` Java docs](api/java/org/apache/spark/mllib/tree/model/DecisionTreeModel.html) for details on the API. -{% highlight java %} -import java.util.HashMap; -import scala.Tuple2; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.DecisionTree; -import org.apache.spark.mllib.tree.model.DecisionTreeModel; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; - -SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); -JavaSparkContext sc = new JavaSparkContext(sparkConf); - -// Load and parse the data file. -String datapath = "data/mllib/sample_libsvm_data.txt"; -JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); -// Split the data into training and test sets (30% held out for testing) -JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); -JavaRDD trainingData = splits[0]; -JavaRDD testData = splits[1]; - -// Set parameters. -// Empty categoricalFeaturesInfo indicates all features are continuous. -Map categoricalFeaturesInfo = new HashMap(); -String impurity = "variance"; -Integer maxDepth = 5; -Integer maxBins = 32; - -// Train a DecisionTree model. -final DecisionTreeModel model = DecisionTree.trainRegressor(trainingData, - categoricalFeaturesInfo, impurity, maxDepth, maxBins); - -// Evaluate model on test instances and compute test error -JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); -Double testMSE = - predictionAndLabel.map(new Function, Double>() { - @Override - public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override - public Double call(Double a, Double b) { - return a + b; - } - }) / data.count(); -System.out.println("Test Mean Squared Error: " + testMSE); -System.out.println("Learned regression tree model:\n" + model.toDebugString()); - -// Save and load model -model.save(sc.sc(), "myModelPath"); -DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java %}
    Refer to the [`DecisionTree` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.DecisionTree) and [`DecisionTreeModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.DecisionTreeModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.tree import DecisionTree, DecisionTreeModel -from pyspark.mllib.util import MLUtils - -# Load and parse the data file into an RDD of LabeledPoint. -data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a DecisionTree model. -# Empty categoricalFeaturesInfo indicates all features are continuous. -model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo={}, - impurity='variance', maxDepth=5, maxBins=32) - -# Evaluate model on test instances and compute test error -predictions = model.predict(testData.map(lambda x: x.features)) -labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) -testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(testData.count()) -print('Test Mean Squared Error = ' + str(testMSE)) -print('Learned regression tree model:') -print(model.toDebugString()) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = DecisionTreeModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/decision_tree_regression_example.py %}
    diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index ac3526908a9f4..cceddce9f79a6 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -1,7 +1,7 @@ --- layout: global -title: Dimensionality Reduction - MLlib -displayTitle: MLlib - Dimensionality Reduction +title: Dimensionality Reduction - spark.mllib +displayTitle: Dimensionality Reduction - spark.mllib --- * Table of contents @@ -11,7 +11,7 @@ displayTitle: MLlib - Dimensionality Reduction of reducing the number of variables under consideration. It can be used to extract latent features from raw and noisy features or compress data while maintaining the structure. -MLlib provides support for dimensionality reduction on the RowMatrix class. +`spark.mllib` provides support for dimensionality reduction on the RowMatrix class. ## Singular value decomposition (SVD) @@ -57,26 +57,14 @@ passes, $O(n)$ storage on each executor, and $O(n k)$ storage on the driver. ### SVD Example -MLlib provides SVD functionality to row-oriented matrices, provided in the +`spark.mllib` provides SVD functionality to row-oriented matrices, provided in the RowMatrix class.
    Refer to the [`SingularValueDecomposition` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.SingularValueDecomposition) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.linalg.Matrix -import org.apache.spark.mllib.linalg.distributed.RowMatrix -import org.apache.spark.mllib.linalg.SingularValueDecomposition - -val mat: RowMatrix = ... - -// Compute the top 20 singular values and corresponding singular vectors. -val svd: SingularValueDecomposition[RowMatrix, Matrix] = mat.computeSVD(20, computeU = true) -val U: RowMatrix = svd.U // The U factor is a RowMatrix. -val s: Vector = svd.s // The singular values are stored in a local dense vector. -val V: Matrix = svd.V // The V factor is a local dense matrix. -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/SVDExample.scala %} The same code applies to `IndexedRowMatrix` if `U` is defined as an `IndexedRowMatrix`. @@ -84,43 +72,7 @@ The same code applies to `IndexedRowMatrix` if `U` is defined as an
    Refer to the [`SingularValueDecomposition` Java docs](api/java/org/apache/spark/mllib/linalg/SingularValueDecomposition.html) for details on the API. -{% highlight java %} -import java.util.LinkedList; - -import org.apache.spark.api.java.*; -import org.apache.spark.mllib.linalg.distributed.RowMatrix; -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.linalg.SingularValueDecomposition; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.rdd.RDD; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class SVD { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("SVD Example"); - SparkContext sc = new SparkContext(conf); - - double[][] array = ... - LinkedList rowsList = new LinkedList(); - for (int i = 0; i < array.length; i++) { - Vector currentRow = Vectors.dense(array[i]); - rowsList.add(currentRow); - } - JavaRDD rows = JavaSparkContext.fromSparkContext(sc).parallelize(rowsList); - - // Create a RowMatrix from JavaRDD. - RowMatrix mat = new RowMatrix(rows.rdd()); - - // Compute the top 4 singular values and corresponding singular vectors. - SingularValueDecomposition svd = mat.computeSVD(4, true, 1.0E-9d); - RowMatrix U = svd.U(); - Vector s = svd.s(); - Matrix V = svd.V(); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaSVDExample.java %} The same code applies to `IndexedRowMatrix` if `U` is defined as an `IndexedRowMatrix`. @@ -141,7 +93,7 @@ statistical method to find a rotation such that the first coordinate has the lar possible, and each succeeding coordinate in turn has the largest variance possible. The columns of the rotation matrix are called principal components. PCA is used widely in dimensionality reduction. -MLlib supports PCA for tall-and-skinny matrices stored in row-oriented format and any Vectors. +`spark.mllib` supports PCA for tall-and-skinny matrices stored in row-oriented format and any Vectors.
    @@ -151,36 +103,14 @@ and use them to project the vectors into a low-dimensional space. Refer to the [`RowMatrix` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.RowMatrix) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.linalg.Matrix -import org.apache.spark.mllib.linalg.distributed.RowMatrix - -val mat: RowMatrix = ... - -// Compute the top 10 principal components. -val pc: Matrix = mat.computePrincipalComponents(10) // Principal components are stored in a local dense matrix. - -// Project the rows to the linear space spanned by the top 10 principal components. -val projected: RowMatrix = mat.multiply(pc) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala %} The following code demonstrates how to compute principal components on source vectors and use them to project the vectors into a low-dimensional space while keeping associated labels: Refer to the [`PCA` Scala docs](api/scala/index.html#org.apache.spark.mllib.feature.PCA) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.feature.PCA - -val data: RDD[LabeledPoint] = ... - -// Compute the top 10 principal components. -val pca = new PCA(10).fit(data.map(_.features)) - -// Project vectors to the linear space spanned by the top 10 principal components, keeping the label -val projected = data.map(p => p.copy(features = pca.transform(p.features))) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/PCAOnSourceVectorExample.scala %}
    @@ -192,40 +122,7 @@ The number of columns should be small, e.g, less than 1000. Refer to the [`RowMatrix` Java docs](api/java/org/apache/spark/mllib/linalg/distributed/RowMatrix.html) for details on the API. -{% highlight java %} -import java.util.LinkedList; - -import org.apache.spark.api.java.*; -import org.apache.spark.mllib.linalg.distributed.RowMatrix; -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.rdd.RDD; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class PCA { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("PCA Example"); - SparkContext sc = new SparkContext(conf); - - double[][] array = ... - LinkedList rowsList = new LinkedList(); - for (int i = 0; i < array.length; i++) { - Vector currentRow = Vectors.dense(array[i]); - rowsList.add(currentRow); - } - JavaRDD rows = JavaSparkContext.fromSparkContext(sc).parallelize(rowsList); - - // Create a RowMatrix from JavaRDD. - RowMatrix mat = new RowMatrix(rows.rdd()); - - // Compute the top 3 principal components. - Matrix pc = mat.computePrincipalComponents(3); - RowMatrix projected = mat.multiply(pc); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaPCAExample.java %}
    diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index fc587298f7d2e..2416b6fa0aeb3 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -1,7 +1,7 @@ --- layout: global -title: Ensembles - MLlib -displayTitle: MLlib - Ensembles +title: Ensembles - spark.mllib +displayTitle: Ensembles - spark.mllib --- * Table of contents @@ -9,7 +9,7 @@ displayTitle: MLlib - Ensembles An [ensemble method](http://en.wikipedia.org/wiki/Ensemble_learning) is a learning algorithm which creates a model composed of a set of other base models. -MLlib supports two major ensemble algorithms: [`GradientBoostedTrees`](api/scala/index.html#org.apache.spark.mllib.tree.GradientBoostedTrees) and [`RandomForest`](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest). +`spark.mllib` supports two major ensemble algorithms: [`GradientBoostedTrees`](api/scala/index.html#org.apache.spark.mllib.tree.GradientBoostedTrees) and [`RandomForest`](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest). Both use [decision trees](mllib-decision-tree.html) as their base models. ## Gradient-Boosted Trees vs. Random Forests @@ -33,9 +33,9 @@ Like decision trees, random forests handle categorical features, extend to the multiclass classification setting, do not require feature scaling, and are able to capture non-linearities and feature interactions. -MLlib supports random forests for binary and multiclass classification and for regression, +`spark.mllib` supports random forests for binary and multiclass classification and for regression, using both continuous and categorical features. -MLlib implements random forests using the existing [decision tree](mllib-decision-tree.html) +`spark.mllib` implements random forests using the existing [decision tree](mllib-decision-tree.html) implementation. Please see the decision tree guide for more information on trees. ### Basic algorithm @@ -98,144 +98,19 @@ The test error is calculated to measure the algorithm accuracy.
    Refer to the [`RandomForest` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest) and [`RandomForestModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.RandomForestModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.tree.RandomForest -import org.apache.spark.mllib.tree.model.RandomForestModel -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -// Split the data into training and test sets (30% held out for testing) -val splits = data.randomSplit(Array(0.7, 0.3)) -val (trainingData, testData) = (splits(0), splits(1)) - -// Train a RandomForest model. -// Empty categoricalFeaturesInfo indicates all features are continuous. -val numClasses = 2 -val categoricalFeaturesInfo = Map[Int, Int]() -val numTrees = 3 // Use more in practice. -val featureSubsetStrategy = "auto" // Let the algorithm choose. -val impurity = "gini" -val maxDepth = 4 -val maxBins = 32 - -val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, - numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins) - -// Evaluate model on test instances and compute test error -val labelAndPreds = testData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() -println("Test Error = " + testErr) -println("Learned classification forest model:\n" + model.toDebugString) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = RandomForestModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala %}
    Refer to the [`RandomForest` Java docs](api/java/org/apache/spark/mllib/tree/RandomForest.html) and [`RandomForestModel` Java docs](api/java/org/apache/spark/mllib/tree/model/RandomForestModel.html) for details on the API. -{% highlight java %} -import scala.Tuple2; -import java.util.HashMap; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.RandomForest; -import org.apache.spark.mllib.tree.model.RandomForestModel; -import org.apache.spark.mllib.util.MLUtils; - -SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestClassification"); -JavaSparkContext sc = new JavaSparkContext(sparkConf); - -// Load and parse the data file. -String datapath = "data/mllib/sample_libsvm_data.txt"; -JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); -// Split the data into training and test sets (30% held out for testing) -JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); -JavaRDD trainingData = splits[0]; -JavaRDD testData = splits[1]; - -// Train a RandomForest model. -// Empty categoricalFeaturesInfo indicates all features are continuous. -Integer numClasses = 2; -HashMap categoricalFeaturesInfo = new HashMap(); -Integer numTrees = 3; // Use more in practice. -String featureSubsetStrategy = "auto"; // Let the algorithm choose. -String impurity = "gini"; -Integer maxDepth = 5; -Integer maxBins = 32; -Integer seed = 12345; - -final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses, - categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, - seed); - -// Evaluate model on test instances and compute test error -JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); -Double testErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override - public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / testData.count(); -System.out.println("Test Error: " + testErr); -System.out.println("Learned classification forest model:\n" + model.toDebugString()); - -// Save and load model -model.save(sc.sc(), "myModelPath"); -RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaRandomForestClassificationExample.java %}
    Refer to the [`RandomForest` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.RandomForest) and [`RandomForest` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.RandomForestModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.tree import RandomForest, RandomForestModel -from pyspark.mllib.util import MLUtils - -# Load and parse the data file into an RDD of LabeledPoint. -data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a RandomForest model. -# Empty categoricalFeaturesInfo indicates all features are continuous. -# Note: Use larger numTrees in practice. -# Setting featureSubsetStrategy="auto" lets the algorithm choose. -model = RandomForest.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={}, - numTrees=3, featureSubsetStrategy="auto", - impurity='gini', maxDepth=4, maxBins=32) - -# Evaluate model on test instances and compute test error -predictions = model.predict(testData.map(lambda x: x.features)) -labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) -testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) -print('Test Error = ' + str(testErr)) -print('Learned classification forest model:') -print(model.toDebugString()) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = RandomForestModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/random_forest_classification_example.py %}
    @@ -254,147 +129,19 @@ The Mean Squared Error (MSE) is computed at the end to evaluate
    Refer to the [`RandomForest` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest) and [`RandomForestModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.RandomForestModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.tree.RandomForest -import org.apache.spark.mllib.tree.model.RandomForestModel -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -// Split the data into training and test sets (30% held out for testing) -val splits = data.randomSplit(Array(0.7, 0.3)) -val (trainingData, testData) = (splits(0), splits(1)) - -// Train a RandomForest model. -// Empty categoricalFeaturesInfo indicates all features are continuous. -val numClasses = 2 -val categoricalFeaturesInfo = Map[Int, Int]() -val numTrees = 3 // Use more in practice. -val featureSubsetStrategy = "auto" // Let the algorithm choose. -val impurity = "variance" -val maxDepth = 4 -val maxBins = 32 - -val model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo, - numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins) - -// Evaluate model on test instances and compute test error -val labelsAndPredictions = testData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() -println("Test Mean Squared Error = " + testMSE) -println("Learned regression forest model:\n" + model.toDebugString) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = RandomForestModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala %}
    Refer to the [`RandomForest` Java docs](api/java/org/apache/spark/mllib/tree/RandomForest.html) and [`RandomForestModel` Java docs](api/java/org/apache/spark/mllib/tree/model/RandomForestModel.html) for details on the API. -{% highlight java %} -import java.util.HashMap; -import scala.Tuple2; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.RandomForest; -import org.apache.spark.mllib.tree.model.RandomForestModel; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; - -SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForest"); -JavaSparkContext sc = new JavaSparkContext(sparkConf); - -// Load and parse the data file. -String datapath = "data/mllib/sample_libsvm_data.txt"; -JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); -// Split the data into training and test sets (30% held out for testing) -JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); -JavaRDD trainingData = splits[0]; -JavaRDD testData = splits[1]; - -// Set parameters. -// Empty categoricalFeaturesInfo indicates all features are continuous. -Map categoricalFeaturesInfo = new HashMap(); -String impurity = "variance"; -Integer maxDepth = 4; -Integer maxBins = 32; - -// Train a RandomForest model. -final RandomForestModel model = RandomForest.trainRegressor(trainingData, - categoricalFeaturesInfo, impurity, maxDepth, maxBins); - -// Evaluate model on test instances and compute test error -JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); -Double testMSE = - predictionAndLabel.map(new Function, Double>() { - @Override - public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override - public Double call(Double a, Double b) { - return a + b; - } - }) / testData.count(); -System.out.println("Test Mean Squared Error: " + testMSE); -System.out.println("Learned regression forest model:\n" + model.toDebugString()); - -// Save and load model -model.save(sc.sc(), "myModelPath"); -RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaRandomForestRegressionExample.java %}
    Refer to the [`RandomForest` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.RandomForest) and [`RandomForest` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.RandomForestModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.tree import RandomForest, RandomForestModel -from pyspark.mllib.util import MLUtils - -# Load and parse the data file into an RDD of LabeledPoint. -data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a RandomForest model. -# Empty categoricalFeaturesInfo indicates all features are continuous. -# Note: Use larger numTrees in practice. -# Setting featureSubsetStrategy="auto" lets the algorithm choose. -model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo={}, - numTrees=3, featureSubsetStrategy="auto", - impurity='variance', maxDepth=4, maxBins=32) - -# Evaluate model on test instances and compute test error -predictions = model.predict(testData.map(lambda x: x.features)) -labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) -testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(testData.count()) -print('Test Mean Squared Error = ' + str(testMSE)) -print('Learned regression forest model:') -print(model.toDebugString()) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = RandomForestModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/random_forest_regression_example.py %}
    @@ -408,9 +155,9 @@ Like decision trees, GBTs handle categorical features, extend to the multiclass classification setting, do not require feature scaling, and are able to capture non-linearities and feature interactions. -MLlib supports GBTs for binary classification and for regression, +`spark.mllib` supports GBTs for binary classification and for regression, using both continuous and categorical features. -MLlib implements GBTs using the existing [decision tree](mllib-decision-tree.html) implementation. Please see the decision tree guide for more information on trees. +`spark.mllib` implements GBTs using the existing [decision tree](mllib-decision-tree.html) implementation. Please see the decision tree guide for more information on trees. *Note*: GBTs do not yet support multiclass classification. For multiclass problems, please use [decision trees](mllib-decision-tree.html) or [Random Forests](mllib-ensembles.html#Random-Forest). @@ -424,7 +171,7 @@ The specific mechanism for re-labeling instances is defined by a loss function ( #### Losses -The table below lists the losses currently supported by GBTs in MLlib. +The table below lists the losses currently supported by GBTs in `spark.mllib`. Note that each loss is applicable to one of classification or regression, not both. Notation: $N$ = number of instances. $y_i$ = label of instance $i$. $x_i$ = features of instance $i$. $F(x_i)$ = model's predicted label for instance $i$. @@ -492,141 +239,19 @@ The test error is calculated to measure the algorithm accuracy.
    Refer to the [`GradientBoostedTrees` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.GradientBoostedTrees) and [`GradientBoostedTreesModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.GradientBoostedTreesModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.tree.GradientBoostedTrees -import org.apache.spark.mllib.tree.configuration.BoostingStrategy -import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -// Split the data into training and test sets (30% held out for testing) -val splits = data.randomSplit(Array(0.7, 0.3)) -val (trainingData, testData) = (splits(0), splits(1)) - -// Train a GradientBoostedTrees model. -// The defaultParams for Classification use LogLoss by default. -val boostingStrategy = BoostingStrategy.defaultParams("Classification") -boostingStrategy.numIterations = 3 // Note: Use more iterations in practice. -boostingStrategy.treeStrategy.numClasses = 2 -boostingStrategy.treeStrategy.maxDepth = 5 -// Empty categoricalFeaturesInfo indicates all features are continuous. -boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]() - -val model = GradientBoostedTrees.train(trainingData, boostingStrategy) - -// Evaluate model on test instances and compute test error -val labelAndPreds = testData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() -println("Test Error = " + testErr) -println("Learned classification GBT model:\n" + model.toDebugString) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = GradientBoostedTreesModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala %}
    Refer to the [`GradientBoostedTrees` Java docs](api/java/org/apache/spark/mllib/tree/GradientBoostedTrees.html) and [`GradientBoostedTreesModel` Java docs](api/java/org/apache/spark/mllib/tree/model/GradientBoostedTreesModel.html) for details on the API. -{% highlight java %} -import scala.Tuple2; -import java.util.HashMap; -import java.util.Map; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.GradientBoostedTrees; -import org.apache.spark.mllib.tree.configuration.BoostingStrategy; -import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; -import org.apache.spark.mllib.util.MLUtils; - -SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees"); -JavaSparkContext sc = new JavaSparkContext(sparkConf); - -// Load and parse the data file. -String datapath = "data/mllib/sample_libsvm_data.txt"; -JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); -// Split the data into training and test sets (30% held out for testing) -JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); -JavaRDD trainingData = splits[0]; -JavaRDD testData = splits[1]; - -// Train a GradientBoostedTrees model. -// The defaultParams for Classification use LogLoss by default. -BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Classification"); -boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. -boostingStrategy.getTreeStrategy().setNumClassesForClassification(2); -boostingStrategy.getTreeStrategy().setMaxDepth(5); -// Empty categoricalFeaturesInfo indicates all features are continuous. -Map categoricalFeaturesInfo = new HashMap(); -boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo); - -final GradientBoostedTreesModel model = - GradientBoostedTrees.train(trainingData, boostingStrategy); - -// Evaluate model on test instances and compute test error -JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); -Double testErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override - public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / testData.count(); -System.out.println("Test Error: " + testErr); -System.out.println("Learned classification GBT model:\n" + model.toDebugString()); - -// Save and load model -model.save(sc.sc(), "myModelPath"); -GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaGradientBoostingClassificationExample.java %}
    Refer to the [`GradientBoostedTrees` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.GradientBoostedTrees) and [`GradientBoostedTreesModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.GradientBoostedTreesModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel -from pyspark.mllib.util import MLUtils - -# Load and parse the data file. -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a GradientBoostedTrees model. -# Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous. -# (b) Use more iterations in practice. -model = GradientBoostedTrees.trainClassifier(trainingData, - categoricalFeaturesInfo={}, numIterations=3) - -# Evaluate model on test instances and compute test error -predictions = model.predict(testData.map(lambda x: x.features)) -labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) -testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) -print('Test Error = ' + str(testErr)) -print('Learned classification GBT model:') -print(model.toDebugString()) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = GradientBoostedTreesModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/gradient_boosting_classification_example.py %}
    @@ -645,146 +270,19 @@ The Mean Squared Error (MSE) is computed at the end to evaluate
    Refer to the [`GradientBoostedTrees` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.GradientBoostedTrees) and [`GradientBoostedTreesModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.GradientBoostedTreesModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.tree.GradientBoostedTrees -import org.apache.spark.mllib.tree.configuration.BoostingStrategy -import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -// Split the data into training and test sets (30% held out for testing) -val splits = data.randomSplit(Array(0.7, 0.3)) -val (trainingData, testData) = (splits(0), splits(1)) - -// Train a GradientBoostedTrees model. -// The defaultParams for Regression use SquaredError by default. -val boostingStrategy = BoostingStrategy.defaultParams("Regression") -boostingStrategy.numIterations = 3 // Note: Use more iterations in practice. -boostingStrategy.treeStrategy.maxDepth = 5 -// Empty categoricalFeaturesInfo indicates all features are continuous. -boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]() - -val model = GradientBoostedTrees.train(trainingData, boostingStrategy) - -// Evaluate model on test instances and compute test error -val labelsAndPredictions = testData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() -println("Test Mean Squared Error = " + testMSE) -println("Learned regression GBT model:\n" + model.toDebugString) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = GradientBoostedTreesModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala %}
    Refer to the [`GradientBoostedTrees` Java docs](api/java/org/apache/spark/mllib/tree/GradientBoostedTrees.html) and [`GradientBoostedTreesModel` Java docs](api/java/org/apache/spark/mllib/tree/model/GradientBoostedTreesModel.html) for details on the API. -{% highlight java %} -import scala.Tuple2; -import java.util.HashMap; -import java.util.Map; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.GradientBoostedTrees; -import org.apache.spark.mllib.tree.configuration.BoostingStrategy; -import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; -import org.apache.spark.mllib.util.MLUtils; - -SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees"); -JavaSparkContext sc = new JavaSparkContext(sparkConf); - -// Load and parse the data file. -String datapath = "data/mllib/sample_libsvm_data.txt"; -JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); -// Split the data into training and test sets (30% held out for testing) -JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); -JavaRDD trainingData = splits[0]; -JavaRDD testData = splits[1]; - -// Train a GradientBoostedTrees model. -// The defaultParams for Regression use SquaredError by default. -BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Regression"); -boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. -boostingStrategy.getTreeStrategy().setMaxDepth(5); -// Empty categoricalFeaturesInfo indicates all features are continuous. -Map categoricalFeaturesInfo = new HashMap(); -boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo); - -final GradientBoostedTreesModel model = - GradientBoostedTrees.train(trainingData, boostingStrategy); - -// Evaluate model on test instances and compute test error -JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); -Double testMSE = - predictionAndLabel.map(new Function, Double>() { - @Override - public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override - public Double call(Double a, Double b) { - return a + b; - } - }) / data.count(); -System.out.println("Test Mean Squared Error: " + testMSE); -System.out.println("Learned regression GBT model:\n" + model.toDebugString()); - -// Save and load model -model.save(sc.sc(), "myModelPath"); -GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaGradientBoostingRegressionExample.java %}
    Refer to the [`GradientBoostedTrees` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.GradientBoostedTrees) and [`GradientBoostedTreesModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.GradientBoostedTreesModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel -from pyspark.mllib.util import MLUtils - -# Load and parse the data file. -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a GradientBoostedTrees model. -# Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous. -# (b) Use more iterations in practice. -model = GradientBoostedTrees.trainRegressor(trainingData, - categoricalFeaturesInfo={}, numIterations=3) - -# Evaluate model on test instances and compute test error -predictions = model.predict(testData.map(lambda x: x.features)) -labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) -testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(testData.count()) -print('Test Mean Squared Error = ' + str(testMSE)) -print('Learned regression GBT model:') -print(model.toDebugString()) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = GradientBoostedTreesModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/gradient_boosting_regression_example.py %}
    diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index f73eff637dc36..a269dbf030e7c 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -1,20 +1,20 @@ --- layout: global -title: Evaluation Metrics - MLlib -displayTitle: MLlib - Evaluation Metrics +title: Evaluation Metrics - spark.mllib +displayTitle: Evaluation Metrics - spark.mllib --- * Table of contents {:toc} -Spark's MLlib comes with a number of machine learning algorithms that can be used to learn from and make predictions +`spark.mllib` comes with a number of machine learning algorithms that can be used to learn from and make predictions on data. When these algorithms are applied to build machine learning models, there is a need to evaluate the performance -of the model on some criteria, which depends on the application and its requirements. Spark's MLlib also provides a +of the model on some criteria, which depends on the application and its requirements. `spark.mllib` also provides a suite of metrics for the purpose of evaluating the performance of machine learning models. Specific machine learning algorithms fall under broader types of machine learning applications like classification, regression, clustering, etc. Each of these types have well established metrics for performance evaluation and those -metrics that are currently available in Spark's MLlib are detailed in this section. +metrics that are currently available in `spark.mllib` are detailed in this section. ## Classification model evaluation @@ -67,7 +67,7 @@ plots (recall, false positive rate) points. - Precision (Postive Predictive Value) + Precision (Positive Predictive Value) $PPV=\frac{TP}{TP + FP}$ @@ -104,214 +104,21 @@ data, and evaluate the performance of the algorithm by several binary evaluation
    Refer to the [`LogisticRegressionWithLBFGS` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS) and [`BinaryClassificationMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.BinaryClassificationMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.MLUtils - -// Load training data in LIBSVM format -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") - -// Split data into training (60%) and test (40%) -val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) -training.cache() - -// Run training algorithm to build the model -val model = new LogisticRegressionWithLBFGS() - .setNumClasses(2) - .run(training) - -// Clear the prediction threshold so the model will return probabilities -model.clearThreshold - -// Compute raw scores on the test set -val predictionAndLabels = test.map { case LabeledPoint(label, features) => - val prediction = model.predict(features) - (prediction, label) -} - -// Instantiate metrics object -val metrics = new BinaryClassificationMetrics(predictionAndLabels) - -// Precision by threshold -val precision = metrics.precisionByThreshold -precision.foreach { case (t, p) => - println(s"Threshold: $t, Precision: $p") -} - -// Recall by threshold -val recall = metrics.recallByThreshold -recall.foreach { case (t, r) => - println(s"Threshold: $t, Recall: $r") -} - -// Precision-Recall Curve -val PRC = metrics.pr - -// F-measure -val f1Score = metrics.fMeasureByThreshold -f1Score.foreach { case (t, f) => - println(s"Threshold: $t, F-score: $f, Beta = 1") -} - -val beta = 0.5 -val fScore = metrics.fMeasureByThreshold(beta) -f1Score.foreach { case (t, f) => - println(s"Threshold: $t, F-score: $f, Beta = 0.5") -} - -// AUPRC -val auPRC = metrics.areaUnderPR -println("Area under precision-recall curve = " + auPRC) - -// Compute thresholds used in ROC and PR curves -val thresholds = precision.map(_._1) - -// ROC Curve -val roc = metrics.roc - -// AUROC -val auROC = metrics.areaUnderROC -println("Area under ROC = " + auROC) - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala %}
    Refer to the [`LogisticRegressionModel` Java docs](api/java/org/apache/spark/mllib/classification/LogisticRegressionModel.html) and [`LogisticRegressionWithLBFGS` Java docs](api/java/org/apache/spark/mllib/classification/LogisticRegressionWithLBFGS.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.rdd.RDD; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.classification.LogisticRegressionModel; -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class BinaryClassification { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Binary Classification Metrics"); - SparkContext sc = new SparkContext(conf); - String path = "data/mllib/sample_binary_classification_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); - - // Split initial RDD into two... [60% training data, 40% testing data]. - JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); - JavaRDD training = splits[0].cache(); - JavaRDD test = splits[1]; - - // Run training algorithm to build the model. - final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() - .setNumClasses(2) - .run(training.rdd()); - - // Clear the prediction threshold so the model will return probabilities - model.clearThreshold(); - - // Compute raw scores on the test set. - JavaRDD> predictionAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double prediction = model.predict(p.features()); - return new Tuple2(prediction, p.label()); - } - } - ); - - // Get evaluation metrics. - BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd()); - - // Precision by threshold - JavaRDD> precision = metrics.precisionByThreshold().toJavaRDD(); - System.out.println("Precision by threshold: " + precision.toArray()); - - // Recall by threshold - JavaRDD> recall = metrics.recallByThreshold().toJavaRDD(); - System.out.println("Recall by threshold: " + recall.toArray()); - - // F Score by threshold - JavaRDD> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); - System.out.println("F1 Score by threshold: " + f1Score.toArray()); - - JavaRDD> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); - System.out.println("F2 Score by threshold: " + f2Score.toArray()); - - // Precision-recall curve - JavaRDD> prc = metrics.pr().toJavaRDD(); - System.out.println("Precision-recall curve: " + prc.toArray()); - - // Thresholds - JavaRDD thresholds = precision.map( - new Function, Double>() { - public Double call (Tuple2 t) { - return new Double(t._1().toString()); - } - } - ); - - // ROC Curve - JavaRDD> roc = metrics.roc().toJavaRDD(); - System.out.println("ROC curve: " + roc.toArray()); - - // AUPRC - System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); - - // AUROC - System.out.println("Area under ROC = " + metrics.areaUnderROC()); - - // Save and load model - model.save(sc, "myModelPath"); - LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); - } -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java %}
    Refer to the [`BinaryClassificationMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.BinaryClassificationMetrics) and [`LogisticRegressionWithLBFGS` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.LogisticRegressionWithLBFGS) for more details on the API. -{% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithLBFGS -from pyspark.mllib.evaluation import BinaryClassificationMetrics -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.util import MLUtils - -# Several of the methods available in scala are currently missing from pyspark - -# Load training data in LIBSVM format -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") - -# Split data into training (60%) and test (40%) -training, test = data.randomSplit([0.6, 0.4], seed = 11L) -training.cache() - -# Run training algorithm to build the model -model = LogisticRegressionWithLBFGS.train(training) - -# Compute raw scores on the test set -predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) - -# Instantiate metrics object -metrics = BinaryClassificationMetrics(predictionAndLabels) - -# Area under precision-recall curve -print("Area under PR = %s" % metrics.areaUnderPR) - -# Area under ROC curve -print("Area under ROC = %s" % metrics.areaUnderROC) - -{% endhighlight %} - +{% include_example python/mllib/binary_classification_metrics_example.py %}
    @@ -433,204 +240,21 @@ the data, and evaluate the performance of the algorithm by several multiclass cl
    Refer to the [`MulticlassMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.MulticlassMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS -import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.MLUtils - -// Load training data in LIBSVM format -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") - -// Split data into training (60%) and test (40%) -val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) -training.cache() - -// Run training algorithm to build the model -val model = new LogisticRegressionWithLBFGS() - .setNumClasses(3) - .run(training) - -// Compute raw scores on the test set -val predictionAndLabels = test.map { case LabeledPoint(label, features) => - val prediction = model.predict(features) - (prediction, label) -} - -// Instantiate metrics object -val metrics = new MulticlassMetrics(predictionAndLabels) - -// Confusion matrix -println("Confusion matrix:") -println(metrics.confusionMatrix) - -// Overall Statistics -val precision = metrics.precision -val recall = metrics.recall // same as true positive rate -val f1Score = metrics.fMeasure -println("Summary Statistics") -println(s"Precision = $precision") -println(s"Recall = $recall") -println(s"F1 Score = $f1Score") - -// Precision by label -val labels = metrics.labels -labels.foreach { l => - println(s"Precision($l) = " + metrics.precision(l)) -} - -// Recall by label -labels.foreach { l => - println(s"Recall($l) = " + metrics.recall(l)) -} - -// False positive rate by label -labels.foreach { l => - println(s"FPR($l) = " + metrics.falsePositiveRate(l)) -} - -// F-measure by label -labels.foreach { l => - println(s"F1-Score($l) = " + metrics.fMeasure(l)) -} - -// Weighted stats -println(s"Weighted precision: ${metrics.weightedPrecision}") -println(s"Weighted recall: ${metrics.weightedRecall}") -println(s"Weighted F1 score: ${metrics.weightedFMeasure}") -println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala %}
    Refer to the [`MulticlassMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/MulticlassMetrics.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.rdd.RDD; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.classification.LogisticRegressionModel; -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; -import org.apache.spark.mllib.evaluation.MulticlassMetrics; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class MulticlassClassification { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Multiclass Classification Metrics"); - SparkContext sc = new SparkContext(conf); - String path = "data/mllib/sample_multiclass_classification_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); - - // Split initial RDD into two... [60% training data, 40% testing data]. - JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); - JavaRDD training = splits[0].cache(); - JavaRDD test = splits[1]; - - // Run training algorithm to build the model. - final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() - .setNumClasses(3) - .run(training.rdd()); - - // Compute raw scores on the test set. - JavaRDD> predictionAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double prediction = model.predict(p.features()); - return new Tuple2(prediction, p.label()); - } - } - ); - - // Get evaluation metrics. - MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); - - // Confusion matrix - Matrix confusion = metrics.confusionMatrix(); - System.out.println("Confusion matrix: \n" + confusion); - - // Overall statistics - System.out.println("Precision = " + metrics.precision()); - System.out.println("Recall = " + metrics.recall()); - System.out.println("F1 Score = " + metrics.fMeasure()); - - // Stats by labels - for (int i = 0; i < metrics.labels().length; i++) { - System.out.format("Class %f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); - System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); - System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure(metrics.labels()[i])); - } - - //Weighted stats - System.out.format("Weighted precision = %f\n", metrics.weightedPrecision()); - System.out.format("Weighted recall = %f\n", metrics.weightedRecall()); - System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure()); - System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate()); - - // Save and load model - model.save(sc, "myModelPath"); - LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); - } -} - -{% endhighlight %} + {% include_example java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java %}
    Refer to the [`MulticlassMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.MulticlassMetrics) for more details on the API. -{% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithLBFGS -from pyspark.mllib.util import MLUtils -from pyspark.mllib.evaluation import MulticlassMetrics - -# Load training data in LIBSVM format -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") - -# Split data into training (60%) and test (40%) -training, test = data.randomSplit([0.6, 0.4], seed = 11L) -training.cache() - -# Run training algorithm to build the model -model = LogisticRegressionWithLBFGS.train(training, numClasses=3) - -# Compute raw scores on the test set -predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) - -# Instantiate metrics object -metrics = MulticlassMetrics(predictionAndLabels) - -# Overall statistics -precision = metrics.precision() -recall = metrics.recall() -f1Score = metrics.fMeasure() -print("Summary Stats") -print("Precision = %s" % precision) -print("Recall = %s" % recall) -print("F1 Score = %s" % f1Score) - -# Statistics by class -labels = data.map(lambda lp: lp.label).distinct().collect() -for label in sorted(labels): - print("Class %s precision = %s" % (label, metrics.precision(label))) - print("Class %s recall = %s" % (label, metrics.recall(label))) - print("Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0))) - -# Weighted stats -print("Weighted recall = %s" % metrics.weightedRecall) -print("Weighted precision = %s" % metrics.weightedPrecision) -print("Weighted F(1) Score = %s" % metrics.weightedFMeasure()) -print("Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5)) -print("Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate) -{% endhighlight %} +{% include_example python/mllib/multi_class_metrics_example.py %}
    @@ -736,7 +360,7 @@ $$I_A(x) = \begin{cases}1 & \text{if $x \in A$}, \\ 0 & \text{otherwise}.\end{ca **Examples** -The following code snippets illustrate how to evaluate the performance of a multilabel classifer. The examples +The following code snippets illustrate how to evaluate the performance of a multilabel classifier. The examples use the fake prediction and label data for multilabel classification that is shown below. Document predictions: @@ -766,154 +390,21 @@ True classes:
    Refer to the [`MultilabelMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.MultilabelMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.evaluation.MultilabelMetrics -import org.apache.spark.rdd.RDD; - -val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize( - Seq((Array(0.0, 1.0), Array(0.0, 2.0)), - (Array(0.0, 2.0), Array(0.0, 1.0)), - (Array(), Array(0.0)), - (Array(2.0), Array(2.0)), - (Array(2.0, 0.0), Array(2.0, 0.0)), - (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)), - (Array(1.0), Array(1.0, 2.0))), 2) - -// Instantiate metrics object -val metrics = new MultilabelMetrics(scoreAndLabels) - -// Summary stats -println(s"Recall = ${metrics.recall}") -println(s"Precision = ${metrics.precision}") -println(s"F1 measure = ${metrics.f1Measure}") -println(s"Accuracy = ${metrics.accuracy}") - -// Individual label stats -metrics.labels.foreach(label => println(s"Class $label precision = ${metrics.precision(label)}")) -metrics.labels.foreach(label => println(s"Class $label recall = ${metrics.recall(label)}")) -metrics.labels.foreach(label => println(s"Class $label F1-score = ${metrics.f1Measure(label)}")) - -// Micro stats -println(s"Micro recall = ${metrics.microRecall}") -println(s"Micro precision = ${metrics.microPrecision}") -println(s"Micro F1 measure = ${metrics.microF1Measure}") - -// Hamming loss -println(s"Hamming loss = ${metrics.hammingLoss}") - -// Subset accuracy -println(s"Subset accuracy = ${metrics.subsetAccuracy}") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala %}
    Refer to the [`MultilabelMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/MultilabelMetrics.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.rdd.RDD; -import org.apache.spark.mllib.evaluation.MultilabelMetrics; -import org.apache.spark.SparkConf; -import java.util.Arrays; -import java.util.List; - -public class MultilabelClassification { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics"); - JavaSparkContext sc = new JavaSparkContext(conf); - - List> data = Arrays.asList( - new Tuple2(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), - new Tuple2(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), - new Tuple2(new double[]{}, new double[]{0.0}), - new Tuple2(new double[]{2.0}, new double[]{2.0}), - new Tuple2(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), - new Tuple2(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), - new Tuple2(new double[]{1.0}, new double[]{1.0, 2.0}) - ); - JavaRDD> scoreAndLabels = sc.parallelize(data); - - // Instantiate metrics object - MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd()); - - // Summary stats - System.out.format("Recall = %f\n", metrics.recall()); - System.out.format("Precision = %f\n", metrics.precision()); - System.out.format("F1 measure = %f\n", metrics.f1Measure()); - System.out.format("Accuracy = %f\n", metrics.accuracy()); - - // Stats by labels - for (int i = 0; i < metrics.labels().length - 1; i++) { - System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); - System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); - System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure(metrics.labels()[i])); - } - - // Micro stats - System.out.format("Micro recall = %f\n", metrics.microRecall()); - System.out.format("Micro precision = %f\n", metrics.microPrecision()); - System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure()); - - // Hamming loss - System.out.format("Hamming loss = %f\n", metrics.hammingLoss()); - - // Subset accuracy - System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy()); - - } -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java %}
    Refer to the [`MultilabelMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.MultilabelMetrics) for more details on the API. -{% highlight python %} -from pyspark.mllib.evaluation import MultilabelMetrics - -scoreAndLabels = sc.parallelize([ - ([0.0, 1.0], [0.0, 2.0]), - ([0.0, 2.0], [0.0, 1.0]), - ([], [0.0]), - ([2.0], [2.0]), - ([2.0, 0.0], [2.0, 0.0]), - ([0.0, 1.0, 2.0], [0.0, 1.0]), - ([1.0], [1.0, 2.0])]) - -# Instantiate metrics object -metrics = MultilabelMetrics(scoreAndLabels) - -# Summary stats -print("Recall = %s" % metrics.recall()) -print("Precision = %s" % metrics.precision()) -print("F1 measure = %s" % metrics.f1Measure()) -print("Accuracy = %s" % metrics.accuracy) - -# Individual label stats -labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect() -for label in labels: - print("Class %s precision = %s" % (label, metrics.precision(label))) - print("Class %s recall = %s" % (label, metrics.recall(label))) - print("Class %s F1 Measure = %s" % (label, metrics.f1Measure(label))) - -# Micro stats -print("Micro precision = %s" % metrics.microPrecision) -print("Micro recall = %s" % metrics.microRecall) -print("Micro F1 measure = %s" % metrics.microF1Measure) - -# Hamming loss -print("Hamming loss = %s" % metrics.hammingLoss) - -# Subset accuracy -print("Subset accuracy = %s" % metrics.subsetAccuracy) - -{% endhighlight %} +{% include_example python/mllib/multi_label_metrics_example.py %}
    @@ -1027,280 +518,21 @@ expanded world of non-positive weights are "the same as never having interacted
    Refer to the [`RegressionMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.RegressionMetrics) and [`RankingMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.RankingMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.evaluation.{RegressionMetrics, RankingMetrics} -import org.apache.spark.mllib.recommendation.{ALS, Rating} - -// Read in the ratings data -val ratings = sc.textFile("data/mllib/sample_movielens_data.txt").map { line => - val fields = line.split("::") - Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5) -}.cache() - -// Map ratings to 1 or 0, 1 indicating a movie that should be recommended -val binarizedRatings = ratings.map(r => Rating(r.user, r.product, if (r.rating > 0) 1.0 else 0.0)).cache() - -// Summarize ratings -val numRatings = ratings.count() -val numUsers = ratings.map(_.user).distinct().count() -val numMovies = ratings.map(_.product).distinct().count() -println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") - -// Build the model -val numIterations = 10 -val rank = 10 -val lambda = 0.01 -val model = ALS.train(ratings, rank, numIterations, lambda) - -// Define a function to scale ratings from 0 to 1 -def scaledRating(r: Rating): Rating = { - val scaledRating = math.max(math.min(r.rating, 1.0), 0.0) - Rating(r.user, r.product, scaledRating) -} - -// Get sorted top ten predictions for each user and then scale from [0, 1] -val userRecommended = model.recommendProductsForUsers(10).map{ case (user, recs) => - (user, recs.map(scaledRating)) -} - -// Assume that any movie a user rated 3 or higher (which maps to a 1) is a relevant document -// Compare with top ten most relevant documents -val userMovies = binarizedRatings.groupBy(_.user) -val relevantDocuments = userMovies.join(userRecommended).map{ case (user, (actual, predictions)) => - (predictions.map(_.product), actual.filter(_.rating > 0.0).map(_.product).toArray) -} - -// Instantiate metrics object -val metrics = new RankingMetrics(relevantDocuments) - -// Precision at K -Array(1, 3, 5).foreach{ k => - println(s"Precision at $k = ${metrics.precisionAt(k)}") -} - -// Mean average precision -println(s"Mean average precision = ${metrics.meanAveragePrecision}") - -// Normalized discounted cumulative gain -Array(1, 3, 5).foreach{ k => - println(s"NDCG at $k = ${metrics.ndcgAt(k)}") -} - -// Get predictions for each data point -val allPredictions = model.predict(ratings.map(r => (r.user, r.product))).map(r => ((r.user, r.product), r.rating)) -val allRatings = ratings.map(r => ((r.user, r.product), r.rating)) -val predictionsAndLabels = allPredictions.join(allRatings).map{ case ((user, product), (predicted, actual)) => - (predicted, actual) -} - -// Get the RMSE using regression metrics -val regressionMetrics = new RegressionMetrics(predictionsAndLabels) -println(s"RMSE = ${regressionMetrics.rootMeanSquaredError}") - -// R-squared -println(s"R-squared = ${regressionMetrics.r2}") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala %}
    Refer to the [`RegressionMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/RegressionMetrics.html) and [`RankingMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/RankingMetrics.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.rdd.RDD; -import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.function.Function; -import java.util.*; -import org.apache.spark.mllib.evaluation.RegressionMetrics; -import org.apache.spark.mllib.evaluation.RankingMetrics; -import org.apache.spark.mllib.recommendation.ALS; -import org.apache.spark.mllib.recommendation.Rating; - -// Read in the ratings data -public class Ranking { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Ranking Metrics"); - JavaSparkContext sc = new JavaSparkContext(conf); - String path = "data/mllib/sample_movielens_data.txt"; - JavaRDD data = sc.textFile(path); - JavaRDD ratings = data.map( - new Function() { - public Rating call(String line) { - String[] parts = line.split("::"); - return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double.parseDouble(parts[2]) - 2.5); - } - } - ); - ratings.cache(); - - // Train an ALS model - final MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), 10, 10, 0.01); - - // Get top 10 recommendations for every user and scale ratings from 0 to 1 - JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD(); - JavaRDD> userRecsScaled = userRecs.map( - new Function, Tuple2>() { - public Tuple2 call(Tuple2 t) { - Rating[] scaledRatings = new Rating[t._2().length]; - for (int i = 0; i < scaledRatings.length; i++) { - double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0); - scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating); - } - return new Tuple2(t._1(), scaledRatings); - } - } - ); - JavaPairRDD userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled); - - // Map ratings to 1 or 0, 1 indicating a movie that should be recommended - JavaRDD binarizedRatings = ratings.map( - new Function() { - public Rating call(Rating r) { - double binaryRating; - if (r.rating() > 0.0) { - binaryRating = 1.0; - } - else { - binaryRating = 0.0; - } - return new Rating(r.user(), r.product(), binaryRating); - } - } - ); - - // Group ratings by common user - JavaPairRDD> userMovies = binarizedRatings.groupBy( - new Function() { - public Object call(Rating r) { - return r.user(); - } - } - ); - - // Get true relevant documents from all user ratings - JavaPairRDD> userMoviesList = userMovies.mapValues( - new Function, List>() { - public List call(Iterable docs) { - List products = new ArrayList(); - for (Rating r : docs) { - if (r.rating() > 0.0) { - products.add(r.product()); - } - } - return products; - } - } - ); - - // Extract the product id from each recommendation - JavaPairRDD> userRecommendedList = userRecommended.mapValues( - new Function>() { - public List call(Rating[] docs) { - List products = new ArrayList(); - for (Rating r : docs) { - products.add(r.product()); - } - return products; - } - } - ); - JavaRDD, List>> relevantDocs = userMoviesList.join(userRecommendedList).values(); - - // Instantiate the metrics object - RankingMetrics metrics = RankingMetrics.of(relevantDocs); - - // Precision and NDCG at k - Integer[] kVector = {1, 3, 5}; - for (Integer k : kVector) { - System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k)); - System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k)); - } - - // Mean average precision - System.out.format("Mean average precision = %f\n", metrics.meanAveragePrecision()); - - // Evaluate the model using numerical ratings and regression metrics - JavaRDD> userProducts = ratings.map( - new Function>() { - public Tuple2 call(Rating r) { - return new Tuple2(r.user(), r.product()); - } - } - ); - JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD( - model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( - new Function, Object>>() { - public Tuple2, Object> call(Rating r){ - return new Tuple2, Object>( - new Tuple2(r.user(), r.product()), r.rating()); - } - } - )); - JavaRDD> ratesAndPreds = - JavaPairRDD.fromJavaRDD(ratings.map( - new Function, Object>>() { - public Tuple2, Object> call(Rating r){ - return new Tuple2, Object>( - new Tuple2(r.user(), r.product()), r.rating()); - } - } - )).join(predictions).values(); - - // Create regression metrics object - RegressionMetrics regressionMetrics = new RegressionMetrics(ratesAndPreds.rdd()); - - // Root mean squared error - System.out.format("RMSE = %f\n", regressionMetrics.rootMeanSquaredError()); - - // R-squared - System.out.format("R-squared = %f\n", regressionMetrics.r2()); - } -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java %}
    Refer to the [`RegressionMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RegressionMetrics) and [`RankingMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RankingMetrics) for more details on the API. -{% highlight python %} -from pyspark.mllib.recommendation import ALS, Rating -from pyspark.mllib.evaluation import RegressionMetrics, RankingMetrics - -# Read in the ratings data -lines = sc.textFile("data/mllib/sample_movielens_data.txt") - -def parseLine(line): - fields = line.split("::") - return Rating(int(fields[0]), int(fields[1]), float(fields[2]) - 2.5) -ratings = lines.map(lambda r: parseLine(r)) - -# Train a model on to predict user-product ratings -model = ALS.train(ratings, 10, 10, 0.01) - -# Get predicted ratings on all existing user-product pairs -testData = ratings.map(lambda p: (p.user, p.product)) -predictions = model.predictAll(testData).map(lambda r: ((r.user, r.product), r.rating)) - -ratingsTuple = ratings.map(lambda r: ((r.user, r.product), r.rating)) -scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1]) - -# Instantiate regression metrics to compare predicted and actual ratings -metrics = RegressionMetrics(scoreAndLabels) - -# Root mean sqaured error -print("RMSE = %s" % metrics.rootMeanSquaredError) - -# R-squared -print("R-squared = %s" % metrics.r2) - -{% endhighlight %} +{% include_example python/mllib/ranking_metrics_example.py %}
    @@ -1326,7 +558,7 @@ variable from a number of independent variables. $RMSE = \sqrt{\frac{\sum_{i=0}^{N-1} (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{N}}$ - Mean Absoloute Error (MAE) + Mean Absolute Error (MAE) $MAE=\sum_{i=0}^{N-1} \left|\mathbf{y}_i - \hat{\mathbf{y}}_i\right|$ @@ -1350,163 +582,21 @@ and evaluate the performance of the algorithm by several regression metrics.
    Refer to the [`RegressionMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.RegressionMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.regression.LinearRegressionModel -import org.apache.spark.mllib.regression.LinearRegressionWithSGD -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.evaluation.RegressionMetrics -import org.apache.spark.mllib.util.MLUtils - -// Load the data -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_linear_regression_data.txt").cache() - -// Build the model -val numIterations = 100 -val model = LinearRegressionWithSGD.train(data, numIterations) - -// Get predictions -val valuesAndPreds = data.map{ point => - val prediction = model.predict(point.features) - (prediction, point.label) -} - -// Instantiate metrics object -val metrics = new RegressionMetrics(valuesAndPreds) - -// Squared error -println(s"MSE = ${metrics.meanSquaredError}") -println(s"RMSE = ${metrics.rootMeanSquaredError}") - -// R-squared -println(s"R-squared = ${metrics.r2}") - -// Mean absolute error -println(s"MAE = ${metrics.meanAbsoluteError}") - -// Explained variance -println(s"Explained variance = ${metrics.explainedVariance}") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala %}
    Refer to the [`RegressionMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/RegressionMetrics.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.regression.LinearRegressionModel; -import org.apache.spark.mllib.regression.LinearRegressionWithSGD; -import org.apache.spark.mllib.evaluation.RegressionMetrics; -import org.apache.spark.SparkConf; - -public class LinearRegression { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); - JavaSparkContext sc = new JavaSparkContext(conf); - - // Load and parse the data - String path = "data/mllib/sample_linear_regression_data.txt"; - JavaRDD data = sc.textFile(path); - JavaRDD parsedData = data.map( - new Function() { - public LabeledPoint call(String line) { - String[] parts = line.split(" "); - double[] v = new double[parts.length - 1]; - for (int i = 1; i < parts.length - 1; i++) - v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); - return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); - } - } - ); - parsedData.cache(); - - // Building the model - int numIterations = 100; - final LinearRegressionModel model = - LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); - - // Evaluate model on training examples and compute training error - JavaRDD> valuesAndPreds = parsedData.map( - new Function>() { - public Tuple2 call(LabeledPoint point) { - double prediction = model.predict(point.features()); - return new Tuple2(prediction, point.label()); - } - } - ); - - // Instantiate metrics object - RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd()); - - // Squared error - System.out.format("MSE = %f\n", metrics.meanSquaredError()); - System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError()); - - // R-squared - System.out.format("R Squared = %f\n", metrics.r2()); - - // Mean absolute error - System.out.format("MAE = %f\n", metrics.meanAbsoluteError()); - - // Explained variance - System.out.format("Explained Variance = %f\n", metrics.explainedVariance()); - - // Save and load model - model.save(sc.sc(), "myModelPath"); - LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath"); - } -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java %}
    Refer to the [`RegressionMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RegressionMetrics) for more details on the API. -{% highlight python %} -from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD -from pyspark.mllib.evaluation import RegressionMetrics -from pyspark.mllib.linalg import DenseVector - -# Load and parse the data -def parsePoint(line): - values = line.split() - return LabeledPoint(float(values[0]), DenseVector([float(x.split(':')[1]) for x in values[1:]])) - -data = sc.textFile("data/mllib/sample_linear_regression_data.txt") -parsedData = data.map(parsePoint) - -# Build the model -model = LinearRegressionWithSGD.train(parsedData) - -# Get predictions -valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.label)) - -# Instantiate metrics object -metrics = RegressionMetrics(valuesAndPreds) - -# Squared Error -print("MSE = %s" % metrics.meanSquaredError) -print("RMSE = %s" % metrics.rootMeanSquaredError) - -# R-squared -print("R-squared = %s" % metrics.r2) - -# Mean absolute error -print("MAE = %s" % metrics.meanAbsoluteError) - -# Explained variance -print("Explained variance = %s" % metrics.explainedVariance) - -{% endhighlight %} +{% include_example python/mllib/regression_metrics_example.py %}
    diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 5bee170c61fe9..7a97285032655 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -1,7 +1,7 @@ --- layout: global -title: Feature Extraction and Transformation - MLlib -displayTitle: MLlib - Feature Extraction and Transformation +title: Feature Extraction and Transformation - spark.mllib +displayTitle: Feature Extraction and Transformation - spark.mllib --- * Table of contents @@ -31,7 +31,7 @@ The TF-IDF measure is simply the product of TF and IDF: TFIDF(t, d, D) = TF(t, d) \cdot IDF(t, D). \]` There are several variants on the definition of term frequency and document frequency. -In MLlib, we separate TF and IDF to make them flexible. +In `spark.mllib`, we separate TF and IDF to make them flexible. Our implementation of term frequency utilizes the [hashing trick](http://en.wikipedia.org/wiki/Feature_hashing). @@ -44,7 +44,7 @@ To reduce the chance of collision, we can increase the target feature dimension, the number of buckets of the hash table. The default feature dimension is `$2^{20} = 1,048,576$`. -**Note:** MLlib doesn't provide tools for text segmentation. +**Note:** `spark.mllib` doesn't provide tools for text segmentation. We refer users to the [Stanford NLP Group](http://nlp.stanford.edu/) and [scalanlp/chalk](https://github.com/scalanlp/chalk). @@ -58,46 +58,7 @@ Each record could be an iterable of strings or other types. Refer to the [`HashingTF` Scala docs](api/scala/index.html#org.apache.spark.mllib.feature.HashingTF) for details on the API. - -{% highlight scala %} -import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext -import org.apache.spark.mllib.feature.HashingTF -import org.apache.spark.mllib.linalg.Vector - -val sc: SparkContext = ... - -// Load documents (one per line). -val documents: RDD[Seq[String]] = sc.textFile("...").map(_.split(" ").toSeq) - -val hashingTF = new HashingTF() -val tf: RDD[Vector] = hashingTF.transform(documents) -{% endhighlight %} - -While applying `HashingTF` only needs a single pass to the data, applying `IDF` needs two passes: -first to compute the IDF vector and second to scale the term frequencies by IDF. - -{% highlight scala %} -import org.apache.spark.mllib.feature.IDF - -// ... continue from the previous example -tf.cache() -val idf = new IDF().fit(tf) -val tfidf: RDD[Vector] = idf.transform(tf) -{% endhighlight %} - -MLlib's IDF implementation provides an option for ignoring terms which occur in less than a -minimum number of documents. In such cases, the IDF for these terms is set to 0. This feature -can be used by passing the `minDocFreq` value to the IDF constructor. - -{% highlight scala %} -import org.apache.spark.mllib.feature.IDF - -// ... continue from the previous example -tf.cache() -val idf = new IDF(minDocFreq = 2).fit(tf) -val tfidf: RDD[Vector] = idf.transform(tf) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/TFIDFExample.scala %}
    @@ -109,41 +70,7 @@ Each record could be an iterable of strings or other types. Refer to the [`HashingTF` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.feature.HashingTF) for details on the API. -{% highlight python %} -from pyspark import SparkContext -from pyspark.mllib.feature import HashingTF - -sc = SparkContext() - -# Load documents (one per line). -documents = sc.textFile("...").map(lambda line: line.split(" ")) - -hashingTF = HashingTF() -tf = hashingTF.transform(documents) -{% endhighlight %} - -While applying `HashingTF` only needs a single pass to the data, applying `IDF` needs two passes: -first to compute the IDF vector and second to scale the term frequencies by IDF. - -{% highlight python %} -from pyspark.mllib.feature import IDF - -# ... continue from the previous example -tf.cache() -idf = IDF().fit(tf) -tfidf = idf.transform(tf) -{% endhighlight %} - -MLLib's IDF implementation provides an option for ignoring terms which occur in less than a -minimum number of documents. In such cases, the IDF for these terms is set to 0. This feature -can be used by passing the `minDocFreq` value to the IDF constructor. - -{% highlight python %} -# ... continue from the previous example -tf.cache() -idf = IDF(minDocFreq=2).fit(tf) -tfidf = idf.transform(tf) -{% endhighlight %} +{% include_example python/mllib/tf_idf_example.py %}
    @@ -192,47 +119,12 @@ Here we assume the extracted file is `text8` and in same directory as you run th
    Refer to the [`Word2Vec` Scala docs](api/scala/index.html#org.apache.spark.mllib.feature.Word2Vec) for details on the API. -{% highlight scala %} -import org.apache.spark._ -import org.apache.spark.rdd._ -import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel} - -val input = sc.textFile("text8").map(line => line.split(" ").toSeq) - -val word2vec = new Word2Vec() - -val model = word2vec.fit(input) - -val synonyms = model.findSynonyms("china", 40) - -for((synonym, cosineSimilarity) <- synonyms) { - println(s"$synonym $cosineSimilarity") -} - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = Word2VecModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/Word2VecExample.scala %}
    Refer to the [`Word2Vec` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.feature.Word2Vec) for more details on the API. -{% highlight python %} -from pyspark import SparkContext -from pyspark.mllib.feature import Word2Vec - -sc = SparkContext(appName='Word2Vec') -inp = sc.textFile("text8_lines").map(lambda row: row.split(" ")) - -word2vec = Word2Vec() -model = word2vec.fit(inp) - -synonyms = model.findSynonyms('china', 40) - -for word, cosine_distance in synonyms: - print("{}: {}".format(word, cosine_distance)) -{% endhighlight %} +{% include_example python/mllib/word2vec_example.py %}
    @@ -277,55 +169,13 @@ so that the new features have unit standard deviation and/or zero mean.
    Refer to the [`StandardScaler` Scala docs](api/scala/index.html#org.apache.spark.mllib.feature.StandardScaler) for details on the API. -{% highlight scala %} -import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.feature.StandardScaler -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.MLUtils - -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") - -val scaler1 = new StandardScaler().fit(data.map(x => x.features)) -val scaler2 = new StandardScaler(withMean = true, withStd = true).fit(data.map(x => x.features)) -// scaler3 is an identical model to scaler2, and will produce identical transformations -val scaler3 = new StandardScalerModel(scaler2.std, scaler2.mean) - -// data1 will be unit variance. -val data1 = data.map(x => (x.label, scaler1.transform(x.features))) - -// Without converting the features into dense vectors, transformation with zero mean will raise -// exception on sparse vector. -// data2 will be unit variance and zero mean. -val data2 = data.map(x => (x.label, scaler2.transform(Vectors.dense(x.features.toArray)))) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/StandardScalerExample.scala %}
    Refer to the [`StandardScaler` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.feature.StandardScaler) for more details on the API. -{% highlight python %} -from pyspark.mllib.util import MLUtils -from pyspark.mllib.linalg import Vectors -from pyspark.mllib.feature import StandardScaler - -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -label = data.map(lambda x: x.label) -features = data.map(lambda x: x.features) - -scaler1 = StandardScaler().fit(features) -scaler2 = StandardScaler(withMean=True, withStd=True).fit(features) -# scaler3 is an identical model to scaler2, and will produce identical transformations -scaler3 = StandardScalerModel(scaler2.std, scaler2.mean) - - -# data1 will be unit variance. -data1 = label.zip(scaler1.transform(features)) - -# Without converting the features into dense vectors, transformation with zero mean will raise -# exception on sparse vector. -# data2 will be unit variance and zero mean. -data2 = label.zip(scaler1.transform(features.map(lambda x: Vectors.dense(x.toArray())))) -{% endhighlight %} +{% include_example python/mllib/standard_scaler_example.py %}
    @@ -355,46 +205,13 @@ with $L^2$ norm, and $L^\infty$ norm.
    Refer to the [`Normalizer` Scala docs](api/scala/index.html#org.apache.spark.mllib.feature.Normalizer) for details on the API. -{% highlight scala %} -import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.feature.Normalizer -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.MLUtils - -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") - -val normalizer1 = new Normalizer() -val normalizer2 = new Normalizer(p = Double.PositiveInfinity) - -// Each sample in data1 will be normalized using $L^2$ norm. -val data1 = data.map(x => (x.label, normalizer1.transform(x.features))) - -// Each sample in data2 will be normalized using $L^\infty$ norm. -val data2 = data.map(x => (x.label, normalizer2.transform(x.features))) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/NormalizerExample.scala %}
    Refer to the [`Normalizer` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.feature.Normalizer) for more details on the API. -{% highlight python %} -from pyspark.mllib.util import MLUtils -from pyspark.mllib.linalg import Vectors -from pyspark.mllib.feature import Normalizer - -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -labels = data.map(lambda x: x.label) -features = data.map(lambda x: x.features) - -normalizer1 = Normalizer() -normalizer2 = Normalizer(p=float("inf")) - -# Each sample in data1 will be normalized using $L^2$ norm. -data1 = labels.zip(normalizer1.transform(features)) - -# Each sample in data2 will be normalized using $L^\infty$ norm. -data2 = labels.zip(normalizer2.transform(features)) -{% endhighlight %} +{% include_example python/mllib/normalizer_example.py %}
    @@ -435,29 +252,7 @@ The following example shows the basic use of ChiSqSelector. The data set used ha Refer to the [`ChiSqSelector` Scala docs](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) for details on the API. -{% highlight scala %} -import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.feature.ChiSqSelector - -// Load some data in libsvm format -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -// Discretize data in 16 equal bins since ChiSqSelector requires categorical features -// Even though features are doubles, the ChiSqSelector treats each unique value as a category -val discretizedData = data.map { lp => - LabeledPoint(lp.label, Vectors.dense(lp.features.toArray.map { x => (x / 16).floor } ) ) -} -// Create ChiSqSelector that will select top 50 of 692 features -val selector = new ChiSqSelector(50) -// Create ChiSqSelector model (selecting features) -val transformer = selector.fit(discretizedData) -// Filter the top 50 features from each feature vector -val filteredData = discretizedData.map { lp => - LabeledPoint(lp.label, transformer.transform(lp.features)) -} -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/ChiSqSelectorExample.scala %}
    @@ -465,52 +260,7 @@ val filteredData = discretizedData.map { lp => Refer to the [`ChiSqSelector` Java docs](api/java/org/apache/spark/mllib/feature/ChiSqSelector.html) for details on the API. -{% highlight java %} -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.feature.ChiSqSelector; -import org.apache.spark.mllib.feature.ChiSqSelectorModel; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; - -SparkConf sparkConf = new SparkConf().setAppName("JavaChiSqSelector"); -JavaSparkContext sc = new JavaSparkContext(sparkConf); -JavaRDD points = MLUtils.loadLibSVMFile(sc.sc(), - "data/mllib/sample_libsvm_data.txt").toJavaRDD().cache(); - -// Discretize data in 16 equal bins since ChiSqSelector requires categorical features -// Even though features are doubles, the ChiSqSelector treats each unique value as a category -JavaRDD discretizedData = points.map( - new Function() { - @Override - public LabeledPoint call(LabeledPoint lp) { - final double[] discretizedFeatures = new double[lp.features().size()]; - for (int i = 0; i < lp.features().size(); ++i) { - discretizedFeatures[i] = Math.floor(lp.features().apply(i) / 16); - } - return new LabeledPoint(lp.label(), Vectors.dense(discretizedFeatures)); - } - }); - -// Create ChiSqSelector that will select top 50 of 692 features -ChiSqSelector selector = new ChiSqSelector(50); -// Create ChiSqSelector model (selecting features) -final ChiSqSelectorModel transformer = selector.fit(discretizedData.rdd()); -// Filter the top 50 features from each feature vector -JavaRDD filteredData = discretizedData.map( - new Function() { - @Override - public LabeledPoint call(LabeledPoint lp) { - return new LabeledPoint(lp.label(), transformer.transform(lp.features())); - } - } -); - -sc.stop(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaChiSqSelectorExample.java %}
    @@ -554,78 +304,19 @@ This example below demonstrates how to transform vectors using a transforming ve Refer to the [`ElementwiseProduct` Scala docs](api/scala/index.html#org.apache.spark.mllib.feature.ElementwiseProduct) for details on the API. -{% highlight scala %} -import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.feature.ElementwiseProduct -import org.apache.spark.mllib.linalg.Vectors - -// Create some vector data; also works for sparse vectors -val data = sc.parallelize(Array(Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0))) - -val transformingVector = Vectors.dense(0.0, 1.0, 2.0) -val transformer = new ElementwiseProduct(transformingVector) - -// Batch transform and per-row transform give the same results: -val transformedData = transformer.transform(data) -val transformedData2 = data.map(x => transformer.transform(x)) - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/ElementwiseProductExample.scala %}
    Refer to the [`ElementwiseProduct` Java docs](api/java/org/apache/spark/mllib/feature/ElementwiseProduct.html) for details on the API. -{% highlight java %} -import java.util.Arrays; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.feature.ElementwiseProduct; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; - -// Create some vector data; also works for sparse vectors -JavaRDD data = sc.parallelize(Arrays.asList( - Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0))); -Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); -ElementwiseProduct transformer = new ElementwiseProduct(transformingVector); - -// Batch transform and per-row transform give the same results: -JavaRDD transformedData = transformer.transform(data); -JavaRDD transformedData2 = data.map( - new Function() { - @Override - public Vector call(Vector v) { - return transformer.transform(v); - } - } -); - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaElementwiseProductExample.java %}
    Refer to the [`ElementwiseProduct` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.feature.ElementwiseProduct) for more details on the API. -{% highlight python %} -from pyspark import SparkContext -from pyspark.mllib.linalg import Vectors -from pyspark.mllib.feature import ElementwiseProduct - -# Load and parse the data -sc = SparkContext() -data = sc.textFile("data/mllib/kmeans_data.txt") -parsedData = data.map(lambda x: [float(t) for t in x.split(" ")]) - -# Create weight vector. -transformingVector = Vectors.dense([0.0, 1.0, 2.0]) -transformer = ElementwiseProduct(transformingVector) - -# Batch transform -transformedData = transformer.transform(parsedData) -# Single-row transform -transformedData2 = transformer.transform(parsedData.first()) - -{% endhighlight %} +{% include_example python/mllib/elementwise_product_example.py %}
    @@ -645,44 +336,6 @@ for calculation a [Linear Regression]((mllib-linear-methods.html))
    Refer to the [`PCA` Scala docs](api/scala/index.html#org.apache.spark.mllib.feature.PCA) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.regression.LinearRegressionWithSGD -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.feature.PCA - -val data = sc.textFile("data/mllib/ridge-data/lpsa.data").map { line => - val parts = line.split(',') - LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble))) -}.cache() - -val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) -val training = splits(0).cache() -val test = splits(1) - -val pca = new PCA(training.first().features.size/2).fit(data.map(_.features)) -val training_pca = training.map(p => p.copy(features = pca.transform(p.features))) -val test_pca = test.map(p => p.copy(features = pca.transform(p.features))) - -val numIterations = 100 -val model = LinearRegressionWithSGD.train(training, numIterations) -val model_pca = LinearRegressionWithSGD.train(training_pca, numIterations) - -val valuesAndPreds = test.map { point => - val score = model.predict(point.features) - (score, point.label) -} - -val valuesAndPreds_pca = test_pca.map { point => - val score = model_pca.predict(point.features) - (score, point.label) -} - -val MSE = valuesAndPreds.map{case(v, p) => math.pow((v - p), 2)}.mean() -val MSE_pca = valuesAndPreds_pca.map{case(v, p) => math.pow((v - p), 2)}.mean() - -println("Mean Squared Error = " + MSE) -println("PCA Mean Squared Error = " + MSE_pca) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/PCAExample.scala %}
    diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index fe42896a05d8e..a7b55dc5e5668 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -1,7 +1,7 @@ --- layout: global -title: Frequent Pattern Mining - MLlib -displayTitle: MLlib - Frequent Pattern Mining +title: Frequent Pattern Mining - spark.mllib +displayTitle: Frequent Pattern Mining - spark.mllib --- Mining frequent items, itemsets, subsequences, or other substructures is usually among the @@ -9,7 +9,7 @@ first steps to analyze a large-scale dataset, which has been an active research data mining for years. We refer users to Wikipedia's [association rule learning](http://en.wikipedia.org/wiki/Association_rule_learning) for more information. -MLlib provides a parallel implementation of FP-growth, +`spark.mllib` provides a parallel implementation of FP-growth, a popular algorithm to mining frequent itemsets. ## FP-growth @@ -22,13 +22,13 @@ Different from [Apriori-like](http://en.wikipedia.org/wiki/Apriori_algorithm) al the second step of FP-growth uses a suffix tree (FP-tree) structure to encode transactions without generating candidate sets explicitly, which are usually expensive to generate. After the second step, the frequent itemsets can be extracted from the FP-tree. -In MLlib, we implemented a parallel version of FP-growth called PFP, +In `spark.mllib`, we implemented a parallel version of FP-growth called PFP, as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027). PFP distributes the work of growing FP-trees based on the suffices of transactions, and hence more scalable than a single-machine implementation. We refer users to the papers for more details. -MLlib's FP-growth implementation takes the following (hyper-)parameters: +`spark.mllib`'s FP-growth implementation takes the following (hyper-)parameters: * `minSupport`: the minimum support for an itemset to be identified as frequent. For example, if an item appears 3 out of 5 transactions, it has a support of 3/5=0.6. @@ -126,7 +126,7 @@ PrefixSpan Approach](http://dx.doi.org/10.1109%2FTKDE.2004.77). We refer the reader to the referenced paper for formalizing the sequential pattern mining problem. -MLlib's PrefixSpan implementation takes the following parameters: +`spark.mllib`'s PrefixSpan implementation takes the following parameters: * `minSupport`: the minimum support required to be considered a frequent sequential pattern. @@ -135,7 +135,7 @@ MLlib's PrefixSpan implementation takes the following parameters: included in the results. * `maxLocalProjDBSize`: the maximum number of items allowed in a prefix-projected database before local iterative processing of the - projected databse begins. This parameter should be tuned with respect + projected database begins. This parameter should be tuned with respect to the size of your executors. **Examples** diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 91e50ccfecec4..fa5e90603505d 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -34,6 +34,7 @@ We list major functionality from both below, with links to detailed guides. * [correlations](mllib-statistics.html#correlations) * [stratified sampling](mllib-statistics.html#stratified-sampling) * [hypothesis testing](mllib-statistics.html#hypothesis-testing) + * [streaming significance testing](mllib-statistics.html#streaming-significance-testing) * [random data generation](mllib-statistics.html#random-data-generation) * [Classification and regression](mllib-classification-regression.html) * [linear models (SVMs, logistic regression, linear regression)](mllib-linear-methods.html) @@ -48,6 +49,7 @@ We list major functionality from both below, with links to detailed guides. * [Gaussian mixture](mllib-clustering.html#gaussian-mixture) * [power iteration clustering (PIC)](mllib-clustering.html#power-iteration-clustering-pic) * [latent Dirichlet allocation (LDA)](mllib-clustering.html#latent-dirichlet-allocation-lda) + * [bisecting k-means](mllib-clustering.html#bisecting-kmeans) * [streaming k-means](mllib-clustering.html#streaming-k-means) * [Dimensionality reduction](mllib-dimensionality-reduction.html) * [singular value decomposition (SVD)](mllib-dimensionality-reduction.html#singular-value-decomposition-svd) @@ -65,14 +67,15 @@ We list major functionality from both below, with links to detailed guides. # spark.ml: high-level APIs for ML pipelines -**[spark.ml programming guide](ml-guide.html)** provides an overview of the Pipelines API and major -concepts. It also contains sections on using algorithms within the Pipelines API, for example: +* [Overview: estimators, transformers and pipelines](ml-guide.html) +* [Extracting, transforming and selecting features](ml-features.html) +* [Classification and regression](ml-classification-regression.html) +* [Clustering](ml-clustering.html) +* [Collaborative filtering](ml-collaborative-filtering.html) +* [Advanced topics](ml-advanced.html) -* [Feature extraction, transformation, and selection](ml-features.html) -* [Decision trees for classification and regression](ml-decision-tree.html) -* [Ensembles](ml-ensembles.html) -* [Linear methods with elastic net regularization](ml-linear-methods.html) -* [Multilayer perceptron classifier](ml-ann.html) +Some techniques are not available yet in spark.ml, most notably dimensionality reduction +Users can seamlessly combine the implementation of these techniques found in `spark.mllib` with the rest of the algorithms found in `spark.ml`. # Dependencies @@ -99,24 +102,32 @@ MLlib is under active development. The APIs marked `Experimental`/`DeveloperApi` may change in future releases, and the migration guide below will explain all changes between releases. -## From 1.4 to 1.5 +## From 1.5 to 1.6 -In the `spark.mllib` package, there are no break API changes but several behavior changes: +There are no breaking API changes in the `spark.mllib` or `spark.ml` packages, but there are +deprecations and changes of behavior. -* [SPARK-9005](https://issues.apache.org/jira/browse/SPARK-9005): - `RegressionMetrics.explainedVariance` returns the average regression sum of squares. -* [SPARK-8600](https://issues.apache.org/jira/browse/SPARK-8600): `NaiveBayesModel.labels` become - sorted. -* [SPARK-3382](https://issues.apache.org/jira/browse/SPARK-3382): `GradientDescent` has a default - convergence tolerance `1e-3`, and hence iterations might end earlier than 1.4. +Deprecations: -In the `spark.ml` package, there exists one break API change and one behavior change: +* [SPARK-11358](https://issues.apache.org/jira/browse/SPARK-11358): + In `spark.mllib.clustering.KMeans`, the `runs` parameter has been deprecated. +* [SPARK-10592](https://issues.apache.org/jira/browse/SPARK-10592): + In `spark.ml.classification.LogisticRegressionModel` and + `spark.ml.regression.LinearRegressionModel`, the `weights` field has been deprecated in favor of + the new name `coefficients`. This helps disambiguate from instance (row) "weights" given to + algorithms. -* [SPARK-9268](https://issues.apache.org/jira/browse/SPARK-9268): Java's varargs support is removed - from `Params.setDefault` due to a - [Scala compiler bug](https://issues.scala-lang.org/browse/SI-9013). -* [SPARK-10097](https://issues.apache.org/jira/browse/SPARK-10097): `Evaluator.isLargerBetter` is - added to indicate metric ordering. Metrics like RMSE no longer flip signs as in 1.4. +Changes of behavior: + +* [SPARK-7770](https://issues.apache.org/jira/browse/SPARK-7770): + `spark.mllib.tree.GradientBoostedTrees`: `validationTol` has changed semantics in 1.6. + Previously, it was a threshold for absolute change in error. Now, it resembles the behavior of + `GradientDescent`'s `convergenceTol`: For large errors, it uses relative error (relative to the + previous error); for small errors (`< 0.01`), it uses absolute error. +* [SPARK-11069](https://issues.apache.org/jira/browse/SPARK-11069): + `spark.ml.feature.RegexTokenizer`: Previously, it did not convert strings to lowercase before + tokenizing. Now, it converts to lowercase by default, with an option not to. This matches the + behavior of the simpler `Tokenizer` transformer. ## Previous Spark versions diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index 85f9226b43416..8ede4407d5843 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -1,7 +1,7 @@ --- layout: global -title: Isotonic regression - MLlib -displayTitle: MLlib - Regression +title: Isotonic regression - spark.mllib +displayTitle: Regression - spark.mllib --- ## Isotonic regression @@ -23,7 +23,7 @@ Essentially isotonic regression is a [monotonic function](http://en.wikipedia.org/wiki/Monotonic_function) best fitting the original data points. -MLlib supports a +`spark.mllib` supports a [pool adjacent violators algorithm](http://doi.org/10.1198/TECH.2010.10111) which uses an approach to [parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 0c76e6e999465..63665c49bc972 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -1,7 +1,7 @@ --- layout: global -title: Linear Methods - MLlib -displayTitle: MLlib - Linear Methods +title: Linear Methods - spark.mllib +displayTitle: Linear Methods - spark.mllib --- * Table of contents @@ -41,7 +41,7 @@ the objective function is of the form Here the vectors `$\x_i\in\R^d$` are the training data examples, for `$1\le i\le n$`, and `$y_i\in\R$` are their corresponding labels, which we want to predict. We call the method *linear* if $L(\wv; \x, y)$ can be expressed as a function of $\wv^T x$ and $y$. -Several of MLlib's classification and regression algorithms fall into this category, +Several of `spark.mllib`'s classification and regression algorithms fall into this category, and are discussed here. The objective function `$f$` has two parts: @@ -55,7 +55,7 @@ training error) and minimizing model complexity (i.e., to avoid overfitting). ### Loss functions The following table summarizes the loss functions and their gradients or sub-gradients for the -methods MLlib supports: +methods `spark.mllib` supports: @@ -83,7 +83,7 @@ methods MLlib supports: The purpose of the [regularizer](http://en.wikipedia.org/wiki/Regularization_(mathematics)) is to encourage simple models and avoid overfitting. We support the following -regularizers in MLlib: +regularizers in `spark.mllib`:
    @@ -115,27 +115,30 @@ especially when the number of training examples is small. ### Optimization -Under the hood, linear methods use convex optimization methods to optimize the objective functions. MLlib uses two methods, SGD and L-BFGS, described in the [optimization section](mllib-optimization.html). Currently, most algorithm APIs support Stochastic Gradient Descent (SGD), and a few support L-BFGS. Refer to [this optimization section](mllib-optimization.html#Choosing-an-Optimization-Method) for guidelines on choosing between optimization methods. +Under the hood, linear methods use convex optimization methods to optimize the objective functions. +`spark.mllib` uses two methods, SGD and L-BFGS, described in the [optimization section](mllib-optimization.html). +Currently, most algorithm APIs support Stochastic Gradient Descent (SGD), and a few support L-BFGS. +Refer to [this optimization section](mllib-optimization.html#Choosing-an-Optimization-Method) for guidelines on choosing between optimization methods. ## Classification [Classification](http://en.wikipedia.org/wiki/Statistical_classification) aims to divide items into categories. The most common classification type is -[binary classificaion](http://en.wikipedia.org/wiki/Binary_classification), where there are two +[binary classification](http://en.wikipedia.org/wiki/Binary_classification), where there are two categories, usually named positive and negative. If there are more than two categories, it is called [multiclass classification](http://en.wikipedia.org/wiki/Multiclass_classification). -MLlib supports two linear methods for classification: linear Support Vector Machines (SVMs) +`spark.mllib` supports two linear methods for classification: linear Support Vector Machines (SVMs) and logistic regression. Linear SVMs supports only binary classification, while logistic regression supports both binary and multiclass classification problems. -For both methods, MLlib supports L1 and L2 regularized variants. +For both methods, `spark.mllib` supports L1 and L2 regularized variants. The training data set is represented by an RDD of [LabeledPoint](mllib-data-types.html) in MLlib, where labels are class indices starting from zero: $0, 1, 2, \ldots$. Note that, in the mathematical formulation in this guide, a binary label $y$ is denoted as either $+1$ (positive) or $-1$ (negative), which is convenient for the formulation. -*However*, the negative label is represented by $0$ in MLlib instead of $-1$, to be consistent with +*However*, the negative label is represented by $0$ in `spark.mllib` instead of $-1$, to be consistent with multiclass labeling. ### Linear Support Vector Machines (SVMs) @@ -167,52 +170,18 @@ error. Refer to the [`SVMWithSGD` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.SVMWithSGD) and [`SVMModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.SVMModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD} -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.mllib.util.MLUtils - -// Load training data in LIBSVM format. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") - -// Split data into training (60%) and test (40%). -val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) -val training = splits(0).cache() -val test = splits(1) - -// Run training algorithm to build the model -val numIterations = 100 -val model = SVMWithSGD.train(training, numIterations) - -// Clear the default threshold. -model.clearThreshold() - -// Compute raw scores on the test set. -val scoreAndLabels = test.map { point => - val score = model.predict(point.features) - (score, point.label) -} - -// Get evaluation metrics. -val metrics = new BinaryClassificationMetrics(scoreAndLabels) -val auROC = metrics.areaUnderROC() - -println("Area under ROC = " + auROC) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = SVMModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala %} The `SVMWithSGD.train()` method by default performs L2 regularization with the regularization parameter set to 1.0. If we want to configure this algorithm, we can customize `SVMWithSGD` further by creating a new object directly and -calling setter methods. All other MLlib algorithms support customization in +calling setter methods. All other `spark.mllib` algorithms support customization in this way as well. For example, the following code produces an L1 regularized variant of SVMs with regularization parameter set to 0.1, and runs the training algorithm for 200 iterations. {% highlight scala %} + import org.apache.spark.mllib.optimization.L1Updater val svmAlg = new SVMWithSGD() @@ -234,66 +203,12 @@ that is equivalent to the provided example in Scala is given below: Refer to the [`SVMWithSGD` Java docs](api/java/org/apache/spark/mllib/classification/SVMWithSGD.html) and [`SVMModel` Java docs](api/java/org/apache/spark/mllib/classification/SVMModel.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.classification.*; -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; - -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class SVMClassifier { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("SVM Classifier Example"); - SparkContext sc = new SparkContext(conf); - String path = "data/mllib/sample_libsvm_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); - - // Split initial RDD into two... [60% training data, 40% testing data]. - JavaRDD training = data.sample(false, 0.6, 11L); - training.cache(); - JavaRDD test = data.subtract(training); - - // Run training algorithm to build the model. - int numIterations = 100; - final SVMModel model = SVMWithSGD.train(training.rdd(), numIterations); - - // Clear the default threshold. - model.clearThreshold(); - - // Compute raw scores on the test set. - JavaRDD> scoreAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double score = model.predict(p.features()); - return new Tuple2(score, p.label()); - } - } - ); - - // Get evaluation metrics. - BinaryClassificationMetrics metrics = - new BinaryClassificationMetrics(JavaRDD.toRDD(scoreAndLabels)); - double auROC = metrics.areaUnderROC(); - - System.out.println("Area under ROC = " + auROC); - - // Save and load model - model.save(sc, "myModelPath"); - SVMModel sameModel = SVMModel.load(sc, "myModelPath"); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaSVMWithSGDExample.java %} The `SVMWithSGD.train()` method by default performs L2 regularization with the regularization parameter set to 1.0. If we want to configure this algorithm, we can customize `SVMWithSGD` further by creating a new object directly and -calling setter methods. All other MLlib algorithms support customization in +calling setter methods. All other `spark.mllib` algorithms support customization in this way as well. For example, the following code produces an L1 regularized variant of SVMs with regularization parameter set to 0.1, and runs the training algorithm for 200 iterations. @@ -322,30 +237,7 @@ and make predictions with the resulting model to compute the training error. Refer to the [`SVMWithSGD` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.SVMWithSGD) and [`SVMModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.SVMModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.classification import SVMWithSGD, SVMModel -from pyspark.mllib.regression import LabeledPoint - -# Load and parse the data -def parsePoint(line): - values = [float(x) for x in line.split(' ')] - return LabeledPoint(values[0], values[1:]) - -data = sc.textFile("data/mllib/sample_svm_data.txt") -parsedData = data.map(parsePoint) - -# Build the model -model = SVMWithSGD.train(parsedData, iterations=100) - -# Evaluating the model on training data -labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) -trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) -print("Training Error = " + str(trainErr)) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = SVMModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/svm_with_sgd_example.py %} @@ -375,7 +267,7 @@ Binary logistic regression can be generalized into train and predict multiclass classification problems. For example, for $K$ possible outcomes, one of the outcomes can be chosen as a "pivot", and the other $K - 1$ outcomes can be separately regressed against the pivot outcome. -In MLlib, the first class $0$ is chosen as the "pivot" class. +In `spark.mllib`, the first class $0$ is chosen as the "pivot" class. See Section 4.4 of [The Elements of Statistical Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for references. @@ -403,42 +295,7 @@ Then the model is evaluated against the test dataset and saved to disk. Refer to the [`LogisticRegressionWithLBFGS` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS) and [`LogisticRegressionModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionModel) for details on the API. -{% highlight scala %} -import org.apache.spark.SparkContext -import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, LogisticRegressionModel} -import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.MLUtils - -// Load training data in LIBSVM format. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") - -// Split data into training (60%) and test (40%). -val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) -val training = splits(0).cache() -val test = splits(1) - -// Run training algorithm to build the model -val model = new LogisticRegressionWithLBFGS() - .setNumClasses(10) - .run(training) - -// Compute raw scores on the test set. -val predictionAndLabels = test.map { case LabeledPoint(label, features) => - val prediction = model.predict(features) - (prediction, label) -} - -// Get evaluation metrics. -val metrics = new MulticlassMetrics(predictionAndLabels) -val precision = metrics.precision -println("Precision = " + precision) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = LogisticRegressionModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/LogisticRegressionWithLBFGSExample.scala %} @@ -451,57 +308,7 @@ Then the model is evaluated against the test dataset and saved to disk. Refer to the [`LogisticRegressionWithLBFGS` Java docs](api/java/org/apache/spark/mllib/classification/LogisticRegressionWithLBFGS.html) and [`LogisticRegressionModel` Java docs](api/java/org/apache/spark/mllib/classification/LogisticRegressionModel.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.classification.LogisticRegressionModel; -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; -import org.apache.spark.mllib.evaluation.MulticlassMetrics; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class MultinomialLogisticRegressionExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("SVM Classifier Example"); - SparkContext sc = new SparkContext(conf); - String path = "data/mllib/sample_libsvm_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); - - // Split initial RDD into two... [60% training data, 40% testing data]. - JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); - JavaRDD training = splits[0].cache(); - JavaRDD test = splits[1]; - - // Run training algorithm to build the model. - final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() - .setNumClasses(10) - .run(training.rdd()); - - // Compute raw scores on the test set. - JavaRDD> predictionAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double prediction = model.predict(p.features()); - return new Tuple2(prediction, p.label()); - } - } - ); - - // Get evaluation metrics. - MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); - double precision = metrics.precision(); - System.out.println("Precision = " + precision); - - // Save and load model - model.save(sc, "myModelPath"); - LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java %}
    @@ -513,30 +320,7 @@ will in the future. Refer to the [`LogisticRegressionWithLBFGS` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.LogisticRegressionWithLBFGS) and [`LogisticRegressionModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.LogisticRegressionModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel -from pyspark.mllib.regression import LabeledPoint - -# Load and parse the data -def parsePoint(line): - values = [float(x) for x in line.split(' ')] - return LabeledPoint(values[0], values[1:]) - -data = sc.textFile("data/mllib/sample_svm_data.txt") -parsedData = data.map(parsePoint) - -# Build the model -model = LogisticRegressionWithLBFGS.train(parsedData) - -# Evaluating the model on training data -labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) -trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) -print("Training Error = " + str(trainErr)) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = LogisticRegressionModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/logistic_regression_with_lbfgs_example.py %}
    @@ -572,35 +356,7 @@ values. We compute the mean squared error at the end to evaluate Refer to the [`LinearRegressionWithSGD` Scala docs](api/scala/index.html#org.apache.spark.mllib.regression.LinearRegressionWithSGD) and [`LinearRegressionModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.regression.LinearRegressionModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.regression.LinearRegressionModel -import org.apache.spark.mllib.regression.LinearRegressionWithSGD -import org.apache.spark.mllib.linalg.Vectors - -// Load and parse the data -val data = sc.textFile("data/mllib/ridge-data/lpsa.data") -val parsedData = data.map { line => - val parts = line.split(',') - LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble))) -}.cache() - -// Building the model -val numIterations = 100 -val model = LinearRegressionWithSGD.train(parsedData, numIterations) - -// Evaluate model on training examples and compute training error -val valuesAndPreds = parsedData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val MSE = valuesAndPreds.map{case(v, p) => math.pow((v - p), 2)}.mean() -println("training Mean Squared Error = " + MSE) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = LinearRegressionModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala %} [`RidgeRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.RidgeRegressionWithSGD) and [`LassoWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.LassoWithSGD) can be used in a similar fashion as `LinearRegressionWithSGD`. @@ -616,69 +372,7 @@ the Scala snippet provided, is presented below: Refer to the [`LinearRegressionWithSGD` Java docs](api/java/org/apache/spark/mllib/regression/LinearRegressionWithSGD.html) and [`LinearRegressionModel` Java docs](api/java/org/apache/spark/mllib/regression/LinearRegressionModel.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.regression.LinearRegressionModel; -import org.apache.spark.mllib.regression.LinearRegressionWithSGD; -import org.apache.spark.SparkConf; - -public class LinearRegression { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); - JavaSparkContext sc = new JavaSparkContext(conf); - - // Load and parse the data - String path = "data/mllib/ridge-data/lpsa.data"; - JavaRDD data = sc.textFile(path); - JavaRDD parsedData = data.map( - new Function() { - public LabeledPoint call(String line) { - String[] parts = line.split(","); - String[] features = parts[1].split(" "); - double[] v = new double[features.length]; - for (int i = 0; i < features.length - 1; i++) - v[i] = Double.parseDouble(features[i]); - return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); - } - } - ); - parsedData.cache(); - - // Building the model - int numIterations = 100; - final LinearRegressionModel model = - LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); - - // Evaluate model on training examples and compute training error - JavaRDD> valuesAndPreds = parsedData.map( - new Function>() { - public Tuple2 call(LabeledPoint point) { - double prediction = model.predict(point.features()); - return new Tuple2(prediction, point.label()); - } - } - ); - double MSE = new JavaDoubleRDD(valuesAndPreds.map( - new Function, Object>() { - public Object call(Tuple2 pair) { - return Math.pow(pair._1() - pair._2(), 2.0); - } - } - ).rdd()).mean(); - System.out.println("training Mean Squared Error = " + MSE); - - // Save and load model - model.save(sc.sc(), "myModelPath"); - LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath"); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java %}
    @@ -691,29 +385,7 @@ Note that the Python API does not yet support model save/load but will in the fu Refer to the [`LinearRegressionWithSGD` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.regression.LinearRegressionWithSGD) and [`LinearRegressionModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.regression.LinearRegressionModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD, LinearRegressionModel - -# Load and parse the data -def parsePoint(line): - values = [float(x) for x in line.replace(',', ' ').split(' ')] - return LabeledPoint(values[0], values[1:]) - -data = sc.textFile("data/mllib/ridge-data/lpsa.data") -parsedData = data.map(parsePoint) - -# Build the model -model = LinearRegressionWithSGD.train(parsedData) - -# Evaluate the model on training data -valuesAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) -MSE = valuesAndPreds.map(lambda (v, p): (v - p)**2).reduce(lambda x, y: x + y) / valuesAndPreds.count() -print("Mean Squared Error = " + str(MSE)) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = LinearRegressionModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/linear_regression_with_sgd_example.py %}
    @@ -726,7 +398,7 @@ a dependency. ###Streaming linear regression When data arrive in a streaming fashion, it is useful to fit regression models online, -updating the parameters of the model as new data arrives. MLlib currently supports +updating the parameters of the model as new data arrives. `spark.mllib` currently supports streaming linear regression using ordinary least squares. The fitting is similar to that performed offline, except fitting occurs on each batch of data, so that the model continually updates to reflect the data from the stream. @@ -743,108 +415,50 @@ online to the first stream, and make predictions on the second stream. First, we import the necessary classes for parsing our input data and creating the model. -{% highlight scala %} - -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD - -{% endhighlight %} - Then we make input streams for training and testing data. We assume a StreamingContext `ssc` has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. For this example, we use labeled points in training and testing streams, but in practice you will likely want to use unlabeled vectors for test data. -{% highlight scala %} - -val trainingData = ssc.textFileStream("/training/data/dir").map(LabeledPoint.parse).cache() -val testData = ssc.textFileStream("/testing/data/dir").map(LabeledPoint.parse) - -{% endhighlight %} +We create our model by initializing the weights to zero and register the streams for training and +testing then start the job. Printing predictions alongside true labels lets us easily see the +result. -We create our model by initializing the weights to 0 - -{% highlight scala %} - -val numFeatures = 3 -val model = new StreamingLinearRegressionWithSGD() - .setInitialWeights(Vectors.zeros(numFeatures)) - -{% endhighlight %} - -Now we register the streams for training and testing and start the job. -Printing predictions alongside true labels lets us easily see the result. - -{% highlight scala %} - -model.trainOn(trainingData) -model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() - -ssc.start() -ssc.awaitTermination() - -{% endhighlight %} - -We can now save text files with data to the training or testing folders. +Finally we can save text files with data to the training or testing folders. Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label -and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` -the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. +and `x1,x2,x3` are the features. Anytime a text file is placed in `args(0)` +the model will update. Anytime a text file is placed in `args(1)` you will see predictions. As you feed more data to the training directory, the predictions will get better! +Here is a complete example: +{% include_example scala/org/apache/spark/examples/mllib/StreamingLinearRegressionExample.scala %} +
    First, we import the necessary classes for parsing our input data and creating the model. -{% highlight python %} -from pyspark.mllib.linalg import Vectors -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.regression import StreamingLinearRegressionWithSGD -{% endhighlight %} - Then we make input streams for training and testing data. We assume a StreamingContext `ssc` has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. For this example, we use labeled points in training and testing streams, but in practice you will likely want to use unlabeled vectors for test data. -{% highlight python %} -def parse(lp): - label = float(lp[lp.find('(') + 1: lp.find(',')]) - vec = Vectors.dense(lp[lp.find('[') + 1: lp.find(']')].split(',')) - return LabeledPoint(label, vec) - -trainingData = ssc.textFileStream("/training/data/dir").map(parse).cache() -testData = ssc.textFileStream("/testing/data/dir").map(parse) -{% endhighlight %} - -We create our model by initializing the weights to 0 - -{% highlight python %} -numFeatures = 3 -model = StreamingLinearRegressionWithSGD() -model.setInitialWeights([0.0, 0.0, 0.0]) -{% endhighlight %} +We create our model by initializing the weights to 0. Now we register the streams for training and testing and start the job. -{% highlight python %} -model.trainOn(trainingData) -print(model.predictOnValues(testData.map(lambda lp: (lp.label, lp.features)))) - -ssc.start() -ssc.awaitTermination() -{% endhighlight %} - We can now save text files with data to the training or testing folders. Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label -and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` -the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. +and `x1,x2,x3` are the features. Anytime a text file is placed in `sys.argv[1]` +the model will update. Anytime a text file is placed in `sys.argv[2]` you will see predictions. As you feed more data to the training directory, the predictions will get better! +Here a complete example: +{% include_example python/mllib/streaming_linear_regression_example.py %} +
    @@ -852,7 +466,7 @@ will get better! # Implementation (developer) -Behind the scene, MLlib implements a simple distributed version of stochastic gradient descent +Behind the scene, `spark.mllib` implements a simple distributed version of stochastic gradient descent (SGD), building on the underlying gradient descent primitive (as described in the optimization section). All provided algorithms take as input a regularization parameter (`regParam`) along with various parameters associated with stochastic diff --git a/docs/mllib-migration-guides.md b/docs/mllib-migration-guides.md index 774b85d1f773a..f3daef2dbadbe 100644 --- a/docs/mllib-migration-guides.md +++ b/docs/mllib-migration-guides.md @@ -1,12 +1,31 @@ --- layout: global -title: Old Migration Guides - MLlib -displayTitle: MLlib - Old Migration Guides +title: Old Migration Guides - spark.mllib +displayTitle: Old Migration Guides - spark.mllib description: MLlib migration guides from before Spark SPARK_VERSION_SHORT --- The migration guide for the current Spark version is kept on the [MLlib Programming Guide main page](mllib-guide.html#migration-guide). +## From 1.4 to 1.5 + +In the `spark.mllib` package, there are no breaking API changes but several behavior changes: + +* [SPARK-9005](https://issues.apache.org/jira/browse/SPARK-9005): + `RegressionMetrics.explainedVariance` returns the average regression sum of squares. +* [SPARK-8600](https://issues.apache.org/jira/browse/SPARK-8600): `NaiveBayesModel.labels` become + sorted. +* [SPARK-3382](https://issues.apache.org/jira/browse/SPARK-3382): `GradientDescent` has a default + convergence tolerance `1e-3`, and hence iterations might end earlier than 1.4. + +In the `spark.ml` package, there exists one breaking API change and one behavior change: + +* [SPARK-9268](https://issues.apache.org/jira/browse/SPARK-9268): Java's varargs support is removed + from `Params.setDefault` due to a + [Scala compiler bug](https://issues.scala-lang.org/browse/SI-9013). +* [SPARK-10097](https://issues.apache.org/jira/browse/SPARK-10097): `Evaluator.isLargerBetter` is + added to indicate metric ordering. Metrics like RMSE no longer flip signs as in 1.4. + ## From 1.3 to 1.4 In the `spark.mllib` package, there were several breaking changes, but all in `DeveloperApi` or `Experimental` APIs: diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index 60ac6c7e5bb1a..d0d594af6a4ad 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -1,7 +1,7 @@ --- layout: global -title: Naive Bayes - MLlib -displayTitle: MLlib - Naive Bayes +title: Naive Bayes - spark.mllib +displayTitle: Naive Bayes - spark.mllib --- [Naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) is a simple @@ -12,7 +12,7 @@ distribution of each feature given label, and then it applies Bayes' theorem to compute the conditional probability distribution of label given an observation and use it for prediction. -MLlib supports [multinomial naive +`spark.mllib` supports [multinomial naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). These models are typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md index a3bd130ba077c..f90b66f8e2c44 100644 --- a/docs/mllib-optimization.md +++ b/docs/mllib-optimization.md @@ -1,7 +1,7 @@ --- layout: global -title: Optimization - MLlib -displayTitle: MLlib - Optimization +title: Optimization - spark.mllib +displayTitle: Optimization - spark.mllib --- * Table of contents @@ -87,7 +87,7 @@ in the `$t$`-th iteration, with the input parameter `$s=$ stepSize`. Note that s step-size for SGD methods can often be delicate in practice and is a topic of active research. **Gradients.** -A table of (sub)gradients of the machine learning methods implemented in MLlib, is available in +A table of (sub)gradients of the machine learning methods implemented in `spark.mllib`, is available in the classification and regression section. @@ -140,7 +140,7 @@ other first-order optimization. ### Choosing an Optimization Method -[Linear methods](mllib-linear-methods.html) use optimization internally, and some linear methods in MLlib support both SGD and L-BFGS. +[Linear methods](mllib-linear-methods.html) use optimization internally, and some linear methods in `spark.mllib` support both SGD and L-BFGS. Different optimization methods can have different convergence guarantees depending on the properties of the objective function, and we cannot cover the literature here. In general, when L-BFGS is available, we recommend using it instead of SGD since L-BFGS tends to converge faster (in fewer iterations). @@ -220,154 +220,13 @@ L-BFGS optimizer.
    Refer to the [`LBFGS` Scala docs](api/scala/index.html#org.apache.spark.mllib.optimization.LBFGS) and [`SquaredL2Updater` Scala docs](api/scala/index.html#org.apache.spark.mllib.optimization.SquaredL2Updater) for details on the API. -{% highlight scala %} -import org.apache.spark.SparkContext -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.classification.LogisticRegressionModel -import org.apache.spark.mllib.optimization.{LBFGS, LogisticGradient, SquaredL2Updater} - -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -val numFeatures = data.take(1)(0).features.size - -// Split data into training (60%) and test (40%). -val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) - -// Append 1 into the training data as intercept. -val training = splits(0).map(x => (x.label, MLUtils.appendBias(x.features))).cache() - -val test = splits(1) - -// Run training algorithm to build the model -val numCorrections = 10 -val convergenceTol = 1e-4 -val maxNumIterations = 20 -val regParam = 0.1 -val initialWeightsWithIntercept = Vectors.dense(new Array[Double](numFeatures + 1)) - -val (weightsWithIntercept, loss) = LBFGS.runLBFGS( - training, - new LogisticGradient(), - new SquaredL2Updater(), - numCorrections, - convergenceTol, - maxNumIterations, - regParam, - initialWeightsWithIntercept) - -val model = new LogisticRegressionModel( - Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)), - weightsWithIntercept(weightsWithIntercept.size - 1)) - -// Clear the default threshold. -model.clearThreshold() - -// Compute raw scores on the test set. -val scoreAndLabels = test.map { point => - val score = model.predict(point.features) - (score, point.label) -} - -// Get evaluation metrics. -val metrics = new BinaryClassificationMetrics(scoreAndLabels) -val auROC = metrics.areaUnderROC() - -println("Loss of each step in training process") -loss.foreach(println) -println("Area under ROC = " + auROC) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/LBFGSExample.scala %}
    Refer to the [`LBFGS` Java docs](api/java/org/apache/spark/mllib/optimization/LBFGS.html) and [`SquaredL2Updater` Java docs](api/java/org/apache/spark/mllib/optimization/SquaredL2Updater.html) for details on the API. -{% highlight java %} -import java.util.Arrays; -import java.util.Random; - -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.classification.LogisticRegressionModel; -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.optimization.*; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class LBFGSExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("L-BFGS Example"); - SparkContext sc = new SparkContext(conf); - String path = "data/mllib/sample_libsvm_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); - int numFeatures = data.take(1).get(0).features().size(); - - // Split initial RDD into two... [60% training data, 40% testing data]. - JavaRDD trainingInit = data.sample(false, 0.6, 11L); - JavaRDD test = data.subtract(trainingInit); - - // Append 1 into the training data as intercept. - JavaRDD> training = data.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - return new Tuple2(p.label(), MLUtils.appendBias(p.features())); - } - }); - training.cache(); - - // Run training algorithm to build the model. - int numCorrections = 10; - double convergenceTol = 1e-4; - int maxNumIterations = 20; - double regParam = 0.1; - Vector initialWeightsWithIntercept = Vectors.dense(new double[numFeatures + 1]); - - Tuple2 result = LBFGS.runLBFGS( - training.rdd(), - new LogisticGradient(), - new SquaredL2Updater(), - numCorrections, - convergenceTol, - maxNumIterations, - regParam, - initialWeightsWithIntercept); - Vector weightsWithIntercept = result._1(); - double[] loss = result._2(); - - final LogisticRegressionModel model = new LogisticRegressionModel( - Vectors.dense(Arrays.copyOf(weightsWithIntercept.toArray(), weightsWithIntercept.size() - 1)), - (weightsWithIntercept.toArray())[weightsWithIntercept.size() - 1]); - - // Clear the default threshold. - model.clearThreshold(); - - // Compute raw scores on the test set. - JavaRDD> scoreAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double score = model.predict(p.features()); - return new Tuple2(score, p.label()); - } - }); - - // Get evaluation metrics. - BinaryClassificationMetrics metrics = - new BinaryClassificationMetrics(scoreAndLabels.rdd()); - double auROC = metrics.areaUnderROC(); - - System.out.println("Loss of each step in training process"); - for (double l : loss) - System.out.println(l); - System.out.println("Area under ROC = " + auROC); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaLBFGSExample.java %}
    diff --git a/docs/mllib-pmml-model-export.md b/docs/mllib-pmml-model-export.md index 615287125c032..58ed5a0e9d702 100644 --- a/docs/mllib-pmml-model-export.md +++ b/docs/mllib-pmml-model-export.md @@ -1,21 +1,21 @@ --- layout: global -title: PMML model export - MLlib -displayTitle: MLlib - PMML model export +title: PMML model export - spark.mllib +displayTitle: PMML model export - spark.mllib --- * Table of contents {:toc} -## MLlib supported models +## `spark.mllib` supported models -MLlib supports model export to Predictive Model Markup Language ([PMML](http://en.wikipedia.org/wiki/Predictive_Model_Markup_Language)). +`spark.mllib` supports model export to Predictive Model Markup Language ([PMML](http://en.wikipedia.org/wiki/Predictive_Model_Markup_Language)). -The table below outlines the MLlib models that can be exported to PMML and their equivalent PMML model. +The table below outlines the `spark.mllib` models that can be exported to PMML and their equivalent PMML model.
    - + @@ -45,41 +45,12 @@ The table below outlines the MLlib models that can be exported to PMML and their
    To export a supported `model` (see table above) to PMML, simply call `model.toPMML`. +As well as exporting the PMML model to a String (`model.toPMML` as in the example above), you can export the PMML model to other formats. + Refer to the [`KMeans` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.KMeans) and [`Vectors` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors) for details on the API. Here a complete example of building a KMeansModel and print it out in PMML format: -{% highlight scala %} -import org.apache.spark.mllib.clustering.KMeans -import org.apache.spark.mllib.linalg.Vectors - -// Load and parse the data -val data = sc.textFile("data/mllib/kmeans_data.txt") -val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))).cache() - -// Cluster the data into two classes using KMeans -val numClusters = 2 -val numIterations = 20 -val clusters = KMeans.train(parsedData, numClusters, numIterations) - -// Export to PMML -println("PMML Model:\n" + clusters.toPMML) -{% endhighlight %} - -As well as exporting the PMML model to a String (`model.toPMML` as in the example above), you can export the PMML model to other formats: - -{% highlight scala %} -// Export the model to a String in PMML format -clusters.toPMML - -// Export the model to a local file in PMML format -clusters.toPMML("/tmp/kmeans.xml") - -// Export the model to a directory on a distributed file system in PMML format -clusters.toPMML(sc,"/tmp/kmeans") - -// Export the model to the OutputStream in PMML format -clusters.toPMML(System.out) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala %} For unsupported models, either you will not find a `.toPMML` method or an `IllegalArgumentException` will be thrown. diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index 2c7c9ed693fd4..02b81f153bf7f 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -1,7 +1,7 @@ --- layout: global -title: Basic Statistics - MLlib -displayTitle: MLlib - Basic Statistics +title: Basic Statistics - spark.mllib +displayTitle: Basic Statistics - spark.mllib --- * Table of contents @@ -10,24 +10,24 @@ displayTitle: MLlib - Basic Statistics `\[ \newcommand{\R}{\mathbb{R}} -\newcommand{\E}{\mathbb{E}} +\newcommand{\E}{\mathbb{E}} \newcommand{\x}{\mathbf{x}} \newcommand{\y}{\mathbf{y}} \newcommand{\wv}{\mathbf{w}} \newcommand{\av}{\mathbf{\alpha}} \newcommand{\bv}{\mathbf{b}} \newcommand{\N}{\mathbb{N}} -\newcommand{\id}{\mathbf{I}} -\newcommand{\ind}{\mathbf{1}} -\newcommand{\0}{\mathbf{0}} -\newcommand{\unit}{\mathbf{e}} -\newcommand{\one}{\mathbf{1}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} \newcommand{\zero}{\mathbf{0}} \]` -## Summary statistics +## Summary statistics -We provide column summary statistics for `RDD[Vector]` through the function `colStats` +We provide column summary statistics for `RDD[Vector]` through the function `colStats` available in `Statistics`.
    @@ -40,19 +40,7 @@ total count. Refer to the [`MultivariateStatisticalSummary` Scala docs](api/scala/index.html#org.apache.spark.mllib.stat.MultivariateStatisticalSummary) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} - -val observations: RDD[Vector] = ... // an RDD of Vectors - -// Compute column summary statistics. -val summary: MultivariateStatisticalSummary = Statistics.colStats(observations) -println(summary.mean) // a dense vector containing the mean value for each column -println(summary.variance) // column-wise variance -println(summary.numNonzeros) // number of nonzeros in each column - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/SummaryStatisticsExample.scala %}
    @@ -64,24 +52,7 @@ total count. Refer to the [`MultivariateStatisticalSummary` Java docs](api/java/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.html) for details on the API. -{% highlight java %} -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.stat.MultivariateStatisticalSummary; -import org.apache.spark.mllib.stat.Statistics; - -JavaSparkContext jsc = ... - -JavaRDD mat = ... // an RDD of Vectors - -// Compute column summary statistics. -MultivariateStatisticalSummary summary = Statistics.colStats(mat.rdd()); -System.out.println(summary.mean()); // a dense vector containing the mean value for each column -System.out.println(summary.variance()); // column-wise variance -System.out.println(summary.numNonzeros()); // number of nonzeros in each column - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaSummaryStatisticsExample.java %}
    @@ -92,306 +63,124 @@ total count. Refer to the [`MultivariateStatisticalSummary` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.stat.MultivariateStatisticalSummary) for more details on the API. -{% highlight python %} -from pyspark.mllib.stat import Statistics - -sc = ... # SparkContext - -mat = ... # an RDD of Vectors - -# Compute column summary statistics. -summary = Statistics.colStats(mat) -print(summary.mean()) -print(summary.variance()) -print(summary.numNonzeros()) - -{% endhighlight %} +{% include_example python/mllib/summary_statistics_example.py %}
    ## Correlations -Calculating the correlation between two series of data is a common operation in Statistics. In MLlib -we provide the flexibility to calculate pairwise correlations among many series. The supported +Calculating the correlation between two series of data is a common operation in Statistics. In `spark.mllib` +we provide the flexibility to calculate pairwise correlations among many series. The supported correlation methods are currently Pearson's and Spearman's correlation. - +
    -[`Statistics`](api/scala/index.html#org.apache.spark.mllib.stat.Statistics$) provides methods to -calculate correlations between series. Depending on the type of input, two `RDD[Double]`s or +[`Statistics`](api/scala/index.html#org.apache.spark.mllib.stat.Statistics$) provides methods to +calculate correlations between series. Depending on the type of input, two `RDD[Double]`s or an `RDD[Vector]`, the output will be a `Double` or the correlation `Matrix` respectively. Refer to the [`Statistics` Scala docs](api/scala/index.html#org.apache.spark.mllib.stat.Statistics) for details on the API. -{% highlight scala %} -import org.apache.spark.SparkContext -import org.apache.spark.mllib.linalg._ -import org.apache.spark.mllib.stat.Statistics - -val sc: SparkContext = ... - -val seriesX: RDD[Double] = ... // a series -val seriesY: RDD[Double] = ... // must have the same number of partitions and cardinality as seriesX - -// compute the correlation using Pearson's method. Enter "spearman" for Spearman's method. If a -// method is not specified, Pearson's method will be used by default. -val correlation: Double = Statistics.corr(seriesX, seriesY, "pearson") - -val data: RDD[Vector] = ... // note that each Vector is a row and not a column - -// calculate the correlation matrix using Pearson's method. Use "spearman" for Spearman's method. -// If a method is not specified, Pearson's method will be used by default. -val correlMatrix: Matrix = Statistics.corr(data, "pearson") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/CorrelationsExample.scala %}
    -[`Statistics`](api/java/org/apache/spark/mllib/stat/Statistics.html) provides methods to -calculate correlations between series. Depending on the type of input, two `JavaDoubleRDD`s or +[`Statistics`](api/java/org/apache/spark/mllib/stat/Statistics.html) provides methods to +calculate correlations between series. Depending on the type of input, two `JavaDoubleRDD`s or a `JavaRDD`, the output will be a `Double` or the correlation `Matrix` respectively. Refer to the [`Statistics` Java docs](api/java/org/apache/spark/mllib/stat/Statistics.html) for details on the API. -{% highlight java %} -import org.apache.spark.api.java.JavaDoubleRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.*; -import org.apache.spark.mllib.stat.Statistics; - -JavaSparkContext jsc = ... - -JavaDoubleRDD seriesX = ... // a series -JavaDoubleRDD seriesY = ... // must have the same number of partitions and cardinality as seriesX - -// compute the correlation using Pearson's method. Enter "spearman" for Spearman's method. If a -// method is not specified, Pearson's method will be used by default. -Double correlation = Statistics.corr(seriesX.srdd(), seriesY.srdd(), "pearson"); - -JavaRDD data = ... // note that each Vector is a row and not a column - -// calculate the correlation matrix using Pearson's method. Use "spearman" for Spearman's method. -// If a method is not specified, Pearson's method will be used by default. -Matrix correlMatrix = Statistics.corr(data.rdd(), "pearson"); - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaCorrelationsExample.java %}
    -[`Statistics`](api/python/pyspark.mllib.html#pyspark.mllib.stat.Statistics) provides methods to -calculate correlations between series. Depending on the type of input, two `RDD[Double]`s or +[`Statistics`](api/python/pyspark.mllib.html#pyspark.mllib.stat.Statistics) provides methods to +calculate correlations between series. Depending on the type of input, two `RDD[Double]`s or an `RDD[Vector]`, the output will be a `Double` or the correlation `Matrix` respectively. Refer to the [`Statistics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.stat.Statistics) for more details on the API. -{% highlight python %} -from pyspark.mllib.stat import Statistics - -sc = ... # SparkContext - -seriesX = ... # a series -seriesY = ... # must have the same number of partitions and cardinality as seriesX - -# Compute the correlation using Pearson's method. Enter "spearman" for Spearman's method. If a -# method is not specified, Pearson's method will be used by default. -print(Statistics.corr(seriesX, seriesY, method="pearson")) - -data = ... # an RDD of Vectors -# calculate the correlation matrix using Pearson's method. Use "spearman" for Spearman's method. -# If a method is not specified, Pearson's method will be used by default. -print(Statistics.corr(data, method="pearson")) - -{% endhighlight %} +{% include_example python/mllib/correlations_example.py %}
    ## Stratified sampling -Unlike the other statistics functions, which reside in MLlib, stratified sampling methods, +Unlike the other statistics functions, which reside in `spark.mllib`, stratified sampling methods, `sampleByKey` and `sampleByKeyExact`, can be performed on RDD's of key-value pairs. For stratified -sampling, the keys can be thought of as a label and the value as a specific attribute. For example -the key can be man or woman, or document ids, and the respective values can be the list of ages -of the people in the population or the list of words in the documents. The `sampleByKey` method -will flip a coin to decide whether an observation will be sampled or not, therefore requires one -pass over the data, and provides an *expected* sample size. `sampleByKeyExact` requires significant +sampling, the keys can be thought of as a label and the value as a specific attribute. For example +the key can be man or woman, or document ids, and the respective values can be the list of ages +of the people in the population or the list of words in the documents. The `sampleByKey` method +will flip a coin to decide whether an observation will be sampled or not, therefore requires one +pass over the data, and provides an *expected* sample size. `sampleByKeyExact` requires significant more resources than the per-stratum simple random sampling used in `sampleByKey`, but will provide -the exact sampling size with 99.99% confidence. `sampleByKeyExact` is currently not supported in +the exact sampling size with 99.99% confidence. `sampleByKeyExact` is currently not supported in python.
    [`sampleByKeyExact()`](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions) allows users to -sample exactly $\lceil f_k \cdot n_k \rceil \, \forall k \in K$ items, where $f_k$ is the desired +sample exactly $\lceil f_k \cdot n_k \rceil \, \forall k \in K$ items, where $f_k$ is the desired fraction for key $k$, $n_k$ is the number of key-value pairs for key $k$, and $K$ is the set of -keys. Sampling without replacement requires one additional pass over the RDD to guarantee sample +keys. Sampling without replacement requires one additional pass over the RDD to guarantee sample size, whereas sampling with replacement requires two additional passes. -{% highlight scala %} -import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.PairRDDFunctions - -val sc: SparkContext = ... - -val data = ... // an RDD[(K, V)] of any key value pairs -val fractions: Map[K, Double] = ... // specify the exact fraction desired from each key - -// Get an exact sample from each stratum -val approxSample = data.sampleByKey(withReplacement = false, fractions) -val exactSample = data.sampleByKeyExact(withReplacement = false, fractions) - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala %}
    [`sampleByKeyExact()`](api/java/org/apache/spark/api/java/JavaPairRDD.html) allows users to -sample exactly $\lceil f_k \cdot n_k \rceil \, \forall k \in K$ items, where $f_k$ is the desired +sample exactly $\lceil f_k \cdot n_k \rceil \, \forall k \in K$ items, where $f_k$ is the desired fraction for key $k$, $n_k$ is the number of key-value pairs for key $k$, and $K$ is the set of -keys. Sampling without replacement requires one additional pass over the RDD to guarantee sample +keys. Sampling without replacement requires one additional pass over the RDD to guarantee sample size, whereas sampling with replacement requires two additional passes. -{% highlight java %} -import java.util.Map; - -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaSparkContext; - -JavaSparkContext jsc = ... - -JavaPairRDD data = ... // an RDD of any key value pairs -Map fractions = ... // specify the exact fraction desired from each key - -// Get an exact sample from each stratum -JavaPairRDD approxSample = data.sampleByKey(false, fractions); -JavaPairRDD exactSample = data.sampleByKeyExact(false, fractions); - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java %}
    [`sampleByKey()`](api/python/pyspark.html#pyspark.RDD.sampleByKey) allows users to -sample approximately $\lceil f_k \cdot n_k \rceil \, \forall k \in K$ items, where $f_k$ is the -desired fraction for key $k$, $n_k$ is the number of key-value pairs for key $k$, and $K$ is the +sample approximately $\lceil f_k \cdot n_k \rceil \, \forall k \in K$ items, where $f_k$ is the +desired fraction for key $k$, $n_k$ is the number of key-value pairs for key $k$, and $K$ is the set of keys. *Note:* `sampleByKeyExact()` is currently not supported in Python. -{% highlight python %} - -sc = ... # SparkContext - -data = ... # an RDD of any key value pairs -fractions = ... # specify the exact fraction desired from each key as a dictionary - -approxSample = data.sampleByKey(False, fractions); - -{% endhighlight %} +{% include_example python/mllib/stratified_sampling_example.py %}
    ## Hypothesis testing -Hypothesis testing is a powerful tool in statistics to determine whether a result is statistically -significant, whether this result occurred by chance or not. MLlib currently supports Pearson's +Hypothesis testing is a powerful tool in statistics to determine whether a result is statistically +significant, whether this result occurred by chance or not. `spark.mllib` currently supports Pearson's chi-squared ( $\chi^2$) tests for goodness of fit and independence. The input data types determine -whether the goodness of fit or the independence test is conducted. The goodness of fit test requires +whether the goodness of fit or the independence test is conducted. The goodness of fit test requires an input type of `Vector`, whereas the independence test requires a `Matrix` as input. -MLlib also supports the input type `RDD[LabeledPoint]` to enable feature selection via chi-squared +`spark.mllib` also supports the input type `RDD[LabeledPoint]` to enable feature selection via chi-squared independence tests.
    -[`Statistics`](api/scala/index.html#org.apache.spark.mllib.stat.Statistics$) provides methods to -run Pearson's chi-squared tests. The following example demonstrates how to run and interpret +[`Statistics`](api/scala/index.html#org.apache.spark.mllib.stat.Statistics$) provides methods to +run Pearson's chi-squared tests. The following example demonstrates how to run and interpret hypothesis tests. -{% highlight scala %} -import org.apache.spark.SparkContext -import org.apache.spark.mllib.linalg._ -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.stat.Statistics._ - -val sc: SparkContext = ... - -val vec: Vector = ... // a vector composed of the frequencies of events - -// compute the goodness of fit. If a second vector to test against is not supplied as a parameter, -// the test runs against a uniform distribution. -val goodnessOfFitTestResult = Statistics.chiSqTest(vec) -println(goodnessOfFitTestResult) // summary of the test including the p-value, degrees of freedom, - // test statistic, the method used, and the null hypothesis. - -val mat: Matrix = ... // a contingency matrix - -// conduct Pearson's independence test on the input contingency matrix -val independenceTestResult = Statistics.chiSqTest(mat) -println(independenceTestResult) // summary of the test including the p-value, degrees of freedom... - -val obs: RDD[LabeledPoint] = ... // (feature, label) pairs. - -// The contingency table is constructed from the raw (feature, label) pairs and used to conduct -// the independence test. Returns an array containing the ChiSquaredTestResult for every feature -// against the label. -val featureTestResults: Array[ChiSqTestResult] = Statistics.chiSqTest(obs) -var i = 1 -featureTestResults.foreach { result => - println(s"Column $i:\n$result") - i += 1 -} // summary of the test - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala %}
    -[`Statistics`](api/java/org/apache/spark/mllib/stat/Statistics.html) provides methods to -run Pearson's chi-squared tests. The following example demonstrates how to run and interpret +[`Statistics`](api/java/org/apache/spark/mllib/stat/Statistics.html) provides methods to +run Pearson's chi-squared tests. The following example demonstrates how to run and interpret hypothesis tests. Refer to the [`ChiSqTestResult` Java docs](api/java/org/apache/spark/mllib/stat/test/ChiSqTestResult.html) for details on the API. -{% highlight java %} -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.*; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.stat.Statistics; -import org.apache.spark.mllib.stat.test.ChiSqTestResult; - -JavaSparkContext jsc = ... - -Vector vec = ... // a vector composed of the frequencies of events - -// compute the goodness of fit. If a second vector to test against is not supplied as a parameter, -// the test runs against a uniform distribution. -ChiSqTestResult goodnessOfFitTestResult = Statistics.chiSqTest(vec); -// summary of the test including the p-value, degrees of freedom, test statistic, the method used, -// and the null hypothesis. -System.out.println(goodnessOfFitTestResult); - -Matrix mat = ... // a contingency matrix - -// conduct Pearson's independence test on the input contingency matrix -ChiSqTestResult independenceTestResult = Statistics.chiSqTest(mat); -// summary of the test including the p-value, degrees of freedom... -System.out.println(independenceTestResult); - -JavaRDD obs = ... // an RDD of labeled points - -// The contingency table is constructed from the raw (feature, label) pairs and used to conduct -// the independence test. Returns an array containing the ChiSquaredTestResult for every feature -// against the label. -ChiSqTestResult[] featureTestResults = Statistics.chiSqTest(obs.rdd()); -int i = 1; -for (ChiSqTestResult result : featureTestResults) { - System.out.println("Column " + i + ":"); - System.out.println(result); // summary of the test - i++; -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaHypothesisTestingExample.java %}
    @@ -401,50 +190,18 @@ hypothesis tests. Refer to the [`Statistics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.stat.Statistics) for more details on the API. -{% highlight python %} -from pyspark import SparkContext -from pyspark.mllib.linalg import Vectors, Matrices -from pyspark.mllib.regresssion import LabeledPoint -from pyspark.mllib.stat import Statistics - -sc = SparkContext() - -vec = Vectors.dense(...) # a vector composed of the frequencies of events - -# compute the goodness of fit. If a second vector to test against is not supplied as a parameter, -# the test runs against a uniform distribution. -goodnessOfFitTestResult = Statistics.chiSqTest(vec) -print(goodnessOfFitTestResult) # summary of the test including the p-value, degrees of freedom, - # test statistic, the method used, and the null hypothesis. - -mat = Matrices.dense(...) # a contingency matrix - -# conduct Pearson's independence test on the input contingency matrix -independenceTestResult = Statistics.chiSqTest(mat) -print(independenceTestResult) # summary of the test including the p-value, degrees of freedom... - -obs = sc.parallelize(...) # LabeledPoint(feature, label) . - -# The contingency table is constructed from an RDD of LabeledPoint and used to conduct -# the independence test. Returns an array containing the ChiSquaredTestResult for every feature -# against the label. -featureTestResults = Statistics.chiSqTest(obs) - -for i, result in enumerate(featureTestResults): - print("Column $d:" % (i + 1)) - print(result) -{% endhighlight %} +{% include_example python/mllib/hypothesis_testing_example.py %}
    -Additionally, MLlib provides a 1-sample, 2-sided implementation of the Kolmogorov-Smirnov (KS) test +Additionally, `spark.mllib` provides a 1-sample, 2-sided implementation of the Kolmogorov-Smirnov (KS) test for equality of probability distributions. By providing the name of a theoretical distribution -(currently solely supported for the normal distribution) and its parameters, or a function to +(currently solely supported for the normal distribution) and its parameters, or a function to calculate the cumulative distribution according to a given theoretical distribution, the user can test the null hypothesis that their sample is drawn from that distribution. In the case that the user tests against the normal distribution (`distName="norm"`), but does not provide distribution -parameters, the test initializes to the standard normal distribution and logs an appropriate +parameters, the test initializes to the standard normal distribution and logs an appropriate message.
    @@ -455,21 +212,7 @@ and interpret the hypothesis tests. Refer to the [`Statistics` Scala docs](api/scala/index.html#org.apache.spark.mllib.stat.Statistics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.stat.Statistics - -val data: RDD[Double] = ... // an RDD of sample data - -// run a KS test for the sample versus a standard normal distribution -val testResult = Statistics.kolmogorovSmirnovTest(data, "norm", 0, 1) -println(testResult) // summary of the test including the p-value, test statistic, - // and null hypothesis - // if our p-value indicates significance, we can reject the null hypothesis - -// perform a KS test using a cumulative distribution function of our making -val myCDF: Double => Double = ... -val testResult2 = Statistics.kolmogorovSmirnovTest(data, myCDF) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/HypothesisTestingKolmogorovSmirnovTestExample.scala %}
    @@ -479,23 +222,7 @@ and interpret the hypothesis tests. Refer to the [`Statistics` Java docs](api/java/org/apache/spark/mllib/stat/Statistics.html) for details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaDoubleRDD; -import org.apache.spark.api.java.JavaSparkContext; - -import org.apache.spark.mllib.stat.Statistics; -import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult; - -JavaSparkContext jsc = ... -JavaDoubleRDD data = jsc.parallelizeDoubles(Arrays.asList(0.2, 1.0, ...)); -KolmogorovSmirnovTestResult testResult = Statistics.kolmogorovSmirnovTest(data, "norm", 0.0, 1.0); -// summary of the test including the p-value, test statistic, -// and null hypothesis -// if our p-value indicates significance, we can reject the null hypothesis -System.out.println(testResult); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaHypothesisTestingKolmogorovSmirnovTestExample.java %}
    @@ -505,19 +232,39 @@ and interpret the hypothesis tests. Refer to the [`Statistics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.stat.Statistics) for more details on the API. -{% highlight python %} -from pyspark.mllib.stat import Statistics +{% include_example python/mllib/hypothesis_testing_kolmogorov_smirnov_test_example.py %} +
    + -parallelData = sc.parallelize([1.0, 2.0, ... ]) +### Streaming Significance Testing +`spark.mllib` provides online implementations of some tests to support use cases +like A/B testing. These tests may be performed on a Spark Streaming +`DStream[(Boolean,Double)]` where the first element of each tuple +indicates control group (`false`) or treatment group (`true`) and the +second element is the value of an observation. -# run a KS test for the sample versus a standard normal distribution -testResult = Statistics.kolmogorovSmirnovTest(parallelData, "norm", 0, 1) -print(testResult) # summary of the test including the p-value, test statistic, - # and null hypothesis - # if our p-value indicates significance, we can reject the null hypothesis -# Note that the Scala functionality of calling Statistics.kolmogorovSmirnovTest with -# a lambda to calculate the CDF is not made available in the Python API -{% endhighlight %} +Streaming significance testing supports the following parameters: + +* `peacePeriod` - The number of initial data points from the stream to +ignore, used to mitigate novelty effects. +* `windowSize` - The number of past batches to perform hypothesis +testing over. Setting to `0` will perform cumulative processing using +all prior batches. + + +
    +
    +[`StreamingTest`](api/scala/index.html#org.apache.spark.mllib.stat.test.StreamingTest) +provides streaming hypothesis testing. + +{% include_example scala/org/apache/spark/examples/mllib/StreamingTestExample.scala %} +
    + +
    +[`StreamingTest`](api/java/index.html#org.apache.spark.mllib.stat.test.StreamingTest) +provides streaming hypothesis testing. + +{% include_example java/org/apache/spark/examples/mllib/JavaStreamingTestExample.java %}
    @@ -525,7 +272,7 @@ print(testResult) # summary of the test including the p-value, test statistic, ## Random data generation Random data generation is useful for randomized algorithms, prototyping, and performance testing. -MLlib supports generating random RDDs with i.i.d. values drawn from a given distribution: +`spark.mllib` supports generating random RDDs with i.i.d. values drawn from a given distribution: uniform, standard normal, or Poisson.
    @@ -594,7 +341,7 @@ sc = ... # SparkContext # Generate a random double RDD that contains 1 million i.i.d. values drawn from the # standard normal distribution `N(0, 1)`, evenly distributed in 10 partitions. -u = RandomRDDs.uniformRDD(sc, 1000000L, 10) +u = RandomRDDs.normalRDD(sc, 1000000L, 10) # Apply a transform to get a random double RDD following `N(1, 4)`. v = u.map(lambda x: 1.0 + 2.0 * x) {% endhighlight %} @@ -619,21 +366,7 @@ to do so. Refer to the [`KernelDensity` Scala docs](api/scala/index.html#org.apache.spark.mllib.stat.KernelDensity) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.stat.KernelDensity -import org.apache.spark.rdd.RDD - -val data: RDD[Double] = ... // an RDD of sample data - -// Construct the density estimator with the sample data and a standard deviation for the Gaussian -// kernels -val kd = new KernelDensity() - .setSample(data) - .setBandwidth(3.0) - -// Find density estimates for the given values -val densities = kd.estimate(Array(-1.0, 2.0, 5.0)) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/KernelDensityEstimationExample.scala %}
    @@ -643,21 +376,7 @@ to do so. Refer to the [`KernelDensity` Java docs](api/java/org/apache/spark/mllib/stat/KernelDensity.html) for details on the API. -{% highlight java %} -import org.apache.spark.mllib.stat.KernelDensity; -import org.apache.spark.rdd.RDD; - -RDD data = ... // an RDD of sample data - -// Construct the density estimator with the sample data and a standard deviation for the Gaussian -// kernels -KernelDensity kd = new KernelDensity() - .setSample(data) - .setBandwidth(3.0); - -// Find density estimates for the given values -double[] densities = kd.estimate(new double[] {-1.0, 2.0, 5.0}); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaKernelDensityEstimationExample.java %}
    @@ -667,20 +386,7 @@ to do so. Refer to the [`KernelDensity` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.stat.KernelDensity) for more details on the API. -{% highlight python %} -from pyspark.mllib.stat import KernelDensity - -data = ... # an RDD of sample data - -# Construct the density estimator with the sample data and a standard deviation for the Gaussian -# kernels -kd = KernelDensity() -kd.setSample(data) -kd.setBandwidth(3.0) - -# Find density estimates for the given values -densities = kd.estimate([-1.0, 2.0, 5.0]) -{% endhighlight %} +{% include_example python/mllib/kernel_density_estimation_example.py %}
    diff --git a/docs/monitoring.md b/docs/monitoring.md index cedceb2958023..32d2e02e93eeb 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -8,7 +8,7 @@ There are several ways to monitor Spark applications: web UIs, metrics, and exte # Web Interfaces -Every SparkContext launches a web UI, by default on port 4040, that +Every SparkContext launches a web UI, by default on port 4040, that displays useful information about the application. This includes: * A list of scheduler stages and tasks @@ -32,17 +32,31 @@ Spark's Standalone Mode cluster manager also has its own the course of its lifetime, then the Standalone master's web UI will automatically re-render the application's UI after the application has finished. -If Spark is run on Mesos or YARN, it is still possible to reconstruct the UI of a finished +If Spark is run on Mesos or YARN, it is still possible to construct the UI of an application through Spark's history server, provided that the application's event logs exist. You can start the history server by executing: ./sbin/start-history-server.sh -When using the file-system provider class (see spark.history.provider below), the base logging -directory must be supplied in the spark.history.fs.logDirectory configuration option, -and should contain sub-directories that each represents an application's event logs. This creates a -web interface at `http://:18080` by default. The history server can be configured as -follows: +This creates a web interface at `http://:18080` by default, listing incomplete +and completed applications and attempts. + +When using the file-system provider class (see `spark.history.provider` below), the base logging +directory must be supplied in the `spark.history.fs.logDirectory` configuration option, +and should contain sub-directories that each represents an application's event logs. + +The spark jobs themselves must be configured to log events, and to log them to the same shared, +writeable directory. For example, if the server was configured with a log directory of +`hdfs://namenode/shared/spark-logs`, then the client-side options would be: + +``` +spark.eventLog.enabled true +spark.eventLog.dir hdfs://namenode/shared/spark-logs +``` + +The history server can be configured as follows: + +### Environment Variables
    MLlib modelPMML model
    `spark.mllib` modelPMML model
    @@ -69,11 +83,13 @@ follows:
    Environment VariableMeaning
    +### Spark configuration options + - + @@ -82,15 +98,21 @@ follows: @@ -112,10 +134,10 @@ follows: @@ -137,12 +159,12 @@ follows: @@ -156,15 +178,15 @@ follows:
    Property NameDefaultMeaning
    spark.history.providerorg.apache.spark.deploy.history.FsHistoryProviderorg.apache.spark.deploy.history.FsHistoryProvider Name of the class implementing the application history backend. Currently there is only one implementation, provided by Spark, which looks for application logs stored in the file system.spark.history.fs.logDirectory file:/tmp/spark-events - Directory that contains application event logs to be loaded by the history server + For the filesystem history provider, the URL to the directory containing application event + logs to load. This can be a local file:// path, + an HDFS path hdfs://namenode/shared/spark-logs + or that of an alternative filesystem supported by the Hadoop APIs.
    spark.history.fs.update.interval 10s - The period at which information displayed by this history server is updated. - Each update checks for any changes made to the event logs in persisted storage. + The period at which the filesystem history provider checks for new or + updated logs in the log directory. A shorter interval detects new applications faster, + at the expense of more server load re-reading updated applications. + As soon as an update has completed, listings of the completed and incomplete applications + will reflect the changes.
    spark.history.kerberos.enabled false - Indicates whether the history server should use kerberos to login. This is useful - if the history server is accessing HDFS files on a secure Hadoop cluster. If this is + Indicates whether the history server should use kerberos to login. This is required + if the history server is accessing HDFS files on a secure Hadoop cluster. If this is true, it uses the configs spark.history.kerberos.principal and - spark.history.kerberos.keytab. + spark.history.kerberos.keytab.
    false Specifies whether acls should be checked to authorize users viewing the applications. - If enabled, access control checks are made regardless of what the individual application had + If enabled, access control checks are made regardless of what the individual application had set for spark.ui.acls.enable when the application was run. The application owner - will always have authorization to view their own application and any users specified via + will always have authorization to view their own application and any users specified via spark.ui.view.acls when the application was run will also have authorization - to view that application. - If disabled, no access control checks are made. + to view that application. + If disabled, no access control checks are made.
    spark.history.fs.cleaner.interval 1d - How often the job history cleaner checks for files to delete. - Files are only deleted if they are older than spark.history.fs.cleaner.maxAge. + How often the filesystem job history cleaner checks for files to delete. + Files are only deleted if they are older than spark.history.fs.cleaner.maxAge
    spark.history.fs.cleaner.maxAge 7d - Job history files older than this will be deleted when the history cleaner runs. + Job history files older than this will be deleted when the filesystem history cleaner runs.
    @@ -172,7 +194,25 @@ follows: Note that in all of these UIs, the tables are sortable by clicking their headers, making it easy to identify slow tasks, data skew, etc. -Note that the history server only displays completed Spark jobs. One way to signal the completion of a Spark job is to stop the Spark Context explicitly (`sc.stop()`), or in Python using the `with SparkContext() as sc:` to handle the Spark Context setup and tear down, and still show the job history on the UI. +Note + +1. The history server displays both completed and incomplete Spark jobs. If an application makes +multiple attempts after failures, the failed attempts will be displayed, as well as any ongoing +incomplete attempt or the final successful attempt. + +2. Incomplete applications are only updated intermittently. The time between updates is defined +by the interval between checks for changed files (`spark.history.fs.update.interval`). +On larger clusters the update interval may be set to large values. +The way to view a running application is actually to view its own web UI. + +3. Applications which exited without registering themselves as completed will be listed +as incomplete —even though they are no longer running. This can happen if an application +crashes. + +2. One way to signal the completion of a Spark job is to stop the Spark Context +explicitly (`sc.stop()`), or in Python using the `with SparkContext() as sc:` construct +to handle the Spark Context setup and tear down. + ## REST API @@ -249,7 +289,7 @@ These endpoints have been strongly versioned to make it easier to develop applic * New endpoints may be added * New fields may be added to existing endpoints * New versions of the api may be added in the future at a separate endpoint (eg., `api/v2`). New versions are *not* required to be backwards compatible. -* Api versions may be dropped, but only after at least one minor release of co-existing with a new api version +* Api versions may be dropped, but only after at least one minor release of co-existing with a new api version. Note that even when examining the UI of a running applications, the `applications/[app-id]` portion is still required, though there is only one application available. Eg. to see the list of jobs for the @@ -258,14 +298,14 @@ keep the paths consistent in both modes. # Metrics -Spark has a configurable metrics system based on the -[Coda Hale Metrics Library](http://metrics.codahale.com/). -This allows users to report Spark metrics to a variety of sinks including HTTP, JMX, and CSV -files. The metrics system is configured via a configuration file that Spark expects to be present -at `$SPARK_HOME/conf/metrics.properties`. A custom file location can be specified via the +Spark has a configurable metrics system based on the +[Coda Hale Metrics Library](http://metrics.codahale.com/). +This allows users to report Spark metrics to a variety of sinks including HTTP, JMX, and CSV +files. The metrics system is configured via a configuration file that Spark expects to be present +at `$SPARK_HOME/conf/metrics.properties`. A custom file location can be specified via the `spark.metrics.conf` [configuration property](configuration.html#spark-properties). -Spark's metrics are decoupled into different -_instances_ corresponding to Spark components. Within each instance, you can configure a +Spark's metrics are decoupled into different +_instances_ corresponding to Spark components. Within each instance, you can configure a set of sinks to which metrics are reported. The following instances are currently supported: * `master`: The Spark standalone master process. @@ -290,26 +330,26 @@ licensing restrictions: * `GangliaSink`: Sends metrics to a Ganglia node or multicast group. To install the `GangliaSink` you'll need to perform a custom build of Spark. _**Note that -by embedding this library you will include [LGPL](http://www.gnu.org/copyleft/lesser.html)-licensed -code in your Spark package**_. For sbt users, set the -`SPARK_GANGLIA_LGPL` environment variable before building. For Maven users, enable +by embedding this library you will include [LGPL](http://www.gnu.org/copyleft/lesser.html)-licensed +code in your Spark package**_. For sbt users, set the +`SPARK_GANGLIA_LGPL` environment variable before building. For Maven users, enable the `-Pspark-ganglia-lgpl` profile. In addition to modifying the cluster's Spark build user applications will need to link to the `spark-ganglia-lgpl` artifact. -The syntax of the metrics configuration file is defined in an example configuration file, +The syntax of the metrics configuration file is defined in an example configuration file, `$SPARK_HOME/conf/metrics.properties.template`. # Advanced Instrumentation Several external tools can be used to help profile the performance of Spark jobs: -* Cluster-wide monitoring tools, such as [Ganglia](http://ganglia.sourceforge.net/), can provide -insight into overall cluster utilization and resource bottlenecks. For instance, a Ganglia -dashboard can quickly reveal whether a particular workload is disk bound, network bound, or +* Cluster-wide monitoring tools, such as [Ganglia](http://ganglia.sourceforge.net/), can provide +insight into overall cluster utilization and resource bottlenecks. For instance, a Ganglia +dashboard can quickly reveal whether a particular workload is disk bound, network bound, or CPU bound. -* OS profiling tools such as [dstat](http://dag.wieers.com/home-made/dstat/), -[iostat](http://linux.die.net/man/1/iostat), and [iotop](http://linux.die.net/man/1/iotop) +* OS profiling tools such as [dstat](http://dag.wieers.com/home-made/dstat/), +[iostat](http://linux.die.net/man/1/iostat), and [iotop](http://linux.die.net/man/1/iotop) can provide fine-grained profiling on individual nodes. -* JVM utilities such as `jstack` for providing stack traces, `jmap` for creating heap-dumps, -`jstat` for reporting time-series statistics and `jconsole` for visually exploring various JVM +* JVM utilities such as `jstack` for providing stack traces, `jmap` for creating heap-dumps, +`jstat` for reporting time-series statistics and `jconsole` for visually exploring various JVM properties are useful for those comfortable with JVM internals. diff --git a/docs/programming-guide.md b/docs/programming-guide.md index f823b89a4b5e9..2f0ed5eca2b2b 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -629,7 +629,7 @@ class MyClass { } {% endhighlight %} -is equilvalent to writing `rdd.map(x => this.field + x)`, which references all of `this`. To avoid this +is equivalent to writing `rdd.map(x => this.field + x)`, which references all of `this`. To avoid this issue, the simplest way is to copy `field` into a local variable instead of accessing it externally: {% highlight scala %} @@ -755,7 +755,7 @@ One of the harder things about Spark is understanding the scope and life cycle o #### Example -Consider the naive RDD element sum below, which behaves completely differently depending on whether execution is happening within the same JVM. A common example of this is when running Spark in `local` mode (`--master = local[n]`) versus deploying a Spark application to a cluster (e.g. via spark-submit to YARN): +Consider the naive RDD element sum below, which may behave differently depending on whether execution is happening within the same JVM. A common example of this is when running Spark in `local` mode (`--master = local[n]`) versus deploying a Spark application to a cluster (e.g. via spark-submit to YARN):
    @@ -789,9 +789,12 @@ counter = 0 rdd = sc.parallelize(data) # Wrong: Don't do this!! -rdd.foreach(lambda x: counter += x) +def increment_counter(x): + global counter + counter += x +rdd.foreach(increment_counter) -print("Counter value: " + counter) +print("Counter value: ", counter) {% endhighlight %}
    @@ -800,13 +803,13 @@ print("Counter value: " + counter) #### Local vs. cluster modes -The primary challenge is that the behavior of the above code is undefined. In local mode with a single JVM, the above code will sum the values within the RDD and store it in **counter**. This is because both the RDD and the variable **counter** are in the same memory space on the driver node. +The behavior of the above code is undefined, and may not work as intended. To execute jobs, Spark breaks up the processing of RDD operations into tasks, each of which is executed by an executor. Prior to execution, Spark computes the task's **closure**. The closure is those variables and methods which must be visible for the executor to perform its computations on the RDD (in this case `foreach()`). This closure is serialized and sent to each executor. -However, in `cluster` mode, what happens is more complicated, and the above may not work as intended. To execute jobs, Spark breaks up the processing of RDD operations into tasks - each of which is operated on by an executor. Prior to execution, Spark computes the **closure**. The closure is those variables and methods which must be visible for the executor to perform its computations on the RDD (in this case `foreach()`). This closure is serialized and sent to each executor. In `local` mode, there is only the one executors so everything shares the same closure. In other modes however, this is not the case and the executors running on separate worker nodes each have their own copy of the closure. +The variables within the closure sent to each executor are now copies and thus, when **counter** is referenced within the `foreach` function, it's no longer the **counter** on the driver node. There is still a **counter** in the memory of the driver node but this is no longer visible to the executors! The executors only see the copy from the serialized closure. Thus, the final value of **counter** will still be zero since all operations on **counter** were referencing the value within the serialized closure. -What is happening here is that the variables within the closure sent to each executor are now copies and thus, when **counter** is referenced within the `foreach` function, it's no longer the **counter** on the driver node. There is still a **counter** in the memory of the driver node but this is no longer visible to the executors! The executors only see the copy from the serialized closure. Thus, the final value of **counter** will still be zero since all operations on **counter** were referencing the value within the serialized closure. +In local mode, in some circumstances the `foreach` function will actually execute within the same JVM as the driver and will reference the same original **counter**, and may actually update it. -To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#AccumLink). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail. +To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#accumulators). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail. In general, closures - constructs like loops or locally defined methods, should not be used to mutate some global state. Spark does not define or guarantee the behavior of mutations to objects referenced from outside of closures. Some code that does this may work in local mode, but that's just by accident and such code will not behave as expected in distributed mode. Use an Accumulator instead if some global aggregation is needed. @@ -1091,7 +1094,7 @@ for details. foreach(func) - Run a function func on each element of the dataset. This is usually done for side effects such as updating an Accumulator or interacting with external storage systems. + Run a function func on each element of the dataset. This is usually done for side effects such as updating an Accumulator or interacting with external storage systems.
    Note: modifying variables other than Accumulators outside of the foreach() may result in undefined behavior. See Understanding closures for more details. @@ -1174,7 +1177,7 @@ that originally created it. In addition, each persisted RDD can be stored using a different *storage level*, allowing you, for example, to persist the dataset on disk, persist it in memory but as serialized Java objects (to save space), -replicate it across nodes, or store it off-heap in [Tachyon](http://tachyon-project.org/). +replicate it across nodes. These levels are set by passing a `StorageLevel` object ([Scala](api/scala/index.html#org.apache.spark.storage.StorageLevel), [Java](api/java/index.html?org/apache/spark/storage/StorageLevel.html), @@ -1196,14 +1199,14 @@ storage levels is: partitions that don't fit on disk, and read them from there when they're needed. - MEMORY_ONLY_SER + MEMORY_ONLY_SER
    (Java and Scala) Store RDD as serialized Java objects (one byte array per partition). This is generally more space-efficient than deserialized objects, especially when using a fast serializer, but more CPU-intensive to read. - MEMORY_AND_DISK_SER + MEMORY_AND_DISK_SER
    (Java and Scala) Similar to MEMORY_ONLY_SER, but spill partitions that don't fit in memory to disk instead of recomputing them on the fly each time they're needed. @@ -1215,22 +1218,11 @@ storage levels is: MEMORY_ONLY_2, MEMORY_AND_DISK_2, etc. Same as the levels above, but replicate each partition on two cluster nodes. - - OFF_HEAP (experimental) - Store RDD in serialized format in Tachyon. - Compared to MEMORY_ONLY_SER, OFF_HEAP reduces garbage collection overhead and allows executors - to be smaller and to share a pool of memory, making it attractive in environments with - large heaps or multiple concurrent applications. Furthermore, as the RDDs reside in Tachyon, - the crash of an executor does not lead to losing the in-memory cache. In this mode, the memory - in Tachyon is discardable. Thus, Tachyon does not attempt to reconstruct a block that it evicts - from memory. If you plan to use Tachyon as the off heap store, Spark is compatible with Tachyon - out-of-the-box. Please refer to this page - for the suggested version pairings. - - -**Note:** *In Python, stored objects will always be serialized with the [Pickle](https://docs.python.org/2/library/pickle.html) library, so it does not matter whether you choose a serialized level.* +**Note:** *In Python, stored objects will always be serialized with the [Pickle](https://docs.python.org/2/library/pickle.html) library, +so it does not matter whether you choose a serialized level. The available storage levels in Python include `MEMORY_ONLY`, `MEMORY_ONLY_2`, +`MEMORY_AND_DISK`, `MEMORY_AND_DISK_2`, `DISK_ONLY`, and `DISK_ONLY_2`.* Spark also automatically persists some intermediate data in shuffle operations (e.g. `reduceByKey`), even without users calling `persist`. This is done to avoid recomputing the entire input if a node fails during the shuffle. We still recommend users call `persist` on the resulting RDD if they plan to reuse it. @@ -1243,7 +1235,7 @@ efficiency. We recommend going through the following process to select one: This is the most CPU-efficient option, allowing operations on the RDDs to run as fast as possible. * If not, try using `MEMORY_ONLY_SER` and [selecting a fast serialization library](tuning.html) to -make the objects much more space-efficient, but still reasonably fast to access. +make the objects much more space-efficient, but still reasonably fast to access. (Java and Scala) * Don't spill to disk unless the functions that computed your datasets are expensive, or they filter a large amount of the data. Otherwise, recomputing a partition may be as fast as reading it from @@ -1254,11 +1246,6 @@ requests from a web application). *All* the storage levels provide full fault to recomputing lost data, but the replicated ones let you continue running tasks on the RDD without waiting to recompute a lost partition. -* In environments with high amounts of memory or multiple applications, the experimental `OFF_HEAP` -mode has several advantages: - * It allows multiple executors to share the same pool of memory in Tachyon. - * It significantly reduces garbage collection costs. - * Cached data is not lost if individual executors crash. ### Removing Data @@ -1336,9 +1323,9 @@ run on the cluster so that `v` is not shipped to the nodes more than once. In ad `v` should not be modified after it is broadcast in order to ensure that all nodes get the same value of the broadcast variable (e.g. if the variable is shipped to a new node later). -## Accumulators +## Accumulators -Accumulators are variables that are only "added" to through an associative operation and can +Accumulators are variables that are only "added" to through an associative and commutative operation and can therefore be efficiently supported in parallel. They can be used to implement counters (as in MapReduce) or sums. Spark natively supports accumulators of numeric types, and programmers can add support for new types. If accumulators are created with a name, they will be diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index ec5a44d79212b..4a0ab623c1082 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -32,7 +32,7 @@ To get started, follow the steps below to install Mesos and deploy Spark jobs vi # Installing Mesos -Spark {{site.SPARK_VERSION}} is designed for use with Mesos {{site.MESOS_VERSION}} and does not +Spark {{site.SPARK_VERSION}} is designed for use with Mesos {{site.MESOS_VERSION}} or newer and does not require any special patches of Mesos. If you already have a Mesos cluster running, you can skip this Mesos installation step. @@ -98,17 +98,17 @@ To host on HDFS, use the Hadoop fs put command: `hadoop fs -put spark-{{site.SPA Or if you are using a custom-compiled version of Spark, you will need to create a package using -the `make-distribution.sh` script included in a Spark source tarball/checkout. +the `dev/make-distribution.sh` script included in a Spark source tarball/checkout. 1. Download and build Spark using the instructions [here](index.html) -2. Create a binary package using `make-distribution.sh --tgz`. +2. Create a binary package using `./dev/make-distribution.sh --tgz`. 3. Upload archive to http/s3/hdfs ## Using a Mesos Master URL The Master URLs for Mesos are in the form `mesos://host:5050` for a single-master Mesos -cluster, or `mesos://zk://host:2181` for a multi-master Mesos cluster using ZooKeeper. +cluster, or `mesos://zk://host1:2181,host2:2181,host3:2181/mesos` for a multi-master Mesos cluster using ZooKeeper. ## Client Mode @@ -150,32 +150,45 @@ it does not need to be redundantly passed in as a system property. Spark on Mesos also supports cluster mode, where the driver is launched in the cluster and the client can find the results of the driver from the Mesos Web UI. -To use cluster mode, you must start the MesosClusterDispatcher in your cluster via the `sbin/start-mesos-dispatcher.sh` script, -passing in the Mesos master url (e.g: mesos://host:5050). +To use cluster mode, you must start the `MesosClusterDispatcher` in your cluster via the `sbin/start-mesos-dispatcher.sh` script, +passing in the Mesos master URL (e.g: mesos://host:5050). This starts the `MesosClusterDispatcher` as a daemon running on the host. -From the client, you can submit a job to Mesos cluster by running `spark-submit` and specifying the master url -to the url of the MesosClusterDispatcher (e.g: mesos://dispatcher:7077). You can view driver statuses on the +If you like to run the `MesosClusterDispatcher` with Marathon, you need to run the `MesosClusterDispatcher` in the foreground (i.e: `bin/spark-class org.apache.spark.deploy.mesos.MesosClusterDispatcher`). Note that the `MesosClusterDispatcher` not yet supports multiple instances for HA. + +The `MesosClusterDispatcher` also supports writing recovery state into Zookeeper. This will allow the `MesosClusterDispatcher` to be able to recover all submitted and running containers on relaunch. In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env by configuring `spark.deploy.recoveryMode` and related spark.deploy.zookeeper.* configurations. +For more information about these configurations please refer to the configurations (doc)[configurations.html#deploy]. + +From the client, you can submit a job to Mesos cluster by running `spark-submit` and specifying the master URL +to the URL of the `MesosClusterDispatcher` (e.g: mesos://dispatcher:7077). You can view driver statuses on the Spark cluster Web UI. -Note that jars or python files that are passed to spark-submit should be URIs reachable by Mesos slaves. +For example: +{% highlight bash %} +./bin/spark-submit \ + --class org.apache.spark.examples.SparkPi \ + --master mesos://207.184.161.138:7077 \ + --deploy-mode cluster \ + --supervise \ + --executor-memory 20G \ + --total-executor-cores 100 \ + http://path/to/examples.jar \ + 1000 +{% endhighlight %} -# Mesos Run Modes -Spark can run over Mesos in two modes: "fine-grained" (default) and "coarse-grained". +Note that jars or python files that are passed to spark-submit should be URIs reachable by Mesos slaves, as the Spark driver doesn't automatically upload local jars. -In "fine-grained" mode (default), each Spark task runs as a separate Mesos task. This allows -multiple instances of Spark (and other frameworks) to share machines at a very fine granularity, -where each application gets more or fewer machines as it ramps up and down, but it comes with an -additional overhead in launching each task. This mode may be inappropriate for low-latency -requirements like interactive queries or serving web requests. +# Mesos Run Modes -The "coarse-grained" mode will instead launch only *one* long-running Spark task on each Mesos +Spark can run over Mesos in two modes: "coarse-grained" (default) and "fine-grained". + +The "coarse-grained" mode will launch only *one* long-running Spark task on each Mesos machine, and dynamically schedule its own "mini-tasks" within it. The benefit is much lower startup overhead, but at the cost of reserving the Mesos resources for the complete duration of the application. -To run in coarse-grained mode, set the `spark.mesos.coarse` property in your -[SparkConf](configuration.html#spark-properties): +Coarse-grained is the default mode. You can also set `spark.mesos.coarse` property to true +to turn it on explicitly in [SparkConf](configuration.html#spark-properties): {% highlight scala %} conf.set("spark.mesos.coarse", "true") @@ -186,13 +199,26 @@ acquire. By default, it will acquire *all* cores in the cluster (that get offere only makes sense if you run just one application at a time. You can cap the maximum number of cores using `conf.set("spark.cores.max", "10")` (for example). +In "fine-grained" mode, each Spark task runs as a separate Mesos task. This allows +multiple instances of Spark (and other frameworks) to share machines at a very fine granularity, +where each application gets more or fewer machines as it ramps up and down, but it comes with an +additional overhead in launching each task. This mode may be inappropriate for low-latency +requirements like interactive queries or serving web requests. + +To run in fine-grained mode, set the `spark.mesos.coarse` property to false in your +[SparkConf](configuration.html#spark-properties): + +{% highlight scala %} +conf.set("spark.mesos.coarse", "false") +{% endhighlight %} + You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. {% highlight scala %} -conf.set("spark.mesos.constraints", "tachyon:true;us-east-1:false") +conf.set("spark.mesos.constraints", "os:centos7;us-east-1:false") {% endhighlight %} -For example, Let's say `spark.mesos.constraints` is set to `tachyon:true;us-east-1:false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. +For example, Let's say `spark.mesos.constraints` is set to `os:centos7;us-east-1:false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. # Mesos Docker Support @@ -220,18 +246,15 @@ In either case, HDFS runs separately from Hadoop MapReduce, without being schedu # Dynamic Resource Allocation with Mesos -Mesos supports dynamic allocation only with coarse grain mode, which can resize the number of executors based on statistics -of the application. While dynamic allocation supports both scaling up and scaling down the number of executors, the coarse grain scheduler only supports scaling down -since it is already designed to run one executor per slave with the configured amount of resources. However, after scaling down the number of executors the coarse grain scheduler -can scale back up to the same amount of executors when Spark signals more executors are needed. +Mesos supports dynamic allocation only with coarse-grain mode, which can resize the number of +executors based on statistics of the application. For general information, +see [Dynamic Resource Allocation](job-scheduling.html#dynamic-resource-allocation). -Users that like to utilize this feature should launch the Mesos Shuffle Service that -provides shuffle data cleanup functionality on top of the Shuffle Service since Mesos doesn't yet support notifying another framework's -termination. To launch/stop the Mesos Shuffle Service please use the provided sbin/start-mesos-shuffle-service.sh and sbin/stop-mesos-shuffle-service.sh -scripts accordingly. +The External Shuffle Service to use is the Mesos Shuffle Service. It provides shuffle data cleanup functionality +on top of the Shuffle Service since Mesos doesn't yet support notifying another framework's +termination. To launch it, run `$SPARK_HOME/sbin/start-mesos-shuffle-service.sh` on all slave nodes, with `spark.shuffle.service.enabled` set to `true`. -The Shuffle Service is expected to be running on each slave node that will run Spark executors. One way to easily achieve this with Mesos -is to launch the Shuffle Service with Marathon with a unique host constraint. +This can also be achieved through Marathon, using a unique host constraint, and the following command: `bin/spark-class org.apache.spark.deploy.mesos.MesosExternalShuffleService`. # Configuration @@ -243,22 +266,22 @@ See the [configuration page](configuration.html) for information on Spark config Property NameDefaultMeaning spark.mesos.coarse - false + true - If set to true, runs over Mesos clusters in - "coarse-grained" sharing mode, - where Spark acquires one long-lived Mesos task on each machine instead of one Mesos task per - Spark task. This gives lower-latency scheduling for short queries, but leaves resources in use - for the whole duration of the Spark job. + If set to true, runs over Mesos clusters in "coarse-grained" sharing mode, where Spark acquires one long-lived Mesos task on each machine. + If set to false, runs over Mesos cluster in "fine-grained" sharing mode, where one Mesos task is created per Spark task. + Detailed information in 'Mesos Run Modes'. spark.mesos.extra.cores 0 - Set the extra amount of cpus to request per task. This setting is only used for Mesos coarse grain mode. - The total amount of cores requested per task is the number of cores in the offer plus the extra cores configured. - Note that total amount of cores the executor will request in total will not exceed the spark.cores.max setting. + Set the extra number of cores for an executor to advertise. This + does not result in more cores allocated. It instead means that an + executor will "pretend" it has more cores, so that the driver will + send it more tasks. Use this to increase parallelism. This + setting is only used for Mesos coarse-grained mode. @@ -278,7 +301,7 @@ See the [configuration page](configuration.html) for information on Spark config Set the name of the docker image that the Spark executors will run in. The selected image must have Spark installed, as well as a compatible version of the Mesos library. The installed path of Spark in the image can be specified with spark.mesos.executor.home; - the installed path of the Mesos library can be specified with spark.executorEnv.MESOS_NATIVE_LIBRARY. + the installed path of the Mesos library can be specified with spark.executorEnv.MESOS_NATIVE_JAVA_LIBRARY. @@ -326,8 +349,9 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.uris (none) - A list of URIs to be downloaded to the sandbox when driver or executor is launched by Mesos. - This applies to both coarse-grain and fine-grain mode. + A comma-separated list of URIs to be downloaded to the sandbox + when driver or executor is launched by Mesos. This applies to + both coarse-grained and fine-grained mode. @@ -361,11 +385,27 @@ See the [configuration page](configuration.html) for information on Spark config
  • Scalar constraints are matched with "less than equal" semantics i.e. value in the constraint must be less than or equal to the value in the resource offer.
  • Range constraints are matched with "contains" semantics i.e. value in the constraint must be within the resource offer's value.
  • Set constraints are matched with "subset of" semantics i.e. value in the constraint must be a subset of the resource offer's value.
  • -
  • Text constraints are metched with "equality" semantics i.e. value in the constraint must be exactly equal to the resource offer's value.
  • +
  • Text constraints are matched with "equality" semantics i.e. value in the constraint must be exactly equal to the resource offer's value.
  • In case there is no value present as a part of the constraint any offer with the corresponding attribute will be accepted (without value check).
  • + + spark.mesos.driver.webui.url + (none) + + Set the Spark Mesos driver webui_url for interacting with the framework. + If unset it will point to Spark's internal web UI. + + + + spark.mesos.dispatcher.webui.url + (none) + + Set the Spark Mesos dispatcher webui_url for interacting with the framework. + If unset it will point to Spark's internal web UI. + + # Troubleshooting and Debugging diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index db6bfa69ee0fe..09701abdb0574 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -49,8 +49,8 @@ In `cluster` mode, the driver runs on a different machine than the client, so `S $ ./bin/spark-submit --class my.main.Class \ --master yarn \ --deploy-mode cluster \ - --jars my-other-jar.jar,my-other-other-jar.jar - my-main-jar.jar + --jars my-other-jar.jar,my-other-other-jar.jar \ + my-main-jar.jar \ app_arg1 app_arg2 @@ -113,6 +113,19 @@ If you need a reference to the proper location to put log files in the YARN so t Use lower-case suffixes, e.g. k, m, g, t, and p, for kibi-, mebi-, gibi-, tebi-, and pebibytes, respectively. + + spark.driver.memory + 1g + + Amount of memory to use for the driver process, i.e. where SparkContext is initialized. + (e.g. 1g, 2g). + +
    Note: In client mode, this config must not be set through the SparkConf + directly in your application, because the driver JVM has already started at that point. + Instead, please set this through the --driver-memory command line option + or in your default properties file. + + spark.driver.cores 1 @@ -146,6 +159,13 @@ If you need a reference to the proper location to put log files in the YARN so t HDFS replication level for the files uploaded into HDFS for the application. These include things like the Spark jar, the app jar, and any distributed cache files/archives. + + spark.yarn.stagingDir + Current user's home directory in the filesystem + + Staging directory used while submitting applications. + + spark.yarn.preserve.staging.files false @@ -202,6 +222,20 @@ If you need a reference to the proper location to put log files in the YARN so t Comma-separated list of files to be placed in the working directory of each executor. + + spark.yarn.dist.jars + (none) + + Comma-separated list of jars to be placed in the working directory of each executor. + + + + spark.executor.cores + 1 in YARN mode, all the available cores on the worker in standalone mode. + + The number of cores to use on each executor. For YARN and standalone mode only. + + spark.executor.instances 2 @@ -209,6 +243,13 @@ If you need a reference to the proper location to put log files in the YARN so t The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. If both spark.dynamicAllocation.enabled and spark.executor.instances are specified, dynamic allocation is turned off and the specified number of spark.executor.instances is used. + + spark.executor.memory + 1g + + Amount of memory to use per executor process (e.g. 2g, 8g). + + spark.yarn.executor.memoryOverhead executorMemory * 0.10, with minimum of 384 @@ -245,14 +286,25 @@ If you need a reference to the proper location to put log files in the YARN so t - spark.yarn.jar + spark.yarn.jars (none) - The location of the Spark jar file, in case overriding the default location is desired. - By default, Spark on YARN will use a Spark jar installed locally, but the Spark jar can also be + List of libraries containing Spark code to distribute to YARN containers. + By default, Spark on YARN will use Spark jars installed locally, but the Spark jars can also be in a world-readable location on HDFS. This allows YARN to cache it on nodes so that it doesn't - need to be distributed each time an application runs. To point to a jar on HDFS, for example, - set this configuration to hdfs:///some/path. + need to be distributed each time an application runs. To point to jars on HDFS, for example, + set this configuration to hdfs:///some/path. Globs are allowed. + + + + spark.yarn.archive + (none) + + An archive containing needed Spark jars for distribution to the YARN cache. If set, this + configuration replaces spark.yarn.jars and the archive is used in all the + application's containers. The archive should contain jar files in its root directory. + Like with the previous option, the archive can also be hosted on HDFS to speed up file + distribution. @@ -260,10 +312,10 @@ If you need a reference to the proper location to put log files in the YARN so t (none) A comma-separated list of secure HDFS namenodes your Spark application is going to access. For - example, spark.yarn.access.namenodes=hdfs://nn1.com:8032,hdfs://nn2.com:8032. - The Spark application must have access to the namenodes listed and Kerberos must - be properly configured to be able to access them (either in the same realm or in - a trusted realm). Spark acquires security tokens for each of the namenodes so that + example, spark.yarn.access.namenodes=hdfs://nn1.com:8032,hdfs://nn2.com:8032, + webhdfs://nn3.com:50070. The Spark application must have access to the namenodes listed + and Kerberos must be properly configured to be able to access them (either in the same realm + or in a trusted realm). Spark acquires security tokens for each of the namenodes so that the Spark application can access those remote HDFS clusters. @@ -290,7 +342,9 @@ If you need a reference to the proper location to put log files in the YARN so t (none) A string of extra JVM options to pass to the YARN Application Master in client mode. - In cluster mode, use spark.driver.extraJavaOptions instead. + In cluster mode, use spark.driver.extraJavaOptions instead. Note that it is illegal + to set maximum heap size (-Xmx) settings with this option. Maximum heap size settings can be set + with spark.yarn.am.memory @@ -326,6 +380,15 @@ If you need a reference to the proper location to put log files in the YARN so t Otherwise, the client process will exit after submission. + + spark.yarn.am.nodeLabelExpression + (none) + + A YARN node label expression that restricts the set of nodes AM will be scheduled on. + Only versions of YARN greater than or equal to 2.6 support node label expressions, so when + running against earlier versions, this property will be ignored. + + spark.yarn.executor.nodeLabelExpression (none) @@ -349,14 +412,14 @@ If you need a reference to the proper location to put log files in the YARN so t The full path to the file that contains the keytab for the principal specified above. This keytab will be copied to the node running the YARN Application Master via the Secure Distributed Cache, - for renewing the login tickets and the delegation tokens periodically. + for renewing the login tickets and the delegation tokens periodically. (Works also with the "local" master) spark.yarn.principal (none) - Principal to be used to login to KDC, while running on secure HDFS. + Principal to be used to login to KDC, while running on secure HDFS. (Works also with the "local" master) diff --git a/docs/security.md b/docs/security.md index 177109415180b..32c33d285747a 100644 --- a/docs/security.md +++ b/docs/security.md @@ -6,15 +6,19 @@ title: Security Spark currently supports authentication via a shared secret. Authentication can be configured to be on via the `spark.authenticate` configuration parameter. This parameter controls whether the Spark communication protocols do authentication using the shared secret. This authentication is a basic handshake to make sure both sides have the same shared secret and are allowed to communicate. If the shared secret is not identical they will not be allowed to communicate. The shared secret is created as follows: -* For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. +* For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. * For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. ## Web UI -The Spark UI can also be secured by using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html) via the `spark.ui.filters` setting. A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view ACLs to make sure they are authorized to view the UI. The configs `spark.acls.enable` and `spark.ui.view.acls` control the behavior of the ACLs. Note that the user who started the application always has view access to the UI. On YARN, the Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters. +The Spark UI can be secured by using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html) via the `spark.ui.filters` setting +and by using [https/SSL](http://en.wikipedia.org/wiki/HTTPS) via the `spark.ui.https.enabled` setting. -Spark also supports modify ACLs to control who has access to modify a running Spark application. This includes things like killing the application or a task. This is controlled by the configs `spark.acls.enable` and `spark.modify.acls`. Note that if you are authenticating the web UI, in order to use the kill button on the web UI it might be necessary to add the users in the modify acls to the view acls also. On YARN, the modify acls are passed in and control who has modify access via YARN interfaces. +### Authentication +A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view ACLs to make sure they are authorized to view the UI. The configs `spark.acls.enable` and `spark.ui.view.acls` control the behavior of the ACLs. Note that the user who started the application always has view access to the UI. On YARN, the Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters. + +Spark also supports modify ACLs to control who has access to modify a running Spark application. This includes things like killing the application or a task. This is controlled by the configs `spark.acls.enable` and `spark.modify.acls`. Note that if you are authenticating the web UI, in order to use the kill button on the web UI it might be necessary to add the users in the modify acls to the view acls also. On YARN, the modify acls are passed in and control who has modify access via YARN interfaces. Spark allows for a set of administrators to be specified in the acls who always have view and modify permissions to all the applications. is controlled by the config `spark.admin.acls`. This is useful on a shared cluster where you might have administrators or support staff who help users debug applications. ## Event Logging @@ -23,8 +27,7 @@ If your applications are using event logging, the directory where the event logs ## Encryption -Spark supports SSL for Akka and HTTP (for broadcast and file server) protocols. SASL encryption is -supported for the block transfer service. Encryption is not yet supported for the WebUI. +Spark supports SSL for HTTP protocols. SASL encryption is supported for the block transfer service. Encryption is not yet supported for data stored by Spark in temporary local storage, such as shuffle files, cached data, and other application files. If encrypting this data is desired, a workaround is @@ -32,8 +35,37 @@ to configure your cluster manager to store application data on encrypted disks. ### SSL Configuration -Configuration for SSL is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the protocols without disabling the ability to configure each one individually. The common SSL settings are at `spark.ssl` namespace in Spark configuration, while Akka SSL configuration is at `spark.ssl.akka` and HTTP for broadcast and file server SSL configuration is at `spark.ssl.fs`. The full breakdown can be found on the [configuration page](configuration.html). +Configuration for SSL is organized hierarchically. The user can configure the default SSL settings +which will be used for all the supported communication protocols unless they are overwritten by +protocol-specific settings. This way the user can easily provide the common settings for all the +protocols without disabling the ability to configure each one individually. The common SSL settings +are at `spark.ssl` namespace in Spark configuration. The following table describes the +component-specific configuration namespaces used to override the default settings: + + + + + + + + + + + + + + + + + + + + + + +
    Config NamespaceComponent
    spark.ssl.fsHTTP file server and broadcast server
    spark.ssl.uiSpark application Web UI
    spark.ssl.standaloneStandalone Master / Worker Web UI
    spark.ssl.historyServerHistory Server Web UI
    +The full breakdown of available SSL options can be found on the [configuration page](configuration.html). SSL must be configured on each node and configured for each component involved in communication using the particular protocol. ### YARN mode @@ -100,7 +132,7 @@ configure those ports. 7077 Submit job to cluster /
    Join cluster SPARK_MASTER_PORT - Akka-based. Set to "0" to choose a port randomly. Standalone mode only. + Set to "0" to choose a port randomly. Standalone mode only. Standalone Master @@ -108,7 +140,7 @@ configure those ports. (random) Schedule executors SPARK_WORKER_PORT - Akka-based. Set to "0" to choose a port randomly. Standalone mode only. + Set to "0" to choose a port randomly. Standalone mode only. @@ -141,40 +173,7 @@ configure those ports. (random) Connect to application /
    Notify executor state changes spark.driver.port - Akka-based. Set to "0" to choose a port randomly. - - - Driver - Executor - (random) - Schedule tasks - spark.executor.port - Akka-based. Set to "0" to choose a port randomly. - - - Executor - Driver - (random) - File server for files and jars - spark.fileserver.port - Jetty-based - - - Executor - Driver - (random) - HTTP Broadcast - spark.broadcast.port - Jetty-based. Not used by TorrentBroadcast, which sends data through the block manager - instead. - - - Executor - Driver - (random) - Class file server - spark.replClassServer.port - Jetty-based. Only used in Spark shells. + Set to "0" to choose a port randomly. Executor / Driver diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 2fe9ec3542b28..fd94c34d1638d 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -112,8 +112,8 @@ You can optionally configure the cluster further by setting environment variable SPARK_LOCAL_DIRS - Directory to use for "scratch" space in Spark, including map output files and RDDs that get - stored on disk. This should be on a fast, local disk in your system. It can also be a + Directory to use for "scratch" space in Spark, including map output files and RDDs that get + stored on disk. This should be on a fast, local disk in your system. It can also be a comma-separated list of multiple directories on different disks. @@ -335,29 +335,14 @@ By default, standalone scheduling clusters are resilient to Worker failures (ins **Overview** -Utilizing ZooKeeper to provide leader election and some state storage, you can launch multiple Masters in your cluster connected to the same ZooKeeper instance. One will be elected "leader" and the others will remain in standby mode. If the current leader dies, another Master will be elected, recover the old Master's state, and then resume scheduling. The entire recovery process (from the time the the first leader goes down) should take between 1 and 2 minutes. Note that this delay only affects scheduling _new_ applications -- applications that were already running during Master failover are unaffected. +Utilizing ZooKeeper to provide leader election and some state storage, you can launch multiple Masters in your cluster connected to the same ZooKeeper instance. One will be elected "leader" and the others will remain in standby mode. If the current leader dies, another Master will be elected, recover the old Master's state, and then resume scheduling. The entire recovery process (from the time the first leader goes down) should take between 1 and 2 minutes. Note that this delay only affects scheduling _new_ applications -- applications that were already running during Master failover are unaffected. Learn more about getting started with ZooKeeper [here](http://zookeeper.apache.org/doc/trunk/zookeeperStarted.html). **Configuration** -In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env using this configuration: - - - - - - - - - - - - - - - -
    System propertyMeaning
    spark.deploy.recoveryModeSet to ZOOKEEPER to enable standby Master recovery mode (default: NONE).
    spark.deploy.zookeeper.urlThe ZooKeeper cluster url (e.g., 192.168.1.100:2181,192.168.1.101:2181).
    spark.deploy.zookeeper.dirThe directory in ZooKeeper to store recovery state (default: /spark).
    +In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env by configuring `spark.deploy.recoveryMode` and related spark.deploy.zookeeper.* configurations. +For more information about these configurations please refer to the configurations (doc)[configurations.html#deploy] Possible gotcha: If you have multiple Masters in your cluster but fail to correctly configure the Masters to use ZooKeeper, the Masters will fail to discover each other and think they're all leaders. This will not lead to a healthy cluster state (as all Masters will schedule independently). diff --git a/docs/sparkr.md b/docs/sparkr.md index 437bd4756c276..73e38b8c70f01 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -148,7 +148,7 @@ printSchema(people) The data sources API can also be used to save out DataFrames into multiple file formats. For example we can save the DataFrame from the previous example -to a Parquet file using `write.df` +to a Parquet file using `write.df` (Until Spark 1.6, the default mode for writes was `append`. It was changed in Spark 1.7 to `error` to match the Scala API)
    {% highlight r %} @@ -286,24 +286,37 @@ head(teenagers) # Machine Learning -SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', '+', and '-'. The example below shows the use of building a gaussian GLM model using SparkR. +SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', ':', '+', and '-'. + +The [summary()](api/R/summary.html) function gives the summary of a model produced by [glm()](api/R/glm.html). + +* For gaussian GLM model, it returns a list with 'devianceResiduals' and 'coefficients' components. The 'devianceResiduals' gives the min/max deviance residuals of the estimation; the 'coefficients' gives the estimated coefficients and their estimated standard errors, t values and p-values. (It only available when model fitted by normal solver.) +* For binomial GLM model, it returns a list with 'coefficients' component which gives the estimated coefficients. + +The examples below show the use of building gaussian GLM model and binomial GLM model using SparkR. + +## Gaussian GLM model
    {% highlight r %} # Create the DataFrame df <- createDataFrame(sqlContext, iris) -# Fit a linear model over the dataset. +# Fit a gaussian GLM model over the dataset. model <- glm(Sepal_Length ~ Sepal_Width + Species, data = df, family = "gaussian") -# Model coefficients are returned in a similar format to R's native glm(). +# Model summary are returned in a similar format to R's native glm(). summary(model) +##$devianceResiduals +## Min Max +## -1.307112 1.412532 +## ##$coefficients -## Estimate -##(Intercept) 2.2513930 -##Sepal_Width 0.8035609 -##Species_versicolor 1.4587432 -##Species_virginica 1.9468169 +## Estimate Std. Error t value Pr(>|t|) +##(Intercept) 2.251393 0.3697543 6.08889 9.568102e-09 +##Sepal_Width 0.8035609 0.106339 7.556598 4.187317e-12 +##Species_versicolor 1.458743 0.1121079 13.01195 0 +##Species_virginica 1.946817 0.100015 19.46525 0 # Make predictions based on the model. predictions <- predict(model, newData = df) @@ -317,3 +330,64 @@ head(select(predictions, "Sepal_Length", "prediction")) ##6 5.4 5.385281 {% endhighlight %}
    + +## Binomial GLM model + +
    +{% highlight r %} +# Create the DataFrame +df <- createDataFrame(sqlContext, iris) +training <- filter(df, df$Species != "setosa") + +# Fit a binomial GLM model over the dataset. +model <- glm(Species ~ Sepal_Length + Sepal_Width, data = training, family = "binomial") + +# Model coefficients are returned in a similar format to R's native glm(). +summary(model) +##$coefficients +## Estimate +##(Intercept) -13.046005 +##Sepal_Length 1.902373 +##Sepal_Width 0.404655 +{% endhighlight %} +
    + +# R Function Name Conflicts + +When loading and attaching a new package in R, it is possible to have a name [conflict](https://stat.ethz.ch/R-manual/R-devel/library/base/html/library.html), where a +function is masking another function. + +The following functions are masked by the SparkR package: + + + + + + + + + + + + + + + +
    Masked functionHow to Access
    cov in package:stats
    stats::cov(x, y = NULL, use = "everything",
    +           method = c("pearson", "kendall", "spearman"))
    filter in package:stats
    stats::filter(x, filter, method = c("convolution", "recursive"),
    +              sides = 2, circular = FALSE, init)
    sample in package:basebase::sample(x, size, replace = FALSE, prob = NULL)
    + +Since part of SparkR is modeled on the `dplyr` package, certain functions in SparkR share the same names with those in `dplyr`. Depending on the load order of the two packages, some functions from the package loaded first are masked by those in the package loaded after. In such case, prefix such calls with the package name, for instance, `SparkR::cume_dist(x)` or `dplyr::cume_dist(x)`. + +You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-manual/R-devel/library/base/html/search.html) + + +# Migration Guide + +## Upgrading From SparkR 1.5.x to 1.6 + + - Before Spark 1.6, the default mode for writes was `append`. It was changed in Spark 1.6.0 to `error` to match the Scala API. + +## Upgrading From SparkR 1.6.x to 2.0 + + - The method `table` has been removed and replaced by `tableToDF`. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 2fe5c36338899..77887f4ca36be 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1,6 +1,6 @@ --- layout: global -displayTitle: Spark SQL and DataFrame Guide +displayTitle: Spark SQL, DataFrames and Datasets Guide title: Spark SQL and DataFrames --- @@ -9,18 +9,51 @@ title: Spark SQL and DataFrames # Overview -Spark SQL is a Spark module for structured data processing. It provides a programming abstraction called DataFrames and can also act as distributed SQL query engine. +Spark SQL is a Spark module for structured data processing. Unlike the basic Spark RDD API, the interfaces provided +by Spark SQL provide Spark with more information about the structure of both the data and the computation being performed. Internally, +Spark SQL uses this extra information to perform extra optimizations. There are several ways to +interact with Spark SQL including SQL, the DataFrames API and the Datasets API. When computing a result +the same execution engine is used, independent of which API/language you are using to express the +computation. This unification means that developers can easily switch back and forth between the +various APIs based on which provides the most natural way to express a given transformation. -Spark SQL can also be used to read data from an existing Hive installation. For more on how to configure this feature, please refer to the [Hive Tables](#hive-tables) section. +All of the examples on this page use sample data included in the Spark distribution and can be run in +the `spark-shell`, `pyspark` shell, or `sparkR` shell. -# DataFrames +## SQL -A DataFrame is a distributed collection of data organized into named columns. It is conceptually equivalent to a table in a relational database or a data frame in R/Python, but with richer optimizations under the hood. DataFrames can be constructed from a wide array of sources such as: structured data files, tables in Hive, external databases, or existing RDDs. +One use of Spark SQL is to execute SQL queries written using either a basic SQL syntax or HiveQL. +Spark SQL can also be used to read data from an existing Hive installation. For more on how to +configure this feature, please refer to the [Hive Tables](#hive-tables) section. When running +SQL from within another programming language the results will be returned as a [DataFrame](#DataFrames). +You can also interact with the SQL interface using the [command-line](#running-the-spark-sql-cli) +or over [JDBC/ODBC](#running-the-thrift-jdbcodbc-server). -The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), [Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame), and [R](api/R/index.html). +## DataFrames -All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`, `pyspark` shell, or `sparkR` shell. +A DataFrame is a distributed collection of data organized into named columns. It is conceptually +equivalent to a table in a relational database or a data frame in R/Python, but with richer +optimizations under the hood. DataFrames can be constructed from a wide array of [sources](#data-sources) such +as: structured data files, tables in Hive, external databases, or existing RDDs. +The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), +[Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), +[Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame), and [R](api/R/index.html). + +## Datasets + +A Dataset is a new experimental interface added in Spark 1.6 that tries to provide the benefits of +RDDs (strong typing, ability to use powerful lambda functions) with the benefits of Spark SQL's +optimized execution engine. A Dataset can be [constructed](#creating-datasets) from JVM objects and then manipulated +using functional transformations (map, flatMap, filter, etc.). + +The unified Dataset API can be used both in [Scala](api/scala/index.html#org.apache.spark.sql.Dataset) and +[Java](api/java/index.html?org/apache/spark/sql/Dataset.html). Python does not yet have support for +the Dataset API, but due to its dynamic nature many of the benefits are already available (i.e. you can +access the field of a row by name naturally `row.columnName`). Full python support will be added +in a future release. + +# Getting Started ## Starting Point: SQLContext @@ -29,7 +62,7 @@ All of the examples on this page use sample data included in the Spark distribut The entry point into all functionality in Spark SQL is the [`SQLContext`](api/scala/index.html#org.apache.spark.sql.SQLContext) class, or one of its -descendants. To create a basic `SQLContext`, all you need is a SparkContext. +descendants. To create a basic `SQLContext`, all you need is a SparkContext. {% highlight scala %} val sc: SparkContext // An existing SparkContext. @@ -45,7 +78,7 @@ import sqlContext.implicits._ The entry point into all functionality in Spark SQL is the [`SQLContext`](api/java/index.html#org.apache.spark.sql.SQLContext) class, or one of its -descendants. To create a basic `SQLContext`, all you need is a SparkContext. +descendants. To create a basic `SQLContext`, all you need is a SparkContext. {% highlight java %} JavaSparkContext sc = ...; // An existing JavaSparkContext. @@ -58,7 +91,7 @@ SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); The entry point into all relational functionality in Spark is the [`SQLContext`](api/python/pyspark.sql.html#pyspark.sql.SQLContext) class, or one -of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. +of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. {% highlight python %} from pyspark.sql import SQLContext @@ -70,7 +103,7 @@ sqlContext = SQLContext(sc)
    The entry point into all relational functionality in Spark is the -`SQLContext` class, or one of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. +`SQLContext` class, or one of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. {% highlight r %} sqlContext <- sparkRSQL.init(sc) @@ -82,20 +115,13 @@ sqlContext <- sparkRSQL.init(sc) In addition to the basic `SQLContext`, you can also create a `HiveContext`, which provides a superset of the functionality provided by the basic `SQLContext`. Additional features include the ability to write queries using the more complete HiveQL parser, access to Hive UDFs, and the -ability to read data from Hive tables. To use a `HiveContext`, you do not need to have an +ability to read data from Hive tables. To use a `HiveContext`, you do not need to have an existing Hive setup, and all of the data sources available to a `SQLContext` are still available. `HiveContext` is only packaged separately to avoid including all of Hive's dependencies in the default -Spark build. If these dependencies are not a problem for your application then using `HiveContext` -is recommended for the 1.3 release of Spark. Future releases will focus on bringing `SQLContext` up +Spark build. If these dependencies are not a problem for your application then using `HiveContext` +is recommended for the 1.3 release of Spark. Future releases will focus on bringing `SQLContext` up to feature parity with a `HiveContext`. -The specific variant of SQL that is used to parse queries can also be selected using the -`spark.sql.dialect` option. This parameter can be changed using either the `setConf` method on -a `SQLContext` or by using a `SET key=value` command in SQL. For a `SQLContext`, the only dialect -available is "sql" which uses a simple SQL parser provided by Spark SQL. In a `HiveContext`, the -default is "hiveql", though "sql" is also available. Since the HiveQL parser is much more complete, -this is recommended for most use cases. - ## Creating DataFrames @@ -215,7 +241,7 @@ df.groupBy("age").count().show() For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/scala/index.html#org.apache.spark.sql.DataFrame). -In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/scala/index.html#org.apache.spark.sql.functions$). +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/scala/index.html#org.apache.spark.sql.functions$).
    @@ -270,7 +296,7 @@ df.groupBy("age").count().show(); For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/java/org/apache/spark/sql/DataFrame.html). -In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/java/org/apache/spark/sql/functions.html). +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/java/org/apache/spark/sql/functions.html).
    @@ -331,7 +357,7 @@ df.groupBy("age").count().show() For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/python/pyspark.sql.html#pyspark.sql.DataFrame). -In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/python/pyspark.sql.html#module-pyspark.sql.functions). +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/python/pyspark.sql.html#module-pyspark.sql.functions). @@ -385,7 +411,7 @@ showDF(count(groupBy(df, "age"))) For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/R/index.html). -In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/R/index.html). +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/R/index.html). @@ -398,14 +424,14 @@ The `sql` function on a `SQLContext` enables applications to run SQL queries pro
    {% highlight scala %} -val sqlContext = ... // An existing SQLContext +val sqlContext = ... // An existing SQLContext val df = sqlContext.sql("SELECT * FROM table") {% endhighlight %}
    {% highlight java %} -SQLContext sqlContext = ... // An existing SQLContext +SQLContext sqlContext = ... // An existing SQLContext DataFrame df = sqlContext.sql("SELECT * FROM table") {% endhighlight %}
    @@ -428,15 +454,54 @@ df <- sql(sqlContext, "SELECT * FROM table")
    +## Creating Datasets + +Datasets are similar to RDDs, however, instead of using Java Serialization or Kryo they use +a specialized [Encoder](api/scala/index.html#org.apache.spark.sql.Encoder) to serialize the objects +for processing or transmitting over the network. While both encoders and standard serialization are +responsible for turning an object into bytes, encoders are code generated dynamically and use a format +that allows Spark to perform many operations like filtering, sorting and hashing without deserializing +the bytes back into an object. + +
    +
    + +{% highlight scala %} +// Encoders for most common types are automatically provided by importing sqlContext.implicits._ +val ds = Seq(1, 2, 3).toDS() +ds.map(_ + 1).collect() // Returns: Array(2, 3, 4) + +// Encoders are also created for case classes. +case class Person(name: String, age: Long) +val ds = Seq(Person("Andy", 32)).toDS() + +// DataFrames can be converted to a Dataset by providing a class. Mapping will be done by name. +val path = "examples/src/main/resources/people.json" +val people = sqlContext.read.json(path).as[Person] + +{% endhighlight %} + +
    + +
    + +{% highlight java %} +JavaSparkContext sc = ...; // An existing JavaSparkContext. +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); +{% endhighlight %} + +
    +
    + ## Interoperating with RDDs -Spark SQL supports two different methods for converting existing RDDs into DataFrames. The first -method uses reflection to infer the schema of an RDD that contains specific types of objects. This +Spark SQL supports two different methods for converting existing RDDs into DataFrames. The first +method uses reflection to infer the schema of an RDD that contains specific types of objects. This reflection based approach leads to more concise code and works well when you already know the schema while writing your Spark application. The second method for creating DataFrames is through a programmatic interface that allows you to -construct a schema and then apply it to an existing RDD. While this method is more verbose, it allows +construct a schema and then apply it to an existing RDD. While this method is more verbose, it allows you to construct DataFrames when the columns and their types are not known until runtime. ### Inferring the Schema Using Reflection @@ -445,11 +510,11 @@ you to construct DataFrames when the columns and their types are not known until
    The Scala interface for Spark SQL supports automatically converting an RDD containing case classes -to a DataFrame. The case class -defines the schema of the table. The names of the arguments to the case class are read using +to a DataFrame. The case class +defines the schema of the table. The names of the arguments to the case class are read using reflection and become the names of the columns. Case classes can also be nested or contain complex types such as Sequences or Arrays. This RDD can be implicitly converted to a DataFrame and then be -registered as a table. Tables can be used in subsequent SQL statements. +registered as a table. Tables can be used in subsequent SQL statements. {% highlight scala %} // sc is an existing SparkContext. @@ -486,9 +551,9 @@ teenagers.map(_.getValuesMap[Any](List("name", "age"))).collect().foreach(printl
    Spark SQL supports automatically converting an RDD of [JavaBeans](http://stackoverflow.com/questions/3295496/what-is-a-javabean-exactly) -into a DataFrame. The BeanInfo, obtained using reflection, defines the schema of the table. +into a DataFrame. The BeanInfo, obtained using reflection, defines the schema of the table. Currently, Spark SQL does not support JavaBeans that contain -nested or contain complex types such as Lists or Arrays. You can create a JavaBean by creating a +nested or contain complex types such as Lists or Arrays. You can create a JavaBean by creating a class that implements Serializable and has getters and setters for all of its fields. {% highlight java %} @@ -559,9 +624,9 @@ List teenagerNames = teenagers.javaRDD().map(new Function()
    -Spark SQL can convert an RDD of Row objects to a DataFrame, inferring the datatypes. Rows are constructed by passing a list of +Spark SQL can convert an RDD of Row objects to a DataFrame, inferring the datatypes. Rows are constructed by passing a list of key/value pairs as kwargs to the Row class. The keys of this list define the column names of the table, -and the types are inferred by looking at the first row. Since we currently only look at the first +and the types are inferred by looking at the first row. Since we currently only look at the first row, it is important that there is no missing data in the first row of the RDD. In future versions we plan to more completely infer the schema by looking at more data, similar to the inference that is performed on JSON files. @@ -688,7 +753,7 @@ JavaRDD people = sc.textFile("examples/src/main/resources/people.txt"); String schemaString = "name age"; // Generate the schema based on the string of schema -List fields = new ArrayList(); +List fields = new ArrayList<>(); for (String fieldName: schemaString.split(" ")) { fields.add(DataTypes.createStructField(fieldName, DataTypes.StringType, true)); } @@ -780,7 +845,7 @@ for name in names.collect(): Spark SQL supports operating on a variety of data sources through the `DataFrame` interface. A DataFrame can be operated on as normal RDDs and can also be registered as a temporary table. -Registering a DataFrame as a table allows you to run SQL queries over its data. This section +Registering a DataFrame as a table allows you to run SQL queries over its data. This section describes the general methods for loading and saving data using the Spark Data Sources and then goes into specific options that are available for the built-in data sources. @@ -834,9 +899,9 @@ saveDF(select(df, "name", "age"), "namesAndAges.parquet") ### Manually Specifying Options You can also manually specify the data source that will be used along with any extra options -that you would like to pass to the data source. Data sources are specified by their fully qualified +that you would like to pass to the data source. Data sources are specified by their fully qualified name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can also use their short -names (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types +names (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types using this syntax.
    @@ -923,8 +988,8 @@ df <- sql(sqlContext, "SELECT * FROM parquet.`examples/src/main/resources/users. ### Save Modes Save operations can optionally take a `SaveMode`, that specifies how to handle existing data if -present. It is important to realize that these save modes do not utilize any locking and are not -atomic. Additionally, when performing a `Overwrite`, the data will be deleted before writing out the +present. It is important to realize that these save modes do not utilize any locking and are not +atomic. Additionally, when performing a `Overwrite`, the data will be deleted before writing out the new data. @@ -960,7 +1025,7 @@ new data.
    Ignore mode means that when saving a DataFrame to a data source, if data already exists, the save operation is expected to not save the contents of the DataFrame and to not - change the existing data. This is similar to a CREATE TABLE IF NOT EXISTS in SQL. + change the existing data. This is similar to a CREATE TABLE IF NOT EXISTS in SQL.
    @@ -968,21 +1033,22 @@ new data. ### Saving to Persistent Tables When working with a `HiveContext`, `DataFrames` can also be saved as persistent tables using the -`saveAsTable` command. Unlike the `registerTempTable` command, `saveAsTable` will materialize the -contents of the dataframe and create a pointer to the data in the HiveMetastore. Persistent tables +`saveAsTable` command. Unlike the `registerTempTable` command, `saveAsTable` will materialize the +contents of the dataframe and create a pointer to the data in the HiveMetastore. Persistent tables will still exist even after your Spark program has restarted, as long as you maintain your connection -to the same metastore. A DataFrame for a persistent table can be created by calling the `table` +to the same metastore. A DataFrame for a persistent table can be created by calling the `table` method on a `SQLContext` with the name of the table. By default `saveAsTable` will create a "managed table", meaning that the location of the data will -be controlled by the metastore. Managed tables will also have their data deleted automatically +be controlled by the metastore. Managed tables will also have their data deleted automatically when a table is dropped. ## Parquet Files [Parquet](http://parquet.io) is a columnar format that is supported by many other data processing systems. Spark SQL provides support for both reading and writing Parquet files that automatically preserves the schema -of the original data. +of the original data. When writing Parquet files, all columns are automatically converted to be nullable for +compatibility reasons. ### Loading Data Programmatically @@ -1002,7 +1068,7 @@ val people: RDD[Person] = ... // An RDD of case class objects, from the previous // The RDD is implicitly converted to a DataFrame by implicits, allowing it to be stored using Parquet. people.write.parquet("people.parquet") -// Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. +// Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. // The result of loading a Parquet file is also a DataFrame. val parquetFile = sqlContext.read.parquet("people.parquet") @@ -1024,7 +1090,7 @@ DataFrame schemaPeople = ... // The DataFrame from the previous example. // DataFrames can be saved as Parquet files, maintaining the schema information. schemaPeople.write().parquet("people.parquet"); -// Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. +// Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a DataFrame. DataFrame parquetFile = sqlContext.read().parquet("people.parquet"); @@ -1050,7 +1116,7 @@ schemaPeople # The DataFrame from the previous example. # DataFrames can be saved as Parquet files, maintaining the schema information. schemaPeople.write.parquet("people.parquet") -# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. +# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. # The result of loading a parquet file is also a DataFrame. parquetFile = sqlContext.read.parquet("people.parquet") @@ -1074,7 +1140,7 @@ schemaPeople # The DataFrame from the previous example. # DataFrames can be saved as Parquet files, maintaining the schema information. saveAsParquetFile(schemaPeople, "people.parquet") -# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. +# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. # The result of loading a parquet file is also a DataFrame. parquetFile <- parquetFile(sqlContext, "people.parquet") @@ -1089,15 +1155,6 @@ for (teenName in collect(teenNames)) {
    -
    - -{% highlight python %} -# sqlContext is an existing HiveContext -sqlContext.sql("REFRESH TABLE my_table") -{% endhighlight %} - -
    -
    {% highlight sql %} @@ -1118,10 +1175,10 @@ SELECT * FROM parquetTable ### Partition Discovery -Table partitioning is a common optimization approach used in systems like Hive. In a partitioned +Table partitioning is a common optimization approach used in systems like Hive. In a partitioned table, data are usually stored in different directories, with partitioning column values encoded in -the path of each partition directory. The Parquet data source is now able to discover and infer -partitioning information automatically. For example, we can store all our previously used +the path of each partition directory. The Parquet data source is now able to discover and infer +partitioning information automatically. For example, we can store all our previously used population data into a partitioned table using the following directory structure, with two extra columns, `gender` and `country` as partitioning columns: @@ -1163,22 +1220,29 @@ root {% endhighlight %} -Notice that the data types of the partitioning columns are automatically inferred. Currently, +Notice that the data types of the partitioning columns are automatically inferred. Currently, numeric data types and string type are supported. Sometimes users may not want to automatically infer the data types of the partitioning columns. For these use cases, the automatic type inference can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, which is default to `true`. When type inference is disabled, string type will be used for the partitioning columns. +Starting from Spark 1.6.0, partition discovery only finds partitions under the given paths +by default. For the above example, if users pass `path/to/table/gender=male` to either +`SQLContext.read.parquet` or `SQLContext.read.load`, `gender` will not be considered as a +partitioning column. If users need to specify the base path that partition discovery +should start with, they can set `basePath` in the data source options. For example, +when `path/to/table/gender=male` is the path of the data and +users set `basePath` to `path/to/table/`, `gender` will be a partitioning column. ### Schema Merging -Like ProtocolBuffer, Avro, and Thrift, Parquet also supports schema evolution. Users can start with -a simple schema, and gradually add more columns to the schema as needed. In this way, users may end -up with multiple Parquet files with different but mutually compatible schemas. The Parquet data +Like ProtocolBuffer, Avro, and Thrift, Parquet also supports schema evolution. Users can start with +a simple schema, and gradually add more columns to the schema as needed. In this way, users may end +up with multiple Parquet files with different but mutually compatible schemas. The Parquet data source is now able to automatically detect this case and merge schemas of all these files. Since schema merging is a relatively expensive operation, and is not a necessity in most cases, we -turned it off by default starting from 1.5.0. You may enable it by +turned it off by default starting from 1.5.0. You may enable it by 1. setting data source option `mergeSchema` to `true` when reading Parquet files (as shown in the examples below), or @@ -1292,22 +1356,22 @@ processing. 1. Hive considers all columns nullable, while nullability in Parquet is significant Due to this reason, we must reconcile Hive metastore schema with Parquet schema when converting a -Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation rules are: +Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation rules are: 1. Fields that have the same name in both schema must have the same data type regardless of - nullability. The reconciled field should have the data type of the Parquet side, so that + nullability. The reconciled field should have the data type of the Parquet side, so that nullability is respected. 1. The reconciled schema contains exactly those fields defined in Hive metastore schema. - Any fields that only appear in the Parquet schema are dropped in the reconciled schema. - - Any fileds that only appear in the Hive metastore schema are added as nullable field in the + - Any fields that only appear in the Hive metastore schema are added as nullable field in the reconciled schema. #### Metadata Refreshing -Spark SQL caches Parquet metadata for better performance. When Hive metastore Parquet table -conversion is enabled, metadata of those converted tables are also cached. If these tables are +Spark SQL caches Parquet metadata for better performance. When Hive metastore Parquet table +conversion is enabled, metadata of those converted tables are also cached. If these tables are updated by Hive or other external tools, you need to refresh them manually to ensure consistent metadata. @@ -1370,7 +1434,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` spark.sql.parquet.int96AsTimestamp true - Some Parquet-producing systems, in particular Impala and Hive, store Timestamp into INT96. This + Some Parquet-producing systems, in particular Impala and Hive, store Timestamp into INT96. This flag tells Spark SQL to interpret INT96 data as a timestamp to provide compatibility with these systems. @@ -1402,37 +1466,6 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` support. - - spark.sql.parquet.output.committer.class - org.apache.parquet.hadoop.
    ParquetOutputCommitter
    - -

    - The output committer class used by Parquet. The specified class needs to be a subclass of - org.apache.hadoop.
    mapreduce.OutputCommitter
    . Typically, it's also a - subclass of org.apache.parquet.hadoop.ParquetOutputCommitter. -

    -

    - Note: -

      -
    • - This option is automatically ignored if spark.speculation is turned on. -
    • -
    • - This option must be set via Hadoop Configuration rather than Spark - SQLConf. -
    • -
    • - This option overrides spark.sql.sources.
      outputCommitterClass
      . -
    • -
    -

    -

    - Spark SQL comes with a builtin - org.apache.spark.sql.
    parquet.DirectParquetOutputCommitter
    , which can be more - efficient then the default Parquet output committer when writing data to S3. -

    - - spark.sql.parquet.mergeSchema false @@ -1469,7 +1502,7 @@ val people = sqlContext.read.json(path) // The inferred schema can be visualized using the printSchema() method. people.printSchema() // root -// |-- age: integer (nullable = true) +// |-- age: long (nullable = true) // |-- name: string (nullable = true) // Register this DataFrame as a table. @@ -1507,7 +1540,7 @@ DataFrame people = sqlContext.read().json("examples/src/main/resources/people.js // The inferred schema can be visualized using the printSchema() method. people.printSchema(); // root -// |-- age: integer (nullable = true) +// |-- age: long (nullable = true) // |-- name: string (nullable = true) // Register this DataFrame as a table. @@ -1545,7 +1578,7 @@ people = sqlContext.read.json("examples/src/main/resources/people.json") # The inferred schema can be visualized using the printSchema() method. people.printSchema() # root -# |-- age: integer (nullable = true) +# |-- age: long (nullable = true) # |-- name: string (nullable = true) # Register this DataFrame as a table. @@ -1584,7 +1617,7 @@ people <- jsonFile(sqlContext, path) # The inferred schema can be visualized using the printSchema() method. printSchema(people) # root -# |-- age: integer (nullable = true) +# |-- age: long (nullable = true) # |-- name: string (nullable = true) # Register this DataFrame as a table. @@ -1618,16 +1651,12 @@ SELECT * FROM jsonTable Spark SQL also supports reading and writing data stored in [Apache Hive](http://hive.apache.org/). However, since Hive has a large number of dependencies, it is not included in the default Spark assembly. Hive support is enabled by adding the `-Phive` and `-Phive-thriftserver` flags to Spark's build. -This command builds a new assembly jar that includes Hive. Note that this Hive assembly jar must also be present +This command builds a new assembly directory that includes Hive. Note that this Hive assembly directory must also be present on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries (SerDes) in order to access data stored in Hive. -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. Please note when running -the query on a YARN cluster (`cluster` mode), the `datanucleus` jars under the `lib_managed/jars` directory -and `hive-site.xml` under `conf/` directory need to be available on the driver and all executors launched by the -YARN cluster. The convenient way to do this is adding them through the `--jars` option and `--file` option of the -`spark-submit` command. - +Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` (for security configuration), +`hdfs-site.xml` (for HDFS configuration) file in `conf/`.
    @@ -1635,9 +1664,11 @@ YARN cluster. The convenient way to do this is adding them through the `--jars` When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and adds support for finding tables in the MetaStore and writing queries using HiveQL. Users who do -not have an existing Hive deployment can still create a `HiveContext`. When not configured by the -hive-site.xml, the context automatically creates `metastore_db` and `warehouse` in the current -directory. +not have an existing Hive deployment can still create a `HiveContext`. When not configured by the +hive-site.xml, the context automatically creates `metastore_db` in the current directory and +creates `warehouse` directory indicated by HiveConf, which defaults to `/user/hive/warehouse`. +Note that you may need to grant write privilege on `/user/hive/warehouse` to the user who starts +the spark application. {% highlight scala %} // sc is an existing SparkContext. @@ -1739,14 +1770,14 @@ The following options can be used to configure the version of Hive that is used property can be one of three options:
    1. builtin
    2. - Use Hive 1.2.1, which is bundled with the Spark assembly jar when -Phive is + Use Hive 1.2.1, which is bundled with the Spark assembly when -Phive is enabled. When this option is chosen, spark.sql.hive.metastore.version must be either 1.2.1 or not defined.
    3. maven
    4. - Use Hive jars of specified version downloaded from Maven repositories. This configuration + Use Hive jars of specified version downloaded from Maven repositories. This configuration is not generally recommended for production deployments. -
    5. A classpath in the standard format for the JVM. This classpath must include all of Hive - and its dependencies, including the correct version of Hadoop. These jars only need to be +
    6. A classpath in the standard format for the JVM. This classpath must include all of Hive + and its dependencies, including the correct version of Hadoop. These jars only need to be present on the driver, but if you are running in yarn cluster mode then you must ensure they are packaged with you application.
    @@ -1781,7 +1812,7 @@ The following options can be used to configure the version of Hive that is used ## JDBC To Other Databases -Spark SQL also includes a data source that can read data from other databases using JDBC. This +Spark SQL also includes a data source that can read data from other databases using JDBC. This functionality should be preferred over using [JdbcRDD](api/scala/index.html#org.apache.spark.rdd.JdbcRDD). This is because the results are returned as a DataFrame and they can easily be processed in Spark SQL or joined with other data sources. @@ -1791,15 +1822,15 @@ provide a ClassTag. run queries using Spark SQL). To get started you will need to include the JDBC driver for you particular database on the -spark classpath. For example, to connect to postgres from the Spark Shell you would run the +spark classpath. For example, to connect to postgres from the Spark Shell you would run the following command: {% highlight bash %} -SPARK_CLASSPATH=postgresql-9.3-1102-jdbc41.jar bin/spark-shell +bin/spark-shell --driver-class-path postgresql-9.4.1207.jar --jars postgresql-9.4.1207.jar {% endhighlight %} Tables from the remote database can be loaded as a DataFrame or Spark SQL Temporary table using -the Data Sources API. The following options are supported: +the Data Sources API. The following options are supported: @@ -1812,8 +1843,8 @@ the Data Sources API. The following options are supported: @@ -1821,15 +1852,14 @@ the Data Sources API. The following options are supported: + + + + + +
    Property NameMeaning
    dbtable - The JDBC table that should be read. Note that anything that is valid in a FROM clause of - a SQL query can be used. For example, instead of a full table you could also use a + The JDBC table that should be read. Note that anything that is valid in a FROM clause of + a SQL query can be used. For example, instead of a full table you could also use a subquery in parentheses.
    driver - The class name of the JDBC driver needed to connect to this URL. This class will be loaded - on the master and workers before running an JDBC commands to allow the driver to - register itself with the JDBC subsystem. + The class name of the JDBC driver to use to connect to this URL.
    partitionColumn, lowerBound, upperBound, numPartitions - These options must all be specified if any of them is specified. They describe how to + These options must all be specified if any of them is specified. They describe how to partition the table when reading in parallel from multiple workers. partitionColumn must be a numeric column from the table in question. Notice that lowerBound and upperBound are just used to decide the @@ -1837,6 +1867,13 @@ the Data Sources API. The following options are supported: partitioned and returned.
    fetchSize + The JDBC fetch size, which determines how many rows to fetch per round trip. This can help performance on JDBC drivers which default to low fetch size (eg. Oracle with 10 rows). +
    @@ -1855,7 +1892,7 @@ val jdbcDF = sqlContext.read.format("jdbc").options( {% highlight java %} -Map options = new HashMap(); +Map options = new HashMap<>(); options.put("url", "jdbc:postgresql:dbserver"); options.put("dbtable", "schema.tablename"); @@ -1935,7 +1972,7 @@ Configuration of in-memory caching can be done using the `setConf` method on `SQ spark.sql.inMemoryColumnarStorage.batchSize 10000 - Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization + Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization and compression, but risk OOMs when caching data. @@ -1944,7 +1981,7 @@ Configuration of in-memory caching can be done using the `setConf` method on `SQ ## Other Configuration Options -The following options can also be used to tune the performance of query execution. It is possible +The following options can also be used to tune the performance of query execution. It is possible that these options will be deprecated in future release as more optimizations are performed automatically. @@ -1954,7 +1991,7 @@ that these options will be deprecated in future release as more optimizations ar @@ -1992,8 +2029,8 @@ To start the JDBC/ODBC server, run the following in the Spark directory: ./sbin/start-thriftserver.sh This script accepts all `bin/spark-submit` command line options, plus a `--hiveconf` option to -specify Hive properties. You may run `./sbin/start-thriftserver.sh --help` for a complete list of -all available options. By default, the server listens on localhost:10000. You may override this +specify Hive properties. You may run `./sbin/start-thriftserver.sh --help` for a complete list of +all available options. By default, the server listens on localhost:10000. You may override this behaviour via either environment variables, i.e.: {% highlight bash %} @@ -2026,7 +2063,7 @@ Beeline will ask you for a username and password. In non-secure mode, simply ent your machine and a blank password. For secure mode, please follow the instructions given in the [beeline documentation](https://cwiki.apache.org/confluence/display/Hive/HiveServer2+Clients). -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. +Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` and `hdfs-site.xml` files in `conf/`. You may also use the beeline script that comes with Hive. @@ -2051,39 +2088,59 @@ To start the Spark SQL CLI, run the following in the Spark directory: ./bin/spark-sql -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. +Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` and `hdfs-site.xml` files in `conf/`. You may run `./bin/spark-sql --help` for a complete list of all available options. # Migration Guide +## Upgrading From Spark SQL 1.5 to 1.6 + + - From Spark 1.6, by default the Thrift server runs in multi-session mode. Which means each JDBC/ODBC + connection owns a copy of their own SQL configuration and temporary function registry. Cached + tables are still shared though. If you prefer to run the Thrift server in the old single-session + mode, please set option `spark.sql.hive.thriftServer.singleSession` to `true`. You may either add + this option to `spark-defaults.conf`, or pass it to `start-thriftserver.sh` via `--conf`: + + {% highlight bash %} + ./sbin/start-thriftserver.sh \ + --conf spark.sql.hive.thriftServer.singleSession=true \ + ... + {% endhighlight %} + - Since 1.6.1, withColumn method in sparkR supports adding a new column to or replacing existing columns + of the same name of a DataFrame. + + - From Spark 1.6, LongType casts to TimestampType expect seconds instead of microseconds. This + change was made to match the behavior of Hive 1.2 for more consistent type casting to TimestampType + from numeric types. See [SPARK-11724](https://issues.apache.org/jira/browse/SPARK-11724) for + details. + ## Upgrading From Spark SQL 1.4 to 1.5 - Optimized execution using manually managed memory (Tungsten) is now enabled by default, along with - code generation for expression evaluation. These features can both be disabled by setting + code generation for expression evaluation. These features can both be disabled by setting `spark.sql.tungsten.enabled` to `false`. - - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting + - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting `spark.sql.parquet.mergeSchema` to `true`. - Resolution of strings to columns in python now supports using dots (`.`) to qualify the column or - access nested values. For example `df['table.column.nestedField']`. However, this means that if - your column name contains any dots you must now escape them using backticks (e.g., ``table.`column.with.dots`.nested``). + access nested values. For example `df['table.column.nestedField']`. However, this means that if + your column name contains any dots you must now escape them using backticks (e.g., ``table.`column.with.dots`.nested``). - In-memory columnar storage partition pruning is on by default. It can be disabled by setting `spark.sql.inMemoryColumnarStorage.partitionPruning` to `false`. - Unlimited precision decimal columns are no longer supported, instead Spark SQL enforces a maximum - precision of 38. When inferring schema from `BigDecimal` objects, a precision of (38, 18) is now + precision of 38. When inferring schema from `BigDecimal` objects, a precision of (38, 18) is now used. When no precision is specified in DDL then the default remains `Decimal(10, 0)`. - Timestamps are now stored at a precision of 1us, rather than 1ns - - In the `sql` dialect, floating point numbers are now parsed as decimal. HiveQL parsing remains + - In the `sql` dialect, floating point numbers are now parsed as decimal. HiveQL parsing remains unchanged. - The canonical name of SQL/DataFrame functions are now lower case (e.g. sum vs SUM). - - It has been determined that using the DirectOutputCommitter when speculation is enabled is unsafe - and thus this output committer will not be used when speculation is on, independent of configuration. - JSON data source will not automatically load new files that are created by other applications (i.e. files that are not inserted to the dataset through Spark SQL). For a JSON persistent table (i.e. the metadata of the table is stored in Hive Metastore), users can use `REFRESH TABLE` SQL command or `HiveContext`'s `refreshTable` method to include those new files to the table. For a DataFrame representing a JSON dataset, users need to recreate the DataFrame and the new DataFrame will include new files. + - DataFrame.withColumn method in pySpark supports adding a new column or replacing existing columns of the same name. ## Upgrading from Spark SQL 1.3 to 1.4 @@ -2163,41 +2220,51 @@ sqlContext.setConf("spark.sql.retainGroupColumns", "false") +#### Behavior change on DataFrame.withColumn + +Prior to 1.4, DataFrame.withColumn() supports adding a column only. The column will always be added +as a new column with its specified name in the result DataFrame even if there may be any existing +columns of the same name. Since 1.4, DataFrame.withColumn() supports adding a column of a different +name from names of all existing columns or replacing existing columns of the same name. + +Note that this change is only for Scala API, not for PySpark and SparkR. + + ## Upgrading from Spark SQL 1.0-1.2 to 1.3 In Spark 1.3 we removed the "Alpha" label from Spark SQL and as part of this did a cleanup of the -available APIs. From Spark 1.3 onwards, Spark SQL will provide binary compatibility with other -releases in the 1.X series. This compatibility guarantee excludes APIs that are explicitly marked +available APIs. From Spark 1.3 onwards, Spark SQL will provide binary compatibility with other +releases in the 1.X series. This compatibility guarantee excludes APIs that are explicitly marked as unstable (i.e., DeveloperAPI or Experimental). #### Rename of SchemaRDD to DataFrame The largest change that users will notice when upgrading to Spark SQL 1.3 is that `SchemaRDD` has -been renamed to `DataFrame`. This is primarily because DataFrames no longer inherit from RDD +been renamed to `DataFrame`. This is primarily because DataFrames no longer inherit from RDD directly, but instead provide most of the functionality that RDDs provide though their own -implementation. DataFrames can still be converted to RDDs by calling the `.rdd` method. +implementation. DataFrames can still be converted to RDDs by calling the `.rdd` method. In Scala there is a type alias from `SchemaRDD` to `DataFrame` to provide source compatibility for -some use cases. It is still recommended that users update their code to use `DataFrame` instead. +some use cases. It is still recommended that users update their code to use `DataFrame` instead. Java and Python users will need to update their code. #### Unification of the Java and Scala APIs Prior to Spark 1.3 there were separate Java compatible classes (`JavaSQLContext` and `JavaSchemaRDD`) -that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users -of either language should use `SQLContext` and `DataFrame`. In general theses classes try to +that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users +of either language should use `SQLContext` and `DataFrame`. In general theses classes try to use types that are usable from both languages (i.e. `Array` instead of language specific collections). In some cases where no common type exists (e.g., for passing in closures or Maps) function overloading is used instead. -Additionally the Java specific types API has been removed. Users of both Scala and Java should +Additionally the Java specific types API has been removed. Users of both Scala and Java should use the classes present in `org.apache.spark.sql.types` to describe schema programmatically. #### Isolation of Implicit Conversions and Removal of dsl Package (Scala-only) Many of the code examples prior to Spark 1.3 started with `import sqlContext._`, which brought -all of the functions from sqlContext into scope. In Spark 1.3 we have isolated the implicit +all of the functions from sqlContext into scope. In Spark 1.3 we have isolated the implicit conversions for converting `RDD`s into `DataFrame`s into an object inside of the `SQLContext`. Users should now write `import sqlContext.implicits._`. @@ -2205,7 +2272,7 @@ Additionally, the implicit conversions now only augment RDDs that are composed o case classes or tuples) with a method `toDF`, instead of applying automatically. When using function inside of the DSL (now replaced with the `DataFrame` API) users used to import -`org.apache.spark.sql.catalyst.dsl`. Instead the public dataframe functions API should be used: +`org.apache.spark.sql.catalyst.dsl`. Instead the public dataframe functions API should be used: `import org.apache.spark.sql.functions._`. #### Removal of the type aliases in org.apache.spark.sql for DataType (Scala-only) @@ -2244,57 +2311,12 @@ Python UDF registration is unchanged. When using DataTypes in Python you will need to construct them (i.e. `StringType()`) instead of referencing a singleton. -## Migration Guide for Shark Users - -### Scheduling -To set a [Fair Scheduler](job-scheduling.html#fair-scheduler-pools) pool for a JDBC client session, -users can set the `spark.sql.thriftserver.scheduler.pool` variable: - - SET spark.sql.thriftserver.scheduler.pool=accounting; - -### Reducer number - -In Shark, default reducer number is 1 and is controlled by the property `mapred.reduce.tasks`. Spark -SQL deprecates this property in favor of `spark.sql.shuffle.partitions`, whose default value -is 200. Users may customize this property via `SET`: - - SET spark.sql.shuffle.partitions=10; - SELECT page, count(*) c - FROM logs_last_month_cached - GROUP BY page ORDER BY c DESC LIMIT 10; - -You may also put this property in `hive-site.xml` to override the default value. - -For now, the `mapred.reduce.tasks` property is still recognized, and is converted to -`spark.sql.shuffle.partitions` automatically. - -### Caching - -The `shark.cache` table property no longer exists, and tables whose name end with `_cached` are no -longer automatically cached. Instead, we provide `CACHE TABLE` and `UNCACHE TABLE` statements to -let user control table caching explicitly: - - CACHE TABLE logs_last_month; - UNCACHE TABLE logs_last_month; - -**NOTE:** `CACHE TABLE tbl` is now __eager__ by default not __lazy__. Don’t need to trigger cache materialization manually anymore. - -Spark SQL newly introduced a statement to let user control table caching whether or not lazy since Spark 1.2.0: - - CACHE [LAZY] TABLE [AS SELECT] ... - -Several caching related features are not supported yet: - -* User defined partition level cache eviction policy -* RDD reloading -* In-memory cache write through policy - ## Compatibility with Apache Hive Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 1.2.1. Also see http://spark.apache.org/docs/latest/sql-programming-guide.html#interacting-with-different-versions-of-hive-metastore). +(from 0.12.0 to 1.2.1. Also see [Interacting with Different Versions of Hive Metastore] (#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md index a75587a92adc7..a4e17fd24eac2 100644 --- a/docs/streaming-custom-receivers.md +++ b/docs/streaming-custom-receivers.md @@ -72,7 +72,8 @@ class CustomReceiver(host: String, port: Int) socket = new Socket(host, port) // Until stopped or connection broken continue reading - val reader = new BufferedReader(new InputStreamReader(socket.getInputStream(), "UTF-8")) + val reader = new BufferedReader( + new InputStreamReader(socket.getInputStream(), StandardCharsets.UTF_8)) userInput = reader.readLine() while(!isStopped && userInput != null) { store(userInput) @@ -135,7 +136,8 @@ public class JavaCustomReceiver extends Receiver { // connect to the server socket = new Socket(host, port); - BufferedReader reader = new BufferedReader(new InputStreamReader(socket.getInputStream())); + BufferedReader reader = new BufferedReader( + new InputStreamReader(socket.getInputStream(), StandardCharsets.UTF_8)); // Until stopped or connection broken continue reading while (!isStopped() && (userInput = reader.readLine()) != null) { @@ -254,28 +256,3 @@ The following table summarizes the characteristics of both types of receivers
    10485760 (10 MB) Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when - performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently + performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently statistics are only supported for Hive Metastore tables where the command ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run.
    - -## Implementing and Using a Custom Actor-based Receiver - -Custom [Akka Actors](http://doc.akka.io/docs/akka/2.2.4/scala/actors.html) can also be used to -receive data. The [`ActorHelper`](api/scala/index.html#org.apache.spark.streaming.receiver.ActorHelper) -trait can be applied on any Akka actor, which allows received data to be stored in Spark using - `store(...)` methods. The supervisor strategy of this actor can be configured to handle failures, etc. - -{% highlight scala %} -class CustomActor extends Actor with ActorHelper { - def receive = { - case data: String => store(data) - } -} -{% endhighlight %} - -And a new input stream can be created with this custom actor as - -{% highlight scala %} -// Assuming ssc is the StreamingContext -val lines = ssc.actorStream[String](Props(new CustomActor()), "CustomReceiver") -{% endhighlight %} - -See [ActorWordCount.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala) -for an end-to-end example. diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index 383d954409ce4..8eeeee75dbf40 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -30,7 +30,7 @@ See the [Flume's documentation](https://flume.apache.org/documentation.html) for configuring Flume agents. #### Configuring Spark Streaming Application -1. **Linking:** In your SBT/Maven projrect definition, link your streaming application against the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). +1. **Linking:** In your SBT/Maven project definition, link your streaming application against the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). groupId = org.apache.spark artifactId = spark-streaming-flume_{{site.SCALA_BINARY_VERSION}} @@ -71,7 +71,16 @@ configuring Flume agents. cluster (Mesos, YARN or Spark Standalone), so that resource allocation can match the names and launch the receiver in the right machine. -3. **Deploying:** Package `spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). +3. **Deploying:** As with any Spark applications, `spark-submit` is used to launch your application. However, the details are slightly different for Scala/Java applications and Python applications. + + For Scala and Java applications, if you are using SBT or Maven for project management, then package `spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). + + For Python applications which lack SBT/Maven project management, `spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}` and its dependencies can be directly added to `spark-submit` using `--packages` (see [Application Submission Guide](submitting-applications.html)). That is, + + ./bin/spark-submit --packages org.apache.spark:spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} ... + + Alternatively, you can also download the JAR of the Maven artifact `spark-streaming-flume-assembly` from the + [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-flume-assembly_{{site.SCALA_BINARY_VERSION}}%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. ## Approach 2: Pull-based Approach using a Custom Sink Instead of Flume pushing data directly to Spark Streaming, this approach runs a custom Flume sink that allows the following. @@ -157,7 +166,7 @@ configuring Flume agents. Note that each input DStream can be configured to receive data from multiple sinks. -3. **Deploying:** Package `spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). +3. **Deploying:** This is same as the first approach. diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index ab7f0117c0b7f..015a2f1fa0bdc 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -71,10 +71,10 @@ Next, we discuss how to use this approach in your streaming application. ./bin/spark-submit --packages org.apache.spark:spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} ... Alternatively, you can also download the JAR of the Maven artifact `spark-streaming-kafka-assembly` from the - [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-kafka-assembly_2.10%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. + [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-kafka-assembly_{{site.SCALA_BINARY_VERSION}}%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. ## Approach 2: Direct Approach (No Receivers) -This new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this is an experimental feature introduced in Spark 1.3 for the Scala and Java API. Spark 1.4 added a Python API, but it is not yet at full feature parity. +This new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this is an experimental feature introduced in Spark 1.3 for the Scala and Java API, in Spark 1.4 for the Python API. This approach has the following advantages over the receiver-based approach (i.e. Approach 1). @@ -104,6 +104,7 @@ Next, we discuss how to use this approach in your streaming application. [key class], [value class], [key decoder class], [value decoder class] ]( streamingContext, [map of Kafka parameters], [set of topics to consume]) + You can also pass a `messageHandler` to `createDirectStream` to access `MessageAndMetadata` that contains metadata about the current message and transform it to any desired type. See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala).
    @@ -115,6 +116,7 @@ Next, we discuss how to use this approach in your streaming application. [key class], [value class], [key decoder class], [value decoder class], [map of Kafka parameters], [set of topics to consume]); + You can also pass a `messageHandler` to `createDirectStream` to access `MessageAndMetadata` that contains metadata about the current message and transform it to any desired type. See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java). @@ -123,6 +125,7 @@ Next, we discuss how to use this approach in your streaming application. from pyspark.streaming.kafka import KafkaUtils directKafkaStream = KafkaUtils.createDirectStream(ssc, [topic], {"metadata.broker.list": brokers}) + You can also pass a `messageHandler` to `createDirectStream` to access `KafkaMessageAndMetadata` that contains metadata about the current message and transform it to any desired type. By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils) and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/direct_kafka_wordcount.py).
    @@ -181,7 +184,20 @@ Next, we discuss how to use this approach in your streaming application. );
    - Not supported yet + offsetRanges = [] + + def storeOffsetRanges(rdd): + global offsetRanges + offsetRanges = rdd.offsetRanges() + return rdd + + def printOffsetRanges(rdd): + for o in offsetRanges: + print "%s %s %s %s" % (o.topic, o.partition, o.fromOffset, o.untilOffset) + + directKafkaStream\ + .transform(storeOffsetRanges)\ + .foreachRDD(printOffsetRanges)
    @@ -191,4 +207,4 @@ Next, we discuss how to use this approach in your streaming application. Another thing to note is that since this approach does not use Receivers, the standard receiver-related (that is, [configurations](configuration.html) of the form `spark.streaming.receiver.*` ) will not apply to the input DStreams created by this approach (will apply to other input DStreams though). Instead, use the [configurations](configuration.html) `spark.streaming.kafka.*`. An important one is `spark.streaming.kafka.maxRatePerPartition` which is the maximum rate (in messages per second) at which each Kafka partition will be read by this direct API. -3. **Deploying:** This is same as the first approach, for Scala, Java and Python. +3. **Deploying:** This is same as the first approach. diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index 238a911a9199f..5b9a7554d2e64 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -15,15 +15,16 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m #### Configuring Spark Streaming Application -1. **Linking:** In your SBT/Maven project definition, link your streaming application against the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). +1. **Linking:** For Scala/Java applications using SBT/Maven project definitions, link your streaming application against the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). groupId = org.apache.spark artifactId = spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}} version = {{site.SPARK_VERSION_SHORT}} + For Python applications, you will have to add this above library and its dependencies when deploying your application. See the *Deploying* subsection below. **Note that by linking to this library, you will include [ASL](https://aws.amazon.com/asl/)-licensed code in your application.** -2. **Programming:** In the streaming application code, import `KinesisUtils` and create the input DStream as follows: +2. **Programming:** In the streaming application code, import `KinesisUtils` and create the input DStream of byte array as follows:
    @@ -36,7 +37,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2) See the [API docs](api/scala/index.html#org.apache.spark.streaming.kinesis.KinesisUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala). Refer to the Running the Example section for instructions on how to run the example. + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala). Refer to the [Running the Example](#running-the-example) subsection for instructions on how to run the example.
    @@ -49,7 +50,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2); See the [API docs](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java). Refer to the next subsection for instructions to run the example. + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java). Refer to the [Running the Example](#running-the-example) subsection for instructions to run the example.
    @@ -60,18 +61,47 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2) See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kinesis.KinesisUtils) - and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py). Refer to the next subsection for instructions to run the example. + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py). Refer to the [Running the Example](#running-the-example) subsection for instructions to run the example.
    - - `streamingContext`: StreamingContext containg an application name used by Kinesis to tie this Kinesis application to the Kinesis stream + You may also provide a "message handler function" that takes a Kinesis `Record` and returns a generic object `T`, in case you would like to use other data included in a `Record` such as partition key. This is currently only supported in Scala and Java. - - `[Kineiss app name]`: The application name that will be used to checkpoint the Kinesis +
    +
    + + import org.apache.spark.streaming.Duration + import org.apache.spark.streaming.kinesis._ + import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream + + val kinesisStream = KinesisUtils.createStream[T]( + streamingContext, [Kinesis app name], [Kinesis stream name], [endpoint URL], + [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2, + [message handler]) + +
    +
    + + import org.apache.spark.streaming.Duration; + import org.apache.spark.streaming.kinesis.*; + import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; + + JavaReceiverInputDStream kinesisStream = KinesisUtils.createStream( + streamingContext, [Kinesis app name], [Kinesis stream name], [endpoint URL], + [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2, + [message handler], [class T]); + +
    +
    + + - `streamingContext`: StreamingContext containing an application name used by Kinesis to tie this Kinesis application to the Kinesis stream + + - `[Kinesis app name]`: The application name that will be used to checkpoint the Kinesis sequence numbers in DynamoDB table. - The application name must be unique for a given account and region. - If the table exists but has incorrect checkpoint information (for a different stream, or - old expired sequenced numbers), then there may be temporary errors. + old expired sequenced numbers), then there may be temporary errors. - `[Kinesis stream name]`: The Kinesis stream that this streaming application will pull data from. @@ -83,9 +113,20 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m - `[initial position]`: Can be either `InitialPositionInStream.TRIM_HORIZON` or `InitialPositionInStream.LATEST` (see Kinesis Checkpointing section and Amazon Kinesis API documentation for more details). + - `[message handler]`: A function that takes a Kinesis `Record` and outputs generic `T`. + In other versions of the API, you can also specify the AWS access key and secret key directly. -3. **Deploying:** Package `spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). +3. **Deploying:** As with any Spark applications, `spark-submit` is used to launch your application. However, the details are slightly different for Scala/Java applications and Python applications. + + For Scala and Java applications, if you are using SBT or Maven for project management, then package `spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). + + For Python applications which lack SBT/Maven project management, `spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}` and its dependencies can be directly added to `spark-submit` using `--packages` (see [Application Submission Guide](submitting-applications.html)). That is, + + ./bin/spark-submit --packages org.apache.spark:spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} ... + + Alternatively, you can also download the JAR of the Maven artifact `spark-streaming-kinesis-asl-assembly` from the + [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-kinesis-asl-assembly_{{site.SCALA_BINARY_VERSION}}%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. *Points to remember at runtime:* @@ -99,7 +140,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m Spark Streaming Kinesis Architecture

    @@ -149,9 +190,9 @@ To run the example,
    - bin/spark-submit --jars extras/kinesis-asl/target/scala-*/\ + bin/spark-submit --jars external/kinesis-asl/target/scala-*/\ spark-streaming-kinesis-asl-assembly_*.jar \ - extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py \ + external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py \ [Kinesis app name] [Kinesis stream name] [endpoint URL] [region name]
    @@ -165,11 +206,16 @@ To run the example, This will push 1000 lines per second of 10 random numbers per line to the Kinesis stream. This data should then be received and processed by the running example. +#### Record De-aggregation + +When data is generated using the [Kinesis Producer Library (KPL)](http://docs.aws.amazon.com/kinesis/latest/dev/developing-producers-with-kpl.html), messages may be aggregated for cost savings. Spark Streaming will automatically +de-aggregate records during consumption. + #### Kinesis Checkpointing - Each Kinesis input DStream periodically stores the current position of the stream in the backing DynamoDB table. This allows the system to recover from failures and continue processing where the DStream left off. - Checkpointing too frequently will cause excess load on the AWS checkpoint storage layer and may lead to AWS throttling. The provided example handles this throttling with a random-backoff-retry strategy. -- If no Kinesis checkpoint info exists when the input DStream starts, it will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) or from the latest tip (InitialPostitionInStream.LATEST). This is configurable. -- InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no input DStreams are running (and no checkpoint info is being stored). +- If no Kinesis checkpoint info exists when the input DStream starts, it will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) or from the latest tip (InitialPositionInStream.LATEST). This is configurable. +- InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no input DStreams are running (and no checkpoint info is being stored). - InitialPositionInStream.TRIM_HORIZON may lead to duplicate processing of records where the impact is dependent on checkpoint frequency and processing idempotency. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index c751dbb41785a..7f6c0ed6994ba 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -11,7 +11,7 @@ description: Spark Streaming programming guide and tutorial for Spark SPARK_VERS # Overview Spark Streaming is an extension of the core Spark API that enables scalable, high-throughput, fault-tolerant stream processing of live data streams. Data can be ingested from many sources -like Kafka, Flume, Twitter, ZeroMQ, Kinesis, or TCP sockets, and can be processed using complex +like Kafka, Flume, Kinesis, or TCP sockets, and can be processed using complex algorithms expressed with high-level functions like `map`, `reduce`, `join` and `window`. Finally, processed data can be pushed out to filesystems, databases, and live dashboards. In fact, you can apply Spark's @@ -158,15 +158,15 @@ JavaReceiverInputDStream lines = jssc.socketTextStream("localhost", 9999 {% endhighlight %} This `lines` DStream represents the stream of data that will be received from the data -server. Each record in this stream is a line of text. Then, we want to split the the lines by +server. Each record in this stream is a line of text. Then, we want to split the lines by space into words. {% highlight java %} // Split each line into words JavaDStream words = lines.flatMap( new FlatMapFunction() { - @Override public Iterable call(String x) { - return Arrays.asList(x.split(" ")); + @Override public Iterator call(String x) { + return Arrays.asList(x.split(" ")).iterator(); } }); {% endhighlight %} @@ -186,7 +186,7 @@ Next, we want to count these words. JavaPairDStream pairs = words.mapToPair( new PairFunction() { @Override public Tuple2 call(String s) { - return new Tuple2(s, 1); + return new Tuple2<>(s, 1); } }); JavaPairDStream wordCounts = pairs.reduceByKey( @@ -419,9 +419,6 @@ some of the common ones are as follows. Kafka spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}} Flume spark-streaming-flume_{{site.SCALA_BINARY_VERSION}} Kinesis
    spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}} [Amazon Software License] - Twitter spark-streaming-twitter_{{site.SCALA_BINARY_VERSION}} - ZeroMQ spark-streaming-zeromq_{{site.SCALA_BINARY_VERSION}} - MQTT spark-streaming-mqtt_{{site.SCALA_BINARY_VERSION}} @@ -594,8 +591,8 @@ data from a source and stores it in Spark's memory for processing. Spark Streaming provides two categories of built-in streaming sources. - *Basic sources*: Sources directly available in the StreamingContext API. - Examples: file systems, socket connections, and Akka actors. -- *Advanced sources*: Sources like Kafka, Flume, Kinesis, Twitter, etc. are available through + Examples: file systems, and socket connections. +- *Advanced sources*: Sources like Kafka, Flume, Kinesis, etc. are available through extra utility classes. These require linking against extra dependencies as discussed in the [linking](#linking) section. @@ -631,7 +628,7 @@ as well as to run the receiver(s). We have already taken a look at the `ssc.socketTextStream(...)` in the [quick example](#a-quick-example) which creates a DStream from text data received over a TCP socket connection. Besides sockets, the StreamingContext API provides -methods for creating DStreams from files and Akka actors as input sources. +methods for creating DStreams from files as input sources. - **File Streams:** For reading data from files on any file system compatible with the HDFS API (that is, HDFS, S3, NFS, etc.), a DStream can be created as: @@ -658,17 +655,12 @@ methods for creating DStreams from files and Akka actors as input sources. Python API `fileStream` is not available in the Python API, only `textFileStream` is available. -- **Streams based on Custom Actors:** DStreams can be created with data streams received through Akka - actors by using `streamingContext.actorStream(actorProps, actor-name)`. See the [Custom Receiver - Guide](streaming-custom-receivers.html) for more details. - - Python API Since actors are available only in the Java and Scala - libraries, `actorStream` is not available in the Python API. +- **Streams based on Custom Receivers:** DStreams can be created with data streams received through custom receivers. See the [Custom Receiver + Guide](streaming-custom-receivers.html) and [DStream Akka](https://github.com/spark-packages/dstream-akka) for more details. - **Queue of RDDs as a Stream:** For testing a Spark Streaming application with test data, one can also create a DStream based on a queue of RDDs, using `streamingContext.queueStream(queueOfRDDs)`. Each RDD pushed into the queue will be treated as a batch of data in the DStream, and processed like a stream. -For more details on streams from sockets, files, and actors, -see the API documentations of the relevant functions in +For more details on streams from sockets and files, see the API documentations of the relevant functions in [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) for Scala, [JavaStreamingContext](api/java/index.html?org/apache/spark/streaming/api/java/JavaStreamingContext.html) for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.streaming.StreamingContext) for Python. @@ -677,38 +669,12 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea {:.no_toc} Python API As of Spark {{site.SPARK_VERSION_SHORT}}, -out of these sources, Kafka, Kinesis, Flume and MQTT are available in the Python API. +out of these sources, Kafka, Kinesis and Flume are available in the Python API. This category of sources require interfacing with external non-Spark libraries, some of them with complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts of dependencies, the functionality to create DStreams from these sources has been moved to separate -libraries that can be [linked](#linking) to explicitly when necessary. For example, if you want to -create a DStream using data from Twitter's stream of tweets, you have to do the following: - -1. *Linking*: Add the artifact `spark-streaming-twitter_{{site.SCALA_BINARY_VERSION}}` to the - SBT/Maven project dependencies. -1. *Programming*: Import the `TwitterUtils` class and create a DStream with - `TwitterUtils.createStream` as shown below. -1. *Deploying*: Generate an uber JAR with all the dependencies (including the dependency - `spark-streaming-twitter_{{site.SCALA_BINARY_VERSION}}` and its transitive dependencies) and - then deploy the application. This is further explained in the [Deploying section](#deploying-applications). - -
    -
    -{% highlight scala %} -import org.apache.spark.streaming.twitter._ - -TwitterUtils.createStream(ssc, None) -{% endhighlight %} -
    -
    -{% highlight java %} -import org.apache.spark.streaming.twitter.*; - -TwitterUtils.createStream(jssc); -{% endhighlight %} -
    -
    +libraries that can be [linked](#linking) to explicitly when necessary. Note that these advanced sources are not available in the Spark shell, hence applications based on these advanced sources cannot be tested in the shell. If you really want to use them in the Spark @@ -723,15 +689,6 @@ Some of these advanced sources are as follows. - **Kinesis:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kinesis Client Library 1.2.1. See the [Kinesis Integration Guide](streaming-kinesis-integration.html) for more details. -- **Twitter:** Spark Streaming's TwitterUtils uses Twitter4j 3.0.3 to get the public stream of tweets using - [Twitter's Streaming API](https://dev.twitter.com/docs/streaming-apis). Authentication information - can be provided by any of the [methods](http://twitter4j.org/en/configuration.html) supported by - Twitter4J library. You can either get the public stream, or get the filtered stream based on a - keywords. See the API documentation ([Scala](api/scala/index.html#org.apache.spark.streaming.twitter.TwitterUtils$), - [Java](api/java/index.html?org/apache/spark/streaming/twitter/TwitterUtils.html)) and examples - ([TwitterPopularTags]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala) - and [TwitterAlgebirdCMS]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala)). - ### Custom Sources {:.no_toc} @@ -798,7 +755,7 @@ Some of the common ones are as follows. reduce(func) Return a new DStream of single-element RDDs by aggregating the elements in each RDD of the source DStream using a function func (which takes two arguments and returns one). - The function should be associative so that it can be computed in parallel. + The function should be associative and commutative so that it can be computed in parallel. countByValue() @@ -872,16 +829,12 @@ val runningCounts = pairs.updateStateByKey[Int](updateFunction _) {% endhighlight %} The update function will be called for each word, with `newValues` having a sequence of 1's (from -the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete -Scala code, take a look at the example -[StatefulNetworkWordCount.scala]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache -/spark/examples/streaming/StatefulNetworkWordCount.scala). +the `(word, 1)` pairs) and the `runningCount` having the previous count.
    {% highlight java %} -import com.google.common.base.Optional; Function2, Optional, Optional> updateFunction = new Function2, Optional, Optional>() { @Override public Optional call(List values, Optional state) { @@ -1073,7 +1026,7 @@ said two parameters - windowLength and slideInterval. reduceByWindow(func, windowLength, slideInterval) Return a new single-element stream, created by aggregating elements in the stream over a - sliding interval using func. The function should be associative so that it can be computed + sliding interval using func. The function should be associative and commutative so that it can be computed correctly in parallel. @@ -1415,6 +1368,171 @@ Note that the connections in the pool should be lazily created on demand and tim *** +## Accumulators and Broadcast Variables + +[Accumulators](programming-guide.html#accumulators) and [Broadcast variables](programming-guide.html#broadcast-variables) cannot be recovered from checkpoint in Spark Streaming. If you enable checkpointing and use [Accumulators](programming-guide.html#accumulators) or [Broadcast variables](programming-guide.html#broadcast-variables) as well, you'll have to create lazily instantiated singleton instances for [Accumulators](programming-guide.html#accumulators) and [Broadcast variables](programming-guide.html#broadcast-variables) so that they can be re-instantiated after the driver restarts on failure. This is shown in the following example. + +
    +
    +{% highlight scala %} + +object WordBlacklist { + + @volatile private var instance: Broadcast[Seq[String]] = null + + def getInstance(sc: SparkContext): Broadcast[Seq[String]] = { + if (instance == null) { + synchronized { + if (instance == null) { + val wordBlacklist = Seq("a", "b", "c") + instance = sc.broadcast(wordBlacklist) + } + } + } + instance + } +} + +object DroppedWordsCounter { + + @volatile private var instance: Accumulator[Long] = null + + def getInstance(sc: SparkContext): Accumulator[Long] = { + if (instance == null) { + synchronized { + if (instance == null) { + instance = sc.accumulator(0L, "WordsInBlacklistCounter") + } + } + } + instance + } +} + +wordCounts.foreachRDD((rdd: RDD[(String, Int)], time: Time) => { + // Get or register the blacklist Broadcast + val blacklist = WordBlacklist.getInstance(rdd.sparkContext) + // Get or register the droppedWordsCounter Accumulator + val droppedWordsCounter = DroppedWordsCounter.getInstance(rdd.sparkContext) + // Use blacklist to drop words and use droppedWordsCounter to count them + val counts = rdd.filter { case (word, count) => + if (blacklist.value.contains(word)) { + droppedWordsCounter += count + false + } else { + true + } + }.collect() + val output = "Counts at time " + time + " " + counts +}) + +{% endhighlight %} + +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala). +
    +
    +{% highlight java %} + +class JavaWordBlacklist { + + private static volatile Broadcast> instance = null; + + public static Broadcast> getInstance(JavaSparkContext jsc) { + if (instance == null) { + synchronized (JavaWordBlacklist.class) { + if (instance == null) { + List wordBlacklist = Arrays.asList("a", "b", "c"); + instance = jsc.broadcast(wordBlacklist); + } + } + } + return instance; + } +} + +class JavaDroppedWordsCounter { + + private static volatile Accumulator instance = null; + + public static Accumulator getInstance(JavaSparkContext jsc) { + if (instance == null) { + synchronized (JavaDroppedWordsCounter.class) { + if (instance == null) { + instance = jsc.accumulator(0, "WordsInBlacklistCounter"); + } + } + } + return instance; + } +} + +wordCounts.foreachRDD(new Function2, Time, Void>() { + @Override + public Void call(JavaPairRDD rdd, Time time) throws IOException { + // Get or register the blacklist Broadcast + final Broadcast> blacklist = JavaWordBlacklist.getInstance(new JavaSparkContext(rdd.context())); + // Get or register the droppedWordsCounter Accumulator + final Accumulator droppedWordsCounter = JavaDroppedWordsCounter.getInstance(new JavaSparkContext(rdd.context())); + // Use blacklist to drop words and use droppedWordsCounter to count them + String counts = rdd.filter(new Function, Boolean>() { + @Override + public Boolean call(Tuple2 wordCount) throws Exception { + if (blacklist.value().contains(wordCount._1())) { + droppedWordsCounter.add(wordCount._2()); + return false; + } else { + return true; + } + } + }).collect().toString(); + String output = "Counts at time " + time + " " + counts; + } +} + +{% endhighlight %} + +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java). +
    +
    +{% highlight python %} + +def getWordBlacklist(sparkContext): + if ('wordBlacklist' not in globals()): + globals()['wordBlacklist'] = sparkContext.broadcast(["a", "b", "c"]) + return globals()['wordBlacklist'] + +def getDroppedWordsCounter(sparkContext): + if ('droppedWordsCounter' not in globals()): + globals()['droppedWordsCounter'] = sparkContext.accumulator(0) + return globals()['droppedWordsCounter'] + +def echo(time, rdd): + # Get or register the blacklist Broadcast + blacklist = getWordBlacklist(rdd.context) + # Get or register the droppedWordsCounter Accumulator + droppedWordsCounter = getDroppedWordsCounter(rdd.context) + + # Use blacklist to drop words and use droppedWordsCounter to count them + def filterFunc(wordCount): + if wordCount[0] in blacklist.value: + droppedWordsCounter.add(wordCount[1]) + False + else: + True + + counts = "Counts at time %s %s" % (time, rdd.filter(filterFunc).collect()) + +wordCounts.foreachRDD(echo) + +{% endhighlight %} + +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/recoverable_network_wordcount.py). + +
    +
    + +*** + ## DataFrame and SQL Operations You can easily use [DataFrames and SQL](sql-programming-guide.html) operations on streaming data. You have to create a SQLContext using the SparkContext that the StreamingContext is using. Furthermore this has to done such that it can be restarted on driver failures. This is done by creating a lazily instantiated singleton instance of SQLContext. This is shown in the following example. It modifies the earlier [word count example](#a-quick-example) to generate word counts using DataFrames and SQL. Each RDD is converted to a DataFrame, registered as a temporary table and then queried using SQL. @@ -1771,10 +1889,10 @@ To run a Spark Streaming applications, you need to have the following. - *Package the application JAR* - You have to compile your streaming application into a JAR. If you are using [`spark-submit`](submitting-applications.html) to start the application, then you will not need to provide Spark and Spark Streaming in the JAR. However, - if your application uses [advanced sources](#advanced-sources) (e.g. Kafka, Flume, Twitter), + if your application uses [advanced sources](#advanced-sources) (e.g. Kafka, Flume), then you will have to package the extra artifact they link to, along with their dependencies, - in the JAR that is used to deploy the application. For example, an application using `TwitterUtils` - will have to include `spark-streaming-twitter_{{site.SCALA_BINARY_VERSION}}` and all its + in the JAR that is used to deploy the application. For example, an application using `KafkaUtils` + will have to include `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and all its transitive dependencies in the application JAR. - *Configuring sufficient memory for the executors* - Since the received data must be stored in @@ -1820,7 +1938,11 @@ To run a Spark Streaming applications, you need to have the following. to increase aggregate throughput. Additionally, it is recommended that the replication of the received data within Spark be disabled when the write ahead log is enabled as the log is already stored in a replicated storage system. This can be done by setting the storage level for the - input stream to `StorageLevel.MEMORY_AND_DISK_SER`. + input stream to `StorageLevel.MEMORY_AND_DISK_SER`. While using S3 (or any file system that + does not support flushing) for _write ahead logs_, please remember to enable + `spark.streaming.driver.writeAheadLog.closeFileAfterWrite` and + `spark.streaming.receiver.writeAheadLog.closeFileAfterWrite`. See + [Spark Streaming Configuration](configuration.html#spark-streaming) for more details. - *Setting the max receiving rate* - If the cluster resources is not large enough for the streaming application to process data as fast as it is being received, the receivers can be rate limited @@ -1858,12 +1980,6 @@ contains serialized Scala/Java/Python objects and trying to deserialize objects modified classes may lead to errors. In this case, either start the upgraded app with a different checkpoint directory, or delete the previous checkpoint directory. -### Other Considerations -{:.no_toc} -If the data is being received by the receivers faster than what can be processed, -you can limit the rate by setting the [configuration parameter](configuration.html#spark-streaming) -`spark.streaming.receiver.maxRate`. - *** ## Monitoring Applications @@ -1936,7 +2052,7 @@ unifiedStream.print()
    {% highlight java %} int numStreams = 5; -List> kafkaStreams = new ArrayList>(numStreams); +List> kafkaStreams = new ArrayList<>(numStreams); for (int i = 0; i < numStreams; i++) { kafkaStreams.add(KafkaUtils.createStream(...)); } @@ -1948,8 +2064,8 @@ unifiedStream.print(); {% highlight python %} numStreams = 5 kafkaStreams = [KafkaUtils.createStream(...) for _ in range (numStreams)] -unifiedStream = streamingContext.union(kafkaStreams) -unifiedStream.print() +unifiedStream = streamingContext.union(*kafkaStreams) +unifiedStream.pprint() {% endhighlight %}
    @@ -2001,9 +2117,6 @@ If the number of tasks launched per second is high (say, 50 or more per second), of sending out tasks to the slaves may be significant and will make it hard to achieve sub-second latencies. The overhead can be reduced by the following changes: -* **Task Serialization**: Using Kryo serialization for serializing tasks can reduce the task - sizes, and therefore reduce the time taken to send them to the slaves. - * **Execution mode**: Running Spark in Standalone mode or coarse-grained Mesos mode leads to better task launch times than the fine-grained Mesos mode. Please refer to the [Running on Mesos guide](running-on-mesos.html) for more details. @@ -2065,7 +2178,7 @@ overall processing throughput of the system, its use is still recommended to ach consistent batch processing times. Make sure you set the CMS GC on both the driver (using `--driver-java-options` in `spark-submit`) and the executors (using [Spark configuration](configuration.html#runtime-environment) `spark.executor.extraJavaOptions`). * **Other tips**: To further reduce GC overheads, here are some more tips to try. - - Use Tachyon for off-heap storage of persisted RDDs. See more detail in the [Spark Programming Guide](programming-guide.html#rdd-persistence). + - Persist RDDs using the `OFF_HEAP` storage level. See more detail in the [Spark Programming Guide](programming-guide.html#rdd-persistence). - Use more executors with smaller heap sizes. This will reduce the GC pressure within each JVM heap. @@ -2247,8 +2360,7 @@ additional effort may be necessary to achieve exactly-once semantics. There are Between Spark 0.9.1 and Spark 1.0, there were a few API changes made to ensure future API stability. This section elaborates the steps required to migrate your existing code to 1.0. -**Input DStreams**: All operations that create an input stream (e.g., `StreamingContext.socketStream`, -`FlumeUtils.createStream`, etc.) now returns +**Input DStreams**: All operations that create an input stream (e.g., `StreamingContext.socketStream`, `FlumeUtils.createStream`, etc.) now returns [InputDStream](api/scala/index.html#org.apache.spark.streaming.dstream.InputDStream) / [ReceiverInputDStream](api/scala/index.html#org.apache.spark.streaming.dstream.ReceiverInputDStream) (instead of DStream) for Scala, and [JavaInputDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaInputDStream.html) / @@ -2283,13 +2395,8 @@ that can be called to store the data in Spark. So, to migrate your custom networ BlockGenerator object (does not exist any more in Spark 1.0 anyway), and use `store(...)` methods on received data. -**Actor-based Receivers**: Data could have been received using any Akka Actors by extending the actor class with -`org.apache.spark.streaming.receivers.Receiver` trait. This has been renamed to -[`org.apache.spark.streaming.receiver.ActorHelper`](api/scala/index.html#org.apache.spark.streaming.receiver.ActorHelper) -and the `pushBlock(...)` methods to store received data has been renamed to `store(...)`. Other helper classes in -the `org.apache.spark.streaming.receivers` package were also moved -to [`org.apache.spark.streaming.receiver`](api/scala/index.html#org.apache.spark.streaming.receiver.package) -package and renamed for better clarity. +**Actor-based Receivers**: The Actor-based Receiver APIs have been moved to [DStream Akka](https://github.com/spark-packages/dstream-akka). +Please refer to the project for more details. *************************************************************************************************** *************************************************************************************************** @@ -2297,9 +2404,13 @@ package and renamed for better clarity. # Where to Go from Here * Additional guides - [Kafka Integration Guide](streaming-kafka-integration.html) - - [Flume Integration Guide](streaming-flume-integration.html) - [Kinesis Integration Guide](streaming-kinesis-integration.html) - [Custom Receiver Guide](streaming-custom-receivers.html) +* External DStream data sources: + - [DStream MQTT](https://github.com/spark-packages/dstream-mqtt) + - [DStream Twitter](https://github.com/spark-packages/dstream-twitter) + - [DStream Akka](https://github.com/spark-packages/dstream-akka) + - [DStream ZeroMQ](https://github.com/spark-packages/dstream-zeromq) * API documentation - Scala docs * [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) and @@ -2307,9 +2418,6 @@ package and renamed for better clarity. * [KafkaUtils](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$), [FlumeUtils](api/scala/index.html#org.apache.spark.streaming.flume.FlumeUtils$), [KinesisUtils](api/scala/index.html#org.apache.spark.streaming.kinesis.KinesisUtils$), - [TwitterUtils](api/scala/index.html#org.apache.spark.streaming.twitter.TwitterUtils$), - [ZeroMQUtils](api/scala/index.html#org.apache.spark.streaming.zeromq.ZeroMQUtils$), and - [MQTTUtils](api/scala/index.html#org.apache.spark.streaming.mqtt.MQTTUtils$) - Java docs * [JavaStreamingContext](api/java/index.html?org/apache/spark/streaming/api/java/JavaStreamingContext.html), [JavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaDStream.html) and @@ -2317,9 +2425,6 @@ package and renamed for better clarity. * [KafkaUtils](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html), [FlumeUtils](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html), [KinesisUtils](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html) - [TwitterUtils](api/java/index.html?org/apache/spark/streaming/twitter/TwitterUtils.html), - [ZeroMQUtils](api/java/index.html?org/apache/spark/streaming/zeromq/ZeroMQUtils.html), and - [MQTTUtils](api/java/index.html?org/apache/spark/streaming/mqtt/MQTTUtils.html) - Python docs * [StreamingContext](api/python/pyspark.streaming.html#pyspark.streaming.StreamingContext) and [DStream](api/python/pyspark.streaming.html#pyspark.streaming.DStream) * [KafkaUtils](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index ac2a14eb56fea..100ff0b147efd 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -30,7 +30,7 @@ dependencies, and can support different cluster managers and deploy modes that S {% highlight bash %} ./bin/spark-submit \ - --class + --class \ --master \ --deploy-mode \ --conf = \ @@ -58,8 +58,7 @@ for applications that involve the REPL (e.g. Spark shell). Alternatively, if your application is submitted from a machine far from the worker machines (e.g. locally on your laptop), it is common to use `cluster` mode to minimize network latency between -the drivers and the executors. Note that `cluster` mode is currently not supported for -Mesos clusters. Currently only YARN supports cluster mode for Python applications. +the drivers and the executors. Currently only YARN supports cluster mode for Python applications. For Python applications, simply pass a `.py` file in the place of `` instead of a JAR, and add Python `.zip`, `.egg` or `.py` files to the search path with `--py-files`. @@ -92,8 +91,8 @@ run it with `--help`. Here are a few examples of common options: ./bin/spark-submit \ --class org.apache.spark.examples.SparkPi \ --master spark://207.184.161.138:7077 \ - --deploy-mode cluster - --supervise + --deploy-mode cluster \ + --supervise \ --executor-memory 20G \ --total-executor-cores 100 \ /path/to/examples.jar \ @@ -115,6 +114,18 @@ export HADOOP_CONF_DIR=XXX --master spark://207.184.161.138:7077 \ examples/src/main/python/pi.py \ 1000 + +# Run on a Mesos cluster in cluster deploy mode with supervise +./bin/spark-submit \ + --class org.apache.spark.examples.SparkPi \ + --master mesos://207.184.161.138:7077 \ + --deploy-mode cluster \ + --supervise \ + --executor-memory 20G \ + --total-executor-cores 100 \ + http://path/to/examples.jar \ + 1000 + {% endhighlight %} # Master URLs @@ -132,17 +143,12 @@ The master URL passed to Spark can be in one of the following formats: mesos://HOST:PORT Connect to the given Mesos cluster. The port must be whichever one your is configured to use, which is 5050 by default. Or, for a Mesos cluster using ZooKeeper, use mesos://zk://.... + To submit with --deploy-mode cluster, the HOST:PORT should be configured to connect to the MesosClusterDispatcher. yarn Connect to a YARN cluster in - client or cluster mode depending on the value of --deploy-mode. + client or cluster mode depending on the value of --deploy-mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable. - yarn-client Equivalent to yarn with --deploy-mode client, - which is preferred to `yarn-client` - - yarn-cluster Equivalent to yarn with --deploy-mode cluster, - which is preferred to `yarn-cluster` - @@ -164,8 +170,9 @@ debugging information by running `spark-submit` with the `--verbose` option. # Advanced Dependency Management When using `spark-submit`, the application jar along with any jars included with the `--jars` option -will be automatically transferred to the cluster. Spark uses the following URL scheme to allow -different strategies for disseminating jars: +will be automatically transferred to the cluster. URLs supplied after `--jars` must be separated by commas. That list is included on the driver and executor classpaths. Directory expansion does not work with `--jars`. + +Spark uses the following URL scheme to allow different strategies for disseminating jars: - **file:** - Absolute paths and `file:/` URIs are served by the driver's HTTP file server, and every executor pulls the file from the driver HTTP server. diff --git a/docs/tuning.md b/docs/tuning.md index 6936912a6be54..e73ed69ffbbf8 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -61,8 +61,8 @@ The [Kryo documentation](https://github.com/EsotericSoftware/kryo) describes mor registration options, such as adding custom serialization code. If your objects are large, you may also need to increase the `spark.kryoserializer.buffer` -config property. The default is 2, but this value needs to be large enough to hold the *largest* -object you will serialize. +[config](configuration.html#compression-and-serialization). This value needs to be large enough +to hold the *largest* object you will serialize. Finally, if you don't register your custom classes, Kryo will still work, but it will have to store the full class name with each object, which is wasteful. @@ -88,9 +88,39 @@ than the "raw" data inside their fields. This is due to several reasons: but also pointers (typically 8 bytes each) to the next object in the list. * Collections of primitive types often store them as "boxed" objects such as `java.lang.Integer`. -This section will discuss how to determine the memory usage of your objects, and how to improve -it -- either by changing your data structures, or by storing data in a serialized format. -We will then cover tuning Spark's cache size and the Java garbage collector. +This section will start with an overview of memory management in Spark, then discuss specific +strategies the user can take to make more efficient use of memory in his/her application. In +particular, we will describe how to determine the memory usage of your objects, and how to +improve it -- either by changing your data structures, or by storing data in a serialized +format. We will then cover tuning Spark's cache size and the Java garbage collector. + +## Memory Management Overview + +Memory usage in Spark largely falls under one of two categories: execution and storage. +Execution memory refers to that used for computation in shuffles, joins, sorts and aggregations, +while storage memory refers to that used for caching and propagating internal data across the +cluster. In Spark, execution and storage share a unified region (M). When no execution memory is +used, storage can acquire all the available memory and vice versa. Execution may evict storage +if necessary, but only until total storage memory usage falls under a certain threshold (R). +In other words, `R` describes a subregion within `M` where cached blocks are never evicted. +Storage may not evict execution due to complexities in implementation. + +This design ensures several desirable properties. First, applications that do not use caching +can use the entire space for execution, obviating unnecessary disk spills. Second, applications +that do use caching can reserve a minimum storage space (R) where their data blocks are immune +to being evicted. Lastly, this approach provides reasonable out-of-the-box performance for a +variety of workloads without requiring user expertise of how memory is divided internally. + +Although there are two relevant configurations, the typical user should not need to adjust them +as the default values are applicable to most workloads: + +* `spark.memory.fraction` expresses the size of `M` as a fraction of the (JVM heap space - 300MB) +(default 0.75). The rest of the space (25%) is reserved for user data structures, internal +metadata in Spark, and safeguarding against OOM errors in the case of sparse and unusually +large records. +* `spark.memory.storageFraction` expresses the size of `R` as a fraction of `M` (default 0.5). +`R` is the storage space within `M` where cached blocks immune to being evicted by execution. + ## Determining Memory Consumption @@ -151,18 +181,6 @@ time spent GC. This can be done by adding `-verbose:gc -XX:+PrintGCDetails -XX:+ each time a garbage collection occurs. Note these logs will be on your cluster's worker nodes (in the `stdout` files in their work directories), *not* on your driver program. -**Cache Size Tuning** - -One important configuration parameter for GC is the amount of memory that should be used for caching RDDs. -By default, Spark uses 60% of the configured executor memory (`spark.executor.memory`) to -cache RDDs. This means that 40% of memory is available for any objects created during task execution. - -In case your tasks slow down and you find that your JVM is garbage-collecting frequently or running out of -memory, lowering this value will help reduce the memory consumption. To change this to, say, 50%, you can call -`conf.set("spark.storage.memoryFraction", "0.5")` on your SparkConf. Combined with the use of serialized caching, -using a smaller cache should be sufficient to mitigate most of the garbage collection problems. -In case you are interested in further tuning the Java GC, continue reading below. - **Advanced GC Tuning** To further tune garbage collection, we first need to understand some basic information about memory management in the JVM: @@ -183,9 +201,9 @@ temporary objects created during task execution. Some steps which may be useful * Check if there are too many garbage collections by collecting GC stats. If a full GC is invoked multiple times for before a task completes, it means that there isn't enough memory available for executing tasks. -* In the GC stats that are printed, if the OldGen is close to being full, reduce the amount of memory used for caching. - This can be done using the `spark.storage.memoryFraction` property. It is better to cache fewer objects than to slow - down task execution! +* In the GC stats that are printed, if the OldGen is close to being full, reduce the amount of + memory used for caching by lowering `spark.memory.storageFraction`; it is better to cache fewer + objects than to slow down task execution! * If there are too many minor collections but not many major GCs, allocating more memory for Eden would help. You can set the size of the Eden to be an over-estimate of how much memory each task will need. If the size of Eden diff --git a/ec2/README b/ec2/README deleted file mode 100644 index 72434f24bf98d..0000000000000 --- a/ec2/README +++ /dev/null @@ -1,4 +0,0 @@ -This folder contains a script, spark-ec2, for launching Spark clusters on -Amazon EC2. Usage instructions are available online at: - -http://spark.apache.org/docs/latest/ec2-scripts.html diff --git a/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh b/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh deleted file mode 100644 index 4f3e8da809f7f..0000000000000 --- a/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/usr/bin/env bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# These variables are automatically filled in by the spark-ec2 script. -export MASTERS="{{master_list}}" -export SLAVES="{{slave_list}}" -export HDFS_DATA_DIRS="{{hdfs_data_dirs}}" -export MAPRED_LOCAL_DIRS="{{mapred_local_dirs}}" -export SPARK_LOCAL_DIRS="{{spark_local_dirs}}" -export MODULES="{{modules}}" -export SPARK_VERSION="{{spark_version}}" -export TACHYON_VERSION="{{tachyon_version}}" -export HADOOP_MAJOR_VERSION="{{hadoop_major_version}}" -export SWAP_MB="{{swap}}" -export SPARK_WORKER_INSTANCES="{{spark_worker_instances}}" -export SPARK_MASTER_OPTS="{{spark_master_opts}}" -export AWS_ACCESS_KEY_ID="{{aws_access_key_id}}" -export AWS_SECRET_ACCESS_KEY="{{aws_secret_access_key}}" diff --git a/ec2/spark-ec2 b/ec2/spark-ec2 deleted file mode 100755 index 26e7d22655694..0000000000000 --- a/ec2/spark-ec2 +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/sh - -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Preserve the user's CWD so that relative paths are passed correctly to -#+ the underlying Python script. -SPARK_EC2_DIR="$(dirname "$0")" - -python -Wdefault "${SPARK_EC2_DIR}/spark_ec2.py" "$@" diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py deleted file mode 100755 index 9327e21e43db7..0000000000000 --- a/ec2/spark_ec2.py +++ /dev/null @@ -1,1520 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import division, print_function, with_statement - -import codecs -import hashlib -import itertools -import logging -import os -import os.path -import pipes -import random -import shutil -import string -from stat import S_IRUSR -import subprocess -import sys -import tarfile -import tempfile -import textwrap -import time -import warnings -from datetime import datetime -from optparse import OptionParser -from sys import stderr - -if sys.version < "3": - from urllib2 import urlopen, Request, HTTPError -else: - from urllib.request import urlopen, Request - from urllib.error import HTTPError - raw_input = input - xrange = range - -SPARK_EC2_VERSION = "1.5.0" -SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) - -VALID_SPARK_VERSIONS = set([ - "0.7.3", - "0.8.0", - "0.8.1", - "0.9.0", - "0.9.1", - "0.9.2", - "1.0.0", - "1.0.1", - "1.0.2", - "1.1.0", - "1.1.1", - "1.2.0", - "1.2.1", - "1.3.0", - "1.3.1", - "1.4.0", - "1.4.1", - "1.5.0" -]) - -SPARK_TACHYON_MAP = { - "1.0.0": "0.4.1", - "1.0.1": "0.4.1", - "1.0.2": "0.4.1", - "1.1.0": "0.5.0", - "1.1.1": "0.5.0", - "1.2.0": "0.5.0", - "1.2.1": "0.5.0", - "1.3.0": "0.5.0", - "1.3.1": "0.5.0", - "1.4.0": "0.6.4", - "1.4.1": "0.6.4", - "1.5.0": "0.7.1" -} - -DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION -DEFAULT_SPARK_GITHUB_REPO = "https://github.com/apache/spark" - -# Default location to get the spark-ec2 scripts (and ami-list) from -DEFAULT_SPARK_EC2_GITHUB_REPO = "https://github.com/amplab/spark-ec2" -DEFAULT_SPARK_EC2_BRANCH = "branch-1.5" - - -def setup_external_libs(libs): - """ - Download external libraries from PyPI to SPARK_EC2_DIR/lib/ and prepend them to our PATH. - """ - PYPI_URL_PREFIX = "https://pypi.python.org/packages/source" - SPARK_EC2_LIB_DIR = os.path.join(SPARK_EC2_DIR, "lib") - - if not os.path.exists(SPARK_EC2_LIB_DIR): - print("Downloading external libraries that spark-ec2 needs from PyPI to {path}...".format( - path=SPARK_EC2_LIB_DIR - )) - print("This should be a one-time operation.") - os.mkdir(SPARK_EC2_LIB_DIR) - - for lib in libs: - versioned_lib_name = "{n}-{v}".format(n=lib["name"], v=lib["version"]) - lib_dir = os.path.join(SPARK_EC2_LIB_DIR, versioned_lib_name) - - if not os.path.isdir(lib_dir): - tgz_file_path = os.path.join(SPARK_EC2_LIB_DIR, versioned_lib_name + ".tar.gz") - print(" - Downloading {lib}...".format(lib=lib["name"])) - download_stream = urlopen( - "{prefix}/{first_letter}/{lib_name}/{lib_name}-{lib_version}.tar.gz".format( - prefix=PYPI_URL_PREFIX, - first_letter=lib["name"][:1], - lib_name=lib["name"], - lib_version=lib["version"] - ) - ) - with open(tgz_file_path, "wb") as tgz_file: - tgz_file.write(download_stream.read()) - with open(tgz_file_path, "rb") as tar: - if hashlib.md5(tar.read()).hexdigest() != lib["md5"]: - print("ERROR: Got wrong md5sum for {lib}.".format(lib=lib["name"]), file=stderr) - sys.exit(1) - tar = tarfile.open(tgz_file_path) - tar.extractall(path=SPARK_EC2_LIB_DIR) - tar.close() - os.remove(tgz_file_path) - print(" - Finished downloading {lib}.".format(lib=lib["name"])) - sys.path.insert(1, lib_dir) - - -# Only PyPI libraries are supported. -external_libs = [ - { - "name": "boto", - "version": "2.34.0", - "md5": "5556223d2d0cc4d06dd4829e671dcecd" - } -] - -setup_external_libs(external_libs) - -import boto -from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType, EBSBlockDeviceType -from boto import ec2 - - -class UsageError(Exception): - pass - - -# Configure and parse our command-line arguments -def parse_args(): - parser = OptionParser( - prog="spark-ec2", - version="%prog {v}".format(v=SPARK_EC2_VERSION), - usage="%prog [options] \n\n" - + " can be: launch, destroy, login, stop, start, get-master, reboot-slaves") - - parser.add_option( - "-s", "--slaves", type="int", default=1, - help="Number of slaves to launch (default: %default)") - parser.add_option( - "-w", "--wait", type="int", - help="DEPRECATED (no longer necessary) - Seconds to wait for nodes to start") - parser.add_option( - "-k", "--key-pair", - help="Key pair to use on instances") - parser.add_option( - "-i", "--identity-file", - help="SSH private key file to use for logging into instances") - parser.add_option( - "-p", "--profile", default=None, - help="If you have multiple profiles (AWS or boto config), you can configure " + - "additional, named profiles by using this option (default: %default)") - parser.add_option( - "-t", "--instance-type", default="m1.large", - help="Type of instance to launch (default: %default). " + - "WARNING: must be 64-bit; small instances won't work") - parser.add_option( - "-m", "--master-instance-type", default="", - help="Master instance type (leave empty for same as instance-type)") - parser.add_option( - "-r", "--region", default="us-east-1", - help="EC2 region used to launch instances in, or to find them in (default: %default)") - parser.add_option( - "-z", "--zone", default="", - help="Availability zone to launch instances in, or 'all' to spread " + - "slaves across multiple (an additional $0.01/Gb for bandwidth" + - "between zones applies) (default: a single zone chosen at random)") - parser.add_option( - "-a", "--ami", - help="Amazon Machine Image ID to use") - parser.add_option( - "-v", "--spark-version", default=DEFAULT_SPARK_VERSION, - help="Version of Spark to use: 'X.Y.Z' or a specific git hash (default: %default)") - parser.add_option( - "--spark-git-repo", - default=DEFAULT_SPARK_GITHUB_REPO, - help="Github repo from which to checkout supplied commit hash (default: %default)") - parser.add_option( - "--spark-ec2-git-repo", - default=DEFAULT_SPARK_EC2_GITHUB_REPO, - help="Github repo from which to checkout spark-ec2 (default: %default)") - parser.add_option( - "--spark-ec2-git-branch", - default=DEFAULT_SPARK_EC2_BRANCH, - help="Github repo branch of spark-ec2 to use (default: %default)") - parser.add_option( - "--deploy-root-dir", - default=None, - help="A directory to copy into / on the first master. " + - "Must be absolute. Note that a trailing slash is handled as per rsync: " + - "If you omit it, the last directory of the --deploy-root-dir path will be created " + - "in / before copying its contents. If you append the trailing slash, " + - "the directory is not created and its contents are copied directly into /. " + - "(default: %default).") - parser.add_option( - "--hadoop-major-version", default="1", - help="Major version of Hadoop. Valid options are 1 (Hadoop 1.0.4), 2 (CDH 4.2.0), yarn " + - "(Hadoop 2.4.0) (default: %default)") - parser.add_option( - "-D", metavar="[ADDRESS:]PORT", dest="proxy_port", - help="Use SSH dynamic port forwarding to create a SOCKS proxy at " + - "the given local address (for use with login)") - parser.add_option( - "--resume", action="store_true", default=False, - help="Resume installation on a previously launched cluster " + - "(for debugging)") - parser.add_option( - "--ebs-vol-size", metavar="SIZE", type="int", default=0, - help="Size (in GB) of each EBS volume.") - parser.add_option( - "--ebs-vol-type", default="standard", - help="EBS volume type (e.g. 'gp2', 'standard').") - parser.add_option( - "--ebs-vol-num", type="int", default=1, - help="Number of EBS volumes to attach to each node as /vol[x]. " + - "The volumes will be deleted when the instances terminate. " + - "Only possible on EBS-backed AMIs. " + - "EBS volumes are only attached if --ebs-vol-size > 0. " + - "Only support up to 8 EBS volumes.") - parser.add_option( - "--placement-group", type="string", default=None, - help="Which placement group to try and launch " + - "instances into. Assumes placement group is already " + - "created.") - parser.add_option( - "--swap", metavar="SWAP", type="int", default=1024, - help="Swap space to set up per node, in MB (default: %default)") - parser.add_option( - "--spot-price", metavar="PRICE", type="float", - help="If specified, launch slaves as spot instances with the given " + - "maximum price (in dollars)") - parser.add_option( - "--ganglia", action="store_true", default=True, - help="Setup Ganglia monitoring on cluster (default: %default). NOTE: " + - "the Ganglia page will be publicly accessible") - parser.add_option( - "--no-ganglia", action="store_false", dest="ganglia", - help="Disable Ganglia monitoring for the cluster") - parser.add_option( - "-u", "--user", default="root", - help="The SSH user you want to connect as (default: %default)") - parser.add_option( - "--delete-groups", action="store_true", default=False, - help="When destroying a cluster, delete the security groups that were created") - parser.add_option( - "--use-existing-master", action="store_true", default=False, - help="Launch fresh slaves, but use an existing stopped master if possible") - parser.add_option( - "--worker-instances", type="int", default=1, - help="Number of instances per worker: variable SPARK_WORKER_INSTANCES. Not used if YARN " + - "is used as Hadoop major version (default: %default)") - parser.add_option( - "--master-opts", type="string", default="", - help="Extra options to give to master through SPARK_MASTER_OPTS variable " + - "(e.g -Dspark.worker.timeout=180)") - parser.add_option( - "--user-data", type="string", default="", - help="Path to a user-data file (most AMIs interpret this as an initialization script)") - parser.add_option( - "--authorized-address", type="string", default="0.0.0.0/0", - help="Address to authorize on created security groups (default: %default)") - parser.add_option( - "--additional-security-group", type="string", default="", - help="Additional security group to place the machines in") - parser.add_option( - "--additional-tags", type="string", default="", - help="Additional tags to set on the machines; tags are comma-separated, while name and " + - "value are colon separated; ex: \"Task:MySparkProject,Env:production\"") - parser.add_option( - "--copy-aws-credentials", action="store_true", default=False, - help="Add AWS credentials to hadoop configuration to allow Spark to access S3") - parser.add_option( - "--subnet-id", default=None, - help="VPC subnet to launch instances in") - parser.add_option( - "--vpc-id", default=None, - help="VPC to launch instances in") - parser.add_option( - "--private-ips", action="store_true", default=False, - help="Use private IPs for instances rather than public if VPC/subnet " + - "requires that.") - parser.add_option( - "--instance-initiated-shutdown-behavior", default="stop", - choices=["stop", "terminate"], - help="Whether instances should terminate when shut down or just stop") - parser.add_option( - "--instance-profile-name", default=None, - help="IAM profile name to launch instances under") - - (opts, args) = parser.parse_args() - if len(args) != 2: - parser.print_help() - sys.exit(1) - (action, cluster_name) = args - - # Boto config check - # http://boto.cloudhackers.com/en/latest/boto_config_tut.html - home_dir = os.getenv('HOME') - if home_dir is None or not os.path.isfile(home_dir + '/.boto'): - if not os.path.isfile('/etc/boto.cfg'): - # If there is no boto config, check aws credentials - if not os.path.isfile(home_dir + '/.aws/credentials'): - if os.getenv('AWS_ACCESS_KEY_ID') is None: - print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set", - file=stderr) - sys.exit(1) - if os.getenv('AWS_SECRET_ACCESS_KEY') is None: - print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set", - file=stderr) - sys.exit(1) - return (opts, action, cluster_name) - - -# Get the EC2 security group of the given name, creating it if it doesn't exist -def get_or_make_group(conn, name, vpc_id): - groups = conn.get_all_security_groups() - group = [g for g in groups if g.name == name] - if len(group) > 0: - return group[0] - else: - print("Creating security group " + name) - return conn.create_security_group(name, "Spark EC2 group", vpc_id) - - -def get_validate_spark_version(version, repo): - if "." in version: - version = version.replace("v", "") - if version not in VALID_SPARK_VERSIONS: - print("Don't know about Spark version: {v}".format(v=version), file=stderr) - sys.exit(1) - return version - else: - github_commit_url = "{repo}/commit/{commit_hash}".format(repo=repo, commit_hash=version) - request = Request(github_commit_url) - request.get_method = lambda: 'HEAD' - try: - response = urlopen(request) - except HTTPError as e: - print("Couldn't validate Spark commit: {url}".format(url=github_commit_url), - file=stderr) - print("Received HTTP response code of {code}.".format(code=e.code), file=stderr) - sys.exit(1) - return version - - -# Source: http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/ -# Last Updated: 2015-06-19 -# For easy maintainability, please keep this manually-inputted dictionary sorted by key. -EC2_INSTANCE_TYPES = { - "c1.medium": "pvm", - "c1.xlarge": "pvm", - "c3.large": "pvm", - "c3.xlarge": "pvm", - "c3.2xlarge": "pvm", - "c3.4xlarge": "pvm", - "c3.8xlarge": "pvm", - "c4.large": "hvm", - "c4.xlarge": "hvm", - "c4.2xlarge": "hvm", - "c4.4xlarge": "hvm", - "c4.8xlarge": "hvm", - "cc1.4xlarge": "hvm", - "cc2.8xlarge": "hvm", - "cg1.4xlarge": "hvm", - "cr1.8xlarge": "hvm", - "d2.xlarge": "hvm", - "d2.2xlarge": "hvm", - "d2.4xlarge": "hvm", - "d2.8xlarge": "hvm", - "g2.2xlarge": "hvm", - "g2.8xlarge": "hvm", - "hi1.4xlarge": "pvm", - "hs1.8xlarge": "pvm", - "i2.xlarge": "hvm", - "i2.2xlarge": "hvm", - "i2.4xlarge": "hvm", - "i2.8xlarge": "hvm", - "m1.small": "pvm", - "m1.medium": "pvm", - "m1.large": "pvm", - "m1.xlarge": "pvm", - "m2.xlarge": "pvm", - "m2.2xlarge": "pvm", - "m2.4xlarge": "pvm", - "m3.medium": "hvm", - "m3.large": "hvm", - "m3.xlarge": "hvm", - "m3.2xlarge": "hvm", - "m4.large": "hvm", - "m4.xlarge": "hvm", - "m4.2xlarge": "hvm", - "m4.4xlarge": "hvm", - "m4.10xlarge": "hvm", - "r3.large": "hvm", - "r3.xlarge": "hvm", - "r3.2xlarge": "hvm", - "r3.4xlarge": "hvm", - "r3.8xlarge": "hvm", - "t1.micro": "pvm", - "t2.micro": "hvm", - "t2.small": "hvm", - "t2.medium": "hvm", - "t2.large": "hvm", -} - - -def get_tachyon_version(spark_version): - return SPARK_TACHYON_MAP.get(spark_version, "") - - -# Attempt to resolve an appropriate AMI given the architecture and region of the request. -def get_spark_ami(opts): - if opts.instance_type in EC2_INSTANCE_TYPES: - instance_type = EC2_INSTANCE_TYPES[opts.instance_type] - else: - instance_type = "pvm" - print("Don't recognize %s, assuming type is pvm" % opts.instance_type, file=stderr) - - # URL prefix from which to fetch AMI information - ami_prefix = "{r}/{b}/ami-list".format( - r=opts.spark_ec2_git_repo.replace("https://github.com", "https://raw.github.com", 1), - b=opts.spark_ec2_git_branch) - - ami_path = "%s/%s/%s" % (ami_prefix, opts.region, instance_type) - reader = codecs.getreader("ascii") - try: - ami = reader(urlopen(ami_path)).read().strip() - except: - print("Could not resolve AMI at: " + ami_path, file=stderr) - sys.exit(1) - - print("Spark AMI: " + ami) - return ami - - -# Launch a cluster of the given name, by setting up its security groups, -# and then starting new instances in them. -# Returns a tuple of EC2 reservation objects for the master and slaves -# Fails if there already instances running in the cluster's groups. -def launch_cluster(conn, opts, cluster_name): - if opts.identity_file is None: - print("ERROR: Must provide an identity file (-i) for ssh connections.", file=stderr) - sys.exit(1) - - if opts.key_pair is None: - print("ERROR: Must provide a key pair name (-k) to use on instances.", file=stderr) - sys.exit(1) - - user_data_content = None - if opts.user_data: - with open(opts.user_data) as user_data_file: - user_data_content = user_data_file.read() - - print("Setting up security groups...") - master_group = get_or_make_group(conn, cluster_name + "-master", opts.vpc_id) - slave_group = get_or_make_group(conn, cluster_name + "-slaves", opts.vpc_id) - authorized_address = opts.authorized_address - if master_group.rules == []: # Group was just now created - if opts.vpc_id is None: - master_group.authorize(src_group=master_group) - master_group.authorize(src_group=slave_group) - else: - master_group.authorize(ip_protocol='icmp', from_port=-1, to_port=-1, - src_group=master_group) - master_group.authorize(ip_protocol='tcp', from_port=0, to_port=65535, - src_group=master_group) - master_group.authorize(ip_protocol='udp', from_port=0, to_port=65535, - src_group=master_group) - master_group.authorize(ip_protocol='icmp', from_port=-1, to_port=-1, - src_group=slave_group) - master_group.authorize(ip_protocol='tcp', from_port=0, to_port=65535, - src_group=slave_group) - master_group.authorize(ip_protocol='udp', from_port=0, to_port=65535, - src_group=slave_group) - master_group.authorize('tcp', 22, 22, authorized_address) - master_group.authorize('tcp', 8080, 8081, authorized_address) - master_group.authorize('tcp', 18080, 18080, authorized_address) - master_group.authorize('tcp', 19999, 19999, authorized_address) - master_group.authorize('tcp', 50030, 50030, authorized_address) - master_group.authorize('tcp', 50070, 50070, authorized_address) - master_group.authorize('tcp', 60070, 60070, authorized_address) - master_group.authorize('tcp', 4040, 4045, authorized_address) - # Rstudio (GUI for R) needs port 8787 for web access - master_group.authorize('tcp', 8787, 8787, authorized_address) - # HDFS NFS gateway requires 111,2049,4242 for tcp & udp - master_group.authorize('tcp', 111, 111, authorized_address) - master_group.authorize('udp', 111, 111, authorized_address) - master_group.authorize('tcp', 2049, 2049, authorized_address) - master_group.authorize('udp', 2049, 2049, authorized_address) - master_group.authorize('tcp', 4242, 4242, authorized_address) - master_group.authorize('udp', 4242, 4242, authorized_address) - # RM in YARN mode uses 8088 - master_group.authorize('tcp', 8088, 8088, authorized_address) - if opts.ganglia: - master_group.authorize('tcp', 5080, 5080, authorized_address) - if slave_group.rules == []: # Group was just now created - if opts.vpc_id is None: - slave_group.authorize(src_group=master_group) - slave_group.authorize(src_group=slave_group) - else: - slave_group.authorize(ip_protocol='icmp', from_port=-1, to_port=-1, - src_group=master_group) - slave_group.authorize(ip_protocol='tcp', from_port=0, to_port=65535, - src_group=master_group) - slave_group.authorize(ip_protocol='udp', from_port=0, to_port=65535, - src_group=master_group) - slave_group.authorize(ip_protocol='icmp', from_port=-1, to_port=-1, - src_group=slave_group) - slave_group.authorize(ip_protocol='tcp', from_port=0, to_port=65535, - src_group=slave_group) - slave_group.authorize(ip_protocol='udp', from_port=0, to_port=65535, - src_group=slave_group) - slave_group.authorize('tcp', 22, 22, authorized_address) - slave_group.authorize('tcp', 8080, 8081, authorized_address) - slave_group.authorize('tcp', 50060, 50060, authorized_address) - slave_group.authorize('tcp', 50075, 50075, authorized_address) - slave_group.authorize('tcp', 60060, 60060, authorized_address) - slave_group.authorize('tcp', 60075, 60075, authorized_address) - - # Check if instances are already running in our groups - existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name, - die_on_error=False) - if existing_slaves or (existing_masters and not opts.use_existing_master): - print("ERROR: There are already instances running in group %s or %s" % - (master_group.name, slave_group.name), file=stderr) - sys.exit(1) - - # Figure out Spark AMI - if opts.ami is None: - opts.ami = get_spark_ami(opts) - - # we use group ids to work around https://github.com/boto/boto/issues/350 - additional_group_ids = [] - if opts.additional_security_group: - additional_group_ids = [sg.id - for sg in conn.get_all_security_groups() - if opts.additional_security_group in (sg.name, sg.id)] - print("Launching instances...") - - try: - image = conn.get_all_images(image_ids=[opts.ami])[0] - except: - print("Could not find AMI " + opts.ami, file=stderr) - sys.exit(1) - - # Create block device mapping so that we can add EBS volumes if asked to. - # The first drive is attached as /dev/sds, 2nd as /dev/sdt, ... /dev/sdz - block_map = BlockDeviceMapping() - if opts.ebs_vol_size > 0: - for i in range(opts.ebs_vol_num): - device = EBSBlockDeviceType() - device.size = opts.ebs_vol_size - device.volume_type = opts.ebs_vol_type - device.delete_on_termination = True - block_map["/dev/sd" + chr(ord('s') + i)] = device - - # AWS ignores the AMI-specified block device mapping for M3 (see SPARK-3342). - if opts.instance_type.startswith('m3.'): - for i in range(get_num_disks(opts.instance_type)): - dev = BlockDeviceType() - dev.ephemeral_name = 'ephemeral%d' % i - # The first ephemeral drive is /dev/sdb. - name = '/dev/sd' + string.letters[i + 1] - block_map[name] = dev - - # Launch slaves - if opts.spot_price is not None: - # Launch spot instances with the requested price - print("Requesting %d slaves as spot instances with price $%.3f" % - (opts.slaves, opts.spot_price)) - zones = get_zones(conn, opts) - num_zones = len(zones) - i = 0 - my_req_ids = [] - for zone in zones: - num_slaves_this_zone = get_partition(opts.slaves, num_zones, i) - slave_reqs = conn.request_spot_instances( - price=opts.spot_price, - image_id=opts.ami, - launch_group="launch-group-%s" % cluster_name, - placement=zone, - count=num_slaves_this_zone, - key_name=opts.key_pair, - security_group_ids=[slave_group.id] + additional_group_ids, - instance_type=opts.instance_type, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content, - instance_profile_name=opts.instance_profile_name) - my_req_ids += [req.id for req in slave_reqs] - i += 1 - - print("Waiting for spot instances to be granted...") - try: - while True: - time.sleep(10) - reqs = conn.get_all_spot_instance_requests() - id_to_req = {} - for r in reqs: - id_to_req[r.id] = r - active_instance_ids = [] - for i in my_req_ids: - if i in id_to_req and id_to_req[i].state == "active": - active_instance_ids.append(id_to_req[i].instance_id) - if len(active_instance_ids) == opts.slaves: - print("All %d slaves granted" % opts.slaves) - reservations = conn.get_all_reservations(active_instance_ids) - slave_nodes = [] - for r in reservations: - slave_nodes += r.instances - break - else: - print("%d of %d slaves granted, waiting longer" % ( - len(active_instance_ids), opts.slaves)) - except: - print("Canceling spot instance requests") - conn.cancel_spot_instance_requests(my_req_ids) - # Log a warning if any of these requests actually launched instances: - (master_nodes, slave_nodes) = get_existing_cluster( - conn, opts, cluster_name, die_on_error=False) - running = len(master_nodes) + len(slave_nodes) - if running: - print(("WARNING: %d instances are still running" % running), file=stderr) - sys.exit(0) - else: - # Launch non-spot instances - zones = get_zones(conn, opts) - num_zones = len(zones) - i = 0 - slave_nodes = [] - for zone in zones: - num_slaves_this_zone = get_partition(opts.slaves, num_zones, i) - if num_slaves_this_zone > 0: - slave_res = image.run( - key_name=opts.key_pair, - security_group_ids=[slave_group.id] + additional_group_ids, - instance_type=opts.instance_type, - placement=zone, - min_count=num_slaves_this_zone, - max_count=num_slaves_this_zone, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content, - instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior, - instance_profile_name=opts.instance_profile_name) - slave_nodes += slave_res.instances - print("Launched {s} slave{plural_s} in {z}, regid = {r}".format( - s=num_slaves_this_zone, - plural_s=('' if num_slaves_this_zone == 1 else 's'), - z=zone, - r=slave_res.id)) - i += 1 - - # Launch or resume masters - if existing_masters: - print("Starting master...") - for inst in existing_masters: - if inst.state not in ["shutting-down", "terminated"]: - inst.start() - master_nodes = existing_masters - else: - master_type = opts.master_instance_type - if master_type == "": - master_type = opts.instance_type - if opts.zone == 'all': - opts.zone = random.choice(conn.get_all_zones()).name - master_res = image.run( - key_name=opts.key_pair, - security_group_ids=[master_group.id] + additional_group_ids, - instance_type=master_type, - placement=opts.zone, - min_count=1, - max_count=1, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content, - instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior, - instance_profile_name=opts.instance_profile_name) - - master_nodes = master_res.instances - print("Launched master in %s, regid = %s" % (zone, master_res.id)) - - # This wait time corresponds to SPARK-4983 - print("Waiting for AWS to propagate instance metadata...") - time.sleep(15) - - # Give the instances descriptive names and set additional tags - additional_tags = {} - if opts.additional_tags.strip(): - additional_tags = dict( - map(str.strip, tag.split(':', 1)) for tag in opts.additional_tags.split(',') - ) - - for master in master_nodes: - master.add_tags( - dict(additional_tags, Name='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) - ) - - for slave in slave_nodes: - slave.add_tags( - dict(additional_tags, Name='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) - ) - - # Return all the instances - return (master_nodes, slave_nodes) - - -def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): - """ - Get the EC2 instances in an existing cluster if available. - Returns a tuple of lists of EC2 instance objects for the masters and slaves. - """ - print("Searching for existing cluster {c} in region {r}...".format( - c=cluster_name, r=opts.region)) - - def get_instances(group_names): - """ - Get all non-terminated instances that belong to any of the provided security groups. - - EC2 reservation filters and instance states are documented here: - http://docs.aws.amazon.com/cli/latest/reference/ec2/describe-instances.html#options - """ - reservations = conn.get_all_reservations( - filters={"instance.group-name": group_names}) - instances = itertools.chain.from_iterable(r.instances for r in reservations) - return [i for i in instances if i.state not in ["shutting-down", "terminated"]] - - master_instances = get_instances([cluster_name + "-master"]) - slave_instances = get_instances([cluster_name + "-slaves"]) - - if any((master_instances, slave_instances)): - print("Found {m} master{plural_m}, {s} slave{plural_s}.".format( - m=len(master_instances), - plural_m=('' if len(master_instances) == 1 else 's'), - s=len(slave_instances), - plural_s=('' if len(slave_instances) == 1 else 's'))) - - if not master_instances and die_on_error: - print("ERROR: Could not find a master for cluster {c} in region {r}.".format( - c=cluster_name, r=opts.region), file=sys.stderr) - sys.exit(1) - - return (master_instances, slave_instances) - - -# Deploy configuration files and run setup scripts on a newly launched -# or started EC2 cluster. -def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): - master = get_dns_name(master_nodes[0], opts.private_ips) - if deploy_ssh_key: - print("Generating cluster's SSH key on master...") - key_setup = """ - [ -f ~/.ssh/id_rsa ] || - (ssh-keygen -q -t rsa -N '' -f ~/.ssh/id_rsa && - cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys) - """ - ssh(master, opts, key_setup) - dot_ssh_tar = ssh_read(master, opts, ['tar', 'c', '.ssh']) - print("Transferring cluster's SSH key to slaves...") - for slave in slave_nodes: - slave_address = get_dns_name(slave, opts.private_ips) - print(slave_address) - ssh_write(slave_address, opts, ['tar', 'x'], dot_ssh_tar) - - modules = ['spark', 'ephemeral-hdfs', 'persistent-hdfs', - 'mapreduce', 'spark-standalone', 'tachyon', 'rstudio'] - - if opts.hadoop_major_version == "1": - modules = list(filter(lambda x: x != "mapreduce", modules)) - - if opts.ganglia: - modules.append('ganglia') - - # Clear SPARK_WORKER_INSTANCES if running on YARN - if opts.hadoop_major_version == "yarn": - opts.worker_instances = "" - - # NOTE: We should clone the repository before running deploy_files to - # prevent ec2-variables.sh from being overwritten - print("Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format( - r=opts.spark_ec2_git_repo, b=opts.spark_ec2_git_branch)) - ssh( - host=master, - opts=opts, - command="rm -rf spark-ec2" - + " && " - + "git clone {r} -b {b} spark-ec2".format(r=opts.spark_ec2_git_repo, - b=opts.spark_ec2_git_branch) - ) - - print("Deploying files to master...") - deploy_files( - conn=conn, - root_dir=SPARK_EC2_DIR + "/" + "deploy.generic", - opts=opts, - master_nodes=master_nodes, - slave_nodes=slave_nodes, - modules=modules - ) - - if opts.deploy_root_dir is not None: - print("Deploying {s} to master...".format(s=opts.deploy_root_dir)) - deploy_user_files( - root_dir=opts.deploy_root_dir, - opts=opts, - master_nodes=master_nodes - ) - - print("Running setup on master...") - setup_spark_cluster(master, opts) - print("Done!") - - -def setup_spark_cluster(master, opts): - ssh(master, opts, "chmod u+x spark-ec2/setup.sh") - ssh(master, opts, "spark-ec2/setup.sh") - print("Spark standalone cluster started at http://%s:8080" % master) - - if opts.ganglia: - print("Ganglia started at http://%s:5080/ganglia" % master) - - -def is_ssh_available(host, opts, print_ssh_output=True): - """ - Check if SSH is available on a host. - """ - s = subprocess.Popen( - ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3', - '%s@%s' % (opts.user, host), stringify_command('true')], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT # we pipe stderr through stdout to preserve output order - ) - cmd_output = s.communicate()[0] # [1] is stderr, which we redirected to stdout - - if s.returncode != 0 and print_ssh_output: - # extra leading newline is for spacing in wait_for_cluster_state() - print(textwrap.dedent("""\n - Warning: SSH connection error. (This could be temporary.) - Host: {h} - SSH return code: {r} - SSH output: {o} - """).format( - h=host, - r=s.returncode, - o=cmd_output.strip() - )) - - return s.returncode == 0 - - -def is_cluster_ssh_available(cluster_instances, opts): - """ - Check if SSH is available on all the instances in a cluster. - """ - for i in cluster_instances: - dns_name = get_dns_name(i, opts.private_ips) - if not is_ssh_available(host=dns_name, opts=opts): - return False - else: - return True - - -def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state): - """ - Wait for all the instances in the cluster to reach a designated state. - - cluster_instances: a list of boto.ec2.instance.Instance - cluster_state: a string representing the desired state of all the instances in the cluster - value can be 'ssh-ready' or a valid value from boto.ec2.instance.InstanceState such as - 'running', 'terminated', etc. - (would be nice to replace this with a proper enum: http://stackoverflow.com/a/1695250) - """ - sys.stdout.write( - "Waiting for cluster to enter '{s}' state.".format(s=cluster_state) - ) - sys.stdout.flush() - - start_time = datetime.now() - num_attempts = 0 - - while True: - time.sleep(5 * num_attempts) # seconds - - for i in cluster_instances: - i.update() - - max_batch = 100 - statuses = [] - for j in xrange(0, len(cluster_instances), max_batch): - batch = [i.id for i in cluster_instances[j:j + max_batch]] - statuses.extend(conn.get_all_instance_status(instance_ids=batch)) - - if cluster_state == 'ssh-ready': - if all(i.state == 'running' for i in cluster_instances) and \ - all(s.system_status.status == 'ok' for s in statuses) and \ - all(s.instance_status.status == 'ok' for s in statuses) and \ - is_cluster_ssh_available(cluster_instances, opts): - break - else: - if all(i.state == cluster_state for i in cluster_instances): - break - - num_attempts += 1 - - sys.stdout.write(".") - sys.stdout.flush() - - sys.stdout.write("\n") - - end_time = datetime.now() - print("Cluster is now in '{s}' state. Waited {t} seconds.".format( - s=cluster_state, - t=(end_time - start_time).seconds - )) - - -# Get number of local disks available for a given EC2 instance type. -def get_num_disks(instance_type): - # Source: http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html - # Last Updated: 2015-06-19 - # For easy maintainability, please keep this manually-inputted dictionary sorted by key. - disks_by_instance = { - "c1.medium": 1, - "c1.xlarge": 4, - "c3.large": 2, - "c3.xlarge": 2, - "c3.2xlarge": 2, - "c3.4xlarge": 2, - "c3.8xlarge": 2, - "c4.large": 0, - "c4.xlarge": 0, - "c4.2xlarge": 0, - "c4.4xlarge": 0, - "c4.8xlarge": 0, - "cc1.4xlarge": 2, - "cc2.8xlarge": 4, - "cg1.4xlarge": 2, - "cr1.8xlarge": 2, - "d2.xlarge": 3, - "d2.2xlarge": 6, - "d2.4xlarge": 12, - "d2.8xlarge": 24, - "g2.2xlarge": 1, - "g2.8xlarge": 2, - "hi1.4xlarge": 2, - "hs1.8xlarge": 24, - "i2.xlarge": 1, - "i2.2xlarge": 2, - "i2.4xlarge": 4, - "i2.8xlarge": 8, - "m1.small": 1, - "m1.medium": 1, - "m1.large": 2, - "m1.xlarge": 4, - "m2.xlarge": 1, - "m2.2xlarge": 1, - "m2.4xlarge": 2, - "m3.medium": 1, - "m3.large": 1, - "m3.xlarge": 2, - "m3.2xlarge": 2, - "m4.large": 0, - "m4.xlarge": 0, - "m4.2xlarge": 0, - "m4.4xlarge": 0, - "m4.10xlarge": 0, - "r3.large": 1, - "r3.xlarge": 1, - "r3.2xlarge": 1, - "r3.4xlarge": 1, - "r3.8xlarge": 2, - "t1.micro": 0, - "t2.micro": 0, - "t2.small": 0, - "t2.medium": 0, - "t2.large": 0, - } - if instance_type in disks_by_instance: - return disks_by_instance[instance_type] - else: - print("WARNING: Don't know number of disks on instance type %s; assuming 1" - % instance_type, file=stderr) - return 1 - - -# Deploy the configuration file templates in a given local directory to -# a cluster, filling in any template parameters with information about the -# cluster (e.g. lists of masters and slaves). Files are only deployed to -# the first master instance in the cluster, and we expect the setup -# script to be run on that instance to copy them to other nodes. -# -# root_dir should be an absolute path to the directory with the files we want to deploy. -def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): - active_master = get_dns_name(master_nodes[0], opts.private_ips) - - num_disks = get_num_disks(opts.instance_type) - hdfs_data_dirs = "/mnt/ephemeral-hdfs/data" - mapred_local_dirs = "/mnt/hadoop/mrlocal" - spark_local_dirs = "/mnt/spark" - if num_disks > 1: - for i in range(2, num_disks + 1): - hdfs_data_dirs += ",/mnt%d/ephemeral-hdfs/data" % i - mapred_local_dirs += ",/mnt%d/hadoop/mrlocal" % i - spark_local_dirs += ",/mnt%d/spark" % i - - cluster_url = "%s:7077" % active_master - - if "." in opts.spark_version: - # Pre-built Spark deploy - spark_v = get_validate_spark_version(opts.spark_version, opts.spark_git_repo) - tachyon_v = get_tachyon_version(spark_v) - else: - # Spark-only custom deploy - spark_v = "%s|%s" % (opts.spark_git_repo, opts.spark_version) - tachyon_v = "" - print("Deploying Spark via git hash; Tachyon won't be set up") - modules = filter(lambda x: x != "tachyon", modules) - - master_addresses = [get_dns_name(i, opts.private_ips) for i in master_nodes] - slave_addresses = [get_dns_name(i, opts.private_ips) for i in slave_nodes] - worker_instances_str = "%d" % opts.worker_instances if opts.worker_instances else "" - template_vars = { - "master_list": '\n'.join(master_addresses), - "active_master": active_master, - "slave_list": '\n'.join(slave_addresses), - "cluster_url": cluster_url, - "hdfs_data_dirs": hdfs_data_dirs, - "mapred_local_dirs": mapred_local_dirs, - "spark_local_dirs": spark_local_dirs, - "swap": str(opts.swap), - "modules": '\n'.join(modules), - "spark_version": spark_v, - "tachyon_version": tachyon_v, - "hadoop_major_version": opts.hadoop_major_version, - "spark_worker_instances": worker_instances_str, - "spark_master_opts": opts.master_opts - } - - if opts.copy_aws_credentials: - template_vars["aws_access_key_id"] = conn.aws_access_key_id - template_vars["aws_secret_access_key"] = conn.aws_secret_access_key - else: - template_vars["aws_access_key_id"] = "" - template_vars["aws_secret_access_key"] = "" - - # Create a temp directory in which we will place all the files to be - # deployed after we substitue template parameters in them - tmp_dir = tempfile.mkdtemp() - for path, dirs, files in os.walk(root_dir): - if path.find(".svn") == -1: - dest_dir = os.path.join('/', path[len(root_dir):]) - local_dir = tmp_dir + dest_dir - if not os.path.exists(local_dir): - os.makedirs(local_dir) - for filename in files: - if filename[0] not in '#.~' and filename[-1] != '~': - dest_file = os.path.join(dest_dir, filename) - local_file = tmp_dir + dest_file - with open(os.path.join(path, filename)) as src: - with open(local_file, "w") as dest: - text = src.read() - for key in template_vars: - text = text.replace("{{" + key + "}}", template_vars[key]) - dest.write(text) - dest.close() - # rsync the whole directory over to the master machine - command = [ - 'rsync', '-rv', - '-e', stringify_command(ssh_command(opts)), - "%s/" % tmp_dir, - "%s@%s:/" % (opts.user, active_master) - ] - subprocess.check_call(command) - # Remove the temp directory we created above - shutil.rmtree(tmp_dir) - - -# Deploy a given local directory to a cluster, WITHOUT parameter substitution. -# Note that unlike deploy_files, this works for binary files. -# Also, it is up to the user to add (or not) the trailing slash in root_dir. -# Files are only deployed to the first master instance in the cluster. -# -# root_dir should be an absolute path. -def deploy_user_files(root_dir, opts, master_nodes): - active_master = get_dns_name(master_nodes[0], opts.private_ips) - command = [ - 'rsync', '-rv', - '-e', stringify_command(ssh_command(opts)), - "%s" % root_dir, - "%s@%s:/" % (opts.user, active_master) - ] - subprocess.check_call(command) - - -def stringify_command(parts): - if isinstance(parts, str): - return parts - else: - return ' '.join(map(pipes.quote, parts)) - - -def ssh_args(opts): - parts = ['-o', 'StrictHostKeyChecking=no'] - parts += ['-o', 'UserKnownHostsFile=/dev/null'] - if opts.identity_file is not None: - parts += ['-i', opts.identity_file] - return parts - - -def ssh_command(opts): - return ['ssh'] + ssh_args(opts) - - -# Run a command on a host through ssh, retrying up to five times -# and then throwing an exception if ssh continues to fail. -def ssh(host, opts, command): - tries = 0 - while True: - try: - return subprocess.check_call( - ssh_command(opts) + ['-t', '-t', '%s@%s' % (opts.user, host), - stringify_command(command)]) - except subprocess.CalledProcessError as e: - if tries > 5: - # If this was an ssh failure, provide the user with hints. - if e.returncode == 255: - raise UsageError( - "Failed to SSH to remote host {0}.\n" - "Please check that you have provided the correct --identity-file and " - "--key-pair parameters and try again.".format(host)) - else: - raise e - print("Error executing remote command, retrying after 30 seconds: {0}".format(e), - file=stderr) - time.sleep(30) - tries = tries + 1 - - -# Backported from Python 2.7 for compatiblity with 2.6 (See SPARK-1990) -def _check_output(*popenargs, **kwargs): - if 'stdout' in kwargs: - raise ValueError('stdout argument not allowed, it will be overridden.') - process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) - output, unused_err = process.communicate() - retcode = process.poll() - if retcode: - cmd = kwargs.get("args") - if cmd is None: - cmd = popenargs[0] - raise subprocess.CalledProcessError(retcode, cmd, output=output) - return output - - -def ssh_read(host, opts, command): - return _check_output( - ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)]) - - -def ssh_write(host, opts, command, arguments): - tries = 0 - while True: - proc = subprocess.Popen( - ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)], - stdin=subprocess.PIPE) - proc.stdin.write(arguments) - proc.stdin.close() - status = proc.wait() - if status == 0: - break - elif tries > 5: - raise RuntimeError("ssh_write failed with error %s" % proc.returncode) - else: - print("Error {0} while executing remote command, retrying after 30 seconds". - format(status), file=stderr) - time.sleep(30) - tries = tries + 1 - - -# Gets a list of zones to launch instances in -def get_zones(conn, opts): - if opts.zone == 'all': - zones = [z.name for z in conn.get_all_zones()] - else: - zones = [opts.zone] - return zones - - -# Gets the number of items in a partition -def get_partition(total, num_partitions, current_partitions): - num_slaves_this_zone = total // num_partitions - if (total % num_partitions) - current_partitions > 0: - num_slaves_this_zone += 1 - return num_slaves_this_zone - - -# Gets the IP address, taking into account the --private-ips flag -def get_ip_address(instance, private_ips=False): - ip = instance.ip_address if not private_ips else \ - instance.private_ip_address - return ip - - -# Gets the DNS name, taking into account the --private-ips flag -def get_dns_name(instance, private_ips=False): - dns = instance.public_dns_name if not private_ips else \ - instance.private_ip_address - return dns - - -def real_main(): - (opts, action, cluster_name) = parse_args() - - # Input parameter validation - get_validate_spark_version(opts.spark_version, opts.spark_git_repo) - - if opts.wait is not None: - # NOTE: DeprecationWarnings are silent in 2.7+ by default. - # To show them, run Python with the -Wdefault switch. - # See: https://docs.python.org/3.5/whatsnew/2.7.html - warnings.warn( - "This option is deprecated and has no effect. " - "spark-ec2 automatically waits as long as necessary for clusters to start up.", - DeprecationWarning - ) - - if opts.identity_file is not None: - if not os.path.exists(opts.identity_file): - print("ERROR: The identity file '{f}' doesn't exist.".format(f=opts.identity_file), - file=stderr) - sys.exit(1) - - file_mode = os.stat(opts.identity_file).st_mode - if not (file_mode & S_IRUSR) or not oct(file_mode)[-2:] == '00': - print("ERROR: The identity file must be accessible only by you.", file=stderr) - print('You can fix this with: chmod 400 "{f}"'.format(f=opts.identity_file), - file=stderr) - sys.exit(1) - - if opts.instance_type not in EC2_INSTANCE_TYPES: - print("Warning: Unrecognized EC2 instance type for instance-type: {t}".format( - t=opts.instance_type), file=stderr) - - if opts.master_instance_type != "": - if opts.master_instance_type not in EC2_INSTANCE_TYPES: - print("Warning: Unrecognized EC2 instance type for master-instance-type: {t}".format( - t=opts.master_instance_type), file=stderr) - # Since we try instance types even if we can't resolve them, we check if they resolve first - # and, if they do, see if they resolve to the same virtualization type. - if opts.instance_type in EC2_INSTANCE_TYPES and \ - opts.master_instance_type in EC2_INSTANCE_TYPES: - if EC2_INSTANCE_TYPES[opts.instance_type] != \ - EC2_INSTANCE_TYPES[opts.master_instance_type]: - print("Error: spark-ec2 currently does not support having a master and slaves " - "with different AMI virtualization types.", file=stderr) - print("master instance virtualization type: {t}".format( - t=EC2_INSTANCE_TYPES[opts.master_instance_type]), file=stderr) - print("slave instance virtualization type: {t}".format( - t=EC2_INSTANCE_TYPES[opts.instance_type]), file=stderr) - sys.exit(1) - - if opts.ebs_vol_num > 8: - print("ebs-vol-num cannot be greater than 8", file=stderr) - sys.exit(1) - - # Prevent breaking ami_prefix (/, .git and startswith checks) - # Prevent forks with non spark-ec2 names for now. - if opts.spark_ec2_git_repo.endswith("/") or \ - opts.spark_ec2_git_repo.endswith(".git") or \ - not opts.spark_ec2_git_repo.startswith("https://github.com") or \ - not opts.spark_ec2_git_repo.endswith("spark-ec2"): - print("spark-ec2-git-repo must be a github repo and it must not have a trailing / or .git. " - "Furthermore, we currently only support forks named spark-ec2.", file=stderr) - sys.exit(1) - - if not (opts.deploy_root_dir is None or - (os.path.isabs(opts.deploy_root_dir) and - os.path.isdir(opts.deploy_root_dir) and - os.path.exists(opts.deploy_root_dir))): - print("--deploy-root-dir must be an absolute path to a directory that exists " - "on the local file system", file=stderr) - sys.exit(1) - - try: - if opts.profile is None: - conn = ec2.connect_to_region(opts.region) - else: - conn = ec2.connect_to_region(opts.region, profile_name=opts.profile) - except Exception as e: - print((e), file=stderr) - sys.exit(1) - - # Select an AZ at random if it was not specified. - if opts.zone == "": - opts.zone = random.choice(conn.get_all_zones()).name - - if action == "launch": - if opts.slaves <= 0: - print("ERROR: You have to start at least 1 slave", file=sys.stderr) - sys.exit(1) - if opts.resume: - (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - else: - (master_nodes, slave_nodes) = launch_cluster(conn, opts, cluster_name) - wait_for_cluster_state( - conn=conn, - opts=opts, - cluster_instances=(master_nodes + slave_nodes), - cluster_state='ssh-ready' - ) - setup_cluster(conn, master_nodes, slave_nodes, opts, True) - - elif action == "destroy": - (master_nodes, slave_nodes) = get_existing_cluster( - conn, opts, cluster_name, die_on_error=False) - - if any(master_nodes + slave_nodes): - print("The following instances will be terminated:") - for inst in master_nodes + slave_nodes: - print("> %s" % get_dns_name(inst, opts.private_ips)) - print("ALL DATA ON ALL NODES WILL BE LOST!!") - - msg = "Are you sure you want to destroy the cluster {c}? (y/N) ".format(c=cluster_name) - response = raw_input(msg) - if response == "y": - print("Terminating master...") - for inst in master_nodes: - inst.terminate() - print("Terminating slaves...") - for inst in slave_nodes: - inst.terminate() - - # Delete security groups as well - if opts.delete_groups: - group_names = [cluster_name + "-master", cluster_name + "-slaves"] - wait_for_cluster_state( - conn=conn, - opts=opts, - cluster_instances=(master_nodes + slave_nodes), - cluster_state='terminated' - ) - print("Deleting security groups (this will take some time)...") - attempt = 1 - while attempt <= 3: - print("Attempt %d" % attempt) - groups = [g for g in conn.get_all_security_groups() if g.name in group_names] - success = True - # Delete individual rules in all groups before deleting groups to - # remove dependencies between them - for group in groups: - print("Deleting rules in security group " + group.name) - for rule in group.rules: - for grant in rule.grants: - success &= group.revoke(ip_protocol=rule.ip_protocol, - from_port=rule.from_port, - to_port=rule.to_port, - src_group=grant) - - # Sleep for AWS eventual-consistency to catch up, and for instances - # to terminate - time.sleep(30) # Yes, it does have to be this long :-( - for group in groups: - try: - # It is needed to use group_id to make it work with VPC - conn.delete_security_group(group_id=group.id) - print("Deleted security group %s" % group.name) - except boto.exception.EC2ResponseError: - success = False - print("Failed to delete security group %s" % group.name) - - # Unfortunately, group.revoke() returns True even if a rule was not - # deleted, so this needs to be rerun if something fails - if success: - break - - attempt += 1 - - if not success: - print("Failed to delete all security groups after 3 tries.") - print("Try re-running in a few minutes.") - - elif action == "login": - (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - if not master_nodes[0].public_dns_name and not opts.private_ips: - print("Master has no public DNS name. Maybe you meant to specify --private-ips?") - else: - master = get_dns_name(master_nodes[0], opts.private_ips) - print("Logging into master " + master + "...") - proxy_opt = [] - if opts.proxy_port is not None: - proxy_opt = ['-D', opts.proxy_port] - subprocess.check_call( - ssh_command(opts) + proxy_opt + ['-t', '-t', "%s@%s" % (opts.user, master)]) - - elif action == "reboot-slaves": - response = raw_input( - "Are you sure you want to reboot the cluster " + - cluster_name + " slaves?\n" + - "Reboot cluster slaves " + cluster_name + " (y/N): ") - if response == "y": - (master_nodes, slave_nodes) = get_existing_cluster( - conn, opts, cluster_name, die_on_error=False) - print("Rebooting slaves...") - for inst in slave_nodes: - if inst.state not in ["shutting-down", "terminated"]: - print("Rebooting " + inst.id) - inst.reboot() - - elif action == "get-master": - (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - if not master_nodes[0].public_dns_name and not opts.private_ips: - print("Master has no public DNS name. Maybe you meant to specify --private-ips?") - else: - print(get_dns_name(master_nodes[0], opts.private_ips)) - - elif action == "stop": - response = raw_input( - "Are you sure you want to stop the cluster " + - cluster_name + "?\nDATA ON EPHEMERAL DISKS WILL BE LOST, " + - "BUT THE CLUSTER WILL KEEP USING SPACE ON\n" + - "AMAZON EBS IF IT IS EBS-BACKED!!\n" + - "All data on spot-instance slaves will be lost.\n" + - "Stop cluster " + cluster_name + " (y/N): ") - if response == "y": - (master_nodes, slave_nodes) = get_existing_cluster( - conn, opts, cluster_name, die_on_error=False) - print("Stopping master...") - for inst in master_nodes: - if inst.state not in ["shutting-down", "terminated"]: - inst.stop() - print("Stopping slaves...") - for inst in slave_nodes: - if inst.state not in ["shutting-down", "terminated"]: - if inst.spot_instance_request_id: - inst.terminate() - else: - inst.stop() - - elif action == "start": - (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - print("Starting slaves...") - for inst in slave_nodes: - if inst.state not in ["shutting-down", "terminated"]: - inst.start() - print("Starting master...") - for inst in master_nodes: - if inst.state not in ["shutting-down", "terminated"]: - inst.start() - wait_for_cluster_state( - conn=conn, - opts=opts, - cluster_instances=(master_nodes + slave_nodes), - cluster_state='ssh-ready' - ) - - # Determine types of running instances - existing_master_type = master_nodes[0].instance_type - existing_slave_type = slave_nodes[0].instance_type - # Setting opts.master_instance_type to the empty string indicates we - # have the same instance type for the master and the slaves - if existing_master_type == existing_slave_type: - existing_master_type = "" - opts.master_instance_type = existing_master_type - opts.instance_type = existing_slave_type - - setup_cluster(conn, master_nodes, slave_nodes, opts, False) - - else: - print("Invalid action: %s" % action, file=stderr) - sys.exit(1) - - -def main(): - try: - real_main() - except UsageError as e: - print("\nError:\n", e, file=stderr) - sys.exit(1) - - -if __name__ == "__main__": - logging.basicConfig() - main() diff --git a/examples/pom.xml b/examples/pom.xml index f5ab2a7fdc098..4a20370f0668d 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -20,20 +20,23 @@ 4.0.0 org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT + spark-parent_2.11 + 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-examples_2.10 - - examples - + spark-examples_2.11 jar Spark Project Examples http://spark.apache.org/ + + examples + none + package + + org.apache.spark @@ -53,12 +56,6 @@ ${project.version} provided - - org.apache.spark - spark-bagel_${scala.binary.version} - ${project.version} - provided - org.apache.spark spark-hive_${scala.binary.version} @@ -71,54 +68,16 @@ ${project.version} provided - - org.apache.spark - spark-streaming-twitter_${scala.binary.version} - ${project.version} - org.apache.spark spark-streaming-flume_${scala.binary.version} ${project.version} - - org.apache.spark - spark-streaming-mqtt_${scala.binary.version} - ${project.version} - - - org.apache.spark - spark-streaming-zeromq_${scala.binary.version} - ${project.version} - - - org.spark-project.protobuf - protobuf-java - - - org.apache.spark spark-streaming-kafka_${scala.binary.version} ${project.version} - - org.apache.hbase - hbase-testing-util - ${hbase.version} - ${hbase.deps.scope} - - - - org.apache.hbase - hbase-annotations - - - org.jruby - jruby-complete - - - org.apache.hbase hbase-protocol @@ -166,6 +125,10 @@ org.apache.hbase hbase-annotations + + org.apache.hbase + hbase-common + org.apache.hadoop hadoop-core @@ -235,13 +198,6 @@ ${hbase.version} ${hbase.deps.scope} - - org.apache.hbase - hbase-hadoop-compat - ${hbase.version} - test-jar - test - org.apache.commons commons-math3 @@ -250,7 +206,7 @@ com.twitter algebird-core_${scala.binary.version} - 0.9.0 + 0.11.0 org.scalacheck @@ -260,7 +216,7 @@ org.apache.cassandra cassandra-all - 1.2.6 + 1.2.19 com.google.guava @@ -319,19 +275,8 @@ com.github.scopt scopt_${scala.binary.version} - 3.2.0 + 3.3.0 - - - - org.scala-lang - scala-library - provided - - @@ -352,38 +297,6 @@ true - - org.apache.maven.plugins - maven-shade-plugin - - false - ${project.build.directory}/scala-${scala.binary.version}/spark-examples-${project.version}-hadoop${hadoop.version}.jar - - - *:* - - - - - *:* - - META-INF/*.SF - META-INF/*.DSA - META-INF/*.RSA - - - - - - - reference.conf - - - log4j.properties - - - - diff --git a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java index 812e9d5580cbf..ebb0687b14ae0 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java @@ -28,14 +28,13 @@ import org.apache.spark.api.java.function.PairFunction; import java.io.Serializable; -import java.util.Collections; import java.util.List; import java.util.regex.Matcher; import java.util.regex.Pattern; /** * Executes a roll up-style query against Apache logs. - * + * * Usage: JavaLogQuery [logFile] */ public final class JavaLogQuery { @@ -83,10 +82,10 @@ public static Tuple3 extractKey(String line) { String user = m.group(3); String query = m.group(5); if (!user.equalsIgnoreCase("-")) { - return new Tuple3(ip, user, query); + return new Tuple3<>(ip, user, query); } } - return new Tuple3(null, null, null); + return new Tuple3<>(null, null, null); } public static Stats extractStats(String line) { @@ -109,7 +108,7 @@ public static void main(String[] args) { JavaPairRDD, Stats> extracted = dataSet.mapToPair(new PairFunction, Stats>() { @Override public Tuple2, Stats> call(String s) { - return new Tuple2, Stats>(extractKey(s), extractStats(s)); + return new Tuple2<>(extractKey(s), extractStats(s)); } }); diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java index a5db8accdf138..229d1234414e5 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java @@ -17,7 +17,10 @@ package org.apache.spark.examples; - +import java.util.ArrayList; +import java.util.List; +import java.util.Iterator; +import java.util.regex.Pattern; import scala.Tuple2; @@ -32,11 +35,6 @@ import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.api.java.function.PairFunction; -import java.util.ArrayList; -import java.util.List; -import java.util.Iterator; -import java.util.regex.Pattern; - /** * Computes the PageRank of URLs from an input file. Input file should * be in format of: @@ -86,13 +84,14 @@ public static void main(String[] args) throws Exception { JavaRDD lines = ctx.textFile(args[0], 1); // Loads all URLs from input file and initialize their neighbors. - JavaPairRDD> links = lines.mapToPair(new PairFunction() { - @Override - public Tuple2 call(String s) { - String[] parts = SPACES.split(s); - return new Tuple2(parts[0], parts[1]); - } - }).distinct().groupByKey().cache(); + JavaPairRDD> links = lines.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(String s) { + String[] parts = SPACES.split(s); + return new Tuple2<>(parts[0], parts[1]); + } + }).distinct().groupByKey().cache(); // Loads all URLs with other URL(s) link to from input file and initialize ranks of them to one. JavaPairRDD ranks = links.mapValues(new Function, Double>() { @@ -108,13 +107,13 @@ public Double call(Iterable rs) { JavaPairRDD contribs = links.join(ranks).values() .flatMapToPair(new PairFlatMapFunction, Double>, String, Double>() { @Override - public Iterable> call(Tuple2, Double> s) { + public Iterator> call(Tuple2, Double> s) { int urlCount = Iterables.size(s._1); - List> results = new ArrayList>(); + List> results = new ArrayList<>(); for (String n : s._1) { - results.add(new Tuple2(n, s._2() / urlCount)); + results.add(new Tuple2<>(n, s._2() / urlCount)); } - return results; + return results.iterator(); } }); diff --git a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java index 0f07cb4098325..04a57a6bfb58b 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java @@ -26,7 +26,7 @@ import java.util.ArrayList; import java.util.List; -/** +/** * Computes an approximation to pi * Usage: JavaSparkPi [slices] */ @@ -38,7 +38,7 @@ public static void main(String[] args) throws Exception { int slices = (args.length == 1) ? Integer.parseInt(args[0]) : 2; int n = 100000 * slices; - List l = new ArrayList(n); + List l = new ArrayList<>(n); for (int i = 0; i < n; i++) { l.add(i); } diff --git a/examples/src/main/java/org/apache/spark/examples/JavaTC.java b/examples/src/main/java/org/apache/spark/examples/JavaTC.java index 2563fcdd234bb..ca10384212da2 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaTC.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaTC.java @@ -41,16 +41,16 @@ public final class JavaTC { private static final Random rand = new Random(42); static List> generateGraph() { - Set> edges = new HashSet>(numEdges); + Set> edges = new HashSet<>(numEdges); while (edges.size() < numEdges) { int from = rand.nextInt(numVertices); int to = rand.nextInt(numVertices); - Tuple2 e = new Tuple2(from, to); + Tuple2 e = new Tuple2<>(from, to); if (from != to) { edges.add(e); } } - return new ArrayList>(edges); + return new ArrayList<>(edges); } static class ProjectFn implements PairFunction>, @@ -59,7 +59,7 @@ static class ProjectFn implements PairFunction call(Tuple2> triple) { - return new Tuple2(triple._2()._2(), triple._2()._1()); + return new Tuple2<>(triple._2()._2(), triple._2()._1()); } } @@ -79,7 +79,7 @@ public static void main(String[] args) { new PairFunction, Integer, Integer>() { @Override public Tuple2 call(Tuple2 e) { - return new Tuple2(e._2(), e._1()); + return new Tuple2<>(e._2(), e._1()); } }); diff --git a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java index 9a6a944f7edef..3ff5412b934f0 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java @@ -27,6 +27,7 @@ import org.apache.spark.api.java.function.PairFunction; import java.util.Arrays; +import java.util.Iterator; import java.util.List; import java.util.regex.Pattern; @@ -46,24 +47,26 @@ public static void main(String[] args) throws Exception { JavaRDD words = lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String s) { - return Arrays.asList(SPACE.split(s)); + public Iterator call(String s) { + return Arrays.asList(SPACE.split(s)).iterator(); } }); - JavaPairRDD ones = words.mapToPair(new PairFunction() { - @Override - public Tuple2 call(String s) { - return new Tuple2(s, 1); - } - }); + JavaPairRDD ones = words.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(String s) { + return new Tuple2<>(s, 1); + } + }); - JavaPairRDD counts = ones.reduceByKey(new Function2() { - @Override - public Integer call(Integer i1, Integer i2) { - return i1 + i2; - } - }); + JavaPairRDD counts = ones.reduceByKey( + new Function2() { + @Override + public Integer call(Integer i1, Integer i2) { + return i1 + i2; + } + }); List> output = counts.collect(); for (Tuple2 tuple : output) { diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java new file mode 100644 index 0000000000000..22b93a3a85c52 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.regression.AFTSurvivalRegression; +import org.apache.spark.ml.regression.AFTSurvivalRegressionModel; +import org.apache.spark.mllib.linalg.*; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; +// $example off$ + +public class JavaAFTSurvivalRegressionExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaAFTSurvivalRegressionExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(1.218, 1.0, Vectors.dense(1.560, -0.605)), + RowFactory.create(2.949, 0.0, Vectors.dense(0.346, 2.158)), + RowFactory.create(3.627, 0.0, Vectors.dense(1.380, 0.231)), + RowFactory.create(0.273, 1.0, Vectors.dense(0.520, 1.151)), + RowFactory.create(4.199, 0.0, Vectors.dense(0.795, -0.226)) + ); + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("censor", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("features", new VectorUDT(), false, Metadata.empty()) + }); + Dataset training = jsql.createDataFrame(data, schema); + double[] quantileProbabilities = new double[]{0.3, 0.6}; + AFTSurvivalRegression aft = new AFTSurvivalRegression() + .setQuantileProbabilities(quantileProbabilities) + .setQuantilesCol("quantiles"); + + AFTSurvivalRegressionModel model = aft.fit(training); + + // Print the coefficients, intercept and scale parameter for AFT survival regression + System.out.println("Coefficients: " + model.coefficients() + " Intercept: " + + model.intercept() + " Scale: " + model.scale()); + model.transform(training).show(false); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java new file mode 100644 index 0000000000000..088037d427f5b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.io.Serializable; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.recommendation.ALS; +import org.apache.spark.ml.recommendation.ALSModel; +import org.apache.spark.sql.types.DataTypes; +// $example off$ + +public class JavaALSExample { + + // $example on$ + public static class Rating implements Serializable { + private int userId; + private int movieId; + private float rating; + private long timestamp; + + public Rating() {} + + public Rating(int userId, int movieId, float rating, long timestamp) { + this.userId = userId; + this.movieId = movieId; + this.rating = rating; + this.timestamp = timestamp; + } + + public int getUserId() { + return userId; + } + + public int getMovieId() { + return movieId; + } + + public float getRating() { + return rating; + } + + public long getTimestamp() { + return timestamp; + } + + public static Rating parseRating(String str) { + String[] fields = str.split("::"); + if (fields.length != 4) { + throw new IllegalArgumentException("Each line must contain 4 fields"); + } + int userId = Integer.parseInt(fields[0]); + int movieId = Integer.parseInt(fields[1]); + float rating = Float.parseFloat(fields[2]); + long timestamp = Long.parseLong(fields[3]); + return new Rating(userId, movieId, rating, timestamp); + } + } + // $example off$ + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaALSExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD ratingsRDD = jsc.textFile("data/mllib/als/sample_movielens_ratings.txt") + .map(new Function() { + public Rating call(String str) { + return Rating.parseRating(str); + } + }); + Dataset ratings = sqlContext.createDataFrame(ratingsRDD, Rating.class); + Dataset[] splits = ratings.randomSplit(new double[]{0.8, 0.2}); + Dataset training = splits[0]; + Dataset test = splits[1]; + + // Build the recommendation model using ALS on the training data + ALS als = new ALS() + .setMaxIter(5) + .setRegParam(0.01) + .setUserCol("userId") + .setItemCol("movieId") + .setRatingCol("rating"); + ALSModel model = als.fit(training); + + // Evaluate the model by computing the RMSE on the test data + Dataset rawPredictions = model.transform(test); + Dataset predictions = rawPredictions + .withColumn("rating", rawPredictions.col("rating").cast(DataTypes.DoubleType)) + .withColumn("prediction", rawPredictions.col("prediction").cast(DataTypes.DoubleType)); + + RegressionEvaluator evaluator = new RegressionEvaluator() + .setMetricName("rmse") + .setLabelCol("rating") + .setPredictionCol("prediction"); + Double rmse = evaluator.evaluate(predictions); + System.out.println("Root-mean-square error = " + rmse); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java new file mode 100644 index 0000000000000..0a6e9c2a1f93c --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.Binarizer; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaBinarizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaBinarizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, 0.1), + RowFactory.create(1, 0.8), + RowFactory.create(2, 0.2) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) + }); + Dataset continuousDataFrame = jsql.createDataFrame(jrdd, schema); + Binarizer binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.5); + Dataset binarizedDataFrame = binarizer.transform(continuousDataFrame); + Dataset binarizedFeatures = binarizedDataFrame.select("binarized_feature"); + for (Row r : binarizedFeatures.collectAsList()) { + Double binarized_value = r.getDouble(0); + System.out.println(binarized_value); + } + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java new file mode 100644 index 0000000000000..1d1a518bbca12 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.util.Arrays; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +// $example on$ +import org.apache.spark.ml.clustering.BisectingKMeans; +import org.apache.spark.ml.clustering.BisectingKMeansModel; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + + +/** + * An example demonstrating a bisecting k-means clustering. + */ +public class JavaBisectingKMeansExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaBisectingKMeansExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(0.1, 0.1, 0.1)), + RowFactory.create(Vectors.dense(0.3, 0.3, 0.25)), + RowFactory.create(Vectors.dense(0.1, 0.1, -0.1)), + RowFactory.create(Vectors.dense(20.3, 20.1, 19.9)), + RowFactory.create(Vectors.dense(20.2, 20.1, 19.7)), + RowFactory.create(Vectors.dense(18.9, 20.0, 19.7)) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + + Dataset dataset = jsql.createDataFrame(data, schema); + + BisectingKMeans bkm = new BisectingKMeans().setK(2); + BisectingKMeansModel model = bkm.fit(dataset); + + System.out.println("Compute Cost: " + model.computeCost(dataset)); + + Vector[] clusterCenters = model.clusterCenters(); + for (int i = 0; i < clusterCenters.length; i++) { + Vector clusterCenter = clusterCenters[i]; + System.out.println("Cluster Center " + i + ": " + clusterCenter); + } + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java new file mode 100644 index 0000000000000..68ffa702ea5e2 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.Bucketizer; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaBucketizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaBucketizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; + + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(-0.5), + RowFactory.create(-0.3), + RowFactory.create(0.0), + RowFactory.create(0.2) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) + }); + Dataset dataFrame = jsql.createDataFrame(data, schema); + + Bucketizer bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits); + + // Transform original data into its bucket index. + Dataset bucketedData = bucketizer.transform(dataFrame); + bucketedData.show(); + // $example off$ + jsc.stop(); + } +} + + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java new file mode 100644 index 0000000000000..b1bf1cfeb2153 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.ml.feature.ChiSqSelector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaChiSqSelectorExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaChiSqSelectorExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(7, Vectors.dense(0.0, 0.0, 18.0, 1.0), 1.0), + RowFactory.create(8, Vectors.dense(0.0, 1.0, 12.0, 0.0), 0.0), + RowFactory.create(9, Vectors.dense(1.0, 0.0, 15.0, 0.1), 0.0) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("features", new VectorUDT(), false, Metadata.empty()), + new StructField("clicked", DataTypes.DoubleType, false, Metadata.empty()) + }); + + Dataset df = sqlContext.createDataFrame(jrdd, schema); + + ChiSqSelector selector = new ChiSqSelector() + .setNumTopFeatures(1) + .setFeaturesCol("features") + .setLabelCol("clicked") + .setOutputCol("selectedFeatures"); + + Dataset result = selector.fit(df).transform(df); + result.show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java index ac33adb65292f..ec3ac202bea4e 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java @@ -25,7 +25,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.feature.CountVectorizer; import org.apache.spark.ml.feature.CountVectorizerModel; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -48,7 +48,7 @@ public static void main(String[] args) { StructType schema = new StructType(new StructField [] { new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) }); - DataFrame df = sqlContext.createDataFrame(jrdd, schema); + Dataset df = sqlContext.createDataFrame(jrdd, schema); // fit a CountVectorizerModel from the corpus CountVectorizerModel cvModel = new CountVectorizer() @@ -65,5 +65,7 @@ public static void main(String[] args) { cvModel.transform(df).show(); // $example off$ + + jsc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java deleted file mode 100644 index 9bbc14ea40875..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import java.util.List; - -import com.google.common.collect.Lists; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.classification.LogisticRegression; -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; -import org.apache.spark.ml.feature.HashingTF; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.ml.tuning.CrossValidator; -import org.apache.spark.ml.tuning.CrossValidatorModel; -import org.apache.spark.ml.tuning.ParamGridBuilder; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; - -/** - * A simple example demonstrating model selection using CrossValidator. - * This example also demonstrates how Pipelines are Estimators. - * - * This example uses the Java bean classes {@link org.apache.spark.examples.ml.LabeledDocument} and - * {@link org.apache.spark.examples.ml.Document} defined in the Scala example - * {@link org.apache.spark.examples.ml.SimpleTextClassificationPipeline}. - * - * Run with - *
    - * bin/run-example ml.JavaCrossValidatorExample
    - * 
    - */ -public class JavaCrossValidatorExample { - - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaCrossValidatorExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // Prepare training documents, which are labeled. - List localTraining = Lists.newArrayList( - new LabeledDocument(0L, "a b c d e spark", 1.0), - new LabeledDocument(1L, "b d", 0.0), - new LabeledDocument(2L, "spark f g h", 1.0), - new LabeledDocument(3L, "hadoop mapreduce", 0.0), - new LabeledDocument(4L, "b spark who", 1.0), - new LabeledDocument(5L, "g d a y", 0.0), - new LabeledDocument(6L, "spark fly", 1.0), - new LabeledDocument(7L, "was mapreduce", 0.0), - new LabeledDocument(8L, "e spark program", 1.0), - new LabeledDocument(9L, "a e c l", 0.0), - new LabeledDocument(10L, "spark compile", 1.0), - new LabeledDocument(11L, "hadoop software", 0.0)); - DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); - - // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. - Tokenizer tokenizer = new Tokenizer() - .setInputCol("text") - .setOutputCol("words"); - HashingTF hashingTF = new HashingTF() - .setNumFeatures(1000) - .setInputCol(tokenizer.getOutputCol()) - .setOutputCol("features"); - LogisticRegression lr = new LogisticRegression() - .setMaxIter(10) - .setRegParam(0.01); - Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); - - // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. - // This will allow us to jointly choose parameters for all Pipeline stages. - // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. - CrossValidator crossval = new CrossValidator() - .setEstimator(pipeline) - .setEvaluator(new BinaryClassificationEvaluator()); - // We use a ParamGridBuilder to construct a grid of parameters to search over. - // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, - // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. - ParamMap[] paramGrid = new ParamGridBuilder() - .addGrid(hashingTF.numFeatures(), new int[]{10, 100, 1000}) - .addGrid(lr.regParam(), new double[]{0.1, 0.01}) - .build(); - crossval.setEstimatorParamMaps(paramGrid); - crossval.setNumFolds(2); // Use 3+ in practice - - // Run cross-validation, and choose the best set of parameters. - CrossValidatorModel cvModel = crossval.fit(training); - - // Prepare test documents, which are unlabeled. - List localTest = Lists.newArrayList( - new Document(4L, "spark i j k"), - new Document(5L, "l m n"), - new Document(6L, "mapreduce spark"), - new Document(7L, "apache hadoop")); - DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); - - // Make predictions on test documents. cvModel uses the best model found (lrModel). - DataFrame predictions = cvModel.transform(test); - for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) - + ", prediction=" + r.get(3)); - } - - jsc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java new file mode 100644 index 0000000000000..4b15fde9c35fa --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.DCT; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaDCTExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaDCTExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), + RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), + RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + Dataset df = jsql.createDataFrame(data, schema); + DCT dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false); + Dataset dctDf = dct.transform(df); + dctDf.select("featuresDCT").show(3); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java new file mode 100644 index 0000000000000..8214952f80695 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// scalastyle:off println +package org.apache.spark.examples.ml; +// $example on$ +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.DecisionTreeClassifier; +import org.apache.spark.ml.classification.DecisionTreeClassificationModel; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.ml.feature.*; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaDecisionTreeClassificationExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaDecisionTreeClassificationExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Load the data stored in LIBSVM format as a DataFrame. + Dataset data = sqlContext + .read() + .format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); + + // Index labels, adding metadata to the label column. + // Fit on whole dataset to include all labels in index. + StringIndexerModel labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data); + + // Automatically identify categorical features, and index them. + VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) // features with > 4 distinct values are treated as continuous + .fit(data); + + // Split the data into training and test sets (30% held out for testing) + Dataset[] splits = data.randomSplit(new double[]{0.7, 0.3}); + Dataset trainingData = splits[0]; + Dataset testData = splits[1]; + + // Train a DecisionTree model. + DecisionTreeClassifier dt = new DecisionTreeClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures"); + + // Convert indexed labels back to original labels. + IndexToString labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels()); + + // Chain indexers and tree in a Pipeline + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[]{labelIndexer, featureIndexer, dt, labelConverter}); + + // Train model. This also runs the indexers. + PipelineModel model = pipeline.fit(trainingData); + + // Make predictions. + Dataset predictions = model.transform(testData); + + // Select example rows to display. + predictions.select("predictedLabel", "label", "features").show(5); + + // Select (prediction, true label) and compute test error + MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision"); + double accuracy = evaluator.evaluate(predictions); + System.out.println("Test Error = " + (1.0 - accuracy)); + + DecisionTreeClassificationModel treeModel = + (DecisionTreeClassificationModel) (model.stages()[2]); + System.out.println("Learned classification tree model:\n" + treeModel.toDebugString()); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java new file mode 100644 index 0000000000000..a4f3e97bf318a --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// scalastyle:off println +package org.apache.spark.examples.ml; +// $example on$ +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.ml.regression.DecisionTreeRegressionModel; +import org.apache.spark.ml.regression.DecisionTreeRegressor; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaDecisionTreeRegressionExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaDecisionTreeRegressionExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + // $example on$ + // Load the data stored in LIBSVM format as a DataFrame. + Dataset data = sqlContext.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); + + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + + // Split the data into training and test sets (30% held out for testing) + Dataset[] splits = data.randomSplit(new double[]{0.7, 0.3}); + Dataset trainingData = splits[0]; + Dataset testData = splits[1]; + + // Train a DecisionTree model. + DecisionTreeRegressor dt = new DecisionTreeRegressor() + .setFeaturesCol("indexedFeatures"); + + // Chain indexer and tree in a Pipeline + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[]{featureIndexer, dt}); + + // Train model. This also runs the indexer. + PipelineModel model = pipeline.fit(trainingData); + + // Make predictions. + Dataset predictions = model.transform(testData); + + // Select example rows to display. + predictions.select("label", "features").show(5); + + // Select (prediction, true label) and compute test error + RegressionEvaluator evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse"); + double rmse = evaluator.evaluate(predictions); + System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); + + DecisionTreeRegressionModel treeModel = + (DecisionTreeRegressionModel) (model.stages()[1]); + System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index 0b4c0d9ba9f8b..0ba94786d4e5f 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -33,7 +33,7 @@ import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -61,7 +61,8 @@ public static void main(String[] args) throws Exception { new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); - DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class); + Dataset training = jsql.createDataFrame( + jsc.parallelize(localTraining), LabeledPoint.class); // Create a LogisticRegression instance. This instance is an Estimator. MyJavaLogisticRegression lr = new MyJavaLogisticRegression(); @@ -79,17 +80,17 @@ public static void main(String[] args) throws Exception { new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); - DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); + Dataset test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents. cvModel uses the best model found (lrModel). - DataFrame results = model.transform(test); + Dataset results = model.transform(test); double sumPredictions = 0; - for (Row r : results.select("features", "label", "prediction").collect()) { + for (Row r : results.select("features", "label", "prediction").collectAsList()) { sumPredictions += r.getDouble(2); } if (sumPredictions != 0.0) { throw new Exception("MyJavaLogisticRegression predicted something other than 0," + - " even though all weights are 0!"); + " even though all coefficients are 0!"); } jsc.stop(); @@ -106,11 +107,11 @@ public static void main(String[] args) throws Exception { class MyJavaLogisticRegression extends Classifier { - public MyJavaLogisticRegression() { + MyJavaLogisticRegression() { init(); } - public MyJavaLogisticRegression(String uid) { + MyJavaLogisticRegression(String uid) { this.uid_ = uid; init(); } @@ -145,16 +146,16 @@ MyJavaLogisticRegression setMaxIter(int value) { // This method is used by fit(). // In Java, we have to make it public since Java does not understand Scala's protected modifier. - public MyJavaLogisticRegressionModel train(DataFrame dataset) { + public MyJavaLogisticRegressionModel train(Dataset dataset) { // Extract columns from data using helper method. JavaRDD oldDataset = extractLabeledPoints(dataset).toJavaRDD(); - // Do learning to estimate the weight vector. + // Do learning to estimate the coefficients vector. int numFeatures = oldDataset.take(1).get(0).features().size(); - Vector weights = Vectors.zeros(numFeatures); // Learning would happen here. + Vector coefficients = Vectors.zeros(numFeatures); // Learning would happen here. // Create a model, and return it. - return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this); + return new MyJavaLogisticRegressionModel(uid(), coefficients).setParent(this); } @Override @@ -173,12 +174,12 @@ public MyJavaLogisticRegression copy(ParamMap extra) { class MyJavaLogisticRegressionModel extends ClassificationModel { - private Vector weights_; - public Vector weights() { return weights_; } + private Vector coefficients_; + public Vector coefficients() { return coefficients_; } - public MyJavaLogisticRegressionModel(String uid, Vector weights) { + MyJavaLogisticRegressionModel(String uid, Vector coefficients) { this.uid_ = uid; - this.weights_ = weights; + this.coefficients_ = coefficients; } private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg"); @@ -208,7 +209,7 @@ public String uid() { * modifier. */ public Vector predictRaw(Vector features) { - double margin = BLAS.dot(features, weights_); + double margin = BLAS.dot(features, coefficients_); // There are 2 classes (binary classification), so we return a length-2 vector, // where index i corresponds to class i (i = 0, 1). return Vectors.dense(-margin, margin); @@ -222,20 +223,20 @@ public Vector predictRaw(Vector features) { /** * Number of features the model was trained on. */ - public int numFeatures() { return weights_.size(); } + public int numFeatures() { return coefficients_.size(); } /** * Create a copy of the model. * The copy is shallow, except for the embedded paramMap, which gets a deep copy. *

    - * This is used for the defaul implementation of [[transform()]]. + * This is used for the default implementation of [[transform()]]. * * In Java, we have to make this method public since Java does not understand Scala's protected * modifier. */ @Override public MyJavaLogisticRegressionModel copy(ParamMap extra) { - return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra) + return copyValues(new MyJavaLogisticRegressionModel(uid(), coefficients_), extra) .setParent(parent()); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDocument.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDocument.java new file mode 100644 index 0000000000000..6459dabc0698b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDocument.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.io.Serializable; + +/** + * Unlabeled instance type, Spark SQL can infer schema from Java Beans. + */ +@SuppressWarnings("serial") +public class JavaDocument implements Serializable { + + private long id; + private String text; + + public JavaDocument(long id, String text) { + this.id = id; + this.text = text; + } + + public long getId() { + return this.id; + } + + public String getText() { + return this.text; + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java new file mode 100644 index 0000000000000..37de9cf3596a9 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.ElementwiseProduct; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaElementwiseProductExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaElementwiseProductExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Create some vector data; also works for sparse vectors + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), + RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) + )); + + List fields = new ArrayList<>(2); + fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); + fields.add(DataTypes.createStructField("vector", new VectorUDT(), false)); + + StructType schema = DataTypes.createStructType(fields); + + Dataset dataFrame = sqlContext.createDataFrame(jrdd, schema); + + Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); + + ElementwiseProduct transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector"); + + // Batch transform the vectors to create new column: + transformer.transform(dataFrame).show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java new file mode 100644 index 0000000000000..604b193dd489b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +// $example off$ + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example on$ +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +// $example off$ +import org.apache.spark.sql.SQLContext; + +/** + * Java example for Estimator, Transformer, and Param. + */ +public class JavaEstimatorTransformerParamExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf() + .setAppName("JavaEstimatorTransformerParamExample"); + SparkContext sc = new SparkContext(conf); + SQLContext sqlContext = new SQLContext(sc); + + // $example on$ + // Prepare training data. + // We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans into + // DataFrames, where it uses the bean metadata to infer the schema. + Dataset training = sqlContext.createDataFrame( + Arrays.asList( + new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), + new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)) + ), LabeledPoint.class); + + // Create a LogisticRegression instance. This instance is an Estimator. + LogisticRegression lr = new LogisticRegression(); + // Print out the parameters, documentation, and any default values. + System.out.println("LogisticRegression parameters:\n" + lr.explainParams() + "\n"); + + // We may set parameters using setter methods. + lr.setMaxIter(10).setRegParam(0.01); + + // Learn a LogisticRegression model. This uses the parameters stored in lr. + LogisticRegressionModel model1 = lr.fit(training); + // Since model1 is a Model (i.e., a Transformer produced by an Estimator), + // we can view the parameters it used during fit(). + // This prints the parameter (name: value) pairs, where names are unique IDs for this + // LogisticRegression instance. + System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap()); + + // We may alternatively specify parameters using a ParamMap. + ParamMap paramMap = new ParamMap() + .put(lr.maxIter().w(20)) // Specify 1 Param. + .put(lr.maxIter(), 30) // This overwrites the original maxIter. + .put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. + + // One can also combine ParamMaps. + ParamMap paramMap2 = new ParamMap() + .put(lr.probabilityCol().w("myProbability")); // Change output column name + ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); + + // Now learn a new model using the paramMapCombined parameters. + // paramMapCombined overrides all parameters set earlier via lr.set* methods. + LogisticRegressionModel model2 = lr.fit(training, paramMapCombined); + System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap()); + + // Prepare test documents. + Dataset test = sqlContext.createDataFrame(Arrays.asList( + new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), + new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)) + ), LabeledPoint.class); + + // Make predictions on test documents using the Transformer.transform() method. + // LogisticRegression.transform will only use the 'features' column. + // Note that model2.transform() outputs a 'myProbability' column instead of the usual + // 'probability' column since we renamed the lr.probabilityCol parameter previously. + Dataset results = model2.transform(test); + Dataset rows = results.select("features", "label", "myProbability", "prediction"); + for (Row r: rows.collectAsList()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) + + ", prediction=" + r.get(3)); + } + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java new file mode 100644 index 0000000000000..553070dace882 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.GBTClassificationModel; +import org.apache.spark.ml.classification.GBTClassifier; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.ml.feature.*; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaGradientBoostedTreeClassifierExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaGradientBoostedTreeClassifierExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + Dataset data = sqlContext.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); + + // Index labels, adding metadata to the label column. + // Fit on whole dataset to include all labels in index. + StringIndexerModel labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data); + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + + // Split the data into training and test sets (30% held out for testing) + Dataset[] splits = data.randomSplit(new double[] {0.7, 0.3}); + Dataset trainingData = splits[0]; + Dataset testData = splits[1]; + + // Train a GBT model. + GBTClassifier gbt = new GBTClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures") + .setMaxIter(10); + + // Convert indexed labels back to original labels. + IndexToString labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels()); + + // Chain indexers and GBT in a Pipeline + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {labelIndexer, featureIndexer, gbt, labelConverter}); + + // Train model. This also runs the indexers. + PipelineModel model = pipeline.fit(trainingData); + + // Make predictions. + Dataset predictions = model.transform(testData); + + // Select example rows to display. + predictions.select("predictedLabel", "label", "features").show(5); + + // Select (prediction, true label) and compute test error + MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision"); + double accuracy = evaluator.evaluate(predictions); + System.out.println("Test Error = " + (1.0 - accuracy)); + + GBTClassificationModel gbtModel = (GBTClassificationModel)(model.stages()[2]); + System.out.println("Learned classification GBT model:\n" + gbtModel.toDebugString()); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java new file mode 100644 index 0000000000000..83fd89e3bd59b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.ml.regression.GBTRegressionModel; +import org.apache.spark.ml.regression.GBTRegressor; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaGradientBoostedTreeRegressorExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaGradientBoostedTreeRegressorExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + Dataset data = + sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + + // Split the data into training and test sets (30% held out for testing) + Dataset[] splits = data.randomSplit(new double[] {0.7, 0.3}); + Dataset trainingData = splits[0]; + Dataset testData = splits[1]; + + // Train a GBT model. + GBTRegressor gbt = new GBTRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures") + .setMaxIter(10); + + // Chain indexer and GBT in a Pipeline + Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {featureIndexer, gbt}); + + // Train model. This also runs the indexer. + PipelineModel model = pipeline.fit(trainingData); + + // Make predictions. + Dataset predictions = model.transform(testData); + + // Select example rows to display. + predictions.select("prediction", "label", "features").show(5); + + // Select (prediction, true label) and compute test error + RegressionEvaluator evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse"); + double rmse = evaluator.evaluate(predictions); + System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); + + GBTRegressionModel gbtModel = (GBTRegressionModel)(model.stages()[1]); + System.out.println("Learned regression GBT model:\n" + gbtModel.toDebugString()); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java new file mode 100644 index 0000000000000..9b8c22f3bdfde --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.ml.feature.IndexToString; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.ml.feature.StringIndexerModel; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaIndexToStringExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaIndexToStringExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") + )); + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("category", DataTypes.StringType, false, Metadata.empty()) + }); + Dataset df = sqlContext.createDataFrame(jrdd, schema); + + StringIndexerModel indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df); + Dataset indexed = indexer.transform(df); + + IndexToString converter = new IndexToString() + .setInputCol("categoryIndex") + .setOutputCol("originalCategory"); + Dataset converted = converter.transform(indexed); + converted.select("id", "originalCategory").show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java index be2bf0c7b465c..c5022f4c0b8fe 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java @@ -23,25 +23,27 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +// $example on$ import org.apache.spark.ml.clustering.KMeansModel; import org.apache.spark.ml.clustering.KMeans; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.catalyst.expressions.GenericRow; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +// $example off$ /** * An example demonstrating a k-means clustering. * Run with *

    - * bin/run-example ml.JavaSimpleParamsExample  
    + * bin/run-example ml.JavaKMeansExample  
      * 
    */ public class JavaKMeansExample { @@ -74,11 +76,12 @@ public static void main(String[] args) { JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext sqlContext = new SQLContext(jsc); + // $example on$ // Loads data JavaRDD points = jsc.textFile(inputFile).map(new ParsePoint()); StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; StructType schema = new StructType(fields); - DataFrame dataset = sqlContext.createDataFrame(points, schema); + Dataset dataset = sqlContext.createDataFrame(points, schema); // Trains a k-means model KMeans kmeans = new KMeans() @@ -91,6 +94,7 @@ public static void main(String[] args) { for (Vector center: centers) { System.out.println(center); } + // $example off$ jsc.stop(); } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java new file mode 100644 index 0000000000000..351bc401180cc --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; +// $example on$ +import java.util.regex.Pattern; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.ml.clustering.LDA; +import org.apache.spark.ml.clustering.LDAModel; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +/** + * An example demonstrating LDA + * Run with + *
    + * bin/run-example ml.JavaLDAExample
    + * 
    + */ +public class JavaLDAExample { + + // $example on$ + private static class ParseVector implements Function { + private static final Pattern separator = Pattern.compile(" "); + + @Override + public Row call(String line) { + String[] tok = separator.split(line); + double[] point = new double[tok.length]; + for (int i = 0; i < tok.length; ++i) { + point[i] = Double.parseDouble(tok[i]); + } + Vector[] points = {Vectors.dense(point)}; + return new GenericRow(points); + } + } + + public static void main(String[] args) { + + String inputFile = "data/mllib/sample_lda_data.txt"; + + // Parses the arguments + SparkConf conf = new SparkConf().setAppName("JavaLDAExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // Loads data + JavaRDD points = jsc.textFile(inputFile).map(new ParseVector()); + StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; + StructType schema = new StructType(fields); + Dataset dataset = sqlContext.createDataFrame(points, schema); + + // Trains a LDA model + LDA lda = new LDA() + .setK(10) + .setMaxIter(10); + LDAModel model = lda.fit(dataset); + + System.out.println(model.logLikelihood(dataset)); + System.out.println(model.logPerplexity(dataset)); + + // Shows the result + Dataset topics = model.describeTopics(3); + topics.show(false); + model.transform(dataset).show(false); + + jsc.stop(); + } + // $example off$ +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLabeledDocument.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLabeledDocument.java new file mode 100644 index 0000000000000..68d1caf6ad3f3 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLabeledDocument.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.io.Serializable; + +/** + * Labeled instance type, Spark SQL can infer schema from Java Beans. + */ +@SuppressWarnings("serial") +public class JavaLabeledDocument extends JavaDocument implements Serializable { + + private double label; + + public JavaLabeledDocument(long id, String text, double label) { + super(id, text); + this.label = label; + } + + public double getLabel() { + return this.label; + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java new file mode 100644 index 0000000000000..08fce89359fc5 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.regression.LinearRegression; +import org.apache.spark.ml.regression.LinearRegressionModel; +import org.apache.spark.ml.regression.LinearRegressionTrainingSummary; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaLinearRegressionWithElasticNetExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaLinearRegressionWithElasticNetExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Load training data + Dataset training = sqlContext.read().format("libsvm") + .load("data/mllib/sample_linear_regression_data.txt"); + + LinearRegression lr = new LinearRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8); + + // Fit the model + LinearRegressionModel lrModel = lr.fit(training); + + // Print the coefficients and intercept for linear regression + System.out.println("Coefficients: " + + lrModel.coefficients() + " Intercept: " + lrModel.intercept()); + + // Summarize the model over the training set and print out some metrics + LinearRegressionTrainingSummary trainingSummary = lrModel.summary(); + System.out.println("numIterations: " + trainingSummary.totalIterations()); + System.out.println("objectiveHistory: " + Vectors.dense(trainingSummary.objectiveHistory())); + trainingSummary.residuals().show(); + System.out.println("RMSE: " + trainingSummary.rootMeanSquaredError()); + System.out.println("r2: " + trainingSummary.r2()); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java new file mode 100644 index 0000000000000..73b028fb44409 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.functions; +// $example off$ + +public class JavaLogisticRegressionSummaryExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaLogisticRegressionSummaryExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // Load training data + Dataset training = sqlContext.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); + + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8); + + // Fit the model + LogisticRegressionModel lrModel = lr.fit(training); + + // $example on$ + // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier + // example + LogisticRegressionTrainingSummary trainingSummary = lrModel.summary(); + + // Obtain the loss per iteration. + double[] objectiveHistory = trainingSummary.objectiveHistory(); + for (double lossPerIteration : objectiveHistory) { + System.out.println(lossPerIteration); + } + + // Obtain the metrics useful to judge performance on test data. + // We cast the summary to a BinaryLogisticRegressionSummary since the problem is a binary + // classification problem. + BinaryLogisticRegressionSummary binarySummary = + (BinaryLogisticRegressionSummary) trainingSummary; + + // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. + Dataset roc = binarySummary.roc(); + roc.show(); + roc.select("FPR").show(); + System.out.println(binarySummary.areaUnderROC()); + + // Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with + // this selected threshold. + Dataset fMeasure = binarySummary.fMeasureByThreshold(); + double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0); + double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)) + .select("threshold").head().getDouble(0); + lrModel.setThreshold(bestThreshold); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java new file mode 100644 index 0000000000000..691166852206c --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaLogisticRegressionWithElasticNetExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaLogisticRegressionWithElasticNetExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Load training data + Dataset training = sqlContext.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); + + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8); + + // Fit the model + LogisticRegressionModel lrModel = lr.fit(training); + + // Print the coefficients and intercept for logistic regression + System.out.println("Coefficients: " + + lrModel.coefficients() + " Intercept: " + lrModel.intercept()); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMaxAbsScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMaxAbsScalerExample.java new file mode 100644 index 0000000000000..a2a072b253f39 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMaxAbsScalerExample.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.feature.MaxAbsScaler; +import org.apache.spark.ml.feature.MaxAbsScalerModel; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +// $example off$ +import org.apache.spark.sql.SQLContext; + +public class JavaMaxAbsScalerExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaMaxAbsScalerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + Dataset dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + MaxAbsScaler scaler = new MaxAbsScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures"); + + // Compute summary statistics and generate MaxAbsScalerModel + MaxAbsScalerModel scalerModel = scaler.fit(dataFrame); + + // rescale each feature to range [-1, 1]. + Dataset scaledData = scalerModel.transform(dataFrame); + scaledData.show(); + // $example off$ + jsc.stop(); + } + +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java new file mode 100644 index 0000000000000..4aee18eeabfcf --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import org.apache.spark.ml.feature.MinMaxScaler; +import org.apache.spark.ml.feature.MinMaxScalerModel; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +// $example off$ + +public class JavaMinMaxScalerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JaveMinMaxScalerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + Dataset dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + MinMaxScaler scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures"); + + // Compute summary statistics and generate MinMaxScalerModel + MinMaxScalerModel scalerModel = scaler.fit(dataFrame); + + // rescale each feature to range [min, max]. + Dataset scaledData = scalerModel.transform(dataFrame); + scaledData.show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java new file mode 100644 index 0000000000000..c4122d1247a94 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +// $example off$ + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example on$ +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; +import org.apache.spark.ml.feature.HashingTF; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.tuning.CrossValidator; +import org.apache.spark.ml.tuning.CrossValidatorModel; +import org.apache.spark.ml.tuning.ParamGridBuilder; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +// $example off$ +import org.apache.spark.sql.SQLContext; + +/** + * Java example for Model Selection via Cross Validation. + */ +public class JavaModelSelectionViaCrossValidationExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf() + .setAppName("JavaModelSelectionViaCrossValidationExample"); + SparkContext sc = new SparkContext(conf); + SQLContext sqlContext = new SQLContext(sc); + + // $example on$ + // Prepare training documents, which are labeled. + Dataset training = sqlContext.createDataFrame(Arrays.asList( + new JavaLabeledDocument(0L, "a b c d e spark", 1.0), + new JavaLabeledDocument(1L, "b d", 0.0), + new JavaLabeledDocument(2L,"spark f g h", 1.0), + new JavaLabeledDocument(3L, "hadoop mapreduce", 0.0), + new JavaLabeledDocument(4L, "b spark who", 1.0), + new JavaLabeledDocument(5L, "g d a y", 0.0), + new JavaLabeledDocument(6L, "spark fly", 1.0), + new JavaLabeledDocument(7L, "was mapreduce", 0.0), + new JavaLabeledDocument(8L, "e spark program", 1.0), + new JavaLabeledDocument(9L, "a e c l", 0.0), + new JavaLabeledDocument(10L, "spark compile", 1.0), + new JavaLabeledDocument(11L, "hadoop software", 0.0) + ), JavaLabeledDocument.class); + + // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + Tokenizer tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words"); + HashingTF hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol()) + .setOutputCol("features"); + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01); + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); + + // We use a ParamGridBuilder to construct a grid of parameters to search over. + // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, + // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. + ParamMap[] paramGrid = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures(), new int[] {10, 100, 1000}) + .addGrid(lr.regParam(), new double[] {0.1, 0.01}) + .build(); + + // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. + // This will allow us to jointly choose parameters for all Pipeline stages. + // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + // Note that the evaluator here is a BinaryClassificationEvaluator and its default metric + // is areaUnderROC. + CrossValidator cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(new BinaryClassificationEvaluator()) + .setEstimatorParamMaps(paramGrid).setNumFolds(2); // Use 3+ in practice + + // Run cross-validation, and choose the best set of parameters. + CrossValidatorModel cvModel = cv.fit(training); + + // Prepare test documents, which are unlabeled. + Dataset test = sqlContext.createDataFrame(Arrays.asList( + new JavaDocument(4L, "spark i j k"), + new JavaDocument(5L, "l m n"), + new JavaDocument(6L, "mapreduce spark"), + new JavaDocument(7L, "apache hadoop") + ), JavaDocument.class); + + // Make predictions on test documents. cvModel uses the best model found (lrModel). + Dataset predictions = cvModel.transform(test); + for (Row r : predictions.select("id", "text", "probability", "prediction").collectAsList()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + + ", prediction=" + r.get(3)); + } + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java new file mode 100644 index 0000000000000..4994f8f9fa857 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example on$ +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.regression.LinearRegression; +import org.apache.spark.ml.tuning.ParamGridBuilder; +import org.apache.spark.ml.tuning.TrainValidationSplit; +import org.apache.spark.ml.tuning.TrainValidationSplitModel; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +// $example off$ +import org.apache.spark.sql.SQLContext; + +/** + * Java example demonstrating model selection using TrainValidationSplit. + * + * The example is based on {@link org.apache.spark.examples.ml.JavaSimpleParamsExample} + * using linear regression. + * + * Run with + * {{{ + * bin/run-example ml.JavaModelSelectionViaTrainValidationSplitExample + * }}} + */ +public class JavaModelSelectionViaTrainValidationSplitExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf() + .setAppName("JavaModelSelectionViaTrainValidationSplitExample"); + SparkContext sc = new SparkContext(conf); + SQLContext jsql = new SQLContext(sc); + + // $example on$ + Dataset data = jsql.read().format("libsvm") + .load("data/mllib/sample_linear_regression_data.txt"); + + // Prepare training and test data. + Dataset[] splits = data.randomSplit(new double[] {0.9, 0.1}, 12345); + Dataset training = splits[0]; + Dataset test = splits[1]; + + LinearRegression lr = new LinearRegression(); + + // We use a ParamGridBuilder to construct a grid of parameters to search over. + // TrainValidationSplit will try all combinations of values and determine best model using + // the evaluator. + ParamMap[] paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam(), new double[] {0.1, 0.01}) + .addGrid(lr.fitIntercept()) + .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0}) + .build(); + + // In this case the estimator is simply the linear regression. + // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + TrainValidationSplit trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator()) + .setEstimatorParamMaps(paramGrid) + .setTrainRatio(0.8); // 80% for training and the remaining 20% for validation + + // Run train validation split, and choose the best set of parameters. + TrainValidationSplitModel model = trainValidationSplit.fit(training); + + // Make predictions on test data. model is the model with combination of parameters + // that performed best. + model.transform(test) + .select("features", "label", "prediction") + .show(); + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java new file mode 100644 index 0000000000000..0ca528d8cd079 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel; +import org.apache.spark.ml.classification.MultilayerPerceptronClassifier; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +// $example off$ + +/** + * An example for Multilayer Perceptron Classification. + */ +public class JavaMultilayerPerceptronClassifierExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaMultilayerPerceptronClassifierExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + // Load training data + String path = "data/mllib/sample_multiclass_classification_data.txt"; + Dataset dataFrame = jsql.read().format("libsvm").load(path); + // Split the data into train and test + Dataset[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L); + Dataset train = splits[0]; + Dataset test = splits[1]; + // specify layers for the neural network: + // input layer of size 4 (features), two intermediate of size 5 and 4 + // and output of size 3 (classes) + int[] layers = new int[] {4, 5, 4, 3}; + // create the trainer and set its parameters + MultilayerPerceptronClassifier trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(128) + .setSeed(1234L) + .setMaxIter(100); + // train the model + MultilayerPerceptronClassificationModel model = trainer.fit(train); + // compute precision on the test set + Dataset result = model.transform(test); + Dataset predictionAndLabels = result.select("prediction", "label"); + MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setMetricName("precision"); + System.out.println("Precision = " + evaluator.evaluate(predictionAndLabels)); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java new file mode 100644 index 0000000000000..608bd80285655 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.NGram; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaNGramExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaNGramExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")), + RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")), + RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat")) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField( + "words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) + }); + + Dataset wordDataFrame = sqlContext.createDataFrame(jrdd, schema); + + NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); + + Dataset ngramDataFrame = ngramTransformer.transform(wordDataFrame); + + for (Row r : ngramDataFrame.select("ngrams", "label").takeAsList(3)) { + java.util.List ngrams = r.getList(0); + for (String ngram : ngrams) System.out.print(ngram + " --- "); + System.out.println(); + } + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java new file mode 100644 index 0000000000000..41d7ad75b9d45 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.classification.NaiveBayes; +import org.apache.spark.ml.classification.NaiveBayesModel; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +// $example off$ + +/** + * An example for Naive Bayes Classification. + */ +public class JavaNaiveBayesExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaNaiveBayesExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + // Load training data + Dataset dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + // Split the data into train and test + Dataset[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L); + Dataset train = splits[0]; + Dataset test = splits[1]; + + // create the trainer and set its parameters + NaiveBayes nb = new NaiveBayes(); + // train the model + NaiveBayesModel model = nb.fit(train); + // compute precision on the test set + Dataset result = model.transform(test); + Dataset predictionAndLabels = result.select("prediction", "label"); + MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setMetricName("precision"); + System.out.println("Precision = " + evaluator.evaluate(predictionAndLabels)); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java new file mode 100644 index 0000000000000..31cd752136689 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import org.apache.spark.ml.feature.Normalizer; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +// $example off$ + +public class JavaNormalizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaNormalizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + Dataset dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + // Normalize each Vector using $L^1$ norm. + Normalizer normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures") + .setP(1.0); + + Dataset l1NormData = normalizer.transform(dataFrame); + l1NormData.show(); + + // Normalize each Vector using $L^\infty$ norm. + Dataset lInfNormData = + normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); + lInfNormData.show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java new file mode 100644 index 0000000000000..882438ca28eb7 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.OneHotEncoder; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.ml.feature.StringIndexerModel; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaOneHotEncoderExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaOneHotEncoderExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("category", DataTypes.StringType, false, Metadata.empty()) + }); + + Dataset df = sqlContext.createDataFrame(jrdd, schema); + + StringIndexerModel indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df); + Dataset indexed = indexer.transform(df); + + OneHotEncoder encoder = new OneHotEncoder() + .setInputCol("categoryIndex") + .setOutputCol("categoryVec"); + Dataset encoded = encoder.transform(indexed); + encoded.select("id", "categoryVec").show(); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java index e7f2f6f615070..1f13b48bf82ae 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java @@ -21,18 +21,19 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; +// $example on$ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.OneVsRest; import org.apache.spark.ml.classification.OneVsRestModel; import org.apache.spark.ml.util.MetadataUtils; import org.apache.spark.mllib.evaluation.MulticlassMetrics; import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.types.StructField; +// $example off$ /** * An example runner for Multiclass to Binary Reduction with One Vs Rest. @@ -63,6 +64,7 @@ public static void main(String[] args) { JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext jsql = new SQLContext(jsc); + // $example on$ // configure the base classifier LogisticRegression classifier = new LogisticRegression() .setMaxIter(params.maxIter) @@ -80,31 +82,30 @@ public static void main(String[] args) { OneVsRest ovr = new OneVsRest().setClassifier(classifier); String input = params.input; - RDD inputData = MLUtils.loadLibSVMFile(jsc.sc(), input); - RDD train; - RDD test; + Dataset inputData = jsql.read().format("libsvm").load(input); + Dataset train; + Dataset test; // compute the train/ test split: if testInput is not provided use part of input String testInput = params.testInput; if (testInput != null) { train = inputData; // compute the number of features in the training set. - int numFeatures = inputData.first().features().size(); - test = MLUtils.loadLibSVMFile(jsc.sc(), testInput, numFeatures); + int numFeatures = inputData.first().getAs(1).size(); + test = jsql.read().format("libsvm").option("numFeatures", + String.valueOf(numFeatures)).load(testInput); } else { double f = params.fracTest; - RDD[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345); + Dataset[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345); train = tmp[0]; test = tmp[1]; } // train the multiclass model - DataFrame trainingDataFrame = jsql.createDataFrame(train, LabeledPoint.class); - OneVsRestModel ovrModel = ovr.fit(trainingDataFrame.cache()); + OneVsRestModel ovrModel = ovr.fit(train.cache()); // score the model on test data - DataFrame testDataFrame = jsql.createDataFrame(test, LabeledPoint.class); - DataFrame predictions = ovrModel.transform(testDataFrame.cache()) + Dataset predictions = ovrModel.transform(test.cache()) .select("prediction", "label"); // obtain metrics @@ -128,6 +129,7 @@ public static void main(String[] args) { System.out.println(confusionMatrix); System.out.println(); System.out.println(results); + // $example off$ jsc.stop(); } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java new file mode 100644 index 0000000000000..a792fd7d47cc9 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.PCA; +import org.apache.spark.ml.feature.PCAModel; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaPCAExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaPCAExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), + RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), + RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + + Dataset df = jsql.createDataFrame(data, schema); + + PCAModel pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df); + + Dataset result = pca.transform(df).select("pcaFeatures"); + result.show(); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java new file mode 100644 index 0000000000000..305420f208b79 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +// $example off$ + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example on$ +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.feature.HashingTF; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +// $example off$ +import org.apache.spark.sql.SQLContext; + +/** + * Java example for simple text document 'Pipeline'. + */ +public class JavaPipelineExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaPipelineExample"); + SparkContext sc = new SparkContext(conf); + SQLContext sqlContext = new SQLContext(sc); + + // $example on$ + // Prepare training documents, which are labeled. + Dataset training = sqlContext.createDataFrame(Arrays.asList( + new JavaLabeledDocument(0L, "a b c d e spark", 1.0), + new JavaLabeledDocument(1L, "b d", 0.0), + new JavaLabeledDocument(2L, "spark f g h", 1.0), + new JavaLabeledDocument(3L, "hadoop mapreduce", 0.0) + ), JavaLabeledDocument.class); + + // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + Tokenizer tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words"); + HashingTF hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol()) + .setOutputCol("features"); + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01); + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); + + // Fit the pipeline to training documents. + PipelineModel model = pipeline.fit(training); + + // Prepare test documents, which are unlabeled. + Dataset test = sqlContext.createDataFrame(Arrays.asList( + new JavaDocument(4L, "spark i j k"), + new JavaDocument(5L, "l m n"), + new JavaDocument(6L, "mapreduce spark"), + new JavaDocument(7L, "apache hadoop") + ), JavaDocument.class); + + // Make predictions on test documents. + Dataset predictions = model.transform(test); + for (Row r : predictions.select("id", "text", "probability", "prediction").collectAsList()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + + ", prediction=" + r.get(3)); + } + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java new file mode 100644 index 0000000000000..48fc3c8acb0c0 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.PolynomialExpansion; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaPolynomialExpansionExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaPolynomialExpansionExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + PolynomialExpansion polyExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3); + + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(-2.0, 2.3)), + RowFactory.create(Vectors.dense(0.0, 0.0)), + RowFactory.create(Vectors.dense(0.6, -1.1)) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + + Dataset df = jsql.createDataFrame(data, schema); + Dataset polyDF = polyExpansion.transform(df); + + List rows = polyDF.select("polyFeatures").takeAsList(3); + for (Row r : rows) { + System.out.println(r.get(0)); + } + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java new file mode 100644 index 0000000000000..7b226fede9968 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.QuantileDiscretizer; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaQuantileDiscretizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaQuantileDiscretizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize( + Arrays.asList( + RowFactory.create(0, 18.0), + RowFactory.create(1, 19.0), + RowFactory.create(2, 8.0), + RowFactory.create(3, 5.0), + RowFactory.create(4, 2.2) + ) + ); + + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("hour", DataTypes.DoubleType, false, Metadata.empty()) + }); + + Dataset df = sqlContext.createDataFrame(jrdd, schema); + + QuantileDiscretizer discretizer = new QuantileDiscretizer() + .setInputCol("hour") + .setOutputCol("result") + .setNumBuckets(3); + + Dataset result = discretizer.fit(df).transform(df); + result.show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java new file mode 100644 index 0000000000000..8c453bf80d645 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RFormula; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaRFormulaExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaRFormulaExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + StructType schema = createStructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("country", StringType, false), + createStructField("hour", IntegerType, false), + createStructField("clicked", DoubleType, false) + }); + + JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(7, "US", 18, 1.0), + RowFactory.create(8, "CA", 12, 0.0), + RowFactory.create(9, "NZ", 15, 0.0) + )); + + Dataset dataset = sqlContext.createDataFrame(rdd, schema); + RFormula formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label"); + Dataset output = formula.fit(dataset).transform(dataset); + output.select("features", "label").show(); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java new file mode 100644 index 0000000000000..05c2bc9622e1b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.RandomForestClassificationModel; +import org.apache.spark.ml.classification.RandomForestClassifier; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.ml.feature.*; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaRandomForestClassifierExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaRandomForestClassifierExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + Dataset data = + sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + // Index labels, adding metadata to the label column. + // Fit on whole dataset to include all labels in index. + StringIndexerModel labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data); + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + + // Split the data into training and test sets (30% held out for testing) + Dataset[] splits = data.randomSplit(new double[] {0.7, 0.3}); + Dataset trainingData = splits[0]; + Dataset testData = splits[1]; + + // Train a RandomForest model. + RandomForestClassifier rf = new RandomForestClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures"); + + // Convert indexed labels back to original labels. + IndexToString labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels()); + + // Chain indexers and forest in a Pipeline + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {labelIndexer, featureIndexer, rf, labelConverter}); + + // Train model. This also runs the indexers. + PipelineModel model = pipeline.fit(trainingData); + + // Make predictions. + Dataset predictions = model.transform(testData); + + // Select example rows to display. + predictions.select("predictedLabel", "label", "features").show(5); + + // Select (prediction, true label) and compute test error + MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision"); + double accuracy = evaluator.evaluate(predictions); + System.out.println("Test Error = " + (1.0 - accuracy)); + + RandomForestClassificationModel rfModel = (RandomForestClassificationModel)(model.stages()[2]); + System.out.println("Learned classification forest model:\n" + rfModel.toDebugString()); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java new file mode 100644 index 0000000000000..d366967083a19 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.ml.regression.RandomForestRegressionModel; +import org.apache.spark.ml.regression.RandomForestRegressor; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaRandomForestRegressorExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaRandomForestRegressorExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + Dataset data = + sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + + // Split the data into training and test sets (30% held out for testing) + Dataset[] splits = data.randomSplit(new double[] {0.7, 0.3}); + Dataset trainingData = splits[0]; + Dataset testData = splits[1]; + + // Train a RandomForest model. + RandomForestRegressor rf = new RandomForestRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures"); + + // Chain indexer and forest in a Pipeline + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {featureIndexer, rf}); + + // Train model. This also runs the indexer. + PipelineModel model = pipeline.fit(trainingData); + + // Make predictions. + Dataset predictions = model.transform(testData); + + // Select example rows to display. + predictions.select("prediction", "label", "features").show(5); + + // Select (prediction, true label) and compute test error + RegressionEvaluator evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse"); + double rmse = evaluator.evaluate(predictions); + System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); + + RandomForestRegressionModel rfModel = (RandomForestRegressionModel)(model.stages()[1]); + System.out.println("Learned regression forest model:\n" + rfModel.toDebugString()); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java new file mode 100644 index 0000000000000..7e3ca99d7cb93 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.SQLTransformer; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; +// $example off$ + +public class JavaSQLTransformerExample { + public static void main(String[] args) { + + SparkConf conf = new SparkConf().setAppName("JavaSQLTransformerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, 1.0, 3.0), + RowFactory.create(2, 2.0, 5.0) + )); + StructType schema = new StructType(new StructField [] { + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("v1", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("v2", DataTypes.DoubleType, false, Metadata.empty()) + }); + Dataset df = sqlContext.createDataFrame(jrdd, schema); + + SQLTransformer sqlTrans = new SQLTransformer().setStatement( + "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__"); + + sqlTrans.transform(df).show(); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index 94beeced3d479..cb911ef5ef586 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -28,7 +28,7 @@ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -54,7 +54,8 @@ public static void main(String[] args) { new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); - DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class); + Dataset training = + jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class); // Create a LogisticRegression instance. This instance is an Estimator. LogisticRegression lr = new LogisticRegression(); @@ -77,7 +78,7 @@ public static void main(String[] args) { ParamMap paramMap = new ParamMap(); paramMap.put(lr.maxIter().w(20)); // Specify 1 Param. paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter. - double thresholds[] = {0.45, 0.55}; + double[] thresholds = {0.45, 0.55}; paramMap.put(lr.regParam().w(0.1), lr.thresholds().w(thresholds)); // Specify multiple Params. // One can also combine ParamMaps. @@ -95,14 +96,15 @@ public static void main(String[] args) { new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); - DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); + Dataset test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. // LogisticRegressionModel.transform will only use the 'features' column. // Note that model2.transform() outputs a 'myProbability' column instead of the usual // 'probability' column since we renamed the lr.probabilityCol parameter previously. - DataFrame results = model2.transform(test); - for (Row r: results.select("features", "label", "myProbability", "prediction").collect()) { + Dataset results = model2.transform(test); + Dataset rows = results.select("features", "label", "myProbability", "prediction"); + for (Row r: rows.collectAsList()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) + ", prediction=" + r.get(3)); } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java index 54738813d0016..a18a60f448166 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java @@ -29,7 +29,7 @@ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.HashingTF; import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -54,7 +54,8 @@ public static void main(String[] args) { new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), new LabeledDocument(3L, "hadoop mapreduce", 0.0)); - DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); + Dataset training = + jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -79,11 +80,11 @@ public static void main(String[] args) { new Document(5L, "l m n"), new Document(6L, "spark hadoop spark"), new Document(7L, "apache hadoop")); - DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); + Dataset test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. - DataFrame predictions = model.transform(test); - for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { + Dataset predictions = model.transform(test); + for (Row r: predictions.select("id", "text", "probability", "prediction").collectAsList()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3)); } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java new file mode 100644 index 0000000000000..e2dd759c0a40c --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import org.apache.spark.ml.feature.StandardScaler; +import org.apache.spark.ml.feature.StandardScalerModel; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +// $example off$ + +public class JavaStandardScalerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaStandardScalerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + Dataset dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + StandardScaler scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false); + + // Compute summary statistics by fitting the StandardScaler + StandardScalerModel scalerModel = scaler.fit(dataFrame); + + // Normalize each feature to have unit standard deviation. + Dataset scaledData = scalerModel.transform(dataFrame); + scaledData.show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java new file mode 100644 index 0000000000000..0ff3782cb3e90 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StopWordsRemover; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaStopWordsRemoverExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaStopWordsRemoverExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + StopWordsRemover remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered"); + + JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), + RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField( + "raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) + }); + + Dataset dataset = jsql.createDataFrame(rdd, schema); + remover.transform(dataset).show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java new file mode 100644 index 0000000000000..ceacbb4fb3f33 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaStringIndexerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaStringIndexerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") + )); + StructType schema = new StructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("category", StringType, false) + }); + Dataset df = sqlContext.createDataFrame(jrdd, schema); + StringIndexer indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex"); + Dataset indexed = indexer.fit(df).transform(df); + indexed.show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java index a41a5ec9bff05..37a3d0d84dae2 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java @@ -28,7 +28,7 @@ import org.apache.spark.ml.feature.IDFModel; import org.apache.spark.ml.feature.Tokenizer; import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -54,19 +54,19 @@ public static void main(String[] args) { new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) }); - DataFrame sentenceData = sqlContext.createDataFrame(jrdd, schema); + Dataset sentenceData = sqlContext.createDataFrame(jrdd, schema); Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); - DataFrame wordsData = tokenizer.transform(sentenceData); + Dataset wordsData = tokenizer.transform(sentenceData); int numFeatures = 20; HashingTF hashingTF = new HashingTF() .setInputCol("words") .setOutputCol("rawFeatures") .setNumFeatures(numFeatures); - DataFrame featurizedData = hashingTF.transform(wordsData); + Dataset featurizedData = hashingTF.transform(wordsData); IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features"); IDFModel idfModel = idf.fit(featurizedData); - DataFrame rescaledData = idfModel.transform(featurizedData); - for (Row r : rescaledData.select("features", "label").take(3)) { + Dataset rescaledData = idfModel.transform(featurizedData); + for (Row r : rescaledData.select("features", "label").takeAsList(3)) { Vector features = r.getAs(0); Double label = r.getDouble(1); System.out.println(features); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java new file mode 100644 index 0000000000000..9225fe2262f57 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RegexTokenizer; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaTokenizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaTokenizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "Hi I heard about Spark"), + RowFactory.create(1, "I wish Java could use case classes"), + RowFactory.create(2, "Logistic,regression,models,are,neat") + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) + }); + + Dataset sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); + + Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); + + Dataset wordsDataFrame = tokenizer.transform(sentenceDataFrame); + for (Row r : wordsDataFrame.select("words", "label").takeAsList(3)) { + java.util.List words = r.getList(0); + for (String word : words) System.out.print(word + " "); + System.out.println(); + } + + RegexTokenizer regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java deleted file mode 100644 index 23f834ab4332b..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.evaluation.RegressionEvaluator; -import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.ml.regression.LinearRegression; -import org.apache.spark.ml.tuning.*; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; - -/** - * A simple example demonstrating model selection using TrainValidationSplit. - * - * The example is based on {@link org.apache.spark.examples.ml.JavaSimpleParamsExample} - * using linear regression. - * - * Run with - * {{{ - * bin/run-example ml.JavaTrainValidationSplitExample - * }}} - */ -public class JavaTrainValidationSplitExample { - - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaTrainValidationSplitExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - DataFrame data = jsql.createDataFrame( - MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"), - LabeledPoint.class); - - // Prepare training and test data. - DataFrame[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345); - DataFrame training = splits[0]; - DataFrame test = splits[1]; - - LinearRegression lr = new LinearRegression(); - - // We use a ParamGridBuilder to construct a grid of parameters to search over. - // TrainValidationSplit will try all combinations of values and determine best model using - // the evaluator. - ParamMap[] paramGrid = new ParamGridBuilder() - .addGrid(lr.regParam(), new double[] {0.1, 0.01}) - .addGrid(lr.fitIntercept()) - .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0}) - .build(); - - // In this case the estimator is simply the linear regression. - // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. - TrainValidationSplit trainValidationSplit = new TrainValidationSplit() - .setEstimator(lr) - .setEvaluator(new RegressionEvaluator()) - .setEstimatorParamMaps(paramGrid); - - // 80% of the data will be used for training and the remaining 20% for validation. - trainValidationSplit.setTrainRatio(0.8); - - // Run train validation split, and choose the best set of parameters. - TrainValidationSplitModel model = trainValidationSplit.fit(training); - - // Make predictions on test data. model is the model with combination of parameters - // that performed best. - model.transform(test) - .select("features", "label", "prediction") - .show(); - - jsc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java new file mode 100644 index 0000000000000..953ad455b1dcd --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.VectorAssembler; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; + +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaVectorAssemblerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaVectorAssemblerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + StructType schema = createStructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("hour", IntegerType, false), + createStructField("mobile", DoubleType, false), + createStructField("userFeatures", new VectorUDT(), false), + createStructField("clicked", DoubleType, false) + }); + Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); + JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); + Dataset dataset = sqlContext.createDataFrame(rdd, schema); + + VectorAssembler assembler = new VectorAssembler() + .setInputCols(new String[]{"hour", "mobile", "userFeatures"}) + .setOutputCol("features"); + + Dataset output = assembler.transform(dataset); + System.out.println(output.select("features", "clicked").first()); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java new file mode 100644 index 0000000000000..b3b5953ee7bbe --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Map; + +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +// $example off$ + +public class JavaVectorIndexerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaVectorIndexerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + Dataset data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + VectorIndexer indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(10); + VectorIndexerModel indexerModel = indexer.fit(data); + + Map> categoryMaps = indexerModel.javaCategoryMaps(); + System.out.print("Chose " + categoryMaps.size() + " categorical features:"); + + for (Integer feature : categoryMaps.keySet()) { + System.out.print(" " + feature); + } + System.out.println(); + + // Create new column "indexed" with categorical values transformed to indices + Dataset indexedData = indexerModel.transform(data); + indexedData.show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java new file mode 100644 index 0000000000000..2ae57c3577eff --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.attribute.Attribute; +import org.apache.spark.ml.attribute.AttributeGroup; +import org.apache.spark.ml.attribute.NumericAttribute; +import org.apache.spark.ml.feature.VectorSlicer; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +// $example off$ + +public class JavaVectorSlicerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaVectorSlicerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + Attribute[] attrs = new Attribute[]{ + NumericAttribute.defaultAttr().withName("f1"), + NumericAttribute.defaultAttr().withName("f2"), + NumericAttribute.defaultAttr().withName("f3") + }; + AttributeGroup group = new AttributeGroup("userFeatures", attrs); + + JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), + RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) + )); + + Dataset dataset = + jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); + + VectorSlicer vectorSlicer = new VectorSlicer() + .setInputCol("userFeatures").setOutputCol("features"); + + vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); + // or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) + + Dataset output = vectorSlicer.transform(dataset); + + System.out.println(output.select("userFeatures", "features").first()); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java index d472375ca9825..c5bb1eaaa3446 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java @@ -25,7 +25,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.feature.Word2Vec; import org.apache.spark.ml.feature.Word2VecModel; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -49,7 +49,7 @@ public static void main(String[] args) { StructType schema = new StructType(new StructField[]{ new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) }); - DataFrame documentDF = sqlContext.createDataFrame(jrdd, schema); + Dataset documentDF = sqlContext.createDataFrame(jrdd, schema); // Learn a mapping from words to Vectors. Word2Vec word2Vec = new Word2Vec() @@ -58,10 +58,12 @@ public static void main(String[] args) { .setVectorSize(3) .setMinCount(0); Word2VecModel model = word2Vec.fit(documentDF); - DataFrame result = model.transform(documentDF); - for (Row r : result.select("result").take(3)) { + Dataset result = model.transform(documentDF); + for (Row r : result.select("result").takeAsList(3)) { System.out.println(r); } // $example off$ + + jsc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java index 4d0f989819ace..189560e3fe1f1 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java @@ -52,5 +52,7 @@ public static void main(String[] args) { rule.javaAntecedent() + " => " + rule.javaConsequent() + ", " + rule.confidence()); } // $example off$ + + sc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java new file mode 100644 index 0000000000000..7561a1f6535d6 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class JavaBinaryClassificationMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Java Binary Classification Metrics Example"); + SparkContext sc = new SparkContext(conf); + // $example on$ + String path = "data/mllib/sample_binary_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = + data.randomSplit(new double[]{0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training.rdd()); + + // Clear the prediction threshold so the model will return probabilities + model.clearThreshold(); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + @Override + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + BinaryClassificationMetrics metrics = + new BinaryClassificationMetrics(predictionAndLabels.rdd()); + + // Precision by threshold + JavaRDD> precision = metrics.precisionByThreshold().toJavaRDD(); + System.out.println("Precision by threshold: " + precision.collect()); + + // Recall by threshold + JavaRDD> recall = metrics.recallByThreshold().toJavaRDD(); + System.out.println("Recall by threshold: " + recall.collect()); + + // F Score by threshold + JavaRDD> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); + System.out.println("F1 Score by threshold: " + f1Score.collect()); + + JavaRDD> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); + System.out.println("F2 Score by threshold: " + f2Score.collect()); + + // Precision-recall curve + JavaRDD> prc = metrics.pr().toJavaRDD(); + System.out.println("Precision-recall curve: " + prc.collect()); + + // Thresholds + JavaRDD thresholds = precision.map( + new Function, Double>() { + @Override + public Double call(Tuple2 t) { + return new Double(t._1().toString()); + } + } + ); + + // ROC Curve + JavaRDD> roc = metrics.roc().toJavaRDD(); + System.out.println("ROC curve: " + roc.collect()); + + // AUPRC + System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); + + // AUROC + System.out.println("Area under ROC = " + metrics.areaUnderROC()); + + // Save and load model + model.save(sc, "target/tmp/LogisticRegressionModel"); + LogisticRegressionModel.load(sc, "target/tmp/LogisticRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java new file mode 100644 index 0000000000000..c600094947d5a --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import java.util.ArrayList; + +// $example on$ +import com.google.common.collect.Lists; +// $example off$ +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.clustering.BisectingKMeans; +import org.apache.spark.mllib.clustering.BisectingKMeansModel; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +// $example off$ + +/** + * Java example for bisecting k-means clustering. + */ +public class JavaBisectingKMeansExample { + public static void main(String[] args) { + SparkConf sparkConf = new SparkConf().setAppName("JavaBisectingKMeansExample"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + // $example on$ + ArrayList localData = Lists.newArrayList( + Vectors.dense(0.1, 0.1), Vectors.dense(0.3, 0.3), + Vectors.dense(10.1, 10.1), Vectors.dense(10.3, 10.3), + Vectors.dense(20.1, 20.1), Vectors.dense(20.3, 20.3), + Vectors.dense(30.1, 30.1), Vectors.dense(30.3, 30.3) + ); + JavaRDD data = sc.parallelize(localData, 2); + + BisectingKMeans bkm = new BisectingKMeans() + .setK(4); + BisectingKMeansModel model = bkm.run(data); + + System.out.println("Compute Cost: " + model.computeCost(data)); + + Vector[] clusterCenters = model.clusterCenters(); + for (int i = 0; i < clusterCenters.length; i++) { + Vector clusterCenter = clusterCenters[i]; + System.out.println("Cluster Center " + i + ": " + clusterCenter); + } + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaChiSqSelectorExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaChiSqSelectorExample.java new file mode 100644 index 0000000000000..ad44acb4cd6e3 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaChiSqSelectorExample.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.VoidFunction; +// $example on$ +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.feature.ChiSqSelector; +import org.apache.spark.mllib.feature.ChiSqSelectorModel; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +public class JavaChiSqSelectorExample { + public static void main(String[] args) { + + SparkConf conf = new SparkConf().setAppName("JavaChiSqSelectorExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + + // $example on$ + JavaRDD points = MLUtils.loadLibSVMFile(jsc.sc(), + "data/mllib/sample_libsvm_data.txt").toJavaRDD().cache(); + + // Discretize data in 16 equal bins since ChiSqSelector requires categorical features + // Although features are doubles, the ChiSqSelector treats each unique value as a category + JavaRDD discretizedData = points.map( + new Function() { + @Override + public LabeledPoint call(LabeledPoint lp) { + final double[] discretizedFeatures = new double[lp.features().size()]; + for (int i = 0; i < lp.features().size(); ++i) { + discretizedFeatures[i] = Math.floor(lp.features().apply(i) / 16); + } + return new LabeledPoint(lp.label(), Vectors.dense(discretizedFeatures)); + } + } + ); + + // Create ChiSqSelector that will select top 50 of 692 features + ChiSqSelector selector = new ChiSqSelector(50); + // Create ChiSqSelector model (selecting features) + final ChiSqSelectorModel transformer = selector.fit(discretizedData.rdd()); + // Filter the top 50 features from each feature vector + JavaRDD filteredData = discretizedData.map( + new Function() { + @Override + public LabeledPoint call(LabeledPoint lp) { + return new LabeledPoint(lp.label(), transformer.transform(lp.features())); + } + } + ); + // $example off$ + + System.out.println("filtered data: "); + filteredData.foreach(new VoidFunction() { + @Override + public void call(LabeledPoint labeledPoint) throws Exception { + System.out.println(labeledPoint.toString()); + } + }); + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaCorrelationsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaCorrelationsExample.java new file mode 100644 index 0000000000000..c0fa0b3cac1e9 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaCorrelationsExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.stat.Statistics; +// $example off$ + +public class JavaCorrelationsExample { + public static void main(String[] args) { + + SparkConf conf = new SparkConf().setAppName("JavaCorrelationsExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + + // $example on$ + JavaDoubleRDD seriesX = jsc.parallelizeDoubles( + Arrays.asList(1.0, 2.0, 3.0, 3.0, 5.0)); // a series + + // must have the same number of partitions and cardinality as seriesX + JavaDoubleRDD seriesY = jsc.parallelizeDoubles( + Arrays.asList(11.0, 22.0, 33.0, 33.0, 555.0)); + + // compute the correlation using Pearson's method. Enter "spearman" for Spearman's method. + // If a method is not specified, Pearson's method will be used by default. + Double correlation = Statistics.corr(seriesX.srdd(), seriesY.srdd(), "pearson"); + System.out.println("Correlation is: " + correlation); + + // note that each Vector is a row and not a column + JavaRDD data = jsc.parallelize( + Arrays.asList( + Vectors.dense(1.0, 10.0, 100.0), + Vectors.dense(2.0, 20.0, 200.0), + Vectors.dense(5.0, 33.0, 366.0) + ) + ); + + // calculate the correlation matrix using Pearson's method. + // Use "spearman" for Spearman's method. + // If a method is not specified, Pearson's method will be used by default. + Matrix correlMatrix = Statistics.corr(data.rdd(), "pearson"); + System.out.println(correlMatrix.toString()); + // $example off$ + + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java deleted file mode 100644 index 1f82e3f4cb18e..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import java.util.HashMap; - -import scala.Tuple2; - -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.DecisionTree; -import org.apache.spark.mllib.tree.model.DecisionTreeModel; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; - -/** - * Classification and regression using decision trees. - */ -public final class JavaDecisionTree { - - public static void main(String[] args) { - String datapath = "data/mllib/sample_libsvm_data.txt"; - if (args.length == 1) { - datapath = args[0]; - } else if (args.length > 1) { - System.err.println("Usage: JavaDecisionTree "); - System.exit(1); - } - SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - - JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); - - // Compute the number of classes from the data. - Integer numClasses = data.map(new Function() { - @Override public Double call(LabeledPoint p) { - return p.label(); - } - }).countByValue().size(); - - // Set parameters. - // Empty categoricalFeaturesInfo indicates all features are continuous. - HashMap categoricalFeaturesInfo = new HashMap(); - String impurity = "gini"; - Integer maxDepth = 5; - Integer maxBins = 32; - - // Train a DecisionTree model for classification. - final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses, - categoricalFeaturesInfo, impurity, maxDepth, maxBins); - - // Evaluate model on training instances and compute training error - JavaPairRDD predictionAndLabel = - data.mapToPair(new PairFunction() { - @Override public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double trainErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / data.count(); - System.out.println("Training error: " + trainErr); - System.out.println("Learned classification tree model:\n" + model); - - // Train a DecisionTree model for regression. - impurity = "variance"; - final DecisionTreeModel regressionModel = DecisionTree.trainRegressor(data, - categoricalFeaturesInfo, impurity, maxDepth, maxBins); - - // Evaluate model on training instances and compute training error - JavaPairRDD regressorPredictionAndLabel = - data.mapToPair(new PairFunction() { - @Override public Tuple2 call(LabeledPoint p) { - return new Tuple2(regressionModel.predict(p.features()), p.label()); - } - }); - Double trainMSE = - regressorPredictionAndLabel.map(new Function, Double>() { - @Override public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override public Double call(Double a, Double b) { - return a + b; - } - }) / data.count(); - System.out.println("Training Mean Squared Error: " + trainMSE); - System.out.println("Learned regression tree model:\n" + regressionModel); - - sc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java new file mode 100644 index 0000000000000..66387b9df51c7 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.HashMap; +import java.util.Map; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.DecisionTree; +import org.apache.spark.mllib.tree.model.DecisionTreeModel; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +class JavaDecisionTreeClassificationExample { + + public static void main(String[] args) { + + // $example on$ + SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeClassificationExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + + // Load and parse the data file. + String datapath = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); + // Split the data into training and test sets (30% held out for testing) + JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); + JavaRDD trainingData = splits[0]; + JavaRDD testData = splits[1]; + + // Set parameters. + // Empty categoricalFeaturesInfo indicates all features are continuous. + Integer numClasses = 2; + Map categoricalFeaturesInfo = new HashMap<>(); + String impurity = "gini"; + Integer maxDepth = 5; + Integer maxBins = 32; + + // Train a DecisionTree model for classification. + final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, + categoricalFeaturesInfo, impurity, maxDepth, maxBins); + + // Evaluate model on test instances and compute test error + JavaPairRDD predictionAndLabel = + testData.mapToPair(new PairFunction() { + @Override + public Tuple2 call(LabeledPoint p) { + return new Tuple2<>(model.predict(p.features()), p.label()); + } + }); + Double testErr = + 1.0 * predictionAndLabel.filter(new Function, Boolean>() { + @Override + public Boolean call(Tuple2 pl) { + return !pl._1().equals(pl._2()); + } + }).count() / testData.count(); + + System.out.println("Test Error: " + testErr); + System.out.println("Learned classification tree model:\n" + model.toDebugString()); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel"); + DecisionTreeModel sameModel = DecisionTreeModel + .load(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java new file mode 100644 index 0000000000000..904e7f7e9505e --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.HashMap; +import java.util.Map; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.DecisionTree; +import org.apache.spark.mllib.tree.model.DecisionTreeModel; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +class JavaDecisionTreeRegressionExample { + + public static void main(String[] args) { + + // $example on$ + SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeRegressionExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + + // Load and parse the data file. + String datapath = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); + // Split the data into training and test sets (30% held out for testing) + JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); + JavaRDD trainingData = splits[0]; + JavaRDD testData = splits[1]; + + // Set parameters. + // Empty categoricalFeaturesInfo indicates all features are continuous. + Map categoricalFeaturesInfo = new HashMap<>(); + String impurity = "variance"; + Integer maxDepth = 5; + Integer maxBins = 32; + + // Train a DecisionTree model. + final DecisionTreeModel model = DecisionTree.trainRegressor(trainingData, + categoricalFeaturesInfo, impurity, maxDepth, maxBins); + + // Evaluate model on test instances and compute test error + JavaPairRDD predictionAndLabel = + testData.mapToPair(new PairFunction() { + @Override + public Tuple2 call(LabeledPoint p) { + return new Tuple2<>(model.predict(p.features()), p.label()); + } + }); + Double testMSE = + predictionAndLabel.map(new Function, Double>() { + @Override + public Double call(Tuple2 pl) { + Double diff = pl._1() - pl._2(); + return diff * diff; + } + }).reduce(new Function2() { + @Override + public Double call(Double a, Double b) { + return a + b; + } + }) / data.count(); + System.out.println("Test Mean Squared Error: " + testMSE); + System.out.println("Learned regression tree model:\n" + model.toDebugString()); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel"); + DecisionTreeModel sameModel = DecisionTreeModel + .load(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaElementwiseProductExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaElementwiseProductExample.java new file mode 100644 index 0000000000000..c8ce6ab284b07 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaElementwiseProductExample.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.Arrays; +// $example off$ + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.feature.ElementwiseProduct; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +// $example off$ +import org.apache.spark.api.java.function.VoidFunction; + +public class JavaElementwiseProductExample { + public static void main(String[] args) { + + SparkConf conf = new SparkConf().setAppName("JavaElementwiseProductExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + + // $example on$ + // Create some vector data; also works for sparse vectors + JavaRDD data = jsc.parallelize(Arrays.asList( + Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0))); + Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); + final ElementwiseProduct transformer = new ElementwiseProduct(transformingVector); + + // Batch transform and per-row transform give the same results: + JavaRDD transformedData = transformer.transform(data); + JavaRDD transformedData2 = data.map( + new Function() { + @Override + public Vector call(Vector v) { + return transformer.transform(v); + } + } + ); + // $example off$ + + System.out.println("transformedData: "); + transformedData.foreach(new VoidFunction() { + @Override + public void call(Vector vector) throws Exception { + System.out.println(vector.toString()); + } + }); + + System.out.println("transformedData2: "); + transformedData2.foreach(new VoidFunction() { + @Override + public void call(Vector vector) throws Exception { + System.out.println(vector.toString()); + } + }); + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java deleted file mode 100644 index 36baf5868736c..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import java.util.ArrayList; - -import com.google.common.base.Joiner; -import com.google.common.collect.Lists; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.fpm.FPGrowth; -import org.apache.spark.mllib.fpm.FPGrowthModel; - -/** - * Java example for mining frequent itemsets using FP-growth. - * Example usage: ./bin/run-example mllib.JavaFPGrowthExample ./data/mllib/sample_fpgrowth.txt - */ -public class JavaFPGrowthExample { - - public static void main(String[] args) { - String inputFile; - double minSupport = 0.3; - int numPartition = -1; - if (args.length < 1) { - System.err.println( - "Usage: JavaFPGrowth [minSupport] [numPartition]"); - System.exit(1); - } - inputFile = args[0]; - if (args.length >= 2) { - minSupport = Double.parseDouble(args[1]); - } - if (args.length >= 3) { - numPartition = Integer.parseInt(args[2]); - } - - SparkConf sparkConf = new SparkConf().setAppName("JavaFPGrowthExample"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - - JavaRDD> transactions = sc.textFile(inputFile).map( - new Function>() { - @Override - public ArrayList call(String s) { - return Lists.newArrayList(s.split(" ")); - } - } - ); - - FPGrowthModel model = new FPGrowth() - .setMinSupport(minSupport) - .setNumPartitions(numPartition) - .run(transactions); - - for (FPGrowth.FreqItemset s: model.freqItemsets().toJavaRDD().collect()) { - System.out.println("[" + Joiner.on(",").join(s.javaItems()) + "], " + s.freq()); - } - - sc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGaussianMixtureExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGaussianMixtureExample.java new file mode 100644 index 0000000000000..3124411c8227c --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGaussianMixtureExample.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; + +// $example on$ +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.clustering.GaussianMixture; +import org.apache.spark.mllib.clustering.GaussianMixtureModel; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +// $example off$ + +public class JavaGaussianMixtureExample { + public static void main(String[] args) { + + SparkConf conf = new SparkConf().setAppName("JavaGaussianMixtureExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + + // $example on$ + // Load and parse data + String path = "data/mllib/gmm_data.txt"; + JavaRDD data = jsc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public Vector call(String s) { + String[] sarray = s.trim().split(" "); + double[] values = new double[sarray.length]; + for (int i = 0; i < sarray.length; i++) { + values[i] = Double.parseDouble(sarray[i]); + } + return Vectors.dense(values); + } + } + ); + parsedData.cache(); + + // Cluster the data into two classes using GaussianMixture + GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd()); + + // Save and load GaussianMixtureModel + gmm.save(jsc.sc(), "target/org/apache/spark/JavaGaussianMixtureExample/GaussianMixtureModel"); + GaussianMixtureModel sameModel = GaussianMixtureModel.load(jsc.sc(), + "target/org.apache.spark.JavaGaussianMixtureExample/GaussianMixtureModel"); + + // Output the parameters of the mixture model + for (int j = 0; j < gmm.k(); j++) { + System.out.printf("weight=%f\nmu=%s\nsigma=\n%s\n", + gmm.weights()[j], gmm.gaussians()[j].mu(), gmm.gaussians()[j].sigma()); + } + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java deleted file mode 100644 index a1844d5d07ad4..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import scala.Tuple2; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.GradientBoostedTrees; -import org.apache.spark.mllib.tree.configuration.BoostingStrategy; -import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; -import org.apache.spark.mllib.util.MLUtils; - -/** - * Classification and regression using gradient-boosted decision trees. - */ -public final class JavaGradientBoostedTreesRunner { - - private static void usage() { - System.err.println("Usage: JavaGradientBoostedTreesRunner " + - " "); - System.exit(-1); - } - - public static void main(String[] args) { - String datapath = "data/mllib/sample_libsvm_data.txt"; - String algo = "Classification"; - if (args.length >= 1) { - datapath = args[0]; - } - if (args.length >= 2) { - algo = args[1]; - } - if (args.length > 2) { - usage(); - } - SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTreesRunner"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - - JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); - - // Set parameters. - // Note: All features are treated as continuous. - BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo); - boostingStrategy.setNumIterations(10); - boostingStrategy.treeStrategy().setMaxDepth(5); - - if (algo.equals("Classification")) { - // Compute the number of classes from the data. - Integer numClasses = data.map(new Function() { - @Override public Double call(LabeledPoint p) { - return p.label(); - } - }).countByValue().size(); - boostingStrategy.treeStrategy().setNumClasses(numClasses); - - // Train a GradientBoosting model for classification. - final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy); - - // Evaluate model on training instances and compute training error - JavaPairRDD predictionAndLabel = - data.mapToPair(new PairFunction() { - @Override public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double trainErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / data.count(); - System.out.println("Training error: " + trainErr); - System.out.println("Learned classification tree model:\n" + model); - } else if (algo.equals("Regression")) { - // Train a GradientBoosting model for classification. - final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy); - - // Evaluate model on training instances and compute training error - JavaPairRDD predictionAndLabel = - data.mapToPair(new PairFunction() { - @Override public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double trainMSE = - predictionAndLabel.map(new Function, Double>() { - @Override public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override public Double call(Double a, Double b) { - return a + b; - } - }) / data.count(); - System.out.println("Training Mean Squared Error: " + trainMSE); - System.out.println("Learned regression tree model:\n" + model); - } else { - usage(); - } - - sc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingClassificationExample.java new file mode 100644 index 0000000000000..213949e525dc2 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingClassificationExample.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.HashMap; +import java.util.Map; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.GradientBoostedTrees; +import org.apache.spark.mllib.tree.configuration.BoostingStrategy; +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +public class JavaGradientBoostingClassificationExample { + public static void main(String[] args) { + // $example on$ + SparkConf sparkConf = new SparkConf() + .setAppName("JavaGradientBoostedTreesClassificationExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + + // Load and parse the data file. + String datapath = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); + // Split the data into training and test sets (30% held out for testing) + JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); + JavaRDD trainingData = splits[0]; + JavaRDD testData = splits[1]; + + // Train a GradientBoostedTrees model. + // The defaultParams for Classification use LogLoss by default. + BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Classification"); + boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. + boostingStrategy.getTreeStrategy().setNumClasses(2); + boostingStrategy.getTreeStrategy().setMaxDepth(5); + // Empty categoricalFeaturesInfo indicates all features are continuous. + Map categoricalFeaturesInfo = new HashMap<>(); + boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo); + + final GradientBoostedTreesModel model = + GradientBoostedTrees.train(trainingData, boostingStrategy); + + // Evaluate model on test instances and compute test error + JavaPairRDD predictionAndLabel = + testData.mapToPair(new PairFunction() { + @Override + public Tuple2 call(LabeledPoint p) { + return new Tuple2<>(model.predict(p.features()), p.label()); + } + }); + Double testErr = + 1.0 * predictionAndLabel.filter(new Function, Boolean>() { + @Override + public Boolean call(Tuple2 pl) { + return !pl._1().equals(pl._2()); + } + }).count() / testData.count(); + System.out.println("Test Error: " + testErr); + System.out.println("Learned classification GBT model:\n" + model.toDebugString()); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myGradientBoostingClassificationModel"); + GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(jsc.sc(), + "target/tmp/myGradientBoostingClassificationModel"); + // $example off$ + + jsc.stop(); + } + +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingRegressionExample.java new file mode 100644 index 0000000000000..78db442dbc99d --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingRegressionExample.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.HashMap; +import java.util.Map; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.GradientBoostedTrees; +import org.apache.spark.mllib.tree.configuration.BoostingStrategy; +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +public class JavaGradientBoostingRegressionExample { + public static void main(String[] args) { + // $example on$ + SparkConf sparkConf = new SparkConf() + .setAppName("JavaGradientBoostedTreesRegressionExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + // Load and parse the data file. + String datapath = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); + // Split the data into training and test sets (30% held out for testing) + JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); + JavaRDD trainingData = splits[0]; + JavaRDD testData = splits[1]; + + // Train a GradientBoostedTrees model. + // The defaultParams for Regression use SquaredError by default. + BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Regression"); + boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. + boostingStrategy.getTreeStrategy().setMaxDepth(5); + // Empty categoricalFeaturesInfo indicates all features are continuous. + Map categoricalFeaturesInfo = new HashMap<>(); + boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo); + + final GradientBoostedTreesModel model = + GradientBoostedTrees.train(trainingData, boostingStrategy); + + // Evaluate model on test instances and compute test error + JavaPairRDD predictionAndLabel = + testData.mapToPair(new PairFunction() { + @Override + public Tuple2 call(LabeledPoint p) { + return new Tuple2<>(model.predict(p.features()), p.label()); + } + }); + Double testMSE = + predictionAndLabel.map(new Function, Double>() { + @Override + public Double call(Tuple2 pl) { + Double diff = pl._1() - pl._2(); + return diff * diff; + } + }).reduce(new Function2() { + @Override + public Double call(Double a, Double b) { + return a + b; + } + }) / data.count(); + System.out.println("Test Mean Squared Error: " + testMSE); + System.out.println("Learned regression GBT model:\n" + model.toDebugString()); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myGradientBoostingRegressionModel"); + GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(jsc.sc(), + "target/tmp/myGradientBoostingRegressionModel"); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaHypothesisTestingExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaHypothesisTestingExample.java new file mode 100644 index 0000000000000..b48b95ff1d2a3 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaHypothesisTestingExample.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.linalg.Matrices; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.stat.Statistics; +import org.apache.spark.mllib.stat.test.ChiSqTestResult; +// $example off$ + +public class JavaHypothesisTestingExample { + public static void main(String[] args) { + + SparkConf conf = new SparkConf().setAppName("JavaHypothesisTestingExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + + // $example on$ + // a vector composed of the frequencies of events + Vector vec = Vectors.dense(0.1, 0.15, 0.2, 0.3, 0.25); + + // compute the goodness of fit. If a second vector to test against is not supplied + // as a parameter, the test runs against a uniform distribution. + ChiSqTestResult goodnessOfFitTestResult = Statistics.chiSqTest(vec); + // summary of the test including the p-value, degrees of freedom, test statistic, + // the method used, and the null hypothesis. + System.out.println(goodnessOfFitTestResult + "\n"); + + // Create a contingency matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)) + Matrix mat = Matrices.dense(3, 2, new double[]{1.0, 3.0, 5.0, 2.0, 4.0, 6.0}); + + // conduct Pearson's independence test on the input contingency matrix + ChiSqTestResult independenceTestResult = Statistics.chiSqTest(mat); + // summary of the test including the p-value, degrees of freedom... + System.out.println(independenceTestResult + "\n"); + + // an RDD of labeled points + JavaRDD obs = jsc.parallelize( + Arrays.asList( + new LabeledPoint(1.0, Vectors.dense(1.0, 0.0, 3.0)), + new LabeledPoint(1.0, Vectors.dense(1.0, 2.0, 0.0)), + new LabeledPoint(-1.0, Vectors.dense(-1.0, 0.0, -0.5)) + ) + ); + + // The contingency table is constructed from the raw (feature, label) pairs and used to conduct + // the independence test. Returns an array containing the ChiSquaredTestResult for every feature + // against the label. + ChiSqTestResult[] featureTestResults = Statistics.chiSqTest(obs.rdd()); + int i = 1; + for (ChiSqTestResult result : featureTestResults) { + System.out.println("Column " + i + ":"); + System.out.println(result + "\n"); // summary of the test + i++; + } + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaHypothesisTestingKolmogorovSmirnovTestExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaHypothesisTestingKolmogorovSmirnovTestExample.java new file mode 100644 index 0000000000000..fe611c9ae67c9 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaHypothesisTestingKolmogorovSmirnovTestExample.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.mllib.stat.Statistics; +import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult; +// $example off$ + +public class JavaHypothesisTestingKolmogorovSmirnovTestExample { + public static void main(String[] args) { + + SparkConf conf = + new SparkConf().setAppName("JavaHypothesisTestingKolmogorovSmirnovTestExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + + // $example on$ + JavaDoubleRDD data = jsc.parallelizeDoubles(Arrays.asList(0.1, 0.15, 0.2, 0.3, 0.25)); + KolmogorovSmirnovTestResult testResult = + Statistics.kolmogorovSmirnovTest(data, "norm", 0.0, 1.0); + // summary of the test including the p-value, test statistic, and null hypothesis + // if our p-value indicates significance, we can reject the null hypothesis + System.out.println(testResult); + // $example off$ + + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java index 37e709b4cbc03..c6361a3729988 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java @@ -48,7 +48,8 @@ public Tuple3 call(String line) { ); // Split data into training (60%) and test (40%) sets. - JavaRDD>[] splits = parsedData.randomSplit(new double[]{0.6, 0.4}, 11L); + JavaRDD>[] splits = + parsedData.randomSplit(new double[]{0.6, 0.4}, 11L); JavaRDD> training = splits[0]; JavaRDD> test = splits[1]; @@ -62,7 +63,7 @@ public Tuple3 call(String line) { @Override public Tuple2 call(Tuple3 point) { Double predictedLabel = model.predict(point._2()); - return new Tuple2(predictedLabel, point._1()); + return new Tuple2<>(predictedLabel, point._1()); } } ); @@ -80,7 +81,10 @@ public Object call(Tuple2 pl) { // Save and load model model.save(jsc.sc(), "target/tmp/myIsotonicRegressionModel"); - IsotonicRegressionModel sameModel = IsotonicRegressionModel.load(jsc.sc(), "target/tmp/myIsotonicRegressionModel"); + IsotonicRegressionModel sameModel = + IsotonicRegressionModel.load(jsc.sc(), "target/tmp/myIsotonicRegressionModel"); // $example off$ + + jsc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java deleted file mode 100644 index e575eedeb465c..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import java.util.regex.Pattern; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; - -import org.apache.spark.mllib.clustering.KMeans; -import org.apache.spark.mllib.clustering.KMeansModel; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; - -/** - * Example using MLlib KMeans from Java. - */ -public final class JavaKMeans { - - private static class ParsePoint implements Function { - private static final Pattern SPACE = Pattern.compile(" "); - - @Override - public Vector call(String line) { - String[] tok = SPACE.split(line); - double[] point = new double[tok.length]; - for (int i = 0; i < tok.length; ++i) { - point[i] = Double.parseDouble(tok[i]); - } - return Vectors.dense(point); - } - } - - public static void main(String[] args) { - if (args.length < 3) { - System.err.println( - "Usage: JavaKMeans []"); - System.exit(1); - } - String inputFile = args[0]; - int k = Integer.parseInt(args[1]); - int iterations = Integer.parseInt(args[2]); - int runs = 1; - - if (args.length >= 4) { - runs = Integer.parseInt(args[3]); - } - SparkConf sparkConf = new SparkConf().setAppName("JavaKMeans"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - JavaRDD lines = sc.textFile(inputFile); - - JavaRDD points = lines.map(new ParsePoint()); - - KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs, KMeans.K_MEANS_PARALLEL()); - - System.out.println("Cluster centers:"); - for (Vector center : model.clusterCenters()) { - System.out.println(" " + center); - } - double cost = model.computeCost(points.rdd()); - System.out.println("Cost: " + cost); - - sc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeansExample.java new file mode 100644 index 0000000000000..2d89c768fcfca --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeansExample.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; + +// $example on$ +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.clustering.KMeans; +import org.apache.spark.mllib.clustering.KMeansModel; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +// $example off$ + +public class JavaKMeansExample { + public static void main(String[] args) { + + SparkConf conf = new SparkConf().setAppName("JavaKMeansExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + + // $example on$ + // Load and parse data + String path = "data/mllib/kmeans_data.txt"; + JavaRDD data = jsc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public Vector call(String s) { + String[] sarray = s.split(" "); + double[] values = new double[sarray.length]; + for (int i = 0; i < sarray.length; i++) { + values[i] = Double.parseDouble(sarray[i]); + } + return Vectors.dense(values); + } + } + ); + parsedData.cache(); + + // Cluster the data into two classes using KMeans + int numClusters = 2; + int numIterations = 20; + KMeansModel clusters = KMeans.train(parsedData.rdd(), numClusters, numIterations); + + System.out.println("Cluster centers:"); + for (Vector center: clusters.clusterCenters()) { + System.out.println(" " + center); + } + double cost = clusters.computeCost(parsedData.rdd()); + System.out.println("Cost: " + cost); + + // Evaluate clustering by computing Within Set Sum of Squared Errors + double WSSSE = clusters.computeCost(parsedData.rdd()); + System.out.println("Within Set Sum of Squared Errors = " + WSSSE); + + // Save and load model + clusters.save(jsc.sc(), "target/org/apache/spark/JavaKMeansExample/KMeansModel"); + KMeansModel sameModel = KMeansModel.load(jsc.sc(), + "target/org/apache/spark/JavaKMeansExample/KMeansModel"); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKernelDensityEstimationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKernelDensityEstimationExample.java new file mode 100644 index 0000000000000..41de0d90eccd7 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKernelDensityEstimationExample.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.stat.KernelDensity; +// $example off$ + +public class JavaKernelDensityEstimationExample { + public static void main(String[] args) { + + SparkConf conf = new SparkConf().setAppName("JavaKernelDensityEstimationExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + + // $example on$ + // an RDD of sample data + JavaRDD data = jsc.parallelize( + Arrays.asList(1.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 6.0, 7.0, 8.0, 9.0, 9.0)); + + // Construct the density estimator with the sample data + // and a standard deviation for the Gaussian kernels + KernelDensity kd = new KernelDensity().setSample(data).setBandwidth(3.0); + + // Find density estimates for the given values + double[] densities = kd.estimate(new double[]{-1.0, 2.0, 5.0}); + + System.out.println(Arrays.toString(densities)); + // $example off$ + + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLBFGSExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLBFGSExample.java new file mode 100644 index 0000000000000..355883f61bd64 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLBFGSExample.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.Arrays; + +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.optimization.*; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example off$ + +public class JavaLBFGSExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("L-BFGS Example"); + SparkContext sc = new SparkContext(conf); + + // $example on$ + String path = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + int numFeatures = data.take(1).get(0).features().size(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD trainingInit = data.sample(false, 0.6, 11L); + JavaRDD test = data.subtract(trainingInit); + + // Append 1 into the training data as intercept. + JavaRDD> training = data.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + return new Tuple2(p.label(), MLUtils.appendBias(p.features())); + } + }); + training.cache(); + + // Run training algorithm to build the model. + int numCorrections = 10; + double convergenceTol = 1e-4; + int maxNumIterations = 20; + double regParam = 0.1; + Vector initialWeightsWithIntercept = Vectors.dense(new double[numFeatures + 1]); + + Tuple2 result = LBFGS.runLBFGS( + training.rdd(), + new LogisticGradient(), + new SquaredL2Updater(), + numCorrections, + convergenceTol, + maxNumIterations, + regParam, + initialWeightsWithIntercept); + Vector weightsWithIntercept = result._1(); + double[] loss = result._2(); + + final LogisticRegressionModel model = new LogisticRegressionModel( + Vectors.dense(Arrays.copyOf(weightsWithIntercept.toArray(), weightsWithIntercept.size() - 1)), + (weightsWithIntercept.toArray())[weightsWithIntercept.size() - 1]); + + // Clear the default threshold. + model.clearThreshold(); + + // Compute raw scores on the test set. + JavaRDD> scoreAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double score = model.predict(p.features()); + return new Tuple2(score, p.label()); + } + }); + + // Get evaluation metrics. + BinaryClassificationMetrics metrics = + new BinaryClassificationMetrics(scoreAndLabels.rdd()); + double auROC = metrics.areaUnderROC(); + + System.out.println("Loss of each step in training process"); + for (double l : loss) + System.out.println(l); + System.out.println("Area under ROC = " + auROC); + // $example off$ + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java deleted file mode 100644 index fd53c81cc4974..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.clustering.DistributedLDAModel; -import org.apache.spark.mllib.clustering.LDA; -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.SparkConf; - -public class JavaLDAExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("LDA Example"); - JavaSparkContext sc = new JavaSparkContext(conf); - - // Load and parse the data - String path = "data/mllib/sample_lda_data.txt"; - JavaRDD data = sc.textFile(path); - JavaRDD parsedData = data.map( - new Function() { - public Vector call(String s) { - String[] sarray = s.trim().split(" "); - double[] values = new double[sarray.length]; - for (int i = 0; i < sarray.length; i++) - values[i] = Double.parseDouble(sarray[i]); - return Vectors.dense(values); - } - } - ); - // Index documents with unique IDs - JavaPairRDD corpus = JavaPairRDD.fromJavaRDD(parsedData.zipWithIndex().map( - new Function, Tuple2>() { - public Tuple2 call(Tuple2 doc_id) { - return doc_id.swap(); - } - } - )); - corpus.cache(); - - // Cluster the documents into three topics using LDA - DistributedLDAModel ldaModel = (DistributedLDAModel)new LDA().setK(3).run(corpus); - - // Output topics. Each is a distribution over words (matching word count vectors) - System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize() - + " words):"); - Matrix topics = ldaModel.topicsMatrix(); - for (int topic = 0; topic < 3; topic++) { - System.out.print("Topic " + topic + ":"); - for (int word = 0; word < ldaModel.vocabSize(); word++) { - System.out.print(" " + topics.apply(word, topic)); - } - System.out.println(); - } - sc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLR.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLR.java deleted file mode 100644 index eceb6927d5551..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLR.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import java.util.regex.Pattern; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; - -import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; -import org.apache.spark.mllib.classification.LogisticRegressionModel; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.regression.LabeledPoint; - -/** - * Logistic regression based classification using ML Lib. - */ -public final class JavaLR { - - static class ParsePoint implements Function { - private static final Pattern COMMA = Pattern.compile(","); - private static final Pattern SPACE = Pattern.compile(" "); - - @Override - public LabeledPoint call(String line) { - String[] parts = COMMA.split(line); - double y = Double.parseDouble(parts[0]); - String[] tok = SPACE.split(parts[1]); - double[] x = new double[tok.length]; - for (int i = 0; i < tok.length; ++i) { - x[i] = Double.parseDouble(tok[i]); - } - return new LabeledPoint(y, Vectors.dense(x)); - } - } - - public static void main(String[] args) { - if (args.length != 3) { - System.err.println("Usage: JavaLR "); - System.exit(1); - } - SparkConf sparkConf = new SparkConf().setAppName("JavaLR"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - JavaRDD lines = sc.textFile(args[0]); - JavaRDD points = lines.map(new ParsePoint()).cache(); - double stepSize = Double.parseDouble(args[1]); - int iterations = Integer.parseInt(args[2]); - - // Another way to configure LogisticRegression - // - // LogisticRegressionWithSGD lr = new LogisticRegressionWithSGD(); - // lr.optimizer().setNumIterations(iterations) - // .setStepSize(stepSize) - // .setMiniBatchFraction(1.0); - // lr.setIntercept(true); - // LogisticRegressionModel model = lr.train(points.rdd()); - - LogisticRegressionModel model = LogisticRegressionWithSGD.train(points.rdd(), - iterations, stepSize); - - System.out.print("Final w: " + model.weights()); - - sc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLatentDirichletAllocationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLatentDirichletAllocationExample.java new file mode 100644 index 0000000000000..578564eeb23dd --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLatentDirichletAllocationExample.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.clustering.DistributedLDAModel; +import org.apache.spark.mllib.clustering.LDA; +import org.apache.spark.mllib.clustering.LDAModel; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +// $example off$ + +public class JavaLatentDirichletAllocationExample { + public static void main(String[] args) { + + SparkConf conf = new SparkConf().setAppName("JavaKLatentDirichletAllocationExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + + // $example on$ + // Load and parse the data + String path = "data/mllib/sample_lda_data.txt"; + JavaRDD data = jsc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public Vector call(String s) { + String[] sarray = s.trim().split(" "); + double[] values = new double[sarray.length]; + for (int i = 0; i < sarray.length; i++) { + values[i] = Double.parseDouble(sarray[i]); + } + return Vectors.dense(values); + } + } + ); + // Index documents with unique IDs + JavaPairRDD corpus = + JavaPairRDD.fromJavaRDD(parsedData.zipWithIndex().map( + new Function, Tuple2>() { + public Tuple2 call(Tuple2 doc_id) { + return doc_id.swap(); + } + } + ) + ); + corpus.cache(); + + // Cluster the documents into three topics using LDA + LDAModel ldaModel = new LDA().setK(3).run(corpus); + + // Output topics. Each is a distribution over words (matching word count vectors) + System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize() + + " words):"); + Matrix topics = ldaModel.topicsMatrix(); + for (int topic = 0; topic < 3; topic++) { + System.out.print("Topic " + topic + ":"); + for (int word = 0; word < ldaModel.vocabSize(); word++) { + System.out.print(" " + topics.apply(word, topic)); + } + System.out.println(); + } + + ldaModel.save(jsc.sc(), + "target/org/apache/spark/JavaLatentDirichletAllocationExample/LDAModel"); + DistributedLDAModel sameModel = DistributedLDAModel.load(jsc.sc(), + "target/org/apache/spark/JavaLatentDirichletAllocationExample/LDAModel"); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java new file mode 100644 index 0000000000000..9ca9a7847c463 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.regression.LinearRegressionModel; +import org.apache.spark.mllib.regression.LinearRegressionWithSGD; +// $example off$ + +/** + * Example for LinearRegressionWithSGD. + */ +public class JavaLinearRegressionWithSGDExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaLinearRegressionWithSGDExample"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // $example on$ + // Load and parse the data + String path = "data/mllib/ridge-data/lpsa.data"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public LabeledPoint call(String line) { + String[] parts = line.split(","); + String[] features = parts[1].split(" "); + double[] v = new double[features.length]; + for (int i = 0; i < features.length - 1; i++) { + v[i] = Double.parseDouble(features[i]); + } + return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); + } + } + ); + parsedData.cache(); + + // Building the model + int numIterations = 100; + double stepSize = 0.00000001; + final LinearRegressionModel model = + LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations, stepSize); + + // Evaluate model on training examples and compute training error + JavaRDD> valuesAndPreds = parsedData.map( + new Function>() { + public Tuple2 call(LabeledPoint point) { + double prediction = model.predict(point.features()); + return new Tuple2<>(prediction, point.label()); + } + } + ); + double MSE = new JavaDoubleRDD(valuesAndPreds.map( + new Function, Object>() { + public Object call(Tuple2 pair) { + return Math.pow(pair._1() - pair._2(), 2.0); + } + } + ).rdd()).mean(); + System.out.println("training Mean Squared Error = " + MSE); + + // Save and load model + model.save(sc.sc(), "target/tmp/javaLinearRegressionWithSGDModel"); + LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), + "target/tmp/javaLinearRegressionWithSGDModel"); + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java new file mode 100644 index 0000000000000..9d8e4a90dbc99 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +/** + * Example for LogisticRegressionWithLBFGS. + */ +public class JavaLogisticRegressionWithLBFGSExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaLogisticRegressionWithLBFGSExample"); + SparkContext sc = new SparkContext(conf); + // $example on$ + String path = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(10) + .run(training.rdd()); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); + double precision = metrics.precision(); + System.out.println("Precision = " + precision); + + // Save and load model + model.save(sc, "target/tmp/javaLogisticRegressionWithLBFGSModel"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, + "target/tmp/javaLogisticRegressionWithLBFGSModel"); + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java new file mode 100644 index 0000000000000..bc99dc023fa7b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.mllib.evaluation.MultilabelMetrics; +import org.apache.spark.SparkConf; +// $example off$ + +public class JavaMultiLabelClassificationMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + // $example on$ + List> data = Arrays.asList( + new Tuple2<>(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), + new Tuple2<>(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2<>(new double[]{}, new double[]{0.0}), + new Tuple2<>(new double[]{2.0}, new double[]{2.0}), + new Tuple2<>(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), + new Tuple2<>(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2<>(new double[]{1.0}, new double[]{1.0, 2.0}) + ); + JavaRDD> scoreAndLabels = sc.parallelize(data); + + // Instantiate metrics object + MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd()); + + // Summary stats + System.out.format("Recall = %f\n", metrics.recall()); + System.out.format("Precision = %f\n", metrics.precision()); + System.out.format("F1 measure = %f\n", metrics.f1Measure()); + System.out.format("Accuracy = %f\n", metrics.accuracy()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length - 1; i++) { + System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision( + metrics.labels()[i])); + System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall( + metrics.labels()[i])); + System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure( + metrics.labels()[i])); + } + + // Micro stats + System.out.format("Micro recall = %f\n", metrics.microRecall()); + System.out.format("Micro precision = %f\n", metrics.microPrecision()); + System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure()); + + // Hamming loss + System.out.format("Hamming loss = %f\n", metrics.hammingLoss()); + + // Subset accuracy + System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy()); + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java new file mode 100644 index 0000000000000..5247c9c748618 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.mllib.linalg.Matrix; +// $example off$ +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class JavaMulticlassClassificationMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multi class Classification Metrics Example"); + SparkContext sc = new SparkContext(conf); + // $example on$ + String path = "data/mllib/sample_multiclass_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[]{0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training.rdd()); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); + + // Confusion matrix + Matrix confusion = metrics.confusionMatrix(); + System.out.println("Confusion matrix: \n" + confusion); + + // Overall statistics + System.out.println("Precision = " + metrics.precision()); + System.out.println("Recall = " + metrics.recall()); + System.out.println("F1 Score = " + metrics.fMeasure()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length; i++) { + System.out.format("Class %f precision = %f\n", metrics.labels()[i],metrics.precision( + metrics.labels()[i])); + System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall( + metrics.labels()[i])); + System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure( + metrics.labels()[i])); + } + + //Weighted stats + System.out.format("Weighted precision = %f\n", metrics.weightedPrecision()); + System.out.format("Weighted recall = %f\n", metrics.weightedRecall()); + System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure()); + System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate()); + + // Save and load model + model.save(sc, "target/tmp/LogisticRegressionModel"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, + "target/tmp/LogisticRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java index e6a5904bd71f0..2b17dbb96365e 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java @@ -46,7 +46,7 @@ public static void main(String[] args) { test.mapToPair(new PairFunction() { @Override public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); + return new Tuple2<>(model.predict(p.features()), p.label()); } }); double accuracy = predictionAndLabel.filter(new Function, Boolean>() { @@ -60,5 +60,7 @@ public Boolean call(Tuple2 pl) { model.save(jsc.sc(), "target/tmp/myNaiveBayesModel"); NaiveBayesModel sameModel = NaiveBayesModel.load(jsc.sc(), "target/tmp/myNaiveBayesModel"); // $example off$ + + jsc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java new file mode 100644 index 0000000000000..a42c29f52fb65 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.LinkedList; +// $example off$ + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example on$ +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.linalg.distributed.RowMatrix; +// $example off$ + +/** + * Example for compute principal components on a 'RowMatrix'. + */ +public class JavaPCAExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("PCA Example"); + SparkContext sc = new SparkContext(conf); + + // $example on$ + double[][] array = {{1.12, 2.05, 3.12}, {5.56, 6.28, 8.94}, {10.2, 8.0, 20.5}}; + LinkedList rowsList = new LinkedList<>(); + for (int i = 0; i < array.length; i++) { + Vector currentRow = Vectors.dense(array[i]); + rowsList.add(currentRow); + } + JavaRDD rows = JavaSparkContext.fromSparkContext(sc).parallelize(rowsList); + + // Create a RowMatrix from JavaRDD. + RowMatrix mat = new RowMatrix(rows.rdd()); + + // Compute the top 3 principal components. + Matrix pc = mat.computePrincipalComponents(3); + RowMatrix projected = mat.multiply(pc); + // $example off$ + Vector[] collectPartitions = (Vector[])projected.rows().collect(); + System.out.println("Projected vector of principal component:"); + for (Vector vector : collectPartitions) { + System.out.println("\t" + vector); + } + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java index 6c6f9768f015e..91c3bd72da3a7 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java @@ -24,8 +24,10 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +// $example on$ import org.apache.spark.mllib.clustering.PowerIterationClustering; import org.apache.spark.mllib.clustering.PowerIterationClusteringModel; +// $example off$ /** * Java example for graph clustering using power iteration clustering (PIC). @@ -36,12 +38,13 @@ public static void main(String[] args) { JavaSparkContext sc = new JavaSparkContext(sparkConf); @SuppressWarnings("unchecked") + // $example on$ JavaRDD> similarities = sc.parallelize(Lists.newArrayList( - new Tuple3(0L, 1L, 0.9), - new Tuple3(1L, 2L, 0.9), - new Tuple3(2L, 3L, 0.9), - new Tuple3(3L, 4L, 0.1), - new Tuple3(4L, 5L, 0.9))); + new Tuple3<>(0L, 1L, 0.9), + new Tuple3<>(1L, 2L, 0.9), + new Tuple3<>(2L, 3L, 0.9), + new Tuple3<>(3L, 4L, 0.1), + new Tuple3<>(4L, 5L, 0.9))); PowerIterationClustering pic = new PowerIterationClustering() .setK(2) @@ -51,6 +54,7 @@ public static void main(String[] args) { for (PowerIterationClustering.Assignment a: model.assignments().toJavaRDD().collect()) { System.out.println(a.id() + " -> " + a.cluster()); } + // $example off$ sc.stop(); } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPrefixSpanExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPrefixSpanExample.java index 68ec7c1e6ebe0..1634075941291 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPrefixSpanExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPrefixSpanExample.java @@ -51,5 +51,7 @@ public static void main(String[] args) { System.out.println(freqSeq.javaSequence() + ", " + freqSeq.freq()); } // $example off$ + + sc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestClassificationExample.java new file mode 100644 index 0000000000000..24af5d0180ce4 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestClassificationExample.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.HashMap; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.RandomForest; +import org.apache.spark.mllib.tree.model.RandomForestModel; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +public class JavaRandomForestClassificationExample { + public static void main(String[] args) { + // $example on$ + SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestClassificationExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + // Load and parse the data file. + String datapath = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); + // Split the data into training and test sets (30% held out for testing) + JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); + JavaRDD trainingData = splits[0]; + JavaRDD testData = splits[1]; + + // Train a RandomForest model. + // Empty categoricalFeaturesInfo indicates all features are continuous. + Integer numClasses = 2; + HashMap categoricalFeaturesInfo = new HashMap<>(); + Integer numTrees = 3; // Use more in practice. + String featureSubsetStrategy = "auto"; // Let the algorithm choose. + String impurity = "gini"; + Integer maxDepth = 5; + Integer maxBins = 32; + Integer seed = 12345; + + final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses, + categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, + seed); + + // Evaluate model on test instances and compute test error + JavaPairRDD predictionAndLabel = + testData.mapToPair(new PairFunction() { + @Override + public Tuple2 call(LabeledPoint p) { + return new Tuple2<>(model.predict(p.features()), p.label()); + } + }); + Double testErr = + 1.0 * predictionAndLabel.filter(new Function, Boolean>() { + @Override + public Boolean call(Tuple2 pl) { + return !pl._1().equals(pl._2()); + } + }).count() / testData.count(); + System.out.println("Test Error: " + testErr); + System.out.println("Learned classification forest model:\n" + model.toDebugString()); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myRandomForestClassificationModel"); + RandomForestModel sameModel = RandomForestModel.load(jsc.sc(), + "target/tmp/myRandomForestClassificationModel"); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestExample.java deleted file mode 100644 index 89a4e092a5af7..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestExample.java +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import scala.Tuple2; - -import java.util.HashMap; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.RandomForest; -import org.apache.spark.mllib.tree.model.RandomForestModel; -import org.apache.spark.mllib.util.MLUtils; - -public final class JavaRandomForestExample { - - /** - * Note: This example illustrates binary classification. - * For information on multiclass classification, please refer to the JavaDecisionTree.java - * example. - */ - private static void testClassification(JavaRDD trainingData, - JavaRDD testData) { - // Train a RandomForest model. - // Empty categoricalFeaturesInfo indicates all features are continuous. - Integer numClasses = 2; - HashMap categoricalFeaturesInfo = new HashMap(); - Integer numTrees = 3; // Use more in practice. - String featureSubsetStrategy = "auto"; // Let the algorithm choose. - String impurity = "gini"; - Integer maxDepth = 4; - Integer maxBins = 32; - Integer seed = 12345; - - final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses, - categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, - seed); - - // Evaluate model on test instances and compute test error - JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double testErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override - public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / testData.count(); - System.out.println("Test Error: " + testErr); - System.out.println("Learned classification forest model:\n" + model.toDebugString()); - } - - private static void testRegression(JavaRDD trainingData, - JavaRDD testData) { - // Train a RandomForest model. - // Empty categoricalFeaturesInfo indicates all features are continuous. - HashMap categoricalFeaturesInfo = new HashMap(); - Integer numTrees = 3; // Use more in practice. - String featureSubsetStrategy = "auto"; // Let the algorithm choose. - String impurity = "variance"; - Integer maxDepth = 4; - Integer maxBins = 32; - Integer seed = 12345; - - final RandomForestModel model = RandomForest.trainRegressor(trainingData, - categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, - seed); - - // Evaluate model on test instances and compute test error - JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double testMSE = - predictionAndLabel.map(new Function, Double>() { - @Override - public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override - public Double call(Double a, Double b) { - return a + b; - } - }) / testData.count(); - System.out.println("Test Mean Squared Error: " + testMSE); - System.out.println("Learned regression forest model:\n" + model.toDebugString()); - } - - public static void main(String[] args) { - SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestExample"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - - // Load and parse the data file. - String datapath = "data/mllib/sample_libsvm_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); - // Split the data into training and test sets (30% held out for testing) - JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); - JavaRDD trainingData = splits[0]; - JavaRDD testData = splits[1]; - - System.out.println("\nRunning example of classification using RandomForest\n"); - testClassification(trainingData, testData); - - System.out.println("\nRunning example of regression using RandomForest\n"); - testRegression(trainingData, testData); - sc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestRegressionExample.java new file mode 100644 index 0000000000000..afa9045878db3 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestRegressionExample.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.HashMap; +import java.util.Map; + +import scala.Tuple2; + +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.RandomForest; +import org.apache.spark.mllib.tree.model.RandomForestModel; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +// $example off$ + +public class JavaRandomForestRegressionExample { + public static void main(String[] args) { + // $example on$ + SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestRegressionExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + // Load and parse the data file. + String datapath = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); + // Split the data into training and test sets (30% held out for testing) + JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); + JavaRDD trainingData = splits[0]; + JavaRDD testData = splits[1]; + + // Set parameters. + // Empty categoricalFeaturesInfo indicates all features are continuous. + Map categoricalFeaturesInfo = new HashMap<>(); + Integer numTrees = 3; // Use more in practice. + String featureSubsetStrategy = "auto"; // Let the algorithm choose. + String impurity = "variance"; + Integer maxDepth = 4; + Integer maxBins = 32; + Integer seed = 12345; + // Train a RandomForest model. + final RandomForestModel model = RandomForest.trainRegressor(trainingData, + categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed); + + // Evaluate model on test instances and compute test error + JavaPairRDD predictionAndLabel = + testData.mapToPair(new PairFunction() { + @Override + public Tuple2 call(LabeledPoint p) { + return new Tuple2<>(model.predict(p.features()), p.label()); + } + }); + Double testMSE = + predictionAndLabel.map(new Function, Double>() { + @Override + public Double call(Tuple2 pl) { + Double diff = pl._1() - pl._2(); + return diff * diff; + } + }).reduce(new Function2() { + @Override + public Double call(Double a, Double b) { + return a + b; + } + }) / testData.count(); + System.out.println("Test Mean Squared Error: " + testMSE); + System.out.println("Learned regression forest model:\n" + model.toDebugString()); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myRandomForestRegressionModel"); + RandomForestModel sameModel = RandomForestModel.load(jsc.sc(), + "target/tmp/myRandomForestRegressionModel"); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java new file mode 100644 index 0000000000000..54dfc404ca6e9 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.*; + +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.mllib.evaluation.RankingMetrics; +import org.apache.spark.mllib.recommendation.ALS; +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +import org.apache.spark.mllib.recommendation.Rating; +// $example off$ +import org.apache.spark.SparkConf; + +public class JavaRankingMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Java Ranking Metrics Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + // $example on$ + String path = "data/mllib/sample_movielens_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD ratings = data.map( + new Function() { + @Override + public Rating call(String line) { + String[] parts = line.split("::"); + return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double + .parseDouble(parts[2]) - 2.5); + } + } + ); + ratings.cache(); + + // Train an ALS model + final MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), 10, 10, 0.01); + + // Get top 10 recommendations for every user and scale ratings from 0 to 1 + JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD(); + JavaRDD> userRecsScaled = userRecs.map( + new Function, Tuple2>() { + @Override + public Tuple2 call(Tuple2 t) { + Rating[] scaledRatings = new Rating[t._2().length]; + for (int i = 0; i < scaledRatings.length; i++) { + double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0); + scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating); + } + return new Tuple2<>(t._1(), scaledRatings); + } + } + ); + JavaPairRDD userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled); + + // Map ratings to 1 or 0, 1 indicating a movie that should be recommended + JavaRDD binarizedRatings = ratings.map( + new Function() { + @Override + public Rating call(Rating r) { + double binaryRating; + if (r.rating() > 0.0) { + binaryRating = 1.0; + } else { + binaryRating = 0.0; + } + return new Rating(r.user(), r.product(), binaryRating); + } + } + ); + + // Group ratings by common user + JavaPairRDD> userMovies = binarizedRatings.groupBy( + new Function() { + @Override + public Object call(Rating r) { + return r.user(); + } + } + ); + + // Get true relevant documents from all user ratings + JavaPairRDD> userMoviesList = userMovies.mapValues( + new Function, List>() { + @Override + public List call(Iterable docs) { + List products = new ArrayList<>(); + for (Rating r : docs) { + if (r.rating() > 0.0) { + products.add(r.product()); + } + } + return products; + } + } + ); + + // Extract the product id from each recommendation + JavaPairRDD> userRecommendedList = userRecommended.mapValues( + new Function>() { + @Override + public List call(Rating[] docs) { + List products = new ArrayList<>(); + for (Rating r : docs) { + products.add(r.product()); + } + return products; + } + } + ); + JavaRDD, List>> relevantDocs = userMoviesList.join( + userRecommendedList).values(); + + // Instantiate the metrics object + RankingMetrics metrics = RankingMetrics.of(relevantDocs); + + // Precision and NDCG at k + Integer[] kVector = {1, 3, 5}; + for (Integer k : kVector) { + System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k)); + System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k)); + } + + // Mean average precision + System.out.format("Mean average precision = %f\n", metrics.meanAveragePrecision()); + + // Evaluate the model using numerical ratings and regression metrics + JavaRDD> userProducts = ratings.map( + new Function>() { + @Override + public Tuple2 call(Rating r) { + return new Tuple2(r.user(), r.product()); + } + } + ); + JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD( + model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( + new Function, Object>>() { + @Override + public Tuple2, Object> call(Rating r) { + return new Tuple2, Object>( + new Tuple2<>(r.user(), r.product()), r.rating()); + } + } + )); + JavaRDD> ratesAndPreds = + JavaPairRDD.fromJavaRDD(ratings.map( + new Function, Object>>() { + @Override + public Tuple2, Object> call(Rating r) { + return new Tuple2, Object>( + new Tuple2<>(r.user(), r.product()), r.rating()); + } + } + )).join(predictions).values(); + + // Create regression metrics object + RegressionMetrics regressionMetrics = new RegressionMetrics(ratesAndPreds.rdd()); + + // Root mean squared error + System.out.format("RMSE = %f\n", regressionMetrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R-squared = %f\n", regressionMetrics.r2()); + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java new file mode 100644 index 0000000000000..f69aa4b75a56c --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.recommendation.ALS; +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +import org.apache.spark.mllib.recommendation.Rating; +import org.apache.spark.SparkConf; +// $example off$ + +public class JavaRecommendationExample { + public static void main(String[] args) { + // $example on$ + SparkConf conf = new SparkConf().setAppName("Java Collaborative Filtering Example"); + JavaSparkContext jsc = new JavaSparkContext(conf); + + // Load and parse the data + String path = "data/mllib/als/test.data"; + JavaRDD data = jsc.textFile(path); + JavaRDD ratings = data.map( + new Function() { + public Rating call(String s) { + String[] sarray = s.split(","); + return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), + Double.parseDouble(sarray[2])); + } + } + ); + + // Build the recommendation model using ALS + int rank = 10; + int numIterations = 10; + MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01); + + // Evaluate the model on rating data + JavaRDD> userProducts = ratings.map( + new Function>() { + public Tuple2 call(Rating r) { + return new Tuple2(r.user(), r.product()); + } + } + ); + JavaPairRDD, Double> predictions = JavaPairRDD.fromJavaRDD( + model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( + new Function, Double>>() { + public Tuple2, Double> call(Rating r){ + return new Tuple2<>(new Tuple2<>(r.user(), r.product()), r.rating()); + } + } + )); + JavaRDD> ratesAndPreds = + JavaPairRDD.fromJavaRDD(ratings.map( + new Function, Double>>() { + public Tuple2, Double> call(Rating r){ + return new Tuple2<>(new Tuple2<>(r.user(), r.product()), r.rating()); + } + } + )).join(predictions).values(); + double MSE = JavaDoubleRDD.fromRDD(ratesAndPreds.map( + new Function, Object>() { + public Object call(Tuple2 pair) { + Double err = pair._1() - pair._2(); + return err * err; + } + } + ).rdd()).mean(); + System.out.println("Mean Squared Error = " + MSE); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myCollaborativeFilter"); + MatrixFactorizationModel sameModel = MatrixFactorizationModel.load(jsc.sc(), + "target/tmp/myCollaborativeFilter"); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java new file mode 100644 index 0000000000000..b3e5c04759575 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.regression.LinearRegressionModel; +import org.apache.spark.mllib.regression.LinearRegressionWithSGD; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.SparkConf; +// $example off$ + +public class JavaRegressionMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Java Regression Metrics Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + // $example on$ + // Load and parse the data + String path = "data/mllib/sample_linear_regression_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public LabeledPoint call(String line) { + String[] parts = line.split(" "); + double[] v = new double[parts.length - 1]; + for (int i = 1; i < parts.length - 1; i++) { + v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); + } + return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); + } + } + ); + parsedData.cache(); + + // Building the model + int numIterations = 100; + final LinearRegressionModel model = LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), + numIterations); + + // Evaluate model on training examples and compute training error + JavaRDD> valuesAndPreds = parsedData.map( + new Function>() { + public Tuple2 call(LabeledPoint point) { + double prediction = model.predict(point.features()); + return new Tuple2(prediction, point.label()); + } + } + ); + + // Instantiate metrics object + RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd()); + + // Squared error + System.out.format("MSE = %f\n", metrics.meanSquaredError()); + System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R Squared = %f\n", metrics.r2()); + + // Mean absolute error + System.out.format("MAE = %f\n", metrics.meanAbsoluteError()); + + // Explained variance + System.out.format("Explained Variance = %f\n", metrics.explainedVariance()); + + // Save and load model + model.save(sc.sc(), "target/tmp/LogisticRegressionModel"); + LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), + "target/tmp/LogisticRegressionModel"); + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java new file mode 100644 index 0000000000000..3730e60f68803 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.LinkedList; +// $example off$ + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example on$ +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.SingularValueDecomposition; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.linalg.distributed.RowMatrix; +// $example off$ + +/** + * Example for SingularValueDecomposition. + */ +public class JavaSVDExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("SVD Example"); + SparkContext sc = new SparkContext(conf); + JavaSparkContext jsc = JavaSparkContext.fromSparkContext(sc); + + // $example on$ + double[][] array = {{1.12, 2.05, 3.12}, {5.56, 6.28, 8.94}, {10.2, 8.0, 20.5}}; + LinkedList rowsList = new LinkedList<>(); + for (int i = 0; i < array.length; i++) { + Vector currentRow = Vectors.dense(array[i]); + rowsList.add(currentRow); + } + JavaRDD rows = jsc.parallelize(rowsList); + + // Create a RowMatrix from JavaRDD. + RowMatrix mat = new RowMatrix(rows.rdd()); + + // Compute the top 3 singular values and corresponding singular vectors. + SingularValueDecomposition svd = mat.computeSVD(3, true, 1.0E-9d); + RowMatrix U = svd.U(); + Vector s = svd.s(); + Matrix V = svd.V(); + // $example off$ + Vector[] collectPartitions = (Vector[]) U.rows().collect(); + System.out.println("U factor is:"); + for (Vector vector : collectPartitions) { + System.out.println("\t" + vector); + } + System.out.println("Singular values are: " + s); + System.out.println("V factor is:\n" + V); + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVMWithSGDExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVMWithSGDExample.java new file mode 100644 index 0000000000000..720b167b2cadf --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVMWithSGDExample.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.SVMModel; +import org.apache.spark.mllib.classification.SVMWithSGD; +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +/** + * Example for SVMWithSGD. + */ +public class JavaSVMWithSGDExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaSVMWithSGDExample"); + SparkContext sc = new SparkContext(conf); + // $example on$ + String path = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD training = data.sample(false, 0.6, 11L); + training.cache(); + JavaRDD test = data.subtract(training); + + // Run training algorithm to build the model. + int numIterations = 100; + final SVMModel model = SVMWithSGD.train(training.rdd(), numIterations); + + // Clear the default threshold. + model.clearThreshold(); + + // Compute raw scores on the test set. + JavaRDD> scoreAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double score = model.predict(p.features()); + return new Tuple2(score, p.label()); + } + } + ); + + // Get evaluation metrics. + BinaryClassificationMetrics metrics = + new BinaryClassificationMetrics(JavaRDD.toRDD(scoreAndLabels)); + double auROC = metrics.areaUnderROC(); + + System.out.println("Area under ROC = " + auROC); + + // Save and load model + model.save(sc, "target/tmp/javaSVMWithSGDModel"); + SVMModel sameModel = SVMModel.load(sc, "target/tmp/javaSVMWithSGDModel"); + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSimpleFPGrowth.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSimpleFPGrowth.java index 72edaca5e95b1..7f4fe600422b2 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSimpleFPGrowth.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSimpleFPGrowth.java @@ -67,5 +67,7 @@ public List call(String line) { rule.javaAntecedent() + " => " + rule.javaConsequent() + ", " + rule.confidence()); } // $example off$ + + sc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java new file mode 100644 index 0000000000000..72bbb2a8fa464 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import com.google.common.collect.ImmutableMap; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; + +// $example on$ +import java.util.*; + +import scala.Tuple2; + +import org.apache.spark.api.java.JavaPairRDD; +// $example off$ + +public class JavaStratifiedSamplingExample { + public static void main(String[] args) { + + SparkConf conf = new SparkConf().setAppName("JavaStratifiedSamplingExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + + @SuppressWarnings("unchecked") + // $example on$ + List> list = new ArrayList<>( + Arrays.>asList( + new Tuple2(1, 'a'), + new Tuple2(1, 'b'), + new Tuple2(2, 'c'), + new Tuple2(2, 'd'), + new Tuple2(2, 'e'), + new Tuple2(3, 'f') + ) + ); + + JavaPairRDD data = jsc.parallelizePairs(list); + + // specify the exact fraction desired from each key Map + ImmutableMap fractions = + ImmutableMap.of(1, (Object)0.1, 2, (Object) 0.6, 3, (Object) 0.3); + + // Get an approximate sample from each stratum + JavaPairRDD approxSample = data.sampleByKey(false, fractions); + // Get an exact sample from each stratum + JavaPairRDD exactSample = data.sampleByKeyExact(false, fractions); + // $example off$ + + System.out.println("approxSample size is " + approxSample.collect().size()); + for (Tuple2 t : approxSample.collect()) { + System.out.println(t._1() + " " + t._2()); + } + + System.out.println("exactSample size is " + exactSample.collect().size()); + for (Tuple2 t : exactSample.collect()) { + System.out.println(t._1() + " " + t._2()); + } + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaStreamingTestExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaStreamingTestExample.java new file mode 100644 index 0000000000000..984909cb947a1 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaStreamingTestExample.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + + +import org.apache.spark.api.java.function.VoidFunction; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +// $example on$ +import org.apache.spark.mllib.stat.test.BinarySample; +import org.apache.spark.mllib.stat.test.StreamingTest; +import org.apache.spark.mllib.stat.test.StreamingTestResult; +// $example off$ +import org.apache.spark.SparkConf; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.Seconds; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.apache.spark.util.Utils; + + +/** + * Perform streaming testing using Welch's 2-sample t-test on a stream of data, where the data + * stream arrives as text files in a directory. Stops when the two groups are statistically + * significant (p-value < 0.05) or after a user-specified timeout in number of batches is exceeded. + * + * The rows of the text files must be in the form `Boolean, Double`. For example: + * false, -3.92 + * true, 99.32 + * + * Usage: + * JavaStreamingTestExample + * + * To run on your local machine using the directory `dataDir` with 5 seconds between each batch and + * a timeout after 100 insignificant batches, call: + * $ bin/run-example mllib.JavaStreamingTestExample dataDir 5 100 + * + * As you add text files to `dataDir` the significance test wil continually update every + * `batchDuration` seconds until the test becomes significant (p-value < 0.05) or the number of + * batches processed exceeds `numBatchesTimeout`. + */ +public class JavaStreamingTestExample { + + private static int timeoutCounter = 0; + + public static void main(String[] args) { + if (args.length != 3) { + System.err.println("Usage: JavaStreamingTestExample " + + " "); + System.exit(1); + } + + String dataDir = args[0]; + Duration batchDuration = Seconds.apply(Long.valueOf(args[1])); + int numBatchesTimeout = Integer.valueOf(args[2]); + + SparkConf conf = new SparkConf().setMaster("local").setAppName("StreamingTestExample"); + JavaStreamingContext ssc = new JavaStreamingContext(conf, batchDuration); + + ssc.checkpoint(Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark").toString()); + + // $example on$ + JavaDStream data = ssc.textFileStream(dataDir).map( + new Function() { + @Override + public BinarySample call(String line) { + String[] ts = line.split(","); + boolean label = Boolean.valueOf(ts[0]); + double value = Double.valueOf(ts[1]); + return new BinarySample(label, value); + } + }); + + StreamingTest streamingTest = new StreamingTest() + .setPeacePeriod(0) + .setWindowSize(0) + .setTestMethod("welch"); + + JavaDStream out = streamingTest.registerStream(data); + out.print(); + // $example off$ + + // Stop processing if test becomes significant or we time out + timeoutCounter = numBatchesTimeout; + + out.foreachRDD(new VoidFunction>() { + @Override + public void call(JavaRDD rdd) { + timeoutCounter -= 1; + + boolean anySignificant = !rdd.filter(new Function() { + @Override + public Boolean call(StreamingTestResult v) { + return v.pValue() < 0.05; + } + }).isEmpty(); + + if (timeoutCounter <= 0 || anySignificant) { + rdd.context().stop(); + } + } + }); + + ssc.start(); + ssc.awaitTermination(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSummaryStatisticsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSummaryStatisticsExample.java new file mode 100644 index 0000000000000..278706bc8f6ed --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSummaryStatisticsExample.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.stat.MultivariateStatisticalSummary; +import org.apache.spark.mllib.stat.Statistics; +// $example off$ + +public class JavaSummaryStatisticsExample { + public static void main(String[] args) { + + SparkConf conf = new SparkConf().setAppName("JavaSummaryStatisticsExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + + // $example on$ + JavaRDD mat = jsc.parallelize( + Arrays.asList( + Vectors.dense(1.0, 10.0, 100.0), + Vectors.dense(2.0, 20.0, 200.0), + Vectors.dense(3.0, 30.0, 300.0) + ) + ); // an RDD of Vectors + + // Compute column summary statistics. + MultivariateStatisticalSummary summary = Statistics.colStats(mat.rdd()); + System.out.println(summary.mean()); // a dense vector containing the mean value for each column + System.out.println(summary.variance()); // column-wise variance + System.out.println(summary.numNonzeros()); // number of nonzeros in each column + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index afee279ec32b1..354a5306ed45f 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -26,7 +26,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -74,11 +74,12 @@ public Person call(String line) { }); // Apply a schema to an RDD of Java Beans and register it as a table. - DataFrame schemaPeople = sqlContext.createDataFrame(people, Person.class); + Dataset schemaPeople = sqlContext.createDataFrame(people, Person.class); schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. - DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + Dataset teenagers = + sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); // The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. @@ -99,11 +100,11 @@ public String call(Row row) { // Read in the parquet file created above. // Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a DataFrame. - DataFrame parquetFile = sqlContext.read().parquet("people.parquet"); + Dataset parquetFile = sqlContext.read().parquet("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); - DataFrame teenagers2 = + Dataset teenagers2 = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); teenagerNames = teenagers2.toJavaRDD().map(new Function() { @Override @@ -120,7 +121,7 @@ public String call(Row row) { // The path can be either a single text file or a directory storing text files. String path = "examples/src/main/resources/people.json"; // Create a DataFrame from the file(s) pointed by path - DataFrame peopleFromJsonFile = sqlContext.read().json(path); + Dataset peopleFromJsonFile = sqlContext.read().json(path); // Because the schema of a JSON dataset is automatically inferred, to write queries, // it is better to take a look at what is the schema. @@ -134,7 +135,8 @@ public String call(Row row) { peopleFromJsonFile.registerTempTable("people"); // SQL statements can be run by using the sql methods provided by sqlContext. - DataFrame teenagers3 = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + Dataset teenagers3 = + sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); // The results of SQL queries are DataFrame and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. @@ -151,7 +153,7 @@ public String call(Row row) { List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData); - DataFrame peopleFromJsonRDD = sqlContext.read().json(anotherPeopleRDD.rdd()); + Dataset peopleFromJsonRDD = sqlContext.read().json(anotherPeopleRDD.rdd()); // Take a look at the schema of this new DataFrame. peopleFromJsonRDD.printSchema(); @@ -164,7 +166,7 @@ public String call(Row row) { peopleFromJsonRDD.registerTempTable("people2"); - DataFrame peopleWithCity = sqlContext.sql("SELECT name, address.city FROM people2"); + Dataset peopleWithCity = sqlContext.sql("SELECT name, address.city FROM people2"); List nameAndCity = peopleWithCity.toJavaRDD().map(new Function() { @Override public String call(Row row) { diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java index 99df259b4e8e6..4544ad2b42ca7 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java @@ -17,7 +17,7 @@ package org.apache.spark.examples.streaming; -import com.google.common.collect.Lists; +import com.google.common.io.Closeables; import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.FlatMapFunction; @@ -36,6 +36,9 @@ import java.io.InputStreamReader; import java.net.ConnectException; import java.net.Socket; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Iterator; import java.util.regex.Pattern; /** @@ -73,14 +76,14 @@ public static void main(String[] args) { new JavaCustomReceiver(args[0], Integer.parseInt(args[1]))); JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Lists.newArrayList(SPACE.split(x)); + public Iterator call(String x) { + return Arrays.asList(SPACE.split(x)).iterator(); } }); JavaPairDStream wordCounts = words.mapToPair( new PairFunction() { @Override public Tuple2 call(String s) { - return new Tuple2(s, 1); + return new Tuple2<>(s, 1); } }).reduceByKey(new Function2() { @Override @@ -121,23 +124,24 @@ public void onStop() { /** Create a socket connection and receive data until receiver is stopped */ private void receive() { - Socket socket = null; - String userInput = null; - try { - // connect to the server - socket = new Socket(host, port); - - BufferedReader reader = new BufferedReader(new InputStreamReader(socket.getInputStream())); - - // Until stopped or connection broken continue reading - while (!isStopped() && (userInput = reader.readLine()) != null) { - System.out.println("Received data '" + userInput + "'"); - store(userInput); + Socket socket = null; + BufferedReader reader = null; + String userInput = null; + try { + // connect to the server + socket = new Socket(host, port); + reader = new BufferedReader( + new InputStreamReader(socket.getInputStream(), StandardCharsets.UTF_8)); + // Until stopped or connection broken continue reading + while (!isStopped() && (userInput = reader.readLine()) != null) { + System.out.println("Received data '" + userInput + "'"); + store(userInput); + } + } finally { + Closeables.close(reader, /* swallowIOException = */ true); + Closeables.close(socket, /* swallowIOException = */ true); } - reader.close(); - socket.close(); - // Restart in an attempt to connect again when server is active again restart("Trying to connect again"); } catch(ConnectException ce) { diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java index bab9f2478e779..769b21cecfb80 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java @@ -20,11 +20,11 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Arrays; +import java.util.Iterator; import java.util.regex.Pattern; import scala.Tuple2; -import com.google.common.collect.Lists; import kafka.serializer.StringDecoder; import org.apache.spark.SparkConf; @@ -35,12 +35,13 @@ /** * Consumes messages from one or more topics in Kafka and does wordcount. - * Usage: DirectKafkaWordCount + * Usage: JavaDirectKafkaWordCount * is a list of one or more Kafka brokers * is a list of one or more kafka topics to consume from * * Example: - * $ bin/run-example streaming.KafkaWordCount broker1-host:port,broker2-host:port topic1,topic2 + * $ bin/run-example streaming.JavaDirectKafkaWordCount broker1-host:port,broker2-host:port \ + * topic1,topic2 */ public final class JavaDirectKafkaWordCount { @@ -48,7 +49,7 @@ public final class JavaDirectKafkaWordCount { public static void main(String[] args) { if (args.length < 2) { - System.err.println("Usage: DirectKafkaWordCount \n" + + System.err.println("Usage: JavaDirectKafkaWordCount \n" + " is a list of one or more Kafka brokers\n" + " is a list of one or more kafka topics to consume from\n\n"); System.exit(1); @@ -59,12 +60,12 @@ public static void main(String[] args) { String brokers = args[0]; String topics = args[1]; - // Create context with 2 second batch interval + // Create context with a 2 seconds batch interval SparkConf sparkConf = new SparkConf().setAppName("JavaDirectKafkaWordCount"); JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, Durations.seconds(2)); - HashSet topicsSet = new HashSet(Arrays.asList(topics.split(","))); - HashMap kafkaParams = new HashMap(); + HashSet topicsSet = new HashSet<>(Arrays.asList(topics.split(","))); + HashMap kafkaParams = new HashMap<>(); kafkaParams.put("metadata.broker.list", brokers); // Create direct kafka stream with brokers and topics @@ -87,15 +88,15 @@ public String call(Tuple2 tuple2) { }); JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Lists.newArrayList(SPACE.split(x)); + public Iterator call(String x) { + return Arrays.asList(SPACE.split(x)).iterator(); } }); JavaPairDStream wordCounts = words.mapToPair( new PairFunction() { @Override public Tuple2 call(String s) { - return new Tuple2(s, 1); + return new Tuple2<>(s, 1); } }).reduceByKey( new Function2() { diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java index da56637fe891a..bae4b78ac2f47 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java @@ -19,7 +19,6 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.Function; -import org.apache.spark.examples.streaming.StreamingExamples; import org.apache.spark.streaming.*; import org.apache.spark.streaming.api.java.*; import org.apache.spark.streaming.flume.FlumeUtils; @@ -58,7 +57,8 @@ public static void main(String[] args) { Duration batchInterval = new Duration(2000); SparkConf sparkConf = new SparkConf().setAppName("JavaFlumeEventCount"); JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, batchInterval); - JavaReceiverInputDStream flumeStream = FlumeUtils.createStream(ssc, host, port); + JavaReceiverInputDStream flumeStream = + FlumeUtils.createStream(ssc, host, port); flumeStream.count(); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java index 16ae9a3319ee2..655da6840cc57 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java @@ -17,20 +17,19 @@ package org.apache.spark.examples.streaming; +import java.util.Arrays; +import java.util.Iterator; import java.util.Map; import java.util.HashMap; import java.util.regex.Pattern; - import scala.Tuple2; -import com.google.common.collect.Lists; import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.examples.streaming.StreamingExamples; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; @@ -66,11 +65,11 @@ public static void main(String[] args) { StreamingExamples.setStreamingLogLevels(); SparkConf sparkConf = new SparkConf().setAppName("JavaKafkaWordCount"); - // Create the context with a 1 second batch size + // Create the context with 2 seconds batch size JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, new Duration(2000)); int numThreads = Integer.parseInt(args[3]); - Map topicMap = new HashMap(); + Map topicMap = new HashMap<>(); String[] topics = args[2].split(","); for (String topic: topics) { topicMap.put(topic, numThreads); @@ -88,8 +87,8 @@ public String call(Tuple2 tuple2) { JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Lists.newArrayList(SPACE.split(x)); + public Iterator call(String x) { + return Arrays.asList(SPACE.split(x)).iterator(); } }); @@ -97,7 +96,7 @@ public Iterable call(String x) { new PairFunction() { @Override public Tuple2 call(String s) { - return new Tuple2(s, 1); + return new Tuple2<>(s, 1); } }).reduceByKey(new Function2() { @Override diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java index 3e9f0f4b8f127..5761da684b467 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java @@ -17,8 +17,11 @@ package org.apache.spark.examples.streaming; +import java.util.Arrays; +import java.util.Iterator; +import java.util.regex.Pattern; + import scala.Tuple2; -import com.google.common.collect.Lists; import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.FlatMapFunction; @@ -31,8 +34,6 @@ import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; -import java.util.regex.Pattern; - /** * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. * @@ -67,15 +68,15 @@ public static void main(String[] args) { args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER); JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Lists.newArrayList(SPACE.split(x)); + public Iterator call(String x) { + return Arrays.asList(SPACE.split(x)).iterator(); } }); JavaPairDStream wordCounts = words.mapToPair( new PairFunction() { @Override public Tuple2 call(String s) { - return new Tuple2(s, 1); + return new Tuple2<>(s, 1); } }).reduceByKey(new Function2() { @Override diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaQueueStream.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaQueueStream.java index 4ce8437f82705..62413b4606ff2 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaQueueStream.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaQueueStream.java @@ -30,7 +30,6 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.examples.streaming.StreamingExamples; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; @@ -50,7 +49,7 @@ public static void main(String[] args) throws Exception { // Create the queue through which RDDs can be pushed to // a QueueInputDStream - Queue> rddQueue = new LinkedList>(); + Queue> rddQueue = new LinkedList<>(); // Create and push some RDDs into the queue List list = Lists.newArrayList(); @@ -68,7 +67,7 @@ public static void main(String[] args) throws Exception { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i % 10, 1); + return new Tuple2<>(i % 10, 1); } }); JavaPairDStream reducedStream = mappedStream.reduceByKey( diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java index bceda97f058ea..e5fb2bfbfae7b 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java @@ -21,27 +21,70 @@ import java.io.IOException; import java.nio.charset.Charset; import java.util.Arrays; +import java.util.Iterator; +import java.util.List; import java.util.regex.Pattern; import scala.Tuple2; -import com.google.common.collect.Lists; + import com.google.common.io.Files; +import org.apache.spark.Accumulator; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.function.FlatMapFunction; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.*; +import org.apache.spark.broadcast.Broadcast; import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.Time; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; -import org.apache.spark.streaming.api.java.JavaStreamingContextFactory; /** - * Counts words in text encoded with UTF8 received from the network every second. + * Use this singleton to get or register a Broadcast variable. + */ +class JavaWordBlacklist { + + private static volatile Broadcast> instance = null; + + public static Broadcast> getInstance(JavaSparkContext jsc) { + if (instance == null) { + synchronized (JavaWordBlacklist.class) { + if (instance == null) { + List wordBlacklist = Arrays.asList("a", "b", "c"); + instance = jsc.broadcast(wordBlacklist); + } + } + } + return instance; + } +} + +/** + * Use this singleton to get or register an Accumulator. + */ +class JavaDroppedWordsCounter { + + private static volatile Accumulator instance = null; + + public static Accumulator getInstance(JavaSparkContext jsc) { + if (instance == null) { + synchronized (JavaDroppedWordsCounter.class) { + if (instance == null) { + instance = jsc.accumulator(0, "WordsInBlacklistCounter"); + } + } + } + return instance; + } +} + +/** + * Counts words in text encoded with UTF8 received from the network every second. This example also + * shows how to use lazily instantiated singleton instances for Accumulator and Broadcast so that + * they can be registered on driver failures. * * Usage: JavaRecoverableNetworkWordCount * and describe the TCP server that Spark Streaming would connect to receive @@ -91,15 +134,15 @@ private static JavaStreamingContext createContext(String ip, JavaReceiverInputDStream lines = ssc.socketTextStream(ip, port); JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Lists.newArrayList(SPACE.split(x)); + public Iterator call(String x) { + return Arrays.asList(SPACE.split(x)).iterator(); } }); JavaPairDStream wordCounts = words.mapToPair( new PairFunction() { @Override public Tuple2 call(String s) { - return new Tuple2(s, 1); + return new Tuple2<>(s, 1); } }).reduceByKey(new Function2() { @Override @@ -108,14 +151,32 @@ public Integer call(Integer i1, Integer i2) { } }); - wordCounts.foreachRDD(new Function2, Time, Void>() { + wordCounts.foreachRDD(new VoidFunction2, Time>() { @Override - public Void call(JavaPairRDD rdd, Time time) throws IOException { - String counts = "Counts at time " + time + " " + rdd.collect(); - System.out.println(counts); + public void call(JavaPairRDD rdd, Time time) throws IOException { + // Get or register the blacklist Broadcast + final Broadcast> blacklist = + JavaWordBlacklist.getInstance(new JavaSparkContext(rdd.context())); + // Get or register the droppedWordsCounter Accumulator + final Accumulator droppedWordsCounter = + JavaDroppedWordsCounter.getInstance(new JavaSparkContext(rdd.context())); + // Use blacklist to drop words and use droppedWordsCounter to count them + String counts = rdd.filter(new Function, Boolean>() { + @Override + public Boolean call(Tuple2 wordCount) { + if (blacklist.value().contains(wordCount._1())) { + droppedWordsCounter.add(wordCount._2()); + return false; + } else { + return true; + } + } + }).collect().toString(); + String output = "Counts at time " + time + " " + counts; + System.out.println(output); + System.out.println("Dropped " + droppedWordsCounter.value() + " word(s) totally"); System.out.println("Appending to " + outputFile.getAbsolutePath()); - Files.append(counts + "\n", outputFile, Charset.defaultCharset()); - return null; + Files.append(output + "\n", outputFile, Charset.defaultCharset()); } }); @@ -141,13 +202,18 @@ public static void main(String[] args) { final int port = Integer.parseInt(args[1]); final String checkpointDirectory = args[2]; final String outputPath = args[3]; - JavaStreamingContextFactory factory = new JavaStreamingContextFactory() { + + // Function to create JavaStreamingContext without any output operations + // (used to detect the new context) + Function0 createContextFunc = new Function0() { @Override - public JavaStreamingContext create() { + public JavaStreamingContext call() { return createContext(ip, port, checkpointDirectory, outputPath); } }; - JavaStreamingContext ssc = JavaStreamingContext.getOrCreate(checkpointDirectory, factory); + + JavaStreamingContext ssc = + JavaStreamingContext.getOrCreate(checkpointDirectory, createContextFunc); ssc.start(); ssc.awaitTermination(); } diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java index 46562ddbbcb57..4b9d9efc8549a 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java @@ -17,18 +17,19 @@ package org.apache.spark.examples.streaming; +import java.util.Arrays; +import java.util.Iterator; import java.util.regex.Pattern; -import com.google.common.collect.Lists; - import org.apache.spark.SparkConf; import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.VoidFunction2; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.DataFrame; import org.apache.spark.api.java.StorageLevels; import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.Time; @@ -72,36 +73,36 @@ public static void main(String[] args) { args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER); JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Lists.newArrayList(SPACE.split(x)); + public Iterator call(String x) { + return Arrays.asList(SPACE.split(x)).iterator(); } }); // Convert RDDs of the words DStream to DataFrame and run SQL query - words.foreachRDD(new Function2, Time, Void>() { + words.foreachRDD(new VoidFunction2, Time>() { @Override - public Void call(JavaRDD rdd, Time time) { + public void call(JavaRDD rdd, Time time) { SQLContext sqlContext = JavaSQLContextSingleton.getInstance(rdd.context()); // Convert JavaRDD[String] to JavaRDD[bean class] to DataFrame JavaRDD rowRDD = rdd.map(new Function() { + @Override public JavaRecord call(String word) { JavaRecord record = new JavaRecord(); record.setWord(word); return record; } }); - DataFrame wordsDataFrame = sqlContext.createDataFrame(rowRDD, JavaRecord.class); + Dataset wordsDataFrame = sqlContext.createDataFrame(rowRDD, JavaRecord.class); // Register as table wordsDataFrame.registerTempTable("words"); // Do word count on table using SQL and print it - DataFrame wordCountsDataFrame = + Dataset wordCountsDataFrame = sqlContext.sql("select word, count(*) as total from words group by word"); System.out.println("========= " + time + "========="); wordCountsDataFrame.show(); - return null; } }); @@ -112,8 +113,8 @@ public JavaRecord call(String word) { /** Lazily instantiated singleton instance of SQLContext */ class JavaSQLContextSingleton { - static private transient SQLContext instance = null; - static public SQLContext getInstance(SparkContext sparkContext) { + private static transient SQLContext instance = null; + public static SQLContext getInstance(SparkContext sparkContext) { if (instance == null) { instance = new SQLContext(sparkContext); } diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java index 99b63a2590ae2..4230dab52e5d4 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java @@ -18,26 +18,21 @@ package org.apache.spark.examples.streaming; import java.util.Arrays; +import java.util.Iterator; import java.util.List; import java.util.regex.Pattern; import scala.Tuple2; -import com.google.common.base.Optional; -import com.google.common.collect.Lists; - -import org.apache.spark.HashPartitioner; import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.*; import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.Optional; import org.apache.spark.api.java.StorageLevels; -import org.apache.spark.api.java.function.FlatMapFunction; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.streaming.Durations; -import org.apache.spark.streaming.api.java.JavaDStream; -import org.apache.spark.streaming.api.java.JavaPairDStream; -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; -import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.apache.spark.streaming.State; +import org.apache.spark.streaming.StateSpec; +import org.apache.spark.streaming.api.java.*; /** * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every @@ -63,28 +58,15 @@ public static void main(String[] args) { StreamingExamples.setStreamingLogLevels(); - // Update the cumulative count function - final Function2, Optional, Optional> updateFunction = - new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - Integer newSum = state.or(0); - for (Integer value : values) { - newSum += value; - } - return Optional.of(newSum); - } - }; - // Create the context with a 1 second batch size SparkConf sparkConf = new SparkConf().setAppName("JavaStatefulNetworkWordCount"); JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1)); ssc.checkpoint("."); - // Initial RDD input to updateStateByKey + // Initial state RDD input to mapWithState @SuppressWarnings("unchecked") - List> tuples = Arrays.asList(new Tuple2("hello", 1), - new Tuple2("world", 1)); + List> tuples = + Arrays.asList(new Tuple2<>("hello", 1), new Tuple2<>("world", 1)); JavaPairRDD initialRDD = ssc.sparkContext().parallelizePairs(tuples); JavaReceiverInputDStream lines = ssc.socketTextStream( @@ -92,8 +74,8 @@ public Optional call(List values, Optional state) { JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Lists.newArrayList(SPACE.split(x)); + public Iterator call(String x) { + return Arrays.asList(SPACE.split(x)).iterator(); } }); @@ -101,13 +83,26 @@ public Iterable call(String x) { new PairFunction() { @Override public Tuple2 call(String s) { - return new Tuple2(s, 1); + return new Tuple2<>(s, 1); } }); - // This will give a Dstream made of state (which is the cumulative count of the words) - JavaPairDStream stateDstream = wordsDstream.updateStateByKey(updateFunction, - new HashPartitioner(ssc.sparkContext().defaultParallelism()), initialRDD); + // Update the cumulative count function + Function3, State, Tuple2> mappingFunc = + new Function3, State, Tuple2>() { + @Override + public Tuple2 call(String word, Optional one, + State state) { + int sum = one.orElse(0) + (state.exists() ? state.get() : 0); + Tuple2 output = new Tuple2<>(word, sum); + state.update(sum); + return output; + } + }; + + // DStream made of get cumulative counts that get updated in every batch + JavaMapWithStateDStream> stateDstream = + wordsDstream.mapWithState(StateSpec.function(mappingFunc).initialState(initialRDD)); stateDstream.print(); ssc.start(); diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index 1c3a787bd0e94..205ca02962bee 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -36,7 +36,7 @@ def rmse(R, ms, us): diff = R - ms * us.T - return np.sqrt(np.sum(np.power(diff, 2)) / M * U) + return np.sqrt(np.sum(np.power(diff, 2)) / (M * U)) def update(i, vec, mat, ratings): diff --git a/examples/src/main/python/ml/aft_survival_regression.py b/examples/src/main/python/ml/aft_survival_regression.py new file mode 100644 index 0000000000000..0ee01fd8258df --- /dev/null +++ b/examples/src/main/python/ml/aft_survival_regression.py @@ -0,0 +1,51 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.regression import AFTSurvivalRegression +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="AFTSurvivalRegressionExample") + sqlContext = SQLContext(sc) + + # $example on$ + training = sqlContext.createDataFrame([ + (1.218, 1.0, Vectors.dense(1.560, -0.605)), + (2.949, 0.0, Vectors.dense(0.346, 2.158)), + (3.627, 0.0, Vectors.dense(1.380, 0.231)), + (0.273, 1.0, Vectors.dense(0.520, 1.151)), + (4.199, 0.0, Vectors.dense(0.795, -0.226))], ["label", "censor", "features"]) + quantileProbabilities = [0.3, 0.6] + aft = AFTSurvivalRegression(quantileProbabilities=quantileProbabilities, + quantilesCol="quantiles") + + model = aft.fit(training) + + # Print the coefficients, intercept and scale parameter for AFT survival regression + print("Coefficients: " + str(model.coefficients)) + print("Intercept: " + str(model.intercept)) + print("Scale: " + str(model.scale)) + model.transform(training).show(truncate=False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/als_example.py b/examples/src/main/python/ml/als_example.py new file mode 100644 index 0000000000000..922173308c6aa --- /dev/null +++ b/examples/src/main/python/ml/als_example.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext + +# $example on$ +from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.ml.recommendation import ALS +from pyspark.sql import Row +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="ALSExample") + sqlContext = SQLContext(sc) + + # $example on$ + lines = sc.textFile("data/mllib/als/sample_movielens_ratings.txt") + parts = lines.map(lambda l: l.split("::")) + ratingsRDD = parts.map(lambda p: Row(userId=int(p[0]), movieId=int(p[1]), + rating=float(p[2]), timestamp=long(p[3]))) + ratings = sqlContext.createDataFrame(ratingsRDD) + (training, test) = ratings.randomSplit([0.8, 0.2]) + + # Build the recommendation model using ALS on the training data + als = ALS(maxIter=5, regParam=0.01, userCol="userId", itemCol="movieId", ratingCol="rating") + model = als.fit(training) + + # Evaluate the model by computing the RMSE on the test data + rawPredictions = model.transform(test) + predictions = rawPredictions\ + .withColumn("rating", rawPredictions.rating.cast("double"))\ + .withColumn("prediction", rawPredictions.prediction.cast("double")) + evaluator =\ + RegressionEvaluator(metricName="rmse", labelCol="rating", predictionCol="prediction") + rmse = evaluator.evaluate(predictions) + print("Root-mean-square error = " + str(rmse)) + # $example off$ + sc.stop() diff --git a/examples/src/main/python/ml/binarizer_example.py b/examples/src/main/python/ml/binarizer_example.py new file mode 100644 index 0000000000000..317cfa638a5a9 --- /dev/null +++ b/examples/src/main/python/ml/binarizer_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Binarizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="BinarizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + continuousDataFrame = sqlContext.createDataFrame([ + (0, 0.1), + (1, 0.8), + (2, 0.2) + ], ["label", "feature"]) + binarizer = Binarizer(threshold=0.5, inputCol="feature", outputCol="binarized_feature") + binarizedDataFrame = binarizer.transform(continuousDataFrame) + binarizedFeatures = binarizedDataFrame.select("binarized_feature") + for binarized_feature, in binarizedFeatures.collect(): + print(binarized_feature) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/bisecting_k_means_example.py b/examples/src/main/python/ml/bisecting_k_means_example.py new file mode 100644 index 0000000000000..e6f6bfd7e84ed --- /dev/null +++ b/examples/src/main/python/ml/bisecting_k_means_example.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.ml.clustering import BisectingKMeans, BisectingKMeansModel +from pyspark.mllib.linalg import VectorUDT, _convert_to_vector, Vectors +from pyspark.mllib.linalg import Vectors +from pyspark.sql.types import Row +# $example off$ +from pyspark.sql import SQLContext + +""" +A simple example demonstrating a bisecting k-means clustering. +""" + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonBisectingKMeansExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = sc.textFile("data/mllib/kmeans_data.txt") + parsed = data.map(lambda l: Row(features=Vectors.dense([float(x) for x in l.split(' ')]))) + training = sqlContext.createDataFrame(parsed) + + kmeans = BisectingKMeans().setK(2).setSeed(1).setFeaturesCol("features") + + model = kmeans.fit(training) + + # Evaluate clustering + cost = model.computeCost(training) + print("Bisecting K-means Cost = " + str(cost)) + + centers = model.clusterCenters() + print("Cluster Centers: ") + for center in centers: + print(center) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/bucketizer_example.py b/examples/src/main/python/ml/bucketizer_example.py new file mode 100644 index 0000000000000..4304255f350db --- /dev/null +++ b/examples/src/main/python/ml/bucketizer_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Bucketizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="BucketizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] + + data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] + dataFrame = sqlContext.createDataFrame(data, ["features"]) + + bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") + + # Transform original data into its bucket index. + bucketedData = bucketizer.transform(dataFrame) + bucketedData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/count_vectorizer_example.py b/examples/src/main/python/ml/count_vectorizer_example.py new file mode 100644 index 0000000000000..e839f645f70b5 --- /dev/null +++ b/examples/src/main/python/ml/count_vectorizer_example.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import CountVectorizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="CountVectorizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + # Input data: Each row is a bag of words with a ID. + df = sqlContext.createDataFrame([ + (0, "a b c".split(" ")), + (1, "a b b c a".split(" ")) + ], ["id", "words"]) + + # fit a CountVectorizerModel from the corpus. + cv = CountVectorizer(inputCol="words", outputCol="features", vocabSize=3, minDF=2.0) + model = cv.fit(df) + result = model.transform(df) + result.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/cross_validator.py b/examples/src/main/python/ml/cross_validator.py index f0ca97c724940..5f0ef20218c4a 100644 --- a/examples/src/main/python/ml/cross_validator.py +++ b/examples/src/main/python/ml/cross_validator.py @@ -18,12 +18,14 @@ from __future__ import print_function from pyspark import SparkContext +# $example on$ from pyspark.ml import Pipeline from pyspark.ml.classification import LogisticRegression from pyspark.ml.evaluation import BinaryClassificationEvaluator from pyspark.ml.feature import HashingTF, Tokenizer from pyspark.ml.tuning import CrossValidator, ParamGridBuilder from pyspark.sql import Row, SQLContext +# $example off$ """ A simple example demonstrating model selection using CrossValidator. @@ -36,7 +38,7 @@ if __name__ == "__main__": sc = SparkContext(appName="CrossValidatorExample") sqlContext = SQLContext(sc) - + # $example on$ # Prepare training documents, which are labeled. LabeledDocument = Row("id", "text", "label") training = sc.parallelize([(0, "a b c d e spark", 1.0), @@ -92,5 +94,6 @@ selected = prediction.select("id", "text", "probability", "prediction") for row in selected.collect(): print(row) + # $example off$ sc.stop() diff --git a/examples/src/main/python/ml/dataframe_example.py b/examples/src/main/python/ml/dataframe_example.py new file mode 100644 index 0000000000000..d2644ca335654 --- /dev/null +++ b/examples/src/main/python/ml/dataframe_example.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +An example of how to use DataFrame for ML. Run with:: + bin/spark-submit examples/src/main/python/ml/dataframe_example.py +""" +from __future__ import print_function + +import os +import sys +import tempfile +import shutil + +from pyspark import SparkContext +from pyspark.sql import SQLContext +from pyspark.mllib.stat import Statistics + +if __name__ == "__main__": + if len(sys.argv) > 2: + print("Usage: dataframe_example.py ", file=sys.stderr) + exit(-1) + sc = SparkContext(appName="DataFrameExample") + sqlContext = SQLContext(sc) + if len(sys.argv) == 2: + input = sys.argv[1] + else: + input = "data/mllib/sample_libsvm_data.txt" + + # Load input data + print("Loading LIBSVM file with UDT from " + input + ".") + df = sqlContext.read.format("libsvm").load(input).cache() + print("Schema from LIBSVM:") + df.printSchema() + print("Loaded training data as a DataFrame with " + + str(df.count()) + " records.") + + # Show statistical summary of labels. + labelSummary = df.describe("label") + labelSummary.show() + + # Convert features column to an RDD of vectors. + features = df.select("features").map(lambda r: r.features) + summary = Statistics.colStats(features) + print("Selected features column with average values:\n" + + str(summary.mean())) + + # Save the records in a parquet file. + tempdir = tempfile.NamedTemporaryFile(delete=False).name + os.unlink(tempdir) + print("Saving to " + tempdir + " as Parquet file.") + df.write.parquet(tempdir) + + # Load the records back. + print("Loading Parquet file with UDT from " + tempdir) + newDF = sqlContext.read.parquet(tempdir) + print("Schema from Parquet:") + newDF.printSchema() + shutil.rmtree(tempdir) + + sc.stop() diff --git a/examples/src/main/python/ml/dct_example.py b/examples/src/main/python/ml/dct_example.py new file mode 100644 index 0000000000000..264d47f404cb1 --- /dev/null +++ b/examples/src/main/python/ml/dct_example.py @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import DCT +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="DCTExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame([ + (Vectors.dense([0.0, 1.0, -2.0, 3.0]),), + (Vectors.dense([-1.0, 2.0, 4.0, -7.0]),), + (Vectors.dense([14.0, -2.0, -5.0, 1.0]),)], ["features"]) + + dct = DCT(inverse=False, inputCol="features", outputCol="featuresDCT") + + dctDf = dct.transform(df) + + for dcts in dctDf.select("featuresDCT").take(3): + print(dcts) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/decision_tree_classification_example.py b/examples/src/main/python/ml/decision_tree_classification_example.py new file mode 100644 index 0000000000000..86bdc65392bbb --- /dev/null +++ b/examples/src/main/python/ml/decision_tree_classification_example.py @@ -0,0 +1,74 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision Tree Classification Example. +""" +from __future__ import print_function + +# $example on$ +from pyspark import SparkContext, SQLContext +from pyspark.ml import Pipeline +from pyspark.ml.classification import DecisionTreeClassifier +from pyspark.ml.feature import StringIndexer, VectorIndexer +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="decision_tree_classification_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load the data stored in LIBSVM format as a DataFrame. + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Index labels, adding metadata to the label column. + # Fit on whole dataset to include all labels in index. + labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) + # Automatically identify categorical features, and index them. + # We specify maxCategories so features with > 4 distinct values are treated as continuous. + featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a DecisionTree model. + dt = DecisionTreeClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures") + + # Chain indexers and tree in a Pipeline + pipeline = Pipeline(stages=[labelIndexer, featureIndexer, dt]) + + # Train model. This also runs the indexers. + model = pipeline.fit(trainingData) + + # Make predictions. + predictions = model.transform(testData) + + # Select example rows to display. + predictions.select("prediction", "indexedLabel", "features").show(5) + + # Select (prediction, true label) and compute test error + evaluator = MulticlassClassificationEvaluator( + labelCol="indexedLabel", predictionCol="prediction", metricName="precision") + accuracy = evaluator.evaluate(predictions) + print("Test Error = %g " % (1.0 - accuracy)) + + treeModel = model.stages[2] + # summary only + print(treeModel) + # $example off$ diff --git a/examples/src/main/python/ml/decision_tree_regression_example.py b/examples/src/main/python/ml/decision_tree_regression_example.py new file mode 100644 index 0000000000000..8e20d5d8572a5 --- /dev/null +++ b/examples/src/main/python/ml/decision_tree_regression_example.py @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision Tree Regression Example. +""" +from __future__ import print_function + +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.ml import Pipeline +from pyspark.ml.regression import DecisionTreeRegressor +from pyspark.ml.feature import VectorIndexer +from pyspark.ml.evaluation import RegressionEvaluator +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="decision_tree_classification_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load the data stored in LIBSVM format as a DataFrame. + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Automatically identify categorical features, and index them. + # We specify maxCategories so features with > 4 distinct values are treated as continuous. + featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a DecisionTree model. + dt = DecisionTreeRegressor(featuresCol="indexedFeatures") + + # Chain indexer and tree in a Pipeline + pipeline = Pipeline(stages=[featureIndexer, dt]) + + # Train model. This also runs the indexer. + model = pipeline.fit(trainingData) + + # Make predictions. + predictions = model.transform(testData) + + # Select example rows to display. + predictions.select("prediction", "label", "features").show(5) + + # Select (prediction, true label) and compute test error + evaluator = RegressionEvaluator( + labelCol="label", predictionCol="prediction", metricName="rmse") + rmse = evaluator.evaluate(predictions) + print("Root Mean Squared Error (RMSE) on test data = %g" % rmse) + + treeModel = model.stages[1] + # summary only + print(treeModel) + # $example off$ diff --git a/examples/src/main/python/ml/elementwise_product_example.py b/examples/src/main/python/ml/elementwise_product_example.py new file mode 100644 index 0000000000000..c85cb0d89543c --- /dev/null +++ b/examples/src/main/python/ml/elementwise_product_example.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import ElementwiseProduct +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="ElementwiseProductExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] + df = sqlContext.createDataFrame(data, ["vector"]) + transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), + inputCol="vector", outputCol="transformedVector") + transformer.transform(df).show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/estimator_transformer_param_example.py b/examples/src/main/python/ml/estimator_transformer_param_example.py new file mode 100644 index 0000000000000..9a8993dac4f65 --- /dev/null +++ b/examples/src/main/python/ml/estimator_transformer_param_example.py @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Estimator Transformer Param Example. +""" +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.mllib.linalg import Vectors +from pyspark.ml.classification import LogisticRegression +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="EstimatorTransformerParamExample") + sqlContext = SQLContext(sc) + + # $example on$ + # Prepare training data from a list of (label, features) tuples. + training = sqlContext.createDataFrame([ + (1.0, Vectors.dense([0.0, 1.1, 0.1])), + (0.0, Vectors.dense([2.0, 1.0, -1.0])), + (0.0, Vectors.dense([2.0, 1.3, 1.0])), + (1.0, Vectors.dense([0.0, 1.2, -0.5]))], ["label", "features"]) + + # Create a LogisticRegression instance. This instance is an Estimator. + lr = LogisticRegression(maxIter=10, regParam=0.01) + # Print out the parameters, documentation, and any default values. + print "LogisticRegression parameters:\n" + lr.explainParams() + "\n" + + # Learn a LogisticRegression model. This uses the parameters stored in lr. + model1 = lr.fit(training) + + # Since model1 is a Model (i.e., a transformer produced by an Estimator), + # we can view the parameters it used during fit(). + # This prints the parameter (name: value) pairs, where names are unique IDs for this + # LogisticRegression instance. + print "Model 1 was fit using parameters: " + print model1.extractParamMap() + + # We may alternatively specify parameters using a Python dictionary as a paramMap + paramMap = {lr.maxIter: 20} + paramMap[lr.maxIter] = 30 # Specify 1 Param, overwriting the original maxIter. + paramMap.update({lr.regParam: 0.1, lr.threshold: 0.55}) # Specify multiple Params. + + # You can combine paramMaps, which are python dictionaries. + paramMap2 = {lr.probabilityCol: "myProbability"} # Change output column name + paramMapCombined = paramMap.copy() + paramMapCombined.update(paramMap2) + + # Now learn a new model using the paramMapCombined parameters. + # paramMapCombined overrides all parameters set earlier via lr.set* methods. + model2 = lr.fit(training, paramMapCombined) + print "Model 2 was fit using parameters: " + print model2.extractParamMap() + + # Prepare test data + test = sqlContext.createDataFrame([ + (1.0, Vectors.dense([-1.0, 1.5, 1.3])), + (0.0, Vectors.dense([3.0, 2.0, -0.1])), + (1.0, Vectors.dense([0.0, 2.2, -1.5]))], ["label", "features"]) + + # Make predictions on test data using the Transformer.transform() method. + # LogisticRegression.transform will only use the 'features' column. + # Note that model2.transform() outputs a "myProbability" column instead of the usual + # 'probability' column since we renamed the lr.probabilityCol parameter previously. + prediction = model2.transform(test) + selected = prediction.select("features", "label", "myProbability", "prediction") + for row in selected.collect(): + print row + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/gradient_boosted_tree_classifier_example.py b/examples/src/main/python/ml/gradient_boosted_tree_classifier_example.py new file mode 100644 index 0000000000000..f7e842f4b303a --- /dev/null +++ b/examples/src/main/python/ml/gradient_boosted_tree_classifier_example.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Gradient Boosted Tree Classifier Example. +""" +from __future__ import print_function + +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.ml import Pipeline +from pyspark.ml.classification import GBTClassifier +from pyspark.ml.feature import StringIndexer, VectorIndexer +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="gradient_boosted_tree_classifier_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load and parse the data file, converting it to a DataFrame. + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Index labels, adding metadata to the label column. + # Fit on whole dataset to include all labels in index. + labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) + # Automatically identify categorical features, and index them. + # Set maxCategories so features with > 4 distinct values are treated as continuous. + featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a GBT model. + gbt = GBTClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures", maxIter=10) + + # Chain indexers and GBT in a Pipeline + pipeline = Pipeline(stages=[labelIndexer, featureIndexer, gbt]) + + # Train model. This also runs the indexers. + model = pipeline.fit(trainingData) + + # Make predictions. + predictions = model.transform(testData) + + # Select example rows to display. + predictions.select("prediction", "indexedLabel", "features").show(5) + + # Select (prediction, true label) and compute test error + evaluator = MulticlassClassificationEvaluator( + labelCol="indexedLabel", predictionCol="prediction", metricName="precision") + accuracy = evaluator.evaluate(predictions) + print("Test Error = %g" % (1.0 - accuracy)) + + gbtModel = model.stages[2] + print(gbtModel) # summary only + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/gradient_boosted_tree_regressor_example.py b/examples/src/main/python/ml/gradient_boosted_tree_regressor_example.py new file mode 100644 index 0000000000000..f8b4de651c768 --- /dev/null +++ b/examples/src/main/python/ml/gradient_boosted_tree_regressor_example.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Gradient Boosted Tree Regressor Example. +""" +from __future__ import print_function + +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.ml import Pipeline +from pyspark.ml.regression import GBTRegressor +from pyspark.ml.feature import VectorIndexer +from pyspark.ml.evaluation import RegressionEvaluator +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="gradient_boosted_tree_regressor_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load and parse the data file, converting it to a DataFrame. + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Automatically identify categorical features, and index them. + # Set maxCategories so features with > 4 distinct values are treated as continuous. + featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a GBT model. + gbt = GBTRegressor(featuresCol="indexedFeatures", maxIter=10) + + # Chain indexer and GBT in a Pipeline + pipeline = Pipeline(stages=[featureIndexer, gbt]) + + # Train model. This also runs the indexer. + model = pipeline.fit(trainingData) + + # Make predictions. + predictions = model.transform(testData) + + # Select example rows to display. + predictions.select("prediction", "label", "features").show(5) + + # Select (prediction, true label) and compute test error + evaluator = RegressionEvaluator( + labelCol="label", predictionCol="prediction", metricName="rmse") + rmse = evaluator.evaluate(predictions) + print("Root Mean Squared Error (RMSE) on test data = %g" % rmse) + + gbtModel = model.stages[1] + print(gbtModel) # summary only + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/gradient_boosted_trees.py b/examples/src/main/python/ml/gradient_boosted_trees.py deleted file mode 100644 index 6446f0fe5eeab..0000000000000 --- a/examples/src/main/python/ml/gradient_boosted_trees.py +++ /dev/null @@ -1,83 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -import sys - -from pyspark import SparkContext -from pyspark.ml.classification import GBTClassifier -from pyspark.ml.feature import StringIndexer -from pyspark.ml.regression import GBTRegressor -from pyspark.mllib.evaluation import BinaryClassificationMetrics, RegressionMetrics -from pyspark.mllib.util import MLUtils -from pyspark.sql import Row, SQLContext - -""" -A simple example demonstrating a Gradient Boosted Trees Classification/Regression Pipeline. -Note: GBTClassifier only supports binary classification currently -Run with: - bin/spark-submit examples/src/main/python/ml/gradient_boosted_trees.py -""" - - -def testClassification(train, test): - # Train a GradientBoostedTrees model. - - rf = GBTClassifier(maxIter=30, maxDepth=4, labelCol="indexedLabel") - - model = rf.fit(train) - predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = BinaryClassificationMetrics(predictionAndLabels) - print("AUC %.3f" % metrics.areaUnderROC) - - -def testRegression(train, test): - # Train a GradientBoostedTrees model. - - rf = GBTRegressor(maxIter=30, maxDepth=4, labelCol="indexedLabel") - - model = rf.fit(train) - predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = RegressionMetrics(predictionAndLabels) - print("rmse %.3f" % metrics.rootMeanSquaredError) - print("r2 %.3f" % metrics.r2) - print("mae %.3f" % metrics.meanAbsoluteError) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - print("Usage: gradient_boosted_trees", file=sys.stderr) - exit(1) - sc = SparkContext(appName="PythonGBTExample") - sqlContext = SQLContext(sc) - - # Load and parse the data file into a dataframe. - df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() - - # Map labels into an indexed column of labels in [0, numLabels) - stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") - si_model = stringIndexer.fit(df) - td = si_model.transform(df) - [train, test] = td.randomSplit([0.7, 0.3]) - testClassification(train, test) - testRegression(train, test) - sc.stop() diff --git a/examples/src/main/python/ml/index_to_string_example.py b/examples/src/main/python/ml/index_to_string_example.py new file mode 100644 index 0000000000000..fb0ba2950bbd6 --- /dev/null +++ b/examples/src/main/python/ml/index_to_string_example.py @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.ml.feature import IndexToString, StringIndexer +# $example off$ +from pyspark.sql import SQLContext + +if __name__ == "__main__": + sc = SparkContext(appName="IndexToStringExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame( + [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], + ["id", "category"]) + + stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") + model = stringIndexer.fit(df) + indexed = model.transform(df) + + converter = IndexToString(inputCol="categoryIndex", outputCol="originalCategory") + converted = converter.transform(indexed) + + converted.select("id", "originalCategory").show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/kmeans_example.py b/examples/src/main/python/ml/kmeans_example.py index 150dadd42f33e..fa57a4d3ada1b 100644 --- a/examples/src/main/python/ml/kmeans_example.py +++ b/examples/src/main/python/ml/kmeans_example.py @@ -18,7 +18,6 @@ from __future__ import print_function import sys -import re import numpy as np from pyspark import SparkContext diff --git a/examples/src/main/python/ml/linear_regression_with_elastic_net.py b/examples/src/main/python/ml/linear_regression_with_elastic_net.py new file mode 100644 index 0000000000000..a4cd40cf26726 --- /dev/null +++ b/examples/src/main/python/ml/linear_regression_with_elastic_net.py @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.regression import LinearRegression +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="LinearRegressionWithElasticNet") + sqlContext = SQLContext(sc) + + # $example on$ + # Load training data + training = sqlContext.read.format("libsvm")\ + .load("data/mllib/sample_linear_regression_data.txt") + + lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) + + # Fit the model + lrModel = lr.fit(training) + + # Print the coefficients and intercept for linear regression + print("Coefficients: " + str(lrModel.coefficients)) + print("Intercept: " + str(lrModel.intercept)) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/logistic_regression.py b/examples/src/main/python/ml/logistic_regression.py deleted file mode 100644 index 55afe1b207fe0..0000000000000 --- a/examples/src/main/python/ml/logistic_regression.py +++ /dev/null @@ -1,67 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -import sys - -from pyspark import SparkContext -from pyspark.ml.classification import LogisticRegression -from pyspark.mllib.evaluation import MulticlassMetrics -from pyspark.ml.feature import StringIndexer -from pyspark.mllib.util import MLUtils -from pyspark.sql import SQLContext - -""" -A simple example demonstrating a logistic regression with elastic net regularization Pipeline. -Run with: - bin/spark-submit examples/src/main/python/ml/logistic_regression.py -""" - -if __name__ == "__main__": - - if len(sys.argv) > 1: - print("Usage: logistic_regression", file=sys.stderr) - exit(-1) - - sc = SparkContext(appName="PythonLogisticRegressionExample") - sqlContext = SQLContext(sc) - - # Load and parse the data file into a dataframe. - df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() - - # Map labels into an indexed column of labels in [0, numLabels) - stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") - si_model = stringIndexer.fit(df) - td = si_model.transform(df) - [training, test] = td.randomSplit([0.7, 0.3]) - - lr = LogisticRegression(maxIter=100, regParam=0.3).setLabelCol("indexedLabel") - lr.setElasticNetParam(0.8) - - # Fit the model - lrModel = lr.fit(training) - - predictionAndLabels = lrModel.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = MulticlassMetrics(predictionAndLabels) - print("weighted f-measure %.3f" % metrics.weightedFMeasure()) - print("precision %s" % metrics.precision()) - print("recall %s" % metrics.recall()) - - sc.stop() diff --git a/examples/src/main/python/ml/logistic_regression_with_elastic_net.py b/examples/src/main/python/ml/logistic_regression_with_elastic_net.py new file mode 100644 index 0000000000000..b0b1d27e13bb0 --- /dev/null +++ b/examples/src/main/python/ml/logistic_regression_with_elastic_net.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.classification import LogisticRegression +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="LogisticRegressionWithElasticNet") + sqlContext = SQLContext(sc) + + # $example on$ + # Load training data + training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) + + # Fit the model + lrModel = lr.fit(training) + + # Print the coefficients and intercept for logistic regression + print("Coefficients: " + str(lrModel.coefficients)) + print("Intercept: " + str(lrModel.intercept)) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/max_abs_scaler_example.py b/examples/src/main/python/ml/max_abs_scaler_example.py new file mode 100644 index 0000000000000..d9b69eef1cd84 --- /dev/null +++ b/examples/src/main/python/ml/max_abs_scaler_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import MaxAbsScaler +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="MaxAbsScalerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + scaler = MaxAbsScaler(inputCol="features", outputCol="scaledFeatures") + + # Compute summary statistics and generate MaxAbsScalerModel + scalerModel = scaler.fit(dataFrame) + + # rescale each feature to range [-1, 1]. + scaledData = scalerModel.transform(dataFrame) + scaledData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/min_max_scaler_example.py b/examples/src/main/python/ml/min_max_scaler_example.py new file mode 100644 index 0000000000000..2f8e4ade468b9 --- /dev/null +++ b/examples/src/main/python/ml/min_max_scaler_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import MinMaxScaler +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="MinMaxScalerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + scaler = MinMaxScaler(inputCol="features", outputCol="scaledFeatures") + + # Compute summary statistics and generate MinMaxScalerModel + scalerModel = scaler.fit(dataFrame) + + # rescale each feature to range [min, max]. + scaledData = scalerModel.transform(dataFrame) + scaledData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/multilayer_perceptron_classification.py b/examples/src/main/python/ml/multilayer_perceptron_classification.py new file mode 100644 index 0000000000000..f84588f547fff --- /dev/null +++ b/examples/src/main/python/ml/multilayer_perceptron_classification.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.classification import MultilayerPerceptronClassifier +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="multilayer_perceptron_classification_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load training data + data = sqlContext.read.format("libsvm")\ + .load("data/mllib/sample_multiclass_classification_data.txt") + # Split the data into train and test + splits = data.randomSplit([0.6, 0.4], 1234) + train = splits[0] + test = splits[1] + # specify layers for the neural network: + # input layer of size 4 (features), two intermediate of size 5 and 4 + # and output of size 3 (classes) + layers = [4, 5, 4, 3] + # create the trainer and set its parameters + trainer = MultilayerPerceptronClassifier(maxIter=100, layers=layers, blockSize=128, seed=1234) + # train the model + model = trainer.fit(train) + # compute precision on the test set + result = model.transform(test) + predictionAndLabels = result.select("prediction", "label") + evaluator = MulticlassClassificationEvaluator(metricName="precision") + print("Precision:" + str(evaluator.evaluate(predictionAndLabels))) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/n_gram_example.py b/examples/src/main/python/ml/n_gram_example.py new file mode 100644 index 0000000000000..f2d85f53e7219 --- /dev/null +++ b/examples/src/main/python/ml/n_gram_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import NGram +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="NGramExample") + sqlContext = SQLContext(sc) + + # $example on$ + wordDataFrame = sqlContext.createDataFrame([ + (0, ["Hi", "I", "heard", "about", "Spark"]), + (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), + (2, ["Logistic", "regression", "models", "are", "neat"]) + ], ["label", "words"]) + ngram = NGram(inputCol="words", outputCol="ngrams") + ngramDataFrame = ngram.transform(wordDataFrame) + for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): + print(ngrams_label) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/naive_bayes_example.py b/examples/src/main/python/ml/naive_bayes_example.py new file mode 100644 index 0000000000000..db8fbea9bf9b1 --- /dev/null +++ b/examples/src/main/python/ml/naive_bayes_example.py @@ -0,0 +1,53 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.classification import NaiveBayes +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="naive_bayes_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load training data + data = sqlContext.read.format("libsvm") \ + .load("data/mllib/sample_libsvm_data.txt") + # Split the data into train and test + splits = data.randomSplit([0.6, 0.4], 1234) + train = splits[0] + test = splits[1] + + # create the trainer and set its parameters + nb = NaiveBayes(smoothing=1.0, modelType="multinomial") + + # train the model + model = nb.fit(train) + # compute precision on the test set + result = model.transform(test) + predictionAndLabels = result.select("prediction", "label") + evaluator = MulticlassClassificationEvaluator(metricName="precision") + print("Precision:" + str(evaluator.evaluate(predictionAndLabels))) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/normalizer_example.py b/examples/src/main/python/ml/normalizer_example.py new file mode 100644 index 0000000000000..d490221474c24 --- /dev/null +++ b/examples/src/main/python/ml/normalizer_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Normalizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="NormalizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Normalize each Vector using $L^1$ norm. + normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) + l1NormData = normalizer.transform(dataFrame) + l1NormData.show() + + # Normalize each Vector using $L^\infty$ norm. + lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) + lInfNormData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/onehot_encoder_example.py b/examples/src/main/python/ml/onehot_encoder_example.py new file mode 100644 index 0000000000000..0f94c26638d35 --- /dev/null +++ b/examples/src/main/python/ml/onehot_encoder_example.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import OneHotEncoder, StringIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="OneHotEncoderExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame([ + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") + ], ["id", "category"]) + + stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") + model = stringIndexer.fit(df) + indexed = model.transform(df) + encoder = OneHotEncoder(dropLast=False, inputCol="categoryIndex", outputCol="categoryVec") + encoded = encoder.transform(indexed) + encoded.select("id", "categoryVec").show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/pca_example.py b/examples/src/main/python/ml/pca_example.py new file mode 100644 index 0000000000000..a17181f1b8a51 --- /dev/null +++ b/examples/src/main/python/ml/pca_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import PCA +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PCAExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), + (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), + (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] + df = sqlContext.createDataFrame(data, ["features"]) + pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") + model = pca.fit(df) + result = model.transform(df).select("pcaFeatures") + result.show(truncate=False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/pipeline_example.py b/examples/src/main/python/ml/pipeline_example.py new file mode 100644 index 0000000000000..3288568f0c287 --- /dev/null +++ b/examples/src/main/python/ml/pipeline_example.py @@ -0,0 +1,64 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Pipeline Example. +""" +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.ml import Pipeline +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.feature import HashingTF, Tokenizer +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="PipelineExample") + sqlContext = SQLContext(sc) + + # $example on$ + # Prepare training documents from a list of (id, text, label) tuples. + training = sqlContext.createDataFrame([ + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0)], ["id", "text", "label"]) + + # Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + tokenizer = Tokenizer(inputCol="text", outputCol="words") + hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") + lr = LogisticRegression(maxIter=10, regParam=0.01) + pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) + + # Fit the pipeline to training documents. + model = pipeline.fit(training) + + # Prepare test documents, which are unlabeled (id, text) tuples. + test = sqlContext.createDataFrame([ + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")], ["id", "text"]) + + # Make predictions on test documents and print columns of interest. + prediction = model.transform(test) + selected = prediction.select("id", "text", "prediction") + for row in selected.collect(): + print(row) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/polynomial_expansion_example.py b/examples/src/main/python/ml/polynomial_expansion_example.py new file mode 100644 index 0000000000000..89f5cbe8f2f41 --- /dev/null +++ b/examples/src/main/python/ml/polynomial_expansion_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import PolynomialExpansion +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PolynomialExpansionExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext\ + .createDataFrame([(Vectors.dense([-2.0, 2.3]),), + (Vectors.dense([0.0, 0.0]),), + (Vectors.dense([0.6, -1.1]),)], + ["features"]) + px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") + polyDF = px.transform(df) + for expanded in polyDF.select("polyFeatures").take(3): + print(expanded) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/random_forest_classifier_example.py b/examples/src/main/python/ml/random_forest_classifier_example.py new file mode 100644 index 0000000000000..c3570438c51d9 --- /dev/null +++ b/examples/src/main/python/ml/random_forest_classifier_example.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Random Forest Classifier Example. +""" +from __future__ import print_function + +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.ml import Pipeline +from pyspark.ml.classification import RandomForestClassifier +from pyspark.ml.feature import StringIndexer, VectorIndexer +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="random_forest_classifier_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load and parse the data file, converting it to a DataFrame. + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Index labels, adding metadata to the label column. + # Fit on whole dataset to include all labels in index. + labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) + # Automatically identify categorical features, and index them. + # Set maxCategories so features with > 4 distinct values are treated as continuous. + featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a RandomForest model. + rf = RandomForestClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures") + + # Chain indexers and forest in a Pipeline + pipeline = Pipeline(stages=[labelIndexer, featureIndexer, rf]) + + # Train model. This also runs the indexers. + model = pipeline.fit(trainingData) + + # Make predictions. + predictions = model.transform(testData) + + # Select example rows to display. + predictions.select("prediction", "indexedLabel", "features").show(5) + + # Select (prediction, true label) and compute test error + evaluator = MulticlassClassificationEvaluator( + labelCol="indexedLabel", predictionCol="prediction", metricName="precision") + accuracy = evaluator.evaluate(predictions) + print("Test Error = %g" % (1.0 - accuracy)) + + rfModel = model.stages[2] + print(rfModel) # summary only + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/random_forest_example.py b/examples/src/main/python/ml/random_forest_example.py deleted file mode 100644 index c7730e1bfacd9..0000000000000 --- a/examples/src/main/python/ml/random_forest_example.py +++ /dev/null @@ -1,87 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -import sys - -from pyspark import SparkContext -from pyspark.ml.classification import RandomForestClassifier -from pyspark.ml.feature import StringIndexer -from pyspark.ml.regression import RandomForestRegressor -from pyspark.mllib.evaluation import MulticlassMetrics, RegressionMetrics -from pyspark.mllib.util import MLUtils -from pyspark.sql import Row, SQLContext - -""" -A simple example demonstrating a RandomForest Classification/Regression Pipeline. -Run with: - bin/spark-submit examples/src/main/python/ml/random_forest_example.py -""" - - -def testClassification(train, test): - # Train a RandomForest model. - # Setting featureSubsetStrategy="auto" lets the algorithm choose. - # Note: Use larger numTrees in practice. - - rf = RandomForestClassifier(labelCol="indexedLabel", numTrees=3, maxDepth=4) - - model = rf.fit(train) - predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = MulticlassMetrics(predictionAndLabels) - print("weighted f-measure %.3f" % metrics.weightedFMeasure()) - print("precision %s" % metrics.precision()) - print("recall %s" % metrics.recall()) - - -def testRegression(train, test): - # Train a RandomForest model. - # Note: Use larger numTrees in practice. - - rf = RandomForestRegressor(labelCol="indexedLabel", numTrees=3, maxDepth=4) - - model = rf.fit(train) - predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = RegressionMetrics(predictionAndLabels) - print("rmse %.3f" % metrics.rootMeanSquaredError) - print("r2 %.3f" % metrics.r2) - print("mae %.3f" % metrics.meanAbsoluteError) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - print("Usage: random_forest_example", file=sys.stderr) - exit(1) - sc = SparkContext(appName="PythonRandomForestExample") - sqlContext = SQLContext(sc) - - # Load and parse the data file into a dataframe. - df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() - - # Map labels into an indexed column of labels in [0, numLabels) - stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") - si_model = stringIndexer.fit(df) - td = si_model.transform(df) - [train, test] = td.randomSplit([0.7, 0.3]) - testClassification(train, test) - testRegression(train, test) - sc.stop() diff --git a/examples/src/main/python/ml/random_forest_regressor_example.py b/examples/src/main/python/ml/random_forest_regressor_example.py new file mode 100644 index 0000000000000..b77014f379237 --- /dev/null +++ b/examples/src/main/python/ml/random_forest_regressor_example.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Random Forest Regressor Example. +""" +from __future__ import print_function + +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.ml import Pipeline +from pyspark.ml.regression import RandomForestRegressor +from pyspark.ml.feature import VectorIndexer +from pyspark.ml.evaluation import RegressionEvaluator +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="random_forest_regressor_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load and parse the data file, converting it to a DataFrame. + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Automatically identify categorical features, and index them. + # Set maxCategories so features with > 4 distinct values are treated as continuous. + featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a RandomForest model. + rf = RandomForestRegressor(featuresCol="indexedFeatures") + + # Chain indexer and forest in a Pipeline + pipeline = Pipeline(stages=[featureIndexer, rf]) + + # Train model. This also runs the indexer. + model = pipeline.fit(trainingData) + + # Make predictions. + predictions = model.transform(testData) + + # Select example rows to display. + predictions.select("prediction", "label", "features").show(5) + + # Select (prediction, true label) and compute test error + evaluator = RegressionEvaluator( + labelCol="label", predictionCol="prediction", metricName="rmse") + rmse = evaluator.evaluate(predictions) + print("Root Mean Squared Error (RMSE) on test data = %g" % rmse) + + rfModel = model.stages[1] + print(rfModel) # summary only + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/rformula_example.py b/examples/src/main/python/ml/rformula_example.py new file mode 100644 index 0000000000000..b544a14700762 --- /dev/null +++ b/examples/src/main/python/ml/rformula_example.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import RFormula +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="RFormulaExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataset = sqlContext.createDataFrame( + [(7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0)], + ["id", "country", "hour", "clicked"]) + formula = RFormula( + formula="clicked ~ country + hour", + featuresCol="features", + labelCol="label") + output = formula.fit(dataset).transform(dataset) + output.select("features", "label").show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/sql_transformer.py b/examples/src/main/python/ml/sql_transformer.py new file mode 100644 index 0000000000000..9575d728d8159 --- /dev/null +++ b/examples/src/main/python/ml/sql_transformer.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.ml.feature import SQLTransformer +# $example off$ +from pyspark.sql import SQLContext + +if __name__ == "__main__": + sc = SparkContext(appName="SQLTransformerExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame([ + (0, 1.0, 3.0), + (2, 2.0, 5.0) + ], ["id", "v1", "v2"]) + sqlTrans = SQLTransformer( + statement="SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") + sqlTrans.transform(df).show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/standard_scaler_example.py b/examples/src/main/python/ml/standard_scaler_example.py new file mode 100644 index 0000000000000..ae7aa85005bcd --- /dev/null +++ b/examples/src/main/python/ml/standard_scaler_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import StandardScaler +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StandardScalerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", + withStd=True, withMean=False) + + # Compute summary statistics by fitting the StandardScaler + scalerModel = scaler.fit(dataFrame) + + # Normalize each feature to have unit standard deviation. + scaledData = scalerModel.transform(dataFrame) + scaledData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/stopwords_remover_example.py b/examples/src/main/python/ml/stopwords_remover_example.py new file mode 100644 index 0000000000000..01f94af8ca752 --- /dev/null +++ b/examples/src/main/python/ml/stopwords_remover_example.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import StopWordsRemover +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StopWordsRemoverExample") + sqlContext = SQLContext(sc) + + # $example on$ + sentenceData = sqlContext.createDataFrame([ + (0, ["I", "saw", "the", "red", "baloon"]), + (1, ["Mary", "had", "a", "little", "lamb"]) + ], ["label", "raw"]) + + remover = StopWordsRemover(inputCol="raw", outputCol="filtered") + remover.transform(sentenceData).show(truncate=False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/string_indexer_example.py b/examples/src/main/python/ml/string_indexer_example.py new file mode 100644 index 0000000000000..58a8cb5d56b73 --- /dev/null +++ b/examples/src/main/python/ml/string_indexer_example.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import StringIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StringIndexerExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame( + [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], + ["id", "category"]) + indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") + indexed = indexer.fit(df).transform(df) + indexed.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/tokenizer_example.py b/examples/src/main/python/ml/tokenizer_example.py new file mode 100644 index 0000000000000..ce9b225be5357 --- /dev/null +++ b/examples/src/main/python/ml/tokenizer_example.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Tokenizer, RegexTokenizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="TokenizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + sentenceDataFrame = sqlContext.createDataFrame([ + (0, "Hi I heard about Spark"), + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") + ], ["label", "sentence"]) + tokenizer = Tokenizer(inputCol="sentence", outputCol="words") + wordsDataFrame = tokenizer.transform(sentenceDataFrame) + for words_label in wordsDataFrame.select("words", "label").take(3): + print(words_label) + regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") + # alternatively, pattern="\\w+", gaps(False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/train_validation_split.py b/examples/src/main/python/ml/train_validation_split.py new file mode 100644 index 0000000000000..161a200c61b6d --- /dev/null +++ b/examples/src/main/python/ml/train_validation_split.py @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import SparkContext +# $example on$ +from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.ml.regression import LinearRegression +from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit +from pyspark.sql import SQLContext +# $example off$ + +""" +This example demonstrates applying TrainValidationSplit to split data +and preform model selection. +Run with: + + bin/spark-submit examples/src/main/python/ml/train_validation_split.py +""" + +if __name__ == "__main__": + sc = SparkContext(appName="TrainValidationSplit") + sqlContext = SQLContext(sc) + # $example on$ + # Prepare training and test data. + data = sqlContext.read.format("libsvm")\ + .load("data/mllib/sample_linear_regression_data.txt") + train, test = data.randomSplit([0.7, 0.3]) + lr = LinearRegression(maxIter=10, regParam=0.1) + + # We use a ParamGridBuilder to construct a grid of parameters to search over. + # TrainValidationSplit will try all combinations of values and determine best model using + # the evaluator. + paramGrid = ParamGridBuilder()\ + .addGrid(lr.regParam, [0.1, 0.01]) \ + .addGrid(lr.elasticNetParam, [0.0, 0.5, 1.0])\ + .build() + + # In this case the estimator is simply the linear regression. + # A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + tvs = TrainValidationSplit(estimator=lr, + estimatorParamMaps=paramGrid, + evaluator=RegressionEvaluator(), + # 80% of the data will be used for training, 20% for validation. + trainRatio=0.8) + + # Run TrainValidationSplit, and choose the best set of parameters. + model = tvs.fit(train) + # Make predictions on test data. model is the model with combination of parameters + # that performed best. + prediction = model.transform(test) + for row in prediction.take(5): + print(row) + # $example off$ + sc.stop() diff --git a/examples/src/main/python/ml/vector_assembler_example.py b/examples/src/main/python/ml/vector_assembler_example.py new file mode 100644 index 0000000000000..04f64839f188d --- /dev/null +++ b/examples/src/main/python/ml/vector_assembler_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.mllib.linalg import Vectors +from pyspark.ml.feature import VectorAssembler +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="VectorAssemblerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataset = sqlContext.createDataFrame( + [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], + ["id", "hour", "mobile", "userFeatures", "clicked"]) + assembler = VectorAssembler( + inputCols=["hour", "mobile", "userFeatures"], + outputCol="features") + output = assembler.transform(dataset) + print(output.select("features", "clicked").first()) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/vector_indexer_example.py b/examples/src/main/python/ml/vector_indexer_example.py new file mode 100644 index 0000000000000..146f41c1dd903 --- /dev/null +++ b/examples/src/main/python/ml/vector_indexer_example.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import VectorIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="VectorIndexerExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) + indexerModel = indexer.fit(data) + + # Create new column "indexed" with categorical values transformed to indices + indexedData = indexerModel.transform(data) + indexedData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/mllib/binary_classification_metrics_example.py b/examples/src/main/python/mllib/binary_classification_metrics_example.py new file mode 100644 index 0000000000000..4e7ea289b2532 --- /dev/null +++ b/examples/src/main/python/mllib/binary_classification_metrics_example.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Binary Classification Metrics Example. +""" +from __future__ import print_function +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.evaluation import BinaryClassificationMetrics +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="BinaryClassificationMetricsExample") + sqlContext = SQLContext(sc) + # $example on$ + # Several of the methods available in scala are currently missing from pyspark + # Load training data in LIBSVM format + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + + # Split data into training (60%) and test (40%) + training, test = data.randomSplit([0.6, 0.4], seed=11L) + training.cache() + + # Run training algorithm to build the model + model = LogisticRegressionWithLBFGS.train(training) + + # Compute raw scores on the test set + predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + + # Instantiate metrics object + metrics = BinaryClassificationMetrics(predictionAndLabels) + + # Area under precision-recall curve + print("Area under PR = %s" % metrics.areaUnderPR) + + # Area under ROC curve + print("Area under ROC = %s" % metrics.areaUnderROC) + # $example off$ diff --git a/examples/src/main/python/mllib/bisecting_k_means_example.py b/examples/src/main/python/mllib/bisecting_k_means_example.py new file mode 100644 index 0000000000000..7f4d0402d620c --- /dev/null +++ b/examples/src/main/python/mllib/bisecting_k_means_example.py @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +# $example on$ +from numpy import array +# $example off$ + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.clustering import BisectingKMeans, BisectingKMeansModel +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonBisectingKMeansExample") # SparkContext + + # $example on$ + # Load and parse the data + data = sc.textFile("data/mllib/kmeans_data.txt") + parsedData = data.map(lambda line: array([float(x) for x in line.split(' ')])) + + # Build the model (cluster the data) + model = BisectingKMeans.train(parsedData, 2, maxIterations=5) + + # Evaluate clustering + cost = model.computeCost(parsedData) + print("Bisecting K-means Cost = " + str(cost)) + + # Save and load model + path = "target/org/apache/spark/PythonBisectingKMeansExample/BisectingKMeansModel" + model.save(sc, path) + sameModel = BisectingKMeansModel.load(sc, path) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/mllib/correlations_example.py b/examples/src/main/python/mllib/correlations_example.py new file mode 100644 index 0000000000000..66d18f6e5df17 --- /dev/null +++ b/examples/src/main/python/mllib/correlations_example.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import numpy as np + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.stat import Statistics +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="CorrelationsExample") # SparkContext + + # $example on$ + seriesX = sc.parallelize([1.0, 2.0, 3.0, 3.0, 5.0]) # a series + # seriesY must have the same number of partitions and cardinality as seriesX + seriesY = sc.parallelize([11.0, 22.0, 33.0, 33.0, 555.0]) + + # Compute the correlation using Pearson's method. Enter "spearman" for Spearman's method. + # If a method is not specified, Pearson's method will be used by default. + print("Correlation is: " + str(Statistics.corr(seriesX, seriesY, method="pearson"))) + + data = sc.parallelize( + [np.array([1.0, 10.0, 100.0]), np.array([2.0, 20.0, 200.0]), np.array([5.0, 33.0, 366.0])] + ) # an RDD of Vectors + + # calculate the correlation matrix using Pearson's method. Use "spearman" for Spearman's method. + # If a method is not specified, Pearson's method will be used by default. + print(Statistics.corr(data, method="pearson")) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py deleted file mode 100644 index e23ecc0c5d302..0000000000000 --- a/examples/src/main/python/mllib/dataset_example.py +++ /dev/null @@ -1,63 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -An example of how to use DataFrame as a dataset for ML. Run with:: - bin/spark-submit examples/src/main/python/mllib/dataset_example.py -""" -from __future__ import print_function - -import os -import sys -import tempfile -import shutil - -from pyspark import SparkContext -from pyspark.sql import SQLContext -from pyspark.mllib.util import MLUtils -from pyspark.mllib.stat import Statistics - - -def summarize(dataset): - print("schema: %s" % dataset.schema().json()) - labels = dataset.map(lambda r: r.label) - print("label average: %f" % labels.mean()) - features = dataset.map(lambda r: r.features) - summary = Statistics.colStats(features) - print("features average: %r" % summary.mean()) - -if __name__ == "__main__": - if len(sys.argv) > 2: - print("Usage: dataset_example.py ", file=sys.stderr) - exit(-1) - sc = SparkContext(appName="DatasetExample") - sqlContext = SQLContext(sc) - if len(sys.argv) == 2: - input = sys.argv[1] - else: - input = "data/mllib/sample_libsvm_data.txt" - points = MLUtils.loadLibSVMFile(sc, input) - dataset0 = sqlContext.inferSchema(points).setName("dataset0").cache() - summarize(dataset0) - tempdir = tempfile.NamedTemporaryFile(delete=False).name - os.unlink(tempdir) - print("Save dataset as a Parquet file to %s." % tempdir) - dataset0.saveAsParquetFile(tempdir) - print("Load it back and summarize it again.") - dataset1 = sqlContext.parquetFile(tempdir).setName("dataset1").cache() - summarize(dataset1) - shutil.rmtree(tempdir) diff --git a/examples/src/main/python/mllib/decision_tree_classification_example.py b/examples/src/main/python/mllib/decision_tree_classification_example.py new file mode 100644 index 0000000000000..1b529768b6c62 --- /dev/null +++ b/examples/src/main/python/mllib/decision_tree_classification_example.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision Tree Classification Example. +""" +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.tree import DecisionTree, DecisionTreeModel +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonDecisionTreeClassificationExample") + + # $example on$ + # Load and parse the data file into an RDD of LabeledPoint. + data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a DecisionTree model. + # Empty categoricalFeaturesInfo indicates all features are continuous. + model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={}, + impurity='gini', maxDepth=5, maxBins=32) + + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) + print('Test Error = ' + str(testErr)) + print('Learned classification tree model:') + print(model.toDebugString()) + + # Save and load model + model.save(sc, "target/tmp/myDecisionTreeClassificationModel") + sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel") + # $example off$ diff --git a/examples/src/main/python/mllib/decision_tree_regression_example.py b/examples/src/main/python/mllib/decision_tree_regression_example.py new file mode 100644 index 0000000000000..cf518eac67e81 --- /dev/null +++ b/examples/src/main/python/mllib/decision_tree_regression_example.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision Tree Regression Example. +""" +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.tree import DecisionTree, DecisionTreeModel +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonDecisionTreeRegressionExample") + + # $example on$ + # Load and parse the data file into an RDD of LabeledPoint. + data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a DecisionTree model. + # Empty categoricalFeaturesInfo indicates all features are continuous. + model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo={}, + impurity='variance', maxDepth=5, maxBins=32) + + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() /\ + float(testData.count()) + print('Test Mean Squared Error = ' + str(testMSE)) + print('Learned regression tree model:') + print(model.toDebugString()) + + # Save and load model + model.save(sc, "target/tmp/myDecisionTreeRegressionModel") + sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeRegressionModel") + # $example off$ diff --git a/examples/src/main/python/mllib/decision_tree_runner.py b/examples/src/main/python/mllib/decision_tree_runner.py deleted file mode 100755 index 513ed8fd51450..0000000000000 --- a/examples/src/main/python/mllib/decision_tree_runner.py +++ /dev/null @@ -1,144 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Decision tree classification and regression using MLlib. - -This example requires NumPy (http://www.numpy.org/). -""" -from __future__ import print_function - -import numpy -import os -import sys - -from operator import add - -from pyspark import SparkContext -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.tree import DecisionTree -from pyspark.mllib.util import MLUtils - - -def getAccuracy(dtModel, data): - """ - Return accuracy of DecisionTreeModel on the given RDD[LabeledPoint]. - """ - seqOp = (lambda acc, x: acc + (x[0] == x[1])) - predictions = dtModel.predict(data.map(lambda x: x.features)) - truth = data.map(lambda p: p.label) - trainCorrect = predictions.zip(truth).aggregate(0, seqOp, add) - if data.count() == 0: - return 0 - return trainCorrect / (0.0 + data.count()) - - -def getMSE(dtModel, data): - """ - Return mean squared error (MSE) of DecisionTreeModel on the given - RDD[LabeledPoint]. - """ - seqOp = (lambda acc, x: acc + numpy.square(x[0] - x[1])) - predictions = dtModel.predict(data.map(lambda x: x.features)) - truth = data.map(lambda p: p.label) - trainMSE = predictions.zip(truth).aggregate(0, seqOp, add) - if data.count() == 0: - return 0 - return trainMSE / (0.0 + data.count()) - - -def reindexClassLabels(data): - """ - Re-index class labels in a dataset to the range {0,...,numClasses-1}. - If all labels in that range already appear at least once, - then the returned RDD is the same one (without a mapping). - Note: If a label simply does not appear in the data, - the index will not include it. - Be aware of this when reindexing subsampled data. - :param data: RDD of LabeledPoint where labels are integer values - denoting labels for a classification problem. - :return: Pair (reindexedData, origToNewLabels) where - reindexedData is an RDD of LabeledPoint with labels in - the range {0,...,numClasses-1}, and - origToNewLabels is a dictionary mapping original labels - to new labels. - """ - # classCounts: class --> # examples in class - classCounts = data.map(lambda x: x.label).countByValue() - numExamples = sum(classCounts.values()) - sortedClasses = sorted(classCounts.keys()) - numClasses = len(classCounts) - # origToNewLabels: class --> index in 0,...,numClasses-1 - if (numClasses < 2): - print("Dataset for classification should have at least 2 classes." - " The given dataset had only %d classes." % numClasses, file=sys.stderr) - exit(1) - origToNewLabels = dict([(sortedClasses[i], i) for i in range(0, numClasses)]) - - print("numClasses = %d" % numClasses) - print("Per-class example fractions, counts:") - print("Class\tFrac\tCount") - for c in sortedClasses: - frac = classCounts[c] / (numExamples + 0.0) - print("%g\t%g\t%d" % (c, frac, classCounts[c])) - - if (sortedClasses[0] == 0 and sortedClasses[-1] == numClasses - 1): - return (data, origToNewLabels) - else: - reindexedData = \ - data.map(lambda x: LabeledPoint(origToNewLabels[x.label], x.features)) - return (reindexedData, origToNewLabels) - - -def usage(): - print("Usage: decision_tree_runner [libsvm format data filepath]", file=sys.stderr) - exit(1) - - -if __name__ == "__main__": - if len(sys.argv) > 2: - usage() - sc = SparkContext(appName="PythonDT") - - # Load data. - dataPath = 'data/mllib/sample_libsvm_data.txt' - if len(sys.argv) == 2: - dataPath = sys.argv[1] - if not os.path.isfile(dataPath): - sc.stop() - usage() - points = MLUtils.loadLibSVMFile(sc, dataPath) - - # Re-index class labels if needed. - (reindexedData, origToNewLabels) = reindexClassLabels(points) - numClasses = len(origToNewLabels) - - # Train a classifier. - categoricalFeaturesInfo = {} # no categorical features - model = DecisionTree.trainClassifier(reindexedData, numClasses=numClasses, - categoricalFeaturesInfo=categoricalFeaturesInfo) - # Print learned tree and stats. - print("Trained DecisionTree for classification:") - print(" Model numNodes: %d" % model.numNodes()) - print(" Model depth: %d" % model.depth()) - print(" Training accuracy: %g" % getAccuracy(model, reindexedData)) - if model.numNodes() < 20: - print(model.toDebugString()) - else: - print(model) - - sc.stop() diff --git a/examples/src/main/python/mllib/elementwise_product_example.py b/examples/src/main/python/mllib/elementwise_product_example.py new file mode 100644 index 0000000000000..6d8bf6d42e08d --- /dev/null +++ b/examples/src/main/python/mllib/elementwise_product_example.py @@ -0,0 +1,51 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.feature import ElementwiseProduct +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="ElementwiseProductExample") # SparkContext + + # $example on$ + data = sc.textFile("data/mllib/kmeans_data.txt") + parsedData = data.map(lambda x: [float(t) for t in x.split(" ")]) + + # Create weight vector. + transformingVector = Vectors.dense([0.0, 1.0, 2.0]) + transformer = ElementwiseProduct(transformingVector) + + # Batch transform + transformedData = transformer.transform(parsedData) + # Single-row transform + transformedData2 = transformer.transform(parsedData.first()) + # $example off$ + + print("transformedData:") + for each in transformedData.collect(): + print(each) + + print("transformedData2:") + for each in transformedData2.collect(): + print(each) + + sc.stop() diff --git a/examples/src/main/python/mllib/gaussian_mixture_example.py b/examples/src/main/python/mllib/gaussian_mixture_example.py new file mode 100644 index 0000000000000..a60e799d62eb1 --- /dev/null +++ b/examples/src/main/python/mllib/gaussian_mixture_example.py @@ -0,0 +1,51 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +# $example on$ +from numpy import array +# $example off$ + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.clustering import GaussianMixture, GaussianMixtureModel +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="GaussianMixtureExample") # SparkContext + + # $example on$ + # Load and parse the data + data = sc.textFile("data/mllib/gmm_data.txt") + parsedData = data.map(lambda line: array([float(x) for x in line.strip().split(' ')])) + + # Build the model (cluster the data) + gmm = GaussianMixture.train(parsedData, 2) + + # Save and load model + gmm.save(sc, "target/org/apache/spark/PythonGaussianMixtureExample/GaussianMixtureModel") + sameModel = GaussianMixtureModel\ + .load(sc, "target/org/apache/spark/PythonGaussianMixtureExample/GaussianMixtureModel") + + # output parameters of model + for i in range(2): + print("weight = ", gmm.weights[i], "mu = ", gmm.gaussians[i].mu, + "sigma = ", gmm.gaussians[i].sigma.toArray()) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/mllib/gaussian_mixture_model.py b/examples/src/main/python/mllib/gaussian_mixture_model.py index 2cb8010cdc07f..69e836fc1d06a 100644 --- a/examples/src/main/python/mllib/gaussian_mixture_model.py +++ b/examples/src/main/python/mllib/gaussian_mixture_model.py @@ -62,5 +62,9 @@ def parseVector(line): for i in range(args.k): print(("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu, "sigma = ", model.gaussians[i].sigma.toArray())) + print("\n") + print(("The membership value of each vector to all mixture components (first 100): ", + model.predictSoft(data).take(100))) + print("\n") print(("Cluster labels (first 100): ", model.predict(data).take(100))) sc.stop() diff --git a/examples/src/main/python/mllib/gradient_boosted_trees.py b/examples/src/main/python/mllib/gradient_boosted_trees.py deleted file mode 100644 index 781bd61c9d2b5..0000000000000 --- a/examples/src/main/python/mllib/gradient_boosted_trees.py +++ /dev/null @@ -1,77 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Gradient boosted Trees classification and regression using MLlib. -""" -from __future__ import print_function - -import sys - -from pyspark.context import SparkContext -from pyspark.mllib.tree import GradientBoostedTrees -from pyspark.mllib.util import MLUtils - - -def testClassification(trainingData, testData): - # Train a GradientBoostedTrees model. - # Empty categoricalFeaturesInfo indicates all features are continuous. - model = GradientBoostedTrees.trainClassifier(trainingData, categoricalFeaturesInfo={}, - numIterations=30, maxDepth=4) - # Evaluate model on test instances and compute test error - predictions = model.predict(testData.map(lambda x: x.features)) - labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testErr = labelsAndPredictions.filter(lambda v_p: v_p[0] != v_p[1]).count() \ - / float(testData.count()) - print('Test Error = ' + str(testErr)) - print('Learned classification ensemble model:') - print(model.toDebugString()) - - -def testRegression(trainingData, testData): - # Train a GradientBoostedTrees model. - # Empty categoricalFeaturesInfo indicates all features are continuous. - model = GradientBoostedTrees.trainRegressor(trainingData, categoricalFeaturesInfo={}, - numIterations=30, maxDepth=4) - # Evaluate model on test instances and compute test error - predictions = model.predict(testData.map(lambda x: x.features)) - labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testMSE = labelsAndPredictions.map(lambda vp: (vp[0] - vp[1]) * (vp[0] - vp[1])).sum() \ - / float(testData.count()) - print('Test Mean Squared Error = ' + str(testMSE)) - print('Learned regression ensemble model:') - print(model.toDebugString()) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - print("Usage: gradient_boosted_trees", file=sys.stderr) - exit(1) - sc = SparkContext(appName="PythonGradientBoostedTrees") - - # Load and parse the data file into an RDD of LabeledPoint. - data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') - # Split the data into training and test sets (30% held out for testing) - (trainingData, testData) = data.randomSplit([0.7, 0.3]) - - print('\nRunning example of classification using GradientBoostedTrees\n') - testClassification(trainingData, testData) - - print('\nRunning example of regression using GradientBoostedTrees\n') - testRegression(trainingData, testData) - - sc.stop() diff --git a/examples/src/main/python/mllib/gradient_boosting_classification_example.py b/examples/src/main/python/mllib/gradient_boosting_classification_example.py new file mode 100644 index 0000000000000..b204cd1b31c86 --- /dev/null +++ b/examples/src/main/python/mllib/gradient_boosting_classification_example.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Gradient Boosted Trees Classification Example. +""" +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonGradientBoostedTreesClassificationExample") + # $example on$ + # Load and parse the data file. + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a GradientBoostedTrees model. + # Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous. + # (b) Use more iterations in practice. + model = GradientBoostedTrees.trainClassifier(trainingData, + categoricalFeaturesInfo={}, numIterations=3) + + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) + print('Test Error = ' + str(testErr)) + print('Learned classification GBT model:') + print(model.toDebugString()) + + # Save and load model + model.save(sc, "target/tmp/myGradientBoostingClassificationModel") + sameModel = GradientBoostedTreesModel.load(sc, + "target/tmp/myGradientBoostingClassificationModel") + # $example off$ diff --git a/examples/src/main/python/mllib/gradient_boosting_regression_example.py b/examples/src/main/python/mllib/gradient_boosting_regression_example.py new file mode 100644 index 0000000000000..758e224a9e21d --- /dev/null +++ b/examples/src/main/python/mllib/gradient_boosting_regression_example.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Gradient Boosted Trees Regression Example. +""" +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonGradientBoostedTreesRegressionExample") + # $example on$ + # Load and parse the data file. + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a GradientBoostedTrees model. + # Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous. + # (b) Use more iterations in practice. + model = GradientBoostedTrees.trainRegressor(trainingData, + categoricalFeaturesInfo={}, numIterations=3) + + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() /\ + float(testData.count()) + print('Test Mean Squared Error = ' + str(testMSE)) + print('Learned regression GBT model:') + print(model.toDebugString()) + + # Save and load model + model.save(sc, "target/tmp/myGradientBoostingRegressionModel") + sameModel = GradientBoostedTreesModel.load(sc, "target/tmp/myGradientBoostingRegressionModel") + # $example off$ diff --git a/examples/src/main/python/mllib/hypothesis_testing_example.py b/examples/src/main/python/mllib/hypothesis_testing_example.py new file mode 100644 index 0000000000000..e566ead0d318d --- /dev/null +++ b/examples/src/main/python/mllib/hypothesis_testing_example.py @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.linalg import Matrices, Vectors +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.stat import Statistics +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="HypothesisTestingExample") + + # $example on$ + vec = Vectors.dense(0.1, 0.15, 0.2, 0.3, 0.25) # a vector composed of the frequencies of events + + # compute the goodness of fit. If a second vector to test against + # is not supplied as a parameter, the test runs against a uniform distribution. + goodnessOfFitTestResult = Statistics.chiSqTest(vec) + + # summary of the test including the p-value, degrees of freedom, + # test statistic, the method used, and the null hypothesis. + print("%s\n" % goodnessOfFitTestResult) + + mat = Matrices.dense(3, 2, [1.0, 3.0, 5.0, 2.0, 4.0, 6.0]) # a contingency matrix + + # conduct Pearson's independence test on the input contingency matrix + independenceTestResult = Statistics.chiSqTest(mat) + + # summary of the test including the p-value, degrees of freedom, + # test statistic, the method used, and the null hypothesis. + print("%s\n" % independenceTestResult) + + obs = sc.parallelize( + [LabeledPoint(1.0, [1.0, 0.0, 3.0]), + LabeledPoint(1.0, [1.0, 2.0, 0.0]), + LabeledPoint(1.0, [-1.0, 0.0, -0.5])] + ) # LabeledPoint(feature, label) + + # The contingency table is constructed from an RDD of LabeledPoint and used to conduct + # the independence test. Returns an array containing the ChiSquaredTestResult for every feature + # against the label. + featureTestResults = Statistics.chiSqTest(obs) + + for i, result in enumerate(featureTestResults): + print("Column %d:\n%s" % (i + 1, result)) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/mllib/hypothesis_testing_kolmogorov_smirnov_test_example.py b/examples/src/main/python/mllib/hypothesis_testing_kolmogorov_smirnov_test_example.py new file mode 100644 index 0000000000000..ef380dee79d3d --- /dev/null +++ b/examples/src/main/python/mllib/hypothesis_testing_kolmogorov_smirnov_test_example.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.stat import Statistics +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="HypothesisTestingKolmogorovSmirnovTestExample") + + # $example on$ + parallelData = sc.parallelize([0.1, 0.15, 0.2, 0.3, 0.25]) + + # run a KS test for the sample versus a standard normal distribution + testResult = Statistics.kolmogorovSmirnovTest(parallelData, "norm", 0, 1) + # summary of the test including the p-value, test statistic, and null hypothesis + # if our p-value indicates significance, we can reject the null hypothesis + # Note that the Scala functionality of calling Statistics.kolmogorovSmirnovTest with + # a lambda to calculate the CDF is not made available in the Python API + print(testResult) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/mllib/k_means_example.py b/examples/src/main/python/mllib/k_means_example.py new file mode 100644 index 0000000000000..5c397e62ef10e --- /dev/null +++ b/examples/src/main/python/mllib/k_means_example.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +# $example on$ +from numpy import array +from math import sqrt +# $example off$ + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.clustering import KMeans, KMeansModel +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="KMeansExample") # SparkContext + + # $example on$ + # Load and parse the data + data = sc.textFile("data/mllib/kmeans_data.txt") + parsedData = data.map(lambda line: array([float(x) for x in line.split(' ')])) + + # Build the model (cluster the data) + clusters = KMeans.train(parsedData, 2, maxIterations=10, + runs=10, initializationMode="random") + + # Evaluate clustering by computing Within Set Sum of Squared Errors + def error(point): + center = clusters.centers[clusters.predict(point)] + return sqrt(sum([x**2 for x in (point - center)])) + + WSSSE = parsedData.map(lambda point: error(point)).reduce(lambda x, y: x + y) + print("Within Set Sum of Squared Error = " + str(WSSSE)) + + # Save and load model + clusters.save(sc, "target/org/apache/spark/PythonKMeansExample/KMeansModel") + sameModel = KMeansModel.load(sc, "target/org/apache/spark/PythonKMeansExample/KMeansModel") + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/mllib/kernel_density_estimation_example.py b/examples/src/main/python/mllib/kernel_density_estimation_example.py new file mode 100644 index 0000000000000..3e8f7241a4a1e --- /dev/null +++ b/examples/src/main/python/mllib/kernel_density_estimation_example.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.stat import KernelDensity +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="KernelDensityEstimationExample") # SparkContext + + # $example on$ + # an RDD of sample data + data = sc.parallelize([1.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 6.0, 7.0, 8.0, 9.0, 9.0]) + + # Construct the density estimator with the sample data and a standard deviation for the Gaussian + # kernels + kd = KernelDensity() + kd.setSample(data) + kd.setBandwidth(3.0) + + # Find density estimates for the given values + densities = kd.estimate([-1.0, 2.0, 5.0]) + # $example off$ + + print(densities) + + sc.stop() diff --git a/examples/src/main/python/mllib/latent_dirichlet_allocation_example.py b/examples/src/main/python/mllib/latent_dirichlet_allocation_example.py new file mode 100644 index 0000000000000..2a1bef5f207b7 --- /dev/null +++ b/examples/src/main/python/mllib/latent_dirichlet_allocation_example.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.clustering import LDA, LDAModel +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="LatentDirichletAllocationExample") # SparkContext + + # $example on$ + # Load and parse the data + data = sc.textFile("data/mllib/sample_lda_data.txt") + parsedData = data.map(lambda line: Vectors.dense([float(x) for x in line.strip().split(' ')])) + # Index documents with unique IDs + corpus = parsedData.zipWithIndex().map(lambda x: [x[1], x[0]]).cache() + + # Cluster the documents into three topics using LDA + ldaModel = LDA.train(corpus, k=3) + + # Output topics. Each is a distribution over words (matching word count vectors) + print("Learned topics (as distributions over vocab of " + str(ldaModel.vocabSize()) + + " words):") + topics = ldaModel.topicsMatrix() + for topic in range(3): + print("Topic " + str(topic) + ":") + for word in range(0, ldaModel.vocabSize()): + print(" " + str(topics[word][topic])) + + # Save and load model + ldaModel.save(sc, "target/org/apache/spark/PythonLatentDirichletAllocationExample/LDAModel") + sameModel = LDAModel\ + .load(sc, "target/org/apache/spark/PythonLatentDirichletAllocationExample/LDAModel") + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/mllib/linear_regression_with_sgd_example.py b/examples/src/main/python/mllib/linear_regression_with_sgd_example.py new file mode 100644 index 0000000000000..6fbaeff0cd5a0 --- /dev/null +++ b/examples/src/main/python/mllib/linear_regression_with_sgd_example.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Linear Regression With SGD Example. +""" +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD, LinearRegressionModel +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonLinearRegressionWithSGDExample") + + # $example on$ + # Load and parse the data + def parsePoint(line): + values = [float(x) for x in line.replace(',', ' ').split(' ')] + return LabeledPoint(values[0], values[1:]) + + data = sc.textFile("data/mllib/ridge-data/lpsa.data") + parsedData = data.map(parsePoint) + + # Build the model + model = LinearRegressionWithSGD.train(parsedData, iterations=100, step=0.00000001) + + # Evaluate the model on training data + valuesAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) + MSE = valuesAndPreds \ + .map(lambda (v, p): (v - p)**2) \ + .reduce(lambda x, y: x + y) / valuesAndPreds.count() + print("Mean Squared Error = " + str(MSE)) + + # Save and load model + model.save(sc, "target/tmp/pythonLinearRegressionWithSGDModel") + sameModel = LinearRegressionModel.load(sc, "target/tmp/pythonLinearRegressionWithSGDModel") + # $example off$ diff --git a/examples/src/main/python/mllib/logistic_regression_with_lbfgs_example.py b/examples/src/main/python/mllib/logistic_regression_with_lbfgs_example.py new file mode 100644 index 0000000000000..e030b74ba6b15 --- /dev/null +++ b/examples/src/main/python/mllib/logistic_regression_with_lbfgs_example.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Logistic Regression With LBFGS Example. +""" +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel +from pyspark.mllib.regression import LabeledPoint +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonLogisticRegressionWithLBFGSExample") + + # $example on$ + # Load and parse the data + def parsePoint(line): + values = [float(x) for x in line.split(' ')] + return LabeledPoint(values[0], values[1:]) + + data = sc.textFile("data/mllib/sample_svm_data.txt") + parsedData = data.map(parsePoint) + + # Build the model + model = LogisticRegressionWithLBFGS.train(parsedData) + + # Evaluating the model on training data + labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) + trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) + print("Training Error = " + str(trainErr)) + + # Save and load model + model.save(sc, "target/tmp/pythonLogisticRegressionWithLBFGSModel") + sameModel = LogisticRegressionModel.load(sc, + "target/tmp/pythonLogisticRegressionWithLBFGSModel") + # $example off$ diff --git a/examples/src/main/python/mllib/multi_class_metrics_example.py b/examples/src/main/python/mllib/multi_class_metrics_example.py new file mode 100644 index 0000000000000..cd56b3c97c778 --- /dev/null +++ b/examples/src/main/python/mllib/multi_class_metrics_example.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.util import MLUtils +from pyspark.mllib.evaluation import MulticlassMetrics +# $example off$ + +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="MultiClassMetricsExample") + + # Several of the methods available in scala are currently missing from pyspark + # $example on$ + # Load training data in LIBSVM format + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + + # Split data into training (60%) and test (40%) + training, test = data.randomSplit([0.6, 0.4], seed=11L) + training.cache() + + # Run training algorithm to build the model + model = LogisticRegressionWithLBFGS.train(training, numClasses=3) + + # Compute raw scores on the test set + predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + + # Instantiate metrics object + metrics = MulticlassMetrics(predictionAndLabels) + + # Overall statistics + precision = metrics.precision() + recall = metrics.recall() + f1Score = metrics.fMeasure() + print("Summary Stats") + print("Precision = %s" % precision) + print("Recall = %s" % recall) + print("F1 Score = %s" % f1Score) + + # Statistics by class + labels = data.map(lambda lp: lp.label).distinct().collect() + for label in sorted(labels): + print("Class %s precision = %s" % (label, metrics.precision(label))) + print("Class %s recall = %s" % (label, metrics.recall(label))) + print("Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0))) + + # Weighted stats + print("Weighted recall = %s" % metrics.weightedRecall) + print("Weighted precision = %s" % metrics.weightedPrecision) + print("Weighted F(1) Score = %s" % metrics.weightedFMeasure()) + print("Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5)) + print("Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate) + # $example off$ diff --git a/examples/src/main/python/mllib/multi_label_metrics_example.py b/examples/src/main/python/mllib/multi_label_metrics_example.py new file mode 100644 index 0000000000000..960ade6597379 --- /dev/null +++ b/examples/src/main/python/mllib/multi_label_metrics_example.py @@ -0,0 +1,61 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.mllib.evaluation import MultilabelMetrics +# $example off$ +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="MultiLabelMetricsExample") + # $example on$ + scoreAndLabels = sc.parallelize([ + ([0.0, 1.0], [0.0, 2.0]), + ([0.0, 2.0], [0.0, 1.0]), + ([], [0.0]), + ([2.0], [2.0]), + ([2.0, 0.0], [2.0, 0.0]), + ([0.0, 1.0, 2.0], [0.0, 1.0]), + ([1.0], [1.0, 2.0])]) + + # Instantiate metrics object + metrics = MultilabelMetrics(scoreAndLabels) + + # Summary stats + print("Recall = %s" % metrics.recall()) + print("Precision = %s" % metrics.precision()) + print("F1 measure = %s" % metrics.f1Measure()) + print("Accuracy = %s" % metrics.accuracy) + + # Individual label stats + labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect() + for label in labels: + print("Class %s precision = %s" % (label, metrics.precision(label))) + print("Class %s recall = %s" % (label, metrics.recall(label))) + print("Class %s F1 Measure = %s" % (label, metrics.f1Measure(label))) + + # Micro stats + print("Micro precision = %s" % metrics.microPrecision) + print("Micro recall = %s" % metrics.microRecall) + print("Micro F1 measure = %s" % metrics.microF1Measure) + + # Hamming loss + print("Hamming loss = %s" % metrics.hammingLoss) + + # Subset accuracy + print("Subset accuracy = %s" % metrics.subsetAccuracy) + # $example off$ diff --git a/examples/src/main/python/mllib/naive_bayes_example.py b/examples/src/main/python/mllib/naive_bayes_example.py index a2e7dacf25491..35724f7d6a92d 100644 --- a/examples/src/main/python/mllib/naive_bayes_example.py +++ b/examples/src/main/python/mllib/naive_bayes_example.py @@ -17,9 +17,16 @@ """ NaiveBayes Example. + +Usage: + `spark-submit --master local[4] examples/src/main/python/mllib/naive_bayes_example.py` """ + from __future__ import print_function +import shutil + +from pyspark import SparkContext # $example on$ from pyspark.mllib.classification import NaiveBayes, NaiveBayesModel from pyspark.mllib.linalg import Vectors @@ -40,7 +47,7 @@ def parseLine(line): # $example on$ data = sc.textFile('data/mllib/sample_naive_bayes_data.txt').map(parseLine) - # Split data aproximately into training (60%) and test (40%) + # Split data approximately into training (60%) and test (40%) training, test = data.randomSplit([0.6, 0.4], seed=0) # Train a naive Bayes model. @@ -49,8 +56,15 @@ def parseLine(line): # Make prediction and test accuracy. predictionAndLabel = test.map(lambda p: (model.predict(p.features), p.label)) accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count() + print('model accuracy {}'.format(accuracy)) # Save and load model - model.save(sc, "target/tmp/myNaiveBayesModel") - sameModel = NaiveBayesModel.load(sc, "target/tmp/myNaiveBayesModel") + output_dir = 'target/tmp/myNaiveBayesModel' + shutil.rmtree(output_dir, ignore_errors=True) + model.save(sc, output_dir) + sameModel = NaiveBayesModel.load(sc, output_dir) + predictionAndLabel = test.map(lambda p: (sameModel.predict(p.features), p.label)) + accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count() + print('sameModel accuracy {}'.format(accuracy)) + # $example off$ diff --git a/examples/src/main/python/mllib/normalizer_example.py b/examples/src/main/python/mllib/normalizer_example.py new file mode 100644 index 0000000000000..a4e028ca9af8b --- /dev/null +++ b/examples/src/main/python/mllib/normalizer_example.py @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.feature import Normalizer +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="NormalizerExample") # SparkContext + + # $example on$ + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + labels = data.map(lambda x: x.label) + features = data.map(lambda x: x.features) + + normalizer1 = Normalizer() + normalizer2 = Normalizer(p=float("inf")) + + # Each sample in data1 will be normalized using $L^2$ norm. + data1 = labels.zip(normalizer1.transform(features)) + + # Each sample in data2 will be normalized using $L^\infty$ norm. + data2 = labels.zip(normalizer2.transform(features)) + # $example off$ + + print("data1:") + for each in data1.collect(): + print(each) + + print("data2:") + for each in data2.collect(): + print(each) + + sc.stop() diff --git a/examples/src/main/python/mllib/power_iteration_clustering_example.py b/examples/src/main/python/mllib/power_iteration_clustering_example.py new file mode 100644 index 0000000000000..ca19c0ccb60c8 --- /dev/null +++ b/examples/src/main/python/mllib/power_iteration_clustering_example.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.clustering import PowerIterationClustering, PowerIterationClusteringModel +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PowerIterationClusteringExample") # SparkContext + + # $example on$ + # Load and parse the data + data = sc.textFile("data/mllib/pic_data.txt") + similarities = data.map(lambda line: tuple([float(x) for x in line.split(' ')])) + + # Cluster the data into two classes using PowerIterationClustering + model = PowerIterationClustering.train(similarities, 2, 10) + + model.assignments().foreach(lambda x: print(str(x.id) + " -> " + str(x.cluster))) + + # Save and load model + model.save(sc, "target/org/apache/spark/PythonPowerIterationClusteringExample/PICModel") + sameModel = PowerIterationClusteringModel\ + .load(sc, "target/org/apache/spark/PythonPowerIterationClusteringExample/PICModel") + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/mllib/random_forest_classification_example.py b/examples/src/main/python/mllib/random_forest_classification_example.py new file mode 100644 index 0000000000000..9e5a8dcaabb0e --- /dev/null +++ b/examples/src/main/python/mllib/random_forest_classification_example.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Random Forest Classification Example. +""" +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.tree import RandomForest, RandomForestModel +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonRandomForestClassificationExample") + # $example on$ + # Load and parse the data file into an RDD of LabeledPoint. + data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a RandomForest model. + # Empty categoricalFeaturesInfo indicates all features are continuous. + # Note: Use larger numTrees in practice. + # Setting featureSubsetStrategy="auto" lets the algorithm choose. + model = RandomForest.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={}, + numTrees=3, featureSubsetStrategy="auto", + impurity='gini', maxDepth=4, maxBins=32) + + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) + print('Test Error = ' + str(testErr)) + print('Learned classification forest model:') + print(model.toDebugString()) + + # Save and load model + model.save(sc, "target/tmp/myRandomForestClassificationModel") + sameModel = RandomForestModel.load(sc, "target/tmp/myRandomForestClassificationModel") + # $example off$ diff --git a/examples/src/main/python/mllib/random_forest_example.py b/examples/src/main/python/mllib/random_forest_example.py deleted file mode 100755 index 4cfdad868c66e..0000000000000 --- a/examples/src/main/python/mllib/random_forest_example.py +++ /dev/null @@ -1,90 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Random Forest classification and regression using MLlib. - -Note: This example illustrates binary classification. - For information on multiclass classification, please refer to the decision_tree_runner.py - example. -""" -from __future__ import print_function - -import sys - -from pyspark.context import SparkContext -from pyspark.mllib.tree import RandomForest -from pyspark.mllib.util import MLUtils - - -def testClassification(trainingData, testData): - # Train a RandomForest model. - # Empty categoricalFeaturesInfo indicates all features are continuous. - # Note: Use larger numTrees in practice. - # Setting featureSubsetStrategy="auto" lets the algorithm choose. - model = RandomForest.trainClassifier(trainingData, numClasses=2, - categoricalFeaturesInfo={}, - numTrees=3, featureSubsetStrategy="auto", - impurity='gini', maxDepth=4, maxBins=32) - - # Evaluate model on test instances and compute test error - predictions = model.predict(testData.map(lambda x: x.features)) - labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testErr = labelsAndPredictions.filter(lambda v_p: v_p[0] != v_p[1]).count()\ - / float(testData.count()) - print('Test Error = ' + str(testErr)) - print('Learned classification forest model:') - print(model.toDebugString()) - - -def testRegression(trainingData, testData): - # Train a RandomForest model. - # Empty categoricalFeaturesInfo indicates all features are continuous. - # Note: Use larger numTrees in practice. - # Setting featureSubsetStrategy="auto" lets the algorithm choose. - model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo={}, - numTrees=3, featureSubsetStrategy="auto", - impurity='variance', maxDepth=4, maxBins=32) - - # Evaluate model on test instances and compute test error - predictions = model.predict(testData.map(lambda x: x.features)) - labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testMSE = labelsAndPredictions.map(lambda v_p1: (v_p1[0] - v_p1[1]) * (v_p1[0] - v_p1[1]))\ - .sum() / float(testData.count()) - print('Test Mean Squared Error = ' + str(testMSE)) - print('Learned regression forest model:') - print(model.toDebugString()) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - print("Usage: random_forest_example", file=sys.stderr) - exit(1) - sc = SparkContext(appName="PythonRandomForestExample") - - # Load and parse the data file into an RDD of LabeledPoint. - data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') - # Split the data into training and test sets (30% held out for testing) - (trainingData, testData) = data.randomSplit([0.7, 0.3]) - - print('\nRunning example of classification using RandomForest\n') - testClassification(trainingData, testData) - - print('\nRunning example of regression using RandomForest\n') - testRegression(trainingData, testData) - - sc.stop() diff --git a/examples/src/main/python/mllib/random_forest_regression_example.py b/examples/src/main/python/mllib/random_forest_regression_example.py new file mode 100644 index 0000000000000..2e1be34c1a29a --- /dev/null +++ b/examples/src/main/python/mllib/random_forest_regression_example.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Random Forest Regression Example. +""" +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.tree import RandomForest, RandomForestModel +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonRandomForestRegressionExample") + # $example on$ + # Load and parse the data file into an RDD of LabeledPoint. + data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a RandomForest model. + # Empty categoricalFeaturesInfo indicates all features are continuous. + # Note: Use larger numTrees in practice. + # Setting featureSubsetStrategy="auto" lets the algorithm choose. + model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo={}, + numTrees=3, featureSubsetStrategy="auto", + impurity='variance', maxDepth=4, maxBins=32) + + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() /\ + float(testData.count()) + print('Test Mean Squared Error = ' + str(testMSE)) + print('Learned regression forest model:') + print(model.toDebugString()) + + # Save and load model + model.save(sc, "target/tmp/myRandomForestRegressionModel") + sameModel = RandomForestModel.load(sc, "target/tmp/myRandomForestRegressionModel") + # $example off$ diff --git a/examples/src/main/python/mllib/ranking_metrics_example.py b/examples/src/main/python/mllib/ranking_metrics_example.py new file mode 100644 index 0000000000000..21333deded35d --- /dev/null +++ b/examples/src/main/python/mllib/ranking_metrics_example.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.mllib.recommendation import ALS, Rating +from pyspark.mllib.evaluation import RegressionMetrics, RankingMetrics +# $example off$ +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="Ranking Metrics Example") + + # Several of the methods available in scala are currently missing from pyspark + # $example on$ + # Read in the ratings data + lines = sc.textFile("data/mllib/sample_movielens_data.txt") + + def parseLine(line): + fields = line.split("::") + return Rating(int(fields[0]), int(fields[1]), float(fields[2]) - 2.5) + ratings = lines.map(lambda r: parseLine(r)) + + # Train a model on to predict user-product ratings + model = ALS.train(ratings, 10, 10, 0.01) + + # Get predicted ratings on all existing user-product pairs + testData = ratings.map(lambda p: (p.user, p.product)) + predictions = model.predictAll(testData).map(lambda r: ((r.user, r.product), r.rating)) + + ratingsTuple = ratings.map(lambda r: ((r.user, r.product), r.rating)) + scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1]) + + # Instantiate regression metrics to compare predicted and actual ratings + metrics = RegressionMetrics(scoreAndLabels) + + # Root mean squared error + print("RMSE = %s" % metrics.rootMeanSquaredError) + + # R-squared + print("R-squared = %s" % metrics.r2) + # $example off$ diff --git a/examples/src/main/python/mllib/recommendation_example.py b/examples/src/main/python/mllib/recommendation_example.py new file mode 100644 index 0000000000000..00e683c3ae938 --- /dev/null +++ b/examples/src/main/python/mllib/recommendation_example.py @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Collaborative Filtering Classification Example. +""" +from __future__ import print_function + +from pyspark import SparkContext + +# $example on$ +from pyspark.mllib.recommendation import ALS, MatrixFactorizationModel, Rating +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonCollaborativeFilteringExample") + # $example on$ + # Load and parse the data + data = sc.textFile("data/mllib/als/test.data") + ratings = data.map(lambda l: l.split(','))\ + .map(lambda l: Rating(int(l[0]), int(l[1]), float(l[2]))) + + # Build the recommendation model using Alternating Least Squares + rank = 10 + numIterations = 10 + model = ALS.train(ratings, rank, numIterations) + + # Evaluate the model on training data + testdata = ratings.map(lambda p: (p[0], p[1])) + predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2])) + ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions) + MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).mean() + print("Mean Squared Error = " + str(MSE)) + + # Save and load model + model.save(sc, "target/tmp/myCollaborativeFilter") + sameModel = MatrixFactorizationModel.load(sc, "target/tmp/myCollaborativeFilter") + # $example off$ diff --git a/examples/src/main/python/mllib/regression_metrics_example.py b/examples/src/main/python/mllib/regression_metrics_example.py new file mode 100644 index 0000000000000..a3a83aafd7a1f --- /dev/null +++ b/examples/src/main/python/mllib/regression_metrics_example.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# $example on$ +from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD +from pyspark.mllib.evaluation import RegressionMetrics +from pyspark.mllib.linalg import DenseVector +# $example off$ + +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="Regression Metrics Example") + + # $example on$ + # Load and parse the data + def parsePoint(line): + values = line.split() + return LabeledPoint(float(values[0]), + DenseVector([float(x.split(':')[1]) for x in values[1:]])) + + data = sc.textFile("data/mllib/sample_linear_regression_data.txt") + parsedData = data.map(parsePoint) + + # Build the model + model = LinearRegressionWithSGD.train(parsedData) + + # Get predictions + valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.label)) + + # Instantiate metrics object + metrics = RegressionMetrics(valuesAndPreds) + + # Squared Error + print("MSE = %s" % metrics.meanSquaredError) + print("RMSE = %s" % metrics.rootMeanSquaredError) + + # R-squared + print("R-squared = %s" % metrics.r2) + + # Mean absolute error + print("MAE = %s" % metrics.meanAbsoluteError) + + # Explained variance + print("Explained variance = %s" % metrics.explainedVariance) + # $example off$ diff --git a/examples/src/main/python/mllib/standard_scaler_example.py b/examples/src/main/python/mllib/standard_scaler_example.py new file mode 100644 index 0000000000000..20a77a470850f --- /dev/null +++ b/examples/src/main/python/mllib/standard_scaler_example.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.feature import StandardScaler, StandardScalerModel +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StandardScalerExample") # SparkContext + + # $example on$ + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + label = data.map(lambda x: x.label) + features = data.map(lambda x: x.features) + + scaler1 = StandardScaler().fit(features) + scaler2 = StandardScaler(withMean=True, withStd=True).fit(features) + + # data1 will be unit variance. + data1 = label.zip(scaler1.transform(features)) + + # Without converting the features into dense vectors, transformation with zero mean will raise + # exception on sparse vector. + # data2 will be unit variance and zero mean. + data2 = label.zip(scaler2.transform(features.map(lambda x: Vectors.dense(x.toArray())))) + # $example off$ + + print("data1:") + for each in data1.collect(): + print(each) + + print("data2:") + for each in data2.collect(): + print(each) + + sc.stop() diff --git a/examples/src/main/python/mllib/stratified_sampling_example.py b/examples/src/main/python/mllib/stratified_sampling_example.py new file mode 100644 index 0000000000000..a13f8f08dd68b --- /dev/null +++ b/examples/src/main/python/mllib/stratified_sampling_example.py @@ -0,0 +1,38 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="StratifiedSamplingExample") # SparkContext + + # $example on$ + # an RDD of any key value pairs + data = sc.parallelize([(1, 'a'), (1, 'b'), (2, 'c'), (2, 'd'), (2, 'e'), (3, 'f')]) + + # specify the exact fraction desired from each key as a dictionary + fractions = {1: 0.1, 2: 0.6, 3: 0.3} + + approxSample = data.sampleByKey(False, fractions) + # $example off$ + + for each in approxSample.collect(): + print(each) + + sc.stop() diff --git a/examples/src/main/python/mllib/streaming_k_means_example.py b/examples/src/main/python/mllib/streaming_k_means_example.py new file mode 100644 index 0000000000000..e82509ad3ffb6 --- /dev/null +++ b/examples/src/main/python/mllib/streaming_k_means_example.py @@ -0,0 +1,66 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +# $example on$ +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.clustering import StreamingKMeans +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StreamingKMeansExample") # SparkContext + ssc = StreamingContext(sc, 1) + + # $example on$ + # we make an input stream of vectors for training, + # as well as a stream of vectors for testing + def parse(lp): + label = float(lp[lp.find('(') + 1: lp.find(')')]) + vec = Vectors.dense(lp[lp.find('[') + 1: lp.find(']')].split(',')) + + return LabeledPoint(label, vec) + + trainingData = sc.textFile("data/mllib/kmeans_data.txt")\ + .map(lambda line: Vectors.dense([float(x) for x in line.strip().split(' ')])) + + testingData = sc.textFile("data/mllib/streaming_kmeans_data_test.txt").map(parse) + + trainingQueue = [trainingData] + testingQueue = [testingData] + + trainingStream = ssc.queueStream(trainingQueue) + testingStream = ssc.queueStream(testingQueue) + + # We create a model with random clusters and specify the number of clusters to find + model = StreamingKMeans(k=2, decayFactor=1.0).setRandomCenters(3, 1.0, 0) + + # Now register the streams for training and testing and start the job, + # printing the predicted cluster assignments on new data points as they arrive. + model.trainOn(trainingStream) + + result = model.predictOnValues(testingStream.map(lambda lp: (lp.label, lp.features))) + result.pprint() + + ssc.start() + ssc.stop(stopSparkContext=True, stopGraceFully=True) + # $example off$ + + print("Final centers: " + str(model.latestModel().centers)) diff --git a/examples/src/main/python/mllib/streaming_linear_regression_example.py b/examples/src/main/python/mllib/streaming_linear_regression_example.py new file mode 100644 index 0000000000000..f600496867c11 --- /dev/null +++ b/examples/src/main/python/mllib/streaming_linear_regression_example.py @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Streaming Linear Regression Example. +""" +from __future__ import print_function + +# $example on$ +import sys +# $example off$ + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +# $example on$ +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.regression import StreamingLinearRegressionWithSGD +# $example off$ + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: streaming_linear_regression_example.py ", + file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonLogisticRegressionWithLBFGSExample") + ssc = StreamingContext(sc, 1) + + # $example on$ + def parse(lp): + label = float(lp[lp.find('(') + 1: lp.find(',')]) + vec = Vectors.dense(lp[lp.find('[') + 1: lp.find(']')].split(',')) + return LabeledPoint(label, vec) + + trainingData = ssc.textFileStream(sys.argv[1]).map(parse).cache() + testData = ssc.textFileStream(sys.argv[2]).map(parse) + + numFeatures = 3 + model = StreamingLinearRegressionWithSGD() + model.setInitialWeights([0.0, 0.0, 0.0]) + + model.trainOn(trainingData) + print(model.predictOnValues(testData.map(lambda lp: (lp.label, lp.features)))) + + ssc.start() + ssc.awaitTermination() + # $example off$ diff --git a/examples/src/main/python/mllib/summary_statistics_example.py b/examples/src/main/python/mllib/summary_statistics_example.py new file mode 100644 index 0000000000000..d55d1a2c2d0e1 --- /dev/null +++ b/examples/src/main/python/mllib/summary_statistics_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +import numpy as np + +from pyspark.mllib.stat import Statistics +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="SummaryStatisticsExample") # SparkContext + + # $example on$ + mat = sc.parallelize( + [np.array([1.0, 10.0, 100.0]), np.array([2.0, 20.0, 200.0]), np.array([3.0, 30.0, 300.0])] + ) # an RDD of Vectors + + # Compute column summary statistics. + summary = Statistics.colStats(mat) + print(summary.mean()) # a dense vector containing the mean value for each column + print(summary.variance()) # column-wise variance + print(summary.numNonzeros()) # number of nonzeros in each column + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/mllib/svm_with_sgd_example.py b/examples/src/main/python/mllib/svm_with_sgd_example.py new file mode 100644 index 0000000000000..309ab09cc375a --- /dev/null +++ b/examples/src/main/python/mllib/svm_with_sgd_example.py @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.classification import SVMWithSGD, SVMModel +from pyspark.mllib.regression import LabeledPoint +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonSVMWithSGDExample") + + # $example on$ + # Load and parse the data + def parsePoint(line): + values = [float(x) for x in line.split(' ')] + return LabeledPoint(values[0], values[1:]) + + data = sc.textFile("data/mllib/sample_svm_data.txt") + parsedData = data.map(parsePoint) + + # Build the model + model = SVMWithSGD.train(parsedData, iterations=100) + + # Evaluating the model on training data + labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) + trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) + print("Training Error = " + str(trainErr)) + + # Save and load model + model.save(sc, "target/tmp/pythonSVMWithSGDModel") + sameModel = SVMModel.load(sc, "target/tmp/pythonSVMWithSGDModel") + # $example off$ diff --git a/examples/src/main/python/mllib/tf_idf_example.py b/examples/src/main/python/mllib/tf_idf_example.py new file mode 100644 index 0000000000000..c4d53333a95a9 --- /dev/null +++ b/examples/src/main/python/mllib/tf_idf_example.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.feature import HashingTF, IDF +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="TFIDFExample") # SparkContext + + # $example on$ + # Load documents (one per line). + documents = sc.textFile("data/mllib/kmeans_data.txt").map(lambda line: line.split(" ")) + + hashingTF = HashingTF() + tf = hashingTF.transform(documents) + + # While applying HashingTF only needs a single pass to the data, applying IDF needs two passes: + # First to compute the IDF vector and second to scale the term frequencies by IDF. + tf.cache() + idf = IDF().fit(tf) + tfidf = idf.transform(tf) + + # spark.mllib's IDF implementation provides an option for ignoring terms + # which occur in less than a minimum number of documents. + # In such cases, the IDF for these terms is set to 0. + # This feature can be used by passing the minDocFreq value to the IDF constructor. + idfIgnore = IDF(minDocFreq=2).fit(tf) + tfidfIgnore = idf.transform(tf) + # $example off$ + + print("tfidf:") + for each in tfidf.collect(): + print(each) + + print("tfidfIgnore:") + for each in tfidfIgnore.collect(): + print(each) + + sc.stop() diff --git a/examples/src/main/python/mllib/word2vec.py b/examples/src/main/python/mllib/word2vec.py index 40d1b887927e0..4e7d4f7610c24 100644 --- a/examples/src/main/python/mllib/word2vec.py +++ b/examples/src/main/python/mllib/word2vec.py @@ -16,7 +16,7 @@ # # This example uses text8 file from http://mattmahoney.net/dc/text8.zip -# The file was downloadded, unziped and split into multiple lines using +# The file was downloaded, unzipped and split into multiple lines using # # wget http://mattmahoney.net/dc/text8.zip # unzip text8.zip diff --git a/examples/src/main/python/mllib/word2vec_example.py b/examples/src/main/python/mllib/word2vec_example.py new file mode 100644 index 0000000000000..ad1090c77ee11 --- /dev/null +++ b/examples/src/main/python/mllib/word2vec_example.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.feature import Word2Vec +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="Word2VecExample") # SparkContext + + # $example on$ + inp = sc.textFile("data/mllib/sample_lda_data.txt").map(lambda row: row.split(" ")) + + word2vec = Word2Vec() + model = word2vec.fit(inp) + + synonyms = model.findSynonyms('1', 5) + + for word, cosine_distance in synonyms: + print("{}: {}".format(word, cosine_distance)) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py index f6b0ecb02c100..b6c2916254056 100755 --- a/examples/src/main/python/sort.py +++ b/examples/src/main/python/sort.py @@ -30,7 +30,7 @@ lines = sc.textFile(sys.argv[1], 1) sortedCount = lines.flatMap(lambda x: x.split(' ')) \ .map(lambda x: (int(x), 1)) \ - .sortByKey(lambda x: x) + .sortByKey() # This is just a demo on how to bring all the sorted data back to a single node. # In reality, we wouldn't want to collect all the data to the driver node. output = sortedCount.collect() diff --git a/examples/src/main/python/streaming/direct_kafka_wordcount.py b/examples/src/main/python/streaming/direct_kafka_wordcount.py index ea20678b9acad..7097f7f4502bd 100644 --- a/examples/src/main/python/streaming/direct_kafka_wordcount.py +++ b/examples/src/main/python/streaming/direct_kafka_wordcount.py @@ -28,6 +28,7 @@ examples/src/main/python/streaming/direct_kafka_wordcount.py \ localhost:9092 test` """ +from __future__ import print_function import sys diff --git a/examples/src/main/python/streaming/mqtt_wordcount.py b/examples/src/main/python/streaming/mqtt_wordcount.py deleted file mode 100644 index abf9c0e21d307..0000000000000 --- a/examples/src/main/python/streaming/mqtt_wordcount.py +++ /dev/null @@ -1,59 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" - A sample wordcount with MqttStream stream - Usage: mqtt_wordcount.py - - To run this in your local machine, you need to setup a MQTT broker and publisher first, - Mosquitto is one of the open source MQTT Brokers, see - http://mosquitto.org/ - Eclipse paho project provides number of clients and utilities for working with MQTT, see - http://www.eclipse.org/paho/#getting-started - - and then run the example - `$ bin/spark-submit --jars \ - external/mqtt-assembly/target/scala-*/spark-streaming-mqtt-assembly-*.jar \ - examples/src/main/python/streaming/mqtt_wordcount.py \ - tcp://localhost:1883 foo` -""" - -import sys - -from pyspark import SparkContext -from pyspark.streaming import StreamingContext -from pyspark.streaming.mqtt import MQTTUtils - -if __name__ == "__main__": - if len(sys.argv) != 3: - print >> sys.stderr, "Usage: mqtt_wordcount.py " - exit(-1) - - sc = SparkContext(appName="PythonStreamingMQTTWordCount") - ssc = StreamingContext(sc, 1) - - brokerUrl = sys.argv[1] - topic = sys.argv[2] - - lines = MQTTUtils.createStream(ssc, brokerUrl, topic) - counts = lines.flatMap(lambda line: line.split(" ")) \ - .map(lambda word: (word, 1)) \ - .reduceByKey(lambda a, b: a+b) - counts.pprint() - - ssc.start() - ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/network_wordjoinsentiments.py b/examples/src/main/python/streaming/network_wordjoinsentiments.py new file mode 100644 index 0000000000000..b85517dfdd913 --- /dev/null +++ b/examples/src/main/python/streaming/network_wordjoinsentiments.py @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Shows the most positive words in UTF8 encoded, '\n' delimited text directly received the network + every 5 seconds. The streaming data is joined with a static RDD of the AFINN word list + (http://neuro.imm.dtu.dk/wiki/AFINN) + + Usage: network_wordjoinsentiments.py + and describe the TCP server that Spark Streaming would connect to receive data. + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + and then run the example + `$ bin/spark-submit examples/src/main/python/streaming/network_wordjoinsentiments.py \ + localhost 9999` +""" + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + + +def print_happiest_words(rdd): + top_list = rdd.take(5) + print("Happiest topics in the last 5 seconds (%d total):" % rdd.count()) + for tuple in top_list: + print("%s (%d happiness)" % (tuple[1], tuple[0])) + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: network_wordjoinsentiments.py ", file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonStreamingNetworkWordJoinSentiments") + ssc = StreamingContext(sc, 5) + + # Read in the word-sentiment list and create a static RDD from it + word_sentiments_file_path = "data/streaming/AFINN-111.txt" + word_sentiments = ssc.sparkContext.textFile(word_sentiments_file_path) \ + .map(lambda line: tuple(line.split("\t"))) + + lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2])) + + word_counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a + b) + + # Determine the words with the highest sentiment values by joining the streaming RDD + # with the static RDD inside the transform() method and then multiplying + # the frequency of the words by its sentiment value + happiest_words = word_counts.transform(lambda rdd: word_sentiments.join(rdd)) \ + .map(lambda (word, tuple): (word, float(tuple[0]) * tuple[1])) \ + .map(lambda (word, happiness): (happiness, word)) \ + .transform(lambda rdd: rdd.sortByKey(False)) + + happiest_words.foreachRDD(print_happiest_words) + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/recoverable_network_wordcount.py b/examples/src/main/python/streaming/recoverable_network_wordcount.py index ac91f0a06b172..52b2639cdf55c 100644 --- a/examples/src/main/python/streaming/recoverable_network_wordcount.py +++ b/examples/src/main/python/streaming/recoverable_network_wordcount.py @@ -44,6 +44,20 @@ from pyspark.streaming import StreamingContext +# Get or register a Broadcast variable +def getWordBlacklist(sparkContext): + if ('wordBlacklist' not in globals()): + globals()['wordBlacklist'] = sparkContext.broadcast(["a", "b", "c"]) + return globals()['wordBlacklist'] + + +# Get or register an Accumulator +def getDroppedWordsCounter(sparkContext): + if ('droppedWordsCounter' not in globals()): + globals()['droppedWordsCounter'] = sparkContext.accumulator(0) + return globals()['droppedWordsCounter'] + + def createContext(host, port, outputPath): # If you do not see this printed, that means the StreamingContext has been loaded # from the new checkpoint @@ -60,8 +74,22 @@ def createContext(host, port, outputPath): wordCounts = words.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y) def echo(time, rdd): - counts = "Counts at time %s %s" % (time, rdd.collect()) + # Get or register the blacklist Broadcast + blacklist = getWordBlacklist(rdd.context) + # Get or register the droppedWordsCounter Accumulator + droppedWordsCounter = getDroppedWordsCounter(rdd.context) + + # Use blacklist to drop words and use droppedWordsCounter to count them + def filterFunc(wordCount): + if wordCount[0] in blacklist.value: + droppedWordsCounter.add(wordCount[1]) + False + else: + True + + counts = "Counts at time %s %s" % (time, rdd.filter(filterFunc).collect()) print(counts) + print("Dropped %d word(s) totally" % droppedWordsCounter.value) print("Appending to " + os.path.abspath(outputPath)) with open(outputPath, 'a') as f: f.write(counts + "\n") diff --git a/examples/src/main/python/streaming/sql_network_wordcount.py b/examples/src/main/python/streaming/sql_network_wordcount.py index da90c07dbd82f..1ba5e9fb78993 100644 --- a/examples/src/main/python/streaming/sql_network_wordcount.py +++ b/examples/src/main/python/streaming/sql_network_wordcount.py @@ -29,7 +29,6 @@ """ from __future__ import print_function -import os import sys from pyspark import SparkContext diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py index 16ef646b7c42e..f8bbc659c2ea7 100644 --- a/examples/src/main/python/streaming/stateful_network_wordcount.py +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -44,13 +44,16 @@ ssc = StreamingContext(sc, 1) ssc.checkpoint("checkpoint") + # RDD with initial state (key, value) pairs + initialStateRDD = sc.parallelize([(u'hello', 1), (u'world', 1)]) + def updateFunc(new_values, last_sum): return sum(new_values) + (last_sum or 0) lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2])) running_counts = lines.flatMap(lambda line: line.split(" "))\ .map(lambda word: (word, 1))\ - .updateStateByKey(updateFunc) + .updateStateByKey(updateFunc, initialRDD=initialStateRDD) running_counts.pprint() diff --git a/examples/src/main/python/transitive_closure.py b/examples/src/main/python/transitive_closure.py index 7bf5fb6ddfe29..3d61250d8b230 100755 --- a/examples/src/main/python/transitive_closure.py +++ b/examples/src/main/python/transitive_closure.py @@ -30,8 +30,8 @@ def generateGraph(): edges = set() while len(edges) < numEdges: - src = rand.randrange(0, numEdges) - dst = rand.randrange(0, numEdges) + src = rand.randrange(0, numVertices) + dst = rand.randrange(0, numVertices) if src != dst: edges.add((src, dst)) return edges diff --git a/examples/src/main/r/dataframe.R b/examples/src/main/r/dataframe.R index 53b817144f6ac..62f60e57eebe6 100644 --- a/examples/src/main/r/dataframe.R +++ b/examples/src/main/r/dataframe.R @@ -35,7 +35,7 @@ printSchema(df) # Create a DataFrame from a JSON file path <- file.path(Sys.getenv("SPARK_HOME"), "examples/src/main/resources/people.json") -peopleDF <- jsonFile(sqlContext, path) +peopleDF <- read.json(sqlContext, path) printSchema(peopleDF) # Register this DataFrame as a table. diff --git a/examples/src/main/r/ml.R b/examples/src/main/r/ml.R new file mode 100644 index 0000000000000..a0c903939cbbb --- /dev/null +++ b/examples/src/main/r/ml.R @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# To run this example use +# ./bin/sparkR examples/src/main/r/ml.R + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkContext and SQLContext +sc <- sparkR.init(appName="SparkR-ML-example") +sqlContext <- sparkRSQL.init(sc) + +# Train GLM of family 'gaussian' +training1 <- suppressWarnings(createDataFrame(sqlContext, iris)) +test1 <- training1 +model1 <- glm(Sepal_Length ~ Sepal_Width + Species, training1, family = "gaussian") + +# Model summary +summary(model1) + +# Prediction +predictions1 <- predict(model1, test1) +head(select(predictions1, "Sepal_Length", "prediction")) + +# Train GLM of family 'binomial' +training2 <- filter(training1, training1$Species != "setosa") +test2 <- training2 +model2 <- glm(Species ~ Sepal_Length + Sepal_Width, data = training2, family = "binomial") + +# Model summary +summary(model2) + +# Prediction (Currently the output of prediction for binomial GLM is the indexed label, +# we need to transform back to the original string label later) +predictions2 <- predict(model2, test2) +head(select(predictions2, "Species", "prediction")) + +# Stop the SparkContext now +sparkR.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index d812262fd87dc..af5a815f6ec76 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -21,16 +21,14 @@ package org.apache.spark.examples import org.apache.spark.{SparkConf, SparkContext} /** - * Usage: BroadcastTest [slices] [numElem] [broadcastAlgo] [blockSize] - */ + * Usage: BroadcastTest [slices] [numElem] [blockSize] + */ object BroadcastTest { def main(args: Array[String]) { - val bcName = if (args.length > 2) args(2) else "Http" - val blockSize = if (args.length > 3) args(3) else "4096" + val blockSize = if (args.length > 2) args(2) else "4096" val sparkConf = new SparkConf().setAppName("Broadcast Test") - .set("spark.broadcast.factory", s"org.apache.spark.broadcast.${bcName}BroadcastFactory") .set("spark.broadcast.blockSize", blockSize) val sc = new SparkContext(sparkConf) @@ -44,7 +42,7 @@ object BroadcastTest { println("===========") val startTime = System.nanoTime val barr1 = sc.broadcast(arr1) - val observedSizes = sc.parallelize(1 to 10, slices).map(_ => barr1.value.size) + val observedSizes = sc.parallelize(1 to 10, slices).map(_ => barr1.value.length) // Collect the small RDD so we can print the observed sizes locally. observedSizes.collect().foreach(i => println(i)) println("Iteration %d took %.0f milliseconds".format(i, (System.nanoTime - startTime) / 1E6)) diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index d1b9b8d398dd8..ca4eea235683a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -16,22 +16,20 @@ */ // scalastyle:off println - // scalastyle:off jobcontext package org.apache.spark.examples import java.nio.ByteBuffer import java.util.Collections import org.apache.cassandra.hadoop.ConfigHelper -import org.apache.cassandra.hadoop.cql3.CqlPagingInputFormat import org.apache.cassandra.hadoop.cql3.CqlConfigHelper import org.apache.cassandra.hadoop.cql3.CqlOutputFormat +import org.apache.cassandra.hadoop.cql3.CqlPagingInputFormat import org.apache.cassandra.utils.ByteBufferUtil import org.apache.hadoop.mapreduce.Job import org.apache.spark.{SparkConf, SparkContext} - /* Need to create following keyspace and column family in cassandra before running this example Start CQL shell using ./bin/cqlsh and execute following commands @@ -80,7 +78,7 @@ object CassandraCQLTest { val InputColumnFamily = "ordercf" val OutputColumnFamily = "salecount" - val job = new Job() + val job = Job.getInstance() job.setInputFormatClass(classOf[CqlPagingInputFormat]) val configuration = job.getConfiguration ConfigHelper.setInputInitialAddress(job.getConfiguration(), cHost) @@ -108,9 +106,8 @@ object CassandraCQLTest { println("Count: " + casRdd.count) val productSaleRDD = casRdd.map { - case (key, value) => { + case (key, value) => (ByteBufferUtil.string(value.get("prod_id")), ByteBufferUtil.toInt(value.get("quantity"))) - } } val aggregatedRDD = productSaleRDD.reduceByKey(_ + _) aggregatedRDD.collect().foreach { @@ -118,11 +115,10 @@ object CassandraCQLTest { } val casoutputCF = aggregatedRDD.map { - case (productId, saleCount) => { + case (productId, saleCount) => val outKey = Collections.singletonMap("prod_id", ByteBufferUtil.bytes(productId)) val outVal = Collections.singletonList(ByteBufferUtil.bytes(saleCount)) (outKey, outVal) - } } casoutputCF.saveAsNewAPIHadoopFile( @@ -137,4 +133,3 @@ object CassandraCQLTest { } } // scalastyle:on println -// scalastyle:on jobcontext diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index 1e679bfb55343..eff840d36e8d4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -16,7 +16,6 @@ */ // scalastyle:off println -// scalastyle:off jobcontext package org.apache.spark.examples import java.nio.ByteBuffer @@ -24,9 +23,9 @@ import java.util.Arrays import java.util.SortedMap import org.apache.cassandra.db.IColumn +import org.apache.cassandra.hadoop.ColumnFamilyInputFormat import org.apache.cassandra.hadoop.ColumnFamilyOutputFormat import org.apache.cassandra.hadoop.ConfigHelper -import org.apache.cassandra.hadoop.ColumnFamilyInputFormat import org.apache.cassandra.thrift._ import org.apache.cassandra.utils.ByteBufferUtil import org.apache.hadoop.mapreduce.Job @@ -59,7 +58,7 @@ object CassandraTest { val sc = new SparkContext(sparkConf) // Build the job configuration with ConfigHelper provided by Cassandra - val job = new Job() + val job = Job.getInstance() job.setInputFormatClass(classOf[ColumnFamilyInputFormat]) val host: String = args(1) @@ -91,9 +90,8 @@ object CassandraTest { // Let us first get all the paragraphs from the retrieved rows val paraRdd = casRdd.map { - case (key, value) => { + case (key, value) => ByteBufferUtil.string(value.get(ByteBufferUtil.bytes("para")).value()) - } } // Lets get the word count in paras @@ -104,7 +102,7 @@ object CassandraTest { } counts.map { - case (word, count) => { + case (word, count) => val colWord = new org.apache.cassandra.thrift.Column() colWord.setName(ByteBufferUtil.bytes("word")) colWord.setValue(ByteBufferUtil.bytes(word)) @@ -123,7 +121,6 @@ object CassandraTest { mutations.get(1).setColumn_or_supercolumn(new ColumnOrSuperColumn()) mutations.get(1).column_or_supercolumn.setColumn(colCount) (outputkey, mutations) - } }.saveAsNewAPIHadoopFile("casDemo", classOf[ByteBuffer], classOf[List[Mutation]], classOf[ColumnFamilyOutputFormat], job.getConfiguration) @@ -131,7 +128,6 @@ object CassandraTest { } } // scalastyle:on println -// scalastyle:on jobcontext /* create keyspace casDemo; diff --git a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala index d651fe4d6ee75..7bf023667dcae 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala @@ -22,20 +22,19 @@ import java.io.File import scala.io.Source._ -import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.SparkContext._ +import org.apache.spark.{SparkConf, SparkContext} /** - * Simple test for reading and writing to a distributed - * file system. This example does the following: - * - * 1. Reads local file - * 2. Computes word count on local file - * 3. Writes local file to a DFS - * 4. Reads the file back from the DFS - * 5. Computes word count on the file using Spark - * 6. Compares the word count results - */ + * Simple test for reading and writing to a distributed + * file system. This example does the following: + * + * 1. Reads local file + * 2. Computes word count on local file + * 3. Writes local file to a DFS + * 4. Reads the file back from the DFS + * 5. Computes word count on the file using Spark + * 6. Compares the word count results + */ object DFSReadWriteTest { private var localFilePath: File = new File(".") @@ -88,7 +87,7 @@ object DFSReadWriteTest { def runLocalWordCount(fileContents: List[String]): Int = { fileContents.flatMap(_.split(" ")) .flatMap(_.split("\t")) - .filter(_.size > 0) + .filter(_.nonEmpty) .groupBy(w => w) .mapValues(_.size) .values @@ -119,7 +118,7 @@ object DFSReadWriteTest { val dfsWordCount = readFileRDD .flatMap(_.split(" ")) .flatMap(_.split("\t")) - .filter(_.size > 0) + .filter(_.nonEmpty) .map(w => (w, 1)) .countByKey() .values diff --git a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala index bec61f3cd4296..d12ef642bd2cd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala @@ -22,11 +22,13 @@ import scala.collection.JavaConverters._ import org.apache.spark.util.Utils -/** Prints out environmental information, sleeps, and then exits. Made to - * test driver submission in the standalone scheduler. */ +/** + * Prints out environmental information, sleeps, and then exits. Made to + * test driver submission in the standalone scheduler. + */ object DriverSubmissionTest { def main(args: Array[String]) { - if (args.size < 1) { + if (args.length < 1) { println("Usage: DriverSubmissionTest ") System.exit(0) } diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala index fa4a3afeecd19..4db229b5dec32 100644 --- a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala @@ -21,11 +21,10 @@ package org.apache.spark.examples import java.util.Random import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ /** - * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] - */ + * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] + */ object GroupByTest { def main(args: Array[String]) { val sparkConf = new SparkConf().setAppName("GroupBy Test") diff --git a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala index 244742327a907..65d7489586062 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala @@ -18,13 +18,12 @@ // scalastyle:off println package org.apache.spark.examples -import org.apache.hadoop.hbase.client.HBaseAdmin import org.apache.hadoop.hbase.{HBaseConfiguration, HTableDescriptor, TableName} +import org.apache.hadoop.hbase.client.HBaseAdmin import org.apache.hadoop.hbase.mapreduce.TableInputFormat import org.apache.spark._ - object HBaseTest { def main(args: Array[String]) { val sparkConf = new SparkConf().setAppName("HBaseTest") diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala index af5f216f28ba4..fa1010195551a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala @@ -104,16 +104,14 @@ object LocalALS { def main(args: Array[String]) { args match { - case Array(m, u, f, iters) => { + case Array(m, u, f, iters) => M = m.toInt U = u.toInt F = f.toInt ITERATIONS = iters.toInt - } - case _ => { + case _ => System.err.println("Usage: LocalALS ") System.exit(1) - } } showWarning() diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala index 9c8aae53cf48d..bec89f7c3dff0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala @@ -20,7 +20,7 @@ package org.apache.spark.examples import java.util.Random -import breeze.linalg.{Vector, DenseVector} +import breeze.linalg.{DenseVector, Vector} /** * Logistic regression based classification. @@ -30,7 +30,7 @@ import breeze.linalg.{Vector, DenseVector} * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ object LocalFileLR { - val D = 10 // Numer of dimensions + val D = 10 // Number of dimensions val rand = new Random(42) case class DataPoint(x: Vector[Double], y: Double) @@ -58,7 +58,7 @@ object LocalFileLR { val ITERATIONS = args(1).toInt // Initialize w to a random value - var w = DenseVector.fill(D){2 * rand.nextDouble - 1} + var w = DenseVector.fill(D) {2 * rand.nextDouble - 1} println("Initial w: " + w) for (i <- 1 to ITERATIONS) { diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala index e7b28d38bdfc6..f8961847f3df2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala @@ -23,9 +23,7 @@ import java.util.Random import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet -import breeze.linalg.{Vector, DenseVector, squaredDistance} - -import org.apache.spark.SparkContext._ +import breeze.linalg.{squaredDistance, DenseVector, Vector} /** * K-means clustering. @@ -43,7 +41,7 @@ object LocalKMeans { def generateData: Array[DenseVector[Double]] = { def generatePoint(i: Int): DenseVector[Double] = { - DenseVector.fill(D){rand.nextDouble * R} + DenseVector.fill(D) {rand.nextDouble * R} } Array.tabulate(N)(generatePoint) } diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index 4f6b092a59ca5..0baf6db607ad9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -20,7 +20,7 @@ package org.apache.spark.examples import java.util.Random -import breeze.linalg.{Vector, DenseVector} +import breeze.linalg.{DenseVector, Vector} /** * Logistic regression based classification. @@ -41,7 +41,7 @@ object LocalLR { def generateData: Array[DataPoint] = { def generatePoint(i: Int): DataPoint = { val y = if (i % 2 == 0) -1 else 1 - val x = DenseVector.fill(D){rand.nextGaussian + y * R} + val x = DenseVector.fill(D) {rand.nextGaussian + y * R} DataPoint(x, y) } Array.tabulate(N)(generatePoint) @@ -62,7 +62,7 @@ object LocalLR { val data = generateData // Initialize w to a random value - var w = DenseVector.fill(D){2 * rand.nextDouble - 1} + var w = DenseVector.fill(D) {2 * rand.nextDouble - 1} println("Initial w: " + w) for (i <- 1 to ITERATIONS) { diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala index 3d923625f11b6..720d92fb9d029 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala @@ -20,9 +20,6 @@ package org.apache.spark.examples import scala.math.random -import org.apache.spark._ -import org.apache.spark.SparkContext._ - object LocalPi { def main(args: Array[String]) { var count = 0 diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala index a80de10f4610a..c55b68e033964 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -19,7 +19,6 @@ package org.apache.spark.examples import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ /** * Executes a roll up-style query against Apache logs. diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala index 61ce9db914f9f..3eb0c2772337a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples -import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.rdd.RDD /** - * Usage: MultiBroadcastTest [slices] [numElem] - */ + * Usage: MultiBroadcastTest [slices] [numElem] + */ object MultiBroadcastTest { def main(args: Array[String]) { @@ -46,7 +46,7 @@ object MultiBroadcastTest { val barr1 = sc.broadcast(arr1) val barr2 = sc.broadcast(arr2) val observedSizes: RDD[(Int, Int)] = sc.parallelize(1 to 10, slices).map { _ => - (barr1.value.size, barr2.value.size) + (barr1.value.length, barr2.value.length) } // Collect the small RDD so we can print the observed sizes locally. observedSizes.collect().foreach(i => println(i)) diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index 3b0b00fe4dd0a..ec07e6323ee9a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -21,11 +21,10 @@ package org.apache.spark.examples import java.util.Random import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ /** - * Usage: SimpleSkewedGroupByTest [numMappers] [numKVPairs] [valSize] [numReducers] [ratio] - */ + * Usage: SimpleSkewedGroupByTest [numMappers] [numKVPairs] [valSize] [numReducers] [ratio] + */ object SimpleSkewedGroupByTest { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala index 719e2176fed3f..8e4c2b6229755 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala @@ -21,11 +21,10 @@ package org.apache.spark.examples import java.util.Random import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ /** - * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] - */ + * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] + */ object SkewedGroupByTest { def main(args: Array[String]) { val sparkConf = new SparkConf().setAppName("GroupBy Test") @@ -39,7 +38,7 @@ object SkewedGroupByTest { val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => val ranGen = new Random - // map output sizes lineraly increase from the 1st to the last + // map output sizes linearly increase from the 1st to the last numKVPairs = (1.0 * (p + 1) / numMappers * numKVPairs).toInt var arr1 = new Array[(Int, Array[Byte])](numKVPairs) diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index 69799b7c2bb30..4263680c6fde3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -58,7 +58,7 @@ object SparkALS { } def update(i: Int, m: RealVector, us: Array[RealVector], R: RealMatrix) : RealVector = { - val U = us.size + val U = us.length val F = us(0).getDimension var XtX: RealMatrix = new Array2DRowRealMatrix(F, F) var Xty: RealVector = new ArrayRealVector(F) diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index 505ea5a4c7a85..7463b868ff19b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -22,12 +22,10 @@ import java.util.Random import scala.math.exp -import breeze.linalg.{Vector, DenseVector} +import breeze.linalg.{DenseVector, Vector} import org.apache.hadoop.conf.Configuration import org.apache.spark._ -import org.apache.spark.scheduler.InputFormatInfo - /** * Logistic regression based classification. @@ -37,7 +35,7 @@ import org.apache.spark.scheduler.InputFormatInfo * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ object SparkHdfsLR { - val D = 10 // Numer of dimensions + val D = 10 // Number of dimensions val rand = new Random(42) case class DataPoint(x: Vector[Double], y: Double) @@ -74,16 +72,13 @@ object SparkHdfsLR { val sparkConf = new SparkConf().setAppName("SparkHdfsLR") val inputPath = args(0) val conf = new Configuration() - val sc = new SparkContext(sparkConf, - InputFormatInfo.computePreferredLocations( - Seq(new InputFormatInfo(conf, classOf[org.apache.hadoop.mapred.TextInputFormat], inputPath)) - )) + val sc = new SparkContext(sparkConf) val lines = sc.textFile(inputPath) - val points = lines.map(parsePoint _).cache() + val points = lines.map(parsePoint).cache() val ITERATIONS = args(1).toInt // Initialize w to a random value - var w = DenseVector.fill(D){2 * rand.nextDouble - 1} + var w = DenseVector.fill(D) {2 * rand.nextDouble - 1} println("Initial w: " + w) for (i <- 1 to ITERATIONS) { diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala index c56e1124ad415..d9f94a42b1a0b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala @@ -18,10 +18,9 @@ // scalastyle:off println package org.apache.spark.examples -import breeze.linalg.{Vector, DenseVector, squaredDistance} +import breeze.linalg.{squaredDistance, DenseVector, Vector} import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ /** * K-means clustering. diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index d265c227f4ed2..acd8656b65a69 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -22,7 +22,7 @@ import java.util.Random import scala.math.exp -import breeze.linalg.{Vector, DenseVector} +import breeze.linalg.{DenseVector, Vector} import org.apache.spark._ @@ -36,7 +36,7 @@ import org.apache.spark._ */ object SparkLR { val N = 10000 // Number of data points - val D = 10 // Numer of dimensions + val D = 10 // Number of dimensions val R = 0.7 // Scaling factor val ITERATIONS = 5 val rand = new Random(42) @@ -46,7 +46,7 @@ object SparkLR { def generateData: Array[DataPoint] = { def generatePoint(i: Int): DataPoint = { val y = if (i % 2 == 0) -1 else 1 - val x = DenseVector.fill(D){rand.nextGaussian + y * R} + val x = DenseVector.fill(D) {rand.nextGaussian + y * R} DataPoint(x, y) } Array.tabulate(N)(generatePoint) @@ -71,7 +71,7 @@ object SparkLR { val points = sc.parallelize(generateData, numSlices).cache() // Initialize w to a random value - var w = DenseVector.fill(D){2 * rand.nextDouble - 1} + var w = DenseVector.fill(D) {2 * rand.nextDouble - 1} println("Initial w: " + w) for (i <- 1 to ITERATIONS) { diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index 0fd79660dd196..2664ddbb87d23 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -18,7 +18,6 @@ // scalastyle:off println package org.apache.spark.examples -import org.apache.spark.SparkContext._ import org.apache.spark.{SparkConf, SparkContext} /** diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index 95072071ccddb..fc7a1f859f602 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -18,11 +18,10 @@ // scalastyle:off println package org.apache.spark.examples -import scala.util.Random import scala.collection.mutable +import scala.util.Random import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ /** * Transitive closure on a graph. diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala deleted file mode 100644 index cfbdae02212a5..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples - -import java.util.Random - -import scala.math.exp - -import breeze.linalg.{Vector, DenseVector} -import org.apache.hadoop.conf.Configuration - -import org.apache.spark._ -import org.apache.spark.scheduler.InputFormatInfo -import org.apache.spark.storage.StorageLevel - - -/** - * Logistic regression based classification. - * This example uses Tachyon to persist rdds during computation. - * - * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or - * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. - */ -object SparkTachyonHdfsLR { - val D = 10 // Numer of dimensions - val rand = new Random(42) - - def showWarning() { - System.err.println( - """WARN: This is a naive implementation of Logistic Regression and is given as an example! - |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or - |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS - |for more conventional use. - """.stripMargin) - } - - case class DataPoint(x: Vector[Double], y: Double) - - def parsePoint(line: String): DataPoint = { - val tok = new java.util.StringTokenizer(line, " ") - var y = tok.nextToken.toDouble - var x = new Array[Double](D) - var i = 0 - while (i < D) { - x(i) = tok.nextToken.toDouble; i += 1 - } - DataPoint(new DenseVector(x), y) - } - - def main(args: Array[String]) { - - showWarning() - - val inputPath = args(0) - val sparkConf = new SparkConf().setAppName("SparkTachyonHdfsLR") - val conf = new Configuration() - val sc = new SparkContext(sparkConf, - InputFormatInfo.computePreferredLocations( - Seq(new InputFormatInfo(conf, classOf[org.apache.hadoop.mapred.TextInputFormat], inputPath)) - )) - val lines = sc.textFile(inputPath) - val points = lines.map(parsePoint _).persist(StorageLevel.OFF_HEAP) - val ITERATIONS = args(1).toInt - - // Initialize w to a random value - var w = DenseVector.fill(D){2 * rand.nextDouble - 1} - println("Initial w: " + w) - - for (i <- 1 to ITERATIONS) { - println("On iteration " + i) - val gradient = points.map { p => - p.x * (1 / (1 + exp(-p.y * (w.dot(p.x)))) - 1) * p.y - }.reduce(_ + _) - w -= gradient - } - - println("Final w: " + w) - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala deleted file mode 100644 index e46ac655beb58..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples - -import scala.math.random - -import org.apache.spark._ -import org.apache.spark.storage.StorageLevel - -/** - * Computes an approximation to pi - * This example uses Tachyon to persist rdds during computation. - */ -object SparkTachyonPi { - def main(args: Array[String]) { - val sparkConf = new SparkConf().setAppName("SparkTachyonPi") - val spark = new SparkContext(sparkConf) - - val slices = if (args.length > 0) args(0).toInt else 2 - val n = 100000 * slices - - val rdd = spark.parallelize(1 to n, slices) - rdd.persist(StorageLevel.OFF_HEAP) - val count = rdd.map { i => - val x = random * 2 - 1 - val y = random * 2 - 1 - if (x * x + y * y < 1) 1 else 0 - }.reduce(_ + _) - println("Pi is roughly " + 4.0 * count / n) - - spark.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index 8dd6c9706e7df..619e585b6ca17 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -19,11 +19,13 @@ package org.apache.spark.examples.graphx import scala.collection.mutable + import org.apache.spark._ -import org.apache.spark.storage.StorageLevel import org.apache.spark.graphx._ -import org.apache.spark.graphx.lib._ import org.apache.spark.graphx.PartitionStrategy._ +import org.apache.spark.graphx.lib._ +import org.apache.spark.internal.Logging +import org.apache.spark.storage.StorageLevel /** * Driver program for running graph algorithms. diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala index 46e52aacd90bb..6d2228c8742aa 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.graphx -import org.apache.spark.SparkContext._ +import java.io.{FileOutputStream, PrintWriter} + +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.graphx.{GraphXUtils, PartitionStrategy} -import org.apache.spark.{SparkContext, SparkConf} import org.apache.spark.graphx.util.GraphGenerators -import java.io.{PrintWriter, FileOutputStream} /** * The SynthBenchmark application can be used to run various GraphX algorithms on diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala new file mode 100644 index 0000000000000..21f58ddf3cfb7 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.regression.AFTSurvivalRegression +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext + +/** + * An example for AFTSurvivalRegression. + */ +object AFTSurvivalRegressionExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("AFTSurvivalRegressionExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val training = sqlContext.createDataFrame(Seq( + (1.218, 1.0, Vectors.dense(1.560, -0.605)), + (2.949, 0.0, Vectors.dense(0.346, 2.158)), + (3.627, 0.0, Vectors.dense(1.380, 0.231)), + (0.273, 1.0, Vectors.dense(0.520, 1.151)), + (4.199, 0.0, Vectors.dense(0.795, -0.226)) + )).toDF("label", "censor", "features") + val quantileProbabilities = Array(0.3, 0.6) + val aft = new AFTSurvivalRegression() + .setQuantileProbabilities(quantileProbabilities) + .setQuantilesCol("quantiles") + + val model = aft.fit(training) + + // Print the coefficients, intercept and scale parameter for AFT survival regression + println(s"Coefficients: ${model.coefficients} Intercept: " + + s"${model.intercept} Scale: ${model.scale}") + model.transform(training).show(false) + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala new file mode 100644 index 0000000000000..a79e15c767e1f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.recommendation.ALS +// $example off$ +import org.apache.spark.sql.SQLContext +// $example on$ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DoubleType +// $example off$ + +object ALSExample { + + // $example on$ + case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long) + object Rating { + def parseRating(str: String): Rating = { + val fields = str.split("::") + assert(fields.size == 4) + Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong) + } + } + // $example off$ + + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("ALSExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // $example on$ + val ratings = sc.textFile("data/mllib/als/sample_movielens_ratings.txt") + .map(Rating.parseRating) + .toDF() + val Array(training, test) = ratings.randomSplit(Array(0.8, 0.2)) + + // Build the recommendation model using ALS on the training data + val als = new ALS() + .setMaxIter(5) + .setRegParam(0.01) + .setUserCol("userId") + .setItemCol("movieId") + .setRatingCol("rating") + val model = als.fit(training) + + // Evaluate the model by computing the RMSE on the test data + val predictions = model.transform(test) + .withColumn("rating", col("rating").cast(DoubleType)) + .withColumn("prediction", col("prediction").cast(DoubleType)) + + val evaluator = new RegressionEvaluator() + .setMetricName("rmse") + .setLabelCol("rating") + .setPredictionCol("prediction") + val rmse = evaluator.evaluate(predictions) + println(s"Root-mean-square error = $rmse") + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala new file mode 100644 index 0000000000000..2ed8101c133cf --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.Binarizer +// $example off$ +import org.apache.spark.sql.{DataFrame, SQLContext} + +object BinarizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("BinarizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + // $example on$ + val data = Array((0, 0.1), (1, 0.8), (2, 0.2)) + val dataFrame: DataFrame = sqlContext.createDataFrame(data).toDF("label", "feature") + + val binarizer: Binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.5) + + val binarizedDataFrame = binarizer.transform(dataFrame) + val binarizedFeatures = binarizedDataFrame.select("binarized_feature") + binarizedFeatures.collect().foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala new file mode 100644 index 0000000000000..6f6236a2b0588 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.Bucketizer +// $example off$ +import org.apache.spark.sql.SQLContext + +object BucketizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("BucketizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) + + val data = Array(-0.5, -0.3, 0.0, 0.2) + val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + + val bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits) + + // Transform original data into its bucket index. + val bucketedData = bucketizer.transform(dataFrame) + bucketedData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala new file mode 100644 index 0000000000000..2be61537e613a --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.ChiSqSelector +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext + +object ChiSqSelectorExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("ChiSqSelectorExample") + val sc = new SparkContext(conf) + + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + + // $example on$ + val data = Seq( + (7, Vectors.dense(0.0, 0.0, 18.0, 1.0), 1.0), + (8, Vectors.dense(0.0, 1.0, 12.0, 0.0), 0.0), + (9, Vectors.dense(1.0, 0.0, 15.0, 0.1), 0.0) + ) + + val df = sc.parallelize(data).toDF("id", "features", "clicked") + + val selector = new ChiSqSelector() + .setNumTopFeatures(1) + .setFeaturesCol("features") + .setLabelCol("clicked") + .setOutputCol("selectedFeatures") + + val result = selector.fit(df).transform(df) + result.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala index ba916f66c4c07..7d07fc7dd113a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala @@ -18,12 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel} // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - object CountVectorizerExample { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala index 14b358d46f6ab..bca301d412f4c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.ml.feature.{HashingTF, Tokenizer} -import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} +import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{Row, SQLContext} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala new file mode 100644 index 0000000000000..dc26b55a768a7 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.DCT +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext + +object DCTExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DCTExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Seq( + Vectors.dense(0.0, 1.0, -2.0, 3.0), + Vectors.dense(-1.0, 2.0, 4.0, -7.0), + Vectors.dense(14.0, -2.0, -5.0, 1.0)) + + val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + + val dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false) + + val dctDf = dct.transform(df) + dctDf.select("featuresDCT").show(3) + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala new file mode 100644 index 0000000000000..7e608a281203e --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import java.io.File + +import com.google.common.io.Files +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +/** + * An example of how to use [[org.apache.spark.sql.DataFrame]] for ML. Run with + * {{{ + * ./bin/run-example ml.DataFrameExample [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object DataFrameExample { + + case class Params(input: String = "data/mllib/sample_libsvm_data.txt") + extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("DataFrameExample") { + head("DataFrameExample: an example app using DataFrame for ML.") + opt[String]("input") + .text(s"input path to dataframe") + .action((x, c) => c.copy(input = x)) + checkConfig { params => + success + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"DataFrameExample with $params") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // Load input data + println(s"Loading LIBSVM file with UDT from ${params.input}.") + val df: DataFrame = sqlContext.read.format("libsvm").load(params.input).cache() + println("Schema from LIBSVM:") + df.printSchema() + println(s"Loaded training data as a DataFrame with ${df.count()} records.") + + // Show statistical summary of labels. + val labelSummary = df.describe("label") + labelSummary.show() + + // Convert features column to an RDD of vectors. + val features = df.select("features").rdd.map { case Row(v: Vector) => v } + val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( + (summary, feat) => summary.add(feat), + (sum1, sum2) => sum1.merge(sum2)) + println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") + + // Save the records in a parquet file. + val tmpDir = Files.createTempDir() + tmpDir.deleteOnExit() + val outputDir = new File(tmpDir, "dataframe").toString + println(s"Saving to $outputDir as Parquet file.") + df.write.parquet(outputDir) + + // Load the records back. + println(s"Loading Parquet file with UDT from $outputDir.") + val newDF = sqlContext.read.parquet(outputDir) + println(s"Schema from Parquet:") + newDF.printSchema() + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala new file mode 100644 index 0000000000000..224d8da5f0ec3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.DecisionTreeClassificationModel +import org.apache.spark.ml.classification.DecisionTreeClassifier +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} +// $example off$ +import org.apache.spark.sql.SQLContext + +object DecisionTreeClassificationExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DecisionTreeClassificationExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + // $example on$ + // Load the data stored in LIBSVM format as a DataFrame. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Index labels, adding metadata to the label column. + // Fit on whole dataset to include all labels in index. + val labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data) + // Automatically identify categorical features, and index them. + val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) // features with > 4 distinct values are treated as continuous + .fit(data) + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a DecisionTree model. + val dt = new DecisionTreeClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures") + + // Convert indexed labels back to original labels. + val labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels) + + // Chain indexers and tree in a Pipeline + val pipeline = new Pipeline() + .setStages(Array(labelIndexer, featureIndexer, dt, labelConverter)) + + // Train model. This also runs the indexers. + val model = pipeline.fit(trainingData) + + // Make predictions. + val predictions = model.transform(testData) + + // Select example rows to display. + predictions.select("predictedLabel", "label", "features").show(5) + + // Select (prediction, true label) and compute test error + val evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision") + val accuracy = evaluator.evaluate(predictions) + println("Test Error = " + (1.0 - accuracy)) + + val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] + println("Learned classification tree model:\n" + treeModel.toDebugString) + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index f28671f7869fc..d2560cc00ba07 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -27,17 +27,13 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.examples.mllib.AbstractParams import org.apache.spark.ml.{Pipeline, PipelineStage, Transformer} import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier} -import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer} +import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer} import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} import org.apache.spark.ml.util.MetadataUtils -import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics} +import org.apache.spark.mllib.evaluation.{MulticlassMetrics, RegressionMetrics} import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types.StringType -import org.apache.spark.sql.{SQLContext, DataFrame} - +import org.apache.spark.sql.{DataFrame, SQLContext} /** * An example runner for decision trees. Run with @@ -138,15 +134,18 @@ object DecisionTreeExample { /** Load a dataset from the given path, using the given format */ private[ml] def loadData( - sc: SparkContext, + sqlContext: SQLContext, path: String, format: String, - expectedNumFeatures: Option[Int] = None): RDD[LabeledPoint] = { + expectedNumFeatures: Option[Int] = None): DataFrame = { + import sqlContext.implicits._ + format match { - case "dense" => MLUtils.loadLabeledPoints(sc, path) + case "dense" => MLUtils.loadLabeledPoints(sqlContext.sparkContext, path).toDF() case "libsvm" => expectedNumFeatures match { - case Some(numFeatures) => MLUtils.loadLibSVMFile(sc, path, numFeatures) - case None => MLUtils.loadLibSVMFile(sc, path) + case Some(numFeatures) => sqlContext.read.option("numFeatures", numFeatures.toString) + .format("libsvm").load(path) + case None => sqlContext.read.format("libsvm").load(path) } case _ => throw new IllegalArgumentException(s"Bad data format: $format") } @@ -169,36 +168,22 @@ object DecisionTreeExample { algo: String, fracTest: Double): (DataFrame, DataFrame) = { val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ // Load training data - val origExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat) + val origExamples: DataFrame = loadData(sqlContext, input, dataFormat) // Load or create test set - val splits: Array[RDD[LabeledPoint]] = if (testInput != "") { + val dataframes: Array[DataFrame] = if (testInput != "") { // Load testInput. - val numFeatures = origExamples.take(1)(0).features.size - val origTestExamples: RDD[LabeledPoint] = - loadData(sc, testInput, dataFormat, Some(numFeatures)) + val numFeatures = origExamples.first().getAs[Vector](1).size + val origTestExamples: DataFrame = + loadData(sqlContext, testInput, dataFormat, Some(numFeatures)) Array(origExamples, origTestExamples) } else { // Split input into training, test. origExamples.randomSplit(Array(1.0 - fracTest, fracTest), seed = 12345) } - // For classification, convert labels to Strings since we will index them later with - // StringIndexer. - def labelsToStrings(data: DataFrame): DataFrame = { - algo.toLowerCase match { - case "classification" => - data.withColumn("labelString", data("label").cast(StringType)) - case "regression" => - data - case _ => - throw new IllegalArgumentException("Algo ${params.algo} not supported.") - } - } - val dataframes = splits.map(_.toDF()).map(labelsToStrings) val training = dataframes(0).cache() val test = dataframes(1).cache() @@ -230,7 +215,7 @@ object DecisionTreeExample { val labelColName = if (algo == "classification") "indexedLabel" else "label" if (algo == "classification") { val labelIndexer = new StringIndexer() - .setInputCol("labelString") + .setInputCol("label") .setOutputCol(labelColName) stages += labelIndexer } @@ -325,8 +310,8 @@ object DecisionTreeExample { data: DataFrame, labelColName: String): Unit = { val fullPredictions = model.transform(data).cache() - val predictions = fullPredictions.select("prediction").map(_.getDouble(0)) - val labels = fullPredictions.select(labelColName).map(_.getDouble(0)) + val predictions = fullPredictions.select("prediction").rdd.map(_.getDouble(0)) + val labels = fullPredictions.select(labelColName).rdd.map(_.getDouble(0)) // Print number of classes for reference val numClasses = MetadataUtils.getNumClasses(fullPredictions.schema(labelColName)) match { case Some(n) => n @@ -350,8 +335,8 @@ object DecisionTreeExample { data: DataFrame, labelColName: String): Unit = { val fullPredictions = model.transform(data).cache() - val predictions = fullPredictions.select("prediction").map(_.getDouble(0)) - val labels = fullPredictions.select(labelColName).map(_.getDouble(0)) + val predictions = fullPredictions.select("prediction").rdd.map(_.getDouble(0)) + val labels = fullPredictions.select(labelColName).rdd.map(_.getDouble(0)) val RMSE = new RegressionMetrics(predictions.zip(labels)).rootMeanSquaredError println(s" Root mean squared error (RMSE): $RMSE") } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala new file mode 100644 index 0000000000000..ad32e5635a3ea --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.feature.VectorIndexer +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.regression.DecisionTreeRegressor +// $example off$ +import org.apache.spark.sql.SQLContext + +object DecisionTreeRegressionExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DecisionTreeRegressionExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Load the data stored in LIBSVM format as a DataFrame. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Automatically identify categorical features, and index them. + // Here, we treat features with > 4 distinct values as continuous. + val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a DecisionTree model. + val dt = new DecisionTreeRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures") + + // Chain indexer and tree in a Pipeline + val pipeline = new Pipeline() + .setStages(Array(featureIndexer, dt)) + + // Train model. This also runs the indexer. + val model = pipeline.fit(trainingData) + + // Make predictions. + val predictions = model.transform(testData) + + // Select example rows to display. + predictions.select("prediction", "label", "features").show(5) + + // Select (prediction, true label) and compute test error + val evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse") + val rmse = evaluator.evaluate(predictions) + println("Root Mean Squared Error (RMSE) on test data = " + rmse) + + val treeModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel] + println("Learned regression tree model:\n" + treeModel.toDebugString) + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 3758edc56198a..8d127f9b35420 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} /** * A simple example demonstrating how to write your own learning algorithm using Estimator, @@ -75,7 +75,7 @@ object DeveloperApiExample { prediction }.sum assert(sumPredictions == 0.0, - "MyLogisticRegression predicted something other than 0, even though all weights are 0!") + "MyLogisticRegression predicted something other than 0, even though all coefficients are 0!") sc.stop() } @@ -120,16 +120,16 @@ private class MyLogisticRegression(override val uid: String) def setMaxIter(value: Int): this.type = set(maxIter, value) // This method is used by fit() - override protected def train(dataset: DataFrame): MyLogisticRegressionModel = { + override protected def train(dataset: Dataset[_]): MyLogisticRegressionModel = { // Extract columns from data using helper method. val oldDataset = extractLabeledPoints(dataset) - // Do learning to estimate the weight vector. + // Do learning to estimate the coefficients vector. val numFeatures = oldDataset.take(1)(0).features.size - val weights = Vectors.zeros(numFeatures) // Learning would happen here. + val coefficients = Vectors.zeros(numFeatures) // Learning would happen here. // Create a model, and return it. - new MyLogisticRegressionModel(uid, weights).setParent(this) + new MyLogisticRegressionModel(uid, coefficients).setParent(this) } override def copy(extra: ParamMap): MyLogisticRegression = defaultCopy(extra) @@ -142,7 +142,7 @@ private class MyLogisticRegression(override val uid: String) */ private class MyLogisticRegressionModel( override val uid: String, - val weights: Vector) + val coefficients: Vector) extends ClassificationModel[Vector, MyLogisticRegressionModel] with MyLogisticRegressionParams { @@ -163,7 +163,7 @@ private class MyLogisticRegressionModel( * confidence for that label. */ override protected def predictRaw(features: Vector): Vector = { - val margin = BLAS.dot(features, weights) + val margin = BLAS.dot(features, coefficients) // There are 2 classes (binary classification), so we return a length-2 vector, // where index i corresponds to class i (i = 0, 1). Vectors.dense(-margin, margin) @@ -173,7 +173,7 @@ private class MyLogisticRegressionModel( override val numClasses: Int = 2 /** Number of features the model was trained on. */ - override val numFeatures: Int = weights.size + override val numFeatures: Int = coefficients.size /** * Create a copy of the model. @@ -182,7 +182,7 @@ private class MyLogisticRegressionModel( * This is used for the default implementation of [[transform()]]. */ override def copy(extra: ParamMap): MyLogisticRegressionModel = { - copyValues(new MyLogisticRegressionModel(uid, weights), extra).setParent(parent) + copyValues(new MyLogisticRegressionModel(uid, coefficients), extra).setParent(parent) } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala new file mode 100644 index 0000000000000..629d322c4357f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.ElementwiseProduct +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext + +object ElementwiseProductExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("ElementwiseProductExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Create some vector data; also works for sparse vectors + val dataFrame = sqlContext.createDataFrame(Seq( + ("a", Vectors.dense(1.0, 2.0, 3.0)), + ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") + + val transformingVector = Vectors.dense(0.0, 1.0, 2.0) + val transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector") + + // Batch transform the vectors to create new column: + transformer.transform(dataFrame).show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala new file mode 100644 index 0000000000000..65e3c365abb3f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.sql.Row +// $example off$ +import org.apache.spark.sql.SQLContext + +object EstimatorTransformerParamExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("EstimatorTransformerParamExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Prepare training data from a list of (label, features) tuples. + val training = sqlContext.createDataFrame(Seq( + (1.0, Vectors.dense(0.0, 1.1, 0.1)), + (0.0, Vectors.dense(2.0, 1.0, -1.0)), + (0.0, Vectors.dense(2.0, 1.3, 1.0)), + (1.0, Vectors.dense(0.0, 1.2, -0.5)) + )).toDF("label", "features") + + // Create a LogisticRegression instance. This instance is an Estimator. + val lr = new LogisticRegression() + // Print out the parameters, documentation, and any default values. + println("LogisticRegression parameters:\n" + lr.explainParams() + "\n") + + // We may set parameters using setter methods. + lr.setMaxIter(10) + .setRegParam(0.01) + + // Learn a LogisticRegression model. This uses the parameters stored in lr. + val model1 = lr.fit(training) + // Since model1 is a Model (i.e., a Transformer produced by an Estimator), + // we can view the parameters it used during fit(). + // This prints the parameter (name: value) pairs, where names are unique IDs for this + // LogisticRegression instance. + println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) + + // We may alternatively specify parameters using a ParamMap, + // which supports several methods for specifying parameters. + val paramMap = ParamMap(lr.maxIter -> 20) + .put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. + .put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. + + // One can also combine ParamMaps. + val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name + val paramMapCombined = paramMap ++ paramMap2 + + // Now learn a new model using the paramMapCombined parameters. + // paramMapCombined overrides all parameters set earlier via lr.set* methods. + val model2 = lr.fit(training, paramMapCombined) + println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) + + // Prepare test data. + val test = sqlContext.createDataFrame(Seq( + (1.0, Vectors.dense(-1.0, 1.5, 1.3)), + (0.0, Vectors.dense(3.0, 2.0, -0.1)), + (1.0, Vectors.dense(0.0, 2.2, -1.5)) + )).toDF("label", "features") + + // Make predictions on test data using the Transformer.transform() method. + // LogisticRegression.transform will only use the 'features' column. + // Note that model2.transform() outputs a 'myProbability' column instead of the usual + // 'probability' column since we renamed the lr.probabilityCol parameter previously. + model2.transform(test) + .select("features", "label", "myProbability", "prediction") + .collect() + .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) => + println(s"($features, $label) -> prob=$prob, prediction=$prediction") + } + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala index f4a15f806ea81..6b0be0f34e196 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala @@ -153,7 +153,7 @@ object GBTExample { val labelColName = if (algo == "classification") "indexedLabel" else "label" if (algo == "classification") { val labelIndexer = new StringIndexer() - .setInputCol("labelString") + .setInputCol("label") .setOutputCol(labelColName) stages += labelIndexer } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala new file mode 100644 index 0000000000000..cd62a803820cf --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier} +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} +// $example off$ +import org.apache.spark.sql.SQLContext + +object GradientBoostedTreeClassifierExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("GradientBoostedTreeClassifierExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Index labels, adding metadata to the label column. + // Fit on whole dataset to include all labels in index. + val labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data) + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a GBT model. + val gbt = new GBTClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures") + .setMaxIter(10) + + // Convert indexed labels back to original labels. + val labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels) + + // Chain indexers and GBT in a Pipeline + val pipeline = new Pipeline() + .setStages(Array(labelIndexer, featureIndexer, gbt, labelConverter)) + + // Train model. This also runs the indexers. + val model = pipeline.fit(trainingData) + + // Make predictions. + val predictions = model.transform(testData) + + // Select example rows to display. + predictions.select("predictedLabel", "label", "features").show(5) + + // Select (prediction, true label) and compute test error + val evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision") + val accuracy = evaluator.evaluate(predictions) + println("Test Error = " + (1.0 - accuracy)) + + val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel] + println("Learned classification GBT model:\n" + gbtModel.toDebugString) + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala new file mode 100644 index 0000000000000..b8cf9629bbdab --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.feature.VectorIndexer +import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor} +// $example off$ +import org.apache.spark.sql.SQLContext + +object GradientBoostedTreeRegressorExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("GradientBoostedTreeRegressorExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a GBT model. + val gbt = new GBTRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures") + .setMaxIter(10) + + // Chain indexer and GBT in a Pipeline + val pipeline = new Pipeline() + .setStages(Array(featureIndexer, gbt)) + + // Train model. This also runs the indexer. + val model = pipeline.fit(trainingData) + + // Make predictions. + val predictions = model.transform(testData) + + // Select example rows to display. + predictions.select("prediction", "label", "features").show(5) + + // Select (prediction, true label) and compute test error + val evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse") + val rmse = evaluator.evaluate(predictions) + println("Root Mean Squared Error (RMSE) on test data = " + rmse) + + val gbtModel = model.stages(1).asInstanceOf[GBTRegressionModel] + println("Learned regression GBT model:\n" + gbtModel.toDebugString) + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala new file mode 100644 index 0000000000000..4cea09ba12656 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.{IndexToString, StringIndexer} +// $example off$ +import org.apache.spark.sql.SQLContext + +object IndexToStringExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("IndexToStringExample") + val sc = new SparkContext(conf) + + val sqlContext = SQLContext.getOrCreate(sc) + + // $example on$ + val df = sqlContext.createDataFrame(Seq( + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") + )).toDF("id", "category") + + val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df) + val indexed = indexer.transform(df) + + val converter = new IndexToString() + .setInputCol("categoryIndex") + .setOutputCol("originalCategory") + + val converted = converter.transform(indexed) + converted.select("id", "originalCategory").show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala index 5ce38462d1181..7af011571f76e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala @@ -17,57 +17,54 @@ package org.apache.spark.examples.ml -import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} -import org.apache.spark.ml.clustering.KMeans -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.types.{StructField, StructType} +// scalastyle:off println +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.clustering.KMeans +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.sql.{DataFrame, SQLContext} +// $example off$ /** * An example demonstrating a k-means clustering. * Run with * {{{ - * bin/run-example ml.KMeansExample + * bin/run-example ml.KMeansExample * }}} */ object KMeansExample { - final val FEATURES_COL = "features" - def main(args: Array[String]): Unit = { - if (args.length != 2) { - // scalastyle:off println - System.err.println("Usage: ml.KMeansExample ") - // scalastyle:on println - System.exit(1) - } - val input = args(0) - val k = args(1).toInt - // Creates a Spark context and a SQL context val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - // Loads data - val rowRDD = sc.textFile(input).filter(_.nonEmpty) - .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_)) - val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false))) - val dataset = sqlContext.createDataFrame(rowRDD, schema) + // $example on$ + // Crates a DataFrame + val dataset: DataFrame = sqlContext.createDataFrame(Seq( + (1, Vectors.dense(0.0, 0.0, 0.0)), + (2, Vectors.dense(0.1, 0.1, 0.1)), + (3, Vectors.dense(0.2, 0.2, 0.2)), + (4, Vectors.dense(9.0, 9.0, 9.0)), + (5, Vectors.dense(9.1, 9.1, 9.1)), + (6, Vectors.dense(9.2, 9.2, 9.2)) + )).toDF("id", "features") // Trains a k-means model val kmeans = new KMeans() - .setK(k) - .setFeaturesCol(FEATURES_COL) + .setK(2) + .setFeaturesCol("features") + .setPredictionCol("prediction") val model = kmeans.fit(dataset) // Shows the result - // scalastyle:off println println("Final Centers: ") model.clusterCenters.foreach(println) - // scalastyle:on println + // $example off$ sc.stop() } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala new file mode 100644 index 0000000000000..f9ddac77090ec --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +// scalastyle:off println +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.clustering.LDA +import org.apache.spark.mllib.linalg.{Vectors, VectorUDT} +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.types.{StructField, StructType} +// $example off$ + +/** + * An example demonstrating a LDA of ML pipeline. + * Run with + * {{{ + * bin/run-example ml.LDAExample + * }}} + */ +object LDAExample { + + final val FEATURES_COL = "features" + + def main(args: Array[String]): Unit = { + + val input = "data/mllib/sample_lda_data.txt" + // Creates a Spark context and a SQL context + val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Loads data + val rowRDD = sc.textFile(input).filter(_.nonEmpty) + .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_)) + val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false))) + val dataset = sqlContext.createDataFrame(rowRDD, schema) + + // Trains a LDA model + val lda = new LDA() + .setK(10) + .setMaxIter(10) + .setFeaturesCol(FEATURES_COL) + val model = lda.fit(dataset) + val transformed = model.transform(dataset) + + val ll = model.logLikelihood(dataset) + val lp = model.logPerplexity(dataset) + + // describeTopics + val topics = model.describeTopics(3) + + // Shows the result + topics.show(false) + transformed.show(false) + + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala index b73299fb12d3f..25be87811da90 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala @@ -18,15 +18,13 @@ // scalastyle:off println package org.apache.spark.examples.ml -import scala.collection.mutable import scala.language.reflectiveCalls import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.examples.mllib.AbstractParams -import org.apache.spark.ml.{Pipeline, PipelineStage} -import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} +import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.sql.DataFrame /** @@ -131,7 +129,7 @@ object LinearRegressionExample { println(s"Training time: $elapsedTime seconds") // Print the weights and intercept for linear regression. - println(s"Weights: ${lirModel.weights} Intercept: ${lirModel.intercept}") + println(s"Weights: ${lirModel.coefficients} Intercept: ${lirModel.intercept}") println("Training data results:") DecisionTreeExample.evaluateRegressionModel(lirModel, training, "label") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala new file mode 100644 index 0000000000000..c7352b3e7ab9c --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.regression.LinearRegression +// $example off$ +import org.apache.spark.sql.SQLContext + +object LinearRegressionWithElasticNetExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("LinearRegressionWithElasticNetExample") + val sc = new SparkContext(conf) + val sqlCtx = new SQLContext(sc) + + // $example on$ + // Load training data + val training = sqlCtx.read.format("libsvm") + .load("data/mllib/sample_linear_regression_data.txt") + + val lr = new LinearRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + + // Fit the model + val lrModel = lr.fit(training) + + // Print the coefficients and intercept for linear regression + println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}") + + // Summarize the model over the training set and print out some metrics + val trainingSummary = lrModel.summary + println(s"numIterations: ${trainingSummary.totalIterations}") + println(s"objectiveHistory: ${trainingSummary.objectiveHistory.toList}") + trainingSummary.residuals.show() + println(s"RMSE: ${trainingSummary.rootMeanSquaredError}") + println(s"r2: ${trainingSummary.r2}") + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala index 8e3760ddb50a9..a380c90662a50 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala @@ -125,7 +125,7 @@ object LogisticRegressionExample { val stages = new mutable.ArrayBuffer[PipelineStage]() val labelIndexer = new StringIndexer() - .setInputCol("labelString") + .setInputCol("label") .setOutputCol("indexedLabel") stages += labelIndexer @@ -149,7 +149,7 @@ object LogisticRegressionExample { val lorModel = pipelineModel.stages.last.asInstanceOf[LogisticRegressionModel] // Print the weights and intercept for logistic regression. - println(s"Weights: ${lorModel.weights} Intercept: ${lorModel.intercept}") + println(s"Weights: ${lorModel.coefficients} Intercept: ${lorModel.intercept}") println("Training data results:") DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, "indexedLabel") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala new file mode 100644 index 0000000000000..04c60c0c1d067 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.functions.max + +object LogisticRegressionSummaryExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("LogisticRegressionSummaryExample") + val sc = new SparkContext(conf) + val sqlCtx = new SQLContext(sc) + import sqlCtx.implicits._ + + // Load training data + val training = sqlCtx.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + + // Fit the model + val lrModel = lr.fit(training) + + // $example on$ + // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier + // example + val trainingSummary = lrModel.summary + + // Obtain the objective per iteration. + val objectiveHistory = trainingSummary.objectiveHistory + objectiveHistory.foreach(loss => println(loss)) + + // Obtain the metrics useful to judge performance on test data. + // We cast the summary to a BinaryLogisticRegressionSummary since the problem is a + // binary classification problem. + val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary] + + // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. + val roc = binarySummary.roc + roc.show() + println(binarySummary.areaUnderROC) + + // Set the model threshold to maximize F-Measure + val fMeasure = binarySummary.fMeasureByThreshold + val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0) + val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure) + .select("threshold").head().getDouble(0) + lrModel.setThreshold(bestThreshold) + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala new file mode 100644 index 0000000000000..f632960f26ae5 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.classification.LogisticRegression +// $example off$ +import org.apache.spark.sql.SQLContext + +object LogisticRegressionWithElasticNetExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("LogisticRegressionWithElasticNetExample") + val sc = new SparkContext(conf) + val sqlCtx = new SQLContext(sc) + + // $example on$ + // Load training data + val training = sqlCtx.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + + // Fit the model + val lrModel = lr.fit(training) + + // Print the coefficients and intercept for logistic regression + println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}") + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MaxAbsScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MaxAbsScalerExample.scala new file mode 100644 index 0000000000000..aafb5efd698e4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MaxAbsScalerExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.MaxAbsScaler +// $example off$ +import org.apache.spark.sql.SQLContext + +object MaxAbsScalerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("MaxAbsScalerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + val scaler = new MaxAbsScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + + // Compute summary statistics and generate MaxAbsScalerModel + val scalerModel = scaler.fit(dataFrame) + + // rescale each feature to range [-1, 1] + val scaledData = scalerModel.transform(dataFrame) + scaledData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala new file mode 100644 index 0000000000000..9a03f69f5af03 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.MinMaxScaler +// $example off$ +import org.apache.spark.sql.SQLContext + +object MinMaxScalerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("MinMaxScalerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + + // Compute summary statistics and generate MinMaxScalerModel + val scalerModel = scaler.fit(dataFrame) + + // rescale each feature to range [min, max]. + val scaledData = scalerModel.transform(dataFrame) + scaledData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala new file mode 100644 index 0000000000000..0331d6e7b35df --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator +import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.sql.Row +// $example off$ +import org.apache.spark.sql.SQLContext + +object ModelSelectionViaCrossValidationExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("ModelSelectionViaCrossValidationExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Prepare training data from a list of (id, text, label) tuples. + val training = sqlContext.createDataFrame(Seq( + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0), + (4L, "b spark who", 1.0), + (5L, "g d a y", 0.0), + (6L, "spark fly", 1.0), + (7L, "was mapreduce", 0.0), + (8L, "e spark program", 1.0), + (9L, "a e c l", 0.0), + (10L, "spark compile", 1.0), + (11L, "hadoop software", 0.0) + )).toDF("id", "text", "label") + + // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + val tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words") + val hashingTF = new HashingTF() + .setInputCol(tokenizer.getOutputCol) + .setOutputCol("features") + val lr = new LogisticRegression() + .setMaxIter(10) + val pipeline = new Pipeline() + .setStages(Array(tokenizer, hashingTF, lr)) + + // We use a ParamGridBuilder to construct a grid of parameters to search over. + // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, + // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. + val paramGrid = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures, Array(10, 100, 1000)) + .addGrid(lr.regParam, Array(0.1, 0.01)) + .build() + + // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. + // This will allow us to jointly choose parameters for all Pipeline stages. + // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + // Note that the evaluator here is a BinaryClassificationEvaluator and its default metric + // is areaUnderROC. + val cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(new BinaryClassificationEvaluator) + .setEstimatorParamMaps(paramGrid) + .setNumFolds(2) // Use 3+ in practice + + // Run cross-validation, and choose the best set of parameters. + val cvModel = cv.fit(training) + + // Prepare test documents, which are unlabeled (id, text) tuples. + val test = sqlContext.createDataFrame(Seq( + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop") + )).toDF("id", "text") + + // Make predictions on test documents. cvModel uses the best model found (lrModel). + cvModel.transform(test) + .select("id", "text", "probability", "prediction") + .collect() + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println(s"($id, $text) --> prob=$prob, prediction=$prediction") + } + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala new file mode 100644 index 0000000000000..5a95344f223df --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} +// $example off$ +import org.apache.spark.sql.SQLContext + +object ModelSelectionViaTrainValidationSplitExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("ModelSelectionViaTrainValidationSplitExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Prepare training and test data. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_linear_regression_data.txt") + val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) + + val lr = new LinearRegression() + + // We use a ParamGridBuilder to construct a grid of parameters to search over. + // TrainValidationSplit will try all combinations of values and determine best model using + // the evaluator. + val paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.01)) + .addGrid(lr.fitIntercept) + .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)) + .build() + + // In this case the estimator is simply the linear regression. + // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + val trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator) + .setEstimatorParamMaps(paramGrid) + // 80% of the data will be used for training and the remaining 20% for validation. + .setTrainRatio(0.8) + + // Run train validation split, and choose the best set of parameters. + val model = trainValidationSplit.fit(training) + + // Make predictions on test data. model is the model with combination of parameters + // that performed best. + model.transform(test) + .select("features", "label", "prediction") + .show() + // $example off$ + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala deleted file mode 100644 index 3ae53e57dbdb8..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -import scopt.OptionParser - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.examples.mllib.AbstractParams -import org.apache.spark.ml.recommendation.ALS -import org.apache.spark.sql.{Row, SQLContext} - -/** - * An example app for ALS on MovieLens data (http://grouplens.org/datasets/movielens/). - * Run with - * {{{ - * bin/run-example ml.MovieLensALS - * }}} - */ -object MovieLensALS { - - case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long) - - object Rating { - def parseRating(str: String): Rating = { - val fields = str.split("::") - assert(fields.size == 4) - Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong) - } - } - - case class Movie(movieId: Int, title: String, genres: Seq[String]) - - object Movie { - def parseMovie(str: String): Movie = { - val fields = str.split("::") - assert(fields.size == 3) - Movie(fields(0).toInt, fields(1), fields(2).split("|")) - } - } - - case class Params( - ratings: String = null, - movies: String = null, - maxIter: Int = 10, - regParam: Double = 0.1, - rank: Int = 10, - numBlocks: Int = 10) extends AbstractParams[Params] - - def main(args: Array[String]) { - val defaultParams = Params() - - val parser = new OptionParser[Params]("MovieLensALS") { - head("MovieLensALS: an example app for ALS on MovieLens data.") - opt[String]("ratings") - .required() - .text("path to a MovieLens dataset of ratings") - .action((x, c) => c.copy(ratings = x)) - opt[String]("movies") - .required() - .text("path to a MovieLens dataset of movies") - .action((x, c) => c.copy(movies = x)) - opt[Int]("rank") - .text(s"rank, default: ${defaultParams.rank}") - .action((x, c) => c.copy(rank = x)) - opt[Int]("maxIter") - .text(s"max number of iterations, default: ${defaultParams.maxIter}") - .action((x, c) => c.copy(maxIter = x)) - opt[Double]("regParam") - .text(s"regularization parameter, default: ${defaultParams.regParam}") - .action((x, c) => c.copy(regParam = x)) - opt[Int]("numBlocks") - .text(s"number of blocks, default: ${defaultParams.numBlocks}") - .action((x, c) => c.copy(numBlocks = x)) - note( - """ - |Example command line to run this app: - | - | bin/spark-submit --class org.apache.spark.examples.ml.MovieLensALS \ - | examples/target/scala-*/spark-examples-*.jar \ - | --rank 10 --maxIter 15 --regParam 0.1 \ - | --movies data/mllib/als/sample_movielens_movies.txt \ - | --ratings data/mllib/als/sample_movielens_ratings.txt - """.stripMargin) - } - - parser.parse(args, defaultParams).map { params => - run(params) - } getOrElse { - System.exit(1) - } - } - - def run(params: Params) { - val conf = new SparkConf().setAppName(s"MovieLensALS with $params") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ - - val ratings = sc.textFile(params.ratings).map(Rating.parseRating).cache() - - val numRatings = ratings.count() - val numUsers = ratings.map(_.userId).distinct().count() - val numMovies = ratings.map(_.movieId).distinct().count() - - println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") - - val splits = ratings.randomSplit(Array(0.8, 0.2), 0L) - val training = splits(0).cache() - val test = splits(1).cache() - - val numTraining = training.count() - val numTest = test.count() - println(s"Training: $numTraining, test: $numTest.") - - ratings.unpersist(blocking = false) - - val als = new ALS() - .setUserCol("userId") - .setItemCol("movieId") - .setRank(params.rank) - .setMaxIter(params.maxIter) - .setRegParam(params.regParam) - .setNumBlocks(params.numBlocks) - - val model = als.fit(training.toDF()) - - val predictions = model.transform(test.toDF()).cache() - - // Evaluate the model. - // TODO: Create an evaluator to compute RMSE. - val mse = predictions.select("rating", "prediction").rdd - .flatMap { case Row(rating: Float, prediction: Float) => - val err = rating.toDouble - prediction - val err2 = err * err - if (err2.isNaN) { - None - } else { - Some(err2) - } - }.mean() - val rmse = math.sqrt(mse) - println(s"Test RMSE = $rmse.") - - // Inspect false positives. - // Note: We reference columns in 2 ways: - // (1) predictions("movieId") lets us specify the movieId column in the predictions - // DataFrame, rather than the movieId column in the movies DataFrame. - // (2) $"userId" specifies the userId column in the predictions DataFrame. - // We could also write predictions("userId") but do not have to since - // the movies DataFrame does not have a column "userId." - val movies = sc.textFile(params.movies).map(Movie.parseMovie).toDF() - val falsePositives = predictions.join(movies) - .where((predictions("movieId") === movies("movieId")) - && ($"rating" <= 1) && ($"prediction" >= 4)) - .select($"userId", predictions("movieId"), $"title", $"rating", $"prediction") - val numFalsePositives = falsePositives.count() - println(s"Found $numFalsePositives false positives") - if (numFalsePositives > 0) { - println(s"Example false positives:") - falsePositives.limit(100).collect().foreach(println) - } - - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala new file mode 100644 index 0000000000000..d7d1e82f6f849 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.classification.MultilayerPerceptronClassifier +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +// $example off$ +import org.apache.spark.sql.SQLContext + +/** + * An example for Multilayer Perceptron Classification. + */ +object MultilayerPerceptronClassifierExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("MultilayerPerceptronClassifierExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Load the data stored in LIBSVM format as a DataFrame. + val data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_multiclass_classification_data.txt") + // Split the data into train and test + val splits = data.randomSplit(Array(0.6, 0.4), seed = 1234L) + val train = splits(0) + val test = splits(1) + // specify layers for the neural network: + // input layer of size 4 (features), two intermediate of size 5 and 4 + // and output of size 3 (classes) + val layers = Array[Int](4, 5, 4, 3) + // create the trainer and set its parameters + val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(128) + .setSeed(1234L) + .setMaxIter(100) + // train the model + val model = trainer.fit(train) + // compute precision on the test set + val result = model.transform(test) + val predictionAndLabels = result.select("prediction", "label") + val evaluator = new MulticlassClassificationEvaluator() + .setMetricName("precision") + println("Precision:" + evaluator.evaluate(predictionAndLabels)) + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala new file mode 100644 index 0000000000000..77b913aaa3fa0 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.NGram +// $example off$ +import org.apache.spark.sql.SQLContext + +object NGramExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("NGramExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val wordDataFrame = sqlContext.createDataFrame(Seq( + (0, Array("Hi", "I", "heard", "about", "Spark")), + (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), + (2, Array("Logistic", "regression", "models", "are", "neat")) + )).toDF("label", "words") + + val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") + val ngramDataFrame = ngram.transform(wordDataFrame) + ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala new file mode 100644 index 0000000000000..5ea1270c9781c --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.classification.{NaiveBayes} +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +// $example off$ +import org.apache.spark.sql.SQLContext + +object NaiveBayesExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("NaiveBayesExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + // $example on$ + // Load the data stored in LIBSVM format as a DataFrame. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a NaiveBayes model. + val model = new NaiveBayes() + .fit(trainingData) + + // Select example rows to display. + val predictions = model.transform(testData) + predictions.show() + + // Select (prediction, true label) and compute test error + val evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("precision") + val precision = evaluator.evaluate(predictions) + println("Precision:" + precision) + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala new file mode 100644 index 0000000000000..6b33c16c74037 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.Normalizer +// $example off$ +import org.apache.spark.sql.SQLContext + +object NormalizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("NormalizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Normalize each Vector using $L^1$ norm. + val normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures") + .setP(1.0) + + val l1NormData = normalizer.transform(dataFrame) + l1NormData.show() + + // Normalize each Vector using $L^\infty$ norm. + val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity) + lInfNormData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala new file mode 100644 index 0000000000000..cb9fe65a85e86 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} +// $example off$ +import org.apache.spark.sql.SQLContext + +object OneHotEncoderExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("OneHotEncoderExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val df = sqlContext.createDataFrame(Seq( + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") + )).toDF("id", "category") + + val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df) + val indexed = indexer.transform(df) + + val encoder = new OneHotEncoder() + .setInputCol("categoryIndex") + .setOutputCol("categoryVec") + val encoded = encoder.transform(indexed) + encoded.select("id", "categoryVec").show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala index bab31f585b0ef..0b5d31c0ff90d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -22,14 +22,15 @@ import java.util.concurrent.TimeUnit.{NANOSECONDS => NANO} import scopt.OptionParser -import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ import org.apache.spark.examples.mllib.AbstractParams -import org.apache.spark.ml.classification.{OneVsRest, LogisticRegression} +import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.sql.DataFrame +// $example off$ import org.apache.spark.sql.SQLContext /** @@ -111,24 +112,23 @@ object OneVsRestExample { private def run(params: Params) { val conf = new SparkConf().setAppName(s"OneVsRestExample with $params") val sc = new SparkContext(conf) - val inputData = MLUtils.loadLibSVMFile(sc, params.input) val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ + // $example on$ + val inputData = sqlContext.read.format("libsvm").load(params.input) // compute the train/test split: if testInput is not provided use part of input. val data = params.testInput match { - case Some(t) => { + case Some(t) => // compute the number of features in the training set. - val numFeatures = inputData.first().features.size - val testData = MLUtils.loadLibSVMFile(sc, t, numFeatures) - Array[RDD[LabeledPoint]](inputData, testData) - } - case None => { + val numFeatures = inputData.first().getAs[Vector](1).size + val testData = sqlContext.read.option("numFeatures", numFeatures.toString) + .format("libsvm").load(t) + Array[DataFrame](inputData, testData) + case None => val f = params.fracTest inputData.randomSplit(Array(1 - f, f), seed = 12345) - } } - val Array(train, test) = data.map(_.toDF().cache()) + val Array(train, test) = data.map(_.cache()) // instantiate the base classifier val classifier = new LogisticRegression() @@ -153,7 +153,7 @@ object OneVsRestExample { // evaluate the model val predictionsAndLabels = predictions.select("prediction", "label") - .map(row => (row.getDouble(0), row.getDouble(1))) + .rdd.map(row => (row.getDouble(0), row.getDouble(1))) val metrics = new MulticlassMetrics(predictionsAndLabels) @@ -173,6 +173,7 @@ object OneVsRestExample { println("label\tfpr") println(fprs.map {case (label, fpr) => label + "\t" + fpr}.mkString("\n")) + // $example off$ sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala new file mode 100644 index 0000000000000..535652ec6c793 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.PCA +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext + +object PCAExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("PCAExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ) + val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + val pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df) + val pcaDF = pca.transform(df) + val result = pcaDF.select("pcaFeatures") + result.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PipelineExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PipelineExample.scala new file mode 100644 index 0000000000000..6c29063626bac --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PipelineExample.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.sql.Row +// $example off$ +import org.apache.spark.sql.SQLContext + +object PipelineExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("PipelineExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Prepare training documents from a list of (id, text, label) tuples. + val training = sqlContext.createDataFrame(Seq( + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0) + )).toDF("id", "text", "label") + + // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + val tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words") + val hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol) + .setOutputCol("features") + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01) + val pipeline = new Pipeline() + .setStages(Array(tokenizer, hashingTF, lr)) + + // Fit the pipeline to training documents. + val model = pipeline.fit(training) + + // Now we can optionally save the fitted pipeline to disk + model.write.overwrite().save("/tmp/spark-logistic-regression-model") + + // We can also save this unfit pipeline to disk + pipeline.write.overwrite().save("/tmp/unfit-lr-model") + + // And load it back in during production + val sameModel = PipelineModel.load("/tmp/spark-logistic-regression-model") + + // Prepare test documents, which are unlabeled (id, text) tuples. + val test = sqlContext.createDataFrame(Seq( + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop") + )).toDF("id", "text") + + // Make predictions on test documents. + model.transform(test) + .select("id", "text", "probability", "prediction") + .collect() + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println(s"($id, $text) --> prob=$prob, prediction=$prediction") + } + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala new file mode 100644 index 0000000000000..3014008ea0ce4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.PolynomialExpansion +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext + +object PolynomialExpansionExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("PolynomialExpansionExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Array( + Vectors.dense(-2.0, 2.3), + Vectors.dense(0.0, 0.0), + Vectors.dense(0.6, -1.1) + ) + val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + val polynomialExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3) + val polyDF = polynomialExpansion.transform(df) + polyDF.select("polyFeatures").take(3).foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala new file mode 100644 index 0000000000000..e64e673a485ed --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.QuantileDiscretizer +// $example off$ +import org.apache.spark.sql.SQLContext + +object QuantileDiscretizerExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("QuantileDiscretizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // $example on$ + val data = Array((0, 18.0), (1, 19.0), (2, 8.0), (3, 5.0), (4, 2.2)) + val df = sc.parallelize(data).toDF("id", "hour") + + val discretizer = new QuantileDiscretizer() + .setInputCol("hour") + .setOutputCol("result") + .setNumBuckets(3) + + val result = discretizer.fit(df).transform(df) + result.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala new file mode 100644 index 0000000000000..bec831d51c581 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.RFormula +// $example off$ +import org.apache.spark.sql.SQLContext + +object RFormulaExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("RFormulaExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataset = sqlContext.createDataFrame(Seq( + (7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0) + )).toDF("id", "country", "hour", "clicked") + val formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label") + val output = formula.fit(dataset).transform(dataset) + output.select("features", "label").show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala new file mode 100644 index 0000000000000..6c9b52cf259e6 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} +// $example off$ +import org.apache.spark.sql.SQLContext + +object RandomForestClassifierExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("RandomForestClassifierExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Index labels, adding metadata to the label column. + // Fit on whole dataset to include all labels in index. + val labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data) + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a RandomForest model. + val rf = new RandomForestClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures") + .setNumTrees(10) + + // Convert indexed labels back to original labels. + val labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels) + + // Chain indexers and forest in a Pipeline + val pipeline = new Pipeline() + .setStages(Array(labelIndexer, featureIndexer, rf, labelConverter)) + + // Train model. This also runs the indexers. + val model = pipeline.fit(trainingData) + + // Make predictions. + val predictions = model.transform(testData) + + // Select example rows to display. + predictions.select("predictedLabel", "label", "features").show(5) + + // Select (prediction, true label) and compute test error + val evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision") + val accuracy = evaluator.evaluate(predictions) + println("Test Error = " + (1.0 - accuracy)) + + val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] + println("Learned classification forest model:\n" + rfModel.toDebugString) + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala index 109178f4137b2..7a00d99dfe53d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala @@ -159,7 +159,7 @@ object RandomForestExample { val labelColName = if (algo == "classification") "indexedLabel" else "label" if (algo == "classification") { val labelIndexer = new StringIndexer() - .setInputCol("labelString") + .setInputCol("label") .setOutputCol(labelColName) stages += labelIndexer } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala new file mode 100644 index 0000000000000..4d2db017f346f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.feature.VectorIndexer +import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor} +// $example off$ +import org.apache.spark.sql.SQLContext + +object RandomForestRegressorExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("RandomForestRegressorExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a RandomForest model. + val rf = new RandomForestRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures") + + // Chain indexer and forest in a Pipeline + val pipeline = new Pipeline() + .setStages(Array(featureIndexer, rf)) + + // Train model. This also runs the indexer. + val model = pipeline.fit(trainingData) + + // Make predictions. + val predictions = model.transform(testData) + + // Select example rows to display. + predictions.select("prediction", "label", "features").show(5) + + // Select (prediction, true label) and compute test error + val evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse") + val rmse = evaluator.evaluate(predictions) + println("Root Mean Squared Error (RMSE) on test data = " + rmse) + + val rfModel = model.stages(1).asInstanceOf[RandomForestRegressionModel] + println("Learned regression forest model:\n" + rfModel.toDebugString) + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala new file mode 100644 index 0000000000000..202925acadff2 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.SQLTransformer +// $example off$ +import org.apache.spark.sql.SQLContext + +object SQLTransformerExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("SQLTransformerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val df = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") + + val sqlTrans = new SQLTransformer().setStatement( + "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") + + sqlTrans.transform(df).show() + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala new file mode 100644 index 0000000000000..e3439677e78d6 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.StandardScaler +// $example off$ +import org.apache.spark.sql.SQLContext + +object StandardScalerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("StandardScalerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false) + + // Compute summary statistics by fitting the StandardScaler. + val scalerModel = scaler.fit(dataFrame) + + // Normalize each feature to have unit standard deviation. + val scaledData = scalerModel.transform(dataFrame) + scaledData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala new file mode 100644 index 0000000000000..8199be12c155b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.StopWordsRemover +// $example off$ +import org.apache.spark.sql.SQLContext + +object StopWordsRemoverExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("StopWordsRemoverExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + + val dataSet = sqlContext.createDataFrame(Seq( + (0, Seq("I", "saw", "the", "red", "baloon")), + (1, Seq("Mary", "had", "a", "little", "lamb")) + )).toDF("id", "raw") + + remover.transform(dataSet).show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala new file mode 100644 index 0000000000000..3f0e870c8dc6b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.StringIndexer +// $example off$ +import org.apache.spark.sql.SQLContext + +object StringIndexerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("StringIndexerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val df = sqlContext.createDataFrame( + Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) + ).toDF("id", "category") + + val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + + val indexed = indexer.fit(df).transform(df) + indexed.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala index 40c33e4e7d44e..28115f939082e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer} // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object TfIdfExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala new file mode 100644 index 0000000000000..c667728d6326d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.{RegexTokenizer, Tokenizer} +// $example off$ +import org.apache.spark.sql.SQLContext + +object TokenizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("TokenizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val sentenceDataFrame = sqlContext.createDataFrame(Seq( + (0, "Hi I heard about Spark"), + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") + )).toDF("label", "sentence") + + val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") + val regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) + + val tokenized = tokenizer.transform(sentenceDataFrame) + tokenized.select("words", "label").take(3).foreach(println) + val regexTokenized = regexTokenizer.transform(sentenceDataFrame) + regexTokenized.select("words", "label").take(3).foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala index 1abdf219b1c00..fbba17eba6a2f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala @@ -17,12 +17,11 @@ package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.evaluation.RegressionEvaluator import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} -import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} /** * A simple example demonstrating model selection using TrainValidationSplit. @@ -39,10 +38,9 @@ object TrainValidationSplitExample { val conf = new SparkConf().setAppName("TrainValidationSplitExample") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ // Prepare training and test data. - val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) val lr = new LinearRegression() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala new file mode 100644 index 0000000000000..768a8c0690477 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.VectorAssembler +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext + +object VectorAssemblerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("VectorAssemblerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataset = sqlContext.createDataFrame( + Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) + ).toDF("id", "hour", "mobile", "userFeatures", "clicked") + + val assembler = new VectorAssembler() + .setInputCols(Array("hour", "mobile", "userFeatures")) + .setOutputCol("features") + + val output = assembler.transform(dataset) + println(output.select("features", "clicked").first()) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala new file mode 100644 index 0000000000000..3bef37ba360b9 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.VectorIndexer +// $example off$ +import org.apache.spark.sql.SQLContext + +object VectorIndexerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("VectorIndexerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(10) + + val indexerModel = indexer.fit(data) + + val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet + println(s"Chose ${categoricalFeatures.size} categorical features: " + + categoricalFeatures.mkString(", ")) + + // Create new column "indexed" with categorical values transformed to indices + val indexedData = indexerModel.transform(data) + indexedData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala new file mode 100644 index 0000000000000..01377d80e7e5c --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} +import org.apache.spark.ml.feature.VectorSlicer +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.StructType +// $example off$ +import org.apache.spark.sql.SQLContext + +object VectorSlicerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("VectorSlicerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Array(Row(Vectors.dense(-2.0, 2.3, 0.0))) + + val defaultAttr = NumericAttribute.defaultAttr + val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) + val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) + + val dataRDD = sc.parallelize(data) + val dataset = sqlContext.createDataFrame(dataRDD, StructType(Array(attrGroup.toStructField()))) + + val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") + + slicer.setIndices(Array(1)).setNames(Array("f3")) + // or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) + + val output = slicer.transform(dataset) + println(output.select("userFeatures", "features").first()) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala index 631ab4c8efa0d..e77aa59ba32b2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.Word2Vec // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object Word2VecExample { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala index ca22ddafc3c48..11e18c9f040bc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala @@ -18,13 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.fpm.AssociationRules import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset // $example off$ -import org.apache.spark.{SparkConf, SparkContext} - object AssociationRulesExample { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala index 1a4016f76c2ad..2282bd2b7d680 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -24,8 +24,8 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, SVMWithSGD} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.optimization.{L1Updater, SquaredL2Updater} import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.optimization.{SquaredL2Updater, L1Updater} /** * An example app for binary classification. Run with diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala new file mode 100644 index 0000000000000..ade33fc5090f9 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object BinaryClassificationMetricsExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("BinaryClassificationMetricsExample") + val sc = new SparkContext(conf) + // $example on$ + // Load training data in LIBSVM format + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + + // Split data into training (60%) and test (40%) + val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) + training.cache() + + // Run training algorithm to build the model + val model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training) + + // Clear the prediction threshold so the model will return probabilities + model.clearThreshold + + // Compute raw scores on the test set + val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) + } + + // Instantiate metrics object + val metrics = new BinaryClassificationMetrics(predictionAndLabels) + + // Precision by threshold + val precision = metrics.precisionByThreshold + precision.foreach { case (t, p) => + println(s"Threshold: $t, Precision: $p") + } + + // Recall by threshold + val recall = metrics.recallByThreshold + recall.foreach { case (t, r) => + println(s"Threshold: $t, Recall: $r") + } + + // Precision-Recall Curve + val PRC = metrics.pr + + // F-measure + val f1Score = metrics.fMeasureByThreshold + f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 1") + } + + val beta = 0.5 + val fScore = metrics.fMeasureByThreshold(beta) + f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 0.5") + } + + // AUPRC + val auPRC = metrics.areaUnderPR + println("Area under precision-recall curve = " + auPRC) + + // Compute thresholds used in ROC and PR curves + val thresholds = precision.map(_._1) + + // ROC Curve + val roc = metrics.roc + + // AUROC + val auROC = metrics.areaUnderROC + println("Area under ROC = " + auROC) + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala new file mode 100644 index 0000000000000..53d0b8fc208ef --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +// scalastyle:off println +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.clustering.BisectingKMeans +import org.apache.spark.mllib.linalg.{Vector, Vectors} +// $example off$ + +/** + * An example demonstrating a bisecting k-means clustering in spark.mllib. + * + * Run with + * {{{ + * bin/run-example mllib.BisectingKMeansExample + * }}} + */ +object BisectingKMeansExample { + + def main(args: Array[String]) { + val sparkConf = new SparkConf().setAppName("mllib.BisectingKMeansExample") + val sc = new SparkContext(sparkConf) + + // $example on$ + // Loads and parses data + def parse(line: String): Vector = Vectors.dense(line.split(" ").map(_.toDouble)) + val data = sc.textFile("data/mllib/kmeans_data.txt").map(parse).cache() + + // Clustering the data into 6 clusters by BisectingKMeans. + val bkm = new BisectingKMeans().setK(6) + val model = bkm.run(data) + + // Show the compute cost and the cluster centers + println(s"Compute Cost: ${model.computeCost(data)}") + model.clusterCenters.zipWithIndex.foreach { case (center, idx) => + println(s"Cluster Center ${idx}: ${center}") + } + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/ChiSqSelectorExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/ChiSqSelectorExample.scala new file mode 100644 index 0000000000000..5e400b7d715b4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/ChiSqSelectorExample.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +// $example on$ +import org.apache.spark.mllib.feature.ChiSqSelector +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object ChiSqSelectorExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("ChiSqSelectorExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load some data in libsvm format + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + // Discretize data in 16 equal bins since ChiSqSelector requires categorical features + // Even though features are doubles, the ChiSqSelector treats each unique value as a category + val discretizedData = data.map { lp => + LabeledPoint(lp.label, Vectors.dense(lp.features.toArray.map { x => (x / 16).floor })) + } + // Create ChiSqSelector that will select top 50 of 692 features + val selector = new ChiSqSelector(50) + // Create ChiSqSelector model (selecting features) + val transformer = selector.fit(discretizedData) + // Filter the top 50 features from each feature vector + val filteredData = discretizedData.map { lp => + LabeledPoint(lp.label, transformer.transform(lp.features)) + } + // $example off$ + + println("filtered data: ") + filteredData.foreach(x => println(x)) + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala index 026d4ecc6d10a..e003f35ed399f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala @@ -20,10 +20,9 @@ package org.apache.spark.examples.mllib import scopt.OptionParser +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.stat.Statistics import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.{SparkConf, SparkContext} - /** * An example app for summarizing multivariate data from a file. Run with diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/CorrelationsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/CorrelationsExample.scala new file mode 100644 index 0000000000000..1202caf534e95 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/CorrelationsExample.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.rdd.RDD +// $example off$ + +object CorrelationsExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("CorrelationsExample") + val sc = new SparkContext(conf) + + // $example on$ + val seriesX: RDD[Double] = sc.parallelize(Array(1, 2, 3, 3, 5)) // a series + // must have the same number of partitions and cardinality as seriesX + val seriesY: RDD[Double] = sc.parallelize(Array(11, 22, 33, 33, 555)) + + // compute the correlation using Pearson's method. Enter "spearman" for Spearman's method. If a + // method is not specified, Pearson's method will be used by default. + val correlation: Double = Statistics.corr(seriesX, seriesY, "pearson") + println(s"Correlation is: $correlation") + + val data: RDD[Vector] = sc.parallelize( + Seq( + Vectors.dense(1.0, 10.0, 100.0), + Vectors.dense(2.0, 20.0, 200.0), + Vectors.dense(5.0, 33.0, 366.0)) + ) // note that each Vector is a row and not a column + + // calculate the correlation matrix using Pearson's method. Use "spearman" for Spearman's method + // If a method is not specified, Pearson's method will be used by default. + val correlMatrix: Matrix = Statistics.corr(data, "pearson") + println(correlMatrix.toString) + // $example off$ + + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala index 69988cc1b9334..5ff3d3624257b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala @@ -20,10 +20,9 @@ package org.apache.spark.examples.mllib import scopt.OptionParser -import org.apache.spark.SparkContext._ +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, RowMatrix} -import org.apache.spark.{SparkConf, SparkContext} /** * Compute the similar columns of a matrix, using cosine similarity. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala deleted file mode 100644 index dc13f82488af7..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.mllib - -import java.io.File - -import com.google.common.io.Files -import scopt.OptionParser - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext, DataFrame} - -/** - * An example of how to use [[org.apache.spark.sql.DataFrame]] as a Dataset for ML. Run with - * {{{ - * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options] - * }}} - * If you use it as a template to create your own app, please use `spark-submit` to submit your app. - */ -object DatasetExample { - - case class Params( - input: String = "data/mllib/sample_libsvm_data.txt", - dataFormat: String = "libsvm") extends AbstractParams[Params] - - def main(args: Array[String]) { - val defaultParams = Params() - - val parser = new OptionParser[Params]("DatasetExample") { - head("Dataset: an example app using DataFrame as a Dataset for ML.") - opt[String]("input") - .text(s"input path to dataset") - .action((x, c) => c.copy(input = x)) - opt[String]("dataFormat") - .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") - .action((x, c) => c.copy(input = x)) - checkConfig { params => - success - } - } - - parser.parse(args, defaultParams).map { params => - run(params) - }.getOrElse { - sys.exit(1) - } - } - - def run(params: Params) { - - val conf = new SparkConf().setAppName(s"DatasetExample with $params") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ // for implicit conversions - - // Load input data - val origData: RDD[LabeledPoint] = params.dataFormat match { - case "dense" => MLUtils.loadLabeledPoints(sc, params.input) - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input) - } - println(s"Loaded ${origData.count()} instances from file: ${params.input}") - - // Convert input data to DataFrame explicitly. - val df: DataFrame = origData.toDF() - println(s"Inferred schema:\n${df.schema.prettyJson}") - println(s"Converted to DataFrame with ${df.count()} records") - - // Select columns - val labelsDf: DataFrame = df.select("label") - val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v } - val numLabels = labels.count() - val meanLabel = labels.fold(0.0)(_ + _) / numLabels - println(s"Selected label column with average value $meanLabel") - - val featuresDf: DataFrame = df.select("features") - val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v } - val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( - (summary, feat) => summary.add(feat), - (sum1, sum2) => sum1.merge(sum2)) - println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") - - val tmpDir = Files.createTempDir() - tmpDir.deleteOnExit() - val outputDir = new File(tmpDir, "dataset").toString - println(s"Saving to $outputDir as Parquet file.") - df.write.parquet(outputDir) - - println(s"Loading Parquet file with UDT from $outputDir.") - val newDataset = sqlContext.read.parquet(outputDir) - - println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") - val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v } - val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( - (summary, feat) => summary.add(feat), - (sum1, sum2) => sum1.merge(sum2)) - println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}") - - sc.stop() - } - -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala new file mode 100644 index 0000000000000..c6c7c6f5e2ed8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.model.DecisionTreeModel +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object DecisionTreeClassificationExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DecisionTreeClassificationExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load and parse the data file. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + // Split the data into training and test sets (30% held out for testing) + val splits = data.randomSplit(Array(0.7, 0.3)) + val (trainingData, testData) = (splits(0), splits(1)) + + // Train a DecisionTree model. + // Empty categoricalFeaturesInfo indicates all features are continuous. + val numClasses = 2 + val categoricalFeaturesInfo = Map[Int, Int]() + val impurity = "gini" + val maxDepth = 5 + val maxBins = 32 + + val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, + impurity, maxDepth, maxBins) + + // Evaluate model on test instances and compute test error + val labelAndPreds = testData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count() + println("Test Error = " + testErr) + println("Learned classification tree model:\n" + model.toDebugString) + + // Save and load model + model.save(sc, "target/tmp/myDecisionTreeClassificationModel") + val sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala new file mode 100644 index 0000000000000..9c8baed3b8668 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.model.DecisionTreeModel +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object DecisionTreeRegressionExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DecisionTreeRegressionExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load and parse the data file. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + // Split the data into training and test sets (30% held out for testing) + val splits = data.randomSplit(Array(0.7, 0.3)) + val (trainingData, testData) = (splits(0), splits(1)) + + // Train a DecisionTree model. + // Empty categoricalFeaturesInfo indicates all features are continuous. + val categoricalFeaturesInfo = Map[Int, Int]() + val impurity = "variance" + val maxDepth = 5 + val maxBins = 32 + + val model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity, + maxDepth, maxBins) + + // Evaluate model on test instances and compute test error + val labelsAndPredictions = testData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val testMSE = labelsAndPredictions.map{ case (v, p) => math.pow(v - p, 2) }.mean() + println("Test Mean Squared Error = " + testMSE) + println("Learned regression tree model:\n" + model.toDebugString) + + // Save and load model + model.save(sc, "target/tmp/myDecisionTreeRegressionModel") + val sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeRegressionModel") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index cc6bce3cb7c9c..ee811d3aa1015 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -26,7 +26,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{DecisionTree, RandomForest, impurity} +import org.apache.spark.mllib.tree.{impurity, DecisionTree, RandomForest} import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.util.MLUtils @@ -180,7 +180,7 @@ object DecisionTreeRunner { } // For classification, re-index classes if needed. val (examples, classIndexMap, numClasses) = algo match { - case Classification => { + case Classification => // classCounts: class --> # examples in class val classCounts = origExamples.map(_.label).countByValue() val sortedClasses = classCounts.keys.toList.sorted @@ -209,7 +209,6 @@ object DecisionTreeRunner { println(s"$c\t$frac\t${classCounts(c)}") } (examples, classIndexMap, numClasses) - } case Regression => (origExamples, null, 0) case _ => @@ -225,7 +224,7 @@ object DecisionTreeRunner { case "libsvm" => MLUtils.loadLibSVMFile(sc, testInput, numFeatures) } algo match { - case Classification => { + case Classification => // classCounts: class --> # examples in class val testExamples = { if (classIndexMap.isEmpty) { @@ -235,7 +234,6 @@ object DecisionTreeRunner { } } Array(examples, testExamples) - } case Regression => Array(examples, origTestExamples) } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala index 1fce4ba7efd60..90b817b23e156 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala @@ -58,6 +58,12 @@ object DenseGaussianMixture { (clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma)) } + println("The membership value of each vector to all mixture components (first <= 100):") + val membership = clusters.predictSoft(data) + membership.take(100).foreach { x => + print(" " + x.mkString(",")) + } + println() println("Cluster labels (first <= 100):") val clusterLabels = clusters.predict(data) clusterLabels.take(100).foreach { x => diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/ElementwiseProductExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/ElementwiseProductExample.scala new file mode 100644 index 0000000000000..1e4e3543194e2 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/ElementwiseProductExample.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +// $example on$ +import org.apache.spark.mllib.feature.ElementwiseProduct +import org.apache.spark.mllib.linalg.Vectors +// $example off$ + +object ElementwiseProductExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("ElementwiseProductExample") + val sc = new SparkContext(conf) + + // $example on$ + // Create some vector data; also works for sparse vectors + val data = sc.parallelize(Array(Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0))) + + val transformingVector = Vectors.dense(0.0, 1.0, 2.0) + val transformer = new ElementwiseProduct(transformingVector) + + // Batch transform and per-row transform give the same results: + val transformedData = transformer.transform(data) + val transformedData2 = data.map(x => transformer.transform(x)) + // $example off$ + + println("transformedData: ") + transformedData.foreach(x => println(x)) + + println("transformedData2: ") + transformedData2.foreach(x => println(x)) + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala index 14b930550d554..a7a3eade04a0c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala @@ -20,8 +20,8 @@ package org.apache.spark.examples.mllib import scopt.OptionParser -import org.apache.spark.mllib.fpm.FPGrowth import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.fpm.FPGrowth /** * Example for mining frequent itemsets using FP-growth. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GaussianMixtureExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GaussianMixtureExample.scala new file mode 100644 index 0000000000000..b1b3a79d87ae1 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GaussianMixtureExample.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.clustering.{GaussianMixture, GaussianMixtureModel} +import org.apache.spark.mllib.linalg.Vectors +// $example off$ + +object GaussianMixtureExample { + + def main(args: Array[String]) { + + val conf = new SparkConf().setAppName("GaussianMixtureExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load and parse the data + val data = sc.textFile("data/mllib/gmm_data.txt") + val parsedData = data.map(s => Vectors.dense(s.trim.split(' ').map(_.toDouble))).cache() + + // Cluster the data into two classes using GaussianMixture + val gmm = new GaussianMixture().setK(2).run(parsedData) + + // Save and load model + gmm.save(sc, "target/org/apache/spark/GaussianMixtureExample/GaussianMixtureModel") + val sameModel = GaussianMixtureModel.load(sc, + "target/org/apache/spark/GaussianMixtureExample/GaussianMixtureModel") + + // output parameters of max-likelihood model + for (i <- 0 until gmm.k) { + println("weight=%f\nmu=%s\nsigma=\n%s\n" format + (gmm.weights(i), gmm.gaussians(i).mu, gmm.gaussians(i).sigma)) + } + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala index e16a6bf033574..b0144ef533133 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala @@ -23,10 +23,9 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.tree.GradientBoostedTrees -import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo} +import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy} import org.apache.spark.util.Utils - /** * An example runner for Gradient Boosting using decision trees as weak learners. Run with * {{{ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala new file mode 100644 index 0000000000000..0ec2e11214e89 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.tree.GradientBoostedTrees +import org.apache.spark.mllib.tree.configuration.BoostingStrategy +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object GradientBoostingClassificationExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("GradientBoostedTreesClassificationExample") + val sc = new SparkContext(conf) + // $example on$ + // Load and parse the data file. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + // Split the data into training and test sets (30% held out for testing) + val splits = data.randomSplit(Array(0.7, 0.3)) + val (trainingData, testData) = (splits(0), splits(1)) + + // Train a GradientBoostedTrees model. + // The defaultParams for Classification use LogLoss by default. + val boostingStrategy = BoostingStrategy.defaultParams("Classification") + boostingStrategy.numIterations = 3 // Note: Use more iterations in practice. + boostingStrategy.treeStrategy.numClasses = 2 + boostingStrategy.treeStrategy.maxDepth = 5 + // Empty categoricalFeaturesInfo indicates all features are continuous. + boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]() + + val model = GradientBoostedTrees.train(trainingData, boostingStrategy) + + // Evaluate model on test instances and compute test error + val labelAndPreds = testData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() + println("Test Error = " + testErr) + println("Learned classification GBT model:\n" + model.toDebugString) + + // Save and load model + model.save(sc, "target/tmp/myGradientBoostingClassificationModel") + val sameModel = GradientBoostedTreesModel.load(sc, + "target/tmp/myGradientBoostingClassificationModel") + // $example off$ + } +} +// scalastyle:on println + + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala new file mode 100644 index 0000000000000..b87ba0defe695 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.tree.GradientBoostedTrees +import org.apache.spark.mllib.tree.configuration.BoostingStrategy +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object GradientBoostingRegressionExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("GradientBoostedTreesRegressionExample") + val sc = new SparkContext(conf) + // $example on$ + // Load and parse the data file. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + // Split the data into training and test sets (30% held out for testing) + val splits = data.randomSplit(Array(0.7, 0.3)) + val (trainingData, testData) = (splits(0), splits(1)) + + // Train a GradientBoostedTrees model. + // The defaultParams for Regression use SquaredError by default. + val boostingStrategy = BoostingStrategy.defaultParams("Regression") + boostingStrategy.numIterations = 3 // Note: Use more iterations in practice. + boostingStrategy.treeStrategy.maxDepth = 5 + // Empty categoricalFeaturesInfo indicates all features are continuous. + boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]() + + val model = GradientBoostedTrees.train(trainingData, boostingStrategy) + + // Evaluate model on test instances and compute test error + val labelsAndPredictions = testData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() + println("Test Mean Squared Error = " + testMSE) + println("Learned regression GBT model:\n" + model.toDebugString) + + // Save and load model + model.save(sc, "target/tmp/myGradientBoostingRegressionModel") + val sameModel = GradientBoostedTreesModel.load(sc, + "target/tmp/myGradientBoostingRegressionModel") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala new file mode 100644 index 0000000000000..0d391a3637c07 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.mllib.stat.test.ChiSqTestResult +import org.apache.spark.rdd.RDD +// $example off$ + +object HypothesisTestingExample { + + def main(args: Array[String]) { + + val conf = new SparkConf().setAppName("HypothesisTestingExample") + val sc = new SparkContext(conf) + + // $example on$ + // a vector composed of the frequencies of events + val vec: Vector = Vectors.dense(0.1, 0.15, 0.2, 0.3, 0.25) + + // compute the goodness of fit. If a second vector to test against is not supplied + // as a parameter, the test runs against a uniform distribution. + val goodnessOfFitTestResult = Statistics.chiSqTest(vec) + // summary of the test including the p-value, degrees of freedom, test statistic, the method + // used, and the null hypothesis. + println(s"$goodnessOfFitTestResult\n") + + // a contingency matrix. Create a dense matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)) + val mat: Matrix = Matrices.dense(3, 2, Array(1.0, 3.0, 5.0, 2.0, 4.0, 6.0)) + + // conduct Pearson's independence test on the input contingency matrix + val independenceTestResult = Statistics.chiSqTest(mat) + // summary of the test including the p-value, degrees of freedom + println(s"$independenceTestResult\n") + + val obs: RDD[LabeledPoint] = + sc.parallelize( + Seq( + LabeledPoint(1.0, Vectors.dense(1.0, 0.0, 3.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 2.0, 0.0)), + LabeledPoint(-1.0, Vectors.dense(-1.0, 0.0, -0.5) + ) + ) + ) // (feature, label) pairs. + + // The contingency table is constructed from the raw (feature, label) pairs and used to conduct + // the independence test. Returns an array containing the ChiSquaredTestResult for every feature + // against the label. + val featureTestResults: Array[ChiSqTestResult] = Statistics.chiSqTest(obs) + featureTestResults.zipWithIndex.foreach { case (k, v) => + println("Column " + (v + 1).toString + ":") + println(k) + } // summary of the test + // $example off$ + + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingKolmogorovSmirnovTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingKolmogorovSmirnovTestExample.scala new file mode 100644 index 0000000000000..840874cf3c2fe --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingKolmogorovSmirnovTestExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.rdd.RDD +// $example off$ + +object HypothesisTestingKolmogorovSmirnovTestExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("HypothesisTestingKolmogorovSmirnovTestExample") + val sc = new SparkContext(conf) + + // $example on$ + val data: RDD[Double] = sc.parallelize(Seq(0.1, 0.15, 0.2, 0.3, 0.25)) // an RDD of sample data + + // run a KS test for the sample versus a standard normal distribution + val testResult = Statistics.kolmogorovSmirnovTest(data, "norm", 0, 1) + // summary of the test including the p-value, test statistic, and null hypothesis if our p-value + // indicates significance, we can reject the null hypothesis. + println(testResult) + println() + + // perform a KS test using a cumulative distribution function of our making + val myCDF = Map(0.1 -> 0.2, 0.15 -> 0.6, 0.2 -> 0.05, 0.3 -> 0.05, 0.25 -> 0.1) + val testResult2 = Statistics.kolmogorovSmirnovTest(data, myCDF) + println(testResult2) + // $example off$ + + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala index 52ac9ae7dd2d0..c4336639d7c0b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala @@ -18,14 +18,14 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.regression.{IsotonicRegression, IsotonicRegressionModel} // $example off$ -import org.apache.spark.{SparkConf, SparkContext} object IsotonicRegressionExample { - def main(args: Array[String]) : Unit = { + def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("IsotonicRegressionExample") val sc = new SparkContext(conf) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala new file mode 100644 index 0000000000000..c4d71d862f375 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.clustering.{KMeans, KMeansModel} +import org.apache.spark.mllib.linalg.Vectors +// $example off$ + +object KMeansExample { + + def main(args: Array[String]) { + + val conf = new SparkConf().setAppName("KMeansExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load and parse the data + val data = sc.textFile("data/mllib/kmeans_data.txt") + val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))).cache() + + // Cluster the data into two classes using KMeans + val numClusters = 2 + val numIterations = 20 + val clusters = KMeans.train(parsedData, numClusters, numIterations) + + // Evaluate clustering by computing Within Set Sum of Squared Errors + val WSSSE = clusters.computeCost(parsedData) + println("Within Set Sum of Squared Errors = " + WSSSE) + + // Save and load model + clusters.save(sc, "target/org/apache/spark/KMeansExample/KMeansModel") + val sameModel = KMeansModel.load(sc, "target/org/apache/spark/KMeansExample/KMeansModel") + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/KernelDensityEstimationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/KernelDensityEstimationExample.scala new file mode 100644 index 0000000000000..cc5d159b36cc9 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/KernelDensityEstimationExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.stat.KernelDensity +import org.apache.spark.rdd.RDD +// $example off$ + +object KernelDensityEstimationExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("KernelDensityEstimationExample") + val sc = new SparkContext(conf) + + // $example on$ + // an RDD of sample data + val data: RDD[Double] = sc.parallelize(Seq(1, 1, 1, 2, 3, 4, 5, 5, 6, 7, 8, 9, 9)) + + // Construct the density estimator with the sample data and a standard deviation + // for the Gaussian kernels + val kd = new KernelDensity() + .setSample(data) + .setBandwidth(3.0) + + // Find density estimates for the given values + val densities = kd.estimate(Array(-1.0, 2.0, 5.0)) + // $example off$ + + densities.foreach(println) + + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala new file mode 100644 index 0000000000000..75a0419da5ec3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.classification.LogisticRegressionModel +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.optimization.{LBFGS, LogisticGradient, SquaredL2Updater} +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object LBFGSExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("LBFGSExample") + val sc = new SparkContext(conf) + + // $example on$ + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + val numFeatures = data.take(1)(0).features.size + + // Split data into training (60%) and test (40%). + val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) + + // Append 1 into the training data as intercept. + val training = splits(0).map(x => (x.label, MLUtils.appendBias(x.features))).cache() + + val test = splits(1) + + // Run training algorithm to build the model + val numCorrections = 10 + val convergenceTol = 1e-4 + val maxNumIterations = 20 + val regParam = 0.1 + val initialWeightsWithIntercept = Vectors.dense(new Array[Double](numFeatures + 1)) + + val (weightsWithIntercept, loss) = LBFGS.runLBFGS( + training, + new LogisticGradient(), + new SquaredL2Updater(), + numCorrections, + convergenceTol, + maxNumIterations, + regParam, + initialWeightsWithIntercept) + + val model = new LogisticRegressionModel( + Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)), + weightsWithIntercept(weightsWithIntercept.size - 1)) + + // Clear the default threshold. + model.clearThreshold() + + // Compute raw scores on the test set. + val scoreAndLabels = test.map { point => + val score = model.predict(point.features) + (score, point.label) + } + + // Get evaluation metrics. + val metrics = new BinaryClassificationMetrics(scoreAndLabels) + val auROC = metrics.areaUnderROC() + + println("Loss of each step in training process") + loss.foreach(println) + println("Area under ROC = " + auROC) + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index 75b0f69cf91aa..e89d555884dd0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -18,19 +18,16 @@ // scalastyle:off println package org.apache.spark.examples.mllib -import java.text.BreakIterator - -import scala.collection.mutable - -import scopt.OptionParser - import org.apache.log4j.{Level, Logger} +import scopt.OptionParser -import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.mllib.clustering.{EMLDAOptimizer, OnlineLDAOptimizer, DistributedLDAModel, LDA} -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover} +import org.apache.spark.mllib.clustering.{DistributedLDAModel, EMLDAOptimizer, LDA, OnlineLDAOptimizer} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD - +import org.apache.spark.sql.{Row, SQLContext} /** * An example Latent Dirichlet Allocation (LDA) app. Run with @@ -121,7 +118,7 @@ object LDAExample { preprocess(sc, params.input, params.vocabSize, params.stopwordFile) corpus.cache() val actualCorpusSize = corpus.count() - val actualVocabSize = vocabArray.size + val actualVocabSize = vocabArray.length val preprocessElapsed = (System.nanoTime() - preprocessStart) / 1e9 println() @@ -192,115 +189,46 @@ object LDAExample { vocabSize: Int, stopwordFile: String): (RDD[(Long, Vector)], Array[String], Long) = { + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + // Get dataset of document texts // One document per line in each text file. If the input consists of many small files, // this can result in a large number of small partitions, which can degrade performance. // In this case, consider using coalesce() to create fewer, larger partitions. - val textRDD: RDD[String] = sc.textFile(paths.mkString(",")) - - // Split text into words - val tokenizer = new SimpleTokenizer(sc, stopwordFile) - val tokenized: RDD[(Long, IndexedSeq[String])] = textRDD.zipWithIndex().map { case (text, id) => - id -> tokenizer.getWords(text) - } - tokenized.cache() - - // Counts words: RDD[(word, wordCount)] - val wordCounts: RDD[(String, Long)] = tokenized - .flatMap { case (_, tokens) => tokens.map(_ -> 1L) } - .reduceByKey(_ + _) - wordCounts.cache() - val fullVocabSize = wordCounts.count() - // Select vocab - // (vocab: Map[word -> id], total tokens after selecting vocab) - val (vocab: Map[String, Int], selectedTokenCount: Long) = { - val tmpSortedWC: Array[(String, Long)] = if (vocabSize == -1 || fullVocabSize <= vocabSize) { - // Use all terms - wordCounts.collect().sortBy(-_._2) - } else { - // Sort terms to select vocab - wordCounts.sortBy(_._2, ascending = false).take(vocabSize) - } - (tmpSortedWC.map(_._1).zipWithIndex.toMap, tmpSortedWC.map(_._2).sum) - } - - val documents = tokenized.map { case (id, tokens) => - // Filter tokens by vocabulary, and create word count vector representation of document. - val wc = new mutable.HashMap[Int, Int]() - tokens.foreach { term => - if (vocab.contains(term)) { - val termIndex = vocab(term) - wc(termIndex) = wc.getOrElse(termIndex, 0) + 1 - } - } - val indices = wc.keys.toArray.sorted - val values = indices.map(i => wc(i).toDouble) - - val sb = Vectors.sparse(vocab.size, indices, values) - (id, sb) - } - - val vocabArray = new Array[String](vocab.size) - vocab.foreach { case (term, i) => vocabArray(i) = term } - - (documents, vocabArray, selectedTokenCount) - } -} - -/** - * Simple Tokenizer. - * - * TODO: Formalize the interface, and make this a public class in mllib.feature - */ -private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Serializable { - - private val stopwords: Set[String] = if (stopwordFile.isEmpty) { - Set.empty[String] - } else { - val stopwordText = sc.textFile(stopwordFile).collect() - stopwordText.flatMap(_.stripMargin.split("\\s+")).toSet - } - - // Matches sequences of Unicode letters - private val allWordRegex = "^(\\p{L}*)$".r - - // Ignore words shorter than this length. - private val minWordLength = 3 - - def getWords(text: String): IndexedSeq[String] = { - - val words = new mutable.ArrayBuffer[String]() - - // Use Java BreakIterator to tokenize text into words. - val wb = BreakIterator.getWordInstance - wb.setText(text) - - // current,end index start,end of each word - var current = wb.first() - var end = wb.next() - while (end != BreakIterator.DONE) { - // Convert to lowercase - val word: String = text.substring(current, end).toLowerCase - // Remove short words and strings that aren't only letters - word match { - case allWordRegex(w) if w.length >= minWordLength && !stopwords.contains(w) => - words += w - case _ => - } - - current = end - try { - end = wb.next() - } catch { - case e: Exception => - // Ignore remaining text in line. - // This is a known bug in BreakIterator (for some Java versions), - // which fails when it sees certain characters. - end = BreakIterator.DONE - } + val df = sc.textFile(paths.mkString(",")).toDF("docs") + val customizedStopWords: Array[String] = if (stopwordFile.isEmpty) { + Array.empty[String] + } else { + val stopWordText = sc.textFile(stopwordFile).collect() + stopWordText.flatMap(_.stripMargin.split("\\s+")) } - words + val tokenizer = new RegexTokenizer() + .setInputCol("docs") + .setOutputCol("rawTokens") + val stopWordsRemover = new StopWordsRemover() + .setInputCol("rawTokens") + .setOutputCol("tokens") + stopWordsRemover.setStopWords(stopWordsRemover.getStopWords ++ customizedStopWords) + val countVectorizer = new CountVectorizer() + .setVocabSize(vocabSize) + .setInputCol("tokens") + .setOutputCol("features") + + val pipeline = new Pipeline() + .setStages(Array(tokenizer, stopWordsRemover, countVectorizer)) + + val model = pipeline.fit(df) + val documents = model.transform(df) + .select("features") + .rdd + .map { case Row(features: Vector) => features } + .zipWithIndex() + .map(_.swap) + + (documents, + model.stages(2).asInstanceOf[CountVectorizerModel].vocabulary, // vocabulary + documents.map(_._2.numActives).sum().toLong) // total token count } - } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala new file mode 100644 index 0000000000000..f2c8ec01439f1 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.clustering.{DistributedLDAModel, LDA} +import org.apache.spark.mllib.linalg.Vectors +// $example off$ + +object LatentDirichletAllocationExample { + + def main(args: Array[String]) { + + val conf = new SparkConf().setAppName("LatentDirichletAllocationExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load and parse the data + val data = sc.textFile("data/mllib/sample_lda_data.txt") + val parsedData = data.map(s => Vectors.dense(s.trim.split(' ').map(_.toDouble))) + // Index documents with unique IDs + val corpus = parsedData.zipWithIndex.map(_.swap).cache() + + // Cluster the documents into three topics using LDA + val ldaModel = new LDA().setK(3).run(corpus) + + // Output topics. Each is a distribution over words (matching word count vectors) + println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize + " words):") + val topics = ldaModel.topicsMatrix + for (topic <- Range(0, 3)) { + print("Topic " + topic + ":") + for (word <- Range(0, ldaModel.vocabSize)) { print(" " + topics(word, topic)); } + println() + } + + // Save and load model. + ldaModel.save(sc, "target/org/apache/spark/LatentDirichletAllocationExample/LDAModel") + val sameModel = DistributedLDAModel.load(sc, + "target/org/apache/spark/LatentDirichletAllocationExample/LDAModel") + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala index 8878061a0970b..f87611f5d4613 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala @@ -22,9 +22,9 @@ import org.apache.log4j.{Level, Logger} import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.optimization.{L1Updater, SimpleUpdater, SquaredL2Updater} import org.apache.spark.mllib.regression.LinearRegressionWithSGD import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.optimization.{SimpleUpdater, SquaredL2Updater, L1Updater} /** * An example app for linear regression. Run with diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala new file mode 100644 index 0000000000000..669868787e8f0 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.LinearRegressionModel +import org.apache.spark.mllib.regression.LinearRegressionWithSGD +// $example off$ + +object LinearRegressionWithSGDExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("LinearRegressionWithSGDExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load and parse the data + val data = sc.textFile("data/mllib/ridge-data/lpsa.data") + val parsedData = data.map { line => + val parts = line.split(',') + LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble))) + }.cache() + + // Building the model + val numIterations = 100 + val stepSize = 0.00000001 + val model = LinearRegressionWithSGD.train(parsedData, numIterations, stepSize) + + // Evaluate model on training examples and compute training error + val valuesAndPreds = parsedData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2) }.mean() + println("training Mean Squared Error = " + MSE) + + // Save and load model + model.save(sc, "target/tmp/scalaLinearRegressionWithSGDModel") + val sameModel = LinearRegressionModel.load(sc, "target/tmp/scalaLinearRegressionWithSGDModel") + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LogisticRegressionWithLBFGSExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LogisticRegressionWithLBFGSExample.scala new file mode 100644 index 0000000000000..632a2d537e5bc --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LogisticRegressionWithLBFGSExample.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.classification.{LogisticRegressionModel, LogisticRegressionWithLBFGS} +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object LogisticRegressionWithLBFGSExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("LogisticRegressionWithLBFGSExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load training data in LIBSVM format. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + + // Split data into training (60%) and test (40%). + val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) + val training = splits(0).cache() + val test = splits(1) + + // Run training algorithm to build the model + val model = new LogisticRegressionWithLBFGS() + .setNumClasses(10) + .run(training) + + // Compute raw scores on the test set. + val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) + } + + // Get evaluation metrics. + val metrics = new MulticlassMetrics(predictionAndLabels) + val precision = metrics.precision + println("Precision = " + precision) + + // Save and load model + model.save(sc, "target/tmp/scalaLogisticRegressionWithLBFGSModel") + val sameModel = LogisticRegressionModel.load(sc, + "target/tmp/scalaLogisticRegressionWithLBFGSModel") + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 69691ae297f64..09750e53cb169 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -24,7 +24,6 @@ import org.apache.log4j.{Level, Logger} import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.mllib.recommendation.{ALS, MatrixFactorizationModel, Rating} import org.apache.spark.rdd.RDD diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala new file mode 100644 index 0000000000000..c0d447bf69dd7 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.evaluation.MultilabelMetrics +import org.apache.spark.rdd.RDD +// $example off$ + +object MultiLabelMetricsExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("MultiLabelMetricsExample") + val sc = new SparkContext(conf) + // $example on$ + val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize( + Seq((Array(0.0, 1.0), Array(0.0, 2.0)), + (Array(0.0, 2.0), Array(0.0, 1.0)), + (Array.empty[Double], Array(0.0)), + (Array(2.0), Array(2.0)), + (Array(2.0, 0.0), Array(2.0, 0.0)), + (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)), + (Array(1.0), Array(1.0, 2.0))), 2) + + // Instantiate metrics object + val metrics = new MultilabelMetrics(scoreAndLabels) + + // Summary stats + println(s"Recall = ${metrics.recall}") + println(s"Precision = ${metrics.precision}") + println(s"F1 measure = ${metrics.f1Measure}") + println(s"Accuracy = ${metrics.accuracy}") + + // Individual label stats + metrics.labels.foreach(label => + println(s"Class $label precision = ${metrics.precision(label)}")) + metrics.labels.foreach(label => println(s"Class $label recall = ${metrics.recall(label)}")) + metrics.labels.foreach(label => println(s"Class $label F1-score = ${metrics.f1Measure(label)}")) + + // Micro stats + println(s"Micro recall = ${metrics.microRecall}") + println(s"Micro precision = ${metrics.microPrecision}") + println(s"Micro F1 measure = ${metrics.microF1Measure}") + + // Hamming loss + println(s"Hamming loss = ${metrics.hammingLoss}") + + // Subset accuracy + println(s"Subset accuracy = ${metrics.subsetAccuracy}") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala new file mode 100644 index 0000000000000..4f925ede24d82 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object MulticlassMetricsExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("MulticlassMetricsExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load training data in LIBSVM format + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + + // Split data into training (60%) and test (40%) + val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) + training.cache() + + // Run training algorithm to build the model + val model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training) + + // Compute raw scores on the test set + val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) + } + + // Instantiate metrics object + val metrics = new MulticlassMetrics(predictionAndLabels) + + // Confusion matrix + println("Confusion matrix:") + println(metrics.confusionMatrix) + + // Overall Statistics + val precision = metrics.precision + val recall = metrics.recall // same as true positive rate + val f1Score = metrics.fMeasure + println("Summary Statistics") + println(s"Precision = $precision") + println(s"Recall = $recall") + println(s"F1 Score = $f1Score") + + // Precision by label + val labels = metrics.labels + labels.foreach { l => + println(s"Precision($l) = " + metrics.precision(l)) + } + + // Recall by label + labels.foreach { l => + println(s"Recall($l) = " + metrics.recall(l)) + } + + // False positive rate by label + labels.foreach { l => + println(s"FPR($l) = " + metrics.falsePositiveRate(l)) + } + + // F-measure by label + labels.foreach { l => + println(s"F1-Score($l) = " + metrics.fMeasure(l)) + } + + // Weighted stats + println(s"Weighted precision: ${metrics.weightedPrecision}") + println(s"Weighted recall: ${metrics.weightedRecall}") + println(s"Weighted F1 score: ${metrics.weightedFMeasure}") + println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala index 5f839c75dd581..3c598172dadf0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala @@ -20,11 +20,10 @@ package org.apache.spark.examples.mllib import scopt.OptionParser +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.{SparkConf, SparkContext} - /** * An example app for summarizing multivariate data from a file. Run with diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala index a7a47c2a3556a..0187ad603a654 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala @@ -18,16 +18,16 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint // $example off$ -import org.apache.spark.{SparkConf, SparkContext} object NaiveBayesExample { - def main(args: Array[String]) : Unit = { + def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("NaiveBayesExample") val sc = new SparkContext(conf) // $example on$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/NormalizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/NormalizerExample.scala new file mode 100644 index 0000000000000..b3a9604c2be3e --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/NormalizerExample.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +// $example on$ +import org.apache.spark.mllib.feature.Normalizer +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object NormalizerExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("NormalizerExample") + val sc = new SparkContext(conf) + + // $example on$ + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + + val normalizer1 = new Normalizer() + val normalizer2 = new Normalizer(p = Double.PositiveInfinity) + + // Each sample in data1 will be normalized using $L^2$ norm. + val data1 = data.map(x => (x.label, normalizer1.transform(x.features))) + + // Each sample in data2 will be normalized using $L^\infty$ norm. + val data2 = data.map(x => (x.label, normalizer2.transform(x.features))) + // $example off$ + + println("data1: ") + data1.foreach(x => println(x)) + + println("data2: ") + data2.foreach(x => println(x)) + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala new file mode 100644 index 0000000000000..f7a813695304f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +// $example on$ +import org.apache.spark.mllib.feature.PCA +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.{LabeledPoint, LinearRegressionWithSGD} +// $example off$ + +object PCAExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("PCAExample") + val sc = new SparkContext(conf) + + // $example on$ + val data = sc.textFile("data/mllib/ridge-data/lpsa.data").map { line => + val parts = line.split(',') + LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble))) + }.cache() + + val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) + val training = splits(0).cache() + val test = splits(1) + + val pca = new PCA(training.first().features.size / 2).fit(data.map(_.features)) + val training_pca = training.map(p => p.copy(features = pca.transform(p.features))) + val test_pca = test.map(p => p.copy(features = pca.transform(p.features))) + + val numIterations = 100 + val model = LinearRegressionWithSGD.train(training, numIterations) + val model_pca = LinearRegressionWithSGD.train(training_pca, numIterations) + + val valuesAndPreds = test.map { point => + val score = model.predict(point.features) + (score, point.label) + } + + val valuesAndPreds_pca = test_pca.map { point => + val score = model_pca.predict(point.features) + (score, point.label) + } + + val MSE = valuesAndPreds.map { case (v, p) => math.pow((v - p), 2) }.mean() + val MSE_pca = valuesAndPreds_pca.map { case (v, p) => math.pow((v - p), 2) }.mean() + + println("Mean Squared Error = " + MSE) + println("PCA Mean Squared Error = " + MSE_pca) + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala new file mode 100644 index 0000000000000..234de230eb201 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +// $example on$ +import org.apache.spark.mllib.linalg.Matrix +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.distributed.RowMatrix +// $example off$ + +object PCAOnRowMatrixExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("PCAOnRowMatrixExample") + val sc = new SparkContext(conf) + + // $example on$ + val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) + + val dataRDD = sc.parallelize(data, 2) + + val mat: RowMatrix = new RowMatrix(dataRDD) + + // Compute the top 4 principal components. + // Principal components are stored in a local dense matrix. + val pc: Matrix = mat.computePrincipalComponents(4) + + // Project the rows to the linear space spanned by the top 4 principal components. + val projected: RowMatrix = mat.multiply(pc) + // $example off$ + val collect = projected.rows.collect() + println("Projected Row Matrix of principal component:") + collect.foreach { vector => println(vector) } + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnSourceVectorExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnSourceVectorExample.scala new file mode 100644 index 0000000000000..f7694879dfbdb --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnSourceVectorExample.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +// $example on$ +import org.apache.spark.mllib.feature.PCA +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD +// $example off$ + +object PCAOnSourceVectorExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("PCAOnSourceVectorExample") + val sc = new SparkContext(conf) + + // $example on$ + val data: RDD[LabeledPoint] = sc.parallelize(Seq( + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)), + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)))) + + // Compute the top 5 principal components. + val pca = new PCA(5).fit(data.map(_.features)) + + // Project vectors to the linear space spanned by the top 5 principal + // components, keeping the label + val projected = data.map(p => p.copy(features = pca.transform(p.features))) + // $example off$ + val collect = projected.collect() + println("Projected vector of principal component:") + collect.foreach { vector => println(vector) } + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala new file mode 100644 index 0000000000000..d74d74a37fb11 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.clustering.KMeans +import org.apache.spark.mllib.linalg.Vectors +// $example off$ + +object PMMLModelExportExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("PMMLModelExportExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load and parse the data + val data = sc.textFile("data/mllib/kmeans_data.txt") + val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))).cache() + + // Cluster the data into two classes using KMeans + val numClusters = 2 + val numIterations = 20 + val clusters = KMeans.train(parsedData, numClusters, numIterations) + + // Export to PMML to a String in PMML format + println("PMML Model:\n" + clusters.toPMML) + + // Export the model to a local file in PMML format + clusters.toPMML("/tmp/kmeans.xml") + + // Export the model to a directory on a distributed file system in PMML format + clusters.toPMML(sc, "/tmp/kmeans") + + // Export the model to the OutputStream in PMML format + clusters.toPMML(System.out) + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala index 0723223954610..a81c9b383ddec 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala @@ -21,9 +21,11 @@ package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} import scopt.OptionParser +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ import org.apache.spark.mllib.clustering.PowerIterationClustering +// $example off$ import org.apache.spark.rdd.RDD -import org.apache.spark.{SparkConf, SparkContext} /** * An example Power Iteration Clustering http://www.icml2010.org/papers/387.pdf app. @@ -40,27 +42,23 @@ import org.apache.spark.{SparkConf, SparkContext} * n: Number of sampled points on innermost circle.. There are proportionally more points * within the outer/larger circles * maxIterations: Number of Power Iterations - * outerRadius: radius of the outermost of the concentric circles * }}} * * Here is a sample run and output: * - * ./bin/run-example mllib.PowerIterationClusteringExample -k 3 --n 30 --maxIterations 15 - * - * Cluster assignments: 1 -> [0,1,2,3,4],2 -> [5,6,7,8,9,10,11,12,13,14], - * 0 -> [15,16,17,18,19,20,21,22,23,24,25,26,27,28,29] + * ./bin/run-example mllib.PowerIterationClusteringExample -k 2 --n 10 --maxIterations 15 * + * Cluster assignments: 1 -> [0,1,2,3,4,5,6,7,8,9], + * 0 -> [10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29] * * If you use it as a template to create your own app, please use `spark-submit` to submit your app. */ object PowerIterationClusteringExample { case class Params( - input: String = null, - k: Int = 3, - numPoints: Int = 5, - maxIterations: Int = 10, - outerRadius: Double = 3.0 + k: Int = 2, + numPoints: Int = 10, + maxIterations: Int = 15 ) extends AbstractParams[Params] def main(args: Array[String]) { @@ -69,7 +67,7 @@ object PowerIterationClusteringExample { val parser = new OptionParser[Params]("PowerIterationClusteringExample") { head("PowerIterationClusteringExample: an example PIC app using concentric circles.") opt[Int]('k', "k") - .text(s"number of circles (/clusters), default: ${defaultParams.k}") + .text(s"number of circles (clusters), default: ${defaultParams.k}") .action((x, c) => c.copy(k = x)) opt[Int]('n', "n") .text(s"number of points in smallest circle, default: ${defaultParams.numPoints}") @@ -77,9 +75,6 @@ object PowerIterationClusteringExample { opt[Int]("maxIterations") .text(s"number of iterations, default: ${defaultParams.maxIterations}") .action((x, c) => c.copy(maxIterations = x)) - opt[Double]('r', "r") - .text(s"radius of outermost circle, default: ${defaultParams.outerRadius}") - .action((x, c) => c.copy(outerRadius = x)) } parser.parse(args, defaultParams).map { params => @@ -97,22 +92,25 @@ object PowerIterationClusteringExample { Logger.getRootLogger.setLevel(Level.WARN) - val circlesRdd = generateCirclesRdd(sc, params.k, params.numPoints, params.outerRadius) + // $example on$ + val circlesRdd = generateCirclesRdd(sc, params.k, params.numPoints) val model = new PowerIterationClustering() .setK(params.k) .setMaxIterations(params.maxIterations) + .setInitializationMode("degree") .run(circlesRdd) val clusters = model.assignments.collect().groupBy(_.cluster).mapValues(_.map(_.id)) - val assignments = clusters.toList.sortBy { case (k, v) => v.length} + val assignments = clusters.toList.sortBy { case (k, v) => v.length } val assignmentsStr = assignments .map { case (k, v) => - s"$k -> ${v.sorted.mkString("[", ",", "]")}" - }.mkString(",") + s"$k -> ${v.sorted.mkString("[", ",", "]")}" + }.mkString(", ") val sizesStr = assignments.map { - _._2.size + _._2.length }.sorted.mkString("(", ",", ")") println(s"Cluster assignments: $assignmentsStr\ncluster sizes: $sizesStr") + // $example off$ sc.stop() } @@ -124,20 +122,17 @@ object PowerIterationClusteringExample { } } - def generateCirclesRdd(sc: SparkContext, - nCircles: Int = 3, - nPoints: Int = 30, - outerRadius: Double): RDD[(Long, Long, Double)] = { - - val radii = Array.tabulate(nCircles) { cx => outerRadius / (nCircles - cx)} - val groupSizes = Array.tabulate(nCircles) { cx => (cx + 1) * nPoints} - val points = (0 until nCircles).flatMap { cx => - generateCircle(radii(cx), groupSizes(cx)) + def generateCirclesRdd( + sc: SparkContext, + nCircles: Int, + nPoints: Int): RDD[(Long, Long, Double)] = { + val points = (1 to nCircles).flatMap { i => + generateCircle(i, i * nPoints) }.zipWithIndex val rdd = sc.parallelize(points) val distancesRdd = rdd.cartesian(rdd).flatMap { case (((x0, y0), i0), ((x1, y1), i1)) => if (i0 < i1) { - Some((i0.toLong, i1.toLong, gaussianSimilarity((x0, y0), (x1, y1), 1.0))) + Some((i0.toLong, i1.toLong, gaussianSimilarity((x0, y0), (x1, y1)))) } else { None } @@ -148,11 +143,9 @@ object PowerIterationClusteringExample { /** * Gaussian Similarity: http://en.wikipedia.org/wiki/Radial_basis_function_kernel */ - def gaussianSimilarity(p1: (Double, Double), p2: (Double, Double), sigma: Double): Double = { - val coeff = 1.0 / (math.sqrt(2.0 * math.Pi) * sigma) - val expCoeff = -1.0 / 2.0 * math.pow(sigma, 2.0) + def gaussianSimilarity(p1: (Double, Double), p2: (Double, Double)): Double = { val ssquares = (p1._1 - p2._1) * (p1._1 - p2._1) + (p1._2 - p2._2) * (p1._2 - p2._2) - coeff * math.exp(expCoeff * ssquares) + math.exp(-ssquares / 2.0) } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala index d237232c430ca..ef86eab9e4ec5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala @@ -18,12 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.fpm.PrefixSpan // $example off$ -import org.apache.spark.{SparkConf, SparkContext} - object PrefixSpanExample { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala new file mode 100644 index 0000000000000..7805153ba7b95 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.tree.RandomForest +import org.apache.spark.mllib.tree.model.RandomForestModel +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object RandomForestClassificationExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("RandomForestClassificationExample") + val sc = new SparkContext(conf) + // $example on$ + // Load and parse the data file. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + // Split the data into training and test sets (30% held out for testing) + val splits = data.randomSplit(Array(0.7, 0.3)) + val (trainingData, testData) = (splits(0), splits(1)) + + // Train a RandomForest model. + // Empty categoricalFeaturesInfo indicates all features are continuous. + val numClasses = 2 + val categoricalFeaturesInfo = Map[Int, Int]() + val numTrees = 3 // Use more in practice. + val featureSubsetStrategy = "auto" // Let the algorithm choose. + val impurity = "gini" + val maxDepth = 4 + val maxBins = 32 + + val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, + numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins) + + // Evaluate model on test instances and compute test error + val labelAndPreds = testData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() + println("Test Error = " + testErr) + println("Learned classification forest model:\n" + model.toDebugString) + + // Save and load model + model.save(sc, "target/tmp/myRandomForestClassificationModel") + val sameModel = RandomForestModel.load(sc, "target/tmp/myRandomForestClassificationModel") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala new file mode 100644 index 0000000000000..655a277e28ae8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.tree.RandomForest +import org.apache.spark.mllib.tree.model.RandomForestModel +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object RandomForestRegressionExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("RandomForestRegressionExample") + val sc = new SparkContext(conf) + // $example on$ + // Load and parse the data file. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + // Split the data into training and test sets (30% held out for testing) + val splits = data.randomSplit(Array(0.7, 0.3)) + val (trainingData, testData) = (splits(0), splits(1)) + + // Train a RandomForest model. + // Empty categoricalFeaturesInfo indicates all features are continuous. + val numClasses = 2 + val categoricalFeaturesInfo = Map[Int, Int]() + val numTrees = 3 // Use more in practice. + val featureSubsetStrategy = "auto" // Let the algorithm choose. + val impurity = "variance" + val maxDepth = 4 + val maxBins = 32 + + val model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo, + numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins) + + // Evaluate model on test instances and compute test error + val labelsAndPredictions = testData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() + println("Test Mean Squared Error = " + testMSE) + println("Learned regression forest model:\n" + model.toDebugString) + + // Save and load model + model.save(sc, "target/tmp/myRandomForestRegressionModel") + val sameModel = RandomForestModel.load(sc, "target/tmp/myRandomForestRegressionModel") + // $example off$ + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala index bee85ba0f9969..7ccbb5a0640cd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala @@ -18,11 +18,10 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.random.RandomRDDs import org.apache.spark.rdd.RDD -import org.apache.spark.{SparkConf, SparkContext} - /** * An example app for randomly generated RDDs. Run with * {{{ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala new file mode 100644 index 0000000000000..fdb01b86dd787 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.evaluation.{RankingMetrics, RegressionMetrics} +import org.apache.spark.mllib.recommendation.{ALS, Rating} +// $example off$ +import org.apache.spark.sql.SQLContext + +object RankingMetricsExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("RankingMetricsExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + // $example on$ + // Read in the ratings data + val ratings = sc.textFile("data/mllib/sample_movielens_data.txt").map { line => + val fields = line.split("::") + Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5) + }.cache() + + // Map ratings to 1 or 0, 1 indicating a movie that should be recommended + val binarizedRatings = ratings.map(r => Rating(r.user, r.product, + if (r.rating > 0) 1.0 else 0.0)).cache() + + // Summarize ratings + val numRatings = ratings.count() + val numUsers = ratings.map(_.user).distinct().count() + val numMovies = ratings.map(_.product).distinct().count() + println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") + + // Build the model + val numIterations = 10 + val rank = 10 + val lambda = 0.01 + val model = ALS.train(ratings, rank, numIterations, lambda) + + // Define a function to scale ratings from 0 to 1 + def scaledRating(r: Rating): Rating = { + val scaledRating = math.max(math.min(r.rating, 1.0), 0.0) + Rating(r.user, r.product, scaledRating) + } + + // Get sorted top ten predictions for each user and then scale from [0, 1] + val userRecommended = model.recommendProductsForUsers(10).map { case (user, recs) => + (user, recs.map(scaledRating)) + } + + // Assume that any movie a user rated 3 or higher (which maps to a 1) is a relevant document + // Compare with top ten most relevant documents + val userMovies = binarizedRatings.groupBy(_.user) + val relevantDocuments = userMovies.join(userRecommended).map { case (user, (actual, + predictions)) => + (predictions.map(_.product), actual.filter(_.rating > 0.0).map(_.product).toArray) + } + + // Instantiate metrics object + val metrics = new RankingMetrics(relevantDocuments) + + // Precision at K + Array(1, 3, 5).foreach { k => + println(s"Precision at $k = ${metrics.precisionAt(k)}") + } + + // Mean average precision + println(s"Mean average precision = ${metrics.meanAveragePrecision}") + + // Normalized discounted cumulative gain + Array(1, 3, 5).foreach { k => + println(s"NDCG at $k = ${metrics.ndcgAt(k)}") + } + + // Get predictions for each data point + val allPredictions = model.predict(ratings.map(r => (r.user, r.product))).map(r => ((r.user, + r.product), r.rating)) + val allRatings = ratings.map(r => ((r.user, r.product), r.rating)) + val predictionsAndLabels = allPredictions.join(allRatings).map { case ((user, product), + (predicted, actual)) => + (predicted, actual) + } + + // Get the RMSE using regression metrics + val regressionMetrics = new RegressionMetrics(predictionsAndLabels) + println(s"RMSE = ${regressionMetrics.rootMeanSquaredError}") + + // R-squared + println(s"R-squared = ${regressionMetrics.r2}") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala new file mode 100644 index 0000000000000..bc946951aebf9 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.recommendation.ALS +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel +import org.apache.spark.mllib.recommendation.Rating +// $example off$ + +object RecommendationExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("CollaborativeFilteringExample") + val sc = new SparkContext(conf) + // $example on$ + // Load and parse the data + val data = sc.textFile("data/mllib/als/test.data") + val ratings = data.map(_.split(',') match { case Array(user, item, rate) => + Rating(user.toInt, item.toInt, rate.toDouble) + }) + + // Build the recommendation model using ALS + val rank = 10 + val numIterations = 10 + val model = ALS.train(ratings, rank, numIterations, 0.01) + + // Evaluate the model on rating data + val usersProducts = ratings.map { case Rating(user, product, rate) => + (user, product) + } + val predictions = + model.predict(usersProducts).map { case Rating(user, product, rate) => + ((user, product), rate) + } + val ratesAndPreds = ratings.map { case Rating(user, product, rate) => + ((user, product), rate) + }.join(predictions) + val MSE = ratesAndPreds.map { case ((user, product), (r1, r2)) => + val err = (r1 - r2) + err * err + }.mean() + println("Mean Squared Error = " + MSE) + + // Save and load model + model.save(sc, "target/tmp/myCollaborativeFilter") + val sameModel = MatrixFactorizationModel.load(sc, "target/tmp/myCollaborativeFilter") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala new file mode 100644 index 0000000000000..add634c957b40 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// scalastyle:off println + +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.evaluation.RegressionMetrics +import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.mllib.util.MLUtils +// $example off$ +import org.apache.spark.sql.SQLContext + +object RegressionMetricsExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("RegressionMetricsExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + // $example on$ + // Load the data + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_linear_regression_data.txt").cache() + + // Build the model + val numIterations = 100 + val model = LinearRegressionWithSGD.train(data, numIterations) + + // Get predictions + val valuesAndPreds = data.map{ point => + val prediction = model.predict(point.features) + (prediction, point.label) + } + + // Instantiate metrics object + val metrics = new RegressionMetrics(valuesAndPreds) + + // Squared error + println(s"MSE = ${metrics.meanSquaredError}") + println(s"RMSE = ${metrics.rootMeanSquaredError}") + + // R-squared + println(s"R-squared = ${metrics.r2}") + + // Mean absolute error + println(s"MAE = ${metrics.meanAbsoluteError}") + + // Explained variance + println(s"Explained variance = ${metrics.explainedVariance}") + // $example off$ + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala new file mode 100644 index 0000000000000..c26580d4c1960 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +// $example on$ +import org.apache.spark.mllib.linalg.Matrix +import org.apache.spark.mllib.linalg.SingularValueDecomposition +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.distributed.RowMatrix +// $example off$ + +object SVDExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("SVDExample") + val sc = new SparkContext(conf) + + // $example on$ + val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) + + val dataRDD = sc.parallelize(data, 2) + + val mat: RowMatrix = new RowMatrix(dataRDD) + + // Compute the top 5 singular values and corresponding singular vectors. + val svd: SingularValueDecomposition[RowMatrix, Matrix] = mat.computeSVD(5, computeU = true) + val U: RowMatrix = svd.U // The U factor is a RowMatrix. + val s: Vector = svd.s // The singular values are stored in a local dense vector. + val V: Matrix = svd.V // The V factor is a local dense matrix. + // $example off$ + val collect = U.rows.collect() + println("U factor is:") + collect.foreach { vector => println(vector) } + println(s"Singular values are: $s") + println(s"V factor is:\n$V") + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala new file mode 100644 index 0000000000000..b73fe9b2b3faa --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD} +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object SVMWithSGDExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("SVMWithSGDExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load training data in LIBSVM format. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + + // Split data into training (60%) and test (40%). + val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) + val training = splits(0).cache() + val test = splits(1) + + // Run training algorithm to build the model + val numIterations = 100 + val model = SVMWithSGD.train(training, numIterations) + + // Clear the default threshold. + model.clearThreshold() + + // Compute raw scores on the test set. + val scoreAndLabels = test.map { point => + val score = model.predict(point.features) + (score, point.label) + } + + // Get evaluation metrics. + val metrics = new BinaryClassificationMetrics(scoreAndLabels) + val auROC = metrics.areaUnderROC() + + println("Area under ROC = " + auROC) + + // Save and load model + model.save(sc, "target/tmp/scalaSVMWithSGDModel") + val sameModel = SVMModel.load(sc, "target/tmp/scalaSVMWithSGDModel") + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala index 6963f43e082c4..0da4005977d1a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala @@ -18,11 +18,10 @@ // scalastyle:off println package org.apache.spark.examples.mllib -import org.apache.spark.mllib.util.MLUtils import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.util.MLUtils /** * An example app for randomly generated and sampled RDDs. Run with @@ -79,7 +78,7 @@ object SampledRDDs { val sampledRDD = examples.sample(withReplacement = true, fraction = fraction) println(s" RDD.sample(): sample has ${sampledRDD.count()} examples") val sampledArray = examples.takeSample(withReplacement = true, num = expectedSampleSize) - println(s" RDD.takeSample(): sample has ${sampledArray.size} examples") + println(s" RDD.takeSample(): sample has ${sampledArray.length} examples") println() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala index b4e06afa7410f..ab15ac2c54d3b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala @@ -18,13 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.fpm.FPGrowth import org.apache.spark.rdd.RDD // $example off$ -import org.apache.spark.{SparkContext, SparkConf} - object SimpleFPGrowth { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StandardScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StandardScalerExample.scala new file mode 100644 index 0000000000000..fc0aa1b7f0915 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StandardScalerExample.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +// $example on$ +import org.apache.spark.mllib.feature.{StandardScaler, StandardScalerModel} +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object StandardScalerExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("StandardScalerExample") + val sc = new SparkContext(conf) + + // $example on$ + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + + val scaler1 = new StandardScaler().fit(data.map(x => x.features)) + val scaler2 = new StandardScaler(withMean = true, withStd = true).fit(data.map(x => x.features)) + // scaler3 is an identical model to scaler2, and will produce identical transformations + val scaler3 = new StandardScalerModel(scaler2.std, scaler2.mean) + + // data1 will be unit variance. + val data1 = data.map(x => (x.label, scaler1.transform(x.features))) + + // Without converting the features into dense vectors, transformation with zero mean will raise + // exception on sparse vector. + // data2 will be unit variance and zero mean. + val data2 = data.map(x => (x.label, scaler2.transform(Vectors.dense(x.features.toArray)))) + // $example off$ + + println("data1: ") + data1.foreach(x => println(x)) + + println("data2: ") + data2.foreach(x => println(x)) + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala new file mode 100644 index 0000000000000..16b074ef60699 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} + +object StratifiedSamplingExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("StratifiedSamplingExample") + val sc = new SparkContext(conf) + + // $example on$ + // an RDD[(K, V)] of any key value pairs + val data = sc.parallelize( + Seq((1, 'a'), (1, 'b'), (2, 'c'), (2, 'd'), (2, 'e'), (3, 'f'))) + + // specify the exact fraction desired from each key + val fractions = Map(1 -> 0.1, 2 -> 0.6, 3 -> 0.3) + + // Get an approximate sample from each stratum + val approxSample = data.sampleByKey(withReplacement = false, fractions = fractions) + // Get an exact sample from each stratum + val exactSample = data.sampleByKeyExact(withReplacement = false, fractions = fractions) + // $example off$ + + println("approxSample size is " + approxSample.collect().size.toString) + approxSample.collect().foreach(println) + + println("exactSample its size is " + exactSample.collect().size.toString) + exactSample.collect().foreach(println) + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala index af03724a8ac62..7888af79f87f4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeansExample.scala @@ -19,10 +19,12 @@ package org.apache.spark.examples.mllib import org.apache.spark.SparkConf +// $example on$ import org.apache.spark.mllib.clustering.StreamingKMeans import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.streaming.{Seconds, StreamingContext} +// $example off$ /** * Estimate clusters on one stream of data and make predictions @@ -58,7 +60,8 @@ object StreamingKMeansExample { System.exit(1) } - val conf = new SparkConf().setMaster("local").setAppName("StreamingKMeansExample") + // $example on$ + val conf = new SparkConf().setAppName("StreamingKMeansExample") val ssc = new StreamingContext(conf, Seconds(args(2).toLong)) val trainingData = ssc.textFileStream(args(0)).map(Vectors.parse) @@ -74,6 +77,7 @@ object StreamingKMeansExample { ssc.start() ssc.awaitTermination() + // $example off$ } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala index b4a5dca031abd..e5592966f13fa 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala @@ -18,9 +18,9 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.SparkConf import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.{LabeledPoint, StreamingLinearRegressionWithSGD} -import org.apache.spark.SparkConf import org.apache.spark.streaming.{Seconds, StreamingContext} /** diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegressionExample.scala new file mode 100644 index 0000000000000..0a1cd2d62d5b5 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegressionExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +// $example on$ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD +// $example off$ +import org.apache.spark.streaming._ + +object StreamingLinearRegressionExample { + + def main(args: Array[String]): Unit = { + if (args.length != 2) { + System.err.println("Usage: StreamingLinearRegressionExample ") + System.exit(1) + } + + val conf = new SparkConf().setAppName("StreamingLinearRegressionExample") + val ssc = new StreamingContext(conf, Seconds(1)) + + // $example on$ + val trainingData = ssc.textFileStream(args(0)).map(LabeledPoint.parse).cache() + val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse) + + val numFeatures = 3 + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.zeros(numFeatures)) + + model.trainOn(trainingData) + model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() + + ssc.start() + ssc.awaitTermination() + // $example off$ + + ssc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala index b42f4cb5f9338..a8b144a197229 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala @@ -18,10 +18,10 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.SparkConf +import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD -import org.apache.spark.SparkConf import org.apache.spark.streaming.{Seconds, StreamingContext} /** diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala index ab29f90254d34..49f5df39443e9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala @@ -18,7 +18,7 @@ package org.apache.spark.examples.mllib import org.apache.spark.SparkConf -import org.apache.spark.mllib.stat.test.StreamingTest +import org.apache.spark.mllib.stat.test.{BinarySample, StreamingTest} import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.util.Utils @@ -64,8 +64,9 @@ object StreamingTestExample { dir.toString }) + // $example on$ val data = ssc.textFileStream(dataDir).map(line => line.split(",") match { - case Array(label, value) => (label.toBoolean, value.toDouble) + case Array(label, value) => BinarySample(label.toBoolean, value.toDouble) }) val streamingTest = new StreamingTest() @@ -75,6 +76,7 @@ object StreamingTestExample { val out = streamingTest.registerStream(data) out.print() + // $example off$ // Stop processing if test becomes significant or we time out var timeoutCounter = numBatchesTimeout diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SummaryStatisticsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SummaryStatisticsExample.scala new file mode 100644 index 0000000000000..948b443c0a754 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SummaryStatisticsExample.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} +// $example off$ + +object SummaryStatisticsExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("SummaryStatisticsExample") + val sc = new SparkContext(conf) + + // $example on$ + val observations = sc.parallelize( + Seq( + Vectors.dense(1.0, 10.0, 100.0), + Vectors.dense(2.0, 20.0, 200.0), + Vectors.dense(3.0, 30.0, 300.0) + ) + ) + + // Compute column summary statistics. + val summary: MultivariateStatisticalSummary = Statistics.colStats(observations) + println(summary.mean) // a dense vector containing the mean value for each column + println(summary.variance) // column-wise variance + println(summary.numNonzeros) // number of nonzeros in each column + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TFIDFExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TFIDFExample.scala new file mode 100644 index 0000000000000..a5bdcd8f2ed32 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TFIDFExample.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +// $example on$ +import org.apache.spark.mllib.feature.{HashingTF, IDF} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD +// $example off$ + +object TFIDFExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("TFIDFExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load documents (one per line). + val documents: RDD[Seq[String]] = sc.textFile("data/mllib/kmeans_data.txt") + .map(_.split(" ").toSeq) + + val hashingTF = new HashingTF() + val tf: RDD[Vector] = hashingTF.transform(documents) + + // While applying HashingTF only needs a single pass to the data, applying IDF needs two passes: + // First to compute the IDF vector and second to scale the term frequencies by IDF. + tf.cache() + val idf = new IDF().fit(tf) + val tfidf: RDD[Vector] = idf.transform(tf) + + // spark.mllib IDF implementation provides an option for ignoring terms which occur in less than + // a minimum number of documents. In such cases, the IDF for these terms is set to 0. + // This feature can be used by passing the minDocFreq value to the IDF constructor. + val idfIgnore = new IDF(minDocFreq = 2).fit(tf) + val tfidfIgnore: RDD[Vector] = idfIgnore.transform(tf) + // $example off$ + + println("tfidf: ") + tfidf.foreach(x => println(x)) + + println("tfidfIgnore: ") + tfidfIgnore.foreach(x => println(x)) + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/Word2VecExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/Word2VecExample.scala new file mode 100644 index 0000000000000..ea794c700ae7e --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/Word2VecExample.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +// $example on$ +import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel} +// $example off$ + +object Word2VecExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("Word2VecExample") + val sc = new SparkContext(conf) + + // $example on$ + val input = sc.textFile("data/mllib/sample_lda_data.txt").map(line => line.split(" ").toSeq) + + val word2vec = new Word2Vec() + + val model = word2vec.fit(input) + + val synonyms = model.findSynonyms("1", 5) + + for((synonym, cosineSimilarity) <- synonyms) { + println(s"$synonym $cosineSimilarity") + } + + // Save and load model + model.save(sc, "myModelPath") + val sameModel = Word2VecModel.load(sc, "myModelPath") + // $example off$ + + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala index 805184e740f06..cf12c98b4af6c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala @@ -79,7 +79,10 @@ object AvroConversionUtil extends Serializable { def unpackBytes(obj: Any): Array[Byte] = { val bytes: Array[Byte] = obj match { - case buf: java.nio.ByteBuffer => buf.array() + case buf: java.nio.ByteBuffer => + val arr = new Array[Byte](buf.remaining()) + buf.get(arr) + arr case arr: Array[Byte] => arr case other => throw new SparkException( s"Unknown BYTES type ${other.getClass.getName}") diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala index 0a25ee7ae56f4..e252ca882e534 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala @@ -20,12 +20,13 @@ package org.apache.spark.examples.pythonconverters import scala.collection.JavaConverters._ import scala.util.parsing.json.JSONObject -import org.apache.spark.api.python.Converter +import org.apache.hadoop.hbase.CellUtil +import org.apache.hadoop.hbase.KeyValue.Type import org.apache.hadoop.hbase.client.{Put, Result} import org.apache.hadoop.hbase.io.ImmutableBytesWritable import org.apache.hadoop.hbase.util.Bytes -import org.apache.hadoop.hbase.KeyValue.Type -import org.apache.hadoop.hbase.CellUtil + +import org.apache.spark.api.python.Converter /** * Implementation of [[org.apache.spark.api.python.Converter]] that converts all diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index 2cc56f04e5c1f..94b67cb29beb0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -19,8 +19,7 @@ package org.apache.spark.examples.sql import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.functions._ +import org.apache.spark.sql.{SaveMode, SQLContext} // One method for defining the schema of an RDD is to make a case class with the desired column // names and types. @@ -53,18 +52,18 @@ object RDDRelation { val rddFromSql = sqlContext.sql("SELECT key, value FROM records WHERE key < 10") println("Result of RDD.map:") - rddFromSql.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println) + rddFromSql.rdd.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println) // Queries can also be written using a LINQ-like Scala DSL. df.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println) - // Write out an RDD as a parquet file. - df.write.parquet("pair.parquet") + // Write out an RDD as a parquet file with overwrite mode. + df.write.mode(SaveMode.Overwrite).parquet("pair.parquet") - // Read in parquet file. Parquet files are self-describing so the schmema is preserved. + // Read in parquet file. Parquet files are self-describing so the schema is preserved. val parquetFile = sqlContext.read.parquet("pair.parquet") - // Queries can be run using the DSL on parequet files just like the original RDD. + // Queries can be run using the DSL on parquet files just like the original RDD. parquetFile.where($"key" === 1).select($"value".as("a")).collect().foreach(println) // These files can also be registered as tables. diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index bf40bd1ef13df..b654a2c8d4a40 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -18,10 +18,10 @@ // scalastyle:off println package org.apache.spark.examples.sql.hive -import com.google.common.io.{ByteStreams, Files} - import java.io.File +import com.google.common.io.{ByteStreams, Files} + import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql._ import org.apache.spark.sql.hive.HiveContext @@ -63,7 +63,7 @@ object HiveFromSpark { val rddFromSql = sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key") println("Result of RDD.map:") - val rddAsStrings = rddFromSql.map { + val rddAsStrings = rddFromSql.rdd.map { case Row(key: Int, value: String) => s"Key: $key, Value: $value" } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala deleted file mode 100644 index e9c9907198769..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ /dev/null @@ -1,174 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.streaming - -import scala.collection.mutable.LinkedList -import scala.reflect.ClassTag -import scala.util.Random - -import akka.actor.{Actor, ActorRef, Props, actorRef2Scala} - -import org.apache.spark.{SparkConf, SecurityManager} -import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions -import org.apache.spark.util.AkkaUtils -import org.apache.spark.streaming.receiver.ActorHelper - -case class SubscribeReceiver(receiverActor: ActorRef) -case class UnsubscribeReceiver(receiverActor: ActorRef) - -/** - * Sends the random content to every receiver subscribed with 1/2 - * second delay. - */ -class FeederActor extends Actor { - - val rand = new Random() - var receivers: LinkedList[ActorRef] = new LinkedList[ActorRef]() - - val strings: Array[String] = Array("words ", "may ", "count ") - - def makeMessage(): String = { - val x = rand.nextInt(3) - strings(x) + strings(2 - x) - } - - /* - * A thread to generate random messages - */ - new Thread() { - override def run() { - while (true) { - Thread.sleep(500) - receivers.foreach(_ ! makeMessage) - } - } - }.start() - - def receive: Receive = { - - case SubscribeReceiver(receiverActor: ActorRef) => - println("received subscribe from %s".format(receiverActor.toString)) - receivers = LinkedList(receiverActor) ++ receivers - - case UnsubscribeReceiver(receiverActor: ActorRef) => - println("received unsubscribe from %s".format(receiverActor.toString)) - receivers = receivers.dropWhile(x => x eq receiverActor) - - } -} - -/** - * A sample actor as receiver, is also simplest. This receiver actor - * goes and subscribe to a typical publisher/feeder actor and receives - * data. - * - * @see [[org.apache.spark.examples.streaming.FeederActor]] - */ -class SampleActorReceiver[T: ClassTag](urlOfPublisher: String) -extends Actor with ActorHelper { - - lazy private val remotePublisher = context.actorSelection(urlOfPublisher) - - override def preStart(): Unit = remotePublisher ! SubscribeReceiver(context.self) - - def receive: PartialFunction[Any, Unit] = { - case msg => store(msg.asInstanceOf[T]) - } - - override def postStop(): Unit = remotePublisher ! UnsubscribeReceiver(context.self) - -} - -/** - * A sample feeder actor - * - * Usage: FeederActor - * and describe the AkkaSystem that Spark Sample feeder would start on. - */ -object FeederActor { - - def main(args: Array[String]) { - if (args.length < 2){ - System.err.println("Usage: FeederActor \n") - System.exit(1) - } - val Seq(host, port) = args.toSeq - - val conf = new SparkConf - val actorSystem = AkkaUtils.createActorSystem("test", host, port.toInt, conf = conf, - securityManager = new SecurityManager(conf))._1 - val feeder = actorSystem.actorOf(Props[FeederActor], "FeederActor") - - println("Feeder started as:" + feeder) - - actorSystem.awaitTermination() - } -} - -/** - * A sample word count program demonstrating the use of plugging in - * Actor as Receiver - * Usage: ActorWordCount - * and describe the AkkaSystem that Spark Sample feeder is running on. - * - * To run this example locally, you may run Feeder Actor as - * `$ bin/run-example org.apache.spark.examples.streaming.FeederActor 127.0.1.1 9999` - * and then run the example - * `$ bin/run-example org.apache.spark.examples.streaming.ActorWordCount 127.0.1.1 9999` - */ -object ActorWordCount { - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println( - "Usage: ActorWordCount ") - System.exit(1) - } - - StreamingExamples.setStreamingLogLevels() - - val Seq(host, port) = args.toSeq - val sparkConf = new SparkConf().setAppName("ActorWordCount") - // Create the context and set the batch size - val ssc = new StreamingContext(sparkConf, Seconds(2)) - - /* - * Following is the use of actorStream to plug in custom actor as receiver - * - * An important point to note: - * Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e type of data received and InputDstream - * should be same. - * - * For example: Both actorStream and SampleActorReceiver are parameterized - * to same type to ensure type safety. - */ - - val lines = ssc.actorStream[String]( - Props(new SampleActorReceiver[String]("akka.tcp://test@%s:%s/user/FeederActor".format( - host, port.toInt))), "SampleReceiver") - - // compute wordcount - lines.flatMap(_.split("\\s+")).map(x => (x, 1)).reduceByKey(_ + _).print() - - ssc.start() - ssc.awaitTermination() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala index 28e9bf520e568..1d144db9864bd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala @@ -18,10 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.streaming -import java.io.{InputStreamReader, BufferedReader, InputStream} +import java.io.{BufferedReader, InputStreamReader} import java.net.Socket +import java.nio.charset.StandardCharsets -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.receiver.Receiver @@ -83,7 +85,8 @@ class CustomReceiver(host: String, port: Int) logInfo("Connecting to " + host + ":" + port) socket = new Socket(host, port) logInfo("Connected to " + host + ":" + port) - val reader = new BufferedReader(new InputStreamReader(socket.getInputStream(), "UTF-8")) + val reader = new BufferedReader( + new InputStreamReader(socket.getInputStream(), StandardCharsets.UTF_8)) userInput = reader.readLine() while(!isStopped && userInput != null) { store(userInput) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala index 2bdbc37e2a289..dd725d72c23ef 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala @@ -19,11 +19,9 @@ package org.apache.spark.examples.streaming import org.apache.spark.SparkConf -import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.flume._ import org.apache.spark.util.IntParam -import java.net.InetSocketAddress /** * Produces a count of events received from Flume. diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala index b40d17e9c2fa3..e7f9bf36e35cf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala @@ -20,11 +20,11 @@ package org.apache.spark.examples.streaming import java.util.HashMap -import org.apache.kafka.clients.producer.{ProducerConfig, KafkaProducer, ProducerRecord} +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerConfig, ProducerRecord} +import org.apache.spark.SparkConf import org.apache.spark.streaming._ import org.apache.spark.streaming.kafka._ -import org.apache.spark.SparkConf /** * Consumes messages from one or more topics in Kafka and does wordcount. diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala deleted file mode 100644 index d772ae309f40d..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.streaming - -import org.eclipse.paho.client.mqttv3._ -import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence - -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.mqtt._ -import org.apache.spark.SparkConf - -/** - * A simple Mqtt publisher for demonstration purposes, repeatedly publishes - * Space separated String Message "hello mqtt demo for spark streaming" - */ -object MQTTPublisher { - - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: MQTTPublisher ") - System.exit(1) - } - - StreamingExamples.setStreamingLogLevels() - - val Seq(brokerUrl, topic) = args.toSeq - - var client: MqttClient = null - - try { - val persistence = new MemoryPersistence() - client = new MqttClient(brokerUrl, MqttClient.generateClientId(), persistence) - - client.connect() - - val msgtopic = client.getTopic(topic) - val msgContent = "hello mqtt demo for spark streaming" - val message = new MqttMessage(msgContent.getBytes("utf-8")) - - while (true) { - try { - msgtopic.publish(message) - println(s"Published data. topic: ${msgtopic.getName()}; Message: $message") - } catch { - case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => - Thread.sleep(10) - println("Queue is full, wait for to consume data from the message queue") - } - } - } catch { - case e: MqttException => println("Exception Caught: " + e) - } finally { - if (client != null) { - client.disconnect() - } - } - } -} - -/** - * A sample wordcount with MqttStream stream - * - * To work with Mqtt, Mqtt Message broker/server required. - * Mosquitto (http://mosquitto.org/) is an open source Mqtt Broker - * In ubuntu mosquitto can be installed using the command `$ sudo apt-get install mosquitto` - * Eclipse paho project provides Java library for Mqtt Client http://www.eclipse.org/paho/ - * Example Java code for Mqtt Publisher and Subscriber can be found here - * https://bitbucket.org/mkjinesh/mqttclient - * Usage: MQTTWordCount - * and describe where Mqtt publisher is running. - * - * To run this example locally, you may run publisher as - * `$ bin/run-example \ - * org.apache.spark.examples.streaming.MQTTPublisher tcp://localhost:1883 foo` - * and run the example as - * `$ bin/run-example \ - * org.apache.spark.examples.streaming.MQTTWordCount tcp://localhost:1883 foo` - */ -object MQTTWordCount { - - def main(args: Array[String]) { - if (args.length < 2) { - // scalastyle:off println - System.err.println( - "Usage: MQTTWordCount ") - // scalastyle:on println - System.exit(1) - } - - val Seq(brokerUrl, topic) = args.toSeq - val sparkConf = new SparkConf().setAppName("MQTTWordCount") - val ssc = new StreamingContext(sparkConf, Seconds(2)) - val lines = MQTTUtils.createStream(ssc, brokerUrl, topic, StorageLevel.MEMORY_ONLY_SER_2) - val words = lines.flatMap(x => x.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - - wordCounts.print() - ssc.start() - ssc.awaitTermination() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala index 9a57fe286d1ae..15b57fccb4076 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala @@ -19,8 +19,8 @@ package org.apache.spark.examples.streaming import org.apache.spark.SparkConf -import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Seconds, StreamingContext} /** * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala index 13ba9a43ec3c9..5455aed22085d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.examples.streaming -import scala.collection.mutable.SynchronizedQueue +import scala.collection.mutable.Queue import org.apache.spark.SparkConf import org.apache.spark.rdd.RDD @@ -34,7 +34,7 @@ object QueueStream { // Create the queue through which RDDs can be pushed to // a QueueInputDStream - val rddQueue = new SynchronizedQueue[RDD[Int]]() + val rddQueue = new Queue[RDD[Int]]() // Create the QueueInputDStream and use it do some processing val inputStream = ssc.queueStream(rddQueue) @@ -45,7 +45,9 @@ object QueueStream { // Create and push some RDDs into for (i <- 1 to 30) { - rddQueue += ssc.sparkContext.makeRDD(1 to 1000, 10) + rddQueue.synchronized { + rddQueue += ssc.sparkContext.makeRDD(1 to 1000, 10) + } Thread.sleep(1000) } ssc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 9916882e4f94a..bb2af9cd72e2a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -23,13 +23,55 @@ import java.nio.charset.Charset import com.google.common.io.Files -import org.apache.spark.SparkConf +import org.apache.spark.{Accumulator, SparkConf, SparkContext} +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.{Time, Seconds, StreamingContext} +import org.apache.spark.streaming.{Seconds, StreamingContext, Time} import org.apache.spark.util.IntParam /** - * Counts words in text encoded with UTF8 received from the network every second. + * Use this singleton to get or register a Broadcast variable. + */ +object WordBlacklist { + + @volatile private var instance: Broadcast[Seq[String]] = null + + def getInstance(sc: SparkContext): Broadcast[Seq[String]] = { + if (instance == null) { + synchronized { + if (instance == null) { + val wordBlacklist = Seq("a", "b", "c") + instance = sc.broadcast(wordBlacklist) + } + } + } + instance + } +} + +/** + * Use this singleton to get or register an Accumulator. + */ +object DroppedWordsCounter { + + @volatile private var instance: Accumulator[Long] = null + + def getInstance(sc: SparkContext): Accumulator[Long] = { + if (instance == null) { + synchronized { + if (instance == null) { + instance = sc.accumulator(0L, "WordsInBlacklistCounter") + } + } + } + instance + } +} + +/** + * Counts words in text encoded with UTF8 received from the network every second. This example also + * shows how to use lazily instantiated singleton instances for Accumulator and Broadcast so that + * they can be registered on driver failures. * * Usage: RecoverableNetworkWordCount * and describe the TCP server that Spark Streaming would connect to receive @@ -74,18 +116,32 @@ object RecoverableNetworkWordCount { val lines = ssc.socketTextStream(ip, port) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.foreachRDD((rdd: RDD[(String, Int)], time: Time) => { - val counts = "Counts at time " + time + " " + rdd.collect().mkString("[", ", ", "]") - println(counts) + wordCounts.foreachRDD { (rdd: RDD[(String, Int)], time: Time) => + // Get or register the blacklist Broadcast + val blacklist = WordBlacklist.getInstance(rdd.sparkContext) + // Get or register the droppedWordsCounter Accumulator + val droppedWordsCounter = DroppedWordsCounter.getInstance(rdd.sparkContext) + // Use blacklist to drop words and use droppedWordsCounter to count them + val counts = rdd.filter { case (word, count) => + if (blacklist.value.contains(word)) { + droppedWordsCounter += count + false + } else { + true + } + }.collect().mkString("[", ", ", "]") + val output = "Counts at time " + time + " " + counts + println(output) + println("Dropped " + droppedWordsCounter.value + " word(s) totally") println("Appending to " + outputFile.getAbsolutePath) - Files.append(counts + "\n", outputFile, Charset.defaultCharset()) - }) + Files.append(output + "\n", outputFile, Charset.defaultCharset()) + } ssc } def main(args: Array[String]) { if (args.length != 4) { - System.err.println("You arguments were " + args.mkString("[", ", ", "]")) + System.err.println("Your arguments were " + args.mkString("[", ", ", "]")) System.err.println( """ |Usage: RecoverableNetworkWordCount diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala index ed617754cbf1c..918e124065e4c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala @@ -21,10 +21,9 @@ package org.apache.spark.examples.streaming import org.apache.spark.SparkConf import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.{Time, Seconds, StreamingContext} -import org.apache.spark.util.IntParam import org.apache.spark.sql.SQLContext import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Seconds, StreamingContext, Time} /** * Use DataFrames and SQL to count words in UTF8 encoded, '\n' delimited text received from the @@ -60,7 +59,7 @@ object SqlNetworkWordCount { val words = lines.flatMap(_.split(" ")) // Convert RDDs of the words DStream to DataFrame and run SQL query - words.foreachRDD((rdd: RDD[String], time: Time) => { + words.foreachRDD { (rdd: RDD[String], time: Time) => // Get the singleton instance of SQLContext val sqlContext = SQLContextSingleton.getInstance(rdd.sparkContext) import sqlContext.implicits._ @@ -76,7 +75,7 @@ object SqlNetworkWordCount { sqlContext.sql("select word, count(*) as total from words group by word") println(s"========= $time =========") wordCountsDataFrame.show() - }) + } ssc.start() ssc.awaitTermination() diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index 02ba1c2eed0f7..2811e67009fb0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -19,7 +19,6 @@ package org.apache.spark.examples.streaming import org.apache.spark.SparkConf -import org.apache.spark.HashPartitioner import org.apache.spark.streaming._ /** @@ -44,24 +43,12 @@ object StatefulNetworkWordCount { StreamingExamples.setStreamingLogLevels() - val updateFunc = (values: Seq[Int], state: Option[Int]) => { - val currentCount = values.sum - - val previousCount = state.getOrElse(0) - - Some(currentCount + previousCount) - } - - val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => { - iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) - } - val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount") // Create the context with a 1 second batch size val ssc = new StreamingContext(sparkConf, Seconds(1)) ssc.checkpoint(".") - // Initial RDD input to updateStateByKey + // Initial state RDD for mapWithState operation val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1))) // Create a ReceiverInputDStream on target ip:port and count the @@ -70,10 +57,17 @@ object StatefulNetworkWordCount { val words = lines.flatMap(_.split(" ")) val wordDstream = words.map(x => (x, 1)) - // Update the cumulative count using updateStateByKey - // This will give a Dstream made of state (which is the cumulative count of the words) - val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc, - new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD) + // Update the cumulative count using mapWithState + // This will give a DStream made of state (which is the cumulative count of the words) + val mappingFunc = (word: String, one: Option[Int], state: State[Int]) => { + val sum = one.getOrElse(0) + state.getOption.getOrElse(0) + val output = (word, sum) + state.update(sum) + output + } + + val stateDstream = wordDstream.mapWithState( + StateSpec.function(mappingFunc).initialState(initialRDD)) stateDstream.print() ssc.start() ssc.awaitTermination() diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StreamingExamples.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StreamingExamples.scala index 8396e65d0d588..b00f32fb25243 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StreamingExamples.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StreamingExamples.scala @@ -17,10 +17,10 @@ package org.apache.spark.examples.streaming -import org.apache.spark.Logging - import org.apache.log4j.{Level, Logger} +import org.apache.spark.internal.Logging + /** Utility functions for Spark Streaming examples. */ object StreamingExamples extends Logging { diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala deleted file mode 100644 index 825c671a929b1..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.streaming - -import com.twitter.algebird._ -import com.twitter.algebird.CMSHasherImplicits._ - -import org.apache.spark.SparkConf -import org.apache.spark.SparkContext._ -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.twitter._ - -// scalastyle:off -/** - * Illustrates the use of the Count-Min Sketch, from Twitter's Algebird library, to compute - * windowed and global Top-K estimates of user IDs occurring in a Twitter stream. - *
    - * Note that since Algebird's implementation currently only supports Long inputs, - * the example operates on Long IDs. Once the implementation supports other inputs (such as String), - * the same approach could be used for computing popular topics for example. - *

    - *

    - * - * This blog post has a good overview of the Count-Min Sketch (CMS). The CMS is a data - * structure for approximate frequency estimation in data streams (e.g. Top-K elements, frequency - * of any given element, etc), that uses space sub-linear in the number of elements in the - * stream. Once elements are added to the CMS, the estimated count of an element can be computed, - * as well as "heavy-hitters" that occur more than a threshold percentage of the overall total - * count. - *

    - * Algebird's implementation is a monoid, so we can succinctly merge two CMS instances in the - * reduce operation. - */ -// scalastyle:on -object TwitterAlgebirdCMS { - def main(args: Array[String]) { - StreamingExamples.setStreamingLogLevels() - - // CMS parameters - val DELTA = 1E-3 - val EPS = 0.01 - val SEED = 1 - val PERC = 0.001 - // K highest frequency elements to take - val TOPK = 10 - - val filters = args - val sparkConf = new SparkConf().setAppName("TwitterAlgebirdCMS") - val ssc = new StreamingContext(sparkConf, Seconds(10)) - val stream = TwitterUtils.createStream(ssc, None, filters, StorageLevel.MEMORY_ONLY_SER_2) - - val users = stream.map(status => status.getUser.getId) - - // val cms = new CountMinSketchMonoid(EPS, DELTA, SEED, PERC) - val cms = TopPctCMS.monoid[Long](EPS, DELTA, SEED, PERC) - var globalCMS = cms.zero - val mm = new MapMonoid[Long, Int]() - var globalExact = Map[Long, Int]() - - val approxTopUsers = users.mapPartitions(ids => { - ids.map(id => cms.create(id)) - }).reduce(_ ++ _) - - val exactTopUsers = users.map(id => (id, 1)) - .reduceByKey((a, b) => a + b) - - approxTopUsers.foreachRDD(rdd => { - if (rdd.count() != 0) { - val partial = rdd.first() - val partialTopK = partial.heavyHitters.map(id => - (id, partial.frequency(id).estimate)).toSeq.sortBy(_._2).reverse.slice(0, TOPK) - globalCMS ++= partial - val globalTopK = globalCMS.heavyHitters.map(id => - (id, globalCMS.frequency(id).estimate)).toSeq.sortBy(_._2).reverse.slice(0, TOPK) - println("Approx heavy hitters at %2.2f%% threshold this batch: %s".format(PERC, - partialTopK.mkString("[", ",", "]"))) - println("Approx heavy hitters at %2.2f%% threshold overall: %s".format(PERC, - globalTopK.mkString("[", ",", "]"))) - } - }) - - exactTopUsers.foreachRDD(rdd => { - if (rdd.count() != 0) { - val partialMap = rdd.collect().toMap - val partialTopK = rdd.map( - {case (id, count) => (count, id)}) - .sortByKey(ascending = false).take(TOPK) - globalExact = mm.plus(globalExact.toMap, partialMap) - val globalTopK = globalExact.toSeq.sortBy(_._2).reverse.slice(0, TOPK) - println("Exact heavy hitters this batch: %s".format(partialTopK.mkString("[", ",", "]"))) - println("Exact heavy hitters overall: %s".format(globalTopK.mkString("[", ",", "]"))) - } - }) - - ssc.start() - ssc.awaitTermination() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala deleted file mode 100644 index 49826ede70418..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.streaming - -import com.twitter.algebird.HyperLogLogMonoid -import com.twitter.algebird.HyperLogLog._ - -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.twitter._ -import org.apache.spark.SparkConf - -// scalastyle:off -/** - * Illustrates the use of the HyperLogLog algorithm, from Twitter's Algebird library, to compute - * a windowed and global estimate of the unique user IDs occurring in a Twitter stream. - *

    - *

    - * This - * blog post and this - * - * blog post - * have good overviews of HyperLogLog (HLL). HLL is a memory-efficient datastructure for - * estimating the cardinality of a data stream, i.e. the number of unique elements. - *

    - * Algebird's implementation is a monoid, so we can succinctly merge two HLL instances in the - * reduce operation. - */ -// scalastyle:on -object TwitterAlgebirdHLL { - def main(args: Array[String]) { - - StreamingExamples.setStreamingLogLevels() - - /** Bit size parameter for HyperLogLog, trades off accuracy vs size */ - val BIT_SIZE = 12 - val filters = args - val sparkConf = new SparkConf().setAppName("TwitterAlgebirdHLL") - val ssc = new StreamingContext(sparkConf, Seconds(5)) - val stream = TwitterUtils.createStream(ssc, None, filters, StorageLevel.MEMORY_ONLY_SER) - - val users = stream.map(status => status.getUser.getId) - - val hll = new HyperLogLogMonoid(BIT_SIZE) - var globalHll = hll.zero - var userSet: Set[Long] = Set() - - val approxUsers = users.mapPartitions(ids => { - ids.map(id => hll(id)) - }).reduce(_ + _) - - val exactUsers = users.map(id => Set(id)).reduce(_ ++ _) - - approxUsers.foreachRDD(rdd => { - if (rdd.count() != 0) { - val partial = rdd.first() - globalHll += partial - println("Approx distinct users this batch: %d".format(partial.estimatedSize.toInt)) - println("Approx distinct users overall: %d".format(globalHll.estimatedSize.toInt)) - } - }) - - exactUsers.foreachRDD(rdd => { - if (rdd.count() != 0) { - val partial = rdd.first() - userSet ++= partial - println("Exact distinct users this batch: %d".format(partial.size)) - println("Exact distinct users overall: %d".format(userSet.size)) - println("Error rate: %2.5f%%".format(((globalHll.estimatedSize / userSet.size.toDouble) - 1 - ) * 100)) - } - }) - - ssc.start() - ssc.awaitTermination() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala deleted file mode 100644 index 49cee1b43c2dc..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.streaming - -import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.SparkContext._ -import org.apache.spark.streaming.twitter._ -import org.apache.spark.SparkConf - -/** - * Calculates popular hashtags (topics) over sliding 10 and 60 second windows from a Twitter - * stream. The stream is instantiated with credentials and optionally filters supplied by the - * command line arguments. - * - * Run this on your local machine as - * - */ -object TwitterPopularTags { - def main(args: Array[String]) { - if (args.length < 4) { - System.err.println("Usage: TwitterPopularTags " + - " []") - System.exit(1) - } - - StreamingExamples.setStreamingLogLevels() - - val Array(consumerKey, consumerSecret, accessToken, accessTokenSecret) = args.take(4) - val filters = args.takeRight(args.length - 4) - - // Set the system properties so that Twitter4j library used by twitter stream - // can use them to generat OAuth credentials - System.setProperty("twitter4j.oauth.consumerKey", consumerKey) - System.setProperty("twitter4j.oauth.consumerSecret", consumerSecret) - System.setProperty("twitter4j.oauth.accessToken", accessToken) - System.setProperty("twitter4j.oauth.accessTokenSecret", accessTokenSecret) - - val sparkConf = new SparkConf().setAppName("TwitterPopularTags") - val ssc = new StreamingContext(sparkConf, Seconds(2)) - val stream = TwitterUtils.createStream(ssc, None, filters) - - val hashTags = stream.flatMap(status => status.getText.split(" ").filter(_.startsWith("#"))) - - val topCounts60 = hashTags.map((_, 1)).reduceByKeyAndWindow(_ + _, Seconds(60)) - .map{case (topic, count) => (count, topic)} - .transform(_.sortByKey(false)) - - val topCounts10 = hashTags.map((_, 1)).reduceByKeyAndWindow(_ + _, Seconds(10)) - .map{case (topic, count) => (count, topic)} - .transform(_.sortByKey(false)) - - - // Print popular hashtags - topCounts60.foreachRDD(rdd => { - val topList = rdd.take(10) - println("\nPopular topics in last 60 seconds (%s total):".format(rdd.count())) - topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))} - }) - - topCounts10.foreachRDD(rdd => { - val topList = rdd.take(10) - println("\nPopular topics in last 10 seconds (%s total):".format(rdd.count())) - topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))} - }) - - ssc.start() - ssc.awaitTermination() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala deleted file mode 100644 index 6ac9a72c37941..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.streaming - -import akka.actor.ActorSystem -import akka.actor.actorRef2Scala -import akka.zeromq._ -import akka.zeromq.Subscribe -import akka.util.ByteString - -import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.zeromq._ - -import scala.language.implicitConversions -import org.apache.spark.SparkConf - -/** - * A simple publisher for demonstration purposes, repeatedly publishes random Messages - * every one second. - */ -object SimpleZeroMQPublisher { - - def main(args: Array[String]): Unit = { - if (args.length < 2) { - System.err.println("Usage: SimpleZeroMQPublisher ") - System.exit(1) - } - - val Seq(url, topic) = args.toSeq - val acs: ActorSystem = ActorSystem() - - val pubSocket = ZeroMQExtension(acs).newSocket(SocketType.Pub, Bind(url)) - implicit def stringToByteString(x: String): ByteString = ByteString(x) - val messages: List[ByteString] = List("words ", "may ", "count ") - while (true) { - Thread.sleep(1000) - pubSocket ! ZMQMessage(ByteString(topic) :: messages) - } - acs.awaitTermination() - } -} - -// scalastyle:off -/** - * A sample wordcount with ZeroMQStream stream - * - * To work with zeroMQ, some native libraries have to be installed. - * Install zeroMQ (release 2.1) core libraries. [ZeroMQ Install guide] - * (http://www.zeromq.org/intro:get-the-software) - * - * Usage: ZeroMQWordCount - * and describe where zeroMq publisher is running. - * - * To run this example locally, you may run publisher as - * `$ bin/run-example \ - * org.apache.spark.examples.streaming.SimpleZeroMQPublisher tcp://127.0.1.1:1234 foo.bar` - * and run the example as - * `$ bin/run-example \ - * org.apache.spark.examples.streaming.ZeroMQWordCount tcp://127.0.1.1:1234 foo` - */ -// scalastyle:on -object ZeroMQWordCount { - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: ZeroMQWordCount ") - System.exit(1) - } - StreamingExamples.setStreamingLogLevels() - val Seq(url, topic) = args.toSeq - val sparkConf = new SparkConf().setAppName("ZeroMQWordCount") - // Create the context and set the batch size - val ssc = new StreamingContext(sparkConf, Seconds(2)) - - def bytesToStringIterator(x: Seq[ByteString]): Iterator[String] = x.map(_.utf8String).iterator - - // For this stream, a zeroMQ publisher should be running. - val lines = ZeroMQUtils.createStream(ssc, url, Subscribe(topic), bytesToStringIterator _) - val words = lines.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - ssc.start() - ssc.awaitTermination() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index bea7a47cb2855..0ddd065f0db2b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -18,49 +18,50 @@ // scalastyle:off println package org.apache.spark.examples.streaming.clickstream -import java.net.ServerSocket import java.io.PrintWriter -import util.Random +import java.net.ServerSocket +import java.util.Random /** Represents a page view on a website with associated dimension data. */ -class PageView(val url : String, val status : Int, val zipCode : Int, val userID : Int) +class PageView(val url: String, val status: Int, val zipCode: Int, val userID: Int) extends Serializable { - override def toString() : String = { + override def toString(): String = { "%s\t%s\t%s\t%s\n".format(url, status, zipCode, userID) } } object PageView extends Serializable { - def fromString(in : String) : PageView = { + def fromString(in: String): PageView = { val parts = in.split("\t") new PageView(parts(0), parts(1).toInt, parts(2).toInt, parts(3).toInt) } } // scalastyle:off -/** Generates streaming events to simulate page views on a website. - * - * This should be used in tandem with PageViewStream.scala. Example: - * - * To run the generator - * `$ bin/run-example org.apache.spark.examples.streaming.clickstream.PageViewGenerator 44444 10` - * To process the generated stream - * `$ bin/run-example \ - * org.apache.spark.examples.streaming.clickstream.PageViewStream errorRatePerZipCode localhost 44444` - * - */ +/** + * Generates streaming events to simulate page views on a website. + * + * This should be used in tandem with PageViewStream.scala. Example: + * + * To run the generator + * `$ bin/run-example org.apache.spark.examples.streaming.clickstream.PageViewGenerator 44444 10` + * To process the generated stream + * `$ bin/run-example \ + * org.apache.spark.examples.streaming.clickstream.PageViewStream errorRatePerZipCode localhost 44444` + * + */ // scalastyle:on object PageViewGenerator { - val pages = Map("http://foo.com/" -> .7, - "http://foo.com/news" -> 0.2, + val pages = Map("http://foo.com/" -> .7, + "http://foo.com/news" -> 0.2, "http://foo.com/contact" -> .1) val httpStatus = Map(200 -> .95, 404 -> .05) val userZipCode = Map(94709 -> .5, 94117 -> .5) - val userID = Map((1 to 100).map(_ -> .01) : _*) + val userID = Map((1 to 100).map(_ -> .01): _*) - def pickFromDistribution[T](inputMap : Map[T, Double]) : T = { + def pickFromDistribution[T](inputMap: Map[T, Double]): T = { val rand = new Random().nextDouble() var total = 0.0 for ((item, prob) <- inputMap) { @@ -72,7 +73,7 @@ object PageViewGenerator { inputMap.take(1).head._1 // Shouldn't get here if probabilities add up to 1.0 } - def getNextClickEvent() : String = { + def getNextClickEvent(): String = { val id = pickFromDistribution(userID) val page = pickFromDistribution(pages) val status = pickFromDistribution(httpStatus) @@ -80,7 +81,7 @@ object PageViewGenerator { new PageView(page, status, zipCode, id).toString() } - def main(args : Array[String]) { + def main(args: Array[String]) { if (args.length != 2) { System.err.println("Usage: PageViewGenerator ") System.exit(1) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index ec7d39da8b2e9..1ba093f57b32c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -18,20 +18,21 @@ // scalastyle:off println package org.apache.spark.examples.streaming.clickstream -import org.apache.spark.SparkContext._ -import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.examples.streaming.StreamingExamples +import org.apache.spark.streaming.{Seconds, StreamingContext} + // scalastyle:off -/** Analyses a streaming dataset of web page views. This class demonstrates several types of - * operators available in Spark streaming. - * - * This should be used in tandem with PageViewStream.scala. Example: - * To run the generator - * `$ bin/run-example org.apache.spark.examples.streaming.clickstream.PageViewGenerator 44444 10` - * To process the generated stream - * `$ bin/run-example \ - * org.apache.spark.examples.streaming.clickstream.PageViewStream errorRatePerZipCode localhost 44444` - */ +/** + * Analyses a streaming dataset of web page views. This class demonstrates several types of + * operators available in Spark streaming. + * + * This should be used in tandem with PageViewStream.scala. Example: + * To run the generator + * `$ bin/run-example org.apache.spark.examples.streaming.clickstream.PageViewGenerator 44444 10` + * To process the generated stream + * `$ bin/run-example \ + * org.apache.spark.examples.streaming.clickstream.PageViewStream errorRatePerZipCode localhost 44444` + */ // scalastyle:on object PageViewStream { def main(args: Array[String]) { @@ -69,7 +70,7 @@ object PageViewStream { .groupByKey() val errorRatePerZipCode = statusesPerZipCode.map{ case(zip, statuses) => - val normalCount = statuses.filter(_ == 200).size + val normalCount = statuses.count(_ == 200) val errorCount = statuses.size - normalCount val errorRatio = errorCount.toFloat / statuses.size if (errorRatio > 0.05) { @@ -87,8 +88,10 @@ object PageViewStream { .map("Unique active users: " + _) // An external dataset we want to join to this stream - val userList = ssc.sparkContext.parallelize( - Map(1 -> "Patrick Wendell", 2->"Reynold Xin", 3->"Matei Zaharia").toSeq) + val userList = ssc.sparkContext.parallelize(Seq( + 1 -> "Patrick Wendell", + 2 -> "Reynold Xin", + 3 -> "Matei Zaharia")) metric match { case "pageCounts" => pageCounts.print() @@ -106,6 +109,7 @@ object PageViewStream { } ssc.start() + ssc.awaitTermination() } } // scalastyle:on println diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml new file mode 100644 index 0000000000000..17fd7d781c9ab --- /dev/null +++ b/external/docker-integration-tests/pom.xml @@ -0,0 +1,214 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.0.0-SNAPSHOT + ../../pom.xml + + + spark-docker-integration-tests_2.11 + jar + Spark Project Docker Integration Tests + http://spark.apache.org/ + + docker-integration-tests + + + + + db2 + https://app.camunda.com/nexus/content/repositories/public/ + + + + + + com.spotify + docker-client + shaded + test + + + + com.fasterxml.jackson.jaxrs + jackson-jaxrs-json-provider + + + com.fasterxml.jackson.datatype + jackson-datatype-guava + + + com.fasterxml.jackson.core + jackson-databind + + + org.glassfish.jersey.core + jersey-client + + + org.glassfish.jersey.connectors + jersey-apache-connector + + + org.glassfish.jersey.media + jersey-media-json-jackson + + + + + org.apache.httpcomponents + httpclient + 4.5 + test + + + org.apache.httpcomponents + httpcore + 4.4.1 + test + + + + com.google.guava + guava + 18.0 + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-test-tags_${scala.binary.version} + ${project.version} + test + + + mysql + mysql-connector-java + test + + + org.postgresql + postgresql + test + + + + + + com.sun.jersey + jersey-server + 1.19 + test + + + com.sun.jersey + jersey-core + 1.19 + test + + + com.sun.jersey + jersey-servlet + 1.19 + test + + + com.sun.jersey + jersey-json + 1.19 + test + + + stax + stax-api + + + + + + + + com.ibm.db2.jcc + db2jcc4 + 10.5.0.5 + jar + + + diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala new file mode 100644 index 0000000000000..4fe1ef6697206 --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.math.BigDecimal +import java.sql.{Connection, Date, Timestamp} +import java.util.Properties + +import org.scalatest._ + +import org.apache.spark.tags.DockerTest + +@DockerTest +@Ignore // AMPLab Jenkins needs to be updated before shared memory works on docker +class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { + override val db = new DatabaseOnDocker { + override val imageName = "lresende/db2express-c:10.5.0.5-3.10.0" + override val env = Map( + "DB2INST1_PASSWORD" -> "rootpass", + "LICENSE" -> "accept" + ) + override val usesIpc = true + override val jdbcPort: Int = 50000 + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:db2://$ip:$port/foo:user=db2inst1;password=rootpass;" + override def getStartupProcessName: Option[String] = Some("db2start") + } + + override def dataPreparation(conn: Connection): Unit = { + conn.prepareStatement("CREATE TABLE tbl (x INTEGER, y VARCHAR(8))").executeUpdate() + conn.prepareStatement("INSERT INTO tbl VALUES (42,'fred')").executeUpdate() + conn.prepareStatement("INSERT INTO tbl VALUES (17,'dave')").executeUpdate() + + conn.prepareStatement("CREATE TABLE numbers (onebit BIT(1), tenbits BIT(10), " + + "small SMALLINT, med MEDIUMINT, nor INT, big BIGINT, deci DECIMAL(40,20), flt FLOAT, " + + "dbl DOUBLE)").executeUpdate() + conn.prepareStatement("INSERT INTO numbers VALUES (b'0', b'1000100101', " + + "17, 77777, 123456789, 123456789012345, 123456789012345.123456789012345, " + + "42.75, 1.0000000000000002)").executeUpdate() + + conn.prepareStatement("CREATE TABLE dates (d DATE, t TIME, dt DATETIME, ts TIMESTAMP, " + + "yr YEAR)").executeUpdate() + conn.prepareStatement("INSERT INTO dates VALUES ('1991-11-09', '13:31:24', " + + "'1996-01-01 01:23:45', '2009-02-13 23:31:30', '2001')").executeUpdate() + + // TODO: Test locale conversion for strings. + conn.prepareStatement("CREATE TABLE strings (a CHAR(10), b VARCHAR(10), c CLOB, d BLOB, " + + "e CHAR FOR BIT DATA)").executeUpdate() + conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', 'fox', 'jumps'") + .executeUpdate() + } + + test("Basic test") { + val df = sqlContext.read.jdbc(jdbcUrl, "tbl", new Properties) + val rows = df.collect() + assert(rows.length == 2) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 2) + assert(types(0).equals("class java.lang.Integer")) + assert(types(1).equals("class java.lang.String")) + } + + test("Numeric types") { + val df = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 9) + assert(types(0).equals("class java.lang.Boolean")) + assert(types(1).equals("class java.lang.Long")) + assert(types(2).equals("class java.lang.Integer")) + assert(types(3).equals("class java.lang.Integer")) + assert(types(4).equals("class java.lang.Integer")) + assert(types(5).equals("class java.lang.Long")) + assert(types(6).equals("class java.math.BigDecimal")) + assert(types(7).equals("class java.lang.Double")) + assert(types(8).equals("class java.lang.Double")) + assert(rows(0).getBoolean(0) == false) + assert(rows(0).getLong(1) == 0x225) + assert(rows(0).getInt(2) == 17) + assert(rows(0).getInt(3) == 77777) + assert(rows(0).getInt(4) == 123456789) + assert(rows(0).getLong(5) == 123456789012345L) + val bd = new BigDecimal("123456789012345.12345678901234500000") + assert(rows(0).getAs[BigDecimal](6).equals(bd)) + assert(rows(0).getDouble(7) == 42.75) + assert(rows(0).getDouble(8) == 1.0000000000000002) + } + + test("Date types") { + val df = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 5) + assert(types(0).equals("class java.sql.Date")) + assert(types(1).equals("class java.sql.Timestamp")) + assert(types(2).equals("class java.sql.Timestamp")) + assert(types(3).equals("class java.sql.Timestamp")) + assert(types(4).equals("class java.sql.Date")) + assert(rows(0).getAs[Date](0).equals(Date.valueOf("1991-11-09"))) + assert(rows(0).getAs[Timestamp](1).equals(Timestamp.valueOf("1970-01-01 13:31:24"))) + assert(rows(0).getAs[Timestamp](2).equals(Timestamp.valueOf("1996-01-01 01:23:45"))) + assert(rows(0).getAs[Timestamp](3).equals(Timestamp.valueOf("2009-02-13 23:31:30"))) + assert(rows(0).getAs[Date](4).equals(Date.valueOf("2001-01-01"))) + } + + test("String types") { + val df = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 9) + assert(types(0).equals("class java.lang.String")) + assert(types(1).equals("class java.lang.String")) + assert(types(2).equals("class java.lang.String")) + assert(types(3).equals("class java.lang.String")) + assert(types(4).equals("class java.lang.String")) + assert(types(5).equals("class java.lang.String")) + assert(types(6).equals("class [B")) + assert(types(7).equals("class [B")) + assert(types(8).equals("class [B")) + assert(rows(0).getString(0).equals("the")) + assert(rows(0).getString(1).equals("quick")) + assert(rows(0).getString(2).equals("brown")) + assert(rows(0).getString(3).equals("fox")) + assert(rows(0).getString(4).equals("jumps")) + assert(rows(0).getString(5).equals("over")) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](116, 104, 101, 0))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](7), Array[Byte](108, 97, 122, 121))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](8), Array[Byte](100, 111, 103))) + } + + test("Basic write test") { + val df1 = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) + val df2 = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties) + val df3 = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties) + df1.write.jdbc(jdbcUrl, "numberscopy", new Properties) + df2.write.jdbc(jdbcUrl, "datescopy", new Properties) + df3.write.jdbc(jdbcUrl, "stringscopy", new Properties) + } +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala new file mode 100644 index 0000000000000..c36f4d5f95482 --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.net.ServerSocket +import java.sql.Connection + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +import com.spotify.docker.client._ +import com.spotify.docker.client.messages.{ContainerConfig, HostConfig, PortBinding} +import org.scalatest.BeforeAndAfterAll +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.DockerUtils + +abstract class DatabaseOnDocker { + /** + * The docker image to be pulled. + */ + val imageName: String + + /** + * Environment variables to set inside of the Docker container while launching it. + */ + val env: Map[String, String] + + /** + * Wheather or not to use ipc mode for shared memory when starting docker image + */ + val usesIpc: Boolean + + /** + * The container-internal JDBC port that the database listens on. + */ + val jdbcPort: Int + + /** + * Return a JDBC URL that connects to the database running at the given IP address and port. + */ + def getJdbcUrl(ip: String, port: Int): String + + /** + * Optional process to run when container starts + */ + def getStartupProcessName: Option[String] +} + +abstract class DockerJDBCIntegrationSuite + extends SparkFunSuite + with BeforeAndAfterAll + with Eventually + with SharedSQLContext { + + val db: DatabaseOnDocker + + private var docker: DockerClient = _ + private var containerId: String = _ + protected var jdbcUrl: String = _ + + override def beforeAll() { + super.beforeAll() + try { + docker = DefaultDockerClient.fromEnv.build() + // Check that Docker is actually up + try { + docker.ping() + } catch { + case NonFatal(e) => + log.error("Exception while connecting to Docker. Check whether Docker is running.") + throw e + } + // Ensure that the Docker image is installed: + try { + docker.inspectImage(db.imageName) + } catch { + case e: ImageNotFoundException => + log.warn(s"Docker image ${db.imageName} not found; pulling image from registry") + docker.pull(db.imageName) + } + // Configure networking (necessary for boot2docker / Docker Machine) + val externalPort: Int = { + val sock = new ServerSocket(0) + val port = sock.getLocalPort + sock.close() + port + } + val dockerIp = DockerUtils.getDockerIp() + val hostConfig: HostConfig = HostConfig.builder() + .networkMode("bridge") + .ipcMode(if (db.usesIpc) "host" else "") + .portBindings( + Map(s"${db.jdbcPort}/tcp" -> List(PortBinding.of(dockerIp, externalPort)).asJava).asJava) + .build() + // Create the database container: + val containerConfigBuilder = ContainerConfig.builder() + .image(db.imageName) + .networkDisabled(false) + .env(db.env.map { case (k, v) => s"$k=$v" }.toSeq.asJava) + .hostConfig(hostConfig) + .exposedPorts(s"${db.jdbcPort}/tcp") + if(db.getStartupProcessName.isDefined) { + containerConfigBuilder + .cmd(db.getStartupProcessName.get) + } + val config = containerConfigBuilder.build() + // Create the database container: + containerId = docker.createContainer(config).id + // Start the container and wait until the database can accept JDBC connections: + docker.startContainer(containerId) + jdbcUrl = db.getJdbcUrl(dockerIp, externalPort) + eventually(timeout(60.seconds), interval(1.seconds)) { + val conn = java.sql.DriverManager.getConnection(jdbcUrl) + conn.close() + } + // Run any setup queries: + val conn: Connection = java.sql.DriverManager.getConnection(jdbcUrl) + try { + dataPreparation(conn) + } finally { + conn.close() + } + } catch { + case NonFatal(e) => + try { + afterAll() + } finally { + throw e + } + } + } + + override def afterAll() { + try { + if (docker != null) { + try { + if (containerId != null) { + docker.killContainer(containerId) + docker.removeContainer(containerId) + } + } catch { + case NonFatal(e) => + logWarning(s"Could not stop container $containerId", e) + } finally { + docker.close() + } + } + } finally { + super.afterAll() + } + } + + /** + * Prepare databases and tables for testing. + */ + def dataPreparation(connection: Connection): Unit +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala new file mode 100644 index 0000000000000..a70ed98b52d5d --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.math.BigDecimal +import java.sql.{Connection, Date, Timestamp} +import java.util.Properties + +import org.apache.spark.tags.DockerTest + +@DockerTest +class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { + override val db = new DatabaseOnDocker { + override val imageName = "mysql:5.7.9" + override val env = Map( + "MYSQL_ROOT_PASSWORD" -> "rootpass" + ) + override val usesIpc = false + override val jdbcPort: Int = 3306 + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass" + override def getStartupProcessName: Option[String] = None + } + + override def dataPreparation(conn: Connection): Unit = { + conn.prepareStatement("CREATE DATABASE foo").executeUpdate() + conn.prepareStatement("CREATE TABLE tbl (x INTEGER, y TEXT(8))").executeUpdate() + conn.prepareStatement("INSERT INTO tbl VALUES (42,'fred')").executeUpdate() + conn.prepareStatement("INSERT INTO tbl VALUES (17,'dave')").executeUpdate() + + conn.prepareStatement("CREATE TABLE numbers (onebit BIT(1), tenbits BIT(10), " + + "small SMALLINT, med MEDIUMINT, nor INT, big BIGINT, deci DECIMAL(40,20), flt FLOAT, " + + "dbl DOUBLE)").executeUpdate() + conn.prepareStatement("INSERT INTO numbers VALUES (b'0', b'1000100101', " + + "17, 77777, 123456789, 123456789012345, 123456789012345.123456789012345, " + + "42.75, 1.0000000000000002)").executeUpdate() + + conn.prepareStatement("CREATE TABLE dates (d DATE, t TIME, dt DATETIME, ts TIMESTAMP, " + + "yr YEAR)").executeUpdate() + conn.prepareStatement("INSERT INTO dates VALUES ('1991-11-09', '13:31:24', " + + "'1996-01-01 01:23:45', '2009-02-13 23:31:30', '2001')").executeUpdate() + + // TODO: Test locale conversion for strings. + conn.prepareStatement("CREATE TABLE strings (a CHAR(10), b VARCHAR(10), c TINYTEXT, " + + "d TEXT, e MEDIUMTEXT, f LONGTEXT, g BINARY(4), h VARBINARY(10), i BLOB)" + ).executeUpdate() + conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', 'fox', " + + "'jumps', 'over', 'the', 'lazy', 'dog')").executeUpdate() + } + + test("Basic test") { + val df = sqlContext.read.jdbc(jdbcUrl, "tbl", new Properties) + val rows = df.collect() + assert(rows.length == 2) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 2) + assert(types(0).equals("class java.lang.Integer")) + assert(types(1).equals("class java.lang.String")) + } + + test("Numeric types") { + val df = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 9) + assert(types(0).equals("class java.lang.Boolean")) + assert(types(1).equals("class java.lang.Long")) + assert(types(2).equals("class java.lang.Integer")) + assert(types(3).equals("class java.lang.Integer")) + assert(types(4).equals("class java.lang.Integer")) + assert(types(5).equals("class java.lang.Long")) + assert(types(6).equals("class java.math.BigDecimal")) + assert(types(7).equals("class java.lang.Double")) + assert(types(8).equals("class java.lang.Double")) + assert(rows(0).getBoolean(0) == false) + assert(rows(0).getLong(1) == 0x225) + assert(rows(0).getInt(2) == 17) + assert(rows(0).getInt(3) == 77777) + assert(rows(0).getInt(4) == 123456789) + assert(rows(0).getLong(5) == 123456789012345L) + val bd = new BigDecimal("123456789012345.12345678901234500000") + assert(rows(0).getAs[BigDecimal](6).equals(bd)) + assert(rows(0).getDouble(7) == 42.75) + assert(rows(0).getDouble(8) == 1.0000000000000002) + } + + test("Date types") { + val df = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 5) + assert(types(0).equals("class java.sql.Date")) + assert(types(1).equals("class java.sql.Timestamp")) + assert(types(2).equals("class java.sql.Timestamp")) + assert(types(3).equals("class java.sql.Timestamp")) + assert(types(4).equals("class java.sql.Date")) + assert(rows(0).getAs[Date](0).equals(Date.valueOf("1991-11-09"))) + assert(rows(0).getAs[Timestamp](1).equals(Timestamp.valueOf("1970-01-01 13:31:24"))) + assert(rows(0).getAs[Timestamp](2).equals(Timestamp.valueOf("1996-01-01 01:23:45"))) + assert(rows(0).getAs[Timestamp](3).equals(Timestamp.valueOf("2009-02-13 23:31:30"))) + assert(rows(0).getAs[Date](4).equals(Date.valueOf("2001-01-01"))) + } + + test("String types") { + val df = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 9) + assert(types(0).equals("class java.lang.String")) + assert(types(1).equals("class java.lang.String")) + assert(types(2).equals("class java.lang.String")) + assert(types(3).equals("class java.lang.String")) + assert(types(4).equals("class java.lang.String")) + assert(types(5).equals("class java.lang.String")) + assert(types(6).equals("class [B")) + assert(types(7).equals("class [B")) + assert(types(8).equals("class [B")) + assert(rows(0).getString(0).equals("the")) + assert(rows(0).getString(1).equals("quick")) + assert(rows(0).getString(2).equals("brown")) + assert(rows(0).getString(3).equals("fox")) + assert(rows(0).getString(4).equals("jumps")) + assert(rows(0).getString(5).equals("over")) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](116, 104, 101, 0))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](7), Array[Byte](108, 97, 122, 121))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](8), Array[Byte](100, 111, 103))) + } + + test("Basic write test") { + val df1 = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) + val df2 = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties) + val df3 = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties) + df1.write.jdbc(jdbcUrl, "numberscopy", new Properties) + df2.write.jdbc(jdbcUrl, "datescopy", new Properties) + df3.write.jdbc(jdbcUrl, "stringscopy", new Properties) + } +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala new file mode 100644 index 0000000000000..2fc174eb1b3a1 --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Connection +import java.util.Properties + +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.tags.DockerTest + +/** + * This patch was tested using the Oracle docker. Created this integration suite for the same. + * The ojdbc6-11.2.0.2.0.jar was to be downloaded from the maven repository. Since there was + * no jdbc jar available in the maven repository, the jar was downloaded from oracle site + * manually and installed in the local; thus tested. So, for SparkQA test case run, the + * ojdbc jar might be manually placed in the local maven repository(com/oracle/ojdbc6/11.2.0.2.0) + * while Spark QA test run. + * + * The following would be the steps to test this + * 1. Pull oracle 11g image - docker pull wnameless/oracle-xe-11g + * 2. Start docker - sudo service docker start + * 3. Download oracle 11g driver jar and put it in maven local repo: + * (com/oracle/ojdbc6/11.2.0.2.0/ojdbc6-11.2.0.2.0.jar) + * 4. The timeout and interval parameter to be increased from 60,1 to a high value for oracle test + * in DockerJDBCIntegrationSuite.scala (Locally tested with 200,200 and executed successfully). + * 5. Run spark test - ./build/sbt "test-only org.apache.spark.sql.jdbc.OracleIntegrationSuite" + * + * All tests in this suite are ignored because of the dependency with the oracle jar from maven + * repository. + */ +@DockerTest +class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLContext { + import testImplicits._ + + override val db = new DatabaseOnDocker { + override val imageName = "wnameless/oracle-xe-11g:latest" + override val env = Map( + "ORACLE_ROOT_PASSWORD" -> "oracle" + ) + override val usesIpc = false + override val jdbcPort: Int = 1521 + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:oracle:thin:system/oracle@//$ip:$port/xe" + override def getStartupProcessName: Option[String] = None + } + + override def dataPreparation(conn: Connection): Unit = { + } + + ignore("SPARK-12941: String datatypes to be mapped to Varchar in Oracle") { + // create a sample dataframe with string type + val df1 = sparkContext.parallelize(Seq(("foo"))).toDF("x") + // write the dataframe to the oracle table tbl + df1.write.jdbc(jdbcUrl, "tbl2", new Properties) + // read the table from the oracle + val dfRead = sqlContext.read.jdbc(jdbcUrl, "tbl2", new Properties) + // get the rows + val rows = dfRead.collect() + // verify the data type is inserted + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types(0).equals("class java.lang.String")) + // verify the value is the inserted correct or not + assert(rows(0).getString(0).equals("foo")) + } +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala new file mode 100644 index 0000000000000..79dd70116ecb8 --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Connection +import java.util.Properties + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.types.{ArrayType, DecimalType} +import org.apache.spark.tags.DockerTest + +@DockerTest +class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { + override val db = new DatabaseOnDocker { + override val imageName = "postgres:9.4.5" + override val env = Map( + "POSTGRES_PASSWORD" -> "rootpass" + ) + override val usesIpc = false + override val jdbcPort = 5432 + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:postgresql://$ip:$port/postgres?user=postgres&password=rootpass" + override def getStartupProcessName: Option[String] = None + } + + override def dataPreparation(conn: Connection): Unit = { + conn.prepareStatement("CREATE DATABASE foo").executeUpdate() + conn.setCatalog("foo") + conn.prepareStatement("CREATE TYPE enum_type AS ENUM ('d1', 'd2')").executeUpdate() + conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, " + + "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, " + + "c10 integer[], c11 text[], c12 real[], c13 numeric(2,2)[], c14 enum_type)").executeUpdate() + conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', " + + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', " + + """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}', '{0.11, 0.22}', 'd1')""").executeUpdate() + } + + test("Type mapping for various types") { + val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass) + assert(types.length == 15) + assert(classOf[String].isAssignableFrom(types(0))) + assert(classOf[java.lang.Integer].isAssignableFrom(types(1))) + assert(classOf[java.lang.Double].isAssignableFrom(types(2))) + assert(classOf[java.lang.Long].isAssignableFrom(types(3))) + assert(classOf[java.lang.Boolean].isAssignableFrom(types(4))) + assert(classOf[Array[Byte]].isAssignableFrom(types(5))) + assert(classOf[Array[Byte]].isAssignableFrom(types(6))) + assert(classOf[java.lang.Boolean].isAssignableFrom(types(7))) + assert(classOf[String].isAssignableFrom(types(8))) + assert(classOf[String].isAssignableFrom(types(9))) + assert(classOf[Seq[Int]].isAssignableFrom(types(10))) + assert(classOf[Seq[String]].isAssignableFrom(types(11))) + assert(classOf[Seq[Double]].isAssignableFrom(types(12))) + assert(classOf[Seq[BigDecimal]].isAssignableFrom(types(13))) + assert(classOf[String].isAssignableFrom(types(14))) + assert(rows(0).getString(0).equals("hello")) + assert(rows(0).getInt(1) == 42) + assert(rows(0).getDouble(2) == 1.25) + assert(rows(0).getLong(3) == 123456789012345L) + assert(!rows(0).getBoolean(4)) + // BIT(10)'s come back as ASCII strings of ten ASCII 0's and 1's... + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](5), + Array[Byte](49, 48, 48, 48, 49, 48, 48, 49, 48, 49))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), + Array[Byte](0xDE.toByte, 0xAD.toByte, 0xBE.toByte, 0xEF.toByte))) + assert(rows(0).getBoolean(7)) + assert(rows(0).getString(8) == "172.16.0.42") + assert(rows(0).getString(9) == "192.168.0.0/16") + assert(rows(0).getSeq(10) == Seq(1, 2)) + assert(rows(0).getSeq(11) == Seq("a", null, "b")) + assert(rows(0).getSeq(12).toSeq == Seq(0.11f, 0.22f)) + assert(rows(0).getSeq(13) == Seq("0.11", "0.22").map(BigDecimal(_).bigDecimal)) + assert(rows(0).getString(14) == "d1") + } + + test("Basic write test") { + val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties) + // Test only that it doesn't crash. + df.write.jdbc(jdbcUrl, "public.barcopy", new Properties) + // Test that written numeric type has same DataType as input + assert(sqlContext.read.jdbc(jdbcUrl, "public.barcopy", new Properties).schema(13).dataType == + ArrayType(DecimalType(2, 2), true)) + // Test write null values. + df.select(df.queryExecution.analyzed.output.map { a => + Column(Literal.create(null, a.dataType)).as(a.name) + }: _*).write.jdbc(jdbcUrl, "public.barcopy2", new Properties) + } +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala new file mode 100644 index 0000000000000..fda377e032350 --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.net.{Inet4Address, InetAddress, NetworkInterface} + +import scala.collection.JavaConverters._ +import scala.sys.process._ +import scala.util.Try + +private[spark] object DockerUtils { + + def getDockerIp(): String = { + /** If docker-machine is setup on this box, attempts to find the ip from it. */ + def findFromDockerMachine(): Option[String] = { + sys.env.get("DOCKER_MACHINE_NAME").flatMap { name => + Try(Seq("/bin/bash", "-c", s"docker-machine ip $name 2>/dev/null").!!.trim).toOption + } + } + sys.env.get("DOCKER_IP") + .orElse(findFromDockerMachine()) + .orElse(Try(Seq("/bin/bash", "-c", "boot2docker ip 2>/dev/null").!!.trim).toOption) + .getOrElse { + // This block of code is based on Utils.findLocalInetAddress(), but is modified to blacklist + // certain interfaces. + val address = InetAddress.getLocalHost + // Address resolves to something like 127.0.1.1, which happens on Debian; try to find + // a better address using the local network interfaces + // getNetworkInterfaces returns ifs in reverse order compared to ifconfig output order + // on unix-like system. On windows, it returns in index order. + // It's more proper to pick ip address following system output order. + val blackListedIFs = Seq( + "vboxnet0", // Mac + "docker0" // Linux + ) + val activeNetworkIFs = NetworkInterface.getNetworkInterfaces.asScala.toSeq.filter { i => + !blackListedIFs.contains(i.getName) + } + val reOrderedNetworkIFs = activeNetworkIFs.reverse + for (ni <- reOrderedNetworkIFs) { + val addresses = ni.getInetAddresses.asScala + .filterNot(addr => addr.isLinkLocalAddress || addr.isLoopbackAddress).toSeq + if (addresses.nonEmpty) { + val addr = addresses.find(_.isInstanceOf[Inet4Address]).getOrElse(addresses.head) + // because of Inet6Address.toHostName may add interface at the end if it knows about it + val strippedAddress = InetAddress.getByAddress(addr.getAddress) + return strippedAddress.getHostAddress + } + } + address.getHostAddress + } + } +} diff --git a/docker/README.md b/external/docker/README.md similarity index 100% rename from docker/README.md rename to external/docker/README.md diff --git a/docker/build b/external/docker/build similarity index 100% rename from docker/build rename to external/docker/build diff --git a/docker/spark-mesos/Dockerfile b/external/docker/spark-mesos/Dockerfile similarity index 100% rename from docker/spark-mesos/Dockerfile rename to external/docker/spark-mesos/Dockerfile diff --git a/docker/spark-test/README.md b/external/docker/spark-test/README.md similarity index 100% rename from docker/spark-test/README.md rename to external/docker/spark-test/README.md diff --git a/docker/spark-test/base/Dockerfile b/external/docker/spark-test/base/Dockerfile similarity index 98% rename from docker/spark-test/base/Dockerfile rename to external/docker/spark-test/base/Dockerfile index 7ba0de603dc7d..76f550f886ce4 100644 --- a/docker/spark-test/base/Dockerfile +++ b/external/docker/spark-test/base/Dockerfile @@ -25,7 +25,7 @@ RUN apt-get update && \ apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server && \ rm -rf /var/lib/apt/lists/* -ENV SCALA_VERSION 2.10.5 +ENV SCALA_VERSION 2.11.7 ENV CDH_VERSION cdh4 ENV SCALA_HOME /opt/scala-$SCALA_VERSION ENV SPARK_HOME /opt/spark diff --git a/docker/spark-test/build b/external/docker/spark-test/build similarity index 100% rename from docker/spark-test/build rename to external/docker/spark-test/build diff --git a/docker/spark-test/master/Dockerfile b/external/docker/spark-test/master/Dockerfile similarity index 100% rename from docker/spark-test/master/Dockerfile rename to external/docker/spark-test/master/Dockerfile diff --git a/docker/spark-test/master/default_cmd b/external/docker/spark-test/master/default_cmd similarity index 100% rename from docker/spark-test/master/default_cmd rename to external/docker/spark-test/master/default_cmd diff --git a/docker/spark-test/worker/Dockerfile b/external/docker/spark-test/worker/Dockerfile similarity index 100% rename from docker/spark-test/worker/Dockerfile rename to external/docker/spark-test/worker/Dockerfile diff --git a/docker/spark-test/worker/default_cmd b/external/docker/spark-test/worker/default_cmd similarity index 100% rename from docker/spark-test/worker/default_cmd rename to external/docker/spark-test/worker/default_cmd diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index dceedcf23ed5b..ac15b93c048da 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT + spark-parent_2.11 + 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-flume-assembly_2.10 + spark-streaming-flume-assembly_2.11 jar Spark Project External Flume Assembly http://spark.apache.org/ diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 75113ff753e7a..e4effe158c826 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT + spark-parent_2.11 + 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-flume-sink_2.10 + spark-streaming-flume-sink_2.11 streaming-flume-sink diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala index d87b86932dd41..09d3fe91e42c8 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/Logging.scala @@ -26,20 +26,20 @@ import org.slf4j.{Logger, LoggerFactory} private[sink] trait Logging { // Make the log field transient so that objects with Logging can // be serialized and used on another machine - @transient private var log_ : Logger = null + @transient private var _log: Logger = null // Method to get or create the logger for this object protected def log: Logger = { - if (log_ == null) { + if (_log == null) { initializeIfNecessary() var className = this.getClass.getName // Ignore trailing $'s in the class names for Scala objects if (className.endsWith("$")) { className = className.substring(0, className.length - 1) } - log_ = LoggerFactory.getLogger(className) + _log = LoggerFactory.getLogger(className) } - log_ + _log } // Log methods that take only a String @@ -101,7 +101,7 @@ private[sink] trait Logging { private def initializeLogging() { Logging.initialized = true - // Force a call into slf4j to initialize it. Avoids this happening from mutliple threads + // Force a call into slf4j to initialize it. Avoids this happening from multiple threads // and triggering this: http://mailman.qos.ch/pipermail/slf4j-dev/2010-April/002956.html log } diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala index 719fca0938b3a..8050ec357e261 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -129,9 +129,9 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha * @param success Whether the batch was successful or not. */ private def completeTransaction(sequenceNumber: CharSequence, success: Boolean) { - removeAndGetProcessor(sequenceNumber).foreach(processor => { + removeAndGetProcessor(sequenceNumber).foreach { processor => processor.batchProcessed(success) - }) + } } /** diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala index 14dffb15fef98..41f27e937662f 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala @@ -88,23 +88,23 @@ class SparkSink extends AbstractSink with Logging with Configurable { // dependencies which are being excluded in the build. In practice, // Netty dependencies are already available on the JVM as Flume would have pulled them in. serverOpt = Option(new NettyServer(responder, new InetSocketAddress(hostname, port))) - serverOpt.foreach(server => { + serverOpt.foreach { server => logInfo("Starting Avro server for sink: " + getName) server.start() - }) + } super.start() } override def stop() { logInfo("Stopping Spark Sink: " + getName) - handler.foreach(callbackHandler => { + handler.foreach { callbackHandler => callbackHandler.shutdown() - }) - serverOpt.foreach(server => { + } + serverOpt.foreach { server => logInfo("Stopping Avro Server for sink: " + getName) server.close() server.join() - }) + } blockingLatch.countDown() super.stop() } diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala index 7ad43b1d7b0a0..19e736f016977 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala @@ -22,7 +22,7 @@ import java.util.concurrent.{Callable, CountDownLatch, TimeUnit} import scala.util.control.Breaks -import org.apache.flume.{Transaction, Channel} +import org.apache.flume.{Channel, Transaction} // Flume forces transactions to be thread-local (horrible, I know!) // So the sink basically spawns a new thread to pull the events out within a transaction. @@ -110,7 +110,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, eventBatch.setErrorMsg("Something went wrong. Channel was " + "unable to create a transaction!") } - txOpt.foreach(tx => { + txOpt.foreach { tx => tx.begin() val events = new util.ArrayList[SparkSinkEvent](maxBatchSize) val loop = new Breaks @@ -145,7 +145,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, // At this point, the events are available, so fill them into the event batch eventBatch = new EventBatch("", seqNum, events) } - }) + } } catch { case interrupted: InterruptedException => // Don't pollute logs if the InterruptedException came from this being stopped @@ -156,9 +156,9 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, logWarning("Error while processing transaction.", e) eventBatch.setErrorMsg(e.getMessage) try { - txOpt.foreach(tx => { + txOpt.foreach { tx => rollbackAndClose(tx, close = true) - }) + } } finally { txOpt = None } @@ -174,7 +174,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, */ private def processAckOrNack() { batchAckLatch.await(transactionTimeout, TimeUnit.SECONDS) - txOpt.foreach(tx => { + txOpt.foreach { tx => if (batchSuccess) { try { logDebug("Committing transaction") @@ -197,7 +197,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, // cause issues. This is required to ensure the TransactionProcessor instance is not leaked parent.removeAndGetProcessor(seqNum) } - }) + } } /** diff --git a/external/flume-sink/src/test/resources/log4j.properties b/external/flume-sink/src/test/resources/log4j.properties index 42df8792f147f..1e3f163f95c09 100644 --- a/external/flume-sink/src/test/resources/log4j.properties +++ b/external/flume-sink/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index d2654700ea729..e8ca1e716394d 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -17,8 +17,9 @@ package org.apache.spark.streaming.flume.sink import java.net.InetSocketAddress +import java.nio.charset.StandardCharsets +import java.util.concurrent.{CountDownLatch, Executors, TimeUnit} import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.{TimeUnit, CountDownLatch, Executors} import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} @@ -36,11 +37,11 @@ import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory // Spark core main, which has too many dependencies to require here manually. // For this reason, we continue to use FunSuite and ignore the scalastyle checks // that fail if this is detected. -//scalastyle:off +// scalastyle:off import org.scalatest.FunSuite class SparkSinkSuite extends FunSuite { -//scalastyle:on +// scalastyle:on val eventsPerBatch = 1000 val channelCapacity = 5000 @@ -184,7 +185,8 @@ class SparkSinkSuite extends FunSuite { private def putEvents(ch: MemoryChannel, count: Int): Unit = { val tx = ch.getTransaction tx.begin() - (1 to count).foreach(x => ch.put(EventBuilder.withBody(x.toString.getBytes))) + (1 to count).foreach(x => + ch.put(EventBuilder.withBody(x.toString.getBytes(StandardCharsets.UTF_8)))) tx.commit() tx.close() } diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 57f83607365d6..d650dd034d636 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT + spark-parent_2.11 + 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-flume_2.10 + spark-streaming-flume_2.11 streaming-flume diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala index 48df27b26867f..07c5286477737 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala @@ -17,12 +17,12 @@ package org.apache.spark.streaming.flume -import java.io.{ObjectOutput, ObjectInput} +import java.io.{ObjectInput, ObjectOutput} import scala.collection.JavaConverters._ +import org.apache.spark.internal.Logging import org.apache.spark.util.Utils -import org.apache.spark.Logging /** * A simple object that provides the implementation of readExternal and writeExternal for both diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala index b9d4e762ca05d..5f234b1f0ccca 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala @@ -20,7 +20,7 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.base.Throwables -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.streaming.flume.sink._ /** @@ -77,7 +77,7 @@ private[flume] class FlumeBatchFetcher(receiver: FlumePollingReceiver) extends R /** * Gets a batch of events from the specified client. This method does not handle any exceptions - * which will be propogated to the caller. + * which will be propagated to the caller. * @param client Client to get events from * @return [[Some]] which contains the event batch if Flume sent any events back, else [[None]] */ @@ -96,8 +96,8 @@ private[flume] class FlumeBatchFetcher(receiver: FlumePollingReceiver) extends R } /** - * Store the events in the buffer to Spark. This method will not propogate any exceptions, - * but will propogate any other errors. + * Store the events in the buffer to Spark. This method will not propagate any exceptions, + * but will propagate any other errors. * @param buffer The buffer to store * @return true if the data was stored without any exception being thrown, else false */ diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index c8780aa83bdbd..13aa817492f7b 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -17,38 +17,36 @@ package org.apache.spark.streaming.flume +import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.net.InetSocketAddress -import java.io.{ObjectInput, ObjectOutput, Externalizable} import java.nio.ByteBuffer import java.util.concurrent.Executors import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import org.apache.flume.source.avro.AvroSourceProtocol -import org.apache.flume.source.avro.AvroFlumeEvent -import org.apache.flume.source.avro.Status -import org.apache.avro.ipc.specific.SpecificResponder import org.apache.avro.ipc.NettyServer -import org.apache.spark.Logging -import org.apache.spark.util.Utils -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.dstream._ -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.receiver.Receiver - +import org.apache.avro.ipc.specific.SpecificResponder +import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol, Status} import org.jboss.netty.channel.{ChannelPipeline, ChannelPipelineFactory, Channels} import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory import org.jboss.netty.handler.codec.compression._ +import org.apache.spark.internal.Logging +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.dstream._ +import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.util.Utils + private[streaming] class FlumeInputDStream[T: ClassTag]( - ssc_ : StreamingContext, + _ssc: StreamingContext, host: String, port: Int, storageLevel: StorageLevel, enableDecompression: Boolean -) extends ReceiverInputDStream[SparkFlumeEvent](ssc_) { +) extends ReceiverInputDStream[SparkFlumeEvent](_ssc) { override def getReceiver(): Receiver[SparkFlumeEvent] = { new FlumeReceiver(host, port, storageLevel, enableDecompression) @@ -62,7 +60,7 @@ class FlumeInputDStream[T: ClassTag]( * which are not serializable. */ class SparkFlumeEvent() extends Externalizable { - var event : AvroFlumeEvent = new AvroFlumeEvent() + var event: AvroFlumeEvent = new AvroFlumeEvent() /* De-serialize from bytes. */ def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -77,12 +75,12 @@ class SparkFlumeEvent() extends Externalizable { val keyLength = in.readInt() val keyBuff = new Array[Byte](keyLength) in.readFully(keyBuff) - val key : String = Utils.deserialize(keyBuff) + val key: String = Utils.deserialize(keyBuff) val valLength = in.readInt() val valBuff = new Array[Byte](valLength) in.readFully(valBuff) - val value : String = Utils.deserialize(valBuff) + val value: String = Utils.deserialize(valBuff) headers.put(key, value) } @@ -93,9 +91,9 @@ class SparkFlumeEvent() extends Externalizable { /* Serialize to bytes. */ def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - val body = event.getBody.array() - out.writeInt(body.length) - out.write(body) + val body = event.getBody + out.writeInt(body.remaining()) + Utils.writeByteBuffer(body, out) val numHeaders = event.getHeaders.size() out.writeInt(numHeaders) @@ -111,7 +109,7 @@ class SparkFlumeEvent() extends Externalizable { } private[streaming] object SparkFlumeEvent { - def fromAvroFlumeEvent(in : AvroFlumeEvent) : SparkFlumeEvent = { + def fromAvroFlumeEvent(in: AvroFlumeEvent): SparkFlumeEvent = { val event = new SparkFlumeEvent event.event = in event @@ -120,20 +118,22 @@ private[streaming] object SparkFlumeEvent { /** A simple server that implements Flume's Avro protocol. */ private[streaming] -class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol { - override def append(event : AvroFlumeEvent) : Status = { +class FlumeEventServer(receiver: FlumeReceiver) extends AvroSourceProtocol { + override def append(event: AvroFlumeEvent): Status = { receiver.store(SparkFlumeEvent.fromAvroFlumeEvent(event)) Status.OK } - override def appendBatch(events : java.util.List[AvroFlumeEvent]) : Status = { + override def appendBatch(events: java.util.List[AvroFlumeEvent]): Status = { events.asScala.foreach(event => receiver.store(SparkFlumeEvent.fromAvroFlumeEvent(event))) Status.OK } } -/** A NetworkReceiver which listens for events using the - * Flume Avro interface. */ +/** + * A NetworkReceiver which listens for events using the + * Flume Avro interface. + */ private[streaming] class FlumeReceiver( host: String, @@ -187,13 +187,14 @@ class FlumeReceiver( override def preferredLocation: Option[String] = Option(host) - /** A Netty Pipeline factory that will decompress incoming data from - * and the Netty client and compress data going back to the client. - * - * The compression on the return is required because Flume requires - * a successful response to indicate it can remove the event/batch - * from the configured channel - */ + /** + * A Netty Pipeline factory that will decompress incoming data from + * and the Netty client and compress data going back to the client. + * + * The compression on the return is required because Flume requires + * a successful response to indicate it can remove the event/batch + * from the configured channel + */ private[streaming] class CompressionChannelPipelineFactory extends ChannelPipelineFactory { def getPipeline(): ChannelPipeline = { diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala index 6737750c3d63e..54565840fa665 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala @@ -28,12 +28,12 @@ import org.apache.avro.ipc.NettyTransceiver import org.apache.avro.ipc.specific.SpecificRequestor import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.streaming.flume.sink._ +import org.apache.spark.streaming.receiver.Receiver /** * A [[ReceiverInputDStream]] that can be used to read data from several Flume agents running @@ -79,11 +79,11 @@ private[streaming] class FlumePollingReceiver( override def onStart(): Unit = { // Create the connections to each Flume agent. - addresses.foreach(host => { + addresses.foreach { host => val transceiver = new NettyTransceiver(host, channelFactory) val client = SpecificRequestor.getClient(classOf[SparkFlumeProtocol.Callback], transceiver) connections.add(new FlumeConnection(transceiver, client)) - }) + } for (i <- 0 until parallelism) { logInfo("Starting Flume Polling Receiver worker threads..") // Threads that pull data from Flume. diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala index 70018c86f92be..945cfa7295d1d 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala @@ -19,16 +19,17 @@ package org.apache.spark.streaming.flume import java.net.{InetSocketAddress, ServerSocket} import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import java.util.{List => JList} import java.util.Collections import scala.collection.JavaConverters._ -import com.google.common.base.Charsets.UTF_8 import org.apache.avro.ipc.NettyTransceiver import org.apache.avro.ipc.specific.SpecificRequestor import org.apache.commons.lang3.RandomUtils import org.apache.flume.source.avro -import org.apache.flume.source.avro.{AvroSourceProtocol, AvroFlumeEvent} +import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol} import org.jboss.netty.channel.ChannelPipeline import org.jboss.netty.channel.socket.SocketChannel import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory @@ -59,12 +60,12 @@ private[flume] class FlumeTestUtils { } /** Send data to the flume receiver */ - def writeInput(input: Seq[String], enableCompression: Boolean): Unit = { + def writeInput(input: JList[String], enableCompression: Boolean): Unit = { val testAddress = new InetSocketAddress("localhost", testPort) - val inputEvents = input.map { item => + val inputEvents = input.asScala.map { item => val event = new AvroFlumeEvent - event.setBody(ByteBuffer.wrap(item.getBytes(UTF_8))) + event.setBody(ByteBuffer.wrap(item.getBytes(StandardCharsets.UTF_8))) event.setHeaders(Collections.singletonMap("test", "header")) event } diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index c719b80aca7ed..3e3ed712f0dbf 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -17,8 +17,8 @@ package org.apache.spark.streaming.flume +import java.io.{ByteArrayOutputStream, DataOutputStream} import java.net.InetSocketAddress -import java.io.{DataOutputStream, ByteArrayOutputStream} import java.util.{List => JList, Map => JMap} import scala.collection.JavaConverters._ @@ -30,7 +30,6 @@ import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream - object FlumeUtils { private val DEFAULT_POLLING_PARALLELISM = 5 private val DEFAULT_POLLING_BATCH_SIZE = 1000 diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala index a2ab320957db3..6a4dafb8eddb4 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala @@ -17,18 +17,18 @@ package org.apache.spark.streaming.flume +import java.nio.charset.StandardCharsets +import java.util.{Collections, List => JList, Map => JMap} import java.util.concurrent._ -import java.util.{Map => JMap, Collections} import scala.collection.mutable.ArrayBuffer -import com.google.common.base.Charsets.UTF_8 import org.apache.flume.event.EventBuilder import org.apache.flume.Context import org.apache.flume.channel.MemoryChannel import org.apache.flume.conf.Configurables -import org.apache.spark.streaming.flume.sink.{SparkSinkConfig, SparkSink} +import org.apache.spark.streaming.flume.sink.{SparkSink, SparkSinkConfig} /** * Share codes for Scala and Python unit tests @@ -123,9 +123,9 @@ private[flume] class PollingFlumeTestUtils { val latch = new CountDownLatch(batchCount * channels.size) sinks.foreach(_.countdownWhenBatchReceived(latch)) - channels.foreach(channel => { + channels.foreach { channel => executorCompletion.submit(new TxnSubmitter(channel)) - }) + } for (i <- 0 until channels.size) { executorCompletion.take() @@ -137,7 +137,8 @@ private[flume] class PollingFlumeTestUtils { /** * A Python-friendly method to assert the output */ - def assertOutput(outputHeaders: Seq[JMap[String, String]], outputBodies: Seq[String]): Unit = { + def assertOutput( + outputHeaders: JList[JMap[String, String]], outputBodies: JList[String]): Unit = { require(outputHeaders.size == outputBodies.size) val eventSize = outputHeaders.size if (eventSize != totalEventsPerChannel * channels.size) { @@ -151,8 +152,8 @@ private[flume] class PollingFlumeTestUtils { var found = false var j = 0 while (j < eventSize && !found) { - if (eventBodyToVerify == outputBodies(j) && - eventHeaderToVerify == outputHeaders(j)) { + if (eventBodyToVerify == outputBodies.get(j) && + eventHeaderToVerify == outputHeaders.get(j)) { found = true counter += 1 } @@ -192,7 +193,8 @@ private[flume] class PollingFlumeTestUtils { val tx = channel.getTransaction tx.begin() for (j <- 0 until eventsPerBatch) { - channel.put(EventBuilder.withBody(s"${channel.getName}-$t".getBytes(UTF_8), + channel.put(EventBuilder.withBody( + s"${channel.getName}-$t".getBytes(StandardCharsets.UTF_8), Collections.singletonMap(s"test-$t", "header"))) t += 1 } diff --git a/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java b/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java index 3b5e0c7746b2c..ada05f203b6a8 100644 --- a/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java +++ b/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java @@ -27,10 +27,11 @@ public class JavaFlumeStreamSuite extends LocalJavaStreamingContext { @Test public void testFlumeStream() { // tests the API, does not actually test data receiving - JavaReceiverInputDStream test1 = FlumeUtils.createStream(ssc, "localhost", 12345); - JavaReceiverInputDStream test2 = FlumeUtils.createStream(ssc, "localhost", 12345, - StorageLevel.MEMORY_AND_DISK_SER_2()); - JavaReceiverInputDStream test3 = FlumeUtils.createStream(ssc, "localhost", 12345, - StorageLevel.MEMORY_AND_DISK_SER_2(), false); + JavaReceiverInputDStream test1 = FlumeUtils.createStream(ssc, "localhost", + 12345); + JavaReceiverInputDStream test2 = FlumeUtils.createStream(ssc, "localhost", + 12345, StorageLevel.MEMORY_AND_DISK_SER_2()); + JavaReceiverInputDStream test3 = FlumeUtils.createStream(ssc, "localhost", + 12345, StorageLevel.MEMORY_AND_DISK_SER_2(), false); } } diff --git a/external/flume/src/test/resources/log4j.properties b/external/flume/src/test/resources/log4j.properties index 75e3b53a093f6..fd51f8faf56b9 100644 --- a/external/flume/src/test/resources/log4j.properties +++ b/external/flume/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala b/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala index 1a900007b696b..c97a27ca7c7aa 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala @@ -18,14 +18,14 @@ package org.apache.spark.streaming import java.io.{IOException, ObjectInputStream} +import java.util.concurrent.ConcurrentLinkedQueue + +import scala.reflect.ClassTag import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream.{DStream, ForEachDStream} import org.apache.spark.util.Utils -import scala.collection.mutable.ArrayBuffer -import scala.reflect.ClassTag - /** * This is a output stream just for the testsuites. All the output is collected into a * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. @@ -33,11 +33,11 @@ import scala.reflect.ClassTag * The buffer contains a sequence of RDD's, each containing a sequence of items */ class TestOutputStream[T: ClassTag](parent: DStream[T], - val output: ArrayBuffer[Seq[T]] = ArrayBuffer[Seq[T]]()) + val output: ConcurrentLinkedQueue[Seq[T]] = new ConcurrentLinkedQueue[Seq[T]]()) extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.collect() - output += collected - }) { + output.add(collected) + }, false) { // This is to clear the output buffer every it is read from a checkpoint @throws(classOf[IOException]) diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index ff2fb8eed204c..156712483d3ab 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -18,20 +18,21 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress +import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.JavaConverters._ -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} import scala.concurrent.duration._ import scala.language.postfixOps -import com.google.common.base.Charsets.UTF_8 import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{Logging, SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Seconds, StreamingContext, TestOutputStream} import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} import org.apache.spark.util.{ManualClock, Utils} class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { @@ -102,9 +103,8 @@ class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Log val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, utils.eventsPerBatch, 5) - val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] - with SynchronizedBuffer[Seq[SparkFlumeEvent]] - val outputStream = new TestOutputStream(flumeStream, outputBuffer) + val outputQueue = new ConcurrentLinkedQueue[Seq[SparkFlumeEvent]] + val outputStream = new TestOutputStream(flumeStream, outputQueue) outputStream.register() ssc.start() @@ -115,12 +115,12 @@ class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Log // The eventually is required to ensure that all data in the batch has been processed. eventually(timeout(10 seconds), interval(100 milliseconds)) { - val flattenOutputBuffer = outputBuffer.flatten - val headers = flattenOutputBuffer.map(_.event.getHeaders.asScala.map { + val flattenOutput = outputQueue.asScala.toSeq.flatten + val headers = flattenOutput.map(_.event.getHeaders.asScala.map { case (key, value) => (key.toString, value.toString) }).map(_.asJava) - val bodies = flattenOutputBuffer.map(e => new String(e.event.getBody.array(), UTF_8)) - utils.assertOutput(headers, bodies) + val bodies = flattenOutput.map(e => JavaUtils.bytesToString(e.event.getBody)) + utils.assertOutput(headers.asJava, bodies.asJava) } } finally { ssc.stop() diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 5ffb60bd602f9..7bac1cc4b0ae7 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.streaming.flume +import java.util.concurrent.ConcurrentLinkedQueue + import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps -import com.google.common.base.Charsets import org.jboss.netty.channel.ChannelPipeline import org.jboss.netty.channel.socket.SocketChannel import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory @@ -30,7 +30,9 @@ import org.jboss.netty.handler.codec.compression._ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{Logging, SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} @@ -51,19 +53,19 @@ class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers w val input = (1 to 100).map { _.toString } val utils = new FlumeTestUtils try { - val outputBuffer = startContext(utils.getTestPort(), testCompression) + val outputQueue = startContext(utils.getTestPort(), testCompression) eventually(timeout(10 seconds), interval(100 milliseconds)) { - utils.writeInput(input, testCompression) + utils.writeInput(input.asJava, testCompression) } eventually(timeout(10 seconds), interval(100 milliseconds)) { - val outputEvents = outputBuffer.flatten.map { _.event } + val outputEvents = outputQueue.asScala.toSeq.flatten.map { _.event } outputEvents.foreach { event => event.getHeaders.get("test") should be("header") } - val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8)) + val output = outputEvents.map(event => JavaUtils.bytesToString(event.getBody)) output should be (input) } } finally { @@ -76,16 +78,15 @@ class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers w /** Setup and start the streaming context */ private def startContext( - testPort: Int, testCompression: Boolean): (ArrayBuffer[Seq[SparkFlumeEvent]]) = { + testPort: Int, testCompression: Boolean): (ConcurrentLinkedQueue[Seq[SparkFlumeEvent]]) = { ssc = new StreamingContext(conf, Milliseconds(200)) val flumeStream = FlumeUtils.createStream( ssc, "localhost", testPort, StorageLevel.MEMORY_AND_DISK, testCompression) - val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] - with SynchronizedBuffer[Seq[SparkFlumeEvent]] - val outputStream = new TestOutputStream(flumeStream, outputBuffer) + val outputQueue = new ConcurrentLinkedQueue[Seq[SparkFlumeEvent]] + val outputStream = new TestOutputStream(flumeStream, outputQueue) outputStream.register() ssc.start() - outputBuffer + outputQueue } /** Class to create socket channel with compression */ diff --git a/external/java8-tests/README.md b/external/java8-tests/README.md new file mode 100644 index 0000000000000..aa87901695c20 --- /dev/null +++ b/external/java8-tests/README.md @@ -0,0 +1,22 @@ +# Java 8 Test Suites + +These tests require having Java 8 installed and are isolated from the main Spark build. +If Java 8 is not your system's default Java version, you will need to point Spark's build +to your Java location. The set-up depends a bit on the build system: + +* Sbt users can either set JAVA_HOME to the location of a Java 8 JDK or explicitly pass + `-java-home` to the sbt launch script. If a Java 8 JDK is detected sbt will automatically + include the Java 8 test project. + + `$ JAVA_HOME=/opt/jdk1.8.0/ build/sbt clean java8-tests/test + +* For Maven users, + + Maven users can also refer to their Java 8 directory using JAVA_HOME. + + `$ JAVA_HOME=/opt/jdk1.8.0/ mvn clean install -DskipTests` + `$ JAVA_HOME=/opt/jdk1.8.0/ mvn -pl :java8-tests_2.11 test` + + Note that the above command can only be run from project root directory since this module + depends on core and the test-jars of core and streaming. This means an install step is + required to make the test dependencies visible to the Java 8 sub-project. diff --git a/external/java8-tests/pom.xml b/external/java8-tests/pom.xml new file mode 100644 index 0000000000000..1ea9196e9dfe3 --- /dev/null +++ b/external/java8-tests/pom.xml @@ -0,0 +1,120 @@ + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.0.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + java8-tests_2.11 + pom + Spark Project Java 8 Tests + + + java8-tests + + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-test-tags_${scala.binary.version} + + + + + + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + org.apache.maven.plugins + maven-install-plugin + + true + + + + org.apache.maven.plugins + maven-compiler-plugin + + true + 1.8 + 1.8 + 1.8 + + + + net.alchim31.maven + scala-maven-plugin + + + -source + 1.8 + -target + 1.8 + -Xlint:all,-serial,-path + + + + + + diff --git a/external/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java b/external/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java new file mode 100644 index 0000000000000..6ac5ca9cf56af --- /dev/null +++ b/external/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import java.io.File; +import java.io.Serializable; +import java.util.*; + +import scala.Tuple2; + +import com.google.common.collect.Iterables; +import com.google.common.io.Files; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapred.SequenceFileOutputFormat; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.Optional; +import org.apache.spark.api.java.function.*; +import org.apache.spark.util.Utils; + +/** + * Most of these tests replicate org.apache.spark.JavaAPISuite using java 8 + * lambda syntax. + */ +public class Java8APISuite implements Serializable { + static int foreachCalls = 0; + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaAPISuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void foreachWithAnonymousClass() { + foreachCalls = 0; + JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); + rdd.foreach(new VoidFunction() { + @Override + public void call(String s) { + foreachCalls++; + } + }); + Assert.assertEquals(2, foreachCalls); + } + + @Test + public void foreach() { + foreachCalls = 0; + JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); + rdd.foreach(x -> foreachCalls++); + Assert.assertEquals(2, foreachCalls); + } + + @Test + public void groupBy() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); + Function isOdd = x -> x % 2 == 0; + JavaPairRDD> oddsAndEvens = rdd.groupBy(isOdd); + Assert.assertEquals(2, oddsAndEvens.count()); + Assert.assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0))); // Evens + Assert.assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds + + oddsAndEvens = rdd.groupBy(isOdd, 1); + Assert.assertEquals(2, oddsAndEvens.count()); + Assert.assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0))); // Evens + Assert.assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds + } + + @Test + public void leftOuterJoin() { + JavaPairRDD rdd1 = sc.parallelizePairs(Arrays.asList( + new Tuple2<>(1, 1), + new Tuple2<>(1, 2), + new Tuple2<>(2, 1), + new Tuple2<>(3, 1) + )); + JavaPairRDD rdd2 = sc.parallelizePairs(Arrays.asList( + new Tuple2<>(1, 'x'), + new Tuple2<>(2, 'y'), + new Tuple2<>(2, 'z'), + new Tuple2<>(4, 'w') + )); + List>>> joined = + rdd1.leftOuterJoin(rdd2).collect(); + Assert.assertEquals(5, joined.size()); + Tuple2>> firstUnmatched = + rdd1.leftOuterJoin(rdd2).filter(tup -> !tup._2()._2().isPresent()).first(); + Assert.assertEquals(3, firstUnmatched._1().intValue()); + } + + @Test + public void foldReduce() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); + Function2 add = (a, b) -> a + b; + + int sum = rdd.fold(0, add); + Assert.assertEquals(33, sum); + + sum = rdd.reduce(add); + Assert.assertEquals(33, sum); + } + + @Test + public void foldByKey() { + List> pairs = Arrays.asList( + new Tuple2<>(2, 1), + new Tuple2<>(2, 1), + new Tuple2<>(1, 1), + new Tuple2<>(3, 2), + new Tuple2<>(3, 1) + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + JavaPairRDD sums = rdd.foldByKey(0, (a, b) -> a + b); + Assert.assertEquals(1, sums.lookup(1).get(0).intValue()); + Assert.assertEquals(2, sums.lookup(2).get(0).intValue()); + Assert.assertEquals(3, sums.lookup(3).get(0).intValue()); + } + + @Test + public void reduceByKey() { + List> pairs = Arrays.asList( + new Tuple2<>(2, 1), + new Tuple2<>(2, 1), + new Tuple2<>(1, 1), + new Tuple2<>(3, 2), + new Tuple2<>(3, 1) + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + JavaPairRDD counts = rdd.reduceByKey((a, b) -> a + b); + Assert.assertEquals(1, counts.lookup(1).get(0).intValue()); + Assert.assertEquals(2, counts.lookup(2).get(0).intValue()); + Assert.assertEquals(3, counts.lookup(3).get(0).intValue()); + + Map localCounts = counts.collectAsMap(); + Assert.assertEquals(1, localCounts.get(1).intValue()); + Assert.assertEquals(2, localCounts.get(2).intValue()); + Assert.assertEquals(3, localCounts.get(3).intValue()); + + localCounts = rdd.reduceByKeyLocally((a, b) -> a + b); + Assert.assertEquals(1, localCounts.get(1).intValue()); + Assert.assertEquals(2, localCounts.get(2).intValue()); + Assert.assertEquals(3, localCounts.get(3).intValue()); + } + + @Test + public void map() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + JavaDoubleRDD doubles = rdd.mapToDouble(x -> 1.0 * x).cache(); + doubles.collect(); + JavaPairRDD pairs = rdd.mapToPair(x -> new Tuple2<>(x, x)) + .cache(); + pairs.collect(); + JavaRDD strings = rdd.map(Object::toString).cache(); + strings.collect(); + } + + @Test + public void flatMap() { + JavaRDD rdd = sc.parallelize(Arrays.asList("Hello World!", + "The quick brown fox jumps over the lazy dog.")); + JavaRDD words = rdd.flatMap(x -> Arrays.asList(x.split(" ")).iterator()); + + Assert.assertEquals("Hello", words.first()); + Assert.assertEquals(11, words.count()); + + JavaPairRDD pairs = rdd.flatMapToPair(s -> { + List> pairs2 = new LinkedList<>(); + for (String word : s.split(" ")) { + pairs2.add(new Tuple2<>(word, word)); + } + return pairs2.iterator(); + }); + + Assert.assertEquals(new Tuple2<>("Hello", "Hello"), pairs.first()); + Assert.assertEquals(11, pairs.count()); + + JavaDoubleRDD doubles = rdd.flatMapToDouble(s -> { + List lengths = new LinkedList<>(); + for (String word : s.split(" ")) { + lengths.add((double) word.length()); + } + return lengths.iterator(); + }); + + Assert.assertEquals(5.0, doubles.first(), 0.01); + Assert.assertEquals(11, pairs.count()); + } + + @Test + public void mapsFromPairsToPairs() { + List> pairs = Arrays.asList( + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") + ); + JavaPairRDD pairRDD = sc.parallelizePairs(pairs); + + // Regression test for SPARK-668: + JavaPairRDD swapped = + pairRDD.flatMapToPair(x -> Collections.singletonList(x.swap()).iterator()); + swapped.collect(); + + // There was never a bug here, but it's worth testing: + pairRDD.map(Tuple2::swap).collect(); + } + + @Test + public void mapPartitions() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); + JavaRDD partitionSums = rdd.mapPartitions(iter -> { + int sum = 0; + while (iter.hasNext()) { + sum += iter.next(); + } + return Collections.singletonList(sum).iterator(); + }); + + Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); + } + + @Test + public void sequenceFile() { + File tempDir = Files.createTempDir(); + tempDir.deleteOnExit(); + String outputDir = new File(tempDir, "output").getAbsolutePath(); + List> pairs = Arrays.asList( + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + + rdd.mapToPair(pair -> new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2()))) + .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); + + // Try reading the output back as an object file + JavaPairRDD readRDD = sc.sequenceFile(outputDir, IntWritable.class, Text.class) + .mapToPair(pair -> new Tuple2<>(pair._1().get(), pair._2().toString())); + Assert.assertEquals(pairs, readRDD.collect()); + Utils.deleteRecursively(tempDir); + } + + @Test + public void zip() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + JavaDoubleRDD doubles = rdd.mapToDouble(x -> 1.0 * x); + JavaPairRDD zipped = rdd.zip(doubles); + zipped.count(); + } + + @Test + public void zipPartitions() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6), 2); + JavaRDD rdd2 = sc.parallelize(Arrays.asList("1", "2", "3", "4"), 2); + FlatMapFunction2, Iterator, Integer> sizesFn = + (Iterator i, Iterator s) -> { + int sizeI = 0; + while (i.hasNext()) { + sizeI += 1; + i.next(); + } + int sizeS = 0; + while (s.hasNext()) { + sizeS += 1; + s.next(); + } + return Arrays.asList(sizeI, sizeS).iterator(); + }; + JavaRDD sizes = rdd1.zipPartitions(rdd2, sizesFn); + Assert.assertEquals("[3, 2, 3, 2]", sizes.collect().toString()); + } + + @Test + public void accumulators() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + + Accumulator intAccum = sc.intAccumulator(10); + rdd.foreach(intAccum::add); + Assert.assertEquals((Integer) 25, intAccum.value()); + + Accumulator doubleAccum = sc.doubleAccumulator(10.0); + rdd.foreach(x -> doubleAccum.add((double) x)); + Assert.assertEquals((Double) 25.0, doubleAccum.value()); + + // Try a custom accumulator type + AccumulatorParam floatAccumulatorParam = new AccumulatorParam() { + @Override + public Float addInPlace(Float r, Float t) { + return r + t; + } + @Override + public Float addAccumulator(Float r, Float t) { + return r + t; + } + @Override + public Float zero(Float initialValue) { + return 0.0f; + } + }; + + Accumulator floatAccum = sc.accumulator(10.0f, floatAccumulatorParam); + rdd.foreach(x -> floatAccum.add((float) x)); + Assert.assertEquals((Float) 25.0f, floatAccum.value()); + + // Test the setValue method + floatAccum.setValue(5.0f); + Assert.assertEquals((Float) 5.0f, floatAccum.value()); + } + + @Test + public void keyBy() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2)); + List> s = rdd.keyBy(Object::toString).collect(); + Assert.assertEquals(new Tuple2<>("1", 1), s.get(0)); + Assert.assertEquals(new Tuple2<>("2", 2), s.get(1)); + } + + @Test + public void mapOnPairRDD() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + JavaPairRDD rdd2 = + rdd1.mapToPair(i -> new Tuple2<>(i, i % 2)); + JavaPairRDD rdd3 = + rdd2.mapToPair(in -> new Tuple2<>(in._2(), in._1())); + Assert.assertEquals(Arrays.asList( + new Tuple2<>(1, 1), + new Tuple2<>(0, 2), + new Tuple2<>(1, 3), + new Tuple2<>(0, 4)), rdd3.collect()); + } + + @Test + public void collectPartitions() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7), 3); + + JavaPairRDD rdd2 = + rdd1.mapToPair(i -> new Tuple2<>(i, i % 2)); + List[] parts = rdd1.collectPartitions(new int[]{0}); + Assert.assertEquals(Arrays.asList(1, 2), parts[0]); + + parts = rdd1.collectPartitions(new int[]{1, 2}); + Assert.assertEquals(Arrays.asList(3, 4), parts[0]); + Assert.assertEquals(Arrays.asList(5, 6, 7), parts[1]); + + Assert.assertEquals(Arrays.asList(new Tuple2<>(1, 1), new Tuple2<>(2, 0)), + rdd2.collectPartitions(new int[]{0})[0]); + + List>[] parts2 = rdd2.collectPartitions(new int[]{1, 2}); + Assert.assertEquals(Arrays.asList(new Tuple2<>(3, 1), new Tuple2<>(4, 0)), parts2[0]); + Assert.assertEquals(Arrays.asList(new Tuple2<>(5, 1), new Tuple2<>(6, 0), new Tuple2<>(7, 1)), + parts2[1]); + } + + @Test + public void collectAsMapWithIntArrayValues() { + // Regression test for SPARK-1040 + JavaRDD rdd = sc.parallelize(Arrays.asList(1)); + JavaPairRDD pairRDD = + rdd.mapToPair(x -> new Tuple2<>(x, new int[]{x})); + pairRDD.collect(); // Works fine + pairRDD.collectAsMap(); // Used to crash with ClassCastException + } +} diff --git a/external/java8-tests/src/test/java/org/apache/spark/sql/Java8DatasetAggregatorSuite.java b/external/java8-tests/src/test/java/org/apache/spark/sql/Java8DatasetAggregatorSuite.java new file mode 100644 index 0000000000000..23abfa397061d --- /dev/null +++ b/external/java8-tests/src/test/java/org/apache/spark/sql/Java8DatasetAggregatorSuite.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources; + +import java.util.Arrays; + +import org.junit.Assert; +import org.junit.Test; +import scala.Tuple2; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.KeyValueGroupedDataset; +import org.apache.spark.sql.expressions.java.typed; + +/** + * Suite that replicates tests in JavaDatasetAggregatorSuite using lambda syntax. + */ +public class Java8DatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase { + @Test + public void testTypedAggregationAverage() { + KeyValueGroupedDataset> grouped = generateGroupedDataset(); + Dataset> agged = grouped.agg(typed.avg(v -> (double)(v._2() * 2))); + Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 6.0)), agged.collectAsList()); + } + + @Test + public void testTypedAggregationCount() { + KeyValueGroupedDataset> grouped = generateGroupedDataset(); + Dataset> agged = grouped.agg(typed.count(v -> v)); + Assert.assertEquals(Arrays.asList(tuple2("a", 2), tuple2("b", 1)), agged.collectAsList()); + } + + @Test + public void testTypedAggregationSumDouble() { + KeyValueGroupedDataset> grouped = generateGroupedDataset(); + Dataset> agged = grouped.agg(typed.sum(v -> (double)v._2())); + Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 3.0)), agged.collectAsList()); + } + + @Test + public void testTypedAggregationSumLong() { + KeyValueGroupedDataset> grouped = generateGroupedDataset(); + Dataset> agged = grouped.agg(typed.sumLong(v -> (long)v._2())); + Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); + } +} diff --git a/external/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/external/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java new file mode 100644 index 0000000000000..d0fed303e659c --- /dev/null +++ b/external/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java @@ -0,0 +1,909 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import java.io.Serializable; +import java.util.*; + +import scala.Tuple2; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.Accumulator; +import org.apache.spark.HashPartitioner; +import org.apache.spark.api.java.Optional; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaMapWithStateDStream; + +/** + * Most of these tests replicate org.apache.spark.streaming.JavaAPISuite using java 8 + * lambda syntax. + */ +@SuppressWarnings("unchecked") +public class Java8APISuite extends LocalJavaStreamingContext implements Serializable { + + @Test + public void testMap() { + List> inputData = Arrays.asList( + Arrays.asList("hello", "world"), + Arrays.asList("goodnight", "moon")); + + List> expected = Arrays.asList( + Arrays.asList(5, 5), + Arrays.asList(9, 4)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream letterCount = stream.map(String::length); + JavaTestUtils.attachTestOutputStream(letterCount); + List> result = JavaTestUtils.runStreams(ssc, 2, 2); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testFilter() { + List> inputData = Arrays.asList( + Arrays.asList("giants", "dodgers"), + Arrays.asList("yankees", "red sox")); + + List> expected = Arrays.asList( + Arrays.asList("giants"), + Arrays.asList("yankees")); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream filtered = stream.filter(s -> s.contains("a")); + JavaTestUtils.attachTestOutputStream(filtered); + List> result = JavaTestUtils.runStreams(ssc, 2, 2); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testMapPartitions() { + List> inputData = Arrays.asList( + Arrays.asList("giants", "dodgers"), + Arrays.asList("yankees", "red sox")); + + List> expected = Arrays.asList( + Arrays.asList("GIANTSDODGERS"), + Arrays.asList("YANKEESRED SOX")); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream mapped = stream.mapPartitions(in -> { + String out = ""; + while (in.hasNext()) { + out = out + in.next().toUpperCase(); + } + return Lists.newArrayList(out).iterator(); + }); + JavaTestUtils.attachTestOutputStream(mapped); + List> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testReduce() { + List> inputData = Arrays.asList( + Arrays.asList(1, 2, 3), + Arrays.asList(4, 5, 6), + Arrays.asList(7, 8, 9)); + + List> expected = Arrays.asList( + Arrays.asList(6), + Arrays.asList(15), + Arrays.asList(24)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream reduced = stream.reduce((x, y) -> x + y); + JavaTestUtils.attachTestOutputStream(reduced); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testReduceByWindow() { + List> inputData = Arrays.asList( + Arrays.asList(1, 2, 3), + Arrays.asList(4, 5, 6), + Arrays.asList(7, 8, 9)); + + List> expected = Arrays.asList( + Arrays.asList(6), + Arrays.asList(21), + Arrays.asList(39), + Arrays.asList(24)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream reducedWindowed = stream.reduceByWindow((x, y) -> x + y, + (x, y) -> x - y, new Duration(2000), new Duration(1000)); + JavaTestUtils.attachTestOutputStream(reducedWindowed); + List> result = JavaTestUtils.runStreams(ssc, 4, 4); + + Assert.assertEquals(expected, result); + } + + @Test + public void testTransform() { + List> inputData = Arrays.asList( + Arrays.asList(1, 2, 3), + Arrays.asList(4, 5, 6), + Arrays.asList(7, 8, 9)); + + List> expected = Arrays.asList( + Arrays.asList(3, 4, 5), + Arrays.asList(6, 7, 8), + Arrays.asList(9, 10, 11)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream transformed = stream.transform(in -> in.map(i -> i + 2)); + + JavaTestUtils.attachTestOutputStream(transformed); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testVariousTransform() { + // tests whether all variations of transform can be called from Java + + List> inputData = Arrays.asList(Arrays.asList(1)); + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + + List>> pairInputData = + Arrays.asList(Arrays.asList(new Tuple2<>("x", 1))); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream( + JavaTestUtils.attachTestInputStream(ssc, pairInputData, 1)); + + JavaDStream transformed1 = stream.transform(in -> null); + JavaDStream transformed2 = stream.transform((x, time) -> null); + JavaPairDStream transformed3 = stream.transformToPair(x -> null); + JavaPairDStream transformed4 = stream.transformToPair((x, time) -> null); + JavaDStream pairTransformed1 = pairStream.transform(x -> null); + JavaDStream pairTransformed2 = pairStream.transform((x, time) -> null); + JavaPairDStream pairTransformed3 = pairStream.transformToPair(x -> null); + JavaPairDStream pairTransformed4 = + pairStream.transformToPair((x, time) -> null); + + } + + @Test + public void testTransformWith() { + List>> stringStringKVStream1 = Arrays.asList( + Arrays.asList( + new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), + Arrays.asList( + new Tuple2<>("california", "sharks"), + new Tuple2<>("new york", "rangers"))); + + List>> stringStringKVStream2 = Arrays.asList( + Arrays.asList( + new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "mets")), + Arrays.asList( + new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "islanders"))); + + + List>>> expected = Arrays.asList( + Sets.newHashSet( + new Tuple2<>("california", + new Tuple2<>("dodgers", "giants")), + new Tuple2<>("new york", + new Tuple2<>("yankees", "mets"))), + Sets.newHashSet( + new Tuple2<>("california", + new Tuple2<>("sharks", "ducks")), + new Tuple2<>("new york", + new Tuple2<>("rangers", "islanders")))); + + JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( + ssc, stringStringKVStream1, 1); + JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream(stream1); + + JavaDStream> stream2 = JavaTestUtils.attachTestInputStream( + ssc, stringStringKVStream2, 1); + JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); + + JavaPairDStream> joined = + pairStream1.transformWithToPair(pairStream2,(x, y, z) -> x.join(y)); + + JavaTestUtils.attachTestOutputStream(joined); + List>>> result = JavaTestUtils.runStreams(ssc, 2, 2); + List>>> unorderedResult = Lists.newArrayList(); + for (List>> res : result) { + unorderedResult.add(Sets.newHashSet(res)); + } + + Assert.assertEquals(expected, unorderedResult); + } + + + @Test + public void testVariousTransformWith() { + // tests whether all variations of transformWith can be called from Java + + List> inputData1 = Arrays.asList(Arrays.asList(1)); + List> inputData2 = Arrays.asList(Arrays.asList("x")); + JavaDStream stream1 = JavaTestUtils.attachTestInputStream(ssc, inputData1, 1); + JavaDStream stream2 = JavaTestUtils.attachTestInputStream(ssc, inputData2, 1); + + List>> pairInputData1 = + Arrays.asList(Arrays.asList(new Tuple2<>("x", 1))); + List>> pairInputData2 = + Arrays.asList(Arrays.asList(new Tuple2<>(1.0, 'x'))); + JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream( + JavaTestUtils.attachTestInputStream(ssc, pairInputData1, 1)); + JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream( + JavaTestUtils.attachTestInputStream(ssc, pairInputData2, 1)); + + JavaDStream transformed1 = stream1.transformWith(stream2, (x, y, z) -> null); + JavaDStream transformed2 = stream1.transformWith(pairStream1,(x, y, z) -> null); + + JavaPairDStream transformed3 = + stream1.transformWithToPair(stream2,(x, y, z) -> null); + + JavaPairDStream transformed4 = + stream1.transformWithToPair(pairStream1,(x, y, z) -> null); + + JavaDStream pairTransformed1 = pairStream1.transformWith(stream2,(x, y, z) -> null); + + JavaDStream pairTransformed2_ = + pairStream1.transformWith(pairStream1,(x, y, z) -> null); + + JavaPairDStream pairTransformed3 = + pairStream1.transformWithToPair(stream2,(x, y, z) -> null); + + JavaPairDStream pairTransformed4 = + pairStream1.transformWithToPair(pairStream2,(x, y, z) -> null); + } + + @Test + public void testStreamingContextTransform() { + List> stream1input = Arrays.asList( + Arrays.asList(1), + Arrays.asList(2) + ); + + List> stream2input = Arrays.asList( + Arrays.asList(3), + Arrays.asList(4) + ); + + List>> pairStream1input = Arrays.asList( + Arrays.asList(new Tuple2<>(1, "x")), + Arrays.asList(new Tuple2<>(2, "y")) + ); + + List>>> expected = Arrays.asList( + Arrays.asList(new Tuple2<>(1, new Tuple2<>(1, "x"))), + Arrays.asList(new Tuple2<>(2, new Tuple2<>(2, "y"))) + ); + + JavaDStream stream1 = JavaTestUtils.attachTestInputStream(ssc, stream1input, 1); + JavaDStream stream2 = JavaTestUtils.attachTestInputStream(ssc, stream2input, 1); + JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream( + JavaTestUtils.attachTestInputStream(ssc, pairStream1input, 1)); + + List> listOfDStreams1 = Arrays.>asList(stream1, stream2); + + // This is just to test whether this transform to JavaStream compiles + JavaDStream transformed1 = ssc.transform( + listOfDStreams1, (List> listOfRDDs, Time time) -> { + Assert.assertEquals(2, listOfRDDs.size()); + return null; + }); + + List> listOfDStreams2 = + Arrays.>asList(stream1, stream2, pairStream1.toJavaDStream()); + + JavaPairDStream> transformed2 = ssc.transformToPair( + listOfDStreams2, (List> listOfRDDs, Time time) -> { + Assert.assertEquals(3, listOfRDDs.size()); + JavaRDD rdd1 = (JavaRDD) listOfRDDs.get(0); + JavaRDD rdd2 = (JavaRDD) listOfRDDs.get(1); + JavaRDD> rdd3 = (JavaRDD>) listOfRDDs.get(2); + JavaPairRDD prdd3 = JavaPairRDD.fromJavaRDD(rdd3); + PairFunction mapToTuple = + (Integer i) -> new Tuple2<>(i, i); + return rdd1.union(rdd2).mapToPair(mapToTuple).join(prdd3); + }); + JavaTestUtils.attachTestOutputStream(transformed2); + List>>> result = + JavaTestUtils.runStreams(ssc, 2, 2); + Assert.assertEquals(expected, result); + } + + @Test + public void testFlatMap() { + List> inputData = Arrays.asList( + Arrays.asList("go", "giants"), + Arrays.asList("boo", "dodgers"), + Arrays.asList("athletics")); + + List> expected = Arrays.asList( + Arrays.asList("g", "o", "g", "i", "a", "n", "t", "s"), + Arrays.asList("b", "o", "o", "d", "o", "d", "g", "e", "r", "s"), + Arrays.asList("a", "t", "h", "l", "e", "t", "i", "c", "s")); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream flatMapped = stream.flatMap( + s -> Lists.newArrayList(s.split("(?!^)")).iterator()); + JavaTestUtils.attachTestOutputStream(flatMapped); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testForeachRDD() { + final Accumulator accumRdd = ssc.sparkContext().accumulator(0); + final Accumulator accumEle = ssc.sparkContext().accumulator(0); + List> inputData = Arrays.asList( + Arrays.asList(1,1,1), + Arrays.asList(1,1,1)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaTestUtils.attachTestOutputStream(stream.count()); // dummy output + + stream.foreachRDD(rdd -> { + accumRdd.add(1); + rdd.foreach(x -> accumEle.add(1)); + }); + + // This is a test to make sure foreachRDD(VoidFunction2) can be called from Java + stream.foreachRDD((rdd, time) -> { + return; + }); + + JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(2, accumRdd.value().intValue()); + Assert.assertEquals(6, accumEle.value().intValue()); + } + + @Test + public void testPairFlatMap() { + List> inputData = Arrays.asList( + Arrays.asList("giants"), + Arrays.asList("dodgers"), + Arrays.asList("athletics")); + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2<>(6, "g"), + new Tuple2<>(6, "i"), + new Tuple2<>(6, "a"), + new Tuple2<>(6, "n"), + new Tuple2<>(6, "t"), + new Tuple2<>(6, "s")), + Arrays.asList( + new Tuple2<>(7, "d"), + new Tuple2<>(7, "o"), + new Tuple2<>(7, "d"), + new Tuple2<>(7, "g"), + new Tuple2<>(7, "e"), + new Tuple2<>(7, "r"), + new Tuple2<>(7, "s")), + Arrays.asList( + new Tuple2<>(9, "a"), + new Tuple2<>(9, "t"), + new Tuple2<>(9, "h"), + new Tuple2<>(9, "l"), + new Tuple2<>(9, "e"), + new Tuple2<>(9, "t"), + new Tuple2<>(9, "i"), + new Tuple2<>(9, "c"), + new Tuple2<>(9, "s"))); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream flatMapped = stream.flatMapToPair(s -> { + List> out = Lists.newArrayList(); + for (String letter : s.split("(?!^)")) { + out.add(new Tuple2<>(s.length(), letter)); + } + return out.iterator(); + }); + + JavaTestUtils.attachTestOutputStream(flatMapped); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + /* + * Performs an order-invariant comparison of lists representing two RDD streams. This allows + * us to account for ordering variation within individual RDD's which occurs during windowing. + */ + public static > void assertOrderInvariantEquals( + List> expected, List> actual) { + expected.forEach(list -> Collections.sort(list)); + List> sortedActual = new ArrayList<>(); + actual.forEach(list -> { + List sortedList = new ArrayList<>(list); + Collections.sort(sortedList); + sortedActual.add(sortedList); + }); + Assert.assertEquals(expected, sortedActual); + } + + @Test + public void testPairFilter() { + List> inputData = Arrays.asList( + Arrays.asList("giants", "dodgers"), + Arrays.asList("yankees", "red sox")); + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2<>("giants", 6)), + Arrays.asList(new Tuple2<>("yankees", 7))); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = + stream.mapToPair(x -> new Tuple2<>(x, x.length())); + JavaPairDStream filtered = pairStream.filter(x -> x._1().contains("a")); + JavaTestUtils.attachTestOutputStream(filtered); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + List>> stringStringKVStream = Arrays.asList( + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "yankees"), + new Tuple2<>("new york", "mets")), + Arrays.asList(new Tuple2<>("california", "sharks"), + new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "rangers"), + new Tuple2<>("new york", "islanders"))); + + List>> stringIntKVStream = Arrays.asList( + Arrays.asList( + new Tuple2<>("california", 1), + new Tuple2<>("california", 3), + new Tuple2<>("new york", 4), + new Tuple2<>("new york", 1)), + Arrays.asList( + new Tuple2<>("california", 5), + new Tuple2<>("california", 5), + new Tuple2<>("new york", 3), + new Tuple2<>("new york", 1))); + + @Test + public void testPairMap() { // Maps pair -> pair of different type + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2<>(1, "california"), + new Tuple2<>(3, "california"), + new Tuple2<>(4, "new york"), + new Tuple2<>(1, "new york")), + Arrays.asList( + new Tuple2<>(5, "california"), + new Tuple2<>(5, "california"), + new Tuple2<>(3, "new york"), + new Tuple2<>(1, "new york"))); + + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + JavaPairDStream reversed = pairStream.mapToPair(x -> x.swap()); + JavaTestUtils.attachTestOutputStream(reversed); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairMapPartitions() { // Maps pair -> pair of different type + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2<>(1, "california"), + new Tuple2<>(3, "california"), + new Tuple2<>(4, "new york"), + new Tuple2<>(1, "new york")), + Arrays.asList( + new Tuple2<>(5, "california"), + new Tuple2<>(5, "california"), + new Tuple2<>(3, "new york"), + new Tuple2<>(1, "new york"))); + + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + JavaPairDStream reversed = pairStream.mapPartitionsToPair(in -> { + LinkedList> out = new LinkedList<>(); + while (in.hasNext()) { + Tuple2 next = in.next(); + out.add(next.swap()); + } + return out.iterator(); + }); + + JavaTestUtils.attachTestOutputStream(reversed); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairMap2() { // Maps pair -> single + List>> inputData = stringIntKVStream; + + List> expected = Arrays.asList( + Arrays.asList(1, 3, 4, 1), + Arrays.asList(5, 5, 3, 1)); + + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + JavaDStream reversed = pairStream.map(in -> in._2()); + JavaTestUtils.attachTestOutputStream(reversed); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair + List>> inputData = Arrays.asList( + Arrays.asList( + new Tuple2<>("hi", 1), + new Tuple2<>("ho", 2)), + Arrays.asList( + new Tuple2<>("hi", 1), + new Tuple2<>("ho", 2))); + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2<>(1, "h"), + new Tuple2<>(1, "i"), + new Tuple2<>(2, "h"), + new Tuple2<>(2, "o")), + Arrays.asList( + new Tuple2<>(1, "h"), + new Tuple2<>(1, "i"), + new Tuple2<>(2, "h"), + new Tuple2<>(2, "o"))); + + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + JavaPairDStream flatMapped = pairStream.flatMapToPair(in -> { + List> out = new LinkedList<>(); + for (Character s : in._1().toCharArray()) { + out.add(new Tuple2<>(in._2(), s.toString())); + } + return out.iterator(); + }); + + JavaTestUtils.attachTestOutputStream(flatMapped); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairReduceByKey() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList( + new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream reduced = pairStream.reduceByKey((x, y) -> x + y); + + JavaTestUtils.attachTestOutputStream(reduced); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testCombineByKey() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList( + new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream combined = pairStream.combineByKey(i -> i, + (x, y) -> x + y, (x, y) -> x + y, new HashPartitioner(2)); + + JavaTestUtils.attachTestOutputStream(combined); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testReduceByKeyAndWindow() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); + + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream reduceWindowed = + pairStream.reduceByKeyAndWindow((x, y) -> x + y, new Duration(2000), new Duration(1000)); + JavaTestUtils.attachTestOutputStream(reduceWindowed); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testUpdateStateByKey() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9))); + + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream updated = pairStream.updateStateByKey((values, state) -> { + int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } + for (Integer v : values) { + out = out + v; + } + return Optional.of(out); + }); + + JavaTestUtils.attachTestOutputStream(updated); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testReduceByKeyAndWindowWithInverse() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); + + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream reduceWindowed = + pairStream.reduceByKeyAndWindow((x, y) -> x + y, (x, y) -> x - y, new Duration(2000), + new Duration(1000)); + JavaTestUtils.attachTestOutputStream(reduceWindowed); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairTransform() { + List>> inputData = Arrays.asList( + Arrays.asList( + new Tuple2<>(3, 5), + new Tuple2<>(1, 5), + new Tuple2<>(4, 5), + new Tuple2<>(2, 5)), + Arrays.asList( + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5), + new Tuple2<>(1, 5))); + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2<>(1, 5), + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5)), + Arrays.asList( + new Tuple2<>(1, 5), + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream sorted = pairStream.transformToPair(in -> in.sortByKey()); + + JavaTestUtils.attachTestOutputStream(sorted); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairToNormalRDDTransform() { + List>> inputData = Arrays.asList( + Arrays.asList( + new Tuple2<>(3, 5), + new Tuple2<>(1, 5), + new Tuple2<>(4, 5), + new Tuple2<>(2, 5)), + Arrays.asList( + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5), + new Tuple2<>(1, 5))); + + List> expected = Arrays.asList( + Arrays.asList(3, 1, 4, 2), + Arrays.asList(2, 3, 4, 1)); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + JavaDStream firstParts = pairStream.transform(in -> in.map(x -> x._1())); + JavaTestUtils.attachTestOutputStream(firstParts); + List> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testMapValues() { + List>> inputData = stringStringKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2<>("california", "DODGERS"), + new Tuple2<>("california", "GIANTS"), + new Tuple2<>("new york", "YANKEES"), + new Tuple2<>("new york", "METS")), + Arrays.asList(new Tuple2<>("california", "SHARKS"), + new Tuple2<>("california", "DUCKS"), + new Tuple2<>("new york", "RANGERS"), + new Tuple2<>("new york", "ISLANDERS"))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream mapped = pairStream.mapValues(String::toUpperCase); + JavaTestUtils.attachTestOutputStream(mapped); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testFlatMapValues() { + List>> inputData = stringStringKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2<>("california", "dodgers1"), + new Tuple2<>("california", "dodgers2"), + new Tuple2<>("california", "giants1"), + new Tuple2<>("california", "giants2"), + new Tuple2<>("new york", "yankees1"), + new Tuple2<>("new york", "yankees2"), + new Tuple2<>("new york", "mets1"), + new Tuple2<>("new york", "mets2")), + Arrays.asList(new Tuple2<>("california", "sharks1"), + new Tuple2<>("california", "sharks2"), + new Tuple2<>("california", "ducks1"), + new Tuple2<>("california", "ducks2"), + new Tuple2<>("new york", "rangers1"), + new Tuple2<>("new york", "rangers2"), + new Tuple2<>("new york", "islanders1"), + new Tuple2<>("new york", "islanders2"))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream flatMapped = + pairStream.flatMapValues(in -> Arrays.asList(in + "1", in + "2")); + JavaTestUtils.attachTestOutputStream(flatMapped); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + Assert.assertEquals(expected, result); + } + + /** + * This test is only for testing the APIs. It's not necessary to run it. + */ + public void testMapWithStateAPI() { + JavaPairRDD initialRDD = null; + JavaPairDStream wordsDstream = null; + + JavaMapWithStateDStream stateDstream = + wordsDstream.mapWithState( + StateSpec.function((time, key, value, state) -> { + // Use all State's methods here + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return Optional.of(2.0); + }).initialState(initialRDD) + .numPartitions(10) + .partitioner(new HashPartitioner(10)) + .timeout(Durations.seconds(10))); + + JavaPairDStream emittedRecords = stateDstream.stateSnapshots(); + + JavaMapWithStateDStream stateDstream2 = + wordsDstream.mapWithState( + StateSpec.function((key, value, state) -> { + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return 2.0; + }).initialState(initialRDD) + .numPartitions(10) + .partitioner(new HashPartitioner(10)) + .timeout(Durations.seconds(10))); + + JavaPairDStream mappedDStream = stateDstream2.stateSnapshots(); + } +} diff --git a/external/java8-tests/src/test/resources/log4j.properties b/external/java8-tests/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..3706a6e361307 --- /dev/null +++ b/external/java8-tests/src/test/resources/log4j.properties @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.spark_project.jetty=WARN diff --git a/external/java8-tests/src/test/scala/org/apache/spark/JDK8ScalaSuite.scala b/external/java8-tests/src/test/scala/org/apache/spark/JDK8ScalaSuite.scala new file mode 100644 index 0000000000000..fa0681db41088 --- /dev/null +++ b/external/java8-tests/src/test/scala/org/apache/spark/JDK8ScalaSuite.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * Test cases where JDK8-compiled Scala user code is used with Spark. + */ +class JDK8ScalaSuite extends SparkFunSuite with SharedSparkContext { + test("basic RDD closure test (SPARK-6152)") { + sc.parallelize(1 to 1000).map(x => x * x).count() + } +} diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml index a9ed39ef8c9a0..62818f5e8f434 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-assembly/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT + spark-parent_2.11 + 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-kafka-assembly_2.10 + spark-streaming-kafka-assembly_2.11 jar Spark Project External Kafka Assembly http://spark.apache.org/ diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 79258c126e043..68d52e9339b3d 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT + spark-parent_2.11 + 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-kafka_2.10 + spark-streaming-kafka_2.11 streaming-kafka diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 8a087474d3169..fb58ed789887f 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -25,7 +25,8 @@ import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata import kafka.serializer.Decoder -import org.apache.spark.{Logging, SparkException} +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset @@ -58,11 +59,11 @@ class DirectKafkaInputDStream[ U <: Decoder[K]: ClassTag, T <: Decoder[V]: ClassTag, R: ClassTag]( - ssc_ : StreamingContext, + _ssc: StreamingContext, val kafkaParams: Map[String, String], val fromOffsets: Map[TopicAndPartition, Long], messageHandler: MessageAndMetadata[K, V] => R - ) extends InputDStream[R](ssc_) with Logging { + ) extends InputDStream[R](_ssc) with Logging { val maxRetries = context.sparkContext.getConf.getInt( "spark.streaming.kafka.maxRetries", 1) @@ -89,23 +90,32 @@ class DirectKafkaInputDStream[ private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt( "spark.streaming.kafka.maxRatePerPartition", 0) - protected def maxMessagesPerPartition: Option[Long] = { + + protected[streaming] def maxMessagesPerPartition( + offsets: Map[TopicAndPartition, Long]): Option[Map[TopicAndPartition, Long]] = { val estimatedRateLimit = rateController.map(_.getLatestRate().toInt) - val numPartitions = currentOffsets.keys.size - - val effectiveRateLimitPerPartition = estimatedRateLimit - .filter(_ > 0) - .map { limit => - if (maxRateLimitPerPartition > 0) { - Math.min(maxRateLimitPerPartition, (limit / numPartitions)) - } else { - limit / numPartitions + + // calculate a per-partition rate limit based on current lag + val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match { + case Some(rate) => + val lagPerPartition = offsets.map { case (tp, offset) => + tp -> Math.max(offset - currentOffsets(tp), 0) + } + val totalLag = lagPerPartition.values.sum + + lagPerPartition.map { case (tp, lag) => + val backpressureRate = Math.round(lag / totalLag.toFloat * rate) + tp -> (if (maxRateLimitPerPartition > 0) { + Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate) } - }.getOrElse(maxRateLimitPerPartition) + case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition } + } - if (effectiveRateLimitPerPartition > 0) { + if (effectiveRateLimitPerPartition.values.sum > 0) { val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 - Some((secsPerBatch * effectiveRateLimitPerPartition).toLong) + Some(effectiveRateLimitPerPartition.map { + case (tp, limit) => tp -> (secsPerBatch * limit).toLong + }) } else { None } @@ -134,9 +144,12 @@ class DirectKafkaInputDStream[ // limits the maximum number of messages per partition protected def clamp( leaderOffsets: Map[TopicAndPartition, LeaderOffset]): Map[TopicAndPartition, LeaderOffset] = { - maxMessagesPerPartition.map { mmp => - leaderOffsets.map { case (tp, lo) => - tp -> lo.copy(offset = Math.min(currentOffsets(tp) + mmp, lo.offset)) + val offsets = leaderOffsets.mapValues(lo => lo.offset) + + maxMessagesPerPartition(offsets).map { mmp => + mmp.map { case (tp, messages) => + val lo = leaderOffsets(tp) + tp -> lo.copy(offset = Math.min(currentOffsets(tp) + messages, lo.offset)) } }.getOrElse(leaderOffsets) } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index 8465432c5850f..726b5d8ec3d3b 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -17,24 +17,31 @@ package org.apache.spark.streaming.kafka -import scala.util.control.NonFatal -import scala.util.Random -import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConverters._ import java.util.Properties + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer +import scala.util.Random +import scala.util.control.NonFatal + import kafka.api._ import kafka.common.{ErrorMapping, OffsetAndMetadata, OffsetMetadataAndError, TopicAndPartition} import kafka.consumer.{ConsumerConfig, SimpleConsumer} + import org.apache.spark.SparkException +import org.apache.spark.annotation.DeveloperApi /** + * :: DeveloperApi :: * Convenience methods for interacting with a Kafka cluster. + * See + * A Guide To The Kafka Protocol for more details on individual api calls. * @param kafkaParams Kafka * configuration parameters. * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s), * NOT zookeeper servers, specified in host1:port1,host2:port2 form */ -private[spark] +@DeveloperApi class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { import KafkaCluster.{Err, LeaderOffset, SimpleConsumerConfig} @@ -160,7 +167,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { ): Either[Err, Map[TopicAndPartition, LeaderOffset]] = { getLeaderOffsets(topicAndPartitions, before, 1).right.map { r => r.map { kv => - // mapValues isnt serializable, see SI-7005 + // mapValues isn't serializable, see SI-7005 kv._1 -> kv._2.head } } @@ -224,7 +231,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { // this 0 here indicates api version, in this case the original ZK backed api. private def defaultConsumerApiVersion: Short = 0 - /** Requires Kafka >= 0.8.1.1 */ + /** Requires Kafka >= 0.8.1.1. Defaults to the original ZooKeeper backed api version. */ def getConsumerOffsets( groupId: String, topicAndPartitions: Set[TopicAndPartition] @@ -243,7 +250,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { } } - /** Requires Kafka >= 0.8.1.1 */ + /** Requires Kafka >= 0.8.1.1. Defaults to the original ZooKeeper backed api version. */ def getConsumerOffsetMetadata( groupId: String, topicAndPartitions: Set[TopicAndPartition] @@ -280,7 +287,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { Left(errs) } - /** Requires Kafka >= 0.8.1.1 */ + /** Requires Kafka >= 0.8.1.1. Defaults to the original ZooKeeper backed api version. */ def setConsumerOffsets( groupId: String, offsets: Map[TopicAndPartition, Long] @@ -298,7 +305,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { setConsumerOffsetMetadata(groupId, meta, consumerApiVersion) } - /** Requires Kafka >= 0.8.1.1 */ + /** Requires Kafka >= 0.8.1.1. Defaults to the original ZooKeeper backed api version. */ def setConsumerOffsetMetadata( groupId: String, metadata: Map[TopicAndPartition, OffsetAndMetadata] @@ -356,7 +363,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { } } -private[spark] +@DeveloperApi object KafkaCluster { type Err = ArrayBuffer[Throwable] @@ -368,7 +375,6 @@ object KafkaCluster { ) } - private[spark] case class LeaderOffset(host: String, port: Int, offset: Long) /** @@ -376,19 +382,17 @@ object KafkaCluster { * Simple consumers connect directly to brokers, but need many of the same configs. * This subclass won't warn about missing ZK params, or presence of broker params. */ - private[spark] class SimpleConsumerConfig private(brokers: String, originalProps: Properties) extends ConsumerConfig(originalProps) { val seedBrokers: Array[(String, Int)] = brokers.split(",").map { hp => val hpa = hp.split(":") if (hpa.size == 1) { - throw new SparkException(s"Broker not the in correct format of : [$brokers]") + throw new SparkException(s"Broker not in the correct format of : [$brokers]") } (hpa(0), hpa(1).toInt) } } - private[spark] object SimpleConsumerConfig { /** * Make a consumer config without requiring group.id or zookeeper.connect, diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index 38730fecf332a..3713bda41b8ee 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -22,11 +22,11 @@ import java.util.Properties import scala.collection.Map import scala.reflect.{classTag, ClassTag} -import kafka.consumer.{KafkaStream, Consumer, ConsumerConfig, ConsumerConnector} +import kafka.consumer.{Consumer, ConsumerConfig, ConsumerConnector, KafkaStream} import kafka.serializer.Decoder import kafka.utils.VerifiableProperties -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.dstream._ @@ -48,12 +48,12 @@ class KafkaInputDStream[ V: ClassTag, U <: Decoder[_]: ClassTag, T <: Decoder[_]: ClassTag]( - ssc_ : StreamingContext, + _ssc: StreamingContext, kafkaParams: Map[String, String], topics: Map[String, Int], useReliableReceiver: Boolean, storageLevel: StorageLevel - ) extends ReceiverInputDStream[(K, V)](ssc_) with Logging { + ) extends ReceiverInputDStream[(K, V)](_ssc) with Logging { def getReceiver(): Receiver[(K, V)] = { if (!useReliableReceiver) { diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index ea5f842c6cafe..d4881b140df3c 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -20,11 +20,6 @@ package org.apache.spark.streaming.kafka import scala.collection.mutable.ArrayBuffer import scala.reflect.{classTag, ClassTag} -import org.apache.spark.{Logging, Partition, SparkContext, SparkException, TaskContext} -import org.apache.spark.partial.{PartialResult, BoundedDouble} -import org.apache.spark.rdd.RDD -import org.apache.spark.util.NextIterator - import kafka.api.{FetchRequestBuilder, FetchResponse} import kafka.common.{ErrorMapping, TopicAndPartition} import kafka.consumer.SimpleConsumer @@ -32,6 +27,12 @@ import kafka.message.{MessageAndMetadata, MessageAndOffset} import kafka.serializer.Decoder import kafka.utils.VerifiableProperties +import org.apache.spark.{Partition, SparkContext, SparkException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.partial.{BoundedDouble, PartialResult} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.NextIterator + /** * A batch-oriented interface for consuming from Kafka. * Starting and ending offsets are specified in advance, @@ -79,7 +80,7 @@ class KafkaRDD[ .map(_.asInstanceOf[KafkaRDDPartition]) .filter(_.count > 0) - if (num < 1 || nonEmptyPartitions.size < 1) { + if (num < 1 || nonEmptyPartitions.isEmpty) { return new Array[R](0) } @@ -156,7 +157,7 @@ class KafkaRDD[ var requestOffset = part.fromOffset var iter: Iterator[MessageAndOffset] = null - // The idea is to use the provided preferred host, except on task retry atttempts, + // The idea is to use the provided preferred host, except on task retry attempts, // to minimize number of kafka metadata requests private def connectLeader: SimpleConsumer = { if (context.attemptNumber > 0) { diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala index a660d2a00c35d..02917becf0ff9 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala @@ -19,13 +19,14 @@ package org.apache.spark.streaming.kafka import org.apache.spark.Partition -/** @param topic kafka topic name - * @param partition kafka partition id - * @param fromOffset inclusive starting offset - * @param untilOffset exclusive ending offset - * @param host preferred kafka host, i.e. the leader at the time the rdd was created - * @param port preferred kafka host's port - */ +/** + * @param topic kafka topic name + * @param partition kafka partition id + * @param fromOffset inclusive starting offset + * @param untilOffset exclusive ending offset + * @param host preferred kafka host, i.e. the leader at the time the rdd was created + * @param port preferred kafka host's port + */ private[kafka] class KafkaRDDPartition( val index: Int, diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala index c9fd715d3d554..d9d4240c056a5 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -20,8 +20,8 @@ package org.apache.spark.streaming.kafka import java.io.File import java.lang.{Integer => JInt} import java.net.InetSocketAddress -import java.util.concurrent.TimeoutException import java.util.{Map => JMap, Properties} +import java.util.concurrent.TimeoutException import scala.annotation.tailrec import scala.collection.JavaConverters._ @@ -37,9 +37,10 @@ import kafka.utils.{ZKStringSerializer, ZkUtils} import org.I0Itec.zkclient.ZkClient import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging import org.apache.spark.streaming.Time import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkConf} /** * This is a helper class for Kafka test suites. This has the functionality to set up @@ -52,7 +53,7 @@ private[kafka] class KafkaTestUtils extends Logging { // Zookeeper related configurations private val zkHost = "localhost" private var zkPort: Int = 0 - private val zkConnectionTimeout = 6000 + private val zkConnectionTimeout = 60000 private val zkSessionTimeout = 6000 private var zookeeper: EmbeddedZookeeper = _ @@ -151,13 +152,16 @@ private[kafka] class KafkaTestUtils extends Logging { } } - /** Create a Kafka topic and wait until it propagated to the whole cluster */ - def createTopic(topic: String): Unit = { - AdminUtils.createTopic(zkClient, topic, 1, 1) + /** Create a Kafka topic and wait until it is propagated to the whole cluster */ + def createTopic(topic: String, partitions: Int): Unit = { + AdminUtils.createTopic(zkClient, topic, partitions, 1) // wait until metadata is propagated - waitUntilMetadataIsPropagated(topic, 0) + (0 until partitions).foreach { p => waitUntilMetadataIsPropagated(topic, p) } } + /** Single-argument version for backwards compatibility */ + def createTopic(topic: String): Unit = createTopic(topic, 1) + /** Java-friendly function for sending messages to the Kafka broker */ def sendMessages(topic: String, messageToFreq: JMap[String, JInt]): Unit = { sendMessages(topic, Map(messageToFreq.asScala.mapValues(_.intValue()).toSeq: _*)) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 3128222077537..edaafb912c5c5 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -17,7 +17,9 @@ package org.apache.spark.streaming.kafka +import java.io.OutputStream import java.lang.{Integer => JInt, Long => JLong} +import java.nio.charset.StandardCharsets import java.util.{List => JList, Map => JMap, Set => JSet} import scala.collection.JavaConverters._ @@ -26,16 +28,18 @@ import scala.reflect.ClassTag import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata import kafka.serializer.{Decoder, DefaultDecoder, StringDecoder} +import net.razorvine.pickle.{IObjectPickler, Opcodes, Pickler} -import org.apache.spark.api.java.function.{Function => JFunction} +import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} +import org.apache.spark.api.java.function.{Function => JFunction} +import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaInputDStream, JavaPairInputDStream, JavaPairReceiverInputDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream} +import org.apache.spark.streaming.api.java._ +import org.apache.spark.streaming.dstream.{DStream, InputDStream, ReceiverInputDStream} import org.apache.spark.streaming.util.WriteAheadLogUtils -import org.apache.spark.{SparkContext, SparkException} object KafkaUtils { /** @@ -47,6 +51,7 @@ object KafkaUtils { * in its own thread * @param storageLevel Storage level to use for storing the received objects * (default: StorageLevel.MEMORY_AND_DISK_SER_2) + * @return DStream of (Kafka message key, Kafka message value) */ def createStream( ssc: StreamingContext, @@ -70,6 +75,11 @@ object KafkaUtils { * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param storageLevel Storage level to use for storing the received objects + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam U type of Kafka message key decoder + * @tparam T type of Kafka message value decoder + * @return DStream of (Kafka message key, Kafka message value) */ def createStream[K: ClassTag, V: ClassTag, U <: Decoder[_]: ClassTag, T <: Decoder[_]: ClassTag]( ssc: StreamingContext, @@ -89,6 +99,7 @@ object KafkaUtils { * @param groupId The group id for this consumer * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread + * @return DStream of (Kafka message key, Kafka message value) */ def createStream( jssc: JavaStreamingContext, @@ -107,6 +118,7 @@ object KafkaUtils { * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param storageLevel RDD storage level. + * @return DStream of (Kafka message key, Kafka message value) */ def createStream( jssc: JavaStreamingContext, @@ -131,6 +143,11 @@ object KafkaUtils { * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread * @param storageLevel RDD storage level. + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam U type of Kafka message key decoder + * @tparam T type of Kafka message value decoder + * @return DStream of (Kafka message key, Kafka message value) */ def createStream[K, V, U <: Decoder[_], T <: Decoder[_]]( jssc: JavaStreamingContext, @@ -184,6 +201,27 @@ object KafkaUtils { } } + private[kafka] def getFromOffsets( + kc: KafkaCluster, + kafkaParams: Map[String, String], + topics: Set[String] + ): Map[TopicAndPartition, Long] = { + val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase) + val result = for { + topicPartitions <- kc.getPartitions(topics).right + leaderOffsets <- (if (reset == Some("smallest")) { + kc.getEarliestLeaderOffsets(topicPartitions) + } else { + kc.getLatestLeaderOffsets(topicPartitions) + }).right + } yield { + leaderOffsets.map { case (tp, lo) => + (tp, lo.offset) + } + } + KafkaCluster.checkErrors(result) + } + /** * Create a RDD from Kafka using offset ranges for each topic and partition. * @@ -194,6 +232,11 @@ object KafkaUtils { * host1:port1,host2:port2 form. * @param offsetRanges Each OffsetRange in the batch corresponds to a * range of offsets for a given Kafka topic/partition + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @return RDD of (Kafka message key, Kafka message value) */ def createRDD[ K: ClassTag, @@ -226,6 +269,12 @@ object KafkaUtils { * @param leaders Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty map, * in which case leaders will be looked up on the driver. * @param messageHandler Function for translating each message and metadata into the desired type + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @tparam R type returned by messageHandler + * @return RDD of R */ def createRDD[ K: ClassTag, @@ -246,7 +295,7 @@ object KafkaUtils { // This could be avoided by refactoring KafkaRDD.leaders and KafkaCluster to use Broker leaders.map { case (tp: TopicAndPartition, Broker(host, port)) => (tp, (host, port)) - }.toMap + } } val cleanedHandler = sc.clean(messageHandler) checkOffsets(kc, offsetRanges) @@ -263,6 +312,15 @@ object KafkaUtils { * host1:port1,host2:port2 form. * @param offsetRanges Each OffsetRange in the batch corresponds to a * range of offsets for a given Kafka topic/partition + * @param keyClass type of Kafka message key + * @param valueClass type of Kafka message value + * @param keyDecoderClass type of Kafka message key decoder + * @param valueDecoderClass type of Kafka message value decoder + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @return RDD of (Kafka message key, Kafka message value) */ def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V]]( jsc: JavaSparkContext, @@ -296,6 +354,12 @@ object KafkaUtils { * @param leaders Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty map, * in which case leaders will be looked up on the driver. * @param messageHandler Function for translating each message and metadata into the desired type + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @tparam R type returned by messageHandler + * @return RDD of R */ def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V], R]( jsc: JavaSparkContext, @@ -348,6 +412,12 @@ object KafkaUtils { * @param fromOffsets Per-topic/partition Kafka offsets defining the (inclusive) * starting point of the stream * @param messageHandler Function for translating each message and metadata into the desired type + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @tparam R type returned by messageHandler + * @return DStream of R */ def createDirectStream[ K: ClassTag, @@ -394,6 +464,11 @@ object KafkaUtils { * If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest" * to determine where the stream starts (defaults to "largest") * @param topics Names of the topics to consume + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @return DStream of (Kafka message key, Kafka message value) */ def createDirectStream[ K: ClassTag, @@ -406,23 +481,9 @@ object KafkaUtils { ): InputDStream[(K, V)] = { val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message) val kc = new KafkaCluster(kafkaParams) - val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase) - - val result = for { - topicPartitions <- kc.getPartitions(topics).right - leaderOffsets <- (if (reset == Some("smallest")) { - kc.getEarliestLeaderOffsets(topicPartitions) - } else { - kc.getLatestLeaderOffsets(topicPartitions) - }).right - } yield { - val fromOffsets = leaderOffsets.map { case (tp, lo) => - (tp, lo.offset) - } - new DirectKafkaInputDStream[K, V, KD, VD, (K, V)]( - ssc, kafkaParams, fromOffsets, messageHandler) - } - KafkaCluster.checkErrors(result) + val fromOffsets = getFromOffsets(kc, kafkaParams, topics) + new DirectKafkaInputDStream[K, V, KD, VD, (K, V)]( + ssc, kafkaParams, fromOffsets, messageHandler) } /** @@ -459,6 +520,12 @@ object KafkaUtils { * @param fromOffsets Per-topic/partition Kafka offsets defining the (inclusive) * starting point of the stream * @param messageHandler Function for translating each message and metadata into the desired type + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @tparam R type returned by messageHandler + * @return DStream of R */ def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V], R]( jssc: JavaStreamingContext, @@ -518,6 +585,11 @@ object KafkaUtils { * If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest" * to determine where the stream starts (defaults to "largest") * @param topics Names of the topics to consume + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @return DStream of (Kafka message key, Kafka message value) */ def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V]]( jssc: JavaStreamingContext, @@ -543,13 +615,15 @@ object KafkaUtils { /** * This is a helper class that wraps the KafkaUtils.createStream() into more * Python-friendly class and function so that it can be easily - * instantiated and called from Python's KafkaUtils (see SPARK-6027). + * instantiated and called from Python's KafkaUtils. * * The zero-arg constructor helps instantiate this class from the Class object * classOf[KafkaUtilsPythonHelper].newInstance(), and the createStream() * takes care of known parameters instead of passing them from Python */ private[kafka] class KafkaUtilsPythonHelper { + import KafkaUtilsPythonHelper._ + def createStream( jssc: JavaStreamingContext, kafkaParams: JMap[String, String], @@ -566,86 +640,92 @@ private[kafka] class KafkaUtilsPythonHelper { storageLevel) } - def createRDD( + def createRDDWithoutMessageHandler( jsc: JavaSparkContext, kafkaParams: JMap[String, String], offsetRanges: JList[OffsetRange], - leaders: JMap[TopicAndPartition, Broker]): JavaPairRDD[Array[Byte], Array[Byte]] = { - val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]], - (Array[Byte], Array[Byte])] { - def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) = - (t1.key(), t1.message()) - } + leaders: JMap[TopicAndPartition, Broker]): JavaRDD[(Array[Byte], Array[Byte])] = { + val messageHandler = + (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => (mmd.key, mmd.message) + new JavaRDD(createRDD(jsc, kafkaParams, offsetRanges, leaders, messageHandler)) + } - val jrdd = KafkaUtils.createRDD[ - Array[Byte], - Array[Byte], - DefaultDecoder, - DefaultDecoder, - (Array[Byte], Array[Byte])]( - jsc, - classOf[Array[Byte]], - classOf[Array[Byte]], - classOf[DefaultDecoder], - classOf[DefaultDecoder], - classOf[(Array[Byte], Array[Byte])], - kafkaParams, - offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())), - leaders, - messageHandler - ) - new JavaPairRDD(jrdd.rdd) + def createRDDWithMessageHandler( + jsc: JavaSparkContext, + kafkaParams: JMap[String, String], + offsetRanges: JList[OffsetRange], + leaders: JMap[TopicAndPartition, Broker]): JavaRDD[Array[Byte]] = { + val messageHandler = (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => + new PythonMessageAndMetadata( + mmd.topic, mmd.partition, mmd.offset, mmd.key(), mmd.message()) + val rdd = createRDD(jsc, kafkaParams, offsetRanges, leaders, messageHandler). + mapPartitions(picklerIterator) + new JavaRDD(rdd) } - def createDirectStream( + private def createRDD[V: ClassTag]( + jsc: JavaSparkContext, + kafkaParams: JMap[String, String], + offsetRanges: JList[OffsetRange], + leaders: JMap[TopicAndPartition, Broker], + messageHandler: MessageAndMetadata[Array[Byte], Array[Byte]] => V): RDD[V] = { + KafkaUtils.createRDD[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder, V]( + jsc.sc, + kafkaParams.asScala.toMap, + offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())), + leaders.asScala.toMap, + messageHandler + ) + } + + def createDirectStreamWithoutMessageHandler( + jssc: JavaStreamingContext, + kafkaParams: JMap[String, String], + topics: JSet[String], + fromOffsets: JMap[TopicAndPartition, JLong]): JavaDStream[(Array[Byte], Array[Byte])] = { + val messageHandler = + (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => (mmd.key, mmd.message) + new JavaDStream(createDirectStream(jssc, kafkaParams, topics, fromOffsets, messageHandler)) + } + + def createDirectStreamWithMessageHandler( jssc: JavaStreamingContext, kafkaParams: JMap[String, String], topics: JSet[String], - fromOffsets: JMap[TopicAndPartition, JLong] - ): JavaPairInputDStream[Array[Byte], Array[Byte]] = { + fromOffsets: JMap[TopicAndPartition, JLong]): JavaDStream[Array[Byte]] = { + val messageHandler = (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => + new PythonMessageAndMetadata(mmd.topic, mmd.partition, mmd.offset, mmd.key(), mmd.message()) + val stream = createDirectStream(jssc, kafkaParams, topics, fromOffsets, messageHandler). + mapPartitions(picklerIterator) + new JavaDStream(stream) + } - if (!fromOffsets.isEmpty) { + private def createDirectStream[V: ClassTag]( + jssc: JavaStreamingContext, + kafkaParams: JMap[String, String], + topics: JSet[String], + fromOffsets: JMap[TopicAndPartition, JLong], + messageHandler: MessageAndMetadata[Array[Byte], Array[Byte]] => V): DStream[V] = { + + val currentFromOffsets = if (!fromOffsets.isEmpty) { val topicsFromOffsets = fromOffsets.keySet().asScala.map(_.topic) if (topicsFromOffsets != topics.asScala.toSet) { throw new IllegalStateException( s"The specified topics: ${topics.asScala.toSet.mkString(" ")} " + s"do not equal to the topic from offsets: ${topicsFromOffsets.mkString(" ")}") } - } - - if (fromOffsets.isEmpty) { - KafkaUtils.createDirectStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder]( - jssc, - classOf[Array[Byte]], - classOf[Array[Byte]], - classOf[DefaultDecoder], - classOf[DefaultDecoder], - kafkaParams, - topics) + Map(fromOffsets.asScala.mapValues { _.longValue() }.toSeq: _*) } else { - val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]], - (Array[Byte], Array[Byte])] { - def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) = - (t1.key(), t1.message()) - } - - val jstream = KafkaUtils.createDirectStream[ - Array[Byte], - Array[Byte], - DefaultDecoder, - DefaultDecoder, - (Array[Byte], Array[Byte])]( - jssc, - classOf[Array[Byte]], - classOf[Array[Byte]], - classOf[DefaultDecoder], - classOf[DefaultDecoder], - classOf[(Array[Byte], Array[Byte])], - kafkaParams, - fromOffsets, - messageHandler) - new JavaPairInputDStream(jstream.inputDStream) + val kc = new KafkaCluster(Map(kafkaParams.asScala.toSeq: _*)) + KafkaUtils.getFromOffsets( + kc, Map(kafkaParams.asScala.toSeq: _*), Set(topics.asScala.toSeq: _*)) } + + KafkaUtils.createDirectStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder, V]( + jssc.ssc, + Map(kafkaParams.asScala.toSeq: _*), + Map(currentFromOffsets.toSeq: _*), + messageHandler) } def createOffsetRange(topic: String, partition: JInt, fromOffset: JLong, untilOffset: JLong @@ -669,3 +749,57 @@ private[kafka] class KafkaUtilsPythonHelper { kafkaRDD.offsetRanges.toSeq.asJava } } + +private object KafkaUtilsPythonHelper { + private var initialized = false + + def initialize(): Unit = { + SerDeUtil.initialize() + synchronized { + if (!initialized) { + new PythonMessageAndMetadataPickler().register() + initialized = true + } + } + } + + initialize() + + def picklerIterator(iter: Iterator[Any]): Iterator[Array[Byte]] = { + new SerDeUtil.AutoBatchedPickler(iter) + } + + case class PythonMessageAndMetadata( + topic: String, + partition: JInt, + offset: JLong, + key: Array[Byte], + message: Array[Byte]) + + class PythonMessageAndMetadataPickler extends IObjectPickler { + private val module = "pyspark.streaming.kafka" + + def register(): Unit = { + Pickler.registerCustomPickler(classOf[PythonMessageAndMetadata], this) + Pickler.registerCustomPickler(this.getClass, this) + } + + def pickle(obj: Object, out: OutputStream, pickler: Pickler) { + if (obj == this) { + out.write(Opcodes.GLOBAL) + out.write(s"$module\nKafkaMessageAndMetadata\n".getBytes(StandardCharsets.UTF_8)) + } else { + pickler.save(this) + val msgAndMetaData = obj.asInstanceOf[PythonMessageAndMetadata] + out.write(Opcodes.MARK) + pickler.save(msgAndMetaData.topic) + pickler.save(msgAndMetaData.partition) + pickler.save(msgAndMetaData.offset) + pickler.save(msgAndMetaData.key) + pickler.save(msgAndMetaData.message) + out.write(Opcodes.TUPLE) + out.write(Opcodes.REDUCE) + } + } + } +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala index 8a5f371494511..d9b856e4697a0 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.kafka import kafka.common.TopicAndPartition /** - * Represents any object that has a collection of [[OffsetRange]]s. This can be used access the + * Represents any object that has a collection of [[OffsetRange]]s. This can be used to access the * offset ranges in RDDs generated by the direct Kafka DStream (see * [[KafkaUtils.createDirectStream()]]). * {{{ diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala index 764d170934aa6..39abe3c3e29d0 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala @@ -18,10 +18,10 @@ package org.apache.spark.streaming.kafka import java.util.Properties -import java.util.concurrent.{ThreadPoolExecutor, ConcurrentHashMap} +import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor} -import scala.collection.{Map, mutable} -import scala.reflect.{ClassTag, classTag} +import scala.collection.{mutable, Map} +import scala.reflect.{classTag, ClassTag} import kafka.common.TopicAndPartition import kafka.consumer.{Consumer, ConsumerConfig, ConsumerConnector, KafkaStream} @@ -30,7 +30,8 @@ import kafka.serializer.Decoder import kafka.utils.{VerifiableProperties, ZKGroupTopicDirs, ZKStringSerializer, ZkUtils} import org.I0Itec.zkclient.ZkClient -import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver} import org.apache.spark.util.ThreadUtils diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java index fbdfbf7e509b3..fa6b0dbc8c219 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java @@ -35,6 +35,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.VoidFunction; import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; @@ -130,17 +131,15 @@ public String call(MessageAndMetadata msgAndMd) { JavaDStream unifiedStream = stream1.union(stream2); final Set result = Collections.synchronizedSet(new HashSet()); - unifiedStream.foreachRDD( - new Function, Void>() { + unifiedStream.foreachRDD(new VoidFunction>() { @Override - public Void call(JavaRDD rdd) { + public void call(JavaRDD rdd) { result.addAll(rdd.collect()); for (OffsetRange o : offsetRanges.get()) { System.out.println( o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset() ); } - return null; } } ); @@ -169,7 +168,7 @@ private static Map topicOffsetToMap(String topic, Long private String[] createTopicAndSendData(String topic) { String[] data = { topic + "-1", topic + "-2", topic + "-3"}; - kafkaTestUtils.createTopic(topic); + kafkaTestUtils.createTopic(topic, 1); kafkaTestUtils.sendMessages(topic, data); return data; } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java index afcc6cfccd39a..c41b6297b0481 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java @@ -149,7 +149,7 @@ public String call(MessageAndMetadata msgAndMd) { private String[] createTopicAndSendData(String topic) { String[] data = { topic + "-1", topic + "-2", topic + "-3"}; - kafkaTestUtils.createTopic(topic); + kafkaTestUtils.createTopic(topic, 1); kafkaTestUtils.sendMessages(topic, data); return data; } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index 1e69de46cd35d..868df64e8c944 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -31,6 +31,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.VoidFunction; import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaDStream; @@ -75,7 +76,7 @@ public void testKafkaStream() throws InterruptedException { sent.put("b", 3); sent.put("c", 10); - kafkaTestUtils.createTopic(topic); + kafkaTestUtils.createTopic(topic, 1); kafkaTestUtils.sendMessages(topic, sent); Map kafkaParams = new HashMap<>(); @@ -103,10 +104,9 @@ public String call(Tuple2 tuple2) { } ); - words.countByValue().foreachRDD( - new Function, Void>() { + words.countByValue().foreachRDD(new VoidFunction>() { @Override - public Void call(JavaPairRDD rdd) { + public void call(JavaPairRDD rdd) { List> ret = rdd.collect(); for (Tuple2 r : ret) { if (result.containsKey(r._1())) { @@ -115,8 +115,6 @@ public Void call(JavaPairRDD rdd) { result.put(r._1(), r._2()); } } - - return null; } } ); diff --git a/external/kafka/src/test/resources/log4j.properties b/external/kafka/src/test/resources/log4j.properties index 75e3b53a093f6..fd51f8faf56b9 100644 --- a/external/kafka/src/test/resources/log4j.properties +++ b/external/kafka/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 02225d5aa7cc5..f14ff6705fd97 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -18,13 +18,11 @@ package org.apache.spark.streaming.kafka import java.io.File +import java.util.Arrays import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.ConcurrentLinkedQueue -import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset -import org.apache.spark.streaming.scheduler.rate.RateEstimator - -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps @@ -34,11 +32,14 @@ import kafka.serializer.StringDecoder import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.scheduler.rate.RateEstimator import org.apache.spark.util.Utils class DirectKafkaStreamSuite @@ -102,8 +103,7 @@ class DirectKafkaStreamSuite ssc, kafkaParams, topics) } - val allReceived = - new ArrayBuffer[(String, String)] with mutable.SynchronizedBuffer[(String, String)] + val allReceived = new ConcurrentLinkedQueue[(String, String)]() // hold a reference to the current offset ranges, so it can be used downstream var offsetRanges = Array[OffsetRange]() @@ -132,11 +132,12 @@ class DirectKafkaStreamSuite assert(partSize === rangeSize, "offset ranges are wrong") } } - stream.foreachRDD { rdd => allReceived ++= rdd.collect() } + stream.foreachRDD { rdd => allReceived.addAll(Arrays.asList(rdd.collect(): _*)) } ssc.start() eventually(timeout(20000.milliseconds), interval(200.milliseconds)) { assert(allReceived.size === totalSent, - "didn't get expected number of messages, messages:\n" + allReceived.mkString("\n")) + "didn't get expected number of messages, messages:\n" + + allReceived.asScala.mkString("\n")) } ssc.stop() } @@ -174,8 +175,8 @@ class DirectKafkaStreamSuite "Start offset not from latest" ) - val collectedData = new mutable.ArrayBuffer[String]() with mutable.SynchronizedBuffer[String] - stream.map { _._2 }.foreachRDD { rdd => collectedData ++= rdd.collect() } + val collectedData = new ConcurrentLinkedQueue[String]() + stream.map { _._2 }.foreachRDD { rdd => collectedData.addAll(Arrays.asList(rdd.collect(): _*)) } ssc.start() val newData = Map("b" -> 10) kafkaTestUtils.sendMessages(topic, newData) @@ -220,8 +221,8 @@ class DirectKafkaStreamSuite "Start offset not from latest" ) - val collectedData = new mutable.ArrayBuffer[String]() with mutable.SynchronizedBuffer[String] - stream.foreachRDD { rdd => collectedData ++= rdd.collect() } + val collectedData = new ConcurrentLinkedQueue[String]() + stream.foreachRDD { rdd => collectedData.addAll(Arrays.asList(rdd.collect(): _*)) } ssc.start() val newData = Map("b" -> 10) kafkaTestUtils.sendMessages(topic, newData) @@ -266,7 +267,7 @@ class DirectKafkaStreamSuite // This is to collect the raw data received from Kafka kafkaStream.foreachRDD { (rdd: RDD[(String, String)], time: Time) => val data = rdd.map { _._2 }.collect() - DirectKafkaStreamSuite.collectedData.appendAll(data) + DirectKafkaStreamSuite.collectedData.addAll(Arrays.asList(data: _*)) } // This is ensure all the data is eventually receiving only once @@ -336,14 +337,14 @@ class DirectKafkaStreamSuite ssc, kafkaParams, Set(topic)) } - val allReceived = - new ArrayBuffer[(String, String)] with mutable.SynchronizedBuffer[(String, String)] + val allReceived = new ConcurrentLinkedQueue[(String, String)] - stream.foreachRDD { rdd => allReceived ++= rdd.collect() } + stream.foreachRDD { rdd => allReceived.addAll(Arrays.asList(rdd.collect(): _*)) } ssc.start() eventually(timeout(20000.milliseconds), interval(200.milliseconds)) { assert(allReceived.size === totalSent, - "didn't get expected number of messages, messages:\n" + allReceived.mkString("\n")) + "didn't get expected number of messages, messages:\n" + + allReceived.asScala.mkString("\n")) // Calculate all the record number collected in the StreamingListener. assert(collector.numRecordsSubmitted.get() === totalSent) @@ -353,10 +354,38 @@ class DirectKafkaStreamSuite ssc.stop() } + test("maxMessagesPerPartition with backpressure disabled") { + val topic = "maxMessagesPerPartition" + val kafkaStream = getDirectKafkaStream(topic, None) + + val input = Map(TopicAndPartition(topic, 0) -> 50L, TopicAndPartition(topic, 1) -> 50L) + assert(kafkaStream.maxMessagesPerPartition(input).get == + Map(TopicAndPartition(topic, 0) -> 10L, TopicAndPartition(topic, 1) -> 10L)) + } + + test("maxMessagesPerPartition with no lag") { + val topic = "maxMessagesPerPartition" + val rateController = Some(new ConstantRateController(0, new ConstantEstimator(100), 100)) + val kafkaStream = getDirectKafkaStream(topic, rateController) + + val input = Map(TopicAndPartition(topic, 0) -> 0L, TopicAndPartition(topic, 1) -> 0L) + assert(kafkaStream.maxMessagesPerPartition(input).isEmpty) + } + + test("maxMessagesPerPartition respects max rate") { + val topic = "maxMessagesPerPartition" + val rateController = Some(new ConstantRateController(0, new ConstantEstimator(100), 1000)) + val kafkaStream = getDirectKafkaStream(topic, rateController) + + val input = Map(TopicAndPartition(topic, 0) -> 1000L, TopicAndPartition(topic, 1) -> 1000L) + assert(kafkaStream.maxMessagesPerPartition(input).get == + Map(TopicAndPartition(topic, 0) -> 10L, TopicAndPartition(topic, 1) -> 10L)) + } + test("using rate controller") { val topic = "backpressure" - val topicPartition = TopicAndPartition(topic, 0) - kafkaTestUtils.createTopic(topic) + val topicPartitions = Set(TopicAndPartition(topic, 0), TopicAndPartition(topic, 1)) + kafkaTestUtils.createTopic(topic, 2) val kafkaParams = Map( "metadata.broker.list" -> kafkaTestUtils.brokerAddress, "auto.offset.reset" -> "smallest" @@ -364,8 +393,8 @@ class DirectKafkaStreamSuite val batchIntervalMilliseconds = 100 val estimator = new ConstantEstimator(100) - val messageKeys = (1 to 200).map(_.toString) - val messages = messageKeys.map((_, 1)).toMap + val messages = Map("foo" -> 200) + kafkaTestUtils.sendMessages(topic, messages) val sparkConf = new SparkConf() // Safe, even with streaming, because we're using the direct API. @@ -380,43 +409,41 @@ class DirectKafkaStreamSuite val kafkaStream = withClue("Error creating direct stream") { val kc = new KafkaCluster(kafkaParams) val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message) - val m = kc.getEarliestLeaderOffsets(Set(topicPartition)) + val m = kc.getEarliestLeaderOffsets(topicPartitions) .fold(e => Map.empty[TopicAndPartition, Long], m => m.mapValues(lo => lo.offset)) new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)]( - ssc, kafkaParams, m, messageHandler) { + ssc, kafkaParams, m, messageHandler) { override protected[streaming] val rateController = Some(new DirectKafkaRateController(id, estimator)) } } - val collectedData = - new mutable.ArrayBuffer[Array[String]]() with mutable.SynchronizedBuffer[Array[String]] + val collectedData = new ConcurrentLinkedQueue[Array[String]]() // Used for assertion failure messages. def dataToString: String = - collectedData.map(_.mkString("[", ",", "]")).mkString("{", ", ", "}") + collectedData.asScala.map(_.mkString("[", ",", "]")).mkString("{", ", ", "}") // This is to collect the raw data received from Kafka kafkaStream.foreachRDD { (rdd: RDD[(String, String)], time: Time) => val data = rdd.map { _._2 }.collect() - collectedData += data + collectedData.add(data) } ssc.start() // Try different rate limits. - // Send data to Kafka and wait for arrays of data to appear matching the rate. + // Wait for arrays of data to appear matching the rate. Seq(100, 50, 20).foreach { rate => collectedData.clear() // Empty this buffer on each pass. estimator.updateRate(rate) // Set a new rate. // Expect blocks of data equal to "rate", scaled by the interval length in secs. val expectedSize = Math.round(rate * batchIntervalMilliseconds * 0.001) - kafkaTestUtils.sendMessages(topic, messages) eventually(timeout(5.seconds), interval(batchIntervalMilliseconds.milliseconds)) { // Assert that rate estimator values are used to determine maxMessagesPerPartition. // Funky "-" in message makes the complete assertion message read better. - assert(collectedData.exists(_.size == expectedSize), + assert(collectedData.asScala.exists(_.size == expectedSize), s" - No arrays of size $expectedSize for rate $rate found in $dataToString") } } @@ -431,10 +458,29 @@ class DirectKafkaStreamSuite rdd.asInstanceOf[KafkaRDD[K, V, _, _, (K, V)]].offsetRanges }.toSeq.sortBy { _._1 } } + + private def getDirectKafkaStream(topic: String, mockRateController: Option[RateController]) = { + val batchIntervalMilliseconds = 100 + + val sparkConf = new SparkConf() + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.kafka.maxRatePerPartition", "100") + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds)) + + val earliestOffsets = Map(TopicAndPartition(topic, 0) -> 0L, TopicAndPartition(topic, 1) -> 0L) + val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message) + new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)]( + ssc, Map[String, String](), earliestOffsets, messageHandler) { + override protected[streaming] val rateController = mockRateController + } + } } object DirectKafkaStreamSuite { - val collectedData = new mutable.ArrayBuffer[String]() with mutable.SynchronizedBuffer[String] + val collectedData = new ConcurrentLinkedQueue[String]() @volatile var total = -1L class InputInfoCollector extends StreamingListener { @@ -470,3 +516,8 @@ private[streaming] class ConstantEstimator(@volatile private var rate: Long) schedulingDelay: Long): Option[Double] = Some(rate) } +private[streaming] class ConstantRateController(id: Int, estimator: RateEstimator, rate: Long) + extends RateController(id, estimator) { + override def publish(rate: Long): Unit = () + override def getLatestRate(): Long = rate +} diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala index f52a738afd65b..5e539c1d790cc 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.streaming.kafka import scala.util.Random -import kafka.serializer.StringDecoder import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata +import kafka.serializer.StringDecoder import org.scalatest.BeforeAndAfterAll import org.apache.spark._ diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index 797b07f80d8ee..6a35ac14a8f6f 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -65,19 +65,20 @@ class KafkaStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( ssc, kafkaParams, Map(topic -> 1), StorageLevel.MEMORY_ONLY) - val result = new mutable.HashMap[String, Long]() with mutable.SynchronizedMap[String, Long] + val result = new mutable.HashMap[String, Long]() stream.map(_._2).countByValue().foreachRDD { r => - val ret = r.collect() - ret.toMap.foreach { kv => - val count = result.getOrElseUpdate(kv._1, 0) + kv._2 - result.put(kv._1, count) + r.collect().foreach { kv => + result.synchronized { + val count = result.getOrElseUpdate(kv._1, 0) + kv._2 + result.put(kv._1, count) + } } } ssc.start() eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { - assert(sent === result) + assert(result.synchronized { sent === result }) } } } diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala index 80e2df62de3fe..7b9aee39ffb76 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala @@ -50,7 +50,7 @@ class ReliableKafkaStreamSuite extends SparkFunSuite private var ssc: StreamingContext = _ private var tempDirectory: File = null - override def beforeAll() : Unit = { + override def beforeAll(): Unit = { kafkaTestUtils = new KafkaTestUtils kafkaTestUtils.setup() diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml new file mode 100644 index 0000000000000..d1c38c7ca5d69 --- /dev/null +++ b/external/kinesis-asl-assembly/pom.xml @@ -0,0 +1,181 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.0.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-kinesis-asl-assembly_2.11 + jar + Spark Project Kinesis Assembly + http://spark.apache.org/ + + + streaming-kinesis-asl-assembly + + + + + org.apache.spark + spark-streaming-kinesis-asl_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + + com.fasterxml.jackson.core + jackson-databind + provided + + + commons-lang + commons-lang + provided + + + com.google.protobuf + protobuf-java + provided + + + com.sun.jersey + jersey-server + provided + + + com.sun.jersey + jersey-core + provided + + + log4j + log4j + provided + + + net.java.dev.jets3t + jets3t + provided + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.avro + avro-ipc + provided + + + org.apache.avro + avro-mapred + ${avro.mapred.classifier} + provided + + + org.apache.curator + curator-recipes + provided + + + org.apache.zookeeper + zookeeper + provided + + + org.slf4j + slf4j-api + provided + + + org.slf4j + slf4j-log4j12 + provided + + + org.xerial.snappy + snappy-java + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + + diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml new file mode 100644 index 0000000000000..935155eb5d362 --- /dev/null +++ b/external/kinesis-asl/pom.xml @@ -0,0 +1,87 @@ + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.0.0-SNAPSHOT + ../../pom.xml + + + + org.apache.spark + spark-streaming-kinesis-asl_2.11 + jar + Spark Kinesis Integration + + + streaming-kinesis-asl + + + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + test-jar + test + + + com.amazonaws + amazon-kinesis-client + ${aws.kinesis.client.version} + + + com.amazonaws + amazon-kinesis-producer + ${aws.kinesis.producer.version} + test + + + org.mockito + mockito-core + test + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.apache.spark + spark-test-tags_${scala.binary.version} + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java similarity index 94% rename from extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java rename to external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java index 06e0ff28afd95..0e43e9272d7c3 100644 --- a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java +++ b/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java @@ -16,7 +16,10 @@ */ package org.apache.spark.examples.streaming; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; import java.util.List; import java.util.regex.Pattern; @@ -38,7 +41,6 @@ import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; import com.amazonaws.services.kinesis.AmazonKinesisClient; import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; -import com.google.common.collect.Lists; /** * Consumes messages from a Amazon Kinesis streams and does wordcount. @@ -134,11 +136,12 @@ public static void main(String[] args) { JavaStreamingContext jssc = new JavaStreamingContext(sparkConfig, batchInterval); // Create the Kinesis DStreams - List> streamsList = new ArrayList>(numStreams); + List> streamsList = new ArrayList<>(numStreams); for (int i = 0; i < numStreams; i++) { streamsList.add( KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, - InitialPositionInStream.LATEST, kinesisCheckpointInterval, StorageLevel.MEMORY_AND_DISK_2()) + InitialPositionInStream.LATEST, kinesisCheckpointInterval, + StorageLevel.MEMORY_AND_DISK_2()) ); } @@ -154,8 +157,9 @@ public static void main(String[] args) { // Convert each line of Array[Byte] to String, and split into words JavaDStream words = unionStreams.flatMap(new FlatMapFunction() { @Override - public Iterable call(byte[] line) { - return Lists.newArrayList(WORD_SEPARATOR.split(new String(line))); + public Iterator call(byte[] line) { + String s = new String(line, StandardCharsets.UTF_8); + return Arrays.asList(WORD_SEPARATOR.split(s)).iterator(); } }); @@ -164,7 +168,7 @@ public Iterable call(byte[] line) { new PairFunction() { @Override public Tuple2 call(String s) { - return new Tuple2(s, 1); + return new Tuple2<>(s, 1); } } ).reduceByKey( diff --git a/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py b/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py similarity index 94% rename from extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py rename to external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py index f428f64da3c42..4d7fc9a549bfb 100644 --- a/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py +++ b/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py @@ -34,9 +34,9 @@ $ export AWS_SECRET_KEY= # run the example - $ bin/spark-submit -jar extras/kinesis-asl/target/scala-*/\ + $ bin/spark-submit -jar external/kinesis-asl/target/scala-*/\ spark-streaming-kinesis-asl-assembly_*.jar \ - extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py \ + external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py \ myAppName mySparkStream https://kinesis.us-east-1.amazonaws.com There is a companion helper class called KinesisWordProducerASL which puts dummy data @@ -54,6 +54,8 @@ See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more details on the Kinesis Spark Streaming integration. """ +from __future__ import print_function + import sys from pyspark import SparkContext diff --git a/external/kinesis-asl/src/main/resources/log4j.properties b/external/kinesis-asl/src/main/resources/log4j.properties new file mode 100644 index 0000000000000..8118d12c5d474 --- /dev/null +++ b/external/kinesis-asl/src/main/resources/log4j.properties @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +log4j.rootCategory=WARN, console + +# File appender +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=false +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n + +# Console appender +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.out +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + +# Settings to quiet third party logs that are too verbose +log4j.logger.org.spark_project.jetty=WARN +log4j.logger.org.spark_project.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO +log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO \ No newline at end of file diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala similarity index 98% rename from extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala rename to external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index de749626ec09c..859fe9edb44fc 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -22,14 +22,15 @@ import java.nio.ByteBuffer import scala.util.Random -import com.amazonaws.auth.{DefaultAWSCredentialsProviderChain, BasicAWSCredentials} +import com.amazonaws.auth.{BasicAWSCredentials, DefaultAWSCredentialsProviderChain} import com.amazonaws.regions.RegionUtils import com.amazonaws.services.kinesis.AmazonKinesisClient import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.model.PutRecordRequest import org.apache.log4j.{Level, Logger} -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.streaming.dstream.DStream.toPairDStreamFunctions diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala similarity index 92% rename from extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala rename to external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index 000897a4e7290..45dc3c388cb8d 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -23,9 +23,11 @@ import scala.util.control.NonFatal import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.kinesis.clientlibrary.types.UserRecord import com.amazonaws.services.kinesis.model._ import org.apache.spark._ +import org.apache.spark.internal.Logging import org.apache.spark.rdd.{BlockRDD, BlockRDDPartition} import org.apache.spark.storage.BlockId import org.apache.spark.util.NextIterator @@ -69,26 +71,26 @@ class KinesisBackedBlockRDDPartition( */ private[kinesis] class KinesisBackedBlockRDD[T: ClassTag]( - @transient sc: SparkContext, + sc: SparkContext, val regionName: String, val endpointUrl: String, - @transient blockIds: Array[BlockId], + @transient private val _blockIds: Array[BlockId], @transient val arrayOfseqNumberRanges: Array[SequenceNumberRanges], - @transient isBlockIdValid: Array[Boolean] = Array.empty, + @transient private val isBlockIdValid: Array[Boolean] = Array.empty, val retryTimeoutMs: Int = 10000, val messageHandler: Record => T = KinesisUtils.defaultMessageHandler _, val awsCredentialsOption: Option[SerializableAWSCredentials] = None - ) extends BlockRDD[T](sc, blockIds) { + ) extends BlockRDD[T](sc, _blockIds) { - require(blockIds.length == arrayOfseqNumberRanges.length, + require(_blockIds.length == arrayOfseqNumberRanges.length, "Number of blockIds is not equal to the number of sequence number ranges") override def isValid(): Boolean = true override def getPartitions: Array[Partition] = { - Array.tabulate(blockIds.length) { i => + Array.tabulate(_blockIds.length) { i => val isValid = if (isBlockIdValid.length == 0) true else isBlockIdValid(i) - new KinesisBackedBlockRDDPartition(i, blockIds(i), isValid, arrayOfseqNumberRanges(i)) + new KinesisBackedBlockRDDPartition(i, _blockIds(i), isValid, arrayOfseqNumberRanges(i)) } } @@ -210,7 +212,10 @@ class KinesisSequenceRangeIterator( s"getting records using shard iterator") { client.getRecords(getRecordsRequest) } - (getRecordsResult.getRecords.iterator().asScala, getRecordsResult.getNextShardIterator) + // De-aggregate records, if KPL was used in producing the records. The KCL automatically + // handles de-aggregation during regular operation. This code path is used during recovery + val recordIterator = UserRecord.deaggregate(getRecordsResult.getRecords) + (recordIterator.iterator().asScala, getRecordsResult.getNextShardIterator) } /** diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala new file mode 100644 index 0000000000000..70b5cc7ca0e8e --- /dev/null +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.util.concurrent._ + +import scala.util.control.NonFatal + +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer +import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason + +import org.apache.spark.internal.Logging +import org.apache.spark.streaming.Duration +import org.apache.spark.streaming.util.RecurringTimer +import org.apache.spark.util.{Clock, SystemClock, ThreadUtils} + +/** + * This is a helper class for managing Kinesis checkpointing. + * + * @param receiver The receiver that keeps track of which sequence numbers we can checkpoint + * @param checkpointInterval How frequently we will checkpoint to DynamoDB + * @param workerId Worker Id of KCL worker for logging purposes + * @param clock In order to use ManualClocks for the purpose of testing + */ +private[kinesis] class KinesisCheckpointer( + receiver: KinesisReceiver[_], + checkpointInterval: Duration, + workerId: String, + clock: Clock = new SystemClock) extends Logging { + + // a map from shardId's to checkpointers + private val checkpointers = new ConcurrentHashMap[String, IRecordProcessorCheckpointer]() + + private val lastCheckpointedSeqNums = new ConcurrentHashMap[String, String]() + + private val checkpointerThread: RecurringTimer = startCheckpointerThread() + + /** Update the checkpointer instance to the most recent one for the given shardId. */ + def setCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + checkpointers.put(shardId, checkpointer) + } + + /** + * Stop tracking the specified shardId. + * + * If a checkpointer is provided, e.g. on IRecordProcessor.shutdown [[ShutdownReason.TERMINATE]], + * we will use that to make the final checkpoint. If `null` is provided, we will not make the + * checkpoint, e.g. in case of [[ShutdownReason.ZOMBIE]]. + */ + def removeCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + synchronized { + checkpointers.remove(shardId) + checkpoint(shardId, checkpointer) + } + } + + /** Perform the checkpoint. */ + private def checkpoint(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + try { + if (checkpointer != null) { + receiver.getLatestSeqNumToCheckpoint(shardId).foreach { latestSeqNum => + val lastSeqNum = lastCheckpointedSeqNums.get(shardId) + // Kinesis sequence numbers are monotonically increasing strings, therefore we can do + // safely do the string comparison + if (lastSeqNum == null || latestSeqNum > lastSeqNum) { + /* Perform the checkpoint */ + KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(latestSeqNum), 4, 100) + logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint at sequence number" + + s" $latestSeqNum for shardId $shardId") + lastCheckpointedSeqNums.put(shardId, latestSeqNum) + } + } + } else { + logDebug(s"Checkpointing skipped for shardId $shardId. Checkpointer not set.") + } + } catch { + case NonFatal(e) => + logWarning(s"Failed to checkpoint shardId $shardId to DynamoDB.", e) + } + } + + /** Checkpoint the latest saved sequence numbers for all active shardId's. */ + private def checkpointAll(): Unit = synchronized { + // if this method throws an exception, then the scheduled task will not run again + try { + val shardIds = checkpointers.keys() + while (shardIds.hasMoreElements) { + val shardId = shardIds.nextElement() + checkpoint(shardId, checkpointers.get(shardId)) + } + } catch { + case NonFatal(e) => + logWarning("Failed to checkpoint to DynamoDB.", e) + } + } + + /** + * Start the checkpointer thread with the given checkpoint duration. + */ + private def startCheckpointerThread(): RecurringTimer = { + val period = checkpointInterval.milliseconds + val threadName = s"Kinesis Checkpointer - Worker $workerId" + val timer = new RecurringTimer(clock, period, _ => checkpointAll(), threadName) + timer.start() + logDebug(s"Started checkpointer thread: $threadName") + timer + } + + /** + * Shutdown the checkpointer. Should be called on the onStop of the Receiver. + */ + def shutdown(): Unit = { + // the recurring timer checkpoints for us one last time. + checkpointerThread.stop(interruptTimer = false) + checkpointers.clear() + lastCheckpointedSeqNums.clear() + logInfo("Successfully shutdown Kinesis Checkpointer.") + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala similarity index 98% rename from extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala rename to external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index 72ab6357a53b0..5223c81a8e0e0 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -24,13 +24,13 @@ import com.amazonaws.services.kinesis.model.Record import org.apache.spark.rdd.RDD import org.apache.spark.storage.{BlockId, StorageLevel} +import org.apache.spark.streaming.{Duration, StreamingContext, Time} import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.streaming.scheduler.ReceivedBlockInfo -import org.apache.spark.streaming.{Duration, StreamingContext, Time} private[kinesis] class KinesisInputDStream[T: ClassTag]( - @transient _ssc: StreamingContext, + _ssc: StreamingContext, streamName: String, endpointUrl: String, regionName: String, diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala similarity index 81% rename from extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala rename to external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 134d627cdaffa..858368d135b6a 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -17,22 +17,22 @@ package org.apache.spark.streaming.kinesis import java.util.UUID +import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.control.NonFatal import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, DefaultAWSCredentialsProviderChain} -import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorFactory} +import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer, IRecordProcessorFactory} import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker} import com.amazonaws.services.kinesis.model.Record +import org.apache.spark.internal.Logging import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming.Duration import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver} import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkEnv} - private[kinesis] case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) @@ -47,17 +47,18 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) * https://github.com/awslabs/amazon-kinesis-client * * The way this Receiver works is as follows: - * - The receiver starts a KCL Worker, which is essentially runs a threadpool of multiple - * KinesisRecordProcessor - * - Each KinesisRecordProcessor receives data from a Kinesis shard in batches. Each batch is - * inserted into a Block Generator, and the corresponding range of sequence numbers is recorded. - * - When the block generator defines a block, then the recorded sequence number ranges that were - * inserted into the block are recorded separately for being used later. - * - When the block is ready to be pushed, the block is pushed and the ranges are reported as - * metadata of the block. In addition, the ranges are used to find out the latest sequence - * number for each shard that can be checkpointed through the DynamoDB. - * - Periodically, each KinesisRecordProcessor checkpoints the latest successfully stored sequence - * number for it own shard. + * + * - The receiver starts a KCL Worker, which is essentially runs a threadpool of multiple + * KinesisRecordProcessor + * - Each KinesisRecordProcessor receives data from a Kinesis shard in batches. Each batch is + * inserted into a Block Generator, and the corresponding range of sequence numbers is recorded. + * - When the block generator defines a block, then the recorded sequence number ranges that were + * inserted into the block are recorded separately for being used later. + * - When the block is ready to be pushed, the block is pushed and the ranges are reported as + * metadata of the block. In addition, the ranges are used to find out the latest sequence + * number for each shard that can be checkpointed through the DynamoDB. + * - Periodically, each KinesisRecordProcessor checkpoints the latest successfully stored sequence + * number for it own shard. * * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) @@ -124,14 +125,18 @@ private[kinesis] class KinesisReceiver[T]( private val seqNumRangesInCurrentBlock = new mutable.ArrayBuffer[SequenceNumberRange] /** Sequence number ranges of data added to each generated block */ - private val blockIdToSeqNumRanges = new mutable.HashMap[StreamBlockId, SequenceNumberRanges] - with mutable.SynchronizedMap[StreamBlockId, SequenceNumberRanges] + private val blockIdToSeqNumRanges = new ConcurrentHashMap[StreamBlockId, SequenceNumberRanges] + + /** + * The centralized kinesisCheckpointer that checkpoints based on the given checkpointInterval. + */ + @volatile private var kinesisCheckpointer: KinesisCheckpointer = null /** * Latest sequence number ranges that have been stored successfully. * This is used for checkpointing through KCL */ - private val shardIdToLatestStoredSeqNum = new mutable.HashMap[String, String] - with mutable.SynchronizedMap[String, String] + private val shardIdToLatestStoredSeqNum = new ConcurrentHashMap[String, String] + /** * This is called when the KinesisReceiver starts and must be non-blocking. * The KCL creates and manages the receiving/processing thread pool through Worker.run(). @@ -141,6 +146,7 @@ private[kinesis] class KinesisReceiver[T]( workerId = Utils.localHostName() + ":" + UUID.randomUUID() + kinesisCheckpointer = new KinesisCheckpointer(receiver, checkpointInterval, workerId) // KCL config instance val awsCredProvider = resolveAWSCredentialsProvider() val kinesisClientLibConfiguration = @@ -157,8 +163,8 @@ private[kinesis] class KinesisReceiver[T]( * We're using our custom KinesisRecordProcessor in this case. */ val recordProcessorFactory = new IRecordProcessorFactory { - override def createProcessor: IRecordProcessor = new KinesisRecordProcessor(receiver, - workerId, new KinesisCheckpointState(checkpointInterval)) + override def createProcessor: IRecordProcessor = + new KinesisRecordProcessor(receiver, workerId) } worker = new Worker(recordProcessorFactory, kinesisClientLibConfiguration) @@ -179,6 +185,7 @@ private[kinesis] class KinesisReceiver[T]( workerThread.setName(s"Kinesis Receiver ${streamId}") workerThread.setDaemon(true) workerThread.start() + logInfo(s"Started receiver with workerId $workerId") } @@ -198,6 +205,10 @@ private[kinesis] class KinesisReceiver[T]( logInfo(s"Stopped receiver for workerId $workerId") } workerId = null + if (kinesisCheckpointer != null) { + kinesisCheckpointer.shutdown() + kinesisCheckpointer = null + } } /** Add records of the given shard to the current block being generated */ @@ -207,13 +218,31 @@ private[kinesis] class KinesisReceiver[T]( val metadata = SequenceNumberRange(streamName, shardId, records.get(0).getSequenceNumber(), records.get(records.size() - 1).getSequenceNumber()) blockGenerator.addMultipleDataWithCallback(dataIterator, metadata) - } } /** Get the latest sequence number for the given shard that can be checkpointed through KCL */ private[kinesis] def getLatestSeqNumToCheckpoint(shardId: String): Option[String] = { - shardIdToLatestStoredSeqNum.get(shardId) + Option(shardIdToLatestStoredSeqNum.get(shardId)) + } + + /** + * Set the checkpointer that will be used to checkpoint sequence numbers to DynamoDB for the + * given shardId. + */ + def setCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + assert(kinesisCheckpointer != null, "Kinesis Checkpointer not initialized!") + kinesisCheckpointer.setCheckpointer(shardId, checkpointer) + } + + /** + * Remove the checkpointer for the given shardId. The provided checkpointer will be used to + * checkpoint one last time for the given shard. If `checkpointer` is `null`, then we will not + * checkpoint. + */ + def removeCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + assert(kinesisCheckpointer != null, "Kinesis Checkpointer not initialized!") + kinesisCheckpointer.removeCheckpointer(shardId, checkpointer) } /** @@ -229,7 +258,7 @@ private[kinesis] class KinesisReceiver[T]( * for next block. Internally, this is synchronized with `rememberAddedRange()`. */ private def finalizeRangesForCurrentBlock(blockId: StreamBlockId): Unit = { - blockIdToSeqNumRanges(blockId) = SequenceNumberRanges(seqNumRangesInCurrentBlock.toArray) + blockIdToSeqNumRanges.put(blockId, SequenceNumberRanges(seqNumRangesInCurrentBlock.toArray)) seqNumRangesInCurrentBlock.clear() logDebug(s"Generated block $blockId has $blockIdToSeqNumRanges") } @@ -237,7 +266,7 @@ private[kinesis] class KinesisReceiver[T]( /** Store the block along with its associated ranges */ private def storeBlockWithRanges( blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[T]): Unit = { - val rangesToReportOption = blockIdToSeqNumRanges.remove(blockId) + val rangesToReportOption = Option(blockIdToSeqNumRanges.remove(blockId)) if (rangesToReportOption.isEmpty) { stop("Error while storing block into Spark, could not find sequence number ranges " + s"for block $blockId") @@ -266,7 +295,7 @@ private[kinesis] class KinesisReceiver[T]( // Note that we are doing this sequentially because the array of sequence number ranges // is assumed to be rangesToReport.ranges.foreach { range => - shardIdToLatestStoredSeqNum(range.shardId) = range.toSeqNumber + shardIdToLatestStoredSeqNum.put(range.shardId, range.toSeqNumber) } } diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala new file mode 100644 index 0000000000000..80e0cce055862 --- /dev/null +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.util.List + +import scala.util.Random +import scala.util.control.NonFatal + +import com.amazonaws.services.kinesis.clientlibrary.exceptions.{InvalidStateException, KinesisClientLibDependencyException, ShutdownException, ThrottlingException} +import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer} +import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason +import com.amazonaws.services.kinesis.model.Record + +import org.apache.spark.internal.Logging +import org.apache.spark.streaming.Duration + +/** + * Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor. + * This implementation operates on the Array[Byte] from the KinesisReceiver. + * The Kinesis Worker creates an instance of this KinesisRecordProcessor for each + * shard in the Kinesis stream upon startup. This is normally done in separate threads, + * but the KCLs within the KinesisReceivers will balance themselves out if you create + * multiple Receivers. + * + * @param receiver Kinesis receiver + * @param workerId for logging purposes + */ +private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], workerId: String) + extends IRecordProcessor with Logging { + + // shardId populated during initialize() + @volatile + private var shardId: String = _ + + /** + * The Kinesis Client Library calls this method during IRecordProcessor initialization. + * + * @param shardId assigned by the KCL to this particular RecordProcessor. + */ + override def initialize(shardId: String) { + this.shardId = shardId + logInfo(s"Initialized workerId $workerId with shardId $shardId") + } + + /** + * This method is called by the KCL when a batch of records is pulled from the Kinesis stream. + * This is the record-processing bridge between the KCL's IRecordProcessor.processRecords() + * and Spark Streaming's Receiver.store(). + * + * @param batch list of records from the Kinesis stream shard + * @param checkpointer used to update Kinesis when this batch has been processed/stored + * in the DStream + */ + override def processRecords(batch: List[Record], checkpointer: IRecordProcessorCheckpointer) { + if (!receiver.isStopped()) { + try { + receiver.addRecords(shardId, batch) + logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") + receiver.setCheckpointer(shardId, checkpointer) + } catch { + case NonFatal(e) => + /* + * If there is a failure within the batch, the batch will not be checkpointed. + * This will potentially cause records since the last checkpoint to be processed + * more than once. + */ + logError(s"Exception: WorkerId $workerId encountered and exception while storing " + + s" or checkpointing a batch for workerId $workerId and shardId $shardId.", e) + + /* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor. */ + throw e + } + } else { + /* RecordProcessor has been stopped. */ + logInfo(s"Stopped: KinesisReceiver has stopped for workerId $workerId" + + s" and shardId $shardId. No more records will be processed.") + } + } + + /** + * Kinesis Client Library is shutting down this Worker for 1 of 2 reasons: + * 1) the stream is resharding by splitting or merging adjacent shards + * (ShutdownReason.TERMINATE) + * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason + * (ShutdownReason.ZOMBIE) + * + * @param checkpointer used to perform a Kinesis checkpoint for ShutdownReason.TERMINATE + * @param reason for shutdown (ShutdownReason.TERMINATE or ShutdownReason.ZOMBIE) + */ + override def shutdown(checkpointer: IRecordProcessorCheckpointer, reason: ShutdownReason) { + logInfo(s"Shutdown: Shutting down workerId $workerId with reason $reason") + reason match { + /* + * TERMINATE Use Case. Checkpoint. + * Checkpoint to indicate that all records from the shard have been drained and processed. + * It's now OK to read from the new shards that resulted from a resharding event. + */ + case ShutdownReason.TERMINATE => + receiver.removeCheckpointer(shardId, checkpointer) + + /* + * ZOMBIE Use Case or Unknown reason. NoOp. + * No checkpoint because other workers may have taken over and already started processing + * the same records. + * This may lead to records being processed more than once. + */ + case _ => + receiver.removeCheckpointer(shardId, null) // return null so that we don't checkpoint + } + + } +} + +private[kinesis] object KinesisRecordProcessor extends Logging { + /** + * Retry the given amount of times with a random backoff time (millis) less than the + * given maxBackOffMillis + * + * @param expression expression to evaluate + * @param numRetriesLeft number of retries left + * @param maxBackOffMillis: max millis between retries + * + * @return evaluation of the given expression + * @throws Unretryable exception, unexpected exception, + * or any exception that persists after numRetriesLeft reaches 0 + */ + @annotation.tailrec + def retryRandom[T](expression: => T, numRetriesLeft: Int, maxBackOffMillis: Int): T = { + util.Try { expression } match { + /* If the function succeeded, evaluate to x. */ + case util.Success(x) => x + /* If the function failed, either retry or throw the exception */ + case util.Failure(e) => e match { + /* Retry: Throttling or other Retryable exception has occurred */ + case _: ThrottlingException | _: KinesisClientLibDependencyException + if numRetriesLeft > 1 => + val backOffMillis = Random.nextInt(maxBackOffMillis) + Thread.sleep(backOffMillis) + logError(s"Retryable Exception: Random backOffMillis=${backOffMillis}", e) + retryRandom(expression, numRetriesLeft - 1, maxBackOffMillis) + /* Throw: Shutdown has been requested by the Kinesis Client Library. */ + case _: ShutdownException => + logError(s"ShutdownException: Caught shutdown exception, skipping checkpoint.", e) + throw e + /* Throw: Non-retryable exception has occurred with the Kinesis Client Library */ + case _: InvalidStateException => + logError(s"InvalidStateException: Cannot save checkpoint to the DynamoDB table used" + + s" by the Amazon Kinesis Client Library. Table likely doesn't exist.", e) + throw e + /* Throw: Unexpected exception has occurred */ + case _ => + logError(s"Unexpected, non-retryable exception.", e) + throw e + } + } + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala similarity index 84% rename from extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala rename to external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index 634bf94521079..0fe66254e989d 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming.kinesis import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ @@ -32,10 +33,12 @@ import com.amazonaws.services.dynamodbv2.document.DynamoDB import com.amazonaws.services.kinesis.AmazonKinesisClient import com.amazonaws.services.kinesis.model._ -import org.apache.spark.Logging +import org.apache.spark.internal.Logging /** - * Shared utility methods for performing Kinesis tests that actually transfer data + * Shared utility methods for performing Kinesis tests that actually transfer data. + * + * PLEASE KEEP THIS FILE UNDER src/main AS PYTHON TESTS NEED ACCESS TO THIS FILE! */ private[kinesis] class KinesisTestUtils extends Logging { @@ -52,7 +55,7 @@ private[kinesis] class KinesisTestUtils extends Logging { @volatile private var _streamName: String = _ - private lazy val kinesisClient = { + protected lazy val kinesisClient = { val client = new AmazonKinesisClient(KinesisTestUtils.getAWSCredentials()) client.setEndpoint(endpointUrl) client @@ -64,6 +67,14 @@ private[kinesis] class KinesisTestUtils extends Logging { new DynamoDB(dynamoDBClient) } + protected def getProducer(aggregate: Boolean): KinesisDataGenerator = { + if (!aggregate) { + new SimpleDataGenerator(kinesisClient) + } else { + throw new UnsupportedOperationException("Aggregation is not supported through this code path") + } + } + def streamName: String = { require(streamCreated, "Stream not yet created, call createStream() to create one") _streamName @@ -90,24 +101,10 @@ private[kinesis] class KinesisTestUtils extends Logging { * Push data to Kinesis stream and return a map of * shardId -> seq of (data, seq number) pushed to corresponding shard */ - def pushData(testData: Seq[Int]): Map[String, Seq[(Int, String)]] = { + def pushData(testData: Seq[Int], aggregate: Boolean): Map[String, Seq[(Int, String)]] = { require(streamCreated, "Stream not yet created, call createStream() to create one") - val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() - - testData.foreach { num => - val str = num.toString - val putRecordRequest = new PutRecordRequest().withStreamName(streamName) - .withData(ByteBuffer.wrap(str.getBytes())) - .withPartitionKey(str) - - val putRecordResult = kinesisClient.putRecord(putRecordRequest) - val shardId = putRecordResult.getShardId - val seqNumber = putRecordResult.getSequenceNumber() - val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, - new ArrayBuffer[(Int, String)]()) - sentSeqNumbers += ((num, seqNumber)) - } - + val producer = getProducer(aggregate) + val shardIdToSeqNumbers = producer.sendData(streamName, testData) logInfo(s"Pushed $testData:\n\t ${shardIdToSeqNumbers.mkString("\n\t")}") shardIdToSeqNumbers.toMap } @@ -116,7 +113,7 @@ private[kinesis] class KinesisTestUtils extends Logging { * Expose a Python friendly API. */ def pushData(testData: java.util.List[Int]): Unit = { - pushData(testData.asScala) + pushData(testData.asScala, aggregate = false) } def deleteStream(): Unit = { @@ -233,3 +230,32 @@ private[kinesis] object KinesisTestUtils { } } } + +/** A wrapper interface that will allow us to consolidate the code for synthetic data generation. */ +private[kinesis] trait KinesisDataGenerator { + /** Sends the data to Kinesis and returns the metadata for everything that has been sent. */ + def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] +} + +private[kinesis] class SimpleDataGenerator( + client: AmazonKinesisClient) extends KinesisDataGenerator { + override def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] = { + val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() + data.foreach { num => + val str = num.toString + val data = ByteBuffer.wrap(str.getBytes(StandardCharsets.UTF_8)) + val putRecordRequest = new PutRecordRequest().withStreamName(streamName) + .withData(data) + .withPartitionKey(str) + + val putRecordResult = client.putRecord(putRecordRequest) + val shardId = putRecordResult.getShardId + val seqNumber = putRecordResult.getSequenceNumber() + val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, + new ArrayBuffer[(Int, String)]()) + sentSeqNumbers += ((num, seqNumber)) + } + + shardIdToSeqNumbers.toMap + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala similarity index 84% rename from extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala rename to external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 2849fd8a82102..a0007d33d6257 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -24,9 +24,9 @@ import com.amazonaws.services.kinesis.model.Record import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Duration, StreamingContext} import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.{Duration, StreamingContext} object KinesisUtils { /** @@ -221,50 +221,6 @@ object KinesisUtils { } } - /** - * Create an input stream that pulls messages from a Kinesis stream. - * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. - * - * Note: - * - The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets AWS credentials. - * - The region of the `endpointUrl` will be used for DynamoDB and CloudWatch. - * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name in - * [[org.apache.spark.SparkConf]]. - * - * @param ssc StreamingContext object - * @param streamName Kinesis stream name - * @param endpointUrl Endpoint url of Kinesis service - * (e.g., https://kinesis.us-east-1.amazonaws.com) - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. - * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the - * worker's initial starting position in the stream. - * The values are either the beginning of the stream - * per Kinesis' limit of 24 hours - * (InitialPositionInStream.TRIM_HORIZON) or - * the tip of the stream (InitialPositionInStream.LATEST). - * @param storageLevel Storage level to use for storing the received objects - * StorageLevel.MEMORY_AND_DISK_2 is recommended. - */ - @deprecated("use other forms of createStream", "1.4.0") - def createStream( - ssc: StreamingContext, - streamName: String, - endpointUrl: String, - checkpointInterval: Duration, - initialPositionInStream: InitialPositionInStream, - storageLevel: StorageLevel - ): ReceiverInputDStream[Array[Byte]] = { - ssc.withNamedScope("kinesis stream") { - new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, - getRegionByEndpoint(endpointUrl), initialPositionInStream, ssc.sc.appName, - checkpointInterval, storageLevel, defaultMessageHandler, None) - } - } - /** * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. @@ -452,47 +408,6 @@ object KinesisUtils { defaultMessageHandler(_), awsAccessKeyId, awsSecretKey) } - /** - * Create an input stream that pulls messages from a Kinesis stream. - * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. - * - * Note: - * - The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets AWS credentials. - * - The region of the `endpointUrl` will be used for DynamoDB and CloudWatch. - * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name in - * [[org.apache.spark.SparkConf]]. - * - * @param jssc Java StreamingContext object - * @param streamName Kinesis stream name - * @param endpointUrl Endpoint url of Kinesis service - * (e.g., https://kinesis.us-east-1.amazonaws.com) - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. - * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the - * worker's initial starting position in the stream. - * The values are either the beginning of the stream - * per Kinesis' limit of 24 hours - * (InitialPositionInStream.TRIM_HORIZON) or - * the tip of the stream (InitialPositionInStream.LATEST). - * @param storageLevel Storage level to use for storing the received objects - * StorageLevel.MEMORY_AND_DISK_2 is recommended. - */ - @deprecated("use other forms of createStream", "1.4.0") - def createStream( - jssc: JavaStreamingContext, - streamName: String, - endpointUrl: String, - checkpointInterval: Duration, - initialPositionInStream: InitialPositionInStream, - storageLevel: StorageLevel - ): JavaReceiverInputDStream[Array[Byte]] = { - createStream( - jssc.ssc, streamName, endpointUrl, checkpointInterval, initialPositionInStream, storageLevel) - } - private def getRegionByEndpoint(endpointUrl: String): String = { RegionUtils.getRegionByEndpoint(endpointUrl).getName() } diff --git a/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java similarity index 85% rename from extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java rename to external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java index 3f0f6793d2d21..f078973c6c285 100644 --- a/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java +++ b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.streaming.kinesis; +import com.amazonaws.regions.RegionUtils; import com.amazonaws.services.kinesis.model.Record; import org.junit.Test; @@ -28,19 +29,19 @@ import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; -import java.nio.ByteBuffer; - /** * Demonstrate the use of the KinesisUtils Java API */ public class JavaKinesisStreamSuite extends LocalJavaStreamingContext { @Test public void testKinesisStream() { - // Tests the API, does not actually test data receiving - JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", new Duration(2000), - InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2()); + String dummyEndpointUrl = KinesisTestUtils.defaultEndpointUrl(); + String dummyRegionName = RegionUtils.getRegionByEndpoint(dummyEndpointUrl).getName(); + // Tests the API, does not actually test data receiving + JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "myAppName", "mySparkStream", + dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, new Duration(2000), + StorageLevel.MEMORY_AND_DISK_2()); ssc.stop(); } diff --git a/external/kinesis-asl/src/test/resources/log4j.properties b/external/kinesis-asl/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..3706a6e361307 --- /dev/null +++ b/external/kinesis-asl/src/test/resources/log4j.properties @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.spark_project.jetty=WARN diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala new file mode 100644 index 0000000000000..0b455e574e6fa --- /dev/null +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import com.amazonaws.services.kinesis.producer.{KinesisProducer => KPLProducer, KinesisProducerConfiguration, UserRecordResult} +import com.google.common.util.concurrent.{FutureCallback, Futures} + +private[kinesis] class KPLBasedKinesisTestUtils extends KinesisTestUtils { + override protected def getProducer(aggregate: Boolean): KinesisDataGenerator = { + if (!aggregate) { + new SimpleDataGenerator(kinesisClient) + } else { + new KPLDataGenerator(regionName) + } + } +} + +/** A wrapper for the KinesisProducer provided in the KPL. */ +private[kinesis] class KPLDataGenerator(regionName: String) extends KinesisDataGenerator { + + private lazy val producer: KPLProducer = { + val conf = new KinesisProducerConfiguration() + .setRecordMaxBufferedTime(1000) + .setMaxConnections(1) + .setRegion(regionName) + .setMetricsLevel("none") + + new KPLProducer(conf) + } + + override def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] = { + val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() + data.foreach { num => + val str = num.toString + val data = ByteBuffer.wrap(str.getBytes(StandardCharsets.UTF_8)) + val future = producer.addUserRecord(streamName, str, data) + val kinesisCallBack = new FutureCallback[UserRecordResult]() { + override def onFailure(t: Throwable): Unit = {} // do nothing + + override def onSuccess(result: UserRecordResult): Unit = { + val shardId = result.getShardId + val seqNumber = result.getSequenceNumber() + val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, + new ArrayBuffer[(Int, String)]()) + sentSeqNumbers += ((num, seqNumber)) + } + } + Futures.addCallback(future, kinesisCallBack) + } + producer.flushSync() + shardIdToSeqNumbers.toMap + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala similarity index 89% rename from extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala rename to external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index 9f9e146a08d46..905c33834df16 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.streaming.kinesis -import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfterEach +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException} import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} -import org.apache.spark.{SparkConf, SparkContext, SparkException} -class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll { +abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) + extends KinesisFunSuite with BeforeAndAfterEach with LocalSparkContext { private val testData = 1 to 8 @@ -34,16 +35,15 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll private var shardIdToRange: Map[String, SequenceNumberRange] = null private var allRanges: Seq[SequenceNumberRange] = null - private var sc: SparkContext = null private var blockManager: BlockManager = null - override def beforeAll(): Unit = { + super.beforeAll() runIfTestsEnabled("Prepare KinesisTestUtils") { - testUtils = new KinesisTestUtils() + testUtils = new KPLBasedKinesisTestUtils() testUtils.createStream() - shardIdToDataAndSeqNumbers = testUtils.pushData(testData) + shardIdToDataAndSeqNumbers = testUtils.pushData(testData, aggregate = aggregateTestData) require(shardIdToDataAndSeqNumbers.size > 1, "Need data to be sent to multiple shards") shardIds = shardIdToDataAndSeqNumbers.keySet.toSeq @@ -55,19 +55,23 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll (shardId, seqNumRange) } allRanges = shardIdToRange.values.toSeq - - val conf = new SparkConf().setMaster("local[4]").setAppName("KinesisBackedBlockRDDSuite") - sc = new SparkContext(conf) - blockManager = sc.env.blockManager } } + override def beforeEach(): Unit = { + super.beforeEach() + val conf = new SparkConf().setMaster("local[4]").setAppName("KinesisBackedBlockRDDSuite") + sc = new SparkContext(conf) + blockManager = sc.env.blockManager + } + override def afterAll(): Unit = { - if (testUtils != null) { - testUtils.deleteStream() - } - if (sc != null) { - sc.stop() + try { + if (testUtils != null) { + testUtils.deleteStream() + } + } finally { + super.afterAll() } } @@ -118,7 +122,7 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll testIsBlockValid = true) } - testIfEnabled("Test whether RDD is valid after removing blocks from block anager") { + testIfEnabled("Test whether RDD is valid after removing blocks from block manager") { testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 2, testBlockRemove = true) } @@ -158,9 +162,9 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll testBlockRemove: Boolean = false ): Unit = { require(shardIds.size > 1, "Need at least 2 shards to test") - require(numPartitionsInBM <= shardIds.size , + require(numPartitionsInBM <= shardIds.size, "Number of partitions in BlockManager cannot be more than the Kinesis test shards available") - require(numPartitionsInKinesis <= shardIds.size , + require(numPartitionsInKinesis <= shardIds.size, "Number of partitions in Kinesis cannot be more than the Kinesis test shards available") require(numPartitionsInBM <= numPartitions, "Number of partitions in BlockManager cannot be more than that in RDD") @@ -247,3 +251,9 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll Array.tabulate(num) { i => new StreamBlockId(0, i) } } } + +class WithAggregationKinesisBackedBlockRDDSuite + extends KinesisBackedBlockRDDTests(aggregateTestData = true) + +class WithoutAggregationKinesisBackedBlockRDDSuite + extends KinesisBackedBlockRDDTests(aggregateTestData = false) diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala new file mode 100644 index 0000000000000..e1499a8220991 --- /dev/null +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import java.util.concurrent.{ExecutorService, TimeoutException} + +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration._ +import scala.language.postfixOps + +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} +import org.scalatest.concurrent.Eventually +import org.scalatest.concurrent.Eventually._ +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.streaming.{Duration, TestSuiteBase} +import org.apache.spark.util.ManualClock + +class KinesisCheckpointerSuite extends TestSuiteBase + with MockitoSugar + with BeforeAndAfterEach + with PrivateMethodTester + with Eventually { + + private val workerId = "dummyWorkerId" + private val shardId = "dummyShardId" + private val seqNum = "123" + private val otherSeqNum = "245" + private val checkpointInterval = Duration(10) + private val someSeqNum = Some(seqNum) + private val someOtherSeqNum = Some(otherSeqNum) + + private var receiverMock: KinesisReceiver[Array[Byte]] = _ + private var checkpointerMock: IRecordProcessorCheckpointer = _ + private var kinesisCheckpointer: KinesisCheckpointer = _ + private var clock: ManualClock = _ + + private val checkpoint = PrivateMethod[Unit]('checkpoint) + + override def beforeEach(): Unit = { + receiverMock = mock[KinesisReceiver[Array[Byte]]] + checkpointerMock = mock[IRecordProcessorCheckpointer] + clock = new ManualClock() + kinesisCheckpointer = new KinesisCheckpointer(receiverMock, checkpointInterval, workerId, clock) + } + + test("checkpoint is not called twice for the same sequence number") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock)) + kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock)) + + verify(checkpointerMock, times(1)).checkpoint(anyString()) + } + + test("checkpoint is called after sequence number increases") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)) + .thenReturn(someSeqNum).thenReturn(someOtherSeqNum) + kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock)) + kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock)) + + verify(checkpointerMock, times(1)).checkpoint(seqNum) + verify(checkpointerMock, times(1)).checkpoint(otherSeqNum) + } + + test("should checkpoint if we have exceeded the checkpoint interval") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)) + .thenReturn(someSeqNum).thenReturn(someOtherSeqNum) + + kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock) + clock.advance(5 * checkpointInterval.milliseconds) + + eventually(timeout(1 second)) { + verify(checkpointerMock, times(1)).checkpoint(seqNum) + verify(checkpointerMock, times(1)).checkpoint(otherSeqNum) + } + } + + test("shouldn't checkpoint if we have not exceeded the checkpoint interval") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + + kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock) + clock.advance(checkpointInterval.milliseconds / 2) + + verify(checkpointerMock, never()).checkpoint(anyString()) + } + + test("should not checkpoint for the same sequence number") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + + kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock) + + clock.advance(checkpointInterval.milliseconds * 5) + eventually(timeout(1 second)) { + verify(checkpointerMock, atMost(1)).checkpoint(anyString()) + } + } + + test("removing checkpointer checkpoints one last time") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + + kinesisCheckpointer.removeCheckpointer(shardId, checkpointerMock) + verify(checkpointerMock, times(1)).checkpoint(anyString()) + } + + test("if checkpointing is going on, wait until finished before removing and checkpointing") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)) + .thenReturn(someSeqNum).thenReturn(someOtherSeqNum) + when(checkpointerMock.checkpoint(anyString)).thenAnswer(new Answer[Unit] { + override def answer(invocations: InvocationOnMock): Unit = { + clock.waitTillTime(clock.getTimeMillis() + checkpointInterval.milliseconds / 2) + } + }) + + kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock) + clock.advance(checkpointInterval.milliseconds) + eventually(timeout(1 second)) { + verify(checkpointerMock, times(1)).checkpoint(anyString()) + } + // don't block test thread + val f = Future(kinesisCheckpointer.removeCheckpointer(shardId, checkpointerMock))( + ExecutionContext.global) + + intercept[TimeoutException] { + Await.ready(f, 50 millis) + } + + clock.advance(checkpointInterval.milliseconds / 2) + eventually(timeout(1 second)) { + verify(checkpointerMock, times(2)).checkpoint(anyString()) + } + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala similarity index 98% rename from extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala rename to external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala index ee428f31d6ce3..1c81298a7c201 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala @@ -40,7 +40,7 @@ trait KinesisFunSuite extends SparkFunSuite { if (shouldRunTests) { body } else { - ignore(s"$message [enable by setting env var $envVarNameForEnablingTests=1]")() + ignore(s"$message [enable by setting env var $envVarNameForEnablingTests=1]")(()) } } } diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala new file mode 100644 index 0000000000000..deac9090e2f48 --- /dev/null +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import java.util.Arrays + +import com.amazonaws.services.kinesis.clientlibrary.exceptions._ +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer +import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason +import com.amazonaws.services.kinesis.model.Record +import org.mockito.Matchers._ +import org.mockito.Matchers.{eq => meq} +import org.mockito.Mockito._ +import org.scalatest.{BeforeAndAfter, Matchers} +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.streaming.{Duration, TestSuiteBase} +import org.apache.spark.util.Utils + +/** + * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor + */ +class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAfter + with MockitoSugar { + + val app = "TestKinesisReceiver" + val stream = "mySparkStream" + val endpoint = "endpoint-url" + val workerId = "dummyWorkerId" + val shardId = "dummyShardId" + val seqNum = "dummySeqNum" + val checkpointInterval = Duration(10) + val someSeqNum = Some(seqNum) + + val record1 = new Record() + record1.setData(ByteBuffer.wrap("Spark In Action".getBytes(StandardCharsets.UTF_8))) + val record2 = new Record() + record2.setData(ByteBuffer.wrap("Learning Spark".getBytes(StandardCharsets.UTF_8))) + val batch = Arrays.asList(record1, record2) + + var receiverMock: KinesisReceiver[Array[Byte]] = _ + var checkpointerMock: IRecordProcessorCheckpointer = _ + + override def beforeFunction(): Unit = { + receiverMock = mock[KinesisReceiver[Array[Byte]]] + checkpointerMock = mock[IRecordProcessorCheckpointer] + } + + test("check serializability of SerializableAWSCredentials") { + Utils.deserialize[SerializableAWSCredentials]( + Utils.serialize(new SerializableAWSCredentials("x", "y"))) + } + + test("process records including store and set checkpointer") { + when(receiverMock.isStopped()).thenReturn(false) + + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) + recordProcessor.initialize(shardId) + recordProcessor.processRecords(batch, checkpointerMock) + + verify(receiverMock, times(1)).isStopped() + verify(receiverMock, times(1)).addRecords(shardId, batch) + verify(receiverMock, times(1)).setCheckpointer(shardId, checkpointerMock) + } + + test("shouldn't store and update checkpointer when receiver is stopped") { + when(receiverMock.isStopped()).thenReturn(true) + + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) + recordProcessor.processRecords(batch, checkpointerMock) + + verify(receiverMock, times(1)).isStopped() + verify(receiverMock, never).addRecords(anyString, anyListOf(classOf[Record])) + verify(receiverMock, never).setCheckpointer(anyString, meq(checkpointerMock)) + } + + test("shouldn't update checkpointer when exception occurs during store") { + when(receiverMock.isStopped()).thenReturn(false) + when( + receiverMock.addRecords(anyString, anyListOf(classOf[Record])) + ).thenThrow(new RuntimeException()) + + intercept[RuntimeException] { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) + recordProcessor.initialize(shardId) + recordProcessor.processRecords(batch, checkpointerMock) + } + + verify(receiverMock, times(1)).isStopped() + verify(receiverMock, times(1)).addRecords(shardId, batch) + verify(receiverMock, never).setCheckpointer(anyString, meq(checkpointerMock)) + } + + test("shutdown should checkpoint if the reason is TERMINATE") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) + recordProcessor.initialize(shardId) + recordProcessor.shutdown(checkpointerMock, ShutdownReason.TERMINATE) + + verify(receiverMock, times(1)).removeCheckpointer(meq(shardId), meq(checkpointerMock)) + } + + + test("shutdown should not checkpoint if the reason is something other than TERMINATE") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) + recordProcessor.initialize(shardId) + recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE) + recordProcessor.shutdown(checkpointerMock, null) + + verify(receiverMock, times(2)).removeCheckpointer(meq(shardId), + meq[IRecordProcessorCheckpointer](null)) + } + + test("retry success on first attempt") { + val expectedIsStopped = false + when(receiverMock.isStopped()).thenReturn(expectedIsStopped) + + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + + verify(receiverMock, times(1)).isStopped() + } + + test("retry success on second attempt after a Kinesis throttling exception") { + val expectedIsStopped = false + when(receiverMock.isStopped()) + .thenThrow(new ThrottlingException("error message")) + .thenReturn(expectedIsStopped) + + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + + verify(receiverMock, times(2)).isStopped() + } + + test("retry success on second attempt after a Kinesis dependency exception") { + val expectedIsStopped = false + when(receiverMock.isStopped()) + .thenThrow(new KinesisClientLibDependencyException("error message")) + .thenReturn(expectedIsStopped) + + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + + verify(receiverMock, times(2)).isStopped() + } + + test("retry failed after a shutdown exception") { + when(checkpointerMock.checkpoint()).thenThrow(new ShutdownException("error message")) + + intercept[ShutdownException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) + } + + verify(checkpointerMock, times(1)).checkpoint() + } + + test("retry failed after an invalid state exception") { + when(checkpointerMock.checkpoint()).thenThrow(new InvalidStateException("error message")) + + intercept[InvalidStateException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) + } + + verify(checkpointerMock, times(1)).checkpoint() + } + + test("retry failed after unexpected exception") { + when(checkpointerMock.checkpoint()).thenThrow(new RuntimeException("error message")) + + intercept[RuntimeException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) + } + + verify(checkpointerMock, times(1)).checkpoint() + } + + test("retry failed after exhausting all retries") { + val expectedErrorMessage = "final try error message" + when(checkpointerMock.checkpoint()) + .thenThrow(new ThrottlingException("error message")) + .thenThrow(new ThrottlingException(expectedErrorMessage)) + + val exception = intercept[RuntimeException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) + } + exception.getMessage().shouldBe(expectedErrorMessage) + + verify(checkpointerMock, times(2)).checkpoint() + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala similarity index 79% rename from extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala rename to external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index ba84e557dfcc2..0e71bf9b84332 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -25,10 +25,12 @@ import scala.util.Random import com.amazonaws.regions.RegionUtils import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.model.Record +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.Matchers._ import org.scalatest.concurrent.Eventually -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming._ @@ -37,9 +39,8 @@ import org.apache.spark.streaming.kinesis.KinesisTestUtils._ import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult import org.apache.spark.streaming.scheduler.ReceivedBlockInfo import org.apache.spark.util.Utils -import org.apache.spark.{SparkConf, SparkContext} -class KinesisStreamSuite extends KinesisFunSuite +abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFunSuite with Eventually with BeforeAndAfter with BeforeAndAfterAll { // This is the name that KCL will use to save metadata to DynamoDB @@ -63,7 +64,7 @@ class KinesisStreamSuite extends KinesisFunSuite sc = new SparkContext(conf) runIfTestsEnabled("Prepare KinesisTestUtils") { - testUtils = new KinesisTestUtils() + testUtils = new KPLBasedKinesisTestUtils() testUtils.createStream() } } @@ -98,14 +99,10 @@ class KinesisStreamSuite extends KinesisFunSuite } test("KinesisUtils API") { - // Tests the API, does not actually test data receiving - val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream", - dummyEndpointUrl, Seconds(2), - InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2) - val kinesisStream2 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", + val kinesisStream1 = KinesisUtils.createStream(ssc, "myAppName", "mySparkStream", dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2) - val kinesisStream3 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", + val kinesisStream2 = KinesisUtils.createStream(ssc, "myAppName", "mySparkStream", dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2, dummyAWSAccessKey, dummyAWSSecretKey) @@ -136,8 +133,8 @@ class KinesisStreamSuite extends KinesisFunSuite // Verify that the generated KinesisBackedBlockRDD has the all the right information val blockInfos = Seq(blockInfo1, blockInfo2) val nonEmptyRDD = kinesisStream.createBlockRDD(time, blockInfos) - nonEmptyRDD shouldBe a [KinesisBackedBlockRDD[Array[Byte]]] - val kinesisRDD = nonEmptyRDD.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]] + nonEmptyRDD shouldBe a [KinesisBackedBlockRDD[_]] + val kinesisRDD = nonEmptyRDD.asInstanceOf[KinesisBackedBlockRDD[_]] assert(kinesisRDD.regionName === dummyRegionName) assert(kinesisRDD.endpointUrl === dummyEndpointUrl) assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds) @@ -153,7 +150,9 @@ class KinesisStreamSuite extends KinesisFunSuite // Verify that KinesisBackedBlockRDD is generated even when there are no blocks val emptyRDD = kinesisStream.createBlockRDD(time, Seq.empty) - emptyRDD shouldBe a [KinesisBackedBlockRDD[Array[Byte]]] + // Verify it's KinesisBackedBlockRDD[_] rather than KinesisBackedBlockRDD[Array[Byte]], because + // the type parameter will be erased at runtime + emptyRDD shouldBe a [KinesisBackedBlockRDD[_]] emptyRDD.partitions shouldBe empty // Verify that the KinesisBackedBlockRDD has isBlockValid = false when blocks are invalid @@ -179,43 +178,49 @@ class KinesisStreamSuite extends KinesisFunSuite Seconds(10), StorageLevel.MEMORY_ONLY, awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) - val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] + val collected = new mutable.HashSet[Int] stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => - collected ++= rdd.collect() - logInfo("Collected = " + rdd.collect().toSeq.mkString(", ")) + collected.synchronized { + collected ++= rdd.collect() + logInfo("Collected = " + collected.mkString(", ")) + } } ssc.start() val testData = 1 to 10 eventually(timeout(120 seconds), interval(10 second)) { - testUtils.pushData(testData) - assert(collected === testData.toSet, "\nData received does not match data sent") + testUtils.pushData(testData, aggregateTestData) + assert(collected.synchronized { collected === testData.toSet }, + "\nData received does not match data sent") } ssc.stop(stopSparkContext = false) } testIfEnabled("custom message handling") { val awsCredentials = KinesisTestUtils.getAWSCredentials() - def addFive(r: Record): Int = new String(r.getData.array()).toInt + 5 + def addFive(r: Record): Int = JavaUtils.bytesToString(r.getData).toInt + 5 val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, Seconds(10), StorageLevel.MEMORY_ONLY, addFive, awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) - stream shouldBe a [ReceiverInputDStream[Int]] + stream shouldBe a [ReceiverInputDStream[_]] - val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] + val collected = new mutable.HashSet[Int] stream.foreachRDD { rdd => - collected ++= rdd.collect() - logInfo("Collected = " + rdd.collect().toSeq.mkString(", ")) + collected.synchronized { + collected ++= rdd.collect() + logInfo("Collected = " + collected.mkString(", ")) + } } ssc.start() val testData = 1 to 10 eventually(timeout(120 seconds), interval(10 second)) { - testUtils.pushData(testData) + testUtils.pushData(testData, aggregateTestData) val modData = testData.map(_ + 5) - assert(collected === modData.toSet, "\nData received does not match data sent") + assert(collected.synchronized { collected === modData.toSet }, + "\nData received does not match data sent") } ssc.stop(stopSparkContext = false) } @@ -229,7 +234,6 @@ class KinesisStreamSuite extends KinesisFunSuite val awsCredentials = KinesisTestUtils.getAWSCredentials() val collectedData = new mutable.HashMap[Time, (Array[SequenceNumberRanges], Seq[Int])] - with mutable.SynchronizedMap[Time, (Array[SequenceNumberRanges], Seq[Int])] val kinesisStream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, @@ -240,13 +244,16 @@ class KinesisStreamSuite extends KinesisFunSuite kinesisStream.foreachRDD((rdd: RDD[Array[Byte]], time: Time) => { val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]] val data = rdd.map { bytes => new String(bytes).toInt }.collect().toSeq - collectedData(time) = (kRdd.arrayOfseqNumberRanges, data) + collectedData.synchronized { + collectedData(time) = (kRdd.arrayOfseqNumberRanges, data) + } }) ssc.remember(Minutes(60)) // remember all the batches so that they are all saved in checkpoint ssc.start() - def numBatchesWithData: Int = collectedData.count(_._2._2.nonEmpty) + def numBatchesWithData: Int = + collectedData.synchronized { collectedData.count(_._2._2.nonEmpty) } def isCheckpointPresent: Boolean = Checkpoint.getCheckpointFiles(checkpointDir).nonEmpty @@ -254,7 +261,7 @@ class KinesisStreamSuite extends KinesisFunSuite // If this times out because numBatchesWithData is empty, then its likely that foreachRDD // function failed with exceptions, and nothing got added to `collectedData` eventually(timeout(2 minutes), interval(1 seconds)) { - testUtils.pushData(1 to 5) + testUtils.pushData(1 to 5, aggregateTestData) assert(isCheckpointPresent && numBatchesWithData > 10) } ssc.stop(stopSparkContext = true) // stop the SparkContext so that the blocks are not reused @@ -267,23 +274,28 @@ class KinesisStreamSuite extends KinesisFunSuite // Verify that the recomputed RDDs are KinesisBackedBlockRDDs with the same sequence ranges // and return the same data - val times = collectedData.keySet - times.foreach { time => - val (arrayOfSeqNumRanges, data) = collectedData(time) - val rdd = recoveredKinesisStream.getOrCompute(time).get.asInstanceOf[RDD[Array[Byte]]] - rdd shouldBe a [KinesisBackedBlockRDD[Array[Byte]]] - - // Verify the recovered sequence ranges - val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]] - assert(kRdd.arrayOfseqNumberRanges.size === arrayOfSeqNumRanges.size) - arrayOfSeqNumRanges.zip(kRdd.arrayOfseqNumberRanges).foreach { case (expected, found) => - assert(expected.ranges.toSeq === found.ranges.toSeq) + collectedData.synchronized { + val times = collectedData.keySet + times.foreach { time => + val (arrayOfSeqNumRanges, data) = collectedData(time) + val rdd = recoveredKinesisStream.getOrCompute(time).get.asInstanceOf[RDD[Array[Byte]]] + rdd shouldBe a[KinesisBackedBlockRDD[_]] + + // Verify the recovered sequence ranges + val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]] + assert(kRdd.arrayOfseqNumberRanges.size === arrayOfSeqNumRanges.size) + arrayOfSeqNumRanges.zip(kRdd.arrayOfseqNumberRanges).foreach { case (expected, found) => + assert(expected.ranges.toSeq === found.ranges.toSeq) + } + + // Verify the recovered data + assert(rdd.map { bytes => new String(bytes).toInt }.collect().toSeq === data) } - - // Verify the recovered data - assert(rdd.map { bytes => new String(bytes).toInt }.collect().toSeq === data) } ssc.stop() } - } + +class WithAggregationKinesisStreamSuite extends KinesisStreamTests(aggregateTestData = true) + +class WithoutAggregationKinesisStreamSuite extends KinesisStreamTests(aggregateTestData = false) diff --git a/external/mqtt-assembly/pom.xml b/external/mqtt-assembly/pom.xml deleted file mode 100644 index 89713a28ca6a8..0000000000000 --- a/external/mqtt-assembly/pom.xml +++ /dev/null @@ -1,175 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT - ../../pom.xml - - - org.apache.spark - spark-streaming-mqtt-assembly_2.10 - jar - Spark Project External MQTT Assembly - http://spark.apache.org/ - - - streaming-mqtt-assembly - - - - - org.apache.spark - spark-streaming-mqtt_${scala.binary.version} - ${project.version} - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - provided - - - - commons-lang - commons-lang - provided - - - com.google.protobuf - protobuf-java - provided - - - com.sun.jersey - jersey-server - provided - - - com.sun.jersey - jersey-core - provided - - - org.apache.hadoop - hadoop-client - provided - - - org.apache.avro - avro-mapred - ${avro.mapred.classifier} - provided - - - org.apache.curator - curator-recipes - provided - - - org.apache.zookeeper - zookeeper - provided - - - log4j - log4j - provided - - - net.java.dev.jets3t - jets3t - provided - - - org.scala-lang - scala-library - provided - - - org.slf4j - slf4j-api - provided - - - org.slf4j - slf4j-log4j12 - provided - - - org.xerial.snappy - snappy-java - provided - - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - - org.apache.maven.plugins - maven-shade-plugin - - false - - - *:* - - - - - *:* - - META-INF/*.SF - META-INF/*.DSA - META-INF/*.RSA - - - - - - - package - - shade - - - - - - reference.conf - - - log4j.properties - - - - - - - - - - - diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml deleted file mode 100644 index 59fba8b826b4f..0000000000000 --- a/external/mqtt/pom.xml +++ /dev/null @@ -1,104 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT - ../../pom.xml - - - org.apache.spark - spark-streaming-mqtt_2.10 - - streaming-mqtt - - jar - Spark Project External MQTT - http://spark.apache.org/ - - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - provided - - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - test-jar - test - - - org.eclipse.paho - org.eclipse.paho.client.mqttv3 - 1.0.1 - - - org.scalacheck - scalacheck_${scala.binary.version} - test - - - org.apache.activemq - activemq-core - 5.7.0 - test - - - org.apache.spark - spark-test-tags_${scala.binary.version} - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - - - - org.apache.maven.plugins - maven-assembly-plugin - - - test-jar-with-dependencies - package - - single - - - - spark-streaming-mqtt-test-${project.version} - ${project.build.directory}/scala-${scala.binary.version}/ - false - - false - - src/main/assembly/assembly.xml - - - - - - - - diff --git a/external/mqtt/src/main/assembly/assembly.xml b/external/mqtt/src/main/assembly/assembly.xml deleted file mode 100644 index ecab5b360eb3e..0000000000000 --- a/external/mqtt/src/main/assembly/assembly.xml +++ /dev/null @@ -1,44 +0,0 @@ - - - test-jar-with-dependencies - - jar - - false - - - - ${project.build.directory}/scala-${scala.binary.version}/test-classes - / - - - - - - true - test - true - - org.apache.hadoop:*:jar - org.apache.zookeeper:*:jar - org.apache.avro:*:jar - - - - - diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala deleted file mode 100644 index 116c170489e96..0000000000000 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.mqtt - -import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken -import org.eclipse.paho.client.mqttv3.MqttCallback -import org.eclipse.paho.client.mqttv3.MqttClient -import org.eclipse.paho.client.mqttv3.MqttMessage -import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence - -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.dstream._ -import org.apache.spark.streaming.receiver.Receiver - -/** - * Input stream that subscribe messages from a Mqtt Broker. - * Uses eclipse paho as MqttClient http://www.eclipse.org/paho/ - * @param brokerUrl Url of remote mqtt publisher - * @param topic topic name to subscribe to - * @param storageLevel RDD storage level. - */ - -private[streaming] -class MQTTInputDStream( - ssc_ : StreamingContext, - brokerUrl: String, - topic: String, - storageLevel: StorageLevel - ) extends ReceiverInputDStream[String](ssc_) { - - private[streaming] override def name: String = s"MQTT stream [$id]" - - def getReceiver(): Receiver[String] = { - new MQTTReceiver(brokerUrl, topic, storageLevel) - } -} - -private[streaming] -class MQTTReceiver( - brokerUrl: String, - topic: String, - storageLevel: StorageLevel - ) extends Receiver[String](storageLevel) { - - def onStop() { - - } - - def onStart() { - - // Set up persistence for messages - val persistence = new MemoryPersistence() - - // Initializing Mqtt Client specifying brokerUrl, clientID and MqttClientPersistance - val client = new MqttClient(brokerUrl, MqttClient.generateClientId(), persistence) - - // Callback automatically triggers as and when new message arrives on specified topic - val callback = new MqttCallback() { - - // Handles Mqtt message - override def messageArrived(topic: String, message: MqttMessage) { - store(new String(message.getPayload(), "utf-8")) - } - - override def deliveryComplete(token: IMqttDeliveryToken) { - } - - override def connectionLost(cause: Throwable) { - restart("Connection lost ", cause) - } - } - - // Set up callback for MqttClient. This needs to happen before - // connecting or subscribing, otherwise messages may be lost - client.setCallback(callback) - - // Connect to MqttBroker - client.connect() - - // Subscribe to Mqtt topic - client.subscribe(topic) - - } -} diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala deleted file mode 100644 index 7b8d56d6faf2d..0000000000000 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.mqtt - -import scala.reflect.ClassTag - -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaDStream, JavaReceiverInputDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.ReceiverInputDStream - -object MQTTUtils { - /** - * Create an input stream that receives messages pushed by a MQTT publisher. - * @param ssc StreamingContext object - * @param brokerUrl Url of remote MQTT publisher - * @param topic Topic name to subscribe to - * @param storageLevel RDD storage level. Defaults to StorageLevel.MEMORY_AND_DISK_SER_2. - */ - def createStream( - ssc: StreamingContext, - brokerUrl: String, - topic: String, - storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 - ): ReceiverInputDStream[String] = { - new MQTTInputDStream(ssc, brokerUrl, topic, storageLevel) - } - - /** - * Create an input stream that receives messages pushed by a MQTT publisher. - * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. - * @param jssc JavaStreamingContext object - * @param brokerUrl Url of remote MQTT publisher - * @param topic Topic name to subscribe to - */ - def createStream( - jssc: JavaStreamingContext, - brokerUrl: String, - topic: String - ): JavaReceiverInputDStream[String] = { - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[String]] - createStream(jssc.ssc, brokerUrl, topic) - } - - /** - * Create an input stream that receives messages pushed by a MQTT publisher. - * @param jssc JavaStreamingContext object - * @param brokerUrl Url of remote MQTT publisher - * @param topic Topic name to subscribe to - * @param storageLevel RDD storage level. - */ - def createStream( - jssc: JavaStreamingContext, - brokerUrl: String, - topic: String, - storageLevel: StorageLevel - ): JavaReceiverInputDStream[String] = { - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[String]] - createStream(jssc.ssc, brokerUrl, topic, storageLevel) - } -} - -/** - * This is a helper class that wraps the methods in MQTTUtils into more Python-friendly class and - * function so that it can be easily instantiated and called from Python's MQTTUtils. - */ -private[mqtt] class MQTTUtilsPythonHelper { - - def createStream( - jssc: JavaStreamingContext, - brokerUrl: String, - topic: String, - storageLevel: StorageLevel - ): JavaDStream[String] = { - MQTTUtils.createStream(jssc, brokerUrl, topic, storageLevel) - } -} diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/package-info.java b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/package-info.java deleted file mode 100644 index 728e0d8663d01..0000000000000 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/package-info.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * MQTT receiver for Spark Streaming. - */ -package org.apache.spark.streaming.mqtt; \ No newline at end of file diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/package.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/package.scala deleted file mode 100644 index 63d0d138183a9..0000000000000 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/package.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming - -/** - * MQTT receiver for Spark Streaming. - */ -package object mqtt diff --git a/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java deleted file mode 100644 index cfedb5a042a35..0000000000000 --- a/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming; - -import org.apache.spark.SparkConf; -import org.apache.spark.streaming.api.java.JavaStreamingContext; -import org.junit.After; -import org.junit.Before; - -public abstract class LocalJavaStreamingContext { - - protected transient JavaStreamingContext ssc; - - @Before - public void setUp() { - SparkConf conf = new SparkConf() - .setMaster("local[2]") - .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); - ssc = new JavaStreamingContext(conf, new Duration(1000)); - ssc.checkpoint("checkpoint"); - } - - @After - public void tearDown() { - ssc.stop(); - ssc = null; - } -} diff --git a/external/mqtt/src/test/java/org/apache/spark/streaming/mqtt/JavaMQTTStreamSuite.java b/external/mqtt/src/test/java/org/apache/spark/streaming/mqtt/JavaMQTTStreamSuite.java deleted file mode 100644 index ce5aa1e0cdda4..0000000000000 --- a/external/mqtt/src/test/java/org/apache/spark/streaming/mqtt/JavaMQTTStreamSuite.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.mqtt; - -import org.apache.spark.storage.StorageLevel; -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; -import org.junit.Test; - -import org.apache.spark.streaming.LocalJavaStreamingContext; - -public class JavaMQTTStreamSuite extends LocalJavaStreamingContext { - @Test - public void testMQTTStream() { - String brokerUrl = "abc"; - String topic = "def"; - - // tests the API, does not actually test data receiving - JavaReceiverInputDStream test1 = MQTTUtils.createStream(ssc, brokerUrl, topic); - JavaReceiverInputDStream test2 = MQTTUtils.createStream(ssc, brokerUrl, topic, - StorageLevel.MEMORY_AND_DISK_SER_2()); - } -} diff --git a/external/mqtt/src/test/resources/log4j.properties b/external/mqtt/src/test/resources/log4j.properties deleted file mode 100644 index 75e3b53a093f6..0000000000000 --- a/external/mqtt/src/test/resources/log4j.properties +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=INFO, file -log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file.append=true -log4j.appender.file.file=target/unit-tests.log -log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n - -# Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN - diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala deleted file mode 100644 index a6a9249db8ed7..0000000000000 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.mqtt - -import scala.concurrent.duration._ -import scala.language.postfixOps - -import org.scalatest.BeforeAndAfter -import org.scalatest.concurrent.Eventually - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Milliseconds, StreamingContext} - -class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter { - - private val batchDuration = Milliseconds(500) - private val master = "local[2]" - private val framework = this.getClass.getSimpleName - private val topic = "def" - - private var ssc: StreamingContext = _ - private var mqttTestUtils: MQTTTestUtils = _ - - before { - ssc = new StreamingContext(master, framework, batchDuration) - mqttTestUtils = new MQTTTestUtils - mqttTestUtils.setup() - } - - after { - if (ssc != null) { - ssc.stop() - ssc = null - } - if (mqttTestUtils != null) { - mqttTestUtils.teardown() - mqttTestUtils = null - } - } - - test("mqtt input stream") { - val sendMessage = "MQTT demo for spark streaming" - val receiveStream = MQTTUtils.createStream(ssc, "tcp://" + mqttTestUtils.brokerUri, topic, - StorageLevel.MEMORY_ONLY) - - @volatile var receiveMessage: List[String] = List() - receiveStream.foreachRDD { rdd => - if (rdd.collect.length > 0) { - receiveMessage = receiveMessage ::: List(rdd.first) - receiveMessage - } - } - - ssc.start() - - // Retry it because we don't know when the receiver will start. - eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { - mqttTestUtils.publishData(topic, sendMessage) - assert(sendMessage.equals(receiveMessage(0))) - } - ssc.stop() - } -} diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala deleted file mode 100644 index 1618e2c088b70..0000000000000 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.mqtt - -import java.net.{ServerSocket, URI} - -import scala.language.postfixOps - -import com.google.common.base.Charsets.UTF_8 -import org.apache.activemq.broker.{BrokerService, TransportConnector} -import org.apache.commons.lang3.RandomUtils -import org.eclipse.paho.client.mqttv3._ -import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence - -import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkConf} - -/** - * Share codes for Scala and Python unit tests - */ -private[mqtt] class MQTTTestUtils extends Logging { - - private val persistenceDir = Utils.createTempDir() - private val brokerHost = "localhost" - private val brokerPort = findFreePort() - - private var broker: BrokerService = _ - private var connector: TransportConnector = _ - - def brokerUri: String = { - s"$brokerHost:$brokerPort" - } - - def setup(): Unit = { - broker = new BrokerService() - broker.setDataDirectoryFile(Utils.createTempDir()) - connector = new TransportConnector() - connector.setName("mqtt") - connector.setUri(new URI("mqtt://" + brokerUri)) - broker.addConnector(connector) - broker.start() - } - - def teardown(): Unit = { - if (broker != null) { - broker.stop() - broker = null - } - if (connector != null) { - connector.stop() - connector = null - } - Utils.deleteRecursively(persistenceDir) - } - - private def findFreePort(): Int = { - val candidatePort = RandomUtils.nextInt(1024, 65536) - Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { - val socket = new ServerSocket(trialPort) - socket.close() - (null, trialPort) - }, new SparkConf())._2 - } - - def publishData(topic: String, data: String): Unit = { - var client: MqttClient = null - try { - val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) - client = new MqttClient("tcp://" + brokerUri, MqttClient.generateClientId(), persistence) - client.connect() - if (client.isConnected) { - val msgTopic = client.getTopic(topic) - val message = new MqttMessage(data.getBytes(UTF_8)) - message.setQos(1) - message.setRetained(true) - - for (i <- 0 to 10) { - try { - msgTopic.publish(message) - } catch { - case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => - // wait for Spark streaming to consume something from the message queue - Thread.sleep(50) - } - } - } - } finally { - if (client != null) { - client.disconnect() - client.close() - client = null - } - } - } - -} diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml new file mode 100644 index 0000000000000..bfb92791de3d8 --- /dev/null +++ b/external/spark-ganglia-lgpl/pom.xml @@ -0,0 +1,49 @@ + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.0.0-SNAPSHOT + ../../pom.xml + + + + org.apache.spark + spark-ganglia-lgpl_2.11 + jar + Spark Ganglia Integration + + + ganglia-lgpl + + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + + + + io.dropwizard.metrics + metrics-ganglia + + + diff --git a/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala b/external/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala similarity index 100% rename from extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala rename to external/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml deleted file mode 100644 index 087270de90b3f..0000000000000 --- a/external/twitter/pom.xml +++ /dev/null @@ -1,70 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT - ../../pom.xml - - - org.apache.spark - spark-streaming-twitter_2.10 - - streaming-twitter - - jar - Spark Project External Twitter - http://spark.apache.org/ - - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - provided - - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - test-jar - test - - - org.twitter4j - twitter4j-stream - 4.0.4 - - - org.scalacheck - scalacheck_${scala.binary.version} - test - - - org.apache.spark - spark-test-tags_${scala.binary.version} - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala deleted file mode 100644 index 9a85a6597c27f..0000000000000 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.twitter - -import twitter4j._ -import twitter4j.auth.Authorization -import twitter4j.conf.ConfigurationBuilder -import twitter4j.auth.OAuthAuthorization - -import org.apache.spark.streaming._ -import org.apache.spark.streaming.dstream._ -import org.apache.spark.storage.StorageLevel -import org.apache.spark.Logging -import org.apache.spark.streaming.receiver.Receiver - -/* A stream of Twitter statuses, potentially filtered by one or more keywords. -* -* @constructor create a new Twitter stream using the supplied Twitter4J authentication credentials. -* An optional set of string filters can be used to restrict the set of tweets. The Twitter API is -* such that this may return a sampled subset of all tweets during each interval. -* -* If no Authorization object is provided, initializes OAuth authorization using the system -* properties twitter4j.oauth.consumerKey, .consumerSecret, .accessToken and .accessTokenSecret. -*/ -private[streaming] -class TwitterInputDStream( - ssc_ : StreamingContext, - twitterAuth: Option[Authorization], - filters: Seq[String], - storageLevel: StorageLevel - ) extends ReceiverInputDStream[Status](ssc_) { - - private def createOAuthAuthorization(): Authorization = { - new OAuthAuthorization(new ConfigurationBuilder().build()) - } - - private val authorization = twitterAuth.getOrElse(createOAuthAuthorization()) - - override def getReceiver(): Receiver[Status] = { - new TwitterReceiver(authorization, filters, storageLevel) - } -} - -private[streaming] -class TwitterReceiver( - twitterAuth: Authorization, - filters: Seq[String], - storageLevel: StorageLevel - ) extends Receiver[Status](storageLevel) with Logging { - - @volatile private var twitterStream: TwitterStream = _ - @volatile private var stopped = false - - def onStart() { - try { - val newTwitterStream = new TwitterStreamFactory().getInstance(twitterAuth) - newTwitterStream.addListener(new StatusListener { - def onStatus(status: Status): Unit = { - store(status) - } - // Unimplemented - def onDeletionNotice(statusDeletionNotice: StatusDeletionNotice) {} - def onTrackLimitationNotice(i: Int) {} - def onScrubGeo(l: Long, l1: Long) {} - def onStallWarning(stallWarning: StallWarning) {} - def onException(e: Exception) { - if (!stopped) { - restart("Error receiving tweets", e) - } - } - }) - - val query = new FilterQuery - if (filters.size > 0) { - query.track(filters.mkString(",")) - newTwitterStream.filter(query) - } else { - newTwitterStream.sample() - } - setTwitterStream(newTwitterStream) - logInfo("Twitter receiver started") - stopped = false - } catch { - case e: Exception => restart("Error starting Twitter stream", e) - } - } - - def onStop() { - stopped = true - setTwitterStream(null) - logInfo("Twitter receiver stopped") - } - - private def setTwitterStream(newTwitterStream: TwitterStream) = synchronized { - if (twitterStream != null) { - twitterStream.shutdown() - } - twitterStream = newTwitterStream - } -} diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterUtils.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterUtils.scala deleted file mode 100644 index c6a9a2b73714f..0000000000000 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterUtils.scala +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.twitter - -import twitter4j.Status -import twitter4j.auth.Authorization -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.{ReceiverInputDStream, DStream} - -object TwitterUtils { - /** - * Create a input stream that returns tweets received from Twitter. - * @param ssc StreamingContext object - * @param twitterAuth Twitter4J authentication, or None to use Twitter4J's default OAuth - * authorization; this uses the system properties twitter4j.oauth.consumerKey, - * twitter4j.oauth.consumerSecret, twitter4j.oauth.accessToken and - * twitter4j.oauth.accessTokenSecret - * @param filters Set of filter strings to get only those tweets that match them - * @param storageLevel Storage level to use for storing the received objects - */ - def createStream( - ssc: StreamingContext, - twitterAuth: Option[Authorization], - filters: Seq[String] = Nil, - storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 - ): ReceiverInputDStream[Status] = { - new TwitterInputDStream(ssc, twitterAuth, filters, storageLevel) - } - - /** - * Create a input stream that returns tweets received from Twitter using Twitter4J's default - * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey, - * twitter4j.oauth.consumerSecret, twitter4j.oauth.accessToken and - * twitter4j.oauth.accessTokenSecret. - * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. - * @param jssc JavaStreamingContext object - */ - def createStream(jssc: JavaStreamingContext): JavaReceiverInputDStream[Status] = { - createStream(jssc.ssc, None) - } - - /** - * Create a input stream that returns tweets received from Twitter using Twitter4J's default - * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey, - * twitter4j.oauth.consumerSecret, twitter4j.oauth.accessToken and - * twitter4j.oauth.accessTokenSecret. - * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. - * @param jssc JavaStreamingContext object - * @param filters Set of filter strings to get only those tweets that match them - */ - def createStream(jssc: JavaStreamingContext, filters: Array[String] - ): JavaReceiverInputDStream[Status] = { - createStream(jssc.ssc, None, filters) - } - - /** - * Create a input stream that returns tweets received from Twitter using Twitter4J's default - * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey, - * twitter4j.oauth.consumerSecret, twitter4j.oauth.accessToken and - * twitter4j.oauth.accessTokenSecret. - * @param jssc JavaStreamingContext object - * @param filters Set of filter strings to get only those tweets that match them - * @param storageLevel Storage level to use for storing the received objects - */ - def createStream( - jssc: JavaStreamingContext, - filters: Array[String], - storageLevel: StorageLevel - ): JavaReceiverInputDStream[Status] = { - createStream(jssc.ssc, None, filters, storageLevel) - } - - /** - * Create a input stream that returns tweets received from Twitter. - * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. - * @param jssc JavaStreamingContext object - * @param twitterAuth Twitter4J Authorization - */ - def createStream(jssc: JavaStreamingContext, twitterAuth: Authorization - ): JavaReceiverInputDStream[Status] = { - createStream(jssc.ssc, Some(twitterAuth)) - } - - /** - * Create a input stream that returns tweets received from Twitter. - * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. - * @param jssc JavaStreamingContext object - * @param twitterAuth Twitter4J Authorization - * @param filters Set of filter strings to get only those tweets that match them - */ - def createStream( - jssc: JavaStreamingContext, - twitterAuth: Authorization, - filters: Array[String] - ): JavaReceiverInputDStream[Status] = { - createStream(jssc.ssc, Some(twitterAuth), filters) - } - - /** - * Create a input stream that returns tweets received from Twitter. - * @param jssc JavaStreamingContext object - * @param twitterAuth Twitter4J Authorization object - * @param filters Set of filter strings to get only those tweets that match them - * @param storageLevel Storage level to use for storing the received objects - */ - def createStream( - jssc: JavaStreamingContext, - twitterAuth: Authorization, - filters: Array[String], - storageLevel: StorageLevel - ): JavaReceiverInputDStream[Status] = { - createStream(jssc.ssc, Some(twitterAuth), filters, storageLevel) - } -} diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/package-info.java b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/package-info.java deleted file mode 100644 index 258c0950a0aa7..0000000000000 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/package-info.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * Twitter feed receiver for spark streaming. - */ -package org.apache.spark.streaming.twitter; \ No newline at end of file diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/package.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/package.scala deleted file mode 100644 index 580e37fa8f814..0000000000000 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/package.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming - -/** - * Twitter feed receiver for spark streaming. - */ -package object twitter diff --git a/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java deleted file mode 100644 index cfedb5a042a35..0000000000000 --- a/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming; - -import org.apache.spark.SparkConf; -import org.apache.spark.streaming.api.java.JavaStreamingContext; -import org.junit.After; -import org.junit.Before; - -public abstract class LocalJavaStreamingContext { - - protected transient JavaStreamingContext ssc; - - @Before - public void setUp() { - SparkConf conf = new SparkConf() - .setMaster("local[2]") - .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); - ssc = new JavaStreamingContext(conf, new Duration(1000)); - ssc.checkpoint("checkpoint"); - } - - @After - public void tearDown() { - ssc.stop(); - ssc = null; - } -} diff --git a/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java b/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java deleted file mode 100644 index 26ec8af455bcf..0000000000000 --- a/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.twitter; - -import org.junit.Test; -import twitter4j.Status; -import twitter4j.auth.Authorization; -import twitter4j.auth.NullAuthorization; -import org.apache.spark.storage.StorageLevel; -import org.apache.spark.streaming.LocalJavaStreamingContext; -import org.apache.spark.streaming.api.java.JavaDStream; - -public class JavaTwitterStreamSuite extends LocalJavaStreamingContext { - @Test - public void testTwitterStream() { - String[] filters = { "filter1", "filter2" }; - Authorization auth = NullAuthorization.getInstance(); - - // tests the API, does not actually test data receiving - JavaDStream test1 = TwitterUtils.createStream(ssc); - JavaDStream test2 = TwitterUtils.createStream(ssc, filters); - JavaDStream test3 = TwitterUtils.createStream( - ssc, filters, StorageLevel.MEMORY_AND_DISK_SER_2()); - JavaDStream test4 = TwitterUtils.createStream(ssc, auth); - JavaDStream test5 = TwitterUtils.createStream(ssc, auth, filters); - JavaDStream test6 = TwitterUtils.createStream(ssc, - auth, filters, StorageLevel.MEMORY_AND_DISK_SER_2()); - } -} diff --git a/external/twitter/src/test/resources/log4j.properties b/external/twitter/src/test/resources/log4j.properties deleted file mode 100644 index 9a3569789d2e0..0000000000000 --- a/external/twitter/src/test/resources/log4j.properties +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Set everything to be logged to the filetarget/unit-tests.log -log4j.rootCategory=INFO, file -log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file.append=true -log4j.appender.file.file=target/unit-tests.log -log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n - -# Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN - diff --git a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala deleted file mode 100644 index d9acb568879fe..0000000000000 --- a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.twitter - - -import org.scalatest.BeforeAndAfter -import twitter4j.Status -import twitter4j.auth.{NullAuthorization, Authorization} - -import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.dstream.ReceiverInputDStream - -class TwitterStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { - - val batchDuration = Seconds(1) - - private val master: String = "local[2]" - - private val framework: String = this.getClass.getSimpleName - - test("twitter input stream") { - val ssc = new StreamingContext(master, framework, batchDuration) - val filters = Seq("filter1", "filter2") - val authorization: Authorization = NullAuthorization.getInstance() - - // tests the API, does not actually test data receiving - val test1: ReceiverInputDStream[Status] = TwitterUtils.createStream(ssc, None) - val test2: ReceiverInputDStream[Status] = - TwitterUtils.createStream(ssc, None, filters) - val test3: ReceiverInputDStream[Status] = - TwitterUtils.createStream(ssc, None, filters, StorageLevel.MEMORY_AND_DISK_SER_2) - val test4: ReceiverInputDStream[Status] = - TwitterUtils.createStream(ssc, Some(authorization)) - val test5: ReceiverInputDStream[Status] = - TwitterUtils.createStream(ssc, Some(authorization), filters) - val test6: ReceiverInputDStream[Status] = TwitterUtils.createStream( - ssc, Some(authorization), filters, StorageLevel.MEMORY_AND_DISK_SER_2) - - // Note that actually testing the data receiving is hard as authentication keys are - // necessary for accessing Twitter live stream - ssc.stop() - } -} diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml deleted file mode 100644 index 02d6b81281576..0000000000000 --- a/external/zeromq/pom.xml +++ /dev/null @@ -1,69 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT - ../../pom.xml - - - org.apache.spark - spark-streaming-zeromq_2.10 - - streaming-zeromq - - jar - Spark Project External ZeroMQ - http://spark.apache.org/ - - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - provided - - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - test-jar - test - - - ${akka.group} - akka-zeromq_${scala.binary.version} - - - org.scalacheck - scalacheck_${scala.binary.version} - test - - - org.apache.spark - spark-test-tags_${scala.binary.version} - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala deleted file mode 100644 index 588e6bac7b14a..0000000000000 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.zeromq - -import scala.reflect.ClassTag - -import akka.actor.Actor -import akka.util.ByteString -import akka.zeromq._ - -import org.apache.spark.Logging -import org.apache.spark.streaming.receiver.ActorHelper - -/** - * A receiver to subscribe to ZeroMQ stream. - */ -private[streaming] class ZeroMQReceiver[T: ClassTag]( - publisherUrl: String, - subscribe: Subscribe, - bytesToObjects: Seq[ByteString] => Iterator[T]) - extends Actor with ActorHelper with Logging { - - override def preStart(): Unit = { - ZeroMQExtension(context.system) - .newSocket(SocketType.Sub, Listener(self), Connect(publisherUrl), subscribe) - } - - def receive: Receive = { - - case Connecting => logInfo("connecting ...") - - case m: ZMQMessage => - logDebug("Received message for:" + m.frame(0)) - - // We ignore first frame for processing as it is the topic - val bytes = m.frames.tail - store(bytesToObjects(bytes)) - - case Closed => logInfo("received closed ") - } -} diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala deleted file mode 100644 index 4ea218eaa4de1..0000000000000 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.zeromq - -import scala.reflect.ClassTag -import scala.collection.JavaConverters._ - -import akka.actor.{Props, SupervisorStrategy} -import akka.util.ByteString -import akka.zeromq.Subscribe - -import org.apache.spark.api.java.function.{Function => JFunction} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.receiver.ActorSupervisorStrategy - -object ZeroMQUtils { - /** - * Create an input stream that receives messages pushed by a zeromq publisher. - * @param ssc StreamingContext object - * @param publisherUrl Url of remote zeromq publisher - * @param subscribe Topic to subscribe to - * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic - * and each frame has sequence of byte thus it needs the converter - * (which might be deserializer of bytes) to translate from sequence - * of sequence of bytes, where sequence refer to a frame - * and sub sequence refer to its payload. - * @param storageLevel RDD storage level. Defaults to StorageLevel.MEMORY_AND_DISK_SER_2. - */ - def createStream[T: ClassTag]( - ssc: StreamingContext, - publisherUrl: String, - subscribe: Subscribe, - bytesToObjects: Seq[ByteString] => Iterator[T], - storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2, - supervisorStrategy: SupervisorStrategy = ActorSupervisorStrategy.defaultStrategy - ): ReceiverInputDStream[T] = { - ssc.actorStream(Props(new ZeroMQReceiver(publisherUrl, subscribe, bytesToObjects)), - "ZeroMQReceiver", storageLevel, supervisorStrategy) - } - - /** - * Create an input stream that receives messages pushed by a zeromq publisher. - * @param jssc JavaStreamingContext object - * @param publisherUrl Url of remote ZeroMQ publisher - * @param subscribe Topic to subscribe to - * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each - * frame has sequence of byte thus it needs the converter(which might be - * deserializer of bytes) to translate from sequence of sequence of bytes, - * where sequence refer to a frame and sub sequence refer to its payload. - * @param storageLevel Storage level to use for storing the received objects - */ - def createStream[T]( - jssc: JavaStreamingContext, - publisherUrl: String, - subscribe: Subscribe, - bytesToObjects: JFunction[Array[Array[Byte]], java.lang.Iterable[T]], - storageLevel: StorageLevel, - supervisorStrategy: SupervisorStrategy - ): JavaReceiverInputDStream[T] = { - implicit val cm: ClassTag[T] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val fn = - (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).iterator().asScala - createStream[T](jssc.ssc, publisherUrl, subscribe, fn, storageLevel, supervisorStrategy) - } - - /** - * Create an input stream that receives messages pushed by a zeromq publisher. - * @param jssc JavaStreamingContext object - * @param publisherUrl Url of remote zeromq publisher - * @param subscribe Topic to subscribe to - * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each - * frame has sequence of byte thus it needs the converter(which might be - * deserializer of bytes) to translate from sequence of sequence of bytes, - * where sequence refer to a frame and sub sequence refer to its payload. - * @param storageLevel RDD storage level. - */ - def createStream[T]( - jssc: JavaStreamingContext, - publisherUrl: String, - subscribe: Subscribe, - bytesToObjects: JFunction[Array[Array[Byte]], java.lang.Iterable[T]], - storageLevel: StorageLevel - ): JavaReceiverInputDStream[T] = { - implicit val cm: ClassTag[T] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val fn = - (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).iterator().asScala - createStream[T](jssc.ssc, publisherUrl, subscribe, fn, storageLevel) - } - - /** - * Create an input stream that receives messages pushed by a zeromq publisher. - * @param jssc JavaStreamingContext object - * @param publisherUrl Url of remote zeromq publisher - * @param subscribe Topic to subscribe to - * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each - * frame has sequence of byte thus it needs the converter(which might - * be deserializer of bytes) to translate from sequence of sequence of - * bytes, where sequence refer to a frame and sub sequence refer to its - * payload. - */ - def createStream[T]( - jssc: JavaStreamingContext, - publisherUrl: String, - subscribe: Subscribe, - bytesToObjects: JFunction[Array[Array[Byte]], java.lang.Iterable[T]] - ): JavaReceiverInputDStream[T] = { - implicit val cm: ClassTag[T] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val fn = - (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).iterator().asScala - createStream[T](jssc.ssc, publisherUrl, subscribe, fn) - } -} diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/package-info.java b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/package-info.java deleted file mode 100644 index 587c524e2120f..0000000000000 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/package-info.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * Zeromq receiver for spark streaming. - */ -package org.apache.spark.streaming.zeromq; \ No newline at end of file diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/package.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/package.scala deleted file mode 100644 index 65e6e57f2c05d..0000000000000 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/package.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming - -/** - * Zeromq receiver for spark streaming. - */ -package object zeromq diff --git a/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java deleted file mode 100644 index cfedb5a042a35..0000000000000 --- a/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming; - -import org.apache.spark.SparkConf; -import org.apache.spark.streaming.api.java.JavaStreamingContext; -import org.junit.After; -import org.junit.Before; - -public abstract class LocalJavaStreamingContext { - - protected transient JavaStreamingContext ssc; - - @Before - public void setUp() { - SparkConf conf = new SparkConf() - .setMaster("local[2]") - .setAppName("test") - .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); - ssc = new JavaStreamingContext(conf, new Duration(1000)); - ssc.checkpoint("checkpoint"); - } - - @After - public void tearDown() { - ssc.stop(); - ssc = null; - } -} diff --git a/external/zeromq/src/test/java/org/apache/spark/streaming/zeromq/JavaZeroMQStreamSuite.java b/external/zeromq/src/test/java/org/apache/spark/streaming/zeromq/JavaZeroMQStreamSuite.java deleted file mode 100644 index 417b91eecb0ee..0000000000000 --- a/external/zeromq/src/test/java/org/apache/spark/streaming/zeromq/JavaZeroMQStreamSuite.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.zeromq; - -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; -import org.junit.Test; -import akka.actor.SupervisorStrategy; -import akka.util.ByteString; -import akka.zeromq.Subscribe; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.storage.StorageLevel; -import org.apache.spark.streaming.LocalJavaStreamingContext; - -public class JavaZeroMQStreamSuite extends LocalJavaStreamingContext { - - @Test // tests the API, does not actually test data receiving - public void testZeroMQStream() { - String publishUrl = "abc"; - Subscribe subscribe = new Subscribe((ByteString)null); - Function> bytesToObjects = new Function>() { - @Override - public Iterable call(byte[][] bytes) throws Exception { - return null; - } - }; - - JavaReceiverInputDStream test1 = ZeroMQUtils.createStream( - ssc, publishUrl, subscribe, bytesToObjects); - JavaReceiverInputDStream test2 = ZeroMQUtils.createStream( - ssc, publishUrl, subscribe, bytesToObjects, StorageLevel.MEMORY_AND_DISK_SER_2()); - JavaReceiverInputDStream test3 = ZeroMQUtils.createStream( - ssc,publishUrl, subscribe, bytesToObjects, StorageLevel.MEMORY_AND_DISK_SER_2(), - SupervisorStrategy.defaultStrategy()); - } -} diff --git a/external/zeromq/src/test/resources/log4j.properties b/external/zeromq/src/test/resources/log4j.properties deleted file mode 100644 index 75e3b53a093f6..0000000000000 --- a/external/zeromq/src/test/resources/log4j.properties +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=INFO, file -log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file.append=true -log4j.appender.file.file=target/unit-tests.log -log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n - -# Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN - diff --git a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala deleted file mode 100644 index 35d2e62c68480..0000000000000 --- a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.zeromq - -import akka.actor.SupervisorStrategy -import akka.util.ByteString -import akka.zeromq.Subscribe - -import org.apache.spark.SparkFunSuite -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.dstream.ReceiverInputDStream - -class ZeroMQStreamSuite extends SparkFunSuite { - - val batchDuration = Seconds(1) - - private val master: String = "local[2]" - - private val framework: String = this.getClass.getSimpleName - - test("zeromq input stream") { - val ssc = new StreamingContext(master, framework, batchDuration) - val publishUrl = "abc" - val subscribe = new Subscribe(null.asInstanceOf[ByteString]) - val bytesToObjects = (bytes: Seq[ByteString]) => null.asInstanceOf[Iterator[String]] - - // tests the API, does not actually test data receiving - val test1: ReceiverInputDStream[String] = - ZeroMQUtils.createStream(ssc, publishUrl, subscribe, bytesToObjects) - val test2: ReceiverInputDStream[String] = ZeroMQUtils.createStream( - ssc, publishUrl, subscribe, bytesToObjects, StorageLevel.MEMORY_AND_DISK_SER_2) - val test3: ReceiverInputDStream[String] = ZeroMQUtils.createStream( - ssc, publishUrl, subscribe, bytesToObjects, - StorageLevel.MEMORY_AND_DISK_SER_2, SupervisorStrategy.defaultStrategy) - - // TODO: Actually test data receiving - ssc.stop() - } -} diff --git a/extras/README.md b/extras/README.md deleted file mode 100644 index 1b4174b7d5cff..0000000000000 --- a/extras/README.md +++ /dev/null @@ -1 +0,0 @@ -This directory contains build components not included by default in Spark's build. diff --git a/extras/java8-tests/README.md b/extras/java8-tests/README.md deleted file mode 100644 index dc9e87f2eeb92..0000000000000 --- a/extras/java8-tests/README.md +++ /dev/null @@ -1,24 +0,0 @@ -# Java 8 Test Suites - -These tests require having Java 8 installed and are isolated from the main Spark build. -If Java 8 is not your system's default Java version, you will need to point Spark's build -to your Java location. The set-up depends a bit on the build system: - -* Sbt users can either set JAVA_HOME to the location of a Java 8 JDK or explicitly pass - `-java-home` to the sbt launch script. If a Java 8 JDK is detected sbt will automatically - include the Java 8 test project. - - `$ JAVA_HOME=/opt/jdk1.8.0/ build/sbt clean "test-only org.apache.spark.Java8APISuite"` - -* For Maven users, - - Maven users can also refer to their Java 8 directory using JAVA_HOME. However, Maven will not - automatically detect the presence of a Java 8 JDK, so a special build profile `-Pjava8-tests` - must be used. - - `$ JAVA_HOME=/opt/jdk1.8.0/ mvn clean install -DskipTests` - `$ JAVA_HOME=/opt/jdk1.8.0/ mvn test -Pjava8-tests -DwildcardSuites=org.apache.spark.Java8APISuite` - - Note that the above command can only be run from project root directory since this module - depends on core and the test-jars of core and streaming. This means an install step is - required to make the test dependencies visible to the Java 8 sub-project. diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml deleted file mode 100644 index 4ce90e75fd359..0000000000000 --- a/extras/java8-tests/pom.xml +++ /dev/null @@ -1,161 +0,0 @@ - - - - 4.0.0 - - org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT - ../../pom.xml - - - org.apache.spark - java8-tests_2.10 - pom - Spark Project Java8 Tests POM - - - java8-tests - - - - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - test-jar - test - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - test-jar - test - - - org.apache.spark - spark-test-tags_${scala.binary.version} - - - - - - java8-tests - - - - - - org.apache.maven.plugins - maven-deploy-plugin - - true - - - - org.apache.maven.plugins - maven-install-plugin - - true - - - - org.apache.maven.plugins - maven-surefire-plugin - - - test - - test - - - - - - - - file:src/test/resources/log4j.properties - - - false - - **/Suite*.java - **/*Suite.java - - - - - org.apache.maven.plugins - maven-compiler-plugin - - - test-compile-first - process-test-resources - - testCompile - - - - - true - true - true - 1.8 - 1.8 - 1.8 - UTF-8 - 1024m - - - - - net.alchim31.maven - scala-maven-plugin - - - none - - - scala-compile-first - none - - - scala-test-compile-first - none - - - attach-scaladocs - none - - - - - - diff --git a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java deleted file mode 100644 index 14975265ab2ce..0000000000000 --- a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java +++ /dev/null @@ -1,393 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark; - -import java.io.File; -import java.io.Serializable; -import java.util.*; - -import scala.Tuple2; - -import com.google.common.collect.Iterables; -import com.google.common.base.Optional; -import com.google.common.io.Files; -import org.apache.hadoop.io.IntWritable; -import org.apache.hadoop.io.Text; -import org.apache.hadoop.mapred.SequenceFileOutputFormat; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import org.apache.spark.api.java.JavaDoubleRDD; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.*; -import org.apache.spark.util.Utils; - -/** - * Most of these tests replicate org.apache.spark.JavaAPISuite using java 8 - * lambda syntax. - */ -public class Java8APISuite implements Serializable { - static int foreachCalls = 0; - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaAPISuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } - - @Test - public void foreachWithAnonymousClass() { - foreachCalls = 0; - JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); - rdd.foreach(new VoidFunction() { - @Override - public void call(String s) { - foreachCalls++; - } - }); - Assert.assertEquals(2, foreachCalls); - } - - @Test - public void foreach() { - foreachCalls = 0; - JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); - rdd.foreach(x -> foreachCalls++); - Assert.assertEquals(2, foreachCalls); - } - - @Test - public void groupBy() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); - Function isOdd = x -> x % 2 == 0; - JavaPairRDD> oddsAndEvens = rdd.groupBy(isOdd); - Assert.assertEquals(2, oddsAndEvens.count()); - Assert.assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0))); // Evens - Assert.assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds - - oddsAndEvens = rdd.groupBy(isOdd, 1); - Assert.assertEquals(2, oddsAndEvens.count()); - Assert.assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0))); // Evens - Assert.assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds - } - - @Test - public void leftOuterJoin() { - JavaPairRDD rdd1 = sc.parallelizePairs(Arrays.asList( - new Tuple2<>(1, 1), - new Tuple2<>(1, 2), - new Tuple2<>(2, 1), - new Tuple2<>(3, 1) - )); - JavaPairRDD rdd2 = sc.parallelizePairs(Arrays.asList( - new Tuple2<>(1, 'x'), - new Tuple2<>(2, 'y'), - new Tuple2<>(2, 'z'), - new Tuple2<>(4, 'w') - )); - List>>> joined = - rdd1.leftOuterJoin(rdd2).collect(); - Assert.assertEquals(5, joined.size()); - Tuple2>> firstUnmatched = - rdd1.leftOuterJoin(rdd2).filter(tup -> !tup._2()._2().isPresent()).first(); - Assert.assertEquals(3, firstUnmatched._1().intValue()); - } - - @Test - public void foldReduce() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); - Function2 add = (a, b) -> a + b; - - int sum = rdd.fold(0, add); - Assert.assertEquals(33, sum); - - sum = rdd.reduce(add); - Assert.assertEquals(33, sum); - } - - @Test - public void foldByKey() { - List> pairs = Arrays.asList( - new Tuple2<>(2, 1), - new Tuple2<>(2, 1), - new Tuple2<>(1, 1), - new Tuple2<>(3, 2), - new Tuple2<>(3, 1) - ); - JavaPairRDD rdd = sc.parallelizePairs(pairs); - JavaPairRDD sums = rdd.foldByKey(0, (a, b) -> a + b); - Assert.assertEquals(1, sums.lookup(1).get(0).intValue()); - Assert.assertEquals(2, sums.lookup(2).get(0).intValue()); - Assert.assertEquals(3, sums.lookup(3).get(0).intValue()); - } - - @Test - public void reduceByKey() { - List> pairs = Arrays.asList( - new Tuple2<>(2, 1), - new Tuple2<>(2, 1), - new Tuple2<>(1, 1), - new Tuple2<>(3, 2), - new Tuple2<>(3, 1) - ); - JavaPairRDD rdd = sc.parallelizePairs(pairs); - JavaPairRDD counts = rdd.reduceByKey((a, b) -> a + b); - Assert.assertEquals(1, counts.lookup(1).get(0).intValue()); - Assert.assertEquals(2, counts.lookup(2).get(0).intValue()); - Assert.assertEquals(3, counts.lookup(3).get(0).intValue()); - - Map localCounts = counts.collectAsMap(); - Assert.assertEquals(1, localCounts.get(1).intValue()); - Assert.assertEquals(2, localCounts.get(2).intValue()); - Assert.assertEquals(3, localCounts.get(3).intValue()); - - localCounts = rdd.reduceByKeyLocally((a, b) -> a + b); - Assert.assertEquals(1, localCounts.get(1).intValue()); - Assert.assertEquals(2, localCounts.get(2).intValue()); - Assert.assertEquals(3, localCounts.get(3).intValue()); - } - - @Test - public void map() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - JavaDoubleRDD doubles = rdd.mapToDouble(x -> 1.0 * x).cache(); - doubles.collect(); - JavaPairRDD pairs = rdd.mapToPair(x -> new Tuple2<>(x, x)) - .cache(); - pairs.collect(); - JavaRDD strings = rdd.map(Object::toString).cache(); - strings.collect(); - } - - @Test - public void flatMap() { - JavaRDD rdd = sc.parallelize(Arrays.asList("Hello World!", - "The quick brown fox jumps over the lazy dog.")); - JavaRDD words = rdd.flatMap(x -> Arrays.asList(x.split(" "))); - - Assert.assertEquals("Hello", words.first()); - Assert.assertEquals(11, words.count()); - - JavaPairRDD pairs = rdd.flatMapToPair(s -> { - List> pairs2 = new LinkedList<>(); - for (String word : s.split(" ")) { - pairs2.add(new Tuple2<>(word, word)); - } - return pairs2; - }); - - Assert.assertEquals(new Tuple2<>("Hello", "Hello"), pairs.first()); - Assert.assertEquals(11, pairs.count()); - - JavaDoubleRDD doubles = rdd.flatMapToDouble(s -> { - List lengths = new LinkedList<>(); - for (String word : s.split(" ")) { - lengths.add((double) word.length()); - } - return lengths; - }); - - Assert.assertEquals(5.0, doubles.first(), 0.01); - Assert.assertEquals(11, pairs.count()); - } - - @Test - public void mapsFromPairsToPairs() { - List> pairs = Arrays.asList( - new Tuple2<>(1, "a"), - new Tuple2<>(2, "aa"), - new Tuple2<>(3, "aaa") - ); - JavaPairRDD pairRDD = sc.parallelizePairs(pairs); - - // Regression test for SPARK-668: - JavaPairRDD swapped = - pairRDD.flatMapToPair(x -> Collections.singletonList(x.swap())); - swapped.collect(); - - // There was never a bug here, but it's worth testing: - pairRDD.map(Tuple2::swap).collect(); - } - - @Test - public void mapPartitions() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); - JavaRDD partitionSums = rdd.mapPartitions(iter -> { - int sum = 0; - while (iter.hasNext()) { - sum += iter.next(); - } - return Collections.singletonList(sum); - }); - - Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); - } - - @Test - public void sequenceFile() { - File tempDir = Files.createTempDir(); - tempDir.deleteOnExit(); - String outputDir = new File(tempDir, "output").getAbsolutePath(); - List> pairs = Arrays.asList( - new Tuple2<>(1, "a"), - new Tuple2<>(2, "aa"), - new Tuple2<>(3, "aaa") - ); - JavaPairRDD rdd = sc.parallelizePairs(pairs); - - rdd.mapToPair(pair -> new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2()))) - .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); - - // Try reading the output back as an object file - JavaPairRDD readRDD = sc.sequenceFile(outputDir, IntWritable.class, Text.class) - .mapToPair(pair -> new Tuple2<>(pair._1().get(), pair._2().toString())); - Assert.assertEquals(pairs, readRDD.collect()); - Utils.deleteRecursively(tempDir); - } - - @Test - public void zip() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - JavaDoubleRDD doubles = rdd.mapToDouble(x -> 1.0 * x); - JavaPairRDD zipped = rdd.zip(doubles); - zipped.count(); - } - - @Test - public void zipPartitions() { - JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6), 2); - JavaRDD rdd2 = sc.parallelize(Arrays.asList("1", "2", "3", "4"), 2); - FlatMapFunction2, Iterator, Integer> sizesFn = - (Iterator i, Iterator s) -> { - int sizeI = 0; - while (i.hasNext()) { - sizeI += 1; - i.next(); - } - int sizeS = 0; - while (s.hasNext()) { - sizeS += 1; - s.next(); - } - return Arrays.asList(sizeI, sizeS); - }; - JavaRDD sizes = rdd1.zipPartitions(rdd2, sizesFn); - Assert.assertEquals("[3, 2, 3, 2]", sizes.collect().toString()); - } - - @Test - public void accumulators() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - - Accumulator intAccum = sc.intAccumulator(10); - rdd.foreach(intAccum::add); - Assert.assertEquals((Integer) 25, intAccum.value()); - - Accumulator doubleAccum = sc.doubleAccumulator(10.0); - rdd.foreach(x -> doubleAccum.add((double) x)); - Assert.assertEquals((Double) 25.0, doubleAccum.value()); - - // Try a custom accumulator type - AccumulatorParam floatAccumulatorParam = new AccumulatorParam() { - @Override - public Float addInPlace(Float r, Float t) { - return r + t; - } - @Override - public Float addAccumulator(Float r, Float t) { - return r + t; - } - @Override - public Float zero(Float initialValue) { - return 0.0f; - } - }; - - Accumulator floatAccum = sc.accumulator(10.0f, floatAccumulatorParam); - rdd.foreach(x -> floatAccum.add((float) x)); - Assert.assertEquals((Float) 25.0f, floatAccum.value()); - - // Test the setValue method - floatAccum.setValue(5.0f); - Assert.assertEquals((Float) 5.0f, floatAccum.value()); - } - - @Test - public void keyBy() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2)); - List> s = rdd.keyBy(Object::toString).collect(); - Assert.assertEquals(new Tuple2<>("1", 1), s.get(0)); - Assert.assertEquals(new Tuple2<>("2", 2), s.get(1)); - } - - @Test - public void mapOnPairRDD() { - JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4)); - JavaPairRDD rdd2 = - rdd1.mapToPair(i -> new Tuple2<>(i, i % 2)); - JavaPairRDD rdd3 = - rdd2.mapToPair(in -> new Tuple2<>(in._2(), in._1())); - Assert.assertEquals(Arrays.asList( - new Tuple2<>(1, 1), - new Tuple2<>(0, 2), - new Tuple2<>(1, 3), - new Tuple2<>(0, 4)), rdd3.collect()); - } - - @Test - public void collectPartitions() { - JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7), 3); - - JavaPairRDD rdd2 = - rdd1.mapToPair(i -> new Tuple2<>(i, i % 2)); - List[] parts = rdd1.collectPartitions(new int[]{0}); - Assert.assertEquals(Arrays.asList(1, 2), parts[0]); - - parts = rdd1.collectPartitions(new int[]{1, 2}); - Assert.assertEquals(Arrays.asList(3, 4), parts[0]); - Assert.assertEquals(Arrays.asList(5, 6, 7), parts[1]); - - Assert.assertEquals(Arrays.asList(new Tuple2<>(1, 1), new Tuple2<>(2, 0)), - rdd2.collectPartitions(new int[]{0})[0]); - - List>[] parts2 = rdd2.collectPartitions(new int[]{1, 2}); - Assert.assertEquals(Arrays.asList(new Tuple2<>(3, 1), new Tuple2<>(4, 0)), parts2[0]); - Assert.assertEquals(Arrays.asList(new Tuple2<>(5, 1), new Tuple2<>(6, 0), new Tuple2<>(7, 1)), - parts2[1]); - } - - @Test - public void collectAsMapWithIntArrayValues() { - // Regression test for SPARK-1040 - JavaRDD rdd = sc.parallelize(Arrays.asList(1)); - JavaPairRDD pairRDD = - rdd.mapToPair(x -> new Tuple2<>(x, new int[]{x})); - pairRDD.collect(); // Works fine - pairRDD.collectAsMap(); // Used to crash with ClassCastException - } -} diff --git a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java deleted file mode 100644 index 73091cfe2c09e..0000000000000 --- a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java +++ /dev/null @@ -1,834 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming; - -import java.io.Serializable; -import java.util.*; - -import scala.Tuple2; - -import com.google.common.base.Optional; -import com.google.common.collect.Lists; -import com.google.common.collect.Sets; -import org.junit.Assert; -import org.junit.Test; - -import org.apache.spark.HashPartitioner; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.streaming.api.java.JavaDStream; -import org.apache.spark.streaming.api.java.JavaPairDStream; - -/** - * Most of these tests replicate org.apache.spark.streaming.JavaAPISuite using java 8 - * lambda syntax. - */ -@SuppressWarnings("unchecked") -public class Java8APISuite extends LocalJavaStreamingContext implements Serializable { - - @Test - public void testMap() { - List> inputData = Arrays.asList( - Arrays.asList("hello", "world"), - Arrays.asList("goodnight", "moon")); - - List> expected = Arrays.asList( - Arrays.asList(5, 5), - Arrays.asList(9, 4)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream letterCount = stream.map(String::length); - JavaTestUtils.attachTestOutputStream(letterCount); - List> result = JavaTestUtils.runStreams(ssc, 2, 2); - - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testFilter() { - List> inputData = Arrays.asList( - Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red sox")); - - List> expected = Arrays.asList( - Arrays.asList("giants"), - Arrays.asList("yankees")); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream filtered = stream.filter(s -> s.contains("a")); - JavaTestUtils.attachTestOutputStream(filtered); - List> result = JavaTestUtils.runStreams(ssc, 2, 2); - - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testMapPartitions() { - List> inputData = Arrays.asList( - Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red sox")); - - List> expected = Arrays.asList( - Arrays.asList("GIANTSDODGERS"), - Arrays.asList("YANKEESRED SOX")); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream mapped = stream.mapPartitions(in -> { - String out = ""; - while (in.hasNext()) { - out = out + in.next().toUpperCase(); - } - return Lists.newArrayList(out); - }); - JavaTestUtils.attachTestOutputStream(mapped); - List> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testReduce() { - List> inputData = Arrays.asList( - Arrays.asList(1, 2, 3), - Arrays.asList(4, 5, 6), - Arrays.asList(7, 8, 9)); - - List> expected = Arrays.asList( - Arrays.asList(6), - Arrays.asList(15), - Arrays.asList(24)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream reduced = stream.reduce((x, y) -> x + y); - JavaTestUtils.attachTestOutputStream(reduced); - List> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testReduceByWindow() { - List> inputData = Arrays.asList( - Arrays.asList(1, 2, 3), - Arrays.asList(4, 5, 6), - Arrays.asList(7, 8, 9)); - - List> expected = Arrays.asList( - Arrays.asList(6), - Arrays.asList(21), - Arrays.asList(39), - Arrays.asList(24)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream reducedWindowed = stream.reduceByWindow((x, y) -> x + y, - (x, y) -> x - y, new Duration(2000), new Duration(1000)); - JavaTestUtils.attachTestOutputStream(reducedWindowed); - List> result = JavaTestUtils.runStreams(ssc, 4, 4); - - Assert.assertEquals(expected, result); - } - - @Test - public void testTransform() { - List> inputData = Arrays.asList( - Arrays.asList(1, 2, 3), - Arrays.asList(4, 5, 6), - Arrays.asList(7, 8, 9)); - - List> expected = Arrays.asList( - Arrays.asList(3, 4, 5), - Arrays.asList(6, 7, 8), - Arrays.asList(9, 10, 11)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream transformed = stream.transform(in -> in.map(i -> i + 2)); - - JavaTestUtils.attachTestOutputStream(transformed); - List> result = JavaTestUtils.runStreams(ssc, 3, 3); - - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testVariousTransform() { - // tests whether all variations of transform can be called from Java - - List> inputData = Arrays.asList(Arrays.asList(1)); - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - - List>> pairInputData = - Arrays.asList(Arrays.asList(new Tuple2<>("x", 1))); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream( - JavaTestUtils.attachTestInputStream(ssc, pairInputData, 1)); - - JavaDStream transformed1 = stream.transform(in -> null); - JavaDStream transformed2 = stream.transform((x, time) -> null); - JavaPairDStream transformed3 = stream.transformToPair(x -> null); - JavaPairDStream transformed4 = stream.transformToPair((x, time) -> null); - JavaDStream pairTransformed1 = pairStream.transform(x -> null); - JavaDStream pairTransformed2 = pairStream.transform((x, time) -> null); - JavaPairDStream pairTransformed3 = pairStream.transformToPair(x -> null); - JavaPairDStream pairTransformed4 = - pairStream.transformToPair((x, time) -> null); - - } - - @Test - public void testTransformWith() { - List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList( - new Tuple2<>("california", "dodgers"), - new Tuple2<>("new york", "yankees")), - Arrays.asList( - new Tuple2<>("california", "sharks"), - new Tuple2<>("new york", "rangers"))); - - List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList( - new Tuple2<>("california", "giants"), - new Tuple2<>("new york", "mets")), - Arrays.asList( - new Tuple2<>("california", "ducks"), - new Tuple2<>("new york", "islanders"))); - - - List>>> expected = Arrays.asList( - Sets.newHashSet( - new Tuple2<>("california", - new Tuple2<>("dodgers", "giants")), - new Tuple2<>("new york", - new Tuple2<>("yankees", "mets"))), - Sets.newHashSet( - new Tuple2<>("california", - new Tuple2<>("sharks", "ducks")), - new Tuple2<>("new york", - new Tuple2<>("rangers", "islanders")))); - - JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( - ssc, stringStringKVStream1, 1); - JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream(stream1); - - JavaDStream> stream2 = JavaTestUtils.attachTestInputStream( - ssc, stringStringKVStream2, 1); - JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); - - JavaPairDStream> joined = - pairStream1.transformWithToPair(pairStream2,(x, y, z) -> x.join(y)); - - JavaTestUtils.attachTestOutputStream(joined); - List>>> result = JavaTestUtils.runStreams(ssc, 2, 2); - List>>> unorderedResult = Lists.newArrayList(); - for (List>> res : result) { - unorderedResult.add(Sets.newHashSet(res)); - } - - Assert.assertEquals(expected, unorderedResult); - } - - - @Test - public void testVariousTransformWith() { - // tests whether all variations of transformWith can be called from Java - - List> inputData1 = Arrays.asList(Arrays.asList(1)); - List> inputData2 = Arrays.asList(Arrays.asList("x")); - JavaDStream stream1 = JavaTestUtils.attachTestInputStream(ssc, inputData1, 1); - JavaDStream stream2 = JavaTestUtils.attachTestInputStream(ssc, inputData2, 1); - - List>> pairInputData1 = - Arrays.asList(Arrays.asList(new Tuple2<>("x", 1))); - List>> pairInputData2 = - Arrays.asList(Arrays.asList(new Tuple2<>(1.0, 'x'))); - JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream( - JavaTestUtils.attachTestInputStream(ssc, pairInputData1, 1)); - JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream( - JavaTestUtils.attachTestInputStream(ssc, pairInputData2, 1)); - - JavaDStream transformed1 = stream1.transformWith(stream2, (x, y, z) -> null); - JavaDStream transformed2 = stream1.transformWith(pairStream1,(x, y, z) -> null); - - JavaPairDStream transformed3 = - stream1.transformWithToPair(stream2,(x, y, z) -> null); - - JavaPairDStream transformed4 = - stream1.transformWithToPair(pairStream1,(x, y, z) -> null); - - JavaDStream pairTransformed1 = pairStream1.transformWith(stream2,(x, y, z) -> null); - - JavaDStream pairTransformed2_ = - pairStream1.transformWith(pairStream1,(x, y, z) -> null); - - JavaPairDStream pairTransformed3 = - pairStream1.transformWithToPair(stream2,(x, y, z) -> null); - - JavaPairDStream pairTransformed4 = - pairStream1.transformWithToPair(pairStream2,(x, y, z) -> null); - } - - @Test - public void testStreamingContextTransform() { - List> stream1input = Arrays.asList( - Arrays.asList(1), - Arrays.asList(2) - ); - - List> stream2input = Arrays.asList( - Arrays.asList(3), - Arrays.asList(4) - ); - - List>> pairStream1input = Arrays.asList( - Arrays.asList(new Tuple2<>(1, "x")), - Arrays.asList(new Tuple2<>(2, "y")) - ); - - List>>> expected = Arrays.asList( - Arrays.asList(new Tuple2<>(1, new Tuple2<>(1, "x"))), - Arrays.asList(new Tuple2<>(2, new Tuple2<>(2, "y"))) - ); - - JavaDStream stream1 = JavaTestUtils.attachTestInputStream(ssc, stream1input, 1); - JavaDStream stream2 = JavaTestUtils.attachTestInputStream(ssc, stream2input, 1); - JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream( - JavaTestUtils.attachTestInputStream(ssc, pairStream1input, 1)); - - List> listOfDStreams1 = Arrays.>asList(stream1, stream2); - - // This is just to test whether this transform to JavaStream compiles - JavaDStream transformed1 = ssc.transform( - listOfDStreams1, (List> listOfRDDs, Time time) -> { - Assert.assertEquals(2, listOfRDDs.size()); - return null; - }); - - List> listOfDStreams2 = - Arrays.>asList(stream1, stream2, pairStream1.toJavaDStream()); - - JavaPairDStream> transformed2 = ssc.transformToPair( - listOfDStreams2, (List> listOfRDDs, Time time) -> { - Assert.assertEquals(3, listOfRDDs.size()); - JavaRDD rdd1 = (JavaRDD) listOfRDDs.get(0); - JavaRDD rdd2 = (JavaRDD) listOfRDDs.get(1); - JavaRDD> rdd3 = (JavaRDD>) listOfRDDs.get(2); - JavaPairRDD prdd3 = JavaPairRDD.fromJavaRDD(rdd3); - PairFunction mapToTuple = - (Integer i) -> new Tuple2<>(i, i); - return rdd1.union(rdd2).mapToPair(mapToTuple).join(prdd3); - }); - JavaTestUtils.attachTestOutputStream(transformed2); - List>>> result = - JavaTestUtils.runStreams(ssc, 2, 2); - Assert.assertEquals(expected, result); - } - - @Test - public void testFlatMap() { - List> inputData = Arrays.asList( - Arrays.asList("go", "giants"), - Arrays.asList("boo", "dodgers"), - Arrays.asList("athletics")); - - List> expected = Arrays.asList( - Arrays.asList("g", "o", "g", "i", "a", "n", "t", "s"), - Arrays.asList("b", "o", "o", "d", "o", "d", "g", "e", "r", "s"), - Arrays.asList("a", "t", "h", "l", "e", "t", "i", "c", "s")); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream flatMapped = stream.flatMap(s -> Lists.newArrayList(s.split("(?!^)"))); - JavaTestUtils.attachTestOutputStream(flatMapped); - List> result = JavaTestUtils.runStreams(ssc, 3, 3); - - assertOrderInvariantEquals(expected, result); - } - - @Test - public void testPairFlatMap() { - List> inputData = Arrays.asList( - Arrays.asList("giants"), - Arrays.asList("dodgers"), - Arrays.asList("athletics")); - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2<>(6, "g"), - new Tuple2<>(6, "i"), - new Tuple2<>(6, "a"), - new Tuple2<>(6, "n"), - new Tuple2<>(6, "t"), - new Tuple2<>(6, "s")), - Arrays.asList( - new Tuple2<>(7, "d"), - new Tuple2<>(7, "o"), - new Tuple2<>(7, "d"), - new Tuple2<>(7, "g"), - new Tuple2<>(7, "e"), - new Tuple2<>(7, "r"), - new Tuple2<>(7, "s")), - Arrays.asList( - new Tuple2<>(9, "a"), - new Tuple2<>(9, "t"), - new Tuple2<>(9, "h"), - new Tuple2<>(9, "l"), - new Tuple2<>(9, "e"), - new Tuple2<>(9, "t"), - new Tuple2<>(9, "i"), - new Tuple2<>(9, "c"), - new Tuple2<>(9, "s"))); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream flatMapped = stream.flatMapToPair(s -> { - List> out = Lists.newArrayList(); - for (String letter : s.split("(?!^)")) { - out.add(new Tuple2<>(s.length(), letter)); - } - return out; - }); - - JavaTestUtils.attachTestOutputStream(flatMapped); - List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - /* - * Performs an order-invariant comparison of lists representing two RDD streams. This allows - * us to account for ordering variation within individual RDD's which occurs during windowing. - */ - public static > void assertOrderInvariantEquals( - List> expected, List> actual) { - expected.forEach((List list) -> Collections.sort(list)); - actual.forEach((List list) -> Collections.sort(list)); - Assert.assertEquals(expected, actual); - } - - @Test - public void testPairFilter() { - List> inputData = Arrays.asList( - Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red sox")); - - List>> expected = Arrays.asList( - Arrays.asList(new Tuple2<>("giants", 6)), - Arrays.asList(new Tuple2<>("yankees", 7))); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = - stream.mapToPair(x -> new Tuple2<>(x, x.length())); - JavaPairDStream filtered = pairStream.filter(x -> x._1().contains("a")); - JavaTestUtils.attachTestOutputStream(filtered); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - List>> stringStringKVStream = Arrays.asList( - Arrays.asList(new Tuple2<>("california", "dodgers"), - new Tuple2<>("california", "giants"), - new Tuple2<>("new york", "yankees"), - new Tuple2<>("new york", "mets")), - Arrays.asList(new Tuple2<>("california", "sharks"), - new Tuple2<>("california", "ducks"), - new Tuple2<>("new york", "rangers"), - new Tuple2<>("new york", "islanders"))); - - List>> stringIntKVStream = Arrays.asList( - Arrays.asList( - new Tuple2<>("california", 1), - new Tuple2<>("california", 3), - new Tuple2<>("new york", 4), - new Tuple2<>("new york", 1)), - Arrays.asList( - new Tuple2<>("california", 5), - new Tuple2<>("california", 5), - new Tuple2<>("new york", 3), - new Tuple2<>("new york", 1))); - - @Test - public void testPairMap() { // Maps pair -> pair of different type - List>> inputData = stringIntKVStream; - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2<>(1, "california"), - new Tuple2<>(3, "california"), - new Tuple2<>(4, "new york"), - new Tuple2<>(1, "new york")), - Arrays.asList( - new Tuple2<>(5, "california"), - new Tuple2<>(5, "california"), - new Tuple2<>(3, "new york"), - new Tuple2<>(1, "new york"))); - - JavaDStream> stream = - JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream reversed = pairStream.mapToPair(x -> x.swap()); - JavaTestUtils.attachTestOutputStream(reversed); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testPairMapPartitions() { // Maps pair -> pair of different type - List>> inputData = stringIntKVStream; - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2<>(1, "california"), - new Tuple2<>(3, "california"), - new Tuple2<>(4, "new york"), - new Tuple2<>(1, "new york")), - Arrays.asList( - new Tuple2<>(5, "california"), - new Tuple2<>(5, "california"), - new Tuple2<>(3, "new york"), - new Tuple2<>(1, "new york"))); - - JavaDStream> stream = - JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream reversed = pairStream.mapPartitionsToPair(in -> { - LinkedList> out = new LinkedList<>(); - while (in.hasNext()) { - Tuple2 next = in.next(); - out.add(next.swap()); - } - return out; - }); - - JavaTestUtils.attachTestOutputStream(reversed); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testPairMap2() { // Maps pair -> single - List>> inputData = stringIntKVStream; - - List> expected = Arrays.asList( - Arrays.asList(1, 3, 4, 1), - Arrays.asList(5, 5, 3, 1)); - - JavaDStream> stream = - JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaDStream reversed = pairStream.map(in -> in._2()); - JavaTestUtils.attachTestOutputStream(reversed); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair - List>> inputData = Arrays.asList( - Arrays.asList( - new Tuple2<>("hi", 1), - new Tuple2<>("ho", 2)), - Arrays.asList( - new Tuple2<>("hi", 1), - new Tuple2<>("ho", 2))); - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2<>(1, "h"), - new Tuple2<>(1, "i"), - new Tuple2<>(2, "h"), - new Tuple2<>(2, "o")), - Arrays.asList( - new Tuple2<>(1, "h"), - new Tuple2<>(1, "i"), - new Tuple2<>(2, "h"), - new Tuple2<>(2, "o"))); - - JavaDStream> stream = - JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream flatMapped = pairStream.flatMapToPair(in -> { - List> out = new LinkedList<>(); - for (Character s : in._1().toCharArray()) { - out.add(new Tuple2<>(in._2(), s.toString())); - } - return out; - }); - - JavaTestUtils.attachTestOutputStream(flatMapped); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testPairReduceByKey() { - List>> inputData = stringIntKVStream; - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2<>("california", 4), - new Tuple2<>("new york", 5)), - Arrays.asList( - new Tuple2<>("california", 10), - new Tuple2<>("new york", 4))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream reduced = pairStream.reduceByKey((x, y) -> x + y); - - JavaTestUtils.attachTestOutputStream(reduced); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testCombineByKey() { - List>> inputData = stringIntKVStream; - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2<>("california", 4), - new Tuple2<>("new york", 5)), - Arrays.asList( - new Tuple2<>("california", 10), - new Tuple2<>("new york", 4))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream combined = pairStream.combineByKey(i -> i, - (x, y) -> x + y, (x, y) -> x + y, new HashPartitioner(2)); - - JavaTestUtils.attachTestOutputStream(combined); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testReduceByKeyAndWindow() { - List>> inputData = stringIntKVStream; - - List>> expected = Arrays.asList( - Arrays.asList(new Tuple2<>("california", 4), - new Tuple2<>("new york", 5)), - Arrays.asList(new Tuple2<>("california", 14), - new Tuple2<>("new york", 9)), - Arrays.asList(new Tuple2<>("california", 10), - new Tuple2<>("new york", 4))); - - JavaDStream> stream = - JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream reduceWindowed = - pairStream.reduceByKeyAndWindow((x, y) -> x + y, new Duration(2000), new Duration(1000)); - JavaTestUtils.attachTestOutputStream(reduceWindowed); - List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testUpdateStateByKey() { - List>> inputData = stringIntKVStream; - - List>> expected = Arrays.asList( - Arrays.asList(new Tuple2<>("california", 4), - new Tuple2<>("new york", 5)), - Arrays.asList(new Tuple2<>("california", 14), - new Tuple2<>("new york", 9)), - Arrays.asList(new Tuple2<>("california", 14), - new Tuple2<>("new york", 9))); - - JavaDStream> stream = - JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream updated = pairStream.updateStateByKey((values, state) -> { - int out = 0; - if (state.isPresent()) { - out = out + state.get(); - } - for (Integer v : values) { - out = out + v; - } - return Optional.of(out); - }); - - JavaTestUtils.attachTestOutputStream(updated); - List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testReduceByKeyAndWindowWithInverse() { - List>> inputData = stringIntKVStream; - - List>> expected = Arrays.asList( - Arrays.asList(new Tuple2<>("california", 4), - new Tuple2<>("new york", 5)), - Arrays.asList(new Tuple2<>("california", 14), - new Tuple2<>("new york", 9)), - Arrays.asList(new Tuple2<>("california", 10), - new Tuple2<>("new york", 4))); - - JavaDStream> stream = - JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream reduceWindowed = - pairStream.reduceByKeyAndWindow((x, y) -> x + y, (x, y) -> x - y, new Duration(2000), - new Duration(1000)); - JavaTestUtils.attachTestOutputStream(reduceWindowed); - List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - - Assert.assertEquals(expected, result); - } - - @Test - public void testPairTransform() { - List>> inputData = Arrays.asList( - Arrays.asList( - new Tuple2<>(3, 5), - new Tuple2<>(1, 5), - new Tuple2<>(4, 5), - new Tuple2<>(2, 5)), - Arrays.asList( - new Tuple2<>(2, 5), - new Tuple2<>(3, 5), - new Tuple2<>(4, 5), - new Tuple2<>(1, 5))); - - List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2<>(1, 5), - new Tuple2<>(2, 5), - new Tuple2<>(3, 5), - new Tuple2<>(4, 5)), - Arrays.asList( - new Tuple2<>(1, 5), - new Tuple2<>(2, 5), - new Tuple2<>(3, 5), - new Tuple2<>(4, 5))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream sorted = pairStream.transformToPair(in -> in.sortByKey()); - - JavaTestUtils.attachTestOutputStream(sorted); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testPairToNormalRDDTransform() { - List>> inputData = Arrays.asList( - Arrays.asList( - new Tuple2<>(3, 5), - new Tuple2<>(1, 5), - new Tuple2<>(4, 5), - new Tuple2<>(2, 5)), - Arrays.asList( - new Tuple2<>(2, 5), - new Tuple2<>(3, 5), - new Tuple2<>(4, 5), - new Tuple2<>(1, 5))); - - List> expected = Arrays.asList( - Arrays.asList(3, 1, 4, 2), - Arrays.asList(2, 3, 4, 1)); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaDStream firstParts = pairStream.transform(in -> in.map(x -> x._1())); - JavaTestUtils.attachTestOutputStream(firstParts); - List> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testMapValues() { - List>> inputData = stringStringKVStream; - - List>> expected = Arrays.asList( - Arrays.asList(new Tuple2<>("california", "DODGERS"), - new Tuple2<>("california", "GIANTS"), - new Tuple2<>("new york", "YANKEES"), - new Tuple2<>("new york", "METS")), - Arrays.asList(new Tuple2<>("california", "SHARKS"), - new Tuple2<>("california", "DUCKS"), - new Tuple2<>("new york", "RANGERS"), - new Tuple2<>("new york", "ISLANDERS"))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream mapped = pairStream.mapValues(String::toUpperCase); - JavaTestUtils.attachTestOutputStream(mapped); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); - } - - @Test - public void testFlatMapValues() { - List>> inputData = stringStringKVStream; - - List>> expected = Arrays.asList( - Arrays.asList(new Tuple2<>("california", "dodgers1"), - new Tuple2<>("california", "dodgers2"), - new Tuple2<>("california", "giants1"), - new Tuple2<>("california", "giants2"), - new Tuple2<>("new york", "yankees1"), - new Tuple2<>("new york", "yankees2"), - new Tuple2<>("new york", "mets1"), - new Tuple2<>("new york", "mets2")), - Arrays.asList(new Tuple2<>("california", "sharks1"), - new Tuple2<>("california", "sharks2"), - new Tuple2<>("california", "ducks1"), - new Tuple2<>("california", "ducks2"), - new Tuple2<>("new york", "rangers1"), - new Tuple2<>("new york", "rangers2"), - new Tuple2<>("new york", "islanders1"), - new Tuple2<>("new york", "islanders2"))); - - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream flatMapped = - pairStream.flatMapValues(in -> Arrays.asList(in + "1", in + "2")); - JavaTestUtils.attachTestOutputStream(flatMapped); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); - Assert.assertEquals(expected, result); - } - -} diff --git a/extras/java8-tests/src/test/resources/log4j.properties b/extras/java8-tests/src/test/resources/log4j.properties deleted file mode 100644 index eb3b1999eb996..0000000000000 --- a/extras/java8-tests/src/test/resources/log4j.properties +++ /dev/null @@ -1,28 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=INFO, file -log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file.append=true -log4j.appender.file.file=target/unit-tests.log -log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n - -# Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN -org.spark-project.jetty.LEVEL=WARN diff --git a/extras/kinesis-asl-assembly/pom.xml b/extras/kinesis-asl-assembly/pom.xml deleted file mode 100644 index 61ba4787fbf90..0000000000000 --- a/extras/kinesis-asl-assembly/pom.xml +++ /dev/null @@ -1,181 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT - ../../pom.xml - - - org.apache.spark - spark-streaming-kinesis-asl-assembly_2.10 - jar - Spark Project Kinesis Assembly - http://spark.apache.org/ - - - streaming-kinesis-asl-assembly - - - - - org.apache.spark - spark-streaming-kinesis-asl_${scala.binary.version} - ${project.version} - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - provided - - - - com.fasterxml.jackson.core - jackson-databind - provided - - - commons-lang - commons-lang - provided - - - com.google.protobuf - protobuf-java - provided - - - com.sun.jersey - jersey-server - provided - - - com.sun.jersey - jersey-core - provided - - - log4j - log4j - provided - - - net.java.dev.jets3t - jets3t - provided - - - org.apache.hadoop - hadoop-client - provided - - - org.apache.avro - avro-ipc - provided - - - org.apache.avro - avro-mapred - ${avro.mapred.classifier} - provided - - - org.apache.curator - curator-recipes - provided - - - org.apache.zookeeper - zookeeper - provided - - - org.slf4j - slf4j-api - provided - - - org.slf4j - slf4j-log4j12 - provided - - - org.xerial.snappy - snappy-java - provided - - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - - org.apache.maven.plugins - maven-shade-plugin - - false - - - *:* - - - - - *:* - - META-INF/*.SF - META-INF/*.DSA - META-INF/*.RSA - - - - - - - package - - shade - - - - - - reference.conf - - - log4j.properties - - - - - - - - - - - - diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml deleted file mode 100644 index ef72d97eae69d..0000000000000 --- a/extras/kinesis-asl/pom.xml +++ /dev/null @@ -1,86 +0,0 @@ - - - - 4.0.0 - - org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT - ../../pom.xml - - - - org.apache.spark - spark-streaming-kinesis-asl_2.10 - jar - Spark Kinesis Integration - - - streaming-kinesis-asl - - - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - test-jar - test - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - test-jar - test - - - com.amazonaws - amazon-kinesis-client - ${aws.kinesis.client.version} - - - com.amazonaws - aws-java-sdk - ${aws.java.sdk.version} - - - org.mockito - mockito-core - test - - - org.scalacheck - scalacheck_${scala.binary.version} - test - - - org.apache.spark - spark-test-tags_${scala.binary.version} - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - diff --git a/extras/kinesis-asl/src/main/resources/log4j.properties b/extras/kinesis-asl/src/main/resources/log4j.properties deleted file mode 100644 index 6cdc9286c5d76..0000000000000 --- a/extras/kinesis-asl/src/main/resources/log4j.properties +++ /dev/null @@ -1,37 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -log4j.rootCategory=WARN, console - -# File appender -log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file.append=false -log4j.appender.file.file=target/unit-tests.log -log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n - -# Console appender -log4j.appender.console=org.apache.log4j.ConsoleAppender -log4j.appender.console.target=System.out -log4j.appender.console.layout=org.apache.log4j.PatternLayout -log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n - -# Settings to quiet third party logs that are too verbose -log4j.logger.org.spark-project.jetty=WARN -log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR -log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO -log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO \ No newline at end of file diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala deleted file mode 100644 index 83a4537559512..0000000000000 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.kinesis - -import org.apache.spark.Logging -import org.apache.spark.streaming.Duration -import org.apache.spark.util.{Clock, ManualClock, SystemClock} - -/** - * This is a helper class for managing checkpoint clocks. - * - * @param checkpointInterval - * @param currentClock. Default to current SystemClock if none is passed in (mocking purposes) - */ -private[kinesis] class KinesisCheckpointState( - checkpointInterval: Duration, - currentClock: Clock = new SystemClock()) - extends Logging { - - /* Initialize the checkpoint clock using the given currentClock + checkpointInterval millis */ - val checkpointClock = new ManualClock() - checkpointClock.setTime(currentClock.getTimeMillis() + checkpointInterval.milliseconds) - - /** - * Check if it's time to checkpoint based on the current time and the derived time - * for the next checkpoint - * - * @return true if it's time to checkpoint - */ - def shouldCheckpoint(): Boolean = { - new SystemClock().getTimeMillis() > checkpointClock.getTimeMillis() - } - - /** - * Advance the checkpoint clock by the checkpoint interval. - */ - def advanceCheckpoint(): Unit = { - checkpointClock.advance(checkpointInterval.milliseconds) - } -} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala deleted file mode 100644 index 1d5178790ec4c..0000000000000 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ /dev/null @@ -1,212 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.kinesis - -import java.util.List - -import scala.util.Random -import scala.util.control.NonFatal - -import com.amazonaws.services.kinesis.clientlibrary.exceptions.{InvalidStateException, KinesisClientLibDependencyException, ShutdownException, ThrottlingException} -import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer} -import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason -import com.amazonaws.services.kinesis.model.Record - -import org.apache.spark.Logging - -/** - * Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor. - * This implementation operates on the Array[Byte] from the KinesisReceiver. - * The Kinesis Worker creates an instance of this KinesisRecordProcessor for each - * shard in the Kinesis stream upon startup. This is normally done in separate threads, - * but the KCLs within the KinesisReceivers will balance themselves out if you create - * multiple Receivers. - * - * @param receiver Kinesis receiver - * @param workerId for logging purposes - * @param checkpointState represents the checkpoint state including the next checkpoint time. - * It's injected here for mocking purposes. - */ -private[kinesis] class KinesisRecordProcessor[T]( - receiver: KinesisReceiver[T], - workerId: String, - checkpointState: KinesisCheckpointState) extends IRecordProcessor with Logging { - - // shardId to be populated during initialize() - @volatile - private var shardId: String = _ - - /** - * The Kinesis Client Library calls this method during IRecordProcessor initialization. - * - * @param shardId assigned by the KCL to this particular RecordProcessor. - */ - override def initialize(shardId: String) { - this.shardId = shardId - logInfo(s"Initialized workerId $workerId with shardId $shardId") - } - - /** - * This method is called by the KCL when a batch of records is pulled from the Kinesis stream. - * This is the record-processing bridge between the KCL's IRecordProcessor.processRecords() - * and Spark Streaming's Receiver.store(). - * - * @param batch list of records from the Kinesis stream shard - * @param checkpointer used to update Kinesis when this batch has been processed/stored - * in the DStream - */ - override def processRecords(batch: List[Record], checkpointer: IRecordProcessorCheckpointer) { - if (!receiver.isStopped()) { - try { - receiver.addRecords(shardId, batch) - logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") - - /* - * - * Checkpoint the sequence number of the last record successfully stored. - * Note that in this current implementation, the checkpointing occurs only when after - * checkpointIntervalMillis from the last checkpoint, AND when there is new record - * to process. This leads to the checkpointing lagging behind what records have been - * stored by the receiver. Ofcourse, this can lead records processed more than once, - * under failures and restarts. - * - * TODO: Instead of checkpointing here, run a separate timer task to perform - * checkpointing so that it checkpoints in a timely manner independent of whether - * new records are available or not. - */ - if (checkpointState.shouldCheckpoint()) { - receiver.getLatestSeqNumToCheckpoint(shardId).foreach { latestSeqNum => - /* Perform the checkpoint */ - KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(latestSeqNum), 4, 100) - - /* Update the next checkpoint time */ - checkpointState.advanceCheckpoint() - - logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint of ${batch.size}" + - s" records for shardId $shardId") - logDebug(s"Checkpoint: Next checkpoint is at " + - s" ${checkpointState.checkpointClock.getTimeMillis()} for shardId $shardId") - } - } - } catch { - case NonFatal(e) => { - /* - * If there is a failure within the batch, the batch will not be checkpointed. - * This will potentially cause records since the last checkpoint to be processed - * more than once. - */ - logError(s"Exception: WorkerId $workerId encountered and exception while storing " + - " or checkpointing a batch for workerId $workerId and shardId $shardId.", e) - - /* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor. */ - throw e - } - } - } else { - /* RecordProcessor has been stopped. */ - logInfo(s"Stopped: KinesisReceiver has stopped for workerId $workerId" + - s" and shardId $shardId. No more records will be processed.") - } - } - - /** - * Kinesis Client Library is shutting down this Worker for 1 of 2 reasons: - * 1) the stream is resharding by splitting or merging adjacent shards - * (ShutdownReason.TERMINATE) - * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason - * (ShutdownReason.ZOMBIE) - * - * @param checkpointer used to perform a Kinesis checkpoint for ShutdownReason.TERMINATE - * @param reason for shutdown (ShutdownReason.TERMINATE or ShutdownReason.ZOMBIE) - */ - override def shutdown(checkpointer: IRecordProcessorCheckpointer, reason: ShutdownReason) { - logInfo(s"Shutdown: Shutting down workerId $workerId with reason $reason") - reason match { - /* - * TERMINATE Use Case. Checkpoint. - * Checkpoint to indicate that all records from the shard have been drained and processed. - * It's now OK to read from the new shards that resulted from a resharding event. - */ - case ShutdownReason.TERMINATE => - val latestSeqNumToCheckpointOption = receiver.getLatestSeqNumToCheckpoint(shardId) - if (latestSeqNumToCheckpointOption.nonEmpty) { - KinesisRecordProcessor.retryRandom( - checkpointer.checkpoint(latestSeqNumToCheckpointOption.get), 4, 100) - } - - /* - * ZOMBIE Use Case. NoOp. - * No checkpoint because other workers may have taken over and already started processing - * the same records. - * This may lead to records being processed more than once. - */ - case ShutdownReason.ZOMBIE => - - /* Unknown reason. NoOp */ - case _ => - } - } -} - -private[kinesis] object KinesisRecordProcessor extends Logging { - /** - * Retry the given amount of times with a random backoff time (millis) less than the - * given maxBackOffMillis - * - * @param expression expression to evalute - * @param numRetriesLeft number of retries left - * @param maxBackOffMillis: max millis between retries - * - * @return evaluation of the given expression - * @throws Unretryable exception, unexpected exception, - * or any exception that persists after numRetriesLeft reaches 0 - */ - @annotation.tailrec - def retryRandom[T](expression: => T, numRetriesLeft: Int, maxBackOffMillis: Int): T = { - util.Try { expression } match { - /* If the function succeeded, evaluate to x. */ - case util.Success(x) => x - /* If the function failed, either retry or throw the exception */ - case util.Failure(e) => e match { - /* Retry: Throttling or other Retryable exception has occurred */ - case _: ThrottlingException | _: KinesisClientLibDependencyException if numRetriesLeft > 1 - => { - val backOffMillis = Random.nextInt(maxBackOffMillis) - Thread.sleep(backOffMillis) - logError(s"Retryable Exception: Random backOffMillis=${backOffMillis}", e) - retryRandom(expression, numRetriesLeft - 1, maxBackOffMillis) - } - /* Throw: Shutdown has been requested by the Kinesis Client Library. */ - case _: ShutdownException => { - logError(s"ShutdownException: Caught shutdown exception, skipping checkpoint.", e) - throw e - } - /* Throw: Non-retryable exception has occurred with the Kinesis Client Library */ - case _: InvalidStateException => { - logError(s"InvalidStateException: Cannot save checkpoint to the DynamoDB table used" + - s" by the Amazon Kinesis Client Library. Table likely doesn't exist.", e) - throw e - } - /* Throw: Unexpected exception has occurred */ - case _ => { - logError(s"Unexpected, non-retryable exception.", e) - throw e - } - } - } - } -} diff --git a/extras/kinesis-asl/src/test/resources/log4j.properties b/extras/kinesis-asl/src/test/resources/log4j.properties deleted file mode 100644 index edbecdae92096..0000000000000 --- a/extras/kinesis-asl/src/test/resources/log4j.properties +++ /dev/null @@ -1,27 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=INFO, file -log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file.append=true -log4j.appender.file.file=target/unit-tests.log -log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n - -# Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala deleted file mode 100644 index 17ab444704f44..0000000000000 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ /dev/null @@ -1,268 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.kinesis - -import java.nio.ByteBuffer -import java.nio.charset.StandardCharsets -import java.util.Arrays - -import com.amazonaws.services.kinesis.clientlibrary.exceptions._ -import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer -import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason -import com.amazonaws.services.kinesis.model.Record -import org.mockito.Matchers._ -import org.mockito.Mockito._ -import org.scalatest.mock.MockitoSugar -import org.scalatest.{BeforeAndAfter, Matchers} - -import org.apache.spark.streaming.{Milliseconds, TestSuiteBase} -import org.apache.spark.util.{Clock, ManualClock, Utils} - -/** - * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor - */ -class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAfter - with MockitoSugar { - - val app = "TestKinesisReceiver" - val stream = "mySparkStream" - val endpoint = "endpoint-url" - val workerId = "dummyWorkerId" - val shardId = "dummyShardId" - val seqNum = "dummySeqNum" - val someSeqNum = Some(seqNum) - - val record1 = new Record() - record1.setData(ByteBuffer.wrap("Spark In Action".getBytes(StandardCharsets.UTF_8))) - val record2 = new Record() - record2.setData(ByteBuffer.wrap("Learning Spark".getBytes(StandardCharsets.UTF_8))) - val batch = Arrays.asList(record1, record2) - - var receiverMock: KinesisReceiver[Array[Byte]] = _ - var checkpointerMock: IRecordProcessorCheckpointer = _ - var checkpointClockMock: ManualClock = _ - var checkpointStateMock: KinesisCheckpointState = _ - var currentClockMock: Clock = _ - - override def beforeFunction(): Unit = { - receiverMock = mock[KinesisReceiver[Array[Byte]]] - checkpointerMock = mock[IRecordProcessorCheckpointer] - checkpointClockMock = mock[ManualClock] - checkpointStateMock = mock[KinesisCheckpointState] - currentClockMock = mock[Clock] - } - - override def afterFunction(): Unit = { - super.afterFunction() - // Since this suite was originally written using EasyMock, add this to preserve the old - // mocking semantics (see SPARK-5735 for more details) - verifyNoMoreInteractions(receiverMock, checkpointerMock, checkpointClockMock, - checkpointStateMock, currentClockMock) - } - - test("check serializability of SerializableAWSCredentials") { - Utils.deserialize[SerializableAWSCredentials]( - Utils.serialize(new SerializableAWSCredentials("x", "y"))) - } - - test("process records including store and checkpoint") { - when(receiverMock.isStopped()).thenReturn(false) - when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) - when(checkpointStateMock.shouldCheckpoint()).thenReturn(true) - - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) - recordProcessor.initialize(shardId) - recordProcessor.processRecords(batch, checkpointerMock) - - verify(receiverMock, times(1)).isStopped() - verify(receiverMock, times(1)).addRecords(shardId, batch) - verify(receiverMock, times(1)).getLatestSeqNumToCheckpoint(shardId) - verify(checkpointStateMock, times(1)).shouldCheckpoint() - verify(checkpointerMock, times(1)).checkpoint(anyString) - verify(checkpointStateMock, times(1)).advanceCheckpoint() - } - - test("shouldn't store and checkpoint when receiver is stopped") { - when(receiverMock.isStopped()).thenReturn(true) - - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) - recordProcessor.processRecords(batch, checkpointerMock) - - verify(receiverMock, times(1)).isStopped() - verify(receiverMock, never).addRecords(anyString, anyListOf(classOf[Record])) - verify(checkpointerMock, never).checkpoint(anyString) - } - - test("shouldn't checkpoint when exception occurs during store") { - when(receiverMock.isStopped()).thenReturn(false) - when( - receiverMock.addRecords(anyString, anyListOf(classOf[Record])) - ).thenThrow(new RuntimeException()) - - intercept[RuntimeException] { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) - recordProcessor.initialize(shardId) - recordProcessor.processRecords(batch, checkpointerMock) - } - - verify(receiverMock, times(1)).isStopped() - verify(receiverMock, times(1)).addRecords(shardId, batch) - verify(checkpointerMock, never).checkpoint(anyString) - } - - test("should set checkpoint time to currentTime + checkpoint interval upon instantiation") { - when(currentClockMock.getTimeMillis()).thenReturn(0) - - val checkpointIntervalMillis = 10 - val checkpointState = - new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) - assert(checkpointState.checkpointClock.getTimeMillis() == checkpointIntervalMillis) - - verify(currentClockMock, times(1)).getTimeMillis() - } - - test("should checkpoint if we have exceeded the checkpoint interval") { - when(currentClockMock.getTimeMillis()).thenReturn(0) - - val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MinValue), currentClockMock) - assert(checkpointState.shouldCheckpoint()) - - verify(currentClockMock, times(1)).getTimeMillis() - } - - test("shouldn't checkpoint if we have not exceeded the checkpoint interval") { - when(currentClockMock.getTimeMillis()).thenReturn(0) - - val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MaxValue), currentClockMock) - assert(!checkpointState.shouldCheckpoint()) - - verify(currentClockMock, times(1)).getTimeMillis() - } - - test("should add to time when advancing checkpoint") { - when(currentClockMock.getTimeMillis()).thenReturn(0) - - val checkpointIntervalMillis = 10 - val checkpointState = - new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) - assert(checkpointState.checkpointClock.getTimeMillis() == checkpointIntervalMillis) - checkpointState.advanceCheckpoint() - assert(checkpointState.checkpointClock.getTimeMillis() == (2 * checkpointIntervalMillis)) - - verify(currentClockMock, times(1)).getTimeMillis() - } - - test("shutdown should checkpoint if the reason is TERMINATE") { - when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) - - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) - recordProcessor.initialize(shardId) - recordProcessor.shutdown(checkpointerMock, ShutdownReason.TERMINATE) - - verify(receiverMock, times(1)).getLatestSeqNumToCheckpoint(shardId) - verify(checkpointerMock, times(1)).checkpoint(anyString) - } - - test("shutdown should not checkpoint if the reason is something other than TERMINATE") { - when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) - - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) - recordProcessor.initialize(shardId) - recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE) - recordProcessor.shutdown(checkpointerMock, null) - - verify(checkpointerMock, never).checkpoint(anyString) - } - - test("retry success on first attempt") { - val expectedIsStopped = false - when(receiverMock.isStopped()).thenReturn(expectedIsStopped) - - val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) - assert(actualVal == expectedIsStopped) - - verify(receiverMock, times(1)).isStopped() - } - - test("retry success on second attempt after a Kinesis throttling exception") { - val expectedIsStopped = false - when(receiverMock.isStopped()) - .thenThrow(new ThrottlingException("error message")) - .thenReturn(expectedIsStopped) - - val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) - assert(actualVal == expectedIsStopped) - - verify(receiverMock, times(2)).isStopped() - } - - test("retry success on second attempt after a Kinesis dependency exception") { - val expectedIsStopped = false - when(receiverMock.isStopped()) - .thenThrow(new KinesisClientLibDependencyException("error message")) - .thenReturn(expectedIsStopped) - - val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) - assert(actualVal == expectedIsStopped) - - verify(receiverMock, times(2)).isStopped() - } - - test("retry failed after a shutdown exception") { - when(checkpointerMock.checkpoint()).thenThrow(new ShutdownException("error message")) - - intercept[ShutdownException] { - KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) - } - - verify(checkpointerMock, times(1)).checkpoint() - } - - test("retry failed after an invalid state exception") { - when(checkpointerMock.checkpoint()).thenThrow(new InvalidStateException("error message")) - - intercept[InvalidStateException] { - KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) - } - - verify(checkpointerMock, times(1)).checkpoint() - } - - test("retry failed after unexpected exception") { - when(checkpointerMock.checkpoint()).thenThrow(new RuntimeException("error message")) - - intercept[RuntimeException] { - KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) - } - - verify(checkpointerMock, times(1)).checkpoint() - } - - test("retry failed after exhausing all retries") { - val expectedErrorMessage = "final try error message" - when(checkpointerMock.checkpoint()) - .thenThrow(new ThrottlingException("error message")) - .thenThrow(new ThrottlingException(expectedErrorMessage)) - - val exception = intercept[RuntimeException] { - KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) - } - exception.getMessage().shouldBe(expectedErrorMessage) - - verify(checkpointerMock, times(2)).checkpoint() - } -} diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml deleted file mode 100644 index 87a4f05a05961..0000000000000 --- a/extras/spark-ganglia-lgpl/pom.xml +++ /dev/null @@ -1,49 +0,0 @@ - - - - 4.0.0 - - org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT - ../../pom.xml - - - - org.apache.spark - spark-ganglia-lgpl_2.10 - jar - Spark Ganglia Integration - - - ganglia-lgpl - - - - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - - - - io.dropwizard.metrics - metrics-ganglia - - - diff --git a/graphx/pom.xml b/graphx/pom.xml index 987b831021a54..1813f383cdcba 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT + spark-parent_2.11 + 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-graphx_2.10 + spark-graphx_2.11 graphx @@ -47,6 +47,10 @@ test-jar test + + org.apache.xbean + xbean-asm5-shaded + com.google.guava guava diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index ee7302a1edbf6..45526bf062fab 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -24,12 +24,11 @@ import org.apache.spark.Dependency import org.apache.spark.Partition import org.apache.spark.SparkContext import org.apache.spark.TaskContext -import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel - import org.apache.spark.graphx.impl.EdgePartition import org.apache.spark.graphx.impl.EdgePartitionBuilder import org.apache.spark.graphx.impl.EdgeRDDImpl +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel /** * `EdgeRDD[ED, VD]` extends `RDD[Edge[ED]]` by storing the edges in columnar format on each diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index 869caa340f52b..5485e30f5a2c9 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -297,7 +297,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab /** * Restricts the graph to only the vertices and edges satisfying the predicates. The resulting - * subgraph satisifies + * subgraph satisfies * * {{{ * V' = {v : for all v in V where vpred(v)} @@ -340,55 +340,6 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab */ def groupEdges(merge: (ED, ED) => ED): Graph[VD, ED] - /** - * Aggregates values from the neighboring edges and vertices of each vertex. The user supplied - * `mapFunc` function is invoked on each edge of the graph, generating 0 or more "messages" to be - * "sent" to either vertex in the edge. The `reduceFunc` is then used to combine the output of - * the map phase destined to each vertex. - * - * This function is deprecated in 1.2.0 because of SPARK-3936. Use aggregateMessages instead. - * - * @tparam A the type of "message" to be sent to each vertex - * - * @param mapFunc the user defined map function which returns 0 or - * more messages to neighboring vertices - * - * @param reduceFunc the user defined reduce function which should - * be commutative and associative and is used to combine the output - * of the map phase - * - * @param activeSetOpt an efficient way to run the aggregation on a subset of the edges if - * desired. This is done by specifying a set of "active" vertices and an edge direction. The - * `sendMsg` function will then run only on edges connected to active vertices by edges in the - * specified direction. If the direction is `In`, `sendMsg` will only be run on edges with - * destination in the active set. If the direction is `Out`, `sendMsg` will only be run on edges - * originating from vertices in the active set. If the direction is `Either`, `sendMsg` will be - * run on edges with *either* vertex in the active set. If the direction is `Both`, `sendMsg` - * will be run on edges with *both* vertices in the active set. The active set must have the - * same index as the graph's vertices. - * - * @example We can use this function to compute the in-degree of each - * vertex - * {{{ - * val rawGraph: Graph[(),()] = Graph.textFile("twittergraph") - * val inDeg: RDD[(VertexId, Int)] = - * mapReduceTriplets[Int](et => Iterator((et.dst.id, 1)), _ + _) - * }}} - * - * @note By expressing computation at the edge level we achieve - * maximum parallelism. This is one of the core functions in the - * Graph API in that enables neighborhood level computation. For - * example this function can be used to count neighbors satisfying a - * predicate or implement PageRank. - * - */ - @deprecated("use aggregateMessages", "1.2.0") - def mapReduceTriplets[A: ClassTag]( - mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], - reduceFunc: (A, A) => A, - activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None) - : VertexRDD[A] - /** * Aggregates values from the neighboring edges and vertices of each vertex. The user-supplied * `sendMsg` function is invoked on each edge of the graph, generating 0 or more messages to be diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala deleted file mode 100644 index 563c948957ecf..0000000000000 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx - -import com.esotericsoftware.kryo.Kryo - -import org.apache.spark.serializer.KryoRegistrator -import org.apache.spark.util.BoundedPriorityQueue -import org.apache.spark.util.collection.BitSet - -import org.apache.spark.graphx.impl._ -import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap -import org.apache.spark.util.collection.OpenHashSet - -/** - * Registers GraphX classes with Kryo for improved performance. - */ -@deprecated("Register GraphX classes with Kryo using GraphXUtils.registerKryoClasses", "1.2.0") -class GraphKryoRegistrator extends KryoRegistrator { - - def registerClasses(kryo: Kryo) { - kryo.register(classOf[Edge[Object]]) - kryo.register(classOf[(VertexId, Object)]) - kryo.register(classOf[EdgePartition[Object, Object]]) - kryo.register(classOf[BitSet]) - kryo.register(classOf[VertexIdToIndexMap]) - kryo.register(classOf[VertexAttributeBlock[Object]]) - kryo.register(classOf[PartitionStrategy]) - kryo.register(classOf[BoundedPriorityQueue[Object]]) - kryo.register(classOf[EdgeDirection]) - kryo.register(classOf[GraphXPrimitiveKeyOpenHashMap[VertexId, Int]]) - kryo.register(classOf[OpenHashSet[Int]]) - kryo.register(classOf[OpenHashSet[Long]]) - } -} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala index 21187be7678a6..f678e5f1238fb 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala @@ -17,9 +17,10 @@ package org.apache.spark.graphx -import org.apache.spark.storage.StorageLevel -import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.SparkContext import org.apache.spark.graphx.impl.{EdgePartitionBuilder, GraphImpl} +import org.apache.spark.internal.Logging +import org.apache.spark.storage.StorageLevel /** * Provides utilities for loading [[Graph]]s from files. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index 9451ff1e5c0e2..868658dfe55e5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -21,10 +21,8 @@ import scala.reflect.ClassTag import scala.util.Random import org.apache.spark.SparkException -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.RDD - import org.apache.spark.graphx.lib._ +import org.apache.spark.rdd.RDD /** * Contains additional functionality for [[Graph]]. All operations are expressed in terms of the @@ -185,6 +183,15 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali } } + /** + * Remove self edges. + * + * @return a graph with all self edges removed + */ + def removeSelfEdges(): Graph[VD, ED] = { + graph.subgraph(epred = e => e.srcId != e.dstId) + } + /** * Join the vertices with an RDD and then apply a function from the * vertex and RDD entry to a new vertex value. The input table @@ -229,11 +236,11 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali * @param preprocess a function to compute new vertex and edge data before filtering * @param epred edge pred to filter on after preprocess, see more details under * [[org.apache.spark.graphx.Graph#subgraph]] - * @param vpred vertex pred to filter on after prerocess, see more details under + * @param vpred vertex pred to filter on after preprocess, see more details under * [[org.apache.spark.graphx.Graph#subgraph]] * @tparam VD2 vertex type the vpred operates on * @tparam ED2 edge type the epred operates on - * @return a subgraph of the orginal graph, with its data unchanged + * @return a subgraph of the original graph, with its data unchanged * * @example This function can be used to filter the graph based on some property, without * changing the vertex and edge values in your program. For example, we could remove the vertices @@ -269,10 +276,10 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali if (Random.nextDouble() < probability) { Some(vidVvals._1) } else { None } } - if (selectedVertices.count > 1) { + if (selectedVertices.count > 0) { found = true val collectedVertices = selectedVertices.collect() - retVal = collectedVertices(Random.nextInt(collectedVertices.size)) + retVal = collectedVertices(Random.nextInt(collectedVertices.length)) } } retVal @@ -282,7 +289,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali * Convert bi-directional edges into uni-directional ones. * Some graph algorithms (e.g., TriangleCount) assume that an input graph * has its edges in canonical direction. - * This function rewrites the vertex ids of edges so that srcIds are bigger + * This function rewrites the vertex ids of edges so that srcIds are smaller * than dstIds, and merges the duplicated edges. * * @param mergeFunc the user defined reduce function which should @@ -380,7 +387,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali * @see [[org.apache.spark.graphx.lib.PageRank$#runUntilConvergenceWithOptions]] */ def personalizedPageRank(src: VertexId, tol: Double, - resetProb: Double = 0.15) : Graph[Double, Double] = { + resetProb: Double = 0.15): Graph[Double, Double] = { PageRank.runUntilConvergenceWithOptions(graph, tol, resetProb, Some(src)) } @@ -393,7 +400,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali * @see [[org.apache.spark.graphx.lib.PageRank$#runWithOptions]] */ def staticPersonalizedPageRank(src: VertexId, numIter: Int, - resetProb: Double = 0.15) : Graph[Double, Double] = { + resetProb: Double = 0.15): Graph[Double, Double] = { PageRank.runWithOptions(graph, numIter, resetProb, Some(src)) } @@ -417,6 +424,16 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali ConnectedComponents.run(graph) } + /** + * Compute the connected component membership of each vertex and return a graph with the vertex + * value containing the lowest vertex id in the connected component containing that vertex. + * + * @see [[org.apache.spark.graphx.lib.ConnectedComponents$#run]] + */ + def connectedComponents(maxIterations: Int): Graph[VertexId, ED] = { + ConnectedComponents.run(graph, maxIterations) + } + /** * Compute the number of triangles passing through each vertex. * diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala index 2cb07937eaa2a..ef0b943fc3c38 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala @@ -17,15 +17,16 @@ package org.apache.spark.graphx -import org.apache.spark.SparkConf +import scala.reflect.ClassTag +import org.apache.spark.SparkConf import org.apache.spark.graphx.impl._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap - -import org.apache.spark.util.collection.{OpenHashSet, BitSet} import org.apache.spark.util.BoundedPriorityQueue +import org.apache.spark.util.collection.{BitSet, OpenHashSet} object GraphXUtils { + /** * Registers classes that GraphX uses with Kryo. */ @@ -44,4 +45,28 @@ object GraphXUtils { classOf[OpenHashSet[Int]], classOf[OpenHashSet[Long]])) } + + /** + * A proxy method to map the obsolete API to the new one. + */ + private[graphx] def mapReduceTriplets[VD: ClassTag, ED: ClassTag, A: ClassTag]( + g: Graph[VD, ED], + mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], + reduceFunc: (A, A) => A, + activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None): VertexRDD[A] = { + def sendMsg(ctx: EdgeContext[VD, ED, A]) { + mapFunc(ctx.toEdgeTriplet).foreach { kv => + val id = kv._1 + val msg = kv._2 + if (id == ctx.srcId) { + ctx.sendToSrc(msg) + } else { + assert(id == ctx.dstId) + ctx.sendToDst(msg) + } + } + } + g.aggregateMessagesWithActiveSet( + sendMsg, reduceFunc, TripletFields.All, activeSetOpt) + } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 2ca60d51f8331..646462b4a8350 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -18,8 +18,8 @@ package org.apache.spark.graphx import scala.reflect.ClassTag -import org.apache.spark.Logging +import org.apache.spark.internal.Logging /** * Implements a Pregel-like bulk-synchronous message-passing API. @@ -119,9 +119,12 @@ object Pregel extends Logging { mergeMsg: (A, A) => A) : Graph[VD, ED] = { + require(maxIterations > 0, s"Maximum number of iterations must be greater than 0," + + s" but got ${maxIterations}") + var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache() // compute the messages - var messages = g.mapReduceTriplets(sendMsg, mergeMsg) + var messages = GraphXUtils.mapReduceTriplets(g, sendMsg, mergeMsg) var activeMessages = messages.count() // Loop var prevG: Graph[VD, ED] = null @@ -135,8 +138,8 @@ object Pregel extends Logging { // Send new messages, skipping edges where neither side received a message. We must cache // messages so it can be materialized on the next line, allowing us to uncache the previous // iteration. - messages = g.mapReduceTriplets( - sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() + messages = GraphXUtils.mapReduceTriplets( + g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() // The call to count() materializes `messages` and the vertices of `g`. This hides oldMessages // (depended on by the vertices of g) and the vertices of prevG (depended on by oldMessages // and the vertices of g). @@ -151,7 +154,7 @@ object Pregel extends Logging { // count the iteration i += 1 } - + messages.unpersist(blocking = false) g } // end of apply diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index 1ef7a78fbcd00..35577d9e2fc6f 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -20,14 +20,12 @@ package org.apache.spark.graphx import scala.reflect.ClassTag import org.apache.spark._ -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd._ -import org.apache.spark.storage.StorageLevel - import org.apache.spark.graphx.impl.RoutingTablePartition import org.apache.spark.graphx.impl.ShippableVertexPartition import org.apache.spark.graphx.impl.VertexAttributeBlock import org.apache.spark.graphx.impl.VertexRDDImpl +import org.apache.spark.rdd._ +import org.apache.spark.storage.StorageLevel /** * Extends `RDD[(VertexId, VD)]` by ensuring that there is only one entry for each vertex and by @@ -277,7 +275,7 @@ object VertexRDD { def apply[VD: ClassTag](vertices: RDD[(VertexId, VD)]): VertexRDD[VD] = { val vPartitioned: RDD[(VertexId, VD)] = vertices.partitioner match { case Some(p) => vertices - case None => vertices.partitionBy(new HashPartitioner(vertices.partitions.size)) + case None => vertices.partitionBy(new HashPartitioner(vertices.partitions.length)) } val vertexPartitions = vPartitioned.mapPartitions( iter => Iterator(ShippableVertexPartition(iter)), @@ -318,7 +316,7 @@ object VertexRDD { ): VertexRDD[VD] = { val vPartitioned: RDD[(VertexId, VD)] = vertices.partitioner match { case Some(p) => vertices - case None => vertices.partitionBy(new HashPartitioner(vertices.partitions.size)) + case None => vertices.partitionBy(new HashPartitioner(vertices.partitions.length)) } val routingTables = createRoutingTables(edges, vPartitioned.partitioner.get) val vertexPartitions = vPartitioned.zipPartitions(routingTables, preservesPartitioning = true) { @@ -359,7 +357,7 @@ object VertexRDD { Function.tupled(RoutingTablePartition.edgePartitionToMsgs))) .setName("VertexRDD.createRoutingTables - vid2pid (aggregation)") - val numEdgePartitions = edges.partitions.size + val numEdgePartitions = edges.partitions.length vid2pid.partitionBy(vertexPartitioner).mapPartitions( iter => Iterator(RoutingTablePartition.fromMsgs(numEdgePartitions, iter)), preservesPartitioning = true) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index ab021a252eb8a..26349f4d88a19 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -17,7 +17,7 @@ package org.apache.spark.graphx.impl -import scala.reflect.{classTag, ClassTag} +import scala.reflect.ClassTag import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap @@ -151,9 +151,9 @@ class EdgePartition[ * applied to each edge */ def map[ED2: ClassTag](f: Edge[ED] => ED2): EdgePartition[ED2, VD] = { - val newData = new Array[ED2](data.size) + val newData = new Array[ED2](data.length) val edge = new Edge[ED]() - val size = data.size + val size = data.length var i = 0 while (i < size) { edge.srcId = srcIds(i) @@ -179,13 +179,13 @@ class EdgePartition[ */ def map[ED2: ClassTag](iter: Iterator[ED2]): EdgePartition[ED2, VD] = { // Faster than iter.toArray, because the expected size is known. - val newData = new Array[ED2](data.size) + val newData = new Array[ED2](data.length) var i = 0 while (iter.hasNext) { newData(i) = iter.next() i += 1 } - assert(newData.size == i) + assert(newData.length == i) this.withData(newData) } @@ -311,7 +311,7 @@ class EdgePartition[ * * @return size of the partition */ - val size: Int = localSrcIds.size + val size: Int = localSrcIds.length /** The number of unique source vertices in the partition. */ def indexSize: Int = index.size diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala index 906d42328fcb9..da3db3c4dca04 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala @@ -21,7 +21,7 @@ import scala.reflect.ClassTag import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap -import org.apache.spark.util.collection.{SortDataFormat, Sorter, PrimitiveVector} +import org.apache.spark.util.collection.{PrimitiveVector, SortDataFormat, Sorter} /** Constructs an EdgePartition from scratch. */ private[graphx] @@ -38,9 +38,9 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla val edgeArray = edges.trim().array new Sorter(Edge.edgeArraySortDataFormat[ED]) .sort(edgeArray, 0, edgeArray.length, Edge.lexicographicOrdering) - val localSrcIds = new Array[Int](edgeArray.size) - val localDstIds = new Array[Int](edgeArray.size) - val data = new Array[ED](edgeArray.size) + val localSrcIds = new Array[Int](edgeArray.length) + val localDstIds = new Array[Int](edgeArray.length) + val data = new Array[ED](edgeArray.length) val index = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int] val global2local = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int] val local2global = new PrimitiveVector[VertexId] @@ -52,7 +52,7 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla var currSrcId: VertexId = edgeArray(0).srcId var currLocalId = -1 var i = 0 - while (i < edgeArray.size) { + while (i < edgeArray.length) { val srcId = edgeArray(i).srcId val dstId = edgeArray(i).dstId localSrcIds(i) = global2local.changeValue(srcId, @@ -98,9 +98,9 @@ class ExistingEdgePartitionBuilder[ val edgeArray = edges.trim().array new Sorter(EdgeWithLocalIds.edgeArraySortDataFormat[ED]) .sort(edgeArray, 0, edgeArray.length, EdgeWithLocalIds.lexicographicOrdering) - val localSrcIds = new Array[Int](edgeArray.size) - val localDstIds = new Array[Int](edgeArray.size) - val data = new Array[ED](edgeArray.size) + val localSrcIds = new Array[Int](edgeArray.length) + val localDstIds = new Array[Int](edgeArray.length) + val data = new Array[ED](edgeArray.length) val index = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int] // Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and // adding them to the index @@ -108,7 +108,7 @@ class ExistingEdgePartitionBuilder[ index.update(edgeArray(0).srcId, 0) var currSrcId: VertexId = edgeArray(0).srcId var i = 0 - while (i < edgeArray.size) { + while (i < edgeArray.length) { localSrcIds(i) = edgeArray(i).localSrcId localDstIds(i) = edgeArray(i).localDstId data(i) = edgeArray(i).attr diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala index c88b2f65a86cd..98e082cc44e1a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala @@ -19,12 +19,11 @@ package org.apache.spark.graphx.impl import scala.reflect.{classTag, ClassTag} -import org.apache.spark.{OneToOneDependency, HashPartitioner} +import org.apache.spark.{HashPartitioner, OneToOneDependency} +import org.apache.spark.graphx._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.graphx._ - class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( @transient override val partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])], val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) @@ -46,7 +45,7 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( * partitioner that allows co-partitioning with `partitionsRDD`. */ override val partitioner = - partitionsRDD.partitioner.orElse(Some(new HashPartitioner(partitions.size))) + partitionsRDD.partitioner.orElse(Some(new HashPartitioner(partitions.length))) override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect() diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index da95314440d86..e18831382d4d5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -20,13 +20,10 @@ package org.apache.spark.graphx.impl import scala.reflect.{classTag, ClassTag} import org.apache.spark.HashPartitioner -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.{RDD, ShuffledRDD} -import org.apache.spark.storage.StorageLevel import org.apache.spark.graphx._ -import org.apache.spark.graphx.impl.GraphImpl._ import org.apache.spark.graphx.util.BytecodeUtils - +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel /** * An implementation of [[org.apache.spark.graphx.Graph]] to support computation on graphs. @@ -94,7 +91,7 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( } override def partitionBy(partitionStrategy: PartitionStrategy): Graph[VD, ED] = { - partitionBy(partitionStrategy, edges.partitions.size) + partitionBy(partitionStrategy, edges.partitions.length) } override def partitionBy( @@ -188,31 +185,6 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( // Lower level transformation methods // /////////////////////////////////////////////////////////////////////////////////////////////// - override def mapReduceTriplets[A: ClassTag]( - mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], - reduceFunc: (A, A) => A, - activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]): VertexRDD[A] = { - - def sendMsg(ctx: EdgeContext[VD, ED, A]) { - mapFunc(ctx.toEdgeTriplet).foreach { kv => - val id = kv._1 - val msg = kv._2 - if (id == ctx.srcId) { - ctx.sendToSrc(msg) - } else { - assert(id == ctx.dstId) - ctx.sendToDst(msg) - } - } - } - - val mapUsesSrcAttr = accessesVertexAttr(mapFunc, "srcAttr") - val mapUsesDstAttr = accessesVertexAttr(mapFunc, "dstAttr") - val tripletFields = new TripletFields(mapUsesSrcAttr, mapUsesDstAttr, true) - - aggregateMessagesWithActiveSet(sendMsg, reduceFunc, tripletFields, activeSetOpt) - } - override def aggregateMessagesWithActiveSet[A: ClassTag]( sendMsg: EdgeContext[VD, ED, A] => Unit, mergeMsg: (A, A) => A, @@ -292,7 +264,7 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( } } - /** Test whether the closure accesses the the attribute with name `attrName`. */ + /** Test whether the closure accesses the attribute with name `attrName`. */ private def accessesVertexAttr(closure: AnyRef, attrName: String): Boolean = { try { BytecodeUtils.invokedMethod(closure, classOf[EdgeTriplet[VD, ED]], attrName) @@ -378,7 +350,8 @@ object GraphImpl { edgeStorageLevel: StorageLevel, vertexStorageLevel: StorageLevel): GraphImpl[VD, ED] = { val edgesCached = edges.withTargetStorageLevel(edgeStorageLevel).cache() - val vertices = VertexRDD.fromEdges(edgesCached, edgesCached.partitions.size, defaultVertexAttr) + val vertices = + VertexRDD.fromEdges(edgesCached, edgesCached.partitions.length, defaultVertexAttr) .withTargetStorageLevel(vertexStorageLevel) fromExistingRDDs(vertices, edgesCached) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala index 1df86449fa0c2..d2194d85bf525 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala @@ -17,12 +17,10 @@ package org.apache.spark.graphx.impl -import scala.reflect.{classTag, ClassTag} - -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.RDD +import scala.reflect.ClassTag import org.apache.spark.graphx._ +import org.apache.spark.rdd.RDD /** * Manages shipping vertex attributes to the edge partitions of an @@ -42,8 +40,8 @@ class ReplicatedVertexView[VD: ClassTag, ED: ClassTag]( * shipping level. */ def withEdges[VD2: ClassTag, ED2: ClassTag]( - edges_ : EdgeRDDImpl[ED2, VD2]): ReplicatedVertexView[VD2, ED2] = { - new ReplicatedVertexView(edges_, hasSrcId, hasDstId) + _edges: EdgeRDDImpl[ED2, VD2]): ReplicatedVertexView[VD2, ED2] = { + new ReplicatedVertexView(_edges, hasSrcId, hasDstId) } /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index 4f1260a5a67b2..6453bbeae9f10 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -17,17 +17,9 @@ package org.apache.spark.graphx.impl -import scala.reflect.ClassTag - -import org.apache.spark.Partitioner -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.ShuffledRDD -import org.apache.spark.util.collection.{BitSet, PrimitiveVector} - import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap - -import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage +import org.apache.spark.util.collection.{BitSet, PrimitiveVector} private[graphx] object RoutingTablePartition { @@ -110,10 +102,10 @@ private[graphx] class RoutingTablePartition( private val routingTable: Array[(Array[VertexId], BitSet, BitSet)]) extends Serializable { /** The maximum number of edge partitions this `RoutingTablePartition` is built to join with. */ - val numEdgePartitions: Int = routingTable.size + val numEdgePartitions: Int = routingTable.length /** Returns the number of vertices that will be sent to the specified edge partition. */ - def partitionSize(pid: PartitionID): Int = routingTable(pid)._1.size + def partitionSize(pid: PartitionID): Int = routingTable(pid)._1.length /** Returns an iterator over all vertex ids stored in this `RoutingTablePartition`. */ def iterator: Iterator[VertexId] = routingTable.iterator.flatMap(_._1.iterator) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala index aa320088f2088..a4e293d74a012 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala @@ -19,17 +19,16 @@ package org.apache.spark.graphx.impl import scala.reflect.ClassTag -import org.apache.spark.util.collection.{BitSet, PrimitiveVector} - import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap +import org.apache.spark.util.collection.{BitSet, PrimitiveVector} /** Stores vertex attributes to ship to an edge partition. */ private[graphx] class VertexAttributeBlock[VD: ClassTag](val vids: Array[VertexId], val attrs: Array[VD]) extends Serializable { def iterator: Iterator[(VertexId, VD)] = - (0 until vids.size).iterator.map { i => (vids(i), attrs(i)) } + (0 until vids.length).iterator.map { i => (vids(i), attrs(i)) } } private[graphx] @@ -50,7 +49,7 @@ object ShippableVertexPartition { /** * Construct a `ShippableVertexPartition` from the given vertices with the specified routing * table, filling in missing vertices mentioned in the routing table using `defaultVal`, - * and merging duplicate vertex atrribute with mergeFunc. + * and merging duplicate vertex attribute with mergeFunc. */ def apply[VD: ClassTag]( iter: Iterator[(VertexId, VD)], routingTable: RoutingTablePartition, defaultVal: VD, @@ -103,8 +102,8 @@ class ShippableVertexPartition[VD: ClassTag]( extends VertexPartitionBase[VD] { /** Return a new ShippableVertexPartition with the specified routing table. */ - def withRoutingTable(routingTable_ : RoutingTablePartition): ShippableVertexPartition[VD] = { - new ShippableVertexPartition(index, values, mask, routingTable_) + def withRoutingTable(_routingTable: RoutingTablePartition): ShippableVertexPartition[VD] = { + new ShippableVertexPartition(index, values, mask, _routingTable) } /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala index fbe53acfc32aa..b4100bade0734 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala @@ -19,10 +19,8 @@ package org.apache.spark.graphx.impl import scala.reflect.ClassTag -import org.apache.spark.util.collection.BitSet - import org.apache.spark.graphx._ -import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap +import org.apache.spark.util.collection.BitSet private[graphx] object VertexPartition { /** Construct a `VertexPartition` from the given vertices. */ diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala index 5ad6390a56c4f..8d608c99b1a1d 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala @@ -20,10 +20,9 @@ package org.apache.spark.graphx.impl import scala.language.higherKinds import scala.reflect.ClassTag -import org.apache.spark.util.collection.BitSet - import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap +import org.apache.spark.util.collection.BitSet private[graphx] object VertexPartitionBase { /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala index b90f9fa327052..31373a53cf933 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala @@ -21,11 +21,10 @@ import scala.language.higherKinds import scala.language.implicitConversions import scala.reflect.ClassTag -import org.apache.spark.Logging -import org.apache.spark.util.collection.BitSet - import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap +import org.apache.spark.internal.Logging +import org.apache.spark.util.collection.BitSet /** * An class containing additional operations for subclasses of VertexPartitionBase that provide @@ -33,7 +32,7 @@ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap * example, [[VertexPartition.VertexPartitionOpsConstructor]]). */ private[graphx] abstract class VertexPartitionBaseOps - [VD: ClassTag, Self[X] <: VertexPartitionBase[X] : VertexPartitionBaseOpsConstructor] + [VD: ClassTag, Self[X] <: VertexPartitionBase[X]: VertexPartitionBaseOpsConstructor] (self: Self[VD]) extends Serializable with Logging { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala index 7f4e7e9d79d6b..d314522de9916 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala @@ -20,12 +20,10 @@ package org.apache.spark.graphx.impl import scala.reflect.ClassTag import org.apache.spark._ -import org.apache.spark.SparkContext._ +import org.apache.spark.graphx._ import org.apache.spark.rdd._ import org.apache.spark.storage.StorageLevel -import org.apache.spark.graphx._ - class VertexRDDImpl[VD] private[graphx] ( @transient val partitionsRDD: RDD[ShippableVertexPartition[VD]], val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala index 859f896039047..4e9b13162e5ca 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala @@ -29,13 +29,16 @@ object ConnectedComponents { * * @tparam VD the vertex attribute type (discarded in the computation) * @tparam ED the edge attribute type (preserved in the computation) - * * @param graph the graph for which to compute the connected components - * + * @param maxIterations the maximum number of iterations to run for * @return a graph with vertex attributes containing the smallest vertex in each * connected component */ - def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[VertexId, ED] = { + def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], + maxIterations: Int): Graph[VertexId, ED] = { + require(maxIterations > 0, s"Maximum of iterations must be greater than 0," + + s" but got ${maxIterations}") + val ccGraph = graph.mapVertices { case (vid, _) => vid } def sendMessage(edge: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, VertexId)] = { if (edge.srcAttr < edge.dstAttr) { @@ -47,9 +50,26 @@ object ConnectedComponents { } } val initialMessage = Long.MaxValue - Pregel(ccGraph, initialMessage, activeDirection = EdgeDirection.Either)( + val pregelGraph = Pregel(ccGraph, initialMessage, + maxIterations, EdgeDirection.Either)( vprog = (id, attr, msg) => math.min(attr, msg), sendMsg = sendMessage, mergeMsg = (a, b) => math.min(a, b)) + ccGraph.unpersist() + pregelGraph } // end of connectedComponents + + /** + * Compute the connected component membership of each vertex and return a graph with the vertex + * value containing the lowest vertex id in the connected component containing that vertex. + * + * @tparam VD the vertex attribute type (discarded in the computation) + * @tparam ED the edge attribute type (preserved in the computation) + * @param graph the graph for which to compute the connected components + * @return a graph with vertex attributes containing the smallest vertex in each + * connected component + */ + def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[VertexId, ED] = { + run(graph, Int.MaxValue) + } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala index a3ad6bed1c998..fc7547a2c7c27 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala @@ -18,6 +18,7 @@ package org.apache.spark.graphx.lib import scala.reflect.ClassTag + import org.apache.spark.graphx._ /** Label Propagation algorithm. */ @@ -42,6 +43,8 @@ object LabelPropagation { * @return a graph with vertex attributes containing the label of community affiliation */ def run[VD, ED: ClassTag](graph: Graph[VD, ED], maxSteps: Int): Graph[VertexId, ED] = { + require(maxSteps > 0, s"Maximum of steps must be greater than 0, but got ${maxSteps}") + val lpaGraph = graph.mapVertices { case (vid, _) => vid } def sendMessage(e: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, Map[VertexId, Long])] = { Iterator((e.srcId, Map(e.dstAttr -> 1L)), (e.dstId, Map(e.srcAttr -> 1L))) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 52b237fc15093..0a1622bca0f4b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -17,11 +17,11 @@ package org.apache.spark.graphx.lib -import scala.reflect.ClassTag import scala.language.postfixOps +import scala.reflect.ClassTag -import org.apache.spark.Logging import org.apache.spark.graphx._ +import org.apache.spark.internal.Logging /** * PageRank algorithm implementation. There are two implementations of PageRank implemented. @@ -54,7 +54,7 @@ import org.apache.spark.graphx._ * }}} * * `alpha` is the random reset probability (typically 0.15), `inNbrs[i]` is the set of - * neighbors whick link to `i` and `outDeg[j]` is the out degree of vertex `j`. + * neighbors which link to `i` and `outDeg[j]` is the out degree of vertex `j`. * * Note that this is not the "normalized" PageRank and as a consequence pages that have no * inlinks will have a PageRank of alpha. @@ -104,6 +104,11 @@ object PageRank extends Logging { graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15, srcId: Option[VertexId] = None): Graph[Double, Double] = { + require(numIter > 0, s"Number of iterations must be greater than 0," + + s" but got ${numIter}") + require(resetProb >= 0 && resetProb <= 1, s"Random reset probability must belong" + + s" to [0, 1], but got ${resetProb}") + val personalized = srcId isDefined val src: VertexId = srcId.getOrElse(-1L) @@ -138,7 +143,7 @@ object PageRank extends Logging { // edge partitions. prevRankGraph = rankGraph val rPrb = if (personalized) { - (src: VertexId , id: VertexId) => resetProb * delta(src, id) + (src: VertexId, id: VertexId) => resetProb * delta(src, id) } else { (src: VertexId, id: VertexId) => resetProb } @@ -197,6 +202,10 @@ object PageRank extends Logging { graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15, srcId: Option[VertexId] = None): Graph[Double, Double] = { + require(tol >= 0, s"Tolerance must be no less than 0, but got ${tol}") + require(resetProb >= 0 && resetProb <= 1, s"Random reset probability must belong" + + s" to [0, 1], but got ${resetProb}") + val personalized = srcId.isDefined val src: VertexId = srcId.getOrElse(-1L) @@ -209,7 +218,7 @@ object PageRank extends Logging { } // Set the weight on the edges based on the degree .mapTriplets( e => 1.0 / e.srcAttr ) - // Set the vertex attributes to (initalPR, delta = 0) + // Set the vertex attributes to (initialPR, delta = 0) .mapVertices { (id, attr) => if (id == src) (resetProb, Double.NegativeInfinity) else (0.0, 0.0) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index 9cb24ed080e1c..bb2ffab0f60f8 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -21,8 +21,8 @@ import scala.util.Random import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark.rdd._ import org.apache.spark.graphx._ +import org.apache.spark.rdd._ /** Implementation of SVD++ algorithm. */ object SVDPlusPlus { @@ -39,17 +39,6 @@ object SVDPlusPlus { var gamma7: Double) extends Serializable - /** - * This method is now replaced by the updated version of `run()` and returns exactly - * the same result. - */ - @deprecated("Call run()", "1.4.0") - def runSVDPlusPlus(edges: RDD[Edge[Double]], conf: Conf) - : (Graph[(Array[Double], Array[Double], Double, Double), Double], Double) = - { - run(edges, conf) - } - /** * Implement SVD++ based on "Factorization Meets the Neighborhood: * a Multifaceted Collaborative Filtering Model", @@ -67,6 +56,11 @@ object SVDPlusPlus { def run(edges: RDD[Edge[Double]], conf: Conf) : (Graph[(Array[Double], Array[Double], Double, Double), Double], Double) = { + require(conf.maxIters > 0, s"Maximum of iterations must be greater than 0," + + s" but got ${conf.maxIters}") + require(conf.maxVal > conf.minVal, s"MaxVal must be greater than MinVal," + + s" but got {maxVal: ${conf.maxVal}, minVal: ${conf.minVal}}") + // Generate default vertex attribute def defaultF(rank: Int): (Array[Double], Array[Double], Double, Double) = { // TODO: use a fixed random seed diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala index 179f2843818e0..f0c6bcb93445c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala @@ -17,9 +17,10 @@ package org.apache.spark.graphx.lib -import org.apache.spark.graphx._ import scala.reflect.ClassTag +import org.apache.spark.graphx._ + /** * Computes shortest paths to the given set of landmark vertices, returning a graph where each * vertex attribute is a map containing the shortest-path distance to each reachable landmark. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala index 8dd958033b338..1fa92b0195410 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala @@ -36,6 +36,8 @@ object StronglyConnectedComponents { * @return a graph with vertex attributes containing the smallest vertex id in each SCC */ def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], numIter: Int): Graph[VertexId, ED] = { + require(numIter > 0, s"Number of iterations must be greater than 0," + + s" but got ${numIter}") // the graph we update with final SCC ids, and the graph we return at the end var sccGraph = graph.mapVertices { case (vid, _) => vid } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala index a5d598053f9ca..34e9e22c3a35a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala @@ -27,27 +27,49 @@ import org.apache.spark.graphx._ * The algorithm is relatively straightforward and can be computed in three steps: * *

      - *
    • Compute the set of neighbors for each vertex - *
    • For each edge compute the intersection of the sets and send the count to both vertices. - *
    • Compute the sum at each vertex and divide by two since each triangle is counted twice. + *
    • Compute the set of neighbors for each vertex
    • + *
    • For each edge compute the intersection of the sets and send the count to both vertices.
    • + *
    • Compute the sum at each vertex and divide by two since each triangle is counted twice.
    • *
    * - * Note that the input graph should have its edges in canonical direction - * (i.e. the `sourceId` less than `destId`). Also the graph must have been partitioned - * using [[org.apache.spark.graphx.Graph#partitionBy]]. + * There are two implementations. The default `TriangleCount.run` implementation first removes + * self cycles and canonicalizes the graph to ensure that the following conditions hold: + *
      + *
    • There are no self edges
    • + *
    • All edges are oriented src > dst
    • + *
    • There are no duplicate edges
    • + *
    + * However, the canonicalization procedure is costly as it requires repartitioning the graph. + * If the input data is already in "canonical form" with self cycles removed then the + * `TriangleCount.runPreCanonicalized` should be used instead. + * + * {{{ + * val canonicalGraph = graph.mapEdges(e => 1).removeSelfEdges().canonicalizeEdges() + * val counts = TriangleCount.runPreCanonicalized(canonicalGraph).vertices + * }}} + * */ object TriangleCount { def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[Int, ED] = { - // Remove redundant edges - val g = graph.groupEdges((a, b) => a).cache() + // Transform the edge data something cheap to shuffle and then canonicalize + val canonicalGraph = graph.mapEdges(e => true).removeSelfEdges().convertToCanonicalEdges() + // Get the triangle counts + val counters = runPreCanonicalized(canonicalGraph).vertices + // Join them bath with the original graph + graph.outerJoinVertices(counters) { (vid, _, optCounter: Option[Int]) => + optCounter.getOrElse(0) + } + } + + def runPreCanonicalized[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[Int, ED] = { // Construct set representations of the neighborhoods val nbrSets: VertexRDD[VertexSet] = - g.collectNeighborIds(EdgeDirection.Either).mapValues { (vid, nbrs) => - val set = new VertexSet(4) + graph.collectNeighborIds(EdgeDirection.Either).mapValues { (vid, nbrs) => + val set = new VertexSet(nbrs.length) var i = 0 - while (i < nbrs.size) { + while (i < nbrs.length) { // prevent self cycle if (nbrs(i) != vid) { set.add(nbrs(i)) @@ -56,14 +78,14 @@ object TriangleCount { } set } + // join the sets with the graph - val setGraph: Graph[VertexSet, ED] = g.outerJoinVertices(nbrSets) { + val setGraph: Graph[VertexSet, ED] = graph.outerJoinVertices(nbrSets) { (vid, _, optSet) => optSet.getOrElse(null) } + // Edge function computes intersection of smaller vertex with larger vertex def edgeFunc(ctx: EdgeContext[VertexSet, ED, Int]) { - assert(ctx.srcAttr != null) - assert(ctx.dstAttr != null) val (smallSet, largeSet) = if (ctx.srcAttr.size < ctx.dstAttr.size) { (ctx.srcAttr, ctx.dstAttr) } else { @@ -80,15 +102,15 @@ object TriangleCount { ctx.sendToSrc(counter) ctx.sendToDst(counter) } + // compute the intersection along edges val counters: VertexRDD[Int] = setGraph.aggregateMessages(edgeFunc, _ + _) // Merge counters with the graph and divide by two since each triangle is counted twice - g.outerJoinVertices(counters) { - (vid, _, optCounter: Option[Int]) => - val dblCount = optCounter.getOrElse(0) - // double count should be even (divisible by two) - assert((dblCount & 1) == 0) - dblCount / 2 + graph.outerJoinVertices(counters) { (_, _, optCounter: Option[Int]) => + val dblCount = optCounter.getOrElse(0) + // This algorithm double counts each triangle so the final count should be even + require(dblCount % 2 == 0, "Triangle count resulted in an invalid number of triangles.") + dblCount / 2 } - } // end of TriangleCount + } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/package.scala b/graphx/src/main/scala/org/apache/spark/graphx/package.scala index 6aab28ff05355..dde25b96594be 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/package.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/package.scala @@ -30,7 +30,7 @@ package object graphx { */ type VertexId = Long - /** Integer identifer of a graph partition. Must be less than 2^30. */ + /** Integer identifier of a graph partition. Must be less than 2^30. */ // TODO: Consider using Char. type PartitionID = Int diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index 74a7de18d4161..d76e84ed8c9ed 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -22,11 +22,10 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.HashSet import scala.language.existentials -import org.apache.spark.util.Utils - -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor} -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ +import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor} +import org.apache.xbean.asm5.Opcodes._ +import org.apache.spark.util.Utils /** * Includes an utility function to test whether a function accesses a specific attribute @@ -93,7 +92,7 @@ private[graphx] object BytecodeUtils { /** * Given the class name, return whether we should look into the class or not. This is used to - * skip examing a large quantity of Java or Scala classes that we know for sure wouldn't access + * skip examining a large quantity of Java or Scala classes that we know for sure wouldn't access * the closures. Note that the class name is expected in ASM style (i.e. use "/" instead of "."). */ private def skipClass(className: String): Boolean = { @@ -107,18 +106,19 @@ private[graphx] object BytecodeUtils { * MethodInvocationFinder("spark/graph/Foo", "test") * its methodsInvoked variable will contain the set of methods invoked directly by * Foo.test(). Interface invocations are not returned as part of the result set because we cannot - * determine the actual metod invoked by inspecting the bytecode. + * determine the actual method invoked by inspecting the bytecode. */ private class MethodInvocationFinder(className: String, methodName: String) - extends ClassVisitor(ASM4) { + extends ClassVisitor(ASM5) { val methodsInvoked = new HashSet[(Class[_], String)] override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { if (name == methodName) { - new MethodVisitor(ASM4) { - override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { + new MethodVisitor(ASM5) { + override def visitMethodInsn( + op: Int, owner: String, name: String, desc: String, itf: Boolean) { if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) { if (!skipClass(owner)) { methodsInvoked.add((Utils.classForName(owner.replace("/", ".")), name)) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 989e226305265..80c6b6838faf5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -18,19 +18,13 @@ package org.apache.spark.graphx.util import scala.annotation.tailrec -import scala.math._ import scala.reflect.ClassTag import scala.util._ import org.apache.spark._ -import org.apache.spark.serializer._ -import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ -import org.apache.spark.graphx.Graph -import org.apache.spark.graphx.Edge -import org.apache.spark.graphx.impl.GraphImpl +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD /** A collection of graph generating functions. */ object GraphGenerators extends Logging { @@ -169,7 +163,7 @@ object GraphGenerators extends Logging { } /** - * This method recursively subdivides the the adjacency matrix into quadrants + * This method recursively subdivides the adjacency matrix into quadrants * until it picks a single cell. The naming conventions in this paper match * those of the R-MAT paper. There are a power of 2 number of nodes in the graph. * The adjacency matrix looks like: diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala index e2754ea699da9..972237da1cb28 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala @@ -17,10 +17,10 @@ package org.apache.spark.graphx.util.collection -import org.apache.spark.util.collection.OpenHashSet - import scala.reflect._ +import org.apache.spark.util.collection.OpenHashSet + /** * A fast hash map implementation for primitive, non-null keys. This hash map supports * insertions and updates, but not deletions. This map is about an order of magnitude diff --git a/graphx/src/test/resources/log4j.properties b/graphx/src/test/resources/log4j.properties index eb3b1999eb996..3706a6e361307 100644 --- a/graphx/src/test/resources/log4j.properties +++ b/graphx/src/test/resources/log4j.properties @@ -24,5 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN -org.spark-project.jetty.LEVEL=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala index 094a63472eaab..4d6b899c83a04 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite class EdgeSuite extends SparkFunSuite { test ("compare") { - // decending order + // descending order val testEdges: Array[Edge[Int]] = Array( Edge(0x7FEDCBA987654321L, -0x7FEDCBA987654321L, 1), Edge(0x2345L, 0x1234L, 1), diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphLoaderSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphLoaderSuite.scala new file mode 100644 index 0000000000000..e55b05fa996ad --- /dev/null +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphLoaderSuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx + +import java.io.File +import java.io.FileOutputStream +import java.io.OutputStreamWriter +import java.nio.charset.StandardCharsets + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.Utils + +class GraphLoaderSuite extends SparkFunSuite with LocalSparkContext { + + test("GraphLoader.edgeListFile") { + withSpark { sc => + val tmpDir = Utils.createTempDir() + val graphFile = new File(tmpDir.getAbsolutePath, "graph.txt") + val writer = new OutputStreamWriter(new FileOutputStream(graphFile), StandardCharsets.UTF_8) + for (i <- (1 until 101)) writer.write(s"$i 0\n") + writer.close() + try { + val graph = GraphLoader.edgeListFile(sc, tmpDir.getAbsolutePath) + val neighborAttrSums = graph.aggregateMessages[Int]( + ctx => ctx.sendToDst(ctx.srcAttr), + _ + _) + assert(neighborAttrSums.collect.toSet === Set((0: VertexId, 100))) + } finally { + Utils.deleteRecursively(tmpDir) + } + } + } +} diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala index 57a8b95dd12e9..32981719499d9 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.graphx import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.Graph._ -import org.apache.spark.graphx.impl.EdgePartition -import org.apache.spark.rdd._ class GraphOpsSuite extends SparkFunSuite with LocalSparkContext { @@ -55,6 +53,21 @@ class GraphOpsSuite extends SparkFunSuite with LocalSparkContext { } } + test("removeSelfEdges") { + withSpark { sc => + val edgeArray = Array((1 -> 2), (2 -> 3), (3 -> 3), (4 -> 3), (1 -> 1)) + .map { + case (a, b) => (a.toLong, b.toLong) + } + val correctEdges = edgeArray.filter { case (a, b) => a != b }.toSet + val graph = Graph.fromEdgeTuples(sc.parallelize(edgeArray), 1) + val canonicalizedEdges = graph.removeSelfEdges().edges.map(e => (e.srcId, e.dstId)) + .collect + assert(canonicalizedEdges.toSet.size === canonicalizedEdges.size) + assert(canonicalizedEdges.toSet === correctEdges) + } + } + test ("filter") { withSpark { sc => val n = 5 diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 1f5e27d5508b8..96aa262a395c8 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -221,14 +221,15 @@ class GraphSuite extends SparkFunSuite with LocalSparkContext { val vertices: RDD[(VertexId, Int)] = sc.parallelize(Array((1L, 1), (2L, 2))) val edges: RDD[Edge[Int]] = sc.parallelize(Array(Edge(1L, 2L, 0))) val graph = Graph(vertices, edges).reverse - val result = graph.mapReduceTriplets[Int](et => Iterator((et.dstId, et.srcAttr)), _ + _) + val result = GraphXUtils.mapReduceTriplets[Int, Int, Int]( + graph, et => Iterator((et.dstId, et.srcAttr)), _ + _) assert(result.collect().toSet === Set((1L, 2))) } } test("subgraph") { withSpark { sc => - // Create a star graph of 10 veritces. + // Create a star graph of 10 vertices. val n = 10 val star = starGraph(sc, n) // Take only vertices whose vids are even @@ -281,49 +282,6 @@ class GraphSuite extends SparkFunSuite with LocalSparkContext { } } - test("mapReduceTriplets") { - withSpark { sc => - val n = 5 - val star = starGraph(sc, n).mapVertices { (_, _) => 0 }.cache() - val starDeg = star.joinVertices(star.degrees){ (vid, oldV, deg) => deg } - val neighborDegreeSums = starDeg.mapReduceTriplets( - edge => Iterator((edge.srcId, edge.dstAttr), (edge.dstId, edge.srcAttr)), - (a: Int, b: Int) => a + b) - assert(neighborDegreeSums.collect().toSet === (0 to n).map(x => (x, n)).toSet) - - // activeSetOpt - val allPairs = for (x <- 1 to n; y <- 1 to n) yield (x: VertexId, y: VertexId) - val complete = Graph.fromEdgeTuples(sc.parallelize(allPairs, 3), 0) - val vids = complete.mapVertices((vid, attr) => vid).cache() - val active = vids.vertices.filter { case (vid, attr) => attr % 2 == 0 } - val numEvenNeighbors = vids.mapReduceTriplets(et => { - // Map function should only run on edges with destination in the active set - if (et.dstId % 2 != 0) { - throw new Exception("map ran on edge with dst vid %d, which is odd".format(et.dstId)) - } - Iterator((et.srcId, 1)) - }, (a: Int, b: Int) => a + b, Some((active, EdgeDirection.In))).collect().toSet - assert(numEvenNeighbors === (1 to n).map(x => (x: VertexId, n / 2)).toSet) - - // outerJoinVertices followed by mapReduceTriplets(activeSetOpt) - val ringEdges = sc.parallelize((0 until n).map(x => (x: VertexId, (x + 1) % n: VertexId)), 3) - val ring = Graph.fromEdgeTuples(ringEdges, 0) .mapVertices((vid, attr) => vid).cache() - val changed = ring.vertices.filter { case (vid, attr) => attr % 2 == 1 }.mapValues(-_).cache() - val changedGraph = ring.outerJoinVertices(changed) { (vid, old, newOpt) => - newOpt.getOrElse(old) - } - val numOddNeighbors = changedGraph.mapReduceTriplets(et => { - // Map function should only run on edges with source in the active set - if (et.srcId % 2 != 1) { - throw new Exception("map ran on edge with src vid %d, which is even".format(et.dstId)) - } - Iterator((et.dstId, 1)) - }, (a: Int, b: Int) => a + b, Some(changed, EdgeDirection.Out)).collect().toSet - assert(numOddNeighbors === (2 to n by 2).map(x => (x: VertexId, 1)).toSet) - - } - } - test("aggregateMessages") { withSpark { sc => val n = 5 @@ -347,7 +305,8 @@ class GraphSuite extends SparkFunSuite with LocalSparkContext { val reverseStarDegrees = reverseStar.outerJoinVertices(reverseStar.outDegrees) { (vid, a, bOpt) => bOpt.getOrElse(0) } - val neighborDegreeSums = reverseStarDegrees.mapReduceTriplets( + val neighborDegreeSums = GraphXUtils.mapReduceTriplets[Int, Int, Int]( + reverseStarDegrees, et => Iterator((et.srcId, et.dstAttr), (et.dstId, et.srcAttr)), (a: Int, b: Int) => a + b).collect().toSet assert(neighborDegreeSums === Set((0: VertexId, n)) ++ (1 to n).map(x => (x: VertexId, 0))) @@ -420,7 +379,8 @@ class GraphSuite extends SparkFunSuite with LocalSparkContext { val edges = sc.parallelize((1 to n).map(x => (x: VertexId, 0: VertexId)), numEdgePartitions) val graph = Graph.fromEdgeTuples(edges, 1) - val neighborAttrSums = graph.mapReduceTriplets[Int]( + val neighborAttrSums = GraphXUtils.mapReduceTriplets[Int, Int, Int]( + graph, et => Iterator((et.dstId, et.srcAttr)), _ + _) assert(neighborAttrSums.collect().toSet === Set((0: VertexId, n))) } finally { @@ -428,4 +388,29 @@ class GraphSuite extends SparkFunSuite with LocalSparkContext { } } + test("unpersist graph RDD") { + withSpark { sc => + val vert = sc.parallelize(List((1L, "a"), (2L, "b"), (3L, "c")), 1) + val edges = sc.parallelize(List(Edge[Long](1L, 2L), Edge[Long](1L, 3L)), 1) + val g0 = Graph(vert, edges) + val g = g0.partitionBy(PartitionStrategy.EdgePartition2D, 2) + val cc = g.connectedComponents() + assert(sc.getPersistentRDDs.nonEmpty) + cc.unpersist() + g.unpersist() + g0.unpersist() + vert.unpersist() + edges.unpersist() + assert(sc.getPersistentRDDs.isEmpty) + } + } + + test("SPARK-14219: pickRandomVertex") { + withSpark { sc => + val vert = sc.parallelize(List((1L, "a")), 1) + val edges = sc.parallelize(List(Edge[Long](1L, 1L)), 1) + val g0 = Graph(vert, edges) + assert(g0.pickRandomVertex() === 1L) + } + } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala index 8afa2d403b53f..90a9ac613ef9d 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.graphx -import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.rdd._ +import org.apache.spark.SparkFunSuite class PregelSuite extends SparkFunSuite with LocalSparkContext { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala index f1aa685a79c98..0bb9e0a3ea180 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala @@ -32,7 +32,7 @@ class VertexRDDSuite extends SparkFunSuite with LocalSparkContext { val n = 100 val verts = vertices(sc, n) val evens = verts.filter(q => ((q._2 % 2) == 0)) - assert(evens.count === (0 to n).filter(_ % 2 == 0).size) + assert(evens.count === (0 to n).count(_ % 2 == 0)) } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala index 7435647c6d9ee..e4678b3578d9d 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala @@ -18,14 +18,12 @@ package org.apache.spark.graphx.impl import scala.reflect.ClassTag -import scala.util.Random import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.graphx._ import org.apache.spark.serializer.JavaSerializer import org.apache.spark.serializer.KryoSerializer -import org.apache.spark.graphx._ - class EdgePartitionSuite extends SparkFunSuite { def makeEdgePartition[A: ClassTag](xs: Iterable[(Int, Int, A)]): EdgePartition[A, Int] = { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala index 1203f8959f506..0fb8451fdcab1 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala @@ -18,11 +18,10 @@ package org.apache.spark.graphx.impl import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.graphx._ import org.apache.spark.serializer.JavaSerializer import org.apache.spark.serializer.KryoSerializer -import org.apache.spark.graphx._ - class VertexPartitionSuite extends SparkFunSuite { test("isDefined, filter") { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala index c965a6eb8df13..1b81423563372 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.graphx.lib -import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.SparkContext._ +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ import org.apache.spark.graphx.util.GraphGenerators import org.apache.spark.rdd._ diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala index d7eaa70ce6407..994395bbffa56 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala @@ -17,12 +17,8 @@ package org.apache.spark.graphx.lib -import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.SparkContext._ +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ -import org.apache.spark.graphx.lib._ -import org.apache.spark.graphx.util.GraphGenerators -import org.apache.spark.rdd._ class ShortestPathsSuite extends SparkFunSuite with LocalSparkContext { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala index d6b03208180db..2c57e8927e4d6 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala @@ -17,11 +17,8 @@ package org.apache.spark.graphx.lib -import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.SparkContext._ +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ -import org.apache.spark.graphx.util.GraphGenerators -import org.apache.spark.rdd._ class StronglyConnectedComponentsSuite extends SparkFunSuite with LocalSparkContext { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala index c47552cf3a3bd..f19c3acdc85cf 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala @@ -26,7 +26,7 @@ class TriangleCountSuite extends SparkFunSuite with LocalSparkContext { test("Count a single triangle") { withSpark { sc => - val rawEdges = sc.parallelize(Array( 0L->1L, 1L->2L, 2L->0L ), 2) + val rawEdges = sc.parallelize(Array( 0L -> 1L, 1L -> 2L, 2L -> 0L ), 2) val graph = Graph.fromEdgeTuples(rawEdges, true).cache() val triangleCount = graph.triangleCount() val verts = triangleCount.vertices @@ -64,9 +64,9 @@ class TriangleCountSuite extends SparkFunSuite with LocalSparkContext { val verts = triangleCount.vertices verts.collect().foreach { case (vid, count) => if (vid == 0) { - assert(count === 4) - } else { assert(count === 2) + } else { + assert(count === 1) } } } @@ -75,7 +75,8 @@ class TriangleCountSuite extends SparkFunSuite with LocalSparkContext { test("Count a single triangle with duplicate edges") { withSpark { sc => val rawEdges = sc.parallelize(Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++ - Array(0L -> 1L, 1L -> 2L, 2L -> 0L), 2) + Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++ + Array(1L -> 0L, 1L -> 1L), 2) val graph = Graph.fromEdgeTuples(rawEdges, true, uniqueEdges = Some(RandomVertexCut)).cache() val triangleCount = graph.triangleCount() val verts = triangleCount.vertices diff --git a/launcher/pom.xml b/launcher/pom.xml index 5739bfc16958f..ef731948826ef 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT + spark-parent_2.11 + 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-launcher_2.10 + spark-launcher_2.11 jar Spark Project Launcher http://spark.apache.org/ diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 3ee6bd92e47fc..c7488082ca899 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -19,18 +19,16 @@ import java.io.BufferedReader; import java.io.File; -import java.io.FileFilter; import java.io.FileInputStream; import java.io.InputStreamReader; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Properties; -import java.util.jar.JarFile; import java.util.regex.Pattern; import static org.apache.spark.launcher.CommandBuilderUtils.*; @@ -59,13 +57,13 @@ abstract class AbstractCommandBuilder { // properties files multiple times. private Map effectiveConfig; - public AbstractCommandBuilder() { - this.appArgs = new ArrayList(); - this.childEnv = new HashMap(); - this.conf = new HashMap(); - this.files = new ArrayList(); - this.jars = new ArrayList(); - this.pyFiles = new ArrayList(); + AbstractCommandBuilder() { + this.appArgs = new ArrayList<>(); + this.childEnv = new HashMap<>(); + this.conf = new HashMap<>(); + this.files = new ArrayList<>(); + this.jars = new ArrayList<>(); + this.pyFiles = new ArrayList<>(); } /** @@ -76,7 +74,8 @@ public AbstractCommandBuilder() { * SparkLauncher constructor that takes an environment), and may be modified to * include other variables needed by the process to be executed. */ - abstract List buildCommand(Map env) throws IOException; + abstract List buildCommand(Map env) + throws IOException, IllegalArgumentException; /** * Builds a list of arguments to run java. @@ -89,7 +88,7 @@ public AbstractCommandBuilder() { * class. */ List buildJavaCommand(String extraClassPath) throws IOException { - List cmd = new ArrayList(); + List cmd = new ArrayList<>(); String envJavaHome; if (javaHome != null) { @@ -104,7 +103,7 @@ List buildJavaCommand(String extraClassPath) throws IOException { File javaOpts = new File(join(File.separator, getConfDir(), "java-opts")); if (javaOpts.isFile()) { BufferedReader br = new BufferedReader(new InputStreamReader( - new FileInputStream(javaOpts), "UTF-8")); + new FileInputStream(javaOpts), StandardCharsets.UTF_8)); try { String line; while ((line = br.readLine()) != null) { @@ -136,7 +135,7 @@ void addOptionString(List cmd, String options) { List buildClassPath(String appClassPath) throws IOException { String sparkHome = getSparkHome(); - List cp = new ArrayList(); + List cp = new ArrayList<>(); addToClassPath(cp, getenv("SPARK_CLASSPATH")); addToClassPath(cp, appClassPath); @@ -146,9 +145,26 @@ List buildClassPath(String appClassPath) throws IOException { boolean isTesting = "1".equals(getenv("SPARK_TESTING")); if (prependClasses || isTesting) { String scala = getScalaVersion(); - List projects = Arrays.asList("core", "repl", "mllib", "bagel", "graphx", - "streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver", - "yarn", "launcher"); + List projects = Arrays.asList( + "common/network-common", + "common/network-shuffle", + "common/network-yarn", + "common/sketch", + "common/tags", + "common/unsafe", + "core", + "examples", + "graphx", + "launcher", + "mllib", + "repl", + "sql/catalyst", + "sql/core", + "sql/hive", + "sql/hive-thriftserver", + "streaming", + "yarn" + ); if (prependClasses) { if (!isTesting) { System.err.println( @@ -172,40 +188,13 @@ List buildClassPath(String appClassPath) throws IOException { addToClassPath(cp, String.format("%s/core/target/jars/*", sparkHome)); } - // We can't rely on the ENV_SPARK_ASSEMBLY variable to be set. Certain situations, such as - // when running unit tests, or user code that embeds Spark and creates a SparkContext - // with a local or local-cluster master, will cause this code to be called from an - // environment where that env variable is not guaranteed to exist. - // - // For the testing case, we rely on the test code to set and propagate the test classpath - // appropriately. - // - // For the user code case, we fall back to looking for the Spark assembly under SPARK_HOME. - // That duplicates some of the code in the shell scripts that look for the assembly, though. - String assembly = getenv(ENV_SPARK_ASSEMBLY); - if (assembly == null && !isTesting) { - assembly = findAssembly(); - } - addToClassPath(cp, assembly); - - // Datanucleus jars must be included on the classpath. Datanucleus jars do not work if only - // included in the uber jar as plugin.xml metadata is lost. Both sbt and maven will populate - // "lib_managed/jars/" with the datanucleus jars when Spark is built with Hive - File libdir; - if (new File(sparkHome, "RELEASE").isFile()) { - libdir = new File(sparkHome, "lib"); - } else { - libdir = new File(sparkHome, "lib_managed/jars"); - } - - if (libdir.isDirectory()) { - for (File jar : libdir.listFiles()) { - if (jar.getName().startsWith("datanucleus-")) { - addToClassPath(cp, jar.getAbsolutePath()); - } - } - } else { - checkState(isTesting, "Library directory '%s' does not exist.", libdir.getAbsolutePath()); + // Add Spark jars to the classpath. For the testing case, we rely on the test code to set and + // propagate the test classpath appropriately. For normal invocation, look for the jars + // directory under SPARK_HOME. + boolean isTestingSql = "1".equals(getenv("SPARK_SQL_TESTING")); + String jarsDir = findJarsDir(getSparkHome(), getScalaVersion(), !isTesting && !isTestingSql); + if (jarsDir != null) { + addToClassPath(cp, join(File.separator, jarsDir, "*")); } addToClassPath(cp, getenv("HADOOP_CONF_DIR")); @@ -302,7 +291,7 @@ private Properties loadPropertiesFile() throws IOException { FileInputStream fd = null; try { fd = new FileInputStream(propsFile); - props.load(new InputStreamReader(fd, "UTF-8")); + props.load(new InputStreamReader(fd, StandardCharsets.UTF_8)); for (Map.Entry e : props.entrySet()) { e.setValue(e.getValue().toString().trim()); } @@ -320,30 +309,6 @@ private Properties loadPropertiesFile() throws IOException { return props; } - private String findAssembly() { - String sparkHome = getSparkHome(); - File libdir; - if (new File(sparkHome, "RELEASE").isFile()) { - libdir = new File(sparkHome, "lib"); - checkState(libdir.isDirectory(), "Library directory '%s' does not exist.", - libdir.getAbsolutePath()); - } else { - libdir = new File(sparkHome, String.format("assembly/target/scala-%s", getScalaVersion())); - } - - final Pattern re = Pattern.compile("spark-assembly.*hadoop.*\\.jar"); - FileFilter filter = new FileFilter() { - @Override - public boolean accept(File file) { - return file.isFile() && re.matcher(file.getName()).matches(); - } - }; - File[] assemblies = libdir.listFiles(filter); - checkState(assemblies != null && assemblies.length > 0, "No assemblies found in '%s'.", libdir); - checkState(assemblies.length == 1, "Multiple assemblies found in '%s'.", libdir); - return assemblies[0].getAbsolutePath(); - } - private String getConfDir() { String confDir = getenv("SPARK_CONF_DIR"); return confDir != null ? confDir : join(File.separator, getSparkHome(), "conf"); diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java index de50f14fbdc87..1bfda289dec39 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -18,6 +18,7 @@ package org.apache.spark.launcher; import java.io.IOException; +import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List; import java.util.concurrent.ThreadFactory; @@ -102,8 +103,20 @@ public synchronized void kill() { disconnect(); } if (childProc != null) { - childProc.destroy(); - childProc = null; + try { + childProc.exitValue(); + } catch (IllegalThreadStateException e) { + // Child is still alive. Try to use Java 8's "destroyForcibly()" if available, + // fall back to the old API if it's not there. + try { + Method destroy = childProc.getClass().getMethod("destroyForcibly"); + destroy.invoke(childProc); + } catch (Exception inner) { + childProc.destroy(); + } + } finally { + childProc = null; + } } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index d30c2ec5f87bb..91586aad7b709 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -30,12 +30,11 @@ class CommandBuilderUtils { static final String DEFAULT_MEM = "1g"; static final String DEFAULT_PROPERTIES_FILE = "spark-defaults.conf"; static final String ENV_SPARK_HOME = "SPARK_HOME"; - static final String ENV_SPARK_ASSEMBLY = "_SPARK_ASSEMBLY"; /** The set of known JVM vendors. */ - static enum JavaVendor { + enum JavaVendor { Oracle, IBM, OpenJDK, Unknown - }; + } /** Returns whether the given string is null or empty. */ static boolean isEmpty(String s) { @@ -147,7 +146,7 @@ static void mergeEnvPathList(Map userEnv, String envKey, String * Output: [ "ab cd", "efgh", "i \" j" ] */ static List parseOptionString(String s) { - List opts = new ArrayList(); + List opts = new ArrayList<>(); StringBuilder opt = new StringBuilder(); boolean inOpt = false; boolean inSingleQuote = false; @@ -322,11 +321,9 @@ static void addPermGenSizeOpt(List cmd) { if (getJavaVendor() == JavaVendor.IBM) { return; } - String[] version = System.getProperty("java.version").split("\\."); - if (Integer.parseInt(version[0]) > 1 || Integer.parseInt(version[1]) > 7) { + if (javaMajorVersion(System.getProperty("java.version")) > 7) { return; } - for (String arg : cmd) { if (arg.startsWith("-XX:MaxPermSize=")) { return; @@ -336,4 +333,45 @@ static void addPermGenSizeOpt(List cmd) { cmd.add("-XX:MaxPermSize=256m"); } + /** + * Get the major version of the java version string supplied. This method + * accepts any JEP-223-compliant strings (9-ea, 9+100), as well as legacy + * version strings such as 1.7.0_79 + */ + static int javaMajorVersion(String javaVersion) { + String[] version = javaVersion.split("[+.\\-]+"); + int major = Integer.parseInt(version[0]); + // if major > 1, we're using the JEP-223 version string, e.g., 9-ea, 9+120 + // otherwise the second number is the major version + if (major > 1) { + return major; + } else { + return Integer.parseInt(version[1]); + } + } + + /** + * Find the location of the Spark jars dir, depending on whether we're looking at a build + * or a distribution directory. + */ + static String findJarsDir(String sparkHome, String scalaVersion, boolean failIfNotFound) { + // TODO: change to the correct directory once the assembly build is changed. + File libdir; + if (new File(sparkHome, "RELEASE").isFile()) { + libdir = new File(sparkHome, "jars"); + checkState(!failIfNotFound || libdir.isDirectory(), + "Library directory '%s' does not exist.", + libdir.getAbsolutePath()); + } else { + libdir = new File(sparkHome, String.format("assembly/target/scala-%s/jars", scalaVersion)); + if (!libdir.isDirectory()) { + checkState(!failIfNotFound, + "Library directory '%s' does not exist; make sure Spark is built.", + libdir.getAbsolutePath()); + libdir = null; + } + } + return libdir != null ? libdir.getAbsolutePath() : null; + } + } diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherProtocol.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherProtocol.java index 50f136497ec1a..042f11cd9e434 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherProtocol.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherProtocol.java @@ -17,13 +17,7 @@ package org.apache.spark.launcher; -import java.io.Closeable; -import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; import java.io.Serializable; -import java.net.Socket; -import java.util.Map; /** * Message definitions for the launcher communication protocol. These messages must remain diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index d099ee9aa9dae..69fbf4387bdfb 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -129,7 +129,7 @@ private LauncherServer() throws IOException { server.setReuseAddress(true); server.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0)); - this.clients = new ArrayList(); + this.clients = new ArrayList<>(); this.threadIds = new AtomicLong(); this.factory = new NamedThreadFactory(THREAD_NAME_FMT); this.pending = new ConcurrentHashMap<>(); @@ -293,9 +293,7 @@ private class ServerConnection extends LauncherConnection { protected void handle(Message msg) throws IOException { try { if (msg instanceof Hello) { - synchronized (timeout) { - timeout.cancel(); - } + timeout.cancel(); timeout = null; Hello hello = (Hello) msg; ChildProcAppHandle handle = pending.remove(hello.secret); diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index a4e3acc674f36..1e34bb8c73279 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -50,7 +50,7 @@ class Main { public static void main(String[] argsArray) throws Exception { checkArgument(argsArray.length > 0, "Not enough arguments: missing class name."); - List args = new ArrayList(Arrays.asList(argsArray)); + List args = new ArrayList<>(Arrays.asList(argsArray)); String className = args.remove(0); boolean printLaunchCommand = !isEmpty(System.getenv("SPARK_PRINT_LAUNCH_COMMAND")); @@ -70,7 +70,7 @@ public static void main(String[] argsArray) throws Exception { // Ignore parsing exceptions. } - List help = new ArrayList(); + List help = new ArrayList<>(); if (parser.className != null) { help.add(parser.CLASS); help.add(parser.className); @@ -82,7 +82,7 @@ public static void main(String[] argsArray) throws Exception { builder = new SparkClassCommandBuilder(className, args); } - Map env = new HashMap(); + Map env = new HashMap<>(); List cmd = builder.buildCommand(env); if (printLaunchCommand) { System.err.println("Spark Command: " + join(" ", cmd)); @@ -130,7 +130,7 @@ private static List prepareBashCommand(List cmd, Map newCmd = new ArrayList(); + List newCmd = new ArrayList<>(); newCmd.add("env"); for (Map.Entry e : childEnv.entrySet()) { @@ -151,7 +151,7 @@ private static class MainClassOptionParser extends SparkSubmitOptionParser { @Override protected boolean handle(String opt, String value) { - if (opt == CLASS) { + if (CLASS.equals(opt)) { className = value; } return false; diff --git a/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java b/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java index 6e7120167d605..c7959aee9f888 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java +++ b/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java @@ -21,6 +21,7 @@ import java.io.InputStream; import java.io.InputStreamReader; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.concurrent.ThreadFactory; import java.util.logging.Level; import java.util.logging.Logger; @@ -42,7 +43,7 @@ class OutputRedirector { OutputRedirector(InputStream in, String loggerName, ThreadFactory tf) { this.active = true; - this.reader = new BufferedReader(new InputStreamReader(in)); + this.reader = new BufferedReader(new InputStreamReader(in, StandardCharsets.UTF_8)); this.thread = tf.newThread(new Runnable() { @Override public void run() { diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java index 13dd9f1739fb6..625d02632114a 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java @@ -32,7 +32,7 @@ public interface SparkAppHandle { * * @since 1.6.0 */ - public enum State { + enum State { /** The application has not reported back yet. */ UNKNOWN(false), /** The application has connected to the handle. */ @@ -89,6 +89,9 @@ public boolean isFinal() { * Tries to kill the underlying application. Implies {@link #disconnect()}. This will not send * a {@link #stop()} message to the application, so it's recommended that users first try to * stop the application cleanly and only resort to this method if that fails. + *

    + * Note that if the application is running as a child process, this method fail to kill the + * process when using Java 7. This may happen if, for example, the application is deadlocked. */ void kill(); diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index 931a24cfd4b1d..82b593a3f797d 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -17,12 +17,10 @@ package org.apache.spark.launcher; -import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.regex.Pattern; import static org.apache.spark.launcher.CommandBuilderUtils.*; @@ -43,13 +41,14 @@ class SparkClassCommandBuilder extends AbstractCommandBuilder { } @Override - public List buildCommand(Map env) throws IOException { - List javaOptsKeys = new ArrayList(); + public List buildCommand(Map env) + throws IOException, IllegalArgumentException { + List javaOptsKeys = new ArrayList<>(); String memKey = null; String extraClassPath = null; - // Master, Worker, and HistoryServer use SPARK_DAEMON_JAVA_OPTS (and specific opts) + - // SPARK_DAEMON_MEMORY. + // Master, Worker, HistoryServer, ExternalShuffleService, MesosClusterDispatcher use + // SPARK_DAEMON_JAVA_OPTS (and specific opts) + SPARK_DAEMON_MEMORY. if (className.equals("org.apache.spark.deploy.master.Master")) { javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); javaOptsKeys.add("SPARK_MASTER_OPTS"); @@ -69,43 +68,31 @@ public List buildCommand(Map env) throws IOException { } else if (className.equals("org.apache.spark.executor.MesosExecutorBackend")) { javaOptsKeys.add("SPARK_EXECUTOR_OPTS"); memKey = "SPARK_EXECUTOR_MEMORY"; + } else if (className.equals("org.apache.spark.deploy.mesos.MesosClusterDispatcher")) { + javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); } else if (className.equals("org.apache.spark.deploy.ExternalShuffleService") || className.equals("org.apache.spark.deploy.mesos.MesosExternalShuffleService")) { javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); javaOptsKeys.add("SPARK_SHUFFLE_OPTS"); memKey = "SPARK_DAEMON_MEMORY"; - } else if (className.startsWith("org.apache.spark.tools.")) { - String sparkHome = getSparkHome(); - File toolsDir = new File(join(File.separator, sparkHome, "tools", "target", - "scala-" + getScalaVersion())); - checkState(toolsDir.isDirectory(), "Cannot find tools build directory."); - - Pattern re = Pattern.compile("spark-tools_.*\\.jar"); - for (File f : toolsDir.listFiles()) { - if (re.matcher(f.getName()).matches()) { - extraClassPath = f.getAbsolutePath(); - break; - } - } - - checkState(extraClassPath != null, - "Failed to find Spark Tools Jar in %s.\n" + - "You need to run \"build/sbt tools/package\" before running %s.", - toolsDir.getAbsolutePath(), className); - - javaOptsKeys.add("SPARK_JAVA_OPTS"); } else { javaOptsKeys.add("SPARK_JAVA_OPTS"); memKey = "SPARK_DRIVER_MEMORY"; } List cmd = buildJavaCommand(extraClassPath); + for (String key : javaOptsKeys) { - addOptionString(cmd, System.getenv(key)); + String envValue = System.getenv(key); + if (!isEmpty(envValue) && envValue.contains("Xmx")) { + String msg = String.format("%s is not allowed to specify max heap(Xmx) memory settings " + + "(was %s). Use the corresponding configuration instead.", key, envValue); + throw new IllegalArgumentException(msg); + } + addOptionString(cmd, envValue); } String mem = firstNonEmpty(memKey != null ? System.getenv(memKey) : null, DEFAULT_MEM); - cmd.add("-Xms" + mem); cmd.add("-Xmx" + mem); addPermGenSizeOpt(cmd); cmd.add(className); diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index dd1c93af6ca4c..a083f05a2a9f7 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -40,6 +40,9 @@ public class SparkLauncher { /** The Spark master. */ public static final String SPARK_MASTER = "spark.master"; + /** The Spark deploy mode. */ + public static final String DEPLOY_MODE = "spark.submit.deployMode"; + /** Configuration key for the driver memory. */ public static final String DRIVER_MEMORY = "spark.driver.memory"; /** Configuration key for the driver class path. */ @@ -72,7 +75,7 @@ public class SparkLauncher { /** Used internally to create unique logger names. */ private static final AtomicInteger COUNTER = new AtomicInteger(); - static final Map launcherConfig = new HashMap(); + static final Map launcherConfig = new HashMap<>(); /** * Set a configuration value for the launcher library. These config values do not affect the @@ -425,7 +428,7 @@ public SparkAppHandle startApplication(SparkAppHandle.Listener... listeners) thr } private ProcessBuilder createBuilder() { - List cmd = new ArrayList(); + List cmd = new ArrayList<>(); String script = isWindows() ? "spark-submit.cmd" : "spark-submit"; cmd.add(join(File.separator, builder.getSparkHome(), "bin", script)); cmd.addAll(builder.buildSparkSubmitArgs()); @@ -434,7 +437,7 @@ private ProcessBuilder createBuilder() { // preserved, otherwise the batch interpreter will mess up the arguments. Batch scripts are // weird. if (isWindows()) { - List winCmd = new ArrayList(); + List winCmd = new ArrayList<>(); for (String arg : cmd) { winCmd.add(quoteForBatchScript(arg)); } @@ -474,6 +477,6 @@ protected void handleExtraArgs(List extra) { // No op. } - }; + } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 39b46e0db8cc2..6941ca903cd0a 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -30,7 +30,8 @@ * driver-side options and special parsing behavior needed for the special-casing certain internal * Spark applications. *

    - * This class has also some special features to aid launching pyspark. + * This class has also some special features to aid launching shells (pyspark and sparkR) and also + * examples. */ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { @@ -62,12 +63,23 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { */ static final String SPARKR_SHELL_RESOURCE = "sparkr-shell"; + /** + * Name of app resource used to identify examples. When running examples, args[0] should be + * this name. The app resource will identify the example class to run. + */ + static final String RUN_EXAMPLE = "run-example"; + + /** + * Prefix for example class names. + */ + static final String EXAMPLE_CLASS_PREFIX = "org.apache.spark.examples."; + /** * This map must match the class names for available special classes, since this modifies the way * command line parsing works. This maps the class name to the resource to use when calling * spark-submit. */ - private static final Map specialClasses = new HashMap(); + private static final Map specialClasses = new HashMap<>(); static { specialClasses.put("org.apache.spark.repl.Main", "spark-shell"); specialClasses.put("org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver", @@ -77,7 +89,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { } final List sparkArgs; - private final boolean printHelp; + private final boolean printInfo; + private final boolean isExample; /** * Controls whether mixing spark-submit arguments with app arguments is allowed. This is needed @@ -87,12 +100,15 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { private boolean allowsMixedArguments; SparkSubmitCommandBuilder() { - this.sparkArgs = new ArrayList(); - this.printHelp = false; + this.sparkArgs = new ArrayList<>(); + this.printInfo = false; + this.isExample = false; } SparkSubmitCommandBuilder(List args) { - this.sparkArgs = new ArrayList(); + this.allowsMixedArguments = false; + + boolean isExample = false; List submitArgs = args; if (args.size() > 0 && args.get(0).equals(PYSPARK_SHELL)) { this.allowsMixedArguments = true; @@ -102,20 +118,25 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { this.allowsMixedArguments = true; appResource = SPARKR_SHELL_RESOURCE; submitArgs = args.subList(1, args.size()); - } else { - this.allowsMixedArguments = false; + } else if (args.size() > 0 && args.get(0).equals(RUN_EXAMPLE)) { + isExample = true; + submitArgs = args.subList(1, args.size()); } + this.sparkArgs = new ArrayList<>(); + this.isExample = isExample; + OptionParser parser = new OptionParser(); parser.parse(submitArgs); - this.printHelp = parser.helpRequested; + this.printInfo = parser.infoRequested; } @Override - public List buildCommand(Map env) throws IOException { - if (PYSPARK_SHELL_RESOURCE.equals(appResource) && !printHelp) { + public List buildCommand(Map env) + throws IOException, IllegalArgumentException { + if (PYSPARK_SHELL_RESOURCE.equals(appResource) && !printInfo) { return buildPySparkShellCommand(env); - } else if (SPARKR_SHELL_RESOURCE.equals(appResource) && !printHelp) { + } else if (SPARKR_SHELL_RESOURCE.equals(appResource) && !printInfo) { return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); @@ -123,7 +144,7 @@ public List buildCommand(Map env) throws IOException { } List buildSparkSubmitArgs() { - List args = new ArrayList(); + List args = new ArrayList<>(); SparkSubmitOptionParser parser = new SparkSubmitOptionParser(); if (verbose) { @@ -155,6 +176,10 @@ List buildSparkSubmitArgs() { args.add(propertiesFile); } + if (isExample) { + jars.addAll(findExamplesJars()); + } + if (!jars.isEmpty()) { args.add(parser.JARS); args.add(join(",", jars)); @@ -170,6 +195,9 @@ List buildSparkSubmitArgs() { args.add(join(",", pyFiles)); } + if (!printInfo) { + checkArgument(!isExample || mainClass != null, "Missing example class name."); + } if (mainClass != null) { args.add(parser.CLASS); args.add(mainClass); @@ -184,7 +212,8 @@ List buildSparkSubmitArgs() { return args; } - private List buildSparkSubmitCommand(Map env) throws IOException { + private List buildSparkSubmitCommand(Map env) + throws IOException, IllegalArgumentException { // Load the properties file and check whether spark-submit will be running the app's driver // or just launching a cluster app. When running the driver, the JVM's argument will be // modified to cover the driver's configuration. @@ -200,6 +229,16 @@ private List buildSparkSubmitCommand(Map env) throws IOE addOptionString(cmd, System.getenv("SPARK_SUBMIT_OPTS")); addOptionString(cmd, System.getenv("SPARK_JAVA_OPTS")); + // We don't want the client to specify Xmx. These have to be set by their corresponding + // memory flag --driver-memory or configuration entry spark.driver.memory + String driverExtraJavaOptions = config.get(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS); + if (!isEmpty(driverExtraJavaOptions) && driverExtraJavaOptions.contains("Xmx")) { + String msg = String.format("Not allowed to specify max heap(Xmx) memory settings through " + + "java options (was %s). Use the corresponding --driver-memory or " + + "spark.driver.memory configuration instead.", driverExtraJavaOptions); + throw new IllegalArgumentException(msg); + } + if (isClientMode) { // Figuring out where the memory value come from is a little tricky due to precedence. // Precedence is observed in the following order: @@ -213,9 +252,8 @@ private List buildSparkSubmitCommand(Map env) throws IOE isThriftServer(mainClass) ? System.getenv("SPARK_DAEMON_MEMORY") : null; String memory = firstNonEmpty(tsMemory, config.get(SparkLauncher.DRIVER_MEMORY), System.getenv("SPARK_DRIVER_MEMORY"), System.getenv("SPARK_MEM"), DEFAULT_MEM); - cmd.add("-Xms" + memory); cmd.add("-Xmx" + memory); - addOptionString(cmd, config.get(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS)); + addOptionString(cmd, driverExtraJavaOptions); mergeEnvPathList(env, getLibPathEnvName(), config.get(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH)); } @@ -231,11 +269,9 @@ private List buildPySparkShellCommand(Map env) throws IO // the pyspark command line, then run it using spark-submit. if (!appArgs.isEmpty() && appArgs.get(0).endsWith(".py")) { System.err.println( - "WARNING: Running python applications through 'pyspark' is deprecated as of Spark 1.0.\n" + + "Running python applications through 'pyspark' is not supported as of Spark 2.0.\n" + "Use ./bin/spark-submit "); - appResource = appArgs.get(0); - appArgs.remove(0); - return buildCommand(env); + System.exit(-1); } checkArgument(appArgs.isEmpty(), "pyspark does not support any application options."); @@ -246,7 +282,7 @@ private List buildPySparkShellCommand(Map env) throws IO // The executable is the PYSPARK_DRIVER_PYTHON env variable set by the pyspark script, // followed by PYSPARK_DRIVER_PYTHON_OPTS. - List pyargs = new ArrayList(); + List pyargs = new ArrayList<>(); pyargs.add(firstNonEmpty(System.getenv("PYSPARK_DRIVER_PYTHON"), "python")); String pyOpts = System.getenv("PYSPARK_DRIVER_PYTHON_OPTS"); if (!isEmpty(pyOpts)) { @@ -258,9 +294,10 @@ private List buildPySparkShellCommand(Map env) throws IO private List buildSparkRCommand(Map env) throws IOException { if (!appArgs.isEmpty() && appArgs.get(0).endsWith(".R")) { - appResource = appArgs.get(0); - appArgs.remove(0); - return buildCommand(env); + System.err.println( + "Running R applications through 'sparkR' is not supported as of Spark 2.0.\n" + + "Use ./bin/spark-submit "); + System.exit(-1); } // When launching the SparkR shell, store the spark-submit arguments in the SPARKR_SUBMIT_ARGS // env variable. @@ -271,7 +308,7 @@ private List buildSparkRCommand(Map env) throws IOExcept env.put("R_PROFILE_USER", join(File.separator, sparkHome, "R", "lib", "SparkR", "profile", "shell.R")); - List args = new ArrayList(); + List args = new ArrayList<>(); args.add(firstNonEmpty(System.getenv("SPARKR_DRIVER_R"), "R")); return args; } @@ -294,10 +331,11 @@ private void constructEnvVarArgs( private boolean isClientMode(Map userProps) { String userMaster = firstNonEmpty(master, userProps.get(SparkLauncher.SPARK_MASTER)); - // Default master is "local[*]", so assume client mode in that case. + String userDeployMode = firstNonEmpty(deployMode, userProps.get(SparkLauncher.DEPLOY_MODE)); + // Default master is "local[*]", so assume client mode in that case return userMaster == null || - "client".equals(deployMode) || - (!userMaster.equals("yarn-cluster") && deployMode == null); + "client".equals(userDeployMode) || + (!userMaster.equals("yarn-cluster") && userDeployMode == null); } /** @@ -308,10 +346,34 @@ private boolean isThriftServer(String mainClass) { mainClass.equals("org.apache.spark.sql.hive.thriftserver.HiveThriftServer2")); } + private List findExamplesJars() { + boolean isTesting = "1".equals(getenv("SPARK_TESTING")); + List examplesJars = new ArrayList<>(); + String sparkHome = getSparkHome(); + + File jarsDir; + if (new File(sparkHome, "RELEASE").isFile()) { + jarsDir = new File(sparkHome, "examples/jars"); + } else { + jarsDir = new File(sparkHome, + String.format("examples/target/scala-%s/jars", getScalaVersion())); + } + + boolean foundDir = jarsDir.isDirectory(); + checkState(isTesting || foundDir, "Examples jars directory '%s' does not exist.", + jarsDir.getAbsolutePath()); + + if (foundDir) { + for (File f: jarsDir.listFiles()) { + examplesJars.add(f.getAbsolutePath()); + } + } + return examplesJars; + } private class OptionParser extends SparkSubmitOptionParser { - boolean helpRequested = false; + boolean infoRequested = false; @Override protected boolean handle(String opt, String value) { @@ -344,7 +406,10 @@ protected boolean handle(String opt, String value) { appResource = specialClasses.get(value); } } else if (opt.equals(HELP) || opt.equals(USAGE_ERROR)) { - helpRequested = true; + infoRequested = true; + sparkArgs.add(opt); + } else if (opt.equals(VERSION)) { + infoRequested = true; sparkArgs.add(opt); } else { sparkArgs.add(opt); @@ -364,6 +429,14 @@ protected boolean handleUnknown(String opt) { if (allowsMixedArguments) { appArgs.add(opt); return true; + } else if (isExample) { + String className = opt; + if (!className.startsWith(EXAMPLE_CLASS_PREFIX)) { + className = EXAMPLE_CLASS_PREFIX + className; + } + mainClass = className; + appResource = "spark-internal"; + return false; } else { checkArgument(!opt.startsWith("-"), "Unrecognized option: %s", opt); sparkArgs.add(opt); @@ -373,8 +446,10 @@ protected boolean handleUnknown(String opt) { @Override protected void handleExtraArgs(List extra) { - for (String arg : extra) { - sparkArgs.add(arg); + if (isExample) { + appArgs.addAll(extra); + } else { + sparkArgs.addAll(extra); } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java index bc513ec9b3d10..4fafc43ef293b 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java @@ -87,6 +87,18 @@ public void testPythonArgQuoting() { assertEquals("\"a \\\"b\\\" c\"", quoteForCommandString("a \"b\" c")); } + @Test + public void testJavaMajorVersion() { + assertEquals(6, javaMajorVersion("1.6.0_50")); + assertEquals(7, javaMajorVersion("1.7.0_79")); + assertEquals(8, javaMajorVersion("1.8.0_66")); + assertEquals(9, javaMajorVersion("9-ea")); + assertEquals(9, javaMajorVersion("9+100")); + assertEquals(9, javaMajorVersion("9")); + assertEquals(9, javaMajorVersion("9.1.0")); + assertEquals(10, javaMajorVersion("10")); + } + private void testOpt(String opts, List expected) { assertEquals(String.format("test string failed to parse: [[ %s ]]", opts), expected, parseOptionString(opts)); diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index dc8fbb58d880b..bfe1fcc87fe35 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -23,11 +23,11 @@ import java.net.Socket; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import org.junit.Test; import static org.junit.Assert.*; -import static org.mockito.Mockito.*; import static org.apache.spark.launcher.LauncherProtocol.*; @@ -69,48 +69,35 @@ public void testCommunication() throws Exception { Socket s = new Socket(InetAddress.getLoopbackAddress(), LauncherServer.getServerInstance().getPort()); - final Object waitLock = new Object(); + final Semaphore semaphore = new Semaphore(0); handle.addListener(new SparkAppHandle.Listener() { @Override public void stateChanged(SparkAppHandle handle) { - wakeUp(); + semaphore.release(); } - @Override public void infoChanged(SparkAppHandle handle) { - wakeUp(); - } - - private void wakeUp() { - synchronized (waitLock) { - waitLock.notifyAll(); - } + semaphore.release(); } }); client = new TestClient(s); - synchronized (waitLock) { - client.send(new Hello(handle.getSecret(), "1.4.0")); - waitLock.wait(TimeUnit.SECONDS.toMillis(10)); - } + client.send(new Hello(handle.getSecret(), "1.4.0")); + assertTrue(semaphore.tryAcquire(30, TimeUnit.SECONDS)); // Make sure the server matched the client to the handle. assertNotNull(handle.getConnection()); - synchronized (waitLock) { - client.send(new SetAppId("app-id")); - waitLock.wait(TimeUnit.SECONDS.toMillis(10)); - } + client.send(new SetAppId("app-id")); + assertTrue(semaphore.tryAcquire(30, TimeUnit.SECONDS)); assertEquals("app-id", handle.getAppId()); - synchronized (waitLock) { - client.send(new SetState(SparkAppHandle.State.RUNNING)); - waitLock.wait(TimeUnit.SECONDS.toMillis(10)); - } + client.send(new SetState(SparkAppHandle.State.RUNNING)); + assertTrue(semaphore.tryAcquire(1, TimeUnit.SECONDS)); assertEquals(SparkAppHandle.State.RUNNING, handle.getState()); handle.stop(); - Message stopMsg = client.inbound.poll(10, TimeUnit.SECONDS); + Message stopMsg = client.inbound.poll(30, TimeUnit.SECONDS); assertTrue(stopMsg instanceof Stop); } finally { kill(handle); @@ -188,7 +175,7 @@ private static class TestClient extends LauncherConnection { TestClient(Socket s) throws IOException { super(s); - this.inbound = new LinkedBlockingQueue(); + this.inbound = new LinkedBlockingQueue<>(); this.clientThread = new Thread(this); clientThread.setName("TestClient"); clientThread.setDaemon(true); diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index 6aad47adbcc82..c7e8b2e03a9fa 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -73,13 +73,12 @@ public void testCliParser() throws Exception { "spark.randomOption=foo", parser.CONF, SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH + "=/driverLibPath"); - Map env = new HashMap(); + Map env = new HashMap<>(); List cmd = buildCommand(sparkSubmitArgs, env); assertTrue(findInStringList(env.get(CommandBuilderUtils.getLibPathEnvName()), File.pathSeparator, "/driverLibPath")); assertTrue(findInStringList(findArgValue(cmd, "-cp"), File.pathSeparator, "/driverCp")); - assertTrue("Driver -Xms should be configured.", cmd.contains("-Xms42g")); assertTrue("Driver -Xmx should be configured.", cmd.contains("-Xmx42g")); assertTrue("Command should contain user-defined conf.", Collections.indexOfSubList(cmd, Arrays.asList(parser.CONF, "spark.randomOption=foo")) > 0); @@ -125,7 +124,7 @@ public void testPySparkLauncher() throws Exception { "--master=foo", "--deploy-mode=bar"); - Map env = new HashMap(); + Map env = new HashMap<>(); List cmd = buildCommand(sparkSubmitArgs, env); assertEquals("python", cmd.get(cmd.size() - 1)); assertEquals( @@ -142,7 +141,7 @@ public void testPySparkFallback() throws Exception { "script.py", "arg1"); - Map env = new HashMap(); + Map env = new HashMap<>(); List cmd = buildCommand(sparkSubmitArgs, env); assertEquals("foo", findArgValue(cmd, "--master")); @@ -151,6 +150,24 @@ public void testPySparkFallback() throws Exception { assertEquals("arg1", cmd.get(cmd.size() - 1)); } + @Test + public void testExamplesRunner() throws Exception { + List sparkSubmitArgs = Arrays.asList( + SparkSubmitCommandBuilder.RUN_EXAMPLE, + parser.MASTER + "=foo", + parser.DEPLOY_MODE + "=bar", + "SparkPi", + "42"); + + Map env = new HashMap<>(); + List cmd = buildCommand(sparkSubmitArgs, env); + assertEquals("foo", findArgValue(cmd, parser.MASTER)); + assertEquals("bar", findArgValue(cmd, parser.DEPLOY_MODE)); + assertEquals(SparkSubmitCommandBuilder.EXAMPLE_CLASS_PREFIX + "SparkPi", + findArgValue(cmd, parser.CLASS)); + assertEquals("42", cmd.get(cmd.size() - 1)); + } + private void testCmdBuilder(boolean isDriver, boolean useDefaultPropertyFile) throws Exception { String deployMode = isDriver ? "client" : "cluster"; @@ -178,18 +195,17 @@ private void testCmdBuilder(boolean isDriver, boolean useDefaultPropertyFile) th + "/launcher/src/test/resources"); } - Map env = new HashMap(); + Map env = new HashMap<>(); List cmd = launcher.buildCommand(env); // Checks below are different for driver and non-driver mode. if (isDriver) { - assertTrue("Driver -Xms should be configured.", cmd.contains("-Xms1g")); assertTrue("Driver -Xmx should be configured.", cmd.contains("-Xmx1g")); } else { boolean found = false; for (String arg : cmd) { - if (arg.startsWith("-Xms") || arg.startsWith("-Xmx")) { + if (arg.startsWith("-Xmx")) { found = true; break; } @@ -199,11 +215,7 @@ private void testCmdBuilder(boolean isDriver, boolean useDefaultPropertyFile) th for (String arg : cmd) { if (arg.startsWith("-XX:MaxPermSize=")) { - if (isDriver) { - assertEquals("-XX:MaxPermSize=256m", arg); - } else { - assertEquals("-XX:MaxPermSize=256m", arg); - } + assertEquals("-XX:MaxPermSize=256m", arg); } } @@ -258,7 +270,7 @@ private boolean contains(String needle, String[] haystack) { } private Map parseConf(List cmd, SparkSubmitOptionParser parser) { - Map conf = new HashMap(); + Map conf = new HashMap<>(); for (int i = 0; i < cmd.size(); i++) { if (cmd.get(i).equals(parser.CONF)) { String[] val = cmd.get(i + 1).split("=", 2); @@ -286,7 +298,6 @@ private boolean findInStringList(String list, String sep, String needle) { private SparkSubmitCommandBuilder newCommandBuilder(List args) { SparkSubmitCommandBuilder builder = new SparkSubmitCommandBuilder(args); builder.childEnv.put(CommandBuilderUtils.ENV_SPARK_HOME, System.getProperty("spark.test.home")); - builder.childEnv.put(CommandBuilderUtils.ENV_SPARK_ASSEMBLY, "dummy"); return builder; } diff --git a/launcher/src/test/resources/log4j.properties b/launcher/src/test/resources/log4j.properties index c64b1565e1469..744c456cb29c1 100644 --- a/launcher/src/test/resources/log4j.properties +++ b/launcher/src/test/resources/log4j.properties @@ -30,5 +30,4 @@ log4j.appender.childproc.layout=org.apache.log4j.PatternLayout log4j.appender.childproc.layout.ConversionPattern=%t: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN -org.spark-project.jetty.LEVEL=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/licenses/LICENSE-jblas.txt b/licenses/LICENSE-jblas.txt deleted file mode 100644 index 5629dafb65b39..0000000000000 --- a/licenses/LICENSE-jblas.txt +++ /dev/null @@ -1,31 +0,0 @@ -Copyright (c) 2009, Mikio L. Braun and contributors -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - * Redistributions in binary form must reproduce the above - copyright notice, this list of conditions and the following - disclaimer in the documentation and/or other materials provided - with the distribution. - - * Neither the name of the Technische Universität Berlin nor the - names of its contributors may be used to endorse or promote - products derived from this software without specific prior - written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/make-distribution.sh b/make-distribution.sh deleted file mode 100755 index e1c2afdbc6d87..0000000000000 --- a/make-distribution.sh +++ /dev/null @@ -1,266 +0,0 @@ -#!/usr/bin/env bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# -# Script to create a binary distribution for easy deploys of Spark. -# The distribution directory defaults to dist/ but can be overridden below. -# The distribution contains fat (assembly) jars that include the Scala library, -# so it is completely self contained. -# It does not contain source or *.class files. - -set -o pipefail -set -e -set -x - -# Figure out where the Spark framework is installed -SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" -DISTDIR="$SPARK_HOME/dist" - -SPARK_TACHYON=false -TACHYON_VERSION="0.8.1" -TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" -TACHYON_URL="http://tachyon-project.org/downloads/files/${TACHYON_VERSION}/${TACHYON_TGZ}" - -MAKE_TGZ=false -NAME=none -MVN="$SPARK_HOME/build/mvn" - -function exit_with_usage { - echo "make-distribution.sh - tool for making binary distributions of Spark" - echo "" - echo "usage:" - cl_options="[--name] [--tgz] [--mvn ] [--with-tachyon]" - echo "./make-distribution.sh $cl_options " - echo "See Spark's \"Building Spark\" doc for correct Maven options." - echo "" - exit 1 -} - -# Parse arguments -while (( "$#" )); do - case $1 in - --hadoop) - echo "Error: '--hadoop' is no longer supported:" - echo "Error: use Maven profiles and options -Dhadoop.version and -Dyarn.version instead." - echo "Error: Related profiles include hadoop-1, hadoop-2.2, hadoop-2.3 and hadoop-2.4." - exit_with_usage - ;; - --with-yarn) - echo "Error: '--with-yarn' is no longer supported, use Maven option -Pyarn" - exit_with_usage - ;; - --with-hive) - echo "Error: '--with-hive' is no longer supported, use Maven options -Phive and -Phive-thriftserver" - exit_with_usage - ;; - --skip-java-test) - SKIP_JAVA_TEST=true - ;; - --with-tachyon) - SPARK_TACHYON=true - ;; - --tgz) - MAKE_TGZ=true - ;; - --mvn) - MVN="$2" - shift - ;; - --name) - NAME="$2" - shift - ;; - --help) - exit_with_usage - ;; - *) - break - ;; - esac - shift -done - -if [ -z "$JAVA_HOME" ]; then - # Fall back on JAVA_HOME from rpm, if found - if [ $(command -v rpm) ]; then - RPM_JAVA_HOME="$(rpm -E %java_home 2>/dev/null)" - if [ "$RPM_JAVA_HOME" != "%java_home" ]; then - JAVA_HOME="$RPM_JAVA_HOME" - echo "No JAVA_HOME set, proceeding with '$JAVA_HOME' learned from rpm" - fi - fi -fi - -if [ -z "$JAVA_HOME" ]; then - echo "Error: JAVA_HOME is not set, cannot proceed." - exit -1 -fi - -if [ $(command -v git) ]; then - GITREV=$(git rev-parse --short HEAD 2>/dev/null || :) - if [ ! -z "$GITREV" ]; then - GITREVSTRING=" (git revision $GITREV)" - fi - unset GITREV -fi - - -if [ ! "$(command -v "$MVN")" ] ; then - echo -e "Could not locate Maven command: '$MVN'." - echo -e "Specify the Maven command with the --mvn flag" - exit -1; -fi - -VERSION=$("$MVN" help:evaluate -Dexpression=project.version $@ 2>/dev/null | grep -v "INFO" | tail -n 1) -SCALA_VERSION=$("$MVN" help:evaluate -Dexpression=scala.binary.version $@ 2>/dev/null\ - | grep -v "INFO"\ - | tail -n 1) -SPARK_HADOOP_VERSION=$("$MVN" help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ - | grep -v "INFO"\ - | tail -n 1) -SPARK_HIVE=$("$MVN" help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\ - | grep -v "INFO"\ - | fgrep --count "hive";\ - # Reset exit status to 0, otherwise the script stops here if the last grep finds nothing\ - # because we use "set -o pipefail" - echo -n) - -if [ "$NAME" == "none" ]; then - NAME=$SPARK_HADOOP_VERSION -fi - -echo "Spark version is $VERSION" - -if [ "$MAKE_TGZ" == "true" ]; then - echo "Making spark-$VERSION-bin-$NAME.tgz" -else - echo "Making distribution for Spark $VERSION in $DISTDIR..." -fi - -if [ "$SPARK_TACHYON" == "true" ]; then - echo "Tachyon Enabled" -else - echo "Tachyon Disabled" -fi - -# Build uber fat JAR -cd "$SPARK_HOME" - -export MAVEN_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" - -# Store the command as an array because $MVN variable might have spaces in it. -# Normal quoting tricks don't work. -# See: http://mywiki.wooledge.org/BashFAQ/050 -BUILD_COMMAND=("$MVN" clean package -DskipTests $@) - -# Actually build the jar -echo -e "\nBuilding with..." -echo -e "\$ ${BUILD_COMMAND[@]}\n" - -"${BUILD_COMMAND[@]}" - -# Make directories -rm -rf "$DISTDIR" -mkdir -p "$DISTDIR/lib" -echo "Spark $VERSION$GITREVSTRING built for Hadoop $SPARK_HADOOP_VERSION" > "$DISTDIR/RELEASE" -echo "Build flags: $@" >> "$DISTDIR/RELEASE" - -# Copy jars -cp "$SPARK_HOME"/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/" -cp "$SPARK_HOME"/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" -# This will fail if the -Pyarn profile is not provided -# In this case, silence the error and ignore the return code of this command -cp "$SPARK_HOME"/network/yarn/target/scala*/spark-*-yarn-shuffle.jar "$DISTDIR/lib/" &> /dev/null || : - -# Copy example sources (needed for python and SQL) -mkdir -p "$DISTDIR/examples/src/main" -cp -r "$SPARK_HOME"/examples/src/main "$DISTDIR/examples/src/" - -if [ "$SPARK_HIVE" == "1" ]; then - cp "$SPARK_HOME"/lib_managed/jars/datanucleus*.jar "$DISTDIR/lib/" -fi - -# Copy license and ASF files -cp "$SPARK_HOME/LICENSE" "$DISTDIR" -cp -r "$SPARK_HOME/licenses" "$DISTDIR" -cp "$SPARK_HOME/NOTICE" "$DISTDIR" - -if [ -e "$SPARK_HOME"/CHANGES.txt ]; then - cp "$SPARK_HOME/CHANGES.txt" "$DISTDIR" -fi - -# Copy data files -cp -r "$SPARK_HOME/data" "$DISTDIR" - -# Copy other things -mkdir "$DISTDIR"/conf -cp "$SPARK_HOME"/conf/*.template "$DISTDIR"/conf -cp "$SPARK_HOME/README.md" "$DISTDIR" -cp -r "$SPARK_HOME/bin" "$DISTDIR" -cp -r "$SPARK_HOME/python" "$DISTDIR" -cp -r "$SPARK_HOME/sbin" "$DISTDIR" -cp -r "$SPARK_HOME/ec2" "$DISTDIR" -# Copy SparkR if it exists -if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then - mkdir -p "$DISTDIR"/R/lib - cp -r "$SPARK_HOME/R/lib/SparkR" "$DISTDIR"/R/lib -fi - -# Download and copy in tachyon, if requested -if [ "$SPARK_TACHYON" == "true" ]; then - TMPD=`mktemp -d 2>/dev/null || mktemp -d -t 'disttmp'` - - pushd "$TMPD" > /dev/null - echo "Fetching tachyon tgz" - - TACHYON_DL="${TACHYON_TGZ}.part" - if [ $(command -v curl) ]; then - curl --silent -k -L "${TACHYON_URL}" > "${TACHYON_DL}" && mv "${TACHYON_DL}" "${TACHYON_TGZ}" - elif [ $(command -v wget) ]; then - wget --quiet "${TACHYON_URL}" -O "${TACHYON_DL}" && mv "${TACHYON_DL}" "${TACHYON_TGZ}" - else - printf "You do not have curl or wget installed. please install Tachyon manually.\n" - exit -1 - fi - - tar xzf "${TACHYON_TGZ}" - cp "tachyon-${TACHYON_VERSION}/assembly/target/tachyon-assemblies-${TACHYON_VERSION}-jar-with-dependencies.jar" "$DISTDIR/lib" - mkdir -p "$DISTDIR/tachyon/src/main/java/tachyon/web" - cp -r "tachyon-${TACHYON_VERSION}"/{bin,conf,libexec} "$DISTDIR/tachyon" - cp -r "tachyon-${TACHYON_VERSION}"/servers/src/main/java/tachyon/web "$DISTDIR/tachyon/src/main/java/tachyon/web" - - if [[ `uname -a` == Darwin* ]]; then - # need to run sed differently on osx - nl=$'\n'; sed -i "" -e "s|export TACHYON_JAR=\$TACHYON_HOME/target/\(.*\)|# This is set for spark's make-distribution\\$nl export TACHYON_JAR=\$TACHYON_HOME/../lib/\1|" "$DISTDIR/tachyon/libexec/tachyon-config.sh" - else - sed -i "s|export TACHYON_JAR=\$TACHYON_HOME/target/\(.*\)|# This is set for spark's make-distribution\n export TACHYON_JAR=\$TACHYON_HOME/../lib/\1|" "$DISTDIR/tachyon/libexec/tachyon-config.sh" - fi - - popd > /dev/null - rm -rf "$TMPD" -fi - -if [ "$MAKE_TGZ" == "true" ]; then - TARDIR_NAME=spark-$VERSION-bin-$NAME - TARDIR="$SPARK_HOME/$TARDIR_NAME" - rm -rf "$TARDIR" - cp -r "$DISTDIR" "$TARDIR" - tar czf "spark-$VERSION-bin-$NAME.tgz" -C "$SPARK_HOME" "$TARDIR_NAME" - rm -rf "$TARDIR" -fi diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml new file mode 100644 index 0000000000000..68f15dd905028 --- /dev/null +++ b/mllib-local/pom.xml @@ -0,0 +1,74 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.0.0-SNAPSHOT + ../pom.xml + + + org.apache.spark + spark-mllib-local_2.11 + + mllib-local + + jar + Spark Project ML Local Library + http://spark.apache.org/ + + + + org.scalanlp + breeze_${scala.binary.version} + + + org.apache.commons + commons-math3 + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.mockito + mockito-core + test + + + + + netlib-lgpl + + + com.github.fommil.netlib + all + ${netlib.java.version} + pom + + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/DummyTesting.scala b/mllib-local/src/main/scala/org/apache/spark/ml/DummyTesting.scala new file mode 100644 index 0000000000000..6b3268cdfa25c --- /dev/null +++ b/mllib-local/src/main/scala/org/apache/spark/ml/DummyTesting.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +// This is a private class testing if the new build works. To be removed soon. +private[ml] object DummyTesting { + private[ml] def add10(input: Double): Double = input + 10 +} diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/DummyTestingSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/DummyTestingSuite.scala new file mode 100644 index 0000000000000..51b7c2409ff22 --- /dev/null +++ b/mllib-local/src/test/scala/org/apache/spark/ml/DummyTestingSuite.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.scalatest.FunSuite // scalastyle:ignore funsuite + +// This is testing if the new build works. To be removed soon. +class DummyTestingSuite extends FunSuite { // scalastyle:ignore funsuite + + test("This is testing if the new build works.") { + assert(DummyTesting.add10(15) === 25) + } +} diff --git a/mllib/pom.xml b/mllib/pom.xml index 70139121d8c78..24d8274e2222f 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT + spark-parent_2.11 + 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-mllib_2.10 + spark-mllib_2.11 mllib @@ -63,27 +63,20 @@ ${project.version} - org.jblas - jblas - ${jblas.version} + org.apache.spark + spark-mllib-local_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-mllib-local_${scala.binary.version} + ${project.version} + test-jar test org.scalanlp breeze_${scala.binary.version} - 0.11.2 - - - - junit - junit - - - org.apache.commons - commons-math3 - - org.apache.commons @@ -109,7 +102,7 @@ org.jpmml pmml-model - 1.1.15 + 1.2.7 com.sun.xml.fastinfoset diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index 57e416591de69..1247882d6c1bd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -19,9 +19,9 @@ package org.apache.spark.ml import scala.annotation.varargs -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.ml.param.{ParamMap, ParamPair} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Dataset /** * :: DeveloperApi :: @@ -39,8 +39,9 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage { * Estimator's embedded ParamMap. * @return fitted model */ + @Since("2.0.0") @varargs - def fit(dataset: DataFrame, firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = { + def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = { val map = new ParamMap() .put(firstParamPair) .put(otherParamPairs: _*) @@ -55,14 +56,16 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage { * These values override any specified in this Estimator's embedded ParamMap. * @return fitted model */ - def fit(dataset: DataFrame, paramMap: ParamMap): M = { + @Since("2.0.0") + def fit(dataset: Dataset[_], paramMap: ParamMap): M = { copy(paramMap).fit(dataset) } /** * Fits a model to the input data. */ - def fit(dataset: DataFrame): M + @Since("2.0.0") + def fit(dataset: Dataset[_]): M /** * Fits multiple models to the input data with multiple sets of parameters. @@ -74,7 +77,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage { * These values override any specified in this Estimator's embedded ParamMap. * @return fitted models, matching the input parameter maps */ - def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = { + @Since("2.0.0") + def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[M] = { paramMaps.map(fit(dataset, _)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index a3e59401c5cfb..82066726a0694 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -22,11 +22,16 @@ import java.{util => ju} import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer -import org.apache.spark.Logging -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} +import org.apache.spark.internal.Logging import org.apache.spark.ml.param.{Param, ParamMap, Params} -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.sql.DataFrame +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType /** @@ -81,30 +86,31 @@ abstract class PipelineStage extends Params with Logging { * transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as * an identity transformer. */ +@Since("1.2.0") @Experimental -class Pipeline(override val uid: String) extends Estimator[PipelineModel] { +class Pipeline @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) extends Estimator[PipelineModel] with MLWritable { + @Since("1.4.0") def this() = this(Identifiable.randomUID("pipeline")) /** * param for pipeline stages * @group param */ + @Since("1.2.0") val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline") /** @group setParam */ + @Since("1.2.0") def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } // Below, we clone stages so that modifications to the list of stages will not change // the Param value in the Pipeline. /** @group getParam */ + @Since("1.2.0") def getStages: Array[PipelineStage] = $(stages).clone() - override def validateParams(): Unit = { - super.validateParams() - $(stages).foreach(_.validateParams()) - } - /** * Fits the pipeline to the input dataset with additional parameters. If a stage is an * [[Estimator]], its [[Estimator#fit]] method will be called on the input dataset to fit a model. @@ -117,7 +123,8 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] { * @param dataset input dataset * @return fitted pipeline */ - override def fit(dataset: DataFrame): PipelineModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): PipelineModel = { transformSchema(dataset.schema, logging = true) val theStages = $(stages) // Search for the last estimator. @@ -140,7 +147,7 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] { t case _ => throw new IllegalArgumentException( - s"Do not support stage $stage of type ${stage.getClass}") + s"Does not support stage $stage of type ${stage.getClass}") } if (index < indexOfLastEstimator) { curDataset = transformer.transform(curDataset) @@ -154,50 +161,188 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] { new PipelineModel(uid, transformers.toArray).setParent(this) } + @Since("1.4.0") override def copy(extra: ParamMap): Pipeline = { val map = extractParamMap(extra) val newStages = map(stages).map(_.copy(extra)) new Pipeline().setStages(newStages) } + @Since("1.2.0") override def transformSchema(schema: StructType): StructType = { val theStages = $(stages) require(theStages.toSet.size == theStages.length, "Cannot have duplicate components in a pipeline.") theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur)) } + + @Since("1.6.0") + override def write: MLWriter = new Pipeline.PipelineWriter(this) +} + +@Since("1.6.0") +object Pipeline extends MLReadable[Pipeline] { + + @Since("1.6.0") + override def read: MLReader[Pipeline] = new PipelineReader + + @Since("1.6.0") + override def load(path: String): Pipeline = super.load(path) + + private[Pipeline] class PipelineWriter(instance: Pipeline) extends MLWriter { + + SharedReadWrite.validateStages(instance.getStages) + + override protected def saveImpl(path: String): Unit = + SharedReadWrite.saveImpl(instance, instance.getStages, sc, path) + } + + private class PipelineReader extends MLReader[Pipeline] { + + /** Checked against metadata when loading model */ + private val className = classOf[Pipeline].getName + + override def load(path: String): Pipeline = { + val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) + new Pipeline(uid).setStages(stages) + } + } + + /** Methods for [[MLReader]] and [[MLWriter]] shared between [[Pipeline]] and [[PipelineModel]] */ + private[ml] object SharedReadWrite { + + import org.json4s.JsonDSL._ + + /** Check that all stages are Writable */ + def validateStages(stages: Array[PipelineStage]): Unit = { + stages.foreach { + case stage: MLWritable => // good + case other => + throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" + + s" because it contains a stage which does not implement Writable. Non-Writable stage:" + + s" ${other.uid} of type ${other.getClass}") + } + } + + /** + * Save metadata and stages for a [[Pipeline]] or [[PipelineModel]] + * - save metadata to path/metadata + * - save stages to stages/IDX_UID + */ + def saveImpl( + instance: Params, + stages: Array[PipelineStage], + sc: SparkContext, + path: String): Unit = { + val stageUids = stages.map(_.uid) + val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq)))) + DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = Some(jsonParams)) + + // Save stages + val stagesDir = new Path(path, "stages").toString + stages.zipWithIndex.foreach { case (stage: MLWritable, idx: Int) => + stage.write.save(getStagePath(stage.uid, idx, stages.length, stagesDir)) + } + } + + /** + * Load metadata and stages for a [[Pipeline]] or [[PipelineModel]] + * @return (UID, list of stages) + */ + def load( + expectedClassName: String, + sc: SparkContext, + path: String): (String, Array[PipelineStage]) = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName) + + implicit val format = DefaultFormats + val stagesDir = new Path(path, "stages").toString + val stageUids: Array[String] = (metadata.params \ "stageUids").extract[Seq[String]].toArray + val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) => + val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir) + DefaultParamsReader.loadParamsInstance[PipelineStage](stagePath, sc) + } + (metadata.uid, stages) + } + + /** Get path for saving the given stage. */ + def getStagePath(stageUid: String, stageIdx: Int, numStages: Int, stagesDir: String): String = { + val stageIdxDigits = numStages.toString.length + val idxFormat = s"%0${stageIdxDigits}d" + val stageDir = idxFormat.format(stageIdx) + "_" + stageUid + new Path(stagesDir, stageDir).toString + } + } } /** * :: Experimental :: * Represents a fitted pipeline. */ +@Since("1.2.0") @Experimental class PipelineModel private[ml] ( - override val uid: String, - val stages: Array[Transformer]) - extends Model[PipelineModel] with Logging { + @Since("1.4.0") override val uid: String, + @Since("1.4.0") val stages: Array[Transformer]) + extends Model[PipelineModel] with MLWritable with Logging { /** A Java/Python-friendly auxiliary constructor. */ private[ml] def this(uid: String, stages: ju.List[Transformer]) = { this(uid, stages.asScala.toArray) } - override def validateParams(): Unit = { - super.validateParams() - stages.foreach(_.validateParams()) - } - - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur)) + stages.foldLeft(dataset.toDF)((cur, transformer) => transformer.transform(cur)) } + @Since("1.2.0") override def transformSchema(schema: StructType): StructType = { stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur)) } + @Since("1.4.0") override def copy(extra: ParamMap): PipelineModel = { new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new PipelineModel.PipelineModelWriter(this) +} + +@Since("1.6.0") +object PipelineModel extends MLReadable[PipelineModel] { + + import Pipeline.SharedReadWrite + + @Since("1.6.0") + override def read: MLReader[PipelineModel] = new PipelineModelReader + + @Since("1.6.0") + override def load(path: String): PipelineModel = super.load(path) + + private[PipelineModel] class PipelineModelWriter(instance: PipelineModel) extends MLWriter { + + SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]]) + + override protected def saveImpl(path: String): Unit = SharedReadWrite.saveImpl(instance, + instance.stages.asInstanceOf[Array[PipelineStage]], sc, path) + } + + private class PipelineModelReader extends MLReader[PipelineModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[PipelineModel].getName + + override def load(path: String): PipelineModel = { + val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) + val transformers = stages map { + case stage: Transformer => stage + case other => throw new RuntimeException(s"PipelineModel.read loaded a stage but found it" + + s" was not a Transformer. Bad stage ${other.uid} of type ${other.getClass}") + } + new PipelineModel(uid, transformers) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index e0dcd427fae24..81140d1f7b21f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -24,9 +24,9 @@ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} -import org.apache.spark.sql.{DataFrame, Row} /** * (private[ml]) Trait for parameters for prediction (regression and classification). @@ -36,6 +36,7 @@ private[ml] trait PredictorParams extends Params /** * Validates and transforms the input schema with the provided param map. + * * @param schema input schema * @param fitting whether this is in fitting * @param featuresDataType SQL DataType for FeaturesType. @@ -49,8 +50,7 @@ private[ml] trait PredictorParams extends Params // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType) if (fitting) { - // TODO: Allow other numeric types - SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(labelCol)) } SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) } @@ -83,7 +83,7 @@ abstract class Predictor[ /** @group setParam */ def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner] - override def fit(dataset: DataFrame): M = { + override def fit(dataset: Dataset[_]): M = { // This handles a few items such as schema validation. // Developers only need to implement train(). transformSchema(dataset.schema, logging = true) @@ -100,7 +100,7 @@ abstract class Predictor[ * @param dataset Training dataset * @return Fitted model */ - protected def train(dataset: DataFrame): M + protected def train(dataset: Dataset[_]): M /** * Returns the SQL DataType corresponding to the FeaturesType type parameter. @@ -120,9 +120,10 @@ abstract class Predictor[ * Extract [[labelCol]] and [[featuresCol]] from the given dataset, * and put it in an RDD with strong types. */ - protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = { - dataset.select($(labelCol), $(featuresCol)) - .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } + protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = { + dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + case Row(label: Double, features: Vector) => LabeledPoint(label, features) + } } } @@ -170,18 +171,18 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, * @param dataset input dataset * @return transformed dataset with [[predictionCol]] of type [[Double]] */ - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) if ($(predictionCol).nonEmpty) { transformImpl(dataset) } else { this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + " since no output columns were set.") - dataset + dataset.toDF } } - protected def transformImpl(dataset: DataFrame): DataFrame = { + protected def transformImpl(dataset: Dataset[_]): DataFrame = { val predictUDF = udf { (features: Any) => predict(features.asInstanceOf[FeaturesType]) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 3c7bcf7590e6d..a3a2b55adc25d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -19,11 +19,11 @@ package org.apache.spark.ml import scala.annotation.varargs -import org.apache.spark.Logging -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.internal.Logging import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -41,9 +41,10 @@ abstract class Transformer extends PipelineStage { * @param otherParamPairs other param pairs, overwrite embedded params * @return transformed dataset */ + @Since("2.0.0") @varargs def transform( - dataset: DataFrame, + dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): DataFrame = { val map = new ParamMap() @@ -58,14 +59,16 @@ abstract class Transformer extends PipelineStage { * @param paramMap additional parameters, overwrite embedded params * @return transformed dataset */ - def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + @Since("2.0.0") + def transform(dataset: Dataset[_], paramMap: ParamMap): DataFrame = { this.copy(paramMap).transform(dataset) } /** * Transforms the input dataset. */ - def transform(dataset: DataFrame): DataFrame + @Since("2.0.0") + def transform(dataset: Dataset[_]): DataFrame override def copy(extra: ParamMap): Transformer } @@ -113,10 +116,10 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] StructType(outputFields) } - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - dataset.withColumn($(outputCol), - callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol)))) + val transformUDF = udf(this.createTransformFunc, outputDataType) + dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol)))) } override def copy(extra: ParamMap): T = defaultCopy(extra) diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala index b5258ff348477..a5b84116e6eae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.ann -import breeze.linalg.{*, DenseMatrix => BDM, DenseVector => BDV, Vector => BV, axpy => Baxpy, - sum => Bsum} -import breeze.numerics.{log => Blog, sigmoid => Bsigmoid} +import java.util.Random + +import breeze.linalg.{*, axpy => Baxpy, DenseMatrix => BDM, DenseVector => BDV, Vector => BV} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.optimization._ @@ -32,20 +32,46 @@ import org.apache.spark.util.random.XORShiftRandom * */ private[ann] trait Layer extends Serializable { + /** - * Returns the instance of the layer based on weights provided - * @param weights vector with layer weights - * @param position position of weights in the vector - * @return the layer model + * Number of weights that is used to allocate memory for the weights vector + */ + val weightSize: Int + + /** + * Returns the output size given the input size (not counting the stack size). + * Output size is used to allocate memory for the output. + * + * @param inputSize input size + * @return output size */ - def getInstance(weights: Vector, position: Int): LayerModel + def getOutputSize(inputSize: Int): Int + /** + * If true, the memory is not allocated for the output of this layer. + * The memory allocated to the previous layer is used to write the output of this layer. + * Developer can set this to true if computing delta of a previous layer + * does not involve its output, so the current layer can write there. + * This also mean that both layers have the same number of outputs. + */ + val inPlace: Boolean + + /** + * Returns the instance of the layer based on weights provided. + * Size of weights must be equal to weightSize + * + * @param initialWeights vector with layer weights + * @return the layer model + */ + def createModel(initialWeights: BDV[Double]): LayerModel /** * Returns the instance of the layer with random generated weights - * @param seed seed + * + * @param weights vector for weights initialization, must be equal to weightSize + * @param random random number generator * @return the layer model */ - def getInstance(seed: Long): LayerModel + def initModel(weights: BDV[Double], random: Random): LayerModel } /** @@ -54,92 +80,102 @@ private[ann] trait Layer extends Serializable { * Can return weights in Vector format. */ private[ann] trait LayerModel extends Serializable { - /** - * number of weights - */ - val size: Int + val weights: BDV[Double] /** * Evaluates the data (process the data through the layer) + * Output is allocated based on the size provided by the + * LayerModel implementation and the stack (batch) size + * Developer is responsible for checking the size of output + * when writing to it + * * @param data data - * @return processed data + * @param output output (modified in place) */ - def eval(data: BDM[Double]): BDM[Double] + def eval(data: BDM[Double], output: BDM[Double]): Unit /** * Computes the delta for back propagation - * @param nextDelta delta of the next layer - * @param input input data - * @return delta + * Delta is allocated based on the size provided by the + * LayerModel implementation and the stack (batch) size + * Developer is responsible for checking the size of + * prevDelta when writing to it + * + * @param delta delta of this layer + * @param output output of this layer + * @param prevDelta the previous delta (modified in place) */ - def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] + def computePrevDelta(delta: BDM[Double], output: BDM[Double], prevDelta: BDM[Double]): Unit /** * Computes the gradient + * cumGrad is a wrapper on the part of the weight vector + * size of cumGrad is based on weightSize provided by + * implementation of LayerModel + * * @param delta delta for this layer * @param input input data - * @return gradient + * @param cumGrad cumulative gradient (modified in place) */ - def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] - - /** - * Returns weights for the layer in a single vector - * @return layer weights - */ - def weights(): Vector + def grad(delta: BDM[Double], input: BDM[Double], cumGrad: BDV[Double]): Unit } /** * Layer properties of affine transformations, that is y=A*x+b + * * @param numIn number of inputs * @param numOut number of outputs */ private[ann] class AffineLayer(val numIn: Int, val numOut: Int) extends Layer { - override def getInstance(weights: Vector, position: Int): LayerModel = { - AffineLayerModel(this, weights, position) - } + override val weightSize = numIn * numOut + numOut - override def getInstance(seed: Long = 11L): LayerModel = { - AffineLayerModel(this, seed) - } + override def getOutputSize(inputSize: Int): Int = numOut + + override val inPlace = false + + override def createModel(weights: BDV[Double]): LayerModel = new AffineLayerModel(weights, this) + + override def initModel(weights: BDV[Double], random: Random): LayerModel = + AffineLayerModel(this, weights, random) } /** - * Model of Affine layer y=A*x+b - * @param w weights (matrix A) - * @param b bias (vector b) + * Model of Affine layer + * + * @param weights weights + * @param layer layer properties */ -private[ann] class AffineLayerModel private(w: BDM[Double], b: BDV[Double]) extends LayerModel { - val size = w.size + b.length - val gwb = new Array[Double](size) - private lazy val gw: BDM[Double] = new BDM[Double](w.rows, w.cols, gwb) - private lazy val gb: BDV[Double] = new BDV[Double](gwb, w.size) - private var z: BDM[Double] = null - private var d: BDM[Double] = null +private[ann] class AffineLayerModel private[ann] ( + val weights: BDV[Double], + val layer: AffineLayer) extends LayerModel { + val w = new BDM[Double](layer.numOut, layer.numIn, weights.data, weights.offset) + val b = + new BDV[Double](weights.data, weights.offset + (layer.numOut * layer.numIn), 1, layer.numOut) + private var ones: BDV[Double] = null - override def eval(data: BDM[Double]): BDM[Double] = { - if (z == null || z.cols != data.cols) z = new BDM[Double](w.rows, data.cols) - z(::, *) := b - BreezeUtil.dgemm(1.0, w, data, 1.0, z) - z + override def eval(data: BDM[Double], output: BDM[Double]): Unit = { + output(::, *) := b + BreezeUtil.dgemm(1.0, w, data, 1.0, output) } - override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = { - if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](w.cols, nextDelta.cols) - BreezeUtil.dgemm(1.0, w.t, nextDelta, 0.0, d) - d + override def computePrevDelta( + delta: BDM[Double], + output: BDM[Double], + prevDelta: BDM[Double]): Unit = { + BreezeUtil.dgemm(1.0, w.t, delta, 0.0, prevDelta) } - override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = { - BreezeUtil.dgemm(1.0 / input.cols, delta, input.t, 0.0, gw) + override def grad(delta: BDM[Double], input: BDM[Double], cumGrad: BDV[Double]): Unit = { + // compute gradient of weights + val cumGradientOfWeights = new BDM[Double](w.rows, w.cols, cumGrad.data, cumGrad.offset) + BreezeUtil.dgemm(1.0 / input.cols, delta, input.t, 1.0, cumGradientOfWeights) if (ones == null || ones.length != delta.cols) ones = BDV.ones[Double](delta.cols) - BreezeUtil.dgemv(1.0 / input.cols, delta, ones, 0.0, gb) - gwb + // compute gradient of bias + val cumGradientOfBias = new BDV[Double](cumGrad.data, cumGrad.offset + w.size, 1, b.length) + BreezeUtil.dgemv(1.0 / input.cols, delta, ones, 1.0, cumGradientOfBias) } - - override def weights(): Vector = AffineLayerModel.roll(w, b) } /** @@ -149,73 +185,40 @@ private[ann] object AffineLayerModel { /** * Creates a model of Affine layer + * * @param layer layer properties - * @param weights vector with weights - * @param position position of weights in the vector - * @return model of Affine layer - */ - def apply(layer: AffineLayer, weights: Vector, position: Int): AffineLayerModel = { - val (w, b) = unroll(weights, position, layer.numIn, layer.numOut) - new AffineLayerModel(w, b) - } - - /** - * Creates a model of Affine layer - * @param layer layer properties - * @param seed seed + * @param weights vector for weights initialization + * @param random random number generator * @return model of Affine layer */ - def apply(layer: AffineLayer, seed: Long): AffineLayerModel = { - val (w, b) = randomWeights(layer.numIn, layer.numOut, seed) - new AffineLayerModel(w, b) - } - - /** - * Unrolls the weights from the vector - * @param weights vector with weights - * @param position position of weights for this layer - * @param numIn number of layer inputs - * @param numOut number of layer outputs - * @return matrix A and vector b - */ - def unroll( - weights: Vector, - position: Int, - numIn: Int, - numOut: Int): (BDM[Double], BDV[Double]) = { - val weightsCopy = weights.toArray - // TODO: the array is not copied to BDMs, make sure this is OK! - val a = new BDM[Double](numOut, numIn, weightsCopy, position) - val b = new BDV[Double](weightsCopy, position + (numOut * numIn), 1, numOut) - (a, b) - } - - /** - * Roll the layer weights into a vector - * @param a matrix A - * @param b vector b - * @return vector of weights - */ - def roll(a: BDM[Double], b: BDV[Double]): Vector = { - val result = new Array[Double](a.size + b.length) - // TODO: make sure that we need to copy! - System.arraycopy(a.toArray, 0, result, 0, a.size) - System.arraycopy(b.toArray, 0, result, a.size, b.length) - Vectors.dense(result) + def apply(layer: AffineLayer, weights: BDV[Double], random: Random): AffineLayerModel = { + randomWeights(layer.numIn, layer.numOut, weights, random) + new AffineLayerModel(weights, layer) } /** - * Generate random weights for the layer - * @param numIn number of inputs + * Initialize weights randomly in the interval + * Uses [Bottou-88] heuristic [-a/sqrt(in); a/sqrt(in)] + * where a is chosen in a such way that the weight variance corresponds + * to the points to the maximal curvature of the activation function + * (which is approximately 2.38 for a standard sigmoid) + * + * @param numIn number of inputs * @param numOut number of outputs - * @param seed seed - * @return (matrix A, vector b) + * @param weights vector for weights initialization + * @param random random number generator */ - def randomWeights(numIn: Int, numOut: Int, seed: Long = 11L): (BDM[Double], BDV[Double]) = { - val rand: XORShiftRandom = new XORShiftRandom(seed) - val weights = BDM.fill[Double](numOut, numIn){ (rand.nextDouble * 4.8 - 2.4) / numIn } - val bias = BDV.fill[Double](numOut){ (rand.nextDouble * 4.8 - 2.4) / numIn } - (weights, bias) + def randomWeights( + numIn: Int, + numOut: Int, + weights: BDV[Double], + random: Random): Unit = { + var i = 0 + val sqrtIn = math.sqrt(numIn) + while (i < weights.length) { + weights(i) = (random.nextDouble * 4.8 - 2.4) / sqrtIn + i += 1 + } } } @@ -226,44 +229,21 @@ private[ann] trait ActivationFunction extends Serializable { /** * Implements a function - * @param x input data - * @param y output data */ - def eval(x: BDM[Double], y: BDM[Double]): Unit + def eval: Double => Double /** * Implements a derivative of a function (needed for the back propagation) - * @param x input data - * @param y output data */ - def derivative(x: BDM[Double], y: BDM[Double]): Unit - - /** - * Implements a cross entropy error of a function. - * Needed if the functional layer that contains this function is the output layer - * of the network. - * @param target target output - * @param output computed output - * @param result intermediate result - * @return cross-entropy - */ - def crossEntropy(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double - - /** - * Implements a mean squared error of a function - * @param target target output - * @param output computed output - * @param result intermediate result - * @return mean squared error - */ - def squared(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double + def derivative: Double => Double } /** - * Implements in-place application of functions + * Implements in-place application of functions in the arrays */ -private[ann] object ActivationFunction { +private[ann] object ApplyInPlace { + // TODO: use Breeze UFunc def apply(x: BDM[Double], y: BDM[Double], func: Double => Double): Unit = { var i = 0 while (i < x.rows) { @@ -276,6 +256,7 @@ private[ann] object ActivationFunction { } } + // TODO: use Breeze UFunc def apply( x1: BDM[Double], x2: BDM[Double], @@ -293,180 +274,87 @@ private[ann] object ActivationFunction { } } -/** - * Implements SoftMax activation function - */ -private[ann] class SoftmaxFunction extends ActivationFunction { - override def eval(x: BDM[Double], y: BDM[Double]): Unit = { - var j = 0 - // find max value to make sure later that exponent is computable - while (j < x.cols) { - var i = 0 - var max = Double.MinValue - while (i < x.rows) { - if (x(i, j) > max) { - max = x(i, j) - } - i += 1 - } - var sum = 0.0 - i = 0 - while (i < x.rows) { - val res = Math.exp(x(i, j) - max) - y(i, j) = res - sum += res - i += 1 - } - i = 0 - while (i < x.rows) { - y(i, j) /= sum - i += 1 - } - j += 1 - } - } - - override def crossEntropy( - output: BDM[Double], - target: BDM[Double], - result: BDM[Double]): Double = { - def m(o: Double, t: Double): Double = o - t - ActivationFunction(output, target, result, m) - -Bsum( target :* Blog(output)) / output.cols - } - - override def derivative(x: BDM[Double], y: BDM[Double]): Unit = { - def sd(z: Double): Double = (1 - z) * z - ActivationFunction(x, y, sd) - } - - override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = { - throw new UnsupportedOperationException("Sorry, squared error is not defined for SoftMax.") - } -} - /** * Implements Sigmoid activation function */ private[ann] class SigmoidFunction extends ActivationFunction { - override def eval(x: BDM[Double], y: BDM[Double]): Unit = { - def s(z: Double): Double = Bsigmoid(z) - ActivationFunction(x, y, s) - } - - override def crossEntropy( - output: BDM[Double], - target: BDM[Double], - result: BDM[Double]): Double = { - def m(o: Double, t: Double): Double = o - t - ActivationFunction(output, target, result, m) - -Bsum(target :* Blog(output)) / output.cols - } - override def derivative(x: BDM[Double], y: BDM[Double]): Unit = { - def sd(z: Double): Double = (1 - z) * z - ActivationFunction(x, y, sd) - } + override def eval: (Double) => Double = x => 1.0 / (1 + math.exp(-x)) - override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = { - // TODO: make it readable - def m(o: Double, t: Double): Double = (o - t) - ActivationFunction(output, target, result, m) - val e = Bsum(result :* result) / 2 / output.cols - def m2(x: Double, o: Double) = x * (o - o * o) - ActivationFunction(result, output, result, m2) - e - } + override def derivative: (Double) => Double = z => (1 - z) * z } /** * Functional layer properties, y = f(x) + * * @param activationFunction activation function */ private[ann] class FunctionalLayer (val activationFunction: ActivationFunction) extends Layer { - override def getInstance(weights: Vector, position: Int): LayerModel = getInstance(0L) - override def getInstance(seed: Long): LayerModel = - FunctionalLayerModel(this) + override val weightSize = 0 + + override def getOutputSize(inputSize: Int): Int = inputSize + + override val inPlace = true + + override def createModel(weights: BDV[Double]): LayerModel = new FunctionalLayerModel(this) + + override def initModel(weights: BDV[Double], random: Random): LayerModel = + createModel(weights) } /** * Functional layer model. Holds no weights. - * @param activationFunction activation function + * + * @param layer functiona layer */ -private[ann] class FunctionalLayerModel private (val activationFunction: ActivationFunction) +private[ann] class FunctionalLayerModel private[ann] (val layer: FunctionalLayer) extends LayerModel { - val size = 0 - // matrices for in-place computations - // outputs - private var f: BDM[Double] = null - // delta - private var d: BDM[Double] = null - // matrix for error computation - private var e: BDM[Double] = null - // delta gradient - private lazy val dg = new Array[Double](0) - override def eval(data: BDM[Double]): BDM[Double] = { - if (f == null || f.cols != data.cols) f = new BDM[Double](data.rows, data.cols) - activationFunction.eval(data, f) - f - } + // empty weights + val weights = new BDV[Double](0) - override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = { - if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](nextDelta.rows, nextDelta.cols) - activationFunction.derivative(input, d) - d :*= nextDelta - d + override def eval(data: BDM[Double], output: BDM[Double]): Unit = { + ApplyInPlace(data, output, layer.activationFunction.eval) } - override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = dg - - override def weights(): Vector = Vectors.dense(new Array[Double](0)) - - def crossEntropy(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = { - if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols) - val error = activationFunction.crossEntropy(output, target, e) - (e, error) + override def computePrevDelta( + nextDelta: BDM[Double], + input: BDM[Double], + delta: BDM[Double]): Unit = { + ApplyInPlace(input, delta, layer.activationFunction.derivative) + delta :*= nextDelta } - def squared(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = { - if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols) - val error = activationFunction.squared(output, target, e) - (e, error) - } - - def error(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = { - // TODO: allow user pick error - activationFunction match { - case sigmoid: SigmoidFunction => squared(output, target) - case softmax: SoftmaxFunction => crossEntropy(output, target) - } - } -} - -/** - * Fabric of functional layer models - */ -private[ann] object FunctionalLayerModel { - def apply(layer: FunctionalLayer): FunctionalLayerModel = - new FunctionalLayerModel(layer.activationFunction) + override def grad(delta: BDM[Double], input: BDM[Double], cumGrad: BDV[Double]): Unit = {} } /** * Trait for the artificial neural network (ANN) topology properties */ -private[ann] trait Topology extends Serializable{ - def getInstance(weights: Vector): TopologyModel - def getInstance(seed: Long): TopologyModel +private[ann] trait Topology extends Serializable { + def model(weights: Vector): TopologyModel + def model(seed: Long): TopologyModel } /** * Trait for ANN topology model */ -private[ann] trait TopologyModel extends Serializable{ +private[ann] trait TopologyModel extends Serializable { + + val weights: Vector + /** + * Array of layers + */ + val layers: Array[Layer] + + /** + * Array of layer models + */ + val layerModels: Array[LayerModel] /** * Forward propagation + * * @param data input data * @return array of outputs for each of the layers */ @@ -474,6 +362,7 @@ private[ann] trait TopologyModel extends Serializable{ /** * Prediction of the model + * * @param data input data * @return prediction */ @@ -481,6 +370,7 @@ private[ann] trait TopologyModel extends Serializable{ /** * Computes gradient for the network + * * @param data input data * @param target target output * @param cumGradient cumulative gradient @@ -489,22 +379,17 @@ private[ann] trait TopologyModel extends Serializable{ */ def computeGradient(data: BDM[Double], target: BDM[Double], cumGradient: Vector, blockSize: Int): Double - - /** - * Returns the weights of the ANN - * @return weights - */ - def weights(): Vector } /** * Feed forward ANN + * * @param layers */ private[ann] class FeedForwardTopology private(val layers: Array[Layer]) extends Topology { - override def getInstance(weights: Vector): TopologyModel = FeedForwardModel(this, weights) + override def model(weights: Vector): TopologyModel = FeedForwardModel(this, weights) - override def getInstance(seed: Long): TopologyModel = FeedForwardModel(this, seed) + override def model(seed: Long): TopologyModel = FeedForwardModel(this, seed) } /** @@ -513,6 +398,7 @@ private[ann] class FeedForwardTopology private(val layers: Array[Layer]) extends private[ml] object FeedForwardTopology { /** * Creates a feed forward topology from the array of layers + * * @param layers array of layers * @return feed forward topology */ @@ -522,18 +408,26 @@ private[ml] object FeedForwardTopology { /** * Creates a multi-layer perceptron + * * @param layerSizes sizes of layers including input and output size - * @param softmax wether to use SoftMax or Sigmoid function for an output layer. + * @param softmaxOnTop wether to use SoftMax or Sigmoid function for an output layer. * Softmax is default * @return multilayer perceptron topology */ - def multiLayerPerceptron(layerSizes: Array[Int], softmax: Boolean = true): FeedForwardTopology = { + def multiLayerPerceptron( + layerSizes: Array[Int], + softmaxOnTop: Boolean = true): FeedForwardTopology = { val layers = new Array[Layer]((layerSizes.length - 1) * 2) - for(i <- 0 until layerSizes.length - 1){ + for (i <- 0 until layerSizes.length - 1) { layers(i * 2) = new AffineLayer(layerSizes(i), layerSizes(i + 1)) layers(i * 2 + 1) = - if (softmax && i == layerSizes.length - 2) { - new FunctionalLayer(new SoftmaxFunction()) + if (i == layerSizes.length - 2) { + if (softmaxOnTop) { + new SoftmaxLayerWithCrossEntropyLoss() + } else { + // TODO: squared error is more natural but converges slower + new SigmoidLayerWithSquaredError() + } } else { new FunctionalLayer(new SigmoidFunction()) } @@ -545,17 +439,45 @@ private[ml] object FeedForwardTopology { /** * Model of Feed Forward Neural Network. * Implements forward, gradient computation and can return weights in vector format. - * @param layerModels models of layers - * @param topology topology of the network + * + * @param weights network weights + * @param topology network topology */ private[ml] class FeedForwardModel private( - val layerModels: Array[LayerModel], + val weights: Vector, val topology: FeedForwardTopology) extends TopologyModel { + + val layers = topology.layers + val layerModels = new Array[LayerModel](layers.length) + private var offset = 0 + for (i <- 0 until layers.length) { + layerModels(i) = layers(i).createModel( + new BDV[Double](weights.toArray, offset, 1, layers(i).weightSize)) + offset += layers(i).weightSize + } + private var outputs: Array[BDM[Double]] = null + private var deltas: Array[BDM[Double]] = null + override def forward(data: BDM[Double]): Array[BDM[Double]] = { - val outputs = new Array[BDM[Double]](layerModels.length) - outputs(0) = layerModels(0).eval(data) + // Initialize output arrays for all layers. Special treatment for InPlace + val currentBatchSize = data.cols + // TODO: allocate outputs as one big array and then create BDMs from it + if (outputs == null || outputs(0).cols != currentBatchSize) { + outputs = new Array[BDM[Double]](layers.length) + var inputSize = data.rows + for (i <- 0 until layers.length) { + if (layers(i).inPlace) { + outputs(i) = outputs(i - 1) + } else { + val outputSize = layers(i).getOutputSize(inputSize) + outputs(i) = new BDM[Double](outputSize, currentBatchSize) + inputSize = outputSize + } + } + } + layerModels(0).eval(data, outputs(0)) for (i <- 1 until layerModels.length) { - outputs(i) = layerModels(i).eval(outputs(i-1)) + layerModels(i).eval(outputs(i - 1), outputs(i)) } outputs } @@ -566,54 +488,36 @@ private[ml] class FeedForwardModel private( cumGradient: Vector, realBatchSize: Int): Double = { val outputs = forward(data) - val deltas = new Array[BDM[Double]](layerModels.length) + val currentBatchSize = data.cols + // TODO: allocate deltas as one big array and then create BDMs from it + if (deltas == null || deltas(0).cols != currentBatchSize) { + deltas = new Array[BDM[Double]](layerModels.length) + var inputSize = data.rows + for (i <- 0 until layerModels.length - 1) { + val outputSize = layers(i).getOutputSize(inputSize) + deltas(i) = new BDM[Double](outputSize, currentBatchSize) + inputSize = outputSize + } + } val L = layerModels.length - 1 - val (newE, newError) = layerModels.last match { - case flm: FunctionalLayerModel => flm.error(outputs.last, target) + // TODO: explain why delta of top layer is null (because it might contain loss+layer) + val loss = layerModels.last match { + case levelWithError: LossFunction => levelWithError.loss(outputs.last, target, deltas(L - 1)) case _ => - throw new UnsupportedOperationException("Non-functional layer not supported at the top") + throw new UnsupportedOperationException("Top layer is required to have objective.") } - deltas(L) = new BDM[Double](0, 0) - deltas(L - 1) = newE for (i <- (L - 2) to (0, -1)) { - deltas(i) = layerModels(i + 1).prevDelta(deltas(i + 1), outputs(i + 1)) - } - val grads = new Array[Array[Double]](layerModels.length) - for (i <- 0 until layerModels.length) { - val input = if (i==0) data else outputs(i - 1) - grads(i) = layerModels(i).grad(deltas(i), input) + layerModels(i + 1).computePrevDelta(deltas(i + 1), outputs(i + 1), deltas(i)) } - // update cumGradient val cumGradientArray = cumGradient.toArray var offset = 0 - // TODO: extract roll - for (i <- 0 until grads.length) { - val gradArray = grads(i) - var k = 0 - while (k < gradArray.length) { - cumGradientArray(offset + k) += gradArray(k) - k += 1 - } - offset += gradArray.length - } - newError - } - - // TODO: do we really need to copy the weights? they should be read-only - override def weights(): Vector = { - // TODO: extract roll - var size = 0 - for (i <- 0 until layerModels.length) { - size += layerModels(i).size - } - val array = new Array[Double](size) - var offset = 0 for (i <- 0 until layerModels.length) { - val layerWeights = layerModels(i).weights().toArray - System.arraycopy(layerWeights, 0, array, offset, layerWeights.length) - offset += layerWeights.length + val input = if (i == 0) data else outputs(i - 1) + layerModels(i).grad(deltas(i), input, + new BDV[Double](cumGradientArray, offset, 1, layers(i).weightSize)) + offset += layers(i).weightSize } - Vectors.dense(array) + loss } override def predict(data: Vector): Vector = { @@ -630,23 +534,19 @@ private[ann] object FeedForwardModel { /** * Creates a model from a topology and weights + * * @param topology topology * @param weights weights * @return model */ def apply(topology: FeedForwardTopology, weights: Vector): FeedForwardModel = { - val layers = topology.layers - val layerModels = new Array[LayerModel](layers.length) - var offset = 0 - for (i <- 0 until layers.length) { - layerModels(i) = layers(i).getInstance(weights, offset) - offset += layerModels(i).size - } - new FeedForwardModel(layerModels, topology) + // TODO: check that weights size is equal to sum of layers sizes + new FeedForwardModel(weights, topology) } /** * Creates a model given a topology and seed + * * @param topology topology * @param seed seed for generating the weights * @return model @@ -654,17 +554,25 @@ private[ann] object FeedForwardModel { def apply(topology: FeedForwardTopology, seed: Long = 11L): FeedForwardModel = { val layers = topology.layers val layerModels = new Array[LayerModel](layers.length) + var totalSize = 0 + for (i <- 0 until topology.layers.length) { + totalSize += topology.layers(i).weightSize + } + val weights = BDV.zeros[Double](totalSize) var offset = 0 - for(i <- 0 until layers.length){ - layerModels(i) = layers(i).getInstance(seed) - offset += layerModels(i).size + val random = new XORShiftRandom(seed) + for (i <- 0 until layers.length) { + layerModels(i) = layers(i). + initModel(new BDV[Double](weights.data, offset, 1, layers(i).weightSize), random) + offset += layers(i).weightSize } - new FeedForwardModel(layerModels, topology) + new FeedForwardModel(Vectors.fromBreeze(weights), topology) } } /** * Neural network gradient. Does nothing but calling Model's gradient + * * @param topology topology * @param dataStacker data stacker */ @@ -682,7 +590,7 @@ private[ann] class ANNGradient(topology: Topology, dataStacker: DataStacker) ext weights: Vector, cumGradient: Vector): Double = { val (input, target, realBatchSize) = dataStacker.unstack(data) - val model = topology.getInstance(weights) + val model = topology.model(weights) model.computeGradient(input, target, cumGradient, realBatchSize) } } @@ -692,6 +600,7 @@ private[ann] class ANNGradient(topology: Topology, dataStacker: DataStacker) ext * through Optimizer/Gradient interfaces. If stackSize is more than one, makes blocks * or matrices of inputs and outputs and then stack them in one vector. * This can be used for further batch computations after unstacking. + * * @param stackSize stack size * @param inputSize size of the input vectors * @param outputSize size of the output vectors @@ -701,6 +610,7 @@ private[ann] class DataStacker(stackSize: Int, inputSize: Int, outputSize: Int) /** * Stacks the data + * * @param data RDD of vector pairs * @return RDD of double (always zero) and vector that contains the stacked vectors */ @@ -733,6 +643,7 @@ private[ann] class DataStacker(stackSize: Int, inputSize: Int, outputSize: Int) /** * Unstack the stacked vectors into matrices for batch operations + * * @param data stacked vector * @return pair of matrices holding input and output data and the real stack size */ @@ -765,6 +676,7 @@ private[ann] class ANNUpdater extends Updater { /** * MLlib-style trainer class that trains a network given the data and topology + * * @param topology topology of ANN * @param inputSize input size * @param outputSize output size @@ -774,36 +686,50 @@ private[ml] class FeedForwardTrainer( val inputSize: Int, val outputSize: Int) extends Serializable { - // TODO: what if we need to pass random seed? - private var _weights = topology.getInstance(11L).weights() + private var _seed = this.getClass.getName.hashCode.toLong + private var _weights: Vector = null private var _stackSize = 128 private var dataStacker = new DataStacker(_stackSize, inputSize, outputSize) private var _gradient: Gradient = new ANNGradient(topology, dataStacker) private var _updater: Updater = new ANNUpdater() private var optimizer: Optimizer = LBFGSOptimizer.setConvergenceTol(1e-4).setNumIterations(100) + /** + * Returns seed + */ + def getSeed: Long = _seed + + /** + * Sets seed + */ + def setSeed(value: Long): this.type = { + _seed = value + this + } + /** * Returns weights - * @return weights */ def getWeights: Vector = _weights /** * Sets weights + * * @param value weights * @return trainer */ - def setWeights(value: Vector): FeedForwardTrainer = { + def setWeights(value: Vector): this.type = { _weights = value this } /** * Sets the stack size + * * @param value stack size * @return trainer */ - def setStackSize(value: Int): FeedForwardTrainer = { + def setStackSize(value: Int): this.type = { _stackSize = value dataStacker = new DataStacker(value, inputSize, outputSize) this @@ -811,6 +737,7 @@ private[ml] class FeedForwardTrainer( /** * Sets the SGD optimizer + * * @return SGD optimizer */ def SGDOptimizer: GradientDescent = { @@ -821,6 +748,7 @@ private[ml] class FeedForwardTrainer( /** * Sets the LBFGS optimizer + * * @return LBGS optimizer */ def LBFGSOptimizer: LBFGS = { @@ -831,10 +759,11 @@ private[ml] class FeedForwardTrainer( /** * Sets the updater + * * @param value updater * @return trainer */ - def setUpdater(value: Updater): FeedForwardTrainer = { + def setUpdater(value: Updater): this.type = { _updater = value updateUpdater(value) this @@ -842,10 +771,11 @@ private[ml] class FeedForwardTrainer( /** * Sets the gradient + * * @param value gradient * @return trainer */ - def setGradient(value: Gradient): FeedForwardTrainer = { + def setGradient(value: Gradient): this.type = { _gradient = value updateGradient(value) this @@ -871,12 +801,20 @@ private[ml] class FeedForwardTrainer( /** * Trains the ANN + * * @param data RDD of input and output vector pairs * @return model */ def train(data: RDD[(Vector, Vector)]): TopologyModel = { - val newWeights = optimizer.optimize(dataStacker.stack(data), getWeights) - topology.getInstance(newWeights) + val w = if (getWeights == null) { + // TODO: will make a copy if vector is a subvector of BDV (see Vectors code) + topology.model(_seed).weights + } else { + getWeights + } + // TODO: deprecate standard optimizer because it needs Vector + val newWeights = optimizer.optimize(dataStacker.stack(data), w) + topology.model(newWeights) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala new file mode 100644 index 0000000000000..32d78e9b226eb --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.ann + +import java.util.Random + +import breeze.linalg.{sum => Bsum, DenseMatrix => BDM, DenseVector => BDV} +import breeze.numerics.{log => brzlog} + +/** + * Trait for loss function + */ +private[ann] trait LossFunction { + /** + * Returns the value of loss function. + * Computes loss based on target and output. + * Writes delta (error) to delta in place. + * Delta is allocated based on the outputSize + * of model implementation. + * + * @param output actual output + * @param target target output + * @param delta delta (updated in place) + * @return loss + */ + def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double +} + +private[ann] class SigmoidLayerWithSquaredError extends Layer { + override val weightSize = 0 + override val inPlace = true + + override def getOutputSize(inputSize: Int): Int = inputSize + override def createModel(weights: BDV[Double]): LayerModel = + new SigmoidLayerModelWithSquaredError() + override def initModel(weights: BDV[Double], random: Random): LayerModel = + new SigmoidLayerModelWithSquaredError() +} + +private[ann] class SigmoidLayerModelWithSquaredError + extends FunctionalLayerModel(new FunctionalLayer(new SigmoidFunction)) with LossFunction { + override def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double = { + ApplyInPlace(output, target, delta, (o: Double, t: Double) => o - t) + val error = Bsum(delta :* delta) / 2 / output.cols + ApplyInPlace(delta, output, delta, (x: Double, o: Double) => x * (o - o * o)) + error + } +} + +private[ann] class SoftmaxLayerWithCrossEntropyLoss extends Layer { + override val weightSize = 0 + override val inPlace = true + + override def getOutputSize(inputSize: Int): Int = inputSize + override def createModel(weights: BDV[Double]): LayerModel = + new SoftmaxLayerModelWithCrossEntropyLoss() + override def initModel(weights: BDV[Double], random: Random): LayerModel = + new SoftmaxLayerModelWithCrossEntropyLoss() +} + +private[ann] class SoftmaxLayerModelWithCrossEntropyLoss extends LayerModel with LossFunction { + + // loss layer models do not have weights + val weights = new BDV[Double](0) + + override def eval(data: BDM[Double], output: BDM[Double]): Unit = { + var j = 0 + // find max value to make sure later that exponent is computable + while (j < data.cols) { + var i = 0 + var max = Double.MinValue + while (i < data.rows) { + if (data(i, j) > max) { + max = data(i, j) + } + i += 1 + } + var sum = 0.0 + i = 0 + while (i < data.rows) { + val res = math.exp(data(i, j) - max) + output(i, j) = res + sum += res + i += 1 + } + i = 0 + while (i < data.rows) { + output(i, j) /= sum + i += 1 + } + j += 1 + } + } + override def computePrevDelta( + nextDelta: BDM[Double], + input: BDM[Double], + delta: BDM[Double]): Unit = { + /* loss layer model computes delta in loss function */ + } + + override def grad(delta: BDM[Double], input: BDM[Double], cumGrad: BDV[Double]): Unit = { + /* loss layer model does not have weights */ + } + + override def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double = { + ApplyInPlace(output, target, delta, (o: Double, t: Double) => o - t) + -Bsum( target :* brzlog(output)) / output.cols + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala index a7c10333c0d53..27554acdf3c26 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.attribute import scala.annotation.varargs import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.types.{DoubleType, NumericType, Metadata, MetadataBuilder, StructField} +import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, NumericType, StructField} /** * :: DeveloperApi :: @@ -481,7 +481,7 @@ object NominalAttribute extends AttributeFactory { * A binary attribute. * @param name optional name * @param index optional index - * @param values optionla values. If set, its size must be 2. + * @param values optional values. If set, its size must be 2. */ @DeveloperApi class BinaryAttribute private[ml] ( diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/attribute/package-info.java index e3474f3c1d3ff..464ed125695d1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/package-info.java +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/package-info.java @@ -20,7 +20,7 @@ /** *

    ML attributes

    * - * The ML pipeline API uses {@link org.apache.spark.sql.DataFrame}s as ML datasets. + * The ML pipeline API uses {@link org.apache.spark.sql.Dataset}s as ML datasets. * Each dataset consists of typed columns, e.g., string, double, vector, etc. * However, knowing only the column type may not be sufficient to handle the data properly. * For instance, a double column with values 0.0, 1.0, 2.0, ... may represent some label indices, diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala index 7ac21d7d563f2..f6964054db839 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala @@ -17,8 +17,8 @@ package org.apache.spark.ml -import org.apache.spark.sql.DataFrame import org.apache.spark.ml.attribute.{Attribute, AttributeGroup} +import org.apache.spark.sql.DataFrame /** * ==ML attributes== diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 45df557a89908..473e801794c06 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -18,15 +18,14 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} +import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, StructType} - /** * (private[spark]) Params for classification. */ @@ -93,7 +92,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur * @param dataset input dataset * @return transformed dataset */ - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) // Output selected columns only. @@ -124,7 +123,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" + " since no output columns were set.") } - outputData + outputData.toDF } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index b0157f7ce24ec..300ae4339c3c4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -17,17 +17,23 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} +import org.apache.spark.ml.tree._ +import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._ import org.apache.spark.ml.tree.impl.RandomForest -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} + /** * :: Experimental :: @@ -36,33 +42,47 @@ import org.apache.spark.sql.DataFrame * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ +@Since("1.4.0") @Experimental -final class DecisionTreeClassifier(override val uid: String) +final class DecisionTreeClassifier @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] - with DecisionTreeParams with TreeClassifierParams { + with DecisionTreeClassifierParams with DefaultParamsWritable { + @Since("1.4.0") def this() = this(Identifiable.randomUID("dtc")) // Override parameter setters from parent trait for Java API compatibility. + @Since("1.4.0") override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + @Since("1.4.0") override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + @Since("1.4.0") override def setMinInstancesPerNode(value: Int): this.type = super.setMinInstancesPerNode(value) + @Since("1.4.0") override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + @Since("1.4.0") override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + @Since("1.4.0") override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + @Since("1.4.0") override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + @Since("1.4.0") override def setImpurity(value: String): this.type = super.setImpurity(value) - override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = { + @Since("1.6.0") + override def setSeed(value: Long): this.type = super.setSeed(value) + + override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { @@ -75,6 +95,14 @@ final class DecisionTreeClassifier(override val uid: String) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = getOldStrategy(categoricalFeatures, numClasses) val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", + seed = $(seed), parentUID = Some(uid)) + trees.head.asInstanceOf[DecisionTreeClassificationModel] + } + + /** (private[ml]) Train a decision tree on an RDD */ + private[ml] def train(data: RDD[LabeledPoint], + oldStrategy: OldStrategy): DecisionTreeClassificationModel = { + val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0L, parentUID = Some(uid)) trees.head.asInstanceOf[DecisionTreeClassificationModel] } @@ -87,13 +115,19 @@ final class DecisionTreeClassifier(override val uid: String) subsamplingRate = 1.0) } + @Since("1.4.1") override def copy(extra: ParamMap): DecisionTreeClassifier = defaultCopy(extra) } +@Since("1.4.0") @Experimental -object DecisionTreeClassifier { +object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifier] { /** Accessor for supported impurities: entropy, gini */ + @Since("1.4.0") final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities + + @Since("2.0.0") + override def load(path: String): DecisionTreeClassifier = super.load(path) } /** @@ -102,14 +136,15 @@ object DecisionTreeClassifier { * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ +@Since("1.4.0") @Experimental final class DecisionTreeClassificationModel private[ml] ( - override val uid: String, - override val rootNode: Node, - override val numFeatures: Int, - override val numClasses: Int) + @Since("1.4.0")override val uid: String, + @Since("1.4.0")override val rootNode: Node, + @Since("1.6.0")override val numFeatures: Int, + @Since("1.5.0")override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] - with DecisionTreeModel with Serializable { + with DecisionTreeModel with DecisionTreeClassifierParams with MLWritable with Serializable { require(rootNode != null, "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") @@ -140,25 +175,91 @@ final class DecisionTreeClassificationModel private[ml] ( } } + @Since("1.4.0") override def copy(extra: ParamMap): DecisionTreeClassificationModel = { copyValues(new DecisionTreeClassificationModel(uid, rootNode, numFeatures, numClasses), extra) .setParent(parent) } + @Since("1.4.0") override def toString: String = { s"DecisionTreeClassificationModel (uid=$uid) of depth $depth with $numNodes nodes" } - /** (private[ml]) Convert to a model in the old API */ - private[ml] def toOld: OldDecisionTreeModel = { + /** + * Estimate of the importance of each feature. + * + * This generalizes the idea of "Gini" importance to other losses, + * following the explanation of Gini importance from "Random Forests" documentation + * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * + * This feature importance is calculated as follows: + * - importance(feature j) = sum (over nodes which split on feature j) of the gain, + * where gain is scaled by the number of instances passing through node + * - Normalize importances for tree to sum to 1. + * + * Note: Feature importance for single decision trees can have high variance due to + * correlated predictor variables. Consider using a [[RandomForestClassifier]] + * to determine feature importance instead. + */ + @Since("2.0.0") + lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures) + + /** Convert to spark.mllib DecisionTreeModel (losing some information) */ + override private[spark] def toOld: OldDecisionTreeModel = { new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification) } + + @Since("2.0.0") + override def write: MLWriter = + new DecisionTreeClassificationModel.DecisionTreeClassificationModelWriter(this) } -private[ml] object DecisionTreeClassificationModel { +@Since("2.0.0") +object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassificationModel] { + + @Since("2.0.0") + override def read: MLReader[DecisionTreeClassificationModel] = + new DecisionTreeClassificationModelReader + + @Since("2.0.0") + override def load(path: String): DecisionTreeClassificationModel = super.load(path) + + private[DecisionTreeClassificationModel] + class DecisionTreeClassificationModelWriter(instance: DecisionTreeClassificationModel) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures, + "numClasses" -> instance.numClasses) + DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) + val (nodeData, _) = NodeData.build(instance.rootNode, 0) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(nodeData).write.parquet(dataPath) + } + } + + private class DecisionTreeClassificationModelReader + extends MLReader[DecisionTreeClassificationModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[DecisionTreeClassificationModel].getName + + override def load(path: String): DecisionTreeClassificationModel = { + implicit val format = DefaultFormats + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val numClasses = (metadata.metadata \ "numClasses").extract[Int] + val root = loadTreeNodes(path, metadata, sqlContext) + val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } - /** (private[ml]) Convert a model from the old API */ - def fromOld( + /** Convert a model from the old API */ + private[ml] def fromOld( oldModel: OldDecisionTreeModel, parent: DecisionTreeClassifier, categoricalFeatures: Map[Int, Int], diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 74aef94bf7675..39a698af153b0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -18,24 +18,25 @@ package org.apache.spark.ml.classification import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ -import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.internal.Logging import org.apache.spark.ml.{PredictionModel, Predictor} -import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeClassifierParams, TreeEnsembleModel} -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.tree._ +import org.apache.spark.ml.tree.impl.GradientBoostedTrees +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DoubleType /** * :: Experimental :: @@ -43,37 +44,58 @@ import org.apache.spark.sql.types.DoubleType * learning algorithm for classification. * It supports binary labels, as well as both continuous and categorical features. * Note: Multiclass labels are not currently supported. + * + * The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999. + * + * Notes on Gradient Boosting vs. TreeBoost: + * - This implementation is for Stochastic Gradient Boosting, not for TreeBoost. + * - Both algorithms learn tree ensembles by minimizing loss functions. + * - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes + * based on the loss function, whereas the original gradient boosting method does not. + * - We expect to implement TreeBoost in the future: + * [https://issues.apache.org/jira/browse/SPARK-4240] */ +@Since("1.4.0") @Experimental -final class GBTClassifier(override val uid: String) +final class GBTClassifier @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) extends Predictor[Vector, GBTClassifier, GBTClassificationModel] - with GBTParams with TreeClassifierParams with Logging { + with GBTClassifierParams with DefaultParamsWritable with Logging { + @Since("1.4.0") def this() = this(Identifiable.randomUID("gbtc")) // Override parameter setters from parent trait for Java API compatibility. // Parameters from TreeClassifierParams: + @Since("1.4.0") override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + @Since("1.4.0") override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + @Since("1.4.0") override def setMinInstancesPerNode(value: Int): this.type = super.setMinInstancesPerNode(value) + @Since("1.4.0") override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + @Since("1.4.0") override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + @Since("1.4.0") override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + @Since("1.4.0") override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) /** * The impurity setting is ignored for GBT models. * Individual trees are built using impurity "Variance." */ + @Since("1.4.0") override def setImpurity(value: String): this.type = { logWarning("GBTClassifier.setImpurity should NOT be used") this @@ -81,51 +103,27 @@ final class GBTClassifier(override val uid: String) // Parameters from TreeEnsembleParams: + @Since("1.4.0") override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) - override def setSeed(value: Long): this.type = { - logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.") - super.setSeed(value) - } + @Since("1.4.0") + override def setSeed(value: Long): this.type = super.setSeed(value) // Parameters from GBTParams: + @Since("1.4.0") override def setMaxIter(value: Int): this.type = super.setMaxIter(value) + @Since("1.4.0") override def setStepSize(value: Double): this.type = super.setStepSize(value) - // Parameters for GBTClassifier: - - /** - * Loss function which GBT tries to minimize. (case-insensitive) - * Supported: "logistic" - * (default = logistic) - * @group param - */ - val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + - " tries to minimize (case-insensitive). Supported options:" + - s" ${GBTClassifier.supportedLossTypes.mkString(", ")}", - (value: String) => GBTClassifier.supportedLossTypes.contains(value.toLowerCase)) - - setDefault(lossType -> "logistic") + // Parameters from GBTClassifierParams: /** @group setParam */ + @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) - /** @group getParam */ - def getLossType: String = $(lossType).toLowerCase - - /** (private[ml]) Convert new loss to old loss. */ - override private[ml] def getOldLossType: OldLoss = { - getLossType match { - case "logistic" => OldLogLoss - case _ => - // Should never happen because of check in setter method. - throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType") - } - } - - override protected def train(dataset: DataFrame): GBTClassificationModel = { + override protected def train(dataset: Dataset[_]): GBTClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { @@ -140,19 +138,25 @@ final class GBTClassifier(override val uid: String) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) - val oldGBT = new OldGBT(boostingStrategy) - val oldModel = oldGBT.run(oldDataset) - GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures, numFeatures) + val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, + $(seed)) + new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) } + @Since("1.4.1") override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra) } +@Since("1.4.0") @Experimental -object GBTClassifier { - // The losses below should be lowercase. +object GBTClassifier extends DefaultParamsReadable[GBTClassifier] { + /** Accessor for supported loss settings: logistic */ - final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) + @Since("1.4.0") + final val supportedLossTypes: Array[String] = GBTClassifierParams.supportedLossTypes + + @Since("2.0.0") + override def load(path: String): GBTClassifier = super.load(path) } /** @@ -164,16 +168,18 @@ object GBTClassifier { * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. */ +@Since("1.6.0") @Experimental final class GBTClassificationModel private[ml]( - override val uid: String, + @Since("1.6.0") override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], private val _treeWeights: Array[Double], - override val numFeatures: Int) + @Since("1.6.0") override val numFeatures: Int) extends PredictionModel[Vector, GBTClassificationModel] - with TreeEnsembleModel with Serializable { + with GBTClassifierParams with TreeEnsembleModel[DecisionTreeRegressionModel] + with MLWritable with Serializable { - require(numTrees > 0, "GBTClassificationModel requires at least 1 tree.") + require(_trees.nonEmpty, "GBTClassificationModel requires at least 1 tree.") require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" + s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") @@ -182,14 +188,17 @@ final class GBTClassificationModel private[ml]( * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. */ + @Since("1.6.0") def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) = this(uid, _trees, _treeWeights, -1) - override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + @Since("1.4.0") + override def trees: Array[DecisionTreeRegressionModel] = _trees + @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights - override protected def transformImpl(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) val predictUDF = udf { (features: Any) => bcastModel.value.predict(features.asInstanceOf[Vector]) @@ -205,25 +214,93 @@ final class GBTClassificationModel private[ml]( if (prediction > 0.0) 1.0 else 0.0 } + /** Number of trees in ensemble */ + val numTrees: Int = trees.length + + @Since("1.4.0") override def copy(extra: ParamMap): GBTClassificationModel = { copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures), extra).setParent(parent) } + @Since("1.4.0") override def toString: String = { s"GBTClassificationModel (uid=$uid) with $numTrees trees" } + /** + * Estimate of the importance of each feature. + * + * Each feature's importance is the average of its importance across all trees in the ensemble + * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. + * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) + * and follows the implementation from scikit-learn. + * + * @see [[DecisionTreeClassificationModel.featureImportances]] + */ + @Since("2.0.0") + lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) + /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldGBTModel = { new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights) } + + @Since("2.0.0") + override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this) } -private[ml] object GBTClassificationModel { +@Since("2.0.0") +object GBTClassificationModel extends MLReadable[GBTClassificationModel] { + + @Since("2.0.0") + override def read: MLReader[GBTClassificationModel] = new GBTClassificationModelReader + + @Since("2.0.0") + override def load(path: String): GBTClassificationModel = super.load(path) + + private[GBTClassificationModel] + class GBTClassificationModelWriter(instance: GBTClassificationModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures, + "numTrees" -> instance.getNumTrees) + EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + } + } + + private class GBTClassificationModelReader extends MLReader[GBTClassificationModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[GBTClassificationModel].getName + private val treeClassName = classOf[DecisionTreeRegressionModel].getName + + override def load(path: String): GBTClassificationModel = { + implicit val format = DefaultFormats + val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = + EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val numTrees = (metadata.metadata \ "numTrees").extract[Int] + + val trees: Array[DecisionTreeRegressionModel] = treesData.map { + case (treeMetadata, root) => + val tree = + new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + DefaultParamsReader.getAndSetParams(tree, treeMetadata) + tree + } + require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" + + s" trees based on metadata but found ${trees.length} trees.") + val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } - /** (private[ml]) Convert a model from the old API */ - def fromOld( + /** Convert a model from the old API */ + private[ml] def fromOld( oldModel: OldGBTModel, parent: GBTClassifier, categoricalFeatures: Map[Int, Int], @@ -235,6 +312,6 @@ private[ml] object GBTClassificationModel { DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc") - new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures) + new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index f5fca686df144..c2b440059b1fa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -21,21 +21,24 @@ import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} +import org.apache.hadoop.fs.Path -import org.apache.spark.{Logging, SparkException} -import org.apache.spark.annotation.Experimental +import org.apache.spark.SparkException +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS._ -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.types.DoubleType import org.apache.spark.storage.StorageLevel /** @@ -153,11 +156,14 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas * Currently, this class only supports binary classification. It will support multiclass * in the future. */ +@Since("1.2.0") @Experimental -class LogisticRegression(override val uid: String) +class LogisticRegression @Since("1.2.0") ( + @Since("1.4.0") override val uid: String) extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] - with LogisticRegressionParams with Logging { + with LogisticRegressionParams with DefaultParamsWritable with Logging { + @Since("1.4.0") def this() = this(Identifiable.randomUID("logreg")) /** @@ -165,6 +171,7 @@ class LogisticRegression(override val uid: String) * Default is 0.0. * @group setParam */ + @Since("1.2.0") def setRegParam(value: Double): this.type = set(regParam, value) setDefault(regParam -> 0.0) @@ -175,6 +182,7 @@ class LogisticRegression(override val uid: String) * Default is 0.0 which is an L2 penalty. * @group setParam */ + @Since("1.4.0") def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value) setDefault(elasticNetParam -> 0.0) @@ -183,6 +191,7 @@ class LogisticRegression(override val uid: String) * Default is 100. * @group setParam */ + @Since("1.2.0") def setMaxIter(value: Int): this.type = set(maxIter, value) setDefault(maxIter -> 100) @@ -192,6 +201,7 @@ class LogisticRegression(override val uid: String) * Default is 1E-6. * @group setParam */ + @Since("1.4.0") def setTol(value: Double): this.type = set(tol, value) setDefault(tol -> 1E-6) @@ -200,6 +210,7 @@ class LogisticRegression(override val uid: String) * Default is true. * @group setParam */ + @Since("1.4.0") def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) setDefault(fitIntercept -> true) @@ -212,11 +223,14 @@ class LogisticRegression(override val uid: String) * Default is true. * @group setParam */ + @Since("1.5.0") def setStandardization(value: Boolean): this.type = set(standardization, value) setDefault(standardization -> true) + @Since("1.5.0") override def setThreshold(value: Double): this.type = super.setThreshold(value) + @Since("1.5.0") override def getThreshold: Double = super.getThreshold /** @@ -225,24 +239,44 @@ class LogisticRegression(override val uid: String) * Default is empty, so all instances have weight one. * @group setParam */ + @Since("1.6.0") def setWeightCol(value: String): this.type = set(weightCol, value) setDefault(weightCol -> "") + @Since("1.5.0") override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) + @Since("1.5.0") override def getThresholds: Array[Double] = super.getThresholds - override protected def train(dataset: DataFrame): LogisticRegressionModel = { - // Extract columns from data. If dataset is persisted, do not persist oldDataset. - val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } + private var optInitialModel: Option[LogisticRegressionModel] = None + + /** @group setParam */ + private[spark] def setInitialModel(model: LogisticRegressionModel): this.type = { + this.optInitialModel = Some(model) + this + } + override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = { val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + train(dataset, handlePersistence) + } + + protected[spark] def train(dataset: Dataset[_], handlePersistence: Boolean): + LogisticRegressionModel = { + val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val instances: RDD[Instance] = + dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) + } + if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + val instr = Instrumentation.create(this, instances) + instr.logParams(regParam, elasticNetParam, standardization, threshold, + maxIter, tol, fitIntercept) + val (summarizer, labelSummarizer) = { val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer), instance: Instance) => @@ -261,152 +295,204 @@ class LogisticRegression(override val uid: String) val numClasses = histogram.length val numFeatures = summarizer.mean.size - if (numInvalid != 0) { - val msg = s"Classification labels should be in {0 to ${numClasses - 1} " + - s"Found $numInvalid invalid labels." - logError(msg) - throw new SparkException(msg) - } + instr.logNumClasses(numClasses) + instr.logNumFeatures(numFeatures) - if (numClasses > 2) { - val msg = s"Currently, LogisticRegression with ElasticNet in ML package only supports " + - s"binary classification. Found $numClasses in the input dataset." - logError(msg) - throw new SparkException(msg) - } + val (coefficients, intercept, objectiveHistory) = { + if (numInvalid != 0) { + val msg = s"Classification labels should be in {0 to ${numClasses - 1} " + + s"Found $numInvalid invalid labels." + logError(msg) + throw new SparkException(msg) + } - val featuresMean = summarizer.mean.toArray - val featuresStd = summarizer.variance.toArray.map(math.sqrt) + if (numClasses > 2) { + val msg = s"Currently, LogisticRegression with ElasticNet in ML package only supports " + + s"binary classification. Found $numClasses in the input dataset." + logError(msg) + throw new SparkException(msg) + } else if ($(fitIntercept) && numClasses == 2 && histogram(0) == 0.0) { + logWarning(s"All labels are one and fitIntercept=true, so the coefficients will be " + + s"zeros and the intercept will be positive infinity; as a result, " + + s"training is not needed.") + (Vectors.sparse(numFeatures, Seq()), Double.PositiveInfinity, Array.empty[Double]) + } else if ($(fitIntercept) && numClasses == 1) { + logWarning(s"All labels are zero and fitIntercept=true, so the coefficients will be " + + s"zeros and the intercept will be negative infinity; as a result, " + + s"training is not needed.") + (Vectors.sparse(numFeatures, Seq()), Double.NegativeInfinity, Array.empty[Double]) + } else { + if (!$(fitIntercept) && numClasses == 2 && histogram(0) == 0.0) { + logWarning(s"All labels are one and fitIntercept=false. It's a dangerous ground, " + + s"so the algorithm may not converge.") + } else if (!$(fitIntercept) && numClasses == 1) { + logWarning(s"All labels are zero and fitIntercept=false. It's a dangerous ground, " + + s"so the algorithm may not converge.") + } - val regParamL1 = $(elasticNetParam) * $(regParam) - val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam) + val featuresMean = summarizer.mean.toArray + val featuresStd = summarizer.variance.toArray.map(math.sqrt) - val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept), $(standardization), - featuresStd, featuresMean, regParamL2) + val regParamL1 = $(elasticNetParam) * $(regParam) + val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam) - val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) { - new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) - } else { - def regParamL1Fun = (index: Int) => { - // Remove the L1 penalization on the intercept - if (index == numFeatures) { - 0.0 + val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept), + $(standardization), featuresStd, featuresMean, regParamL2) + + val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) { + new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) } else { - if ($(standardization)) { - regParamL1 - } else { - // If `standardization` is false, we still standardize the data - // to improve the rate of convergence; as a result, we have to - // perform this reverse standardization by penalizing each component - // differently to get effectively the same objective function when - // the training dataset is not standardized. - if (featuresStd(index) != 0.0) regParamL1 / featuresStd(index) else 0.0 + val standardizationParam = $(standardization) + def regParamL1Fun = (index: Int) => { + // Remove the L1 penalization on the intercept + if (index == numFeatures) { + 0.0 + } else { + if (standardizationParam) { + regParamL1 + } else { + // If `standardization` is false, we still standardize the data + // to improve the rate of convergence; as a result, we have to + // perform this reverse standardization by penalizing each component + // differently to get effectively the same objective function when + // the training dataset is not standardized. + if (featuresStd(index) != 0.0) regParamL1 / featuresStd(index) else 0.0 + } + } } + new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol)) } - } - new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol)) - } - val initialCoefficientsWithIntercept = - Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures) - - if ($(fitIntercept)) { - /* - For binary logistic regression, when we initialize the coefficients as zeros, - it will converge faster if we initialize the intercept such that - it follows the distribution of the labels. - - {{{ - P(0) = 1 / (1 + \exp(b)), and - P(1) = \exp(b) / (1 + \exp(b)) - }}}, hence - {{{ - b = \log{P(1) / P(0)} = \log{count_1 / count_0} - }}} - */ - initialCoefficientsWithIntercept.toArray(numFeatures) - = math.log(histogram(1) / histogram(0)) - } + val initialCoefficientsWithIntercept = + Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures) - val states = optimizer.iterations(new CachedDiffFunction(costFun), - initialCoefficientsWithIntercept.toBreeze.toDenseVector) + if (optInitialModel.isDefined && optInitialModel.get.coefficients.size != numFeatures) { + val vec = optInitialModel.get.coefficients + logWarning( + s"Initial coefficients provided $vec did not match the expected size $numFeatures") + } - val (coefficients, intercept, objectiveHistory) = { - /* - Note that in Logistic Regression, the objective history (loss + regularization) - is log-likelihood which is invariance under feature standardization. As a result, - the objective history from optimizer is the same as the one in the original space. - */ - val arrayBuilder = mutable.ArrayBuilder.make[Double] - var state: optimizer.State = null - while (states.hasNext) { - state = states.next() - arrayBuilder += state.adjustedValue - } + if (optInitialModel.isDefined && optInitialModel.get.coefficients.size == numFeatures) { + val initialCoefficientsWithInterceptArray = initialCoefficientsWithIntercept.toArray + optInitialModel.get.coefficients.foreachActive { case (index, value) => + initialCoefficientsWithInterceptArray(index) = value + } + if ($(fitIntercept)) { + initialCoefficientsWithInterceptArray(numFeatures) == optInitialModel.get.intercept + } + } else if ($(fitIntercept)) { + /* + For binary logistic regression, when we initialize the coefficients as zeros, + it will converge faster if we initialize the intercept such that + it follows the distribution of the labels. + + {{{ + P(0) = 1 / (1 + \exp(b)), and + P(1) = \exp(b) / (1 + \exp(b)) + }}}, hence + {{{ + b = \log{P(1) / P(0)} = \log{count_1 / count_0} + }}} + */ + initialCoefficientsWithIntercept.toArray(numFeatures) = math.log( + histogram(1) / histogram(0)) + } - if (state == null) { - val msg = s"${optimizer.getClass.getName} failed." - logError(msg) - throw new SparkException(msg) - } + val states = optimizer.iterations(new CachedDiffFunction(costFun), + initialCoefficientsWithIntercept.toBreeze.toDenseVector) + + /* + Note that in Logistic Regression, the objective history (loss + regularization) + is log-likelihood which is invariance under feature standardization. As a result, + the objective history from optimizer is the same as the one in the original space. + */ + val arrayBuilder = mutable.ArrayBuilder.make[Double] + var state: optimizer.State = null + while (states.hasNext) { + state = states.next() + arrayBuilder += state.adjustedValue + } - /* - The coefficients are trained in the scaled space; we're converting them back to - the original space. - Note that the intercept in scaled space and original space is the same; - as a result, no scaling is needed. - */ - val rawCoefficients = state.x.toArray.clone() - var i = 0 - while (i < numFeatures) { - rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 } - i += 1 - } + if (state == null) { + val msg = s"${optimizer.getClass.getName} failed." + logError(msg) + throw new SparkException(msg) + } - if ($(fitIntercept)) { - (Vectors.dense(rawCoefficients.dropRight(1)).compressed, rawCoefficients.last, - arrayBuilder.result()) - } else { - (Vectors.dense(rawCoefficients).compressed, 0.0, arrayBuilder.result()) + /* + The coefficients are trained in the scaled space; we're converting them back to + the original space. + Note that the intercept in scaled space and original space is the same; + as a result, no scaling is needed. + */ + val rawCoefficients = state.x.toArray.clone() + var i = 0 + while (i < numFeatures) { + rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 } + i += 1 + } + + if ($(fitIntercept)) { + (Vectors.dense(rawCoefficients.dropRight(1)).compressed, rawCoefficients.last, + arrayBuilder.result()) + } else { + (Vectors.dense(rawCoefficients).compressed, 0.0, arrayBuilder.result()) + } } } if (handlePersistence) instances.unpersist() val model = copyValues(new LogisticRegressionModel(uid, coefficients, intercept)) + val (summaryModel, probabilityColName) = model.findSummaryModelAndProbabilityCol() val logRegSummary = new BinaryLogisticRegressionTrainingSummary( - model.transform(dataset), - $(probabilityCol), + summaryModel.transform(dataset), + probabilityColName, $(labelCol), $(featuresCol), objectiveHistory) - model.setSummary(logRegSummary) + val m = model.setSummary(logRegSummary) + instr.logSuccess(m) + m } + @Since("1.4.0") override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra) } +@Since("1.6.0") +object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { + + @Since("1.6.0") + override def load(path: String): LogisticRegression = super.load(path) +} + /** * :: Experimental :: * Model produced by [[LogisticRegression]]. */ +@Since("1.4.0") @Experimental -class LogisticRegressionModel private[ml] ( - override val uid: String, - val coefficients: Vector, - val intercept: Double) +class LogisticRegressionModel private[spark] ( + @Since("1.4.0") override val uid: String, + @Since("1.6.0") val coefficients: Vector, + @Since("1.3.0") val intercept: Double) extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] - with LogisticRegressionParams { + with LogisticRegressionParams with MLWritable { @deprecated("Use coefficients instead.", "1.6.0") def weights: Vector = coefficients + @Since("1.5.0") override def setThreshold(value: Double): this.type = super.setThreshold(value) + @Since("1.5.0") override def getThreshold: Double = super.getThreshold + @Since("1.5.0") override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) + @Since("1.5.0") override def getThresholds: Array[Double] = super.getThresholds /** Margin (rawPrediction) for class label 1. For binary classification only. */ @@ -420,8 +506,10 @@ class LogisticRegressionModel private[ml] ( 1.0 / (1.0 + math.exp(-m)) } + @Since("1.6.0") override val numFeatures: Int = coefficients.size + @Since("1.3.0") override val numClasses: Int = 2 private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None @@ -430,12 +518,24 @@ class LogisticRegressionModel private[ml] ( * Gets summary of model on training set. An exception is * thrown if `trainingSummary == None`. */ - def summary: LogisticRegressionTrainingSummary = trainingSummary match { - case Some(summ) => summ - case None => - throw new SparkException( - "No training summary available for this LogisticRegressionModel", - new NullPointerException()) + @Since("1.5.0") + def summary: LogisticRegressionTrainingSummary = trainingSummary.getOrElse { + throw new SparkException("No training summary available for this LogisticRegressionModel") + } + + /** + * If the probability column is set returns the current model and probability column, + * otherwise generates a new column and sets it as the probability column on a new copy + * of the current model. + */ + private[classification] def findSummaryModelAndProbabilityCol(): + (LogisticRegressionModel, String) = { + $(probabilityCol) match { + case "" => + val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString + (copy(ParamMap.empty).setProbabilityCol(probabilityColName), probabilityColName) + case p => (this, p) + } } private[classification] def setSummary( @@ -445,16 +545,19 @@ class LogisticRegressionModel private[ml] ( } /** Indicates whether a training summary exists for this model instance. */ + @Since("1.5.0") def hasSummary: Boolean = trainingSummary.isDefined /** - * Evaluates the model on a testset. + * Evaluates the model on a test dataset. * @param dataset Test dataset to evaluate model on. */ - // TODO: decide on a good name before exposing to public API - private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = { - new BinaryLogisticRegressionSummary( - this.transform(dataset), $(probabilityCol), $(labelCol), $(featuresCol)) + @Since("2.0.0") + def evaluate(dataset: Dataset[_]): LogisticRegressionSummary = { + // Handle possible missing or invalid prediction columns + val (summaryModel, probabilityColName) = findSummaryModelAndProbabilityCol() + new BinaryLogisticRegressionSummary(summaryModel.transform(dataset), + probabilityColName, $(labelCol), $(featuresCol)) } /** @@ -487,6 +590,7 @@ class LogisticRegressionModel private[ml] ( Vectors.dense(-m, m) } + @Since("1.4.0") override def copy(extra: ParamMap): LogisticRegressionModel = { val newModel = copyValues(new LogisticRegressionModel(uid, coefficients, intercept), extra) if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) @@ -510,8 +614,77 @@ class LogisticRegressionModel private[ml] ( // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. if (probability(1) > getThreshold) 1 else 0 } + + /** + * Returns a [[MLWriter]] instance for this ML instance. + * + * For [[LogisticRegressionModel]], this does NOT currently save the training [[summary]]. + * An option to save [[summary]] may be added in the future. + * + * This also does not save the [[parent]] currently. + */ + @Since("1.6.0") + override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this) } + +@Since("1.6.0") +object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { + + @Since("1.6.0") + override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader + + @Since("1.6.0") + override def load(path: String): LogisticRegressionModel = super.load(path) + + /** [[MLWriter]] instance for [[LogisticRegressionModel]] */ + private[LogisticRegressionModel] + class LogisticRegressionModelWriter(instance: LogisticRegressionModel) + extends MLWriter with Logging { + + private case class Data( + numClasses: Int, + numFeatures: Int, + intercept: Double, + coefficients: Vector) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: numClasses, numFeatures, intercept, coefficients + val data = Data(instance.numClasses, instance.numFeatures, instance.intercept, + instance.coefficients) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class LogisticRegressionModelReader + extends MLReader[LogisticRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[LogisticRegressionModel].getName + + override def load(path: String): LogisticRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.format("parquet").load(dataPath) + .select("numClasses", "numFeatures", "intercept", "coefficients").head() + // We will need numClasses, numFeatures in the future for multinomial logreg support. + // val numClasses = data.getInt(0) + // val numFeatures = data.getInt(1) + val intercept = data.getDouble(2) + val coefficients = data.getAs[Vector](3) + val model = new LogisticRegressionModel(metadata.uid, coefficients, intercept) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } +} + + /** * MultiClassSummarizer computes the number of distinct labels and corresponding counts, * and validates the data to see if the labels used for k class multi-label classification @@ -533,7 +706,7 @@ private[classification] class MultiClassSummarizer extends Serializable { * @return This MultilabelSummarizer */ def add(label: Double, weight: Double = 1.0): this.type = { - require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this @@ -610,13 +783,13 @@ sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary */ sealed trait LogisticRegressionSummary extends Serializable { - /** Dataframe outputted by the model's `transform` method. */ + /** Dataframe output by the model's `transform` method. */ def predictions: DataFrame - /** Field in "predictions" which gives the calibrated probability of each instance as a vector. */ + /** Field in "predictions" which gives the probability of each class as a vector. */ def probabilityCol: String - /** Field in "predictions" which gives the true label of each instance. */ + /** Field in "predictions" which gives the true label of each instance (if available). */ def labelCol: String /** Field in "predictions" which gives the features of each instance as a vector. */ @@ -627,20 +800,22 @@ sealed trait LogisticRegressionSummary extends Serializable { /** * :: Experimental :: * Logistic regression training results. - * @param predictions dataframe outputted by the model's `transform` method. - * @param probabilityCol field in "predictions" which gives the calibrated probability of - * each instance as a vector. + * + * @param predictions dataframe output by the model's `transform` method. + * @param probabilityCol field in "predictions" which gives the probability of + * each class as a vector. * @param labelCol field in "predictions" which gives the true label of each instance. * @param featuresCol field in "predictions" which gives the features of each instance as a vector. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ @Experimental +@Since("1.5.0") class BinaryLogisticRegressionTrainingSummary private[classification] ( - predictions: DataFrame, - probabilityCol: String, - labelCol: String, - featuresCol: String, - val objectiveHistory: Array[Double]) + @Since("1.5.0") predictions: DataFrame, + @Since("1.5.0") probabilityCol: String, + @Since("1.5.0") labelCol: String, + @Since("1.6.0") featuresCol: String, + @Since("1.5.0") val objectiveHistory: Array[Double]) extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol, featuresCol) with LogisticRegressionTrainingSummary { @@ -649,18 +824,21 @@ class BinaryLogisticRegressionTrainingSummary private[classification] ( /** * :: Experimental :: * Binary Logistic regression results for a given model. - * @param predictions dataframe outputted by the model's `transform` method. - * @param probabilityCol field in "predictions" which gives the calibrated probability of - * each instance. + * + * @param predictions dataframe output by the model's `transform` method. + * @param probabilityCol field in "predictions" which gives the probability of + * each class as a vector. * @param labelCol field in "predictions" which gives the true label of each instance. * @param featuresCol field in "predictions" which gives the features of each instance as a vector. */ @Experimental +@Since("1.5.0") class BinaryLogisticRegressionSummary private[classification] ( - @transient override val predictions: DataFrame, - override val probabilityCol: String, - override val labelCol: String, - override val featuresCol: String) extends LogisticRegressionSummary { + @Since("1.5.0") @transient override val predictions: DataFrame, + @Since("1.5.0") override val probabilityCol: String, + @Since("1.5.0") override val labelCol: String, + @Since("1.6.0") override val featuresCol: String) extends LogisticRegressionSummary { + private val sqlContext = predictions.sqlContext import sqlContext.implicits._ @@ -671,7 +849,7 @@ class BinaryLogisticRegressionSummary private[classification] ( // TODO: Allow the user to vary the number of bins using a setBins method in // BinaryClassificationMetrics. For now the default is set to 100. @transient private val binaryMetrics = new BinaryClassificationMetrics( - predictions.select(probabilityCol, labelCol).map { + predictions.select(probabilityCol, labelCol).rdd.map { case Row(score: Vector, label: Double) => (score(1), label) }, 100 ) @@ -680,24 +858,40 @@ class BinaryLogisticRegressionSummary private[classification] ( * Returns the receiver operating characteristic (ROC) curve, * which is an Dataframe having two fields (FPR, TPR) * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic */ + @Since("1.5.0") @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR") /** * Computes the area under the receiver operating characteristic (ROC) curve. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ + @Since("1.5.0") lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC() /** * Returns the precision-recall curve, which is an Dataframe containing * two fields recall, precision with (0.0, 1.0) prepended to it. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ + @Since("1.5.0") @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision") /** * Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ + @Since("1.5.0") @transient lazy val fMeasureByThreshold: DataFrame = { binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure") } @@ -706,7 +900,11 @@ class BinaryLogisticRegressionSummary private[classification] ( * Returns a dataframe with two fields (threshold, precision) curve. * Every possible probability obtained in transforming the dataset are used * as thresholds used in calculating the precision. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ + @Since("1.5.0") @transient lazy val precisionByThreshold: DataFrame = { binaryMetrics.precisionByThreshold().toDF("threshold", "precision") } @@ -715,7 +913,11 @@ class BinaryLogisticRegressionSummary private[classification] ( * Returns a dataframe with two fields (threshold, recall) curve. * Every possible probability obtained in transforming the dataset are used * as thresholds used in calculating the recall. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ + @Since("1.5.0") @transient lazy val recallByThreshold: DataFrame = { binaryMetrics.recallByThreshold().toDF("threshold", "recall") } @@ -769,7 +971,7 @@ private class LogisticAggregator( instance match { case Instance(label, weight, features) => require(dim == features.size, s"Dimensions mismatch when adding new instance." + s" Expecting $dim but got ${features.size}.") - require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index cd7462596dd9e..9ff5252e4ff37 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -19,30 +19,32 @@ package org.apache.spark.ml.classification import scala.collection.JavaConverters._ -import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed} -import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} -import org.apache.spark.ml.param.{IntParam, ParamValidators, IntArrayParam, ParamMap} -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.ml.ann.{FeedForwardTrainer, FeedForwardTopology} -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} +import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasStepSize, HasTol} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} /** Params for Multilayer Perceptron. */ private[ml] trait MultilayerPerceptronParams extends PredictorParams - with HasSeed with HasMaxIter with HasTol { + with HasSeed with HasMaxIter with HasTol with HasStepSize { /** * Layer sizes including input size and output size. * Default: Array(1, 1) - * @group param + * + * @group param */ final val layers: IntArrayParam = new IntArrayParam(this, "layers", "Sizes of layers from input layer to output layer" + " E.g., Array(780, 100, 10) means 780 inputs, " + "one hidden layer with 100 neurons and output layer of 10 neurons.", - // TODO: how to check ALSO that all elements are greater than 0? - ParamValidators.arrayLengthGt(1) + (t: Array[Int]) => t.forall(ParamValidators.gt(0)) && t.length > 1 ) /** @group getParam */ @@ -54,7 +56,8 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams * a partition then it is adjusted to the size of this data. * Recommended size is between 10 and 1000. * Default: 128 - * @group expertParam + * + * @group expertParam */ final val blockSize: IntParam = new IntParam(this, "blockSize", "Block size for stacking input data in matrices. Data is stacked within partitions." + @@ -65,7 +68,33 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams /** @group getParam */ final def getBlockSize: Int = $(blockSize) - setDefault(maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 128) + /** + * Allows setting the solver: minibatch gradient descent (gd) or l-bfgs. + * l-bfgs is the default one. + * + * @group expertParam + */ + final val solver: Param[String] = new Param[String](this, "solver", + " Allows setting the solver: minibatch gradient descent (gd) or l-bfgs. " + + " l-bfgs is the default one.", + ParamValidators.inArray[String](Array("gd", "l-bfgs"))) + + /** @group getParam */ + final def getOptimizer: String = $(solver) + + /** + * Model weights. Can be returned either after training or after explicit setting + * + * @group expertParam + */ + final val weights: Param[Vector] = new Param[Vector](this, "weights", + " Sets the weights of the model ") + + /** @group getParam */ + final def getWeights: Vector = $(weights) + + + setDefault(maxIter -> 100, tol -> 1e-4, blockSize -> 128, solver -> "l-bfgs", stepSize -> 0.03) } /** Label to vector converter. */ @@ -106,40 +135,60 @@ private object LabelConverter { * Number of outputs has to be equal to the total number of labels. * */ +@Since("1.5.0") @Experimental -class MultilayerPerceptronClassifier(override val uid: String) +class MultilayerPerceptronClassifier @Since("1.5.0") ( + @Since("1.5.0") override val uid: String) extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel] - with MultilayerPerceptronParams { + with MultilayerPerceptronParams with DefaultParamsWritable { + @Since("1.5.0") def this() = this(Identifiable.randomUID("mlpc")) /** @group setParam */ + @Since("1.5.0") def setLayers(value: Array[Int]): this.type = set(layers, value) /** @group setParam */ + @Since("1.5.0") def setBlockSize(value: Int): this.type = set(blockSize, value) /** * Set the maximum number of iterations. * Default is 100. - * @group setParam + * + * @group setParam */ + @Since("1.5.0") def setMaxIter(value: Int): this.type = set(maxIter, value) /** * Set the convergence tolerance of iterations. * Smaller value will lead to higher accuracy with the cost of more iterations. * Default is 1E-4. - * @group setParam + * + * @group setParam */ + @Since("1.5.0") def setTol(value: Double): this.type = set(tol, value) /** - * Set the seed for weights initialization. - * @group setParam + * Set the seed for weights initialization if weights are not set + * + * @group setParam */ + @Since("1.5.0") def setSeed(value: Long): this.type = set(seed, value) + /** + * Sets the model weights. + * + * @group expertParam + */ + @Since("2.0.0") + def setWeights(value: Vector): this.type = set(weights, value) + + @Since("1.5.0") override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra) /** @@ -150,40 +199,58 @@ class MultilayerPerceptronClassifier(override val uid: String) * @param dataset Training dataset * @return Fitted model */ - override protected def train(dataset: DataFrame): MultilayerPerceptronClassificationModel = { + override protected def train(dataset: Dataset[_]): MultilayerPerceptronClassificationModel = { val myLayers = $(layers) val labels = myLayers.last val lpData = extractLabeledPoints(dataset) val data = lpData.map(lp => LabelConverter.encodeLabeledPoint(lp, labels)) val topology = FeedForwardTopology.multiLayerPerceptron(myLayers, true) - val FeedForwardTrainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last) - FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol($(tol)).setNumIterations($(maxIter)) - FeedForwardTrainer.setStackSize($(blockSize)) - val mlpModel = FeedForwardTrainer.train(data) - new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights()) + val trainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last) + if (isDefined(weights)) { + trainer.setWeights($(weights)) + } else { + trainer.setSeed($(seed)) + } + trainer.LBFGSOptimizer + .setConvergenceTol($(tol)) + .setNumIterations($(maxIter)) + trainer.setStackSize($(blockSize)) + val mlpModel = trainer.train(data) + new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights) } } +@Since("2.0.0") +object MultilayerPerceptronClassifier + extends DefaultParamsReadable[MultilayerPerceptronClassifier] { + + @Since("2.0.0") + override def load(path: String): MultilayerPerceptronClassifier = super.load(path) +} + /** * :: Experimental :: * Classification model based on the Multilayer Perceptron. * Each layer has sigmoid activation function, output layer has softmax. - * @param uid uid + * + * @param uid uid * @param layers array of layer sizes including input and output layers * @param weights vector of initial weights for the model that consists of the weights of layers * @return prediction model */ +@Since("1.5.0") @Experimental class MultilayerPerceptronClassificationModel private[ml] ( - override val uid: String, - val layers: Array[Int], - val weights: Vector) + @Since("1.5.0") override val uid: String, + @Since("1.5.0") val layers: Array[Int], + @Since("1.5.0") val weights: Vector) extends PredictionModel[Vector, MultilayerPerceptronClassificationModel] - with Serializable { + with Serializable with MLWritable { + @Since("1.6.0") override val numFeatures: Int = layers.head - private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) + private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).model(weights) /** * Returns layers in a Java List. @@ -200,7 +267,61 @@ class MultilayerPerceptronClassificationModel private[ml] ( LabelConverter.decodeLabel(mlpModel.predict(features)) } + @Since("1.5.0") override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = { copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra) } + + @Since("2.0.0") + override def write: MLWriter = + new MultilayerPerceptronClassificationModel.MultilayerPerceptronClassificationModelWriter(this) +} + +@Since("2.0.0") +object MultilayerPerceptronClassificationModel + extends MLReadable[MultilayerPerceptronClassificationModel] { + + @Since("2.0.0") + override def read: MLReader[MultilayerPerceptronClassificationModel] = + new MultilayerPerceptronClassificationModelReader + + @Since("2.0.0") + override def load(path: String): MultilayerPerceptronClassificationModel = super.load(path) + + /** [[MLWriter]] instance for [[MultilayerPerceptronClassificationModel]] */ + private[MultilayerPerceptronClassificationModel] + class MultilayerPerceptronClassificationModelWriter( + instance: MultilayerPerceptronClassificationModel) extends MLWriter { + + private case class Data(layers: Array[Int], weights: Vector) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: layers, weights + val data = Data(instance.layers, instance.weights) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class MultilayerPerceptronClassificationModelReader + extends MLReader[MultilayerPerceptronClassificationModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[MultilayerPerceptronClassificationModel].getName + + override def load(path: String): MultilayerPerceptronClassificationModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("layers", "weights").head() + val layers = data.getAs[Seq[Int]](0).toArray + val weights = data.getAs[Vector](1) + val model = new MultilayerPerceptronClassificationModel(metadata.uid, layers, weights) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index a14dcecbaf5b9..267d63b51eb6c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -17,16 +17,19 @@ package org.apache.spark.ml.classification +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkException -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes, NaiveBayesModel => OldNaiveBayesModel} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes} +import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} /** * Params for Naive Bayes Classifiers. @@ -69,11 +72,14 @@ private[ml] trait NaiveBayesParams extends PredictorParams { * ([[http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html]]). * The input feature values must be nonnegative. */ +@Since("1.5.0") @Experimental -class NaiveBayes(override val uid: String) +class NaiveBayes @Since("1.5.0") ( + @Since("1.5.0") override val uid: String) extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] - with NaiveBayesParams { + with NaiveBayesParams with DefaultParamsWritable { + @Since("1.5.0") def this() = this(Identifiable.randomUID("nb")) /** @@ -81,6 +87,7 @@ class NaiveBayes(override val uid: String) * Default is 1.0. * @group setParam */ + @Since("1.5.0") def setSmoothing(value: Double): this.type = set(smoothing, value) setDefault(smoothing -> 1.0) @@ -90,18 +97,27 @@ class NaiveBayes(override val uid: String) * Default is "multinomial" * @group setParam */ + @Since("1.5.0") def setModelType(value: String): this.type = set(modelType, value) setDefault(modelType -> OldNaiveBayes.Multinomial) - override protected def train(dataset: DataFrame): NaiveBayesModel = { + override protected def train(dataset: Dataset[_]): NaiveBayesModel = { val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType)) NaiveBayesModel.fromOld(oldModel, this) } + @Since("1.5.0") override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra) } +@Since("1.6.0") +object NaiveBayes extends DefaultParamsReadable[NaiveBayes] { + + @Since("1.6.0") + override def load(path: String): NaiveBayes = super.load(path) +} + /** * :: Experimental :: * Model produced by [[NaiveBayes]] @@ -109,12 +125,14 @@ class NaiveBayes(override val uid: String) * @param theta log of class conditional probabilities, whose dimension is C (number of classes) * by D (number of features) */ +@Since("1.5.0") @Experimental class NaiveBayesModel private[ml] ( - override val uid: String, - val pi: Vector, - val theta: Matrix) - extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams { + @Since("1.5.0") override val uid: String, + @Since("1.5.0") val pi: Vector, + @Since("1.5.0") val theta: Matrix) + extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] + with NaiveBayesParams with MLWritable { import OldNaiveBayes.{Bernoulli, Multinomial} @@ -127,7 +145,7 @@ class NaiveBayesModel private[ml] ( case Multinomial => (None, None) case Bernoulli => val negTheta = theta.map(value => math.log(1.0 - math.exp(value))) - val ones = new DenseVector(Array.fill(theta.numCols){1.0}) + val ones = new DenseVector(Array.fill(theta.numCols) {1.0}) val thetaMinusNegTheta = theta.map { value => value - math.log(1.0 - math.exp(value)) } @@ -137,8 +155,10 @@ class NaiveBayesModel private[ml] ( throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") } + @Since("1.6.0") override val numFeatures: Int = theta.numCols + @Since("1.5.0") override val numClasses: Int = pi.size private def multinomialCalculation(features: Vector) = { @@ -195,20 +215,25 @@ class NaiveBayesModel private[ml] ( } } + @Since("1.5.0") override def copy(extra: ParamMap): NaiveBayesModel = { copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra) } + @Since("1.5.0") override def toString: String = { s"NaiveBayesModel (uid=$uid) with ${pi.size} classes" } + @Since("1.6.0") + override def write: MLWriter = new NaiveBayesModel.NaiveBayesModelWriter(this) } -private[ml] object NaiveBayesModel { +@Since("1.6.0") +object NaiveBayesModel extends MLReadable[NaiveBayesModel] { /** Convert a model from the old API */ - def fromOld( + private[ml] def fromOld( oldModel: OldNaiveBayesModel, parent: NaiveBayes): NaiveBayesModel = { val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb") @@ -218,4 +243,44 @@ private[ml] object NaiveBayesModel { oldModel.theta.flatten, true) new NaiveBayesModel(uid, pi, theta) } + + @Since("1.6.0") + override def read: MLReader[NaiveBayesModel] = new NaiveBayesModelReader + + @Since("1.6.0") + override def load(path: String): NaiveBayesModel = super.load(path) + + /** [[MLWriter]] instance for [[NaiveBayesModel]] */ + private[NaiveBayesModel] class NaiveBayesModelWriter(instance: NaiveBayesModel) extends MLWriter { + + private case class Data(pi: Vector, theta: Matrix) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: pi, theta + val data = Data(instance.pi, instance.theta) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class NaiveBayesModelReader extends MLReader[NaiveBayesModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[NaiveBayesModel].getName + + override def load(path: String): NaiveBayesModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("pi", "theta").head() + val pi = data.getAs[Vector](0) + val theta = data.getAs[Matrix](1) + val model = new NaiveBayesModel(metadata.uid, pi, theta) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index debc164bf2432..4de1b877b0194 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -21,22 +21,24 @@ import java.util.UUID import scala.language.existentials -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path +import org.json4s.{DefaultFormats, JObject, _} +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel -/** - * Params for [[OneVsRest]]. - */ -private[ml] trait OneVsRestParams extends PredictorParams { - +private[ml] trait ClassifierTypeTrait { // scalastyle:off structural.type type ClassifierType = Classifier[F, E, M] forSome { type F @@ -44,6 +46,12 @@ private[ml] trait OneVsRestParams extends PredictorParams { type E <: Classifier[F, E, M] } // scalastyle:on structural.type +} + +/** + * Params for [[OneVsRest]]. + */ +private[ml] trait OneVsRestParams extends PredictorParams with ClassifierTypeTrait { /** * param for the base binary classifier that we reduce multiclass classification into. @@ -57,6 +65,55 @@ private[ml] trait OneVsRestParams extends PredictorParams { def getClassifier: ClassifierType = $(classifier) } +private[ml] object OneVsRestParams extends ClassifierTypeTrait { + + def validateParams(instance: OneVsRestParams): Unit = { + def checkElement(elem: Params, name: String): Unit = elem match { + case stage: MLWritable => // good + case other => + throw new UnsupportedOperationException("OneVsRest write will fail " + + s" because it contains $name which does not implement MLWritable." + + s" Non-Writable $name: ${other.uid} of type ${other.getClass}") + } + + instance match { + case ovrModel: OneVsRestModel => ovrModel.models.foreach(checkElement(_, "model")) + case _ => // no need to check OneVsRest here + } + + checkElement(instance.getClassifier, "classifier") + } + + def saveImpl( + path: String, + instance: OneVsRestParams, + sc: SparkContext, + extraMetadata: Option[JObject] = None): Unit = { + + val params = instance.extractParamMap().toSeq + val jsonParams = render(params + .filter { case ParamPair(p, v) => p.name != "classifier" } + .map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) } + .toList) + + DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams)) + + val classifierPath = new Path(path, "classifier").toString + instance.getClassifier.asInstanceOf[MLWritable].save(classifierPath) + } + + def loadImpl( + path: String, + sc: SparkContext, + expectedClassName: String): (DefaultParamsReader.Metadata, ClassifierType) = { + + val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName) + val classifierPath = new Path(path, "classifier").toString + val estimator = DefaultParamsReader.loadParamsInstance[ClassifierType](classifierPath, sc) + (metadata, estimator) + } +} + /** * :: Experimental :: * Model produced by [[OneVsRest]]. @@ -70,18 +127,21 @@ private[ml] trait OneVsRestParams extends PredictorParams { * The i-th model is produced by testing the i-th class (taking label 1) vs the rest * (taking label 0). */ +@Since("1.4.0") @Experimental final class OneVsRestModel private[ml] ( - override val uid: String, - labelMetadata: Metadata, - val models: Array[_ <: ClassificationModel[_, _]]) - extends Model[OneVsRestModel] with OneVsRestParams { + @Since("1.4.0") override val uid: String, + private[ml] val labelMetadata: Metadata, + @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]]) + extends Model[OneVsRestModel] with OneVsRestParams with MLWritable { + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) } - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { // Check schema transformSchema(dataset.schema, logging = true) @@ -110,13 +170,13 @@ final class OneVsRestModel private[ml] ( val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) => predictions + ((index, prediction(1))) } - val transformedDataset = model.transform(df).select(columns : _*) + val transformedDataset = model.transform(df).select(columns: _*) val updatedDataset = transformedDataset .withColumn(tmpColName, updateUDF(col(accColName), col(rawPredictionCol))) val newColumns = origCols ++ List(col(tmpColName)) // switch out the intermediate column with the accumulator column - updatedDataset.select(newColumns : _*).withColumnRenamed(tmpColName, accColName) + updatedDataset.select(newColumns: _*).withColumnRenamed(tmpColName, accColName) } if (handlePersistence) { @@ -134,11 +194,62 @@ final class OneVsRestModel private[ml] ( .drop(accColName) } + @Since("1.4.1") override def copy(extra: ParamMap): OneVsRestModel = { val copied = new OneVsRestModel( uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]])) copyValues(copied, extra).setParent(parent) } + + @Since("2.0.0") + override def write: MLWriter = new OneVsRestModel.OneVsRestModelWriter(this) +} + +@Since("2.0.0") +object OneVsRestModel extends MLReadable[OneVsRestModel] { + + @Since("2.0.0") + override def read: MLReader[OneVsRestModel] = new OneVsRestModelReader + + @Since("2.0.0") + override def load(path: String): OneVsRestModel = super.load(path) + + /** [[MLWriter]] instance for [[OneVsRestModel]] */ + private[OneVsRestModel] class OneVsRestModelWriter(instance: OneVsRestModel) extends MLWriter { + + OneVsRestParams.validateParams(instance) + + override protected def saveImpl(path: String): Unit = { + val extraJson = ("labelMetadata" -> instance.labelMetadata.json) ~ + ("numClasses" -> instance.models.length) + OneVsRestParams.saveImpl(path, instance, sc, Some(extraJson)) + instance.models.zipWithIndex.foreach { case (model: MLWritable, idx) => + val modelPath = new Path(path, s"model_$idx").toString + model.save(modelPath) + } + } + } + + private class OneVsRestModelReader extends MLReader[OneVsRestModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[OneVsRestModel].getName + + override def load(path: String): OneVsRestModel = { + implicit val format = DefaultFormats + val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className) + val labelMetadata = Metadata.fromJson((metadata.metadata \ "labelMetadata").extract[String]) + val numClasses = (metadata.metadata \ "numClasses").extract[Int] + val models = Range(0, numClasses).toArray.map { idx => + val modelPath = new Path(path, s"model_$idx").toString + DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sc) + } + val ovrModel = new OneVsRestModel(metadata.uid, labelMetadata, models) + DefaultParamsReader.getAndSetParams(ovrModel, metadata) + ovrModel.set("classifier", classifier) + ovrModel + } + } } /** @@ -150,35 +261,46 @@ final class OneVsRestModel private[ml] ( * Each example is scored against all k models and the model with highest score * is picked to label the example. */ +@Since("1.4.0") @Experimental -final class OneVsRest(override val uid: String) - extends Estimator[OneVsRestModel] with OneVsRestParams { +final class OneVsRest @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) + extends Estimator[OneVsRestModel] with OneVsRestParams with MLWritable { + @Since("1.4.0") def this() = this(Identifiable.randomUID("oneVsRest")) /** @group setParam */ + @Since("1.4.0") def setClassifier(value: Classifier[_, _, _]): this.type = { set(classifier, value.asInstanceOf[ClassifierType]) } /** @group setParam */ + @Since("1.5.0") def setLabelCol(value: String): this.type = set(labelCol, value) /** @group setParam */ + @Since("1.5.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) /** @group setParam */ + @Since("1.5.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType) } - override def fit(dataset: DataFrame): OneVsRestModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): OneVsRestModel = { + transformSchema(dataset.schema) + // determine number of classes either from metadata if provided, or via computation. val labelSchema = dataset.schema($(labelCol)) val computeNumClasses: () => Int = () => { - val Row(maxLabelIndex: Double) = dataset.agg(max($(labelCol))).head() + val Row(maxLabelIndex: Double) = dataset.agg(max(col($(labelCol)).cast(DoubleType))).head() // classes are assumed to be numbered from 0,...,maxLabelIndex maxLabelIndex.toInt + 1 } @@ -222,6 +344,7 @@ final class OneVsRest(override val uid: String) copyValues(model) } + @Since("1.4.1") override def copy(extra: ParamMap): OneVsRest = { val copied = defaultCopy(extra).asInstanceOf[OneVsRest] if (isDefined(classifier)) { @@ -229,4 +352,40 @@ final class OneVsRest(override val uid: String) } copied } + + @Since("2.0.0") + override def write: MLWriter = new OneVsRest.OneVsRestWriter(this) +} + +@Since("2.0.0") +object OneVsRest extends MLReadable[OneVsRest] { + + @Since("2.0.0") + override def read: MLReader[OneVsRest] = new OneVsRestReader + + @Since("2.0.0") + override def load(path: String): OneVsRest = super.load(path) + + /** [[MLWriter]] instance for [[OneVsRest]] */ + private[OneVsRest] class OneVsRestWriter(instance: OneVsRest) extends MLWriter { + + OneVsRestParams.validateParams(instance) + + override protected def saveImpl(path: String): Unit = { + OneVsRestParams.saveImpl(path, instance, sc) + } + } + + private class OneVsRestReader extends MLReader[OneVsRest] { + + /** Checked against metadata when loading model */ + private val className = classOf[OneVsRest].getName + + override def load(path: String): OneVsRest = { + val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className) + val ovr = new OneVsRest(metadata.uid) + DefaultParamsReader.getAndSetParams(ovr, metadata) + ovr.setClassifier(classifier) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index fdd1851ae5508..d00fee12b08c0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -20,8 +20,8 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils -import org.apache.spark.mllib.linalg.{DenseVector, Vector, VectorUDT, Vectors} -import org.apache.spark.sql.DataFrame +import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors, VectorUDT} +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, StructType} @@ -95,7 +95,7 @@ abstract class ProbabilisticClassificationModel[ * @param dataset input dataset * @return transformed dataset */ - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + @@ -145,7 +145,7 @@ abstract class ProbabilisticClassificationModel[ this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" + " since no output columns were set.") } - outputData + outputData.toDF } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index bae329692a68d..dfa711b2436cb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -17,17 +17,21 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.tree.impl.RandomForest +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} -import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors} +import org.apache.spark.ml.tree._ +import org.apache.spark.ml.tree.impl.RandomForest +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ @@ -38,48 +42,63 @@ import org.apache.spark.sql.functions._ * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ +@Since("1.4.0") @Experimental -final class RandomForestClassifier(override val uid: String) +final class RandomForestClassifier @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel] - with RandomForestParams with TreeClassifierParams { + with RandomForestClassifierParams with DefaultParamsWritable { + @Since("1.4.0") def this() = this(Identifiable.randomUID("rfc")) // Override parameter setters from parent trait for Java API compatibility. // Parameters from TreeClassifierParams: + @Since("1.4.0") override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + @Since("1.4.0") override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + @Since("1.4.0") override def setMinInstancesPerNode(value: Int): this.type = super.setMinInstancesPerNode(value) + @Since("1.4.0") override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + @Since("1.4.0") override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + @Since("1.4.0") override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + @Since("1.4.0") override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + @Since("1.4.0") override def setImpurity(value: String): this.type = super.setImpurity(value) // Parameters from TreeEnsembleParams: + @Since("1.4.0") override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + @Since("1.4.0") override def setSeed(value: Long): this.type = super.setSeed(value) // Parameters from RandomForestParams: + @Since("1.4.0") override def setNumTrees(value: Int): this.type = super.setNumTrees(value) + @Since("1.4.0") override def setFeatureSubsetStrategy(value: String): this.type = super.setFeatureSubsetStrategy(value) - override protected def train(dataset: DataFrame): RandomForestClassificationModel = { + override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { @@ -99,17 +118,24 @@ final class RandomForestClassifier(override val uid: String) new RandomForestClassificationModel(trees, numFeatures, numClasses) } + @Since("1.4.1") override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra) } +@Since("1.4.0") @Experimental -object RandomForestClassifier { +object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifier] { /** Accessor for supported impurity settings: entropy, gini */ + @Since("1.4.0") final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ + @Since("1.4.0") final val supportedFeatureSubsetStrategies: Array[String] = RandomForestParams.supportedFeatureSubsetStrategies + + @Since("2.0.0") + override def load(path: String): RandomForestClassifier = super.load(path) } /** @@ -117,22 +143,26 @@ object RandomForestClassifier { * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for classification. * It supports both binary and multiclass labels, as well as both continuous and categorical * features. + * * @param _trees Decision trees in the ensemble. - * Warning: These have null parents. + * Warning: These have null parents. */ +@Since("1.4.0") @Experimental final class RandomForestClassificationModel private[ml] ( - override val uid: String, + @Since("1.5.0") override val uid: String, private val _trees: Array[DecisionTreeClassificationModel], - override val numFeatures: Int, - override val numClasses: Int) + @Since("1.6.0") override val numFeatures: Int, + @Since("1.5.0") override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel] - with TreeEnsembleModel with Serializable { + with RandomForestClassificationModelParams with TreeEnsembleModel[DecisionTreeClassificationModel] + with MLWritable with Serializable { - require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") + require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.") /** * Construct a random forest classification model, with all trees weighted equally. + * * @param trees Component trees */ private[ml] def this( @@ -141,14 +171,16 @@ final class RandomForestClassificationModel private[ml] ( numClasses: Int) = this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses) - override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + @Since("1.4.0") + override def trees: Array[DecisionTreeClassificationModel] = _trees // Note: We may add support for weights (based on tree performance) later on. - private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0) + private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0) + @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights - override protected def transformImpl(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) val predictUDF = udf { (features: Any) => bcastModel.value.predict(features.asInstanceOf[Vector]) @@ -186,42 +218,106 @@ final class RandomForestClassificationModel private[ml] ( } } + /** + * Number of trees in ensemble + * + * @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0 + */ + // TODO: Once this is removed, then this class can inherit from RandomForestClassifierParams + @deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0") + val numTrees: Int = trees.length + + @Since("1.4.0") override def copy(extra: ParamMap): RandomForestClassificationModel = { copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra) .setParent(parent) } + @Since("1.4.0") override def toString: String = { - s"RandomForestClassificationModel (uid=$uid) with $numTrees trees" + s"RandomForestClassificationModel (uid=$uid) with $getNumTrees trees" } /** * Estimate of the importance of each feature. * - * This generalizes the idea of "Gini" importance to other losses, - * following the explanation of Gini importance from "Random Forests" documentation - * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * Each feature's importance is the average of its importance across all trees in the ensemble + * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. + * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) + * and follows the implementation from scikit-learn. * - * This feature importance is calculated as follows: - * - Average over trees: - * - importance(feature j) = sum (over nodes which split on feature j) of the gain, - * where gain is scaled by the number of instances passing through node - * - Normalize importances for tree based on total number of training instances used - * to build tree. - * - Normalize feature importance vector to sum to 1. + * @see [[DecisionTreeClassificationModel.featureImportances]] */ - lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures) + @Since("1.5.0") + lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldRandomForestModel = { new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld)) } + + @Since("2.0.0") + override def write: MLWriter = + new RandomForestClassificationModel.RandomForestClassificationModelWriter(this) } -private[ml] object RandomForestClassificationModel { +@Since("2.0.0") +object RandomForestClassificationModel extends MLReadable[RandomForestClassificationModel] { + + @Since("2.0.0") + override def read: MLReader[RandomForestClassificationModel] = + new RandomForestClassificationModelReader + + @Since("2.0.0") + override def load(path: String): RandomForestClassificationModel = super.load(path) + + private[RandomForestClassificationModel] + class RandomForestClassificationModelWriter(instance: RandomForestClassificationModel) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + // Note: numTrees is not currently used, but could be nice to store for fast querying. + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures, + "numClasses" -> instance.numClasses, + "numTrees" -> instance.getNumTrees) + EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + } + } + + private class RandomForestClassificationModelReader + extends MLReader[RandomForestClassificationModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[RandomForestClassificationModel].getName + private val treeClassName = classOf[DecisionTreeClassificationModel].getName + + override def load(path: String): RandomForestClassificationModel = { + implicit val format = DefaultFormats + val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) = + EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val numClasses = (metadata.metadata \ "numClasses").extract[Int] + val numTrees = (metadata.metadata \ "numTrees").extract[Int] + + val trees: Array[DecisionTreeClassificationModel] = treesData.map { + case (treeMetadata, root) => + val tree = + new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses) + DefaultParamsReader.getAndSetParams(tree, treeMetadata) + tree + } + require(numTrees == trees.length, s"RandomForestClassificationModel.load expected $numTrees" + + s" trees based on metadata but found ${trees.length} trees.") + + val model = new RandomForestClassificationModel(metadata.uid, trees, numFeatures, numClasses) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } - /** (private[ml]) Convert a model from the old API */ - def fromOld( + /** Convert a model from the old API */ + private[ml] def fromOld( oldModel: OldRandomForestModel, parent: RandomForestClassifier, categoricalFeatures: Map[Int, Int], diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala new file mode 100644 index 0000000000000..6cc9117da3fea --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.clustering. + {BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{IntegerType, StructType} + + +/** + * Common params for BisectingKMeans and BisectingKMeansModel + */ +private[clustering] trait BisectingKMeansParams extends Params + with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol { + + /** + * Set the number of clusters to create (k). Must be > 1. Default: 2. + * @group param + */ + @Since("2.0.0") + final val k = new IntParam(this, "k", "number of clusters to create", (x: Int) => x > 1) + + /** @group getParam */ + @Since("2.0.0") + def getK: Int = $(k) + + /** @group expertParam */ + @Since("2.0.0") + final val minDivisibleClusterSize = new DoubleParam( + this, + "minDivisibleClusterSize", + "the minimum number of points (if >= 1.0) or the minimum proportion", + (value: Double) => value > 0) + + /** @group expertGetParam */ + @Since("2.0.0") + def getMinDivisibleClusterSize: Double = $(minDivisibleClusterSize) + + /** + * Validates and transforms the input schema. + * @param schema input schema + * @return output schema + */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) + } +} + +/** + * :: Experimental :: + * Model fitted by BisectingKMeans. + * + * @param parentModel a model trained by spark.mllib.clustering.BisectingKMeans. + */ +@Since("2.0.0") +@Experimental +class BisectingKMeansModel private[ml] ( + @Since("2.0.0") override val uid: String, + private val parentModel: MLlibBisectingKMeansModel + ) extends Model[BisectingKMeansModel] with BisectingKMeansParams with MLWritable { + + @Since("2.0.0") + override def copy(extra: ParamMap): BisectingKMeansModel = { + val copied = new BisectingKMeansModel(uid, parentModel) + copyValues(copied, extra) + } + + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { + val predictUDF = udf((vector: Vector) => predict(vector)) + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + + @Since("2.0.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + private[clustering] def predict(features: Vector): Int = parentModel.predict(features) + + @Since("2.0.0") + def clusterCenters: Array[Vector] = parentModel.clusterCenters + + /** + * Computes the sum of squared distances between the input points and their corresponding cluster + * centers. + */ + @Since("2.0.0") + def computeCost(dataset: Dataset[_]): Double = { + SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) + val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } + parentModel.computeCost(data) + } + + @Since("2.0.0") + override def write: MLWriter = new BisectingKMeansModel.BisectingKMeansModelWriter(this) +} + +object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] { + @Since("2.0.0") + override def read: MLReader[BisectingKMeansModel] = new BisectingKMeansModelReader + + @Since("2.0.0") + override def load(path: String): BisectingKMeansModel = super.load(path) + + /** [[MLWriter]] instance for [[BisectingKMeansModel]] */ + private[BisectingKMeansModel] + class BisectingKMeansModelWriter(instance: BisectingKMeansModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + val dataPath = new Path(path, "data").toString + instance.parentModel.save(sc, dataPath) + } + } + + private class BisectingKMeansModelReader extends MLReader[BisectingKMeansModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[BisectingKMeansModel].getName + + override def load(path: String): BisectingKMeansModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath) + val model = new BisectingKMeansModel(metadata.uid, mllibModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } +} + +/** + * :: Experimental :: + * + * A bisecting k-means algorithm based on the paper "A comparison of document clustering techniques" + * by Steinbach, Karypis, and Kumar, with modification to fit Spark. + * The algorithm starts from a single cluster that contains all points. + * Iteratively it finds divisible clusters on the bottom level and bisects each of them using + * k-means, until there are `k` leaf clusters in total or no leaf clusters are divisible. + * The bisecting steps of clusters on the same level are grouped together to increase parallelism. + * If bisecting all divisible clusters on the bottom level would result more than `k` leaf clusters, + * larger clusters get higher priority. + * + * @see [[http://glaros.dtc.umn.edu/gkhome/fetch/papers/docclusterKDDTMW00.pdf + * Steinbach, Karypis, and Kumar, A comparison of document clustering techniques, + * KDD Workshop on Text Mining, 2000.]] + */ +@Since("2.0.0") +@Experimental +class BisectingKMeans @Since("2.0.0") ( + @Since("2.0.0") override val uid: String) + extends Estimator[BisectingKMeansModel] with BisectingKMeansParams with DefaultParamsWritable { + + setDefault( + k -> 4, + maxIter -> 20, + minDivisibleClusterSize -> 1.0) + + @Since("2.0.0") + override def copy(extra: ParamMap): BisectingKMeans = defaultCopy(extra) + + @Since("2.0.0") + def this() = this(Identifiable.randomUID("bisecting-kmeans")) + + /** @group setParam */ + @Since("2.0.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("2.0.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("2.0.0") + def setK(value: Int): this.type = set(k, value) + + /** @group setParam */ + @Since("2.0.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + @Since("2.0.0") + def setSeed(value: Long): this.type = set(seed, value) + + /** @group expertSetParam */ + @Since("2.0.0") + def setMinDivisibleClusterSize(value: Double): this.type = set(minDivisibleClusterSize, value) + + @Since("2.0.0") + override def fit(dataset: Dataset[_]): BisectingKMeansModel = { + val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } + + val bkm = new MLlibBisectingKMeans() + .setK($(k)) + .setMaxIterations($(maxIter)) + .setMinDivisibleClusterSize($(minDivisibleClusterSize)) + .setSeed($(seed)) + val parentModel = bkm.run(rdd) + val model = new BisectingKMeansModel(uid, parentModel) + copyValues(model.setParent(this)) + } + + @Since("2.0.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } +} + + +@Since("2.0.0") +object BisectingKMeans extends DefaultParamsReadable[BisectingKMeans] { + + @Since("2.0.0") + override def load(path: String): BisectingKMeans = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala new file mode 100644 index 0000000000000..ead8ad7806290 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param.{IntParam, ParamMap, Params} +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM, GaussianMixtureModel => MLlibGMModel} +import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.stat.distribution.MultivariateGaussian +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{IntegerType, StructType} + + +/** + * Common params for GaussianMixture and GaussianMixtureModel + */ +private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter with HasFeaturesCol + with HasSeed with HasPredictionCol with HasProbabilityCol with HasTol { + + /** + * Set the number of clusters to create (k). Must be > 1. Default: 2. + * @group param + */ + @Since("2.0.0") + final val k = new IntParam(this, "k", "number of clusters to create", (x: Int) => x > 1) + + /** @group getParam */ + @Since("2.0.0") + def getK: Int = $(k) + + /** + * Validates and transforms the input schema. + * @param schema input schema + * @return output schema + */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) + SchemaUtils.appendColumn(schema, $(probabilityCol), new VectorUDT) + } +} + +/** + * :: Experimental :: + * Model fitted by GaussianMixture. + * @param parentModel a model trained by spark.mllib.clustering.GaussianMixture. + */ +@Since("2.0.0") +@Experimental +class GaussianMixtureModel private[ml] ( + @Since("2.0.0") override val uid: String, + private val parentModel: MLlibGMModel) + extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable { + + @Since("2.0.0") + override def copy(extra: ParamMap): GaussianMixtureModel = { + val copied = new GaussianMixtureModel(uid, parentModel) + copyValues(copied, extra).setParent(this.parent) + } + + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { + val predUDF = udf((vector: Vector) => predict(vector)) + val probUDF = udf((vector: Vector) => predictProbability(vector)) + dataset.withColumn($(predictionCol), predUDF(col($(featuresCol)))) + .withColumn($(probabilityCol), probUDF(col($(featuresCol)))) + } + + @Since("2.0.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + private[clustering] def predict(features: Vector): Int = parentModel.predict(features) + + private[clustering] def predictProbability(features: Vector): Vector = { + Vectors.dense(parentModel.predictSoft(features)) + } + + @Since("2.0.0") + def weights: Array[Double] = parentModel.weights + + @Since("2.0.0") + def gaussians: Array[MultivariateGaussian] = parentModel.gaussians + + @Since("2.0.0") + override def write: MLWriter = new GaussianMixtureModel.GaussianMixtureModelWriter(this) + + private var trainingSummary: Option[GaussianMixtureSummary] = None + + private[clustering] def setSummary(summary: GaussianMixtureSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** + * Return true if there exists summary of model. + */ + @Since("2.0.0") + def hasSummary: Boolean = trainingSummary.nonEmpty + + /** + * Gets summary of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + @Since("2.0.0") + def summary: GaussianMixtureSummary = trainingSummary.getOrElse { + throw new RuntimeException( + s"No training summary available for the ${this.getClass.getSimpleName}") + } +} + +@Since("2.0.0") +object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { + + @Since("2.0.0") + override def read: MLReader[GaussianMixtureModel] = new GaussianMixtureModelReader + + @Since("2.0.0") + override def load(path: String): GaussianMixtureModel = super.load(path) + + /** [[MLWriter]] instance for [[GaussianMixtureModel]] */ + private[GaussianMixtureModel] class GaussianMixtureModelWriter( + instance: GaussianMixtureModel) extends MLWriter { + + private case class Data(weights: Array[Double], mus: Array[Vector], sigmas: Array[Matrix]) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: weights and gaussians + val weights = instance.weights + val gaussians = instance.gaussians + val mus = gaussians.map(_.mu) + val sigmas = gaussians.map(_.sigma) + val data = Data(weights, mus, sigmas) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class GaussianMixtureModelReader extends MLReader[GaussianMixtureModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[GaussianMixtureModel].getName + + override def load(path: String): GaussianMixtureModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val row = sqlContext.read.parquet(dataPath).select("weights", "mus", "sigmas").head() + val weights = row.getSeq[Double](0).toArray + val mus = row.getSeq[Vector](1).toArray + val sigmas = row.getSeq[Matrix](2).toArray + require(mus.length == sigmas.length, "Length of Mu and Sigma array must match") + require(mus.length == weights.length, "Length of weight and Gaussian array must match") + + val gaussians = (mus zip sigmas).map { + case (mu, sigma) => + new MultivariateGaussian(mu, sigma) + } + val model = new GaussianMixtureModel(metadata.uid, new MLlibGMModel(weights, gaussians)) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } +} + +/** + * :: Experimental :: + * GaussianMixture clustering. + */ +@Since("2.0.0") +@Experimental +class GaussianMixture @Since("2.0.0") ( + @Since("2.0.0") override val uid: String) + extends Estimator[GaussianMixtureModel] with GaussianMixtureParams with DefaultParamsWritable { + + setDefault( + k -> 2, + maxIter -> 100, + tol -> 0.01) + + @Since("2.0.0") + override def copy(extra: ParamMap): GaussianMixture = defaultCopy(extra) + + @Since("2.0.0") + def this() = this(Identifiable.randomUID("GaussianMixture")) + + /** @group setParam */ + @Since("2.0.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("2.0.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("2.0.0") + def setProbabilityCol(value: String): this.type = set(probabilityCol, value) + + /** @group setParam */ + @Since("2.0.0") + def setK(value: Int): this.type = set(k, value) + + /** @group setParam */ + @Since("2.0.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + @Since("2.0.0") + def setTol(value: Double): this.type = set(tol, value) + + /** @group setParam */ + @Since("2.0.0") + def setSeed(value: Long): this.type = set(seed, value) + + @Since("2.0.0") + override def fit(dataset: Dataset[_]): GaussianMixtureModel = { + val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } + + val algo = new MLlibGM() + .setK($(k)) + .setMaxIterations($(maxIter)) + .setSeed($(seed)) + .setConvergenceTol($(tol)) + val parentModel = algo.run(rdd) + val model = copyValues(new GaussianMixtureModel(uid, parentModel).setParent(this)) + val summary = new GaussianMixtureSummary(model.transform(dataset), + $(predictionCol), $(probabilityCol), $(featuresCol), $(k)) + model.setSummary(summary) + } + + @Since("2.0.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } +} + +@Since("2.0.0") +object GaussianMixture extends DefaultParamsReadable[GaussianMixture] { + + @Since("2.0.0") + override def load(path: String): GaussianMixture = super.load(path) +} + +/** + * :: Experimental :: + * Summary of GaussianMixture. + * + * @param predictions [[DataFrame]] produced by [[GaussianMixtureModel.transform()]] + * @param predictionCol Name for column of predicted clusters in `predictions` + * @param probabilityCol Name for column of predicted probability of each cluster in `predictions` + * @param featuresCol Name for column of features in `predictions` + * @param k Number of clusters + */ +@Since("2.0.0") +@Experimental +class GaussianMixtureSummary private[clustering] ( + @Since("2.0.0") @transient val predictions: DataFrame, + @Since("2.0.0") val predictionCol: String, + @Since("2.0.0") val probabilityCol: String, + @Since("2.0.0") val featuresCol: String, + @Since("2.0.0") val k: Int) extends Serializable { + + /** + * Cluster centers of the transformed data. + */ + @Since("2.0.0") + @transient lazy val cluster: DataFrame = predictions.select(predictionCol) + + /** + * Probability of each cluster. + */ + @Since("2.0.0") + @transient lazy val probability: DataFrame = predictions.select(probabilityCol) + + /** + * Size of (number of data points in) each cluster. + */ + @Since("2.0.0") + lazy val clusterSizes: Array[Long] = { + val sizes = Array.fill[Long](k)(0) + cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach { + case Row(cluster: Int, count: Long) => sizes(cluster) = count + } + sizes + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 509be63002396..b324196842a41 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -17,17 +17,19 @@ package org.apache.spark.ml.clustering -import org.apache.spark.annotation.{Since, Experimental} -import org.apache.spark.ml.param.{Param, Params, IntParam, ParamMap} -import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkException +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} -import org.apache.spark.sql.{DataFrame, Row} - /** * Common params for KMeans and KMeansModel @@ -94,7 +96,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Experimental class KMeansModel private[ml] ( @Since("1.5.0") override val uid: String, - private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams { + private val parentModel: MLlibKMeansModel) + extends Model[KMeansModel] with KMeansParams with MLWritable { @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { @@ -102,8 +105,8 @@ class KMeansModel private[ml] ( copyValues(copied, extra) } - @Since("1.5.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val predictUDF = udf((vector: Vector) => predict(vector)) dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } @@ -123,12 +126,81 @@ class KMeansModel private[ml] ( * model on the given data. */ // TODO: Replace the temp fix when we have proper evaluators defined for clustering. - @Since("1.6.0") - def computeCost(dataset: DataFrame): Double = { + @Since("2.0.0") + def computeCost(dataset: Dataset[_]): Double = { SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) - val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } + val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } parentModel.computeCost(data) } + + @Since("1.6.0") + override def write: MLWriter = new KMeansModel.KMeansModelWriter(this) + + private var trainingSummary: Option[KMeansSummary] = None + + private[clustering] def setSummary(summary: KMeansSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** + * Return true if there exists summary of model. + */ + @Since("2.0.0") + def hasSummary: Boolean = trainingSummary.nonEmpty + + /** + * Gets summary of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + @Since("2.0.0") + def summary: KMeansSummary = trainingSummary.getOrElse { + throw new SparkException( + s"No training summary available for the ${this.getClass.getSimpleName}") + } +} + +@Since("1.6.0") +object KMeansModel extends MLReadable[KMeansModel] { + + @Since("1.6.0") + override def read: MLReader[KMeansModel] = new KMeansModelReader + + @Since("1.6.0") + override def load(path: String): KMeansModel = super.load(path) + + /** [[MLWriter]] instance for [[KMeansModel]] */ + private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { + + private case class Data(clusterCenters: Array[Vector]) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: cluster centers + val data = Data(instance.clusterCenters) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class KMeansModelReader extends MLReader[KMeansModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[KMeansModel].getName + + override def load(path: String): KMeansModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("clusterCenters").head() + val clusterCenters = data.getAs[Seq[Vector]](0).toArray + val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** @@ -141,7 +213,7 @@ class KMeansModel private[ml] ( @Experimental class KMeans @Since("1.5.0") ( @Since("1.5.0") override val uid: String) - extends Estimator[KMeansModel] with KMeansParams { + extends Estimator[KMeansModel] with KMeansParams with DefaultParamsWritable { setDefault( k -> 2, @@ -188,9 +260,9 @@ class KMeans @Since("1.5.0") ( @Since("1.5.0") def setSeed(value: Long): this.type = set(seed, value) - @Since("1.5.0") - override def fit(dataset: DataFrame): KMeansModel = { - val rdd = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } + @Since("2.0.0") + override def fit(dataset: Dataset[_]): KMeansModel = { + val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } val algo = new MLlibKMeans() .setK($(k)) @@ -200,8 +272,10 @@ class KMeans @Since("1.5.0") ( .setSeed($(seed)) .setEpsilon($(tol)) val parentModel = algo.run(rdd) - val model = new KMeansModel(uid, parentModel) - copyValues(model) + val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) + val summary = new KMeansSummary( + model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) + model.setSummary(summary) } @Since("1.5.0") @@ -210,3 +284,46 @@ class KMeans @Since("1.5.0") ( } } +@Since("1.6.0") +object KMeans extends DefaultParamsReadable[KMeans] { + + @Since("1.6.0") + override def load(path: String): KMeans = super.load(path) +} + +/** + * :: Experimental :: + * Summary of KMeans. + * + * @param predictions [[DataFrame]] produced by [[KMeansModel.transform()]] + * @param predictionCol Name for column of predicted clusters in `predictions` + * @param featuresCol Name for column of features in `predictions` + * @param k Number of clusters + */ +@Since("2.0.0") +@Experimental +class KMeansSummary private[clustering] ( + @Since("2.0.0") @transient val predictions: DataFrame, + @Since("2.0.0") val predictionCol: String, + @Since("2.0.0") val featuresCol: String, + @Since("2.0.0") val k: Int) extends Serializable { + + /** + * Cluster centers of the transformed data. + */ + @Since("2.0.0") + @transient lazy val cluster: DataFrame = predictions.select(predictionCol) + + /** + * Size of (number of data points in) each cluster. + */ + @Since("2.0.0") + lazy val clusterSizes: Array[Long] = { + val sizes = Array.fill[Long](k)(0) + cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach { + case Row(cluster: Int, count: Long) => sizes(cluster) = count + } + sizes + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala new file mode 100644 index 0000000000000..c57ceba4a9977 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -0,0 +1,888 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} +import org.apache.spark.internal.Logging +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter, HasSeed} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, + EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, + LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, + OnlineLDAOptimizer => OldOnlineLDAOptimizer} +import org.apache.spark.mllib.impl.PeriodicCheckpointer +import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors, VectorUDT} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} +import org.apache.spark.sql.functions.{col, monotonicallyIncreasingId, udf} +import org.apache.spark.sql.types.StructType + + +private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter + with HasSeed with HasCheckpointInterval { + + /** + * Param for the number of topics (clusters) to infer. Must be > 1. Default: 10. + * + * @group param + */ + @Since("1.6.0") + final val k = new IntParam(this, "k", "number of topics (clusters) to infer", + ParamValidators.gt(1)) + + /** @group getParam */ + @Since("1.6.0") + def getK: Int = $(k) + + /** + * Concentration parameter (commonly named "alpha") for the prior placed on documents' + * distributions over topics ("theta"). + * + * This is the parameter to a Dirichlet distribution, where larger values mean more smoothing + * (more regularization). + * + * If not set by the user, then docConcentration is set automatically. If set to + * singleton vector [alpha], then alpha is replicated to a vector of length k in fitting. + * Otherwise, the [[docConcentration]] vector must be length k. + * (default = automatic) + * + * Optimizer-specific parameter settings: + * - EM + * - Currently only supports symmetric distributions, so all values in the vector should be + * the same. + * - Values should be > 1.0 + * - default = uniformly (50 / k) + 1, where 50/k is common in LDA libraries and +1 follows + * from Asuncion et al. (2009), who recommend a +1 adjustment for EM. + * - Online + * - Values should be >= 0 + * - default = uniformly (1.0 / k), following the implementation from + * [[https://github.com/Blei-Lab/onlineldavb]]. + * @group param + */ + @Since("1.6.0") + final val docConcentration = new DoubleArrayParam(this, "docConcentration", + "Concentration parameter (commonly named \"alpha\") for the prior placed on documents'" + + " distributions over topics (\"theta\").", (alpha: Array[Double]) => alpha.forall(_ >= 0.0)) + + /** @group getParam */ + @Since("1.6.0") + def getDocConcentration: Array[Double] = $(docConcentration) + + /** Get docConcentration used by spark.mllib LDA */ + protected def getOldDocConcentration: Vector = { + if (isSet(docConcentration)) { + Vectors.dense(getDocConcentration) + } else { + Vectors.dense(-1.0) + } + } + + /** + * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics' + * distributions over terms. + * + * This is the parameter to a symmetric Dirichlet distribution. + * + * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. + * + * If not set by the user, then topicConcentration is set automatically. + * (default = automatic) + * + * Optimizer-specific parameter settings: + * - EM + * - Value should be > 1.0 + * - default = 0.1 + 1, where 0.1 gives a small amount of smoothing and +1 follows + * Asuncion et al. (2009), who recommend a +1 adjustment for EM. + * - Online + * - Value should be >= 0 + * - default = (1.0 / k), following the implementation from + * [[https://github.com/Blei-Lab/onlineldavb]]. + * @group param + */ + @Since("1.6.0") + final val topicConcentration = new DoubleParam(this, "topicConcentration", + "Concentration parameter (commonly named \"beta\" or \"eta\") for the prior placed on topic'" + + " distributions over terms.", ParamValidators.gtEq(0)) + + /** @group getParam */ + @Since("1.6.0") + def getTopicConcentration: Double = $(topicConcentration) + + /** Get topicConcentration used by spark.mllib LDA */ + protected def getOldTopicConcentration: Double = { + if (isSet(topicConcentration)) { + getTopicConcentration + } else { + -1.0 + } + } + + /** Supported values for Param [[optimizer]]. */ + @Since("1.6.0") + final val supportedOptimizers: Array[String] = Array("online", "em") + + /** + * Optimizer or inference algorithm used to estimate the LDA model. + * Currently supported (case-insensitive): + * - "online": Online Variational Bayes (default) + * - "em": Expectation-Maximization + * + * For details, see the following papers: + * - Online LDA: + * Hoffman, Blei and Bach. "Online Learning for Latent Dirichlet Allocation." + * Neural Information Processing Systems, 2010. + * [[http://www.cs.columbia.edu/~blei/papers/HoffmanBleiBach2010b.pdf]] + * - EM: + * Asuncion et al. "On Smoothing and Inference for Topic Models." + * Uncertainty in Artificial Intelligence, 2009. + * [[http://arxiv.org/pdf/1205.2662.pdf]] + * + * @group param + */ + @Since("1.6.0") + final val optimizer = new Param[String](this, "optimizer", "Optimizer or inference" + + " algorithm used to estimate the LDA model. Supported: " + supportedOptimizers.mkString(", "), + (o: String) => ParamValidators.inArray(supportedOptimizers).apply(o.toLowerCase)) + + /** @group getParam */ + @Since("1.6.0") + def getOptimizer: String = $(optimizer) + + /** + * Output column with estimates of the topic mixture distribution for each document (often called + * "theta" in the literature). Returns a vector of zeros for an empty document. + * + * This uses a variational approximation following Hoffman et al. (2010), where the approximate + * distribution is called "gamma." Technically, this method returns this approximation "gamma" + * for each document. + * + * @group param + */ + @Since("1.6.0") + final val topicDistributionCol = new Param[String](this, "topicDistributionCol", "Output column" + + " with estimates of the topic mixture distribution for each document (often called \"theta\"" + + " in the literature). Returns a vector of zeros for an empty document.") + + setDefault(topicDistributionCol -> "topicDistribution") + + /** @group getParam */ + @Since("1.6.0") + def getTopicDistributionCol: String = $(topicDistributionCol) + + /** + * For Online optimizer only: [[optimizer]] = "online". + * + * A (positive) learning parameter that downweights early iterations. Larger values make early + * iterations count less. + * This is called "tau0" in the Online LDA paper (Hoffman et al., 2010) + * Default: 1024, following Hoffman et al. + * + * @group expertParam + */ + @Since("1.6.0") + final val learningOffset = new DoubleParam(this, "learningOffset", "(For online optimizer)" + + " A (positive) learning parameter that downweights early iterations. Larger values make early" + + " iterations count less.", + ParamValidators.gt(0)) + + /** @group expertGetParam */ + @Since("1.6.0") + def getLearningOffset: Double = $(learningOffset) + + /** + * For Online optimizer only: [[optimizer]] = "online". + * + * Learning rate, set as an exponential decay rate. + * This should be between (0.5, 1.0] to guarantee asymptotic convergence. + * This is called "kappa" in the Online LDA paper (Hoffman et al., 2010). + * Default: 0.51, based on Hoffman et al. + * + * @group expertParam + */ + @Since("1.6.0") + final val learningDecay = new DoubleParam(this, "learningDecay", "(For online optimizer)" + + " Learning rate, set as an exponential decay rate. This should be between (0.5, 1.0] to" + + " guarantee asymptotic convergence.", ParamValidators.gt(0)) + + /** @group expertGetParam */ + @Since("1.6.0") + def getLearningDecay: Double = $(learningDecay) + + /** + * For Online optimizer only: [[optimizer]] = "online". + * + * Fraction of the corpus to be sampled and used in each iteration of mini-batch gradient descent, + * in range (0, 1]. + * + * Note that this should be adjusted in synch with [[LDA.maxIter]] + * so the entire corpus is used. Specifically, set both so that + * maxIterations * miniBatchFraction >= 1. + * + * Note: This is the same as the `miniBatchFraction` parameter in + * [[org.apache.spark.mllib.clustering.OnlineLDAOptimizer]]. + * + * Default: 0.05, i.e., 5% of total documents. + * + * @group param + */ + @Since("1.6.0") + final val subsamplingRate = new DoubleParam(this, "subsamplingRate", "(For online optimizer)" + + " Fraction of the corpus to be sampled and used in each iteration of mini-batch" + + " gradient descent, in range (0, 1].", + ParamValidators.inRange(0.0, 1.0, lowerInclusive = false, upperInclusive = true)) + + /** @group getParam */ + @Since("1.6.0") + def getSubsamplingRate: Double = $(subsamplingRate) + + /** + * For Online optimizer only (currently): [[optimizer]] = "online". + * + * Indicates whether the docConcentration (Dirichlet parameter for + * document-topic distribution) will be optimized during training. + * Setting this to true will make the model more expressive and fit the training data better. + * Default: false + * + * @group expertParam + */ + @Since("1.6.0") + final val optimizeDocConcentration = new BooleanParam(this, "optimizeDocConcentration", + "(For online optimizer only, currently) Indicates whether the docConcentration" + + " (Dirichlet parameter for document-topic distribution) will be optimized during training.") + + /** @group expertGetParam */ + @Since("1.6.0") + def getOptimizeDocConcentration: Boolean = $(optimizeDocConcentration) + + /** + * For EM optimizer only: [[optimizer]] = "em". + * + * If using checkpointing, this indicates whether to keep the last + * checkpoint. If false, then the checkpoint will be deleted. Deleting the checkpoint can + * cause failures if a data partition is lost, so set this bit with care. + * Note that checkpoints will be cleaned up via reference counting, regardless. + * + * See [[DistributedLDAModel.getCheckpointFiles]] for getting remaining checkpoints and + * [[DistributedLDAModel.deleteCheckpointFiles]] for removing remaining checkpoints. + * + * Default: true + * + * @group expertParam + */ + @Since("2.0.0") + final val keepLastCheckpoint = new BooleanParam(this, "keepLastCheckpoint", + "(For EM optimizer) If using checkpointing, this indicates whether to keep the last" + + " checkpoint. If false, then the checkpoint will be deleted. Deleting the checkpoint can" + + " cause failures if a data partition is lost, so set this bit with care.") + + /** @group expertGetParam */ + @Since("2.0.0") + def getKeepLastCheckpoint: Boolean = $(keepLastCheckpoint) + + /** + * Validates and transforms the input schema. + * + * @param schema input schema + * @return output schema + */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + if (isSet(docConcentration)) { + if (getDocConcentration.length != 1) { + require(getDocConcentration.length == getK, s"LDA docConcentration was of length" + + s" ${getDocConcentration.length}, but k = $getK. docConcentration must be an array of" + + s" length either 1 (scalar) or k (num topics).") + } + getOptimizer match { + case "online" => + require(getDocConcentration.forall(_ >= 0), + "For Online LDA optimizer, docConcentration values must be >= 0. Found values: " + + getDocConcentration.mkString(",")) + case "em" => + require(getDocConcentration.forall(_ >= 0), + "For EM optimizer, docConcentration values must be >= 1. Found values: " + + getDocConcentration.mkString(",")) + } + } + if (isSet(topicConcentration)) { + getOptimizer match { + case "online" => + require(getTopicConcentration >= 0, s"For Online LDA optimizer, topicConcentration" + + s" must be >= 0. Found value: $getTopicConcentration") + case "em" => + require(getTopicConcentration >= 0, s"For EM optimizer, topicConcentration" + + s" must be >= 1. Found value: $getTopicConcentration") + } + } + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT) + } + + private[clustering] def getOldOptimizer: OldLDAOptimizer = getOptimizer match { + case "online" => + new OldOnlineLDAOptimizer() + .setTau0($(learningOffset)) + .setKappa($(learningDecay)) + .setMiniBatchFraction($(subsamplingRate)) + .setOptimizeDocConcentration($(optimizeDocConcentration)) + case "em" => + new OldEMLDAOptimizer() + .setKeepLastCheckpoint($(keepLastCheckpoint)) + } +} + + +/** + * :: Experimental :: + * Model fitted by [[LDA]]. + * + * @param vocabSize Vocabulary size (number of terms or terms in the vocabulary) + * @param sqlContext Used to construct local DataFrames for returning query results + */ +@Since("1.6.0") +@Experimental +sealed abstract class LDAModel private[ml] ( + @Since("1.6.0") override val uid: String, + @Since("1.6.0") val vocabSize: Int, + @Since("1.6.0") @transient protected val sqlContext: SQLContext) + extends Model[LDAModel] with LDAParams with Logging with MLWritable { + + // NOTE to developers: + // This abstraction should contain all important functionality for basic LDA usage. + // Specializations of this class can contain expert-only functionality. + + /** + * Underlying spark.mllib model. + * If this model was produced by Online LDA, then this is the only model representation. + * If this model was produced by EM, then this local representation may be built lazily. + */ + @Since("1.6.0") + protected def oldLocalModel: OldLocalLDAModel + + /** Returns underlying spark.mllib model, which may be local or distributed */ + @Since("1.6.0") + protected def getModel: OldLDAModel + + /** + * The features for LDA should be a [[Vector]] representing the word counts in a document. + * The vector should be of length vocabSize, with counts for each term (word). + * + * @group setParam + */ + @Since("1.6.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setSeed(value: Long): this.type = set(seed, value) + + /** + * Transforms the input dataset. + * + * WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]] + * is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver. + * This implementation may be changed in the future. + */ + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { + if ($(topicDistributionCol).nonEmpty) { + val t = udf(oldLocalModel.getTopicDistributionMethod(sqlContext.sparkContext)) + dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF + } else { + logWarning("LDAModel.transform was called without any output columns. Set an output column" + + " such as topicDistributionCol to produce results.") + dataset.toDF + } + } + + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + /** + * Value for [[docConcentration]] estimated from data. + * If Online LDA was used and [[optimizeDocConcentration]] was set to false, + * then this returns the fixed (given) value for the [[docConcentration]] parameter. + */ + @Since("1.6.0") + def estimatedDocConcentration: Vector = getModel.docConcentration + + /** + * Inferred topics, where each topic is represented by a distribution over terms. + * This is a matrix of size vocabSize x k, where each column is a topic. + * No guarantees are given about the ordering of the topics. + * + * WARNING: If this model is actually a [[DistributedLDAModel]] instance produced by + * the Expectation-Maximization ("em") [[optimizer]], then this method could involve + * collecting a large amount of data to the driver (on the order of vocabSize x k). + */ + @Since("1.6.0") + def topicsMatrix: Matrix = oldLocalModel.topicsMatrix + + /** Indicates whether this instance is of type [[DistributedLDAModel]] */ + @Since("1.6.0") + def isDistributed: Boolean + + /** + * Calculates a lower bound on the log likelihood of the entire corpus. + * + * See Equation (16) in the Online LDA paper (Hoffman et al., 2010). + * + * WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]] + * is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver. + * This implementation may be changed in the future. + * + * @param dataset test corpus to use for calculating log likelihood + * @return variational lower bound on the log likelihood of the entire corpus + */ + @Since("2.0.0") + def logLikelihood(dataset: Dataset[_]): Double = { + val oldDataset = LDA.getOldDataset(dataset, $(featuresCol)) + oldLocalModel.logLikelihood(oldDataset) + } + + /** + * Calculate an upper bound bound on perplexity. (Lower is better.) + * See Equation (16) in the Online LDA paper (Hoffman et al., 2010). + * + * WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]] + * is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver. + * This implementation may be changed in the future. + * + * @param dataset test corpus to use for calculating perplexity + * @return Variational upper bound on log perplexity per token. + */ + @Since("2.0.0") + def logPerplexity(dataset: Dataset[_]): Double = { + val oldDataset = LDA.getOldDataset(dataset, $(featuresCol)) + oldLocalModel.logPerplexity(oldDataset) + } + + /** + * Return the topics described by their top-weighted terms. + * + * @param maxTermsPerTopic Maximum number of terms to collect for each topic. + * Default value of 10. + * @return Local DataFrame with one topic per Row, with columns: + * - "topic": IntegerType: topic index + * - "termIndices": ArrayType(IntegerType): term indices, sorted in order of decreasing + * term importance + * - "termWeights": ArrayType(DoubleType): corresponding sorted term weights + */ + @Since("1.6.0") + def describeTopics(maxTermsPerTopic: Int): DataFrame = { + val topics = getModel.describeTopics(maxTermsPerTopic).zipWithIndex.map { + case ((termIndices, termWeights), topic) => + (topic, termIndices.toSeq, termWeights.toSeq) + } + sqlContext.createDataFrame(topics).toDF("topic", "termIndices", "termWeights") + } + + @Since("1.6.0") + def describeTopics(): DataFrame = describeTopics(10) +} + + +/** + * :: Experimental :: + * + * Local (non-distributed) model fitted by [[LDA]]. + * + * This model stores the inferred topics only; it does not store info about the training dataset. + */ +@Since("1.6.0") +@Experimental +class LocalLDAModel private[ml] ( + uid: String, + vocabSize: Int, + @Since("1.6.0") override protected val oldLocalModel: OldLocalLDAModel, + sqlContext: SQLContext) + extends LDAModel(uid, vocabSize, sqlContext) { + + @Since("1.6.0") + override def copy(extra: ParamMap): LocalLDAModel = { + val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext) + copyValues(copied, extra).setParent(parent).asInstanceOf[LocalLDAModel] + } + + override protected def getModel: OldLDAModel = oldLocalModel + + @Since("1.6.0") + override def isDistributed: Boolean = false + + @Since("1.6.0") + override def write: MLWriter = new LocalLDAModel.LocalLDAModelWriter(this) +} + + +@Since("1.6.0") +object LocalLDAModel extends MLReadable[LocalLDAModel] { + + private[LocalLDAModel] + class LocalLDAModelWriter(instance: LocalLDAModel) extends MLWriter { + + private case class Data( + vocabSize: Int, + topicsMatrix: Matrix, + docConcentration: Vector, + topicConcentration: Double, + gammaShape: Double) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val oldModel = instance.oldLocalModel + val data = Data(instance.vocabSize, oldModel.topicsMatrix, oldModel.docConcentration, + oldModel.topicConcentration, oldModel.gammaShape) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class LocalLDAModelReader extends MLReader[LocalLDAModel] { + + private val className = classOf[LocalLDAModel].getName + + override def load(path: String): LocalLDAModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration", + "gammaShape") + .head() + val vocabSize = data.getAs[Int](0) + val topicsMatrix = data.getAs[Matrix](1) + val docConcentration = data.getAs[Vector](2) + val topicConcentration = data.getAs[Double](3) + val gammaShape = data.getAs[Double](4) + val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration, + gammaShape) + val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sqlContext) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[LocalLDAModel] = new LocalLDAModelReader + + @Since("1.6.0") + override def load(path: String): LocalLDAModel = super.load(path) +} + + +/** + * :: Experimental :: + * + * Distributed model fitted by [[LDA]]. + * This type of model is currently only produced by Expectation-Maximization (EM). + * + * This model stores the inferred topics, the full training dataset, and the topic distribution + * for each training document. + * + * @param oldLocalModelOption Used to implement [[oldLocalModel]] as a lazy val, but keeping + * [[copy()]] cheap. + */ +@Since("1.6.0") +@Experimental +class DistributedLDAModel private[ml] ( + uid: String, + vocabSize: Int, + private val oldDistributedModel: OldDistributedLDAModel, + sqlContext: SQLContext, + private var oldLocalModelOption: Option[OldLocalLDAModel]) + extends LDAModel(uid, vocabSize, sqlContext) { + + override protected def oldLocalModel: OldLocalLDAModel = { + if (oldLocalModelOption.isEmpty) { + oldLocalModelOption = Some(oldDistributedModel.toLocal) + } + oldLocalModelOption.get + } + + override protected def getModel: OldLDAModel = oldDistributedModel + + /** + * Convert this distributed model to a local representation. This discards info about the + * training dataset. + * + * WARNING: This involves collecting a large [[topicsMatrix]] to the driver. + */ + @Since("1.6.0") + def toLocal: LocalLDAModel = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext) + + @Since("1.6.0") + override def copy(extra: ParamMap): DistributedLDAModel = { + val copied = + new DistributedLDAModel(uid, vocabSize, oldDistributedModel, sqlContext, oldLocalModelOption) + copyValues(copied, extra).setParent(parent) + copied + } + + @Since("1.6.0") + override def isDistributed: Boolean = true + + /** + * Log likelihood of the observed tokens in the training set, + * given the current parameter estimates: + * log P(docs | topics, topic distributions for docs, Dirichlet hyperparameters) + * + * Notes: + * - This excludes the prior; for that, use [[logPrior]]. + * - Even with [[logPrior]], this is NOT the same as the data log likelihood given the + * hyperparameters. + * - This is computed from the topic distributions computed during training. If you call + * [[logLikelihood()]] on the same training dataset, the topic distributions will be computed + * again, possibly giving different results. + */ + @Since("1.6.0") + lazy val trainingLogLikelihood: Double = oldDistributedModel.logLikelihood + + /** + * Log probability of the current parameter estimate: + * log P(topics, topic distributions for docs | Dirichlet hyperparameters) + */ + @Since("1.6.0") + lazy val logPrior: Double = oldDistributedModel.logPrior + + private var _checkpointFiles: Array[String] = oldDistributedModel.checkpointFiles + + /** + * If using checkpointing and [[LDA.keepLastCheckpoint]] is set to true, then there may be + * saved checkpoint files. This method is provided so that users can manage those files. + * + * Note that removing the checkpoints can cause failures if a partition is lost and is needed + * by certain [[DistributedLDAModel]] methods. Reference counting will clean up the checkpoints + * when this model and derivative data go out of scope. + * + * @return Checkpoint files from training + */ + @DeveloperApi + @Since("2.0.0") + def getCheckpointFiles: Array[String] = _checkpointFiles + + /** + * Remove any remaining checkpoint files from training. + * + * @see [[getCheckpointFiles]] + */ + @DeveloperApi + @Since("2.0.0") + def deleteCheckpointFiles(): Unit = { + val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration) + _checkpointFiles.foreach(PeriodicCheckpointer.removeCheckpointFile(_, fs)) + _checkpointFiles = Array.empty[String] + } + + @Since("1.6.0") + override def write: MLWriter = new DistributedLDAModel.DistributedWriter(this) +} + + +@Since("1.6.0") +object DistributedLDAModel extends MLReadable[DistributedLDAModel] { + + private[DistributedLDAModel] + class DistributedWriter(instance: DistributedLDAModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val modelPath = new Path(path, "oldModel").toString + instance.oldDistributedModel.save(sc, modelPath) + } + } + + private class DistributedLDAModelReader extends MLReader[DistributedLDAModel] { + + private val className = classOf[DistributedLDAModel].getName + + override def load(path: String): DistributedLDAModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val modelPath = new Path(path, "oldModel").toString + val oldModel = OldDistributedLDAModel.load(sc, modelPath) + val model = new DistributedLDAModel( + metadata.uid, oldModel.vocabSize, oldModel, sqlContext, None) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[DistributedLDAModel] = new DistributedLDAModelReader + + @Since("1.6.0") + override def load(path: String): DistributedLDAModel = super.load(path) +} + + +/** + * :: Experimental :: + * + * Latent Dirichlet Allocation (LDA), a topic model designed for text documents. + * + * Terminology: + * - "term" = "word": an element of the vocabulary + * - "token": instance of a term appearing in a document + * - "topic": multinomial distribution over terms representing some concept + * - "document": one piece of text, corresponding to one row in the input data + * + * References: + * - Original LDA paper (journal version): + * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. + * + * Input data (featuresCol): + * LDA is given a collection of documents as input data, via the featuresCol parameter. + * Each document is specified as a [[Vector]] of length vocabSize, where each entry is the + * count for the corresponding term (word) in the document. Feature transformers such as + * [[org.apache.spark.ml.feature.Tokenizer]] and [[org.apache.spark.ml.feature.CountVectorizer]] + * can be useful for converting text to word count vectors. + * + * @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation + * (Wikipedia)]] + */ +@Since("1.6.0") +@Experimental +class LDA @Since("1.6.0") ( + @Since("1.6.0") override val uid: String) + extends Estimator[LDAModel] with LDAParams with DefaultParamsWritable { + + @Since("1.6.0") + def this() = this(Identifiable.randomUID("lda")) + + setDefault(maxIter -> 20, k -> 10, optimizer -> "online", checkpointInterval -> 10, + learningOffset -> 1024, learningDecay -> 0.51, subsamplingRate -> 0.05, + optimizeDocConcentration -> true, keepLastCheckpoint -> true) + + /** + * The features for LDA should be a [[Vector]] representing the word counts in a document. + * The vector should be of length vocabSize, with counts for each term (word). + * + * @group setParam + */ + @Since("1.6.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + @Since("1.6.0") + def setSeed(value: Long): this.type = set(seed, value) + + /** @group setParam */ + @Since("1.6.0") + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + + /** @group setParam */ + @Since("1.6.0") + def setK(value: Int): this.type = set(k, value) + + /** @group setParam */ + @Since("1.6.0") + def setDocConcentration(value: Array[Double]): this.type = set(docConcentration, value) + + /** @group setParam */ + @Since("1.6.0") + def setDocConcentration(value: Double): this.type = set(docConcentration, Array(value)) + + /** @group setParam */ + @Since("1.6.0") + def setTopicConcentration(value: Double): this.type = set(topicConcentration, value) + + /** @group setParam */ + @Since("1.6.0") + def setOptimizer(value: String): this.type = set(optimizer, value) + + /** @group setParam */ + @Since("1.6.0") + def setTopicDistributionCol(value: String): this.type = set(topicDistributionCol, value) + + /** @group expertSetParam */ + @Since("1.6.0") + def setLearningOffset(value: Double): this.type = set(learningOffset, value) + + /** @group expertSetParam */ + @Since("1.6.0") + def setLearningDecay(value: Double): this.type = set(learningDecay, value) + + /** @group setParam */ + @Since("1.6.0") + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + + /** @group expertSetParam */ + @Since("1.6.0") + def setOptimizeDocConcentration(value: Boolean): this.type = set(optimizeDocConcentration, value) + + /** @group expertSetParam */ + @Since("2.0.0") + def setKeepLastCheckpoint(value: Boolean): this.type = set(keepLastCheckpoint, value) + + @Since("1.6.0") + override def copy(extra: ParamMap): LDA = defaultCopy(extra) + + @Since("2.0.0") + override def fit(dataset: Dataset[_]): LDAModel = { + transformSchema(dataset.schema, logging = true) + val oldLDA = new OldLDA() + .setK($(k)) + .setDocConcentration(getOldDocConcentration) + .setTopicConcentration(getOldTopicConcentration) + .setMaxIterations($(maxIter)) + .setSeed($(seed)) + .setCheckpointInterval($(checkpointInterval)) + .setOptimizer(getOldOptimizer) + // TODO: persist here, or in old LDA? + val oldData = LDA.getOldDataset(dataset, $(featuresCol)) + val oldModel = oldLDA.run(oldData) + val newModel = oldModel match { + case m: OldLocalLDAModel => + new LocalLDAModel(uid, m.vocabSize, m, dataset.sqlContext) + case m: OldDistributedLDAModel => + new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext, None) + } + copyValues(newModel).setParent(this) + } + + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } +} + + +private[clustering] object LDA extends DefaultParamsReadable[LDA] { + + /** Get dataset for spark.mllib LDA */ + def getOldDataset(dataset: Dataset[_], featuresCol: String): RDD[(Long, Vector)] = { + dataset + .withColumn("docId", monotonicallyIncreasingId()) + .select("docId", featuresCol) + .rdd + .map { case Row(docId: Long, features: Vector) => + (docId, features) + } + } + + @Since("1.6.0") + override def load(path: String): LDA = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 1fe3abaca81c3..bde8c275fda43 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -20,27 +20,28 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.types.DoubleType /** * :: Experimental :: * Evaluator for binary classification, which expects two input columns: rawPrediction and label. + * The rawPrediction column can be of type double (binary 0/1 prediction, or probability of label 1) + * or of type vector (length-2 vector of raw predictions, scores, or label probabilities). */ @Since("1.2.0") @Experimental class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Evaluator with HasRawPredictionCol with HasLabelCol { + extends Evaluator with HasRawPredictionCol with HasLabelCol with DefaultParamsWritable { @Since("1.2.0") def this() = this(Identifiable.randomUID("binEval")) /** - * param for metric name in evaluation - * Default: areaUnderROC + * param for metric name in evaluation (supports `"areaUnderROC"` (default), `"areaUnderPR"`) * @group param */ @Since("1.2.0") @@ -62,31 +63,23 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va @Since("1.5.0") def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value) - /** - * @group setParam - * @deprecated use [[setRawPredictionCol()]] instead - */ - @deprecated("use setRawPredictionCol instead", "1.5.0") - @Since("1.2.0") - def setScoreCol(value: String): this.type = set(rawPredictionCol, value) - /** @group setParam */ @Since("1.2.0") def setLabelCol(value: String): this.type = set(labelCol, value) setDefault(metricName -> "areaUnderROC") - @Since("1.2.0") - override def evaluate(dataset: DataFrame): Double = { + @Since("2.0.0") + override def evaluate(dataset: Dataset[_]): Double = { val schema = dataset.schema - SchemaUtils.checkColumnType(schema, $(rawPredictionCol), new VectorUDT) + SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT)) SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2. - val scoreAndLabels = dataset.select($(rawPredictionCol), $(labelCol)) - .map { case Row(rawPrediction: Vector, label: Double) => - (rawPrediction(1), label) - } + val scoreAndLabels = dataset.select($(rawPredictionCol), $(labelCol)).rdd.map { + case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label) + case Row(rawPrediction: Double, label: Double) => (rawPrediction, label) + } val metrics = new BinaryClassificationMetrics(scoreAndLabels) val metric = $(metricName) match { case "areaUnderROC" => metrics.areaUnderROC() @@ -105,3 +98,10 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va @Since("1.4.1") override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra) } + +@Since("1.6.0") +object BinaryClassificationEvaluator extends DefaultParamsReadable[BinaryClassificationEvaluator] { + + @Since("1.6.0") + override def load(path: String): BinaryClassificationEvaluator = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala index 0f22cca3a78d1..5f765c071b9cd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.ml.param.{ParamMap, Params} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Dataset /** * :: DeveloperApi :: @@ -36,8 +36,8 @@ abstract class Evaluator extends Params { * @param paramMap parameter map that specifies the input columns and output metrics * @return metric */ - @Since("1.5.0") - def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { + @Since("2.0.0") + def evaluate(dataset: Dataset[_], paramMap: ParamMap): Double = { this.copy(paramMap).evaluate(dataset) } @@ -46,8 +46,8 @@ abstract class Evaluator extends Params { * @param dataset a dataset that contains labels/observations and predictions. * @return metric */ - @Since("1.5.0") - def evaluate(dataset: DataFrame): Double + @Since("2.0.0") + def evaluate(dataset: Dataset[_]): Double /** * Indicates whether the metric returned by [[evaluate()]] should be maximized (true, default) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index df5f04ca5a8d9..3acfc221c95ec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -18,11 +18,11 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param} +import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} -import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.types.DoubleType /** @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.DoubleType @Since("1.5.0") @Experimental class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol { + extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { @Since("1.5.0") def this() = this(Identifiable.randomUID("mcEval")) @@ -68,15 +68,15 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid setDefault(metricName -> "f1") - @Since("1.5.0") - override def evaluate(dataset: DataFrame): Double = { + @Since("2.0.0") + override def evaluate(dataset: Dataset[_]): Double = { val schema = dataset.schema SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) - val predictionAndLabels = dataset.select($(predictionCol), $(labelCol)) - .map { case Row(prediction: Double, label: Double) => - (prediction, label) + val predictionAndLabels = dataset.select($(predictionCol), $(labelCol)).rdd.map { + case Row(prediction: Double, label: Double) => + (prediction, label) } val metrics = new MulticlassMetrics(predictionAndLabels) val metric = $(metricName) match { @@ -101,3 +101,11 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid @Since("1.5.0") override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra) } + +@Since("1.6.0") +object MulticlassClassificationEvaluator + extends DefaultParamsReadable[MulticlassClassificationEvaluator] { + + @Since("1.6.0") + override def load(path: String): MulticlassClassificationEvaluator = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index ba012f444d3e0..ed04b67bcc93b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -20,9 +20,9 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} import org.apache.spark.mllib.evaluation.RegressionMetrics -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, FloatType} @@ -33,17 +33,18 @@ import org.apache.spark.sql.types.{DoubleType, FloatType} @Since("1.4.0") @Experimental final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol { + extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("regEval")) /** - * param for metric name in evaluation (supports `"rmse"` (default), `"mse"`, `"r2"`, and `"mae"`) + * Param for metric name in evaluation. Supports: + * - `"rmse"` (default): root mean squared error + * - `"mse"`: mean squared error + * - `"r2"`: R^2^ metric + * - `"mae"`: mean absolute error * - * Because we will maximize evaluation value (ref: `CrossValidator`), - * when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`), - * we take and output the negative of this metric. * @group param */ @Since("1.4.0") @@ -70,17 +71,23 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui setDefault(metricName -> "rmse") - @Since("1.4.0") - override def evaluate(dataset: DataFrame): Double = { + @Since("2.0.0") + override def evaluate(dataset: Dataset[_]): Double = { val schema = dataset.schema + val predictionColName = $(predictionCol) val predictionType = schema($(predictionCol)).dataType - require(predictionType == FloatType || predictionType == DoubleType) + require(predictionType == FloatType || predictionType == DoubleType, + s"Prediction column $predictionColName must be of type float or double, " + + s" but not $predictionType") + val labelColName = $(labelCol) val labelType = schema($(labelCol)).dataType - require(labelType == FloatType || labelType == DoubleType) + require(labelType == FloatType || labelType == DoubleType, + s"Label column $labelColName must be of type float or double, but not $labelType") val predictionAndLabels = dataset .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType)) - .map { case Row(prediction: Double, label: Double) => + .rdd. + map { case Row(prediction: Double, label: Double) => (prediction, label) } val metrics = new RegressionMetrics(predictionAndLabels) @@ -104,3 +111,10 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui @Since("1.5.0") override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra) } + +@Since("1.6.0") +object RegressionEvaluator extends DefaultParamsReadable[RegressionEvaluator] { + + @Since("1.6.0") + override def load(path: String): RegressionEvaluator = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index edad754436455..898ac2cc8941b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -17,15 +17,18 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.linalg._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.sql.types._ /** * :: Experimental :: @@ -33,7 +36,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} */ @Experimental final class Binarizer(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("binarizer")) @@ -61,29 +64,63 @@ final class Binarizer(override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { - transformSchema(dataset.schema, logging = true) + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { + val outputSchema = transformSchema(dataset.schema, logging = true) + val schema = dataset.schema + val inputType = schema($(inputCol)).dataType val td = $(threshold) - val binarizer = udf { in: Double => if (in > td) 1.0 else 0.0 } - val outputColName = $(outputCol) - val metadata = BinaryAttribute.defaultAttr.withName(outputColName).toMetadata() - dataset.select(col("*"), - binarizer(col($(inputCol))).as(outputColName, metadata)) + + val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 } + val binarizerVector = udf { (data: Vector) => + val indices = ArrayBuilder.make[Int] + val values = ArrayBuilder.make[Double] + + data.foreachActive { (index, value) => + if (value > td) { + indices += index + values += 1.0 + } + } + + Vectors.sparse(data.size, indices.result(), values.result()).compressed + } + + val metadata = outputSchema($(outputCol)).metadata + + inputType match { + case DoubleType => + dataset.select(col("*"), binarizerDouble(col($(inputCol))).as($(outputCol), metadata)) + case _: VectorUDT => + dataset.select(col("*"), binarizerVector(col($(inputCol))).as($(outputCol), metadata)) + } } override def transformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) - - val inputFields = schema.fields + val inputType = schema($(inputCol)).dataType val outputColName = $(outputCol) - require(inputFields.forall(_.name != outputColName), - s"Output column $outputColName already exists.") - - val attr = BinaryAttribute.defaultAttr.withName(outputColName) - val outputFields = inputFields :+ attr.toStructField() - StructType(outputFields) + val outCol: StructField = inputType match { + case DoubleType => + BinaryAttribute.defaultAttr.withName(outputColName).toStructField() + case _: VectorUDT => + new StructField(outputColName, new VectorUDT, true) + case other => + throw new IllegalArgumentException(s"Data type $other is not supported.") + } + + if (schema.fieldNames.contains(outputColName)) { + throw new IllegalArgumentException(s"Output column $outputColName already exists.") + } + StructType(schema.fields :+ outCol) } override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) } + +@Since("1.6.0") +object Binarizer extends DefaultParamsReadable[Binarizer] { + + @Since("1.6.0") + override def load(path: String): Binarizer = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 6fdf25b015b0b..10e622ace6d5e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -20,12 +20,12 @@ package org.apache.spark.ml.feature import java.{util => ju} import org.apache.spark.SparkException -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Model import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} */ @Experimental final class Bucketizer(override val uid: String) - extends Model[Bucketizer] with HasInputCol with HasOutputCol { + extends Model[Bucketizer] with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("bucketizer")) @@ -68,7 +68,8 @@ final class Bucketizer(override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) val bucketizer = udf { feature: Double => Bucketizer.binarySearchForBuckets($(splits), feature) @@ -95,9 +96,10 @@ final class Bucketizer(override val uid: String) } } -private[feature] object Bucketizer { +object Bucketizer extends DefaultParamsReadable[Bucketizer] { + /** We require splits to be of length >= 3 and to be in strictly increasing order. */ - def checkSplits(splits: Array[Double]): Boolean = { + private[feature] def checkSplits(splits: Array[Double]): Boolean = { if (splits.length < 3) { false } else { @@ -115,7 +117,7 @@ private[feature] object Bucketizer { * Binary searching in several buckets to place each data point. * @throws SparkException if a feature is < splits.head or > splits.last */ - def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = { + private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = { if (feature == splits.last) { splits.length - 2 } else { @@ -134,4 +136,7 @@ private[feature] object Bucketizer { } } } + + @Since("1.6.0") + override def load(path: String): Bucketizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 5e4061fba5494..cfecae7e0b152 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -17,13 +17,14 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.attribute.{AttributeGroup, _} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint @@ -60,7 +61,7 @@ private[feature] trait ChiSqSelectorParams extends Params */ @Experimental final class ChiSqSelector(override val uid: String) - extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams { + extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("chiSqSelector")) @@ -76,9 +77,10 @@ final class ChiSqSelector(override val uid: String) /** @group setParam */ def setLabelCol(value: String): this.type = set(labelCol, value) - override def fit(dataset: DataFrame): ChiSqSelectorModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): ChiSqSelectorModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(labelCol), $(featuresCol)).map { + val input = dataset.select($(labelCol), $(featuresCol)).rdd.map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } @@ -95,6 +97,13 @@ final class ChiSqSelector(override val uid: String) override def copy(extra: ParamMap): ChiSqSelector = defaultCopy(extra) } +@Since("1.6.0") +object ChiSqSelector extends DefaultParamsReadable[ChiSqSelector] { + + @Since("1.6.0") + override def load(path: String): ChiSqSelector = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[ChiSqSelector]]. @@ -103,7 +112,12 @@ final class ChiSqSelector(override val uid: String) final class ChiSqSelectorModel private[ml] ( override val uid: String, private val chiSqSelector: feature.ChiSqSelectorModel) - extends Model[ChiSqSelectorModel] with ChiSqSelectorParams { + extends Model[ChiSqSelectorModel] with ChiSqSelectorParams with MLWritable { + + import ChiSqSelectorModel._ + + /** list of indices to select (filter). Must be ordered asc */ + val selectedFeatures: Array[Int] = chiSqSelector.selectedFeatures /** @group setParam */ def setFeaturesCol(value: String): this.type = set(featuresCol, value) @@ -114,7 +128,8 @@ final class ChiSqSelectorModel private[ml] ( /** @group setParam */ def setLabelCol(value: String): this.type = set(labelCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val transformedSchema = transformSchema(dataset.schema, logging = true) val newField = transformedSchema.last val selector = udf { chiSqSelector.transform _ } @@ -147,4 +162,46 @@ final class ChiSqSelectorModel private[ml] ( val copied = new ChiSqSelectorModel(uid, chiSqSelector) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new ChiSqSelectorModelWriter(this) +} + +@Since("1.6.0") +object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] { + + private[ChiSqSelectorModel] + class ChiSqSelectorModelWriter(instance: ChiSqSelectorModel) extends MLWriter { + + private case class Data(selectedFeatures: Seq[Int]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.selectedFeatures.toSeq) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class ChiSqSelectorModelReader extends MLReader[ChiSqSelectorModel] { + + private val className = classOf[ChiSqSelectorModel].getName + + override def load(path: String): ChiSqSelectorModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("selectedFeatures").head() + val selectedFeatures = data.getAs[Seq[Int]](0).toArray + val oldModel = new feature.ChiSqSelectorModel(selectedFeatures) + val model = new ChiSqSelectorModel(metadata.uid, oldModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[ChiSqSelectorModel] = new ChiSqSelectorModelReader + + @Since("1.6.0") + override def load(path: String): ChiSqSelectorModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 49028e4b85064..922670a41b6b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -16,17 +16,19 @@ */ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.linalg.{Vectors, VectorUDT} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.DataFrame import org.apache.spark.util.collection.OpenHashMap /** @@ -68,7 +70,8 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) + val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false)) + SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } @@ -97,6 +100,21 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit /** @group getParam */ def getMinTF: Double = $(minTF) + + /** + * Binary toggle to control the output vector values. + * If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful for + * discrete probabilistic models that model binary events rather than integer counts. + * Default: false + * @group param + */ + val binary: BooleanParam = + new BooleanParam(this, "binary", "If True, all non zero counts are set to 1.") + + /** @group getParam */ + def getBinary: Boolean = $(binary) + + setDefault(binary -> false) } /** @@ -105,7 +123,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit */ @Experimental class CountVectorizer(override val uid: String) - extends Estimator[CountVectorizerModel] with CountVectorizerParams { + extends Estimator[CountVectorizerModel] with CountVectorizerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("cntVec")) @@ -124,12 +142,16 @@ class CountVectorizer(override val uid: String) /** @group setParam */ def setMinTF(value: Double): this.type = set(minTF, value) + /** @group setParam */ + def setBinary(value: Boolean): this.type = set(binary, value) + setDefault(vocabSize -> (1 << 18), minDF -> 1) - override def fit(dataset: DataFrame): CountVectorizerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): CountVectorizerModel = { transformSchema(dataset.schema, logging = true) val vocSize = $(vocabSize) - val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0)) + val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0)) val minDf = if ($(minDF) >= 1.0) { $(minDF) } else { @@ -149,16 +171,10 @@ class CountVectorizer(override val uid: String) (word, count) }.cache() val fullVocabSize = wordCounts.count() - val vocab: Array[String] = { - val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) { - // Use all terms - wordCounts.collect().sortBy(-_._2) - } else { - // Sort terms to select vocab - wordCounts.sortBy(_._2, ascending = false).take(vocSize) - } - tmpSortedWC.map(_._1) - } + + val vocab = wordCounts + .top(math.min(fullVocabSize, vocSize).toInt)(Ordering.by(_._2)) + .map(_._1) require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.") copyValues(new CountVectorizerModel(uid, vocab).setParent(this)) @@ -171,6 +187,13 @@ class CountVectorizer(override val uid: String) override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra) } +@Since("1.6.0") +object CountVectorizer extends DefaultParamsReadable[CountVectorizer] { + + @Since("1.6.0") + override def load(path: String): CountVectorizer = super.load(path) +} + /** * :: Experimental :: * Converts a text document to a sparse vector of token counts. @@ -178,7 +201,9 @@ class CountVectorizer(override val uid: String) */ @Experimental class CountVectorizerModel(override val uid: String, val vocabulary: Array[String]) - extends Model[CountVectorizerModel] with CountVectorizerParams { + extends Model[CountVectorizerModel] with CountVectorizerParams with MLWritable { + + import CountVectorizerModel._ def this(vocabulary: Array[String]) = { this(Identifiable.randomUID("cntVecModel"), vocabulary) @@ -194,10 +219,15 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin /** @group setParam */ def setMinTF(value: Double): this.type = set(minTF, value) + /** @group setParam */ + def setBinary(value: Boolean): this.type = set(binary, value) + /** Dictionary created from [[vocabulary]] and its indices, broadcast once for [[transform()]] */ private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) if (broadcastDict.isEmpty) { val dict = vocabulary.zipWithIndex.toMap broadcastDict = Some(dataset.sqlContext.sparkContext.broadcast(dict)) @@ -214,12 +244,14 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin } tokenCount += 1 } - val effectiveMinTF = if (minTf >= 1.0) { - minTf + val effectiveMinTF = if (minTf >= 1.0) minTf else tokenCount * minTf + val effectiveCounts = if ($(binary)) { + termCounts.filter(_._2 >= effectiveMinTF).map(p => (p._1, 1.0)).toSeq } else { - tokenCount * minTf + termCounts.filter(_._2 >= effectiveMinTF).toSeq } - Vectors.sparse(dictBr.value.size, termCounts.filter(_._2 >= effectiveMinTF).toSeq) + + Vectors.sparse(dictBr.value.size, effectiveCounts) } dataset.withColumn($(outputCol), vectorizer(col($(inputCol)))) } @@ -232,4 +264,47 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin val copied = new CountVectorizerModel(uid, vocabulary).setParent(parent) copyValues(copied, extra) } + + @Since("1.6.0") + override def write: MLWriter = new CountVectorizerModelWriter(this) +} + +@Since("1.6.0") +object CountVectorizerModel extends MLReadable[CountVectorizerModel] { + + private[CountVectorizerModel] + class CountVectorizerModelWriter(instance: CountVectorizerModel) extends MLWriter { + + private case class Data(vocabulary: Seq[String]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.vocabulary) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class CountVectorizerModelReader extends MLReader[CountVectorizerModel] { + + private val className = classOf[CountVectorizerModel].getName + + override def load(path: String): CountVectorizerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("vocabulary") + .head() + val vocabulary = data.getAs[Seq[String]](0).toArray + val model = new CountVectorizerModel(metadata.uid, vocabulary) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[CountVectorizerModel] = new CountVectorizerModelReader + + @Since("1.6.0") + override def load(path: String): CountVectorizerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index 228347635c92b..a6f878151de73 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -19,11 +19,11 @@ package org.apache.spark.ml.feature import edu.emory.mathcs.jtransforms.dct._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.BooleanParam -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.sql.types.DataType /** @@ -37,7 +37,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class DCT(override val uid: String) - extends UnaryTransformer[Vector, Vector, DCT] { + extends UnaryTransformer[Vector, Vector, DCT] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("dct")) @@ -70,3 +70,10 @@ class DCT(override val uid: String) override protected def outputDataType: DataType = new VectorUDT } + +@Since("1.6.0") +object DCT extends DefaultParamsReadable[DCT] { + + @Since("1.6.0") + override def load(path: String): DCT = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index a359cb8f37ec3..1b0a9a12e83bc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{ParamMap, Param} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.param.Param +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.types.DataType @@ -33,14 +33,14 @@ import org.apache.spark.sql.types.DataType */ @Experimental class ElementwiseProduct(override val uid: String) - extends UnaryTransformer[Vector, Vector, ElementwiseProduct] { + extends UnaryTransformer[Vector, Vector, ElementwiseProduct] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("elemProd")) /** - * the vector to multiply with input vectors - * @group param - */ + * the vector to multiply with input vectors + * @group param + */ val scalingVec: Param[Vector] = new Param(this, "scalingVec", "vector for hadamard product") /** @group setParam */ @@ -57,3 +57,10 @@ class ElementwiseProduct(override val uid: String) override protected def outputDataType: DataType = new VectorUDT() } + +@Since("2.0.0") +object ElementwiseProduct extends DefaultParamsReadable[ElementwiseProduct] { + + @Since("2.0.0") + override def load(path: String): ElementwiseProduct = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 319d23e46cef4..467ad7307462a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -17,14 +17,14 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} +import org.apache.spark.ml.param.{BooleanParam, IntParam, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{ArrayType, StructType} @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.{ArrayType, StructType} * Maps a sequence of terms to their term frequencies using the hashing trick. */ @Experimental -class HashingTF(override val uid: String) extends Transformer with HasInputCol with HasOutputCol { +class HashingTF(override val uid: String) + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("hashingTF")) @@ -51,7 +52,18 @@ class HashingTF(override val uid: String) extends Transformer with HasInputCol w val numFeatures = new IntParam(this, "numFeatures", "number of features (> 0)", ParamValidators.gt(0)) - setDefault(numFeatures -> (1 << 18)) + /** + * Binary toggle to control term frequency counts. + * If true, all non-zero counts are set to 1. This is useful for discrete probabilistic + * models that model binary events rather than integer counts. + * (default = false) + * @group param + */ + val binary = new BooleanParam(this, "binary", "If true, all non zero counts are set to 1. " + + "This is useful for discrete probabilistic models that model binary events rather " + + "than integer counts") + + setDefault(numFeatures -> (1 << 18), binary -> false) /** @group getParam */ def getNumFeatures: Int = $(numFeatures) @@ -59,9 +71,16 @@ class HashingTF(override val uid: String) extends Transformer with HasInputCol w /** @group setParam */ def setNumFeatures(value: Int): this.type = set(numFeatures, value) - override def transform(dataset: DataFrame): DataFrame = { + /** @group getParam */ + def getBinary: Boolean = $(binary) + + /** @group setParam */ + def setBinary(value: Boolean): this.type = set(binary, value) + + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema) - val hashingTF = new feature.HashingTF($(numFeatures)) + val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary)) val t = udf { terms: Seq[_] => hashingTF.transform(terms) } val metadata = outputSchema($(outputCol)).metadata dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata)) @@ -77,3 +96,10 @@ class HashingTF(override val uid: String) extends Transformer with HasInputCol w override def copy(extra: ParamMap): HashingTF = defaultCopy(extra) } + +@Since("1.6.0") +object HashingTF extends DefaultParamsReadable[HashingTF] { + + @Since("1.6.0") + override def load(path: String): HashingTF = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 4c36df75d8aa0..5075b78c9856a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -17,11 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ @@ -60,7 +62,8 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol * Compute the Inverse Document Frequency (IDF) given a collection of documents. */ @Experimental -final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase { +final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase + with DefaultParamsWritable { def this() = this(Identifiable.randomUID("idf")) @@ -73,9 +76,10 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa /** @group setParam */ def setMinDocFreq(value: Int): this.type = set(minDocFreq, value) - override def fit(dataset: DataFrame): IDFModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): IDFModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } + val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val idf = new feature.IDF($(minDocFreq)).fit(input) copyValues(new IDFModel(uid, idf).setParent(this)) } @@ -87,6 +91,13 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa override def copy(extra: ParamMap): IDF = defaultCopy(extra) } +@Since("1.6.0") +object IDF extends DefaultParamsReadable[IDF] { + + @Since("1.6.0") + override def load(path: String): IDF = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[IDF]]. @@ -95,7 +106,9 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa class IDFModel private[ml] ( override val uid: String, idfModel: feature.IDFModel) - extends Model[IDFModel] with IDFBase { + extends Model[IDFModel] with IDFBase with MLWritable { + + import IDFModel._ /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -103,7 +116,8 @@ class IDFModel private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val idf = udf { vec: Vector => idfModel.transform(vec) } dataset.withColumn($(outputCol), idf(col($(inputCol)))) @@ -117,4 +131,50 @@ class IDFModel private[ml] ( val copied = new IDFModel(uid, idfModel) copyValues(copied, extra).setParent(parent) } + + /** Returns the IDF vector. */ + @Since("1.6.0") + def idf: Vector = idfModel.idf + + @Since("1.6.0") + override def write: MLWriter = new IDFModelWriter(this) +} + +@Since("1.6.0") +object IDFModel extends MLReadable[IDFModel] { + + private[IDFModel] class IDFModelWriter(instance: IDFModel) extends MLWriter { + + private case class Data(idf: Vector) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.idf) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class IDFModelReader extends MLReader[IDFModel] { + + private val className = classOf[IDFModel].getName + + override def load(path: String): IDFModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("idf") + .head() + val idf = data.getAs[Vector](0) + val model = new IDFModel(metadata.uid, new feature.IDFModel(idf)) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[IDFModel] = new IDFModelReader + + @Since("1.6.0") + override def load(path: String): IDFModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 37f7862476cfe..9ca34e9ae22f4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -20,14 +20,14 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.ml.Transformer -import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -42,26 +42,34 @@ import org.apache.spark.sql.types._ * `Vector(6, 8)` if all input features were numeric. If the first feature was instead nominal * with four categories, the output would then be `Vector(0, 0, 0, 0, 3, 4, 0, 0)`. */ +@Since("1.6.0") @Experimental -class Interaction(override val uid: String) extends Transformer - with HasInputCols with HasOutputCol { +class Interaction @Since("1.6.0") (override val uid: String) extends Transformer + with HasInputCols with HasOutputCol with DefaultParamsWritable { + @Since("1.6.0") def this() = this(Identifiable.randomUID("interaction")) /** @group setParam */ + @Since("1.6.0") def setInputCols(values: Array[String]): this.type = set(inputCols, values) /** @group setParam */ + @Since("1.6.0") def setOutputCol(value: String): this.type = set(outputCol, value) // optimistic schema; does not contain any ML attributes + @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { - validateParams() + require(get(inputCols).isDefined, "Input cols must be defined first.") + require(get(outputCol).isDefined, "Output col must be defined first.") + require($(inputCols).length > 0, "Input cols must have non-zero length.") + require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false)) } - override def transform(dataset: DataFrame): DataFrame = { - validateParams() + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val inputFeatures = $(inputCols).map(c => dataset.schema(c)) val featureEncoders = getFeatureEncoders(inputFeatures) val featureAttrs = getFeatureAttrs(inputFeatures) @@ -208,14 +216,16 @@ class Interaction(override val uid: String) extends Transformer } } + @Since("1.6.0") override def copy(extra: ParamMap): Interaction = defaultCopy(extra) - override def validateParams(): Unit = { - require(get(inputCols).isDefined, "Input cols must be defined first.") - require(get(outputCol).isDefined, "Output col must be defined first.") - require($(inputCols).length > 0, "Input cols must have non-zero length.") - require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") - } +} + +@Since("1.6.0") +object Interaction extends DefaultParamsReadable[Interaction] { + + @Since("1.6.0") + override def load(path: String): Interaction = super.load(path) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala new file mode 100644 index 0000000000000..e9df600c8a991 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param.{ParamMap, Params} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{StructField, StructType} + +/** + * Params for [[MaxAbsScaler]] and [[MaxAbsScalerModel]]. + */ +private[feature] trait MaxAbsScalerParams extends Params with HasInputCol with HasOutputCol { + + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${$(inputCol)} must be a vector column") + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) + StructType(outputFields) + } +} + +/** + * :: Experimental :: + * Rescale each feature individually to range [-1, 1] by dividing through the largest maximum + * absolute value in each feature. It does not shift/center the data, and thus does not destroy + * any sparsity. + */ +@Experimental +class MaxAbsScaler @Since("2.0.0") (override val uid: String) + extends Estimator[MaxAbsScalerModel] with MaxAbsScalerParams with DefaultParamsWritable { + + @Since("2.0.0") + def this() = this(Identifiable.randomUID("maxAbsScal")) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + @Since("2.0.0") + override def fit(dataset: Dataset[_]): MaxAbsScalerModel = { + transformSchema(dataset.schema, logging = true) + val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } + val summary = Statistics.colStats(input) + val minVals = summary.min.toArray + val maxVals = summary.max.toArray + val n = minVals.length + val maxAbs = Array.tabulate(n) { i => math.max(math.abs(minVals(i)), math.abs(maxVals(i))) } + + copyValues(new MaxAbsScalerModel(uid, Vectors.dense(maxAbs)).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): MaxAbsScaler = defaultCopy(extra) +} + +@Since("1.6.0") +object MaxAbsScaler extends DefaultParamsReadable[MaxAbsScaler] { + + @Since("1.6.0") + override def load(path: String): MaxAbsScaler = super.load(path) +} + +/** + * :: Experimental :: + * Model fitted by [[MaxAbsScaler]]. + * + */ +@Experimental +class MaxAbsScalerModel private[ml] ( + override val uid: String, + val maxAbs: Vector) + extends Model[MaxAbsScalerModel] with MaxAbsScalerParams with MLWritable { + + import MaxAbsScalerModel._ + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) + // TODO: this looks hack, we may have to handle sparse and dense vectors separately. + val maxAbsUnzero = Vectors.dense(maxAbs.toArray.map(x => if (x == 0) 1 else x)) + val reScale = udf { (vector: Vector) => + val brz = vector.toBreeze / maxAbsUnzero.toBreeze + Vectors.fromBreeze(brz) + } + dataset.withColumn($(outputCol), reScale(col($(inputCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): MaxAbsScalerModel = { + val copied = new MaxAbsScalerModel(uid, maxAbs) + copyValues(copied, extra).setParent(parent) + } + + @Since("1.6.0") + override def write: MLWriter = new MaxAbsScalerModelWriter(this) +} + +@Since("1.6.0") +object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] { + + private[MaxAbsScalerModel] + class MaxAbsScalerModelWriter(instance: MaxAbsScalerModel) extends MLWriter { + + private case class Data(maxAbs: Vector) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = new Data(instance.maxAbs) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class MaxAbsScalerModelReader extends MLReader[MaxAbsScalerModel] { + + private val className = classOf[MaxAbsScalerModel].getName + + override def load(path: String): MaxAbsScalerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val Row(maxAbs: Vector) = sqlContext.read.parquet(dataPath) + .select("maxAbs") + .head() + val model = new MaxAbsScalerModel(metadata.uid, maxAbs) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[MaxAbsScalerModel] = new MaxAbsScalerModelReader + + @Since("1.6.0") + override def load(path: String): MaxAbsScalerModel = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 1b494ec8b1727..125becbb8a5b5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -17,12 +17,14 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.param.{ParamMap, DoubleParam, Params} -import org.apache.spark.ml.util.Identifiable +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.ml.param.{DoubleParam, ParamMap, Params} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.stat.Statistics import org.apache.spark.sql._ import org.apache.spark.sql.functions._ @@ -57,6 +59,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { + require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})") val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") @@ -66,9 +69,6 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H StructType(outputFields) } - override def validateParams(): Unit = { - require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})") - } } /** @@ -85,7 +85,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H */ @Experimental class MinMaxScaler(override val uid: String) - extends Estimator[MinMaxScalerModel] with MinMaxScalerParams { + extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("minMaxScal")) @@ -103,9 +103,10 @@ class MinMaxScaler(override val uid: String) /** @group setParam */ def setMax(value: Double): this.type = set(max, value) - override def fit(dataset: DataFrame): MinMaxScalerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): MinMaxScalerModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } + val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val summary = Statistics.colStats(input) copyValues(new MinMaxScalerModel(uid, summary.min, summary.max).setParent(this)) } @@ -117,6 +118,13 @@ class MinMaxScaler(override val uid: String) override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra) } +@Since("1.6.0") +object MinMaxScaler extends DefaultParamsReadable[MinMaxScaler] { + + @Since("1.6.0") + override def load(path: String): MinMaxScaler = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[MinMaxScaler]]. @@ -131,7 +139,9 @@ class MinMaxScalerModel private[ml] ( override val uid: String, val originalMin: Vector, val originalMax: Vector) - extends Model[MinMaxScalerModel] with MinMaxScalerParams { + extends Model[MinMaxScalerModel] with MinMaxScalerParams with MLWritable { + + import MinMaxScalerModel._ /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -145,7 +155,8 @@ class MinMaxScalerModel private[ml] ( /** @group setParam */ def setMax(value: Double): this.type = set(max, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val originalRange = (originalMax.toBreeze - originalMin.toBreeze).toArray val minArray = originalMin.toArray @@ -154,7 +165,7 @@ class MinMaxScalerModel private[ml] ( // 0 in sparse vector will probably be rescaled to non-zero val values = vector.toArray - val size = values.size + val size = values.length var i = 0 while (i < size) { val raw = if (originalRange(i) != 0) (values(i) - minArray(i)) / originalRange(i) else 0.5 @@ -175,4 +186,46 @@ class MinMaxScalerModel private[ml] ( val copied = new MinMaxScalerModel(uid, originalMin, originalMax) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new MinMaxScalerModelWriter(this) +} + +@Since("1.6.0") +object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { + + private[MinMaxScalerModel] + class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends MLWriter { + + private case class Data(originalMin: Vector, originalMax: Vector) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = new Data(instance.originalMin, instance.originalMax) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class MinMaxScalerModelReader extends MLReader[MinMaxScalerModel] { + + private val className = classOf[MinMaxScalerModel].getName + + override def load(path: String): MinMaxScalerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val Row(originalMin: Vector, originalMax: Vector) = sqlContext.read.parquet(dataPath) + .select("originalMin", "originalMax") + .head() + val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[MinMaxScalerModel] = new MinMaxScalerModelReader + + @Since("1.6.0") + override def load(path: String): MinMaxScalerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index 8de10eb51f923..f8bc7e3f0c031 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} /** @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} */ @Experimental class NGram(override val uid: String) - extends UnaryTransformer[Seq[String], Seq[String], NGram] { + extends UnaryTransformer[Seq[String], Seq[String], NGram] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("ngram")) @@ -67,3 +67,10 @@ class NGram(override val uid: String) override protected def outputDataType: DataType = new ArrayType(StringType, false) } + +@Since("1.6.0") +object NGram extends DefaultParamsReadable[NGram] { + + @Since("1.6.0") + override def load(path: String): NGram = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index 8282e5ffa17f7..a603b3f833202 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{DoubleParam, ParamValidators} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.types.DataType @@ -30,7 +30,8 @@ import org.apache.spark.sql.types.DataType * Normalize a vector to have unit norm using the given p-norm. */ @Experimental -class Normalizer(override val uid: String) extends UnaryTransformer[Vector, Vector, Normalizer] { +class Normalizer(override val uid: String) + extends UnaryTransformer[Vector, Vector, Normalizer] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("normalizer")) @@ -56,3 +57,10 @@ class Normalizer(override val uid: String) extends UnaryTransformer[Vector, Vect override protected def outputDataType: DataType = new VectorUDT() } + +@Since("1.6.0") +object Normalizer extends DefaultParamsReadable[Normalizer] { + + @Since("1.6.0") + override def load(path: String): Normalizer = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 9c60d4084ec46..99357793dbaeb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -17,16 +17,16 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions.{col, udf} -import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.sql.types.{DoubleType, NumericType, StructType} /** * :: Experimental :: @@ -44,7 +44,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} */ @Experimental class OneHotEncoder(override val uid: String) extends Transformer - with HasInputCol with HasOutputCol { + with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("oneHot")) @@ -69,7 +69,8 @@ class OneHotEncoder(override val uid: String) extends Transformer val inputColName = $(inputCol) val outputColName = $(outputCol) - SchemaUtils.checkColumnType(schema, inputColName, DoubleType) + require(schema(inputColName).dataType.isInstanceOf[NumericType], + s"Input column must be of type NumericType but got ${schema(inputColName).dataType}") val inputFields = schema.fields require(!inputFields.exists(_.name == outputColName), s"Output column $outputColName already exists.") @@ -120,7 +121,8 @@ class OneHotEncoder(override val uid: String) extends Transformer StructType(outputFields) } - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { // schema transformation val inputColName = $(inputCol) val outputColName = $(outputCol) @@ -129,10 +131,12 @@ class OneHotEncoder(override val uid: String) extends Transformer transformSchema(dataset.schema)(outputColName)) if (outputAttrGroup.size < 0) { // If the number of attributes is unknown, we check the values from the input column. - val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).map(_.getDouble(0)) + val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).rdd.map(_.getDouble(0)) .aggregate(0.0)( (m, x) => { - assert(x >=0.0 && x == x.toInt, + assert(x <= Int.MaxValue, + s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x") + assert(x >= 0.0 && x == x.toInt, s"Values from column $inputColName must be indices, but got $x.") math.max(m, x) }, @@ -166,3 +170,10 @@ class OneHotEncoder(override val uid: String) extends Transformer override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra) } + +@Since("1.6.0") +object OneHotEncoder extends DefaultParamsReadable[OneHotEncoder] { + + @Since("1.6.0") + override def load(path: String): OneHotEncoder = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 539084704b653..9cf722e121697 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -17,13 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.linalg._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -49,7 +51,8 @@ private[feature] trait PCAParams extends Params with HasInputCol with HasOutputC * PCA trains a model to project vectors to a low-dimensional space using PCA. */ @Experimental -class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams { +class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams + with DefaultParamsWritable { def this() = this(Identifiable.randomUID("pca")) @@ -65,12 +68,13 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams /** * Computes a [[PCAModel]] that contains the principal components of the input vectors. */ - override def fit(dataset: DataFrame): PCAModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): PCAModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v} + val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v} val pca = new feature.PCA(k = $(k)) val pcaModel = pca.fit(input) - copyValues(new PCAModel(uid, pcaModel).setParent(this)) + copyValues(new PCAModel(uid, pcaModel.pc, pcaModel.explainedVariance).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -86,15 +90,29 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams override def copy(extra: ParamMap): PCA = defaultCopy(extra) } +@Since("1.6.0") +object PCA extends DefaultParamsReadable[PCA] { + + @Since("1.6.0") + override def load(path: String): PCA = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[PCA]]. + * + * @param pc A principal components Matrix. Each column is one principal component. + * @param explainedVariance A vector of proportions of variance explained by + * each principal component. */ @Experimental class PCAModel private[ml] ( override val uid: String, - pcaModel: feature.PCAModel) - extends Model[PCAModel] with PCAParams { + val pc: DenseMatrix, + val explainedVariance: DenseVector) + extends Model[PCAModel] with PCAParams with MLWritable { + + import PCAModel._ /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -107,8 +125,10 @@ class PCAModel private[ml] ( * NOTE: Vectors to be transformed must be the same length * as the source vectors given to [[PCA.fit()]]. */ - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) + val pcaModel = new feature.PCAModel($(k), pc, explainedVariance) val pcaOp = udf { pcaModel.transform _ } dataset.withColumn($(outputCol), pcaOp(col($(inputCol)))) } @@ -124,7 +144,72 @@ class PCAModel private[ml] ( } override def copy(extra: ParamMap): PCAModel = { - val copied = new PCAModel(uid, pcaModel) + val copied = new PCAModel(uid, pc, explainedVariance) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new PCAModelWriter(this) +} + +@Since("1.6.0") +object PCAModel extends MLReadable[PCAModel] { + + private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter { + + private case class Data(pc: DenseMatrix, explainedVariance: DenseVector) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.pc, instance.explainedVariance) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class PCAModelReader extends MLReader[PCAModel] { + + private val className = classOf[PCAModel].getName + + /** + * Loads a [[PCAModel]] from data located at the input path. Note that the model includes an + * `explainedVariance` member that is not recorded by Spark 1.6 and earlier. A model + * can be loaded from such older data but will have an empty vector for + * `explainedVariance`. + * + * @param path path to serialized model data + * @return a [[PCAModel]] + */ + override def load(path: String): PCAModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + // explainedVariance field is not present in Spark <= 1.6 + val versionRegex = "([0-9]+)\\.([0-9]+).*".r + val hasExplainedVariance = metadata.sparkVersion match { + case versionRegex(major, minor) => + (major.toInt >= 2 || (major.toInt == 1 && minor.toInt > 6)) + case _ => false + } + + val dataPath = new Path(path, "data").toString + val model = if (hasExplainedVariance) { + val Row(pc: DenseMatrix, explainedVariance: DenseVector) = + sqlContext.read.parquet(dataPath) + .select("pc", "explainedVariance") + .head() + new PCAModel(metadata.uid, pc, explainedVariance) + } else { + val Row(pc: DenseMatrix) = sqlContext.read.parquet(dataPath).select("pc").head() + new PCAModel(metadata.uid, pc, Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]) + } + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[PCAModel] = new PCAModelReader + + @Since("1.6.0") + override def load(path: String): PCAModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index d85e468562d4a..0a9b9719c15d3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -19,10 +19,10 @@ package org.apache.spark.ml.feature import scala.collection.mutable -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{ParamMap, IntParam, ParamValidators} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.types.DataType @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class PolynomialExpansion(override val uid: String) - extends UnaryTransformer[Vector, Vector, PolynomialExpansion] { + extends UnaryTransformer[Vector, Vector, PolynomialExpansion] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("poly")) @@ -46,7 +46,7 @@ class PolynomialExpansion(override val uid: String) * @group param */ val degree = new IntParam(this, "degree", "the polynomial degree to expand (>= 1)", - ParamValidators.gt(1)) + ParamValidators.gtEq(1)) setDefault(degree -> 2) @@ -77,7 +77,8 @@ class PolynomialExpansion(override val uid: String) * To handle sparsity, if c is zero, we can skip all monomials that contain it. We remember the * current index and increment it properly for sparse input. */ -private[feature] object PolynomialExpansion { +@Since("1.6.0") +object PolynomialExpansion extends DefaultParamsReadable[PolynomialExpansion] { private def choose(n: Int, k: Int): Int = { Range(n, n - k, -1).product / Range(k, 1, -1).product @@ -169,11 +170,14 @@ private[feature] object PolynomialExpansion { new SparseVector(polySize - 1, polyIndices.result(), polyValues.result()) } - def expand(v: Vector, degree: Int): Vector = { + private[feature] def expand(v: Vector, degree: Int): Vector = { v match { case dv: DenseVector => expand(dv, degree) case sv: SparseVector => expand(sv, degree) case _ => throw new IllegalArgumentException } } + + @Since("1.6.0") + override def load(path: String): PolynomialExpansion = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 46b836da9cfde..5c7993af645af 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -19,24 +19,25 @@ package org.apache.spark.ml.feature import scala.collection.mutable -import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.internal.Logging import org.apache.spark.ml._ import org.apache.spark.ml.attribute.NominalAttribute -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.param.{IntParam, _} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol, HasSeed} import org.apache.spark.ml.util._ +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.types.{DoubleType, StructType} -import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.util.random.XORShiftRandom /** * Params for [[QuantileDiscretizer]]. */ -private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol with HasOutputCol { +private[feature] trait QuantileDiscretizerBase extends Params + with HasInputCol with HasOutputCol with HasSeed { /** - * Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must + * Number of buckets (quantiles, or categories) into which data points are grouped. Must * be >= 2. * default: 2 * @group param @@ -48,6 +49,21 @@ private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol w /** @group getParam */ def getNumBuckets: Int = getOrDefault(numBuckets) + + /** + * Relative error (see documentation for + * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] for description) + * Must be a number in [0, 1]. + * default: 0.001 + * @group param + */ + val relativeError = new DoubleParam(this, "relativeError", "The relative target precision " + + "for approxQuantile", + ParamValidators.inRange(0.0, 1.0)) + setDefault(relativeError -> 0.001) + + /** @group getParam */ + def getRelativeError: Double = getOrDefault(relativeError) } /** @@ -55,15 +71,17 @@ private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol w * `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned * categorical features. The bin ranges are chosen by taking a sample of the data and dividing it * into roughly equal parts. The lower and upper bin bounds will be -Infinity and +Infinity, - * covering all real values. This attempts to find numBuckets partitions based on a sample of data, - * but it may find fewer depending on the data sample values. + * covering all real values. */ @Experimental final class QuantileDiscretizer(override val uid: String) - extends Estimator[Bucketizer] with QuantileDiscretizerBase { + extends Estimator[Bucketizer] with QuantileDiscretizerBase with DefaultParamsWritable { def this() = this(Identifiable.randomUID("quantileDiscretizer")) + /** @group setParam */ + def setRelativeError(value: Double): this.type = set(relativeError, value) + /** @group setParam */ def setNumBuckets(value: Int): this.type = set(numBuckets, value) @@ -73,6 +91,9 @@ final class QuantileDiscretizer(override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + def setSeed(value: Long): this.type = set(seed, value) + override def transformSchema(schema: StructType): StructType = { SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) val inputFields = schema.fields @@ -83,94 +104,23 @@ final class QuantileDiscretizer(override val uid: String) StructType(outputFields) } - override def fit(dataset: DataFrame): Bucketizer = { - val samples = QuantileDiscretizer.getSampledInput(dataset.select($(inputCol)), $(numBuckets)) - .map { case Row(feature: Double) => feature } - val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1) - val splits = QuantileDiscretizer.getSplits(candidates) + @Since("2.0.0") + override def fit(dataset: Dataset[_]): Bucketizer = { + val splits = dataset.stat.approxQuantile($(inputCol), + (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError)) + splits(0) = Double.NegativeInfinity + splits(splits.length - 1) = Double.PositiveInfinity + val bucketizer = new Bucketizer(uid).setSplits(splits) - copyValues(bucketizer) + copyValues(bucketizer.setParent(this)) } override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra) } -private[feature] object QuantileDiscretizer extends Logging { - /** - * Sampling from the given dataset to collect quantile statistics. - */ - def getSampledInput(dataset: DataFrame, numBins: Int): Array[Row] = { - val totalSamples = dataset.count() - require(totalSamples > 0, - "QuantileDiscretizer requires non-empty input dataset but was given an empty input.") - val requiredSamples = math.max(numBins * numBins, 10000) - val fraction = math.min(requiredSamples / dataset.count(), 1.0) - dataset.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect() - } +@Since("1.6.0") +object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { - /** - * Compute split points with respect to the sample distribution. - */ - def findSplitCandidates(samples: Array[Double], numSplits: Int): Array[Double] = { - val valueCountMap = samples.foldLeft(Map.empty[Double, Int]) { (m, x) => - m + ((x, m.getOrElse(x, 0) + 1)) - } - val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray ++ Array((Double.MaxValue, 1)) - val possibleSplits = valueCounts.length - 1 - if (possibleSplits <= numSplits) { - valueCounts.dropRight(1).map(_._1) - } else { - val stride: Double = math.ceil(samples.length.toDouble / (numSplits + 1)) - val splitsBuilder = mutable.ArrayBuilder.make[Double] - var index = 1 - // currentCount: sum of counts of values that have been visited - var currentCount = valueCounts(0)._2 - // targetCount: target value for `currentCount`. If `currentCount` is closest value to - // `targetCount`, then current value is a split threshold. After finding a split threshold, - // `targetCount` is added by stride. - var targetCount = stride - while (index < valueCounts.length) { - val previousCount = currentCount - currentCount += valueCounts(index)._2 - val previousGap = math.abs(previousCount - targetCount) - val currentGap = math.abs(currentCount - targetCount) - // If adding count of current value to currentCount makes the gap between currentCount and - // targetCount smaller, previous value is a split threshold. - if (previousGap < currentGap) { - splitsBuilder += valueCounts(index - 1)._1 - targetCount += stride - } - index += 1 - } - splitsBuilder.result() - } - } - - /** - * Adjust split candidates to proper splits by: adding positive/negative infinity to both sides as - * needed, and adding a default split value of 0 if no good candidates are found. - */ - def getSplits(candidates: Array[Double]): Array[Double] = { - val effectiveValues = if (candidates.size != 0) { - if (candidates.head == Double.NegativeInfinity - && candidates.last == Double.PositiveInfinity) { - candidates.drop(1).dropRight(1) - } else if (candidates.head == Double.NegativeInfinity) { - candidates.drop(1) - } else if (candidates.last == Double.PositiveInfinity) { - candidates.dropRight(1) - } else { - candidates - } - } else { - candidates - } - - if (effectiveValues.size == 0) { - Array(Double.NegativeInfinity, 0, Double.PositiveInfinity) - } else { - Array(Double.NegativeInfinity) ++ effectiveValues ++ Array(Double.PositiveInfinity) - } - } + @Since("1.6.0") + override def load(path: String): QuantileDiscretizer = super.load(path) } - diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 5c43a41bee3b4..3ac6c776699a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -20,14 +20,16 @@ package org.apache.spark.ml.feature import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} +import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.VectorUDT -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types._ /** @@ -45,9 +47,31 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { * Implements the transforms required for fitting a dataset against an R model formula. Currently * we support a limited subset of the R operators, including '~', '.', ':', '+', and '-'. Also see * the R formula docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + * + * The basic operators are: + * - `~` separate target and terms + * - `+` concat terms, "+ 0" means removing intercept + * - `-` remove a term, "- 1" means removing intercept + * - `:` interaction (multiplication for numeric values, or binarized categorical values) + * - `.` all columns except target + * + * Suppose `a` and `b` are double columns, we use the following simple examples + * to illustrate the effect of `RFormula`: + * - `y ~ a + b` means model `y ~ w0 + w1 * a + w2 * b` where `w0` is the intercept and `w1, w2` + * are coefficients. + * - `y ~ a + b + a:b - 1` means model `y ~ w1 * a + w2 * b + w3 * a * b` where `w1, w2, w3` + * are coefficients. + * + * RFormula produces a vector column of features and a double or string column of label. + * Like when formulas are used in R for linear regression, string input columns will be one-hot + * encoded, and numeric columns will be cast to doubles. + * If the label column is of type string, it will be first transformed to double with + * `StringIndexer`. If the label column does not exist in the DataFrame, the output label column + * will be created from the specified response variable in the formula. */ @Experimental -class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase { +class RFormula(override val uid: String) + extends Estimator[RFormulaModel] with RFormulaBase with DefaultParamsWritable { def this() = this(Identifiable.randomUID("rFormula")) @@ -79,7 +103,8 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R RFormulaParser.parse($(formula)).hasIntercept } - override def fit(dataset: DataFrame): RFormulaModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): RFormulaModel = { require(isDefined(formula), "Formula must be defined first.") val parsedFormula = RFormulaParser.parse($(formula)) val resolvedFormula = parsedFormula.resolve(dataset.schema) @@ -101,6 +126,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R encoderStages += new StringIndexer() .setInputCol(term) .setOutputCol(indexCol) + prefixesToRewrite(indexCol + "_") = term + "_" (term, indexCol) case _ => (term, term) @@ -159,6 +185,13 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R override def toString: String = s"RFormula(${get(formula)}) (uid=$uid)" } +@Since("2.0.0") +object RFormula extends DefaultParamsReadable[RFormula] { + + @Since("2.0.0") + override def load(path: String): RFormula = super.load(path) +} + /** * :: Experimental :: * A fitted RFormula. Fitting is required to determine the factor levels of formula terms. @@ -168,11 +201,12 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R @Experimental class RFormulaModel private[feature]( override val uid: String, - resolvedFormula: ResolvedRFormula, - pipelineModel: PipelineModel) - extends Model[RFormulaModel] with RFormulaBase { + private[ml] val resolvedFormula: ResolvedRFormula, + private[ml] val pipelineModel: PipelineModel) + extends Model[RFormulaModel] with RFormulaBase with MLWritable { - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { checkCanTransform(dataset.schema) transformLabel(pipelineModel.transform(dataset)) } @@ -198,12 +232,12 @@ class RFormulaModel private[feature]( override def copy(extra: ParamMap): RFormulaModel = copyValues( new RFormulaModel(uid, resolvedFormula, pipelineModel)) - override def toString: String = s"RFormulaModel(${resolvedFormula}) (uid=$uid)" + override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)" - private def transformLabel(dataset: DataFrame): DataFrame = { + private def transformLabel(dataset: Dataset[_]): DataFrame = { val labelName = resolvedFormula.label if (hasLabelCol(dataset.schema)) { - dataset + dataset.toDF } else if (dataset.schema.exists(_.name == labelName)) { dataset.schema(labelName).dataType match { case _: NumericType | BooleanType => @@ -214,7 +248,7 @@ class RFormulaModel private[feature]( } else { // Ignore the label field. This is a hack so that this transformer can also work on test // datasets in a Pipeline. - dataset + dataset.toDF } } @@ -225,18 +259,75 @@ class RFormulaModel private[feature]( !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType, "Label column already exists and is not of type DoubleType.") } + + @Since("2.0.0") + override def write: MLWriter = new RFormulaModel.RFormulaModelWriter(this) +} + +@Since("2.0.0") +object RFormulaModel extends MLReadable[RFormulaModel] { + + @Since("2.0.0") + override def read: MLReader[RFormulaModel] = new RFormulaModelReader + + @Since("2.0.0") + override def load(path: String): RFormulaModel = super.load(path) + + /** [[MLWriter]] instance for [[RFormulaModel]] */ + private[RFormulaModel] class RFormulaModelWriter(instance: RFormulaModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: resolvedFormula + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(instance.resolvedFormula)) + .repartition(1).write.parquet(dataPath) + // Save pipeline model + val pmPath = new Path(path, "pipelineModel").toString + instance.pipelineModel.save(pmPath) + } + } + + private class RFormulaModelReader extends MLReader[RFormulaModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[RFormulaModel].getName + + override def load(path: String): RFormulaModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("label", "terms", "hasIntercept").head() + val label = data.getString(0) + val terms = data.getAs[Seq[Seq[String]]](1) + val hasIntercept = data.getBoolean(2) + val resolvedRFormula = ResolvedRFormula(label, terms, hasIntercept) + + val pmPath = new Path(path, "pipelineModel").toString + val pipelineModel = PipelineModel.load(pmPath) + + val model = new RFormulaModel(metadata.uid, resolvedRFormula, pipelineModel) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** * Utility transformer for removing temporary columns from a DataFrame. * TODO(ekl) make this a public transformer */ -private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { - override val uid = Identifiable.randomUID("columnPruner") +private class ColumnPruner(override val uid: String, val columnsToPrune: Set[String]) + extends Transformer with MLWritable { + + def this(columnsToPrune: Set[String]) = + this(Identifiable.randomUID("columnPruner"), columnsToPrune) - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_)) - dataset.select(columnsToKeep.map(dataset.col) : _*) + dataset.select(columnsToKeep.map(dataset.col): _*) } override def transformSchema(schema: StructType): StructType = { @@ -244,6 +335,48 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { } override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra) + + override def write: MLWriter = new ColumnPruner.ColumnPrunerWriter(this) +} + +private object ColumnPruner extends MLReadable[ColumnPruner] { + + override def read: MLReader[ColumnPruner] = new ColumnPrunerReader + + override def load(path: String): ColumnPruner = super.load(path) + + /** [[MLWriter]] instance for [[ColumnPruner]] */ + private[ColumnPruner] class ColumnPrunerWriter(instance: ColumnPruner) extends MLWriter { + + private case class Data(columnsToPrune: Seq[String]) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: columnsToPrune + val data = Data(instance.columnsToPrune.toSeq) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class ColumnPrunerReader extends MLReader[ColumnPruner] { + + /** Checked against metadata when loading model */ + private val className = classOf[ColumnPruner].getName + + override def load(path: String): ColumnPruner = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("columnsToPrune").head() + val columnsToPrune = data.getAs[Seq[String]](0).toSet + val pruner = new ColumnPruner(metadata.uid, columnsToPrune) + + DefaultParamsReader.getAndSetParams(pruner, metadata) + pruner + } + } } /** @@ -257,25 +390,23 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { * by the value in the map. */ private class VectorAttributeRewriter( - vectorCol: String, - prefixesToRewrite: Map[String, String]) - extends Transformer { + override val uid: String, + val vectorCol: String, + val prefixesToRewrite: Map[String, String]) + extends Transformer with MLWritable { - override val uid = Identifiable.randomUID("vectorAttrRewriter") + def this(vectorCol: String, prefixesToRewrite: Map[String, String]) = + this(Identifiable.randomUID("vectorAttrRewriter"), vectorCol, prefixesToRewrite) - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { val metadata = { val group = AttributeGroup.fromStructField(dataset.schema(vectorCol)) val attrs = group.attributes.get.map { attr => if (attr.name.isDefined) { - val name = attr.name.get - val replacement = prefixesToRewrite.filter { case (k, _) => name.startsWith(k) } - if (replacement.nonEmpty) { - val (k, v) = replacement.headOption.get - attr.withName(v + name.stripPrefix(k)) - } else { - attr + val name = prefixesToRewrite.foldLeft(attr.name.get) { case (curName, (from, to)) => + curName.replace(from, to) } + attr.withName(name) } else { attr } @@ -284,7 +415,7 @@ private class VectorAttributeRewriter( } val otherCols = dataset.columns.filter(_ != vectorCol).map(dataset.col) val rewrittenCol = dataset.col(vectorCol).as(vectorCol, metadata) - dataset.select((otherCols :+ rewrittenCol): _*) + dataset.select(otherCols :+ rewrittenCol : _*) } override def transformSchema(schema: StructType): StructType = { @@ -294,4 +425,48 @@ private class VectorAttributeRewriter( } override def copy(extra: ParamMap): VectorAttributeRewriter = defaultCopy(extra) + + override def write: MLWriter = new VectorAttributeRewriter.VectorAttributeRewriterWriter(this) +} + +private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewriter] { + + override def read: MLReader[VectorAttributeRewriter] = new VectorAttributeRewriterReader + + override def load(path: String): VectorAttributeRewriter = super.load(path) + + /** [[MLWriter]] instance for [[VectorAttributeRewriter]] */ + private[VectorAttributeRewriter] + class VectorAttributeRewriterWriter(instance: VectorAttributeRewriter) extends MLWriter { + + private case class Data(vectorCol: String, prefixesToRewrite: Map[String, String]) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: vectorCol, prefixesToRewrite + val data = Data(instance.vectorCol, instance.prefixesToRewrite) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class VectorAttributeRewriterReader extends MLReader[VectorAttributeRewriter] { + + /** Checked against metadata when loading model */ + private val className = classOf[VectorAttributeRewriter].getName + + override def load(path: String): VectorAttributeRewriter = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("vectorCol", "prefixesToRewrite").head() + val vectorCol = data.getString(0) + val prefixesToRewrite = data.getAs[Map[String, String]](1) + val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite) + + DefaultParamsReader.getAndSetParams(rewriter, metadata) + rewriter + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index 95e4305638730..2002d15745d9f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -18,55 +18,80 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.param.{ParamMap, Param} +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.Transformer -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.sql.{SQLContext, DataFrame, Row} +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} import org.apache.spark.sql.types.StructType /** * :: Experimental :: - * Implements the transforms which are defined by SQL statement. - * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__' + * Implements the transformations which are defined by SQL statement. + * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__ ...' * where '__THIS__' represents the underlying table of the input dataset. + * The select clause specifies the fields, constants, and expressions to display in + * the output, it can be any select clause that Spark SQL supports. Users can also + * use Spark SQL built-in function and UDFs to operate on these selected columns. + * For example, [[SQLTransformer]] supports statements like: + * - SELECT a, a + b AS a_b FROM __THIS__ + * - SELECT a, SQRT(b) AS b_sqrt FROM __THIS__ where a > 5 + * - SELECT a, b, SUM(c) AS c_sum FROM __THIS__ GROUP BY a, b */ @Experimental -class SQLTransformer (override val uid: String) extends Transformer { +@Since("1.6.0") +class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transformer + with DefaultParamsWritable { + @Since("1.6.0") def this() = this(Identifiable.randomUID("sql")) /** * SQL statement parameter. The statement is provided in string form. * @group param */ + @Since("1.6.0") final val statement: Param[String] = new Param[String](this, "statement", "SQL statement") /** @group setParam */ + @Since("1.6.0") def setStatement(value: String): this.type = set(statement, value) /** @group getParam */ + @Since("1.6.0") def getStatement: String = $(statement) private val tableIdentifier: String = "__THIS__" - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val tableName = Identifiable.randomUID(uid) dataset.registerTempTable(tableName) val realStatement = $(statement).replace(tableIdentifier, tableName) - val outputDF = dataset.sqlContext.sql(realStatement) - outputDF + dataset.sqlContext.sql(realStatement) } + @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { val sc = SparkContext.getOrCreate() val sqlContext = SQLContext.getOrCreate(sc) val dummyRDD = sc.parallelize(Seq(Row.empty)) val dummyDF = sqlContext.createDataFrame(dummyRDD, schema) - dummyDF.registerTempTable(tableIdentifier) - val outputSchema = sqlContext.sql($(statement)).schema + val tableName = Identifiable.randomUID(uid) + val realStatement = $(statement).replace(tableIdentifier, tableName) + dummyDF.registerTempTable(tableName) + val outputSchema = sqlContext.sql(realStatement).schema + sqlContext.dropTempTable(tableName) outputSchema } + @Since("1.6.0") override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra) } + +@Since("1.6.0") +object SQLTransformer extends DefaultParamsReadable[SQLTransformer] { + + @Since("1.6.0") + override def load(path: String): SQLTransformer = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index f6d0b0c0e9e75..118a6e3e6ad44 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -17,11 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ @@ -34,20 +36,30 @@ import org.apache.spark.sql.types.{StructField, StructType} private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol { /** - * Centers the data with mean before scaling. + * Whether to center the data with mean before scaling. * It will build a dense output, so this does not work on sparse input * and will raise an exception. * Default: false * @group param */ - val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean") + val withMean: BooleanParam = new BooleanParam(this, "withMean", + "Whether to center data with mean") + + /** @group getParam */ + def getWithMean: Boolean = $(withMean) /** - * Scales the data to unit standard deviation. + * Whether to scale the data to unit standard deviation. * Default: true * @group param */ - val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation") + val withStd: BooleanParam = new BooleanParam(this, "withStd", + "Whether to scale the data to unit standard deviation") + + /** @group getParam */ + def getWithStd: Boolean = $(withStd) + + setDefault(withMean -> false, withStd -> true) } /** @@ -57,12 +69,10 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with */ @Experimental class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel] - with StandardScalerParams { + with StandardScalerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("stdScal")) - setDefault(withMean -> false, withStd -> true) - /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -75,12 +85,13 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM /** @group setParam */ def setWithStd(value: Boolean): this.type = set(withStd, value) - override def fit(dataset: DataFrame): StandardScalerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): StandardScalerModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } + val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd)) val scalerModel = scaler.fit(input) - copyValues(new StandardScalerModel(uid, scalerModel).setParent(this)) + copyValues(new StandardScalerModel(uid, scalerModel.std, scalerModel.mean).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -96,21 +107,28 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra) } +@Since("1.6.0") +object StandardScaler extends DefaultParamsReadable[StandardScaler] { + + @Since("1.6.0") + override def load(path: String): StandardScaler = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[StandardScaler]]. + * + * @param std Standard deviation of the StandardScalerModel + * @param mean Mean of the StandardScalerModel */ @Experimental class StandardScalerModel private[ml] ( override val uid: String, - scaler: feature.StandardScalerModel) - extends Model[StandardScalerModel] with StandardScalerParams { - - /** Standard deviation of the StandardScalerModel */ - val std: Vector = scaler.std + val std: Vector, + val mean: Vector) + extends Model[StandardScalerModel] with StandardScalerParams with MLWritable { - /** Mean of the StandardScalerModel */ - val mean: Vector = scaler.mean + import StandardScalerModel._ /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -118,8 +136,10 @@ class StandardScalerModel private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) + val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean)) val scale = udf { scaler.transform _ } dataset.withColumn($(outputCol), scale(col($(inputCol)))) } @@ -135,7 +155,49 @@ class StandardScalerModel private[ml] ( } override def copy(extra: ParamMap): StandardScalerModel = { - val copied = new StandardScalerModel(uid, scaler) + val copied = new StandardScalerModel(uid, std, mean) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new StandardScalerModelWriter(this) +} + +@Since("1.6.0") +object StandardScalerModel extends MLReadable[StandardScalerModel] { + + private[StandardScalerModel] + class StandardScalerModelWriter(instance: StandardScalerModel) extends MLWriter { + + private case class Data(std: Vector, mean: Vector) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.std, instance.mean) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class StandardScalerModelReader extends MLReader[StandardScalerModel] { + + private val className = classOf[StandardScalerModel].getName + + override def load(path: String): StandardScalerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val Row(std: Vector, mean: Vector) = sqlContext.read.parquet(dataPath) + .select("std", "mean") + .head() + val model = new StandardScalerModel(metadata.uid, std, mean) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[StandardScalerModel] = new StandardScalerModelReader + + @Since("1.6.0") + override def load(path: String): StandardScalerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 2a79582625e9a..b96bc48566fa7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -17,14 +17,14 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.sql.DataFrame +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions.{col, udf} -import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, StringType, StructType} /** * stop words list @@ -86,7 +86,7 @@ private[spark] object StopWords { */ @Experimental class StopWordsRemover(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("stopWords")) @@ -125,7 +125,8 @@ class StopWordsRemover(override val uid: String) setDefault(stopWords -> StopWords.English, caseSensitive -> false) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema) val t = if ($(caseSensitive)) { val stopWordsSet = $(stopWords).toSet @@ -148,10 +149,15 @@ class StopWordsRemover(override val uid: String) val inputType = schema($(inputCol)).dataType require(inputType.sameType(ArrayType(StringType)), s"Input type must be ArrayType(StringType) but got $inputType.") - val outputFields = schema.fields :+ - StructField($(outputCol), inputType, schema($(inputCol)).nullable) - StructType(outputFields) + SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable) } override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra) } + +@Since("1.6.0") +object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] { + + @Since("1.6.0") + override def load(path: String): StopWordsRemover = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 486274cd75a14..7e0d374f02723 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -17,15 +17,16 @@ package org.apache.spark.ml.feature +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkException -import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model, Transformer} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.Transformer -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.sql.DataFrame +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashMap @@ -64,7 +65,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha */ @Experimental class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] - with StringIndexerBase { + with StringIndexerBase with DefaultParamsWritable { def this() = this(Identifiable.randomUID("strIdx")) @@ -79,8 +80,10 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod def setOutputCol(value: String): this.type = set(outputCol, value) - override def fit(dataset: DataFrame): StringIndexerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): StringIndexerModel = { val counts = dataset.select(col($(inputCol)).cast(StringType)) + .rdd .map(_.getString(0)) .countByValue() val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray @@ -94,6 +97,13 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra) } +@Since("1.6.0") +object StringIndexer extends DefaultParamsReadable[StringIndexer] { + + @Since("1.6.0") + override def load(path: String): StringIndexer = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[StringIndexer]]. @@ -107,7 +117,10 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod @Experimental class StringIndexerModel ( override val uid: String, - val labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { + val labels: Array[String]) + extends Model[StringIndexerModel] with StringIndexerBase with MLWritable { + + import StringIndexerModel._ def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels) @@ -132,12 +145,14 @@ class StringIndexerModel ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { if (!dataset.schema.fieldNames.contains($(inputCol))) { logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " + "Skip StringIndexerModel.") - return dataset + return dataset.toDF } + validateAndTransformSchema(dataset.schema) val indexer = udf { label: String => if (labelToIndex.contains(label)) { @@ -148,15 +163,14 @@ class StringIndexerModel ( } val metadata = NominalAttribute.defaultAttr - .withName($(inputCol)).withValues(labels).toMetadata() + .withName($(outputCol)).withValues(labels).toMetadata() // If we are skipping invalid records, filter them out. - val filteredDataset = (getHandleInvalid) match { - case "skip" => { + val filteredDataset = getHandleInvalid match { + case "skip" => val filterer = udf { label: String => labelToIndex.contains(label) } dataset.where(filterer(dataset($(inputCol)))) - } case _ => dataset } filteredDataset.select(col("*"), @@ -176,6 +190,49 @@ class StringIndexerModel ( val copied = new StringIndexerModel(uid, labels) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: StringIndexModelWriter = new StringIndexModelWriter(this) +} + +@Since("1.6.0") +object StringIndexerModel extends MLReadable[StringIndexerModel] { + + private[StringIndexerModel] + class StringIndexModelWriter(instance: StringIndexerModel) extends MLWriter { + + private case class Data(labels: Array[String]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.labels) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class StringIndexerModelReader extends MLReader[StringIndexerModel] { + + private val className = classOf[StringIndexerModel].getName + + override def load(path: String): StringIndexerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("labels") + .head() + val labels = data.getAs[Seq[String]](0).toArray + val model = new StringIndexerModel(metadata.uid, labels) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[StringIndexerModel] = new StringIndexerModelReader + + @Since("1.6.0") + override def load(path: String): StringIndexerModel = super.load(path) } /** @@ -188,9 +245,8 @@ class StringIndexerModel ( * @see [[StringIndexer]] for converting strings into indices */ @Experimental -class IndexToString private[ml] ( - override val uid: String) extends Transformer - with HasInputCol with HasOutputCol { +class IndexToString private[ml] (override val uid: String) + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("idxToStr")) @@ -232,7 +288,8 @@ class IndexToString private[ml] ( StructType(outputFields) } - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val inputColSchema = dataset.schema($(inputCol)) // If the labels array is empty use column metadata val values = if ($(labels).isEmpty) { @@ -258,3 +315,10 @@ class IndexToString private[ml] ( defaultCopy(extra) } } + +@Since("1.6.0") +object IndexToString extends DefaultParamsReadable[IndexToString] { + + @Since("1.6.0") + override def load(path: String): IndexToString = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 248288ca73e99..8456a0e915804 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} /** @@ -30,7 +30,8 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} * @see [[RegexTokenizer]] */ @Experimental -class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[String], Tokenizer] { +class Tokenizer(override val uid: String) + extends UnaryTransformer[String, Seq[String], Tokenizer] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("tok")) @@ -47,6 +48,13 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra) } +@Since("1.6.0") +object Tokenizer extends DefaultParamsReadable[Tokenizer] { + + @Since("1.6.0") + override def load(path: String): Tokenizer = super.load(path) +} + /** * :: Experimental :: * A regex based tokenizer that extracts tokens either by using the provided regex pattern to split @@ -56,7 +64,7 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S */ @Experimental class RegexTokenizer(override val uid: String) - extends UnaryTransformer[String, Seq[String], RegexTokenizer] { + extends UnaryTransformer[String, Seq[String], RegexTokenizer] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("regexTok")) @@ -100,10 +108,25 @@ class RegexTokenizer(override val uid: String) /** @group getParam */ def getPattern: String = $(pattern) - setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+") + /** + * Indicates whether to convert all characters to lowercase before tokenizing. + * Default: true + * @group param + */ + final val toLowercase: BooleanParam = new BooleanParam(this, "toLowercase", + "whether to convert all characters to lowercase before tokenizing.") - override protected def createTransformFunc: String => Seq[String] = { str => + /** @group setParam */ + def setToLowercase(value: Boolean): this.type = set(toLowercase, value) + + /** @group getParam */ + def getToLowercase: Boolean = $(toLowercase) + + setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+", toLowercase -> true) + + override protected def createTransformFunc: String => Seq[String] = { originStr => val re = $(pattern).r + val str = if ($(toLowercase)) originStr.toLowerCase() else originStr val tokens = if ($(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq val minLength = $(minTokenLength) tokens.filter(_.length >= minLength) @@ -117,3 +140,10 @@ class RegexTokenizer(override val uid: String) override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra) } + +@Since("1.6.0") +object RegexTokenizer extends DefaultParamsReadable[RegexTokenizer] { + + @Since("1.6.0") + override def load(path: String): RegexTokenizer = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 086917fa680f8..4d3e46e488c67 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -20,14 +20,14 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -37,7 +37,7 @@ import org.apache.spark.sql.types._ */ @Experimental class VectorAssembler(override val uid: String) - extends Transformer with HasInputCols with HasOutputCol { + extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("vecAssembler")) @@ -47,10 +47,11 @@ class VectorAssembler(override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { // Schema transformation. val schema = dataset.schema - lazy val first = dataset.first() + lazy val first = dataset.toDF.first() val attrs = $(inputCols).flatMap { c => val field = schema(c) val index = schema.fieldIndex(c) @@ -70,20 +71,22 @@ class VectorAssembler(override val uid: String) val group = AttributeGroup.fromStructField(field) if (group.attributes.isDefined) { // If attributes are defined, copy them with updated names. - group.attributes.get.map { attr => + group.attributes.get.zipWithIndex.map { case (attr, i) => if (attr.name.isDefined) { // TODO: Define a rigorous naming scheme. attr.withName(c + "_" + attr.name.get) } else { - attr + attr.withName(c + "_" + i) } } } else { // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes // from metadata, check the first row. val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) - Array.fill(numAttrs)(NumericAttribute.defaultAttr) + Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i)) } + case otherType => + throw new SparkException(s"VectorAssembler does not support the $otherType type") } } val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() @@ -100,7 +103,7 @@ class VectorAssembler(override val uid: String) } } - dataset.select(col("*"), assembleFunc(struct(args : _*)).as($(outputCol), metadata)) + dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) } override def transformSchema(schema: StructType): StructType = { @@ -122,7 +125,11 @@ class VectorAssembler(override val uid: String) override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra) } -private object VectorAssembler { +@Since("1.6.0") +object VectorAssembler extends DefaultParamsReadable[VectorAssembler] { + + @Since("1.6.0") + override def load(path: String): VectorAssembler = super.load(path) private[feature] def assemble(vv: Any*): Vector = { val indices = ArrayBuilder.make[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 52e0599e38d83..68b699d569c7d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -22,14 +22,16 @@ import java.util.{Map => JMap} import scala.collection.JavaConverters._ -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, Params} +import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.util.collection.OpenHashSet @@ -93,7 +95,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu */ @Experimental class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerModel] - with VectorIndexerParams { + with VectorIndexerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("vecIdx")) @@ -106,12 +108,13 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def fit(dataset: DataFrame): VectorIndexerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): VectorIndexerModel = { transformSchema(dataset.schema, logging = true) val firstRow = dataset.select($(inputCol)).take(1) require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.") val numFeatures = firstRow(0).getAs[Vector](0).size - val vectorDataset = dataset.select($(inputCol)).map { case Row(v: Vector) => v } + val vectorDataset = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val maxCats = $(maxCategories) val categoryStats: VectorIndexer.CategoryStats = vectorDataset.mapPartitions { iter => val localCatStats = new VectorIndexer.CategoryStats(numFeatures, maxCats) @@ -136,7 +139,11 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod override def copy(extra: ParamMap): VectorIndexer = defaultCopy(extra) } -private object VectorIndexer { +@Since("1.6.0") +object VectorIndexer extends DefaultParamsReadable[VectorIndexer] { + + @Since("1.6.0") + override def load(path: String): VectorIndexer = super.load(path) /** * Helper class for tracking unique values for each feature. @@ -146,7 +153,7 @@ private object VectorIndexer { * @param numFeatures This class fails if it encounters a Vector whose length is not numFeatures. * @param maxCategories This class caps the number of unique values collected at maxCategories. */ - class CategoryStats(private val numFeatures: Int, private val maxCategories: Int) + private class CategoryStats(private val numFeatures: Int, private val maxCategories: Int) extends Serializable { /** featureValueSets[feature index] = set of unique values */ @@ -252,7 +259,9 @@ class VectorIndexerModel private[ml] ( override val uid: String, val numFeatures: Int, val categoryMaps: Map[Int, Map[Double, Int]]) - extends Model[VectorIndexerModel] with VectorIndexerParams { + extends Model[VectorIndexerModel] with VectorIndexerParams with MLWritable { + + import VectorIndexerModel._ /** Java-friendly version of [[categoryMaps]] */ def javaCategoryMaps: JMap[JInt, JMap[JDouble, JInt]] = { @@ -337,7 +346,8 @@ class VectorIndexerModel private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val newField = prepOutputField(dataset.schema) val transformUDF = udf { (vector: Vector) => transformFunc(vector) } @@ -408,4 +418,48 @@ class VectorIndexerModel private[ml] ( val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new VectorIndexerModelWriter(this) +} + +@Since("1.6.0") +object VectorIndexerModel extends MLReadable[VectorIndexerModel] { + + private[VectorIndexerModel] + class VectorIndexerModelWriter(instance: VectorIndexerModel) extends MLWriter { + + private case class Data(numFeatures: Int, categoryMaps: Map[Int, Map[Double, Int]]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.numFeatures, instance.categoryMaps) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class VectorIndexerModelReader extends MLReader[VectorIndexerModel] { + + private val className = classOf[VectorIndexerModel].getName + + override def load(path: String): VectorIndexerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("numFeatures", "categoryMaps") + .head() + val numFeatures = data.getAs[Int](0) + val categoryMaps = data.getAs[Map[Int, Map[Double, Int]]](1) + val model = new VectorIndexerModel(metadata.uid, numFeatures, categoryMaps) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[VectorIndexerModel] = new VectorIndexerModelReader + + @Since("1.6.0") + override def load(path: String): VectorIndexerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index fb3387d4aa9be..7a9468b87b73e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -17,14 +17,14 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup} -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam} -import org.apache.spark.ml.util.{Identifiable, MetadataUtils, SchemaUtils} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.StructType @@ -42,7 +42,7 @@ import org.apache.spark.sql.types.StructType */ @Experimental final class VectorSlicer(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("vectorSlicer")) @@ -89,12 +89,8 @@ final class VectorSlicer(override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def validateParams(): Unit = { - require($(indices).length > 0 || $(names).length > 0, - s"VectorSlicer requires that at least one feature be selected.") - } - - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { // Validity checks transformSchema(dataset.schema) val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol))) @@ -139,6 +135,8 @@ final class VectorSlicer(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + require($(indices).length > 0 || $(names).length > 0, + s"VectorSlicer requires that at least one feature be selected.") SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) if (schema.fieldNames.contains($(outputCol))) { @@ -153,10 +151,11 @@ final class VectorSlicer(override val uid: String) override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra) } -private[feature] object VectorSlicer { +@Since("1.6.0") +object VectorSlicer extends DefaultParamsReadable[VectorSlicer] { /** Return true if given feature indices are valid */ - def validIndices(indices: Array[Int]): Boolean = { + private[feature] def validIndices(indices: Array[Int]): Boolean = { if (indices.isEmpty) { true } else { @@ -165,7 +164,10 @@ private[feature] object VectorSlicer { } /** Return true if given feature names are valid */ - def validNames(names: Array[String]): Boolean = { + private[feature] def validNames(names: Array[String]): Boolean = { names.forall(_.nonEmpty) && names.length == names.distinct.length } + + @Since("1.6.0") + override def load(path: String): VectorSlicer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 9edab3af913ca..a72692960f928 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -17,18 +17,18 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkContext +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{VectorUDT, Vector, Vectors} -import org.apache.spark.mllib.linalg.BLAS._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT} +import org.apache.spark.sql.{DataFrame, Dataset, SQLContext} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types._ /** @@ -49,6 +49,17 @@ private[feature] trait Word2VecBase extends Params /** @group getParam */ def getVectorSize: Int = $(vectorSize) + /** + * The window size (context words from [-window, window]) default 5. + * @group expertParam + */ + final val windowSize = new IntParam( + this, "windowSize", "the window size (context words from [-window, window])") + setDefault(windowSize -> 5) + + /** @group expertGetParam */ + def getWindowSize: Int = $(windowSize) + /** * Number of partitions for sentences of words. * Default: 1 @@ -92,7 +103,8 @@ private[feature] trait Word2VecBase extends Params * natural language processing or machine learning process. */ @Experimental -final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] with Word2VecBase { +final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] with Word2VecBase + with DefaultParamsWritable { def this() = this(Identifiable.randomUID("w2v")) @@ -105,6 +117,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] /** @group setParam */ def setVectorSize(value: Int): this.type = set(vectorSize, value) + /** @group expertSetParam */ + def setWindowSize(value: Int): this.type = set(windowSize, value) + /** @group setParam */ def setStepSize(value: Double): this.type = set(stepSize, value) @@ -120,9 +135,10 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] /** @group setParam */ def setMinCount(value: Int): this.type = set(minCount, value) - override def fit(dataset: DataFrame): Word2VecModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): Word2VecModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0)) + val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0)) val wordVectors = new feature.Word2Vec() .setLearningRate($(stepSize)) .setMinCount($(minCount)) @@ -130,6 +146,7 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] .setNumPartitions($(numPartitions)) .setSeed($(seed)) .setVectorSize($(vectorSize)) + .setWindowSize($(windowSize)) .fit(input) copyValues(new Word2VecModel(uid, wordVectors).setParent(this)) } @@ -141,6 +158,13 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] override def copy(extra: ParamMap): Word2Vec = defaultCopy(extra) } +@Since("1.6.0") +object Word2Vec extends DefaultParamsReadable[Word2Vec] { + + @Since("1.6.0") + override def load(path: String): Word2Vec = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[Word2Vec]]. @@ -148,9 +172,10 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] @Experimental class Word2VecModel private[ml] ( override val uid: String, - wordVectors: feature.Word2VecModel) - extends Model[Word2VecModel] with Word2VecBase { + @transient private val wordVectors: feature.Word2VecModel) + extends Model[Word2VecModel] with Word2VecBase with MLWritable { + import Word2VecModel._ /** * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and @@ -195,24 +220,26 @@ class Word2VecModel private[ml] ( * Transform a sentence column to a vector column to represent the whole sentence. The transform * is performed by averaging all word vectors it contains. */ - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors) + val vectors = wordVectors.getVectors + .mapValues(vv => Vectors.dense(vv.map(_.toDouble))) + .map(identity) // mapValues doesn't return a serializable map (SI-7005) + val bVectors = dataset.sqlContext.sparkContext.broadcast(vectors) + val d = $(vectorSize) val word2Vec = udf { sentence: Seq[String] => if (sentence.size == 0) { - Vectors.sparse($(vectorSize), Array.empty[Int], Array.empty[Double]) + Vectors.sparse(d, Array.empty[Int], Array.empty[Double]) } else { - val cum = Vectors.zeros($(vectorSize)) - val model = bWordVectors.value.getVectors - for (word <- sentence) { - if (model.contains(word)) { - axpy(1.0, bWordVectors.value.transform(word), cum) - } else { - // pass words which not belong to model + val sum = Vectors.zeros(d) + sentence.foreach { word => + bVectors.value.get(word).foreach { v => + BLAS.axpy(1.0, v, sum) } } - scal(1.0 / sentence.size, cum) - cum + BLAS.scal(1.0 / sentence.size, sum) + sum } } dataset.withColumn($(outputCol), word2Vec(col($(inputCol)))) @@ -226,4 +253,49 @@ class Word2VecModel private[ml] ( val copied = new Word2VecModel(uid, wordVectors) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new Word2VecModelWriter(this) +} + +@Since("1.6.0") +object Word2VecModel extends MLReadable[Word2VecModel] { + + private[Word2VecModel] + class Word2VecModelWriter(instance: Word2VecModel) extends MLWriter { + + private case class Data(wordIndex: Map[String, Int], wordVectors: Seq[Float]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.wordVectors.wordIndex, instance.wordVectors.wordVectors.toSeq) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class Word2VecModelReader extends MLReader[Word2VecModel] { + + private val className = classOf[Word2VecModel].getName + + override def load(path: String): Word2VecModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("wordIndex", "wordVectors") + .head() + val wordIndex = data.getAs[Map[String, Int]](0) + val wordVectors = data.getAs[Seq[Float]](1).toArray + val oldModel = new feature.Word2VecModel(wordIndex, wordVectors) + val model = new Word2VecModel(metadata.uid, oldModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[Word2VecModel] = new Word2VecModelReader + + @Since("1.6.0") + override def load(path: String): Word2VecModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java index c22d2e0cd2d90..dcff4245d1d26 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java @@ -22,8 +22,8 @@ * The `ml.feature` package provides common feature transformers that help convert raw data or * features into more suitable forms for model fitting. * Most feature transformers are implemented as {@link org.apache.spark.ml.Transformer}s, which - * transforms one {@link org.apache.spark.sql.DataFrame} into another, e.g., - * {@link org.apache.spark.feature.HashingTF}. + * transforms one {@link org.apache.spark.sql.Dataset} into another, e.g., + * {@link org.apache.spark.ml.feature.HashingTF}. * Some feature transformers are implemented as {@link org.apache.spark.ml.Estimator}}s, because the * transformation requires some aggregated information of the dataset, e.g., document * frequencies in {@link org.apache.spark.ml.feature.IDF}. @@ -31,7 +31,7 @@ * obtain the model first, e.g., {@link org.apache.spark.ml.feature.IDFModel}, in order to apply * transformation. * The transformation is usually done by appending new columns to the input - * {@link org.apache.spark.sql.DataFrame}, so all input columns are carried over. + * {@link org.apache.spark.sql.Dataset}, so all input columns are carried over. * * We try to make each transformer minimal, so it becomes flexible to assemble feature * transformation pipelines. @@ -46,7 +46,7 @@ * import org.apache.spark.api.java.JavaRDD; * import static org.apache.spark.sql.types.DataTypes.*; * import org.apache.spark.sql.types.StructType; - * import org.apache.spark.sql.DataFrame; + * import org.apache.spark.sql.Dataset; * import org.apache.spark.sql.RowFactory; * import org.apache.spark.sql.Row; * @@ -66,7 +66,7 @@ * RowFactory.create(0, "Hi I heard about Spark", 3.0), * RowFactory.create(1, "I wish Java could use case classes", 4.0), * RowFactory.create(2, "Logistic regression models are neat", 4.0))); - * DataFrame df = jsql.createDataFrame(rowRDD, schema); + * Dataset dataset = jsql.createDataFrame(rowRDD, schema); * // define feature transformers * RegexTokenizer tok = new RegexTokenizer() * .setInputCol("text") @@ -88,10 +88,10 @@ * // assemble and fit the feature transformation pipeline * Pipeline pipeline = new Pipeline() * .setStages(new PipelineStage[] {tok, sw, tf, idf, assembler}); - * PipelineModel model = pipeline.fit(df); + * PipelineModel model = pipeline.fit(dataset); * * // save transformed features with raw data - * model.transform(df) + * model.transform(dataset) * .select("id", "text", "rating", "features") * .write().format("parquet").save("/output/path"); *
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala new file mode 100644 index 0000000000000..a2b52835e177a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.optim + +import org.apache.spark.internal.Logging +import org.apache.spark.ml.feature.Instance +import org.apache.spark.mllib.linalg._ +import org.apache.spark.rdd.RDD + +/** + * Model fitted by [[IterativelyReweightedLeastSquares]]. + * @param coefficients model coefficients + * @param intercept model intercept + * @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration + * @param numIterations number of iterations + */ +private[ml] class IterativelyReweightedLeastSquaresModel( + val coefficients: DenseVector, + val intercept: Double, + val diagInvAtWA: DenseVector, + val numIterations: Int) extends Serializable + +/** + * Implements the method of iteratively reweighted least squares (IRLS) which is used to solve + * certain optimization problems by an iterative method. In each step of the iterations, it + * involves solving a weighted lease squares (WLS) problem by [[WeightedLeastSquares]]. + * It can be used to find maximum likelihood estimates of a generalized linear model (GLM), + * find M-estimator in robust regression and other optimization problems. + * + * @param initialModel the initial guess model. + * @param reweightFunc the reweight function which is used to update offsets and weights + * at each iteration. + * @param fitIntercept whether to fit intercept. + * @param regParam L2 regularization parameter used by WLS. + * @param maxIter maximum number of iterations. + * @param tol the convergence tolerance. + * + * @see [[http://www.jstor.org/stable/2345503 P. J. Green, Iteratively Reweighted Least Squares + * for Maximum Likelihood Estimation, and some Robust and Resistant Alternatives, + * Journal of the Royal Statistical Society. Series B, 1984.]] + */ +private[ml] class IterativelyReweightedLeastSquares( + val initialModel: WeightedLeastSquaresModel, + val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double), + val fitIntercept: Boolean, + val regParam: Double, + val maxIter: Int, + val tol: Double) extends Logging with Serializable { + + def fit(instances: RDD[Instance]): IterativelyReweightedLeastSquaresModel = { + + var converged = false + var iter = 0 + + var model: WeightedLeastSquaresModel = initialModel + var oldModel: WeightedLeastSquaresModel = null + + while (iter < maxIter && !converged) { + + oldModel = model + + // Update offsets and weights using reweightFunc + val newInstances = instances.map { instance => + val (newOffset, newWeight) = reweightFunc(instance, oldModel) + Instance(newOffset, newWeight, instance.features) + } + + // Estimate new model + model = new WeightedLeastSquares(fitIntercept, regParam, standardizeFeatures = false, + standardizeLabel = false).fit(newInstances) + + // Check convergence + val oldCoefficients = oldModel.coefficients + val coefficients = model.coefficients + BLAS.axpy(-1.0, coefficients, oldCoefficients) + val maxTolOfCoefficients = oldCoefficients.toArray.reduce { (x, y) => + math.max(math.abs(x), math.abs(y)) + } + val maxTol = math.max(maxTolOfCoefficients, math.abs(oldModel.intercept - model.intercept)) + + if (maxTol < tol) { + converged = true + logInfo(s"IRLS converged in $iter iterations.") + } + + logInfo(s"Iteration $iter : relative tolerance = $maxTol") + iter = iter + 1 + + if (iter == maxIter) { + logInfo(s"IRLS reached the max number of iterations: $maxIter.") + } + + } + + new IterativelyReweightedLeastSquaresModel( + model.coefficients, model.intercept, model.diagInvAtWA, iter) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 8617722ae542f..7d21302f962bf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.optim -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance import org.apache.spark.mllib.linalg._ import org.apache.spark.rdd.RDD @@ -31,7 +31,12 @@ import org.apache.spark.rdd.RDD private[ml] class WeightedLeastSquaresModel( val coefficients: DenseVector, val intercept: Double, - val diagInvAtWA: DenseVector) extends Serializable + val diagInvAtWA: DenseVector) extends Serializable { + + def predict(features: Vector): Double = { + BLAS.dot(coefficients, features) + intercept + } +} /** * Weighted least squares solver via normal equation. @@ -86,6 +91,24 @@ private[ml] class WeightedLeastSquares( val aaBar = summary.aaBar val aaValues = aaBar.values + if (bStd == 0) { + if (fitIntercept) { + logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + + s"zeros and the intercept will be the mean of the label; as a result, " + + s"training is not needed.") + val coefficients = new DenseVector(Array.ofDim(k-1)) + val intercept = bBar + val diagInvAtWA = new DenseVector(Array(0D)) + return new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA) + } else { + require(!(regParam > 0.0 && standardizeLabel), + "The standard deviation of the label is zero. " + + "Model cannot be regularized with standardization=true") + logWarning(s"The standard deviation of the label is zero. " + + "Consider setting fitIntercept=true.") + } + } + // add regularization to diagonals var i = 0 var j = 2 @@ -94,8 +117,7 @@ private[ml] class WeightedLeastSquares( if (standardizeFeatures) { lambda *= aVar(j - 2) } - if (standardizeLabel) { - // TODO: handle the case when bStd = 0 + if (standardizeLabel && bStd != 0) { lambda /= bStd } aaValues(i) += lambda @@ -134,6 +156,12 @@ private[ml] class WeightedLeastSquares( private[ml] object WeightedLeastSquares { + /** + * In order to take the normal equation approach efficiently, [[WeightedLeastSquares]] + * only supports the number of features is no more than 4096. + */ + val MAX_NUM_FEATURES: Int = 4096 + /** * Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]]. */ @@ -152,8 +180,8 @@ private[ml] object WeightedLeastSquares { private var aaSum: DenseVector = _ private def init(k: Int): Unit = { - require(k <= 4096, "In order to take the normal equation approach efficiently, " + - s"we set the max number of features to 4096 but got $k.") + require(k <= MAX_NUM_FEATURES, "In order to take the normal equation approach efficiently, " + + s"we set the max number of features to $MAX_NUM_FEATURES but got $k.") this.k = k triK = k * (k + 1) / 2 count = 0L diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 8361406f87299..c368aadd23669 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.param import java.lang.reflect.Modifier +import java.util.{List => JList} import java.util.NoSuchElementException import scala.annotation.varargs @@ -27,8 +28,9 @@ import scala.collection.JavaConverters._ import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * :: DeveloperApi :: @@ -57,9 +59,8 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali /** * Assert that the given value is valid for this parameter. * - * Note: Parameter checks involving interactions between multiple parameters should be - * implemented in [[Params.validateParams()]]. Checks for input/output columns should be - * implemented in [[org.apache.spark.ml.PipelineStage.transformSchema()]]. + * Note: Parameter checks involving interactions between multiple parameters and input/output + * columns should be implemented in [[org.apache.spark.ml.PipelineStage.transformSchema()]]. * * DEVELOPERS: This method is only called by [[ParamPair]], which means that all parameters * should be specified via [[ParamPair]]. @@ -81,33 +82,30 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali def w(value: T): ParamPair[T] = this -> value /** Creates a param pair with the given value (for Scala). */ + // scalastyle:off def ->(value: T): ParamPair[T] = ParamPair(this, value) + // scalastyle:on /** Encodes a param value into JSON, which can be decoded by [[jsonDecode()]]. */ def jsonEncode(value: T): String = { value match { case x: String => compact(render(JString(x))) + case v: Vector => + v.toJson case _ => throw new NotImplementedError( - "The default jsonEncode only supports string. " + + "The default jsonEncode only supports string and vector. " + s"${this.getClass.getName} must override jsonEncode for ${value.getClass.getName}.") } } /** Decodes a param value from JSON. */ - def jsonDecode(json: String): T = { - parse(json) match { - case JString(x) => - x.asInstanceOf[T] - case _ => - throw new NotImplementedError( - "The default jsonDecode only supports string. " + - s"${this.getClass.getName} must override jsonDecode to support its value type.") - } - } + def jsonDecode(json: String): T = Param.jsonDecode[T](json) + + private[this] val stringRepresentation = s"${parent}__$name" - override final def toString: String = s"${parent}__$name" + override final def toString: String = stringRepresentation override final def hashCode: Int = toString.## @@ -119,6 +117,26 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali } } +private[ml] object Param { + + /** Decodes a param value from JSON. */ + def jsonDecode[T](json: String): T = { + parse(json) match { + case JString(x) => + x.asInstanceOf[T] + case JObject(v) => + val keys = v.map(_._1) + assert(keys.contains("type") && keys.contains("values"), + s"Expect a JSON serialized vector but cannot find fields 'type' and 'values' in $json.") + Vectors.fromJson(json).asInstanceOf[T] + case _ => + throw new NotImplementedError( + "The default jsonDecode only supports string and vector. " + + s"${this.getClass.getName} must override jsonDecode to support its value type.") + } + } +} + /** * :: DeveloperApi :: * Factory methods for common validation functions for [[Param.isValid]]. @@ -494,8 +512,11 @@ class IntArrayParam(parent: Params, name: String, doc: String, isValid: Array[In * :: Experimental :: * A param and its value. */ +@Since("1.2.0") @Experimental -case class ParamPair[T](param: Param[T], value: T) { +case class ParamPair[T] @Since("1.2.0") ( + @Since("1.2.0") param: Param[T], + @Since("1.2.0") value: T) { // This is *the* place Param.validate is called. Whenever a parameter is specified, we should // always construct a ParamPair so that validate is called. param.validate(value) @@ -534,7 +555,9 @@ trait Params extends Identifiable with Serializable { * Parameter value checks which do not depend on other parameters are handled by * [[Param.validate()]]. This method does not handle input/output column parameters; * those are checked during schema validation. + * @deprecated Will be removed in 2.1.0. All the checks should be merged into transformSchema */ + @deprecated("Will be removed in 2.1.0. Checks should be merged into transformSchema.", "2.0.0") def validateParams(): Unit = { // Do nothing by default. Override to handle Param interactions. } @@ -592,7 +615,7 @@ trait Params extends Identifiable with Serializable { /** * Sets a parameter in the embedded param map. */ - protected final def set[T](param: Param[T], value: T): this.type = { + final def set[T](param: Param[T], value: T): this.type = { set(param -> value) } @@ -776,6 +799,7 @@ abstract class JavaParams extends Params * :: Experimental :: * A param to value map. */ +@Since("1.2.0") @Experimental final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable { @@ -789,17 +813,20 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Creates an empty param map. */ + @Since("1.2.0") def this() = this(mutable.Map.empty) /** * Puts a (param, value) pair (overwrites if the input param exists). */ + @Since("1.2.0") def put[T](param: Param[T], value: T): this.type = put(param -> value) /** * Puts a list of param pairs (overwrites if the input params exists). */ @varargs + @Since("1.2.0") def put(paramPairs: ParamPair[_]*): this.type = { paramPairs.foreach { p => map(p.param.asInstanceOf[Param[Any]]) = p.value @@ -807,9 +834,15 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) this } + /** Put param pairs with a [[java.util.List]] of values for Python. */ + private[ml] def put(paramPairs: JList[ParamPair[_]]): this.type = { + put(paramPairs.asScala: _*) + } + /** * Optionally returns the value associated with a param. */ + @Since("1.2.0") def get[T](param: Param[T]): Option[T] = { map.get(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]] } @@ -817,6 +850,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Returns the value associated with a param or a default value. */ + @Since("1.4.0") def getOrElse[T](param: Param[T], default: T): T = { get(param).getOrElse(default) } @@ -825,6 +859,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) * Gets the value of the input param or its default value if it does not exist. * Raises a NoSuchElementException if there is no value associated with the input param. */ + @Since("1.2.0") def apply[T](param: Param[T]): T = { get(param).getOrElse { throw new NoSuchElementException(s"Cannot find param ${param.name}.") @@ -834,6 +869,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Checks whether a parameter is explicitly specified. */ + @Since("1.2.0") def contains(param: Param[_]): Boolean = { map.contains(param.asInstanceOf[Param[Any]]) } @@ -841,6 +877,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Removes a key from this map and returns its value associated previously as an option. */ + @Since("1.4.0") def remove[T](param: Param[T]): Option[T] = { map.remove(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]] } @@ -848,16 +885,23 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Filters this param map for the given parent. */ + @Since("1.2.0") def filter(parent: Params): ParamMap = { - val filtered = map.filterKeys(_.parent == parent) - new ParamMap(filtered.asInstanceOf[mutable.Map[Param[Any], Any]]) + // Don't use filterKeys because mutable.Map#filterKeys + // returns the instance of collections.Map, not mutable.Map. + // Otherwise, we get ClassCastException. + // Not using filterKeys also avoid SI-6654 + val filtered = map.filter { case (k, _) => k.parent == parent.uid } + new ParamMap(filtered) } /** * Creates a copy of this param map. */ + @Since("1.2.0") def copy: ParamMap = new ParamMap(map.clone()) + @Since("1.2.0") override def toString: String = { map.toSeq.sortBy(_._1.name).map { case (param, value) => s"\t${param.parent}-${param.name}: $value" @@ -868,6 +912,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) * Returns a new param map that contains parameters in this map and the given map, * where the latter overwrites this if there exist conflicts. */ + @Since("1.2.0") def ++(other: ParamMap): ParamMap = { // TODO: Provide a better method name for Java users. new ParamMap(this.map ++ other.map) @@ -876,6 +921,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Adds all parameters from the input param map into this param map. */ + @Since("1.2.0") def ++=(other: ParamMap): this.type = { // TODO: Provide a better method name for Java users. this.map ++= other.map @@ -885,30 +931,40 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Converts this param map to a sequence of param pairs. */ + @Since("1.2.0") def toSeq: Seq[ParamPair[_]] = { map.toSeq.map { case (param, value) => ParamPair(param, value) } } + /** Java-friendly method for Python API */ + private[ml] def toList: java.util.List[ParamPair[_]] = { + this.toSeq.asJava + } + /** * Number of param pairs in this map. */ + @Since("1.3.0") def size: Int = map.size } +@Since("1.2.0") @Experimental object ParamMap { /** * Returns an empty param map. */ + @Since("1.2.0") def empty: ParamMap = new ParamMap() /** * Constructs a param map by specifying its entries. */ @varargs + @Since("1.2.0") def apply(paramPairs: ParamPair[_]*): ParamMap = { new ParamMap().put(paramPairs: _*) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index c7bca1243092c..1d03a5b4f4048 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -44,6 +44,7 @@ private[shared] object SharedParamsCodeGen { " probabilities. Note: Not all models output well-calibrated probability estimates!" + " These probabilities should be treated as confidences, not precise probabilities", Some("\"probability\"")), + ParamDesc[String]("varianceCol", "Column name for the biased sample variance of prediction"), ParamDesc[Double]("threshold", "threshold in binary classification prediction, in range [0, 1]", Some("0.5"), isValid = "ParamValidators.inRange(0, 1)", finalMethods = false), @@ -51,7 +52,7 @@ private[shared] object SharedParamsCodeGen { " to adjust the probability of predicting each class." + " Array must have length equal to the number of classes, with values >= 0." + " The class with largest value p/t is predicted, where p is the original probability" + - " of that class and t is the class' threshold.", + " of that class and t is the class' threshold", isValid = "(t: Array[Double]) => t.forall(_ >= 0)", finalMethods = false), ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), @@ -61,8 +62,8 @@ private[shared] object SharedParamsCodeGen { "every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " + - "will filter out rows with bad values), or error (which will throw an errror). More " + - "options may be added later.", + "will filter out rows with bad values), or error (which will throw an error). More " + + "options may be added later", isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"), ParamDesc[Boolean]("standardization", "whether to standardize the training features" + " before fitting the model", Some("true")), @@ -71,11 +72,11 @@ private[shared] object SharedParamsCodeGen { " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty", isValid = "ParamValidators.inRange(0, 1)"), ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"), - ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."), + ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization"), ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " + - "all instance weights as 1.0."), + "all instance weights as 1.0"), ParamDesc[String]("solver", "the solver algorithm for optimization. If this is not set or " + - "empty, default value is 'auto'.", Some("\"auto\""))) + "empty, default value is 'auto'", Some("\"auto\""))) val code = genSharedParams(params) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index cb2a060a34dd6..64d6af2766ca9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -138,6 +138,21 @@ private[ml] trait HasProbabilityCol extends Params { final def getProbabilityCol: String = $(probabilityCol) } +/** + * Trait for shared param varianceCol. + */ +private[ml] trait HasVarianceCol extends Params { + + /** + * Param for Column name for the biased sample variance of prediction. + * @group param + */ + final val varianceCol: Param[String] = new Param[String](this, "varianceCol", "Column name for the biased sample variance of prediction") + + /** @group getParam */ + final def getVarianceCol: String = $(varianceCol) +} + /** * Trait for shared param threshold (default: 0.5). */ @@ -161,10 +176,10 @@ private[ml] trait HasThreshold extends Params { private[ml] trait HasThresholds extends Params { /** - * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.. + * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold. * @group param */ - final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", (t: Array[Double]) => t.forall(_ >= 0)) + final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold", (t: Array[Double]) => t.forall(_ >= 0)) /** @group getParam */ def getThresholds: Array[Double] = $(thresholds) @@ -255,10 +270,10 @@ private[ml] trait HasFitIntercept extends Params { private[ml] trait HasHandleInvalid extends Params { /** - * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.. + * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an error). More options may be added later. * @group param */ - final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", ParamValidators.inArray(Array("skip", "error"))) + final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an error). More options may be added later", ParamValidators.inArray(Array("skip", "error"))) /** @group getParam */ final def getHandleInvalid: String = $(handleInvalid) @@ -334,10 +349,10 @@ private[ml] trait HasTol extends Params { private[ml] trait HasStepSize extends Params { /** - * Param for Step size to be used for each iteration of optimization.. + * Param for Step size to be used for each iteration of optimization. * @group param */ - final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size to be used for each iteration of optimization.") + final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size to be used for each iteration of optimization") /** @group getParam */ final def getStepSize: Double = $(stepSize) @@ -349,10 +364,10 @@ private[ml] trait HasStepSize extends Params { private[ml] trait HasWeightCol extends Params { /** - * Param for weight column name. If this is not set or empty, we treat all instance weights as 1.0.. + * Param for weight column name. If this is not set or empty, we treat all instance weights as 1.0. * @group param */ - final val weightCol: Param[String] = new Param[String](this, "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.") + final val weightCol: Param[String] = new Param[String](this, "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0") /** @group getParam */ final def getWeightCol: String = $(weightCol) @@ -364,10 +379,10 @@ private[ml] trait HasWeightCol extends Params { private[ml] trait HasSolver extends Params { /** - * Param for the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.. + * Param for the solver algorithm for optimization. If this is not set or empty, default value is 'auto'. * @group param */ - final val solver: Param[String] = new Param[String](this, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.") + final val solver: Param[String] = new Param[String](this, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'") setDefault(solver, "auto") diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala new file mode 100644 index 0000000000000..783546862689a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.spark.SparkException +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.regression.{AFTSurvivalRegression, AFTSurvivalRegressionModel} +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class AFTSurvivalRegressionWrapper private ( + pipeline: PipelineModel, + features: Array[String]) { + + private val aftModel: AFTSurvivalRegressionModel = + pipeline.stages(1).asInstanceOf[AFTSurvivalRegressionModel] + + lazy val rCoefficients: Array[Double] = if (aftModel.getFitIntercept) { + Array(aftModel.intercept) ++ aftModel.coefficients.toArray ++ Array(math.log(aftModel.scale)) + } else { + aftModel.coefficients.toArray ++ Array(math.log(aftModel.scale)) + } + + lazy val rFeatures: Array[String] = if (aftModel.getFitIntercept) { + Array("(Intercept)") ++ features ++ Array("Log(scale)") + } else { + features ++ Array("Log(scale)") + } + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset).drop(aftModel.getFeaturesCol) + } +} + +private[r] object AFTSurvivalRegressionWrapper { + + private def formulaRewrite(formula: String): (String, String) = { + var rewritedFormula: String = null + var censorCol: String = null + + val regex = """Surv\(([^,]+), ([^,]+)\) ~ (.+)""".r + try { + val regex(label, censor, features) = formula + // TODO: Support dot operator. + if (features.contains(".")) { + throw new UnsupportedOperationException( + "Terms of survreg formula can not support dot operator.") + } + rewritedFormula = label.trim + "~" + features.trim + censorCol = censor.trim + } catch { + case e: MatchError => + throw new SparkException(s"Could not parse formula: $formula") + } + + (rewritedFormula, censorCol) + } + + + def fit(formula: String, data: DataFrame): AFTSurvivalRegressionWrapper = { + + val (rewritedFormula, censorCol) = formulaRewrite(formula) + + val rFormula = new RFormula().setFormula(rewritedFormula) + val rFormulaModel = rFormula.fit(data) + + // get feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + + val aft = new AFTSurvivalRegression() + .setCensorCol(censorCol) + .setFitIntercept(rFormula.hasIntercept) + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, aft)) + .fit(data) + + new AFTSurvivalRegressionWrapper(pipeline, features) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala new file mode 100644 index 0000000000000..475a3083854e8 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.regression._ +import org.apache.spark.sql._ + +private[r] class GeneralizedLinearRegressionWrapper private ( + pipeline: PipelineModel, + val features: Array[String]) { + + private val glm: GeneralizedLinearRegressionModel = + pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel] + + lazy val rCoefficients: Array[Double] = if (glm.getFitIntercept) { + Array(glm.intercept) ++ glm.coefficients.toArray + } else { + glm.coefficients.toArray + } + + lazy val rFeatures: Array[String] = if (glm.getFitIntercept) { + Array("(Intercept)") ++ features + } else { + features + } + + def transform(dataset: DataFrame): DataFrame = { + pipeline.transform(dataset).drop(glm.getFeaturesCol) + } +} + +private[r] object GeneralizedLinearRegressionWrapper { + + def fit( + formula: String, + data: DataFrame, + family: String, + link: String, + epsilon: Double, + maxit: Int): GeneralizedLinearRegressionWrapper = { + val rFormula = new RFormula() + .setFormula(formula) + val rFormulaModel = rFormula.fit(data) + // get labels and feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + // assemble and fit the pipeline + val glm = new GeneralizedLinearRegression() + .setFamily(family) + .setLink(link) + .setFitIntercept(rFormula.hasIntercept) + .setTol(epsilon) + .setMaxIter(maxit) + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, glm)) + .fit(data) + new GeneralizedLinearRegressionWrapper(pipeline, features) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala new file mode 100644 index 0000000000000..9e2b81ee20147 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.clustering.{KMeans, KMeansModel} +import org.apache.spark.ml.feature.VectorAssembler +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class KMeansWrapper private ( + pipeline: PipelineModel) { + + private val kMeansModel: KMeansModel = pipeline.stages(1).asInstanceOf[KMeansModel] + + lazy val coefficients: Array[Double] = kMeansModel.clusterCenters.flatMap(_.toArray) + + private lazy val attrs = AttributeGroup.fromStructField( + kMeansModel.summary.predictions.schema(kMeansModel.getFeaturesCol)) + + lazy val features: Array[String] = attrs.attributes.get.map(_.name.get) + + lazy val k: Int = kMeansModel.getK + + lazy val size: Array[Long] = kMeansModel.summary.clusterSizes + + lazy val cluster: DataFrame = kMeansModel.summary.cluster + + def fitted(method: String): DataFrame = { + if (method == "centers") { + kMeansModel.summary.predictions.drop(kMeansModel.getFeaturesCol) + } else if (method == "classes") { + kMeansModel.summary.cluster + } else { + throw new UnsupportedOperationException( + s"Method (centers or classes) required but $method found.") + } + } + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset).drop(kMeansModel.getFeaturesCol) + } + +} + +private[r] object KMeansWrapper { + + def fit( + data: DataFrame, + k: Double, + maxIter: Double, + initMode: String, + columns: Array[String]): KMeansWrapper = { + + val assembler = new VectorAssembler() + .setInputCols(columns) + .setOutputCol("features") + + val kMeans = new KMeans() + .setK(k.toInt) + .setMaxIter(maxIter.toInt) + .setInitMode(initMode) + + val pipeline = new Pipeline() + .setStages(Array(assembler, kMeans)) + .fit(data) + + new KMeansWrapper(pipeline) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala new file mode 100644 index 0000000000000..b17207e99bb85 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} +import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel} +import org.apache.spark.ml.feature.{IndexToString, RFormula} +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class NaiveBayesWrapper private ( + pipeline: PipelineModel, + val labels: Array[String], + val features: Array[String]) { + + import NaiveBayesWrapper._ + + private val naiveBayesModel: NaiveBayesModel = pipeline.stages(1).asInstanceOf[NaiveBayesModel] + + lazy val apriori: Array[Double] = naiveBayesModel.pi.toArray.map(math.exp) + + lazy val tables: Array[Double] = naiveBayesModel.theta.toArray.map(math.exp) + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset) + .drop(PREDICTED_LABEL_INDEX_COL) + .drop(naiveBayesModel.getFeaturesCol) + } +} + +private[r] object NaiveBayesWrapper { + + val PREDICTED_LABEL_INDEX_COL = "pred_label_idx" + val PREDICTED_LABEL_COL = "prediction" + + def fit(formula: String, data: DataFrame, laplace: Double): NaiveBayesWrapper = { + val rFormula = new RFormula() + .setFormula(formula) + .fit(data) + // get labels and feature names from output schema + val schema = rFormula.transform(data).schema + val labelAttr = Attribute.fromStructField(schema(rFormula.getLabelCol)) + .asInstanceOf[NominalAttribute] + val labels = labelAttr.values.get + val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + // assemble and fit the pipeline + val naiveBayes = new NaiveBayes() + .setSmoothing(laplace) + .setModelType("bernoulli") + .setPredictionCol(PREDICTED_LABEL_INDEX_COL) + val idxToStr = new IndexToString() + .setInputCol(PREDICTED_LABEL_INDEX_COL) + .setOutputCol(PREDICTED_LABEL_COL) + .setLabels(labels) + val pipeline = new Pipeline() + .setStages(Array(rFormula, naiveBayes, idxToStr)) + .fit(data) + new NaiveBayesWrapper(pipeline, labels, features) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala deleted file mode 100644 index 5be2f86936211..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.api.r - -import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.feature.RFormula -import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} -import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} -import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.sql.DataFrame - -private[r] object SparkRWrappers { - def fitRModelFormula( - value: String, - df: DataFrame, - family: String, - lambda: Double, - alpha: Double, - standardize: Boolean, - solver: String): PipelineModel = { - val formula = new RFormula().setFormula(value) - val estimator = family match { - case "gaussian" => new LinearRegression() - .setRegParam(lambda) - .setElasticNetParam(alpha) - .setFitIntercept(formula.hasIntercept) - .setStandardization(standardize) - .setSolver(solver) - case "binomial" => new LogisticRegression() - .setRegParam(lambda) - .setElasticNetParam(alpha) - .setFitIntercept(formula.hasIntercept) - .setStandardization(standardize) - } - val pipeline = new Pipeline().setStages(Array(formula, estimator)) - pipeline.fit(df) - } - - def getModelCoefficients(model: PipelineModel): Array[Double] = { - model.stages.last match { - case m: LinearRegressionModel => - Array(m.intercept) ++ m.coefficients.toArray - case m: LogisticRegressionModel => - Array(m.intercept) ++ m.coefficients.toArray - } - } - - def getModelFeatures(model: PipelineModel): Array[String] = { - model.stages.last match { - case m: LinearRegressionModel => - val attrs = AttributeGroup.fromStructField( - m.summary.predictions.schema(m.summary.featuresCol)) - Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) - case m: LogisticRegressionModel => - val attrs = AttributeGroup.fromStructField( - m.summary.predictions.schema(m.summary.featuresCol)) - Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) - } - } -} diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 535f266b9a944..36dce015908eb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -27,17 +27,20 @@ import scala.util.hashing.byteswap64 import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.hadoop.fs.{FileSystem, Path} +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ -import org.apache.spark.{Logging, Partitioner} -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.Partitioner +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} +import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType} import org.apache.spark.storage.StorageLevel @@ -177,23 +180,28 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w * @param itemFactors a DataFrame that stores item factors in two columns: `id` and `features` */ @Experimental +@Since("1.3.0") class ALSModel private[ml] ( - override val uid: String, - val rank: Int, + @Since("1.4.0") override val uid: String, + @Since("1.4.0") val rank: Int, @transient val userFactors: DataFrame, @transient val itemFactors: DataFrame) - extends Model[ALSModel] with ALSModelParams { + extends Model[ALSModel] with ALSModelParams with MLWritable { /** @group setParam */ + @Since("1.4.0") def setUserCol(value: String): this.type = set(userCol, value) /** @group setParam */ + @Since("1.4.0") def setItemCol(value: String): this.type = set(itemCol, value) /** @group setParam */ + @Since("1.3.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { // Register a UDF for DataFrame, and then // create a new column named map(predictionCol) by running the predict UDF. val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => @@ -210,18 +218,65 @@ class ALSModel private[ml] ( predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) } + @Since("1.3.0") override def transformSchema(schema: StructType): StructType = { SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) } + @Since("1.5.0") override def copy(extra: ParamMap): ALSModel = { val copied = new ALSModel(uid, rank, userFactors, itemFactors) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new ALSModel.ALSModelWriter(this) } +@Since("1.6.0") +object ALSModel extends MLReadable[ALSModel] { + + @Since("1.6.0") + override def read: MLReader[ALSModel] = new ALSModelReader + + @Since("1.6.0") + override def load(path: String): ALSModel = super.load(path) + + private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val extraMetadata = "rank" -> instance.rank + DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) + val userPath = new Path(path, "userFactors").toString + instance.userFactors.write.format("parquet").save(userPath) + val itemPath = new Path(path, "itemFactors").toString + instance.itemFactors.write.format("parquet").save(itemPath) + } + } + + private class ALSModelReader extends MLReader[ALSModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[ALSModel].getName + + override def load(path: String): ALSModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + implicit val format = DefaultFormats + val rank = (metadata.metadata \ "rank").extract[Int] + val userPath = new Path(path, "userFactors").toString + val userFactors = sqlContext.read.format("parquet").load(userPath) + val itemPath = new Path(path, "itemFactors").toString + val itemFactors = sqlContext.read.format("parquet").load(itemPath) + + val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } +} /** * :: Experimental :: @@ -254,69 +309,89 @@ class ALSModel private[ml] ( * preferences rather than explicit ratings given to items. */ @Experimental -class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { +@Since("1.3.0") +class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] with ALSParams + with DefaultParamsWritable { import org.apache.spark.ml.recommendation.ALS.Rating + @Since("1.4.0") def this() = this(Identifiable.randomUID("als")) /** @group setParam */ + @Since("1.3.0") def setRank(value: Int): this.type = set(rank, value) /** @group setParam */ + @Since("1.3.0") def setNumUserBlocks(value: Int): this.type = set(numUserBlocks, value) /** @group setParam */ + @Since("1.3.0") def setNumItemBlocks(value: Int): this.type = set(numItemBlocks, value) /** @group setParam */ + @Since("1.3.0") def setImplicitPrefs(value: Boolean): this.type = set(implicitPrefs, value) /** @group setParam */ + @Since("1.3.0") def setAlpha(value: Double): this.type = set(alpha, value) /** @group setParam */ + @Since("1.3.0") def setUserCol(value: String): this.type = set(userCol, value) /** @group setParam */ + @Since("1.3.0") def setItemCol(value: String): this.type = set(itemCol, value) /** @group setParam */ + @Since("1.3.0") def setRatingCol(value: String): this.type = set(ratingCol, value) /** @group setParam */ + @Since("1.3.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) /** @group setParam */ + @Since("1.3.0") def setMaxIter(value: Int): this.type = set(maxIter, value) /** @group setParam */ + @Since("1.3.0") def setRegParam(value: Double): this.type = set(regParam, value) /** @group setParam */ + @Since("1.3.0") def setNonnegative(value: Boolean): this.type = set(nonnegative, value) /** @group setParam */ + @Since("1.4.0") def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) /** @group setParam */ + @Since("1.3.0") def setSeed(value: Long): this.type = set(seed, value) /** * Sets both numUserBlocks and numItemBlocks to the specific value. * @group setParam */ + @Since("1.3.0") def setNumBlocks(value: Int): this.type = { setNumUserBlocks(value) setNumItemBlocks(value) this } - override def fit(dataset: DataFrame): ALSModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): ALSModel = { import dataset.sqlContext.implicits._ val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) val ratings = dataset .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), r) + .rdd .map { row => Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) } @@ -331,13 +406,16 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { copyValues(model) } + @Since("1.3.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + @Since("1.5.0") override def copy(extra: ParamMap): ALS = defaultCopy(extra) } + /** * :: DeveloperApi :: * An implementation of ALS that supports generic ID types, specialized for Int and Long. This is @@ -347,7 +425,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { * than 2 billion. */ @DeveloperApi -object ALS extends Logging { +object ALS extends DefaultParamsReadable[ALS] with Logging { /** * :: DeveloperApi :: @@ -356,6 +434,9 @@ object ALS extends Logging { @DeveloperApi case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float) + @Since("1.6.0") + override def load(path: String): ALS = super.load(path) + /** Trait for least squares solvers applied to the normal equation. */ private[recommendation] trait LeastSquaresNESolver extends Serializable { /** Solves a least squares problem with regularization (possibly with other constraints). */ @@ -415,7 +496,7 @@ object ALS extends Logging { } /** - * Solves a nonnegative least squares problem with L2 regularizatin: + * Solves a nonnegative least squares problem with L2 regularization: * * min_x_ norm(A x - b)^2^ + lambda * n * norm(x)^2^ * subject to x >= 0 @@ -1219,8 +1300,8 @@ object ALS extends Logging { } /** - * Partitioner used by ALS. We requires that getPartition is a projection. That is, for any key k, - * we have getPartition(getPartition(k)) = getPartition(k). Since the the default HashPartitioner + * Partitioner used by ALS. We require that getPartition is a projection. That is, for any key k, + * we have getPartition(getPartition(k)) = getPartition(k). Since the default HashPartitioner * satisfies this requirement, we simply use a type alias here. */ private[recommendation] type ALSPartitioner = org.apache.spark.HashPartitioner diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index b7d095872ffa5..89ba6ab5d2772 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -21,17 +21,19 @@ import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS} +import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkException, Logging} -import org.apache.spark.annotation.{Since, Experimental} -import org.apache.spark.ml.{Model, Estimator} +import org.apache.spark.SparkException +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.internal.Logging +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{SchemaUtils, Identifiable} -import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} -import org.apache.spark.mllib.linalg.BLAS +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.storage.StorageLevel @@ -102,7 +104,7 @@ private[regression] trait AFTSurvivalRegressionParams extends Params SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) if (fitting) { SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType) - SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(labelCol)) } if (hasQuantilesCol) { SchemaUtils.appendColumn(schema, $(quantilesCol), new VectorUDT) @@ -120,7 +122,8 @@ private[regression] trait AFTSurvivalRegressionParams extends Params @Experimental @Since("1.6.0") class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: String) - extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with Logging { + extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams + with DefaultParamsWritable with Logging { @Since("1.6.0") def this() = this(Identifiable.randomUID("aftSurvReg")) @@ -181,24 +184,35 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S * Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset, * and put it in an RDD with strong types. */ - protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = { - dataset.select($(featuresCol), $(labelCol), $(censorCol)).map { - case Row(features: Vector, label: Double, censor: Double) => - AFTPoint(features, label, censor) - } + protected[ml] def extractAFTPoints(dataset: Dataset[_]): RDD[AFTPoint] = { + dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol))) + .rdd.map { + case Row(features: Vector, label: Double, censor: Double) => + AFTPoint(features, label, censor) + } } - @Since("1.6.0") - override def fit(dataset: DataFrame): AFTSurvivalRegressionModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = { validateAndTransformSchema(dataset.schema, fitting = true) val instances = extractAFTPoints(dataset) val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val costFun = new AFTCostFun(instances, $(fitIntercept)) + val featuresSummarizer = { + val seqOp = (c: MultivariateOnlineSummarizer, v: AFTPoint) => c.add(v.features) + val combOp = (c1: MultivariateOnlineSummarizer, c2: MultivariateOnlineSummarizer) => { + c1.merge(c2) + } + instances.treeAggregate(new MultivariateOnlineSummarizer)(seqOp, combOp) + } + + val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) + + val costFun = new AFTCostFun(instances, $(fitIntercept), featuresStd) val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) - val numFeatures = dataset.select($(featuresCol)).take(1)(0).getAs[Vector](0).size + val numFeatures = featuresStd.size /* The parameters vector has three parts: the first element: Double, log(sigma), the log of scale parameter @@ -227,7 +241,13 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S if (handlePersistence) instances.unpersist() - val coefficients = Vectors.dense(parameters.slice(2, parameters.length)) + val rawCoefficients = parameters.slice(2, parameters.length) + var i = 0 + while (i < numFeatures) { + rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 } + i += 1 + } + val coefficients = Vectors.dense(rawCoefficients) val intercept = parameters(1) val scale = math.exp(parameters(0)) val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale) @@ -243,6 +263,13 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S override def copy(extra: ParamMap): AFTSurvivalRegression = defaultCopy(extra) } +@Since("1.6.0") +object AFTSurvivalRegression extends DefaultParamsReadable[AFTSurvivalRegression] { + + @Since("1.6.0") + override def load(path: String): AFTSurvivalRegression = super.load(path) +} + /** * :: Experimental :: * Model produced by [[AFTSurvivalRegression]]. @@ -254,7 +281,7 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") val coefficients: Vector, @Since("1.6.0") val intercept: Double, @Since("1.6.0") val scale: Double) - extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams { + extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with MLWritable { /** @group setParam */ @Since("1.6.0") @@ -289,8 +316,8 @@ class AFTSurvivalRegressionModel private[ml] ( math.exp(BLAS.dot(coefficients, features) + intercept) } - @Since("1.6.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) val predictUDF = udf { features: Vector => predict(features) } val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)} @@ -312,6 +339,58 @@ class AFTSurvivalRegressionModel private[ml] ( copyValues(new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale), extra) .setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = + new AFTSurvivalRegressionModel.AFTSurvivalRegressionModelWriter(this) +} + +@Since("1.6.0") +object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] { + + @Since("1.6.0") + override def read: MLReader[AFTSurvivalRegressionModel] = new AFTSurvivalRegressionModelReader + + @Since("1.6.0") + override def load(path: String): AFTSurvivalRegressionModel = super.load(path) + + /** [[MLWriter]] instance for [[AFTSurvivalRegressionModel]] */ + private[AFTSurvivalRegressionModel] class AFTSurvivalRegressionModelWriter ( + instance: AFTSurvivalRegressionModel + ) extends MLWriter with Logging { + + private case class Data(coefficients: Vector, intercept: Double, scale: Double) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: coefficients, intercept, scale + val data = Data(instance.coefficients, instance.intercept, instance.scale) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class AFTSurvivalRegressionModelReader extends MLReader[AFTSurvivalRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[AFTSurvivalRegressionModel].getName + + override def load(path: String): AFTSurvivalRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("coefficients", "intercept", "scale").head() + val coefficients = data.getAs[Vector](0) + val intercept = data.getDouble(1) + val scale = data.getDouble(2) + val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** @@ -372,27 +451,36 @@ class AFTSurvivalRegressionModel private[ml] ( * @param parameters including three part: The log of scale parameter, the intercept and * regression coefficients corresponding to the features. * @param fitIntercept Whether to fit an intercept term. + * @param featuresStd The standard deviation values of the features. */ -private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) - extends Serializable { - - // beta is the intercept and regression coefficients to the covariates - private val beta = parameters.slice(1, parameters.length) +private class AFTAggregator( + parameters: BDV[Double], + fitIntercept: Boolean, + featuresStd: Array[Double]) extends Serializable { + + // the regression coefficients to the covariates + private val coefficients = parameters.slice(2, parameters.length) + private val intercept = parameters(1) // sigma is the scale parameter of the AFT model private val sigma = math.exp(parameters(0)) private var totalCnt: Long = 0L private var lossSum = 0.0 - private var gradientBetaSum = BDV.zeros[Double](beta.length) - private var gradientLogSigmaSum = 0.0 + // Here we optimize loss function over log(sigma), intercept and coefficients + private val gradientSumArray = Array.ofDim[Double](parameters.length) def count: Long = totalCnt + def loss: Double = { + require(totalCnt > 0.0, s"The number of instances should be " + + s"greater than 0.0, but got $totalCnt.") + lossSum / totalCnt + } + def gradient: BDV[Double] = { + require(totalCnt > 0.0, s"The number of instances should be " + + s"greater than 0.0, but got $totalCnt.") + new BDV(gradientSumArray.map(_ / totalCnt.toDouble)) + } - def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt - - // Here we optimize loss function over beta and log(sigma) - def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)), - gradientBetaSum/totalCnt.toDouble) /** * Add a new training data to this AFTAggregator, and update the loss and gradient @@ -402,26 +490,32 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) * @return This AFTAggregator object. */ def add(data: AFTPoint): this.type = { - - // TODO: Don't create a new xi vector each time. - val xi = if (fitIntercept) { - Vectors.dense(Array(1.0) ++ data.features.toArray).toBreeze - } else { - Vectors.dense(Array(0.0) ++ data.features.toArray).toBreeze - } + val xi = data.features val ti = data.label val delta = data.censor - val epsilon = (math.log(ti) - beta.dot(xi)) / sigma - lossSum += math.log(sigma) * delta - lossSum += (math.exp(epsilon) - delta * epsilon) + val margin = { + var sum = 0.0 + xi.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + sum += coefficients(index) * (value / featuresStd(index)) + } + } + sum + intercept + } + val epsilon = (math.log(ti) - margin) / sigma - // Sanity check (should never occur): - assert(!lossSum.isInfinity, - s"AFTAggregator loss sum is infinity. Error for unknown reason.") + lossSum += delta * math.log(sigma) - delta * epsilon + math.exp(epsilon) - gradientBetaSum += xi * (delta - math.exp(epsilon)) / sigma - gradientLogSigmaSum += delta + (delta - math.exp(epsilon)) * epsilon + val multiplier = (delta - math.exp(epsilon)) / sigma + + gradientSumArray(0) += delta + multiplier * sigma * epsilon + gradientSumArray(1) += { if (fitIntercept) multiplier else 0.0 } + xi.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + gradientSumArray(index + 2) += multiplier * (value / featuresStd(index)) + } + } totalCnt += 1 this @@ -440,8 +534,12 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) totalCnt += other.totalCnt lossSum += other.lossSum - gradientBetaSum += other.gradientBetaSum - gradientLogSigmaSum += other.gradientLogSigmaSum + var i = 0 + val len = this.gradientSumArray.length + while (i < len) { + this.gradientSumArray(i) += other.gradientSumArray(i) + i += 1 + } } this } @@ -452,12 +550,15 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) * It returns the loss and gradient at a particular point (parameters). * It's used in Breeze's convex optimization routines. */ -private class AFTCostFun(data: RDD[AFTPoint], fitIntercept: Boolean) - extends DiffFunction[BDV[Double]] { +private class AFTCostFun( + data: RDD[AFTPoint], + fitIntercept: Boolean, + featuresStd: Array[Double]) extends DiffFunction[BDV[Double]] { override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = { - val aftAggregator = data.treeAggregate(new AFTAggregator(parameters, fitIntercept))( + val aftAggregator = data.treeAggregate( + new AFTAggregator(parameters, fitIntercept, featuresStd))( seqOp = (c, v) => (c, v) match { case (aggregator, instance) => aggregator.add(instance) }, diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 04420fc6e8251..c04c416aaf19c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -17,18 +17,25 @@ package org.apache.spark.ml.regression +import org.apache.hadoop.fs.Path +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ + import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams} +import org.apache.spark.ml.tree._ +import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._ import org.apache.spark.ml.tree.impl.RandomForest -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.functions._ + /** * :: Experimental :: @@ -40,7 +47,7 @@ import org.apache.spark.sql.DataFrame @Experimental final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] - with DecisionTreeParams with TreeRegressorParams { + with DecisionTreeRegressorParams with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("dtr")) @@ -71,13 +78,26 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val @Since("1.4.0") override def setImpurity(value: String): this.type = super.setImpurity(value) - override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = { + override def setSeed(value: Long): this.type = super.setSeed(value) + + /** @group setParam */ + def setVarianceCol(value: String): this.type = set(varianceCol, value) + + override protected def train(dataset: Dataset[_]): DecisionTreeRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = getOldStrategy(categoricalFeatures) val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 0L, parentUID = Some(uid)) + seed = $(seed), parentUID = Some(uid)) + trees.head.asInstanceOf[DecisionTreeRegressionModel] + } + + /** (private[ml]) Train a decision tree on an RDD */ + private[ml] def train(data: RDD[LabeledPoint], + oldStrategy: OldStrategy): DecisionTreeRegressionModel = { + val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", + seed = $(seed), parentUID = Some(uid)) trees.head.asInstanceOf[DecisionTreeRegressionModel] } @@ -93,9 +113,12 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val @Since("1.4.0") @Experimental -object DecisionTreeRegressor { +object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor] { /** Accessor for supported impurities: variance */ final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities + + @Since("2.0.0") + override def load(path: String): DecisionTreeRegressor = super.load(path) } /** @@ -111,10 +134,13 @@ final class DecisionTreeRegressionModel private[ml] ( override val rootNode: Node, override val numFeatures: Int) extends PredictionModel[Vector, DecisionTreeRegressionModel] - with DecisionTreeModel with Serializable { + with DecisionTreeModel with DecisionTreeRegressorParams with MLWritable with Serializable { + + /** @group setParam */ + def setVarianceCol(value: String): this.type = set(varianceCol, value) require(rootNode != null, - "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") + "DecisionTreeRegressionModel given null rootNode, but it requires a non-null rootNode.") /** * Construct a decision tree regression model. @@ -127,6 +153,30 @@ final class DecisionTreeRegressionModel private[ml] ( rootNode.predictImpl(features).prediction } + /** We need to update this function if we ever add other impurity measures. */ + protected def predictVariance(features: Vector): Double = { + rootNode.predictImpl(features).impurityStats.calculate() + } + + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) + transformImpl(dataset) + } + + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { + val predictUDF = udf { (features: Vector) => predict(features) } + val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) } + var output = dataset.toDF + if ($(predictionCol).nonEmpty) { + output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + if (isDefined(varianceCol) && $(varianceCol).nonEmpty) { + output = output.withColumn($(varianceCol), predictVarianceUDF(col($(featuresCol)))) + } + output + } + @Since("1.4.0") override def copy(extra: ParamMap): DecisionTreeRegressionModel = { copyValues(new DecisionTreeRegressionModel(uid, rootNode, numFeatures), extra).setParent(parent) @@ -137,16 +187,78 @@ final class DecisionTreeRegressionModel private[ml] ( s"DecisionTreeRegressionModel (uid=$uid) of depth $depth with $numNodes nodes" } - /** Convert to a model in the old API */ - private[ml] def toOld: OldDecisionTreeModel = { + /** + * Estimate of the importance of each feature. + * + * This generalizes the idea of "Gini" importance to other losses, + * following the explanation of Gini importance from "Random Forests" documentation + * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * + * This feature importance is calculated as follows: + * - importance(feature j) = sum (over nodes which split on feature j) of the gain, + * where gain is scaled by the number of instances passing through node + * - Normalize importances for tree to sum to 1. + * + * Note: Feature importance for single decision trees can have high variance due to + * correlated predictor variables. Consider using a [[RandomForestRegressor]] + * to determine feature importance instead. + */ + @Since("2.0.0") + lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures) + + /** Convert to spark.mllib DecisionTreeModel (losing some information) */ + override private[spark] def toOld: OldDecisionTreeModel = { new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression) } + + @Since("2.0.0") + override def write: MLWriter = + new DecisionTreeRegressionModel.DecisionTreeRegressionModelWriter(this) } -private[ml] object DecisionTreeRegressionModel { +@Since("2.0.0") +object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionModel] { + + @Since("2.0.0") + override def read: MLReader[DecisionTreeRegressionModel] = + new DecisionTreeRegressionModelReader + + @Since("2.0.0") + override def load(path: String): DecisionTreeRegressionModel = super.load(path) + + private[DecisionTreeRegressionModel] + class DecisionTreeRegressionModelWriter(instance: DecisionTreeRegressionModel) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures) + DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) + val (nodeData, _) = NodeData.build(instance.rootNode, 0) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(nodeData).write.parquet(dataPath) + } + } + + private class DecisionTreeRegressionModelReader + extends MLReader[DecisionTreeRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[DecisionTreeRegressionModel].getName + + override def load(path: String): DecisionTreeRegressionModel = { + implicit val format = DefaultFormats + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val root = loadTreeNodes(path, metadata, sqlContext) + val model = new DecisionTreeRegressionModel(metadata.uid, root, numFeatures) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } - /** (private[ml]) Convert a model from the old API */ - def fromOld( + /** Convert a model from the old API */ + private[ml] def fromOld( oldModel: OldDecisionTreeModel, parent: DecisionTreeRegressor, categoricalFeatures: Map[Int, Int], diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 07144cc7cfbd7..741724d7a1045 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -18,35 +18,48 @@ package org.apache.spark.ml.regression import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ -import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.internal.Logging import org.apache.spark.ml.{PredictionModel, Predictor} -import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel, TreeRegressorParams} -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.tree._ +import org.apache.spark.ml.tree.impl.GradientBoostedTrees +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss, SquaredError => OldSquaredError} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DoubleType /** * :: Experimental :: * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] * learning algorithm for regression. * It supports both continuous and categorical features. + * + * The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999. + * + * Notes on Gradient Boosting vs. TreeBoost: + * - This implementation is for Stochastic Gradient Boosting, not for TreeBoost. + * - Both algorithms learn tree ensembles by minimizing loss functions. + * - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes + * based on the loss function, whereas the original gradient boosting method does not. + * - When the loss is SquaredError, these methods give the same result, but they could differ + * for other loss functions. + * - We expect to implement TreeBoost in the future: + * [https://issues.apache.org/jira/browse/SPARK-4240] */ @Since("1.4.0") @Experimental final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, GBTRegressor, GBTRegressionModel] - with GBTParams with TreeRegressorParams with Logging { + with GBTRegressorParams with DefaultParamsWritable with Logging { @Since("1.4.0") def this() = this(Identifiable.randomUID("gbtr")) @@ -91,10 +104,7 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) @Since("1.4.0") - override def setSeed(value: Long): this.type = { - logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.") - super.setSeed(value) - } + override def setSeed(value: Long): this.type = super.setSeed(value) // Parameters from GBTParams: @Since("1.4.0") @@ -103,50 +113,21 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri @Since("1.4.0") override def setStepSize(value: Double): this.type = super.setStepSize(value) - // Parameters for GBTRegressor: - - /** - * Loss function which GBT tries to minimize. (case-insensitive) - * Supported: "squared" (L2) and "absolute" (L1) - * (default = squared) - * @group param - */ - @Since("1.4.0") - val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + - " tries to minimize (case-insensitive). Supported options:" + - s" ${GBTRegressor.supportedLossTypes.mkString(", ")}", - (value: String) => GBTRegressor.supportedLossTypes.contains(value.toLowerCase)) - - setDefault(lossType -> "squared") + // Parameters from GBTRegressorParams: /** @group setParam */ @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) - /** @group getParam */ - @Since("1.4.0") - def getLossType: String = $(lossType).toLowerCase - - /** (private[ml]) Convert new loss to old loss. */ - override private[ml] def getOldLossType: OldLoss = { - getLossType match { - case "squared" => OldSquaredError - case "absolute" => OldAbsoluteError - case _ => - // Should never happen because of check in setter method. - throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType") - } - } - - override protected def train(dataset: DataFrame): GBTRegressionModel = { + override protected def train(dataset: Dataset[_]): GBTRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) - val oldGBT = new OldGBT(boostingStrategy) - val oldModel = oldGBT.run(oldDataset) - GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures, numFeatures) + val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, + $(seed)) + new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) } @Since("1.4.0") @@ -155,11 +136,14 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri @Since("1.4.0") @Experimental -object GBTRegressor { - // The losses below should be lowercase. +object GBTRegressor extends DefaultParamsReadable[GBTRegressor] { + /** Accessor for supported loss settings: squared (L2), absolute (L1) */ @Since("1.4.0") - final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase) + final val supportedLossTypes: Array[String] = GBTRegressorParams.supportedLossTypes + + @Since("2.0.0") + override def load(path: String): GBTRegressor = super.load(path) } /** @@ -179,9 +163,10 @@ final class GBTRegressionModel private[ml]( private val _treeWeights: Array[Double], override val numFeatures: Int) extends PredictionModel[Vector, GBTRegressionModel] - with TreeEnsembleModel with Serializable { + with GBTRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel] + with MLWritable with Serializable { - require(numTrees > 0, "GBTRegressionModel requires at least 1 tree.") + require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.") require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" + s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") @@ -195,12 +180,12 @@ final class GBTRegressionModel private[ml]( this(uid, _trees, _treeWeights, -1) @Since("1.4.0") - override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + override def trees: Array[DecisionTreeRegressionModel] = _trees @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights - override protected def transformImpl(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) val predictUDF = udf { (features: Any) => bcastModel.value.predict(features.asInstanceOf[Vector]) @@ -215,6 +200,9 @@ final class GBTRegressionModel private[ml]( blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) } + /** Number of trees in ensemble */ + val numTrees: Int = trees.length + @Since("1.4.0") override def copy(extra: ParamMap): GBTRegressionModel = { copyValues(new GBTRegressionModel(uid, _trees, _treeWeights, numFeatures), @@ -226,16 +214,81 @@ final class GBTRegressionModel private[ml]( s"GBTRegressionModel (uid=$uid) with $numTrees trees" } + /** + * Estimate of the importance of each feature. + * + * Each feature's importance is the average of its importance across all trees in the ensemble + * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. + * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) + * and follows the implementation from scikit-learn. + * + * @see [[DecisionTreeRegressionModel.featureImportances]] + */ + @Since("2.0.0") + lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) + /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldGBTModel = { new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights) } + + @Since("2.0.0") + override def write: MLWriter = new GBTRegressionModel.GBTRegressionModelWriter(this) } -private[ml] object GBTRegressionModel { +@Since("2.0.0") +object GBTRegressionModel extends MLReadable[GBTRegressionModel] { + + @Since("2.0.0") + override def read: MLReader[GBTRegressionModel] = new GBTRegressionModelReader + + @Since("2.0.0") + override def load(path: String): GBTRegressionModel = super.load(path) + + private[GBTRegressionModel] + class GBTRegressionModelWriter(instance: GBTRegressionModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures, + "numTrees" -> instance.getNumTrees) + EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + } + } + + private class GBTRegressionModelReader extends MLReader[GBTRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[GBTRegressionModel].getName + private val treeClassName = classOf[DecisionTreeRegressionModel].getName + + override def load(path: String): GBTRegressionModel = { + implicit val format = DefaultFormats + val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = + EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val numTrees = (metadata.metadata \ "numTrees").extract[Int] + + val trees: Array[DecisionTreeRegressionModel] = treesData.map { + case (treeMetadata, root) => + val tree = + new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + DefaultParamsReader.getAndSetParams(tree, treeMetadata) + tree + } + + require(numTrees == trees.length, s"GBTRegressionModel.load expected $numTrees" + + s" trees based on metadata but found ${trees.length} trees.") + + val model = new GBTRegressionModel(metadata.uid, trees, treeWeights, numFeatures) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } - /** (private[ml]) Convert a model from the old API */ - def fromOld( + /** Convert a model from the old API */ + private[ml] def fromOld( oldModel: OldGBTModel, parent: GBTRegressor, categoricalFeatures: Map[Int, Int], @@ -247,6 +300,6 @@ private[ml] object GBTRegressionModel { DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr") - new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures) + new GBTRegressionModel(uid, newTrees, oldModel.treeWeights, numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala new file mode 100644 index 0000000000000..e92a3e7fa1f0c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -0,0 +1,983 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import breeze.stats.{distributions => dist} +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkException +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.internal.Logging +import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.optim._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.linalg.{BLAS, Vector} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DataType, DoubleType, StructType} + +/** + * Params for Generalized Linear Regression. + */ +private[regression] trait GeneralizedLinearRegressionBase extends PredictorParams + with HasFitIntercept with HasMaxIter with HasTol with HasRegParam with HasWeightCol + with HasSolver with Logging { + + /** + * Param for the name of family which is a description of the error distribution + * to be used in the model. + * Supported options: "gaussian", "binomial", "poisson" and "gamma". + * Default is "gaussian". + * + * @group param + */ + @Since("2.0.0") + final val family: Param[String] = new Param(this, "family", + "The name of family which is a description of the error distribution to be used in the " + + "model. Supported options: gaussian(default), binomial, poisson and gamma.", + ParamValidators.inArray[String](GeneralizedLinearRegression.supportedFamilyNames.toArray)) + + /** @group getParam */ + @Since("2.0.0") + def getFamily: String = $(family) + + /** + * Param for the name of link function which provides the relationship + * between the linear predictor and the mean of the distribution function. + * Supported options: "identity", "log", "inverse", "logit", "probit", "cloglog" and "sqrt". + * + * @group param + */ + @Since("2.0.0") + final val link: Param[String] = new Param(this, "link", "The name of link function " + + "which provides the relationship between the linear predictor and the mean of the " + + "distribution function. Supported options: identity, log, inverse, logit, probit, " + + "cloglog and sqrt.", + ParamValidators.inArray[String](GeneralizedLinearRegression.supportedLinkNames.toArray)) + + /** @group getParam */ + @Since("2.0.0") + def getLink: String = $(link) + + import GeneralizedLinearRegression._ + + @Since("2.0.0") + override def validateAndTransformSchema( + schema: StructType, + fitting: Boolean, + featuresDataType: DataType): StructType = { + if ($(solver) == "irls") { + setDefault(maxIter -> 25) + } + if (isDefined(link)) { + require(supportedFamilyAndLinkPairs.contains( + Family.fromName($(family)) -> Link.fromName($(link))), "Generalized Linear Regression " + + s"with ${$(family)} family does not support ${$(link)} link function.") + } + super.validateAndTransformSchema(schema, fitting, featuresDataType) + } +} + +/** + * :: Experimental :: + * + * Fit a Generalized Linear Model ([[https://en.wikipedia.org/wiki/Generalized_linear_model]]) + * specified by giving a symbolic description of the linear predictor (link function) and + * a description of the error distribution (family). + * It supports "gaussian", "binomial", "poisson" and "gamma" as family. + * Valid link functions for each family is listed below. The first link function of each family + * is the default one. + * - "gaussian" -> "identity", "log", "inverse" + * - "binomial" -> "logit", "probit", "cloglog" + * - "poisson" -> "log", "identity", "sqrt" + * - "gamma" -> "inverse", "identity", "log" + */ +@Experimental +@Since("2.0.0") +class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val uid: String) + extends Regressor[Vector, GeneralizedLinearRegression, GeneralizedLinearRegressionModel] + with GeneralizedLinearRegressionBase with DefaultParamsWritable with Logging { + + import GeneralizedLinearRegression._ + + @Since("2.0.0") + def this() = this(Identifiable.randomUID("glm")) + + /** + * Sets the value of param [[family]]. + * Default is "gaussian". + * @group setParam + */ + @Since("2.0.0") + def setFamily(value: String): this.type = set(family, value) + setDefault(family -> Gaussian.name) + + /** + * Sets the value of param [[link]]. + * @group setParam + */ + @Since("2.0.0") + def setLink(value: String): this.type = set(link, value) + + /** + * Sets if we should fit the intercept. + * Default is true. + * @group setParam + */ + @Since("2.0.0") + def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) + + /** + * Sets the maximum number of iterations. + * Default is 25 if the solver algorithm is "irls". + * @group setParam + */ + @Since("2.0.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** + * Sets the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-6. + * @group setParam + */ + @Since("2.0.0") + def setTol(value: Double): this.type = set(tol, value) + setDefault(tol -> 1E-6) + + /** + * Sets the regularization parameter for L2 regularization. + * The regularization term is + * {{{ + * 0.5 * regParam * L2norm(coefficients)^2 + * }}} + * Default is 0.0. + * @group setParam + */ + @Since("2.0.0") + def setRegParam(value: Double): this.type = set(regParam, value) + setDefault(regParam -> 0.0) + + /** + * Sets the value of param [[weightCol]]. + * If this is not set or empty, we treat all instance weights as 1.0. + * Default is empty, so all instances have weight one. + * @group setParam + */ + @Since("2.0.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(weightCol -> "") + + /** + * Sets the solver algorithm used for optimization. + * Currently only support "irls" which is also the default solver. + * @group setParam + */ + @Since("2.0.0") + def setSolver(value: String): this.type = set(solver, value) + setDefault(solver -> "irls") + + override protected def train(dataset: Dataset[_]): GeneralizedLinearRegressionModel = { + val familyObj = Family.fromName($(family)) + val linkObj = if (isDefined(link)) { + Link.fromName($(link)) + } else { + familyObj.defaultLink + } + val familyAndLink = new FamilyAndLink(familyObj, linkObj) + + val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd + .map { case Row(features: Vector) => + features.size + }.first() + if (numFeatures > WeightedLeastSquares.MAX_NUM_FEATURES) { + val msg = "Currently, GeneralizedLinearRegression only supports number of features" + + s" <= ${WeightedLeastSquares.MAX_NUM_FEATURES}. Found $numFeatures in the input dataset." + throw new SparkException(msg) + } + + val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val instances: RDD[Instance] = + dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) + } + + if (familyObj == Gaussian && linkObj == Identity) { + // TODO: Make standardizeFeatures and standardizeLabel configurable. + val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), + standardizeFeatures = true, standardizeLabel = true) + val wlsModel = optimizer.fit(instances) + val model = copyValues( + new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept) + .setParent(this)) + // Handle possible missing or invalid prediction columns + val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() + val trainingSummary = new GeneralizedLinearRegressionSummary( + summaryModel.transform(dataset), + predictionColName, + model, + wlsModel.diagInvAtWA.toArray, + 1, + getSolver) + return model.setSummary(trainingSummary) + } + + // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). + val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam)) + val optimizer = new IterativelyReweightedLeastSquares(initialModel, familyAndLink.reweightFunc, + $(fitIntercept), $(regParam), $(maxIter), $(tol)) + val irlsModel = optimizer.fit(instances) + + val model = copyValues( + new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept) + .setParent(this)) + // Handle possible missing or invalid prediction columns + val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() + val trainingSummary = new GeneralizedLinearRegressionSummary( + summaryModel.transform(dataset), + predictionColName, + model, + irlsModel.diagInvAtWA.toArray, + irlsModel.numIterations, + getSolver) + + model.setSummary(trainingSummary) + } + + @Since("2.0.0") + override def copy(extra: ParamMap): GeneralizedLinearRegression = defaultCopy(extra) +} + +@Since("2.0.0") +object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLinearRegression] { + + @Since("2.0.0") + override def load(path: String): GeneralizedLinearRegression = super.load(path) + + /** Set of family and link pairs that GeneralizedLinearRegression supports. */ + private[ml] lazy val supportedFamilyAndLinkPairs = Set( + Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse, + Binomial -> Logit, Binomial -> Probit, Binomial -> CLogLog, + Poisson -> Log, Poisson -> Identity, Poisson -> Sqrt, + Gamma -> Inverse, Gamma -> Identity, Gamma -> Log + ) + + /** Set of family names that GeneralizedLinearRegression supports. */ + private[ml] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name) + + /** Set of link names that GeneralizedLinearRegression supports. */ + private[ml] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name) + + private[ml] val epsilon: Double = 1E-16 + + /** + * Wrapper of family and link combination used in the model. + */ + private[ml] class FamilyAndLink(val family: Family, val link: Link) extends Serializable { + + /** Linear predictor based on given mu. */ + def predict(mu: Double): Double = link.link(family.project(mu)) + + /** Fitted value based on linear predictor eta. */ + def fitted(eta: Double): Double = family.project(link.unlink(eta)) + + /** + * Get the initial guess model for [[IterativelyReweightedLeastSquares]]. + */ + def initialize( + instances: RDD[Instance], + fitIntercept: Boolean, + regParam: Double): WeightedLeastSquaresModel = { + val newInstances = instances.map { instance => + val mu = family.initialize(instance.label, instance.weight) + val eta = predict(mu) + Instance(eta, instance.weight, instance.features) + } + // TODO: Make standardizeFeatures and standardizeLabel configurable. + val initialModel = new WeightedLeastSquares(fitIntercept, regParam, + standardizeFeatures = true, standardizeLabel = true) + .fit(newInstances) + initialModel + } + + /** + * The reweight function used to update offsets and weights + * at each iteration of [[IterativelyReweightedLeastSquares]]. + */ + val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double) = { + (instance: Instance, model: WeightedLeastSquaresModel) => { + val eta = model.predict(instance.features) + val mu = fitted(eta) + val offset = eta + (instance.label - mu) * link.deriv(mu) + val weight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu)) + (offset, weight) + } + } + } + + /** + * A description of the error distribution to be used in the model. + * @param name the name of the family. + */ + private[ml] abstract class Family(val name: String) extends Serializable { + + /** The default link instance of this family. */ + val defaultLink: Link + + /** Initialize the starting value for mu. */ + def initialize(y: Double, weight: Double): Double + + /** The variance of the endogenous variable's mean, given the value mu. */ + def variance(mu: Double): Double + + /** Deviance of (y, mu) pair. */ + def deviance(y: Double, mu: Double, weight: Double): Double + + /** + * Akaike's 'An Information Criterion'(AIC) value of the family for a given dataset. + * @param predictions an RDD of (y, mu, weight) of instances in evaluation dataset + * @param deviance the deviance for the fitted model in evaluation dataset + * @param numInstances number of instances in evaluation dataset + * @param weightSum weights sum of instances in evaluation dataset + */ + def aic( + predictions: RDD[(Double, Double, Double)], + deviance: Double, + numInstances: Double, + weightSum: Double): Double + + /** Trim the fitted value so that it will be in valid range. */ + def project(mu: Double): Double = mu + } + + private[ml] object Family { + + /** + * Gets the [[Family]] object from its name. + * @param name family name: "gaussian", "binomial", "poisson" or "gamma". + */ + def fromName(name: String): Family = { + name match { + case Gaussian.name => Gaussian + case Binomial.name => Binomial + case Poisson.name => Poisson + case Gamma.name => Gamma + } + } + } + + /** + * Gaussian exponential family distribution. + * The default link for the Gaussian family is the identity link. + */ + private[ml] object Gaussian extends Family("gaussian") { + + val defaultLink: Link = Identity + + override def initialize(y: Double, weight: Double): Double = y + + override def variance(mu: Double): Double = 1.0 + + override def deviance(y: Double, mu: Double, weight: Double): Double = { + weight * (y - mu) * (y - mu) + } + + override def aic( + predictions: RDD[(Double, Double, Double)], + deviance: Double, + numInstances: Double, + weightSum: Double): Double = { + val wt = predictions.map(x => math.log(x._3)).sum() + numInstances * (math.log(deviance / numInstances * 2.0 * math.Pi) + 1.0) + 2.0 - wt + } + + override def project(mu: Double): Double = { + if (mu.isNegInfinity) { + Double.MinValue + } else if (mu.isPosInfinity) { + Double.MaxValue + } else { + mu + } + } + } + + /** + * Binomial exponential family distribution. + * The default link for the Binomial family is the logit link. + */ + private[ml] object Binomial extends Family("binomial") { + + val defaultLink: Link = Logit + + override def initialize(y: Double, weight: Double): Double = { + val mu = (weight * y + 0.5) / (weight + 1.0) + require(mu > 0.0 && mu < 1.0, "The response variable of Binomial family" + + s"should be in range (0, 1), but got $mu") + mu + } + + override def variance(mu: Double): Double = mu * (1.0 - mu) + + override def deviance(y: Double, mu: Double, weight: Double): Double = { + val my = 1.0 - y + 2.0 * weight * (y * math.log(math.max(y, 1.0) / mu) + + my * math.log(math.max(my, 1.0) / (1.0 - mu))) + } + + override def aic( + predictions: RDD[(Double, Double, Double)], + deviance: Double, + numInstances: Double, + weightSum: Double): Double = { + -2.0 * predictions.map { case (y: Double, mu: Double, weight: Double) => + weight * dist.Binomial(1, mu).logProbabilityOf(math.round(y).toInt) + }.sum() + } + + override def project(mu: Double): Double = { + if (mu < epsilon) { + epsilon + } else if (mu > 1.0 - epsilon) { + 1.0 - epsilon + } else { + mu + } + } + } + + /** + * Poisson exponential family distribution. + * The default link for the Poisson family is the log link. + */ + private[ml] object Poisson extends Family("poisson") { + + val defaultLink: Link = Log + + override def initialize(y: Double, weight: Double): Double = { + require(y > 0.0, "The response variable of Poisson family " + + s"should be positive, but got $y") + y + } + + override def variance(mu: Double): Double = mu + + override def deviance(y: Double, mu: Double, weight: Double): Double = { + 2.0 * weight * (y * math.log(y / mu) - (y - mu)) + } + + override def aic( + predictions: RDD[(Double, Double, Double)], + deviance: Double, + numInstances: Double, + weightSum: Double): Double = { + -2.0 * predictions.map { case (y: Double, mu: Double, weight: Double) => + weight * dist.Poisson(mu).logProbabilityOf(y.toInt) + }.sum() + } + + override def project(mu: Double): Double = { + if (mu < epsilon) { + epsilon + } else if (mu.isInfinity) { + Double.MaxValue + } else { + mu + } + } + } + + /** + * Gamma exponential family distribution. + * The default link for the Gamma family is the inverse link. + */ + private[ml] object Gamma extends Family("gamma") { + + val defaultLink: Link = Inverse + + override def initialize(y: Double, weight: Double): Double = { + require(y > 0.0, "The response variable of Gamma family " + + s"should be positive, but got $y") + y + } + + override def variance(mu: Double): Double = mu * mu + + override def deviance(y: Double, mu: Double, weight: Double): Double = { + -2.0 * weight * (math.log(y / mu) - (y - mu)/mu) + } + + override def aic( + predictions: RDD[(Double, Double, Double)], + deviance: Double, + numInstances: Double, + weightSum: Double): Double = { + val disp = deviance / weightSum + -2.0 * predictions.map { case (y: Double, mu: Double, weight: Double) => + weight * dist.Gamma(1.0 / disp, mu * disp).logPdf(y) + }.sum() + 2.0 + } + + override def project(mu: Double): Double = { + if (mu < epsilon) { + epsilon + } else if (mu.isInfinity) { + Double.MaxValue + } else { + mu + } + } + } + + /** + * A description of the link function to be used in the model. + * The link function provides the relationship between the linear predictor + * and the mean of the distribution function. + * @param name the name of link function. + */ + private[ml] abstract class Link(val name: String) extends Serializable { + + /** The link function. */ + def link(mu: Double): Double + + /** Derivative of the link function. */ + def deriv(mu: Double): Double + + /** The inverse link function. */ + def unlink(eta: Double): Double + } + + private[ml] object Link { + + /** + * Gets the [[Link]] object from its name. + * @param name link name: "identity", "logit", "log", + * "inverse", "probit", "cloglog" or "sqrt". + */ + def fromName(name: String): Link = { + name match { + case Identity.name => Identity + case Logit.name => Logit + case Log.name => Log + case Inverse.name => Inverse + case Probit.name => Probit + case CLogLog.name => CLogLog + case Sqrt.name => Sqrt + } + } + } + + private[ml] object Identity extends Link("identity") { + + override def link(mu: Double): Double = mu + + override def deriv(mu: Double): Double = 1.0 + + override def unlink(eta: Double): Double = eta + } + + private[ml] object Logit extends Link("logit") { + + override def link(mu: Double): Double = math.log(mu / (1.0 - mu)) + + override def deriv(mu: Double): Double = 1.0 / (mu * (1.0 - mu)) + + override def unlink(eta: Double): Double = 1.0 / (1.0 + math.exp(-1.0 * eta)) + } + + private[ml] object Log extends Link("log") { + + override def link(mu: Double): Double = math.log(mu) + + override def deriv(mu: Double): Double = 1.0 / mu + + override def unlink(eta: Double): Double = math.exp(eta) + } + + private[ml] object Inverse extends Link("inverse") { + + override def link(mu: Double): Double = 1.0 / mu + + override def deriv(mu: Double): Double = -1.0 * math.pow(mu, -2.0) + + override def unlink(eta: Double): Double = 1.0 / eta + } + + private[ml] object Probit extends Link("probit") { + + override def link(mu: Double): Double = dist.Gaussian(0.0, 1.0).icdf(mu) + + override def deriv(mu: Double): Double = { + 1.0 / dist.Gaussian(0.0, 1.0).pdf(dist.Gaussian(0.0, 1.0).icdf(mu)) + } + + override def unlink(eta: Double): Double = dist.Gaussian(0.0, 1.0).cdf(eta) + } + + private[ml] object CLogLog extends Link("cloglog") { + + override def link(mu: Double): Double = math.log(-1.0 * math.log(1 - mu)) + + override def deriv(mu: Double): Double = 1.0 / ((mu - 1.0) * math.log(1.0 - mu)) + + override def unlink(eta: Double): Double = 1.0 - math.exp(-1.0 * math.exp(eta)) + } + + private[ml] object Sqrt extends Link("sqrt") { + + override def link(mu: Double): Double = math.sqrt(mu) + + override def deriv(mu: Double): Double = 1.0 / (2.0 * math.sqrt(mu)) + + override def unlink(eta: Double): Double = eta * eta + } +} + +/** + * :: Experimental :: + * Model produced by [[GeneralizedLinearRegression]]. + */ +@Experimental +@Since("2.0.0") +class GeneralizedLinearRegressionModel private[ml] ( + @Since("2.0.0") override val uid: String, + @Since("2.0.0") val coefficients: Vector, + @Since("2.0.0") val intercept: Double) + extends RegressionModel[Vector, GeneralizedLinearRegressionModel] + with GeneralizedLinearRegressionBase with MLWritable { + + import GeneralizedLinearRegression._ + + lazy val familyObj = Family.fromName($(family)) + lazy val linkObj = if (isDefined(link)) { + Link.fromName($(link)) + } else { + familyObj.defaultLink + } + lazy val familyAndLink = new FamilyAndLink(familyObj, linkObj) + + override protected def predict(features: Vector): Double = { + val eta = BLAS.dot(features, coefficients) + intercept + familyAndLink.fitted(eta) + } + + private var trainingSummary: Option[GeneralizedLinearRegressionSummary] = None + + /** + * Gets R-like summary of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + @Since("2.0.0") + def summary: GeneralizedLinearRegressionSummary = trainingSummary.getOrElse { + throw new SparkException( + "No training summary available for this GeneralizedLinearRegressionModel") + } + + private[regression] def setSummary(summary: GeneralizedLinearRegressionSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** + * If the prediction column is set returns the current model and prediction column, + * otherwise generates a new column and sets it as the prediction column on a new copy + * of the current model. + */ + private[regression] def findSummaryModelAndPredictionCol() + : (GeneralizedLinearRegressionModel, String) = { + $(predictionCol) match { + case "" => + val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString + (copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName) + case p => (this, p) + } + } + + @Since("2.0.0") + override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = { + copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra) + .setParent(parent) + } + + @Since("2.0.0") + override def write: MLWriter = + new GeneralizedLinearRegressionModel.GeneralizedLinearRegressionModelWriter(this) +} + +@Since("2.0.0") +object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegressionModel] { + + @Since("2.0.0") + override def read: MLReader[GeneralizedLinearRegressionModel] = + new GeneralizedLinearRegressionModelReader + + @Since("2.0.0") + override def load(path: String): GeneralizedLinearRegressionModel = super.load(path) + + /** [[MLWriter]] instance for [[GeneralizedLinearRegressionModel]] */ + private[GeneralizedLinearRegressionModel] + class GeneralizedLinearRegressionModelWriter(instance: GeneralizedLinearRegressionModel) + extends MLWriter with Logging { + + private case class Data(intercept: Double, coefficients: Vector) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: intercept, coefficients + val data = Data(instance.intercept, instance.coefficients) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class GeneralizedLinearRegressionModelReader + extends MLReader[GeneralizedLinearRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[GeneralizedLinearRegressionModel].getName + + override def load(path: String): GeneralizedLinearRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("intercept", "coefficients").head() + val intercept = data.getDouble(0) + val coefficients = data.getAs[Vector](1) + + val model = new GeneralizedLinearRegressionModel(metadata.uid, coefficients, intercept) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } +} + +/** + * :: Experimental :: + * Summarizing Generalized Linear regression Fits. + * + * @param predictions predictions output by the model's `transform` method + * @param predictionCol field in "predictions" which gives the prediction value of each instance + * @param model the model that should be summarized + * @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration + * @param numIterations number of iterations + * @param solver the solver algorithm used for model training + */ +@Since("2.0.0") +@Experimental +class GeneralizedLinearRegressionSummary private[regression] ( + @Since("2.0.0") @transient val predictions: DataFrame, + @Since("2.0.0") val predictionCol: String, + @Since("2.0.0") val model: GeneralizedLinearRegressionModel, + private val diagInvAtWA: Array[Double], + @Since("2.0.0") val numIterations: Int, + @Since("2.0.0") val solver: String) extends Serializable { + + import GeneralizedLinearRegression._ + + private lazy val family = Family.fromName(model.getFamily) + private lazy val link = if (model.isDefined(model.getParam("link"))) { + Link.fromName(model.getLink) + } else { + family.defaultLink + } + + /** Number of instances in DataFrame predictions */ + private lazy val numInstances: Long = predictions.count() + + /** The numeric rank of the fitted linear model */ + @Since("2.0.0") + lazy val rank: Long = if (model.getFitIntercept) { + model.coefficients.size + 1 + } else { + model.coefficients.size + } + + /** Degrees of freedom */ + @Since("2.0.0") + lazy val degreesOfFreedom: Long = { + numInstances - rank + } + + /** The residual degrees of freedom */ + @Since("2.0.0") + lazy val residualDegreeOfFreedom: Long = degreesOfFreedom + + /** The residual degrees of freedom for the null model */ + @Since("2.0.0") + lazy val residualDegreeOfFreedomNull: Long = if (model.getFitIntercept) { + numInstances - 1 + } else { + numInstances + } + + private lazy val devianceResiduals: DataFrame = { + val drUDF = udf { (y: Double, mu: Double, weight: Double) => + val r = math.sqrt(math.max(family.deviance(y, mu, weight), 0.0)) + if (y > mu) r else -1.0 * r + } + val w = if (model.getWeightCol.isEmpty) lit(1.0) else col(model.getWeightCol) + predictions.select( + drUDF(col(model.getLabelCol), col(predictionCol), w).as("devianceResiduals")) + } + + private lazy val pearsonResiduals: DataFrame = { + val prUDF = udf { mu: Double => family.variance(mu) } + val w = if (model.getWeightCol.isEmpty) lit(1.0) else col(model.getWeightCol) + predictions.select(col(model.getLabelCol).minus(col(predictionCol)) + .multiply(sqrt(w)).divide(sqrt(prUDF(col(predictionCol)))).as("pearsonResiduals")) + } + + private lazy val workingResiduals: DataFrame = { + val wrUDF = udf { (y: Double, mu: Double) => (y - mu) * link.deriv(mu) } + predictions.select(wrUDF(col(model.getLabelCol), col(predictionCol)).as("workingResiduals")) + } + + private lazy val responseResiduals: DataFrame = { + predictions.select(col(model.getLabelCol).minus(col(predictionCol)).as("responseResiduals")) + } + + /** + * Get the default residuals(deviance residuals) of the fitted model. + */ + @Since("2.0.0") + def residuals(): DataFrame = devianceResiduals + + /** + * Get the residuals of the fitted model by type. + * @param residualsType The type of residuals which should be returned. + * Supported options: deviance, pearson, working and response. + */ + @Since("2.0.0") + def residuals(residualsType: String): DataFrame = { + residualsType match { + case "deviance" => devianceResiduals + case "pearson" => pearsonResiduals + case "working" => workingResiduals + case "response" => responseResiduals + case other => throw new UnsupportedOperationException( + s"The residuals type $other is not supported by Generalized Linear Regression.") + } + } + + /** + * The deviance for the null model. + */ + @Since("2.0.0") + lazy val nullDeviance: Double = { + val w = if (model.getWeightCol.isEmpty) lit(1.0) else col(model.getWeightCol) + val wtdmu: Double = if (model.getFitIntercept) { + val agg = predictions.agg(sum(w.multiply(col(model.getLabelCol))), sum(w)).first() + agg.getDouble(0) / agg.getDouble(1) + } else { + link.unlink(0.0) + } + predictions.select(col(model.getLabelCol), w).rdd.map { + case Row(y: Double, weight: Double) => + family.deviance(y, wtdmu, weight) + }.sum() + } + + /** + * The deviance for the fitted model. + */ + @Since("2.0.0") + lazy val deviance: Double = { + val w = if (model.getWeightCol.isEmpty) lit(1.0) else col(model.getWeightCol) + predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map { + case Row(label: Double, pred: Double, weight: Double) => + family.deviance(label, pred, weight) + }.sum() + } + + /** + * The dispersion of the fitted model. + * It is taken as 1.0 for the "binomial" and "poisson" families, and otherwise + * estimated by the residual Pearson's Chi-Squared statistic(which is defined as + * sum of the squares of the Pearson residuals) divided by the residual degrees of freedom. + */ + @Since("2.0.0") + lazy val dispersion: Double = if ( + model.getFamily == Binomial.name || model.getFamily == Poisson.name) { + 1.0 + } else { + val rss = pearsonResiduals.agg(sum(pow(col("pearsonResiduals"), 2.0))).first().getDouble(0) + rss / degreesOfFreedom + } + + /** Akaike's "An Information Criterion"(AIC) for the fitted model. */ + @Since("2.0.0") + lazy val aic: Double = { + val w = if (model.getWeightCol.isEmpty) lit(1.0) else col(model.getWeightCol) + val weightSum = predictions.select(w).agg(sum(w)).first().getDouble(0) + val t = predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map { + case Row(label: Double, pred: Double, weight: Double) => + (label, pred, weight) + } + family.aic(t, deviance, numInstances, weightSum) + 2 * rank + } + + /** + * Standard error of estimated coefficients and intercept. + * + * If [[GeneralizedLinearRegression.fitIntercept]] is set to true, + * then the last element returned corresponds to the intercept. + */ + @Since("2.0.0") + lazy val coefficientStandardErrors: Array[Double] = { + diagInvAtWA.map(_ * dispersion).map(math.sqrt) + } + + /** + * T-statistic of estimated coefficients and intercept. + * + * If [[GeneralizedLinearRegression.fitIntercept]] is set to true, + * then the last element returned corresponds to the intercept. + */ + @Since("2.0.0") + lazy val tValues: Array[Double] = { + val estimate = if (model.getFitIntercept) { + Array.concat(model.coefficients.toArray, Array(model.intercept)) + } else { + model.coefficients.toArray + } + estimate.zip(coefficientStandardErrors).map { x => x._1 / x._2 } + } + + /** + * Two-sided p-value of estimated coefficients and intercept. + * + * If [[GeneralizedLinearRegression.fitIntercept]] is set to true, + * then the last element returned corresponds to the intercept. + */ + @Since("2.0.0") + lazy val pValues: Array[Double] = { + if (model.getFamily == Binomial.name || model.getFamily == Poisson.name) { + tValues.map { x => 2.0 * (1.0 - dist.Gaussian(0.0, 1.0).cdf(math.abs(x))) } + } else { + tValues.map { x => 2.0 * (1.0 - dist.StudentsT(degreesOfFreedom.toDouble).cdf(math.abs(x))) } + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index a1fe01b047108..7a78ecbdf16de 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -17,16 +17,20 @@ package org.apache.spark.ml.regression -import org.apache.spark.Logging +import org.apache.hadoop.fs.Path + import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasPredictionCol, HasWeightCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} -import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} -import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression, IsotonicRegressionModel => MLlibIsotonicRegressionModel} +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.regression.IsotonicRegressionModel.IsotonicRegressionModelWriter +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} +import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, lit, udf} import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.storage.StorageLevel @@ -73,7 +77,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures * Extracts (label, feature, weight) from input dataset. */ protected[ml] def extractWeightedLabeledPoints( - dataset: DataFrame): RDD[(Double, Double, Double)] = { + dataset: Dataset[_]): RDD[(Double, Double, Double)] = { val f = if (dataset.schema($(featuresCol)).dataType.isInstanceOf[VectorUDT]) { val idx = $(featureIndex) val extract = udf { v: Vector => v(idx) } @@ -86,9 +90,9 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures } else { lit(1.0) } - dataset.select(col($(labelCol)), f, w) - .map { case Row(label: Double, feature: Double, weight: Double) => - (label, feature, weight) + dataset.select(col($(labelCol)).cast(DoubleType), f, w).rdd.map { + case Row(label: Double, feature: Double, weight: Double) => + (label, feature, weight) } } @@ -102,7 +106,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures schema: StructType, fitting: Boolean): StructType = { if (fitting) { - SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(labelCol)) if (hasWeightCol) { SchemaUtils.checkColumnType(schema, $(weightCol), DoubleType) } else { @@ -127,7 +131,8 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures @Since("1.5.0") @Experimental class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: String) - extends Estimator[IsotonicRegressionModel] with IsotonicRegressionBase { + extends Estimator[IsotonicRegressionModel] + with IsotonicRegressionBase with DefaultParamsWritable { @Since("1.5.0") def this() = this(Identifiable.randomUID("isoReg")) @@ -159,8 +164,8 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri @Since("1.5.0") override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) - @Since("1.5.0") - override def fit(dataset: DataFrame): IsotonicRegressionModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): IsotonicRegressionModel = { validateAndTransformSchema(dataset.schema, fitting = true) // Extract columns from data. If dataset is persisted, do not persist oldDataset. val instances = extractWeightedLabeledPoints(dataset) @@ -179,6 +184,13 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri } } +@Since("1.6.0") +object IsotonicRegression extends DefaultParamsReadable[IsotonicRegression] { + + @Since("1.6.0") + override def load(path: String): IsotonicRegression = super.load(path) +} + /** * :: Experimental :: * Model fitted by IsotonicRegression. @@ -194,7 +206,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri class IsotonicRegressionModel private[ml] ( override val uid: String, private val oldModel: MLlibIsotonicRegressionModel) - extends Model[IsotonicRegressionModel] with IsotonicRegressionBase { + extends Model[IsotonicRegressionModel] with IsotonicRegressionBase with MLWritable { /** @group setParam */ @Since("1.5.0") @@ -224,8 +236,8 @@ class IsotonicRegressionModel private[ml] ( copyValues(new IsotonicRegressionModel(uid, oldModel), extra).setParent(parent) } - @Since("1.5.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val predict = dataset.schema($(featuresCol)).dataType match { case DoubleType => udf { feature: Double => oldModel.predict(feature) } @@ -240,4 +252,61 @@ class IsotonicRegressionModel private[ml] ( override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = false) } + + @Since("1.6.0") + override def write: MLWriter = + new IsotonicRegressionModelWriter(this) +} + +@Since("1.6.0") +object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] { + + @Since("1.6.0") + override def read: MLReader[IsotonicRegressionModel] = new IsotonicRegressionModelReader + + @Since("1.6.0") + override def load(path: String): IsotonicRegressionModel = super.load(path) + + /** [[MLWriter]] instance for [[IsotonicRegressionModel]] */ + private[IsotonicRegressionModel] class IsotonicRegressionModelWriter ( + instance: IsotonicRegressionModel + ) extends MLWriter with Logging { + + private case class Data( + boundaries: Array[Double], + predictions: Array[Double], + isotonic: Boolean) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: boundaries, predictions, isotonic + val data = Data( + instance.oldModel.boundaries, instance.oldModel.predictions, instance.oldModel.isotonic) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class IsotonicRegressionModelReader extends MLReader[IsotonicRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[IsotonicRegressionModel].getName + + override def load(path: String): IsotonicRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("boundaries", "predictions", "isotonic").head() + val boundaries = data.getAs[Seq[Double]](0).toArray + val predictions = data.getAs[Seq[Double]](1).toArray + val isotonic = data.getBoolean(2) + val model = new IsotonicRegressionModel( + metadata.uid, new MLlibIsotonicRegressionModel(boundaries, predictions, isotonic)) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 913140e581983..71e02730c7210 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -22,22 +22,25 @@ import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} import breeze.stats.distributions.StudentsT +import org.apache.hadoop.fs.Path -import org.apache.spark.{Logging, SparkException} +import org.apache.spark.SparkException +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.optim.WeightedLeastSquares -import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DoubleType import org.apache.spark.storage.StorageLevel /** @@ -55,7 +58,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams * The specific squared error loss function used is: * L = 1/2n ||A coefficients - y||^2^ * - * This support multiple types of regularization: + * This supports multiple types of regularization: * - none (a.k.a. ordinary least squares) * - L2 (ridge regression) * - L1 (Lasso) @@ -65,7 +68,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams @Experimental class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel] - with LinearRegressionParams with Logging { + with LinearRegressionParams with DefaultParamsWritable with Logging { @Since("1.4.0") def this() = this(Identifiable.randomUID("linReg")) @@ -144,6 +147,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String /** * Set the solver algorithm used for optimization. * In case of linear regression, this can be "l-bfgs", "normal" and "auto". + * "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton + * optimization method. "normal" denotes using Normal Equation as an analytical + * solution to the linear regression problem. * The default value is "auto" which means that the solver algorithm is * selected automatically. * @group setParam @@ -152,21 +158,21 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String def setSolver(value: String): this.type = set(solver, value) setDefault(solver -> "auto") - override protected def train(dataset: DataFrame): LinearRegressionModel = { + override protected def train(dataset: Dataset[_]): LinearRegressionModel = { // Extract the number of features before deciding optimization solver. - val numFeatures = dataset.select(col($(featuresCol))).limit(1).map { + val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd.map { case Row(features: Vector) => features.size }.first() val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && numFeatures <= 4096) || - $(solver) == "normal") { + if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && + numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") { require($(elasticNetParam) == 0.0, "Only L2 regularization can be used when normal " + "solver is used.'") // For low dimensional data, WeightedLeastSquares is more efficiently since the // training algorithm only requires one pass through the data. (SPARK-10668) val instances: RDD[Instance] = dataset.select( - col($(labelCol)), w, col($(featuresCol))).map { + col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } @@ -184,18 +190,19 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String summaryModel.transform(dataset), predictionColName, $(labelCol), + $(featuresCol), summaryModel, model.diagInvAtWA.toArray, - $(featuresCol), Array(0D)) return lrModel.setSummary(trainingSummary) } - val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } + val instances: RDD[Instance] = + dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) + } val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) @@ -215,33 +222,49 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } val yMean = ySummarizer.mean(0) - val yStd = math.sqrt(ySummarizer.variance(0)) - - // If the yStd is zero, then the intercept is yMean with zero coefficient; - // as a result, training is not needed. - if (yStd == 0.0) { - logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + - s"zeros and the intercept will be the mean of the label; as a result, " + - s"training is not needed.") - if (handlePersistence) instances.unpersist() - val coefficients = Vectors.sparse(numFeatures, Seq()) - val intercept = yMean - - val model = new LinearRegressionModel(uid, coefficients, intercept) - // Handle possible missing or invalid prediction columns - val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() - - val trainingSummary = new LinearRegressionTrainingSummary( - summaryModel.transform(dataset), - predictionColName, - $(labelCol), - model, - Array(0D), - $(featuresCol), - Array(0D)) - return copyValues(model.setSummary(trainingSummary)) + val rawYStd = math.sqrt(ySummarizer.variance(0)) + if (rawYStd == 0.0) { + if ($(fitIntercept) || yMean==0.0) { + // If the rawYStd is zero and fitIntercept=true, then the intercept is yMean with + // zero coefficient; as a result, training is not needed. + // Also, if yMean==0 and rawYStd==0, all the coefficients are zero regardless of + // the fitIntercept + if (yMean == 0.0) { + logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " + + s"and the intercept will all be zero; as a result, training is not needed.") + } else { + logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + + s"zeros and the intercept will be the mean of the label; as a result, " + + s"training is not needed.") + } + if (handlePersistence) instances.unpersist() + val coefficients = Vectors.sparse(numFeatures, Seq()) + val intercept = yMean + + val model = new LinearRegressionModel(uid, coefficients, intercept) + // Handle possible missing or invalid prediction columns + val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() + + val trainingSummary = new LinearRegressionTrainingSummary( + summaryModel.transform(dataset), + predictionColName, + $(labelCol), + $(featuresCol), + model, + Array(0D), + Array(0D)) + return copyValues(model.setSummary(trainingSummary)) + } else { + require($(regParam) == 0.0, "The standard deviation of the label is zero. " + + "Model cannot be regularized.") + logWarning(s"The standard deviation of the label is zero. " + + "Consider setting fitIntercept=true.") + } } + // if y is constant (rawYStd is zero), then y cannot be scaled. In this case + // setting yStd=1.0 ensures that y is not scaled anymore in l-bfgs algorithm. + val yStd = if (rawYStd > 0) rawYStd else math.abs(yMean) val featuresMean = featuresSummarizer.mean.toArray val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) @@ -257,8 +280,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) } else { + val standardizationParam = $(standardization) def effectiveL1RegFun = (index: Int) => { - if ($(standardization)) { + if (standardizationParam) { effectiveL1RegParam } else { // If `standardization` is false, we still standardize the data @@ -332,9 +356,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String summaryModel.transform(dataset), predictionColName, $(labelCol), + $(featuresCol), model, Array(0D), - $(featuresCol), objectiveHistory) model.setSummary(trainingSummary) } @@ -343,6 +367,13 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) } +@Since("1.6.0") +object LinearRegression extends DefaultParamsReadable[LinearRegression] { + + @Since("1.6.0") + override def load(path: String): LinearRegression = super.load(path) +} + /** * :: Experimental :: * Model produced by [[LinearRegression]]. @@ -354,7 +385,7 @@ class LinearRegressionModel private[ml] ( val coefficients: Vector, val intercept: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams { + with LinearRegressionParams with MLWritable { private var trainingSummary: Option[LinearRegressionTrainingSummary] = None @@ -368,12 +399,8 @@ class LinearRegressionModel private[ml] ( * thrown if `trainingSummary == None`. */ @Since("1.5.0") - def summary: LinearRegressionTrainingSummary = trainingSummary match { - case Some(summ) => summ - case None => - throw new SparkException( - "No training summary available for this LinearRegressionModel", - new NullPointerException()) + def summary: LinearRegressionTrainingSummary = trainingSummary.getOrElse { + throw new SparkException("No training summary available for this LinearRegressionModel") } private[regression] def setSummary(summary: LinearRegressionTrainingSummary): this.type = { @@ -386,15 +413,15 @@ class LinearRegressionModel private[ml] ( def hasSummary: Boolean = trainingSummary.isDefined /** - * Evaluates the model on a testset. + * Evaluates the model on a test dataset. * @param dataset Test dataset to evaluate model on. */ - // TODO: decide on a good name before exposing to public API - private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = { + @Since("2.0.0") + def evaluate(dataset: Dataset[_]): LinearRegressionSummary = { // Handle possible missing or invalid prediction columns val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol() new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName, - $(labelCol), this, Array(0D)) + $(labelCol), $(featuresCol), summaryModel, Array(0D)) } /** @@ -405,7 +432,7 @@ class LinearRegressionModel private[ml] ( private[regression] def findSummaryModelAndPredictionCol(): (LinearRegressionModel, String) = { $(predictionCol) match { case "" => - val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString() + val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString (copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName) case p => (this, p) } @@ -422,13 +449,71 @@ class LinearRegressionModel private[ml] ( if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) newModel.setParent(parent) } + + /** + * Returns a [[MLWriter]] instance for this ML instance. + * + * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. + * An option to save [[summary]] may be added in the future. + * + * This also does not save the [[parent]] currently. + */ + @Since("1.6.0") + override def write: MLWriter = new LinearRegressionModel.LinearRegressionModelWriter(this) +} + +@Since("1.6.0") +object LinearRegressionModel extends MLReadable[LinearRegressionModel] { + + @Since("1.6.0") + override def read: MLReader[LinearRegressionModel] = new LinearRegressionModelReader + + @Since("1.6.0") + override def load(path: String): LinearRegressionModel = super.load(path) + + /** [[MLWriter]] instance for [[LinearRegressionModel]] */ + private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel) + extends MLWriter with Logging { + + private case class Data(intercept: Double, coefficients: Vector) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: intercept, coefficients + val data = Data(instance.intercept, instance.coefficients) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[LinearRegressionModel].getName + + override def load(path: String): LinearRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.format("parquet").load(dataPath) + .select("intercept", "coefficients").head() + val intercept = data.getDouble(0) + val coefficients = data.getAs[Vector](1) + val model = new LinearRegressionModel(metadata.uid, coefficients, intercept) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** * :: Experimental :: * Linear regression training results. Currently, the training summary ignores the - * training coefficients except for the objective trace. - * @param predictions predictions outputted by the model's `transform` method. + * training weights except for the objective trace. + * + * @param predictions predictions output by the model's `transform` method. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ @Since("1.5.0") @@ -437,13 +522,24 @@ class LinearRegressionTrainingSummary private[regression] ( predictions: DataFrame, predictionCol: String, labelCol: String, + featuresCol: String, model: LinearRegressionModel, diagInvAtWA: Array[Double], - val featuresCol: String, val objectiveHistory: Array[Double]) - extends LinearRegressionSummary(predictions, predictionCol, labelCol, model, diagInvAtWA) { + extends LinearRegressionSummary( + predictions, + predictionCol, + labelCol, + featuresCol, + model, + diagInvAtWA) { - /** Number of training iterations until termination */ + /** + * Number of training iterations until termination + * + * This value is only available when using the "l-bfgs" solver. + * @see [[LinearRegression.solver]] + */ @Since("1.5.0") val totalIterations = objectiveHistory.length @@ -452,7 +548,12 @@ class LinearRegressionTrainingSummary private[regression] ( /** * :: Experimental :: * Linear regression results evaluated on a dataset. - * @param predictions predictions outputted by the model's `transform` method. + * + * @param predictions predictions output by the model's `transform` method. + * @param predictionCol Field in "predictions" which gives the predicted value of the label at + * each instance. + * @param labelCol Field in "predictions" which gives the true label of each instance. + * @param featuresCol Field in "predictions" which gives the features of each instance as a vector. */ @Since("1.5.0") @Experimental @@ -460,18 +561,24 @@ class LinearRegressionSummary private[regression] ( @transient val predictions: DataFrame, val predictionCol: String, val labelCol: String, + val featuresCol: String, val model: LinearRegressionModel, - val diagInvAtWA: Array[Double]) extends Serializable { + private val diagInvAtWA: Array[Double]) extends Serializable { @transient private val metrics = new RegressionMetrics( predictions - .select(predictionCol, labelCol) - .map { case Row(pred: Double, label: Double) => (pred, label) } ) + .select(col(predictionCol), col(labelCol).cast(DoubleType)) + .rdd + .map { case Row(pred: Double, label: Double) => (pred, label) }, + !model.getFitIntercept) /** * Returns the explained variance regression score. * explainedVariance = 1 - variance(y - \hat{y}) / variance(y) * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]] + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val explainedVariance: Double = metrics.explainedVariance @@ -479,6 +586,9 @@ class LinearRegressionSummary private[regression] ( /** * Returns the mean absolute error, which is a risk function corresponding to the * expected value of the absolute error loss or l1-norm loss. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val meanAbsoluteError: Double = metrics.meanAbsoluteError @@ -486,6 +596,9 @@ class LinearRegressionSummary private[regression] ( /** * Returns the mean squared error, which is a risk function corresponding to the * expected value of the squared error loss or quadratic loss. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val meanSquaredError: Double = metrics.meanSquaredError @@ -493,6 +606,9 @@ class LinearRegressionSummary private[regression] ( /** * Returns the root mean squared error, which is defined as the square root of * the mean squared error. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val rootMeanSquaredError: Double = metrics.rootMeanSquaredError @@ -500,6 +616,9 @@ class LinearRegressionSummary private[regression] ( /** * Returns R^2^, the coefficient of determination. * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val r2: Double = metrics.r2 @@ -536,6 +655,12 @@ class LinearRegressionSummary private[regression] ( /** * Standard error of estimated coefficients and intercept. + * This value is only available when using the "normal" solver. + * + * If [[LinearRegression.fitIntercept]] is set to true, + * then the last element returned corresponds to the intercept. + * + * @see [[LinearRegression.solver]] */ lazy val coefficientStandardErrors: Array[Double] = { if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) { @@ -551,12 +676,18 @@ class LinearRegressionSummary private[regression] ( col(model.getWeightCol)).as("wse")).agg(sum(col("wse"))).first().getDouble(0) } val sigma2 = rss / degreesOfFreedom - diagInvAtWA.map(_ * sigma2).map(math.sqrt(_)) + diagInvAtWA.map(_ * sigma2).map(math.sqrt) } } /** * T-statistic of estimated coefficients and intercept. + * This value is only available when using the "normal" solver. + * + * If [[LinearRegression.fitIntercept]] is set to true, + * then the last element returned corresponds to the intercept. + * + * @see [[LinearRegression.solver]] */ lazy val tValues: Array[Double] = { if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) { @@ -574,6 +705,12 @@ class LinearRegressionSummary private[regression] ( /** * Two-sided p-value of estimated coefficients and intercept. + * This value is only available when using the "normal" solver. + * + * If [[LinearRegression.fitIntercept]] is set to true, + * then the last element returned corresponds to the intercept. + * + * @see [[LinearRegression.solver]] */ lazy val pValues: Array[Double] = { if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) { @@ -724,7 +861,7 @@ private class LeastSquaresAggregator( instance match { case Instance(label, weight, features) => require(dim == features.size, s"Dimensions mismatch when adding new sample." + s" Expecting $dim but got ${features.size}.") - require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 71e40b513ee0a..4c4ff278d4eb7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -17,18 +17,22 @@ package org.apache.spark.ml.regression +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ + import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams} +import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.RandomForest -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ @@ -41,7 +45,7 @@ import org.apache.spark.sql.functions._ @Experimental final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel] - with RandomForestParams with TreeRegressorParams { + with RandomForestRegressorParams with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("rfr")) @@ -89,7 +93,7 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val override def setFeatureSubsetStrategy(value: String): this.type = super.setFeatureSubsetStrategy(value) - override protected def train(dataset: DataFrame): RandomForestRegressionModel = { + override protected def train(dataset: Dataset[_]): RandomForestRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) @@ -108,7 +112,7 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val @Since("1.4.0") @Experimental -object RandomForestRegressor { +object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor]{ /** Accessor for supported impurity settings: variance */ @Since("1.4.0") final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities @@ -117,12 +121,17 @@ object RandomForestRegressor { @Since("1.4.0") final val supportedFeatureSubsetStrategies: Array[String] = RandomForestParams.supportedFeatureSubsetStrategies + + @Since("2.0.0") + override def load(path: String): RandomForestRegressor = super.load(path) + } /** * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression. * It supports both continuous and categorical features. + * * @param _trees Decision trees in the ensemble. * @param numFeatures Number of features used by this model */ @@ -133,27 +142,29 @@ final class RandomForestRegressionModel private[ml] ( private val _trees: Array[DecisionTreeRegressionModel], override val numFeatures: Int) extends PredictionModel[Vector, RandomForestRegressionModel] - with TreeEnsembleModel with Serializable { + with RandomForestRegressionModelParams with TreeEnsembleModel[DecisionTreeRegressionModel] + with MLWritable with Serializable { - require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.") + require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.") /** * Construct a random forest regression model, with all trees weighted equally. + * * @param trees Component trees */ private[ml] def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) = this(Identifiable.randomUID("rfr"), trees, numFeatures) @Since("1.4.0") - override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + override def trees: Array[DecisionTreeRegressionModel] = _trees // Note: We may add support for weights (based on tree performance) later on. - private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0) + private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0) @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights - override protected def transformImpl(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) val predictUDF = udf { (features: Any) => bcastModel.value.predict(features.asInstanceOf[Vector]) @@ -165,9 +176,17 @@ final class RandomForestRegressionModel private[ml] ( // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128 // Predict average of tree predictions. // Ignore the weights since all are 1.0 for now. - _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees + _trees.map(_.rootNode.predictImpl(features).prediction).sum / getNumTrees } + /** + * Number of trees in ensemble + * @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0 + */ + // TODO: Once this is removed, then this class can inherit from RandomForestRegressorParams + @deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0") + val numTrees: Int = trees.length + @Since("1.4.0") override def copy(extra: ParamMap): RandomForestRegressionModel = { copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent) @@ -175,36 +194,83 @@ final class RandomForestRegressionModel private[ml] ( @Since("1.4.0") override def toString: String = { - s"RandomForestRegressionModel (uid=$uid) with $numTrees trees" + s"RandomForestRegressionModel (uid=$uid) with $getNumTrees trees" } /** * Estimate of the importance of each feature. * - * This generalizes the idea of "Gini" importance to other losses, - * following the explanation of Gini importance from "Random Forests" documentation - * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * Each feature's importance is the average of its importance across all trees in the ensemble + * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. + * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) + * and follows the implementation from scikit-learn. * - * This feature importance is calculated as follows: - * - Average over trees: - * - importance(feature j) = sum (over nodes which split on feature j) of the gain, - * where gain is scaled by the number of instances passing through node - * - Normalize importances for tree based on total number of training instances used - * to build tree. - * - Normalize feature importance vector to sum to 1. + * @see [[DecisionTreeRegressionModel.featureImportances]] */ - lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures) + @Since("1.5.0") + lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldRandomForestModel = { new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld)) } + + @Since("2.0.0") + override def write: MLWriter = + new RandomForestRegressionModel.RandomForestRegressionModelWriter(this) } -private[ml] object RandomForestRegressionModel { +@Since("2.0.0") +object RandomForestRegressionModel extends MLReadable[RandomForestRegressionModel] { + + @Since("2.0.0") + override def read: MLReader[RandomForestRegressionModel] = new RandomForestRegressionModelReader + + @Since("2.0.0") + override def load(path: String): RandomForestRegressionModel = super.load(path) + + private[RandomForestRegressionModel] + class RandomForestRegressionModelWriter(instance: RandomForestRegressionModel) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures, + "numTrees" -> instance.getNumTrees) + EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + } + } + + private class RandomForestRegressionModelReader extends MLReader[RandomForestRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[RandomForestRegressionModel].getName + private val treeClassName = classOf[DecisionTreeRegressionModel].getName + + override def load(path: String): RandomForestRegressionModel = { + implicit val format = DefaultFormats + val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = + EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val numTrees = (metadata.metadata \ "numTrees").extract[Int] + + val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => + val tree = + new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + DefaultParamsReader.getAndSetParams(tree, treeMetadata) + tree + } + require(numTrees == trees.length, s"RandomForestRegressionModel.load expected $numTrees" + + s" trees based on metadata but found ${trees.length} trees.") + + val model = new RandomForestRegressionModel(metadata.uid, trees, numFeatures) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } - /** (private[ml]) Convert a model from the old API */ - def fromOld( + /** Convert a model from the old API */ + private[ml] def fromOld( oldModel: OldRandomForestModel, parent: RandomForestRegressor, categoricalFeatures: Map[Int, Int], @@ -215,6 +281,7 @@ private[ml] object RandomForestRegressionModel { // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } - new RandomForestRegressionModel(parent.uid, newTrees, numFeatures) + val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfr") + new RandomForestRegressionModel(uid, newTrees, numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala index c72ef29680329..be356575ca09a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala @@ -18,19 +18,16 @@ package org.apache.spark.ml.regression import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} +import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} /** - * :: DeveloperApi :: - * * Single-label regression * * @tparam FeaturesType Type of input features. E.g., [[org.apache.spark.mllib.linalg.Vector]] * @tparam Learner Concrete Estimator type * @tparam M Concrete Model type */ -@DeveloperApi private[spark] abstract class Regressor[ FeaturesType, Learner <: Regressor[FeaturesType, Learner, M], diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 1f627777fc68d..2f1f2523fd11e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -17,54 +17,62 @@ package org.apache.spark.ml.source.libsvm -import com.google.common.base.Objects +import java.io.IOException + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.{NullWritable, Text} +import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat -import org.apache.spark.Logging import org.apache.spark.annotation.Since -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrameReader, DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, JoinedRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, HadoopFileLinesReader, PartitionedFile} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{DoubleType, StructField, StructType} +import org.apache.spark.sql.types._ +import org.apache.spark.util.SerializableConfiguration -/** - * LibSVMRelation provides the DataFrame constructed from LibSVM format data. - * @param path File path of LibSVM format - * @param numFeatures The number of features - * @param vectorType The type of vector. It can be 'sparse' or 'dense' - * @param sqlContext The Spark SQLContext - */ -private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String) - (@transient val sqlContext: SQLContext) - extends BaseRelation with TableScan with Logging with Serializable { - - override def schema: StructType = StructType( - StructField("label", DoubleType, nullable = false) :: - StructField("features", new VectorUDT(), nullable = false) :: Nil - ) - - override def buildScan(): RDD[Row] = { - val sc = sqlContext.sparkContext - val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures) - val sparse = vectorType == "sparse" - baseRdd.map { pt => - val features = if (sparse) pt.features.toSparse else pt.features.toDense - Row(pt.label, features) - } +private[libsvm] class LibSVMOutputWriter( + path: String, + dataSchema: StructType, + context: TaskAttemptContext) + extends OutputWriter { + + private[this] val buffer = new Text() + + private val recordWriter: RecordWriter[NullWritable, Text] = { + new TextOutputFormat[NullWritable, Text]() { + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val configuration = context.getConfiguration + val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId = context.getTaskAttemptID + val split = taskAttemptId.getTaskID.getId + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + } + }.getRecordWriter(context) } - override def hashCode(): Int = { - Objects.hashCode(path, Double.box(numFeatures), vectorType) + override def write(row: Row): Unit = { + val label = row.get(0) + val vector = row.get(1).asInstanceOf[Vector] + val sb = new StringBuilder(label.toString) + vector.foreachActive { case (i, v) => + sb += ' ' + sb ++= s"${i + 1}:$v" + } + buffer.set(sb.mkString) + recordWriter.write(NullWritable.get(), buffer) } - override def equals(other: Any): Boolean = other match { - case that: LibSVMRelation => - path == that.path && - numFeatures == that.numFeatures && - vectorType == that.vectorType - case _ => - false + override def close(): Unit = { + recordWriter.close(context) } } @@ -82,7 +90,7 @@ private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val * .load("data/mllib/sample_libsvm_data.txt") * * // Java - * DataFrame df = sqlContext.read.format("libsvm") + * DataFrame df = sqlContext.read().format("libsvm") * .option("numFeatures, "780") * .load("data/mllib/sample_libsvm_data.txt"); * }}} @@ -99,18 +107,119 @@ private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val * @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]] */ @Since("1.6.0") -class DefaultSource extends RelationProvider with DataSourceRegister { +class DefaultSource extends FileFormat with DataSourceRegister { @Since("1.6.0") override def shortName(): String = "libsvm" - @Since("1.6.0") - override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]) - : BaseRelation = { - val path = parameters.getOrElse("path", - throw new IllegalArgumentException("'path' must be specified")) - val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt - val vectorType = parameters.getOrElse("vectorType", "sparse") - new LibSVMRelation(path, numFeatures, vectorType)(sqlContext) + override def toString: String = "LibSVM" + + private def verifySchema(dataSchema: StructType): Unit = { + if (dataSchema.size != 2 || + (!dataSchema(0).dataType.sameType(DataTypes.DoubleType) + || !dataSchema(1).dataType.sameType(new VectorUDT()))) { + throw new IOException(s"Illegal schema for libsvm data, schema=$dataSchema") + } + } + + override def inferSchema( + sqlContext: SQLContext, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + Some( + StructType( + StructField("label", DoubleType, nullable = false) :: + StructField("features", new VectorUDT(), nullable = false) :: Nil)) + } + + override def prepareRead( + sqlContext: SQLContext, + options: Map[String, String], + files: Seq[FileStatus]): Map[String, String] = { + def computeNumFeatures(): Int = { + val dataFiles = files.filterNot(_.getPath.getName startsWith "_") + val path = if (dataFiles.length == 1) { + dataFiles.head.getPath.toUri.toString + } else if (dataFiles.isEmpty) { + throw new IOException("No input path specified for libsvm data") + } else { + throw new IOException("Multiple input paths are not supported for libsvm data.") + } + + val sc = sqlContext.sparkContext + val parsed = MLUtils.parseLibSVMFile(sc, path, sc.defaultParallelism) + MLUtils.computeNumFeatures(parsed) + } + + val numFeatures = options.get("numFeatures").filter(_.toInt > 0).getOrElse { + computeNumFeatures() + } + + new CaseInsensitiveMap(options + ("numFeatures" -> numFeatures.toString)) + } + + override def prepareWrite( + sqlContext: SQLContext, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + new OutputWriterFactory { + override def newInstance( + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + if (bucketId.isDefined) { sys.error("LibSVM doesn't support bucketing") } + new LibSVMOutputWriter(path, dataSchema, context) + } + } + } + + override def buildReader( + sqlContext: SQLContext, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String]): (PartitionedFile) => Iterator[InternalRow] = { + verifySchema(dataSchema) + val numFeatures = options("numFeatures").toInt + assert(numFeatures > 0) + + val sparse = options.getOrElse("vectorType", "sparse") == "sparse" + + val broadcastedConf = sqlContext.sparkContext.broadcast( + new SerializableConfiguration(new Configuration(sqlContext.sparkContext.hadoopConfiguration)) + ) + + (file: PartitionedFile) => { + val points = + new HadoopFileLinesReader(file, broadcastedConf.value.value) + .map(_.toString.trim) + .filterNot(line => line.isEmpty || line.startsWith("#")) + .map { line => + val (label, indices, values) = MLUtils.parseLibSVMRecord(line) + LabeledPoint(label, Vectors.sparse(numFeatures, indices, values)) + } + + val converter = RowEncoder(requiredSchema) + + val unsafeRowIterator = points.map { pt => + val features = if (sparse) pt.features.toSparse else pt.features.toDense + converter.toRow(Row(pt.label, features)) + } + + def toAttribute(f: StructField): AttributeReference = + AttributeReference(f.name, f.dataType, f.nullable, f.metadata)() + + // Appends partition values + val fullOutput = (requiredSchema ++ partitionSchema).map(toAttribute) + val joinedRow = new JoinedRow() + val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput) + + unsafeRowIterator.map { dataRow => + appendPartitionColumns(joinedRow(dataRow, file.partitionValues)) + } + } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index d89682611e3f5..b5cb378829eba 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -20,8 +20,8 @@ package org.apache.spark.ml.tree import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.impurity.ImpurityCalculator -import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats, - Node => OldNode, Predict => OldPredict, ImpurityStats} +import org.apache.spark.mllib.tree.model.{ImpurityStats, + InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} /** * :: DeveloperApi :: @@ -78,6 +78,9 @@ sealed abstract class Node extends Serializable { * @return Max feature index used in a split, or -1 if there are no splits (single leaf node). */ private[ml] def maxSplitFeatureIndex(): Int + + /** Returns a deep copy of the subtree rooted at this node. */ + private[tree] def deepCopy(): Node } private[ml] object Node { @@ -137,6 +140,10 @@ final class LeafNode private[ml] ( } override private[ml] def maxSplitFeatureIndex(): Int = -1 + + override private[tree] def deepCopy(): Node = { + new LeafNode(prediction, impurity, impurityStats) + } } /** @@ -203,6 +210,11 @@ final class InternalNode private[ml] ( math.max(split.featureIndex, math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex())) } + + override private[tree] def deepCopy(): Node = { + new InternalNode(prediction, impurity, gain, leftChild.deepCopy(), rightChild.deepCopy(), + split, impurityStats) + } } private object InternalNode { @@ -286,11 +298,12 @@ private[tree] class LearningNode( * * @param binnedFeatures Binned feature vector for data point. * @param splits possible splits for all features, indexed (numFeatures)(numSplits) - * @return Leaf index if the data point reaches a leaf. - * Otherwise, last node reachable in tree matching this example. - * Note: This is the global node index, i.e., the index used in the tree. - * This index is different from the index used during training a particular - * group of nodes on one call to [[findBestSplits()]]. + * @return Leaf index if the data point reaches a leaf. + * Otherwise, last node reachable in tree matching this example. + * Note: This is the global node index, i.e., the index used in the tree. + * This index is different from the index used during training a particular + * group of nodes on one call to + * [[org.apache.spark.ml.tree.impl.RandomForest.findBestSplits()]]. */ def predictImpl(binnedFeatures: Array[Int], splits: Array[Array[Split]]): Int = { if (this.isLeaf || this.split.isEmpty) { @@ -386,9 +399,9 @@ private[tree] object LearningNode { var levelsToGo = indexToLevel(nodeIndex) while (levelsToGo > 0) { if ((nodeIndex & (1 << levelsToGo - 1)) == 0) { - tmpNode = tmpNode.leftChild.asInstanceOf[LearningNode] + tmpNode = tmpNode.leftChild.get } else { - tmpNode = tmpNode.rightChild.asInstanceOf[LearningNode] + tmpNode = tmpNode.rightChild.get } levelsToGo -= 1 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala index 78199cc2df582..9d895b8faca7d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.tree -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType} import org.apache.spark.mllib.tree.model.{Split => OldSplit} @@ -76,7 +76,7 @@ private[tree] object Split { final class CategoricalSplit private[ml] ( override val featureIndex: Int, _leftCategories: Array[Double], - private val numCategories: Int) + @Since("2.0.0") val numCategories: Int) extends Split { require(_leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala similarity index 99% rename from mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala rename to mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala index 572815df0bc4a..4e372702f0c65 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.tree.impl +package org.apache.spark.ml.tree.impl import org.apache.commons.math3.distribution.PoissonDistribution diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala similarity index 77% rename from mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala rename to mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala index 7985ed4b4c0fa..61091bb803e49 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.tree.impl +package org.apache.spark.ml.tree.impl import org.apache.spark.mllib.tree.impurity._ @@ -73,25 +73,34 @@ private[spark] class DTStatsAggregator( * Flat array of elements. * Index for start of stats for a (feature, bin) is: * index = featureOffsets(featureIndex) + binIndex * statsSize - * Note: For unordered features, - * the left child stats have binIndex in [0, numBins(featureIndex) / 2)) - * and the right child stats in [numBins(featureIndex) / 2), numBins(featureIndex)) */ private val allStats: Array[Double] = new Array[Double](allStatsSize) + /** + * Array of parent node sufficient stats. + * + * Note: this is necessary because stats for the parent node are not available + * on the first iteration of tree learning. + */ + private val parentStats: Array[Double] = new Array[Double](statsSize) /** * Get an [[ImpurityCalculator]] for a given (node, feature, bin). - * @param featureOffset For ordered features, this is a pre-computed (node, feature) offset + * + * @param featureOffset This is a pre-computed (node, feature) offset * from [[getFeatureOffset]]. - * For unordered features, this is a pre-computed - * (node, feature, left/right child) offset from - * [[getLeftRightFeatureOffsets]]. */ def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = { impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize) } + /** + * Get an [[ImpurityCalculator]] for the parent node. + */ + def getParentImpurityCalculator(): ImpurityCalculator = { + impurityAggregator.getCalculator(parentStats, 0) + } + /** * Update the stats for a given (feature, bin) for ordered features, using the given label. */ @@ -100,14 +109,19 @@ private[spark] class DTStatsAggregator( impurityAggregator.update(allStats, i, label, instanceWeight) } + /** + * Update the parent node stats using the given label. + */ + def updateParent(label: Double, instanceWeight: Double): Unit = { + impurityAggregator.update(parentStats, 0, label, instanceWeight) + } + /** * Faster version of [[update]]. * Update the stats for a given (feature, bin), using the given label. - * @param featureOffset For ordered features, this is a pre-computed feature offset + * + * @param featureOffset This is a pre-computed feature offset * from [[getFeatureOffset]]. - * For unordered features, this is a pre-computed - * (feature, left/right child) offset from - * [[getLeftRightFeatureOffsets]]. */ def featureUpdate( featureOffset: Int, @@ -124,22 +138,11 @@ private[spark] class DTStatsAggregator( */ def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex) - /** - * Pre-compute feature offset for use with [[featureUpdate]]. - * For unordered features only. - */ - def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = { - val baseOffset = featureOffsets(featureIndex) - (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize) - } - /** * For a given feature, merge the stats for two bins. - * @param featureOffset For ordered features, this is a pre-computed feature offset + * + * @param featureOffset This is a pre-computed feature offset * from [[getFeatureOffset]]. - * For unordered features, this is a pre-computed - * (feature, left/right child) offset from - * [[getLeftRightFeatureOffsets]]. * @param binIndex The other bin is merged into this bin. * @param otherBinIndex This bin is not modified. */ @@ -162,6 +165,17 @@ private[spark] class DTStatsAggregator( allStats(i) += other.allStats(i) i += 1 } + + require(statsSize == other.statsSize, + s"DTStatsAggregator.merge requires that both aggregators have the same length parent " + + s"stats vectors. This aggregator's parent stats are length $statsSize, " + + s"but the other is ${other.statsSize}.") + var j = 0 + while (j < statsSize) { + parentStats(j) += other.parentStats(j) + j += 1 + } + this } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala similarity index 90% rename from mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala rename to mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala index 21ee49c45788c..5f7c40f6071f6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala @@ -15,11 +15,13 @@ * limitations under the License. */ -package org.apache.spark.mllib.tree.impl +package org.apache.spark.ml.tree.impl import scala.collection.mutable +import scala.util.Try -import org.apache.spark.Logging +import org.apache.spark.internal.Logging +import org.apache.spark.ml.tree.RandomForestParams import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ @@ -67,11 +69,11 @@ private[spark] class DecisionTreeMetadata( /** * Number of splits for the given feature. - * For unordered features, there are 2 bins per split. + * For unordered features, there is 1 bin per split. * For ordered features, there is 1 more bin than split. */ def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) { - numBins(featureIndex) >> 1 + numBins(featureIndex) } else { numBins(featureIndex) - 1 } @@ -183,11 +185,23 @@ private[spark] object DecisionTreeMetadata extends Logging { } case _ => featureSubsetStrategy } + val numFeaturesPerNode: Int = _featureSubsetStrategy match { case "all" => numFeatures case "sqrt" => math.sqrt(numFeatures).ceil.toInt case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt) case "onethird" => (numFeatures / 3.0).ceil.toInt + case _ => + Try(_featureSubsetStrategy.toInt).filter(_ > 0).toOption match { + case Some(value) => math.min(value, numFeatures) + case None => + Try(_featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).toOption match { + case Some(value) => math.ceil(value * numFeatures).toInt + case _ => throw new IllegalArgumentException(s"Supported values:" + + s" ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}," + + s" (0.0-1.0], [1-n].") + } + } } new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, @@ -212,6 +226,6 @@ private[spark] object DecisionTreeMetadata extends Logging { * there are math.pow(2, arity - 1) - 1 such splits. * Each split has 2 corresponding bins. */ - def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1) + def numUnorderedBins(arity: Int): Int = (1 << arity - 1) - 1 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala new file mode 100644 index 0000000000000..b6334762c7a7f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -0,0 +1,375 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.internal.Logging +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} +import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy} +import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance} +import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + + +private[spark] object GradientBoostedTrees extends Logging { + + /** + * Method to train a gradient boosting model + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param seed Random seed. + * @return tuple of ensemble models and weights: + * (array of decision tree models, array of model weights) + */ + def run( + input: RDD[LabeledPoint], + boostingStrategy: OldBoostingStrategy, + seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = { + val algo = boostingStrategy.treeStrategy.algo + algo match { + case OldAlgo.Regression => + GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed) + case OldAlgo.Classification => + // Map labels to -1, +1 so binary classification can be treated as regression. + val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false, + seed) + case _ => + throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.") + } + } + + /** + * Method to validate a gradient boosting model + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param validationInput Validation dataset. + * This dataset should be different from the training dataset, + * but it should follow the same distribution. + * E.g., these two datasets could be created from an original dataset + * by using [[org.apache.spark.rdd.RDD.randomSplit()]] + * @param seed Random seed. + * @return tuple of ensemble models and weights: + * (array of decision tree models, array of model weights) + */ + def runWithValidation( + input: RDD[LabeledPoint], + validationInput: RDD[LabeledPoint], + boostingStrategy: OldBoostingStrategy, + seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = { + val algo = boostingStrategy.treeStrategy.algo + algo match { + case OldAlgo.Regression => + GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed) + case OldAlgo.Classification => + // Map labels to -1, +1 so binary classification can be treated as regression. + val remappedInput = input.map( + x => new LabeledPoint((x.label * 2) - 1, x.features)) + val remappedValidationInput = validationInput.map( + x => new LabeledPoint((x.label * 2) - 1, x.features)) + GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, + validate = true, seed) + case _ => + throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") + } + } + + /** + * Compute the initial predictions and errors for a dataset for the first + * iteration of gradient boosting. + * @param data: training data. + * @param initTreeWeight: learning rate assigned to the first tree. + * @param initTree: first DecisionTreeModel. + * @param loss: evaluation metric. + * @return a RDD with each element being a zip of the prediction and error + * corresponding to every sample. + */ + def computeInitialPredictionAndError( + data: RDD[LabeledPoint], + initTreeWeight: Double, + initTree: DecisionTreeRegressionModel, + loss: OldLoss): RDD[(Double, Double)] = { + data.map { lp => + val pred = updatePrediction(lp.features, 0.0, initTree, initTreeWeight) + val error = loss.computeError(pred, lp.label) + (pred, error) + } + } + + /** + * Update a zipped predictionError RDD + * (as obtained with computeInitialPredictionAndError) + * @param data: training data. + * @param predictionAndError: predictionError RDD + * @param treeWeight: Learning rate. + * @param tree: Tree using which the prediction and error should be updated. + * @param loss: evaluation metric. + * @return a RDD with each element being a zip of the prediction and error + * corresponding to each sample. + */ + def updatePredictionError( + data: RDD[LabeledPoint], + predictionAndError: RDD[(Double, Double)], + treeWeight: Double, + tree: DecisionTreeRegressionModel, + loss: OldLoss): RDD[(Double, Double)] = { + + val newPredError = data.zip(predictionAndError).mapPartitions { iter => + iter.map { case (lp, (pred, error)) => + val newPred = updatePrediction(lp.features, pred, tree, treeWeight) + val newError = loss.computeError(newPred, lp.label) + (newPred, newError) + } + } + newPredError + } + + /** + * Add prediction from a new boosting iteration to an existing prediction. + * + * @param features Vector of features representing a single data point. + * @param prediction The existing prediction. + * @param tree New Decision Tree model. + * @param weight Tree weight. + * @return Updated prediction. + */ + def updatePrediction( + features: Vector, + prediction: Double, + tree: DecisionTreeRegressionModel, + weight: Double): Double = { + prediction + tree.rootNode.predictImpl(features).prediction * weight + } + + /** + * Method to calculate error of the base learner for the gradient boosting calculation. + * Note: This method is not used by the gradient boosting algorithm but is useful for debugging + * purposes. + * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param trees Boosted Decision Tree models + * @param treeWeights Learning rates at each boosting iteration. + * @param loss evaluation metric. + * @return Measure of model error on data + */ + def computeError( + data: RDD[LabeledPoint], + trees: Array[DecisionTreeRegressionModel], + treeWeights: Array[Double], + loss: OldLoss): Double = { + data.map { lp => + val predicted = trees.zip(treeWeights).foldLeft(0.0) { case (acc, (model, weight)) => + updatePrediction(lp.features, acc, model, weight) + } + loss.computeError(predicted, lp.label) + }.mean() + } + + /** + * Method to compute error or loss for every iteration of gradient boosting. + * + * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param trees Boosted Decision Tree models + * @param treeWeights Learning rates at each boosting iteration. + * @param loss evaluation metric. + * @param algo algorithm for the ensemble, either Classification or Regression + * @return an array with index i having the losses or errors for the ensemble + * containing the first i+1 trees + */ + def evaluateEachIteration( + data: RDD[LabeledPoint], + trees: Array[DecisionTreeRegressionModel], + treeWeights: Array[Double], + loss: OldLoss, + algo: OldAlgo.Value): Array[Double] = { + + val sc = data.sparkContext + val remappedData = algo match { + case OldAlgo.Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + case _ => data + } + + val numIterations = trees.length + val evaluationArray = Array.fill(numIterations)(0.0) + val localTreeWeights = treeWeights + + var predictionAndError = computeInitialPredictionAndError( + remappedData, localTreeWeights(0), trees(0), loss) + + evaluationArray(0) = predictionAndError.values.mean() + + val broadcastTrees = sc.broadcast(trees) + (1 until numIterations).foreach { nTree => + predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter => + val currentTree = broadcastTrees.value(nTree) + val currentTreeWeight = localTreeWeights(nTree) + iter.map { case (point, (pred, error)) => + val newPred = updatePrediction(point.features, pred, currentTree, currentTreeWeight) + val newError = loss.computeError(newPred, point.label) + (newPred, newError) + } + } + evaluationArray(nTree) = predictionAndError.values.mean() + } + + broadcastTrees.unpersist() + evaluationArray + } + + /** + * Internal method for performing regression using trees as base learners. + * @param input training dataset + * @param validationInput validation dataset, ignored if validate is set to false. + * @param boostingStrategy boosting parameters + * @param validate whether or not to use the validation dataset. + * @param seed Random seed. + * @return tuple of ensemble models and weights: + * (array of decision tree models, array of model weights) + */ + def boost( + input: RDD[LabeledPoint], + validationInput: RDD[LabeledPoint], + boostingStrategy: OldBoostingStrategy, + validate: Boolean, + seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = { + val timer = new TimeTracker() + timer.start("total") + timer.start("init") + + boostingStrategy.assertValid() + + // Initialize gradient boosting parameters + val numIterations = boostingStrategy.numIterations + val baseLearners = new Array[DecisionTreeRegressionModel](numIterations) + val baseLearnerWeights = new Array[Double](numIterations) + val loss = boostingStrategy.loss + val learningRate = boostingStrategy.learningRate + // Prepare strategy for individual trees, which use regression with variance impurity. + val treeStrategy = boostingStrategy.treeStrategy.copy + val validationTol = boostingStrategy.validationTol + treeStrategy.algo = OldAlgo.Regression + treeStrategy.impurity = OldVariance + treeStrategy.assertValid() + + // Cache input + val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) { + input.persist(StorageLevel.MEMORY_AND_DISK) + true + } else { + false + } + + // Prepare periodic checkpointers + val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( + treeStrategy.getCheckpointInterval, input.sparkContext) + val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( + treeStrategy.getCheckpointInterval, input.sparkContext) + + timer.stop("init") + + logDebug("##########") + logDebug("Building tree 0") + logDebug("##########") + + // Initialize tree + timer.start("building tree 0") + val firstTree = new DecisionTreeRegressor().setSeed(seed) + val firstTreeModel = firstTree.train(input, treeStrategy) + val firstTreeWeight = 1.0 + baseLearners(0) = firstTreeModel + baseLearnerWeights(0) = firstTreeWeight + + var predError: RDD[(Double, Double)] = + computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) + predErrorCheckpointer.update(predError) + logDebug("error of gbt = " + predError.values.mean()) + + // Note: A model of type regression is used since we require raw prediction + timer.stop("building tree 0") + + var validatePredError: RDD[(Double, Double)] = + computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss) + if (validate) validatePredErrorCheckpointer.update(validatePredError) + var bestValidateError = if (validate) validatePredError.values.mean() else 0.0 + var bestM = 1 + + var m = 1 + var doneLearning = false + while (m < numIterations && !doneLearning) { + // Update data with pseudo-residuals + val data = predError.zip(input).map { case ((pred, _), point) => + LabeledPoint(-loss.gradient(pred, point.label), point.features) + } + + timer.start(s"building tree $m") + logDebug("###################################################") + logDebug("Gradient boosting tree iteration " + m) + logDebug("###################################################") + val dt = new DecisionTreeRegressor().setSeed(seed + m) + val model = dt.train(data, treeStrategy) + timer.stop(s"building tree $m") + // Update partial model + baseLearners(m) = model + // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. + // Technically, the weight should be optimized for the particular loss. + // However, the behavior should be reasonable, though not optimal. + baseLearnerWeights(m) = learningRate + + predError = updatePredictionError( + input, predError, baseLearnerWeights(m), baseLearners(m), loss) + predErrorCheckpointer.update(predError) + logDebug("error of gbt = " + predError.values.mean()) + + if (validate) { + // Stop training early if + // 1. Reduction in error is less than the validationTol or + // 2. If the error increases, that is if the model is overfit. + // We want the model returned corresponding to the best validation error. + + validatePredError = updatePredictionError( + validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss) + validatePredErrorCheckpointer.update(validatePredError) + val currentValidateError = validatePredError.values.mean() + if (bestValidateError - currentValidateError < validationTol * Math.max( + currentValidateError, 0.01)) { + doneLearning = true + } else if (currentValidateError < bestValidateError) { + bestValidateError = currentValidateError + bestM = m + 1 + } + } + m += 1 + } + + timer.stop("total") + + logInfo("Internal timing for DecisionTree:") + logInfo(s"$timer") + + predErrorCheckpointer.deleteAllCheckpoints() + validatePredErrorCheckpointer.deleteAllCheckpoints() + if (persistedInput) input.unpersist() + + if (validate) { + (baseLearners.slice(0, bestM), baseLearnerWeights.slice(0, bestM)) + } else { + (baseLearners, baseLearnerWeights) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala index 1ee01131d6334..9d697a36b67da 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -21,12 +21,11 @@ import java.io.IOException import scala.collection.mutable -import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.Logging import org.apache.spark.ml.tree.{LearningNode, Split} -import org.apache.spark.mllib.tree.impl.BaggedPoint import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -171,7 +170,6 @@ private[spark] class NodeIdCache( } } -@DeveloperApi private[spark] object NodeIdCache { /** * Initialize the node Id cache with initial node Id values. diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 4a3b12d1440b8..7b1fd089f2943 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -22,24 +22,61 @@ import java.io.IOException import scala.collection.mutable import scala.util.Random -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.ml.classification.DecisionTreeClassificationModel import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ -import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} -import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata, - TimeTracker} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} -private[ml] object RandomForest extends Logging { +/** + * ALGORITHM + * + * This is a sketch of the algorithm to help new developers. + * + * The algorithm partitions data by instances (rows). + * On each iteration, the algorithm splits a set of nodes. In order to choose the best split + * for a given node, sufficient statistics are collected from the distributed data. + * For each node, the statistics are collected to some worker node, and that worker selects + * the best split. + * + * This setup requires discretization of continuous features. This binning is done in the + * findSplits() method during initialization, after which each continuous feature becomes + * an ordered discretized feature with at most maxBins possible values. + * + * The main loop in the algorithm operates on a queue of nodes (nodeQueue). These nodes + * lie at the periphery of the tree being trained. If multiple trees are being trained at once, + * then this queue contains nodes from all of them. Each iteration works roughly as follows: + * On the master node: + * - Some number of nodes are pulled off of the queue (based on the amount of memory + * required for their sufficient statistics). + * - For random forests, if featureSubsetStrategy is not "all," then a subset of candidate + * features are chosen for each node. See method selectNodesToSplit(). + * On worker nodes, via method findBestSplits(): + * - The worker makes one pass over its subset of instances. + * - For each (tree, node, feature, split) tuple, the worker collects statistics about + * splitting. Note that the set of (tree, node) pairs is limited to the nodes selected + * from the queue for this iteration. The set of features considered can also be limited + * based on featureSubsetStrategy. + * - For each node, the statistics for that node are aggregated to a particular worker + * via reduceByKey(). The designated worker chooses the best (feature, split) pair, + * or chooses to stop splitting if the stopping criteria are met. + * On the master node: + * - The master collects all decisions about splitting nodes and updates the model. + * - The updated model is passed to the workers on the next iteration. + * This process continues until the node queue is empty. + * + * Most of the methods in this implementation support the statistics aggregation, which is + * the heaviest part of the computation. In general, this implementation is bound by either + * the cost of statistics computation on workers or by communicating the sufficient statistics. + */ +private[spark] object RandomForest extends Logging { /** * Train a random forest. @@ -73,9 +110,9 @@ private[ml] object RandomForest extends Logging { // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. - timer.start("findSplitsBins") + timer.start("findSplits") val splits = findSplits(retaggedInput, metadata, seed) - timer.stop("findSplitsBins") + timer.stop("findSplits") logDebug("numBins: feature: number of bins") logDebug(Range(0, metadata.numFeatures).map { featureIndex => s"\t$featureIndex\t${metadata.numBins(featureIndex)}" @@ -100,22 +137,6 @@ private[ml] object RandomForest extends Logging { // TODO: Calculate memory usage more precisely. val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") - val maxMemoryPerNode = { - val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { - // Find numFeaturesPerNode largest bins to get an upper bound on memory usage. - Some(metadata.numBins.zipWithIndex.sortBy(- _._1) - .take(metadata.numFeaturesPerNode).map(_._2)) - } else { - None - } - RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L - } - require(maxMemoryPerNode <= maxMemoryUsage, - s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," + - " which is too small for the given features." + - s" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}") - - timer.stop("init") /* * The main idea here is to perform group-wise training of the decision tree nodes thus @@ -146,6 +167,8 @@ private[ml] object RandomForest extends Logging { val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1)) Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))) + timer.stop("init") + while (nodeQueue.nonEmpty) { // Collect some nodes to split, and choose features for each node (if subsampling). // Each group of nodes may come from one or multiple trees, and at multiple levels. @@ -244,8 +267,7 @@ private[ml] object RandomForest extends Logging { if (unorderedFeatures.contains(featureIndex)) { // Unordered feature val featureValue = treePoint.binnedFeatures(featureIndex) - val (leftNodeFeatureOffset, rightNodeFeatureOffset) = - agg.getLeftRightFeatureOffsets(featureIndexIdx) + val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx) // Update the left or right bin for each split. val numSplits = agg.metadata.numSplits(featureIndex) val featureSplits = splits(featureIndex) @@ -253,8 +275,6 @@ private[ml] object RandomForest extends Logging { while (splitIndex < numSplits) { if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) - } else { - agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) } splitIndex += 1 } @@ -308,7 +328,7 @@ private[ml] object RandomForest extends Logging { /** * Given a group of nodes, this finds the best split for each node. * - * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] + * @param input Training data: RDD of [[org.apache.spark.ml.tree.impl.TreePoint]] * @param metadata Learning and dataset metadata * @param topNodes Root node for each tree. Used for matching instances with nodes. * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree @@ -394,6 +414,7 @@ private[ml] object RandomForest extends Logging { mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits, metadata.unorderedFeatures, instanceWeight, featuresForNode) } + agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight) } } @@ -474,13 +495,13 @@ private[ml] object RandomForest extends Logging { val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo) val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures) - val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) { + val partitionAggregates: RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) { input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points => // Construct a nodeStatsAggregators array to hold node aggregate stats, // each node will have a nodeStatsAggregator val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => - val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => - Some(nodeToFeatures(nodeIndex)) + val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures => + nodeToFeatures(nodeIndex) } new DTStatsAggregator(metadata, featuresForNode) } @@ -650,7 +671,7 @@ private[ml] object RandomForest extends Logging { * @param binAggregates Bin statistics. * @return tuple for best split: (Split, information gain, prediction at node) */ - private def binsToBestSplit( + private[tree] def binsToBestSplit( binAggregates: DTStatsAggregator, splits: Array[Array[Split]], featuresForNode: Option[Array[Int]], @@ -658,7 +679,7 @@ private[ml] object RandomForest extends Logging { // Calculate InformationGain and ImpurityStats if current node is top node val level = LearningNode.indexToLevel(node.id) - var gainAndImpurityStats: ImpurityStats = if (level ==0) { + var gainAndImpurityStats: ImpurityStats = if (level == 0) { null } else { node.stats @@ -697,13 +718,12 @@ private[ml] object RandomForest extends Logging { (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (binAggregates.metadata.isUnordered(featureIndex)) { // Unordered categorical feature - val (leftChildOffset, rightChildOffset) = - binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) + val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - val rightChildStats = - binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) + val rightChildStats = binAggregates.getParentImpurityCalculator() + .subtract(leftChildStats) gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, leftChildStats, rightChildStats, binAggregates.metadata) (splitIndex, gainAndImpurityStats) @@ -720,32 +740,30 @@ private[ml] object RandomForest extends Logging { * * centroidForCategories is a list: (category, centroid) */ - val centroidForCategories = if (binAggregates.metadata.isMulticlass) { - // For categorical variables in multiclass classification, - // the bins are ordered by the impurity of their corresponding labels. - Range(0, numCategories).map { case featureValue => - val categoryStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val centroid = if (categoryStats.count != 0) { + val centroidForCategories = Range(0, numCategories).map { case featureValue => + val categoryStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = if (categoryStats.count != 0) { + if (binAggregates.metadata.isMulticlass) { + // multiclass classification + // For categorical variables in multiclass classification, + // the bins are ordered by the impurity of their corresponding labels. categoryStats.calculate() + } else if (binAggregates.metadata.isClassification) { + // binary classification + // For categorical variables in binary classification, + // the bins are ordered by the count of class 1. + categoryStats.stats(1) } else { - Double.MaxValue - } - (featureValue, centroid) - } - } else { // regression or binary classification - // For categorical variables in regression and binary classification, - // the bins are ordered by the centroid of their corresponding labels. - Range(0, numCategories).map { case featureValue => - val categoryStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val centroid = if (categoryStats.count != 0) { + // regression + // For categorical variables in regression and binary classification, + // the bins are ordered by the prediction. categoryStats.predict - } else { - Double.MaxValue } - (featureValue, centroid) + } else { + Double.MaxValue } + (featureValue, centroid) } logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) @@ -793,7 +811,7 @@ private[ml] object RandomForest extends Logging { } /** - * Returns splits and bins for decision tree calculation. + * Returns splits for decision tree calculation. * Continuous and categorical features are handled differently. * * Continuous features: @@ -816,24 +834,21 @@ private[ml] object RandomForest extends Logging { * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param metadata Learning and dataset metadata * @param seed random seed - * @return A tuple of (splits, bins). - * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] - * of size (numFeatures, numSplits). - * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] - * of size (numFeatures, numBins). + * @return Splits, an Array of [[org.apache.spark.mllib.tree.model.Split]] + * of size (numFeatures, numSplits) */ protected[tree] def findSplits( input: RDD[LabeledPoint], metadata: DecisionTreeMetadata, - seed : Long): Array[Array[Split]] = { + seed: Long): Array[Array[Split]] = { logDebug("isMulticlass = " + metadata.isMulticlass) val numFeatures = metadata.numFeatures // Sample the input only if there are continuous features. - val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous) - val sampledInput = if (hasContinuousFeatures) { + val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous) + val sampledInput = if (continuousFeatures.nonEmpty) { // Calculate the number of samples for approximate quantile calculation. val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) val fraction = if (requiredSamples < metadata.numExamples) { @@ -842,58 +857,56 @@ private[ml] object RandomForest extends Logging { 1.0 } logDebug("fraction of data used for calculating quantiles = " + fraction) - input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect() + input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()) } else { - new Array[LabeledPoint](0) + input.sparkContext.emptyRDD[LabeledPoint] } - val splits = new Array[Array[Split]](numFeatures) - - // Find all splits. - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - if (metadata.isContinuous(featureIndex)) { - val featureSamples = sampledInput.map(_.features(featureIndex)) - val featureSplits = findSplitsForContinuousFeature(featureSamples, metadata, featureIndex) + findSplitsBySorting(sampledInput, metadata, continuousFeatures) + } - val numSplits = featureSplits.length - logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits") - splits(featureIndex) = new Array[Split](numSplits) + private def findSplitsBySorting( + input: RDD[LabeledPoint], + metadata: DecisionTreeMetadata, + continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = { + + val continuousSplits: scala.collection.Map[Int, Array[Split]] = { + // reduce the parallelism for split computations when there are less + // continuous features than input partitions. this prevents tasks from + // being spun up that will definitely do no work. + val numPartitions = math.min(continuousFeatures.length, input.partitions.length) + + input + .flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx)))) + .groupByKey(numPartitions) + .map { case (idx, samples) => + val thresholds = findSplitsForContinuousFeature(samples, metadata, idx) + val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh)) + logDebug(s"featureIndex = $idx, numSplits = ${splits.length}") + (idx, splits) + }.collectAsMap() + } - var splitIndex = 0 - while (splitIndex < numSplits) { - val threshold = featureSplits(splitIndex) - splits(featureIndex)(splitIndex) = new ContinuousSplit(featureIndex, threshold) - splitIndex += 1 - } - } else { - // Categorical feature - if (metadata.isUnordered(featureIndex)) { - val numSplits = metadata.numSplits(featureIndex) - val featureArity = metadata.featureArity(featureIndex) - // TODO: Use an implicit representation mapping each category to a subset of indices. - // I.e., track indices such that we can calculate the set of bins for which - // feature value x splits to the left. - // Unordered features - // 2^(maxFeatureValue - 1) - 1 combinations - splits(featureIndex) = new Array[Split](numSplits) - var splitIndex = 0 - while (splitIndex < numSplits) { - val categories: List[Double] = - extractMultiClassCategories(splitIndex + 1, featureArity) - splits(featureIndex)(splitIndex) = - new CategoricalSplit(featureIndex, categories.toArray, featureArity) - splitIndex += 1 - } - } else { - // Ordered features - // Bins correspond to feature values, so we do not need to compute splits or bins - // beforehand. Splits are constructed as needed during training. - splits(featureIndex) = new Array[Split](0) + val numFeatures = metadata.numFeatures + val splits: Array[Array[Split]] = Array.tabulate(numFeatures) { + case i if metadata.isContinuous(i) => + val split = continuousSplits(i) + metadata.setNumSplits(i, split.length) + split + + case i if metadata.isCategorical(i) && metadata.isUnordered(i) => + // Unordered features + // 2^(maxFeatureValue - 1) - 1 combinations + val featureArity = metadata.featureArity(i) + Array.tabulate[Split](metadata.numSplits(i)) { splitIndex => + val categories = extractMultiClassCategories(splitIndex + 1, featureArity) + new CategoricalSplit(i, categories.toArray, featureArity) } - } - featureIndex += 1 + + case i if metadata.isCategorical(i) => + // Ordered features + // Splits are constructed as needed during training. + Array.empty[Split] } splits } @@ -935,7 +948,7 @@ private[ml] object RandomForest extends Logging { * @return array of splits */ private[tree] def findSplitsForContinuousFeature( - featureSamples: Array[Double], + featureSamples: Iterable[Double], metadata: DecisionTreeMetadata, featureIndex: Int): Array[Double] = { require(metadata.isContinuous(featureIndex), @@ -945,8 +958,9 @@ private[ml] object RandomForest extends Logging { val numSplits = metadata.numSplits(featureIndex) // get count for each distinct value - val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) => - m + ((x, m.getOrElse(x, 0) + 1)) + val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) { + case ((m, cnt), x) => + (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1) } // sort distinct values val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray @@ -957,7 +971,7 @@ private[ml] object RandomForest extends Logging { valueCounts.map(_._1) } else { // stride between splits - val stride: Double = featureSamples.length.toDouble / (numSplits + 1) + val stride: Double = numSamples.toDouble / (numSplits + 1) logDebug("stride = " + stride) // iterate `valueCount` to find splits @@ -993,8 +1007,6 @@ private[ml] object RandomForest extends Logging { assert(splits.length > 0, s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." + " Please remove this feature and then try again.") - // set number of splits accordingly - metadata.setNumSplits(featureIndex, splits.length) splits } @@ -1032,7 +1044,9 @@ private[ml] object RandomForest extends Logging { new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]() var memUsage: Long = 0L var numNodesInGroup = 0 - while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) { + // If maxMemoryInMB is set very small, we want to still try to split 1 node, + // so we allow one iteration if memUsage == 0. + while (nodeQueue.nonEmpty && (memUsage < maxMemoryUsage || memUsage == 0)) { val (treeIndex, node) = nodeQueue.head // Choose subset of features for node (if subsampling). val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { @@ -1043,7 +1057,7 @@ private[ml] object RandomForest extends Logging { } // Check if enough memory remains to add this node to the group. val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L - if (memUsage + nodeMemUsage <= maxMemoryUsage) { + if (memUsage + nodeMemUsage <= maxMemoryUsage || memUsage == 0) { nodeQueue.dequeue() mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) += node @@ -1054,6 +1068,12 @@ private[ml] object RandomForest extends Logging { numNodesInGroup += 1 memUsage += nodeMemUsage } + if (memUsage > maxMemoryUsage) { + // If maxMemoryUsage is 0, we should still allow splitting 1 node. + logWarning(s"Tree learning is using approximately $memUsage bytes per iteration, which" + + s" exceeds requested limit maxMemoryUsage=$maxMemoryUsage. This allows splitting" + + s" $numNodesInGroup nodes in this iteration.") + } // Convert mutable maps to immutable ones. val nodesForGroup: Map[Int, Array[LearningNode]] = mutableNodesForGroup.mapValues(_.toArray).toMap @@ -1081,94 +1101,4 @@ private[ml] object RandomForest extends Logging { } } - /** - * Given a Random Forest model, compute the importance of each feature. - * This generalizes the idea of "Gini" importance to other losses, - * following the explanation of Gini importance from "Random Forests" documentation - * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. - * - * This feature importance is calculated as follows: - * - Average over trees: - * - importance(feature j) = sum (over nodes which split on feature j) of the gain, - * where gain is scaled by the number of instances passing through node - * - Normalize importances for tree based on total number of training instances used - * to build tree. - * - Normalize feature importance vector to sum to 1. - * - * Note: This should not be used with Gradient-Boosted Trees. It only makes sense for - * independently trained trees. - * @param trees Unweighted forest of trees - * @param numFeatures Number of features in model (even if not all are explicitly used by - * the model). - * If -1, then numFeatures is set based on the max feature index in all trees. - * @return Feature importance values, of length numFeatures. - */ - private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = { - val totalImportances = new OpenHashMap[Int, Double]() - trees.foreach { tree => - // Aggregate feature importance vector for this tree - val importances = new OpenHashMap[Int, Double]() - computeFeatureImportance(tree.rootNode, importances) - // Normalize importance vector for this tree, and add it to total. - // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count? - val treeNorm = importances.map(_._2).sum - if (treeNorm != 0) { - importances.foreach { case (idx, impt) => - val normImpt = impt / treeNorm - totalImportances.changeValue(idx, normImpt, _ + normImpt) - } - } - } - // Normalize importances - normalizeMapValues(totalImportances) - // Construct vector - val d = if (numFeatures != -1) { - numFeatures - } else { - // Find max feature index used in trees - val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max - maxFeatureIndex + 1 - } - if (d == 0) { - assert(totalImportances.size == 0, s"Unknown error in computing RandomForest feature" + - s" importance: No splits in forest, but some non-zero importances.") - } - val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip - Vectors.sparse(d, indices.toArray, values.toArray) - } - - /** - * Recursive method for computing feature importances for one tree. - * This walks down the tree, adding to the importance of 1 feature at each node. - * @param node Current node in recursion - * @param importances Aggregate feature importances, modified by this method - */ - private[impl] def computeFeatureImportance( - node: Node, - importances: OpenHashMap[Int, Double]): Unit = { - node match { - case n: InternalNode => - val feature = n.split.featureIndex - val scaledGain = n.gain * n.impurityStats.count - importances.changeValue(feature, scaledGain, _ + scaledGain) - computeFeatureImportance(n.leftChild, importances) - computeFeatureImportance(n.rightChild, importances) - case n: LeafNode => - // do nothing - } - } - - /** - * Normalize the values of this map to sum to 1, in place. - * If all values are 0, this method does nothing. - * @param map Map with non-negative values. - */ - private[impl] def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = { - val total = map.map(_._2).sum - if (total != 0) { - val keys = map.iterator.map(_._1).toArray - keys.foreach { key => map.changeValue(key, 0.0, _ / total) } - } - } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala similarity index 98% rename from mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala rename to mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala index 70afaa162b2e7..4cc250aa462e3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.tree.impl +package org.apache.spark.ml.tree.impl import scala.collection.mutable.{HashMap => MutableHashMap} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala index 9fa27e5e1f721..3a2bf3c725730 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala @@ -19,7 +19,6 @@ package org.apache.spark.ml.tree.impl import org.apache.spark.ml.tree.{ContinuousSplit, Split} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index b77191156f68f..f38e1ec7c09a8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -17,14 +17,29 @@ package org.apache.spark.ml.tree -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.param.{Param, Params} +import org.apache.spark.ml.tree.DecisionTreeModelReadWrite.NodeData +import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter} +import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator +import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Dataset, SQLContext} +import org.apache.spark.util.collection.OpenHashMap /** * Abstraction for Decision Tree models. * * TODO: Add support for predicting probabilities and raw predictions SPARK-3727 */ -private[ml] trait DecisionTreeModel { +private[spark] trait DecisionTreeModel { /** Root of the decision tree */ def rootNode: Node @@ -56,23 +71,34 @@ private[ml] trait DecisionTreeModel { /** * Trace down the tree, and return the largest feature index used in any split. + * * @return Max feature index used in a split, or -1 if there are no splits (single leaf node). */ private[ml] def maxSplitFeatureIndex(): Int = rootNode.maxSplitFeatureIndex() + + /** Convert to spark.mllib DecisionTreeModel (losing some information) */ + private[spark] def toOld: OldDecisionTreeModel } /** * Abstraction for models which are ensembles of decision trees * * TODO: Add support for predicting probabilities and raw predictions SPARK-3727 + * + * @tparam M Type of tree model in this ensemble */ -private[ml] trait TreeEnsembleModel { +private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] { // Note: We use getTrees since subclasses of TreeEnsembleModel will store subclasses of // DecisionTreeModel. /** Trees in this ensemble. Warning: These have null parent Estimators. */ - def trees: Array[DecisionTreeModel] + def trees: Array[M] + + /** + * Number of trees in ensemble + */ + val getNumTrees: Int = trees.length /** Weights for each tree, zippable with [[trees]] */ def treeWeights: Array[Double] @@ -84,7 +110,7 @@ private[ml] trait TreeEnsembleModel { /** Summary of the model */ override def toString: String = { // Implementing classes should generally override this method to be more descriptive. - s"TreeEnsembleModel with $numTrees trees" + s"TreeEnsembleModel with ${trees.length} trees" } /** Full description of model */ @@ -95,9 +121,364 @@ private[ml] trait TreeEnsembleModel { }.fold("")(_ + _) } - /** Number of trees in ensemble */ - val numTrees: Int = trees.length - /** Total number of nodes, summed over all trees in the ensemble. */ lazy val totalNumNodes: Int = trees.map(_.numNodes).sum } + +private[ml] object TreeEnsembleModel { + + /** + * Given a tree ensemble model, compute the importance of each feature. + * This generalizes the idea of "Gini" importance to other losses, + * following the explanation of Gini importance from "Random Forests" documentation + * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * + * For collections of trees, including boosting and bagging, Hastie et al. + * propose to use the average of single tree importances across all trees in the ensemble. + * + * This feature importance is calculated as follows: + * - Average over trees: + * - importance(feature j) = sum (over nodes which split on feature j) of the gain, + * where gain is scaled by the number of instances passing through node + * - Normalize importances for tree to sum to 1. + * - Normalize feature importance vector to sum to 1. + * + * References: + * - Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001. + * + * @param trees Unweighted collection of trees + * @param numFeatures Number of features in model (even if not all are explicitly used by + * the model). + * If -1, then numFeatures is set based on the max feature index in all trees. + * @return Feature importance values, of length numFeatures. + */ + def featureImportances[M <: DecisionTreeModel](trees: Array[M], numFeatures: Int): Vector = { + val totalImportances = new OpenHashMap[Int, Double]() + trees.foreach { tree => + // Aggregate feature importance vector for this tree + val importances = new OpenHashMap[Int, Double]() + computeFeatureImportance(tree.rootNode, importances) + // Normalize importance vector for this tree, and add it to total. + // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count? + val treeNorm = importances.map(_._2).sum + if (treeNorm != 0) { + importances.foreach { case (idx, impt) => + val normImpt = impt / treeNorm + totalImportances.changeValue(idx, normImpt, _ + normImpt) + } + } + } + // Normalize importances + normalizeMapValues(totalImportances) + // Construct vector + val d = if (numFeatures != -1) { + numFeatures + } else { + // Find max feature index used in trees + val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max + maxFeatureIndex + 1 + } + if (d == 0) { + assert(totalImportances.size == 0, s"Unknown error in computing feature" + + s" importance: No splits found, but some non-zero importances.") + } + val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip + Vectors.sparse(d, indices.toArray, values.toArray) + } + + /** + * Given a Decision Tree model, compute the importance of each feature. + * This generalizes the idea of "Gini" importance to other losses, + * following the explanation of Gini importance from "Random Forests" documentation + * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * + * This feature importance is calculated as follows: + * - importance(feature j) = sum (over nodes which split on feature j) of the gain, + * where gain is scaled by the number of instances passing through node + * - Normalize importances for tree to sum to 1. + * + * @param tree Decision tree to compute importances for. + * @param numFeatures Number of features in model (even if not all are explicitly used by + * the model). + * If -1, then numFeatures is set based on the max feature index in all trees. + * @return Feature importance values, of length numFeatures. + */ + def featureImportances[M <: DecisionTreeModel : ClassTag](tree: M, numFeatures: Int): Vector = { + featureImportances(Array(tree), numFeatures) + } + + /** + * Recursive method for computing feature importances for one tree. + * This walks down the tree, adding to the importance of 1 feature at each node. + * + * @param node Current node in recursion + * @param importances Aggregate feature importances, modified by this method + */ + def computeFeatureImportance( + node: Node, + importances: OpenHashMap[Int, Double]): Unit = { + node match { + case n: InternalNode => + val feature = n.split.featureIndex + val scaledGain = n.gain * n.impurityStats.count + importances.changeValue(feature, scaledGain, _ + scaledGain) + computeFeatureImportance(n.leftChild, importances) + computeFeatureImportance(n.rightChild, importances) + case n: LeafNode => + // do nothing + } + } + + /** + * Normalize the values of this map to sum to 1, in place. + * If all values are 0, this method does nothing. + * + * @param map Map with non-negative values. + */ + def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = { + val total = map.map(_._2).sum + if (total != 0) { + val keys = map.iterator.map(_._1).toArray + keys.foreach { key => map.changeValue(key, 0.0, _ / total) } + } + } +} + +/** Helper classes for tree model persistence */ +private[ml] object DecisionTreeModelReadWrite { + + /** + * Info for a [[org.apache.spark.ml.tree.Split]] + * + * @param featureIndex Index of feature split on + * @param leftCategoriesOrThreshold For categorical feature, set of leftCategories. + * For continuous feature, threshold. + * @param numCategories For categorical feature, number of categories. + * For continuous feature, -1. + */ + case class SplitData( + featureIndex: Int, + leftCategoriesOrThreshold: Array[Double], + numCategories: Int) { + + def getSplit: Split = { + if (numCategories != -1) { + new CategoricalSplit(featureIndex, leftCategoriesOrThreshold, numCategories) + } else { + assert(leftCategoriesOrThreshold.length == 1, s"DecisionTree split data expected" + + s" 1 threshold for ContinuousSplit, but found thresholds: " + + leftCategoriesOrThreshold.mkString(", ")) + new ContinuousSplit(featureIndex, leftCategoriesOrThreshold(0)) + } + } + } + + object SplitData { + def apply(split: Split): SplitData = split match { + case s: CategoricalSplit => + SplitData(s.featureIndex, s.leftCategories, s.numCategories) + case s: ContinuousSplit => + SplitData(s.featureIndex, Array(s.threshold), -1) + } + } + + /** + * Info for a [[Node]] + * + * @param id Index used for tree reconstruction. Indices follow a pre-order traversal. + * @param impurityStats Stats array. Impurity type is stored in metadata. + * @param gain Gain, or arbitrary value if leaf node. + * @param leftChild Left child index, or arbitrary value if leaf node. + * @param rightChild Right child index, or arbitrary value if leaf node. + * @param split Split info, or arbitrary value if leaf node. + */ + case class NodeData( + id: Int, + prediction: Double, + impurity: Double, + impurityStats: Array[Double], + gain: Double, + leftChild: Int, + rightChild: Int, + split: SplitData) + + object NodeData { + /** + * Create [[NodeData]] instances for this node and all children. + * + * @param id Current ID. IDs are assigned via a pre-order traversal. + * @return (sequence of nodes in pre-order traversal order, largest ID in subtree) + * The nodes are returned in pre-order traversal (root first) so that it is easy to + * get the ID of the subtree's root node. + */ + def build(node: Node, id: Int): (Seq[NodeData], Int) = node match { + case n: InternalNode => + val (leftNodeData, leftIdx) = build(n.leftChild, id + 1) + val (rightNodeData, rightIdx) = build(n.rightChild, leftIdx + 1) + val thisNodeData = NodeData(id, n.prediction, n.impurity, n.impurityStats.stats, + n.gain, leftNodeData.head.id, rightNodeData.head.id, SplitData(n.split)) + (thisNodeData +: (leftNodeData ++ rightNodeData), rightIdx) + case _: LeafNode => + (Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats, + -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))), + id) + } + } + + /** + * Load a decision tree from a file. + * @return Root node of reconstructed tree + */ + def loadTreeNodes( + path: String, + metadata: DefaultParamsReader.Metadata, + sqlContext: SQLContext): Node = { + import sqlContext.implicits._ + implicit val format = DefaultFormats + + // Get impurity to construct ImpurityCalculator for each node + val impurityType: String = { + val impurityJson: JValue = metadata.getParamValue("impurity") + Param.jsonDecode[String](compact(render(impurityJson))) + } + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).as[NodeData] + buildTreeFromNodes(data.collect(), impurityType) + } + + /** + * Given all data for all nodes in a tree, rebuild the tree. + * @param data Unsorted node data + * @param impurityType Impurity type for this tree + * @return Root node of reconstructed tree + */ + def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = { + // Load all nodes, sorted by ID. + val nodes = data.sortBy(_.id) + // Sanity checks; could remove + assert(nodes.head.id == 0, s"Decision Tree load failed. Expected smallest node ID to be 0," + + s" but found ${nodes.head.id}") + assert(nodes.last.id == nodes.length - 1, s"Decision Tree load failed. Expected largest" + + s" node ID to be ${nodes.length - 1}, but found ${nodes.last.id}") + // We fill `finalNodes` in reverse order. Since node IDs are assigned via a pre-order + // traversal, this guarantees that child nodes will be built before parent nodes. + val finalNodes = new Array[Node](nodes.length) + nodes.reverseIterator.foreach { case n: NodeData => + val impurityStats = ImpurityCalculator.getCalculator(impurityType, n.impurityStats) + val node = if (n.leftChild != -1) { + val leftChild = finalNodes(n.leftChild) + val rightChild = finalNodes(n.rightChild) + new InternalNode(n.prediction, n.impurity, n.gain, leftChild, rightChild, + n.split.getSplit, impurityStats) + } else { + new LeafNode(n.prediction, n.impurity, impurityStats) + } + finalNodes(n.id) = node + } + // Return the root node + finalNodes.head + } +} + +private[ml] object EnsembleModelReadWrite { + + /** + * Helper method for saving a tree ensemble to disk. + * + * @param instance Tree ensemble model + * @param path Path to which to save the ensemble model. + * @param extraMetadata Metadata such as numFeatures, numClasses, numTrees. + */ + def saveImpl[M <: Params with TreeEnsembleModel[_ <: DecisionTreeModel]]( + instance: M, + path: String, + sql: SQLContext, + extraMetadata: JObject): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sql.sparkContext, Some(extraMetadata)) + val treesMetadataWeights: Array[(Int, String, Double)] = instance.trees.zipWithIndex.map { + case (tree, treeID) => + (treeID, + DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sql.sparkContext), + instance.treeWeights(treeID)) + } + val treesMetadataPath = new Path(path, "treesMetadata").toString + sql.createDataFrame(treesMetadataWeights).toDF("treeID", "metadata", "weights") + .write.parquet(treesMetadataPath) + val dataPath = new Path(path, "data").toString + val nodeDataRDD = sql.sparkContext.parallelize(instance.trees.zipWithIndex).flatMap { + case (tree, treeID) => EnsembleNodeData.build(tree, treeID) + } + sql.createDataFrame(nodeDataRDD).write.parquet(dataPath) + } + + /** + * Helper method for loading a tree ensemble from disk. + * This reconstructs all trees, returning the root nodes. + * @param path Path given to [[saveImpl()]] + * @param className Class name for ensemble model type + * @param treeClassName Class name for tree model type in the ensemble + * @return (ensemble metadata, array over trees of (tree metadata, root node)), + * where the root node is linked with all descendents + * @see [[saveImpl()]] for how the model was saved + */ + def loadImpl( + path: String, + sql: SQLContext, + className: String, + treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = { + import sql.implicits._ + implicit val format = DefaultFormats + val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className) + + // Get impurity to construct ImpurityCalculator for each node + val impurityType: String = { + val impurityJson: JValue = metadata.getParamValue("impurity") + Param.jsonDecode[String](compact(render(impurityJson))) + } + + val treesMetadataPath = new Path(path, "treesMetadata").toString + val treesMetadataRDD: RDD[(Int, (Metadata, Double))] = sql.read.parquet(treesMetadataPath) + .select("treeID", "metadata", "weights").as[(Int, String, Double)].rdd.map { + case (treeID: Int, json: String, weights: Double) => + treeID -> (DefaultParamsReader.parseMetadata(json, treeClassName), weights) + } + + val treesMetadataWeights = treesMetadataRDD.sortByKey().values.collect() + val treesMetadata = treesMetadataWeights.map(_._1) + val treesWeights = treesMetadataWeights.map(_._2) + + val dataPath = new Path(path, "data").toString + val nodeData: Dataset[EnsembleNodeData] = + sql.read.parquet(dataPath).as[EnsembleNodeData] + val rootNodesRDD: RDD[(Int, Node)] = + nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map { + case (treeID: Int, nodeData: Iterable[NodeData]) => + treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType) + } + val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect() + (metadata, treesMetadata.zip(rootNodes), treesWeights) + } + + /** + * Info for one [[Node]] in a tree ensemble + * + * @param treeID Tree index + * @param nodeData Data for this node + */ + case class EnsembleNodeData( + treeID: Int, + nodeData: NodeData) + + object EnsembleNodeData { + /** + * Create [[EnsembleNodeData]] instances for the given tree. + * + * @return Sequence of nodes for this tree + */ + def build(tree: DecisionTreeModel, treeID: Int): Seq[EnsembleNodeData] = { + val (nodeData: Seq[NodeData], _) = NodeData.build(tree.rootNode, 0) + nodeData.map(nd => EnsembleNodeData(treeID, nd)) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 281ba6eeffa92..d7559f8950c3d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -17,19 +17,24 @@ package org.apache.spark.ml.tree +import scala.util.Try + import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} -import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} +import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError} +import org.apache.spark.sql.types.{DataType, DoubleType, StructType} /** * Parameters for Decision Tree-based algorithms. * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait DecisionTreeParams extends PredictorParams with HasCheckpointInterval { +private[ml] trait DecisionTreeParams extends PredictorParams + with HasCheckpointInterval with HasSeed { /** * Maximum depth of the tree (>= 0). @@ -75,7 +80,8 @@ private[ml] trait DecisionTreeParams extends PredictorParams with HasCheckpointI "Minimum information gain for a split to be considered at a tree node.") /** - * Maximum memory in MB allocated to histogram aggregation. + * Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be + * split per iteration, and its aggregates may exceed this size. * (default = 256 MB) * @group expertParam */ @@ -123,6 +129,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams with HasCheckpointI /** @group getParam */ final def getMinInfoGain: Double = $(minInfoGain) + /** @group setParam */ + def setSeed(value: Long): this.type = set(seed, value) + /** @group expertSetParam */ def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) @@ -211,6 +220,9 @@ private[ml] object TreeClassifierParams { final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase) } +private[ml] trait DecisionTreeClassifierParams + extends DecisionTreeParams with TreeClassifierParams + /** * Parameters for Decision Tree-based regression algorithms. */ @@ -252,12 +264,28 @@ private[ml] object TreeRegressorParams { final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase) } +private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams + with TreeRegressorParams with HasVarianceCol { + + override protected def validateAndTransformSchema( + schema: StructType, + fitting: Boolean, + featuresDataType: DataType): StructType = { + val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType) + if (isDefined(varianceCol) && $(varianceCol).nonEmpty) { + SchemaUtils.appendColumn(newSchema, $(varianceCol), DoubleType) + } else { + newSchema + } + } +} + /** * Parameters for Decision Tree-based ensemble algorithms. * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { +private[ml] trait TreeEnsembleParams extends DecisionTreeParams { /** * Fraction of the training data used for learning each decision tree, in range (0, 1]. @@ -276,9 +304,6 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { /** @group getParam */ final def getSubsamplingRate: Double = $(subsamplingRate) - /** @group setParam */ - def setSeed(value: Long): this.type = set(seed, value) - /** * Create a Strategy instance to use with the old API. * NOTE: The caller should set impurity and seed. @@ -292,22 +317,8 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { } } -/** - * Parameters for Random Forest algorithms. - * - * Note: Marked as private and DeveloperApi since this may be made public in the future. - */ -private[ml] trait RandomForestParams extends TreeEnsembleParams { - - /** - * Number of trees to train (>= 1). - * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. - * TODO: Change to always do bootstrapping (simpler). SPARK-7130 - * (default = 20) - * @group param - */ - final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", - ParamValidators.gtEq(1)) +/** Used for [[RandomForestParams]] */ +private[ml] trait HasFeatureSubsetStrategy extends Params { /** * The number of features to consider for splits at each tree node. @@ -320,6 +331,8 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { * - "onethird": use 1/3 of the features * - "sqrt": use sqrt(number of features) * - "log2": use log2(number of features) + * - "n": when n is in the range (0, 1.0], use n * number of features. When n + * is in the range (1, number of features), use n features. * (default = "auto") * * These various settings are based on the following references: @@ -335,31 +348,72 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { */ final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy", "The number of features to consider for splits at each tree node." + - s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}", + s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}" + + s", (0.0-1.0], [1-n].", (value: String) => - RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase)) + RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase) + || Try(value.toInt).filter(_ > 0).isSuccess + || Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess) - setDefault(numTrees -> 20, featureSubsetStrategy -> "auto") + setDefault(featureSubsetStrategy -> "auto") /** @group setParam */ - def setNumTrees(value: Int): this.type = set(numTrees, value) + def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) /** @group getParam */ - final def getNumTrees: Int = $(numTrees) + final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase +} + +/** + * Used for [[RandomForestParams]]. + * This is separated out from [[RandomForestParams]] because of an issue with the + * `numTrees` method conflicting with this Param in the Estimator. + */ +private[ml] trait HasNumTrees extends Params { + + /** + * Number of trees to train (>= 1). + * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. + * TODO: Change to always do bootstrapping (simpler). SPARK-7130 + * (default = 20) + * @group param + */ + final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", + ParamValidators.gtEq(1)) + + setDefault(numTrees -> 20) /** @group setParam */ - def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) + def setNumTrees(value: Int): this.type = set(numTrees, value) /** @group getParam */ - final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase + final def getNumTrees: Int = $(numTrees) } -private[ml] object RandomForestParams { +/** + * Parameters for Random Forest algorithms. + */ +private[ml] trait RandomForestParams extends TreeEnsembleParams + with HasFeatureSubsetStrategy with HasNumTrees + +private[spark] object RandomForestParams { // These options should be lowercase. final val supportedFeatureSubsetStrategies: Array[String] = Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase) } +private[ml] trait RandomForestClassifierParams + extends RandomForestParams with TreeClassifierParams + +private[ml] trait RandomForestClassificationModelParams extends TreeEnsembleParams + with HasFeatureSubsetStrategy with TreeClassifierParams + +private[ml] trait RandomForestRegressorParams + extends RandomForestParams with TreeRegressorParams + +private[ml] trait RandomForestRegressionModelParams extends TreeEnsembleParams + with HasFeatureSubsetStrategy with TreeRegressorParams + /** * Parameters for Gradient-Boosted Tree algorithms. * @@ -409,3 +463,74 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS /** Get old Gradient Boosting Loss type */ private[ml] def getOldLossType: OldLoss } + +private[ml] object GBTClassifierParams { + // The losses below should be lowercase. + /** Accessor for supported loss settings: logistic */ + final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) +} + +private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams { + + /** + * Loss function which GBT tries to minimize. (case-insensitive) + * Supported: "logistic" + * (default = logistic) + * @group param + */ + val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + + " tries to minimize (case-insensitive). Supported options:" + + s" ${GBTClassifierParams.supportedLossTypes.mkString(", ")}", + (value: String) => GBTClassifierParams.supportedLossTypes.contains(value.toLowerCase)) + + setDefault(lossType -> "logistic") + + /** @group getParam */ + def getLossType: String = $(lossType).toLowerCase + + /** (private[ml]) Convert new loss to old loss. */ + override private[ml] def getOldLossType: OldLoss = { + getLossType match { + case "logistic" => OldLogLoss + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType") + } + } +} + +private[ml] object GBTRegressorParams { + // The losses below should be lowercase. + /** Accessor for supported loss settings: squared (L2), absolute (L1) */ + final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase) +} + +private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams { + + /** + * Loss function which GBT tries to minimize. (case-insensitive) + * Supported: "squared" (L2) and "absolute" (L1) + * (default = squared) + * @group param + */ + val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + + " tries to minimize (case-insensitive). Supported options:" + + s" ${GBTRegressorParams.supportedLossTypes.mkString(", ")}", + (value: String) => GBTRegressorParams.supportedLossTypes.contains(value.toLowerCase)) + + setDefault(lossType -> "squared") + + /** @group getParam */ + def getLossType: String = $(lossType).toLowerCase + + /** (private[ml]) Convert new loss to old loss. */ + override private[ml] def getOldLossType: OldLoss = { + getLossType match { + case "squared" => OldSquaredError + case "absolute" => OldAbsoluteError + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType") + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 77d9948ed86b9..de563d4fad7df 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -17,25 +17,33 @@ package org.apache.spark.ml.tuning +import java.util.{List => JList} + +import scala.collection.JavaConverters._ + import com.github.fommil.netlib.F2jBLAS +import org.apache.hadoop.fs.Path +import org.json4s.DefaultFormats -import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.internal.Logging import org.apache.spark.ml._ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.param.shared.HasSeed +import org.apache.spark.ml.util._ import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ -private[ml] trait CrossValidatorParams extends ValidatorParams { +private[ml] trait CrossValidatorParams extends ValidatorParams with HasSeed { /** * Param for number of folds for cross validation. Must be >= 2. * Default: 3 + * * @group param */ val numFolds: IntParam = new IntParam(this, "numFolds", @@ -51,27 +59,39 @@ private[ml] trait CrossValidatorParams extends ValidatorParams { * :: Experimental :: * K-fold cross validation. */ +@Since("1.2.0") @Experimental -class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel] - with CrossValidatorParams with Logging { +class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) + extends Estimator[CrossValidatorModel] + with CrossValidatorParams with MLWritable with Logging { + @Since("1.2.0") def this() = this(Identifiable.randomUID("cv")) private val f2jBLAS = new F2jBLAS /** @group setParam */ + @Since("1.2.0") def setEstimator(value: Estimator[_]): this.type = set(estimator, value) /** @group setParam */ + @Since("1.2.0") def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) /** @group setParam */ + @Since("1.2.0") def setEvaluator(value: Evaluator): this.type = set(evaluator, value) /** @group setParam */ + @Since("1.2.0") def setNumFolds(value: Int): this.type = set(numFolds, value) - override def fit(dataset: DataFrame): CrossValidatorModel = { + /** @group setParam */ + @Since("2.0.0") + def setSeed(value: Long): this.type = set(seed, value) + + @Since("2.0.0") + override def fit(dataset: Dataset[_]): CrossValidatorModel = { val schema = dataset.schema transformSchema(schema, logging = true) val sqlCtx = dataset.sqlContext @@ -80,7 +100,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM val epm = $(estimatorParamMaps) val numModels = epm.length val metrics = new Array[Double](epm.length) - val splits = MLUtils.kFold(dataset.rdd, $(numFolds), 0) + val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() @@ -109,18 +129,10 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this)) } - override def transformSchema(schema: StructType): StructType = { - $(estimator).transformSchema(schema) - } - - override def validateParams(): Unit = { - super.validateParams() - val est = $(estimator) - for (paramMap <- $(estimatorParamMaps)) { - est.copy(paramMap).validateParams() - } - } + @Since("1.4.0") + override def transformSchema(schema: StructType): StructType = transformSchemaImpl(schema) + @Since("1.4.0") override def copy(extra: ParamMap): CrossValidator = { val copied = defaultCopy(extra).asInstanceOf[CrossValidator] if (copied.isDefined(estimator)) { @@ -131,6 +143,49 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM } copied } + + // Currently, this only works if all [[Param]]s in [[estimatorParamMaps]] are simple types. + // E.g., this may fail if a [[Param]] is an instance of an [[Estimator]]. + // However, this case should be unusual. + @Since("1.6.0") + override def write: MLWriter = new CrossValidator.CrossValidatorWriter(this) +} + +@Since("1.6.0") +object CrossValidator extends MLReadable[CrossValidator] { + + @Since("1.6.0") + override def read: MLReader[CrossValidator] = new CrossValidatorReader + + @Since("1.6.0") + override def load(path: String): CrossValidator = super.load(path) + + private[CrossValidator] class CrossValidatorWriter(instance: CrossValidator) extends MLWriter { + + ValidatorParams.validateParams(instance) + + override protected def saveImpl(path: String): Unit = + ValidatorParams.saveImpl(path, instance, sc) + } + + private class CrossValidatorReader extends MLReader[CrossValidator] { + + /** Checked against metadata when loading model */ + private val className = classOf[CrossValidator].getName + + override def load(path: String): CrossValidator = { + implicit val format = DefaultFormats + + val (metadata, estimator, evaluator, estimatorParamMaps) = + ValidatorParams.loadImpl(path, sc, className) + val numFolds = (metadata.params \ "numFolds").extract[Int] + new CrossValidator(metadata.uid) + .setEstimator(estimator) + .setEvaluator(evaluator) + .setEstimatorParamMaps(estimatorParamMaps) + .setNumFolds(numFolds) + } + } } /** @@ -139,28 +194,33 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM * * @param bestModel The best model selected from k-fold cross validation. * @param avgMetrics Average cross-validation metrics for each paramMap in - * [[estimatorParamMaps]], in the corresponding order. + * [[CrossValidator.estimatorParamMaps]], in the corresponding order. */ +@Since("1.2.0") @Experimental class CrossValidatorModel private[ml] ( - override val uid: String, - val bestModel: Model[_], - val avgMetrics: Array[Double]) - extends Model[CrossValidatorModel] with CrossValidatorParams { + @Since("1.4.0") override val uid: String, + @Since("1.2.0") val bestModel: Model[_], + @Since("1.5.0") val avgMetrics: Array[Double]) + extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable { - override def validateParams(): Unit = { - bestModel.validateParams() + /** A Python-friendly auxiliary constructor. */ + private[ml] def this(uid: String, bestModel: Model[_], avgMetrics: JList[Double]) = { + this(uid, bestModel, avgMetrics.asScala.toArray) } - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) bestModel.transform(dataset) } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { bestModel.transformSchema(schema) } + @Since("1.4.0") override def copy(extra: ParamMap): CrossValidatorModel = { val copied = new CrossValidatorModel( uid, @@ -168,4 +228,53 @@ class CrossValidatorModel private[ml] ( avgMetrics.clone()) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this) +} + +@Since("1.6.0") +object CrossValidatorModel extends MLReadable[CrossValidatorModel] { + + @Since("1.6.0") + override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader + + @Since("1.6.0") + override def load(path: String): CrossValidatorModel = super.load(path) + + private[CrossValidatorModel] + class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter { + + ValidatorParams.validateParams(instance) + + override protected def saveImpl(path: String): Unit = { + import org.json4s.JsonDSL._ + val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq + ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata)) + val bestModelPath = new Path(path, "bestModel").toString + instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) + } + } + + private class CrossValidatorModelReader extends MLReader[CrossValidatorModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[CrossValidatorModel].getName + + override def load(path: String): CrossValidatorModel = { + implicit val format = DefaultFormats + + val (metadata, estimator, evaluator, estimatorParamMaps) = + ValidatorParams.loadImpl(path, sc, className) + val numFolds = (metadata.params \ "numFolds").extract[Int] + val bestModelPath = new Path(path, "bestModel").toString + val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) + val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray + val cv = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) + cv.set(cv.estimator, estimator) + .set(cv.evaluator, evaluator) + .set(cv.estimatorParamMaps, estimatorParamMaps) + .set(cv.numFolds, numFolds) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala index 98a8f0330ca45..b836d2a2340e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala @@ -20,21 +20,23 @@ package org.apache.spark.ml.tuning import scala.annotation.varargs import scala.collection.mutable -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param._ /** * :: Experimental :: * Builder for a param grid used in grid search-based model selection. */ +@Since("1.2.0") @Experimental -class ParamGridBuilder { +class ParamGridBuilder @Since("1.2.0") { private val paramGrid = mutable.Map.empty[Param[_], Iterable[_]] /** * Sets the given parameters in this grid to fixed values. */ + @Since("1.2.0") def baseOn(paramMap: ParamMap): this.type = { baseOn(paramMap.toSeq: _*) this @@ -43,6 +45,7 @@ class ParamGridBuilder { /** * Sets the given parameters in this grid to fixed values. */ + @Since("1.2.0") @varargs def baseOn(paramPairs: ParamPair[_]*): this.type = { paramPairs.foreach { p => @@ -54,6 +57,7 @@ class ParamGridBuilder { /** * Adds a param with multiple values (overwrites if the input param exists). */ + @Since("1.2.0") def addGrid[T](param: Param[T], values: Iterable[T]): this.type = { paramGrid.put(param, values) this @@ -64,6 +68,7 @@ class ParamGridBuilder { /** * Adds a double param with multiple values. */ + @Since("1.2.0") def addGrid(param: DoubleParam, values: Array[Double]): this.type = { addGrid[Double](param, values) } @@ -71,6 +76,7 @@ class ParamGridBuilder { /** * Adds a int param with multiple values. */ + @Since("1.2.0") def addGrid(param: IntParam, values: Array[Int]): this.type = { addGrid[Int](param, values) } @@ -78,6 +84,7 @@ class ParamGridBuilder { /** * Adds a float param with multiple values. */ + @Since("1.2.0") def addGrid(param: FloatParam, values: Array[Float]): this.type = { addGrid[Float](param, values) } @@ -85,6 +92,7 @@ class ParamGridBuilder { /** * Adds a long param with multiple values. */ + @Since("1.2.0") def addGrid(param: LongParam, values: Array[Long]): this.type = { addGrid[Long](param, values) } @@ -92,6 +100,7 @@ class ParamGridBuilder { /** * Adds a boolean param with true and false. */ + @Since("1.2.0") def addGrid(param: BooleanParam): this.type = { addGrid[Boolean](param, Array(true, false)) } @@ -99,6 +108,7 @@ class ParamGridBuilder { /** * Builds and returns all combinations of parameters specified by the param grid. */ + @Since("1.2.0") def build(): Array[ParamMap] = { var paramMaps = Array(new ParamMap) paramGrid.foreach { case (param, values) => diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 73a14b8310157..12d6905510c02 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -17,22 +17,32 @@ package org.apache.spark.ml.tuning -import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.evaluation.Evaluator +import java.util.{List => JList} + +import scala.collection.JavaConverters._ +import scala.language.existentials + +import org.apache.hadoop.fs.Path +import org.json4s.DefaultFormats + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.sql.DataFrame +import org.apache.spark.ml.param.shared.HasSeed +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType /** * Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]]. */ -private[ml] trait TrainValidationSplitParams extends ValidatorParams { +private[ml] trait TrainValidationSplitParams extends ValidatorParams with HasSeed { /** * Param for ratio between train and validation data. Must be between 0 and 1. * Default: 0.75 + * * @group param */ val trainRatio: DoubleParam = new DoubleParam(this, "trainRatio", @@ -51,25 +61,37 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams { * and uses evaluation metric on the validation set to select the best model. * Similar to [[CrossValidator]], but only splits the set once. */ +@Since("1.5.0") @Experimental -class TrainValidationSplit(override val uid: String) extends Estimator[TrainValidationSplitModel] - with TrainValidationSplitParams with Logging { +class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: String) + extends Estimator[TrainValidationSplitModel] + with TrainValidationSplitParams with MLWritable with Logging { + @Since("1.5.0") def this() = this(Identifiable.randomUID("tvs")) /** @group setParam */ + @Since("1.5.0") def setEstimator(value: Estimator[_]): this.type = set(estimator, value) /** @group setParam */ + @Since("1.5.0") def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) /** @group setParam */ + @Since("1.5.0") def setEvaluator(value: Evaluator): this.type = set(evaluator, value) /** @group setParam */ + @Since("1.5.0") def setTrainRatio(value: Double): this.type = set(trainRatio, value) - override def fit(dataset: DataFrame): TrainValidationSplitModel = { + /** @group setParam */ + @Since("2.0.0") + def setSeed(value: Long): this.type = set(seed, value) + + @Since("2.0.0") + override def fit(dataset: Dataset[_]): TrainValidationSplitModel = { val schema = dataset.schema transformSchema(schema, logging = true) val sqlCtx = dataset.sqlContext @@ -79,10 +101,10 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali val numModels = epm.length val metrics = new Array[Double](epm.length) - val Array(training, validation) = - dataset.rdd.randomSplit(Array($(trainRatio), 1 - $(trainRatio))) - val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() - val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() + val Array(trainingDataset, validationDataset) = + dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed)) + trainingDataset.cache() + validationDataset.cache() // multi-model training logDebug(s"Train split with multiple sets of parameters.") @@ -108,18 +130,10 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this)) } - override def transformSchema(schema: StructType): StructType = { - $(estimator).transformSchema(schema) - } - - override def validateParams(): Unit = { - super.validateParams() - val est = $(estimator) - for (paramMap <- $(estimatorParamMaps)) { - est.copy(paramMap).validateParams() - } - } + @Since("1.5.0") + override def transformSchema(schema: StructType): StructType = transformSchemaImpl(schema) + @Since("1.5.0") override def copy(extra: ParamMap): TrainValidationSplit = { val copied = defaultCopy(extra).asInstanceOf[TrainValidationSplit] if (copied.isDefined(estimator)) { @@ -130,6 +144,47 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali } copied } + + @Since("2.0.0") + override def write: MLWriter = new TrainValidationSplit.TrainValidationSplitWriter(this) +} + +@Since("2.0.0") +object TrainValidationSplit extends MLReadable[TrainValidationSplit] { + + @Since("2.0.0") + override def read: MLReader[TrainValidationSplit] = new TrainValidationSplitReader + + @Since("2.0.0") + override def load(path: String): TrainValidationSplit = super.load(path) + + private[TrainValidationSplit] class TrainValidationSplitWriter(instance: TrainValidationSplit) + extends MLWriter { + + ValidatorParams.validateParams(instance) + + override protected def saveImpl(path: String): Unit = + ValidatorParams.saveImpl(path, instance, sc) + } + + private class TrainValidationSplitReader extends MLReader[TrainValidationSplit] { + + /** Checked against metadata when loading model */ + private val className = classOf[TrainValidationSplit].getName + + override def load(path: String): TrainValidationSplit = { + implicit val format = DefaultFormats + + val (metadata, estimator, evaluator, estimatorParamMaps) = + ValidatorParams.loadImpl(path, sc, className) + val trainRatio = (metadata.params \ "trainRatio").extract[Double] + new TrainValidationSplit(metadata.uid) + .setEstimator(estimator) + .setEvaluator(evaluator) + .setEstimatorParamMaps(estimatorParamMaps) + .setTrainRatio(trainRatio) + } + } } /** @@ -140,26 +195,31 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali * @param bestModel Estimator determined best model. * @param validationMetrics Evaluated validation metrics. */ +@Since("1.5.0") @Experimental class TrainValidationSplitModel private[ml] ( - override val uid: String, - val bestModel: Model[_], - val validationMetrics: Array[Double]) - extends Model[TrainValidationSplitModel] with TrainValidationSplitParams { + @Since("1.5.0") override val uid: String, + @Since("1.5.0") val bestModel: Model[_], + @Since("1.5.0") val validationMetrics: Array[Double]) + extends Model[TrainValidationSplitModel] with TrainValidationSplitParams with MLWritable { - override def validateParams(): Unit = { - bestModel.validateParams() + /** A Python-friendly auxiliary constructor. */ + private[ml] def this(uid: String, bestModel: Model[_], validationMetrics: JList[Double]) = { + this(uid, bestModel, validationMetrics.asScala.toArray) } - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) bestModel.transform(dataset) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { bestModel.transformSchema(schema) } + @Since("1.5.0") override def copy(extra: ParamMap): TrainValidationSplitModel = { val copied = new TrainValidationSplitModel ( uid, @@ -167,4 +227,53 @@ class TrainValidationSplitModel private[ml] ( validationMetrics.clone()) copyValues(copied, extra) } + + @Since("2.0.0") + override def write: MLWriter = new TrainValidationSplitModel.TrainValidationSplitModelWriter(this) +} + +@Since("2.0.0") +object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { + + @Since("2.0.0") + override def read: MLReader[TrainValidationSplitModel] = new TrainValidationSplitModelReader + + @Since("2.0.0") + override def load(path: String): TrainValidationSplitModel = super.load(path) + + private[TrainValidationSplitModel] + class TrainValidationSplitModelWriter(instance: TrainValidationSplitModel) extends MLWriter { + + ValidatorParams.validateParams(instance) + + override protected def saveImpl(path: String): Unit = { + import org.json4s.JsonDSL._ + val extraMetadata = "validationMetrics" -> instance.validationMetrics.toSeq + ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata)) + val bestModelPath = new Path(path, "bestModel").toString + instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) + } + } + + private class TrainValidationSplitModelReader extends MLReader[TrainValidationSplitModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[TrainValidationSplitModel].getName + + override def load(path: String): TrainValidationSplitModel = { + implicit val format = DefaultFormats + + val (metadata, estimator, evaluator, estimatorParamMaps) = + ValidatorParams.loadImpl(path, sc, className) + val trainRatio = (metadata.params \ "trainRatio").extract[Double] + val bestModelPath = new Path(path, "bestModel").toString + val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) + val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray + val tvs = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics) + tvs.set(tvs.estimator, estimator) + .set(tvs.evaluator, evaluator) + .set(tvs.estimatorParamMaps, estimatorParamMaps) + .set(tvs.trainRatio, trainRatio) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 8897ab0825acd..7a4e106aeb999 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -17,20 +17,27 @@ package org.apache.spark.ml.tuning -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.ml.Estimator +import org.apache.hadoop.fs.Path +import org.json4s.{DefaultFormats, _} +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator -import org.apache.spark.ml.param.{ParamMap, Param, Params} +import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} +import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, MetaAlgorithmReadWrite, + MLWritable} +import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.sql.types.StructType /** - * :: DeveloperApi :: * Common params for [[TrainValidationSplitParams]] and [[CrossValidatorParams]]. */ -@DeveloperApi private[ml] trait ValidatorParams extends Params { /** * param for the estimator to be validated + * * @group param */ val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") @@ -40,6 +47,7 @@ private[ml] trait ValidatorParams extends Params { /** * param for estimator param maps + * * @group param */ val estimatorParamMaps: Param[Array[ParamMap]] = @@ -50,6 +58,7 @@ private[ml] trait ValidatorParams extends Params { /** * param for the evaluator used to select hyper-parameters that maximize the validated metric + * * @group param */ val evaluator: Param[Evaluator] = new Param(this, "evaluator", @@ -57,4 +66,119 @@ private[ml] trait ValidatorParams extends Params { /** @group getParam */ def getEvaluator: Evaluator = $(evaluator) + + protected def transformSchemaImpl(schema: StructType): StructType = { + require($(estimatorParamMaps).nonEmpty, s"Validator requires non-empty estimatorParamMaps") + val firstEstimatorParamMap = $(estimatorParamMaps).head + val est = $(estimator) + for (paramMap <- $(estimatorParamMaps).tail) { + est.copy(paramMap).transformSchema(schema) + } + est.copy(firstEstimatorParamMap).transformSchema(schema) + } +} + +private[ml] object ValidatorParams { + /** + * Check that [[ValidatorParams.evaluator]] and [[ValidatorParams.estimator]] are Writable. + * This does not check [[ValidatorParams.estimatorParamMaps]]. + */ + def validateParams(instance: ValidatorParams): Unit = { + def checkElement(elem: Params, name: String): Unit = elem match { + case stage: MLWritable => // good + case other => + throw new UnsupportedOperationException(instance.getClass.getName + " write will fail " + + s" because it contains $name which does not implement Writable." + + s" Non-Writable $name: ${other.uid} of type ${other.getClass}") + } + checkElement(instance.getEvaluator, "evaluator") + checkElement(instance.getEstimator, "estimator") + // Check to make sure all Params apply to this estimator. Throw an error if any do not. + // Extraneous Params would cause problems when loading the estimatorParamMaps. + val uidToInstance: Map[String, Params] = MetaAlgorithmReadWrite.getUidMap(instance) + instance.getEstimatorParamMaps.foreach { case pMap: ParamMap => + pMap.toSeq.foreach { case ParamPair(p, v) => + require(uidToInstance.contains(p.parent), s"ValidatorParams save requires all Params in" + + s" estimatorParamMaps to apply to this ValidatorParams, its Estimator, or its" + + s" Evaluator. An extraneous Param was found: $p") + } + } + } + + /** + * Generic implementation of save for [[ValidatorParams]] types. + * This handles all [[ValidatorParams]] fields and saves [[Param]] values, but the implementing + * class needs to handle model data. + */ + def saveImpl( + path: String, + instance: ValidatorParams, + sc: SparkContext, + extraMetadata: Option[JObject] = None): Unit = { + import org.json4s.JsonDSL._ + + val estimatorParamMapsJson = compact(render( + instance.getEstimatorParamMaps.map { case paramMap => + paramMap.toSeq.map { case ParamPair(p, v) => + Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v)) + } + }.toSeq + )) + + val validatorSpecificParams = instance match { + case cv: CrossValidatorParams => + List("numFolds" -> parse(cv.numFolds.jsonEncode(cv.getNumFolds))) + case tvs: TrainValidationSplitParams => + List("trainRatio" -> parse(tvs.trainRatio.jsonEncode(tvs.getTrainRatio))) + case _ => + // This should not happen. + throw new NotImplementedError("ValidatorParams.saveImpl does not handle type: " + + instance.getClass.getCanonicalName) + } + + val jsonParams = validatorSpecificParams ++ List( + "estimatorParamMaps" -> parse(estimatorParamMapsJson)) + + DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams)) + + val evaluatorPath = new Path(path, "evaluator").toString + instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath) + val estimatorPath = new Path(path, "estimator").toString + instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath) + } + + /** + * Generic implementation of load for [[ValidatorParams]] types. + * This handles all [[ValidatorParams]] fields, but the implementing + * class needs to handle model data and special [[Param]] values. + */ + def loadImpl[M <: Model[M]]( + path: String, + sc: SparkContext, + expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap]) = { + + val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName) + + implicit val format = DefaultFormats + val evaluatorPath = new Path(path, "evaluator").toString + val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc) + val estimatorPath = new Path(path, "estimator").toString + val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc) + + val uidToParams = Map(evaluator.uid -> evaluator) ++ MetaAlgorithmReadWrite.getUidMap(estimator) + + val estimatorParamMaps: Array[ParamMap] = + (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map { + pMap => + val paramPairs = pMap.map { case pInfo: Map[String, String] => + val est = uidToParams(pInfo("parent")) + val param = est.getParam(pInfo("name")) + val value = param.jsonDecode(pInfo("value")) + param -> value + } + ParamMap(paramPairs: _*) + }.toArray + + (metadata, estimator, evaluator, estimatorParamMaps) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala new file mode 100644 index 0000000000000..7e57cefc4449f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.util.concurrent.atomic.AtomicLong + +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.internal.Logging +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param.Param +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Dataset + +/** + * A small wrapper that defines a training session for an estimator, and some methods to log + * useful information during this session. + * + * A new instance is expected to be created within fit(). + * + * @param estimator the estimator that is being fit + * @param dataset the training dataset + * @tparam E the type of the estimator + */ +private[ml] class Instrumentation[E <: Estimator[_]] private ( + estimator: E, dataset: RDD[_]) extends Logging { + + private val id = Instrumentation.counter.incrementAndGet() + private val prefix = { + val className = estimator.getClass.getSimpleName + s"$className-${estimator.uid}-${dataset.hashCode()}-$id: " + } + + init() + + private def init(): Unit = { + log(s"training: numPartitions=${dataset.partitions.length}" + + s" storageLevel=${dataset.getStorageLevel}") + } + + /** + * Logs a message with a prefix that uniquely identifies the training session. + */ + def log(msg: String): Unit = { + logInfo(prefix + msg) + } + + /** + * Logs the value of the given parameters for the estimator being used in this session. + */ + def logParams(params: Param[_]*): Unit = { + val pairs: Seq[(String, JValue)] = for { + p <- params + value <- estimator.get(p) + } yield { + val cast = p.asInstanceOf[Param[Any]] + p.name -> parse(cast.jsonEncode(value)) + } + log(compact(render(map2jvalue(pairs.toMap)))) + } + + def logNumFeatures(num: Long): Unit = { + log(compact(render("numFeatures" -> num))) + } + + def logNumClasses(num: Long): Unit = { + log(compact(render("numClasses" -> num))) + } + + /** + * Logs the successful completion of the training session and the value of the learned model. + */ + def logSuccess(model: Model[_]): Unit = { + log(s"training finished") + } +} + +/** + * Some common methods for logging information about a training session. + */ +private[ml] object Instrumentation { + private val counter = new AtomicLong(0) + + /** + * Creates an instrumentation object for a training session. + */ + def create[E <: Estimator[_]]( + estimator: E, dataset: Dataset[_]): Instrumentation[E] = { + create[E](estimator, dataset.rdd) + } + + /** + * Creates an instrumentation object for a training session. + */ + def create[E <: Estimator[_]]( + estimator: E, dataset: RDD[_]): Instrumentation[E] = { + new Instrumentation[E](estimator, dataset) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala new file mode 100644 index 0000000000000..7dec07ea14976 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -0,0 +1,426 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.IOException + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.internal.Logging +import org.apache.spark.ml._ +import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} +import org.apache.spark.ml.feature.RFormulaModel +import org.apache.spark.ml.param.{ParamPair, Params} +import org.apache.spark.ml.tuning.ValidatorParams +import org.apache.spark.sql.SQLContext +import org.apache.spark.util.Utils + +/** + * Trait for [[MLWriter]] and [[MLReader]]. + */ +private[util] sealed trait BaseReadWrite { + private var optionSQLContext: Option[SQLContext] = None + + /** + * Sets the SQL context to use for saving/loading. + */ + @Since("1.6.0") + def context(sqlContext: SQLContext): this.type = { + optionSQLContext = Option(sqlContext) + this + } + + /** + * Returns the user-specified SQL context or the default. + */ + protected final def sqlContext: SQLContext = { + if (optionSQLContext.isEmpty) { + optionSQLContext = Some(SQLContext.getOrCreate(SparkContext.getOrCreate())) + } + optionSQLContext.get + } + + /** Returns the [[SparkContext]] underlying [[sqlContext]] */ + protected final def sc: SparkContext = sqlContext.sparkContext +} + +/** + * Abstract class for utility classes that can save ML instances. + */ +@Experimental +@Since("1.6.0") +abstract class MLWriter extends BaseReadWrite with Logging { + + protected var shouldOverwrite: Boolean = false + + /** + * Saves the ML instances to the input path. + */ + @Since("1.6.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") + def save(path: String): Unit = { + val hadoopConf = sc.hadoopConfiguration + val outputPath = new Path(path) + val fs = outputPath.getFileSystem(hadoopConf) + val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + if (fs.exists(qualifiedOutputPath)) { + if (shouldOverwrite) { + logInfo(s"Path $path already exists. It will be overwritten.") + // TODO: Revert back to the original content if save is not successful. + fs.delete(qualifiedOutputPath, true) + } else { + throw new IOException( + s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.") + } + } + saveImpl(path) + } + + /** + * [[save()]] handles overwriting and then calls this method. Subclasses should override this + * method to implement the actual saving of the instance. + */ + @Since("1.6.0") + protected def saveImpl(path: String): Unit + + /** + * Overwrites if the output path already exists. + */ + @Since("1.6.0") + def overwrite(): this.type = { + shouldOverwrite = true + this + } + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.context(sqlContext) +} + +/** + * Trait for classes that provide [[MLWriter]]. + */ +@Since("1.6.0") +trait MLWritable { + + /** + * Returns an [[MLWriter]] instance for this ML instance. + */ + @Since("1.6.0") + def write: MLWriter + + /** + * Saves this ML instance to the input path, a shortcut of `write.save(path)`. + */ + @Since("1.6.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") + def save(path: String): Unit = write.save(path) +} + +private[ml] trait DefaultParamsWritable extends MLWritable { self: Params => + + override def write: MLWriter = new DefaultParamsWriter(this) +} + +/** + * Abstract class for utility classes that can load ML instances. + * + * @tparam T ML instance type + */ +@Experimental +@Since("1.6.0") +abstract class MLReader[T] extends BaseReadWrite { + + /** + * Loads the ML component from the input path. + */ + @Since("1.6.0") + def load(path: String): T + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.context(sqlContext) +} + +/** + * Trait for objects that provide [[MLReader]]. + * + * @tparam T ML instance type + */ +@Experimental +@Since("1.6.0") +trait MLReadable[T] { + + /** + * Returns an [[MLReader]] instance for this class. + */ + @Since("1.6.0") + def read: MLReader[T] + + /** + * Reads an ML instance from the input path, a shortcut of `read.load(path)`. + * + * Note: Implementing classes should override this to be Java-friendly. + */ + @Since("1.6.0") + def load(path: String): T = read.load(path) +} + +private[ml] trait DefaultParamsReadable[T] extends MLReadable[T] { + + override def read: MLReader[T] = new DefaultParamsReader +} + +/** + * Default [[MLWriter]] implementation for transformers and estimators that contain basic + * (json4s-serializable) params and no data. This will not handle more complex params or types with + * data (e.g., models with coefficients). + * + * @param instance object to save + */ +private[ml] class DefaultParamsWriter(instance: Params) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + } +} + +private[ml] object DefaultParamsWriter { + + /** + * Saves metadata + Params to: path + "/metadata" + * - class + * - timestamp + * - sparkVersion + * - uid + * - paramMap + * - (optionally, extra metadata) + * + * @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc. + * @param paramMap If given, this is saved in the "paramMap" field. + * Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using + * [[org.apache.spark.ml.param.Param.jsonEncode()]]. + */ + def saveMetadata( + instance: Params, + path: String, + sc: SparkContext, + extraMetadata: Option[JObject] = None, + paramMap: Option[JValue] = None): Unit = { + val metadataPath = new Path(path, "metadata").toString + val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap) + sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) + } + + /** + * Helper for [[saveMetadata()]] which extracts the JSON to save. + * This is useful for ensemble models which need to save metadata for many sub-models. + * + * @see [[saveMetadata()]] for details on what this includes. + */ + def getMetadataToSave( + instance: Params, + sc: SparkContext, + extraMetadata: Option[JObject] = None, + paramMap: Option[JValue] = None): String = { + val uid = instance.uid + val cls = instance.getClass.getName + val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] + val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList)) + val basicMetadata = ("class" -> cls) ~ + ("timestamp" -> System.currentTimeMillis()) ~ + ("sparkVersion" -> sc.version) ~ + ("uid" -> uid) ~ + ("paramMap" -> jsonParams) + val metadata = extraMetadata match { + case Some(jObject) => + basicMetadata ~ jObject + case None => + basicMetadata + } + val metadataJson: String = compact(render(metadata)) + metadataJson + } +} + +/** + * Default [[MLReader]] implementation for transformers and estimators that contain basic + * (json4s-serializable) params and no data. This will not handle more complex params or types with + * data (e.g., models with coefficients). + * + * @tparam T ML instance type + * TODO: Consider adding check for correct class name. + */ +private[ml] class DefaultParamsReader[T] extends MLReader[T] { + + override def load(path: String): T = { + val metadata = DefaultParamsReader.loadMetadata(path, sc) + val cls = Utils.classForName(metadata.className) + val instance = + cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params] + DefaultParamsReader.getAndSetParams(instance, metadata) + instance.asInstanceOf[T] + } +} + +private[ml] object DefaultParamsReader { + + /** + * All info from metadata file. + * + * @param params paramMap, as a [[JValue]] + * @param metadata All metadata, including the other fields + * @param metadataJson Full metadata file String (for debugging) + */ + case class Metadata( + className: String, + uid: String, + timestamp: Long, + sparkVersion: String, + params: JValue, + metadata: JValue, + metadataJson: String) { + + /** + * Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name. + * This can be useful for getting a Param value before an instance of [[Params]] + * is available. + */ + def getParamValue(paramName: String): JValue = { + implicit val format = DefaultFormats + params match { + case JObject(pairs) => + val values = pairs.filter { case (pName, jsonValue) => + pName == paramName + }.map(_._2) + assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" + + s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", ")) + values.head + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: $metadataJson.") + } + } + } + + /** + * Load metadata saved using [[DefaultParamsWriter.saveMetadata()]] + * + * @param expectedClassName If non empty, this is checked against the loaded metadata. + * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata + */ + def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = { + val metadataPath = new Path(path, "metadata").toString + val metadataStr = sc.textFile(metadataPath, 1).first() + parseMetadata(metadataStr, expectedClassName) + } + + /** + * Parse metadata JSON string produced by [[DefaultParamsWriter.getMetadataToSave()]]. + * This is a helper function for [[loadMetadata()]]. + * + * @param metadataStr JSON string of metadata + * @param expectedClassName If non empty, this is checked against the loaded metadata. + * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata + */ + def parseMetadata(metadataStr: String, expectedClassName: String = ""): Metadata = { + val metadata = parse(metadataStr) + + implicit val format = DefaultFormats + val className = (metadata \ "class").extract[String] + val uid = (metadata \ "uid").extract[String] + val timestamp = (metadata \ "timestamp").extract[Long] + val sparkVersion = (metadata \ "sparkVersion").extract[String] + val params = metadata \ "paramMap" + if (expectedClassName.nonEmpty) { + require(className == expectedClassName, s"Error loading metadata: Expected class name" + + s" $expectedClassName but found class name $className") + } + + Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr) + } + + /** + * Extract Params from metadata, and set them in the instance. + * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]]. + * TODO: Move to [[Metadata]] method + */ + def getAndSetParams(instance: Params, metadata: Metadata): Unit = { + implicit val format = DefaultFormats + metadata.params match { + case JObject(pairs) => + pairs.foreach { case (paramName, jsonValue) => + val param = instance.getParam(paramName) + val value = param.jsonDecode(compact(render(jsonValue))) + instance.set(param, value) + } + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") + } + } + + /** + * Load a [[Params]] instance from the given path, and return it. + * This assumes the instance implements [[MLReadable]]. + */ + def loadParamsInstance[T](path: String, sc: SparkContext): T = { + val metadata = DefaultParamsReader.loadMetadata(path, sc) + val cls = Utils.classForName(metadata.className) + cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path) + } +} + +/** + * Default Meta-Algorithm read and write implementation. + */ +private[ml] object MetaAlgorithmReadWrite { + /** + * Examine the given estimator (which may be a compound estimator) and extract a mapping + * from UIDs to corresponding [[Params]] instances. + */ + def getUidMap(instance: Params): Map[String, Params] = { + val uidList = getUidMapImpl(instance) + val uidMap = uidList.toMap + if (uidList.size != uidMap.size) { + throw new RuntimeException(s"${instance.getClass.getName}.load found a compound estimator" + + s" with stages with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}.") + } + uidMap + } + + private def getUidMapImpl(instance: Params): List[(String, Params)] = { + val subStages: Array[Params] = instance match { + case p: Pipeline => p.getStages.asInstanceOf[Array[Params]] + case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]] + case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator) + case ovr: OneVsRest => Array(ovr.getClassifier) + case ovrModel: OneVsRestModel => Array(ovrModel.getClassifier) ++ ovrModel.models + case rformModel: RFormulaModel => Array(rformModel.pipelineModel) + case _: Params => Array() + } + val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _) + List((instance.uid, instance)) ++ subStageMaps + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 76f651488aef9..334410c9620de 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.util -import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.sql.types.{DataType, NumericType, StructField, StructType} /** @@ -43,6 +43,37 @@ private[spark] object SchemaUtils { s"Column $colName must be of type $dataType but was actually $actualDataType.$message") } + /** + * Check whether the given schema contains a column of one of the require data types. + * @param colName column name + * @param dataTypes required column data types + */ + def checkColumnTypes( + schema: StructType, + colName: String, + dataTypes: Seq[DataType], + msg: String = ""): Unit = { + val actualDataType = schema(colName).dataType + val message = if (msg != null && msg.trim.length > 0) " " + msg else "" + require(dataTypes.exists(actualDataType.equals), + s"Column $colName must be of type equal to one of the following types: " + + s"${dataTypes.mkString("[", ", ", "]")} but was actually of type $actualDataType.$message") + } + + /** + * Check whether the given schema contains a column of the numeric data type. + * @param colName column name + */ + def checkNumericType( + schema: StructType, + colName: String, + msg: String = ""): Unit = { + val actualDataType = schema(colName).dataType + val message = if (msg != null && msg.trim.length > 0) " " + msg else "" + require(actualDataType.isInstanceOf[NumericType], s"Column $colName must be of type " + + s"NumericType but was actually of type $actualDataType.$message") + } + /** * Appends a new column to the input schema. This fails if the given output column already exists. * @param schema input schema @@ -54,12 +85,10 @@ private[spark] object SchemaUtils { def appendColumn( schema: StructType, colName: String, - dataType: DataType): StructType = { + dataType: DataType, + nullable: Boolean = false): StructType = { if (colName.isEmpty) return schema - val fieldNames = schema.fieldNames - require(!fieldNames.contains(colName), s"Column $colName already exists.") - val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false) - StructType(outputFields) + appendColumn(schema, StructField(colName, dataType, nullable)) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala index ee933f4cfcafd..e6d1dceebed4c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/FPGrowthModelWrapper.scala @@ -17,8 +17,7 @@ package org.apache.spark.mllib.api.python -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel} +import org.apache.spark.mllib.fpm.FPGrowthModel import org.apache.spark.rdd.RDD /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala index 0ec88ef77d695..364d5eea08ce4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala @@ -17,36 +17,31 @@ package org.apache.spark.mllib.api.python -import java.util.{List => JList} - -import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConverters import org.apache.spark.SparkContext -import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrix} import org.apache.spark.mllib.clustering.GaussianMixtureModel +import org.apache.spark.mllib.linalg.{Vector, Vectors} /** - * Wrapper around GaussianMixtureModel to provide helper methods in Python - */ + * Wrapper around GaussianMixtureModel to provide helper methods in Python + */ private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) { val weights: Vector = Vectors.dense(model.weights) val k: Int = weights.size /** - * Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian - */ - val gaussians: JList[Object] = { - val modelGaussians = model.gaussians - var i = 0 - var mu = ArrayBuffer.empty[Vector] - var sigma = ArrayBuffer.empty[Matrix] - while (i < k) { - mu += modelGaussians(i).mu - sigma += modelGaussians(i).sigma - i += 1 + * Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian + */ + val gaussians: Array[Byte] = { + val modelGaussians = model.gaussians.map { gaussian => + Array[Any](gaussian.mu, gaussian.sigma) } - List(mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava + SerDe.dumps(JavaConverters.seqAsJavaListConverter(modelGaussians).asJava) + } + + def predictSoft(point: Vector): Vector = { + Vectors.dense(model.predictSoft(point)) } def save(sc: SparkContext, path: String): Unit = model.save(sc, path) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala new file mode 100644 index 0000000000000..63282eee6e656 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.api.python + +import scala.collection.JavaConverters + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.clustering.LDAModel +import org.apache.spark.mllib.linalg.Matrix + +/** + * Wrapper around LDAModel to provide helper methods in Python + */ +private[python] class LDAModelWrapper(model: LDAModel) { + + def topicsMatrix(): Matrix = model.topicsMatrix + + def vocabSize(): Int = model.vocabSize + + def describeTopics(): Array[Byte] = describeTopics(this.model.vocabSize) + + def describeTopics(maxTermsPerTopic: Int): Array[Byte] = { + val topics = model.describeTopics(maxTermsPerTopic).map { case (terms, termWeights) => + val jTerms = JavaConverters.seqAsJavaListConverter(terms).asJava + val jTermWeights = JavaConverters.seqAsJavaListConverter(termWeights).asJava + Array[Any](jTerms, jTermWeights) + } + SerDe.dumps(JavaConverters.seqAsJavaListConverter(topics).asJava) + } + + def save(sc: SparkContext, path: String): Unit = model.save(sc, path) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala index bc6041b221732..6530870b83a11 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala @@ -17,8 +17,8 @@ package org.apache.spark.mllib.api.python -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.clustering.PowerIterationClusteringModel +import org.apache.spark.rdd.RDD /** * A Wrapper of PowerIterationClusteringModel to provide helper method for Python diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 40c41806cdfea..1a58779055f44 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -19,16 +19,15 @@ package org.apache.spark.mllib.api.python import java.io.OutputStream import java.nio.{ByteBuffer, ByteOrder} +import java.nio.charset.StandardCharsets import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer import scala.language.existentials import scala.reflect.ClassTag import net.razorvine.pickle._ -import org.apache.spark.SparkContext import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.python.SerDeUtil import org.apache.spark.mllib.classification._ @@ -42,18 +41,18 @@ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.random.{RandomRDDs => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.stat.{ + KernelDensity, MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.stat.test.{ChiSqTestResult, KolmogorovSmirnovTestResult} -import org.apache.spark.mllib.stat.{ - KernelDensity, MultivariateStatisticalSummary, Statistics} +import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest} import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy} import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.loss.Losses -import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel, RandomForestModel} -import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest} -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.util.LinearDataGenerator +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel, + RandomForestModel} +import org.apache.spark.mllib.util.{LinearDataGenerator, MLUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.storage.StorageLevel @@ -120,6 +119,23 @@ private[python] class PythonMLLibAPI extends Serializable { } } + /** + * Java stub for Python mllib BisectingKMeans.run() + */ + def trainBisectingKMeans( + data: JavaRDD[Vector], + k: Int, + maxIterations: Int, + minDivisibleClusterSize: Double, + seed: Long): BisectingKMeansModel = { + new BisectingKMeans() + .setK(k) + .setMaxIterations(maxIterations) + .setMinDivisibleClusterSize(minDivisibleClusterSize) + .setSeed(seed) + .run(data) + } + /** * Java stub for Python mllib LinearRegressionWithSGD.train() */ @@ -341,7 +357,7 @@ private[python] class PythonMLLibAPI extends Serializable { val kMeansAlg = new KMeans() .setK(k) .setMaxIterations(maxIterations) - .setRuns(runs) + .internalSetRuns(runs) .setInitializationMode(initializationMode) .setInitializationSteps(initializationSteps) .setEpsilon(epsilon) @@ -413,7 +429,7 @@ private[python] class PythonMLLibAPI extends Serializable { val weight = wt.toArray val mean = mu.map(_.asInstanceOf[DenseVector]) val sigma = si.map(_.asInstanceOf[DenseMatrix]) - val gaussians = Array.tabulate(weight.length){ + val gaussians = Array.tabulate(weight.length) { i => new MultivariateGaussian(mean(i), sigma(i)) } val model = new GaussianMixtureModel(weight, gaussians) @@ -517,7 +533,7 @@ private[python] class PythonMLLibAPI extends Serializable { topicConcentration: Double, seed: java.lang.Long, checkpointInterval: Int, - optimizer: String): LDAModel = { + optimizer: String): LDAModelWrapper = { val algo = new LDA() .setK(k) .setMaxIterations(maxIterations) @@ -535,7 +551,16 @@ private[python] class PythonMLLibAPI extends Serializable { case _ => throw new IllegalArgumentException("input values contains invalid type value.") } } - algo.run(documents) + val model = algo.run(documents) + new LDAModelWrapper(model) + } + + /** + * Load a LDA model + */ + def loadLDAModel(jsc: JavaSparkContext, path: String): LDAModelWrapper = { + val model = DistributedLDAModel.load(jsc.sc, path) + new LDAModelWrapper(model) } @@ -671,39 +696,6 @@ private[python] class PythonMLLibAPI extends Serializable { } } - private[python] class Word2VecModelWrapper(model: Word2VecModel) { - def transform(word: String): Vector = { - model.transform(word) - } - - /** - * Transforms an RDD of words to its vector representation - * @param rdd an RDD of words - * @return an RDD of vector representations of words - */ - def transform(rdd: JavaRDD[String]): JavaRDD[Vector] = { - rdd.rdd.map(model.transform) - } - - def findSynonyms(word: String, num: Int): JList[Object] = { - val vec = transform(word) - findSynonyms(vec, num) - } - - def findSynonyms(vector: Vector, num: Int): JList[Object] = { - val result = model.findSynonyms(vector, num) - val similarity = Vectors.dense(result.map(_._2)) - val words = result.map(_._1) - List(words, similarity).map(_.asInstanceOf[Object]).asJava - } - - def getVectors: JMap[String, JList[Float]] = { - model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava - } - - def save(sc: SparkContext, path: String): Unit = model.save(sc, path) - } - /** * Java stub for Python mllib DecisionTree.train(). * This stub returns a handle to the Java object instead of the content of the Java object. @@ -1060,7 +1052,7 @@ private[python] class PythonMLLibAPI extends Serializable { * Java stub for the constructor of Python mllib RankingMetrics */ def newRankingMetrics(predictionAndLabels: DataFrame): RankingMetrics[Any] = { - new RankingMetrics(predictionAndLabels.map( + new RankingMetrics(predictionAndLabels.rdd.map( r => (r.getSeq(0).toArray[Any], r.getSeq(1).toArray[Any]))) } @@ -1134,7 +1126,7 @@ private[python] class PythonMLLibAPI extends Serializable { * Wrapper around RowMatrix constructor. */ def createRowMatrix(rows: JavaRDD[Vector], numRows: Long, numCols: Int): RowMatrix = { - new RowMatrix(rows.rdd, numRows, numCols) + new RowMatrix(rows.rdd.retag(classOf[Vector]), numRows, numCols) } /** @@ -1143,7 +1135,7 @@ private[python] class PythonMLLibAPI extends Serializable { def createIndexedRowMatrix(rows: DataFrame, numRows: Long, numCols: Int): IndexedRowMatrix = { // We use DataFrames for serialization of IndexedRows from Python, // so map each Row in the DataFrame back to an IndexedRow. - val indexedRows = rows.map { + val indexedRows = rows.rdd.map { case Row(index: Long, vector: Vector) => IndexedRow(index, vector) } new IndexedRowMatrix(indexedRows, numRows, numCols) @@ -1155,7 +1147,7 @@ private[python] class PythonMLLibAPI extends Serializable { def createCoordinateMatrix(rows: DataFrame, numRows: Long, numCols: Long): CoordinateMatrix = { // We use DataFrames for serialization of MatrixEntry entries from // Python, so map each Row in the DataFrame back to a MatrixEntry. - val entries = rows.map { + val entries = rows.rdd.map { case Row(i: Long, j: Long, value: Double) => MatrixEntry(i, j, value) } new CoordinateMatrix(entries, numRows, numCols) @@ -1169,7 +1161,7 @@ private[python] class PythonMLLibAPI extends Serializable { // We use DataFrames for serialization of sub-matrix blocks from // Python, so map each Row in the DataFrame back to a // ((blockRowIndex, blockColIndex), sub-matrix) tuple. - val blockTuples = blocks.map { + val blockTuples = blocks.rdd.map { case Row(Row(blockRowIndex: Long, blockColIndex: Long), subMatrix: Matrix) => ((blockRowIndex.toInt, blockColIndex.toInt), subMatrix) } @@ -1182,7 +1174,7 @@ private[python] class PythonMLLibAPI extends Serializable { def getIndexedRows(indexedRowMatrix: IndexedRowMatrix): DataFrame = { // We use DataFrames for serialization of IndexedRows to Python, // so return a DataFrame. - val sqlContext = new SQLContext(indexedRowMatrix.rows.sparkContext) + val sqlContext = SQLContext.getOrCreate(indexedRowMatrix.rows.sparkContext) sqlContext.createDataFrame(indexedRowMatrix.rows) } @@ -1192,7 +1184,7 @@ private[python] class PythonMLLibAPI extends Serializable { def getMatrixEntries(coordinateMatrix: CoordinateMatrix): DataFrame = { // We use DataFrames for serialization of MatrixEntry entries to // Python, so return a DataFrame. - val sqlContext = new SQLContext(coordinateMatrix.entries.sparkContext) + val sqlContext = SQLContext.getOrCreate(coordinateMatrix.entries.sparkContext) sqlContext.createDataFrame(coordinateMatrix.entries) } @@ -1202,7 +1194,7 @@ private[python] class PythonMLLibAPI extends Serializable { def getMatrixBlocks(blockMatrix: BlockMatrix): DataFrame = { // We use DataFrames for serialization of sub-matrix blocks to // Python, so return a DataFrame. - val sqlContext = new SQLContext(blockMatrix.blocks.sparkContext) + val sqlContext = SQLContext.getOrCreate(blockMatrix.blocks.sparkContext) sqlContext.createDataFrame(blockMatrix.blocks) } } @@ -1213,7 +1205,6 @@ private[python] class PythonMLLibAPI extends Serializable { private[spark] object SerDe extends Serializable { val PYSPARK_PACKAGE = "pyspark.mllib" - val LATIN1 = "ISO-8859-1" /** * Base class used for pickle @@ -1235,7 +1226,7 @@ private[spark] object SerDe extends Serializable { def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { if (obj == this) { out.write(Opcodes.GLOBAL) - out.write((module + "\n" + name + "\n").getBytes) + out.write((module + "\n" + name + "\n").getBytes(StandardCharsets.UTF_8)) } else { pickler.save(this) // it will be memorized by Pickler saveState(obj, out, pickler) @@ -1261,7 +1252,8 @@ private[spark] object SerDe extends Serializable { if (obj.getClass.isArray) { obj.asInstanceOf[Array[Byte]] } else { - obj.asInstanceOf[String].getBytes(LATIN1) + // This must be ISO 8859-1 / Latin 1, not UTF-8, to interoperate correctly + obj.asInstanceOf[String].getBytes(StandardCharsets.ISO_8859_1) } } @@ -1305,7 +1297,7 @@ private[spark] object SerDe extends Serializable { def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { val m: DenseMatrix = obj.asInstanceOf[DenseMatrix] - val bytes = new Array[Byte](8 * m.values.size) + val bytes = new Array[Byte](8 * m.values.length) val order = ByteOrder.nativeOrder() val isTransposed = if (m.isTransposed) 1 else 0 ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values) @@ -1397,7 +1389,7 @@ private[spark] object SerDe extends Serializable { def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { val v: SparseVector = obj.asInstanceOf[SparseVector] - val n = v.indices.size + val n = v.indices.length val indiceBytes = new Array[Byte](4 * n) val order = ByteOrder.nativeOrder() ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().put(v.indices) @@ -1462,9 +1454,19 @@ private[spark] object SerDe extends Serializable { if (args.length != 3) { throw new PickleException("should be 3") } - new Rating(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], + new Rating(ratingsIdCheckLong(args(0)), ratingsIdCheckLong(args(1)), args(2).asInstanceOf[Double]) } + + private def ratingsIdCheckLong(obj: Object): Int = { + try { + obj.asInstanceOf[Int] + } catch { + case ex: ClassCastException => + throw new PickleException(s"Ratings id ${obj.toString} exceeds " + + s"max integer value of ${Int.MaxValue}", ex) + } + } } var initialized = false @@ -1488,7 +1490,11 @@ private[spark] object SerDe extends Serializable { initialize() def dumps(obj: AnyRef): Array[Byte] = { - new Pickler().dumps(obj) + obj match { + // Pickler in Python side cannot deserialize Scala Array normally. See SPARK-12834. + case array: Array[_] => new Pickler().dumps(array.toSeq.asJava) + case _ => new Pickler().dumps(obj) + } } def loads(bytes: Array[Byte]): AnyRef = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala new file mode 100644 index 0000000000000..05273c34347e8 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import java.util.{List => JList, Map => JMap} + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkContext +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.feature.Word2VecModel +import org.apache.spark.mllib.linalg.{Vector, Vectors} + +/** + * Wrapper around Word2VecModel to provide helper methods in Python + */ +private[python] class Word2VecModelWrapper(model: Word2VecModel) { + def transform(word: String): Vector = { + model.transform(word) + } + + /** + * Transforms an RDD of words to its vector representation + * @param rdd an RDD of words + * @return an RDD of vector representations of words + */ + def transform(rdd: JavaRDD[String]): JavaRDD[Vector] = { + rdd.rdd.map(model.transform) + } + + def findSynonyms(word: String, num: Int): JList[Object] = { + val vec = transform(word) + findSynonyms(vec, num) + } + + def findSynonyms(vector: Vector, num: Int): JList[Object] = { + val result = model.findSynonyms(vector, num) + val similarity = Vectors.dense(result.map(_._2)) + val words = result.map(_._1) + List(words, similarity).map(_.asInstanceOf[Object]).asJava + } + + def getVectors: JMap[String, JList[Float]] = { + model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava + } + + def save(sc: SparkContext, path: String): Unit = model.save(sc, path) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 2d52abc122bf2..f10570e662e07 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -19,15 +19,17 @@ package org.apache.spark.mllib.classification import org.apache.spark.SparkContext import org.apache.spark.annotation.Since +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.classification.impl.GLMClassificationModel +import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.dot -import org.apache.spark.mllib.linalg.{DenseVector, Vector} import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader} +import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable} import org.apache.spark.rdd.RDD - +import org.apache.spark.sql.SQLContext +import org.apache.spark.storage.StorageLevel /** * Classification model trained using Multinomial/Binary Logistic Regression. @@ -332,6 +334,13 @@ object LogisticRegressionWithSGD { * Limited-memory BFGS. Standard feature scaling and L2 regularization are used by default. * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} * for k classes multi-label classification problem. + * + * Earlier implementations of LogisticRegressionWithLBFGS applies a regularization + * penalty to all elements including the intercept. If this is called with one of + * standard updaters (L1Updater, or SquaredL2Updater) this is translated + * into a call to ml.LogisticRegression, otherwise this will use the existing mllib + * GeneralizedLinearAlgorithm trainer, resulting in a regularization penalty to the + * intercept. */ @Since("1.1.0") class LogisticRegressionWithLBFGS @@ -374,4 +383,76 @@ class LogisticRegressionWithLBFGS new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1) } } + + /** + * Run Logistic Regression with the configured parameters on an input RDD + * of LabeledPoint entries. + * + * If a known updater is used calls the ml implementation, to avoid + * applying a regularization penalty to the intercept, otherwise + * defaults to the mllib implementation. If more than two classes + * or feature scaling is disabled, always uses mllib implementation. + * If using ml implementation, uses ml code to generate initial weights. + */ + override def run(input: RDD[LabeledPoint]): LogisticRegressionModel = { + run(input, generateInitialWeights(input), userSuppliedWeights = false) + } + + /** + * Run Logistic Regression with the configured parameters on an input RDD + * of LabeledPoint entries starting from the initial weights provided. + * + * If a known updater is used calls the ml implementation, to avoid + * applying a regularization penalty to the intercept, otherwise + * defaults to the mllib implementation. If more than two classes + * or feature scaling is disabled, always uses mllib implementation. + * Uses user provided weights. + * + * In the ml LogisticRegression implementation, the number of corrections + * used in the LBFGS update can not be configured. So `optimizer.setNumCorrections()` + * will have no effect if we fall into that route. + */ + override def run(input: RDD[LabeledPoint], initialWeights: Vector): LogisticRegressionModel = { + run(input, initialWeights, userSuppliedWeights = true) + } + + private def run(input: RDD[LabeledPoint], initialWeights: Vector, userSuppliedWeights: Boolean): + LogisticRegressionModel = { + // ml's Logistic regression only supports binary classification currently. + if (numOfLinearPredictor == 1) { + def runWithMlLogisitcRegression(elasticNetParam: Double) = { + // Prepare the ml LogisticRegression based on our settings + val lr = new org.apache.spark.ml.classification.LogisticRegression() + lr.setRegParam(optimizer.getRegParam()) + lr.setElasticNetParam(elasticNetParam) + lr.setStandardization(useFeatureScaling) + if (userSuppliedWeights) { + val uid = Identifiable.randomUID("logreg-static") + lr.setInitialModel(new org.apache.spark.ml.classification.LogisticRegressionModel( + uid, initialWeights, 1.0)) + } + lr.setFitIntercept(addIntercept) + lr.setMaxIter(optimizer.getNumIterations()) + lr.setTol(optimizer.getConvergenceTol()) + // Convert our input into a DataFrame + val sqlContext = new SQLContext(input.context) + import sqlContext.implicits._ + val df = input.toDF() + // Determine if we should cache the DF + val handlePersistence = input.getStorageLevel == StorageLevel.NONE + // Train our model + val mlLogisticRegresionModel = lr.train(df, handlePersistence) + // convert the model + val weights = Vectors.dense(mlLogisticRegresionModel.coefficients.toArray) + createModel(weights, mlLogisticRegresionModel.intercept) + } + optimizer.getUpdater() match { + case x: SquaredL2Updater => runWithMlLogisitcRegression(0.0) + case x: L1Updater => runWithMlLogisitcRegression(1.0) + case _ => super.run(input, initialWeights) + } + } else { + super.run(input, initialWeights) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index a956084ae06e8..eb3ee41f7cf4f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -24,8 +24,9 @@ import scala.collection.JavaConverters._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.{Logging, SparkContext, SparkException} +import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.annotation.Since +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{Loader, Saveable} @@ -74,7 +75,7 @@ class NaiveBayesModel private[spark] ( case Multinomial => (None, None) case Bernoulli => val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value))) - val ones = new DenseVector(Array.fill(thetaMatrix.numCols){1.0}) + val ones = new DenseVector(Array.fill(thetaMatrix.numCols) {1.0}) val thetaMinusNegTheta = thetaMatrix.map { value => value - math.log(1.0 - math.exp(value)) } @@ -192,7 +193,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { modelType: String) def save(sc: SparkContext, path: String, data: Data): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -208,7 +209,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { @Since("1.3.0") def load(sc: SparkContext, path: String): NaiveBayesModel = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) // Load Parquet data. val dataRDD = sqlContext.read.parquet(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. @@ -239,7 +240,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { theta: Array[Array[Double]]) def save(sc: SparkContext, path: String, data: Data): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -254,7 +255,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { } def load(sc: SparkContext, path: String): NaiveBayesModel = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) // Load Parquet data. val dataRDD = sqlContext.read.parquet(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. @@ -325,6 +326,8 @@ class NaiveBayes private ( /** Set the smoothing parameter. Default: 1.0. */ @Since("0.9.0") def setLambda(lambda: Double): NaiveBayes = { + require(lambda >= 0, + s"Smoothing parameter must be nonnegative but got ${lambda}") this.lambda = lambda this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala index fe09f6b75d28b..4308ae04ee84d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala @@ -51,7 +51,7 @@ private[classification] object GLMClassificationModel { weights: Vector, intercept: Double, threshold: Option[Double]): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -74,10 +74,10 @@ private[classification] object GLMClassificationModel { */ def loadData(sc: SparkContext, path: String, modelClass: String): Data = { val datapath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataRDD = sqlContext.read.parquet(datapath) val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1) - assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") + assert(dataArray.length == 1, s"Unable to load $modelClass data from: $datapath") val data = dataArray(0) assert(data.size == 3, s"Unable to load $modelClass data from: $datapath") val (weights, intercept) = data match { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala new file mode 100644 index 0000000000000..e4bd0dc25ee54 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -0,0 +1,493 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import java.util.Random + +import scala.annotation.tailrec +import scala.collection.mutable + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.internal.Logging +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +/** + * A bisecting k-means algorithm based on the paper "A comparison of document clustering techniques" + * by Steinbach, Karypis, and Kumar, with modification to fit Spark. + * The algorithm starts from a single cluster that contains all points. + * Iteratively it finds divisible clusters on the bottom level and bisects each of them using + * k-means, until there are `k` leaf clusters in total or no leaf clusters are divisible. + * The bisecting steps of clusters on the same level are grouped together to increase parallelism. + * If bisecting all divisible clusters on the bottom level would result more than `k` leaf clusters, + * larger clusters get higher priority. + * + * @param k the desired number of leaf clusters (default: 4). The actual number could be smaller if + * there are no divisible leaf clusters. + * @param maxIterations the max number of k-means iterations to split clusters (default: 20) + * @param minDivisibleClusterSize the minimum number of points (if >= 1.0) or the minimum proportion + * of points (if < 1.0) of a divisible cluster (default: 1) + * @param seed a random seed (default: hash value of the class name) + * + * @see [[http://glaros.dtc.umn.edu/gkhome/fetch/papers/docclusterKDDTMW00.pdf + * Steinbach, Karypis, and Kumar, A comparison of document clustering techniques, + * KDD Workshop on Text Mining, 2000.]] + */ +@Since("1.6.0") +@Experimental +class BisectingKMeans private ( + private var k: Int, + private var maxIterations: Int, + private var minDivisibleClusterSize: Double, + private var seed: Long) extends Logging { + + import BisectingKMeans._ + + /** + * Constructs with the default configuration + */ + @Since("1.6.0") + def this() = this(4, 20, 1.0, classOf[BisectingKMeans].getName.##) + + /** + * Sets the desired number of leaf clusters (default: 4). + * The actual number could be smaller if there are no divisible leaf clusters. + */ + @Since("1.6.0") + def setK(k: Int): this.type = { + require(k > 0, s"k must be positive but got $k.") + this.k = k + this + } + + /** + * Gets the desired number of leaf clusters. + */ + @Since("1.6.0") + def getK: Int = this.k + + /** + * Sets the max number of k-means iterations to split clusters (default: 20). + */ + @Since("1.6.0") + def setMaxIterations(maxIterations: Int): this.type = { + require(maxIterations > 0, s"maxIterations must be positive but got $maxIterations.") + this.maxIterations = maxIterations + this + } + + /** + * Gets the max number of k-means iterations to split clusters. + */ + @Since("1.6.0") + def getMaxIterations: Int = this.maxIterations + + /** + * Sets the minimum number of points (if >= `1.0`) or the minimum proportion of points + * (if < `1.0`) of a divisible cluster (default: 1). + */ + @Since("1.6.0") + def setMinDivisibleClusterSize(minDivisibleClusterSize: Double): this.type = { + require(minDivisibleClusterSize > 0.0, + s"minDivisibleClusterSize must be positive but got $minDivisibleClusterSize.") + this.minDivisibleClusterSize = minDivisibleClusterSize + this + } + + /** + * Gets the minimum number of points (if >= `1.0`) or the minimum proportion of points + * (if < `1.0`) of a divisible cluster. + */ + @Since("1.6.0") + def getMinDivisibleClusterSize: Double = minDivisibleClusterSize + + /** + * Sets the random seed (default: hash value of the class name). + */ + @Since("1.6.0") + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + + /** + * Gets the random seed. + */ + @Since("1.6.0") + def getSeed: Long = this.seed + + /** + * Runs the bisecting k-means algorithm. + * @param input RDD of vectors + * @return model for the bisecting kmeans + */ + @Since("1.6.0") + def run(input: RDD[Vector]): BisectingKMeansModel = { + if (input.getStorageLevel == StorageLevel.NONE) { + logWarning(s"The input RDD ${input.id} is not directly cached, which may hurt performance if" + + " its parent RDDs are also not cached.") + } + val d = input.map(_.size).first() + logInfo(s"Feature dimension: $d.") + // Compute and cache vector norms for fast distance computation. + val norms = input.map(v => Vectors.norm(v, 2.0)).persist(StorageLevel.MEMORY_AND_DISK) + val vectors = input.zip(norms).map { case (x, norm) => new VectorWithNorm(x, norm) } + var assignments = vectors.map(v => (ROOT_INDEX, v)) + var activeClusters = summarize(d, assignments) + val rootSummary = activeClusters(ROOT_INDEX) + val n = rootSummary.size + logInfo(s"Number of points: $n.") + logInfo(s"Initial cost: ${rootSummary.cost}.") + val minSize = if (minDivisibleClusterSize >= 1.0) { + math.ceil(minDivisibleClusterSize).toLong + } else { + math.ceil(minDivisibleClusterSize * n).toLong + } + logInfo(s"The minimum number of points of a divisible cluster is $minSize.") + var inactiveClusters = mutable.Seq.empty[(Long, ClusterSummary)] + val random = new Random(seed) + var numLeafClustersNeeded = k - 1 + var level = 1 + while (activeClusters.nonEmpty && numLeafClustersNeeded > 0 && level < LEVEL_LIMIT) { + // Divisible clusters are sufficiently large and have non-trivial cost. + var divisibleClusters = activeClusters.filter { case (_, summary) => + (summary.size >= minSize) && (summary.cost > MLUtils.EPSILON * summary.size) + } + // If we don't need all divisible clusters, take the larger ones. + if (divisibleClusters.size > numLeafClustersNeeded) { + divisibleClusters = divisibleClusters.toSeq.sortBy { case (_, summary) => + -summary.size + }.take(numLeafClustersNeeded) + .toMap + } + if (divisibleClusters.nonEmpty) { + val divisibleIndices = divisibleClusters.keys.toSet + logInfo(s"Dividing ${divisibleIndices.size} clusters on level $level.") + var newClusterCenters = divisibleClusters.flatMap { case (index, summary) => + val (left, right) = splitCenter(summary.center, random) + Iterator((leftChildIndex(index), left), (rightChildIndex(index), right)) + }.map(identity) // workaround for a Scala bug (SI-7005) that produces a not serializable map + var newClusters: Map[Long, ClusterSummary] = null + var newAssignments: RDD[(Long, VectorWithNorm)] = null + for (iter <- 0 until maxIterations) { + newAssignments = updateAssignments(assignments, divisibleIndices, newClusterCenters) + .filter { case (index, _) => + divisibleIndices.contains(parentIndex(index)) + } + newClusters = summarize(d, newAssignments) + newClusterCenters = newClusters.mapValues(_.center).map(identity) + } + // TODO: Unpersist old indices. + val indices = updateAssignments(assignments, divisibleIndices, newClusterCenters).keys + .persist(StorageLevel.MEMORY_AND_DISK) + assignments = indices.zip(vectors) + inactiveClusters ++= activeClusters + activeClusters = newClusters + numLeafClustersNeeded -= divisibleClusters.size + } else { + logInfo(s"None active and divisible clusters left on level $level. Stop iterations.") + inactiveClusters ++= activeClusters + activeClusters = Map.empty + } + level += 1 + } + val clusters = activeClusters ++ inactiveClusters + val root = buildTree(clusters) + new BisectingKMeansModel(root) + } + + /** + * Java-friendly version of [[run()]]. + */ + def run(data: JavaRDD[Vector]): BisectingKMeansModel = run(data.rdd) +} + +private object BisectingKMeans extends Serializable { + + /** The index of the root node of a tree. */ + private val ROOT_INDEX: Long = 1 + + private val MAX_DIVISIBLE_CLUSTER_INDEX: Long = Long.MaxValue / 2 + + private val LEVEL_LIMIT = math.log10(Long.MaxValue) / math.log10(2) + + /** Returns the left child index of the given node index. */ + private def leftChildIndex(index: Long): Long = { + require(index <= MAX_DIVISIBLE_CLUSTER_INDEX, s"Child index out of bound: 2 * $index.") + 2 * index + } + + /** Returns the right child index of the given node index. */ + private def rightChildIndex(index: Long): Long = { + require(index <= MAX_DIVISIBLE_CLUSTER_INDEX, s"Child index out of bound: 2 * $index + 1.") + 2 * index + 1 + } + + /** Returns the parent index of the given node index, or 0 if the input is 1 (root). */ + private def parentIndex(index: Long): Long = { + index / 2 + } + + /** + * Summarizes data by each cluster as Map. + * @param d feature dimension + * @param assignments pairs of point and its cluster index + * @return a map from cluster indices to corresponding cluster summaries + */ + private def summarize( + d: Int, + assignments: RDD[(Long, VectorWithNorm)]): Map[Long, ClusterSummary] = { + assignments.aggregateByKey(new ClusterSummaryAggregator(d))( + seqOp = (agg, v) => agg.add(v), + combOp = (agg1, agg2) => agg1.merge(agg2) + ).mapValues(_.summary) + .collect().toMap + } + + /** + * Cluster summary aggregator. + * @param d feature dimension + */ + private class ClusterSummaryAggregator(val d: Int) extends Serializable { + private var n: Long = 0L + private val sum: Vector = Vectors.zeros(d) + private var sumSq: Double = 0.0 + + /** Adds a point. */ + def add(v: VectorWithNorm): this.type = { + n += 1L + // TODO: use a numerically stable approach to estimate cost + sumSq += v.norm * v.norm + BLAS.axpy(1.0, v.vector, sum) + this + } + + /** Merges another aggregator. */ + def merge(other: ClusterSummaryAggregator): this.type = { + n += other.n + sumSq += other.sumSq + BLAS.axpy(1.0, other.sum, sum) + this + } + + /** Returns the summary. */ + def summary: ClusterSummary = { + val mean = sum.copy + if (n > 0L) { + BLAS.scal(1.0 / n, mean) + } + val center = new VectorWithNorm(mean) + val cost = math.max(sumSq - n * center.norm * center.norm, 0.0) + new ClusterSummary(n, center, cost) + } + } + + /** + * Bisects a cluster center. + * + * @param center current cluster center + * @param random a random number generator + * @return initial centers + */ + private def splitCenter( + center: VectorWithNorm, + random: Random): (VectorWithNorm, VectorWithNorm) = { + val d = center.vector.size + val norm = center.norm + val level = 1e-4 * norm + val noise = Vectors.dense(Array.fill(d)(random.nextDouble())) + val left = center.vector.copy + BLAS.axpy(-level, noise, left) + val right = center.vector.copy + BLAS.axpy(level, noise, right) + (new VectorWithNorm(left), new VectorWithNorm(right)) + } + + /** + * Updates assignments. + * @param assignments current assignments + * @param divisibleIndices divisible cluster indices + * @param newClusterCenters new cluster centers + * @return new assignments + */ + private def updateAssignments( + assignments: RDD[(Long, VectorWithNorm)], + divisibleIndices: Set[Long], + newClusterCenters: Map[Long, VectorWithNorm]): RDD[(Long, VectorWithNorm)] = { + assignments.map { case (index, v) => + if (divisibleIndices.contains(index)) { + val children = Seq(leftChildIndex(index), rightChildIndex(index)) + val selected = children.minBy { child => + KMeans.fastSquaredDistance(newClusterCenters(child), v) + } + (selected, v) + } else { + (index, v) + } + } + } + + /** + * Builds a clustering tree by re-indexing internal and leaf clusters. + * @param clusters a map from cluster indices to corresponding cluster summaries + * @return the root node of the clustering tree + */ + private def buildTree(clusters: Map[Long, ClusterSummary]): ClusteringTreeNode = { + var leafIndex = 0 + var internalIndex = -1 + + /** + * Builds a subtree from this given node index. + */ + def buildSubTree(rawIndex: Long): ClusteringTreeNode = { + val cluster = clusters(rawIndex) + val size = cluster.size + val center = cluster.center + val cost = cluster.cost + val isInternal = clusters.contains(leftChildIndex(rawIndex)) + if (isInternal) { + val index = internalIndex + internalIndex -= 1 + val leftIndex = leftChildIndex(rawIndex) + val rightIndex = rightChildIndex(rawIndex) + val height = math.sqrt(Seq(leftIndex, rightIndex).map { childIndex => + KMeans.fastSquaredDistance(center, clusters(childIndex).center) + }.max) + val left = buildSubTree(leftIndex) + val right = buildSubTree(rightIndex) + new ClusteringTreeNode(index, size, center, cost, height, Array(left, right)) + } else { + val index = leafIndex + leafIndex += 1 + val height = 0.0 + new ClusteringTreeNode(index, size, center, cost, height, Array.empty) + } + } + + buildSubTree(ROOT_INDEX) + } + + /** + * Summary of a cluster. + * + * @param size the number of points within this cluster + * @param center the center of the points within this cluster + * @param cost the sum of squared distances to the center + */ + private case class ClusterSummary(size: Long, center: VectorWithNorm, cost: Double) +} + +/** + * Represents a node in a clustering tree. + * + * @param index node index, negative for internal nodes and non-negative for leaf nodes + * @param size size of the cluster + * @param centerWithNorm cluster center with norm + * @param cost cost of the cluster, i.e., the sum of squared distances to the center + * @param height height of the node in the dendrogram. Currently this is defined as the max distance + * from the center to the centers of the children's, but subject to change. + * @param children children nodes + */ +@Since("1.6.0") +@Experimental +private[clustering] class ClusteringTreeNode private[clustering] ( + val index: Int, + val size: Long, + private[clustering] val centerWithNorm: VectorWithNorm, + val cost: Double, + val height: Double, + val children: Array[ClusteringTreeNode]) extends Serializable { + + /** Whether this is a leaf node. */ + val isLeaf: Boolean = children.isEmpty + + require((isLeaf && index >= 0) || (!isLeaf && index < 0)) + + /** Cluster center. */ + def center: Vector = centerWithNorm.vector + + /** Predicts the leaf cluster node index that the input point belongs to. */ + def predict(point: Vector): Int = { + val (index, _) = predict(new VectorWithNorm(point)) + index + } + + /** Returns the full prediction path from root to leaf. */ + def predictPath(point: Vector): Array[ClusteringTreeNode] = { + predictPath(new VectorWithNorm(point)).toArray + } + + /** Returns the full prediction path from root to leaf. */ + private def predictPath(pointWithNorm: VectorWithNorm): List[ClusteringTreeNode] = { + if (isLeaf) { + this :: Nil + } else { + val selected = children.minBy { child => + KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm) + } + selected :: selected.predictPath(pointWithNorm) + } + } + + /** + * Computes the cost (squared distance to the predicted leaf cluster center) of the input point. + */ + def computeCost(point: Vector): Double = { + val (_, cost) = predict(new VectorWithNorm(point)) + cost + } + + /** + * Predicts the cluster index and the cost of the input point. + */ + private def predict(pointWithNorm: VectorWithNorm): (Int, Double) = { + predict(pointWithNorm, KMeans.fastSquaredDistance(centerWithNorm, pointWithNorm)) + } + + /** + * Predicts the cluster index and the cost of the input point. + * @param pointWithNorm input point + * @param cost the cost to the current center + * @return (predicted leaf cluster index, cost) + */ + @tailrec + private def predict(pointWithNorm: VectorWithNorm, cost: Double): (Int, Double) = { + if (isLeaf) { + (index, cost) + } else { + val (selectedChild, minCost) = children.map { child => + (child, KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm)) + }.minBy(_._2) + selectedChild.predict(pointWithNorm, minCost) + } + } + + /** + * Returns all leaf nodes from this node. + */ + def leafNodes: Array[ClusteringTreeNode] = { + if (isLeaf) { + Array(this) + } else { + children.flatMap(_.leafNodes) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala new file mode 100644 index 0000000000000..c3b5b8b7900f5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.json4s._ +import org.json4s.DefaultFormats +import org.json4s.jackson.JsonMethods._ +import org.json4s.JsonDSL._ + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.internal.Logging +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.util.{Loader, Saveable} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} + +/** + * Clustering model produced by [[BisectingKMeans]]. + * The prediction is done level-by-level from the root node to a leaf node, and at each node among + * its children the closest to the input point is selected. + * + * @param root the root node of the clustering tree + */ +@Since("1.6.0") +@Experimental +class BisectingKMeansModel private[clustering] ( + private[clustering] val root: ClusteringTreeNode + ) extends Serializable with Saveable with Logging { + + /** + * Leaf cluster centers. + */ + @Since("1.6.0") + def clusterCenters: Array[Vector] = root.leafNodes.map(_.center) + + /** + * Number of leaf clusters. + */ + lazy val k: Int = clusterCenters.length + + /** + * Predicts the index of the cluster that the input point belongs to. + */ + @Since("1.6.0") + def predict(point: Vector): Int = { + root.predict(point) + } + + /** + * Predicts the indices of the clusters that the input points belong to. + */ + @Since("1.6.0") + def predict(points: RDD[Vector]): RDD[Int] = { + points.map { p => root.predict(p) } + } + + /** + * Java-friendly version of [[predict()]]. + */ + @Since("1.6.0") + def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = + predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] + + /** + * Computes the squared distance between the input point and the cluster center it belongs to. + */ + @Since("1.6.0") + def computeCost(point: Vector): Double = { + root.computeCost(point) + } + + /** + * Computes the sum of squared distances between the input points and their corresponding cluster + * centers. + */ + @Since("1.6.0") + def computeCost(data: RDD[Vector]): Double = { + data.map(root.computeCost).sum() + } + + /** + * Java-friendly version of [[computeCost()]]. + */ + @Since("1.6.0") + def computeCost(data: JavaRDD[Vector]): Double = this.computeCost(data.rdd) + + @Since("2.0.0") + override def save(sc: SparkContext, path: String): Unit = { + BisectingKMeansModel.SaveLoadV1_0.save(sc, this, path) + } + + override protected def formatVersion: String = "1.0" +} + +@Since("2.0.0") +object BisectingKMeansModel extends Loader[BisectingKMeansModel] { + + @Since("2.0.0") + override def load(sc: SparkContext, path: String): BisectingKMeansModel = { + val (loadedClassName, formatVersion, metadata) = Loader.loadMetadata(sc, path) + implicit val formats = DefaultFormats + val rootId = (metadata \ "rootId").extract[Int] + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, formatVersion) match { + case (classNameV1_0, "1.0") => + val model = SaveLoadV1_0.load(sc, path, rootId) + model + case _ => throw new Exception( + s"BisectingKMeansModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $formatVersion). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } + + private case class Data(index: Int, size: Long, center: Vector, norm: Double, cost: Double, + height: Double, children: Seq[Int]) + + private object Data { + def apply(r: Row): Data = Data(r.getInt(0), r.getLong(1), r.getAs[Vector](2), r.getDouble(3), + r.getDouble(4), r.getDouble(5), r.getSeq[Int](6)) + } + + private[clustering] object SaveLoadV1_0 { + private val thisFormatVersion = "1.0" + + private[clustering] + val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel" + + def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = { + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) + ~ ("rootId" -> model.root.index))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val data = getNodes(model.root).map(node => Data(node.index, node.size, + node.centerWithNorm.vector, node.centerWithNorm.norm, node.cost, node.height, + node.children.map(_.index))) + val dataRDD = sc.parallelize(data).toDF() + dataRDD.write.parquet(Loader.dataPath(path)) + } + + private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = { + if (node.children.isEmpty) { + Array(node) + } else { + node.children.flatMap(getNodes(_)) ++ Array(node) + } + } + + def load(sc: SparkContext, path: String, rootId: Int): BisectingKMeansModel = { + val sqlContext = SQLContext.getOrCreate(sc) + val rows = sqlContext.read.parquet(Loader.dataPath(path)) + Loader.checkSchema[Data](rows.schema) + val data = rows.select("index", "size", "center", "norm", "cost", "height", "children") + val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap + val rootNode = buildTree(rootId, nodes) + new BisectingKMeansModel(rootNode) + } + + private def buildTree(rootId: Int, nodes: Map[Int, Data]): ClusteringTreeNode = { + val root = nodes.get(rootId).get + if (root.children.isEmpty) { + new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), + root.cost, root.height, new Array[ClusteringTreeNode](0)) + } else { + val children = root.children.map(c => buildTree(c, nodes)) + new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), + root.cost, root.height, children.toArray) + } + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index 7b203e2f40815..f04c87259c941 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -45,10 +45,10 @@ import org.apache.spark.util.Utils * This is due to high-dimensional data (a) making it difficult to cluster at all (based * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. * - * @param k The number of independent Gaussians in the mixture model - * @param convergenceTol The maximum change in log-likelihood at which convergence - * is considered to have occurred. - * @param maxIterations The maximum number of iterations to perform + * @param k Number of independent Gaussians in the mixture model. + * @param convergenceTol Maximum change in log-likelihood at which convergence + * is considered to have occurred. + * @param maxIterations Maximum number of iterations allowed. */ @Since("1.3.0") class GaussianMixture private ( @@ -78,11 +78,9 @@ class GaussianMixture private ( */ @Since("1.3.0") def setInitialModel(model: GaussianMixtureModel): this.type = { - if (model.k == k) { - initialModel = Some(model) - } else { - throw new IllegalArgumentException("mismatched cluster count (model.k != k)") - } + require(model.k == k, + s"Mismatched cluster count (model.k ${model.k} != k ${k})") + initialModel = Some(model) this } @@ -97,6 +95,8 @@ class GaussianMixture private ( */ @Since("1.3.0") def setK(k: Int): this.type = { + require(k > 0, + s"Number of Gaussians must be positive but got ${k}") this.k = k this } @@ -108,16 +108,18 @@ class GaussianMixture private ( def getK: Int = k /** - * Set the maximum number of iterations to run. Default: 100 + * Set the maximum number of iterations allowed. Default: 100 */ @Since("1.3.0") def setMaxIterations(maxIterations: Int): this.type = { + require(maxIterations >= 0, + s"Maximum of iterations must be nonnegative but got ${maxIterations}") this.maxIterations = maxIterations this } /** - * Return the maximum number of iterations to run + * Return the maximum number of iterations allowed */ @Since("1.3.0") def getMaxIterations: Int = maxIterations @@ -128,6 +130,8 @@ class GaussianMixture private ( */ @Since("1.3.0") def setConvergenceTol(convergenceTol: Double): this.type = { + require(convergenceTol >= 0.0, + s"Convergence tolerance must be nonnegative but got ${convergenceTol}") this.convergenceTol = convergenceTol this } @@ -177,13 +181,12 @@ class GaussianMixture private ( val (weights, gaussians) = initialModel match { case Some(gmm) => (gmm.weights, gmm.gaussians) - case None => { + case None => val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed) (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => val slice = samples.view(i * nSamples, (i + 1) * nSamples) new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) }) - } } var llh = Double.MinValue // current log-likelihood diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 2115f7d99c182..f87613cc72f9a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -18,7 +18,6 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseVector => BreezeVector} - import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -26,11 +25,11 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD -import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix} +import org.apache.spark.mllib.linalg.{Matrix, Vector} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian -import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable} +import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.sql.{Row, SQLContext} /** * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points @@ -76,7 +75,7 @@ class GaussianMixtureModel @Since("1.3.0") ( */ @Since("1.5.0") def predict(point: Vector): Int = { - val r = computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k) + val r = predictSoft(point) r.indexOf(r.max) } @@ -145,7 +144,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { weights: Array[Double], gaussians: Array[MultivariateGaussian]): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -162,7 +161,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { def load(sc: SparkContext, path: String): GaussianMixtureModel = { val dataPath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataFrame = sqlContext.read.parquet(dataPath) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) @@ -178,13 +177,13 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { } @Since("1.4.0") - override def load(sc: SparkContext, path: String) : GaussianMixtureModel = { + override def load(sc: SparkContext, path: String): GaussianMixtureModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) implicit val formats = DefaultFormats val k = (metadata \ "k").extract[Int] val classNameV1_0 = SaveLoadV1_0.classNameV1_0 (loadedClassName, version) match { - case (classNameV1_0, "1.0") => { + case (classNameV1_0, "1.0") => val model = SaveLoadV1_0.load(sc, path) require(model.weights.length == k, s"GaussianMixtureModel requires weights of length $k " + @@ -193,7 +192,6 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { s"GaussianMixtureModel requires gaussians of length $k" + s"got gaussians of length ${model.gaussians.length}") model - } case _ => throw new Exception( s"GaussianMixtureModel.load did not recognize model with (className, format version):" + s"($loadedClassName, $version). Supported:\n" + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 2895db7c9061b..8ff0b83e8b49f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -19,8 +19,8 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable.ArrayBuffer -import org.apache.spark.Logging -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.{axpy, scal} import org.apache.spark.mllib.util.MLUtils @@ -65,21 +65,25 @@ class KMeans private ( */ @Since("0.8.0") def setK(k: Int): this.type = { + require(k > 0, + s"Number of clusters must be positive but got ${k}") this.k = k this } /** - * Maximum number of iterations to run. + * Maximum number of iterations allowed. */ @Since("1.4.0") def getMaxIterations: Int = maxIterations /** - * Set maximum number of iterations to run. Default: 20. + * Set maximum number of iterations allowed. Default: 20. */ @Since("0.8.0") def setMaxIterations(maxIterations: Int): this.type = { + require(maxIterations >= 0, + s"Maximum of iterations must be nonnegative but got ${maxIterations}") this.maxIterations = maxIterations this } @@ -107,7 +111,7 @@ class KMeans private ( * Number of runs of the algorithm to execute in parallel. */ @Since("1.4.0") - @deprecated("Support for runs is deprecated. This param will have no effect in 1.7.0.", "1.6.0") + @deprecated("Support for runs is deprecated. This param will have no effect in 2.0.0.", "1.6.0") def getRuns: Int = runs /** @@ -117,11 +121,20 @@ class KMeans private ( * return the best clustering found over any run. Default: 1. */ @Since("0.8.0") - @deprecated("Support for runs is deprecated. This param will have no effect in 1.7.0.", "1.6.0") + @deprecated("Support for runs is deprecated. This param will have no effect in 2.0.0.", "1.6.0") def setRuns(runs: Int): this.type = { + internalSetRuns(runs) + } + + // Internal version of setRuns for Python API, this should be removed at the same time as setRuns + // this is done to avoid deprecation warnings in our build. + private[mllib] def internalSetRuns(runs: Int): this.type = { if (runs <= 0) { throw new IllegalArgumentException("Number of runs must be positive") } + if (runs != 1) { + logWarning("Setting number of runs is deprecated and will have no effect in 2.0.0") + } this.runs = runs this } @@ -138,9 +151,8 @@ class KMeans private ( */ @Since("0.8.0") def setInitializationSteps(initializationSteps: Int): this.type = { - if (initializationSteps <= 0) { - throw new IllegalArgumentException("Number of initialization steps must be positive") - } + require(initializationSteps > 0, + s"Number of initialization steps must be positive but got ${initializationSteps}") this.initializationSteps = initializationSteps this } @@ -157,6 +169,8 @@ class KMeans private ( */ @Since("0.8.0") def setEpsilon(epsilon: Double): this.type = { + require(epsilon >= 0, + s"Distance threshold must be nonnegative but got ${epsilon}") this.epsilon = epsilon this } @@ -239,16 +253,14 @@ class KMeans private ( } val centers = initialModel match { - case Some(kMeansCenters) => { + case Some(kMeansCenters) => Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s))) - } - case None => { + case None => if (initializationMode == KMeans.RANDOM) { initRandom(data) } else { initKMeansParallel(data) } - } } val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9 logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) + @@ -301,6 +313,8 @@ class KMeans private ( contribs.iterator }.reduceByKey(mergeContribs).collectAsMap() + bcActiveCenters.unpersist(blocking = false) + // Update the cluster centers and costs for each active run for ((run, i) <- activeRuns.zipWithIndex) { var changed = false @@ -374,6 +388,8 @@ class KMeans private ( // Initialize each run's first center to a random point. val seed = new XORShiftRandom(this.seed).nextInt() val sample = data.takeSample(true, runs, seed).toSeq + // Could be empty if data is empty; fail with a better message early: + require(sample.size >= runs, s"Required $runs samples but got ${sample.size} from $data") val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense)) /** Merges new centers to centers. */ @@ -419,14 +435,17 @@ class KMeans private ( s0 } ) + + bcNewCenters.unpersist(blocking = false) preCosts.unpersist(blocking = false) + val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) => val rand = new XORShiftRandom(seed ^ (step << 16) ^ index) pointsWithCosts.flatMap { case (p, c) => val rs = (0 until runs).filter { r => rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r) } - if (rs.length > 0) Some(p, rs) else None + if (rs.length > 0) Some((p, rs)) else None } }.collect() mergeNewCenters() @@ -448,6 +467,9 @@ class KMeans private ( ((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0) } }.reduceByKey(_ + _).collectAsMap() + + bcCenters.unpersist(blocking = false) + val finalCenters = (0 until runs).par.map { r => val myCenters = centers(r).toArray val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray @@ -474,12 +496,15 @@ object KMeans { /** * Trains a k-means model using the given set of parameters. * - * @param data training points stored as `RDD[Vector]` - * @param k number of clusters - * @param maxIterations max number of iterations - * @param runs number of parallel runs, defaults to 1. The best model is returned. - * @param initializationMode initialization model, either "random" or "k-means||" (default). - * @param seed random seed value for cluster initialization + * @param data Training points as an `RDD` of `Vector` types. + * @param k Number of clusters to create. + * @param maxIterations Maximum number of iterations allowed. + * @param runs Number of runs to execute in parallel. The best model according to the cost + * function will be returned. (default: 1) + * @param initializationMode The initialization algorithm. This can either be "random" or + * "k-means||". (default: "k-means||") + * @param seed Random seed for cluster initialization. Default is to generate seed based + * on system time. */ @Since("1.3.0") def train( @@ -491,7 +516,7 @@ object KMeans { seed: Long): KMeansModel = { new KMeans().setK(k) .setMaxIterations(maxIterations) - .setRuns(runs) + .internalSetRuns(runs) .setInitializationMode(initializationMode) .setSeed(seed) .run(data) @@ -500,11 +525,13 @@ object KMeans { /** * Trains a k-means model using the given set of parameters. * - * @param data training points stored as `RDD[Vector]` - * @param k number of clusters - * @param maxIterations max number of iterations - * @param runs number of parallel runs, defaults to 1. The best model is returned. - * @param initializationMode initialization model, either "random" or "k-means||" (default). + * @param data Training points as an `RDD` of `Vector` types. + * @param k Number of clusters to create. + * @param maxIterations Maximum number of iterations allowed. + * @param runs Number of runs to execute in parallel. The best model according to the cost + * function will be returned. (default: 1) + * @param initializationMode The initialization algorithm. This can either be "random" or + * "k-means||". (default: "k-means||") */ @Since("0.8.0") def train( @@ -515,7 +542,7 @@ object KMeans { initializationMode: String): KMeansModel = { new KMeans().setK(k) .setMaxIterations(maxIterations) - .setRuns(runs) + .internalSetRuns(runs) .setInitializationMode(initializationMode) .run(data) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index a741584982725..439e4f8672242 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -23,15 +23,14 @@ import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Row, SQLContext} /** * A clustering model for K-means. Each point belongs to the cluster with the closest center. @@ -124,7 +123,7 @@ object KMeansModel extends Loader[KMeansModel] { val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel" def save(sc: SparkContext, model: KMeansModel, path: String): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) @@ -137,15 +136,15 @@ object KMeansModel extends Loader[KMeansModel] { def load(sc: SparkContext, path: String): KMeansModel = { implicit val formats = DefaultFormats - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) val k = (metadata \ "k").extract[Int] val centroids = sqlContext.read.parquet(Loader.dataPath(path)) Loader.checkSchema[Cluster](centroids.schema) - val localCentroids = centroids.map(Cluster.apply).collect() - assert(k == localCentroids.size) + val localCentroids = centroids.rdd.map(Cluster.apply).collect() + assert(k == localCentroids.length) new KMeansModel(localCentroids.sortBy(_.id).map(_.point)) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index eb802a365ed6e..d999b9be8e8ac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -19,10 +19,10 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseVector => BDV} -import org.apache.spark.Logging import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaPairRDD import org.apache.spark.graphx._ +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils @@ -61,14 +61,13 @@ class LDA private ( ldaOptimizer = new EMLDAOptimizer) /** - * Number of topics to infer. I.e., the number of soft cluster centers. - * + * Number of topics to infer, i.e., the number of soft cluster centers. */ @Since("1.3.0") def getK: Int = k /** - * Number of topics to infer. I.e., the number of soft cluster centers. + * Set the number of topics to infer, i.e., the number of soft cluster centers. * (default = 10) */ @Since("1.3.0") @@ -131,7 +130,8 @@ class LDA private ( */ @Since("1.5.0") def setDocConcentration(docConcentration: Vector): this.type = { - require(docConcentration.size > 0, "docConcentration must have > 0 elements") + require(docConcentration.size == 1 || docConcentration.size == k, + s"Size of docConcentration must be 1 or ${k} but got ${docConcentration.size}") this.docConcentration = docConcentration this } @@ -222,29 +222,31 @@ class LDA private ( def setBeta(beta: Double): this.type = setTopicConcentration(beta) /** - * Maximum number of iterations for learning. + * Maximum number of iterations allowed. */ @Since("1.3.0") def getMaxIterations: Int = maxIterations /** - * Maximum number of iterations for learning. + * Set the maximum number of iterations allowed. * (default = 20) */ @Since("1.3.0") def setMaxIterations(maxIterations: Int): this.type = { + require(maxIterations >= 0, + s"Maximum of iterations must be nonnegative but got ${maxIterations}") this.maxIterations = maxIterations this } /** - * Random seed + * Random seed for cluster initialization. */ @Since("1.3.0") def getSeed: Long = seed /** - * Random seed + * Set the random seed for cluster initialization. */ @Since("1.3.0") def setSeed(seed: Long): this.type = { @@ -259,15 +261,18 @@ class LDA private ( def getCheckpointInterval: Int = checkpointInterval /** - * Period (in iterations) between checkpoints (default = 10). Checkpointing helps with recovery + * Parameter for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that + * the cache will get checkpointed every 10 iterations. Checkpointing helps with recovery * (when nodes fail). It also helps with eliminating temporary shuffle files on disk, which can be * important when LDA is run for many iterations. If the checkpoint directory is not set in - * [[org.apache.spark.SparkContext]], this setting is ignored. + * [[org.apache.spark.SparkContext]], this setting is ignored. (default = 10) * * @see [[org.apache.spark.SparkContext#setCheckpointDir]] */ @Since("1.3.0") def setCheckpointInterval(checkpointInterval: Int): this.type = { + require(checkpointInterval == -1 || checkpointInterval > 0, + s"Period between checkpoints must be -1 or positive but got ${checkpointInterval}") this.checkpointInterval = checkpointInterval this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 31d8a9fdea1c6..27b4004927aaa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax, argtopk, normalize, sum} +import breeze.linalg.{argmax, argtopk, normalize, sum, DenseMatrix => BDM, DenseVector => BDV} import breeze.numerics.{exp, lgamma} import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats @@ -183,16 +183,15 @@ abstract class LDAModel private[clustering] extends Saveable { /** * Local LDA model. * This model stores only the inferred topics. - * It may be used for computing topics for new documents, but it may give less accurate answers - * than the [[DistributedLDAModel]]. + * * @param topics Inferred topics (vocabSize x k matrix). */ @Since("1.3.0") -class LocalLDAModel private[clustering] ( +class LocalLDAModel private[spark] ( @Since("1.3.0") val topics: Matrix, @Since("1.5.0") override val docConcentration: Vector, @Since("1.5.0") override val topicConcentration: Double, - override protected[clustering] val gammaShape: Double = 100) + override protected[spark] val gammaShape: Double = 100) extends LDAModel with Serializable { @Since("1.3.0") @@ -353,7 +352,7 @@ class LocalLDAModel private[clustering] ( documents.map { case (id: Long, termCounts: Vector) => if (termCounts.numNonzeros == 0) { - (id, Vectors.zeros(k)) + (id, Vectors.zeros(k)) } else { val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference( termCounts, @@ -366,6 +365,54 @@ class LocalLDAModel private[clustering] ( } } + /** Get a method usable as a UDF for [[topicDistributions()]] */ + private[spark] def getTopicDistributionMethod(sc: SparkContext): Vector => Vector = { + val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t) + val expElogbetaBc = sc.broadcast(expElogbeta) + val docConcentrationBrz = this.docConcentration.toBreeze + val gammaShape = this.gammaShape + val k = this.k + + (termCounts: Vector) => + if (termCounts.numNonzeros == 0) { + Vectors.zeros(k) + } else { + val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, + expElogbetaBc.value, + docConcentrationBrz, + gammaShape, + k) + Vectors.dense(normalize(gamma, 1.0).toArray) + } + } + + /** + * Predicts the topic mixture distribution for a document (often called "theta" in the + * literature). Returns a vector of zeros for an empty document. + * + * Note this means to allow quick query for single document. For batch documents, please refer + * to [[topicDistributions()]] to avoid overhead. + * + * @param document document to predict topic mixture distributions for + * @return topic mixture distribution for the document + */ + @Since("2.0.0") + def topicDistribution(document: Vector): Vector = { + val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t) + if (document.numNonzeros == 0) { + Vectors.zeros(this.k) + } else { + val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference( + document, + expElogbeta, + this.docConcentration.toBreeze, + gammaShape, + this.k) + Vectors.dense(normalize(gamma, 1.0).toArray) + } + } + /** * Java-friendly version of [[topicDistributions]] */ @@ -477,8 +524,6 @@ object LocalLDAModel extends Loader[LocalLDAModel] { /** * Distributed LDA model. * This model stores the inferred topics, the full training dataset, and the topic distributions. - * When computing topics for new documents, it may give more accurate answers - * than the [[LocalLDAModel]]. */ @Since("1.3.0") class DistributedLDAModel private[clustering] ( @@ -489,7 +534,8 @@ class DistributedLDAModel private[clustering] ( @Since("1.5.0") override val docConcentration: Vector, @Since("1.5.0") override val topicConcentration: Double, private[spark] val iterationTimes: Array[Double], - override protected[clustering] val gammaShape: Double = 100) + override protected[clustering] val gammaShape: Double = DistributedLDAModel.defaultGammaShape, + private[spark] val checkpointFiles: Array[String] = Array.empty[String]) extends LDAModel { import LDA._ @@ -761,11 +807,9 @@ class DistributedLDAModel private[clustering] ( override protected def formatVersion = "1.0" - /** - * Java-friendly version of [[topicDistributions]] - */ @Since("1.5.0") override def save(sc: SparkContext, path: String): Unit = { + // Note: This intentionally does not save checkpointFiles. DistributedLDAModel.SaveLoadV1_0.save( sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration, iterationTimes, gammaShape) @@ -777,6 +821,12 @@ class DistributedLDAModel private[clustering] ( @Since("1.5.0") object DistributedLDAModel extends Loader[DistributedLDAModel] { + /** + * The [[DistributedLDAModel]] constructor's default arguments assume gammaShape = 100 + * to ensure equivalence in LDAModel.toLocal conversion. + */ + private[clustering] val defaultGammaShape: Double = 100 + private object SaveLoadV1_0 { val thisFormatVersion = "1.0" @@ -851,11 +901,11 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { Loader.checkSchema[EdgeData](edgeDataFrame.schema) val globalTopicTotals: LDA.TopicCounts = dataFrame.first().getAs[Vector](0).toBreeze.toDenseVector - val vertices: RDD[(VertexId, LDA.TopicCounts)] = vertexDataFrame.map { + val vertices: RDD[(VertexId, LDA.TopicCounts)] = vertexDataFrame.rdd.map { case Row(ind: Long, vec: Vector) => (ind, vec.toBreeze.toDenseVector) } - val edges: RDD[Edge[LDA.TokenCount]] = edgeDataFrame.map { + val edges: RDD[Edge[LDA.TokenCount]] = edgeDataFrame.rdd.map { case Row(srcId: Long, dstId: Long, prop: Double) => Edge(srcId, dstId, prop) } val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 17c0609800e90..6418f0d3b32e5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -19,13 +19,12 @@ package org.apache.spark.mllib.clustering import java.util.Random -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, all, normalize, sum} -import breeze.numerics.{trigamma, abs, exp} +import breeze.linalg.{all, normalize, sum, DenseMatrix => BDM, DenseVector => BDV} +import breeze.numerics.{abs, exp, trigamma} import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.graphx._ -import org.apache.spark.graphx.impl.GraphImpl import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD @@ -81,9 +80,29 @@ final class EMLDAOptimizer extends LDAOptimizer { import LDA._ + // Adjustable parameters + private var keepLastCheckpoint: Boolean = true + + /** + * If using checkpointing, this indicates whether to keep the last checkpoint (vs clean up). + */ + @Since("2.0.0") + def getKeepLastCheckpoint: Boolean = this.keepLastCheckpoint + /** - * The following fields will only be initialized through the initialize() method + * If using checkpointing, this indicates whether to keep the last checkpoint (vs clean up). + * Deleting the checkpoint can cause failures if a data partition is lost, so set this bit with + * care. Note that checkpoints will be cleaned up via reference counting, regardless. + * + * Default: true */ + @Since("2.0.0") + def setKeepLastCheckpoint(keepLastCheckpoint: Boolean): this.type = { + this.keepLastCheckpoint = keepLastCheckpoint + this + } + + // The following fields will only be initialized through the initialize() method private[clustering] var graph: Graph[TopicCounts, TokenCount] = null private[clustering] var k: Int = 0 private[clustering] var vocabSize: Int = 0 @@ -95,7 +114,9 @@ final class EMLDAOptimizer extends LDAOptimizer { /** * Compute bipartite term/doc graph. */ - override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer = { + override private[clustering] def initialize( + docs: RDD[(Long, Vector)], + lda: LDA): EMLDAOptimizer = { // EMLDAOptimizer currently only supports symmetric document-topic priors val docConcentration = lda.getDocConcentration @@ -186,7 +207,7 @@ final class EMLDAOptimizer extends LDAOptimizer { graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg) .mapValues(_._2) // Update the vertex descriptors with the new counts. - val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) + val newGraph = Graph(docTopicDistributions, graph.edges) graph = newGraph graphCheckpointer.update(newGraph) globalTopicTotals = computeGlobalTopicTotals() @@ -207,12 +228,18 @@ final class EMLDAOptimizer extends LDAOptimizer { override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { require(graph != null, "graph is null, EMLDAOptimizer not initialized.") - this.graphCheckpointer.deleteAllCheckpoints() + val checkpointFiles: Array[String] = if (keepLastCheckpoint) { + this.graphCheckpointer.deleteAllCheckpointsButLast() + this.graphCheckpointer.getAllCheckpointFiles + } else { + this.graphCheckpointer.deleteAllCheckpoints() + Array.empty[String] + } // The constructor's default arguments assume gammaShape = 100 to ensure equivalence in - // LDAModel.toLocal conversion + // LDAModel.toLocal conversion. new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize, Vectors.dense(Array.fill(this.k)(this.docConcentration)), this.topicConcentration, - iterationTimes) + iterationTimes, DistributedLDAModel.defaultGammaShape, checkpointFiles) } } @@ -450,10 +477,11 @@ final class OnlineLDAOptimizer extends LDAOptimizer { } Iterator((stat, gammaPart)) } - val statsSum: BDM[Double] = stats.map(_._1).reduce(_ += _) + val statsSum: BDM[Double] = stats.map(_._1).treeAggregate(BDM.zeros[Double](k, vocabSize))( + _ += _, _ += _) expElogbetaBc.unpersist() val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat( - stats.map(_._2).reduce(_ ++ _).map(_.toDenseMatrix): _*) + stats.map(_._2).flatMap(list => list).collect().map(_.toDenseMatrix): _*) val batchResult = statsSum :* expElogbeta.t // Note that this is an optimization to avoid batch.count diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala index a9ba7b60bad08..647d37bd822c1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, max, sum} +import breeze.linalg.{max, sum, DenseMatrix => BDM, DenseVector => BDV} import breeze.numerics._ /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala index b2f140e1b1352..adf20dc4b8b16 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala @@ -19,9 +19,9 @@ package org.apache.spark.mllib.clustering import scala.util.Random -import org.apache.spark.Logging -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.BLAS.{axpy, scal} +import org.apache.spark.mllib.linalg.Vectors /** * An utility object to run K-means locally. This is private to the ML package because it's used diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 7cd9b08fa8e0e..2e257ff9b7def 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -17,20 +17,20 @@ package org.apache.spark.mllib.clustering -import org.json4s.JsonDSL._ import org.json4s._ +import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.graphx._ -import org.apache.spark.graphx.impl.GraphImpl +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.random.XORShiftRandom -import org.apache.spark.{Logging, SparkContext, SparkException} /** * Model produced by [[PowerIterationClustering]]. @@ -70,7 +70,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode @Since("1.4.0") def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render( @@ -84,7 +84,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode @Since("1.4.0") def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { implicit val formats = DefaultFormats - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) @@ -94,7 +94,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode val assignments = sqlContext.read.parquet(Loader.dataPath(path)) Loader.checkSchema[PowerIterationClustering.Assignment](assignments.schema) - val assignmentsRDD = assignments.map { + val assignmentsRDD = assignments.rdd.map { case Row(id: Long, cluster: Int) => PowerIterationClustering.Assignment(id, cluster) } @@ -111,7 +111,9 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode * * @param k Number of clusters. * @param maxIterations Maximum number of iterations of the PIC algorithm. - * @param initMode Initialization mode. + * @param initMode Set the initialization mode. This can be either "random" to use a random vector + * as vertex properties, or "degree" to use normalized sum similarities. + * Default: random. * * @see [[http://en.wikipedia.org/wiki/Spectral_clustering Spectral clustering (Wikipedia)]] */ @@ -135,6 +137,8 @@ class PowerIterationClustering private[clustering] ( */ @Since("1.3.0") def setK(k: Int): this.type = { + require(k > 0, + s"Number of clusters must be positive but got ${k}") this.k = k this } @@ -144,6 +148,8 @@ class PowerIterationClustering private[clustering] ( */ @Since("1.3.0") def setMaxIterations(maxIterations: Int): this.type = { + require(maxIterations >= 0, + s"Maximum of iterations must be nonnegative but got ${maxIterations}") this.maxIterations = maxIterations this } @@ -262,10 +268,12 @@ object PowerIterationClustering extends Logging { }, mergeMsg = _ + _, TripletFields.EdgeOnly) - GraphImpl.fromExistingRDDs(vD, graph.edges) + Graph(vD, graph.edges) .mapTriplets( e => e.attr / math.max(e.srcAttr, MLUtils.EPSILON), - TripletFields.Src) + new TripletFields(/* useSrc */ true, + /* useDst */ false, + /* useEdge */ true)) } /** @@ -291,10 +299,12 @@ object PowerIterationClustering extends Logging { }, mergeMsg = _ + _, TripletFields.EdgeOnly) - GraphImpl.fromExistingRDDs(vD, gA.edges) + Graph(vD, gA.edges) .mapTriplets( e => e.attr / math.max(e.srcAttr, MLUtils.EPSILON), - TripletFields.Src) + new TripletFields(/* useSrc */ true, + /* useDst */ false, + /* useEdge */ true)) } /** @@ -315,7 +325,7 @@ object PowerIterationClustering extends Logging { }, preservesPartitioning = true).cache() val sum = r.values.map(math.abs).sum() val v0 = r.mapValues(x => x / sum) - GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges) + Graph(VertexRDD(v0), g.edges) } /** @@ -330,7 +340,7 @@ object PowerIterationClustering extends Logging { def initDegreeVector(g: Graph[Double, Double]): Graph[Double, Double] = { val sum = g.vertices.values.sum() val v0 = g.vertices.mapValues(_ / sum) - GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges) + Graph(VertexRDD(v0), g.edges) } /** @@ -355,7 +365,9 @@ object PowerIterationClustering extends Logging { val v = curG.aggregateMessages[Double]( sendMsg = ctx => ctx.sendToSrc(ctx.attr * ctx.dstAttr), mergeMsg = _ + _, - TripletFields.Dst).cache() + new TripletFields(/* useSrc */ false, + /* useDst */ true, + /* useEdge */ true)).cache() // normalize v val norm = v.values.map(math.abs).sum() logInfo(s"$msgPrefix: norm(v) = $norm.") @@ -368,7 +380,7 @@ object PowerIterationClustering extends Logging { diffDelta = math.abs(delta - prevDelta) logInfo(s"$msgPrefix: diff(delta) = $diffDelta.") // update v - curG = GraphImpl.fromExistingRDDs(VertexRDD(v1), g.edges) + curG = Graph(VertexRDD(v1), g.edges) prevDelta = delta } curG.vertices @@ -385,7 +397,6 @@ object PowerIterationClustering extends Logging { val points = v.mapValues(x => Vectors.dense(x)).cache() val model = new KMeans() .setK(k) - .setRuns(5) .setSeed(0L) .run(points.values) points.mapValues(p => model.predict(p)).cache() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 80843719f50b4..24e1cff0dcc6b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -19,12 +19,12 @@ package org.apache.spark.mllib.clustering import scala.reflect.ClassTag -import org.apache.spark.Logging import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaSparkContext._ +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStream} +import org.apache.spark.streaming.api.java.{JavaDStream, JavaPairDStream} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -178,24 +178,32 @@ class StreamingKMeans @Since("1.2.0") ( */ @Since("1.2.0") def setK(k: Int): this.type = { + require(k > 0, + s"Number of clusters must be positive but got ${k}") this.k = k this } /** - * Set the decay factor directly (for forgetful algorithms). + * Set the forgetfulness of the previous centroids. */ @Since("1.2.0") def setDecayFactor(a: Double): this.type = { + require(a >= 0, + s"Decay factor must be nonnegative but got ${a}") this.decayFactor = a this } /** - * Set the half life and time unit ("batches" or "points") for forgetful algorithms. + * Set the half life and time unit ("batches" or "points"). If points, then the decay factor + * is raised to the power of number of new points and if batches, then decay factor will be + * used as is. */ @Since("1.2.0") def setHalfLife(halfLife: Double, timeUnit: String): this.type = { + require(halfLife > 0, + s"Half life must be positive but got ${halfLife}") if (timeUnit != StreamingKMeans.BATCHES && timeUnit != StreamingKMeans.POINTS) { throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit) } @@ -210,6 +218,12 @@ class StreamingKMeans @Since("1.2.0") ( */ @Since("1.2.0") def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = { + require(centers.size == weights.size, + "Number of initial centers must be equal to number of weights") + require(centers.size == k, + s"Number of initial centers must be ${k} but got ${centers.size}") + require(weights.forall(_ >= 0), + s"Weight for each inital center must be nonnegative but got [${weights.mkString(" ")}]") model = new StreamingKMeansModel(centers, weights) this } @@ -223,6 +237,10 @@ class StreamingKMeans @Since("1.2.0") ( */ @Since("1.2.0") def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = { + require(dim > 0, + s"Number of dimensions must be positive but got ${dim}") + require(weight >= 0, + s"Weight for each center must be nonnegative but got ${weight}") val random = new XORShiftRandom(seed) val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian()))) val weights = Array.fill(k)(weight) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala index 078fbfbe4f0e1..f0779491e6374 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala @@ -17,8 +17,8 @@ package org.apache.spark.mllib.evaluation -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.rdd.RDD /** * Computes the area under the curve (AUC) using the trapezoidal rule. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index 12cf22095720a..0a7a45b4f4e94 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.evaluation import org.apache.spark.annotation.Since -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.mllib.evaluation.binary._ import org.apache.spark.rdd.{RDD, UnionRDD} import org.apache.spark.sql.DataFrame @@ -58,7 +58,7 @@ class BinaryClassificationMetrics @Since("1.3.0") ( * @param scoreAndLabels a DataFrame with two double columns: score and label */ private[mllib] def this(scoreAndLabels: DataFrame) = - this(scoreAndLabels.map(r => (r.getDouble(0), r.getDouble(1)))) + this(scoreAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) /** * Unpersist intermediate RDDs used in the computation. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index c5104960cfcb6..5dde2bdb17f3a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -38,7 +38,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl * @param predictionAndLabels a DataFrame with two double columns: prediction and label */ private[mllib] def this(predictionAndLabels: DataFrame) = - this(predictionAndLabels.map(r => (r.getDouble(0), r.getDouble(1)))) + this(predictionAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue() private lazy val labelCount: Long = labelCountByClass.values.sum @@ -66,7 +66,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl */ @Since("1.1.0") def confusionMatrix: Matrix = { - val n = labels.size + val n = labels.length val values = Array.ofDim[Double](n * n) var i = 0 while (i < n) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala index c100b3c9ec14a..77bd0aa30dda1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala @@ -19,7 +19,6 @@ package org.apache.spark.mllib.evaluation import org.apache.spark.annotation.Since import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ import org.apache.spark.sql.DataFrame /** @@ -35,7 +34,9 @@ class MultilabelMetrics @Since("1.2.0") (predictionAndLabels: RDD[(Array[Double] * @param predictionAndLabels a DataFrame with two double array columns: prediction and label */ private[mllib] def this(predictionAndLabels: DataFrame) = - this(predictionAndLabels.map(r => (r.getSeq[Double](0).toArray, r.getSeq[Double](1).toArray))) + this(predictionAndLabels.rdd.map { r => + (r.getSeq[Double](0).toArray, r.getSeq[Double](1).toArray) + }) private lazy val numDocs: Long = predictionAndLabels.count() @@ -56,8 +57,8 @@ class MultilabelMetrics @Since("1.2.0") (predictionAndLabels: RDD[(Array[Double] */ @Since("1.2.0") lazy val accuracy: Double = predictionAndLabels.map { case (predictions, labels) => - labels.intersect(predictions).size.toDouble / - (labels.size + predictions.size - labels.intersect(predictions).size)}.sum / numDocs + labels.intersect(predictions).length.toDouble / + (labels.length + predictions.length - labels.intersect(predictions).length)}.sum / numDocs /** @@ -65,7 +66,7 @@ class MultilabelMetrics @Since("1.2.0") (predictionAndLabels: RDD[(Array[Double] */ @Since("1.2.0") lazy val hammingLoss: Double = predictionAndLabels.map { case (predictions, labels) => - labels.size + predictions.size - 2 * labels.intersect(predictions).size + labels.length + predictions.length - 2 * labels.intersect(predictions).length }.sum / (numDocs * numLabels) /** @@ -73,8 +74,8 @@ class MultilabelMetrics @Since("1.2.0") (predictionAndLabels: RDD[(Array[Double] */ @Since("1.2.0") lazy val precision: Double = predictionAndLabels.map { case (predictions, labels) => - if (predictions.size > 0) { - predictions.intersect(labels).size.toDouble / predictions.size + if (predictions.length > 0) { + predictions.intersect(labels).length.toDouble / predictions.length } else { 0 } @@ -85,7 +86,7 @@ class MultilabelMetrics @Since("1.2.0") (predictionAndLabels: RDD[(Array[Double] */ @Since("1.2.0") lazy val recall: Double = predictionAndLabels.map { case (predictions, labels) => - labels.intersect(predictions).size.toDouble / labels.size + labels.intersect(predictions).length.toDouble / labels.length }.sum / numDocs /** @@ -93,7 +94,7 @@ class MultilabelMetrics @Since("1.2.0") (predictionAndLabels: RDD[(Array[Double] */ @Since("1.2.0") lazy val f1Measure: Double = predictionAndLabels.map { case (predictions, labels) => - 2.0 * predictions.intersect(labels).size / (predictions.size + labels.size) + 2.0 * predictions.intersect(labels).length / (predictions.length + labels.length) }.sum / numDocs private lazy val tpPerClass = predictionAndLabels.flatMap { case (predictions, labels) => @@ -151,7 +152,7 @@ class MultilabelMetrics @Since("1.2.0") (predictionAndLabels: RDD[(Array[Double] */ @Since("1.2.0") lazy val microPrecision: Double = { - val sumFp = fpPerClass.foldLeft(0L){ case(cum, (_, fp)) => cum + fp} + val sumFp = fpPerClass.foldLeft(0L) { case(cum, (_, fp)) => cum + fp} sumTp.toDouble / (sumTp + sumFp) } @@ -161,7 +162,7 @@ class MultilabelMetrics @Since("1.2.0") (predictionAndLabels: RDD[(Array[Double] */ @Since("1.2.0") lazy val microRecall: Double = { - val sumFn = fnPerClass.foldLeft(0.0){ case(cum, (_, fn)) => cum + fn} + val sumFn = fnPerClass.foldLeft(0.0) { case(cum, (_, fn)) => cum + fn} sumTp.toDouble / (sumTp + sumFn) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index cc01936dd34b2..c45742cebbfe2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -22,9 +22,9 @@ import java.{lang => jl} import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import org.apache.spark.Logging import org.apache.spark.annotation.Since -import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD /** @@ -83,7 +83,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] /** * Returns the mean average precision (MAP) of all the queries. * If a query has an empty ground truth set, the average precision will be zero and a log - * warining is generated. + * warning is generated. */ lazy val meanAveragePrecision: Double = { predictionAndLabels.map { case (pred, lab) => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index 1d8f4fe340fb4..ef45c9fd9e5cd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -18,20 +18,27 @@ package org.apache.spark.mllib.evaluation import org.apache.spark.annotation.Since -import org.apache.spark.rdd.RDD -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} +import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} +import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** * Evaluator for regression. * - * @param predictionAndObservations an RDD of (prediction, observation) pairs. + * @param predictionAndObservations an RDD of (prediction, observation) pairs + * @param throughOrigin True if the regression is through the origin. For example, in linear + * regression, it will be true without fitting intercept. */ @Since("1.2.0") -class RegressionMetrics @Since("1.2.0") ( - predictionAndObservations: RDD[(Double, Double)]) extends Logging { +class RegressionMetrics @Since("2.0.0") ( + predictionAndObservations: RDD[(Double, Double)], throughOrigin: Boolean) + extends Logging { + + @Since("1.2.0") + def this(predictionAndObservations: RDD[(Double, Double)]) = + this(predictionAndObservations, false) /** * An auxiliary constructor taking a DataFrame. @@ -39,7 +46,7 @@ class RegressionMetrics @Since("1.2.0") ( * prediction and observation */ private[mllib] def this(predictionAndObservations: DataFrame) = - this(predictionAndObservations.map(r => (r.getDouble(0), r.getDouble(1)))) + this(predictionAndObservations.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) /** * Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors. @@ -53,6 +60,8 @@ class RegressionMetrics @Since("1.2.0") ( ) summary } + + private lazy val SSy = math.pow(summary.normL2(0), 2) private lazy val SSerr = math.pow(summary.normL2(1), 2) private lazy val SStot = summary.variance(0) * (summary.count - 1) private lazy val SSreg = { @@ -102,9 +111,16 @@ class RegressionMetrics @Since("1.2.0") ( /** * Returns R^2^, the unadjusted coefficient of determination. * @see [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] + * In case of regression through the origin, the definition of R^2^ is to be modified. + * @see J. G. Eisenhauer, Regression through the Origin. Teaching Statistics 25, 76-80 (2003) + * [[https://online.stat.psu.edu/~ajw13/stat501/SpecialTopics/Reg_thru_origin.pdf]] */ @Since("1.2.0") def r2: Double = { - 1 - SSerr / SStot + if (throughOrigin) { + 1 - SSerr / SSy + } else { + 1 - SSerr / SStot + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index d4d022afde051..4f0e13feae086 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -30,7 +30,7 @@ import org.apache.spark.mllib.stat.Statistics import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext -import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.sql.{Row, SQLContext} /** * Chi Squared selector model. @@ -134,7 +134,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { val thisClassName = "org.apache.spark.mllib.feature.ChiSqSelectorModel" def save(sc: SparkContext, model: ChiSqSelectorModel, path: String): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) @@ -150,7 +150,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { def load(sc: SparkContext, path: String): ChiSqSelectorModel = { implicit val formats = DefaultFormats - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) @@ -161,7 +161,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) - val features = dataArray.map { + val features = dataArray.rdd.map { case Row(feature: Int) => (feature) }.collect() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala index c93ed64183ad6..47c9e850a011d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -36,11 +36,23 @@ import org.apache.spark.util.Utils @Since("1.1.0") class HashingTF(val numFeatures: Int) extends Serializable { + private var binary = false + /** */ @Since("1.1.0") def this() = this(1 << 20) + /** + * If true, term frequency vector will be binary such that non-zero term counts will be set to 1 + * (default: false) + */ + @Since("2.0.0") + def setBinary(value: Boolean): this.type = { + binary = value + this + } + /** * Returns the index of the input term. */ @@ -53,9 +65,10 @@ class HashingTF(val numFeatures: Int) extends Serializable { @Since("1.1.0") def transform(document: Iterable[_]): Vector = { val termFrequencies = mutable.HashMap.empty[Int, Double] + val setTF = if (binary) (i: Int) => 1.0 else (i: Int) => termFrequencies.getOrElse(i, 0.0) + 1.0 document.foreach { term => val i = indexOf(term) - termFrequencies.put(i, termFrequencies.getOrElse(i, 0.0) + 1.0) + termFrequencies.put(i, setTF(i)) } Vectors.sparse(numFeatures, termFrequencies.toSeq) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index cffa9fba05c8a..9457c6e9e35f2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -88,7 +88,7 @@ private object IDF { } doc match { case SparseVector(size, indices, values) => - val nnz = indices.size + val nnz = indices.length var k = 0 while (k < nnz) { if (values(k) > 0) { @@ -97,7 +97,7 @@ private object IDF { k += 1 } case DenseVector(values) => - val n = values.size + val n = values.length var j = 0 while (j < n) { if (values(j) > 0.0) { @@ -211,7 +211,7 @@ private object IDFModel { val n = v.size v match { case SparseVector(size, indices, values) => - val nnz = indices.size + val nnz = indices.length val newValues = new Array[Double](nnz) var k = 0 while (k < nnz) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala index af0c8e1d8a9d2..99fcb36f27e3f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -55,7 +55,7 @@ class Normalizer @Since("1.1.0") (p: Double) extends VectorTransformer { vector match { case DenseVector(vs) => val values = vs.clone() - val size = values.size + val size = values.length var i = 0 while (i < size) { values(i) /= norm @@ -64,7 +64,7 @@ class Normalizer @Since("1.1.0") (p: Double) extends VectorTransformer { Vectors.dense(values) case SparseVector(size, ids, vs) => val values = vs.clone() - val nnz = values.size + val nnz = values.length var i = 0 while (i < nnz) { values(i) /= norm diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala index ecb3c1e6c1c83..30c403e547bee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.feature -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.distributed.RowMatrix @@ -30,7 +30,8 @@ import org.apache.spark.rdd.RDD */ @Since("1.4.0") class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { - require(k >= 1, s"PCA requires a number of principal components k >= 1 but was given $k") + require(k > 0, + s"Number of principal components must be positive but got ${k}") /** * Computes a [[PCAModel]] that contains the principal components of the input vectors. @@ -43,7 +44,8 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { s"source vector size is ${sources.first().size} must be greater than k=$k") val mat = new RowMatrix(sources) - val pc = mat.computePrincipalComponents(k) match { + val (pc, explainedVariance) = mat.computePrincipalComponentsAndExplainedVariance(k) + val densePC = pc match { case dm: DenseMatrix => dm case sm: SparseMatrix => @@ -58,7 +60,13 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { s"SparseMatrix or DenseMatrix. Instead got: ${m.getClass}") } - new PCAModel(k, pc) + val denseExplainedVariance = explainedVariance match { + case dv: DenseVector => + dv + case sv: SparseVector => + sv.toDense + } + new PCAModel(k, densePC, denseExplainedVariance) } /** @@ -77,7 +85,8 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { @Since("1.4.0") class PCAModel private[spark] ( @Since("1.4.0") val k: Int, - @Since("1.4.0") val pc: DenseMatrix) extends VectorTransformer { + @Since("1.4.0") val pc: DenseMatrix, + @Since("1.6.0") val explainedVariance: DenseVector) extends VectorTransformer { /** * Transform a vector by computed Principal Components. * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index 6fe573c528943..5c35e1b91c9bf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -17,8 +17,8 @@ package org.apache.spark.mllib.feature -import org.apache.spark.Logging import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD @@ -132,7 +132,7 @@ class StandardScalerModel @Since("1.3.0") ( vector match { case DenseVector(vs) => val values = vs.clone() - val size = values.size + val size = values.length if (withStd) { var i = 0 while (i < size) { @@ -153,7 +153,7 @@ class StandardScalerModel @Since("1.3.0") ( vector match { case DenseVector(vs) => val values = vs.clone() - val size = values.size + val size = values.length var i = 0 while(i < size) { values(i) *= (if (std(i) != 0.0) 1.0 / std(i) else 0.0) @@ -164,7 +164,7 @@ class StandardScalerModel @Since("1.3.0") ( // For sparse vector, the `index` array inside sparse vector object will not be changed, // so we can re-use it to save memory. val values = vs.clone() - val nnz = values.size + val nnz = values.length var i = 0 while (i < nnz) { values(i) *= (if (std(indices(i)) != 0.0) 1.0 / std(indices(i)) else 0.0) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala index 5778fd1d09254..ca7385128d79a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala @@ -47,7 +47,7 @@ trait VectorTransformer extends Serializable { */ @Since("1.1.0") def transform(data: RDD[Vector]): RDD[Vector] = { - // Later in #1498 , all RDD objects are sent via broadcasting instead of akka. + // Later in #1498 , all RDD objects are sent via broadcasting instead of RPC. // So it should be no longer necessary to explicitly broadcast `this` object. data.map(x => this.transform(x)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index f3e4d346e358a..5b079fce3a83d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -21,24 +21,22 @@ import java.lang.{Iterable => JavaIterable} import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.ArrayBuilder import com.github.fommil.netlib.BLAS.{getInstance => blas} - import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.Logging import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd._ +import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom -import org.apache.spark.sql.SQLContext /** * Entry in vocabulary @@ -77,12 +75,28 @@ class Word2Vec extends Serializable with Logging { private var numIterations = 1 private var seed = Utils.random.nextLong() private var minCount = 5 + private var maxSentenceLength = 1000 + + /** + * Sets the maximum length (in words) of each sentence in the input data. + * Any sentence longer than this threshold will be divided into chunks of + * up to `maxSentenceLength` size (default: 1000) + */ + @Since("2.0.0") + def setMaxSentenceLength(maxSentenceLength: Int): this.type = { + require(maxSentenceLength > 0, + s"Maximum length of sentences must be positive but got ${maxSentenceLength}") + this.maxSentenceLength = maxSentenceLength + this + } /** * Sets vector size (default: 100). */ @Since("1.1.0") def setVectorSize(vectorSize: Int): this.type = { + require(vectorSize > 0, + s"vector size must be positive but got ${vectorSize}") this.vectorSize = vectorSize this } @@ -92,6 +106,8 @@ class Word2Vec extends Serializable with Logging { */ @Since("1.1.0") def setLearningRate(learningRate: Double): this.type = { + require(learningRate > 0, + s"Initial learning rate must be positive but got ${learningRate}") this.learningRate = learningRate this } @@ -101,7 +117,8 @@ class Word2Vec extends Serializable with Logging { */ @Since("1.1.0") def setNumPartitions(numPartitions: Int): this.type = { - require(numPartitions > 0, s"numPartitions must be greater than 0 but got $numPartitions") + require(numPartitions > 0, + s"Number of partitions must be positive but got ${numPartitions}") this.numPartitions = numPartitions this } @@ -112,6 +129,8 @@ class Word2Vec extends Serializable with Logging { */ @Since("1.1.0") def setNumIterations(numIterations: Int): this.type = { + require(numIterations >= 0, + s"Number of iterations must be nonnegative but got ${numIterations}") this.numIterations = numIterations this } @@ -125,12 +144,25 @@ class Word2Vec extends Serializable with Logging { this } + /** + * Sets the window of words (default: 5) + */ + @Since("1.6.0") + def setWindowSize(window: Int): this.type = { + require(window > 0, + s"Window of words must be positive but got ${window}") + this.window = window + this + } + /** * Sets minCount, the minimum number of times a token must appear to be included in the word2vec * model's vocabulary (default: 5). */ @Since("1.3.0") def setMinCount(minCount: Int): this.type = { + require(minCount >= 0, + s"Minimum number of times must be nonnegative but got ${minCount}") this.minCount = minCount this } @@ -138,26 +170,27 @@ class Word2Vec extends Serializable with Logging { private val EXP_TABLE_SIZE = 1000 private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 - private val MAX_SENTENCE_LENGTH = 1000 /** context words from [-window, window] */ - private val window = 5 + private var window = 5 - private var trainWordsCount = 0 + private var trainWordsCount = 0L private var vocabSize = 0 - private var vocab: Array[VocabWord] = null - private var vocabHash = mutable.HashMap.empty[String, Int] + @transient private var vocab: Array[VocabWord] = null + @transient private var vocabHash = mutable.HashMap.empty[String, Int] + + private def learnVocab[S <: Iterable[String]](dataset: RDD[S]): Unit = { + val words = dataset.flatMap(x => x) - private def learnVocab(words: RDD[String]): Unit = { vocab = words.map(w => (w, 1)) .reduceByKey(_ + _) + .filter(_._2 >= minCount) .map(x => VocabWord( x._1, x._2, new Array[Int](MAX_CODE_LENGTH), new Array[Int](MAX_CODE_LENGTH), 0)) - .filter(_.cn >= minCount) .collect() .sortWith((a, b) => a.cn > b.cn) @@ -171,7 +204,7 @@ class Word2Vec extends Serializable with Logging { trainWordsCount += vocab(a).cn a += 1 } - logInfo("trainWordsCount = " + trainWordsCount) + logInfo(s"vocabSize = $vocabSize, trainWordsCount = $trainWordsCount") } private def createExpTable(): Array[Float] = { @@ -264,15 +297,14 @@ class Word2Vec extends Serializable with Logging { /** * Computes the vector representation of each word in vocabulary. - * @param dataset an RDD of words + * @param dataset an RDD of sentences, + * each sentence is expressed as an iterable collection of words * @return a Word2VecModel */ @Since("1.1.0") def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = { - val words = dataset.flatMap(x => x) - - learnVocab(words) + learnVocab(dataset) createBinaryTree() @@ -281,47 +313,40 @@ class Word2Vec extends Serializable with Logging { val expTable = sc.broadcast(createExpTable()) val bcVocab = sc.broadcast(vocab) val bcVocabHash = sc.broadcast(vocabHash) - - val sentences: RDD[Array[Int]] = words.mapPartitions { iter => - new Iterator[Array[Int]] { - def hasNext: Boolean = iter.hasNext - - def next(): Array[Int] = { - val sentence = ArrayBuilder.make[Int] - var sentenceLength = 0 - while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) { - val word = bcVocabHash.value.get(iter.next()) - word match { - case Some(w) => - sentence += w - sentenceLength += 1 - case None => - } - } - sentence.result() - } + // each partition is a collection of sentences, + // will be translated into arrays of Index integer + val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter => + // Each sentence will map to 0 or more Array[Int] + sentenceIter.flatMap { sentence => + // Sentence of words, some of which map to a word index + val wordIndexes = sentence.flatMap(bcVocabHash.value.get) + // break wordIndexes into trunks of maxSentenceLength when has more + wordIndexes.grouped(maxSentenceLength).map(_.toArray) } } val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) - if (vocabSize.toLong * vectorSize * 8 >= Int.MaxValue) { + if (vocabSize.toLong * vectorSize >= Int.MaxValue) { throw new RuntimeException("Please increase minCount or decrease vectorSize in Word2Vec" + " to avoid an OOM. You are highly recommended to make your vocabSize*vectorSize, " + - "which is " + vocabSize + "*" + vectorSize + " for now, less than `Int.MaxValue/8`.") + "which is " + vocabSize + "*" + vectorSize + " for now, less than `Int.MaxValue`.") } val syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) val syn1Global = new Array[Float](vocabSize * vectorSize) var alpha = learningRate + for (k <- 1 to numIterations) { + val bcSyn0Global = sc.broadcast(syn0Global) + val bcSyn1Global = sc.broadcast(syn1Global) val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) val syn0Modify = new Array[Int](vocabSize) val syn1Modify = new Array[Int](vocabSize) - val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) { + val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 0L, 0L)) { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount var wc = wordCount @@ -333,9 +358,9 @@ class Word2Vec extends Serializable with Logging { if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001 logInfo("wordCount = " + wordCount + ", alpha = " + alpha) } - wc += sentence.size + wc += sentence.length var pos = 0 - while (pos < sentence.size) { + while (pos < sentence.length) { val word = sentence(pos) val b = random.nextInt(window) // Train Skip-gram @@ -343,7 +368,7 @@ class Word2Vec extends Serializable with Logging { while (a < window * 2 + 1 - b) { if (a != window) { val c = pos - window + a - if (c >= 0 && c < sentence.size) { + if (c >= 0 && c < sentence.length) { val lastWord = sentence(c) val l1 = lastWord * vectorSize val neu1e = new Array[Float](vectorSize) @@ -405,6 +430,8 @@ class Word2Vec extends Serializable with Logging { } i += 1 } + bcSyn0Global.unpersist(false) + bcSyn1Global.unpersist(false) } newSentences.unpersist() @@ -432,9 +459,9 @@ class Word2Vec extends Serializable with Logging { * (i * vectorSize, i * vectorSize + vectorSize) */ @Since("1.1.0") -class Word2VecModel private[mllib] ( - private val wordIndex: Map[String, Int], - private val wordVectors: Array[Float]) extends Serializable with Saveable { +class Word2VecModel private[spark] ( + private[spark] val wordIndex: Map[String, Int], + private[spark] val wordVectors: Array[Float]) extends Serializable with Saveable { private val numWords = wordIndex.size // vectorSize: Dimension of each word's vector. @@ -464,15 +491,6 @@ class Word2VecModel private[mllib] ( this(Word2VecModel.buildWordIndex(model), Word2VecModel.buildWordVectors(model)) } - private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { - require(v1.length == v2.length, "Vectors should have the same length") - val n = v1.length - val norm1 = blas.snrm2(n, v1, 1) - val norm2 = blas.snrm2(n, v2, 1) - if (norm1 == 0 || norm2 == 0) return 0.0 - blas.sdot(n, v1, 1, v2, 1) / norm1 / norm2 - } - override protected def formatVersion = "1.0" @Since("1.4.0") @@ -529,16 +547,27 @@ class Word2VecModel private[mllib] ( // Need not divide with the norm of the given vector since it is constant. val cosVec = cosineVec.map(_.toDouble) var ind = 0 + val vecNorm = blas.snrm2(vectorSize, fVector, 1) while (ind < numWords) { - cosVec(ind) /= wordVecNorms(ind) + val norm = wordVecNorms(ind) + if (norm == 0.0) { + cosVec(ind) = 0.0 + } else { + cosVec(ind) /= norm + } ind += 1 } - wordList.zip(cosVec) + var topResults = wordList.zip(cosVec) .toSeq - .sortBy(- _._2) + .sortBy(-_._2) .take(num + 1) .tail - .toArray + if (vecNorm != 0.0f) { + topResults = topResults.map { case (word, cosVal) => + (word, cosVal / vecNorm) + } + } + topResults.toArray } /** @@ -550,6 +579,7 @@ class Word2VecModel private[mllib] ( (word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize)) } } + } @Since("1.4.0") @@ -561,7 +591,7 @@ object Word2VecModel extends Loader[Word2VecModel] { private def buildWordVectors(model: Map[String, Array[Float]]): Array[Float] = { require(model.nonEmpty, "Word2VecMap should be non-empty") - val (vectorSize, numWords) = (model.head._2.size, model.size) + val (vectorSize, numWords) = (model.head._2.length, model.size) val wordList = model.keys.toArray val wordVectors = new Array[Float](vectorSize * numWords) var i = 0 @@ -582,7 +612,7 @@ object Word2VecModel extends Loader[Word2VecModel] { def load(sc: SparkContext, path: String): Word2VecModel = { val dataPath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataFrame = sqlContext.read.parquet(dataPath) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) @@ -594,18 +624,26 @@ object Word2VecModel extends Loader[Word2VecModel] { def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ - val vectorSize = model.values.head.size + val vectorSize = model.values.head.length val numWords = model.size - val metadata = compact(render - (("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ - ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords))) + val metadata = compact(render( + ("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ + ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + // We want to partition the model in partitions of size 32MB + val partitionSize = (1L << 25) + // We calculate the approximate size of the model + // We only calculate the array size, not considering + // the string size, the formula is: + // floatSize * numWords * vectorSize + val approxSize = 4L * numWords * vectorSize + val nPartitions = ((approxSize / partitionSize) + 1).toInt val dataArray = model.toSeq.map { case (w, v) => Data(w, v) } - sc.parallelize(dataArray.toSeq, 1).toDF().write.parquet(Loader.dataPath(path)) + sc.parallelize(dataArray.toSeq, nPartitions).toDF().write.parquet(Loader.dataPath(path)) } } @@ -620,7 +658,7 @@ object Word2VecModel extends Loader[Word2VecModel] { (loadedClassName, loadedVersion) match { case (classNameV1_0, "1.0") => val model = SaveLoadV1_0.load(sc, path) - val vectorSize = model.getVectors.values.head.size + val vectorSize = model.getVectors.values.head.length val numWords = model.getVectors.size require(expectedVectorSize == vectorSize, s"Word2VecModel requires each word to be mapped to a vector of size " + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala index 07eb750b06a3b..9a63cc29dacb5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala @@ -19,10 +19,10 @@ package org.apache.spark.mllib.fpm import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.internal.Logging import org.apache.spark.mllib.fpm.AssociationRules.Rule import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset import org.apache.spark.rdd.RDD @@ -50,7 +50,8 @@ class AssociationRules private[fpm] ( */ @Since("1.5.0") def setMinConfidence(minConfidence: Double): this.type = { - require(minConfidence >= 0.0 && minConfidence <= 1.0) + require(minConfidence >= 0.0 && minConfidence <= 1.0, + s"Minimal confidence must be in range [0, 1] but got ${minConfidence}") this.minConfidence = minConfidence this } @@ -58,7 +59,7 @@ class AssociationRules private[fpm] ( /** * Computes the association rules with confidence above [[minConfidence]]. * @param freqItemsets frequent itemset model obtained from [[FPGrowth]] - * @return a [[Set[Rule[Item]]] containing the assocation rules. + * @return a [[Set[Rule[Item]]] containing the association rules. * */ @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 70ef1ed30c71a..4f4996f3be617 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -20,16 +20,26 @@ package org.apache.spark.mllib.fpm import java.{util => ju} import java.lang.{Iterable => JavaIterable} -import scala.collection.mutable import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.reflect.ClassTag +import scala.reflect.runtime.universe._ -import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} + +import org.apache.spark.{HashPartitioner, Partitioner, SparkContext, SparkException} import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.internal.Logging import org.apache.spark.mllib.fpm.FPGrowth._ +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel /** @@ -39,7 +49,8 @@ import org.apache.spark.storage.StorageLevel */ @Since("1.3.0") class FPGrowthModel[Item: ClassTag] @Since("1.3.0") ( - @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable { + @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]]) + extends Saveable with Serializable { /** * Generates association rules for the [[Item]]s in [[freqItemsets]]. * @param confidence minimal confidence of the rules produced @@ -49,6 +60,89 @@ class FPGrowthModel[Item: ClassTag] @Since("1.3.0") ( val associationRules = new AssociationRules(confidence) associationRules.run(freqItemsets) } + + /** + * Save this model to the given path. + * It only works for Item datatypes supported by DataFrames. + * + * This saves: + * - human-readable (JSON) model metadata to path/metadata/ + * - Parquet formatted data to path/data/ + * + * The model may be loaded using [[FPGrowthModel.load]]. + * + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * If the directory already exists, this method throws an exception. + */ + @Since("2.0.0") + override def save(sc: SparkContext, path: String): Unit = { + FPGrowthModel.SaveLoadV1_0.save(this, path) + } + + override protected val formatVersion: String = "1.0" +} + +@Since("2.0.0") +object FPGrowthModel extends Loader[FPGrowthModel[_]] { + + @Since("2.0.0") + override def load(sc: SparkContext, path: String): FPGrowthModel[_] = { + FPGrowthModel.SaveLoadV1_0.load(sc, path) + } + + private[fpm] object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + private val thisClassName = "org.apache.spark.mllib.fpm.FPGrowthModel" + + def save(model: FPGrowthModel[_], path: String): Unit = { + val sc = model.freqItemsets.sparkContext + val sqlContext = SQLContext.getOrCreate(sc) + + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Get the type of item class + val sample = model.freqItemsets.first().items(0) + val className = sample.getClass.getCanonicalName + val classSymbol = runtimeMirror(getClass.getClassLoader).staticClass(className) + val tpe = classSymbol.selfType + + val itemType = ScalaReflection.schemaFor(tpe).dataType + val fields = Array(StructField("items", ArrayType(itemType)), + StructField("freq", LongType)) + val schema = StructType(fields) + val rowDataRDD = model.freqItemsets.map { x => + Row(x.items, x.freq) + } + sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): FPGrowthModel[_] = { + implicit val formats = DefaultFormats + val sqlContext = SQLContext.getOrCreate(sc) + + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + + val freqItemsets = sqlContext.read.parquet(Loader.dataPath(path)) + val sample = freqItemsets.select("items").head().get(0) + loadImpl(freqItemsets, sample) + } + + def loadImpl[Item: ClassTag](freqItemsets: DataFrame, sample: Item): FPGrowthModel[Item] = { + val freqItemsetsRDD = freqItemsets.select("items", "freq").rdd.map { x => + val items = x.getAs[Seq[Item]](0).toArray + val freq = x.getLong(1) + new FreqItemset(items, freq) + } + new FPGrowthModel(freqItemsetsRDD) + } + } } /** @@ -59,7 +153,7 @@ class FPGrowthModel[Item: ClassTag] @Since("1.3.0") ( * [[http://dx.doi.org/10.1145/335191.335372 Han et al., Mining frequent patterns without candidate * generation]]. * - * @param minSupport the minimal support level of the frequent pattern, any pattern appears + * @param minSupport the minimal support level of the frequent pattern, any pattern that appears * more than (minSupport * size-of-the-dataset) times will be output * @param numPartitions number of partitions used by parallel FP-growth * @@ -86,6 +180,8 @@ class FPGrowth private ( */ @Since("1.3.0") def setMinSupport(minSupport: Double): this.type = { + require(minSupport >= 0.0 && minSupport <= 1.0, + s"Minimal support level must be in range [0, 1] but got ${minSupport}") this.minSupport = minSupport this } @@ -96,6 +192,8 @@ class FPGrowth private ( */ @Since("1.3.0") def setNumPartitions(numPartitions: Int): this.type = { + require(numPartitions > 0, + s"Number of partitions must be positive but got ${numPartitions}") this.numPartitions = numPartitions this } @@ -139,7 +237,7 @@ class FPGrowth private ( partitioner: Partitioner): Array[Item] = { data.flatMap { t => val uniq = t.toSet - if (t.size != uniq.size) { + if (t.length != uniq.size) { throw new SparkException(s"Items in a transaction must be unique but got ${t.toSeq}.") } t diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index 3ea10779a1837..659f875a6dc98 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.fpm import scala.collection.mutable -import org.apache.spark.Logging +import org.apache.spark.internal.Logging /** * Calculate all patterns of a projected database in local mode. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 97916daa2e9ad..4344ab1bade9a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -20,15 +20,25 @@ package org.apache.spark.mllib.fpm import java.{lang => jl, util => ju} import java.util.concurrent.atomic.AtomicInteger -import scala.collection.mutable import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.reflect.ClassTag +import scala.reflect.runtime.universe._ + +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} -import org.apache.spark.Logging +import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.internal.Logging +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel /** @@ -38,9 +48,9 @@ import org.apache.spark.storage.StorageLevel * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns * Efficiently by Prefix-Projected Pattern Growth ([[http://doi.org/10.1109/ICDE.2001.914830]]). * - * @param minSupport the minimal support level of the sequential pattern, any pattern appears - * more than (minSupport * size-of-the-dataset) times will be output - * @param maxPatternLength the maximal length of the sequential pattern, any pattern appears + * @param minSupport the minimal support level of the sequential pattern, any pattern that appears + * more than (minSupport * size-of-the-dataset) times will be output + * @param maxPatternLength the maximal length of the sequential pattern, any pattern that appears * less than maxPatternLength will be output * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the internal * storage format) allowed in a projected database before local @@ -541,7 +551,7 @@ object PrefixSpan extends Logging { } /** - * Represents a frequence sequence. + * Represents a frequent sequence. * @param sequence a sequence of itemsets stored as an Array of Arrays * @param freq frequency * @tparam Item item type @@ -566,4 +576,88 @@ object PrefixSpan extends Logging { @Since("1.5.0") class PrefixSpanModel[Item] @Since("1.5.0") ( @Since("1.5.0") val freqSequences: RDD[PrefixSpan.FreqSequence[Item]]) - extends Serializable + extends Saveable with Serializable { + + /** + * Save this model to the given path. + * It only works for Item datatypes supported by DataFrames. + * + * This saves: + * - human-readable (JSON) model metadata to path/metadata/ + * - Parquet formatted data to path/data/ + * + * The model may be loaded using [[PrefixSpanModel.load]]. + * + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * If the directory already exists, this method throws an exception. + */ + @Since("2.0.0") + override def save(sc: SparkContext, path: String): Unit = { + PrefixSpanModel.SaveLoadV1_0.save(this, path) + } + + override protected val formatVersion: String = "1.0" +} + +@Since("2.0.0") +object PrefixSpanModel extends Loader[PrefixSpanModel[_]] { + + @Since("2.0.0") + override def load(sc: SparkContext, path: String): PrefixSpanModel[_] = { + PrefixSpanModel.SaveLoadV1_0.load(sc, path) + } + + private[fpm] object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + private val thisClassName = "org.apache.spark.mllib.fpm.PrefixSpanModel" + + def save(model: PrefixSpanModel[_], path: String): Unit = { + val sc = model.freqSequences.sparkContext + val sqlContext = SQLContext.getOrCreate(sc) + + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Get the type of item class + val sample = model.freqSequences.first().sequence(0)(0) + val className = sample.getClass.getCanonicalName + val classSymbol = runtimeMirror(getClass.getClassLoader).staticClass(className) + val tpe = classSymbol.selfType + + val itemType = ScalaReflection.schemaFor(tpe).dataType + val fields = Array(StructField("sequence", ArrayType(ArrayType(itemType))), + StructField("freq", LongType)) + val schema = StructType(fields) + val rowDataRDD = model.freqSequences.map { x => + Row(x.sequence, x.freq) + } + sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): PrefixSpanModel[_] = { + implicit val formats = DefaultFormats + val sqlContext = SQLContext.getOrCreate(sc) + + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + + val freqSequences = sqlContext.read.parquet(Loader.dataPath(path)) + val sample = freqSequences.select("sequence").head().get(0) + loadImpl(freqSequences, sample) + } + + def loadImpl[Item: ClassTag](freqSequences: DataFrame, sample: Item): PrefixSpanModel[Item] = { + val freqSequencesRDD = freqSequences.select("sequence", "freq").rdd.map { x => + val sequence = x.getAs[Seq[Seq[Item]]](0).map(_.toArray).toArray + val freq = x.getLong(1) + new PrefixSpan.FreqSequence(sequence, freq) + } + new PrefixSpanModel(freqSequencesRDD) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala index 72d3aabc9b1f4..5c12c9305b99c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala @@ -19,9 +19,10 @@ package org.apache.spark.mllib.impl import scala.collection.mutable -import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.SparkContext +import org.apache.spark.internal.Logging import org.apache.spark.storage.StorageLevel @@ -51,7 +52,8 @@ import org.apache.spark.storage.StorageLevel * - This class removes checkpoint files once later Datasets have been checkpointed. * However, references to the older Datasets will still return isCheckpointed = true. * - * @param checkpointInterval Datasets will be checkpointed at this interval + * @param checkpointInterval Datasets will be checkpointed at this interval. + * If this interval was set as -1, then checkpointing will be disabled. * @param sc SparkContext for the Datasets given to this checkpointer * @tparam T Dataset type, such as RDD[Double] */ @@ -88,7 +90,8 @@ private[mllib] abstract class PeriodicCheckpointer[T]( updateCount += 1 // Handle checkpointing (after persisting) - if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) { + if (checkpointInterval != -1 && (updateCount % checkpointInterval) == 0 + && sc.getCheckpointDir.nonEmpty) { // Add new checkpoint before removing old checkpoints. checkpoint(newData) checkpointQueue.enqueue(newData) @@ -132,6 +135,24 @@ private[mllib] abstract class PeriodicCheckpointer[T]( } } + /** + * Call this at the end to delete any remaining checkpoint files, except for the last checkpoint. + * Note that there may not be any checkpoints at all. + */ + def deleteAllCheckpointsButLast(): Unit = { + while (checkpointQueue.size > 1) { + removeCheckpointFile() + } + } + + /** + * Get all current checkpoint files. + * This is useful in combination with [[deleteAllCheckpointsButLast()]]. + */ + def getAllCheckpointFiles: Array[String] = { + checkpointQueue.flatMap(getCheckpointFiles).toArray + } + /** * Dequeue the oldest checkpointed Dataset, and remove its checkpoint files. * This prints a warning but does not fail if the files cannot be removed. @@ -140,15 +161,20 @@ private[mllib] abstract class PeriodicCheckpointer[T]( val old = checkpointQueue.dequeue() // Since the old checkpoint is not deleted by Spark, we manually delete it. val fs = FileSystem.get(sc.hadoopConfiguration) - getCheckpointFiles(old).foreach { checkpointFile => - try { - fs.delete(new Path(checkpointFile), true) - } catch { - case e: Exception => - logWarning("PeriodicCheckpointer could not remove old checkpoint file: " + - checkpointFile) - } - } + getCheckpointFiles(old).foreach(PeriodicCheckpointer.removeCheckpointFile(_, fs)) } +} + +private[spark] object PeriodicCheckpointer extends Logging { + /** Delete a checkpoint file, and log a warning if deletion fails. */ + def removeCheckpointFile(path: String, fs: FileSystem): Unit = { + try { + fs.delete(new Path(path), true) + } catch { + case e: Exception => + logWarning("PeriodicCheckpointer could not remove old checkpoint file: " + + path) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala index 11a059536c50c..20db6084d0e0d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala @@ -69,7 +69,8 @@ import org.apache.spark.storage.StorageLevel * // checkpointed: graph4 * }}} * - * @param checkpointInterval Graphs will be checkpointed at this interval + * @param checkpointInterval Graphs will be checkpointed at this interval. + * If this interval was set as -1, then checkpointing will be disabled. * @tparam VD Vertex descriptor type * @tparam ED Edge descriptor type * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala index f31ed2aa90a64..145dc22b7428e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala @@ -74,7 +74,7 @@ import org.apache.spark.storage.StorageLevel * * TODO: Move this out of MLlib? */ -private[mllib] class PeriodicRDDCheckpointer[T]( +private[spark] class PeriodicRDDCheckpointer[T]( checkpointInterval: Int, sc: SparkContext) extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index df9f4ae145b88..19cc942aba133 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.linalg import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS} import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS} -import org.apache.spark.Logging +import org.apache.spark.internal.Logging /** * BLAS routines for MLlib's vectors and matrices. @@ -75,7 +75,7 @@ private[spark] object BLAS extends Serializable with Logging { val xValues = x.values val xIndices = x.indices val yValues = y.values - val nnz = xIndices.size + val nnz = xIndices.length if (a == 1.0) { var k = 0 @@ -135,7 +135,7 @@ private[spark] object BLAS extends Serializable with Logging { val xValues = x.values val xIndices = x.indices val yValues = y.values - val nnz = xIndices.size + val nnz = xIndices.length var sum = 0.0 var k = 0 @@ -154,8 +154,8 @@ private[spark] object BLAS extends Serializable with Logging { val xIndices = x.indices val yValues = y.values val yIndices = y.indices - val nnzx = xIndices.size - val nnzy = yIndices.size + val nnzx = xIndices.length + val nnzy = yIndices.length var kx = 0 var ky = 0 @@ -188,7 +188,7 @@ private[spark] object BLAS extends Serializable with Logging { val sxIndices = sx.indices val sxValues = sx.values val dyValues = dy.values - val nnz = sxIndices.size + val nnz = sxIndices.length var i = 0 var k = 0 @@ -420,7 +420,7 @@ private[spark] object BLAS extends Serializable with Logging { val AcolPtrs = A.colPtrs // Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices - if (A.isTransposed){ + if (A.isTransposed) { var colCounterForB = 0 if (!B.isTransposed) { // Expensive to put the check inside the loop while (colCounterForB < nB) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala index 0cd371e9cce34..e4494792bb390 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala @@ -33,11 +33,11 @@ private[spark] object CholeskyDecomposition { * @return the solution array */ def solve(A: Array[Double], bx: Array[Double]): Array[Double] = { - val k = bx.size + val k = bx.length val info = new intW(0) lapack.dppsv("U", k, 1, A, bx, k, info) val code = info.`val` - assert(code == 0, s"lapack.dpotrs returned $code.") + assert(code == 0, s"lapack.dppsv returned $code.") bx } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala index 863abe86d38d7..bb94745f078e8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV} import com.github.fommil.netlib.ARPACK -import org.netlib.util.{intW, doubleW} +import org.netlib.util.{doubleW, intW} /** * Compute eigen-decomposition. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 8879dcf75c9bf..8c09b69b3c751 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -19,13 +19,14 @@ package org.apache.spark.mllib.linalg import java.util.{Arrays, Random} -import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHashSet, ArrayBuffer} +import scala.collection.mutable.{ArrayBuffer, ArrayBuilder => MArrayBuilder, HashSet => MHashSet} import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM} +import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ @@ -58,6 +59,20 @@ sealed trait Matrix extends Serializable { newArray } + /** + * Returns an iterator of column vectors. + * This operation could be expensive, depending on the underlying storage. + */ + @Since("2.0.0") + def colIter: Iterator[Vector] + + /** + * Returns an iterator of row vectors. + * This operation could be expensive, depending on the underlying storage. + */ + @Since("2.0.0") + def rowIter: Iterator[Vector] = this.transpose.colIter + /** Converts to a breeze matrix. */ private[mllib] def toBreeze: BM[Double] @@ -108,14 +123,18 @@ sealed trait Matrix extends Serializable { @Since("1.4.0") def toString(maxLines: Int, maxLineWidth: Int): String = toBreeze.toString(maxLines, maxLineWidth) - /** Map the values of this matrix using a function. Generates a new matrix. Performs the - * function on only the backing array. For example, an operation such as addition or - * subtraction will only be performed on the non-zero values in a `SparseMatrix`. */ + /** + * Map the values of this matrix using a function. Generates a new matrix. Performs the + * function on only the backing array. For example, an operation such as addition or + * subtraction will only be performed on the non-zero values in a `SparseMatrix`. + */ private[spark] def map(f: Double => Double): Matrix - /** Update all the values of this matrix using the function f. Performed in-place on the - * backing array. For example, an operation such as addition or subtraction will only be - * performed on the non-zero values in a `SparseMatrix`. */ + /** + * Update all the values of this matrix using the function f. Performed in-place on the + * backing array. For example, an operation such as addition or subtraction will only be + * performed on the non-zero values in a `SparseMatrix`. + */ private[mllib] def update(f: Double => Double): Matrix /** @@ -141,7 +160,6 @@ sealed trait Matrix extends Serializable { def numActives: Int } -@DeveloperApi private[spark] class MatrixUDT extends UserDefinedType[Matrix] { override def sqlType: StructType = { @@ -162,7 +180,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { )) } - override def serialize(obj: Any): InternalRow = { + override def serialize(obj: Matrix): InternalRow = { val row = new GenericMutableRow(7) obj match { case sm: SparseMatrix => @@ -279,7 +297,7 @@ class DenseMatrix @Since("1.3.0") ( } override def hashCode: Int = { - com.google.common.base.Objects.hashCode(numRows : Integer, numCols: Integer, toArray) + com.google.common.base.Objects.hashCode(numRows: Integer, numCols: Integer, toArray) } private[mllib] def toBreeze: BM[Double] = { @@ -386,6 +404,21 @@ class DenseMatrix @Since("1.3.0") ( } new SparseMatrix(numRows, numCols, colPtrs, rowIndices.result(), spVals.result()) } + + @Since("2.0.0") + override def colIter: Iterator[Vector] = { + if (isTransposed) { + Iterator.tabulate(numCols) { j => + val col = new Array[Double](numRows) + blas.dcopy(numRows, values, j, numCols, col, 0, 1) + new DenseVector(col) + } + } else { + Iterator.tabulate(numCols) { j => + new DenseVector(values.slice(j * numRows, (j + 1) * numRows)) + } + } + } } /** @@ -584,7 +617,7 @@ class SparseMatrix @Since("1.3.0") ( private[mllib] def update(i: Int, j: Int, v: Double): Unit = { val ind = index(i, j) - if (ind == -1) { + if (ind < 0) { throw new NoSuchElementException("The given row and column indices correspond to a zero " + "value. Only non-zero elements in Sparse Matrices can be updated.") } else { @@ -656,6 +689,38 @@ class SparseMatrix @Since("1.3.0") ( @Since("1.5.0") override def numActives: Int = values.length + @Since("2.0.0") + override def colIter: Iterator[Vector] = { + if (isTransposed) { + val indicesArray = Array.fill(numCols)(MArrayBuilder.make[Int]) + val valuesArray = Array.fill(numCols)(MArrayBuilder.make[Double]) + var i = 0 + while (i < numRows) { + var k = colPtrs(i) + val rowEnd = colPtrs(i + 1) + while (k < rowEnd) { + val j = rowIndices(k) + indicesArray(j) += i + valuesArray(j) += values(k) + k += 1 + } + i += 1 + } + Iterator.tabulate(numCols) { j => + val ii = indicesArray(j).result() + val vv = valuesArray(j).result() + new SparseVector(numRows, ii, vv) + } + } else { + Iterator.tabulate(numCols) { j => + val colStart = colPtrs(j) + val colEnd = colPtrs(j + 1) + val ii = rowIndices.slice(colStart, colEnd) + val vv = values.slice(colStart, colEnd) + new SparseVector(numRows, ii, vv) + } + } + } } /** @@ -879,8 +944,16 @@ object Matrices { case dm: BDM[Double] => new DenseMatrix(dm.rows, dm.cols, dm.data, dm.isTranspose) case sm: BSM[Double] => + // Spark-11507. work around breeze issue 479. + val mat = if (sm.colPtrs.last != sm.data.length) { + val matCopy = sm.copy + matCopy.compact() + matCopy + } else { + sm + } // There is no isTranspose flag for sparse matrices in Breeze - new SparseMatrix(sm.rows, sm.cols, sm.colPtrs, sm.rowIndices, sm.data) + new SparseMatrix(mat.rows, mat.cols, mat.colPtrs, mat.rowIndices, mat.data) case _ => throw new UnsupportedOperationException( s"Do not support conversion from type ${breeze.getClass.getName}.") @@ -987,7 +1060,7 @@ object Matrices { def horzcat(matrices: Array[Matrix]): Matrix = { if (matrices.isEmpty) { return new DenseMatrix(0, 0, Array[Double]()) - } else if (matrices.size == 1) { + } else if (matrices.length == 1) { return matrices(0) } val numRows = matrices(0).numRows @@ -1046,7 +1119,7 @@ object Matrices { def vertcat(matrices: Array[Matrix]): Matrix = { if (matrices.isEmpty) { return new DenseMatrix(0, 0, Array[Double]()) - } else if (matrices.size == 1) { + } else if (matrices.length == 1) { return matrices(0) } val numCols = matrices(0).numCols diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index bd9badc03c345..5812cdde2c427 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -17,13 +17,16 @@ package org.apache.spark.mllib.linalg -import java.util import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable} +import java.util import scala.annotation.varargs import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, parse => parseJson, render} import org.apache.spark.SparkException import org.apache.spark.annotation.{AlphaComponent, Since} @@ -171,13 +174,19 @@ sealed trait Vector extends Serializable { */ @Since("1.5.0") def argmax: Int + + /** + * Converts the vector to a JSON string. + */ + @Since("1.6.0") + def toJson: String } /** * :: AlphaComponent :: * * User-defined type for [[Vector]] which allows easy interaction with SQL - * via [[org.apache.spark.sql.DataFrame]]. + * via [[org.apache.spark.sql.Dataset]]. */ @AlphaComponent class VectorUDT extends UserDefinedType[Vector] { @@ -194,7 +203,7 @@ class VectorUDT extends UserDefinedType[Vector] { StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true))) } - override def serialize(obj: Any): InternalRow = { + override def serialize(obj: Vector): InternalRow = { obj match { case SparseVector(size, indices, values) => val row = new GenericMutableRow(4) @@ -339,6 +348,27 @@ object Vectors { parseNumeric(NumericParser.parse(s)) } + /** + * Parses the JSON representation of a vector into a [[Vector]]. + */ + @Since("1.6.0") + def fromJson(json: String): Vector = { + implicit val formats = DefaultFormats + val jValue = parseJson(json) + (jValue \ "type").extract[Int] match { + case 0 => // sparse + val size = (jValue \ "size").extract[Int] + val indices = (jValue \ "indices").extract[Seq[Int]].toArray + val values = (jValue \ "values").extract[Seq[Double]].toArray + sparse(size, indices, values) + case 1 => // dense + val values = (jValue \ "values").extract[Seq[Double]].toArray + dense(values) + case _ => + throw new IllegalArgumentException(s"Cannot parse $json into a vector.") + } + } + private[mllib] def parseNumeric(any: Any): Vector = { any match { case values: Array[Double] => @@ -650,6 +680,12 @@ class DenseVector @Since("1.0.0") ( maxIdx } } + + @Since("1.6.0") + override def toJson: String = { + val jValue = ("type" -> 1) ~ ("values" -> values.toSeq) + compact(render(jValue)) + } } @Since("1.3.0") @@ -837,6 +873,15 @@ class SparseVector @Since("1.0.0") ( }.unzip new SparseVector(selectedIndices.length, sliceInds.toArray, sliceVals.toArray) } + + @Since("1.6.0") + override def toJson: String = { + val jValue = ("type" -> 0) ~ + ("size" -> size) ~ + ("indices" -> indices.toSeq) ~ + ("values" -> values.toSeq) + compact(render(jValue)) + } } @Since("1.3.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index 09527dcf5d9e5..580d7a98fb362 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -19,11 +19,12 @@ package org.apache.spark.mllib.linalg.distributed import scala.collection.mutable.ArrayBuffer -import breeze.linalg.{DenseMatrix => BDM} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Matrix => BM, SparseVector => BSV, Vector => BV} -import org.apache.spark.{Logging, Partitioner, SparkException} +import org.apache.spark.{Partitioner, SparkException} import org.apache.spark.annotation.Since -import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix} +import org.apache.spark.internal.Logging +import org.apache.spark.mllib.linalg._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -176,7 +177,7 @@ class BlockMatrix @Since("1.3.0") ( val numColBlocks = math.ceil(numCols() * 1.0 / colsPerBlock).toInt private[mllib] def createPartitioner(): GridPartitioner = - GridPartitioner(numRowBlocks, numColBlocks, suggestedNumPartitions = blocks.partitions.size) + GridPartitioner(numRowBlocks, numColBlocks, suggestedNumPartitions = blocks.partitions.length) private lazy val blockInfo = blocks.mapValues(block => (block.numRows, block.numCols)).cache() @@ -263,13 +264,35 @@ class BlockMatrix @Since("1.3.0") ( new CoordinateMatrix(entryRDD, numRows(), numCols()) } + /** Converts to IndexedRowMatrix. The number of columns must be within the integer range. */ @Since("1.3.0") def toIndexedRowMatrix(): IndexedRowMatrix = { - require(numCols() < Int.MaxValue, "The number of columns must be within the integer range. " + - s"numCols: ${numCols()}") - // TODO: This implementation may be optimized - toCoordinateMatrix().toIndexedRowMatrix() + val cols = numCols().toInt + + require(cols < Int.MaxValue, s"The number of columns should be less than Int.MaxValue ($cols).") + + val rows = blocks.flatMap { case ((blockRowIdx, blockColIdx), mat) => + mat.rowIter.zipWithIndex.map { + case (vector, rowIdx) => + blockRowIdx * rowsPerBlock + rowIdx -> (blockColIdx, vector.toBreeze) + } + }.groupByKey().map { case (rowIdx, vectors) => + val numberNonZeroPerRow = vectors.map(_._2.activeSize).sum.toDouble / cols.toDouble + + val wholeVector = if (numberNonZeroPerRow <= 0.1) { // Sparse at 1/10th nnz + BSV.zeros[Double](cols) + } else { + BDV.zeros[Double](cols) + } + + vectors.foreach { case (blockColIdx: Int, vec: BV[Double]) => + val offset = colsPerBlock * blockColIdx + wholeVector(offset until offset + colsPerBlock) := vec + } + new IndexedRow(rowIdx, Vectors.fromBreeze(wholeVector)) + } + new IndexedRowMatrix(rows) } /** Collect the distributed matrix on the driver as a `DenseMatrix`. */ @@ -317,40 +340,72 @@ class BlockMatrix @Since("1.3.0") ( } /** - * Adds two block matrices together. The matrices must have the same size and matching - * `rowsPerBlock` and `colsPerBlock` values. If one of the blocks that are being added are - * instances of [[SparseMatrix]], the resulting sub matrix will also be a [[SparseMatrix]], even - * if it is being added to a [[DenseMatrix]]. If two dense matrices are added, the output will - * also be a [[DenseMatrix]]. + * For given matrices `this` and `other` of compatible dimensions and compatible block dimensions, + * it applies a binary function on their corresponding blocks. + * + * @param other The second BlockMatrix argument for the operator specified by `binMap` + * @param binMap A function taking two breeze matrices and returning a breeze matrix + * @return A [[BlockMatrix]] whose blocks are the results of a specified binary map on blocks + * of `this` and `other`. + * Note: `blockMap` ONLY works for `add` and `subtract` methods and it does not support + * operators such as (a, b) => -a + b + * TODO: Make the use of zero matrices more storage efficient. */ - @Since("1.3.0") - def add(other: BlockMatrix): BlockMatrix = { + private[mllib] def blockMap( + other: BlockMatrix, + binMap: (BM[Double], BM[Double]) => BM[Double]): BlockMatrix = { require(numRows() == other.numRows(), "Both matrices must have the same number of rows. " + s"A.numRows: ${numRows()}, B.numRows: ${other.numRows()}") require(numCols() == other.numCols(), "Both matrices must have the same number of columns. " + s"A.numCols: ${numCols()}, B.numCols: ${other.numCols()}") if (rowsPerBlock == other.rowsPerBlock && colsPerBlock == other.colsPerBlock) { - val addedBlocks = blocks.cogroup(other.blocks, createPartitioner()) + val newBlocks = blocks.cogroup(other.blocks, createPartitioner()) .map { case ((blockRowIndex, blockColIndex), (a, b)) => if (a.size > 1 || b.size > 1) { throw new SparkException("There are multiple MatrixBlocks with indices: " + s"($blockRowIndex, $blockColIndex). Please remove them.") } if (a.isEmpty) { - new MatrixBlock((blockRowIndex, blockColIndex), b.head) + val zeroBlock = BM.zeros[Double](b.head.numRows, b.head.numCols) + val result = binMap(zeroBlock, b.head.toBreeze) + new MatrixBlock((blockRowIndex, blockColIndex), Matrices.fromBreeze(result)) } else if (b.isEmpty) { new MatrixBlock((blockRowIndex, blockColIndex), a.head) } else { - val result = a.head.toBreeze + b.head.toBreeze + val result = binMap(a.head.toBreeze, b.head.toBreeze) new MatrixBlock((blockRowIndex, blockColIndex), Matrices.fromBreeze(result)) } } - new BlockMatrix(addedBlocks, rowsPerBlock, colsPerBlock, numRows(), numCols()) + new BlockMatrix(newBlocks, rowsPerBlock, colsPerBlock, numRows(), numCols()) } else { - throw new SparkException("Cannot add matrices with different block dimensions") + throw new SparkException("Cannot perform on matrices with different block dimensions") } } + /** + * Adds the given block matrix `other` to `this` block matrix: `this + other`. + * The matrices must have the same size and matching `rowsPerBlock` and `colsPerBlock` + * values. If one of the blocks that are being added are instances of [[SparseMatrix]], + * the resulting sub matrix will also be a [[SparseMatrix]], even if it is being added + * to a [[DenseMatrix]]. If two dense matrices are added, the output will also be a + * [[DenseMatrix]]. + */ + @Since("1.3.0") + def add(other: BlockMatrix): BlockMatrix = + blockMap(other, (x: BM[Double], y: BM[Double]) => x + y) + + /** + * Subtracts the given block matrix `other` from `this` block matrix: `this - other`. + * The matrices must have the same size and matching `rowsPerBlock` and `colsPerBlock` + * values. If one of the blocks that are being subtracted are instances of [[SparseMatrix]], + * the resulting sub matrix will also be a [[SparseMatrix]], even if it is being subtracted + * from a [[DenseMatrix]]. If two dense matrices are subtracted, the output will also be a + * [[DenseMatrix]]. + */ + @Since("2.0.0") + def subtract(other: BlockMatrix): BlockMatrix = + blockMap(other, (x: BM[Double], y: BM[Double]) => x - y) + /** Block (i,j) --> Set of destination partitions */ private type BlockDestinations = Map[(Int, Int), Set[Int]] diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala index 8a70f34e70f6a..97b03b340f20e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala @@ -20,8 +20,8 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.annotation.Since -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Matrix, SparseMatrix, Vectors} +import org.apache.spark.rdd.RDD /** * Represents an entry in an distributed matrix. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 976299124cedd..06b9c4ac67bb0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -20,9 +20,9 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.annotation.Since -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.SingularValueDecomposition +import org.apache.spark.rdd.RDD /** * Represents a row of [[org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix]]. @@ -120,9 +120,9 @@ class IndexedRowMatrix @Since("1.0.0") ( val rowIndex = row.index row.vector match { case SparseVector(size, indices, values) => - Iterator.tabulate(indices.size)(i => MatrixEntry(rowIndex, indices(i), values(i))) + Iterator.tabulate(indices.length)(i => MatrixEntry(rowIndex, indices(i), values(i))) case DenseVector(values) => - Iterator.tabulate(values.size)(i => MatrixEntry(rowIndex, i, values(i))) + Iterator.tabulate(values.length)(i => MatrixEntry(rowIndex, i, values(i))) } } new CoordinateMatrix(entries, numRows(), numCols()) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 52c0f19c645d9..f6183a5eaadc0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -21,17 +21,17 @@ import java.util.Arrays import scala.collection.mutable.ListBuffer -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV, axpy => brzAxpy, - svd => brzSvd, MatrixSingularException, inv} +import breeze.linalg.{axpy => brzAxpy, inv, svd => brzSvd, DenseMatrix => BDM, DenseVector => BDV, + MatrixSingularException, SparseVector => BSV} import breeze.numerics.{sqrt => brzSqrt} -import org.apache.spark.Logging import org.apache.spark.annotation.Since +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} import org.apache.spark.rdd.RDD -import org.apache.spark.util.random.XORShiftRandom import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.random.XORShiftRandom /** * Represents a row-oriented distributed Matrix with no meaningful row indices. @@ -368,7 +368,8 @@ class RowMatrix @Since("1.0.0") ( } /** - * Computes the top k principal components. + * Computes the top k principal components and a vector of proportions of + * variance explained by each principal component. * Rows correspond to observations and columns correspond to variables. * The principal components are stored a local matrix of size n-by-k. * Each column corresponds for one principal component, @@ -379,24 +380,42 @@ class RowMatrix @Since("1.0.0") ( * Note that this cannot be computed on matrices with more than 65535 columns. * * @param k number of top principal components. - * @return a matrix of size n-by-k, whose columns are principal components + * @return a matrix of size n-by-k, whose columns are principal components, and + * a vector of values which indicate how much variance each principal component + * explains */ - @Since("1.0.0") - def computePrincipalComponents(k: Int): Matrix = { + @Since("1.6.0") + def computePrincipalComponentsAndExplainedVariance(k: Int): (Matrix, Vector) = { val n = numCols().toInt require(k > 0 && k <= n, s"k = $k out of range (0, n = $n]") val Cov = computeCovariance().toBreeze.asInstanceOf[BDM[Double]] - val brzSvd.SVD(u: BDM[Double], _, _) = brzSvd(Cov) + val brzSvd.SVD(u: BDM[Double], s: BDV[Double], _) = brzSvd(Cov) + + val eigenSum = s.data.sum + val explainedVariance = s.data.map(_ / eigenSum) if (k == n) { - Matrices.dense(n, k, u.data) + (Matrices.dense(n, k, u.data), Vectors.dense(explainedVariance)) } else { - Matrices.dense(n, k, Arrays.copyOfRange(u.data, 0, n * k)) + (Matrices.dense(n, k, Arrays.copyOfRange(u.data, 0, n * k)), + Vectors.dense(Arrays.copyOfRange(explainedVariance, 0, k))) } } + /** + * Computes the top k principal components only. + * + * @param k number of top principal components. + * @return a matrix of size n-by-k, whose columns are principal components + * @see computePrincipalComponentsAndExplainedVariance + */ + @Since("1.0.0") + def computePrincipalComponents(k: Int): Matrix = { + computePrincipalComponentsAndExplainedVariance(k)._1 + } + /** * Computes column-wise summary statistics. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 3b663b5defb03..a67ea836e5681 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -19,12 +19,12 @@ package org.apache.spark.mllib.optimization import scala.collection.mutable.ArrayBuffer -import breeze.linalg.{DenseVector => BDV, norm} +import breeze.linalg.{norm, DenseVector => BDV} -import org.apache.spark.annotation.{Experimental, DeveloperApi} -import org.apache.spark.Logging +import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.internal.Logging +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.{Vectors, Vector} /** @@ -46,6 +46,8 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va * In subsequent steps, the step size will decrease with stepSize/sqrt(t) */ def setStepSize(step: Double): this.type = { + require(step > 0, + s"Initial step size must be positive but got ${step}") this.stepSize = step this } @@ -57,6 +59,8 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va */ @Experimental def setMiniBatchFraction(fraction: Double): this.type = { + require(fraction > 0 && fraction <= 1.0, + s"Fraction for mini-batch SGD must be in range (0, 1] but got ${fraction}") this.miniBatchFraction = fraction this } @@ -65,6 +69,8 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va * Set the number of iterations for SGD. Default 100. */ def setNumIterations(iters: Int): this.type = { + require(iters >= 0, + s"Number of iterations must be nonnegative but got ${iters}") this.numIterations = iters this } @@ -73,6 +79,8 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va * Set the regularization parameter. Default 0.0. */ def setRegParam(regParam: Double): this.type = { + require(regParam >= 0, + s"Regularization parameter must be nonnegative but got ${regParam}") this.regParam = regParam this } @@ -81,15 +89,18 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va * Set the convergence tolerance. Default 0.001 * convergenceTol is a condition which decides iteration termination. * The end of iteration is decided based on below logic. - * - If the norm of the new solution vector is >1, the diff of solution vectors - * is compared to relative tolerance which means normalizing by the norm of - * the new solution vector. - * - If the norm of the new solution vector is <=1, the diff of solution vectors - * is compared to absolute tolerance which is not normalizing. + * + * - If the norm of the new solution vector is >1, the diff of solution vectors + * is compared to relative tolerance which means normalizing by the norm of + * the new solution vector. + * - If the norm of the new solution vector is <=1, the diff of solution vectors + * is compared to absolute tolerance which is not normalizing. + * * Must be between 0.0 and 1.0 inclusively. */ def setConvergenceTol(tolerance: Double): this.type = { - require(0.0 <= tolerance && tolerance <= 1.0) + require(tolerance >= 0.0 && tolerance <= 1.0, + s"Convergence tolerance must be in range [0, 1] but got ${tolerance}") this.convergenceTol = tolerance this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index efedc112d380e..74e2cad76c8f5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -18,13 +18,12 @@ package org.apache.spark.mllib.optimization import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS} -import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.axpy import org.apache.spark.rdd.RDD @@ -41,7 +40,7 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) extends Optimizer with Logging { private var numCorrections = 10 - private var convergenceTol = 1E-4 + private var convergenceTol = 1E-6 private var maxNumIterations = 100 private var regParam = 0.0 @@ -53,47 +52,66 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) * Restriction: numCorrections > 0 */ def setNumCorrections(corrections: Int): this.type = { - assert(corrections > 0) + require(corrections > 0, + s"Number of corrections must be positive but got ${corrections}") this.numCorrections = corrections this } /** - * Set the convergence tolerance of iterations for L-BFGS. Default 1E-4. + * Set the convergence tolerance of iterations for L-BFGS. Default 1E-6. * Smaller value will lead to higher accuracy with the cost of more iterations. * This value must be nonnegative. Lower convergence values are less tolerant * and therefore generally cause more iterations to be run. */ def setConvergenceTol(tolerance: Double): this.type = { + require(tolerance >= 0, + s"Convergence tolerance must be nonnegative but got ${tolerance}") this.convergenceTol = tolerance this } - /** - * Set the maximal number of iterations for L-BFGS. Default 100. - * @deprecated use [[LBFGS#setNumIterations]] instead + /* + * Get the convergence tolerance of iterations. */ - @deprecated("use setNumIterations instead", "1.1.0") - def setMaxNumIterations(iters: Int): this.type = { - this.setNumIterations(iters) + private[mllib] def getConvergenceTol(): Double = { + this.convergenceTol } /** * Set the maximal number of iterations for L-BFGS. Default 100. */ def setNumIterations(iters: Int): this.type = { + require(iters >= 0, + s"Maximum of iterations must be nonnegative but got ${iters}") this.maxNumIterations = iters this } + /** + * Get the maximum number of iterations for L-BFGS. Defaults to 100. + */ + private[mllib] def getNumIterations(): Int = { + this.maxNumIterations + } + /** * Set the regularization parameter. Default 0.0. */ def setRegParam(regParam: Double): this.type = { + require(regParam >= 0, + s"Regularization parameter must be nonnegative but got ${regParam}") this.regParam = regParam this } + /** + * Get the regularization parameter. + */ + private[mllib] def getRegParam(): Double = { + this.regParam + } + /** * Set the gradient function (of the loss function of one single data example) * to be used for L-BFGS. @@ -113,6 +131,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) this } + /** + * Returns the updater, limited to internal use. + */ + private[mllib] def getUpdater(): Updater = { + updater + } + override def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = { val (weights, _) = LBFGS.runLBFGS( data, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala index 7f6d94571b5ef..d8e56720967d8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala @@ -17,10 +17,9 @@ package org.apache.spark.mllib.optimization -import org.apache.spark.rdd.RDD - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala index 9f463e0cafb6f..03c01e0553d78 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala @@ -19,10 +19,10 @@ package org.apache.spark.mllib.optimization import scala.math._ -import breeze.linalg.{norm => brzNorm, axpy => brzAxpy, Vector => BV} +import breeze.linalg.{axpy => brzAxpy, norm => brzNorm, Vector => BV} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala index 622b53a252ac5..a8c32f72bfdeb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala @@ -27,9 +27,9 @@ import org.apache.spark.mllib.regression.GeneralizedLinearModel * PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel */ private[mllib] class BinaryClassificationPMMLModelExport( - model : GeneralizedLinearModel, - description : String, - normalizationMethod : RegressionNormalizationMethodType, + model: GeneralizedLinearModel, + description: String, + normalizationMethod: RegressionNormalizationMethodType, threshold: Double) extends PMMLModelExport { @@ -45,7 +45,7 @@ private[mllib] class BinaryClassificationPMMLModelExport( val fields = new SArray[FieldName](model.weights.size) val dataDictionary = new DataDictionary val miningSchema = new MiningSchema - val regressionTableYES = new RegressionTable(model.intercept).withTargetCategory("1") + val regressionTableYES = new RegressionTable(model.intercept).setTargetCategory("1") var interceptNO = threshold if (RegressionNormalizationMethodType.LOGIT == normalizationMethod) { if (threshold <= 0) { @@ -56,35 +56,35 @@ private[mllib] class BinaryClassificationPMMLModelExport( interceptNO = -math.log(1 / threshold - 1) } } - val regressionTableNO = new RegressionTable(interceptNO).withTargetCategory("0") + val regressionTableNO = new RegressionTable(interceptNO).setTargetCategory("0") val regressionModel = new RegressionModel() - .withFunctionName(MiningFunctionType.CLASSIFICATION) - .withMiningSchema(miningSchema) - .withModelName(description) - .withNormalizationMethod(normalizationMethod) - .withRegressionTables(regressionTableYES, regressionTableNO) + .setFunctionName(MiningFunctionType.CLASSIFICATION) + .setMiningSchema(miningSchema) + .setModelName(description) + .setNormalizationMethod(normalizationMethod) + .addRegressionTables(regressionTableYES, regressionTableNO) for (i <- 0 until model.weights.size) { fields(i) = FieldName.create("field_" + i) - dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + dataDictionary.addDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) miningSchema - .withMiningFields(new MiningField(fields(i)) - .withUsageType(FieldUsageType.ACTIVE)) - regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) + .addMiningFields(new MiningField(fields(i)) + .setUsageType(FieldUsageType.ACTIVE)) + regressionTableYES.addNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) } // add target field val targetField = FieldName.create("target") dataDictionary - .withDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.STRING)) + .addDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.STRING)) miningSchema - .withMiningFields(new MiningField(targetField) - .withUsageType(FieldUsageType.TARGET)) + .addMiningFields(new MiningField(targetField) + .setUsageType(FieldUsageType.TARGET)) - dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + dataDictionary.setNumberOfFields(dataDictionary.getDataFields.size) pmml.setDataDictionary(dataDictionary) - pmml.withModels(regressionModel) + pmml.addModels(regressionModel) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala index 1874786af0002..4d951d2973a6f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala @@ -45,31 +45,31 @@ private[mllib] class GeneralizedLinearPMMLModelExport( val miningSchema = new MiningSchema val regressionTable = new RegressionTable(model.intercept) val regressionModel = new RegressionModel() - .withFunctionName(MiningFunctionType.REGRESSION) - .withMiningSchema(miningSchema) - .withModelName(description) - .withRegressionTables(regressionTable) + .setFunctionName(MiningFunctionType.REGRESSION) + .setMiningSchema(miningSchema) + .setModelName(description) + .addRegressionTables(regressionTable) for (i <- 0 until model.weights.size) { fields(i) = FieldName.create("field_" + i) - dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + dataDictionary.addDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) miningSchema - .withMiningFields(new MiningField(fields(i)) - .withUsageType(FieldUsageType.ACTIVE)) - regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) + .addMiningFields(new MiningField(fields(i)) + .setUsageType(FieldUsageType.ACTIVE)) + regressionTable.addNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) } // for completeness add target field val targetField = FieldName.create("target") - dataDictionary.withDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)) + dataDictionary.addDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)) miningSchema - .withMiningFields(new MiningField(targetField) - .withUsageType(FieldUsageType.TARGET)) + .addMiningFields(new MiningField(targetField) + .setUsageType(FieldUsageType.TARGET)) - dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + dataDictionary.setNumberOfFields(dataDictionary.getDataFields.size) pmml.setDataDictionary(dataDictionary) - pmml.withModels(regressionModel) + pmml.addModels(regressionModel) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala index 069e7afc9fca0..255c6140e5410 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala @@ -26,14 +26,14 @@ import org.apache.spark.mllib.clustering.KMeansModel /** * PMML Model Export for KMeansModel class */ -private[mllib] class KMeansPMMLModelExport(model : KMeansModel) extends PMMLModelExport{ +private[mllib] class KMeansPMMLModelExport(model: KMeansModel) extends PMMLModelExport{ populateKMeansPMML(model) /** * Export the input KMeansModel model to PMML format. */ - private def populateKMeansPMML(model : KMeansModel): Unit = { + private def populateKMeansPMML(model: KMeansModel): Unit = { pmml.getHeader.setDescription("k-means clustering") if (model.clusterCenters.length > 0) { @@ -42,42 +42,42 @@ private[mllib] class KMeansPMMLModelExport(model : KMeansModel) extends PMMLMode val dataDictionary = new DataDictionary val miningSchema = new MiningSchema val comparisonMeasure = new ComparisonMeasure() - .withKind(ComparisonMeasure.Kind.DISTANCE) - .withMeasure(new SquaredEuclidean()) + .setKind(ComparisonMeasure.Kind.DISTANCE) + .setMeasure(new SquaredEuclidean()) val clusteringModel = new ClusteringModel() - .withModelName("k-means") - .withMiningSchema(miningSchema) - .withComparisonMeasure(comparisonMeasure) - .withFunctionName(MiningFunctionType.CLUSTERING) - .withModelClass(ClusteringModel.ModelClass.CENTER_BASED) - .withNumberOfClusters(model.clusterCenters.length) + .setModelName("k-means") + .setMiningSchema(miningSchema) + .setComparisonMeasure(comparisonMeasure) + .setFunctionName(MiningFunctionType.CLUSTERING) + .setModelClass(ClusteringModel.ModelClass.CENTER_BASED) + .setNumberOfClusters(model.clusterCenters.length) for (i <- 0 until clusterCenter.size) { fields(i) = FieldName.create("field_" + i) - dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + dataDictionary.addDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) miningSchema - .withMiningFields(new MiningField(fields(i)) - .withUsageType(FieldUsageType.ACTIVE)) - clusteringModel.withClusteringFields( - new ClusteringField(fields(i)).withCompareFunction(CompareFunctionType.ABS_DIFF)) + .addMiningFields(new MiningField(fields(i)) + .setUsageType(FieldUsageType.ACTIVE)) + clusteringModel.addClusteringFields( + new ClusteringField(fields(i)).setCompareFunction(CompareFunctionType.ABS_DIFF)) } - dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + dataDictionary.setNumberOfFields(dataDictionary.getDataFields.size) - for (i <- 0 until model.clusterCenters.length) { + for (i <- model.clusterCenters.indices) { val cluster = new Cluster() - .withName("cluster_" + i) - .withArray(new org.dmg.pmml.Array() - .withType(Array.Type.REAL) - .withN(clusterCenter.size) - .withValue(model.clusterCenters(i).toArray.mkString(" "))) + .setName("cluster_" + i) + .setArray(new org.dmg.pmml.Array() + .setType(Array.Type.REAL) + .setN(clusterCenter.size) + .setValue(model.clusterCenters(i).toArray.mkString(" "))) // we don't have the size of the single cluster but only the centroids (withValue) // .withSize(value) - clusteringModel.withClusters(cluster) + clusteringModel.addClusters(cluster) } pmml.setDataDictionary(dataDictionary) - pmml.withModels(clusteringModel) + pmml.addModels(clusteringModel) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala index c5fdecd3ca17f..426bb818c9266 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala @@ -30,18 +30,14 @@ private[mllib] trait PMMLModelExport { * Holder of the exported model in PMML format */ @BeanProperty - val pmml: PMML = new PMML - - setHeader(pmml) - - private def setHeader(pmml: PMML): Unit = { + val pmml: PMML = { val version = getClass.getPackage.getImplementationVersion - val app = new Application().withName("Apache Spark MLlib").withVersion(version) + val app = new Application("Apache Spark MLlib").setVersion(version) val timestamp = new Timestamp() - .withContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date())) + .addContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date())) val header = new Header() - .withApplication(app) - .withTimestamp(timestamp) - pmml.setHeader(header) + .setApplication(app) + .setTimestamp(timestamp) + new PMML("4.2", header, null) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala index 9eab7efc160da..fa04f8eb5e796 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala @@ -19,8 +19,8 @@ package org.apache.spark.mllib.random import org.apache.commons.math3.distribution._ -import org.apache.spark.annotation.{Since, DeveloperApi} -import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} +import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.util.random.{Pseudorandom, XORShiftRandom} /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index 78172843be56e..e8a937ffcb96f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -37,39 +37,20 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { * trigger a Spark job if the parent RDD has more than one partitions and the window size is * greater than 1. */ - def sliding(windowSize: Int): RDD[Array[T]] = { + def sliding(windowSize: Int, step: Int): RDD[Array[T]] = { require(windowSize > 0, s"Sliding window size must be positive, but got $windowSize.") - if (windowSize == 1) { + if (windowSize == 1 && step == 1) { self.map(Array(_)) } else { - new SlidingRDD[T](self, windowSize) + new SlidingRDD[T](self, windowSize, step) } } /** - * Reduces the elements of this RDD in a multi-level tree pattern. - * - * @param depth suggested depth of the tree (default: 2) - * @see [[org.apache.spark.rdd.RDD#treeReduce]] - * @deprecated Use [[org.apache.spark.rdd.RDD#treeReduce]] instead. + * [[sliding(Int, Int)*]] with step = 1. */ - @deprecated("Use RDD.treeReduce instead.", "1.3.0") - def treeReduce(f: (T, T) => T, depth: Int = 2): T = self.treeReduce(f, depth) + def sliding(windowSize: Int): RDD[Array[T]] = sliding(windowSize, 1) - /** - * Aggregates the elements of this RDD in a multi-level tree pattern. - * - * @param depth suggested depth of the tree (default: 2) - * @see [[org.apache.spark.rdd.RDD#treeAggregate]] - * @deprecated Use [[org.apache.spark.rdd.RDD#treeAggregate]] instead. - */ - @deprecated("Use RDD.treeAggregate instead.", "1.3.0") - def treeAggregate[U: ClassTag](zeroValue: U)( - seqOp: (U, T) => U, - combOp: (U, U) => U, - depth: Int = 2): U = { - self.treeAggregate(zeroValue)(seqOp, combOp, depth) - } } @DeveloperApi diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala index f8cea7ecea6bf..92bc66949ae80 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala @@ -17,15 +17,15 @@ package org.apache.spark.mllib.rdd +import scala.reflect.ClassTag +import scala.util.Random + import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.mllib.linalg.{DenseVector, Vector} import org.apache.spark.mllib.random.RandomDataGenerator import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils -import scala.reflect.ClassTag -import scala.util.Random - private[mllib] class RandomRDDPartition[T](override val index: Int, val size: Int, val generator: RandomDataGenerator[T], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala index 1facf83d806d0..adb5e51947f6d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala @@ -20,17 +20,17 @@ package org.apache.spark.mllib.rdd import scala.collection.mutable import scala.reflect.ClassTag -import org.apache.spark.{TaskContext, Partition} +import org.apache.spark.{Partition, TaskContext} import org.apache.spark.rdd.RDD private[mllib] -class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T]) +class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T], val offset: Int) extends Partition with Serializable { override val index: Int = idx } /** - * Represents a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding + * Represents an RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding * window over them. The ordering is first based on the partition index and then the ordering of * items within each partition. This is similar to sliding in Scala collections, except that it * becomes an empty RDD if the window size is greater than the total number of items. It needs to @@ -40,19 +40,24 @@ class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T] * * @param parent the parent RDD * @param windowSize the window size, must be greater than 1 + * @param step step size for windows * - * @see [[org.apache.spark.mllib.rdd.RDDFunctions#sliding]] + * @see [[org.apache.spark.mllib.rdd.RDDFunctions.sliding(Int, Int)*]] + * @see [[scala.collection.IterableLike.sliding(Int, Int)*]] */ private[mllib] -class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int) +class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int, val step: Int) extends RDD[Array[T]](parent) { - require(windowSize > 1, s"Window size must be greater than 1, but got $windowSize.") + require(windowSize > 0 && step > 0 && !(windowSize == 1 && step == 1), + "Window size and step must be greater than 0, " + + s"and they cannot be both 1, but got windowSize = $windowSize and step = $step.") override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = { val part = split.asInstanceOf[SlidingRDDPartition[T]] (firstParent[T].iterator(part.prev, context) ++ part.tail) - .sliding(windowSize) + .drop(part.offset) + .sliding(windowSize, step) .withPartial(false) .map(_.toArray) } @@ -62,40 +67,42 @@ class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int override def getPartitions: Array[Partition] = { val parentPartitions = parent.partitions - val n = parentPartitions.size + val n = parentPartitions.length if (n == 0) { Array.empty } else if (n == 1) { - Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty)) + Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty, 0)) } else { - val n1 = n - 1 val w1 = windowSize - 1 - // Get the first w1 items of each partition, starting from the second partition. - val nextHeads = - parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n) - val partitions = mutable.ArrayBuffer[SlidingRDDPartition[T]]() + // Get partition sizes and first w1 elements. + val (sizes, heads) = parent.mapPartitions { iter => + val w1Array = iter.take(w1).toArray + Iterator.single((w1Array.length + iter.length, w1Array)) + }.collect().unzip + val partitions = mutable.ArrayBuffer.empty[SlidingRDDPartition[T]] var i = 0 + var cumSize = 0 var partitionIndex = 0 - while (i < n1) { - var j = i - val tail = mutable.ListBuffer[T]() - // Keep appending to the current tail until appended a head of size w1. - while (j < n1 && nextHeads(j).size < w1) { - tail ++= nextHeads(j) - j += 1 + while (i < n) { + val mod = cumSize % step + val offset = if (mod == 0) 0 else step - mod + val size = sizes(i) + if (offset < size) { + val tail = mutable.ListBuffer.empty[T] + // Keep appending to the current tail until it has w1 elements. + var j = i + 1 + while (j < n && tail.length < w1) { + tail ++= heads(j).take(w1 - tail.length) + j += 1 + } + if (sizes(i) + tail.length >= offset + windowSize) { + partitions += + new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail, offset) + partitionIndex += 1 + } } - if (j < n1) { - tail ++= nextHeads(j) - j += 1 - } - partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail) - partitionIndex += 1 - // Skip appended heads. - i = j - } - // If the head of last partition has size w1, we also need to add this partition. - if (nextHeads.last.size == w1) { - partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(n1), Seq.empty) + cumSize += size + i += 1 } partitions.toArray } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 33aaf853e599d..467cb83cd1662 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -17,9 +17,9 @@ package org.apache.spark.mllib.recommendation -import org.apache.spark.Logging import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaRDD +import org.apache.spark.internal.Logging import org.apache.spark.ml.recommendation.{ALS => NewALS} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -97,6 +97,8 @@ class ALS private ( */ @Since("0.8.0") def setBlocks(numBlocks: Int): this.type = { + require(numBlocks == -1 || numBlocks > 0, + s"Number of blocks must be -1 or positive but got ${numBlocks}") this.numUserBlocks = numBlocks this.numProductBlocks = numBlocks this @@ -107,6 +109,8 @@ class ALS private ( */ @Since("1.1.0") def setUserBlocks(numUserBlocks: Int): this.type = { + require(numUserBlocks == -1 || numUserBlocks > 0, + s"Number of blocks must be -1 or positive but got ${numUserBlocks}") this.numUserBlocks = numUserBlocks this } @@ -116,6 +120,8 @@ class ALS private ( */ @Since("1.1.0") def setProductBlocks(numProductBlocks: Int): this.type = { + require(numProductBlocks == -1 || numProductBlocks > 0, + s"Number of product blocks must be -1 or positive but got ${numProductBlocks}") this.numProductBlocks = numProductBlocks this } @@ -123,6 +129,8 @@ class ALS private ( /** Set the rank of the feature matrices computed (number of features). Default: 10. */ @Since("0.8.0") def setRank(rank: Int): this.type = { + require(rank > 0, + s"Rank of the feature matrices must be positive but got ${rank}") this.rank = rank this } @@ -130,6 +138,8 @@ class ALS private ( /** Set the number of iterations to run. Default: 10. */ @Since("0.8.0") def setIterations(iterations: Int): this.type = { + require(iterations >= 0, + s"Number of iterations must be nonnegative but got ${iterations}") this.iterations = iterations this } @@ -137,6 +147,8 @@ class ALS private ( /** Set the regularization parameter, lambda. Default: 0.01. */ @Since("0.8.0") def setLambda(lambda: Double): this.type = { + require(lambda >= 0.0, + s"Regularization parameter must be nonnegative but got ${lambda}") this.lambda = lambda this } @@ -218,7 +230,7 @@ class ALS private ( } /** - * Run ALS with the configured parameters on an input RDD of (user, product, rating) triples. + * Run ALS with the configured parameters on an input RDD of [[Rating]] objects. * Returns a MatrixFactorizationModel with feature vectors for each user and product. */ @Since("0.8.0") @@ -226,12 +238,12 @@ class ALS private ( val sc = ratings.context val numUserBlocks = if (this.numUserBlocks == -1) { - math.max(sc.defaultParallelism, ratings.partitions.size / 2) + math.max(sc.defaultParallelism, ratings.partitions.length / 2) } else { this.numUserBlocks } val numProductBlocks = if (this.numProductBlocks == -1) { - math.max(sc.defaultParallelism, ratings.partitions.size / 2) + math.max(sc.defaultParallelism, ratings.partitions.length / 2) } else { this.numProductBlocks } @@ -279,18 +291,17 @@ class ALS private ( @Since("0.8.0") object ALS { /** - * Train a matrix factorization model given an RDD of ratings given by users to some products, - * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the - * product of two lower-rank matrices of a given rank (number of features). To solve for these - * features, we run a given number of iterations of ALS. This is done using a level of - * parallelism given by `blocks`. + * Train a matrix factorization model given an RDD of ratings by users for a subset of products. + * The ratings matrix is approximated as the product of two lower-rank matrices of a given rank + * (number of features). To solve for these features, ALS is run iteratively with a configurable + * level of parallelism. * - * @param ratings RDD of (userID, productID, rating) pairs + * @param ratings RDD of [[Rating]] objects with userID, productID, and rating * @param rank number of features to use - * @param iterations number of iterations of ALS (recommended: 10-20) - * @param lambda regularization factor (recommended: 0.01) + * @param iterations number of iterations of ALS + * @param lambda regularization parameter * @param blocks level of parallelism to split computation into - * @param seed random seed + * @param seed random seed for initial matrix factorization model */ @Since("0.9.1") def train( @@ -305,16 +316,15 @@ object ALS { } /** - * Train a matrix factorization model given an RDD of ratings given by users to some products, - * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the - * product of two lower-rank matrices of a given rank (number of features). To solve for these - * features, we run a given number of iterations of ALS. This is done using a level of - * parallelism given by `blocks`. + * Train a matrix factorization model given an RDD of ratings by users for a subset of products. + * The ratings matrix is approximated as the product of two lower-rank matrices of a given rank + * (number of features). To solve for these features, ALS is run iteratively with a configurable + * level of parallelism. * - * @param ratings RDD of (userID, productID, rating) pairs + * @param ratings RDD of [[Rating]] objects with userID, productID, and rating * @param rank number of features to use - * @param iterations number of iterations of ALS (recommended: 10-20) - * @param lambda regularization factor (recommended: 0.01) + * @param iterations number of iterations of ALS + * @param lambda regularization parameter * @param blocks level of parallelism to split computation into */ @Since("0.8.0") @@ -329,16 +339,15 @@ object ALS { } /** - * Train a matrix factorization model given an RDD of ratings given by users to some products, - * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the - * product of two lower-rank matrices of a given rank (number of features). To solve for these - * features, we run a given number of iterations of ALS. The level of parallelism is determined - * automatically based on the number of partitions in `ratings`. + * Train a matrix factorization model given an RDD of ratings by users for a subset of products. + * The ratings matrix is approximated as the product of two lower-rank matrices of a given rank + * (number of features). To solve for these features, ALS is run iteratively with a level of + * parallelism automatically based on the number of partitions in `ratings`. * - * @param ratings RDD of (userID, productID, rating) pairs + * @param ratings RDD of [[Rating]] objects with userID, productID, and rating * @param rank number of features to use - * @param iterations number of iterations of ALS (recommended: 10-20) - * @param lambda regularization factor (recommended: 0.01) + * @param iterations number of iterations of ALS + * @param lambda regularization parameter */ @Since("0.8.0") def train(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double) @@ -347,15 +356,14 @@ object ALS { } /** - * Train a matrix factorization model given an RDD of ratings given by users to some products, - * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the - * product of two lower-rank matrices of a given rank (number of features). To solve for these - * features, we run a given number of iterations of ALS. The level of parallelism is determined - * automatically based on the number of partitions in `ratings`. + * Train a matrix factorization model given an RDD of ratings by users for a subset of products. + * The ratings matrix is approximated as the product of two lower-rank matrices of a given rank + * (number of features). To solve for these features, ALS is run iteratively with a level of + * parallelism automatically based on the number of partitions in `ratings`. * - * @param ratings RDD of (userID, productID, rating) pairs + * @param ratings RDD of [[Rating]] objects with userID, productID, and rating * @param rank number of features to use - * @param iterations number of iterations of ALS (recommended: 10-20) + * @param iterations number of iterations of ALS */ @Since("0.8.0") def train(ratings: RDD[Rating], rank: Int, iterations: Int) @@ -372,11 +380,11 @@ object ALS { * * @param ratings RDD of (userID, productID, rating) pairs * @param rank number of features to use - * @param iterations number of iterations of ALS (recommended: 10-20) - * @param lambda regularization factor (recommended: 0.01) + * @param iterations number of iterations of ALS + * @param lambda regularization parameter * @param blocks level of parallelism to split computation into * @param alpha confidence parameter - * @param seed random seed + * @param seed random seed for initial matrix factorization model */ @Since("0.8.1") def trainImplicit( @@ -392,16 +400,15 @@ object ALS { } /** - * Train a matrix factorization model given an RDD of 'implicit preferences' given by users - * to some products, in the form of (userID, productID, preference) pairs. We approximate the - * ratings matrix as the product of two lower-rank matrices of a given rank (number of features). - * To solve for these features, we run a given number of iterations of ALS. This is done using - * a level of parallelism given by `blocks`. + * Train a matrix factorization model given an RDD of 'implicit preferences' of users for a + * subset of products. The ratings matrix is approximated as the product of two lower-rank + * matrices of a given rank (number of features). To solve for these features, ALS is run + * iteratively with a configurable level of parallelism. * - * @param ratings RDD of (userID, productID, rating) pairs + * @param ratings RDD of [[Rating]] objects with userID, productID, and rating * @param rank number of features to use - * @param iterations number of iterations of ALS (recommended: 10-20) - * @param lambda regularization factor (recommended: 0.01) + * @param iterations number of iterations of ALS + * @param lambda regularization parameter * @param blocks level of parallelism to split computation into * @param alpha confidence parameter */ @@ -418,16 +425,16 @@ object ALS { } /** - * Train a matrix factorization model given an RDD of 'implicit preferences' given by users to - * some products, in the form of (userID, productID, preference) pairs. We approximate the - * ratings matrix as the product of two lower-rank matrices of a given rank (number of features). - * To solve for these features, we run a given number of iterations of ALS. The level of - * parallelism is determined automatically based on the number of partitions in `ratings`. + * Train a matrix factorization model given an RDD of 'implicit preferences' of users for a + * subset of products. The ratings matrix is approximated as the product of two lower-rank + * matrices of a given rank (number of features). To solve for these features, ALS is run + * iteratively with a level of parallelism determined automatically based on the number of + * partitions in `ratings`. * - * @param ratings RDD of (userID, productID, rating) pairs + * @param ratings RDD of [[Rating]] objects with userID, productID, and rating * @param rank number of features to use - * @param iterations number of iterations of ALS (recommended: 10-20) - * @param lambda regularization factor (recommended: 0.01) + * @param iterations number of iterations of ALS + * @param lambda regularization parameter * @param alpha confidence parameter */ @Since("0.8.1") @@ -437,16 +444,15 @@ object ALS { } /** - * Train a matrix factorization model given an RDD of 'implicit preferences' ratings given by - * users to some products, in the form of (userID, productID, rating) pairs. We approximate the - * ratings matrix as the product of two lower-rank matrices of a given rank (number of features). - * To solve for these features, we run a given number of iterations of ALS. The level of - * parallelism is determined automatically based on the number of partitions in `ratings`. - * Model parameters `alpha` and `lambda` are set to reasonable default values + * Train a matrix factorization model given an RDD of 'implicit preferences' of users for a + * subset of products. The ratings matrix is approximated as the product of two lower-rank + * matrices of a given rank (number of features). To solve for these features, ALS is run + * iteratively with a level of parallelism determined automatically based on the number of + * partitions in `ratings`. * - * @param ratings RDD of (userID, productID, rating) pairs + * @param ratings RDD of [[Rating]] objects with userID, productID, and rating * @param rank number of features to use - * @param iterations number of iterations of ALS (recommended: 10-20) + * @param iterations number of iterations of ALS */ @Since("0.8.1") def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 46562eb2ad0f7..6f780b0da71f5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -29,9 +29,10 @@ import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ import org.apache.spark.mllib.util.{Loader, Saveable} @@ -206,7 +207,7 @@ class MatrixFactorizationModel @Since("0.8.0") ( } /** - * Recommends topK products for all users. + * Recommends top products for all users. * * @param num how many products to return for every user. * @return [(Int, Array[Rating])] objects, where every tuple contains a userID and an array of @@ -224,7 +225,7 @@ class MatrixFactorizationModel @Since("0.8.0") ( /** - * Recommends topK users for all products. + * Recommends top users for all products. * * @param num how many users to return for every product. * @return [(Int, Array[Rating])] objects, where every tuple contains a productID and an array @@ -353,7 +354,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { */ def save(model: MatrixFactorizationModel, path: String): Unit = { val sc = model.userFeatures.sparkContext - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank))) @@ -364,18 +365,18 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { def load(sc: SparkContext, path: String): MatrixFactorizationModel = { implicit val formats = DefaultFormats - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val (className, formatVersion, metadata) = loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) val rank = (metadata \ "rank").extract[Int] - val userFeatures = sqlContext.read.parquet(userPath(path)) - .map { case Row(id: Int, features: Seq[_]) => + val userFeatures = sqlContext.read.parquet(userPath(path)).rdd.map { + case Row(id: Int, features: Seq[_]) => + (id, features.asInstanceOf[Seq[Double]].toArray) + } + val productFeatures = sqlContext.read.parquet(productPath(path)).rdd.map { + case Row(id: Int, features: Seq[_]) => (id, features.asInstanceOf[Seq[Double]].toArray) - } - val productFeatures = sqlContext.read.parquet(productPath(path)) - .map { case Row(id: Int, features: Seq[_]) => - (id, features.asInstanceOf[Seq[Double]].toArray) } new MatrixFactorizationModel(rank, userFeatures, productFeatures) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 8f657bfb9c730..4d5aaf14111c2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -17,13 +17,14 @@ package org.apache.spark.mllib.regression +import org.apache.spark.SparkException import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.internal.Logging import org.apache.spark.mllib.feature.StandardScaler -import org.apache.spark.{Logging, SparkException} -import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.optimization._ -import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.MLUtils._ +import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel /** @@ -140,7 +141,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * translated back to resulting model weights, so it's transparent to users. * Note: This technique is used in both libsvm and glmnet packages. Default false. */ - private var useFeatureScaling = false + private[mllib] var useFeatureScaling = false /** * The dimension of training features. @@ -196,12 +197,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] } /** - * Run the algorithm with the configured parameters on an input - * RDD of LabeledPoint entries. - * + * Generate the initial weights when the user does not supply them */ - @Since("0.8.0") - def run(input: RDD[LabeledPoint]): M = { + protected def generateInitialWeights(input: RDD[LabeledPoint]): Vector = { if (numFeatures < 0) { numFeatures = input.map(_.features.size).first() } @@ -217,16 +215,23 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * TODO: See if we can deprecate `intercept` in `GeneralizedLinearModel`, and always * have the intercept as part of weights to have consistent design. */ - val initialWeights = { - if (numOfLinearPredictor == 1) { - Vectors.zeros(numFeatures) - } else if (addIntercept) { - Vectors.zeros((numFeatures + 1) * numOfLinearPredictor) - } else { - Vectors.zeros(numFeatures * numOfLinearPredictor) - } + if (numOfLinearPredictor == 1) { + Vectors.zeros(numFeatures) + } else if (addIntercept) { + Vectors.zeros((numFeatures + 1) * numOfLinearPredictor) + } else { + Vectors.zeros(numFeatures * numOfLinearPredictor) } - run(input, initialWeights) + } + + /** + * Run the algorithm with the configured parameters on an input + * RDD of LabeledPoint entries. + * + */ + @Since("0.8.0") + def run(input: RDD[LabeledPoint]): M = { + run(input, generateInitialWeights(input)) } /** @@ -346,7 +351,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] val partialWeightsArray = scaler.transform( Vectors.dense(weightsArray.slice(start, end))).toArray - System.arraycopy(partialWeightsArray, 0, weightsArray, start, partialWeightsArray.size) + System.arraycopy(partialWeightsArray, 0, weightsArray, start, partialWeightsArray.length) i += 1 } weights = Vectors.dense(weightsArray) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index ec78ea24539b5..abdd7981970fa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -136,7 +136,7 @@ class IsotonicRegressionModel @Since("1.3.0") ( // higher than all values, in between two values or exact match. if (insertIndex == 0) { predictions.head - } else if (insertIndex == boundaries.length){ + } else if (insertIndex == boundaries.length) { predictions.last } else if (foundIndex < 0) { linearInterpolation( @@ -185,7 +185,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { boundaries: Array[Double], predictions: Array[Double], isotonic: Boolean): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ @@ -198,7 +198,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { } def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataRDD = sqlContext.read.parquet(dataPath(path)) checkSchema[Data](dataRDD.schema) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index c284ad2325374..45540f0c5c4ce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.regression import scala.beans.BeanInfo import org.apache.spark.annotation.Since -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.NumericParser import org.apache.spark.SparkException diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index a9aba173fa0e3..d55e5dfdaaf53 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -23,7 +23,7 @@ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression.impl.GLMRegressionModel -import org.apache.spark.mllib.util.{Saveable, Loader} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 4996ace5df85d..e754e74492755 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -23,7 +23,7 @@ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression.impl.GLMRegressionModel -import org.apache.spark.mllib.util.{Saveable, Loader} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD /** @@ -89,6 +89,7 @@ object LinearRegressionModel extends Loader[LinearRegressionModel] { class LinearRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, + private var regParam: Double, private var miniBatchFraction: Double) extends GeneralizedLinearAlgorithm[LinearRegressionModel] with Serializable { @@ -98,6 +99,7 @@ class LinearRegressionWithSGD private[mllib] ( override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) + .setRegParam(regParam) .setMiniBatchFraction(miniBatchFraction) /** @@ -105,7 +107,7 @@ class LinearRegressionWithSGD private[mllib] ( * numIterations: 100, miniBatchFraction: 1.0}. */ @Since("0.8.0") - def this() = this(1.0, 100, 1.0) + def this() = this(1.0, 100, 0.0, 1.0) override protected[mllib] def createModel(weights: Vector, intercept: Double) = { new LinearRegressionModel(weights, intercept) @@ -141,7 +143,7 @@ object LinearRegressionWithSGD { stepSize: Double, miniBatchFraction: Double, initialWeights: Vector): LinearRegressionModel = { - new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction) + new LinearRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction) .run(input, initialWeights) } @@ -163,7 +165,7 @@ object LinearRegressionWithSGD { numIterations: Int, stepSize: Double, miniBatchFraction: Double): LinearRegressionModel = { - new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction).run(input) + new LinearRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run(input) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index 73948b2d9851a..46deb545af3f0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -19,9 +19,9 @@ package org.apache.spark.mllib.regression import scala.reflect.ClassTag -import org.apache.spark.Logging import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.Vector import org.apache.spark.streaming.api.java.{JavaDStream, JavaPairDStream} import org.apache.spark.streaming.dstream.DStream diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index fe2a46b9eecc7..84764963b5f36 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -43,6 +43,7 @@ import org.apache.spark.mllib.linalg.Vector class StreamingLinearRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, + private var regParam: Double, private var miniBatchFraction: Double) extends StreamingLinearAlgorithm[LinearRegressionModel, LinearRegressionWithSGD] with Serializable { @@ -54,10 +55,10 @@ class StreamingLinearRegressionWithSGD private[mllib] ( * (see `StreamingLinearAlgorithm`) */ @Since("1.1.0") - def this() = this(0.1, 50, 1.0) + def this() = this(0.1, 50, 0.0, 1.0) @Since("1.1.0") - val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction) + val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction) protected var model: Option[LinearRegressionModel] = None @@ -70,6 +71,15 @@ class StreamingLinearRegressionWithSGD private[mllib] ( this } + /** + * Set the regularization parameter. Default: 0.0. + */ + @Since("2.0.0") + def setRegParam(regParam: Double): this.type = { + this.algorithm.optimizer.setRegParam(regParam) + this + } + /** * Set the number of iterations of gradient descent to run per update. Default: 50. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala index 317d3a5702636..a6e1767fe236a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala @@ -47,7 +47,7 @@ private[regression] object GLMRegressionModel { modelClass: String, weights: Vector, intercept: Double): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -71,10 +71,10 @@ private[regression] object GLMRegressionModel { */ def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = { val datapath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataRDD = sqlContext.read.parquet(datapath) val dataArray = dataRDD.select("weights", "intercept").take(1) - assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") + assert(dataArray.length == 1, s"Unable to load $modelClass data from: $datapath") val data = dataArray(0) assert(data.size == 2, s"Unable to load $modelClass data from: $datapath") data match { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 201333c3690df..98404be2603c7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.stat import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index bcb33a7a04677..f3159f7e724cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -20,9 +20,9 @@ package org.apache.spark.mllib.stat import scala.annotation.varargs import org.apache.spark.annotation.Since -import org.apache.spark.api.java.{JavaRDD, JavaDoubleRDD} -import org.apache.spark.mllib.linalg.distributed.RowMatrix +import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD} import org.apache.spark.mllib.linalg.{Matrix, Vector} +import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.correlation.Correlations import org.apache.spark.mllib.stat.test.{ChiSqTest, ChiSqTestResult, KolmogorovSmirnovTest, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala index 8a821d1b23bab..f131f6948ab1e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.stat.correlation import breeze.linalg.{DenseMatrix => BDM} -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector} import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala index 4a6c677f06d28..b760347bcb6fb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.stat.correlation import scala.collection.mutable.ArrayBuffer -import org.apache.spark.Logging -import org.apache.spark.SparkContext._ +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors} import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index 0724af93088c2..6c6e9fb7c6b3d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -17,10 +17,10 @@ package org.apache.spark.mllib.stat.distribution -import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym, Vector => BV} +import breeze.linalg.{diag, eigSym, max, DenseMatrix => DBM, DenseVector => DBV, Vector => BV} import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} +import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.MLUtils /** @@ -61,15 +61,17 @@ class MultivariateGaussian @Since("1.3.0") ( */ private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants - /** Returns density of this multivariate Gaussian at given point, x - */ + /** + * Returns density of this multivariate Gaussian at given point, x + */ @Since("1.3.0") def pdf(x: Vector): Double = { pdf(x.toBreeze) } - /** Returns the log-density of this multivariate Gaussian at given point, x - */ + /** + * Returns the log-density of this multivariate Gaussian at given point, x + */ @Since("1.3.0") def logpdf(x: Vector): Double = { logpdf(x.toBreeze) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala index 23c8d7c7c8075..76ca6a8abd032 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -17,16 +17,17 @@ package org.apache.spark.mllib.stat.test +import scala.collection.mutable + import breeze.linalg.{DenseMatrix => BDM} import org.apache.commons.math3.distribution.ChiSquaredDistribution -import org.apache.spark.{SparkException, Logging} +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD -import scala.collection.mutable - /** * Conduct the chi-squared test for the input RDDs using the specified method. * Goodness-of-fit test is conducted on two `Vectors`, whereas test of independence is conducted @@ -109,7 +110,9 @@ private[stat] object ChiSqTest extends Logging { } i += 1 distinctLabels += label - features.toArray.view.zipWithIndex.slice(startCol, endCol).map { case (feature, col) => + val brzFeatures = features.toBreeze + (startCol until endCol).map { col => + val feature = brzFeatures(col) allDistinctFeatures(col) += feature (col, feature, label) } @@ -122,7 +125,7 @@ private[stat] object ChiSqTest extends Logging { pairCounts.keys.filter(_._1 == startCol).map(_._3).toArray.distinct.zipWithIndex.toMap } val numLabels = labels.size - pairCounts.keys.groupBy(_._1).map { case (col, keys) => + pairCounts.keys.groupBy(_._1).foreach { case (col, keys) => val features = keys.map(_._2).toArray.distinct.zipWithIndex.toMap val numRows = features.size val contingency = new BDM(numRows, numLabels, new Array[Double](numRows * numLabels)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala index 2b3ed6df486c9..9748fbf2c97b6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala @@ -22,7 +22,7 @@ import scala.annotation.varargs import org.apache.commons.math3.distribution.{NormalDistribution, RealDistribution} import org.apache.commons.math3.stat.inference.{KolmogorovSmirnovTest => CommonMathKolmogorovSmirnovTest} -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD /** @@ -166,7 +166,7 @@ private[stat] object KolmogorovSmirnovTest extends Logging { : KolmogorovSmirnovTestResult = { val distObj = distName match { - case "norm" => { + case "norm" => if (params.nonEmpty) { // parameters are passed, then can only be 2 require(params.length == 2, "Normal distribution requires mean and standard " + @@ -178,7 +178,6 @@ private[stat] object KolmogorovSmirnovTest extends Logging { "initialized to standard normal (i.e. N(0, 1))") new NormalDistribution(0, 1) } - } case _ => throw new UnsupportedOperationException(s"$distName not yet supported through" + s" convenience method. Current options are:['norm'].") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala index 75c6a51d09571..4c382d7c2b791 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala @@ -17,12 +17,30 @@ package org.apache.spark.mllib.stat.test -import org.apache.spark.Logging +import scala.beans.BeanInfo + import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.rdd.RDD +import org.apache.spark.internal.Logging +import org.apache.spark.streaming.api.java.JavaDStream import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.StatCounter +/** + * Class that represents the group and value of a sample. + * + * @param isExperiment if the sample is of the experiment group. + * @param value numeric value of the observation. + */ +@Since("1.6.0") +@BeanInfo +case class BinarySample @Since("1.6.0") ( + @Since("1.6.0") isExperiment: Boolean, + @Since("1.6.0") value: Double) { + override def toString: String = { + s"($isExperiment, $value)" + } +} + /** * :: Experimental :: * Performs online 2-sample significance testing for a stream of (Boolean, Double) pairs. The @@ -83,13 +101,13 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable { /** * Register a [[DStream]] of values for significance testing. * - * @param data stream of (key,value) pairs where the key denotes group membership (true = - * experiment, false = control) and the value is the numerical metric to test for - * significance + * @param data stream of BinarySample(key,value) pairs where the key denotes group membership + * (true = experiment, false = control) and the value is the numerical metric to + * test for significance * @return stream of significance testing results */ @Since("1.6.0") - def registerStream(data: DStream[(Boolean, Double)]): DStream[StreamingTestResult] = { + def registerStream(data: DStream[BinarySample]): DStream[StreamingTestResult] = { val dataAfterPeacePeriod = dropPeacePeriod(data) val summarizedData = summarizeByKeyAndWindow(dataAfterPeacePeriod) val pairedSummaries = pairSummaries(summarizedData) @@ -97,9 +115,22 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable { testMethod.doTest(pairedSummaries) } + /** + * Register a [[JavaDStream]] of values for significance testing. + * + * @param data stream of BinarySample(isExperiment,value) pairs where the isExperiment denotes + * group (true = experiment, false = control) and the value is the numerical metric + * to test for significance + * @return stream of significance testing results + */ + @Since("1.6.0") + def registerStream(data: JavaDStream[BinarySample]): JavaDStream[StreamingTestResult] = { + JavaDStream.fromDStream(registerStream(data.dstream)) + } + /** Drop all batches inside the peace period. */ private[stat] def dropPeacePeriod( - data: DStream[(Boolean, Double)]): DStream[(Boolean, Double)] = { + data: DStream[BinarySample]): DStream[BinarySample] = { data.transform { (rdd, time) => if (time.milliseconds > data.slideDuration.milliseconds * peacePeriod) { rdd @@ -111,9 +142,10 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable { /** Compute summary statistics over each key and the specified test window size. */ private[stat] def summarizeByKeyAndWindow( - data: DStream[(Boolean, Double)]): DStream[(Boolean, StatCounter)] = { + data: DStream[BinarySample]): DStream[(Boolean, StatCounter)] = { + val categoryValuePair = data.map(sample => (sample.isExperiment, sample.value)) if (this.windowSize == 0) { - data.updateStateByKey[StatCounter]( + categoryValuePair.updateStateByKey[StatCounter]( (newValues: Seq[Double], oldSummary: Option[StatCounter]) => { val newSummary = oldSummary.getOrElse(new StatCounter()) newSummary.merge(newValues) @@ -121,7 +153,7 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable { }) } else { val windowDuration = data.slideDuration * this.windowSize - data + categoryValuePair .groupByKeyAndWindow(windowDuration) .mapValues { values => val summary = new StatCounter() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala index a7eaed51b4d55..ff27f28459e26 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala @@ -26,7 +26,7 @@ import com.twitter.chill.MeatLocker import org.apache.commons.math3.stat.descriptive.StatisticalSummaryValues import org.apache.commons.math3.stat.inference.TTest -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.StatCounter @@ -152,8 +152,8 @@ private[stat] object StudentTTest extends StreamingTestMethod with Logging { private[stat] object StreamingTestMethod { // Note: after new `StreamingTestMethod`s are implemented, please update this map. private final val TEST_NAME_TO_OBJECT: Map[String, StreamingTestMethod] = Map( - "welch"->WelchTTest, - "student"->StudentTTest) + "welch" -> WelchTTest, + "student" -> StudentTTest) def getTestMethodFromName(method: String): StreamingTestMethod = TEST_NAME_TO_OBJECT.get(method) match { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index af1f7e74c004d..21810a3b11aa6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -18,45 +18,52 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ -import scala.collection.mutable -import org.apache.spark.Logging import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD +import org.apache.spark.internal.Logging import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo -import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impl._ +import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD -import org.apache.spark.util.random.XORShiftRandom + /** * A class which implements a decision tree learning algorithm for classification and regression. * It supports both continuous and categorical features. + * * @param strategy The configuration parameters for the tree algorithm which specify the type - * of algorithm (classification, regression, etc.), feature type (continuous, + * of decision tree (classification or regression), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. + * @param seed Random seed. */ @Since("1.0.0") -class DecisionTree @Since("1.0.0") (private val strategy: Strategy) +class DecisionTree private[spark] (private val strategy: Strategy, private val seed: Int) extends Serializable with Logging { + /** + * @param strategy The configuration parameters for the tree algorithm which specify the type + * of decision tree (classification or regression), feature type (continuous, + * categorical), depth of the tree, quantile calculation strategy, etc. + */ + @Since("1.0.0") + def this(strategy: Strategy) = this(strategy, seed = 0) + strategy.assertValid() /** * Method to train a decision tree model over an RDD - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @return DecisionTreeModel that can be used for prediction + * + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return DecisionTreeModel that can be used for prediction. */ @Since("1.2.0") def run(input: RDD[LabeledPoint]): DecisionTreeModel = { - // Note: random seed will not be used since numTrees = 1. - val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0) + val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", + seed = seed) val rfModel = rf.run(input) rfModel.trees(0) } @@ -77,9 +84,9 @@ object DecisionTree extends Serializable with Logging { * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. * @param strategy The configuration parameters for the tree algorithm which specify the type - * of algorithm (classification, regression, etc.), feature type (continuous, + * of decision tree (classification or regression), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. - * @return DecisionTreeModel that can be used for prediction + * @return DecisionTreeModel that can be used for prediction. */ @Since("1.0.0") def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { @@ -97,11 +104,11 @@ object DecisionTree extends Serializable with Logging { * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. - * @param algo algorithm, classification or regression - * @param impurity impurity criterion used for information gain calculation - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @return DecisionTreeModel that can be used for prediction + * @param algo Type of decision tree, either classification or regression. + * @param impurity Criterion used for information gain calculation. + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * @return DecisionTreeModel that can be used for prediction. */ @Since("1.0.0") def train( @@ -124,12 +131,12 @@ object DecisionTree extends Serializable with Logging { * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. - * @param algo algorithm, classification or regression - * @param impurity impurity criterion used for information gain calculation - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @param numClasses number of classes for classification. Default value of 2. - * @return DecisionTreeModel that can be used for prediction + * @param algo Type of decision tree, either classification or regression. + * @param impurity Criterion used for information gain calculation. + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * @param numClasses Number of classes for classification. Default value of 2. + * @return DecisionTreeModel that can be used for prediction. */ @Since("1.2.0") def train( @@ -153,17 +160,17 @@ object DecisionTree extends Serializable with Logging { * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. - * @param algo classification or regression - * @param impurity criterion used for information gain calculation - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @param numClasses number of classes for classification. Default value of 2. - * @param maxBins maximum number of bins used for splitting features - * @param quantileCalculationStrategy algorithm for calculating quantiles - * @param categoricalFeaturesInfo Map storing arity of categorical features. - * E.g., an entry (n -> k) indicates that feature n is categorical - * with k categories indexed from 0: {0, 1, ..., k-1}. - * @return DecisionTreeModel that can be used for prediction + * @param algo Type of decision tree, either classification or regression. + * @param impurity Criterion used for information gain calculation. + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * @param numClasses Number of classes for classification. Default value of 2. + * @param maxBins Maximum number of bins used for splitting features. + * @param quantileCalculationStrategy Algorithm for calculating quantiles. + * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n -> k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. + * @return DecisionTreeModel that can be used for prediction. */ @Since("1.0.0") def train( @@ -185,18 +192,18 @@ object DecisionTree extends Serializable with Logging { * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * Labels should take values {0, 1, ..., numClasses-1}. - * @param numClasses number of classes for classification. - * @param categoricalFeaturesInfo Map storing arity of categorical features. - * E.g., an entry (n -> k) indicates that feature n is categorical - * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param numClasses Number of classes for classification. + * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n -> k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. * @param impurity Criterion used for information gain calculation. * Supported values: "gini" (recommended) or "entropy". - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * (suggested value: 5) - * @param maxBins maximum number of bins used for splitting features - * (suggested value: 32) - * @return DecisionTreeModel that can be used for prediction + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * (suggested value: 5) + * @param maxBins Maximum number of bins used for splitting features. + * (suggested value: 32) + * @return DecisionTreeModel that can be used for prediction. */ @Since("1.1.0") def trainClassifier( @@ -232,17 +239,17 @@ object DecisionTree extends Serializable with Logging { * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * Labels are real numbers. - * @param categoricalFeaturesInfo Map storing arity of categorical features. - * E.g., an entry (n -> k) indicates that feature n is categorical - * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n -> k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. * @param impurity Criterion used for information gain calculation. - * Supported values: "variance". - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * (suggested value: 5) - * @param maxBins maximum number of bins used for splitting features - * (suggested value: 32) - * @return DecisionTreeModel that can be used for prediction + * The only supported value for regression is "variance". + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * (suggested value: 5) + * @param maxBins Maximum number of bins used for splitting features. + * (suggested value: 32) + * @return DecisionTreeModel that can be used for prediction. */ @Since("1.1.0") def trainRegressor( @@ -269,909 +276,4 @@ object DecisionTree extends Serializable with Logging { categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, impurity, maxDepth, maxBins) } - - /** - * Get the node index corresponding to this data point. - * This function mimics prediction, passing an example from the root node down to a leaf - * or unsplit node; that node's index is returned. - * - * @param node Node in tree from which to classify the given data point. - * @param binnedFeatures Binned feature vector for data point. - * @param bins possible bins for all features, indexed (numFeatures)(numBins) - * @param unorderedFeatures Set of indices of unordered features. - * @return Leaf index if the data point reaches a leaf. - * Otherwise, last node reachable in tree matching this example. - * Note: This is the global node index, i.e., the index used in the tree. - * This index is different from the index used during training a particular - * group of nodes on one call to [[findBestSplits()]]. - */ - private def predictNodeIndex( - node: Node, - binnedFeatures: Array[Int], - bins: Array[Array[Bin]], - unorderedFeatures: Set[Int]): Int = { - if (node.isLeaf || node.split.isEmpty) { - // Node is either leaf, or has not yet been split. - node.id - } else { - val featureIndex = node.split.get.feature - val splitLeft = node.split.get.featureType match { - case Continuous => { - val binIndex = binnedFeatures(featureIndex) - val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold - // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold] - // We do not need to check lowSplit since bins are separated by splits. - featureValueUpperBound <= node.split.get.threshold - } - case Categorical => { - val featureValue = binnedFeatures(featureIndex) - node.split.get.categories.contains(featureValue) - } - case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.") - } - if (node.leftNode.isEmpty || node.rightNode.isEmpty) { - // Return index from next layer of nodes to train - if (splitLeft) { - Node.leftChildIndex(node.id) - } else { - Node.rightChildIndex(node.id) - } - } else { - if (splitLeft) { - predictNodeIndex(node.leftNode.get, binnedFeatures, bins, unorderedFeatures) - } else { - predictNodeIndex(node.rightNode.get, binnedFeatures, bins, unorderedFeatures) - } - } - } - } - - /** - * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features. - * - * For ordered features, a single bin is updated. - * For unordered features, bins correspond to subsets of categories; either the left or right bin - * for each subset is updated. - * - * @param agg Array storing aggregate calculation, with a set of sufficient statistics for - * each (feature, bin). - * @param treePoint Data point being aggregated. - * @param splits possible splits indexed (numFeatures)(numSplits) - * @param unorderedFeatures Set of indices of unordered features. - * @param instanceWeight Weight (importance) of instance in dataset. - */ - private def mixedBinSeqOp( - agg: DTStatsAggregator, - treePoint: TreePoint, - splits: Array[Array[Split]], - unorderedFeatures: Set[Int], - instanceWeight: Double, - featuresForNode: Option[Array[Int]]): Unit = { - val numFeaturesPerNode = if (featuresForNode.nonEmpty) { - // Use subsampled features - featuresForNode.get.size - } else { - // Use all features - agg.metadata.numFeatures - } - // Iterate over features. - var featureIndexIdx = 0 - while (featureIndexIdx < numFeaturesPerNode) { - val featureIndex = if (featuresForNode.nonEmpty) { - featuresForNode.get.apply(featureIndexIdx) - } else { - featureIndexIdx - } - if (unorderedFeatures.contains(featureIndex)) { - // Unordered feature - val featureValue = treePoint.binnedFeatures(featureIndex) - val (leftNodeFeatureOffset, rightNodeFeatureOffset) = - agg.getLeftRightFeatureOffsets(featureIndexIdx) - // Update the left or right bin for each split. - val numSplits = agg.metadata.numSplits(featureIndex) - var splitIndex = 0 - while (splitIndex < numSplits) { - if (splits(featureIndex)(splitIndex).categories.contains(featureValue)) { - agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, - instanceWeight) - } else { - agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, - instanceWeight) - } - splitIndex += 1 - } - } else { - // Ordered feature - val binIndex = treePoint.binnedFeatures(featureIndex) - agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight) - } - featureIndexIdx += 1 - } - } - - /** - * Helper for binSeqOp, for regression and for classification with only ordered features. - * - * For each feature, the sufficient statistics of one bin are updated. - * - * @param agg Array storing aggregate calculation, with a set of sufficient statistics for - * each (feature, bin). - * @param treePoint Data point being aggregated. - * @param instanceWeight Weight (importance) of instance in dataset. - */ - private def orderedBinSeqOp( - agg: DTStatsAggregator, - treePoint: TreePoint, - instanceWeight: Double, - featuresForNode: Option[Array[Int]]): Unit = { - val label = treePoint.label - - // Iterate over features. - if (featuresForNode.nonEmpty) { - // Use subsampled features - var featureIndexIdx = 0 - while (featureIndexIdx < featuresForNode.get.size) { - val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx)) - agg.update(featureIndexIdx, binIndex, label, instanceWeight) - featureIndexIdx += 1 - } - } else { - // Use all features - val numFeatures = agg.metadata.numFeatures - var featureIndex = 0 - while (featureIndex < numFeatures) { - val binIndex = treePoint.binnedFeatures(featureIndex) - agg.update(featureIndex, binIndex, label, instanceWeight) - featureIndex += 1 - } - } - } - - /** - * Given a group of nodes, this finds the best split for each node. - * - * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] - * @param metadata Learning and dataset metadata - * @param topNodes Root node for each tree. Used for matching instances with nodes. - * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree - * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo, - * where nodeIndexInfo stores the index in the group and the - * feature subsets (if using feature subsets). - * @param splits possible splits for all features, indexed (numFeatures)(numSplits) - * @param bins possible bins for all features, indexed (numFeatures)(numBins) - * @param nodeQueue Queue of nodes to split, with values (treeIndex, node). - * Updated with new non-leaf nodes which are created. - * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where - * each value in the array is the data point's node Id - * for a corresponding tree. This is used to prevent the need - * to pass the entire tree to the executors during - * the node stat aggregation phase. - */ - private[tree] def findBestSplits( - input: RDD[BaggedPoint[TreePoint]], - metadata: DecisionTreeMetadata, - topNodes: Array[Node], - nodesForGroup: Map[Int, Array[Node]], - treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]], - splits: Array[Array[Split]], - bins: Array[Array[Bin]], - nodeQueue: mutable.Queue[(Int, Node)], - timer: TimeTracker = new TimeTracker, - nodeIdCache: Option[NodeIdCache] = None): Unit = { - - /* - * The high-level descriptions of the best split optimizations are noted here. - * - * *Group-wise training* - * We perform bin calculations for groups of nodes to reduce the number of - * passes over the data. Each iteration requires more computation and storage, - * but saves several iterations over the data. - * - * *Bin-wise computation* - * We use a bin-wise best split computation strategy instead of a straightforward best split - * computation strategy. Instead of analyzing each sample for contribution to the left/right - * child node impurity of every split, we first categorize each feature of a sample into a - * bin. We exploit this structure to calculate aggregates for bins and then use these aggregates - * to calculate information gain for each split. - * - * *Aggregation over partitions* - * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know - * the number of splits in advance. Thus, we store the aggregates (at the appropriate - * indices) in a single array for all bins and rely upon the RDD aggregate method to - * drastically reduce the communication overhead. - */ - - // numNodes: Number of nodes in this group - val numNodes = nodesForGroup.values.map(_.size).sum - logDebug("numNodes = " + numNodes) - logDebug("numFeatures = " + metadata.numFeatures) - logDebug("numClasses = " + metadata.numClasses) - logDebug("isMulticlass = " + metadata.isMulticlass) - logDebug("isMulticlassWithCategoricalFeatures = " + - metadata.isMulticlassWithCategoricalFeatures) - logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString) - - /** - * Performs a sequential aggregation over a partition for a particular tree and node. - * - * For each feature, the aggregate sufficient statistics are updated for the relevant - * bins. - * - * @param treeIndex Index of the tree that we want to perform aggregation for. - * @param nodeInfo The node info for the tree node. - * @param agg Array storing aggregate calculation, with a set of sufficient statistics - * for each (node, feature, bin). - * @param baggedPoint Data point being aggregated. - */ - def nodeBinSeqOp( - treeIndex: Int, - nodeInfo: RandomForest.NodeIndexInfo, - agg: Array[DTStatsAggregator], - baggedPoint: BaggedPoint[TreePoint]): Unit = { - if (nodeInfo != null) { - val aggNodeIndex = nodeInfo.nodeIndexInGroup - val featuresForNode = nodeInfo.featureSubset - val instanceWeight = baggedPoint.subsampleWeights(treeIndex) - if (metadata.unorderedFeatures.isEmpty) { - orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) - } else { - mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits, - metadata.unorderedFeatures, instanceWeight, featuresForNode) - } - } - } - - /** - * Performs a sequential aggregation over a partition. - * - * Each data point contributes to one node. For each feature, - * the aggregate sufficient statistics are updated for the relevant bins. - * - * @param agg Array storing aggregate calculation, with a set of sufficient statistics for - * each (node, feature, bin). - * @param baggedPoint Data point being aggregated. - * @return agg - */ - def binSeqOp( - agg: Array[DTStatsAggregator], - baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = { - treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => - val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, - bins, metadata.unorderedFeatures) - nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) - } - - agg - } - - /** - * Do the same thing as binSeqOp, but with nodeIdCache. - */ - def binSeqOpWithNodeIdCache( - agg: Array[DTStatsAggregator], - dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = { - treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => - val baggedPoint = dataPoint._1 - val nodeIdCache = dataPoint._2 - val nodeIndex = nodeIdCache(treeIndex) - nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) - } - - agg - } - - /** - * Get node index in group --> features indices map, - * which is a short cut to find feature indices for a node given node index in group - * @param treeToNodeToIndexInfo - * @return - */ - def getNodeToFeatures(treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]) - : Option[Map[Int, Array[Int]]] = if (!metadata.subsamplingFeatures) { - None - } else { - val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]() - treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo => - nodeIdToNodeInfo.values.foreach { nodeIndexInfo => - assert(nodeIndexInfo.featureSubset.isDefined) - mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get - } - } - Some(mutableNodeToFeatures.toMap) - } - - // array of nodes to train indexed by node index in group - val nodes = new Array[Node](numNodes) - nodesForGroup.foreach { case (treeIndex, nodesForTree) => - nodesForTree.foreach { node => - nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node - } - } - - // Calculate best splits for all nodes in the group - timer.start("chooseSplits") - - // In each partition, iterate all instances and compute aggregate stats for each node, - // yield an (nodeIndex, nodeAggregateStats) pair for each node. - // After a `reduceByKey` operation, - // stats of a node will be shuffled to a particular partition and be combined together, - // then best splits for nodes are found there. - // Finally, only best Splits for nodes are collected to driver to construct decision tree. - val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo) - val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures) - - val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) { - input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points => - // Construct a nodeStatsAggregators array to hold node aggregate stats, - // each node will have a nodeStatsAggregator - val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => - val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => - Some(nodeToFeatures(nodeIndex)) - } - new DTStatsAggregator(metadata, featuresForNode) - } - - // iterator all instances in current partition and update aggregate stats - points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _)) - - // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, - // which can be combined with other partition using `reduceByKey` - nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator - } - } else { - input.mapPartitions { points => - // Construct a nodeStatsAggregators array to hold node aggregate stats, - // each node will have a nodeStatsAggregator - val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => - val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => - Some(nodeToFeatures(nodeIndex)) - } - new DTStatsAggregator(metadata, featuresForNode) - } - - // iterator all instances in current partition and update aggregate stats - points.foreach(binSeqOp(nodeStatsAggregators, _)) - - // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, - // which can be combined with other partition using `reduceByKey` - nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator - } - } - - val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)) - .map { case (nodeIndex, aggStats) => - val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures => - nodeToFeatures(nodeIndex) - } - - // find best split for each node - val (split: Split, stats: InformationGainStats, predict: Predict) = - binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) - (nodeIndex, (split, stats, predict)) - }.collectAsMap() - - timer.stop("chooseSplits") - - val nodeIdUpdaters = if (nodeIdCache.nonEmpty) { - Array.fill[mutable.Map[Int, NodeIndexUpdater]]( - metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]()) - } else { - null - } - - // Iterate over all nodes in this group. - nodesForGroup.foreach { case (treeIndex, nodesForTree) => - nodesForTree.foreach { node => - val nodeIndex = node.id - val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) - val aggNodeIndex = nodeInfo.nodeIndexInGroup - val (split: Split, stats: InformationGainStats, predict: Predict) = - nodeToBestSplits(aggNodeIndex) - logDebug("best split = " + split) - - // Extract info for this node. Create children if not leaf. - val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth) - assert(node.id == nodeIndex) - node.predict = predict - node.isLeaf = isLeaf - node.stats = Some(stats) - node.impurity = stats.impurity - logDebug("Node = " + node) - - if (!isLeaf) { - node.split = Some(split) - val childIsLeaf = (Node.indexToLevel(nodeIndex) + 1) == metadata.maxDepth - val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0) - val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0) - node.leftNode = Some(Node(Node.leftChildIndex(nodeIndex), - stats.leftPredict, stats.leftImpurity, leftChildIsLeaf)) - node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex), - stats.rightPredict, stats.rightImpurity, rightChildIsLeaf)) - - if (nodeIdCache.nonEmpty) { - val nodeIndexUpdater = NodeIndexUpdater( - split = split, - nodeIndex = nodeIndex) - nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater) - } - - // enqueue left child and right child if they are not leaves - if (!leftChildIsLeaf) { - nodeQueue.enqueue((treeIndex, node.leftNode.get)) - } - if (!rightChildIsLeaf) { - nodeQueue.enqueue((treeIndex, node.rightNode.get)) - } - - logDebug("leftChildIndex = " + node.leftNode.get.id + - ", impurity = " + stats.leftImpurity) - logDebug("rightChildIndex = " + node.rightNode.get.id + - ", impurity = " + stats.rightImpurity) - } - } - } - - if (nodeIdCache.nonEmpty) { - // Update the cache if needed. - nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, bins) - } - } - - /** - * Calculate the information gain for a given (feature, split) based upon left/right aggregates. - * @param leftImpurityCalculator left node aggregates for this (feature, split) - * @param rightImpurityCalculator right node aggregate for this (feature, split) - * @return information gain and statistics for split - */ - private def calculateGainForSplit( - leftImpurityCalculator: ImpurityCalculator, - rightImpurityCalculator: ImpurityCalculator, - metadata: DecisionTreeMetadata, - impurity: Double): InformationGainStats = { - val leftCount = leftImpurityCalculator.count - val rightCount = rightImpurityCalculator.count - - // If left child or right child doesn't satisfy minimum instances per node, - // then this split is invalid, return invalid information gain stats. - if ((leftCount < metadata.minInstancesPerNode) || - (rightCount < metadata.minInstancesPerNode)) { - return InformationGainStats.invalidInformationGainStats - } - - val totalCount = leftCount + rightCount - - val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 - val rightImpurity = rightImpurityCalculator.calculate() - - val leftWeight = leftCount / totalCount.toDouble - val rightWeight = rightCount / totalCount.toDouble - - val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - - // if information gain doesn't satisfy minimum information gain, - // then this split is invalid, return invalid information gain stats. - if (gain < metadata.minInfoGain) { - return InformationGainStats.invalidInformationGainStats - } - - // calculate left and right predict - val leftPredict = calculatePredict(leftImpurityCalculator) - val rightPredict = calculatePredict(rightImpurityCalculator) - - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, - leftPredict, rightPredict) - } - - private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = { - val predict = impurityCalculator.predict - val prob = impurityCalculator.prob(predict) - new Predict(predict, prob) - } - - /** - * Calculate predict value for current node, given stats of any split. - * Note that this function is called only once for each node. - * @param leftImpurityCalculator left node aggregates for a split - * @param rightImpurityCalculator right node aggregates for a split - * @return predict value and impurity for current node - */ - private def calculatePredictImpurity( - leftImpurityCalculator: ImpurityCalculator, - rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { - val parentNodeAgg = leftImpurityCalculator.copy - parentNodeAgg.add(rightImpurityCalculator) - val predict = calculatePredict(parentNodeAgg) - val impurity = parentNodeAgg.calculate() - - (predict, impurity) - } - - /** - * Find the best split for a node. - * @param binAggregates Bin statistics. - * @return tuple for best split: (Split, information gain, prediction at node) - */ - private def binsToBestSplit( - binAggregates: DTStatsAggregator, - splits: Array[Array[Split]], - featuresForNode: Option[Array[Int]], - node: Node): (Split, InformationGainStats, Predict) = { - - // calculate predict and impurity if current node is top node - val level = Node.indexToLevel(node.id) - var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) { - None - } else { - Some((node.predict, node.impurity)) - } - - // For each (feature, split), calculate the gain, and select the best (feature, split). - val (bestSplit, bestSplitStats) = - Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx => - val featureIndex = if (featuresForNode.nonEmpty) { - featuresForNode.get.apply(featureIndexIdx) - } else { - featureIndexIdx - } - val numSplits = binAggregates.metadata.numSplits(featureIndex) - if (binAggregates.metadata.isContinuous(featureIndex)) { - // Cumulative sum (scanLeft) of bin statistics. - // Afterwards, binAggregates for a bin is the sum of aggregates for - // that bin + all preceding bins. - val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) - var splitIndex = 0 - while (splitIndex < numSplits) { - binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) - splitIndex += 1 - } - // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { case splitIdx => - val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) - val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) - rightChildStats.subtract(leftChildStats) - predictWithImpurity = Some(predictWithImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) - (splitIdx, gainStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else if (binAggregates.metadata.isUnordered(featureIndex)) { - // Unordered categorical feature - val (leftChildOffset, rightChildOffset) = - binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) - predictWithImpurity = Some(predictWithImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) - (splitIndex, gainStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else { - // Ordered categorical feature - val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) - val numBins = binAggregates.metadata.numBins(featureIndex) - - /* Each bin is one category (feature value). - * The bins are ordered based on centroidForCategories, and this ordering determines which - * splits are considered. (With K categories, we consider K - 1 possible splits.) - * - * centroidForCategories is a list: (category, centroid) - */ - val centroidForCategories = if (binAggregates.metadata.isMulticlass) { - // For categorical variables in multiclass classification, - // the bins are ordered by the impurity of their corresponding labels. - Range(0, numBins).map { case featureValue => - val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val centroid = if (categoryStats.count != 0) { - categoryStats.calculate() - } else { - Double.MaxValue - } - (featureValue, centroid) - } - } else { // regression or binary classification - // For categorical variables in regression and binary classification, - // the bins are ordered by the centroid of their corresponding labels. - Range(0, numBins).map { case featureValue => - val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val centroid = if (categoryStats.count != 0) { - categoryStats.predict - } else { - Double.MaxValue - } - (featureValue, centroid) - } - } - - logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) - - // bins sorted by centroids - val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2) - - logDebug("Sorted centroids for categorical variable = " + - categoriesSortedByCentroid.mkString(",")) - - // Cumulative sum (scanLeft) of bin statistics. - // Afterwards, binAggregates for a bin is the sum of aggregates for - // that bin + all preceding bins. - var splitIndex = 0 - while (splitIndex < numSplits) { - val currentCategory = categoriesSortedByCentroid(splitIndex)._1 - val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 - binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) - splitIndex += 1 - } - // lastCategory = index of bin with total aggregates for this (node, feature) - val lastCategory = categoriesSortedByCentroid.last._1 - // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val featureValue = categoriesSortedByCentroid(splitIndex)._1 - val leftChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val rightChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) - rightChildStats.subtract(leftChildStats) - predictWithImpurity = Some(predictWithImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) - (splitIndex, gainStats) - }.maxBy(_._2.gain) - val categoriesForSplit = - categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) - val bestFeatureSplit = - new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit) - (bestFeatureSplit, bestFeatureGainStats) - } - }.maxBy(_._2.gain) - - (bestSplit, bestSplitStats, predictWithImpurity.get._1) - } - - /** - * Returns splits and bins for decision tree calculation. - * Continuous and categorical features are handled differently. - * - * Continuous features: - * For each feature, there are numBins - 1 possible splits representing the possible binary - * decisions at each node in the tree. - * This finds locations (feature values) for splits using a subsample of the data. - * - * Categorical features: - * For each feature, there is 1 bin per split. - * Splits and bins are handled in 2 ways: - * (a) "unordered features" - * For multiclass classification with a low-arity feature - * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), - * the feature is split based on subsets of categories. - * (b) "ordered features" - * For regression and binary classification, - * and for multiclass classification with a high-arity feature, - * there is one bin per category. - * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @param metadata Learning and dataset metadata - * @return A tuple of (splits, bins). - * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] - * of size (numFeatures, numSplits). - * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] - * of size (numFeatures, numBins). - */ - protected[tree] def findSplitsBins( - input: RDD[LabeledPoint], - metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = { - - logDebug("isMulticlass = " + metadata.isMulticlass) - - val numFeatures = metadata.numFeatures - - // Sample the input only if there are continuous features. - val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous) - val sampledInput = if (continuousFeatures.nonEmpty) { - // Calculate the number of samples for approximate quantile calculation. - val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) - val fraction = if (requiredSamples < metadata.numExamples) { - requiredSamples.toDouble / metadata.numExamples - } else { - 1.0 - } - logDebug("fraction of data used for calculating quantiles = " + fraction) - input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()) - } else { - input.sparkContext.emptyRDD[LabeledPoint] - } - - metadata.quantileStrategy match { - case Sort => - findSplitsBinsBySorting(sampledInput, metadata, continuousFeatures) - case MinMax => - throw new UnsupportedOperationException("minmax not supported yet.") - case ApproxHist => - throw new UnsupportedOperationException("approximate histogram not supported yet.") - } - } - - private def findSplitsBinsBySorting( - input: RDD[LabeledPoint], - metadata: DecisionTreeMetadata, - continuousFeatures: IndexedSeq[Int]): (Array[Array[Split]], Array[Array[Bin]]) = { - def findSplits( - featureIndex: Int, - featureSamples: Iterable[Double]): (Int, (Array[Split], Array[Bin])) = { - val splits = { - val featureSplits = findSplitsForContinuousFeature( - featureSamples.toArray, - metadata, - featureIndex) - logDebug(s"featureIndex = $featureIndex, numSplits = ${featureSplits.length}") - - featureSplits.map(threshold => new Split(featureIndex, threshold, Continuous, Nil)) - } - - val bins = { - val lowSplit = new DummyLowSplit(featureIndex, Continuous) - val highSplit = new DummyHighSplit(featureIndex, Continuous) - - // tack the dummy splits on either side of the computed splits - val allSplits = lowSplit +: splits.toSeq :+ highSplit - - // slide across the split points pairwise to allocate the bins - allSplits.sliding(2).map { - case Seq(left, right) => new Bin(left, right, Continuous, Double.MinValue) - }.toArray - } - - (featureIndex, (splits, bins)) - } - - val continuousSplits = { - // reduce the parallelism for split computations when there are less - // continuous features than input partitions. this prevents tasks from - // being spun up that will definitely do no work. - val numPartitions = math.min(continuousFeatures.length, input.partitions.length) - - input - .flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx)))) - .groupByKey(numPartitions) - .map { case (k, v) => findSplits(k, v) } - .collectAsMap() - } - - val numFeatures = metadata.numFeatures - val (splits, bins) = Range(0, numFeatures).unzip { - case i if metadata.isContinuous(i) => - val (split, bin) = continuousSplits(i) - metadata.setNumSplits(i, split.length) - (split, bin) - - case i if metadata.isCategorical(i) && metadata.isUnordered(i) => - // Unordered features - // 2^(maxFeatureValue - 1) - 1 combinations - val featureArity = metadata.featureArity(i) - val split = Range(0, metadata.numSplits(i)).map { splitIndex => - val categories = extractMultiClassCategories(splitIndex + 1, featureArity) - new Split(i, Double.MinValue, Categorical, categories) - } - - // For unordered categorical features, there is no need to construct the bins. - // since there is a one-to-one correspondence between the splits and the bins. - (split.toArray, Array.empty[Bin]) - - case i if metadata.isCategorical(i) => - // Ordered features - // Bins correspond to feature values, so we do not need to compute splits or bins - // beforehand. Splits are constructed as needed during training. - (Array.empty[Split], Array.empty[Bin]) - } - - (splits.toArray, bins.toArray) - } - - /** - * Nested method to extract list of eligible categories given an index. It extracts the - * position of ones in a binary representation of the input. If binary - * representation of an number is 01101 (13), the output list should (3.0, 2.0, - * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones. - */ - private[tree] def extractMultiClassCategories( - input: Int, - maxFeatureValue: Int): List[Double] = { - var categories = List[Double]() - var j = 0 - var bitShiftedInput = input - while (j < maxFeatureValue) { - if (bitShiftedInput % 2 != 0) { - // updating the list of categories. - categories = j.toDouble :: categories - } - // Right shift by one - bitShiftedInput = bitShiftedInput >> 1 - j += 1 - } - categories - } - - /** - * Find splits for a continuous feature - * NOTE: Returned number of splits is set based on `featureSamples` and - * could be different from the specified `numSplits`. - * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly. - * @param featureSamples feature values of each sample - * @param metadata decision tree metadata - * NOTE: `metadata.numbins` will be changed accordingly - * if there are not enough splits to be found - * @param featureIndex feature index to find splits - * @return array of splits - */ - private[tree] def findSplitsForContinuousFeature( - featureSamples: Array[Double], - metadata: DecisionTreeMetadata, - featureIndex: Int): Array[Double] = { - require(metadata.isContinuous(featureIndex), - "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") - - val splits = { - val numSplits = metadata.numSplits(featureIndex) - - // get count for each distinct value - val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) => - m + ((x, m.getOrElse(x, 0) + 1)) - } - // sort distinct values - val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray - - // if possible splits is not enough or just enough, just return all possible splits - val possibleSplits = valueCounts.length - if (possibleSplits <= numSplits) { - valueCounts.map(_._1) - } else { - // stride between splits - val stride: Double = featureSamples.length.toDouble / (numSplits + 1) - logDebug("stride = " + stride) - - // iterate `valueCount` to find splits - val splitsBuilder = Array.newBuilder[Double] - var index = 1 - // currentCount: sum of counts of values that have been visited - var currentCount = valueCounts(0)._2 - // targetCount: target value for `currentCount`. - // If `currentCount` is closest value to `targetCount`, - // then current value is a split threshold. - // After finding a split threshold, `targetCount` is added by stride. - var targetCount = stride - while (index < valueCounts.length) { - val previousCount = currentCount - currentCount += valueCounts(index)._2 - val previousGap = math.abs(previousCount - targetCount) - val currentGap = math.abs(currentCount - targetCount) - // If adding count of current value to currentCount - // makes the gap between currentCount and targetCount smaller, - // previous value is a split threshold. - if (previousGap < currentGap) { - splitsBuilder += valueCounts(index - 1)._1 - targetCount += stride - } - index += 1 - } - - splitsBuilder.result() - } - } - - // TODO: Do not fail; just ignore the useless feature. - assert(splits.length > 0, - s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." + - " Please remove this feature and then try again.") - - // the split metadata must be updated on the driver - - splits - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 729a211574822..7fe60e2d99e4f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -17,18 +17,14 @@ package org.apache.spark.mllib.tree -import org.apache.spark.Logging import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD -import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer +import org.apache.spark.internal.Logging +import org.apache.spark.ml.tree.impl.{GradientBoostedTrees => NewGBT} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.BoostingStrategy -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.impl.TimeTracker -import org.apache.spark.mllib.tree.impurity.Variance -import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel} +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel /** * A class that implements @@ -47,29 +43,31 @@ import org.apache.spark.storage.StorageLevel * for other loss functions. * * @param boostingStrategy Parameters for the gradient boosting algorithm. + * @param seed Random seed. */ @Since("1.2.0") -class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: BoostingStrategy) +class GradientBoostedTrees private[spark] ( + private val boostingStrategy: BoostingStrategy, + private val seed: Int) extends Serializable with Logging { + /** + * @param boostingStrategy Parameters for the gradient boosting algorithm. + */ + @Since("1.2.0") + def this(boostingStrategy: BoostingStrategy) = this(boostingStrategy, seed = 0) + /** * Method to train a gradient boosting model + * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return a gradient boosted trees model that can be used for prediction + * @return GradientBoostedTreesModel that can be used for prediction. */ @Since("1.2.0") def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo - algo match { - case Regression => - GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false) - case Classification => - // Map labels to -1, +1 so binary classification can be treated as regression. - val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false) - case _ => - throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") - } + val (trees, treeWeights) = NewGBT.run(input, boostingStrategy, seed.toLong) + new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights) } /** @@ -82,33 +80,23 @@ class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: Boosti /** * Method to validate a gradient boosting model + * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @param validationInput Validation dataset. * This dataset should be different from the training dataset, * but it should follow the same distribution. * E.g., these two datasets could be created from an original dataset * by using [[org.apache.spark.rdd.RDD.randomSplit()]] - * @return a gradient boosted trees model that can be used for prediction + * @return GradientBoostedTreesModel that can be used for prediction. */ @Since("1.4.0") def runWithValidation( input: RDD[LabeledPoint], validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo - algo match { - case Regression => - GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true) - case Classification => - // Map labels to -1, +1 so binary classification can be treated as regression. - val remappedInput = input.map( - x => new LabeledPoint((x.label * 2) - 1, x.features)) - val remappedValidationInput = validationInput.map( - x => new LabeledPoint((x.label * 2) - 1, x.features)) - GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, - validate = true) - case _ => - throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") - } + val (trees, treeWeights) = NewGBT.runWithValidation(input, validationInput, boostingStrategy, + seed.toLong) + new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights) } /** @@ -132,13 +120,13 @@ object GradientBoostedTrees extends Logging { * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. * @param boostingStrategy Configuration options for the boosting algorithm. - * @return a gradient boosted trees model that can be used for prediction + * @return GradientBoostedTreesModel that can be used for prediction. */ @Since("1.2.0") def train( input: RDD[LabeledPoint], boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { - new GradientBoostedTrees(boostingStrategy).run(input) + new GradientBoostedTrees(boostingStrategy, seed = 0).run(input) } /** @@ -150,145 +138,4 @@ object GradientBoostedTrees extends Logging { boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { train(input.rdd, boostingStrategy) } - - /** - * Internal method for performing regression using trees as base learners. - * @param input training dataset - * @param validationInput validation dataset, ignored if validate is set to false. - * @param boostingStrategy boosting parameters - * @param validate whether or not to use the validation dataset. - * @return a gradient boosted trees model that can be used for prediction - */ - private def boost( - input: RDD[LabeledPoint], - validationInput: RDD[LabeledPoint], - boostingStrategy: BoostingStrategy, - validate: Boolean): GradientBoostedTreesModel = { - val timer = new TimeTracker() - timer.start("total") - timer.start("init") - - boostingStrategy.assertValid() - - // Initialize gradient boosting parameters - val numIterations = boostingStrategy.numIterations - val baseLearners = new Array[DecisionTreeModel](numIterations) - val baseLearnerWeights = new Array[Double](numIterations) - val loss = boostingStrategy.loss - val learningRate = boostingStrategy.learningRate - // Prepare strategy for individual trees, which use regression with variance impurity. - val treeStrategy = boostingStrategy.treeStrategy.copy - val validationTol = boostingStrategy.validationTol - treeStrategy.algo = Regression - treeStrategy.impurity = Variance - treeStrategy.assertValid() - - // Cache input - val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) { - input.persist(StorageLevel.MEMORY_AND_DISK) - true - } else { - false - } - - // Prepare periodic checkpointers - val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( - treeStrategy.getCheckpointInterval, input.sparkContext) - val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( - treeStrategy.getCheckpointInterval, input.sparkContext) - - timer.stop("init") - - logDebug("##########") - logDebug("Building tree 0") - logDebug("##########") - - // Initialize tree - timer.start("building tree 0") - val firstTreeModel = new DecisionTree(treeStrategy).run(input) - val firstTreeWeight = 1.0 - baseLearners(0) = firstTreeModel - baseLearnerWeights(0) = firstTreeWeight - - var predError: RDD[(Double, Double)] = GradientBoostedTreesModel. - computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) - predErrorCheckpointer.update(predError) - logDebug("error of gbt = " + predError.values.mean()) - - // Note: A model of type regression is used since we require raw prediction - timer.stop("building tree 0") - - var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel. - computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss) - if (validate) validatePredErrorCheckpointer.update(validatePredError) - var bestValidateError = if (validate) validatePredError.values.mean() else 0.0 - var bestM = 1 - - var m = 1 - var doneLearning = false - while (m < numIterations && !doneLearning) { - // Update data with pseudo-residuals - val data = predError.zip(input).map { case ((pred, _), point) => - LabeledPoint(-loss.gradient(pred, point.label), point.features) - } - - timer.start(s"building tree $m") - logDebug("###################################################") - logDebug("Gradient boosting tree iteration " + m) - logDebug("###################################################") - val model = new DecisionTree(treeStrategy).run(data) - timer.stop(s"building tree $m") - // Update partial model - baseLearners(m) = model - // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. - // Technically, the weight should be optimized for the particular loss. - // However, the behavior should be reasonable, though not optimal. - baseLearnerWeights(m) = learningRate - - predError = GradientBoostedTreesModel.updatePredictionError( - input, predError, baseLearnerWeights(m), baseLearners(m), loss) - predErrorCheckpointer.update(predError) - logDebug("error of gbt = " + predError.values.mean()) - - if (validate) { - // Stop training early if - // 1. Reduction in error is less than the validationTol or - // 2. If the error increases, that is if the model is overfit. - // We want the model returned corresponding to the best validation error. - - validatePredError = GradientBoostedTreesModel.updatePredictionError( - validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss) - validatePredErrorCheckpointer.update(validatePredError) - val currentValidateError = validatePredError.values.mean() - if (bestValidateError - currentValidateError < validationTol * Math.max( - currentValidateError, 0.01)) { - doneLearning = true - } else if (currentValidateError < bestValidateError) { - bestValidateError = currentValidateError - bestM = m + 1 - } - } - m += 1 - } - - timer.stop("total") - - logInfo("Internal timing for DecisionTree:") - logInfo(s"$timer") - - predErrorCheckpointer.deleteAllCheckpoints() - validatePredErrorCheckpointer.deleteAllCheckpoints() - if (persistedInput) input.unpersist() - - if (validate) { - new GradientBoostedTreesModel( - boostingStrategy.treeStrategy.algo, - baseLearners.slice(0, bestM), - baseLearnerWeights.slice(0, bestM)) - } else { - new GradientBoostedTreesModel( - boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights) - } - } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index a684cdd18c2fc..ca7fb7f51c3fc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -17,26 +17,23 @@ package org.apache.spark.mllib.tree -import java.io.IOException - -import scala.collection.mutable import scala.collection.JavaConverters._ +import scala.util.Try -import org.apache.spark.Logging import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD +import org.apache.spark.internal.Logging +import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, RandomForestParams => NewRFParams} +import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, NodeIdCache, - TimeTracker, TreePoint} +import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impurity.Impurities import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -import org.apache.spark.util.random.SamplingUtils + /** * A class that implements a [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] @@ -53,16 +50,21 @@ import org.apache.spark.util.random.SamplingUtils * random forests]] * * @param strategy The configuration parameters for the random forest algorithm which specify - * the type of algorithm (classification, regression, etc.), feature type + * the type of random forest (classification or regression), feature type * (continuous, categorical), depth of the tree, quantile calculation strategy, * etc. * @param numTrees If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. * @param featureSubsetStrategy Number of features to consider for splits at each node. - * Supported: "auto", "all", "sqrt", "log2", "onethird". + * Supported values: "auto", "all", "sqrt", "log2", "onethird". + * Supported numerical values: "(0.0-1.0]", "[1-n]". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "sqrt" for classification and * to "onethird" for regression. + * If a real value "n" in the range (0, 1.0] is set, + * use n * number of features. + * If an integer value "n" in the range (1, num features) is set, + * use n features. * @param seed Random seed for bootstrapping and choosing feature subsets. */ private class RandomForest ( @@ -72,188 +74,25 @@ private class RandomForest ( private val seed: Int) extends Serializable with Logging { - /* - ALGORITHM - This is a sketch of the algorithm to help new developers. - - The algorithm partitions data by instances (rows). - On each iteration, the algorithm splits a set of nodes. In order to choose the best split - for a given node, sufficient statistics are collected from the distributed data. - For each node, the statistics are collected to some worker node, and that worker selects - the best split. - - This setup requires discretization of continuous features. This binning is done in the - findSplitsBins() method during initialization, after which each continuous feature becomes - an ordered discretized feature with at most maxBins possible values. - - The main loop in the algorithm operates on a queue of nodes (nodeQueue). These nodes - lie at the periphery of the tree being trained. If multiple trees are being trained at once, - then this queue contains nodes from all of them. Each iteration works roughly as follows: - On the master node: - - Some number of nodes are pulled off of the queue (based on the amount of memory - required for their sufficient statistics). - - For random forests, if featureSubsetStrategy is not "all," then a subset of candidate - features are chosen for each node. See method selectNodesToSplit(). - On worker nodes, via method findBestSplits(): - - The worker makes one pass over its subset of instances. - - For each (tree, node, feature, split) tuple, the worker collects statistics about - splitting. Note that the set of (tree, node) pairs is limited to the nodes selected - from the queue for this iteration. The set of features considered can also be limited - based on featureSubsetStrategy. - - For each node, the statistics for that node are aggregated to a particular worker - via reduceByKey(). The designated worker chooses the best (feature, split) pair, - or chooses to stop splitting if the stopping criteria are met. - On the master node: - - The master collects all decisions about splitting nodes and updates the model. - - The updated model is passed to the workers on the next iteration. - This process continues until the node queue is empty. - - Most of the methods in this implementation support the statistics aggregation, which is - the heaviest part of the computation. In general, this implementation is bound by either - the cost of statistics computation on workers or by communicating the sufficient statistics. - */ - strategy.assertValid() require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.") - require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy), + require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy) + || Try(featureSubsetStrategy.toInt).filter(_ > 0).isSuccess + || Try(featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess, s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." + - s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}.") + s" Supported values: ${NewRFParams.supportedFeatureSubsetStrategies.mkString(", ")}," + + s" (0.0-1.0], [1-n].") /** * Method to train a decision tree model over an RDD - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @return a random forest model that can be used for prediction + * + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return RandomForestModel that can be used for prediction. */ def run(input: RDD[LabeledPoint]): RandomForestModel = { - - val timer = new TimeTracker() - - timer.start("total") - - timer.start("init") - - val retaggedInput = input.retag(classOf[LabeledPoint]) - val metadata = - DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy) - logDebug("algo = " + strategy.algo) - logDebug("numTrees = " + numTrees) - logDebug("seed = " + seed) - logDebug("maxBins = " + metadata.maxBins) - logDebug("featureSubsetStrategy = " + featureSubsetStrategy) - logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode) - logDebug("subsamplingRate = " + strategy.subsamplingRate) - - // Find the splits and the corresponding bins (interval between the splits) using a sample - // of the input data. - timer.start("findSplitsBins") - val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata) - timer.stop("findSplitsBins") - logDebug("numBins: feature: number of bins") - logDebug(Range(0, metadata.numFeatures).map { featureIndex => - s"\t$featureIndex\t${metadata.numBins(featureIndex)}" - }.mkString("\n")) - - // Bin feature values (TreePoint representation). - // Cache input RDD for speedup during multiple passes. - val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) - - val withReplacement = if (numTrees > 1) true else false - - val baggedInput - = BaggedPoint.convertToBaggedRDD(treeInput, - strategy.subsamplingRate, numTrees, - withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK) - - // depth of the decision tree - val maxDepth = strategy.maxDepth - require(maxDepth <= 30, - s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.") - - // Max memory usage for aggregates - // TODO: Calculate memory usage more precisely. - val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L - logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") - val maxMemoryPerNode = { - val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { - // Find numFeaturesPerNode largest bins to get an upper bound on memory usage. - Some(metadata.numBins.zipWithIndex.sortBy(- _._1) - .take(metadata.numFeaturesPerNode).map(_._2)) - } else { - None - } - RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L - } - require(maxMemoryPerNode <= maxMemoryUsage, - s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," + - " which is too small for the given features." + - s" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}") - - timer.stop("init") - - /* - * The main idea here is to perform group-wise training of the decision tree nodes thus - * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup). - * Each data sample is handled by a particular node (or it reaches a leaf and is not used - * in lower levels). - */ - - // Create an RDD of node Id cache. - // At first, all the rows belong to the root nodes (node Id == 1). - val nodeIdCache = if (strategy.useNodeIdCache) { - Some(NodeIdCache.init( - data = baggedInput, - numTrees = numTrees, - checkpointInterval = strategy.checkpointInterval, - initVal = 1)) - } else { - None - } - - // FIFO queue of nodes to train: (treeIndex, node) - val nodeQueue = new mutable.Queue[(Int, Node)]() - - val rng = new scala.util.Random() - rng.setSeed(seed) - - // Allocate and queue root nodes. - val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1)) - Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))) - - while (nodeQueue.nonEmpty) { - // Collect some nodes to split, and choose features for each node (if subsampling). - // Each group of nodes may come from one or multiple trees, and at multiple levels. - val (nodesForGroup, treeToNodeToIndexInfo) = - RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) - // Sanity check (should never occur): - assert(nodesForGroup.size > 0, - s"RandomForest selected empty nodesForGroup. Error for unknown reason.") - - // Choose node splits, and enqueue new nodes as needed. - timer.start("findBestSplits") - DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, - treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache) - timer.stop("findBestSplits") - } - - baggedInput.unpersist() - - timer.stop("total") - - logInfo("Internal timing for DecisionTree:") - logInfo(s"$timer") - - // Delete any remaining checkpoints used for node Id cache. - if (nodeIdCache.nonEmpty) { - try { - nodeIdCache.get.deleteAllCheckpoints() - } catch { - case e: IOException => - logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}") - } - } - - val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo)) - new RandomForestModel(strategy.algo, trees) + val trees: Array[NewDTModel] = + NewRandomForest.run(input, strategy, numTrees, featureSubsetStrategy, seed.toLong) + new RandomForestModel(strategy.algo, trees.map(_.toOld)) } } @@ -269,12 +108,12 @@ object RandomForest extends Serializable with Logging { * @param strategy Parameters for training each tree in the forest. * @param numTrees Number of trees in the random forest. * @param featureSubsetStrategy Number of features to consider for splits at each node. - * Supported: "auto", "all", "sqrt", "log2", "onethird". + * Supported values: "auto", "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "sqrt". - * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return a random forest model that can be used for prediction + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction. */ @Since("1.2.0") def trainClassifier( @@ -294,25 +133,25 @@ object RandomForest extends Serializable with Logging { * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * Labels should take values {0, 1, ..., numClasses-1}. - * @param numClasses number of classes for classification. - * @param categoricalFeaturesInfo Map storing arity of categorical features. - * E.g., an entry (n -> k) indicates that feature n is categorical - * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param numClasses Number of classes for classification. + * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n -> k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. * @param numTrees Number of trees in the random forest. * @param featureSubsetStrategy Number of features to consider for splits at each node. - * Supported: "auto", "all", "sqrt", "log2", "onethird". + * Supported values: "auto", "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "sqrt". * @param impurity Criterion used for information gain calculation. * Supported values: "gini" (recommended) or "entropy". - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * (suggested value: 4) - * @param maxBins maximum number of bins used for splitting features - * (suggested value: 100) - * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return a random forest model that can be used for prediction + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * (suggested value: 4) + * @param maxBins Maximum number of bins used for splitting features + * (suggested value: 100) + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction. */ @Since("1.2.0") def trainClassifier( @@ -358,12 +197,12 @@ object RandomForest extends Serializable with Logging { * @param strategy Parameters for training each tree in the forest. * @param numTrees Number of trees in the random forest. * @param featureSubsetStrategy Number of features to consider for splits at each node. - * Supported: "auto", "all", "sqrt", "log2", "onethird". + * Supported values: "auto", "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "onethird". - * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return a random forest model that can be used for prediction + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction. */ @Since("1.2.0") def trainRegressor( @@ -383,24 +222,24 @@ object RandomForest extends Serializable with Logging { * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * Labels are real numbers. - * @param categoricalFeaturesInfo Map storing arity of categorical features. - * E.g., an entry (n -> k) indicates that feature n is categorical - * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n -> k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. * @param numTrees Number of trees in the random forest. * @param featureSubsetStrategy Number of features to consider for splits at each node. - * Supported: "auto", "all", "sqrt", "log2", "onethird". + * Supported values: "auto", "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "onethird". * @param impurity Criterion used for information gain calculation. - * Supported values: "variance". - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * (suggested value: 4) - * @param maxBins maximum number of bins used for splitting features - * (suggested value: 100) - * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return a random forest model that can be used for prediction + * The only supported value for regression is "variance". + * @param maxDepth Maximum depth of the tree. (e.g., depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * (suggested value: 4) + * @param maxBins Maximum number of bins used for splitting features. + * (suggested value: 100) + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction. */ @Since("1.2.0") def trainRegressor( @@ -440,86 +279,5 @@ object RandomForest extends Serializable with Logging { * List of supported feature subset sampling strategies. */ @Since("1.2.0") - val supportedFeatureSubsetStrategies: Array[String] = - Array("auto", "all", "sqrt", "log2", "onethird") - - private[tree] class NodeIndexInfo( - val nodeIndexInGroup: Int, - val featureSubset: Option[Array[Int]]) extends Serializable - - /** - * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration. - * This tracks the memory usage for aggregates and stops adding nodes when too much memory - * will be needed; this allows an adaptive number of nodes since different nodes may require - * different amounts of memory (if featureSubsetStrategy is not "all"). - * - * @param nodeQueue Queue of nodes to split. - * @param maxMemoryUsage Bound on size of aggregate statistics. - * @return (nodesForGroup, treeToNodeToIndexInfo). - * nodesForGroup holds the nodes to split: treeIndex --> nodes in tree. - * - * treeToNodeToIndexInfo holds indices selected features for each node: - * treeIndex --> (global) node index --> (node index in group, feature indices). - * The (global) node index is the index in the tree; the node index in group is the - * index in [0, numNodesInGroup) of the node in this group. - * The feature indices are None if not subsampling features. - */ - private[tree] def selectNodesToSplit( - nodeQueue: mutable.Queue[(Int, Node)], - maxMemoryUsage: Long, - metadata: DecisionTreeMetadata, - rng: scala.util.Random): (Map[Int, Array[Node]], Map[Int, Map[Int, NodeIndexInfo]]) = { - // Collect some nodes to split: - // nodesForGroup(treeIndex) = nodes to split - val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[Node]]() - val mutableTreeToNodeToIndexInfo = - new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]() - var memUsage: Long = 0L - var numNodesInGroup = 0 - while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) { - val (treeIndex, node) = nodeQueue.head - // Choose subset of features for node (if subsampling). - val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { - Some(SamplingUtils.reservoirSampleAndCount(Range(0, - metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1) - } else { - None - } - // Check if enough memory remains to add this node to the group. - val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L - if (memUsage + nodeMemUsage <= maxMemoryUsage) { - nodeQueue.dequeue() - mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[Node]()) += node - mutableTreeToNodeToIndexInfo - .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id) - = new NodeIndexInfo(numNodesInGroup, featureSubset) - } - numNodesInGroup += 1 - memUsage += nodeMemUsage - } - // Convert mutable maps to immutable ones. - val nodesForGroup: Map[Int, Array[Node]] = mutableNodesForGroup.mapValues(_.toArray).toMap - val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap - (nodesForGroup, treeToNodeToIndexInfo) - } - - /** - * Get the number of values to be stored for this node in the bin aggregates. - * @param featureSubset Indices of features which may be split at this node. - * If None, then use all features. - */ - private[tree] def aggregateSizeForNode( - metadata: DecisionTreeMetadata, - featureSubset: Option[Array[Int]]): Long = { - val totalBins = if (featureSubset.nonEmpty) { - featureSubset.get.map(featureIndex => metadata.numBins(featureIndex).toLong).sum - } else { - metadata.numBins.map(_.toLong).sum - } - if (metadata.isClassification) { - metadata.numClasses * totalBins - } else { - 3 * totalBins - } - } + val supportedFeatureSubsetStrategies: Array[String] = NewRFParams.supportedFeatureSubsetStrategies } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index d2513a9d5c5bb..d8405d13ce904 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -21,7 +21,7 @@ import scala.beans.BeanProperty import org.apache.spark.annotation.Since import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} +import org.apache.spark.mllib.tree.loss.{LogLoss, Loss, SquaredError} /** * Configuration options for [[org.apache.spark.mllib.tree.GradientBoostedTrees]]. @@ -59,7 +59,7 @@ case class BoostingStrategy @Since("1.4.0") ( * Check validity of parameters. * Throws exception if invalid. */ - private[tree] def assertValid(): Unit = { + private[spark] def assertValid(): Unit = { treeStrategy.algo match { case Classification => require(treeStrategy.numClasses == 2, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 372d6617a4014..b34e1b1b56c43 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -21,9 +21,9 @@ import scala.beans.BeanProperty import scala.collection.JavaConverters._ import org.apache.spark.annotation.Since -import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} /** * Stores all the configuration options for tree construction @@ -34,8 +34,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * Supported for Classification: [[org.apache.spark.mllib.tree.impurity.Gini]], * [[org.apache.spark.mllib.tree.impurity.Entropy]]. * Supported for Regression: [[org.apache.spark.mllib.tree.impurity.Variance]]. - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). * @param numClasses Number of classes for classification. * (Ignored for regression.) * Default value is 2 (binary classification). @@ -45,10 +45,9 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * @param quantileCalculationStrategy Algorithm for calculating quantiles. Supported: * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]] * @param categoricalFeaturesInfo A map storing information about the categorical variables and the - * number of discrete values they take. For example, an entry (n -> - * k) implies the feature n is categorical with k categories 0, - * 1, 2, ... , k-1. It's important to note that features are - * zero-indexed. + * number of discrete values they take. An entry (n -> k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. * @param minInstancesPerNode Minimum number of instances each child must have after split. * Default value is 1. If a split cause left or right child * to have less than minInstancesPerNode, @@ -57,10 +56,11 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * If a split has less information gain than minInfoGain, * this split will not be considered as a valid split. * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is - * 256 MB. + * 256 MB. If too small, then 1 node will be split per iteration, and + * its aggregates may exceed this size. * @param subsamplingRate Fraction of the training data used for learning decision tree. * @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will - * maintain a separate RDD of node Id cache for each row. + * maintain a separate RDD of node Id cache for each row. * @param checkpointInterval How often to checkpoint when the node Id cache gets updated. * E.g. 10 means that the cache will get checkpointed every 10 updates. If * the checkpoint directory is not set in @@ -134,7 +134,7 @@ class Strategy @Since("1.3.0") ( * Check validity of parameters. * Throws exception if invalid. */ - private[tree] def assertValid(): Unit = { + private[spark] def assertValid(): Unit = { algo match { case Classification => require(numClasses >= 2, @@ -202,8 +202,4 @@ object Strategy { numClasses = 0) } - @deprecated("Use Strategy.defaultStrategy instead.", "1.5.0") - @Since("1.2.0") - def defaultStategy(algo: Algo): Strategy = defaultStrategy(algo) - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala deleted file mode 100644 index 1c611976a9308..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala +++ /dev/null @@ -1,196 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.impl - -import scala.collection.mutable - -import org.apache.hadoop.fs.{Path, FileSystem} - -import org.apache.spark.rdd.RDD -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.tree.configuration.FeatureType._ -import org.apache.spark.storage.StorageLevel -import org.apache.spark.mllib.tree.model.{Bin, Node, Split} - -/** - * :: DeveloperApi :: - * This is used by the node id cache to find the child id that a data point would belong to. - * @param split Split information. - * @param nodeIndex The current node index of a data point that this will update. - */ -@DeveloperApi -private[tree] case class NodeIndexUpdater( - split: Split, - nodeIndex: Int) { - /** - * Determine a child node index based on the feature value and the split. - * @param binnedFeatures Binned feature values. - * @param bins Bin information to convert the bin indices to approximate feature values. - * @return Child node index to update to. - */ - def updateNodeIndex(binnedFeatures: Array[Int], bins: Array[Array[Bin]]): Int = { - if (split.featureType == Continuous) { - val featureIndex = split.feature - val binIndex = binnedFeatures(featureIndex) - val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold - if (featureValueUpperBound <= split.threshold) { - Node.leftChildIndex(nodeIndex) - } else { - Node.rightChildIndex(nodeIndex) - } - } else { - if (split.categories.contains(binnedFeatures(split.feature).toDouble)) { - Node.leftChildIndex(nodeIndex) - } else { - Node.rightChildIndex(nodeIndex) - } - } - } -} - -/** - * :: DeveloperApi :: - * A given TreePoint would belong to a particular node per tree. - * Each row in the nodeIdsForInstances RDD is an array over trees of the node index - * in each tree. Initially, values should all be 1 for root node. - * The nodeIdsForInstances RDD needs to be updated at each iteration. - * @param nodeIdsForInstances The initial values in the cache - * (should be an Array of all 1's (meaning the root nodes)). - * @param checkpointInterval The checkpointing interval - * (how often should the cache be checkpointed.). - */ -@DeveloperApi -private[spark] class NodeIdCache( - var nodeIdsForInstances: RDD[Array[Int]], - val checkpointInterval: Int) { - - // Keep a reference to a previous node Ids for instances. - // Because we will keep on re-persisting updated node Ids, - // we want to unpersist the previous RDD. - private var prevNodeIdsForInstances: RDD[Array[Int]] = null - - // To keep track of the past checkpointed RDDs. - private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]() - private var rddUpdateCount = 0 - - /** - * Update the node index values in the cache. - * This updates the RDD and its lineage. - * TODO: Passing bin information to executors seems unnecessary and costly. - * @param data The RDD of training rows. - * @param nodeIdUpdaters A map of node index updaters. - * The key is the indices of nodes that we want to update. - * @param bins Bin information needed to find child node indices. - */ - def updateNodeIndices( - data: RDD[BaggedPoint[TreePoint]], - nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]], - bins: Array[Array[Bin]]): Unit = { - if (prevNodeIdsForInstances != null) { - // Unpersist the previous one if one exists. - prevNodeIdsForInstances.unpersist() - } - - prevNodeIdsForInstances = nodeIdsForInstances - nodeIdsForInstances = data.zip(nodeIdsForInstances).map { - case (point, node) => { - var treeId = 0 - while (treeId < nodeIdUpdaters.length) { - val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(node(treeId), null) - if (nodeIdUpdater != null) { - val newNodeIndex = nodeIdUpdater.updateNodeIndex( - binnedFeatures = point.datum.binnedFeatures, - bins = bins) - node(treeId) = newNodeIndex - } - - treeId += 1 - } - - node - } - } - - // Keep on persisting new ones. - nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK) - rddUpdateCount += 1 - - // Handle checkpointing if the directory is not None. - if (nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty && - (rddUpdateCount % checkpointInterval) == 0) { - // Let's see if we can delete previous checkpoints. - var canDelete = true - while (checkpointQueue.size > 1 && canDelete) { - // We can delete the oldest checkpoint iff - // the next checkpoint actually exists in the file system. - if (checkpointQueue.get(1).get.getCheckpointFile.isDefined) { - val old = checkpointQueue.dequeue() - - // Since the old checkpoint is not deleted by Spark, - // we'll manually delete it here. - val fs = FileSystem.get(old.sparkContext.hadoopConfiguration) - fs.delete(new Path(old.getCheckpointFile.get), true) - } else { - canDelete = false - } - } - - nodeIdsForInstances.checkpoint() - checkpointQueue.enqueue(nodeIdsForInstances) - } - } - - /** - * Call this after training is finished to delete any remaining checkpoints. - */ - def deleteAllCheckpoints(): Unit = { - while (checkpointQueue.nonEmpty) { - val old = checkpointQueue.dequeue() - for (checkpointFile <- old.getCheckpointFile) { - val fs = FileSystem.get(old.sparkContext.hadoopConfiguration) - fs.delete(new Path(checkpointFile), true) - } - } - if (prevNodeIdsForInstances != null) { - // Unpersist the previous one if one exists. - prevNodeIdsForInstances.unpersist() - } - } -} - -@DeveloperApi -private[spark] object NodeIdCache { - /** - * Initialize the node Id cache with initial node Id values. - * @param data The RDD of training rows. - * @param numTrees The number of trees that we want to create cache for. - * @param checkpointInterval The checkpointing interval - * (how often should the cache be checkpointed.). - * @param initVal The initial values in the cache. - * @return A node Id cache containing an RDD of initial root node Indices. - */ - def init( - data: RDD[BaggedPoint[TreePoint]], - numTrees: Int, - checkpointInterval: Int, - initVal: Int = 1): NodeIdCache = { - new NodeIdCache( - data.map(_ => Array.fill[Int](numTrees)(initVal)), - checkpointInterval) - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala deleted file mode 100644 index 21919d69a38a3..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.impl - -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.Bin -import org.apache.spark.rdd.RDD - - -/** - * Internal representation of LabeledPoint for DecisionTree. - * This bins feature values based on a subsampled of data as follows: - * (a) Continuous features are binned into ranges. - * (b) Unordered categorical features are binned based on subsets of feature values. - * "Unordered categorical features" are categorical features with low arity used in - * multiclass classification. - * (c) Ordered categorical features are binned based on feature values. - * "Ordered categorical features" are categorical features with high arity, - * or any categorical feature used in regression or binary classification. - * - * @param label Label from LabeledPoint - * @param binnedFeatures Binned feature values. - * Same length as LabeledPoint.features, but values are bin indices. - */ -private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int]) - extends Serializable { -} - -private[spark] object TreePoint { - - /** - * Convert an input dataset into its TreePoint representation, - * binning feature values in preparation for DecisionTree training. - * @param input Input dataset. - * @param bins Bins for features, of size (numFeatures, numBins). - * @param metadata Learning and dataset metadata - * @return TreePoint dataset representation - */ - def convertToTreeRDD( - input: RDD[LabeledPoint], - bins: Array[Array[Bin]], - metadata: DecisionTreeMetadata): RDD[TreePoint] = { - // Construct arrays for featureArity for efficiency in the inner loop. - val featureArity: Array[Int] = new Array[Int](metadata.numFeatures) - var featureIndex = 0 - while (featureIndex < metadata.numFeatures) { - featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0) - featureIndex += 1 - } - input.map { x => - TreePoint.labeledPointToTreePoint(x, bins, featureArity) - } - } - - /** - * Convert one LabeledPoint into its TreePoint representation. - * @param bins Bins for features, of size (numFeatures, numBins). - * @param featureArity Array indexed by feature, with value 0 for continuous and numCategories - * for categorical features. - */ - private def labeledPointToTreePoint( - labeledPoint: LabeledPoint, - bins: Array[Array[Bin]], - featureArity: Array[Int]): TreePoint = { - val numFeatures = labeledPoint.features.size - val arr = new Array[Int](numFeatures) - var featureIndex = 0 - while (featureIndex < numFeatures) { - arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex), - bins) - featureIndex += 1 - } - new TreePoint(labeledPoint.label, arr) - } - - /** - * Find bin for one (labeledPoint, feature). - * - * @param featureArity 0 for continuous features; number of categories for categorical features. - * @param bins Bins for features, of size (numFeatures, numBins). - */ - private def findBin( - featureIndex: Int, - labeledPoint: LabeledPoint, - featureArity: Int, - bins: Array[Array[Bin]]): Int = { - - /** - * Binary search helper method for continuous feature. - */ - def binarySearchForBins(): Int = { - val binForFeatures = bins(featureIndex) - val feature = labeledPoint.features(featureIndex) - var left = 0 - var right = binForFeatures.length - 1 - while (left <= right) { - val mid = left + (right - left) / 2 - val bin = binForFeatures(mid) - val lowThreshold = bin.lowSplit.threshold - val highThreshold = bin.highSplit.threshold - if ((lowThreshold < feature) && (highThreshold >= feature)) { - return mid - } else if (lowThreshold >= feature) { - right = mid - 1 - } else { - left = mid + 1 - } - } - -1 - } - - if (featureArity == 0) { - // Perform binary search for finding bin for continuous features. - val binIndex = binarySearchForBins() - if (binIndex == -1) { - throw new RuntimeException("No bin was found for continuous feature." + - " This error can occur when given invalid data values (such as NaN)." + - s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}") - } - binIndex - } else { - // Categorical feature bins are indexed by feature values. - val featureValue = labeledPoint.features(featureIndex) - if (featureValue < 0 || featureValue >= featureArity) { - throw new IllegalArgumentException( - s"DecisionTree given invalid data:" + - s" Feature $featureIndex is categorical with values in" + - s" {0,...,${featureArity - 1}," + - s" but a data point gives it value $featureValue.\n" + - " Bad data point: " + labeledPoint.toString) - } - featureValue.toInt - } - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 73df6b054a8ce..ff7700d2d1b7f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -85,7 +85,7 @@ object Entropy extends Impurity { * Note: Instances of this class do not hold the data; they operate on views of the data. * @param numClasses Number of classes for label. */ -private[tree] class EntropyAggregator(numClasses: Int) +private[spark] class EntropyAggregator(numClasses: Int) extends ImpurityAggregator(numClasses) with Serializable { /** @@ -113,7 +113,6 @@ private[tree] class EntropyAggregator(numClasses: Int) def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = { new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray) } - } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index f21845b21a802..58dc79b7398e2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -81,7 +81,7 @@ object Gini extends Impurity { * Note: Instances of this class do not hold the data; they operate on views of the data. * @param numClasses Number of classes for label. */ -private[tree] class GiniAggregator(numClasses: Int) +private[spark] class GiniAggregator(numClasses: Int) extends ImpurityAggregator(numClasses) with Serializable { /** @@ -109,7 +109,6 @@ private[tree] class GiniAggregator(numClasses: Int) def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = { new GiniCalculator(allStats.view(offset, offset + statsSize).toArray) } - } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 4637dcceea7f8..65f0163ec6059 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -89,7 +89,6 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser * @param offset Start index of stats for this (node, feature, bin). */ def getCalculator(allStats: Array[Double], offset: Int): ImpurityCalculator - } /** @@ -179,3 +178,21 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten } } + +private[spark] object ImpurityCalculator { + + /** + * Create an [[ImpurityCalculator]] instance of the given impurity type and with + * the given stats. + */ + def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = { + impurity match { + case "gini" => new GiniCalculator(stats) + case "entropy" => new EntropyCalculator(stats) + case "variance" => new VarianceCalculator(stats) + case _ => + throw new IllegalArgumentException( + s"ImpurityCalculator builder did not recognize impurity type: $impurity") + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index a74197278d6f7..2423516123b82 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -71,7 +71,7 @@ object Variance extends Impurity { * in order to compute impurity from a sample. * Note: Instances of this class do not hold the data; they operate on views of the data. */ -private[tree] class VarianceAggregator() +private[spark] class VarianceAggregator() extends ImpurityAggregator(statsSize = 3) with Serializable { /** @@ -93,7 +93,6 @@ private[tree] class VarianceAggregator() def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = { new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray) } - } /** @@ -104,9 +103,9 @@ private[tree] class VarianceAggregator() */ private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { - require(stats.size == 3, + require(stats.length == 3, s"VarianceCalculator requires sufficient statistics array stats to be of length 3," + - s" but was given array of length ${stats.size}.") + s" but was given array of length ${stats.length}.") /** * Make a deep copy of this [[ImpurityCalculator]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala index bab7b8c6cadf2..9b60d018d0eda 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala @@ -18,8 +18,6 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.TreeEnsembleModel /** @@ -47,7 +45,7 @@ object AbsoluteError extends Loss { if (label - prediction < 0) 1.0 else -1.0 } - override private[mllib] def computeError(prediction: Double, label: Double): Double = { + override private[spark] def computeError(prediction: Double, label: Double): Double = { val err = label - prediction math.abs(err) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index b2b4594712f0d..5d92ce495b04d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -18,8 +18,6 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.mllib.util.MLUtils @@ -49,7 +47,7 @@ object LogLoss extends Loss { - 4.0 * label / (1.0 + math.exp(2.0 * label * prediction)) } - override private[mllib] def computeError(prediction: Double, label: Double): Double = { + override private[spark] def computeError(prediction: Double, label: Double): Double = { val margin = 2.0 * label * prediction // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable. 2.0 * MLUtils.log1pExp(-margin) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index 687cde325ffed..de14ddf024d75 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -61,5 +61,5 @@ trait Loss extends Serializable { * @param label True label. * @return Measure of model error on datapoint. */ - private[mllib] def computeError(prediction: Double, label: Double): Double + private[spark] def computeError(prediction: Double, label: Double): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala index 3f7d3d38be16c..4eb6810c46b20 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala @@ -18,8 +18,6 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.TreeEnsembleModel /** @@ -47,7 +45,7 @@ object SquaredError extends Loss { - 2.0 * (label - prediction) } - override private[mllib] def computeError(prediction: Double, label: Double): Double = { + override private[spark] def computeError(prediction: Double, label: Double): Double = { val err = label - prediction err * err } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala deleted file mode 100644 index 0cad473782af1..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.model - -import org.apache.spark.mllib.tree.configuration.FeatureType._ - -/** - * Used for "binning" the feature values for faster best split calculation. - * - * For a continuous feature, the bin is determined by a low and a high split, - * where an example with featureValue falls into the bin s.t. - * lowSplit.threshold < featureValue <= highSplit.threshold. - * - * For ordered categorical features, there is a 1-1-1 correspondence between - * bins, splits, and feature values. The bin is determined by category/feature value. - * However, the bins are not necessarily ordered by feature value; - * they are ordered using impurity. - * - * For unordered categorical features, there is a 1-1 correspondence between bins, splits, - * where bins and splits correspond to subsets of feature values (in highSplit.categories). - * An unordered feature with k categories uses (1 << k - 1) - 1 bins, corresponding to all - * partitionings of categories into 2 disjoint, non-empty sets. - * - * @param lowSplit signifying the lower threshold for the continuous feature to be - * accepted in the bin - * @param highSplit signifying the upper threshold for the continuous feature to be - * accepted in the bin - * @param featureType type of feature -- categorical or continuous - * @param category categorical label value accepted in the bin for ordered features - */ -private[tree] -case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 54c136aecf660..a87f8a6cde318 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -23,9 +23,10 @@ import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.configuration.{Algo, FeatureType} import org.apache.spark.mllib.tree.configuration.Algo._ @@ -155,7 +156,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { feature: Int, threshold: Double, featureType: Int, - categories: Seq[Double]) { // TODO: Change to List once SPARK-3365 is fixed + categories: Seq[Double]) { def toSplit: Split = { new Split(feature, threshold, FeatureType(featureType), categories.toList) } @@ -201,7 +202,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { } def save(sc: SparkContext, path: String, model: DecisionTreeModel): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // SPARK-6120: We do a hacky check here so users understand why save() is failing @@ -242,15 +243,15 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = { val datapath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) // Load Parquet data. val dataRDD = sqlContext.read.parquet(datapath) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[NodeData](dataRDD.schema) - val nodes = dataRDD.map(NodeData.apply) + val nodes = dataRDD.rdd.map(NodeData.apply) // Build node data into a tree. val trees = constructTrees(nodes) - assert(trees.size == 1, + assert(trees.length == 1, "Decision tree should contain exactly one tree but got ${trees.size} trees.") val model = new DecisionTreeModel(trees(0), Algo.fromString(algo)) assert(model.numNodes == numNodes, s"Unable to load DecisionTreeModel data from: $datapath." + @@ -266,7 +267,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { .map { case (treeId, data) => (treeId, constructTree(data)) }.sortBy(_._1) - val numTrees = trees.size + val numTrees = trees.length val treeIndices = trees.map(_._1).toSeq assert(treeIndices == (0 until numTrees), s"Tree indices must start from 0 and increment by 1, but we found $treeIndices.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index 091a0462c204f..f3dbfd96e1815 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -68,18 +68,7 @@ class InformationGainStats( } } -private[spark] object InformationGainStats { - /** - * An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to - * denote that current split doesn't satisfies minimum info gain or - * minimum number of instances per node. - */ - val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, - new Predict(0.0, 0.0), new Predict(0.0, 0.0)) -} - /** - * :: DeveloperApi :: * Impurity statistics for each split * @param gain information gain value * @param impurity current node impurity @@ -89,7 +78,6 @@ private[spark] object InformationGainStats { * @param valid whether the current split satisfies minimum info gain or * minimum number of instances per node */ -@DeveloperApi private[spark] class ImpurityStats( val gain: Double, val impurity: Double, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index ea6e5aa5d94e7..5fd053647aa46 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -18,9 +18,9 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.Logging -import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.FeatureType._ /** * :: DeveloperApi :: @@ -56,34 +56,13 @@ class Node @Since("1.2.0") ( s"split = $split, stats = $stats" } - /** - * build the left node and right nodes if not leaf - * @param nodes array of nodes - */ - @Since("1.0.0") - @deprecated("build should no longer be used since trees are constructed on-the-fly in training", - "1.2.0") - def build(nodes: Array[Node]): Unit = { - logDebug("building node " + id + " at level " + Node.indexToLevel(id)) - logDebug("id = " + id + ", split = " + split) - logDebug("stats = " + stats) - logDebug("predict = " + predict) - logDebug("impurity = " + impurity) - if (!isLeaf) { - leftNode = Some(nodes(Node.leftChildIndex(id))) - rightNode = Some(nodes(Node.rightChildIndex(id))) - leftNode.get.build(nodes) - rightNode.get.build(nodes) - } - } - /** * predict value if node is not leaf * @param features feature value * @return predicted value */ @Since("1.1.0") - def predict(features: Vector) : Double = { + def predict(features: Vector): Double = { if (isLeaf) { predict.predict } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index b85a66c05a81d..5cef9d0631b59 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -19,8 +19,6 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType -import org.apache.spark.mllib.tree.configuration.FeatureType -import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 90e032e3d9842..cbf49b6d5821a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -24,9 +24,10 @@ import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.annotation.Since +import org.apache.spark.SparkContext +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaRDD +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo @@ -186,6 +187,7 @@ class GradientBoostedTreesModel @Since("1.2.0") ( object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { /** + * :: DeveloperApi :: * Compute the initial predictions and errors for a dataset for the first * iteration of gradient boosting. * @param data: training data. @@ -196,6 +198,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { * corresponding to every sample. */ @Since("1.4.0") + @DeveloperApi def computeInitialPredictionAndError( data: RDD[LabeledPoint], initTreeWeight: Double, @@ -209,6 +212,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { } /** + * :: DeveloperApi :: * Update a zipped predictionError RDD * (as obtained with computeInitialPredictionAndError) * @param data: training data. @@ -220,6 +224,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { * corresponding to each sample. */ @Since("1.4.0") + @DeveloperApi def updatePredictionError( data: RDD[LabeledPoint], predictionAndError: RDD[(Double, Double)], @@ -408,7 +413,7 @@ private[tree] object TreeEnsembleModel extends Logging { case class EnsembleNodeData(treeId: Int, node: NodeData) def save(sc: SparkContext, path: String, model: TreeEnsembleModel, className: String): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // SPARK-6120: We do a hacky check here so users understand why save() is failing @@ -468,8 +473,8 @@ private[tree] object TreeEnsembleModel extends Logging { path: String, treeAlgo: String): Array[DecisionTreeModel] = { val datapath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) - val nodes = sqlContext.read.parquet(datapath).map(NodeData.apply) + val sqlContext = SQLContext.getOrCreate(sc) + val nodes = sqlContext.read.parquet(datapath).rdd.map(NodeData.apply) val trees = constructTrees(nodes) trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo))) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala index dffe6e78939e8..2c712d8f821a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala @@ -17,8 +17,8 @@ package org.apache.spark.mllib.util -import org.apache.spark.Logging import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.internal.Logging import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index 6ff07eed6cfd2..58fd010e4905f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -24,7 +24,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{BLAS, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -131,42 +131,34 @@ object LinearDataGenerator { eps: Double, sparsity: Double): Seq[LabeledPoint] = { require(0.0 <= sparsity && sparsity <= 1.0) - val rnd = new Random(seed) - val x = Array.fill[Array[Double]](nPoints)( - Array.fill[Double](weights.length)(rnd.nextDouble())) - - val sparseRnd = new Random(seed) - x.foreach { v => - var i = 0 - val len = v.length - while (i < len) { - if (sparseRnd.nextDouble() < sparsity) { - v(i) = 0.0 - } else { - v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) - } - i += 1 - } - } - val y = x.map { xi => - blas.ddot(weights.length, xi, 1, weights, 1) + intercept + eps * rnd.nextGaussian() - } + val rnd = new Random(seed) + def rndElement(i: Int) = {(rnd.nextDouble() - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)} - y.zip(x).map { p => - if (sparsity == 0.0) { + if (sparsity == 0.0) { + (0 until nPoints).map { _ => + val features = Vectors.dense(weights.indices.map { rndElement(_) }.toArray) + val label = BLAS.dot(Vectors.dense(weights), features) + + intercept + eps * rnd.nextGaussian() // Return LabeledPoints with DenseVector - LabeledPoint(p._1, Vectors.dense(p._2)) - } else { + LabeledPoint(label, features) + } + } else { + (0 until nPoints).map { _ => + val indices = weights.indices.filter { _ => rnd.nextDouble() <= sparsity} + val values = indices.map { rndElement(_) } + val features = Vectors.sparse(weights.length, indices.toArray, values.toArray) + val label = BLAS.dot(Vectors.dense(weights), features) + + intercept + eps * rnd.nextGaussian() // Return LabeledPoints with SparseVector - LabeledPoint(p._1, Vectors.dense(p._2).toSparse) + LabeledPoint(label, features) } } } /** * Generate an RDD containing sample data for Linear Regression models - including Ridge, Lasso, - * and uregularized variants. + * and unregularized variants. * * @param sc SparkContext to be used for generating the RDD. * @param nexamples Number of examples that will be contained in the RDD. @@ -183,7 +175,7 @@ object LinearDataGenerator { nfeatures: Int, eps: Double, nparts: Int = 2, - intercept: Double = 0.0) : RDD[LabeledPoint] = { + intercept: Double = 0.0): RDD[LabeledPoint] = { val random = new Random(42) // Random values distributed uniformly in [-0.5, 0.5] val w = Array.fill(nfeatures)(random.nextDouble() - 0.5) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala index 33477ee20ebbd..68835bc79677f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala @@ -19,11 +19,11 @@ package org.apache.spark.mllib.util import scala.util.Random -import org.apache.spark.annotation.{Since, DeveloperApi} import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala index 906bd30563bd0..898a09e51636c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala @@ -23,7 +23,7 @@ import scala.language.postfixOps import scala.util.Random import org.apache.spark.SparkContext -import org.apache.spark.annotation.{Since, DeveloperApi} +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix} import org.apache.spark.rdd.RDD @@ -105,8 +105,7 @@ object MFDataGenerator { // optionally generate testing data if (test) { - val testSampSize = math.min( - math.round(sampSize * testSampFact), math.round(mn - sampSize)).toInt + val testSampSize = math.min(math.round(sampSize * testSampFact).toInt, mn - sampSize) val testOmega = shuffled.slice(sampSize, sampSize + testSampSize) val testOrdered = testOmega.sortWith(_ < _).toArray val testData: RDD[(Int, Int, Double)] = sc.parallelize(testOrdered) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 414ea99cfd8c8..774170ff401e9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -19,15 +19,14 @@ package org.apache.spark.mllib.util import scala.reflect.ClassTag -import org.apache.spark.annotation.Since import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.PartitionwiseSampledRDD -import org.apache.spark.util.random.BernoulliCellSampler -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors} +import org.apache.spark.annotation.Since +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.dot +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.random.BernoulliCellSampler /** * Helper methods to load, save and pre-process data used in ML Lib. @@ -68,41 +67,14 @@ object MLUtils { path: String, numFeatures: Int, minPartitions: Int): RDD[LabeledPoint] = { - val parsed = sc.textFile(path, minPartitions) - .map(_.trim) - .filter(line => !(line.isEmpty || line.startsWith("#"))) - .map { line => - val items = line.split(' ') - val label = items.head.toDouble - val (indices, values) = items.tail.filter(_.nonEmpty).map { item => - val indexAndValue = item.split(':') - val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based. - val value = indexAndValue(1).toDouble - (index, value) - }.unzip - - // check if indices are one-based and in ascending order - var previous = -1 - var i = 0 - val indicesLength = indices.length - while (i < indicesLength) { - val current = indices(i) - require(current > previous, "indices should be one-based and in ascending order" ) - previous = current - i += 1 - } - - (label, indices.toArray, values.toArray) - } + val parsed = parseLibSVMFile(sc, path, minPartitions) // Determine number of features. val d = if (numFeatures > 0) { numFeatures } else { parsed.persist(StorageLevel.MEMORY_ONLY) - parsed.map { case (label, indices, values) => - indices.lastOption.getOrElse(0) - }.reduce(math.max) + 1 + computeNumFeatures(parsed) } parsed.map { case (label, indices, values) => @@ -110,17 +82,46 @@ object MLUtils { } } - // Convenient methods for `loadLibSVMFile`. + private[spark] def computeNumFeatures(rdd: RDD[(Double, Array[Int], Array[Double])]): Int = { + rdd.map { case (label, indices, values) => + indices.lastOption.getOrElse(0) + }.reduce(math.max) + 1 + } - @Since("1.0.0") - @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") - def loadLibSVMFile( + private[spark] def parseLibSVMFile( sc: SparkContext, path: String, - multiclass: Boolean, - numFeatures: Int, - minPartitions: Int): RDD[LabeledPoint] = - loadLibSVMFile(sc, path, numFeatures, minPartitions) + minPartitions: Int): RDD[(Double, Array[Int], Array[Double])] = { + sc.textFile(path, minPartitions) + .map(_.trim) + .filter(line => !(line.isEmpty || line.startsWith("#"))) + .map(parseLibSVMRecord) + } + + private[spark] def parseLibSVMRecord(line: String): (Double, Array[Int], Array[Double]) = { + val items = line.split(' ') + val label = items.head.toDouble + val (indices, values) = items.tail.filter(_.nonEmpty).map { item => + val indexAndValue = item.split(':') + val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based. + val value = indexAndValue(1).toDouble + (index, value) + }.unzip + + // check if indices are one-based and in ascending order + var previous = -1 + var i = 0 + val indicesLength = indices.length + while (i < indicesLength) { + val current = indices(i) + require(current > previous, s"indices should be one-based and in ascending order;" + + " found current=$current, previous=$previous; line=\"$line\"") + previous = current + i += 1 + } + + (label, indices.toArray, values.toArray) + } /** * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the default number of @@ -133,23 +134,6 @@ object MLUtils { numFeatures: Int): RDD[LabeledPoint] = loadLibSVMFile(sc, path, numFeatures, sc.defaultMinPartitions) - @Since("1.0.0") - @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") - def loadLibSVMFile( - sc: SparkContext, - path: String, - multiclass: Boolean, - numFeatures: Int): RDD[LabeledPoint] = - loadLibSVMFile(sc, path, numFeatures) - - @Since("1.0.0") - @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") - def loadLibSVMFile( - sc: SparkContext, - path: String, - multiclass: Boolean): RDD[LabeledPoint] = - loadLibSVMFile(sc, path) - /** * Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint], with number of * features determined automatically and the default number of partitions. @@ -216,48 +200,6 @@ object MLUtils { def loadLabeledPoints(sc: SparkContext, dir: String): RDD[LabeledPoint] = loadLabeledPoints(sc, dir, sc.defaultMinPartitions) - /** - * Load labeled data from a file. The data format used here is - * L, f1 f2 ... - * where f1, f2 are feature values in Double and L is the corresponding label as Double. - * - * @param sc SparkContext - * @param dir Directory to the input data files. - * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is - * the label, and the second element represents the feature values (an array of Double). - * - * @deprecated Should use [[org.apache.spark.rdd.RDD#saveAsTextFile]] for saving and - * [[org.apache.spark.mllib.util.MLUtils#loadLabeledPoints]] for loading. - */ - @Since("1.0.0") - @deprecated("Should use MLUtils.loadLabeledPoints instead.", "1.0.1") - def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { - sc.textFile(dir).map { line => - val parts = line.split(',') - val label = parts(0).toDouble - val features = Vectors.dense(parts(1).trim().split(' ').map(_.toDouble)) - LabeledPoint(label, features) - } - } - - /** - * Save labeled data to a file. The data format used here is - * L, f1 f2 ... - * where f1, f2 are feature values in Double and L is the corresponding label as Double. - * - * @param data An RDD of LabeledPoints containing data to be saved. - * @param dir Directory to save the data. - * - * @deprecated Should use [[org.apache.spark.rdd.RDD#saveAsTextFile]] for saving and - * [[org.apache.spark.mllib.util.MLUtils#loadLabeledPoints]] for loading. - */ - @Since("1.0.0") - @deprecated("Should use RDD[LabeledPoint].saveAsTextFile instead.", "1.0.1") - def saveLabeledData(data: RDD[LabeledPoint], dir: String) { - val dataStr = data.map(x => x.label + "," + x.features.toArray.mkString(" ")) - dataStr.saveAsTextFile(dir) - } - /** * Return a k element array of pairs of RDDs with the first element of each pair * containing the training data, a complement of the validation data and the second @@ -265,6 +207,14 @@ object MLUtils { */ @Since("1.0.0") def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = { + kFold(rdd, numFolds, seed.toLong) + } + + /** + * Version of [[kFold()]] taking a Long seed. + */ + @Since("2.0.0") + def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Long): Array[(RDD[T], RDD[T])] = { val numFoldsF = numFolds.toFloat (1 to numFolds).map { fold => val sampler = new BernoulliCellSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala index a841c5caf0142..2c613348c2d92 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala @@ -98,7 +98,7 @@ private[mllib] object NumericParser { } } else if (token == ")") { parsing = false - } else if (token.trim.isEmpty){ + } else if (token.trim.isEmpty) { // ignore whitespaces between delim chars, e.g. ", [" } else { // expecting a number diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java index 0a8c9e5954676..60a4a1d2ea2af 100644 --- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -17,6 +17,8 @@ package org.apache.spark.ml; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -26,7 +28,6 @@ import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -37,7 +38,7 @@ public class JavaPipelineSuite { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient DataFrame dataset; + private transient Dataset dataset; @Before public void setUp() { @@ -65,7 +66,7 @@ public void pipeline() { .setStages(new PipelineStage[] {scaler, lr}); PipelineModel model = pipeline.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); + Dataset predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java index 60f25e5cce437..1f23682621594 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java @@ -21,16 +21,17 @@ import java.util.HashMap; import java.util.Map; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.ml.tree.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; public class JavaDecisionTreeClassifierSuite implements Serializable { @@ -56,8 +57,8 @@ public void runDT() { JavaRDD data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); - Map categoricalFeatures = new HashMap(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + Map categoricalFeatures = new HashMap<>(); + Dataset dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); // This tests setters. Training with various options is tested in Scala. DecisionTreeClassifier dt = new DecisionTreeClassifier() diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java index 3c69467fa119e..74841058a21b0 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java @@ -27,10 +27,11 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.ml.tree.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; public class JavaGBTClassifierSuite implements Serializable { @@ -56,8 +57,8 @@ public void runDT() { JavaRDD data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); - Map categoricalFeatures = new HashMap(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + Map categoricalFeatures = new HashMap<>(); + Dataset dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); // This tests setters. Training with various options is tested in Scala. GBTClassifier rf = new GBTClassifier() diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index fd22eb6dca018..e160a5a47e304 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -18,7 +18,6 @@ package org.apache.spark.ml.classification; import java.io.Serializable; -import java.lang.Math; import java.util.List; import org.junit.After; @@ -31,16 +30,16 @@ import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; public class JavaLogisticRegressionSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient DataFrame dataset; + private transient Dataset dataset; private transient JavaRDD datasetRDD; private double eps = 1e-5; @@ -67,7 +66,7 @@ public void logisticRegressionDefaultParams() { Assert.assertEquals(lr.getLabelCol(), "label"); LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); + Dataset predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); // Check defaults Assert.assertEquals(0.5, model.getThreshold(), eps); @@ -96,14 +95,14 @@ public void logisticRegressionWithSetters() { // Modify model params, and check that the params worked. model.setThreshold(1.0); model.transform(dataset).registerTempTable("predAllZero"); - DataFrame predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero"); + Dataset predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero"); for (Row r: predAllZero.collectAsList()) { Assert.assertEquals(0.0, r.getDouble(0), eps); } // Call transform with params, and check that the params worked. model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb")) .registerTempTable("predNotAllZero"); - DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); + Dataset predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); boolean foundNonZero = false; for (Row r: predNotAllZero.collectAsList()) { if (r.getDouble(0) != 0.0) foundNonZero = true; @@ -129,8 +128,8 @@ public void logisticRegressionPredictorClassifierMethods() { Assert.assertEquals(2, model.numClasses()); model.transform(dataset).registerTempTable("transformed"); - DataFrame trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed"); - for (Row row: trans1.collect()) { + Dataset trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed"); + for (Row row: trans1.collectAsList()) { Vector raw = (Vector)row.get(0); Vector prob = (Vector)row.get(1); Assert.assertEquals(raw.size(), 2); @@ -140,8 +139,8 @@ public void logisticRegressionPredictorClassifierMethods() { Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps); } - DataFrame trans2 = jsql.sql("SELECT prediction, probability FROM transformed"); - for (Row row: trans2.collect()) { + Dataset trans2 = jsql.sql("SELECT prediction, probability FROM transformed"); + for (Row row: trans2.collectAsList()) { double pred = row.getDouble(0); Vector prob = (Vector)row.get(1); double probOfPred = prob.apply((int)pred); diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java index ec6b4bf3c0f8c..bc955f3cf6b0f 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java @@ -19,6 +19,7 @@ import java.io.Serializable; import java.util.Arrays; +import java.util.List; import org.junit.After; import org.junit.Assert; @@ -28,7 +29,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -52,7 +53,7 @@ public void tearDown() { @Test public void testMLPC() { - DataFrame dataFrame = sqlContext.createDataFrame( + Dataset dataFrame = sqlContext.createDataFrame( jsc.parallelize(Arrays.asList( new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), @@ -62,11 +63,11 @@ public void testMLPC() { MultilayerPerceptronClassifier mlpc = new MultilayerPerceptronClassifier() .setLayers(new int[] {2, 5, 2}) .setBlockSize(1) - .setSeed(11L) + .setSeed(123L) .setMaxIter(100); MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame); - DataFrame result = model.transform(dataFrame); - Row[] predictionAndLabels = result.select("prediction", "label").collect(); + Dataset result = model.transform(dataFrame); + List predictionAndLabels = result.select("prediction", "label").collectAsList(); for (Row r: predictionAndLabels) { Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1)); } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java index f5f690eabd12c..45101f286c6d2 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java @@ -26,11 +26,10 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -56,8 +55,8 @@ public void tearDown() { jsc = null; } - public void validatePrediction(DataFrame predictionAndLabels) { - for (Row r : predictionAndLabels.collect()) { + public void validatePrediction(Dataset predictionAndLabels) { + for (Row r : predictionAndLabels.collectAsList()) { double prediction = r.getAs(0); double label = r.getAs(1); assertEquals(label, prediction, 1E-5); @@ -89,11 +88,11 @@ public void testNaiveBayes() { new StructField("features", new VectorUDT(), false, Metadata.empty()) }); - DataFrame dataset = jsql.createDataFrame(data, schema); + Dataset dataset = jsql.createDataFrame(data, schema); NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial"); NaiveBayesModel model = nb.fit(dataset); - DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label"); + Dataset predictionAndLabels = model.transform(dataset).select("prediction", "label"); validatePrediction(predictionAndLabels); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java index cbabafe1b541d..00f4476841af1 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java @@ -20,6 +20,7 @@ import java.io.Serializable; import java.util.List; +import org.apache.spark.sql.Row; import scala.collection.JavaConverters; import org.junit.After; @@ -31,14 +32,14 @@ import org.apache.spark.api.java.JavaSparkContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.SQLContext; public class JavaOneVsRestSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient DataFrame dataset; + private transient Dataset dataset; private transient JavaRDD datasetRDD; @Before @@ -47,7 +48,8 @@ public void setUp() { jsql = new SQLContext(jsc); int nPoints = 3; - // The following coefficients and xMean/xVariance are computed from iris dataset with lambda=0.2. + // The following coefficients and xMean/xVariance are computed from iris dataset with + // lambda=0.2. // As a result, we are drawing samples from probability distribution of an actual model. double[] coefficients = { -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, @@ -75,7 +77,7 @@ public void oneVsRestDefaultParams() { Assert.assertEquals(ova.getLabelCol() , "label"); Assert.assertEquals(ova.getPredictionCol() , "prediction"); OneVsRestModel ovaModel = ova.fit(dataset); - DataFrame predictions = ovaModel.transform(dataset).select("label", "prediction"); + Dataset predictions = ovaModel.transform(dataset).select("label", "prediction"); predictions.collectAsList(); Assert.assertEquals(ovaModel.getLabelCol(), "label"); Assert.assertEquals(ovaModel.getPredictionCol() , "prediction"); diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java index a66a1e12927be..5aec52ac72b18 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java @@ -22,16 +22,18 @@ import java.util.Map; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.ml.tree.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; public class JavaRandomForestClassifierSuite implements Serializable { @@ -57,8 +59,8 @@ public void runDT() { JavaRDD data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); - Map categoricalFeatures = new HashMap(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + Map categoricalFeatures = new HashMap<>(); + Dataset dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); // This tests setters. Training with various options is tested in Scala. RandomForestClassifier rf = new RandomForestClassifier() @@ -79,6 +81,24 @@ public void runDT() { for (String featureSubsetStrategy: RandomForestClassifier.supportedFeatureSubsetStrategies()) { rf.setFeatureSubsetStrategy(featureSubsetStrategy); } + String realStrategies[] = {".1", ".10", "0.10", "0.1", "0.9", "1.0"}; + for (String strategy: realStrategies) { + rf.setFeatureSubsetStrategy(strategy); + } + String integerStrategies[] = {"1", "10", "100", "1000", "10000"}; + for (String strategy: integerStrategies) { + rf.setFeatureSubsetStrategy(strategy); + } + String invalidStrategies[] = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"}; + for (String strategy: invalidStrategies) { + try { + rf.setFeatureSubsetStrategy(strategy); + Assert.fail("Expected exception to be thrown for invalid strategies"); + } catch (Exception e) { + Assert.assertTrue(e instanceof IllegalArgumentException); + } + } + RandomForestClassificationModel model = rf.fit(dataFrame); model.transform(dataFrame); diff --git a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java index d09fa7fd5637c..a3fcdb54ee7ad 100644 --- a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java @@ -24,20 +24,20 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; -import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; public class JavaKMeansSuite implements Serializable { private transient int k = 5; private transient JavaSparkContext sc; - private transient DataFrame dataset; + private transient Dataset dataset; private transient SQLContext sql; @Before @@ -62,7 +62,7 @@ public void fitAndTransform() { Vector[] centers = model.clusterCenters(); assertEquals(k, centers.length); - DataFrame transformed = model.transform(dataset); + Dataset transformed = model.transform(dataset); List columns = Arrays.asList(transformed.columns()); List expectedColumns = Arrays.asList("features", "prediction"); for (String column: expectedColumns) { diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java index 8a1e5ef015659..77e3a489a93aa 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java @@ -18,15 +18,15 @@ package org.apache.spark.ml.feature; import java.util.Arrays; +import java.util.List; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -58,7 +58,7 @@ public void bucketizerTest() { StructType schema = new StructType(new StructField[] { new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) }); - DataFrame dataset = jsql.createDataFrame( + Dataset dataset = jsql.createDataFrame( Arrays.asList( RowFactory.create(-0.5), RowFactory.create(-0.3), @@ -71,7 +71,7 @@ public void bucketizerTest() { .setOutputCol("result") .setSplits(splits); - Row[] result = bucketizer.transform(dataset).select("result").collect(); + List result = bucketizer.transform(dataset).select("result").collectAsList(); for (Row r : result) { double index = r.getDouble(0); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java index 39da47381b129..ed1ad4c3a316a 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature; import java.util.Arrays; +import java.util.List; import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D; import org.junit.After; @@ -25,12 +26,11 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -57,7 +57,7 @@ public void tearDown() { @Test public void javaCompatibilityTest() { double[] input = new double[] {1D, 2D, 3D, 4D}; - DataFrame dataset = jsql.createDataFrame( + Dataset dataset = jsql.createDataFrame( Arrays.asList(RowFactory.create(Vectors.dense(input))), new StructType(new StructField[]{ new StructField("vec", (new VectorUDT()), false, Metadata.empty()) @@ -70,8 +70,8 @@ public void javaCompatibilityTest() { .setInputCol("vec") .setOutputCol("resultVec"); - Row[] result = dct.transform(dataset).select("resultVec").collect(); - Vector resultVec = result[0].getAs("resultVec"); + List result = dct.transform(dataset).select("resultVec").collectAsList(); + Vector resultVec = result.get(0).getAs("resultVec"); Assert.assertArrayEquals(expectedResult, resultVec.toArray(), 1e-6); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java index d12332c2a02a3..6e2cc7e8877c6 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java @@ -25,10 +25,9 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -66,21 +65,21 @@ public void hashingTF() { new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) }); - DataFrame sentenceData = jsql.createDataFrame(data, schema); + Dataset sentenceData = jsql.createDataFrame(data, schema); Tokenizer tokenizer = new Tokenizer() .setInputCol("sentence") .setOutputCol("words"); - DataFrame wordsData = tokenizer.transform(sentenceData); + Dataset wordsData = tokenizer.transform(sentenceData); int numFeatures = 20; HashingTF hashingTF = new HashingTF() .setInputCol("words") .setOutputCol("rawFeatures") .setNumFeatures(numFeatures); - DataFrame featurizedData = hashingTF.transform(wordsData); + Dataset featurizedData = hashingTF.transform(wordsData); IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features"); IDFModel idfModel = idf.fit(featurizedData); - DataFrame rescaledData = idfModel.transform(featurizedData); - for (Row r : rescaledData.select("features", "label").take(3)) { + Dataset rescaledData = idfModel.transform(featurizedData); + for (Row r : rescaledData.select("features", "label").takeAsList(3)) { Vector features = r.getAs(0); Assert.assertEquals(features.size(), numFeatures); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java index e17d549c5059b..5bbd9634b2c27 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java @@ -26,7 +26,8 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; public class JavaNormalizerSuite { @@ -53,17 +54,17 @@ public void normalizer() { new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) )); - DataFrame dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class); + Dataset dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class); Normalizer normalizer = new Normalizer() .setInputCol("features") .setOutputCol("normFeatures"); // Normalize each Vector using $L^2$ norm. - DataFrame l2NormData = normalizer.transform(dataFrame, normalizer.p().w(2)); + Dataset l2NormData = normalizer.transform(dataFrame, normalizer.p().w(2)); l2NormData.count(); // Normalize each Vector using $L^\infty$ norm. - DataFrame lInfNormData = + Dataset lInfNormData = normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); lInfNormData.count(); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java index e8f329f9cf29e..1389d17e7e07a 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java @@ -35,7 +35,7 @@ import org.apache.spark.mllib.linalg.Matrix; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -100,7 +100,7 @@ public VectorPair call(Tuple2 pair) { } ); - DataFrame df = sqlContext.createDataFrame(featuresExpected, VectorPair.class); + Dataset df = sqlContext.createDataFrame(featuresExpected, VectorPair.class); PCAModel pca = new PCA() .setInputCol("features") .setOutputCol("pca_features") diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java index bf8eefd71905c..6a8bb6480174a 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java @@ -25,12 +25,11 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -78,11 +77,11 @@ public void polynomialExpansionTest() { new StructField("expected", new VectorUDT(), false, Metadata.empty()) }); - DataFrame dataset = jsql.createDataFrame(data, schema); + Dataset dataset = jsql.createDataFrame(data, schema); - Row[] pairs = polyExpansion.transform(dataset) + List pairs = polyExpansion.transform(dataset) .select("polyFeatures", "expected") - .collect(); + .collectAsList(); for (Row r : pairs) { double[] polyFeatures = ((Vector)r.get(0)).toArray(); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java index ed74363f59e34..3f6fc333e4e13 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java @@ -26,7 +26,8 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; public class JavaStandardScalerSuite { @@ -53,7 +54,7 @@ public void standardScaler() { new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) ); - DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), + Dataset dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), VectorIndexerSuite.FeatureData.class); StandardScaler scaler = new StandardScaler() .setInputCol("features") @@ -65,7 +66,7 @@ public void standardScaler() { StandardScalerModel scalerModel = scaler.fit(dataFrame); // Normalize each feature to have unit standard deviation. - DataFrame scaledData = scalerModel.transform(dataFrame); + Dataset scaledData = scalerModel.transform(dataFrame); scaledData.count(); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java index 848d9f8aa9288..bdcbde5e26223 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java @@ -24,9 +24,8 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -64,9 +63,10 @@ public void javaCompatibilityTest() { RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) ); StructType schema = new StructType(new StructField[] { - new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) + new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, + Metadata.empty()) }); - DataFrame dataset = jsql.createDataFrame(data, schema); + Dataset dataset = jsql.createDataFrame(data, schema); remover.transform(dataset).collect(); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java index 6b2c48ef1c342..431779cd2e72e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java @@ -25,9 +25,8 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -58,21 +57,21 @@ public void testStringIndexer() { createStructField("label", StringType, false) }); List data = Arrays.asList( - c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c")); - DataFrame dataset = sqlContext.createDataFrame(data, schema); + cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c")); + Dataset dataset = sqlContext.createDataFrame(data, schema); StringIndexer indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex"); - DataFrame output = indexer.fit(dataset).transform(dataset); + Dataset output = indexer.fit(dataset).transform(dataset); - Assert.assertArrayEquals( - new Row[] { c(0, 0.0), c(1, 2.0), c(2, 1.0), c(3, 0.0), c(4, 0.0), c(5, 1.0) }, - output.orderBy("id").select("id", "labelIndex").collect()); + Assert.assertEquals( + Arrays.asList(cr(0, 0.0), cr(1, 2.0), cr(2, 1.0), cr(3, 0.0), cr(4, 0.0), cr(5, 1.0)), + output.orderBy("id").select("id", "labelIndex").collectAsList()); } /** An alias for RowFactory.create. */ - private Row c(Object... values) { + private Row cr(Object... values) { return RowFactory.create(values); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java index 02309ce63219a..83d16cbd0e7a1 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature; import java.util.Arrays; +import java.util.List; import org.junit.After; import org.junit.Assert; @@ -26,7 +27,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -53,6 +54,7 @@ public void regexTokenizer() { .setOutputCol("tokens") .setPattern("\\s") .setGaps(true) + .setToLowercase(false) .setMinTokenLength(3); @@ -60,11 +62,11 @@ public void regexTokenizer() { new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}), new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"}) )); - DataFrame dataset = jsql.createDataFrame(rdd, TokenizerTestData.class); + Dataset dataset = jsql.createDataFrame(rdd, TokenizerTestData.class); - Row[] pairs = myRegExTokenizer.transform(dataset) + List pairs = myRegExTokenizer.transform(dataset) .select("tokens", "wantedTokens") - .collect(); + .collectAsList(); for (Row r : pairs) { Assert.assertEquals(r.get(0), r.get(1)); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java index e283777570930..e45e19804345a 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java @@ -24,12 +24,11 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -65,11 +64,11 @@ public void testVectorAssembler() { Row row = RowFactory.create( 0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L); - DataFrame dataset = sqlContext.createDataFrame(Arrays.asList(row), schema); + Dataset dataset = sqlContext.createDataFrame(Arrays.asList(row), schema); VectorAssembler assembler = new VectorAssembler() .setInputCols(new String[] {"x", "y", "z", "n"}) .setOutputCol("features"); - DataFrame output = assembler.transform(dataset); + Dataset output = assembler.transform(dataset); Assert.assertEquals( Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}), output.select("features").first().getAs(0)); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java index bfcca62fa1c98..fec6cac8bec30 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java @@ -30,7 +30,8 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -57,7 +58,7 @@ public void vectorIndexerAPI() { new FeatureData(Vectors.dense(1.0, 4.0)) ); SQLContext sqlContext = new SQLContext(sc); - DataFrame data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class); + Dataset data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class); VectorIndexer indexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexed") @@ -66,6 +67,6 @@ public void vectorIndexerAPI() { Assert.assertEquals(model.numFeatures(), 2); Map> categoryMaps = model.javaCategoryMaps(); Assert.assertEquals(categoryMaps.size(), 1); - DataFrame indexedData = model.transform(data); + Dataset indexedData = model.transform(data); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java index 00174e6a683d6..e2da11183b93f 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java @@ -25,14 +25,13 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.attribute.Attribute; import org.apache.spark.ml.attribute.AttributeGroup; import org.apache.spark.ml.attribute.NumericAttribute; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -69,16 +68,17 @@ public void vectorSlice() { RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) ); - DataFrame dataset = jsql.createDataFrame(data, (new StructType()).add(group.toStructField())); + Dataset dataset = + jsql.createDataFrame(data, (new StructType()).add(group.toStructField())); VectorSlicer vectorSlicer = new VectorSlicer() .setInputCol("userFeatures").setOutputCol("features"); vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); - DataFrame output = vectorSlicer.transform(dataset); + Dataset output = vectorSlicer.transform(dataset); - for (Row r : output.select("userFeatures", "features").take(2)) { + for (Row r : output.select("userFeatures", "features").takeAsList(2)) { Vector features = r.getAs(1); Assert.assertEquals(features.size(), 2); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java index 0c0c1c4d12d0f..7517b70cc9bee 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java @@ -24,10 +24,9 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -54,7 +53,7 @@ public void testJavaWord2Vec() { StructType schema = new StructType(new StructField[]{ new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) }); - DataFrame documentDF = sqlContext.createDataFrame( + Dataset documentDF = sqlContext.createDataFrame( Arrays.asList( RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))), RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))), @@ -67,9 +66,9 @@ public void testJavaWord2Vec() { .setVectorSize(3) .setMinCount(0); Word2VecModel model = word2Vec.fit(documentDF); - DataFrame result = model.transform(documentDF); + Dataset result = model.transform(documentDF); - for (Row r: result.select("result").collect()) { + for (Row r: result.select("result").collectAsList()) { double[] polyFeatures = ((Vector)r.get(0)).toArray(); Assert.assertEquals(polyFeatures.length, 3); } diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java index 65841182df9b4..06f7fbb86e88e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -89,7 +89,7 @@ private void init() { myDoubleParam_ = new DoubleParam(this, "myDoubleParam", "this is a double param", ParamValidators.inRange(0.0, 1.0)); List validStrings = Arrays.asList("a", "b"); - myStringParam_ = new Param(this, "myStringParam", "this is a string param", + myStringParam_ = new Param<>(this, "myStringParam", "this is a string param", ParamValidators.inArray(validStrings)); myDoubleArrayParam_ = new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param"); diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java index ebe800e749e05..fa3b28ed4f302 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java @@ -27,10 +27,11 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.ml.tree.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; public class JavaDecisionTreeRegressorSuite implements Serializable { @@ -56,8 +57,8 @@ public void runDT() { JavaRDD data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); - Map categoricalFeatures = new HashMap(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); + Map categoricalFeatures = new HashMap<>(); + Dataset dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); // This tests setters. Training with various options is tested in Scala. DecisionTreeRegressor dt = new DecisionTreeRegressor() diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java index fc8c13db07e6f..8413ea0e0a940 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java @@ -27,10 +27,11 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.ml.tree.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; public class JavaGBTRegressorSuite implements Serializable { @@ -56,8 +57,8 @@ public void runDT() { JavaRDD data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); - Map categoricalFeatures = new HashMap(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); + Map categoricalFeatures = new HashMap<>(); + Dataset dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); GBTRegressor rf = new GBTRegressor() .setMaxDepth(2) diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java index 4fb0b0d1092b6..9f817515eb86d 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java @@ -28,7 +28,8 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite .generateLogisticInputAsList; @@ -38,7 +39,7 @@ public class JavaLinearRegressionSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient DataFrame dataset; + private transient Dataset dataset; private transient JavaRDD datasetRDD; @Before @@ -64,7 +65,7 @@ public void linearRegressionDefaultParams() { assertEquals("auto", lr.getSolver()); LinearRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction"); + Dataset predictions = jsql.sql("SELECT label, prediction FROM prediction"); predictions.collect(); // Check defaults assertEquals("features", model.getFeaturesCol()); diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java index a00ce5e249c34..a8736669f72e7 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java @@ -22,16 +22,18 @@ import java.util.Map; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.classification.LogisticRegressionSuite; -import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.ml.tree.impl.TreeTests; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; public class JavaRandomForestRegressorSuite implements Serializable { @@ -57,8 +59,8 @@ public void runDT() { JavaRDD data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); - Map categoricalFeatures = new HashMap(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); + Map categoricalFeatures = new HashMap<>(); + Dataset dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); // This tests setters. Training with various options is tested in Scala. RandomForestRegressor rf = new RandomForestRegressor() @@ -79,6 +81,24 @@ public void runDT() { for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) { rf.setFeatureSubsetStrategy(featureSubsetStrategy); } + String realStrategies[] = {".1", ".10", "0.10", "0.1", "0.9", "1.0"}; + for (String strategy: realStrategies) { + rf.setFeatureSubsetStrategy(strategy); + } + String integerStrategies[] = {"1", "10", "100", "1000", "10000"}; + for (String strategy: integerStrategies) { + rf.setFeatureSubsetStrategy(strategy); + } + String invalidStrategies[] = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"}; + for (String strategy: invalidStrategies) { + try { + rf.setFeatureSubsetStrategy(strategy); + Assert.fail("Expected exception to be thrown for invalid strategies"); + } catch (Exception e) { + Assert.assertTrue(e instanceof IllegalArgumentException); + } + } + RandomForestRegressionModel model = rf.fit(dataFrame); model.transform(dataFrame); diff --git a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java index 2976b38e45031..1c18b2b266fef 100644 --- a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java @@ -19,8 +19,8 @@ import java.io.File; import java.io.IOException; +import java.nio.charset.StandardCharsets; -import com.google.common.base.Charsets; import com.google.common.io.Files; import org.junit.After; @@ -31,7 +31,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.DenseVector; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import org.apache.spark.util.Utils; @@ -55,7 +55,7 @@ public void setUp() throws IOException { tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource"); File file = new File(tempDir, "part-00000"); String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0"; - Files.write(s, file, Charsets.US_ASCII); + Files.write(s, file, StandardCharsets.UTF_8); path = tempDir.toURI().toString(); } @@ -68,7 +68,7 @@ public void tearDown() { @Test public void verifyLibSVMDF() { - DataFrame dataset = sqlContext.read().format("libsvm").option("vectorType", "dense") + Dataset dataset = sqlContext.read().format("libsvm").option("vectorType", "dense") .load(path); Assert.assertEquals("label", dataset.columns()[0]); Assert.assertEquals("features", dataset.columns()[1]); diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java index 08eeca53f0721..24b0097454fe0 100644 --- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -30,7 +30,8 @@ import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -38,7 +39,7 @@ public class JavaCrossValidatorSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient DataFrame dataset; + private transient Dataset dataset; @Before public void setUp() { diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java new file mode 100644 index 0000000000000..01ff1ea658610 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util; + +import java.io.File; +import java.io.IOException; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.util.Utils; + +public class JavaDefaultReadWriteSuite { + + JavaSparkContext jsc = null; + SQLContext sqlContext = null; + File tempDir = null; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local[2]", "JavaDefaultReadWriteSuite"); + SQLContext.clearActive(); + sqlContext = new SQLContext(jsc); + SQLContext.setActive(sqlContext); + tempDir = Utils.createTempDir( + System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite"); + } + + @After + public void tearDown() { + sqlContext = null; + SQLContext.clearActive(); + if (jsc != null) { + jsc.stop(); + jsc = null; + } + Utils.deleteRecursively(tempDir); + } + + @Test + public void testDefaultReadWrite() throws IOException { + String uid = "my_params"; + MyParams instance = new MyParams(uid); + instance.set(instance.intParam(), 2); + String outputPath = new File(tempDir, uid).getPath(); + instance.save(outputPath); + try { + instance.save(outputPath); + Assert.fail( + "Write without overwrite enabled should fail if the output directory already exists."); + } catch (IOException e) { + // expected + } + instance.write().context(sqlContext).overwrite().save(outputPath); + MyParams newInstance = MyParams.load(outputPath); + Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid()); + Assert.assertEquals("Params should be preserved.", + 2, newInstance.getOrDefault(newInstance.intParam())); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java index c9e5ee22f3273..62c6d9b7e390a 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java @@ -66,8 +66,8 @@ public void javaAPI() { JavaDStream training = attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2); List> testBatch = Arrays.asList( - new Tuple2(10, Vectors.dense(1.0)), - new Tuple2(11, Vectors.dense(0.0))); + new Tuple2<>(10, Vectors.dense(1.0)), + new Tuple2<>(11, Vectors.dense(0.0))); JavaPairDStream test = JavaPairDStream.fromJavaDStream( attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2)); StreamingLogisticRegressionWithSGD slr = new StreamingLogisticRegressionWithSGD() diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java new file mode 100644 index 0000000000000..a714620ff7e4b --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering; + +import java.io.Serializable; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; + +public class JavaBisectingKMeansSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", this.getClass().getSimpleName()); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void twoDimensionalData() { + JavaRDD points = sc.parallelize(Lists.newArrayList( + Vectors.dense(4, -1), + Vectors.dense(4, 1), + Vectors.sparse(2, new int[] {0}, new double[] {1.0}) + ), 2); + + BisectingKMeans bkm = new BisectingKMeans() + .setK(4) + .setMaxIterations(2) + .setSeed(1L); + BisectingKMeansModel model = bkm.run(points); + Assert.assertEquals(3, model.k()); + Assert.assertArrayEquals(new double[] {3.0, 0.0}, model.root().center().toArray(), 1e-12); + for (ClusteringTreeNode child: model.root().children()) { + double[] center = child.center().toArray(); + if (center[0] > 2) { + Assert.assertEquals(2, child.size()); + Assert.assertArrayEquals(new double[] {4.0, 0.0}, center, 1e-12); + } else { + Assert.assertEquals(1, child.size()); + Assert.assertArrayEquals(new double[] {1.0, 0.0}, center, 1e-12); + } + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index 3fea359a3b46c..db19b309f65ae 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -45,9 +45,9 @@ public class JavaLDASuite implements Serializable { @Before public void setUp() { sc = new JavaSparkContext("local", "JavaLDA"); - ArrayList> tinyCorpus = new ArrayList>(); + ArrayList> tinyCorpus = new ArrayList<>(); for (int i = 0; i < LDASuite.tinyCorpus().length; i++) { - tinyCorpus.add(new Tuple2((Long)LDASuite.tinyCorpus()[i]._1(), + tinyCorpus.add(new Tuple2<>((Long)LDASuite.tinyCorpus()[i]._1(), LDASuite.tinyCorpus()[i]._2())); } JavaRDD> tmpCorpus = sc.parallelize(tinyCorpus, 2); @@ -144,7 +144,7 @@ public Boolean call(Tuple2 tuple2) { } @Test - public void OnlineOptimizerCompatibility() { + public void onlineOptimizerCompatibility() { int k = 3; double topicSmoothing = 1.2; double termSmoothing = 1.2; @@ -189,8 +189,8 @@ public void localLdaMethods() { double logPerplexity = toyModel.logPerplexity(pairedDocs); // check: logLikelihood. - ArrayList> docsSingleWord = new ArrayList>(); - docsSingleWord.add(new Tuple2(0L, Vectors.dense(1.0, 0.0, 0.0))); + ArrayList> docsSingleWord = new ArrayList<>(); + docsSingleWord.add(new Tuple2<>(0L, Vectors.dense(1.0, 0.0, 0.0))); JavaPairRDD single = JavaPairRDD.fromJavaRDD(sc.parallelize(docsSingleWord)); double logLikelihood = toyModel.logLikelihood(single); } diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java index d644766d1e54d..62edbd3a298c0 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java @@ -66,8 +66,8 @@ public void javaAPI() { JavaDStream training = attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2); List> testBatch = Arrays.asList( - new Tuple2(10, Vectors.dense(1.0)), - new Tuple2(11, Vectors.dense(0.0))); + new Tuple2<>(10, Vectors.dense(1.0)), + new Tuple2<>(11, Vectors.dense(0.0))); JavaPairDStream test = JavaPairDStream.fromJavaDStream( attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2)); StreamingKMeans skmeans = new StreamingKMeans() diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java index 154f75d75e4a6..916fff14a7214 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.mllib.fpm; +import java.io.File; import java.io.Serializable; import java.util.Arrays; import java.util.List; @@ -28,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.util.Utils; public class JavaFPGrowthSuite implements Serializable { private transient JavaSparkContext sc; @@ -69,4 +71,44 @@ public void runFPGrowth() { long freq = itemset.freq(); } } + + @Test + public void runFPGrowthSaveLoad() { + + @SuppressWarnings("unchecked") + JavaRDD> rdd = sc.parallelize(Arrays.asList( + Arrays.asList("r z h k p".split(" ")), + Arrays.asList("z y x w v u t s".split(" ")), + Arrays.asList("s x o n r".split(" ")), + Arrays.asList("x z y m t s q e".split(" ")), + Arrays.asList("z".split(" ")), + Arrays.asList("x z y r q t p".split(" "))), 2); + + FPGrowthModel model = new FPGrowth() + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd); + + File tempDir = Utils.createTempDir( + System.getProperty("java.io.tmpdir"), "JavaFPGrowthSuite"); + String outputPath = tempDir.getPath(); + + try { + model.save(sc.sc(), outputPath); + @SuppressWarnings("unchecked") + FPGrowthModel newModel = + (FPGrowthModel) FPGrowthModel.load(sc.sc(), outputPath); + List> freqItemsets = newModel.freqItemsets().toJavaRDD() + .collect(); + assertEquals(18, freqItemsets.size()); + + for (FPGrowth.FreqItemset itemset: freqItemsets) { + // Test return types. + List items = itemset.javaItems(); + long freq = itemset.freq(); + } + } finally { + Utils.deleteRecursively(tempDir); + } + } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java index 34daf5fbde80f..8a67793abc142 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.mllib.fpm; +import java.io.File; import java.util.Arrays; import java.util.List; @@ -28,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence; +import org.apache.spark.util.Utils; public class JavaPrefixSpanSuite { private transient JavaSparkContext sc; @@ -64,4 +66,39 @@ public void runPrefixSpan() { long freq = freqSeq.freq(); } } + + @Test + public void runPrefixSpanSaveLoad() { + JavaRDD>> sequences = sc.parallelize(Arrays.asList( + Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)), + Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)), + Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)), + Arrays.asList(Arrays.asList(6)) + ), 2); + PrefixSpan prefixSpan = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5); + PrefixSpanModel model = prefixSpan.run(sequences); + + File tempDir = Utils.createTempDir( + System.getProperty("java.io.tmpdir"), "JavaPrefixSpanSuite"); + String outputPath = tempDir.getPath(); + + try { + model.save(sc.sc(), outputPath); + PrefixSpanModel newModel = PrefixSpanModel.load(sc.sc(), outputPath); + JavaRDD> freqSeqs = newModel.freqSequences().toJavaRDD(); + List> localFreqSeqs = freqSeqs.collect(); + Assert.assertEquals(5, localFreqSeqs.size()); + // Check that each frequent sequence could be materialized. + for (PrefixSpan.FreqSequence freqSeq: localFreqSeqs) { + List> seq = freqSeq.javaSequence(); + long freq = freqSeq.freq(); + } + } finally { + Utils.deleteRecursively(tempDir); + } + + + } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java index 77c8c6274f374..4ba8e543a9a6b 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java @@ -37,8 +37,8 @@ public void denseArrayConstruction() { public void sparseArrayConstruction() { @SuppressWarnings("unchecked") Vector v = Vectors.sparse(3, Arrays.asList( - new Tuple2(0, 2.0), - new Tuple2(2, 3.0))); + new Tuple2<>(0, 2.0), + new Tuple2<>(2, 3.0))); assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0); } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java index 5728df5aeebdc..be58691f4d87e 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java @@ -166,7 +166,7 @@ public void testNormalVectorRDD() { @SuppressWarnings("unchecked") public void testLogNormalVectorRDD() { double mean = 4.0; - double std = 2.0; + double std = 2.0; long m = 100L; int n = 10; int p = 2; diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java index 271dda4662e0d..d0bf7f556dcc0 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -24,7 +24,6 @@ import scala.Tuple2; import scala.Tuple3; -import org.jblas.DoubleMatrix; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -48,18 +47,18 @@ public void tearDown() { sc = null; } - void validatePrediction( + private void validatePrediction( MatrixFactorizationModel model, int users, int products, - DoubleMatrix trueRatings, + double[] trueRatings, double matchThreshold, boolean implicitPrefs, - DoubleMatrix truePrefs) { - List> localUsersProducts = new ArrayList(users * products); + double[] truePrefs) { + List> localUsersProducts = new ArrayList<>(users * products); for (int u=0; u < users; ++u) { for (int p=0; p < products; ++p) { - localUsersProducts.add(new Tuple2(u, p)); + localUsersProducts.add(new Tuple2<>(u, p)); } } JavaPairRDD usersProducts = sc.parallelizePairs(localUsersProducts); @@ -68,7 +67,7 @@ void validatePrediction( if (!implicitPrefs) { for (Rating r: predictedRatings) { double prediction = r.rating(); - double correct = trueRatings.get(r.user(), r.product()); + double correct = trueRatings[r.product() * users + r.user()]; Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f", prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold); } @@ -79,9 +78,9 @@ void validatePrediction( double denom = 0.0; for (Rating r: predictedRatings) { double prediction = r.rating(); - double truePref = truePrefs.get(r.user(), r.product()); + double truePref = truePrefs[r.product() * users + r.user()]; double confidence = 1.0 + - /* alpha = */ 1.0 * Math.abs(trueRatings.get(r.user(), r.product())); + /* alpha = 1.0 * ... */ Math.abs(trueRatings[r.product() * users + r.user()]); double err = confidence * (truePref - prediction) * (truePref - prediction); sqErr += err; denom += confidence; @@ -98,8 +97,8 @@ public void runALSUsingStaticMethods() { int iterations = 15; int users = 50; int products = 100; - Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( - users, products, features, 0.7, false, false); + Tuple3, double[], double[]> testData = + ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false); JavaRDD data = sc.parallelize(testData._1()); MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations); @@ -112,8 +111,8 @@ public void runALSUsingConstructor() { int iterations = 15; int users = 100; int products = 200; - Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( - users, products, features, 0.7, false, false); + Tuple3, double[], double[]> testData = + ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false); JavaRDD data = sc.parallelize(testData._1()); @@ -129,8 +128,8 @@ public void runImplicitALSUsingStaticMethods() { int iterations = 15; int users = 80; int products = 160; - Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( - users, products, features, 0.7, true, false); + Tuple3, double[], double[]> testData = + ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false); JavaRDD data = sc.parallelize(testData._1()); MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations); @@ -143,8 +142,8 @@ public void runImplicitALSUsingConstructor() { int iterations = 15; int users = 100; int products = 200; - Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( - users, products, features, 0.7, true, false); + Tuple3, double[], double[]> testData = + ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false); JavaRDD data = sc.parallelize(testData._1()); @@ -161,8 +160,8 @@ public void runImplicitALSWithNegativeWeight() { int iterations = 15; int users = 80; int products = 160; - Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( - users, products, features, 0.7, true, true); + Tuple3, double[], double[]> testData = + ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, true); JavaRDD data = sc.parallelize(testData._1()); MatrixFactorizationModel model = new ALS().setRank(features) @@ -179,9 +178,9 @@ public void runRecommend() { int iterations = 10; int users = 200; int products = 50; - Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( - users, products, features, 0.7, true, false); - JavaRDD data = sc.parallelize(testData._1()); + List testData = ALSSuite.generateRatingsAsJava( + users, products, features, 0.7, true, false)._1(); + JavaRDD data = sc.parallelize(testData); MatrixFactorizationModel model = new ALS().setRank(features) .setIterations(iterations) .setImplicitPrefs(true) diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java index 32c2f4f3395b7..3db9b39e740e7 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java @@ -36,11 +36,11 @@ public class JavaIsotonicRegressionSuite implements Serializable { private transient JavaSparkContext sc; - private List> generateIsotonicInput(double[] labels) { - ArrayList> input = new ArrayList(labels.length); + private static List> generateIsotonicInput(double[] labels) { + List> input = new ArrayList<>(labels.length); for (int i = 1; i <= labels.length; i++) { - input.add(new Tuple3(labels[i-1], (double) i, 1d)); + input.add(new Tuple3<>(labels[i-1], (double) i, 1.0)); } return input; @@ -70,7 +70,7 @@ public void testIsotonicRegressionJavaRDD() { runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12}); Assert.assertArrayEquals( - new double[] {1, 2, 7d/3, 7d/3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1e-14); + new double[] {1, 2, 7.0/3, 7.0/3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1.0e-14); } @Test @@ -81,10 +81,10 @@ public void testIsotonicRegressionPredictionsJavaRDD() { JavaDoubleRDD testRDD = sc.parallelizeDoubles(Arrays.asList(0.0, 1.0, 9.5, 12.0, 13.0)); List predictions = model.predict(testRDD).collect(); - Assert.assertTrue(predictions.get(0) == 1d); - Assert.assertTrue(predictions.get(1) == 1d); - Assert.assertTrue(predictions.get(2) == 10d); - Assert.assertTrue(predictions.get(3) == 12d); - Assert.assertTrue(predictions.get(4) == 12d); + Assert.assertEquals(1.0, predictions.get(0).doubleValue(), 1.0e-14); + Assert.assertEquals(1.0, predictions.get(1).doubleValue(), 1.0e-14); + Assert.assertEquals(10.0, predictions.get(2).doubleValue(), 1.0e-14); + Assert.assertEquals(12.0, predictions.get(3).doubleValue(), 1.0e-14); + Assert.assertEquals(12.0, predictions.get(4).doubleValue(), 1.0e-14); } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java index 7266eec235800..c56db703ea0b4 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java @@ -19,14 +19,13 @@ import java.io.Serializable; import java.util.List; +import java.util.Random; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.jblas.DoubleMatrix; - import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.util.LinearDataGenerator; @@ -45,7 +44,8 @@ public void tearDown() { sc = null; } - double predictionError(List validationData, RidgeRegressionModel model) { + private static double predictionError(List validationData, + RidgeRegressionModel model) { double errorSum = 0; for (LabeledPoint point: validationData) { Double prediction = model.predict(point.features()); @@ -54,11 +54,14 @@ public void tearDown() { return errorSum / validationData.size(); } - List generateRidgeData(int numPoints, int numFeatures, double std) { - org.jblas.util.Random.seed(42); + private static List generateRidgeData(int numPoints, int numFeatures, double std) { // Pick weights as random values distributed uniformly in [-0.5, 0.5] - DoubleMatrix w = DoubleMatrix.rand(numFeatures, 1).subi(0.5); - return LinearDataGenerator.generateLinearInputAsList(0.0, w.data, numPoints, 42, std); + Random random = new Random(42); + double[] w = new double[numFeatures]; + for (int i = 0; i < w.length; i++) { + w[i] = random.nextDouble() - 0.5; + } + return LinearDataGenerator.generateLinearInputAsList(0.0, w, numPoints, 42, std); } @Test diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java index dbf6488d41085..ea0ccd7448986 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java @@ -65,8 +65,8 @@ public void javaAPI() { JavaDStream training = attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2); List> testBatch = Arrays.asList( - new Tuple2(10, Vectors.dense(1.0)), - new Tuple2(11, Vectors.dense(0.0))); + new Tuple2<>(10, Vectors.dense(1.0)), + new Tuple2<>(11, Vectors.dense(0.0))); JavaPairDStream test = JavaPairDStream.fromJavaDStream( attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2)); StreamingLinearRegressionWithSGD slr = new StreamingLinearRegressionWithSGD() diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java index 4795809e47a46..66b2ceacb05f2 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java @@ -18,34 +18,49 @@ package org.apache.spark.mllib.stat; import java.io.Serializable; - import java.util.Arrays; +import java.util.List; import org.junit.After; import org.junit.Before; import org.junit.Test; +import static org.apache.spark.streaming.JavaTestUtils.*; import static org.junit.Assert.assertEquals; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaDoubleRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.stat.test.BinarySample; import org.apache.spark.mllib.stat.test.ChiSqTestResult; import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult; +import org.apache.spark.mllib.stat.test.StreamingTest; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; public class JavaStatisticsSuite implements Serializable { private transient JavaSparkContext sc; + private transient JavaStreamingContext ssc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaStatistics"); + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("JavaStatistics") + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); + sc = new JavaSparkContext(conf); + ssc = new JavaStreamingContext(sc, new Duration(1000)); + ssc.checkpoint("checkpoint"); } @After public void tearDown() { - sc.stop(); + ssc.stop(); + ssc = null; sc = null; } @@ -76,4 +91,21 @@ public void chiSqTest() { new LabeledPoint(0.0, Vectors.dense(2.4, 8.1)))); ChiSqTestResult[] testResults = Statistics.chiSqTest(data); } + + @Test + public void streamingTest() { + List trainingBatch = Arrays.asList( + new BinarySample(true, 1.0), + new BinarySample(false, 2.0)); + JavaDStream training = + attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2); + int numBatches = 2; + StreamingTest model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod("welch"); + model.registerStream(training); + attachTestOutputStream(training); + runStreams(ssc, numBatches, numBatches); + } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java index 9925aae441af9..8dd29061daaad 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java @@ -64,7 +64,7 @@ int validatePrediction(List validationData, DecisionTreeModel mode public void runDTUsingConstructor() { List arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList(); JavaRDD rdd = sc.parallelize(arr); - HashMap categoricalFeaturesInfo = new HashMap(); + HashMap categoricalFeaturesInfo = new HashMap<>(); categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories int maxDepth = 4; @@ -84,7 +84,7 @@ public void runDTUsingConstructor() { public void runDTUsingStaticMethods() { List arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList(); JavaRDD rdd = sc.parallelize(arr); - HashMap categoricalFeaturesInfo = new HashMap(); + HashMap categoricalFeaturesInfo = new HashMap<>(); categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories int maxDepth = 4; diff --git a/mllib/src/test/resources/log4j.properties b/mllib/src/test/resources/log4j.properties index 75e3b53a093f6..fd51f8faf56b9 100644 --- a/mllib/src/test/resources/log4j.properties +++ b/mllib/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 1f2c9b75b617b..a8c4ac6d0561d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -19,17 +19,22 @@ package org.apache.spark.ml import scala.collection.JavaConverters._ +import org.apache.hadoop.fs.Path import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito.when import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.feature.HashingTF -import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.MLTestingUtils -import org.apache.spark.sql.DataFrame +import org.apache.spark.ml.Pipeline.SharedReadWrite +import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler} +import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.types.StructType -class PipelineSuite extends SparkFunSuite { +class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { abstract class MyModel extends Model[MyModel] @@ -46,6 +51,12 @@ class PipelineSuite extends SparkFunSuite { val dataset3 = mock[DataFrame] val dataset4 = mock[DataFrame] + when(dataset0.toDF).thenReturn(dataset0) + when(dataset1.toDF).thenReturn(dataset1) + when(dataset2.toDF).thenReturn(dataset2) + when(dataset3.toDF).thenReturn(dataset3) + when(dataset4.toDF).thenReturn(dataset4) + when(estimator0.copy(any[ParamMap])).thenReturn(estimator0) when(model0.copy(any[ParamMap])).thenReturn(model0) when(transformer1.copy(any[ParamMap])).thenReturn(transformer1) @@ -111,4 +122,125 @@ class PipelineSuite extends SparkFunSuite { assert(pipelineModel1.uid === "pipeline1") assert(pipelineModel1.stages === stages) } + + test("Pipeline read/write") { + val writableStage = new WritableStage("writableStage").setIntParam(56) + val pipeline = new Pipeline().setStages(Array(writableStage)) + + val pipeline2 = testDefaultReadWrite(pipeline, testParams = false) + assert(pipeline2.getStages.length === 1) + assert(pipeline2.getStages(0).isInstanceOf[WritableStage]) + val writableStage2 = pipeline2.getStages(0).asInstanceOf[WritableStage] + assert(writableStage.getIntParam === writableStage2.getIntParam) + } + + test("Pipeline read/write with non-Writable stage") { + val unWritableStage = new UnWritableStage("unwritableStage") + val unWritablePipeline = new Pipeline().setStages(Array(unWritableStage)) + withClue("Pipeline.write should fail when Pipeline contains non-Writable stage") { + intercept[UnsupportedOperationException] { + unWritablePipeline.write + } + } + } + + test("PipelineModel read/write") { + val writableStage = new WritableStage("writableStage").setIntParam(56) + val pipeline = + new PipelineModel("pipeline_89329327", Array(writableStage.asInstanceOf[Transformer])) + + val pipeline2 = testDefaultReadWrite(pipeline, testParams = false) + assert(pipeline2.stages.length === 1) + assert(pipeline2.stages(0).isInstanceOf[WritableStage]) + val writableStage2 = pipeline2.stages(0).asInstanceOf[WritableStage] + assert(writableStage.getIntParam === writableStage2.getIntParam) + } + + test("PipelineModel read/write: getStagePath") { + val stageUid = "myStage" + val stagesDir = new Path("pipeline", "stages").toString + def testStage(stageIdx: Int, numStages: Int, expectedPrefix: String): Unit = { + val path = SharedReadWrite.getStagePath(stageUid, stageIdx, numStages, stagesDir) + val expected = new Path(stagesDir, expectedPrefix + "_" + stageUid).toString + assert(path === expected) + } + testStage(0, 1, "0") + testStage(0, 9, "0") + testStage(0, 10, "00") + testStage(1, 10, "01") + testStage(12, 999, "012") + } + + test("PipelineModel read/write with non-Writable stage") { + val unWritableStage = new UnWritableStage("unwritableStage") + val unWritablePipeline = + new PipelineModel("pipeline_328957", Array(unWritableStage.asInstanceOf[Transformer])) + withClue("PipelineModel.write should fail when PipelineModel contains non-Writable stage") { + intercept[UnsupportedOperationException] { + unWritablePipeline.write + } + } + } + + test("pipeline validateParams") { + val df = sqlContext.createDataFrame( + Seq( + (1, Vectors.dense(0.0, 1.0, 4.0), 1.0), + (2, Vectors.dense(1.0, 0.0, 4.0), 2.0), + (3, Vectors.dense(1.0, 0.0, 5.0), 3.0), + (4, Vectors.dense(0.0, 0.0, 5.0), 4.0)) + ).toDF("id", "features", "label") + + intercept[IllegalArgumentException] { + val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("features_scaled") + .setMin(10) + .setMax(0) + val pipeline = new Pipeline().setStages(Array(scaler)) + pipeline.fit(df) + } + } +} + + +/** Used to test [[Pipeline]] with [[MLWritable]] stages */ +class WritableStage(override val uid: String) extends Transformer with MLWritable { + + final val intParam: IntParam = new IntParam(this, "intParam", "doc") + + def getIntParam: Int = $(intParam) + + def setIntParam(value: Int): this.type = set(intParam, value) + + setDefault(intParam -> 0) + + override def copy(extra: ParamMap): WritableStage = defaultCopy(extra) + + override def write: MLWriter = new DefaultParamsWriter(this) + + override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF + + override def transformSchema(schema: StructType): StructType = schema +} + +object WritableStage extends MLReadable[WritableStage] { + + override def read: MLReader[WritableStage] = new DefaultParamsReader[WritableStage] + + override def load(path: String): WritableStage = super.load(path) +} + +/** Used to test [[Pipeline]] with non-[[MLWritable]] stages */ +class UnWritableStage(override val uid: String) extends Transformer { + + final val intParam: IntParam = new IntParam(this, "intParam", "doc") + + setDefault(intParam -> 0) + + override def copy(extra: ParamMap): UnWritableStage = defaultCopy(extra) + + override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF + + override def transformSchema(schema: StructType): StructType = schema } diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala index 1292e57d7c01a..dc91fc5f9e458 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala @@ -42,7 +42,7 @@ class ANNSuite extends SparkFunSuite with MLlibTestSparkContext { val dataSample = rddData.first() val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false) - val initialWeights = FeedForwardModel(topology, 23124).weights() + val initialWeights = FeedForwardModel(topology, 23124).weights val trainer = new FeedForwardTrainer(topology, 2, 1) trainer.setWeights(initialWeights) trainer.LBFGSOptimizer.setNumIterations(20) @@ -76,10 +76,11 @@ class ANNSuite extends SparkFunSuite with MLlibTestSparkContext { val dataSample = rddData.first() val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false) - val initialWeights = FeedForwardModel(topology, 23124).weights() + val initialWeights = FeedForwardModel(topology, 23124).weights val trainer = new FeedForwardTrainer(topology, 2, 2) - trainer.SGDOptimizer.setNumIterations(2000) - trainer.setWeights(initialWeights) + // TODO: add a test for SGD + trainer.LBFGSOptimizer.setConvergenceTol(1e-4).setNumIterations(20) + trainer.setWeights(initialWeights).setStackSize(1) val model = trainer.train(rddData) val predictionAndLabels = rddData.map { case (input, label) => (model.predict(input), label) diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala new file mode 100644 index 0000000000000..04cc426c40b5e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.ann + +import breeze.linalg.{DenseMatrix => BDM} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class GradientSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("Gradient computation against numerical differentiation") { + val input = new BDM[Double](3, 1, Array(1.0, 1.0, 1.0)) + // output must contain zeros and one 1 for SoftMax + val target = new BDM[Double](2, 1, Array(0.0, 1.0)) + val topology = FeedForwardTopology.multiLayerPerceptron(Array(3, 4, 2), softmaxOnTop = false) + val layersWithErrors = Seq( + new SigmoidLayerWithSquaredError(), + new SoftmaxLayerWithCrossEntropyLoss() + ) + // check all layers that provide loss computation + // 1) compute loss and gradient given the model and initial weights + // 2) modify weights with small number epsilon (per dimension i) + // 3) compute new loss + // 4) ((newLoss - loss) / epsilon) should be close to the i-th component of the gradient + for (layerWithError <- layersWithErrors) { + topology.layers(topology.layers.length - 1) = layerWithError + val model = topology.model(seed = 12L) + val weights = model.weights.toArray + val numWeights = weights.size + val gradient = Vectors.dense(Array.fill[Double](numWeights)(0.0)) + val loss = model.computeGradient(input, target, gradient, 1) + val eps = 1e-4 + var i = 0 + val tol = 1e-4 + while (i < numWeights) { + val originalValue = weights(i) + weights(i) += eps + val newModel = topology.model(Vectors.dense(weights)) + val newLoss = computeLoss(input, target, newModel) + val derivativeEstimate = (newLoss - loss) / eps + assert(math.abs(gradient(i) - derivativeEstimate) < tol, "Layer failed gradient check: " + + layerWithError.getClass) + weights(i) = originalValue + i += 1 + } + } + } + + private def computeLoss(input: BDM[Double], target: BDM[Double], model: TopologyModel): Double = { + val outputs = model.forward(input) + model.layerModels.last match { + case layerWithLoss: LossFunction => + layerWithLoss.loss(outputs.last, target, new BDM[Double](target.rows, target.cols)) + case _ => + throw new UnsupportedOperationException("Top layer is required to have loss." + + " Failed layer:" + model.layerModels.last.getClass) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala new file mode 100644 index 0000000000000..d0e3fe7ad14b6 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +object ClassifierSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "rawPredictionCol" -> "myRawPrediction" + ) + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 815f6fd997584..fe839e15e9572 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -18,19 +18,19 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.LeafNode -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode} +import org.apache.spark.ml.tree.impl.TreeTests +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Row +import org.apache.spark.sql.{DataFrame, Row} -class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { +class DecisionTreeClassifierSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import DecisionTreeClassifierSuite.compareAPIs @@ -72,7 +72,8 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte .setImpurity("gini") .setMaxDepth(2) .setMaxBins(100) - val categoricalFeatures = Map(0 -> 3, 1-> 3) + .setSeed(1) + val categoricalFeatures = Map(0 -> 3, 1 -> 3) val numClasses = 2 compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses) } @@ -174,7 +175,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte } test("Multiclass classification tree with 10-ary (ordered) categorical features," + - " with just enough bins") { + " with just enough bins") { val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD val dt = new DecisionTreeClassifier() .setImpurity("Gini") @@ -213,7 +214,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte .setMaxBins(2) .setMaxDepth(2) .setMinInstancesPerNode(2) - val categoricalFeatures = Map(0 -> 2, 1-> 2) + val categoricalFeatures = Map(0 -> 2, 1 -> 2) val numClasses = 2 compareAPIs(rdd, dt, categoricalFeatures, numClasses) } @@ -271,32 +272,108 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte )) val df = TreeTests.setMetadata(data, Map(0 -> 1), 2) val dt = new DecisionTreeClassifier().setMaxDepth(3) + dt.fit(df) + } + + test("Use soft prediction for binary classification with ordered categorical features") { + // The following dataset is set up such that the best split is {1} vs. {0, 2}. + // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(2.0))) + val data = sc.parallelize(arr) + val df = TreeTests.setMetadata(data, Map(0 -> 3), 2) + + // Must set maxBins s.t. the feature will be treated as an ordered categorical feature. + val dt = new DecisionTreeClassifier() + .setImpurity("gini") + .setMaxDepth(1) + .setMaxBins(3) + val model = dt.fit(df) + model.rootNode match { + case n: InternalNode => + n.split match { + case s: CategoricalSplit => + assert(s.leftCategories === Array(1.0)) + case other => + fail(s"All splits should be categorical, but got ${other.getClass.getName}: $other.") + } + case other => + fail(s"Root node should be an internal node, but got ${other.getClass.getName}: $other.") + } + } + + test("Feature importance with toy data") { + val dt = new DecisionTreeClassifier() + .setImpurity("gini") + .setMaxDepth(3) + .setSeed(123) + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc) + val numFeatures = data.first().features.size + val categoricalFeatures = (0 to numFeatures).map(i => (i, 2)).toMap + val df = TreeTests.setMetadata(data, categoricalFeatures, 2) + val model = dt.fit(df) + + val importances = model.featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + assert(importances.toArray.sum === 1.0) + assert(importances.toArray.forall(_ >= 0.0)) + } + + test("should support all NumericType labels and not support other types") { + val dt = new DecisionTreeClassifier().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[DecisionTreeClassificationModel, DecisionTreeClassifier]( + dt, isClassification = true, sqlContext) { (expected, actual) => + TreeTests.checkEqual(expected, actual) + } } ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: Reinstate test once save/load are implemented SPARK-6725 - /* - test("model save/load") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - - val oldModel = OldDecisionTreeSuite.createModel(OldAlgo.Classification) - val newModel = DecisionTreeClassificationModel.fromOld(oldModel) - - // Save model, load it back, and compare. - try { - newModel.save(sc, path) - val sameNewModel = DecisionTreeClassificationModel.load(sc, path) - TreeTests.checkEqual(newModel, sameNewModel) - } finally { - Utils.deleteRecursively(tempDir) + test("read/write") { + def checkModelData( + model: DecisionTreeClassificationModel, + model2: DecisionTreeClassificationModel): Unit = { + TreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) + assert(model.numClasses === model2.numClasses) } + + val dt = new DecisionTreeClassifier() + val rdd = TreeTests.getTreeReadWriteData(sc) + + val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "entropy") + + // Categorical splits with tree depth 2 + val categoricalData: DataFrame = + TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2) + testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, checkModelData) + + // Continuous splits with tree depth 2 + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) + testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, checkModelData) + + // Continuous splits with tree depth 0 + testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0), + checkModelData) } - */ } private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 039141aeb6f67..7e6aec6b1bb80 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.tree.impl.TreeTests +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -31,11 +31,11 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.util.Utils - /** * Test suite for [[GBTClassifier]]. */ -class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { +class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { import GBTClassifierSuite.compareAPIs @@ -74,6 +74,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { .setLossType("logistic") .setMaxIter(maxIter) .setStepSize(learningRate) + .setSeed(123) compareAPIs(data, None, gbt, categoricalFeatures) } } @@ -91,6 +92,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { .setMaxIter(5) .setStepSize(0.1) .setCheckpointInterval(2) + .setSeed(123) val model = gbt.fit(df) // copied model must have the same parent. @@ -100,6 +102,14 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { Utils.deleteRecursively(tempDir) } + test("should support all NumericType labels and not support other types") { + val gbt = new GBTClassifier().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[GBTClassificationModel, GBTClassifier]( + gbt, isClassification = true, sqlContext) { (expected, actual) => + TreeTests.checkEqual(expected, actual) + } + } + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 /* test("runWithValidation stops early and performs better on a validation dataset") { @@ -118,31 +128,52 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { } */ + ///////////////////////////////////////////////////////////////////////////// + // Tests of feature importance + ///////////////////////////////////////////////////////////////////////////// + test("Feature importance with toy data") { + val numClasses = 2 + val gbt = new GBTClassifier() + .setImpurity("Gini") + .setMaxDepth(3) + .setMaxIter(5) + .setSubsamplingRate(1.0) + .setStepSize(0.5) + .setSeed(123) + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) + + val importances = gbt.fit(df).featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + assert(importances.toArray.sum === 1.0) + assert(importances.toArray.forall(_ >= 0.0)) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: Reinstate test once save/load are implemented SPARK-6725 - /* test("model save/load") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - - val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray - val treeWeights = Array(0.1, 0.3, 1.1) - val oldModel = new OldGBTModel(OldAlgo.Classification, trees, treeWeights) - val newModel = GBTClassificationModel.fromOld(oldModel) - - // Save model, load it back, and compare. - try { - newModel.save(sc, path) - val sameNewModel = GBTClassificationModel.load(sc, path) - TreeTests.checkEqual(newModel, sameNewModel) - } finally { - Utils.deleteRecursively(tempDir) + def checkModelData( + model: GBTClassificationModel, + model2: GBTClassificationModel): Unit = { + TreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) } + + val gbt = new GBTClassifier() + val rdd = TreeTests.getTreeReadWriteData(sc) + + val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "logistic") + + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) + testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) } - */ } private object GBTClassifierSuite extends SparkFunSuite { @@ -159,7 +190,7 @@ private object GBTClassifierSuite extends SparkFunSuite { val numFeatures = data.first().features.size val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) - val oldGBT = new OldGBT(oldBoostingStrategy) + val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt) val oldModel = oldGBT.run(data) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) val newModel = gbt.fit(newData) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 325faf37e8eea..48db4281309b7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -23,17 +23,19 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.functions.lit -class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class LogisticRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ @transient var binaryDataset: DataFrame = _ private val eps: Double = 1e-5 @@ -42,20 +44,6 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { dataset = sqlContext.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42)) - /* - Here is the instruction describing how to export the test data into CSV format - so we can validate the training accuracy compared with R's glmnet package. - - import org.apache.spark.mllib.classification.LogisticRegressionSuite - val nPoints = 10000 - val coefficients = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) - val xMean = Array(5.843, 3.057, 3.758, 1.199) - val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) - val data = sc.parallelize(LogisticRegressionSuite.generateMultinomialLogisticInput( - coefficients, xMean, xVariance, true, nPoints, 42), 1) - data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1) + ", " - + x.features(2) + ", " + x.features(3)).saveAsTextFile("path") - */ binaryDataset = { val nPoints = 10000 val coefficients = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) @@ -63,12 +51,23 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) val testData = - generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42) + generateMultinomialLogisticInput(coefficients, xMean, xVariance, + addIntercept = true, nPoints, 42) sqlContext.createDataFrame(sc.parallelize(testData, 4)) } } + /** + * Enable the ignored test to export the dataset into CSV format, + * so we can validate the training accuracy compared with R's glmnet package. + */ + ignore("export test data into CSV format") { + binaryDataset.rdd.map { case Row(label: Double, features: Vector) => + label + "," + features.toArray.mkString(",") + }.repartition(1).saveAsTextFile("target/tmp/LogisticRegressionSuite/binaryDataset") + } + test("params") { ParamsSuite.checkParams(new LogisticRegression) val model = new LogisticRegressionModel("logReg", Vectors.dense(0.0), 0.0) @@ -98,6 +97,17 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.hasParent) } + test("empty probabilityCol") { + val lr = new LogisticRegression().setProbabilityCol("") + val model = lr.fit(dataset) + assert(model.hasSummary) + // Validate that we re-insert a probability column for evaluation + val fieldNames = model.summary.predictions.schema.fieldNames + assert(dataset.schema.fieldNames.toSet.subsetOf( + fieldNames.toSet)) + assert(fieldNames.exists(s => s.startsWith("probability_"))) + } + test("setThreshold, getThreshold") { val lr = new LogisticRegression // default @@ -722,7 +732,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val model1 = trainer1.fit(binaryDataset) val model2 = trainer2.fit(binaryDataset) - val histogram = binaryDataset.map { case Row(label: Double, features: Vector) => label } + val histogram = binaryDataset.rdd.map { case Row(label: Double, features: Vector) => label } .treeAggregate(new MultiClassSummarizer)( seqOp = (c, v) => (c, v) match { case (classSummarizer: MultiClassSummarizer, label: Double) => classSummarizer.add(label) @@ -869,6 +879,88 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3) assert(model1a0.coefficients ~== model1b.coefficients absTol 1E-3) assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) + } + test("logistic regression with all labels the same") { + val sameLabels = dataset + .withColumn("zeroLabel", lit(0.0)) + .withColumn("oneLabel", lit(1.0)) + + // fitIntercept=true + val lrIntercept = new LogisticRegression() + .setFitIntercept(true) + .setMaxIter(3) + + val allZeroInterceptModel = lrIntercept + .setLabelCol("zeroLabel") + .fit(sameLabels) + assert(allZeroInterceptModel.coefficients ~== Vectors.dense(0.0) absTol 1E-3) + assert(allZeroInterceptModel.intercept === Double.NegativeInfinity) + assert(allZeroInterceptModel.summary.totalIterations === 0) + + val allOneInterceptModel = lrIntercept + .setLabelCol("oneLabel") + .fit(sameLabels) + assert(allOneInterceptModel.coefficients ~== Vectors.dense(0.0) absTol 1E-3) + assert(allOneInterceptModel.intercept === Double.PositiveInfinity) + assert(allOneInterceptModel.summary.totalIterations === 0) + + // fitIntercept=false + val lrNoIntercept = new LogisticRegression() + .setFitIntercept(false) + .setMaxIter(3) + + val allZeroNoInterceptModel = lrNoIntercept + .setLabelCol("zeroLabel") + .fit(sameLabels) + assert(allZeroNoInterceptModel.intercept === 0.0) + assert(allZeroNoInterceptModel.summary.totalIterations > 0) + + val allOneNoInterceptModel = lrNoIntercept + .setLabelCol("oneLabel") + .fit(sameLabels) + assert(allOneNoInterceptModel.intercept === 0.0) + assert(allOneNoInterceptModel.summary.totalIterations > 0) } + + test("read/write") { + def checkModelData(model: LogisticRegressionModel, model2: LogisticRegressionModel): Unit = { + assert(model.intercept === model2.intercept) + assert(model.coefficients.toArray === model2.coefficients.toArray) + assert(model.numClasses === model2.numClasses) + assert(model.numFeatures === model2.numFeatures) + } + val lr = new LogisticRegression() + testEstimatorAndModelReadWrite(lr, dataset, LogisticRegressionSuite.allParamSettings, + checkModelData) + } + + test("should support all NumericType labels and not support other types") { + val lr = new LogisticRegression().setMaxIter(1) + MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression]( + lr, isClassification = true, sqlContext) { (expected, actual) => + assert(expected.intercept === actual.intercept) + assert(expected.coefficients.toArray === actual.coefficients.toArray) + } + } +} + +object LogisticRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = ProbabilisticClassifierSuite.allParamSettings ++ Map( + "probabilityCol" -> "myProbability", + "thresholds" -> Array(0.4, 0.6), + "regParam" -> 0.01, + "elasticNetParam" -> 0.1, + "maxIter" -> 2, // intentionally small + "fitIntercept" -> true, + "tol" -> 0.8, + "standardization" -> false, + "threshold" -> 0.6 + ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 17db8c44777d4..80547fad6af8d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -18,38 +18,87 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.Row +import org.apache.spark.sql.{DataFrame, Dataset, Row} -class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { +class MultilayerPerceptronClassifierSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - test("XOR function learning as binary classification problem with two outputs.") { - val dataFrame = sqlContext.createDataFrame(Seq( + @transient var dataset: Dataset[_] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + dataset = sqlContext.createDataFrame(Seq( (Vectors.dense(0.0, 0.0), 0.0), (Vectors.dense(0.0, 1.0), 1.0), (Vectors.dense(1.0, 0.0), 1.0), (Vectors.dense(1.0, 1.0), 0.0)) ).toDF("features", "label") + } + + test("Input Validation") { + val mlpc = new MultilayerPerceptronClassifier() + intercept[IllegalArgumentException] { + mlpc.setLayers(Array[Int]()) + } + intercept[IllegalArgumentException] { + mlpc.setLayers(Array[Int](1)) + } + intercept[IllegalArgumentException] { + mlpc.setLayers(Array[Int](0, 1)) + } + intercept[IllegalArgumentException] { + mlpc.setLayers(Array[Int](1, 0)) + } + mlpc.setLayers(Array[Int](1, 1)) + } + + test("XOR function learning as binary classification problem with two outputs.") { val layers = Array[Int](2, 5, 2) val trainer = new MultilayerPerceptronClassifier() .setLayers(layers) .setBlockSize(1) - .setSeed(11L) + .setSeed(123L) .setMaxIter(100) - val model = trainer.fit(dataFrame) - val result = model.transform(dataFrame) + val model = trainer.fit(dataset) + val result = model.transform(dataset) val predictionAndLabels = result.select("prediction", "label").collect() predictionAndLabels.foreach { case Row(p: Double, l: Double) => assert(p == l) } } - // TODO: implement a more rigorous test + test("Test setWeights by training restart") { + val dataFrame = sqlContext.createDataFrame(Seq( + (Vectors.dense(0.0, 0.0), 0.0), + (Vectors.dense(0.0, 1.0), 1.0), + (Vectors.dense(1.0, 0.0), 1.0), + (Vectors.dense(1.0, 1.0), 0.0)) + ).toDF("features", "label") + val layers = Array[Int](2, 5, 2) + val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(1) + .setSeed(12L) + .setMaxIter(1) + .setTol(1e-6) + val initialWeights = trainer.fit(dataFrame).weights + trainer.setWeights(initialWeights.copy) + val weights1 = trainer.fit(dataFrame).weights + trainer.setWeights(initialWeights.copy) + val weights2 = trainer.fit(dataFrame).weights + assert(weights1 ~== weights2 absTol 10e-5, + "Training should produce the same weights given equal initial weights and number of steps") + } + test("3 class classification with 2 hidden layers") { val nPoints = 1000 @@ -61,8 +110,9 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp val xMean = Array(5.843, 3.057, 3.758, 1.199) val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + // the input seed is somewhat magic, to make this test pass val rdd = sc.parallelize(generateMultinomialLogisticInput( - coefficients, xMean, xVariance, true, nPoints, 42), 2) + coefficients, xMean, xVariance, true, nPoints, 1), 2) val dataFrame = sqlContext.createDataFrame(rdd).toDF("label", "features") val numClasses = 3 val numIterations = 100 @@ -70,13 +120,14 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp val trainer = new MultilayerPerceptronClassifier() .setLayers(layers) .setBlockSize(1) - .setSeed(11L) + .setSeed(11L) // currently this seed is ignored .setMaxIter(numIterations) val model = trainer.fit(dataFrame) val numFeatures = dataFrame.select("features").first().getAs[Vector](0).size assert(model.numFeatures === numFeatures) - val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label") - .map { case Row(p: Double, l: Double) => (p, l) } + val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label").rdd.map { + case Row(p: Double, l: Double) => (p, l) + } // train multinomial logistic regression val lr = new LogisticRegressionWithLBFGS() .setIntercept(true) @@ -90,4 +141,37 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels) assert(mlpMetrics.confusionMatrix ~== lrMetrics.confusionMatrix absTol 100) } + + test("read/write: MultilayerPerceptronClassifier") { + val mlp = new MultilayerPerceptronClassifier() + .setLayers(Array(2, 3, 2)) + .setMaxIter(5) + .setBlockSize(2) + .setSeed(42) + .setTol(0.1) + .setFeaturesCol("myFeatures") + .setLabelCol("myLabel") + .setPredictionCol("myPrediction") + + testDefaultReadWrite(mlp, testParams = true) + } + + test("read/write: MultilayerPerceptronClassificationModel") { + val mlp = new MultilayerPerceptronClassifier().setLayers(Array(2, 3, 2)).setMaxIter(5) + val mlpModel = mlp.fit(dataset) + val newMlpModel = testDefaultReadWrite(mlpModel, testParams = true) + assert(newMlpModel.layers === mlpModel.layers) + assert(newMlpModel.weights === mlpModel.weights) + } + + test("should support all NumericType labels and not support other types") { + val layers = Array(3, 2) + val mpc = new MultilayerPerceptronClassifier().setLayers(layers).setMaxIter(1) + MLTestingUtils.checkNumericTypes[ + MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier]( + mpc, isClassification = true, sqlContext) { (expected, actual) => + assert(expected.layers === actual.layers) + assert(expected.weights === actual.weights) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 98bc9511163e7..80a46fc70c75d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -21,15 +21,30 @@ import breeze.linalg.{Vector => BV} import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.mllib.classification.NaiveBayes.{Multinomial, Bernoulli} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.classification.NaiveBayes.{Bernoulli, Multinomial} +import org.apache.spark.mllib.classification.NaiveBayesSuite._ import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.mllib.classification.NaiveBayesSuite._ -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Row +import org.apache.spark.sql.{DataFrame, Dataset, Row} + +class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + @transient var dataset: Dataset[_] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + val pi = Array(0.5, 0.1, 0.4).map(math.log) + val theta = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) -class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { + dataset = sqlContext.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42)) + } def validatePrediction(predictionAndLabels: DataFrame): Unit = { val numOfErrorPredictions = predictionAndLabels.collect().count { @@ -71,7 +86,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { model: NaiveBayesModel, modelType: String): Unit = { featureAndProbabilities.collect().foreach { - case Row(features: Vector, probability: Vector) => { + case Row(features: Vector, probability: Vector) => assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10) val expected = modelType match { case Multinomial => @@ -82,7 +97,6 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { throw new UnknownError(s"Invalid modelType: $modelType.") } assert(probability ~== expected relTol 1.0e-10) - } } } @@ -161,4 +175,35 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { .select("features", "probability") validateProbabilities(featureAndProbabilities, model, "bernoulli") } + + test("read/write") { + def checkModelData(model: NaiveBayesModel, model2: NaiveBayesModel): Unit = { + assert(model.pi === model2.pi) + assert(model.theta === model2.theta) + } + val nb = new NaiveBayes() + testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData) + } + + test("should support all NumericType labels and not support other types") { + val nb = new NaiveBayes() + MLTestingUtils.checkNumericTypes[NaiveBayesModel, NaiveBayes]( + nb, isClassification = true, sqlContext) { (expected, actual) => + assert(expected.pi === actual.pi) + assert(expected.theta === actual.theta) + } + } +} + +object NaiveBayesSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "smoothing" -> 0.1 + ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 5ea71c5317b7a..f3e8fd11b296a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -21,21 +21,21 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.feature.StringIndexer import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{MLTestingUtils, MetadataUtils} -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils} import org.apache.spark.mllib.classification.LogisticRegressionSuite._ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.Metadata -class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { +class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ @transient var rdd: RDD[LabeledPoint] = _ override def beforeAll(): Unit = { @@ -74,16 +74,16 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { // copied model must have the same parent. MLTestingUtils.checkCopy(ovaModel) - assert(ovaModel.models.size === numClasses) + assert(ovaModel.models.length === numClasses) val transformedDataset = ovaModel.transform(dataset) // check for label metadata in prediction col val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol) assert(MetadataUtils.getNumClasses(predictionColSchema) === Some(3)) - val ovaResults = transformedDataset - .select("prediction", "label") - .map(row => (row.getDouble(0), row.getDouble(1))) + val ovaResults = transformedDataset.select("prediction", "label").rdd.map { + row => (row.getDouble(0), row.getDouble(1)) + } val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(numClasses) lr.optimizer.setRegParam(0.1).setNumIterations(100) @@ -160,6 +160,84 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { require(m.getThreshold === 0.1, "copy should handle extra model params") } } + + test("read/write: OneVsRest") { + val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01) + + val ova = new OneVsRest() + .setClassifier(lr) + .setLabelCol("myLabel") + .setFeaturesCol("myFeature") + .setPredictionCol("myPrediction") + + val ova2 = testDefaultReadWrite(ova, testParams = false) + assert(ova.uid === ova2.uid) + assert(ova.getFeaturesCol === ova2.getFeaturesCol) + assert(ova.getLabelCol === ova2.getLabelCol) + assert(ova.getPredictionCol === ova2.getPredictionCol) + + ova2.getClassifier match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getMaxIter === lr2.getMaxIter) + assert(lr.getRegParam === lr2.getRegParam) + case other => + throw new AssertionError(s"Loaded OneVsRest expected classifier of type" + + s" LogisticRegression but found ${other.getClass.getName}") + } + } + + test("read/write: OneVsRestModel") { + def checkModelData(model: OneVsRestModel, model2: OneVsRestModel): Unit = { + assert(model.uid === model2.uid) + assert(model.getFeaturesCol === model2.getFeaturesCol) + assert(model.getLabelCol === model2.getLabelCol) + assert(model.getPredictionCol === model2.getPredictionCol) + + val classifier = model.getClassifier.asInstanceOf[LogisticRegression] + + model2.getClassifier match { + case lr2: LogisticRegression => + assert(classifier.uid === lr2.uid) + assert(classifier.getMaxIter === lr2.getMaxIter) + assert(classifier.getRegParam === lr2.getRegParam) + case other => + throw new AssertionError(s"Loaded OneVsRestModel expected classifier of type" + + s" LogisticRegression but found ${other.getClass.getName}") + } + + assert(model.labelMetadata === model2.labelMetadata) + model.models.zip(model2.models).foreach { + case (lrModel1: LogisticRegressionModel, lrModel2: LogisticRegressionModel) => + assert(lrModel1.uid === lrModel2.uid) + assert(lrModel1.coefficients === lrModel2.coefficients) + assert(lrModel1.intercept === lrModel2.intercept) + case other => + throw new AssertionError(s"Loaded OneVsRestModel expected model of type" + + s" LogisticRegressionModel but found ${other.getClass.getName}") + } + } + + val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01) + val ova = new OneVsRest().setClassifier(lr) + val ovaModel = ova.fit(dataset) + val newOvaModel = testDefaultReadWrite(ovaModel, testParams = false) + checkModelData(ovaModel, newOvaModel) + } + + test("should support all NumericType labels and not support other types") { + val ovr = new OneVsRest().setClassifier(new LogisticRegression().setMaxIter(1)) + MLTestingUtils.checkNumericTypes[OneVsRestModel, OneVsRest]( + ovr, isClassification = true, sqlContext) { (expected, actual) => + val expectedModels = expected.models.map(m => m.asInstanceOf[LogisticRegressionModel]) + val actualModels = actual.models.map(m => m.asInstanceOf[LogisticRegressionModel]) + assert(expectedModels.length === actualModels.length) + expectedModels.zip(actualModels).foreach { case (e, a) => + assert(e.intercept === a.intercept) + assert(e.coefficients.toArray === a.coefficients.toArray) + } + } + } } private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) { @@ -168,7 +246,7 @@ private class MockLogisticRegression(uid: String) extends LogisticRegression(uid setMaxIter(1) - override protected def train(dataset: DataFrame): LogisticRegressionModel = { + override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = { val labelSchema = dataset.schema($(labelCol)) // check for label attribute propagation. assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index fb5f00e0646c6..cfa75ecf387cd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -57,3 +57,17 @@ class ProbabilisticClassifierSuite extends SparkFunSuite { assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0) } } + +object ProbabilisticClassifierSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = ClassifierSuite.allParamSettings ++ Map( + "probabilityCol" -> "myProbability", + "thresholds" -> Array(0.4, 0.6) + ) + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index deb8ec771cb27..aaaa429103478 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -18,10 +18,10 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.tree.impl.TreeTests +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} @@ -34,7 +34,8 @@ import org.apache.spark.sql.{DataFrame, Row} /** * Test suite for [[RandomForestClassifier]]. */ -class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { +class RandomForestClassifierSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import RandomForestClassifierSuite.compareAPIs @@ -105,7 +106,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte compareAPIs(rdd, rf, categoricalFeatures, numClasses) } - test("subsampling rate in RandomForest"){ + test("subsampling rate in RandomForest") { val rdd = orderedLabeledPoints5_20 val categoricalFeatures = Map.empty[Int, Int] val numClasses = 2 @@ -167,46 +168,47 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte .setSeed(123) // In this data, feature 1 is very important. - val data: RDD[LabeledPoint] = sc.parallelize(Seq( - new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)), - new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)), - new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)), - new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)), - new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)) - )) + val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc) val categoricalFeatures = Map.empty[Int, Int] val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) val importances = rf.fit(df).featureImportances val mostImportantFeature = importances.argmax assert(mostImportantFeature === 1) + assert(importances.toArray.sum === 1.0) + assert(importances.toArray.forall(_ >= 0.0)) + } + + test("should support all NumericType labels and not support other types") { + val rf = new RandomForestClassifier().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[RandomForestClassificationModel, RandomForestClassifier]( + rf, isClassification = true, sqlContext) { (expected, actual) => + TreeTests.checkEqual(expected, actual) + } } ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: Reinstate test once save/load are implemented SPARK-6725 - /* - test("model save/load") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - - val trees = - Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Classification)).toArray - val oldModel = new OldRandomForestModel(OldAlgo.Classification, trees) - val newModel = RandomForestClassificationModel.fromOld(oldModel) - - // Save model, load it back, and compare. - try { - newModel.save(sc, path) - val sameNewModel = RandomForestClassificationModel.load(sc, path) - TreeTests.checkEqual(newModel, sameNewModel) - } finally { - Utils.deleteRecursively(tempDir) + test("read/write") { + def checkModelData( + model: RandomForestClassificationModel, + model2: RandomForestClassificationModel): Unit = { + TreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) + assert(model.numClasses === model2.numClasses) } + + val rf = new RandomForestClassifier().setNumTrees(2) + val rdd = TreeTests.getTreeReadWriteData(sc) + + val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "entropy") + + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) + testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData) } - */ } private object RandomForestClassifierSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala new file mode 100644 index 0000000000000..e641d79c1707b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Dataset} + +class BisectingKMeansSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + final val k = 5 + @transient var dataset: Dataset[_] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k) + } + + test("default parameters") { + val bkm = new BisectingKMeans() + + assert(bkm.getK === 4) + assert(bkm.getFeaturesCol === "features") + assert(bkm.getPredictionCol === "prediction") + assert(bkm.getMaxIter === 20) + assert(bkm.getMinDivisibleClusterSize === 1.0) + } + + test("setter/getter") { + val bkm = new BisectingKMeans() + .setK(9) + .setMinDivisibleClusterSize(2.0) + .setFeaturesCol("test_feature") + .setPredictionCol("test_prediction") + .setMaxIter(33) + .setSeed(123) + + assert(bkm.getK === 9) + assert(bkm.getFeaturesCol === "test_feature") + assert(bkm.getPredictionCol === "test_prediction") + assert(bkm.getMaxIter === 33) + assert(bkm.getMinDivisibleClusterSize === 2.0) + assert(bkm.getSeed === 123) + + intercept[IllegalArgumentException] { + new BisectingKMeans().setK(1) + } + + intercept[IllegalArgumentException] { + new BisectingKMeans().setMinDivisibleClusterSize(0) + } + } + + test("fit & transform") { + val predictionColName = "bisecting_kmeans_prediction" + val bkm = new BisectingKMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) + val model = bkm.fit(dataset) + assert(model.clusterCenters.length === k) + + val transformed = model.transform(dataset) + val expectedColumns = Array("features", predictionColName) + expectedColumns.foreach { column => + assert(transformed.columns.contains(column)) + } + val clusters = + transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet + assert(clusters.size === k) + assert(clusters === Set(0, 1, 2, 3, 4)) + assert(model.computeCost(dataset) < 0.1) + assert(model.hasParent) + } + + test("read/write") { + def checkModelData(model: BisectingKMeansModel, model2: BisectingKMeansModel): Unit = { + assert(model.clusterCenters === model2.clusterCenters) + } + val bisectingKMeans = new BisectingKMeans() + testEstimatorAndModelReadWrite( + bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, checkModelData) + } +} + +object BisectingKMeansSuite { + val allParamSettings: Map[String, Any] = Map( + "k" -> 3, + "maxIter" -> 2, + "seed" -> -1L, + "minDivisibleClusterSize" -> 2.0 + ) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala new file mode 100644 index 0000000000000..1a274aea291f4 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Dataset} + + +class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { + + final val k = 5 + @transient var dataset: Dataset[_] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k) + } + + test("default parameters") { + val gm = new GaussianMixture() + + assert(gm.getK === 2) + assert(gm.getFeaturesCol === "features") + assert(gm.getPredictionCol === "prediction") + assert(gm.getMaxIter === 100) + assert(gm.getTol === 0.01) + } + + test("set parameters") { + val gm = new GaussianMixture() + .setK(9) + .setFeaturesCol("test_feature") + .setPredictionCol("test_prediction") + .setProbabilityCol("test_probability") + .setMaxIter(33) + .setSeed(123) + .setTol(1e-3) + + assert(gm.getK === 9) + assert(gm.getFeaturesCol === "test_feature") + assert(gm.getPredictionCol === "test_prediction") + assert(gm.getProbabilityCol === "test_probability") + assert(gm.getMaxIter === 33) + assert(gm.getSeed === 123) + assert(gm.getTol === 1e-3) + } + + test("parameters validation") { + intercept[IllegalArgumentException] { + new GaussianMixture().setK(1) + } + } + + test("fit, transform, and summary") { + val predictionColName = "gm_prediction" + val probabilityColName = "gm_probability" + val gm = new GaussianMixture().setK(k).setMaxIter(2).setPredictionCol(predictionColName) + .setProbabilityCol(probabilityColName).setSeed(1) + val model = gm.fit(dataset) + assert(model.hasParent) + assert(model.weights.length === k) + assert(model.gaussians.length === k) + + val transformed = model.transform(dataset) + val expectedColumns = Array("features", predictionColName, probabilityColName) + expectedColumns.foreach { column => + assert(transformed.columns.contains(column)) + } + + // Check validity of model summary + val numRows = dataset.count() + assert(model.hasSummary) + val summary: GaussianMixtureSummary = model.summary + assert(summary.predictionCol === predictionColName) + assert(summary.probabilityCol === probabilityColName) + assert(summary.featuresCol === "features") + assert(summary.predictions.count() === numRows) + for (c <- Array(predictionColName, probabilityColName, "features")) { + assert(summary.predictions.columns.contains(c)) + } + assert(summary.cluster.columns === Array(predictionColName)) + assert(summary.probability.columns === Array(probabilityColName)) + val clusterSizes = summary.clusterSizes + assert(clusterSizes.length === k) + assert(clusterSizes.sum === numRows) + assert(clusterSizes.forall(_ >= 0)) + } + + test("read/write") { + def checkModelData(model: GaussianMixtureModel, model2: GaussianMixtureModel): Unit = { + assert(model.weights === model2.weights) + assert(model.gaussians.map(_.mu) === model2.gaussians.map(_.mu)) + assert(model.gaussians.map(_.sigma) === model2.gaussians.map(_.sigma)) + } + val gm = new GaussianMixture() + testEstimatorAndModelReadWrite(gm, dataset, + GaussianMixtureSuite.allParamSettings, checkModelData) + } +} + +object GaussianMixtureSuite { + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "probabilityCol" -> "myProbability", + "k" -> 3, + "maxIter" -> 2, + "tol" -> 0.01 + ) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index c05f90550d161..2ca386e4229ca 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -18,26 +18,18 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, SQLContext} private[clustering] case class TestRow(features: Vector) -object KMeansSuite { - def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = { - val sc = sql.sparkContext - val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble))) - .map(v => new TestRow(v)) - sql.createDataFrame(rdd) - } -} - -class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { final val k = 5 - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() @@ -90,7 +82,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { } } - test("fit & transform") { + test("fit, transform, and summary") { val predictionColName = "kmeans_prediction" val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) val model = kmeans.fit(dataset) @@ -101,9 +93,56 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { expectedColumns.foreach { column => assert(transformed.columns.contains(column)) } - val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet + val clusters = + transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet assert(clusters.size === k) assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) + assert(model.hasParent) + + // Check validity of model summary + val numRows = dataset.count() + assert(model.hasSummary) + val summary: KMeansSummary = model.summary + assert(summary.predictionCol === predictionColName) + assert(summary.featuresCol === "features") + assert(summary.predictions.count() === numRows) + for (c <- Array(predictionColName, "features")) { + assert(summary.predictions.columns.contains(c)) + } + assert(summary.cluster.columns === Array(predictionColName)) + val clusterSizes = summary.clusterSizes + assert(clusterSizes.length === k) + assert(clusterSizes.sum === numRows) + assert(clusterSizes.forall(_ >= 0)) } + + test("read/write") { + def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = { + assert(model.clusterCenters === model2.clusterCenters) + } + val kmeans = new KMeans() + testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData) + } +} + +object KMeansSuite { + def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = { + val sc = sql.sparkContext + val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble))) + .map(v => new TestRow(v)) + sql.createDataFrame(rdd) + } + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "k" -> 3, + "maxIter" -> 2, + "tol" -> 0.01 + ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala new file mode 100644 index 0000000000000..17d6e9fc2ee76 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} + + +object LDASuite { + def generateLDAData( + sql: SQLContext, + rows: Int, + k: Int, + vocabSize: Int): DataFrame = { + val avgWC = 1 // average instances of each word in a doc + val sc = sql.sparkContext + val rng = new java.util.Random() + rng.setSeed(1) + val rdd = sc.parallelize(1 to rows).map { i => + Vectors.dense(Array.fill(vocabSize)(rng.nextInt(2 * avgWC).toDouble)) + }.map(v => new TestRow(v)) + sql.createDataFrame(rdd) + } + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "k" -> 3, + "maxIter" -> 2, + "checkpointInterval" -> 30, + "learningOffset" -> 1023.0, + "learningDecay" -> 0.52, + "subsamplingRate" -> 0.051, + "docConcentration" -> Array(2.0) + ) +} + + +class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + val k: Int = 5 + val vocabSize: Int = 30 + @transient var dataset: Dataset[_] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + dataset = LDASuite.generateLDAData(sqlContext, 50, k, vocabSize) + } + + test("default parameters") { + val lda = new LDA() + + assert(lda.getFeaturesCol === "features") + assert(lda.getMaxIter === 20) + assert(lda.isDefined(lda.seed)) + assert(lda.getCheckpointInterval === 10) + assert(lda.getK === 10) + assert(!lda.isSet(lda.docConcentration)) + assert(!lda.isSet(lda.topicConcentration)) + assert(lda.getOptimizer === "online") + assert(lda.getLearningDecay === 0.51) + assert(lda.getLearningOffset === 1024) + assert(lda.getSubsamplingRate === 0.05) + assert(lda.getOptimizeDocConcentration) + assert(lda.getTopicDistributionCol === "topicDistribution") + } + + test("set parameters") { + val lda = new LDA() + .setFeaturesCol("test_feature") + .setMaxIter(33) + .setSeed(123) + .setCheckpointInterval(7) + .setK(9) + .setTopicConcentration(0.56) + .setTopicDistributionCol("myOutput") + + assert(lda.getFeaturesCol === "test_feature") + assert(lda.getMaxIter === 33) + assert(lda.getSeed === 123) + assert(lda.getCheckpointInterval === 7) + assert(lda.getK === 9) + assert(lda.getTopicConcentration === 0.56) + assert(lda.getTopicDistributionCol === "myOutput") + + + // setOptimizer + lda.setOptimizer("em") + assert(lda.getOptimizer === "em") + lda.setOptimizer("online") + assert(lda.getOptimizer === "online") + lda.setLearningDecay(0.53) + assert(lda.getLearningDecay === 0.53) + lda.setLearningOffset(1027) + assert(lda.getLearningOffset === 1027) + lda.setSubsamplingRate(0.06) + assert(lda.getSubsamplingRate === 0.06) + lda.setOptimizeDocConcentration(false) + assert(!lda.getOptimizeDocConcentration) + } + + test("parameters validation") { + val lda = new LDA() + + // misc Params + intercept[IllegalArgumentException] { + new LDA().setK(1) + } + intercept[IllegalArgumentException] { + new LDA().setOptimizer("no_such_optimizer") + } + intercept[IllegalArgumentException] { + new LDA().setDocConcentration(-1.1) + } + intercept[IllegalArgumentException] { + new LDA().setTopicConcentration(-1.1) + } + + val dummyDF = sqlContext.createDataFrame(Seq( + (1, Vectors.dense(1.0, 2.0)))).toDF("id", "features") + // validate parameters + lda.transformSchema(dummyDF.schema) + lda.setDocConcentration(1.1) + lda.transformSchema(dummyDF.schema) + lda.setDocConcentration(Range(0, lda.getK).map(_ + 2.0).toArray) + lda.transformSchema(dummyDF.schema) + lda.setDocConcentration(Range(0, lda.getK - 1).map(_ + 2.0).toArray) + withClue("LDA docConcentration validity check failed for bad array length") { + intercept[IllegalArgumentException] { + lda.transformSchema(dummyDF.schema) + } + } + + // Online LDA + intercept[IllegalArgumentException] { + new LDA().setLearningOffset(0) + } + intercept[IllegalArgumentException] { + new LDA().setLearningDecay(0) + } + intercept[IllegalArgumentException] { + new LDA().setSubsamplingRate(0) + } + intercept[IllegalArgumentException] { + new LDA().setSubsamplingRate(1.1) + } + } + + test("fit & transform with Online LDA") { + val lda = new LDA().setK(k).setSeed(1).setOptimizer("online").setMaxIter(2) + val model = lda.fit(dataset) + + MLTestingUtils.checkCopy(model) + + assert(model.isInstanceOf[LocalLDAModel]) + assert(model.vocabSize === vocabSize) + assert(model.estimatedDocConcentration.size === k) + assert(model.topicsMatrix.numRows === vocabSize) + assert(model.topicsMatrix.numCols === k) + assert(!model.isDistributed) + + // transform() + val transformed = model.transform(dataset) + val expectedColumns = Array("features", lda.getTopicDistributionCol) + expectedColumns.foreach { column => + assert(transformed.columns.contains(column)) + } + transformed.select(lda.getTopicDistributionCol).collect().foreach { r => + val topicDistribution = r.getAs[Vector](0) + assert(topicDistribution.size === k) + assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0)) + } + + // logLikelihood, logPerplexity + val ll = model.logLikelihood(dataset) + assert(ll <= 0.0 && ll != Double.NegativeInfinity) + val lp = model.logPerplexity(dataset) + assert(lp >= 0.0 && lp != Double.PositiveInfinity) + + // describeTopics + val topics = model.describeTopics(3) + assert(topics.count() === k) + assert(topics.select("topic").rdd.map(_.getInt(0)).collect().toSet === Range(0, k).toSet) + topics.select("termIndices").collect().foreach { case r: Row => + val termIndices = r.getAs[Seq[Int]](0) + assert(termIndices.length === 3 && termIndices.toSet.size === 3) + } + topics.select("termWeights").collect().foreach { case r: Row => + val termWeights = r.getAs[Seq[Double]](0) + assert(termWeights.length === 3 && termWeights.forall(w => w >= 0.0 && w <= 1.0)) + } + } + + test("fit & transform with EM LDA") { + val lda = new LDA().setK(k).setSeed(1).setOptimizer("em").setMaxIter(2) + val model_ = lda.fit(dataset) + + MLTestingUtils.checkCopy(model_) + + assert(model_.isInstanceOf[DistributedLDAModel]) + val model = model_.asInstanceOf[DistributedLDAModel] + assert(model.vocabSize === vocabSize) + assert(model.estimatedDocConcentration.size === k) + assert(model.topicsMatrix.numRows === vocabSize) + assert(model.topicsMatrix.numCols === k) + assert(model.isDistributed) + + val localModel = model.toLocal + assert(localModel.isInstanceOf[LocalLDAModel]) + + // training logLikelihood, logPrior + val ll = model.trainingLogLikelihood + assert(ll <= 0.0 && ll != Double.NegativeInfinity) + val lp = model.logPrior + assert(lp <= 0.0 && lp != Double.NegativeInfinity) + } + + test("read/write LocalLDAModel") { + def checkModelData(model: LDAModel, model2: LDAModel): Unit = { + assert(model.vocabSize === model2.vocabSize) + assert(Vectors.dense(model.topicsMatrix.toArray) ~== + Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6) + assert(Vectors.dense(model.getDocConcentration) ~== + Vectors.dense(model2.getDocConcentration) absTol 1e-6) + } + val lda = new LDA() + testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData) + } + + test("read/write DistributedLDAModel") { + def checkModelData(model: LDAModel, model2: LDAModel): Unit = { + assert(model.vocabSize === model2.vocabSize) + assert(Vectors.dense(model.topicsMatrix.toArray) ~== + Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6) + assert(Vectors.dense(model.getDocConcentration) ~== + Vectors.dense(model2.getDocConcentration) absTol 1e-6) + } + val lda = new LDA() + testEstimatorAndModelReadWrite(lda, dataset, + LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData) + } + + test("EM LDA checkpointing: save last checkpoint") { + // Checkpoint dir is set by MLlibTestSparkContext + val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3).setCheckpointInterval(1) + val model_ = lda.fit(dataset) + assert(model_.isInstanceOf[DistributedLDAModel]) + val model = model_.asInstanceOf[DistributedLDAModel] + + // There should be 1 checkpoint remaining. + assert(model.getCheckpointFiles.length === 1) + val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration) + assert(fs.exists(new Path(model.getCheckpointFiles.head))) + model.deleteCheckpointFiles() + assert(model.getCheckpointFiles.isEmpty) + } + + test("EM LDA checkpointing: remove last checkpoint") { + // Checkpoint dir is set by MLlibTestSparkContext + val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3).setCheckpointInterval(1) + .setKeepLastCheckpoint(false) + val model_ = lda.fit(dataset) + assert(model_.isInstanceOf[DistributedLDAModel]) + val model = model_.asInstanceOf[DistributedLDAModel] + + assert(model.getCheckpointFiles.isEmpty) + } + + test("EM LDA disable checkpointing") { + // Checkpoint dir is set by MLlibTestSparkContext + val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3) + .setCheckpointInterval(-1) + val model_ = lda.fit(dataset) + assert(model_.isInstanceOf[DistributedLDAModel]) + val model = model_.asInstanceOf[DistributedLDAModel] + + assert(model.getCheckpointFiles.isEmpty) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index def869fe66777..27349950dc119 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -19,10 +19,53 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext -class BinaryClassificationEvaluatorSuite extends SparkFunSuite { +class BinaryClassificationEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new BinaryClassificationEvaluator) } + + test("read/write") { + val evaluator = new BinaryClassificationEvaluator() + .setRawPredictionCol("myRawPrediction") + .setLabelCol("myLabel") + .setMetricName("areaUnderPR") + testDefaultReadWrite(evaluator) + } + + test("should accept both vector and double raw prediction col") { + val evaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") + + val vectorDF = sqlContext.createDataFrame(Seq( + (0d, Vectors.dense(12, 2.5)), + (1d, Vectors.dense(1, 3)), + (0d, Vectors.dense(10, 2)) + )).toDF("label", "rawPrediction") + assert(evaluator.evaluate(vectorDF) === 1.0) + + val doubleDF = sqlContext.createDataFrame(Seq( + (0d, 0d), + (1d, 1d), + (0d, 0d) + )).toDF("label", "rawPrediction") + assert(evaluator.evaluate(doubleDF) === 1.0) + + val stringDF = sqlContext.createDataFrame(Seq( + (0d, "0d"), + (1d, "1d"), + (0d, "0d") + )).toDF("label", "rawPrediction") + val thrown = intercept[IllegalArgumentException] { + evaluator.evaluate(stringDF) + } + assert(thrown.getMessage.replace("\n", "") contains "Column rawPrediction must be of type " + + "equal to one of the following types: [DoubleType, ") + assert(thrown.getMessage.replace("\n", "") contains "but was actually of type StringType.") + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala index 6d8412b0b3701..7ee65975d22f7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala @@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext -class MulticlassClassificationEvaluatorSuite extends SparkFunSuite { +class MulticlassClassificationEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new MulticlassClassificationEvaluator) } + + test("read/write") { + val evaluator = new MulticlassClassificationEvaluator() + .setPredictionCol("myPrediction") + .setLabelCol("myLabel") + .setMetricName("recall") + testDefaultReadWrite(evaluator) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index aa722da323935..954d3bedc14bc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -20,10 +20,12 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext { +class RegressionEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new RegressionEvaluator) @@ -63,14 +65,22 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext // default = rmse val evaluator = new RegressionEvaluator() - assert(evaluator.evaluate(predictions) ~== 0.1019382 absTol 0.001) + assert(evaluator.evaluate(predictions) ~== 0.1013829 absTol 0.01) // r2 score evaluator.setMetricName("r2") - assert(evaluator.evaluate(predictions) ~== 0.9998196 absTol 0.001) + assert(evaluator.evaluate(predictions) ~== 0.9998387 absTol 0.01) // mae evaluator.setMetricName("mae") - assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001) + assert(evaluator.evaluate(predictions) ~== 0.08399089 absTol 0.01) + } + + test("read/write") { + val evaluator = new RegressionEvaluator() + .setPredictionCol("myPrediction") + .setLabelCol("myLabel") + .setMetricName("r2") + testDefaultReadWrite(evaluator) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 2086043983661..714b9db3aa19f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var data: Array[Double] = _ @@ -66,4 +68,47 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(x === y, "The feature value is not correct after binarization.") } } + + test("Binarize vector of continuous features with default parameter") { + val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0) + val dataFrame: DataFrame = sqlContext.createDataFrame(Seq( + (Vectors.dense(data), Vectors.dense(defaultBinarized)) + )).toDF("feature", "expected") + + val binarizer: Binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + + binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x == y, "The feature value is not correct after binarization.") + } + } + + test("Binarize vector of continuous features with setter") { + val threshold: Double = 0.2 + val defaultBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) + val dataFrame: DataFrame = sqlContext.createDataFrame(Seq( + (Vectors.dense(data), Vectors.dense(defaultBinarized)) + )).toDF("feature", "expected") + + val binarizer: Binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(threshold) + + binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x == y, "The feature value is not correct after binarization.") + } + } + + + test("read/write") { + val t = new Binarizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setThreshold(0.1) + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 0eba34fda6228..9ea7d431763a1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -21,13 +21,13 @@ import scala.util.Random import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new Bucketizer) @@ -112,6 +112,14 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext { val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x))) assert(bsResult ~== lsResult absTol 1e-5) } + + test("read/write") { + val t = new Bucketizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setSplits(Array(0.1, 0.8, 0.9)) + testDefaultReadWrite(t) + } } private object BucketizerSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index e5a42967bd2c8..7827db2794cf3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -18,13 +18,17 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{Row, SQLContext} -class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { +class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { + test("Test Chi-Square selector") { val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ @@ -58,4 +62,20 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { assert(vec1 ~== vec2 absTol 1e-1) } } + + test("ChiSqSelector read/write") { + val t = new ChiSqSelector() + .setFeaturesCol("myFeaturesCol") + .setLabelCol("myLabelCol") + .setOutputCol("myOutputCol") + .setNumTopFeatures(2) + testDefaultReadWrite(t) + } + + test("ChiSqSelectorModel read/write") { + val oldModel = new feature.ChiSqSelectorModel(Array(1, 3)) + val instance = new ChiSqSelectorModel("myChiSqSelectorModel", oldModel) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.selectedFeatures === instance.selectedFeatures) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index e192fa4850af0..7641e3b8cf668 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -18,14 +18,17 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row -class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { test("params") { + ParamsSuite.checkParams(new CountVectorizer) ParamsSuite.checkParams(new CountVectorizerModel(Array("empty"))) } @@ -56,14 +59,15 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { (0, split("a b c d e"), Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))), (1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))), - (2, split("c"), Vectors.sparse(5, Seq((2, 1.0)))), - (3, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0))))) + (2, split("c c"), Vectors.sparse(5, Seq((2, 2.0)))), + (3, split("d"), Vectors.sparse(5, Seq((3, 1.0)))), + (4, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0))))) ).toDF("id", "words", "expected") val cv = new CountVectorizer() .setInputCol("words") .setOutputCol("features") .fit(df) - assert(cv.vocabulary === Array("a", "b", "c", "d", "e")) + assert(cv.vocabulary.toSet === Set("a", "b", "c", "d", "e")) cv.transform(df).select("features", "expected").collect().foreach { case Row(features: Vector, expected: Vector) => @@ -154,7 +158,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { (3, split("e e e e e"), Vectors.sparse(4, Seq()))) ).toDF("id", "words", "expected") - // minTF: count + // minTF: set frequency val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) .setInputCol("words") .setOutputCol("features") @@ -164,4 +168,53 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(features ~== expected absTol 1e-14) } } + + test("CountVectorizerModel and CountVectorizer with binary") { + val df = sqlContext.createDataFrame(Seq( + (0, split("a a a a b b b b c d"), + Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))), + (1, split("c c c"), Vectors.sparse(4, Seq((2, 1.0)))), + (2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))) + )).toDF("id", "words", "expected") + + // CountVectorizer test + val cv = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setBinary(true) + .fit(df) + cv.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + + // CountVectorizerModel test + val cv2 = new CountVectorizerModel(cv.vocabulary) + .setInputCol("words") + .setOutputCol("features") + .setBinary(true) + cv2.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } + + test("CountVectorizer read/write") { + val t = new CountVectorizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinDF(0.5) + .setMinTF(3.0) + .setVocabSize(10) + testDefaultReadWrite(t) + } + + test("CountVectorizerModel read/write") { + val instance = new CountVectorizerModel("myCountVectorizerModel", Array("a", "b", "c")) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinTF(3.0) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.vocabulary === instance.vocabulary) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala index 37ed2367c33f7..36cafa290f083 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala @@ -22,14 +22,15 @@ import scala.beans.BeanInfo import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.Row @BeanInfo case class DCTTestData(vec: Vector, wantedVec: Vector) -class DCTSuite extends SparkFunSuite with MLlibTestSparkContext { +class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("forward transform of discrete cosine matches jTransforms result") { val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray) @@ -45,6 +46,14 @@ class DCTSuite extends SparkFunSuite with MLlibTestSparkContext { testDCT(data, inverse) } + test("read/write") { + val t = new DCT() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setInverse(true) + testDefaultReadWrite(t) + } + private def testDCT(data: Vector, inverse: Boolean): Unit = { val expectedResultBuffer = data.toArray.clone() if (inverse) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala new file mode 100644 index 0000000000000..fc1c05de233ea --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class ElementwiseProductSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + test("read/write") { + val ep = new ElementwiseProduct() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setScalingVec(Vectors.dense(0.1, 0.2)) + testDefaultReadWrite(ep) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala index 4157b84b29d01..addd733c20b5a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -20,12 +20,13 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { +class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new HashingTF) @@ -45,9 +46,39 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { require(attrGroup.numAttributes === Some(n)) val features = output.select("features").first().getAs[Vector](0) // Assume perfect hash on "a", "b", "c", and "d". - def idx(any: Any): Int = Utils.nonNegativeMod(any.##, n) + def idx: Any => Int = featureIdx(n) val expected = Vectors.sparse(n, Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0))) assert(features ~== expected absTol 1e-14) } + + test("applying binary term freqs") { + val df = sqlContext.createDataFrame(Seq( + (0, "a a b c c c".split(" ").toSeq) + )).toDF("id", "words") + val n = 100 + val hashingTF = new HashingTF() + .setInputCol("words") + .setOutputCol("features") + .setNumFeatures(n) + .setBinary(true) + val output = hashingTF.transform(df) + val features = output.select("features").first().getAs[Vector](0) + def idx: Any => Int = featureIdx(n) // Assume perfect hash on input features + val expected = Vectors.sparse(n, + Seq((idx("a"), 1.0), (idx("b"), 1.0), (idx("c"), 1.0))) + assert(features ~== expected absTol 1e-14) + } + + test("read/write") { + val t = new HashingTF() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setNumFeatures(10) + testDefaultReadWrite(t) + } + + private def featureIdx(numFeatures: Int)(term: Any): Int = { + Utils.nonNegativeMod(term.##, numFeatures) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index 08f80af03429b..bc958c15857ba 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row -class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { +class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = { dataSet.map { @@ -98,4 +99,20 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } } + + test("IDF read/write") { + val t = new IDF() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinDocFreq(5) + testDefaultReadWrite(t) + } + + test("IDFModel read/write") { + val instance = new IDFModel("myIDFModel", new OldIDFModel(Vectors.dense(1.0, 2.0))) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.idf === instance.idf) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala index 2beb62ca08233..0d4e00668ddb8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -22,11 +22,12 @@ import scala.collection.mutable.ArrayBuilder import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.functions.col -class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext { +class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new Interaction()) } @@ -162,4 +163,11 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext { new NumericAttribute(Some("a_2:b_1:c"), Some(9)))) assert(attrs === expectedAttrs) } + + test("read/write") { + val t = new Interaction() + .setInputCols(Array("myInputCol", "myInputCol2")) + .setOutputCol("myOutputCol") + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala new file mode 100644 index 0000000000000..e083d4713680e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row + +class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + test("MaxAbsScaler fit basic case") { + val data = Array( + Vectors.dense(1, 0, 100), + Vectors.dense(2, 0, 0), + Vectors.sparse(3, Array(0, 2), Array(-2, -100)), + Vectors.sparse(3, Array(0), Array(-1.5))) + + val expected: Array[Vector] = Array( + Vectors.dense(0.5, 0, 1), + Vectors.dense(1, 0, 0), + Vectors.sparse(3, Array(0, 2), Array(-1, -1)), + Vectors.sparse(3, Array(0), Array(-0.75))) + + val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + val scaler = new MaxAbsScaler() + .setInputCol("features") + .setOutputCol("scaled") + + val model = scaler.fit(df) + model.transform(df).select("expected", "scaled").collect() + .foreach { case Row(vector1: Vector, vector2: Vector) => + assert(vector1.equals(vector2), s"MaxAbsScaler ut error: $vector2 should be $vector1") + } + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + } + + test("MaxAbsScaler read/write") { + val t = new MaxAbsScaler() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + testDefaultReadWrite(t) + } + + test("MaxAbsScalerModel read/write") { + val instance = new MaxAbsScalerModel( + "myMaxAbsScalerModel", Vectors.dense(1.0, 10.0)) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.maxAbs === instance.maxAbs) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index c04dda41eea34..87206c777e352 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -18,16 +18,14 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.Row -class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext { +class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("MinMaxScaler fit basic case") { - val sqlContext = new SQLContext(sc) - val data = Array( Vectors.dense(1, 0, Long.MinValue), Vectors.dense(2, 0, 0), @@ -59,14 +57,37 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext { test("MinMaxScaler arguments max must be larger than min") { withClue("arguments max must be larger than min") { + val dummyDF = sqlContext.createDataFrame(Seq( + (1, Vectors.dense(1.0, 2.0)))).toDF("id", "feature") intercept[IllegalArgumentException] { - val scaler = new MinMaxScaler().setMin(10).setMax(0) - scaler.validateParams() + val scaler = new MinMaxScaler().setMin(10).setMax(0).setInputCol("feature") + scaler.transformSchema(dummyDF.schema) } intercept[IllegalArgumentException] { - val scaler = new MinMaxScaler().setMin(0).setMax(0) - scaler.validateParams() + val scaler = new MinMaxScaler().setMin(0).setMax(0).setInputCol("feature") + scaler.transformSchema(dummyDF.schema) } } } + + test("MinMaxScaler read/write") { + val t = new MinMaxScaler() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMax(1.0) + .setMin(-1.0) + testDefaultReadWrite(t) + } + + test("MinMaxScalerModel read/write") { + val instance = new MinMaxScalerModel( + "myMinMaxScalerModel", Vectors.dense(-1.0, 0.0), Vectors.dense(1.0, 10.0)) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMin(-1.0) + .setMax(1.0) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.originalMin === instance.originalMin) + assert(newInstance.originalMax === instance.originalMax) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala index ab97e3dbc6ee0..e4e15f43310ba 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -20,13 +20,14 @@ package org.apache.spark.ml.feature import scala.beans.BeanInfo import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} @BeanInfo case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) -class NGramSuite extends SparkFunSuite with MLlibTestSparkContext { +class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import org.apache.spark.ml.feature.NGramSuite._ test("default behavior yields bigram features") { @@ -79,11 +80,19 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext { ))) testNGram(nGram, dataset) } + + test("read/write") { + val t = new NGram() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setN(3) + testDefaultReadWrite(t) + } } object NGramSuite extends SparkFunSuite { - def testNGram(t: NGram, dataset: DataFrame): Unit = { + def testNGram(t: NGram, dataset: Dataset[_]): Unit = { t.transform(dataset) .select("nGrams", "wantedNGrams") .collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index 9f03470b7f328..468833901995a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -18,13 +18,14 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row} -class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var data: Array[Vector] = _ @transient var dataFrame: DataFrame = _ @@ -60,7 +61,6 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { Vectors.sparse(3, Seq()) ) - val sqlContext = new SQLContext(sc) dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData)) normalizer = new Normalizer() .setInputCol("features") @@ -104,6 +104,14 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { assertValues(result, l1Normalized) } + + test("read/write") { + val t = new Normalizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setP(3.0) + testDefaultReadWrite(t) + } } private object NormalizerSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 321eeb843941c..49803aef71587 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -20,12 +20,15 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types._ -class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { +class OneHotEncoderSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { def stringIndexed(): DataFrame = { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) @@ -49,7 +52,7 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { .setDropLast(false) val encoded = encoder.transform(transformed) - val output = encoded.select("id", "labelVec").map { r => + val output = encoded.select("id", "labelVec").rdd.map { r => val vec = r.getAs[Vector](1) (r.getInt(0), vec(0), vec(1), vec(2)) }.collect().toSet @@ -66,7 +69,7 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { .setOutputCol("labelVec") val encoded = encoder.transform(transformed) - val output = encoded.select("id", "labelVec").map { r => + val output = encoded.select("id", "labelVec").rdd.map { r => val vec = r.getAs[Vector](1) (r.getInt(0), vec(0), vec(1)) }.collect().toSet @@ -101,4 +104,40 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) } + + test("read/write") { + val t = new OneHotEncoder() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setDropLast(false) + testDefaultReadWrite(t) + } + + test("OneHotEncoder with varying types") { + val df = stringIndexed() + val dfWithTypes = df + .withColumn("shortLabel", df("labelIndex").cast(ShortType)) + .withColumn("longLabel", df("labelIndex").cast(LongType)) + .withColumn("intLabel", df("labelIndex").cast(IntegerType)) + .withColumn("floatLabel", df("labelIndex").cast(FloatType)) + .withColumn("decimalLabel", df("labelIndex").cast(DecimalType(10, 0))) + val cols = Array("labelIndex", "shortLabel", "longLabel", "intLabel", + "floatLabel", "decimalLabel") + for (col <- cols) { + val encoder = new OneHotEncoder() + .setInputCol(col) + .setOutputCol("labelVec") + .setDropLast(false) + val encoded = encoder.transform(dfWithTypes) + + val output = encoded.select("id", "labelVec").rdd.map { r => + val vec = r.getAs[Vector](1) + (r.getInt(0), vec(0), vec(1), vec(2)) + }.collect().toSet + // a -> 0, b -> 2, c -> 1 + val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0), + (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0)) + assert(output === expected) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index 30c500f87a769..f372ec58269e4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -19,20 +19,20 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.distributed.RowMatrix -import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, Matrices} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.mllib.feature.{PCAModel => OldPCAModel} import org.apache.spark.sql.Row -class PCASuite extends SparkFunSuite with MLlibTestSparkContext { +class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new PCA) val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix] - val model = new PCAModel("pca", new OldPCAModel(2, mat)) + val explainedVariance = Vectors.dense(0.5, 0.5).asInstanceOf[DenseVector] + val model = new PCAModel("pca", mat, explainedVariance) ParamsSuite.checkParams(model) } @@ -65,4 +65,20 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext { assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } } + + test("PCA read/write") { + val t = new PCA() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setK(3) + testDefaultReadWrite(t) + } + + test("PCAModel read/write") { + val instance = new PCAModel("myPCAModel", + Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix], + Vectors.dense(0.5, 0.5).asInstanceOf[DenseVector]) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.pc === instance.pc) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index 29eebd8960ebc..86dbee1cf4a5a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -17,37 +17,48 @@ package org.apache.spark.ml.feature -import org.apache.spark.ml.param.ParamsSuite import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row -class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext { +class PolynomialExpansionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new PolynomialExpansion) } - test("Polynomial expansion with default parameter") { - val data = Array( - Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), - Vectors.dense(-2.0, 2.3), - Vectors.dense(0.0, 0.0, 0.0), - Vectors.dense(0.6, -1.1, -3.0), - Vectors.sparse(3, Seq()) - ) - - val twoDegreeExpansion: Array[Vector] = Array( - Vectors.sparse(9, Array(0, 1, 2, 3, 4), Array(-2.0, 4.0, 2.3, -4.6, 5.29)), - Vectors.dense(-2.0, 4.0, 2.3, -4.6, 5.29), - Vectors.dense(new Array[Double](9)), - Vectors.dense(0.6, 0.36, -1.1, -0.66, 1.21, -3.0, -1.8, 3.3, 9.0), - Vectors.sparse(9, Array.empty, Array.empty)) + private val data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(-2.0, 2.3), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.6, -1.1, -3.0), + Vectors.sparse(3, Seq()) + ) + + private val twoDegreeExpansion: Array[Vector] = Array( + Vectors.sparse(9, Array(0, 1, 2, 3, 4), Array(-2.0, 4.0, 2.3, -4.6, 5.29)), + Vectors.dense(-2.0, 4.0, 2.3, -4.6, 5.29), + Vectors.dense(new Array[Double](9)), + Vectors.dense(0.6, 0.36, -1.1, -0.66, 1.21, -3.0, -1.8, 3.3, 9.0), + Vectors.sparse(9, Array.empty, Array.empty)) + + private val threeDegreeExpansion: Array[Vector] = Array( + Vectors.sparse(19, Array(0, 1, 2, 3, 4, 5, 6, 7, 8), + Array(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17)), + Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17), + Vectors.dense(new Array[Double](19)), + Vectors.dense(0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331, -3.0, -1.8, + -1.08, 3.3, 1.98, -3.63, 9.0, 5.4, -9.9, -27.0), + Vectors.sparse(19, Array.empty, Array.empty)) + test("Polynomial expansion with default parameter") { val df = sqlContext.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected") val polynomialExpansion = new PolynomialExpansion() @@ -65,23 +76,6 @@ class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext } test("Polynomial expansion with setter") { - val data = Array( - Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), - Vectors.dense(-2.0, 2.3), - Vectors.dense(0.0, 0.0, 0.0), - Vectors.dense(0.6, -1.1, -3.0), - Vectors.sparse(3, Seq()) - ) - - val threeDegreeExpansion: Array[Vector] = Array( - Vectors.sparse(19, Array(0, 1, 2, 3, 4, 5, 6, 7, 8), - Array(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17)), - Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17), - Vectors.dense(new Array[Double](19)), - Vectors.dense(0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331, -3.0, -1.8, - -1.08, 3.3, 1.98, -3.63, 9.0, 5.4, -9.9, -27.0), - Vectors.sparse(19, Array.empty, Array.empty)) - val df = sqlContext.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected") val polynomialExpansion = new PolynomialExpansion() @@ -98,5 +92,29 @@ class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext throw new TestFailedException("Unmatched data types after polynomial expansion", 0) } } + + test("Polynomial expansion with degree 1 is identity on vectors") { + val df = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected") + + val polynomialExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(1) + + polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { + case Row(expanded: Vector, expected: Vector) => + assert(expanded ~== expected absTol 1e-1) + case _ => + throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + } + } + + test("read/write") { + val t = new PolynomialExpansion() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setDegree(3) + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index b2bdd8935f903..8895d630a0879 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -17,82 +17,80 @@ package org.apache.spark.ml.feature -import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.functions.udf -class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext { - import org.apache.spark.ml.feature.QuantileDiscretizerSuite._ +class QuantileDiscretizerSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - test("Test quantile discretizer") { - checkDiscretizedData(sc, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - 10, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity")) - - checkDiscretizedData(sc, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - 4, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity")) - - checkDiscretizedData(sc, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - 3, - Array[Double](0, 1, 2, 2, 2, 2, 2, 2, 2), - Array("-Infinity, 2.0", "2.0, 3.0", "3.0, Infinity")) + test("Test observed number of buckets and their sizes match expected values") { + val sqlCtx = SQLContext.getOrCreate(sc) + import sqlCtx.implicits._ - checkDiscretizedData(sc, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - 2, - Array[Double](0, 1, 1, 1, 1, 1, 1, 1, 1), - Array("-Infinity, 2.0", "2.0, Infinity")) + val datasetSize = 100000 + val numBuckets = 5 + val df = sc.parallelize(1.0 to datasetSize by 1.0).map(Tuple1.apply).toDF("input") + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBuckets(numBuckets) + val result = discretizer.fit(df).transform(df) - } + val observedNumBuckets = result.select("result").distinct.count + assert(observedNumBuckets === numBuckets, + "Observed number of buckets does not equal expected number of buckets.") - test("Test getting splits") { - val splitTestPoints = Array( - Array[Double]() -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(Double.NegativeInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(Double.PositiveInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(Double.NegativeInfinity, Double.PositiveInfinity) - -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(0.0) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(1.0) -> Array(Double.NegativeInfinity, 1, Double.PositiveInfinity), - Array(0.0, 1.0) -> Array(Double.NegativeInfinity, 0, 1, Double.PositiveInfinity) - ) - for ((ori, res) <- splitTestPoints) { - assert(QuantileDiscretizer.getSplits(ori) === res, "Returned splits are invalid.") + val relativeError = discretizer.getRelativeError + val isGoodBucket = udf { + (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize) } + val numGoodBuckets = result.groupBy("result").count.filter(isGoodBucket($"count")).count + assert(numGoodBuckets === numBuckets, + "Bucket sizes are not within expected relative error tolerance.") } -} - -private object QuantileDiscretizerSuite extends SparkFunSuite { - def checkDiscretizedData( - sc: SparkContext, - data: Array[Double], - numBucket: Int, - expectedResult: Array[Double], - expectedAttrs: Array[String]): Unit = { + test("Test transform method on unseen data") { val sqlCtx = SQLContext.getOrCreate(sc) import sqlCtx.implicits._ - val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input") - val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result") - .setNumBuckets(numBucket) - val result = discretizer.fit(df).transform(df) + val trainDF = sc.parallelize(1.0 to 100.0 by 1.0).map(Tuple1.apply).toDF("input") + val testDF = sc.parallelize(-10.0 to 110.0 by 1.0).map(Tuple1.apply).toDF("input") + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBuckets(5) + + val result = discretizer.fit(trainDF).transform(testDF) + val firstBucketSize = result.filter(result("result") === 0.0).count + val lastBucketSize = result.filter(result("result") === 4.0).count + + assert(firstBucketSize === 30L, + s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.") + assert(lastBucketSize === 31L, + s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.") + } - val transformedFeatures = result.select("result").collect() - .map { case Row(transformedFeature: Double) => transformedFeature } - val transformedAttrs = Attribute.fromStructField(result.schema("result")) - .asInstanceOf[NominalAttribute].values.get + test("read/write") { + val t = new QuantileDiscretizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setNumBuckets(6) + testDefaultReadWrite(t) + } + + test("Verify resulting model has parent") { + val sqlCtx = SQLContext.getOrCreate(sc) + import sqlCtx.implicits._ - assert(transformedFeatures === expectedResult, - "Transformed features do not equal expected features.") - assert(transformedAttrs === expectedAttrs, - "Transformed attributes do not equal expected attributes.") + val df = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("input") + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBuckets(5) + val model = discretizer.fit(df) + assert(model.hasParent) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index dc20a5ec2152d..e1b269b5b681f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext -class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { +class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new RFormula()) } @@ -143,6 +144,44 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { assert(attrs === expectedAttrs) } + test("vector attribute generation") { + val formula = new RFormula().setFormula("id ~ vec") + val original = sqlContext.createDataFrame( + Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) + ).toDF("id", "vec") + val model = formula.fit(original) + val result = model.transform(original) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("vec_0"), Some(1)), + new NumericAttribute(Some("vec_1"), Some(2)))) + assert(attrs === expectedAttrs) + } + + test("vector attribute generation with unnamed input attrs") { + val formula = new RFormula().setFormula("id ~ vec2") + val base = sqlContext.createDataFrame( + Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) + ).toDF("id", "vec") + val metadata = new AttributeGroup( + "vec2", + Array[Attribute]( + NumericAttribute.defaultAttr, + NumericAttribute.defaultAttr)).toMetadata + val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata)) + val model = formula.fit(original) + val result = model.transform(original) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("vec2_0"), Some(1)), + new NumericAttribute(Some("vec2_1"), Some(2)))) + assert(attrs === expectedAttrs) + } + test("numeric interaction") { val formula = new RFormula().setFormula("a ~ b:c:d") val original = sqlContext.createDataFrame( @@ -214,4 +253,41 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { new NumericAttribute(Some("a_foo:b_zz"), Some(4)))) assert(attrs === expectedAttrs) } + + test("read/write: RFormula") { + val rFormula = new RFormula() + .setFormula("id ~ a:b") + .setFeaturesCol("myFeatures") + .setLabelCol("myLabels") + + testDefaultReadWrite(rFormula) + } + + test("read/write: RFormulaModel") { + def checkModelData(model: RFormulaModel, model2: RFormulaModel): Unit = { + assert(model.uid === model2.uid) + + assert(model.resolvedFormula.label === model2.resolvedFormula.label) + assert(model.resolvedFormula.terms === model2.resolvedFormula.terms) + assert(model.resolvedFormula.hasIntercept === model2.resolvedFormula.hasIntercept) + + assert(model.pipelineModel.uid === model2.pipelineModel.uid) + + model.pipelineModel.stages.zip(model2.pipelineModel.stages).foreach { + case (transformer1, transformer2) => + assert(transformer1.uid === transformer2.uid) + assert(transformer1.params === transformer2.params) + } + } + + val dataset = sqlContext.createDataFrame( + Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) + ).toDF("id", "a", "b") + + val rFormula = new RFormula().setFormula("id ~ a:b") + + val model = rFormula.fit(dataset) + val newModel = testDefaultReadWrite(model) + checkModelData(model, newModel) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala index d19052881ae45..e213e17d0d9de 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -19,9 +19,12 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.types.{LongType, StructField, StructType} -class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext { +class SQLTransformerSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new SQLTransformer()) @@ -41,4 +44,19 @@ class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(resultSchema == expected.schema) assert(result.collect().toSeq == expected.collect().toSeq) } + + test("read/write") { + val t = new SQLTransformer() + .setStatement("select * from __THIS__") + testDefaultReadWrite(t) + } + + test("transformSchema") { + val df = sqlContext.range(10) + val outputSchema = new SQLTransformer() + .setStatement("SELECT id + 1 AS id1 FROM __THIS__") + .transformSchema(df.schema) + val expected = StructType(Seq(StructField("id1", LongType, nullable = false))) + assert(outputSchema === expected) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala new file mode 100644 index 0000000000000..8c5e47a22c969 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row} + +class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { + + @transient var data: Array[Vector] = _ + @transient var resWithStd: Array[Vector] = _ + @transient var resWithMean: Array[Vector] = _ + @transient var resWithBoth: Array[Vector] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + data = Array( + Vectors.dense(-2.0, 2.3, 0.0), + Vectors.dense(0.0, -5.1, 1.0), + Vectors.dense(1.7, -0.6, 3.3) + ) + resWithMean = Array( + Vectors.dense(-1.9, 3.433333333333, -1.433333333333), + Vectors.dense(0.1, -3.966666666667, -0.433333333333), + Vectors.dense(1.8, 0.533333333333, 1.866666666667) + ) + resWithStd = Array( + Vectors.dense(-1.079898494312, 0.616834091415, 0.0), + Vectors.dense(0.0, -1.367762550529, 0.590968109266), + Vectors.dense(0.917913720165, -0.160913241239, 1.950194760579) + ) + resWithBoth = Array( + Vectors.dense(-1.0259035695965, 0.920781324866, -0.8470542899497), + Vectors.dense(0.0539949247156, -1.063815317078, -0.256086180682), + Vectors.dense(0.9719086448809, 0.143033992212, 1.103140470631) + ) + } + + def assertResult(df: DataFrame): Unit = { + df.select("standardized_features", "expected").collect().foreach { + case Row(vector1: Vector, vector2: Vector) => + assert(vector1 ~== vector2 absTol 1E-5, + "The vector value is not correct after standardization.") + } + } + + test("params") { + ParamsSuite.checkParams(new StandardScaler) + ParamsSuite.checkParams(new StandardScalerModel("empty", + Vectors.dense(1.0), Vectors.dense(2.0))) + } + + test("Standardization with default parameter") { + val df0 = sqlContext.createDataFrame(data.zip(resWithStd)).toDF("features", "expected") + + val standardScaler0 = new StandardScaler() + .setInputCol("features") + .setOutputCol("standardized_features") + .fit(df0) + + assertResult(standardScaler0.transform(df0)) + } + + test("Standardization with setter") { + val df1 = sqlContext.createDataFrame(data.zip(resWithBoth)).toDF("features", "expected") + val df2 = sqlContext.createDataFrame(data.zip(resWithMean)).toDF("features", "expected") + val df3 = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected") + + val standardScaler1 = new StandardScaler() + .setInputCol("features") + .setOutputCol("standardized_features") + .setWithMean(true) + .setWithStd(true) + .fit(df1) + + val standardScaler2 = new StandardScaler() + .setInputCol("features") + .setOutputCol("standardized_features") + .setWithMean(true) + .setWithStd(false) + .fit(df2) + + val standardScaler3 = new StandardScaler() + .setInputCol("features") + .setOutputCol("standardized_features") + .setWithMean(false) + .setWithStd(false) + .fit(df3) + + assertResult(standardScaler1.transform(df1)) + assertResult(standardScaler2.transform(df2)) + assertResult(standardScaler3.transform(df3)) + } + + test("StandardScaler read/write") { + val t = new StandardScaler() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setWithStd(false) + .setWithMean(true) + testDefaultReadWrite(t) + } + + test("StandardScalerModel read/write") { + val instance = new StandardScalerModel("myStandardScalerModel", + Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0)) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.std === instance.std) + assert(newInstance.mean === instance.mean) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index e0d433f566c25..3505befdf8e37 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -18,11 +18,12 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} object StopWordsRemoverSuite extends SparkFunSuite { - def testStopWordsRemover(t: StopWordsRemover, dataset: DataFrame): Unit = { + def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = { t.transform(dataset) .select("filtered", "expected") .collect() @@ -32,7 +33,9 @@ object StopWordsRemoverSuite extends SparkFunSuite { } } -class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext { +class StopWordsRemoverSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import StopWordsRemoverSuite._ test("StopWordsRemover default") { @@ -77,4 +80,28 @@ class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext { testStopWordsRemover(remover, dataSet) } + + test("read/write") { + val t = new StopWordsRemover() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setStopWords(Array("the", "a")) + .setCaseSensitive(true) + testDefaultReadWrite(t) + } + + test("StopWordsRemover output column already exists") { + val outputCol = "expected" + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol(outputCol) + val dataSet = sqlContext.createDataFrame(Seq( + (Seq("The", "the", "swift"), Seq("swift")) + )).toDF("raw", outputCol) + + val thrown = intercept[IllegalArgumentException] { + testStopWordsRemover(remover, dataSet) + } + assert(thrown.getMessage == s"requirement failed: Column $outputCol already exists.") + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index ddcdb5f4212be..d0f3cdc841d11 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,16 +17,17 @@ package org.apache.spark.ml.feature -import org.apache.spark.sql.types.{StringType, StructType, StructField, DoubleType} import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} -class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { +class StringIndexerSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new StringIndexer) @@ -51,7 +52,7 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { val attr = Attribute.fromStructField(transformed.schema("labelIndex")) .asInstanceOf[NominalAttribute] assert(attr.values.get === Array("a", "c", "b")) - val output = transformed.select("id", "labelIndex").map { r => + val output = transformed.select("id", "labelIndex").rdd.map { r => (r.getInt(0), r.getDouble(1)) }.collect().toSet // a -> 0, b -> 2, c -> 1 @@ -82,7 +83,7 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { val attr = Attribute.fromStructField(transformed.schema("labelIndex")) .asInstanceOf[NominalAttribute] assert(attr.values.get === Array("b", "a")) - val output = transformed.select("id", "labelIndex").map { r => + val output = transformed.select("id", "labelIndex").rdd.map { r => (r.getInt(0), r.getDouble(1)) }.collect().toSet // a -> 1, b -> 0 @@ -101,7 +102,7 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { val attr = Attribute.fromStructField(transformed.schema("labelIndex")) .asInstanceOf[NominalAttribute] assert(attr.values.get === Array("100", "300", "200")) - val output = transformed.select("id", "labelIndex").map { r => + val output = transformed.select("id", "labelIndex").rdd.map { r => (r.getInt(0), r.getDouble(1)) }.collect().toSet // 100 -> 0, 200 -> 2, 300 -> 1 @@ -113,8 +114,36 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c")) .setInputCol("label") .setOutputCol("labelIndex") - val df = sqlContext.range(0L, 10L) - assert(indexerModel.transform(df).eq(df)) + val df = sqlContext.range(0L, 10L).toDF() + assert(indexerModel.transform(df).collect().toSet === df.collect().toSet) + } + + test("StringIndexerModel can't overwrite output column") { + val df = sqlContext.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output") + val indexer = new StringIndexer() + .setInputCol("input") + .setOutputCol("output") + .fit(df) + intercept[IllegalArgumentException] { + indexer.transform(df) + } + } + + test("StringIndexer read/write") { + val t = new StringIndexer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setHandleInvalid("skip") + testDefaultReadWrite(t) + } + + test("StringIndexerModel read/write") { + val instance = new StringIndexerModel("myStringIndexerModel", Array("a", "b", "c")) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setHandleInvalid("skip") + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.labels === instance.labels) } test("IndexToString params") { @@ -173,4 +202,25 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { val outSchema = idxToStr.transformSchema(inSchema) assert(outSchema("output").dataType === StringType) } + + test("IndexToString read/write") { + val t = new IndexToString() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setLabels(Array("a", "b", "c")) + testDefaultReadWrite(t) + } + + test("StringIndexer metadata") { + val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + val transformed = indexer.transform(df) + val attrs = + NominalAttribute.decodeStructField(transformed.schema("labelIndex"), preserveName = true) + assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex") + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index e5fd21c3f6fca..299f6223b20b7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -21,20 +21,30 @@ import scala.beans.BeanInfo import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} @BeanInfo case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) -class TokenizerSuite extends SparkFunSuite { +class TokenizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new Tokenizer) } + + test("read/write") { + val t = new Tokenizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + testDefaultReadWrite(t) + } } -class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class RegexTokenizerSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import org.apache.spark.ml.feature.RegexTokenizerSuite._ test("params") { @@ -48,13 +58,13 @@ class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { .setInputCol("rawText") .setOutputCol("tokens") val dataset0 = sqlContext.createDataFrame(Seq( - TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization", ".")), - TokenizerTestData("Te,st. punct", Array("Te", ",", "st", ".", "punct")) + TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization", ".")), + TokenizerTestData("Te,st. punct", Array("te", ",", "st", ".", "punct")) )) testRegexTokenizer(tokenizer0, dataset0) val dataset1 = sqlContext.createDataFrame(Seq( - TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization")), + TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization")), TokenizerTestData("Te,st. punct", Array("punct")) )) tokenizer0.setMinTokenLength(3) @@ -64,16 +74,39 @@ class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { .setInputCol("rawText") .setOutputCol("tokens") val dataset2 = sqlContext.createDataFrame(Seq( - TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization.")), - TokenizerTestData("Te,st. punct", Array("Te,st.", "punct")) + TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization.")), + TokenizerTestData("Te,st. punct", Array("te,st.", "punct")) )) testRegexTokenizer(tokenizer2, dataset2) } + + test("RegexTokenizer with toLowercase false") { + val tokenizer = new RegexTokenizer() + .setInputCol("rawText") + .setOutputCol("tokens") + .setToLowercase(false) + val dataset = sqlContext.createDataFrame(Seq( + TokenizerTestData("JAVA SCALA", Array("JAVA", "SCALA")), + TokenizerTestData("java scala", Array("java", "scala")) + )) + testRegexTokenizer(tokenizer, dataset) + } + + test("read/write") { + val t = new RegexTokenizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinTokenLength(2) + .setGaps(false) + .setPattern("hi") + .setToLowercase(false) + testDefaultReadWrite(t) + } } object RegexTokenizerSuite extends SparkFunSuite { - def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = { + def testRegexTokenizer(t: RegexTokenizer, dataset: Dataset[_]): Unit = { t.transform(dataset) .select("tokens", "wantedTokens") .collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index bb4d5b983e0d4..dce994fdbd056 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -20,12 +20,14 @@ package org.apache.spark.ml.feature import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col -class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext { +class VectorAssemblerSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new VectorAssembler) @@ -67,6 +69,17 @@ class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("transform should throw an exception in case of unsupported type") { + val df = sqlContext.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c") + val assembler = new VectorAssembler() + .setInputCols(Array("a", "b", "c")) + .setOutputCol("features") + val thrown = intercept[SparkException] { + assembler.transform(df) + } + assert(thrown.getMessage contains "VectorAssembler does not support the StringType type") + } + test("ML attributes") { val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari") val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0) @@ -98,7 +111,14 @@ class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(userGenderOut === user.getAttr("gender").withName("user_gender").withIndex(3)) val userSalaryOut = features.getAttr(4) assert(userSalaryOut === user.getAttr("salary").withName("user_salary").withIndex(4)) - assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5)) - assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6)) + assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5).withName("ad_0")) + assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6).withName("ad_1")) + } + + test("read/write") { + val t = new VectorAssembler() + .setInputCols(Array("myInputCol", "myInputCol2")) + .setOutputCol("myOutputCol") + testDefaultReadWrite(t) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 8cb0a2cf14d37..1ffc62b38e856 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -19,16 +19,18 @@ package org.apache.spark.ml.feature import scala.beans.{BeanInfo, BeanProperty} -import org.apache.spark.{Logging, SparkException, SparkFunSuite} +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.internal.Logging import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { +class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest with Logging { import VectorIndexerSuite.FeatureData @@ -158,7 +160,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with L // Chose correct categorical features assert(categoryMaps.keys.toSet === categoricalFeatures) val transformed = model.transform(data).select("indexed") - val indexedRDD: RDD[Vector] = transformed.map(_.getAs[Vector](0)) + val indexedRDD: RDD[Vector] = transformed.rdd.map(_.getAs[Vector](0)) val featureAttrs = AttributeGroup.fromStructField(transformed.schema("indexed")) assert(featureAttrs.name === "indexed") assert(featureAttrs.attributes.get.length === model.numFeatures) @@ -215,7 +217,8 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with L val points = data.collect().map(_.getAs[Vector](0)) val vectorIndexer = getIndexer.setMaxCategories(maxCategories) val model = vectorIndexer.fit(data) - val indexedPoints = model.transform(data).select("indexed").map(_.getAs[Vector](0)).collect() + val indexedPoints = + model.transform(data).select("indexed").rdd.map(_.getAs[Vector](0)).collect() points.zip(indexedPoints).foreach { case (orig: SparseVector, indexed: SparseVector) => assert(orig.indices.length == indexed.indices.length) @@ -251,6 +254,23 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with L } } } + + test("VectorIndexer read/write") { + val t = new VectorIndexer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMaxCategories(30) + testDefaultReadWrite(t) + } + + test("VectorIndexerModel read/write") { + val categoryMaps = Map(0 -> Map(0.0 -> 0, 1.0 -> 1), 1 -> Map(0.0 -> 0, 1.0 -> 1, + 2.0 -> 2, 3.0 -> 3), 2 -> Map(0.0 -> 0, -1.0 -> 1, 2.0 -> 2)) + val instance = new VectorIndexerModel("myVectorIndexerModel", 3, categoryMaps) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.numFeatures === instance.numFeatures) + assert(newInstance.categoryMaps === instance.categoryMaps) + } } private[feature] object VectorIndexerSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala index a6c2fba8360dd..6bb4678dc5f97 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala @@ -20,21 +20,22 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.types.{StructField, StructType} -class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext { +class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { - val slicer = new VectorSlicer + val slicer = new VectorSlicer().setInputCol("feature") ParamsSuite.checkParams(slicer) assert(slicer.getIndices.length === 0) assert(slicer.getNames.length === 0) withClue("VectorSlicer should not have any features selected by default") { intercept[IllegalArgumentException] { - slicer.validateParams() + slicer.transformSchema(StructType(Seq(StructField("feature", new VectorUDT, true)))) } } } @@ -53,8 +54,6 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext { } test("Test vector slicer") { - val sqlContext = new SQLContext(sc) - val data = Array( Vectors.sparse(5, Seq((0, -2.0), (1, 2.3))), Vectors.dense(-2.0, 2.3, 0.0, 0.0, 1.0), @@ -106,4 +105,13 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext { vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4")) validateResults(vectorSlicer.transform(df)) } + + test("read/write") { + val t = new VectorSlicer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setIndices(Array(1, 3)) + .setNames(Array("a", "d")) + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index a2e46f2029956..80c177b8d3188 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -19,14 +19,14 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} +import org.apache.spark.sql.Row -class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { +class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new Word2Vec) @@ -35,7 +35,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { } test("Word2Vec") { - val sqlContext = new SQLContext(sc) + + val sqlContext = this.sqlContext import sqlContext.implicits._ val sentence = "a b " * 100 + "a c " * 10 @@ -66,15 +67,18 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { // copied model must have the same parent. MLTestingUtils.checkCopy(model) + // These expectations are just magic values, characterizing the current + // behavior. The test needs to be updated to be more general, see SPARK-11502 + val magicExp = Vectors.dense(0.30153007534417237, -0.6833061711354689, 0.5116530778733167) model.transform(docDF).select("result", "expected").collect().foreach { case Row(vector1: Vector, vector2: Vector) => - assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.") + assert(vector1 ~== magicExp absTol 1E-5, "Transformed vector is different with expected.") } } test("getVectors") { - val sqlContext = new SQLContext(sc) + val sqlContext = this.sqlContext import sqlContext.implicits._ val sentence = "a b " * 100 + "a c " * 10 @@ -96,11 +100,18 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { .setSeed(42L) .fit(docDF) - val realVectors = model.getVectors.sort("word").select("vector").map { + val realVectors = model.getVectors.sort("word").select("vector").rdd.map { case Row(v: Vector) => v }.collect() + // These expectations are just magic values, characterizing the current + // behavior. The test needs to be updated to be more general, see SPARK-11502 + val magicExpected = Seq( + Vectors.dense(0.3326166272163391, -0.5603077411651611, -0.2309209555387497), + Vectors.dense(0.32463887333869934, -0.9306551218032837, 1.393115520477295), + Vectors.dense(-0.27150997519493103, 0.4372006058692932, -0.13465698063373566) + ) - realVectors.zip(expectedVectors).foreach { + realVectors.zip(magicExpected).foreach { case (real, expected) => assert(real ~== expected absTol 1E-5, "Actual vector is different from expected.") } @@ -108,7 +119,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { test("findSynonyms") { - val sqlContext = new SQLContext(sc) + val sqlContext = this.sqlContext import sqlContext.implicits._ val sentence = "a b " * 100 + "a c " * 10 @@ -122,8 +133,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { .setSeed(42L) .fit(docDF) - val expectedSimilarity = Array(0.2789285076917586, -0.6336972059851644) - val (synonyms, similarity) = model.findSynonyms("a", 2).map { + val expectedSimilarity = Array(0.2608488929093532, -0.8271274846926078) + val (synonyms, similarity) = model.findSynonyms("a", 2).rdd.map { case Row(w: String, sim: Double) => (w, sim) }.collect().unzip @@ -131,7 +142,69 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { expectedSimilarity.zip(similarity).map { case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5) } + } + + test("window size") { + + val sqlContext = this.sqlContext + import sqlContext.implicits._ + + val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10 + val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) + val docDF = doc.zip(doc).toDF("text", "alsotext") + + val model = new Word2Vec() + .setVectorSize(3) + .setWindowSize(2) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .fit(docDF) + val (synonyms, similarity) = model.findSynonyms("a", 6).rdd.map { + case Row(w: String, sim: Double) => (w, sim) + }.collect().unzip + + // Increase the window size + val biggerModel = new Word2Vec() + .setVectorSize(3) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .setWindowSize(10) + .fit(docDF) + + val (synonymsLarger, similarityLarger) = model.findSynonyms("a", 6).rdd.map { + case Row(w: String, sim: Double) => (w, sim) + }.collect().unzip + // The similarity score should be very different with the larger window + assert(math.abs(similarity(5) - similarityLarger(5) / similarity(5)) > 1E-5) + } + + test("Word2Vec read/write") { + val t = new Word2Vec() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMaxIter(2) + .setMinCount(8) + .setNumPartitions(1) + .setSeed(42L) + .setStepSize(0.01) + .setVectorSize(100) + testDefaultReadWrite(t) + } + + test("Word2VecModel read/write") { + val word2VecMap = Map( + ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)), + ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)), + ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), + ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f)) + ) + val oldModel = new OldWord2VecModel(word2VecMap) + val instance = new Word2VecModel("myWord2VecModel", oldModel) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.getVectors.collect() === instance.getVectors.collect()) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala deleted file mode 100644 index 460849c79f04f..0000000000000 --- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.impl - -import scala.collection.JavaConverters._ - -import org.apache.spark.SparkFunSuite -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} -import org.apache.spark.ml.tree._ -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SQLContext, DataFrame} - - -private[ml] object TreeTests extends SparkFunSuite { - - /** - * Convert the given data to a DataFrame, and set the features and label metadata. - * @param data Dataset. Categorical features and labels must already have 0-based indices. - * This must be non-empty. - * @param categoricalFeatures Map: categorical feature index -> number of distinct values - * @param numClasses Number of classes label can take. If 0, mark as continuous. - * @return DataFrame with metadata - */ - def setMetadata( - data: RDD[LabeledPoint], - categoricalFeatures: Map[Int, Int], - numClasses: Int): DataFrame = { - val sqlContext = new SQLContext(data.sparkContext) - import sqlContext.implicits._ - val df = data.toDF() - val numFeatures = data.first().features.size - val featuresAttributes = Range(0, numFeatures).map { feature => - if (categoricalFeatures.contains(feature)) { - NominalAttribute.defaultAttr.withIndex(feature).withNumValues(categoricalFeatures(feature)) - } else { - NumericAttribute.defaultAttr.withIndex(feature) - } - }.toArray - val featuresMetadata = new AttributeGroup("features", featuresAttributes).toMetadata() - val labelAttribute = if (numClasses == 0) { - NumericAttribute.defaultAttr.withName("label") - } else { - NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses) - } - val labelMetadata = labelAttribute.toMetadata() - df.select(df("features").as("features", featuresMetadata), - df("label").as("label", labelMetadata)) - } - - /** Java-friendly version of [[setMetadata()]] */ - def setMetadata( - data: JavaRDD[LabeledPoint], - categoricalFeatures: java.util.Map[java.lang.Integer, java.lang.Integer], - numClasses: Int): DataFrame = { - setMetadata(data.rdd, categoricalFeatures.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, - numClasses) - } - - /** - * Check if the two trees are exactly the same. - * Note: I hesitate to override Node.equals since it could cause problems if users - * make mistakes such as creating loops of Nodes. - * If the trees are not equal, this prints the two trees and throws an exception. - */ - def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = { - try { - checkEqual(a.rootNode, b.rootNode) - } catch { - case ex: Exception => - throw new AssertionError("checkEqual failed since the two trees were not identical.\n" + - "TREE A:\n" + a.toDebugString + "\n" + - "TREE B:\n" + b.toDebugString + "\n", ex) - } - } - - /** - * Return true iff the two nodes and their descendants are exactly the same. - * Note: I hesitate to override Node.equals since it could cause problems if users - * make mistakes such as creating loops of Nodes. - */ - private def checkEqual(a: Node, b: Node): Unit = { - assert(a.prediction === b.prediction) - assert(a.impurity === b.impurity) - (a, b) match { - case (aye: InternalNode, bee: InternalNode) => - assert(aye.split === bee.split) - checkEqual(aye.leftChild, bee.leftChild) - checkEqual(aye.rightChild, bee.rightChild) - case (aye: LeafNode, bee: LeafNode) => // do nothing - case _ => - throw new AssertionError("Found mismatched nodes") - } - } - - /** - * Check if the two models are exactly the same. - * If the models are not equal, this throws an exception. - */ - def checkEqual(a: TreeEnsembleModel, b: TreeEnsembleModel): Unit = { - try { - a.trees.zip(b.trees).foreach { case (treeA, treeB) => - TreeTests.checkEqual(treeA, treeB) - } - assert(a.treeWeights === b.treeWeights) - } catch { - case ex: Exception => throw new AssertionError( - "checkEqual failed since the two tree ensembles were not identical") - } - } - - /** - * Helper method for constructing a tree for testing. - * Given left, right children, construct a parent node. - * @param split Split for parent node - * @return Parent node with children attached - */ - def buildParentNode(left: Node, right: Node, split: Split): Node = { - val leftImp = left.impurityStats - val rightImp = right.impurityStats - val parentImp = leftImp.copy.add(rightImp) - val leftWeight = leftImp.count / parentImp.count.toDouble - val rightWeight = rightImp.count / parentImp.count.toDouble - val gain = parentImp.calculate() - - (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate()) - val pred = parentImp.predict - new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp) - } -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala new file mode 100644 index 0000000000000..604021220a139 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.optim + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.Instance +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD + +class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext { + + private var instances1: RDD[Instance] = _ + private var instances2: RDD[Instance] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 2, 1, 3), 4, 2) + b <- c(1, 0, 1, 0) + w <- c(1, 2, 3, 4) + */ + instances1 = sc.parallelize(Seq( + Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(0.0, 2.0, Vectors.dense(1.0, 2.0)), + Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)), + Instance(0.0, 4.0, Vectors.dense(3.0, 3.0)) + ), 2) + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b <- c(2, 8, 3, 9) + w <- c(1, 2, 3, 4) + */ + instances2 = sc.parallelize(Seq( + Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2) + } + + test("IRLS against GLM with Binomial errors") { + /* + R code: + + df <- as.data.frame(cbind(A, b)) + for (formula in c(b ~ . -1, b ~ .)) { + model <- glm(formula, family="binomial", data=df, weights=w) + print(as.vector(coef(model))) + } + + [1] -0.30216651 -0.04452045 + [1] 3.5651651 -1.2334085 -0.7348971 + */ + val expected = Seq( + Vectors.dense(0.0, -0.30216651, -0.04452045), + Vectors.dense(3.5651651, -1.2334085, -0.7348971)) + + import IterativelyReweightedLeastSquaresSuite._ + + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + val newInstances = instances1.map { instance => + val mu = (instance.label + 0.5) / 2.0 + val eta = math.log(mu / (1.0 - mu)) + Instance(eta, instance.weight, instance.features) + } + val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, + standardizeFeatures = false, standardizeLabel = false).fit(newInstances) + val irls = new IterativelyReweightedLeastSquares(initial, BinomialReweightFunc, + fitIntercept, regParam = 0.0, maxIter = 25, tol = 1e-8).fit(instances1) + val actual = Vectors.dense(irls.intercept, irls.coefficients(0), irls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + idx += 1 + } + } + + test("IRLS against GLM with Poisson errors") { + /* + R code: + + df <- as.data.frame(cbind(A, b)) + for (formula in c(b ~ . -1, b ~ .)) { + model <- glm(formula, family="poisson", data=df, weights=w) + print(as.vector(coef(model))) + } + + [1] -0.09607792 0.18375613 + [1] 6.299947 3.324107 -1.081766 + */ + val expected = Seq( + Vectors.dense(0.0, -0.09607792, 0.18375613), + Vectors.dense(6.299947, 3.324107, -1.081766)) + + import IterativelyReweightedLeastSquaresSuite._ + + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + val yMean = instances2.map(_.label).mean + val newInstances = instances2.map { instance => + val mu = (instance.label + yMean) / 2.0 + val eta = math.log(mu) + Instance(eta, instance.weight, instance.features) + } + val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, + standardizeFeatures = false, standardizeLabel = false).fit(newInstances) + val irls = new IterativelyReweightedLeastSquares(initial, PoissonReweightFunc, + fitIntercept, regParam = 0.0, maxIter = 25, tol = 1e-8).fit(instances2) + val actual = Vectors.dense(irls.intercept, irls.coefficients(0), irls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + idx += 1 + } + } + + test("IRLS against L1Regression") { + /* + R code: + + library(quantreg) + + df <- as.data.frame(cbind(A, b)) + for (formula in c(b ~ . -1, b ~ .)) { + model <- rq(formula, data=df, weights=w) + print(as.vector(coef(model))) + } + + [1] 1.266667 0.400000 + [1] 29.5 17.0 -5.5 + */ + val expected = Seq( + Vectors.dense(0.0, 1.266667, 0.400000), + Vectors.dense(29.5, 17.0, -5.5)) + + import IterativelyReweightedLeastSquaresSuite._ + + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, + standardizeFeatures = false, standardizeLabel = false).fit(instances2) + val irls = new IterativelyReweightedLeastSquares(initial, L1RegressionReweightFunc, + fitIntercept, regParam = 0.0, maxIter = 200, tol = 1e-7).fit(instances2) + val actual = Vectors.dense(irls.intercept, irls.coefficients(0), irls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + idx += 1 + } + } +} + +object IterativelyReweightedLeastSquaresSuite { + + def BinomialReweightFunc( + instance: Instance, + model: WeightedLeastSquaresModel): (Double, Double) = { + val eta = model.predict(instance.features) + val mu = 1.0 / (1.0 + math.exp(-1.0 * eta)) + val z = eta + (instance.label - mu) / (mu * (1.0 - mu)) + val w = mu * (1 - mu) * instance.weight + (z, w) + } + + def PoissonReweightFunc( + instance: Instance, + model: WeightedLeastSquaresModel): (Double, Double) = { + val eta = model.predict(instance.features) + val mu = math.exp(eta) + val z = eta + (instance.label - mu) / mu + val w = mu * instance.weight + (z, w) + } + + def L1RegressionReweightFunc( + instance: Instance, + model: WeightedLeastSquaresModel): (Double, Double) = { + val eta = model.predict(instance.features) + val e = math.max(math.abs(eta - instance.label), 1e-7) + val w = 1 / e + val y = instance.label + (y, w) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala index b542ba3dc54d2..0b58a9821f57b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext { private var instances: RDD[Instance] = _ + private var instancesConstLabel: RDD[Instance] = _ override def beforeAll(): Unit = { super.beforeAll() @@ -43,6 +44,20 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) ), 2) + + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b.const <- c(17, 17, 17, 17) + w <- c(1, 2, 3, 4) + */ + instancesConstLabel = sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(17.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2) } test("WLS against lm") { @@ -65,15 +80,59 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext var idx = 0 for (fitIntercept <- Seq(false, true)) { - val wls = new WeightedLeastSquares( - fitIntercept, regParam = 0.0, standardizeFeatures = false, standardizeLabel = false) - .fit(instances) - val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) - assert(actual ~== expected(idx) absTol 1e-4) + for (standardization <- Seq(false, true)) { + val wls = new WeightedLeastSquares( + fitIntercept, regParam = 0.0, standardizeFeatures = standardization, + standardizeLabel = standardization).fit(instances) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + } + idx += 1 + } + } + + test("WLS against lm when label is constant and no regularization") { + /* + R code: + + df.const.label <- as.data.frame(cbind(A, b.const)) + for (formula in c(b.const ~ . -1, b.const ~ .)) { + model <- lm(formula, data=df.const.label, weights=w) + print(as.vector(coef(model))) + } + + [1] -9.221298 3.394343 + [1] 17 0 0 + */ + + val expected = Seq( + Vectors.dense(0.0, -9.221298, 3.394343), + Vectors.dense(17.0, 0.0, 0.0)) + + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + for (standardization <- Seq(false, true)) { + val wls = new WeightedLeastSquares( + fitIntercept, regParam = 0.0, standardizeFeatures = standardization, + standardizeLabel = standardization).fit(instancesConstLabel) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + } idx += 1 } } + test("WLS with regularization when label is constant") { + // if regParam is non-zero and standardization is true, the problem is ill-defined and + // an exception is thrown. + val wls = new WeightedLeastSquares( + fitIntercept = false, regParam = 0.1, standardizeFeatures = true, + standardizeLabel = true) + intercept[IllegalArgumentException]{ + wls.fit(instancesConstLabel) + } + } + test("WLS against glmnet") { /* R code: diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index eeb03dba2f825..a3366c0e5934c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -17,7 +17,11 @@ package org.apache.spark.ml.param +import java.io.{ByteArrayOutputStream, NotSerializableException, ObjectOutputStream} + import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.MyParams +import org.apache.spark.mllib.linalg.{Vector, Vectors} class ParamsSuite extends SparkFunSuite { @@ -80,7 +84,7 @@ class ParamsSuite extends SparkFunSuite { } } - { // StringParam + { // Param[String] val param = new Param[String](dummy, "name", "doc") // Currently we do not support null. for (value <- Seq("", "1", "abc", "quote\"", "newline\n")) { @@ -89,6 +93,19 @@ class ParamsSuite extends SparkFunSuite { } } + { // Param[Vector] + val param = new Param[Vector](dummy, "name", "doc") + val values = Seq( + Vectors.dense(Array.empty[Double]), + Vectors.dense(0.0, 2.0), + Vectors.sparse(0, Array.empty, Array.empty), + Vectors.sparse(2, Array(1), Array(2.0))) + for (value <- values) { + val json = param.jsonEncode(value) + assert(param.jsonDecode(json) === value) + } + } + { // IntArrayParam val param = new IntArrayParam(dummy, "name", "doc") val values: Seq[Array[Int]] = Seq( @@ -138,7 +155,7 @@ class ParamsSuite extends SparkFunSuite { test("param") { val solver = new TestParams() val uid = solver.uid - import solver.{maxIter, inputCol} + import solver.{inputCol, maxIter} assert(maxIter.name === "maxIter") assert(maxIter.doc === "maximum number of iterations (>= 0)") @@ -181,7 +198,7 @@ class ParamsSuite extends SparkFunSuite { test("param map") { val solver = new TestParams() - import solver.{maxIter, inputCol} + import solver.{inputCol, maxIter} val map0 = ParamMap.empty @@ -220,7 +237,7 @@ class ParamsSuite extends SparkFunSuite { test("params") { val solver = new TestParams() - import solver.{handleInvalid, maxIter, inputCol} + import solver.{handleInvalid, inputCol, maxIter} val params = solver.params assert(params.length === 3) @@ -251,15 +268,10 @@ class ParamsSuite extends SparkFunSuite { solver.getParam("abc") } - intercept[IllegalArgumentException] { - solver.validateParams() - } - solver.copy(ParamMap(inputCol -> "input")).validateParams() solver.setInputCol("input") assert(solver.isSet(inputCol)) assert(solver.isDefined(inputCol)) assert(solver.getInputCol === "input") - solver.validateParams() intercept[IllegalArgumentException] { ParamMap(maxIter -> -10) } @@ -335,6 +347,31 @@ class ParamsSuite extends SparkFunSuite { val t3 = t.copy(ParamMap(t.maxIter -> 20)) assert(t3.isSet(t3.maxIter)) } + + test("Filtering ParamMap") { + val params1 = new MyParams("my_params1") + val params2 = new MyParams("my_params2") + val paramMap = ParamMap( + params1.intParam -> 1, + params2.intParam -> 1, + params1.doubleParam -> 0.2, + params2.doubleParam -> 0.2) + val filteredParamMap = paramMap.filter(params1) + + assert(filteredParamMap.size === 2) + filteredParamMap.toSeq.foreach { + case ParamPair(p, _) => + assert(p.parent === params1.uid) + } + + // At the previous implementation of ParamMap#filter, + // mutable.Map#filterKeys was used internally but + // the return type of the method is not serializable (see SI-6654). + // Now mutable.Map#filter is used instead of filterKeys and the return type is serializable. + // So let's ensure serializability. + val objOut = new ObjectOutputStream(new ByteArrayOutputStream()) + objOut.writeObject(filteredParamMap) + } } object ParamsSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index 9d23547f28447..7d990ce0bcfd8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -34,10 +34,5 @@ class TestParams(override val uid: String) extends Params with HasHandleInvalid def clearMaxIter(): this.type = clear(maxIter) - override def validateParams(): Unit = { - super.validateParams() - require(isDefined(inputCol)) - } - override def copy(extra: ParamMap): TestParams = defaultCopy(extra) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index eadc80e0e62b1..dac76aa7a12c8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.recommendation -import java.io.File import java.util.Random import scala.collection.mutable @@ -26,28 +25,25 @@ import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark.{Logging, SparkException, SparkFunSuite} +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.internal.Logging import org.apache.spark.ml.recommendation.ALS._ -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.util.Utils +import org.apache.spark.sql.{DataFrame, Row} -class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { - - private var tempDir: File = _ +class ALSSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging { override def beforeAll(): Unit = { super.beforeAll() - tempDir = Utils.createTempDir() sc.setCheckpointDir(tempDir.getAbsolutePath) } override def afterAll(): Unit = { - Utils.deleteRecursively(tempDir) super.afterAll() } @@ -186,7 +182,7 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { assert(compressed.dstPtrs.toSeq === Seq(0, 2, 3, 4, 5)) var decompressed = ArrayBuffer.empty[(Int, Int, Int, Float)] var i = 0 - while (i < compressed.srcIds.size) { + while (i < compressed.srcIds.length) { var j = compressed.dstPtrs(i) while (j < compressed.dstPtrs(i + 1)) { val dstEncodedIndex = compressed.dstEncodedIndices(j) @@ -346,11 +342,10 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { .setSeed(0) val alpha = als.getAlpha val model = als.fit(training.toDF()) - val predictions = model.transform(test.toDF()) - .select("rating", "prediction") - .map { case Row(rating: Float, prediction: Float) => + val predictions = model.transform(test.toDF()).select("rating", "prediction").rdd.map { + case Row(rating: Float, prediction: Float) => (rating.toDouble, prediction.toDouble) - } + } val rmse = if (implicitPrefs) { // TODO: Use a better (rank-based?) evaluation metric for implicit feedback. @@ -483,4 +478,67 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, implicitPrefs = true, seed = 0) } + + test("read/write") { + import ALSSuite._ + val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) + val als = new ALS() + allEstimatorParamSettings.foreach { case (p, v) => + als.set(als.getParam(p), v) + } + val sqlContext = this.sqlContext + import sqlContext.implicits._ + val model = als.fit(ratings.toDF()) + + // Test Estimator save/load + val als2 = testDefaultReadWrite(als) + allEstimatorParamSettings.foreach { case (p, v) => + val param = als.getParam(p) + assert(als.get(param).get === als2.get(param).get) + } + + // Test Model save/load + val model2 = testDefaultReadWrite(model) + allModelParamSettings.foreach { case (p, v) => + val param = model.getParam(p) + assert(model.get(param).get === model2.get(param).get) + } + assert(model.rank === model2.rank) + def getFactors(df: DataFrame): Set[(Int, Array[Float])] = { + df.select("id", "features").collect().map { case r => + (r.getInt(0), r.getAs[Array[Float]](1)) + }.toSet + } + assert(getFactors(model.userFactors) === getFactors(model2.userFactors)) + assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors)) + } +} + +object ALSSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allModelParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPredictionCol" + ) + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allEstimatorParamSettings: Map[String, Any] = allModelParamSettings ++ Map( + "maxIter" -> 1, + "rank" -> 1, + "regParam" -> 0.01, + "numUserBlocks" -> 2, + "numItemBlocks" -> 2, + "implicitPrefs" -> true, + "alpha" -> 0.9, + "nonnegative" -> true, + "checkpointInterval" -> 20 + ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 359f31027172b..76891ad562811 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -21,17 +21,19 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator} -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row} -class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class AFTSurvivalRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var datasetUnivariate: DataFrame = _ @transient var datasetMultivariate: DataFrame = _ + @transient var datasetUnivariateScaled: DataFrame = _ override def beforeAll(): Unit = { super.beforeAll() @@ -41,6 +43,24 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex datasetMultivariate = sqlContext.createDataFrame( sc.parallelize(generateAFTInput( 2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0))) + datasetUnivariateScaled = sqlContext.createDataFrame( + sc.parallelize(generateAFTInput( + 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0)).map { x => + AFTPoint(Vectors.dense(x.features(0) * 1.0E3), x.label, x.censor) + }) + } + + /** + * Enable the ignored test to export the dataset into CSV format, + * so we can validate the training accuracy compared with R's survival package. + */ + ignore("export test data into CSV format") { + datasetUnivariate.rdd.map { case Row(features: Vector, label: Double, censor: Double) => + features.toArray.mkString(",") + "," + censor + "," + label + }.repartition(1).saveAsTextFile("target/tmp/AFTSurvivalRegressionSuite/datasetUnivariate") + datasetMultivariate.rdd.map { case Row(features: Vector, label: Double, censor: Double) => + features.toArray.mkString(",") + "," + censor + "," + label + }.repartition(1).saveAsTextFile("target/tmp/AFTSurvivalRegressionSuite/datasetMultivariate") } test("params") { @@ -332,4 +352,57 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex assert(prediction ~== model.predict(features) relTol 1E-5) } } + + test("should support all NumericType labels") { + val aft = new AFTSurvivalRegression().setMaxIter(1) + MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression]( + aft, isClassification = false, sqlContext) { (expected, actual) => + assert(expected.intercept === actual.intercept) + assert(expected.coefficients === actual.coefficients) + } + } + + test("numerical stability of standardization") { + val trainer = new AFTSurvivalRegression() + val model1 = trainer.fit(datasetUnivariate) + val model2 = trainer.fit(datasetUnivariateScaled) + + /** + * During training we standardize the dataset first, so no matter how we multiple + * a scaling factor into the dataset, the convergence rate should be the same, + * and the coefficients should equal to the original coefficients multiple by + * the scaling factor. It will have no effect on the intercept and scale. + */ + assert(model1.coefficients(0) ~== model2.coefficients(0) * 1.0E3 absTol 0.01) + assert(model1.intercept ~== model2.intercept absTol 0.01) + assert(model1.scale ~== model2.scale absTol 0.01) + } + + test("read/write") { + def checkModelData( + model: AFTSurvivalRegressionModel, + model2: AFTSurvivalRegressionModel): Unit = { + assert(model.intercept === model2.intercept) + assert(model.coefficients === model2.coefficients) + assert(model.scale === model2.scale) + } + val aft = new AFTSurvivalRegression() + testEstimatorAndModelReadWrite(aft, datasetMultivariate, + AFTSurvivalRegressionSuite.allParamSettings, checkModelData) + } +} + +object AFTSurvivalRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "fitIntercept" -> true, + "maxIter" -> 2, + "tol" -> 0.01 + ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 868fb8eecb8bb..e9fb2677b215b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -18,17 +18,18 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.impl.TreeTests -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.tree.impl.TreeTests +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} - -class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { +class DecisionTreeRegressorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import DecisionTreeRegressorSuite.compareAPIs @@ -49,7 +50,8 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex .setImpurity("variance") .setMaxDepth(2) .setMaxBins(100) - val categoricalFeatures = Map(0 -> 3, 1-> 3) + .setSeed(1) + val categoricalFeatures = Map(0 -> 3, 1 -> 3) compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) } @@ -58,12 +60,12 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex .setImpurity("variance") .setMaxDepth(2) .setMaxBins(100) - val categoricalFeatures = Map(0 -> 2, 1-> 2) + val categoricalFeatures = Map(0 -> 2, 1 -> 2) compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) } test("copied model must have the same parent") { - val categoricalFeatures = Map(0 -> 2, 1-> 2) + val categoricalFeatures = Map(0 -> 2, 1 -> 2) val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) val model = new DecisionTreeRegressor() .setImpurity("variance") @@ -72,11 +74,88 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex MLTestingUtils.checkCopy(model) } + test("predictVariance") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(100) + .setPredictionCol("") + .setVarianceCol("variance") + val categoricalFeatures = Map(0 -> 2, 1 -> 2) + + val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) + val model = dt.fit(df) + + val predictions = model.transform(df) + .select(model.getFeaturesCol, model.getVarianceCol) + .collect() + + predictions.foreach { case Row(features: Vector, variance: Double) => + val expectedVariance = model.rootNode.predictImpl(features).impurityStats.calculate() + assert(variance === expectedVariance, + s"Expected variance $expectedVariance but got $variance.") + } + } + + test("Feature importance with toy data") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(3) + .setSeed(123) + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0) + + val model = dt.fit(df) + + val importances = model.featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + assert(importances.toArray.sum === 1.0) + assert(importances.toArray.forall(_ >= 0.0)) + } + + test("should support all NumericType labels and not support other types") { + val dt = new DecisionTreeRegressor().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor]( + dt, isClassification = false, sqlContext) { (expected, actual) => + TreeTests.checkEqual(expected, actual) + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: test("model save/load") SPARK-6725 + test("read/write") { + def checkModelData( + model: DecisionTreeRegressionModel, + model2: DecisionTreeRegressionModel): Unit = { + TreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) + } + + val dt = new DecisionTreeRegressor() + val rdd = TreeTests.getTreeReadWriteData(sc) + + // Categorical splits with tree depth 2 + val categoricalData: DataFrame = + TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 0) + testEstimatorAndModelReadWrite(dt, categoricalData, + TreeTests.allParamSettings, checkModelData) + + // Continuous splits with tree depth 2 + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) + testEstimatorAndModelReadWrite(dt, continuousData, + TreeTests.allParamSettings, checkModelData) + + // Continuous splits with tree depth 0 + testEstimatorAndModelReadWrite(dt, continuousData, + TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData) + } } private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 09326600e620f..216377959e090 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.impl.TreeTests -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.tree.impl.TreeTests +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} @@ -29,11 +29,11 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.util.Utils - /** * Test suite for [[GBTRegressor]]. */ -class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { +class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { import GBTRegressorSuite.compareAPIs @@ -54,7 +54,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2) } - test("Regression with continuous features: SquaredError") { + test("Regression with continuous features") { val categoricalFeatures = Map.empty[Int, Int] GBTRegressor.supportedLossTypes.foreach { loss => testCombinations.foreach { @@ -65,6 +65,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { .setLossType(loss) .setMaxIter(maxIter) .setStepSize(learningRate) + .setSeed(123) compareAPIs(data, None, gbt, categoricalFeatures) } } @@ -87,7 +88,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { // copied model must have the same parent. MLTestingUtils.checkCopy(model) val preds = model.transform(df) - val predictions = preds.select("prediction").map(_.getDouble(0)) + val predictions = preds.select("prediction").rdd.map(_.getDouble(0)) // Checks based on SPARK-8736 (to ensure it is not doing classification) assert(predictions.max() > 2) assert(predictions.min() < -1) @@ -104,11 +105,19 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { .setMaxIter(5) .setStepSize(0.1) .setCheckpointInterval(2) + .setSeed(123) val model = gbt.fit(df) sc.checkpointDir = None Utils.deleteRecursively(tempDir) + } + test("should support all NumericType labels and not support other types") { + val gbt = new GBTRegressor().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[GBTRegressionModel, GBTRegressor]( + gbt, isClassification = false, sqlContext) { (expected, actual) => + TreeTests.checkEqual(expected, actual) + } } // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 @@ -129,31 +138,49 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { } */ + ///////////////////////////////////////////////////////////////////////////// + // Tests of feature importance + ///////////////////////////////////////////////////////////////////////////// + test("Feature importance with toy data") { + val gbt = new GBTRegressor() + .setMaxDepth(3) + .setMaxIter(5) + .setSubsamplingRate(1.0) + .setStepSize(0.5) + .setSeed(123) + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0) + + val importances = gbt.fit(df).featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + assert(importances.toArray.sum === 1.0) + assert(importances.toArray.forall(_ >= 0.0)) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: Reinstate test once save/load are implemented SPARK-6725 - /* test("model save/load") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - - val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray - val treeWeights = Array(0.1, 0.3, 1.1) - val oldModel = new OldGBTModel(OldAlgo.Regression, trees, treeWeights) - val newModel = GBTRegressionModel.fromOld(oldModel) - - // Save model, load it back, and compare. - try { - newModel.save(sc, path) - val sameNewModel = GBTRegressionModel.load(sc, path) - TreeTests.checkEqual(newModel, sameNewModel) - } finally { - Utils.deleteRecursively(tempDir) + def checkModelData( + model: GBTRegressionModel, + model2: GBTRegressionModel): Unit = { + TreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) } + + val gbt = new GBTRegressor() + val rdd = TreeTests.getTreeReadWriteData(sc) + + val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "squared") + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) + testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) } - */ } private object GBTRegressorSuite extends SparkFunSuite { @@ -169,7 +196,7 @@ private object GBTRegressorSuite extends SparkFunSuite { categoricalFeatures: Map[Int, Int]): Unit = { val numFeatures = data.first().features.size val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) - val oldGBT = new OldGBT(oldBoostingStrategy) + val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt) val oldModel = oldGBT.run(data) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newModel = gbt.fit(newData) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala new file mode 100644 index 0000000000000..3ecc210abdfca --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -0,0 +1,1061 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.classification.LogisticRegressionSuite._ +import org.apache.spark.mllib.linalg.{BLAS, DenseVector, Vector, Vectors} +import org.apache.spark.mllib.random._ +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ + +class GeneralizedLinearRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + private val seed: Int = 42 + @transient var datasetGaussianIdentity: DataFrame = _ + @transient var datasetGaussianLog: DataFrame = _ + @transient var datasetGaussianInverse: DataFrame = _ + @transient var datasetBinomial: DataFrame = _ + @transient var datasetPoissonLog: DataFrame = _ + @transient var datasetPoissonIdentity: DataFrame = _ + @transient var datasetPoissonSqrt: DataFrame = _ + @transient var datasetGammaInverse: DataFrame = _ + @transient var datasetGammaIdentity: DataFrame = _ + @transient var datasetGammaLog: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + import GeneralizedLinearRegressionSuite._ + + datasetGaussianIdentity = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gaussian", link = "identity"), 2)) + + datasetGaussianLog = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gaussian", link = "log"), 2)) + + datasetGaussianInverse = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gaussian", link = "inverse"), 2)) + + datasetBinomial = { + val nPoints = 10000 + val coefficients = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + + val testData = + generateMultinomialLogisticInput(coefficients, xMean, xVariance, + addIntercept = true, nPoints, seed) + + sqlContext.createDataFrame(sc.parallelize(testData, 2)) + } + + datasetPoissonLog = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "poisson", link = "log"), 2)) + + datasetPoissonIdentity = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "poisson", link = "identity"), 2)) + + datasetPoissonSqrt = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "poisson", link = "sqrt"), 2)) + + datasetGammaInverse = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gamma", link = "inverse"), 2)) + + datasetGammaIdentity = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gamma", link = "identity"), 2)) + + datasetGammaLog = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gamma", link = "log"), 2)) + } + + /** + * Enable the ignored test to export the dataset into CSV format, + * so we can validate the training accuracy compared with R's glm and glmnet package. + */ + ignore("export test data into CSV format") { + datasetGaussianIdentity.rdd.map { case Row(label: Double, features: Vector) => + label + "," + features.toArray.mkString(",") + }.repartition(1).saveAsTextFile( + "target/tmp/GeneralizedLinearRegressionSuite/datasetGaussianIdentity") + datasetGaussianLog.rdd.map { case Row(label: Double, features: Vector) => + label + "," + features.toArray.mkString(",") + }.repartition(1).saveAsTextFile( + "target/tmp/GeneralizedLinearRegressionSuite/datasetGaussianLog") + datasetGaussianInverse.rdd.map { case Row(label: Double, features: Vector) => + label + "," + features.toArray.mkString(",") + }.repartition(1).saveAsTextFile( + "target/tmp/GeneralizedLinearRegressionSuite/datasetGaussianInverse") + datasetBinomial.rdd.map { case Row(label: Double, features: Vector) => + label + "," + features.toArray.mkString(",") + }.repartition(1).saveAsTextFile( + "target/tmp/GeneralizedLinearRegressionSuite/datasetBinomial") + datasetPoissonLog.rdd.map { case Row(label: Double, features: Vector) => + label + "," + features.toArray.mkString(",") + }.repartition(1).saveAsTextFile( + "target/tmp/GeneralizedLinearRegressionSuite/datasetPoissonLog") + datasetPoissonIdentity.rdd.map { case Row(label: Double, features: Vector) => + label + "," + features.toArray.mkString(",") + }.repartition(1).saveAsTextFile( + "target/tmp/GeneralizedLinearRegressionSuite/datasetPoissonIdentity") + datasetPoissonSqrt.rdd.map { case Row(label: Double, features: Vector) => + label + "," + features.toArray.mkString(",") + }.repartition(1).saveAsTextFile( + "target/tmp/GeneralizedLinearRegressionSuite/datasetPoissonSqrt") + datasetGammaInverse.rdd.map { case Row(label: Double, features: Vector) => + label + "," + features.toArray.mkString(",") + }.repartition(1).saveAsTextFile( + "target/tmp/GeneralizedLinearRegressionSuite/datasetGammaInverse") + datasetGammaIdentity.rdd.map { case Row(label: Double, features: Vector) => + label + "," + features.toArray.mkString(",") + }.repartition(1).saveAsTextFile( + "target/tmp/GeneralizedLinearRegressionSuite/datasetGammaIdentity") + datasetGammaLog.rdd.map { case Row(label: Double, features: Vector) => + label + "," + features.toArray.mkString(",") + }.repartition(1).saveAsTextFile( + "target/tmp/GeneralizedLinearRegressionSuite/datasetGammaLog") + } + + test("params") { + ParamsSuite.checkParams(new GeneralizedLinearRegression) + val model = new GeneralizedLinearRegressionModel("genLinReg", Vectors.dense(0.0), 0.0) + ParamsSuite.checkParams(model) + } + + test("generalized linear regression: default params") { + val glr = new GeneralizedLinearRegression + assert(glr.getLabelCol === "label") + assert(glr.getFeaturesCol === "features") + assert(glr.getPredictionCol === "prediction") + assert(glr.getFitIntercept) + assert(glr.getTol === 1E-6) + assert(glr.getWeightCol === "") + assert(glr.getRegParam === 0.0) + assert(glr.getSolver == "irls") + // TODO: Construct model directly instead of via fitting. + val model = glr.setFamily("gaussian").setLink("identity") + .fit(datasetGaussianIdentity) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + + assert(model.getFeaturesCol === "features") + assert(model.getPredictionCol === "prediction") + assert(model.intercept !== 0.0) + assert(model.hasParent) + assert(model.getFamily === "gaussian") + assert(model.getLink === "identity") + } + + test("generalized linear regression: gaussian family against glm") { + /* + R code: + f1 <- data$V1 ~ data$V2 + data$V3 - 1 + f2 <- data$V1 ~ data$V2 + data$V3 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family="gaussian", data=data) + print(as.vector(coef(model))) + } + + [1] 2.2960999 0.8087933 + [1] 2.5002642 2.2000403 0.5999485 + + data <- read.csv("path", header=FALSE) + model1 <- glm(f1, family=gaussian(link=log), data=data, start=c(0,0)) + model2 <- glm(f2, family=gaussian(link=log), data=data, start=c(0,0,0)) + print(as.vector(coef(model1))) + print(as.vector(coef(model2))) + + [1] 0.23069326 0.07993778 + [1] 0.25001858 0.22002452 0.05998789 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=gaussian(link=inverse), data=data) + print(as.vector(coef(model))) + } + + [1] 2.3010179 0.8198976 + [1] 2.4108902 2.2130248 0.6086152 + */ + + val expected = Seq( + Vectors.dense(0.0, 2.2960999, 0.8087933), + Vectors.dense(2.5002642, 2.2000403, 0.5999485), + Vectors.dense(0.0, 0.23069326, 0.07993778), + Vectors.dense(0.25001858, 0.22002452, 0.05998789), + Vectors.dense(0.0, 2.3010179, 0.8198976), + Vectors.dense(2.4108902, 2.2130248, 0.6086152)) + + import GeneralizedLinearRegression._ + + var idx = 0 + for ((link, dataset) <- Seq(("identity", datasetGaussianIdentity), ("log", datasetGaussianLog), + ("inverse", datasetGaussianInverse))) { + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily("gaussian").setLink(link) + .setFitIntercept(fitIntercept) + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gaussian family, " + + s"$link link and fitIntercept = $fitIntercept.") + + val familyLink = new FamilyAndLink(Gaussian, Link.fromName(link)) + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"gaussian family, $link link and fitIntercept = $fitIntercept.") + } + + idx += 1 + } + } + } + + test("generalized linear regression: gaussian family against glmnet") { + /* + R code: + library(glmnet) + data <- read.csv("path", header=FALSE) + label = data$V1 + features = as.matrix(data.frame(data$V2, data$V3)) + for (intercept in c(FALSE, TRUE)) { + for (lambda in c(0.0, 0.1, 1.0)) { + model <- glmnet(features, label, family="gaussian", intercept=intercept, + lambda=lambda, alpha=0, thresh=1E-14) + print(as.vector(coef(model))) + } + } + + [1] 0.0000000 2.2961005 0.8087932 + [1] 0.0000000 2.2130368 0.8309556 + [1] 0.0000000 1.7176137 0.9610657 + [1] 2.5002642 2.2000403 0.5999485 + [1] 3.1106389 2.0935142 0.5712711 + [1] 6.7597127 1.4581054 0.3994266 + */ + + val expected = Seq( + Vectors.dense(0.0, 2.2961005, 0.8087932), + Vectors.dense(0.0, 2.2130368, 0.8309556), + Vectors.dense(0.0, 1.7176137, 0.9610657), + Vectors.dense(2.5002642, 2.2000403, 0.5999485), + Vectors.dense(3.1106389, 2.0935142, 0.5712711), + Vectors.dense(6.7597127, 1.4581054, 0.3994266)) + + var idx = 0 + for (fitIntercept <- Seq(false, true); + regParam <- Seq(0.0, 0.1, 1.0)) { + val trainer = new GeneralizedLinearRegression().setFamily("gaussian") + .setFitIntercept(fitIntercept).setRegParam(regParam) + val model = trainer.fit(datasetGaussianIdentity) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gaussian family, " + + s"fitIntercept = $fitIntercept and regParam = $regParam.") + + idx += 1 + } + } + + test("generalized linear regression: binomial family against glm") { + /* + R code: + f1 <- data$V1 ~ data$V2 + data$V3 + data$V4 + data$V5 - 1 + f2 <- data$V1 ~ data$V2 + data$V3 + data$V4 + data$V5 + data <- read.csv("path", header=FALSE) + + for (formula in c(f1, f2)) { + model <- glm(formula, family="binomial", data=data) + print(as.vector(coef(model))) + } + + [1] -0.3560284 1.3010002 -0.3570805 -0.7406762 + [1] 2.8367406 -0.5896187 0.8931655 -0.3925169 -0.7996989 + + for (formula in c(f1, f2)) { + model <- glm(formula, family=binomial(link=probit), data=data) + print(as.vector(coef(model))) + } + + [1] -0.2134390 0.7800646 -0.2144267 -0.4438358 + [1] 1.6995366 -0.3524694 0.5332651 -0.2352985 -0.4780850 + + for (formula in c(f1, f2)) { + model <- glm(formula, family=binomial(link=cloglog), data=data) + print(as.vector(coef(model))) + } + + [1] -0.2832198 0.8434144 -0.2524727 -0.5293452 + [1] 1.5063590 -0.4038015 0.6133664 -0.2687882 -0.5541758 + */ + val expected = Seq( + Vectors.dense(0.0, -0.3560284, 1.3010002, -0.3570805, -0.7406762), + Vectors.dense(2.8367406, -0.5896187, 0.8931655, -0.3925169, -0.7996989), + Vectors.dense(0.0, -0.2134390, 0.7800646, -0.2144267, -0.4438358), + Vectors.dense(1.6995366, -0.3524694, 0.5332651, -0.2352985, -0.4780850), + Vectors.dense(0.0, -0.2832198, 0.8434144, -0.2524727, -0.5293452), + Vectors.dense(1.5063590, -0.4038015, 0.6133664, -0.2687882, -0.5541758)) + + import GeneralizedLinearRegression._ + + var idx = 0 + for ((link, dataset) <- Seq(("logit", datasetBinomial), ("probit", datasetBinomial), + ("cloglog", datasetBinomial))) { + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily("binomial").setLink(link) + .setFitIntercept(fitIntercept) + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1), + model.coefficients(2), model.coefficients(3)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with binomial family, " + + s"$link link and fitIntercept = $fitIntercept.") + + val familyLink = new FamilyAndLink(Binomial, Link.fromName(link)) + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"binomial family, $link link and fitIntercept = $fitIntercept.") + } + + idx += 1 + } + } + } + + test("generalized linear regression: poisson family against glm") { + /* + R code: + f1 <- data$V1 ~ data$V2 + data$V3 - 1 + f2 <- data$V1 ~ data$V2 + data$V3 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family="poisson", data=data) + print(as.vector(coef(model))) + } + + [1] 0.22999393 0.08047088 + [1] 0.25022353 0.21998599 0.05998621 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=poisson(link=identity), data=data) + print(as.vector(coef(model))) + } + + [1] 2.2929501 0.8119415 + [1] 2.5012730 2.1999407 0.5999107 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=poisson(link=sqrt), data=data) + print(as.vector(coef(model))) + } + + [1] 2.2958947 0.8090515 + [1] 2.5000480 2.1999972 0.5999968 + */ + val expected = Seq( + Vectors.dense(0.0, 0.22999393, 0.08047088), + Vectors.dense(0.25022353, 0.21998599, 0.05998621), + Vectors.dense(0.0, 2.2929501, 0.8119415), + Vectors.dense(2.5012730, 2.1999407, 0.5999107), + Vectors.dense(0.0, 2.2958947, 0.8090515), + Vectors.dense(2.5000480, 2.1999972, 0.5999968)) + + import GeneralizedLinearRegression._ + + var idx = 0 + for ((link, dataset) <- Seq(("log", datasetPoissonLog), ("identity", datasetPoissonIdentity), + ("sqrt", datasetPoissonSqrt))) { + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily("poisson").setLink(link) + .setFitIntercept(fitIntercept) + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " + + s"$link link and fitIntercept = $fitIntercept.") + + val familyLink = new FamilyAndLink(Poisson, Link.fromName(link)) + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"poisson family, $link link and fitIntercept = $fitIntercept.") + } + + idx += 1 + } + } + } + + test("generalized linear regression: gamma family against glm") { + /* + R code: + f1 <- data$V1 ~ data$V2 + data$V3 - 1 + f2 <- data$V1 ~ data$V2 + data$V3 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family="Gamma", data=data) + print(as.vector(coef(model))) + } + + [1] 2.3392419 0.8058058 + [1] 2.3507700 2.2533574 0.6042991 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=Gamma(link=identity), data=data) + print(as.vector(coef(model))) + } + + [1] 2.2908883 0.8147796 + [1] 2.5002406 2.1998346 0.6000059 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=Gamma(link=log), data=data) + print(as.vector(coef(model))) + } + + [1] 0.22958970 0.08091066 + [1] 0.25003210 0.21996957 0.06000215 + */ + val expected = Seq( + Vectors.dense(0.0, 2.3392419, 0.8058058), + Vectors.dense(2.3507700, 2.2533574, 0.6042991), + Vectors.dense(0.0, 2.2908883, 0.8147796), + Vectors.dense(2.5002406, 2.1998346, 0.6000059), + Vectors.dense(0.0, 0.22958970, 0.08091066), + Vectors.dense(0.25003210, 0.21996957, 0.06000215)) + + import GeneralizedLinearRegression._ + + var idx = 0 + for ((link, dataset) <- Seq(("inverse", datasetGammaInverse), + ("identity", datasetGammaIdentity), ("log", datasetGammaLog))) { + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily("gamma").setLink(link) + .setFitIntercept(fitIntercept) + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gamma family, " + + s"$link link and fitIntercept = $fitIntercept.") + + val familyLink = new FamilyAndLink(Gamma, Link.fromName(link)) + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"gamma family, $link link and fitIntercept = $fitIntercept.") + } + + idx += 1 + } + } + } + + test("glm summary: gaussian family with weight") { + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b <- c(17, 19, 23, 29) + w <- c(1, 2, 3, 4) + df <- as.data.frame(cbind(A, b)) + */ + val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2)) + /* + R code: + + model <- glm(formula = "b ~ .", family="gaussian", data = df, weights = w) + summary(model) + + Deviance Residuals: + 1 2 3 4 + 1.920 -1.358 -1.109 0.960 + + Coefficients: + Estimate Std. Error t value Pr(>|t|) + (Intercept) 18.080 9.608 1.882 0.311 + V1 6.080 5.556 1.094 0.471 + V2 -0.600 1.960 -0.306 0.811 + + (Dispersion parameter for gaussian family taken to be 7.68) + + Null deviance: 202.00 on 3 degrees of freedom + Residual deviance: 7.68 on 1 degrees of freedom + AIC: 18.783 + + Number of Fisher Scoring iterations: 2 + + residuals(model, type="pearson") + 1 2 3 4 + 1.920000 -1.357645 -1.108513 0.960000 + + residuals(model, type="working") + 1 2 3 4 + 1.92 -0.96 -0.64 0.48 + + residuals(model, type="response") + 1 2 3 4 + 1.92 -0.96 -0.64 0.48 + */ + val trainer = new GeneralizedLinearRegression() + .setWeightCol("weight") + + val model = trainer.fit(datasetWithWeight) + + val coefficientsR = Vectors.dense(Array(6.080, -0.600)) + val interceptR = 18.080 + val devianceResidualsR = Array(1.920, -1.358, -1.109, 0.960) + val pearsonResidualsR = Array(1.920000, -1.357645, -1.108513, 0.960000) + val workingResidualsR = Array(1.92, -0.96, -0.64, 0.48) + val responseResidualsR = Array(1.92, -0.96, -0.64, 0.48) + val seCoefR = Array(5.556, 1.960, 9.608) + val tValsR = Array(1.094, -0.306, 1.882) + val pValsR = Array(0.471, 0.811, 0.311) + val dispersionR = 7.68 + val nullDevianceR = 202.00 + val residualDevianceR = 7.68 + val residualDegreeOfFreedomNullR = 3 + val residualDegreeOfFreedomR = 1 + val aicR = 18.783 + + val summary = model.summary + + val devianceResiduals = summary.residuals() + .select(col("devianceResiduals")) + .collect() + .map(_.getDouble(0)) + val pearsonResiduals = summary.residuals("pearson") + .select(col("pearsonResiduals")) + .collect() + .map(_.getDouble(0)) + val workingResiduals = summary.residuals("working") + .select(col("workingResiduals")) + .collect() + .map(_.getDouble(0)) + val responseResiduals = summary.residuals("response") + .select(col("responseResiduals")) + .collect() + .map(_.getDouble(0)) + + assert(model.coefficients ~== coefficientsR absTol 1E-3) + assert(model.intercept ~== interceptR absTol 1E-3) + devianceResiduals.zip(devianceResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + pearsonResiduals.zip(pearsonResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + workingResiduals.zip(workingResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + responseResiduals.zip(responseResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => + assert(x._1 ~== x._2 absTol 1E-3) } + summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + assert(summary.dispersion ~== dispersionR absTol 1E-3) + assert(summary.nullDeviance ~== nullDevianceR absTol 1E-3) + assert(summary.deviance ~== residualDevianceR absTol 1E-3) + assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) + assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) + assert(summary.aic ~== aicR absTol 1E-3) + assert(summary.solver === "irls") + } + + test("glm summary: binomial family with weight") { + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 2, 1, 3), 4, 2) + b <- c(1, 0, 1, 0) + w <- c(1, 2, 3, 4) + df <- as.data.frame(cbind(A, b)) + */ + val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq( + Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(0.0, 2.0, Vectors.dense(1.0, 2.0)), + Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)), + Instance(0.0, 4.0, Vectors.dense(3.0, 3.0)) + ), 2)) + /* + R code: + + model <- glm(formula = "b ~ . -1", family="binomial", data = df, weights = w) + summary(model) + + Deviance Residuals: + 1 2 3 4 + 1.273 -1.437 2.533 -1.556 + + Coefficients: + Estimate Std. Error z value Pr(>|z|) + V1 -0.30217 0.46242 -0.653 0.513 + V2 -0.04452 0.37124 -0.120 0.905 + + (Dispersion parameter for binomial family taken to be 1) + + Null deviance: 13.863 on 4 degrees of freedom + Residual deviance: 12.524 on 2 degrees of freedom + AIC: 16.524 + + Number of Fisher Scoring iterations: 5 + + residuals(model, type="pearson") + 1 2 3 4 + 1.117731 -1.162962 2.395838 -1.189005 + + residuals(model, type="working") + 1 2 3 4 + 2.249324 -1.676240 2.913346 -1.353433 + + residuals(model, type="response") + 1 2 3 4 + 0.5554219 -0.4034267 0.6567520 -0.2611382 + */ + val trainer = new GeneralizedLinearRegression() + .setFamily("binomial") + .setWeightCol("weight") + .setFitIntercept(false) + + val model = trainer.fit(datasetWithWeight) + + val coefficientsR = Vectors.dense(Array(-0.30217, -0.04452)) + val interceptR = 0.0 + val devianceResidualsR = Array(1.273, -1.437, 2.533, -1.556) + val pearsonResidualsR = Array(1.117731, -1.162962, 2.395838, -1.189005) + val workingResidualsR = Array(2.249324, -1.676240, 2.913346, -1.353433) + val responseResidualsR = Array(0.5554219, -0.4034267, 0.6567520, -0.2611382) + val seCoefR = Array(0.46242, 0.37124) + val tValsR = Array(-0.653, -0.120) + val pValsR = Array(0.513, 0.905) + val dispersionR = 1.0 + val nullDevianceR = 13.863 + val residualDevianceR = 12.524 + val residualDegreeOfFreedomNullR = 4 + val residualDegreeOfFreedomR = 2 + val aicR = 16.524 + + val summary = model.summary + val devianceResiduals = summary.residuals() + .select(col("devianceResiduals")) + .collect() + .map(_.getDouble(0)) + val pearsonResiduals = summary.residuals("pearson") + .select(col("pearsonResiduals")) + .collect() + .map(_.getDouble(0)) + val workingResiduals = summary.residuals("working") + .select(col("workingResiduals")) + .collect() + .map(_.getDouble(0)) + val responseResiduals = summary.residuals("response") + .select(col("responseResiduals")) + .collect() + .map(_.getDouble(0)) + + assert(model.coefficients ~== coefficientsR absTol 1E-3) + assert(model.intercept ~== interceptR absTol 1E-3) + devianceResiduals.zip(devianceResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + pearsonResiduals.zip(pearsonResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + workingResiduals.zip(workingResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + responseResiduals.zip(responseResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => + assert(x._1 ~== x._2 absTol 1E-3) } + summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + assert(summary.dispersion ~== dispersionR absTol 1E-3) + assert(summary.nullDeviance ~== nullDevianceR absTol 1E-3) + assert(summary.deviance ~== residualDevianceR absTol 1E-3) + assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) + assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) + assert(summary.aic ~== aicR absTol 1E-3) + assert(summary.solver === "irls") + } + + test("glm summary: poisson family with weight") { + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b <- c(2, 8, 3, 9) + w <- c(1, 2, 3, 4) + df <- as.data.frame(cbind(A, b)) + */ + val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq( + Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2)) + /* + R code: + + model <- glm(formula = "b ~ .", family="poisson", data = df, weights = w) + summary(model) + + Deviance Residuals: + 1 2 3 4 + -0.28952 0.11048 0.14839 -0.07268 + + Coefficients: + Estimate Std. Error z value Pr(>|z|) + (Intercept) 6.2999 1.6086 3.916 8.99e-05 *** + V1 3.3241 1.0184 3.264 0.00110 ** + V2 -1.0818 0.3522 -3.071 0.00213 ** + --- + Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 + + (Dispersion parameter for poisson family taken to be 1) + + Null deviance: 15.38066 on 3 degrees of freedom + Residual deviance: 0.12333 on 1 degrees of freedom + AIC: 41.803 + + Number of Fisher Scoring iterations: 3 + + residuals(model, type="pearson") + 1 2 3 4 + -0.28043145 0.11099310 0.14963714 -0.07253611 + + residuals(model, type="working") + 1 2 3 4 + -0.17960679 0.02813593 0.05113852 -0.01201650 + + residuals(model, type="response") + 1 2 3 4 + -0.4378554 0.2189277 0.1459518 -0.1094638 + */ + val trainer = new GeneralizedLinearRegression() + .setFamily("poisson") + .setWeightCol("weight") + .setFitIntercept(true) + + val model = trainer.fit(datasetWithWeight) + + val coefficientsR = Vectors.dense(Array(3.3241, -1.0818)) + val interceptR = 6.2999 + val devianceResidualsR = Array(-0.28952, 0.11048, 0.14839, -0.07268) + val pearsonResidualsR = Array(-0.28043145, 0.11099310, 0.14963714, -0.07253611) + val workingResidualsR = Array(-0.17960679, 0.02813593, 0.05113852, -0.01201650) + val responseResidualsR = Array(-0.4378554, 0.2189277, 0.1459518, -0.1094638) + val seCoefR = Array(1.0184, 0.3522, 1.6086) + val tValsR = Array(3.264, -3.071, 3.916) + val pValsR = Array(0.00110, 0.00213, 0.00009) + val dispersionR = 1.0 + val nullDevianceR = 15.38066 + val residualDevianceR = 0.12333 + val residualDegreeOfFreedomNullR = 3 + val residualDegreeOfFreedomR = 1 + val aicR = 41.803 + + val summary = model.summary + val devianceResiduals = summary.residuals() + .select(col("devianceResiduals")) + .collect() + .map(_.getDouble(0)) + val pearsonResiduals = summary.residuals("pearson") + .select(col("pearsonResiduals")) + .collect() + .map(_.getDouble(0)) + val workingResiduals = summary.residuals("working") + .select(col("workingResiduals")) + .collect() + .map(_.getDouble(0)) + val responseResiduals = summary.residuals("response") + .select(col("responseResiduals")) + .collect() + .map(_.getDouble(0)) + + assert(model.coefficients ~== coefficientsR absTol 1E-3) + assert(model.intercept ~== interceptR absTol 1E-3) + devianceResiduals.zip(devianceResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + pearsonResiduals.zip(pearsonResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + workingResiduals.zip(workingResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + responseResiduals.zip(responseResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => + assert(x._1 ~== x._2 absTol 1E-3) } + summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + assert(summary.dispersion ~== dispersionR absTol 1E-3) + assert(summary.nullDeviance ~== nullDevianceR absTol 1E-3) + assert(summary.deviance ~== residualDevianceR absTol 1E-3) + assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) + assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) + assert(summary.aic ~== aicR absTol 1E-3) + assert(summary.solver === "irls") + } + + test("glm summary: gamma family with weight") { + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b <- c(2, 8, 3, 9) + w <- c(1, 2, 3, 4) + df <- as.data.frame(cbind(A, b)) + */ + val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq( + Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2)) + /* + R code: + + model <- glm(formula = "b ~ .", family="Gamma", data = df, weights = w) + summary(model) + + Deviance Residuals: + 1 2 3 4 + -0.26343 0.05761 0.12818 -0.03484 + + Coefficients: + Estimate Std. Error t value Pr(>|t|) + (Intercept) -0.81511 0.23449 -3.476 0.178 + V1 -0.72730 0.16137 -4.507 0.139 + V2 0.23894 0.05481 4.359 0.144 + + (Dispersion parameter for Gamma family taken to be 0.07986091) + + Null deviance: 2.937462 on 3 degrees of freedom + Residual deviance: 0.090358 on 1 degrees of freedom + AIC: 23.202 + + Number of Fisher Scoring iterations: 4 + + residuals(model, type="pearson") + 1 2 3 4 + -0.24082508 0.05839241 0.13135766 -0.03463621 + + residuals(model, type="working") + 1 2 3 4 + 0.091414181 -0.005374314 -0.027196998 0.001890910 + + residuals(model, type="response") + 1 2 3 4 + -0.6344390 0.3172195 0.2114797 -0.1586097 + */ + val trainer = new GeneralizedLinearRegression() + .setFamily("gamma") + .setWeightCol("weight") + + val model = trainer.fit(datasetWithWeight) + + val coefficientsR = Vectors.dense(Array(-0.72730, 0.23894)) + val interceptR = -0.81511 + val devianceResidualsR = Array(-0.26343, 0.05761, 0.12818, -0.03484) + val pearsonResidualsR = Array(-0.24082508, 0.05839241, 0.13135766, -0.03463621) + val workingResidualsR = Array(0.091414181, -0.005374314, -0.027196998, 0.001890910) + val responseResidualsR = Array(-0.6344390, 0.3172195, 0.2114797, -0.1586097) + val seCoefR = Array(0.16137, 0.05481, 0.23449) + val tValsR = Array(-4.507, 4.359, -3.476) + val pValsR = Array(0.139, 0.144, 0.178) + val dispersionR = 0.07986091 + val nullDevianceR = 2.937462 + val residualDevianceR = 0.090358 + val residualDegreeOfFreedomNullR = 3 + val residualDegreeOfFreedomR = 1 + val aicR = 23.202 + + val summary = model.summary + val devianceResiduals = summary.residuals() + .select(col("devianceResiduals")) + .collect() + .map(_.getDouble(0)) + val pearsonResiduals = summary.residuals("pearson") + .select(col("pearsonResiduals")) + .collect() + .map(_.getDouble(0)) + val workingResiduals = summary.residuals("working") + .select(col("workingResiduals")) + .collect() + .map(_.getDouble(0)) + val responseResiduals = summary.residuals("response") + .select(col("responseResiduals")) + .collect() + .map(_.getDouble(0)) + + assert(model.coefficients ~== coefficientsR absTol 1E-3) + assert(model.intercept ~== interceptR absTol 1E-3) + devianceResiduals.zip(devianceResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + pearsonResiduals.zip(pearsonResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + workingResiduals.zip(workingResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + responseResiduals.zip(responseResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => + assert(x._1 ~== x._2 absTol 1E-3) } + summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + assert(summary.dispersion ~== dispersionR absTol 1E-3) + assert(summary.nullDeviance ~== nullDevianceR absTol 1E-3) + assert(summary.deviance ~== residualDevianceR absTol 1E-3) + assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) + assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) + assert(summary.aic ~== aicR absTol 1E-3) + assert(summary.solver === "irls") + } + + test("read/write") { + def checkModelData( + model: GeneralizedLinearRegressionModel, + model2: GeneralizedLinearRegressionModel): Unit = { + assert(model.intercept === model2.intercept) + assert(model.coefficients.toArray === model2.coefficients.toArray) + } + + val glr = new GeneralizedLinearRegression() + testEstimatorAndModelReadWrite(glr, datasetPoissonLog, + GeneralizedLinearRegressionSuite.allParamSettings, checkModelData) + } + + test("should support all NumericType labels and not support other types") { + val glr = new GeneralizedLinearRegression().setMaxIter(1) + MLTestingUtils.checkNumericTypes[ + GeneralizedLinearRegressionModel, GeneralizedLinearRegression]( + glr, isClassification = false, sqlContext) { (expected, actual) => + assert(expected.intercept === actual.intercept) + assert(expected.coefficients === actual.coefficients) + } + } + + test("glm accepts Dataset[LabeledPoint]") { + val context = sqlContext + import context.implicits._ + new GeneralizedLinearRegression() + .setFamily("gaussian") + .fit(datasetGaussianIdentity.as[LabeledPoint]) + } +} + +object GeneralizedLinearRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "family" -> "poisson", + "link" -> "log", + "fitIntercept" -> true, + "maxIter" -> 2, // intentionally small + "tol" -> 0.8, + "regParam" -> 0.01, + "predictionCol" -> "myPrediction") + + def generateGeneralizedLinearRegressionInput( + intercept: Double, + coefficients: Array[Double], + xMean: Array[Double], + xVariance: Array[Double], + nPoints: Int, + seed: Int, + noiseLevel: Double, + family: String, + link: String): Seq[LabeledPoint] = { + + val rnd = new Random(seed) + def rndElement(i: Int) = { + (rnd.nextDouble() - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) + } + val (generator, mean) = family match { + case "gaussian" => (new StandardNormalGenerator, 0.0) + case "poisson" => (new PoissonGenerator(1.0), 1.0) + case "gamma" => (new GammaGenerator(1.0, 1.0), 1.0) + } + generator.setSeed(seed) + + (0 until nPoints).map { _ => + val features = Vectors.dense(coefficients.indices.map(rndElement).toArray) + val eta = BLAS.dot(Vectors.dense(coefficients), features) + intercept + val mu = link match { + case "identity" => eta + case "log" => math.exp(eta) + case "sqrt" => math.pow(eta, 2.0) + case "inverse" => 1.0 / eta + } + val label = mu + noiseLevel * (generator.nextValue() - mean) + // Return LabeledPoints with DenseVector + LabeledPoint(label, features) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index 59f4193abc8f0..3a10ad7ed060a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -19,12 +19,14 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class IsotonicRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + private def generateIsotonicInput(labels: Seq[Double]): DataFrame = { sqlContext.createDataFrame( labels.zipWithIndex.map { case (label, i) => (label, i.toDouble, 1.0) } @@ -44,7 +46,7 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val predictions = model .transform(dataset) - .select("prediction").map { case Row(pred) => + .select("prediction").rdd.map { case Row(pred) => pred }.collect() @@ -64,7 +66,7 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val predictions = model .transform(features) - .select("prediction").map { + .select("prediction").rdd.map { case Row(pred) => pred }.collect() @@ -158,10 +160,47 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val predictions = model .transform(features) - .select("prediction").map { + .select("prediction").rdd.map { case Row(pred) => pred }.collect() assert(predictions === Array(3.5, 5.0, 5.0, 5.0)) } + + test("read/write") { + val dataset = generateIsotonicInput(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18)) + + def checkModelData(model: IsotonicRegressionModel, model2: IsotonicRegressionModel): Unit = { + assert(model.boundaries === model2.boundaries) + assert(model.predictions === model2.predictions) + assert(model.isotonic === model2.isotonic) + } + + val ir = new IsotonicRegression() + testEstimatorAndModelReadWrite(ir, dataset, IsotonicRegressionSuite.allParamSettings, + checkModelData) + } + + test("should support all NumericType labels and not support other types") { + val ir = new IsotonicRegression() + MLTestingUtils.checkNumericTypes[IsotonicRegressionModel, IsotonicRegression]( + ir, isClassification = false, sqlContext) { (expected, actual) => + assert(expected.boundaries === actual.boundaries) + assert(expected.predictions === actual.predictions) + } + } +} + +object IsotonicRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "isotonic" -> true, + "featureIndex" -> 0 + ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index a1d86fe8fedad..eb19d130939e4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -22,33 +22,24 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.linalg.{Vector, DenseVector, Vectors} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class LinearRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { private val seed: Int = 42 @transient var datasetWithDenseFeature: DataFrame = _ @transient var datasetWithDenseFeatureWithoutIntercept: DataFrame = _ @transient var datasetWithSparseFeature: DataFrame = _ @transient var datasetWithWeight: DataFrame = _ + @transient var datasetWithWeightConstantLabel: DataFrame = _ + @transient var datasetWithWeightZeroLabel: DataFrame = _ - /* - In `LinearRegressionSuite`, we will make sure that the model trained by SparkML - is the same as the one trained by R's glmnet package. The following instruction - describes how to reproduce the data in R. - - import org.apache.spark.mllib.util.LinearDataGenerator - val data = - sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), - Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2) - data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1) - .saveAsTextFile("path") - */ override def beforeAll(): Unit = { super.beforeAll() datasetWithDenseFeature = sqlContext.createDataFrame( @@ -56,8 +47,8 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2)) /* - datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating - training model without intercept + datasetWithDenseFeatureWithoutIntercept is not needed for correctness testing + but is useful for illustrating training model without intercept */ datasetWithDenseFeatureWithoutIntercept = sqlContext.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( @@ -70,9 +61,9 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val featureSize = 4100 datasetWithSparseFeature = sqlContext.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( - intercept = 0.0, weights = Seq.fill(featureSize)(r.nextDouble).toArray, - xMean = Seq.fill(featureSize)(r.nextDouble).toArray, - xVariance = Seq.fill(featureSize)(r.nextDouble).toArray, nPoints = 200, + intercept = 0.0, weights = Seq.fill(featureSize)(r.nextDouble()).toArray, + xMean = Seq.fill(featureSize)(r.nextDouble()).toArray, + xVariance = Seq.fill(featureSize)(r.nextDouble()).toArray, nPoints = 200, seed, eps = 0.1, sparsity = 0.7), 2)) /* @@ -90,6 +81,49 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) ), 2)) + + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b.const <- c(17, 17, 17, 17) + w <- c(1, 2, 3, 4) + df.const.label <- as.data.frame(cbind(A, b.const)) + */ + datasetWithWeightConstantLabel = sqlContext.createDataFrame( + sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(17.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2)) + datasetWithWeightZeroLabel = sqlContext.createDataFrame( + sc.parallelize(Seq( + Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(0.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(0.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2)) + } + + /** + * Enable the ignored test to export the dataset into CSV format, + * so we can validate the training accuracy compared with R's glmnet package. + */ + ignore("export test data into CSV format") { + datasetWithDenseFeature.rdd.map { case Row(label: Double, features: Vector) => + label + "," + features.toArray.mkString(",") + }.repartition(1).saveAsTextFile("target/tmp/LinearRegressionSuite/datasetWithDenseFeature") + + datasetWithDenseFeatureWithoutIntercept.rdd.map { + case Row(label: Double, features: Vector) => + label + "," + features.toArray.mkString(",") + }.repartition(1).saveAsTextFile( + "target/tmp/LinearRegressionSuite/datasetWithDenseFeatureWithoutIntercept") + + datasetWithSparseFeature.rdd.map { case Row(label: Double, features: Vector) => + label + "," + features.toArray.mkString(",") + }.repartition(1).saveAsTextFile("target/tmp/LinearRegressionSuite/datasetWithSparseFeature") } test("params") { @@ -183,19 +217,19 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . - as.numeric.data.V2. 6.995908 - as.numeric.data.V3. 5.275131 + as.numeric.data.V2. 6.973403 + as.numeric.data.V3. 5.284370 */ - val coefficientsR = Vectors.dense(6.995908, 5.275131) + val coefficientsR = Vectors.dense(6.973403, 5.284370) - assert(model1.intercept ~== 0 absTol 1E-3) - assert(model1.coefficients ~= coefficientsR relTol 1E-3) - assert(model2.intercept ~== 0 absTol 1E-3) - assert(model2.coefficients ~= coefficientsR relTol 1E-3) + assert(model1.intercept ~== 0 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR relTol 1E-2) + assert(model2.intercept ~== 0 absTol 1E-2) + assert(model2.coefficients ~= coefficientsR relTol 1E-2) /* Then again with the data with no intercept: - > coefficientsWithourIntercept + > coefficientsWithoutIntercept 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . @@ -234,14 +268,14 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) 6.24300 - as.numeric.data.V2. 4.024821 - as.numeric.data.V3. 6.679841 + (Intercept) 6.242284 + as.numeric.d1.V2. 4.019605 + as.numeric.d1.V3. 6.679538 */ - val interceptR1 = 6.24300 - val coefficientsR1 = Vectors.dense(4.024821, 6.679841) - assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + val interceptR1 = 6.242284 + val coefficientsR1 = Vectors.dense(4.019605, 6.679538) + assert(model1.intercept ~== interceptR1 relTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, @@ -295,14 +329,14 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . - as.numeric.data.V2. 6.299752 - as.numeric.data.V3. 4.772913 + as.numeric.data.V2. 6.272927 + as.numeric.data.V3. 4.782604 */ val interceptR1 = 0.0 - val coefficientsR1 = Vectors.dense(6.299752, 4.772913) + val coefficientsR1 = Vectors.dense(6.272927, 4.782604) - assert(model1.intercept ~== interceptR1 absTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + assert(model1.intercept ~== interceptR1 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, @@ -311,14 +345,14 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . - as.numeric.data.V2. 6.232193 - as.numeric.data.V3. 4.764229 + as.numeric.data.V2. 6.207817 + as.numeric.data.V3. 4.775780 */ val interceptR2 = 0.0 - val coefficientsR2 = Vectors.dense(6.232193, 4.764229) + val coefficientsR2 = Vectors.dense(6.207817, 4.775780) - assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + assert(model2.intercept ~== interceptR2 absTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) model1.transform(datasetWithDenseFeature).select("features", "prediction") .collect().foreach { @@ -346,15 +380,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) 5.269376 - as.numeric.data.V2. 3.736216 - as.numeric.data.V3. 5.712356) + (Intercept) 5.260103 + as.numeric.d1.V2. 3.725522 + as.numeric.d1.V3. 5.711203 */ - val interceptR1 = 5.269376 - val coefficientsR1 = Vectors.dense(3.736216, 5.712356) + val interceptR1 = 5.260103 + val coefficientsR1 = Vectors.dense(3.725522, 5.711203) - assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + assert(model1.intercept ~== interceptR1 relTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, @@ -362,15 +396,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) 5.791109 - as.numeric.data.V2. 3.435466 - as.numeric.data.V3. 5.910406 + (Intercept) 5.790885 + as.numeric.d1.V2. 3.432373 + as.numeric.d1.V3. 5.919196 */ - val interceptR2 = 5.791109 - val coefficientsR2 = Vectors.dense(3.435466, 5.910406) + val interceptR2 = 5.790885 + val coefficientsR2 = Vectors.dense(3.432373, 5.919196) - assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + assert(model2.intercept ~== interceptR2 relTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -397,15 +431,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) . - as.numeric.data.V2. 5.522875 - as.numeric.data.V3. 4.214502 + (Intercept) . + as.numeric.d1.V2. 5.493430 + as.numeric.d1.V3. 4.223082 */ val interceptR1 = 0.0 - val coefficientsR1 = Vectors.dense(5.522875, 4.214502) + val coefficientsR1 = Vectors.dense(5.493430, 4.223082) - assert(model1.intercept ~== interceptR1 absTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + assert(model1.intercept ~== interceptR1 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, @@ -414,14 +448,14 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . - as.numeric.data.V2. 5.263704 - as.numeric.data.V3. 4.187419 + as.numeric.d1.V2. 5.244324 + as.numeric.d1.V3. 4.203106 */ val interceptR2 = 0.0 - val coefficientsR2 = Vectors.dense(5.263704, 4.187419) + val coefficientsR2 = Vectors.dense(5.244324, 4.203106) - assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + assert(model2.intercept ~== interceptR2 absTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -456,15 +490,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) 6.324108 - as.numeric.data.V2. 3.168435 - as.numeric.data.V3. 5.200403 + (Intercept) 5.689855 + as.numeric.d1.V2. 3.661181 + as.numeric.d1.V3. 6.000274 */ - val interceptR1 = 5.696056 - val coefficientsR1 = Vectors.dense(3.670489, 6.001122) + val interceptR1 = 5.689855 + val coefficientsR1 = Vectors.dense(3.661181, 6.000274) - assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + assert(model1.intercept ~== interceptR1 relTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6 @@ -472,15 +506,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) 6.114723 - as.numeric.data.V2. 3.409937 - as.numeric.data.V3. 6.146531 + (Intercept) 6.113890 + as.numeric.d1.V2. 3.407021 + as.numeric.d1.V3. 6.152512 */ - val interceptR2 = 6.114723 - val coefficientsR2 = Vectors.dense(3.409937, 6.146531) + val interceptR2 = 6.113890 + val coefficientsR2 = Vectors.dense(3.407021, 6.152512) - assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + assert(model2.intercept ~== interceptR2 relTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) model1.transform(datasetWithDenseFeature).select("features", "prediction") .collect().foreach { @@ -517,15 +551,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) . - as.numeric.dataM.V2. 5.673348 - as.numeric.dataM.V3. 4.322251 + (Intercept) . + as.numeric.d1.V2. 5.643748 + as.numeric.d1.V3. 4.331519 */ val interceptR1 = 0.0 - val coefficientsR1 = Vectors.dense(5.673348, 4.322251) + val coefficientsR1 = Vectors.dense(5.643748, 4.331519) - assert(model1.intercept ~== interceptR1 absTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + assert(model1.intercept ~== interceptR1 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, @@ -534,14 +568,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . - as.numeric.data.V2. 5.477988 - as.numeric.data.V3. 4.297622 + as.numeric.d1.V2. 5.455902 + as.numeric.d1.V3. 4.312266 + */ val interceptR2 = 0.0 - val coefficientsR2 = Vectors.dense(5.477988, 4.297622) + val coefficientsR2 = Vectors.dense(5.455902, 4.312266) - assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + assert(model2.intercept ~== interceptR2 absTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) model1.transform(datasetWithDenseFeature).select("features", "prediction") .collect().foreach { @@ -555,6 +590,86 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("linear regression model with constant label") { + /* + R code: + for (formula in c(b.const ~ . -1, b.const ~ .)) { + model <- lm(formula, data=df.const.label, weights=w) + print(as.vector(coef(model))) + } + [1] -9.221298 3.394343 + [1] 17 0 0 + */ + val expected = Seq( + Vectors.dense(0.0, -9.221298, 3.394343), + Vectors.dense(17.0, 0.0, 0.0)) + + Seq("auto", "l-bfgs", "normal").foreach { solver => + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + val model1 = new LinearRegression() + .setFitIntercept(fitIntercept) + .setWeightCol("weight") + .setSolver(solver) + .fit(datasetWithWeightConstantLabel) + val actual1 = Vectors.dense(model1.intercept, model1.coefficients(0), + model1.coefficients(1)) + assert(actual1 ~== expected(idx) absTol 1e-4) + + val model2 = new LinearRegression() + .setFitIntercept(fitIntercept) + .setWeightCol("weight") + .setSolver(solver) + .fit(datasetWithWeightZeroLabel) + val actual2 = Vectors.dense(model2.intercept, model2.coefficients(0), + model2.coefficients(1)) + assert(actual2 ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1e-4) + idx += 1 + } + } + } + + test("regularized linear regression through origin with constant label") { + // The problem is ill-defined if fitIntercept=false, regParam is non-zero. + // An exception is thrown in this case. + Seq("auto", "l-bfgs", "normal").foreach { solver => + for (standardization <- Seq(false, true)) { + val model = new LinearRegression().setFitIntercept(false) + .setRegParam(0.1).setStandardization(standardization).setSolver(solver) + intercept[IllegalArgumentException] { + model.fit(datasetWithWeightConstantLabel) + } + } + } + } + + test("linear regression with l-bfgs when training is not needed") { + // When label is constant, l-bfgs solver returns results without training. + // There are two possibilities: If the label is non-zero but constant, + // and fitIntercept is true, then the model return yMean as intercept without training. + // If label is all zeros, then all coefficients are zero regardless of fitIntercept, so + // no training is needed. + for (fitIntercept <- Seq(false, true)) { + for (standardization <- Seq(false, true)) { + val model1 = new LinearRegression() + .setFitIntercept(fitIntercept) + .setStandardization(standardization) + .setWeightCol("weight") + .setSolver("l-bfgs") + .fit(datasetWithWeightConstantLabel) + if (fitIntercept) { + assert(model1.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4) + } + val model2 = new LinearRegression() + .setFitIntercept(fitIntercept) + .setWeightCol("weight") + .setSolver("l-bfgs") + .fit(datasetWithWeightZeroLabel) + assert(model2.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4) + } + } + } + test("linear regression model training summary") { Seq("auto", "l-bfgs", "normal").foreach { solver => val trainer = new LinearRegression().setSolver(solver) @@ -572,40 +687,67 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { // Validate that we re-insert a prediction column for evaluation val modelNoPredictionColFieldNames = modelNoPredictionCol.summary.predictions.schema.fieldNames - assert((datasetWithDenseFeature.schema.fieldNames.toSet).subsetOf( + assert(datasetWithDenseFeature.schema.fieldNames.toSet.subsetOf( modelNoPredictionColFieldNames.toSet)) assert(modelNoPredictionColFieldNames.exists(s => s.startsWith("prediction_"))) // Residuals in [[LinearRegressionResults]] should equal those manually computed val expectedResiduals = datasetWithDenseFeature.select("features", "label") + .rdd .map { case Row(features: DenseVector, label: Double) => - val prediction = - features(0) * model.coefficients(0) + features(1) * model.coefficients(1) + - model.intercept - label - prediction - } - .zip(model.summary.residuals.map(_.getDouble(0))) + val prediction = + features(0) * model.coefficients(0) + features(1) * model.coefficients(1) + + model.intercept + label - prediction + } + .zip(model.summary.residuals.rdd.map(_.getDouble(0))) .collect() .foreach { case (manualResidual: Double, resultResidual: Double) => - assert(manualResidual ~== resultResidual relTol 1E-5) - } + assert(manualResidual ~== resultResidual relTol 1E-5) + } /* - Use the following R code to generate model training results. - - predictions <- predict(fit, newx=features) - residuals <- label - predictions - > mean(residuals^2) # MSE - [1] 0.009720325 - > mean(abs(residuals)) # MAD - [1] 0.07863206 - > cor(predictions, label)^2# r^2 - [,1] - s0 0.9998749 + # Use the following R code to generate model training results. + + # path/part-00000 is the file generated by running LinearDataGenerator.generateLinearInput + # as described before the beforeAll() method. + d1 <- read.csv("path/part-00000", header=FALSE, stringsAsFactors=FALSE) + fit <- glm(V1 ~ V2 + V3, data = d1, family = "gaussian") + names(f1)[1] = c("V2") + names(f1)[2] = c("V3") + f1 <- data.frame(as.numeric(d1$V2), as.numeric(d1$V3)) + predictions <- predict(fit, newdata=f1) + l1 <- as.numeric(d1$V1) + + residuals <- l1 - predictions + > mean(residuals^2) # MSE + [1] 0.00985449 + > mean(abs(residuals)) # MAD + [1] 0.07961668 + > cor(predictions, l1)^2 # r^2 + [1] 0.9998737 + + > summary(fit) + + Call: + glm(formula = V1 ~ V2 + V3, family = "gaussian", data = d1) + + Deviance Residuals: + Min 1Q Median 3Q Max + -0.47082 -0.06797 0.00002 0.06725 0.34635 + + Coefficients: + Estimate Std. Error t value Pr(>|t|) + (Intercept) 6.3022157 0.0018600 3388 <2e-16 *** + V2 4.6982442 0.0011805 3980 <2e-16 *** + V3 7.1994344 0.0009044 7961 <2e-16 *** + --- + + .... */ - assert(model.summary.meanSquaredError ~== 0.00972035 relTol 1E-5) - assert(model.summary.meanAbsoluteError ~== 0.07863206 relTol 1E-5) - assert(model.summary.r2 ~== 0.9998749 relTol 1E-5) + assert(model.summary.meanSquaredError ~== 0.00985449 relTol 1E-4) + assert(model.summary.meanAbsoluteError ~== 0.07961668 relTol 1E-4) + assert(model.summary.r2 ~== 0.9998737 relTol 1E-4) // Normal solver uses "WeightedLeastSquares". This algorithm does not generate // objective history because it does not run through iterations. @@ -617,17 +759,17 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .sliding(2) .forall(x => x(0) >= x(1))) } else { - // To clalify that the normal solver is used here. + // To clarify that the normal solver is used here. assert(model.summary.objectiveHistory.length == 1) assert(model.summary.objectiveHistory(0) == 0.0) - val devianceResidualsR = Array(-0.35566, 0.34504) - val seCoefR = Array(0.0011756, 0.0009032, 0.0018489) - val tValsR = Array(3998, 7971, 3407) + val devianceResidualsR = Array(-0.47082, 0.34635) + val seCoefR = Array(0.0011805, 0.0009044, 0.0018600) + val tValsR = Array(3980, 7961, 3388) val pValsR = Array(0, 0, 0) model.summary.devianceResiduals.zip(devianceResidualsR).foreach { x => - assert(x._1 ~== x._2 absTol 1E-5) } + assert(x._1 ~== x._2 absTol 1E-4) } model.summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => - assert(x._1 ~== x._2 absTol 1E-5) } + assert(x._1 ~== x._2 absTol 1E-4) } model.summary.tValues.map(_.round).zip(tValsR).foreach{ x => assert(x._1 === x._2) } model.summary.pValues.map(_.round).zip(pValsR).foreach{ x => assert(x._1 === x._2) } } @@ -822,7 +964,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { V1 -3.7271 2.9032 -1.284 0.3279 V2 3.0100 0.6022 4.998 0.0378 * --- - Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 + Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 (Dispersion parameter for gaussian family taken to be 17.4376) @@ -854,4 +996,42 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { model.summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } model.summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } } + + test("read/write") { + def checkModelData(model: LinearRegressionModel, model2: LinearRegressionModel): Unit = { + assert(model.intercept === model2.intercept) + assert(model.coefficients === model2.coefficients) + } + val lr = new LinearRegression() + testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings, + checkModelData) + } + + test("should support all NumericType labels and not support other types") { + val lr = new LinearRegression().setMaxIter(1) + MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression]( + lr, isClassification = false, sqlContext) { (expected, actual) => + assert(expected.intercept === actual.intercept) + assert(expected.coefficients === actual.coefficients) + } + } +} + +object LinearRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "regParam" -> 0.01, + "elasticNetParam" -> 0.1, + "maxIter" -> 2, // intentionally small + "fitIntercept" -> true, + "tol" -> 0.8, + "standardization" -> false, + "solver" -> "l-bfgs" + ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 7e751e4b553b6..ca400e1914518 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -18,9 +18,8 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.impl.TreeTests -import org.apache.spark.ml.util.MLTestingUtils -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.tree.impl.TreeTests +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -31,7 +30,8 @@ import org.apache.spark.sql.DataFrame /** * Test suite for [[RandomForestRegressor]]. */ -class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { +class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest{ import RandomForestRegressorSuite.compareAPIs @@ -82,49 +82,48 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex .setSeed(123) // In this data, feature 1 is very important. - val data: RDD[LabeledPoint] = sc.parallelize(Seq( - new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)), - new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)), - new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)), - new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)), - new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)) - )) + val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc) val categoricalFeatures = Map.empty[Int, Int] val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0) val model = rf.fit(df) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) val importances = model.featureImportances val mostImportantFeature = importances.argmax assert(mostImportantFeature === 1) + assert(importances.toArray.sum === 1.0) + assert(importances.toArray.forall(_ >= 0.0)) + } + + test("should support all NumericType labels and not support other types") { + val rf = new RandomForestRegressor().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[RandomForestRegressionModel, RandomForestRegressor]( + rf, isClassification = false, sqlContext) { (expected, actual) => + TreeTests.checkEqual(expected, actual) + } } ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: Reinstate test once save/load are implemented SPARK-6725 - /* - test("model save/load") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - - val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray - val oldModel = new OldRandomForestModel(OldAlgo.Regression, trees) - val newModel = RandomForestRegressionModel.fromOld(oldModel) - - // Save model, load it back, and compare. - try { - newModel.save(sc, path) - val sameNewModel = RandomForestRegressionModel.load(sc, path) - TreeTests.checkEqual(newModel, sameNewModel) - } finally { - Utils.deleteRecursively(tempDir) + test("read/write") { + def checkModelData( + model: RandomForestRegressionModel, + model2: RandomForestRegressionModel): Unit = { + TreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) } + + val rf = new RandomForestRegressor().setNumTrees(2) + val rdd = TreeTests.getTreeReadWriteData(sc) + + val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "variance") + + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) + testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData) } - */ } private object RandomForestRegressorSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index 997f574e51f6a..0bd14978b2bbf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -18,17 +18,19 @@ package org.apache.spark.ml.source.libsvm import java.io.File +import java.nio.charset.StandardCharsets -import com.google.common.base.Charsets import com.google.common.io.Files -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.SaveMode import org.apache.spark.util.Utils + class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { - var tempDir: File = _ + // Path for dataset var path: String = _ override def beforeAll(): Unit = { @@ -39,15 +41,18 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { |0 |0 2:4.0 4:5.0 6:6.0 """.stripMargin - tempDir = Utils.createTempDir() - val file = new File(tempDir, "part-00000") - Files.write(lines, file, Charsets.US_ASCII) - path = tempDir.toURI.toString + val dir = Utils.createDirectory(tempDir.getCanonicalPath, "data") + val file = new File(dir, "part-00000") + Files.write(lines, file, StandardCharsets.UTF_8) + path = dir.toURI.toString } override def afterAll(): Unit = { - Utils.deleteRecursively(tempDir) - super.afterAll() + try { + Utils.deleteRecursively(new File(path)) + } finally { + super.afterAll() + } } test("select as sparse vector") { @@ -79,4 +84,24 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { val v = row1.getAs[SparseVector](1) assert(v == Vectors.sparse(100, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) } + + test("write libsvm data and read it again") { + val df = sqlContext.read.format("libsvm").load(path) + val tempDir2 = new File(tempDir, "read_write_test") + val writepath = tempDir2.toURI.toString + // TODO: Remove requirement to coalesce by supporting multiple reads. + df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writepath) + + val df2 = sqlContext.read.format("libsvm").load(writepath) + val row1 = df2.first() + val v = row1.getAs[SparseVector](1) + assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + } + + test("write libsvm data failed due to invalid schema") { + val df = sqlContext.read.format("text").load(path) + intercept[SparkException] { + df.write.format("libsvm").save(path + "_2") + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala similarity index 99% rename from mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala rename to mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala index 9d756da410325..77ab3d8bb75f7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.tree.impl +package org.apache.spark.ml.tree.impl import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.tree.EnsembleTestHelper diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala new file mode 100644 index 0000000000000..fecf372c3d843 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.Logging +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{GradientBoostedTreesSuite => OldGBTSuite} +import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.impurity.Variance +import org.apache.spark.mllib.tree.loss.{AbsoluteError, LogLoss, SquaredError} +import org.apache.spark.mllib.util.MLlibTestSparkContext + +/** + * Test suite for [[GradientBoostedTrees]]. + */ +class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { + + test("runWithValidation stops early and performs better on a validation dataset") { + // Set numIterations large enough so that it stops early. + val numIterations = 20 + val trainRdd = sc.parallelize(OldGBTSuite.trainData, 2) + val validateRdd = sc.parallelize(OldGBTSuite.validateData, 2) + val trainDF = sqlContext.createDataFrame(trainRdd) + val validateDF = sqlContext.createDataFrame(validateRdd) + + val algos = Array(Regression, Regression, Classification) + val losses = Array(SquaredError, AbsoluteError, LogLoss) + algos.zip(losses).foreach { case (algo, loss) => + val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty) + val boostingStrategy = + new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) + val (validateTrees, validateTreeWeights) = GradientBoostedTrees + .runWithValidation(trainRdd, validateRdd, boostingStrategy, 42L) + val numTrees = validateTrees.length + assert(numTrees !== numIterations) + + // Test that it performs better on the validation dataset. + val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, 42L) + val (errorWithoutValidation, errorWithValidation) = { + if (algo == Classification) { + val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) + (GradientBoostedTrees.computeError(remappedRdd, trees, treeWeights, loss), + GradientBoostedTrees.computeError(remappedRdd, validateTrees, + validateTreeWeights, loss)) + } else { + (GradientBoostedTrees.computeError(validateRdd, trees, treeWeights, loss), + GradientBoostedTrees.computeError(validateRdd, validateTrees, + validateTreeWeights, loss)) + } + } + assert(errorWithValidation <= errorWithoutValidation) + + // Test that results from evaluateEachIteration comply with runWithValidation. + // Note that convergenceTol is set to 0.0 + val evaluationArray = GradientBoostedTrees + .evaluateEachIteration(validateRdd, trees, treeWeights, loss, algo) + assert(evaluationArray.length === numIterations) + assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1)) + var i = 1 + while (i < numTrees) { + assert(evaluationArray(i) <= evaluationArray(i - 1)) + i += 1 + } + } + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index d5c238e9ae164..1719f9fab5345 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -17,12 +17,16 @@ package org.apache.spark.ml.tree.impl +import scala.collection.mutable + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.DecisionTreeClassificationModel -import org.apache.spark.ml.impl.TreeTests -import org.apache.spark.ml.tree.{ContinuousSplit, DecisionTreeModel, LeafNode, Node} +import org.apache.spark.ml.tree._ import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.tree.impurity.GiniCalculator +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.collection.OpenHashMap @@ -34,6 +38,453 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { import RandomForestSuite.mapToVec + ///////////////////////////////////////////////////////////////////////////// + // Tests for split calculation + ///////////////////////////////////////////////////////////////////////////// + + test("Binary classification with continuous features: split calculation") { + val arr = OldDTSuite.generateOrderedLabeledPointsWithLabel1() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2, 100) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + val splits = RandomForest.findSplits(rdd, metadata, seed = 42) + assert(splits.length === 2) + assert(splits(0).length === 99) + } + + test("Binary classification with binary (ordered) categorical features: split calculation") { + val arr = OldDTSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2, + maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) + + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val splits = RandomForest.findSplits(rdd, metadata, seed = 42) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + assert(splits.length === 2) + // no splits pre-computed for ordered categorical features + assert(splits(0).length === 0) + } + + test("Binary classification with 3-ary (ordered) categorical features," + + " with no samples for one category: split calculation") { + val arr = OldDTSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2, + maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + val splits = RandomForest.findSplits(rdd, metadata, seed = 42) + assert(splits.length === 2) + // no splits pre-computed for ordered categorical features + assert(splits(0).length === 0) + } + + test("find splits for a continuous feature") { + // find splits for normal case + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(6), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array.fill(200000)(math.random) + val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits.length === 5) + assert(fakeMetadata.numSplits(0) === 5) + assert(fakeMetadata.numBins(0) === 6) + // check returned splits are distinct + assert(splits.distinct.length === splits.length) + } + + // find splits should not return identical splits + // when there are not enough split candidates, reduce the number of splits in metadata + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(5), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) + val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits.length === 3) + // check returned splits are distinct + assert(splits.distinct.length === splits.length) + } + + // find splits when most samples close to the minimum + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) + val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits.length === 2) + assert(splits(0) === 2.0) + assert(splits(1) === 3.0) + } + + // find splits when most samples close to the maximum + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) + val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits.length === 1) + assert(splits(0) === 1.0) + } + } + + test("Multiclass classification with unordered categorical features: split calculations") { + val arr = OldDTSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new OldStrategy( + OldAlgo.Classification, + Gini, + maxDepth = 2, + numClasses = 100, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(metadata.isUnordered(featureIndex = 0)) + assert(metadata.isUnordered(featureIndex = 1)) + val splits = RandomForest.findSplits(rdd, metadata, seed = 42) + assert(splits.length === 2) + assert(splits(0).length === 3) + assert(metadata.numSplits(0) === 3) + assert(metadata.numBins(0) === 3) + assert(metadata.numSplits(1) === 3) + assert(metadata.numBins(1) === 3) + + // Expecting 2^2 - 1 = 3 splits per feature + def checkCategoricalSplit(s: Split, featureIndex: Int, leftCategories: Array[Double]): Unit = { + assert(s.featureIndex === featureIndex) + assert(s.isInstanceOf[CategoricalSplit]) + val s0 = s.asInstanceOf[CategoricalSplit] + assert(s0.leftCategories === leftCategories) + assert(s0.numCategories === 3) // for this unit test + } + // Feature 0 + checkCategoricalSplit(splits(0)(0), 0, Array(0.0)) + checkCategoricalSplit(splits(0)(1), 0, Array(1.0)) + checkCategoricalSplit(splits(0)(2), 0, Array(0.0, 1.0)) + // Feature 1 + checkCategoricalSplit(splits(1)(0), 1, Array(0.0)) + checkCategoricalSplit(splits(1)(1), 1, Array(1.0)) + checkCategoricalSplit(splits(1)(2), 1, Array(0.0, 1.0)) + } + + test("Multiclass classification with ordered categorical features: split calculations") { + val arr = OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() + assert(arr.length === 3000) + val rdd = sc.parallelize(arr) + val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 100, + maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) + // 2^(10-1) - 1 > 100, so categorical features will be ordered + + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + val splits = RandomForest.findSplits(rdd, metadata, seed = 42) + assert(splits.length === 2) + // no splits pre-computed for ordered categorical features + assert(splits(0).length === 0) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of other algorithm internals + ///////////////////////////////////////////////////////////////////////////// + + test("extract categories from a number for multiclass classification") { + val l = RandomForest.extractMultiClassCategories(13, 10) + assert(l.length === 3) + assert(Seq(3.0, 2.0, 0.0) === l) + } + + test("Avoid aggregation on the last level") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) + val input = sc.parallelize(arr) + + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val splits = RandomForest.findSplits(input, metadata, seed = 42) + + val treeInput = TreePoint.convertToTreeRDD(input, splits, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, withReplacement = false) + + val topNode = LearningNode.emptyNode(nodeIndex = 1) + assert(topNode.isLeaf === false) + assert(topNode.stats === null) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, LearningNode)]() + RandomForest.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue) + + // don't enqueue leaf nodes into node queue + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.stats !== null) + assert(topNode.stats.impurity > 0.0) + + // set impurity and predict for child nodes + assert(topNode.leftChild.get.toNode.prediction === 0.0) + assert(topNode.rightChild.get.toNode.prediction === 1.0) + assert(topNode.leftChild.get.stats.impurity === 0.0) + assert(topNode.rightChild.get.stats.impurity === 0.0) + } + + test("Avoid aggregation if impurity is 0.0") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) + val input = sc.parallelize(arr) + + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 5, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val splits = RandomForest.findSplits(input, metadata, seed = 42) + + val treeInput = TreePoint.convertToTreeRDD(input, splits, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, withReplacement = false) + + val topNode = LearningNode.emptyNode(nodeIndex = 1) + assert(topNode.isLeaf === false) + assert(topNode.stats === null) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, LearningNode)]() + RandomForest.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue) + + // don't enqueue a node into node queue if its impurity is 0.0 + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.stats !== null) + assert(topNode.stats.impurity > 0.0) + + // set impurity and predict for child nodes + assert(topNode.leftChild.get.toNode.prediction === 0.0) + assert(topNode.rightChild.get.toNode.prediction === 1.0) + assert(topNode.leftChild.get.stats.impurity === 0.0) + assert(topNode.rightChild.get.stats.impurity === 0.0) + } + + test("Use soft prediction for binary classification with ordered categorical features") { + // The following dataset is set up such that the best split is {1} vs. {0, 2}. + // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(2.0))) + val input = sc.parallelize(arr) + + // Must set maxBins s.t. the feature will be treated as an ordered categorical feature. + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3) + + val model = RandomForest.run(input, strategy, numTrees = 1, featureSubsetStrategy = "all", + seed = 42).head + model.rootNode match { + case n: InternalNode => n.split match { + case s: CategoricalSplit => + assert(s.leftCategories === Array(1.0)) + case _ => throw new AssertionError("model.rootNode.split was not a CategoricalSplit") + } + case _ => throw new AssertionError("model.rootNode was not an InternalNode") + } + } + + test("Second level node building with vs. without groups") { + val arr = OldDTSuite.generateOrderedLabeledPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + // For tree with 1 group + val strategy1 = + new OldStrategy(OldAlgo.Classification, Entropy, 3, 2, 100, maxMemoryInMB = 1000) + // For tree with multiple groups + val strategy2 = + new OldStrategy(OldAlgo.Classification, Entropy, 3, 2, 100, maxMemoryInMB = 0) + + val tree1 = RandomForest.run(rdd, strategy1, numTrees = 1, featureSubsetStrategy = "all", + seed = 42).head + val tree2 = RandomForest.run(rdd, strategy2, numTrees = 1, featureSubsetStrategy = "all", + seed = 42).head + + def getChildren(rootNode: Node): Array[InternalNode] = rootNode match { + case n: InternalNode => + assert(n.leftChild.isInstanceOf[InternalNode]) + assert(n.rightChild.isInstanceOf[InternalNode]) + Array(n.leftChild.asInstanceOf[InternalNode], n.rightChild.asInstanceOf[InternalNode]) + case _ => throw new AssertionError("rootNode was not an InternalNode") + } + + // Single group second level tree construction. + val children1 = getChildren(tree1.rootNode) + val children2 = getChildren(tree2.rootNode) + + // Verify whether the splits obtained using single group and multiple group level + // construction strategies are the same. + for (i <- 0 until 2) { + assert(children1(i).gain > 0) + assert(children2(i).gain > 0) + assert(children1(i).split === children2(i).split) + assert(children1(i).impurity === children2(i).impurity) + assert(children1(i).impurityStats.stats === children2(i).impurityStats.stats) + assert(children1(i).leftChild.impurity === children2(i).leftChild.impurity) + assert(children1(i).rightChild.impurity === children2(i).rightChild.impurity) + assert(children1(i).prediction === children2(i).prediction) + } + } + + def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: OldStrategy) { + val numFeatures = 50 + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000) + val rdd = sc.parallelize(arr) + + // Select feature subset for top nodes. Return true if OK. + def checkFeatureSubsetStrategy( + numTrees: Int, + featureSubsetStrategy: String, + numFeaturesPerNode: Int): Unit = { + val seeds = Array(123, 5354, 230, 349867, 23987) + val maxMemoryUsage: Long = 128 * 1024L * 1024L + val metadata = + DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees, featureSubsetStrategy) + seeds.foreach { seed => + val failString = s"Failed on test with:" + + s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," + + s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed" + val nodeQueue = new mutable.Queue[(Int, LearningNode)]() + val topNodes: Array[LearningNode] = new Array[LearningNode](numTrees) + Range(0, numTrees).foreach { treeIndex => + topNodes(treeIndex) = LearningNode.emptyNode(nodeIndex = 1) + nodeQueue.enqueue((treeIndex, topNodes(treeIndex))) + } + val rng = new scala.util.Random(seed = seed) + val (nodesForGroup: Map[Int, Array[LearningNode]], + treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) = + RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) + + assert(nodesForGroup.size === numTrees, failString) + assert(nodesForGroup.values.forall(_.length == 1), failString) // 1 node per tree + + if (numFeaturesPerNode == numFeatures) { + // featureSubset values should all be None + assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)), + failString) + } else { + // Check number of features. + assert(treeToNodeToIndexInfo.values.forall(_.values.forall( + _.featureSubset.get.length === numFeaturesPerNode)), failString) + } + } + } + + checkFeatureSubsetStrategy(numTrees = 1, "auto", numFeatures) + checkFeatureSubsetStrategy(numTrees = 1, "all", numFeatures) + checkFeatureSubsetStrategy(numTrees = 1, "sqrt", math.sqrt(numFeatures).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 1, "log2", + (math.log(numFeatures) / math.log(2)).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt) + + val realStrategies = Array(".1", ".10", "0.10", "0.1", "0.9", "1.0") + for (strategy <- realStrategies) { + val expected = (strategy.toDouble * numFeatures).ceil.toInt + checkFeatureSubsetStrategy(numTrees = 1, strategy, expected) + } + + val integerStrategies = Array("1", "10", "100", "1000", "10000") + for (strategy <- integerStrategies) { + val expected = if (strategy.toInt < numFeatures) strategy.toInt else numFeatures + checkFeatureSubsetStrategy(numTrees = 1, strategy, expected) + } + + val invalidStrategies = Array("-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0") + for (invalidStrategy <- invalidStrategies) { + intercept[IllegalArgumentException]{ + val metadata = + DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 1, invalidStrategy) + } + } + + checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures) + checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 2, "log2", + (math.log(numFeatures) / math.log(2)).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt) + + for (strategy <- realStrategies) { + val expected = (strategy.toDouble * numFeatures).ceil.toInt + checkFeatureSubsetStrategy(numTrees = 2, strategy, expected) + } + + for (strategy <- integerStrategies) { + val expected = if (strategy.toInt < numFeatures) strategy.toInt else numFeatures + checkFeatureSubsetStrategy(numTrees = 2, strategy, expected) + } + for (invalidStrategy <- invalidStrategies) { + intercept[IllegalArgumentException]{ + val metadata = + DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 2, invalidStrategy) + } + } + } + + test("Binary classification with continuous features: subsampling features") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 2, + numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) + } + + test("Binary classification with continuous features and node Id cache: subsampling features") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 2, + numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, + useNodeIdCache = true) + binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) + } + test("computeFeatureImportance, featureImportances") { /* Build tree for testing, with this structure: grandParent @@ -58,7 +509,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // Test feature importance computed at different subtrees. def testNode(node: Node, expected: Map[Int, Double]): Unit = { val map = new OpenHashMap[Int, Double]() - RandomForest.computeFeatureImportance(node, map) + TreeEnsembleModel.computeFeatureImportance(node, map) assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) } @@ -80,7 +531,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 3) .asInstanceOf[DecisionTreeModel] } - val importances: Vector = RandomForest.featureImportances(trees, 2) + val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2) val tree2norm = feature0importance + feature1importance val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0, (feature1importance / tree2norm) / 2.0) @@ -91,7 +542,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val map = new OpenHashMap[Int, Double]() map(0) = 1.0 map(2) = 2.0 - RandomForest.normalizeMapValues(map) + TreeEnsembleModel.normalizeMapValues(map) val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0) assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala new file mode 100644 index 0000000000000..b650a9f092b0e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import scala.collection.JavaConverters._ + +import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.tree._ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SQLContext} + +private[ml] object TreeTests extends SparkFunSuite { + + /** + * Convert the given data to a DataFrame, and set the features and label metadata. + * @param data Dataset. Categorical features and labels must already have 0-based indices. + * This must be non-empty. + * @param categoricalFeatures Map: categorical feature index -> number of distinct values + * @param numClasses Number of classes label can take. If 0, mark as continuous. + * @return DataFrame with metadata + */ + def setMetadata( + data: RDD[LabeledPoint], + categoricalFeatures: Map[Int, Int], + numClasses: Int): DataFrame = { + val sqlContext = SQLContext.getOrCreate(data.sparkContext) + import sqlContext.implicits._ + val df = data.toDF() + val numFeatures = data.first().features.size + val featuresAttributes = Range(0, numFeatures).map { feature => + if (categoricalFeatures.contains(feature)) { + NominalAttribute.defaultAttr.withIndex(feature).withNumValues(categoricalFeatures(feature)) + } else { + NumericAttribute.defaultAttr.withIndex(feature) + } + }.toArray + val featuresMetadata = new AttributeGroup("features", featuresAttributes).toMetadata() + val labelAttribute = if (numClasses == 0) { + NumericAttribute.defaultAttr.withName("label") + } else { + NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses) + } + val labelMetadata = labelAttribute.toMetadata() + df.select(df("features").as("features", featuresMetadata), + df("label").as("label", labelMetadata)) + } + + /** Java-friendly version of [[setMetadata()]] */ + def setMetadata( + data: JavaRDD[LabeledPoint], + categoricalFeatures: java.util.Map[java.lang.Integer, java.lang.Integer], + numClasses: Int): DataFrame = { + setMetadata(data.rdd, categoricalFeatures.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + numClasses) + } + + /** + * Set label metadata (particularly the number of classes) on a DataFrame. + * @param data Dataset. Categorical features and labels must already have 0-based indices. + * This must be non-empty. + * @param numClasses Number of classes label can take. If 0, mark as continuous. + * @param labelColName Name of the label column on which to set the metadata. + * @return DataFrame with metadata + */ + def setMetadata(data: DataFrame, numClasses: Int, labelColName: String): DataFrame = { + val labelAttribute = if (numClasses == 0) { + NumericAttribute.defaultAttr.withName(labelColName) + } else { + NominalAttribute.defaultAttr.withName(labelColName).withNumValues(numClasses) + } + val labelMetadata = labelAttribute.toMetadata() + data.select(data("features"), data(labelColName).as(labelColName, labelMetadata)) + } + + /** + * Check if the two trees are exactly the same. + * Note: I hesitate to override Node.equals since it could cause problems if users + * make mistakes such as creating loops of Nodes. + * If the trees are not equal, this prints the two trees and throws an exception. + */ + def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = { + try { + checkEqual(a.rootNode, b.rootNode) + } catch { + case ex: Exception => + throw new AssertionError("checkEqual failed since the two trees were not identical.\n" + + "TREE A:\n" + a.toDebugString + "\n" + + "TREE B:\n" + b.toDebugString + "\n", ex) + } + } + + /** + * Return true iff the two nodes and their descendants are exactly the same. + * Note: I hesitate to override Node.equals since it could cause problems if users + * make mistakes such as creating loops of Nodes. + */ + private def checkEqual(a: Node, b: Node): Unit = { + assert(a.prediction === b.prediction) + assert(a.impurity === b.impurity) + (a, b) match { + case (aye: InternalNode, bee: InternalNode) => + assert(aye.split === bee.split) + checkEqual(aye.leftChild, bee.leftChild) + checkEqual(aye.rightChild, bee.rightChild) + case (aye: LeafNode, bee: LeafNode) => // do nothing + case _ => + throw new AssertionError("Found mismatched nodes") + } + } + + /** + * Check if the two models are exactly the same. + * If the models are not equal, this throws an exception. + */ + def checkEqual[M <: DecisionTreeModel](a: TreeEnsembleModel[M], b: TreeEnsembleModel[M]): Unit = { + try { + a.trees.zip(b.trees).foreach { case (treeA, treeB) => + TreeTests.checkEqual(treeA, treeB) + } + assert(a.treeWeights === b.treeWeights) + } catch { + case ex: Exception => throw new AssertionError( + "checkEqual failed since the two tree ensembles were not identical") + } + } + + /** + * Helper method for constructing a tree for testing. + * Given left, right children, construct a parent node. + * @param split Split for parent node + * @return Parent node with children attached + */ + def buildParentNode(left: Node, right: Node, split: Split): Node = { + val leftImp = left.impurityStats + val rightImp = right.impurityStats + val parentImp = leftImp.copy.add(rightImp) + val leftWeight = leftImp.count / parentImp.count.toDouble + val rightWeight = rightImp.count / parentImp.count.toDouble + val gain = parentImp.calculate() - + (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate()) + val pred = parentImp.predict + new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp) + } + + /** + * Create some toy data for testing feature importances. + */ + def featureImportanceData(sc: SparkContext): RDD[LabeledPoint] = sc.parallelize(Seq( + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)), + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)) + )) + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + * + * This set of Params is for all Decision Tree-based models. + */ + val allParamSettings: Map[String, Any] = Map( + "checkpointInterval" -> 7, + "seed" -> 543L, + "maxDepth" -> 2, + "maxBins" -> 20, + "minInstancesPerNode" -> 2, + "minInfoGain" -> 1e-14, + "maxMemoryInMB" -> 257, + "cacheNodeIds" -> true + ) + + /** Data for tree read/write tests which produces a non-trivial tree. */ + def getTreeReadWriteData(sc: SparkContext): RDD[LabeledPoint] = { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 2.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 2.0))) + sc.parallelize(arr) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index cbe09292a0337..3e734aabc5544 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -18,25 +18,27 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.MLTestingUtils -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.{Estimator, Model, Pipeline} +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.feature.HashingTF +import org.apache.spark.ml.param.{ParamMap, ParamPair} import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} -import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.types.{StructField, StructType} -class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { +class CrossValidatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() - val sqlContext = new SQLContext(sc) dataset = sqlContext.createDataFrame( sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) } @@ -94,8 +96,8 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { assert(cvModel2.avgMetrics.length === lrParamMaps.length) } - test("validateParams should check estimatorParamMaps") { - import CrossValidatorSuite._ + test("transformSchema should check estimatorParamMaps") { + import CrossValidatorSuite.{MyEstimator, MyEvaluator} val est = new MyEstimator("est") val eval = new MyEvaluator @@ -108,30 +110,214 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { .setEstimatorParamMaps(paramMaps) .setEvaluator(eval) - cv.validateParams() // This should pass. + cv.transformSchema(new StructType()) // This should pass. val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") cv.setEstimatorParamMaps(invalidParamMaps) intercept[IllegalArgumentException] { - cv.validateParams() + cv.transformSchema(new StructType()) } } + + test("read/write: CrossValidator with simple estimator") { + val lr = new LogisticRegression().setMaxIter(3) + val evaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val cv = new CrossValidator() + .setEstimator(lr) + .setEvaluator(evaluator) + .setNumFolds(20) + .setEstimatorParamMaps(paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + + assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] + assert(evaluator.uid === evaluator2.uid) + assert(evaluator.getMetricName === evaluator2.getMetricName) + + cv2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getMaxIter === lr2.getMaxIter) + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" LogisticRegression but found ${other.getClass.getName}") + } + + CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + } + + test("read/write: CrossValidator with complex estimator") { + // workflow: CrossValidator[Pipeline[HashingTF, CrossValidator[LogisticRegression]]] + val lrEvaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + + val lr = new LogisticRegression().setMaxIter(3) + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val lrcv = new CrossValidator() + .setEstimator(lr) + .setEvaluator(lrEvaluator) + .setEstimatorParamMaps(lrParamMaps) + + val hashingTF = new HashingTF() + val pipeline = new Pipeline().setStages(Array(hashingTF, lrcv)) + val paramMaps = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures, Array(10, 20)) + .addGrid(lr.elasticNetParam, Array(0.0, 1.0)) + .build() + val evaluator = new BinaryClassificationEvaluator() + + val cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(evaluator) + .setNumFolds(20) + .setEstimatorParamMaps(paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + + assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + assert(cv.getEvaluator.uid === cv2.getEvaluator.uid) + + CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + + cv2.getEstimator match { + case pipeline2: Pipeline => + assert(pipeline.uid === pipeline2.uid) + pipeline2.getStages match { + case Array(hashingTF2: HashingTF, lrcv2: CrossValidator) => + assert(hashingTF.uid === hashingTF2.uid) + lrcv2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getMaxIter === lr2.getMaxIter) + case other => + throw new AssertionError(s"Loaded internal CrossValidator expected to be" + + s" LogisticRegression but found type ${other.getClass.getName}") + } + assert(lrcv.uid === lrcv2.uid) + assert(lrcv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + assert(lrEvaluator.uid === lrcv2.getEvaluator.uid) + CrossValidatorSuite.compareParamMaps(lrParamMaps, lrcv2.getEstimatorParamMaps) + case other => + throw new AssertionError("Loaded Pipeline expected stages (HashingTF, CrossValidator)" + + " but found: " + other.map(_.getClass.getName).mkString(", ")) + } + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" CrossValidator but found ${other.getClass.getName}") + } + } + + test("read/write: CrossValidator fails for extraneous Param") { + val lr = new LogisticRegression() + val lr2 = new LogisticRegression() + val evaluator = new BinaryClassificationEvaluator() + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .addGrid(lr2.regParam, Array(0.1, 0.2)) + .build() + val cv = new CrossValidator() + .setEstimator(lr) + .setEvaluator(evaluator) + .setEstimatorParamMaps(paramMaps) + withClue("CrossValidator.write failed to catch extraneous Param error") { + intercept[IllegalArgumentException] { + cv.write + } + } + } + + test("read/write: CrossValidatorModel") { + val lr = new LogisticRegression() + .setThreshold(0.6) + val lrModel = new LogisticRegressionModel(lr.uid, Vectors.dense(1.0, 2.0), 1.2) + .setThreshold(0.6) + val evaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val cv = new CrossValidatorModel("cvUid", lrModel, Array(0.3, 0.6)) + cv.set(cv.estimator, lr) + .set(cv.evaluator, evaluator) + .set(cv.numFolds, 20) + .set(cv.estimatorParamMaps, paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + + assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] + assert(evaluator.uid === evaluator2.uid) + assert(evaluator.getMetricName === evaluator2.getMetricName) + + cv2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getThreshold === lr2.getThreshold) + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" LogisticRegression but found ${other.getClass.getName}") + } + + CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + + cv2.bestModel match { + case lrModel2: LogisticRegressionModel => + assert(lrModel.uid === lrModel2.uid) + assert(lrModel.getThreshold === lrModel2.getThreshold) + assert(lrModel.coefficients === lrModel2.coefficients) + assert(lrModel.intercept === lrModel2.intercept) + case other => + throw new AssertionError(s"Loaded CrossValidator expected bestModel of type" + + s" LogisticRegressionModel but found ${other.getClass.getName}") + } + assert(cv.avgMetrics === cv2.avgMetrics) + } } -object CrossValidatorSuite { +object CrossValidatorSuite extends SparkFunSuite { + + /** + * Assert sequences of estimatorParamMaps are identical. + * Params must be simple types comparable with `===`. + */ + def compareParamMaps(pMaps: Array[ParamMap], pMaps2: Array[ParamMap]): Unit = { + assert(pMaps.length === pMaps2.length) + pMaps.zip(pMaps2).foreach { case (pMap, pMap2) => + assert(pMap.size === pMap2.size) + pMap.toSeq.foreach { case ParamPair(p, v) => + assert(pMap2.contains(p)) + assert(pMap2(p) === v) + } + } + } abstract class MyModel extends Model[MyModel] class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { - override def validateParams(): Unit = require($(inputCol).nonEmpty) - - override def fit(dataset: DataFrame): MyModel = { + override def fit(dataset: Dataset[_]): MyModel = { throw new UnsupportedOperationException } override def transformSchema(schema: StructType): StructType = { - throw new UnsupportedOperationException + require($(inputCol).nonEmpty) + schema } override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra) @@ -139,7 +325,7 @@ object CrossValidatorSuite { class MyEvaluator extends Evaluator { - override def evaluate(dataset: DataFrame): Double = { + override def evaluate(dataset: Dataset[_]): Double = { throw new UnsupportedOperationException } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 5fb80091d0b4b..dbee47c8475d7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -19,17 +19,20 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType -class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext { +class TrainValidationSplitSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("train validation with logistic regression") { val dataset = sqlContext.createDataFrame( sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) @@ -45,6 +48,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) .setTrainRatio(0.5) + .setSeed(42L) val cvModel = cv.fit(dataset) val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(cv.getTrainRatio === 0.5) @@ -69,6 +73,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) .setTrainRatio(0.5) + .setSeed(42L) val cvModel = cv.fit(dataset) val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression] assert(parent.getRegParam === 0.001) @@ -83,7 +88,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext assert(cvModel2.validationMetrics.length === lrParamMaps.length) } - test("validateParams should check estimatorParamMaps") { + test("transformSchema should check estimatorParamMaps") { import TrainValidationSplitSuite._ val est = new MyEstimator("est") @@ -97,14 +102,54 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext .setEstimatorParamMaps(paramMaps) .setEvaluator(eval) .setTrainRatio(0.5) - cv.validateParams() // This should pass. + cv.transformSchema(new StructType()) // This should pass. val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") cv.setEstimatorParamMaps(invalidParamMaps) intercept[IllegalArgumentException] { - cv.validateParams() + cv.transformSchema(new StructType()) } } + + test("read/write: TrainValidationSplit") { + val lr = new LogisticRegression().setMaxIter(3) + val evaluator = new BinaryClassificationEvaluator() + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val tvs = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(evaluator) + .setTrainRatio(0.5) + .setEstimatorParamMaps(paramMaps) + .setSeed(42L) + + val tvs2 = testDefaultReadWrite(tvs, testParams = false) + + assert(tvs.getTrainRatio === tvs2.getTrainRatio) + } + + test("read/write: TrainValidationSplitModel") { + val lr = new LogisticRegression() + .setThreshold(0.6) + val lrModel = new LogisticRegressionModel(lr.uid, Vectors.dense(1.0, 2.0), 1.2) + .setThreshold(0.6) + val evaluator = new BinaryClassificationEvaluator() + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val tvs = new TrainValidationSplitModel("cvUid", lrModel, Array(0.3, 0.6)) + tvs.set(tvs.estimator, lr) + .set(tvs.evaluator, evaluator) + .set(tvs.trainRatio, 0.5) + .set(tvs.estimatorParamMaps, paramMaps) + .set(tvs.seed, 42L) + + val tvs2 = testDefaultReadWrite(tvs, testParams = false) + + assert(tvs.getTrainRatio === tvs2.getTrainRatio) + assert(tvs.validationMetrics === tvs2.validationMetrics) + } } object TrainValidationSplitSuite { @@ -113,14 +158,13 @@ object TrainValidationSplitSuite { class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { - override def validateParams(): Unit = require($(inputCol).nonEmpty) - - override def fit(dataset: DataFrame): MyModel = { + override def fit(dataset: Dataset[_]): MyModel = { throw new UnsupportedOperationException } override def transformSchema(schema: StructType): StructType = { - throw new UnsupportedOperationException + require($(inputCol).nonEmpty) + schema } override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra) @@ -128,7 +172,7 @@ object TrainValidationSplitSuite { class MyEvaluator extends Evaluator { - override def evaluate(dataset: DataFrame): Double = { + override def evaluate(dataset: Dataset[_]): Double = { throw new UnsupportedOperationException } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala new file mode 100644 index 0000000000000..7ebd7eb144632 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.{File, IOException} + +import org.scalatest.Suite + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Dataset} + +trait DefaultReadWriteTest extends TempDirectory { self: Suite => + + /** + * Checks "overwrite" option and params. + * This saves to and loads from [[tempDir]], but creates a subdirectory with a random name + * in order to avoid conflicts from multiple calls to this method. + * + * @param instance ML instance to test saving/loading + * @param testParams If true, then test values of Params. Otherwise, just test overwrite option. + * @tparam T ML instance type + * @return Instance loaded from file + */ + def testDefaultReadWrite[T <: Params with MLWritable]( + instance: T, + testParams: Boolean = true): T = { + val uid = instance.uid + val subdirName = Identifiable.randomUID("test") + + val subdir = new File(tempDir, subdirName) + val path = new File(subdir, uid).getPath + + instance.save(path) + intercept[IOException] { + instance.save(path) + } + instance.write.overwrite().save(path) + val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[MLReader[T]] + val newInstance = loader.load(path) + + assert(newInstance.uid === instance.uid) + if (testParams) { + instance.params.foreach { p => + if (instance.isDefined(p)) { + (instance.getOrDefault(p), newInstance.getOrDefault(p)) match { + case (Array(values), Array(newValues)) => + assert(values === newValues, s"Values do not match on param ${p.name}.") + case (value, newValue) => + assert(value === newValue, s"Values do not match on param ${p.name}.") + } + } else { + assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.") + } + } + } + + val load = instance.getClass.getMethod("load", classOf[String]) + val another = load.invoke(instance, path).asInstanceOf[T] + assert(another.uid === instance.uid) + another + } + + /** + * Default test for Estimator, Model pairs: + * - Explicitly set Params, and train model + * - Test save/load using [[testDefaultReadWrite()]] on Estimator and Model + * - Check Params on Estimator and Model + * - Compare model data + * + * This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s. + * + * @param estimator Estimator to test + * @param dataset Dataset to pass to [[Estimator.fit()]] + * @param testParams Set of [[Param]] values to set in estimator + * @param checkModelData Method which takes the original and loaded [[Model]] and compares their + * data. This method does not need to check [[Param]] values. + * @tparam E Type of [[Estimator]] + * @tparam M Type of [[Model]] produced by estimator + */ + def testEstimatorAndModelReadWrite[ + E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable]( + estimator: E, + dataset: Dataset[_], + testParams: Map[String, Any], + checkModelData: (M, M) => Unit): Unit = { + // Set some Params to make sure set Params are serialized. + testParams.foreach { case (p, v) => + estimator.set(estimator.getParam(p), v) + } + val model = estimator.fit(dataset) + + // Test Estimator save/load + val estimator2 = testDefaultReadWrite(estimator) + testParams.foreach { case (p, v) => + val param = estimator.getParam(p) + assert(estimator.get(param).get === estimator2.get(param).get) + } + + // Test Model save/load + val model2 = testDefaultReadWrite(model) + testParams.foreach { case (p, v) => + val param = model.getParam(p) + assert(model.get(param).get === model2.get(param).get) + } + + checkModelData(model, model2) + } +} + +class MyParams(override val uid: String) extends Params with MLWritable { + + final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc") + final val intParam: IntParam = new IntParam(this, "intParam", "doc") + final val floatParam: FloatParam = new FloatParam(this, "floatParam", "doc") + final val doubleParam: DoubleParam = new DoubleParam(this, "doubleParam", "doc") + final val longParam: LongParam = new LongParam(this, "longParam", "doc") + final val stringParam: Param[String] = new Param[String](this, "stringParam", "doc") + final val intArrayParam: IntArrayParam = new IntArrayParam(this, "intArrayParam", "doc") + final val doubleArrayParam: DoubleArrayParam = + new DoubleArrayParam(this, "doubleArrayParam", "doc") + final val stringArrayParam: StringArrayParam = + new StringArrayParam(this, "stringArrayParam", "doc") + + setDefault(intParamWithDefault -> 0) + set(intParam -> 1) + set(floatParam -> 2.0f) + set(doubleParam -> 3.0) + set(longParam -> 4L) + set(stringParam -> "5") + set(intArrayParam -> Array(6, 7)) + set(doubleArrayParam -> Array(8.0, 9.0)) + set(stringArrayParam -> Array("10", "11")) + + override def copy(extra: ParamMap): Params = defaultCopy(extra) + + override def write: MLWriter = new DefaultParamsWriter(this) +} + +object MyParams extends MLReadable[MyParams] { + + override def read: MLReader[MyParams] = new DefaultParamsReader[MyParams] + + override def load(path: String): MyParams = super.load(path) +} + +class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { + + test("default read/write") { + val myParams = new MyParams("my_params") + testDefaultReadWrite(myParams) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index d290cc9b06e73..810846051866c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -17,14 +17,96 @@ package org.apache.spark.ml.util -import org.apache.spark.ml.Model +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.tree.impl.TreeTests +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ -object MLTestingUtils { +object MLTestingUtils extends SparkFunSuite { def checkCopy(model: Model[_]): Unit = { val copied = model.copy(ParamMap.empty) .asInstanceOf[Model[_]] assert(copied.parent.uid == model.parent.uid) assert(copied.parent == model.parent) } + + def checkNumericTypes[M <: Model[M], T <: Estimator[M]]( + estimator: T, + isClassification: Boolean, + sqlContext: SQLContext)(check: (M, M) => Unit): Unit = { + val dfs = if (isClassification) { + genClassifDFWithNumericLabelCol(sqlContext) + } else { + genRegressionDFWithNumericLabelCol(sqlContext) + } + val expected = estimator.fit(dfs(DoubleType)) + val actuals = dfs.keys.filter(_ != DoubleType).map(t => estimator.fit(dfs(t))) + actuals.foreach(actual => check(expected, actual)) + + val dfWithStringLabels = generateDFWithStringLabelCol(sqlContext) + val thrown = intercept[IllegalArgumentException] { + estimator.fit(dfWithStringLabels) + } + assert(thrown.getMessage contains + "Column label must be of type NumericType but was actually of type StringType") + } + + def genClassifDFWithNumericLabelCol( + sqlContext: SQLContext, + labelColName: String = "label", + featuresColName: String = "features"): Map[NumericType, DataFrame] = { + val df = sqlContext.createDataFrame(Seq( + (0, Vectors.dense(0, 2, 3)), + (1, Vectors.dense(0, 3, 1)), + (0, Vectors.dense(0, 2, 2)), + (1, Vectors.dense(0, 3, 9)), + (0, Vectors.dense(0, 2, 6)) + )).toDF(labelColName, featuresColName) + + val types = + Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) + types.map(t => t -> df.select(col(labelColName).cast(t), col(featuresColName))) + .map { case (t, d) => t -> TreeTests.setMetadata(d, 2, labelColName) } + .toMap + } + + def genRegressionDFWithNumericLabelCol( + sqlContext: SQLContext, + labelColName: String = "label", + featuresColName: String = "features", + censorColName: String = "censor"): Map[NumericType, DataFrame] = { + val df = sqlContext.createDataFrame(Seq( + (0, Vectors.dense(0)), + (1, Vectors.dense(1)), + (2, Vectors.dense(2)), + (3, Vectors.dense(3)), + (4, Vectors.dense(4)) + )).toDF(labelColName, featuresColName) + + val types = + Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) + types + .map(t => t -> df.select(col(labelColName).cast(t), col(featuresColName))) + .map { case (t, d) => + t -> TreeTests.setMetadata(d, 0, labelColName).withColumn(censorColName, lit(0.0)) + } + .toMap + } + + def generateDFWithStringLabelCol( + sqlContext: SQLContext, + labelColName: String = "label", + featuresColName: String = "features", + censorColName: String = "censor"): DataFrame = + sqlContext.createDataFrame(Seq( + ("0", Vectors.dense(0, 2, 3), 0.0), + ("1", Vectors.dense(0, 3, 1), 1.0), + ("0", Vectors.dense(0, 2, 2), 0.0), + ("1", Vectors.dense(0, 3, 9), 1.0), + ("0", Vectors.dense(0, 2, 6), 0.0) + )).toDF(labelColName, featuresColName, censorColName) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala new file mode 100644 index 0000000000000..8f11bbc8e47af --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.File + +import org.scalatest.{BeforeAndAfterAll, Suite} + +import org.apache.spark.util.Utils + +/** + * Trait that creates a temporary directory before all tests and deletes it after all. + */ +trait TempDirectory extends BeforeAndAfterAll { self: Suite => + + private var _tempDir: File = _ + + /** Returns the temporary directory as a [[File]] instance. */ + protected def tempDir: File = _tempDir + + override def beforeAll(): Unit = { + super.beforeAll() + _tempDir = Utils.createTempDir(namePrefix = this.getClass.getName) + } + + override def afterAll(): Unit = { + try { + Utils.deleteRecursively(_tempDir) + } finally { + super.afterAll() + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala index 59944416d96a6..0eb839f20c003 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.mllib.api.python import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors, SparseMatrix} -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, SparseMatrix, Vectors} import org.apache.spark.mllib.recommendation.Rating +import org.apache.spark.mllib.regression.LabeledPoint class PythonMLLibAPISuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 8d14bb6572155..28fada7053d65 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -25,9 +25,11 @@ import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils @@ -170,6 +172,37 @@ object LogisticRegressionSuite { class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { + + @transient var binaryDataset: RDD[LabeledPoint] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + /* + Here is the instruction describing how to export the test data into CSV format + so we can validate the training accuracy compared with R's glmnet package. + + val nPoints = 10000 + val coefficients = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + val data = sc.parallelize(LogisticRegressionSuite.generateMultinomialLogisticInput( + coefficients, xMean, xVariance, true, nPoints, 42), 1) + data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1) + ", " + + x.features(2) + ", " + x.features(3)).saveAsTextFile("path") + */ + binaryDataset = { + val nPoints = 10000 + val coefficients = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + + val testData = LogisticRegressionSuite.generateMultinomialLogisticInput( + coefficients, xMean, xVariance, true, nPoints, 42) + + sc.parallelize(testData, 2) + } + } + def validatePrediction( predictions: Seq[Double], input: Seq[LabeledPoint], @@ -215,6 +248,11 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w // Test if we can correctly learn A, B where Y = logistic(A + B*X) test("logistic regression with LBFGS") { + val updaters: List[Updater] = List(new SquaredL2Updater(), new L1Updater()) + updaters.foreach(testLBFGS) + } + + private def testLBFGS(myUpdater: Updater): Unit = { val nPoints = 10000 val A = 2.0 val B = -1.5 @@ -223,7 +261,15 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w val testRDD = sc.parallelize(testData, 2) testRDD.cache() - val lr = new LogisticRegressionWithLBFGS().setIntercept(true) + + // Override the updater + class LogisticRegressionWithLBFGSCustomUpdater + extends LogisticRegressionWithLBFGS { + override val optimizer = + new LBFGS(new LogisticGradient, myUpdater) + } + + val lr = new LogisticRegressionWithLBFGSCustomUpdater().setIntercept(true) val model = lr.run(testRDD) @@ -396,10 +442,11 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w assert(modelA1.weights(0) ~== modelA3.weights(0) * 1.0E6 absTol 0.01) // Training data with different scales without feature standardization - // will not yield the same result in the scaled space due to poor - // convergence rate. - assert(modelB1.weights(0) !~== modelB2.weights(0) * 1.0E3 absTol 0.1) - assert(modelB1.weights(0) !~== modelB3.weights(0) * 1.0E6 absTol 0.1) + // should still converge quickly since the model still uses standardization but + // simply modifies the regularization function. See regParamL1Fun and related + // inside of LogisticRegression + assert(modelB1.weights(0) ~== modelB2.weights(0) * 1.0E3 absTol 0.1) + assert(modelB1.weights(0) ~== modelB3.weights(0) * 1.0E6 absTol 0.1) } test("multinomial logistic regression with LBFGS") { @@ -449,7 +496,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w * features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) * weights = coef(glmnet(features,label, family="multinomial", alpha = 0, lambda = 0)) * - * The model weights of mutinomial logstic regression in R have `K` set of linear predictors + * The model weights of multinomial logistic regression in R have `K` set of linear predictors * for `K` classes classification problem; however, only `K-1` set is required if the first * outcome is chosen as a "pivot", and the other `K-1` outcomes are separately regressed against * the pivot outcome. This can be done by subtracting the first weights from those `K-1` set @@ -540,6 +587,322 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w } } + /** + * From Spark 2.0, MLlib LogisticRegressionWithLBFGS will call the LogisticRegression + * implementation in ML to train model. We copies test cases from ML to guarantee + * they produce the same result. + */ + test("binary logistic regression with intercept without regularization") { + val trainer1 = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(true) + val trainer2 = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(false) + + val model1 = trainer1.run(binaryDataset) + val model2 = trainer2.run(binaryDataset) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0)) + coefficients + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 2.8366423 + data.V2 -0.5895848 + data.V3 0.8931147 + data.V4 -0.3925051 + data.V5 -0.7996864 + */ + val interceptR = 2.8366423 + val coefficientsR = Vectors.dense(-0.5895848, 0.8931147, -0.3925051, -0.7996864) + + assert(model1.intercept ~== interceptR relTol 1E-3) + assert(model1.weights ~= coefficientsR relTol 1E-3) + + // Without regularization, with or without feature scaling will converge to the same solution. + assert(model2.intercept ~== interceptR relTol 1E-3) + assert(model2.weights ~= coefficientsR relTol 1E-3) + } + + test("binary logistic regression without intercept without regularization") { + val trainer1 = new LogisticRegressionWithLBFGS().setIntercept(false).setFeatureScaling(true) + val trainer2 = new LogisticRegressionWithLBFGS().setIntercept(false).setFeatureScaling(false) + + val model1 = trainer1.run(binaryDataset) + val model2 = trainer2.run(binaryDataset) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + coefficients = + coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE)) + coefficients + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.3534996 + data.V3 1.2964482 + data.V4 -0.3571741 + data.V5 -0.7407946 + */ + val interceptR = 0.0 + val coefficientsR = Vectors.dense(-0.3534996, 1.2964482, -0.3571741, -0.7407946) + + assert(model1.intercept ~== interceptR relTol 1E-3) + assert(model1.weights ~= coefficientsR relTol 1E-2) + + // Without regularization, with or without feature scaling should converge to the same solution. + assert(model2.intercept ~== interceptR relTol 1E-3) + assert(model2.weights ~= coefficientsR relTol 1E-2) + } + + test("binary logistic regression with intercept with L1 regularization") { + val trainer1 = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(true) + trainer1.optimizer.setUpdater(new L1Updater).setRegParam(0.12) + val trainer2 = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(false) + trainer2.optimizer.setUpdater(new L1Updater).setRegParam(0.12) + + val model1 = trainer1.run(binaryDataset) + val model2 = trainer2.run(binaryDataset) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + coefficients = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12)) + coefficients + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) -0.05627428 + data.V2 . + data.V3 . + data.V4 -0.04325749 + data.V5 -0.02481551 + */ + val interceptR1 = -0.05627428 + val coefficientsR1 = Vectors.dense(0.0, 0.0, -0.04325749, -0.02481551) + + assert(model1.intercept ~== interceptR1 relTol 1E-2) + assert(model1.weights ~= coefficientsR1 absTol 2E-2) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + coefficients = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, + standardize=FALSE)) + coefficients + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.3722152 + data.V2 . + data.V3 . + data.V4 -0.1665453 + data.V5 . + */ + val interceptR2 = 0.3722152 + val coefficientsR2 = Vectors.dense(0.0, 0.0, -0.1665453, 0.0) + + assert(model2.intercept ~== interceptR2 relTol 1E-2) + assert(model2.weights ~= coefficientsR2 absTol 1E-3) + } + + test("binary logistic regression without intercept with L1 regularization") { + val trainer1 = new LogisticRegressionWithLBFGS().setIntercept(false).setFeatureScaling(true) + trainer1.optimizer.setUpdater(new L1Updater).setRegParam(0.12) + val trainer2 = new LogisticRegressionWithLBFGS().setIntercept(false).setFeatureScaling(false) + trainer2.optimizer.setUpdater(new L1Updater).setRegParam(0.12) + + val model1 = trainer1.run(binaryDataset) + val model2 = trainer2.run(binaryDataset) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + coefficients = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, + intercept=FALSE)) + coefficients + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 . + data.V3 . + data.V4 -0.05189203 + data.V5 -0.03891782 + */ + val interceptR1 = 0.0 + val coefficientsR1 = Vectors.dense(0.0, 0.0, -0.05189203, -0.03891782) + + assert(model1.intercept ~== interceptR1 relTol 1E-3) + assert(model1.weights ~= coefficientsR1 absTol 1E-3) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + coefficients = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, + intercept=FALSE, standardize=FALSE)) + coefficients + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 . + data.V3 . + data.V4 -0.08420782 + data.V5 . + */ + val interceptR2 = 0.0 + val coefficientsR2 = Vectors.dense(0.0, 0.0, -0.08420782, 0.0) + + assert(model2.intercept ~== interceptR2 absTol 1E-3) + assert(model2.weights ~= coefficientsR2 absTol 1E-3) + } + + test("binary logistic regression with intercept with L2 regularization") { + val trainer1 = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(true) + trainer1.optimizer.setUpdater(new SquaredL2Updater).setRegParam(1.37) + val trainer2 = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(false) + trainer2.optimizer.setUpdater(new SquaredL2Updater).setRegParam(1.37) + + val model1 = trainer1.run(binaryDataset) + val model2 = trainer2.run(binaryDataset) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37)) + coefficients + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.15021751 + data.V2 -0.07251837 + data.V3 0.10724191 + data.V4 -0.04865309 + data.V5 -0.10062872 + */ + val interceptR1 = 0.15021751 + val coefficientsR1 = Vectors.dense(-0.07251837, 0.10724191, -0.04865309, -0.10062872) + + assert(model1.intercept ~== interceptR1 relTol 1E-3) + assert(model1.weights ~= coefficientsR1 relTol 1E-3) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, + standardize=FALSE)) + coefficients + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.48657516 + data.V2 -0.05155371 + data.V3 0.02301057 + data.V4 -0.11482896 + data.V5 -0.06266838 + */ + val interceptR2 = 0.48657516 + val coefficientsR2 = Vectors.dense(-0.05155371, 0.02301057, -0.11482896, -0.06266838) + + assert(model2.intercept ~== interceptR2 relTol 1E-3) + assert(model2.weights ~= coefficientsR2 relTol 1E-3) + } + + test("binary logistic regression without intercept with L2 regularization") { + val trainer1 = new LogisticRegressionWithLBFGS().setIntercept(false).setFeatureScaling(true) + trainer1.optimizer.setUpdater(new SquaredL2Updater).setRegParam(1.37) + val trainer2 = new LogisticRegressionWithLBFGS().setIntercept(false).setFeatureScaling(false) + trainer2.optimizer.setUpdater(new SquaredL2Updater).setRegParam(1.37) + + val model1 = trainer1.run(binaryDataset) + val model2 = trainer2.run(binaryDataset) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, + intercept=FALSE)) + coefficients + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.06099165 + data.V3 0.12857058 + data.V4 -0.04708770 + data.V5 -0.09799775 + */ + val interceptR1 = 0.0 + val coefficientsR1 = Vectors.dense(-0.06099165, 0.12857058, -0.04708770, -0.09799775) + + assert(model1.intercept ~== interceptR1 absTol 1E-3) + assert(model1.weights ~= coefficientsR1 relTol 1E-2) + + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, + intercept=FALSE, standardize=FALSE)) + coefficients + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.005679651 + data.V3 0.048967094 + data.V4 -0.093714016 + data.V5 -0.053314311 + */ + val interceptR2 = 0.0 + val coefficientsR2 = Vectors.dense(-0.005679651, 0.048967094, -0.093714016, -0.053314311) + + assert(model2.intercept ~== interceptR2 absTol 1E-3) + assert(model2.weights ~= coefficientsR2 relTol 1E-2) + } + } class LogisticRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index cffa1ab700f80..ab54cb06d5aab 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Vector => BV} import breeze.stats.distributions.{Multinomial => BrzMultinomial} +import org.scalatest.exceptions.TestFailedException import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -103,17 +104,24 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { piData: Array[Double], thetaData: Array[Array[Double]], model: NaiveBayesModel): Unit = { - def closeFit(d1: Double, d2: Double, precision: Double): Boolean = { - (d1 - d2).abs <= precision - } - val modelIndex = (0 until piData.length).zip(model.labels.map(_.toInt)) - for (i <- modelIndex) { - assert(closeFit(math.exp(piData(i._2)), math.exp(model.pi(i._1)), 0.05)) - } - for (i <- modelIndex) { - for (j <- 0 until thetaData(i._2).length) { - assert(closeFit(math.exp(thetaData(i._2)(j)), math.exp(model.theta(i._1)(j)), 0.05)) + val modelIndex = piData.indices.zip(model.labels.map(_.toInt)) + try { + for (i <- modelIndex) { + assert(math.exp(piData(i._2)) ~== math.exp(model.pi(i._1)) absTol 0.05) + for (j <- thetaData(i._2).indices) { + assert(math.exp(thetaData(i._2)(j)) ~== math.exp(model.theta(i._1)(j)) absTol 0.05) + } } + } catch { + case e: TestFailedException => + def arr2str(a: Array[Double]): String = a.mkString("[", ", ", "]") + def msg(orig: String): String = orig + "\nvalidateModelFit:\n" + + " piData: " + arr2str(piData) + "\n" + + " thetaData: " + thetaData.map(arr2str).mkString("\n") + "\n" + + " model.labels: " + arr2str(model.labels) + "\n" + + " model.pi: " + arr2str(model.pi) + "\n" + + " model.theta: " + model.theta.map(arr2str).mkString("\n") + throw e.modifyMessage(_.map(msg)) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index ee3c85d09a463..3676d9c5debc8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.classification import scala.collection.JavaConverters._ import scala.util.Random -import org.jblas.DoubleMatrix +import breeze.linalg.{DenseVector => BDV} import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.Vectors @@ -45,12 +45,11 @@ object SVMSuite { nPoints: Int, seed: Int): Seq[LabeledPoint] = { val rnd = new Random(seed) - val weightsMat = new DoubleMatrix(1, weights.length, weights : _*) + val weightsMat = new BDV(weights) val x = Array.fill[Array[Double]](nPoints)( Array.fill[Double](weights.length)(rnd.nextDouble() * 2.0 - 1.0)) val y = x.map { xi => - val yD = new DoubleMatrix(1, xi.length, xi: _*).dot(weightsMat) + - intercept + 0.01 * rnd.nextGaussian() + val yD = new BDV(xi).dot(weightsMat) + intercept + 0.01 * rnd.nextGaussian() if (yD < 0) 0.0 else 1.0 } y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2))) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index d7b291d5a6330..bf98bf2f5fde5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -23,8 +23,8 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} +import org.apache.spark.streaming.dstream.DStream class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala new file mode 100644 index 0000000000000..35f7932ae8224 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils + +class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("default values") { + val bkm0 = new BisectingKMeans() + assert(bkm0.getK === 4) + assert(bkm0.getMaxIterations === 20) + assert(bkm0.getMinDivisibleClusterSize === 1.0) + val bkm1 = new BisectingKMeans() + assert(bkm0.getSeed === bkm1.getSeed, "The default seed should be constant.") + } + + test("setter/getter") { + val bkm = new BisectingKMeans() + + val k = 10 + assert(bkm.getK !== k) + assert(bkm.setK(k).getK === k) + val maxIter = 100 + assert(bkm.getMaxIterations !== maxIter) + assert(bkm.setMaxIterations(maxIter).getMaxIterations === maxIter) + val minSize = 2.0 + assert(bkm.getMinDivisibleClusterSize !== minSize) + assert(bkm.setMinDivisibleClusterSize(minSize).getMinDivisibleClusterSize === minSize) + val seed = 10L + assert(bkm.getSeed !== seed) + assert(bkm.setSeed(seed).getSeed === seed) + + intercept[IllegalArgumentException] { + bkm.setK(0) + } + intercept[IllegalArgumentException] { + bkm.setMaxIterations(0) + } + intercept[IllegalArgumentException] { + bkm.setMinDivisibleClusterSize(0.0) + } + } + + test("1D data") { + val points = Vectors.sparse(1, Array.empty, Array.empty) +: + (1 until 8).map(i => Vectors.dense(i)) + val data = sc.parallelize(points, 2) + val bkm = new BisectingKMeans() + .setK(4) + .setMaxIterations(1) + .setSeed(1L) + // The clusters should be + // (0, 1, 2, 3, 4, 5, 6, 7) + // - (0, 1, 2, 3) + // - (0, 1) + // - (2, 3) + // - (4, 5, 6, 7) + // - (4, 5) + // - (6, 7) + val model = bkm.run(data) + assert(model.k === 4) + // The total cost should be 8 * 0.5 * 0.5 = 2.0. + assert(model.computeCost(data) ~== 2.0 relTol 1e-12) + val predictions = data.map(v => (v(0), model.predict(v))).collectAsMap() + Range(0, 8, 2).foreach { i => + assert(predictions(i) === predictions(i + 1), + s"$i and ${i + 1} should belong to the same cluster.") + } + val root = model.root + assert(root.center(0) ~== 3.5 relTol 1e-12) + assert(root.height ~== 2.0 relTol 1e-12) + assert(root.children.length === 2) + assert(root.children(0).height ~== 1.0 relTol 1e-12) + assert(root.children(1).height ~== 1.0 relTol 1e-12) + } + + test("points are the same") { + val data = sc.parallelize(Seq.fill(8)(Vectors.dense(1.0, 1.0)), 2) + val bkm = new BisectingKMeans() + .setK(2) + .setMaxIterations(1) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 1) + } + + test("more desired clusters than points") { + val data = sc.parallelize(Seq.tabulate(4)(i => Vectors.dense(i)), 2) + val bkm = new BisectingKMeans() + .setK(8) + .setMaxIterations(2) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 4) + } + + test("min divisible cluster") { + val data = sc.parallelize( + Seq.tabulate(16)(i => Vectors.dense(i)) ++ Seq.tabulate(4)(i => Vectors.dense(-100.0 - i)), + 2) + val bkm = new BisectingKMeans() + .setK(4) + .setMinDivisibleClusterSize(10) + .setMaxIterations(1) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 3) + assert(model.predict(Vectors.dense(-100)) === model.predict(Vectors.dense(-97))) + assert(model.predict(Vectors.dense(7)) !== model.predict(Vectors.dense(8))) + + bkm.setMinDivisibleClusterSize(0.5) + val sameModel = bkm.run(data) + assert(sameModel.k === 3) + } + + test("larger clusters get selected first") { + val data = sc.parallelize( + Seq.tabulate(16)(i => Vectors.dense(i)) ++ Seq.tabulate(4)(i => Vectors.dense(-100.0 - i)), + 2) + val bkm = new BisectingKMeans() + .setK(3) + .setMaxIterations(1) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 3) + assert(model.predict(Vectors.dense(-100)) === model.predict(Vectors.dense(-97))) + assert(model.predict(Vectors.dense(7)) !== model.predict(Vectors.dense(8))) + } + + test("2D data") { + val points = Seq( + (11, 10), (9, 10), (10, 9), (10, 11), + (11, -10), (9, -10), (10, -9), (10, -11), + (0, 1), (0, -1) + ).map { case (x, y) => + if (x == 0) { + Vectors.sparse(2, Array(1), Array(y)) + } else { + Vectors.dense(x, y) + } + } + val data = sc.parallelize(points, 2) + val bkm = new BisectingKMeans() + .setK(3) + .setMaxIterations(4) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 3) + assert(model.root.center ~== Vectors.dense(8, 0) relTol 1e-12) + model.root.leafNodes.foreach { node => + if (node.center(0) < 5) { + assert(node.size === 2) + assert(node.center ~== Vectors.dense(0, 0) relTol 1e-12) + } else if (node.center(1) > 0) { + assert(node.size === 4) + assert(node.center ~== Vectors.dense(10, 10) relTol 1e-12) + } else { + assert(node.size === 4) + assert(node.center ~== Vectors.dense(10, -10) relTol 1e-12) + } + } + } + + test("BisectingKMeans model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val points = (1 until 8).map(i => Vectors.dense(i)) + val data = sc.parallelize(points, 2) + val model = new BisectingKMeans().run(data) + try { + model.save(sc, path) + val sameModel = BisectingKMeansModel.load(sc, path) + assert(model.k === sameModel.k) + model.clusterCenters.zip(sameModel.clusterCenters).foreach(c => c._1 === c._2) + } finally { + Utils.deleteRecursively(tempDir) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala index a72723eb00daf..67e680be73303 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.clustering import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrices} +import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -182,7 +182,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) ) - val data2: Array[Vector] = Array.tabulate(25){ i: Int => + val data2: Array[Vector] = Array.tabulate(25) { i: Int => Vectors.dense(Array.tabulate(50)(i + _.toDouble)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 37fb69d68f6be..ea23196d2c801 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.clustering import java.util.{ArrayList => JArrayList} -import breeze.linalg.{DenseMatrix => BDM, argtopk, max, argmax} +import breeze.linalg.{argmax, argtopk, max, DenseMatrix => BDM} import org.apache.spark.SparkFunSuite import org.apache.spark.graphx.Edge @@ -366,7 +366,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { (0, 0.99504), (1, 0.99504), (1, 0.99504), (1, 0.99504)) - val actualPredictions = ldaModel.topicDistributions(docs).map { case (id, topics) => + val actualPredictions = ldaModel.topicDistributions(docs).cache() + val topTopics = actualPredictions.map { case (id, topics) => // convert results to expectedPredictions format, which only has highest probability topic val topicsBz = topics.toBreeze.toDenseVector (id, (argmax(topicsBz), max(topicsBz))) @@ -374,9 +375,17 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { .values .collect() - expectedPredictions.zip(actualPredictions).forall { case (expected, actual) => - expected._1 === actual._1 && (expected._2 ~== actual._2 relTol 1E-3D) + expectedPredictions.zip(topTopics).foreach { case (expected, actual) => + assert(expected._1 === actual._1 && (expected._2 ~== actual._2 relTol 1E-3D)) } + + docs.collect() + .map(doc => ldaModel.topicDistribution(doc._2)) + .zip(actualPredictions.map(_._2).collect()) + .foreach { case (single, batch) => + assert(single ~== batch relTol 1E-3D) + } + actualPredictions.unpersist() } test("OnlineLDAOptimizer with asymmetric prior") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala index 189000512155f..3d81d375c716e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala @@ -30,62 +30,65 @@ class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkCon import org.apache.spark.mllib.clustering.PowerIterationClustering._ + /** Generates a circle of points. */ + private def genCircle(r: Double, n: Int): Array[(Double, Double)] = { + Array.tabulate(n) { i => + val theta = 2.0 * math.Pi * i / n + (r * math.cos(theta), r * math.sin(theta)) + } + } + + /** Computes Gaussian similarity. */ + private def sim(x: (Double, Double), y: (Double, Double)): Double = { + val dist2 = (x._1 - y._1) * (x._1 - y._1) + (x._2 - y._2) * (x._2 - y._2) + math.exp(-dist2 / 2.0) + } + test("power iteration clustering") { - /* - We use the following graph to test PIC. All edges are assigned similarity 1.0 except 0.1 for - edge (3, 4). - - 15-14 -13 -12 - | | - 4 . 3 - 2 11 - | | x | | - 5 0 - 1 10 - | | - 6 - 7 - 8 - 9 - */ + // Generate two circles following the example in the PIC paper. + val r1 = 1.0 + val n1 = 10 + val r2 = 4.0 + val n2 = 40 + val n = n1 + n2 + val points = genCircle(r1, n1) ++ genCircle(r2, n2) + val similarities = for (i <- 1 until n; j <- 0 until i) yield { + (i.toLong, j.toLong, sim(points(i), points(j))) + } - val similarities = Seq[(Long, Long, Double)]((0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), - (1, 3, 1.0), (2, 3, 1.0), (3, 4, 0.1), // (3, 4) is a weak edge - (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0), (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0), - (10, 11, 1.0), (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0)) val model = new PowerIterationClustering() .setK(2) + .setMaxIterations(40) .run(sc.parallelize(similarities, 2)) val predictions = Array.fill(2)(mutable.Set.empty[Long]) model.assignments.collect().foreach { a => predictions(a.cluster) += a.id } - assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) + assert(predictions.toSet == Set((0 until n1).toSet, (n1 until n).toSet)) val model2 = new PowerIterationClustering() .setK(2) + .setMaxIterations(10) .setInitializationMode("degree") .run(sc.parallelize(similarities, 2)) val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) model2.assignments.collect().foreach { a => predictions2(a.cluster) += a.id } - assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) + assert(predictions2.toSet == Set((0 until n1).toSet, (n1 until n).toSet)) } test("power iteration clustering on graph") { - /* - We use the following graph to test PIC. All edges are assigned similarity 1.0 except 0.1 for - edge (3, 4). - - 15-14 -13 -12 - | | - 4 . 3 - 2 11 - | | x | | - 5 0 - 1 10 - | | - 6 - 7 - 8 - 9 - */ - - val similarities = Seq[(Long, Long, Double)]((0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), - (1, 3, 1.0), (2, 3, 1.0), (3, 4, 0.1), // (3, 4) is a weak edge - (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0), (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0), - (10, 11, 1.0), (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0)) + // Generate two circles following the example in the PIC paper. + val r1 = 1.0 + val n1 = 10 + val r2 = 4.0 + val n2 = 40 + val n = n1 + n2 + val points = genCircle(r1, n1) ++ genCircle(r2, n2) + val similarities = for (i <- 1 until n; j <- 0 until i) yield { + (i.toLong, j.toLong, sim(points(i), points(j))) + } val edges = similarities.flatMap { case (i, j, s) => if (i != j) { @@ -98,22 +101,24 @@ class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkCon val model = new PowerIterationClustering() .setK(2) + .setMaxIterations(40) .run(graph) val predictions = Array.fill(2)(mutable.Set.empty[Long]) model.assignments.collect().foreach { a => predictions(a.cluster) += a.id } - assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) + assert(predictions.toSet == Set((0 until n1).toSet, (n1 until n).toSet)) val model2 = new PowerIterationClustering() .setK(2) + .setMaxIterations(10) .setInitializationMode("degree") .run(sc.parallelize(similarities, 2)) val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) model2.assignments.collect().foreach { a => predictions2(a.cluster) += a.id } - assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) + assert(predictions2.toSet == Set((0 until n1).toSet, (n1 until n).toSet)) } test("normalize and powerIter") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index 3645d29dccdb2..65e37c64d404e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -98,9 +98,16 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { runStreams(ssc, numBatches, numBatches) // check that estimated centers are close to true centers - // NOTE exact assignment depends on the initialization! - assert(centers(0) ~== kMeans.latestModel().clusterCenters(0) absTol 1E-1) - assert(centers(1) ~== kMeans.latestModel().clusterCenters(1) absTol 1E-1) + // cluster ordering is arbitrary, so choose closest cluster + val d0 = Vectors.sqdist(kMeans.latestModel().clusterCenters(0), centers(0)) + val d1 = Vectors.sqdist(kMeans.latestModel().clusterCenters(0), centers(1)) + val (c0, c1) = if (d0 < d1) { + (centers(0), centers(1)) + } else { + (centers(1), centers(0)) + } + assert(c0 ~== kMeans.latestModel().clusterCenters(0) absTol 1E-1) + assert(c1 ~== kMeans.latestModel().clusterCenters(1) absTol 1E-1) } test("detecting dying clusters") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala index c0924a213a844..77ec49d005398 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.mllib.evaluation import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("Ranking metrics: map, ndcg") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala index 4b7f1be58f99b..f1d517383643d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala @@ -22,91 +22,115 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { + val obs = List[Double](77, 85, 62, 55, 63, 88, 57, 81, 51) + val eps = 1E-5 test("regression metrics for unbiased (includes intercept term) predictor") { /* Verify results in R: - preds = c(2.25, -0.25, 1.75, 7.75) - obs = c(3.0, -0.5, 2.0, 7.0) - - SStot = sum((obs - mean(obs))^2) - SSreg = sum((preds - mean(obs))^2) - SSerr = sum((obs - preds)^2) - - explainedVariance = SSreg / length(obs) - explainedVariance - > [1] 8.796875 - meanAbsoluteError = mean(abs(preds - obs)) - meanAbsoluteError - > [1] 0.5 - meanSquaredError = mean((preds - obs)^2) - meanSquaredError - > [1] 0.3125 - rmse = sqrt(meanSquaredError) - rmse - > [1] 0.559017 - r2 = 1 - SSerr / SStot - r2 - > [1] 0.9571734 + y = c(77, 85, 62, 55, 63, 88, 57, 81, 51) + x = c(16, 22, 14, 10, 13, 19, 12, 18, 11) + df <- as.data.frame(cbind(x, y)) + model <- lm(y ~ x, data=df) + preds = signif(predict(model), digits = 4) + preds + 1 2 3 4 5 6 7 8 9 + 72.08 91.88 65.48 52.28 62.18 81.98 58.88 78.68 55.58 + options(digits=8) + explainedVariance = mean((preds - mean(y))^2) + [1] 157.3 + meanAbsoluteError = mean(abs(preds - y)) + meanAbsoluteError + [1] 3.7355556 + meanSquaredError = mean((preds - y)^2) + meanSquaredError + [1] 17.539511 + rmse = sqrt(meanSquaredError) + rmse + [1] 4.18802 + r2 = summary(model)$r.squared + r2 + [1] 0.89968225 */ - val predictionAndObservations = sc.parallelize( - Seq((2.25, 3.0), (-0.25, -0.5), (1.75, 2.0), (7.75, 7.0)), 2) + val preds = List(72.08, 91.88, 65.48, 52.28, 62.18, 81.98, 58.88, 78.68, 55.58) + val predictionAndObservations = sc.parallelize(preds.zip(obs), 2) val metrics = new RegressionMetrics(predictionAndObservations) - assert(metrics.explainedVariance ~== 8.79687 absTol 1E-5, + assert(metrics.explainedVariance ~== 157.3 absTol eps, "explained variance regression score mismatch") - assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch") - assert(metrics.meanSquaredError ~== 0.3125 absTol 1E-5, "mean squared error mismatch") - assert(metrics.rootMeanSquaredError ~== 0.55901 absTol 1E-5, + assert(metrics.meanAbsoluteError ~== 3.7355556 absTol eps, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 17.539511 absTol eps, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 4.18802 absTol eps, "root mean squared error mismatch") - assert(metrics.r2 ~== 0.95717 absTol 1E-5, "r2 score mismatch") + assert(metrics.r2 ~== 0.89968225 absTol eps, "r2 score mismatch") } test("regression metrics for biased (no intercept term) predictor") { /* Verify results in R: - preds = c(2.5, 0.0, 2.0, 8.0) - obs = c(3.0, -0.5, 2.0, 7.0) - - SStot = sum((obs - mean(obs))^2) - SSreg = sum((preds - mean(obs))^2) - SSerr = sum((obs - preds)^2) - - explainedVariance = SSreg / length(obs) - explainedVariance - > [1] 8.859375 - meanAbsoluteError = mean(abs(preds - obs)) - meanAbsoluteError - > [1] 0.5 - meanSquaredError = mean((preds - obs)^2) - meanSquaredError - > [1] 0.375 - rmse = sqrt(meanSquaredError) - rmse - > [1] 0.6123724 - r2 = 1 - SSerr / SStot - r2 - > [1] 0.9486081 + y = c(77, 85, 62, 55, 63, 88, 57, 81, 51) + x = c(16, 22, 14, 10, 13, 19, 12, 18, 11) + df <- as.data.frame(cbind(x, y)) + model <- lm(y ~ 0 + x, data=df) + preds = signif(predict(model), digits = 4) + preds + 1 2 3 4 5 6 7 8 9 + 72.12 99.17 63.11 45.08 58.60 85.65 54.09 81.14 49.58 + options(digits=8) + explainedVariance = mean((preds - mean(y))^2) + explainedVariance + [1] 294.88167 + meanAbsoluteError = mean(abs(preds - y)) + meanAbsoluteError + [1] 4.5888889 + meanSquaredError = mean((preds - y)^2) + meanSquaredError + [1] 39.958711 + rmse = sqrt(meanSquaredError) + rmse + [1] 6.3212903 + r2 = summary(model)$r.squared + r2 + [1] 0.99185395 */ - val predictionAndObservations = sc.parallelize( - Seq((2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)), 2) - val metrics = new RegressionMetrics(predictionAndObservations) - assert(metrics.explainedVariance ~== 8.85937 absTol 1E-5, + val preds = List(72.12, 99.17, 63.11, 45.08, 58.6, 85.65, 54.09, 81.14, 49.58) + val predictionAndObservations = sc.parallelize(preds.zip(obs), 2) + val metrics = new RegressionMetrics(predictionAndObservations, true) + assert(metrics.explainedVariance ~== 294.88167 absTol eps, "explained variance regression score mismatch") - assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch") - assert(metrics.meanSquaredError ~== 0.375 absTol 1E-5, "mean squared error mismatch") - assert(metrics.rootMeanSquaredError ~== 0.61237 absTol 1E-5, + assert(metrics.meanAbsoluteError ~== 4.5888889 absTol eps, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 39.958711 absTol eps, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 6.3212903 absTol eps, "root mean squared error mismatch") - assert(metrics.r2 ~== 0.94860 absTol 1E-5, "r2 score mismatch") + assert(metrics.r2 ~== 0.99185395 absTol eps, "r2 score mismatch") } test("regression metrics with complete fitting") { - val predictionAndObservations = sc.parallelize( - Seq((3.0, 3.0), (0.0, 0.0), (2.0, 2.0), (8.0, 8.0)), 2) + /* Verify results in R: + y = c(77, 85, 62, 55, 63, 88, 57, 81, 51) + preds = y + explainedVariance = mean((preds - mean(y))^2) + explainedVariance + [1] 174.8395 + meanAbsoluteError = mean(abs(preds - y)) + meanAbsoluteError + [1] 0 + meanSquaredError = mean((preds - y)^2) + meanSquaredError + [1] 0 + rmse = sqrt(meanSquaredError) + rmse + [1] 0 + r2 = 1 - sum((preds - y)^2)/sum((y - mean(y))^2) + r2 + [1] 1 + */ + val preds = obs + val predictionAndObservations = sc.parallelize(preds.zip(obs), 2) val metrics = new RegressionMetrics(predictionAndObservations) - assert(metrics.explainedVariance ~== 8.6875 absTol 1E-5, + assert(metrics.explainedVariance ~== 174.83951 absTol eps, "explained variance regression score mismatch") - assert(metrics.meanAbsoluteError ~== 0.0 absTol 1E-5, "mean absolute error mismatch") - assert(metrics.meanSquaredError ~== 0.0 absTol 1E-5, "mean squared error mismatch") - assert(metrics.rootMeanSquaredError ~== 0.0 absTol 1E-5, + assert(metrics.meanAbsoluteError ~== 0.0 absTol eps, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 0.0 absTol eps, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 0.0 absTol eps, "root mean squared error mismatch") - assert(metrics.r2 ~== 1.0 absTol 1E-5, "r2 score mismatch") + assert(metrics.r2 ~== 1.0 absTol eps, "r2 score mismatch") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala index cf279c02334e9..6c07e3a5cef2e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.feature import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -48,4 +49,15 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { val docs = sc.parallelize(localDocs, 2) assert(hashingTF.transform(docs).collect().toSet === localDocs.map(hashingTF.transform).toSet) } + + test("applying binary term freqs") { + val hashingTF = new HashingTF(100).setBinary(true) + val doc = "a a b c c c".split(" ") + val n = hashingTF.numFeatures + val expected = Vectors.sparse(n, Seq( + (hashingTF.indexOf("a"), 1.0), + (hashingTF.indexOf("b"), 1.0), + (hashingTF.indexOf("c"), 1.0))) + assert(hashingTF.transform(doc) ~== expected absTol 1e-14) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala index 21163633051e5..5c938a61ed990 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.feature import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala index e57f49191378f..a8d82932d3904 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala @@ -37,11 +37,12 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext { val pca = new PCA(k).fit(dataRDD) val mat = new RowMatrix(dataRDD) - val pc = mat.computePrincipalComponents(k) + val (pc, explainedVariance) = mat.computePrincipalComponentsAndExplainedVariance(k) val pca_transform = pca.transform(dataRDD).collect() val mat_multiply = mat.multiply(pc).rows.collect() assert(pca_transform.toSet === mat_multiply.toSet) + assert(pca.explainedVariance === explainedVariance) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala index 6ab2fa6770123..b4e26b2aeb3cf 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.mllib.feature import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} import org.apache.spark.rdd.RDD class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index a864eec460f2b..4fcf417d5f82e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.mllib.feature import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext - -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -92,4 +90,23 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { } } + + test("big model load / save") { + // create a model bigger than 32MB since 9000 * 1000 * 4 > 2^25 + val word2VecMap = Map((0 to 9000).map(i => s"$i" -> Array.fill(1000)(0.1f)): _*) + val model = new Word2VecModel(word2VecMap) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + try { + model.save(sc, path) + val sameModel = Word2VecModel.load(sc, path) + assert(sameModel.getVectors.mapValues(_.toSeq) === model.getVectors.mapValues(_.toSeq)) + } finally { + Utils.deleteRecursively(tempDir) + } + } + + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala index 77a2773c36f56..dcb1f398b04b8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala @@ -42,6 +42,7 @@ class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext { .collect() /* Verify results using the `R` code: + library(arules) transactions = as(sapply( list("r z h k p", "z y x w v u t s", @@ -52,7 +53,7 @@ class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext { FUN=function(x) strsplit(x," ",fixed=TRUE)), "transactions") ars = apriori(transactions, - parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2)) + parameter = list(support = 0.5, confidence = 0.9, target="rules", minlen=2)) arsDF = as(ars, "data.frame") arsDF$support = arsDF$support * length(transactions) names(arsDF)[names(arsDF) == "support"] = "freq" diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index 4a9bfdb348d9f..dc44c58e97eb4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -16,8 +16,11 @@ */ package org.apache.spark.mllib.fpm +import scala.language.existentials + import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -274,4 +277,71 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { */ assert(model1.freqItemsets.count() === 65) } + + test("model save/load with String type") { + val transactions = Seq( + "r z h k p", + "z y x w v u t s", + "s x o n r", + "x z y m t s q e", + "z", + "x z y r q t p") + .map(_.split(" ")) + val rdd = sc.parallelize(transactions, 2).cache() + + val model3 = new FPGrowth() + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd) + val freqItemsets3 = model3.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + try { + model3.save(sc, path) + val newModel = FPGrowthModel.load(sc, path) + val newFreqItemsets = newModel.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + assert(freqItemsets3.toSet === newFreqItemsets.toSet) + } finally { + Utils.deleteRecursively(tempDir) + } + } + + test("model save/load with Int type") { + val transactions = Seq( + "1 2 3", + "1 2 3 4", + "5 4 3 2 1", + "6 5 4 3 2 1", + "2 4", + "1 3", + "1 7") + .map(_.split(" ").map(_.toInt).toArray) + val rdd = sc.parallelize(transactions, 2).cache() + + val model3 = new FPGrowth() + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd) + val freqItemsets3 = model3.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + try { + model3.save(sc, path) + val newModel = FPGrowthModel.load(sc, path) + val newFreqItemsets = newModel.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + assert(freqItemsets3.toSet === newFreqItemsets.toSet) + } finally { + Utils.deleteRecursively(tempDir) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index a83e543859b8a..6d8c7b47d8373 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -357,6 +358,36 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { compareResults(expected, model.freqSequences.collect()) } + test("model save/load") { + val sequences = Seq( + Array(Array(1, 2), Array(3)), + Array(Array(1), Array(3, 2), Array(1, 2)), + Array(Array(1, 2), Array(5)), + Array(Array(6))) + val rdd = sc.parallelize(sequences, 2).cache() + + val prefixSpan = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + val model = prefixSpan.run(rdd) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + try { + model.save(sc, path) + val newModel = PrefixSpanModel.load(sc, path) + val originalSet = model.freqSequences.collect().map { x => + (x.sequence.map(_.toSet).toSeq, x.freq) + }.toSet + val newSet = newModel.freqSequences.collect().map { x => + (x.sequence.map(_.toSet).toSeq, x.freq) + }.toSet + assert(originalSet === newSet) + } finally { + Utils.deleteRecursively(tempDir) + } + } + private def compareResults[Item]( expectedValue: Array[(Array[Array[Item]], Long)], actualValue: Array[PrefixSpan.FreqSequence[Item]]): Unit = { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 96e5ffef7a131..80da03cc2efeb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.mllib.linalg import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.linalg.BLAS._ +import org.apache.spark.mllib.util.TestingUtils._ class BLASSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala index dc04258e41d27..de2c3c13bd923 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.linalg -import breeze.linalg.{DenseMatrix => BDM, CSCMatrix => BSM} +import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM} import org.apache.spark.SparkFunSuite diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 1833cf3833671..e289724cdaa3c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.linalg import java.util.Random +import breeze.linalg.{CSCMatrix, Matrix => BM} import org.mockito.Mockito.when import org.scalatest.mock.MockitoSugar._ import scala.collection.mutable.{Map => MutableMap} @@ -150,6 +151,10 @@ class MatricesSuite extends SparkFunSuite { sparseMat.update(0, 0, 10.0) } + intercept[NoSuchElementException] { + sparseMat.update(2, 1, 10.0) + } + sparseMat.update(0, 1, 10.0) assert(sparseMat(0, 1) === 10.0) assert(sparseMat.values(2) === 10.0) @@ -494,4 +499,28 @@ class MatricesSuite extends SparkFunSuite { assert(sm1.numNonzeros === 1) assert(sm1.numActives === 3) } + + test("fromBreeze with sparse matrix") { + // colPtr.last does NOT always equal to values.length in breeze SCSMatrix and + // invocation of compact() may be necessary. Refer to SPARK-11507 + val bm1: BM[Double] = new CSCMatrix[Double]( + Array(1.0, 1, 1), 3, 3, Array(0, 1, 2, 3), Array(0, 1, 2)) + val bm2: BM[Double] = new CSCMatrix[Double]( + Array(1.0, 2, 2, 4), 3, 3, Array(0, 0, 2, 4), Array(1, 2, 1, 2)) + val sum = bm1 + bm2 + Matrices.fromBreeze(sum) + } + + test("row/col iterator") { + val dm = new DenseMatrix(3, 2, Array(0, 1, 2, 3, 4, 0)) + val sm = dm.toSparse + val rows = Seq(Vectors.dense(0, 3), Vectors.dense(1, 4), Vectors.dense(2, 0)) + val cols = Seq(Vectors.dense(0, 1, 2), Vectors.dense(3, 4, 0)) + for (m <- Seq(dm, sm)) { + assert(m.rowIter.toSeq === rows) + assert(m.colIter.toSeq === cols) + assert(m.transpose.rowIter.toSeq === cols) + assert(m.transpose.colIter.toSeq === rows) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 6508ddeba4206..e5567492a2c76 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -19,9 +19,11 @@ package org.apache.spark.mllib.linalg import scala.util.Random -import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance} +import breeze.linalg.{squaredDistance => breezeSquaredDistance, DenseMatrix => BDM} +import org.json4s.jackson.JsonMethods.{parse => parseJson} -import org.apache.spark.{Logging, SparkException, SparkFunSuite} +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.internal.Logging import org.apache.spark.mllib.util.TestingUtils._ class VectorsSuite extends SparkFunSuite with Logging { @@ -374,4 +376,20 @@ class VectorsSuite extends SparkFunSuite with Logging { assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2))) assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), Array(2.2, 4.4))) } + + test("toJson/fromJson") { + val sv0 = Vectors.sparse(0, Array.empty, Array.empty) + val sv1 = Vectors.sparse(1, Array.empty, Array.empty) + val sv2 = Vectors.sparse(2, Array(1), Array(2.0)) + val dv0 = Vectors.dense(Array.empty[Double]) + val dv1 = Vectors.dense(1.0) + val dv2 = Vectors.dense(0.0, 2.0) + for (v <- Seq(sv0, sv1, sv2, dv0, dv1, dv2)) { + val json = v.toJson + parseJson(json) // `json` should be a valid JSON string + val u = Vectors.fromJson(json) + assert(u.getClass === v.getClass, "toJson/fromJson should preserve vector types.") + assert(u === v, "toJson/fromJson should preserve vector values.") + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala index b8eb10305801c..f37eaf225ab88 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.mllib.linalg.distributed import java.{util => ju} -import breeze.linalg.{DenseMatrix => BDM} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV} import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.mllib.linalg.{SparseMatrix, DenseMatrix, Matrices, Matrix} +import org.apache.spark.mllib.linalg.{DenseMatrix, DenseVector, Matrices, Matrix, SparseMatrix, SparseVector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -134,6 +134,33 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { assert(rowMat.numRows() === m) assert(rowMat.numCols() === n) assert(rowMat.toBreeze() === gridBasedMat.toBreeze()) + + val rows = 1 + val cols = 10 + + val matDense = new DenseMatrix(rows, cols, + Array(1.0, 1.0, 3.0, 2.0, 5.0, 6.0, 7.0, 1.0, 2.0, 3.0)) + val matSparse = new SparseMatrix(rows, cols, + Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1), Array(0), Array(1.0)) + + val vectors: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), matDense), + ((1, 0), matSparse)) + + val rdd = sc.parallelize(vectors) + val B = new BlockMatrix(rdd, rows, cols) + + val C = B.toIndexedRowMatrix.rows.collect + + (C(0).vector.toBreeze, C(1).vector.toBreeze) match { + case (denseVector: BDV[Double], sparseVector: BSV[Double]) => + assert(denseVector.length === sparseVector.length) + + assert(matDense.toArray === denseVector.toArray) + assert(matSparse.toArray === sparseVector.toArray) + case _ => + throw new RuntimeException("IndexedRow returns vectors of unexpected type") + } } test("toBreeze and toLocalMatrix") { @@ -192,6 +219,49 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { assert(sparseBM.add(sparseBM).toBreeze() === sparseBM.add(denseBM).toBreeze()) } + test("subtract") { + val blocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))), + ((0, 1), new DenseMatrix(2, 2, Array(0.0, 1.0, 0.0, 0.0))), + ((1, 0), new DenseMatrix(2, 2, Array(3.0, 0.0, 1.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 2.0, 0.0, 1.0))), + ((2, 0), new DenseMatrix(1, 2, Array(1.0, 0.0))), // Added block that doesn't exist in A + ((2, 1), new DenseMatrix(1, 2, Array(1.0, 5.0)))) + val rdd = sc.parallelize(blocks, numPartitions) + val B = new BlockMatrix(rdd, rowPerPart, colPerPart) + + val expected = BDM( + (0.0, 0.0, 0.0, 0.0), + (0.0, 0.0, 0.0, 0.0), + (0.0, 0.0, 0.0, 0.0), + (0.0, 0.0, 0.0, 0.0), + (-1.0, 0.0, 0.0, 0.0)) + + val AsubtractB = gridBasedMat.subtract(B) + assert(AsubtractB.numRows() === m) + assert(AsubtractB.numCols() === B.numCols()) + assert(AsubtractB.toBreeze() === expected) + + val C = new BlockMatrix(rdd, rowPerPart, colPerPart, m, n + 1) // columns don't match + intercept[IllegalArgumentException] { + gridBasedMat.subtract(C) + } + val largerBlocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(4, 4, new Array[Double](16))), + ((1, 0), new DenseMatrix(1, 4, Array(1.0, 0.0, 1.0, 5.0)))) + val C2 = new BlockMatrix(sc.parallelize(largerBlocks, numPartitions), 4, 4, m, n) + intercept[SparkException] { // partitioning doesn't match + gridBasedMat.subtract(C2) + } + // subtracting BlockMatrices composed of SparseMatrices + val sparseBlocks = for (i <- 0 until 4) yield ((i / 2, i % 2), SparseMatrix.speye(4)) + val denseBlocks = for (i <- 0 until 4) yield ((i / 2, i % 2), DenseMatrix.eye(4)) + val sparseBM = new BlockMatrix(sc.makeRDD(sparseBlocks, 4), 4, 4, 8, 8) + val denseBM = new BlockMatrix(sc.makeRDD(denseBlocks, 4), 4, 4, 8, 8) + + assert(sparseBM.subtract(sparseBM).toBreeze() === sparseBM.subtract(denseBM).toBreeze()) + } + test("multiply") { // identity matrix val blocks: Seq[((Int, Int), Matrix)] = Seq( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala index f3728cd036a3f..37d75103d18d2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext class CoordinateMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index 6de6cf2fa8634..5b7ccb90158b0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{diag => brzDiag, DenseMatrix => BDM, DenseVector => BDV} import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{Matrices, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.{Matrices, Vectors} class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 4abb98fb6fe4e..2dff52c601d81 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -17,13 +17,15 @@ package org.apache.spark.mllib.linalg.distributed +import java.util.Arrays + import scala.util.Random +import breeze.linalg.{norm => brzNorm, svd => brzSvd, DenseMatrix => BDM, DenseVector => BDV} import breeze.numerics.abs -import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd} import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector} +import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors} import org.apache.spark.mllib.random.RandomRDDs import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} @@ -49,6 +51,7 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { (0.0, 1.0, 0.0), (math.sqrt(2.0) / 2.0, 0.0, math.sqrt(2.0) / 2.0), (math.sqrt(2.0) / 2.0, 0.0, - math.sqrt(2.0) / 2.0)) + val explainedVariance = BDV(4.0 / 7.0, 3.0 / 7.0, 0.0) var denseMat: RowMatrix = _ var sparseMat: RowMatrix = _ @@ -201,10 +204,15 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { test("pca") { for (mat <- Seq(denseMat, sparseMat); k <- 1 to n) { - val pc = denseMat.computePrincipalComponents(k) + val (pc, expVariance) = mat.computePrincipalComponentsAndExplainedVariance(k) assert(pc.numRows === n) assert(pc.numCols === k) assertColumnEqualUpToSign(pc.toBreeze.asInstanceOf[BDM[Double]], principalComponents, k) + assert( + closeToZero(BDV(expVariance.toArray) - + BDV(Arrays.copyOfRange(explainedVariance.data, 0, k)))) + // Check that this method returns the same answer + assert(pc === mat.computePrincipalComponents(k)) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index 36ac7d267243d..1c9b7c78e5b8d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{MLUtils, LocalClusterSparkContext, MLlibTestSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext, MLUtils} import org.apache.spark.mllib.util.TestingUtils._ object GradientDescentSuite { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala index d8f9b8c33963d..4ec3dc0df03b5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala @@ -19,28 +19,22 @@ package org.apache.spark.mllib.optimization import scala.util.Random -import org.jblas.{DoubleMatrix, SimpleBlas} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV} import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ class NNLSSuite extends SparkFunSuite { /** Generate an NNLS problem whose optimal solution is the all-ones vector. */ - def genOnesData(n: Int, rand: Random): (DoubleMatrix, DoubleMatrix) = { - val A = new DoubleMatrix(n, n, Array.fill(n*n)(rand.nextDouble()): _*) - val b = A.mmul(DoubleMatrix.ones(n, 1)) - - val ata = A.transpose.mmul(A) - val atb = A.transpose.mmul(b) - - (ata, atb) + def genOnesData(n: Int, rand: Random): (BDM[Double], BDV[Double]) = { + val A = new BDM(n, n, Array.fill(n*n)(rand.nextDouble())) + val b = A * new BDV(Array.fill(n)(1.0)) + (A.t * A, A.t * b) } /** Compute the objective value */ - def computeObjectiveValue(ata: DoubleMatrix, atb: DoubleMatrix, x: DoubleMatrix): Double = { - val res = (x.transpose().mmul(ata).mmul(x)).mul(0.5).sub(atb.dot(x)) - res.get(0) - } + def computeObjectiveValue(ata: BDM[Double], atb: BDV[Double], x: BDV[Double]): Double = + (x.t * ata * x) / 2.0 - atb.dot(x) test("NNLS: exact solution cases") { val n = 20 @@ -54,12 +48,15 @@ class NNLSSuite extends SparkFunSuite { for (k <- 0 until 100) { val (ata, atb) = genOnesData(n, rand) - val x = new DoubleMatrix(NNLS.solve(ata.data, atb.data, ws)) + val x = new BDV(NNLS.solve(ata.data, atb.data, ws)) assert(x.length === n) - val answer = DoubleMatrix.ones(n, 1) - SimpleBlas.axpy(-1.0, answer, x) - val solved = (x.norm2 < 1e-2) && (x.normmax < 1e-3) - if (solved) numSolved = numSolved + 1 + val answer = new BDV(Array.fill(n)(1.0)) + val solved = + (breeze.linalg.norm(x - answer) < 0.01) && // L2 norm + ((x - answer).toArray.map(_.abs).max < 0.001) // inf norm + if (solved) { + numSolved += 1 + } } assert(numSolved > 50) @@ -67,20 +64,18 @@ class NNLSSuite extends SparkFunSuite { test("NNLS: nonnegativity constraint active") { val n = 5 - // scalastyle:off - val ata = new DoubleMatrix(Array( - Array( 4.377, -3.531, -1.306, -0.139, 3.418), - Array(-3.531, 4.344, 0.934, 0.305, -2.140), - Array(-1.306, 0.934, 2.644, -0.203, -0.170), - Array(-0.139, 0.305, -0.203, 5.883, 1.428), - Array( 3.418, -2.140, -0.170, 1.428, 4.684))) - // scalastyle:on - val atb = new DoubleMatrix(Array(-1.632, 2.115, 1.094, -1.025, -0.636)) + val ata = Array( + 4.377, -3.531, -1.306, -0.139, 3.418, + -3.531, 4.344, 0.934, 0.305, -2.140, + -1.306, 0.934, 2.644, -0.203, -0.170, + -0.139, 0.305, -0.203, 5.883, 1.428, + 3.418, -2.140, -0.170, 1.428, 4.684) + val atb = Array(-1.632, 2.115, 1.094, -1.025, -0.636) val goodx = Array(0.13025, 0.54506, 0.2874, 0.0, 0.028628) val ws = NNLS.createWorkspace(n) - val x = NNLS.solve(ata.data, atb.data, ws) + val x = NNLS.solve(ata, atb, ws) for (i <- 0 until n) { assert(x(i) ~== goodx(i) absTol 1E-3) assert(x(i) >= 0) @@ -89,23 +84,21 @@ class NNLSSuite extends SparkFunSuite { test("NNLS: objective value test") { val n = 5 - val ata = new DoubleMatrix(5, 5 - , 517399.13534, 242529.67289, -153644.98976, 130802.84503, -798452.29283 - , 242529.67289, 126017.69765, -75944.21743, 81785.36128, -405290.60884 - , -153644.98976, -75944.21743, 46986.44577, -45401.12659, 247059.51049 - , 130802.84503, 81785.36128, -45401.12659, 67457.31310, -253747.03819 - , -798452.29283, -405290.60884, 247059.51049, -253747.03819, 1310939.40814 - ) - val atb = new DoubleMatrix(5, 1, - -31755.05710, 13047.14813, -20191.24443, 25993.77580, 11963.55017) + val ata = new BDM(5, 5, Array( + 517399.13534, 242529.67289, -153644.98976, 130802.84503, -798452.29283, + 242529.67289, 126017.69765, -75944.21743, 81785.36128, -405290.60884, + -153644.98976, -75944.21743, 46986.44577, -45401.12659, 247059.51049, + 130802.84503, 81785.36128, -45401.12659, 67457.31310, -253747.03819, + -798452.29283, -405290.60884, 247059.51049, -253747.03819, 1310939.40814)) + val atb = new BDV(Array(-31755.05710, 13047.14813, -20191.24443, 25993.77580, 11963.55017)) /** reference solution obtained from matlab function quadprog */ - val refx = new DoubleMatrix(Array(34.90751, 103.96254, 0.00000, 27.82094, 58.79627)) + val refx = new BDV(Array(34.90751, 103.96254, 0.00000, 27.82094, 58.79627)) val refObj = computeObjectiveValue(ata, atb, refx) val ws = NNLS.createWorkspace(n) - val x = new DoubleMatrix(NNLS.solve(ata.data, atb.data, ws)) + val x = new BDV(NNLS.solve(ata.data, atb.data, ws)) val obj = computeObjectiveValue(ata, atb, x) assert(obj < refObj + 1E-5) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala index 413db2000d6d7..f464d25c3fbda 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala @@ -20,9 +20,8 @@ package org.apache.spark.mllib.random import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkFunSuite -import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.rdd.{RandomRDDPartition, RandomRDD} +import org.apache.spark.mllib.rdd.{RandomRDD, RandomRDDPartition} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.util.StatCounter diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala index 10f5a2be48f7c..56231429859ee 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.mllib.rdd import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ +import org.apache.spark.mllib.util.MLlibTestSparkContext class MLPairRDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { test("topByKey") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index bc64172614830..0e931fca6cf07 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.mllib.rdd import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.mllib.util.MLlibTestSparkContext class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -28,9 +28,12 @@ class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { for (numPartitions <- 1 to 8) { val rdd = sc.parallelize(data, numPartitions) for (windowSize <- 1 to 6) { - val sliding = rdd.sliding(windowSize).collect().map(_.toList).toList - val expected = data.sliding(windowSize).map(_.toList).toList - assert(sliding === expected) + for (step <- 1 to 3) { + val sliding = rdd.sliding(windowSize, step).collect().map(_.toList).toList + val expected = data.sliding(windowSize, step) + .map(_.toList).toList.filter(l => l.size == windowSize) + assert(sliding === expected) + } } assert(rdd.sliding(7).collect().isEmpty, "Should return an empty RDD if the window size is greater than the number of items.") @@ -40,7 +43,7 @@ class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { test("sliding with empty partitions") { val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7)) val rdd = sc.parallelize(data, data.length).flatMap(s => s) - assert(rdd.partitions.size === data.length) + assert(rdd.partitions.length === data.length) val sliding = rdd.sliding(3).collect().toSeq.map(_.toSeq) val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq) assert(sliding === expected) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index 045135f7f8d60..d9dc557e3b2b9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import scala.math.abs import scala.util.Random -import org.jblas.DoubleMatrix +import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -29,16 +29,16 @@ import org.apache.spark.storage.StorageLevel object ALSSuite { - def generateRatingsAsJavaList( + def generateRatingsAsJava( users: Int, products: Int, features: Int, samplingRate: Double, implicitPrefs: Boolean, - negativeWeights: Boolean): (java.util.List[Rating], DoubleMatrix, DoubleMatrix) = { + negativeWeights: Boolean): (java.util.List[Rating], Array[Double], Array[Double]) = { val (sampledRatings, trueRatings, truePrefs) = - generateRatings(users, products, features, samplingRate, implicitPrefs) - (sampledRatings.asJava, trueRatings, truePrefs) + generateRatings(users, products, features, samplingRate, implicitPrefs, negativeWeights) + (sampledRatings.asJava, trueRatings.toArray, if (truePrefs == null) null else truePrefs.toArray) } def generateRatings( @@ -48,35 +48,36 @@ object ALSSuite { samplingRate: Double, implicitPrefs: Boolean = false, negativeWeights: Boolean = false, - negativeFactors: Boolean = true): (Seq[Rating], DoubleMatrix, DoubleMatrix) = { + negativeFactors: Boolean = true): (Seq[Rating], BDM[Double], BDM[Double]) = { val rand = new Random(42) // Create a random matrix with uniform values from -1 to 1 def randomMatrix(m: Int, n: Int) = { if (negativeFactors) { - new DoubleMatrix(m, n, Array.fill(m * n)(rand.nextDouble() * 2 - 1): _*) + new BDM(m, n, Array.fill(m * n)(rand.nextDouble() * 2 - 1)) } else { - new DoubleMatrix(m, n, Array.fill(m * n)(rand.nextDouble()): _*) + new BDM(m, n, Array.fill(m * n)(rand.nextDouble())) } } val userMatrix = randomMatrix(users, features) val productMatrix = randomMatrix(features, products) - val (trueRatings, truePrefs) = implicitPrefs match { - case true => + val (trueRatings, truePrefs) = + if (implicitPrefs) { // Generate raw values from [0,9], or if negativeWeights, from [-2,7] - val raw = new DoubleMatrix(users, products, + val raw = new BDM(users, products, Array.fill(users * products)( - (if (negativeWeights) -2 else 0) + rand.nextInt(10).toDouble): _*) + (if (negativeWeights) -2 else 0) + rand.nextInt(10).toDouble)) val prefs = - new DoubleMatrix(users, products, raw.data.map(v => if (v > 0) 1.0 else 0.0): _*) + new BDM(users, products, raw.data.map(v => if (v > 0) 1.0 else 0.0)) (raw, prefs) - case false => (userMatrix.mmul(productMatrix), null) - } + } else { + (userMatrix * productMatrix, null) + } val sampledRatings = { for (u <- 0 until users; p <- 0 until products if rand.nextDouble() < samplingRate) - yield Rating(u, p, trueRatings.get(u, p)) + yield Rating(u, p, trueRatings(u, p)) } (sampledRatings, trueRatings, truePrefs) @@ -149,8 +150,8 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext { .setSeed(1) .setFinalRDDStorageLevel(storageLevel) .run(ratings) - assert(model.productFeatures.getStorageLevel == storageLevel); - assert(model.userFeatures.getStorageLevel == storageLevel); + assert(model.productFeatures.getStorageLevel == storageLevel) + assert(model.userFeatures.getStorageLevel == storageLevel) storageLevel = StorageLevel.DISK_ONLY model = new ALS() .setRank(5) @@ -160,8 +161,8 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext { .setSeed(1) .setFinalRDDStorageLevel(storageLevel) .run(ratings) - assert(model.productFeatures.getStorageLevel == storageLevel); - assert(model.userFeatures.getStorageLevel == storageLevel); + assert(model.productFeatures.getStorageLevel == storageLevel) + assert(model.userFeatures.getStorageLevel == storageLevel) } test("negative ids") { @@ -178,7 +179,7 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext { val u = r.user + 25 val p = r.product + 25 val v = r.rating - val error = v - correct.get(u, p) + val error = v - correct(u, p) assert(math.abs(error) < 0.4) } } @@ -197,7 +198,7 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext { * @param samplingRate what fraction of the user-product pairs are known * @param matchThreshold max difference allowed to consider a predicted rating correct * @param implicitPrefs flag to test implicit feedback - * @param bulkPredict flag to test bulk predicition + * @param bulkPredict flag to test bulk prediction * @param negativeWeights whether the generated data can contain negative values * @param numUserBlocks number of user blocks to partition users into * @param numProductBlocks number of product blocks to partition products into @@ -234,30 +235,31 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext { .setNonnegative(!negativeFactors) .run(sc.parallelize(sampledRatings)) - val predictedU = new DoubleMatrix(users, features) + val predictedU = new BDM[Double](users, features) for ((u, vec) <- model.userFeatures.collect(); i <- 0 until features) { - predictedU.put(u, i, vec(i)) + predictedU(u, i) = vec(i) } - val predictedP = new DoubleMatrix(products, features) + val predictedP = new BDM[Double](products, features) for ((p, vec) <- model.productFeatures.collect(); i <- 0 until features) { - predictedP.put(p, i, vec(i)) + predictedP(p, i) = vec(i) } - val predictedRatings = bulkPredict match { - case false => predictedU.mmul(predictedP.transpose) - case true => - val allRatings = new DoubleMatrix(users, products) + val predictedRatings = + if (bulkPredict) { + val allRatings = new BDM[Double](users, products) val usersProducts = for (u <- 0 until users; p <- 0 until products) yield (u, p) val userProductsRDD = sc.parallelize(usersProducts) model.predict(userProductsRDD).collect().foreach { elem => - allRatings.put(elem.user, elem.product, elem.rating) + allRatings(elem.user, elem.product) = elem.rating } allRatings - } + } else { + predictedU * predictedP.t + } if (!implicitPrefs) { for (u <- 0 until users; p <- 0 until products) { - val prediction = predictedRatings.get(u, p) - val correct = trueRatings.get(u, p) + val prediction = predictedRatings(u, p) + val correct = trueRatings(u, p) if (math.abs(prediction - correct) > matchThreshold) { fail(("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s") .format(u, p, correct, prediction, trueRatings, predictedRatings, predictedU, @@ -269,9 +271,9 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext { var sqErr = 0.0 var denom = 0.0 for (u <- 0 until users; p <- 0 until products) { - val prediction = predictedRatings.get(u, p) - val truePref = truePrefs.get(u, p) - val confidence = 1 + 1.0 * abs(trueRatings.get(u, p)) + val prediction = predictedRatings(u, p) + val truePref = truePrefs(u, p) + val confidence = 1.0 + abs(trueRatings(u, p)) val err = confidence * (truePref - prediction) * (truePref - prediction) sqErr += err denom += confidence diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index 39537e7bb4c72..d96103d01e4ab 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -21,7 +21,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, +import org.apache.spark.mllib.util.{LinearDataGenerator, LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.util.Utils diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index f88a1c33c9f7c..0694079b9df9e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -21,7 +21,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, +import org.apache.spark.mllib.util.{LinearDataGenerator, LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.util.Utils diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index 7a781fee634c8..815be32d2e510 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -19,11 +19,9 @@ package org.apache.spark.mllib.regression import scala.util.Random -import org.jblas.DoubleMatrix - import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, +import org.apache.spark.mllib.util.{LinearDataGenerator, LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.util.Utils @@ -38,7 +36,7 @@ class RidgeRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]): Double = { predictions.zip(input).map { case (prediction, expected) => (prediction - expected.label) * (prediction - expected.label) - }.reduceLeft(_ + _) / predictions.size + }.sum / predictions.size } test("ridge regression can help avoid overfitting") { @@ -49,12 +47,12 @@ class RidgeRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val numExamples = 50 val numFeatures = 20 - org.jblas.util.Random.seed(42) // Pick weights as random values distributed uniformly in [-0.5, 0.5] - val w = DoubleMatrix.rand(numFeatures, 1).subi(0.5) + val random = new Random(42) + val w = Array.fill(numFeatures)(random.nextDouble() - 0.5) // Use half of data for training and other half for validation - val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, 2 * numExamples, 42, 10.0) + val data = LinearDataGenerator.generateLinearInput(3.0, w, 2 * numExamples, 42, 10.0) val testData = data.take(numExamples) val validationData = data.takeRight(numExamples) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala index c3eeda012571c..eaa819c2e6e39 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -19,7 +19,8 @@ package org.apache.spark.mllib.stat import breeze.linalg.{DenseMatrix => BDM, Matrix => BM} -import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation, SpearmanCorrelation} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala index 142b90e764a7c..46fcebe132749 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala @@ -144,7 +144,7 @@ class HypothesisTestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(chi.size === numCols) assert(chi(1000) != null) // SPARK-3087 - // Detect continous features or labels + // Detect continuous features or labels val random = new Random(11L) val continuousLabel = Seq.fill(100000)(LabeledPoint(random.nextDouble(), Vectors.dense(random.nextInt(2)))) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala index d3e9ef4ff079c..0921fdba339ca 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.mllib.stat import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.stat.test.{StreamingTest, StreamingTestResult, StudentTTest, WelchTTest} +import org.apache.spark.mllib.stat.test.{BinarySample, StreamingTest, StreamingTestResult, + StudentTTest, WelchTTest} import org.apache.spark.streaming.TestSuiteBase import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.StatCounter @@ -26,7 +27,7 @@ import org.apache.spark.util.random.XORShiftRandom class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { - override def maxWaitTimeMillis : Int = 30000 + override def maxWaitTimeMillis: Int = 30000 test("accuracy for null hypothesis using welch t-test") { // set parameters @@ -48,7 +49,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( - input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) assert(outputBatches.flatten.forall(res => @@ -75,7 +76,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( - input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) assert(outputBatches.flatten.forall(res => @@ -102,7 +103,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( - input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) @@ -130,7 +131,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( - input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) assert(outputBatches.flatten.forall(res => @@ -157,13 +158,13 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( input, - (inputDStream: DStream[(Boolean, Double)]) => model.summarizeByKeyAndWindow(inputDStream)) + (inputDStream: DStream[BinarySample]) => model.summarizeByKeyAndWindow(inputDStream)) val outputBatches = runStreams[(Boolean, StatCounter)](ssc, numBatches, numBatches) val outputCounts = outputBatches.flatten.map(_._2.count) // number of batches seen so far does not exceed testWindow, expect counts to continue growing for (i <- 0 until testWindow) { - assert(outputCounts.drop(2 * i).take(2).forall(_ == (i + 1) * pointsPerBatch / 2)) + assert(outputCounts.slice(2 * i, 2 * i + 2).forall(_ == (i + 1) * pointsPerBatch / 2)) } // number of batches seen exceeds testWindow, expect counts to be constant @@ -190,7 +191,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( - input, (inputDStream: DStream[(Boolean, Double)]) => model.dropPeacePeriod(inputDStream)) + input, (inputDStream: DStream[BinarySample]) => model.dropPeacePeriod(inputDStream)) val outputBatches = runStreams[(Boolean, Double)](ssc, numBatches, numBatches) assert(outputBatches.flatten.length == (numBatches - peacePeriod) * pointsPerBatch) @@ -210,11 +211,11 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { .setPeacePeriod(0) val input = generateTestData(numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) - .map(batch => batch.filter(_._1)) // only keep one test group + .map(batch => batch.filter(_.isExperiment)) // only keep one test group // setup and run the model val ssc = setupStreams( - input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) assert(outputBatches.flatten.forall(result => (result.pValue - 1.0).abs < 0.001)) @@ -228,13 +229,13 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { stdevA: Double, meanB: Double, stdevB: Double, - seed: Int): (IndexedSeq[IndexedSeq[(Boolean, Double)]]) = { + seed: Int): (IndexedSeq[IndexedSeq[BinarySample]]) = { val rand = new XORShiftRandom(seed) val numTrues = pointsPerBatch / 2 val data = (0 until numBatches).map { i => - (0 until numTrues).map { idx => (true, meanA + stdevA * rand.nextGaussian())} ++ + (0 until numTrues).map { idx => BinarySample(true, meanA + stdevA * rand.nextGaussian())} ++ (pointsPerBatch / 2 until pointsPerBatch).map { idx => - (false, meanB + stdevB * rand.nextGaussian()) + BinarySample(false, meanB + stdevB * rand.nextGaussian()) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala index 6e7a003475458..669d44223d713 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.stat.distribution import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{ Vectors, Matrices } +import org.apache.spark.mllib.linalg.{Matrices, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 1a4299db4eab2..49cb7e1f24e35 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -18,15 +18,14 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ -import scala.collection.mutable import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.tree.impl.DecisionTreeMetadata import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ -import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy} -import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint} +import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -35,378 +34,6 @@ import org.apache.spark.util.Utils class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { - ///////////////////////////////////////////////////////////////////////////// - // Tests examining individual elements of training - ///////////////////////////////////////////////////////////////////////////// - - test("Binary classification with continuous features: split and bin calculation") { - val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() - assert(arr.length === 1000) - val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - assert(!metadata.isUnordered(featureIndex = 0)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(bins.length === 2) - assert(splits(0).length === 99) - assert(bins(0).length === 100) - } - - test("Binary classification with binary (ordered) categorical features:" + - " split and bin calculation") { - val arr = DecisionTreeSuite.generateCategoricalDataPoints() - assert(arr.length === 1000) - val rdd = sc.parallelize(arr) - val strategy = new Strategy( - Classification, - Gini, - maxDepth = 2, - numClasses = 2, - maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) - - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(!metadata.isUnordered(featureIndex = 0)) - assert(!metadata.isUnordered(featureIndex = 1)) - assert(splits.length === 2) - assert(bins.length === 2) - // no bins or splits pre-computed for ordered categorical features - assert(splits(0).length === 0) - assert(bins(0).length === 0) - } - - test("Binary classification with 3-ary (ordered) categorical features," + - " with no samples for one category") { - val arr = DecisionTreeSuite.generateCategoricalDataPoints() - assert(arr.length === 1000) - val rdd = sc.parallelize(arr) - val strategy = new Strategy( - Classification, - Gini, - maxDepth = 2, - numClasses = 2, - maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - assert(!metadata.isUnordered(featureIndex = 0)) - assert(!metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(bins.length === 2) - // no bins or splits pre-computed for ordered categorical features - assert(splits(0).length === 0) - assert(bins(0).length === 0) - } - - test("extract categories from a number for multiclass classification") { - val l = DecisionTree.extractMultiClassCategories(13, 10) - assert(l.length === 3) - assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq) - } - - test("find splits for a continuous feature") { - // find splits for normal case - { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, - Map(), Set(), - Array(6), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 - ) - val featureSamples = Array.fill(200000)(math.random) - val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits.length === 5) - assert(fakeMetadata.numSplits(0) === 5) - assert(fakeMetadata.numBins(0) === 6) - // check returned splits are distinct - assert(splits.distinct.length === splits.length) - } - - // find splits should not return identical splits - // when there are not enough split candidates, reduce the number of splits in metadata - { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, - Map(), Set(), - Array(5), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 - ) - val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) - val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits.length === 3) - // check returned splits are distinct - assert(splits.distinct.length === splits.length) - } - - // find splits when most samples close to the minimum - { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, - Map(), Set(), - Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 - ) - val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) - val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits.length === 2) - assert(splits(0) === 2.0) - assert(splits(1) === 3.0) - } - - // find splits when most samples close to the maximum - { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, - Map(), Set(), - Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 - ) - val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) - val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits.length === 1) - assert(splits(0) === 1.0) - } - } - - test("Multiclass classification with unordered categorical features:" + - " split and bin calculations") { - val arr = DecisionTreeSuite.generateCategoricalDataPoints() - assert(arr.length === 1000) - val rdd = sc.parallelize(arr) - val strategy = new Strategy( - Classification, - Gini, - maxDepth = 2, - numClasses = 100, - maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - assert(metadata.isUnordered(featureIndex = 0)) - assert(metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(bins.length === 2) - assert(splits(0).length === 3) - assert(bins(0).length === 0) - - // Expecting 2^2 - 1 = 3 bins/splits - assert(splits(0)(0).feature === 0) - assert(splits(0)(0).threshold === Double.MinValue) - assert(splits(0)(0).featureType === Categorical) - assert(splits(0)(0).categories.length === 1) - assert(splits(0)(0).categories.contains(0.0)) - assert(splits(1)(0).feature === 1) - assert(splits(1)(0).threshold === Double.MinValue) - assert(splits(1)(0).featureType === Categorical) - assert(splits(1)(0).categories.length === 1) - assert(splits(1)(0).categories.contains(0.0)) - - assert(splits(0)(1).feature === 0) - assert(splits(0)(1).threshold === Double.MinValue) - assert(splits(0)(1).featureType === Categorical) - assert(splits(0)(1).categories.length === 1) - assert(splits(0)(1).categories.contains(1.0)) - assert(splits(1)(1).feature === 1) - assert(splits(1)(1).threshold === Double.MinValue) - assert(splits(1)(1).featureType === Categorical) - assert(splits(1)(1).categories.length === 1) - assert(splits(1)(1).categories.contains(1.0)) - - assert(splits(0)(2).feature === 0) - assert(splits(0)(2).threshold === Double.MinValue) - assert(splits(0)(2).featureType === Categorical) - assert(splits(0)(2).categories.length === 2) - assert(splits(0)(2).categories.contains(0.0)) - assert(splits(0)(2).categories.contains(1.0)) - assert(splits(1)(2).feature === 1) - assert(splits(1)(2).threshold === Double.MinValue) - assert(splits(1)(2).featureType === Categorical) - assert(splits(1)(2).categories.length === 2) - assert(splits(1)(2).categories.contains(0.0)) - assert(splits(1)(2).categories.contains(1.0)) - - } - - test("Multiclass classification with ordered categorical features: split and bin calculations") { - val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() - assert(arr.length === 3000) - val rdd = sc.parallelize(arr) - val strategy = new Strategy( - Classification, - Gini, - maxDepth = 2, - numClasses = 100, - maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) - // 2^(10-1) - 1 > 100, so categorical features will be ordered - - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - assert(!metadata.isUnordered(featureIndex = 0)) - assert(!metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(bins.length === 2) - // no bins or splits pre-computed for ordered categorical features - assert(splits(0).length === 0) - assert(bins(0).length === 0) - } - - test("Avoid aggregation on the last level") { - val arr = Array( - LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), - LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), - LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), - LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) - val input = sc.parallelize(arr) - - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, - numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) - - val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) - - val topNode = Node.emptyNode(nodeIndex = 1) - assert(topNode.predict.predict === Double.MinValue) - assert(topNode.impurity === -1.0) - assert(topNode.isLeaf === false) - - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) - val nodeQueue = new mutable.Queue[(Int, Node)]() - DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), - nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) - - // don't enqueue leaf nodes into node queue - assert(nodeQueue.isEmpty) - - // set impurity and predict for topNode - assert(topNode.predict.predict !== Double.MinValue) - assert(topNode.impurity !== -1.0) - - // set impurity and predict for child nodes - assert(topNode.leftNode.get.predict.predict === 0.0) - assert(topNode.rightNode.get.predict.predict === 1.0) - assert(topNode.leftNode.get.impurity === 0.0) - assert(topNode.rightNode.get.impurity === 0.0) - } - - test("Avoid aggregation if impurity is 0.0") { - val arr = Array( - LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), - LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), - LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), - LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) - val input = sc.parallelize(arr) - - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, - numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) - - val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) - - val topNode = Node.emptyNode(nodeIndex = 1) - assert(topNode.predict.predict === Double.MinValue) - assert(topNode.impurity === -1.0) - assert(topNode.isLeaf === false) - - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) - val nodeQueue = new mutable.Queue[(Int, Node)]() - DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), - nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) - - // don't enqueue a node into node queue if its impurity is 0.0 - assert(nodeQueue.isEmpty) - - // set impurity and predict for topNode - assert(topNode.predict.predict !== Double.MinValue) - assert(topNode.impurity !== -1.0) - - // set impurity and predict for child nodes - assert(topNode.leftNode.get.predict.predict === 0.0) - assert(topNode.rightNode.get.predict.predict === 1.0) - assert(topNode.leftNode.get.impurity === 0.0) - assert(topNode.rightNode.get.impurity === 0.0) - } - - test("Second level node building with vs. without groups") { - val arr = DecisionTreeSuite.generateOrderedLabeledPoints() - assert(arr.length === 1000) - val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(splits(0).length === 99) - assert(bins.length === 2) - assert(bins(0).length === 100) - - // Train a 1-node model - val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1, - numClasses = 2, maxBins = 100) - val modelOneNode = DecisionTree.train(rdd, strategyOneNode) - val rootNode1 = modelOneNode.topNode.deepCopy() - val rootNode2 = modelOneNode.topNode.deepCopy() - assert(rootNode1.leftNode.nonEmpty) - assert(rootNode1.rightNode.nonEmpty) - - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) - - // Single group second level tree construction. - val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get))) - val treeToNodeToIndexInfo = Map((0, Map( - (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)), - (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None))))) - val nodeQueue = new mutable.Queue[(Int, Node)]() - DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1), - nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) - val children1 = new Array[Node](2) - children1(0) = rootNode1.leftNode.get - children1(1) = rootNode1.rightNode.get - - // Train one second-level node at a time. - val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get))) - val treeToNodeToIndexInfoA = Map((0, Map( - (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) - nodeQueue.clear() - DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), - nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue) - val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get))) - val treeToNodeToIndexInfoB = Map((0, Map( - (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) - nodeQueue.clear() - DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), - nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue) - val children2 = new Array[Node](2) - children2(0) = rootNode2.leftNode.get - children2(1) = rootNode2.rightNode.get - - // Verify whether the splits obtained using single group and multiple group level - // construction strategies are the same. - for (i <- 0 until 2) { - assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0) - assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0) - assert(children1(i).split === children2(i).split) - assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty) - val stats1 = children1(i).stats.get - val stats2 = children2(i).stats.get - assert(stats1.gain === stats2.gain) - assert(stats1.impurity === stats2.impurity) - assert(stats1.leftImpurity === stats2.leftImpurity) - assert(stats1.rightImpurity === stats2.rightImpurity) - assert(children1(i).predict.predict === children2(i).predict.predict) - } - } - ///////////////////////////////////////////////////////////////////////////// // Tests calling train() ///////////////////////////////////////////////////////////////////////////// @@ -421,24 +48,13 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { numClasses = 2, maxDepth = 2, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - assert(!metadata.isUnordered(featureIndex = 0)) - assert(!metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(bins.length === 2) - // no bins or splits pre-computed for ordered categorical features - assert(splits(0).length === 0) - assert(bins(0).length === 0) + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) val rootNode = DecisionTree.train(rdd, strategy).topNode val split = rootNode.split.get assert(split.categories === List(1.0)) assert(split.featureType === Categorical) - assert(split.threshold === Double.MinValue) val stats = rootNode.stats.get assert(stats.gain > 0) @@ -455,7 +71,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { Variance, maxDepth = 2, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) @@ -467,7 +83,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(split.categories.length === 1) assert(split.categories.contains(1.0)) assert(split.featureType === Categorical) - assert(split.threshold === Double.MinValue) val stats = rootNode.stats.get assert(stats.gain > 0) @@ -484,7 +99,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { Variance, maxDepth = 2, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -505,18 +120,11 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(splits(0).length === 99) - assert(bins.length === 2) - assert(bins(0).length === 100) - val rootNode = DecisionTree.train(rdd, strategy).topNode - val stats = rootNode.stats.get - assert(stats.gain === 0) - assert(stats.leftImpurity === 0) - assert(stats.rightImpurity === 0) + assert(rootNode.impurity === 0) + assert(rootNode.stats.isEmpty) + assert(rootNode.predict.predict === 0) } test("Binary classification stump with fixed label 1 for Gini") { @@ -529,18 +137,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(splits(0).length === 99) - assert(bins.length === 2) - assert(bins(0).length === 100) - val rootNode = DecisionTree.train(rdd, strategy).topNode - val stats = rootNode.stats.get - assert(stats.gain === 0) - assert(stats.leftImpurity === 0) - assert(stats.rightImpurity === 0) + assert(rootNode.impurity === 0) + assert(rootNode.stats.isEmpty) assert(rootNode.predict.predict === 1) } @@ -554,18 +154,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(splits(0).length === 99) - assert(bins.length === 2) - assert(bins(0).length === 100) - val rootNode = DecisionTree.train(rdd, strategy).topNode - val stats = rootNode.stats.get - assert(stats.gain === 0) - assert(stats.leftImpurity === 0) - assert(stats.rightImpurity === 0) + assert(rootNode.impurity === 0) + assert(rootNode.stats.isEmpty) assert(rootNode.predict.predict === 0) } @@ -579,18 +171,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(splits(0).length === 99) - assert(bins.length === 2) - assert(bins(0).length === 100) - val rootNode = DecisionTree.train(rdd, strategy).topNode - val stats = rootNode.stats.get - assert(stats.gain === 0) - assert(stats.leftImpurity === 0) - assert(stats.rightImpurity === 0) + assert(rootNode.impurity === 0) + assert(rootNode.stats.isEmpty) assert(rootNode.predict.predict === 1) } @@ -684,7 +268,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 3, maxBins = 100) assert(strategy.isMulticlassClassification) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) val model = DecisionTree.train(rdd, strategy) DecisionTreeSuite.validateClassifier(model, arr, 0.9) @@ -773,8 +356,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { // test when no valid split can be found val rootNode = model.topNode - val gain = rootNode.stats.get - assert(gain == InformationGainStats.invalidInformationGainStats) + assert(rootNode.stats.isEmpty) } test("do not choose split that does not satisfy min instance per node requirements") { @@ -788,15 +370,16 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, - maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2), + maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2), numClasses = 2, minInstancesPerNode = 2) val rootNode = DecisionTree.train(rdd, strategy).topNode val split = rootNode.split.get - val gain = rootNode.stats.get + val gainStats = rootNode.stats.get assert(split.feature == 1) - assert(gain != InformationGainStats.invalidInformationGainStats) + assert(gainStats.gain >= 0) + assert(gainStats.impurity >= 0) } test("split must satisfy min info gain requirements") { @@ -818,10 +401,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { } // test when no valid split can be found - val rootNode = model.topNode - - val gain = rootNode.stats.get - assert(gain == InformationGainStats.invalidInformationGainStats) + assert(model.topNode.stats.isEmpty) } ///////////////////////////////////////////////////////////////////////////// @@ -1045,7 +625,7 @@ object DecisionTreeSuite extends SparkFunSuite { assert(a.isLeaf === b.isLeaf) assert(a.split === b.split) (a.stats, b.stats) match { - // TODO: Check other fields besides the infomation gain. + // TODO: Check other fields besides the information gain. case (Some(aStats), Some(bStats)) => assert(aStats.gain === bStats.gain) case (None, None) => case _ => throw new AssertionError( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala index 3d3f80063f904..1cc8f342021a0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala @@ -17,13 +17,13 @@ package org.apache.spark.mllib.tree +import scala.collection.mutable + import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.util.StatCounter -import scala.collection.mutable - object EnsembleTestHelper { /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 6fc9e8df621df..c61f89322d35f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -17,17 +17,17 @@ package org.apache.spark.mllib.tree -import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.Logging import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} +import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.impurity.Variance -import org.apache.spark.mllib.tree.loss.{AbsoluteError, SquaredError, LogLoss} +import org.apache.spark.mllib.tree.loss.{AbsoluteError, LogLoss, SquaredError} import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.Utils - /** * Test suite for [[GradientBoostedTrees]]. */ @@ -158,49 +158,6 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext } } - test("runWithValidation stops early and performs better on a validation dataset") { - // Set numIterations large enough so that it stops early. - val numIterations = 20 - val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2) - val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2) - - val algos = Array(Regression, Regression, Classification) - val losses = Array(SquaredError, AbsoluteError, LogLoss) - algos.zip(losses).foreach { case (algo, loss) => - val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2, - categoricalFeaturesInfo = Map.empty) - val boostingStrategy = - new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) - val gbtValidate = new GradientBoostedTrees(boostingStrategy) - .runWithValidation(trainRdd, validateRdd) - val numTrees = gbtValidate.numTrees - assert(numTrees !== numIterations) - - // Test that it performs better on the validation dataset. - val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd) - val (errorWithoutValidation, errorWithValidation) = { - if (algo == Classification) { - val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) - (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd)) - } else { - (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd)) - } - } - assert(errorWithValidation <= errorWithoutValidation) - - // Test that results from evaluateEachIteration comply with runWithValidation. - // Note that convergenceTol is set to 0.0 - val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss) - assert(evaluationArray.length === numIterations) - assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1)) - var i = 1 - while (i < numTrees) { - assert(evaluationArray(i) <= evaluationArray(i - 1)) - i += 1 - } - } - } - test("Checkpointing") { val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString @@ -220,7 +177,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext } -private object GradientBoostedTreesSuite { +private[spark] object GradientBoostedTreesSuite { // Combinations for estimators, learning rates and subsamplingRate val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala index 49aff21fe7914..14152cdd63bc7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.mllib.tree import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator} -import org.apache.spark.mllib.util.MLlibTestSparkContext /** * Test suites for [[GiniAggregator]] and [[EntropyAggregator]]. */ -class ImpuritySuite extends SparkFunSuite with MLlibTestSparkContext { +class ImpuritySuite extends SparkFunSuite { test("Gini impurity does not support negative labels") { val gini = new GiniAggregator(2) intercept[IllegalArgumentException] { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index e6df5d974bf36..bec61ba6a003c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -17,16 +17,13 @@ package org.apache.spark.mllib.tree -import scala.collection.mutable - import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata import org.apache.spark.mllib.tree.impurity.{Gini, Variance} -import org.apache.spark.mllib.tree.model.{Node, RandomForestModel} +import org.apache.spark.mllib.tree.model.RandomForestModel import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.Utils @@ -42,7 +39,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees, featureSubsetStrategy = "auto", seed = 123) - assert(rf.trees.size === 1) + assert(rf.trees.length === 1) val rfTree = rf.trees(0) val dt = DecisionTree.train(rdd, strategy) @@ -78,7 +75,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees, featureSubsetStrategy = "auto", seed = 123) - assert(rf.trees.size === 1) + assert(rf.trees.length === 1) val rfTree = rf.trees(0) val dt = DecisionTree.train(rdd, strategy) @@ -108,80 +105,6 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { regressionTestWithContinuousFeatures(strategy) } - def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: Strategy) { - val numFeatures = 50 - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000) - val rdd = sc.parallelize(arr) - - // Select feature subset for top nodes. Return true if OK. - def checkFeatureSubsetStrategy( - numTrees: Int, - featureSubsetStrategy: String, - numFeaturesPerNode: Int): Unit = { - val seeds = Array(123, 5354, 230, 349867, 23987) - val maxMemoryUsage: Long = 128 * 1024L * 1024L - val metadata = - DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees, featureSubsetStrategy) - seeds.foreach { seed => - val failString = s"Failed on test with:" + - s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," + - s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed" - val nodeQueue = new mutable.Queue[(Int, Node)]() - val topNodes: Array[Node] = new Array[Node](numTrees) - Range(0, numTrees).foreach { treeIndex => - topNodes(treeIndex) = Node.emptyNode(nodeIndex = 1) - nodeQueue.enqueue((treeIndex, topNodes(treeIndex))) - } - val rng = new scala.util.Random(seed = seed) - val (nodesForGroup: Map[Int, Array[Node]], - treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) = - RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) - - assert(nodesForGroup.size === numTrees, failString) - assert(nodesForGroup.values.forall(_.size == 1), failString) // 1 node per tree - - if (numFeaturesPerNode == numFeatures) { - // featureSubset values should all be None - assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)), - failString) - } else { - // Check number of features. - assert(treeToNodeToIndexInfo.values.forall(_.values.forall( - _.featureSubset.get.size === numFeaturesPerNode)), failString) - } - } - } - - checkFeatureSubsetStrategy(numTrees = 1, "auto", numFeatures) - checkFeatureSubsetStrategy(numTrees = 1, "all", numFeatures) - checkFeatureSubsetStrategy(numTrees = 1, "sqrt", math.sqrt(numFeatures).ceil.toInt) - checkFeatureSubsetStrategy(numTrees = 1, "log2", - (math.log(numFeatures) / math.log(2)).ceil.toInt) - checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt) - - checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures) - checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt) - checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt) - checkFeatureSubsetStrategy(numTrees = 2, "log2", - (math.log(numFeatures) / math.log(2)).ceil.toInt) - checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt) - } - - test("Binary classification with continuous features: subsampling features") { - val categoricalFeaturesInfo = Map.empty[Int, Int] - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, - numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) - binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) - } - - test("Binary classification with continuous features and node Id cache: subsampling features") { - val categoricalFeaturesInfo = Map.empty[Int, Int] - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, - numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, - useNodeIdCache = true) - binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) - } - test("alternating categorical and continuous features with multiclass labels to test indexing") { val arr = new Array[LabeledPoint](4) arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0)) @@ -197,7 +120,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { featureSubsetStrategy = "sqrt", seed = 12345) } - test("subsampling rate in RandomForest"){ + test("subsampling rate in RandomForest") { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(5, 20) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala index 525ab68c7921a..95d874b8432eb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.util -import org.scalatest.{Suite, BeforeAndAfterAll} +import org.scalatest.{BeforeAndAfterAll, Suite} import org.apache.spark.{SparkConf, SparkContext} @@ -25,18 +25,21 @@ trait LocalClusterSparkContext extends BeforeAndAfterAll { self: Suite => @transient var sc: SparkContext = _ override def beforeAll() { + super.beforeAll() val conf = new SparkConf() .setMaster("local-cluster[2, 1, 1024]") .setAppName("test-cluster") - .set("spark.akka.frameSize", "1") // set to 1MB to detect direct serialization of data + .set("spark.rpc.message.maxSize", "1") // set to 1MB to detect direct serialization of data sc = new SparkContext(conf) - super.beforeAll() } override def afterAll() { - if (sc != null) { - sc.stop() + try { + if (sc != null) { + sc.stop() + } + } finally { + super.afterAll() } - super.afterAll() } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 70219e9ad9d3e..e542f21a1802c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.mllib.util import java.io.File +import java.nio.charset.StandardCharsets import scala.io.Source import breeze.linalg.{squaredDistance => breezeSquaredDistance} -import com.google.common.base.Charsets import com.google.common.io.Files import org.apache.spark.SparkException @@ -84,7 +84,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { """.stripMargin val tempDir = Utils.createTempDir() val file = new File(tempDir.getPath, "part-00000") - Files.write(lines, file, Charsets.US_ASCII) + Files.write(lines, file, StandardCharsets.UTF_8) val path = tempDir.toURI.toString val pointsWithNumFeatures = loadLibSVMFile(sc, path, 6).collect() @@ -117,7 +117,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { """.stripMargin val tempDir = Utils.createTempDir() val file = new File(tempDir.getPath, "part-00000") - Files.write(lines, file, Charsets.US_ASCII) + Files.write(lines, file, StandardCharsets.UTF_8) val path = tempDir.toURI.toString intercept[SparkException] { @@ -134,7 +134,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { """.stripMargin val tempDir = Utils.createTempDir() val file = new File(tempDir.getPath, "part-00000") - Files.write(lines, file, Charsets.US_ASCII) + Files.write(lines, file, StandardCharsets.UTF_8) val path = tempDir.toURI.toString intercept[SparkException] { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index 5d1796ef65722..cb1efd525134a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -17,14 +17,20 @@ package org.apache.spark.mllib.util -import org.scalatest.{BeforeAndAfterAll, Suite} +import java.io.File + +import org.apache.hadoop.fs.Path +import org.scalatest.Suite import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.ml.util.TempDirectory import org.apache.spark.sql.SQLContext +import org.apache.spark.util.Utils -trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => +trait MLlibTestSparkContext extends TempDirectory { self: Suite => @transient var sc: SparkContext = _ @transient var sqlContext: SQLContext = _ + @transient var checkpointDir: String = _ override def beforeAll() { super.beforeAll() @@ -32,15 +38,24 @@ trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => .setMaster("local[2]") .setAppName("MLlibUnitTest") sc = new SparkContext(conf) + SQLContext.clearActive() sqlContext = new SQLContext(sc) + SQLContext.setActive(sqlContext) + checkpointDir = Utils.createDirectory(tempDir.getCanonicalPath, "checkpoints").toString + sc.setCheckpointDir(checkpointDir) } override def afterAll() { - sqlContext = null - if (sc != null) { - sc.stop() + try { + Utils.deleteRecursively(new File(checkpointDir)) + sqlContext = null + SQLContext.clearActive() + if (sc != null) { + sc.stop() + } + sc = null + } finally { + super.afterAll() } - sc = null - super.afterAll() } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala index 352193a67860c..6de9aaf94f1b2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala @@ -17,9 +17,10 @@ package org.apache.spark.mllib.util -import org.apache.spark.mllib.linalg.{Matrix, Vector} import org.scalatest.exceptions.TestFailedException +import org.apache.spark.mllib.linalg.{Matrix, Vector} + object TestingUtils { val ABS_TOL_MSG = " using absolute tolerance" diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala index 8f475f30249d6..44c39704e5b92 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.mllib.util +import org.scalatest.exceptions.TestFailedException + import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.TestingUtils._ -import org.scalatest.exceptions.TestFailedException class TestingUtilsSuite extends SparkFunSuite { diff --git a/network/common/pom.xml b/network/common/pom.xml deleted file mode 100644 index 9af6cc5e925f9..0000000000000 --- a/network/common/pom.xml +++ /dev/null @@ -1,108 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT - ../../pom.xml - - - org.apache.spark - spark-network-common_2.10 - jar - Spark Project Networking - http://spark.apache.org/ - - network-common - - - - - - io.netty - netty-all - - - - - org.slf4j - slf4j-api - provided - - - com.google.code.findbugs - jsr305 - - - - com.google.guava - guava - compile - - - - - log4j - log4j - test - - - org.apache.spark - spark-test-tags_${scala.binary.version} - - - org.mockito - mockito-core - test - - - org.slf4j - slf4j-log4j12 - test - - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - - - org.apache.maven.plugins - maven-jar-plugin - - - test-jar-on-test-compile - test-compile - - test-jar - - - - - - - diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java deleted file mode 100644 index 6cce97c807dc0..0000000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import java.util.List; - -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.MessageToMessageEncoder; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Encoder used by the server side to encode server-to-client responses. - * This encoder is stateless so it is safe to be shared by multiple threads. - */ -@ChannelHandler.Sharable -public final class MessageEncoder extends MessageToMessageEncoder { - - private final Logger logger = LoggerFactory.getLogger(MessageEncoder.class); - - /*** - * Encodes a Message by invoking its encode() method. For non-data messages, we will add one - * ByteBuf to 'out' containing the total frame length, the message type, and the message itself. - * In the case of a ChunkFetchSuccess, we will also add the ManagedBuffer corresponding to the - * data to 'out', in order to enable zero-copy transfer. - */ - @Override - public void encode(ChannelHandlerContext ctx, Message in, List out) { - Object body = null; - long bodyLength = 0; - boolean isBodyInFrame = false; - - // Detect ResponseWithBody messages and get the data buffer out of them. - // The body is used in order to enable zero-copy transfer for the payload. - if (in instanceof ResponseWithBody) { - ResponseWithBody resp = (ResponseWithBody) in; - try { - bodyLength = resp.body.size(); - body = resp.body.convertToNetty(); - isBodyInFrame = resp.isBodyInFrame; - } catch (Exception e) { - // Re-encode this message as a failure response. - String error = e.getMessage() != null ? e.getMessage() : "null"; - logger.error(String.format("Error processing %s for client %s", - resp, ctx.channel().remoteAddress()), e); - encode(ctx, resp.createFailureResponse(error), out); - return; - } - } - - Message.Type msgType = in.type(); - // All messages have the frame length, message type, and message itself. The frame length - // may optionally include the length of the body data, depending on what message is being - // sent. - int headerLength = 8 + msgType.encodedLength() + in.encodedLength(); - long frameLength = headerLength + (isBodyInFrame ? bodyLength : 0); - ByteBuf header = ctx.alloc().heapBuffer(headerLength); - header.writeLong(frameLength); - msgType.encode(header); - in.encode(header); - assert header.writableBytes() == 0; - - if (body != null && bodyLength > 0) { - out.add(new MessageWithHeader(header, body, bodyLength)); - } else { - out.add(header); - } - } - -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java deleted file mode 100644 index d686a951467cf..0000000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import java.io.IOException; -import java.nio.channels.WritableByteChannel; - -import com.google.common.base.Preconditions; -import io.netty.buffer.ByteBuf; -import io.netty.channel.FileRegion; -import io.netty.util.AbstractReferenceCounted; -import io.netty.util.ReferenceCountUtil; - -/** - * A wrapper message that holds two separate pieces (a header and a body). - * - * The header must be a ByteBuf, while the body can be a ByteBuf or a FileRegion. - */ -class MessageWithHeader extends AbstractReferenceCounted implements FileRegion { - - private final ByteBuf header; - private final int headerLength; - private final Object body; - private final long bodyLength; - private long totalBytesTransferred; - - MessageWithHeader(ByteBuf header, Object body, long bodyLength) { - Preconditions.checkArgument(body instanceof ByteBuf || body instanceof FileRegion, - "Body must be a ByteBuf or a FileRegion."); - this.header = header; - this.headerLength = header.readableBytes(); - this.body = body; - this.bodyLength = bodyLength; - } - - @Override - public long count() { - return headerLength + bodyLength; - } - - @Override - public long position() { - return 0; - } - - @Override - public long transfered() { - return totalBytesTransferred; - } - - /** - * This code is more complicated than you would think because we might require multiple - * transferTo invocations in order to transfer a single MessageWithHeader to avoid busy waiting. - * - * The contract is that the caller will ensure position is properly set to the total number - * of bytes transferred so far (i.e. value returned by transfered()). - */ - @Override - public long transferTo(final WritableByteChannel target, final long position) throws IOException { - Preconditions.checkArgument(position == totalBytesTransferred, "Invalid position."); - // Bytes written for header in this call. - long writtenHeader = 0; - if (header.readableBytes() > 0) { - writtenHeader = copyByteBuf(header, target); - totalBytesTransferred += writtenHeader; - if (header.readableBytes() > 0) { - return writtenHeader; - } - } - - // Bytes written for body in this call. - long writtenBody = 0; - if (body instanceof FileRegion) { - writtenBody = ((FileRegion) body).transferTo(target, totalBytesTransferred - headerLength); - } else if (body instanceof ByteBuf) { - writtenBody = copyByteBuf((ByteBuf) body, target); - } - totalBytesTransferred += writtenBody; - - return writtenHeader + writtenBody; - } - - @Override - protected void deallocate() { - header.release(); - ReferenceCountUtil.release(body); - } - - private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException { - int written = target.write(buf.nioBuffer()); - buf.skipBytes(written); - return written; - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java deleted file mode 100644 index 67be77e39f711..0000000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NettyManagedBuffer; - -/** - * Abstract class for response messages that contain a large data portion kept in a separate - * buffer. These messages are treated especially by MessageEncoder. - */ -public abstract class ResponseWithBody implements ResponseMessage { - public final ManagedBuffer body; - public final boolean isBodyInFrame; - - protected ResponseWithBody(ManagedBuffer body, boolean isBodyInFrame) { - this.body = body; - this.isBodyInFrame = isBodyInFrame; - } - - public abstract ResponseMessage createFailureResponse(String error); -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java deleted file mode 100644 index 745039db742fa..0000000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import java.util.Arrays; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -/** - * A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}. - * This will correspond to a single - * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). - */ -public final class RpcRequest implements RequestMessage { - /** Used to link an RPC request with its response. */ - public final long requestId; - - /** Serialized message to send to remote RpcHandler. */ - public final byte[] message; - - public RpcRequest(long requestId, byte[] message) { - this.requestId = requestId; - this.message = message; - } - - @Override - public Type type() { return Type.RpcRequest; } - - @Override - public int encodedLength() { - return 8 + Encoders.ByteArrays.encodedLength(message); - } - - @Override - public void encode(ByteBuf buf) { - buf.writeLong(requestId); - Encoders.ByteArrays.encode(buf, message); - } - - public static RpcRequest decode(ByteBuf buf) { - long requestId = buf.readLong(); - byte[] message = Encoders.ByteArrays.decode(buf); - return new RpcRequest(requestId, message); - } - - @Override - public int hashCode() { - return Objects.hashCode(requestId, Arrays.hashCode(message)); - } - - @Override - public boolean equals(Object other) { - if (other instanceof RpcRequest) { - RpcRequest o = (RpcRequest) other; - return requestId == o.requestId && Arrays.equals(message, o.message); - } - return false; - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("requestId", requestId) - .add("message", message) - .toString(); - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java deleted file mode 100644 index 1671cd444f039..0000000000000 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import java.util.Arrays; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -/** Response to {@link RpcRequest} for a successful RPC. */ -public final class RpcResponse implements ResponseMessage { - public final long requestId; - public final byte[] response; - - public RpcResponse(long requestId, byte[] response) { - this.requestId = requestId; - this.response = response; - } - - @Override - public Type type() { return Type.RpcResponse; } - - @Override - public int encodedLength() { return 8 + Encoders.ByteArrays.encodedLength(response); } - - @Override - public void encode(ByteBuf buf) { - buf.writeLong(requestId); - Encoders.ByteArrays.encode(buf, response); - } - - public static RpcResponse decode(ByteBuf buf) { - long requestId = buf.readLong(); - byte[] response = Encoders.ByteArrays.decode(buf); - return new RpcResponse(requestId, response); - } - - @Override - public int hashCode() { - return Objects.hashCode(requestId, Arrays.hashCode(response)); - } - - @Override - public boolean equals(Object other) { - if (other instanceof RpcResponse) { - RpcResponse o = (RpcResponse) other; - return requestId == o.requestId && Arrays.equals(response, o.response); - } - return false; - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("requestId", requestId) - .add("response", response) - .toString(); - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java deleted file mode 100644 index cad76ab7aa54e..0000000000000 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.sasl; - -import io.netty.buffer.ByteBuf; - -import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.protocol.Encoders; - -/** - * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged - * with the given appId. This appId allows a single SaslRpcHandler to multiplex different - * applications which may be using different sets of credentials. - */ -class SaslMessage implements Encodable { - - /** Serialization tag used to catch incorrect payloads. */ - private static final byte TAG_BYTE = (byte) 0xEA; - - public final String appId; - public final byte[] payload; - - public SaslMessage(String appId, byte[] payload) { - this.appId = appId; - this.payload = payload; - } - - @Override - public int encodedLength() { - return 1 + Encoders.Strings.encodedLength(appId) + Encoders.ByteArrays.encodedLength(payload); - } - - @Override - public void encode(ByteBuf buf) { - buf.writeByte(TAG_BYTE); - Encoders.Strings.encode(buf, appId); - Encoders.ByteArrays.encode(buf, payload); - } - - public static SaslMessage decode(ByteBuf buf) { - if (buf.readByte() != TAG_BYTE) { - throw new IllegalStateException("Expected SaslMessage, received something else" - + " (maybe your client does not have SASL enabled?)"); - } - - String appId = Encoders.Strings.decode(buf); - byte[] payload = Encoders.ByteArrays.decode(buf); - return new SaslMessage(appId, payload); - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java deleted file mode 100644 index dbb7f95f55bc0..0000000000000 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.server; - -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; - -/** - * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s. - */ -public abstract class RpcHandler { - /** - * Receive a single RPC message. Any exception thrown while in this method will be sent back to - * the client in string form as a standard RPC failure. - * - * This method will not be called in parallel for a single TransportClient (i.e., channel). - * - * @param client A channel client which enables the handler to make requests back to the sender - * of this RPC. This will always be the exact same object for a particular channel. - * @param message The serialized bytes of the RPC. - * @param callback Callback which should be invoked exactly once upon success or failure of the - * RPC. - */ - public abstract void receive( - TransportClient client, - byte[] message, - RpcResponseCallback callback); - - /** - * Returns the StreamManager which contains the state about which streams are currently being - * fetched by a TransportClient. - */ - public abstract StreamManager getStreamManager(); - - /** - * Invoked when the connection associated with the given client has been invalidated. - * No further requests will come from this client. - */ - public void connectionTerminated(TransportClient client) { } - - public void exceptionCaught(Throwable cause, TransportClient client) { } -} diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java deleted file mode 100644 index 8e0ee709e38e3..0000000000000 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.server; - -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.SimpleChannelInboundHandler; -import io.netty.handler.timeout.IdleState; -import io.netty.handler.timeout.IdleStateEvent; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportResponseHandler; -import org.apache.spark.network.protocol.Message; -import org.apache.spark.network.protocol.RequestMessage; -import org.apache.spark.network.protocol.ResponseMessage; -import org.apache.spark.network.util.NettyUtils; - -/** - * The single Transport-level Channel handler which is used for delegating requests to the - * {@link TransportRequestHandler} and responses to the {@link TransportResponseHandler}. - * - * All channels created in the transport layer are bidirectional. When the Client initiates a Netty - * Channel with a RequestMessage (which gets handled by the Server's RequestHandler), the Server - * will produce a ResponseMessage (handled by the Client's ResponseHandler). However, the Server - * also gets a handle on the same Channel, so it may then begin to send RequestMessages to the - * Client. - * This means that the Client also needs a RequestHandler and the Server needs a ResponseHandler, - * for the Client's responses to the Server's requests. - * - * This class also handles timeouts from a {@link io.netty.handler.timeout.IdleStateHandler}. - * We consider a connection timed out if there are outstanding fetch or RPC requests but no traffic - * on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not - * timeout if the client is continuously sending but getting no responses, for simplicity. - */ -public class TransportChannelHandler extends SimpleChannelInboundHandler { - private final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class); - - private final TransportClient client; - private final TransportResponseHandler responseHandler; - private final TransportRequestHandler requestHandler; - private final long requestTimeoutNs; - - public TransportChannelHandler( - TransportClient client, - TransportResponseHandler responseHandler, - TransportRequestHandler requestHandler, - long requestTimeoutMs) { - this.client = client; - this.responseHandler = responseHandler; - this.requestHandler = requestHandler; - this.requestTimeoutNs = requestTimeoutMs * 1000L * 1000; - } - - public TransportClient getClient() { - return client; - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - logger.warn("Exception in connection from " + NettyUtils.getRemoteAddress(ctx.channel()), - cause); - requestHandler.exceptionCaught(cause); - responseHandler.exceptionCaught(cause); - ctx.close(); - } - - @Override - public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { - try { - requestHandler.channelUnregistered(); - } catch (RuntimeException e) { - logger.error("Exception from request handler while unregistering channel", e); - } - try { - responseHandler.channelUnregistered(); - } catch (RuntimeException e) { - logger.error("Exception from response handler while unregistering channel", e); - } - super.channelUnregistered(ctx); - } - - @Override - public void channelRead0(ChannelHandlerContext ctx, Message request) { - if (request instanceof RequestMessage) { - requestHandler.handle((RequestMessage) request); - } else { - responseHandler.handle((ResponseMessage) request); - } - } - - /** Triggered based on events from an {@link io.netty.handler.timeout.IdleStateHandler}. */ - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { - if (evt instanceof IdleStateEvent) { - IdleStateEvent e = (IdleStateEvent) evt; - // See class comment for timeout semantics. In addition to ensuring we only timeout while - // there are outstanding requests, we also do a secondary consistency check to ensure - // there's no race between the idle timeout and incrementing the numOutstandingRequests. - boolean hasInFlightRequests = responseHandler.numOutstandingRequests() > 0; - boolean isActuallyOverdue = - System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; - if (e.state() == IdleState.ALL_IDLE && hasInFlightRequests && isActuallyOverdue) { - String address = NettyUtils.getRemoteAddress(ctx.channel()); - logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + - "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + - "is wrong.", address, requestTimeoutNs / 1000 / 1000); - ctx.close(); - } - } - } -} diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java deleted file mode 100644 index 3b2eff377955a..0000000000000 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.util; - -import com.google.common.primitives.Ints; - -/** - * A central location that tracks all the settings we expose to users. - */ -public class TransportConf { - private final ConfigProvider conf; - - public TransportConf(ConfigProvider conf) { - this.conf = conf; - } - - /** IO mode: nio or epoll */ - public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); } - - /** If true, we will prefer allocating off-heap byte buffers within Netty. */ - public boolean preferDirectBufs() { - return conf.getBoolean("spark.shuffle.io.preferDirectBufs", true); - } - - /** Connect timeout in milliseconds. Default 120 secs. */ - public int connectionTimeoutMs() { - long defaultNetworkTimeoutS = JavaUtils.timeStringAsSec( - conf.get("spark.network.timeout", "120s")); - long defaultTimeoutMs = JavaUtils.timeStringAsSec( - conf.get("spark.shuffle.io.connectionTimeout", defaultNetworkTimeoutS + "s")) * 1000; - return (int) defaultTimeoutMs; - } - - /** Number of concurrent connections between two nodes for fetching data. */ - public int numConnectionsPerPeer() { - return conf.getInt("spark.shuffle.io.numConnectionsPerPeer", 1); - } - - /** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */ - public int backLog() { return conf.getInt("spark.shuffle.io.backLog", -1); } - - /** Number of threads used in the server thread pool. Default to 0, which is 2x#cores. */ - public int serverThreads() { return conf.getInt("spark.shuffle.io.serverThreads", 0); } - - /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */ - public int clientThreads() { return conf.getInt("spark.shuffle.io.clientThreads", 0); } - - /** - * Receive buffer size (SO_RCVBUF). - * Note: the optimal size for receive buffer and send buffer should be - * latency * network_bandwidth. - * Assuming latency = 1ms, network_bandwidth = 10Gbps - * buffer size should be ~ 1.25MB - */ - public int receiveBuf() { return conf.getInt("spark.shuffle.io.receiveBuffer", -1); } - - /** Send buffer size (SO_SNDBUF). */ - public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); } - - /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ - public int saslRTTimeoutMs() { - return (int) JavaUtils.timeStringAsSec(conf.get("spark.shuffle.sasl.timeout", "30s")) * 1000; - } - - /** - * Max number of times we will try IO exceptions (such as connection timeouts) per request. - * If set to 0, we will not do any retries. - */ - public int maxIORetries() { return conf.getInt("spark.shuffle.io.maxRetries", 3); } - - /** - * Time (in milliseconds) that we will wait in order to perform a retry after an IOException. - * Only relevant if maxIORetries > 0. - */ - public int ioRetryWaitTimeMs() { - return (int) JavaUtils.timeStringAsSec(conf.get("spark.shuffle.io.retryWait", "5s")) * 1000; - } - - /** - * Minimum size of a block that we should start using memory map rather than reading in through - * normal IO operations. This prevents Spark from memory mapping very small blocks. In general, - * memory mapping has high overhead for blocks close to or below the page size of the OS. - */ - public int memoryMapBytes() { - return conf.getInt("spark.storage.memoryMapThreshold", 2 * 1024 * 1024); - } - - /** - * Whether to initialize shuffle FileDescriptor lazily or not. If true, file descriptors are - * created only when data is going to be transferred. This can reduce the number of open files. - */ - public boolean lazyFileDescriptor() { - return conf.getBoolean("spark.shuffle.io.lazyFD", true); - } - - /** - * Maximum number of retries when binding to a port before giving up. - */ - public int portMaxRetries() { - return conf.getInt("spark.port.maxRetries", 16); - } - - /** - * Maximum number of bytes to be encrypted at a time when SASL encryption is enabled. - */ - public int maxSaslEncryptedBlockSize() { - return Ints.checkedCast(JavaUtils.byteStringAsBytes( - conf.get("spark.network.sasl.maxEncryptedBlockSize", "64k"))); - } - - /** - * Whether the server should enforce encryption on SASL-authenticated connections. - */ - public boolean saslServerAlwaysEncrypt() { - return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false); - } - -} diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java deleted file mode 100644 index 272ea84e6180d..0000000000000 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.util; - -import com.google.common.base.Preconditions; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.CompositeByteBuf; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; - -/** - * A customized frame decoder that allows intercepting raw data. - *

    - * This behaves like Netty's frame decoder (with harcoded parameters that match this library's - * needs), except it allows an interceptor to be installed to read data directly before it's - * framed. - *

    - * Unlike Netty's frame decoder, each frame is dispatched to child handlers as soon as it's - * decoded, instead of building as many frames as the current buffer allows and dispatching - * all of them. This allows a child handler to install an interceptor if needed. - *

    - * If an interceptor is installed, framing stops, and data is instead fed directly to the - * interceptor. When the interceptor indicates that it doesn't need to read any more data, - * framing resumes. Interceptors should not hold references to the data buffers provided - * to their handle() method. - */ -public class TransportFrameDecoder extends ChannelInboundHandlerAdapter { - - public static final String HANDLER_NAME = "frameDecoder"; - private static final int LENGTH_SIZE = 8; - private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE; - - private CompositeByteBuf buffer; - private volatile Interceptor interceptor; - - @Override - public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { - ByteBuf in = (ByteBuf) data; - - if (buffer == null) { - buffer = in.alloc().compositeBuffer(); - } - - buffer.writeBytes(in); - - while (buffer.isReadable()) { - feedInterceptor(); - if (interceptor != null) { - continue; - } - - ByteBuf frame = decodeNext(); - if (frame != null) { - ctx.fireChannelRead(frame); - } else { - break; - } - } - - // We can't discard read sub-buffers if there are other references to the buffer (e.g. - // through slices used for framing). This assumes that code that retains references - // will call retain() from the thread that called "fireChannelRead()" above, otherwise - // ref counting will go awry. - if (buffer != null && buffer.refCnt() == 1) { - buffer.discardReadComponents(); - } - } - - protected ByteBuf decodeNext() throws Exception { - if (buffer.readableBytes() < LENGTH_SIZE) { - return null; - } - - int frameLen = (int) buffer.readLong() - LENGTH_SIZE; - if (buffer.readableBytes() < frameLen) { - buffer.readerIndex(buffer.readerIndex() - LENGTH_SIZE); - return null; - } - - Preconditions.checkArgument(frameLen < MAX_FRAME_SIZE, "Too large frame: %s", frameLen); - Preconditions.checkArgument(frameLen > 0, "Frame length should be positive: %s", frameLen); - - ByteBuf frame = buffer.readSlice(frameLen); - frame.retain(); - return frame; - } - - @Override - public void channelInactive(ChannelHandlerContext ctx) throws Exception { - if (buffer != null) { - if (buffer.isReadable()) { - feedInterceptor(); - } - buffer.release(); - } - if (interceptor != null) { - interceptor.channelInactive(); - } - super.channelInactive(ctx); - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - if (interceptor != null) { - interceptor.exceptionCaught(cause); - } - super.exceptionCaught(ctx, cause); - } - - public void setInterceptor(Interceptor interceptor) { - Preconditions.checkState(this.interceptor == null, "Already have an interceptor."); - this.interceptor = interceptor; - } - - private void feedInterceptor() throws Exception { - if (interceptor != null && !interceptor.handle(buffer)) { - interceptor = null; - } - } - - public static interface Interceptor { - - /** - * Handles data received from the remote end. - * - * @param data Buffer containing data. - * @return "true" if the interceptor expects more data, "false" to uninstall the interceptor. - */ - boolean handle(ByteBuf data) throws Exception; - - /** Called if an exception is thrown in the channel pipeline. */ - void exceptionCaught(Throwable cause) throws Exception; - - /** Called if the channel is closed and the interceptor is still installed. */ - void channelInactive() throws Exception; - - } - -} diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java deleted file mode 100644 index 84ebb337e6d54..0000000000000 --- a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ /dev/null @@ -1,277 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network; - -import com.google.common.collect.Maps; -import com.google.common.util.concurrent.Uninterruptibles; -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.util.MapConfigProvider; -import org.apache.spark.network.util.TransportConf; -import org.junit.*; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.*; -import java.util.concurrent.Semaphore; -import java.util.concurrent.TimeUnit; - -/** - * Suite which ensures that requests that go without a response for the network timeout period are - * failed, and the connection closed. - * - * In this suite, we use 2 seconds as the connection timeout, with some slack given in the tests, - * to ensure stability in different test environments. - */ -public class RequestTimeoutIntegrationSuite { - - private TransportServer server; - private TransportClientFactory clientFactory; - - private StreamManager defaultManager; - private TransportConf conf; - - // A large timeout that "shouldn't happen", for the sake of faulty tests not hanging forever. - private final int FOREVER = 60 * 1000; - - @Before - public void setUp() throws Exception { - Map configMap = Maps.newHashMap(); - configMap.put("spark.shuffle.io.connectionTimeout", "2s"); - conf = new TransportConf(new MapConfigProvider(configMap)); - - defaultManager = new StreamManager() { - @Override - public ManagedBuffer getChunk(long streamId, int chunkIndex) { - throw new UnsupportedOperationException(); - } - }; - } - - @After - public void tearDown() { - if (server != null) { - server.close(); - } - if (clientFactory != null) { - clientFactory.close(); - } - } - - // Basic suite: First request completes quickly, and second waits for longer than network timeout. - @Test - public void timeoutInactiveRequests() throws Exception { - final Semaphore semaphore = new Semaphore(1); - final byte[] response = new byte[16]; - RpcHandler handler = new RpcHandler() { - @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - try { - semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS); - callback.onSuccess(response); - } catch (InterruptedException e) { - // do nothing - } - } - - @Override - public StreamManager getStreamManager() { - return defaultManager; - } - }; - - TransportContext context = new TransportContext(conf, handler); - server = context.createServer(); - clientFactory = context.createClientFactory(); - TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - - // First completes quickly (semaphore starts at 1). - TestCallback callback0 = new TestCallback(); - synchronized (callback0) { - client.sendRpc(new byte[0], callback0); - callback0.wait(FOREVER); - assert (callback0.success.length == response.length); - } - - // Second times out after 2 seconds, with slack. Must be IOException. - TestCallback callback1 = new TestCallback(); - synchronized (callback1) { - client.sendRpc(new byte[0], callback1); - callback1.wait(4 * 1000); - assert (callback1.failure != null); - assert (callback1.failure instanceof IOException); - } - semaphore.release(); - } - - // A timeout will cause the connection to be closed, invalidating the current TransportClient. - // It should be the case that requesting a client from the factory produces a new, valid one. - @Test - public void timeoutCleanlyClosesClient() throws Exception { - final Semaphore semaphore = new Semaphore(0); - final byte[] response = new byte[16]; - RpcHandler handler = new RpcHandler() { - @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - try { - semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS); - callback.onSuccess(response); - } catch (InterruptedException e) { - // do nothing - } - } - - @Override - public StreamManager getStreamManager() { - return defaultManager; - } - }; - - TransportContext context = new TransportContext(conf, handler); - server = context.createServer(); - clientFactory = context.createClientFactory(); - - // First request should eventually fail. - TransportClient client0 = - clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - TestCallback callback0 = new TestCallback(); - synchronized (callback0) { - client0.sendRpc(new byte[0], callback0); - callback0.wait(FOREVER); - assert (callback0.failure instanceof IOException); - assert (!client0.isActive()); - } - - // Increment the semaphore and the second request should succeed quickly. - semaphore.release(2); - TransportClient client1 = - clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - TestCallback callback1 = new TestCallback(); - synchronized (callback1) { - client1.sendRpc(new byte[0], callback1); - callback1.wait(FOREVER); - assert (callback1.success.length == response.length); - assert (callback1.failure == null); - } - } - - // The timeout is relative to the LAST request sent, which is kinda weird, but still. - // This test also makes sure the timeout works for Fetch requests as well as RPCs. - @Test - public void furtherRequestsDelay() throws Exception { - final byte[] response = new byte[16]; - final StreamManager manager = new StreamManager() { - @Override - public ManagedBuffer getChunk(long streamId, int chunkIndex) { - Uninterruptibles.sleepUninterruptibly(FOREVER, TimeUnit.MILLISECONDS); - return new NioManagedBuffer(ByteBuffer.wrap(response)); - } - }; - RpcHandler handler = new RpcHandler() { - @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - throw new UnsupportedOperationException(); - } - - @Override - public StreamManager getStreamManager() { - return manager; - } - }; - - TransportContext context = new TransportContext(conf, handler); - server = context.createServer(); - clientFactory = context.createClientFactory(); - TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); - - // Send one request, which will eventually fail. - TestCallback callback0 = new TestCallback(); - client.fetchChunk(0, 0, callback0); - Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS); - - // Send a second request before the first has failed. - TestCallback callback1 = new TestCallback(); - client.fetchChunk(0, 1, callback1); - Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS); - - synchronized (callback0) { - // not complete yet, but should complete soon - assert (callback0.success == null && callback0.failure == null); - callback0.wait(2 * 1000); - assert (callback0.failure instanceof IOException); - } - - synchronized (callback1) { - // failed at same time as previous - assert (callback0.failure instanceof IOException); - } - } - - /** - * Callback which sets 'success' or 'failure' on completion. - * Additionally notifies all waiters on this callback when invoked. - */ - class TestCallback implements RpcResponseCallback, ChunkReceivedCallback { - - byte[] success; - Throwable failure; - - @Override - public void onSuccess(byte[] response) { - synchronized(this) { - success = response; - this.notifyAll(); - } - } - - @Override - public void onFailure(Throwable e) { - synchronized(this) { - failure = e; - this.notifyAll(); - } - } - - @Override - public void onSuccess(int chunkIndex, ManagedBuffer buffer) { - synchronized(this) { - try { - success = buffer.nioByteBuffer().array(); - this.notifyAll(); - } catch (IOException e) { - // weird - } - } - } - - @Override - public void onFailure(int chunkIndex, Throwable e) { - synchronized(this) { - failure = e; - this.notifyAll(); - } - } - } -} diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java deleted file mode 100644 index 17a03ebe88a93..0000000000000 --- a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network; - -import io.netty.channel.local.LocalChannel; -import org.junit.Test; - -import static org.junit.Assert.assertEquals; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.*; - -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportResponseHandler; -import org.apache.spark.network.protocol.ChunkFetchFailure; -import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.RpcFailure; -import org.apache.spark.network.protocol.RpcResponse; -import org.apache.spark.network.protocol.StreamChunkId; - -public class TransportResponseHandlerSuite { - @Test - public void handleSuccessfulFetch() { - StreamChunkId streamChunkId = new StreamChunkId(1, 0); - - TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); - ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - handler.addFetchRequest(streamChunkId, callback); - assertEquals(1, handler.numOutstandingRequests()); - - handler.handle(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123))); - verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); - assertEquals(0, handler.numOutstandingRequests()); - } - - @Test - public void handleFailedFetch() { - StreamChunkId streamChunkId = new StreamChunkId(1, 0); - TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); - ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - handler.addFetchRequest(streamChunkId, callback); - assertEquals(1, handler.numOutstandingRequests()); - - handler.handle(new ChunkFetchFailure(streamChunkId, "some error msg")); - verify(callback, times(1)).onFailure(eq(0), (Throwable) any()); - assertEquals(0, handler.numOutstandingRequests()); - } - - @Test - public void clearAllOutstandingRequests() { - TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); - ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - handler.addFetchRequest(new StreamChunkId(1, 0), callback); - handler.addFetchRequest(new StreamChunkId(1, 1), callback); - handler.addFetchRequest(new StreamChunkId(1, 2), callback); - assertEquals(3, handler.numOutstandingRequests()); - - handler.handle(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12))); - handler.exceptionCaught(new Exception("duh duh duhhhh")); - - // should fail both b2 and b3 - verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); - verify(callback, times(1)).onFailure(eq(1), (Throwable) any()); - verify(callback, times(1)).onFailure(eq(2), (Throwable) any()); - assertEquals(0, handler.numOutstandingRequests()); - } - - @Test - public void handleSuccessfulRPC() { - TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); - RpcResponseCallback callback = mock(RpcResponseCallback.class); - handler.addRpcRequest(12345, callback); - assertEquals(1, handler.numOutstandingRequests()); - - handler.handle(new RpcResponse(54321, new byte[7])); // should be ignored - assertEquals(1, handler.numOutstandingRequests()); - - byte[] arr = new byte[10]; - handler.handle(new RpcResponse(12345, arr)); - verify(callback, times(1)).onSuccess(eq(arr)); - assertEquals(0, handler.numOutstandingRequests()); - } - - @Test - public void handleFailedRPC() { - TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); - RpcResponseCallback callback = mock(RpcResponseCallback.class); - handler.addRpcRequest(12345, callback); - assertEquals(1, handler.numOutstandingRequests()); - - handler.handle(new RpcFailure(54321, "uh-oh!")); // should be ignored - assertEquals(1, handler.numOutstandingRequests()); - - handler.handle(new RpcFailure(12345, "oh no")); - verify(callback, times(1)).onFailure((Throwable) any()); - assertEquals(0, handler.numOutstandingRequests()); - } -} diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java deleted file mode 100644 index 6c98e733b462f..0000000000000 --- a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.WritableByteChannel; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.FileRegion; -import io.netty.util.AbstractReferenceCounted; -import org.junit.Test; - -import static org.junit.Assert.*; - -import org.apache.spark.network.util.ByteArrayWritableChannel; - -public class MessageWithHeaderSuite { - - @Test - public void testSingleWrite() throws Exception { - testFileRegionBody(8, 8); - } - - @Test - public void testShortWrite() throws Exception { - testFileRegionBody(8, 1); - } - - @Test - public void testByteBufBody() throws Exception { - ByteBuf header = Unpooled.copyLong(42); - ByteBuf body = Unpooled.copyLong(84); - MessageWithHeader msg = new MessageWithHeader(header, body, body.readableBytes()); - - ByteBuf result = doWrite(msg, 1); - assertEquals(msg.count(), result.readableBytes()); - assertEquals(42, result.readLong()); - assertEquals(84, result.readLong()); - } - - private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception { - ByteBuf header = Unpooled.copyLong(42); - int headerLength = header.readableBytes(); - TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall); - MessageWithHeader msg = new MessageWithHeader(header, region, region.count()); - - ByteBuf result = doWrite(msg, totalWrites / writesPerCall); - assertEquals(headerLength + region.count(), result.readableBytes()); - assertEquals(42, result.readLong()); - for (long i = 0; i < 8; i++) { - assertEquals(i, result.readLong()); - } - } - - private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception { - int writes = 0; - ByteArrayWritableChannel channel = new ByteArrayWritableChannel((int) msg.count()); - while (msg.transfered() < msg.count()) { - msg.transferTo(channel, msg.transfered()); - writes++; - } - assertTrue("Not enough writes!", minExpectedWrites <= writes); - return Unpooled.wrappedBuffer(channel.getData()); - } - - private static class TestFileRegion extends AbstractReferenceCounted implements FileRegion { - - private final int writeCount; - private final int writesPerCall; - private int written; - - TestFileRegion(int totalWrites, int writesPerCall) { - this.writeCount = totalWrites; - this.writesPerCall = writesPerCall; - } - - @Override - public long count() { - return 8 * writeCount; - } - - @Override - public long position() { - return 0; - } - - @Override - public long transfered() { - return 8 * written; - } - - @Override - public long transferTo(WritableByteChannel target, long position) throws IOException { - for (int i = 0; i < writesPerCall; i++) { - ByteBuf buf = Unpooled.copyLong((position / 8) + i); - ByteBuffer nio = buf.nioBuffer(); - while (nio.remaining() > 0) { - target.write(nio); - } - buf.release(); - written++; - } - return 8 * writesPerCall; - } - - @Override - protected void deallocate() { - } - - } - -} diff --git a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java deleted file mode 100644 index ca74f0a00cf9d..0000000000000 --- a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.util; - -import java.nio.ByteBuffer; -import java.util.Random; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelHandlerContext; -import org.junit.Test; -import static org.junit.Assert.*; -import static org.mockito.Mockito.*; - -public class TransportFrameDecoderSuite { - - @Test - public void testFrameDecoding() throws Exception { - Random rnd = new Random(); - TransportFrameDecoder decoder = new TransportFrameDecoder(); - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - - final int frameCount = 100; - ByteBuf data = Unpooled.buffer(); - try { - for (int i = 0; i < frameCount; i++) { - byte[] frame = new byte[1024 * (rnd.nextInt(31) + 1)]; - data.writeLong(frame.length + 8); - data.writeBytes(frame); - } - - while (data.isReadable()) { - int size = rnd.nextInt(16 * 1024) + 256; - decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size))); - } - - verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class)); - } finally { - data.release(); - } - } - - @Test - public void testInterception() throws Exception { - final int interceptedReads = 3; - TransportFrameDecoder decoder = new TransportFrameDecoder(); - TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads)); - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - - byte[] data = new byte[8]; - ByteBuf len = Unpooled.copyLong(8 + data.length); - ByteBuf dataBuf = Unpooled.wrappedBuffer(data); - - try { - decoder.setInterceptor(interceptor); - for (int i = 0; i < interceptedReads; i++) { - decoder.channelRead(ctx, dataBuf); - dataBuf.release(); - dataBuf = Unpooled.wrappedBuffer(data); - } - decoder.channelRead(ctx, len); - decoder.channelRead(ctx, dataBuf); - verify(interceptor, times(interceptedReads)).handle(any(ByteBuf.class)); - verify(ctx).fireChannelRead(any(ByteBuffer.class)); - } finally { - len.release(); - dataBuf.release(); - } - } - - @Test(expected = IllegalArgumentException.class) - public void testNegativeFrameSize() throws Exception { - testInvalidFrame(-1); - } - - @Test(expected = IllegalArgumentException.class) - public void testEmptyFrame() throws Exception { - // 8 because frame size includes the frame length. - testInvalidFrame(8); - } - - @Test(expected = IllegalArgumentException.class) - public void testLargeFrame() throws Exception { - // Frame length includes the frame size field, so need to add a few more bytes. - testInvalidFrame(Integer.MAX_VALUE + 9); - } - - private void testInvalidFrame(long size) throws Exception { - TransportFrameDecoder decoder = new TransportFrameDecoder(); - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - ByteBuf frame = Unpooled.copyLong(size); - try { - decoder.channelRead(ctx, frame); - } finally { - frame.release(); - } - } - - private static class MockInterceptor implements TransportFrameDecoder.Interceptor { - - private int remainingReads; - - MockInterceptor(int readCount) { - this.remainingReads = readCount; - } - - @Override - public boolean handle(ByteBuf data) throws Exception { - data.readerIndex(data.readerIndex() + data.readableBytes()); - assertFalse(data.isReadable()); - remainingReads -= 1; - return remainingReads != 0; - } - - @Override - public void exceptionCaught(Throwable cause) throws Exception { - - } - - @Override - public void channelInactive() throws Exception { - - } - - } - -} diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml deleted file mode 100644 index 70ba5cb1995bb..0000000000000 --- a/network/shuffle/pom.xml +++ /dev/null @@ -1,101 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT - ../../pom.xml - - - org.apache.spark - spark-network-shuffle_2.10 - jar - Spark Project Shuffle Streaming Service - http://spark.apache.org/ - - network-shuffle - - - - - - org.apache.spark - spark-network-common_${scala.binary.version} - ${project.version} - - - - org.fusesource.leveldbjni - leveldbjni-all - 1.8 - - - - com.fasterxml.jackson.core - jackson-databind - - - - com.fasterxml.jackson.core - jackson-annotations - - - - - org.slf4j - slf4j-api - provided - - - com.google.guava - guava - - - - - org.apache.spark - spark-network-common_${scala.binary.version} - ${project.version} - test-jar - test - - - org.apache.spark - spark-test-tags_${scala.binary.version} - - - log4j - log4j - test - - - org.mockito - mockito-core - test - - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java deleted file mode 100644 index 7543b6be4f2a1..0000000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle.mesos; - -import java.io.IOException; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.sasl.SecretKeyHolder; -import org.apache.spark.network.shuffle.ExternalShuffleClient; -import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver; -import org.apache.spark.network.util.TransportConf; - -/** - * A client for talking to the external shuffle service in Mesos coarse-grained mode. - * - * This is used by the Spark driver to register with each external shuffle service on the cluster. - * The reason why the driver has to talk to the service is for cleaning up shuffle files reliably - * after the application exits. Mesos does not provide a great alternative to do this, so Spark - * has to detect this itself. - */ -public class MesosExternalShuffleClient extends ExternalShuffleClient { - private final Logger logger = LoggerFactory.getLogger(MesosExternalShuffleClient.class); - - /** - * Creates an Mesos external shuffle client that wraps the {@link ExternalShuffleClient}. - * Please refer to docs on {@link ExternalShuffleClient} for more information. - */ - public MesosExternalShuffleClient( - TransportConf conf, - SecretKeyHolder secretKeyHolder, - boolean saslEnabled, - boolean saslEncryptionEnabled) { - super(conf, secretKeyHolder, saslEnabled, saslEncryptionEnabled); - } - - public void registerDriverWithShuffleService(String host, int port) throws IOException { - checkInit(); - byte[] registerDriver = new RegisterDriver(appId).toByteArray(); - TransportClient client = clientFactory.createClient(host, port); - client.sendRpc(registerDriver, new RpcResponseCallback() { - @Override - public void onSuccess(byte[] response) { - logger.info("Successfully registered app " + appId + " with external shuffle service."); - } - - @Override - public void onFailure(Throwable e) { - logger.warn("Unable to register app " + appId + " with external shuffle service. " + - "Please manually remove shuffle data after driver exit. Error: " + e); - } - }); - } -} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java deleted file mode 100644 index 94a61d6caadc4..0000000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle.protocol.mesos; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -import org.apache.spark.network.protocol.Encoders; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; - -// Needed by ScalaDoc. See SPARK-7726 -import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; - -/** - * A message sent from the driver to register with the MesosExternalShuffleService. - */ -public class RegisterDriver extends BlockTransferMessage { - private final String appId; - - public RegisterDriver(String appId) { - this.appId = appId; - } - - public String getAppId() { return appId; } - - @Override - protected Type type() { return Type.REGISTER_DRIVER; } - - @Override - public int encodedLength() { - return Encoders.Strings.encodedLength(appId); - } - - @Override - public void encode(ByteBuf buf) { - Encoders.Strings.encode(buf, appId); - } - - @Override - public int hashCode() { - return Objects.hashCode(appId); - } - - public static RegisterDriver decode(ByteBuf buf) { - String appId = Encoders.Strings.decode(buf); - return new RegisterDriver(appId); - } -} diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml deleted file mode 100644 index e2360eff5cfe1..0000000000000 --- a/network/yarn/pom.xml +++ /dev/null @@ -1,101 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT - ../../pom.xml - - - org.apache.spark - spark-network-yarn_2.10 - jar - Spark Project YARN Shuffle Service - http://spark.apache.org/ - - network-yarn - - provided - - - - - - org.apache.spark - spark-network-shuffle_${scala.binary.version} - ${project.version} - - - org.apache.spark - spark-test-tags_${scala.binary.version} - - - - - org.apache.hadoop - hadoop-client - - - org.slf4j - slf4j-api - provided - - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - - org.apache.maven.plugins - maven-shade-plugin - - false - ${project.build.directory}/scala-${scala.binary.version}/spark-${project.version}-yarn-shuffle.jar - - - *:* - - - - - *:* - - META-INF/*.SF - META-INF/*.DSA - META-INF/*.RSA - - - - - - - package - - shade - - - - - - - diff --git a/pom.xml b/pom.xml index 4ed1c0c82dee6..a772d513372e7 100644 --- a/pom.xml +++ b/pom.xml @@ -25,8 +25,8 @@ 14 org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT + spark-parent_2.11 + 2.0.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ @@ -86,27 +86,25 @@ - tags + common/sketch + common/network-common + common/network-shuffle + common/unsafe + common/tags core - bagel graphx mllib + mllib-local tools - network/common - network/shuffle streaming sql/catalyst sql/core sql/hive - unsafe + external/docker-integration-tests assembly - external/twitter external/flume external/flume-sink external/flume-assembly - external/mqtt - external/mqtt-assembly - external/zeromq examples repl launcher @@ -117,55 +115,56 @@ UTF-8 UTF-8 - com.typesafe.akka - 2.3.11 1.7 - 3.3.3 + 3.3.9 spark 0.21.1 shaded-protobuf - 1.7.10 + 1.7.16 1.2.17 2.2.0 2.5.0 ${hadoop.version} - 0.98.7-hadoop2 + 0.98.17-hadoop2 hbase 1.6.0 3.4.5 2.4.0 org.spark-project.hive - 1.2.1.spark + 1.2.1.spark2 1.2.1 10.10.1.1 1.7.0 1.6.0 - 1.2.4 8.1.14.v20131031 3.0.0.v201112011016 - 0.5.0 + 0.8.0 2.4.0 2.0.8 3.1.2 1.7.7 hadoop2 0.7.1 - 1.9.40 - 1.4.0 + 1.6.1 + + 0.10.2 4.3.2 + 4.3.2 3.1 3.4.1 - 2.10.5 - 2.10 + + 3.2.2 + 2.11.8 + 2.11 ${scala.version} org.scala-lang 1.9.13 - 2.4.4 - 1.1.2 + 2.5.3 + 1.1.2.4 1.1.2 1.2.0-incubating 1.10 @@ -180,16 +179,27 @@ 3.5.2 1.3.9 0.9.2 + 4.5.2-1 ${java.home} + + org.spark_project + + + ${project.build.directory}/scala-${scala.binary.version}/jars + + + prepare-package + none + compile @@ -222,93 +232,6 @@ false - - apache-repo - Apache Repository - https://repository.apache.org/content/repositories/releases - - true - - - false - - - - jboss-repo - JBoss Repository - https://repository.jboss.org/nexus/content/repositories/releases - - true - - - false - - - - mqtt-repo - MQTT Repository - https://repo.eclipse.org/content/repositories/paho-releases - - true - - - false - - - - cloudera-repo - Cloudera Repository - https://repository.cloudera.com/artifactory/cloudera-repos - - true - - - false - - - - spark-hive-staging - Staging Repo for Hive 1.2.1 (Spark Version) - https://oss.sonatype.org/content/repositories/orgspark-project-1113 - - true - - - - mapr-repo - MapR Repository - http://repository.mapr.com/maven/ - - true - - - false - - - - - spring-releases - Spring Release Repository - https://repo.spring.io/libs-release - - false - - - false - - - - - twttr-repo - Twttr Repository - http://maven.twttr.com - - true - - - false - - @@ -323,15 +246,6 @@ - - - org.spark-project.spark - unused - 1.0.0 - + + org.apache.xbean + xbean-asm5-shaded + 4.4 test @@ -580,37 +487,6 @@ ${protobuf.version} ${hadoop.deps.scope} - - ${akka.group} - akka-actor_${scala.binary.version} - ${akka.version} - - - ${akka.group} - akka-remote_${scala.binary.version} - ${akka.version} - - - ${akka.group} - akka-slf4j_${scala.binary.version} - ${akka.version} - - - ${akka.group} - akka-testkit_${scala.binary.version} - ${akka.version} - - - ${akka.group} - akka-zeromq_${scala.binary.version} - ${akka.version} - - - ${akka.group} - akka-actor_${scala.binary.version} - - - org.apache.mesos mesos @@ -623,6 +499,11 @@ + + org.roaringbitmap + RoaringBitmap + 0.5.11 + commons-net commons-net @@ -633,6 +514,11 @@ netty-all 4.0.29.Final + + io.netty + netty + 3.8.0.Final + org.apache.derby derby @@ -698,6 +584,28 @@ ${jersey.version} ${hadoop.deps.scope} + + org.scalanlp + breeze_${scala.binary.version} + 0.11.2 + + + + junit + junit + + + org.apache.commons + commons-math3 + + + + + org.json4s + json4s-jackson_${scala.binary.version} + 3.2.10 + com.sun.jersey jersey-json @@ -737,25 +645,25 @@ org.scalatest scalatest_${scala.binary.version} - 2.2.1 + 2.2.6 test org.mockito mockito-core - 1.9.5 + 1.10.19 test org.scalacheck scalacheck_${scala.binary.version} - 1.11.3 + 1.12.5 test junit junit - 4.11 + 4.12 test @@ -776,6 +684,47 @@ 0.11 test + + com.spotify + docker-client + shaded + 3.6.6 + test + + + guava + com.google.guava + + + org.apache.httpcomponents + httpclient + + + org.apache.httpcomponents + httpcore + + + commons-logging + httpclient + + + commons-logging + commons-logging + + + + + mysql + mysql-connector-java + 5.1.38 + test + + + org.postgresql + postgresql + 9.4.1207.jre7 + test + org.apache.curator curator-recipes @@ -786,6 +735,10 @@ org.jboss.netty netty + + jline + jline + @@ -882,6 +835,14 @@ + + + org.apache.avro + avro-ipc + tests + ${avro.version} + test + org.apache.avro avro-mapred @@ -1070,6 +1031,12 @@ zookeeper ${zookeeper.version} ${hadoop.deps.scope} + + + org.jboss.netty + netty + + org.codehaus.jackson @@ -1338,6 +1305,10 @@ commons-logging commons-logging + + org.codehaus.groovy + groovy-all + @@ -1409,6 +1380,10 @@ commons-logging commons-logging + + org.codehaus.groovy + groovy-all + @@ -1503,6 +1478,10 @@ commons-logging commons-logging + + org.codehaus.groovy + groovy-all + @@ -1548,6 +1527,14 @@ org.apache.thrift libthrift + + org.codehaus.groovy + groovy-all + + + javax.servlet + servlet-api + @@ -1596,6 +1583,10 @@ commons-logging commons-logging + + org.codehaus.groovy + groovy-all + @@ -1781,6 +1772,11 @@ + + org.antlr + antlr4-runtime + ${antlr4.version} + @@ -1790,7 +1786,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 1.4 + 1.4.1 enforce-versions @@ -1805,6 +1801,20 @@ ${java.version} + + + + org.jboss.netty + org.codehaus.groovy + + true + @@ -1813,7 +1823,7 @@ org.codehaus.mojo build-helper-maven-plugin - 1.9.1 + 1.10 net.alchim31.maven @@ -1876,7 +1886,7 @@ org.apache.maven.plugins maven-compiler-plugin - 3.3 + 3.5.1 ${java.version} ${java.version} @@ -1888,11 +1898,21 @@ + + org.antlr + antlr3-maven-plugin + 3.5.2 + + + org.antlr + antlr4-maven-plugin + ${antlr4.version} + org.apache.maven.plugins maven-surefire-plugin - 2.18.1 + 2.19.1 @@ -1910,10 +1930,12 @@ --> ${test_classpath} 1 + ${scala.binary.version} 1 ${test.java.home} + file:src/test/resources/log4j.properties test true ${project.build.directory}/tmp @@ -1922,7 +1944,6 @@ false false false - true true src @@ -1930,6 +1951,14 @@ false ${test.exclude.tags} + + + test + + test + + + @@ -1950,10 +1979,12 @@ --> ${test_classpath} 1 + ${scala.binary.version} 1 ${test.java.home} + file:src/test/resources/log4j.properties test true ${project.build.directory}/tmp @@ -1961,7 +1992,6 @@ 1 false false - true true __not_used__ @@ -2007,7 +2037,7 @@ org.apache.maven.plugins maven-clean-plugin - 2.6.1 + 3.0.0 @@ -2035,12 +2065,12 @@ org.apache.maven.plugins maven-assembly-plugin - 2.5.5 + 2.6 org.apache.maven.plugins maven-shade-plugin - 2.4.1 + 2.4.3 org.apache.maven.plugins @@ -2052,6 +2082,23 @@ maven-deploy-plugin 2.8.2 + + org.apache.maven.plugins + maven-dependency-plugin + + + default-cli + + build-classpath + + + + runtime + + + + @@ -2115,6 +2162,7 @@ 2.10 + generate-test-classpath test-compile build-classpath @@ -2124,6 +2172,17 @@ test_classpath + + copy-module-dependencies + ${build.copyDependenciesPhase} + + copy-dependencies + + + runtime + ${jars.target.dir} + + @@ -2138,9 +2197,6 @@ false - - org.spark-project.spark:unused - org.eclipse.jetty:jetty-io org.eclipse.jetty:jetty-http org.eclipse.jetty:jetty-continuation @@ -2155,25 +2211,14 @@ org.eclipse.jetty - org.spark-project.jetty + ${spark.shade.packageName}.jetty org.eclipse.jetty.** com.google.common - org.spark-project.guava - - - com/google/common/base/Absent* - com/google/common/base/Function - com/google/common/base/Optional* - com/google/common/base/Present* - com/google/common/base/Supplier - + ${spark.shade.packageName}.guava @@ -2201,7 +2246,7 @@ org.scalastyle scalastyle-maven-plugin - 0.7.0 + 0.8.0 false true @@ -2222,6 +2267,30 @@ + + org.apache.maven.plugins + maven-checkstyle-plugin + 2.17 + + false + false + true + false + ${basedir}/src/main/java,${basedir}/src/main/scala + ${basedir}/src/test/java + dev/checkstyle.xml + ${basedir}/target/checkstyle-output.xml + ${project.build.sourceEncoding} + ${project.reporting.outputEncoding} + + + + + check + + + + org.apache.maven.plugins @@ -2258,7 +2327,7 @@ prepare-test-jar - prepare-package + ${build.testJarPhase} test-jar @@ -2294,7 +2363,7 @@ spark-ganglia-lgpl - extras/spark-ganglia-lgpl + external/spark-ganglia-lgpl @@ -2302,34 +2371,19 @@ kinesis-asl - extras/kinesis-asl - extras/kinesis-asl-assembly + external/kinesis-asl + external/kinesis-asl-assembly java8-tests - - - - - org.apache.maven.plugins - maven-jar-plugin - - - - test-jar - - - - - - - + + [1.8,) + - extras/java8-tests + external/java8-tests - @@ -2357,19 +2411,6 @@ http://hadoop.apache.org/docs/ra.b.c/hadoop-project-dist/hadoop-common/dependency-analysis.html --> - - hadoop-1 - - 1.2.1 - 2.4.1 - 0.98.7-hadoop1 - hadoop1 - 1.8.8 - org.spark-project.akka - 2.3.4-spark - - - hadoop-2.2 @@ -2401,11 +2442,21 @@ + + hadoop-2.7 + + 2.7.0 + 0.9.3 + 3.4.6 + 2.6.0 + + + yarn yarn - network/yarn + common/network-yarn @@ -2419,10 +2470,10 @@ scala-2.10 - !scala-2.11 + scala-2.10 - 2.10.5 + 2.10.6 2.10 ${scala.version} org.scala-lang @@ -2451,10 +2502,10 @@ scala-2.11 - scala-2.11 + !scala-2.10 - 2.11.7 + 2.11.8 2.11 diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 519052620246f..3dc1ceacde19a 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -42,14 +42,11 @@ object MimaBuild { ProblemFilters.exclude[IncompatibleFieldTypeProblem](fullName) ) - // Exclude a single class and its corresponding object + // Exclude a single class def excludeClass(className: String) = Seq( excludePackage(className), ProblemFilters.exclude[MissingClassProblem](className), - ProblemFilters.exclude[MissingTypesProblem](className), - excludePackage(className + "$"), - ProblemFilters.exclude[MissingClassProblem](className + "$"), - ProblemFilters.exclude[MissingTypesProblem](className + "$") + ProblemFilters.exclude[MissingTypesProblem](className) ) // Exclude a Spark class, that is in the package org.apache.spark @@ -91,8 +88,8 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "1.5.0" - val fullId = "spark-" + projectRef.project + "_2.10" + val previousSparkVersion = "1.6.0" + val fullId = "spark-" + projectRef.project + "_2.11" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), binaryIssueFilters ++= ignoredABIProblems(sparkHome, version.value)) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 90dc947d4e588..71f337ce1f63e 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -21,8 +21,7 @@ import com.typesafe.tools.mima.core.ProblemFilters._ /** * Additional excludes for checking of Spark's binary compatibility. * - * The Mima build will automatically exclude @DeveloperApi and @Experimental classes. This acts - * as an official audit of cases where we excluded other classes. Please use the narrowest + * This acts as an official audit of cases where we excluded other classes. Please use the narrowest * possible exclude here. MIMA will usually tell you what exclude to use, e.g.: * * ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.take") @@ -30,9 +29,609 @@ import com.typesafe.tools.mima.core.ProblemFilters._ * It is also possible to exclude Spark classes and packages. This should be used sparingly: * * MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") + * + * For a new Spark version, please update MimaBuild.scala to reflect the previous version. */ object MimaExcludes { def excludes(version: String) = version match { + case v if v.startsWith("2.0") => + Seq( + excludePackage("org.apache.spark.rpc"), + excludePackage("org.spark-project.jetty"), + excludePackage("org.apache.spark.unused"), + excludePackage("org.apache.spark.unsafe"), + excludePackage("org.apache.spark.util.collection.unsafe"), + excludePackage("org.apache.spark.sql.catalyst"), + excludePackage("org.apache.spark.sql.execution"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.StageData.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.status.api.v1.ApplicationAttemptInfo.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.status.api.v1.ApplicationAttemptInfo.$default$5"), + // SPARK-12600 Remove SQL deprecated methods + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$QueryExecution"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$SparkPlanner"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.applySchema"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.parquetFile"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jdbc"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jsonFile"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jsonRDD"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.load"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.dialectClassName"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.getSQLDialect"), + // SPARK-13664 Replace HadoopFsRelation with FileFormat + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.source.libsvm.LibSVMRelation"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.HadoopFsRelationProvider"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.HadoopFsRelation$FileStatusCache") + ) ++ Seq( + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.SparkContext.emptyRDD"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.broadcast.HttpBroadcastFactory"), + // SPARK-14358 SparkListener from trait to abstract class + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.SparkContext.addSparkListener"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.JavaSparkListener"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkFirehoseListener"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.scheduler.SparkListener"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.jobs.JobProgressListener"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.exec.ExecutorsListener"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.env.EnvironmentListener"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.storage.StorageListener"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.storage.StorageStatusListener") + ) ++ + Seq( + // SPARK-3369 Fix Iterable/Iterator in Java API + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.function.FlatMapFunction.call"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.function.FlatMapFunction.call"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.function.DoubleFlatMapFunction.call"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.function.DoubleFlatMapFunction.call"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.function.FlatMapFunction2.call"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.function.FlatMapFunction2.call"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.function.PairFlatMapFunction.call"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.function.PairFlatMapFunction.call"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.function.CoGroupFunction.call"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.function.CoGroupFunction.call"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.function.MapPartitionsFunction.call"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.function.MapPartitionsFunction.call"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.function.FlatMapGroupsFunction.call"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.function.FlatMapGroupsFunction.call") + ) ++ + Seq( + // SPARK-4819 replace Guava Optional + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.JavaSparkContext.getCheckpointDir"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.JavaSparkContext.getSparkHome"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.getCheckpointFile"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.getCheckpointFile"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner") + ) ++ + Seq( + // SPARK-12481 Remove Hadoop 1.x + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.mapred.SparkHadoopMapRedUtil"), + // SPARK-12615 Remove deprecated APIs in core + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.$default$6"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.numericRDDToDoubleRDDFunctions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.intToIntWritable"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.intWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.writableWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToPairRDDFunctions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToAsyncRDDActions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.boolToBoolWritable"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.longToLongWritable"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.doubleWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToOrderedRDDFunctions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.floatWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.booleanWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.stringToText"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.doubleRDDToDoubleRDDFunctions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.doubleToDoubleWritable"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.bytesWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToSequenceFileRDDFunctions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.bytesToBytesWritable"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.longWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.stringWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.floatToFloatWritable"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToPairRDDFunctions$default$4"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.addOnCompleteCallback"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.runningLocally"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.attemptId"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.defaultMinSplits"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.SparkContext.runJob"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.runJob"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.tachyonFolderName"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.initLocalProperties"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.clearJars"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.clearFiles"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.SparkContext.this"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.flatMapWith$default$2"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.toArray"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapWith$default$2"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapPartitionsWithSplit"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.flatMapWith"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.filterWith"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.foreachWith"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapWith"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapPartitionsWithSplit$default$2"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.SequenceFileRDDFunctions.this"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.splits"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.toArray"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.defaultMinSplits"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.clearJars"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.clearFiles"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.externalBlockStoreFolderName"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockStore$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockManager"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockStore") + ) ++ Seq( + // SPARK-12149 Added new fields to ExecutorSummary + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this") + ) ++ + // SPARK-12665 Remove deprecated and unused classes + Seq( + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.graphx.GraphKryoRegistrator"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector$Multiplier"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector$") + ) ++ Seq( + // SPARK-12591 Register OpenHashMapBasedStateMap for Kryo + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.KryoInputDataInputBridge"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.KryoOutputDataOutputBridge") + ) ++ Seq( + // SPARK-12510 Refactor ActorReceiver to support Java + ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.streaming.receiver.ActorReceiver") + ) ++ Seq( + // SPARK-12895 Implement TaskMetrics using accumulators + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.internalMetricsToAccumulators"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.collectInternalAccumulators"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.collectAccumulators") + ) ++ Seq( + // SPARK-12896 Send only accumulator updates to driver, not TaskMetrics + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.Accumulable.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.Accumulator.this"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.Accumulator.initialValue") + ) ++ Seq( + // SPARK-12692 Scala style: Fix the style violation (Space before "," or ":") + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkSink.org$apache$spark$streaming$flume$sink$Logging$$log_"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkSink.org$apache$spark$streaming$flume$sink$Logging$$log__="), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler.org$apache$spark$streaming$flume$sink$Logging$$log_"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler.org$apache$spark$streaming$flume$sink$Logging$$log__="), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$log__="), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$log_"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$_log"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$_log_="), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log_"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log__=") + ) ++ Seq( + // SPARK-12689 Migrate DDL parsing to the newly absorbed parser + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.execution.datasources.DDLParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.execution.datasources.DDLException"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.ddlParser") + ) ++ Seq( + // SPARK-7799 Add "streaming-akka" project + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream$default$6"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream$default$5"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.actorStream$default$4"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.actorStream$default$3"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.actorStream"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.actorStream"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.streaming.zeromq.ZeroMQReceiver"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorReceiver$Supervisor") + ) ++ Seq( + // SPARK-12348 Remove deprecated Streaming APIs. + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.streaming.dstream.DStream.foreach"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions$default$4"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.awaitTermination"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.networkStream"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.api.java.JavaStreamingContextFactory"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.awaitTermination"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.sc"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaDStreamLike.foreachRDD"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaDStreamLike.foreach"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.getOrCreate") + ) ++ Seq( + // SPARK-12847 Remove StreamingListenerBus and post all Streaming events to the same thread as Spark events + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus") + ) ++ Seq( + // SPARK-11622 Make LibSVMRelation extends HadoopFsRelation and Add LibSVMOutputWriter + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.source.libsvm.DefaultSource"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.source.libsvm.DefaultSource.createRelation") + ) ++ Seq( + // SPARK-6363 Make Scala 2.11 the default Scala version + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.cleanup"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.metadataCleaner"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnDriverEndpoint"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnSchedulerEndpoint") + ) ++ Seq( + // SPARK-7889 + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.org$apache$spark$deploy$history$HistoryServer$@tachSparkUI"), + // SPARK-13296 + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.UDFRegistration.register"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedPythonFunction$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedPythonFunction"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedFunction"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedFunction$") + ) ++ Seq( + // SPARK-12995 Remove deprecated APIs in graphx + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.lib.SVDPlusPlus.runSVDPlusPlus"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.Graph.mapReduceTriplets"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.Graph.mapReduceTriplets$default$3"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.impl.GraphImpl.mapReduceTriplets") + ) ++ Seq( + // SPARK-13426 Remove the support of SIMR + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkMasterRegex.SIMR_REGEX") + ) ++ Seq( + // SPARK-13413 Remove SparkContext.metricsSystem/schedulerBackend_ setter + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.metricsSystem"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.schedulerBackend_=") + ) ++ Seq( + // SPARK-13220 Deprecate yarn-client and yarn-cluster mode + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler") + ) ++ Seq( + // SPARK-13465 TaskContext. + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.addTaskFailureListener") + ) ++ Seq ( + // SPARK-7729 Executor which has been killed should also be displayed on Executor Tab + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this") + ) ++ Seq( + // SPARK-13526 Move SQLContext per-session states to new class + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.sql.UDFRegistration.this") + ) ++ Seq( + // [SPARK-13486][SQL] Move SQLConf into an internal package + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry$") + ) ++ Seq( + //SPARK-11011 UserDefinedType serialization should be strongly typed + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"), + // SPARK-12073: backpressure rate controller consumes events preferentially from lagging partitions + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.KafkaTestUtils.createTopic"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.DirectKafkaInputDStream.maxMessagesPerPartition") + ) ++ Seq( + // [SPARK-13244][SQL] Migrates DataFrame to Dataset + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.tables"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.sql"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.baseRelationToDataFrame"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.table"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrame.apply"), + + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrame"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrame$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.LegacyFunctions"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameHolder"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameHolder$"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.localSeqToDataFrameHolder"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.stringRddToDataFrameHolder"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.rddToDataFrameHolder"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.longRddToDataFrameHolder"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.intRddToDataFrameHolder"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.GroupedDataset"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Dataset.subtract"), + + // [SPARK-14451][SQL] Move encoder definition into Aggregator interface + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.toColumn"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.bufferEncoder"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.outputEncoder"), + + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MultilabelMetrics.this"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions") + ) ++ Seq( + // [SPARK-13686][MLLIB][STREAMING] Add a constructor parameter `reqParam` to (Streaming)LinearRegressionWithSGD + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.regression.LinearRegressionWithSGD.this") + ) ++ Seq( + // SPARK-13920: MIMA checks should apply to @Experimental and @DeveloperAPI APIs + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.Aggregator.combineCombinersByKey"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.Aggregator.combineValuesByKey"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ComplexFutureAction.run"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ComplexFutureAction.runJob"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ComplexFutureAction.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.actorSystem"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.cacheManager"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getConfigurationFromJobContext"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getTaskAttemptIDFromTaskAttemptContext"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.newConfiguration"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.bytesReadCallback"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.bytesReadCallback_="), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.canEqual"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.productArity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.productElement"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.productIterator"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.productPrefix"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.setBytesReadCallback"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.updateBytesRead"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.canEqual"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.productArity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.productElement"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.productIterator"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.productPrefix"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decFetchWaitTime"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decLocalBlocksFetched"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decRecordsRead"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decRemoteBlocksFetched"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decRemoteBytesRead"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.decShuffleBytesWritten"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.decShuffleRecordsWritten"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.decShuffleWriteTime"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.incShuffleBytesWritten"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.incShuffleRecordsWritten"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.incShuffleWriteTime"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.setShuffleRecordsWritten"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.PCAModel.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.rdd.RDD.mapPartitionsWithContext"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.AccumulableInfo.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate.taskMetrics"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.TaskInfo.attempt"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.ExperimentalMethods.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.callUDF"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.callUdf"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.cumeDist"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.denseRank"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.inputFileName"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.isNaN"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.percentRank"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.rowNumber"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.sparkPartitionId"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockStatus.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockStatus.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockStatus.externalBlockStoreSize"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockStatus.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.offHeapUsed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.offHeapUsedByRdd"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatusListener.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.streaming.scheduler.BatchInfo.streamIdToNumRecords"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.storageStatusList"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.storage.StorageListener.storageStatusList"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ExceptionFailure.apply"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ExceptionFailure.copy"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ExceptionFailure.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.executor.InputMetrics.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.executor.OutputMetrics.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Pipeline.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PipelineModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PredictionModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PredictionModel.transformImpl"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Predictor.extractLabeledPoints"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Predictor.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Predictor.train"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.ClassificationModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.GBTClassifier.train"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassifier.train"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.NaiveBayes.train"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.OneVsRest.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.OneVsRestModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.RandomForestClassifier.train"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.KMeans.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.KMeansModel.computeCost"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.KMeansModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.LDAModel.logLikelihood"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.LDAModel.logPerplexity"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.LDAModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.Evaluator.evaluate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator.evaluate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.RegressionEvaluator.evaluate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Binarizer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Bucketizer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.ChiSqSelector.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.CountVectorizer.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.CountVectorizerModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.HashingTF.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.IDF.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.IDFModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.IndexToString.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Interaction.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.MinMaxScaler.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.MinMaxScalerModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.OneHotEncoder.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.PCA.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.PCAModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.QuantileDiscretizer.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.RFormula.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.RFormulaModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.SQLTransformer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScaler.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StopWordsRemover.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StringIndexer.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StringIndexerModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.VectorAssembler.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.VectorIndexer.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.VectorIndexerModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.VectorSlicer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Word2Vec.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Word2VecModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.recommendation.ALS.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.recommendation.ALSModel.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.recommendation.ALSModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegression.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.GBTRegressor.train"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.IsotonicRegression.extractWeightedLabeledPoints"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.IsotonicRegression.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.IsotonicRegressionModel.extractWeightedLabeledPoints"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.IsotonicRegressionModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegression.train"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegressionTrainingSummary.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressor.train"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidator.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidatorModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.TrainValidationSplit.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.BinaryClassificationMetrics.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MulticlassMetrics.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.RegressionMetrics.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameNaFunctions.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameStatFunctions.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameWriter.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.functions.broadcast"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.functions.callUDF"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.sources.CreatableRelationProvider.createRelation"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.sources.InsertableRelation.insert"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.fMeasureByThreshold"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.pr"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.precisionByThreshold"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.predictions"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.recallByThreshold"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.roc"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.clustering.LDAModel.describeTopics"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.Word2VecModel.findSynonyms"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.Word2VecModel.getVectors"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.recommendation.ALSModel.itemFactors"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.recommendation.ALSModel.userFactors"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.predictions"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.residuals"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.scheduler.AccumulableInfo.name"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.scheduler.AccumulableInfo.value"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameNaFunctions.drop"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameNaFunctions.fill"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameNaFunctions.replace"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.jdbc"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.json"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.load"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.orc"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.parquet"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.table"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.text"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameStatFunctions.crosstab"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameStatFunctions.freqItems"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameStatFunctions.sampleBy"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.createExternalTable"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.emptyDataFrame"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.range"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.functions.udf"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.JobLogger"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorHelper"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorSupervisorStrategy"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorSupervisorStrategy$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.Statistics"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.Statistics$"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.InputMetrics"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.InputMetrics$"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.OutputMetrics"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.OutputMetrics$"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.functions$"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.Estimator.fit"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.Predictor.train"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.Transformer.transform"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.evaluation.Evaluator.evaluate"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.scheduler.SparkListener.onOtherEvent"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.CreatableRelationProvider.createRelation"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.InsertableRelation.insert") + ) ++ Seq( + // [SPARK-13926] Automatically use Kryo serializer when shuffling RDDs with simple types + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ShuffleDependency.this"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ShuffleDependency.serializer"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.Serializer$") + ) ++ Seq( + // SPARK-13927: add row/column iterator to local matrices + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.rowIter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.colIter") + ) ++ Seq( + // SPARK-13948: MiMa Check should catch if the visibility change to `private` + // TODO(josh): Some of these may be legitimate incompatibilities; we should follow up before the 2.0.0 release + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Dataset.toDS"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.OutputWriterFactory.newInstance"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.RpcUtils.askTimeout"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.RpcUtils.lookupTimeout"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.UnaryTransformer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassifier.train"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegression.train"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressor.train"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.Dataset.groupBy"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.Dataset.groupBy"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.Dataset.select"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.Dataset.toDF"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.Logging.initializeLogIfNecessary"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerEvent.logEvent"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.OutputWriterFactory.newInstance") + ) ++ Seq( + // [SPARK-14014] Replace existing analysis.Catalog with SessionCatalog + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLContext.this") + ) ++ Seq( + // [SPARK-13928] Move org.apache.spark.Logging into org.apache.spark.internal.Logging + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Logging"), + (problem: Problem) => problem match { + case MissingTypesProblem(_, missing) + if missing.map(_.fullName).sameElements(Seq("org.apache.spark.Logging")) => false + case _ => true + } + ) ++ Seq( + // [SPARK-13990] Automatically pick serializer when caching RDDs + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockTransferService.uploadBlock") + ) ++ Seq( + // [SPARK-14089][CORE][MLLIB] Remove methods that has been deprecated since 1.1, 1.2, 1.3, 1.4, and 1.5 + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.getThreadLocal"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.rdd.RDDFunctions.treeReduce"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.rdd.RDDFunctions.treeAggregate"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.tree.configuration.Strategy.defaultStategy"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.util.MLUtils.loadLibSVMFile"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.util.MLUtils.loadLibSVMFile"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.util.MLUtils.loadLibSVMFile"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.util.MLUtils.saveLabeledData"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.util.MLUtils.loadLabeledData"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.optimization.LBFGS.setMaxNumIterations"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.setScoreCol") + ) ++ Seq( + // [SPARK-14205][SQL] remove trait Queryable + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.Dataset") + ) ++ Seq( + // [SPARK-11262][ML] Unit test for gradient, loss layers, memory management + // for multilayer perceptron. + // This class is marked as `private`. + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.ann.SoftmaxFunction") + ) ++ Seq( + // [SPARK-13674][SQL] Add wholestage codegen support to Sample + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.util.random.PoissonSampler.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.random.PoissonSampler.this") + ) ++ Seq( + // [SPARK-13430][ML] moved featureCol from LinearRegressionModelSummary to LinearRegressionSummary + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.this") + ) ++ Seq( + // [SPARK-14437][Core] Use the address that NettyBlockTransferService listens to create BlockManagerId + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockTransferService.this") + ) ++ Seq( + // [SPARK-13048][ML][MLLIB] keepLastCheckpoint option for LDA EM optimizer + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.DistributedLDAModel.this") + ) ++ Seq( + // [SPARK-14475] Propagate user-defined context from driver to executors + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperty"), + // [SPARK-14617] Remove deprecated APIs in TaskMetrics + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.InputMetrics$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.OutputMetrics$") + ) case v if v.startsWith("1.6") => Seq( MimaBuild.excludeSparkPackage("deploy"), @@ -48,20 +647,18 @@ object MimaExcludes { excludePackage("org.apache.spark.sql.columnar"), // The shuffle package is considered private. excludePackage("org.apache.spark.shuffle"), - // The collections utlities are considered pricate. + // The collections utilities are considered private. excludePackage("org.apache.spark.util.collection") ) ++ MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++ MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++ Seq( - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.classification.LogisticCostFun.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.classification.LogisticAggregator.add"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.classification.LogisticAggregator.count"), + // MiMa does not deal properly with sealed traits ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.ml.classification.LogisticRegressionSummary.featuresCol") + ) ++ Seq( + // SPARK-11530 + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this") ) ++ Seq( // SPARK-10381 Fix types / units in private AskPermissionToCommitOutput RPC message. // This class is marked as `private` but MiMa still seems to be confused by the change. @@ -113,58 +710,66 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.rdd.MapPartitionsWithPreparationRDD"), ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$") + "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSQLParser") ) ++ Seq( + // SPARK-11485 + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.DataFrameHolder.df"), + // SPARK-11541 mark various JDBC dialects as private + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productElement"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productArity"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.canEqual"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productIterator"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productPrefix"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.toString"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.hashCode"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.jdbc.PostgresDialect$"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productElement"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productArity"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.canEqual"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productIterator"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productPrefix"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.toString"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.hashCode"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.jdbc.NoopDialect$") + ) ++ Seq ( + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.status.api.v1.ApplicationInfo.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.status.api.v1.StageData.this") + ) ++ Seq( + // SPARK-11766 add toJson to Vector ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$2"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$3"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$4"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$5"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$6"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$7"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$8"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$9"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$10"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$11"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$12"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$13"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$14"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$15"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$16"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$17"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$18"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$19"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$20"), + "org.apache.spark.mllib.linalg.Vector.toJson") + ) ++ Seq( + // SPARK-9065 Support message handler in Kafka Python API ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$21"), + "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createDirectStream"), ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$22"), + "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createRDD") + ) ++ Seq( + // SPARK-4557 Changed foreachRDD to use VoidFunction ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$23"), + "org.apache.spark.streaming.api.java.JavaDStreamLike.foreachRDD") + ) ++ Seq( + // SPARK-11996 Make the executor thread dump work again + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.ExecutorEndpoint"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.ExecutorEndpoint$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor$") + ) ++ Seq( + // SPARK-3580 Add getNumPartitions method to JavaRDD ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$24") + "org.apache.spark.api.java.JavaRDDLike.getNumPartitions") ) ++ Seq( - // SPARK-11485 - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.DataFrameHolder.df") - ) + // SPARK-12149 Added new fields to ExecutorSummary + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this") + ) ++ + // SPARK-11314: YARN backend moved to yarn sub-module and MiMA complains even though it's a + // private class. + MimaBuild.excludeSparkClass("scheduler.cluster.YarnSchedulerBackend$YarnSchedulerEndpoint") case v if v.startsWith("1.5") => Seq( MimaBuild.excludeSparkPackage("network"), @@ -245,7 +850,6 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopPartition"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DefaultWriterContainer"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues"), @@ -254,10 +858,8 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DynamicPartitionWriterContainer"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation$"), @@ -268,7 +870,6 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect$"), @@ -282,7 +883,6 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CaseInsensitiveMap"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException"), @@ -347,7 +947,7 @@ object MimaExcludes { Seq( MimaBuild.excludeSparkPackage("deploy"), MimaBuild.excludeSparkPackage("ml"), - // SPARK-7910 Adding a method to get the partioner to JavaRDD, + // SPARK-7910 Adding a method to get the partitioner to JavaRDD, ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"), // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"), @@ -365,7 +965,7 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.scheduler.OutputCommitCoordinator$OutputCommitCoordinatorEndpoint") ) ++ Seq( - // SPARK-4655 - Making Stage an Abstract class broke binary compatility even though + // SPARK-4655 - Making Stage an Abstract class broke binary compatibility even though // the stage class is defined as private[spark] ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.scheduler.Stage") ) ++ Seq( diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 75c36930decef..a58dd7e7f125c 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -16,16 +16,19 @@ */ import java.io._ +import java.nio.file.Files import scala.util.Properties import scala.collection.JavaConverters._ +import scala.collection.mutable.Stack import sbt._ import sbt.Classpaths.publishTask import sbt.Keys._ import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion +import com.simplytyped.Antlr4Plugin._ import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys} -import net.virtualvoid.sbt.graph.Plugin.graphSettings +import com.typesafe.tools.mima.plugin.MimaKeys import spray.revolver.RevolverPlugin._ @@ -33,28 +36,44 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, - sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka, - streamingMqtt, streamingTwitter, streamingZeromq, launcher, unsafe, testTags) = - Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", - "sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink", - "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", - "streaming-zeromq", "launcher", "unsafe", "test-tags").map(ProjectRef(buildLocation, _)) - - val optionallyEnabledProjects@Seq(yarn, yarnStable, java8Tests, sparkGangliaLgpl, - streamingKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", - "streaming-kinesis-asl").map(ProjectRef(buildLocation, _)) - - val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingMqttAssembly, streamingKinesisAslAssembly) = - Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-mqtt-assembly", "streaming-kinesis-asl-assembly") + val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer) = Seq( + "catalyst", "sql", "hive", "hive-thriftserver" + ).map(ProjectRef(buildLocation, _)) + + val streamingProjects@Seq( + streaming, streamingFlumeSink, streamingFlume, streamingKafka + ) = Seq( + "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka" + ).map(ProjectRef(buildLocation, _)) + + val allProjects@Seq( + core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, testTags, sketch, _* + ) = Seq( + "core", "graphx", "mllib", "mllib-local", "repl", "network-common", "network-shuffle", "launcher", "unsafe", + "test-tags", "sketch" + ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects + + val optionallyEnabledProjects@Seq(yarn, java8Tests, sparkGangliaLgpl, + streamingKinesisAsl, dockerIntegrationTests) = + Seq("yarn", "java8-tests", "ganglia-lgpl", "streaming-kinesis-asl", + "docker-integration-tests").map(ProjectRef(buildLocation, _)) + + val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKinesisAslAssembly) = + Seq("network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-kinesis-asl-assembly") .map(ProjectRef(buildLocation, _)) + val copyJarsProjects@Seq(assembly, examples) = Seq("assembly", "examples") + .map(ProjectRef(buildLocation, _)) + val tools = ProjectRef(buildLocation, "tools") // Root project. val spark = ProjectRef(buildLocation, "spark") val sparkHome = buildLocation val testTempDir = s"$sparkHome/target/tmp" + + val javacJVMVersion = settingKey[String]("source and target JVM version for javac") + val scalacJVMVersion = settingKey[String]("source and target JVM version for scalac") } object SparkBuild extends PomBuild { @@ -67,7 +86,6 @@ object SparkBuild extends PomBuild { // Provides compatibility for older versions of the Spark build def backwardCompatibility = { import scala.collection.mutable - var isAlphaYarn = false var profiles: mutable.Seq[String] = mutable.Seq("sbt") // scalastyle:off println if (Properties.envOrNone("SPARK_GANGLIA_LGPL").isDefined) { @@ -80,7 +98,6 @@ object SparkBuild extends PomBuild { } Properties.envOrNone("SPARK_HADOOP_VERSION") match { case Some(v) => - if (v.matches("0.23.*")) isAlphaYarn = true println("NOTE: SPARK_HADOOP_VERSION is deprecated, please use -Dhadoop.version=" + v) System.setProperty("hadoop.version", v) case None => @@ -105,11 +122,11 @@ object SparkBuild extends PomBuild { v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq } - if (System.getProperty("scala-2.11") == "") { - // To activate scala-2.11 profile, replace empty property value to non-empty value + if (System.getProperty("scala-2.10") == "") { + // To activate scala-2.10 profile, replace empty property value to non-empty value // in the same way as Maven which handles -Dname as -Dname=true before executes build process. // see: https://github.com/apache/maven/blob/maven-3.0.4/maven-embedder/src/main/java/org/apache/maven/cli/MavenCli.java#L1082 - System.setProperty("scala-2.11", "true") + System.setProperty("scala-2.10", "true") } profiles } @@ -130,17 +147,23 @@ object SparkBuild extends PomBuild { "org.spark-project" %% "genjavadoc-plugin" % unidocGenjavadocVersion.value cross CrossVersion.full), scalacOptions <+= target.map(t => "-P:genjavadoc:out=" + (t / "java"))) - lazy val sharedSettings = graphSettings ++ sparkGenjavadocSettings ++ Seq ( + lazy val sharedSettings = sparkGenjavadocSettings ++ Seq ( + exportJars in Compile := true, + exportJars in Test := false, javaHome := sys.env.get("JAVA_HOME") .orElse(sys.props.get("java.home").map { p => new File(p).getParentFile().getAbsolutePath() }) .map(file), incOptions := incOptions.value.withNameHashing(true), - retrieveManaged := true, - retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", publishMavenStyle := true, unidocGenjavadocVersion := "0.9-spark0", - resolvers += Resolver.mavenLocal, + // Override SBT's default resolvers: + resolvers := Seq( + DefaultMavenRepository, + Resolver.mavenLocal, + Resolver.file("local", file(Path.userHome.absolutePath + "/.ivy2/local"))(Resolver.ivyStylePatterns) + ), + externalResolvers := resolvers.value, otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))), publishLocalConfiguration in MavenCompile <<= (packagedArtifacts, deliverLocal, ivyLoggingLevel) map { (arts, _, level) => new PublishConfiguration(None, "dotM2", arts, Seq(), level) @@ -150,13 +173,28 @@ object SparkBuild extends PomBuild { publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn, javacOptions in (Compile, doc) ++= { - val Array(major, minor, _) = System.getProperty("java.version").split("\\.", 3) - if (major.toInt >= 1 && minor.toInt >= 8) Seq("-Xdoclint:all", "-Xdoclint:-missing") else Seq.empty + val versionParts = System.getProperty("java.version").split("[+.\\-]+", 3) + var major = versionParts(0).toInt + if (major == 1) major = versionParts(1).toInt + if (major >= 8) Seq("-Xdoclint:all", "-Xdoclint:-missing") else Seq.empty }, - javacOptions in Compile ++= Seq("-encoding", "UTF-8"), + javacJVMVersion := "1.7", + scalacJVMVersion := "1.7", + + javacOptions in Compile ++= Seq( + "-encoding", "UTF-8", + "-source", javacJVMVersion.value + ), + // This -target option cannot be set in the Compile configuration scope since `javadoc` doesn't + // play nicely with it; see https://github.com/sbt/sbt/issues/355#issuecomment-3817629 for + // additional discussion and explanation. + javacOptions in (Compile, compile) ++= Seq( + "-target", javacJVMVersion.value + ), scalacOptions in Compile ++= Seq( + s"-target:jvm-${scalacJVMVersion.value}", "-sourcepath", (baseDirectory in ThisBuild).value.getAbsolutePath // Required for relative source links in scaladoc ), @@ -206,32 +244,48 @@ object SparkBuild extends PomBuild { // Note ordering of these settings matter. /* Enable shared settings on all projects */ - (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ Seq(spark, tools)) - .foreach(enable(sharedSettings ++ ExcludedDependencies.settings ++ Revolver.settings)) + (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ copyJarsProjects ++ Seq(spark, tools)) + .foreach(enable(sharedSettings ++ DependencyOverrides.settings ++ + ExcludedDependencies.settings)) /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) - allProjects.filterNot(x => Seq(spark, hive, hiveThriftServer, catalyst, repl, - networkCommon, networkShuffle, networkYarn, unsafe, testTags).contains(x)).foreach { - x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) - } + val mimaProjects = allProjects.filterNot { x => + Seq( + spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn, + unsafe, testTags, sketch, mllibLocal + ).contains(x) + } + + mimaProjects.foreach { x => + enable(MimaBuild.mimaSettings(sparkHome, x))(x) + } /* Unsafe settings */ enable(Unsafe.settings)(unsafe) + /* + * Set up tasks to copy dependencies during packaging. This step can be disabled in the command + * line, so that dev/mima can run without trying to copy these files again and potentially + * causing issues. + */ + if (!"false".equals(System.getProperty("copyDependencies"))) { + copyJarsProjects.foreach(enable(CopyDependencies.settings)) + } + /* Enable Assembly for all assembly projects */ assemblyProjects.foreach(enable(Assembly.settings)) - /* Enable Assembly for streamingMqtt test */ - enable(inConfig(Test)(Assembly.settings))(streamingMqtt) - /* Package pyspark artifacts in a separate zip file for YARN. */ enable(PySparkAssembly.settings)(assembly) /* Enable unidoc only for the root spark project */ enable(Unidoc.settings)(spark) + /* Catalyst ANTLR generation settings */ + enable(Catalyst.settings)(catalyst) + /* Spark SQL Core console settings */ enable(SQL.settings)(sql) @@ -240,6 +294,9 @@ object SparkBuild extends PomBuild { enable(Flume.settings)(streamingFlumeSink) + enable(Java8TestSettings.settings)(java8Tests) + + enable(DockerIntegrationTests.settings)(dockerIntegrationTests) /** * Adds the ability to run the spark shell directly from SBT without building an assembly @@ -248,6 +305,11 @@ object SparkBuild extends PomBuild { * Usage: `build/sbt sparkShell` */ val sparkShell = taskKey[Unit]("start a spark-shell.") + val sparkPackage = inputKey[Unit]( + s""" + |Download and run a spark package. + |Usage `builds/sbt "sparkPackage [args] + """.stripMargin) val sparkSql = taskKey[Unit]("starts the spark sql CLI.") enable(Seq( @@ -261,6 +323,16 @@ object SparkBuild extends PomBuild { (runMain in Compile).toTask(" org.apache.spark.repl.Main -usejavacp").value }, + sparkPackage := { + import complete.DefaultParsers._ + val packages :: className :: otherArgs = spaceDelimited(" [args]").parsed.toList + val scalaRun = (runner in run).value + val classpath = (fullClasspath in Runtime).value + val args = Seq("--packages", packages, "--class", className, (Keys.`package` in Compile in "core").value.getCanonicalPath) ++ otherArgs + println(args) + scalaRun.run("org.apache.spark.deploy.SparkSubmit", classpath.map(_.data), args, streams.value.log) + }, + javaOptions in Compile += "-Dspark.master=local", sparkSql := { @@ -291,6 +363,23 @@ object Flume { lazy val settings = sbtavro.SbtAvro.avroSettings } +object DockerIntegrationTests { + // This serves to override the override specified in DependencyOverrides: + lazy val settings = Seq( + dependencyOverrides += "com.google.guava" % "guava" % "18.0", + resolvers ++= Seq("DB2" at "https://app.camunda.com/nexus/content/repositories/public/") + ) + +} + +/** + * Overrides to work around sbt's dependency resolution being different from Maven's. + */ +object DependencyOverrides { + lazy val settings = Seq( + dependencyOverrides += "com.google.guava" % "guava" % "14.0.1") +} + /** This excludes library dependencies in sbt, which are specified in maven but are not needed by sbt build. @@ -302,27 +391,31 @@ object ExcludedDependencies { } /** - * Following project only exists to pull previous artifacts of Spark for generating - * Mima ignores. For more information see: SPARK 2071 + * Project to pull previous artifacts of Spark for generating Mima excludes. */ object OldDeps { lazy val project = Project("oldDeps", file("dev"), settings = oldDepsSettings) - def versionArtifact(id: String): Option[sbt.ModuleID] = { - val fullId = id + "_2.10" - Some("org.apache.spark" % fullId % "1.2.0") + lazy val allPreviousArtifactKeys = Def.settingDyn[Seq[Option[ModuleID]]] { + SparkBuild.mimaProjects + .map { project => MimaKeys.previousArtifact in project } + .map(k => Def.setting(k.value)) + .join } def oldDepsSettings() = Defaults.coreDefaultSettings ++ Seq( name := "old-deps", scalaVersion := "2.10.5", - retrieveManaged := true, - retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", - libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq", - "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter", - "spark-streaming", "spark-mllib", "spark-bagel", "spark-graphx", - "spark-core").map(versionArtifact(_).get intransitive()) + libraryDependencies := allPreviousArtifactKeys.value.flatten + ) +} + +object Catalyst { + lazy val settings = antlr4Settings ++ Seq( + antlr4PackageName in Antlr4 := Some("org.apache.spark.sql.catalyst.parser"), + antlr4GenListener in Antlr4 := true, + antlr4GenVisitor in Antlr4 := true ) } @@ -385,7 +478,6 @@ object Hive { // new query tests. fullClasspath in Test := (fullClasspath in Test).value.filterNot { f => f.toString.contains("jcl-over") } ) - } object Assembly { @@ -402,7 +494,7 @@ object Assembly { .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String]) }, jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => - if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly") || mName.contains("streaming-mqtt-assembly") || mName.contains("streaming-kinesis-asl-assembly")) { + if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly") || mName.contains("streaming-kinesis-asl-assembly")) { // This must match the same name used in maven (see external/kafka-assembly/pom.xml) s"${mName}-${v}.jar" } else { @@ -413,7 +505,6 @@ object Assembly { s"${mName}-test-${v}.jar" }, mergeStrategy in assembly := { - case PathList("org", "datanucleus", xs @ _*) => MergeStrategy.discard case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard case "log4j.properties" => MergeStrategy.discard @@ -487,8 +578,9 @@ object Unidoc { private def ignoreUndocumentedPackages(packages: Seq[Seq[File]]): Seq[Seq[File]] = { packages .map(_.filterNot(_.getName.contains("$"))) - .map(_.filterNot(_.getCanonicalPath.contains("akka"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/deploy"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/examples"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/memory"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/network"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/shuffle"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/executor"))) @@ -506,9 +598,9 @@ object Unidoc { publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, testTags), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, streamingFlumeSink, yarn), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, testTags), // Skip actual catalyst, but include the subproject. // Catalyst is not public API and contains quasiquotes which break scaladoc. @@ -527,8 +619,7 @@ object Unidoc { "-public", "-group", "Core Java API", packageList("api.java", "api.java.function"), "-group", "Spark Streaming", packageList( - "streaming.api.java", "streaming.flume", "streaming.kafka", - "streaming.mqtt", "streaming.twitter", "streaming.zeromq", "streaming.kinesis" + "streaming.api.java", "streaming.flume", "streaming.kafka", "streaming.kinesis" ), "-group", "MLlib", packageList( "mllib.classification", "mllib.clustering", "mllib.evaluation.binary", "mllib.linalg", @@ -544,7 +635,7 @@ object Unidoc { "-noqualifier", "java.lang" ), - // Use GitHub repository for Scaladoc source linke + // Use GitHub repository for Scaladoc source links unidocSourceBase := s"https://github.com/apache/spark/tree/v${version.value}", scalacOptions in (ScalaUnidoc, unidoc) ++= Seq( @@ -560,9 +651,54 @@ object Unidoc { ) } +object CopyDependencies { + + val copyDeps = TaskKey[Unit]("copyDeps", "Copies needed dependencies to the build directory.") + val destPath = (crossTarget in Compile) / "jars" + + lazy val settings = Seq( + copyDeps := { + val dest = destPath.value + if (!dest.isDirectory() && !dest.mkdirs()) { + throw new IOException("Failed to create jars directory.") + } + + (dependencyClasspath in Compile).value.map(_.data) + .filter { jar => jar.isFile() } + .foreach { jar => + val destJar = new File(dest, jar.getName()) + if (destJar.isFile()) { + destJar.delete() + } + Files.copy(jar.toPath(), destJar.toPath()) + } + }, + crossTarget in (Compile, packageBin) := destPath.value, + packageBin in Compile <<= (packageBin in Compile).dependsOn(copyDeps) + ) + +} + +object Java8TestSettings { + import BuildCommons._ + + lazy val settings = Seq( + javacJVMVersion := "1.8", + // Targeting Java 8 bytecode is only supported in Scala 2.11.4 and higher: + scalacJVMVersion := (if (System.getProperty("scala-2.10") == "true") "1.7" else "1.8") + ) +} + object TestSettings { import BuildCommons._ + private val scalaBinaryVersion = + if (System.getProperty("scala-2.10") == "true") { + "2.10" + } else { + "2.11" + } + lazy val settings = Seq ( // Fork new JVMs for tests and set Java options for those fork := true, @@ -572,6 +708,7 @@ object TestSettings { "SPARK_DIST_CLASSPATH" -> (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":"), "SPARK_PREPEND_CLASSES" -> "1", + "SPARK_SCALA_VERSION" -> scalaBinaryVersion, "SPARK_TESTING" -> "1", "JAVA_HOME" -> sys.env.get("JAVA_HOME").getOrElse(sys.props("java.home"))), javaOptions in Test += s"-Djava.io.tmpdir=$testTempDir", @@ -581,7 +718,6 @@ object TestSettings { javaOptions in Test += "-Dspark.master.rest.enabled=false", javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dspark.ui.showConsoleProgress=false", - javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test += "-Dderby.system.durability=test", @@ -609,8 +745,21 @@ object TestSettings { parallelExecution in Test := false, // Make sure the test temp directory exists. resourceGenerators in Test <+= resourceManaged in Test map { outDir: File => - if (!new File(testTempDir).isDirectory()) { - require(new File(testTempDir).mkdirs()) + var dir = new File(testTempDir) + if (!dir.isDirectory()) { + // Because File.mkdirs() can fail if multiple callers are trying to create the same + // parent directory, this code tries to create parents one at a time, and avoids + // failures when the directories have been created by somebody else. + val stack = new Stack[File]() + while (!dir.isDirectory()) { + stack.push(dir) + dir = dir.getParentFile() + } + + while (stack.nonEmpty) { + val d = stack.pop() + require(d.mkdir() || d.isDirectory(), s"Failed to create directory $d") + } } Seq[File]() }, @@ -619,7 +768,6 @@ object TestSettings { scalacOptions in (Compile, doc) := Seq( "-groups", "-skip-packages", Seq( - "akka", "org.apache.spark.api.python", "org.apache.spark.network", "org.apache.spark.deploy", diff --git a/project/build.properties b/project/build.properties index 064ec843da9ea..1e38156e0b577 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=0.13.7 +sbt.version=0.13.11 diff --git a/project/plugins.sbt b/project/plugins.sbt index c06687d8f197b..44ec3a12ae709 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,25 +1,12 @@ -resolvers += Resolver.url("artifactory", url("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases"))(Resolver.ivyStylePatterns) - -resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/" - -resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/" - addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") -addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.2.0") - -addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0") +addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "4.0.0") -// For Sonatype publishing -//resolvers += Resolver.url("sbt-plugin-releases", new URL("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases/"))(Resolver.ivyStylePatterns) +addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.8.2") -//addSbtPlugin("com.jsuereth" % "xsbt-gpg-plugin" % "0.6") +addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.8.0") -addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.4") - -addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.7.0") - -addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") +addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.9") addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1") @@ -27,8 +14,13 @@ addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.3") addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") -addSbtPlugin("io.spray" % "sbt-revolver" % "0.7.2") +addSbtPlugin("io.spray" % "sbt-revolver" % "0.8.0") libraryDependencies += "org.ow2.asm" % "asm" % "5.0.3" libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.0.3" + +// TODO I am not sure we want such a dep. +resolvers += "simplytyped" at "http://simplytyped.github.io/repo/releases" + +addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.7.10") diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala index 471d00bd8223f..cbb88dc7dd1dd 100644 --- a/project/project/SparkPluginBuild.scala +++ b/project/project/SparkPluginBuild.scala @@ -19,9 +19,8 @@ import sbt._ import sbt.Keys._ /** - * This plugin project is there to define new scala style rules for spark. This is - * a plugin project so that this gets compiled first and is put on the classpath and - * becomes available for scalastyle sbt plugin. + * This plugin project is there because we use our custom fork of sbt-pom-reader plugin. This is + * a plugin project so that this gets compiled first and is available on the classpath for SBT build. */ object SparkPluginDef extends Build { lazy val root = Project("plugins", file(".")) dependsOn(sbtPomReader) diff --git a/python/docs/Makefile b/python/docs/Makefile index 4cec74f057fbe..905e0215c20c2 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -2,12 +2,12 @@ # # You can set these variables from the command line. -SPHINXOPTS = -SPHINXBUILD = sphinx-build -PAPER = -BUILDDIR = _build +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +PAPER ?= +BUILDDIR ?= _build -export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.9-src.zip) +export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.9.2-src.zip) # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) diff --git a/python/docs/conf.py b/python/docs/conf.py index 365d6af514177..d35bf73c30510 100644 --- a/python/docs/conf.py +++ b/python/docs/conf.py @@ -334,3 +334,6 @@ # If false, no index is generated. #epub_use_index = True + +# Skip sample endpoint link (not expected to resolve) +linkcheck_ignore = [r'https://kinesis.us-east-1.amazonaws.com'] diff --git a/python/docs/pyspark.streaming.rst b/python/docs/pyspark.streaming.rst index fc52a647543e7..25ceabac0a541 100644 --- a/python/docs/pyspark.streaming.rst +++ b/python/docs/pyspark.streaming.rst @@ -29,10 +29,3 @@ pyspark.streaming.flume.module :members: :undoc-members: :show-inheritance: - -pyspark.streaming.mqtt module ------------------------------ -.. automodule:: pyspark.streaming.mqtt - :members: - :undoc-members: - :show-inheritance: diff --git a/python/lib/py4j-0.9-src.zip b/python/lib/py4j-0.9-src.zip deleted file mode 100644 index dace2d0fe3b0b..0000000000000 Binary files a/python/lib/py4j-0.9-src.zip and /dev/null differ diff --git a/python/lib/py4j-0.9.2-src.zip b/python/lib/py4j-0.9.2-src.zip new file mode 100644 index 0000000000000..881bb759d7823 Binary files /dev/null and b/python/lib/py4j-0.9.2-src.zip differ diff --git a/pylintrc b/python/pylintrc similarity index 100% rename from pylintrc rename to python/pylintrc diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 8475dfb1c6ad0..111ebaafee3e1 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -37,6 +37,8 @@ """ +import types + from pyspark.conf import SparkConf from pyspark.context import SparkContext from pyspark.rdd import RDD @@ -64,8 +66,26 @@ def deco(f): return deco +def copy_func(f, name=None, sinceversion=None, doc=None): + """ + Returns a function with same code, globals, defaults, closure, and + name (or provide a new name). + """ + # See + # http://stackoverflow.com/questions/6527633/how-can-i-make-a-deepcopy-of-a-function-in-python + fn = types.FunctionType(f.__code__, f.__globals__, name or f.__name__, f.__defaults__, + f.__closure__) + # in case f was given attrs (note this dict is a shallow copy): + fn.__dict__.update(f.__dict__) + if doc is not None: + fn.__doc__ = doc + if sinceversion is not None: + fn = since(sinceversion)(fn) + return fn + + # for back compatibility -from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row +from pyspark.sql import SQLContext, HiveContext, Row __all__ = [ "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast", diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 663c9abe0881e..a0b819220e6d3 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -99,11 +99,26 @@ def value(self): def unpersist(self, blocking=False): """ - Delete cached copies of this broadcast on the executors. + Delete cached copies of this broadcast on the executors. If the + broadcast is used after this is called, it will need to be + re-sent to each executor. + + :param blocking: Whether to block until unpersisting has completed """ if self._jbroadcast is None: raise Exception("Broadcast can only be unpersisted in driver") self._jbroadcast.unpersist(blocking) + + def destroy(self): + """ + Destroy all data and metadata related to this broadcast variable. + Use this with caution; once a broadcast variable has been destroyed, + it cannot be used again. This method blocks until destroy has + completed. + """ + if self._jbroadcast is None: + raise Exception("Broadcast can only be destroyed in driver") + self._jbroadcast.destroy() os.unlink(self._path) def __reduce__(self): diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 95b3abc74244b..e56e22a9b920e 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -241,6 +241,7 @@ def save_function_tuple(self, func): save(f_globals) save(defaults) save(dct) + save(func.__module__) write(pickle.TUPLE) write(pickle.REDUCE) # applies _fill_function on the tuple @@ -698,13 +699,14 @@ def _genpartial(func, args, kwds): return partial(func, *args, **kwds) -def _fill_function(func, globals, defaults, dict): +def _fill_function(func, globals, defaults, dict, module): """ Fills in the rest of function data into the skeleton function object that were created via _make_skel_func(). """ func.__globals__.update(globals) func.__defaults__ = defaults func.__dict__ = dict + func.__module__ = module return func diff --git a/python/pyspark/context.py b/python/pyspark/context.py index afd74d937a413..cb15b4b91f913 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -21,6 +21,7 @@ import shutil import signal import sys +import threading from threading import RLock from tempfile import NamedTemporaryFile @@ -221,8 +222,11 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, # create a signal handler which would be invoked on receiving SIGINT def signal_handler(signal, frame): self.cancelAllJobs() + raise KeyboardInterrupt() - signal.signal(signal.SIGINT, signal_handler) + # see http://stackoverflow.com/questions/23206787/ + if isinstance(threading.current_thread(), threading._MainThread): + signal.signal(signal.SIGINT, signal_handler) def _initialize_context(self, jconf): """ @@ -424,15 +428,19 @@ def f(split, iterator): # because it sends O(n) Py4J commands. As an alternative, serialized # objects are written to a file and loaded through textFile(). tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) - # Make sure we distribute data evenly if it's smaller than self.batchSize - if "__len__" not in dir(c): - c = list(c) # Make it a list so we can compute its length - batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) - serializer = BatchedSerializer(self._unbatched_serializer, batchSize) - serializer.dump_stream(c, tempFile) - tempFile.close() - readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile - jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices) + try: + # Make sure we distribute data evenly if it's smaller than self.batchSize + if "__len__" not in dir(c): + c = list(c) # Make it a list so we can compute its length + batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) + serializer = BatchedSerializer(self._unbatched_serializer, batchSize) + serializer.dump_stream(c, tempFile) + tempFile.close() + readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile + jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices) + finally: + # readRDDFromFile eagerily reads the file so we can delete right after. + os.unlink(tempFile.name) return RDD(jrdd, self, serializer) def pickleFile(self, name, minPartitions=None): diff --git a/python/pyspark/join.py b/python/pyspark/join.py index 94df3990164d6..c1f5362648f6e 100644 --- a/python/pyspark/join.py +++ b/python/pyspark/join.py @@ -93,7 +93,7 @@ def dispatch(seq): vbuf.append(None) if not wbuf: wbuf.append(None) - return [(v, w) for v in vbuf for w in wbuf] + return ((v, w) for v in vbuf for w in wbuf) return _do_python_join(rdd, other, numPartitions, dispatch) diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py index 327a11b14b5aa..25cfac02f383d 100644 --- a/python/pyspark/ml/__init__.py +++ b/python/pyspark/ml/__init__.py @@ -15,6 +15,7 @@ # limitations under the License. # -from pyspark.ml.pipeline import Transformer, Estimator, Model, Pipeline, PipelineModel +from pyspark.ml.base import Estimator, Model, Transformer +from pyspark.ml.pipeline import Pipeline, PipelineModel __all__ = ["Transformer", "Estimator", "Model", "Pipeline", "PipelineModel"] diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py new file mode 100644 index 0000000000000..a7a58e17a43ed --- /dev/null +++ b/python/pyspark/ml/base.py @@ -0,0 +1,118 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABCMeta, abstractmethod + +from pyspark import since +from pyspark.ml.param import Params +from pyspark.mllib.common import inherit_doc + + +@inherit_doc +class Estimator(Params): + """ + Abstract class for estimators that fit models to data. + + .. versionadded:: 1.3.0 + """ + + __metaclass__ = ABCMeta + + @abstractmethod + def _fit(self, dataset): + """ + Fits a model to the input dataset. This is called by the default implementation of fit. + + :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame` + :returns: fitted model + """ + raise NotImplementedError() + + @since("1.3.0") + def fit(self, dataset, params=None): + """ + Fits a model to the input dataset with optional parameters. + + :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame` + :param params: an optional param map that overrides embedded params. If a list/tuple of + param maps is given, this calls fit on each param map and returns a list of + models. + :returns: fitted model(s) + """ + if params is None: + params = dict() + if isinstance(params, (list, tuple)): + return [self.fit(dataset, paramMap) for paramMap in params] + elif isinstance(params, dict): + if params: + return self.copy(params)._fit(dataset) + else: + return self._fit(dataset) + else: + raise ValueError("Params must be either a param map or a list/tuple of param maps, " + "but got %s." % type(params)) + + +@inherit_doc +class Transformer(Params): + """ + Abstract class for transformers that transform one dataset into another. + + .. versionadded:: 1.3.0 + """ + + __metaclass__ = ABCMeta + + @abstractmethod + def _transform(self, dataset): + """ + Transforms the input dataset. + + :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame` + :returns: transformed dataset + """ + raise NotImplementedError() + + @since("1.3.0") + def transform(self, dataset, params=None): + """ + Transforms the input dataset with optional parameters. + + :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame` + :param params: an optional param map that overrides embedded params. + :returns: transformed dataset + """ + if params is None: + params = dict() + if isinstance(params, dict): + if params: + return self.copy(params)._transform(dataset) + else: + return self._transform(dataset) + else: + raise ValueError("Params must be a param map but got %s." % type(params)) + + +@inherit_doc +class Model(Transformer): + """ + Abstract class for models that are fitted by estimators. + + .. versionadded:: 1.4.0 + """ + + __metaclass__ = ABCMeta diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 2e468f67b8987..6ef119a4265fd 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -18,25 +18,31 @@ import warnings from pyspark import since -from pyspark.ml.util import keyword_only -from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.util import * +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper +from pyspark.ml.param import TypeConverters from pyspark.ml.param.shared import * from pyspark.ml.regression import ( RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels) from pyspark.mllib.common import inherit_doc +from pyspark.sql import DataFrame -__all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassifier', - 'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel', - 'RandomForestClassifier', 'RandomForestClassificationModel', 'NaiveBayes', - 'NaiveBayesModel', 'MultilayerPerceptronClassifier', - 'MultilayerPerceptronClassificationModel'] +__all__ = ['LogisticRegression', 'LogisticRegressionModel', + 'LogisticRegressionSummary', 'LogisticRegressionTrainingSummary', + 'BinaryLogisticRegressionSummary', 'BinaryLogisticRegressionTrainingSummary', + 'DecisionTreeClassifier', 'DecisionTreeClassificationModel', + 'GBTClassifier', 'GBTClassificationModel', + 'RandomForestClassifier', 'RandomForestClassificationModel', + 'NaiveBayes', 'NaiveBayesModel', + 'MultilayerPerceptronClassifier', 'MultilayerPerceptronClassificationModel'] @inherit_doc class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol, - HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds): + HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds, + HasWeightCol, JavaMLWritable, JavaMLReadable): """ Logistic regression. Currently, this class only supports binary classification. @@ -44,11 +50,11 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> from pyspark.sql import Row >>> from pyspark.mllib.linalg import Vectors >>> df = sc.parallelize([ - ... Row(label=1.0, features=Vectors.dense(1.0)), - ... Row(label=0.0, features=Vectors.sparse(1, [], []))]).toDF() - >>> lr = LogisticRegression(maxIter=5, regParam=0.01) + ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)), + ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF() + >>> lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") >>> model = lr.fit(df) - >>> model.weights + >>> model.coefficients DenseVector([5.5...]) >>> model.intercept -2.68... @@ -67,47 +73,58 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + >>> lr_path = temp_path + "/lr" + >>> lr.save(lr_path) + >>> lr2 = LogisticRegression.load(lr_path) + >>> lr2.getMaxIter() + 5 + >>> model_path = temp_path + "/lr_model" + >>> model.save(model_path) + >>> model2 = LogisticRegressionModel.load(model_path) + >>> model.coefficients[0] == model2.coefficients[0] + True + >>> model.intercept == model2.intercept + True + + .. versionadded:: 1.3.0 """ - # a placeholder to make it appear in the generated doc threshold = Param(Params._dummy(), "threshold", "Threshold in binary classification prediction, in range [0, 1]." + - " If threshold and thresholds are both set, they must match.") + " If threshold and thresholds are both set, they must match.", + typeConverter=TypeConverters.toFloat) @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, threshold=0.5, thresholds=None, probabilityCol="probability", - rawPredictionCol="rawPrediction", standardization=True): + rawPredictionCol="rawPrediction", standardization=True, weightCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ threshold=0.5, thresholds=None, probabilityCol="probability", \ - rawPredictionCol="rawPrediction", standardization=True) + rawPredictionCol="rawPrediction", standardization=True, weightCol=None) If the threshold and thresholds Params are both set, they must be equivalent. """ super(LogisticRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.LogisticRegression", self.uid) - #: param for threshold in binary classification, in range [0, 1]. - self.threshold = Param(self, "threshold", - "Threshold in binary classification prediction, in range [0, 1]." + - " If threshold and thresholds are both set, they must match.") - self._setDefault(maxIter=100, regParam=0.1, tol=1E-6, threshold=0.5) + self._setDefault(maxIter=100, regParam=0.0, tol=1E-6, threshold=0.5) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) self._checkThresholdConsistency() @keyword_only + @since("1.3.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, threshold=0.5, thresholds=None, probabilityCol="probability", - rawPredictionCol="rawPrediction", standardization=True): + rawPredictionCol="rawPrediction", standardization=True, weightCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ threshold=0.5, thresholds=None, probabilityCol="probability", \ - rawPredictionCol="rawPrediction", standardization=True) + rawPredictionCol="rawPrediction", standardization=True, weightCol=None) Sets params for logistic regression. If the threshold and thresholds Params are both set, they must be equivalent. """ @@ -119,6 +136,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LogisticRegressionModel(java_model) + @since("1.4.0") def setThreshold(self, value): """ Sets the value of :py:attr:`threshold`. @@ -129,6 +147,7 @@ def setThreshold(self, value): del self._paramMap[self.thresholds] return self + @since("1.4.0") def getThreshold(self): """ Gets the value of threshold or its default value. @@ -144,6 +163,7 @@ def getThreshold(self): else: return self.getOrDefault(self.threshold) + @since("1.5.0") def setThresholds(self, value): """ Sets the value of :py:attr:`thresholds`. @@ -154,6 +174,7 @@ def setThresholds(self, value): del self._paramMap[self.threshold] return self + @since("1.5.0") def getThresholds(self): """ If :py:attr:`thresholds` is set, return its value. @@ -182,12 +203,15 @@ def _checkThresholdConsistency(self): " threshold (%g) and thresholds (equivalent to %g)" % (t2, t)) -class LogisticRegressionModel(JavaModel): +class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by LogisticRegression. + + .. versionadded:: 1.3.0 """ @property + @since("1.4.0") def weights(self): """ Model weights. @@ -205,32 +229,244 @@ def coefficients(self): return self._call_java("coefficients") @property + @since("1.4.0") def intercept(self): """ Model intercept. """ return self._call_java("intercept") + @property + @since("2.0.0") + def summary(self): + """ + Gets summary (e.g. residuals, mse, r-squared ) of model on + training set. An exception is thrown if + `trainingSummary is None`. + """ + java_blrt_summary = self._call_java("summary") + # Note: Once multiclass is added, update this to return correct summary + return BinaryLogisticRegressionTrainingSummary(java_blrt_summary) + + @property + @since("2.0.0") + def hasSummary(self): + """ + Indicates whether a training summary exists for this model + instance. + """ + return self._call_java("hasSummary") + + @since("2.0.0") + def evaluate(self, dataset): + """ + Evaluates the model on a test dataset. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + """ + if not isinstance(dataset, DataFrame): + raise ValueError("dataset must be a DataFrame but got %s." % type(dataset)) + java_blr_summary = self._call_java("evaluate", dataset) + return BinaryLogisticRegressionSummary(java_blr_summary) + + +class LogisticRegressionSummary(JavaWrapper): + """ + Abstraction for Logistic Regression Results for a given model. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def predictions(self): + """ + Dataframe outputted by the model's `transform` method. + """ + return self._call_java("predictions") + + @property + @since("2.0.0") + def probabilityCol(self): + """ + Field in "predictions" which gives the probability + of each class as a vector. + """ + return self._call_java("probabilityCol") + + @property + @since("2.0.0") + def labelCol(self): + """ + Field in "predictions" which gives the true label of each + instance. + """ + return self._call_java("labelCol") + + @property + @since("2.0.0") + def featuresCol(self): + """ + Field in "predictions" which gives the features of each instance + as a vector. + """ + return self._call_java("featuresCol") + + +@inherit_doc +class LogisticRegressionTrainingSummary(LogisticRegressionSummary): + """ + Abstraction for multinomial Logistic Regression Training results. + Currently, the training summary ignores the training weights except + for the objective trace. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def objectiveHistory(self): + """ + Objective function (scaled loss + regularization) at each + iteration. + """ + return self._call_java("objectiveHistory") + + @property + @since("2.0.0") + def totalIterations(self): + """ + Number of training iterations until termination. + """ + return self._call_java("totalIterations") + + +@inherit_doc +class BinaryLogisticRegressionSummary(LogisticRegressionSummary): + """ + .. note:: Experimental + + Binary Logistic regression results for a given model. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def roc(self): + """ + Returns the receiver operating characteristic (ROC) curve, + which is an Dataframe having two fields (FPR, TPR) with + (0.0, 0.0) prepended and (1.0, 1.0) appended to it. + Reference: http://en.wikipedia.org/wiki/Receiver_operating_characteristic + + Note: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("roc") + + @property + @since("2.0.0") + def areaUnderROC(self): + """ + Computes the area under the receiver operating characteristic + (ROC) curve. + + Note: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("areaUnderROC") + + @property + @since("2.0.0") + def pr(self): + """ + Returns the precision-recall curve, which is an Dataframe + containing two fields recall, precision with (0.0, 1.0) prepended + to it. + + Note: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("pr") + + @property + @since("2.0.0") + def fMeasureByThreshold(self): + """ + Returns a dataframe with two fields (threshold, F-Measure) curve + with beta = 1.0. + + Note: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("fMeasureByThreshold") + + @property + @since("2.0.0") + def precisionByThreshold(self): + """ + Returns a dataframe with two fields (threshold, precision) curve. + Every possible probability obtained in transforming the dataset + are used as thresholds used in calculating the precision. + + Note: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("precisionByThreshold") + + @property + @since("2.0.0") + def recallByThreshold(self): + """ + Returns a dataframe with two fields (threshold, recall) curve. + Every possible probability obtained in transforming the dataset + are used as thresholds used in calculating the recall. + + Note: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("recallByThreshold") + + +@inherit_doc +class BinaryLogisticRegressionTrainingSummary(BinaryLogisticRegressionSummary, + LogisticRegressionTrainingSummary): + """ + .. note:: Experimental + + Binary Logistic regression training results for a given model. + + .. versionadded:: 2.0.0 + """ + pass + class TreeClassifierParams(object): """ Private class to track supported impurity measures. + + .. versionadded:: 1.4.0 """ supportedImpurities = ["entropy", "gini"] - # a placeholder to make it appear in the generated doc impurity = Param(Params._dummy(), "impurity", "Criterion used for information gain calculation (case-insensitive). " + "Supported options: " + - ", ".join(supportedImpurities)) + ", ".join(supportedImpurities), typeConverter=TypeConverters.toString) def __init__(self): super(TreeClassifierParams, self).__init__() - #: param for Criterion used for information gain calculation (case-insensitive). - self.impurity = Param(self, "impurity", "Criterion used for information " + - "gain calculation (case-insensitive). Supported options: " + - ", ".join(self.supportedImpurities)) + @since("1.6.0") def setImpurity(self, value): """ Sets the value of :py:attr:`impurity`. @@ -238,6 +474,7 @@ def setImpurity(self, value): self._paramMap[self.impurity] = value return self + @since("1.6.0") def getImpurity(self): """ Gets the value of impurity or its default value. @@ -248,6 +485,8 @@ def getImpurity(self): class GBTParams(TreeEnsembleParams): """ Private class to track supported GBT params. + + .. versionadded:: 1.4.0 """ supportedLossTypes = ["logistic"] @@ -255,7 +494,8 @@ class GBTParams(TreeEnsembleParams): @inherit_doc class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams, - TreeClassifierParams, HasCheckpointInterval): + TreeClassifierParams, HasCheckpointInterval, HasSeed, JavaMLWritable, + JavaMLReadable): """ `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` learning algorithm for classification. @@ -276,6 +516,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred 3 >>> model.depth 1 + >>> model.featureImportances + SparseVector(1, {0: 1.0}) >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> result = model.transform(test0).head() >>> result.prediction @@ -287,18 +529,33 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + + >>> dtc_path = temp_path + "/dtc" + >>> dt.save(dtc_path) + >>> dt2 = DecisionTreeClassifier.load(dtc_path) + >>> dt2.getMaxDepth() + 2 + >>> model_path = temp_path + "/dtc_model" + >>> model.save(model_path) + >>> model2 = DecisionTreeClassificationModel.load(model_path) + >>> model.featureImportances == model2.featureImportances + True + + .. versionadded:: 1.4.0 """ @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini"): + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", + seed=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini") + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \ + seed=None) """ super(DecisionTreeClassifier, self).__init__() self._java_obj = self._new_java_obj( @@ -310,16 +567,18 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - impurity="gini"): + impurity="gini", seed=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini") + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \ + seed=None) Sets params for the DecisionTreeClassifier. """ kwargs = self.setParams._input_kwargs @@ -330,16 +589,40 @@ def _create_model(self, java_model): @inherit_doc -class DecisionTreeClassificationModel(DecisionTreeModel): +class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable): """ Model fitted by DecisionTreeClassifier. + + .. versionadded:: 1.4.0 """ + @property + @since("2.0.0") + def featureImportances(self): + """ + Estimate of the importance of each feature. + + This generalizes the idea of "Gini" importance to other losses, + following the explanation of Gini importance from "Random Forests" documentation + by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + + This feature importance is calculated as follows: + - importance(feature j) = sum (over nodes which split on feature j) of the gain, + where gain is scaled by the number of instances passing through node + - Normalize importances for tree to sum to 1. + + Note: Feature importance for single decision trees can have high variance due to + correlated predictor variables. Consider using a :py:class:`RandomForestClassifier` + to determine feature importance instead. + """ + return self._call_java("featureImportances") + @inherit_doc class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed, HasRawPredictionCol, HasProbabilityCol, - RandomForestParams, TreeClassifierParams, HasCheckpointInterval): + RandomForestParams, TreeClassifierParams, HasCheckpointInterval, + JavaMLWritable, JavaMLReadable): """ `http://en.wikipedia.org/wiki/Random_forest Random Forest` learning algorithm for classification. @@ -358,6 +641,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> td = si_model.transform(df) >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42) >>> model = rf.fit(td) + >>> model.featureImportances + SparseVector(1, {0: 1.0}) >>> allclose(model.treeWeights, [1.0, 1.0, 1.0]) True >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) @@ -371,6 +656,18 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + >>> rfc_path = temp_path + "/rfc" + >>> rf.save(rfc_path) + >>> rf2 = RandomForestClassifier.load(rfc_path) + >>> rf2.getNumTrees() + 3 + >>> model_path = temp_path + "/rfc_model" + >>> model.save(model_path) + >>> model2 = RandomForestClassificationModel.load(model_path) + >>> model.featureImportances == model2.featureImportances + True + + .. versionadded:: 1.4.0 """ @keyword_only @@ -396,6 +693,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, @@ -416,15 +714,33 @@ def _create_model(self, java_model): return RandomForestClassificationModel(java_model) -class RandomForestClassificationModel(TreeEnsembleModels): +class RandomForestClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable): """ Model fitted by RandomForestClassifier. + + .. versionadded:: 1.4.0 """ + @property + @since("2.0.0") + def featureImportances(self): + """ + Estimate of the importance of each feature. + + Each feature's importance is the average of its importance across all trees in the ensemble + The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. + (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) + and follows the implementation from scikit-learn. + + .. seealso:: :py:attr:`DecisionTreeClassificationModel.featureImportances` + """ + return self._call_java("featureImportances") + @inherit_doc class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - GBTParams, HasCheckpointInterval, HasStepSize, HasSeed): + GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable, + JavaMLReadable): """ `http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)` learning algorithm for classification. @@ -440,8 +756,10 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") >>> si_model = stringIndexer.fit(df) >>> td = si_model.transform(df) - >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed") + >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42) >>> model = gbt.fit(td) + >>> model.featureImportances + SparseVector(1, {0: 1.0}) >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1]) True >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) @@ -450,47 +768,58 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + >>> gbtc_path = temp_path + "gbtc" + >>> gbt.save(gbtc_path) + >>> gbt2 = GBTClassifier.load(gbtc_path) + >>> gbt2.getMaxDepth() + 2 + >>> model_path = temp_path + "gbtc_model" + >>> model.save(model_path) + >>> model2 = GBTClassificationModel.load(model_path) + >>> model.featureImportances == model2.featureImportances + True + >>> model.treeWeights == model2.treeWeights + True + + .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc lossType = Param(Params._dummy(), "lossType", "Loss function which GBT tries to minimize (case-insensitive). " + - "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) + "Supported options: " + ", ".join(GBTParams.supportedLossTypes), + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", - maxIter=20, stepSize=0.1): + maxIter=20, stepSize=0.1, seed=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - lossType="logistic", maxIter=20, stepSize=0.1) + lossType="logistic", maxIter=20, stepSize=0.1, seed=None) """ super(GBTClassifier, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.GBTClassifier", self.uid) - #: param for Loss function which GBT tries to minimize (case-insensitive). - self.lossType = Param(self, "lossType", - "Loss function which GBT tries to minimize (case-insensitive). " + - "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - lossType="logistic", maxIter=20, stepSize=0.1) + lossType="logistic", maxIter=20, stepSize=0.1, seed=None) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - lossType="logistic", maxIter=20, stepSize=0.1): + lossType="logistic", maxIter=20, stepSize=0.1, seed=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - lossType="logistic", maxIter=20, stepSize=0.1) + lossType="logistic", maxIter=20, stepSize=0.1, seed=None) Sets params for Gradient Boosted Tree Classification. """ kwargs = self.setParams._input_kwargs @@ -499,6 +828,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return GBTClassificationModel(java_model) + @since("1.4.0") def setLossType(self, value): """ Sets the value of :py:attr:`lossType`. @@ -506,6 +836,7 @@ def setLossType(self, value): self._paramMap[self.lossType] = value return self + @since("1.4.0") def getLossType(self): """ Gets the value of lossType or its default value. @@ -513,15 +844,32 @@ def getLossType(self): return self.getOrDefault(self.lossType) -class GBTClassificationModel(TreeEnsembleModels): +class GBTClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable): """ Model fitted by GBTClassifier. + + .. versionadded:: 1.4.0 """ + @property + @since("2.0.0") + def featureImportances(self): + """ + Estimate of the importance of each feature. + + Each feature's importance is the average of its importance across all trees in the ensemble + The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. + (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) + and follows the implementation from scikit-learn. + + .. seealso:: :py:attr:`DecisionTreeClassificationModel.featureImportances` + """ + return self._call_java("featureImportances") + @inherit_doc class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, - HasRawPredictionCol): + HasRawPredictionCol, JavaMLWritable, JavaMLReadable): """ Naive Bayes Classifiers. It supports both Multinomial and Bernoulli NB. Multinomial NB @@ -555,13 +903,27 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF() >>> model.transform(test1).head().prediction 1.0 + >>> nb_path = temp_path + "/nb" + >>> nb.save(nb_path) + >>> nb2 = NaiveBayes.load(nb_path) + >>> nb2.getSmoothing() + 1.0 + >>> model_path = temp_path + "/nb_model" + >>> model.save(model_path) + >>> model2 = NaiveBayesModel.load(model_path) + >>> model.pi == model2.pi + True + >>> model.theta == model2.theta + True + + .. versionadded:: 1.5.0 """ - # a placeholder to make it appear in the generated doc smoothing = Param(Params._dummy(), "smoothing", "The smoothing parameter, should be >= 0, " + - "default is 1.0") + "default is 1.0", typeConverter=TypeConverters.toFloat) modelType = Param(Params._dummy(), "modelType", "The model type which is a string " + - "(case-sensitive). Supported options: multinomial (default) and bernoulli.") + "(case-sensitive). Supported options: multinomial (default) and bernoulli.", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", @@ -575,18 +937,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(NaiveBayes, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.NaiveBayes", self.uid) - #: param for the smoothing parameter. - self.smoothing = Param(self, "smoothing", "The smoothing parameter, should be >= 0, " + - "default is 1.0") - #: param for the model type. - self.modelType = Param(self, "modelType", "The model type which is a string " + - "(case-sensitive). Supported options: multinomial (default) " + - "and bernoulli.") self._setDefault(smoothing=1.0, modelType="multinomial") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only + @since("1.5.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, modelType="multinomial"): @@ -602,6 +958,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return NaiveBayesModel(java_model) + @since("1.5.0") def setSmoothing(self, value): """ Sets the value of :py:attr:`smoothing`. @@ -609,12 +966,14 @@ def setSmoothing(self, value): self._paramMap[self.smoothing] = value return self + @since("1.5.0") def getSmoothing(self): """ Gets the value of smoothing or its default value. """ return self.getOrDefault(self.smoothing) + @since("1.5.0") def setModelType(self, value): """ Sets the value of :py:attr:`modelType`. @@ -622,6 +981,7 @@ def setModelType(self, value): self._paramMap[self.modelType] = value return self + @since("1.5.0") def getModelType(self): """ Gets the value of modelType or its default value. @@ -629,12 +989,15 @@ def getModelType(self): return self.getOrDefault(self.modelType) -class NaiveBayesModel(JavaModel): +class NaiveBayesModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by NaiveBayes. + + .. versionadded:: 1.5.0 """ @property + @since("1.5.0") def pi(self): """ log of class priors. @@ -642,6 +1005,7 @@ def pi(self): return self._call_java("pi") @property + @since("1.5.0") def theta(self): """ log of class conditional probabilities. @@ -651,7 +1015,7 @@ def theta(self): @inherit_doc class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - HasMaxIter, HasTol, HasSeed): + HasMaxIter, HasTol, HasSeed, JavaMLWritable, JavaMLReadable): """ Classifier trainer based on the Multilayer Perceptron. Each layer has sigmoid activation function, output layer has softmax. @@ -664,7 +1028,7 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, ... (1.0, Vectors.dense([0.0, 1.0])), ... (1.0, Vectors.dense([1.0, 0.0])), ... (0.0, Vectors.dense([1.0, 1.0]))], ["label", "features"]) - >>> mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[2, 5, 2], blockSize=1, seed=11) + >>> mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[2, 5, 2], blockSize=1, seed=123) >>> model = mlp.fit(df) >>> model.layers [2, 5, 2] @@ -681,16 +1045,31 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, |[0.0,0.0]| 0.0| +---------+----------+ ... + >>> mlp_path = temp_path + "/mlp" + >>> mlp.save(mlp_path) + >>> mlp2 = MultilayerPerceptronClassifier.load(mlp_path) + >>> mlp2.getBlockSize() + 1 + >>> model_path = temp_path + "/mlp_model" + >>> model.save(model_path) + >>> model2 = MultilayerPerceptronClassificationModel.load(model_path) + >>> model.layers == model2.layers + True + >>> model.weights == model2.weights + True + + .. versionadded:: 1.6.0 """ - # a placeholder to make it appear in the generated doc layers = Param(Params._dummy(), "layers", "Sizes of layers from input layer to output layer " + "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 " + - "neurons and output layer of 10 neurons, default is [1, 1].") + "neurons and output layer of 10 neurons, default is [1, 1].", + typeConverter=TypeConverters.toListInt) blockSize = Param(Params._dummy(), "blockSize", "Block size for stacking input data in " + "matrices. Data is stacked within partitions. If block size is more than " + "remaining data in a partition then it is adjusted to the size of this " + - "data. Recommended size is between 10 and 1000, default is 128.") + "data. Recommended size is between 10 and 1000, default is 128.", + typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", @@ -702,19 +1081,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(MultilayerPerceptronClassifier, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid) - self.layers = Param(self, "layers", "Sizes of layers from input layer to output layer " + - "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with " + - "100 neurons and output layer of 10 neurons, default is [1, 1].") - self.blockSize = Param(self, "blockSize", "Block size for stacking input data in " + - "matrices. Data is stacked within partitions. If block size is " + - "more than remaining data in a partition then it is adjusted to " + - "the size of this data. Recommended size is between 10 and 1000, " + - "default is 128.") self._setDefault(maxIter=100, tol=1E-4, layers=[1, 1], blockSize=128) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only + @since("1.6.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128): """ @@ -731,6 +1103,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return MultilayerPerceptronClassificationModel(java_model) + @since("1.6.0") def setLayers(self, value): """ Sets the value of :py:attr:`layers`. @@ -738,12 +1111,14 @@ def setLayers(self, value): self._paramMap[self.layers] = value return self + @since("1.6.0") def getLayers(self): """ Gets the value of layers or its default value. """ return self.getOrDefault(self.layers) + @since("1.6.0") def setBlockSize(self, value): """ Sets the value of :py:attr:`blockSize`. @@ -751,6 +1126,7 @@ def setBlockSize(self, value): self._paramMap[self.blockSize] = value return self + @since("1.6.0") def getBlockSize(self): """ Gets the value of blockSize or its default value. @@ -758,12 +1134,15 @@ def getBlockSize(self): return self.getOrDefault(self.blockSize) -class MultilayerPerceptronClassificationModel(JavaModel): +class MultilayerPerceptronClassificationModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by MultilayerPerceptronClassifier. + + .. versionadded:: 1.6.0 """ @property + @since("1.6.0") def layers(self): """ array of layer sizes including input and output layers. @@ -771,6 +1150,7 @@ def layers(self): return self._call_java("javaLayers") @property + @since("1.6.0") def weights(self): """ vector of initial weights for the model that consists of the weights of layers. @@ -780,17 +1160,27 @@ def weights(self): if __name__ == "__main__": import doctest + import pyspark.ml.classification from pyspark.context import SparkContext from pyspark.sql import SQLContext - globs = globals().copy() + globs = pyspark.ml.classification.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.classification tests") sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - (failure_count, test_count) = doctest.testmod( - globs=globs, optionflags=doctest.ELLIPSIS) - sc.stop() + import tempfile + temp_path = tempfile.mkdtemp() + globs['temp_path'] = temp_path + try: + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + finally: + from shutil import rmtree + try: + rmtree(temp_path) + except OSError: + pass if failure_count: exit(-1) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 7bb8ab94e17df..f071c597c87f3 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -16,15 +16,16 @@ # from pyspark import since -from pyspark.ml.util import keyword_only +from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * from pyspark.mllib.common import inherit_doc -__all__ = ['KMeans', 'KMeansModel'] +__all__ = ['BisectingKMeans', 'BisectingKMeansModel', + 'KMeans', 'KMeansModel'] -class KMeansModel(JavaModel): +class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by KMeans. @@ -36,9 +37,18 @@ def clusterCenters(self): """Get the cluster centers, represented as a list of NumPy arrays.""" return [c.toArray() for c in self._call_java("clusterCenters")] + @since("2.0.0") + def computeCost(self, dataset): + """ + Return the K-means cost (sum of squared distances of points to their nearest center) + for this model on the given data. + """ + return self._call_java("computeCost", dataset) + @inherit_doc -class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed): +class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, + JavaMLWritable, JavaMLReadable): """ K-means clustering with support for multiple parallel runs and a k-means++ like initialization mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, @@ -53,23 +63,38 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol >>> centers = model.clusterCenters() >>> len(centers) 2 + >>> model.computeCost(df) + 2.000... >>> transformed = model.transform(df).select("features", "prediction") >>> rows = transformed.collect() >>> rows[0].prediction == rows[1].prediction True >>> rows[2].prediction == rows[3].prediction True + >>> kmeans_path = temp_path + "/kmeans" + >>> kmeans.save(kmeans_path) + >>> kmeans2 = KMeans.load(kmeans_path) + >>> kmeans2.getK() + 2 + >>> model_path = temp_path + "/kmeans_model" + >>> model.save(model_path) + >>> model2 = KMeansModel.load(model_path) + >>> model.clusterCenters()[0] == model2.clusterCenters()[0] + array([ True, True], dtype=bool) + >>> model.clusterCenters()[1] == model2.clusterCenters()[1] + array([ True, True], dtype=bool) .. versionadded:: 1.5.0 """ - # a placeholder to make it appear in the generated doc - k = Param(Params._dummy(), "k", "number of clusters to create") + k = Param(Params._dummy(), "k", "number of clusters to create", + typeConverter=TypeConverters.toInt) initMode = Param(Params._dummy(), "initMode", "the initialization algorithm. This can be either \"random\" to " + "choose random points as initial cluster centers, or \"k-means||\" " + - "to use a parallel variant of k-means++") - initSteps = Param(Params._dummy(), "initSteps", "steps for k-means initialization mode") + "to use a parallel variant of k-means++", TypeConverters.toString) + initSteps = Param(Params._dummy(), "initSteps", "steps for k-means initialization mode", + typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, featuresCol="features", predictionCol="prediction", k=2, @@ -80,12 +105,6 @@ def __init__(self, featuresCol="features", predictionCol="prediction", k=2, """ super(KMeans, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid) - self.k = Param(self, "k", "number of clusters to create") - self.initMode = Param(self, "initMode", - "the initialization algorithm. This can be either \"random\" to " + - "choose random points as initial cluster centers, or \"k-means||\" " + - "to use a parallel variant of k-means++") - self.initSteps = Param(self, "initSteps", "steps for k-means initialization mode") self._setDefault(k=2, initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -110,10 +129,6 @@ def setParams(self, featuresCol="features", predictionCol="prediction", k=2, def setK(self, value): """ Sets the value of :py:attr:`k`. - - >>> algo = KMeans().setK(10) - >>> algo.getK() - 10 """ self._paramMap[self.k] = value return self @@ -129,13 +144,6 @@ def getK(self): def setInitMode(self, value): """ Sets the value of :py:attr:`initMode`. - - >>> algo = KMeans() - >>> algo.getInitMode() - 'k-means||' - >>> algo = algo.setInitMode("random") - >>> algo.getInitMode() - 'random' """ self._paramMap[self.initMode] = value return self @@ -151,10 +159,6 @@ def getInitMode(self): def setInitSteps(self, value): """ Sets the value of :py:attr:`initSteps`. - - >>> algo = KMeans().setInitSteps(10) - >>> algo.getInitSteps() - 10 """ self._paramMap[self.initSteps] = value return self @@ -167,18 +171,167 @@ def getInitSteps(self): return self.getOrDefault(self.initSteps) +class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): + """ + .. note:: Experimental + + Model fitted by BisectingKMeans. + + .. versionadded:: 2.0.0 + """ + + @since("2.0.0") + def clusterCenters(self): + """Get the cluster centers, represented as a list of NumPy arrays.""" + return [c.toArray() for c in self._call_java("clusterCenters")] + + @since("2.0.0") + def computeCost(self, dataset): + """ + Computes the sum of squared distances between the input points + and their corresponding cluster centers. + """ + return self._call_java("computeCost", dataset) + + +@inherit_doc +class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasSeed, + JavaMLWritable, JavaMLReadable): + """ + .. note:: Experimental + + A bisecting k-means algorithm based on the paper "A comparison of document clustering + techniques" by Steinbach, Karypis, and Kumar, with modification to fit Spark. + The algorithm starts from a single cluster that contains all points. + Iteratively it finds divisible clusters on the bottom level and bisects each of them using + k-means, until there are `k` leaf clusters in total or no leaf clusters are divisible. + The bisecting steps of clusters on the same level are grouped together to increase parallelism. + If bisecting all divisible clusters on the bottom level would result more than `k` leaf + clusters, larger clusters get higher priority. + + >>> from pyspark.mllib.linalg import Vectors + >>> data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), + ... (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)] + >>> df = sqlContext.createDataFrame(data, ["features"]) + >>> bkm = BisectingKMeans(k=2, minDivisibleClusterSize=1.0) + >>> model = bkm.fit(df) + >>> centers = model.clusterCenters() + >>> len(centers) + 2 + >>> model.computeCost(df) + 2.000... + >>> transformed = model.transform(df).select("features", "prediction") + >>> rows = transformed.collect() + >>> rows[0].prediction == rows[1].prediction + True + >>> rows[2].prediction == rows[3].prediction + True + >>> bkm_path = temp_path + "/bkm" + >>> bkm.save(bkm_path) + >>> bkm2 = BisectingKMeans.load(bkm_path) + >>> bkm2.getK() + 2 + >>> model_path = temp_path + "/bkm_model" + >>> model.save(model_path) + >>> model2 = BisectingKMeansModel.load(model_path) + >>> model.clusterCenters()[0] == model2.clusterCenters()[0] + array([ True, True], dtype=bool) + >>> model.clusterCenters()[1] == model2.clusterCenters()[1] + array([ True, True], dtype=bool) + + .. versionadded:: 2.0.0 + """ + + k = Param(Params._dummy(), "k", "number of clusters to create", + typeConverter=TypeConverters.toInt) + minDivisibleClusterSize = Param(Params._dummy(), "minDivisibleClusterSize", + "the minimum number of points (if >= 1.0) " + + "or the minimum proportion", + typeConverter=TypeConverters.toFloat) + + @keyword_only + def __init__(self, featuresCol="features", predictionCol="prediction", maxIter=20, + seed=None, k=4, minDivisibleClusterSize=1.0): + """ + __init__(self, featuresCol="features", predictionCol="prediction", maxIter=20, \ + seed=None, k=4, minDivisibleClusterSize=1.0) + """ + super(BisectingKMeans, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.BisectingKMeans", + self.uid) + self._setDefault(maxIter=20, k=4, minDivisibleClusterSize=1.0) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.0.0") + def setParams(self, featuresCol="features", predictionCol="prediction", maxIter=20, + seed=None, k=4, minDivisibleClusterSize=1.0): + """ + setParams(self, featuresCol="features", predictionCol="prediction", maxIter=20, \ + seed=None, k=4, minDivisibleClusterSize=1.0) + Sets params for BisectingKMeans. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + @since("2.0.0") + def setK(self, value): + """ + Sets the value of :py:attr:`k`. + """ + self._paramMap[self.k] = value + return self + + @since("2.0.0") + def getK(self): + """ + Gets the value of `k` or its default value. + """ + return self.getOrDefault(self.k) + + @since("2.0.0") + def setMinDivisibleClusterSize(self, value): + """ + Sets the value of :py:attr:`minDivisibleClusterSize`. + """ + self._paramMap[self.minDivisibleClusterSize] = value + return self + + @since("2.0.0") + def getMinDivisibleClusterSize(self): + """ + Gets the value of `minDivisibleClusterSize` or its default value. + """ + return self.getOrDefault(self.minDivisibleClusterSize) + + def _create_model(self, java_model): + return BisectingKMeansModel(java_model) + + if __name__ == "__main__": import doctest + import pyspark.ml.clustering from pyspark.context import SparkContext from pyspark.sql import SQLContext - globs = globals().copy() + globs = pyspark.ml.clustering.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.clustering tests") sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - sc.stop() + import tempfile + temp_path = tempfile.mkdtemp() + globs['temp_path'] = temp_path + try: + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + finally: + from shutil import rmtree + try: + rmtree(temp_path) + except OSError: + pass if failure_count: exit(-1) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index dcc1738ec518b..4b0bade102802 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -18,7 +18,7 @@ from abc import abstractmethod, ABCMeta from pyspark import since -from pyspark.ml.wrapper import JavaWrapper +from pyspark.ml.wrapper import JavaParams from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol from pyspark.ml.util import keyword_only @@ -81,7 +81,7 @@ def isLargerBetter(self): @inherit_doc -class JavaEvaluator(Evaluator, JavaWrapper): +class JavaEvaluator(JavaParams, Evaluator): """ Base class for :py:class:`Evaluator`s that wrap Java/Scala implementations. @@ -106,8 +106,9 @@ def isLargerBetter(self): @inherit_doc class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol): """ - Evaluator for binary classification, which expects two input - columns: rawPrediction and label. + Evaluator for binary classification, which expects two input columns: rawPrediction and label. + The rawPrediction column can be of type double (binary 0/1 prediction, or probability of label + 1) or of type vector (length-2 vector of raw predictions, scores, or label probabilities). >>> from pyspark.mllib.linalg import Vectors >>> scoreAndLabels = map(lambda x: (Vectors.dense([1.0 - x[0], x[0]]), x[1]), @@ -123,7 +124,6 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc metricName = Param(Params._dummy(), "metricName", "metric name in evaluation (areaUnderROC|areaUnderPR)") @@ -137,9 +137,6 @@ def __init__(self, rawPredictionCol="rawPrediction", labelCol="label", super(BinaryClassificationEvaluator, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator", self.uid) - #: param for metric name in evaluation (areaUnderROC|areaUnderPR) - self.metricName = Param(self, "metricName", - "metric name in evaluation (areaUnderROC|areaUnderPR)") self._setDefault(rawPredictionCol="rawPrediction", labelCol="label", metricName="areaUnderROC") kwargs = self.__init__._input_kwargs @@ -209,9 +206,6 @@ def __init__(self, predictionCol="prediction", labelCol="label", super(RegressionEvaluator, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid) - #: param for metric name in evaluation (mse|rmse|r2|mae) - self.metricName = Param(self, "metricName", - "metric name in evaluation (mse|rmse|r2|mae)") self._setDefault(predictionCol="prediction", labelCol="label", metricName="rmse") kwargs = self.__init__._input_kwargs @@ -264,7 +258,6 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio .. versionadded:: 1.5.0 """ - # a placeholder to make it appear in the generated doc metricName = Param(Params._dummy(), "metricName", "metric name in evaluation " "(f1|precision|recall|weightedPrecision|weightedRecall)") @@ -279,10 +272,6 @@ def __init__(self, predictionCol="prediction", labelCol="label", super(MulticlassClassificationEvaluator, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid) - # param for metric name in evaluation (f1|precision|recall|weightedPrecision|weightedRecall) - self.metricName = Param(self, "metricName", - "metric name in evaluation" - " (f1|precision|recall|weightedPrecision|weightedRecall)") self._setDefault(predictionCol="prediction", labelCol="label", metricName="f1") kwargs = self.__init__._input_kwargs diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index c7b6dd926c3e8..809a513316f9f 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -22,22 +22,43 @@ from pyspark import since from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.param.shared import * -from pyspark.ml.util import keyword_only +from pyspark.ml.util import keyword_only, JavaMLReadable, JavaMLWritable from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector -__all__ = ['Binarizer', 'Bucketizer', 'CountVectorizer', 'CountVectorizerModel', 'DCT', - 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', 'IndexToString', 'MinMaxScaler', - 'MinMaxScalerModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PCA', 'PCAModel', - 'PolynomialExpansion', 'RegexTokenizer', 'RFormula', 'RFormulaModel', 'SQLTransformer', - 'StandardScaler', 'StandardScalerModel', 'StopWordsRemover', 'StringIndexer', - 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', +__all__ = ['Binarizer', + 'Bucketizer', + 'ChiSqSelector', 'ChiSqSelectorModel', + 'CountVectorizer', 'CountVectorizerModel', + 'DCT', + 'ElementwiseProduct', + 'HashingTF', + 'IDF', 'IDFModel', + 'IndexToString', + 'MaxAbsScaler', 'MaxAbsScalerModel', + 'MinMaxScaler', 'MinMaxScalerModel', + 'NGram', + 'Normalizer', + 'OneHotEncoder', + 'PCA', 'PCAModel', + 'PolynomialExpansion', + 'QuantileDiscretizer', + 'RegexTokenizer', + 'RFormula', 'RFormulaModel', + 'SQLTransformer', + 'StandardScaler', 'StandardScalerModel', + 'StopWordsRemover', + 'StringIndexer', 'StringIndexerModel', + 'Tokenizer', + 'VectorAssembler', + 'VectorIndexer', 'VectorIndexerModel', + 'VectorSlicer', 'Word2Vec', 'Word2VecModel'] @inherit_doc -class Binarizer(JavaTransformer, HasInputCol, HasOutputCol): +class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -52,13 +73,18 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol): >>> params = {binarizer.threshold: -0.5, binarizer.outputCol: "vector"} >>> binarizer.transform(df, params).head().vector 1.0 + >>> binarizerPath = temp_path + "/binarizer" + >>> binarizer.save(binarizerPath) + >>> loadedBinarizer = Binarizer.load(binarizerPath) + >>> loadedBinarizer.getThreshold() == binarizer.getThreshold() + True .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc threshold = Param(Params._dummy(), "threshold", - "threshold in binary classification prediction, in range [0, 1]") + "threshold in binary classification prediction, in range [0, 1]", + typeConverter=TypeConverters.toFloat) @keyword_only def __init__(self, threshold=0.0, inputCol=None, outputCol=None): @@ -67,8 +93,6 @@ def __init__(self, threshold=0.0, inputCol=None, outputCol=None): """ super(Binarizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Binarizer", self.uid) - self.threshold = Param(self, "threshold", - "threshold in binary classification prediction, in range [0, 1]") self._setDefault(threshold=0.0) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -100,7 +124,7 @@ def getThreshold(self): @inherit_doc -class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol): +class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -120,11 +144,15 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol): 2.0 >>> bucketizer.setParams(outputCol="b").transform(df).head().b 0.0 + >>> bucketizerPath = temp_path + "/bucketizer" + >>> bucketizer.save(bucketizerPath) + >>> loadedBucketizer = Bucketizer.load(bucketizerPath) + >>> loadedBucketizer.getSplits() == bucketizer.getSplits() + True .. versionadded:: 1.3.0 """ - # a placeholder to make it appear in the generated doc splits = \ Param(Params._dummy(), "splits", "Split points for mapping continuous features into buckets. With n+1 splits, " + @@ -132,7 +160,8 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol): "range [x,y) except the last bucket, which also includes y. The splits " + "should be strictly increasing. Values at -inf, inf must be explicitly " + "provided to cover all Double values; otherwise, values outside the splits " + - "specified will be treated as errors.") + "specified will be treated as errors.", + typeConverter=TypeConverters.toListFloat) @keyword_only def __init__(self, splits=None, inputCol=None, outputCol=None): @@ -141,19 +170,6 @@ def __init__(self, splits=None, inputCol=None, outputCol=None): """ super(Bucketizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid) - #: param for Splitting points for mapping continuous features into buckets. With n+1 splits, - # there are n buckets. A bucket defined by splits x,y holds values in the range [x,y) - # except the last bucket, which also includes y. The splits should be strictly increasing. - # Values at -inf, inf must be explicitly provided to cover all Double values; otherwise, - # values outside the splits specified will be treated as errors. - self.splits = \ - Param(self, "splits", - "Split points for mapping continuous features into buckets. With n+1 splits, " + - "there are n buckets. A bucket defined by splits x,y holds values in the " + - "range [x,y) except the last bucket, which also includes y. The splits " + - "should be strictly increasing. Values at -inf, inf must be explicitly " + - "provided to cover all Double values; otherwise, values outside the splits " + - "specified will be treated as errors.") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -184,7 +200,7 @@ def getSplits(self): @inherit_doc -class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol): +class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -203,62 +219,70 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol): |1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])| +-----+---------------+-------------------------+ ... - >>> sorted(map(str, model.vocabulary)) - ['a', 'b', 'c'] + >>> sorted(model.vocabulary) == ['a', 'b', 'c'] + True + >>> countVectorizerPath = temp_path + "/count-vectorizer" + >>> cv.save(countVectorizerPath) + >>> loadedCv = CountVectorizer.load(countVectorizerPath) + >>> loadedCv.getMinDF() == cv.getMinDF() + True + >>> loadedCv.getMinTF() == cv.getMinTF() + True + >>> loadedCv.getVocabSize() == cv.getVocabSize() + True + >>> modelPath = temp_path + "/count-vectorizer-model" + >>> model.save(modelPath) + >>> loadedModel = CountVectorizerModel.load(modelPath) + >>> loadedModel.vocabulary == model.vocabulary + True .. versionadded:: 1.6.0 """ - # a placeholder to make it appear in the generated doc minTF = Param( Params._dummy(), "minTF", "Filter to ignore rare words in" + " a document. For each document, terms with frequency/count less than the given" + " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + " times the term must appear in the document); if this is a double in [0,1), then this " + "specifies a fraction (out of the document's token count). Note that the parameter is " + - "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0") + "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0", + typeConverter=TypeConverters.toFloat) minDF = Param( Params._dummy(), "minDF", "Specifies the minimum number of" + " different documents a term must appear in to be included in the vocabulary." + " If this is an integer >= 1, this specifies the number of documents the term must" + " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + - " Default 1.0") + " Default 1.0", typeConverter=TypeConverters.toFloat) vocabSize = Param( - Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.") + Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.", + typeConverter=TypeConverters.toInt) + binary = Param( + Params._dummy(), "binary", "Binary toggle to control the output vector values." + + " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" + + " for discrete probabilistic models that model binary events rather than integer counts." + + " Default False", typeConverter=TypeConverters.toBoolean) @keyword_only - def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None): + def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, + outputCol=None): """ - __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None) + __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\ + outputCol=None) """ super(CountVectorizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", self.uid) - self.minTF = Param( - self, "minTF", "Filter to ignore rare words in" + - " a document. For each document, terms with frequency/count less than the given" + - " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + - " times the term must appear in the document); if this is a double in [0,1), then " + - "this specifies a fraction (out of the document's token count). Note that the " + - "parameter is only used in transform of CountVectorizerModel and does not affect" + - "fitting. Default 1.0") - self.minDF = Param( - self, "minDF", "Specifies the minimum number of" + - " different documents a term must appear in to be included in the vocabulary." + - " If this is an integer >= 1, this specifies the number of documents the term must" + - " appear in; if this is a double in [0,1), then this specifies the fraction of " + - "documents. Default 1.0") - self.vocabSize = Param( - self, "vocabSize", "max size of the vocabulary. Default 1 << 18.") - self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18) + self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.6.0") - def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None): + def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, + outputCol=None): """ - setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None) + setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\ + outputCol=None) Set the params for the CountVectorizer """ kwargs = self.setParams._input_kwargs @@ -309,11 +333,26 @@ def getVocabSize(self): """ return self.getOrDefault(self.vocabSize) + @since("2.0.0") + def setBinary(self, value): + """ + Sets the value of :py:attr:`binary`. + """ + self._paramMap[self.binary] = value + return self + + @since("2.0.0") + def getBinary(self): + """ + Gets the value of binary or its default value. + """ + return self.getOrDefault(self.binary) + def _create_model(self, java_model): return CountVectorizerModel(java_model) -class CountVectorizerModel(JavaModel): +class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -332,7 +371,7 @@ def vocabulary(self): @inherit_doc -class DCT(JavaTransformer, HasInputCol, HasOutputCol): +class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -354,13 +393,17 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol): >>> df3 = DCT(inverse=True, inputCol="resultVec", outputCol="origVec").transform(df2) >>> df3.head().origVec DenseVector([5.0, 8.0, 6.0]) + >>> dctPath = temp_path + "/dct" + >>> dct.save(dctPath) + >>> loadedDtc = DCT.load(dctPath) + >>> loadedDtc.getInverse() + False .. versionadded:: 1.6.0 """ - # a placeholder to make it appear in the generated doc inverse = Param(Params._dummy(), "inverse", "Set transformer to perform inverse DCT, " + - "default False.") + "default False.", typeConverter=TypeConverters.toBoolean) @keyword_only def __init__(self, inverse=False, inputCol=None, outputCol=None): @@ -369,8 +412,6 @@ def __init__(self, inverse=False, inputCol=None, outputCol=None): """ super(DCT, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.DCT", self.uid) - self.inverse = Param(self, "inverse", "Set transformer to perform inverse DCT, " + - "default False.") self._setDefault(inverse=False) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -402,7 +443,8 @@ def getInverse(self): @inherit_doc -class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol): +class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -418,13 +460,17 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol): DenseVector([2.0, 2.0, 9.0]) >>> ep.setParams(scalingVec=Vectors.dense([2.0, 3.0, 5.0])).transform(df).head().eprod DenseVector([4.0, 3.0, 15.0]) + >>> elementwiseProductPath = temp_path + "/elementwise-product" + >>> ep.save(elementwiseProductPath) + >>> loadedEp = ElementwiseProduct.load(elementwiseProductPath) + >>> loadedEp.getScalingVec() == ep.getScalingVec() + True .. versionadded:: 1.5.0 """ - # a placeholder to make it appear in the generated doc - scalingVec = Param(Params._dummy(), "scalingVec", "vector for hadamard product, " + - "it must be MLlib Vector type.") + scalingVec = Param(Params._dummy(), "scalingVec", "Vector for hadamard product.", + typeConverter=TypeConverters.toVector) @keyword_only def __init__(self, scalingVec=None, inputCol=None, outputCol=None): @@ -434,8 +480,6 @@ def __init__(self, scalingVec=None, inputCol=None, outputCol=None): super(ElementwiseProduct, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ElementwiseProduct", self.uid) - self.scalingVec = Param(self, "scalingVec", "vector for hadamard product, " + - "it must be MLlib Vector type.") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -466,7 +510,8 @@ def getScalingVec(self): @inherit_doc -class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): +class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -482,18 +527,28 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): >>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"} >>> hashingTF.transform(df, params).head().vector SparseVector(5, {2: 1.0, 3: 1.0, 4: 1.0}) + >>> hashingTFPath = temp_path + "/hashing-tf" + >>> hashingTF.save(hashingTFPath) + >>> loadedHashingTF = HashingTF.load(hashingTFPath) + >>> loadedHashingTF.getNumFeatures() == hashingTF.getNumFeatures() + True .. versionadded:: 1.3.0 """ + binary = Param(Params._dummy(), "binary", "If True, all non zero counts are set to 1. " + + "This is useful for discrete probabilistic models that model binary events " + + "rather than integer counts. Default False.", + typeConverter=TypeConverters.toBoolean) + @keyword_only - def __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None): + def __init__(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None): """ __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None) """ super(HashingTF, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.HashingTF", self.uid) - self._setDefault(numFeatures=1 << 18) + self._setDefault(numFeatures=1 << 18, binary=False) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -507,9 +562,24 @@ def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("2.0.0") + def setBinary(self, value): + """ + Sets the value of :py:attr:`binary`. + """ + self._paramMap[self.binary] = value + return self + + @since("2.0.0") + def getBinary(self): + """ + Gets the value of binary or its default value. + """ + return self.getOrDefault(self.binary) + @inherit_doc -class IDF(JavaEstimator, HasInputCol, HasOutputCol): +class IDF(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -519,20 +589,31 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol): >>> df = sqlContext.createDataFrame([(DenseVector([1.0, 2.0]),), ... (DenseVector([0.0, 1.0]),), (DenseVector([3.0, 0.2]),)], ["tf"]) >>> idf = IDF(minDocFreq=3, inputCol="tf", outputCol="idf") - >>> idf.fit(df).transform(df).head().idf + >>> model = idf.fit(df) + >>> model.transform(df).head().idf DenseVector([0.0, 0.0]) >>> idf.setParams(outputCol="freqs").fit(df).transform(df).collect()[1].freqs DenseVector([0.0, 0.0]) >>> params = {idf.minDocFreq: 1, idf.outputCol: "vector"} >>> idf.fit(df, params).transform(df).head().vector DenseVector([0.2877, 0.0]) + >>> idfPath = temp_path + "/idf" + >>> idf.save(idfPath) + >>> loadedIdf = IDF.load(idfPath) + >>> loadedIdf.getMinDocFreq() == idf.getMinDocFreq() + True + >>> modelPath = temp_path + "/idf-model" + >>> model.save(modelPath) + >>> loadedModel = IDFModel.load(modelPath) + >>> loadedModel.transform(df).head().idf == model.transform(df).head().idf + True .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc minDocFreq = Param(Params._dummy(), "minDocFreq", - "minimum of documents in which a term should appear for filtering") + "minimum of documents in which a term should appear for filtering", + typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, minDocFreq=0, inputCol=None, outputCol=None): @@ -541,8 +622,6 @@ def __init__(self, minDocFreq=0, inputCol=None, outputCol=None): """ super(IDF, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IDF", self.uid) - self.minDocFreq = Param(self, "minDocFreq", - "minimum of documents in which a term should appear for filtering") self._setDefault(minDocFreq=0) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -576,7 +655,7 @@ def _create_model(self, java_model): return IDFModel(java_model) -class IDFModel(JavaModel): +class IDFModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -587,7 +666,87 @@ class IDFModel(JavaModel): @inherit_doc -class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol): +class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): + """ + .. note:: Experimental + + Rescale each feature individually to range [-1, 1] by dividing through the largest maximum + absolute value in each feature. It does not shift/center the data, and thus does not destroy + any sparsity. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([(Vectors.dense([1.0]),), (Vectors.dense([2.0]),)], ["a"]) + >>> maScaler = MaxAbsScaler(inputCol="a", outputCol="scaled") + >>> model = maScaler.fit(df) + >>> model.transform(df).show() + +-----+------+ + | a|scaled| + +-----+------+ + |[1.0]| [0.5]| + |[2.0]| [1.0]| + +-----+------+ + ... + >>> scalerPath = temp_path + "/max-abs-scaler" + >>> maScaler.save(scalerPath) + >>> loadedMAScaler = MaxAbsScaler.load(scalerPath) + >>> loadedMAScaler.getInputCol() == maScaler.getInputCol() + True + >>> loadedMAScaler.getOutputCol() == maScaler.getOutputCol() + True + >>> modelPath = temp_path + "/max-abs-scaler-model" + >>> model.save(modelPath) + >>> loadedModel = MaxAbsScalerModel.load(modelPath) + >>> loadedModel.maxAbs == model.maxAbs + True + + .. versionadded:: 2.0.0 + """ + + @keyword_only + def __init__(self, inputCol=None, outputCol=None): + """ + __init__(self, inputCol=None, outputCol=None) + """ + super(MaxAbsScaler, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MaxAbsScaler", self.uid) + self._setDefault() + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.0.0") + def setParams(self, inputCol=None, outputCol=None): + """ + setParams(self, inputCol=None, outputCol=None) + Sets params for this MaxAbsScaler. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return MaxAbsScalerModel(java_model) + + +class MaxAbsScalerModel(JavaModel, JavaMLReadable, JavaMLWritable): + """ + .. note:: Experimental + + Model fitted by :py:class:`MaxAbsScaler`. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def maxAbs(self): + """ + Max Abs vector. + """ + return self._call_java("maxAbs") + + +@inherit_doc +class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -606,6 +765,10 @@ class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol): >>> df = sqlContext.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"]) >>> mmScaler = MinMaxScaler(inputCol="a", outputCol="scaled") >>> model = mmScaler.fit(df) + >>> model.originalMin + DenseVector([0.0]) + >>> model.originalMax + DenseVector([2.0]) >>> model.transform(df).show() +-----+------+ | a|scaled| @@ -614,13 +777,28 @@ class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol): |[2.0]| [1.0]| +-----+------+ ... + >>> minMaxScalerPath = temp_path + "/min-max-scaler" + >>> mmScaler.save(minMaxScalerPath) + >>> loadedMMScaler = MinMaxScaler.load(minMaxScalerPath) + >>> loadedMMScaler.getMin() == mmScaler.getMin() + True + >>> loadedMMScaler.getMax() == mmScaler.getMax() + True + >>> modelPath = temp_path + "/min-max-scaler-model" + >>> model.save(modelPath) + >>> loadedModel = MinMaxScalerModel.load(modelPath) + >>> loadedModel.originalMin == model.originalMin + True + >>> loadedModel.originalMax == model.originalMax + True .. versionadded:: 1.6.0 """ - # a placeholder to make it appear in the generated doc - min = Param(Params._dummy(), "min", "Lower bound of the output feature range") - max = Param(Params._dummy(), "max", "Upper bound of the output feature range") + min = Param(Params._dummy(), "min", "Lower bound of the output feature range", + typeConverter=TypeConverters.toFloat) + max = Param(Params._dummy(), "max", "Upper bound of the output feature range", + typeConverter=TypeConverters.toFloat) @keyword_only def __init__(self, min=0.0, max=1.0, inputCol=None, outputCol=None): @@ -629,8 +807,6 @@ def __init__(self, min=0.0, max=1.0, inputCol=None, outputCol=None): """ super(MinMaxScaler, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MinMaxScaler", self.uid) - self.min = Param(self, "min", "Lower bound of the output feature range") - self.max = Param(self, "max", "Upper bound of the output feature range") self._setDefault(min=0.0, max=1.0) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -679,7 +855,7 @@ def _create_model(self, java_model): return MinMaxScalerModel(java_model) -class MinMaxScalerModel(JavaModel): +class MinMaxScalerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -688,10 +864,26 @@ class MinMaxScalerModel(JavaModel): .. versionadded:: 1.6.0 """ + @property + @since("2.0.0") + def originalMin(self): + """ + Min value for each original column during fitting. + """ + return self._call_java("originalMin") + + @property + @since("2.0.0") + def originalMax(self): + """ + Max value for each original column during fitting. + """ + return self._call_java("originalMax") + @inherit_doc @ignore_unicode_prefix -class NGram(JavaTransformer, HasInputCol, HasOutputCol): +class NGram(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -720,12 +912,17 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol): Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + >>> ngramPath = temp_path + "/ngram" + >>> ngram.save(ngramPath) + >>> loadedNGram = NGram.load(ngramPath) + >>> loadedNGram.getN() == ngram.getN() + True .. versionadded:: 1.5.0 """ - # a placeholder to make it appear in the generated doc - n = Param(Params._dummy(), "n", "number of elements per n-gram (>=1)") + n = Param(Params._dummy(), "n", "number of elements per n-gram (>=1)", + typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, n=2, inputCol=None, outputCol=None): @@ -734,7 +931,6 @@ def __init__(self, n=2, inputCol=None, outputCol=None): """ super(NGram, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.NGram", self.uid) - self.n = Param(self, "n", "number of elements per n-gram (>=1)") self._setDefault(n=2) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -766,7 +962,7 @@ def getN(self): @inherit_doc -class Normalizer(JavaTransformer, HasInputCol, HasOutputCol): +class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -783,12 +979,17 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol): >>> params = {normalizer.p: 1.0, normalizer.inputCol: "dense", normalizer.outputCol: "vector"} >>> normalizer.transform(df, params).head().vector DenseVector([0.4286, -0.5714]) + >>> normalizerPath = temp_path + "/normalizer" + >>> normalizer.save(normalizerPath) + >>> loadedNormalizer = Normalizer.load(normalizerPath) + >>> loadedNormalizer.getP() == normalizer.getP() + True .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc - p = Param(Params._dummy(), "p", "the p norm value.") + p = Param(Params._dummy(), "p", "the p norm value.", + typeConverter=TypeConverters.toFloat) @keyword_only def __init__(self, p=2.0, inputCol=None, outputCol=None): @@ -797,7 +998,6 @@ def __init__(self, p=2.0, inputCol=None, outputCol=None): """ super(Normalizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Normalizer", self.uid) - self.p = Param(self, "p", "the p norm value.") self._setDefault(p=2.0) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -829,7 +1029,7 @@ def getP(self): @inherit_doc -class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol): +class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -862,12 +1062,17 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol): >>> params = {encoder.dropLast: False, encoder.outputCol: "test"} >>> encoder.transform(td, params).head().test SparseVector(3, {0: 1.0}) + >>> onehotEncoderPath = temp_path + "/onehot-encoder" + >>> encoder.save(onehotEncoderPath) + >>> loadedEncoder = OneHotEncoder.load(onehotEncoderPath) + >>> loadedEncoder.getDropLast() == encoder.getDropLast() + True .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc - dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category") + dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category", + typeConverter=TypeConverters.toBoolean) @keyword_only def __init__(self, dropLast=True, inputCol=None, outputCol=None): @@ -876,7 +1081,6 @@ def __init__(self, dropLast=True, inputCol=None, outputCol=None): """ super(OneHotEncoder, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.OneHotEncoder", self.uid) - self.dropLast = Param(self, "dropLast", "whether to drop the last category") self._setDefault(dropLast=True) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -908,7 +1112,8 @@ def getDropLast(self): @inherit_doc -class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol): +class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -925,12 +1130,17 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol): DenseVector([0.5, 0.25, 2.0, 1.0, 4.0]) >>> px.setParams(outputCol="test").transform(df).head().test DenseVector([0.5, 0.25, 2.0, 1.0, 4.0]) + >>> polyExpansionPath = temp_path + "/poly-expansion" + >>> px.save(polyExpansionPath) + >>> loadedPx = PolynomialExpansion.load(polyExpansionPath) + >>> loadedPx.getDegree() == px.getDegree() + True .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc - degree = Param(Params._dummy(), "degree", "the polynomial degree to expand (>= 1)") + degree = Param(Params._dummy(), "degree", "the polynomial degree to expand (>= 1)", + typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, degree=2, inputCol=None, outputCol=None): @@ -940,7 +1150,6 @@ def __init__(self, degree=2, inputCol=None, outputCol=None): super(PolynomialExpansion, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.feature.PolynomialExpansion", self.uid) - self.degree = Param(self, "degree", "the polynomial degree to expand (>= 1)") self._setDefault(degree=2) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -971,9 +1180,99 @@ def getDegree(self): return self.getOrDefault(self.degree) +@inherit_doc +class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, JavaMLReadable, + JavaMLWritable): + """ + .. note:: Experimental + + `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned + categorical features. The bin ranges are chosen by taking a sample of the data and dividing it + into roughly equal parts. The lower and upper bin bounds will be -Infinity and +Infinity, + covering all real values. This attempts to find numBuckets partitions based on a sample of data, + but it may find fewer depending on the data sample values. + + >>> df = sqlContext.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"]) + >>> qds = QuantileDiscretizer(numBuckets=2, + ... inputCol="values", outputCol="buckets", seed=123) + >>> qds.getSeed() + 123 + >>> bucketizer = qds.fit(df) + >>> splits = bucketizer.getSplits() + >>> splits[0] + -inf + >>> print("%2.1f" % round(splits[1], 1)) + 0.4 + >>> bucketed = bucketizer.transform(df).head() + >>> bucketed.buckets + 0.0 + >>> quantileDiscretizerPath = temp_path + "/quantile-discretizer" + >>> qds.save(quantileDiscretizerPath) + >>> loadedQds = QuantileDiscretizer.load(quantileDiscretizerPath) + >>> loadedQds.getNumBuckets() == qds.getNumBuckets() + True + + .. versionadded:: 2.0.0 + """ + + # a placeholder to make it appear in the generated doc + numBuckets = Param(Params._dummy(), "numBuckets", + "Maximum number of buckets (quantiles, or " + + "categories) into which data points are grouped. Must be >= 2. Default 2.", + typeConverter=TypeConverters.toInt) + + @keyword_only + def __init__(self, numBuckets=2, inputCol=None, outputCol=None, seed=None): + """ + __init__(self, numBuckets=2, inputCol=None, outputCol=None, seed=None) + """ + super(QuantileDiscretizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.QuantileDiscretizer", + self.uid) + self.numBuckets = Param(self, "numBuckets", + "Maximum number of buckets (quantiles, or " + + "categories) into which data points are grouped. Must be >= 2.") + self._setDefault(numBuckets=2) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.0.0") + def setParams(self, numBuckets=2, inputCol=None, outputCol=None, seed=None): + """ + setParams(self, numBuckets=2, inputCol=None, outputCol=None, seed=None) + Set the params for the QuantileDiscretizer + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + @since("2.0.0") + def setNumBuckets(self, value): + """ + Sets the value of :py:attr:`numBuckets`. + """ + self._paramMap[self.numBuckets] = value + return self + + @since("2.0.0") + def getNumBuckets(self): + """ + Gets the value of numBuckets or its default value. + """ + return self.getOrDefault(self.numBuckets) + + def _create_model(self, java_model): + """ + Private method to convert the java_model to a Python model. + """ + return Bucketizer(splits=list(java_model.getSplits()), + inputCol=self.getInputCol(), + outputCol=self.getOutputCol()) + + @inherit_doc @ignore_unicode_prefix -class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): +class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -984,51 +1283,62 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): length. It returns an array of strings that can be empty. - >>> df = sqlContext.createDataFrame([("a b c",)], ["text"]) + >>> df = sqlContext.createDataFrame([("A B c",)], ["text"]) >>> reTokenizer = RegexTokenizer(inputCol="text", outputCol="words") >>> reTokenizer.transform(df).head() - Row(text=u'a b c', words=[u'a', u'b', u'c']) + Row(text=u'A B c', words=[u'a', u'b', u'c']) >>> # Change a parameter. >>> reTokenizer.setParams(outputCol="tokens").transform(df).head() - Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + Row(text=u'A B c', tokens=[u'a', u'b', u'c']) >>> # Temporarily modify a parameter. >>> reTokenizer.transform(df, {reTokenizer.outputCol: "words"}).head() - Row(text=u'a b c', words=[u'a', u'b', u'c']) + Row(text=u'A B c', words=[u'a', u'b', u'c']) >>> reTokenizer.transform(df).head() - Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + Row(text=u'A B c', tokens=[u'a', u'b', u'c']) >>> # Must use keyword arguments to specify params. >>> reTokenizer.setParams("text") Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + >>> regexTokenizerPath = temp_path + "/regex-tokenizer" + >>> reTokenizer.save(regexTokenizerPath) + >>> loadedReTokenizer = RegexTokenizer.load(regexTokenizerPath) + >>> loadedReTokenizer.getMinTokenLength() == reTokenizer.getMinTokenLength() + True + >>> loadedReTokenizer.getGaps() == reTokenizer.getGaps() + True .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc - minTokenLength = Param(Params._dummy(), "minTokenLength", "minimum token length (>= 0)") + minTokenLength = Param(Params._dummy(), "minTokenLength", "minimum token length (>= 0)", + typeConverter=TypeConverters.toInt) gaps = Param(Params._dummy(), "gaps", "whether regex splits on gaps (True) or matches tokens") - pattern = Param(Params._dummy(), "pattern", "regex pattern (Java dialect) used for tokenizing") + pattern = Param(Params._dummy(), "pattern", "regex pattern (Java dialect) used for tokenizing", + TypeConverters.toString) + toLowercase = Param(Params._dummy(), "toLowercase", "whether to convert all characters to " + + "lowercase before tokenizing", TypeConverters.toBoolean) @keyword_only - def __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None): + def __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, + outputCol=None, toLowercase=True): """ - __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None) + __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, \ + outputCol=None, toLowercase=True) """ super(RegexTokenizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RegexTokenizer", self.uid) - self.minTokenLength = Param(self, "minTokenLength", "minimum token length (>= 0)") - self.gaps = Param(self, "gaps", "whether regex splits on gaps (True) or matches tokens") - self.pattern = Param(self, "pattern", "regex pattern (Java dialect) used for tokenizing") - self._setDefault(minTokenLength=1, gaps=True, pattern="\\s+") + self._setDefault(minTokenLength=1, gaps=True, pattern="\\s+", toLowercase=True) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.4.0") - def setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None): + def setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, + outputCol=None, toLowercase=True): """ - setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None) + setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, \ + outputCol=None, toLowercase=True) Sets params for this RegexTokenizer. """ kwargs = self.setParams._input_kwargs @@ -1079,9 +1389,24 @@ def getPattern(self): """ return self.getOrDefault(self.pattern) + @since("2.0.0") + def setToLowercase(self, value): + """ + Sets the value of :py:attr:`toLowercase`. + """ + self._paramMap[self.toLowercase] = value + return self + + @since("2.0.0") + def getToLowercase(self): + """ + Gets the value of toLowercase or its default value. + """ + return self.getOrDefault(self.toLowercase) + @inherit_doc -class SQLTransformer(JavaTransformer): +class SQLTransformer(JavaTransformer, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1094,12 +1419,16 @@ class SQLTransformer(JavaTransformer): ... statement="SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") >>> sqlTrans.transform(df).head() Row(id=0, v1=1.0, v2=3.0, v3=4.0, v4=3.0) + >>> sqlTransformerPath = temp_path + "/sql-transformer" + >>> sqlTrans.save(sqlTransformerPath) + >>> loadedSqlTrans = SQLTransformer.load(sqlTransformerPath) + >>> loadedSqlTrans.getStatement() == sqlTrans.getStatement() + True .. versionadded:: 1.6.0 """ - # a placeholder to make it appear in the generated doc - statement = Param(Params._dummy(), "statement", "SQL statement") + statement = Param(Params._dummy(), "statement", "SQL statement", TypeConverters.toString) @keyword_only def __init__(self, statement=None): @@ -1108,7 +1437,6 @@ def __init__(self, statement=None): """ super(SQLTransformer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.SQLTransformer", self.uid) - self.statement = Param(self, "statement", "SQL statement") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -1139,7 +1467,7 @@ def getStatement(self): @inherit_doc -class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol): +class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1156,13 +1484,27 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol): DenseVector([1.4142]) >>> model.transform(df).collect()[1].scaled DenseVector([1.4142]) + >>> standardScalerPath = temp_path + "/standard-scaler" + >>> standardScaler.save(standardScalerPath) + >>> loadedStandardScaler = StandardScaler.load(standardScalerPath) + >>> loadedStandardScaler.getWithMean() == standardScaler.getWithMean() + True + >>> loadedStandardScaler.getWithStd() == standardScaler.getWithStd() + True + >>> modelPath = temp_path + "/standard-scaler-model" + >>> model.save(modelPath) + >>> loadedModel = StandardScalerModel.load(modelPath) + >>> loadedModel.std == model.std + True + >>> loadedModel.mean == model.mean + True .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc - withMean = Param(Params._dummy(), "withMean", "Center data with mean") - withStd = Param(Params._dummy(), "withStd", "Scale to unit standard deviation") + withMean = Param(Params._dummy(), "withMean", "Center data with mean", TypeConverters.toBoolean) + withStd = Param(Params._dummy(), "withStd", "Scale to unit standard deviation", + TypeConverters.toBoolean) @keyword_only def __init__(self, withMean=False, withStd=True, inputCol=None, outputCol=None): @@ -1171,8 +1513,6 @@ def __init__(self, withMean=False, withStd=True, inputCol=None, outputCol=None): """ super(StandardScaler, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StandardScaler", self.uid) - self.withMean = Param(self, "withMean", "Center data with mean") - self.withStd = Param(self, "withStd", "Scale to unit standard deviation") self._setDefault(withMean=False, withStd=True) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -1221,7 +1561,7 @@ def _create_model(self, java_model): return StandardScalerModel(java_model) -class StandardScalerModel(JavaModel): +class StandardScalerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1248,7 +1588,8 @@ def mean(self): @inherit_doc -class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid): +class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -1263,11 +1604,26 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid): >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), ... key=lambda x: x[0]) [(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)] - >>> inverter = IndexToString(inputCol="indexed", outputCol="label2", labels=model.labels()) + >>> inverter = IndexToString(inputCol="indexed", outputCol="label2", labels=model.labels) >>> itd = inverter.transform(td) >>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]), ... key=lambda x: x[0]) [(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')] + >>> stringIndexerPath = temp_path + "/string-indexer" + >>> stringIndexer.save(stringIndexerPath) + >>> loadedIndexer = StringIndexer.load(stringIndexerPath) + >>> loadedIndexer.getHandleInvalid() == stringIndexer.getHandleInvalid() + True + >>> modelPath = temp_path + "/string-indexer-model" + >>> model.save(modelPath) + >>> loadedModel = StringIndexerModel.load(modelPath) + >>> loadedModel.labels == model.labels + True + >>> indexToStringPath = temp_path + "/index-to-string" + >>> inverter.save(indexToStringPath) + >>> loadedInverter = IndexToString.load(indexToStringPath) + >>> loadedInverter.getLabels() == inverter.getLabels() + True .. versionadded:: 1.4.0 """ @@ -1297,7 +1653,7 @@ def _create_model(self, java_model): return StringIndexerModel(java_model) -class StringIndexerModel(JavaModel): +class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1305,17 +1661,18 @@ class StringIndexerModel(JavaModel): .. versionadded:: 1.4.0 """ + @property @since("1.5.0") def labels(self): """ Ordered list of labels, corresponding to indices to be assigned. """ - return self._java_obj.labels + return self._call_java("labels") @inherit_doc -class IndexToString(JavaTransformer, HasInputCol, HasOutputCol): +class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1328,10 +1685,10 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol): .. versionadded:: 1.6.0 """ - # a placeholder to make the labels show up in generated doc labels = Param(Params._dummy(), "labels", "Optional array of labels specifying index-string mapping." + - " If not provided or if empty, then metadata from inputCol is used instead.") + " If not provided or if empty, then metadata from inputCol is used instead.", + typeConverter=TypeConverters.toListString) @keyword_only def __init__(self, inputCol=None, outputCol=None, labels=None): @@ -1341,9 +1698,6 @@ def __init__(self, inputCol=None, outputCol=None, labels=None): super(IndexToString, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString", self.uid) - self.labels = Param(self, "labels", - "Optional array of labels specifying index-string mapping. If not" + - " provided or if empty, then metadata from inputCol is used instead.") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -1373,19 +1727,32 @@ def getLabels(self): return self.getOrDefault(self.labels) -class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol): +class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental A feature transformer that filters out stop words from input. Note: null values from input array are preserved unless adding null to stopWords explicitly. + >>> df = sqlContext.createDataFrame([(["a", "b", "c"],)], ["text"]) + >>> remover = StopWordsRemover(inputCol="text", outputCol="words", stopWords=["b"]) + >>> remover.transform(df).head().words == ['a', 'c'] + True + >>> stopWordsRemoverPath = temp_path + "/stopwords-remover" + >>> remover.save(stopWordsRemoverPath) + >>> loadedRemover = StopWordsRemover.load(stopWordsRemoverPath) + >>> loadedRemover.getStopWords() == remover.getStopWords() + True + >>> loadedRemover.getCaseSensitive() == remover.getCaseSensitive() + True + .. versionadded:: 1.6.0 """ - # a placeholder to make the stopwords show up in generated doc - stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out") + + stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out", + typeConverter=TypeConverters.toListString) caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " + - "comparison over the stop words") + "comparison over the stop words", TypeConverters.toBoolean) @keyword_only def __init__(self, inputCol=None, outputCol=None, stopWords=None, @@ -1397,12 +1764,9 @@ def __init__(self, inputCol=None, outputCol=None, stopWords=None, super(StopWordsRemover, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover", self.uid) - self.stopWords = Param(self, "stopWords", "The words to be filtered out") - self.caseSensitive = Param(self, "caseSensitive", "whether to do a case " + - "sensitive comparison over the stop words") stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWords defaultStopWords = stopWordsObj.English() - self._setDefault(stopWords=defaultStopWords) + self._setDefault(stopWords=defaultStopWords, caseSensitive=False) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -1451,7 +1815,7 @@ def getCaseSensitive(self): @inherit_doc @ignore_unicode_prefix -class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): +class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1475,6 +1839,11 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + >>> tokenizerPath = temp_path + "/tokenizer" + >>> tokenizer.save(tokenizerPath) + >>> loadedTokenizer = Tokenizer.load(tokenizerPath) + >>> loadedTokenizer.transform(df).head().tokens == tokenizer.transform(df).head().tokens + True .. versionadded:: 1.3.0 """ @@ -1501,7 +1870,7 @@ def setParams(self, inputCol=None, outputCol=None): @inherit_doc -class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol): +class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1516,6 +1885,11 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol): >>> params = {vecAssembler.inputCols: ["b", "a"], vecAssembler.outputCol: "vector"} >>> vecAssembler.transform(df, params).head().vector DenseVector([0.0, 1.0]) + >>> vectorAssemblerPath = temp_path + "/vector-assembler" + >>> vecAssembler.save(vectorAssemblerPath) + >>> loadedAssembler = VectorAssembler.load(vectorAssemblerPath) + >>> loadedAssembler.transform(df).head().freqs == vecAssembler.transform(df).head().freqs + True .. versionadded:: 1.4.0 """ @@ -1542,7 +1916,7 @@ def setParams(self, inputCols=None, outputCol=None): @inherit_doc -class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol): +class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1598,15 +1972,26 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol): >>> model2 = indexer.fit(df, params) >>> model2.transform(df).head().vector DenseVector([1.0, 0.0]) + >>> vectorIndexerPath = temp_path + "/vector-indexer" + >>> indexer.save(vectorIndexerPath) + >>> loadedIndexer = VectorIndexer.load(vectorIndexerPath) + >>> loadedIndexer.getMaxCategories() == indexer.getMaxCategories() + True + >>> modelPath = temp_path + "/vector-indexer-model" + >>> model.save(modelPath) + >>> loadedModel = VectorIndexerModel.load(modelPath) + >>> loadedModel.numFeatures == model.numFeatures + True + >>> loadedModel.categoryMaps == model.categoryMaps + True .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc maxCategories = Param(Params._dummy(), "maxCategories", "Threshold for the number of values a categorical feature can take " + "(>= 2). If a feature is found to have > maxCategories values, then " + - "it is declared continuous.") + "it is declared continuous.", typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, maxCategories=20, inputCol=None, outputCol=None): @@ -1615,10 +2000,6 @@ def __init__(self, maxCategories=20, inputCol=None, outputCol=None): """ super(VectorIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorIndexer", self.uid) - self.maxCategories = Param(self, "maxCategories", - "Threshold for the number of values a categorical feature " + - "can take (>= 2). If a feature is found to have " + - "> maxCategories values, then it is declared continuous.") self._setDefault(maxCategories=20) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -1652,7 +2033,7 @@ def _create_model(self, java_model): return VectorIndexerModel(java_model) -class VectorIndexerModel(JavaModel): +class VectorIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1681,7 +2062,7 @@ def categoryMaps(self): @inherit_doc -class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol): +class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1703,17 +2084,24 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol): >>> vs = VectorSlicer(inputCol="features", outputCol="sliced", indices=[1, 4]) >>> vs.transform(df).head().sliced DenseVector([2.3, 1.0]) + >>> vectorSlicerPath = temp_path + "/vector-slicer" + >>> vs.save(vectorSlicerPath) + >>> loadedVs = VectorSlicer.load(vectorSlicerPath) + >>> loadedVs.getIndices() == vs.getIndices() + True + >>> loadedVs.getNames() == vs.getNames() + True .. versionadded:: 1.6.0 """ - # a placeholder to make it appear in the generated doc indices = Param(Params._dummy(), "indices", "An array of indices to select features from " + - "a vector column. There can be no overlap with names.") + "a vector column. There can be no overlap with names.", + typeConverter=TypeConverters.toListInt) names = Param(Params._dummy(), "names", "An array of feature names to select features from " + "a vector column. These names must be specified by ML " + "org.apache.spark.ml.attribute.Attribute. There can be no overlap with " + - "indices.") + "indices.", typeConverter=TypeConverters.toListString) @keyword_only def __init__(self, inputCol=None, outputCol=None, indices=None, names=None): @@ -1722,12 +2110,7 @@ def __init__(self, inputCol=None, outputCol=None, indices=None, names=None): """ super(VectorSlicer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorSlicer", self.uid) - self.indices = Param(self, "indices", "An array of indices to select features from " + - "a vector column. There can be no overlap with names.") - self.names = Param(self, "names", "An array of feature names to select features from " + - "a vector column. These names must be specified by ML " + - "org.apache.spark.ml.attribute.Attribute. There can be no overlap " + - "with indices.") + self._setDefault(indices=[], names=[]) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -1774,7 +2157,8 @@ def getNames(self): @inherit_doc @ignore_unicode_prefix -class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCol): +class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCol, + JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1783,38 +2167,56 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has >>> sent = ("a b " * 100 + "a c " * 10).split(" ") >>> doc = sqlContext.createDataFrame([(sent,), (sent,)], ["sentence"]) - >>> model = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", outputCol="model").fit(doc) + >>> word2Vec = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", outputCol="model") + >>> model = word2Vec.fit(doc) >>> model.getVectors().show() +----+--------------------+ |word| vector| +----+--------------------+ - | a|[-0.3511952459812...| - | b|[0.29077222943305...| - | c|[0.02315592765808...| + | a|[0.09461779892444...| + | b|[1.15474212169647...| + | c|[-0.3794820010662...| +----+--------------------+ ... >>> model.findSynonyms("a", 2).show() +----+-------------------+ |word| similarity| +----+-------------------+ - | b|0.29255685145799626| - | c|-0.5414068302988307| + | b| 0.2505344027513247| + | c|-0.6980510075367647| +----+-------------------+ ... >>> model.transform(doc).head().model - DenseVector([-0.0422, -0.5138, -0.2546, 0.6885, 0.276]) + DenseVector([0.5524, -0.4995, -0.3599, 0.0241, 0.3461]) + >>> word2vecPath = temp_path + "/word2vec" + >>> word2Vec.save(word2vecPath) + >>> loadedWord2Vec = Word2Vec.load(word2vecPath) + >>> loadedWord2Vec.getVectorSize() == word2Vec.getVectorSize() + True + >>> loadedWord2Vec.getNumPartitions() == word2Vec.getNumPartitions() + True + >>> loadedWord2Vec.getMinCount() == word2Vec.getMinCount() + True + >>> modelPath = temp_path + "/word2vec-model" + >>> model.save(modelPath) + >>> loadedModel = Word2VecModel.load(modelPath) + >>> loadedModel.getVectors().first().word == model.getVectors().first().word + True + >>> loadedModel.getVectors().first().vector == model.getVectors().first().vector + True .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc vectorSize = Param(Params._dummy(), "vectorSize", - "the dimension of codes after transforming from words") + "the dimension of codes after transforming from words", + typeConverter=TypeConverters.toInt) numPartitions = Param(Params._dummy(), "numPartitions", - "number of partitions for sentences of words") + "number of partitions for sentences of words", + typeConverter=TypeConverters.toInt) minCount = Param(Params._dummy(), "minCount", "the minimum number of times a token must appear to be included in the " + - "word2vec model's vocabulary") + "word2vec model's vocabulary", typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, @@ -1825,13 +2227,6 @@ def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, """ super(Word2Vec, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Word2Vec", self.uid) - self.vectorSize = Param(self, "vectorSize", - "the dimension of codes after transforming from words") - self.numPartitions = Param(self, "numPartitions", - "number of partitions for sentences of words") - self.minCount = Param(self, "minCount", - "the minimum number of times a token must appear to be included " + - "in the word2vec model's vocabulary") self._setDefault(vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=None) kwargs = self.__init__._input_kwargs @@ -1898,7 +2293,7 @@ def _create_model(self, java_model): return Word2VecModel(java_model) -class Word2VecModel(JavaModel): +class Word2VecModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1929,7 +2324,7 @@ def findSynonyms(self, word, num): @inherit_doc -class PCA(JavaEstimator, HasInputCol, HasOutputCol): +class PCA(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1944,12 +2339,26 @@ class PCA(JavaEstimator, HasInputCol, HasOutputCol): >>> model = pca.fit(df) >>> model.transform(df).collect()[0].pca_features DenseVector([1.648..., -4.013...]) + >>> model.explainedVariance + DenseVector([0.794..., 0.205...]) + >>> pcaPath = temp_path + "/pca" + >>> pca.save(pcaPath) + >>> loadedPca = PCA.load(pcaPath) + >>> loadedPca.getK() == pca.getK() + True + >>> modelPath = temp_path + "/pca-model" + >>> model.save(modelPath) + >>> loadedModel = PCAModel.load(modelPath) + >>> loadedModel.pc == model.pc + True + >>> loadedModel.explainedVariance == model.explainedVariance + True .. versionadded:: 1.5.0 """ - # a placeholder to make it appear in the generated doc - k = Param(Params._dummy(), "k", "the number of principal components") + k = Param(Params._dummy(), "k", "the number of principal components", + typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, k=None, inputCol=None, outputCol=None): @@ -1958,7 +2367,6 @@ def __init__(self, k=None, inputCol=None, outputCol=None): """ super(PCA, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.PCA", self.uid) - self.k = Param(self, "k", "the number of principal components") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -1991,7 +2399,7 @@ def _create_model(self, java_model): return PCAModel(java_model) -class PCAModel(JavaModel): +class PCAModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2000,9 +2408,27 @@ class PCAModel(JavaModel): .. versionadded:: 1.5.0 """ + @property + @since("2.0.0") + def pc(self): + """ + Returns a principal components Matrix. + Each column is one principal component. + """ + return self._call_java("pc") + + @property + @since("2.0.0") + def explainedVariance(self): + """ + Returns a vector of proportions of variance + explained by each principal component. + """ + return self._call_java("explainedVariance") + @inherit_doc -class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol): +class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2018,7 +2444,8 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol): ... (0.0, 0.0, "a") ... ], ["y", "x", "s"]) >>> rf = RFormula(formula="y ~ x + s") - >>> rf.fit(df).transform(df).show() + >>> model = rf.fit(df) + >>> model.transform(df).show() +---+---+---+---------+-----+ | y| x| s| features|label| +---+---+---+---------+-----+ @@ -2036,12 +2463,34 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol): |0.0|0.0| a| [0.0]| 0.0| +---+---+---+--------+-----+ ... + >>> rFormulaPath = temp_path + "/rFormula" + >>> rf.save(rFormulaPath) + >>> loadedRF = RFormula.load(rFormulaPath) + >>> loadedRF.getFormula() == rf.getFormula() + True + >>> loadedRF.getFeaturesCol() == rf.getFeaturesCol() + True + >>> loadedRF.getLabelCol() == rf.getLabelCol() + True + >>> modelPath = temp_path + "/rFormulaModel" + >>> model.save(modelPath) + >>> loadedModel = RFormulaModel.load(modelPath) + >>> loadedModel.uid == model.uid + True + >>> loadedModel.transform(df).show() + +---+---+---+---------+-----+ + | y| x| s| features|label| + +---+---+---+---------+-----+ + |1.0|1.0| a|[1.0,1.0]| 1.0| + |0.0|2.0| b|[2.0,0.0]| 0.0| + |0.0|0.0| a|[0.0,1.0]| 0.0| + +---+---+---+---------+-----+ + ... .. versionadded:: 1.5.0 """ - # a placeholder to make it appear in the generated doc - formula = Param(Params._dummy(), "formula", "R model formula") + formula = Param(Params._dummy(), "formula", "R model formula", TypeConverters.toString) @keyword_only def __init__(self, formula=None, featuresCol="features", labelCol="label"): @@ -2050,7 +2499,6 @@ def __init__(self, formula=None, featuresCol="features", labelCol="label"): """ super(RFormula, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid) - self.formula = Param(self, "formula", "R model formula") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -2083,7 +2531,7 @@ def _create_model(self, java_model): return RFormulaModel(java_model) -class RFormulaModel(JavaModel): +class RFormulaModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2093,11 +2541,118 @@ class RFormulaModel(JavaModel): """ +@inherit_doc +class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, JavaMLReadable, + JavaMLWritable): + """ + .. note:: Experimental + + Chi-Squared feature selection, which selects categorical features to use for predicting a + categorical label. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame( + ... [(Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0), + ... (Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0), + ... (Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0)], + ... ["features", "label"]) + >>> selector = ChiSqSelector(numTopFeatures=1, outputCol="selectedFeatures") + >>> model = selector.fit(df) + >>> model.transform(df).head().selectedFeatures + DenseVector([1.0]) + >>> model.selectedFeatures + [3] + >>> chiSqSelectorPath = temp_path + "/chi-sq-selector" + >>> selector.save(chiSqSelectorPath) + >>> loadedSelector = ChiSqSelector.load(chiSqSelectorPath) + >>> loadedSelector.getNumTopFeatures() == selector.getNumTopFeatures() + True + >>> modelPath = temp_path + "/chi-sq-selector-model" + >>> model.save(modelPath) + >>> loadedModel = ChiSqSelectorModel.load(modelPath) + >>> loadedModel.selectedFeatures == model.selectedFeatures + True + + .. versionadded:: 2.0.0 + """ + + numTopFeatures = \ + Param(Params._dummy(), "numTopFeatures", + "Number of features that selector will select, ordered by statistics value " + + "descending. If the number of features is < numTopFeatures, then this will select " + + "all features.", typeConverter=TypeConverters.toInt) + + @keyword_only + def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label"): + """ + __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label") + """ + super(ChiSqSelector, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.0.0") + def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, + labelCol="labels"): + """ + setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None,\ + labelCol="labels") + Sets params for this ChiSqSelector. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + @since("2.0.0") + def setNumTopFeatures(self, value): + """ + Sets the value of :py:attr:`numTopFeatures`. + """ + self._paramMap[self.numTopFeatures] = value + return self + + @since("2.0.0") + def getNumTopFeatures(self): + """ + Gets the value of numTopFeatures or its default value. + """ + return self.getOrDefault(self.numTopFeatures) + + def _create_model(self, java_model): + return ChiSqSelectorModel(java_model) + + +class ChiSqSelectorModel(JavaModel, JavaMLReadable, JavaMLWritable): + """ + .. note:: Experimental + + Model fitted by ChiSqSelector. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def selectedFeatures(self): + """ + List of indices to select (filter). Must be ordered asc. + """ + return self._call_java("selectedFeatures") + + if __name__ == "__main__": import doctest + import tempfile + + import pyspark.ml.feature from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext + globs = globals().copy() + features = pyspark.ml.feature.__dict__.copy() + globs.update(features) + # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.feature tests") @@ -2108,7 +2663,16 @@ class RFormulaModel(JavaModel): Row(id=2, label="c"), Row(id=3, label="a"), Row(id=4, label="a"), Row(id=5, label="c")], 2) globs['stringIndDf'] = sqlContext.createDataFrame(testData) - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - sc.stop() + temp_path = tempfile.mkdtemp() + globs['temp_path'] = temp_path + try: + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + finally: + from shutil import rmtree + try: + rmtree(temp_path) + except OSError: + pass if failure_count: exit(-1) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 35c9b776a3d5e..a1265294a1e9e 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -14,30 +14,56 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import array +import sys +if sys.version > '3': + basestring = str + xrange = range + unicode = str from abc import ABCMeta import copy +import numpy as np +import warnings from pyspark import since from pyspark.ml.util import Identifiable +from pyspark.mllib.linalg import DenseVector, Vector -__all__ = ['Param', 'Params'] +__all__ = ['Param', 'Params', 'TypeConverters'] class Param(object): """ A param with self-contained documentation. + Note: `expectedType` is deprecated and will be removed in 2.1. Use typeConverter instead, + as a keyword argument. + .. versionadded:: 1.3.0 """ - def __init__(self, parent, name, doc): + def __init__(self, parent, name, doc, expectedType=None, typeConverter=None): if not isinstance(parent, Identifiable): raise TypeError("Parent must be an Identifiable but got type %s." % type(parent)) self.parent = parent.uid self.name = str(name) self.doc = str(doc) + self.expectedType = expectedType + if expectedType is not None: + warnings.warn("expectedType is deprecated and will be removed in 2.1. " + + "Use typeConverter instead, as a keyword argument.") + self.typeConverter = TypeConverters.identity if typeConverter is None else typeConverter + + def _copy_new_parent(self, parent): + """Copy the current param to a new parent, must be a dummy param.""" + if self.parent == "undefined": + param = copy.copy(self) + param.parent = parent.uid + return param + else: + raise ValueError("Cannot copy from non-dummy parent %s." % parent) def __str__(self): return str(self.parent) + "__" + self.name @@ -55,6 +81,146 @@ def __eq__(self, other): return False +class TypeConverters(object): + """ + .. note:: DeveloperApi + + Factory methods for common type conversion functions for `Param.typeConverter`. + + .. versionadded:: 2.0.0 + """ + + @staticmethod + def _is_numeric(value): + vtype = type(value) + return vtype in [int, float, np.float64, np.int64] or vtype.__name__ == 'long' + + @staticmethod + def _is_integer(value): + return TypeConverters._is_numeric(value) and float(value).is_integer() + + @staticmethod + def _can_convert_to_list(value): + vtype = type(value) + return vtype in [list, np.ndarray, tuple, xrange, array.array] or isinstance(value, Vector) + + @staticmethod + def _can_convert_to_string(value): + vtype = type(value) + return isinstance(value, basestring) or vtype in [np.unicode_, np.string_, np.str_] + + @staticmethod + def identity(value): + """ + Dummy converter that just returns value. + """ + return value + + @staticmethod + def toList(value): + """ + Convert a value to a list, if possible. + """ + if type(value) == list: + return value + elif type(value) in [np.ndarray, tuple, xrange, array.array]: + return list(value) + elif isinstance(value, Vector): + return list(value.toArray()) + else: + raise TypeError("Could not convert %s to list" % value) + + @staticmethod + def toListFloat(value): + """ + Convert a value to list of floats, if possible. + """ + if TypeConverters._can_convert_to_list(value): + value = TypeConverters.toList(value) + if all(map(lambda v: TypeConverters._is_numeric(v), value)): + return [float(v) for v in value] + raise TypeError("Could not convert %s to list of floats" % value) + + @staticmethod + def toListInt(value): + """ + Convert a value to list of ints, if possible. + """ + if TypeConverters._can_convert_to_list(value): + value = TypeConverters.toList(value) + if all(map(lambda v: TypeConverters._is_integer(v), value)): + return [int(v) for v in value] + raise TypeError("Could not convert %s to list of ints" % value) + + @staticmethod + def toListString(value): + """ + Convert a value to list of strings, if possible. + """ + if TypeConverters._can_convert_to_list(value): + value = TypeConverters.toList(value) + if all(map(lambda v: TypeConverters._can_convert_to_string(v), value)): + return [TypeConverters.toString(v) for v in value] + raise TypeError("Could not convert %s to list of strings" % value) + + @staticmethod + def toVector(value): + """ + Convert a value to a MLlib Vector, if possible. + """ + if isinstance(value, Vector): + return value + elif TypeConverters._can_convert_to_list(value): + value = TypeConverters.toList(value) + if all(map(lambda v: TypeConverters._is_numeric(v), value)): + return DenseVector(value) + raise TypeError("Could not convert %s to vector" % value) + + @staticmethod + def toFloat(value): + """ + Convert a value to a float, if possible. + """ + if TypeConverters._is_numeric(value): + return float(value) + else: + raise TypeError("Could not convert %s to float" % value) + + @staticmethod + def toInt(value): + """ + Convert a value to an int, if possible. + """ + if TypeConverters._is_integer(value): + return int(value) + else: + raise TypeError("Could not convert %s to int" % value) + + @staticmethod + def toString(value): + """ + Convert a value to a string, if possible. + """ + if isinstance(value, basestring): + return value + elif type(value) in [np.string_, np.str_]: + return str(value) + elif type(value) == np.unicode_: + return unicode(value) + else: + raise TypeError("Could not convert %s to string type" % type(value)) + + @staticmethod + def toBoolean(value): + """ + Convert a value to a boolean, if possible. + """ + if type(value) == bool: + return value + else: + raise TypeError("Boolean Param requires value of type bool. Found %s." % type(value)) + + class Params(Identifiable): """ Components that take parameters. This also provides an internal @@ -76,6 +242,19 @@ def __init__(self): #: value returned by :py:func:`params` self._params = None + # Copy the params from the class to the object + self._copy_params() + + def _copy_params(self): + """ + Copy all params defined on the class to current object. + """ + cls = type(self) + src_name_attrs = [(x, getattr(cls, x)) for x in dir(cls)] + src_params = list(filter(lambda nameAttr: isinstance(nameAttr[1], Param), src_name_attrs)) + for name, param in src_params: + setattr(self, name, param._copy_new_parent(self)) + @property @since("1.3.0") def params(self): @@ -86,7 +265,8 @@ def params(self): """ if self._params is None: self._params = list(filter(lambda attr: isinstance(attr, Param), - [getattr(self, x) for x in dir(self) if x != "params"])) + [getattr(self, x) for x in dir(self) if x != "params" and + not isinstance(getattr(type(self), x, None), property)])) return self._params @since("1.4.0") @@ -156,8 +336,11 @@ def hasParam(self, paramName): Tests whether this instance contains a param with a given (string) name. """ - param = self._resolveParam(paramName) - return param in self.params + if isinstance(paramName, str): + p = getattr(self, paramName, None) + return isinstance(p, Param) + else: + raise TypeError("hasParam(): paramName must be a string") @since("1.4.0") def getOrDefault(self, param): @@ -247,7 +430,13 @@ def _set(self, **kwargs): Sets user-supplied params. """ for param, value in kwargs.items(): - self._paramMap[getattr(self, param)] = value + p = getattr(self, param) + if value is not None: + try: + value = p.typeConverter(value) + except TypeError as e: + raise TypeError('Invalid param value given for param "%s". %s' % (p.name, e)) + self._paramMap[p] = value return self def _setDefault(self, **kwargs): @@ -274,3 +463,27 @@ def _copyValues(self, to, extra=None): if p in paramMap and to.hasParam(p.name): to._set(**{p.name: paramMap[p]}) return to + + def _resetUid(self, newUid): + """ + Changes the uid of this instance. This updates both + the stored uid and the parent uid of params and param maps. + This is used by persistence (loading). + :param newUid: new uid to use + :return: same instance, but with the uid and Param.parent values + updated, including within param maps + """ + self.uid = newUid + newDefaultParamMap = dict() + newParamMap = dict() + for param in self.params: + newParam = copy.copy(param) + newParam.parent = newUid + if param in self._defaultParamMap: + newDefaultParamMap[newParam] = self._defaultParamMap[param] + if param in self._paramMap: + newParamMap[newParam] = self._paramMap[param] + param.parent = newUid + self._defaultParamMap = newDefaultParamMap + self._paramMap = newParamMap + return self diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 070c5db01ae73..a7615c43bee24 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -38,7 +38,7 @@ # python _shared_params_code_gen.py > shared.py -def _gen_param_header(name, doc, defaultValueStr): +def _gen_param_header(name, doc, defaultValueStr, typeConverter): """ Generates the header part for shared variables @@ -50,23 +50,24 @@ def _gen_param_header(name, doc, defaultValueStr): Mixin for param $name: $doc """ - # a placeholder to make it appear in the generated doc - $name = Param(Params._dummy(), "$name", "$doc") + $name = Param(Params._dummy(), "$name", "$doc", typeConverter=$typeConverter) def __init__(self): - super(Has$Name, self).__init__() - #: param for $doc - self.$name = Param(self, "$name", "$doc")''' + super(Has$Name, self).__init__()''' + if defaultValueStr is not None: template += ''' self._setDefault($name=$defaultValueStr)''' Name = name[0].upper() + name[1:] + if typeConverter is None: + typeConverter = str(None) return template \ .replace("$name", name) \ .replace("$Name", Name) \ .replace("$doc", doc) \ - .replace("$defaultValueStr", str(defaultValueStr)) + .replace("$defaultValueStr", str(defaultValueStr)) \ + .replace("$typeConverter", typeConverter) def _gen_param_code(name, doc, defaultValueStr): @@ -84,7 +85,7 @@ def set$Name(self, value): """ Sets the value of :py:attr:`$name`. """ - self._paramMap[self.$name] = value + self._set($name=value) return self def get$Name(self): @@ -103,83 +104,96 @@ def get$Name(self): if __name__ == "__main__": print(header) print("\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n") - print("from pyspark.ml.param import Param, Params\n\n") + print("from pyspark.ml.param import *\n\n") shared = [ - ("maxIter", "max number of iterations (>= 0).", None), - ("regParam", "regularization parameter (>= 0).", None), - ("featuresCol", "features column name.", "'features'"), - ("labelCol", "label column name.", "'label'"), - ("predictionCol", "prediction column name.", "'prediction'"), + ("maxIter", "max number of iterations (>= 0).", None, "TypeConverters.toInt"), + ("regParam", "regularization parameter (>= 0).", None, "TypeConverters.toFloat"), + ("featuresCol", "features column name.", "'features'", "TypeConverters.toString"), + ("labelCol", "label column name.", "'label'", "TypeConverters.toString"), + ("predictionCol", "prediction column name.", "'prediction'", "TypeConverters.toString"), ("probabilityCol", "Column name for predicted class conditional probabilities. " + "Note: Not all models output well-calibrated probability estimates! These probabilities " + - "should be treated as confidences, not precise probabilities.", "'probability'"), - ("rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", "'rawPrediction'"), - ("inputCol", "input column name.", None), - ("inputCols", "input column names.", None), - ("outputCol", "output column name.", "self.uid + '__output'"), - ("numFeatures", "number of features.", None), - ("checkpointInterval", "checkpoint interval (>= 1).", None), - ("seed", "random seed.", "hash(type(self).__name__)"), - ("tol", "the convergence tolerance for iterative algorithms.", None), - ("stepSize", "Step size to be used for each iteration of optimization.", None), + "should be treated as confidences, not precise probabilities.", "'probability'", + "TypeConverters.toString"), + ("rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", "'rawPrediction'", + "TypeConverters.toString"), + ("inputCol", "input column name.", None, "TypeConverters.toString"), + ("inputCols", "input column names.", None, "TypeConverters.toListString"), + ("outputCol", "output column name.", "self.uid + '__output'", "TypeConverters.toString"), + ("numFeatures", "number of features.", None, "TypeConverters.toInt"), + ("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " + + "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None, + "TypeConverters.toInt"), + ("seed", "random seed.", "hash(type(self).__name__)", "TypeConverters.toInt"), + ("tol", "the convergence tolerance for iterative algorithms.", None, + "TypeConverters.toFloat"), + ("stepSize", "Step size to be used for each iteration of optimization.", None, + "TypeConverters.toFloat"), ("handleInvalid", "how to handle invalid entries. Options are skip (which will filter " + "out rows with bad values), or error (which will throw an errror). More options may be " + - "added later.", None), + "added later.", None, "TypeConverters.toBoolean"), ("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + - "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0"), - ("fitIntercept", "whether to fit an intercept term.", "True"), + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0", + "TypeConverters.toFloat"), + ("fitIntercept", "whether to fit an intercept term.", "True", "TypeConverters.toBoolean"), ("standardization", "whether to standardize the training features before fitting the " + - "model.", "True"), + "model.", "True", "TypeConverters.toBoolean"), ("thresholds", "Thresholds in multi-class classification to adjust the probability of " + "predicting each class. Array must have length equal to the number of classes, with " + "values >= 0. The class with largest value p/t is predicted, where p is the original " + - "probability of that class and t is the class' threshold.", None), + "probability of that class and t is the class' threshold.", None, + "TypeConverters.toListFloat"), ("weightCol", "weight column name. If this is not set or empty, we treat " + - "all instance weights as 1.0.", None), + "all instance weights as 1.0.", None, "TypeConverters.toString"), ("solver", "the solver algorithm for optimization. If this is not set or empty, " + - "default value is 'auto'.", "'auto'")] + "default value is 'auto'.", "'auto'", "TypeConverters.toString"), + ("varianceCol", "column name for the biased sample variance of prediction.", + None, "TypeConverters.toString")] code = [] - for name, doc, defaultValueStr in shared: - param_code = _gen_param_header(name, doc, defaultValueStr) + for name, doc, defaultValueStr, typeConverter in shared: + param_code = _gen_param_header(name, doc, defaultValueStr, typeConverter) code.append(param_code + "\n" + _gen_param_code(name, doc, defaultValueStr)) decisionTreeParams = [ ("maxDepth", "Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; " + - "depth 1 means 1 internal node + 2 leaf nodes."), + "depth 1 means 1 internal node + 2 leaf nodes.", "TypeConverters.toInt"), ("maxBins", "Max number of bins for" + " discretizing continuous features. Must be >=2 and >= number of categories for any" + - " categorical feature."), + " categorical feature.", "TypeConverters.toInt"), ("minInstancesPerNode", "Minimum number of instances each child must have after split. " + "If a split causes the left or right child to have fewer than minInstancesPerNode, the " + - "split will be discarded as invalid. Should be >= 1."), - ("minInfoGain", "Minimum information gain for a split to be considered at a tree node."), - ("maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation."), + "split will be discarded as invalid. Should be >= 1.", "TypeConverters.toInt"), + ("minInfoGain", "Minimum information gain for a split to be considered at a tree node.", + "TypeConverters.toFloat"), + ("maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation. If too small," + + " then 1 node will be split per iteration, and its aggregates may exceed this size.", + "TypeConverters.toInt"), ("cacheNodeIds", "If false, the algorithm will pass trees to executors to match " + "instances with nodes. If true, the algorithm will cache node IDs for each instance. " + - "Caching can speed up training of deeper trees.")] + "Caching can speed up training of deeper trees. Users can set how often should the " + + "cache be checkpointed or disable it by setting checkpointInterval.", + "TypeConverters.toBoolean")] decisionTreeCode = '''class DecisionTreeParams(Params): """ Mixin for Decision Tree parameters. """ - # a placeholder to make it appear in the generated doc $dummyPlaceHolders def __init__(self): - super(DecisionTreeParams, self).__init__() - $realParams''' + super(DecisionTreeParams, self).__init__()''' dtParamMethods = "" dummyPlaceholders = "" - realParams = "" - paramTemplate = """$name = Param($owner, "$name", "$doc")""" - for name, doc in decisionTreeParams: - variable = paramTemplate.replace("$name", name).replace("$doc", doc) + paramTemplate = """$name = Param($owner, "$name", "$doc", typeConverter=$typeConverterStr)""" + for name, doc, typeConverterStr in decisionTreeParams: + if typeConverterStr is None: + typeConverterStr = str(None) + variable = paramTemplate.replace("$name", name).replace("$doc", doc) \ + .replace("$typeConverterStr", typeConverterStr) dummyPlaceholders += variable.replace("$owner", "Params._dummy()") + "\n " - realParams += "#: param for " + doc + "\n " - realParams += "self." + variable.replace("$owner", "self") + "\n " dtParamMethods += _gen_param_code(name, doc, None) + "\n" - code.append(decisionTreeCode.replace("$dummyPlaceHolders", dummyPlaceholders) - .replace("$realParams", realParams) + dtParamMethods) + code.append(decisionTreeCode.replace("$dummyPlaceHolders", dummyPlaceholders) + "\n" + + dtParamMethods) print("\n\n\n".join(code)) diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 4bdf2a8cc563f..c9e975525ce1f 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -17,7 +17,7 @@ # DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py. -from pyspark.ml.param import Param, Params +from pyspark.ml.param import * class HasMaxIter(Params): @@ -25,19 +25,16 @@ class HasMaxIter(Params): Mixin for param maxIter: max number of iterations (>= 0). """ - # a placeholder to make it appear in the generated doc - maxIter = Param(Params._dummy(), "maxIter", "max number of iterations (>= 0).") + maxIter = Param(Params._dummy(), "maxIter", "max number of iterations (>= 0).", typeConverter=TypeConverters.toInt) def __init__(self): super(HasMaxIter, self).__init__() - #: param for max number of iterations (>= 0). - self.maxIter = Param(self, "maxIter", "max number of iterations (>= 0).") def setMaxIter(self, value): """ Sets the value of :py:attr:`maxIter`. """ - self._paramMap[self.maxIter] = value + self._set(maxIter=value) return self def getMaxIter(self): @@ -52,19 +49,16 @@ class HasRegParam(Params): Mixin for param regParam: regularization parameter (>= 0). """ - # a placeholder to make it appear in the generated doc - regParam = Param(Params._dummy(), "regParam", "regularization parameter (>= 0).") + regParam = Param(Params._dummy(), "regParam", "regularization parameter (>= 0).", typeConverter=TypeConverters.toFloat) def __init__(self): super(HasRegParam, self).__init__() - #: param for regularization parameter (>= 0). - self.regParam = Param(self, "regParam", "regularization parameter (>= 0).") def setRegParam(self, value): """ Sets the value of :py:attr:`regParam`. """ - self._paramMap[self.regParam] = value + self._set(regParam=value) return self def getRegParam(self): @@ -79,20 +73,17 @@ class HasFeaturesCol(Params): Mixin for param featuresCol: features column name. """ - # a placeholder to make it appear in the generated doc - featuresCol = Param(Params._dummy(), "featuresCol", "features column name.") + featuresCol = Param(Params._dummy(), "featuresCol", "features column name.", typeConverter=TypeConverters.toString) def __init__(self): super(HasFeaturesCol, self).__init__() - #: param for features column name. - self.featuresCol = Param(self, "featuresCol", "features column name.") self._setDefault(featuresCol='features') def setFeaturesCol(self, value): """ Sets the value of :py:attr:`featuresCol`. """ - self._paramMap[self.featuresCol] = value + self._set(featuresCol=value) return self def getFeaturesCol(self): @@ -107,20 +98,17 @@ class HasLabelCol(Params): Mixin for param labelCol: label column name. """ - # a placeholder to make it appear in the generated doc - labelCol = Param(Params._dummy(), "labelCol", "label column name.") + labelCol = Param(Params._dummy(), "labelCol", "label column name.", typeConverter=TypeConverters.toString) def __init__(self): super(HasLabelCol, self).__init__() - #: param for label column name. - self.labelCol = Param(self, "labelCol", "label column name.") self._setDefault(labelCol='label') def setLabelCol(self, value): """ Sets the value of :py:attr:`labelCol`. """ - self._paramMap[self.labelCol] = value + self._set(labelCol=value) return self def getLabelCol(self): @@ -135,20 +123,17 @@ class HasPredictionCol(Params): Mixin for param predictionCol: prediction column name. """ - # a placeholder to make it appear in the generated doc - predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name.") + predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name.", typeConverter=TypeConverters.toString) def __init__(self): super(HasPredictionCol, self).__init__() - #: param for prediction column name. - self.predictionCol = Param(self, "predictionCol", "prediction column name.") self._setDefault(predictionCol='prediction') def setPredictionCol(self, value): """ Sets the value of :py:attr:`predictionCol`. """ - self._paramMap[self.predictionCol] = value + self._set(predictionCol=value) return self def getPredictionCol(self): @@ -163,20 +148,17 @@ class HasProbabilityCol(Params): Mixin for param probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. """ - # a placeholder to make it appear in the generated doc - probabilityCol = Param(Params._dummy(), "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.") + probabilityCol = Param(Params._dummy(), "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.", typeConverter=TypeConverters.toString) def __init__(self): super(HasProbabilityCol, self).__init__() - #: param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. - self.probabilityCol = Param(self, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.") self._setDefault(probabilityCol='probability') def setProbabilityCol(self, value): """ Sets the value of :py:attr:`probabilityCol`. """ - self._paramMap[self.probabilityCol] = value + self._set(probabilityCol=value) return self def getProbabilityCol(self): @@ -191,20 +173,17 @@ class HasRawPredictionCol(Params): Mixin for param rawPredictionCol: raw prediction (a.k.a. confidence) column name. """ - # a placeholder to make it appear in the generated doc - rawPredictionCol = Param(Params._dummy(), "rawPredictionCol", "raw prediction (a.k.a. confidence) column name.") + rawPredictionCol = Param(Params._dummy(), "rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", typeConverter=TypeConverters.toString) def __init__(self): super(HasRawPredictionCol, self).__init__() - #: param for raw prediction (a.k.a. confidence) column name. - self.rawPredictionCol = Param(self, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name.") self._setDefault(rawPredictionCol='rawPrediction') def setRawPredictionCol(self, value): """ Sets the value of :py:attr:`rawPredictionCol`. """ - self._paramMap[self.rawPredictionCol] = value + self._set(rawPredictionCol=value) return self def getRawPredictionCol(self): @@ -219,19 +198,16 @@ class HasInputCol(Params): Mixin for param inputCol: input column name. """ - # a placeholder to make it appear in the generated doc - inputCol = Param(Params._dummy(), "inputCol", "input column name.") + inputCol = Param(Params._dummy(), "inputCol", "input column name.", typeConverter=TypeConverters.toString) def __init__(self): super(HasInputCol, self).__init__() - #: param for input column name. - self.inputCol = Param(self, "inputCol", "input column name.") def setInputCol(self, value): """ Sets the value of :py:attr:`inputCol`. """ - self._paramMap[self.inputCol] = value + self._set(inputCol=value) return self def getInputCol(self): @@ -246,19 +222,16 @@ class HasInputCols(Params): Mixin for param inputCols: input column names. """ - # a placeholder to make it appear in the generated doc - inputCols = Param(Params._dummy(), "inputCols", "input column names.") + inputCols = Param(Params._dummy(), "inputCols", "input column names.", typeConverter=TypeConverters.toListString) def __init__(self): super(HasInputCols, self).__init__() - #: param for input column names. - self.inputCols = Param(self, "inputCols", "input column names.") def setInputCols(self, value): """ Sets the value of :py:attr:`inputCols`. """ - self._paramMap[self.inputCols] = value + self._set(inputCols=value) return self def getInputCols(self): @@ -273,20 +246,17 @@ class HasOutputCol(Params): Mixin for param outputCol: output column name. """ - # a placeholder to make it appear in the generated doc - outputCol = Param(Params._dummy(), "outputCol", "output column name.") + outputCol = Param(Params._dummy(), "outputCol", "output column name.", typeConverter=TypeConverters.toString) def __init__(self): super(HasOutputCol, self).__init__() - #: param for output column name. - self.outputCol = Param(self, "outputCol", "output column name.") self._setDefault(outputCol=self.uid + '__output') def setOutputCol(self, value): """ Sets the value of :py:attr:`outputCol`. """ - self._paramMap[self.outputCol] = value + self._set(outputCol=value) return self def getOutputCol(self): @@ -301,19 +271,16 @@ class HasNumFeatures(Params): Mixin for param numFeatures: number of features. """ - # a placeholder to make it appear in the generated doc - numFeatures = Param(Params._dummy(), "numFeatures", "number of features.") + numFeatures = Param(Params._dummy(), "numFeatures", "number of features.", typeConverter=TypeConverters.toInt) def __init__(self): super(HasNumFeatures, self).__init__() - #: param for number of features. - self.numFeatures = Param(self, "numFeatures", "number of features.") def setNumFeatures(self, value): """ Sets the value of :py:attr:`numFeatures`. """ - self._paramMap[self.numFeatures] = value + self._set(numFeatures=value) return self def getNumFeatures(self): @@ -325,22 +292,19 @@ def getNumFeatures(self): class HasCheckpointInterval(Params): """ - Mixin for param checkpointInterval: checkpoint interval (>= 1). + Mixin for param checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. """ - # a placeholder to make it appear in the generated doc - checkpointInterval = Param(Params._dummy(), "checkpointInterval", "checkpoint interval (>= 1).") + checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", typeConverter=TypeConverters.toInt) def __init__(self): super(HasCheckpointInterval, self).__init__() - #: param for checkpoint interval (>= 1). - self.checkpointInterval = Param(self, "checkpointInterval", "checkpoint interval (>= 1).") def setCheckpointInterval(self, value): """ Sets the value of :py:attr:`checkpointInterval`. """ - self._paramMap[self.checkpointInterval] = value + self._set(checkpointInterval=value) return self def getCheckpointInterval(self): @@ -355,20 +319,17 @@ class HasSeed(Params): Mixin for param seed: random seed. """ - # a placeholder to make it appear in the generated doc - seed = Param(Params._dummy(), "seed", "random seed.") + seed = Param(Params._dummy(), "seed", "random seed.", typeConverter=TypeConverters.toInt) def __init__(self): super(HasSeed, self).__init__() - #: param for random seed. - self.seed = Param(self, "seed", "random seed.") self._setDefault(seed=hash(type(self).__name__)) def setSeed(self, value): """ Sets the value of :py:attr:`seed`. """ - self._paramMap[self.seed] = value + self._set(seed=value) return self def getSeed(self): @@ -383,19 +344,16 @@ class HasTol(Params): Mixin for param tol: the convergence tolerance for iterative algorithms. """ - # a placeholder to make it appear in the generated doc - tol = Param(Params._dummy(), "tol", "the convergence tolerance for iterative algorithms.") + tol = Param(Params._dummy(), "tol", "the convergence tolerance for iterative algorithms.", typeConverter=TypeConverters.toFloat) def __init__(self): super(HasTol, self).__init__() - #: param for the convergence tolerance for iterative algorithms. - self.tol = Param(self, "tol", "the convergence tolerance for iterative algorithms.") def setTol(self, value): """ Sets the value of :py:attr:`tol`. """ - self._paramMap[self.tol] = value + self._set(tol=value) return self def getTol(self): @@ -410,19 +368,16 @@ class HasStepSize(Params): Mixin for param stepSize: Step size to be used for each iteration of optimization. """ - # a placeholder to make it appear in the generated doc - stepSize = Param(Params._dummy(), "stepSize", "Step size to be used for each iteration of optimization.") + stepSize = Param(Params._dummy(), "stepSize", "Step size to be used for each iteration of optimization.", typeConverter=TypeConverters.toFloat) def __init__(self): super(HasStepSize, self).__init__() - #: param for Step size to be used for each iteration of optimization. - self.stepSize = Param(self, "stepSize", "Step size to be used for each iteration of optimization.") def setStepSize(self, value): """ Sets the value of :py:attr:`stepSize`. """ - self._paramMap[self.stepSize] = value + self._set(stepSize=value) return self def getStepSize(self): @@ -437,19 +392,16 @@ class HasHandleInvalid(Params): Mixin for param handleInvalid: how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later. """ - # a placeholder to make it appear in the generated doc - handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.") + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", typeConverter=TypeConverters.toBoolean) def __init__(self): super(HasHandleInvalid, self).__init__() - #: param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later. - self.handleInvalid = Param(self, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.") def setHandleInvalid(self, value): """ Sets the value of :py:attr:`handleInvalid`. """ - self._paramMap[self.handleInvalid] = value + self._set(handleInvalid=value) return self def getHandleInvalid(self): @@ -464,20 +416,17 @@ class HasElasticNetParam(Params): Mixin for param elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. """ - # a placeholder to make it appear in the generated doc - elasticNetParam = Param(Params._dummy(), "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") + elasticNetParam = Param(Params._dummy(), "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", typeConverter=TypeConverters.toFloat) def __init__(self): super(HasElasticNetParam, self).__init__() - #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. - self.elasticNetParam = Param(self, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") self._setDefault(elasticNetParam=0.0) def setElasticNetParam(self, value): """ Sets the value of :py:attr:`elasticNetParam`. """ - self._paramMap[self.elasticNetParam] = value + self._set(elasticNetParam=value) return self def getElasticNetParam(self): @@ -492,20 +441,17 @@ class HasFitIntercept(Params): Mixin for param fitIntercept: whether to fit an intercept term. """ - # a placeholder to make it appear in the generated doc - fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.") + fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.", typeConverter=TypeConverters.toBoolean) def __init__(self): super(HasFitIntercept, self).__init__() - #: param for whether to fit an intercept term. - self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") self._setDefault(fitIntercept=True) def setFitIntercept(self, value): """ Sets the value of :py:attr:`fitIntercept`. """ - self._paramMap[self.fitIntercept] = value + self._set(fitIntercept=value) return self def getFitIntercept(self): @@ -520,20 +466,17 @@ class HasStandardization(Params): Mixin for param standardization: whether to standardize the training features before fitting the model. """ - # a placeholder to make it appear in the generated doc - standardization = Param(Params._dummy(), "standardization", "whether to standardize the training features before fitting the model.") + standardization = Param(Params._dummy(), "standardization", "whether to standardize the training features before fitting the model.", typeConverter=TypeConverters.toBoolean) def __init__(self): super(HasStandardization, self).__init__() - #: param for whether to standardize the training features before fitting the model. - self.standardization = Param(self, "standardization", "whether to standardize the training features before fitting the model.") self._setDefault(standardization=True) def setStandardization(self, value): """ Sets the value of :py:attr:`standardization`. """ - self._paramMap[self.standardization] = value + self._set(standardization=value) return self def getStandardization(self): @@ -548,19 +491,16 @@ class HasThresholds(Params): Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold. """ - # a placeholder to make it appear in the generated doc - thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.") + thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", typeConverter=TypeConverters.toListFloat) def __init__(self): super(HasThresholds, self).__init__() - #: param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold. - self.thresholds = Param(self, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.") def setThresholds(self, value): """ Sets the value of :py:attr:`thresholds`. """ - self._paramMap[self.thresholds] = value + self._set(thresholds=value) return self def getThresholds(self): @@ -575,19 +515,16 @@ class HasWeightCol(Params): Mixin for param weightCol: weight column name. If this is not set or empty, we treat all instance weights as 1.0. """ - # a placeholder to make it appear in the generated doc - weightCol = Param(Params._dummy(), "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.") + weightCol = Param(Params._dummy(), "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.", typeConverter=TypeConverters.toString) def __init__(self): super(HasWeightCol, self).__init__() - #: param for weight column name. If this is not set or empty, we treat all instance weights as 1.0. - self.weightCol = Param(self, "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.") def setWeightCol(self, value): """ Sets the value of :py:attr:`weightCol`. """ - self._paramMap[self.weightCol] = value + self._set(weightCol=value) return self def getWeightCol(self): @@ -602,20 +539,17 @@ class HasSolver(Params): Mixin for param solver: the solver algorithm for optimization. If this is not set or empty, default value is 'auto'. """ - # a placeholder to make it appear in the generated doc - solver = Param(Params._dummy(), "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.") + solver = Param(Params._dummy(), "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.", typeConverter=TypeConverters.toString) def __init__(self): super(HasSolver, self).__init__() - #: param for the solver algorithm for optimization. If this is not set or empty, default value is 'auto'. - self.solver = Param(self, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.") self._setDefault(solver='auto') def setSolver(self, value): """ Sets the value of :py:attr:`solver`. """ - self._paramMap[self.solver] = value + self._set(solver=value) return self def getSolver(self): @@ -625,40 +559,51 @@ def getSolver(self): return self.getOrDefault(self.solver) +class HasVarianceCol(Params): + """ + Mixin for param varianceCol: column name for the biased sample variance of prediction. + """ + + varianceCol = Param(Params._dummy(), "varianceCol", "column name for the biased sample variance of prediction.", typeConverter=TypeConverters.toString) + + def __init__(self): + super(HasVarianceCol, self).__init__() + + def setVarianceCol(self, value): + """ + Sets the value of :py:attr:`varianceCol`. + """ + self._set(varianceCol=value) + return self + + def getVarianceCol(self): + """ + Gets the value of varianceCol or its default value. + """ + return self.getOrDefault(self.varianceCol) + + class DecisionTreeParams(Params): """ Mixin for Decision Tree parameters. """ - # a placeholder to make it appear in the generated doc - maxDepth = Param(Params._dummy(), "maxDepth", "Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.") - maxBins = Param(Params._dummy(), "maxBins", "Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.") - minInstancesPerNode = Param(Params._dummy(), "minInstancesPerNode", "Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.") - minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") - maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") - cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") + maxDepth = Param(Params._dummy(), "maxDepth", "Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.", typeConverter=TypeConverters.toInt) + maxBins = Param(Params._dummy(), "maxBins", "Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.", typeConverter=TypeConverters.toInt) + minInstancesPerNode = Param(Params._dummy(), "minInstancesPerNode", "Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.", typeConverter=TypeConverters.toInt) + minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.", typeConverter=TypeConverters.toFloat) + maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size.", typeConverter=TypeConverters.toInt) + cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.", typeConverter=TypeConverters.toBoolean) def __init__(self): super(DecisionTreeParams, self).__init__() - #: param for Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - self.maxDepth = Param(self, "maxDepth", "Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.") - #: param for Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature. - self.maxBins = Param(self, "maxBins", "Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.") - #: param for Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1. - self.minInstancesPerNode = Param(self, "minInstancesPerNode", "Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.") - #: param for Minimum information gain for a split to be considered at a tree node. - self.minInfoGain = Param(self, "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") - #: param for Maximum memory in MB allocated to histogram aggregation. - self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") - #: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. - self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") - + def setMaxDepth(self, value): """ Sets the value of :py:attr:`maxDepth`. """ - self._paramMap[self.maxDepth] = value + self._set(maxDepth=value) return self def getMaxDepth(self): @@ -671,7 +616,7 @@ def setMaxBins(self, value): """ Sets the value of :py:attr:`maxBins`. """ - self._paramMap[self.maxBins] = value + self._set(maxBins=value) return self def getMaxBins(self): @@ -684,7 +629,7 @@ def setMinInstancesPerNode(self, value): """ Sets the value of :py:attr:`minInstancesPerNode`. """ - self._paramMap[self.minInstancesPerNode] = value + self._set(minInstancesPerNode=value) return self def getMinInstancesPerNode(self): @@ -697,7 +642,7 @@ def setMinInfoGain(self, value): """ Sets the value of :py:attr:`minInfoGain`. """ - self._paramMap[self.minInfoGain] = value + self._set(minInfoGain=value) return self def getMinInfoGain(self): @@ -710,7 +655,7 @@ def setMaxMemoryInMB(self, value): """ Sets the value of :py:attr:`maxMemoryInMB`. """ - self._paramMap[self.maxMemoryInMB] = value + self._set(maxMemoryInMB=value) return self def getMaxMemoryInMB(self): @@ -723,7 +668,7 @@ def setCacheNodeIds(self, value): """ Sets the value of :py:attr:`cacheNodeIds`. """ - self._paramMap[self.cacheNodeIds] = value + self._set(cacheNodeIds=value) return self def getCacheNodeIds(self): diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 4475451edb781..9d654e8b0f8d0 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -15,120 +15,40 @@ # limitations under the License. # -from abc import ABCMeta, abstractmethod +import sys +if sys.version > '3': + basestring = str + +from pyspark import SparkContext from pyspark import since +from pyspark.ml import Estimator, Model, Transformer from pyspark.ml.param import Param, Params -from pyspark.ml.util import keyword_only +from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable +from pyspark.ml.wrapper import JavaParams from pyspark.mllib.common import inherit_doc @inherit_doc -class Estimator(Params): +class PipelineMLWriter(JavaMLWriter): """ - Abstract class for estimators that fit models to data. + Private Pipeline utility class that can save ML instances through their Scala implementation. - .. versionadded:: 1.3.0 + We can currently use JavaMLWriter, rather than MLWriter, since Pipeline implements _to_java. """ - __metaclass__ = ABCMeta - - @abstractmethod - def _fit(self, dataset): - """ - Fits a model to the input dataset. This is called by the - default implementation of fit. - - :param dataset: input dataset, which is an instance of - :py:class:`pyspark.sql.DataFrame` - :returns: fitted model - """ - raise NotImplementedError() - - @since("1.3.0") - def fit(self, dataset, params=None): - """ - Fits a model to the input dataset with optional parameters. - - :param dataset: input dataset, which is an instance of - :py:class:`pyspark.sql.DataFrame` - :param params: an optional param map that overrides embedded - params. If a list/tuple of param maps is given, - this calls fit on each param map and returns a - list of models. - :returns: fitted model(s) - """ - if params is None: - params = dict() - if isinstance(params, (list, tuple)): - return [self.fit(dataset, paramMap) for paramMap in params] - elif isinstance(params, dict): - if params: - return self.copy(params)._fit(dataset) - else: - return self._fit(dataset) - else: - raise ValueError("Params must be either a param map or a list/tuple of param maps, " - "but got %s." % type(params)) - @inherit_doc -class Transformer(Params): +class PipelineMLReader(JavaMLReader): """ - Abstract class for transformers that transform one dataset into - another. + Private utility class that can load Pipeline instances through their Scala implementation. - .. versionadded:: 1.3.0 + We can currently use JavaMLReader, rather than MLReader, since Pipeline implements _from_java. """ - __metaclass__ = ABCMeta - - @abstractmethod - def _transform(self, dataset): - """ - Transforms the input dataset with optional parameters. - - :param dataset: input dataset, which is an instance of - :py:class:`pyspark.sql.DataFrame` - :returns: transformed dataset - """ - raise NotImplementedError() - - @since("1.3.0") - def transform(self, dataset, params=None): - """ - Transforms the input dataset with optional parameters. - - :param dataset: input dataset, which is an instance of - :py:class:`pyspark.sql.DataFrame` - :param params: an optional param map that overrides embedded - params. - :returns: transformed dataset - """ - if params is None: - params = dict() - if isinstance(params, dict): - if params: - return self.copy(params,)._transform(dataset) - else: - return self._transform(dataset) - else: - raise ValueError("Params must be either a param map but got %s." % type(params)) - @inherit_doc -class Model(Transformer): - """ - Abstract class for models that are fitted by estimators. - - .. versionadded:: 1.4.0 - """ - - __metaclass__ = ABCMeta - - -@inherit_doc -class Pipeline(Estimator): +class Pipeline(Estimator, MLReadable, MLWritable): """ A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each of which is either an @@ -149,6 +69,8 @@ class Pipeline(Estimator): .. versionadded:: 1.3.0 """ + stages = Param(Params._dummy(), "stages", "pipeline stages") + @keyword_only def __init__(self, stages=None): """ @@ -157,8 +79,6 @@ def __init__(self, stages=None): if stages is None: stages = [] super(Pipeline, self).__init__() - #: Param for pipeline stages. - self.stages = Param(self, "stages", "pipeline stages") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -232,9 +152,78 @@ def copy(self, extra=None): stages = [stage.copy(extra) for stage in that.getStages()] return that.setStages(stages) + @since("2.0.0") + def write(self): + """Returns an JavaMLWriter instance for this ML instance.""" + return PipelineMLWriter(self) + + @since("2.0.0") + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + self.write().save(path) + + @classmethod + @since("2.0.0") + def read(cls): + """Returns an MLReader instance for this class.""" + return PipelineMLReader(cls) + + @classmethod + def _from_java(cls, java_stage): + """ + Given a Java Pipeline, create and return a Python wrapper of it. + Used for ML persistence. + """ + # Create a new instance of this stage. + py_stage = cls() + # Load information from java_stage to the instance. + py_stages = [JavaParams._from_java(s) for s in java_stage.getStages()] + py_stage.setStages(py_stages) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java Pipeline. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + + gateway = SparkContext._gateway + cls = SparkContext._jvm.org.apache.spark.ml.PipelineStage + java_stages = gateway.new_array(cls, len(self.getStages())) + for idx, stage in enumerate(self.getStages()): + java_stages[idx] = stage._to_java() + + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.Pipeline", self.uid) + _java_obj.setStages(java_stages) + + return _java_obj + @inherit_doc -class PipelineModel(Model): +class PipelineModelMLWriter(JavaMLWriter): + """ + Private PipelineModel utility class that can save ML instances through their Scala + implementation. + + We can (currently) use JavaMLWriter, rather than MLWriter, since PipelineModel implements + _to_java. + """ + + +@inherit_doc +class PipelineModelMLReader(JavaMLReader): + """ + Private utility class that can load PipelineModel instances through their Scala implementation. + + We can currently use JavaMLReader, rather than MLReader, since PipelineModel implements + _from_java. + """ + + +@inherit_doc +class PipelineModel(Model, MLReadable, MLWritable): """ Represents a compiled pipeline with transformers and fitted models. @@ -262,3 +251,50 @@ def copy(self, extra=None): extra = dict() stages = [stage.copy(extra) for stage in self.stages] return PipelineModel(stages) + + @since("2.0.0") + def write(self): + """Returns an JavaMLWriter instance for this ML instance.""" + return PipelineModelMLWriter(self) + + @since("2.0.0") + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + self.write().save(path) + + @classmethod + @since("2.0.0") + def read(cls): + """Returns an JavaMLReader instance for this class.""" + return PipelineModelMLReader(cls) + + @classmethod + def _from_java(cls, java_stage): + """ + Given a Java PipelineModel, create and return a Python wrapper of it. + Used for ML persistence. + """ + # Load information from java_stage to the instance. + py_stages = [JavaParams._from_java(s) for s in java_stage.stages()] + # Create a new instance of this stage. + py_stage = cls(py_stages) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java PipelineModel. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + + gateway = SparkContext._gateway + cls = SparkContext._jvm.org.apache.spark.ml.Transformer + java_stages = gateway.new_array(cls, len(self.stages)) + for idx, stage in enumerate(self.stages): + java_stages[idx] = stage._to_java() + + _java_obj =\ + JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages) + + return _java_obj diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index ec5748a1cfe94..7c7a1b67a100e 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -16,7 +16,7 @@ # from pyspark import since -from pyspark.ml.util import keyword_only +from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * from pyspark.mllib.common import inherit_doc @@ -26,7 +26,8 @@ @inherit_doc -class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, HasRegParam, HasSeed): +class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, HasRegParam, HasSeed, + JavaMLWritable, JavaMLReadable): """ Alternating Least Squares (ALS) matrix factorization. @@ -76,26 +77,46 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha >>> test = sqlContext.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"]) >>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0]) >>> predictions[0] - Row(user=0, item=2, prediction=0.39...) + Row(user=0, item=2, prediction=-0.13807615637779236) >>> predictions[1] - Row(user=1, item=0, prediction=3.19...) + Row(user=1, item=0, prediction=2.6258413791656494) >>> predictions[2] - Row(user=2, item=0, prediction=-1.15...) + Row(user=2, item=0, prediction=-1.5018409490585327) + >>> als_path = temp_path + "/als" + >>> als.save(als_path) + >>> als2 = ALS.load(als_path) + >>> als.getMaxIter() + 5 + >>> model_path = temp_path + "/als_model" + >>> model.save(model_path) + >>> model2 = ALSModel.load(model_path) + >>> model.rank == model2.rank + True + >>> sorted(model.userFactors.collect()) == sorted(model2.userFactors.collect()) + True + >>> sorted(model.itemFactors.collect()) == sorted(model2.itemFactors.collect()) + True .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc - rank = Param(Params._dummy(), "rank", "rank of the factorization") - numUserBlocks = Param(Params._dummy(), "numUserBlocks", "number of user blocks") - numItemBlocks = Param(Params._dummy(), "numItemBlocks", "number of item blocks") - implicitPrefs = Param(Params._dummy(), "implicitPrefs", "whether to use implicit preference") - alpha = Param(Params._dummy(), "alpha", "alpha for implicit preference") - userCol = Param(Params._dummy(), "userCol", "column name for user ids") - itemCol = Param(Params._dummy(), "itemCol", "column name for item ids") - ratingCol = Param(Params._dummy(), "ratingCol", "column name for ratings") + rank = Param(Params._dummy(), "rank", "rank of the factorization", + typeConverter=TypeConverters.toInt) + numUserBlocks = Param(Params._dummy(), "numUserBlocks", "number of user blocks", + typeConverter=TypeConverters.toInt) + numItemBlocks = Param(Params._dummy(), "numItemBlocks", "number of item blocks", + typeConverter=TypeConverters.toInt) + implicitPrefs = Param(Params._dummy(), "implicitPrefs", "whether to use implicit preference", + TypeConverters.toBoolean) + alpha = Param(Params._dummy(), "alpha", "alpha for implicit preference", + typeConverter=TypeConverters.toFloat) + userCol = Param(Params._dummy(), "userCol", "column name for user ids", TypeConverters.toString) + itemCol = Param(Params._dummy(), "itemCol", "column name for item ids", TypeConverters.toString) + ratingCol = Param(Params._dummy(), "ratingCol", "column name for ratings", + TypeConverters.toString) nonnegative = Param(Params._dummy(), "nonnegative", - "whether to use nonnegative constraint for least squares") + "whether to use nonnegative constraint for least squares", + TypeConverters.toBoolean) @keyword_only def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, @@ -108,16 +129,6 @@ def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemB """ super(ALS, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid) - self.rank = Param(self, "rank", "rank of the factorization") - self.numUserBlocks = Param(self, "numUserBlocks", "number of user blocks") - self.numItemBlocks = Param(self, "numItemBlocks", "number of item blocks") - self.implicitPrefs = Param(self, "implicitPrefs", "whether to use implicit preference") - self.alpha = Param(self, "alpha", "alpha for implicit preference") - self.userCol = Param(self, "userCol", "column name for user ids") - self.itemCol = Param(self, "itemCol", "column name for item ids") - self.ratingCol = Param(self, "ratingCol", "column name for ratings") - self.nonnegative = Param(self, "nonnegative", - "whether to use nonnegative constraint for least squares") self._setDefault(rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10) @@ -285,7 +296,7 @@ def getNonnegative(self): return self.getOrDefault(self.nonnegative) -class ALSModel(JavaModel): +class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by ALS. @@ -319,16 +330,27 @@ def itemFactors(self): if __name__ == "__main__": import doctest + import pyspark.ml.recommendation from pyspark.context import SparkContext from pyspark.sql import SQLContext - globs = globals().copy() + globs = pyspark.ml.recommendation.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.recommendation tests") sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - sc.stop() + import tempfile + temp_path = tempfile.mkdtemp() + globs['temp_path'] = temp_path + try: + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + finally: + from shutil import rmtree + try: + rmtree(temp_path) + except OSError: + pass if failure_count: exit(-1) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 7648bf13266bf..3c7852526a481 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -18,29 +18,32 @@ import warnings from pyspark import since -from pyspark.ml.util import keyword_only -from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * +from pyspark.ml.util import * +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper from pyspark.mllib.common import inherit_doc +from pyspark.sql import DataFrame __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel', 'DecisionTreeRegressor', 'DecisionTreeRegressionModel', 'GBTRegressor', 'GBTRegressionModel', + 'GeneralizedLinearRegression', 'GeneralizedLinearRegressionModel', 'IsotonicRegression', 'IsotonicRegressionModel', 'LinearRegression', 'LinearRegressionModel', + 'LinearRegressionSummary', 'LinearRegressionTrainingSummary', 'RandomForestRegressor', 'RandomForestRegressionModel'] @inherit_doc class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, - HasStandardization, HasSolver): + HasStandardization, HasSolver, HasWeightCol, JavaMLWritable, JavaMLReadable): """ Linear regression. The learning objective is to minimize the squared error, with regularization. - The specific squared error loss function used is: L = 1/2n ||A weights - y||^2^ + The specific squared error loss function used is: L = 1/2n ||A coefficients - y||^2^ This support multiple types of regularization: - none (a.k.a. ordinary least squares) @@ -50,9 +53,9 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction >>> from pyspark.mllib.linalg import Vectors >>> df = sqlContext.createDataFrame([ - ... (1.0, Vectors.dense(1.0)), - ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal") + ... (1.0, 2.0, Vectors.dense(1.0)), + ... (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"]) + >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight") >>> model = lr.fit(df) >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> abs(model.transform(test0).head().prediction - (-1.0)) < 0.001 @@ -68,6 +71,18 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + >>> lr_path = temp_path + "/lr" + >>> lr.save(lr_path) + >>> lr2 = LinearRegression.load(lr_path) + >>> lr2.getMaxIter() + 5 + >>> model_path = temp_path + "/lr_model" + >>> model.save(model_path) + >>> model2 = LinearRegressionModel.load(model_path) + >>> model.coefficients[0] == model2.coefficients[0] + True + >>> model.intercept == model2.intercept + True .. versionadded:: 1.4.0 """ @@ -75,11 +90,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - standardization=True, solver="auto"): + standardization=True, solver="auto", weightCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - standardization=True, solver="auto") + standardization=True, solver="auto", weightCol=None) """ super(LinearRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -92,11 +107,11 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - standardization=True, solver="auto"): + standardization=True, solver="auto", weightCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - standardization=True, solver="auto") + standardization=True, solver="auto", weightCol=None) Sets params for linear regression. """ kwargs = self.setParams._input_kwargs @@ -106,7 +121,7 @@ def _create_model(self, java_model): return LinearRegressionModel(java_model) -class LinearRegressionModel(JavaModel): +class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by LinearRegression. @@ -119,7 +134,6 @@ def weights(self): """ Model weights. """ - warnings.warn("weights is deprecated. Use coefficients instead.") return self._call_java("weights") @@ -139,10 +153,259 @@ def intercept(self): """ return self._call_java("intercept") + @property + @since("2.0.0") + def summary(self): + """ + Gets summary (e.g. residuals, mse, r-squared ) of model on + training set. An exception is thrown if + `trainingSummary is None`. + """ + java_lrt_summary = self._call_java("summary") + return LinearRegressionTrainingSummary(java_lrt_summary) + + @property + @since("2.0.0") + def hasSummary(self): + """ + Indicates whether a training summary exists for this model + instance. + """ + return self._call_java("hasSummary") + + @since("2.0.0") + def evaluate(self, dataset): + """ + Evaluates the model on a test dataset. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + """ + if not isinstance(dataset, DataFrame): + raise ValueError("dataset must be a DataFrame but got %s." % type(dataset)) + java_lr_summary = self._call_java("evaluate", dataset) + return LinearRegressionSummary(java_lr_summary) + + +class LinearRegressionSummary(JavaWrapper): + """ + .. note:: Experimental + + Linear regression results evaluated on a dataset. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def predictions(self): + """ + Dataframe outputted by the model's `transform` method. + """ + return self._call_java("predictions") + + @property + @since("2.0.0") + def predictionCol(self): + """ + Field in "predictions" which gives the predicted value of + the label at each instance. + """ + return self._call_java("predictionCol") + + @property + @since("2.0.0") + def labelCol(self): + """ + Field in "predictions" which gives the true label of each + instance. + """ + return self._call_java("labelCol") + + @property + @since("2.0.0") + def featuresCol(self): + """ + Field in "predictions" which gives the features of each instance + as a vector. + """ + return self._call_java("featuresCol") + + @property + @since("2.0.0") + def explainedVariance(self): + """ + Returns the explained variance regression score. + explainedVariance = 1 - variance(y - \hat{y}) / variance(y) + Reference: http://en.wikipedia.org/wiki/Explained_variation + + Note: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("explainedVariance") + + @property + @since("2.0.0") + def meanAbsoluteError(self): + """ + Returns the mean absolute error, which is a risk function + corresponding to the expected value of the absolute error + loss or l1-norm loss. + + Note: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("meanAbsoluteError") + + @property + @since("2.0.0") + def meanSquaredError(self): + """ + Returns the mean squared error, which is a risk function + corresponding to the expected value of the squared error + loss or quadratic loss. + + Note: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("meanSquaredError") + + @property + @since("2.0.0") + def rootMeanSquaredError(self): + """ + Returns the root mean squared error, which is defined as the + square root of the mean squared error. + + Note: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("rootMeanSquaredError") + + @property + @since("2.0.0") + def r2(self): + """ + Returns R^2^, the coefficient of determination. + Reference: http://en.wikipedia.org/wiki/Coefficient_of_determination + + Note: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("r2") + + @property + @since("2.0.0") + def residuals(self): + """ + Residuals (label - predicted value) + """ + return self._call_java("residuals") + + @property + @since("2.0.0") + def numInstances(self): + """ + Number of instances in DataFrame predictions + """ + return self._call_java("numInstances") + + @property + @since("2.0.0") + def devianceResiduals(self): + """ + The weighted residuals, the usual residuals rescaled by the + square root of the instance weights. + """ + return self._call_java("devianceResiduals") + + @property + @since("2.0.0") + def coefficientStandardErrors(self): + """ + Standard error of estimated coefficients and intercept. + This value is only available when using the "normal" solver. + + If :py:attr:`LinearRegression.fitIntercept` is set to True, + then the last element returned corresponds to the intercept. + + .. seealso:: :py:attr:`LinearRegression.solver` + """ + return self._call_java("coefficientStandardErrors") + + @property + @since("2.0.0") + def tValues(self): + """ + T-statistic of estimated coefficients and intercept. + This value is only available when using the "normal" solver. + + If :py:attr:`LinearRegression.fitIntercept` is set to True, + then the last element returned corresponds to the intercept. + + .. seealso:: :py:attr:`LinearRegression.solver` + """ + return self._call_java("tValues") + + @property + @since("2.0.0") + def pValues(self): + """ + Two-sided p-value of estimated coefficients and intercept. + This value is only available when using the "normal" solver. + + If :py:attr:`LinearRegression.fitIntercept` is set to True, + then the last element returned corresponds to the intercept. + + .. seealso:: :py:attr:`LinearRegression.solver` + """ + return self._call_java("pValues") + + +@inherit_doc +class LinearRegressionTrainingSummary(LinearRegressionSummary): + """ + .. note:: Experimental + + Linear regression training results. Currently, the training summary ignores the + training weights except for the objective trace. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def objectiveHistory(self): + """ + Objective function (scaled loss + regularization) at each + iteration. + This value is only available when using the "l-bfgs" solver. + + .. seealso:: :py:attr:`LinearRegression.solver` + """ + return self._call_java("objectiveHistory") + + @property + @since("2.0.0") + def totalIterations(self): + """ + Number of training iterations until termination. + This value is only available when using the "l-bfgs" solver. + + .. seealso:: :py:attr:`LinearRegression.solver` + """ + return self._call_java("totalIterations") + @inherit_doc class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - HasWeightCol): + HasWeightCol, JavaMLWritable, JavaMLReadable): """ .. note:: Experimental @@ -160,16 +423,28 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti 0.0 >>> model.boundaries DenseVector([0.0, 1.0]) + >>> ir_path = temp_path + "/ir" + >>> ir.save(ir_path) + >>> ir2 = IsotonicRegression.load(ir_path) + >>> ir2.getIsotonic() + True + >>> model_path = temp_path + "/ir_model" + >>> model.save(model_path) + >>> model2 = IsotonicRegressionModel.load(model_path) + >>> model.boundaries == model2.boundaries + True + >>> model.predictions == model2.predictions + True """ - # a placeholder to make it appear in the generated doc isotonic = \ Param(Params._dummy(), "isotonic", "whether the output sequence should be isotonic/increasing (true) or" + - "antitonic/decreasing (false).") + "antitonic/decreasing (false).", typeConverter=TypeConverters.toBoolean) featureIndex = \ Param(Params._dummy(), "featureIndex", - "The index of the feature if featuresCol is a vector column, no effect otherwise.") + "The index of the feature if featuresCol is a vector column, no effect otherwise.", + typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", @@ -181,14 +456,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(IsotonicRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.IsotonicRegression", self.uid) - self.isotonic = \ - Param(self, "isotonic", - "whether the output sequence should be isotonic/increasing (true) or" + - "antitonic/decreasing (false).") - self.featureIndex = \ - Param(self, "featureIndex", - "The index of the feature if featuresCol is a vector column, no effect " + - "otherwise.") self._setDefault(isotonic=True, featureIndex=0) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -234,7 +501,7 @@ def getFeatureIndex(self): return self.getOrDefault(self.featureIndex) -class IsotonicRegressionModel(JavaModel): +class IsotonicRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ .. note:: Experimental @@ -262,15 +529,12 @@ class TreeEnsembleParams(DecisionTreeParams): Mixin for Decision Tree-based ensemble algorithms parameters. """ - # a placeholder to make it appear in the generated doc subsamplingRate = Param(Params._dummy(), "subsamplingRate", "Fraction of the training data " + - "used for learning each decision tree, in range (0, 1].") + "used for learning each decision tree, in range (0, 1].", + typeConverter=TypeConverters.toFloat) def __init__(self): super(TreeEnsembleParams, self).__init__() - #: param for Fraction of the training data, in range (0, 1]. - self.subsamplingRate = Param(self, "subsamplingRate", "Fraction of the training data " + - "used for learning each decision tree, in range (0, 1].") @since("1.4.0") def setSubsamplingRate(self, value): @@ -294,7 +558,6 @@ class TreeRegressorParams(Params): """ supportedImpurities = ["variance"] - # a placeholder to make it appear in the generated doc impurity = Param(Params._dummy(), "impurity", "Criterion used for information gain calculation (case-insensitive). " + "Supported options: " + @@ -302,10 +565,6 @@ class TreeRegressorParams(Params): def __init__(self): super(TreeRegressorParams, self).__init__() - #: param for Criterion used for information gain calculation (case-insensitive). - self.impurity = Param(self, "impurity", "Criterion used for information " + - "gain calculation (case-insensitive). Supported options: " + - ", ".join(self.supportedImpurities)) @since("1.4.0") def setImpurity(self, value): @@ -329,22 +588,16 @@ class RandomForestParams(TreeEnsembleParams): """ supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"] - # a placeholder to make it appear in the generated doc - numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).") + numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).", + typeConverter=TypeConverters.toInt) featureSubsetStrategy = \ Param(Params._dummy(), "featureSubsetStrategy", "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(supportedFeatureSubsetStrategies)) + "options: " + ", ".join(supportedFeatureSubsetStrategies), + typeConverter=TypeConverters.toString) def __init__(self): super(RandomForestParams, self).__init__() - #: param for Number of trees to train (>= 1). - self.numTrees = Param(self, "numTrees", "Number of trees to train (>= 1).") - #: param for The number of features to consider for splits at each tree node. - self.featureSubsetStrategy = \ - Param(self, "featureSubsetStrategy", - "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(self.supportedFeatureSubsetStrategies)) @since("1.4.0") def setNumTrees(self, value): @@ -386,7 +639,8 @@ class GBTParams(TreeEnsembleParams): @inherit_doc class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval): + DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval, + HasSeed, JavaMLWritable, JavaMLReadable, HasVarianceCol): """ `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` learning algorithm for regression. @@ -396,18 +650,34 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> df = sqlContext.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> dt = DecisionTreeRegressor(maxDepth=2) + >>> dt = DecisionTreeRegressor(maxDepth=2, varianceCol="variance") >>> model = dt.fit(df) >>> model.depth 1 >>> model.numNodes 3 + >>> model.featureImportances + SparseVector(1, {0: 1.0}) >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction 0.0 >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + >>> dtr_path = temp_path + "/dtr" + >>> dt.save(dtr_path) + >>> dt2 = DecisionTreeRegressor.load(dtr_path) + >>> dt2.getMaxDepth() + 2 + >>> model_path = temp_path + "/dtr_model" + >>> model.save(model_path) + >>> model2 = DecisionTreeRegressionModel.load(model_path) + >>> model.numNodes == model2.numNodes + True + >>> model.depth == model2.depth + True + >>> model.transform(test1).head().variance + 0.0 .. versionadded:: 1.4.0 """ @@ -415,11 +685,13 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance"): + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance", + seed=None, varianceCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance") + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ + impurity="variance", seed=None, varianceCol=None) """ super(DecisionTreeRegressor, self).__init__() self._java_obj = self._new_java_obj( @@ -435,11 +707,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - impurity="variance"): + impurity="variance", seed=None, varianceCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance") + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ + impurity="variance", seed=None, varianceCol=None) Sets params for the DecisionTreeRegressor. """ kwargs = self.setParams._input_kwargs @@ -490,17 +763,39 @@ def __repr__(self): @inherit_doc -class DecisionTreeRegressionModel(DecisionTreeModel): +class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable): """ Model fitted by DecisionTreeRegressor. .. versionadded:: 1.4.0 """ + @property + @since("2.0.0") + def featureImportances(self): + """ + Estimate of the importance of each feature. + + This generalizes the idea of "Gini" importance to other losses, + following the explanation of Gini importance from "Random Forests" documentation + by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + + This feature importance is calculated as follows: + - importance(feature j) = sum (over nodes which split on feature j) of the gain, + where gain is scaled by the number of instances passing through node + - Normalize importances for tree to sum to 1. + + Note: Feature importance for single decision trees can have high variance due to + correlated predictor variables. Consider using a :py:class:`RandomForestRegressor` + to determine feature importance instead. + """ + return self._call_java("featureImportances") + @inherit_doc class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed, - RandomForestParams, TreeRegressorParams, HasCheckpointInterval): + RandomForestParams, TreeRegressorParams, HasCheckpointInterval, + JavaMLWritable, JavaMLReadable): """ `http://en.wikipedia.org/wiki/Random_forest Random Forest` learning algorithm for regression. @@ -513,6 +808,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) >>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42) >>> model = rf.fit(df) + >>> model.featureImportances + SparseVector(1, {0: 1.0}) >>> allclose(model.treeWeights, [1.0, 1.0]) True >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) @@ -521,6 +818,16 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 0.5 + >>> rfr_path = temp_path + "/rfr" + >>> rf.save(rfr_path) + >>> rf2 = RandomForestRegressor.load(rfr_path) + >>> rf2.getNumTrees() + 2 + >>> model_path = temp_path + "/rfr_model" + >>> model.save(model_path) + >>> model2 = RandomForestRegressionModel.load(model_path) + >>> model.featureImportances == model2.featureImportances + True .. versionadded:: 1.4.0 """ @@ -570,17 +877,33 @@ def _create_model(self, java_model): return RandomForestRegressionModel(java_model) -class RandomForestRegressionModel(TreeEnsembleModels): +class RandomForestRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable): """ Model fitted by RandomForestRegressor. .. versionadded:: 1.4.0 """ + @property + @since("2.0.0") + def featureImportances(self): + """ + Estimate of the importance of each feature. + + Each feature's importance is the average of its importance across all trees in the ensemble + The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. + (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) + and follows the implementation from scikit-learn. + + .. seealso:: :py:attr:`DecisionTreeRegressionModel.featureImportances` + """ + return self._call_java("featureImportances") + @inherit_doc class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - GBTParams, HasCheckpointInterval, HasStepSize, HasSeed): + GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable, + JavaMLReadable): """ `http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)` learning algorithm for regression. @@ -591,8 +914,10 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, >>> df = sqlContext.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> gbt = GBTRegressor(maxIter=5, maxDepth=2) + >>> gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42) >>> model = gbt.fit(df) + >>> model.featureImportances + SparseVector(1, {0: 1.0}) >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1]) True >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) @@ -601,35 +926,44 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + >>> gbtr_path = temp_path + "gbtr" + >>> gbt.save(gbtr_path) + >>> gbt2 = GBTRegressor.load(gbtr_path) + >>> gbt2.getMaxDepth() + 2 + >>> model_path = temp_path + "gbtr_model" + >>> model.save(model_path) + >>> model2 = GBTRegressionModel.load(model_path) + >>> model.featureImportances == model2.featureImportances + True + >>> model.treeWeights == model2.treeWeights + True .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc lossType = Param(Params._dummy(), "lossType", "Loss function which GBT tries to minimize (case-insensitive). " + - "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) + "Supported options: " + ", ".join(GBTParams.supportedLossTypes), + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, - checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1): + checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \ - checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1) + checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None) """ super(GBTRegressor, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid) - #: param for Loss function which GBT tries to minimize (case-insensitive). - self.lossType = Param(self, "lossType", - "Loss function which GBT tries to minimize (case-insensitive). " + - "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, - checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1) + checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, + seed=None) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -638,12 +972,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, - checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1): + checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \ - checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1) + checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None) Sets params for Gradient Boosted Tree Regression. """ kwargs = self.setParams._input_kwargs @@ -668,17 +1002,32 @@ def getLossType(self): return self.getOrDefault(self.lossType) -class GBTRegressionModel(TreeEnsembleModels): +class GBTRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable): """ Model fitted by GBTRegressor. .. versionadded:: 1.4.0 """ + @property + @since("2.0.0") + def featureImportances(self): + """ + Estimate of the importance of each feature. + + Each feature's importance is the average of its importance across all trees in the ensemble + The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. + (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) + and follows the implementation from scikit-learn. + + .. seealso:: :py:attr:`DecisionTreeRegressionModel.featureImportances` + """ + return self._call_java("featureImportances") + @inherit_doc class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - HasFitIntercept, HasMaxIter, HasTol): + HasFitIntercept, HasMaxIter, HasTol, JavaMLWritable, JavaMLReadable): """ Accelerated Failure Time (AFT) Model Survival Regression @@ -705,22 +1054,37 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi | 0.0|(1,[],[])| 0.0| 1.0| +-----+---------+------+----------+ ... + >>> aftsr_path = temp_path + "/aftsr" + >>> aftsr.save(aftsr_path) + >>> aftsr2 = AFTSurvivalRegression.load(aftsr_path) + >>> aftsr2.getMaxIter() + 100 + >>> model_path = temp_path + "/aftsr_model" + >>> model.save(model_path) + >>> model2 = AFTSurvivalRegressionModel.load(model_path) + >>> model.coefficients == model2.coefficients + True + >>> model.intercept == model2.intercept + True + >>> model.scale == model2.scale + True .. versionadded:: 1.6.0 """ - # a placeholder to make it appear in the generated doc censorCol = Param(Params._dummy(), "censorCol", "censor column name. The value of this column could be 0 or 1. " + "If the value is 1, it means the event has occurred i.e. " + - "uncensored; otherwise censored.") + "uncensored; otherwise censored.", typeConverter=TypeConverters.toString) quantileProbabilities = \ Param(Params._dummy(), "quantileProbabilities", "quantile probabilities array. Values of the quantile probabilities array " + - "should be in the range (0, 1) and the array should be non-empty.") + "should be in the range (0, 1) and the array should be non-empty.", + typeConverter=TypeConverters.toListFloat) quantilesCol = Param(Params._dummy(), "quantilesCol", "quantiles column name. This column will output quantiles of " + - "corresponding quantileProbabilities if it is set.") + "corresponding quantileProbabilities if it is set.", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", @@ -735,20 +1099,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(AFTSurvivalRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.AFTSurvivalRegression", self.uid) - #: Param for censor column name - self.censorCol = Param(self, "censorCol", - "censor column name. The value of this column could be 0 or 1. " + - "If the value is 1, it means the event has occurred i.e. " + - "uncensored; otherwise censored.") - #: Param for quantile probabilities array - self.quantileProbabilities = \ - Param(self, "quantileProbabilities", - "quantile probabilities array. Values of the quantile probabilities array " + - "should be in the range (0, 1) and the array should be non-empty.") - #: Param for quantiles column name - self.quantilesCol = Param(self, "quantilesCol", - "quantiles column name. This column will output quantiles of " + - "corresponding quantileProbabilities if it is set.") self._setDefault(censorCol="censor", quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]) kwargs = self.__init__._input_kwargs @@ -817,7 +1167,7 @@ def getQuantilesCol(self): return self.getOrDefault(self.quantilesCol) -class AFTSurvivalRegressionModel(JavaModel): +class AFTSurvivalRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by AFTSurvivalRegression. @@ -861,18 +1211,173 @@ def predict(self, features): return self._call_java("predict", features) +@inherit_doc +class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, HasPredictionCol, + HasFitIntercept, HasMaxIter, HasTol, HasRegParam, HasWeightCol, + HasSolver, JavaMLWritable, JavaMLReadable): + """ + Generalized Linear Regression. + + Fit a Generalized Linear Model specified by giving a symbolic description of the linear + predictor (link function) and a description of the error distribution (family). It supports + "gaussian", "binomial", "poisson" and "gamma" as family. Valid link functions for each family + is listed below. The first link function of each family is the default one. + - "gaussian" -> "identity", "log", "inverse" + - "binomial" -> "logit", "probit", "cloglog" + - "poisson" -> "log", "identity", "sqrt" + - "gamma" -> "inverse", "identity", "log" + + .. seealso:: `GLM `_ + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([ + ... (1.0, Vectors.dense(0.0, 0.0)), + ... (1.0, Vectors.dense(1.0, 2.0)), + ... (2.0, Vectors.dense(0.0, 0.0)), + ... (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"]) + >>> glr = GeneralizedLinearRegression(family="gaussian", link="identity") + >>> model = glr.fit(df) + >>> abs(model.transform(df).head().prediction - 1.5) < 0.001 + True + >>> model.coefficients + DenseVector([1.5..., -1.0...]) + >>> abs(model.intercept - 1.5) < 0.001 + True + >>> glr_path = temp_path + "/glr" + >>> glr.save(glr_path) + >>> glr2 = GeneralizedLinearRegression.load(glr_path) + >>> glr.getFamily() == glr2.getFamily() + True + >>> model_path = temp_path + "/glr_model" + >>> model.save(model_path) + >>> model2 = GeneralizedLinearRegressionModel.load(model_path) + >>> model.intercept == model2.intercept + True + >>> model.coefficients[0] == model2.coefficients[0] + True + + .. versionadded:: 2.0.0 + """ + + family = Param(Params._dummy(), "family", "The name of family which is a description of " + + "the error distribution to be used in the model. Supported options: " + + "gaussian(default), binomial, poisson and gamma.") + link = Param(Params._dummy(), "link", "The name of link function which provides the " + + "relationship between the linear predictor and the mean of the distribution " + + "function. Supported options: identity, log, inverse, logit, probit, cloglog " + + "and sqrt.") + + @keyword_only + def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", + family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, + regParam=0.0, weightCol=None, solver="irls"): + """ + __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ + family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ + regParam=0.0, weightCol=None, solver="irls") + """ + super(GeneralizedLinearRegression, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid) + self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.0.0") + def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", + family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, + regParam=0.0, weightCol=None, solver="irls"): + """ + setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ + family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ + regParam=0.0, weightCol=None, solver="irls") + Sets params for generalized linear regression. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return GeneralizedLinearRegressionModel(java_model) + + @since("2.0.0") + def setFamily(self, value): + """ + Sets the value of :py:attr:`family`. + """ + self._paramMap[self.family] = value + return self + + @since("2.0.0") + def getFamily(self): + """ + Gets the value of family or its default value. + """ + return self.getOrDefault(self.family) + + @since("2.0.0") + def setLink(self, value): + """ + Sets the value of :py:attr:`link`. + """ + self._paramMap[self.link] = value + return self + + @since("2.0.0") + def getLink(self): + """ + Gets the value of link or its default value. + """ + return self.getOrDefault(self.link) + + +class GeneralizedLinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): + """ + Model fitted by GeneralizedLinearRegression. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def coefficients(self): + """ + Model coefficients. + """ + return self._call_java("coefficients") + + @property + @since("2.0.0") + def intercept(self): + """ + Model intercept. + """ + return self._call_java("intercept") + + if __name__ == "__main__": import doctest + import pyspark.ml.regression from pyspark.context import SparkContext from pyspark.sql import SQLContext - globs = globals().copy() + globs = pyspark.ml.regression.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.regression tests") sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - sc.stop() + import tempfile + temp_path = tempfile.mkdtemp() + globs['temp_path'] = temp_path + try: + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + finally: + from shutil import rmtree + try: + rmtree(temp_path) + except OSError: + pass if failure_count: exit(-1) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 7a16cf52cccb2..86c0254a2b7b5 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -18,8 +18,11 @@ """ Unit tests for Spark ML Python APIs. """ - +import array import sys +if sys.version > '3': + xrange = range + try: import xmlrunner except ImportError: @@ -34,17 +37,26 @@ else: import unittest -from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase -from pyspark.sql import DataFrame, SQLContext, Row -from pyspark.sql.functions import rand -from pyspark.ml.evaluation import RegressionEvaluator -from pyspark.ml.param import Param, Params +from shutil import rmtree +import tempfile +import numpy as np + +from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer +from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier +from pyspark.ml.clustering import KMeans +from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator +from pyspark.ml.feature import * +from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed +from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor +from pyspark.ml.tuning import * from pyspark.ml.util import keyword_only -from pyspark.ml import Estimator, Model, Pipeline, Transformer -from pyspark.ml.feature import * -from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, CrossValidatorModel -from pyspark.mllib.linalg import DenseVector +from pyspark.ml.util import MLWritable, MLWriter +from pyspark.ml.wrapper import JavaParams +from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector +from pyspark.sql import DataFrame, SQLContext, Row +from pyspark.sql.functions import rand +from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase class MockDataset(DataFrame): @@ -92,6 +104,72 @@ class MockModel(MockTransformer, Model, HasFake): pass +class ParamTypeConversionTests(PySparkTestCase): + """ + Test that param type conversion happens. + """ + + def test_int(self): + lr = LogisticRegression(maxIter=5.0) + self.assertEqual(lr.getMaxIter(), 5) + self.assertTrue(type(lr.getMaxIter()) == int) + self.assertRaises(TypeError, lambda: LogisticRegression(maxIter="notAnInt")) + self.assertRaises(TypeError, lambda: LogisticRegression(maxIter=5.1)) + + def test_float(self): + lr = LogisticRegression(tol=1) + self.assertEqual(lr.getTol(), 1.0) + self.assertTrue(type(lr.getTol()) == float) + self.assertRaises(TypeError, lambda: LogisticRegression(tol="notAFloat")) + + def test_vector(self): + ewp = ElementwiseProduct(scalingVec=[1, 3]) + self.assertEqual(ewp.getScalingVec(), DenseVector([1.0, 3.0])) + ewp = ElementwiseProduct(scalingVec=np.array([1.2, 3.4])) + self.assertEqual(ewp.getScalingVec(), DenseVector([1.2, 3.4])) + self.assertRaises(TypeError, lambda: ElementwiseProduct(scalingVec=["a", "b"])) + + def test_list(self): + l = [0, 1] + for lst_like in [l, np.array(l), DenseVector(l), SparseVector(len(l), range(len(l)), l), + array.array('l', l), xrange(2), tuple(l)]: + converted = TypeConverters.toList(lst_like) + self.assertEqual(type(converted), list) + self.assertListEqual(converted, l) + + def test_list_int(self): + for indices in [[1.0, 2.0], np.array([1.0, 2.0]), DenseVector([1.0, 2.0]), + SparseVector(2, {0: 1.0, 1: 2.0}), xrange(1, 3), (1.0, 2.0), + array.array('d', [1.0, 2.0])]: + vs = VectorSlicer(indices=indices) + self.assertListEqual(vs.getIndices(), [1, 2]) + self.assertTrue(all([type(v) == int for v in vs.getIndices()])) + self.assertRaises(TypeError, lambda: VectorSlicer(indices=["a", "b"])) + + def test_list_float(self): + b = Bucketizer(splits=[1, 4]) + self.assertEqual(b.getSplits(), [1.0, 4.0]) + self.assertTrue(all([type(v) == float for v in b.getSplits()])) + self.assertRaises(TypeError, lambda: Bucketizer(splits=["a", 1.0])) + + def test_list_string(self): + for labels in [np.array(['a', u'b']), ['a', u'b'], np.array(['a', 'b'])]: + idx_to_string = IndexToString(labels=labels) + self.assertListEqual(idx_to_string.getLabels(), ['a', 'b']) + self.assertRaises(TypeError, lambda: IndexToString(labels=['a', 2])) + + def test_string(self): + lr = LogisticRegression() + for col in ['features', u'features', np.str_('features')]: + lr.setFeaturesCol(col) + self.assertEqual(lr.getFeaturesCol(), 'features') + self.assertRaises(TypeError, lambda: LogisticRegression(featuresCol=2.3)) + + def test_bool(self): + self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept=1)) + self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept="false")) + + class PipelineTests(PySparkTestCase): def test_pipeline(self): @@ -161,8 +239,31 @@ def setParams(self, seed=None): return self._set(**kwargs) +class HasThrowableProperty(Params): + + def __init__(self): + super(HasThrowableProperty, self).__init__() + self.p = Param(self, "none", "empty param") + + @property + def test_property(self): + raise RuntimeError("Test property to raise error when invoked") + + class ParamTests(PySparkTestCase): + def test_copy_new_parent(self): + testParams = TestParams() + # Copying an instantiated param should fail + with self.assertRaises(ValueError): + testParams.maxIter._copy_new_parent(testParams) + # Copying a dummy param should succeed + TestParams.maxIter._copy_new_parent(testParams) + maxIter = testParams.maxIter + self.assertEqual(maxIter.name, "maxIter") + self.assertEqual(maxIter.doc, "max number of iterations (>= 0).") + self.assertTrue(maxIter.parent == testParams.uid) + def test_param(self): testParams = TestParams() maxIter = testParams.maxIter @@ -170,6 +271,11 @@ def test_param(self): self.assertEqual(maxIter.doc, "max number of iterations (>= 0).") self.assertTrue(maxIter.parent == testParams.uid) + def test_hasparam(self): + testParams = TestParams() + self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params])) + self.assertFalse(testParams.hasParam("notAParameter")) + def test_params(self): testParams = TestParams() maxIter = testParams.maxIter @@ -179,7 +285,7 @@ def test_params(self): params = testParams.params self.assertEqual(params, [inputCol, maxIter, seed]) - self.assertTrue(testParams.hasParam(maxIter)) + self.assertTrue(testParams.hasParam(maxIter.name)) self.assertTrue(testParams.hasDefault(maxIter)) self.assertFalse(testParams.isSet(maxIter)) self.assertTrue(testParams.isDefined(maxIter)) @@ -188,7 +294,7 @@ def test_params(self): self.assertTrue(testParams.isSet(maxIter)) self.assertEqual(testParams.getMaxIter(), 100) - self.assertTrue(testParams.hasParam(inputCol)) + self.assertTrue(testParams.hasParam(inputCol.name)) self.assertFalse(testParams.hasDefault(inputCol)) self.assertFalse(testParams.isSet(inputCol)) self.assertFalse(testParams.isDefined(inputCol)) @@ -205,6 +311,14 @@ def test_params(self): "maxIter: max number of iterations (>= 0). (default: 10, current: 100)", "seed: random seed. (default: 41, current: 43)"])) + def test_kmeans_param(self): + algo = KMeans() + self.assertEqual(algo.getInitMode(), "k-means||") + algo.setK(10) + self.assertEqual(algo.getK(), 10) + algo.setInitSteps(10) + self.assertEqual(algo.getInitSteps(), 10) + def test_hasseed(self): noSeedSpecd = TestParams() withSeedSpecd = TestParams(seed=42) @@ -219,6 +333,12 @@ def test_hasseed(self): # Check that a different class has a different seed self.assertNotEqual(other.getSeed(), noSeedSpecd.getSeed()) + def test_param_property_error(self): + param_store = HasThrowableProperty() + self.assertRaises(RuntimeError, lambda: param_store.test_property) + params = param_store.params # should not invoke the property 'test_property' + self.assertEqual(len(params), 1) + class FeatureTests(PySparkTestCase): @@ -286,6 +406,22 @@ def test_stopwordsremover(self): transformedDF = stopWordRemover.transform(dataset) self.assertEqual(transformedDF.head().output, ["a"]) + def test_count_vectorizer_with_binary(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([ + (0, "a a a b b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), + (1, "a a".split(' '), SparseVector(3, {0: 1.0}),), + (2, "a b".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), + (3, "c".split(' '), SparseVector(3, {2: 1.0}),)], ["id", "words", "expected"]) + cv = CountVectorizer(binary=True, inputCol="words", outputCol="features") + model = cv.fit(dataset) + + transformedList = model.transform(dataset).select("features", "expected").collect() + + for r in transformedList: + feature, expected = r + self.assertEqual(feature, expected) + class HasInducedError(Params): @@ -370,8 +506,368 @@ def test_fit_maximize_metric(self): "Best model should have zero induced error") self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") + def test_save_load(self): + temp_path = tempfile.mkdtemp() + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + cvPath = temp_path + "/cv" + cv.save(cvPath) + loadedCV = CrossValidator.load(cvPath) + self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) + self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) + self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps()) + cvModelPath = temp_path + "/cvModel" + cvModel.save(cvModelPath) + loadedModel = CrossValidatorModel.load(cvModelPath) + self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) + + +class TrainValidationSplitTests(PySparkTestCase): + + def test_fit_minimize_metric(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="rmse") + + grid = (ParamGridBuilder() + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) + .build()) + tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + bestModel = tvsModel.bestModel + bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) + + self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), + "Best model should have zero induced error") + self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0") + + def test_fit_maximize_metric(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="r2") + + grid = (ParamGridBuilder() + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) + .build()) + tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + bestModel = tvsModel.bestModel + bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) + + self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), + "Best model should have zero induced error") + self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") + + def test_save_load(self): + temp_path = tempfile.mkdtemp() + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + tvsPath = temp_path + "/tvs" + tvs.save(tvsPath) + loadedTvs = TrainValidationSplit.load(tvsPath) + self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) + self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) + self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps()) + tvsModelPath = temp_path + "/tvsModel" + tvsModel.save(tvsModelPath) + loadedModel = TrainValidationSplitModel.load(tvsModelPath) + self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) + + +class PersistenceTest(PySparkTestCase): + + def test_linear_regression(self): + lr = LinearRegression(maxIter=1) + path = tempfile.mkdtemp() + lr_path = path + "/lr" + lr.save(lr_path) + lr2 = LinearRegression.load(lr_path) + self.assertEqual(lr2.uid, lr2.maxIter.parent, + "Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)" + % (lr2.uid, lr2.maxIter.parent)) + self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter], + "Loaded LinearRegression instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + + def test_logistic_regression(self): + lr = LogisticRegression(maxIter=1) + path = tempfile.mkdtemp() + lr_path = path + "/logreg" + lr.save(lr_path) + lr2 = LogisticRegression.load(lr_path) + self.assertEqual(lr2.uid, lr2.maxIter.parent, + "Loaded LogisticRegression instance uid (%s) " + "did not match Param's uid (%s)" + % (lr2.uid, lr2.maxIter.parent)) + self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter], + "Loaded LogisticRegression instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + + def _compare_pipelines(self, m1, m2): + """ + Compare 2 ML types, asserting that they are equivalent. + This currently supports: + - basic types + - Pipeline, PipelineModel + This checks: + - uid + - type + - Param values and parents + """ + self.assertEqual(m1.uid, m2.uid) + self.assertEqual(type(m1), type(m2)) + if isinstance(m1, JavaParams): + self.assertEqual(len(m1.params), len(m2.params)) + for p in m1.params: + self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p)) + self.assertEqual(p.parent, m2.getParam(p.name).parent) + elif isinstance(m1, Pipeline): + self.assertEqual(len(m1.getStages()), len(m2.getStages())) + for s1, s2 in zip(m1.getStages(), m2.getStages()): + self._compare_pipelines(s1, s2) + elif isinstance(m1, PipelineModel): + self.assertEqual(len(m1.stages), len(m2.stages)) + for s1, s2 in zip(m1.stages, m2.stages): + self._compare_pipelines(s1, s2) + else: + raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1)) + + def test_pipeline_persistence(self): + """ + Pipeline[HashingTF, PCA] + """ + sqlContext = SQLContext(self.sc) + temp_path = tempfile.mkdtemp() + + try: + df = sqlContext.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) + tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") + pca = PCA(k=2, inputCol="features", outputCol="pca_features") + pl = Pipeline(stages=[tf, pca]) + model = pl.fit(df) + + pipeline_path = temp_path + "/pipeline" + pl.save(pipeline_path) + loaded_pipeline = Pipeline.load(pipeline_path) + self._compare_pipelines(pl, loaded_pipeline) + + model_path = temp_path + "/pipeline-model" + model.save(model_path) + loaded_model = PipelineModel.load(model_path) + self._compare_pipelines(model, loaded_model) + finally: + try: + rmtree(temp_path) + except OSError: + pass + + def test_nested_pipeline_persistence(self): + """ + Pipeline[HashingTF, Pipeline[PCA]] + """ + sqlContext = SQLContext(self.sc) + temp_path = tempfile.mkdtemp() + + try: + df = sqlContext.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) + tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") + pca = PCA(k=2, inputCol="features", outputCol="pca_features") + p0 = Pipeline(stages=[pca]) + pl = Pipeline(stages=[tf, p0]) + model = pl.fit(df) + + pipeline_path = temp_path + "/pipeline" + pl.save(pipeline_path) + loaded_pipeline = Pipeline.load(pipeline_path) + self._compare_pipelines(pl, loaded_pipeline) + + model_path = temp_path + "/pipeline-model" + model.save(model_path) + loaded_model = PipelineModel.load(model_path) + self._compare_pipelines(model, loaded_model) + finally: + try: + rmtree(temp_path) + except OSError: + pass + + def test_write_property(self): + lr = LinearRegression(maxIter=1) + self.assertTrue(isinstance(lr.write, MLWriter)) + + def test_decisiontree_classifier(self): + dt = DecisionTreeClassifier(maxDepth=1) + path = tempfile.mkdtemp() + dtc_path = path + "/dtc" + dt.save(dtc_path) + dt2 = DecisionTreeClassifier.load(dtc_path) + self.assertEqual(dt2.uid, dt2.maxDepth.parent, + "Loaded DecisionTreeClassifier instance uid (%s) " + "did not match Param's uid (%s)" + % (dt2.uid, dt2.maxDepth.parent)) + self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth], + "Loaded DecisionTreeClassifier instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + + def test_decisiontree_regressor(self): + dt = DecisionTreeRegressor(maxDepth=1) + path = tempfile.mkdtemp() + dtr_path = path + "/dtr" + dt.save(dtr_path) + dt2 = DecisionTreeClassifier.load(dtr_path) + self.assertEqual(dt2.uid, dt2.maxDepth.parent, + "Loaded DecisionTreeRegressor instance uid (%s) " + "did not match Param's uid (%s)" + % (dt2.uid, dt2.maxDepth.parent)) + self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth], + "Loaded DecisionTreeRegressor instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + + +class TrainingSummaryTest(PySparkTestCase): + + def test_linear_regression_summary(self): + from pyspark.mllib.linalg import Vectors + sqlContext = SQLContext(self.sc) + df = sqlContext.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight", + fitIntercept=False) + model = lr.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + # test that api is callable and returns expected types + self.assertGreater(s.totalIterations, 0) + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.predictionCol, "prediction") + self.assertEqual(s.labelCol, "label") + self.assertEqual(s.featuresCol, "features") + objHist = s.objectiveHistory + self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) + self.assertAlmostEqual(s.explainedVariance, 0.25, 2) + self.assertAlmostEqual(s.meanAbsoluteError, 0.0) + self.assertAlmostEqual(s.meanSquaredError, 0.0) + self.assertAlmostEqual(s.rootMeanSquaredError, 0.0) + self.assertAlmostEqual(s.r2, 1.0, 2) + self.assertTrue(isinstance(s.residuals, DataFrame)) + self.assertEqual(s.numInstances, 2) + devResiduals = s.devianceResiduals + self.assertTrue(isinstance(devResiduals, list) and isinstance(devResiduals[0], float)) + coefStdErr = s.coefficientStandardErrors + self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float)) + tValues = s.tValues + self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float)) + pValues = s.pValues + self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float)) + # test evaluation (with training dataset) produces a summary with same values + # one check is enough to verify a summary is returned, Scala version runs full test + sameSummary = model.evaluate(df) + self.assertAlmostEqual(sameSummary.explainedVariance, s.explainedVariance) + + def test_logistic_regression_summary(self): + from pyspark.mllib.linalg import Vectors + sqlContext = SQLContext(self.sc) + df = sqlContext.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False) + model = lr.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + # test that api is callable and returns expected types + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.probabilityCol, "probability") + self.assertEqual(s.labelCol, "label") + self.assertEqual(s.featuresCol, "features") + objHist = s.objectiveHistory + self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) + self.assertGreater(s.totalIterations, 0) + self.assertTrue(isinstance(s.roc, DataFrame)) + self.assertAlmostEqual(s.areaUnderROC, 1.0, 2) + self.assertTrue(isinstance(s.pr, DataFrame)) + self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame)) + self.assertTrue(isinstance(s.precisionByThreshold, DataFrame)) + self.assertTrue(isinstance(s.recallByThreshold, DataFrame)) + # test evaluation (with training dataset) produces a summary with same values + # one check is enough to verify a summary is returned, Scala version runs full test + sameSummary = model.evaluate(df) + self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) + + +class HashingTFTest(PySparkTestCase): + + def test_apply_binary_term_freqs(self): + sqlContext = SQLContext(self.sc) + + df = sqlContext.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"]) + n = 100 + hashingTF = HashingTF() + hashingTF.setInputCol("words").setOutputCol("features").setNumFeatures(n).setBinary(True) + output = hashingTF.transform(df) + features = output.select("features").first().features.toArray() + expected = Vectors.sparse(n, {(ord("a") % n): 1.0, + (ord("b") % n): 1.0, + (ord("c") % n): 1.0}).toArray() + for i in range(0, n): + self.assertAlmostEqual(features[i], expected[i], 14, "Error at " + str(i) + + ": expected " + str(expected[i]) + ", got " + str(features[i])) + if __name__ == "__main__": + from pyspark.ml.tests import * if xmlrunner: unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) else: diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 705ee53685752..456d79d897e00 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -18,13 +18,18 @@ import itertools import numpy as np +from pyspark import SparkContext from pyspark import since -from pyspark.ml.param import Params, Param from pyspark.ml import Estimator, Model -from pyspark.ml.util import keyword_only +from pyspark.ml.param import Params, Param, TypeConverters +from pyspark.ml.param.shared import HasSeed +from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable +from pyspark.ml.wrapper import JavaParams from pyspark.sql.functions import rand +from pyspark.mllib.common import inherit_doc, _py2java -__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel'] +__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit', + 'TrainValidationSplitModel'] class ParamGridBuilder(object): @@ -89,121 +94,135 @@ def build(self): return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)] -class CrossValidator(Estimator): +class ValidatorParams(HasSeed): """ - K-fold cross validation. - - >>> from pyspark.ml.classification import LogisticRegression - >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator - >>> from pyspark.mllib.linalg import Vectors - >>> dataset = sqlContext.createDataFrame( - ... [(Vectors.dense([0.0]), 0.0), - ... (Vectors.dense([0.4]), 1.0), - ... (Vectors.dense([0.5]), 0.0), - ... (Vectors.dense([0.6]), 1.0), - ... (Vectors.dense([1.0]), 1.0)] * 10, - ... ["features", "label"]) - >>> lr = LogisticRegression() - >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() - >>> evaluator = BinaryClassificationEvaluator() - >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - >>> cvModel = cv.fit(dataset) - >>> evaluator.evaluate(cvModel.transform(dataset)) - 0.8333... - - .. versionadded:: 1.4.0 + Common params for TrainValidationSplit and CrossValidator. """ - # a placeholder to make it appear in the generated doc estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated") - - # a placeholder to make it appear in the generated doc estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps") - - # a placeholder to make it appear in the generated doc evaluator = Param( Params._dummy(), "evaluator", - "evaluator used to select hyper-parameters that maximize the cross-validated metric") - - # a placeholder to make it appear in the generated doc - numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation") + "evaluator used to select hyper-parameters that maximize the validator metric") - @keyword_only - def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): - """ - __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3) - """ - super(CrossValidator, self).__init__() - #: param for estimator to be cross-validated - self.estimator = Param(self, "estimator", "estimator to be cross-validated") - #: param for estimator param maps - self.estimatorParamMaps = Param(self, "estimatorParamMaps", "estimator param maps") - #: param for the evaluator used to select hyper-parameters that - #: maximize the cross-validated metric - self.evaluator = Param( - self, "evaluator", - "evaluator used to select hyper-parameters that maximize the cross-validated metric") - #: param for number of folds for cross validation - self.numFolds = Param(self, "numFolds", "number of folds for cross validation") - self._setDefault(numFolds=3) - kwargs = self.__init__._input_kwargs - self._set(**kwargs) - - @keyword_only - @since("1.4.0") - def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): - """ - setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): - Sets params for cross validator. - """ - kwargs = self.setParams._input_kwargs - return self._set(**kwargs) - - @since("1.4.0") def setEstimator(self, value): """ Sets the value of :py:attr:`estimator`. """ - self._paramMap[self.estimator] = value - return self + return self._set(estimator=value) - @since("1.4.0") def getEstimator(self): """ Gets the value of estimator or its default value. """ return self.getOrDefault(self.estimator) - @since("1.4.0") def setEstimatorParamMaps(self, value): """ Sets the value of :py:attr:`estimatorParamMaps`. """ - self._paramMap[self.estimatorParamMaps] = value - return self + return self._set(estimatorParamMaps=value) - @since("1.4.0") def getEstimatorParamMaps(self): """ Gets the value of estimatorParamMaps or its default value. """ return self.getOrDefault(self.estimatorParamMaps) - @since("1.4.0") def setEvaluator(self, value): """ Sets the value of :py:attr:`evaluator`. """ - self._paramMap[self.evaluator] = value - return self + return self._set(evaluator=value) - @since("1.4.0") def getEvaluator(self): """ Gets the value of evaluator or its default value. """ return self.getOrDefault(self.evaluator) + @classmethod + def _from_java_impl(cls, java_stage): + """ + Return Python estimator, estimatorParamMaps, and evaluator from a Java ValidatorParams. + """ + + # Load information from java_stage to the instance. + estimator = JavaParams._from_java(java_stage.getEstimator()) + evaluator = JavaParams._from_java(java_stage.getEvaluator()) + epms = [estimator._transfer_param_map_from_java(epm) + for epm in java_stage.getEstimatorParamMaps()] + return estimator, epms, evaluator + + def _to_java_impl(self): + """ + Return Java estimator, estimatorParamMaps, and evaluator from this Python instance. + """ + + gateway = SparkContext._gateway + cls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap + + java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps())) + for idx, epm in enumerate(self.getEstimatorParamMaps()): + java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm) + + java_estimator = self.getEstimator()._to_java() + java_evaluator = self.getEvaluator()._to_java() + return java_estimator, java_epms, java_evaluator + + +class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable): + """ + K-fold cross validation. + + >>> from pyspark.ml.classification import LogisticRegression + >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator + >>> from pyspark.mllib.linalg import Vectors + >>> dataset = sqlContext.createDataFrame( + ... [(Vectors.dense([0.0]), 0.0), + ... (Vectors.dense([0.4]), 1.0), + ... (Vectors.dense([0.5]), 0.0), + ... (Vectors.dense([0.6]), 1.0), + ... (Vectors.dense([1.0]), 1.0)] * 10, + ... ["features", "label"]) + >>> lr = LogisticRegression() + >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + >>> evaluator = BinaryClassificationEvaluator() + >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + >>> cvModel = cv.fit(dataset) + >>> evaluator.evaluate(cvModel.transform(dataset)) + 0.8333... + + .. versionadded:: 1.4.0 + """ + + numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation", + typeConverter=TypeConverters.toInt) + + @keyword_only + def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, + seed=None): + """ + __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ + seed=None) + """ + super(CrossValidator, self).__init__() + self._setDefault(numFolds=3) + kwargs = self.__init__._input_kwargs + self._set(**kwargs) + + @keyword_only + @since("1.4.0") + def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, + seed=None): + """ + setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ + seed=None): + Sets params for cross validator. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + @since("1.4.0") def setNumFolds(self, value): """ @@ -225,9 +244,10 @@ def _fit(self, dataset): numModels = len(epm) eva = self.getOrDefault(self.evaluator) nFolds = self.getOrDefault(self.numFolds) + seed = self.getOrDefault(self.seed) h = 1.0 / nFolds randCol = self.uid + "_rand" - df = dataset.select("*", rand(0).alias(randCol)) + df = dataset.select("*", rand(seed).alias(randCol)) metrics = np.zeros(numModels) for i in range(nFolds): validateLB = i * h @@ -246,7 +266,7 @@ def _fit(self, dataset): else: bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) - return CrossValidatorModel(bestModel) + return self._copyValues(CrossValidatorModel(bestModel)) @since("1.4.0") def copy(self, extra=None): @@ -268,8 +288,58 @@ def copy(self, extra=None): newCV.setEvaluator(self.getEvaluator().copy(extra)) return newCV + @since("2.0.0") + def write(self): + """Returns an MLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + @since("2.0.0") + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + self.write().save(path) + + @classmethod + @since("2.0.0") + def read(cls): + """Returns an MLReader instance for this class.""" + return JavaMLReader(cls) + + @classmethod + def _from_java(cls, java_stage): + """ + Given a Java CrossValidator, create and return a Python wrapper of it. + Used for ML persistence. + """ + + estimator, epms, evaluator = super(CrossValidator, cls)._from_java_impl(java_stage) + numFolds = java_stage.getNumFolds() + seed = java_stage.getSeed() + # Create a new instance of this stage. + py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, + numFolds=numFolds, seed=seed) + py_stage._resetUid(java_stage.uid()) + return py_stage -class CrossValidatorModel(Model): + def _to_java(self): + """ + Transfer this instance to a Java CrossValidator. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + + estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl() + + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid) + _java_obj.setEstimatorParamMaps(epms) + _java_obj.setEvaluator(evaluator) + _java_obj.setEstimator(estimator) + _java_obj.setSeed(self.getSeed()) + _java_obj.setNumFolds(self.getNumFolds()) + + return _java_obj + + +class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable): """ Model from k-fold cross validation. @@ -299,20 +369,321 @@ def copy(self, extra=None): extra = dict() return CrossValidatorModel(self.bestModel.copy(extra)) + @since("2.0.0") + def write(self): + """Returns an MLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + @since("2.0.0") + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + self.write().save(path) + + @classmethod + @since("2.0.0") + def read(cls): + """Returns an MLReader instance for this class.""" + return JavaMLReader(cls) + + @classmethod + def _from_java(cls, java_stage): + """ + Given a Java CrossValidatorModel, create and return a Python wrapper of it. + Used for ML persistence. + """ + + # Load information from java_stage to the instance. + bestModel = JavaParams._from_java(java_stage.bestModel()) + estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) + # Create a new instance of this stage. + py_stage = cls(bestModel=bestModel)\ + .setEstimator(estimator).setEstimatorParamMaps(epms).setEvaluator(evaluator) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java CrossValidatorModel. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + + sc = SparkContext._active_spark_context + + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel", + self.uid, + self.bestModel._to_java(), + _py2java(sc, [])) + estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl() + + _java_obj.set("evaluator", evaluator) + _java_obj.set("estimator", estimator) + _java_obj.set("estimatorParamMaps", epms) + return _java_obj + + +class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable): + """ + Train-Validation-Split. + + >>> from pyspark.ml.classification import LogisticRegression + >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator + >>> from pyspark.mllib.linalg import Vectors + >>> dataset = sqlContext.createDataFrame( + ... [(Vectors.dense([0.0]), 0.0), + ... (Vectors.dense([0.4]), 1.0), + ... (Vectors.dense([0.5]), 0.0), + ... (Vectors.dense([0.6]), 1.0), + ... (Vectors.dense([1.0]), 1.0)] * 10, + ... ["features", "label"]) + >>> lr = LogisticRegression() + >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + >>> evaluator = BinaryClassificationEvaluator() + >>> tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + >>> tvsModel = tvs.fit(dataset) + >>> evaluator.evaluate(tvsModel.transform(dataset)) + 0.8333... + + .. versionadded:: 2.0.0 + """ + + trainRatio = Param(Params._dummy(), "trainRatio", "Param for ratio between train and\ + validation data. Must be between 0 and 1.") + + @keyword_only + def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, + seed=None): + """ + __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\ + seed=None) + """ + super(TrainValidationSplit, self).__init__() + self._setDefault(trainRatio=0.75) + kwargs = self.__init__._input_kwargs + self._set(**kwargs) + + @since("2.0.0") + @keyword_only + def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, + seed=None): + """ + setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\ + seed=None): + Sets params for the train validation split. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + @since("2.0.0") + def setTrainRatio(self, value): + """ + Sets the value of :py:attr:`trainRatio`. + """ + self._paramMap[self.trainRatio] = value + return self + + @since("2.0.0") + def getTrainRatio(self): + """ + Gets the value of trainRatio or its default value. + """ + return self.getOrDefault(self.trainRatio) + + def _fit(self, dataset): + est = self.getOrDefault(self.estimator) + epm = self.getOrDefault(self.estimatorParamMaps) + numModels = len(epm) + eva = self.getOrDefault(self.evaluator) + tRatio = self.getOrDefault(self.trainRatio) + seed = self.getOrDefault(self.seed) + randCol = self.uid + "_rand" + df = dataset.select("*", rand(seed).alias(randCol)) + metrics = np.zeros(numModels) + condition = (df[randCol] >= tRatio) + validation = df.filter(condition) + train = df.filter(~condition) + for j in range(numModels): + model = est.fit(train, epm[j]) + metric = eva.evaluate(model.transform(validation, epm[j])) + metrics[j] += metric + if eva.isLargerBetter(): + bestIndex = np.argmax(metrics) + else: + bestIndex = np.argmin(metrics) + bestModel = est.fit(dataset, epm[bestIndex]) + return self._copyValues(TrainValidationSplitModel(bestModel)) + + @since("2.0.0") + def copy(self, extra=None): + """ + Creates a copy of this instance with a randomly generated uid + and some extra params. This copies creates a deep copy of + the embedded paramMap, and copies the embedded and extra parameters over. + + :param extra: Extra parameters to copy to the new instance + :return: Copy of this instance + """ + if extra is None: + extra = dict() + newTVS = Params.copy(self, extra) + if self.isSet(self.estimator): + newTVS.setEstimator(self.getEstimator().copy(extra)) + # estimatorParamMaps remain the same + if self.isSet(self.evaluator): + newTVS.setEvaluator(self.getEvaluator().copy(extra)) + return newTVS + + @since("2.0.0") + def write(self): + """Returns an MLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + @since("2.0.0") + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + self.write().save(path) + + @classmethod + @since("2.0.0") + def read(cls): + """Returns an MLReader instance for this class.""" + return JavaMLReader(cls) + + @classmethod + def _from_java(cls, java_stage): + """ + Given a Java TrainValidationSplit, create and return a Python wrapper of it. + Used for ML persistence. + """ + + estimator, epms, evaluator = super(TrainValidationSplit, cls)._from_java_impl(java_stage) + trainRatio = java_stage.getTrainRatio() + seed = java_stage.getSeed() + # Create a new instance of this stage. + py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, + trainRatio=trainRatio, seed=seed) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java TrainValidationSplit. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + + estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl() + + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit", + self.uid) + _java_obj.setEstimatorParamMaps(epms) + _java_obj.setEvaluator(evaluator) + _java_obj.setEstimator(estimator) + _java_obj.setTrainRatio(self.getTrainRatio()) + _java_obj.setSeed(self.getSeed()) + + return _java_obj + + +class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): + """ + Model from train validation split. + + .. versionadded:: 2.0.0 + """ + + def __init__(self, bestModel): + super(TrainValidationSplitModel, self).__init__() + #: best model from cross validation + self.bestModel = bestModel + + def _transform(self, dataset): + return self.bestModel.transform(dataset) + + @since("2.0.0") + def copy(self, extra=None): + """ + Creates a copy of this instance with a randomly generated uid + and some extra params. This copies the underlying bestModel, + creates a deep copy of the embedded paramMap, and + copies the embedded and extra parameters over. + + :param extra: Extra parameters to copy to the new instance + :return: Copy of this instance + """ + if extra is None: + extra = dict() + return TrainValidationSplitModel(self.bestModel.copy(extra)) + + @since("2.0.0") + def write(self): + """Returns an MLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + @since("2.0.0") + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + self.write().save(path) + + @classmethod + @since("2.0.0") + def read(cls): + """Returns an MLReader instance for this class.""" + return JavaMLReader(cls) + + @classmethod + def _from_java(cls, java_stage): + """ + Given a Java TrainValidationSplitModel, create and return a Python wrapper of it. + Used for ML persistence. + """ + + # Load information from java_stage to the instance. + bestModel = JavaParams._from_java(java_stage.bestModel()) + estimator, epms, evaluator = \ + super(TrainValidationSplitModel, cls)._from_java_impl(java_stage) + # Create a new instance of this stage. + py_stage = cls(bestModel=bestModel)\ + .setEstimator(estimator).setEstimatorParamMaps(epms).setEvaluator(evaluator) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java TrainValidationSplitModel. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + + sc = SparkContext._active_spark_context + + _java_obj = JavaParams._new_java_obj( + "org.apache.spark.ml.tuning.TrainValidationSplitModel", + self.uid, + self.bestModel._to_java(), + _py2java(sc, [])) + estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl() + + _java_obj.set("evaluator", evaluator) + _java_obj.set("estimator", estimator) + _java_obj.set("estimatorParamMaps", epms) + return _java_obj + if __name__ == "__main__": import doctest + from pyspark.context import SparkContext from pyspark.sql import SQLContext globs = globals().copy() + # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.tuning tests") sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - (failure_count, test_count) = doctest.testmod( - globs=globs, optionflags=doctest.ELLIPSIS) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) sc.stop() if failure_count: exit(-1) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index cee9d67b05325..9dfcef0e40d67 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -15,8 +15,27 @@ # limitations under the License. # -from functools import wraps +import sys import uuid +from functools import wraps + +if sys.version > '3': + basestring = str + +from pyspark import SparkContext, since +from pyspark.mllib.common import inherit_doc + + +def _jvm(): + """ + Returns the JVM view associated with SparkContext. Must be called + after SparkContext is initialized. + """ + jvm = SparkContext._jvm + if jvm: + return jvm + else: + raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?") def keyword_only(func): @@ -52,3 +71,186 @@ def _randomUID(cls): concatenates the class name, "_", and 12 random hex chars. """ return cls.__name__ + "_" + uuid.uuid4().hex[12:] + + +@inherit_doc +class MLWriter(object): + """ + .. note:: Experimental + + Utility class that can save ML instances. + + .. versionadded:: 2.0.0 + """ + + def save(self, path): + """Save the ML instance to the input path.""" + raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) + + def overwrite(self): + """Overwrites if the output path already exists.""" + raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) + + def context(self, sqlContext): + """Sets the SQL context to use for saving.""" + raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) + + +@inherit_doc +class JavaMLWriter(MLWriter): + """ + (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaParams` types + """ + + def __init__(self, instance): + super(JavaMLWriter, self).__init__() + _java_obj = instance._to_java() + self._jwrite = _java_obj.write() + + def save(self, path): + """Save the ML instance to the input path.""" + if not isinstance(path, basestring): + raise TypeError("path should be a basestring, got type %s" % type(path)) + self._jwrite.save(path) + + def overwrite(self): + """Overwrites if the output path already exists.""" + self._jwrite.overwrite() + return self + + def context(self, sqlContext): + """Sets the SQL context to use for saving.""" + self._jwrite.context(sqlContext._ssql_ctx) + return self + + +@inherit_doc +class MLWritable(object): + """ + .. note:: Experimental + + Mixin for ML instances that provide :py:class:`MLWriter`. + + .. versionadded:: 2.0.0 + """ + + @property + def write(self): + """Returns an JavaMLWriter instance for this ML instance.""" + raise NotImplementedError("MLWritable is not yet implemented for type: %r" % type(self)) + + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + self.write.save(path) + + +@inherit_doc +class JavaMLWritable(MLWritable): + """ + (Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`. + """ + + @property + def write(self): + """Returns an JavaMLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + +@inherit_doc +class MLReader(object): + """ + .. note:: Experimental + + Utility class that can load ML instances. + + .. versionadded:: 2.0.0 + """ + + def load(self, path): + """Load the ML instance from the input path.""" + raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) + + def context(self, sqlContext): + """Sets the SQL context to use for loading.""" + raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) + + +@inherit_doc +class JavaMLReader(MLReader): + """ + (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaParams` types + """ + + def __init__(self, clazz): + self._clazz = clazz + self._jread = self._load_java_obj(clazz).read() + + def load(self, path): + """Load the ML instance from the input path.""" + if not isinstance(path, basestring): + raise TypeError("path should be a basestring, got type %s" % type(path)) + java_obj = self._jread.load(path) + if not hasattr(self._clazz, "_from_java"): + raise NotImplementedError("This Java ML type cannot be loaded into Python currently: %r" + % self._clazz) + return self._clazz._from_java(java_obj) + + def context(self, sqlContext): + """Sets the SQL context to use for loading.""" + self._jread.context(sqlContext._ssql_ctx) + return self + + @classmethod + def _java_loader_class(cls, clazz): + """ + Returns the full class name of the Java ML instance. The default + implementation replaces "pyspark" by "org.apache.spark" in + the Python full class name. + """ + java_package = clazz.__module__.replace("pyspark", "org.apache.spark") + if clazz.__name__ in ("Pipeline", "PipelineModel"): + # Remove the last package name "pipeline" for Pipeline and PipelineModel. + java_package = ".".join(java_package.split(".")[0:-1]) + return java_package + "." + clazz.__name__ + + @classmethod + def _load_java_obj(cls, clazz): + """Load the peer Java object of the ML instance.""" + java_class = cls._java_loader_class(clazz) + java_obj = _jvm() + for name in java_class.split("."): + java_obj = getattr(java_obj, name) + return java_obj + + +@inherit_doc +class MLReadable(object): + """ + .. note:: Experimental + + Mixin for instances that provide :py:class:`MLReader`. + + .. versionadded:: 2.0.0 + """ + + @classmethod + def read(cls): + """Returns an JavaMLReader instance for this class.""" + raise NotImplementedError("MLReadable.read() not implemented for type: %r" % cls) + + @classmethod + def load(cls, path): + """Reads an ML instance from the input path, a shortcut of `read().load(path)`.""" + return cls.read().load(path) + + +@inherit_doc +class JavaMLReadable(MLReadable): + """ + (Private) Mixin for instances that provide JavaMLReader. + """ + + @classmethod + def read(cls): + """Returns an JavaMLReader instance for this class.""" + return JavaMLReader(cls) diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 4bcb4aaec89de..cd0e5b80d5559 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -15,45 +15,42 @@ # limitations under the License. # -from abc import ABCMeta +from abc import ABCMeta, abstractmethod from pyspark import SparkContext from pyspark.sql import DataFrame +from pyspark.ml import Estimator, Transformer, Model from pyspark.ml.param import Params -from pyspark.ml.pipeline import Estimator, Transformer, Model +from pyspark.ml.util import _jvm from pyspark.mllib.common import inherit_doc, _java2py, _py2java -def _jvm(): +class JavaWrapper(object): """ - Returns the JVM view associated with SparkContext. Must be called - after SparkContext is initialized. + Wrapper class for a Java companion object """ - jvm = SparkContext._jvm - if jvm: - return jvm - else: - raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?") + def __init__(self, java_obj=None): + super(JavaWrapper, self).__init__() + self._java_obj = java_obj + @classmethod + def _create_from_java_class(cls, java_class, *args): + """ + Construct this object from given Java classname and arguments + """ + java_obj = JavaWrapper._new_java_obj(java_class, *args) + return cls(java_obj) -@inherit_doc -class JavaWrapper(Params): - """ - Utility class to help create wrapper classes from Java/Scala - implementations of pipeline components. - """ - - __metaclass__ = ABCMeta - - #: The wrapped Java companion object. Subclasses should initialize - #: it properly. The param values in the Java object should be - #: synced with the Python wrapper in fit/transform/evaluate/copy. - _java_obj = None + def _call_java(self, name, *args): + m = getattr(self._java_obj, name) + sc = SparkContext._active_spark_context + java_args = [_py2java(sc, arg) for arg in args] + return _java2py(sc, m(*java_args)) @staticmethod def _new_java_obj(java_class, *args): """ - Construct a new Java object. + Returns a new Java object. """ sc = SparkContext._active_spark_context java_obj = _jvm() @@ -62,6 +59,18 @@ def _new_java_obj(java_class, *args): java_args = [_py2java(sc, arg) for arg in args] return java_obj(*java_args) + +@inherit_doc +class JavaParams(JavaWrapper, Params): + """ + Utility class to help create wrapper classes from Java/Scala + implementations of pipeline components. + """ + #: The param values in the Java object should be + #: synced with the Python wrapper in fit/transform/evaluate/copy. + + __metaclass__ = ABCMeta + def _make_java_param_pair(self, param, value): """ Makes a Java parm pair. @@ -82,6 +91,17 @@ def _transfer_params_to_java(self): pair = self._make_java_param_pair(param, paramMap[param]) self._java_obj.set(pair) + def _transfer_param_map_to_java(self, pyParamMap): + """ + Transforms a Python ParamMap into a Java ParamMap. + """ + paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap") + for param in self.params: + if param in pyParamMap: + pair = self._make_java_param_pair(param, pyParamMap[param]) + paramMap.put([pair]) + return paramMap + def _transfer_params_from_java(self): """ Transforms the embedded params from the companion Java object. @@ -90,8 +110,21 @@ def _transfer_params_from_java(self): for param in self.params: if self._java_obj.hasParam(param.name): java_param = self._java_obj.getParam(param.name) - value = _java2py(sc, self._java_obj.getOrDefault(java_param)) - self._paramMap[param] = value + if self._java_obj.isDefined(java_param): + value = _java2py(sc, self._java_obj.getOrDefault(java_param)) + self._paramMap[param] = value + + def _transfer_param_map_from_java(self, javaParamMap): + """ + Transforms a Java ParamMap into a Python ParamMap. + """ + sc = SparkContext._active_spark_context + paramMap = dict() + for pair in javaParamMap.toList(): + param = pair.param() + if self.hasParam(str(param.name())): + paramMap[self.getParam(param.name())] = _java2py(sc, pair.value()) + return paramMap @staticmethod def _empty_java_param_map(): @@ -100,9 +133,55 @@ def _empty_java_param_map(): """ return _jvm().org.apache.spark.ml.param.ParamMap() + def _to_java(self): + """ + Transfer this instance's Params to the wrapped Java object, and return the Java object. + Used for ML persistence. + + Meta-algorithms such as Pipeline should override this method. + + :return: Java object equivalent to this instance. + """ + self._transfer_params_to_java() + return self._java_obj + + @staticmethod + def _from_java(java_stage): + """ + Given a Java object, create and return a Python wrapper of it. + Used for ML persistence. + + Meta-algorithms such as Pipeline should override this method as a classmethod. + """ + def __get_class(clazz): + """ + Loads Python class from its name. + """ + parts = clazz.split('.') + module = ".".join(parts[:-1]) + m = __import__(module) + for comp in parts[1:]: + m = getattr(m, comp) + return m + stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark") + # Generate a default new instance from the stage_name class. + py_type = __get_class(stage_name) + if issubclass(py_type, JavaParams): + # Load information from java_stage to the instance. + py_stage = py_type() + py_stage._java_obj = java_stage + py_stage._resetUid(java_stage.uid()) + py_stage._transfer_params_from_java() + elif hasattr(py_type, "_from_java"): + py_stage = py_type._from_java(java_stage) + else: + raise NotImplementedError("This Java stage cannot be loaded into Python currently: %r" + % stage_name) + return py_stage + @inherit_doc -class JavaEstimator(Estimator, JavaWrapper): +class JavaEstimator(JavaParams, Estimator): """ Base class for :py:class:`Estimator`s that wrap Java/Scala implementations. @@ -110,6 +189,7 @@ class JavaEstimator(Estimator, JavaWrapper): __metaclass__ = ABCMeta + @abstractmethod def _create_model(self, java_model): """ Creates a model from the input Java model reference. @@ -134,7 +214,7 @@ def _fit(self, dataset): @inherit_doc -class JavaTransformer(Transformer, JavaWrapper): +class JavaTransformer(JavaParams, Transformer): """ Base class for :py:class:`Transformer`s that wrap Java/Scala implementations. Subclasses should ensure they have the transformer Java object @@ -149,7 +229,7 @@ def _transform(self, dataset): @inherit_doc -class JavaModel(Model, JavaTransformer): +class JavaModel(JavaTransformer, Model): """ Base class for :py:class:`Model`s that wrap Java/Scala implementations. Subclasses should inherit this class before @@ -158,15 +238,23 @@ class JavaModel(Model, JavaTransformer): __metaclass__ = ABCMeta - def __init__(self, java_model): + def __init__(self, java_model=None): """ Initialize this instance with a Java model object. Subclasses should call this constructor, initialize params, - and then call _transformer_params_from_java. + and then call _transfer_params_from_java. + + This instance can be instantiated without specifying java_model, + it will be assigned after that, but this scenario only used by + :py:class:`JavaMLReader` to load models. This is a bit of a + hack, but it is easiest since a proper fix would require + MLReader (in pyspark.ml.util) to depend on these wrappers, but + these wrappers depend on pyspark.ml.util (both directly and via + other ML classes). """ - super(JavaModel, self).__init__() - self._java_obj = java_model - self.uid = java_model.uid() + super(JavaModel, self).__init__(java_model) + if java_model is not None: + self.uid = java_model.uid() def copy(self, extra=None): """ @@ -181,12 +269,7 @@ def copy(self, extra=None): if extra is None: extra = dict() that = super(JavaModel, self).copy(extra) - that._java_obj = self._java_obj.copy(self._empty_java_param_map()) - that._transfer_params_to_java() + if self._java_obj is not None: + that._java_obj = self._java_obj.copy(self._empty_java_param_map()) + that._transfer_params_to_java() return that - - def _call_java(self, name, *args): - m = getattr(self._java_obj, name) - sc = SparkContext._active_spark_context - java_args = [_py2java(sc, arg) for arg in args] - return _java2py(sc, m(*java_args)) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index aab4015ba80f8..57106f8690a7d 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -94,16 +94,19 @@ class LogisticRegressionModel(LinearClassificationModel): Classification model trained using Multinomial/Binary Logistic Regression. - :param weights: Weights computed for every feature. - :param intercept: Intercept computed for this model. (Only used - in Binary Logistic Regression. In Multinomial Logistic - Regression, the intercepts will not be a single value, - so the intercepts will be part of the weights.) - :param numFeatures: the dimension of the features. - :param numClasses: the number of possible outcomes for k classes - classification problem in Multinomial Logistic Regression. - By default, it is binary logistic regression so numClasses - will be set to 2. + :param weights: + Weights computed for every feature. + :param intercept: + Intercept computed for this model. (Only used in Binary Logistic + Regression. In Multinomial Logistic Regression, the intercepts will + not bea single value, so the intercepts will be part of the + weights.) + :param numFeatures: + The dimension of the features. + :param numClasses: + The number of possible outcomes for k classes classification problem + in Multinomial Logistic Regression. By default, it is binary + logistic regression so numClasses will be set to 2. >>> data = [ ... LabeledPoint(0.0, [0.0, 1.0]), @@ -189,8 +192,8 @@ def numFeatures(self): @since('1.4.0') def numClasses(self): """ - Number of possible outcomes for k classes classification problem in Multinomial - Logistic Regression. + Number of possible outcomes for k classes classification problem + in Multinomial Logistic Regression. """ return self._numClasses @@ -272,37 +275,42 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, """ Train a logistic regression model on the given data. - :param data: The training data, an RDD of - LabeledPoint. - :param iterations: The number of iterations - (default: 100). - :param step: The step parameter used in SGD - (default: 1.0). - :param miniBatchFraction: Fraction of data to be used for each - SGD iteration (default: 1.0). - :param initialWeights: The initial weights (default: None). - :param regParam: The regularizer parameter - (default: 0.01). - :param regType: The type of regularizer used for - training our model. - - :Allowed values: - - "l1" for using L1 regularization - - "l2" for using L2 regularization - - None for no regularization - - (default: "l2") - - :param intercept: Boolean parameter which indicates the - use or not of the augmented representation - for training data (i.e. whether bias - features are activated or not, - default: False). - :param validateData: Boolean parameter which indicates if - the algorithm should validate data - before training. (default: True) - :param convergenceTol: A condition which decides iteration termination. - (default: 0.001) + :param data: + The training data, an RDD of LabeledPoint. + :param iterations: + The number of iterations. + (default: 100) + :param step: + The step parameter used in SGD. + (default: 1.0) + :param miniBatchFraction: + Fraction of data to be used for each SGD iteration. + (default: 1.0) + :param initialWeights: + The initial weights. + (default: None) + :param regParam: + The regularizer parameter. + (default: 0.01) + :param regType: + The type of regularizer used for training our model. + Supported values: + + - "l1" for using L1 regularization + - "l2" for using L2 regularization (default) + - None for no regularization + :param intercept: + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e., whether bias + features are activated or not). + (default: False) + :param validateData: + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + :param convergenceTol: + A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations), @@ -318,43 +326,50 @@ class LogisticRegressionWithLBFGS(object): """ @classmethod @since('1.2.0') - def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType="l2", - intercept=False, corrections=10, tolerance=1e-4, validateData=True, numClasses=2): + def train(cls, data, iterations=100, initialWeights=None, regParam=0.0, regType="l2", + intercept=False, corrections=10, tolerance=1e-6, validateData=True, numClasses=2): """ Train a logistic regression model on the given data. - :param data: The training data, an RDD of - LabeledPoint. - :param iterations: The number of iterations - (default: 100). - :param initialWeights: The initial weights (default: None). - :param regParam: The regularizer parameter - (default: 0.01). - :param regType: The type of regularizer used for - training our model. - - :Allowed values: - - "l1" for using L1 regularization - - "l2" for using L2 regularization - - None for no regularization - - (default: "l2") - - :param intercept: Boolean parameter which indicates the - use or not of the augmented representation - for training data (i.e. whether bias - features are activated or not, - default: False). - :param corrections: The number of corrections used in the - LBFGS update (default: 10). - :param tolerance: The convergence tolerance of iterations - for L-BFGS (default: 1e-4). - :param validateData: Boolean parameter which indicates if the - algorithm should validate data before - training. (default: True) - :param numClasses: The number of classes (i.e., outcomes) a - label can take in Multinomial Logistic - Regression (default: 2). + :param data: + The training data, an RDD of LabeledPoint. + :param iterations: + The number of iterations. + (default: 100) + :param initialWeights: + The initial weights. + (default: None) + :param regParam: + The regularizer parameter. + (default: 0.0) + :param regType: + The type of regularizer used for training our model. + Supported values: + + - "l1" for using L1 regularization + - "l2" for using L2 regularization (default) + - None for no regularization + :param intercept: + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e., whether bias + features are activated or not). + (default: False) + :param corrections: + The number of corrections used in the LBFGS update. + If a known updater is used for binary classification, + it calls the ml implementation and this parameter will + have no effect. (default: 10) + :param tolerance: + The convergence tolerance of iterations for L-BFGS. + (default: 1e-6) + :param validateData: + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + :param numClasses: + The number of classes (i.e., outcomes) a label can take in + Multinomial Logistic Regression. + (default: 2) >>> data = [ ... LabeledPoint(0.0, [0.0, 1.0]), @@ -387,8 +402,10 @@ class SVMModel(LinearClassificationModel): """ Model for Support Vector Machines (SVMs). - :param weights: Weights computed for every feature. - :param intercept: Intercept computed for this model. + :param weights: + Weights computed for every feature. + :param intercept: + Intercept computed for this model. >>> data = [ ... LabeledPoint(0.0, [0.0]), @@ -490,37 +507,42 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, """ Train a support vector machine on the given data. - :param data: The training data, an RDD of - LabeledPoint. - :param iterations: The number of iterations - (default: 100). - :param step: The step parameter used in SGD - (default: 1.0). - :param regParam: The regularizer parameter - (default: 0.01). - :param miniBatchFraction: Fraction of data to be used for each - SGD iteration (default: 1.0). - :param initialWeights: The initial weights (default: None). - :param regType: The type of regularizer used for - training our model. - - :Allowed values: - - "l1" for using L1 regularization - - "l2" for using L2 regularization - - None for no regularization - - (default: "l2") - - :param intercept: Boolean parameter which indicates the - use or not of the augmented representation - for training data (i.e. whether bias - features are activated or not, - default: False). - :param validateData: Boolean parameter which indicates if - the algorithm should validate data - before training. (default: True) - :param convergenceTol: A condition which decides iteration termination. - (default: 0.001) + :param data: + The training data, an RDD of LabeledPoint. + :param iterations: + The number of iterations. + (default: 100) + :param step: + The step parameter used in SGD. + (default: 1.0) + :param regParam: + The regularizer parameter. + (default: 0.01) + :param miniBatchFraction: + Fraction of data to be used for each SGD iteration. + (default: 1.0) + :param initialWeights: + The initial weights. + (default: None) + :param regType: + The type of regularizer used for training our model. + Allowed values: + + - "l1" for using L1 regularization + - "l2" for using L2 regularization (default) + - None for no regularization + :param intercept: + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e. whether bias + features are activated or not). + (default: False) + :param validateData: + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + :param convergenceTol: + A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step), @@ -536,11 +558,13 @@ class NaiveBayesModel(Saveable, Loader): """ Model for Naive Bayes classifiers. - :param labels: list of labels. - :param pi: log of class priors, whose dimension is C, - number of labels. - :param theta: log of class conditional probabilities, whose - dimension is C-by-D, where D is number of features. + :param labels: + List of labels. + :param pi: + Log of class priors, whose dimension is C, number of labels. + :param theta: + Log of class conditional probabilities, whose dimension is C-by-D, + where D is number of features. >>> data = [ ... LabeledPoint(0.0, [0.0, 0.0]), @@ -639,8 +663,11 @@ def train(cls, data, lambda_=1.0): it can also be used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}). The input feature values must be nonnegative. - :param data: RDD of LabeledPoint. - :param lambda_: The smoothing parameter (default: 1.0). + :param data: + RDD of LabeledPoint. + :param lambda_: + The smoothing parameter. + (default: 1.0) """ first = data.first() if not isinstance(first, LabeledPoint): @@ -652,21 +679,34 @@ def train(cls, data, lambda_=1.0): @inherit_doc class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm): """ - Run LogisticRegression with SGD on a batch of data. - - The weights obtained at the end of training a stream are used as initial - weights for the next batch. - - :param stepSize: Step size for each iteration of gradient descent. - :param numIterations: Number of iterations run for each batch of data. - :param miniBatchFraction: Fraction of data on which SGD is run for each - iteration. - :param regParam: L2 Regularization parameter. - :param convergenceTol: A condition which decides iteration termination. + Train or predict a logistic regression model on streaming data. + Training uses Stochastic Gradient Descent to update the model based on + each new batch of incoming data from a DStream. + + Each batch of data is assumed to be an RDD of LabeledPoints. + The number of data points per batch can vary, but the number + of features must be constant. An initial weight + vector must be provided. + + :param stepSize: + Step size for each iteration of gradient descent. + (default: 0.1) + :param numIterations: + Number of iterations run for each batch of data. + (default: 50) + :param miniBatchFraction: + Fraction of each batch of data to use for updates. + (default: 1.0) + :param regParam: + L2 Regularization parameter. + (default: 0.0) + :param convergenceTol: + Value used to determine when to terminate iterations. + (default: 0.001) .. versionadded:: 1.5.0 """ - def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.01, + def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.0, convergenceTol=0.001): self.stepSize = stepSize self.numIterations = numIterations diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 8629aa5a17164..23d118bd40900 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -38,12 +38,140 @@ from pyspark.mllib.util import Saveable, Loader, inherit_doc, JavaLoader, JavaSaveable from pyspark.streaming import DStream -__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture', - 'PowerIterationClusteringModel', 'PowerIterationClustering', - 'StreamingKMeans', 'StreamingKMeansModel', +__all__ = ['BisectingKMeansModel', 'BisectingKMeans', 'KMeansModel', 'KMeans', + 'GaussianMixtureModel', 'GaussianMixture', 'PowerIterationClusteringModel', + 'PowerIterationClustering', 'StreamingKMeans', 'StreamingKMeansModel', 'LDA', 'LDAModel'] +@inherit_doc +class BisectingKMeansModel(JavaModelWrapper): + """ + .. note:: Experimental + + A clustering model derived from the bisecting k-means method. + + >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4, 2) + >>> bskm = BisectingKMeans() + >>> model = bskm.train(sc.parallelize(data, 2), k=4) + >>> p = array([0.0, 0.0]) + >>> model.predict(p) + 0 + >>> model.k + 4 + >>> model.computeCost(p) + 0.0 + + .. versionadded:: 2.0.0 + """ + + def __init__(self, java_model): + super(BisectingKMeansModel, self).__init__(java_model) + self.centers = [c.toArray() for c in self.call("clusterCenters")] + + @property + @since('2.0.0') + def clusterCenters(self): + """Get the cluster centers, represented as a list of NumPy + arrays.""" + return self.centers + + @property + @since('2.0.0') + def k(self): + """Get the number of clusters""" + return self.call("k") + + @since('2.0.0') + def predict(self, x): + """ + Find the cluster that each of the points belongs to in this + model. + + :param x: + A data point (or RDD of points) to determine cluster index. + :return: + Predicted cluster index or an RDD of predicted cluster indices + if the input is an RDD. + """ + if isinstance(x, RDD): + vecs = x.map(_convert_to_vector) + return self.call("predict", vecs) + + x = _convert_to_vector(x) + return self.call("predict", x) + + @since('2.0.0') + def computeCost(self, x): + """ + Return the Bisecting K-means cost (sum of squared distances of + points to their nearest center) for this model on the given + data. If provided with an RDD of points returns the sum. + + :param point: + A data point (or RDD of points) to compute the cost(s). + """ + if isinstance(x, RDD): + vecs = x.map(_convert_to_vector) + return self.call("computeCost", vecs) + + return self.call("computeCost", _convert_to_vector(x)) + + +class BisectingKMeans(object): + """ + .. note:: Experimental + + A bisecting k-means algorithm based on the paper "A comparison of + document clustering techniques" by Steinbach, Karypis, and Kumar, + with modification to fit Spark. + The algorithm starts from a single cluster that contains all points. + Iteratively it finds divisible clusters on the bottom level and + bisects each of them using k-means, until there are `k` leaf + clusters in total or no leaf clusters are divisible. + The bisecting steps of clusters on the same level are grouped + together to increase parallelism. If bisecting all divisible + clusters on the bottom level would result more than `k` leaf + clusters, larger clusters get higher priority. + + Based on + U{http://glaros.dtc.umn.edu/gkhome/fetch/papers/docclusterKDDTMW00.pdf} + Steinbach, Karypis, and Kumar, A comparison of document clustering + techniques, KDD Workshop on Text Mining, 2000. + + .. versionadded:: 2.0.0 + """ + + @classmethod + @since('2.0.0') + def train(self, rdd, k=4, maxIterations=20, minDivisibleClusterSize=1.0, seed=-1888008604): + """ + Runs the bisecting k-means algorithm return the model. + + :param rdd: + Training points as an `RDD` of `Vector` or convertible + sequence types. + :param k: + The desired number of leaf clusters. The actual number could + be smaller if there are no divisible leaf clusters. + (default: 4) + :param maxIterations: + Maximum number of iterations allowed to split clusters. + (default: 20) + :param minDivisibleClusterSize: + Minimum number of points (if >= 1.0) or the minimum proportion + of points (if < 1.0) of a divisible cluster. + (default: 1) + :param seed: + Random seed value for cluster initialization. + (default: -1888008604 from classOf[BisectingKMeans].getName.##) + """ + java_model = callMLlibFunc( + "trainBisectingKMeans", rdd.map(_convert_to_vector), + k, maxIterations, minDivisibleClusterSize, seed) + return BisectingKMeansModel(java_model) + + @inherit_doc class KMeansModel(Saveable, Loader): @@ -118,7 +246,16 @@ def k(self): @since('0.9.0') def predict(self, x): - """Find the cluster to which x belongs in this model.""" + """ + Find the cluster that each of the points belongs to in this + model. + + :param x: + A data point (or RDD of points) to determine cluster index. + :return: + Predicted cluster index or an RDD of predicted cluster indices + if the input is an RDD. + """ best = 0 best_distance = float("inf") if isinstance(x, RDD): @@ -136,7 +273,11 @@ def predict(self, x): def computeCost(self, rdd): """ Return the K-means cost (sum of squared distances of points to - their nearest center) for this model on the given data. + their nearest center) for this model on the given + data. + + :param rdd: + The RDD of points to compute the cost on. """ cost = callMLlibFunc("computeCostKmeansModel", rdd.map(_convert_to_vector), [_convert_to_vector(c) for c in self.centers]) @@ -170,10 +311,47 @@ class KMeans(object): @since('0.9.0') def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||", seed=None, initializationSteps=5, epsilon=1e-4, initialModel=None): - """Train a k-means clustering model.""" + """ + Train a k-means clustering model. + + :param rdd: + Training points as an `RDD` of `Vector` or convertible + sequence types. + :param k: + Number of clusters to create. + :param maxIterations: + Maximum number of iterations allowed. + (default: 100) + :param runs: + Number of runs to execute in parallel. The best model according + to the cost function will be returned (deprecated in 1.6.0). + (default: 1) + :param initializationMode: + The initialization algorithm. This can be either "random" or + "k-means||". + (default: "k-means||") + :param seed: + Random seed value for cluster initialization. Set as None to + generate seed based on system time. + (default: None) + :param initializationSteps: + Number of steps for the k-means|| initialization mode. + This is an advanced setting -- the default of 5 is almost + always enough. + (default: 5) + :param epsilon: + Distance threshold within which a center will be considered to + have converged. If all centers move less than this Euclidean + distance, iterations are stopped. + (default: 1e-4) + :param initialModel: + Initial cluster centers can be provided as a KMeansModel object + rather than using the random or k-means|| initializationModel. + (default: None) + """ if runs != 1: warnings.warn( - "Support for runs is deprecated in 1.6.0. This param will have no effect in 1.7.0.") + "Support for runs is deprecated in 1.6.0. This param will have no effect in 2.0.0.") clusterInitialModel = [] if initialModel is not None: if not isinstance(initialModel, KMeansModel): @@ -202,16 +380,25 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1, ... 0.9,0.8,0.75,0.935, - ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2)) + ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2), 2) >>> model = GaussianMixture.train(clusterdata_1, 3, convergenceTol=0.0001, ... maxIterations=50, seed=10) >>> labels = model.predict(clusterdata_1).collect() >>> labels[0]==labels[1] False >>> labels[1]==labels[2] - True + False >>> labels[4]==labels[5] True + >>> model.predict([-0.1,-0.05]) + 0 + >>> softPredicted = model.predictSoft([-0.1,-0.05]) + >>> abs(softPredicted[0] - 1.0) < 0.001 + True + >>> abs(softPredicted[1] - 0.0) < 0.001 + True + >>> abs(softPredicted[2] - 0.0) < 0.001 + True >>> path = tempfile.mkdtemp() >>> model.save(sc, path) @@ -266,7 +453,7 @@ def gaussians(self): """ return [ MultivariateGaussian(gaussian[0], gaussian[1]) - for gaussian in zip(*self.call("gaussians"))] + for gaussian in self.call("gaussians")] @property @since('1.4.0') @@ -277,26 +464,32 @@ def k(self): @since('1.3.0') def predict(self, x): """ - Find the cluster to which the points in 'x' has maximum membership - in this model. - - :param x: RDD of data points. - :return: cluster_labels. RDD of cluster labels. + Find the cluster to which the point 'x' or each point in RDD 'x' + has maximum membership in this model. + + :param x: + A feature vector or an RDD of vectors representing data points. + :return: + Predicted cluster label or an RDD of predicted cluster labels + if the input is an RDD. """ if isinstance(x, RDD): cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z))) return cluster_labels else: - raise TypeError("x should be represented by an RDD, " - "but got %s." % type(x)) + z = self.predictSoft(x) + return z.argmax() @since('1.3.0') def predictSoft(self, x): """ - Find the membership of each point in 'x' to all mixture components. + Find the membership of point 'x' or each point in RDD 'x' to all mixture components. - :param x: RDD of data points. - :return: membership_matrix. RDD of array of double values. + :param x: + A feature vector or an RDD of vectors representing data points. + :return: + The membership value to all mixture components for vector 'x' + or each vector in RDD 'x'. """ if isinstance(x, RDD): means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians]) @@ -304,16 +497,17 @@ def predictSoft(self, x): _convert_to_vector(self.weights), means, sigmas) return membership_matrix.map(lambda x: pyarray.array('d', x)) else: - raise TypeError("x should be represented by an RDD, " - "but got %s." % type(x)) + return self.call("predictSoft", _convert_to_vector(x)).toArray() @classmethod @since('1.5.0') def load(cls, sc, path): """Load the GaussianMixtureModel from disk. - :param sc: SparkContext - :param path: str, path to where the model is stored. + :param sc: + SparkContext. + :param path: + Path to where the model is stored. """ model = cls._load_java(sc, path) wrapper = sc._jvm.GaussianMixtureModelWrapper(model) @@ -326,19 +520,35 @@ class GaussianMixture(object): Learning algorithm for Gaussian Mixtures using the expectation-maximization algorithm. - :param data: RDD of data points - :param k: Number of components - :param convergenceTol: Threshold value to check the convergence criteria. Defaults to 1e-3 - :param maxIterations: Number of iterations. Default to 100 - :param seed: Random Seed - :param initialModel: GaussianMixtureModel for initializing learning - .. versionadded:: 1.3.0 """ @classmethod @since('1.3.0') def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initialModel=None): - """Train a Gaussian Mixture clustering model.""" + """ + Train a Gaussian Mixture clustering model. + + :param rdd: + Training points as an `RDD` of `Vector` or convertible + sequence types. + :param k: + Number of independent Gaussians in the mixture model. + :param convergenceTol: + Maximum change in log-likelihood at which convergence is + considered to have occurred. + (default: 1e-3) + :param maxIterations: + Maximum number of iterations allowed. + (default: 100) + :param seed: + Random seed for initial Gaussian distribution. Set as None to + generate seed based on system time. + (default: None) + :param initialModel: + Initial GMM starting point, bypassing the random + initialization. + (default: None) + """ initialModelWeights = None initialModelMu = None initialModelSigma = None @@ -346,7 +556,7 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia if initialModel.k != k: raise Exception("Mismatched cluster count, initialModel.k = %s, however k = %s" % (initialModel.k, k)) - initialModelWeights = initialModel.weights + initialModelWeights = list(initialModel.weights) initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)] initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)] java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector), @@ -362,12 +572,25 @@ class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader): Model produced by [[PowerIterationClustering]]. - >>> data = [(0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), (1, 3, 1.0), - ... (2, 3, 1.0), (3, 4, 0.1), (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0), - ... (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0), (10, 11, 1.0), - ... (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0)] - >>> rdd = sc.parallelize(data, 2) - >>> model = PowerIterationClustering.train(rdd, 2, 100) + >>> import math + >>> def genCircle(r, n): + ... points = [] + ... for i in range(0, n): + ... theta = 2.0 * math.pi * i / n + ... points.append((r * math.cos(theta), r * math.sin(theta))) + ... return points + >>> def sim(x, y): + ... dist2 = (x[0] - y[0]) * (x[0] - y[0]) + (x[1] - y[1]) * (x[1] - y[1]) + ... return math.exp(-dist2 / 2.0) + >>> r1 = 1.0 + >>> n1 = 10 + >>> r2 = 4.0 + >>> n2 = 40 + >>> n = n1 + n2 + >>> points = genCircle(r1, n1) + genCircle(r2, n2) + >>> similarities = [(i, j, sim(points[i], points[j])) for i in range(1, n) for j in range(0, i)] + >>> rdd = sc.parallelize(similarities, 2) + >>> model = PowerIterationClustering.train(rdd, 2, 40) >>> model.k 2 >>> result = sorted(model.assignments().collect(), key=lambda x: x.id) @@ -439,18 +662,24 @@ class PowerIterationClustering(object): @since('1.5.0') def train(cls, rdd, k, maxIterations=100, initMode="random"): """ - :param rdd: an RDD of (i, j, s,,ij,,) tuples representing the - affinity matrix, which is the matrix A in the PIC paper. - The similarity s,,ij,, must be nonnegative. - This is a symmetric matrix and hence s,,ij,, = s,,ji,,. - For any (i, j) with nonzero similarity, there should be - either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. - Tuples with i = j are ignored, because we assume - s,,ij,, = 0.0. - :param k: Number of clusters. - :param maxIterations: Maximum number of iterations of the - PIC algorithm. - :param initMode: Initialization mode. + :param rdd: + An RDD of (i, j, s\ :sub:`ij`\) tuples representing the + affinity matrix, which is the matrix A in the PIC paper. The + similarity s\ :sub:`ij`\ must be nonnegative. This is a symmetric + matrix and hence s\ :sub:`ij`\ = s\ :sub:`ji`\ For any (i, j) with + nonzero similarity, there should be either (i, j, s\ :sub:`ij`\) or + (j, i, s\ :sub:`ji`\) in the input. Tuples with i = j are ignored, + because it is assumed s\ :sub:`ij`\ = 0.0. + :param k: + Number of clusters. + :param maxIterations: + Maximum number of iterations of the PIC algorithm. + (default: 100) + :param initMode: + Initialization mode. This can be either "random" to use + a random vector as vertex properties, or "degree" to use + normalized sum similarities. + (default: "random") """ model = callMLlibFunc("trainPowerIterationClusteringModel", rdd.map(_convert_to_vector), int(k), int(maxIterations), initMode) @@ -490,8 +719,10 @@ class StreamingKMeansModel(KMeansModel): and new data. If it set to zero, the old centroids are completely forgotten. - :param clusterCenters: Initial cluster centers. - :param clusterWeights: List of weights assigned to each cluster. + :param clusterCenters: + Initial cluster centers. + :param clusterWeights: + List of weights assigned to each cluster. >>> initCenters = [[0.0, 0.0], [1.0, 1.0]] >>> initWeights = [1.0, 1.0] @@ -538,11 +769,14 @@ def clusterWeights(self): def update(self, data, decayFactor, timeUnit): """Update the centroids, according to data - :param data: Should be a RDD that represents the new data. - :param decayFactor: forgetfulness of the previous centroids. - :param timeUnit: Can be "batches" or "points". If points, then the - decay factor is raised to the power of number of new - points and if batches, it is used as it is. + :param data: + RDD with new data for the model update. + :param decayFactor: + Forgetfulness of the previous centroids. + :param timeUnit: + Can be "batches" or "points". If points, then the decay factor + is raised to the power of number of new points and if batches, + then decay factor will be used as is. """ if not isinstance(data, RDD): raise TypeError("Data should be of an RDD, got %s." % type(data)) @@ -569,10 +803,17 @@ class StreamingKMeans(object): More details on how the centroids are updated are provided under the docs of StreamingKMeansModel. - :param k: int, number of clusters - :param decayFactor: float, forgetfulness of the previous centroids. - :param timeUnit: can be "batches" or "points". If points, then the - decayfactor is raised to the power of no. of new points. + :param k: + Number of clusters. + (default: 2) + :param decayFactor: + Forgetfulness of the previous centroids. + (default: 1.0) + :param timeUnit: + Can be "batches" or "points". If points, then the decay factor is + raised to the power of number of new points and if batches, then + decay factor will be used as is. + (default: "batches") .. versionadded:: 1.5.0 """ @@ -671,7 +912,7 @@ def predictOnValues(self, dstream): return dstream.mapValues(lambda x: self._model.predict(x)) -class LDAModel(JavaModelWrapper): +class LDAModel(JavaModelWrapper, JavaSaveable, Loader): """ A clustering model derived from the LDA method. @@ -691,9 +932,14 @@ class LDAModel(JavaModelWrapper): ... [2, SparseVector(2, {0: 1.0})], ... ] >>> rdd = sc.parallelize(data) - >>> model = LDA.train(rdd, k=2) + >>> model = LDA.train(rdd, k=2, seed=1) >>> model.vocabSize() 2 + >>> model.describeTopics() + [([1, 0], [0.5..., 0.49...]), ([0, 1], [0.5..., 0.49...])] + >>> model.describeTopics(1) + [([1], [0.5...]), ([0], [0.5...])] + >>> topics = model.topicsMatrix() >>> topics_expect = array([[0.5, 0.5], [0.5, 0.5]]) >>> assert_almost_equal(topics, topics_expect, 1) @@ -724,34 +970,42 @@ def vocabSize(self): """Vocabulary size (number of terms or terms in the vocabulary)""" return self.call("vocabSize") - @since('1.5.0') - def save(self, sc, path): - """Save the LDAModel on to disk. + @since('1.6.0') + def describeTopics(self, maxTermsPerTopic=None): + """Return the topics described by weighted terms. + + WARNING: If vocabSize and k are large, this can return a large object! - :param sc: SparkContext - :param path: str, path to where the model needs to be stored. + :param maxTermsPerTopic: + Maximum number of terms to collect for each topic. + (default: vocabulary size) + :return: + Array over topics. Each topic is represented as a pair of + matching arrays: (term indices, term weights in topic). + Each topic's terms are sorted in order of decreasing weight. """ - if not isinstance(sc, SparkContext): - raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) - if not isinstance(path, basestring): - raise TypeError("path should be a basestring, got type %s" % type(path)) - self._java_model.save(sc._jsc.sc(), path) + if maxTermsPerTopic is None: + topics = self.call("describeTopics") + else: + topics = self.call("describeTopics", maxTermsPerTopic) + return topics @classmethod @since('1.5.0') def load(cls, sc, path): """Load the LDAModel from disk. - :param sc: SparkContext - :param path: str, path to where the model is stored. + :param sc: + SparkContext. + :param path: + Path to where the model is stored. """ if not isinstance(sc, SparkContext): raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) if not isinstance(path, basestring): raise TypeError("path should be a basestring, got type %s" % type(path)) - java_model = sc._jvm.org.apache.spark.mllib.clustering.DistributedLDAModel.load( - sc._jsc.sc(), path) - return cls(java_model) + model = callMLlibFunc("loadLDAModel", sc, path) + return LDAModel(model) class LDA(object): @@ -765,17 +1019,38 @@ def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0, topicConcentration=-1.0, seed=None, checkpointInterval=10, optimizer="em"): """Train a LDA model. - :param rdd: RDD of data points - :param k: Number of clusters you want - :param maxIterations: Number of iterations. Default to 20 - :param docConcentration: Concentration parameter (commonly named "alpha") - for the prior placed on documents' distributions over topics ("theta"). - :param topicConcentration: Concentration parameter (commonly named "beta" or "eta") - for the prior placed on topics' distributions over terms. - :param seed: Random Seed - :param checkpointInterval: Period (in iterations) between checkpoints. - :param optimizer: LDAOptimizer used to perform the actual calculation. - Currently "em", "online" are supported. Default to "em". + :param rdd: + RDD of documents, which are tuples of document IDs and term + (word) count vectors. The term count vectors are "bags of + words" with a fixed-size vocabulary (where the vocabulary size + is the length of the vector). Document IDs must be unique + and >= 0. + :param k: + Number of topics to infer, i.e., the number of soft cluster + centers. + (default: 10) + :param maxIterations: + Maximum number of iterations allowed. + (default: 20) + :param docConcentration: + Concentration parameter (commonly named "alpha") for the prior + placed on documents' distributions over topics ("theta"). + (default: -1.0) + :param topicConcentration: + Concentration parameter (commonly named "beta" or "eta") for + the prior placed on topics' distributions over terms. + (default: -1.0) + :param seed: + Random seed for cluster initialization. Set as None to generate + seed based on system time. + (default: None) + :param checkpointInterval: + Period (in iterations) between checkpoints. + (default: 10) + :param optimizer: + LDAOptimizer used to perform the actual calculation. Currently + "em", "online" are supported. + (default: "em") """ model = callMLlibFunc("trainLDAModel", rdd, k, maxIterations, docConcentration, topicConcentration, seed, diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index a439a488de5cc..6bc2b1e64651e 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -101,8 +101,8 @@ def _java2py(sc, r, encoding="bytes"): jrdd = sc._jvm.SerDe.javaToPython(r) return RDD(jrdd, sc) - if clsName == 'DataFrame': - return DataFrame(r, SQLContext(sc)) + if clsName == 'Dataset': + return DataFrame(r, SQLContext.getOrCreate(sc)) if clsName in _picklable_classes: r = sc._jvm.SerDe.dumps(r) @@ -125,7 +125,7 @@ def callJavaFunc(sc, func, *args): def callMLlibFunc(name, *args): """ Call API in PythonMLLibAPI """ - sc = SparkContext._active_spark_context + sc = SparkContext.getOrCreate() api = getattr(sc._jvm.PythonMLLibAPI(), name) return callJavaFunc(sc, api, *args) @@ -135,7 +135,7 @@ class JavaModelWrapper(object): Wrapper for the model in JVM """ def __init__(self, java_model): - self._sc = SparkContext._active_spark_context + self._sc = SparkContext.getOrCreate() self._java_model = java_model def __del__(self): diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 8c87ee9df2132..22e68ea5b4511 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -44,7 +44,7 @@ class BinaryClassificationMetrics(JavaModelWrapper): def __init__(self, scoreAndLabels): sc = scoreAndLabels.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([ StructField("score", DoubleType(), nullable=False), StructField("label", DoubleType(), nullable=False)])) @@ -103,7 +103,7 @@ class RegressionMetrics(JavaModelWrapper): def __init__(self, predictionAndObservations): sc = predictionAndObservations.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(predictionAndObservations, schema=StructType([ StructField("prediction", DoubleType(), nullable=False), StructField("observation", DoubleType(), nullable=False)])) @@ -197,7 +197,7 @@ class MulticlassMetrics(JavaModelWrapper): def __init__(self, predictionAndLabels): sc = predictionAndLabels.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([ StructField("prediction", DoubleType(), nullable=False), StructField("label", DoubleType(), nullable=False)])) @@ -338,7 +338,7 @@ class RankingMetrics(JavaModelWrapper): def __init__(self, predictionAndLabels): sc = predictionAndLabels.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(predictionAndLabels, schema=sql_ctx._inferSchema(predictionAndLabels)) java_model = callMLlibFunc("newRankingMetrics", df._jdf) @@ -424,7 +424,7 @@ class MultilabelMetrics(JavaModelWrapper): def __init__(self, predictionAndLabels): sc = predictionAndLabels.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(predictionAndLabels, schema=sql_ctx._inferSchema(predictionAndLabels)) java_class = sc._jvm.org.apache.spark.mllib.evaluation.MultilabelMetrics diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 7b077b058c3fd..b3dd2f63a5d80 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -30,7 +30,7 @@ from py4j.protocol import Py4JJavaError -from pyspark import SparkContext, since +from pyspark import since from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import ( @@ -100,8 +100,6 @@ def transform(self, vector): :return: normalized vector. If the norm of the input is zero, it will return the input vector. """ - sc = SparkContext._active_spark_context - assert sc is not None, "SparkContext should be initialized first" if isinstance(vector, RDD): vector = vector.map(_convert_to_vector) else: @@ -174,6 +172,38 @@ def setWithStd(self, withStd): self.call("setWithStd", withStd) return self + @property + @since('2.0.0') + def withStd(self): + """ + Returns if the model scales the data to unit standard deviation. + """ + return self.call("withStd") + + @property + @since('2.0.0') + def withMean(self): + """ + Returns if the model centers the data before scaling. + """ + return self.call("withMean") + + @property + @since('2.0.0') + def std(self): + """ + Return the column standard deviation values. + """ + return self.call("std") + + @property + @since('2.0.0') + def mean(self): + """ + Return the column mean values. + """ + return self.call("mean") + class StandardScaler(object): """ @@ -198,6 +228,14 @@ class StandardScaler(object): >>> for r in result.collect(): r DenseVector([-0.7071, 0.7071, -0.7071]) DenseVector([0.7071, -0.7071, 0.7071]) + >>> int(model.std[0]) + 4 + >>> int(model.mean[0]*10) + 9 + >>> model.withStd + True + >>> model.withMean + True .. versionadded:: 1.2.0 """ @@ -341,6 +379,17 @@ class HashingTF(object): """ def __init__(self, numFeatures=1 << 20): self.numFeatures = numFeatures + self.binary = False + + @since("2.0.0") + def setBinary(self, value): + """ + If True, term frequency vector will be binary such that non-zero + term counts will be set to 1 + (default: False) + """ + self.binary = value + return self @since('1.2.0') def indexOf(self, term): @@ -360,7 +409,7 @@ def transform(self, document): freq = {} for term in document: i = self.indexOf(term) - freq[i] = freq.get(i, 0) + 1.0 + freq[i] = 1.0 if self.binary else freq.get(i, 0) + 1.0 return Vectors.sparse(self.numFeatures, freq.items()) @@ -504,7 +553,8 @@ def load(cls, sc, path): """ jmodel = sc._jvm.org.apache.spark.mllib.feature \ .Word2VecModel.load(sc._jsc.sc(), path) - return Word2VecModel(jmodel) + model = sc._jvm.Word2VecModelWrapper(jmodel) + return Word2VecModel(model) @ignore_unicode_prefix @@ -546,6 +596,9 @@ class Word2Vec(object): >>> sameModel = Word2VecModel.load(sc, path) >>> model.transform("a") == sameModel.transform("a") True + >>> syms = sameModel.findSynonyms("a", 2) + >>> [s[0] for s in syms] + [u'b', u'c'] >>> from shutil import rmtree >>> try: ... rmtree(path) diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index 2039decc0cb3c..f339e50891166 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -21,15 +21,15 @@ from pyspark import SparkContext, since from pyspark.rdd import ignore_unicode_prefix -from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc +from pyspark.mllib.util import JavaSaveable, JavaLoader, inherit_doc __all__ = ['FPGrowth', 'FPGrowthModel', 'PrefixSpan', 'PrefixSpanModel'] @inherit_doc @ignore_unicode_prefix -class FPGrowthModel(JavaModelWrapper): - +class FPGrowthModel(JavaModelWrapper, JavaSaveable, JavaLoader): """ .. note:: Experimental @@ -41,6 +41,11 @@ class FPGrowthModel(JavaModelWrapper): >>> model = FPGrowth.train(rdd, 0.6, 2) >>> sorted(model.freqItemsets().collect()) [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ... + >>> model_path = temp_path + "/fpm" + >>> model.save(sc, model_path) + >>> sameModel = FPGrowthModel.load(sc, model_path) + >>> sorted(model.freqItemsets().collect()) == sorted(sameModel.freqItemsets().collect()) + True .. versionadded:: 1.4.0 """ @@ -52,6 +57,16 @@ def freqItemsets(self): """ return self.call("getFreqItemsets").map(lambda x: (FPGrowth.FreqItemset(x[0], x[1]))) + @classmethod + @since("2.0.0") + def load(cls, sc, path): + """ + Load a model from the given path. + """ + model = cls._load_java(sc, path) + wrapper = sc._jvm.FPGrowthModelWrapper(model) + return FPGrowthModel(wrapper) + class FPGrowth(object): """ @@ -68,11 +83,15 @@ def train(cls, data, minSupport=0.3, numPartitions=-1): """ Computes an FP-Growth model that contains frequent itemsets. - :param data: The input data set, each element contains a - transaction. - :param minSupport: The minimal support level (default: `0.3`). - :param numPartitions: The number of partitions used by - parallel FP-growth (default: same as input data). + :param data: + The input data set, each element contains a transaction. + :param minSupport: + The minimal support level. + (default: 0.3) + :param numPartitions: + The number of partitions used by parallel FP-growth. A value + of -1 will use the same number as input data. + (default: -1) """ model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartitions)) return FPGrowthModel(model) @@ -108,7 +127,7 @@ class PrefixSpanModel(JavaModelWrapper): @since("1.6.0") def freqSequences(self): - """Gets frequence sequences""" + """Gets frequent sequences""" return self.call("getFreqSequences").map(lambda x: PrefixSpan.FreqSequence(x[0], x[1])) @@ -128,17 +147,27 @@ class PrefixSpan(object): @since("1.6.0") def train(cls, data, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000): """ - Finds the complete set of frequent sequential patterns in the input sequences of itemsets. - - :param data: The input data set, each element contains a sequnce of itemsets. - :param minSupport: the minimal support level of the sequential pattern, any pattern appears - more than (minSupport * size-of-the-dataset) times will be output (default: `0.1`) - :param maxPatternLength: the maximal length of the sequential pattern, any pattern appears - less than maxPatternLength will be output. (default: `10`) - :param maxLocalProjDBSize: The maximum number of items (including delimiters used in - the internal storage format) allowed in a projected database before local - processing. If a projected database exceeds this size, another - iteration of distributed prefix growth is run. (default: `32000000`) + Finds the complete set of frequent sequential patterns in the + input sequences of itemsets. + + :param data: + The input data set, each element contains a sequence of + itemsets. + :param minSupport: + The minimal support level of the sequential pattern, any + pattern that appears more than (minSupport * + size-of-the-dataset) times will be output. + (default: 0.1) + :param maxPatternLength: + The maximal length of the sequential pattern, any pattern + that appears less than maxPatternLength will be output. + (default: 10) + :param maxLocalProjDBSize: + The maximum number of items (including delimiters used in the + internal storage format) allowed in a projected database before + local processing. If a projected database exceeds this size, + another iteration of distributed prefix growth is run. + (default: 32000000) """ model = callMLlibFunc("trainPrefixSpanModel", data, minSupport, maxPatternLength, maxLocalProjDBSize) @@ -157,8 +186,19 @@ def _test(): import pyspark.mllib.fpm globs = pyspark.mllib.fpm.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest') - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - globs['sc'].stop() + import tempfile + + temp_path = tempfile.mkdtemp() + globs['temp_path'] = temp_path + try: + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + finally: + from shutil import rmtree + try: + rmtree(temp_path) + except OSError: + pass if failure_count: exit(-1) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index ae9ce58450905..abf00a4737948 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -528,7 +528,9 @@ def __init__(self, size, *args): assert len(self.indices) == len(self.values), "index and value arrays not same length" for i in xrange(len(self.indices) - 1): if self.indices[i] >= self.indices[i + 1]: - raise TypeError("indices array must be sorted") + raise TypeError( + "Indices %s and %s are not strictly increasing" + % (self.indices[i], self.indices[i + 1])) def numNonzeros(self): """ @@ -556,7 +558,7 @@ def __reduce__(self): @staticmethod def parse(s): """ - Parse string representation back into the DenseVector. + Parse string representation back into the SparseVector. >>> SparseVector.parse(' (4, [0,1 ],[ 4.0,5.0] )') SparseVector(4, {0: 4.0, 1: 5.0}) diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index 0e76050788630..43cb0beef1bd3 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -256,7 +256,7 @@ def rows(self): # on the Scala/Java side. Then we map each Row in the # DataFrame back to an IndexedRow on this side. rows_df = callMLlibFunc("getIndexedRows", self._java_matrix_wrapper._java_model) - rows = rows_df.map(lambda row: IndexedRow(row[0], row[1])) + rows = rows_df.rdd.map(lambda row: IndexedRow(row[0], row[1])) return rows def numRows(self): @@ -297,6 +297,20 @@ def numCols(self): """ return self._java_matrix_wrapper.call("numCols") + def columnSimilarities(self): + """ + Compute all cosine similarities between columns. + + >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]), + ... IndexedRow(6, [4, 5, 6])]) + >>> mat = IndexedRowMatrix(rows) + >>> cs = mat.columnSimilarities() + >>> print(cs.numCols()) + 3 + """ + java_coordinate_matrix = self._java_matrix_wrapper.call("columnSimilarities") + return CoordinateMatrix(java_coordinate_matrix) + def toRowMatrix(self): """ Convert this matrix to a RowMatrix. @@ -461,7 +475,7 @@ def entries(self): # DataFrame on the Scala/Java side. Then we map each Row in # the DataFrame back to a MatrixEntry on this side. entries_df = callMLlibFunc("getMatrixEntries", self._java_matrix_wrapper._java_model) - entries = entries_df.map(lambda row: MatrixEntry(row[0], row[1], row[2])) + entries = entries_df.rdd.map(lambda row: MatrixEntry(row[0], row[1], row[2])) return entries def numRows(self): @@ -686,7 +700,7 @@ def blocks(self): # DataFrame on the Scala/Java side. Then we map each Row in # the DataFrame back to a sub-matrix block on this side. blocks_df = callMLlibFunc("getMatrixBlocks", self._java_matrix_wrapper._java_model) - blocks = blocks_df.map(lambda row: ((row[0][0], row[0][1]), row[1])) + blocks = blocks_df.rdd.map(lambda row: ((row[0][0], row[0][1]), row[1])) return blocks @property diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index b9442b0d16c0f..7e60255d43ead 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -101,12 +101,12 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> model = ALS.train(ratings, 1, nonnegative=True, seed=10) >>> model.predict(2, 2) - 3.8... + 3.73... >>> df = sqlContext.createDataFrame([Rating(1, 1, 1.0), Rating(1, 2, 2.0), Rating(2, 1, 2.0)]) >>> model = ALS.train(df, 1, nonnegative=True, seed=10) >>> model.predict(2, 2) - 3.8... + 3.73... >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10) >>> model.predict(2, 2) @@ -138,7 +138,8 @@ def predict(self, user, product): @since("0.9.0") def predictAll(self, user_product): """ - Returns a list of predicted ratings for input user and product pairs. + Returns a list of predicted ratings for input user and product + pairs. """ assert isinstance(user_product, RDD), "user_product should be RDD of (user, product)" first = user_product.first() @@ -165,28 +166,33 @@ def productFeatures(self): @since("1.4.0") def recommendUsers(self, product, num): """ - Recommends the top "num" number of users for a given product and returns a list - of Rating objects sorted by the predicted rating in descending order. + Recommends the top "num" number of users for a given product and + returns a list of Rating objects sorted by the predicted rating in + descending order. """ return list(self.call("recommendUsers", product, num)) @since("1.4.0") def recommendProducts(self, user, num): """ - Recommends the top "num" number of products for a given user and returns a list - of Rating objects sorted by the predicted rating in descending order. + Recommends the top "num" number of products for a given user and + returns a list of Rating objects sorted by the predicted rating in + descending order. """ return list(self.call("recommendProducts", user, num)) def recommendProductsForUsers(self, num): """ - Recommends top "num" products for all users. The number returned may be less than this. + Recommends the top "num" number of products for all users. The + number of recommendations returned per user may be less than "num". """ return self.call("wrappedRecommendProductsForUsers", num) def recommendUsersForProducts(self, num): """ - Recommends top "num" users for all products. The number returned may be less than this. + Recommends the top "num" number of users for all products. The + number of recommendations returned per product may be less than + "num". """ return self.call("wrappedRecommendUsersForProducts", num) @@ -234,11 +240,34 @@ def _prepare(cls, ratings): def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative=False, seed=None): """ - Train a matrix factorization model given an RDD of ratings given by users to some products, - in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the - product of two lower-rank matrices of a given rank (number of features). To solve for these - features, we run a given number of iterations of ALS. This is done using a level of - parallelism given by `blocks`. + Train a matrix factorization model given an RDD of ratings by users + for a subset of products. The ratings matrix is approximated as the + product of two lower-rank matrices of a given rank (number of + features). To solve for these features, ALS is run iteratively with + a configurable level of parallelism. + + :param ratings: + RDD of `Rating` or (userID, productID, rating) tuple. + :param rank: + Rank of the feature matrices computed (number of features). + :param iterations: + Number of iterations of ALS. + (default: 5) + :param lambda_: + Regularization parameter. + (default: 0.01) + :param blocks: + Number of blocks used to parallelize the computation. A value + of -1 will use an auto-configured number of blocks. + (default: -1) + :param nonnegative: + A value of True will solve least-squares with nonnegativity + constraints. + (default: False) + :param seed: + Random seed for initial matrix factorization model. A value + of None will use system time as the seed. + (default: None) """ model = callMLlibFunc("trainALSModel", cls._prepare(ratings), rank, iterations, lambda_, blocks, nonnegative, seed) @@ -249,11 +278,37 @@ def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01, nonnegative=False, seed=None): """ - Train a matrix factorization model given an RDD of 'implicit preferences' given by users - to some products, in the form of (userID, productID, preference) pairs. We approximate the - ratings matrix as the product of two lower-rank matrices of a given rank (number of - features). To solve for these features, we run a given number of iterations of ALS. - This is done using a level of parallelism given by `blocks`. + Train a matrix factorization model given an RDD of 'implicit + preferences' of users for a subset of products. The ratings matrix + is approximated as the product of two lower-rank matrices of a + given rank (number of features). To solve for these features, ALS + is run iteratively with a configurable level of parallelism. + + :param ratings: + RDD of `Rating` or (userID, productID, rating) tuple. + :param rank: + Rank of the feature matrices computed (number of features). + :param iterations: + Number of iterations of ALS. + (default: 5) + :param lambda_: + Regularization parameter. + (default: 0.01) + :param blocks: + Number of blocks used to parallelize the computation. A value + of -1 will use an auto-configured number of blocks. + (default: -1) + :param alpha: + A constant used in computing confidence. + (default: 0.01) + :param nonnegative: + A value of True will solve least-squares with nonnegativity + constraints. + (default: False) + :param seed: + Random seed for initial matrix factorization model. A value + of None will use system time as the seed. + (default: None) """ model = callMLlibFunc("trainImplicitALSModel", cls._prepare(ratings), rank, iterations, lambda_, blocks, alpha, nonnegative, seed) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 6f00d1df209c0..3b77a6200054f 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -37,10 +37,11 @@ class LabeledPoint(object): """ Class that represents the features and labels of a data point. - :param label: Label for this data point. - :param features: Vector of features for this point (NumPy array, - list, pyspark.mllib.linalg.SparseVector, or scipy.sparse - column matrix) + :param label: + Label for this data point. + :param features: + Vector of features for this point (NumPy array, list, + pyspark.mllib.linalg.SparseVector, or scipy.sparse column matrix). Note: 'label' and 'features' are accessible as class attributes. @@ -66,8 +67,10 @@ class LinearModel(object): """ A linear model that has a vector of coefficients and an intercept. - :param weights: Weights computed for every feature. - :param intercept: Intercept computed for this model. + :param weights: + Weights computed for every feature. + :param intercept: + Intercept computed for this model. .. versionadded:: 0.9.0 """ @@ -217,17 +220,8 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights): class LinearRegressionWithSGD(object): """ - Train a linear regression model with no regularization using Stochastic Gradient Descent. - This solves the least squares regression formulation - f(weights) = 1/n ||A weights-y||^2^ - (which is the mean squared error). - Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with - its corresponding right hand side label y. - See also the documentation for the precise formulation. - .. versionadded:: 0.9.0 """ - @classmethod @since("0.9.0") def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, @@ -235,47 +229,52 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, validateData=True, convergenceTol=0.001): """ Train a linear regression model using Stochastic Gradient - Descent (SGD). - This solves the least squares regression formulation - - f(weights) = 1/(2n) ||A weights - y||^2, - - which is the mean squared error. - Here the data matrix has n rows, and the input RDD holds the - set of rows of A, each with its corresponding right hand side - label y. See also the documentation for the precise formulation. - - :param data: The training data, an RDD of - LabeledPoint. - :param iterations: The number of iterations - (default: 100). - :param step: The step parameter used in SGD - (default: 1.0). - :param miniBatchFraction: Fraction of data to be used for each - SGD iteration (default: 1.0). - :param initialWeights: The initial weights (default: None). - :param regParam: The regularizer parameter - (default: 0.0). - :param regType: The type of regularizer used for - training our model. - - :Allowed values: - - "l1" for using L1 regularization (lasso), - - "l2" for using L2 regularization (ridge), - - None for no regularization - - (default: None) - - :param intercept: Boolean parameter which indicates the - use or not of the augmented representation - for training data (i.e. whether bias - features are activated or not, - default: False). - :param validateData: Boolean parameter which indicates if - the algorithm should validate data - before training. (default: True) - :param convergenceTol: A condition which decides iteration termination. - (default: 0.001) + Descent (SGD). This solves the least squares regression + formulation + + f(weights) = 1/(2n) ||A weights - y||^2 + + which is the mean squared error. Here the data matrix has n rows, + and the input RDD holds the set of rows of A, each with its + corresponding right hand side label y. + See also the documentation for the precise formulation. + + :param data: + The training data, an RDD of LabeledPoint. + :param iterations: + The number of iterations. + (default: 100) + :param step: + The step parameter used in SGD. + (default: 1.0) + :param miniBatchFraction: + Fraction of data to be used for each SGD iteration. + (default: 1.0) + :param initialWeights: + The initial weights. + (default: None) + :param regParam: + The regularizer parameter. + (default: 0.0) + :param regType: + The type of regularizer used for training our model. + Supported values: + + - "l1" for using L1 regularization + - "l2" for using L2 regularization + - None for no regularization (default) + :param intercept: + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e., whether bias + features are activated or not). + (default: False) + :param validateData: + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + :param convergenceTol: + A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), @@ -366,54 +365,53 @@ def load(cls, sc, path): class LassoWithSGD(object): """ - Train a regression model with L1-regularization using Stochastic Gradient Descent. - This solves the l1-regularized least squares regression formulation - f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1 - Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with - its corresponding right hand side label y. - See also the documentation for the precise formulation. - .. versionadded:: 0.9.0 """ - @classmethod @since("0.9.0") def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, validateData=True, convergenceTol=0.001): """ - Train a regression model with L1-regularization using - Stochastic Gradient Descent. - This solves the l1-regularized least squares regression - formulation - - f(weights) = 1/(2n) ||A weights - y||^2 + regParam ||weights||_1. - - Here the data matrix has n rows, and the input RDD holds the - set of rows of A, each with its corresponding right hand side - label y. See also the documentation for the precise formulation. - - :param data: The training data, an RDD of - LabeledPoint. - :param iterations: The number of iterations - (default: 100). - :param step: The step parameter used in SGD - (default: 1.0). - :param regParam: The regularizer parameter - (default: 0.01). - :param miniBatchFraction: Fraction of data to be used for each - SGD iteration (default: 1.0). - :param initialWeights: The initial weights (default: None). - :param intercept: Boolean parameter which indicates the - use or not of the augmented representation - for training data (i.e. whether bias - features are activated or not, - default: False). - :param validateData: Boolean parameter which indicates if - the algorithm should validate data - before training. (default: True) - :param convergenceTol: A condition which decides iteration termination. - (default: 0.001) + Train a regression model with L1-regularization using Stochastic + Gradient Descent. This solves the l1-regularized least squares + regression formulation + + f(weights) = 1/(2n) ||A weights - y||^2 + regParam ||weights||_1 + + Here the data matrix has n rows, and the input RDD holds the set + of rows of A, each with its corresponding right hand side label y. + See also the documentation for the precise formulation. + + :param data: + The training data, an RDD of LabeledPoint. + :param iterations: + The number of iterations. + (default: 100) + :param step: + The step parameter used in SGD. + (default: 1.0) + :param regParam: + The regularizer parameter. + (default: 0.01) + :param miniBatchFraction: + Fraction of data to be used for each SGD iteration. + (default: 1.0) + :param initialWeights: + The initial weights. + (default: None) + :param intercept: + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e. whether bias + features are activated or not). + (default: False) + :param validateData: + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + :param convergenceTol: + A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step), @@ -504,54 +502,53 @@ def load(cls, sc, path): class RidgeRegressionWithSGD(object): """ - Train a regression model with L2-regularization using Stochastic Gradient Descent. - This solves the l2-regularized least squares regression formulation - f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^ - Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with - its corresponding right hand side label y. - See also the documentation for the precise formulation. - .. versionadded:: 0.9.0 """ - @classmethod @since("0.9.0") def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, validateData=True, convergenceTol=0.001): """ - Train a regression model with L2-regularization using - Stochastic Gradient Descent. - This solves the l2-regularized least squares regression - formulation - - f(weights) = 1/(2n) ||A weights - y||^2 + regParam/2 ||weights||^2. - - Here the data matrix has n rows, and the input RDD holds the - set of rows of A, each with its corresponding right hand side - label y. See also the documentation for the precise formulation. - - :param data: The training data, an RDD of - LabeledPoint. - :param iterations: The number of iterations - (default: 100). - :param step: The step parameter used in SGD - (default: 1.0). - :param regParam: The regularizer parameter - (default: 0.01). - :param miniBatchFraction: Fraction of data to be used for each - SGD iteration (default: 1.0). - :param initialWeights: The initial weights (default: None). - :param intercept: Boolean parameter which indicates the - use or not of the augmented representation - for training data (i.e. whether bias - features are activated or not, - default: False). - :param validateData: Boolean parameter which indicates if - the algorithm should validate data - before training. (default: True) - :param convergenceTol: A condition which decides iteration termination. - (default: 0.001) + Train a regression model with L2-regularization using Stochastic + Gradient Descent. This solves the l2-regularized least squares + regression formulation + + f(weights) = 1/(2n) ||A weights - y||^2 + regParam/2 ||weights||^2 + + Here the data matrix has n rows, and the input RDD holds the set + of rows of A, each with its corresponding right hand side label y. + See also the documentation for the precise formulation. + + :param data: + The training data, an RDD of LabeledPoint. + :param iterations: + The number of iterations. + (default: 100) + :param step: + The step parameter used in SGD. + (default: 1.0) + :param regParam: + The regularizer parameter. + (default: 0.01) + :param miniBatchFraction: + Fraction of data to be used for each SGD iteration. + (default: 1.0) + :param initialWeights: + The initial weights. + (default: None) + :param intercept: + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e. whether bias + features are activated or not). + (default: False) + :param validateData: + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + :param convergenceTol: + A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step), @@ -566,12 +563,14 @@ class IsotonicRegressionModel(Saveable, Loader): """ Regression model for isotonic regression. - :param boundaries: Array of boundaries for which predictions are - known. Boundaries must be sorted in increasing order. - :param predictions: Array of predictions associated to the - boundaries at the same index. Results of isotonic - regression and therefore monotone. - :param isotonic: indicates whether this is isotonic or antitonic. + :param boundaries: + Array of boundaries for which predictions are known. Boundaries + must be sorted in increasing order. + :param predictions: + Array of predictions associated to the boundaries at the same + index. Results of isotonic regression and therefore monotone. + :param isotonic: + Indicates whether this is isotonic or antitonic. >>> data = [(1, 0, 1), (2, 1, 1), (3, 2, 1), (1, 3, 1), (6, 4, 1), (17, 5, 1), (16, 6, 1)] >>> irm = IsotonicRegression.train(sc.parallelize(data)) @@ -622,7 +621,8 @@ def predict(self, x): values with the same boundary then the same rules as in 2) are used. - :param x: Feature or RDD of Features to be labeled. + :param x: + Feature or RDD of Features to be labeled. """ if isinstance(x, RDD): return x.map(lambda v: self.predict(v)) @@ -651,21 +651,23 @@ def load(cls, sc, path): class IsotonicRegression(object): """ Isotonic regression. - Currently implemented using parallelized pool adjacent violators algorithm. - Only univariate (single feature) algorithm supported. + Currently implemented using parallelized pool adjacent violators + algorithm. Only univariate (single feature) algorithm supported. Sequential PAV implementation based on: - Tibshirani, Ryan J., Holger Hoefling, and Robert Tibshirani. + + Tibshirani, Ryan J., Holger Hoefling, and Robert Tibshirani. "Nearly-isotonic regression." Technometrics 53.1 (2011): 54-61. - Available from [[http://www.stat.cmu.edu/~ryantibs/papers/neariso.pdf]] + Available from http://www.stat.cmu.edu/~ryantibs/papers/neariso.pdf Sequential PAV parallelization based on: - Kearsley, Anthony J., Richard A. Tapia, and Michael W. Trosset. - "An approach to parallelizing isotonic regression." - Applied Mathematics and Parallel Computing. Physica-Verlag HD, 1996. 141-147. - Available from [[http://softlib.rice.edu/pub/CRPC-TRs/reports/CRPC-TR96640.pdf]] - @see [[http://en.wikipedia.org/wiki/Isotonic_regression Isotonic regression (Wikipedia)]] + Kearsley, Anthony J., Richard A. Tapia, and Michael W. Trosset. + "An approach to parallelizing isotonic regression." + Applied Mathematics and Parallel Computing. Physica-Verlag HD, 1996. 141-147. + Available from http://softlib.rice.edu/pub/CRPC-TRs/reports/CRPC-TR96640.pdf + + See `Isotonic regression (Wikipedia) `_. .. versionadded:: 1.4.0 """ @@ -676,8 +678,11 @@ def train(cls, data, isotonic=True): """ Train a isotonic regression model on the given data. - :param data: RDD of (label, feature, weight) tuples. - :param isotonic: Whether this is isotonic or antitonic. + :param data: + RDD of (label, feature, weight) tuples. + :param isotonic: + Whether this is isotonic (which is default) or antitonic. + (default: True) """ boundaries, predictions = callMLlibFunc("trainIsotonicRegressionModel", data.map(_convert_to_vector), bool(isotonic)) @@ -713,9 +718,11 @@ def _validate(self, dstream): @since("1.5.0") def predictOn(self, dstream): """ - Make predictions on a dstream. + Use the model to make predictions on batches of data from a + DStream. - :return: Transformed dstream object. + :return: + DStream containing predictions. """ self._validate(dstream) return dstream.map(lambda x: self._model.predict(x)) @@ -723,9 +730,11 @@ def predictOn(self, dstream): @since("1.5.0") def predictOnValues(self, dstream): """ - Make predictions on a keyed dstream. + Use the model to make predictions on the values of a DStream and + carry over its keys. - :return: Transformed dstream object. + :return: + DStream containing the input keys and the predictions as values. """ self._validate(dstream) return dstream.mapValues(lambda x: self._model.predict(x)) @@ -734,17 +743,28 @@ def predictOnValues(self, dstream): @inherit_doc class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm): """ - Run LinearRegression with SGD on a batch of data. - - The problem minimized is (1 / n_samples) * (y - weights'X)**2. - After training on a batch of data, the weights obtained at the end of - training are used as initial weights for the next batch. - - :param stepSize: Step size for each iteration of gradient descent. - :param numIterations: Total number of iterations run. - :param miniBatchFraction: Fraction of data on which SGD is run for each - iteration. - :param convergenceTol: A condition which decides iteration termination. + Train or predict a linear regression model on streaming data. + Training uses Stochastic Gradient Descent to update the model + based on each new batch of incoming data from a DStream + (see `LinearRegressionWithSGD` for model equation). + + Each batch of data is assumed to be an RDD of LabeledPoints. + The number of data points per batch can vary, but the number + of features must be constant. An initial weight vector must + be provided. + + :param stepSize: + Step size for each iteration of gradient descent. + (default: 0.1) + :param numIterations: + Number of iterations run for each batch of data. + (default: 50) + :param miniBatchFraction: + Fraction of each batch of data to use for updates. + (default: 1.0) + :param convergenceTol: + Value used to determine when to terminate iterations. + (default: 0.001) .. versionadded:: 1.5.0 """ diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index f8e8e0e0adbea..ac55fbf79841f 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -54,9 +54,11 @@ from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD +from pyspark.mllib.recommendation import Rating from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics +from pyspark.mllib.feature import HashingTF from pyspark.mllib.feature import Word2Vec from pyspark.mllib.feature import IDF from pyspark.mllib.feature import StandardScaler, ElementwiseProduct @@ -76,21 +78,24 @@ pass ser = PickleSerializer() -sc = SparkContext('local[4]', "MLlib tests") class MLlibTestCase(unittest.TestCase): def setUp(self): - self.sc = sc + self.sc = SparkContext('local[4]', "MLlib tests") + + def tearDown(self): + self.sc.stop() class MLLibStreamingTestCase(unittest.TestCase): def setUp(self): - self.sc = sc + self.sc = SparkContext('local[4]', "MLlib tests") self.ssc = StreamingContext(self.sc, 1.0) def tearDown(self): self.ssc.stop(False) + self.sc.stop() @staticmethod def _eventually(condition, timeout=30.0, catch_assertions=False): @@ -418,6 +423,17 @@ class ListTests(MLlibTestCase): as NumPy arrays. """ + def test_bisecting_kmeans(self): + from pyspark.mllib.clustering import BisectingKMeans + data = array([0.0, 0.0, 1.0, 1.0, 9.0, 8.0, 8.0, 9.0]).reshape(4, 2) + bskm = BisectingKMeans() + model = bskm.train(self.sc.parallelize(data, 2), k=4) + p = array([0.0, 0.0]) + rdd_p = self.sc.parallelize([p]) + self.assertEqual(model.predict(p), model.predict(rdd_p).first()) + self.assertEqual(model.computeCost(p), model.computeCost(rdd_p)) + self.assertEqual(model.k, len(model.clusterCenters)) + def test_kmeans(self): from pyspark.mllib.clustering import KMeans data = [ @@ -474,6 +490,18 @@ def test_gmm_deterministic(self): for c1, c2 in zip(clusters1.weights, clusters2.weights): self.assertEqual(round(c1, 7), round(c2, 7)) + def test_gmm_with_initial_model(self): + from pyspark.mllib.clustering import GaussianMixture + data = self.sc.parallelize([ + (-10, -5), (-9, -4), (10, 5), (9, 4) + ]) + + gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001, + maxIterations=10, seed=63) + gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001, + maxIterations=10, seed=63, initialModel=gmm1) + self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0) + def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\ @@ -670,7 +698,7 @@ def test_infer_schema(self): schema = df.schema field = [f for f in schema.fields if f.name == "features"][0] self.assertEqual(field.dataType, self.udt) - vectors = df.map(lambda p: p.features).collect() + vectors = df.rdd.map(lambda p: p.features).collect() self.assertEqual(len(vectors), 2) for v in vectors: if isinstance(v, SparseVector): @@ -702,7 +730,7 @@ def test_infer_schema(self): df = rdd.toDF() schema = df.schema self.assertTrue(schema.fields[1].dataType, self.udt) - matrices = df.map(lambda x: x._2).collect() + matrices = df.rdd.map(lambda x: x._2).collect() self.assertEqual(len(matrices), 2) for m in matrices: if isinstance(m, DenseMatrix): @@ -1142,7 +1170,7 @@ def test_predictOn_model(self): clusterWeights=[1.0, 1.0, 1.0, 1.0]) predict_data = [[[1.5, 1.5]], [[-1.5, 1.5]], [[-1.5, -1.5]], [[1.5, -1.5]]] - predict_data = [sc.parallelize(batch, 1) for batch in predict_data] + predict_data = [self.sc.parallelize(batch, 1) for batch in predict_data] predict_stream = self.ssc.queueStream(predict_data) predict_val = stkm.predictOn(predict_stream) @@ -1162,6 +1190,7 @@ def condition(): self._eventually(condition, catch_assertions=True) + @unittest.skip("SPARK-10086: Flaky StreamingKMeans test in PySpark") def test_trainOn_predictOn(self): """Test that prediction happens on the updated model.""" stkm = StreamingKMeans(decayFactor=0.0, k=2) @@ -1173,7 +1202,7 @@ def test_trainOn_predictOn(self): # classification based in the initial model would have been 0 # proving that the model is updated. batches = [[[-0.5], [0.6], [0.8]], [[0.2], [-0.1], [0.3]]] - batches = [sc.parallelize(batch) for batch in batches] + batches = [self.sc.parallelize(batch) for batch in batches] input_stream = self.ssc.queueStream(batches) predict_results = [] @@ -1206,7 +1235,7 @@ def test_dim(self): self.assertEqual(len(point.features), 3) linear_data = LinearDataGenerator.generateLinearRDD( - sc=sc, nexamples=6, nfeatures=2, eps=0.1, + sc=self.sc, nexamples=6, nfeatures=2, eps=0.1, nParts=2, intercept=0.0).collect() self.assertEqual(len(linear_data), 6) for point in linear_data: @@ -1382,7 +1411,7 @@ def test_parameter_accuracy(self): for i in range(10): batch = LinearDataGenerator.generateLinearInput( 0.0, [10.0, 10.0], xMean, xVariance, 100, 42 + i, 0.1) - batches.append(sc.parallelize(batch)) + batches.append(self.sc.parallelize(batch)) input_stream = self.ssc.queueStream(batches) slr.trainOn(input_stream) @@ -1406,7 +1435,7 @@ def test_parameter_convergence(self): for i in range(10): batch = LinearDataGenerator.generateLinearInput( 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) - batches.append(sc.parallelize(batch)) + batches.append(self.sc.parallelize(batch)) model_weights = [] input_stream = self.ssc.queueStream(batches) @@ -1439,7 +1468,7 @@ def test_prediction(self): 0.0, [10.0, 10.0], [0.0, 0.0], [1.0 / 3.0, 1.0 / 3.0], 100, 42 + i, 0.1) batches.append( - sc.parallelize(batch).map(lambda lp: (lp.label, lp.features))) + self.sc.parallelize(batch).map(lambda lp: (lp.label, lp.features))) input_stream = self.ssc.queueStream(batches) output_stream = slr.predictOnValues(input_stream) @@ -1470,7 +1499,7 @@ def test_train_prediction(self): for i in range(10): batch = LinearDataGenerator.generateLinearInput( 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) - batches.append(sc.parallelize(batch)) + batches.append(self.sc.parallelize(batch)) predict_batches = [ b.map(lambda lp: (lp.label, lp.features)) for b in batches] @@ -1539,7 +1568,39 @@ def test_load_vectors(self): shutil.rmtree(load_vectors_path) +class ALSTests(MLlibTestCase): + + def test_als_ratings_serialize(self): + r = Rating(7, 1123, 3.14) + jr = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(r))) + nr = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jr))) + self.assertEqual(r.user, nr.user) + self.assertEqual(r.product, nr.product) + self.assertAlmostEqual(r.rating, nr.rating, 2) + + def test_als_ratings_id_long_error(self): + r = Rating(1205640308657491975, 50233468418, 1.0) + # rating user id exceeds max int value, should fail when pickled + self.assertRaises(Py4JJavaError, self.sc._jvm.SerDe.loads, bytearray(ser.dumps(r))) + + +class HashingTFTest(MLlibTestCase): + + def test_binary_term_freqs(self): + hashingTF = HashingTF(100).setBinary(True) + doc = "a a b c c c".split(" ") + n = hashingTF.numFeatures + output = hashingTF.transform(doc).toArray() + expected = Vectors.sparse(n, {hashingTF.indexOf("a"): 1.0, + hashingTF.indexOf("b"): 1.0, + hashingTF.indexOf("c"): 1.0}).toArray() + for i in range(0, n): + self.assertAlmostEqual(output[i], expected[i], 14, "Error at " + str(i) + + ": expected " + str(expected[i]) + ", got " + str(output[i])) + + if __name__ == "__main__": + from pyspark.mllib.tests import * if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") if xmlrunner: diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 0001b60093a69..f7ea466b43291 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -60,8 +60,7 @@ def numTrees(self): @since("1.3.0") def totalNumNodes(self): """ - Get total number of nodes, summed over all trees in the - ensemble. + Get total number of nodes, summed over all trees in the ensemble. """ return self.call("totalNumNodes") @@ -92,8 +91,9 @@ def predict(self, x): transformation or action. Call predict directly on the RDD instead. - :param x: Data point (feature vector), - or an RDD of data points (feature vectors). + :param x: + Data point (feature vector), or an RDD of data points (feature + vectors). """ if isinstance(x, RDD): return self.call("predict", x.map(_convert_to_vector)) @@ -108,8 +108,9 @@ def numNodes(self): @since("1.1.0") def depth(self): - """Get depth of tree. - E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. + """ + Get depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). """ return self._java_model.depth() @@ -152,24 +153,37 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0): """ - Train a DecisionTreeModel for classification. - - :param data: Training data: RDD of LabeledPoint. - Labels are integers {0,1,...,numClasses}. - :param numClasses: Number of classes for classification. - :param categoricalFeaturesInfo: Map from categorical feature index - to number of categories. - Any feature not in this map - is treated as continuous. - :param impurity: Supported values: "entropy" or "gini" - :param maxDepth: Max depth of tree. - E.g., depth 0 means 1 leaf node. - Depth 1 means 1 internal node + 2 leaf nodes. - :param maxBins: Number of bins used for finding splits at each node. - :param minInstancesPerNode: Min number of instances required at child - nodes to create the parent split - :param minInfoGain: Min info gain required to create a split - :return: DecisionTreeModel + Train a decision tree model for classification. + + :param data: + Training data: RDD of LabeledPoint. Labels should take values + {0, 1, ..., numClasses-1}. + :param numClasses: + Number of classes for classification. + :param categoricalFeaturesInfo: + Map storing arity of categorical features. An entry (n -> k) + indicates that feature n is categorical with k categories + indexed from 0: {0, 1, ..., k-1}. + :param impurity: + Criterion used for information gain calculation. + Supported values: "gini" or "entropy". + (default: "gini") + :param maxDepth: + Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). + (default: 5) + :param maxBins: + Number of bins used for finding splits at each node. + (default: 32) + :param minInstancesPerNode: + Minimum number of instances required at child nodes to create + the parent split. + (default: 1) + :param minInfoGain: + Minimum info gain required to create a split. + (default: 0.0) + :return: + DecisionTreeModel. Example usage: @@ -211,23 +225,34 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0): """ - Train a DecisionTreeModel for regression. - - :param data: Training data: RDD of LabeledPoint. - Labels are real numbers. - :param categoricalFeaturesInfo: Map from categorical feature - index to number of categories. - Any feature not in this map is treated as continuous. - :param impurity: Supported values: "variance" - :param maxDepth: Max depth of tree. - E.g., depth 0 means 1 leaf node. - Depth 1 means 1 internal node + 2 leaf nodes. - :param maxBins: Number of bins used for finding splits at each - node. - :param minInstancesPerNode: Min number of instances required at - child nodes to create the parent split - :param minInfoGain: Min info gain required to create a split - :return: DecisionTreeModel + Train a decision tree model for regression. + + :param data: + Training data: RDD of LabeledPoint. Labels are real numbers. + :param categoricalFeaturesInfo: + Map storing arity of categorical features. An entry (n -> k) + indicates that feature n is categorical with k categories + indexed from 0: {0, 1, ..., k-1}. + :param impurity: + Criterion used for information gain calculation. + The only supported value for regression is "variance". + (default: "variance") + :param maxDepth: + Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). + (default: 5) + :param maxBins: + Number of bins used for finding splits at each node. + (default: 32) + :param minInstancesPerNode: + Minimum number of instances required at child nodes to create + the parent split. + (default: 1) + :param minInfoGain: + Minimum info gain required to create a split. + (default: 0.0) + :return: + DecisionTreeModel. Example usage: @@ -302,34 +327,44 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto", impurity="gini", maxDepth=4, maxBins=32, seed=None): """ - Method to train a decision tree model for binary or multiclass + Train a random forest model for binary or multiclass classification. - :param data: Training dataset: RDD of LabeledPoint. Labels - should take values {0, 1, ..., numClasses-1}. - :param numClasses: number of classes for classification. - :param categoricalFeaturesInfo: Map storing arity of categorical - features. E.g., an entry (n -> k) indicates that - feature n is categorical with k categories indexed - from 0: {0, 1, ..., k-1}. - :param numTrees: Number of trees in the random forest. - :param featureSubsetStrategy: Number of features to consider for - splits at each node. - Supported: "auto" (default), "all", "sqrt", "log2", "onethird". - If "auto" is set, this parameter is set based on numTrees: - if numTrees == 1, set to "all"; - if numTrees > 1 (forest) set to "sqrt". - :param impurity: Criterion used for information gain calculation. - Supported values: "gini" (recommended) or "entropy". - :param maxDepth: Maximum depth of the tree. - E.g., depth 0 means 1 leaf node; depth 1 means - 1 internal node + 2 leaf nodes. (default: 4) - :param maxBins: maximum number of bins used for splitting - features - (default: 32) - :param seed: Random seed for bootstrapping and choosing feature - subsets. - :return: RandomForestModel that can be used for prediction + :param data: + Training dataset: RDD of LabeledPoint. Labels should take values + {0, 1, ..., numClasses-1}. + :param numClasses: + Number of classes for classification. + :param categoricalFeaturesInfo: + Map storing arity of categorical features. An entry (n -> k) + indicates that feature n is categorical with k categories + indexed from 0: {0, 1, ..., k-1}. + :param numTrees: + Number of trees in the random forest. + :param featureSubsetStrategy: + Number of features to consider for splits at each node. + Supported values: "auto", "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "sqrt". + (default: "auto") + :param impurity: + Criterion used for information gain calculation. + Supported values: "gini" or "entropy". + (default: "gini") + :param maxDepth: + Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). + (default: 4) + :param maxBins: + Maximum number of bins used for splitting features. + (default: 32) + :param seed: + Random seed for bootstrapping and choosing feature subsets. + Set as None to generate seed based on system time. + (default: None) + :return: + RandomForestModel that can be used for prediction. Example usage: @@ -383,32 +418,40 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto", impurity="variance", maxDepth=4, maxBins=32, seed=None): """ - Method to train a decision tree model for regression. - - :param data: Training dataset: RDD of LabeledPoint. Labels are - real numbers. - :param categoricalFeaturesInfo: Map storing arity of categorical - features. E.g., an entry (n -> k) indicates that feature - n is categorical with k categories indexed from 0: - {0, 1, ..., k-1}. - :param numTrees: Number of trees in the random forest. - :param featureSubsetStrategy: Number of features to consider for - splits at each node. - Supported: "auto" (default), "all", "sqrt", "log2", "onethird". - If "auto" is set, this parameter is set based on numTrees: - if numTrees == 1, set to "all"; - if numTrees > 1 (forest) set to "onethird" for regression. - :param impurity: Criterion used for information gain - calculation. - Supported values: "variance". - :param maxDepth: Maximum depth of the tree. E.g., depth 0 means - 1 leaf node; depth 1 means 1 internal node + 2 leaf - nodes. (default: 4) - :param maxBins: maximum number of bins used for splitting - features (default: 32) - :param seed: Random seed for bootstrapping and choosing feature - subsets. - :return: RandomForestModel that can be used for prediction + Train a random forest model for regression. + + :param data: + Training dataset: RDD of LabeledPoint. Labels are real numbers. + :param categoricalFeaturesInfo: + Map storing arity of categorical features. An entry (n -> k) + indicates that feature n is categorical with k categories + indexed from 0: {0, 1, ..., k-1}. + :param numTrees: + Number of trees in the random forest. + :param featureSubsetStrategy: + Number of features to consider for splits at each node. + Supported values: "auto", "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "onethird" for regression. + (default: "auto") + :param impurity: + Criterion used for information gain calculation. + The only supported value for regression is "variance". + (default: "variance") + :param maxDepth: + Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). + (default: 4) + :param maxBins: + Maximum number of bins used for splitting features. + (default: 32) + :param seed: + Random seed for bootstrapping and choosing feature subsets. + Set as None to generate seed based on system time. + (default: None) + :return: + RandomForestModel that can be used for prediction. Example usage: @@ -480,31 +523,37 @@ def trainClassifier(cls, data, categoricalFeaturesInfo, loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32): """ - Method to train a gradient-boosted trees model for - classification. - - :param data: Training dataset: RDD of LabeledPoint. - Labels should take values {0, 1}. - :param categoricalFeaturesInfo: Map storing arity of categorical - features. E.g., an entry (n -> k) indicates that feature - n is categorical with k categories indexed from 0: - {0, 1, ..., k-1}. - :param loss: Loss function used for minimization during gradient - boosting. Supported: {"logLoss" (default), - "leastSquaresError", "leastAbsoluteError"}. - :param numIterations: Number of iterations of boosting. - (default: 100) - :param learningRate: Learning rate for shrinking the - contribution of each estimator. The learning rate - should be between in the interval (0, 1]. - (default: 0.1) - :param maxDepth: Maximum depth of the tree. E.g., depth 0 means - 1 leaf node; depth 1 means 1 internal node + 2 leaf - nodes. (default: 3) - :param maxBins: maximum number of bins used for splitting - features (default: 32) DecisionTree requires maxBins >= max categories - :return: GradientBoostedTreesModel that can be used for - prediction + Train a gradient-boosted trees model for classification. + + :param data: + Training dataset: RDD of LabeledPoint. Labels should take values + {0, 1}. + :param categoricalFeaturesInfo: + Map storing arity of categorical features. An entry (n -> k) + indicates that feature n is categorical with k categories + indexed from 0: {0, 1, ..., k-1}. + :param loss: + Loss function used for minimization during gradient boosting. + Supported values: "logLoss", "leastSquaresError", + "leastAbsoluteError". + (default: "logLoss") + :param numIterations: + Number of iterations of boosting. + (default: 100) + :param learningRate: + Learning rate for shrinking the contribution of each estimator. + The learning rate should be between in the interval (0, 1]. + (default: 0.1) + :param maxDepth: + Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). + (default: 3) + :param maxBins: + Maximum number of bins used for splitting features. DecisionTree + requires maxBins >= max categories. + (default: 32) + :return: + GradientBoostedTreesModel that can be used for prediction. Example usage: @@ -543,30 +592,36 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32): """ - Method to train a gradient-boosted trees model for regression. - - :param data: Training dataset: RDD of LabeledPoint. Labels are - real numbers. - :param categoricalFeaturesInfo: Map storing arity of categorical - features. E.g., an entry (n -> k) indicates that feature - n is categorical with k categories indexed from 0: - {0, 1, ..., k-1}. - :param loss: Loss function used for minimization during gradient - boosting. Supported: {"logLoss" (default), - "leastSquaresError", "leastAbsoluteError"}. - :param numIterations: Number of iterations of boosting. - (default: 100) - :param learningRate: Learning rate for shrinking the - contribution of each estimator. The learning rate - should be between in the interval (0, 1]. - (default: 0.1) - :param maxBins: maximum number of bins used for splitting - features (default: 32) DecisionTree requires maxBins >= max categories - :param maxDepth: Maximum depth of the tree. E.g., depth 0 means - 1 leaf node; depth 1 means 1 internal node + 2 leaf - nodes. (default: 3) - :return: GradientBoostedTreesModel that can be used for - prediction + Train a gradient-boosted trees model for regression. + + :param data: + Training dataset: RDD of LabeledPoint. Labels are real numbers. + :param categoricalFeaturesInfo: + Map storing arity of categorical features. An entry (n -> k) + indicates that feature n is categorical with k categories + indexed from 0: {0, 1, ..., k-1}. + :param loss: + Loss function used for minimization during gradient boosting. + Supported values: "logLoss", "leastSquaresError", + "leastAbsoluteError". + (default: "leastSquaresError") + :param numIterations: + Number of iterations of boosting. + (default: 100) + :param learningRate: + Learning rate for shrinking the contribution of each estimator. + The learning rate should be between in the interval (0, 1]. + (default: 0.1) + :param maxDepth: + Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 + means 1 internal node + 2 leaf nodes). + (default: 3) + :param maxBins: + Maximum number of bins used for splitting features. DecisionTree + requires maxBins >= max categories. + (default: 32) + :return: + GradientBoostedTreesModel that can be used for prediction. Example usage: diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 56e892243c79c..8978f028c5928 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -115,7 +115,7 @@ def _parse_memory(s): 2048 """ units = {'g': 1024, 'm': 1, 't': 1 << 20, 'k': 1.0 / 1024} - if s[-1] not in units: + if s[-1].lower() not in units: raise ValueError("invalid format: " + s) return int(float(s[:-1]) * units[s[-1].lower()]) @@ -220,18 +220,18 @@ def context(self): def cache(self): """ - Persist this RDD with the default storage level (C{MEMORY_ONLY_SER}). + Persist this RDD with the default storage level (C{MEMORY_ONLY}). """ self.is_cached = True - self.persist(StorageLevel.MEMORY_ONLY_SER) + self.persist(StorageLevel.MEMORY_ONLY) return self - def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): + def persist(self, storageLevel=StorageLevel.MEMORY_ONLY): """ Set this RDD's storage level to persist its values across operations after the first time it is computed. This can only be used to assign a new storage level if the RDD does not have a storage level set yet. - If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). + If no storage level is specified defaults to (C{MEMORY_ONLY}). >>> rdd = sc.parallelize(["b", "a", "c"]) >>> rdd.persist().is_cached @@ -426,6 +426,9 @@ def takeSample(self, withReplacement, num, seed=None): """ Return a fixed-size sampled subset of this RDD. + Note that this method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. + >>> rdd = sc.parallelize(range(0, 10)) >>> len(rdd.takeSample(True, 20, 1)) 20 @@ -766,6 +769,8 @@ def func(it): def collect(self): """ Return a list that contains all of the elements in this RDD. + Note that this method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. """ with SCCallSiteSync(self.context) as css: port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) @@ -839,8 +844,7 @@ def op(x, y): def fold(self, zeroValue, op): """ Aggregate the elements of each partition, and then the results for all - the partitions, using a given associative and commutative function and - a neutral "zero value." + the partitions, using a given associative function and a neutral "zero value." The function C{op(t1, t2)} is allowed to modify C{t1} and return it as its result value to avoid object allocation; however, it should not @@ -861,7 +865,7 @@ def fold(self, zeroValue, op): def func(iterator): acc = zeroValue for obj in iterator: - acc = op(obj, acc) + acc = op(acc, obj) yield acc # collecting result of mapPartitions here ensures that the copy of # zeroValue provided to each partition is unique from the one provided @@ -1213,6 +1217,9 @@ def top(self, num, key=None): """ Get the top N elements from a RDD. + Note that this method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. + Note: It returns the list sorted in descending order. >>> sc.parallelize([10, 4, 2, 12, 3]).top(1) @@ -1235,6 +1242,9 @@ def takeOrdered(self, num, key=None): Get the N elements from a RDD ordered in ascending order or as specified by the optional key function. + Note that this method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. + >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6) [1, 2, 3, 4, 5, 6] >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7], 2).takeOrdered(6, key=lambda x: -x) @@ -1254,6 +1264,9 @@ def take(self, num): that partition to estimate the number of additional partitions needed to satisfy the limit. + Note that this method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. + Translated from the Scala implementation in RDD#take(). >>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2) @@ -1511,6 +1524,9 @@ def collectAsMap(self): """ Return the key-value pairs in this RDD to the master as a dictionary. + Note that this method should only be used if the resulting data is expected + to be small, as all the data is loaded into the driver's memory. + >>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap() >>> m[1] 2 @@ -1541,7 +1557,7 @@ def values(self): def reduceByKey(self, func, numPartitions=None, partitionFunc=portable_hash): """ - Merge the values for each key using an associative reduce function. + Merge the values for each key using an associative and commutative reduce function. This will also perform the merging locally on each mapper before sending results to a reducer, similarly to a "combiner" in MapReduce. @@ -1559,7 +1575,7 @@ def reduceByKey(self, func, numPartitions=None, partitionFunc=portable_hash): def reduceByKeyLocally(self, func): """ - Merge the values for each key using an associative reduce function, but + Merge the values for each key using an associative and commutative reduce function, but return the results immediately to the master as a dictionary. This will also perform the merging locally on each mapper before @@ -1760,7 +1776,6 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, In addition, users can control the partitioning of the output RDD. >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> def f(x): return x >>> def add(a, b): return a + str(b) >>> sorted(x.combineByKey(str, add, add).collect()) [('a', '11'), ('b', '1')] @@ -2016,7 +2031,7 @@ def coalesce(self, numPartitions, shuffle=False): >>> sc.parallelize([1, 2, 3, 4, 5], 3).coalesce(1).glom().collect() [[1, 2, 3, 4, 5]] """ - jrdd = self._jrdd.coalesce(numPartitions) + jrdd = self._jrdd.coalesce(numPartitions, shuffle) return RDD(jrdd, self.ctx, self._jrdd_deserializer) def zip(self, other): @@ -2284,17 +2299,17 @@ def toLocalIterator(self): """ Return an iterator that contains all of the elements in this RDD. The iterator will consume as much memory as the largest partition in this RDD. + >>> rdd = sc.parallelize(range(10)) >>> [x for x in rdd.toLocalIterator()] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ - for partition in range(self.getNumPartitions()): - rows = self.context.runJob(self, lambda x: x, [partition]) - for row in rows: - yield row + with SCCallSiteSync(self.context) as css: + port = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) + return _load_from_socket(port, self._jrdd_deserializer) -def _prepare_for_python_RDD(sc, command, obj=None): +def _prepare_for_python_RDD(sc, command): # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() pickled_command = ser.dumps(command) @@ -2314,6 +2329,15 @@ def _prepare_for_python_RDD(sc, command, obj=None): return pickled_command, broadcast_vars, env, includes +def _wrap_function(sc, func, deserializer, serializer, profiler=None): + assert deserializer, "deserializer should not be empty" + assert serializer, "serializer should not be empty" + command = (func, profiler, deserializer, serializer) + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) + return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, + sc.pythonVer, broadcast_vars, sc._javaAccumulator) + + class PipelinedRDD(RDD): """ @@ -2375,14 +2399,10 @@ def _jrdd(self): else: profiler = None - command = (self.func, profiler, self._prev_jrdd_deserializer, - self._jrdd_deserializer) - pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self) - python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), - bytearray(pickled_cmd), - env, includes, self.preservesPartitioning, - self.ctx.pythonExec, self.ctx.pythonVer, - bvars, self.ctx._javaAccumulator) + wrapped_func = _wrap_function(self.ctx, self.func, self._prev_jrdd_deserializer, + self._jrdd_deserializer, profiler) + python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), wrapped_func, + self.preservesPartitioning) self._jrdd_val = python_rdd.asJavaRDD() if profiler: diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 99331297c19f0..7c37f75193473 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -32,15 +32,10 @@ from pyspark.sql import SQLContext, HiveContext from pyspark.storagelevel import StorageLevel -# this is the deprecated equivalent of ADD_JARS -add_files = None -if os.environ.get("ADD_FILES") is not None: - add_files = os.environ.get("ADD_FILES").split(',') - if os.environ.get("SPARK_EXECUTOR_URI"): SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) -sc = SparkContext(pyFiles=add_files) +sc = SparkContext() atexit.register(lambda: sc.stop()) try: @@ -68,12 +63,10 @@ platform.python_build()[1])) print("SparkContext available as sc, %s available as sqlContext." % sqlContext.__class__.__name__) -if add_files is not None: - print("Warning: ADD_FILES environment variable is deprecated, use --py-files argument instead") - print("Adding files: [%s]" % ", ".join(add_files)) - # The ./bin/pyspark script stores the old PYTHONSTARTUP value in OLD_PYTHONSTARTUP, # which allows us to execute the user's PYTHONSTARTUP file: _pythonstartup = os.environ.get('OLD_PYTHONSTARTUP') if _pythonstartup and os.path.isfile(_pythonstartup): - execfile(_pythonstartup) + with open(_pythonstartup) as f: + code = compile(f.read(), _pythonstartup, 'exec') + exec(code) diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 98eaf52866d23..0b06c8339f501 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -47,7 +47,7 @@ from pyspark.sql.types import Row from pyspark.sql.context import SQLContext, HiveContext from pyspark.sql.column import Column -from pyspark.sql.dataframe import DataFrame, SchemaRDD, DataFrameNaFunctions, DataFrameStatFunctions +from pyspark.sql.dataframe import DataFrame, DataFrameNaFunctions, DataFrameStatFunctions from pyspark.sql.group import GroupedData from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter from pyspark.sql.window import Window, WindowSpec diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 9ca8e1f264cfa..43e9baece2de9 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -22,13 +22,12 @@ basestring = str long = int -from pyspark import since +from pyspark import copy_func, since from pyspark.context import SparkContext from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.types import * -__all__ = ["DataFrame", "Column", "SchemaRDD", "DataFrameNaFunctions", - "DataFrameStatFunctions"] +__all__ = ["DataFrame", "Column", "DataFrameNaFunctions", "DataFrameStatFunctions"] def _create_column_from_literal(literal): @@ -220,17 +219,17 @@ def getField(self, name): >>> from pyspark.sql import Row >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF() >>> df.select(df.r.getField("b")).show() - +----+ - |r[b]| - +----+ - | b| - +----+ + +---+ + |r.b| + +---+ + | b| + +---+ >>> df.select(df.r.a).show() - +----+ - |r[a]| - +----+ - | 1| - +----+ + +---+ + |r.a| + +---+ + | 1| + +---+ """ return self[name] @@ -272,23 +271,6 @@ def substr(self, startPos, length): __getslice__ = substr - @ignore_unicode_prefix - @since(1.3) - def inSet(self, *cols): - """ - A boolean expression that is evaluated to true if the value of this - expression is contained by the evaluated values of the arguments. - - >>> df[df.name.inSet("Bob", "Mike")].collect() - [Row(age=5, name=u'Bob')] - >>> df[df.age.inSet([1, 2, 3])].collect() - [Row(age=2, name=u'Alice')] - - .. note:: Deprecated in 1.5, use :func:`Column.isin` instead. - """ - warnings.warn("inSet is deprecated. Use isin() instead.") - return self.isin(*cols) - @ignore_unicode_prefix @since(1.5) def isin(self, *cols): @@ -333,6 +315,8 @@ def alias(self, *alias): sc = SparkContext._active_spark_context return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias)))) + name = copy_func(alias, sinceversion=2.0, doc=":func:`name` is an alias for :func:`alias`.") + @ignore_unicode_prefix @since(1.3) def cast(self, dataType): @@ -346,15 +330,16 @@ def cast(self, dataType): if isinstance(dataType, basestring): jc = self._jc.cast(dataType) elif isinstance(dataType, DataType): - sc = SparkContext._active_spark_context - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - jdt = ssql_ctx.parseDataType(dataType.json()) + from pyspark.sql import SQLContext + sc = SparkContext.getOrCreate() + ctx = SQLContext.getOrCreate(sc) + jdt = ctx._ssql_ctx.parseDataType(dataType.json()) jc = self._jc.cast(jdt) else: raise TypeError("unexpected type: %s" % type(dataType)) return Column(jc) - astype = cast + astype = copy_func(cast, sinceversion=1.4, doc=":func:`astype` is an alias for :func:`cast`.") @since(1.3) def between(self, lowerBound, upperBound): @@ -363,12 +348,12 @@ def between(self, lowerBound, upperBound): expression is between the given columns. >>> df.select(df.name, df.age.between(2, 4)).show() - +-----+--------------------------+ - | name|((age >= 2) && (age <= 4))| - +-----+--------------------------+ - |Alice| true| - | Bob| false| - +-----+--------------------------+ + +-----+---------------------------+ + | name|((age >= 2) AND (age <= 4))| + +-----+---------------------------+ + |Alice| true| + | Bob| false| + +-----+---------------------------+ """ return (self >= lowerBound) & (self <= upperBound) @@ -385,12 +370,12 @@ def when(self, condition, value): >>> from pyspark.sql import functions as F >>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show() - +-----+--------------------------------------------------------+ - | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0| - +-----+--------------------------------------------------------+ - |Alice| -1| - | Bob| 1| - +-----+--------------------------------------------------------+ + +-----+------------------------------------------------------------+ + | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0 END| + +-----+------------------------------------------------------------+ + |Alice| -1| + | Bob| 1| + +-----+------------------------------------------------------------+ """ if not isinstance(condition, Column): raise TypeError("condition should be a Column") @@ -410,12 +395,12 @@ def otherwise(self, value): >>> from pyspark.sql import functions as F >>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show() - +-----+---------------------------------+ - | name|CASE WHEN (age > 3) THEN 1 ELSE 0| - +-----+---------------------------------+ - |Alice| 0| - | Bob| 1| - +-----+---------------------------------+ + +-----+-------------------------------------+ + | name|CASE WHEN (age > 3) THEN 1 ELSE 0 END| + +-----+-------------------------------------+ + |Alice| 0| + | Bob| 1| + +-----+-------------------------------------+ """ v = value._jc if isinstance(value, Column) else value jc = self._jc.otherwise(v) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 924bb6433de0e..11dfcfe13ee0d 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -15,9 +15,11 @@ # limitations under the License. # +from __future__ import print_function import sys import warnings import json +from functools import reduce if sys.version >= '3': basestring = unicode = str @@ -27,10 +29,10 @@ from py4j.protocol import Py4JError from pyspark import since -from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix +from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.serializers import AutoBatchedSerializer, PickleSerializer -from pyspark.sql.types import Row, StringType, StructType, _verify_type, \ - _infer_schema, _has_nulltype, _merge_type, _create_converter +from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \ + _infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.utils import install_exception_handler @@ -90,9 +92,9 @@ def __init__(self, sparkContext, sqlContext=None): >>> df.registerTempTable("allTypes") >>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() - [Row(_c0=2, _c1=2.0, _c2=False, _c3=2, _c4=0, \ - time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)] - >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect() + [Row((i + CAST(1 AS BIGINT))=2, (d + CAST(1 AS DOUBLE))=2.0, (NOT b)=False, list[1]=2, \ + dict[s]=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)] + >>> df.rdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect() [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] """ self._sc = sparkContext @@ -195,29 +197,30 @@ def range(self, start, end=None, step=1, numPartitions=None): @ignore_unicode_prefix @since(1.2) def registerFunction(self, name, f, returnType=StringType()): - """Registers a lambda function as a UDF so it can be used in SQL statements. + """Registers a python function (including lambda function) as a UDF + so it can be used in SQL statements. In addition to a name and the function itself, the return type can be optionally specified. When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type. :param name: name of the UDF - :param samplingRatio: lambda function + :param f: python function :param returnType: a :class:`DataType` object >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlContext.sql("SELECT stringLengthString('test')").collect() - [Row(_c0=u'4')] + [Row(stringLengthString(test)=u'4')] >>> from pyspark.sql.types import IntegerType >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(_c0=4)] + [Row(stringLengthInt(test)=4)] >>> from pyspark.sql.types import IntegerType >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(_c0=4)] + [Row(stringLengthInt(test)=4)] """ udf = UserDefinedFunction(f, returnType, name) self._ssql_ctx.udf().registerPython(name, udf._judf) @@ -235,14 +238,9 @@ def _inferSchemaFromList(self, data): if type(first) is dict: warnings.warn("inferring schema from dict is deprecated," "please use pyspark.sql.Row instead") - schema = _infer_schema(first) + schema = reduce(_merge_type, map(_infer_schema, data)) if _has_nulltype(schema): - for r in data: - schema = _merge_type(schema, _infer_schema(r)) - if not _has_nulltype(schema): - break - else: - raise ValueError("Some of types cannot be determined after inferring") + raise ValueError("Some of types cannot be determined after inferring") return schema def _inferSchema(self, rdd, samplingRatio=None): @@ -277,33 +275,6 @@ def _inferSchema(self, rdd, samplingRatio=None): schema = rdd.map(_infer_schema).reduce(_merge_type) return schema - @ignore_unicode_prefix - def inferSchema(self, rdd, samplingRatio=None): - """ - .. note:: Deprecated in 1.3, use :func:`createDataFrame` instead. - """ - warnings.warn("inferSchema is deprecated, please use createDataFrame instead.") - - if isinstance(rdd, DataFrame): - raise TypeError("Cannot apply schema to DataFrame") - - return self.createDataFrame(rdd, None, samplingRatio) - - @ignore_unicode_prefix - def applySchema(self, rdd, schema): - """ - .. note:: Deprecated in 1.3, use :func:`createDataFrame` instead. - """ - warnings.warn("applySchema is deprecated, please use createDataFrame instead") - - if isinstance(rdd, DataFrame): - raise TypeError("Cannot apply schema to DataFrame") - - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType, but got %s" % type(schema)) - - return self.createDataFrame(rdd, schema) - def _createFromRDD(self, rdd, schema, samplingRatio): """ Create an RDD for DataFrame from an existing RDD, returns the RDD and schema. @@ -330,11 +301,6 @@ def _createFromLocal(self, data, schema): Create an RDD for DataFrame from an list or pandas.DataFrame, returns the RDD and schema. """ - if has_pandas and isinstance(data, pandas.DataFrame): - if schema is None: - schema = [str(x) for x in data.columns] - data = [r.tolist() for r in data.to_records(index=False)] - # make sure data could consumed multiple times if not isinstance(data, list): data = list(data) @@ -362,8 +328,7 @@ def _createFromLocal(self, data, schema): @ignore_unicode_prefix def createDataFrame(self, data, schema=None, samplingRatio=None): """ - Creates a :class:`DataFrame` from an :class:`RDD` of :class:`tuple`/:class:`list`, - list or :class:`pandas.DataFrame`. + Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`. When ``schema`` is a list of column names, the type of each column will be inferred from ``data``. @@ -372,15 +337,29 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): from ``data``, which should be an RDD of :class:`Row`, or :class:`namedtuple`, or :class:`dict`. + When ``schema`` is :class:`DataType` or datatype string, it must match the real data, or + exception will be thrown at runtime. If the given schema is not StructType, it will be + wrapped into a StructType as its only field, and the field name will be "value", each record + will also be wrapped into a tuple, which can be converted to row later. + If schema inference is needed, ``samplingRatio`` is used to determined the ratio of rows used for schema inference. The first row will be used if ``samplingRatio`` is ``None``. - :param data: an RDD of :class:`Row`/:class:`tuple`/:class:`list`/:class:`dict`, - :class:`list`, or :class:`pandas.DataFrame`. - :param schema: a :class:`StructType` or list of column names. default None. + :param data: an RDD of any kind of SQL data representation(e.g. row, tuple, int, boolean, + etc.), or :class:`list`, or :class:`pandas.DataFrame`. + :param schema: a :class:`DataType` or a datatype string or a list of column names, default + is None. The data type string format equals to `DataType.simpleString`, except that + top level struct type can omit the `struct<>` and atomic types use `typeName()` as + their format, e.g. use `byte` instead of `tinyint` for ByteType. We can also use `int` + as a short name for IntegerType. :param samplingRatio: the sample ratio of rows used for inferring :return: :class:`DataFrame` + .. versionchanged:: 2.0 + The schema parameter can be a DataType or a datatype string after 2.0. If it's not a + StructType, it will be wrapped into a StructType and each record will also be wrapped + into a tuple. + >>> l = [('Alice', 1)] >>> sqlContext.createDataFrame(l).collect() [Row(_1=u'Alice', _2=1)] @@ -415,16 +394,48 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): >>> sqlContext.createDataFrame(df.toPandas()).collect() # doctest: +SKIP [Row(name=u'Alice', age=1)] - >>> sqlContext.createDataFrame(pandas.DataFrame([[1, 2]]).collect()) # doctest: +SKIP + >>> sqlContext.createDataFrame(pandas.DataFrame([[1, 2]])).collect() # doctest: +SKIP [Row(0=1, 1=2)] + + >>> sqlContext.createDataFrame(rdd, "a: string, b: int").collect() + [Row(a=u'Alice', b=1)] + >>> rdd = rdd.map(lambda row: row[1]) + >>> sqlContext.createDataFrame(rdd, "int").collect() + [Row(value=1)] + >>> sqlContext.createDataFrame(rdd, "boolean").collect() # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + Py4JJavaError: ... """ if isinstance(data, DataFrame): raise TypeError("data is already a DataFrame") + if isinstance(schema, basestring): + schema = _parse_datatype_string(schema) + + if has_pandas and isinstance(data, pandas.DataFrame): + if schema is None: + schema = [str(x) for x in data.columns] + data = [r.tolist() for r in data.to_records(index=False)] + + if isinstance(schema, StructType): + def prepare(obj): + _verify_type(obj, schema) + return obj + elif isinstance(schema, DataType): + datatype = schema + + def prepare(obj): + _verify_type(obj, datatype) + return (obj, ) + schema = StructType().add("value", datatype) + else: + prepare = lambda obj: obj + if isinstance(data, RDD): - rdd, schema = self._createFromRDD(data, schema, samplingRatio) + rdd, schema = self._createFromRDD(data.map(prepare), schema, samplingRatio) else: - rdd, schema = self._createFromLocal(data, schema) + rdd, schema = self._createFromLocal(map(prepare, data), schema) jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) jdf = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) df = DataFrame(jdf, self) @@ -444,89 +455,14 @@ def registerDataFrameAsTable(self, df, tableName): else: raise ValueError("Can only register DataFrame as table") - def parquetFile(self, *paths): - """Loads a Parquet file, returning the result as a :class:`DataFrame`. - - .. note:: Deprecated in 1.4, use :func:`DataFrameReader.parquet` instead. - - >>> sqlContext.parquetFile('python/test_support/sql/parquet_partitioned').dtypes - [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] - """ - warnings.warn("parquetFile is deprecated. Use read.parquet() instead.") - gateway = self._sc._gateway - jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths)) - for i in range(0, len(paths)): - jpaths[i] = paths[i] - jdf = self._ssql_ctx.parquetFile(jpaths) - return DataFrame(jdf, self) - - def jsonFile(self, path, schema=None, samplingRatio=1.0): - """Loads a text file storing one JSON object per line as a :class:`DataFrame`. - - .. note:: Deprecated in 1.4, use :func:`DataFrameReader.json` instead. - - >>> sqlContext.jsonFile('python/test_support/sql/people.json').dtypes - [('age', 'bigint'), ('name', 'string')] - """ - warnings.warn("jsonFile is deprecated. Use read.json() instead.") - if schema is None: - df = self._ssql_ctx.jsonFile(path, samplingRatio) - else: - scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - df = self._ssql_ctx.jsonFile(path, scala_datatype) - return DataFrame(df, self) - - @ignore_unicode_prefix - @since(1.0) - def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): - """Loads an RDD storing one JSON object per string as a :class:`DataFrame`. - - If the schema is provided, applies the given schema to this JSON dataset. - Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema. - - >>> df1 = sqlContext.jsonRDD(json) - >>> df1.first() - Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) - - >>> df2 = sqlContext.jsonRDD(json, df1.schema) - >>> df2.first() - Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) - - >>> from pyspark.sql.types import * - >>> schema = StructType([ - ... StructField("field2", StringType()), - ... StructField("field3", - ... StructType([StructField("field5", ArrayType(IntegerType()))])) - ... ]) - >>> df3 = sqlContext.jsonRDD(json, schema) - >>> df3.first() - Row(field2=u'row1', field3=Row(field5=None)) - """ - - def func(iterator): - for x in iterator: - if not isinstance(x, basestring): - x = unicode(x) - if isinstance(x, unicode): - x = x.encode("utf-8") - yield x - keyed = rdd.mapPartitions(func) - keyed._bypass_serializer = True - jrdd = keyed._jrdd.map(self._jvm.BytesToString()) - if schema is None: - df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) - else: - scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) - return DataFrame(df, self) - - def load(self, path=None, source=None, schema=None, **options): - """Returns the dataset in a data source as a :class:`DataFrame`. + @since(1.6) + def dropTempTable(self, tableName): + """ Remove the temp table from catalog. - .. note:: Deprecated in 1.4, use :func:`DataFrameReader.load` instead. + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> sqlContext.dropTempTable("table1") """ - warnings.warn("load is deprecated. Use read.load() instead.") - return self.read.load(path, source, schema, **options) + self._ssql_ctx.dropTempTable(tableName) @since(1.3) def createExternalTable(self, tableName, path=None, source=None, schema=None, **options): @@ -618,7 +554,7 @@ def tableNames(self, dbName=None): >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> "table1" in sqlContext.tableNames() True - >>> "table1" in sqlContext.tableNames("db") + >>> "table1" in sqlContext.tableNames("default") True """ if dbName is None: @@ -676,9 +612,10 @@ def _ssql_ctx(self): self._scala_HiveContext = self._get_hive_ctx() return self._scala_HiveContext except Py4JError as e: - raise Exception("You must build Spark with Hive. " - "Export 'SPARK_HIVE=true' and run " - "build/sbt assembly", e) + print("You must build Spark with Hive. " + "Export 'SPARK_HIVE=true' and run " + "build/sbt assembly", file=sys.stderr) + raise def _get_hive_ctx(self): return self._jvm.HiveContext(self._jsc.sc()) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 3baff8147753d..b4fa8368936a4 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -26,7 +26,7 @@ else: from itertools import imap as map -from pyspark import since +from pyspark import copy_func, since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel @@ -36,7 +36,7 @@ from pyspark.sql.readwriter import DataFrameWriter from pyspark.sql.types import * -__all__ = ["DataFrame", "SchemaRDD", "DataFrameNaFunctions", "DataFrameStatFunctions"] +__all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"] class DataFrame(object): @@ -60,7 +60,7 @@ class DataFrame(object): people = sqlContext.read.parquet("...") department = sqlContext.read.parquet("...") - people.filter(people.age > 30).join(department, people.deptId == department.id)) \ + people.filter(people.age > 30).join(department, people.deptId == department.id)\ .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) .. note:: Experimental @@ -113,14 +113,6 @@ def toJSON(self, use_unicode=True): rdd = self._jdf.toJSON() return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) - def saveAsParquetFile(self, path): - """Saves the contents as a Parquet file, preserving the schema. - - .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.parquet` instead. - """ - warnings.warn("saveAsParquetFile is deprecated. Use write.parquet() instead.") - self._jdf.saveAsParquetFile(path) - @since(1.3) def registerTempTable(self, name): """Registers this RDD as a temporary table using the given name. @@ -135,38 +127,6 @@ def registerTempTable(self, name): """ self._jdf.registerTempTable(name) - def registerAsTable(self, name): - """ - .. note:: Deprecated in 1.4, use :func:`registerTempTable` instead. - """ - warnings.warn("Use registerTempTable instead of registerAsTable.") - self.registerTempTable(name) - - def insertInto(self, tableName, overwrite=False): - """Inserts the contents of this :class:`DataFrame` into the specified table. - - .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.insertInto` instead. - """ - warnings.warn("insertInto is deprecated. Use write.insertInto() instead.") - self.write.insertInto(tableName, overwrite) - - def saveAsTable(self, tableName, source=None, mode="error", **options): - """Saves the contents of this :class:`DataFrame` to a data source as a table. - - .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.saveAsTable` instead. - """ - warnings.warn("insertInto is deprecated. Use write.saveAsTable() instead.") - self.write.saveAsTable(tableName, source, mode, **options) - - @since(1.3) - def save(self, path=None, source=None, mode="error", **options): - """Saves the contents of the :class:`DataFrame` to a data source. - - .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.save` instead. - """ - warnings.warn("insertInto is deprecated. Use write.save() instead.") - return self.write.save(path, source, mode, **options) - @property @since(1.4) def write(self): @@ -213,7 +173,7 @@ def explain(self, extended=False): >>> df.explain() == Physical Plan == - Scan PhysicalRDD[age#0,name#1] + Scan ExistingRDD[age#0,name#1] >>> df.explain(True) == Parsed Logical Plan == @@ -277,9 +237,23 @@ def collect(self): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: - port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd()) + port = self._jdf.collectToPython() return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) + @ignore_unicode_prefix + @since(2.0) + def toLocalIterator(self): + """ + Returns an iterator that contains all of the rows in this :class:`DataFrame`. + The iterator will consume as much memory as the largest partition in this DataFrame. + + >>> list(df.toLocalIterator()) + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + """ + with SCCallSiteSync(self._sc) as css: + port = self._jdf.toPythonIterator() + return _load_from_socket(port, BatchedSerializer(PickleSerializer())) + @ignore_unicode_prefix @since(1.3) def limit(self, num): @@ -302,48 +276,10 @@ def take(self, num): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: - port = self._sc._jvm.org.apache.spark.sql.execution.EvaluatePython.takeAndServe( + port = self._sc._jvm.org.apache.spark.sql.execution.python.EvaluatePython.takeAndServe( self._jdf, num) return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) - @ignore_unicode_prefix - @since(1.3) - def map(self, f): - """ Returns a new :class:`RDD` by applying a the ``f`` function to each :class:`Row`. - - This is a shorthand for ``df.rdd.map()``. - - >>> df.map(lambda p: p.name).collect() - [u'Alice', u'Bob'] - """ - return self.rdd.map(f) - - @ignore_unicode_prefix - @since(1.3) - def flatMap(self, f): - """ Returns a new :class:`RDD` by first applying the ``f`` function to each :class:`Row`, - and then flattening the results. - - This is a shorthand for ``df.rdd.flatMap()``. - - >>> df.flatMap(lambda p: p.name).collect() - [u'A', u'l', u'i', u'c', u'e', u'B', u'o', u'b'] - """ - return self.rdd.flatMap(f) - - @since(1.3) - def mapPartitions(self, f, preservesPartitioning=False): - """Returns a new :class:`RDD` by applying the ``f`` function to each partition. - - This is a shorthand for ``df.rdd.mapPartitions()``. - - >>> rdd = sc.parallelize([1, 2, 3, 4], 4) - >>> def f(iterator): yield 1 - >>> rdd.mapPartitions(f).sum() - 4 - """ - return self.rdd.mapPartitions(f, preservesPartitioning) - @since(1.3) def foreach(self, f): """Applies the ``f`` function to all :class:`Row` of this :class:`DataFrame`. @@ -354,7 +290,7 @@ def foreach(self, f): ... print(person.name) >>> df.foreach(f) """ - return self.rdd.foreach(f) + self.rdd.foreach(f) @since(1.3) def foreachPartition(self, f): @@ -367,22 +303,22 @@ def foreachPartition(self, f): ... print(person.name) >>> df.foreachPartition(f) """ - return self.rdd.foreachPartition(f) + self.rdd.foreachPartition(f) @since(1.3) def cache(self): - """ Persists with the default storage level (C{MEMORY_ONLY_SER}). + """ Persists with the default storage level (C{MEMORY_ONLY}). """ self.is_cached = True self._jdf.cache() return self @since(1.3) - def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): + def persist(self, storageLevel=StorageLevel.MEMORY_ONLY): """Sets the storage level to persist its values across operations after the first time it is computed. This can only be used to assign a new storage level if the RDD does not have a storage level set yet. - If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). + If no storage level is specified defaults to (C{MEMORY_ONLY}). """ self.is_cached = True javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) @@ -422,6 +358,67 @@ def repartition(self, numPartitions): """ return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) + @since(1.3) + def repartition(self, numPartitions, *cols): + """ + Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The + resulting DataFrame is hash partitioned. + + ``numPartitions`` can be an int to specify the target number of partitions or a Column. + If it is a Column, it will be used as the first partitioning column. If not specified, + the default number of partitions is used. + + .. versionchanged:: 1.6 + Added optional arguments to specify the partitioning columns. Also made numPartitions + optional if partitioning columns are specified. + + >>> df.repartition(10).rdd.getNumPartitions() + 10 + >>> data = df.union(df).repartition("age") + >>> data.show() + +---+-----+ + |age| name| + +---+-----+ + | 5| Bob| + | 5| Bob| + | 2|Alice| + | 2|Alice| + +---+-----+ + >>> data = data.repartition(7, "age") + >>> data.show() + +---+-----+ + |age| name| + +---+-----+ + | 5| Bob| + | 5| Bob| + | 2|Alice| + | 2|Alice| + +---+-----+ + >>> data.rdd.getNumPartitions() + 7 + >>> data = data.repartition("name", "age") + >>> data.show() + +---+-----+ + |age| name| + +---+-----+ + | 5| Bob| + | 5| Bob| + | 2|Alice| + | 2|Alice| + +---+-----+ + """ + if isinstance(numPartitions, int): + if len(cols) == 0: + return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) + else: + return DataFrame( + self._jdf.repartition(numPartitions, self._jcols(*cols)), self.sql_ctx) + elif isinstance(numPartitions, (basestring, Column)): + cols = (numPartitions, ) + cols + return DataFrame(self._jdf.repartition(self._jcols(*cols)), self.sql_ctx) + else: + raise TypeError("numPartitions should be an int or Column") + @since(1.3) def distinct(self): """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. @@ -436,7 +433,7 @@ def sample(self, withReplacement, fraction, seed=None): """Returns a sampled subset of this :class:`DataFrame`. >>> df.sample(False, 0.5, 42).count() - 1 + 2 """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction seed = seed if seed is not None else random.randint(0, sys.maxsize) @@ -463,8 +460,8 @@ def sampleBy(self, col, fractions, seed=None): +---+-----+ |key|count| +---+-----+ - | 0| 3| - | 1| 8| + | 0| 5| + | 1| 9| +---+-----+ """ @@ -530,7 +527,7 @@ def alias(self, alias): >>> df_as1 = df.alias("df_as1") >>> df_as2 = df.alias("df_as2") >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner') - >>> joined_df.select(col("df_as1.name"), col("df_as2.name"), col("df_as2.age")).collect() + >>> joined_df.select("df_as1.name", "df_as2.name", "df_as2.age").collect() [Row(name=u'Alice', name=u'Alice', age=2), Row(name=u'Bob', name=u'Bob', age=5)] """ assert isinstance(alias, basestring), "alias should be a string" @@ -547,16 +544,19 @@ def join(self, other, on=None, how=None): :param on: a string for join column name, a list of column names, , a join expression (Column) or a list of Columns. If `on` is a string or a list of string indicating the name of the join column(s), - the column(s) must exist on both sides, and this performs an inner equi-join. + the column(s) must exist on both sides, and this performs an equi-join. :param how: str, default 'inner'. - One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. + One of `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() - [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] + [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] + + >>> df.join(df2, 'name', 'outer').select('name', 'height').collect() + [Row(name=u'Tom', height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] >>> cond = [df.name == df3.name, df.age == df3.age] >>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect() - [Row(name=u'Bob', age=5), Row(name=u'Alice', age=2)] + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] >>> df.join(df2, 'name').select(df.name, df2.height).collect() [Row(name=u'Bob', height=85)] @@ -589,6 +589,26 @@ def join(self, other, on=None, how=None): jdf = self._jdf.join(other._jdf, on._jc, how) return DataFrame(jdf, self.sql_ctx) + @since(1.6) + def sortWithinPartitions(self, *cols, **kwargs): + """Returns a new :class:`DataFrame` with each partition sorted by the specified column(s). + + :param cols: list of :class:`Column` or column names to sort by. + :param ascending: boolean or list of boolean (default True). + Sort ascending vs. descending. Specify list for multiple sort orders. + If a list is specified, length of the list must equal length of the `cols`. + + >>> df.sortWithinPartitions("age", ascending=False).show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 5| Bob| + +---+-----+ + """ + jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs)) + return DataFrame(jdf, self.sql_ctx) + @ignore_unicode_prefix @since(1.3) def sort(self, *cols, **kwargs): @@ -613,22 +633,7 @@ def sort(self, *cols, **kwargs): >>> df.orderBy(["age", "name"], ascending=[0, 1]).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] """ - if not cols: - raise ValueError("should sort by at least one column") - if len(cols) == 1 and isinstance(cols[0], list): - cols = cols[0] - jcols = [_to_java_column(c) for c in cols] - ascending = kwargs.get('ascending', True) - if isinstance(ascending, (bool, int)): - if not ascending: - jcols = [jc.desc() for jc in jcols] - elif isinstance(ascending, list): - jcols = [jc if asc else jc.desc() - for asc, jc in zip(ascending, jcols)] - else: - raise TypeError("ascending can only be boolean or list, but got %s" % type(ascending)) - - jdf = self._jdf.sort(self._jseq(jcols)) + jdf = self._jdf.sort(self._sort_cols(cols, kwargs)) return DataFrame(jdf, self.sql_ctx) orderBy = sort @@ -650,6 +655,25 @@ def _jcols(self, *cols): cols = cols[0] return self._jseq(cols, _to_java_column) + def _sort_cols(self, cols, kwargs): + """ Return a JVM Seq of Columns that describes the sort order + """ + if not cols: + raise ValueError("should sort by at least one column") + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] + jcols = [_to_java_column(c) for c in cols] + ascending = kwargs.get('ascending', True) + if isinstance(ascending, (bool, int)): + if not ascending: + jcols = [jc.desc() for jc in jcols] + elif isinstance(ascending, list): + jcols = [jc if asc else jc.desc() + for asc, jc in zip(ascending, jcols)] + else: + raise TypeError("ascending can only be boolean or list, but got %s" % type(ascending)) + return self._jseq(jcols) + @since("1.3.1") def describe(self, *cols): """Computes statistics for numeric columns. @@ -691,6 +715,9 @@ def describe(self, *cols): def head(self, n=None): """Returns the first ``n`` rows. + Note that this method should only be used if the resulting array is expected + to be small, as all the data is loaded into the driver's memory. + :param n: int, default 1. Number of rows to return. :return: If n is greater than 1, return a list of :class:`Row`. If n is 1, return a single Row. @@ -781,7 +808,7 @@ def selectExpr(self, *expr): This is a variant of :func:`select` that accepts SQL expressions. >>> df.selectExpr("age * 2", "abs(age)").collect() - [Row((age * 2)=4, 'abs(age)=2), Row((age * 2)=10, 'abs(age)=5)] + [Row((age * 2)=4, abs(age)=2), Row((age * 2)=10, abs(age)=5)] """ if len(expr) == 1 and isinstance(expr[0], list): expr = expr[0] @@ -816,8 +843,6 @@ def filter(self, condition): raise TypeError("condition should be string or Column") return DataFrame(jdf, self.sql_ctx) - where = filter - @ignore_unicode_prefix @since(1.3) def groupBy(self, *cols): @@ -832,12 +857,12 @@ def groupBy(self, *cols): >>> df.groupBy().avg().collect() [Row(avg(age)=3.5)] - >>> df.groupBy('name').agg({'age': 'mean'}).collect() + >>> sorted(df.groupBy('name').agg({'age': 'mean'}).collect()) [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] - >>> df.groupBy(df.name).avg().collect() + >>> sorted(df.groupBy(df.name).avg().collect()) [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] - >>> df.groupBy(['name', df.age]).count().collect() - [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)] + >>> sorted(df.groupBy(['name', df.age]).count().collect()) + [Row(name=u'Alice', age=2, count=1), Row(name=u'Bob', age=5, count=1)] """ jgd = self._jdf.groupBy(self._jcols(*cols)) from pyspark.sql.group import GroupedData @@ -849,15 +874,15 @@ def rollup(self, *cols): Create a multi-dimensional rollup for the current :class:`DataFrame` using the specified columns, so we can run aggregation on them. - >>> df.rollup('name', df.age).count().show() + >>> df.rollup("name", df.age).count().orderBy("name", "age").show() +-----+----+-----+ | name| age|count| +-----+----+-----+ - |Alice|null| 1| - | Bob| 5| 1| - | Bob|null| 1| | null|null| 2| + |Alice|null| 1| |Alice| 2| 1| + | Bob|null| 1| + | Bob| 5| 1| +-----+----+-----+ """ jgd = self._jdf.rollup(self._jcols(*cols)) @@ -870,17 +895,17 @@ def cube(self, *cols): Create a multi-dimensional cube for the current :class:`DataFrame` using the specified columns, so we can run aggregation on them. - >>> df.cube('name', df.age).count().show() + >>> df.cube("name", df.age).count().orderBy("name", "age").show() +-----+----+-----+ | name| age|count| +-----+----+-----+ + | null|null| 2| | null| 2| 1| - |Alice|null| 1| - | Bob| 5| 1| - | Bob|null| 1| | null| 5| 1| - | null|null| 2| + |Alice|null| 1| |Alice| 2| 1| + | Bob|null| 1| + | Bob| 5| 1| +-----+----+-----+ """ jgd = self._jdf.cube(self._jcols(*cols)) @@ -900,14 +925,24 @@ def agg(self, *exprs): """ return self.groupBy().agg(*exprs) + @since(2.0) + def union(self, other): + """ Return a new :class:`DataFrame` containing union of rows in this + frame and another frame. + + This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union + (that does deduplication of elements), use this function followed by a distinct. + """ + return DataFrame(self._jdf.union(other._jdf), self.sql_ctx) + @since(1.3) def unionAll(self, other): """ Return a new :class:`DataFrame` containing union of rows in this frame and another frame. - This is equivalent to `UNION ALL` in SQL. + .. note:: Deprecated in 2.0, use union instead. """ - return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) + return self.union(other) @since(1.3) def intersect(self, other): @@ -1126,6 +1161,55 @@ def replace(self, to_replace, value, subset=None): return DataFrame( self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx) + @since(2.0) + def approxQuantile(self, col, probabilities, relativeError): + """ + Calculates the approximate quantiles of a numerical column of a + DataFrame. + + The result of this algorithm has the following deterministic bound: + If the DataFrame has N elements and if we request the quantile at + probability `p` up to error `err`, then the algorithm will return + a sample `x` from the DataFrame so that the *exact* rank of `x` is + close to (p * N). More precisely, + + floor((p - err) * N) <= rank(x) <= ceil((p + err) * N). + + This method implements a variation of the Greenwald-Khanna + algorithm (with some speed optimizations). The algorithm was first + present in [[http://dx.doi.org/10.1145/375663.375670 + Space-efficient Online Computation of Quantile Summaries]] + by Greenwald and Khanna. + + :param col: the name of the numerical column + :param probabilities: a list of quantile probabilities + Each number must belong to [0, 1]. + For example 0 is the minimum, 0.5 is the median, 1 is the maximum. + :param relativeError: The relative target precision to achieve + (>= 0). If set to zero, the exact quantiles are computed, which + could be very expensive. Note that values greater than 1 are + accepted but give the same result as 1. + :return: the approximate quantiles at the given probabilities + """ + if not isinstance(col, str): + raise ValueError("col should be a string.") + + if not isinstance(probabilities, (list, tuple)): + raise ValueError("probabilities should be a list or tuple") + if isinstance(probabilities, tuple): + probabilities = list(probabilities) + for p in probabilities: + if not isinstance(p, (float, int, long)) or p < 0 or p > 1: + raise ValueError("probabilities should be numerical (float, int, long) in [0,1].") + probabilities = _to_list(self._sc, probabilities) + + if not isinstance(relativeError, (float, int, long)) or relativeError < 0: + raise ValueError("relativeError should be numerical (float, int, long) >= 0.") + relativeError = float(relativeError) + + jaq = self._jdf.stat().approxQuantile(col, probabilities, relativeError) + return list(jaq) + @since(1.4) def corr(self, col1, col2, method=None): """ @@ -1282,6 +1366,9 @@ def toDF(self, *cols): def toPandas(self): """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. + Note that this method should only be used if the resulting Pandas's DataFrame is expected + to be small, as all the data is loaded into the driver's memory. + This is only available if Pandas is installed and available. >>> df.toPandas() # doctest: +SKIP @@ -1296,14 +1383,20 @@ def toPandas(self): # Pandas compatibility ########################################################################################## - groupby = groupBy - drop_duplicates = dropDuplicates + groupby = copy_func( + groupBy, + sinceversion=1.4, + doc=":func:`groupby` is an alias for :func:`groupBy`.") + drop_duplicates = copy_func( + dropDuplicates, + sinceversion=1.4, + doc=":func:`drop_duplicates` is an alias for :func:`dropDuplicates`.") -# Having SchemaRDD for backward compatibility (for docs) -class SchemaRDD(DataFrame): - """SchemaRDD is deprecated, please use :class:`DataFrame`. - """ + where = copy_func( + filter, + sinceversion=1.3, + doc=":func:`where` is an alias for :func:`filter`.") def _to_scala_map(sc, jm): @@ -1347,6 +1440,11 @@ class DataFrameStatFunctions(object): def __init__(self, df): self.df = df + def approxQuantile(self, col, probabilities, relativeError): + return self.df.approxQuantile(col, probabilities, relativeError) + + approxQuantile.__doc__ = DataFrame.approxQuantile.__doc__ + def corr(self, col1, col2, method=None): return self.df.corr(col1, col2, method) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 2f7c2f4aacd47..5017ab5b3646d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -81,8 +81,6 @@ def _(): 'max': 'Aggregate function: returns the maximum value of the expression in a group.', 'min': 'Aggregate function: returns the minimum value of the expression in a group.', - 'first': 'Aggregate function: returns the first value in a group.', - 'last': 'Aggregate function: returns the last value in a group.', 'count': 'Aggregate function: returns the number of items in a group.', 'sum': 'Aggregate function: returns the sum of all values in the expression.', 'avg': 'Aggregate function: returns the average of the values in a group.', @@ -124,41 +122,40 @@ def _(): _functions_1_6 = { # unary math functions - "stddev": "Aggregate function: returns the unbiased sample standard deviation of" + - " the expression in a group.", - "stddev_samp": "Aggregate function: returns the unbiased sample standard deviation of" + - " the expression in a group.", - "stddev_pop": "Aggregate function: returns population standard deviation of" + - " the expression in a group.", - "variance": "Aggregate function: returns the population variance of the values in a group.", - "var_samp": "Aggregate function: returns the unbiased variance of the values in a group.", - "var_pop": "Aggregate function: returns the population variance of the values in a group.", - "skewness": "Aggregate function: returns the skewness of the values in a group.", - "kurtosis": "Aggregate function: returns the kurtosis of the values in a group." + 'stddev': 'Aggregate function: returns the unbiased sample standard deviation of' + + ' the expression in a group.', + 'stddev_samp': 'Aggregate function: returns the unbiased sample standard deviation of' + + ' the expression in a group.', + 'stddev_pop': 'Aggregate function: returns population standard deviation of' + + ' the expression in a group.', + 'variance': 'Aggregate function: returns the population variance of the values in a group.', + 'var_samp': 'Aggregate function: returns the unbiased variance of the values in a group.', + 'var_pop': 'Aggregate function: returns the population variance of the values in a group.', + 'skewness': 'Aggregate function: returns the skewness of the values in a group.', + 'kurtosis': 'Aggregate function: returns the kurtosis of the values in a group.', + 'collect_list': 'Aggregate function: returns a list of objects with duplicates.', + 'collect_set': 'Aggregate function: returns a set of objects with duplicate elements' + + ' eliminated.' } # math functions that take two arguments as input _binary_mathfunctions = { 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' + 'polar coordinates (r, theta).', - 'hypot': 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.', + 'hypot': 'Computes `sqrt(a^2 + b^2)` without intermediate overflow or underflow.', 'pow': 'Returns the value of the first argument raised to the power of the second argument.', } _window_functions = { - 'rowNumber': - """returns a sequential number starting at 1 within a window partition. - - This is equivalent to the ROW_NUMBER function in SQL.""", - 'denseRank': + 'row_number': + """returns a sequential number starting at 1 within a window partition.""", + 'dense_rank': """returns the rank of rows within a window partition, without any gaps. The difference between rank and denseRank is that denseRank leaves no gaps in ranking sequence when there are ties. That is, if you were ranking a competition using denseRank and had three people tie for second place, you would say that all three were in second - place and that the next person came in third. - - This is equivalent to the DENSE_RANK function in SQL.""", + place and that the next person came in third.""", 'rank': """returns the rank of rows within a window partition. @@ -168,15 +165,11 @@ def _(): place and that the next person came in third. This is equivalent to the RANK function in SQL.""", - 'cumeDist': + 'cume_dist': """returns the cumulative distribution of values within a window partition, - i.e. the fraction of rows that are below the current row. - - This is equivalent to the CUME_DIST function in SQL.""", - 'percentRank': - """returns the relative rank (i.e. percentile) of rows within a window partition. - - This is equivalent to the PERCENT_RANK function in SQL.""", + i.e. the fraction of rows that are below the current row.""", + 'percent_rank': + """returns the relative rank (i.e. percentile) of rows within a window partition.""", } for _name, _doc in _functions.items(): @@ -186,7 +179,7 @@ def _(): for _name, _doc in _binary_mathfunctions.items(): globals()[_name] = since(1.4)(_create_binary_mathfunction(_name, _doc)) for _name, _doc in _window_functions.items(): - globals()[_name] = since(1.4)(_create_window_function(_name, _doc)) + globals()[_name] = since(1.6)(_create_window_function(_name, _doc)) for _name, _doc in _functions_1_6.items(): globals()[_name] = since(1.6)(_create_function(_name, _doc)) del _name, _doc @@ -230,28 +223,73 @@ def coalesce(*cols): +----+----+ >>> cDf.select(coalesce(cDf["a"], cDf["b"])).show() - +-------------+ - |coalesce(a,b)| - +-------------+ - | null| - | 1| - | 2| - +-------------+ + +--------------+ + |coalesce(a, b)| + +--------------+ + | null| + | 1| + | 2| + +--------------+ >>> cDf.select('*', coalesce(cDf["a"], lit(0.0))).show() - +----+----+---------------+ - | a| b|coalesce(a,0.0)| - +----+----+---------------+ - |null|null| 0.0| - | 1|null| 1.0| - |null| 2| 0.0| - +----+----+---------------+ + +----+----+----------------+ + | a| b|coalesce(a, 0.0)| + +----+----+----------------+ + |null|null| 0.0| + | 1|null| 1.0| + |null| 2| 0.0| + +----+----+----------------+ """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.coalesce(_to_seq(sc, cols, _to_java_column)) return Column(jc) +@since(1.6) +def corr(col1, col2): + """Returns a new :class:`Column` for the Pearson Correlation Coefficient for ``col1`` + and ``col2``. + + >>> a = range(20) + >>> b = [2 * x for x in range(20)] + >>> df = sqlContext.createDataFrame(zip(a, b), ["a", "b"]) + >>> df.agg(corr("a", "b").alias('c')).collect() + [Row(c=1.0)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.corr(_to_java_column(col1), _to_java_column(col2))) + + +@since(2.0) +def covar_pop(col1, col2): + """Returns a new :class:`Column` for the population covariance of ``col1`` + and ``col2``. + + >>> a = [1] * 10 + >>> b = [1] * 10 + >>> df = sqlContext.createDataFrame(zip(a, b), ["a", "b"]) + >>> df.agg(covar_pop("a", "b").alias('c')).collect() + [Row(c=0.0)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.covar_pop(_to_java_column(col1), _to_java_column(col2))) + + +@since(2.0) +def covar_samp(col1, col2): + """Returns a new :class:`Column` for the sample covariance of ``col1`` + and ``col2``. + + >>> a = [1] * 10 + >>> b = [1] * 10 + >>> df = sqlContext.createDataFrame(zip(a, b), ["a", "b"]) + >>> df.agg(covar_samp("a", "b").alias('c')).collect() + [Row(c=0.0)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.covar_samp(_to_java_column(col1), _to_java_column(col2))) + + @since(1.3) def countDistinct(col, *cols): """Returns a new :class:`Column` for distinct count of ``col`` or ``cols``. @@ -267,8 +305,108 @@ def countDistinct(col, *cols): return Column(jc) -@since(1.4) -def monotonicallyIncreasingId(): +@since(1.3) +def first(col, ignorenulls=False): + """Aggregate function: returns the first value in a group. + + The function by default returns the first values it sees. It will return the first non-null + value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.first(_to_java_column(col), ignorenulls) + return Column(jc) + + +@since(2.0) +def grouping(col): + """ + Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated + or not, returns 1 for aggregated or 0 for not aggregated in the result set. + + >>> df.cube("name").agg(grouping("name"), sum("age")).orderBy("name").show() + +-----+--------------+--------+ + | name|grouping(name)|sum(age)| + +-----+--------------+--------+ + | null| 1| 7| + |Alice| 0| 2| + | Bob| 0| 5| + +-----+--------------+--------+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.grouping(_to_java_column(col)) + return Column(jc) + + +@since(2.0) +def grouping_id(*cols): + """ + Aggregate function: returns the level of grouping, equals to + + (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) + + Note: the list of columns should match with grouping columns exactly, or empty (means all the + grouping columns). + + >>> df.cube("name").agg(grouping_id(), sum("age")).orderBy("name").show() + +-----+-------------+--------+ + | name|grouping_id()|sum(age)| + +-----+-------------+--------+ + | null| 1| 7| + |Alice| 0| 2| + | Bob| 0| 5| + +-----+-------------+--------+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.grouping_id(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + +@since(1.6) +def input_file_name(): + """Creates a string column for the file name of the current Spark task. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.input_file_name()) + + +@since(1.6) +def isnan(col): + """An expression that returns true iff the column is NaN. + + >>> df = sqlContext.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) + >>> df.select(isnan("a").alias("r1"), isnan(df.a).alias("r2")).collect() + [Row(r1=False, r2=False), Row(r1=True, r2=True)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.isnan(_to_java_column(col))) + + +@since(1.6) +def isnull(col): + """An expression that returns true iff the column is null. + + >>> df = sqlContext.createDataFrame([(1, None), (None, 2)], ("a", "b")) + >>> df.select(isnull("a").alias("r1"), isnull(df.a).alias("r2")).collect() + [Row(r1=False, r2=False), Row(r1=True, r2=True)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.isnull(_to_java_column(col))) + + +@since(1.3) +def last(col, ignorenulls=False): + """Aggregate function: returns the last value in a group. + + The function by default returns the last values it sees. It will return the last non-null + value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.last(_to_java_column(col), ignorenulls) + return Column(jc) + + +@since(1.6) +def monotonically_increasing_id(): """A column that generates monotonically increasing 64-bit integers. The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. @@ -281,11 +419,25 @@ def monotonicallyIncreasingId(): 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. >>> df0 = sc.parallelize(range(2), 2).mapPartitions(lambda x: [(1,), (2,), (3,)]).toDF(['col1']) - >>> df0.select(monotonicallyIncreasingId().alias('id')).collect() + >>> df0.select(monotonically_increasing_id().alias('id')).collect() [Row(id=0), Row(id=1), Row(id=2), Row(id=8589934592), Row(id=8589934593), Row(id=8589934594)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.monotonicallyIncreasingId()) + return Column(sc._jvm.functions.monotonically_increasing_id()) + + +@since(1.6) +def nanvl(col1, col2): + """Returns col1 if it is not NaN, or col2 if col1 is NaN. + + Both inputs should be floating point columns (DoubleType or FloatType). + + >>> df = sqlContext.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) + >>> df.select(nanvl("a", "b").alias("r1"), nanvl(df.a, df.b).alias("r2")).collect() + [Row(r1=1.0, r2=1.0), Row(r1=2.0, r2=2.0)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.nanvl(_to_java_column(col1), _to_java_column(col2))) @since(1.4) @@ -327,7 +479,7 @@ def round(col, scale=0): @since(1.5) def shiftLeft(col, numBits): - """Shift the the given value numBits left. + """Shift the given value numBits left. >>> sqlContext.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).collect() [Row(r=42)] @@ -338,7 +490,7 @@ def shiftLeft(col, numBits): @since(1.5) def shiftRight(col, numBits): - """Shift the the given value numBits right. + """Shift the given value numBits right. >>> sqlContext.createDataFrame([(42,)], ['a']).select(shiftRight('a', 1).alias('r')).collect() [Row(r=21)] @@ -350,7 +502,7 @@ def shiftRight(col, numBits): @since(1.5) def shiftRightUnsigned(col, numBits): - """Unsigned shift the the given value numBits right. + """Unsigned shift the given value numBits right. >>> df = sqlContext.createDataFrame([(-42,)], ['a']) >>> df.select(shiftRightUnsigned('a', 1).alias('r')).collect() @@ -361,17 +513,17 @@ def shiftRightUnsigned(col, numBits): return Column(jc) -@since(1.4) -def sparkPartitionId(): +@since(1.6) +def spark_partition_id(): """A column for partition ID of the Spark task. Note that this is indeterministic because it depends on data partitioning and task scheduling. - >>> df.repartition(1).select(sparkPartitionId().alias("pid")).collect() + >>> df.repartition(1).select(spark_partition_id().alias("pid")).collect() [Row(pid=0), Row(pid=0)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.sparkPartitionId()) + return Column(sc._jvm.functions.spark_partition_id()) @since(1.5) @@ -379,7 +531,7 @@ def expr(str): """Parses the expression string into the column that it represents >>> df.select(expr("length(name)")).collect() - [Row('length(name)=5), Row('length(name)=3)] + [Row(length(name)=5), Row(length(name)=3)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.expr(str)) @@ -464,10 +616,10 @@ def log(arg1, arg2=None): If there is only one argument, then this takes the natural logarithm of the argument. - >>> df.select(log(10.0, df.age).alias('ten')).map(lambda l: str(l.ten)[:7]).collect() + >>> df.select(log(10.0, df.age).alias('ten')).rdd.map(lambda l: str(l.ten)[:7]).collect() ['0.30102', '0.69897'] - >>> df.select(log(df.age).alias('e')).map(lambda l: str(l.e)[:7]).collect() + >>> df.select(log(df.age).alias('e')).rdd.map(lambda l: str(l.e)[:7]).collect() ['0.69314', '1.60943'] """ sc = SparkContext._active_spark_context @@ -901,6 +1053,55 @@ def to_utc_timestamp(timestamp, tz): return Column(sc._jvm.functions.to_utc_timestamp(_to_java_column(timestamp), tz)) +@since(2.0) +@ignore_unicode_prefix +def window(timeColumn, windowDuration, slideDuration=None, startTime=None): + """Bucketize rows into one or more time windows given a timestamp specifying column. Window + starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window + [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in + the order of months are not supported. + + The time column must be of TimestampType. + + Durations are provided as strings, e.g. '1 second', '1 day 12 hours', '2 minutes'. Valid + interval strings are 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'. + If the `slideDuration` is not provided, the windows will be tumbling windows. + + The startTime is the offset with respect to 1970-01-01 00:00:00 UTC with which to start + window intervals. For example, in order to have hourly tumbling windows that start 15 minutes + past the hour, e.g. 12:15-13:15, 13:15-14:15... provide `startTime` as `15 minutes`. + + The output column will be a struct called 'window' by default with the nested columns 'start' + and 'end', where 'start' and 'end' will be of `TimestampType`. + + >>> df = sqlContext.createDataFrame([("2016-03-11 09:00:07", 1)]).toDF("date", "val") + >>> w = df.groupBy(window("date", "5 seconds")).agg(sum("val").alias("sum")) + >>> w.select(w.window.start.cast("string").alias("start"), + ... w.window.end.cast("string").alias("end"), "sum").collect() + [Row(start=u'2016-03-11 09:00:05', end=u'2016-03-11 09:00:10', sum=1)] + """ + def check_string_field(field, fieldName): + if not field or type(field) is not str: + raise TypeError("%s should be provided as a string" % fieldName) + + sc = SparkContext._active_spark_context + time_col = _to_java_column(timeColumn) + check_string_field(windowDuration, "windowDuration") + if slideDuration and startTime: + check_string_field(slideDuration, "slideDuration") + check_string_field(startTime, "startTime") + res = sc._jvm.functions.window(time_col, windowDuration, slideDuration, startTime) + elif slideDuration: + check_string_field(slideDuration, "slideDuration") + res = sc._jvm.functions.window(time_col, windowDuration, slideDuration) + elif startTime: + check_string_field(startTime, "startTime") + res = sc._jvm.functions.window(time_col, windowDuration, windowDuration, startTime) + else: + res = sc._jvm.functions.window(time_col, windowDuration) + return Column(res) + + # ---------------------------- misc functions ---------------------------------- @since(1.5) @@ -961,6 +1162,18 @@ def sha2(col, numBits): return Column(jc) +@since(2.0) +def hash(*cols): + """Calculates the hash code of given columns, and returns the result as a int column. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(hash('a').alias('hash')).collect() + [Row(hash=-757602832)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.hash(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + # ---------------------- String/Binary functions ------------------------------ _string_functions = { @@ -972,7 +1185,7 @@ def sha2(col, numBits): 'lower': 'Converts a string column to lower case.', 'upper': 'Converts a string column to upper case.', 'reverse': 'Reverses the string column and returns it as a new string column.', - 'ltrim': 'Trim the spaces from right end for the specified string value.', + 'ltrim': 'Trim the spaces from left end for the specified string value.', 'rtrim': 'Trim the spaces from right end for the specified string value.', 'trim': 'Trim the spaces from both ends for the specified string column.', } @@ -1334,6 +1547,26 @@ def translate(srcCol, matching, replace): # ---------------------- Collection functions ------------------------------ +@ignore_unicode_prefix +@since(2.0) +def create_map(*cols): + """Creates a new map column. + + :param cols: list of column names (string) or list of :class:`Column` expressions that grouped + as key-value pairs, e.g. (key1, value1, key2, value2, ...). + + >>> df.select(create_map('name', 'age').alias("map")).collect() + [Row(map={u'Alice': 2}), Row(map={u'Bob': 5})] + >>> df.select(create_map([df.name, df.age]).alias("map")).collect() + [Row(map={u'Alice': 2}), Row(map={u'Bob': 5})] + """ + sc = SparkContext._active_spark_context + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + jc = sc._jvm.functions.map(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + @since(1.4) def array(*cols): """Creates a new array column. @@ -1364,7 +1597,7 @@ def array_contains(col, value): >>> df = sqlContext.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) >>> df.select(array_contains(df.data, "a")).collect() - [Row(array_contains(data,a)=True), Row(array_contains(data,a)=False)] + [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) @@ -1391,6 +1624,45 @@ def explode(col): return Column(jc) +@ignore_unicode_prefix +@since(1.6) +def get_json_object(col, path): + """ + Extracts json object from a json string based on json path specified, and returns json string + of the extracted json object. It will return null if the input json string is invalid. + + :param col: string column in json format + :param path: path to the json object to extract + + >>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1": "value12"}''')] + >>> df = sqlContext.createDataFrame(data, ("key", "jstring")) + >>> df.select(df.key, get_json_object(df.jstring, '$.f1').alias("c0"), \ + get_json_object(df.jstring, '$.f2').alias("c1") ).collect() + [Row(key=u'1', c0=u'value1', c1=u'value2'), Row(key=u'2', c0=u'value12', c1=None)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.get_json_object(_to_java_column(col), path) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.6) +def json_tuple(col, *fields): + """Creates a new row for a json column according to the given field names. + + :param col: string column in json format + :param fields: list of fields to extract + + >>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1": "value12"}''')] + >>> df = sqlContext.createDataFrame(data, ("key", "jstring")) + >>> df.select(df.key, json_tuple(df.jstring, 'f1', 'f2')).collect() + [Row(key=u'1', c0=u'value1', c1=u'value2'), Row(key=u'2', c0=u'value12', c1=None)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.json_tuple(_to_java_column(col), _to_seq(sc, fields)) + return Column(jc) + + @since(1.5) def size(col): """ @@ -1425,6 +1697,13 @@ def sort_array(col, asc=True): # ---------------------------- User Defined Function ---------------------------------- +def _wrap_function(sc, func, returnType): + command = (func, returnType) + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) + return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, + sc.pythonVer, broadcast_vars, sc._javaAccumulator) + + class UserDefinedFunction(object): """ User defined function in Python @@ -1438,19 +1717,16 @@ def __init__(self, func, returnType, name=None): self._judf = self._create_judf(name) def _create_judf(self, name): - f, returnType = self.func, self.returnType # put them in closure `func` - func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it) - ser = AutoBatchedSerializer(PickleSerializer()) - command = (func, None, ser, ser) - sc = SparkContext._active_spark_context - pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - jdt = ssql_ctx.parseDataType(self.returnType.json()) + from pyspark.sql import SQLContext + sc = SparkContext.getOrCreate() + wrapped_func = _wrap_function(sc, self.func, self.returnType) + ctx = SQLContext.getOrCreate(sc) + jdt = ctx._ssql_ctx.parseDataType(self.returnType.json()) if name is None: + f = self.func name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ - judf = sc._jvm.UserDefinedPythonFunction(name, bytearray(pickled_command), env, includes, - sc.pythonExec, sc.pythonVer, broadcast_vars, - sc._javaAccumulator, jdt) + judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( + name, wrapped_func, jdt) return judf def __del__(self): diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 71c0bccc5eeff..ee734cb439287 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -17,7 +17,7 @@ from pyspark import since from pyspark.rdd import ignore_unicode_prefix -from pyspark.sql.column import Column, _to_seq +from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import * @@ -74,11 +74,11 @@ def agg(self, *exprs): or a list of :class:`Column`. >>> gdf = df.groupBy(df.name) - >>> gdf.agg({"*": "count"}).collect() + >>> sorted(gdf.agg({"*": "count"}).collect()) [Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)] >>> from pyspark.sql import functions as F - >>> gdf.agg(F.min(df.age)).collect() + >>> sorted(gdf.agg(F.min(df.age)).collect()) [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] """ assert exprs, "exprs should not be empty" @@ -96,7 +96,7 @@ def agg(self, *exprs): def count(self): """Counts the number of records for each group. - >>> df.groupBy(df.age).count().collect() + >>> sorted(df.groupBy(df.age).count().collect()) [Row(age=2, count=1), Row(age=5, count=1)] """ @@ -167,6 +167,31 @@ def sum(self, *cols): [Row(sum(age)=7, sum(height)=165)] """ + @since(1.6) + def pivot(self, pivot_col, values=None): + """ + Pivots a column of the current [[DataFrame]] and perform the specified aggregation. + There are two versions of pivot function: one that requires the caller to specify the list + of distinct values to pivot on, and one that does not. The latter is more concise but less + efficient, because Spark needs to first compute the list of distinct values internally. + + :param pivot_col: Name of the column to pivot. + :param values: List of values that will be translated to columns in the output DataFrame. + + // Compute the sum of earnings for each year by course with each course as a separate column + >>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect() + [Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)] + + // Or without specifying column values (less efficient) + >>> df4.groupBy("year").pivot("course").sum("earnings").collect() + [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] + """ + if values is None: + jgd = self._jdf.pivot(pivot_col) + else: + jgd = self._jdf.pivot(pivot_col, values) + return GroupedData(jgd, self.sql_ctx) + def _test(): import doctest @@ -182,6 +207,11 @@ def _test(): StructField('name', StringType())])) globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), Row(name='Bob', age=5, height=85)]).toDF() + globs['df4'] = sc.parallelize([Row(course="dotNET", year=2012, earnings=10000), + Row(course="Java", year=2012, earnings=20000), + Row(course="dotNET", year=2012, earnings=5000), + Row(course="dotNET", year=2013, earnings=48000), + Row(course="Java", year=2013, earnings=30000)]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.group, globs=globs, diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 97bd90c4db829..0cef37e57cd54 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -26,6 +26,7 @@ from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq from pyspark.sql.types import * +from pyspark.sql import utils __all__ = ["DataFrameReader", "DataFrameWriter"] @@ -108,7 +109,7 @@ def options(self, **options): def load(self, path=None, format=None, schema=None, **options): """Loads data from a data source and returns it as a :class`DataFrame`. - :param path: optional string for file-system backed data sources. + :param path: optional string or a list of string for file-system backed data sources. :param format: optional string for format of the data source. Default to 'parquet'. :param schema: optional :class:`StructType` for the input schema. :param options: all other string options @@ -117,6 +118,7 @@ def load(self, path=None, format=None, schema=None, **options): ... opt2=1, opt3='str') >>> df.dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] + >>> df = sqlContext.read.format('json').load(['python/test_support/sql/people.json', ... 'python/test_support/sql/people1.json']) >>> df.dtypes @@ -128,15 +130,9 @@ def load(self, path=None, format=None, schema=None, **options): self.schema(schema) self.options(**options) if path is not None: - if type(path) == list: - paths = path - gateway = self._sqlContext._sc._gateway - jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths)) - for i in range(0, len(paths)): - jpaths[i] = paths[i] - return self._df(self._jreader.load(jpaths)) - else: - return self._df(self._jreader.load(path)) + if type(path) != list: + path = [path] + return self._df(self._jreader.load(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) else: return self._df(self._jreader.load()) @@ -153,6 +149,31 @@ def json(self, path, schema=None): or RDD of Strings storing JSON objects. :param schema: an optional :class:`StructType` for the input schema. + You can set the following JSON-specific options to deal with non-standard JSON files: + * ``primitivesAsString`` (default ``false``): infers all primitive values as a string \ + type + * `prefersDecimal` (default `false`): infers all floating-point values as a decimal \ + type. If the values do not fit in decimal, then it infers them as doubles. + * ``allowComments`` (default ``false``): ignores Java/C++ style comment in JSON records + * ``allowUnquotedFieldNames`` (default ``false``): allows unquoted JSON field names + * ``allowSingleQuotes`` (default ``true``): allows single quotes in addition to double \ + quotes + * ``allowNumericLeadingZeros`` (default ``false``): allows leading zeros in numbers \ + (e.g. 00012) + * ``allowBackslashEscapingAnyCharacter`` (default ``false``): allows accepting quoting \ + of all character using backslash quoting mechanism + * ``mode`` (default ``PERMISSIVE``): allows a mode for dealing with corrupt records \ + during parsing. + * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ + record and puts the malformed string into a new field configured by \ + ``columnNameOfCorruptRecord``. When a schema is set by user, it sets \ + ``null`` for extra fields. + * ``DROPMALFORMED`` : ignores the whole corrupted records. + * ``FAILFAST`` : throws an exception when it meets corrupted records. + * ``columnNameOfCorruptRecord`` (default ``_corrupt_record``): allows renaming the \ + new field having malformed string created by ``PERMISSIVE`` mode. \ + This overrides ``spark.sql.columnNameOfCorruptRecord``. + >>> df1 = sqlContext.read.json('python/test_support/sql/people.json') >>> df1.dtypes [('age', 'bigint'), ('name', 'string')] @@ -166,8 +187,20 @@ def json(self, path, schema=None): self.schema(schema) if isinstance(path, basestring): return self._df(self._jreader.json(path)) + elif type(path) == list: + return self._df(self._jreader.json(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) elif isinstance(path, RDD): - return self._df(self._jreader.json(path._jrdd)) + def func(iterator): + for x in iterator: + if not isinstance(x, basestring): + x = unicode(x) + if isinstance(x, unicode): + x = x.encode("utf-8") + yield x + keyed = path.mapPartitions(func) + keyed._bypass_serializer = True + jrdd = keyed._jrdd.map(self._sqlContext._jvm.BytesToString()) + return self._df(self._jreader.json(jrdd)) else: raise TypeError("path can be only string or RDD") @@ -196,16 +229,37 @@ def parquet(self, *paths): @ignore_unicode_prefix @since(1.6) - def text(self, path): - """Loads a text file and returns a [[DataFrame]] with a single string column named "text". + def text(self, paths): + """Loads a text file and returns a [[DataFrame]] with a single string column named "value". Each line in the text file is a new row in the resulting DataFrame. + :param paths: string, or list of strings, for input path(s). + >>> df = sqlContext.read.text('python/test_support/sql/text-test.txt') >>> df.collect() - [Row(text=u'hello'), Row(text=u'this')] + [Row(value=u'hello'), Row(value=u'this')] """ - return self._df(self._jreader.text(path)) + if isinstance(paths, basestring): + paths = [paths] + return self._df(self._jreader.text(self._sqlContext._sc._jvm.PythonUtils.toSeq(paths))) + + @since(2.0) + def csv(self, paths): + """Loads a CSV file and returns the result as a [[DataFrame]]. + + This function goes through the input once to determine the input schema. To avoid going + through the entire data once, specify the schema explicitly using [[schema]]. + + :param paths: string, or list of strings, for input path(s). + + >>> df = sqlContext.read.csv('python/test_support/sql/ages.csv') + >>> df.dtypes + [('C0', 'string'), ('C1', 'string')] + """ + if isinstance(paths, basestring): + paths = [paths] + return self._df(self._jreader.csv(self._sqlContext._sc._jvm.PythonUtils.toSeq(paths))) @since(1.5) def orc(self, path): @@ -259,8 +313,9 @@ def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPar return self._df(self._jreader.jdbc(url, table, column, int(lowerBound), int(upperBound), int(numPartitions), jprop)) if predicates is not None: - arr = self._sqlContext._sc._jvm.PythonUtils.toArray(predicates) - return self._df(self._jreader.jdbc(url, table, arr, jprop)) + gateway = self._sqlContext._sc._gateway + jpredicates = utils.toJArray(gateway, gateway.jvm.java.lang.String, predicates) + return self._df(self._jreader.jdbc(url, table, jpredicates, jprop)) return self._df(self._jreader.jdbc(url, table, jprop)) @@ -410,7 +465,7 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) self._jwrite.saveAsTable(name) @since(1.4) - def json(self, path, mode=None): + def json(self, path, mode=None, compression=None): """Saves the content of the :class:`DataFrame` in JSON format at the specified path. :param path: the path in any Hadoop supported file system @@ -420,13 +475,19 @@ def json(self, path, mode=None): * ``overwrite``: Overwrite existing data. * ``ignore``: Silently ignore this operation if data already exists. * ``error`` (default case): Throw an exception if data already exists. + :param compression: compression codec to use when saving to file. This can be one of the + known case-insensitive shorten names (none, bzip2, gzip, lz4, + snappy and deflate). >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ - self.mode(mode)._jwrite.json(path) + self.mode(mode) + if compression is not None: + self.option("compression", compression) + self._jwrite.json(path) @since(1.4) - def parquet(self, path, mode=None, partitionBy=None): + def parquet(self, path, mode=None, partitionBy=None, compression=None): """Saves the content of the :class:`DataFrame` in Parquet format at the specified path. :param path: the path in any Hadoop supported file system @@ -437,25 +498,60 @@ def parquet(self, path, mode=None, partitionBy=None): * ``ignore``: Silently ignore this operation if data already exists. * ``error`` (default case): Throw an exception if data already exists. :param partitionBy: names of partitioning columns + :param compression: compression codec to use when saving to file. This can be one of the + known case-insensitive shorten names (none, snappy, gzip, and lzo). + This will overwrite ``spark.sql.parquet.compression.codec``. >>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) if partitionBy is not None: self.partitionBy(partitionBy) + if compression is not None: + self.option("compression", compression) self._jwrite.parquet(path) @since(1.6) - def text(self, path): + def text(self, path, compression=None): """Saves the content of the DataFrame in a text file at the specified path. + :param path: the path in any Hadoop supported file system + :param compression: compression codec to use when saving to file. This can be one of the + known case-insensitive shorten names (none, bzip2, gzip, lz4, + snappy and deflate). + The DataFrame must have only one column that is of string type. Each row becomes a new line in the output file. """ + if compression is not None: + self.option("compression", compression) self._jwrite.text(path) + @since(2.0) + def csv(self, path, mode=None, compression=None): + """Saves the content of the [[DataFrame]] in CSV format at the specified path. + + :param path: the path in any Hadoop supported file system + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. + + :param compression: compression codec to use when saving to file. This can be one of the + known case-insensitive shorten names (none, bzip2, gzip, lz4, + snappy and deflate). + + >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) + """ + self.mode(mode) + if compression is not None: + self.option("compression", compression) + self._jwrite.csv(path) + @since(1.5) - def orc(self, path, mode=None, partitionBy=None): + def orc(self, path, mode=None, partitionBy=None, compression=None): """Saves the content of the :class:`DataFrame` in ORC format at the specified path. ::Note: Currently ORC support is only available together with @@ -469,6 +565,9 @@ def orc(self, path, mode=None, partitionBy=None): * ``ignore``: Silently ignore this operation if data already exists. * ``error`` (default case): Throw an exception if data already exists. :param partitionBy: names of partitioning columns + :param compression: compression codec to use when saving to file. This can be one of the + known case-insensitive shorten names (none, snappy, zlib, and lzo). + This will overwrite ``orc.compress``. >>> orc_df = hiveContext.read.orc('python/test_support/sql/orc_partitioned') >>> orc_df.write.orc(os.path.join(tempfile.mkdtemp(), 'data')) @@ -476,6 +575,8 @@ def orc(self, path, mode=None, partitionBy=None): self.mode(mode) if partitionBy is not None: self.partitionBy(partitionBy) + if compression is not None: + self.option("compression", compression) self._jwrite.orc(path) @since(1.4) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4c03a0d4ffe93..e4f79c911c0d9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -51,7 +51,7 @@ from pyspark.tests import ReusedPySparkTestCase from pyspark.sql.functions import UserDefinedFunction, sha2 from pyspark.sql.window import Window -from pyspark.sql.utils import AnalysisException, IllegalArgumentException +from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException class UTCOffsetTimezone(datetime.tzinfo): @@ -305,6 +305,25 @@ def test_udf2(self): [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() self.assertEqual(4, res[0]) + def test_chained_udf(self): + self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType()) + [row] = self.sqlCtx.sql("SELECT double(1)").collect() + self.assertEqual(row[0], 2) + [row] = self.sqlCtx.sql("SELECT double(double(1))").collect() + self.assertEqual(row[0], 4) + [row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect() + self.assertEqual(row[0], 6) + + def test_multiple_udfs(self): + self.sqlCtx.registerFunction("double", lambda x: x * 2, IntegerType()) + [row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect() + self.assertEqual(tuple(row), (2, 4)) + [row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect() + self.assertEqual(tuple(row), (4, 12)) + self.sqlCtx.registerFunction("add", lambda x, y: x + y, IntegerType()) + [row] = self.sqlCtx.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect() + self.assertEqual(tuple(row), (6, 5)) + def test_udf_with_array_type(self): d = [Row(l=list(range(3)), d={"key": list(range(5))})] rdd = self.sc.parallelize(d) @@ -324,9 +343,18 @@ def test_broadcast_in_udf(self): [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() self.assertEqual("", res[0]) + def test_udf_with_aggregate_function(self): + df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + from pyspark.sql.functions import udf, col + from pyspark.sql.types import BooleanType + + my_filter = udf(lambda a: a == 1, BooleanType()) + sel = df.select(col("key")).distinct().filter(my_filter(col("key"))) + self.assertEqual(sel.collect(), [Row(key=1)]) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.sqlCtx.jsonRDD(rdd) + df = self.sqlCtx.read.json(rdd) df.count() df.collect() df.schema @@ -345,14 +373,32 @@ def test_basic_functions(self): df.collect() def test_apply_schema_to_row(self): - df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) - df2 = self.sqlCtx.createDataFrame(df.map(lambda x: x), df.schema) + df = self.sqlCtx.read.json(self.sc.parallelize(["""{"a":2}"""])) + df2 = self.sqlCtx.createDataFrame(df.rdd.map(lambda x: x), df.schema) self.assertEqual(df.collect(), df2.collect()) rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) df3 = self.sqlCtx.createDataFrame(rdd, df.schema) self.assertEqual(10, df3.count()) + def test_infer_schema_to_local(self): + input = [{"a": 1}, {"b": "coffee"}] + rdd = self.sc.parallelize(input) + df = self.sqlCtx.createDataFrame(input) + df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0) + self.assertEqual(df.schema, df2.schema) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None)) + df3 = self.sqlCtx.createDataFrame(rdd, df.schema) + self.assertEqual(10, df3.count()) + + def test_create_dataframe_schema_mismatch(self): + input = [Row(a=1)] + rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i)) + schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())]) + df = self.sqlCtx.createDataFrame(rdd, schema) + self.assertRaises(Exception, lambda: df.show()) + def test_serialize_nested_array_and_map(self): d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] rdd = self.sc.parallelize(d) @@ -362,15 +408,15 @@ def test_serialize_nested_array_and_map(self): self.assertEqual(1, row.l[0].a) self.assertEqual("2", row.d["key"].d) - l = df.map(lambda x: x.l).first() + l = df.rdd.map(lambda x: x.l).first() self.assertEqual(1, len(l)) self.assertEqual('s', l[0].b) - d = df.map(lambda x: x.d).first() + d = df.rdd.map(lambda x: x.d).first() self.assertEqual(1, len(d)) self.assertEqual(1.0, d["key"].c) - row = df.map(lambda x: x.d["key"]).first() + row = df.rdd.map(lambda x: x.d["key"]).first() self.assertEqual(1.0, row.c) self.assertEqual("2", row.d) @@ -379,16 +425,16 @@ def test_infer_schema(self): Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] rdd = self.sc.parallelize(d) df = self.sqlCtx.createDataFrame(rdd) - self.assertEqual([], df.map(lambda r: r.l).first()) - self.assertEqual([None, ""], df.map(lambda r: r.s).collect()) + self.assertEqual([], df.rdd.map(lambda r: r.l).first()) + self.assertEqual([None, ""], df.rdd.map(lambda r: r.s).collect()) df.registerTempTable("test") result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") self.assertEqual(1, result.head()[0]) df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0) self.assertEqual(df.schema, df2.schema) - self.assertEqual({}, df2.map(lambda r: r.d).first()) - self.assertEqual([None, ""], df2.map(lambda r: r.s).collect()) + self.assertEqual({}, df2.rdd.map(lambda r: r.d).first()) + self.assertEqual([None, ""], df2.rdd.map(lambda r: r.s).collect()) df2.registerTempTable("test2") result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") self.assertEqual(1, result.head()[0]) @@ -397,12 +443,12 @@ def test_infer_nested_schema(self): NestedRow = Row("f1", "f2") nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}), NestedRow([2, 3], {"row2": 2.0})]) - df = self.sqlCtx.inferSchema(nestedRdd1) + df = self.sqlCtx.createDataFrame(nestedRdd1) self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0]) nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]), NestedRow([[2, 3], [3, 4]], [2, 3])]) - df = self.sqlCtx.inferSchema(nestedRdd2) + df = self.sqlCtx.createDataFrame(nestedRdd2) self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0]) from collections import namedtuple @@ -410,7 +456,7 @@ def test_infer_nested_schema(self): rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"), CustomRow(field1=2, field2="row2"), CustomRow(field1=3, field2="row3")]) - df = self.sqlCtx.inferSchema(rdd) + df = self.sqlCtx.createDataFrame(rdd) self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) def test_create_dataframe_from_objects(self): @@ -442,8 +488,8 @@ def test_apply_schema(self): StructField("list1", ArrayType(ByteType(), False), False), StructField("null1", DoubleType(), True)]) df = self.sqlCtx.createDataFrame(rdd, schema) - results = df.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1, - x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1)) + results = df.rdd.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, + x.date1, x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1)) r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) self.assertEqual(r, results.first()) @@ -550,7 +596,7 @@ def test_udf_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) df = self.sqlCtx.createDataFrame([row]) - self.assertEqual(1.0, df.map(lambda r: r.point.x).first()) + self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) udf = UserDefinedFunction(lambda p: p.y, DoubleType()) self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT()) @@ -558,7 +604,7 @@ def test_udf_with_udt(self): row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) df = self.sqlCtx.createDataFrame([row]) - self.assertEqual(1.0, df.map(lambda r: r.point.x).first()) + self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) udf = UserDefinedFunction(lambda p: p.y, DoubleType()) self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT()) @@ -570,17 +616,35 @@ def test_parquet_with_udt(self): df0 = self.sqlCtx.createDataFrame([row]) output_dir = os.path.join(self.tempdir.name, "labeled_point") df0.write.parquet(output_dir) - df1 = self.sqlCtx.parquetFile(output_dir) + df1 = self.sqlCtx.read.parquet(output_dir) point = df1.head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) df0 = self.sqlCtx.createDataFrame([row]) df0.write.parquet(output_dir, mode='overwrite') - df1 = self.sqlCtx.parquetFile(output_dir) + df1 = self.sqlCtx.read.parquet(output_dir) point = df1.head().point self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) + def test_union_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + row1 = (1.0, ExamplePoint(1.0, 2.0)) + row2 = (2.0, ExamplePoint(3.0, 4.0)) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + df1 = self.sqlCtx.createDataFrame([row1], schema) + df2 = self.sqlCtx.createDataFrame([row2], schema) + + result = df1.union(df2).orderBy("label").collect() + self.assertEqual( + result, + [ + Row(label=1.0, point=ExamplePoint(1.0, 2.0)), + Row(label=2.0, point=ExamplePoint(3.0, 4.0)) + ] + ) + def test_column_operators(self): ci = self.df.key cs = self.df.value @@ -621,6 +685,23 @@ def test_aggregator(self): self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0]) self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) + def test_first_last_ignorenulls(self): + from pyspark.sql import functions + df = self.sqlCtx.range(0, 100) + df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id")) + df3 = df2.select(functions.first(df2.id, False).alias('a'), + functions.first(df2.id, True).alias('b'), + functions.last(df2.id, False).alias('c'), + functions.last(df2.id, True).alias('d')) + self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect()) + + def test_approxQuantile(self): + df = self.sc.parallelize([Row(a=i) for i in range(10)]).toDF() + aq = df.stat.approxQuantile("a", [0.1, 0.5, 0.9], 0.1) + self.assertTrue(isinstance(aq, list)) + self.assertEqual(len(aq), 3) + self.assertTrue(all(isinstance(q, float) for q in aq)) + def test_corr(self): import math df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF() @@ -727,6 +808,13 @@ def test_struct_type(self): except ValueError: self.assertEqual(1, 1) + def test_metadata_null(self): + from pyspark.sql.types import StructType, StringType, StructField + schema = StructType([StructField("f1", StringType(), True, None), + StructField("f2", StringType(), True, {'a': None})]) + rdd = self.sc.parallelize([["a", "b"], ["c", "d"]]) + self.sqlCtx.createDataFrame(rdd, schema) + def test_save_and_load(self): df = self.df tmpPath = tempfile.mkdtemp() @@ -752,7 +840,7 @@ def test_save_and_load(self): defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - actual = self.sqlCtx.load(path=tmpPath) + actual = self.sqlCtx.read.load(path=tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) @@ -785,7 +873,7 @@ def test_save_and_load_builder(self): defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - actual = self.sqlCtx.load(path=tmpPath) + actual = self.sqlCtx.read.load(path=tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) @@ -794,7 +882,7 @@ def test_save_and_load_builder(self): def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.sqlCtx.jsonRDD(rdd) + df = self.sqlCtx.read.json(rdd) # render_doc() reproduces the help() exception without printing output pydoc.render_doc(df) pydoc.render_doc(df.foo) @@ -842,8 +930,8 @@ def test_infer_long_type(self): # this saving as Parquet caused issues as well. output_dir = os.path.join(self.tempdir.name, "infer_long_type") - df.saveAsParquetFile(output_dir) - df1 = self.sqlCtx.parquetFile(output_dir) + df.write.parquet(output_dir) + df1 = self.sqlCtx.read.parquet(output_dir) self.assertEqual('a', df1.first().f1) self.assertEqual(100000000000000, df1.first().f2) @@ -1017,7 +1105,7 @@ def test_expr(self): row = Row(a="length string", b=75) df = self.sqlCtx.createDataFrame([row]) result = df.select(functions.expr("length(a)")).collect()[0].asDict() - self.assertEqual(13, result["'length(a)"]) + self.assertEqual(13, result["length(a)"]) def test_replace(self): schema = StructType([ @@ -1070,8 +1158,9 @@ def test_replace(self): def test_capture_analysis_exception(self): self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc")) self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) - # RuntimeException should not be captured - self.assertRaises(py4j.protocol.Py4JJavaError, lambda: self.sqlCtx.sql("abc")) + + def test_capture_parse_exception(self): + self.assertRaises(ParseException, lambda: self.sqlCtx.sql("abc")) def test_capture_illegalargument_exception(self): self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks", @@ -1117,6 +1206,42 @@ def test_functions_broadcast(self): # planner should not crash without a join broadcast(df1)._jdf.queryExecution().executedPlan() + def test_toDF_with_schema_string(self): + data = [Row(key=i, value=str(i)) for i in range(100)] + rdd = self.sc.parallelize(data, 5) + + df = rdd.toDF("key: int, value: string") + self.assertEqual(df.schema.simpleString(), "struct") + self.assertEqual(df.collect(), data) + + # different but compatible field types can be used. + df = rdd.toDF("key: string, value: string") + self.assertEqual(df.schema.simpleString(), "struct") + self.assertEqual(df.collect(), [Row(key=str(i), value=str(i)) for i in range(100)]) + + # field names can differ. + df = rdd.toDF(" a: int, b: string ") + self.assertEqual(df.schema.simpleString(), "struct") + self.assertEqual(df.collect(), data) + + # number of fields must match. + self.assertRaisesRegexp(Exception, "Length of object", + lambda: rdd.toDF("key: int").collect()) + + # field types mismatch will cause exception at runtime. + self.assertRaisesRegexp(Exception, "FloatType can not accept", + lambda: rdd.toDF("key: float, value: string").collect()) + + # flat schema values will be wrapped into row. + df = rdd.map(lambda row: row.key).toDF("int") + self.assertEqual(df.schema.simpleString(), "struct") + self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) + + # users can use DataType directly instead of data type string. + df = rdd.map(lambda row: row.key).toDF(IntegerType()) + self.assertEqual(df.schema.simpleString(), "struct") + self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) + class HiveContextSQLTests(ReusedPySparkTestCase): @@ -1194,9 +1319,9 @@ def test_window_functions(self): F.max("key").over(w.rowsBetween(0, 1)), F.min("key").over(w.rowsBetween(0, 1)), F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))), - F.rowNumber().over(w), + F.row_number().over(w), F.rank().over(w), - F.denseRank().over(w), + F.dense_rank().over(w), F.ntile(2).over(w)) rs = sorted(sel.collect()) expected = [ @@ -1216,9 +1341,9 @@ def test_window_functions_without_partitionBy(self): F.max("key").over(w.rowsBetween(0, 1)), F.min("key").over(w.rowsBetween(0, 1)), F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))), - F.rowNumber().over(w), + F.row_number().over(w), F.rank().over(w), - F.denseRank().over(w), + F.dense_rank().over(w), F.ntile(2).over(w)) rs = sorted(sel.collect()) expected = [ @@ -1230,8 +1355,26 @@ def test_window_functions_without_partitionBy(self): for r, ex in zip(rs, expected): self.assertEqual(tuple(r), ex[:len(r)]) + def test_collect_functions(self): + df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + from pyspark.sql import functions + + self.assertEqual( + sorted(df.select(functions.collect_set(df.key).alias('r')).collect()[0].r), + [1, 2]) + self.assertEqual( + sorted(df.select(functions.collect_list(df.key).alias('r')).collect()[0].r), + [1, 1, 1, 2]) + self.assertEqual( + sorted(df.select(functions.collect_set(df.value).alias('r')).collect()[0].r), + ["1", "2"]) + self.assertEqual( + sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r), + ["1", "2", "2", "2"]) + if __name__ == "__main__": + from pyspark.sql.tests import * if xmlrunner: unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) else: diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 5bc0773fa8660..734c1533a24bc 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -681,6 +681,129 @@ def __eq__(self, other): for v in [ArrayType, MapType, StructType]) +_FIXED_DECIMAL = re.compile("decimal\\(\\s*(\\d+)\\s*,\\s*(\\d+)\\s*\\)") + + +_BRACKETS = {'(': ')', '[': ']', '{': '}'} + + +def _parse_basic_datatype_string(s): + if s in _all_atomic_types.keys(): + return _all_atomic_types[s]() + elif s == "int": + return IntegerType() + elif _FIXED_DECIMAL.match(s): + m = _FIXED_DECIMAL.match(s) + return DecimalType(int(m.group(1)), int(m.group(2))) + else: + raise ValueError("Could not parse datatype: %s" % s) + + +def _ignore_brackets_split(s, separator): + """ + Splits the given string by given separator, but ignore separators inside brackets pairs, e.g. + given "a,b" and separator ",", it will return ["a", "b"], but given "a, d", it will return + ["a", "d"]. + """ + parts = [] + buf = "" + level = 0 + for c in s: + if c in _BRACKETS.keys(): + level += 1 + buf += c + elif c in _BRACKETS.values(): + if level == 0: + raise ValueError("Brackets are not correctly paired: %s" % s) + level -= 1 + buf += c + elif c == separator and level > 0: + buf += c + elif c == separator: + parts.append(buf) + buf = "" + else: + buf += c + + if len(buf) == 0: + raise ValueError("The %s cannot be the last char: %s" % (separator, s)) + parts.append(buf) + return parts + + +def _parse_struct_fields_string(s): + parts = _ignore_brackets_split(s, ",") + fields = [] + for part in parts: + name_and_type = _ignore_brackets_split(part, ":") + if len(name_and_type) != 2: + raise ValueError("The strcut field string format is: 'field_name:field_type', " + + "but got: %s" % part) + field_name = name_and_type[0].strip() + field_type = _parse_datatype_string(name_and_type[1]) + fields.append(StructField(field_name, field_type)) + return StructType(fields) + + +def _parse_datatype_string(s): + """ + Parses the given data type string to a :class:`DataType`. The data type string format equals + to `DataType.simpleString`, except that top level struct type can omit the `struct<>` and + atomic types use `typeName()` as their format, e.g. use `byte` instead of `tinyint` for + ByteType. We can also use `int` as a short name for IntegerType. + + >>> _parse_datatype_string("int ") + IntegerType + >>> _parse_datatype_string("a: byte, b: decimal( 16 , 8 ) ") + StructType(List(StructField(a,ByteType,true),StructField(b,DecimalType(16,8),true))) + >>> _parse_datatype_string("a: array< short>") + StructType(List(StructField(a,ArrayType(ShortType,true),true))) + >>> _parse_datatype_string(" map ") + MapType(StringType,StringType,true) + + >>> # Error cases + >>> _parse_datatype_string("blabla") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + >>> _parse_datatype_string("a: int,") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + >>> _parse_datatype_string("array>> _parse_datatype_string("map>") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + """ + s = s.strip() + if s.startswith("array<"): + if s[-1] != ">": + raise ValueError("'>' should be the last char, but got: %s" % s) + return ArrayType(_parse_datatype_string(s[6:-1])) + elif s.startswith("map<"): + if s[-1] != ">": + raise ValueError("'>' should be the last char, but got: %s" % s) + parts = _ignore_brackets_split(s[4:-1], ",") + if len(parts) != 2: + raise ValueError("The map type string format is: 'map', " + + "but got: %s" % s) + kt = _parse_datatype_string(parts[0]) + vt = _parse_datatype_string(parts[1]) + return MapType(kt, vt) + elif s.startswith("struct<"): + if s[-1] != ">": + raise ValueError("'>' should be the last char, but got: %s" % s) + return _parse_struct_fields_string(s[7:-1]) + elif ":" in s: + return _parse_struct_fields_string(s) + else: + return _parse_basic_datatype_string(s) + + def _parse_datatype_json_string(json_string): """Parses the given data type JSON string. >>> import pickle @@ -730,9 +853,6 @@ def _parse_datatype_json_string(json_string): return _parse_datatype_json_value(json.loads(json_string)) -_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)") - - def _parse_datatype_json_value(json_value): if not isinstance(json_value, dict): if json_value in _all_atomic_types.keys(): @@ -940,9 +1060,6 @@ def convert_struct(obj): return convert_struct -_BRACKETS = {'(': ')', '[': ']', '{': '}'} - - def _split_schema_abstract(s): """ split the schema abstract into fields @@ -1091,10 +1208,13 @@ def _infer_schema_type(obj, dataType): } -def _verify_type(obj, dataType): +def _verify_type(obj, dataType, nullable=True): """ - Verify the type of obj against dataType, raise an exception if - they do not match. + Verify the type of obj against dataType, raise a TypeError if they do not match. + + Also verify the value of obj against datatype, raise a ValueError if it's not within the allowed + range, e.g. using 128 as ByteType will overflow. Note that, Python float is not checked, so it + will become infinity when cast to Java float if it overflows. >>> _verify_type(None, StructType([])) >>> _verify_type("", StringType()) @@ -1111,10 +1231,35 @@ def _verify_type(obj, dataType): Traceback (most recent call last): ... ValueError:... + >>> # Check if numeric values are within the allowed range. + >>> _verify_type(12, ByteType()) + >>> _verify_type(1234, ByteType()) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + >>> _verify_type(None, ByteType(), False) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + >>> _verify_type([1, None], ArrayType(ShortType(), False)) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + >>> _verify_type({None: 1}, MapType(StringType(), IntegerType())) + Traceback (most recent call last): + ... + ValueError:... + >>> schema = StructType().add("a", IntegerType()).add("b", StringType(), False) + >>> _verify_type((1, None), schema) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... """ - # all objects are nullable if obj is None: - return + if nullable: + return + else: + raise ValueError("This field is not nullable, but got None") # StringType can work with any types if isinstance(dataType, StringType): @@ -1137,21 +1282,33 @@ def _verify_type(obj, dataType): if type(obj) not in _acceptable_types[_type]: raise TypeError("%s can not accept object %r in type %s" % (dataType, obj, type(obj))) - if isinstance(dataType, ArrayType): + if isinstance(dataType, ByteType): + if obj < -128 or obj > 127: + raise ValueError("object of ByteType out of range, got: %s" % obj) + + elif isinstance(dataType, ShortType): + if obj < -32768 or obj > 32767: + raise ValueError("object of ShortType out of range, got: %s" % obj) + + elif isinstance(dataType, IntegerType): + if obj < -2147483648 or obj > 2147483647: + raise ValueError("object of IntegerType out of range, got: %s" % obj) + + elif isinstance(dataType, ArrayType): for i in obj: - _verify_type(i, dataType.elementType) + _verify_type(i, dataType.elementType, dataType.containsNull) elif isinstance(dataType, MapType): for k, v in obj.items(): - _verify_type(k, dataType.keyType) - _verify_type(v, dataType.valueType) + _verify_type(k, dataType.keyType, False) + _verify_type(v, dataType.valueType, dataType.valueContainsNull) elif isinstance(dataType, StructType): if len(obj) != len(dataType.fields): raise ValueError("Length of object (%d) does not match with " "length of fields (%d)" % (len(obj), len(dataType.fields))) for v, f in zip(obj, dataType.fields): - _verify_type(v, f.dataType) + _verify_type(v, f.dataType, f.nullable) # This is used to unpickle a Row from JVM diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index c4fda8bd3b891..7ea0e0d5c9bef 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -33,6 +33,12 @@ class AnalysisException(CapturedException): """ +class ParseException(CapturedException): + """ + Failed to parse a SQL command. + """ + + class IllegalArgumentException(CapturedException): """ Passed an illegal or inappropriate argument. @@ -49,6 +55,8 @@ def deco(*a, **kw): e.java_exception.getStackTrace())) if s.startswith('org.apache.spark.sql.AnalysisException: '): raise AnalysisException(s.split(': ', 1)[1], stackTrace) + if s.startswith('org.apache.spark.sql.catalyst.parser.ParseException: '): + raise ParseException(s.split(': ', 1)[1], stackTrace) if s.startswith('java.lang.IllegalArgumentException: '): raise IllegalArgumentException(s.split(': ', 1)[1], stackTrace) raise @@ -71,3 +79,16 @@ def install_exception_handler(): patched = capture_sql_exception(original) # only patch the one used in in py4j.java_gateway (call Java API) py4j.java_gateway.get_return_value = patched + + +def toJArray(gateway, jtype, arr): + """ + Convert python list to java type array + :param gateway: Py4j Gateway + :param jtype: java type of element in array + :param arr: python type list + """ + jarr = gateway.new_array(jtype, len(arr)) + for i in range(0, len(arr)): + jarr[i] = arr[i] + return jarr diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index 57bbe340bbd4d..46663f69a0881 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -60,7 +60,7 @@ def partitionBy(*cols): @since(1.4) def orderBy(*cols): """ - Creates a :class:`WindowSpec` with the partitioning defined. + Creates a :class:`WindowSpec` with the ordering defined. """ sc = SparkContext._active_spark_context jspec = sc._jvm.org.apache.spark.sql.expressions.Window.orderBy(_to_java_cols(cols)) diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py index 676aa0f7144aa..ef012d27cb22f 100644 --- a/python/pyspark/storagelevel.py +++ b/python/pyspark/storagelevel.py @@ -23,8 +23,10 @@ class StorageLevel(object): """ Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory - in a serialized format, and whether to replicate the RDD partitions on multiple nodes. - Also contains static constants for some commonly used storage levels, such as MEMORY_ONLY. + in a JAVA-specific serialized format, and whether to replicate the RDD partitions on multiple + nodes. Also contains static constants for some commonly used storage levels, MEMORY_ONLY. + Since the data is always serialized on the Python side, all the constants use the serialized + formats. """ def __init__(self, useDisk, useMemory, useOffHeap, deserialized, replication=1): @@ -42,19 +44,28 @@ def __str__(self): result = "" result += "Disk " if self.useDisk else "" result += "Memory " if self.useMemory else "" - result += "Tachyon " if self.useOffHeap else "" + result += "OffHeap " if self.useOffHeap else "" result += "Deserialized " if self.deserialized else "Serialized " result += "%sx Replicated" % self.replication return result StorageLevel.DISK_ONLY = StorageLevel(True, False, False, False) StorageLevel.DISK_ONLY_2 = StorageLevel(True, False, False, False, 2) -StorageLevel.MEMORY_ONLY = StorageLevel(False, True, False, True) -StorageLevel.MEMORY_ONLY_2 = StorageLevel(False, True, False, True, 2) -StorageLevel.MEMORY_ONLY_SER = StorageLevel(False, True, False, False) -StorageLevel.MEMORY_ONLY_SER_2 = StorageLevel(False, True, False, False, 2) -StorageLevel.MEMORY_AND_DISK = StorageLevel(True, True, False, True) -StorageLevel.MEMORY_AND_DISK_2 = StorageLevel(True, True, False, True, 2) -StorageLevel.MEMORY_AND_DISK_SER = StorageLevel(True, True, False, False) -StorageLevel.MEMORY_AND_DISK_SER_2 = StorageLevel(True, True, False, False, 2) -StorageLevel.OFF_HEAP = StorageLevel(False, False, True, False, 1) +StorageLevel.MEMORY_ONLY = StorageLevel(False, True, False, False) +StorageLevel.MEMORY_ONLY_2 = StorageLevel(False, True, False, False, 2) +StorageLevel.MEMORY_AND_DISK = StorageLevel(True, True, False, False) +StorageLevel.MEMORY_AND_DISK_2 = StorageLevel(True, True, False, False, 2) +StorageLevel.OFF_HEAP = StorageLevel(True, True, True, False, 1) + +""" +.. note:: The following four storage level constants are deprecated in 2.0, since the records \ +will always be serialized in Python. +""" +StorageLevel.MEMORY_ONLY_SER = StorageLevel.MEMORY_ONLY +""".. note:: Deprecated in 2.0, use ``StorageLevel.MEMORY_ONLY`` instead.""" +StorageLevel.MEMORY_ONLY_SER_2 = StorageLevel.MEMORY_ONLY_2 +""".. note:: Deprecated in 2.0, use ``StorageLevel.MEMORY_ONLY_2`` instead.""" +StorageLevel.MEMORY_AND_DISK_SER = StorageLevel.MEMORY_AND_DISK +""".. note:: Deprecated in 2.0, use ``StorageLevel.MEMORY_AND_DISK`` instead.""" +StorageLevel.MEMORY_AND_DISK_SER_2 = StorageLevel.MEMORY_AND_DISK_2 +""".. note:: Deprecated in 2.0, use ``StorageLevel.MEMORY_AND_DISK_2`` instead.""" diff --git a/python/pyspark/streaming/__init__.py b/python/pyspark/streaming/__init__.py index d2644a1d4ffab..66e8f8ef001e3 100644 --- a/python/pyspark/streaming/__init__.py +++ b/python/pyspark/streaming/__init__.py @@ -17,5 +17,6 @@ from pyspark.streaming.context import StreamingContext from pyspark.streaming.dstream import DStream +from pyspark.streaming.listener import StreamingListener -__all__ = ['StreamingContext', 'DStream'] +__all__ = ['StreamingContext', 'DStream', 'StreamingListener'] diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 8be56c9915265..ec3ad9933cf60 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -94,7 +94,7 @@ def _ensure_initialized(cls): # get the GatewayServer object in JVM by ID jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) # update the port of CallbackClient with real port - gw.jvm.PythonDStream.updatePythonGatewayPort(jgws, gw._python_proxy_port) + jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port) # register serializer for TransformFunction # it happens before creating SparkContext when loading from checkpointing @@ -116,16 +116,13 @@ def getOrCreate(cls, checkpointPath, setupFunc): gw = SparkContext._gateway # Check whether valid checkpoint information exists in the given path - if gw.jvm.CheckpointReader.read(checkpointPath).isEmpty(): + ssc_option = gw.jvm.StreamingContextPythonHelper().tryRecoverFromCheckpoint(checkpointPath) + if ssc_option.isEmpty(): ssc = setupFunc() ssc.checkpoint(checkpointPath) return ssc - try: - jssc = gw.jvm.JavaStreamingContext(checkpointPath) - except Exception: - print("failed to load StreamingContext from checkpoint", file=sys.stderr) - raise + jssc = gw.jvm.JavaStreamingContext(ssc_option.get()) # If there is already an active instance of Python SparkContext use it, or create a new one if not SparkContext._active_spark_context: @@ -258,7 +255,7 @@ def checkpoint(self, directory): """ self._jssc.checkpoint(directory) - def socketTextStream(self, hostname, port, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): + def socketTextStream(self, hostname, port, storageLevel=StorageLevel.MEMORY_AND_DISK_2): """ Create an input from TCP source hostname:port. Data is received using a TCP socket and receive byte is interpreted as UTF8 encoded ``\\n`` delimited @@ -363,3 +360,11 @@ def union(self, *dstreams): first = dstreams[0] jrest = [d._jdstream for d in dstreams[1:]] return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer) + + def addStreamingListener(self, streamingListener): + """ + Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for + receiving system events related to streaming. + """ + self._jssc.addStreamingListener(self._jvm.JavaStreamingListenerWrapper( + self._jvm.PythonStreamingListenerWrapper(streamingListener))) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 698336cfce18d..2056663872198 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -208,10 +208,10 @@ def func(iterator): def cache(self): """ Persist the RDDs of this DStream with the default storage level - (C{MEMORY_ONLY_SER}). + (C{MEMORY_ONLY}). """ self.is_cached = True - self.persist(StorageLevel.MEMORY_ONLY_SER) + self.persist(StorageLevel.MEMORY_ONLY) return self def persist(self, storageLevel): @@ -247,7 +247,7 @@ def countByValue(self): Return a new DStream in which each RDD contains the counts of each distinct value in each RDD of this DStream. """ - return self.map(lambda x: (x, None)).reduceByKey(lambda x, y: None).count() + return self.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x+y) def saveAsTextFiles(self, prefix, suffix=None): """ @@ -453,7 +453,7 @@ def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, slideDuratio 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) This is more efficient than `invReduceFunc` is None. - @param reduceFunc: associative reduce function + @param reduceFunc: associative and commutative reduce function @param invReduceFunc: inverse reduce function of `reduceFunc` @param windowDuration: width of the window; must be a multiple of this DStream's batching interval @@ -493,7 +493,7 @@ def countByValueAndWindow(self, windowDuration, slideDuration, numPartitions=Non keyed = self.map(lambda x: (x, 1)) counted = keyed.reduceByKeyAndWindow(operator.add, operator.sub, windowDuration, slideDuration, numPartitions) - return counted.filter(lambda kv: kv[1] > 0).count() + return counted.filter(lambda kv: kv[1] > 0) def groupByKeyAndWindow(self, windowDuration, slideDuration, numPartitions=None): """ @@ -524,8 +524,8 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None `invFunc` can be None, then it will reduce all the RDDs in window, could be slower than having `invFunc`. - @param reduceFunc: associative reduce function - @param invReduceFunc: inverse function of `reduceFunc` + @param func: associative and commutative reduce function + @param invFunc: inverse function of `reduceFunc` @param windowDuration: width of the window; must be a multiple of this DStream's batching interval @param slideDuration: sliding interval of the window (i.e., the interval after which @@ -542,33 +542,34 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None reduced = self.reduceByKey(func, numPartitions) - def reduceFunc(t, a, b): - b = b.reduceByKey(func, numPartitions) - r = a.union(b).reduceByKey(func, numPartitions) if a else b - if filterFunc: - r = r.filter(filterFunc) - return r - - def invReduceFunc(t, a, b): - b = b.reduceByKey(func, numPartitions) - joined = a.leftOuterJoin(b, numPartitions) - return joined.mapValues(lambda kv: invFunc(kv[0], kv[1]) - if kv[1] is not None else kv[0]) - - jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) - if invReduceFunc: + if invFunc: + def reduceFunc(t, a, b): + b = b.reduceByKey(func, numPartitions) + r = a.union(b).reduceByKey(func, numPartitions) if a else b + if filterFunc: + r = r.filter(filterFunc) + return r + + def invReduceFunc(t, a, b): + b = b.reduceByKey(func, numPartitions) + joined = a.leftOuterJoin(b, numPartitions) + return joined.mapValues(lambda kv: invFunc(kv[0], kv[1]) + if kv[1] is not None else kv[0]) + + jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) jinvReduceFunc = TransformFunction(self._sc, invReduceFunc, reduced._jrdd_deserializer) + if slideDuration is None: + slideDuration = self._slideDuration + dstream = self._sc._jvm.PythonReducedWindowedDStream( + reduced._jdstream.dstream(), + jreduceFunc, jinvReduceFunc, + self._ssc._jduration(windowDuration), + self._ssc._jduration(slideDuration)) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) else: - jinvReduceFunc = None - if slideDuration is None: - slideDuration = self._slideDuration - dstream = self._sc._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(), - jreduceFunc, jinvReduceFunc, - self._ssc._jduration(windowDuration), - self._ssc._jduration(slideDuration)) - return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) + return reduced.window(windowDuration, slideDuration).reduceByKey(func, numPartitions) - def updateStateByKey(self, updateFunc, numPartitions=None): + def updateStateByKey(self, updateFunc, numPartitions=None, initialRDD=None): """ Return a new "state" DStream where the state for each key is updated by applying the given function on the previous state of the key and the new values of the key. @@ -579,6 +580,9 @@ def updateStateByKey(self, updateFunc, numPartitions=None): if numPartitions is None: numPartitions = self._sc.defaultParallelism + if initialRDD and not isinstance(initialRDD, RDD): + initialRDD = self._sc.parallelize(initialRDD) + def reduceFunc(t, a, b): if a is None: g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None)) @@ -590,7 +594,13 @@ def reduceFunc(t, a, b): jreduceFunc = TransformFunction(self._sc, reduceFunc, self._sc.serializer, self._jrdd_deserializer) - dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) + if initialRDD: + initialRDD = initialRDD._reserialize(self._jrdd_deserializer) + dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc, + initialRDD._jrdd) + else: + dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py index b3d1905365925..cd30483fc636a 100644 --- a/python/pyspark/streaming/flume.py +++ b/python/pyspark/streaming/flume.py @@ -40,7 +40,7 @@ class FlumeUtils(object): @staticmethod def createStream(ssc, hostname, port, - storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, + storageLevel=StorageLevel.MEMORY_AND_DISK_2, enableDecompression=False, bodyDecoder=utf8_decoder): """ @@ -55,22 +55,13 @@ def createStream(ssc, hostname, port, :return: A DStream object """ jlevel = ssc._sc._getJavaStorageLevel(storageLevel) - - try: - helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ - .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper") - helper = helperClass.newInstance() - jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression) - except Py4JJavaError as e: - if 'ClassNotFoundException' in str(e.java_exception): - FlumeUtils._printErrorMsg(ssc.sparkContext) - raise e - + helper = FlumeUtils._get_helper(ssc._sc) + jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression) return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder) @staticmethod def createPollingStream(ssc, addresses, - storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, + storageLevel=StorageLevel.MEMORY_AND_DISK_2, maxBatchSize=1000, parallelism=5, bodyDecoder=utf8_decoder): @@ -95,18 +86,9 @@ def createPollingStream(ssc, addresses, for (host, port) in addresses: hosts.append(host) ports.append(port) - - try: - helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ - .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper") - helper = helperClass.newInstance() - jstream = helper.createPollingStream( - ssc._jssc, hosts, ports, jlevel, maxBatchSize, parallelism) - except Py4JJavaError as e: - if 'ClassNotFoundException' in str(e.java_exception): - FlumeUtils._printErrorMsg(ssc.sparkContext) - raise e - + helper = FlumeUtils._get_helper(ssc._sc) + jstream = helper.createPollingStream( + ssc._jssc, hosts, ports, jlevel, maxBatchSize, parallelism) return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder) @staticmethod @@ -126,6 +108,15 @@ def func(event): return (headers, body) return stream.map(func) + @staticmethod + def _get_helper(sc): + try: + return sc._jvm.org.apache.spark.streaming.flume.FlumeUtilsPythonHelper() + except TypeError as e: + if str(e) == "'JavaPackage' object is not callable": + FlumeUtils._printErrorMsg(sc) + raise + @staticmethod def _printErrorMsg(sc): print(""" diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 06e159172ab51..02a88699a2886 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -19,12 +19,14 @@ from pyspark.rdd import RDD from pyspark.storagelevel import StorageLevel -from pyspark.serializers import PairDeserializer, NoOpSerializer +from pyspark.serializers import AutoBatchedSerializer, PickleSerializer, PairDeserializer, \ + NoOpSerializer from pyspark.streaming import DStream from pyspark.streaming.dstream import TransformedDStream from pyspark.streaming.util import TransformFunction -__all__ = ['Broker', 'KafkaUtils', 'OffsetRange', 'TopicAndPartition', 'utf8_decoder'] +__all__ = ['Broker', 'KafkaMessageAndMetadata', 'KafkaUtils', 'OffsetRange', + 'TopicAndPartition', 'utf8_decoder'] def utf8_decoder(s): @@ -38,7 +40,7 @@ class KafkaUtils(object): @staticmethod def createStream(ssc, zkQuorum, groupId, topics, kafkaParams=None, - storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, + storageLevel=StorageLevel.MEMORY_AND_DISK_2, keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): """ Create an input stream that pulls messages from a Kafka Broker. @@ -64,25 +66,16 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams=None, if not isinstance(topics, dict): raise TypeError("topics should be dict") jlevel = ssc._sc._getJavaStorageLevel(storageLevel) - - try: - # Use KafkaUtilsPythonHelper to access Scala's KafkaUtils (see SPARK-6027) - helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ - .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") - helper = helperClass.newInstance() - jstream = helper.createStream(ssc._jssc, kafkaParams, topics, jlevel) - except Py4JJavaError as e: - # TODO: use --jar once it also work on driver - if 'ClassNotFoundException' in str(e.java_exception): - KafkaUtils._printErrorMsg(ssc.sparkContext) - raise e + helper = KafkaUtils._get_helper(ssc._sc) + jstream = helper.createStream(ssc._jssc, kafkaParams, topics, jlevel) ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) stream = DStream(jstream, ssc, ser) return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) @staticmethod def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None, - keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): + keyDecoder=utf8_decoder, valueDecoder=utf8_decoder, + messageHandler=None): """ .. note:: Experimental @@ -107,6 +100,8 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None, point of the stream. :param keyDecoder: A function used to decode key (default is utf8_decoder). :param valueDecoder: A function used to decode value (default is utf8_decoder). + :param messageHandler: A function used to convert KafkaMessageAndMetadata. You can assess + meta using messageHandler (default is None). :return: A DStream object """ if fromOffsets is None: @@ -116,27 +111,36 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None, if not isinstance(kafkaParams, dict): raise TypeError("kafkaParams should be dict") - try: - helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ - .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") - helper = helperClass.newInstance() - - jfromOffsets = dict([(k._jTopicAndPartition(helper), - v) for (k, v) in fromOffsets.items()]) - jstream = helper.createDirectStream(ssc._jssc, kafkaParams, set(topics), jfromOffsets) - except Py4JJavaError as e: - if 'ClassNotFoundException' in str(e.java_exception): - KafkaUtils._printErrorMsg(ssc.sparkContext) - raise e + def funcWithoutMessageHandler(k_v): + return (keyDecoder(k_v[0]), valueDecoder(k_v[1])) - ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) - stream = DStream(jstream, ssc, ser) \ - .map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + def funcWithMessageHandler(m): + m._set_key_decoder(keyDecoder) + m._set_value_decoder(valueDecoder) + return messageHandler(m) + + helper = KafkaUtils._get_helper(ssc._sc) + + jfromOffsets = dict([(k._jTopicAndPartition(helper), + v) for (k, v) in fromOffsets.items()]) + if messageHandler is None: + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + func = funcWithoutMessageHandler + jstream = helper.createDirectStreamWithoutMessageHandler( + ssc._jssc, kafkaParams, set(topics), jfromOffsets) + else: + ser = AutoBatchedSerializer(PickleSerializer()) + func = funcWithMessageHandler + jstream = helper.createDirectStreamWithMessageHandler( + ssc._jssc, kafkaParams, set(topics), jfromOffsets) + + stream = DStream(jstream, ssc, ser).map(func) return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer) @staticmethod def createRDD(sc, kafkaParams, offsetRanges, leaders=None, - keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): + keyDecoder=utf8_decoder, valueDecoder=utf8_decoder, + messageHandler=None): """ .. note:: Experimental @@ -149,6 +153,8 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders=None, map, in which case leaders will be looked up on the driver. :param keyDecoder: A function used to decode key (default is utf8_decoder) :param valueDecoder: A function used to decode value (default is utf8_decoder) + :param messageHandler: A function used to convert KafkaMessageAndMetadata. You can assess + meta using messageHandler (default is None). :return: A RDD object """ if leaders is None: @@ -158,22 +164,39 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders=None, if not isinstance(offsetRanges, list): raise TypeError("offsetRanges should be list") + def funcWithoutMessageHandler(k_v): + return (keyDecoder(k_v[0]), valueDecoder(k_v[1])) + + def funcWithMessageHandler(m): + m._set_key_decoder(keyDecoder) + m._set_value_decoder(valueDecoder) + return messageHandler(m) + + helper = KafkaUtils._get_helper(sc) + + joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges] + jleaders = dict([(k._jTopicAndPartition(helper), + v._jBroker(helper)) for (k, v) in leaders.items()]) + if messageHandler is None: + jrdd = helper.createRDDWithoutMessageHandler( + sc._jsc, kafkaParams, joffsetRanges, jleaders) + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + rdd = RDD(jrdd, sc, ser).map(funcWithoutMessageHandler) + else: + jrdd = helper.createRDDWithMessageHandler( + sc._jsc, kafkaParams, joffsetRanges, jleaders) + rdd = RDD(jrdd, sc).map(funcWithMessageHandler) + + return KafkaRDD(rdd._jrdd, sc, rdd._jrdd_deserializer) + + @staticmethod + def _get_helper(sc): try: - helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ - .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") - helper = helperClass.newInstance() - joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges] - jleaders = dict([(k._jTopicAndPartition(helper), - v._jBroker(helper)) for (k, v) in leaders.items()]) - jrdd = helper.createRDD(sc._jsc, kafkaParams, joffsetRanges, jleaders) - except Py4JJavaError as e: - if 'ClassNotFoundException' in str(e.java_exception): + return sc._jvm.org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper() + except TypeError as e: + if str(e) == "'JavaPackage' object is not callable": KafkaUtils._printErrorMsg(sc) - raise e - - ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) - rdd = RDD(jrdd, sc, ser).map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) - return KafkaRDD(rdd._jrdd, rdd.ctx, rdd._jrdd_deserializer) + raise @staticmethod def _printErrorMsg(sc): @@ -296,16 +319,8 @@ def offsetRanges(self): Get the OffsetRange of specific KafkaRDD. :return: A list of OffsetRange """ - try: - helperClass = self.ctx._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ - .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") - helper = helperClass.newInstance() - joffsetRanges = helper.offsetRangesOfKafkaRDD(self._jrdd.rdd()) - except Py4JJavaError as e: - if 'ClassNotFoundException' in str(e.java_exception): - KafkaUtils._printErrorMsg(self.ctx) - raise e - + helper = KafkaUtils._get_helper(self.ctx) + joffsetRanges = helper.offsetRangesOfKafkaRDD(self._jrdd.rdd()) ranges = [OffsetRange(o.topic(), o.partition(), o.fromOffset(), o.untilOffset()) for o in joffsetRanges] return ranges @@ -365,3 +380,53 @@ def _jdstream(self): dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc) self._jdstream_val = dstream.asJavaDStream() return self._jdstream_val + + +class KafkaMessageAndMetadata(object): + """ + Kafka message and metadata information. Including topic, partition, offset and message + """ + + def __init__(self, topic, partition, offset, key, message): + """ + Python wrapper of Kafka MessageAndMetadata + :param topic: topic name of this Kafka message + :param partition: partition id of this Kafka message + :param offset: Offset of this Kafka message in the specific partition + :param key: key payload of this Kafka message, can be null if this Kafka message has no key + specified, the return data is undecoded bytearry. + :param message: actual message payload of this Kafka message, the return data is + undecoded bytearray. + """ + self.topic = topic + self.partition = partition + self.offset = offset + self._rawKey = key + self._rawMessage = message + self._keyDecoder = utf8_decoder + self._valueDecoder = utf8_decoder + + def __str__(self): + return "KafkaMessageAndMetadata(topic: %s, partition: %d, offset: %d, key and message...)" \ + % (self.topic, self.partition, self.offset) + + def __repr__(self): + return self.__str__() + + def __reduce__(self): + return (KafkaMessageAndMetadata, + (self.topic, self.partition, self.offset, self._rawKey, self._rawMessage)) + + def _set_key_decoder(self, decoder): + self._keyDecoder = decoder + + def _set_value_decoder(self, decoder): + self._valueDecoder = decoder + + @property + def key(self): + return self._keyDecoder(self._rawKey) + + @property + def message(self): + return self._valueDecoder(self._rawMessage) diff --git a/python/pyspark/streaming/kinesis.py b/python/pyspark/streaming/kinesis.py index af72c3d6903f9..434ce83e1e6f9 100644 --- a/python/pyspark/streaming/kinesis.py +++ b/python/pyspark/streaming/kinesis.py @@ -74,16 +74,14 @@ def createStream(ssc, kinesisAppName, streamName, endpointUrl, regionName, try: # Use KinesisUtilsPythonHelper to access Scala's KinesisUtils - helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ - .loadClass("org.apache.spark.streaming.kinesis.KinesisUtilsPythonHelper") - helper = helperClass.newInstance() - jstream = helper.createStream(ssc._jssc, kinesisAppName, streamName, endpointUrl, - regionName, initialPositionInStream, jduration, jlevel, - awsAccessKeyId, awsSecretKey) - except Py4JJavaError as e: - if 'ClassNotFoundException' in str(e.java_exception): + helper = ssc._jvm.org.apache.spark.streaming.kinesis.KinesisUtilsPythonHelper() + except TypeError as e: + if str(e) == "'JavaPackage' object is not callable": KinesisUtils._printErrorMsg(ssc.sparkContext) - raise e + raise + jstream = helper.createStream(ssc._jssc, kinesisAppName, streamName, endpointUrl, + regionName, initialPositionInStream, jduration, jlevel, + awsAccessKeyId, awsSecretKey) stream = DStream(jstream, ssc, NoOpSerializer()) return stream.map(lambda v: decoder(v)) diff --git a/python/pyspark/streaming/listener.py b/python/pyspark/streaming/listener.py new file mode 100644 index 0000000000000..b830797f5c0a0 --- /dev/null +++ b/python/pyspark/streaming/listener.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +__all__ = ["StreamingListener"] + + +class StreamingListener(object): + + def __init__(self): + pass + + def onReceiverStarted(self, receiverStarted): + """ + Called when a receiver has been started + """ + pass + + def onReceiverError(self, receiverError): + """ + Called when a receiver has reported an error + """ + pass + + def onReceiverStopped(self, receiverStopped): + """ + Called when a receiver has been stopped + """ + pass + + def onBatchSubmitted(self, batchSubmitted): + """ + Called when a batch of jobs has been submitted for processing. + """ + pass + + def onBatchStarted(self, batchStarted): + """ + Called when processing of a batch of jobs has started. + """ + pass + + def onBatchCompleted(self, batchCompleted): + """ + Called when processing of a batch of jobs has completed. + """ + pass + + def onOutputOperationStarted(self, outputOperationStarted): + """ + Called when processing of a job of a batch has started. + """ + pass + + def onOutputOperationCompleted(self, outputOperationCompleted): + """ + Called when processing of a job of a batch has completed + """ + pass + + class Java: + implements = ["org.apache.spark.streaming.api.java.PythonStreamingListener"] diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py deleted file mode 100644 index 1ce4093196e63..0000000000000 --- a/python/pyspark/streaming/mqtt.py +++ /dev/null @@ -1,73 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from py4j.protocol import Py4JJavaError - -from pyspark.storagelevel import StorageLevel -from pyspark.serializers import UTF8Deserializer -from pyspark.streaming import DStream - -__all__ = ['MQTTUtils'] - - -class MQTTUtils(object): - - @staticmethod - def createStream(ssc, brokerUrl, topic, - storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): - """ - Create an input stream that pulls messages from a Mqtt Broker. - - :param ssc: StreamingContext object - :param brokerUrl: Url of remote mqtt publisher - :param topic: topic name to subscribe to - :param storageLevel: RDD storage level. - :return: A DStream object - """ - jlevel = ssc._sc._getJavaStorageLevel(storageLevel) - - try: - helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ - .loadClass("org.apache.spark.streaming.mqtt.MQTTUtilsPythonHelper") - helper = helperClass.newInstance() - jstream = helper.createStream(ssc._jssc, brokerUrl, topic, jlevel) - except Py4JJavaError as e: - if 'ClassNotFoundException' in str(e.java_exception): - MQTTUtils._printErrorMsg(ssc.sparkContext) - raise e - - return DStream(jstream, ssc, UTF8Deserializer()) - - @staticmethod - def _printErrorMsg(sc): - print(""" -________________________________________________________________________________________________ - - Spark Streaming's MQTT libraries not found in class path. Try one of the following. - - 1. Include the MQTT library and its dependencies with in the - spark-submit command as - - $ bin/spark-submit --packages org.apache.spark:spark-streaming-mqtt:%s ... - - 2. Download the JAR of the artifact from Maven Central http://search.maven.org/, - Group Id = org.apache.spark, Artifact Id = spark-streaming-mqtt-assembly, Version = %s. - Then, include the jar in the spark-submit command as - - $ bin/spark-submit --jars ... -________________________________________________________________________________________________ -""" % (sc.version, sc.version)) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 179479625bca4..148bf7e8ff5ce 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -46,8 +46,8 @@ from pyspark.streaming.context import StreamingContext from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition from pyspark.streaming.flume import FlumeUtils -from pyspark.streaming.mqtt import MQTTUtils from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream +from pyspark.streaming.listener import StreamingListener class PySparkStreamingTestCase(unittest.TestCase): @@ -278,8 +278,10 @@ def test_countByValue(self): def func(dstream): return dstream.countByValue() - expected = [[4], [4], [3]] - self._test_func(input, func, expected) + expected = [[(1, 2), (2, 2), (3, 2), (4, 2)], + [(5, 2), (6, 2), (7, 1), (8, 1)], + [("a", 2), ("b", 1), ("", 1)]] + self._test_func(input, func, expected, sort=True) def test_groupByKey(self): """Basic operation test for DStream.groupByKey.""" @@ -402,6 +404,216 @@ def func(dstream): expected = [[('k', v)] for v in expected] self._test_func(input, func, expected) + def test_update_state_by_key_initial_rdd(self): + + def updater(vs, s): + if not s: + s = [] + s.extend(vs) + return s + + initial = [('k', [0, 1])] + initial = self.sc.parallelize(initial, 1) + + input = [[('k', i)] for i in range(2, 5)] + + def func(dstream): + return dstream.updateStateByKey(updater, initialRDD=initial) + + expected = [[0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]] + expected = [[('k', v)] for v in expected] + self._test_func(input, func, expected) + + def test_failed_func(self): + # Test failure in + # TransformFunction.apply(rdd: Option[RDD[_]], time: Time) + input = [self.sc.parallelize([d], 1) for d in range(4)] + input_stream = self.ssc.queueStream(input) + + def failed_func(i): + raise ValueError("This is a special error") + + input_stream.map(failed_func).pprint() + self.ssc.start() + try: + self.ssc.awaitTerminationOrTimeout(10) + except: + import traceback + failure = traceback.format_exc() + self.assertTrue("This is a special error" in failure) + return + + self.fail("a failed func should throw an error") + + def test_failed_func2(self): + # Test failure in + # TransformFunction.apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time) + input = [self.sc.parallelize([d], 1) for d in range(4)] + input_stream1 = self.ssc.queueStream(input) + input_stream2 = self.ssc.queueStream(input) + + def failed_func(rdd1, rdd2): + raise ValueError("This is a special error") + + input_stream1.transformWith(failed_func, input_stream2, True).pprint() + self.ssc.start() + try: + self.ssc.awaitTerminationOrTimeout(10) + except: + import traceback + failure = traceback.format_exc() + self.assertTrue("This is a special error" in failure) + return + + self.fail("a failed func should throw an error") + + def test_failed_func_with_reseting_failure(self): + input = [self.sc.parallelize([d], 1) for d in range(4)] + input_stream = self.ssc.queueStream(input) + + def failed_func(i): + if i == 1: + # Make it fail in the second batch + raise ValueError("This is a special error") + else: + return i + + # We should be able to see the results of the 3rd and 4th batches even if the second batch + # fails + expected = [[0], [2], [3]] + self.assertEqual(expected, self._collect(input_stream.map(failed_func), 3)) + try: + self.ssc.awaitTerminationOrTimeout(10) + except: + import traceback + failure = traceback.format_exc() + self.assertTrue("This is a special error" in failure) + return + + self.fail("a failed func should throw an error") + + +class StreamingListenerTests(PySparkStreamingTestCase): + + duration = .5 + + class BatchInfoCollector(StreamingListener): + + def __init__(self): + super(StreamingListener, self).__init__() + self.batchInfosCompleted = [] + self.batchInfosStarted = [] + self.batchInfosSubmitted = [] + + def onBatchSubmitted(self, batchSubmitted): + self.batchInfosSubmitted.append(batchSubmitted.batchInfo()) + + def onBatchStarted(self, batchStarted): + self.batchInfosStarted.append(batchStarted.batchInfo()) + + def onBatchCompleted(self, batchCompleted): + self.batchInfosCompleted.append(batchCompleted.batchInfo()) + + def test_batch_info_reports(self): + batch_collector = self.BatchInfoCollector() + self.ssc.addStreamingListener(batch_collector) + input = [[1], [2], [3], [4]] + + def func(dstream): + return dstream.map(int) + expected = [[1], [2], [3], [4]] + self._test_func(input, func, expected) + + batchInfosSubmitted = batch_collector.batchInfosSubmitted + batchInfosStarted = batch_collector.batchInfosStarted + batchInfosCompleted = batch_collector.batchInfosCompleted + + self.wait_for(batchInfosCompleted, 4) + + self.assertGreaterEqual(len(batchInfosSubmitted), 4) + for info in batchInfosSubmitted: + self.assertGreaterEqual(info.batchTime().milliseconds(), 0) + self.assertGreaterEqual(info.submissionTime(), 0) + + for streamId in info.streamIdToInputInfo(): + streamInputInfo = info.streamIdToInputInfo()[streamId] + self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) + self.assertGreaterEqual(streamInputInfo.numRecords, 0) + for key in streamInputInfo.metadata(): + self.assertIsNotNone(streamInputInfo.metadata()[key]) + self.assertIsNotNone(streamInputInfo.metadataDescription()) + + for outputOpId in info.outputOperationInfos(): + outputInfo = info.outputOperationInfos()[outputOpId] + self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) + self.assertGreaterEqual(outputInfo.id(), 0) + self.assertIsNotNone(outputInfo.name()) + self.assertIsNotNone(outputInfo.description()) + self.assertGreaterEqual(outputInfo.startTime(), -1) + self.assertGreaterEqual(outputInfo.endTime(), -1) + self.assertIsNone(outputInfo.failureReason()) + + self.assertEqual(info.schedulingDelay(), -1) + self.assertEqual(info.processingDelay(), -1) + self.assertEqual(info.totalDelay(), -1) + self.assertEqual(info.numRecords(), 0) + + self.assertGreaterEqual(len(batchInfosStarted), 4) + for info in batchInfosStarted: + self.assertGreaterEqual(info.batchTime().milliseconds(), 0) + self.assertGreaterEqual(info.submissionTime(), 0) + + for streamId in info.streamIdToInputInfo(): + streamInputInfo = info.streamIdToInputInfo()[streamId] + self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) + self.assertGreaterEqual(streamInputInfo.numRecords, 0) + for key in streamInputInfo.metadata(): + self.assertIsNotNone(streamInputInfo.metadata()[key]) + self.assertIsNotNone(streamInputInfo.metadataDescription()) + + for outputOpId in info.outputOperationInfos(): + outputInfo = info.outputOperationInfos()[outputOpId] + self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) + self.assertGreaterEqual(outputInfo.id(), 0) + self.assertIsNotNone(outputInfo.name()) + self.assertIsNotNone(outputInfo.description()) + self.assertGreaterEqual(outputInfo.startTime(), -1) + self.assertGreaterEqual(outputInfo.endTime(), -1) + self.assertIsNone(outputInfo.failureReason()) + + self.assertGreaterEqual(info.schedulingDelay(), 0) + self.assertEqual(info.processingDelay(), -1) + self.assertEqual(info.totalDelay(), -1) + self.assertEqual(info.numRecords(), 0) + + self.assertGreaterEqual(len(batchInfosCompleted), 4) + for info in batchInfosCompleted: + self.assertGreaterEqual(info.batchTime().milliseconds(), 0) + self.assertGreaterEqual(info.submissionTime(), 0) + + for streamId in info.streamIdToInputInfo(): + streamInputInfo = info.streamIdToInputInfo()[streamId] + self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0) + self.assertGreaterEqual(streamInputInfo.numRecords, 0) + for key in streamInputInfo.metadata(): + self.assertIsNotNone(streamInputInfo.metadata()[key]) + self.assertIsNotNone(streamInputInfo.metadataDescription()) + + for outputOpId in info.outputOperationInfos(): + outputInfo = info.outputOperationInfos()[outputOpId] + self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0) + self.assertGreaterEqual(outputInfo.id(), 0) + self.assertIsNotNone(outputInfo.name()) + self.assertIsNotNone(outputInfo.description()) + self.assertGreaterEqual(outputInfo.startTime(), 0) + self.assertGreaterEqual(outputInfo.endTime(), 0) + self.assertIsNone(outputInfo.failureReason()) + + self.assertGreaterEqual(info.schedulingDelay(), 0) + self.assertGreaterEqual(info.processingDelay(), 0) + self.assertGreaterEqual(info.totalDelay(), 0) + self.assertEqual(info.numRecords(), 0) + class WindowFunctionTests(PySparkStreamingTestCase): @@ -440,7 +652,16 @@ def test_count_by_value_and_window(self): def func(dstream): return dstream.countByValueAndWindow(2.5, .5) - expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]] + expected = [[(0, 1)], + [(0, 2), (1, 1)], + [(0, 3), (1, 2), (2, 1)], + [(0, 4), (1, 3), (2, 2), (3, 1)], + [(0, 5), (1, 4), (2, 3), (3, 2), (4, 1)], + [(0, 5), (1, 5), (2, 4), (3, 3), (4, 2), (5, 1)], + [(0, 4), (1, 4), (2, 4), (3, 3), (4, 2), (5, 1)], + [(0, 3), (1, 3), (2, 3), (3, 3), (4, 2), (5, 1)], + [(0, 2), (1, 2), (2, 2), (3, 2), (4, 2), (5, 1)], + [(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1)]] self._test_func(input, func, expected) def test_group_by_key_and_window(self): @@ -459,6 +680,17 @@ def test_reduce_by_invalid_window(self): self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1)) self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1)) + def test_reduce_by_key_and_window_with_none_invFunc(self): + input = [range(1), range(2), range(3), range(4), range(5), range(6)] + + def func(dstream): + return dstream.map(lambda x: (x, 1))\ + .reduceByKeyAndWindow(operator.add, None, 5, 1)\ + .filter(lambda kv: kv[1] > 0).count() + + expected = [[2], [4], [6], [6], [6], [6]] + self._test_func(input, func, expected) + class StreamingContextTests(PySparkStreamingTestCase): @@ -611,12 +843,16 @@ class CheckpointTests(unittest.TestCase): @staticmethod def tearDownClass(): # Clean up in the JVM just in case there has been some issues in Python API - jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive() - if jStreamingContextOption.nonEmpty(): - jStreamingContextOption.get().stop() - jSparkContextOption = SparkContext._jvm.SparkContext.get() - if jSparkContextOption.nonEmpty(): - jSparkContextOption.get().stop() + if SparkContext._jvm is not None: + jStreamingContextOption = \ + SparkContext._jvm.org.apache.spark.streaming.StreamingContext.getActive() + if jStreamingContextOption.nonEmpty(): + jStreamingContextOption.get().stop() + + def setUp(self): + self.ssc = None + self.sc = None + self.cpd = None def tearDown(self): if self.ssc is not None: @@ -626,6 +862,34 @@ def tearDown(self): if self.cpd is not None: shutil.rmtree(self.cpd) + def test_transform_function_serializer_failure(self): + inputd = tempfile.mkdtemp() + self.cpd = tempfile.mkdtemp("test_transform_function_serializer_failure") + + def setup(): + conf = SparkConf().set("spark.default.parallelism", 1) + sc = SparkContext(conf=conf) + ssc = StreamingContext(sc, 0.5) + + # A function that cannot be serialized + def process(time, rdd): + sc.parallelize(range(1, 10)) + + ssc.textFileStream(inputd).foreachRDD(process) + return ssc + + self.ssc = StreamingContext.getOrCreate(self.cpd, setup) + try: + self.ssc.start() + except: + import traceback + failure = traceback.format_exc() + self.assertTrue( + "It appears that you are attempting to reference SparkContext" in failure) + return + + self.fail("using SparkContext in process should fail because it's not Serializable") + def test_get_or_create_and_get_active_or_create(self): inputd = tempfile.mkdtemp() outputd = tempfile.mkdtemp() + "/" @@ -648,7 +912,7 @@ def setup(): self.cpd = tempfile.mkdtemp("test_streaming_cps") self.setupCalled = False self.ssc = StreamingContext.getOrCreate(self.cpd, setup) - self.assertFalse(self.setupCalled) + self.assertTrue(self.setupCalled) self.ssc.start() @@ -694,11 +958,11 @@ def check_output(n): # Verify that getOrCreate() uses existing SparkContext self.ssc.stop(True, True) time.sleep(1) - sc = SparkContext(SparkConf()) + self.sc = SparkContext(conf=SparkConf()) self.setupCalled = False self.ssc = StreamingContext.getOrCreate(self.cpd, setup) self.assertFalse(self.setupCalled) - self.assertTrue(self.ssc.sparkContext == sc) + self.assertTrue(self.ssc.sparkContext == self.sc) # Verify the getActiveOrCreate() recovers from checkpoint files self.ssc.stop(True, True) @@ -717,11 +981,11 @@ def check_output(n): # Verify that getActiveOrCreate() uses existing SparkContext self.ssc.stop(True, True) time.sleep(1) - self.sc = SparkContext(SparkConf()) + self.sc = SparkContext(conf=SparkConf()) self.setupCalled = False self.ssc = StreamingContext.getActiveOrCreate(self.cpd, setup) self.assertFalse(self.setupCalled) - self.assertTrue(self.ssc.sparkContext == sc) + self.assertTrue(self.ssc.sparkContext == self.sc) # Verify that getActiveOrCreate() calls setup() in absence of checkpoint files self.ssc.stop(True, True) @@ -741,19 +1005,16 @@ class KafkaStreamTests(PySparkStreamingTestCase): def setUp(self): super(KafkaStreamTests, self).setUp() - - kafkaTestUtilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ - .loadClass("org.apache.spark.streaming.kafka.KafkaTestUtils") - self._kafkaTestUtils = kafkaTestUtilsClz.newInstance() + self._kafkaTestUtils = self.ssc._jvm.org.apache.spark.streaming.kafka.KafkaTestUtils() self._kafkaTestUtils.setup() def tearDown(self): + super(KafkaStreamTests, self).tearDown() + if self._kafkaTestUtils is not None: self._kafkaTestUtils.teardown() self._kafkaTestUtils = None - super(KafkaStreamTests, self).tearDown() - def _randomTopic(self): return "topic-%d" % random.randint(0, 10000) @@ -915,6 +1176,90 @@ def test_topic_and_partition_equality(self): self.assertNotEqual(topic_and_partition_a, topic_and_partition_c) self.assertNotEqual(topic_and_partition_a, topic_and_partition_d) + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_direct_stream_transform_with_checkpoint(self): + """Test the Python direct Kafka stream transform with checkpoint correctly recovered.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(), + "auto.offset.reset": "smallest"} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + offsetRanges = [] + + def transformWithOffsetRanges(rdd): + for o in rdd.offsetRanges(): + offsetRanges.append(o) + return rdd + + self.ssc.stop(False) + self.ssc = None + tmpdir = "checkpoint-test-%d" % random.randint(0, 10000) + + def setup(): + ssc = StreamingContext(self.sc, 0.5) + ssc.checkpoint(tmpdir) + stream = KafkaUtils.createDirectStream(ssc, [topic], kafkaParams) + stream.transform(transformWithOffsetRanges).count().pprint() + return ssc + + try: + ssc1 = StreamingContext.getOrCreate(tmpdir, setup) + ssc1.start() + self.wait_for(offsetRanges, 1) + self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))]) + + # To make sure some checkpoint is written + time.sleep(3) + ssc1.stop(False) + ssc1 = None + + # Restart again to make sure the checkpoint is recovered correctly + ssc2 = StreamingContext.getOrCreate(tmpdir, setup) + ssc2.start() + ssc2.awaitTermination(3) + ssc2.stop(stopSparkContext=False, stopGraceFully=True) + ssc2 = None + finally: + shutil.rmtree(tmpdir) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_rdd_message_handler(self): + """Test Python direct Kafka RDD MessageHandler.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 1, "c": 2} + offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))] + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + + def getKeyAndDoubleMessage(m): + return m and (m.key, m.message * 2) + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, + messageHandler=getKeyAndDoubleMessage) + self._validateRddResult({"aa": 1, "bb": 1, "cc": 2}, rdd) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_direct_stream_message_handler(self): + """Test the Python direct Kafka stream MessageHandler.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(), + "auto.offset.reset": "smallest"} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + def getKeyAndDoubleMessage(m): + return m and (m.key, m.message * 2) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, + messageHandler=getKeyAndDoubleMessage) + self._validateStreamResult({"aa": 1, "bb": 2, "cc": 3}, stream) + class FlumeStreamTests(PySparkStreamingTestCase): timeout = 20 # seconds @@ -922,10 +1267,7 @@ class FlumeStreamTests(PySparkStreamingTestCase): def setUp(self): super(FlumeStreamTests, self).setUp() - - utilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ - .loadClass("org.apache.spark.streaming.flume.FlumeTestUtils") - self._utils = utilsClz.newInstance() + self._utils = self.ssc._jvm.org.apache.spark.streaming.flume.FlumeTestUtils() def tearDown(self): if self._utils is not None: @@ -990,10 +1332,7 @@ class FlumePollingStreamTests(PySparkStreamingTestCase): maxAttempts = 5 def setUp(self): - utilsClz = \ - self.sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ - .loadClass("org.apache.spark.streaming.flume.PollingFlumeTestUtils") - self._utils = utilsClz.newInstance() + self._utils = self.sc._jvm.org.apache.spark.streaming.flume.PollingFlumeTestUtils() def tearDown(self): if self._utils is not None: @@ -1064,68 +1403,6 @@ def test_flume_polling_multiple_hosts(self): self._testMultipleTimes(self._testFlumePollingMultipleHosts) -class MQTTStreamTests(PySparkStreamingTestCase): - timeout = 20 # seconds - duration = 1 - - def setUp(self): - super(MQTTStreamTests, self).setUp() - - MQTTTestUtilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ - .loadClass("org.apache.spark.streaming.mqtt.MQTTTestUtils") - self._MQTTTestUtils = MQTTTestUtilsClz.newInstance() - self._MQTTTestUtils.setup() - - def tearDown(self): - if self._MQTTTestUtils is not None: - self._MQTTTestUtils.teardown() - self._MQTTTestUtils = None - - super(MQTTStreamTests, self).tearDown() - - def _randomTopic(self): - return "topic-%d" % random.randint(0, 10000) - - def _startContext(self, topic): - # Start the StreamingContext and also collect the result - stream = MQTTUtils.createStream(self.ssc, "tcp://" + self._MQTTTestUtils.brokerUri(), topic) - result = [] - - def getOutput(_, rdd): - for data in rdd.collect(): - result.append(data) - - stream.foreachRDD(getOutput) - self.ssc.start() - return result - - def test_mqtt_stream(self): - """Test the Python MQTT stream API.""" - sendData = "MQTT demo for spark streaming" - topic = self._randomTopic() - result = self._startContext(topic) - - def retry(): - self._MQTTTestUtils.publishData(topic, sendData) - # Because "publishData" sends duplicate messages, here we should use > 0 - self.assertTrue(len(result) > 0) - self.assertEqual(sendData, result[0]) - - # Retry it because we don't know when the receiver will start. - self._retry_or_timeout(retry) - - def _retry_or_timeout(self, test_func): - start_time = time.time() - while True: - try: - test_func() - break - except: - if time.time() - start_time > self.timeout: - raise - time.sleep(0.01) - - class KinesisStreamTests(PySparkStreamingTestCase): def test_kinesis_stream_api(self): @@ -1149,10 +1426,7 @@ def test_kinesis_stream(self): import random kinesisAppName = ("KinesisStreamTests-%d" % abs(random.randint(0, 10000000))) - kinesisTestUtilsClz = \ - self.sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ - .loadClass("org.apache.spark.streaming.kinesis.KinesisTestUtils") - kinesisTestUtils = kinesisTestUtilsClz.newInstance() + kinesisTestUtils = self.ssc._jvm.org.apache.spark.streaming.kinesis.KinesisTestUtils() try: kinesisTestUtils.createStream() aWSCredentials = kinesisTestUtils.getAWSCredentials() @@ -1208,7 +1482,7 @@ def search_kafka_assembly_jar(): raise Exception( ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + "You need to build Spark with " - "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or " + "'build/sbt assembly/package streaming-kafka-assembly/assembly' or " "'build/mvn package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Kafka assembly JARs: %s; please " @@ -1234,43 +1508,9 @@ def search_flume_assembly_jar(): return jars[0] -def search_mqtt_assembly_jar(): - SPARK_HOME = os.environ["SPARK_HOME"] - mqtt_assembly_dir = os.path.join(SPARK_HOME, "external/mqtt-assembly") - jars = search_jar(mqtt_assembly_dir, "spark-streaming-mqtt-assembly") - if not jars: - raise Exception( - ("Failed to find Spark Streaming MQTT assembly jar in %s. " % mqtt_assembly_dir) + - "You need to build Spark with " - "'build/sbt assembly/assembly streaming-mqtt-assembly/assembly' or " - "'build/mvn package' before running this test") - elif len(jars) > 1: - raise Exception(("Found multiple Spark Streaming MQTT assembly JARs: %s; please " - "remove all but one") % (", ".join(jars))) - else: - return jars[0] - - -def search_mqtt_test_jar(): - SPARK_HOME = os.environ["SPARK_HOME"] - mqtt_test_dir = os.path.join(SPARK_HOME, "external/mqtt") - jars = glob.glob( - os.path.join(mqtt_test_dir, "target/scala-*/spark-streaming-mqtt-test-*.jar")) - if not jars: - raise Exception( - ("Failed to find Spark Streaming MQTT test jar in %s. " % mqtt_test_dir) + - "You need to build Spark with " - "'build/sbt assembly/assembly streaming-mqtt/test:assembly'") - elif len(jars) > 1: - raise Exception(("Found multiple Spark Streaming MQTT test JARs: %s; please " - "remove all but one") % (", ".join(jars))) - else: - return jars[0] - - def search_kinesis_asl_assembly_jar(): SPARK_HOME = os.environ["SPARK_HOME"] - kinesis_asl_assembly_dir = os.path.join(SPARK_HOME, "extras/kinesis-asl-assembly") + kinesis_asl_assembly_dir = os.path.join(SPARK_HOME, "external/kinesis-asl-assembly") jars = search_jar(kinesis_asl_assembly_dir, "spark-streaming-kinesis-asl-assembly") if not jars: return None @@ -1286,31 +1526,29 @@ def search_kinesis_asl_assembly_jar(): are_kinesis_tests_enabled = os.environ.get(kinesis_test_environ_var) == '1' if __name__ == "__main__": + from pyspark.streaming.tests import * kafka_assembly_jar = search_kafka_assembly_jar() flume_assembly_jar = search_flume_assembly_jar() - mqtt_assembly_jar = search_mqtt_assembly_jar() - mqtt_test_jar = search_mqtt_test_jar() kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar() if kinesis_asl_assembly_jar is None: kinesis_jar_present = False - jars = "%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar, - mqtt_test_jar) + jars = "%s,%s" % (kafka_assembly_jar, flume_assembly_jar) else: kinesis_jar_present = True - jars = "%s,%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar, - mqtt_test_jar, kinesis_asl_assembly_jar) + jars = "%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar) os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, - KafkaStreamTests, FlumeStreamTests, FlumePollingStreamTests, MQTTStreamTests] + KafkaStreamTests, FlumeStreamTests, FlumePollingStreamTests, + StreamingListenerTests] if kinesis_jar_present is True: testcases.append(KinesisStreamTests) elif are_kinesis_tests_enabled is False: sys.stderr.write("Skipping all Kinesis Python tests as the optional Kinesis project was " "not compiled into a JAR. To run these tests, " - "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/assembly " + "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/package " "streaming-kinesis-asl-assembly/assembly' or " "'build/mvn -Pkinesis-asl package' before running this test.") else: @@ -1318,15 +1556,20 @@ def search_kinesis_asl_assembly_jar(): ("Failed to find Spark Streaming Kinesis assembly jar in %s. " % kinesis_asl_assembly_dir) + "You need to build Spark with 'build/sbt -Pkinesis-asl " - "assembly/assembly streaming-kinesis-asl-assembly/assembly'" + "assembly/package streaming-kinesis-asl-assembly/assembly'" "or 'build/mvn -Pkinesis-asl package' before running this test.") sys.stderr.write("Running tests: %s \n" % (str(testcases))) + failed = False for testcase in testcases: sys.stderr.write("[Running %s]\n" % (testcase)) tests = unittest.TestLoader().loadTestsFromTestCase(testcase) if xmlrunner: - unittest.main(tests, verbosity=3, - testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + result = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=3).run(tests) + if not result.wasSuccessful(): + failed = True else: - unittest.TextTestRunner(verbosity=3).run(tests) + result = unittest.TextTestRunner(verbosity=3).run(tests) + if not result.wasSuccessful(): + failed = True + sys.exit(failed) diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index b20613b1283bd..abbbf6eb9394f 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -37,13 +37,16 @@ def __init__(self, ctx, func, *deserializers): self.ctx = ctx self.func = func self.deserializers = deserializers - self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser) + self.rdd_wrap_func = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser) + self.failure = None def rdd_wrapper(self, func): - self._rdd_wrapper = func + self.rdd_wrap_func = func return self def call(self, milliseconds, jrdds): + # Clear the failure + self.failure = None try: if self.ctx is None: self.ctx = SparkContext._active_spark_context @@ -56,14 +59,17 @@ def call(self, milliseconds, jrdds): if len(sers) < len(jrdds): sers += (sers[0],) * (len(jrdds) - len(sers)) - rdds = [self._rdd_wrapper(jrdd, self.ctx, ser) if jrdd else None + rdds = [self.rdd_wrap_func(jrdd, self.ctx, ser) if jrdd else None for jrdd, ser in zip(jrdds, sers)] t = datetime.fromtimestamp(milliseconds / 1000.0) r = self.func(t, *rdds) if r: return r._jrdd - except Exception: - traceback.print_exc() + except: + self.failure = traceback.format_exc() + + def getLastFailure(self): + return self.failure def __repr__(self): return "TransformFunction(%s)" % self.func @@ -88,20 +94,29 @@ def __init__(self, ctx, serializer, gateway=None): self.serializer = serializer self.gateway = gateway or self.ctx._gateway self.gateway.jvm.PythonDStream.registerSerializer(self) + self.failure = None def dumps(self, id): + # Clear the failure + self.failure = None try: func = self.gateway.gateway_property.pool[id] - return bytearray(self.serializer.dumps((func.func, func.deserializers))) - except Exception: - traceback.print_exc() + return bytearray(self.serializer.dumps(( + func.func, func.rdd_wrap_func, func.deserializers))) + except: + self.failure = traceback.format_exc() def loads(self, data): + # Clear the failure + self.failure = None try: - f, deserializers = self.serializer.loads(bytes(data)) - return TransformFunction(self.ctx, f, *deserializers) - except Exception: - traceback.print_exc() + f, wrap_func, deserializers = self.serializer.loads(bytes(data)) + return TransformFunction(self.ctx, f, *deserializers).rdd_wrapper(wrap_func) + except: + self.failure = traceback.format_exc() + + def getLastFailure(self): + return self.failure def __repr__(self): return "TransformFunctionSerializer(%s)" % self.serializer diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 5bd94476597ab..97ea39dde05fa 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -228,6 +228,12 @@ def test_itemgetter(self): getter2 = ser.loads(ser.dumps(getter)) self.assertEqual(getter(d), getter2(d)) + def test_function_module_name(self): + ser = CloudPickleSerializer() + func = lambda x: x + func2 = ser.loads(ser.dumps(func)) + self.assertEqual(func.__module__, func2.__module__) + def test_attrgetter(self): from operator import attrgetter ser = CloudPickleSerializer() @@ -688,6 +694,21 @@ def test_large_broadcast(self): m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() self.assertEqual(N, m) + def test_unpersist(self): + N = 1000 + data = [[float(i) for i in range(300)] for i in range(N)] + bdata = self.sc.broadcast(data) # 3MB + bdata.unpersist() + m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + self.assertEqual(N, m) + bdata.destroy() + try: + self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + except Exception as e: + pass + else: + raise Exception("job should fail after destroy the broadcast") + def test_multiple_broadcasts(self): N = 1 << 21 b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM @@ -1893,6 +1914,13 @@ def test_get_or_create(self): with SparkContext.getOrCreate() as sc: self.assertTrue(SparkContext.getOrCreate() is sc) + def test_parallelize_eager_cleanup(self): + with SparkContext() as sc: + temp_files = os.listdir(sc._temp_dir) + rdd = sc.parallelize([0, 1, 2]) + post_parallalize_temp_files = os.listdir(sc._temp_dir) + self.assertEqual(temp_files, post_parallalize_temp_files) + def test_stop(self): sc = SparkContext() self.assertNotEqual(SparkContext._active_spark_context, None) @@ -1960,6 +1988,18 @@ def test_startTime(self): self.assertGreater(sc.startTime, 0) +class ConfTests(unittest.TestCase): + def test_memory_conf(self): + memoryList = ["1T", "1G", "1M", "1024K"] + for memory in memoryList: + sc = SparkContext(conf=SparkConf().set("spark.python.worker.memory", memory)) + l = list(range(1024)) + random.shuffle(l) + rdd = sc.parallelize(l, 4) + self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) + sc.stop() + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): @@ -2008,6 +2048,7 @@ def test_statcounter_array(self): if __name__ == "__main__": + from pyspark.tests import * if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") if not _have_numpy: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 42c2f8b75933e..cf47ab8f96c6d 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -29,7 +29,7 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ - write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer + write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer from pyspark import shuffle pickleSer = PickleSerializer() @@ -50,6 +50,65 @@ def add_path(path): sys.path.insert(1, path) +def read_command(serializer, file): + command = serializer._read_with_length(file) + if isinstance(command, Broadcast): + command = serializer.loads(command.value) + return command + + +def chain(f, g): + """chain two function together """ + return lambda *a: g(f(*a)) + + +def wrap_udf(f, return_type): + if return_type.needConversion(): + toInternal = return_type.toInternal + return lambda *a: toInternal(f(*a)) + else: + return lambda *a: f(*a) + + +def read_single_udf(pickleSer, infile): + num_arg = read_int(infile) + arg_offsets = [read_int(infile) for i in range(num_arg)] + row_func = None + for i in range(read_int(infile)): + f, return_type = read_command(pickleSer, infile) + if row_func is None: + row_func = f + else: + row_func = chain(row_func, f) + # the last returnType will be the return type of UDF + return arg_offsets, wrap_udf(row_func, return_type) + + +def read_udfs(pickleSer, infile): + num_udfs = read_int(infile) + if num_udfs == 1: + # fast path for single UDF + _, udf = read_single_udf(pickleSer, infile) + mapper = lambda a: udf(*a) + else: + udfs = {} + call_udf = [] + for i in range(num_udfs): + arg_offsets, udf = read_single_udf(pickleSer, infile) + udfs['f%d' % i] = udf + args = ["a[%d]" % o for o in arg_offsets] + call_udf.append("f%d(%s)" % (i, ", ".join(args))) + # Create function like this: + # lambda a: (f0(a0), f1(a1, a2), f2(a3)) + mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) + mapper = eval(mapper_str, udfs) + + func = lambda _, it: map(mapper, it) + ser = BatchedSerializer(PickleSerializer(), 100) + # profiling is not supported for UDF + return func, None, ser, ser + + def main(infile, outfile): try: boot_time = time.time() @@ -95,10 +154,12 @@ def main(infile, outfile): _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() - command = pickleSer._read_with_length(infile) - if isinstance(command, Broadcast): - command = pickleSer.loads(command.value) - func, profiler, deserializer, serializer = command + is_sql_udf = read_int(infile) + if is_sql_udf: + func, profiler, deserializer, serializer = read_udfs(pickleSer, infile) + else: + func, profiler, deserializer, serializer = read_command(pickleSer, infile) + init_time = time.time() def process(): diff --git a/python/run-tests.py b/python/run-tests.py index f5857f8c62214..38b3bb84c10be 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -53,10 +53,25 @@ def print_red(text): FAILURE_REPORTING_LOCK = Lock() LOGGER = logging.getLogger() +# Find out where the assembly jars are located. +for scala in ["2.11", "2.10"]: + build_dir = os.path.join(SPARK_HOME, "assembly", "target", "scala-" + scala) + if os.path.isdir(build_dir): + SPARK_DIST_CLASSPATH = os.path.join(build_dir, "jars", "*") + break +else: + raise Exception("Cannot find assembly build directory, please build Spark first.") + def run_individual_python_test(test_name, pyspark_python): env = dict(os.environ) - env.update({'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)}) + env.update({ + 'SPARK_DIST_CLASSPATH': SPARK_DIST_CLASSPATH, + 'SPARK_TESTING': '1', + 'SPARK_PREPEND_CLASSES': '1', + 'PYSPARK_PYTHON': which(pyspark_python), + 'PYSPARK_DRIVER_PYTHON': which(pyspark_python) + }) LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name) start_time = time.time() try: @@ -156,7 +171,7 @@ def main(): LOGGER.info("Will test against the following Python executables: %s", python_execs) LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test]) - task_queue = Queue.Queue() + task_queue = Queue.PriorityQueue() for python_exec in python_execs: python_implementation = subprocess_check_output( [python_exec, "-c", "import platform; print(platform.python_implementation())"], @@ -167,12 +182,17 @@ def main(): for module in modules_to_test: if python_implementation not in module.blacklisted_python_implementations: for test_goal in module.python_test_goals: - task_queue.put((python_exec, test_goal)) + if test_goal in ('pyspark.streaming.tests', 'pyspark.mllib.tests', + 'pyspark.tests', 'pyspark.sql.tests'): + priority = 0 + else: + priority = 100 + task_queue.put((priority, (python_exec, test_goal))) def process_queue(task_queue): while True: try: - (python_exec, test_goal) = task_queue.get_nowait() + (priority, (python_exec, test_goal)) = task_queue.get_nowait() except Queue.Empty: break try: diff --git a/python/test_support/sql/ages.csv b/python/test_support/sql/ages.csv new file mode 100644 index 0000000000000..18991feda788a --- /dev/null +++ b/python/test_support/sql/ages.csv @@ -0,0 +1,4 @@ +Joe,20 +Tom,30 +Hyukjin,25 + diff --git a/repl/pom.xml b/repl/pom.xml index fb0a0e1286c80..0f396c9b809bd 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT + spark-parent_2.11 + 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-repl_2.10 + spark-repl_2.11 jar Spark Project REPL http://spark.apache.org/ @@ -50,12 +50,6 @@ test-jar test - - org.apache.spark - spark-bagel_${scala.binary.version} - ${project.version} - runtime - org.apache.spark spark-mllib_${scala.binary.version} @@ -95,6 +89,10 @@ org.apache.spark spark-test-tags_${scala.binary.version} + + org.apache.xbean + xbean-asm5-shaded + @@ -161,7 +159,7 @@ scala-2.10 - !scala-2.11 + scala-2.10 @@ -175,7 +173,7 @@ scala-2.11 - scala-2.11 + !scala-2.10 scala-2.11/src/main/scala diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala index 14b448d076d84..7b4e14bb6aa47 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala @@ -17,9 +17,12 @@ package org.apache.spark.repl -import scala.collection.mutable.Set +import org.apache.spark.internal.Logging + +object Main extends Logging { + + initializeLogIfNecessary(true) -object Main { private var _interp: SparkILoop = _ def interp = _interp diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala index 5fb378112ef92..2b5d56a895902 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala @@ -13,7 +13,7 @@ import scala.tools.nsc.interpreter._ import scala.reflect.internal.util.BatchSourceFile import scala.tools.nsc.ast.parser.Tokens.EOF -import org.apache.spark.Logging +import org.apache.spark.internal.Logging private[repl] trait SparkExprTyper extends Logging { val repl: SparkIMain diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 304b1e8cdbed5..c5dc6ba2219f8 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -10,13 +10,10 @@ package org.apache.spark.repl import java.net.URL -import org.apache.spark.annotation.DeveloperApi - import scala.reflect.io.AbstractFile import scala.tools.nsc._ import scala.tools.nsc.backend.JavaPlatform import scala.tools.nsc.interpreter._ - import scala.tools.nsc.interpreter.{Results => IR} import Predef.{println => _, _} import java.io.{BufferedReader, FileReader} @@ -42,9 +39,10 @@ import scala.tools.reflect.StdRuntimeTags._ import java.lang.{Class => jClass} import scala.reflect.api.{Mirror, TypeCreator, Universe => ApiUniverse} -import org.apache.spark.Logging import org.apache.spark.SparkConf import org.apache.spark.SparkContext +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.Logging import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils @@ -169,7 +167,7 @@ class SparkILoop( } - private def sparkCleanUp(){ + private def sparkCleanUp() { echo("Stopping spark context.") intp.beQuietDuring { command("sc.stop()") @@ -253,7 +251,7 @@ class SparkILoop( case xs => xs find (_.name == cmd) } } - private var fallbackMode = false + private var fallbackMode = false private def toggleFallbackMode() { val old = fallbackMode @@ -261,9 +259,9 @@ class SparkILoop( System.setProperty("spark.repl.fallback", fallbackMode.toString) echo(s""" |Switched ${if (old) "off" else "on"} fallback mode without restarting. - | If you have defined classes in the repl, it would + | If you have defined classes in the repl, it would |be good to redefine them incase you plan to use them. If you still run - |into issues it would be good to restart the repl and turn on `:fallback` + |into issues it would be good to restart the repl and turn on `:fallback` |mode as first command. """.stripMargin) } @@ -350,7 +348,7 @@ class SparkILoop( shCommand, nullary("silent", "disable/enable automatic printing of results", verbosity), nullary("fallback", """ - |disable/enable advanced repl changes, these fix some issues but may introduce others. + |disable/enable advanced repl changes, these fix some issues but may introduce others. |This mode will be removed once these fixes stablize""".stripMargin, toggleFallbackMode), cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand), nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand) @@ -799,9 +797,11 @@ class SparkILoop( // echo("Switched " + (if (old) "off" else "on") + " result printing.") } - /** Run one command submitted by the user. Two values are returned: - * (1) whether to keep running, (2) the line to record for replay, - * if any. */ + /** + * Run one command submitted by the user. Two values are returned: + * (1) whether to keep running, (2) the line to record for replay, + * if any. + */ private[repl] def command(line: String): Result = { if (line startsWith ":") { val cmd = line.tail takeWhile (x => !x.isWhitespace) @@ -843,12 +843,13 @@ class SparkILoop( } import paste.{ ContinueString, PromptString } - /** Interpret expressions starting with the first line. - * Read lines until a complete compilation unit is available - * or until a syntax error has been seen. If a full unit is - * read, go ahead and interpret it. Return the full string - * to be recorded for replay, if any. - */ + /** + * Interpret expressions starting with the first line. + * Read lines until a complete compilation unit is available + * or until a syntax error has been seen. If a full unit is + * read, go ahead and interpret it. Return the full string + * to be recorded for replay, if any. + */ private def interpretStartingWith(code: String): Option[String] = { // signal completion non-completion input has been received in.completion.resetVerbosity() @@ -1009,8 +1010,13 @@ class SparkILoop( val conf = new SparkConf() .setMaster(getMaster()) .setJars(jars) - .set("spark.repl.class.uri", intp.classServerUri) .setIfMissing("spark.app.name", "Spark shell") + // SparkContext will detect this configuration and register it with the RpcEnv's + // file server, setting spark.repl.class.uri to the actual URI for executors to + // use. This is sort of ugly but since executors are started as part of SparkContext + // initialization in certain cases, there's an initialization order issue that prevents + // this from being set after SparkContext is instantiated. + .set("spark.repl.class.outputDir", intp.outputDir.getAbsolutePath()) if (execUri != null) { conf.set("spark.executor.uri", execUri) } @@ -1025,7 +1031,7 @@ class SparkILoop( val loader = Utils.getContextOrSparkClassLoader try { sqlContext = loader.loadClass(name).getConstructor(classOf[SparkContext]) - .newInstance(sparkContext).asInstanceOf[SQLContext] + .newInstance(sparkContext).asInstanceOf[SQLContext] logInfo("Created sql context (with Hive support)..") } catch { diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala index bd3314d94eed6..99e1e1df33fd8 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala @@ -123,18 +123,19 @@ private[repl] trait SparkILoopInit { def initializeSpark() { intp.beQuietDuring { command(""" - @transient val sc = { - val _sc = org.apache.spark.repl.Main.interp.createSparkContext() - println("Spark context available as sc.") - _sc - } + @transient val sc = { + val _sc = org.apache.spark.repl.Main.interp.createSparkContext() + println("Spark context available as sc " + + s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") + _sc + } """) command(""" - @transient val sqlContext = { - val _sqlContext = org.apache.spark.repl.Main.interp.createSQLContext() - println("SQL context available as sqlContext.") - _sqlContext - } + @transient val sqlContext = { + val _sqlContext = org.apache.spark.repl.Main.interp.createSQLContext() + println("SQL context available as sqlContext.") + _sqlContext + } """) command("import org.apache.spark.SparkContext._") command("import sqlContext.implicits._") diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 4ee605fd7f11e..74a04d5a42bb2 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -37,7 +37,8 @@ import scala.reflect.{ ClassTag, classTag } import scala.tools.reflect.StdRuntimeTags._ import scala.util.control.ControlThrowable -import org.apache.spark.{Logging, HttpServer, SecurityManager, SparkConf} +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.internal.Logging import org.apache.spark.util.Utils import org.apache.spark.annotation.DeveloperApi @@ -72,7 +73,7 @@ import org.apache.spark.annotation.DeveloperApi * all variables defined by that code. To extract the result of an * interpreted line to show the user, a second "result object" is created * which imports the variables exported by the above object and then - * exports members called "$eval" and "$print". To accomodate user expressions + * exports members called "$eval" and "$print". To accommodate user expressions * that read from variables or methods defined in previous statements, "import" * statements are used. * @@ -96,10 +97,9 @@ import org.apache.spark.annotation.DeveloperApi private val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") /** Local directory to save .class files too */ - private lazy val outputDir = { - val tmp = System.getProperty("java.io.tmpdir") - val rootDir = conf.get("spark.repl.classdir", tmp) - Utils.createTempDir(rootDir) + private[repl] val outputDir = { + val rootDir = conf.getOption("spark.repl.classdir").getOrElse(Utils.getLocalDir(conf)) + Utils.createTempDir(root = rootDir, namePrefix = "repl") } if (SPARK_DEBUG_REPL) { echo("Output directory: " + outputDir) @@ -114,8 +114,6 @@ import org.apache.spark.annotation.DeveloperApi private val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles /** Jetty server that will serve our classes to worker nodes */ - private val classServerPort = conf.getInt("spark.replClassServer.port", 0) - private val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf), classServerPort, "HTTP class server") private var currentSettings: Settings = initialSettings private var printResults = true // whether to print result lines private var totalSilence = false // whether to print anything @@ -124,22 +122,6 @@ import org.apache.spark.annotation.DeveloperApi private var bindExceptions = true // whether to bind the lastException variable private var _executionWrapper = "" // code to be wrapped around all lines - - // Start the classServer and store its URI in a spark system property - // (which will be passed to executors so that they can connect to it) - classServer.start() - if (SPARK_DEBUG_REPL) { - echo("Class server started, URI = " + classServer.uri) - } - - /** - * URI of the class server used to feed REPL compiled classes. - * - * @return The string representing the class server uri - */ - @DeveloperApi - def classServerUri = classServer.uri - /** We're going to go to some trouble to initialize the compiler asynchronously. * It's critical that nothing call into it until it's been initialized or we will * run into unrecoverable issues, but the perceived repl startup time goes @@ -994,7 +976,6 @@ import org.apache.spark.annotation.DeveloperApi @DeveloperApi def close() { reporter.flush() - classServer.stop() } /** @@ -1221,10 +1202,16 @@ import org.apache.spark.annotation.DeveloperApi ) } - val preamble = """ - |class %s extends Serializable { - | %s%s%s - """.stripMargin.format(lineRep.readName, envLines.map(" " + _ + ";\n").mkString, importsPreamble, indentCode(toCompute)) + val preamble = s""" + |class ${lineRep.readName} extends Serializable { + | ${envLines.map(" " + _ + ";\n").mkString} + | $importsPreamble + | + | // If we need to construct any objects defined in the REPL on an executor we will need + | // to pass the outer scope to the appropriate encoder. + | org.apache.spark.sql.catalyst.encoders.OuterScopes.addOuterScope(this) + | ${indentCode(toCompute)} + """.stripMargin val postamble = importsTrailer + "\n}" + "\n" + "object " + lineRep.readName + " {\n" + " val INSTANCE = new " + lineRep.readName + "();\n" + @@ -1529,7 +1516,7 @@ import org.apache.spark.annotation.DeveloperApi exprTyper.symbolOfLine(code) /** - * Constucts type information based on the provided expression's final + * Constructs type information based on the provided expression's final * result or the definition provided. * * @param expr The expression or definition diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala index 1d0fe10d3d817..f22776592c288 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala @@ -118,8 +118,9 @@ private[repl] trait SparkImports { case class ReqAndHandler(req: Request, handler: MemberHandler) { } def reqsToUse: List[ReqAndHandler] = { - /** Loop through a list of MemberHandlers and select which ones to keep. - * 'wanted' is the set of names that need to be imported. + /** + * Loop through a list of MemberHandlers and select which ones to keep. + * 'wanted' is the set of names that need to be imported. */ def select(reqs: List[ReqAndHandler], wanted: Set[Name]): List[ReqAndHandler] = { // Single symbol imports might be implicits! See bug #1752. Rather than diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala index f24d6da72437e..1ba17dfd8e3d0 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala @@ -7,8 +7,6 @@ package org.apache.spark.repl -import org.apache.spark.annotation.DeveloperApi - import scala.tools.nsc._ import scala.tools.nsc.interpreter._ @@ -16,7 +14,9 @@ import scala.tools.jline._ import scala.tools.jline.console.completer._ import Completion._ import scala.collection.mutable.ListBuffer -import org.apache.spark.Logging + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.Logging /** * Represents an auto-completion tool for the supplied interpreter that diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 5674dcd669bee..547da8f713ac7 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -262,6 +262,9 @@ class ReplSuite extends SparkFunSuite { |import sqlContext.implicits._ |case class TestCaseClass(value: Int) |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect() + | + |// Test Dataset Serialization in the REPL + |Seq(TestCaseClass(1)).toDS().collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -278,6 +281,29 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("java.lang.ClassNotFoundException", output) } + test("Datasets and encoders") { + val output = runInterpreter("local", + """ + |import org.apache.spark.sql.functions._ + |import org.apache.spark.sql.{Encoder, Encoders} + |import org.apache.spark.sql.expressions.Aggregator + |import org.apache.spark.sql.TypedColumn + |val simpleSum = new Aggregator[Int, Int, Int] { + | def zero: Int = 0 // The initial value. + | def reduce(b: Int, a: Int) = b + a // Add an element to the running total + | def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values. + | def finish(b: Int) = b // Return the final result. + | def bufferEncoder: Encoder[Int] = Encoders.scalaInt + | def outputEncoder: Encoder[Int] = Encoders.scalaInt + |}.toColumn + | + |val ds = Seq(1, 2, 3, 4).toDS() + |ds.select(simpleSum).collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + test("SPARK-2632 importing a method from non serializable class and not using it.") { val output = runInterpreter("local", """ diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index 627148df80c11..b822ff496c118 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -19,57 +19,67 @@ package org.apache.spark.repl import java.io.File -import scala.tools.nsc.Settings +import scala.tools.nsc.GenericRunnerSettings -import org.apache.spark.util.Utils import org.apache.spark._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils import org.apache.spark.sql.SQLContext object Main extends Logging { + initializeLogIfNecessary(true) + val conf = new SparkConf() - val tmp = System.getProperty("java.io.tmpdir") - val rootDir = conf.get("spark.repl.classdir", tmp) - val outputDir = Utils.createTempDir(rootDir) - val s = new Settings() - s.processArguments(List("-Yrepl-class-based", - "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", - "-classpath", getAddedJars.mkString(File.pathSeparator)), true) - // the creation of SecurityManager has to be lazy so SPARK_YARN_MODE is set if needed - lazy val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf)) + val rootDir = conf.getOption("spark.repl.classdir").getOrElse(Utils.getLocalDir(conf)) + val outputDir = Utils.createTempDir(root = rootDir, namePrefix = "repl") + var sparkContext: SparkContext = _ var sqlContext: SQLContext = _ - var interp = new SparkILoop // this is a public var because tests reset it. + // this is a public var because tests reset it. + var interp: SparkILoop = _ + + private var hasErrors = false + + private def scalaOptionError(msg: String): Unit = { + hasErrors = true + Console.err.println(msg) + } def main(args: Array[String]) { - if (getMaster == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") - // Start the classServer and store its URI in a spark system property - // (which will be passed to executors so that they can connect to it) - classServer.start() - interp.process(s) // Repl starts and goes in loop of R.E.P.L - classServer.stop() - Option(sparkContext).map(_.stop) + doMain(args, new SparkILoop) } - def getAddedJars: Array[String] = { - val envJars = sys.env.get("ADD_JARS") - if (envJars.isDefined) { - logWarning("ADD_JARS environment variable is deprecated, use --jar spark submit argument instead") + // Visible for testing + private[repl] def doMain(args: Array[String], _interp: SparkILoop): Unit = { + interp = _interp + val jars = conf.getOption("spark.jars") + .map(_.replace(",", File.pathSeparator)) + .getOrElse("") + val interpArguments = List( + "-Yrepl-class-based", + "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", + "-classpath", jars + ) ++ args.toList + + val settings = new GenericRunnerSettings(scalaOptionError) + settings.processArguments(interpArguments, true) + + if (!hasErrors) { + interp.process(settings) // Repl starts and goes in loop of R.E.P.L + Option(sparkContext).map(_.stop) } - val propJars = sys.props.get("spark.jars").flatMap { p => if (p == "") None else Some(p) } - val jars = propJars.orElse(envJars).getOrElse("") - Utils.resolveURIs(jars).split(",").filter(_.nonEmpty) } def createSparkContext(): SparkContext = { val execUri = System.getenv("SPARK_EXECUTOR_URI") - val jars = getAddedJars - val conf = new SparkConf() - .setMaster(getMaster) - .setJars(jars) - .set("spark.repl.class.uri", classServer.uri) - .setIfMissing("spark.app.name", "Spark shell") - logInfo("Spark class server started at " + classServer.uri) + conf.setIfMissing("spark.app.name", "Spark shell") + // SparkContext will detect this configuration and register it with the RpcEnv's + // file server, setting spark.repl.class.uri to the actual URI for executors to + // use. This is sort of ugly but since executors are started as part of SparkContext + // initialization in certain cases, there's an initialization order issue that prevents + // this from being set after SparkContext is instantiated. + .set("spark.repl.class.outputDir", outputDir.getAbsolutePath()) if (execUri != null) { conf.set("spark.executor.uri", execUri) } @@ -96,12 +106,4 @@ object Main extends Logging { sqlContext } - private def getMaster: String = { - val master = { - val envMaster = sys.env.get("MASTER") - val propMaster = sys.props.get("spark.master") - propMaster.orElse(envMaster).getOrElse("local[*]") - } - master - } } diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 33d262558b1fc..db09d6ace1c65 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -17,14 +17,13 @@ package org.apache.spark.repl -import java.io.{BufferedReader, FileReader} +import java.io.BufferedReader -import Predef.{println => _, _} -import scala.util.Properties.{jdkHome, javaVersion, versionString, javaVmName} - -import scala.tools.nsc.interpreter.{JPrintWriter, ILoop} +import scala.Predef.{println => _, _} import scala.tools.nsc.Settings +import scala.tools.nsc.interpreter.{ILoop, JPrintWriter} import scala.tools.nsc.util.stringFromStream +import scala.util.Properties.{javaVersion, javaVmName, versionString} /** * A Spark-specific interactive shell. @@ -37,18 +36,19 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) def initializeSpark() { intp.beQuietDuring { processLine(""" - @transient val sc = { - val _sc = org.apache.spark.repl.Main.createSparkContext() - println("Spark context available as sc.") - _sc - } + @transient val sc = { + val _sc = org.apache.spark.repl.Main.createSparkContext() + println("Spark context available as sc " + + s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") + _sc + } """) processLine(""" - @transient val sqlContext = { - val _sqlContext = org.apache.spark.repl.Main.createSQLContext() - println("SQL context available as sqlContext.") - _sqlContext - } + @transient val sqlContext = { + val _sqlContext = org.apache.spark.repl.Main.createSQLContext() + println("SQL context available as sqlContext.") + _sqlContext + } """) processLine("import org.apache.spark.SparkContext._") processLine("import sqlContext.implicits._") @@ -74,18 +74,16 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) echo("Type :help for more information.") } - import LoopCommand.{ cmd, nullary } - - private val blockedCommands = Set("implicits", "javap", "power", "type", "kind") + private val blockedCommands = Set("implicits", "javap", "power", "type", "kind", "reset") - /** Standard commands **/ + /** Standard commands */ lazy val sparkStandardCommands: List[SparkILoop.this.LoopCommand] = standardCommands.filter(cmd => !blockedCommands(cmd.name)) /** Available commands */ override def commands: List[LoopCommand] = sparkStandardCommands - /** + /** * We override `loadFiles` because we need to initialize Spark *before* the REPL * sees any files, so that the Spark context is visible in those files. This is a bit of a * hack, but there isn't another hook available to us at this point. @@ -98,7 +96,7 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) object SparkILoop { - /** + /** * Creates an interpreter loop with default settings and feeds * the given code to it as input. */ @@ -111,9 +109,9 @@ object SparkILoop { val output = new JPrintWriter(new OutputStreamWriter(ostream), true) val repl = new SparkILoop(input, output) - if (sets.classpath.isDefault) + if (sets.classpath.isDefault) { sets.classpath.value = sys.props("java.class.path") - + } repl process sets } } diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index bf8997998e00d..d3dafe9c42ee2 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -21,7 +21,6 @@ import java.io._ import java.net.URLClassLoader import scala.collection.mutable.ArrayBuffer -import scala.concurrent.duration._ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.{SparkContext, SparkFunSuite} @@ -49,14 +48,8 @@ class ReplSuite extends SparkFunSuite { val oldExecutorClasspath = System.getProperty(CONF_EXECUTOR_CLASSPATH) System.setProperty(CONF_EXECUTOR_CLASSPATH, classpath) - System.setProperty("spark.master", master) - val interp = { - new SparkILoop(in, new PrintWriter(out)) - } - org.apache.spark.repl.Main.interp = interp - Main.s.processArguments(List("-classpath", classpath), true) - Main.main(Array()) // call main - org.apache.spark.repl.Main.interp = null + Main.conf.set("spark.master", master) + Main.doMain(Array("-classpath", classpath), new SparkILoop(in, new PrintWriter(out))) if (oldExecutorClasspath != null) { System.setProperty(CONF_EXECUTOR_CLASSPATH, oldExecutorClasspath) @@ -66,6 +59,10 @@ class ReplSuite extends SparkFunSuite { return out.toString } + // Simulate the paste mode in Scala REPL. + def runInterpreterInPasteMode(master: String, input: String): String = + runInterpreter(master, ":paste\n" + input + 4.toChar) // 4 is the ascii code of CTRL + D + def assertContains(message: String, output: String) { val isContain = output.contains(message) assert(isContain, @@ -256,10 +253,34 @@ class ReplSuite extends SparkFunSuite { // We need to use local-cluster to test this case. val output = runInterpreter("local-cluster[1,1,1024]", """ - |val sqlContext = new org.apache.spark.sql.SQLContext(sc) - |import sqlContext.implicits._ |case class TestCaseClass(value: Int) |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect() + | + |// Test Dataset Serialization in the REPL + |Seq(TestCaseClass(1)).toDS().collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + + test("Datasets and encoders") { + val output = runInterpreter("local", + """ + |import org.apache.spark.sql.functions._ + |import org.apache.spark.sql.{Encoder, Encoders} + |import org.apache.spark.sql.expressions.Aggregator + |import org.apache.spark.sql.TypedColumn + |val simpleSum = new Aggregator[Int, Int, Int] { + | def zero: Int = 0 // The initial value. + | def reduce(b: Int, a: Int) = b + a // Add an element to the running total + | def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values. + | def finish(b: Int) = b // Return the final result. + | def bufferEncoder: Encoder[Int] = Encoders.scalaInt + | def outputEncoder: Encoder[Int] = Encoders.scalaInt + |}.toColumn + | + |val ds = Seq(1, 2, 3, 4).toDS() + |ds.select(simpleSum).collect """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -324,4 +345,59 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("Exception", output) assertContains("ret: Array[(Int, Iterable[Foo])] = Array((1,", output) } + + test("line wrapper only initialized once when used as encoder outer scope") { + val output = runInterpreter("local", + """ + |val fileName = "repl-test-" + System.currentTimeMillis + |val tmpDir = System.getProperty("java.io.tmpdir") + |val file = new java.io.File(tmpDir, fileName) + |def createFile(): Unit = file.createNewFile() + | + |createFile();case class TestCaseClass(value: Int) + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect() + | + |file.delete() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + + test("define case class and create Dataset together with paste mode") { + val output = runInterpreterInPasteMode("local-cluster[1,1,1024]", + """ + |import sqlContext.implicits._ + |case class TestClass(value: Int) + |Seq(TestClass(1)).toDS() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + + test("should clone and clean line object in ClosureCleaner") { + val output = runInterpreterInPasteMode("local-cluster[1,4,4096]", + """ + |import org.apache.spark.rdd.RDD + | + |val lines = sc.textFile("pom.xml") + |case class Data(s: String) + |val dataRDD = lines.map(line => Data(line.take(3))) + |dataRDD.cache.count + |val repartitioned = dataRDD.repartition(dataRDD.partitions.size) + |repartitioned.cache.count + | + |def getCacheSize(rdd: RDD[_]) = { + | sc.getRDDStorageInfo.filter(_.id == rdd.id).map(_.memSize).sum + |} + |val cacheSize1 = getCacheSize(dataRDD) + |val cacheSize2 = getCacheSize(repartitioned) + | + |// The cache size of dataRDD and the repartitioned one should be similar. + |val deviation = math.abs(cacheSize2 - cacheSize1).toDouble / cacheSize1 + |assert(deviation < 0.2, + | s"deviation too large: $deviation, first size: $cacheSize1, second size: $cacheSize2") + """.stripMargin) + assertDoesNotContain("AssertionError", output) + assertDoesNotContain("Exception", output) + } } diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 004941d5f50ae..4a15d52b570a4 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -17,27 +17,33 @@ package org.apache.spark.repl -import java.io.{IOException, ByteArrayOutputStream, InputStream} +import java.io.{ByteArrayOutputStream, FilterInputStream, InputStream, IOException} import java.net.{HttpURLConnection, URI, URL, URLEncoder} +import java.nio.channels.Channels import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.xbean.asm5._ +import org.apache.xbean.asm5.Opcodes._ -import org.apache.spark.{SparkConf, SparkEnv, Logging} +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.util.Utils -import org.apache.spark.util.ParentClassLoader - -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm._ -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.{ParentClassLoader, Utils} /** * A ClassLoader that reads classes from a Hadoop FileSystem or HTTP URI, * used to load classes defined by the interpreter when the REPL is used. - * Allows the user to specify if user class path should be first + * Allows the user to specify if user class path should be first. + * This class loader delegates getting/finding resources to parent loader, + * which makes sense until REPL never provide resource dynamically. */ -class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader, +class ExecutorClassLoader( + conf: SparkConf, + env: SparkEnv, + classUri: String, + parent: ClassLoader, userClassPathFirst: Boolean) extends ClassLoader with Logging { val uri = new URI(classUri) val directory = uri.getPath @@ -47,29 +53,62 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader // Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes private[repl] var httpUrlConnectionTimeoutMillis: Int = -1 - // Hadoop FileSystem object for our URI, if it isn't using HTTP - var fileSystem: FileSystem = { - if (Set("http", "https", "ftp").contains(uri.getScheme)) { - null + private val fetchFn: (String) => InputStream = uri.getScheme() match { + case "spark" => getClassFileInputStreamFromSparkRPC + case "http" | "https" | "ftp" => getClassFileInputStreamFromHttpServer + case _ => + val fileSystem = FileSystem.get(uri, SparkHadoopUtil.get.newConfiguration(conf)) + getClassFileInputStreamFromFileSystem(fileSystem) + } + + override def getResource(name: String): URL = { + parentLoader.getResource(name) + } + + override def getResources(name: String): java.util.Enumeration[URL] = { + parentLoader.getResources(name) + } + + override def findClass(name: String): Class[_] = { + if (userClassPathFirst) { + findClassLocally(name).getOrElse(parentLoader.loadClass(name)) } else { - FileSystem.get(uri, SparkHadoopUtil.get.newConfiguration(conf)) + try { + parentLoader.loadClass(name) + } catch { + case e: ClassNotFoundException => + val classOption = findClassLocally(name) + classOption match { + case None => + // If this class has a cause, it will break the internal assumption of Janino + // (the compiler used for Spark SQL code-gen). + // See org.codehaus.janino.ClassLoaderIClassLoader's findIClass, you will see + // its behavior will be changed if there is a cause and the compilation + // of generated class will fail. + throw new ClassNotFoundException(name) + case Some(a) => a + } + } } } - override def findClass(name: String): Class[_] = { - userClassPathFirst match { - case true => findClassLocally(name).getOrElse(parentLoader.loadClass(name)) - case false => { + private def getClassFileInputStreamFromSparkRPC(path: String): InputStream = { + val channel = env.rpcEnv.openChannel(s"$classUri/$path") + new FilterInputStream(Channels.newInputStream(channel)) { + + override def read(): Int = toClassNotFound(super.read()) + + override def read(b: Array[Byte]): Int = toClassNotFound(super.read(b)) + + override def read(b: Array[Byte], offset: Int, len: Int) = + toClassNotFound(super.read(b, offset, len)) + + private def toClassNotFound(fn: => Int): Int = { try { - parentLoader.loadClass(name) + fn } catch { - case e: ClassNotFoundException => { - val classOption = findClassLocally(name) - classOption match { - case None => throw new ClassNotFoundException(name, e) - case Some(a) => a - } - } + case e: Exception => + throw new ClassNotFoundException(path, e) } } } @@ -111,7 +150,8 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader } } - private def getClassFileInputStreamFromFileSystem(pathInDirectory: String): InputStream = { + private def getClassFileInputStreamFromFileSystem(fileSystem: FileSystem)( + pathInDirectory: String): InputStream = { val path = new Path(directory, pathInDirectory) if (fileSystem.exists(path)) { fileSystem.open(path) @@ -124,13 +164,7 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader val pathInDirectory = name.replace('.', '/') + ".class" var inputStream: InputStream = null try { - inputStream = { - if (fileSystem != null) { - getClassFileInputStreamFromFileSystem(pathInDirectory) - } else { - getClassFileInputStreamFromHttpServer(pathInDirectory) - } - } + inputStream = fetchFn(pathInDirectory) val bytes = readAndTransformClass(name, inputStream) Some(defineClass(name, bytes, 0, bytes.length)) } catch { @@ -192,7 +226,7 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader } class ConstructorCleaner(className: String, cv: ClassVisitor) -extends ClassVisitor(ASM4, cv) { +extends ClassVisitor(ASM5, cv) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { val mv = cv.visitMethod(access, name, desc, sig, exceptions) @@ -202,7 +236,7 @@ extends ClassVisitor(ASM4, cv) { // field in the class to point to it, but do nothing otherwise. mv.visitCode() mv.visitVarInsn(ALOAD, 0) // load this - mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "", "()V") + mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "", "()V", false) mv.visitVarInsn(ALOAD, 0) // load this // val classType = className.replace('.', '/') // mv.visitFieldInsn(PUTSTATIC, classType, "MODULE$", "L" + classType + ";") diff --git a/repl/src/test/resources/log4j.properties b/repl/src/test/resources/log4j.properties index e2ee9c963a4da..7665bd5e7c070 100644 --- a/repl/src/test/resources/log4j.properties +++ b/repl/src/test/resources/log4j.properties @@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index a58eda12b1120..9a143ee36ff46 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -18,19 +18,30 @@ package org.apache.spark.repl import java.io.File -import java.net.{URL, URLClassLoader} +import java.net.{URI, URL, URLClassLoader} +import java.nio.channels.{FileChannel, ReadableByteChannel} +import java.nio.charset.StandardCharsets +import java.nio.file.{Paths, StandardOpenOption} +import java.util import scala.concurrent.duration._ +import scala.io.Source import scala.language.implicitConversions import scala.language.postfixOps +import com.google.common.io.Files +import org.mockito.Matchers.anyString +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Interruptor import org.scalatest.concurrent.Timeouts._ import org.scalatest.mock.MockitoSugar -import org.mockito.Mockito._ import org.apache.spark._ +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.Utils class ExecutorClassLoaderSuite @@ -41,6 +52,7 @@ class ExecutorClassLoaderSuite val childClassNames = List("ReplFakeClass1", "ReplFakeClass2") val parentClassNames = List("ReplFakeClass1", "ReplFakeClass2", "ReplFakeClass3") + val parentResourceNames = List("fake-resource.txt") var tempDir1: File = _ var tempDir2: File = _ var url1: String = _ @@ -54,22 +66,28 @@ class ExecutorClassLoaderSuite url1 = "file://" + tempDir1 urls2 = List(tempDir2.toURI.toURL).toArray childClassNames.foreach(TestUtils.createCompiledClass(_, tempDir1, "1")) + parentResourceNames.foreach { x => + Files.write("resource".getBytes(StandardCharsets.UTF_8), new File(tempDir2, x)) + } parentClassNames.foreach(TestUtils.createCompiledClass(_, tempDir2, "2")) } override def afterAll() { - super.afterAll() - if (classServer != null) { - classServer.stop() + try { + if (classServer != null) { + classServer.stop() + } + Utils.deleteRecursively(tempDir1) + Utils.deleteRecursively(tempDir2) + SparkEnv.set(null) + } finally { + super.afterAll() } - Utils.deleteRecursively(tempDir1) - Utils.deleteRecursively(tempDir2) - SparkEnv.set(null) } test("child first") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "1") @@ -77,7 +95,7 @@ class ExecutorClassLoaderSuite test("parent first") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, false) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, false) val fakeClass = classLoader.loadClass("ReplFakeClass1").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") @@ -85,7 +103,7 @@ class ExecutorClassLoaderSuite test("child first can fall back") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) val fakeClass = classLoader.loadClass("ReplFakeClass3").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") @@ -93,12 +111,32 @@ class ExecutorClassLoaderSuite test("child first can fail") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) intercept[java.lang.ClassNotFoundException] { classLoader.loadClass("ReplFakeClassDoesNotExist").newInstance() } } + test("resource from parent") { + val parentLoader = new URLClassLoader(urls2, null) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) + val resourceName: String = parentResourceNames.head + val is = classLoader.getResourceAsStream(resourceName) + assert(is != null, s"Resource $resourceName not found") + val content = Source.fromInputStream(is, "UTF-8").getLines().next() + assert(content.contains("resource"), "File doesn't contain 'resource'") + } + + test("resources from parent") { + val parentLoader = new URLClassLoader(urls2, null) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) + val resourceName: String = parentResourceNames.head + val resources: util.Enumeration[URL] = classLoader.getResources(resourceName) + assert(resources.hasMoreElements, s"Resource $resourceName not found") + val fileReader = Source.fromInputStream(resources.nextElement().openStream()).bufferedReader() + assert(fileReader.readLine().contains("resource"), "File doesn't contain 'resource'") + } + test("failing to fetch classes from HTTP server should not leak resources (SPARK-6209)") { // This is a regression test for SPARK-6209, a bug where each failed attempt to load a class // from the driver's class server would leak a HTTP connection, causing the class server's @@ -113,7 +151,7 @@ class ExecutorClassLoaderSuite SparkEnv.set(mockEnv) // Create an ExecutorClassLoader that's configured to load classes from the HTTP server val parentLoader = new URLClassLoader(Array.empty, null) - val classLoader = new ExecutorClassLoader(conf, classServer.uri, parentLoader, false) + val classLoader = new ExecutorClassLoader(conf, null, classServer.uri, parentLoader, false) classLoader.httpUrlConnectionTimeoutMillis = 500 // Check that this class loader can actually load classes that exist val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() @@ -148,4 +186,27 @@ class ExecutorClassLoaderSuite failAfter(10 seconds)(tryAndFailToLoadABunchOfClasses())(interruptor) } + test("fetch classes using Spark's RpcEnv") { + val env = mock[SparkEnv] + val rpcEnv = mock[RpcEnv] + when(env.rpcEnv).thenReturn(rpcEnv) + when(rpcEnv.openChannel(anyString())).thenAnswer(new Answer[ReadableByteChannel]() { + override def answer(invocation: InvocationOnMock): ReadableByteChannel = { + val uri = new URI(invocation.getArguments()(0).asInstanceOf[String]) + val path = Paths.get(tempDir1.getAbsolutePath(), uri.getPath().stripPrefix("/")) + FileChannel.open(path, StandardOpenOption.READ) + } + }) + + val classLoader = new ExecutorClassLoader(new SparkConf(), env, "spark://localhost:1234", + getClass().getClassLoader(), false) + + val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() + val fakeClassVersion = fakeClass.toString + assert(fakeClassVersion === "1") + intercept[java.lang.ClassNotFoundException] { + classLoader.loadClass("ReplFakeClassDoesNotExist").newInstance() + } + } + } diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh index d8d9d00d64ebc..97df433a0b675 100755 --- a/sbin/spark-config.sh +++ b/sbin/spark-config.sh @@ -27,4 +27,4 @@ fi export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}/conf"}" # Add the PySpark classes to the PYTHONPATH: export PYTHONPATH="${SPARK_HOME}/python:${PYTHONPATH}" -export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.9-src.zip:${PYTHONPATH}" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.9.2-src.zip:${PYTHONPATH}" diff --git a/sbin/start-all.sh b/sbin/start-all.sh index 6217f9bf28e3d..a5d30d274ea6e 100755 --- a/sbin/start-all.sh +++ b/sbin/start-all.sh @@ -25,22 +25,11 @@ if [ -z "${SPARK_HOME}" ]; then export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" fi -TACHYON_STR="" - -while (( "$#" )); do -case $1 in - --with-tachyon) - TACHYON_STR="--with-tachyon" - ;; - esac -shift -done - # Load the Spark configuration . "${SPARK_HOME}/sbin/spark-config.sh" # Start Master -"${SPARK_HOME}/sbin"/start-master.sh $TACHYON_STR +"${SPARK_HOME}/sbin"/start-master.sh # Start Workers -"${SPARK_HOME}/sbin"/start-slaves.sh $TACHYON_STR +"${SPARK_HOME}/sbin"/start-slaves.sh diff --git a/sbin/start-master.sh b/sbin/start-master.sh index c20e19a8412df..ce7f17795997e 100755 --- a/sbin/start-master.sh +++ b/sbin/start-master.sh @@ -23,22 +23,21 @@ if [ -z "${SPARK_HOME}" ]; then export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" fi -ORIGINAL_ARGS="$@" - -START_TACHYON=false +# NOTE: This exact class name is matched downstream by SparkSubmit. +# Any changes need to be reflected there. +CLASS="org.apache.spark.deploy.master.Master" + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + echo "Usage: ./sbin/start-master.sh [options]" + pattern="Usage:" + pattern+="\|Using Spark's default log4j profile:" + pattern+="\|Registered signal handlers for" + + "${SPARK_HOME}"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 + exit 1 +fi -while (( "$#" )); do -case $1 in - --with-tachyon) - if [ ! -e "$sbin"/../tachyon/bin/tachyon ]; then - echo "Error: --with-tachyon specified, but tachyon not found." - exit -1 - fi - START_TACHYON=true - ;; - esac -shift -done +ORIGINAL_ARGS="$@" . "${SPARK_HOME}/sbin/spark-config.sh" @@ -56,12 +55,6 @@ if [ "$SPARK_MASTER_WEBUI_PORT" = "" ]; then SPARK_MASTER_WEBUI_PORT=8080 fi -"${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.master.Master 1 \ +"${SPARK_HOME}/sbin"/spark-daemon.sh start $CLASS 1 \ --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT \ $ORIGINAL_ARGS - -if [ "$START_TACHYON" == "true" ]; then - "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon bootstrap-conf $SPARK_MASTER_IP - "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon format -s - "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon-start.sh master -fi diff --git a/sbin/start-mesos-dispatcher.sh b/sbin/start-mesos-dispatcher.sh index 4777e1668c703..06a966d1c20b4 100755 --- a/sbin/start-mesos-dispatcher.sh +++ b/sbin/start-mesos-dispatcher.sh @@ -37,5 +37,8 @@ if [ "$SPARK_MESOS_DISPATCHER_HOST" = "" ]; then SPARK_MESOS_DISPATCHER_HOST=`hostname` fi +if [ "$SPARK_MESOS_DISPATCHER_NUM" = "" ]; then + SPARK_MESOS_DISPATCHER_NUM=1 +fi -"${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.mesos.MesosClusterDispatcher 1 --host $SPARK_MESOS_DISPATCHER_HOST --port $SPARK_MESOS_DISPATCHER_PORT "$@" +"${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.mesos.MesosClusterDispatcher $SPARK_MESOS_DISPATCHER_NUM --host $SPARK_MESOS_DISPATCHER_HOST --port $SPARK_MESOS_DISPATCHER_PORT "$@" diff --git a/sbin/start-slave.sh b/sbin/start-slave.sh index 21455648d1c6d..8c268b8859155 100755 --- a/sbin/start-slave.sh +++ b/sbin/start-slave.sh @@ -31,18 +31,24 @@ # worker. Subsequent workers will increment this # number. Default is 8081. -usage="Usage: start-slave.sh where is like spark://localhost:7077" - -if [ $# -lt 1 ]; then - echo $usage - echo Called as start-slave.sh $* - exit 1 -fi - if [ -z "${SPARK_HOME}" ]; then export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" fi +# NOTE: This exact class name is matched downstream by SparkSubmit. +# Any changes need to be reflected there. +CLASS="org.apache.spark.deploy.worker.Worker" + +if [[ $# -lt 1 ]] || [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + echo "Usage: ./sbin/start-slave.sh [options] " + pattern="Usage:" + pattern+="\|Using Spark's default log4j profile:" + pattern+="\|Registered signal handlers for" + + "${SPARK_HOME}"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 + exit 1 +fi + . "${SPARK_HOME}/sbin/spark-config.sh" . "${SPARK_HOME}/bin/load-spark-env.sh" @@ -72,7 +78,7 @@ function start_instance { fi WEBUI_PORT=$(( $SPARK_WORKER_WEBUI_PORT + $WORKER_NUM - 1 )) - "${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.worker.Worker $WORKER_NUM \ + "${SPARK_HOME}/sbin"/spark-daemon.sh start $CLASS $WORKER_NUM \ --webui-port "$WEBUI_PORT" $PORT_FLAG $PORT_NUM $MASTER "$@" } diff --git a/sbin/start-slaves.sh b/sbin/start-slaves.sh index 51ca81e053b70..5bf2b83b42ce4 100755 --- a/sbin/start-slaves.sh +++ b/sbin/start-slaves.sh @@ -23,21 +23,6 @@ if [ -z "${SPARK_HOME}" ]; then export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" fi -START_TACHYON=false - -while (( "$#" )); do -case $1 in - --with-tachyon) - if [ ! -e "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon ]; then - echo "Error: --with-tachyon specified, but tachyon not found." - exit -1 - fi - START_TACHYON=true - ;; - esac -shift -done - . "${SPARK_HOME}/sbin/spark-config.sh" . "${SPARK_HOME}/bin/load-spark-env.sh" @@ -50,12 +35,5 @@ if [ "$SPARK_MASTER_IP" = "" ]; then SPARK_MASTER_IP="`hostname`" fi -if [ "$START_TACHYON" == "true" ]; then - "${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon bootstrap-conf "$SPARK_MASTER_IP" - - # set -t so we can call sudo - SPARK_SSH_OPTS="-o StrictHostKeyChecking=no -t" "${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/tachyon/bin/tachyon-start.sh" worker SudoMount \; sleep 1 -fi - # Launch the slaves "${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin/start-slave.sh" "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" diff --git a/sbin/stop-master.sh b/sbin/stop-master.sh index e57962bb354d9..14644ea72d43b 100755 --- a/sbin/stop-master.sh +++ b/sbin/stop-master.sh @@ -26,7 +26,3 @@ fi . "${SPARK_HOME}/sbin/spark-config.sh" "${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.deploy.master.Master 1 - -if [ -e "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon ]; then - "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon killAll tachyon.master.Master -fi diff --git a/sbin/stop-mesos-dispatcher.sh b/sbin/stop-mesos-dispatcher.sh index 5c0b4e051db38..b13e018c7d41e 100755 --- a/sbin/stop-mesos-dispatcher.sh +++ b/sbin/stop-mesos-dispatcher.sh @@ -24,5 +24,10 @@ fi . "${SPARK_HOME}/sbin/spark-config.sh" -"${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.deploy.mesos.MesosClusterDispatcher 1 +if [ "$SPARK_MESOS_DISPATCHER_NUM" = "" ]; then + SPARK_MESOS_DISPATCHER_NUM=1 +fi + +"${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.deploy.mesos.MesosClusterDispatcher \ + $SPARK_MESOS_DISPATCHER_NUM diff --git a/sbin/stop-slaves.sh b/sbin/stop-slaves.sh index 63956377629d6..a57441b52a04a 100755 --- a/sbin/stop-slaves.sh +++ b/sbin/stop-slaves.sh @@ -25,9 +25,4 @@ fi . "${SPARK_HOME}/bin/load-spark-env.sh" -# do before the below calls as they exec -if [ -e "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon ]; then - "${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon killAll tachyon.worker.Worker -fi - "${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin"/stop-slave.sh diff --git a/sbt/sbt b/sbt/sbt deleted file mode 100755 index 41438251f681e..0000000000000 --- a/sbt/sbt +++ /dev/null @@ -1,29 +0,0 @@ -#!/usr/bin/env bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Determine the current working directory -_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" - -echo "NOTE: The sbt/sbt script has been relocated to build/sbt." >&2 -echo " Please update references to point to the new location." >&2 -echo "" >&2 -echo " Invoking 'build/sbt $@' now ..." >&2 -echo "" >&2 - -${_DIR}/../build/sbt "$@" diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 64a0c71bbef2a..a14e3e583f870 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -116,7 +116,7 @@ This file is divided into 3 sections: - + @@ -150,6 +150,37 @@ This file is divided into 3 sections: // scalastyle:on println]]> + + @VisibleForTesting + + + + + Runtime\.getRuntime\.addShutdownHook + + + + + mutable\.SynchronizedBuffer + + + Class\.forName - - - ^getConfiguration$|^getTaskAttemptID$ - Instead of calling .getConfiguration() or .getTaskAttemptID() directly, - use SparkHadoopUtil's getConfigurationFromJobContext() and getTaskAttemptIDFromTaskAttemptContext() methods. - + + + java,scala,3rdParty,spark + javax?\..* + scala\..* + (?!org\.apache\.spark\.).* + org\.apache\.spark\..* + + + + + + COMMA + + + + + + \)\{ + + + + + (?m)^(\s*)/[*][*].*$(\r|)\n^\1 [*] + Use Javadoc style indentation for multiline comments + + + + case[^\n>]*=>\s*\{ + Omit braces in case clauses. diff --git a/sql/README.md b/sql/README.md index 63d4dac9829e0..b0903980a59f3 100644 --- a/sql/README.md +++ b/sql/README.md @@ -5,7 +5,7 @@ This module provides support for executing relational queries expressed in eithe Spark SQL is broken up into four subprojects: - Catalyst (sql/catalyst) - An implementation-agnostic framework for manipulating trees of relational operators and expressions. - - Execution (sql/core) - A query planner / execution engine for translating Catalyst’s logical query plans into Spark RDDs. This component also includes a new public interface, SQLContext, that allows users to execute SQL or LINQ statements against existing RDDs and Parquet files. + - Execution (sql/core) - A query planner / execution engine for translating Catalyst's logical query plans into Spark RDDs. This component also includes a new public interface, SQLContext, that allows users to execute SQL or LINQ statements against existing RDDs and Parquet files. - Hive Support (sql/hive) - Includes an extension of SQLContext called HiveContext that allows users to write queries using a subset of HiveQL and access data from a Hive Metastore using Hive SerDes. There are also wrappers that allows users to run queries that include Hive UDFs, UDAFs, and UDTFs. - HiveServer and CLI support (sql/hive-thriftserver) - Includes support for the SQL CLI (bin/spark-sql) and a HiveServer2 (for JDBC/ODBC) compatible server. @@ -20,7 +20,7 @@ If you are working with Hive 0.12.0, you will need to set several environmental ``` export HIVE_HOME="/hive/build/dist" export HIVE_DEV_HOME="/hive/" -export HADOOP_HOME="/hadoop-1.0.4" +export HADOOP_HOME="/hadoop" ``` If you are working with Hive 0.13.1, the following steps are needed: @@ -47,7 +47,7 @@ An interactive scala console can be invoked by running `build/sbt hive/console`. From here you can execute queries with HiveQl and manipulate DataFrame by using DSL. ```scala -catalyst$ build/sbt hive/console +$ build/sbt hive/console [info] Starting scala interpreter... import org.apache.spark.sql.catalyst.analysis._ @@ -61,22 +61,23 @@ import org.apache.spark.sql.execution import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.types._ Type in expressions to have them evaluated. Type :help for more information. scala> val query = sql("SELECT * FROM (SELECT * FROM src) a") -query: org.apache.spark.sql.DataFrame = org.apache.spark.sql.DataFrame@74448eed +query: org.apache.spark.sql.DataFrame = [key: int, value: string] ``` Query results are `DataFrames` and can be operated as such. ``` scala> query.collect() -res2: Array[org.apache.spark.sql.Row] = Array([238,val_238], [86,val_86], [311,val_311], [27,val_27]... +res0: Array[org.apache.spark.sql.Row] = Array([238,val_238], [86,val_86], [311,val_311], [27,val_27]... ``` You can also build further queries on top of these `DataFrames` using the query DSL. ``` scala> query.where(query("key") > 30).select(avg(query("key"))).collect() -res3: Array[org.apache.spark.sql.Row] = Array([274.79025423728814]) +res1: Array[org.apache.spark.sql.Row] = Array([274.79025423728814]) ``` diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 61d6fc63554bb..1748fa2778d6a 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT + spark-parent_2.11 + 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-catalyst_2.10 + spark-catalyst_2.11 jar Spark Project Catalyst http://spark.apache.org/ @@ -71,6 +71,14 @@ org.codehaus.janino janino + + org.antlr + antlr4-runtime + + + commons-codec + commons-codec + target/scala-${scala.binary.version}/classes @@ -103,15 +111,21 @@ + + org.antlr + antlr4-maven-plugin + + + + antlr4 + + + + + true + ../catalyst/src/main/antlr4 + + - - - - scala-2.10 - - !scala-2.11 - - - diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 new file mode 100644 index 0000000000000..9cf2dd257e5c1 --- /dev/null +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -0,0 +1,957 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file is an adaptation of Presto's presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 grammar. + */ + +grammar SqlBase; + +tokens { + DELIMITER +} + +singleStatement + : statement EOF + ; + +singleExpression + : namedExpression EOF + ; + +singleTableIdentifier + : tableIdentifier EOF + ; + +singleDataType + : dataType EOF + ; + +statement + : query #statementDefault + | USE db=identifier #use + | CREATE DATABASE (IF NOT EXISTS)? identifier + (COMMENT comment=STRING)? locationSpec? + (WITH DBPROPERTIES tablePropertyList)? #createDatabase + | ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties + | DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase + | createTableHeader ('(' colTypeList ')')? tableProvider + (OPTIONS tablePropertyList)? #createTableUsing + | createTableHeader tableProvider + (OPTIONS tablePropertyList)? AS? query #createTableUsing + | createTableHeader ('(' columns=colTypeList ')')? + (COMMENT STRING)? + (PARTITIONED BY '(' partitionColumns=colTypeList ')')? + bucketSpec? skewSpec? + rowFormat? createFileFormat? locationSpec? + (TBLPROPERTIES tablePropertyList)? + (AS? query)? #createTable + | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier + LIKE source=tableIdentifier #createTableLike + | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS + (identifier | FOR COLUMNS identifierSeq?)? #analyze + | ALTER (TABLE | VIEW) from=tableIdentifier + RENAME TO to=tableIdentifier #renameTable + | ALTER (TABLE | VIEW) tableIdentifier + SET TBLPROPERTIES tablePropertyList #setTableProperties + | ALTER (TABLE | VIEW) tableIdentifier + UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList #unsetTableProperties + | ALTER TABLE tableIdentifier (partitionSpec)? + SET SERDE STRING (WITH SERDEPROPERTIES tablePropertyList)? #setTableSerDe + | ALTER TABLE tableIdentifier (partitionSpec)? + SET SERDEPROPERTIES tablePropertyList #setTableSerDe + | ALTER TABLE tableIdentifier bucketSpec #bucketTable + | ALTER TABLE tableIdentifier NOT CLUSTERED #unclusterTable + | ALTER TABLE tableIdentifier NOT SORTED #unsortTable + | ALTER TABLE tableIdentifier skewSpec #skewTable + | ALTER TABLE tableIdentifier NOT SKEWED #unskewTable + | ALTER TABLE tableIdentifier NOT STORED AS DIRECTORIES #unstoreTable + | ALTER TABLE tableIdentifier + SET SKEWED LOCATION skewedLocationList #setTableSkewLocations + | ALTER TABLE tableIdentifier ADD (IF NOT EXISTS)? + partitionSpecLocation+ #addTablePartition + | ALTER VIEW tableIdentifier ADD (IF NOT EXISTS)? + partitionSpec+ #addTablePartition + | ALTER TABLE tableIdentifier + from=partitionSpec RENAME TO to=partitionSpec #renameTablePartition + | ALTER TABLE from=tableIdentifier + EXCHANGE partitionSpec WITH TABLE to=tableIdentifier #exchangeTablePartition + | ALTER TABLE tableIdentifier + DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? #dropTablePartitions + | ALTER VIEW tableIdentifier + DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* #dropTablePartitions + | ALTER TABLE tableIdentifier ARCHIVE partitionSpec #archiveTablePartition + | ALTER TABLE tableIdentifier UNARCHIVE partitionSpec #unarchiveTablePartition + | ALTER TABLE tableIdentifier partitionSpec? + SET FILEFORMAT fileFormat #setTableFileFormat + | ALTER TABLE tableIdentifier partitionSpec? SET locationSpec #setTableLocation + | ALTER TABLE tableIdentifier TOUCH partitionSpec? #touchTable + | ALTER TABLE tableIdentifier partitionSpec? COMPACT STRING #compactTable + | ALTER TABLE tableIdentifier partitionSpec? CONCATENATE #concatenateTable + | ALTER TABLE tableIdentifier partitionSpec? + CHANGE COLUMN? oldName=identifier colType + (FIRST | AFTER after=identifier)? (CASCADE | RESTRICT)? #changeColumn + | ALTER TABLE tableIdentifier partitionSpec? + ADD COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #addColumns + | ALTER TABLE tableIdentifier partitionSpec? + REPLACE COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #replaceColumns + | DROP TABLE (IF EXISTS)? tableIdentifier PURGE? + (FOR METADATA? REPLICATION '(' STRING ')')? #dropTable + | DROP VIEW (IF EXISTS)? tableIdentifier #dropTable + | CREATE (OR REPLACE)? VIEW (IF NOT EXISTS)? tableIdentifier + identifierCommentList? (COMMENT STRING)? + (PARTITIONED ON identifierList)? + (TBLPROPERTIES tablePropertyList)? AS query #createView + | ALTER VIEW tableIdentifier AS? query #alterViewQuery + | CREATE TEMPORARY? FUNCTION qualifiedName AS className=STRING + (USING resource (',' resource)*)? #createFunction + | DROP TEMPORARY? FUNCTION (IF EXISTS)? qualifiedName #dropFunction + | EXPLAIN explainOption* statement #explain + | SHOW TABLES ((FROM | IN) db=identifier)? + (LIKE? pattern=STRING)? #showTables + | SHOW DATABASES (LIKE pattern=STRING)? #showDatabases + | SHOW TBLPROPERTIES table=tableIdentifier + ('(' key=tablePropertyKey ')')? #showTblProperties + | SHOW FUNCTIONS (LIKE? (qualifiedName | pattern=STRING))? #showFunctions + | (DESC | DESCRIBE) FUNCTION EXTENDED? qualifiedName #describeFunction + | (DESC | DESCRIBE) option=(EXTENDED | FORMATTED)? + tableIdentifier partitionSpec? describeColName? #describeTable + | (DESC | DESCRIBE) DATABASE EXTENDED? identifier #describeDatabase + | REFRESH TABLE tableIdentifier #refreshTable + | CACHE LAZY? TABLE identifier (AS? query)? #cacheTable + | UNCACHE TABLE identifier #uncacheTable + | CLEAR CACHE #clearCache + | ADD identifier .*? #addResource + | SET ROLE .*? #failNativeCommand + | SET .*? #setConfiguration + | kws=unsupportedHiveNativeCommands .*? #failNativeCommand + | hiveNativeCommands #executeNativeCommand + ; + +hiveNativeCommands + : DELETE FROM tableIdentifier (WHERE booleanExpression)? + | TRUNCATE TABLE tableIdentifier partitionSpec? + (COLUMNS identifierList)? + | SHOW COLUMNS (FROM | IN) tableIdentifier ((FROM|IN) identifier)? + | START TRANSACTION (transactionMode (',' transactionMode)*)? + | COMMIT WORK? + | ROLLBACK WORK? + | SHOW PARTITIONS tableIdentifier partitionSpec? + | DFS .*? + | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | LOAD) .*? + ; + +unsupportedHiveNativeCommands + : kw1=CREATE kw2=ROLE + | kw1=DROP kw2=ROLE + | kw1=GRANT kw2=ROLE? + | kw1=REVOKE kw2=ROLE? + | kw1=SHOW kw2=GRANT + | kw1=SHOW kw2=ROLE kw3=GRANT? + | kw1=SHOW kw2=PRINCIPALS + | kw1=SHOW kw2=ROLES + | kw1=SHOW kw2=CURRENT kw3=ROLES + | kw1=EXPORT kw2=TABLE + | kw1=IMPORT kw2=TABLE + | kw1=SHOW kw2=COMPACTIONS + | kw1=SHOW kw2=CREATE kw3=TABLE + | kw1=SHOW kw2=TRANSACTIONS + | kw1=SHOW kw2=INDEXES + | kw1=SHOW kw2=LOCKS + | kw1=CREATE kw2=INDEX + | kw1=DROP kw2=INDEX + | kw1=ALTER kw2=INDEX + | kw1=LOCK kw2=TABLE + | kw1=LOCK kw2=DATABASE + | kw1=UNLOCK kw2=TABLE + | kw1=UNLOCK kw2=DATABASE + | kw1=CREATE kw2=TEMPORARY kw3=MACRO + | kw1=DROP kw2=TEMPORARY kw3=MACRO + | kw1=MSCK kw2=REPAIR kw3=TABLE + ; + +createTableHeader + : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? tableIdentifier + ; + +bucketSpec + : CLUSTERED BY identifierList + (SORTED BY orderedIdentifierList)? + INTO INTEGER_VALUE BUCKETS + ; + +skewSpec + : SKEWED BY identifierList + ON (constantList | nestedConstantList) + (STORED AS DIRECTORIES)? + ; + +locationSpec + : LOCATION STRING + ; + +query + : ctes? queryNoWith + ; + +insertInto + : INSERT OVERWRITE TABLE tableIdentifier partitionSpec? (IF NOT EXISTS)? + | INSERT INTO TABLE? tableIdentifier partitionSpec? + ; + +partitionSpecLocation + : partitionSpec locationSpec? + ; + +partitionSpec + : PARTITION '(' partitionVal (',' partitionVal)* ')' + ; + +partitionVal + : identifier (EQ constant)? + ; + +describeColName + : identifier ('.' (identifier | STRING))* + ; + +ctes + : WITH namedQuery (',' namedQuery)* + ; + +namedQuery + : name=identifier AS? '(' queryNoWith ')' + ; + +tableProvider + : USING qualifiedName + ; + +tablePropertyList + : '(' tableProperty (',' tableProperty)* ')' + ; + +tableProperty + : key=tablePropertyKey (EQ? value=STRING)? + ; + +tablePropertyKey + : looseIdentifier ('.' looseIdentifier)* + | STRING + ; + +constantList + : '(' constant (',' constant)* ')' + ; + +nestedConstantList + : '(' constantList (',' constantList)* ')' + ; + +skewedLocation + : (constant | constantList) EQ STRING + ; + +skewedLocationList + : '(' skewedLocation (',' skewedLocation)* ')' + ; + +createFileFormat + : STORED AS fileFormat + | STORED BY storageHandler + ; + +fileFormat + : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING (SERDE serdeCls=STRING)? #tableFileFormat + | identifier #genericFileFormat + ; + +storageHandler + : STRING (WITH SERDEPROPERTIES tablePropertyList)? + ; + +resource + : identifier STRING + ; + +queryNoWith + : insertInto? queryTerm queryOrganization #singleInsertQuery + | fromClause multiInsertQueryBody+ #multiInsertQuery + ; + +queryOrganization + : (ORDER BY order+=sortItem (',' order+=sortItem)*)? + (CLUSTER BY clusterBy+=expression (',' clusterBy+=expression)*)? + (DISTRIBUTE BY distributeBy+=expression (',' distributeBy+=expression)*)? + (SORT BY sort+=sortItem (',' sort+=sortItem)*)? + windows? + (LIMIT limit=expression)? + ; + +multiInsertQueryBody + : insertInto? + querySpecification + queryOrganization + ; + +queryTerm + : queryPrimary #queryTermDefault + | left=queryTerm operator=(INTERSECT | UNION | EXCEPT) setQuantifier? right=queryTerm #setOperation + ; + +queryPrimary + : querySpecification #queryPrimaryDefault + | TABLE tableIdentifier #table + | inlineTable #inlineTableDefault1 + | '(' queryNoWith ')' #subquery + ; + +sortItem + : expression ordering=(ASC | DESC)? + ; + +querySpecification + : (((SELECT kind=TRANSFORM '(' namedExpressionSeq ')' + | kind=MAP namedExpressionSeq + | kind=REDUCE namedExpressionSeq)) + inRowFormat=rowFormat? + (RECORDWRITER recordWriter=STRING)? + USING script=STRING + (AS (identifierSeq | colTypeList | ('(' (identifierSeq | colTypeList) ')')))? + outRowFormat=rowFormat? + (RECORDREADER recordReader=STRING)? + fromClause? + (WHERE where=booleanExpression)?) + | ((kind=SELECT setQuantifier? namedExpressionSeq fromClause? + | fromClause (kind=SELECT setQuantifier? namedExpressionSeq)?) + lateralView* + (WHERE where=booleanExpression)? + aggregation? + (HAVING having=booleanExpression)? + windows?) + ; + +fromClause + : FROM relation (',' relation)* lateralView* + ; + +aggregation + : GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* ( + WITH kind=ROLLUP + | WITH kind=CUBE + | kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')')? + ; + +groupingSet + : '(' (expression (',' expression)*)? ')' + | expression + ; + +lateralView + : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)? + ; + +setQuantifier + : DISTINCT + | ALL + ; + +relation + : left=relation + ((CROSS | joinType) JOIN right=relation joinCriteria? + | NATURAL joinType JOIN right=relation + ) #joinRelation + | relationPrimary #relationDefault + ; + +joinType + : INNER? + | LEFT OUTER? + | LEFT SEMI + | RIGHT OUTER? + | FULL OUTER? + | LEFT? ANTI + ; + +joinCriteria + : ON booleanExpression + | USING '(' identifier (',' identifier)* ')' + ; + +sample + : TABLESAMPLE '(' + ( (percentage=(INTEGER_VALUE | DECIMAL_VALUE) sampleType=PERCENTLIT) + | (expression sampleType=ROWS) + | (sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE (ON identifier)?)) + ')' + ; + +identifierList + : '(' identifierSeq ')' + ; + +identifierSeq + : identifier (',' identifier)* + ; + +orderedIdentifierList + : '(' orderedIdentifier (',' orderedIdentifier)* ')' + ; + +orderedIdentifier + : identifier ordering=(ASC | DESC)? + ; + +identifierCommentList + : '(' identifierComment (',' identifierComment)* ')' + ; + +identifierComment + : identifier (COMMENT STRING)? + ; + +relationPrimary + : tableIdentifier sample? (AS? identifier)? #tableName + | '(' queryNoWith ')' sample? (AS? identifier)? #aliasedQuery + | '(' relation ')' sample? (AS? identifier)? #aliasedRelation + | inlineTable #inlineTableDefault2 + ; + +inlineTable + : VALUES expression (',' expression)* (AS? identifier identifierList?)? + ; + +rowFormat + : ROW FORMAT SERDE name=STRING (WITH SERDEPROPERTIES props=tablePropertyList)? #rowFormatSerde + | ROW FORMAT DELIMITED + (FIELDS TERMINATED BY fieldsTerminatedBy=STRING (ESCAPED BY escapedBy=STRING)?)? + (COLLECTION ITEMS TERMINATED BY collectionItemsTerminatedBy=STRING)? + (MAP KEYS TERMINATED BY keysTerminatedBy=STRING)? + (LINES TERMINATED BY linesSeparatedBy=STRING)? + (NULL DEFINED AS nullDefinedAs=STRING)? #rowFormatDelimited + ; + +tableIdentifier + : (db=identifier '.')? table=identifier + ; + +namedExpression + : expression (AS? (identifier | identifierList))? + ; + +namedExpressionSeq + : namedExpression (',' namedExpression)* + ; + +expression + : booleanExpression + ; + +booleanExpression + : predicated #booleanDefault + | NOT booleanExpression #logicalNot + | left=booleanExpression operator=AND right=booleanExpression #logicalBinary + | left=booleanExpression operator=OR right=booleanExpression #logicalBinary + | EXISTS '(' query ')' #exists + ; + +// workaround for: +// https://github.com/antlr/antlr4/issues/780 +// https://github.com/antlr/antlr4/issues/781 +predicated + : valueExpression predicate? + ; + +predicate + : NOT? kind=BETWEEN lower=valueExpression AND upper=valueExpression + | NOT? kind=IN '(' expression (',' expression)* ')' + | NOT? kind=IN '(' query ')' + | NOT? kind=(RLIKE | LIKE) pattern=valueExpression + | IS NOT? kind=NULL + ; + +valueExpression + : primaryExpression #valueExpressionDefault + | operator=(MINUS | PLUS | TILDE) valueExpression #arithmeticUnary + | left=valueExpression operator=(ASTERISK | SLASH | PERCENT | DIV) right=valueExpression #arithmeticBinary + | left=valueExpression operator=(PLUS | MINUS) right=valueExpression #arithmeticBinary + | left=valueExpression operator=AMPERSAND right=valueExpression #arithmeticBinary + | left=valueExpression operator=HAT right=valueExpression #arithmeticBinary + | left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary + | left=valueExpression comparisonOperator right=valueExpression #comparison + ; + +primaryExpression + : constant #constantDefault + | ASTERISK #star + | qualifiedName '.' ASTERISK #star + | '(' expression (',' expression)+ ')' #rowConstructor + | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall + | '(' query ')' #subqueryExpression + | CASE valueExpression whenClause+ (ELSE elseExpression=expression)? END #simpleCase + | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase + | CAST '(' expression AS dataType ')' #cast + | value=primaryExpression '[' index=valueExpression ']' #subscript + | identifier #columnReference + | base=primaryExpression '.' fieldName=identifier #dereference + | '(' expression ')' #parenthesizedExpression + ; + +constant + : NULL #nullLiteral + | interval #intervalLiteral + | identifier STRING #typeConstructor + | number #numericLiteral + | booleanValue #booleanLiteral + | STRING+ #stringLiteral + ; + +comparisonOperator + : EQ | NEQ | NEQJ | LT | LTE | GT | GTE | NSEQ + ; + +booleanValue + : TRUE | FALSE + ; + +interval + : INTERVAL intervalField* + ; + +intervalField + : value=intervalValue unit=identifier (TO to=identifier)? + ; + +intervalValue + : (PLUS | MINUS)? (INTEGER_VALUE | DECIMAL_VALUE) + | STRING + ; + +dataType + : complex=ARRAY '<' dataType '>' #complexDataType + | complex=MAP '<' dataType ',' dataType '>' #complexDataType + | complex=STRUCT ('<' colTypeList? '>' | NEQ) #complexDataType + | identifier ('(' INTEGER_VALUE (',' INTEGER_VALUE)* ')')? #primitiveDataType + ; + +colTypeList + : colType (',' colType)* + ; + +colType + : identifier ':'? dataType (COMMENT STRING)? + ; + +whenClause + : WHEN condition=expression THEN result=expression + ; + +windows + : WINDOW namedWindow (',' namedWindow)* + ; + +namedWindow + : identifier AS windowSpec + ; + +windowSpec + : name=identifier #windowRef + | '(' + ( CLUSTER BY partition+=expression (',' partition+=expression)* + | ((PARTITION | DISTRIBUTE) BY partition+=expression (',' partition+=expression)*)? + ((ORDER | SORT) BY sortItem (',' sortItem)*)?) + windowFrame? + ')' #windowDef + ; + +windowFrame + : frameType=RANGE start=frameBound + | frameType=ROWS start=frameBound + | frameType=RANGE BETWEEN start=frameBound AND end=frameBound + | frameType=ROWS BETWEEN start=frameBound AND end=frameBound + ; + +frameBound + : UNBOUNDED boundType=(PRECEDING | FOLLOWING) + | boundType=CURRENT ROW + | expression boundType=(PRECEDING | FOLLOWING) + ; + + +explainOption + : LOGICAL | FORMATTED | EXTENDED | CODEGEN + ; + +transactionMode + : ISOLATION LEVEL SNAPSHOT #isolationLevel + | READ accessMode=(ONLY | WRITE) #transactionAccessMode + ; + +qualifiedName + : identifier ('.' identifier)* + ; + +// Identifier that also allows the use of a number of SQL keywords (mainly for backwards compatibility). +looseIdentifier + : identifier + | FROM + | TO + | TABLE + | WITH + ; + +identifier + : IDENTIFIER #unquotedIdentifier + | quotedIdentifier #quotedIdentifierAlternative + | nonReserved #unquotedIdentifier + ; + +quotedIdentifier + : BACKQUOTED_IDENTIFIER + ; + +number + : DECIMAL_VALUE #decimalLiteral + | SCIENTIFIC_DECIMAL_VALUE #scientificDecimalLiteral + | INTEGER_VALUE #integerLiteral + | BIGINT_LITERAL #bigIntLiteral + | SMALLINT_LITERAL #smallIntLiteral + | TINYINT_LITERAL #tinyIntLiteral + | DOUBLE_LITERAL #doubleLiteral + ; + +nonReserved + : SHOW | TABLES | COLUMNS | COLUMN | PARTITIONS | FUNCTIONS | DATABASES + | ADD + | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | MAP | ARRAY | STRUCT + | LATERAL | WINDOW | REDUCE | TRANSFORM | USING | SERDE | SERDEPROPERTIES | RECORDREADER + | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED + | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | TEMPORARY | OPTIONS + | GROUPING | CUBE | ROLLUP + | EXPLAIN | FORMAT | LOGICAL | FORMATTED | CODEGEN + | TABLESAMPLE | USE | TO | BUCKET | PERCENTLIT | OUT | OF + | SET + | VIEW | REPLACE + | IF + | NO | DATA + | START | TRANSACTION | COMMIT | ROLLBACK | WORK | ISOLATION | LEVEL + | SNAPSHOT | READ | WRITE | ONLY + | SORT | CLUSTER | DISTRIBUTE | UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION + | EXCHANGE | ARCHIVE | UNARCHIVE | FILEFORMAT | TOUCH | COMPACT | CONCATENATE | CHANGE | FIRST + | AFTER | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT + | INPUTDRIVER | OUTPUTDRIVER | DBPROPERTIES | DFS | TRUNCATE | METADATA | REPLICATION | COMPUTE + | STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER + | REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE + | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEX | INDEXES | LOCKS | OPTION + ; + +SELECT: 'SELECT'; +FROM: 'FROM'; +ADD: 'ADD'; +AS: 'AS'; +ALL: 'ALL'; +DISTINCT: 'DISTINCT'; +WHERE: 'WHERE'; +GROUP: 'GROUP'; +BY: 'BY'; +GROUPING: 'GROUPING'; +SETS: 'SETS'; +CUBE: 'CUBE'; +ROLLUP: 'ROLLUP'; +ORDER: 'ORDER'; +HAVING: 'HAVING'; +LIMIT: 'LIMIT'; +AT: 'AT'; +OR: 'OR'; +AND: 'AND'; +IN: 'IN'; +NOT: 'NOT' | '!'; +NO: 'NO'; +EXISTS: 'EXISTS'; +BETWEEN: 'BETWEEN'; +LIKE: 'LIKE'; +RLIKE: 'RLIKE' | 'REGEXP'; +IS: 'IS'; +NULL: 'NULL'; +TRUE: 'TRUE'; +FALSE: 'FALSE'; +NULLS: 'NULLS'; +ASC: 'ASC'; +DESC: 'DESC'; +FOR: 'FOR'; +INTERVAL: 'INTERVAL'; +CASE: 'CASE'; +WHEN: 'WHEN'; +THEN: 'THEN'; +ELSE: 'ELSE'; +END: 'END'; +JOIN: 'JOIN'; +CROSS: 'CROSS'; +OUTER: 'OUTER'; +INNER: 'INNER'; +LEFT: 'LEFT'; +SEMI: 'SEMI'; +RIGHT: 'RIGHT'; +FULL: 'FULL'; +NATURAL: 'NATURAL'; +ON: 'ON'; +LATERAL: 'LATERAL'; +WINDOW: 'WINDOW'; +OVER: 'OVER'; +PARTITION: 'PARTITION'; +RANGE: 'RANGE'; +ROWS: 'ROWS'; +UNBOUNDED: 'UNBOUNDED'; +PRECEDING: 'PRECEDING'; +FOLLOWING: 'FOLLOWING'; +CURRENT: 'CURRENT'; +ROW: 'ROW'; +WITH: 'WITH'; +VALUES: 'VALUES'; +CREATE: 'CREATE'; +TABLE: 'TABLE'; +VIEW: 'VIEW'; +REPLACE: 'REPLACE'; +INSERT: 'INSERT'; +DELETE: 'DELETE'; +INTO: 'INTO'; +DESCRIBE: 'DESCRIBE'; +EXPLAIN: 'EXPLAIN'; +FORMAT: 'FORMAT'; +LOGICAL: 'LOGICAL'; +CODEGEN: 'CODEGEN'; +CAST: 'CAST'; +SHOW: 'SHOW'; +TABLES: 'TABLES'; +COLUMNS: 'COLUMNS'; +COLUMN: 'COLUMN'; +USE: 'USE'; +PARTITIONS: 'PARTITIONS'; +FUNCTIONS: 'FUNCTIONS'; +DROP: 'DROP'; +UNION: 'UNION'; +EXCEPT: 'EXCEPT'; +INTERSECT: 'INTERSECT'; +TO: 'TO'; +TABLESAMPLE: 'TABLESAMPLE'; +STRATIFY: 'STRATIFY'; +ALTER: 'ALTER'; +RENAME: 'RENAME'; +ARRAY: 'ARRAY'; +MAP: 'MAP'; +STRUCT: 'STRUCT'; +COMMENT: 'COMMENT'; +SET: 'SET'; +DATA: 'DATA'; +START: 'START'; +TRANSACTION: 'TRANSACTION'; +COMMIT: 'COMMIT'; +ROLLBACK: 'ROLLBACK'; +WORK: 'WORK'; +ISOLATION: 'ISOLATION'; +LEVEL: 'LEVEL'; +SNAPSHOT: 'SNAPSHOT'; +READ: 'READ'; +WRITE: 'WRITE'; +ONLY: 'ONLY'; +MACRO: 'MACRO'; + +IF: 'IF'; + +EQ : '=' | '=='; +NSEQ: '<=>'; +NEQ : '<>'; +NEQJ: '!='; +LT : '<'; +LTE : '<='; +GT : '>'; +GTE : '>='; + +PLUS: '+'; +MINUS: '-'; +ASTERISK: '*'; +SLASH: '/'; +PERCENT: '%'; +DIV: 'DIV'; +TILDE: '~'; +AMPERSAND: '&'; +PIPE: '|'; +HAT: '^'; + +PERCENTLIT: 'PERCENT'; +BUCKET: 'BUCKET'; +OUT: 'OUT'; +OF: 'OF'; + +SORT: 'SORT'; +CLUSTER: 'CLUSTER'; +DISTRIBUTE: 'DISTRIBUTE'; +OVERWRITE: 'OVERWRITE'; +TRANSFORM: 'TRANSFORM'; +REDUCE: 'REDUCE'; +USING: 'USING'; +SERDE: 'SERDE'; +SERDEPROPERTIES: 'SERDEPROPERTIES'; +RECORDREADER: 'RECORDREADER'; +RECORDWRITER: 'RECORDWRITER'; +DELIMITED: 'DELIMITED'; +FIELDS: 'FIELDS'; +TERMINATED: 'TERMINATED'; +COLLECTION: 'COLLECTION'; +ITEMS: 'ITEMS'; +KEYS: 'KEYS'; +ESCAPED: 'ESCAPED'; +LINES: 'LINES'; +SEPARATED: 'SEPARATED'; +FUNCTION: 'FUNCTION'; +EXTENDED: 'EXTENDED'; +REFRESH: 'REFRESH'; +CLEAR: 'CLEAR'; +CACHE: 'CACHE'; +UNCACHE: 'UNCACHE'; +LAZY: 'LAZY'; +FORMATTED: 'FORMATTED'; +TEMPORARY: 'TEMPORARY' | 'TEMP'; +OPTIONS: 'OPTIONS'; +UNSET: 'UNSET'; +TBLPROPERTIES: 'TBLPROPERTIES'; +DBPROPERTIES: 'DBPROPERTIES'; +BUCKETS: 'BUCKETS'; +SKEWED: 'SKEWED'; +STORED: 'STORED'; +DIRECTORIES: 'DIRECTORIES'; +LOCATION: 'LOCATION'; +EXCHANGE: 'EXCHANGE'; +ARCHIVE: 'ARCHIVE'; +UNARCHIVE: 'UNARCHIVE'; +FILEFORMAT: 'FILEFORMAT'; +TOUCH: 'TOUCH'; +COMPACT: 'COMPACT'; +CONCATENATE: 'CONCATENATE'; +CHANGE: 'CHANGE'; +FIRST: 'FIRST'; +AFTER: 'AFTER'; +CASCADE: 'CASCADE'; +RESTRICT: 'RESTRICT'; +CLUSTERED: 'CLUSTERED'; +SORTED: 'SORTED'; +PURGE: 'PURGE'; +INPUTFORMAT: 'INPUTFORMAT'; +OUTPUTFORMAT: 'OUTPUTFORMAT'; +INPUTDRIVER: 'INPUTDRIVER'; +OUTPUTDRIVER: 'OUTPUTDRIVER'; +DATABASE: 'DATABASE' | 'SCHEMA'; +DATABASES: 'DATABASES' | 'SCHEMAS'; +DFS: 'DFS'; +TRUNCATE: 'TRUNCATE'; +METADATA: 'METADATA'; +REPLICATION: 'REPLICATION'; +ANALYZE: 'ANALYZE'; +COMPUTE: 'COMPUTE'; +STATISTICS: 'STATISTICS'; +PARTITIONED: 'PARTITIONED'; +EXTERNAL: 'EXTERNAL'; +DEFINED: 'DEFINED'; +REVOKE: 'REVOKE'; +GRANT: 'GRANT'; +LOCK: 'LOCK'; +UNLOCK: 'UNLOCK'; +MSCK: 'MSCK'; +REPAIR: 'REPAIR'; +EXPORT: 'EXPORT'; +IMPORT: 'IMPORT'; +LOAD: 'LOAD'; +ROLE: 'ROLE'; +ROLES: 'ROLES'; +COMPACTIONS: 'COMPACTIONS'; +PRINCIPALS: 'PRINCIPALS'; +TRANSACTIONS: 'TRANSACTIONS'; +INDEX: 'INDEX'; +INDEXES: 'INDEXES'; +LOCKS: 'LOCKS'; +OPTION: 'OPTION'; +ANTI: 'ANTI'; + +STRING + : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' + | '\"' ( ~('\"'|'\\') | ('\\' .) )* '\"' + ; + +BIGINT_LITERAL + : DIGIT+ 'L' + ; + +SMALLINT_LITERAL + : DIGIT+ 'S' + ; + +TINYINT_LITERAL + : DIGIT+ 'Y' + ; + +INTEGER_VALUE + : DIGIT+ + ; + +DECIMAL_VALUE + : DIGIT+ '.' DIGIT* + | '.' DIGIT+ + ; + +SCIENTIFIC_DECIMAL_VALUE + : DIGIT+ ('.' DIGIT*)? EXPONENT + | '.' DIGIT+ EXPONENT + ; + +DOUBLE_LITERAL + : + (INTEGER_VALUE | DECIMAL_VALUE | SCIENTIFIC_DECIMAL_VALUE) 'D' + ; + +IDENTIFIER + : (LETTER | DIGIT | '_')+ + ; + +BACKQUOTED_IDENTIFIER + : '`' ( ~'`' | '``' )* '`' + ; + +fragment EXPONENT + : 'E' [+-]? DIGIT+ + ; + +fragment DIGIT + : [0-9] + ; + +fragment LETTER + : [A-Z] + ; + +SIMPLE_COMMENT + : '--' ~[\r\n]* '\r'? '\n'? -> channel(HIDDEN) + ; + +BRACKETED_COMMENT + : '/*' .*? '*/' -> channel(HIDDEN) + ; + +WS + : [ \r\n\t]+ -> channel(HIDDEN) + ; + +// Catch-all for anything we can't recognize. +// We use this to be able to ignore and recover all the text +// when splitting statements with DelimiterLexer +UNRECOGNIZED + : . + ; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 3513960b41813..648625b2cc5d2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -25,6 +25,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -270,8 +271,8 @@ public UnsafeRow getStruct(int ordinal, int numFields) { final int offset = getElementOffset(ordinal); if (offset < 0) return null; final int size = getElementSize(offset, ordinal); - final UnsafeRow row = new UnsafeRow(); - row.pointTo(baseObject, baseOffset + offset, numFields, size); + final UnsafeRow row = new UnsafeRow(numFields); + row.pointTo(baseObject, baseOffset + offset, size); return row; } @@ -299,11 +300,7 @@ public UnsafeMapData getMap(int ordinal) { @Override public int hashCode() { - int result = 37; - for (int i = 0; i < sizeInBytes; i++) { - result = 37 * result + Platform.getByte(baseObject, baseOffset + i); - } - return result; + return Murmur3_x86_32.hashUnsafeBytes(baseObject, baseOffset, sizeInBytes, 42); } @Override diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 5ba14ebdb62a4..dd2f39eb816f2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -68,6 +68,10 @@ public static int calculateBitSetWidthInBytes(int numFields) { return ((numFields + 63)/ 64) * 8; } + public static int calculateFixedPortionByteSize(int numFields) { + return 8 * numFields + calculateBitSetWidthInBytes(numFields); + } + /** * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types) */ @@ -116,11 +120,6 @@ public static boolean isMutable(DataType dt) { /** The size of this row's backing data, in bytes) */ private int sizeInBytes; - private void setNotNullAt(int i) { - assertIndexIsValid(i); - BitSetMethods.unset(baseObject, baseOffset, i); - } - /** The width of the null tracking bit set, in bytes */ private int bitSetWidthInBytes; @@ -140,8 +139,16 @@ private void assertIndexIsValid(int index) { /** * Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called, * since the value returned by this constructor is equivalent to a null pointer. + * + * @param numFields the number of fields in this row */ - public UnsafeRow() { } + public UnsafeRow(int numFields) { + this.numFields = numFields; + this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); + } + + // for serializer + public UnsafeRow() {} public Object getBaseObject() { return baseObject; } public long getBaseOffset() { return baseOffset; } @@ -155,15 +162,12 @@ public UnsafeRow() { } * * @param baseObject the base object * @param baseOffset the offset within the base object - * @param numFields the number of fields in this row * @param sizeInBytes the size of this row's backing data, in bytes */ - public void pointTo(Object baseObject, long baseOffset, int numFields, int sizeInBytes) { + public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { assert numFields >= 0 : "numFields (" + numFields + ") should >= 0"; - this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); this.baseObject = baseObject; this.baseOffset = baseOffset; - this.numFields = numFields; this.sizeInBytes = sizeInBytes; } @@ -171,11 +175,19 @@ public void pointTo(Object baseObject, long baseOffset, int numFields, int sizeI * Update this UnsafeRow to point to the underlying byte array. * * @param buf byte array to point to - * @param numFields the number of fields in this row * @param sizeInBytes the number of bytes valid in the byte array */ - public void pointTo(byte[] buf, int numFields, int sizeInBytes) { - pointTo(buf, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); + public void pointTo(byte[] buf, int sizeInBytes) { + pointTo(buf, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); + } + + public void setTotalSize(int sizeInBytes) { + this.sizeInBytes = sizeInBytes; + } + + public void setNotNullAt(int i) { + assertIndexIsValid(i); + BitSetMethods.unset(baseObject, baseOffset, i); } @Override @@ -388,7 +400,7 @@ public Decimal getDecimal(int ordinal, int precision, int scale) { return null; } if (precision <= Decimal.MAX_LONG_DIGITS()) { - return Decimal.apply(getLong(ordinal), precision, scale); + return Decimal.createUnsafe(getLong(ordinal), precision, scale); } else { byte[] bytes = getBinary(ordinal); BigInteger bigInteger = new BigInteger(bytes); @@ -447,8 +459,8 @@ public UnsafeRow getStruct(int ordinal, int numFields) { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) offsetAndSize; - final UnsafeRow row = new UnsafeRow(); - row.pointTo(baseObject, baseOffset + offset, numFields, size); + final UnsafeRow row = new UnsafeRow(numFields); + row.pointTo(baseObject, baseOffset + offset, size); return row; } } @@ -487,7 +499,7 @@ public UnsafeMapData getMap(int ordinal) { */ @Override public UnsafeRow copy() { - UnsafeRow rowCopy = new UnsafeRow(); + UnsafeRow rowCopy = new UnsafeRow(numFields); final byte[] rowDataCopy = new byte[sizeInBytes]; Platform.copyMemory( baseObject, @@ -496,7 +508,7 @@ public UnsafeRow copy() { Platform.BYTE_ARRAY_OFFSET, sizeInBytes ); - rowCopy.pointTo(rowDataCopy, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); + rowCopy.pointTo(rowDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); return rowCopy; } @@ -505,8 +517,8 @@ public UnsafeRow copy() { * The returned row is invalid until we call copyFrom on it. */ public static UnsafeRow createFromByteArray(int numBytes, int numFields) { - final UnsafeRow row = new UnsafeRow(); - row.pointTo(new byte[numBytes], numFields, numBytes); + final UnsafeRow row = new UnsafeRow(numFields); + row.pointTo(new byte[numBytes], numBytes); return row; } @@ -588,10 +600,9 @@ public byte[] getBytes() { public String toString() { StringBuilder build = new StringBuilder("["); for (int i = 0; i < sizeInBytes; i += 8) { + if (i != 0) build.append(','); build.append(java.lang.Long.toHexString(Platform.getLong(baseObject, baseOffset + i))); - build.append(','); } - build.deleteCharAt(build.length() - 1); build.append(']'); return build.toString(); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java new file mode 100644 index 0000000000000..f37ef83ad92b4 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.unsafe.Platform; + +// scalastyle: off +/** + * xxHash64. A high quality and fast 64 bit hash code by Yann Colet and Mathias Westerdahl. The + * class below is modelled like its Murmur3_x86_32 cousin. + *

    + * This was largely based on the following (original) C and Java implementations: + * https://github.com/Cyan4973/xxHash/blob/master/xxhash.c + * https://github.com/OpenHFT/Zero-Allocation-Hashing/blob/master/src/main/java/net/openhft/hashing/XxHash_r39.java + * https://github.com/airlift/slice/blob/master/src/main/java/io/airlift/slice/XxHash64.java + */ +// scalastyle: on +public final class XXH64 { + + private static final long PRIME64_1 = 0x9E3779B185EBCA87L; + private static final long PRIME64_2 = 0xC2B2AE3D27D4EB4FL; + private static final long PRIME64_3 = 0x165667B19E3779F9L; + private static final long PRIME64_4 = 0x85EBCA77C2B2AE63L; + private static final long PRIME64_5 = 0x27D4EB2F165667C5L; + + private final long seed; + + public XXH64(long seed) { + super(); + this.seed = seed; + } + + @Override + public String toString() { + return "xxHash64(seed=" + seed + ")"; + } + + public long hashInt(int input) { + return hashInt(input, seed); + } + + public static long hashInt(int input, long seed) { + long hash = seed + PRIME64_5 + 4L; + hash ^= (input & 0xFFFFFFFFL) * PRIME64_1; + hash = Long.rotateLeft(hash, 23) * PRIME64_2 + PRIME64_3; + return fmix(hash); + } + + public long hashLong(long input) { + return hashLong(input, seed); + } + + public static long hashLong(long input, long seed) { + long hash = seed + PRIME64_5 + 8L; + hash ^= Long.rotateLeft(input * PRIME64_2, 31) * PRIME64_1; + hash = Long.rotateLeft(hash, 27) * PRIME64_1 + PRIME64_4; + return fmix(hash); + } + + public long hashUnsafeWords(Object base, long offset, int length) { + return hashUnsafeWords(base, offset, length, seed); + } + + public static long hashUnsafeWords(Object base, long offset, int length, long seed) { + assert (length % 8 == 0) : "lengthInBytes must be a multiple of 8 (word-aligned)"; + long hash = hashBytesByWords(base, offset, length, seed); + return fmix(hash); + } + + public long hashUnsafeBytes(Object base, long offset, int length) { + return hashUnsafeBytes(base, offset, length, seed); + } + + public static long hashUnsafeBytes(Object base, long offset, int length, long seed) { + assert (length >= 0) : "lengthInBytes cannot be negative"; + long hash = hashBytesByWords(base, offset, length, seed); + long end = offset + length; + offset += length & -8; + + if (offset + 4L <= end) { + hash ^= (Platform.getInt(base, offset) & 0xFFFFFFFFL) * PRIME64_1; + hash = Long.rotateLeft(hash, 23) * PRIME64_2 + PRIME64_3; + offset += 4L; + } + + while (offset < end) { + hash ^= (Platform.getByte(base, offset) & 0xFFL) * PRIME64_5; + hash = Long.rotateLeft(hash, 11) * PRIME64_1; + offset++; + } + return fmix(hash); + } + + private static long fmix(long hash) { + hash ^= hash >>> 33; + hash *= PRIME64_2; + hash ^= hash >>> 29; + hash *= PRIME64_3; + hash ^= hash >>> 32; + return hash; + } + + private static long hashBytesByWords(Object base, long offset, int length, long seed) { + long end = offset + length; + long hash; + if (length >= 32) { + long limit = end - 32; + long v1 = seed + PRIME64_1 + PRIME64_2; + long v2 = seed + PRIME64_2; + long v3 = seed; + long v4 = seed - PRIME64_1; + + do { + v1 += Platform.getLong(base, offset) * PRIME64_2; + v1 = Long.rotateLeft(v1, 31); + v1 *= PRIME64_1; + + v2 += Platform.getLong(base, offset + 8) * PRIME64_2; + v2 = Long.rotateLeft(v2, 31); + v2 *= PRIME64_1; + + v3 += Platform.getLong(base, offset + 16) * PRIME64_2; + v3 = Long.rotateLeft(v3, 31); + v3 *= PRIME64_1; + + v4 += Platform.getLong(base, offset + 24) * PRIME64_2; + v4 = Long.rotateLeft(v4, 31); + v4 *= PRIME64_1; + + offset += 32L; + } while (offset <= limit); + + hash = Long.rotateLeft(v1, 1) + + Long.rotateLeft(v2, 7) + + Long.rotateLeft(v3, 12) + + Long.rotateLeft(v4, 18); + + v1 *= PRIME64_2; + v1 = Long.rotateLeft(v1, 31); + v1 *= PRIME64_1; + hash ^= v1; + hash = hash * PRIME64_1 + PRIME64_4; + + v2 *= PRIME64_2; + v2 = Long.rotateLeft(v2, 31); + v2 *= PRIME64_1; + hash ^= v2; + hash = hash * PRIME64_1 + PRIME64_4; + + v3 *= PRIME64_2; + v3 = Long.rotateLeft(v3, 31); + v3 *= PRIME64_1; + hash ^= v3; + hash = hash * PRIME64_1 + PRIME64_4; + + v4 *= PRIME64_2; + v4 = Long.rotateLeft(v4, 31); + v4 *= PRIME64_1; + hash ^= v4; + hash = hash * PRIME64_1 + PRIME64_4; + } else { + hash = seed + PRIME64_5; + } + + hash += length; + + long limit = end - 8; + while (offset <= limit) { + long k1 = Platform.getLong(base, offset); + hash ^= Long.rotateLeft(k1 * PRIME64_2, 31) * PRIME64_1; + hash = Long.rotateLeft(hash, 27) * PRIME64_1 + PRIME64_4; + offset += 8L; + } + return hash; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 9c9468678065d..af61e2011f400 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -17,18 +17,43 @@ package org.apache.spark.sql.catalyst.expressions.codegen; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.unsafe.Platform; /** - * A helper class to manage the row buffer used in `GenerateUnsafeProjection`. + * A helper class to manage the data buffer for an unsafe row. The data buffer can grow and + * automatically re-point the unsafe row to it. * - * Note that it is only used in `GenerateUnsafeProjection`, so it's safe to mark member variables - * public for ease of use. + * This class can be used to build a one-pass unsafe row writing program, i.e. data will be written + * to the data buffer directly and no extra copy is needed. There should be only one instance of + * this class per writing program, so that the memory segment/data buffer can be reused. Note that + * for each incoming record, we should call `reset` of BufferHolder instance before write the record + * and reuse the data buffer. + * + * Generally we should call `UnsafeRow.setTotalSize` and pass in `BufferHolder.totalSize` to update + * the size of the result row, after writing a record to the buffer. However, we can skip this step + * if the fields of row are all fixed-length, as the size of result row is also fixed. */ public class BufferHolder { - public byte[] buffer = new byte[64]; + public byte[] buffer; public int cursor = Platform.BYTE_ARRAY_OFFSET; + private final UnsafeRow row; + private final int fixedSize; + + public BufferHolder(UnsafeRow row) { + this(row, 64); + } + + public BufferHolder(UnsafeRow row, int initialSize) { + this.fixedSize = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()) + 8 * row.numFields(); + this.buffer = new byte[fixedSize + initialSize]; + this.row = row; + this.row.pointTo(buffer, buffer.length); + } + /** + * Grows the buffer by at least neededSize and points the row to the buffer. + */ public void grow(int neededSize) { final int length = totalSize() + neededSize; if (buffer.length < length) { @@ -41,11 +66,12 @@ public void grow(int neededSize) { Platform.BYTE_ARRAY_OFFSET, totalSize()); buffer = tmp; + row.pointTo(buffer, buffer.length); } } public void reset() { - cursor = Platform.BYTE_ARRAY_OFFSET; + cursor = Platform.BYTE_ARRAY_OFFSET + fixedSize; } public int totalSize() { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 048b7749d8fb4..4776617043878 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -26,27 +26,51 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * A helper class to write data into global row buffer using `UnsafeRow` format, - * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}. + * A helper class to write data into global row buffer using `UnsafeRow` format. + * + * It will remember the offset of row buffer which it starts to write, and move the cursor of row + * buffer while writing. If new data(can be the input record if this is the outermost writer, or + * nested struct if this is an inner writer) comes, the starting cursor of row buffer may be + * changed, so we need to call `UnsafeRowWriter.reset` before writing, to update the + * `startingOffset` and clear out null bits. + * + * Note that if this is the outermost writer, which means we will always write from the very + * beginning of the global row buffer, we don't need to update `startingOffset` and can just call + * `zeroOutNullBytes` before writing new data. */ public class UnsafeRowWriter { - private BufferHolder holder; + private final BufferHolder holder; // The offset of the global buffer where we start to write this row. private int startingOffset; - private int nullBitsSize; + private final int nullBitsSize; + private final int fixedSize; - public void initialize(BufferHolder holder, int numFields) { + public UnsafeRowWriter(BufferHolder holder, int numFields) { this.holder = holder; - this.startingOffset = holder.cursor; this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields); + this.fixedSize = nullBitsSize + 8 * numFields; + this.startingOffset = holder.cursor; + } + + /** + * Resets the `startingOffset` according to the current cursor of row buffer, and clear out null + * bits. This should be called before we write a new nested struct to the row buffer. + */ + public void reset() { + this.startingOffset = holder.cursor; // grow the global buffer to make sure it has enough space to write fixed-length data. - final int fixedSize = nullBitsSize + 8 * numFields; holder.grow(fixedSize); holder.cursor += fixedSize; - // zero-out the null bits region + zeroOutNullBytes(); + } + + /** + * Clears out null bits. This should be called before we write a new row to row buffer. + */ + public void zeroOutNullBytes() { for (int i = 0; i < nullBitsSize; i += 8) { Platform.putLong(holder.buffer, startingOffset + i, 0L); } @@ -58,6 +82,8 @@ private void zeroOutPaddingBytes(int numBytes) { } } + public BufferHolder holder() { return holder; } + public boolean isNullAt(int ordinal) { return BitSetMethods.isSet(holder.buffer, startingOffset, ordinal); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index f7063d1e5c829..7784345a7a966 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -26,10 +26,9 @@ import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; -import org.apache.spark.sql.catalyst.util.AbstractScalaRowIterator; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.util.AbstractScalaRowIterator; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.Platform; import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; @@ -37,7 +36,7 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter; import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator; -final class UnsafeExternalRowSorter { +public final class UnsafeExternalRowSorter { /** * If positive, forces records to be spilled to disk at the given frequency (measured in numbers @@ -51,7 +50,7 @@ final class UnsafeExternalRowSorter { private final PrefixComputer prefixComputer; private final UnsafeExternalSorter sorter; - public static abstract class PrefixComputer { + public abstract static class PrefixComputer { abstract long computePrefix(InternalRow row); } @@ -68,6 +67,7 @@ public UnsafeExternalRowSorter( sorter = UnsafeExternalSorter.create( taskContext.taskMemoryManager(), sparkEnv.blockManager(), + sparkEnv.serializerManager(), taskContext, new RowComparator(ordering, schema.length()), prefixComparator, @@ -85,8 +85,7 @@ void setTestSpillFrequency(int frequency) { testSpillFrequency = frequency; } - @VisibleForTesting - void insertRow(UnsafeRow row) throws IOException { + public void insertRow(UnsafeRow row) throws IOException { final long prefix = prefixComputer.computePrefix(row); sorter.insertRecord( row.getBaseObject(), @@ -111,8 +110,7 @@ private void cleanupResources() { sorter.cleanupResources(); } - @VisibleForTesting - Iterator sort() throws IOException { + public Iterator sort() throws IOException { try { final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); if (!sortedIterator.hasNext()) { @@ -123,7 +121,7 @@ Iterator sort() throws IOException { return new AbstractScalaRowIterator() { private final int numFields = schema.length(); - private UnsafeRow row = new UnsafeRow(); + private UnsafeRow row = new UnsafeRow(numFields); @Override public boolean hasNext() { @@ -137,7 +135,6 @@ public UnsafeRow next() { row.pointTo( sortedIterator.getBaseObject(), sortedIterator.getBaseOffset(), - numFields, sortedIterator.getRecordLength()); if (!hasNext()) { UnsafeRow copy = row.copy(); // so that we don't have dangling pointers to freed page @@ -154,7 +151,7 @@ public UnsafeRow next() { Platform.throwException(e); } throw new RuntimeException("Exception should have been re-thrown in next()"); - }; + } }; } catch (IOException e) { cleanupResources(); @@ -162,7 +159,6 @@ public UnsafeRow next() { } } - public Iterator sort(Iterator inputIterator) throws IOException { while (inputIterator.hasNext()) { insertRow(inputIterator.next()); @@ -170,29 +166,24 @@ public Iterator sort(Iterator inputIterator) throws IOExce return sort(); } - /** - * Return true if UnsafeExternalRowSorter can sort rows with the given schema, false otherwise. - */ - public static boolean supportsSchema(StructType schema) { - return UnsafeProjection.canSupport(schema); - } - private static final class RowComparator extends RecordComparator { private final Ordering ordering; private final int numFields; - private final UnsafeRow row1 = new UnsafeRow(); - private final UnsafeRow row2 = new UnsafeRow(); + private final UnsafeRow row1; + private final UnsafeRow row2; - public RowComparator(Ordering ordering, int numFields) { + RowComparator(Ordering ordering, int numFields) { this.numFields = numFields; + this.row1 = new UnsafeRow(numFields); + this.row2 = new UnsafeRow(numFields); this.ordering = ordering; } @Override public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { // TODO: Why are the sizes -1? - row1.pointTo(baseObj1, baseOff1, numFields, -1); - row2.pointTo(baseObj2, baseOff2, numFields, -1); + row1.pointTo(baseObj1, baseOff1, -1); + row2.pointTo(baseObj2, baseOff2, -1); return ordering.compare(row1, row2); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java index 17659d7d960b0..24adeadf95675 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java @@ -201,7 +201,7 @@ public static StructType createStructType(StructField[] fields) { if (fields == null) { throw new IllegalArgumentException("fields should not be null."); } - Set distinctNames = new HashSet(); + Set distinctNames = new HashSet<>(); for (StructField field : fields) { if (field == null) { throw new IllegalArgumentException( diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java index df64a878b6b36..1e4e5ede8cc11 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java @@ -41,5 +41,5 @@ * Returns an instance of the UserDefinedType which can serialize and deserialize the user * class to and from Catalyst built-in types. */ - Class > udt(); + Class> udt(); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index f9992185a4563..d2003fd6892e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -18,6 +18,10 @@ package org.apache.spark.sql import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + + +// TODO: don't swallow original stack trace if it exists /** * :: DeveloperApi :: @@ -27,7 +31,8 @@ import org.apache.spark.annotation.DeveloperApi class AnalysisException protected[sql] ( val message: String, val line: Option[Int] = None, - val startPosition: Option[Int] = None) + val startPosition: Option[Int] = None, + val plan: Option[LogicalPlan] = None) extends Exception with Serializable { def withPosition(line: Option[Int], startPosition: Option[Int]): AnalysisException = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala new file mode 100644 index 0000000000000..ffa694fcdc07a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.annotation.implicitNotFound +import scala.reflect.ClassTag + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.types._ + + +/** + * :: Experimental :: + * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. + * + * == Scala == + * Encoders are generally created automatically through implicits from a `SQLContext`, or can be + * explicitly created by calling static methods on [[Encoders]]. + * + * {{{ + * import sqlContext.implicits._ + * + * val ds = Seq(1, 2, 3).toDS() // implicitly provided (sqlContext.implicits.newIntEncoder) + * }}} + * + * == Java == + * Encoders are specified by calling static methods on [[Encoders]]. + * + * {{{ + * List data = Arrays.asList("abc", "abc", "xyz"); + * Dataset ds = context.createDataset(data, Encoders.STRING()); + * }}} + * + * Encoders can be composed into tuples: + * + * {{{ + * Encoder> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING()); + * List> data2 = Arrays.asList(new scala.Tuple2(1, "a"); + * Dataset> ds2 = context.createDataset(data2, encoder2); + * }}} + * + * Or constructed from Java Beans: + * + * {{{ + * Encoders.bean(MyClass.class); + * }}} + * + * == Implementation == + * - Encoders are not required to be thread-safe and thus they do not need to use locks to guard + * against concurrent access if they reuse internal buffers to improve performance. + * + * @since 1.6.0 + */ +@Experimental +@implicitNotFound("Unable to find encoder for type stored in a Dataset. Primitive types " + + "(Int, String, etc) and Product types (case classes) are supported by importing " + + "sqlContext.implicits._ Support for serializing other types will be added in future " + + "releases.") +trait Encoder[T] extends Serializable { + + /** Returns the schema of encoding this type of object as a Row. */ + def schema: StructType + + /** A ClassTag that can be used to construct and Array to contain a collection of `T`. */ + def clsTag: ClassTag[T] +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala new file mode 100644 index 0000000000000..3f4df704db755 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.lang.reflect.Modifier + +import scala.reflect.{classTag, ClassTag} +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, DecodeUsingSerializer, EncodeUsingSerializer} +import org.apache.spark.sql.types._ + +/** + * :: Experimental :: + * Methods for creating an [[Encoder]]. + * + * @since 1.6.0 + */ +@Experimental +object Encoders { + + /** + * An encoder for nullable boolean type. + * The Scala primitive encoder is available as [[scalaBoolean]]. + * @since 1.6.0 + */ + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder() + + /** + * An encoder for nullable byte type. + * The Scala primitive encoder is available as [[scalaByte]]. + * @since 1.6.0 + */ + def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder() + + /** + * An encoder for nullable short type. + * The Scala primitive encoder is available as [[scalaShort]]. + * @since 1.6.0 + */ + def SHORT: Encoder[java.lang.Short] = ExpressionEncoder() + + /** + * An encoder for nullable int type. + * The Scala primitive encoder is available as [[scalaInt]]. + * @since 1.6.0 + */ + def INT: Encoder[java.lang.Integer] = ExpressionEncoder() + + /** + * An encoder for nullable long type. + * The Scala primitive encoder is available as [[scalaLong]]. + * @since 1.6.0 + */ + def LONG: Encoder[java.lang.Long] = ExpressionEncoder() + + /** + * An encoder for nullable float type. + * The Scala primitive encoder is available as [[scalaFloat]]. + * @since 1.6.0 + */ + def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder() + + /** + * An encoder for nullable double type. + * The Scala primitive encoder is available as [[scalaDouble]]. + * @since 1.6.0 + */ + def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder() + + /** + * An encoder for nullable string type. + * + * @since 1.6.0 + */ + def STRING: Encoder[java.lang.String] = ExpressionEncoder() + + /** + * An encoder for nullable decimal type. + * + * @since 1.6.0 + */ + def DECIMAL: Encoder[java.math.BigDecimal] = ExpressionEncoder() + + /** + * An encoder for nullable date type. + * + * @since 1.6.0 + */ + def DATE: Encoder[java.sql.Date] = ExpressionEncoder() + + /** + * An encoder for nullable timestamp type. + * + * @since 1.6.0 + */ + def TIMESTAMP: Encoder[java.sql.Timestamp] = ExpressionEncoder() + + /** + * An encoder for arrays of bytes. + * + * @since 1.6.1 + */ + def BINARY: Encoder[Array[Byte]] = ExpressionEncoder() + + /** + * Creates an encoder for Java Bean of type T. + * + * T must be publicly accessible. + * + * supported types for java bean field: + * - primitive types: boolean, int, double, etc. + * - boxed types: Boolean, Integer, Double, etc. + * - String + * - java.math.BigDecimal + * - time related: java.sql.Date, java.sql.Timestamp + * - collection types: only array and java.util.List currently, map support is in progress + * - nested java bean. + * + * @since 1.6.0 + */ + def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass) + + /** + * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. + * This encoder maps T into a single byte array (binary) field. + * + * T must be publicly accessible. + * + * @since 1.6.0 + */ + def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true) + + /** + * Creates an encoder that serializes objects of type T using Kryo. + * This encoder maps T into a single byte array (binary) field. + * + * T must be publicly accessible. + * + * @since 1.6.0 + */ + def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) + + /** + * (Scala-specific) Creates an encoder that serializes objects of type T using generic Java + * serialization. This encoder maps T into a single byte array (binary) field. + * + * Note that this is extremely inefficient and should only be used as the last resort. + * + * T must be publicly accessible. + * + * @since 1.6.0 + */ + def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) + + /** + * Creates an encoder that serializes objects of type T using generic Java serialization. + * This encoder maps T into a single byte array (binary) field. + * + * Note that this is extremely inefficient and should only be used as the last resort. + * + * T must be publicly accessible. + * + * @since 1.6.0 + */ + def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) + + /** Throws an exception if T is not a public class. */ + private def validatePublicClass[T: ClassTag](): Unit = { + if (!Modifier.isPublic(classTag[T].runtimeClass.getModifiers)) { + throw new UnsupportedOperationException( + s"${classTag[T].runtimeClass.getName} is not a public class. " + + "Only public classes are supported.") + } + } + + /** A way to construct encoders using generic serializers. */ + private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = { + if (classTag[T].runtimeClass.isPrimitive) { + throw new UnsupportedOperationException("Primitive types are not supported.") + } + + validatePublicClass[T]() + + ExpressionEncoder[T]( + schema = new StructType().add("value", BinaryType), + flat = true, + serializer = Seq( + EncodeUsingSerializer( + BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), + deserializer = + DecodeUsingSerializer[T]( + BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), + clsTag = classTag[T] + ) + } + + /** + * An encoder for 2-ary tuples. + * + * @since 1.6.0 + */ + def tuple[T1, T2]( + e1: Encoder[T1], + e2: Encoder[T2]): Encoder[(T1, T2)] = { + ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2)) + } + + /** + * An encoder for 3-ary tuples. + * + * @since 1.6.0 + */ + def tuple[T1, T2, T3]( + e1: Encoder[T1], + e2: Encoder[T2], + e3: Encoder[T3]): Encoder[(T1, T2, T3)] = { + ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3)) + } + + /** + * An encoder for 4-ary tuples. + * + * @since 1.6.0 + */ + def tuple[T1, T2, T3, T4]( + e1: Encoder[T1], + e2: Encoder[T2], + e3: Encoder[T3], + e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = { + ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4)) + } + + /** + * An encoder for 5-ary tuples. + * + * @since 1.6.0 + */ + def tuple[T1, T2, T3, T4, T5]( + e1: Encoder[T1], + e2: Encoder[T2], + e3: Encoder[T3], + e4: Encoder[T4], + e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = { + ExpressionEncoder.tuple( + encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4), encoderFor(e5)) + } + + /** + * An encoder for Scala's product type (tuples, case classes, etc). + * @since 2.0.0 + */ + def product[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder() + + /** + * An encoder for Scala's primitive int type. + * @since 2.0.0 + */ + def scalaInt: Encoder[Int] = ExpressionEncoder() + + /** + * An encoder for Scala's primitive long type. + * @since 2.0.0 + */ + def scalaLong: Encoder[Long] = ExpressionEncoder() + + /** + * An encoder for Scala's primitive double type. + * @since 2.0.0 + */ + def scalaDouble: Encoder[Double] = ExpressionEncoder() + + /** + * An encoder for Scala's primitive float type. + * @since 2.0.0 + */ + def scalaFloat: Encoder[Float] = ExpressionEncoder() + + /** + * An encoder for Scala's primitive byte type. + * @since 2.0.0 + */ + def scalaByte: Encoder[Byte] = ExpressionEncoder() + + /** + * An encoder for Scala's primitive short type. + * @since 2.0.0 + */ + def scalaShort: Encoder[Short] = ExpressionEncoder() + + /** + * An encoder for Scala's primitive boolean type. + * @since 2.0.0 + */ + def scalaBoolean: Encoder[Boolean] = ExpressionEncoder() + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index ed2fdf9f2f7cf..1219d4d453e13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import scala.util.hashing.MurmurHash3 -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types.StructType @@ -152,7 +151,7 @@ trait Row extends Serializable { * BinaryType -> byte array * ArrayType -> scala.collection.Seq (use getList for java.util.List) * MapType -> scala.collection.Map (use getJavaMap for java.util.Map) - * StructType -> org.apache.spark.sql.Row + * StructType -> org.apache.spark.sql.Row (or Product) * }}} */ def apply(i: Int): Any = get(i) @@ -177,7 +176,7 @@ trait Row extends Serializable { * BinaryType -> byte array * ArrayType -> scala.collection.Seq (use getList for java.util.List) * MapType -> scala.collection.Map (use getJavaMap for java.util.Map) - * StructType -> org.apache.spark.sql.Row + * StructType -> org.apache.spark.sql.Row (or Product) * }}} */ def get(i: Int): Any @@ -191,7 +190,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getBoolean(i: Int): Boolean = getAs[Boolean](i) + def getBoolean(i: Int): Boolean = getAnyValAs[Boolean](i) /** * Returns the value at position i as a primitive byte. @@ -199,7 +198,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getByte(i: Int): Byte = getAs[Byte](i) + def getByte(i: Int): Byte = getAnyValAs[Byte](i) /** * Returns the value at position i as a primitive short. @@ -207,7 +206,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getShort(i: Int): Short = getAs[Short](i) + def getShort(i: Int): Short = getAnyValAs[Short](i) /** * Returns the value at position i as a primitive int. @@ -215,7 +214,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getInt(i: Int): Int = getAs[Int](i) + def getInt(i: Int): Int = getAnyValAs[Int](i) /** * Returns the value at position i as a primitive long. @@ -223,7 +222,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getLong(i: Int): Long = getAs[Long](i) + def getLong(i: Int): Long = getAnyValAs[Long](i) /** * Returns the value at position i as a primitive float. @@ -232,7 +231,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getFloat(i: Int): Float = getAs[Float](i) + def getFloat(i: Int): Float = getAnyValAs[Float](i) /** * Returns the value at position i as a primitive double. @@ -240,13 +239,12 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getDouble(i: Int): Double = getAs[Double](i) + def getDouble(i: Int): Double = getAnyValAs[Double](i) /** * Returns the value at position i as a String object. * * @throws ClassCastException when data type does not match. - * @throws NullPointerException when value is null. */ def getString(i: Int): String = getAs[String](i) @@ -306,10 +304,20 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ - def getStruct(i: Int): Row = getAs[Row](i) + def getStruct(i: Int): Row = { + // Product and Row both are recognized as StructType in a Row + val t = get(i) + if (t.isInstanceOf[Product]) { + Row.fromTuple(t.asInstanceOf[Product]) + } else { + t.asInstanceOf[Row] + } + } /** * Returns the value at position i. + * For primitive types if value is null it returns 'zero value' specific for primitive + * ie. 0 for Int - use isNullAt to ensure that value is not null * * @throws ClassCastException when data type does not match. */ @@ -317,6 +325,8 @@ trait Row extends Serializable { /** * Returns the value of a given fieldName. + * For primitive types if value is null it returns 'zero value' specific for primitive + * ie. 0 for Int - use isNullAt to ensure that value is not null * * @throws UnsupportedOperationException when schema is not defined. * @throws IllegalArgumentException when fieldName do not exist. @@ -336,6 +346,8 @@ trait Row extends Serializable { /** * Returns a Map(name -> value) for the requested fieldNames + * For primitive types if value is null it returns 'zero value' specific for primitive + * ie. 0 for Int - use isNullAt to ensure that value is not null * * @throws UnsupportedOperationException when schema is not defined. * @throws IllegalArgumentException when fieldName do not exist. @@ -450,4 +462,15 @@ trait Row extends Serializable { * start, end, and separator strings. */ def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) + + /** + * Returns the value of a given fieldName. + * + * @throws UnsupportedOperationException when schema is not defined. + * @throws ClassCastException when data type does not match. + * @throws NullPointerException when value is null. + */ + private def getAnyValAs[T <: AnyVal](i: Int): T = + if (isNullAt(i)) throw new NullPointerException(s"Value at index $i in null") + else getAs[T](i) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala deleted file mode 100644 index 04ac4f20c66ec..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst - -import scala.language.implicitConversions -import scala.util.parsing.combinator.lexical.StdLexical -import scala.util.parsing.combinator.syntactical.StandardTokenParsers -import scala.util.parsing.combinator.PackratParsers -import scala.util.parsing.input.CharArrayReader.EofCh - -import org.apache.spark.sql.catalyst.plans.logical._ - -private[sql] abstract class AbstractSparkSQLParser - extends StandardTokenParsers with PackratParsers { - - def parse(input: String): LogicalPlan = synchronized { - // Initialize the Keywords. - initLexical - phrase(start)(new lexical.Scanner(input)) match { - case Success(plan, _) => plan - case failureOrError => sys.error(failureOrError.toString) - } - } - /* One time initialization of lexical.This avoid reinitialization of lexical in parse method */ - protected lazy val initLexical: Unit = lexical.initialize(reservedWords) - - protected case class Keyword(str: String) { - def normalize: String = lexical.normalizeKeyword(str) - def parser: Parser[String] = normalize - } - - protected implicit def asParser(k: Keyword): Parser[String] = k.parser - - // By default, use Reflection to find the reserved words defined in the sub class. - // NOTICE, Since the Keyword properties defined by sub class, we couldn't call this - // method during the parent class instantiation, because the sub class instance - // isn't created yet. - protected lazy val reservedWords: Seq[String] = - this - .getClass - .getMethods - .filter(_.getReturnType == classOf[Keyword]) - .map(_.invoke(this).asInstanceOf[Keyword].normalize) - - // Set the keywords as empty by default, will change that later. - override val lexical = new SqlLexical - - protected def start: Parser[LogicalPlan] - - // Returns the whole input string - protected lazy val wholeInput: Parser[String] = new Parser[String] { - def apply(in: Input): ParseResult[String] = - Success(in.source.toString, in.drop(in.source.length())) - } - - // Returns the rest of the input string that are not parsed yet - protected lazy val restInput: Parser[String] = new Parser[String] { - def apply(in: Input): ParseResult[String] = - Success( - in.source.subSequence(in.offset, in.source.length()).toString, - in.drop(in.source.length())) - } -} - -class SqlLexical extends StdLexical { - case class FloatLit(chars: String) extends Token { - override def toString: String = chars - } - - case class DecimalLit(chars: String) extends Token { - override def toString: String = chars - } - - /* This is a work around to support the lazy setting */ - def initialize(keywords: Seq[String]): Unit = { - reserved.clear() - reserved ++= keywords - } - - /* Normal the keyword string */ - def normalizeKeyword(str: String): String = str.toLowerCase - - delimiters += ( - "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", - ",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~", "<=>" - ) - - protected override def processIdent(name: String) = { - val token = normalizeKeyword(name) - if (reserved contains token) Keyword(token) else Identifier(name) - } - - override lazy val token: Parser[Token] = - ( rep1(digit) ~ ('.' ~> digit.*).? ~ (exp ~> sign.? ~ rep1(digit)) ^^ { - case i ~ None ~ (sig ~ rest) => - DecimalLit(i.mkString + "e" + sig.mkString + rest.mkString) - case i ~ Some(d) ~ (sig ~ rest) => - DecimalLit(i.mkString + "." + d.mkString + "e" + sig.mkString + rest.mkString) - } - | digit.* ~ identChar ~ (identChar | digit).* ^^ - { case first ~ middle ~ rest => processIdent((first ++ (middle :: rest)).mkString) } - | rep1(digit) ~ ('.' ~> digit.*).? ^^ { - case i ~ None => NumericLit(i.mkString) - case i ~ Some(d) => FloatLit(i.mkString + "." + d.mkString) - } - | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^ - { case chars => StringLit(chars mkString "") } - | '"' ~> chrExcept('"', '\n', EofCh).* <~ '"' ^^ - { case chars => StringLit(chars mkString "") } - | '`' ~> chrExcept('`', '\n', EofCh).* <~ '`' ^^ - { case chars => Identifier(chars mkString "") } - | EofCh ^^^ EOF - | '\'' ~> failure("unclosed string literal") - | '"' ~> failure("unclosed string literal") - | delim - | failure("illegal character") - ) - - override def identChar: Parser[Elem] = letter | elem('_') - - private lazy val sign: Parser[Elem] = elem("s", c => c == '+' || c == '-') - private lazy val exp: Parser[Elem] = elem("e", c => c == 'E' || c == 'e') - - override def whitespace: Parser[Any] = - ( whitespaceChar - | '/' ~ '*' ~ comment - | '/' ~ '/' ~ chrExcept(EofCh, '\n').* - | '#' ~ chrExcept(EofCh, '\n').* - | '-' ~ '-' ~ chrExcept(EofCh, '\n').* - | '/' ~ '*' ~ failure("unclosed comment") - ).* -} - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index 3f351b07b37df..2b98aacdd7264 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -17,8 +17,25 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.sql.catalyst.analysis._ + private[spark] trait CatalystConf { def caseSensitiveAnalysis: Boolean + + def orderByOrdinal: Boolean + def groupByOrdinal: Boolean + + /** + * Returns the [[Resolver]] for the current configuration, which can be used to determine if two + * identifiers are equal. + */ + def resolver: Resolver = { + if (caseSensitiveAnalysis) { + caseSensitiveResolution + } else { + caseInsensitiveResolution + } + } } /** @@ -29,7 +46,19 @@ object EmptyConf extends CatalystConf { override def caseSensitiveAnalysis: Boolean = { throw new UnsupportedOperationException } + override def orderByOrdinal: Boolean = { + throw new UnsupportedOperationException + } + override def groupByOrdinal: Boolean = { + throw new UnsupportedOperationException + } } /** A CatalystConf that can be used for local testing. */ -case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends CatalystConf +case class SimpleCatalystConf( + caseSensitiveAnalysis: Boolean, + orderByOrdinal: Boolean = true, + groupByOrdinal: Boolean = true) + + extends CatalystConf { +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 2ec0ff53c89c0..9bfc381639140 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -136,16 +136,16 @@ object CatalystTypeConverters { override def toScalaImpl(row: InternalRow, column: Int): Any = row.get(column, dataType) } - private case class UDTConverter( - udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] { + private case class UDTConverter[A >: Null]( + udt: UserDefinedType[A]) extends CatalystTypeConverter[A, A, Any] { // toCatalyst (it calls toCatalystImpl) will do null check. - override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue) + override def toCatalystImpl(scalaValue: A): Any = udt.serialize(scalaValue) - override def toScala(catalystValue: Any): Any = { + override def toScala(catalystValue: Any): A = { if (catalystValue == null) null else udt.deserialize(catalystValue) } - override def toScalaImpl(row: InternalRow, column: Int): Any = + override def toScalaImpl(row: InternalRow, column: Int): A = toScala(row.get(column, udt.sqlType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 88a457f87ce4e..6f9fbbbead474 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -17,29 +17,35 @@ package org.apache.spark.sql.catalyst -import java.beans.Introspector +import java.beans.{Introspector, PropertyDescriptor} import java.lang.{Iterable => JIterable} -import java.util.{Iterator => JIterator, Map => JMap} +import java.util.{Iterator => JIterator, List => JList, Map => JMap} import scala.language.existentials import com.google.common.reflect.TypeToken + +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Type-inference utilities for POJOs and Java collections. */ -private [sql] object JavaTypeInference { +object JavaTypeInference { private val iterableType = TypeToken.of(classOf[JIterable[_]]) private val mapType = TypeToken.of(classOf[JMap[_, _]]) + private val listType = TypeToken.of(classOf[JList[_]]) private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType /** - * Infers the corresponding SQL data type of a JavaClean class. + * Infers the corresponding SQL data type of a JavaBean class. * @param beanClass Java type * @return (SQL data type, nullable) */ @@ -53,12 +59,13 @@ private [sql] object JavaTypeInference { * @return (SQL data type, nullable) */ private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { - // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. typeToken.getRawType match { case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) + case c: Class[_] if c == classOf[Array[Byte]] => (BinaryType, true) + case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) @@ -88,15 +95,14 @@ private [sql] object JavaTypeInference { (ArrayType(dataType, nullable), true) case _ if mapType.isAssignableFrom(typeToken) => - val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]] - val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]]) - val keyType = elementType(mapSupertype.resolveType(keySetReturnType)) - val valueType = elementType(mapSupertype.resolveType(valuesReturnType)) + val (keyType, valueType) = mapKeyValueType(typeToken) val (keyDataType, _) = inferDataType(keyType) val (valueDataType, nullable) = inferDataType(valueType) (MapType(keyDataType, valueDataType, nullable), true) case _ => + // TODO: we should only collect properties that have getter and setter. However, some tests + // pass in scala case class as java bean class which doesn't have getter and setter. val beanInfo = Introspector.getBeanInfo(typeToken.getRawType) val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") val fields = properties.map { property => @@ -108,11 +114,302 @@ private [sql] object JavaTypeInference { } } + private def getJavaBeanProperties(beanClass: Class[_]): Array[PropertyDescriptor] = { + val beanInfo = Introspector.getBeanInfo(beanClass) + beanInfo.getPropertyDescriptors + .filter(p => p.getReadMethod != null && p.getWriteMethod != null) + } + private def elementType(typeToken: TypeToken[_]): TypeToken[_] = { val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]] - val iterableSupertype = typeToken2.getSupertype(classOf[JIterable[_]]) - val iteratorType = iterableSupertype.resolveType(iteratorReturnType) - val itemType = iteratorType.resolveType(nextReturnType) - itemType + val iterableSuperType = typeToken2.getSupertype(classOf[JIterable[_]]) + val iteratorType = iterableSuperType.resolveType(iteratorReturnType) + iteratorType.resolveType(nextReturnType) + } + + private def mapKeyValueType(typeToken: TypeToken[_]): (TypeToken[_], TypeToken[_]) = { + val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]] + val mapSuperType = typeToken2.getSupertype(classOf[JMap[_, _]]) + val keyType = elementType(mapSuperType.resolveType(keySetReturnType)) + val valueType = elementType(mapSuperType.resolveType(valuesReturnType)) + keyType -> valueType + } + + /** + * Returns the Spark SQL DataType for a given java class. Where this is not an exact mapping + * to a native type, an ObjectType is returned. + * + * Unlike `inferDataType`, this function doesn't do any massaging of types into the Spark SQL type + * system. As a result, ObjectType will be returned for things like boxed Integers. + */ + private def inferExternalType(cls: Class[_]): DataType = cls match { + case c if c == java.lang.Boolean.TYPE => BooleanType + case c if c == java.lang.Byte.TYPE => ByteType + case c if c == java.lang.Short.TYPE => ShortType + case c if c == java.lang.Integer.TYPE => IntegerType + case c if c == java.lang.Long.TYPE => LongType + case c if c == java.lang.Float.TYPE => FloatType + case c if c == java.lang.Double.TYPE => DoubleType + case c if c == classOf[Array[Byte]] => BinaryType + case _ => ObjectType(cls) + } + + /** + * Returns an expression that can be used to deserialize an internal row to an object of java bean + * `T` with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes + * of the same name as the constructor arguments. Nested classes will have their fields accessed + * using UnresolvedExtractValue. + */ + def deserializerFor(beanClass: Class[_]): Expression = { + deserializerFor(TypeToken.of(beanClass), None) + } + + private def deserializerFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = { + /** Returns the current path with a sub-field extracted. */ + def addToPath(part: String): Expression = path + .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) + .getOrElse(UnresolvedAttribute(part)) + + /** Returns the current path or `BoundReference`. */ + def getPath: Expression = path.getOrElse(BoundReference(0, inferDataType(typeToken)._1, true)) + + typeToken.getRawType match { + case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath + + case c if c == classOf[java.lang.Short] => + NewInstance(c, getPath :: Nil, ObjectType(c)) + case c if c == classOf[java.lang.Integer] => + NewInstance(c, getPath :: Nil, ObjectType(c)) + case c if c == classOf[java.lang.Long] => + NewInstance(c, getPath :: Nil, ObjectType(c)) + case c if c == classOf[java.lang.Double] => + NewInstance(c, getPath :: Nil, ObjectType(c)) + case c if c == classOf[java.lang.Byte] => + NewInstance(c, getPath :: Nil, ObjectType(c)) + case c if c == classOf[java.lang.Float] => + NewInstance(c, getPath :: Nil, ObjectType(c)) + case c if c == classOf[java.lang.Boolean] => + NewInstance(c, getPath :: Nil, ObjectType(c)) + + case c if c == classOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils.getClass, + ObjectType(c), + "toJavaDate", + getPath :: Nil, + propagateNull = true) + + case c if c == classOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils.getClass, + ObjectType(c), + "toJavaTimestamp", + getPath :: Nil, + propagateNull = true) + + case c if c == classOf[java.lang.String] => + Invoke(getPath, "toString", ObjectType(classOf[String])) + + case c if c == classOf[java.math.BigDecimal] => + Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + + case c if c.isArray => + val elementType = c.getComponentType + val primitiveMethod = elementType match { + case c if c == java.lang.Boolean.TYPE => Some("toBooleanArray") + case c if c == java.lang.Byte.TYPE => Some("toByteArray") + case c if c == java.lang.Short.TYPE => Some("toShortArray") + case c if c == java.lang.Integer.TYPE => Some("toIntArray") + case c if c == java.lang.Long.TYPE => Some("toLongArray") + case c if c == java.lang.Float.TYPE => Some("toFloatArray") + case c if c == java.lang.Double.TYPE => Some("toDoubleArray") + case _ => None + } + + primitiveMethod.map { method => + Invoke(getPath, method, ObjectType(c)) + }.getOrElse { + Invoke( + MapObjects( + p => deserializerFor(typeToken.getComponentType, Some(p)), + getPath, + inferDataType(elementType)._1), + "array", + ObjectType(c)) + } + + case c if listType.isAssignableFrom(typeToken) => + val et = elementType(typeToken) + val array = + Invoke( + MapObjects( + p => deserializerFor(et, Some(p)), + getPath, + inferDataType(et)._1), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke(classOf[java.util.Arrays], ObjectType(c), "asList", array :: Nil) + + case _ if mapType.isAssignableFrom(typeToken) => + val (keyType, valueType) = mapKeyValueType(typeToken) + val keyDataType = inferDataType(keyType)._1 + val valueDataType = inferDataType(valueType)._1 + + val keyData = + Invoke( + MapObjects( + p => deserializerFor(keyType, Some(p)), + Invoke(getPath, "keyArray", ArrayType(keyDataType)), + keyDataType), + "array", + ObjectType(classOf[Array[Any]])) + + val valueData = + Invoke( + MapObjects( + p => deserializerFor(valueType, Some(p)), + Invoke(getPath, "valueArray", ArrayType(valueDataType)), + valueDataType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + ArrayBasedMapData.getClass, + ObjectType(classOf[JMap[_, _]]), + "toJavaMap", + keyData :: valueData :: Nil) + + case other => + val properties = getJavaBeanProperties(other) + assert(properties.length > 0) + + val setters = properties.map { p => + val fieldName = p.getName + val fieldType = typeToken.method(p.getReadMethod).getReturnType + val (_, nullable) = inferDataType(fieldType) + val constructor = deserializerFor(fieldType, Some(addToPath(fieldName))) + val setter = if (nullable) { + constructor + } else { + AssertNotNull(constructor, Seq("currently no type path record in java")) + } + p.getWriteMethod.getName -> setter + }.toMap + + val newInstance = NewInstance(other, Nil, ObjectType(other), propagateNull = false) + val result = InitializeJavaBean(newInstance, setters) + + if (path.nonEmpty) { + expressions.If( + IsNull(getPath), + expressions.Literal.create(null, ObjectType(other)), + result + ) + } else { + result + } + } + } + + /** + * Returns an expression for serializing an object of the given type to an internal row. + */ + def serializerFor(beanClass: Class[_]): CreateNamedStruct = { + val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) + serializerFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct] + } + + private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { + + def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = { + val (dataType, nullable) = inferDataType(elementType) + if (ScalaReflection.isNativeType(dataType)) { + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(dataType, nullable)) + } else { + MapObjects(serializerFor(_, elementType), input, ObjectType(elementType.getRawType)) + } + } + + if (!inputObject.dataType.isInstanceOf[ObjectType]) { + inputObject + } else { + typeToken.getRawType match { + case c if c == classOf[String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil) + + case c if c == classOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils.getClass, + TimestampType, + "fromJavaTimestamp", + inputObject :: Nil) + + case c if c == classOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils.getClass, + DateType, + "fromJavaDate", + inputObject :: Nil) + + case c if c == classOf[java.math.BigDecimal] => + StaticInvoke( + Decimal.getClass, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case c if c == classOf[java.lang.Boolean] => + Invoke(inputObject, "booleanValue", BooleanType) + case c if c == classOf[java.lang.Byte] => + Invoke(inputObject, "byteValue", ByteType) + case c if c == classOf[java.lang.Short] => + Invoke(inputObject, "shortValue", ShortType) + case c if c == classOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case c if c == classOf[java.lang.Long] => + Invoke(inputObject, "longValue", LongType) + case c if c == classOf[java.lang.Float] => + Invoke(inputObject, "floatValue", FloatType) + case c if c == classOf[java.lang.Double] => + Invoke(inputObject, "doubleValue", DoubleType) + + case _ if typeToken.isArray => + toCatalystArray(inputObject, typeToken.getComponentType) + + case _ if listType.isAssignableFrom(typeToken) => + toCatalystArray(inputObject, elementType(typeToken)) + + case _ if mapType.isAssignableFrom(typeToken) => + // TODO: for java map, if we get the keys and values by `keySet` and `values`, we can + // not guarantee they have same iteration order(which is different from scala map). + // A possible solution is creating a new `MapObjects` that can iterate a map directly. + throw new UnsupportedOperationException("map type is not supported currently") + + case other => + val properties = getJavaBeanProperties(other) + if (properties.length > 0) { + CreateNamedStruct(properties.flatMap { p => + val fieldName = p.getName + val fieldType = typeToken.method(p.getReadMethod).getReturnType + val fieldValue = Invoke( + inputObject, + p.getReadMethod.getName, + inferExternalType(fieldType.getRawType)) + expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil + }) + } else { + throw new UnsupportedOperationException( + s"Cannot infer type for class ${other.getName} because it is not bean-compliant") + } + } + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala deleted file mode 100644 index e21d3c05464b6..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan - -/** - * Root class of SQL Parser Dialect, and we don't guarantee the binary - * compatibility for the future release, let's keep it as the internal - * interface for advanced user. - * - */ -@DeveloperApi -abstract class ParserDialect { - // this is the main function that will be implemented by sql parser. - def parse(sqlText: String): LogicalPlan -} - -/** - * Currently we support the default dialect named "sql", associated with the class - * [[DefaultParserDialect]] - * - * And we can also provide custom SQL Dialect, for example in Spark SQL CLI: - * {{{ - *-- switch to "hiveql" dialect - * spark-sql>SET spark.sql.dialect=hiveql; - * spark-sql>SELECT * FROM src LIMIT 1; - * - *-- switch to "sql" dialect - * spark-sql>SET spark.sql.dialect=sql; - * spark-sql>SELECT * FROM src LIMIT 1; - * - *-- register the new SQL dialect - * spark-sql> SET spark.sql.dialect=com.xxx.xxx.SQL99Dialect; - * spark-sql> SELECT * FROM src LIMIT 1; - * - *-- register the non-exist SQL dialect - * spark-sql> SET spark.sql.dialect=NotExistedClass; - * spark-sql> SELECT * FROM src LIMIT 1; - * - *-- Exception will be thrown and switch to dialect - *-- "sql" (for SQLContext) or - *-- "hiveql" (for HiveContext) - * }}} - */ -private[spark] class DefaultParserDialect extends ParserDialect { - @transient - protected val sqlParser = SqlParser - - override def parse(sqlText: String): LogicalPlan = { - sqlParser.parse(sqlText) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 0b8a8abd02d67..4795fc25576aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, ArrayData, DateTimeUtils} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -30,22 +29,13 @@ import org.apache.spark.util.Utils */ object ScalaReflection extends ScalaReflection { val universe: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe - // Since we are creating a runtime mirror usign the class loader of current thread, + // Since we are creating a runtime mirror using the class loader of current thread, // we need to use def at here. So, every time we call mirror, it is using the // class loader of the current thread. - override def mirror: universe.Mirror = + // SPARK-13640: Synchronize this because universe.runtimeMirror is not thread-safe in Scala 2.10. + override def mirror: universe.Mirror = ScalaReflectionLock.synchronized { universe.runtimeMirror(Thread.currentThread().getContextClassLoader) -} - -/** - * Support for generating catalyst schemas for scala objects. - */ -trait ScalaReflection { - /** The universe we work in (runtime or macro) */ - val universe: scala.reflect.api.Universe - - /** The mirror used to access types in the universe */ - def mirror: universe.Mirror + } import universe._ @@ -53,30 +43,6 @@ trait ScalaReflection { // Since the map values can be mutable, we explicitly import scala.collection.Map at here. import scala.collection.Map - case class Schema(dataType: DataType, nullable: Boolean) - - /** Returns a Sequence of attributes for the given case class type. */ - def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { - case Schema(s: StructType, _) => - s.toAttributes - } - - /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor[T: TypeTag]: Schema = - ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) } - - /** - * Return the Scala Type for `T` in the current classloader mirror. - * - * Use this method instead of the convenience method `universe.typeOf`, which - * assumes that all types can be found in the classloader that loaded scala-reflect classes. - * That's not necessarily the case when running using Eclipse launchers or even - * Sbt console or test (without `fork := true`). - * - * @see SPARK-5281 - */ - private def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe - /** * Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping * to a native type, an ObjectType is returned. Special handling is also used for Arrays including @@ -85,37 +51,30 @@ trait ScalaReflection { * Unlike `schemaFor`, this function doesn't do any massaging of types into the Spark SQL type * system. As a result, ObjectType will be returned for things like boxed Integers */ - def dataTypeFor(tpe: `Type`): DataType = tpe match { - case t if t <:< definitions.IntTpe => IntegerType - case t if t <:< definitions.LongTpe => LongType - case t if t <:< definitions.DoubleTpe => DoubleType - case t if t <:< definitions.FloatTpe => FloatType - case t if t <:< definitions.ShortTpe => ShortType - case t if t <:< definitions.ByteTpe => ByteType - case t if t <:< definitions.BooleanTpe => BooleanType - case t if t <:< localTypeOf[Array[Byte]] => BinaryType - case _ => - val className: String = tpe.erasure.typeSymbol.asClass.fullName - className match { - case "scala.Array" => - val TypeRef(_, _, Seq(arrayType)) = tpe - val cls = arrayType match { - case t if t <:< definitions.IntTpe => classOf[Array[Int]] - case t if t <:< definitions.LongTpe => classOf[Array[Long]] - case t if t <:< definitions.DoubleTpe => classOf[Array[Double]] - case t if t <:< definitions.FloatTpe => classOf[Array[Float]] - case t if t <:< definitions.ShortTpe => classOf[Array[Short]] - case t if t <:< definitions.ByteTpe => classOf[Array[Byte]] - case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]] - case other => - // There is probably a better way to do this, but I couldn't find it... - val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls - java.lang.reflect.Array.newInstance(elementType, 1).getClass + def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T]) - } - ObjectType(cls) - case other => ObjectType(Utils.classForName(className)) - } + private def dataTypeFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { + tpe match { + case t if t <:< definitions.IntTpe => IntegerType + case t if t <:< definitions.LongTpe => LongType + case t if t <:< definitions.DoubleTpe => DoubleType + case t if t <:< definitions.FloatTpe => FloatType + case t if t <:< definitions.ShortTpe => ShortType + case t if t <:< definitions.ByteTpe => ByteType + case t if t <:< definitions.BooleanTpe => BooleanType + case t if t <:< localTypeOf[Array[Byte]] => BinaryType + case t if t <:< localTypeOf[Decimal] => DecimalType.SYSTEM_DEFAULT + case _ => + val className = getClassNameFromType(tpe) + className match { + case "scala.Array" => + val TypeRef(_, _, Seq(elementType)) = tpe + arrayClassFor(elementType) + case other => + val clazz = getClassFromType(tpe) + ObjectType(clazz) + } + } } /** @@ -123,7 +82,7 @@ trait ScalaReflection { * Array[T]. Special handling is performed for primitive types to map them back to their raw * JVM form instead of the Scala Array that handles auto boxing. */ - def arrayClassFor(tpe: `Type`): DataType = { + private def arrayClassFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { val cls = tpe match { case t if t <:< definitions.IntTpe => classOf[Array[Int]] case t if t <:< definitions.LongTpe => classOf[Array[Long]] @@ -142,109 +101,133 @@ trait ScalaReflection { } /** - * Returns an expression that can be used to construct an object of type `T` given an input - * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes + * Returns true if the value of this data type is same between internal and external. + */ + def isNativeType(dt: DataType): Boolean = dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType | BinaryType => true + case _ => false + } + + /** + * Returns an expression that can be used to deserialize an input row to an object of type `T` + * with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes * of the same name as the constructor arguments. Nested classes will have their fields accessed * using UnresolvedExtractValue. * * When used on a primitive type, the constructor will instead default to extracting the value * from ordinal 0 (since there are no names to map to). The actual location can be moved by - * calling unbind/bind with a new schema. + * calling resolve/bind with a new schema. */ - def constructorFor[T : TypeTag]: Expression = constructorFor(typeOf[T], None) + def deserializerFor[T : TypeTag]: Expression = { + val tpe = localTypeOf[T] + val clsName = getClassNameFromType(tpe) + val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil + deserializerFor(tpe, None, walkedTypePath) + } - protected def constructorFor( + private def deserializerFor( tpe: `Type`, - path: Option[Expression]): Expression = ScalaReflectionLock.synchronized { + path: Option[Expression], + walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String) = - path + def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = { + val newPath = path .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) .getOrElse(UnresolvedAttribute(part)) + upCastToExpectedType(newPath, dataType, walkedTypePath) + } /** Returns the current path with a field at ordinal extracted. */ - def addToPathOrdinal(ordinal: Int, dataType: DataType) = - path - .map(p => GetStructField(p, StructField(s"_$ordinal", dataType), ordinal)) + def addToPathOrdinal( + ordinal: Int, + dataType: DataType, + walkedTypePath: Seq[String]): Expression = { + val newPath = path + .map(p => GetStructField(p, ordinal)) .getOrElse(BoundReference(ordinal, dataType, false)) + upCastToExpectedType(newPath, dataType, walkedTypePath) + } + + /** Returns the current path or `BoundReference`. */ + def getPath: Expression = { + val dataType = schemaFor(tpe).dataType + if (path.isDefined) { + path.get + } else { + upCastToExpectedType(BoundReference(0, dataType, true), dataType, walkedTypePath) + } + } - /** Returns the current path or throws an error. */ - def getPath = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) + /** + * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff + * and lost the required data type, which may lead to runtime error if the real type doesn't + * match the encoder's schema. + * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type + * is [a: int, b: long], then we will hit runtime error and say that we can't construct class + * `Data` with int and long, because we lost the information that `b` should be a string. + * + * This method help us "remember" the required data type by adding a `UpCast`. Note that we + * don't need to cast struct type because there must be `UnresolvedExtractValue` or + * `GetStructField` wrapping it, thus we only need to handle leaf type. + */ + def upCastToExpectedType( + expr: Expression, + expected: DataType, + walkedTypePath: Seq[String]): Expression = expected match { + case _: StructType => expr + case _ => UpCast(expr, expected, walkedTypePath) + } + val className = getClassNameFromType(tpe) tpe match { - case t if !dataTypeFor(t).isInstanceOf[ObjectType] => - getPath + case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t - val boxedType = optType match { - // For primitive types we must manually box the primitive value. - case t if t <:< definitions.IntTpe => Some(classOf[java.lang.Integer]) - case t if t <:< definitions.LongTpe => Some(classOf[java.lang.Long]) - case t if t <:< definitions.DoubleTpe => Some(classOf[java.lang.Double]) - case t if t <:< definitions.FloatTpe => Some(classOf[java.lang.Float]) - case t if t <:< definitions.ShortTpe => Some(classOf[java.lang.Short]) - case t if t <:< definitions.ByteTpe => Some(classOf[java.lang.Byte]) - case t if t <:< definitions.BooleanTpe => Some(classOf[java.lang.Boolean]) - case _ => None - } - - boxedType.map { boxedType => - val objectType = ObjectType(boxedType) - WrapOption( - objectType, - NewInstance( - boxedType, - getPath :: Nil, - propagateNull = true, - objectType)) - }.getOrElse { - val className: String = optType.erasure.typeSymbol.asClass.fullName - val cls = Utils.classForName(className) - val objectType = ObjectType(cls) - - WrapOption(objectType, constructorFor(optType, path)) - } + val className = getClassNameFromType(optType) + val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath + WrapOption(deserializerFor(optType, path, newTypePath), dataTypeFor(optType)) case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Long] => val boxedType = classOf[java.lang.Long] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Double] => val boxedType = classOf[java.lang.Double] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Float] => val boxedType = classOf[java.lang.Float] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Short] => val boxedType = classOf[java.lang.Short] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Byte] => val boxedType = classOf[java.lang.Byte] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Boolean] => val boxedType = classOf[java.lang.Boolean] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", getPath :: Nil, @@ -252,7 +235,7 @@ trait ScalaReflection { case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", getPath :: Nil, @@ -264,11 +247,13 @@ trait ScalaReflection { case t if t <:< localTypeOf[java.math.BigDecimal] => Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + case t if t <:< localTypeOf[BigDecimal] => + Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal])) + case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t - val elementDataType = dataTypeFor(elementType) - val Schema(dataType, nullable) = schemaFor(elementType) + // TODO: add runtime null check for primitive array val primitiveMethod = elementType match { case t if t <:< definitions.IntTpe => Some("toIntArray") case t if t <:< definitions.LongTpe => Some("toLongArray") @@ -281,136 +266,103 @@ trait ScalaReflection { } primitiveMethod.map { method => - Invoke(getPath, method, dataTypeFor(t)) + Invoke(getPath, method, arrayClassFor(elementType)) }.getOrElse { - val returnType = dataTypeFor(t) + val className = getClassNameFromType(elementType) + val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath Invoke( - MapObjects(p => constructorFor(elementType, Some(p)), getPath, dataType), + MapObjects( + p => deserializerFor(elementType, Some(p), newTypePath), + getPath, + schemaFor(elementType).dataType), "array", - returnType) + arrayClassFor(elementType)) } - case t if t <:< localTypeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(keyDataType, _) = schemaFor(keyType) - val Schema(valueDataType, valueNullable) = schemaFor(valueType) + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + val className = getClassNameFromType(elementType) + val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath - val primitiveMethodKey = keyType match { - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None + val mapFunction: Expression => Expression = p => { + val converter = deserializerFor(elementType, Some(p), newTypePath) + if (nullable) { + converter + } else { + AssertNotNull(converter, newTypePath) + } } + val array = Invoke( + MapObjects(mapFunction, getPath, dataType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + scala.collection.mutable.WrappedArray.getClass, + ObjectType(classOf[Seq[_]]), + "make", + array :: Nil) + + case t if t <:< localTypeOf[Map[_, _]] => + // TODO: add walked type path for map + val TypeRef(_, _, Seq(keyType, valueType)) = t + val keyData = Invoke( MapObjects( - p => constructorFor(keyType, Some(p)), - Invoke(getPath, "keyArray", ArrayType(keyDataType)), - keyDataType), + p => deserializerFor(keyType, Some(p), walkedTypePath), + Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), + schemaFor(keyType).dataType), "array", ObjectType(classOf[Array[Any]])) - val primitiveMethodValue = valueType match { - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None - } - val valueData = Invoke( MapObjects( - p => constructorFor(valueType, Some(p)), - Invoke(getPath, "valueArray", ArrayType(valueDataType)), - valueDataType), + p => deserializerFor(valueType, Some(p), walkedTypePath), + Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), + schemaFor(valueType).dataType), "array", ObjectType(classOf[Array[Any]])) StaticInvoke( - ArrayBasedMapData, + ArrayBasedMapData.getClass, ObjectType(classOf[Map[_, _]]), "toScalaMap", keyData :: valueData :: Nil) - case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val elementDataType = dataTypeFor(elementType) - val Schema(dataType, nullable) = schemaFor(elementType) - - // Avoid boxing when possible by just wrapping a primitive array. - val primitiveMethod = elementType match { - case _ if nullable => None - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None - } - - val arrayData = primitiveMethod.map { method => - Invoke(getPath, method, arrayClassFor(elementType)) - }.getOrElse { - Invoke( - MapObjects(p => constructorFor(elementType, Some(p)), getPath, dataType), - "array", - arrayClassFor(elementType)) - } - - StaticInvoke( - scala.collection.mutable.WrappedArray, - ObjectType(classOf[Seq[_]]), - "make", - arrayData :: Nil) - - case t if t <:< localTypeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val constructorSymbol = t.member(nme.CONSTRUCTOR) - val params = if (constructorSymbol.isMethod) { - constructorSymbol.asMethod.paramss - } else { - // Find the primary constructor, and use its parameter ordering. - val primaryConstructorSymbol: Option[Symbol] = - constructorSymbol.asTerm.alternatives.find(s => - s.isMethod && s.asMethod.isPrimaryConstructor) - - if (primaryConstructorSymbol.isEmpty) { - sys.error("Internal SQL error: Product object did not have a primary constructor.") - } else { - primaryConstructorSymbol.get.asMethod.paramss - } - } + val params = getConstructorParameters(t) - val className: String = t.erasure.typeSymbol.asClass.fullName - val cls = Utils.classForName(className) - - val arguments = params.head.zipWithIndex.map { case (p, i) => - val fieldName = p.name.toString - val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) - val dataType = schemaFor(fieldType).dataType + val cls = getClassFromType(tpe) + val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) => + val Schema(dataType, nullable) = schemaFor(fieldType) + val clsName = getClassNameFromType(fieldType) + val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath // For tuples, we based grab the inner fields by ordinal instead of name. - if (className startsWith "scala.Tuple") { - constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) + if (cls.getName startsWith "scala.Tuple") { + deserializerFor( + fieldType, + Some(addToPathOrdinal(i, dataType, newTypePath)), + newTypePath) } else { - constructorFor(fieldType, Some(addToPath(fieldName))) + val constructor = deserializerFor( + fieldType, + Some(addToPath(fieldName, dataType, newTypePath)), + newTypePath) + + if (!nullable) { + AssertNotNull(constructor, newTypePath) + } else { + constructor + } } } - val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls)) + val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false) if (path.nonEmpty) { expressions.If( @@ -422,26 +374,65 @@ trait ScalaReflection { newInstance } + case t if Utils.classIsLoadable(className) && + Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => + val udt = Utils.classForName(className) + .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) } } - /** Returns expressions for extracting all the fields from the given type. */ - def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { - ScalaReflectionLock.synchronized { - extractorFor(inputObject, typeTag[T].tpe) match { - case s: CreateNamedStruct => s - case o => CreateNamedStruct(expressions.Literal("value") :: o :: Nil) - } + /** + * Returns an expression for serializing an object of type T to an internal row. + * + * If the given type is not supported, i.e. there is no encoder can be built for this type, + * an [[UnsupportedOperationException]] will be thrown with detailed error message to explain + * the type path walked so far and which class we are not supporting. + * There are 4 kinds of type path: + * * the root type: `root class: "abc.xyz.MyClass"` + * * the value type of [[Option]]: `option value class: "abc.xyz.MyClass"` + * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"` + * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")` + */ + def serializerFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { + val tpe = localTypeOf[T] + val clsName = getClassNameFromType(tpe) + val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil + serializerFor(inputObject, tpe, walkedTypePath) match { + case expressions.If(_, _, s: CreateNamedStruct) if tpe <:< localTypeOf[Product] => s + case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) } } /** Helper for extracting internal fields from a case class. */ - protected def extractorFor( + private def serializerFor( inputObject: Expression, - tpe: `Type`): Expression = ScalaReflectionLock.synchronized { + tpe: `Type`, + walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { + + def toCatalystArray(input: Expression, elementType: `Type`): Expression = { + val externalDataType = dataTypeFor(elementType) + val Schema(catalystType, nullable) = silentSchemaFor(elementType) + if (isNativeType(externalDataType)) { + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(catalystType, nullable)) + } else { + val clsName = getClassNameFromType(elementType) + val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath + MapObjects(serializerFor(_, elementType, newPath), input, externalDataType) + } + } + if (!inputObject.dataType.isInstanceOf[ObjectType]) { inputObject } else { + val className = getClassNameFromType(tpe) tpe match { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t @@ -485,90 +476,67 @@ trait ScalaReflection { // For non-primitives, we can just extract the object from the Option and then recurse. case other => - val className: String = optType.erasure.typeSymbol.asClass.fullName - val classObj = Utils.classForName(className) - val optionObjectType = ObjectType(classObj) - + val className = getClassNameFromType(optType) + val newPath = s"""- option value class: "$className"""" +: walkedTypePath + + val optionObjectType: DataType = other match { + // Special handling is required for arrays, as getClassFromType() will fail + // since Scala Arrays map to native Java constructs. E.g. "Array[Int]" will map to + // the Java type "[I". + case arr if arr <:< localTypeOf[Array[_]] => arrayClassFor(t) + case cls => ObjectType(getClassFromType(cls)) + } val unwrapped = UnwrapOption(optionObjectType, inputObject) + expressions.If( IsNull(unwrapped), - expressions.Literal.create(null, schemaFor(optType).dataType), - extractorFor(unwrapped, optType)) + expressions.Literal.create(null, silentSchemaFor(optType).dataType), + serializerFor(unwrapped, optType, newPath)) } case t if t <:< localTypeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val constructorSymbol = t.member(nme.CONSTRUCTOR) - val params = if (constructorSymbol.isMethod) { - constructorSymbol.asMethod.paramss - } else { - // Find the primary constructor, and use its parameter ordering. - val primaryConstructorSymbol: Option[Symbol] = - constructorSymbol.asTerm.alternatives.find(s => - s.isMethod && s.asMethod.isPrimaryConstructor) - - if (primaryConstructorSymbol.isEmpty) { - sys.error("Internal SQL error: Product object did not have a primary constructor.") - } else { - primaryConstructorSymbol.get.asMethod.paramss - } - } - - CreateNamedStruct(params.head.flatMap { p => - val fieldName = p.name.toString - val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + val params = getConstructorParameters(t) + val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) - expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil + val clsName = getClassNameFromType(fieldType) + val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath + expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil }) + val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) + expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t - val elementDataType = dataTypeFor(elementType) - val Schema(dataType, nullable) = schemaFor(elementType) - - if (!elementDataType.isInstanceOf[AtomicType]) { - MapObjects(extractorFor(_, elementType), inputObject, elementDataType) - } else { - NewInstance( - classOf[GenericArrayData], - inputObject :: Nil, - dataType = ArrayType(dataType, nullable)) - } + toCatalystArray(inputObject, elementType) case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t - val elementDataType = dataTypeFor(elementType) - val Schema(dataType, nullable) = schemaFor(elementType) - - if (dataType.isInstanceOf[AtomicType]) { - NewInstance( - classOf[GenericArrayData], - inputObject :: Nil, - dataType = ArrayType(dataType, nullable)) - } else { - MapObjects(extractorFor(_, elementType), inputObject, elementDataType) - } + toCatalystArray(inputObject, elementType) case t if t <:< localTypeOf[Map[_, _]] => val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(keyDataType, _) = schemaFor(keyType) - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - val rawMap = inputObject val keys = - NewInstance( - classOf[GenericArrayData], - Invoke(rawMap, "keys", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil, - dataType = ObjectType(classOf[ArrayData])) + Invoke( + Invoke(inputObject, "keysIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedKeys = toCatalystArray(keys, keyType) + val values = - NewInstance( - classOf[GenericArrayData], - Invoke(rawMap, "values", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil, - dataType = ObjectType(classOf[ArrayData])) + Invoke( + Invoke(inputObject, "valuesIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedValues = toCatalystArray(values, valueType) + + val Schema(keyDataType, _) = schemaFor(keyType) + val Schema(valueDataType, valueNullable) = schemaFor(valueType) NewInstance( classOf[ArrayBasedMapData], - keys :: values :: Nil, + convertedKeys :: convertedValues :: Nil, dataType = MapType(keyDataType, valueDataType, valueNullable)) case t if t <:< localTypeOf[String] => @@ -580,27 +548,28 @@ trait ScalaReflection { case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", inputObject :: Nil) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, DateType, "fromJavaDate", inputObject :: Nil) + case t if t <:< localTypeOf[BigDecimal] => StaticInvoke( - Decimal, + Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", inputObject :: Nil) case t if t <:< localTypeOf[java.math.BigDecimal] => StaticInvoke( - Decimal, + Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", inputObject :: Nil) @@ -620,33 +589,109 @@ trait ScalaReflection { case t if t <:< localTypeOf[java.lang.Boolean] => Invoke(inputObject, "booleanValue", BooleanType) - case t if t <:< definitions.IntTpe => - BoundReference(0, IntegerType, false) - case t if t <:< definitions.LongTpe => - BoundReference(0, LongType, false) - case t if t <:< definitions.DoubleTpe => - BoundReference(0, DoubleType, false) - case t if t <:< definitions.FloatTpe => - BoundReference(0, FloatType, false) - case t if t <:< definitions.ShortTpe => - BoundReference(0, ShortType, false) - case t if t <:< definitions.ByteTpe => - BoundReference(0, ByteType, false) - case t if t <:< definitions.BooleanTpe => - BoundReference(0, BooleanType, false) + case t if Utils.classIsLoadable(className) && + Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => + val udt = Utils.classForName(className) + .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) case other => - throw new UnsupportedOperationException(s"Extractor for type $other is not supported") + throw new UnsupportedOperationException( + s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) } } } + /** + * Returns the parameter names and types for the primary constructor of this class. + * + * Note that it only works for scala classes with primary constructor, and currently doesn't + * support inner class. + */ + def getConstructorParameters(cls: Class[_]): Seq[(String, Type)] = { + val m = runtimeMirror(cls.getClassLoader) + val classSymbol = m.staticClass(cls.getName) + val t = classSymbol.selfType + getConstructorParameters(t) + } + + /** + * Returns the parameter names for the primary constructor of this class. + * + * Logically we should call `getConstructorParameters` and throw away the parameter types to get + * parameter names, however there are some weird scala reflection problems and this method is a + * workaround to avoid getting parameter types. + */ + def getConstructorParameterNames(cls: Class[_]): Seq[String] = { + val m = runtimeMirror(cls.getClassLoader) + val classSymbol = m.staticClass(cls.getName) + val t = classSymbol.selfType + constructParams(t).map(_.name.toString) + } + + /* + * Retrieves the runtime class corresponding to the provided type. + */ + def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) +} + +/** + * Support for generating catalyst schemas for scala objects. Note that unlike its companion + * object, this trait able to work in both the runtime and the compile time (macro) universe. + */ +trait ScalaReflection { + /** The universe we work in (runtime or macro) */ + val universe: scala.reflect.api.Universe + + /** The mirror used to access types in the universe */ + def mirror: universe.Mirror + + import universe._ + + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + import scala.collection.Map + + case class Schema(dataType: DataType, nullable: Boolean) + + /** Returns a Sequence of attributes for the given case class type. */ + def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { + case Schema(s: StructType, _) => + s.toAttributes + } + + /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ + def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T]) + + /** + * Return the Scala Type for `T` in the current classloader mirror. + * + * Use this method instead of the convenience method `universe.typeOf`, which + * assumes that all types can be found in the classloader that loaded scala-reflect classes. + * That's not necessarily the case when running using Eclipse launchers or even + * Sbt console or test (without `fork := true`). + * + * @see SPARK-5281 + */ + // SPARK-13640: Synchronize this because TypeTag.tpe is not thread-safe in Scala 2.10. + def localTypeOf[T: TypeTag]: `Type` = ScalaReflectionLock.synchronized { + val tag = implicitly[TypeTag[T]] + tag.in(mirror).tpe.normalize + } + /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { - val className: String = tpe.erasure.typeSymbol.asClass.fullName + val className = getClassNameFromType(tpe) + tpe match { + case t if Utils.classIsLoadable(className) && Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => + // Note: We check for classIsLoadable above since Utils.classForName uses Java reflection, // whereas className is from Scala reflection. This can make it hard to find classes // in some cases, such as when a class is enclosed in an object (in which case @@ -672,26 +717,11 @@ trait ScalaReflection { Schema(MapType(schemaFor(keyType).dataType, valueDataType, valueContainsNull = valueNullable), nullable = true) case t if t <:< localTypeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val constructorSymbol = t.member(nme.CONSTRUCTOR) - val params = if (constructorSymbol.isMethod) { - constructorSymbol.asMethod.paramss - } else { - // Find the primary constructor, and use its parameter ordering. - val primaryConstructorSymbol: Option[Symbol] = constructorSymbol.asTerm.alternatives.find( - s => s.isMethod && s.asMethod.isPrimaryConstructor) - if (primaryConstructorSymbol.isEmpty) { - sys.error("Internal SQL error: Product object did not have a primary constructor.") - } else { - primaryConstructorSymbol.get.asMethod.paramss - } - } + val params = getConstructorParameters(t) Schema(StructType( - params.head.map { p => - val Schema(dataType, nullable) = - schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) - StructField(p.name.toString, dataType, nullable) + params.map { case (fieldName, fieldType) => + val Schema(dataType, nullable) = schemaFor(fieldType) + StructField(fieldName, dataType, nullable) }), nullable = true) case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true) case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true) @@ -719,38 +749,70 @@ trait ScalaReflection { } } - def typeOfObject: PartialFunction[Any, DataType] = { - // The data type can be determined without ambiguity. - case obj: Boolean => BooleanType - case obj: Array[Byte] => BinaryType - case obj: String => StringType - case obj: UTF8String => StringType - case obj: Byte => ByteType - case obj: Short => ShortType - case obj: Int => IntegerType - case obj: Long => LongType - case obj: Float => FloatType - case obj: Double => DoubleType - case obj: java.sql.Date => DateType - case obj: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT - case obj: Decimal => DecimalType.SYSTEM_DEFAULT - case obj: java.sql.Timestamp => TimestampType - case null => NullType - // For other cases, there is no obvious mapping from the type of the given object to a - // Catalyst data type. A user should provide his/her specific rules - // (in a user-defined PartialFunction) to infer the Catalyst data type for other types of - // objects and then compose the user-defined PartialFunction with this one. + /** + * Returns a catalyst DataType and its nullability for the given Scala Type using reflection. + * + * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return + * `NullType` silently instead. + */ + def silentSchemaFor(tpe: `Type`): Schema = try { + schemaFor(tpe) + } catch { + case _: UnsupportedOperationException => Schema(NullType, nullable = true) + } + + /** + * Returns the full class name for a type. The returned name is the canonical + * Scala name, where each component is separated by a period. It is NOT the + * Java-equivalent runtime name (no dollar signs). + * + * In simple cases, both the Scala and Java names are the same, however when Scala + * generates constructs that do not map to a Java equivalent, such as singleton objects + * or nested classes in package objects, it uses the dollar sign ($) to create + * synthetic classes, emulating behaviour in Java bytecode. + */ + def getClassNameFromType(tpe: `Type`): String = { + tpe.erasure.typeSymbol.asClass.fullName + } + + /** + * Returns classes of input parameters of scala function object. + */ + def getParameterTypes(func: AnyRef): Seq[Class[_]] = { + val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge) + assert(methods.length == 1) + methods.head.getParameterTypes } - implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) { + /** + * Returns the parameter names and types for the primary constructor of this type. + * + * Note that it only works for scala classes with primary constructor, and currently doesn't + * support inner class. + */ + def getConstructorParameters(tpe: Type): Seq[(String, Type)] = { + val formalTypeArgs = tpe.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = tpe + constructParams(tpe).map { p => + p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + } + } - /** - * Implicitly added to Sequences of case class objects. Returns a catalyst logical relation - * for the the data in the sequence. - */ - def asRelation: LocalRelation = { - val output = attributesFor[A] - LocalRelation.fromProduct(output, data) + protected def constructParams(tpe: Type): Seq[Symbol] = { + val constructorSymbol = tpe.member(nme.CONSTRUCTOR) + val params = if (constructorSymbol.isMethod) { + constructorSymbol.asMethod.paramss + } else { + // Find the primary constructor, and use its parameter ordering. + val primaryConstructorSymbol: Option[Symbol] = constructorSymbol.asTerm.alternatives.find( + s => s.isMethod && s.asMethod.isPrimaryConstructor) + if (primaryConstructorSymbol.isEmpty) { + sys.error("Internal SQL error: Product object did not have a primary constructor.") + } else { + primaryConstructorSymbol.get.asMethod.paramss + } } + params.flatten } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala deleted file mode 100644 index 440e9e28fa783..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ /dev/null @@ -1,516 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst - -import scala.language.implicitConversions - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.DataTypeParser -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval - -/** - * A very simple SQL parser. Based loosely on: - * https://github.com/stephentu/scala-sql-parser/blob/master/src/main/scala/parser.scala - * - * Limitations: - * - Only supports a very limited subset of SQL. - * - * This is currently included mostly for illustrative purposes. Users wanting more complete support - * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project. - */ -object SqlParser extends AbstractSparkSQLParser with DataTypeParser { - - def parseExpression(input: String): Expression = synchronized { - // Initialize the Keywords. - initLexical - phrase(projection)(new lexical.Scanner(input)) match { - case Success(plan, _) => plan - case failureOrError => sys.error(failureOrError.toString) - } - } - - def parseTableIdentifier(input: String): TableIdentifier = synchronized { - // Initialize the Keywords. - initLexical - phrase(tableIdentifier)(new lexical.Scanner(input)) match { - case Success(ident, _) => ident - case failureOrError => sys.error(failureOrError.toString) - } - } - - // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` - // properties via reflection the class in runtime for constructing the SqlLexical object - protected val ALL = Keyword("ALL") - protected val AND = Keyword("AND") - protected val APPROXIMATE = Keyword("APPROXIMATE") - protected val AS = Keyword("AS") - protected val ASC = Keyword("ASC") - protected val BETWEEN = Keyword("BETWEEN") - protected val BY = Keyword("BY") - protected val CASE = Keyword("CASE") - protected val CAST = Keyword("CAST") - protected val DESC = Keyword("DESC") - protected val DISTINCT = Keyword("DISTINCT") - protected val ELSE = Keyword("ELSE") - protected val END = Keyword("END") - protected val EXCEPT = Keyword("EXCEPT") - protected val FALSE = Keyword("FALSE") - protected val FROM = Keyword("FROM") - protected val FULL = Keyword("FULL") - protected val GROUP = Keyword("GROUP") - protected val HAVING = Keyword("HAVING") - protected val IN = Keyword("IN") - protected val INNER = Keyword("INNER") - protected val INSERT = Keyword("INSERT") - protected val INTERSECT = Keyword("INTERSECT") - protected val INTERVAL = Keyword("INTERVAL") - protected val INTO = Keyword("INTO") - protected val IS = Keyword("IS") - protected val JOIN = Keyword("JOIN") - protected val LEFT = Keyword("LEFT") - protected val LIKE = Keyword("LIKE") - protected val LIMIT = Keyword("LIMIT") - protected val NOT = Keyword("NOT") - protected val NULL = Keyword("NULL") - protected val ON = Keyword("ON") - protected val OR = Keyword("OR") - protected val ORDER = Keyword("ORDER") - protected val SORT = Keyword("SORT") - protected val OUTER = Keyword("OUTER") - protected val OVERWRITE = Keyword("OVERWRITE") - protected val REGEXP = Keyword("REGEXP") - protected val RIGHT = Keyword("RIGHT") - protected val RLIKE = Keyword("RLIKE") - protected val SELECT = Keyword("SELECT") - protected val SEMI = Keyword("SEMI") - protected val TABLE = Keyword("TABLE") - protected val THEN = Keyword("THEN") - protected val TRUE = Keyword("TRUE") - protected val UNION = Keyword("UNION") - protected val WHEN = Keyword("WHEN") - protected val WHERE = Keyword("WHERE") - protected val WITH = Keyword("WITH") - - protected lazy val start: Parser[LogicalPlan] = - start1 | insert | cte - - protected lazy val start1: Parser[LogicalPlan] = - (select | ("(" ~> select <~ ")")) * - ( UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } - | INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) } - | EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} - | UNION ~ DISTINCT.? ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } - ) - - protected lazy val select: Parser[LogicalPlan] = - SELECT ~> DISTINCT.? ~ - repsep(projection, ",") ~ - (FROM ~> relations).? ~ - (WHERE ~> expression).? ~ - (GROUP ~ BY ~> rep1sep(expression, ",")).? ~ - (HAVING ~> expression).? ~ - sortType.? ~ - (LIMIT ~> expression).? ^^ { - case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => - val base = r.getOrElse(OneRowRelation) - val withFilter = f.map(Filter(_, base)).getOrElse(base) - val withProjection = g - .map(Aggregate(_, p.map(UnresolvedAlias(_)), withFilter)) - .getOrElse(Project(p.map(UnresolvedAlias(_)), withFilter)) - val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection) - val withHaving = h.map(Filter(_, withDistinct)).getOrElse(withDistinct) - val withOrder = o.map(_(withHaving)).getOrElse(withHaving) - val withLimit = l.map(Limit(_, withOrder)).getOrElse(withOrder) - withLimit - } - - protected lazy val insert: Parser[LogicalPlan] = - INSERT ~> (OVERWRITE ^^^ true | INTO ^^^ false) ~ (TABLE ~> relation) ~ select ^^ { - case o ~ r ~ s => InsertIntoTable(r, Map.empty[String, Option[String]], s, o, false) - } - - protected lazy val cte: Parser[LogicalPlan] = - WITH ~> rep1sep(ident ~ ( AS ~ "(" ~> start1 <~ ")"), ",") ~ (start1 | insert) ^^ { - case r ~ s => With(s, r.map({case n ~ s => (n, Subquery(n, s))}).toMap) - } - - protected lazy val projection: Parser[Expression] = - expression ~ (AS.? ~> ident.?) ^^ { - case e ~ a => a.fold(e)(Alias(e, _)()) - } - - // Based very loosely on the MySQL Grammar. - // http://dev.mysql.com/doc/refman/5.0/en/join.html - protected lazy val relations: Parser[LogicalPlan] = - ( relation ~ rep1("," ~> relation) ^^ { - case r1 ~ joins => joins.foldLeft(r1) { case(lhs, r) => Join(lhs, r, Inner, None) } } - | relation - ) - - protected lazy val relation: Parser[LogicalPlan] = - joinedRelation | relationFactor - - protected lazy val relationFactor: Parser[LogicalPlan] = - ( tableIdentifier ~ (opt(AS) ~> opt(ident)) ^^ { - case tableIdent ~ alias => UnresolvedRelation(tableIdent, alias) - } - | ("(" ~> start <~ ")") ~ (AS.? ~> ident) ^^ { case s ~ a => Subquery(a, s) } - ) - - protected lazy val joinedRelation: Parser[LogicalPlan] = - relationFactor ~ rep1(joinType.? ~ (JOIN ~> relationFactor) ~ joinConditions.?) ^^ { - case r1 ~ joins => - joins.foldLeft(r1) { case (lhs, jt ~ rhs ~ cond) => - Join(lhs, rhs, joinType = jt.getOrElse(Inner), cond) - } - } - - protected lazy val joinConditions: Parser[Expression] = - ON ~> expression - - protected lazy val joinType: Parser[JoinType] = - ( INNER ^^^ Inner - | LEFT ~ SEMI ^^^ LeftSemi - | LEFT ~ OUTER.? ^^^ LeftOuter - | RIGHT ~ OUTER.? ^^^ RightOuter - | FULL ~ OUTER.? ^^^ FullOuter - ) - - protected lazy val sortType: Parser[LogicalPlan => LogicalPlan] = - ( ORDER ~ BY ~> ordering ^^ { case o => l: LogicalPlan => Sort(o, true, l) } - | SORT ~ BY ~> ordering ^^ { case o => l: LogicalPlan => Sort(o, false, l) } - ) - - protected lazy val ordering: Parser[Seq[SortOrder]] = - ( rep1sep(expression ~ direction.? , ",") ^^ { - case exps => exps.map(pair => SortOrder(pair._1, pair._2.getOrElse(Ascending))) - } - ) - - protected lazy val direction: Parser[SortDirection] = - ( ASC ^^^ Ascending - | DESC ^^^ Descending - ) - - protected lazy val expression: Parser[Expression] = - orExpression - - protected lazy val orExpression: Parser[Expression] = - andExpression * (OR ^^^ { (e1: Expression, e2: Expression) => Or(e1, e2) }) - - protected lazy val andExpression: Parser[Expression] = - notExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1, e2) }) - - protected lazy val notExpression: Parser[Expression] = - NOT.? ~ comparisonExpression ^^ { case maybeNot ~ e => maybeNot.map(_ => Not(e)).getOrElse(e) } - - protected lazy val comparisonExpression: Parser[Expression] = - ( termExpression ~ ("=" ~> termExpression) ^^ { case e1 ~ e2 => EqualTo(e1, e2) } - | termExpression ~ ("<" ~> termExpression) ^^ { case e1 ~ e2 => LessThan(e1, e2) } - | termExpression ~ ("<=" ~> termExpression) ^^ { case e1 ~ e2 => LessThanOrEqual(e1, e2) } - | termExpression ~ (">" ~> termExpression) ^^ { case e1 ~ e2 => GreaterThan(e1, e2) } - | termExpression ~ (">=" ~> termExpression) ^^ { case e1 ~ e2 => GreaterThanOrEqual(e1, e2) } - | termExpression ~ ("!=" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) } - | termExpression ~ ("<>" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) } - | termExpression ~ ("<=>" ~> termExpression) ^^ { case e1 ~ e2 => EqualNullSafe(e1, e2) } - | termExpression ~ NOT.? ~ (BETWEEN ~> termExpression) ~ (AND ~> termExpression) ^^ { - case e ~ not ~ el ~ eu => - val betweenExpr: Expression = And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu)) - not.fold(betweenExpr)(f => Not(betweenExpr)) - } - | termExpression ~ (RLIKE ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } - | termExpression ~ (REGEXP ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } - | termExpression ~ (LIKE ~> termExpression) ^^ { case e1 ~ e2 => Like(e1, e2) } - | termExpression ~ (NOT ~ LIKE ~> termExpression) ^^ { case e1 ~ e2 => Not(Like(e1, e2)) } - | termExpression ~ (IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ { - case e1 ~ e2 => In(e1, e2) - } - | termExpression ~ (NOT ~ IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ { - case e1 ~ e2 => Not(In(e1, e2)) - } - | termExpression <~ IS ~ NULL ^^ { case e => IsNull(e) } - | termExpression <~ IS ~ NOT ~ NULL ^^ { case e => IsNotNull(e) } - | termExpression - ) - - protected lazy val termExpression: Parser[Expression] = - productExpression * - ( "+" ^^^ { (e1: Expression, e2: Expression) => Add(e1, e2) } - | "-" ^^^ { (e1: Expression, e2: Expression) => Subtract(e1, e2) } - ) - - protected lazy val productExpression: Parser[Expression] = - baseExpression * - ( "*" ^^^ { (e1: Expression, e2: Expression) => Multiply(e1, e2) } - | "/" ^^^ { (e1: Expression, e2: Expression) => Divide(e1, e2) } - | "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1, e2) } - | "&" ^^^ { (e1: Expression, e2: Expression) => BitwiseAnd(e1, e2) } - | "|" ^^^ { (e1: Expression, e2: Expression) => BitwiseOr(e1, e2) } - | "^" ^^^ { (e1: Expression, e2: Expression) => BitwiseXor(e1, e2) } - ) - - protected lazy val function: Parser[Expression] = - ( ident <~ ("(" ~ "*" ~ ")") ^^ { case udfName => - if (lexical.normalizeKeyword(udfName) == "count") { - Count(Literal(1)) - } else { - throw new AnalysisException(s"invalid expression $udfName(*)") - } - } - | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ - { case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) } - | ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs => - lexical.normalizeKeyword(udfName) match { - case "sum" => SumDistinct(exprs.head) - case "count" => CountDistinct(exprs) - case _ => UnresolvedFunction(udfName, exprs, isDistinct = true) - } - } - | APPROXIMATE ~> ident ~ ("(" ~ DISTINCT ~> expression <~ ")") ^^ { case udfName ~ exp => - if (lexical.normalizeKeyword(udfName) == "count") { - ApproxCountDistinct(exp) - } else { - throw new AnalysisException(s"invalid function approximate $udfName") - } - } - | APPROXIMATE ~> "(" ~> unsignedFloat ~ ")" ~ ident ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ - { case s ~ _ ~ udfName ~ _ ~ _ ~ exp => - if (lexical.normalizeKeyword(udfName) == "count") { - ApproxCountDistinct(exp, s.toDouble) - } else { - throw new AnalysisException(s"invalid function approximate($s) $udfName") - } - } - | CASE ~> whenThenElse ^^ CaseWhen - | CASE ~> expression ~ whenThenElse ^^ - { case keyPart ~ branches => CaseKeyWhen(keyPart, branches) } - ) - - protected lazy val whenThenElse: Parser[List[Expression]] = - rep1(WHEN ~> expression ~ (THEN ~> expression)) ~ (ELSE ~> expression).? <~ END ^^ { - case altPart ~ elsePart => - altPart.flatMap { case whenExpr ~ thenExpr => - Seq(whenExpr, thenExpr) - } ++ elsePart - } - - protected lazy val cast: Parser[Expression] = - CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { - case exp ~ t => Cast(exp, t) - } - - protected lazy val literal: Parser[Literal] = - ( numericLiteral - | booleanLiteral - | stringLit ^^ { case s => Literal.create(s, StringType) } - | intervalLiteral - | NULL ^^^ Literal.create(null, NullType) - ) - - protected lazy val booleanLiteral: Parser[Literal] = - ( TRUE ^^^ Literal.create(true, BooleanType) - | FALSE ^^^ Literal.create(false, BooleanType) - ) - - protected lazy val numericLiteral: Parser[Literal] = - ( integral ^^ { case i => Literal(toNarrowestIntegerType(i)) } - | sign.? ~ unsignedFloat ^^ { - case s ~ f => Literal(toDecimalOrDouble(s.getOrElse("") + f)) - } - | sign.? ~ unsignedDecimal ^^ { - case s ~ d => Literal(toDecimalOrDouble(s.getOrElse("") + d)) - } - ) - - protected lazy val unsignedFloat: Parser[String] = - ( "." ~> numericLit ^^ { u => "0." + u } - | elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars) - ) - - protected lazy val unsignedDecimal: Parser[String] = - ( "." ~> decimalLit ^^ { u => "0." + u } - | elem("scientific_notation", _.isInstanceOf[lexical.DecimalLit]) ^^ (_.chars) - ) - - def decimalLit: Parser[String] = - elem("scientific_notation", _.isInstanceOf[lexical.DecimalLit]) ^^ (_.chars) - - protected lazy val sign: Parser[String] = ("+" | "-") - - protected lazy val integral: Parser[String] = - sign.? ~ numericLit ^^ { case s ~ n => s.getOrElse("") + n } - - private def intervalUnit(unitName: String) = acceptIf { - case lexical.Identifier(str) => - val normalized = lexical.normalizeKeyword(str) - normalized == unitName || normalized == unitName + "s" - case _ => false - } {_ => "wrong interval unit"} - - protected lazy val month: Parser[Int] = - integral <~ intervalUnit("month") ^^ { case num => num.toInt } - - protected lazy val year: Parser[Int] = - integral <~ intervalUnit("year") ^^ { case num => num.toInt * 12 } - - protected lazy val microsecond: Parser[Long] = - integral <~ intervalUnit("microsecond") ^^ { case num => num.toLong } - - protected lazy val millisecond: Parser[Long] = - integral <~ intervalUnit("millisecond") ^^ { - case num => num.toLong * CalendarInterval.MICROS_PER_MILLI - } - - protected lazy val second: Parser[Long] = - integral <~ intervalUnit("second") ^^ { - case num => num.toLong * CalendarInterval.MICROS_PER_SECOND - } - - protected lazy val minute: Parser[Long] = - integral <~ intervalUnit("minute") ^^ { - case num => num.toLong * CalendarInterval.MICROS_PER_MINUTE - } - - protected lazy val hour: Parser[Long] = - integral <~ intervalUnit("hour") ^^ { - case num => num.toLong * CalendarInterval.MICROS_PER_HOUR - } - - protected lazy val day: Parser[Long] = - integral <~ intervalUnit("day") ^^ { - case num => num.toLong * CalendarInterval.MICROS_PER_DAY - } - - protected lazy val week: Parser[Long] = - integral <~ intervalUnit("week") ^^ { - case num => num.toLong * CalendarInterval.MICROS_PER_WEEK - } - - private def intervalKeyword(keyword: String) = acceptIf { - case lexical.Identifier(str) => - lexical.normalizeKeyword(str) == keyword - case _ => false - } {_ => "wrong interval keyword"} - - protected lazy val intervalLiteral: Parser[Literal] = - ( INTERVAL ~> stringLit <~ intervalKeyword("year") ~ intervalKeyword("to") ~ - intervalKeyword("month") ^^ { case s => - Literal(CalendarInterval.fromYearMonthString(s)) - } - | INTERVAL ~> stringLit <~ intervalKeyword("day") ~ intervalKeyword("to") ~ - intervalKeyword("second") ^^ { case s => - Literal(CalendarInterval.fromDayTimeString(s)) - } - | INTERVAL ~> stringLit <~ intervalKeyword("year") ^^ { case s => - Literal(CalendarInterval.fromSingleUnitString("year", s)) - } - | INTERVAL ~> stringLit <~ intervalKeyword("month") ^^ { case s => - Literal(CalendarInterval.fromSingleUnitString("month", s)) - } - | INTERVAL ~> stringLit <~ intervalKeyword("day") ^^ { case s => - Literal(CalendarInterval.fromSingleUnitString("day", s)) - } - | INTERVAL ~> stringLit <~ intervalKeyword("hour") ^^ { case s => - Literal(CalendarInterval.fromSingleUnitString("hour", s)) - } - | INTERVAL ~> stringLit <~ intervalKeyword("minute") ^^ { case s => - Literal(CalendarInterval.fromSingleUnitString("minute", s)) - } - | INTERVAL ~> stringLit <~ intervalKeyword("second") ^^ { case s => - Literal(CalendarInterval.fromSingleUnitString("second", s)) - } - | INTERVAL ~> year.? ~ month.? ~ week.? ~ day.? ~ hour.? ~ minute.? ~ second.? ~ - millisecond.? ~ microsecond.? ^^ { case year ~ month ~ week ~ day ~ hour ~ minute ~ second ~ - millisecond ~ microsecond => - if (!Seq(year, month, week, day, hour, minute, second, - millisecond, microsecond).exists(_.isDefined)) { - throw new AnalysisException( - "at least one time unit should be given for interval literal") - } - val months = Seq(year, month).map(_.getOrElse(0)).sum - val microseconds = Seq(week, day, hour, minute, second, millisecond, microsecond) - .map(_.getOrElse(0L)).sum - Literal(new CalendarInterval(months, microseconds)) - } - ) - - private def toNarrowestIntegerType(value: String): Any = { - val bigIntValue = BigDecimal(value) - - bigIntValue match { - case v if bigIntValue.isValidInt => v.toIntExact - case v if bigIntValue.isValidLong => v.toLongExact - case v => v.underlying() - } - } - - private def toDecimalOrDouble(value: String): Any = { - val decimal = BigDecimal(value) - // follow the behavior in MS SQL Server - // https://msdn.microsoft.com/en-us/library/ms179899.aspx - if (value.contains('E') || value.contains('e')) { - decimal.doubleValue() - } else { - decimal.underlying() - } - } - - protected lazy val baseExpression: Parser[Expression] = - ( "*" ^^^ UnresolvedStar(None) - | (ident <~ "."). + <~ "*" ^^ { case target => UnresolvedStar(Option(target))} - | primary - ) - - protected lazy val signedPrimary: Parser[Expression] = - sign ~ primary ^^ { case s ~ e => if (s == "-") UnaryMinus(e) else e } - - protected lazy val attributeName: Parser[String] = acceptMatch("attribute name", { - case lexical.Identifier(str) => str - case lexical.Keyword(str) if !lexical.delimiters.contains(str) => str - }) - - protected lazy val primary: PackratParser[Expression] = - ( literal - | expression ~ ("[" ~> expression <~ "]") ^^ - { case base ~ ordinal => UnresolvedExtractValue(base, ordinal) } - | (expression <~ ".") ~ ident ^^ - { case base ~ fieldName => UnresolvedExtractValue(base, Literal(fieldName)) } - | cast - | "(" ~> expression <~ ")" - | function - | dotExpressionHeader - | signedPrimary - | "~" ~> expression ^^ BitwiseNot - | attributeName ^^ UnresolvedAttribute.quoted - ) - - protected lazy val dotExpressionHeader: Parser[Expression] = - (ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ { - case i1 ~ i2 ~ rest => UnresolvedAttribute(Seq(i1, i2) ++ rest) - } - - protected lazy val tableIdentifier: Parser[TableIdentifier] = - (ident <~ ".").? ~ ident ^^ { - case maybeDbName ~ tableName => TableIdentifier(tableName, maybeDbName) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala deleted file mode 100644 index 4d4e4ded99477..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst - -/** - * Identifies a `table` in `database`. If `database` is not defined, the current database is used. - */ -private[sql] case class TableIdentifier(table: String, database: Option[String]) { - def this(table: String) = this(table, None) - - override def toString: String = quotedString - - def quotedString: String = database.map(db => s"`$db`.`$table`").getOrElse(s"`$table`") - - def unquotedString: String = database.map(db => s"$db.$table").getOrElse(table) -} - -private[sql] object TableIdentifier { - def apply(tableName: String): TableIdentifier = new TableIdentifier(tableName) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 899ee67352df4..de40ddde1bdd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,33 +17,43 @@ package org.apache.spark.sql.catalyst.analysis +import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2} +import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.planning.IntegerIndex +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} +import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.types._ /** - * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing - * when all relations are already filled in and the analyzer needs only to resolve attribute - * references. + * A trivial [[Analyzer]] with an dummy [[SessionCatalog]] and [[EmptyFunctionRegistry]]. + * Used for testing when all relations are already filled in and the analyzer needs only + * to resolve attribute references. */ object SimpleAnalyzer - extends Analyzer(EmptyCatalog, EmptyFunctionRegistry, new SimpleCatalystConf(true)) + extends SimpleAnalyzer( + EmptyFunctionRegistry, + new SimpleCatalystConf(caseSensitiveAnalysis = true)) + +class SimpleAnalyzer(functionRegistry: FunctionRegistry, conf: CatalystConf) + extends Analyzer(new SessionCatalog(new InMemoryCatalog, functionRegistry, conf), conf) /** * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and - * [[UnresolvedRelation]]s into fully typed objects using information in a schema [[Catalog]] and - * a [[FunctionRegistry]]. + * [[UnresolvedRelation]]s into fully typed objects using information in a + * [[SessionCatalog]] and a [[FunctionRegistry]]. */ class Analyzer( - catalog: Catalog, - registry: FunctionRegistry, + catalog: SessionCatalog, conf: CatalystConf, maxIterations: Int = 100) extends RuleExecutor[LogicalPlan] with CheckAnalysis { @@ -65,24 +75,36 @@ class Analyzer( lazy val batches: Seq[Batch] = Seq( Batch("Substitution", fixedPoint, - CTESubstitution :: - WindowsSubstitution :: - Nil : _*), + CTESubstitution, + WindowsSubstitution, + EliminateUnions), Batch("Resolution", fixedPoint, ResolveRelations :: ResolveReferences :: + ResolveDeserializer :: + ResolveNewInstance :: + ResolveUpCast :: ResolveGroupingAnalytics :: - ResolveSortReferences :: + ResolvePivot :: + ResolveOrdinalInOrderByAndGroupBy :: + ResolveMissingReferences :: ResolveGenerate :: ResolveFunctions :: ResolveAliases :: + ResolveSubquery :: + ResolveWindowOrder :: + ResolveWindowFrame :: + ResolveNaturalAndUsingJoin :: ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: + TimeWindowing :: HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, PullOutNondeterministic), + Batch("UDF", Once, + HandleNullInputsForUDF), Batch("Cleanup", fixedPoint, CleanupAliases) ) @@ -106,10 +128,16 @@ class Analyzer( // see https://github.com/apache/spark/pull/4929#discussion_r27186638 for more info case u : UnresolvedRelation => val substituted = cteRelations.get(u.tableIdentifier.table).map { relation => - val withAlias = u.alias.map(Subquery(_, relation)) + val withAlias = u.alias.map(SubqueryAlias(_, relation)) withAlias.getOrElse(relation) } substituted.getOrElse(u) + case other => + // This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE. + other transformExpressions { + case e: SubqueryExpression => + e.withNewPlan(substituteCTE(e.query, cteRelations)) + } } } } @@ -122,14 +150,12 @@ class Analyzer( // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => child.transform { - case plan => plan.transformExpressions { + case p => p.transformExpressions { case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) => val errorMessage = s"Window specification $windowName is not defined in the WINDOW clause." val windowSpecDefinition = - windowDefinitions - .get(windowName) - .getOrElse(failAnalysis(errorMessage)) + windowDefinitions.getOrElse(windowName, failAnalysis(errorMessage)) WindowExpression(c, windowSpecDefinition) } } @@ -143,13 +169,14 @@ class Analyzer( private def assignAliases(exprs: Seq[NamedExpression]) = { exprs.zipWithIndex.map { case (expr, i) => - expr transform { - case u @ UnresolvedAlias(child) => child match { + expr transformUp { + case u @ UnresolvedAlias(child, optionalAliasName) => child match { case ne: NamedExpression => ne case e if !e.resolved => u case g: Generator => MultiAlias(g, Nil) case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)() - case other => Alias(other, s"_c$i")() + case e: ExtractValue => Alias(e, usePrettyExpression(e).sql)() + case e => Alias(e, optionalAliasName.getOrElse(usePrettyExpression(e).sql))() } } }.asInstanceOf[Seq[NamedExpression]] @@ -162,8 +189,12 @@ class Analyzer( case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) - case g: GroupingAnalytics if g.child.resolved && hasUnresolvedAlias(g.aggregations) => - g.withNewAggs(assignAliases(g.aggregations)) + case g: GroupingSets if g.child.resolved && hasUnresolvedAlias(g.aggregations) => + g.copy(aggregations = assignAliases(g.aggregations)) + + case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) + if child.resolved && hasUnresolvedAlias(groupByExprs) => + Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child) case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) => Project(assignAliases(projectList), child) @@ -197,53 +228,175 @@ class Analyzer( Seq.tabulate(1 << c.groupByExprs.length)(i => i) } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + private def hasGroupingAttribute(expr: Expression): Boolean = { + expr.collectFirst { + case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.hiveGroupingIdName) => u + }.isDefined + } + + private def hasGroupingFunction(e: Expression): Boolean = { + e.collectFirst { + case g: Grouping => g + case g: GroupingID => g + }.isDefined + } + + private def replaceGroupingFunc( + expr: Expression, + groupByExprs: Seq[Expression], + gid: Expression): Expression = { + expr transform { + case e: GroupingID => + if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) { + gid + } else { + throw new AnalysisException( + s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " + + s"grouping columns (${groupByExprs.mkString(",")})") + } + case Grouping(col: Expression) => + val idx = groupByExprs.indexOf(col) + if (idx >= 0) { + Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)), + Literal(1)), ByteType) + } else { + throw new AnalysisException(s"Column of grouping ($col) can't be found " + + s"in grouping columns ${groupByExprs.mkString(",")}") + } + } + } + + // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case a if !a.childrenResolved => a // be sure all of the children are resolved. - case a: Cube => - GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) - case a: Rollup => - GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) - case x: GroupingSets => + case p if p.expressions.exists(hasGroupingAttribute) => + failAnalysis( + s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead") + + case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) => + GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions) + case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) => + GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions) + + // Ensure all the expressions have been resolved. + case x: GroupingSets if x.expressions.forall(_.resolved) => val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() - // We will insert another Projection if the GROUP BY keys contains the - // non-attribute expressions. And the top operators can references those - // expressions by its alias. - // e.g. SELECT key%5 as c1 FROM src GROUP BY key%5 ==> - // SELECT a as c1 FROM (SELECT key%5 AS a FROM src) GROUP BY a - - // find all of the non-attribute expressions in the GROUP BY keys - val nonAttributeGroupByExpressions = new ArrayBuffer[Alias]() - - // The pair of (the original GROUP BY key, associated attribute) - val groupByExprPairs = x.groupByExprs.map(_ match { - case e: NamedExpression => (e, e.toAttribute) - case other => { - val alias = Alias(other, other.toString)() - nonAttributeGroupByExpressions += alias // add the non-attributes expression alias - (other, alias.toAttribute) - } - }) - // substitute the non-attribute expressions for aggregations. - val aggregation = x.aggregations.map(expr => expr.transformDown { - case e => groupByExprPairs.find(_._1.semanticEquals(e)).map(_._2).getOrElse(e) - }.asInstanceOf[NamedExpression]) + // Expand works by setting grouping expressions to null as determined by the bitmasks. To + // prevent these null values from being used in an aggregate instead of the original value + // we need to create new aliases for all group by expressions that will only be used for + // the intended purpose. + val groupByAliases: Seq[Alias] = x.groupByExprs.map { + case e: NamedExpression => Alias(e, e.name)() + case other => Alias(other, other.toString)() + } - // substitute the group by expressions. - val newGroupByExprs = groupByExprPairs.map(_._2) + val nonNullBitmask = x.bitmasks.reduce(_ & _) - val child = if (nonAttributeGroupByExpressions.length > 0) { - // insert additional projection if contains the - // non-attribute expressions in the GROUP BY keys - Project(x.child.output ++ nonAttributeGroupByExpressions, x.child) - } else { - x.child + val groupByAttributes = groupByAliases.zipWithIndex.map { case (a, idx) => + a.toAttribute.withNullability((nonNullBitmask & 1 << idx) == 0) + } + + val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr => + // collect all the found AggregateExpression, so we can check an expression is part of + // any AggregateExpression or not. + val aggsBuffer = ArrayBuffer[Expression]() + // Returns whether the expression belongs to any expressions in `aggsBuffer` or not. + def isPartOfAggregation(e: Expression): Boolean = { + aggsBuffer.exists(a => a.find(_ eq e).isDefined) + } + replaceGroupingFunc(expr, x.groupByExprs, gid).transformDown { + // AggregateExpression should be computed on the unmodified value of its argument + // expressions, so we should not replace any references to grouping expression + // inside it. + case e: AggregateExpression => + aggsBuffer += e + e + case e if isPartOfAggregation(e) => e + case e => + val index = groupByAliases.indexWhere(_.child.semanticEquals(e)) + if (index == -1) { + e + } else { + groupByAttributes(index) + } + }.asInstanceOf[NamedExpression] } Aggregate( - newGroupByExprs :+ VirtualColumn.groupingIdAttribute, - aggregation, - Expand(x.bitmasks, newGroupByExprs, gid, child)) + groupByAttributes :+ gid, + aggregations, + Expand(x.bitmasks, groupByAliases, groupByAttributes, gid, x.child)) + + case f @ Filter(cond, child) if hasGroupingFunction(cond) => + val groupingExprs = findGroupingExprs(child) + // The unresolved grouping id will be resolved by ResolveMissingReferences + val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute) + f.copy(condition = newCond) + + case s @ Sort(order, _, child) if order.exists(hasGroupingFunction) => + val groupingExprs = findGroupingExprs(child) + val gid = VirtualColumn.groupingIdAttribute + // The unresolved grouping id will be resolved by ResolveMissingReferences + val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, gid).asInstanceOf[SortOrder]) + s.copy(order = newOrder) + } + + private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = { + plan.collectFirst { + case a: Aggregate => + // this Aggregate should have grouping id as the last grouping key. + val gid = a.groupingExpressions.last + if (!gid.isInstanceOf[AttributeReference] + || gid.asInstanceOf[AttributeReference].name != VirtualColumn.groupingIdName) { + failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + } + a.groupingExpressions.take(a.groupingExpressions.length - 1) + }.getOrElse { + failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + } + } + } + + object ResolvePivot extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved) => p + case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => + val singleAgg = aggregates.size == 1 + val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => + def ifExpr(expr: Expression) = { + If(EqualTo(pivotColumn, value), expr, Literal(null)) + } + aggregates.map { aggregate => + val filteredAggregate = aggregate.transformDown { + // Assumption is the aggregate function ignores nulls. This is true for all current + // AggregateFunction's with the exception of First and Last in their default mode + // (which we handle) and possibly some Hive UDAF's. + case First(expr, _) => + First(ifExpr(expr), Literal(true)) + case Last(expr, _) => + Last(ifExpr(expr), Literal(true)) + case a: AggregateFunction => + a.withNewChildren(a.children.map(ifExpr)) + }.transform { + // We are duplicating aggregates that are now computing a different value for each + // pivot value. + // TODO: Don't construct the physical container until after analysis. + case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId) + } + if (filteredAggregate.fastEquals(aggregate)) { + throw new AnalysisException( + s"Aggregate expression required for pivot, found '$aggregate'") + } + val name = if (singleAgg) value.toString else value + "_" + aggregate.sql + Alias(filteredAggregate, name)() + } + } + val newGroupByExprs = groupByExprs.map { + case UnresolvedAlias(e, _) => e + case e => e + } + Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child) } } @@ -251,18 +404,18 @@ class Analyzer( * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. */ object ResolveRelations extends Rule[LogicalPlan] { - def getTable(u: UnresolvedRelation): LogicalPlan = { + private def getTable(u: UnresolvedRelation): LogicalPlan = { try { catalog.lookupRelation(u.tableIdentifier, u.alias) } catch { case _: NoSuchTableException => - u.failAnalysis(s"Table not found: ${u.tableName}") + u.failAnalysis(s"Table or View not found: ${u.tableName}") } } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _, _) => - i.copy(table = EliminateSubQueries(getTable(u))) + i.copy(table = EliminateSubqueryAliases(getTable(u))) case u: UnresolvedRelation => try { getTable(u) @@ -280,20 +433,60 @@ class Analyzer( */ object ResolveReferences extends Rule[LogicalPlan] { /** - * Foreach expression, expands the matching attribute.*'s in `child`'s input for the subtree - * rooted at each expression. + * Generate a new logical plan for the right child with different expression IDs + * for all conflicting attributes. */ - def expandStarExpressions(exprs: Seq[Expression], child: LogicalPlan): Seq[Expression] = { - exprs.flatMap { - case s: Star => s.expand(child, resolver) - case e => - e.transformDown { - case f1: UnresolvedFunction if containsStar(f1.children) => - f1.copy(children = f1.children.flatMap { - case s: Star => s.expand(child, resolver) - case o => o :: Nil - }) - } :: Nil + private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = { + val conflictingAttributes = left.outputSet.intersect(right.outputSet) + logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " + + s"between $left and $right") + + right.collect { + // Handle base relations that might appear more than once. + case oldVersion: MultiInstanceRelation + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + val newVersion = oldVersion.newInstance() + (oldVersion, newVersion) + + // Handle projects that create conflicting aliases. + case oldVersion @ Project(projectList, _) + if findAliases(projectList).intersect(conflictingAttributes).nonEmpty => + (oldVersion, oldVersion.copy(projectList = newAliases(projectList))) + + case oldVersion @ Aggregate(_, aggregateExpressions, _) + if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => + (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) + + case oldVersion: Generate + if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty => + val newOutput = oldVersion.generatorOutput.map(_.newInstance()) + (oldVersion, oldVersion.copy(generatorOutput = newOutput)) + + case oldVersion @ Window(windowExpressions, _, _, child) + if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) + .nonEmpty => + (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) + } + // Only handle first case, others will be fixed on the next pass. + .headOption match { + case None => + /* + * No result implies that there is a logical plan node that produces new references + * that this rule cannot handle. When that is the case, there must be another rule + * that resolves these conflicts. Otherwise, the analysis will fail. + */ + right + case Some((oldRelation, newRelation)) => + val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) + val newRight = right transformUp { + case r if r == oldRelation => newRelation + } transformUp { + case other => other transformExpressions { + case a: Attribute => + attributeRewrites.get(a).getOrElse(a).withQualifier(a.qualifier) + } + } + newRight } } @@ -301,32 +494,18 @@ class Analyzer( case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. - case p @ Project(projectList, child) if containsStar(projectList) => - Project( - projectList.flatMap { - case s: Star => s.expand(child, resolver) - case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if containsStar(args) => - val newChildren = expandStarExpressions(args, child) - UnresolvedAlias(child = f.copy(children = newChildren)) :: Nil - case Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) => - val newChildren = expandStarExpressions(args, child) - Alias(child = f.copy(children = newChildren), name)() :: Nil - case UnresolvedAlias(c @ CreateArray(args)) if containsStar(args) => - val expandedArgs = args.flatMap { - case s: Star => s.expand(child, resolver) - case o => o :: Nil - } - UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil - case UnresolvedAlias(c @ CreateStruct(args)) if containsStar(args) => - val expandedArgs = args.flatMap { - case s: Star => s.expand(child, resolver) - case o => o :: Nil - } - UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil - case o => o :: Nil - }, - child) - + case p: Project if containsStar(p.projectList) => + p.copy(projectList = buildExpandedProjectList(p.projectList, p.child)) + // If the aggregate function argument contains Stars, expand it. + case a: Aggregate if containsStar(a.aggregateExpressions) => + if (conf.groupByOrdinal && a.groupingExpressions.exists(IntegerIndex.unapply(_).nonEmpty)) { + failAnalysis( + "Group by position: star is not allowed to use in the select list " + + "when using ordinals in group by") + } else { + a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child)) + } + // If the script transformation input contains Stars, expand it. case t: ScriptTransformation if containsStar(t.input) => t.copy( input = t.input.flatMap { @@ -334,87 +513,38 @@ class Analyzer( case o => o :: Nil } ) + case g: Generate if containsStar(g.generator.children) => + failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF") - // If the aggregate function argument contains Stars, expand it. - case a: Aggregate if containsStar(a.aggregateExpressions) => - val expanded = expandStarExpressions(a.aggregateExpressions, a.child) - .map(_.asInstanceOf[NamedExpression]) - a.copy(aggregateExpressions = expanded) - - // Special handling for cases when self-join introduce duplicate expression ids. - case j @ Join(left, right, _, _) if !j.selfJoinResolved => - val conflictingAttributes = left.outputSet.intersect(right.outputSet) - logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j") - - right.collect { - // Handle base relations that might appear more than once. - case oldVersion: MultiInstanceRelation - if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => - val newVersion = oldVersion.newInstance() - (oldVersion, newVersion) - - // Handle projects that create conflicting aliases. - case oldVersion @ Project(projectList, _) - if findAliases(projectList).intersect(conflictingAttributes).nonEmpty => - (oldVersion, oldVersion.copy(projectList = newAliases(projectList))) - - case oldVersion @ Aggregate(_, aggregateExpressions, _) - if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => - (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) - - case oldVersion: Generate - if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty => - val newOutput = oldVersion.generatorOutput.map(_.newInstance()) - (oldVersion, oldVersion.copy(generatorOutput = newOutput)) - - case oldVersion @ Window(_, windowExpressions, _, _, child) - if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) - .nonEmpty => - (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) - } - // Only handle first case, others will be fixed on the next pass. - .headOption match { - case None => - /* - * No result implies that there is a logical plan node that produces new references - * that this rule cannot handle. When that is the case, there must be another rule - * that resolves these conflicts. Otherwise, the analysis will fail. - */ - j - case Some((oldRelation, newRelation)) => - val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) - val newRight = right transformUp { - case r if r == oldRelation => newRelation - } transformUp { - case other => other transformExpressions { - case a: Attribute => attributeRewrites.get(a).getOrElse(a) - } - } - j.copy(right = newRight) - } + // To resolve duplicate expression IDs for Join and Intersect + case j @ Join(left, right, _, _) if !j.duplicateResolved => + j.copy(right = dedupRight(left, right)) + case i @ Intersect(left, right) if !i.duplicateResolved => + i.copy(right = dedupRight(left, right)) // When resolve `SortOrder`s in Sort based on child, don't report errors as - // we still have chance to resolve it based on grandchild + // we still have chance to resolve it based on its descendants case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => - val newOrdering = resolveSortOrders(ordering, child, throws = false) + val newOrdering = + ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder]) Sort(newOrdering, global, child) // A special case for Generate, because the output of Generate should not be resolved by // ResolveReferences. Attributes in the output will be resolved by ResolveGenerate. - case g @ Generate(generator, join, outer, qualifier, output, child) - if child.resolved && !generator.resolved => - val newG = generator transformUp { - case u @ UnresolvedAttribute(nameParts) => - withPosition(u) { child.resolve(nameParts, resolver).getOrElse(u) } - case UnresolvedExtractValue(child, fieldExpr) => - ExtractValue(child, fieldExpr, resolver) - } + case g @ Generate(generator, _, _, _, _, _) if generator.resolved => g + + case g @ Generate(generator, join, outer, qualifier, output, child) => + val newG = resolveExpression(generator, child, throws = true) if (newG.fastEquals(generator)) { g } else { Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child) } + // Skips plan which contains deserializer expressions, as they should be resolved by another + // rule: ResolveDeserializer. + case plan if containsDeserializer(plan.expressions) => plan + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressionsUp { @@ -431,7 +561,7 @@ class Analyzer( def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = { expressions.map { - case a: Alias => Alias(a.child, a.name)() + case a: Alias => Alias(a.child, a.name)(isGenerated = a.isGenerated) case other => other } } @@ -440,30 +570,137 @@ class Analyzer( AttributeSet(projectList.collect { case a: Alias => a.toAttribute }) } + /** + * Build a project list for Project/Aggregate and expand the star if possible + */ + private def buildExpandedProjectList( + exprs: Seq[NamedExpression], + child: LogicalPlan): Seq[NamedExpression] = { + exprs.flatMap { + // Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*") + case s: Star => s.expand(child, resolver) + // Using SQL API without running ResolveAlias: SELECT * FROM testData2 group by a, b + case UnresolvedAlias(s: Star, _) => s.expand(child, resolver) + case o if containsStar(o :: Nil) => expandStarExpression(o, child) :: Nil + case o => o :: Nil + }.map(_.asInstanceOf[NamedExpression]) + } + /** * Returns true if `exprs` contains a [[Star]]. */ def containsStar(exprs: Seq[Expression]): Boolean = exprs.exists(_.collect { case _: Star => true }.nonEmpty) + + /** + * Expands the matching attribute.*'s in `child`'s output. + */ + def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = { + expr.transformUp { + case f1: UnresolvedFunction if containsStar(f1.children) => + f1.copy(children = f1.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) + case c: CreateStruct if containsStar(c.children) => + c.copy(children = c.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) + case c: CreateArray if containsStar(c.children) => + c.copy(children = c.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) + case p: Murmur3Hash if containsStar(p.children) => + p.copy(children = p.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) + // count(*) has been replaced by count(1) + case o if containsStar(o.children) => + failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'") + } + } } - private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = { - ordering.map { order => - // Resolve SortOrder in one round. - // If throws == false or the desired attribute doesn't exist - // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one. - // Else, throw exception. - try { - val newOrder = order transformUp { - case u @ UnresolvedAttribute(nameParts) => - plan.resolve(nameParts, resolver).getOrElse(u) - case UnresolvedExtractValue(child, fieldName) if child.resolved => - ExtractValue(child, fieldName, resolver) - } - newOrder.asInstanceOf[SortOrder] - } catch { - case a: AnalysisException if !throws => order + private def containsDeserializer(exprs: Seq[Expression]): Boolean = { + exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined) + } + + protected[sql] def resolveExpression( + expr: Expression, + plan: LogicalPlan, + throws: Boolean = false) = { + // Resolve expression in one round. + // If throws == false or the desired attribute doesn't exist + // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one. + // Else, throw exception. + try { + expr transformUp { + case u @ UnresolvedAttribute(nameParts) => + withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) } + case UnresolvedExtractValue(child, fieldName) if child.resolved => + ExtractValue(child, fieldName, resolver) } + } catch { + case a: AnalysisException if !throws => expr + } + } + + /** + * In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by + * clauses. This rule is to convert ordinal positions to the corresponding expressions in the + * select list. This support is introduced in Spark 2.0. + * + * - When the sort references or group by expressions are not integer but foldable expressions, + * just ignore them. + * - When spark.sql.orderByOrdinal/spark.sql.groupByOrdinal is set to false, ignore the position + * numbers too. + * + * Before the release of Spark 2.0, the literals in order/sort by and group by clauses + * have no effect on the results. + */ + object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.childrenResolved => p + // Replace the index with the related attribute for ORDER BY, + // which is a 1-base position of the projection list. + case s @ Sort(orders, global, child) + if conf.orderByOrdinal && orders.exists(o => IntegerIndex.unapply(o.child).nonEmpty) => + val newOrders = orders map { + case s @ SortOrder(IntegerIndex(index), direction) => + if (index > 0 && index <= child.output.size) { + SortOrder(child.output(index - 1), direction) + } else { + throw new UnresolvedException(s, + s"Order/sort By position: $index does not exist " + + s"The Select List is indexed from 1 to ${child.output.size}") + } + case o => o + } + Sort(newOrders, global, child) + + // Replace the index with the corresponding expression in aggregateExpressions. The index is + // a 1-base position of aggregateExpressions, which is output columns (select expression) + case a @ Aggregate(groups, aggs, child) + if conf.groupByOrdinal && aggs.forall(_.resolved) && + groups.exists(IntegerIndex.unapply(_).nonEmpty) => + val newGroups = groups.map { + case IntegerIndex(index) if index > 0 && index <= aggs.size => + aggs(index - 1) match { + case e if ResolveAggregateFunctions.containsAggregate(e) => + throw new UnresolvedException(a, + s"Group by position: the '$index'th column in the select contains an " + + s"aggregate function: ${e.sql}. Aggregate functions are not allowed in GROUP BY") + case o => o + } + case IntegerIndex(index) => + throw new UnresolvedException(a, + s"Group by position: '$index' exceeds the size of the select list '${aggs.size}'.") + case o => o + } + Aggregate(newGroups, aggs, child) } } @@ -472,45 +709,106 @@ class Analyzer( * clause. This rule detects such queries and adds the required attributes to the original * projection, so that they will be available during sorting. Another projection is added to * remove these attributes after sorting. + * + * The HAVING clause could also used a grouping columns that is not presented in the SELECT. */ - object ResolveSortReferences extends Rule[LogicalPlan] { + object ResolveMissingReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case s @ Sort(ordering, global, p @ Project(projectList, child)) - if !s.resolved && p.resolved => - val (newOrdering, missing) = resolveAndFindMissing(ordering, p, child) - - // If this rule was not a no-op, return the transformed plan, otherwise return the original. - if (missing.nonEmpty) { - // Add missing attributes and then project them away after the sort. - Project(p.output, - Sort(newOrdering, global, - Project(projectList ++ missing, child))) - } else { - logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}") - s // Nothing we can do here. Return original plan. + // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions + case sa @ Sort(_, _, child: Aggregate) => sa + + case s @ Sort(order, _, child) if child.resolved => + try { + val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) + val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) + val missingAttrs = requiredAttrs -- child.outputSet + if (missingAttrs.nonEmpty) { + // Add missing attributes and then project them away after the sort. + Project(child.output, + Sort(newOrder, s.global, addMissingAttr(child, missingAttrs))) + } else if (newOrder != order) { + s.copy(order = newOrder) + } else { + s + } + } catch { + // Attempting to resolve it might fail. When this happens, return the original plan. + // Users will see an AnalysisException for resolution failure of missing attributes + // in Sort + case ae: AnalysisException => s + } + + case f @ Filter(cond, child) if child.resolved => + try { + val newCond = resolveExpressionRecursively(cond, child) + val requiredAttrs = newCond.references.filter(_.resolved) + val missingAttrs = requiredAttrs -- child.outputSet + if (missingAttrs.nonEmpty) { + // Add missing attributes and then project them away. + Project(child.output, + Filter(newCond, addMissingAttr(child, missingAttrs))) + } else if (newCond != cond) { + f.copy(condition = newCond) + } else { + f + } + } catch { + // Attempting to resolve it might fail. When this happens, return the original plan. + // Users will see an AnalysisException for resolution failure of missing attributes + case ae: AnalysisException => f } } /** - * Given a child and a grandchild that are present beneath a sort operator, try to resolve - * the sort ordering and returns it with a list of attributes that are missing from the - * child but are present in the grandchild. + * Add the missing attributes into projectList of Project/Window or aggregateExpressions of + * Aggregate. */ - def resolveAndFindMissing( - ordering: Seq[SortOrder], - child: LogicalPlan, - grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { - val newOrdering = resolveSortOrders(ordering, grandchild, throws = true) - // Construct a set that contains all of the attributes that we need to evaluate the - // ordering. - val requiredAttributes = AttributeSet(newOrdering).filter(_.resolved) - // Figure out which ones are missing from the projection, so that we can add them and - // remove them after the sort. - val missingInProject = requiredAttributes -- child.output - // It is important to return the new SortOrders here, instead of waiting for the standard - // resolving process as adding attributes to the project below can actually introduce - // ambiguity that was not present before. - (newOrdering, missingInProject.toSeq) + private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = { + if (missingAttrs.isEmpty) { + return plan + } + plan match { + case p: Project => + val missing = missingAttrs -- p.child.outputSet + Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, missing)) + case a: Aggregate => + // all the missing attributes should be grouping expressions + // TODO: push down AggregateExpression + missingAttrs.foreach { attr => + if (!a.groupingExpressions.exists(_.semanticEquals(attr))) { + throw new AnalysisException(s"Can't add $attr to ${a.simpleString}") + } + } + val newAggregateExpressions = a.aggregateExpressions ++ missingAttrs + a.copy(aggregateExpressions = newAggregateExpressions) + case g: Generate => + // If join is false, we will convert it to true for getting from the child the missing + // attributes that its child might have or could have. + val missing = missingAttrs -- g.child.outputSet + g.copy(join = true, child = addMissingAttr(g.child, missing)) + case u: UnaryNode => + u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil) + case other => + throw new AnalysisException(s"Can't add $missingAttrs to $other") + } + } + + /** + * Resolve the expression on a specified logical plan and it's child (recursively), until + * the expression is resolved or meet a non-unary node or Subquery. + */ + @tailrec + private def resolveExpressionRecursively(expr: Expression, plan: LogicalPlan): Expression = { + val resolved = resolveExpression(expr, plan) + if (resolved.resolved) { + resolved + } else { + plan match { + case u: UnaryNode if !u.isInstanceOf[SubqueryAlias] => + resolveExpressionRecursively(resolved, u.child) + case other => resolved + } + } } } @@ -522,24 +820,30 @@ class Analyzer( case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. + case u @ UnresolvedGenerator(name, children) => + withPosition(u) { + catalog.lookupFunction(name, children) match { + case generator: Generator => generator + case other => + failAnalysis(s"$name is expected to be a generator. However, " + + s"its class is ${other.getClass.getCanonicalName}, which is not a generator.") + } + } case u @ UnresolvedFunction(name, children, isDistinct) => withPosition(u) { - registry.lookupFunction(name, children) match { - // We get an aggregate function built based on AggregateFunction2 interface. - // So, we wrap it in AggregateExpression2. - case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct) - // Currently, our old aggregate function interface supports SUM(DISTINCT ...) - // and COUTN(DISTINCT ...). - case sumDistinct: SumDistinct => sumDistinct - case countDistinct: CountDistinct => countDistinct - // DISTINCT is not meaningful with Max and Min. - case max: Max if isDistinct => max - case min: Min if isDistinct => min - // For other aggregate functions, DISTINCT keyword is not supported for now. - // Once we converted to the new code path, we will allow using DISTINCT keyword. - case other: AggregateExpression1 if isDistinct => - failAnalysis(s"$name does not support DISTINCT keyword.") - // If it does not have DISTINCT keyword, we will return it as is. + catalog.lookupFunction(name, children) match { + // DISTINCT is not meaningful for a Max or a Min. + case max: Max if isDistinct => + AggregateExpression(max, Complete, isDistinct = false) + case min: Min if isDistinct => + AggregateExpression(min, Complete, isDistinct = false) + // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within + // the context of a Window clause. They do not need to be wrapped in an + // AggregateExpression. + case wf: AggregateWindowFunction => wf + // We get an aggregate function, we need to wrap it in an AggregateExpression. + case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct) + // This function is not an aggregate function, just return the resolved one. case other => other } } @@ -547,6 +851,30 @@ class Analyzer( } } + /** + * This rule resolve subqueries inside expressions. + * + * Note: CTE are handled in CTESubstitution. + */ + object ResolveSubquery extends Rule[LogicalPlan] with PredicateHelper { + + private def hasSubquery(e: Expression): Boolean = { + e.find(_.isInstanceOf[SubqueryExpression]).isDefined + } + + private def hasSubquery(q: LogicalPlan): Boolean = { + q.expressions.exists(hasSubquery) + } + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case q: LogicalPlan if q.childrenResolved && hasSubquery(q) => + q transformExpressions { + case e: SubqueryExpression if !e.query.resolved => + e.withNewPlan(execute(e.query)) + } + } + } + /** * Turns projections that contain aggregate expressions into aggregations. */ @@ -557,11 +885,17 @@ class Analyzer( } def containsAggregates(exprs: Seq[Expression]): Boolean = { - exprs.foreach(_.foreach { - case agg: AggregateExpression => return true - case _ => - }) - false + // Collect all Windowed Aggregate Expressions. + val windowedAggExprs = exprs.flatMap { expr => + expr.collect { + case WindowExpression(ae: AggregateExpression, _) => ae + } + }.toSet + + // Find the first Aggregate Expression that is not Windowed. + exprs.exists(_.collectFirst { + case ae: AggregateExpression if !windowedAggExprs.contains(ae) => ae + }.isDefined) } } @@ -577,32 +911,42 @@ class Analyzer( if aggregate.resolved => // Try resolving the condition of the filter as though it is in the aggregate clause - val aggregatedCondition = - Aggregate(grouping, Alias(havingCondition, "havingCondition")() :: Nil, child) - val resolvedOperator = execute(aggregatedCondition) - def resolvedAggregateFilter = - resolvedOperator - .asInstanceOf[Aggregate] - .aggregateExpressions.head - - // If resolution was successful and we see the filter has an aggregate in it, add it to - // the original aggregate operator. - if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter)) { - val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs - - Project(aggregate.output, - Filter(resolvedAggregateFilter.toAttribute, - aggregate.copy(aggregateExpressions = aggExprsWithHaving))) - } else { - filter + try { + val aggregatedCondition = + Aggregate( + grouping, + Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil, + child) + val resolvedOperator = execute(aggregatedCondition) + def resolvedAggregateFilter = + resolvedOperator + .asInstanceOf[Aggregate] + .aggregateExpressions.head + + // If resolution was successful and we see the filter has an aggregate in it, add it to + // the original aggregate operator. + if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter)) { + val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs + + Project(aggregate.output, + Filter(resolvedAggregateFilter.toAttribute, + aggregate.copy(aggregateExpressions = aggExprsWithHaving))) + } else { + filter + } + } catch { + // Attempting to resolve in the aggregate can result in ambiguity. When this happens, + // just return the original plan. + case ae: AnalysisException => filter } - case sort @ Sort(sortOrder, global, aggregate: Aggregate) - if aggregate.resolved => + case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => // Try resolving the ordering as though it is in the aggregate clause. try { - val aliasedOrdering = sortOrder.map(o => Alias(o.child, "aggOrder")()) + val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s)) + val aliasedOrdering = + unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")(isGenerated = true)) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] val resolvedAliasedOrdering: Seq[Alias] = @@ -635,13 +979,19 @@ class Analyzer( } } + val sortOrdersMap = unresolvedSortOrders + .map(new TreeNodeRef(_)) + .zip(evaluatedOrderings) + .toMap + val finalSortOrders = sortOrder.map(s => sortOrdersMap.getOrElse(new TreeNodeRef(s), s)) + // Since we don't rely on sort.resolved as the stop condition for this rule, // we need to check this and prevent applying this rule multiple times - if (sortOrder == evaluatedOrderings) { + if (sortOrder == finalSortOrders) { sort } else { Project(aggregate.output, - Sort(evaluatedOrderings, global, + Sort(finalSortOrders, global, aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) } } catch { @@ -651,7 +1001,7 @@ class Analyzer( } } - protected def containsAggregate(condition: Expression): Boolean = { + def containsAggregate(condition: Expression): Boolean = { condition.find(_.isInstanceOf[AggregateExpression]).isDefined } } @@ -668,8 +1018,6 @@ class Analyzer( */ object ResolveGenerate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case g: Generate if ResolveReferences.containsStar(g.generator.children) => - failAnalysis("Cannot explode *, explode can only be applied on a specific column.") case p: Generate if !p.child.resolved || !p.generator.resolved => p case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) @@ -815,12 +1163,13 @@ class Analyzer( if (missingExpr.nonEmpty) { extractedExprBuffer += ne } - ne.toAttribute + // alias will be cleaned in the rule CleanupAliases + ne case e: Expression if e.foldable => e // No need to create an attribute reference if it will be evaluated as a Literal. case e: Expression => // For other expressions, we extract it and replace it with an AttributeReference (with - // an interal column name, e.g. "_w0"). + // an internal column name, e.g. "_w0"). val withName = Alias(e, s"_w${extractedExprBuffer.length}")() extractedExprBuffer += withName withName.toAttribute @@ -828,26 +1177,37 @@ class Analyzer( // Now, we extract regular expressions from expressionsWithWindowFunctions // by using extractExpr. + val seenWindowAggregates = new ArrayBuffer[AggregateExpression] val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map { _.transform { // Extracts children expressions of a WindowFunction (input parameters of // a WindowFunction). - case wf : WindowFunction => - val newChildren = wf.children.map(extractExpr(_)) + case wf: WindowFunction => + val newChildren = wf.children.map(extractExpr) wf.withNewChildren(newChildren) // Extracts expressions from the partition spec and order spec. case wsc @ WindowSpecDefinition(partitionSpec, orderSpec, _) => - val newPartitionSpec = partitionSpec.map(extractExpr(_)) + val newPartitionSpec = partitionSpec.map(extractExpr) val newOrderSpec = orderSpec.map { so => val newChild = extractExpr(so.child) so.copy(child = newChild) } wsc.copy(partitionSpec = newPartitionSpec, orderSpec = newOrderSpec) + // Extract Windowed AggregateExpression + case we @ WindowExpression( + ae @ AggregateExpression(function, _, _, _), + spec: WindowSpecDefinition) => + val newChildren = function.children.map(extractExpr) + val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction] + val newAgg = ae.copy(aggregateFunction = newFunction) + seenWindowAggregates += newAgg + WindowExpression(newAgg, spec) + // Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...), // we need to extract SUM(x). - case agg: AggregateExpression => + case agg: AggregateExpression if !seenWindowAggregates.contains(agg) => val withName = Alias(agg, s"_w${extractedExprBuffer.length}")() extractedExprBuffer += withName withName.toAttribute @@ -925,7 +1285,6 @@ class Analyzer( // Set currentChild to the newly created Window operator. currentChild = Window( - currentChild.output, windowExpressions, partitionSpec, orderSpec, @@ -958,7 +1317,7 @@ class Analyzer( val withWindow = addWindow(windowExpressions, withFilter) // Finally, generate output columns according to the original projectList. - val finalProjectList = aggregateExprs.map (_.toAttribute) + val finalProjectList = aggregateExprs.map(_.toAttribute) Project(finalProjectList, withWindow) case p: LogicalPlan if !p.childrenResolved => p @@ -974,7 +1333,7 @@ class Analyzer( val withWindow = addWindow(windowExpressions, withAggregate) // Finally, generate output columns according to the original projectList. - val finalProjectList = aggregateExprs.map (_.toAttribute) + val finalProjectList = aggregateExprs.map(_.toAttribute) Project(finalProjectList, withWindow) // We only extract Window Expressions after all expressions of the Project @@ -989,7 +1348,7 @@ class Analyzer( val withWindow = addWindow(windowExpressions, withProject) // Finally, generate output columns according to the original projectList. - val finalProjectList = projectList.map (_.toAttribute) + val finalProjectList = projectList.map(_.toAttribute) Project(finalProjectList, withWindow) } } @@ -1015,7 +1374,7 @@ class Analyzer( leafNondeterministic.map { e => val ne = e match { case n: NamedExpression => n - case _ => Alias(e, "_nondeterministic")() + case _ => Alias(e, "_nondeterministic")(isGenerated = true) } new TreeNodeRef(e) -> ne } @@ -1027,15 +1386,240 @@ class Analyzer( Project(p.output, newPlan.withNewChildren(newChild :: Nil)) } } + + /** + * Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the + * null check. When user defines a UDF with primitive parameters, there is no way to tell if the + * primitive parameter is null or not, so here we assume the primitive input is null-propagatable + * and we should return null if the input is null. + */ + object HandleNullInputsForUDF extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.resolved => p // Skip unresolved nodes. + + case p => p transformExpressionsUp { + + case udf @ ScalaUDF(func, _, inputs, _) => + val parameterTypes = ScalaReflection.getParameterTypes(func) + assert(parameterTypes.length == inputs.length) + + val inputsNullCheck = parameterTypes.zip(inputs) + // TODO: skip null handling for not-nullable primitive inputs after we can completely + // trust the `nullable` information. + // .filter { case (cls, expr) => cls.isPrimitive && expr.nullable } + .filter { case (cls, _) => cls.isPrimitive } + .map { case (_, expr) => IsNull(expr) } + .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2)) + inputsNullCheck.map(If(_, Literal.create(null, udf.dataType), udf)).getOrElse(udf) + } + } + } + + /** + * Check and add proper window frames for all window functions. + */ + object ResolveWindowFrame extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case logical: LogicalPlan => logical transformExpressions { + case WindowExpression(wf: WindowFunction, + WindowSpecDefinition(_, _, f: SpecifiedWindowFrame)) + if wf.frame != UnspecifiedFrame && wf.frame != f => + failAnalysis(s"Window Frame $f must match the required frame ${wf.frame}") + case WindowExpression(wf: WindowFunction, + s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) + if wf.frame != UnspecifiedFrame => + WindowExpression(wf, s.copy(frameSpecification = wf.frame)) + case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) => + val frame = SpecifiedWindowFrame.defaultWindowFrame(o.nonEmpty, acceptWindowFrame = true) + we.copy(windowSpec = s.copy(frameSpecification = frame)) + } + } + } + + /** + * Check and add order to [[AggregateWindowFunction]]s. + */ + object ResolveWindowOrder extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case logical: LogicalPlan => logical transformExpressions { + case WindowExpression(wf: WindowFunction, spec) if spec.orderSpec.isEmpty => + failAnalysis(s"WindowFunction $wf requires window to be ordered") + case WindowExpression(rank: RankLike, spec) if spec.resolved => + val order = spec.orderSpec.map(_.child) + WindowExpression(rank.withOrder(order), spec) + } + } + } + + /** + * Removes natural or using joins by calculating output columns based on output from two sides, + * Then apply a Project on a normal Join to eliminate natural or using join. + */ + object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) + if left.resolved && right.resolved && j.duplicateResolved => + // Resolve the column names referenced in using clause from both the legs of join. + val lCols = usingCols.flatMap(col => left.resolveQuoted(col.name, resolver)) + val rCols = usingCols.flatMap(col => right.resolveQuoted(col.name, resolver)) + if ((lCols.length == usingCols.length) && (rCols.length == usingCols.length)) { + val joinNames = lCols.map(exp => exp.name) + commonNaturalJoinProcessing(left, right, joinType, joinNames, None) + } else { + j + } + case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural => + // find common column names from both sides + val joinNames = left.output.map(_.name).intersect(right.output.map(_.name)) + commonNaturalJoinProcessing(left, right, joinType, joinNames, condition) + } + } + + private def commonNaturalJoinProcessing( + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + joinNames: Seq[String], + condition: Option[Expression]) = { + val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get) + val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get) + val joinPairs = leftKeys.zip(rightKeys) + + val newCondition = (condition ++ joinPairs.map(EqualTo.tupled)).reduceOption(And) + + // columns not in joinPairs + val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att)) + val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att)) + + // the output list looks like: join keys, columns from left, columns from right + val projectList = joinType match { + case LeftOuter => + leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true)) + case LeftExistence(_) => + leftKeys ++ lUniqueOutput + case RightOuter => + rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput + case FullOuter => + // in full outer join, joinCols should be non-null if there is. + val joinedCols = joinPairs.map { case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() } + joinedCols ++ + lUniqueOutput.map(_.withNullability(true)) ++ + rUniqueOutput.map(_.withNullability(true)) + case Inner => + leftKeys ++ lUniqueOutput ++ rUniqueOutput + case _ => + sys.error("Unsupported natural join type " + joinType) + } + // use Project to trim unnecessary fields + Project(projectList, Join(left, right, joinType, newCondition)) + } + + /** + * Replaces [[UnresolvedDeserializer]] with the deserialization expression that has been resolved + * to the given input attributes. + */ + object ResolveDeserializer extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.childrenResolved => p + case p if p.resolved => p + + case p => p transformExpressions { + case UnresolvedDeserializer(deserializer, inputAttributes) => + val inputs = if (inputAttributes.isEmpty) { + p.children.flatMap(_.output) + } else { + inputAttributes + } + val unbound = deserializer transform { + case b: BoundReference => inputs(b.ordinal) + } + resolveExpression(unbound, LocalRelation(inputs), throws = true) + } + } + } + + /** + * Resolves [[NewInstance]] by finding and adding the outer scope to it if the object being + * constructed is an inner class. + */ + object ResolveNewInstance extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.childrenResolved => p + case p if p.resolved => p + + case p => p transformExpressions { + case n: NewInstance if n.childrenResolved && !n.resolved => + val outer = OuterScopes.getOuterScope(n.cls) + if (outer == null) { + throw new AnalysisException( + s"Unable to generate an encoder for inner class `${n.cls.getName}` without " + + "access to the scope that this class was defined in.\n" + + "Try moving this class out of its parent class.") + } + n.copy(outerPointer = Some(outer)) + } + } + } + + /** + * Replace the [[UpCast]] expression by [[Cast]], and throw exceptions if the cast may truncate. + */ + object ResolveUpCast extends Rule[LogicalPlan] { + private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = { + throw new AnalysisException(s"Cannot up cast ${from.sql} from " + + s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" + + "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + + "You can either add an explicit cast to the input data or choose a higher precision " + + "type of the field in the target object") + } + + private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = { + val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from) + val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to) + toPrecedence > 0 && fromPrecedence > toPrecedence + } + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.childrenResolved => p + case p if p.resolved => p + + case p => p transformExpressions { + case u @ UpCast(child, _, _) if !child.resolved => u + + case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match { + case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) => + fail(child, to, walkedTypePath) + case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) => + fail(child, to, walkedTypePath) + case (from, to) if illegalNumericPrecedence(from, to) => + fail(child, to, walkedTypePath) + case (TimestampType, DateType) => + fail(child, DateType, walkedTypePath) + case (StringType, to: NumericType) => + fail(child, to, walkedTypePath) + case _ => Cast(child, dataType.asNullable) + } + } + } + } } /** - * Removes [[Subquery]] operators from the plan. Subqueries are only required to provide + * Removes [[SubqueryAlias]] operators from the plan. Subqueries are only required to provide * scoping information for attributes and can be removed once analysis is complete. */ -object EliminateSubQueries extends Rule[LogicalPlan] { +object EliminateSubqueryAliases extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case SubqueryAlias(_, child) => child + } +} + +/** + * Removes [[Union]] operators from the plan if it just has one child. + */ +object EliminateUnions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Subquery(_, child) => child + case Union(children) if children.size == 1 => children.head } } @@ -1063,7 +1647,7 @@ object CleanupAliases extends Rule[LogicalPlan] { def trimNonTopLevelAliases(e: Expression): Expression = e match { case a: Alias => - Alias(trimAliases(a.child), a.name)(a.exprId, a.qualifiers, a.explicitMetadata) + a.withNewChildren(trimAliases(a.child) :: Nil) case other => trimAliases(other) } @@ -1077,12 +1661,18 @@ object CleanupAliases extends Rule[LogicalPlan] { val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) Aggregate(grouping.map(trimAliases), cleanedAggs, child) - case w @ Window(projectList, windowExprs, partitionSpec, orderSpec, child) => + case w @ Window(windowExprs, partitionSpec, orderSpec, child) => val cleanedWindowExprs = windowExprs.map(e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression]) - Window(projectList, cleanedWindowExprs, partitionSpec.map(trimAliases), + Window(cleanedWindowExprs, partitionSpec.map(trimAliases), orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child) + // Operators that operate on objects should only have expressions from encoders, which should + // never have extra aliases. + case o: ObjectOperator => o + case d: DeserializeToObject => d + case s: SerializeFromObject => s + case other => var stop = false other transformExpressionsDown { @@ -1096,3 +1686,92 @@ object CleanupAliases extends Rule[LogicalPlan] { } } } + +/** + * Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to + * figure out how many windows a time column can map to, we over-estimate the number of windows and + * filter out the rows where the time column is not inside the time window. + */ +object TimeWindowing extends Rule[LogicalPlan] { + import org.apache.spark.sql.catalyst.dsl.expressions._ + + private final val WINDOW_START = "start" + private final val WINDOW_END = "end" + + /** + * Generates the logical plan for generating window ranges on a timestamp column. Without + * knowing what the timestamp value is, it's non-trivial to figure out deterministically how many + * window ranges a timestamp will map to given all possible combinations of a window duration, + * slide duration and start time (offset). Therefore, we express and over-estimate the number of + * windows there may be, and filter the valid windows. We use last Project operator to group + * the window columns into a struct so they can be accessed as `window.start` and `window.end`. + * + * The windows are calculated as below: + * maxNumOverlapping <- ceil(windowDuration / slideDuration) + * for (i <- 0 until maxNumOverlapping) + * windowId <- ceil((timestamp - startTime) / slideDuration) + * windowStart <- windowId * slideDuration + (i - maxNumOverlapping) * slideDuration + startTime + * windowEnd <- windowStart + windowDuration + * return windowStart, windowEnd + * + * This behaves as follows for the given parameters for the time: 12:05. The valid windows are + * marked with a +, and invalid ones are marked with a x. The invalid ones are filtered using the + * Filter operator. + * window: 12m, slide: 5m, start: 0m :: window: 12m, slide: 5m, start: 2m + * 11:55 - 12:07 + 11:52 - 12:04 x + * 12:00 - 12:12 + 11:57 - 12:09 + + * 12:05 - 12:17 + 12:02 - 12:14 + + * + * @param plan The logical plan + * @return the logical plan that will generate the time windows using the Expand operator, with + * the Filter operator for correctness and Project for usability. + */ + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p: LogicalPlan if p.children.size == 1 => + val child = p.children.head + val windowExpressions = + p.expressions.flatMap(_.collect { case t: TimeWindow => t }).distinct.toList // Not correct. + + // Only support a single window expression for now + if (windowExpressions.size == 1 && + windowExpressions.head.timeColumn.resolved && + windowExpressions.head.checkInputDataTypes().isSuccess) { + val window = windowExpressions.head + val windowAttr = AttributeReference("window", window.dataType)() + + val maxNumOverlapping = math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt + val windows = Seq.tabulate(maxNumOverlapping + 1) { i => + val windowId = Ceil((PreciseTimestamp(window.timeColumn) - window.startTime) / + window.slideDuration) + val windowStart = (windowId + i - maxNumOverlapping) * + window.slideDuration + window.startTime + val windowEnd = windowStart + window.windowDuration + + CreateNamedStruct( + Literal(WINDOW_START) :: windowStart :: + Literal(WINDOW_END) :: windowEnd :: Nil) + } + + val projections = windows.map(_ +: p.children.head.output) + + val filterExpr = + window.timeColumn >= windowAttr.getField(WINDOW_START) && + window.timeColumn < windowAttr.getField(WINDOW_END) + + val expandedPlan = + Filter(filterExpr, + Expand(projections, windowAttr +: child.output, child)) + + val substitutedPlan = p transformExpressions { + case t: TimeWindow => windowAttr + } + + substitutedPlan.withNewChildren(expandedPlan :: Nil) + } else if (windowExpressions.size > 1) { + p.failAnalysis("Multiple time window expressions would result in a cartesian product " + + "of rows, therefore they are not currently not supported.") + } else { + p // Return unchanged. Analyzer will throw exception later + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala deleted file mode 100644 index 8f4ce74a2ea38..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ /dev/null @@ -1,221 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.analysis - -import java.util.concurrent.ConcurrentHashMap - -import scala.collection.JavaConverters._ -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{TableIdentifier, CatalystConf, EmptyConf} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery} - -/** - * Thrown by a catalog when a table cannot be found. The analyzer will rethrow the exception - * as an AnalysisException with the correct position information. - */ -class NoSuchTableException extends Exception - -class NoSuchDatabaseException extends Exception - -/** - * An interface for looking up relations by name. Used by an [[Analyzer]]. - */ -trait Catalog { - - val conf: CatalystConf - - def tableExists(tableIdent: TableIdentifier): Boolean - - def lookupRelation(tableIdent: TableIdentifier, alias: Option[String] = None): LogicalPlan - - /** - * Returns tuples of (tableName, isTemporary) for all tables in the given database. - * isTemporary is a Boolean value indicates if a table is a temporary or not. - */ - def getTables(databaseName: Option[String]): Seq[(String, Boolean)] - - def refreshTable(tableIdent: TableIdentifier): Unit - - def registerTable(tableIdent: TableIdentifier, plan: LogicalPlan): Unit - - def unregisterTable(tableIdent: TableIdentifier): Unit - - def unregisterAllTables(): Unit - - /** - * Get the table name of TableIdentifier for temporary tables. - */ - protected def getTableName(tableIdent: TableIdentifier): String = { - // It is not allowed to specify database name for temporary tables. - // We check it here and throw exception if database is defined. - if (tableIdent.database.isDefined) { - throw new AnalysisException("Specifying database name or other qualifiers are not allowed " + - "for temporary tables. If the table name has dots (.) in it, please quote the " + - "table name with backticks (`).") - } - if (conf.caseSensitiveAnalysis) { - tableIdent.table - } else { - tableIdent.table.toLowerCase - } - } -} - -class SimpleCatalog(val conf: CatalystConf) extends Catalog { - private[this] val tables = new ConcurrentHashMap[String, LogicalPlan] - - override def registerTable(tableIdent: TableIdentifier, plan: LogicalPlan): Unit = { - tables.put(getTableName(tableIdent), plan) - } - - override def unregisterTable(tableIdent: TableIdentifier): Unit = { - tables.remove(getTableName(tableIdent)) - } - - override def unregisterAllTables(): Unit = { - tables.clear() - } - - override def tableExists(tableIdent: TableIdentifier): Boolean = { - tables.containsKey(getTableName(tableIdent)) - } - - override def lookupRelation( - tableIdent: TableIdentifier, - alias: Option[String] = None): LogicalPlan = { - val tableName = getTableName(tableIdent) - val table = tables.get(tableName) - if (table == null) { - throw new NoSuchTableException - } - val tableWithQualifiers = Subquery(tableName, table) - - // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are - // properly qualified with this alias. - alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) - } - - override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { - tables.keySet().asScala.map(_ -> true).toSeq - } - - override def refreshTable(tableIdent: TableIdentifier): Unit = { - throw new UnsupportedOperationException - } -} - -/** - * A trait that can be mixed in with other Catalogs allowing specific tables to be overridden with - * new logical plans. This can be used to bind query result to virtual tables, or replace tables - * with in-memory cached versions. Note that the set of overrides is stored in memory and thus - * lost when the JVM exits. - */ -trait OverrideCatalog extends Catalog { - private[this] val overrides = new ConcurrentHashMap[String, LogicalPlan] - - private def getOverriddenTable(tableIdent: TableIdentifier): Option[LogicalPlan] = { - if (tableIdent.database.isDefined) { - None - } else { - Option(overrides.get(getTableName(tableIdent))) - } - } - - abstract override def tableExists(tableIdent: TableIdentifier): Boolean = { - getOverriddenTable(tableIdent) match { - case Some(_) => true - case None => super.tableExists(tableIdent) - } - } - - abstract override def lookupRelation( - tableIdent: TableIdentifier, - alias: Option[String] = None): LogicalPlan = { - getOverriddenTable(tableIdent) match { - case Some(table) => - val tableName = getTableName(tableIdent) - val tableWithQualifiers = Subquery(tableName, table) - - // If an alias was specified by the lookup, wrap the plan in a sub-query so that attributes - // are properly qualified with this alias. - alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) - - case None => super.lookupRelation(tableIdent, alias) - } - } - - abstract override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { - overrides.keySet().asScala.map(_ -> true).toSeq ++ super.getTables(databaseName) - } - - override def registerTable(tableIdent: TableIdentifier, plan: LogicalPlan): Unit = { - overrides.put(getTableName(tableIdent), plan) - } - - override def unregisterTable(tableIdent: TableIdentifier): Unit = { - if (tableIdent.database.isEmpty) { - overrides.remove(getTableName(tableIdent)) - } - } - - override def unregisterAllTables(): Unit = { - overrides.clear() - } -} - -/** - * A trivial catalog that returns an error when a relation is requested. Used for testing when all - * relations are already filled in and the analyzer needs only to resolve attribute references. - */ -object EmptyCatalog extends Catalog { - - override val conf: CatalystConf = EmptyConf - - override def tableExists(tableIdent: TableIdentifier): Boolean = { - throw new UnsupportedOperationException - } - - override def lookupRelation( - tableIdent: TableIdentifier, - alias: Option[String] = None): LogicalPlan = { - throw new UnsupportedOperationException - } - - override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { - throw new UnsupportedOperationException - } - - override def registerTable(tableIdent: TableIdentifier, plan: LogicalPlan): Unit = { - throw new UnsupportedOperationException - } - - override def unregisterTable(tableIdent: TableIdentifier): Unit = { - throw new UnsupportedOperationException - } - - override def unregisterAllTables(): Unit = { - throw new UnsupportedOperationException - } - - override def refreshTable(tableIdent: TableIdentifier): Unit = { - throw new UnsupportedOperationException - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 98d6637c0601b..d6a8c3eec81aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.UsingJoin import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -50,45 +52,73 @@ trait CheckAnalysis { case p if p.analyzed => // Skip already analyzed sub-plans case u: UnresolvedRelation => - u.failAnalysis(s"Table not found: ${u.tableIdentifier}") + u.failAnalysis(s"Table or View not found: ${u.tableIdentifier}") case operator: LogicalPlan => operator transformExpressionsUp { case a: Attribute if !a.resolved => val from = operator.inputSet.map(_.name).mkString(", ") - a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") + a.failAnalysis(s"cannot resolve '${a.sql}' given input columns: [$from]") case e: Expression if e.checkInputDataTypes().isFailure => e.checkInputDataTypes() match { case TypeCheckResult.TypeCheckFailure(message) => e.failAnalysis( - s"cannot resolve '${e.prettyString}' due to data type mismatch: $message") + s"cannot resolve '${e.sql}' due to data type mismatch: $message") } case c: Cast if !c.resolved => failAnalysis( s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") - case WindowExpression(UnresolvedWindowFunction(name, _), _) => - failAnalysis( - s"Could not resolve window function '$name'. " + - "Note that, using window functions currently requires a HiveContext") + case g: Grouping => + failAnalysis(s"grouping() can only be used with GroupingSets/Cube/Rollup") + case g: GroupingID => + failAnalysis(s"grouping_id() can only be used with GroupingSets/Cube/Rollup") + + case w @ WindowExpression(AggregateExpression(_, _, true, _), _) => + failAnalysis(s"Distinct window functions are not supported: $w") + + case w @ WindowExpression(_: OffsetWindowFunction, WindowSpecDefinition(_, order, + SpecifiedWindowFrame(frame, + FrameBoundary(l), + FrameBoundary(h)))) + if order.isEmpty || frame != RowFrame || l != h => + failAnalysis("An offset window function can only be evaluated in an ordered " + + s"row-based window frame with a single offset: $w") + + case w @ WindowExpression(e, s) => + // Only allow window functions with an aggregate expression or an offset window + // function. + e match { + case _: AggregateExpression | _: OffsetWindowFunction | _: AggregateWindowFunction => + case _ => + failAnalysis(s"Expression '$e' not supported within a window function.") + } + // Make sure the window specification is valid. + s.validate match { + case Some(m) => + failAnalysis(s"Window specification $s is not valid because $m") + case None => w + } - case w @ WindowExpression(windowFunction, windowSpec) if windowSpec.validate.nonEmpty => - // The window spec is not valid. - val reason = windowSpec.validate.get - failAnalysis(s"Window specification $windowSpec is not valid because $reason") } operator match { case f: Filter if f.condition.dataType != BooleanType => failAnalysis( - s"filter expression '${f.condition.prettyString}' " + + s"filter expression '${f.condition.sql}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") + case j @ Join(_, _, UsingJoin(_, cols), _) => + val from = operator.inputSet.map(_.name).mkString(", ") + failAnalysis( + s"using columns [${cols.mkString(",")}] " + + s"can not be resolved given input columns: [$from] ") + case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => failAnalysis( - s"join condition '${condition.prettyString}' " + + s"join condition '${condition.sql}' " + s"of type ${condition.dataType.simpleString} is not a boolean.") case j @ Join(_, _, _, Some(condition)) => @@ -96,10 +126,10 @@ trait CheckAnalysis { case p: Predicate => p.asInstanceOf[Expression].children.foreach(checkValidJoinConditionExprs) case e if e.dataType.isInstanceOf[BinaryType] => - failAnalysis(s"binary type expression ${e.prettyString} cannot be used " + + failAnalysis(s"binary type expression ${e.sql} cannot be used " + "in join conditions") case e if e.dataType.isInstanceOf[MapType] => - failAnalysis(s"map type expression ${e.prettyString} cannot be used " + + failAnalysis(s"map type expression ${e.sql} cannot be used " + "in join conditions") case _ => // OK } @@ -108,10 +138,26 @@ trait CheckAnalysis { case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { - case _: AggregateExpression => // OK + case aggExpr: AggregateExpression => + aggExpr.aggregateFunction.children.foreach { child => + child.foreach { + case agg: AggregateExpression => + failAnalysis( + s"It is not allowed to use an aggregate function in the argument of " + + s"another aggregate function. Please use the inner aggregate function " + + s"in a sub-query.") + case other => // OK + } + + if (!child.deterministic) { + failAnalysis( + s"nondeterministic expression ${expr.sql} should not " + + s"appear in the arguments of an aggregate function.") + } + } case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) => failAnalysis( - s"expression '${e.prettyString}' is neither present in the group by, " + + s"expression '${e.sql}' is neither present in the group by, " + s"nor is it an aggregate function. " + "Add to group by or wrap in first() (or first_value) if you don't care " + "which value you get.") @@ -120,14 +166,22 @@ trait CheckAnalysis { case e => e.children.foreach(checkValidAggregateExpression) } - def checkValidGroupingExprs(expr: Expression): Unit = expr.dataType match { - case BinaryType => - failAnalysis(s"binary type expression ${expr.prettyString} cannot be used " + - "in grouping expression") - case m: MapType => - failAnalysis(s"map type expression ${expr.prettyString} cannot be used " + - "in grouping expression") - case _ => // OK + def checkValidGroupingExprs(expr: Expression): Unit = { + // Check if the data type of expr is orderable. + if (!RowOrdering.isOrderable(expr.dataType)) { + failAnalysis( + s"expression ${expr.sql} cannot be used as a grouping expression " + + s"because its data type ${expr.dataType.simpleString} is not a orderable " + + s"data type.") + } + + if (!expr.deterministic) { + // This is just a sanity check, our analysis rule PullOutNondeterministic should + // already pull out those nondeterministic expressions and evaluate them in + // a Project node. + failAnalysis(s"nondeterministic expression ${expr.sql} should not " + + s"appear in grouping expression.") + } } aggregateExprs.foreach(checkValidAggregateExpression) @@ -147,6 +201,14 @@ trait CheckAnalysis { s"but the left table has ${left.output.length} columns and the right has " + s"${right.output.length}") + case s: Union if s.children.exists(_.output.length != s.children.head.output.length) => + val firstError = s.children.find(_.output.length != s.children.head.output.length).get + failAnalysis( + s""" + |Unions can only be performed on tables with the same number of columns, + | but one table has '${firstError.output.length}' columns and another table has + | '${s.children.head.output.length}' columns""".stripMargin) + case _ => // Fallbacks to the following checks } @@ -162,11 +224,10 @@ trait CheckAnalysis { case p @ Project(exprs, _) if containsMultipleGenerators(exprs) => failAnalysis( s"""Only a single table generating function is allowed in a SELECT clause, found: - | ${exprs.map(_.prettyString).mkString(",")}""".stripMargin) + | ${exprs.map(_.sql).mkString(",")}""".stripMargin) - // Special handling for cases when self-join introduce duplicate expression ids. - case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => - val conflictingAttributes = left.outputSet.intersect(right.outputSet) + case j: Join if !j.duplicateResolved => + val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet) failAnalysis( s""" |Failure when resolving conflicting references in Join: @@ -174,17 +235,29 @@ trait CheckAnalysis { |Conflicting attributes: ${conflictingAttributes.mkString(",")} |""".stripMargin) + case i: Intersect if !i.duplicateResolved => + val conflictingAttributes = i.left.outputSet.intersect(i.right.outputSet) + failAnalysis( + s""" + |Failure when resolving conflicting references in Intersect: + |$plan + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + |""".stripMargin) + case o if !o.resolved => failAnalysis( s"unresolved operator ${operator.simpleString}") case o if o.expressions.exists(!_.deterministic) && - !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] => + !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] && + !o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] => + // The rule above is used to check Aggregate operator. failAnalysis( - s"""nondeterministic expressions are only allowed in Project or Filter, found: - | ${o.expressions.map(_.prettyString).mkString(",")} + s"""nondeterministic expressions are only allowed in + |Project, Filter, Aggregate or Window, found: + | ${o.expressions.map(_.sql).mkString(",")} |in operator ${operator.simpleString} - """.stripMargin) + """.stripMargin) case _ => // Analysis successful! } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala new file mode 100644 index 0000000000000..9c38dd2ee4e53 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types._ + + +// scalastyle:off +/** + * Calculates and propagates precision for fixed-precision decimals. Hive has a number of + * rules for this based on the SQL standard and MS SQL: + * https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf + * https://msdn.microsoft.com/en-us/library/ms190476.aspx + * + * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2 + * respectively, then the following operations have the following precision / scale: + * + * Operation Result Precision Result Scale + * ------------------------------------------------------------------------ + * e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) + * e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) + * e1 * e2 p1 + p2 + 1 s1 + s2 + * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) + * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) + * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) + * sum(e1) p1 + 10 s1 + * avg(e1) p1 + 4 s1 + 4 + * + * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited + * precision, do the math on unlimited-precision numbers, then introduce casts back to the + * required fixed precision. This allows us to do all rounding and overflow handling in the + * cast-to-fixed-precision operator. + * + * In addition, when mixing non-decimal types with decimals, we use the following rules: + * - BYTE gets turned into DECIMAL(3, 0) + * - SHORT gets turned into DECIMAL(5, 0) + * - INT gets turned into DECIMAL(10, 0) + * - LONG gets turned into DECIMAL(20, 0) + * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE + */ +// scalastyle:on +object DecimalPrecision extends Rule[LogicalPlan] { + import scala.math.{max, min} + + private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType + + // Returns the wider decimal type that's wider than both of them + def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType = { + widerDecimalType(d1.precision, d1.scale, d2.precision, d2.scale) + } + // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) + def widerDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = { + val scale = max(s1, s2) + val range = max(p1 - s1, p2 - s2) + DecimalType.bounded(range + scale, scale) + } + + private def promotePrecision(e: Expression, dataType: DataType): Expression = { + PromotePrecision(Cast(e, dataType)) + } + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + // fix decimal precision for expressions + case q => q.transformExpressions( + decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal)) + } + + /** Decimal precision promotion for +, -, *, /, %, pmod, and binary comparison. */ + private val decimalAndDecimal: PartialFunction[Expression, Expression] = { + // Skip nodes whose children have not been resolved yet + case e if !e.childrenResolved => e + + // Skip nodes who is already promoted + case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e + + case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + CheckOverflow(Add(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) + + case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + CheckOverflow(Subtract(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) + + case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + val resultType = DecimalType.bounded(p1 + p2 + 1, s1 + s2) + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), + resultType) + + case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) + var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) + val diff = (intDig + decDig) - DecimalType.MAX_SCALE + if (diff > 0) { + decDig -= diff / 2 + 1 + intDig = DecimalType.MAX_SCALE - decDig + } + val resultType = DecimalType.bounded(intDig + decDig, decDig) + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), + resultType) + + case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + // resultType may have lower precision, so we cast them into wider type first. + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), + resultType) + + case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + // resultType may have lower precision, so we cast them into wider type first. + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), + resultType) + + case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + val resultType = widerDecimalType(p1, s1, p2, s2) + b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType))) + + // TODO: MaxOf, MinOf, etc might want other rules + // SUM and AVERAGE are handled by the implementations of those expressions + } + + /** + * Strength reduction for comparing integral expressions with decimal literals. + * 1. int_col > decimal_literal => int_col > floor(decimal_literal) + * 2. int_col >= decimal_literal => int_col >= ceil(decimal_literal) + * 3. int_col < decimal_literal => int_col < ceil(decimal_literal) + * 4. int_col <= decimal_literal => int_col <= floor(decimal_literal) + * 5. decimal_literal > int_col => ceil(decimal_literal) > int_col + * 6. decimal_literal >= int_col => floor(decimal_literal) >= int_col + * 7. decimal_literal < int_col => floor(decimal_literal) < int_col + * 8. decimal_literal <= int_col => ceil(decimal_literal) <= int_col + * + * Note that technically this is an "optimization" and should go into the optimizer. However, + * by the time the optimizer runs, these comparison expressions would be pretty hard to pattern + * match because there are multiple (at least 2) levels of casts involved. + * + * There are a lot more possible rules we can implement, but we don't do them + * because we are not sure how common they are. + */ + private val integralAndDecimalLiteral: PartialFunction[Expression, Expression] = { + + case GreaterThan(i @ IntegralType(), DecimalLiteral(value)) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + TrueLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + FalseLiteral + } else { + GreaterThan(i, Literal(value.floor.toLong)) + } + + case GreaterThanOrEqual(i @ IntegralType(), DecimalLiteral(value)) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + TrueLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + FalseLiteral + } else { + GreaterThanOrEqual(i, Literal(value.ceil.toLong)) + } + + case LessThan(i @ IntegralType(), DecimalLiteral(value)) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + FalseLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + TrueLiteral + } else { + LessThan(i, Literal(value.ceil.toLong)) + } + + case LessThanOrEqual(i @ IntegralType(), DecimalLiteral(value)) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + FalseLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + TrueLiteral + } else { + LessThanOrEqual(i, Literal(value.floor.toLong)) + } + + case GreaterThan(DecimalLiteral(value), i @ IntegralType()) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + FalseLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + TrueLiteral + } else { + GreaterThan(Literal(value.ceil.toLong), i) + } + + case GreaterThanOrEqual(DecimalLiteral(value), i @ IntegralType()) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + FalseLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + TrueLiteral + } else { + GreaterThanOrEqual(Literal(value.floor.toLong), i) + } + + case LessThan(DecimalLiteral(value), i @ IntegralType()) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + TrueLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + FalseLiteral + } else { + LessThan(Literal(value.floor.toLong), i) + } + + case LessThanOrEqual(DecimalLiteral(value), i @ IntegralType()) => + if (DecimalLiteral.smallerThanSmallestLong(value)) { + TrueLiteral + } else if (DecimalLiteral.largerThanLargestLong(value)) { + FalseLiteral + } else { + LessThanOrEqual(Literal(value.ceil.toLong), i) + } + } + + /** + * Type coercion for BinaryOperator in which one side is a non-decimal numeric, and the other + * side is a decimal. + */ + private val nondecimalAndDecimal: PartialFunction[Expression, Expression] = { + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => + (left.dataType, right.dataType) match { + case (t: IntegralType, DecimalType.Fixed(p, s)) => + b.makeCopy(Array(Cast(left, DecimalType.forType(t)), right)) + case (DecimalType.Fixed(p, s), t: IntegralType) => + b.makeCopy(Array(left, Cast(right, DecimalType.forType(t)))) + case (t, DecimalType.Fixed(p, s)) if isFloat(t) => + b.makeCopy(Array(left, Cast(right, DoubleType))) + case (DecimalType.Fixed(p, s), t) if isFloat(t) => + b.makeCopy(Array(Cast(left, DoubleType), right)) + case _ => + b + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala new file mode 100644 index 0000000000000..2e30d83a60970 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.IntegerType + +/** + * This rule rewrites an aggregate query with distinct aggregations into an expanded double + * aggregation in which the regular aggregation expressions and every distinct clause is aggregated + * in a separate group. The results are then combined in a second aggregate. + * + * For example (in scala): + * {{{ + * val data = Seq( + * ("a", "ca1", "cb1", 10), + * ("a", "ca1", "cb2", 5), + * ("b", "ca1", "cb1", 13)) + * .toDF("key", "cat1", "cat2", "value") + * data.registerTempTable("data") + * + * val agg = data.groupBy($"key") + * .agg( + * countDistinct($"cat1").as("cat1_cnt"), + * countDistinct($"cat2").as("cat2_cnt"), + * sum($"value").as("total")) + * }}} + * + * This translates to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [COUNT(DISTINCT 'cat1), + * COUNT(DISTINCT 'cat2), + * sum('value)] + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) + * LocalTableScan [...] + * }}} + * + * This rule rewrites this logical plan to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [count(if (('gid = 1)) 'cat1 else null), + * count(if (('gid = 2)) 'cat2 else null), + * first(if (('gid = 0)) 'total else null) ignore nulls] + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) + * Aggregate( + * key = ['key, 'cat1, 'cat2, 'gid] + * functions = [sum('value)] + * output = ['key, 'cat1, 'cat2, 'gid, 'total]) + * Expand( + * projections = [('key, null, null, 0, cast('value as bigint)), + * ('key, 'cat1, null, 1, null), + * ('key, null, 'cat2, 2, null)] + * output = ['key, 'cat1, 'cat2, 'gid, 'value]) + * LocalTableScan [...] + * }}} + * + * The rule does the following things here: + * 1. Expand the data. There are three aggregation groups in this query: + * i. the non-distinct group; + * ii. the distinct 'cat1 group; + * iii. the distinct 'cat2 group. + * An expand operator is inserted to expand the child data for each group. The expand will null + * out all unused columns for the given group; this must be done in order to ensure correctness + * later on. Groups can by identified by a group id (gid) column added by the expand operator. + * 2. De-duplicate the distinct paths and aggregate the non-aggregate path. The group by clause of + * this aggregate consists of the original group by clause, all the requested distinct columns + * and the group id. Both de-duplication of distinct column and the aggregation of the + * non-distinct group take advantage of the fact that we group by the group id (gid) and that we + * have nulled out all non-relevant columns the given group. + * 3. Aggregating the distinct groups and combining this with the results of the non-distinct + * aggregation. In this step we use the group id to filter the inputs for the aggregate + * functions. The result of the non-distinct group are 'aggregated' by using the first operator, + * it might be more elegant to use the native UDAF merge mechanism for this in the future. + * + * This rule duplicates the input data by two or more times (# distinct groups + an optional + * non-distinct group). This will put quite a bit of memory pressure of the used aggregate and + * exchange operators. Keeping the number of distinct groups as low a possible should be priority, + * we could improve this in the current rule by applying more advanced expression canonicalization + * techniques. + */ +object DistinctAggregationRewriter extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case a: Aggregate => rewrite(a) + } + + def rewrite(a: Aggregate): Aggregate = { + + // Collect all aggregate expressions. + val aggExpressions = a.aggregateExpressions.flatMap { e => + e.collect { + case ae: AggregateExpression => ae + } + } + + // Extract distinct aggregate expressions. + val distinctAggGroups = aggExpressions + .filter(_.isDistinct) + .groupBy(_.aggregateFunction.children.toSet) + + // Aggregation strategy can handle the query with single distinct + if (distinctAggGroups.size > 1) { + // Create the attributes for the grouping id and the group by clause. + val gid = + new AttributeReference("gid", IntegerType, false)(isGenerated = true) + val groupByMap = a.groupingExpressions.collect { + case ne: NamedExpression => ne -> ne.toAttribute + case e => e -> new AttributeReference(e.sql, e.dataType, e.nullable)() + } + val groupByAttrs = groupByMap.map(_._2) + + // Functions used to modify aggregate functions and their inputs. + def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e)) + def patchAggregateFunctionChildren( + af: AggregateFunction)( + attrs: Expression => Expression): AggregateFunction = { + af.withNewChildren(af.children.map { + case afc => attrs(afc) + }).asInstanceOf[AggregateFunction] + } + + // Setup unique distinct aggregate children. + val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct + val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) + val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) + + // Setup expand & aggregate operators for distinct aggregate expressions. + val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap + val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { + case ((group, expressions), i) => + val id = Literal(i + 1) + + // Expand projection + val projection = distinctAggChildren.map { + case e if group.contains(e) => e + case e => nullify(e) + } :+ id + + // Final aggregate + val operators = expressions.map { e => + val af = e.aggregateFunction + val naf = patchAggregateFunctionChildren(af) { x => + evalWithinGroup(id, distinctAggChildAttrLookup(x)) + } + (e, e.copy(aggregateFunction = naf, isDistinct = false)) + } + + (projection, operators) + } + + // Setup expand for the 'regular' aggregate expressions. + val regularAggExprs = aggExpressions.filter(!_.isDistinct) + val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct + val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) + + // Setup aggregates for 'regular' aggregate expressions. + val regularGroupId = Literal(0) + val regularAggChildAttrLookup = regularAggChildAttrMap.toMap + val regularAggOperatorMap = regularAggExprs.map { e => + // Perform the actual aggregation in the initial aggregate. + val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup) + val operator = Alias(e.copy(aggregateFunction = af), e.sql)() + + // Select the result of the first aggregate in the last aggregate. + val result = AggregateExpression( + aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)), + mode = Complete, + isDistinct = false) + + // Some aggregate functions (COUNT) have the special property that they can return a + // non-null result without any input. We need to make sure we return a result in this case. + val resultWithDefault = af.defaultResult match { + case Some(lit) => Coalesce(Seq(result, lit)) + case None => result + } + + // Return a Tuple3 containing: + // i. The original aggregate expression (used for look ups). + // ii. The actual aggregation operator (used in the first aggregate). + // iii. The operator that selects and returns the result (used in the second aggregate). + (e, operator, resultWithDefault) + } + + // Construct the regular aggregate input projection only if we need one. + val regularAggProjection = if (regularAggExprs.nonEmpty) { + Seq(a.groupingExpressions ++ + distinctAggChildren.map(nullify) ++ + Seq(regularGroupId) ++ + regularAggChildren) + } else { + Seq.empty[Seq[Expression]] + } + + // Construct the distinct aggregate input projections. + val regularAggNulls = regularAggChildren.map(nullify) + val distinctAggProjections = distinctAggOperatorMap.map { + case (projection, _) => + a.groupingExpressions ++ + projection ++ + regularAggNulls + } + + // Construct the expand operator. + val expand = Expand( + regularAggProjection ++ distinctAggProjections, + groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2), + a.child) + + // Construct the first aggregate operator. This de-duplicates the all the children of + // distinct operators, and applies the regular aggregate operators. + val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid + val firstAggregate = Aggregate( + firstAggregateGroupBy, + firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2), + expand) + + // Construct the second aggregate + val transformations: Map[Expression, Expression] = + (distinctAggOperatorMap.flatMap(_._2) ++ + regularAggOperatorMap.map(e => (e._1, e._3))).toMap + + val patchedAggExpressions = a.aggregateExpressions.map { e => + e.transformDown { + case e: Expression => + // The same GROUP BY clauses can have different forms (different names for instance) in + // the groupBy and aggregate expressions of an aggregate. This makes a map lookup + // tricky. So we do a linear search for a semantically equal group by expression. + groupByMap + .find(ge => e.semanticEquals(ge._1)) + .map(_._2) + .getOrElse(transformations.getOrElse(e, e)) + }.asInstanceOf[NamedExpression] + } + Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate) + } else { + a + } + } + + private def nullify(e: Expression) = Literal.create(null, e.dataType) + + private def expressionAttributePair(e: Expression) = + // We are creating a new reference here instead of reusing the attribute in case of a + // NamedExpression. This is done to prevent collisions between distinct and regular aggregate + // children, in this case attribute reuse causes the input of the regular aggregate to bound to + // the (nulled out) input of the distinct aggregate. + e -> new AttributeReference(e.sql, e.dataType, true)() +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d4334d16289a5..f2abf136da685 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -24,6 +24,7 @@ import scala.util.{Failure, Success, Try} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.util.StringKeyHashMap @@ -44,11 +45,24 @@ trait FunctionRegistry { /* Get the class of the registered function by specified name. */ def lookupFunction(name: String): Option[ExpressionInfo] + + /* Get the builder of the registered function by specified name. */ + def lookupFunctionBuilder(name: String): Option[FunctionBuilder] + + /** Drop a function and return whether the function existed. */ + def dropFunction(name: String): Boolean + + /** Checks if a function with a given name exists. */ + def functionExists(name: String): Boolean = lookupFunction(name).isDefined + + /** Clear all registered functions. */ + def clear(): Unit + } class SimpleFunctionRegistry extends FunctionRegistry { - private val functionBuilders = + private[sql] val functionBuilders = StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false) override def registerFunction( @@ -75,6 +89,18 @@ class SimpleFunctionRegistry extends FunctionRegistry { functionBuilders.get(name).map(_._1) } + override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = synchronized { + functionBuilders.get(name).map(_._2) + } + + override def dropFunction(name: String): Boolean = synchronized { + functionBuilders.remove(name).isDefined + } + + override def clear(): Unit = { + functionBuilders.clear() + } + def copy(): SimpleFunctionRegistry = synchronized { val registry = new SimpleFunctionRegistry functionBuilders.iterator.foreach { case (name, (info, builder)) => @@ -105,6 +131,19 @@ object EmptyFunctionRegistry extends FunctionRegistry { override def lookupFunction(name: String): Option[ExpressionInfo] = { throw new UnsupportedOperationException } + + override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = { + throw new UnsupportedOperationException + } + + override def dropFunction(name: String): Boolean = { + throw new UnsupportedOperationException + } + + override def clear(): Unit = { + throw new UnsupportedOperationException + } + } @@ -112,6 +151,7 @@ object FunctionRegistry { type FunctionBuilder = Seq[Expression] => Expression + // Note: Whenever we add a new entry here, make sure we also update ExpressionToSQLSuite val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map( // misc non-aggregate functions expression[Abs]("abs"), @@ -124,13 +164,14 @@ object FunctionRegistry { expression[IsNull]("isnull"), expression[IsNotNull]("isnotnull"), expression[Least]("least"), + expression[CreateMap]("map"), + expression[CreateNamedStruct]("named_struct"), + expression[NaNvl]("nanvl"), expression[Coalesce]("nvl"), expression[Rand]("rand"), expression[Randn]("randn"), expression[CreateStruct]("struct"), - expression[CreateNamedStruct]("named_struct"), - expression[Sqrt]("sqrt"), - expression[NaNvl]("nanvl"), + expression[CaseWhen]("when"), // math functions expression[Acos]("acos"), @@ -144,24 +185,26 @@ object FunctionRegistry { expression[Cos]("cos"), expression[Cosh]("cosh"), expression[Conv]("conv"), + expression[ToDegrees]("degrees"), expression[EulerNumber]("e"), expression[Exp]("exp"), expression[Expm1]("expm1"), expression[Floor]("floor"), expression[Factorial]("factorial"), - expression[Hypot]("hypot"), expression[Hex]("hex"), + expression[Hypot]("hypot"), expression[Logarithm]("log"), - expression[Log]("ln"), expression[Log10]("log10"), expression[Log1p]("log1p"), expression[Log2]("log2"), + expression[Log]("ln"), expression[UnaryMinus]("negative"), expression[Pi]("pi"), - expression[Pow]("pow"), - expression[Pow]("power"), expression[Pmod]("pmod"), expression[UnaryPositive]("positive"), + expression[Pow]("pow"), + expression[Pow]("power"), + expression[ToRadians]("radians"), expression[Rint]("rint"), expression[Round]("round"), expression[ShiftLeft]("shiftleft"), @@ -171,22 +214,32 @@ object FunctionRegistry { expression[Signum]("signum"), expression[Sin]("sin"), expression[Sinh]("sinh"), + expression[Sqrt]("sqrt"), expression[Tan]("tan"), expression[Tanh]("tanh"), - expression[ToDegrees]("degrees"), - expression[ToRadians]("radians"), + + expression[Add]("+"), + expression[Subtract]("-"), + expression[Multiply]("*"), + expression[Divide]("/"), + expression[Remainder]("%"), // aggregate functions + expression[HyperLogLogPlusPlus]("approx_count_distinct"), expression[Average]("avg"), expression[Corr]("corr"), expression[Count]("count"), + expression[CovPopulation]("covar_pop"), + expression[CovSample]("covar_samp"), expression[First]("first"), expression[First]("first_value"), + expression[Kurtosis]("kurtosis"), expression[Last]("last"), expression[Last]("last_value"), expression[Max]("max"), expression[Average]("mean"), expression[Min]("min"), + expression[Skewness]("skewness"), expression[StddevSamp]("stddev"), expression[StddevPop]("stddev_pop"), expression[StddevSamp]("stddev_samp"), @@ -194,36 +247,36 @@ object FunctionRegistry { expression[VarianceSamp]("variance"), expression[VariancePop]("var_pop"), expression[VarianceSamp]("var_samp"), - expression[Skewness]("skewness"), - expression[Kurtosis]("kurtosis"), // string functions expression[Ascii]("ascii"), expression[Base64]("base64"), expression[Concat]("concat"), expression[ConcatWs]("concat_ws"), - expression[Encode]("encode"), expression[Decode]("decode"), + expression[Encode]("encode"), expression[FindInSet]("find_in_set"), expression[FormatNumber]("format_number"), + expression[FormatString]("format_string"), expression[GetJsonObject]("get_json_object"), expression[InitCap]("initcap"), - expression[JsonTuple]("json_tuple"), + expression[StringInstr]("instr"), expression[Lower]("lcase"), - expression[Lower]("lower"), expression[Length]("length"), expression[Levenshtein]("levenshtein"), - expression[RegExpExtract]("regexp_extract"), - expression[RegExpReplace]("regexp_replace"), - expression[StringInstr]("instr"), + expression[Like]("like"), + expression[Lower]("lower"), expression[StringLocate]("locate"), expression[StringLPad]("lpad"), expression[StringTrimLeft]("ltrim"), - expression[FormatString]("format_string"), + expression[JsonTuple]("json_tuple"), expression[FormatString]("printf"), - expression[StringRPad]("rpad"), + expression[RegExpExtract]("regexp_extract"), + expression[RegExpReplace]("regexp_replace"), expression[StringRepeat]("repeat"), expression[StringReverse]("reverse"), + expression[RLike]("rlike"), + expression[StringRPad]("rpad"), expression[StringTrimRight]("rtrim"), expression[SoundEx]("soundex"), expression[StringSpace]("space"), @@ -233,8 +286,8 @@ object FunctionRegistry { expression[SubstringIndex]("substring_index"), expression[StringTranslate]("translate"), expression[StringTrim]("trim"), - expression[UnBase64]("unbase64"), expression[Upper]("ucase"), + expression[UnBase64]("unbase64"), expression[Unhex]("unhex"), expression[Upper]("upper"), @@ -257,28 +310,71 @@ object FunctionRegistry { expression[Month]("month"), expression[MonthsBetween]("months_between"), expression[NextDay]("next_day"), + expression[CurrentTimestamp]("now"), expression[Quarter]("quarter"), expression[Second]("second"), expression[ToDate]("to_date"), + expression[ToUnixTimestamp]("to_unix_timestamp"), expression[ToUTCTimestamp]("to_utc_timestamp"), expression[TruncDate]("trunc"), expression[UnixTimestamp]("unix_timestamp"), expression[WeekOfYear]("weekofyear"), expression[Year]("year"), + expression[TimeWindow]("window"), // collection functions + expression[ArrayContains]("array_contains"), expression[Size]("size"), expression[SortArray]("sort_array"), - expression[ArrayContains]("array_contains"), // misc functions expression[Crc32]("crc32"), expression[Md5]("md5"), + expression[Murmur3Hash]("hash"), expression[Sha1]("sha"), expression[Sha1]("sha1"), expression[Sha2]("sha2"), expression[SparkPartitionID]("spark_partition_id"), - expression[InputFileName]("input_file_name") + expression[InputFileName]("input_file_name"), + expression[MonotonicallyIncreasingID]("monotonically_increasing_id"), + + // grouping sets + expression[Cube]("cube"), + expression[Rollup]("rollup"), + expression[Grouping]("grouping"), + expression[GroupingID]("grouping_id"), + + // window functions + expression[Lead]("lead"), + expression[Lag]("lag"), + expression[RowNumber]("row_number"), + expression[CumeDist]("cume_dist"), + expression[NTile]("ntile"), + expression[Rank]("rank"), + expression[DenseRank]("dense_rank"), + expression[PercentRank]("percent_rank"), + + // predicates + expression[And]("and"), + expression[In]("in"), + expression[Not]("not"), + expression[Or]("or"), + + expression[EqualNullSafe]("<=>"), + expression[EqualTo]("="), + expression[EqualTo]("=="), + expression[GreaterThan](">"), + expression[GreaterThanOrEqual](">="), + expression[LessThan]("<"), + expression[LessThanOrEqual]("<="), + expression[Not]("!"), + + // bitwise + expression[BitwiseAnd]("&"), + expression[BitwiseNot]("~"), + expression[BitwiseOr]("|"), + expression[BitwiseXor]("^") + ) val builtin: SimpleFunctionRegistry = { @@ -311,7 +407,10 @@ object FunctionRegistry { } Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match { case Success(e) => e - case Failure(e) => throw new AnalysisException(e.getMessage) + case Failure(e) => + // the exception is an invocation exception. To get a meaningful message, we need the + // cause. + throw new AnalysisException(e.getCause.getMessage) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 84e2b1366f626..823d2495fad80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -19,17 +19,32 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable +import scala.annotation.tailrec +import scala.collection.mutable + import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ /** - * A collection of [[Rule Rules]] that can be used to coerce differing types that - * participate in operations into compatible ones. Most of these rules are based on Hive semantics, - * but they do not introduce any dependencies on the hive codebase. For this reason they remain in - * Catalyst until we have a more standard set of coercions. + * A collection of [[Rule]] that can be used to coerce differing types that participate in + * operations into compatible ones. + * + * Most of these rules are based on Hive semantics, but they do not introduce any dependencies on + * the hive codebase. + * + * Notes about type widening / tightest common types: Broadly, there are two cases when we need + * to widen data types (e.g. union, binary comparison). In case 1, we are looking for a common + * data type for two or more data types, and in this case no loss of precision is allowed. Examples + * include type inference in JSON (e.g. what's the column's data type if one row is an integer + * while the other row is a long?). In case 2, we are looking for a widened data type with + * some acceptable loss of precision (e.g. there is no common type for double and decimal because + * double's range is larger than decimal, and yet decimal is more precise than double, but in + * union we would cast the decimal into double). */ object HiveTypeCoercion { @@ -52,7 +67,7 @@ object HiveTypeCoercion { // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. // The conversion for integral and floating point types have a linear widening hierarchy: - private val numericPrecedence = + private[sql] val numericPrecedence = IndexedSeq( ByteType, ShortType, @@ -62,10 +77,12 @@ object HiveTypeCoercion { DoubleType) /** + * Case 1 type widening (see the classdoc comment above for HiveTypeCoercion). + * * Find the tightest common type of two types that might be used in a binary expression. * This handles all numeric types except fixed-precision decimals interacting with each other or * with primitive types, because in that case the precision and scale of the result depends on - * the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]]. + * the operation. Those rules are implemented in [[DecimalPrecision]]. */ val findTightestCommonTypeOfTwo: (DataType, DataType) => Option[DataType] = { case (t1, t2) if t1 == t2 => Some(t1) @@ -117,6 +134,12 @@ object HiveTypeCoercion { }) } + /** + * Case 2 type widening (see the classdoc comment above for HiveTypeCoercion). + * + * i.e. the main difference with [[findTightestCommonTypeOfTwo]] is that here we allow some + * loss of precision when widening decimal and double. + */ private def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = (t1, t2) match { case (t1: DecimalType, t2: DecimalType) => Some(DecimalPrecision.widerDecimalType(t1, t2)) @@ -124,9 +147,7 @@ object HiveTypeCoercion { Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) case (d: DecimalType, t: IntegralType) => Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) - case (t: FractionalType, d: DecimalType) => - Some(DoubleType) - case (d: DecimalType, t: FractionalType) => + case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) => Some(DoubleType) case _ => findTightestCommonTypeToString(t1, t2) @@ -139,6 +160,9 @@ object HiveTypeCoercion { }) } + private def haveSameType(exprs: Seq[Expression]): Boolean = + exprs.map(_.dataType).distinct.length == 1 + /** * Applies any changes to [[AttributeReference]] data types that are made by other rules to * instances higher in the query tree. @@ -199,41 +223,65 @@ object HiveTypeCoercion { */ object WidenSetOperationTypes extends Rule[LogicalPlan] { - private[this] def widenOutputTypes( - planName: String, - left: LogicalPlan, - right: LogicalPlan): (LogicalPlan, LogicalPlan) = { - require(left.output.length == right.output.length) + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if p.analyzed => p - val castedTypes = left.output.zip(right.output).map { - case (lhs, rhs) if lhs.dataType != rhs.dataType => - findWiderTypeForTwo(lhs.dataType, rhs.dataType) - case other => None - } + case s @ SetOperation(left, right) if s.childrenResolved && + left.output.length == right.output.length && !s.resolved => + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) + assert(newChildren.length == 2) + s.makeCopy(Array(newChildren.head, newChildren.last)) - def castOutput(plan: LogicalPlan): LogicalPlan = { - val casted = plan.output.zip(castedTypes).map { - case (e, Some(dt)) if e.dataType != dt => - Alias(Cast(e, dt), e.name)() - case (e, _) => e - } - Project(casted, plan) - } + case s: Union if s.childrenResolved && + s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children) + s.makeCopy(Array(newChildren)) + } - if (castedTypes.exists(_.isDefined)) { - (castOutput(left), castOutput(right)) + /** Build new children with the widest types for each attribute among all the children */ + private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = { + require(children.forall(_.output.length == children.head.output.length)) + + // Get a sequence of data types, each of which is the widest type of this specific attribute + // in all the children + val targetTypes: Seq[DataType] = + getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType]()) + + if (targetTypes.nonEmpty) { + // Add an extra Project if the targetTypes are different from the original types. + children.map(widenTypes(_, targetTypes)) } else { - (left, right) + // Unable to find a target type to widen, then just return the original set. + children } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case p if p.analyzed => p + /** Get the widest type for each attribute in all the children */ + @tailrec private def getWidestTypes( + children: Seq[LogicalPlan], + attrIndex: Int, + castedTypes: mutable.Queue[DataType]): Seq[DataType] = { + // Return the result after the widen data types have been found for all the children + if (attrIndex >= children.head.output.length) return castedTypes.toSeq + + // For the attrIndex-th attribute, find the widest type + findWiderCommonType(children.map(_.output(attrIndex).dataType)) match { + // If unable to find an appropriate widen type for this column, return an empty Seq + case None => Seq.empty[DataType] + // Otherwise, record the result in the queue and find the type for the next column + case Some(widenType) => + castedTypes.enqueue(widenType) + getWidestTypes(children, attrIndex + 1, castedTypes) + } + } - case s @ SetOperation(left, right) if s.childrenResolved - && left.output.length == right.output.length && !s.resolved => - val (newLeft, newRight) = widenOutputTypes(s.nodeName, left, right) - s.makeCopy(Array(newLeft, newRight)) + /** Given a plan, add an extra project on top to widen some columns' data types. */ + private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]): LogicalPlan = { + val casted = plan.output.zip(targetTypes).map { + case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)() + case (e, _) => e + } + Project(casted, plan) } } @@ -262,7 +310,7 @@ object HiveTypeCoercion { case p @ Equality(left @ TimestampType(), right @ StringType()) => p.makeCopy(Array(left, Cast(right, TimestampType))) - // We should cast all relative timestamp/date/string comparison into string comparisions + // We should cast all relative timestamp/date/string comparison into string comparisons // This behaves as a user would expect because timestamp strings sort lexicographically. // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true case p @ BinaryComparison(left @ StringType(), right @ DateType()) => @@ -280,6 +328,12 @@ object HiveTypeCoercion { case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) => p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) + // Checking NullType + case p @ BinaryComparison(left @ StringType(), right @ NullType()) => + p.makeCopy(Array(left, Literal.create(null, StringType))) + case p @ BinaryComparison(left @ NullType(), right @ StringType()) => + p.makeCopy(Array(Literal.create(null, StringType), right)) + case p @ BinaryComparison(left @ StringType(), right) if right.dataType != StringType => p.makeCopy(Array(Cast(left, DoubleType), right)) case p @ BinaryComparison(left, right @ StringType()) if left.dataType != StringType => @@ -295,7 +349,6 @@ object HiveTypeCoercion { i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) - case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) @@ -326,143 +379,6 @@ object HiveTypeCoercion { } } - // scalastyle:off - /** - * Calculates and propagates precision for fixed-precision decimals. Hive has a number of - * rules for this based on the SQL standard and MS SQL: - * https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf - * https://msdn.microsoft.com/en-us/library/ms190476.aspx - * - * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2 - * respectively, then the following operations have the following precision / scale: - * - * Operation Result Precision Result Scale - * ------------------------------------------------------------------------ - * e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) - * e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) - * e1 * e2 p1 + p2 + 1 s1 + s2 - * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) - * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) - * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) - * sum(e1) p1 + 10 s1 - * avg(e1) p1 + 4 s1 + 4 - * - * Catalyst also has unlimited-precision decimals. For those, all ops return unlimited precision. - * - * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited - * precision, do the math on unlimited-precision numbers, then introduce casts back to the - * required fixed precision. This allows us to do all rounding and overflow handling in the - * cast-to-fixed-precision operator. - * - * In addition, when mixing non-decimal types with decimals, we use the following rules: - * - BYTE gets turned into DECIMAL(3, 0) - * - SHORT gets turned into DECIMAL(5, 0) - * - INT gets turned into DECIMAL(10, 0) - * - LONG gets turned into DECIMAL(20, 0) - * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE - * - * Note: Union/Except/Interact is handled by WidenTypes - */ - // scalastyle:on - object DecimalPrecision extends Rule[LogicalPlan] { - import scala.math.{max, min} - - private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType - - // Returns the wider decimal type that's wider than both of them - def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType = { - widerDecimalType(d1.precision, d1.scale, d2.precision, d2.scale) - } - // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) - def widerDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = { - val scale = max(s1, s2) - val range = max(p1 - s1, p2 - s2) - DecimalType.bounded(range + scale, scale) - } - - private def promotePrecision(e: Expression, dataType: DataType): Expression = { - PromotePrecision(Cast(e, dataType)) - } - - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - - // fix decimal precision for expressions - case q => q.transformExpressions { - // Skip nodes whose children have not been resolved yet - case e if !e.childrenResolved => e - - // Skip nodes who is already promoted - case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e - - case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - CheckOverflow(Add(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) - - case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - CheckOverflow(Subtract(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) - - case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(p1 + p2 + 1, s1 + s2) - val widerType = widerDecimalType(p1, s1, p2, s2) - CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), - resultType) - - case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) - var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) - val diff = (intDig + decDig) - DecimalType.MAX_SCALE - if (diff > 0) { - decDig -= diff / 2 + 1 - intDig = DecimalType.MAX_SCALE - decDig - } - val resultType = DecimalType.bounded(intDig + decDig, decDig) - val widerType = widerDecimalType(p1, s1, p2, s2) - CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), - resultType) - - case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) - // resultType may have lower precision, so we cast them into wider type first. - val widerType = widerDecimalType(p1, s1, p2, s2) - CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), - resultType) - - case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) - // resultType may have lower precision, so we cast them into wider type first. - val widerType = widerDecimalType(p1, s1, p2, s2) - CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), - resultType) - - case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - val resultType = widerDecimalType(p1, s1, p2, s2) - b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType))) - - // Promote integers inside a binary expression with fixed-precision decimals to decimals, - // and fixed-precision decimals in an expression with floats / doubles to doubles - case b @ BinaryOperator(left, right) if left.dataType != right.dataType => - (left.dataType, right.dataType) match { - case (t: IntegralType, DecimalType.Fixed(p, s)) => - b.makeCopy(Array(Cast(left, DecimalType.forType(t)), right)) - case (DecimalType.Fixed(p, s), t: IntegralType) => - b.makeCopy(Array(left, Cast(right, DecimalType.forType(t)))) - case (t, DecimalType.Fixed(p, s)) if isFloat(t) => - b.makeCopy(Array(left, Cast(right, DoubleType))) - case (DecimalType.Fixed(p, s), t) if isFloat(t) => - b.makeCopy(Array(Cast(left, DoubleType), right)) - case _ => - b - } - - // TODO: MaxOf, MinOf, etc might want other rules - - // SUM and AVERAGE are handled by the implementations of those expressions - } - } - } - /** * Changes numeric values to booleans so that expressions like true = 1 can be evaluated. */ @@ -470,27 +386,6 @@ object HiveTypeCoercion { private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE) private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO) - private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = { - CaseKeyWhen(numericExpr, Seq( - Literal(trueValues.head), booleanExpr, - Literal(falseValues.head), Not(booleanExpr), - Literal(false))) - } - - private def transform(booleanExpr: Expression, numericExpr: Expression) = { - If(Or(IsNull(booleanExpr), IsNull(numericExpr)), - Literal.create(null, BooleanType), - buildCaseKeyWhen(booleanExpr, numericExpr)) - } - - private def transformNullSafe(booleanExpr: Expression, numericExpr: Expression) = { - CaseWhen(Seq( - And(IsNull(booleanExpr), IsNull(numericExpr)), Literal(true), - Or(IsNull(booleanExpr), IsNull(numericExpr)), Literal(false), - buildCaseKeyWhen(booleanExpr, numericExpr) - )) - } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -499,6 +394,7 @@ object HiveTypeCoercion { // all other cases are considered as false. // We may simplify the expression if one side is literal numeric values + // TODO: Maybe these rules should go into the optimizer. case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType)) if trueValues.contains(value) => bool case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType)) @@ -517,13 +413,13 @@ object HiveTypeCoercion { if falseValues.contains(value) => And(IsNotNull(bool), Not(bool)) case EqualTo(left @ BooleanType(), right @ NumericType()) => - transform(left , right) + EqualTo(Cast(left, right.dataType), right) case EqualTo(left @ NumericType(), right @ BooleanType()) => - transform(right, left) + EqualTo(left, Cast(right, left.dataType)) case EqualNullSafe(left @ BooleanType(), right @ NumericType()) => - transformNullSafe(left, right) + EqualNullSafe(Cast(left, right.dataType), right) case EqualNullSafe(left @ NumericType(), right @ BooleanType()) => - transformNullSafe(right, left) + EqualNullSafe(left, Cast(right, left.dataType)) } } @@ -550,24 +446,42 @@ object HiveTypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ CreateArray(children) if children.map(_.dataType).distinct.size > 1 => + case a @ CreateArray(children) if !haveSameType(children) => val types = children.map(_.dataType) findTightestCommonTypeAndPromoteToString(types) match { case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) case None => a } + case m @ CreateMap(children) if m.keys.length == m.values.length && + (!haveSameType(m.keys) || !haveSameType(m.values)) => + val newKeys = if (haveSameType(m.keys)) { + m.keys + } else { + val types = m.keys.map(_.dataType) + findTightestCommonTypeAndPromoteToString(types) match { + case Some(finalDataType) => m.keys.map(Cast(_, finalDataType)) + case None => m.keys + } + } + + val newValues = if (haveSameType(m.values)) { + m.values + } else { + val types = m.values.map(_.dataType) + findTightestCommonTypeAndPromoteToString(types) match { + case Some(finalDataType) => m.values.map(Cast(_, finalDataType)) + case None => m.values + } + } + + CreateMap(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) }) + // Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows. case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest. case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType)) case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType)) - case s @ SumDistinct(e @ DecimalType()) => s // Decimal is already the biggest. - case SumDistinct(e @ IntegralType()) if e.dataType != LongType => - SumDistinct(Cast(e, LongType)) - case SumDistinct(e @ FractionalType()) if e.dataType != DoubleType => - SumDistinct(Cast(e, DoubleType)) - case s @ Average(e @ DecimalType()) => s // Decimal is already the biggest. case Average(e @ IntegralType()) if e.dataType != LongType => Average(Cast(e, LongType)) @@ -581,13 +495,27 @@ object HiveTypeCoercion { // Coalesce should return the first non-null value, which could be any column // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. - case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 => + case c @ Coalesce(es) if !haveSameType(es) => val types = es.map(_.dataType) findWiderCommonType(types) match { case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => c } + case g @ Greatest(children) if !haveSameType(children) => + val types = children.map(_.dataType) + findTightestCommonType(types) match { + case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) + case None => g + } + + case l @ Least(children) if !haveSameType(children) => + val types = children.map(_.dataType) + findTightestCommonType(types) match { + case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) + case None => l + } + case NaNvl(l, r) if l.dataType == DoubleType && r.dataType == FloatType => NaNvl(l, Cast(r, DoubleType)) case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType => @@ -618,33 +546,27 @@ object HiveTypeCoercion { */ object CaseWhenCoercion extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { - case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual => - logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}") + case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => val maybeCommonType = findWiderCommonType(c.valueTypes) maybeCommonType.map { commonType => - val castedBranches = c.branches.grouped(2).map { - case Seq(when, value) if value.dataType != commonType => - Seq(when, Cast(value, commonType)) - case Seq(elseVal) if elseVal.dataType != commonType => - Seq(Cast(elseVal, commonType)) - case other => other - }.reduce(_ ++ _) - c match { - case _: CaseWhen => CaseWhen(castedBranches) - case CaseKeyWhen(key, _) => CaseKeyWhen(key, castedBranches) + var changed = false + val newBranches = c.branches.map { case (condition, value) => + if (value.dataType.sameType(commonType)) { + (condition, value) + } else { + changed = true + (condition, Cast(value, commonType)) + } } - }.getOrElse(c) - - case c: CaseKeyWhen if c.childrenResolved && !c.resolved => - val maybeCommonType = - findWiderCommonType((c.key +: c.whenList).map(_.dataType)) - maybeCommonType.map { commonType => - val castedBranches = c.branches.grouped(2).map { - case Seq(whenExpr, thenExpr) if whenExpr.dataType != commonType => - Seq(Cast(whenExpr, commonType), thenExpr) - case other => other - }.reduce(_ ++ _) - CaseKeyWhen(Cast(c.key, commonType), castedBranches) + val newElseValue = c.elseValue.map { value => + if (value.dataType.sameType(commonType)) { + value + } else { + changed = true + Cast(value, commonType) + } + } + if (changed) CaseWhen(newBranches, newElseValue) else c }.getOrElse(c) } } @@ -657,12 +579,11 @@ object HiveTypeCoercion { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if left.dataType != right.dataType => - findTightestCommonTypeToString(left.dataType, right.dataType).map { widestType => + findWiderTypeForTwo(left.dataType, right.dataType).map { widestType => val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) val newRight = if (right.dataType == widestType) right else Cast(right, widestType) If(pred, newLeft, newRight) }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. - // Convert If(null literal, _, _) into boolean type. // In the optimizer, we should short-circuit this directly into false value. case If(pred, left, right) if pred.dataType == NullType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala new file mode 100644 index 0000000000000..5e18316c94bfd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec + + +/** + * Thrown by a catalog when an item cannot be found. The analyzer will rethrow the exception + * as an [[org.apache.spark.sql.AnalysisException]] with the correct position information. + */ +class NoSuchDatabaseException(db: String) extends AnalysisException(s"Database $db not found") + +class NoSuchTableException(db: String, table: String) + extends AnalysisException(s"Table or View $table not found in database $db") + +class NoSuchPartitionException(db: String, table: String, spec: TablePartitionSpec) extends + AnalysisException(s"Partition not found in table $table database $db:\n" + spec.mkString("\n")) + +class NoSuchFunctionException(db: String, func: String) + extends AnalysisException(s"Function $func not found in database $db") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index eae17c86ddc7a..4ec43aba02d66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{errors, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.catalyst.{TableIdentifier, errors} +import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.types.{DataType, StructType} /** @@ -33,7 +34,7 @@ class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: Str errors.TreeNodeException(tree, s"Invalid call to $function on unresolved object", null) /** - * Holds the name of a relation that has yet to be looked up in a [[Catalog]]. + * Holds the name of a relation that has yet to be looked up in a catalog. */ case class UnresolvedRelation( tableIdentifier: TableIdentifier, @@ -58,15 +59,17 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un override def exprId: ExprId = throw new UnresolvedException(this, "exprId") override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") + override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier") override lazy val resolved = false override def newInstance(): UnresolvedAttribute = this override def withNullability(newNullability: Boolean): UnresolvedAttribute = this - override def withQualifiers(newQualifiers: Seq[String]): UnresolvedAttribute = this + override def withQualifier(newQualifier: Option[String]): UnresolvedAttribute = this override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) override def toString: String = s"'$name" + + override def sql: String = quoteIdentifier(name) } object UnresolvedAttribute { @@ -130,6 +133,33 @@ object UnresolvedAttribute { } } +/** + * Represents an unresolved generator, which will be created by the parser for + * the [[org.apache.spark.sql.catalyst.plans.logical.Generate]] operator. + * The analyzer will resolve this generator. + */ +case class UnresolvedGenerator(name: String, children: Seq[Expression]) extends Generator { + + override def elementTypes: Seq[(DataType, Boolean, String)] = + throw new UnresolvedException(this, "elementTypes") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override lazy val resolved = false + + override def prettyName: String = name + override def toString: String = s"'$name(${children.mkString(", ")})" + + override def eval(input: InternalRow = null): TraversableOnce[InternalRow] = + throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") + + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = + throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") + + override def terminate(): TraversableOnce[InternalRow] = + throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") +} + case class UnresolvedFunction( name: String, children: Seq[Expression], @@ -141,7 +171,8 @@ case class UnresolvedFunction( override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false - override def toString: String = s"'$name(${children.mkString(",")})" + override def prettyName: String = name + override def toString: String = s"'$name(${children.mkString(", ")})" } /** @@ -154,8 +185,9 @@ abstract class Star extends LeafExpression with NamedExpression { override def exprId: ExprId = throw new UnresolvedException(this, "exprId") override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") + override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier") override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") + override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance") override lazy val resolved = false def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] @@ -183,7 +215,7 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu case None => input.output // If there is a table, pick out attributes that are part of this table. case Some(t) => if (t.size == 1) { - input.output.filter(_.qualifiers.exists(resolver(_, t.head))) + input.output.filter(_.qualifier.exists(resolver(_, t.head))) } else { List() } @@ -197,16 +229,15 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu if (attribute.isDefined) { // This target resolved to an attribute in child. It must be a struct. Expand it. attribute.get.dataType match { - case s: StructType => { - s.fields.map( f => { - val extract = GetStructField(attribute.get, f, s.getFieldIndex(f.name).get) - Alias(extract, target.get + "." + f.name)() - }) + case s: StructType => s.zipWithIndex.map { + case (f, i) => + val extract = GetStructField(attribute.get, i) + Alias(extract, f.name)() } - case _ => { + + case _ => throw new AnalysisException("Can only star expand struct data types. Attribute: `" + target.get + "`") - } } } else { val from = input.inputSet.map(_.name).mkString(", ") @@ -223,6 +254,7 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu * For example the SQL expression "stack(2, key, value, key, value) as (a, b)" could be represented * as follows: * MultiAlias(stack_function, Seq(a, b)) + * * @param child the computation being performed * @param names the names to be associated with each output of computing [[child]]. @@ -238,10 +270,12 @@ case class MultiAlias(child: Expression, names: Seq[String]) override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") + override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier") override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") + override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance") + override lazy val resolved = false override def toString: String = s"$child AS $names" @@ -255,6 +289,7 @@ case class MultiAlias(child: Expression, names: Seq[String]) * @param expressions Expressions to expand. */ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star with Unevaluable { + override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance") override def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] = expressions override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")") } @@ -276,20 +311,48 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression) override lazy val resolved = false override def toString: String = s"$child[$extraction]" + override def sql: String = s"${child.sql}[${extraction.sql}]" } /** * Holds the expression that has yet to be aliased. + * + * @param child The computation that is needs to be resolved during analysis. + * @param aliasName The name if specified to be associated with the result of computing [[child]] + * */ -case class UnresolvedAlias(child: Expression) +case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None) extends UnaryExpression with NamedExpression with Unevaluable { override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") - override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") + override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier") override def exprId: ExprId = throw new UnresolvedException(this, "exprId") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def name: String = throw new UnresolvedException(this, "name") + override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance") + + override lazy val resolved = false +} +/** + * Holds the deserializer expression and the attributes that are available during the resolution + * for it. Deserializer expression is a special kind of expression that is not always resolved by + * children output, but by given attributes, e.g. the `keyDeserializer` in `MapGroups` should be + * resolved by `groupingAttributes` instead of children output. + * + * @param deserializer The unresolved deserializer expression + * @param inputAttributes The input attributes used to resolve deserializer expression, can be empty + * if we want to resolve deserializer by children output. + */ +case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute] = Nil) + extends UnaryExpression with Unevaluable with NonSQLExpression { + // The input attributes used to resolve deserializer expression must be all resolved. + require(inputAttributes.forall(_.resolved), "Input attributes must all be resolved.") + + override def child: Expression = deserializer + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala new file mode 100644 index 0000000000000..f8a6fb74cc87d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import scala.collection.mutable + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.util.StringUtils + +/** + * An in-memory (ephemeral) implementation of the system catalog. + * + * This is a dummy implementation that does not require setting up external systems. + * It is intended for testing or exploration purposes only and should not be used + * in production. + * + * All public methods should be synchronized for thread-safety. + */ +class InMemoryCatalog extends ExternalCatalog { + import ExternalCatalog._ + + private class TableDesc(var table: CatalogTable) { + val partitions = new mutable.HashMap[TablePartitionSpec, CatalogTablePartition] + } + + private class DatabaseDesc(var db: CatalogDatabase) { + val tables = new mutable.HashMap[String, TableDesc] + val functions = new mutable.HashMap[String, CatalogFunction] + } + + // Database name -> description + private val catalog = new scala.collection.mutable.HashMap[String, DatabaseDesc] + + private def partitionExists(db: String, table: String, spec: TablePartitionSpec): Boolean = { + requireTableExists(db, table) + catalog(db).tables(table).partitions.contains(spec) + } + + private def requireFunctionExists(db: String, funcName: String): Unit = { + if (!functionExists(db, funcName)) { + throw new AnalysisException( + s"Function not found: '$funcName' does not exist in database '$db'") + } + } + + private def requireTableExists(db: String, table: String): Unit = { + if (!tableExists(db, table)) { + throw new AnalysisException( + s"Table or View not found: '$table' does not exist in database '$db'") + } + } + + private def requirePartitionExists(db: String, table: String, spec: TablePartitionSpec): Unit = { + if (!partitionExists(db, table, spec)) { + throw new AnalysisException( + s"Partition not found: database '$db' table '$table' does not contain: '$spec'") + } + } + + // -------------------------------------------------------------------------- + // Databases + // -------------------------------------------------------------------------- + + override def createDatabase( + dbDefinition: CatalogDatabase, + ignoreIfExists: Boolean): Unit = synchronized { + if (catalog.contains(dbDefinition.name)) { + if (!ignoreIfExists) { + throw new AnalysisException(s"Database '${dbDefinition.name}' already exists.") + } + } else { + catalog.put(dbDefinition.name, new DatabaseDesc(dbDefinition)) + } + } + + override def dropDatabase( + db: String, + ignoreIfNotExists: Boolean, + cascade: Boolean): Unit = synchronized { + if (catalog.contains(db)) { + if (!cascade) { + // If cascade is false, make sure the database is empty. + if (catalog(db).tables.nonEmpty) { + throw new AnalysisException(s"Database '$db' is not empty. One or more tables exist.") + } + if (catalog(db).functions.nonEmpty) { + throw new AnalysisException(s"Database '$db' is not empty. One or more functions exist.") + } + } + // Remove the database. + catalog.remove(db) + } else { + if (!ignoreIfNotExists) { + throw new AnalysisException(s"Database '$db' does not exist") + } + } + } + + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = synchronized { + requireDbExists(dbDefinition.name) + catalog(dbDefinition.name).db = dbDefinition + } + + override def getDatabase(db: String): CatalogDatabase = synchronized { + requireDbExists(db) + catalog(db).db + } + + override def databaseExists(db: String): Boolean = synchronized { + catalog.contains(db) + } + + override def listDatabases(): Seq[String] = synchronized { + catalog.keySet.toSeq + } + + override def listDatabases(pattern: String): Seq[String] = synchronized { + StringUtils.filterPattern(listDatabases(), pattern) + } + + override def setCurrentDatabase(db: String): Unit = { /* no-op */ } + + // -------------------------------------------------------------------------- + // Tables + // -------------------------------------------------------------------------- + + override def createTable( + db: String, + tableDefinition: CatalogTable, + ignoreIfExists: Boolean): Unit = synchronized { + requireDbExists(db) + val table = tableDefinition.identifier.table + if (tableExists(db, table)) { + if (!ignoreIfExists) { + throw new AnalysisException(s"Table '$table' already exists in database '$db'") + } + } else { + catalog(db).tables.put(table, new TableDesc(tableDefinition)) + } + } + + override def dropTable( + db: String, + table: String, + ignoreIfNotExists: Boolean): Unit = synchronized { + requireDbExists(db) + if (tableExists(db, table)) { + catalog(db).tables.remove(table) + } else { + if (!ignoreIfNotExists) { + throw new AnalysisException(s"Table or View '$table' does not exist in database '$db'") + } + } + } + + override def renameTable(db: String, oldName: String, newName: String): Unit = synchronized { + requireTableExists(db, oldName) + val oldDesc = catalog(db).tables(oldName) + oldDesc.table = oldDesc.table.copy(identifier = TableIdentifier(newName, Some(db))) + catalog(db).tables.put(newName, oldDesc) + catalog(db).tables.remove(oldName) + } + + override def alterTable(db: String, tableDefinition: CatalogTable): Unit = synchronized { + requireTableExists(db, tableDefinition.identifier.table) + catalog(db).tables(tableDefinition.identifier.table).table = tableDefinition + } + + override def getTable(db: String, table: String): CatalogTable = synchronized { + requireTableExists(db, table) + catalog(db).tables(table).table + } + + override def getTableOption(db: String, table: String): Option[CatalogTable] = synchronized { + if (!tableExists(db, table)) None else Option(catalog(db).tables(table).table) + } + + override def tableExists(db: String, table: String): Boolean = synchronized { + requireDbExists(db) + catalog(db).tables.contains(table) + } + + override def listTables(db: String): Seq[String] = synchronized { + requireDbExists(db) + catalog(db).tables.keySet.toSeq + } + + override def listTables(db: String, pattern: String): Seq[String] = synchronized { + StringUtils.filterPattern(listTables(db), pattern) + } + + // -------------------------------------------------------------------------- + // Partitions + // -------------------------------------------------------------------------- + + override def createPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition], + ignoreIfExists: Boolean): Unit = synchronized { + requireTableExists(db, table) + val existingParts = catalog(db).tables(table).partitions + if (!ignoreIfExists) { + val dupSpecs = parts.collect { case p if existingParts.contains(p.spec) => p.spec } + if (dupSpecs.nonEmpty) { + val dupSpecsStr = dupSpecs.mkString("\n===\n") + throw new AnalysisException("The following partitions already exist in database " + + s"'$db' table '$table':\n$dupSpecsStr") + } + } + parts.foreach { p => existingParts.put(p.spec, p) } + } + + override def dropPartitions( + db: String, + table: String, + partSpecs: Seq[TablePartitionSpec], + ignoreIfNotExists: Boolean): Unit = synchronized { + requireTableExists(db, table) + val existingParts = catalog(db).tables(table).partitions + if (!ignoreIfNotExists) { + val missingSpecs = partSpecs.collect { case s if !existingParts.contains(s) => s } + if (missingSpecs.nonEmpty) { + val missingSpecsStr = missingSpecs.mkString("\n===\n") + throw new AnalysisException("The following partitions do not exist in database " + + s"'$db' table '$table':\n$missingSpecsStr") + } + } + partSpecs.foreach(existingParts.remove) + } + + override def renamePartitions( + db: String, + table: String, + specs: Seq[TablePartitionSpec], + newSpecs: Seq[TablePartitionSpec]): Unit = synchronized { + require(specs.size == newSpecs.size, "number of old and new partition specs differ") + specs.zip(newSpecs).foreach { case (oldSpec, newSpec) => + val newPart = getPartition(db, table, oldSpec).copy(spec = newSpec) + val existingParts = catalog(db).tables(table).partitions + existingParts.remove(oldSpec) + existingParts.put(newSpec, newPart) + } + } + + override def alterPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition]): Unit = synchronized { + parts.foreach { p => + requirePartitionExists(db, table, p.spec) + catalog(db).tables(table).partitions.put(p.spec, p) + } + } + + override def getPartition( + db: String, + table: String, + spec: TablePartitionSpec): CatalogTablePartition = synchronized { + requirePartitionExists(db, table, spec) + catalog(db).tables(table).partitions(spec) + } + + override def listPartitions( + db: String, + table: String): Seq[CatalogTablePartition] = synchronized { + requireTableExists(db, table) + catalog(db).tables(table).partitions.values.toSeq + } + + // -------------------------------------------------------------------------- + // Functions + // -------------------------------------------------------------------------- + + override def createFunction(db: String, func: CatalogFunction): Unit = synchronized { + requireDbExists(db) + if (functionExists(db, func.identifier.funcName)) { + throw new AnalysisException(s"Function '$func' already exists in '$db' database") + } else { + catalog(db).functions.put(func.identifier.funcName, func) + } + } + + override def dropFunction(db: String, funcName: String): Unit = synchronized { + requireFunctionExists(db, funcName) + catalog(db).functions.remove(funcName) + } + + override def renameFunction(db: String, oldName: String, newName: String): Unit = synchronized { + requireFunctionExists(db, oldName) + val newFunc = getFunction(db, oldName).copy(identifier = FunctionIdentifier(newName, Some(db))) + catalog(db).functions.remove(oldName) + catalog(db).functions.put(newName, newFunc) + } + + override def getFunction(db: String, funcName: String): CatalogFunction = synchronized { + requireFunctionExists(db, funcName) + catalog(db).functions(funcName) + } + + override def functionExists(db: String, funcName: String): Boolean = { + requireDbExists(db) + catalog(db).functions.contains(funcName) + } + + override def listFunctions(db: String, pattern: String): Seq[String] = synchronized { + requireDbExists(db) + StringUtils.filterPattern(catalog(db).functions.keysIterator.toSeq, pattern) + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala new file mode 100644 index 0000000000000..34e1cb7315a9c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -0,0 +1,665 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import java.io.File + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchFunctionException, SimpleFunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} +import org.apache.spark.sql.catalyst.util.StringUtils + +/** + * An internal catalog that is used by a Spark Session. This internal catalog serves as a + * proxy to the underlying metastore (e.g. Hive Metastore) and it also manages temporary + * tables and functions of the Spark Session that it belongs to. + * + * This class is not thread-safe. + */ +class SessionCatalog( + externalCatalog: ExternalCatalog, + functionResourceLoader: FunctionResourceLoader, + functionRegistry: FunctionRegistry, + conf: CatalystConf) extends Logging { + import ExternalCatalog._ + + def this( + externalCatalog: ExternalCatalog, + functionRegistry: FunctionRegistry, + conf: CatalystConf) { + this(externalCatalog, DummyFunctionResourceLoader, functionRegistry, conf) + } + + // For testing only. + def this(externalCatalog: ExternalCatalog) { + this(externalCatalog, new SimpleFunctionRegistry, new SimpleCatalystConf(true)) + } + + protected[this] val tempTables = new mutable.HashMap[String, LogicalPlan] + + // Note: we track current database here because certain operations do not explicitly + // specify the database (e.g. DROP TABLE my_table). In these cases we must first + // check whether the temporary table or function exists, then, if not, operate on + // the corresponding item in the current database. + protected[this] var currentDb = { + val defaultName = "default" + val defaultDbDefinition = CatalogDatabase(defaultName, "default database", "", Map()) + // Initialize default database if it doesn't already exist + createDatabase(defaultDbDefinition, ignoreIfExists = true) + defaultName + } + + /** + * Format table name, taking into account case sensitivity. + */ + protected[this] def formatTableName(name: String): String = { + if (conf.caseSensitiveAnalysis) name else name.toLowerCase + } + + // ---------------------------------------------------------------------------- + // Databases + // ---------------------------------------------------------------------------- + // All methods in this category interact directly with the underlying catalog. + // ---------------------------------------------------------------------------- + + def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = { + externalCatalog.createDatabase(dbDefinition, ignoreIfExists) + } + + def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = { + externalCatalog.dropDatabase(db, ignoreIfNotExists, cascade) + } + + def alterDatabase(dbDefinition: CatalogDatabase): Unit = { + externalCatalog.alterDatabase(dbDefinition) + } + + def getDatabaseMetadata(db: String): CatalogDatabase = { + externalCatalog.getDatabase(db) + } + + def databaseExists(db: String): Boolean = { + externalCatalog.databaseExists(db) + } + + def listDatabases(): Seq[String] = { + externalCatalog.listDatabases() + } + + def listDatabases(pattern: String): Seq[String] = { + externalCatalog.listDatabases(pattern) + } + + def getCurrentDatabase: String = currentDb + + def setCurrentDatabase(db: String): Unit = { + if (!databaseExists(db)) { + throw new AnalysisException(s"cannot set current database to non-existent '$db'") + } + currentDb = db + } + + def getDefaultDBPath(db: String): String = { + System.getProperty("java.io.tmpdir") + File.separator + db + ".db" + } + + // ---------------------------------------------------------------------------- + // Tables + // ---------------------------------------------------------------------------- + // There are two kinds of tables, temporary tables and metastore tables. + // Temporary tables are isolated across sessions and do not belong to any + // particular database. Metastore tables can be used across multiple + // sessions as their metadata is persisted in the underlying catalog. + // ---------------------------------------------------------------------------- + + // ---------------------------------------------------- + // | Methods that interact with metastore tables only | + // ---------------------------------------------------- + + /** + * Create a metastore table in the database specified in `tableDefinition`. + * If no such database is specified, create it in the current database. + */ + def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { + val db = tableDefinition.identifier.database.getOrElse(currentDb) + val table = formatTableName(tableDefinition.identifier.table) + val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db))) + externalCatalog.createTable(db, newTableDefinition, ignoreIfExists) + } + + /** + * Alter the metadata of an existing metastore table identified by `tableDefinition`. + * + * If no database is specified in `tableDefinition`, assume the table is in the + * current database. + * + * Note: If the underlying implementation does not support altering a certain field, + * this becomes a no-op. + */ + def alterTable(tableDefinition: CatalogTable): Unit = { + val db = tableDefinition.identifier.database.getOrElse(currentDb) + val table = formatTableName(tableDefinition.identifier.table) + val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db))) + externalCatalog.alterTable(db, newTableDefinition) + } + + /** + * Retrieve the metadata of an existing metastore table. + * If no database is specified, assume the table is in the current database. + * If the specified table is not found in the database then an [[AnalysisException]] is thrown. + */ + def getTableMetadata(name: TableIdentifier): CatalogTable = { + val db = name.database.getOrElse(currentDb) + val table = formatTableName(name.table) + externalCatalog.getTable(db, table) + } + + /** + * Retrieve the metadata of an existing metastore table. + * If no database is specified, assume the table is in the current database. + * If the specified table is not found in the database then return None if it doesn't exist. + */ + def getTableMetadataOption(name: TableIdentifier): Option[CatalogTable] = { + val db = name.database.getOrElse(currentDb) + val table = formatTableName(name.table) + externalCatalog.getTableOption(db, table) + } + + // ------------------------------------------------------------- + // | Methods that interact with temporary and metastore tables | + // ------------------------------------------------------------- + + /** + * Create a temporary table. + */ + def createTempTable( + name: String, + tableDefinition: LogicalPlan, + overrideIfExists: Boolean): Unit = { + val table = formatTableName(name) + if (tempTables.contains(table) && !overrideIfExists) { + throw new AnalysisException(s"Temporary table '$name' already exists.") + } + tempTables.put(table, tableDefinition) + } + + /** + * Rename a table. + * + * If a database is specified in `oldName`, this will rename the table in that database. + * If no database is specified, this will first attempt to rename a temporary table with + * the same name, then, if that does not exist, rename the table in the current database. + * + * This assumes the database specified in `oldName` matches the one specified in `newName`. + */ + def renameTable(oldName: TableIdentifier, newName: TableIdentifier): Unit = { + if (oldName.database != newName.database) { + throw new AnalysisException("rename does not support moving tables across databases") + } + val db = oldName.database.getOrElse(currentDb) + val oldTableName = formatTableName(oldName.table) + val newTableName = formatTableName(newName.table) + if (oldName.database.isDefined || !tempTables.contains(oldTableName)) { + externalCatalog.renameTable(db, oldTableName, newTableName) + } else { + val table = tempTables(oldTableName) + tempTables.remove(oldTableName) + tempTables.put(newTableName, table) + } + } + + /** + * Drop a table. + * + * If a database is specified in `name`, this will drop the table from that database. + * If no database is specified, this will first attempt to drop a temporary table with + * the same name, then, if that does not exist, drop the table from the current database. + */ + def dropTable(name: TableIdentifier, ignoreIfNotExists: Boolean): Unit = { + val db = name.database.getOrElse(currentDb) + val table = formatTableName(name.table) + if (name.database.isDefined || !tempTables.contains(table)) { + // When ignoreIfNotExists is false, no exception is issued when the table does not exist. + // Instead, log it as an error message. + if (externalCatalog.tableExists(db, table)) { + externalCatalog.dropTable(db, table, ignoreIfNotExists = true) + } else if (!ignoreIfNotExists) { + logError(s"Table or View '${name.quotedString}' does not exist") + } + } else { + tempTables.remove(table) + } + } + + /** + * Return a [[LogicalPlan]] that represents the given table. + * + * If a database is specified in `name`, this will return the table from that database. + * If no database is specified, this will first attempt to return a temporary table with + * the same name, then, if that does not exist, return the table from the current database. + */ + def lookupRelation(name: TableIdentifier, alias: Option[String] = None): LogicalPlan = { + val db = name.database.getOrElse(currentDb) + val table = formatTableName(name.table) + val relation = + if (name.database.isDefined || !tempTables.contains(table)) { + val metadata = externalCatalog.getTable(db, table) + CatalogRelation(db, metadata, alias) + } else { + tempTables(table) + } + val qualifiedTable = SubqueryAlias(table, relation) + // If an alias was specified by the lookup, wrap the plan in a subquery so that + // attributes are properly qualified with this alias. + alias.map(a => SubqueryAlias(a, qualifiedTable)).getOrElse(qualifiedTable) + } + + /** + * Return whether a table with the specified name exists. + * + * Note: If a database is explicitly specified, then this will return whether the table + * exists in that particular database instead. In that case, even if there is a temporary + * table with the same name, we will return false if the specified database does not + * contain the table. + */ + def tableExists(name: TableIdentifier): Boolean = { + val db = name.database.getOrElse(currentDb) + val table = formatTableName(name.table) + if (name.database.isDefined || !tempTables.contains(table)) { + externalCatalog.tableExists(db, table) + } else { + true // it's a temporary table + } + } + + /** + * Return whether a table with the specified name is a temporary table. + * + * Note: The temporary table cache is checked only when database is not + * explicitly specified. + */ + def isTemporaryTable(name: TableIdentifier): Boolean = { + name.database.isEmpty && tempTables.contains(formatTableName(name.table)) + } + + /** + * List all tables in the specified database, including temporary tables. + */ + def listTables(db: String): Seq[TableIdentifier] = listTables(db, "*") + + /** + * List all matching tables in the specified database, including temporary tables. + */ + def listTables(db: String, pattern: String): Seq[TableIdentifier] = { + val dbTables = + externalCatalog.listTables(db, pattern).map { t => TableIdentifier(t, Some(db)) } + val _tempTables = StringUtils.filterPattern(tempTables.keys.toSeq, pattern) + .map { t => TableIdentifier(t) } + dbTables ++ _tempTables + } + + // TODO: It's strange that we have both refresh and invalidate here. + + /** + * Refresh the cache entry for a metastore table, if any. + */ + def refreshTable(name: TableIdentifier): Unit = { /* no-op */ } + + /** + * Invalidate the cache entry for a metastore table, if any. + */ + def invalidateTable(name: TableIdentifier): Unit = { /* no-op */ } + + /** + * Drop all existing temporary tables. + * For testing only. + */ + def clearTempTables(): Unit = { + tempTables.clear() + } + + /** + * Return a temporary table exactly as it was stored. + * For testing only. + */ + private[catalog] def getTempTable(name: String): Option[LogicalPlan] = { + tempTables.get(name) + } + + // ---------------------------------------------------------------------------- + // Partitions + // ---------------------------------------------------------------------------- + // All methods in this category interact directly with the underlying catalog. + // These methods are concerned with only metastore tables. + // ---------------------------------------------------------------------------- + + // TODO: We need to figure out how these methods interact with our data source + // tables. For such tables, we do not store values of partitioning columns in + // the metastore. For now, partition values of a data source table will be + // automatically discovered when we load the table. + + /** + * Create partitions in an existing table, assuming it exists. + * If no database is specified, assume the table is in the current database. + */ + def createPartitions( + tableName: TableIdentifier, + parts: Seq[CatalogTablePartition], + ignoreIfExists: Boolean): Unit = { + val db = tableName.database.getOrElse(currentDb) + val table = formatTableName(tableName.table) + externalCatalog.createPartitions(db, table, parts, ignoreIfExists) + } + + /** + * Drop partitions from a table, assuming they exist. + * If no database is specified, assume the table is in the current database. + */ + def dropPartitions( + tableName: TableIdentifier, + parts: Seq[TablePartitionSpec], + ignoreIfNotExists: Boolean): Unit = { + val db = tableName.database.getOrElse(currentDb) + val table = formatTableName(tableName.table) + externalCatalog.dropPartitions(db, table, parts, ignoreIfNotExists) + } + + /** + * Override the specs of one or many existing table partitions, assuming they exist. + * + * This assumes index i of `specs` corresponds to index i of `newSpecs`. + * If no database is specified, assume the table is in the current database. + */ + def renamePartitions( + tableName: TableIdentifier, + specs: Seq[TablePartitionSpec], + newSpecs: Seq[TablePartitionSpec]): Unit = { + val db = tableName.database.getOrElse(currentDb) + val table = formatTableName(tableName.table) + externalCatalog.renamePartitions(db, table, specs, newSpecs) + } + + /** + * Alter one or many table partitions whose specs that match those specified in `parts`, + * assuming the partitions exist. + * + * If no database is specified, assume the table is in the current database. + * + * Note: If the underlying implementation does not support altering a certain field, + * this becomes a no-op. + */ + def alterPartitions(tableName: TableIdentifier, parts: Seq[CatalogTablePartition]): Unit = { + val db = tableName.database.getOrElse(currentDb) + val table = formatTableName(tableName.table) + externalCatalog.alterPartitions(db, table, parts) + } + + /** + * Retrieve the metadata of a table partition, assuming it exists. + * If no database is specified, assume the table is in the current database. + */ + def getPartition(tableName: TableIdentifier, spec: TablePartitionSpec): CatalogTablePartition = { + val db = tableName.database.getOrElse(currentDb) + val table = formatTableName(tableName.table) + externalCatalog.getPartition(db, table, spec) + } + + /** + * List all partitions in a table, assuming it exists. + * If no database is specified, assume the table is in the current database. + */ + def listPartitions(tableName: TableIdentifier): Seq[CatalogTablePartition] = { + val db = tableName.database.getOrElse(currentDb) + val table = formatTableName(tableName.table) + externalCatalog.listPartitions(db, table) + } + + // ---------------------------------------------------------------------------- + // Functions + // ---------------------------------------------------------------------------- + // There are two kinds of functions, temporary functions and metastore + // functions (permanent UDFs). Temporary functions are isolated across + // sessions. Metastore functions can be used across multiple sessions as + // their metadata is persisted in the underlying catalog. + // ---------------------------------------------------------------------------- + + // ------------------------------------------------------- + // | Methods that interact with metastore functions only | + // ------------------------------------------------------- + + /** + * Create a metastore function in the database specified in `funcDefinition`. + * If no such database is specified, create it in the current database. + */ + def createFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = { + val db = funcDefinition.identifier.database.getOrElse(currentDb) + val identifier = FunctionIdentifier(funcDefinition.identifier.funcName, Some(db)) + val newFuncDefinition = funcDefinition.copy(identifier = identifier) + if (!functionExists(identifier)) { + externalCatalog.createFunction(db, newFuncDefinition) + } else if (!ignoreIfExists) { + throw new AnalysisException(s"function '$identifier' already exists in database '$db'") + } + } + + /** + * Drop a metastore function. + * If no database is specified, assume the function is in the current database. + */ + def dropFunction(name: FunctionIdentifier, ignoreIfNotExists: Boolean): Unit = { + val db = name.database.getOrElse(currentDb) + val identifier = name.copy(database = Some(db)) + if (functionExists(identifier)) { + // TODO: registry should just take in FunctionIdentifier for type safety + if (functionRegistry.functionExists(identifier.unquotedString)) { + // If we have loaded this function into the FunctionRegistry, + // also drop it from there. + // For a permanent function, because we loaded it to the FunctionRegistry + // when it's first used, we also need to drop it from the FunctionRegistry. + functionRegistry.dropFunction(identifier.unquotedString) + } + externalCatalog.dropFunction(db, name.funcName) + } else if (!ignoreIfNotExists) { + throw new AnalysisException(s"function '$identifier' does not exist in database '$db'") + } + } + + /** + * Retrieve the metadata of a metastore function. + * + * If a database is specified in `name`, this will return the function in that database. + * If no database is specified, this will return the function in the current database. + */ + def getFunctionMetadata(name: FunctionIdentifier): CatalogFunction = { + val db = name.database.getOrElse(currentDb) + externalCatalog.getFunction(db, name.funcName) + } + + /** + * Check if the specified function exists. + */ + def functionExists(name: FunctionIdentifier): Boolean = { + val db = name.database.getOrElse(currentDb) + functionRegistry.functionExists(name.unquotedString) || + externalCatalog.functionExists(db, name.funcName) + } + + // ---------------------------------------------------------------- + // | Methods that interact with temporary and metastore functions | + // ---------------------------------------------------------------- + + /** + * Construct a [[FunctionBuilder]] based on the provided class that represents a function. + * + * This performs reflection to decide what type of [[Expression]] to return in the builder. + */ + private[sql] def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { + // TODO: at least support UDAFs here + throw new UnsupportedOperationException("Use sqlContext.udf.register(...) instead.") + } + + /** + * Loads resources such as JARs and Files for a function. Every resource is represented + * by a tuple (resource type, resource uri). + */ + def loadFunctionResources(resources: Seq[(String, String)]): Unit = { + resources.foreach { case (resourceType, uri) => + val functionResource = + FunctionResource(FunctionResourceType.fromString(resourceType.toLowerCase), uri) + functionResourceLoader.loadResource(functionResource) + } + } + + /** + * Create a temporary function. + * This assumes no database is specified in `funcDefinition`. + */ + def createTempFunction( + name: String, + info: ExpressionInfo, + funcDefinition: FunctionBuilder, + ignoreIfExists: Boolean): Unit = { + if (functionRegistry.lookupFunctionBuilder(name).isDefined && !ignoreIfExists) { + throw new AnalysisException(s"Temporary function '$name' already exists.") + } + functionRegistry.registerFunction(name, info, funcDefinition) + } + + /** + * Drop a temporary function. + */ + // TODO: The reason that we distinguish dropFunction and dropTempFunction is that + // Hive has DROP FUNCTION and DROP TEMPORARY FUNCTION. We may want to consolidate + // dropFunction and dropTempFunction. + def dropTempFunction(name: String, ignoreIfNotExists: Boolean): Unit = { + if (!functionRegistry.dropFunction(name) && !ignoreIfNotExists) { + throw new AnalysisException( + s"Temporary function '$name' cannot be dropped because it does not exist!") + } + } + + protected def failFunctionLookup(name: String): Nothing = { + throw new AnalysisException(s"Undefined function: $name. This function is " + + s"neither a registered temporary function nor " + + s"a permanent function registered in the database $currentDb.") + } + + /** + * Return an [[Expression]] that represents the specified function, assuming it exists. + * + * For a temporary function or a permanent function that has been loaded, + * this method will simply lookup the function through the + * FunctionRegistry and create an expression based on the builder. + * + * For a permanent function that has not been loaded, we will first fetch its metadata + * from the underlying external catalog. Then, we will load all resources associated + * with this function (i.e. jars and files). Finally, we create a function builder + * based on the function class and put the builder into the FunctionRegistry. + * The name of this function in the FunctionRegistry will be `databaseName.functionName`. + */ + def lookupFunction(name: String, children: Seq[Expression]): Expression = { + // TODO: Right now, the name can be qualified or not qualified. + // It will be better to get a FunctionIdentifier. + // TODO: Right now, we assume that name is not qualified! + val qualifiedName = FunctionIdentifier(name, Some(currentDb)).unquotedString + if (functionRegistry.functionExists(name)) { + // This function has been already loaded into the function registry. + functionRegistry.lookupFunction(name, children) + } else if (functionRegistry.functionExists(qualifiedName)) { + // This function has been already loaded into the function registry. + // Unlike the above block, we find this function by using the qualified name. + functionRegistry.lookupFunction(qualifiedName, children) + } else { + // The function has not been loaded to the function registry, which means + // that the function is a permanent function (if it actually has been registered + // in the metastore). We need to first put the function in the FunctionRegistry. + val catalogFunction = try { + externalCatalog.getFunction(currentDb, name) + } catch { + case e: AnalysisException => failFunctionLookup(name) + case e: NoSuchFunctionException => failFunctionLookup(name) + } + loadFunctionResources(catalogFunction.resources) + // Please note that qualifiedName is provided by the user. However, + // catalogFunction.identifier.unquotedString is returned by the underlying + // catalog. So, it is possible that qualifiedName is not exactly the same as + // catalogFunction.identifier.unquotedString (difference is on case-sensitivity). + // At here, we preserve the input from the user. + val info = new ExpressionInfo(catalogFunction.className, qualifiedName) + val builder = makeFunctionBuilder(qualifiedName, catalogFunction.className) + createTempFunction(qualifiedName, info, builder, ignoreIfExists = false) + // Now, we need to create the Expression. + functionRegistry.lookupFunction(qualifiedName, children) + } + } + + /** + * List all functions in the specified database, including temporary functions. + */ + def listFunctions(db: String): Seq[FunctionIdentifier] = listFunctions(db, "*") + + /** + * List all matching functions in the specified database, including temporary functions. + */ + def listFunctions(db: String, pattern: String): Seq[FunctionIdentifier] = { + val dbFunctions = + externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) } + val loadedFunctions = StringUtils.filterPattern(functionRegistry.listFunction(), pattern) + .map { f => FunctionIdentifier(f) } + // TODO: Actually, there will be dbFunctions that have been loaded into the FunctionRegistry. + // So, the returned list may have two entries for the same function. + dbFunctions ++ loadedFunctions + } + + + // ----------------- + // | Other methods | + // ----------------- + + /** + * Drop all existing databases (except "default") along with all associated tables, + * partitions and functions, and set the current database to "default". + * + * This is mainly used for tests. + */ + private[sql] def reset(): Unit = { + val default = "default" + listDatabases().filter(_ != default).foreach { db => + dropDatabase(db, ignoreIfNotExists = false, cascade = true) + } + tempTables.clear() + functionRegistry.clear() + // restore built-in functions + FunctionRegistry.builtin.listFunction().foreach { f => + val expressionInfo = FunctionRegistry.builtin.lookupFunction(f) + val functionBuilder = FunctionRegistry.builtin.lookupFunctionBuilder(f) + require(expressionInfo.isDefined, s"built-in function '$f' is missing expression info") + require(functionBuilder.isDefined, s"built-in function '$f' is missing function builder") + functionRegistry.registerFunction(f, expressionInfo.get, functionBuilder.get) + } + setCurrentDatabase(default) + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala new file mode 100644 index 0000000000000..5adcc892cf682 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.sql.AnalysisException + +/** An trait that represents the type of a resourced needed by a function. */ +sealed trait FunctionResourceType + +object JarResource extends FunctionResourceType + +object FileResource extends FunctionResourceType + +// We do not allow users to specify a archive because it is YARN specific. +// When loading resources, we will throw an exception and ask users to +// use --archive with spark submit. +object ArchiveResource extends FunctionResourceType + +object FunctionResourceType { + def fromString(resourceType: String): FunctionResourceType = { + resourceType.toLowerCase match { + case "jar" => JarResource + case "file" => FileResource + case "archive" => ArchiveResource + case other => + throw new AnalysisException(s"Resource Type '$resourceType' is not supported.") + } + } +} + +case class FunctionResource(resourceType: FunctionResourceType, uri: String) + +/** + * A simple trait representing a class that can be used to load resources used by + * a function. Because only a SQLContext can load resources, we create this trait + * to avoid of explicitly passing SQLContext around. + */ +trait FunctionResourceLoader { + def loadResource(resource: FunctionResource): Unit +} + +object DummyFunctionResourceLoader extends FunctionResourceLoader { + override def loadResource(resource: FunctionResource): Unit = { + throw new UnsupportedOperationException + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala new file mode 100644 index 0000000000000..ad989a97e4afa --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import javax.annotation.Nullable + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} + + +/** + * Interface for the system catalog (of columns, partitions, tables, and databases). + * + * This is only used for non-temporary items, and implementations must be thread-safe as they + * can be accessed in multiple threads. This is an external catalog because it is expected to + * interact with external systems. + * + * Implementations should throw [[AnalysisException]] when table or database don't exist. + */ +abstract class ExternalCatalog { + import ExternalCatalog._ + + protected def requireDbExists(db: String): Unit = { + if (!databaseExists(db)) { + throw new AnalysisException(s"Database '$db' does not exist") + } + } + + // -------------------------------------------------------------------------- + // Databases + // -------------------------------------------------------------------------- + + def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit + + def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit + + /** + * Alter a database whose name matches the one specified in `dbDefinition`, + * assuming the database exists. + * + * Note: If the underlying implementation does not support altering a certain field, + * this becomes a no-op. + */ + def alterDatabase(dbDefinition: CatalogDatabase): Unit + + def getDatabase(db: String): CatalogDatabase + + def databaseExists(db: String): Boolean + + def listDatabases(): Seq[String] + + def listDatabases(pattern: String): Seq[String] + + def setCurrentDatabase(db: String): Unit + + // -------------------------------------------------------------------------- + // Tables + // -------------------------------------------------------------------------- + + def createTable(db: String, tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit + + def dropTable(db: String, table: String, ignoreIfNotExists: Boolean): Unit + + def renameTable(db: String, oldName: String, newName: String): Unit + + /** + * Alter a table whose name that matches the one specified in `tableDefinition`, + * assuming the table exists. + * + * Note: If the underlying implementation does not support altering a certain field, + * this becomes a no-op. + */ + def alterTable(db: String, tableDefinition: CatalogTable): Unit + + def getTable(db: String, table: String): CatalogTable + + def getTableOption(db: String, table: String): Option[CatalogTable] + + def tableExists(db: String, table: String): Boolean + + def listTables(db: String): Seq[String] + + def listTables(db: String, pattern: String): Seq[String] + + // -------------------------------------------------------------------------- + // Partitions + // -------------------------------------------------------------------------- + + def createPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition], + ignoreIfExists: Boolean): Unit + + def dropPartitions( + db: String, + table: String, + parts: Seq[TablePartitionSpec], + ignoreIfNotExists: Boolean): Unit + + /** + * Override the specs of one or many existing table partitions, assuming they exist. + * This assumes index i of `specs` corresponds to index i of `newSpecs`. + */ + def renamePartitions( + db: String, + table: String, + specs: Seq[TablePartitionSpec], + newSpecs: Seq[TablePartitionSpec]): Unit + + /** + * Alter one or many table partitions whose specs that match those specified in `parts`, + * assuming the partitions exist. + * + * Note: If the underlying implementation does not support altering a certain field, + * this becomes a no-op. + */ + def alterPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition]): Unit + + def getPartition(db: String, table: String, spec: TablePartitionSpec): CatalogTablePartition + + // TODO: support listing by pattern + def listPartitions(db: String, table: String): Seq[CatalogTablePartition] + + // -------------------------------------------------------------------------- + // Functions + // -------------------------------------------------------------------------- + + def createFunction(db: String, funcDefinition: CatalogFunction): Unit + + def dropFunction(db: String, funcName: String): Unit + + def renameFunction(db: String, oldName: String, newName: String): Unit + + def getFunction(db: String, funcName: String): CatalogFunction + + def functionExists(db: String, funcName: String): Boolean + + def listFunctions(db: String, pattern: String): Seq[String] + +} + + +/** + * A function defined in the catalog. + * + * @param identifier name of the function + * @param className fully qualified class name, e.g. "org.apache.spark.util.MyFunc" + * @param resources resource types and Uris used by the function + */ +// TODO: Use FunctionResource instead of (String, String) as the element type of resources. +case class CatalogFunction( + identifier: FunctionIdentifier, + className: String, + resources: Seq[(String, String)]) + + +/** + * Storage format, used to describe how a partition or a table is stored. + */ +case class CatalogStorageFormat( + locationUri: Option[String], + inputFormat: Option[String], + outputFormat: Option[String], + serde: Option[String], + serdeProperties: Map[String, String]) + + +/** + * A column in a table. + */ +case class CatalogColumn( + name: String, + // This may be null when used to create views. TODO: make this type-safe; this is left + // as a string due to issues in converting Hive varchars to and from SparkSQL strings. + @Nullable dataType: String, + nullable: Boolean = true, + comment: Option[String] = None) + + +/** + * A partition (Hive style) defined in the catalog. + * + * @param spec partition spec values indexed by column name + * @param storage storage format of the partition + */ +case class CatalogTablePartition( + spec: ExternalCatalog.TablePartitionSpec, + storage: CatalogStorageFormat) + + +/** + * A table defined in the catalog. + * + * Note that Hive's metastore also tracks skewed columns. We should consider adding that in the + * future once we have a better understanding of how we want to handle skewed columns. + */ +case class CatalogTable( + identifier: TableIdentifier, + tableType: CatalogTableType, + storage: CatalogStorageFormat, + schema: Seq[CatalogColumn], + partitionColumnNames: Seq[String] = Seq.empty, + sortColumnNames: Seq[String] = Seq.empty, + bucketColumnNames: Seq[String] = Seq.empty, + numBuckets: Int = -1, + createTime: Long = System.currentTimeMillis, + lastAccessTime: Long = -1, + properties: Map[String, String] = Map.empty, + viewOriginalText: Option[String] = None, + viewText: Option[String] = None, + comment: Option[String] = None) { + + // Verify that the provided columns are part of the schema + private val colNames = schema.map(_.name).toSet + private def requireSubsetOfSchema(cols: Seq[String], colType: String): Unit = { + require(cols.toSet.subsetOf(colNames), s"$colType columns (${cols.mkString(", ")}) " + + s"must be a subset of schema (${colNames.mkString(", ")}) in table '$identifier'") + } + requireSubsetOfSchema(partitionColumnNames, "partition") + requireSubsetOfSchema(sortColumnNames, "sort") + requireSubsetOfSchema(bucketColumnNames, "bucket") + + /** Columns this table is partitioned by. */ + def partitionColumns: Seq[CatalogColumn] = + schema.filter { c => partitionColumnNames.contains(c.name) } + + /** Return the database this table was specified to belong to, assuming it exists. */ + def database: String = identifier.database.getOrElse { + throw new AnalysisException(s"table $identifier did not specify database") + } + + /** Return the fully qualified name of this table, assuming the database was specified. */ + def qualifiedName: String = identifier.unquotedString + + /** Syntactic sugar to update a field in `storage`. */ + def withNewStorage( + locationUri: Option[String] = storage.locationUri, + inputFormat: Option[String] = storage.inputFormat, + outputFormat: Option[String] = storage.outputFormat, + serde: Option[String] = storage.serde, + serdeProperties: Map[String, String] = storage.serdeProperties): CatalogTable = { + copy(storage = CatalogStorageFormat( + locationUri, inputFormat, outputFormat, serde, serdeProperties)) + } + +} + + +case class CatalogTableType private(name: String) +object CatalogTableType { + val EXTERNAL_TABLE = new CatalogTableType("EXTERNAL_TABLE") + val MANAGED_TABLE = new CatalogTableType("MANAGED_TABLE") + val INDEX_TABLE = new CatalogTableType("INDEX_TABLE") + val VIRTUAL_VIEW = new CatalogTableType("VIRTUAL_VIEW") +} + + +/** + * A database defined in the catalog. + */ +case class CatalogDatabase( + name: String, + description: String, + locationUri: String, + properties: Map[String, String]) + + +object ExternalCatalog { + /** + * Specifications of a table partition. Mapping column name to column value. + */ + type TablePartitionSpec = Map[String, String] +} + + +/** + * A [[LogicalPlan]] that wraps [[CatalogTable]]. + */ +case class CatalogRelation( + db: String, + metadata: CatalogTable, + alias: Option[String] = None) + extends LeafNode { + + // TODO: implement this + override def output: Seq[Attribute] = Seq.empty + + require(metadata.identifier.database == Some(db), + "provided database does not match the one specified in the table definition") +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index d8df66430a695..1e7296664bb25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -21,10 +21,12 @@ import java.sql.{Date, Timestamp} import scala.language.implicitConversions -import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedExtractValue, UnresolvedAttribute} +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ /** @@ -82,7 +84,7 @@ package object dsl { def >= (other: Expression): Predicate = GreaterThanOrEqual(expr, other) def === (other: Expression): Predicate = EqualTo(expr, other) def <=> (other: Expression): Predicate = EqualNullSafe(expr, other) - def !== (other: Expression): Predicate = Not(EqualTo(expr, other)) + def =!= (other: Expression): Predicate = Not(EqualTo(expr, other)) def in(list: Expression*): Expression = In(expr, list) @@ -144,21 +146,34 @@ package object dsl { } } - def sum(e: Expression): Expression = Sum(e) - def sumDistinct(e: Expression): Expression = SumDistinct(e) - def count(e: Expression): Expression = Count(e) - def countDistinct(e: Expression*): Expression = CountDistinct(e) + def sum(e: Expression): Expression = Sum(e).toAggregateExpression() + def sumDistinct(e: Expression): Expression = Sum(e).toAggregateExpression(isDistinct = true) + def count(e: Expression): Expression = Count(e).toAggregateExpression() + def countDistinct(e: Expression*): Expression = + Count(e).toAggregateExpression(isDistinct = true) def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression = - ApproxCountDistinct(e, rsd) - def avg(e: Expression): Expression = Average(e) - def first(e: Expression): Expression = First(e) - def last(e: Expression): Expression = Last(e) - def min(e: Expression): Expression = Min(e) - def max(e: Expression): Expression = Max(e) + HyperLogLogPlusPlus(e, rsd).toAggregateExpression() + def avg(e: Expression): Expression = Average(e).toAggregateExpression() + def first(e: Expression): Expression = new First(e).toAggregateExpression() + def last(e: Expression): Expression = new Last(e).toAggregateExpression() + def min(e: Expression): Expression = Min(e).toAggregateExpression() + def max(e: Expression): Expression = Max(e).toAggregateExpression() def upper(e: Expression): Expression = Upper(e) def lower(e: Expression): Expression = Lower(e) def sqrt(e: Expression): Expression = Sqrt(e) def abs(e: Expression): Expression = Abs(e) + def star(names: String*): Expression = names match { + case Seq() => UnresolvedStar(None) + case target => UnresolvedStar(Option(target)) + } + + def callFunction[T, U]( + func: T => U, + returnType: DataType, + argument: Expression): Expression = { + val function = Literal.create(func, ObjectType(classOf[T => U])) + Invoke(function, "apply", returnType, argument :: Nil) + } implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } // TODO more implicit class for literal? @@ -225,14 +240,21 @@ package object dsl { AttributeReference(s, mapType, nullable = true)() /** Creates a new AttributeReference of type struct */ - def struct(fields: StructField*): AttributeReference = struct(StructType(fields)) def struct(structType: StructType): AttributeReference = AttributeReference(s, structType, nullable = true)() + def struct(attrs: AttributeReference*): AttributeReference = + struct(StructType.fromAttributes(attrs)) + + /** Create a function. */ + def function(exprs: Expression*): UnresolvedFunction = + UnresolvedFunction(s, exprs, isDistinct = false) + def distinctFunction(exprs: Expression*): UnresolvedFunction = + UnresolvedFunction(s, exprs, isDistinct = true) } implicit class DslAttribute(a: AttributeReference) { def notNull: AttributeReference = a.withNullability(false) - def nullable: AttributeReference = a.withNullability(true) + def canBeNull: AttributeReference = a.withNullability(true) def at(ordinal: Int): BoundReference = BoundReference(ordinal, a.dataType, a.nullable) } } @@ -240,11 +262,33 @@ package object dsl { object expressions extends ExpressionConversions // scalastyle:ignore object plans { // scalastyle:ignore + def table(ref: String): LogicalPlan = + UnresolvedRelation(TableIdentifier(ref), None) + + def table(db: String, ref: String): LogicalPlan = + UnresolvedRelation(TableIdentifier(ref, Option(db)), None) + implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) { - def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan) + def select(exprs: Expression*): LogicalPlan = { + val namedExpressions = exprs.map { + case e: NamedExpression => e + case e => UnresolvedAlias(e) + } + Project(namedExpressions, logicalPlan) + } def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan) + def filter[T : Encoder](func: T => Boolean): LogicalPlan = { + val deserialized = logicalPlan.deserialize[T] + val condition = expressions.callFunction(func, BooleanType, deserialized.output.head) + Filter(condition, deserialized).serialize[T] + } + + def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan) + + def deserialize[T : Encoder]: LogicalPlan = CatalystSerde.deserialize[T](logicalPlan) + def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan) def join( @@ -265,28 +309,44 @@ package object dsl { Aggregate(groupingExprs, aliasedExprs, logicalPlan) } - def subquery(alias: Symbol): LogicalPlan = Subquery(alias.name, logicalPlan) + def window( + windowExpressions: Seq[NamedExpression], + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder]): LogicalPlan = + Window(windowExpressions, partitionSpec, orderSpec, logicalPlan) + + def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan) def except(otherPlan: LogicalPlan): LogicalPlan = Except(logicalPlan, otherPlan) def intersect(otherPlan: LogicalPlan): LogicalPlan = Intersect(logicalPlan, otherPlan) - def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) + def union(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) - // TODO specify the output column names def generate( generator: Generator, join: Boolean = false, outer: Boolean = false, - alias: Option[String] = None): LogicalPlan = - Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan) + alias: Option[String] = None, + outputNames: Seq[String] = Nil): LogicalPlan = + Generate(generator, join = join, outer = outer, alias, + outputNames.map(UnresolvedAttribute(_)), logicalPlan) def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( analysis.UnresolvedRelation(TableIdentifier(tableName)), Map.empty, logicalPlan, overwrite, false) - def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer.execute(logicalPlan)) + def as(alias: String): LogicalPlan = logicalPlan match { + case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias)) + case plan => SubqueryAlias(alias, plan) + } + + def distribute(exprs: Expression*): LogicalPlan = + RepartitionByExpression(exprs, logicalPlan) + + def analyze: LogicalPlan = + EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(logicalPlan)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala deleted file mode 100644 index 329a132d3d8b2..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - - - -import scala.reflect.ClassTag - -import org.apache.spark.sql.types.StructType - -/** - * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. - * - * Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking - * and reuse internal buffers to improve performance. - */ -trait Encoder[T] extends Serializable { - - /** Returns the schema of encoding this type of object as a Row. */ - def schema: StructType - - /** A ClassTag that can be used to construct and Array to contain a collection of `T`. */ - def clsTag: ClassTag[T] -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index c287aebeeee05..56d29cfbe1f66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -17,21 +17,23 @@ package org.apache.spark.sql.catalyst.encoders -import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} -import org.apache.spark.util.Utils +import java.util.concurrent.ConcurrentMap import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} +import org.apache.spark.sql.{AnalysisException, Encoder} +import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection} +import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.types.{StructField, DataType, ObjectType, StructType} +import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.types.{ObjectType, StructField, StructType} +import org.apache.spark.util.Utils /** - * A factory for constructing encoders that convert objects and primitves to and from the + * A factory for constructing encoders that convert objects and primitives to and from the * internal row format using catalyst expressions and code generation. By default, the * expressions used to retrieve values from an input row when producing an object will be created as * follows: @@ -42,88 +44,191 @@ import org.apache.spark.sql.types.{StructField, DataType, ObjectType, StructType * to the name `value`. */ object ExpressionEncoder { - def apply[T : TypeTag](flat: Boolean = false): ExpressionEncoder[T] = { + def apply[T : TypeTag](): ExpressionEncoder[T] = { // We convert the not-serializable TypeTag into StructType and ClassTag. val mirror = typeTag[T].mirror val cls = mirror.runtimeClass(typeTag[T].tpe) + val flat = !classOf[Product].isAssignableFrom(cls) - val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val extractExpression = ScalaReflection.extractorsFor[T](inputObject) - val constructExpression = ScalaReflection.constructorFor[T] + val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = false) + val serializer = ScalaReflection.serializerFor[T](inputObject) + val deserializer = ScalaReflection.deserializerFor[T] + + val schema = ScalaReflection.schemaFor[T] match { + case ScalaReflection.Schema(s: StructType, _) => s + case ScalaReflection.Schema(dt, nullable) => new StructType().add("value", dt, nullable) + } new ExpressionEncoder[T]( - extractExpression.dataType, + schema, flat, - extractExpression.flatten, - constructExpression, + serializer.flatten, + deserializer, ClassTag[T](cls)) } + // TODO: improve error message for java bean encoder. + def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = { + val schema = JavaTypeInference.inferDataType(beanClass)._1 + assert(schema.isInstanceOf[StructType]) + + val serializer = JavaTypeInference.serializerFor(beanClass) + val deserializer = JavaTypeInference.deserializerFor(beanClass) + + new ExpressionEncoder[T]( + schema.asInstanceOf[StructType], + flat = false, + serializer.flatten, + deserializer, + ClassTag[T](beanClass)) + } + /** * Given a set of N encoders, constructs a new encoder that produce objects as items in an - * N-tuple. Note that these encoders should first be bound correctly to the combined input - * schema. + * N-tuple. Note that these encoders should be unresolved so that information about + * name/positional binding is preserved. */ def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { - val schema = - StructType( - encoders.zipWithIndex.map { case (e, i) => StructField(s"_${i + 1}", e.schema)}) + encoders.foreach(_.assertUnresolved()) + + val schema = StructType(encoders.zipWithIndex.map { + case (e, i) => + val (dataType, nullable) = if (e.flat) { + e.schema.head.dataType -> e.schema.head.nullable + } else { + e.schema -> true + } + StructField(s"_${i + 1}", dataType, nullable) + }) + val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") - val extractExpressions = encoders.map { - case e if e.flat => e.extractExpressions.head - case other => CreateStruct(other.extractExpressions) + + val serializer = encoders.map { + case e if e.flat => e.serializer.head + case other => CreateStruct(other.serializer) + }.zipWithIndex.map { case (expr, index) => + expr.transformUp { + case BoundReference(0, t, _) => + Invoke( + BoundReference(0, ObjectType(cls), nullable = true), + s"_${index + 1}", + t) + } + } + + val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => + if (enc.flat) { + enc.deserializer.transform { + case b: BoundReference => b.copy(ordinal = index) + } + } else { + val input = BoundReference(index, enc.schema, nullable = true) + enc.deserializer.transformUp { + case UnresolvedAttribute(nameParts) => + assert(nameParts.length == 1) + UnresolvedExtractValue(input, Literal(nameParts.head)) + case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal) + } + } } - val constructExpression = - NewInstance(cls, encoders.map(_.constructExpression), false, ObjectType(cls)) + + val deserializer = + NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false) new ExpressionEncoder[Any]( schema, - false, - extractExpressions, - constructExpression, - ClassTag.apply(cls)) + flat = false, + serializer, + deserializer, + ClassTag(cls)) } - /** A helper for producing encoders of Tuple2 from other encoders. */ def tuple[T1, T2]( e1: ExpressionEncoder[T1], e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] = - tuple(e1 :: e2 :: Nil).asInstanceOf[ExpressionEncoder[(T1, T2)]] + tuple(Seq(e1, e2)).asInstanceOf[ExpressionEncoder[(T1, T2)]] + + def tuple[T1, T2, T3]( + e1: ExpressionEncoder[T1], + e2: ExpressionEncoder[T2], + e3: ExpressionEncoder[T3]): ExpressionEncoder[(T1, T2, T3)] = + tuple(Seq(e1, e2, e3)).asInstanceOf[ExpressionEncoder[(T1, T2, T3)]] + + def tuple[T1, T2, T3, T4]( + e1: ExpressionEncoder[T1], + e2: ExpressionEncoder[T2], + e3: ExpressionEncoder[T3], + e4: ExpressionEncoder[T4]): ExpressionEncoder[(T1, T2, T3, T4)] = + tuple(Seq(e1, e2, e3, e4)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]] + + def tuple[T1, T2, T3, T4, T5]( + e1: ExpressionEncoder[T1], + e2: ExpressionEncoder[T2], + e3: ExpressionEncoder[T3], + e4: ExpressionEncoder[T4], + e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] = + tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]] } /** * A generic encoder for JVM objects. * * @param schema The schema after converting `T` to a Spark SQL row. - * @param extractExpressions A set of expressions, one for each top-level field that can be used to - * extract the values from a raw object. + * @param serializer A set of expressions, one for each top-level field that can be used to + * extract the values from a raw object into an [[InternalRow]]. + * @param deserializer An expression that will construct an object given an [[InternalRow]]. * @param clsTag A classtag for `T`. */ case class ExpressionEncoder[T]( schema: StructType, flat: Boolean, - extractExpressions: Seq[Expression], - constructExpression: Expression, + serializer: Seq[Expression], + deserializer: Expression, clsTag: ClassTag[T]) extends Encoder[T] { - if (flat) require(extractExpressions.size == 1) + if (flat) require(serializer.size == 1) @transient - private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions) - private val inputRow = new GenericMutableRow(1) + private lazy val extractProjection = GenerateUnsafeProjection.generate(serializer) @transient - private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil) + private lazy val inputRow = new GenericMutableRow(1) + + @transient + private lazy val constructProjection = GenerateSafeProjection.generate(deserializer :: Nil) + + /** + * Returns this encoder where it has been bound to its own output (i.e. no remaping of columns + * is performed). + */ + def defaultBinding: ExpressionEncoder[T] = { + val attrs = schema.toAttributes + resolve(attrs, OuterScopes.outerScopes).bind(attrs) + } + + + /** + * Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form + * of this object. + */ + def namedExpressions: Seq[NamedExpression] = schema.map(_.name).zip(serializer).map { + case (_, ne: NamedExpression) => ne.newInstance() + case (name, e) => Alias(e, name)() + } /** * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should * copy the result before making another call if required. */ - def toRow(t: T): InternalRow = { + def toRow(t: T): InternalRow = try { inputRow(0) = t extractProjection(inputRow) + } catch { + case e: Exception => + throw new RuntimeException( + s"Error while encoding: $e\n${serializer.map(_.treeString).mkString("\n")}", e) } /** @@ -135,74 +240,111 @@ case class ExpressionEncoder[T]( constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T] } catch { case e: Exception => - throw new RuntimeException(s"Error while decoding: $e\n${constructExpression.treeString}", e) + throw new RuntimeException(s"Error while decoding: $e\n${deserializer.treeString}", e) } /** - * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the - * given schema. + * The process of resolution to a given schema throws away information about where a given field + * is being bound by ordinal instead of by name. This method checks to make sure this process + * has not been done already in places where we plan to do later composition of encoders. */ - def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = { - val plan = Project(Alias(constructExpression, "")() :: Nil, LocalRelation(schema)) - val analyzedPlan = SimpleAnalyzer.execute(plan) - copy(constructExpression = analyzedPlan.expressions.head.children.head) + def assertUnresolved(): Unit = { + (deserializer +: serializer).foreach(_.foreach { + case a: AttributeReference if a.name != "loopVar" => + sys.error(s"Unresolved encoder expected, but $a was found.") + case _ => + }) } /** - * Returns a copy of this encoder where the expressions used to construct an object from an input - * row have been bound to the ordinals of the given schema. Note that you need to first call - * resolve before bind. + * Validates `deserializer` to make sure it can be resolved by given schema, and produce + * friendly error messages to explain why it fails to resolve if there is something wrong. */ - def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = { - copy(constructExpression = BindReferences.bindReference(constructExpression, schema)) + def validate(schema: Seq[Attribute]): Unit = { + def fail(st: StructType, maxOrdinal: Int): Unit = { + throw new AnalysisException(s"Try to map ${st.simpleString} to Tuple${maxOrdinal + 1}, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: " + StructType.fromAttributes(schema).simpleString + "\n" + + " - Target schema: " + this.schema.simpleString) + } + + // If this is a tuple encoder or tupled encoder, which means its leaf nodes are all + // `BoundReference`, make sure their ordinals are all valid. + var maxOrdinal = -1 + deserializer.foreach { + case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal + case _ => + } + if (maxOrdinal >= 0 && maxOrdinal != schema.length - 1) { + fail(StructType.fromAttributes(schema), maxOrdinal) + } + + // If we have nested tuple, the `fromRowExpression` will contains `GetStructField` instead of + // `UnresolvedExtractValue`, so we need to check if their ordinals are all valid. + // Note that, `BoundReference` contains the expected type, but here we need the actual type, so + // we unbound it by the given `schema` and propagate the actual type to `GetStructField`, after + // we resolve the `fromRowExpression`. + val resolved = SimpleAnalyzer.resolveExpression( + deserializer, + LocalRelation(schema), + throws = true) + + val unbound = resolved transform { + case b: BoundReference => schema(b.ordinal) + } + + val exprToMaxOrdinal = scala.collection.mutable.HashMap.empty[Expression, Int] + unbound.foreach { + case g: GetStructField => + val maxOrdinal = exprToMaxOrdinal.getOrElse(g.child, -1) + if (maxOrdinal < g.ordinal) { + exprToMaxOrdinal.update(g.child, g.ordinal) + } + case _ => + } + exprToMaxOrdinal.foreach { + case (expr, maxOrdinal) => + val schema = expr.dataType.asInstanceOf[StructType] + if (maxOrdinal != schema.length - 1) { + fail(schema, maxOrdinal) + } + } } /** - * Replaces any bound references in the schema with the attributes at the corresponding ordinal - * in the provided schema. This can be used to "relocate" a given encoder to pull values from - * a different schema than it was initially bound to. It can also be used to assign attributes - * to ordinal based extraction (i.e. because the input data was a tuple). + * Returns a new copy of this encoder, where the `deserializer` is resolved to the given schema. */ - def unbind(schema: Seq[Attribute]): ExpressionEncoder[T] = { - val positionToAttribute = AttributeMap.toIndex(schema) - copy(constructExpression = constructExpression transform { - case b: BoundReference => positionToAttribute(b.ordinal) - }) + def resolve( + schema: Seq[Attribute], + outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { + // Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check + // analysis, go through optimizer, etc. + val plan = Project( + Alias(UnresolvedDeserializer(deserializer, schema), "")() :: Nil, + LocalRelation(schema)) + val analyzedPlan = SimpleAnalyzer.execute(plan) + SimpleAnalyzer.checkAnalysis(analyzedPlan) + copy(deserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head) } /** - * Given an encoder that has already been bound to a given schema, returns a new encoder - * where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example, - * when you are trying to use an encoder on grouping keys that were originally part of a larger - * row, but now you have projected out only the key expressions. + * Returns a copy of this encoder where the `deserializer` has been bound to the + * ordinals of the given schema. Note that you need to first call resolve before bind. */ - def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ExpressionEncoder[T] = { - val positionToAttribute = AttributeMap.toIndex(oldSchema) - val attributeToNewPosition = AttributeMap.byIndex(newSchema) - copy(constructExpression = constructExpression transform { - case r: BoundReference => - r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal))) - }) + def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = { + copy(deserializer = BindReferences.bindReference(deserializer, schema)) } /** - * Returns a copy of this encoder where the expressions used to create an object given an - * input row have been modified to pull the object out from a nested struct, instead of the - * top level fields. + * Returns a new encoder with input columns shifted by `delta` ordinals */ - def nested(input: Expression = BoundReference(0, schema, true)): ExpressionEncoder[T] = { - copy(constructExpression = constructExpression transform { - case u: Attribute if u != input => - UnresolvedExtractValue(input, Literal(u.name)) - case b: BoundReference if b != input => - GetStructField( - input, - StructField(s"i[${b.ordinal}]", b.dataType), - b.ordinal) + def shift(delta: Int): ExpressionEncoder[T] = { + copy(deserializer = deserializer transform { + case r: BoundReference => r.copy(ordinal = r.ordinal + delta) }) } - protected val attrs = extractExpressions.flatMap(_.collect { + protected val attrs = serializer.flatMap(_.collect { case _: UnresolvedAttribute => "" case a: Attribute => s"#${a.exprId}" case b: BoundReference => s"[${b.ordinal}]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala new file mode 100644 index 0000000000000..a1f0312bd853c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import java.util.concurrent.ConcurrentMap + +import com.google.common.collect.MapMaker + +import org.apache.spark.util.Utils + +object OuterScopes { + @transient + lazy val outerScopes: ConcurrentMap[String, AnyRef] = + new MapMaker().weakValues().makeMap() + + /** + * Adds a new outer scope to this context that can be used when instantiating an `inner class` + * during deserialization. Inner classes are created when a case class is defined in the + * Spark REPL and registering the outer scope that this class was defined in allows us to create + * new instances on the spark executors. In normal use, users should not need to call this + * function. + * + * Warning: this function operates on the assumption that there is only ever one instance of any + * given wrapper class. + */ + def addOuterScope(outer: AnyRef): Unit = { + outerScopes.putIfAbsent(outer.getClass.getName, outer) + } + + /** + * Returns a function which can get the outer scope for the given inner class. By using function + * as return type, we can delay the process of getting outer pointer to execution time, which is + * useful for inner class defined in REPL. + */ + def getOuterScope(innerCls: Class[_]): () => AnyRef = { + assert(innerCls.isMemberClass) + val outerClassName = innerCls.getDeclaringClass.getName + val outer = outerScopes.get(outerClassName) + if (outer == null) { + outerClassName match { + // If the outer class is generated by REPL, users don't need to register it as it has + // only one instance and there is a way to retrieve it: get the `$read` object, call the + // `INSTANCE()` method to get the single instance of class `$read`. Then call `$iw()` + // method multiply times to get the single instance of the inner most `$iw` class. + case REPLClass(baseClassName) => + () => { + val objClass = Utils.classForName(baseClassName + "$") + val objInstance = objClass.getField("MODULE$").get(null) + val baseInstance = objClass.getMethod("INSTANCE").invoke(objInstance) + val baseClass = Utils.classForName(baseClassName) + + var getter = iwGetter(baseClass) + var obj = baseInstance + while (getter != null) { + obj = getter.invoke(obj) + getter = iwGetter(getter.getReturnType) + } + + if (obj == null) { + throw new RuntimeException(s"Failed to get outer pointer for ${innerCls.getName}") + } + + outerScopes.putIfAbsent(outerClassName, obj) + obj + } + case _ => null + } + } else { + () => outer + } + } + + private def iwGetter(cls: Class[_]) = { + try { + cls.getMethod("$iw") + } catch { + case _: NoSuchMethodException => null + } + } + + // The format of REPL generated wrapper class's name, e.g. `$line12.$read$$iw$$iw` + private[this] val REPLClass = """^(\$line(?:\d+)\.\$read)(?:\$\$iw)+$""".r +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 0b42130a013b2..a8397aa5e5c26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -22,7 +22,8 @@ import scala.reflect.ClassTag import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -34,41 +35,51 @@ object RowEncoder { def apply(schema: StructType): ExpressionEncoder[Row] = { val cls = classOf[Row] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val extractExpressions = extractorsFor(inputObject, schema) - val constructExpression = constructorFor(schema) + // We use an If expression to wrap extractorsFor result of StructType + val serializer = serializerFor(inputObject, schema).asInstanceOf[If].falseValue + val deserializer = deserializerFor(schema) new ExpressionEncoder[Row]( schema, flat = false, - extractExpressions.asInstanceOf[CreateStruct].children, - constructExpression, + serializer.asInstanceOf[CreateStruct].children, + deserializer, ClassTag(cls)) } - private def extractorsFor( + private def serializerFor( inputObject: Expression, inputType: DataType): Expression = inputType match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType => inputObject + case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType | BinaryType | CalendarIntervalType => inputObject + + case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType) + + case udt: UserDefinedType[_] => + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) case TimestampType => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", inputObject :: Nil) case DateType => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, DateType, "fromJavaDate", inputObject :: Nil) case _: DecimalType => StaticInvoke( - Decimal, + Decimal.getClass, DecimalType.SYSTEM_DEFAULT, - "apply", + "fromDecimal", inputObject :: Nil) case StringType => @@ -84,7 +95,7 @@ object RowEncoder { classOf[GenericArrayData], inputObject :: Nil, dataType = t) - case _ => MapObjects(extractorsFor(_, et), inputObject, externalDataTypeFor(et)) + case _ => MapObjects(serializerFor(_, et), inputObject, externalDataTypeForInput(et)) } case t @ MapType(kt, vt, valueNullable) => @@ -93,14 +104,14 @@ object RowEncoder { Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])), "toSeq", ObjectType(classOf[scala.collection.Seq[_]])) - val convertedKeys = extractorsFor(keys, ArrayType(kt, false)) + val convertedKeys = serializerFor(keys, ArrayType(kt, false)) val values = Invoke( Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]])), "toSeq", ObjectType(classOf[scala.collection.Seq[_]])) - val convertedValues = extractorsFor(values, ArrayType(vt, valueNullable)) + val convertedValues = serializerFor(values, ArrayType(vt, valueNullable)) NewInstance( classOf[ArrayBasedMapData], @@ -109,19 +120,41 @@ object RowEncoder { case StructType(fields) => val convertedFields = fields.zipWithIndex.map { case (f, i) => + val method = if (f.dataType.isInstanceOf[StructType]) { + "getStruct" + } else { + "get" + } If( Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), Literal.create(null, f.dataType), - extractorsFor( - Invoke(inputObject, "get", externalDataTypeFor(f.dataType), Literal(i) :: Nil), + serializerFor( + Invoke(inputObject, method, externalDataTypeForInput(f.dataType), Literal(i) :: Nil), f.dataType)) } - CreateStruct(convertedFields) + If(IsNull(inputObject), + Literal.create(null, inputType), + CreateStruct(convertedFields)) + } + + /** + * Returns the `DataType` that can be used when generating code that converts input data + * into the Spark SQL internal format. Unlike `externalDataTypeFor`, the `DataType` returned + * by this function can be more permissive since multiple external types may map to a single + * internal type. For example, for an input with DecimalType in external row, its external types + * can be `scala.math.BigDecimal`, `java.math.BigDecimal`, or + * `org.apache.spark.sql.types.Decimal`. + */ + private def externalDataTypeForInput(dt: DataType): DataType = dt match { + // In order to support both Decimal and java BigDecimal in external row, we make this + // as java.lang.Object. + case _: DecimalType => ObjectType(classOf[java.lang.Object]) + case _ => externalDataTypeFor(dt) } private def externalDataTypeFor(dt: DataType): DataType = dt match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType => dt + case _ if ScalaReflection.isNativeType(dt) => dt + case CalendarIntervalType => dt case TimestampType => ObjectType(classOf[java.sql.Timestamp]) case DateType => ObjectType(classOf[java.sql.Date]) case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) @@ -129,34 +162,47 @@ object RowEncoder { case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]]) case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) case _: StructType => ObjectType(classOf[Row]) + case udt: UserDefinedType[_] => ObjectType(udt.userClass) + case _: NullType => ObjectType(classOf[java.lang.Object]) } - private def constructorFor(schema: StructType): Expression = { + private def deserializerFor(schema: StructType): Expression = { val fields = schema.zipWithIndex.map { case (f, i) => - val field = BoundReference(i, f.dataType, f.nullable) + val dt = f.dataType match { + case p: PythonUserDefinedType => p.sqlType + case other => other + } + val field = BoundReference(i, dt, f.nullable) If( IsNull(field), - Literal.create(null, externalDataTypeFor(f.dataType)), - constructorFor(BoundReference(i, f.dataType, f.nullable), f.dataType) + Literal.create(null, externalDataTypeFor(dt)), + deserializerFor(field) ) } - CreateExternalRow(fields) + CreateExternalRow(fields, schema) } - private def constructorFor(input: Expression, dataType: DataType): Expression = dataType match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType => input + private def deserializerFor(input: Expression): Expression = input.dataType match { + case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType | BinaryType | CalendarIntervalType => input + + case udt: UserDefinedType[_] => + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil) case TimestampType => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", input :: Nil) case DateType => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", input :: Nil) @@ -170,69 +216,37 @@ object RowEncoder { case ArrayType(et, nullable) => val arrayData = Invoke( - MapObjects(constructorFor(_, et), input, et), + MapObjects(deserializerFor(_), input, et), "array", ObjectType(classOf[Array[_]])) StaticInvoke( - scala.collection.mutable.WrappedArray, + scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), "make", arrayData :: Nil) case MapType(kt, vt, valueNullable) => val keyArrayType = ArrayType(kt, false) - val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType), keyArrayType) + val keyData = deserializerFor(Invoke(input, "keyArray", keyArrayType)) val valueArrayType = ArrayType(vt, valueNullable) - val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType), valueArrayType) + val valueData = deserializerFor(Invoke(input, "valueArray", valueArrayType)) StaticInvoke( - ArrayBasedMapData, + ArrayBasedMapData.getClass, ObjectType(classOf[Map[_, _]]), "toScalaMap", keyData :: valueData :: Nil) - case StructType(fields) => + case schema @ StructType(fields) => val convertedFields = fields.zipWithIndex.map { case (f, i) => If( Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil), Literal.create(null, externalDataTypeFor(f.dataType)), - constructorFor(getField(input, i, f.dataType), f.dataType)) + deserializerFor(GetStructField(input, i))) } - CreateExternalRow(convertedFields) - } - - private def getField( - row: Expression, - ordinal: Int, - dataType: DataType): Expression = dataType match { - case BooleanType => - Invoke(row, "getBoolean", dataType, Literal(ordinal) :: Nil) - case ByteType => - Invoke(row, "getByte", dataType, Literal(ordinal) :: Nil) - case ShortType => - Invoke(row, "getShort", dataType, Literal(ordinal) :: Nil) - case IntegerType | DateType => - Invoke(row, "getInt", dataType, Literal(ordinal) :: Nil) - case LongType | TimestampType => - Invoke(row, "getLong", dataType, Literal(ordinal) :: Nil) - case FloatType => - Invoke(row, "getFloat", dataType, Literal(ordinal) :: Nil) - case DoubleType => - Invoke(row, "getDouble", dataType, Literal(ordinal) :: Nil) - case t: DecimalType => - Invoke(row, "getDecimal", dataType, Seq(ordinal, t.precision, t.scale).map(Literal(_))) - case StringType => - Invoke(row, "getUTF8String", dataType, Literal(ordinal) :: Nil) - case BinaryType => - Invoke(row, "getBinary", dataType, Literal(ordinal) :: Nil) - case CalendarIntervalType => - Invoke(row, "getInterval", dataType, Literal(ordinal) :: Nil) - case t: StructType => - Invoke(row, "getStruct", dataType, Literal(ordinal) :: Literal(t.size) :: Nil) - case _: ArrayType => - Invoke(row, "getArray", dataType, Literal(ordinal) :: Nil) - case _: MapType => - Invoke(row, "getMap", dataType, Literal(ordinal) :: Nil) + If(IsNull(input), + Literal.create(null, externalDataTypeFor(input.dataType)), + CreateExternalRow(convertedFields, schema)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala index d4642a500672e..03708fb7afd44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala @@ -17,10 +17,19 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.sql.Encoder + package object encoders { + /** + * Returns an internal encoder object that can be used to serialize / deserialize JVM objects + * into Spark SQL rows. The implicit encoder should always be unresolved (i.e. have no attribute + * references from a specific schema.) This requirement allows us to preserve whether a given + * object type is being bound by name or by ordinal when doing resolution. + */ private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match { - case e: ExpressionEncoder[A] => e + case e: ExpressionEncoder[A] => + e.assertUnresolved() + e case _ => sys.error(s"Only expression encoders are supported today") } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala index d2a90a50c89f4..0420b4b5387c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala @@ -25,21 +25,22 @@ import org.apache.spark.sql.catalyst.trees.TreeNode package object errors { class TreeNodeException[TreeType <: TreeNode[_]]( - tree: TreeType, msg: String, cause: Throwable) + @transient val tree: TreeType, + msg: String, + cause: Throwable) extends Exception(msg, cause) { + val treeString = tree.toString + // Yes, this is the same as a default parameter, but... those don't seem to work with SBT // external project dependencies for some reason. def this(tree: TreeType, msg: String) = this(tree, msg, null) override def getMessage: String = { - val treeString = tree.toString s"${super.getMessage}, tree:${if (treeString contains "\n") "\n" else " "}$tree" } } - class DialectException(msg: String, cause: Throwable) extends Exception(msg, cause) - /** * Wraps any exceptions that are thrown while executing `f` in a * [[catalyst.errors.TreeNodeException TreeNodeException]], attaching the provided `tree`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index 3831535574205..8bdf9b29c9641 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -53,7 +53,7 @@ object AttributeSet { * cosmetically (e.g., the names have different capitalizations). * * Note that we do not override equality for Attribute references as it is really weird when - * `AttributeReference("a"...) == AttrributeReference("b", ...)`. This tactic leads to broken tests, + * `AttributeReference("a"...) == AttributeReference("b", ...)`. This tactic leads to broken tests, * and also makes doing transformations hard (we always try keep older trees instead of new ones * when the transformation was a no-op). */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index ff1f28ddbbf35..c1fd23f28d6b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types._ /** @@ -29,9 +29,9 @@ import org.apache.spark.sql.types._ * the layout of intermediate tuples, BindReferences should be run after all such transformations. */ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) - extends LeafExpression with NamedExpression { + extends LeafExpression { - override def toString: String = s"input[$ordinal, $dataType]" + override def toString: String = s"input[$ordinal, ${dataType.simpleString}]" // Use special getter for primitive types (for UnsafeRow) override def eval(input: InternalRow): Any = { @@ -58,21 +58,27 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) } } - override def name: String = s"i[$ordinal]" - - override def toAttribute: Attribute = throw new UnsupportedOperationException - - override def qualifiers: Seq[String] = throw new UnsupportedOperationException - - override def exprId: ExprId = throw new UnsupportedOperationException - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val javaType = ctx.javaType(dataType) val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) - s""" - boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); - $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); - """ + if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) { + val oev = ctx.currentVars(ordinal) + ev.isNull = oev.isNull + ev.value = oev.value + val code = oev.code + oev.code = "" + code + } else if (nullable) { + s""" + boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); + $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); + """ + } else { + ev.isNull = "false" + s""" + $javaType ${ev.value} = $value; + """ + } } } @@ -92,7 +98,7 @@ object BindReferences extends Logging { sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") } } else { - BoundReference(ordinal, a.dataType, a.nullable) + BoundReference(ordinal, a.dataType, input(ordinal).nullable) } } }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala new file mode 100644 index 0000000000000..07ba7d5e4a849 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +/** + * Rewrites an expression using rules that are guaranteed preserve the result while attempting + * to remove cosmetic variations. Deterministic expressions that are `equal` after canonicalization + * will always return the same answer given the same input (i.e. false positives should not be + * possible). However, it is possible that two canonical expressions that are not equal will in fact + * return the same answer given any input (i.e. false negatives are possible). + * + * The following rules are applied: + * - Names and nullability hints for [[org.apache.spark.sql.types.DataType]]s are stripped. + * - Commutative and associative operations ([[Add]] and [[Multiply]]) have their children ordered + * by `hashCode`. + * - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`. + * - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by `hashCode`. + */ +object Canonicalize extends { + def execute(e: Expression): Expression = { + expressionReorder(ignoreNamesTypes(e)) + } + + /** Remove names and nullability from types. */ + private def ignoreNamesTypes(e: Expression): Expression = e match { + case a: AttributeReference => + AttributeReference("none", a.dataType.asNullable)(exprId = a.exprId) + case _ => e + } + + /** Collects adjacent commutative operations. */ + private def gatherCommutative( + e: Expression, + f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] = e match { + case c if f.isDefinedAt(c) => f(c).flatMap(gatherCommutative(_, f)) + case other => other :: Nil + } + + /** Orders a set of commutative operations by their hash code. */ + private def orderCommutative( + e: Expression, + f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] = + gatherCommutative(e, f).sortBy(_.hashCode()) + + /** Rearrange expressions that are commutative or associative. */ + private def expressionReorder(e: Expression): Expression = e match { + case a: Add => orderCommutative(a, { case Add(l, r) => Seq(l, r) }).reduce(Add) + case m: Multiply => orderCommutative(m, { case Multiply(l, r) => Seq(l, r) }).reduce(Multiply) + + case EqualTo(l, r) if l.hashCode() > r.hashCode() => EqualTo(r, l) + case EqualNullSafe(l, r) if l.hashCode() > r.hashCode() => EqualNullSafe(r, l) + + case GreaterThan(l, r) if l.hashCode() > r.hashCode() => LessThan(r, l) + case LessThan(l, r) if l.hashCode() > r.hashCode() => GreaterThan(r, l) + + case GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l) + case LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l) + + case Not(GreaterThan(l, r)) if l.hashCode() > r.hashCode() => GreaterThan(r, l) + case Not(GreaterThan(l, r)) => LessThanOrEqual(l, r) + case Not(LessThan(l, r)) if l.hashCode() > r.hashCode() => LessThan(r, l) + case Not(LessThan(l, r)) => GreaterThanOrEqual(l, r) + case Not(GreaterThanOrEqual(l, r)) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l) + case Not(GreaterThanOrEqual(l, r)) => LessThan(l, r) + case Not(LessThanOrEqual(l, r)) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l) + case Not(LessThanOrEqual(l, r)) => GreaterThan(l, r) + + case _ => e + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 5564e242b0472..0f8876a9e6881 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.math.{BigDecimal => JavaBigDecimal} +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -81,31 +82,37 @@ object Cast { toField.nullable) } + case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt1.userClass == udt2.userClass => + true + case _ => false } private def resolvableNullability(from: Boolean, to: Boolean) = !from || to private def forceNullable(from: DataType, to: DataType) = (from, to) match { - case (StringType, _: NumericType) => true - case (StringType, TimestampType) => true - case (DoubleType, TimestampType) => true - case (FloatType, TimestampType) => true - case (StringType, DateType) => true - case (_: NumericType, DateType) => true - case (BooleanType, DateType) => true - case (DateType, _: NumericType) => true - case (DateType, BooleanType) => true - case (DoubleType, _: DecimalType) => true - case (FloatType, _: DecimalType) => true - case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null + case (NullType, _) => true + case (_, _) if from == to => false + + case (StringType, BinaryType) => false + case (StringType, _) => true + case (_, StringType) => false + + case (FloatType | DoubleType, TimestampType) => true + case (TimestampType, DateType) => false + case (_, DateType) => true + case (DateType, TimestampType) => false + case (DateType, _) => true + case (_, CalendarIntervalType) => true + + case (_, _: DecimalType) => true // overflow + case (_: FractionalType, _: IntegralType) => true // NaN, infinity case _ => false } } /** Cast the child expression to the target data type. */ -case class Cast(child: Expression, dataType: DataType) - extends UnaryExpression with CodegenFallback { +case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with NullIntolerant { override def toString: String = s"cast($child as ${dataType.simpleString})" @@ -204,8 +211,8 @@ case class Cast(child: Expression, dataType: DataType) if (d.isNaN || d.isInfinite) null else (d * 1000000L).toLong } - // converting milliseconds to us - private[this] def longToTimestamp(t: Long): Long = t * 1000L + // converting seconds to us + private[this] def longToTimestamp(t: Long): Long = t * 1000000L // converting us to seconds private[this] def timestampToLong(ts: Long): Long = math.floor(ts.toDouble / 1000000L).toLong // converting us to seconds in double @@ -428,13 +435,18 @@ case class Cast(child: Expression, dataType: DataType) case array: ArrayType => castArray(from.asInstanceOf[ArrayType].elementType, array.elementType) case map: MapType => castMap(from.asInstanceOf[MapType], map) case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) + case udt: UserDefinedType[_] + if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => + identity[Any] + case _: UserDefinedType[_] => + throw new SparkException(s"Cannot cast $from to $to.") } private[this] lazy val cast: Any => Any = cast(child.dataType, dataType) protected override def nullSafeEval(input: Any): Any = cast(input) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval = child.gen(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) eval.code + @@ -448,7 +460,7 @@ case class Cast(child: Expression, dataType: DataType) private[this] def nullSafeCastFunction( from: DataType, to: DataType, - ctx: CodeGenContext): CastFunction = to match { + ctx: CodegenContext): CastFunction = to match { case _ if from == NullType => (c, evPrim, evNull) => s"$evNull = true;" case _ if to == from => (c, evPrim, evNull) => s"$evPrim = $c;" @@ -470,11 +482,16 @@ case class Cast(child: Expression, dataType: DataType) castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx) case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx) case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx) + case udt: UserDefinedType[_] + if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => + (c, evPrim, evNull) => s"$evPrim = $c;" + case _: UserDefinedType[_] => + throw new SparkException(s"Cannot cast $from to $to.") } // Since we need to cast child expressions recursively inside ComplexTypes, such as Map's // Key and Value, Struct's field, we need to name out all the variable names involved in a cast. - private[this] def castCode(ctx: CodeGenContext, childPrim: String, childNull: String, + private[this] def castCode(ctx: CodegenContext, childPrim: String, childNull: String, resultPrim: String, resultNull: String, resultType: DataType, cast: CastFunction): String = { s""" boolean $resultNull = $childNull; @@ -485,7 +502,7 @@ case class Cast(child: Expression, dataType: DataType) """ } - private[this] def castToStringCode(from: DataType, ctx: CodeGenContext): CastFunction = { + private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromBytes($c);" @@ -507,7 +524,7 @@ case class Cast(child: Expression, dataType: DataType) private[this] def castToDateCode( from: DataType, - ctx: CodeGenContext): CastFunction = from match { + ctx: CodegenContext): CastFunction = from match { case StringType => val intOpt = ctx.freshName("intOpt") (c, evPrim, evNull) => s""" @@ -539,7 +556,7 @@ case class Cast(child: Expression, dataType: DataType) private[this] def castToDecimalCode( from: DataType, target: DecimalType, - ctx: CodeGenContext): CastFunction = { + ctx: CodegenContext): CastFunction = { val tmp = ctx.freshName("tmpDecimal") from match { case StringType => @@ -597,7 +614,7 @@ case class Cast(child: Expression, dataType: DataType) private[this] def castToTimestampCode( from: DataType, - ctx: CodeGenContext): CastFunction = from match { + ctx: CodegenContext): CastFunction = from match { case StringType => val longOpt = ctx.freshName("longOpt") (c, evPrim, evNull) => @@ -647,7 +664,7 @@ case class Cast(child: Expression, dataType: DataType) private[this] def decimalToTimestampCode(d: String): String = s"($d.toBigDecimal().bigDecimal().multiply(new java.math.BigDecimal(1000000L))).longValue()" - private[this] def longToTimeStampCode(l: String): String = s"$l * 1000L" + private[this] def longToTimeStampCode(l: String): String = s"$l * 1000000L" private[this] def timestampToIntegerCode(ts: String): String = s"java.lang.Math.floor((double) $ts / 1000000L)" private[this] def timestampToDoubleCode(ts: String): String = s"$ts / 1000000.0" @@ -809,7 +826,7 @@ case class Cast(child: Expression, dataType: DataType) } private[this] def castArrayCode( - fromType: DataType, toType: DataType, ctx: CodeGenContext): CastFunction = { + fromType: DataType, toType: DataType, ctx: CodegenContext): CastFunction = { val elementCast = nullSafeCastFunction(fromType, toType, ctx) val arrayClass = classOf[GenericArrayData].getName val fromElementNull = ctx.freshName("feNull") @@ -844,7 +861,7 @@ case class Cast(child: Expression, dataType: DataType) """ } - private[this] def castMapCode(from: MapType, to: MapType, ctx: CodeGenContext): CastFunction = { + private[this] def castMapCode(from: MapType, to: MapType, ctx: CodegenContext): CastFunction = { val keysCast = castArrayCode(from.keyType, to.keyType, ctx) val valuesCast = castArrayCode(from.valueType, to.valueType, ctx) @@ -872,7 +889,7 @@ case class Cast(child: Expression, dataType: DataType) } private[this] def castStructCode( - from: StructType, to: StructType, ctx: CodeGenContext): CastFunction = { + from: StructType, to: StructType, ctx: CodegenContext): CastFunction = { val fieldsCasts = from.fields.zip(to.fields).map { case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx) @@ -881,7 +898,7 @@ case class Cast(child: Expression, dataType: DataType) val result = ctx.freshName("result") val tmpRow = ctx.freshName("tmpRow") - val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => { + val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => val fromFieldPrim = ctx.freshName("ffp") val fromFieldNull = ctx.freshName("ffn") val toFieldPrim = ctx.freshName("tfp") @@ -903,15 +920,31 @@ case class Cast(child: Expression, dataType: DataType) } } """ - } }.mkString("\n") (c, evPrim, evNull) => s""" - final $rowClass $result = new $rowClass(${fieldsCasts.size}); + final $rowClass $result = new $rowClass(${fieldsCasts.length}); final InternalRow $tmpRow = $c; $fieldsEvalCode $evPrim = $result.copy(); """ } + + override def sql: String = dataType match { + // HiveQL doesn't allow casting to complex types. For logical plans translated from HiveQL, this + // type of casting can only be introduced by the analyzer, and can be omitted when converting + // back to SQL query string. + case _: ArrayType | _: MapType | _: StructType => child.sql + case _ => s"CAST(${child.sql} AS ${dataType.sql})" + } +} + +/** + * Cast the child expression to the target data type, but will throw error if the cast might + * truncate, e.g. long -> int, timestamp -> data. + */ +case class UpCast(child: Expression, dataType: DataType, walkedTypePath: Seq[String]) + extends UnaryExpression with Unevaluable { + override lazy val resolved = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala new file mode 100644 index 0000000000000..8d8cc152ff29c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback + +/** + * This class is used to compute equality of (sub)expression trees. Expressions can be added + * to this class and they subsequently query for expression equality. Expression trees are + * considered equal if for the same input(s), the same result is produced. + */ +class EquivalentExpressions { + /** + * Wrapper around an Expression that provides semantic equality. + */ + case class Expr(e: Expression) { + override def equals(o: Any): Boolean = o match { + case other: Expr => e.semanticEquals(other.e) + case _ => false + } + override val hashCode: Int = e.semanticHash() + } + + // For each expression, the set of equivalent expressions. + private val equivalenceMap = mutable.HashMap.empty[Expr, mutable.MutableList[Expression]] + + /** + * Adds each expression to this data structure, grouping them with existing equivalent + * expressions. Non-recursive. + * Returns true if there was already a matching expression. + */ + def addExpr(expr: Expression): Boolean = { + if (expr.deterministic) { + val e: Expr = Expr(expr) + val f = equivalenceMap.get(e) + if (f.isDefined) { + f.get += expr + true + } else { + equivalenceMap.put(e, mutable.MutableList(expr)) + false + } + } else { + false + } + } + + /** + * Adds the expression to this data structure recursively. Stops if a matching expression + * is found. That is, if `expr` has already been added, its children are not added. + * If ignoreLeaf is true, leaf nodes are ignored. + */ + def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = { + val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf + // the children of CodegenFallback will not be used to generate code (call eval() instead) + if (!skip && !addExpr(root) && !root.isInstanceOf[CodegenFallback]) { + root.children.foreach(addExprTree(_, ignoreLeaf)) + } + } + + /** + * Returns all of the expression trees that are equivalent to `e`. Returns + * an empty collection if there are none. + */ + def getEquivalentExprs(e: Expression): Seq[Expression] = { + equivalenceMap.getOrElse(Expr(e), mutable.MutableList()) + } + + /** + * Returns all the equivalent sets of expressions. + */ + def getAllEquivalentExprs: Seq[Seq[Expression]] = { + equivalenceMap.values.map(_.toSeq).toSeq + } + + /** + * Returns the state of the data structure as a string. If `all` is false, skips sets of + * equivalent expressions with cardinality 1. + */ + def debugString(all: Boolean = false): String = { + val sb: mutable.StringBuilder = new StringBuilder() + sb.append("Equivalent expressions:\n") + equivalenceMap.foreach { case (k, v) => + if (all || v.length > 1) { + sb.append(" " + v.mkString(", ")).append("\n") + } + } + sb.toString() + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 2dcbd4eb15031..b3dfac806f7fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.types.AbstractDataType -import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion.ImplicitTypeCasts /** * An trait that gets mixin to define the expected input types of an expression. @@ -45,7 +44,7 @@ trait ExpectsInputTypes extends Expression { val mismatches = children.zip(inputTypes).zipWithIndex.collect { case ((child, expected), idx) if !expected.acceptsType(child.dataType) => s"argument ${idx + 1} requires ${expected.simpleString} type, " + - s"however, '${child.prettyString}' is of ${child.dataType.simpleString} type." + s"however, '${child.sql}' is of ${child.dataType.simpleString} type." } if (mismatches.isEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 96fcc799e537a..718bb4b118cea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.util.toCommentSafeString import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -85,19 +86,30 @@ abstract class Expression extends TreeNode[Expression] { def eval(input: InternalRow = null): Any /** - * Returns an [[GeneratedExpressionCode]], which contains Java source code that + * Returns an [[ExprCode]], which contains Java source code that * can be used to generate the result of evaluating the expression on an input row. * - * @param ctx a [[CodeGenContext]] - * @return [[GeneratedExpressionCode]] + * @param ctx a [[CodegenContext]] + * @return [[ExprCode]] */ - def gen(ctx: CodeGenContext): GeneratedExpressionCode = { - val isNull = ctx.freshName("isNull") - val primitive = ctx.freshName("primitive") - val ve = GeneratedExpressionCode("", isNull, primitive) - ve.code = genCode(ctx, ve) - // Add `this` in the comment. - ve.copy(s"/* $this */\n" + ve.code) + def gen(ctx: CodegenContext): ExprCode = { + ctx.subExprEliminationExprs.get(this).map { subExprState => + // This expression is repeated meaning the code to evaluated has already been added + // as a function and called in advance. Just use it. + val code = s"/* ${toCommentSafeString(this.toString)} */" + ExprCode(code, subExprState.isNull, subExprState.value) + }.getOrElse { + val isNull = ctx.freshName("isNull") + val value = ctx.freshName("value") + val ve = ExprCode("", isNull, value) + ve.code = genCode(ctx, ve) + if (ve.code != "") { + // Add `this` in the comment. + ve.copy(s"/* ${toCommentSafeString(this.toString)} */\n" + ve.code.trim) + } else { + ve + } + } } /** @@ -105,11 +117,11 @@ abstract class Expression extends TreeNode[Expression] { * The default behavior is to call the eval method of the expression. Concrete expression * implementations should override this to do actual code generation. * - * @param ctx a [[CodeGenContext]] - * @param ev an [[GeneratedExpressionCode]] with unique terms. + * @param ctx a [[CodegenContext]] + * @param ev an [[ExprCode]] with unique terms. * @return Java source code */ - protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String + protected def genCode(ctx: CodegenContext, ev: ExprCode): String /** * Returns `true` if this expression and all its children have been resolved to a specific schema @@ -132,23 +144,35 @@ abstract class Expression extends TreeNode[Expression] { */ def childrenResolved: Boolean = children.forall(_.resolved) + /** + * Returns an expression where a best effort attempt has been made to transform `this` in a way + * that preserves the result but removes cosmetic variations (case sensitivity, ordering for + * commutative operations, etc.) See [[Canonicalize]] for more details. + * + * `deterministic` expressions where `this.canonicalized == other.canonicalized` will always + * evaluate to the same result. + */ + lazy val canonicalized: Expression = { + val canonicalizedChildren = children.map(_.canonicalized) + Canonicalize.execute(withNewChildren(canonicalizedChildren)) + } + /** * Returns true when two expressions will always compute the same result, even if they differ * cosmetically (i.e. capitalization of names in attributes may be different). + * + * See [[Canonicalize]] for more details. */ - def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && { - def checkSemantic(elements1: Seq[Any], elements2: Seq[Any]): Boolean = { - elements1.length == elements2.length && elements1.zip(elements2).forall { - case (e1: Expression, e2: Expression) => e1 semanticEquals e2 - case (Some(e1: Expression), Some(e2: Expression)) => e1 semanticEquals e2 - case (t1: Traversable[_], t2: Traversable[_]) => checkSemantic(t1.toSeq, t2.toSeq) - case (i1, i2) => i1 == i2 - } - } - val elements1 = this.productIterator.toSeq - val elements2 = other.asInstanceOf[Product].productIterator.toSeq - checkSemantic(elements1, elements2) - } + def semanticEquals(other: Expression): Boolean = + deterministic && other.deterministic && canonicalized == other.canonicalized + + /** + * Returns a `hashCode` for the calculation performed by this expression. Unlike the standard + * `hashCode`, an attempt has been made to eliminate cosmetic differences. + * + * See [[Canonicalize]] for more details. + */ + def semanticHash(): Int = canonicalized.hashCode() /** * Checks the input data types, returns `TypeCheckResult.success` if it's valid, @@ -161,26 +185,25 @@ abstract class Expression extends TreeNode[Expression] { * Returns a user-facing string representation of this expression's name. * This should usually match the name of the function in SQL. */ - def prettyName: String = getClass.getSimpleName.toLowerCase - - /** - * Returns a user-facing string representation of this expression, i.e. does not have developer - * centric debugging information like the expression id. - */ - def prettyString: String = { - transform { - case a: AttributeReference => PrettyAttribute(a.name) - case u: UnresolvedAttribute => PrettyAttribute(u.name) - }.toString - } - + def prettyName: String = nodeName.toLowerCase private def flatArguments = productIterator.flatMap { case t: Traversable[_] => t case single => single :: Nil } - override def toString: String = prettyName + flatArguments.mkString("(", ",", ")") + override def simpleString: String = toString + + override def toString: String = prettyName + flatArguments.mkString("(", ", ", ")") + + /** + * Returns SQL representation of this expression. For expressions extending [[NonSQLExpression]], + * this method may return an arbitrary user facing string. + */ + def sql: String = { + val childrenSQL = children.map(_.sql).mkString(", ") + s"$prettyName($childrenSQL)" + } } @@ -193,11 +216,24 @@ trait Unevaluable extends Expression { final override def eval(input: InternalRow = null): Any = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") - final override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = + final override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") } +/** + * Expressions that don't have SQL representation should extend this trait. Examples are + * `ScalaUDF`, `ScalaUDAF`, and object expressions like `MapObjects` and `Invoke`. + */ +trait NonSQLExpression extends Expression { + final override def sql: String = { + transform { + case a: Attribute => new PrettyAttribute(a) + }.toString + } +} + + /** * An expression that is nondeterministic. */ @@ -268,7 +304,7 @@ abstract class UnaryExpression extends Expression { /** * Called by unary expressions to generate a code block that returns null if its parent returns - * null, and if not not null, use `f` to generate the expression. + * null, and if not null, use `f` to generate the expression. * * As an example, the following does a boolean inversion (i.e. NOT). * {{{ @@ -278,8 +314,8 @@ abstract class UnaryExpression extends Expression { * @param f function that accepts a variable name and returns Java code to compute the output. */ protected def defineCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, + ctx: CodegenContext, + ev: ExprCode, f: String => String): String = { nullSafeCodeGen(ctx, ev, eval => { s"${ev.value} = ${f(eval)};" @@ -288,28 +324,37 @@ abstract class UnaryExpression extends Expression { /** * Called by unary expressions to generate a code block that returns null if its parent returns - * null, and if not not null, use `f` to generate the expression. + * null, and if not null, use `f` to generate the expression. * * @param f function that accepts the non-null evaluation result name of child and returns Java * code to compute the output. */ protected def nullSafeCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, + ctx: CodegenContext, + ev: ExprCode, f: String => String): String = { - val eval = child.gen(ctx) - val resultCode = f(eval.value) - eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { + val childGen = child.gen(ctx) + val resultCode = f(childGen.value) + + if (nullable) { + val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode) + s""" + ${childGen.code} + boolean ${ev.isNull} = ${childGen.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $nullSafeEval + """ + } else { + ev.isNull = "false" + s""" + ${childGen.code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; $resultCode - } - """ + """ + } } } - /** * An expression with two inputs and one output. The output is by default evaluated to null * if any input is evaluated to null. @@ -359,8 +404,8 @@ abstract class BinaryExpression extends Expression { * @param f accepts two variable names and returns Java code to compute the output. */ protected def defineCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, + ctx: CodegenContext, + ev: ExprCode, f: (String, String) => String): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s"${ev.value} = ${f(eval1, eval2)};" @@ -376,25 +421,38 @@ abstract class BinaryExpression extends Expression { * and returns Java code to compute the output. */ protected def nullSafeCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, + ctx: CodegenContext, + ev: ExprCode, f: (String, String) => String): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) - val resultCode = f(eval1.value, eval2.value) - s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${eval2.code} - if (!${eval2.isNull}) { - $resultCode - } else { - ${ev.isNull} = true; - } + val leftGen = left.gen(ctx) + val rightGen = right.gen(ctx) + val resultCode = f(leftGen.value, rightGen.value) + + if (nullable) { + val nullSafeEval = + leftGen.code + ctx.nullSafeExec(left.nullable, leftGen.isNull) { + rightGen.code + ctx.nullSafeExec(right.nullable, rightGen.isNull) { + s""" + ${ev.isNull} = false; // resultCode could change nullability. + $resultCode + """ + } } - """ + + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $nullSafeEval + """ + } else { + ev.isNull = "false" + s""" + ${leftGen.code} + ${rightGen.code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $resultCode + """ + } } } @@ -416,6 +474,8 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { def symbol: String + def sqlOperator: String = symbol + override def toString: String = s"($left $symbol $right)" override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType) @@ -423,15 +483,17 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { override def checkInputDataTypes(): TypeCheckResult = { // First check whether left and right have the same type, then check if the type is acceptable. if (left.dataType != right.dataType) { - TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + + TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).") } else if (!inputType.acceptsType(left.dataType)) { - TypeCheckResult.TypeCheckFailure(s"'$prettyString' requires ${inputType.simpleString} type," + + TypeCheckResult.TypeCheckFailure(s"'$sql' requires ${inputType.simpleString} type," + s" not ${left.dataType.simpleString}") } else { TypeCheckResult.TypeCheckSuccess } } + + override def sql: String = s"(${left.sql} $sqlOperator ${right.sql})" } @@ -451,7 +513,7 @@ abstract class TernaryExpression extends Expression { /** * Default behavior of evaluation according to the default nullability of TernaryExpression. - * If subclass of BinaryExpression override nullable, probably should also override this. + * If subclass of TernaryExpression override nullable, probably should also override this. */ override def eval(input: InternalRow): Any = { val exprs = children @@ -477,15 +539,15 @@ abstract class TernaryExpression extends Expression { sys.error(s"BinaryExpressions must override either eval or nullSafeEval") /** - * Short hand for generating binary evaluation code. + * Short hand for generating ternary evaluation code. * If either of the sub-expressions is null, the result of this computation * is assumed to be null. * - * @param f accepts two variable names and returns Java code to compute the output. + * @param f accepts three variable names and returns Java code to compute the output. */ protected def defineCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, + ctx: CodegenContext, + ev: ExprCode, f: (String, String, String) => String): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2, eval3) => { s"${ev.value} = ${f(eval1, eval2, eval3)};" @@ -493,33 +555,49 @@ abstract class TernaryExpression extends Expression { } /** - * Short hand for generating binary evaluation code. + * Short hand for generating ternary evaluation code. * If either of the sub-expressions is null, the result of this computation * is assumed to be null. * - * @param f function that accepts the 2 non-null evaluation result names of children + * @param f function that accepts the 3 non-null evaluation result names of children * and returns Java code to compute the output. */ protected def nullSafeCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, + ctx: CodegenContext, + ev: ExprCode, f: (String, String, String) => String): String = { - val evals = children.map(_.gen(ctx)) - val resultCode = f(evals(0).value, evals(1).value, evals(2).value) - s""" - ${evals(0).code} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${evals(0).isNull}) { - ${evals(1).code} - if (!${evals(1).isNull}) { - ${evals(2).code} - if (!${evals(2).isNull}) { - ${ev.isNull} = false; // resultCode could change nullability - $resultCode + val leftGen = children(0).gen(ctx) + val midGen = children(1).gen(ctx) + val rightGen = children(2).gen(ctx) + val resultCode = f(leftGen.value, midGen.value, rightGen.value) + + if (nullable) { + val nullSafeEval = + leftGen.code + ctx.nullSafeExec(children(0).nullable, leftGen.isNull) { + midGen.code + ctx.nullSafeExec(children(1).nullable, midGen.isNull) { + rightGen.code + ctx.nullSafeExec(children(2).nullable, rightGen.isNull) { + s""" + ${ev.isNull} = false; // resultCode could change nullability. + $resultCode + """ + } } - } } - """ + + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $nullSafeEval + """ + } else { + ev.isNull = "false" + s""" + ${leftGen.code} + ${midGen.code} + ${rightGen.code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $resultCode + """ + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala new file mode 100644 index 0000000000000..644a5b28a2151 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +object ExpressionSet { + /** Constructs a new [[ExpressionSet]] by applying [[Canonicalize]] to `expressions`. */ + def apply(expressions: TraversableOnce[Expression]): ExpressionSet = { + val set = new ExpressionSet() + expressions.foreach(set.add) + set + } +} + +/** + * A [[Set]] where membership is determined based on a canonical representation of an [[Expression]] + * (i.e. one that attempts to ignore cosmetic differences). See [[Canonicalize]] for more details. + * + * Internally this set uses the canonical representation, but keeps also track of the original + * expressions to ease debugging. Since different expressions can share the same canonical + * representation, this means that operations that extract expressions from this set are only + * guaranteed to see at least one such expression. For example: + * + * {{{ + * val set = AttributeSet(a + 1, 1 + a) + * + * set.iterator => Iterator(a + 1) + * set.contains(a + 1) => true + * set.contains(1 + a) => true + * set.contains(a + 2) => false + * }}} + */ +class ExpressionSet protected( + protected val baseSet: mutable.Set[Expression] = new mutable.HashSet, + protected val originals: mutable.Buffer[Expression] = new ArrayBuffer) + extends Set[Expression] { + + protected def add(e: Expression): Unit = { + if (!baseSet.contains(e.canonicalized)) { + baseSet.add(e.canonicalized) + originals.append(e) + } + } + + override def contains(elem: Expression): Boolean = baseSet.contains(elem.canonicalized) + + override def +(elem: Expression): ExpressionSet = { + val newSet = new ExpressionSet(baseSet.clone(), originals.clone()) + newSet.add(elem) + newSet + } + + override def -(elem: Expression): ExpressionSet = { + val newBaseSet = baseSet.clone().filterNot(_ == elem.canonicalized) + val newOriginals = originals.clone().filterNot(_.canonicalized == elem.canonicalized) + new ExpressionSet(newBaseSet, newOriginals) + } + + override def iterator: Iterator[Expression] = originals.iterator + + /** + * Returns a string containing both the post [[Canonicalize]] expressions and the original + * expressions in this set. + */ + def toDebugString: String = + s""" + |baseSet: ${baseSet.mkString(", ")} + |originals: ${originals.mkString(", ")} + """.stripMargin +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index d809877817a5b..2ed6fc0d3824f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -17,33 +17,35 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.rdd.SqlNewHadoopRDD +import org.apache.spark.rdd.InputFileNameHolder import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.unsafe.types.UTF8String /** - * Expression that returns the name of the current file being read in using [[SqlNewHadoopRDD]] + * Expression that returns the name of the current file being read. */ +@ExpressionDescription( + usage = "_FUNC_() - Returns the name of the current file being read if available", + extended = "> SELECT _FUNC_();\n ''") case class InputFileName() extends LeafExpression with Nondeterministic { override def nullable: Boolean = true override def dataType: DataType = StringType - override val prettyName = "INPUT_FILE_NAME" + override def prettyName: String = "input_file_name" override protected def initInternal(): Unit = {} override protected def evalInternal(input: InternalRow): UTF8String = { - SqlNewHadoopRDD.getInputFileName() + InputFileNameHolder.getInputFileName() } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { ev.isNull = "false" s"final ${ctx.javaType(dataType)} ${ev.value} = " + - "org.apache.spark.rdd.SqlNewHadoopRDD.getInputFileName();" + "org.apache.spark.rdd.InputFileNameHolder.getInputFileName();" } - } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala index 935c3aa28c999..ed894f6d6e10e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala @@ -18,11 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} - /** * A mutable wrapper that makes two rows appear as a single concatenated row. Designed to * be instantiated once per thread and reused. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 2d7679fdfe043..5d28f8fbde8be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} -import org.apache.spark.sql.types.{LongType, DataType} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types.{DataType, LongType} /** * Returns monotonically increasing 64-bit integers. @@ -32,6 +32,14 @@ import org.apache.spark.sql.types.{LongType, DataType} * * Since this expression is stateful, it cannot be a case object. */ +@ExpressionDescription( + usage = + """_FUNC_() - Returns monotonically increasing 64-bit integers. + The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + The current implementation puts the partition ID in the upper 31 bits, and the lower 33 bits + represent the record number within each partition. The assumption is that the data frame has + less than 1 billion partitions, and each partition has less than 8 billion records.""", + extended = "> SELECT _FUNC_();\n 0") private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterministic { /** @@ -57,7 +65,7 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with partitionMask + currentCount } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val countTerm = ctx.freshName("count") val partitionMaskTerm = ctx.freshName("partitionMask") ctx.addMutableState(ctx.JAVA_LONG, countTerm, s"$countTerm = 0L;") @@ -70,4 +78,8 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with $countTerm++; """ } + + override def prettyName: String = "monotonically_increasing_id" + + override def sql: String = s"$prettyName()" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index a6fe730f6dad4..354311c5e7449 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -102,16 +102,6 @@ abstract class UnsafeProjection extends Projection { object UnsafeProjection { - /* - * Returns whether UnsafeProjection can support given StructType, Array[DataType] or - * Seq[Expression]. - */ - def canSupport(schema: StructType): Boolean = canSupport(schema.fields.map(_.dataType)) - def canSupport(exprs: Seq[Expression]): Boolean = canSupport(exprs.map(_.dataType).toArray) - private def canSupport(types: Array[DataType]): Boolean = { - types.forall(GenerateUnsafeProjection.canSupport) - } - /** * Returns an UnsafeProjection for given StructType. */ @@ -128,7 +118,11 @@ object UnsafeProjection { * Returns an UnsafeProjection for given sequence of Expressions (bounded). */ def create(exprs: Seq[Expression]): UnsafeProjection = { - GenerateUnsafeProjection.generate(exprs) + val unsafeExprs = exprs.map(_ transform { + case CreateStruct(children) => CreateStructUnsafe(children) + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + }) + GenerateUnsafeProjection.generate(unsafeExprs) } def create(expr: Expression): UnsafeProjection = create(Seq(expr)) @@ -140,6 +134,22 @@ object UnsafeProjection { def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { create(exprs.map(BindReferences.bindReference(_, inputSchema))) } + + /** + * Same as other create()'s but allowing enabling/disabling subexpression elimination. + * TODO: refactor the plumbing and clean this up. + */ + def create( + exprs: Seq[Expression], + inputSchema: Seq[Attribute], + subexpressionEliminationEnabled: Boolean): UnsafeProjection = { + val e = exprs.map(BindReferences.bindReference(_, inputSchema)) + .map(_ transform { + case CreateStruct(children) => CreateStructUnsafe(children) + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + }) + GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala new file mode 100644 index 0000000000000..22645c952e722 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types.DataType + +/** + * A special expression that evaluates [[BoundReference]]s by given expressions instead of the + * input row. + * + * @param result The expression that contains [[BoundReference]] and produces the final output. + * @param children The expressions that used as input values for [[BoundReference]]. + */ +case class ReferenceToExpressions(result: Expression, children: Seq[Expression]) + extends Expression { + + override def nullable: Boolean = result.nullable + override def dataType: DataType = result.dataType + + override def checkInputDataTypes(): TypeCheckResult = { + if (result.references.nonEmpty) { + return TypeCheckFailure("The result expression cannot reference to any attributes.") + } + + var maxOrdinal = -1 + result foreach { + case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal + } + if (maxOrdinal > children.length) { + return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " + + s"there are only ${children.length} inputs.") + } + + TypeCheckSuccess + } + + private lazy val projection = UnsafeProjection.create(children) + + override def eval(input: InternalRow): Any = { + result.eval(projection(input)) + } + + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + val childrenGen = children.map(_.gen(ctx)) + val childrenVars = childrenGen.zip(children).map { + case (childGen, child) => LambdaVariable(childGen.value, childGen.isNull, child.dataType) + } + + val resultGen = result.transform { + case b: BoundReference => childrenVars(b.ordinal) + }.gen(ctx) + + ev.value = resultGen.value + ev.isNull = resultGen.isNull + + childrenGen.map(_.code).mkString("\n") + "\n" + resultGen.code + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index a04af7f1dd877..500ff447a9754 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -17,32 +17,35 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types.DataType /** * User-defined function. + * @param function The user defined scala function to run. + * Note that if you use primitive parameters, you are not able to check if it is + * null or not, and the UDF will return null for you if the primitive input is + * null. Use boxed type or [[Option]] if you wanna do the null-handling yourself. * @param dataType Return type of function. + * @param children The input expressions of this UDF. + * @param inputTypes The expected input types of this UDF, used to perform type coercion. If we do + * not want to perform coercion, simply use "Nil". Note that it would've been + * better to use Option of Seq[DataType] so we can use "None" as the case for no + * type coercion. However, that would require more refactoring of the codebase. */ case class ScalaUDF( function: AnyRef, dataType: DataType, children: Seq[Expression], - inputTypes: Seq[DataType] = Nil, - isDeterministic: Boolean = true) - extends Expression with ImplicitCastInputTypes with CodegenFallback { + inputTypes: Seq[DataType] = Nil) + extends Expression with ImplicitCastInputTypes with NonSQLExpression { override def nullable: Boolean = true - override def toString: String = s"UDF(${children.mkString(",")})" + override def toString: String = s"UDF(${children.mkString(", ")})" - override def foldable: Boolean = deterministic && children.forall(_.foldable) - - override def deterministic: Boolean = isDeterministic && children.forall(_.deterministic) - - // scalastyle:off + // scalastyle:off line.size.limit /** This method has been generated by this script @@ -65,6 +68,10 @@ case class ScalaUDF( */ + // Accessors used in genCode + def userDefinedFunc(): AnyRef = function + def getChildren(): Seq[Expression] = children + private[this] val f = children.size match { case 0 => val func = function.asInstanceOf[() => Any] @@ -964,7 +971,91 @@ case class ScalaUDF( } } - // scalastyle:on + // scalastyle:on line.size.limit + + // Generate codes used to convert the arguments to Scala type for user-defined functions + private[this] def genCodeForConverter(ctx: CodegenContext, index: Int): String = { + val converterClassName = classOf[Any => Any].getName + val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" + val expressionClassName = classOf[Expression].getName + val scalaUDFClassName = classOf[ScalaUDF].getName + + val converterTerm = ctx.freshName("converter") + val expressionIdx = ctx.references.size - 1 + ctx.addMutableState(converterClassName, converterTerm, + s"this.$converterTerm = ($converterClassName)$typeConvertersClassName" + + s".createToScalaConverter(((${expressionClassName})((($scalaUDFClassName)" + + s"references[$expressionIdx]).getChildren().apply($index))).dataType());") + converterTerm + } + + override def genCode( + ctx: CodegenContext, + ev: ExprCode): String = { + + ctx.references += this + + val scalaUDFClassName = classOf[ScalaUDF].getName + val converterClassName = classOf[Any => Any].getName + val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" + val expressionClassName = classOf[Expression].getName + + // Generate codes used to convert the returned value of user-defined functions to Catalyst type + val catalystConverterTerm = ctx.freshName("catalystConverter") + val catalystConverterTermIdx = ctx.references.size - 1 + ctx.addMutableState(converterClassName, catalystConverterTerm, + s"this.$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" + + s".createToCatalystConverter((($scalaUDFClassName)references" + + s"[$catalystConverterTermIdx]).dataType());") + + val resultTerm = ctx.freshName("result") + + // This must be called before children expressions' codegen + // because ctx.references is used in genCodeForConverter + val converterTerms = children.indices.map(genCodeForConverter(ctx, _)) + + // Initialize user-defined function + val funcClassName = s"scala.Function${children.size}" + + val funcTerm = ctx.freshName("udf") + val funcExpressionIdx = ctx.references.size - 1 + ctx.addMutableState(funcClassName, funcTerm, + s"this.$funcTerm = ($funcClassName)((($scalaUDFClassName)references" + + s"[$funcExpressionIdx]).userDefinedFunc());") + + // codegen for children expressions + val evals = children.map(_.gen(ctx)) + + // Generate the codes for expressions and calling user-defined function + // We need to get the boxedType of dataType's javaType here. Because for the dataType + // such as IntegerType, its javaType is `int` and the returned type of user-defined + // function is Object. Trying to convert an Object to `int` will cause casting exception. + val evalCode = evals.map(_.code).mkString + val (converters, funcArguments) = converterTerms.zipWithIndex.map { case (converter, i) => + val eval = evals(i) + val argTerm = ctx.freshName("arg") + val convert = s"Object $argTerm = ${eval.isNull} ? null : $converter.apply(${eval.value});" + (convert, argTerm) + }.unzip + + val callFunc = s"${ctx.boxedType(dataType)} $resultTerm = " + + s"(${ctx.boxedType(dataType)})${catalystConverterTerm}" + + s".apply($funcTerm.apply(${funcArguments.mkString(", ")}));" + + s""" + $evalCode + ${converters.mkString("\n")} + $callFunc + + boolean ${ev.isNull} = $resultTerm == null; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $resultTerm; + } + """ + } + private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) + override def eval(input: InternalRow): Any = converter(f(input)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 290c128d65b30..b739361937b6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -19,14 +19,22 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types._ import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.BinaryPrefixComparator import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator -abstract sealed class SortDirection -case object Ascending extends SortDirection -case object Descending extends SortDirection +abstract sealed class SortDirection { + def sql: String +} + +case object Ascending extends SortDirection { + override def sql: String = "ASC" +} + +case object Descending extends SortDirection { + override def sql: String = "DESC" +} /** * An expression that can be used to sort a tuple. This class extends expression primarily so that @@ -49,7 +57,8 @@ case class SortOrder(child: Expression, direction: SortDirection) override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable - override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" + override def toString: String = s"$child ${direction.sql}" + override def sql: String = child.sql + " " + direction.sql def isAscending: Boolean = direction == Ascending } @@ -61,7 +70,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { override def eval(input: InternalRow): Any = throw new UnsupportedOperationException - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val childCode = child.child.gen(ctx) val input = childCode.value val BinaryPrefixCmp = classOf[BinaryPrefixComparator].getName diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 8bff173d64eb9..377f08eb105fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} -import org.apache.spark.sql.types.{IntegerType, DataType} - +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types.{DataType, IntegerType} /** * Expression that returns the current partition id of the Spark task. */ +@ExpressionDescription( + usage = "_FUNC_() - Returns the current partition id of the Spark task", + extended = "> SELECT _FUNC_();\n 0") private[sql] case class SparkPartitionID() extends LeafExpression with Nondeterministic { override def nullable: Boolean = false @@ -42,7 +44,7 @@ private[sql] case class SparkPartitionID() extends LeafExpression with Nondeterm override protected def evalInternal(input: InternalRow): Int = partitionId - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val idTerm = ctx.freshName("partitionId") ctx.addMutableState(ctx.JAVA_INT, idTerm, s"$idTerm = org.apache.spark.TaskContext.getPartitionId();") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 475cbe005a6ee..61ca7272dfa61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String /** * A parent class for mutable container objects that are reused when the values are changed, @@ -63,7 +62,7 @@ import org.apache.spark.unsafe.types.UTF8String abstract class MutableValue extends Serializable { var isNull: Boolean = true def boxed: Any - def update(v: Any) + def update(v: Any): Unit def copy(): MutableValue } @@ -212,6 +211,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) def this() = this(Seq.empty) + def this(schema: StructType) = this(schema.fields.map(_.dataType)) + override def numFields: Int = values.length override def setNullAt(i: Int): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala new file mode 100644 index 0000000000000..daf3de95dd9ea --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.commons.lang.StringUtils + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + +case class TimeWindow( + timeColumn: Expression, + windowDuration: Long, + slideDuration: Long, + startTime: Long) extends UnaryExpression + with ImplicitCastInputTypes + with Unevaluable + with NonSQLExpression { + + ////////////////////////// + // SQL Constructors + ////////////////////////// + + def this( + timeColumn: Expression, + windowDuration: Expression, + slideDuration: Expression, + startTime: Expression) = { + this(timeColumn, TimeWindow.parseExpression(windowDuration), + TimeWindow.parseExpression(windowDuration), TimeWindow.parseExpression(startTime)) + } + + def this(timeColumn: Expression, windowDuration: Expression, slideDuration: Expression) = { + this(timeColumn, TimeWindow.parseExpression(windowDuration), + TimeWindow.parseExpression(windowDuration), 0) + } + + def this(timeColumn: Expression, windowDuration: Expression) = { + this(timeColumn, windowDuration, windowDuration) + } + + override def child: Expression = timeColumn + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) + override def dataType: DataType = new StructType() + .add(StructField("start", TimestampType)) + .add(StructField("end", TimestampType)) + + // This expression is replaced in the analyzer. + override lazy val resolved = false + + /** + * Validate the inputs for the window duration, slide duration, and start time in addition to + * the input data type. + */ + override def checkInputDataTypes(): TypeCheckResult = { + val dataTypeCheck = super.checkInputDataTypes() + if (dataTypeCheck.isSuccess) { + if (windowDuration <= 0) { + return TypeCheckFailure(s"The window duration ($windowDuration) must be greater than 0.") + } + if (slideDuration <= 0) { + return TypeCheckFailure(s"The slide duration ($slideDuration) must be greater than 0.") + } + if (startTime < 0) { + return TypeCheckFailure(s"The start time ($startTime) must be greater than or equal to 0.") + } + if (slideDuration > windowDuration) { + return TypeCheckFailure(s"The slide duration ($slideDuration) must be less than or equal" + + s" to the windowDuration ($windowDuration).") + } + if (startTime >= slideDuration) { + return TypeCheckFailure(s"The start time ($startTime) must be less than the " + + s"slideDuration ($slideDuration).") + } + } + dataTypeCheck + } +} + +object TimeWindow { + /** + * Parses the interval string for a valid time duration. CalendarInterval expects interval + * strings to start with the string `interval`. For usability, we prepend `interval` to the string + * if the user omitted it. + * + * @param interval The interval string + * @return The interval duration in microseconds. SparkSQL casts TimestampType has microsecond + * precision. + */ + private def getIntervalInMicroSeconds(interval: String): Long = { + if (StringUtils.isBlank(interval)) { + throw new IllegalArgumentException( + "The window duration, slide duration and start time cannot be null or blank.") + } + val intervalString = if (interval.startsWith("interval")) { + interval + } else { + "interval " + interval + } + val cal = CalendarInterval.fromString(intervalString) + if (cal == null) { + throw new IllegalArgumentException( + s"The provided interval ($interval) did not correspond to a valid interval string.") + } + if (cal.months > 0) { + throw new IllegalArgumentException( + s"Intervals greater than a month is not supported ($interval).") + } + cal.microseconds + } + + /** + * Parses the duration expression to generate the long value for the original constructor so + * that we can use `window` in SQL. + */ + private def parseExpression(expr: Expression): Long = expr match { + case NonNullLiteral(s, StringType) => getIntervalInMicroSeconds(s.toString) + case IntegerLiteral(i) => i.toLong + case NonNullLiteral(l, LongType) => l.toString.toLong + case _ => throw new AnalysisException("The duration and time inputs to window must be " + + "an integer, long or string literal.") + } + + def apply( + timeColumn: Expression, + windowDuration: String, + slideDuration: String, + startTime: String): TimeWindow = { + TimeWindow(timeColumn, + getIntervalInMicroSeconds(windowDuration), + getIntervalInMicroSeconds(slideDuration), + getIntervalInMicroSeconds(startTime)) + } +} + +/** + * Expression used internally to convert the TimestampType to Long without losing + * precision, i.e. in microseconds. Used in time windowing. + */ +case class PreciseTimestamp(child: Expression) extends UnaryExpression with ExpectsInputTypes { + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) + override def dataType: DataType = LongType + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + val eval = child.gen(ctx) + eval.code + + s"""boolean ${ev.isNull} = ${eval.isNull}; + |${ctx.javaType(dataType)} ${ev.value} = ${eval.value}; + """.stripMargin + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index c8c20ada5fbc7..ff70774847830 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -17,10 +17,14 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the mean calculated from values of a group.") case class Average(child: Expression) extends DeclarativeAggregate { override def prettyName: String = "avg" @@ -32,36 +36,33 @@ case class Average(child: Expression) extends DeclarativeAggregate { // Return data type. override def dataType: DataType = resultType - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select avg(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) - private val resultType = child.dataType match { + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function average") + + private lazy val resultType = child.dataType match { case DecimalType.Fixed(p, s) => DecimalType.bounded(p + 4, s + 4) case _ => DoubleType } - private val sumDataType = child.dataType match { + private lazy val sumDataType = child.dataType match { case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s) case _ => DoubleType } - private val sum = AttributeReference("sum", sumDataType)() - private val count = AttributeReference("count", LongType)() + private lazy val sum = AttributeReference("sum", sumDataType)() + private lazy val count = AttributeReference("count", LongType)() - override val aggBufferAttributes = sum :: count :: Nil + override lazy val aggBufferAttributes = sum :: count :: Nil - override val initialValues = Seq( + override lazy val initialValues = Seq( /* sum = */ Cast(Literal(0), sumDataType), /* count = */ Literal(0L) ) - override val updateExpressions = Seq( + override lazy val updateExpressions = Seq( /* sum = */ Add( sum, @@ -69,13 +70,13 @@ case class Average(child: Expression) extends DeclarativeAggregate { /* count = */ If(IsNull(child), count, count + 1L) ) - override val mergeExpressions = Seq( + override lazy val mergeExpressions = Seq( /* sum = */ sum.left + sum.right, /* count = */ count.left + count.right ) // If all input are nulls, count will be 0 and we will get null after the division. - override val evaluateExpression = child.dataType match { + override lazy val evaluateExpression = child.dataType match { case DecimalType.Fixed(p, s) => // increase the precision and scale to prevent precision loss val dt = DecimalType.bounded(p + 14, s + 4) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index ef08b025ff556..17a7c6dce89ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -42,7 +42,7 @@ import org.apache.spark.sql.types._ * * @param child to compute central moments of. */ -abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate with Serializable { +abstract class CentralMomentAgg(child: Expression) extends DeclarativeAggregate { /** * The central moment order to be computed. @@ -50,181 +50,175 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w protected def momentOrder: Int override def children: Seq[Expression] = Seq(child) + override def nullable: Boolean = true + override def dataType: DataType = DoubleType + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType) - override def nullable: Boolean = false + protected val n = AttributeReference("n", DoubleType, nullable = false)() + protected val avg = AttributeReference("avg", DoubleType, nullable = false)() + protected val m2 = AttributeReference("m2", DoubleType, nullable = false)() + protected val m3 = AttributeReference("m3", DoubleType, nullable = false)() + protected val m4 = AttributeReference("m4", DoubleType, nullable = false)() - override def dataType: DataType = DoubleType + private def trimHigherOrder[T](expressions: Seq[T]) = expressions.take(momentOrder + 1) - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select avg(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + override val aggBufferAttributes = trimHigherOrder(Seq(n, avg, m2, m3, m4)) - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + override val initialValues: Seq[Expression] = Array.fill(momentOrder + 1)(Literal(0.0)) - /** - * Size of aggregation buffer. - */ - private[this] val bufferSize = 5 + override val updateExpressions: Seq[Expression] = { + val newN = n + Literal(1.0) + val delta = child - avg + val deltaN = delta / newN + val newAvg = avg + deltaN + val newM2 = m2 + delta * (delta - deltaN) + + val delta2 = delta * delta + val deltaN2 = deltaN * deltaN + val newM3 = if (momentOrder >= 3) { + m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2) + } else { + Literal(0.0) + } + val newM4 = if (momentOrder >= 4) { + m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 + + delta * (delta * delta2 - deltaN * deltaN2) + } else { + Literal(0.0) + } - override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(bufferSize) { i => - AttributeReference(s"M$i", DoubleType)() + trimHigherOrder(Seq( + If(IsNull(child), n, newN), + If(IsNull(child), avg, newAvg), + If(IsNull(child), m2, newM2), + If(IsNull(child), m3, newM3), + If(IsNull(child), m4, newM4) + )) } - // Note: although this simply copies aggBufferAttributes, this common code can not be placed - // in the superclass because that will lead to initialization ordering issues. - override val inputAggBufferAttributes: Seq[AttributeReference] = - aggBufferAttributes.map(_.newInstance()) - - // buffer offsets - private[this] val nOffset = mutableAggBufferOffset - private[this] val meanOffset = mutableAggBufferOffset + 1 - private[this] val secondMomentOffset = mutableAggBufferOffset + 2 - private[this] val thirdMomentOffset = mutableAggBufferOffset + 3 - private[this] val fourthMomentOffset = mutableAggBufferOffset + 4 - - // frequently used values for online updates - private[this] var delta = 0.0 - private[this] var deltaN = 0.0 - private[this] var delta2 = 0.0 - private[this] var deltaN2 = 0.0 - private[this] var n = 0.0 - private[this] var mean = 0.0 - private[this] var m2 = 0.0 - private[this] var m3 = 0.0 - private[this] var m4 = 0.0 + override val mergeExpressions: Seq[Expression] = { - /** - * Initialize all moments to zero. - */ - override def initialize(buffer: MutableRow): Unit = { - for (aggIndex <- 0 until bufferSize) { - buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0) + val n1 = n.left + val n2 = n.right + val newN = n1 + n2 + val delta = avg.right - avg.left + val deltaN = If(newN === Literal(0.0), Literal(0.0), delta / newN) + val newAvg = avg.left + deltaN * n2 + + // higher order moments computed according to: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics + val newM2 = m2.left + m2.right + delta * deltaN * n1 * n2 + // `m3.right` is not available if momentOrder < 3 + val newM3 = if (momentOrder >= 3) { + m3.left + m3.right + deltaN * deltaN * delta * n1 * n2 * (n1 - n2) + + Literal(3.0) * deltaN * (n1 * m2.right - n2 * m2.left) + } else { + Literal(0.0) + } + // `m4.right` is not available if momentOrder < 4 + val newM4 = if (momentOrder >= 4) { + m4.left + m4.right + + deltaN * deltaN * deltaN * delta * n1 * n2 * (n1 * n1 - n1 * n2 + n2 * n2) + + Literal(6.0) * deltaN * deltaN * (n1 * n1 * m2.right + n2 * n2 * m2.left) + + Literal(4.0) * deltaN * (n1 * m3.right - n2 * m3.left) + } else { + Literal(0.0) } + + trimHigherOrder(Seq(newN, newAvg, newM2, newM3, newM4)) } +} - /** - * Update the central moments buffer. - */ - override def update(buffer: MutableRow, input: InternalRow): Unit = { - val v = Cast(child, DoubleType).eval(input) - if (v != null) { - val updateValue = v match { - case d: Double => d - } - - n = buffer.getDouble(nOffset) - mean = buffer.getDouble(meanOffset) - - n += 1.0 - buffer.setDouble(nOffset, n) - delta = updateValue - mean - deltaN = delta / n - mean += deltaN - buffer.setDouble(meanOffset, mean) - - if (momentOrder >= 2) { - m2 = buffer.getDouble(secondMomentOffset) - m2 += delta * (delta - deltaN) - buffer.setDouble(secondMomentOffset, m2) - } - - if (momentOrder >= 3) { - delta2 = delta * delta - deltaN2 = deltaN * deltaN - m3 = buffer.getDouble(thirdMomentOffset) - m3 += -3.0 * deltaN * m2 + delta * (delta2 - deltaN2) - buffer.setDouble(thirdMomentOffset, m3) - } - - if (momentOrder >= 4) { - m4 = buffer.getDouble(fourthMomentOffset) - m4 += -4.0 * deltaN * m3 - 6.0 * deltaN2 * m2 + - delta * (delta * delta2 - deltaN * deltaN2) - buffer.setDouble(fourthMomentOffset, m4) - } - } +// Compute the population standard deviation of a column +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the population standard deviation calculated from values of a group.") +// scalastyle:on line.size.limit +case class StddevPop(child: Expression) extends CentralMomentAgg(child) { + + override protected def momentOrder = 2 + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + Sqrt(m2 / n)) } - /** - * Merge two central moment buffers. - */ - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val n1 = buffer1.getDouble(nOffset) - val n2 = buffer2.getDouble(inputAggBufferOffset) - val mean1 = buffer1.getDouble(meanOffset) - val mean2 = buffer2.getDouble(inputAggBufferOffset + 1) + override def prettyName: String = "stddev_pop" +} - var secondMoment1 = 0.0 - var secondMoment2 = 0.0 +// Compute the sample standard deviation of a column +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the sample standard deviation calculated from values of a group.") +case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { - var thirdMoment1 = 0.0 - var thirdMoment2 = 0.0 + override protected def momentOrder = 2 - var fourthMoment1 = 0.0 - var fourthMoment2 = 0.0 + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + If(n === Literal(1.0), Literal(Double.NaN), + Sqrt(m2 / (n - Literal(1.0))))) + } - n = n1 + n2 - buffer1.setDouble(nOffset, n) - delta = mean2 - mean1 - deltaN = if (n == 0.0) 0.0 else delta / n - mean = mean1 + deltaN * n2 - buffer1.setDouble(mutableAggBufferOffset + 1, mean) + override def prettyName: String = "stddev_samp" +} - // higher order moments computed according to: - // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics - if (momentOrder >= 2) { - secondMoment1 = buffer1.getDouble(secondMomentOffset) - secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2) - m2 = secondMoment1 + secondMoment2 + delta * deltaN * n1 * n2 - buffer1.setDouble(secondMomentOffset, m2) - } +// Compute the population variance of a column +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the population variance calculated from values of a group.") +case class VariancePop(child: Expression) extends CentralMomentAgg(child) { - if (momentOrder >= 3) { - thirdMoment1 = buffer1.getDouble(thirdMomentOffset) - thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3) - m3 = thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * n1 * n2 * - (n1 - n2) + 3.0 * deltaN * (n1 * secondMoment2 - n2 * secondMoment1) - buffer1.setDouble(thirdMomentOffset, m3) - } + override protected def momentOrder = 2 - if (momentOrder >= 4) { - fourthMoment1 = buffer1.getDouble(fourthMomentOffset) - fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4) - m4 = fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * n1 * - n2 * (n1 * n1 - n1 * n2 + n2 * n2) + deltaN * deltaN * 6.0 * - (n1 * n1 * secondMoment2 + n2 * n2 * secondMoment1) + - 4.0 * deltaN * (n1 * thirdMoment2 - n2 * thirdMoment1) - buffer1.setDouble(fourthMomentOffset, m4) - } + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + m2 / n) } - /** - * Compute aggregate statistic from sufficient moments. - * @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized) - * needed to compute the aggregate stat. - */ - def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Double - - override final def eval(buffer: InternalRow): Any = { - val n = buffer.getDouble(nOffset) - val mean = buffer.getDouble(meanOffset) - val moments = Array.ofDim[Double](momentOrder + 1) - moments(0) = 1.0 - moments(1) = 0.0 - if (momentOrder >= 2) { - moments(2) = buffer.getDouble(secondMomentOffset) - } - if (momentOrder >= 3) { - moments(3) = buffer.getDouble(thirdMomentOffset) - } - if (momentOrder >= 4) { - moments(4) = buffer.getDouble(fourthMomentOffset) - } + override def prettyName: String = "var_pop" +} - getStatistic(n, mean, moments) +// Compute the sample variance of a column +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the sample variance calculated from values of a group.") +case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { + + override protected def momentOrder = 2 + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + If(n === Literal(1.0), Literal(Double.NaN), + m2 / (n - Literal(1.0)))) } + + override def prettyName: String = "var_samp" +} + +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the Skewness value calculated from values of a group.") +case class Skewness(child: Expression) extends CentralMomentAgg(child) { + + override def prettyName: String = "skewness" + + override protected def momentOrder = 3 + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + If(m2 === Literal(0.0), Literal(Double.NaN), + Sqrt(n) * m3 / Sqrt(m2 * m2 * m2))) + } +} + +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the Kurtosis value calculated from values of a group.") +case class Kurtosis(child: Expression) extends CentralMomentAgg(child) { + + override protected def momentOrder = 4 + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + If(m2 === Literal(0.0), Literal(Double.NaN), + n * m4 / (m2 * m2) - Literal(3.0))) + } + + override def prettyName: String = "kurtosis" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index 832338378fb38..e29265e2f41e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -28,152 +28,72 @@ import org.apache.spark.sql.types._ * Definition of Pearson correlation can be found at * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient */ -case class Corr( - left: Expression, - right: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends ImperativeAggregate { - - override def children: Seq[Expression] = Seq(left, right) - - override def nullable: Boolean = false +@ExpressionDescription( + usage = "_FUNC_(x,y) - Returns Pearson coefficient of correlation between a set of number pairs.") +case class Corr(x: Expression, y: Expression) extends DeclarativeAggregate { + override def children: Seq[Expression] = Seq(x, y) + override def nullable: Boolean = true override def dataType: DataType = DoubleType - override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - - override def inputAggBufferAttributes: Seq[AttributeReference] = { - aggBufferAttributes.map(_.newInstance()) - } - - override val aggBufferAttributes: Seq[AttributeReference] = Seq( - AttributeReference("xAvg", DoubleType)(), - AttributeReference("yAvg", DoubleType)(), - AttributeReference("Ck", DoubleType)(), - AttributeReference("MkX", DoubleType)(), - AttributeReference("MkY", DoubleType)(), - AttributeReference("count", LongType)()) - - // Local cache of mutableAggBufferOffset(s) that will be used in update and merge - private[this] val mutableAggBufferOffsetPlus1 = mutableAggBufferOffset + 1 - private[this] val mutableAggBufferOffsetPlus2 = mutableAggBufferOffset + 2 - private[this] val mutableAggBufferOffsetPlus3 = mutableAggBufferOffset + 3 - private[this] val mutableAggBufferOffsetPlus4 = mutableAggBufferOffset + 4 - private[this] val mutableAggBufferOffsetPlus5 = mutableAggBufferOffset + 5 - - // Local cache of inputAggBufferOffset(s) that will be used in update and merge - private[this] val inputAggBufferOffsetPlus1 = inputAggBufferOffset + 1 - private[this] val inputAggBufferOffsetPlus2 = inputAggBufferOffset + 2 - private[this] val inputAggBufferOffsetPlus3 = inputAggBufferOffset + 3 - private[this] val inputAggBufferOffsetPlus4 = inputAggBufferOffset + 4 - private[this] val inputAggBufferOffsetPlus5 = inputAggBufferOffset + 5 - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def initialize(buffer: MutableRow): Unit = { - buffer.setDouble(mutableAggBufferOffset, 0.0) - buffer.setDouble(mutableAggBufferOffsetPlus1, 0.0) - buffer.setDouble(mutableAggBufferOffsetPlus2, 0.0) - buffer.setDouble(mutableAggBufferOffsetPlus3, 0.0) - buffer.setDouble(mutableAggBufferOffsetPlus4, 0.0) - buffer.setLong(mutableAggBufferOffsetPlus5, 0L) + protected val n = AttributeReference("n", DoubleType, nullable = false)() + protected val xAvg = AttributeReference("xAvg", DoubleType, nullable = false)() + protected val yAvg = AttributeReference("yAvg", DoubleType, nullable = false)() + protected val ck = AttributeReference("ck", DoubleType, nullable = false)() + protected val xMk = AttributeReference("xMk", DoubleType, nullable = false)() + protected val yMk = AttributeReference("yMk", DoubleType, nullable = false)() + + override val aggBufferAttributes: Seq[AttributeReference] = Seq(n, xAvg, yAvg, ck, xMk, yMk) + + override val initialValues: Seq[Expression] = Array.fill(6)(Literal(0.0)) + + override val updateExpressions: Seq[Expression] = { + val newN = n + Literal(1.0) + val dx = x - xAvg + val dxN = dx / newN + val dy = y - yAvg + val dyN = dy / newN + val newXAvg = xAvg + dxN + val newYAvg = yAvg + dyN + val newCk = ck + dx * (y - newYAvg) + val newXMk = xMk + dx * (x - newXAvg) + val newYMk = yMk + dy * (y - newYAvg) + + val isNull = IsNull(x) || IsNull(y) + Seq( + If(isNull, n, newN), + If(isNull, xAvg, newXAvg), + If(isNull, yAvg, newYAvg), + If(isNull, ck, newCk), + If(isNull, xMk, newXMk), + If(isNull, yMk, newYMk) + ) } - override def update(buffer: MutableRow, input: InternalRow): Unit = { - val leftEval = left.eval(input) - val rightEval = right.eval(input) - - if (leftEval != null && rightEval != null) { - val x = leftEval.asInstanceOf[Double] - val y = rightEval.asInstanceOf[Double] - - var xAvg = buffer.getDouble(mutableAggBufferOffset) - var yAvg = buffer.getDouble(mutableAggBufferOffsetPlus1) - var Ck = buffer.getDouble(mutableAggBufferOffsetPlus2) - var MkX = buffer.getDouble(mutableAggBufferOffsetPlus3) - var MkY = buffer.getDouble(mutableAggBufferOffsetPlus4) - var count = buffer.getLong(mutableAggBufferOffsetPlus5) - - val deltaX = x - xAvg - val deltaY = y - yAvg - count += 1 - xAvg += deltaX / count - yAvg += deltaY / count - Ck += deltaX * (y - yAvg) - MkX += deltaX * (x - xAvg) - MkY += deltaY * (y - yAvg) - - buffer.setDouble(mutableAggBufferOffset, xAvg) - buffer.setDouble(mutableAggBufferOffsetPlus1, yAvg) - buffer.setDouble(mutableAggBufferOffsetPlus2, Ck) - buffer.setDouble(mutableAggBufferOffsetPlus3, MkX) - buffer.setDouble(mutableAggBufferOffsetPlus4, MkY) - buffer.setLong(mutableAggBufferOffsetPlus5, count) - } + override val mergeExpressions: Seq[Expression] = { + + val n1 = n.left + val n2 = n.right + val newN = n1 + n2 + val dx = xAvg.right - xAvg.left + val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) + val dy = yAvg.right - yAvg.left + val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) + val newXAvg = xAvg.left + dxN * n2 + val newYAvg = yAvg.left + dyN * n2 + val newCk = ck.left + ck.right + dx * dyN * n1 * n2 + val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2 + val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2 + + Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk) } - // Merge counters from other partitions. Formula can be found at: - // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val count2 = buffer2.getLong(inputAggBufferOffsetPlus5) - - // We only go to merge two buffers if there is at least one record aggregated in buffer2. - // We don't need to check count in buffer1 because if count2 is more than zero, totalCount - // is more than zero too, then we won't get a divide by zero exception. - if (count2 > 0) { - var xAvg = buffer1.getDouble(mutableAggBufferOffset) - var yAvg = buffer1.getDouble(mutableAggBufferOffsetPlus1) - var Ck = buffer1.getDouble(mutableAggBufferOffsetPlus2) - var MkX = buffer1.getDouble(mutableAggBufferOffsetPlus3) - var MkY = buffer1.getDouble(mutableAggBufferOffsetPlus4) - var count = buffer1.getLong(mutableAggBufferOffsetPlus5) - - val xAvg2 = buffer2.getDouble(inputAggBufferOffset) - val yAvg2 = buffer2.getDouble(inputAggBufferOffsetPlus1) - val Ck2 = buffer2.getDouble(inputAggBufferOffsetPlus2) - val MkX2 = buffer2.getDouble(inputAggBufferOffsetPlus3) - val MkY2 = buffer2.getDouble(inputAggBufferOffsetPlus4) - - val totalCount = count + count2 - val deltaX = xAvg - xAvg2 - val deltaY = yAvg - yAvg2 - Ck += Ck2 + deltaX * deltaY * count / totalCount * count2 - xAvg = (xAvg * count + xAvg2 * count2) / totalCount - yAvg = (yAvg * count + yAvg2 * count2) / totalCount - MkX += MkX2 + deltaX * deltaX * count / totalCount * count2 - MkY += MkY2 + deltaY * deltaY * count / totalCount * count2 - count = totalCount - - buffer1.setDouble(mutableAggBufferOffset, xAvg) - buffer1.setDouble(mutableAggBufferOffsetPlus1, yAvg) - buffer1.setDouble(mutableAggBufferOffsetPlus2, Ck) - buffer1.setDouble(mutableAggBufferOffsetPlus3, MkX) - buffer1.setDouble(mutableAggBufferOffsetPlus4, MkY) - buffer1.setLong(mutableAggBufferOffsetPlus5, count) - } + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + If(n === Literal(1.0), Literal(Double.NaN), + ck / Sqrt(xMk * yMk))) } - override def eval(buffer: InternalRow): Any = { - val count = buffer.getLong(mutableAggBufferOffsetPlus5) - if (count > 0) { - val Ck = buffer.getDouble(mutableAggBufferOffsetPlus2) - val MkX = buffer.getDouble(mutableAggBufferOffsetPlus3) - val MkY = buffer.getDouble(mutableAggBufferOffsetPlus4) - val corr = Ck / math.sqrt(MkX * MkY) - if (corr.isNaN) { - null - } else { - corr - } - } else { - null - } - } + override def prettyName: String = "corr" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 54df96cd2446a..17ae012af79be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -21,8 +21,13 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -case class Count(child: Expression) extends DeclarativeAggregate { - override def children: Seq[Expression] = child :: Nil +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """_FUNC_(*) - Returns the total number of retrieved rows, including rows containing NULL values. + _FUNC_(expr) - Returns the number of rows for which the supplied expression is non-NULL. + _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and non-NULL.""") +// scalastyle:on line.size.limit +case class Count(children: Seq[Expression]) extends DeclarativeAggregate { override def nullable: Boolean = false @@ -30,23 +35,38 @@ case class Count(child: Expression) extends DeclarativeAggregate { override def dataType: DataType = LongType // Expected input data type. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType) - private val count = AttributeReference("count", LongType)() + private lazy val count = AttributeReference("count", LongType, nullable = false)() - override val aggBufferAttributes = count :: Nil + override lazy val aggBufferAttributes = count :: Nil - override val initialValues = Seq( + override lazy val initialValues = Seq( /* count = */ Literal(0L) ) - override val updateExpressions = Seq( - /* count = */ If(IsNull(child), count, count + 1L) - ) + override lazy val updateExpressions = { + val nullableChildren = children.filter(_.nullable) + if (nullableChildren.isEmpty) { + Seq( + /* count = */ count + 1L + ) + } else { + Seq( + /* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L) + ) + } + } - override val mergeExpressions = Seq( + override lazy val mergeExpressions = Seq( /* count = */ count.left + count.right ) - override val evaluateExpression = Cast(count, LongType) + override lazy val evaluateExpression = count + + override def defaultResult: Option[Literal] = Option(Literal(0L)) +} + +object Count { + def apply(child: Expression): Count = Count(child :: Nil) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala new file mode 100644 index 0000000000000..d80afbebf7404 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * Compute the covariance between two expressions. + * When applied on empty data (i.e., count is zero), it returns NULL. + */ +abstract class Covariance(x: Expression, y: Expression) extends DeclarativeAggregate { + + override def children: Seq[Expression] = Seq(x, y) + override def nullable: Boolean = true + override def dataType: DataType = DoubleType + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) + + protected val n = AttributeReference("n", DoubleType, nullable = false)() + protected val xAvg = AttributeReference("xAvg", DoubleType, nullable = false)() + protected val yAvg = AttributeReference("yAvg", DoubleType, nullable = false)() + protected val ck = AttributeReference("ck", DoubleType, nullable = false)() + + override val aggBufferAttributes: Seq[AttributeReference] = Seq(n, xAvg, yAvg, ck) + + override val initialValues: Seq[Expression] = Array.fill(4)(Literal(0.0)) + + override lazy val updateExpressions: Seq[Expression] = { + val newN = n + Literal(1.0) + val dx = x - xAvg + val dy = y - yAvg + val dyN = dy / newN + val newXAvg = xAvg + dx / newN + val newYAvg = yAvg + dyN + val newCk = ck + dx * (y - newYAvg) + + val isNull = IsNull(x) || IsNull(y) + Seq( + If(isNull, n, newN), + If(isNull, xAvg, newXAvg), + If(isNull, yAvg, newYAvg), + If(isNull, ck, newCk) + ) + } + + override val mergeExpressions: Seq[Expression] = { + + val n1 = n.left + val n2 = n.right + val newN = n1 + n2 + val dx = xAvg.right - xAvg.left + val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) + val dy = yAvg.right - yAvg.left + val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) + val newXAvg = xAvg.left + dxN * n2 + val newYAvg = yAvg.left + dyN * n2 + val newCk = ck.left + ck.right + dx * dyN * n1 * n2 + + Seq(newN, newXAvg, newYAvg, newCk) + } +} + +@ExpressionDescription( + usage = "_FUNC_(x,y) - Returns the population covariance of a set of number pairs.") +case class CovPopulation(left: Expression, right: Expression) extends Covariance(left, right) { + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + ck / n) + } + override def prettyName: String = "covar_pop" +} + + +@ExpressionDescription( + usage = "_FUNC_(x,y) - Returns the sample covariance of a set of number pairs.") +case class CovSample(left: Expression, right: Expression) extends Covariance(left, right) { + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), + If(n === Literal(1.0), Literal(Double.NaN), + ck / (n - Literal(1.0)))) + } + override def prettyName: String = "covar_samp" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 9028143015853..b8ab0364dd8f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -28,6 +28,11 @@ import org.apache.spark.sql.types._ * is used) its result will not be deterministic (unless the input table is sorted and has * a single partition, and we use a single reducer to do the aggregation.). */ +@ExpressionDescription( + usage = """_FUNC_(expr) - Returns the first value of `child` for a group of rows. + _FUNC_(expr,isIgnoreNull=false) - Returns the first value of `child` for a group of rows. + If isIgnoreNull is true, returns only non-null values. + """) case class First(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate { def this(child: Expression) = this(child, Literal.create(false, BooleanType)) @@ -51,18 +56,18 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val first = AttributeReference("first", child.dataType)() + private lazy val first = AttributeReference("first", child.dataType)() - private val valueSet = AttributeReference("valueSet", BooleanType)() + private lazy val valueSet = AttributeReference("valueSet", BooleanType)() - override val aggBufferAttributes: Seq[AttributeReference] = first :: valueSet :: Nil + override lazy val aggBufferAttributes: Seq[AttributeReference] = first :: valueSet :: Nil - override val initialValues: Seq[Literal] = Seq( + override lazy val initialValues: Seq[Literal] = Seq( /* first = */ Literal.create(null, child.dataType), /* valueSet = */ Literal.create(false, BooleanType) ) - override val updateExpressions: Seq[Expression] = { + override lazy val updateExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( /* first = */ If(Or(valueSet, IsNull(child)), first, child), @@ -76,7 +81,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara } } - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { // For first, we can just check if valueSet.left is set to true. If it is set // to true, we use first.right. If not, we use first.right (even if valueSet.right is // false, we are safe to do so because first.right will be null in this case). @@ -86,7 +91,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara ) } - override val evaluateExpression: AttributeReference = first + override lazy val evaluateExpression: AttributeReference = first override def toString: String = s"first($child)${if (ignoreNulls) " ignore nulls"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index 8d341ee630bdb..1d218da6db806 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import java.lang.{Long => JLong} import java.util -import com.clearspring.analytics.hash.MurmurHash - +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -47,6 +46,11 @@ import org.apache.spark.sql.types._ * @param relativeSD the maximum estimation error allowed. */ // scalastyle:on +@ExpressionDescription( + usage = """_FUNC_(expr) - Returns the estimated cardinality by HyperLogLog++. + _FUNC_(expr, relativeSD=0.05) - Returns the estimated cardinality by HyperLogLog++ + with relativeSD, the maximum estimation error allowed. + """) case class HyperLogLogPlusPlus( child: Expression, relativeSD: Double = 0.05, @@ -55,6 +59,20 @@ case class HyperLogLogPlusPlus( extends ImperativeAggregate { import HyperLogLogPlusPlus._ + def this(child: Expression) = { + this(child = child, relativeSD = 0.05, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + } + + def this(child: Expression, relativeSD: Expression) = { + this( + child = child, + relativeSD = HyperLogLogPlusPlus.validateDoubleLiteral(relativeSD), + mutableAggBufferOffset = 0, + inputAggBufferOffset = 0) + } + + override def prettyName: String = "approx_count_distinct" + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -154,7 +172,7 @@ case class HyperLogLogPlusPlus( val v = child.eval(input) if (v != null) { // Create the hashed value 'x'. - val x = MurmurHash.hash64(v) + val x = XxHash64Function.hash(v, child.dataType, 42L) // Determine the index of the register we are going to use. val idx = (x >>> idxShift).toInt @@ -225,7 +243,7 @@ case class HyperLogLogPlusPlus( diff * diff } - // Keep moving bounds as long as the the (exclusive) high bound is closer to the estimate than + // Keep moving bounds as long as the (exclusive) high bound is closer to the estimate than // the lower (inclusive) bound. var low = math.max(nearestEstimateIndex - K + 1, 0) var high = math.min(low + K, numEstimates) @@ -431,4 +449,12 @@ object HyperLogLogPlusPlus { Array(189083, 185696.913, 182348.774, 179035.946, 175762.762, 172526.444, 169329.754, 166166.099, 163043.269, 159958.91, 156907.912, 153906.845, 150924.199, 147996.568, 145093.457, 142239.233, 139421.475, 136632.27, 133889.588, 131174.2, 128511.619, 125868.621, 123265.385, 120721.061, 118181.769, 115709.456, 113252.446, 110840.198, 108465.099, 106126.164, 103823.469, 101556.618, 99308.004, 97124.508, 94937.803, 92833.731, 90745.061, 88677.627, 86617.47, 84650.442, 82697.833, 80769.132, 78879.629, 77014.432, 75215.626, 73384.587, 71652.482, 69895.93, 68209.301, 66553.669, 64921.981, 63310.323, 61742.115, 60205.018, 58698.658, 57190.657, 55760.865, 54331.169, 52908.167, 51550.273, 50225.254, 48922.421, 47614.533, 46362.049, 45098.569, 43926.083, 42736.03, 41593.473, 40425.26, 39316.237, 38243.651, 37170.617, 36114.609, 35084.19, 34117.233, 33206.509, 32231.505, 31318.728, 30403.404, 29540.0550000001, 28679.236, 27825.862, 26965.216, 26179.148, 25462.08, 24645.952, 23922.523, 23198.144, 22529.128, 21762.4179999999, 21134.779, 20459.117, 19840.818, 19187.04, 18636.3689999999, 17982.831, 17439.7389999999, 16874.547, 16358.2169999999, 15835.684, 15352.914, 14823.681, 14329.313, 13816.897, 13342.874, 12880.882, 12491.648, 12021.254, 11625.392, 11293.7610000001, 10813.697, 10456.209, 10099.074, 9755.39000000001, 9393.18500000006, 9047.57900000003, 8657.98499999999, 8395.85900000005, 8033, 7736.95900000003, 7430.59699999995, 7258.47699999996, 6924.58200000005, 6691.29399999999, 6357.92500000005, 6202.05700000003, 5921.19700000004, 5628.28399999999, 5404.96799999999, 5226.71100000001, 4990.75600000005, 4799.77399999998, 4622.93099999998, 4472.478, 4171.78700000001, 3957.46299999999, 3868.95200000005, 3691.14300000004, 3474.63100000005, 3341.67200000002, 3109.14000000001, 3071.97400000005, 2796.40399999998, 2756.17799999996, 2611.46999999997, 2471.93000000005, 2382.26399999997, 2209.22400000005, 2142.28399999999, 2013.96100000001, 1911.18999999994, 1818.27099999995, 1668.47900000005, 1519.65800000005, 1469.67599999998, 1367.13800000004, 1248.52899999998, 1181.23600000003, 1022.71900000004, 1088.20700000005, 959.03600000008, 876.095999999903, 791.183999999892, 703.337000000058, 731.949999999953, 586.86400000006, 526.024999999907, 323.004999999888, 320.448000000091, 340.672999999952, 309.638999999966, 216.601999999955, 102.922999999952, 19.2399999999907, -0.114000000059605, -32.6240000000689, -89.3179999999702, -153.497999999905, -64.2970000000205, -143.695999999996, -259.497999999905, -253.017999999924, -213.948000000091, -397.590000000084, -434.006000000052, -403.475000000093, -297.958000000101, -404.317000000039, -528.898999999976, -506.621000000043, -513.205000000075, -479.351000000024, -596.139999999898, -527.016999999993, -664.681000000099, -680.306000000099, -704.050000000047, -850.486000000034, -757.43200000003, -713.308999999892) ) // scalastyle:on + + private def validateDoubleLiteral(exp: Expression): Double = exp match { + case Literal(d: Double, DoubleType) => d + case Literal(dec: Decimal, _) => dec.toDouble + case _ => + throw new AnalysisException("The second argument should be a double literal.") + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala deleted file mode 100644 index 6da39e7143447..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.expressions._ - -case class Kurtosis(child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends CentralMomentAgg(child) { - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "kurtosis" - - override protected val momentOrder = 4 - - // NOTE: this is the formula for excess kurtosis, which is default for R and SciPy - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") - val m2 = moments(2) - val m4 = moments(4) - if (n == 0.0 || m2 == 0.0) { - Double.NaN - } else { - n * m4 / (m2 * m2) - 3.0 - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 8636bfe8d07aa..b05d74b49b591 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.types._ * is used) its result will not be deterministic (unless the input table is sorted and has * a single partition, and we use a single reducer to do the aggregation.). */ +@ExpressionDescription( + usage = "_FUNC_(expr,isIgnoreNull) - Returns the last value of `child` for a group of rows.") case class Last(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate { def this(child: Expression) = this(child, Literal.create(false, BooleanType)) @@ -51,15 +53,15 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val last = AttributeReference("last", child.dataType)() + private lazy val last = AttributeReference("last", child.dataType)() - override val aggBufferAttributes: Seq[AttributeReference] = last :: Nil + override lazy val aggBufferAttributes: Seq[AttributeReference] = last :: Nil - override val initialValues: Seq[Literal] = Seq( + override lazy val initialValues: Seq[Literal] = Seq( /* last = */ Literal.create(null, child.dataType) ) - override val updateExpressions: Seq[Expression] = { + override lazy val updateExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( /* last = */ If(IsNull(child), last, child) @@ -71,7 +73,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat } } - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( /* last = */ If(IsNull(last.right), last.left, last.right) @@ -83,7 +85,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat } } - override val evaluateExpression: AttributeReference = last + override lazy val evaluateExpression: AttributeReference = last override def toString: String = s"last($child)${if (ignoreNulls) " ignore nulls"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index b9d75ad452838..c534fe495fc13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -17,9 +17,13 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the maximum value of expr.") case class Max(child: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = child :: Nil @@ -32,24 +36,26 @@ case class Max(child: Expression) extends DeclarativeAggregate { // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val max = AttributeReference("max", child.dataType)() + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(child.dataType, "function max") - override val aggBufferAttributes: Seq[AttributeReference] = max :: Nil + private lazy val max = AttributeReference("max", child.dataType)() - override val initialValues: Seq[Literal] = Seq( + override lazy val aggBufferAttributes: Seq[AttributeReference] = max :: Nil + + override lazy val initialValues: Seq[Literal] = Seq( /* max = */ Literal.create(null, child.dataType) ) - override val updateExpressions: Seq[Expression] = Seq( - /* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child)))) + override lazy val updateExpressions: Seq[Expression] = Seq( + /* max = */ Greatest(Seq(max, child)) ) - override val mergeExpressions: Seq[Expression] = { - val greatest = Greatest(Seq(max.left, max.right)) + override lazy val mergeExpressions: Seq[Expression] = { Seq( - /* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest)) + /* max = */ Greatest(Seq(max.left, max.right)) ) } - override val evaluateExpression: AttributeReference = max + override lazy val evaluateExpression: AttributeReference = max } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index 5ed9cd348daba..35289b468183c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ - +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the minimum value of expr.") case class Min(child: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = child :: Nil @@ -33,24 +36,26 @@ case class Min(child: Expression) extends DeclarativeAggregate { // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val min = AttributeReference("min", child.dataType)() + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(child.dataType, "function min") + + private lazy val min = AttributeReference("min", child.dataType)() - override val aggBufferAttributes: Seq[AttributeReference] = min :: Nil + override lazy val aggBufferAttributes: Seq[AttributeReference] = min :: Nil - override val initialValues: Seq[Expression] = Seq( + override lazy val initialValues: Seq[Expression] = Seq( /* min = */ Literal.create(null, child.dataType) ) - override val updateExpressions: Seq[Expression] = Seq( - /* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child)))) + override lazy val updateExpressions: Seq[Expression] = Seq( + /* min = */ Least(Seq(min, child)) ) - override val mergeExpressions: Seq[Expression] = { - val least = Least(Seq(min.left, min.right)) + override lazy val mergeExpressions: Seq[Expression] = { Seq( - /* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least)) + /* min = */ Least(Seq(min.left, min.right)) ) } - override val evaluateExpression: AttributeReference = min + override lazy val evaluateExpression: AttributeReference = min } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala deleted file mode 100644 index 0def7ddfd9d3d..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.expressions._ - -case class Skewness(child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends CentralMomentAgg(child) { - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "skewness" - - override protected val momentOrder = 3 - - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") - val m2 = moments(2) - val m3 = moments(3) - if (n == 0.0 || m2 == 0.0) { - Double.NaN - } else { - math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala deleted file mode 100644 index 3f47ffe13cbc8..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types._ - - -// Compute the population standard deviation of a column -case class StddevPop(child: Expression) extends StddevAgg(child) { - override def isSample: Boolean = false - override def prettyName: String = "stddev_pop" -} - - -// Compute the sample standard deviation of a column -case class StddevSamp(child: Expression) extends StddevAgg(child) { - override def isSample: Boolean = true - override def prettyName: String = "stddev_samp" -} - - -// Compute standard deviation based on online algorithm specified here: -// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance -abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { - - def isSample: Boolean - - override def children: Seq[Expression] = child :: Nil - - override def nullable: Boolean = true - - override def dataType: DataType = resultType - - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select stddev(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) - - private val resultType = DoubleType - - private val count = AttributeReference("count", resultType)() - private val avg = AttributeReference("avg", resultType)() - private val mk = AttributeReference("mk", resultType)() - - override val aggBufferAttributes = count :: avg :: mk :: Nil - - override val initialValues: Seq[Expression] = Seq( - /* count = */ Cast(Literal(0), resultType), - /* avg = */ Cast(Literal(0), resultType), - /* mk = */ Cast(Literal(0), resultType) - ) - - override val updateExpressions: Seq[Expression] = { - val value = Cast(child, resultType) - val newCount = count + Cast(Literal(1), resultType) - - // update average - // avg = avg + (value - avg)/count - val newAvg = avg + (value - avg) / newCount - - // update sum ofference from mean - // Mk = Mk + (value - preAvg) * (value - updatedAvg) - val newMk = mk + (value - avg) * (value - newAvg) - - Seq( - /* count = */ If(IsNull(child), count, newCount), - /* avg = */ If(IsNull(child), avg, newAvg), - /* mk = */ If(IsNull(child), mk, newMk) - ) - } - - override val mergeExpressions: Seq[Expression] = { - - // count merge - val newCount = count.left + count.right - - // average merge - val newAvg = ((avg.left * count.left) + (avg.right * count.right)) / newCount - - // update sum of square differences - val newMk = { - val avgDelta = avg.right - avg.left - val mkDelta = (avgDelta * avgDelta) * (count.left * count.right) / newCount - mk.left + mk.right + mkDelta - } - - Seq( - /* count = */ If(IsNull(count.left), count.right, - If(IsNull(count.right), count.left, newCount)), - /* avg = */ If(IsNull(avg.left), avg.right, - If(IsNull(avg.right), avg.left, newAvg)), - /* mk = */ If(IsNull(mk.left), mk.right, - If(IsNull(mk.right), mk.left, newMk)) - ) - } - - override val evaluateExpression: Expression = { - // when count == 0, return null - // when count == 1, return 0 - // when count >1 - // stddev_samp = sqrt (mk/(count -1)) - // stddev_pop = sqrt (mk/count) - val varCol = - if (isSample) { - mk / Cast(count - Cast(Literal(1), resultType), resultType) - } else { - mk / count - } - - If(EqualTo(count, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), - If(EqualTo(count, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), - Cast(Sqrt(varCol), resultType))) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 7f8adbc56ad1d..ad217f25b5a26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -17,9 +17,13 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the sum calculated from values of a group.") case class Sum(child: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = child :: Nil @@ -29,47 +33,50 @@ case class Sum(child: Expression) extends DeclarativeAggregate { // Return data type. override def dataType: DataType = resultType - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select sum(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType)) + Seq(TypeCollection(LongType, DoubleType, DecimalType)) - private val resultType = child.dataType match { + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function sum") + + private lazy val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType.bounded(precision + 10, scale) - // TODO: Remove this line once we remove the NullType from inputTypes. - case NullType => IntegerType case _ => child.dataType } - private val sumDataType = resultType + private lazy val sumDataType = resultType - private val sum = AttributeReference("sum", sumDataType)() + private lazy val sum = AttributeReference("sum", sumDataType)() - private val zero = Cast(Literal(0), sumDataType) + private lazy val zero = Cast(Literal(0), sumDataType) - override val aggBufferAttributes = sum :: Nil + override lazy val aggBufferAttributes = sum :: Nil - override val initialValues: Seq[Expression] = Seq( + override lazy val initialValues: Seq[Expression] = Seq( /* sum = */ Literal.create(null, sumDataType) ) - override val updateExpressions: Seq[Expression] = Seq( - /* sum = */ - Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) - ) + override lazy val updateExpressions: Seq[Expression] = { + if (child.nullable) { + Seq( + /* sum = */ + Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) + ) + } else { + Seq( + /* sum = */ + Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)) + ) + } + } - override val mergeExpressions: Seq[Expression] = { - val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType)) + override lazy val mergeExpressions: Seq[Expression] = { Seq( /* sum = */ - Coalesce(Seq(add, sum.left)) + Coalesce(Seq(Add(Coalesce(Seq(sum.left, zero)), sum.right), sum.left)) ) } - override val evaluateExpression: Expression = Cast(sum, resultType) + override lazy val evaluateExpression: Expression = sum } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala deleted file mode 100644 index 644c6211d5f31..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala +++ /dev/null @@ -1,215 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} -import org.apache.spark.sql.types.{StructType, MapType, ArrayType} - -/** - * Utility functions used by the query planner to convert our plan to new aggregation code path. - */ -object Utils { - // Right now, we do not support complex types in the grouping key schema. - private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = { - val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists { - case array: ArrayType => true - case map: MapType => true - case struct: StructType => true - case _ => false - } - - !hasComplexTypes - } - - private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match { - case p: Aggregate if supportsGroupingKeySchema(p) => - val converted = p.transformExpressionsDown { - case expressions.Average(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Average(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Count(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Count(child), - mode = aggregate.Complete, - isDistinct = false) - - // We do not support multiple COUNT DISTINCT columns for now. - case expressions.CountDistinct(children) if children.length == 1 => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Count(children.head), - mode = aggregate.Complete, - isDistinct = true) - - case expressions.First(child, ignoreNulls) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.First(child, ignoreNulls), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Kurtosis(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Kurtosis(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Last(child, ignoreNulls) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Last(child, ignoreNulls), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Max(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Max(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Min(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Min(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Skewness(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Skewness(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.StddevPop(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.StddevPop(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.StddevSamp(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.StddevSamp(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Sum(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Sum(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.SumDistinct(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Sum(child), - mode = aggregate.Complete, - isDistinct = true) - - case expressions.Corr(left, right) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Corr(left, right), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.ApproxCountDistinct(child, rsd) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.HyperLogLogPlusPlus(child, rsd), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.VariancePop(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.VariancePop(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.VarianceSamp(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.VarianceSamp(child), - mode = aggregate.Complete, - isDistinct = false) - } - // Check if there is any expressions.AggregateExpression1 left. - // If so, we cannot convert this plan. - val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr => - // For every expressions, check if it contains AggregateExpression1. - expr.find { - case agg: expressions.AggregateExpression1 => true - case other => false - }.isDefined - } - - // Check if there are multiple distinct columns. - val aggregateExpressions = converted.aggregateExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression2 => agg - } - }.toSet.toSeq - val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct) - val hasMultipleDistinctColumnSets = - if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { - true - } else { - false - } - - if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None - - case other => None - } - - def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = { - // If the plan cannot be converted, we will do a final round check to see if the original - // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so, - // we need to throw an exception. - val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression2 => agg.aggregateFunction - } - }.distinct - if (aggregateFunction2s.nonEmpty) { - // For functions implemented based on the new interface, prepare a list of function names. - val invalidFunctions = { - if (aggregateFunction2s.length > 1) { - s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " + - s"and ${aggregateFunction2s.head.nodeName} are" - } else { - s"${aggregateFunction2s.head.nodeName} is" - } - } - val errorMessage = - s"${invalidFunctions} implemented based on the new Aggregate Function " + - s"interface and it cannot be used with functions implemented based on " + - s"the old Aggregate Function interface." - throw new AnalysisException(errorMessage) - } - } - - def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match { - case p: Aggregate => - val converted = doConvert(p) - if (converted.isDefined) { - converted - } else { - checkInvalidAggregateFunction2(p) - None - } - case other => None - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala deleted file mode 100644 index ec63534e5290a..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.catalyst.expressions._ - -case class VarianceSamp(child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends CentralMomentAgg(child) { - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "var_samp" - - override protected val momentOrder = 2 - - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - - if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / (n - 1.0) - } -} - -case class VariancePop(child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends CentralMomentAgg(child) { - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "var_pop" - - override protected val momentOrder = 2 - - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - - if (n == 0.0) Double.NaN else moments(2) / n - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index a2fab258fcac3..d31ccf9985360 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -17,23 +17,24 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types._ -/** The mode of an [[AggregateFunction2]]. */ +/** The mode of an [[AggregateFunction]]. */ private[sql] sealed trait AggregateMode /** - * An [[AggregateFunction2]] with [[Partial]] mode is used for partial aggregation. + * An [[AggregateFunction]] with [[Partial]] mode is used for partial aggregation. * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the aggregation buffer is returned. */ private[sql] case object Partial extends AggregateMode /** - * An [[AggregateFunction2]] with [[PartialMerge]] mode is used to merge aggregation buffers + * An [[AggregateFunction]] with [[PartialMerge]] mode is used to merge aggregation buffers * containing intermediate results for this function. * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the aggregation buffer is returned. @@ -41,7 +42,7 @@ private[sql] case object Partial extends AggregateMode private[sql] case object PartialMerge extends AggregateMode /** - * An [[AggregateFunction2]] with [[Final]] mode is used to merge aggregation buffers + * An [[AggregateFunction]] with [[Final]] mode is used to merge aggregation buffers * containing intermediate results for this function and then generate final result. * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the final result of this function is returned. @@ -49,7 +50,7 @@ private[sql] case object PartialMerge extends AggregateMode private[sql] case object Final extends AggregateMode /** - * An [[AggregateFunction2]] with [[Complete]] mode is used to evaluate this function directly + * An [[AggregateFunction]] with [[Complete]] mode is used to evaluate this function directly * from original input rows without any partial aggregation. * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the final result of this function is returned. @@ -66,14 +67,50 @@ private[sql] case object NoOp extends Expression with Unevaluable { override def children: Seq[Expression] = Nil } +object AggregateExpression { + def apply( + aggregateFunction: AggregateFunction, + mode: AggregateMode, + isDistinct: Boolean): AggregateExpression = { + AggregateExpression( + aggregateFunction, + mode, + isDistinct, + NamedExpression.newExprId) + } +} + /** - * A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field + * A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field * (`isDistinct`) indicating if DISTINCT keyword is specified for this function. */ -private[sql] case class AggregateExpression2( - aggregateFunction: AggregateFunction2, +private[sql] case class AggregateExpression( + aggregateFunction: AggregateFunction, mode: AggregateMode, - isDistinct: Boolean) extends AggregateExpression { + isDistinct: Boolean, + resultId: ExprId) + extends Expression + with Unevaluable { + + lazy val resultAttribute: Attribute = if (aggregateFunction.resolved) { + AttributeReference( + aggregateFunction.toString, + aggregateFunction.dataType, + aggregateFunction.nullable)(exprId = resultId) + } else { + // This is a bit of a hack. Really we should not be constructing this container and reasoning + // about datatypes / aggregation mode until after we have finished analysis and made it to + // planning. + UnresolvedAttribute(aggregateFunction.toString) + } + + // We compute the same thing regardless of our final result. + override lazy val canonicalized: Expression = + AggregateExpression( + aggregateFunction.canonicalized.asInstanceOf[AggregateFunction], + mode, + isDistinct, + ExprId(0)) override def children: Seq[Expression] = aggregateFunction :: Nil override def dataType: DataType = aggregateFunction.dataType @@ -89,11 +126,13 @@ private[sql] case class AggregateExpression2( AttributeSet(childReferences) } - override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)" + override def toString: String = s"($aggregateFunction,mode=$mode,isDistinct=$isDistinct)" + + override def sql: String = aggregateFunction.sql(isDistinct) } /** - * AggregateFunction2 is the superclass of two aggregation function interfaces: + * AggregateFunction is the superclass of two aggregation function interfaces: * * - [[ImperativeAggregate]] is for aggregation functions that are specified in terms of * initialize(), update(), and merge() functions that operate on Row-based aggregation buffers. @@ -106,10 +145,10 @@ private[sql] case class AggregateExpression2( * combined aggregation buffer which concatenates the aggregation buffers of the individual * aggregate functions. * - * Code which accepts [[AggregateFunction2]] instances should be prepared to handle both types of + * Code which accepts [[AggregateFunction]] instances should be prepared to handle both types of * aggregate functions. */ -sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInputTypes { +sealed abstract class AggregateFunction extends Expression with ImplicitCastInputTypes { /** An aggregate function is not foldable. */ final override def foldable: Boolean = false @@ -133,8 +172,37 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp */ def supportsPartial: Boolean = true - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = - throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") + /** + * Result of the aggregate function when the input is empty. This is currently only used for the + * proper rewriting of distinct aggregate functions. + */ + def defaultResult: Option[Literal] = None + + /** + * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] because + * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode, + * and the flag indicating if this aggregation is distinct aggregation or not. + * An [[AggregateFunction]] should not be used without being wrapped in + * an [[AggregateExpression]]. + */ + def toAggregateExpression(): AggregateExpression = toAggregateExpression(isDistinct = false) + + /** + * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] and set isDistinct + * field of the [[AggregateExpression]] to the given value because + * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode, + * and the flag indicating if this aggregation is distinct aggregation or not. + * An [[AggregateFunction]] should not be used without being wrapped in + * an [[AggregateExpression]]. + */ + def toAggregateExpression(isDistinct: Boolean): AggregateExpression = { + AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct) + } + + def sql(isDistinct: Boolean): String = { + val distinct = if (isDistinct) "DISTINCT " else "" + s"$prettyName($distinct${children.map(_.sql).mkString(", ")})" + } } /** @@ -155,7 +223,7 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp * `inputAggBufferOffset`, but not on the correctness of the attribute ids in `aggBufferAttributes` * and `inputAggBufferAttributes`. */ -abstract class ImperativeAggregate extends AggregateFunction2 { +abstract class ImperativeAggregate extends AggregateFunction with CodegenFallback { /** * The offset of this function's first buffer value in the underlying shared mutable aggregation @@ -164,7 +232,7 @@ abstract class ImperativeAggregate extends AggregateFunction2 { * For example, we have two aggregate functions `avg(x)` and `avg(y)`, which share the same * aggregation buffer. In this shared buffer, the position of the first buffer value of `avg(x)` * will be 0 and the position of the first buffer value of `avg(y)` will be 2: - * + * {{{ * avg(x) mutableAggBufferOffset = 0 * | * v @@ -174,7 +242,7 @@ abstract class ImperativeAggregate extends AggregateFunction2 { * ^ * | * avg(y) mutableAggBufferOffset = 2 - * + * }}} */ protected val mutableAggBufferOffset: Int @@ -197,7 +265,7 @@ abstract class ImperativeAggregate extends AggregateFunction2 { * `avg(x)` and `avg(y)`. In the shared input aggregation buffer, the position of the first * buffer value of `avg(x)` will be 1 and the position of the first buffer value of `avg(y)` * will be 3 (position 0 is used for the value of `key`): - * + * {{{ * avg(x) inputAggBufferOffset = 1 * | * v @@ -207,7 +275,7 @@ abstract class ImperativeAggregate extends AggregateFunction2 { * ^ * | * avg(y) inputAggBufferOffset = 3 - * + * }}} */ protected val inputAggBufferOffset: Int @@ -252,9 +320,14 @@ abstract class ImperativeAggregate extends AggregateFunction2 { * `bufferAttributes`, defining attributes for the fields of the mutable aggregation buffer. You * can then use these attributes when defining `updateExpressions`, `mergeExpressions`, and * `evaluateExpressions`. + * + * Please note that children of an aggregate function can be unresolved (it will happen when + * we create this function in DataFrame API). So, if there is any fields in + * the implemented class that need to access fields of its children, please make + * those fields `lazy val`s. */ abstract class DeclarativeAggregate - extends AggregateFunction2 + extends AggregateFunction with Serializable with Unevaluable { @@ -303,4 +376,3 @@ abstract class DeclarativeAggregate def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a)) } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala deleted file mode 100644 index 3dcf7915d77b3..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ /dev/null @@ -1,1073 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import com.clearspring.analytics.stream.cardinality.HyperLogLog - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData, TypeUtils} -import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.OpenHashSet - - -trait AggregateExpression extends Expression with Unevaluable - -trait AggregateExpression1 extends AggregateExpression { - - /** - * Aggregate expressions should not be foldable. - */ - override def foldable: Boolean = false - - /** - * Creates a new instance that can be used to compute this aggregate expression for a group - * of input rows/ - */ - def newInstance(): AggregateFunction1 -} - -/** - * Represents an aggregation that has been rewritten to be performed in two steps. - * - * @param finalEvaluation an aggregate expression that evaluates to same final result as the - * original aggregation. - * @param partialEvaluations A sequence of [[NamedExpression]]s that can be computed on partial - * data sets and are required to compute the `finalEvaluation`. - */ -case class SplitEvaluation( - finalEvaluation: Expression, - partialEvaluations: Seq[NamedExpression]) - -/** - * An [[AggregateExpression1]] that can be partially computed without seeing all relevant tuples. - * These partial evaluations can then be combined to compute the actual answer. - */ -trait PartialAggregate1 extends AggregateExpression1 { - - /** - * Returns a [[SplitEvaluation]] that computes this aggregation using partial aggregation. - */ - def asPartial: SplitEvaluation -} - -/** - * A specific implementation of an aggregate function. Used to wrap a generic - * [[AggregateExpression1]] with an algorithm that will be used to compute one specific result. - */ -abstract class AggregateFunction1 extends LeafExpression with Serializable { - - /** Base should return the generic aggregate expression that this function is computing */ - val base: AggregateExpression1 - - override def nullable: Boolean = base.nullable - override def dataType: DataType = base.dataType - - def update(input: InternalRow): Unit - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - throw new UnsupportedOperationException( - "AggregateFunction1 should not be used for generated aggregates") - } -} - -case class Min(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - - override def asPartial: SplitEvaluation = { - val partialMin = Alias(Min(child), "PartialMin")() - SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil) - } - - override def newInstance(): MinFunction = new MinFunction(child, this) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForOrderingExpr(child.dataType, "function min") -} - -case class MinFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. - - val currentMin: MutableLiteral = MutableLiteral(null, expr.dataType) - val cmp = GreaterThan(currentMin, expr) - - override def update(input: InternalRow): Unit = { - if (currentMin.value == null) { - currentMin.value = expr.eval(input) - } else if (cmp.eval(input) == true) { - currentMin.value = expr.eval(input) - } - } - - override def eval(input: InternalRow): Any = currentMin.value -} - -case class Max(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - - override def asPartial: SplitEvaluation = { - val partialMax = Alias(Max(child), "PartialMax")() - SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil) - } - - override def newInstance(): MaxFunction = new MaxFunction(child, this) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForOrderingExpr(child.dataType, "function max") -} - -case class MaxFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. - - val currentMax: MutableLiteral = MutableLiteral(null, expr.dataType) - val cmp = LessThan(currentMax, expr) - - override def update(input: InternalRow): Unit = { - if (currentMax.value == null) { - currentMax.value = expr.eval(input) - } else if (cmp.eval(input) == true) { - currentMax.value = expr.eval(input) - } - } - - override def eval(input: InternalRow): Any = currentMax.value -} - -case class Count(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = false - override def dataType: LongType.type = LongType - - override def asPartial: SplitEvaluation = { - val partialCount = Alias(Count(child), "PartialCount")() - SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), partialCount :: Nil) - } - - override def newInstance(): CountFunction = new CountFunction(child, this) -} - -case class CountFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. - - var count: Long = _ - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - count += 1L - } - } - - override def eval(input: InternalRow): Any = count -} - -case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate1 { - def this() = this(null) - - override def children: Seq[Expression] = expressions - - override def nullable: Boolean = false - override def dataType: DataType = LongType - override def toString: String = s"COUNT(DISTINCT ${expressions.mkString(",")})" - override def newInstance(): CountDistinctFunction = new CountDistinctFunction(expressions, this) - - override def asPartial: SplitEvaluation = { - val partialSet = Alias(CollectHashSet(expressions), "partialSets")() - SplitEvaluation( - CombineSetsAndCount(partialSet.toAttribute), - partialSet :: Nil) - } -} - -case class CountDistinctFunction( - @transient expr: Seq[Expression], - @transient base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - @transient - val distinctValue = new InterpretedProjection(expr) - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = distinctValue(input) - if (!evaluatedExpr.anyNull) { - seen.add(evaluatedExpr) - } - } - - override def eval(input: InternalRow): Any = seen.size.toLong -} - -case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression1 { - def this() = this(null) - - override def children: Seq[Expression] = expressions - override def nullable: Boolean = false - override def dataType: OpenHashSetUDT = new OpenHashSetUDT(expressions.head.dataType) - override def toString: String = s"AddToHashSet(${expressions.mkString(",")})" - override def newInstance(): CollectHashSetFunction = - new CollectHashSetFunction(expressions, this) -} - -case class CollectHashSetFunction( - @transient expr: Seq[Expression], - @transient base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - @transient - val distinctValue = new InterpretedProjection(expr) - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = distinctValue(input) - if (!evaluatedExpr.anyNull) { - seen.add(evaluatedExpr) - } - } - - override def eval(input: InternalRow): Any = { - seen - } -} - -case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression1 { - def this() = this(null) - - override def children: Seq[Expression] = inputSet :: Nil - override def nullable: Boolean = false - override def dataType: DataType = LongType - override def toString: String = s"CombineAndCount($inputSet)" - override def newInstance(): CombineSetsAndCountFunction = { - new CombineSetsAndCountFunction(inputSet, this) - } -} - -case class CombineSetsAndCountFunction( - @transient inputSet: Expression, - @transient base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - override def update(input: InternalRow): Unit = { - val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] - val inputIterator = inputSetEval.iterator - while (inputIterator.hasNext) { - seen.add(inputIterator.next) - } - } - - override def eval(input: InternalRow): Any = seen.size.toLong -} - -/** The data type of ApproxCountDistinctPartition since its output is a HyperLogLog object. */ -private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] { - - override def sqlType: DataType = BinaryType - - /** Since we are using HyperLogLog internally, usually it will not be called. */ - override def serialize(obj: Any): Array[Byte] = - obj.asInstanceOf[HyperLogLog].getBytes - - - /** Since we are using HyperLogLog internally, usually it will not be called. */ - override def deserialize(datum: Any): HyperLogLog = - HyperLogLog.Builder.build(datum.asInstanceOf[Array[Byte]]) - - override def userClass: Class[HyperLogLog] = classOf[HyperLogLog] -} - -case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) - extends UnaryExpression with AggregateExpression1 { - - override def nullable: Boolean = false - override def dataType: DataType = HyperLogLogUDT - override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" - override def newInstance(): ApproxCountDistinctPartitionFunction = { - new ApproxCountDistinctPartitionFunction(child, this, relativeSD) - } -} - -case class ApproxCountDistinctPartitionFunction( - expr: Expression, - base: AggregateExpression1, - relativeSD: Double) - extends AggregateFunction1 { - def this() = this(null, null, 0) // Required for serialization. - - private val hyperLogLog = new HyperLogLog(relativeSD) - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - hyperLogLog.offer(evaluatedExpr) - } - } - - override def eval(input: InternalRow): Any = hyperLogLog -} - -case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) - extends UnaryExpression with AggregateExpression1 { - - override def nullable: Boolean = false - override def dataType: LongType.type = LongType - override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" - override def newInstance(): ApproxCountDistinctMergeFunction = { - new ApproxCountDistinctMergeFunction(child, this, relativeSD) - } -} - -case class ApproxCountDistinctMergeFunction( - expr: Expression, - base: AggregateExpression1, - relativeSD: Double) - extends AggregateFunction1 { - def this() = this(null, null, 0) // Required for serialization. - - private val hyperLogLog = new HyperLogLog(relativeSD) - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog]) - } - - override def eval(input: InternalRow): Any = hyperLogLog.cardinality() -} - -case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) - extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = false - override def dataType: LongType.type = LongType - override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" - - override def asPartial: SplitEvaluation = { - val partialCount = - Alias(ApproxCountDistinctPartition(child, relativeSD), "PartialApproxCountDistinct")() - - SplitEvaluation( - ApproxCountDistinctMerge(partialCount.toAttribute, relativeSD), - partialCount :: Nil) - } - - override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this) -} - -case class Average(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def prettyName: String = "avg" - - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - // Add 4 digits after decimal point, like Hive - DecimalType.bounded(precision + 4, scale + 4) - case _ => - DoubleType - } - - override def asPartial: SplitEvaluation = { - child.dataType match { - case DecimalType.Fixed(precision, scale) => - val partialSum = Alias(Sum(child), "PartialSum")() - val partialCount = Alias(Count(child), "PartialCount")() - - // partialSum already increase the precision by 10 - val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType) - val castedCount = Cast(Sum(partialCount.toAttribute), partialSum.dataType) - SplitEvaluation( - Cast(Divide(castedSum, castedCount), dataType), - partialCount :: partialSum :: Nil) - - case _ => - val partialSum = Alias(Sum(child), "PartialSum")() - val partialCount = Alias(Count(child), "PartialCount")() - - val castedSum = Cast(Sum(partialSum.toAttribute), dataType) - val castedCount = Cast(Sum(partialCount.toAttribute), dataType) - SplitEvaluation( - Divide(castedSum, castedCount), - partialCount :: partialSum :: Nil) - } - } - - override def newInstance(): AverageFunction = new AverageFunction(child, this) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function average") -} - -case class AverageFunction(expr: Expression, base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - private val calcType = - expr.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType.bounded(precision + 10, scale) - case _ => - expr.dataType - } - - private val zero = Cast(Literal(0), calcType) - - private var count: Long = _ - private val sum = MutableLiteral(zero.eval(null), calcType) - - private def addFunction(value: Any) = Add(sum, - Cast(Literal.create(value, expr.dataType), calcType)) - - override def eval(input: InternalRow): Any = { - if (count == 0L) { - null - } else { - expr.dataType match { - case DecimalType.Fixed(precision, scale) => - val dt = DecimalType.bounded(precision + 14, scale + 4) - Cast(Divide(Cast(sum, dt), Cast(Literal(count), dt)), dataType).eval(null) - case _ => - Divide( - Cast(sum, dataType), - Cast(Literal(count), dataType)).eval(null) - } - } - } - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - count += 1 - sum.update(addFunction(evaluatedExpr), input) - } - } -} - -case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - // Add 10 digits left of decimal point, like Hive - DecimalType.bounded(precision + 10, scale) - case _ => - child.dataType - } - - override def asPartial: SplitEvaluation = { - child.dataType match { - case DecimalType.Fixed(_, _) => - val partialSum = Alias(Sum(child), "PartialSum")() - SplitEvaluation( - Cast(Sum(partialSum.toAttribute), dataType), - partialSum :: Nil) - - case _ => - val partialSum = Alias(Sum(child), "PartialSum")() - SplitEvaluation( - Sum(partialSum.toAttribute), - partialSum :: Nil) - } - } - - override def newInstance(): SumFunction = new SumFunction(child, this) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function sum") -} - -case class SumFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. - - private val calcType = - expr.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType.bounded(precision + 10, scale) - case _ => - expr.dataType - } - - private val zero = Cast(Literal(0), calcType) - - private val sum = MutableLiteral(null, calcType) - - private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum)) - - override def update(input: InternalRow): Unit = { - sum.update(addFunction, input) - } - - override def eval(input: InternalRow): Any = { - expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(sum, dataType).eval(null) - case _ => sum.eval(null) - } - } -} - -case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate1 { - - def this() = this(null) - override def nullable: Boolean = true - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - // Add 10 digits left of decimal point, like Hive - DecimalType.bounded(precision + 10, scale) - case _ => - child.dataType - } - override def toString: String = s"sum(distinct $child)" - override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this) - - override def asPartial: SplitEvaluation = { - val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")() - SplitEvaluation( - CombineSetsAndSum(partialSet.toAttribute, this), - partialSet :: Nil) - } - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct") -} - -case class SumDistinctFunction(expr: Expression, base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - private val seen = new scala.collection.mutable.HashSet[Any]() - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - seen += evaluatedExpr - } - } - - override def eval(input: InternalRow): Any = { - if (seen.size == 0) { - null - } else { - Cast(Literal( - seen.reduceLeft( - dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), - dataType).eval(null) - } - } -} - -case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression1 { - def this() = this(null, null) - - override def children: Seq[Expression] = inputSet :: Nil - override def nullable: Boolean = true - override def dataType: DataType = base.dataType - override def toString: String = s"CombineAndSum($inputSet)" - override def newInstance(): CombineSetsAndSumFunction = { - new CombineSetsAndSumFunction(inputSet, this) - } -} - -case class CombineSetsAndSumFunction( - @transient inputSet: Expression, - @transient base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - override def update(input: InternalRow): Unit = { - val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] - val inputIterator = inputSetEval.iterator - while (inputIterator.hasNext) { - seen.add(inputIterator.next()) - } - } - - override def eval(input: InternalRow): Any = { - val casted = seen.asInstanceOf[OpenHashSet[InternalRow]] - if (casted.size == 0) { - null - } else { - Cast(Literal( - casted.iterator.map(f => f.get(0, null)).reduceLeft( - base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), - base.dataType).eval(null) - } - } -} - -case class First( - child: Expression, - ignoreNullsExpr: Expression) - extends UnaryExpression with PartialAggregate1 { - - def this(child: Expression) = this(child, Literal.create(false, BooleanType)) - - private val ignoreNulls: Boolean = ignoreNullsExpr match { - case Literal(b: Boolean, BooleanType) => b - case _ => - throw new AnalysisException("The second argument of First should be a boolean literal.") - } - - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"first(${child}${if (ignoreNulls) " ignore nulls"})" - - override def asPartial: SplitEvaluation = { - val partialFirst = Alias(First(child, ignoreNulls), "PartialFirst")() - SplitEvaluation( - First(partialFirst.toAttribute, ignoreNulls), - partialFirst :: Nil) - } - override def newInstance(): FirstFunction = new FirstFunction(child, ignoreNulls, this) -} - -object First { - def apply(child: Expression): First = First(child, ignoreNulls = false) - - def apply(child: Expression, ignoreNulls: Boolean): First = - First(child, Literal.create(ignoreNulls, BooleanType)) -} - -case class FirstFunction( - expr: Expression, - ignoreNulls: Boolean, - base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization. - - private[this] var result: Any = null - - private[this] var valueSet: Boolean = false - - override def update(input: InternalRow): Unit = { - if (!valueSet) { - val value = expr.eval(input) - // When we have not set the result, we will set the result if we respect nulls - // (i.e. ignoreNulls is false), or we ignore nulls and the evaluated value is not null. - if (!ignoreNulls || (ignoreNulls && value != null)) { - result = value - valueSet = true - } - } - } - - override def eval(input: InternalRow): Any = result -} - -case class Last( - child: Expression, - ignoreNullsExpr: Expression) - extends UnaryExpression with PartialAggregate1 { - - def this(child: Expression) = this(child, Literal.create(false, BooleanType)) - - private val ignoreNulls: Boolean = ignoreNullsExpr match { - case Literal(b: Boolean, BooleanType) => b - case _ => - throw new AnalysisException("The second argument of First should be a boolean literal.") - } - - override def references: AttributeSet = child.references - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"last($child)${if (ignoreNulls) " ignore nulls"}" - - override def asPartial: SplitEvaluation = { - val partialLast = Alias(Last(child, ignoreNulls), "PartialLast")() - SplitEvaluation( - Last(partialLast.toAttribute, ignoreNulls), - partialLast :: Nil) - } - override def newInstance(): LastFunction = new LastFunction(child, ignoreNulls, this) -} - -object Last { - def apply(child: Expression): Last = Last(child, ignoreNulls = false) - - def apply(child: Expression, ignoreNulls: Boolean): Last = - Last(child, Literal.create(ignoreNulls, BooleanType)) -} - -case class LastFunction( - expr: Expression, - ignoreNulls: Boolean, - base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization. - - var result: Any = null - - override def update(input: InternalRow): Unit = { - val value = expr.eval(input) - if (!ignoreNulls || (ignoreNulls && value != null)) { - result = value - } - } - - override def eval(input: InternalRow): Any = { - result - } -} - -/** - * Calculate Pearson Correlation Coefficient for the given columns. - * Only support AggregateExpression2. - * - */ -case class Corr(left: Expression, right: Expression) - extends BinaryExpression with AggregateExpression1 with ImplicitCastInputTypes { - override def nullable: Boolean = false - override def dataType: DoubleType.type = DoubleType - override def toString: String = s"corr($left, $right)" - override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) - override def newInstance(): AggregateFunction1 = { - throw new UnsupportedOperationException( - "Corr only supports the new AggregateExpression2 and can only be used " + - "when spark.sql.useAggregate2 = true") - } -} - -// Compute standard deviation based on online algorithm specified here: -// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance -abstract class StddevAgg1(child: Expression) extends UnaryExpression with PartialAggregate1 { - override def nullable: Boolean = true - override def dataType: DataType = DoubleType - - def isSample: Boolean - - override def asPartial: SplitEvaluation = { - val partialStd = Alias(ComputePartialStd(child), "PartialStddev")() - SplitEvaluation(MergePartialStd(partialStd.toAttribute, isSample), partialStd :: Nil) - } - - override def newInstance(): StddevFunction = new StddevFunction(child, this, isSample) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function stddev") - -} - -// Compute the population standard deviation of a column -case class StddevPop(child: Expression) extends StddevAgg1(child) { - - override def toString: String = s"stddev_pop($child)" - override def isSample: Boolean = false -} - -// Compute the sample standard deviation of a column -case class StddevSamp(child: Expression) extends StddevAgg1(child) { - - override def toString: String = s"stddev_samp($child)" - override def isSample: Boolean = true -} - -case class ComputePartialStd(child: Expression) extends UnaryExpression with AggregateExpression1 { - def this() = this(null) - - override def children: Seq[Expression] = child :: Nil - override def nullable: Boolean = false - override def dataType: DataType = ArrayType(DoubleType) - override def toString: String = s"computePartialStddev($child)" - override def newInstance(): ComputePartialStdFunction = - new ComputePartialStdFunction(child, this) -} - -case class ComputePartialStdFunction ( - expr: Expression, - base: AggregateExpression1 - ) extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization - - private val computeType = DoubleType - private val zero = Cast(Literal(0), computeType) - private var partialCount: Long = 0L - - // the mean of data processed so far - private val partialAvg: MutableLiteral = MutableLiteral(zero.eval(null), computeType) - - // update average based on this formula: - // avg = avg + (value - avg)/count - private def avgAddFunction (value: Literal): Expression = { - val delta = Subtract(Cast(value, computeType), partialAvg) - Add(partialAvg, Divide(delta, Cast(Literal(partialCount), computeType))) - } - - // the sum of squares of difference from mean - private val partialMk: MutableLiteral = MutableLiteral(zero.eval(null), computeType) - - // update sum of square of difference from mean based on following formula: - // Mk = Mk + (value - preAvg) * (value - updatedAvg) - private def mkAddFunction(value: Literal, prePartialAvg: MutableLiteral): Expression = { - val delta1 = Subtract(Cast(value, computeType), prePartialAvg) - val delta2 = Subtract(Cast(value, computeType), partialAvg) - Add(partialMk, Multiply(delta1, delta2)) - } - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - val exprValue = Literal.create(evaluatedExpr, expr.dataType) - val prePartialAvg = partialAvg.copy() - partialCount += 1 - partialAvg.update(avgAddFunction(exprValue), input) - partialMk.update(mkAddFunction(exprValue, prePartialAvg), input) - } - } - - override def eval(input: InternalRow): Any = { - new GenericArrayData(Array(Cast(Literal(partialCount), computeType).eval(null), - partialAvg.eval(null), - partialMk.eval(null))) - } -} - -case class MergePartialStd( - child: Expression, - isSample: Boolean -) extends UnaryExpression with AggregateExpression1 { - def this() = this(null, false) // required for serialization - - override def children: Seq[Expression] = child:: Nil - override def nullable: Boolean = false - override def dataType: DataType = DoubleType - override def toString: String = s"MergePartialStd($child)" - override def newInstance(): MergePartialStdFunction = { - new MergePartialStdFunction(child, this, isSample) - } -} - -case class MergePartialStdFunction( - expr: Expression, - base: AggregateExpression1, - isSample: Boolean -) extends AggregateFunction1 { - def this() = this (null, null, false) // Required for serialization - - private val computeType = DoubleType - private val zero = Cast(Literal(0), computeType) - private val combineCount = MutableLiteral(zero.eval(null), computeType) - private val combineAvg = MutableLiteral(zero.eval(null), computeType) - private val combineMk = MutableLiteral(zero.eval(null), computeType) - - private def avgUpdateFunction(preCount: Expression, - partialCount: Expression, - partialAvg: Expression): Expression = { - Divide(Add(Multiply(combineAvg, preCount), - Multiply(partialAvg, partialCount)), - Add(preCount, partialCount)) - } - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input).asInstanceOf[ArrayData] - - if (evaluatedExpr != null) { - val exprValue = evaluatedExpr.toArray(computeType) - val (partialCount, partialAvg, partialMk) = - (Literal.create(exprValue(0), computeType), - Literal.create(exprValue(1), computeType), - Literal.create(exprValue(2), computeType)) - - if (Cast(partialCount, LongType).eval(null).asInstanceOf[Long] > 0) { - val preCount = combineCount.copy() - combineCount.update(Add(combineCount, partialCount), input) - - val preAvg = combineAvg.copy() - val avgDelta = Subtract(partialAvg, preAvg) - val mkDelta = Multiply(Multiply(avgDelta, avgDelta), - Divide(Multiply(preCount, partialCount), - combineCount)) - - // update average based on following formula - // (combineAvg * preCount + partialAvg * partialCount) / (preCount + partialCount) - combineAvg.update(avgUpdateFunction(preCount, partialCount, partialAvg), input) - - // update sum of square differences from mean based on following formula - // (combineMk + partialMk + (avgDelta * avgDelta) * (preCount * partialCount/combineCount) - combineMk.update(Add(combineMk, Add(partialMk, mkDelta)), input) - } - } - } - - override def eval(input: InternalRow): Any = { - val count: Long = Cast(combineCount, LongType).eval(null).asInstanceOf[Long] - - if (count == 0) null - else if (count < 2) zero.eval(null) - else { - // when total count > 2 - // stddev_samp = sqrt (combineMk/(combineCount -1)) - // stddev_pop = sqrt (combineMk/combineCount) - val varCol = { - if (isSample) { - Divide(combineMk, Cast(Literal(count - 1), computeType)) - } - else { - Divide(combineMk, Cast(Literal(count), computeType)) - } - } - Sqrt(varCol).eval(null) - } - } -} - -case class StddevFunction( - expr: Expression, - base: AggregateExpression1, - isSample: Boolean -) extends AggregateFunction1 { - - def this() = this(null, null, false) // Required for serialization - - private val computeType = DoubleType - private var curCount: Long = 0L - private val zero = Cast(Literal(0), computeType) - private val curAvg = MutableLiteral(zero.eval(null), computeType) - private val curMk = MutableLiteral(zero.eval(null), computeType) - - private def curAvgAddFunction(value: Literal): Expression = { - val delta = Subtract(Cast(value, computeType), curAvg) - Add(curAvg, Divide(delta, Cast(Literal(curCount), computeType))) - } - private def curMkAddFunction(value: Literal, preAvg: MutableLiteral): Expression = { - val delta1 = Subtract(Cast(value, computeType), preAvg) - val delta2 = Subtract(Cast(value, computeType), curAvg) - Add(curMk, Multiply(delta1, delta2)) - } - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - val preAvg: MutableLiteral = curAvg.copy() - val exprValue = Literal.create(evaluatedExpr, expr.dataType) - curCount += 1L - curAvg.update(curAvgAddFunction(exprValue), input) - curMk.update(curMkAddFunction(exprValue, preAvg), input) - } - } - - override def eval(input: InternalRow): Any = { - if (curCount == 0) null - else if (curCount < 2) zero.eval(null) - else { - // when total count > 2, - // stddev_samp = sqrt(curMk/(curCount - 1)) - // stddev_pop = sqrt(curMk/curCount) - val varCol = { - if (isSample) { - Divide(curMk, Cast(Literal(curCount - 1), computeType)) - } - else { - Divide(curMk, Cast(Literal(curCount), computeType)) - } - } - Sqrt(varCol).eval(null) - } - } -} - -// placeholder -case class Kurtosis(child: Expression) extends UnaryExpression with AggregateExpression1 { - - override def newInstance(): AggregateFunction1 = { - throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + - "please set spark.sql.useAggregate2 = true") - } - - override def nullable: Boolean = false - - override def dataType: DoubleType.type = DoubleType - - override def foldable: Boolean = false - - override def prettyName: String = "kurtosis" -} - -// placeholder -case class Skewness(child: Expression) extends UnaryExpression with AggregateExpression1 { - - override def newInstance(): AggregateFunction1 = { - throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + - "please set spark.sql.useAggregate2 = true") - } - - override def nullable: Boolean = false - - override def dataType: DoubleType.type = DoubleType - - override def foldable: Boolean = false - - override def prettyName: String = "skewness" -} - -// placeholder -case class VariancePop(child: Expression) extends UnaryExpression with AggregateExpression1 { - - override def newInstance(): AggregateFunction1 = { - throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + - "please set spark.sql.useAggregate2 = true") - } - - override def nullable: Boolean = false - - override def dataType: DoubleType.type = DoubleType - - override def foldable: Boolean = false - - override def prettyName: String = "var_pop" -} - -// placeholder -case class VarianceSamp(child: Expression) extends UnaryExpression with AggregateExpression1 { - - override def newInstance(): AggregateFunction1 = { - throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + - "please set spark.sql.useAggregate2 = true") - } - - override def nullable: Boolean = false - - override def dataType: DoubleType.type = DoubleType - - override def foldable: Boolean = false - - override def prettyName: String = "var_samp" -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 61a17fd7db0fe..f3d42fc0b2164 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -23,8 +23,10 @@ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval - -case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes { +@ExpressionDescription( + usage = "_FUNC_(a) - Returns -a.") +case class UnaryMinus(child: Expression) extends UnaryExpression + with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) @@ -34,7 +36,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp private lazy val numeric = TypeUtils.getNumeric(dataType) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => { val originValue = ctx.freshName("origin") @@ -54,28 +56,36 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp numeric.negate(input) } } + + override def sql: String = s"(-${child.sql})" } -case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes { +@ExpressionDescription( + usage = "_FUNC_(a) - Returns a.") +case class UnaryPositive(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def prettyName: String = "positive" override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) override def dataType: DataType = child.dataType - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = + override def genCode(ctx: CodegenContext, ev: ExprCode): String = defineCodeGen(ctx, ev, c => c) protected override def nullSafeEval(input: Any): Any = input + + override def sql: String = s"(+${child.sql})" } /** * A function that get the absolute value of the numeric value. */ @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the absolute value of the numeric value", - extended = "> SELECT _FUNC_('-1');\n1") -case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes { + usage = "_FUNC_(expr) - Returns the absolute value of the numeric value.", + extended = "> SELECT _FUNC_('-1');\n 1") +case class Abs(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) @@ -83,7 +93,7 @@ case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes private lazy val numeric = TypeUtils.getNumeric(dataType) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.abs()") case dt: NumericType => @@ -103,7 +113,7 @@ abstract class BinaryArithmetic extends BinaryOperator { def decimalMethod: String = sys.error("BinaryArithmetics must override either decimalMethod or genCode") - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") // byte and short are casted into int when add, minus, times or divide @@ -119,7 +129,9 @@ private[sql] object BinaryArithmetic { def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right)) } -case class Add(left: Expression, right: Expression) extends BinaryArithmetic { +@ExpressionDescription( + usage = "a _FUNC_ b - Returns a+b.") +case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -135,7 +147,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)") case ByteType | ShortType => @@ -148,7 +160,10 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { } } -case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { +@ExpressionDescription( + usage = "a _FUNC_ b - Returns a-b.") +case class Subtract(left: Expression, right: Expression) + extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -164,7 +179,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)") case ByteType | ShortType => @@ -177,7 +192,10 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti } } -case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { +@ExpressionDescription( + usage = "a _FUNC_ b - Multiplies a by b.") +case class Multiply(left: Expression, right: Expression) + extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = NumericType @@ -189,7 +207,11 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) } -case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { +@ExpressionDescription( + usage = "a _FUNC_ b - Divides a by b.", + extended = "> SELECT 3 _FUNC_ 2;\n 1.5") +case class Divide(left: Expression, right: Expression) + extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = NumericType @@ -219,7 +241,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic /** * Special case handling due to division by 0 => null. */ - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val isZero = if (dataType.isInstanceOf[DecimalType]) { @@ -233,25 +255,42 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } else { s"($javaType)(${eval1.value} $symbol ${eval2.value})" } - s""" - ${eval2.code} - boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; - if (${eval2.isNull} || $isZero) { - ${ev.isNull} = true; - } else { - ${eval1.code} - if (${eval1.isNull}) { + if (!left.nullable && !right.nullable) { + s""" + ${eval2.code} + boolean ${ev.isNull} = false; + $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + if ($isZero) { ${ev.isNull} = true; } else { + ${eval1.code} ${ev.value} = $divide; } - } - """ + """ + } else { + s""" + ${eval2.code} + boolean ${ev.isNull} = false; + $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + if (${eval2.isNull} || $isZero) { + ${ev.isNull} = true; + } else { + ${eval1.code} + if (${eval1.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = $divide; + } + } + """ + } } } -case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { +@ExpressionDescription( + usage = "a _FUNC_ b - Returns the remainder when dividing a by b.") +case class Remainder(left: Expression, right: Expression) + extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = NumericType @@ -281,7 +320,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet /** * Special case handling for x % 0 ==> null. */ - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val isZero = if (dataType.isInstanceOf[DecimalType]) { @@ -295,25 +334,41 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } else { s"($javaType)(${eval1.value} $symbol ${eval2.value})" } - s""" - ${eval2.code} - boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; - if (${eval2.isNull} || $isZero) { - ${ev.isNull} = true; - } else { - ${eval1.code} - if (${eval1.isNull}) { + if (!left.nullable && !right.nullable) { + s""" + ${eval2.code} + boolean ${ev.isNull} = false; + $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + if ($isZero) { ${ev.isNull} = true; } else { + ${eval1.code} ${ev.value} = $remainder; } - } - """ + """ + } else { + s""" + ${eval2.code} + boolean ${ev.isNull} = false; + $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + if (${eval2.isNull} || $isZero) { + ${ev.isNull} = true; + } else { + ${eval1.code} + if (${eval1.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = $remainder; + } + } + """ + } } } -case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { +case class MaxOf(left: Expression, right: Expression) + extends BinaryArithmetic with NonSQLExpression { + // TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least. override def inputType: AbstractDataType = TypeCollection.Ordered @@ -338,7 +393,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val compCode = ctx.genComp(dataType, eval1.value, eval2.value) @@ -367,7 +422,9 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "max" } -case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { +case class MinOf(left: Expression, right: Expression) + extends BinaryArithmetic with NonSQLExpression { + // TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least. override def inputType: AbstractDataType = TypeCollection.Ordered @@ -392,7 +449,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val compCode = ctx.genComp(dataType, eval1.value, eval2.value) @@ -421,7 +478,10 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "min" } -case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { +@ExpressionDescription( + usage = "_FUNC_(a, b) - Returns the positive modulo", + extended = "> SELECT _FUNC_(10,3);\n 1") +case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { override def toString: String = s"pmod($left, $right)" @@ -443,7 +503,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { case _: DecimalType => pmod(left.asInstanceOf[Decimal], right.asInstanceOf[Decimal]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { dataType match { case dt: DecimalType => @@ -513,4 +573,6 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { val r = a % n if (r.compare(Decimal.ZERO) < 0) {(r + n) % n} else r } + + override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index a1e48c4210877..a7e1cd66f24aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -26,6 +26,9 @@ import org.apache.spark.sql.types._ * * Code generation inherited from BinaryArithmetic. */ +@ExpressionDescription( + usage = "a _FUNC_ b - Bitwise AND.", + extended = "> SELECT 3 _FUNC_ 5; 1") case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = IntegralType @@ -51,6 +54,9 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme * * Code generation inherited from BinaryArithmetic. */ +@ExpressionDescription( + usage = "a _FUNC_ b - Bitwise OR.", + extended = "> SELECT 3 _FUNC_ 5; 7") case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = IntegralType @@ -76,6 +82,9 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet * * Code generation inherited from BinaryArithmetic. */ +@ExpressionDescription( + usage = "a _FUNC_ b - Bitwise exclusive OR.", + extended = "> SELECT 3 _FUNC_ 5; 2") case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = IntegralType @@ -99,6 +108,9 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme /** * A function that calculates bitwise not(~) of a number. */ +@ExpressionDescription( + usage = "_FUNC_ b - Bitwise NOT.", + extended = "> SELECT _FUNC_ 0; -1") case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType) @@ -118,9 +130,11 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)") } protected override def nullSafeEval(input: Any): Any = not(input) + + override def sql: String = s"~${child.sql}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala index 9b8b6382d753d..ab4831f7abdd0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala @@ -25,19 +25,63 @@ package org.apache.spark.sql.catalyst.expressions.codegen */ object CodeFormatter { def format(code: String): String = new CodeFormatter().addLines(code).result() + def stripExtraNewLines(input: String): String = { + val code = new StringBuilder + var lastLine: String = "dummy" + input.split('\n').foreach { l => + val line = l.trim() + val skip = line == "" && (lastLine == "" || lastLine.endsWith("{")) + if (!skip) { + code.append(line) + code.append("\n") + } + lastLine = line + } + code.result() + } } private class CodeFormatter { private val code = new StringBuilder - private var indentLevel = 0 private val indentSize = 2 + + // Tracks the level of indentation in the current line. + private var indentLevel = 0 private var indentString = "" private var currentLine = 1 + // Tracks the level of indentation in multi-line comment blocks. + private var inCommentBlock = false + private var indentLevelOutsideCommentBlock = indentLevel + private def addLine(line: String): Unit = { - val indentChange = - line.count(c => "({".indexOf(c) >= 0) - line.count(c => ")}".indexOf(c) >= 0) - val newIndentLevel = math.max(0, indentLevel + indentChange) + + // We currently infer the level of indentation of a given line based on a simple heuristic that + // examines the number of parenthesis and braces in that line. This isn't the most robust + // implementation but works for all code that we generate. + val indentChange = line.count(c => "({".indexOf(c) >= 0) - line.count(c => ")}".indexOf(c) >= 0) + var newIndentLevel = math.max(0, indentLevel + indentChange) + + // Please note that while we try to format the comment blocks in exactly the same way as the + // rest of the code, once the block ends, we reset the next line's indentation level to what it + // was immediately before entering the comment block. + if (!inCommentBlock) { + if (line.startsWith("/*")) { + // Handle multi-line comments + inCommentBlock = true + indentLevelOutsideCommentBlock = indentLevel + } else if (line.startsWith("//")) { + // Handle single line comments + newIndentLevel = indentLevel + } + } + if (inCommentBlock) { + if (line.endsWith("*/")) { + inCommentBlock = false + newIndentLevel = indentLevelOutsideCommentBlock + } + } + // Lines starting with '}' should be de-indented even if they contain '{' after; // in addition, lines ending with ':' are typically labels val thisLineIndent = if (line.startsWith("}") || line.startsWith(")") || line.endsWith(":")) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f0f7a6cf0cc4d..f43626ca814a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -24,41 +24,68 @@ import scala.language.existentials import com.google.common.cache.{CacheBuilder, CacheLoader} import org.codehaus.janino.ClassBodyEvaluator -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types._ - - -// These classes are here to avoid issues with serialization and integration with quasiquotes. -class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int] -class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] +import org.apache.spark.util.Utils /** * Java source for evaluating an [[Expression]] given a [[InternalRow]] of input. * * @param code The sequence of statements required to evaluate the expression. + * It should be empty string, if `isNull` and `value` are already existed, or no code + * needed to evaluate them (literals). * @param isNull A term that holds a boolean value representing whether the expression evaluated * to null. * @param value A term for a (possibly primitive) value of the result of the evaluation. Not * valid if `isNull` is set to `true`. */ -case class GeneratedExpressionCode(var code: String, var isNull: String, var value: String) +case class ExprCode(var code: String, var isNull: String, var value: String) /** - * A context for codegen, which is used to bookkeeping the expressions those are not supported - * by codegen, then they are evaluated directly. The unsupported expression is appended at the - * end of `references`, the position of it is kept in the code, used to access and evaluate it. + * A context for codegen, tracking a list of objects that could be passed into generated Java + * function. */ -class CodeGenContext { +class CodegenContext { + + /** + * Holding a list of objects that could be used passed into generated class. + */ + val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]() /** - * Holding all the expressions those do not support codegen, will be evaluated directly. + * Add an object to `references`, create a class member to access it. + * + * Returns the name of class member. + */ + def addReferenceObj(name: String, obj: Any, className: String = null): String = { + val term = freshName(name) + val idx = references.length + references += obj + val clsName = Option(className).getOrElse(obj.getClass.getName) + addMutableState(clsName, term, s"this.$term = ($clsName) references[$idx];") + term + } + + /** + * Holding a list of generated columns as input of current operator, will be used by + * BoundReference to generate code. + */ + var currentVars: Seq[ExprCode] = null + + /** + * Whether should we copy the result rows or not. + * + * If any operator inside WholeStageCodegen generate multiple rows from a single row (for + * example, Join), this should be true. + * + * If an operator starts a new pipeline, this should be reset to false before calling `consume()`. */ - val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() + var copyResult: Boolean = false /** * Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a @@ -82,6 +109,16 @@ class CodeGenContext { mutableStates += ((javaType, variableName, initCode)) } + def declareMutableStates(): String = { + mutableStates.map { case (javaType, variableName, _) => + s"private $javaType $variableName;" + }.mkString("\n") + } + + def initMutableStates(): String = { + mutableStates.map(_._3).mkString("\n") + } + /** * Holding all the functions those will be added into generated class. */ @@ -92,6 +129,34 @@ class CodeGenContext { addedFunctions += ((funcName, funcCode)) } + /** + * Holds expressions that are equivalent. Used to perform subexpression elimination + * during codegen. + * + * For expressions that appear more than once, generate additional code to prevent + * recomputing the value. + * + * For example, consider two expression generated from this SQL statement: + * SELECT (col1 + col2), (col1 + col2) / col3. + * + * equivalentExpressions will match the tree containing `col1 + col2` and it will only + * be evaluated once. + */ + val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions + + // State used for subexpression elimination. + case class SubExprEliminationState(isNull: String, value: String) + + // Foreach expression that is participating in subexpression elimination, the state to use. + val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] + + // The collection of sub-expression result resetting methods that need to be called on each row. + val subexprFunctions = mutable.ArrayBuffer.empty[String] + + def declareAddedFunctions(): String = { + addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n") + } + final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" @@ -101,18 +166,36 @@ class CodeGenContext { final val JAVA_DOUBLE = "double" /** The variable name of the input row in generated code. */ - final val INPUT_ROW = "i" + final var INPUT_ROW = "i" + + /** + * The map from a variable name to it's next ID. + */ + private val freshNameIds = new mutable.HashMap[String, Int] + freshNameIds += INPUT_ROW -> 1 - private val curId = new java.util.concurrent.atomic.AtomicInteger() + /** + * A prefix used to generate fresh name. + */ + var freshNamePrefix = "" /** - * Returns a term name that is unique within this instance of a `CodeGenerator`. - * - * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` - * function.) + * Returns a term name that is unique within this instance of a `CodegenContext`. */ - def freshName(prefix: String): String = { - s"$prefix${curId.getAndIncrement}" + def freshName(name: String): String = synchronized { + val fullName = if (freshNamePrefix == "") { + name + } else { + s"${freshNamePrefix}_$name" + } + if (freshNameIds.contains(fullName)) { + val id = freshNameIds(fullName) + freshNameIds(fullName) = id + 1 + s"$fullName$id" + } else { + freshNameIds += fullName -> 1 + fullName + } } /** @@ -150,6 +233,39 @@ class CodeGenContext { } } + /** + * Update a column in MutableRow from ExprCode. + */ + def updateColumn( + row: String, + dataType: DataType, + ordinal: Int, + ev: ExprCode, + nullable: Boolean): String = { + if (nullable) { + // Can't call setNullAt on DecimalType, because we need to keep the offset + if (dataType.isInstanceOf[DecimalType]) { + s""" + if (!${ev.isNull}) { + ${setColumn(row, dataType, ordinal, ev.value)}; + } else { + ${setColumn(row, dataType, ordinal, "null")}; + } + """ + } else { + s""" + if (!${ev.isNull}) { + ${setColumn(row, dataType, ordinal, ev.value)}; + } else { + $row.setNullAt($ordinal); + } + """ + } + } else { + s"""${setColumn(row, dataType, ordinal, ev.value)};""" + } + } + /** * Returns the name used in accessor and setter for a Java primitive type. */ @@ -178,8 +294,6 @@ class CodeGenContext { case _: StructType => "InternalRow" case _: ArrayType => "ArrayData" case _: MapType => "MapData" - case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName - case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName case udt: UserDefinedType[_] => javaType(udt.sqlType) case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" case ObjectType(cls) => cls.getName @@ -246,6 +360,49 @@ class CodeGenContext { case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" case NullType => "0" + case array: ArrayType => + val elementType = array.elementType + val elementA = freshName("elementA") + val isNullA = freshName("isNullA") + val elementB = freshName("elementB") + val isNullB = freshName("isNullB") + val compareFunc = freshName("compareArray") + val minLength = freshName("minLength") + val funcCode: String = + s""" + public int $compareFunc(ArrayData a, ArrayData b) { + int lengthA = a.numElements(); + int lengthB = b.numElements(); + int $minLength = (lengthA > lengthB) ? lengthB : lengthA; + for (int i = 0; i < $minLength; i++) { + boolean $isNullA = a.isNullAt(i); + boolean $isNullB = b.isNullAt(i); + if ($isNullA && $isNullB) { + // Nothing + } else if ($isNullA) { + return -1; + } else if ($isNullB) { + return 1; + } else { + ${javaType(elementType)} $elementA = ${getValue("a", elementType, "i")}; + ${javaType(elementType)} $elementB = ${getValue("b", elementType, "i")}; + int comp = ${genComp(elementType, elementA, elementB)}; + if (comp != 0) { + return comp; + } + } + } + + if (lengthA < lengthB) { + return -1; + } else if (lengthA > lengthB) { + return 1; + } + return 0; + } + """ + addNewFunction(compareFunc, funcCode) + s"this.$compareFunc($c1, $c2)" case schema: StructType => val comparisons = GenerateOrdering.genComparisons(this, schema) val compareFunc = freshName("compareStruct") @@ -265,6 +422,38 @@ class CodeGenContext { throw new IllegalArgumentException("cannot generate compare code for un-comparable type") } + /** + * Generates code for greater of two expressions. + * + * @param dataType data type of the expressions + * @param c1 name of the variable of expression 1's output + * @param c2 name of the variable of expression 2's output + */ + def genGreater(dataType: DataType, c1: String, c2: String): String = javaType(dataType) match { + case JAVA_BYTE | JAVA_SHORT | JAVA_INT | JAVA_LONG => s"$c1 > $c2" + case _ => s"(${genComp(dataType, c1, c2)}) > 0" + } + + /** + * Generates code to do null safe execution, i.e. only execute the code when the input is not + * null by adding null check if necessary. + * + * @param nullable used to decide whether we should add null check or not. + * @param isNull the code to check if the input is null. + * @param execute the code that should only be executed when the input is not null. + */ + def nullSafeExec(nullable: Boolean, isNull: String)(execute: String): String = { + if (nullable) { + s""" + if (!$isNull) { + $execute + } + """ + } else { + "\n" + execute + } + } + /** * List of java data types that have special accessors and setters in [[InternalRow]]. */ @@ -317,6 +506,74 @@ class CodeGenContext { functions.map(name => s"$name($row);").mkString("\n") } } + + /** + * Checks and sets up the state and codegen for subexpression elimination. This finds the + * common subexpressions, generates the functions that evaluate those expressions and populates + * the mapping of common subexpressions to the generated functions. + */ + private def subexpressionElimination(expressions: Seq[Expression]) = { + // Add each expression tree and compute the common subexpressions. + expressions.foreach(equivalentExpressions.addExprTree(_)) + + // Get all the expressions that appear at least twice and set up the state for subexpression + // elimination. + val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) + commonExprs.foreach { e => + val expr = e.head + val fnName = freshName("evalExpr") + val isNull = s"${fnName}IsNull" + val value = s"${fnName}Value" + + // Generate the code for this expression tree and wrap it in a function. + val code = expr.gen(this) + val fn = + s""" + |private void $fnName(InternalRow $INPUT_ROW) { + | ${code.code.trim} + | $isNull = ${code.isNull}; + | $value = ${code.value}; + |} + """.stripMargin + + addNewFunction(fnName, fn) + + // Add a state and a mapping of the common subexpressions that are associate with this + // state. Adding this expression to subExprEliminationExprMap means it will call `fn` + // when it is code generated. This decision should be a cost based one. + // + // The cost of doing subexpression elimination is: + // 1. Extra function call, although this is probably *good* as the JIT can decide to + // inline or not. + // 2. Extra branch to check isLoaded. This branch is likely to be predicted correctly + // very often. The reason it is not loaded is because of a prior branch. + // 3. Extra store into isLoaded. + // The benefit doing subexpression elimination is: + // 1. Running the expression logic. Even for a simple expression, it is likely more than 3 + // above. + // 2. Less code. + // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with + // at least two nodes) as the cost of doing it is expected to be low. + addMutableState("boolean", isNull, s"$isNull = false;") + addMutableState(javaType(expr.dataType), value, + s"$value = ${defaultValue(expr.dataType)};") + + subexprFunctions += s"$fnName($INPUT_ROW);" + val state = SubExprEliminationState(isNull, value) + e.foreach(subExprEliminationExprs.put(_, state)) + } + } + + /** + * Generates code for expressions. If doSubexpressionElimination is true, subexpression + * elimination will be performed. Subexpression elimination assumes that the code will for each + * expression will be combined in the `expressions` order. + */ + def generateExpressions(expressions: Seq[Expression], + doSubexpressionElimination: Boolean = false): Seq[ExprCode] = { + if (doSubexpressionElimination) subexpressionElimination(expressions) + expressions.map(e => e.gen(this)) + } } /** @@ -324,7 +581,7 @@ class CodeGenContext { * into generated class. */ abstract class GeneratedClass { - def generate(expressions: Array[Expression]): Any + def generate(references: Array[Any]): Any } /** @@ -334,24 +591,8 @@ abstract class GeneratedClass { */ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { - protected val exprType: String = classOf[Expression].getName - protected val mutableRowType: String = classOf[MutableRow].getName protected val genericMutableRowType: String = classOf[GenericMutableRow].getName - protected def declareMutableStates(ctx: CodeGenContext): String = { - ctx.mutableStates.map { case (javaType, variableName, _) => - s"private $javaType $variableName;" - }.mkString("\n") - } - - protected def initMutableStates(ctx: CodeGenContext): String = { - ctx.mutableStates.map(_._3).mkString("\n") - } - - protected def declareAddedFunctions(ctx: CodeGenContext): String = { - ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n") - } - /** * Generates a class for a given input expression. Called when there is not cached code * already available. @@ -367,10 +608,27 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin /** Binds an input expression to a given input schema */ protected def bind(in: InType, inputSchema: Seq[Attribute]): InType + /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */ + def generate(expressions: InType, inputSchema: Seq[Attribute]): OutType = + generate(bind(expressions, inputSchema)) + + /** Generates the requested evaluator given already bound expression(s). */ + def generate(expressions: InType): OutType = create(canonicalize(expressions)) + + /** + * Create a new codegen context for expression evaluator, used to store those + * expressions that don't support codegen + */ + def newCodeGenContext(): CodegenContext = { + new CodegenContext + } +} + +object CodeGenerator extends Logging { /** * Compile the Java source code into a Java class, using Janino. */ - protected def compile(code: String): GeneratedClass = { + def compile(code: String): GeneratedClass = { cache.get(code) } @@ -379,7 +637,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin */ private[this] def doCompile(code: String): GeneratedClass = { val evaluator = new ClassBodyEvaluator() - evaluator.setParentClassLoader(getClass.getClassLoader) + evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader) // Cannot be under package codegen, or fail with java.lang.InstantiationException evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass") evaluator.setDefaultImports(Array( @@ -393,7 +651,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin classOf[UnsafeArrayData].getName, classOf[MapData].getName, classOf[UnsafeMapData].getName, - classOf[MutableRow].getName + classOf[MutableRow].getName, + classOf[Expression].getName )) evaluator.setExtendedClass(classOf[GeneratedClass]) @@ -402,7 +661,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin logDebug({ // Only add extra debugging info to byte code when we are going to print the source code. evaluator.setDebuggingInformation(true, true, false) - formatted + s"\n$formatted" }) try { @@ -438,19 +697,4 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin result } }) - - /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */ - def generate(expressions: InType, inputSchema: Seq[Attribute]): OutType = - generate(bind(expressions, inputSchema)) - - /** Generates the requested evaluator given already bound expression(s). */ - def generate(expressions: InType): OutType = create(canonicalize(expressions)) - - /** - * Create a new codegen context for expression evaluator, used to store those - * expressions that don't support codegen - */ - def newCodeGenContext(): CodeGenContext = { - new CodeGenContext - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index d51a8dede7f34..1365ee4b55634 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -17,29 +17,42 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.sql.catalyst.expressions.{Nondeterministic, Expression} +import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic} +import org.apache.spark.sql.catalyst.util.toCommentSafeString /** * A trait that can be used to provide a fallback mode for expression code generation. */ trait CodegenFallback extends Expression { - protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { foreach { case n: Nondeterministic => n.setInitialValues() case _ => } + // LeafNode does not need `input` + val input = if (this.isInstanceOf[LeafExpression]) "null" else ctx.INPUT_ROW + val idx = ctx.references.length ctx.references += this val objectTerm = ctx.freshName("obj") - s""" - /* expression: ${this} */ - Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); - boolean ${ev.isNull} = $objectTerm == null; - ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; - if (!${ev.isNull}) { - ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; - } - """ + if (nullable) { + s""" + /* expression: ${toCommentSafeString(this.toString)} */ + Object $objectTerm = ((Expression) references[$idx]).eval($input); + boolean ${ev.isNull} = $objectTerm == null; + ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; + if (!${ev.isNull}) { + ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; + } + """ + } else { + ev.isNull = "false" + s""" + /* expression: ${toCommentSafeString(this.toString)} */ + Object $objectTerm = ((Expression) references[$idx]).eval($input); + ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; + """ + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 4b66069b5f55a..7f840890f8ae5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp -import org.apache.spark.sql.types.DecimalType // MutableProjection is not accessible in Java abstract class BaseMutableProjection extends MutableProjection @@ -38,68 +37,82 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = in.map(BindReferences.bindReference(_, inputSchema)) + def generate( + expressions: Seq[Expression], + inputSchema: Seq[Attribute], + useSubexprElimination: Boolean): (() => MutableProjection) = { + create(canonicalize(bind(expressions, inputSchema)), useSubexprElimination) + } + protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { + create(expressions, false) + } + + private def create( + expressions: Seq[Expression], + useSubexprElimination: Boolean): (() => MutableProjection) = { val ctx = newCodeGenContext() - val projectionCodes = expressions.zipWithIndex.map { - case (NoOp, _) => "" - case (e, i) => - val evaluationCode = e.gen(ctx) - val isNull = s"isNull_$i" - val value = s"value_$i" - ctx.addMutableState("boolean", isNull, s"this.$isNull = true;") - ctx.addMutableState(ctx.javaType(e.dataType), value, - s"this.$value = ${ctx.defaultValue(e.dataType)};") - s""" - ${evaluationCode.code} - this.$isNull = ${evaluationCode.isNull}; - this.$value = ${evaluationCode.value}; - """ - } - val updates = expressions.zipWithIndex.map { - case (NoOp, _) => "" - case (e, i) => - if (e.dataType.isInstanceOf[DecimalType]) { - // Can't call setNullAt on DecimalType, because we need to keep the offset + val (validExpr, index) = expressions.zipWithIndex.filter { + case (NoOp, _) => false + case _ => true + }.unzip + val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination) + val projectionCodes = exprVals.zip(index).map { + case (ev, i) => + val e = expressions(i) + if (e.nullable) { + val isNull = s"isNull_$i" + val value = s"value_$i" + ctx.addMutableState("boolean", isNull, s"this.$isNull = true;") + ctx.addMutableState(ctx.javaType(e.dataType), value, + s"this.$value = ${ctx.defaultValue(e.dataType)};") s""" - if (this.isNull_$i) { - ${ctx.setColumn("mutableRow", e.dataType, i, null)}; - } else { - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - } - """ + ${ev.code} + this.$isNull = ${ev.isNull}; + this.$value = ${ev.value}; + """ } else { + val value = s"value_$i" + ctx.addMutableState(ctx.javaType(e.dataType), value, + s"this.$value = ${ctx.defaultValue(e.dataType)};") s""" - if (this.isNull_$i) { - mutableRow.setNullAt($i); - } else { - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - } - """ + ${ev.code} + this.$value = ${ev.value}; + """ } } + // Evaluate all the subexpressions. + val evalSubexpr = ctx.subexprFunctions.mkString("\n") + + val updates = validExpr.zip(index).map { + case (e, i) => + val ev = ExprCode("", s"this.isNull_$i", s"this.value_$i") + ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) + } + val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes) val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates) val code = s""" - public Object generate($exprType[] expr) { - return new SpecificMutableProjection(expr); + public java.lang.Object generate(Object[] references) { + return new SpecificMutableProjection(references); } class SpecificMutableProjection extends ${classOf[BaseMutableProjection].getName} { - private $exprType[] expressions; - private $mutableRowType mutableRow; - ${declareMutableStates(ctx)} - ${declareAddedFunctions(ctx)} + private Object[] references; + private MutableRow mutableRow; + ${ctx.declareMutableStates()} + ${ctx.declareAddedFunctions()} - public SpecificMutableProjection($exprType[] expr) { - expressions = expr; + public SpecificMutableProjection(Object[] references) { + this.references = references; mutableRow = new $genericMutableRowType(${expressions.size}); - ${initMutableStates(ctx)} + ${ctx.initMutableStates()} } - public ${classOf[BaseMutableProjection].getName} target($mutableRowType row) { + public ${classOf[BaseMutableProjection].getName} target(MutableRow row) { mutableRow = row; return this; } @@ -109,8 +122,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu return (InternalRow) mutableRow; } - public Object apply(Object _i) { + public java.lang.Object apply(java.lang.Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; + $evalSubexpr $allProjections // copy all the results into MutableRow $allUpdates @@ -121,7 +135,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") - val c = compile(code) + val c = CodeGenerator.compile(code) () => { c.generate(ctx.references.toArray).asInstanceOf[MutableProjection] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 1af7c73cd4bf5..908c32de4d896 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.Logging +import java.io.ObjectInputStream + +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** * Inherits some default implementation for Java from `Ordering[Row]` @@ -55,7 +58,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR * Generates the code for comparing a struct type according to its natural ordering * (i.e. ascending order by field 1, then field 2, ..., then field n. */ - def genComparisons(ctx: CodeGenContext, schema: StructType): String = { + def genComparisons(ctx: CodegenContext, schema: StructType): String = { val ordering = schema.fields.map(_.dataType).zipWithIndex.map { case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) } @@ -65,7 +68,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR /** * Generates the code for ordering based on the given order. */ - def genComparisons(ctx: CodeGenContext, ordering: Seq[SortOrder]): String = { + def genComparisons(ctx: CodegenContext, ordering: Seq[SortOrder]): String = { val comparisons = ordering.map { order => val eval = order.child.gen(ctx) val asc = order.direction == Ascending @@ -111,19 +114,19 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR val ctx = newCodeGenContext() val comparisons = genComparisons(ctx, ordering) val code = s""" - public SpecificOrdering generate($exprType[] expr) { - return new SpecificOrdering(expr); + public SpecificOrdering generate(Object[] references) { + return new SpecificOrdering(references); } class SpecificOrdering extends ${classOf[BaseOrdering].getName} { - private $exprType[] expressions; - ${declareMutableStates(ctx)} - ${declareAddedFunctions(ctx)} + private Object[] references; + ${ctx.declareMutableStates()} + ${ctx.declareAddedFunctions()} - public SpecificOrdering($exprType[] expr) { - expressions = expr; - ${initMutableStates(ctx)} + public SpecificOrdering(Object[] references) { + this.references = references; + ${ctx.initMutableStates()} } public int compare(InternalRow a, InternalRow b) { @@ -135,6 +138,40 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR logDebug(s"Generated Ordering: ${CodeFormatter.format(code)}") - compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering] + CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering] + } +} + +/** + * A lazily generated row ordering comparator. + */ +class LazilyGeneratedOrdering(val ordering: Seq[SortOrder]) extends Ordering[InternalRow] { + + def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = + this(ordering.map(BindReferences.bindReference(_, inputSchema))) + + @transient + private[this] var generatedOrdering = GenerateOrdering.generate(ordering) + + def compare(a: InternalRow, b: InternalRow): Int = { + generatedOrdering.compare(a, b) + } + + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { + in.defaultReadObject() + generatedOrdering = GenerateOrdering.generate(ordering) + } +} + +object LazilyGeneratedOrdering { + + /** + * Creates a [[LazilyGeneratedOrdering]] for the given schema, in natural ascending order. + */ + def forSchema(schema: StructType): LazilyGeneratedOrdering = { + new LazilyGeneratedOrdering(schema.zipWithIndex.map { + case (field, ordinal) => + SortOrder(BoundReference(ordinal, field.dataType, nullable = true), Ascending) + }) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 457b4f08424a6..58065d956f072 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -41,18 +41,18 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool val ctx = newCodeGenContext() val eval = predicate.gen(ctx) val code = s""" - public SpecificPredicate generate($exprType[] expr) { - return new SpecificPredicate(expr); + public SpecificPredicate generate(Object[] references) { + return new SpecificPredicate(references); } class SpecificPredicate extends ${classOf[Predicate].getName} { - private final $exprType[] expressions; - ${declareMutableStates(ctx)} - ${declareAddedFunctions(ctx)} + private final Object[] references; + ${ctx.declareMutableStates()} + ${ctx.declareAddedFunctions()} - public SpecificPredicate($exprType[] expr) { - expressions = expr; - ${initMutableStates(ctx)} + public SpecificPredicate(Object[] references) { + this.references = references; + ${ctx.initMutableStates()} } public boolean eval(InternalRow ${ctx.INPUT_ROW}) { @@ -63,7 +63,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}") - val p = compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] + val p = CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] (r: InternalRow) => p.eval(r) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala deleted file mode 100644 index c0d313b2e1301..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ /dev/null @@ -1,238 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.codegen - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types._ - -/** - * Java can not access Projection (in package object) - */ -abstract class BaseProjection extends Projection {} - -abstract class CodeGenMutableRow extends MutableRow with BaseGenericInternalRow - -/** - * Generates bytecode that produces a new [[InternalRow]] object based on a fixed set of input - * [[Expression Expressions]] and a given input [[InternalRow]]. The returned [[InternalRow]] - * object is custom generated based on the output types of the [[Expression]] to avoid boxing of - * primitive values. - */ -object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { - - protected def canonicalize(in: Seq[Expression]): Seq[Expression] = - in.map(ExpressionCanonicalizer.execute) - - protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = - in.map(BindReferences.bindReference(_, inputSchema)) - - // Make Mutablility optional... - protected def create(expressions: Seq[Expression]): Projection = { - val ctx = newCodeGenContext() - val columns = expressions.zipWithIndex.map { - case (e, i) => - s"private ${ctx.javaType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n" - }.mkString("\n") - - val initColumns = expressions.zipWithIndex.map { - case (e, i) => - val eval = e.gen(ctx) - s""" - { - // column$i - ${eval.code} - nullBits[$i] = ${eval.isNull}; - if (!${eval.isNull}) { - c$i = ${eval.value}; - } - } - """ - }.mkString("\n") - - val getCases = (0 until expressions.size).map { i => - s"case $i: return c$i;" - }.mkString("\n") - - val updateCases = expressions.zipWithIndex.map { case (e, i) => - s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}" - }.mkString("\n") - - val specificAccessorFunctions = ctx.primitiveTypes.map { jt => - val cases = expressions.zipWithIndex.flatMap { - case (e, i) if ctx.javaType(e.dataType) == jt => - Some(s"case $i: return c$i;") - case _ => None - }.mkString("\n") - if (cases.length > 0) { - val getter = "get" + ctx.primitiveTypeName(jt) - s""" - public $jt $getter(int i) { - if (isNullAt(i)) { - return ${ctx.defaultValue(jt)}; - } - switch (i) { - $cases - } - throw new IllegalArgumentException("Invalid index: " + i - + " in $getter"); - }""" - } else { - "" - } - }.filter(_.length > 0).mkString("\n") - - val specificMutatorFunctions = ctx.primitiveTypes.map { jt => - val cases = expressions.zipWithIndex.flatMap { - case (e, i) if ctx.javaType(e.dataType) == jt => - Some(s"case $i: { c$i = value; return; }") - case _ => None - }.mkString("\n") - if (cases.length > 0) { - val setter = "set" + ctx.primitiveTypeName(jt) - s""" - public void $setter(int i, $jt value) { - nullBits[i] = false; - switch (i) { - $cases - } - throw new IllegalArgumentException("Invalid index: " + i + - " in $setter}"); - }""" - } else { - "" - } - }.filter(_.length > 0).mkString("\n") - - val hashValues = expressions.zipWithIndex.map { case (e, i) => - val col = s"c$i" - val nonNull = e.dataType match { - case BooleanType => s"$col ? 0 : 1" - case ByteType | ShortType | IntegerType | DateType => s"$col" - case LongType | TimestampType => s"$col ^ ($col >>> 32)" - case FloatType => s"Float.floatToIntBits($col)" - case DoubleType => - s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))" - case BinaryType => s"java.util.Arrays.hashCode($col)" - case _ => s"$col.hashCode()" - } - s"isNullAt($i) ? 0 : ($nonNull)" - } - - val hashUpdates: String = hashValues.map( v => - s""" - result *= 37; result += $v;""" - ).mkString("\n") - - val columnChecks = expressions.zipWithIndex.map { case (e, i) => - s""" - if (nullBits[$i] != row.nullBits[$i] || - (!nullBits[$i] && !(${ctx.genEqual(e.dataType, s"c$i", s"row.c$i")}))) { - return false; - } - """ - }.mkString("\n") - - val copyColumns = expressions.zipWithIndex.map { case (e, i) => - s"""if (!nullBits[$i]) arr[$i] = c$i;""" - }.mkString("\n") - - val code = s""" - public SpecificProjection generate($exprType[] expr) { - return new SpecificProjection(expr); - } - - class SpecificProjection extends ${classOf[BaseProjection].getName} { - private $exprType[] expressions; - ${declareMutableStates(ctx)} - ${declareAddedFunctions(ctx)} - - public SpecificProjection($exprType[] expr) { - expressions = expr; - ${initMutableStates(ctx)} - } - - public Object apply(Object r) { - // GenerateProjection does not work with UnsafeRows. - assert(!(r instanceof ${classOf[UnsafeRow].getName})); - return new SpecificRow((InternalRow) r); - } - - final class SpecificRow extends ${classOf[CodeGenMutableRow].getName} { - - $columns - - public SpecificRow(InternalRow ${ctx.INPUT_ROW}) { - $initColumns - } - - public int numFields() { return ${expressions.length};} - protected boolean[] nullBits = new boolean[${expressions.length}]; - public void setNullAt(int i) { nullBits[i] = true; } - public boolean isNullAt(int i) { return nullBits[i]; } - - public Object genericGet(int i) { - if (isNullAt(i)) return null; - switch (i) { - $getCases - } - return null; - } - public void update(int i, Object value) { - if (value == null) { - setNullAt(i); - return; - } - nullBits[i] = false; - switch (i) { - $updateCases - } - } - $specificAccessorFunctions - $specificMutatorFunctions - - public int hashCode() { - int result = 37; - $hashUpdates - return result; - } - - public boolean equals(Object other) { - if (other instanceof SpecificRow) { - SpecificRow row = (SpecificRow) other; - $columnChecks - return true; - } - return super.equals(other); - } - - public InternalRow copy() { - Object[] arr = new Object[${expressions.length}]; - ${copyColumns} - return new ${classOf[GenericInternalRow].getName}(arr); - } - } - } - """ - - logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n" + - CodeFormatter.format(code)) - - compile(code).generate(ctx.references.toArray).asInstanceOf[Projection] - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index f0ed8645d923f..cf73e36d227c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -17,11 +17,17 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import scala.annotation.tailrec + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types._ +/** + * Java can not access Projection (in package object) + */ +abstract class BaseProjection extends Projection {} /** * Generates byte code that produces a [[MutableRow]] object (not an [[UnsafeRow]]) that can update @@ -36,9 +42,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] in.map(BindReferences.bindReference(_, inputSchema)) private def createCodeForStruct( - ctx: CodeGenContext, + ctx: CodegenContext, input: String, - schema: StructType): GeneratedExpressionCode = { + schema: StructType): ExprCode = { val tmp = ctx.freshName("tmp") val output = ctx.freshName("safeRow") val values = ctx.freshName("values") @@ -64,13 +70,13 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final InternalRow $output = new $rowClass($values); """ - GeneratedExpressionCode(code, "false", output) + ExprCode(code, "false", output) } private def createCodeForArray( - ctx: CodeGenContext, + ctx: CodegenContext, input: String, - elementType: DataType): GeneratedExpressionCode = { + elementType: DataType): ExprCode = { val tmp = ctx.freshName("tmp") val output = ctx.freshName("safeArray") val values = ctx.freshName("values") @@ -92,14 +98,14 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final ArrayData $output = new $arrayClass($values); """ - GeneratedExpressionCode(code, "false", output) + ExprCode(code, "false", output) } private def createCodeForMap( - ctx: CodeGenContext, + ctx: CodegenContext, input: String, keyType: DataType, - valueType: DataType): GeneratedExpressionCode = { + valueType: DataType): ExprCode = { val tmp = ctx.freshName("tmp") val output = ctx.freshName("safeMap") val mapClass = classOf[ArrayBasedMapData].getName @@ -113,20 +119,21 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value}); """ - GeneratedExpressionCode(code, "false", output) + ExprCode(code, "false", output) } + @tailrec private def convertToSafe( - ctx: CodeGenContext, + ctx: CodegenContext, input: String, - dataType: DataType): GeneratedExpressionCode = dataType match { + dataType: DataType): ExprCode = dataType match { case s: StructType => createCodeForStruct(ctx, input, s) case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) // UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe. - case StringType => GeneratedExpressionCode("", "false", s"$input.clone()") + case StringType => ExprCode("", "false", s"$input.clone()") case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) - case _ => GeneratedExpressionCode("", "false", input) + case _ => ExprCode("", "false", input) } protected def create(expressions: Seq[Expression]): Projection = { @@ -148,24 +155,24 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] } val allExpressions = ctx.splitExpressions(ctx.INPUT_ROW, expressionCodes) val code = s""" - public Object generate($exprType[] expr) { - return new SpecificSafeProjection(expr); + public java.lang.Object generate(Object[] references) { + return new SpecificSafeProjection(references); } class SpecificSafeProjection extends ${classOf[BaseProjection].getName} { - private $exprType[] expressions; - private $mutableRowType mutableRow; - ${declareMutableStates(ctx)} - ${declareAddedFunctions(ctx)} + private Object[] references; + private MutableRow mutableRow; + ${ctx.declareMutableStates()} + ${ctx.declareAddedFunctions()} - public SpecificSafeProjection($exprType[] expr) { - expressions = expr; - mutableRow = new $genericMutableRowType(${expressions.size}); - ${initMutableStates(ctx)} + public SpecificSafeProjection(Object[] references) { + this.references = references; + mutableRow = (MutableRow) references[references.length - 1]; + ${ctx.initMutableStates()} } - public Object apply(Object _i) { + public java.lang.Object apply(java.lang.Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; $allExpressions return mutableRow; @@ -175,7 +182,8 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") - val c = compile(code) - c.generate(ctx.references.toArray).asInstanceOf[Projection] + val c = CodeGenerator.compile(code) + val resultRow = new SpecificMutableRow(expressions.map(_.dataType)) + c.generate(ctx.references.toArray :+ resultRow).asInstanceOf[Projection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 2136f82ba4752..6aa9cbf08bdb9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -39,17 +39,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) case t: ArrayType if canSupport(t.elementType) => true case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true - case dt: OpenHashSetUDT => false // it's not a standard UDT case udt: UserDefinedType[_] => canSupport(udt.sqlType) case _ => false } - private val rowWriterClass = classOf[UnsafeRowWriter].getName - private val arrayWriterClass = classOf[UnsafeArrayWriter].getName - // TODO: if the nullability of field is correct, we can use it to save null check. private def writeStructToBuffer( - ctx: CodeGenContext, + ctx: CodegenContext, input: String, fieldTypes: Seq[DataType], bufferHolder: String): String = { @@ -57,7 +53,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val fieldName = ctx.freshName("fieldName") val code = s"final ${ctx.javaType(dt)} $fieldName = ${ctx.getValue(input, dt, i.toString)};" val isNull = s"$input.isNullAt($i)" - GeneratedExpressionCode(code, isNull, fieldName) + ExprCode(code, isNull, fieldName) } s""" @@ -70,13 +66,31 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } private def writeExpressionsToBuffer( - ctx: CodeGenContext, + ctx: CodegenContext, row: String, - inputs: Seq[GeneratedExpressionCode], + inputs: Seq[ExprCode], inputTypes: Seq[DataType], - bufferHolder: String): String = { + bufferHolder: String, + isTopLevel: Boolean = false): String = { + val rowWriterClass = classOf[UnsafeRowWriter].getName val rowWriter = ctx.freshName("rowWriter") - ctx.addMutableState(rowWriterClass, rowWriter, s"this.$rowWriter = new $rowWriterClass();") + ctx.addMutableState(rowWriterClass, rowWriter, + s"this.$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});") + + val resetWriter = if (isTopLevel) { + // For top level row writer, it always writes to the beginning of the global buffer holder, + // which means its fixed-size region always in the same position, so we don't need to call + // `reset` to set up its fixed-size region every time. + if (inputs.map(_.isNull).forall(_ == "false")) { + // If all fields are not nullable, which means the null bits never changes, then we don't + // need to clear it out every time. + "" + } else { + s"$rowWriter.zeroOutNullBytes();" + } + } else { + s"$rowWriter.reset();" + } val writeFields = inputs.zip(inputTypes).zipWithIndex.map { case ((input, dataType), index) => @@ -123,11 +137,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor); """ - case _ if ctx.isPrimitiveType(dt) => - s""" - $rowWriter.write($index, ${input.value}); - """ - case t: DecimalType => s"$rowWriter.write($index, ${input.value}, ${t.precision}, ${t.scale});" @@ -136,28 +145,36 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => s"$rowWriter.write($index, ${input.value});" } - s""" - ${input.code} - if (${input.isNull}) { - $setNull - } else { - $writeField - } - """ + if (input.isNull == "false") { + s""" + ${input.code} + ${writeField.trim} + """ + } else { + s""" + ${input.code} + if (${input.isNull}) { + ${setNull.trim} + } else { + ${writeField.trim} + } + """ + } } s""" - $rowWriter.initialize($bufferHolder, ${inputs.length}); + $resetWriter ${ctx.splitExpressions(row, writeFields)} - """ + """.trim } // TODO: if the nullability of array element is correct, we can use it to save null check. private def writeArrayToBuffer( - ctx: CodeGenContext, + ctx: CodegenContext, input: String, elementType: DataType, bufferHolder: String): String = { + val arrayWriterClass = classOf[UnsafeArrayWriter].getName val arrayWriter = ctx.freshName("arrayWriter") ctx.addMutableState(arrayWriterClass, arrayWriter, s"this.$arrayWriter = new $arrayWriterClass();") @@ -226,7 +243,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // TODO: if the nullability of value element is correct, we can use it to save null check. private def writeMapToBuffer( - ctx: CodeGenContext, + ctx: CodegenContext, input: String, keyType: DataType, valueType: DataType, @@ -264,7 +281,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro * If the input is already in unsafe format, we don't need to go through all elements/fields, * we can directly write it. */ - private def writeUnsafeData(ctx: CodeGenContext, input: String, bufferHolder: String) = { + private def writeUnsafeData(ctx: CodegenContext, input: String, bufferHolder: String) = { val sizeInBytes = ctx.freshName("sizeInBytes") s""" final int $sizeInBytes = $input.getSizeInBytes(); @@ -275,23 +292,52 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro """ } - def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = { - val exprEvals = expressions.map(e => e.gen(ctx)) + def createCode( + ctx: CodegenContext, + expressions: Seq[Expression], + useSubexprElimination: Boolean = false): ExprCode = { + val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) val exprTypes = expressions.map(_.dataType) + val numVarLenFields = exprTypes.count { + case dt if UnsafeRow.isFixedLength(dt) => false + // TODO: consider large decimal and interval type + case _ => true + } + val result = ctx.freshName("result") - ctx.addMutableState("UnsafeRow", result, s"this.$result = new UnsafeRow();") - val bufferHolder = ctx.freshName("bufferHolder") + ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});") + + val holder = ctx.freshName("holder") val holderClass = classOf[BufferHolder].getName - ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();") + ctx.addMutableState(holderClass, holder, + s"this.$holder = new $holderClass($result, ${numVarLenFields * 32});") + + val resetBufferHolder = if (numVarLenFields == 0) { + "" + } else { + s"$holder.reset();" + } + val updateRowSize = if (numVarLenFields == 0) { + "" + } else { + s"$result.setTotalSize($holder.totalSize());" + } + + // Evaluate all the subexpression. + val evalSubexpr = ctx.subexprFunctions.mkString("\n") + + val writeExpressions = + writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true) val code = s""" - $bufferHolder.reset(); - ${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)} - $result.pointTo($bufferHolder.buffer, ${expressions.length}, $bufferHolder.totalSize()); + $resetBufferHolder + $evalSubexpr + $writeExpressions + $updateRowSize """ - GeneratedExpressionCode(code, "false", result) + ExprCode(code, "false", result) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = @@ -300,35 +346,45 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = in.map(BindReferences.bindReference(_, inputSchema)) - protected def create(expressions: Seq[Expression]): UnsafeProjection = { - val ctx = newCodeGenContext() + def generate( + expressions: Seq[Expression], + subexpressionEliminationEnabled: Boolean): UnsafeProjection = { + create(canonicalize(expressions), subexpressionEliminationEnabled) + } - val eval = createCode(ctx, expressions) + protected def create(references: Seq[Expression]): UnsafeProjection = { + create(references, subexpressionEliminationEnabled = false) + } + + private def create( + expressions: Seq[Expression], + subexpressionEliminationEnabled: Boolean): UnsafeProjection = { + val ctx = newCodeGenContext() + val eval = createCode(ctx, expressions, subexpressionEliminationEnabled) val code = s""" - public Object generate($exprType[] exprs) { - return new SpecificUnsafeProjection(exprs); + public java.lang.Object generate(Object[] references) { + return new SpecificUnsafeProjection(references); } class SpecificUnsafeProjection extends ${classOf[UnsafeProjection].getName} { - private $exprType[] expressions; - - ${declareMutableStates(ctx)} - ${declareAddedFunctions(ctx)} + private Object[] references; + ${ctx.declareMutableStates()} + ${ctx.declareAddedFunctions()} - public SpecificUnsafeProjection($exprType[] expressions) { - this.expressions = expressions; - ${initMutableStates(ctx)} + public SpecificUnsafeProjection(Object[] references) { + this.references = references; + ${ctx.initMutableStates()} } // Scala.Function1 need this - public Object apply(Object row) { + public java.lang.Object apply(java.lang.Object row) { return apply((InternalRow) row); } public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { - ${eval.code} + ${eval.code.trim} return ${eval.value}; } } @@ -336,7 +392,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") - val c = compile(code) + val c = CodeGenerator.compile(code) c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index da91ff29537b3..b1ffbaa3e94ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.Platform - abstract class UnsafeRowJoiner { def join(row1: UnsafeRow, row2: UnsafeRow): UnsafeRow } @@ -61,9 +60,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U val outputBitsetWords = (schema1.size + schema2.size + 63) / 64 val bitset1Remainder = schema1.size % 64 - // The number of words we can reduce when we concat two rows together. + // The number of bytes we can reduce when we concat two rows together. // The only reduction comes from merging the bitset portion of the two rows, saving 1 word. - val sizeReduction = bitset1Words + bitset2Words - outputBitsetWords + val sizeReduction = (bitset1Words + bitset2Words - outputBitsetWords) * 8 // --------------------- copy bitset from row 1 and row 2 --------------------------- // val copyBitset = Seq.tabulate(outputBitsetWords) { i => @@ -159,26 +158,26 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U // ------------------------ Finally, put everything together --------------------------- // val code = s""" - |public Object generate($exprType[] exprs) { + |public java.lang.Object generate(Object[] references) { | return new SpecificUnsafeRowJoiner(); |} | |class SpecificUnsafeRowJoiner extends ${classOf[UnsafeRowJoiner].getName} { | private byte[] buf = new byte[64]; - | private UnsafeRow out = new UnsafeRow(); + | private UnsafeRow out = new UnsafeRow(${schema1.size + schema2.size}); | | public UnsafeRow join(UnsafeRow row1, UnsafeRow row2) { | // row1: ${schema1.size} fields, $bitset1Words words in bitset | // row2: ${schema2.size}, $bitset2Words words in bitset | // output: ${schema1.size + schema2.size} fields, $outputBitsetWords words in bitset - | final int sizeInBytes = row1.getSizeInBytes() + row2.getSizeInBytes(); + | final int sizeInBytes = row1.getSizeInBytes() + row2.getSizeInBytes() - $sizeReduction; | if (sizeInBytes > buf.length) { | buf = new byte[sizeInBytes]; | } | - | final Object obj1 = row1.getBaseObject(); + | final java.lang.Object obj1 = row1.getBaseObject(); | final long offset1 = row1.getBaseOffset(); - | final Object obj2 = row2.getBaseObject(); + | final java.lang.Object obj2 = row2.getBaseObject(); | final long offset2 = row2.getBaseOffset(); | | $copyBitset @@ -188,7 +187,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | $copyVariableLengthRow2 | $updateOffset | - | out.pointTo(buf, ${schema1.size + schema2.size}, sizeInBytes - $sizeReduction); + | out.pointTo(buf, sizeInBytes); | | return out; | } @@ -197,7 +196,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U logDebug(s"SpecificUnsafeRowJoiner($schema1, $schema2):\n${CodeFormatter.format(code)}") - val c = compile(code) + val c = CodeGenerator.compile(code) c.generate(Array.empty).asInstanceOf[UnsafeRowJoiner] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 2cf19b939f734..ab790cf372d9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.util.{MapData, GenericArrayData, ArrayData} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ /** * Given an array or map, returns its size. */ +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the size of an array or a map.") case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType)) @@ -35,7 +37,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType case _: MapType => value.asInstanceOf[MapData].numElements() } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).numElements();") } } @@ -44,6 +46,11 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType * Sorts the input array in ascending / descending order according to the natural ordering of * the array elements and returns it. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(array(obj1, obj2,...)) - Sorts the input array in ascending order according to the natural ordering of the array elements.", + extended = " > SELECT _FUNC_(array('b', 'd', 'c', 'a'));\n 'a', 'b', 'c', 'd'") +// scalastyle:on line.size.limit case class SortArray(base: Expression, ascendingOrder: Expression) extends BinaryExpression with ExpectsInputTypes with CodegenFallback { @@ -68,6 +75,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) private lazy val lt: Comparator[Any] = { val ordering = base.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } @@ -90,6 +98,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) private lazy val gt: Comparator[Any] = { val ordering = base.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } @@ -123,6 +132,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) /** * Checks if the array (left) has the element (right) */ +@ExpressionDescription( + usage = "_FUNC_(array, value) - Returns TRUE if the array contains value.", + extended = " > SELECT _FUNC_(array(1, 2, 3), 2);\n true") case class ArrayContains(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -168,7 +180,7 @@ case class ArrayContains(left: Expression, right: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (arr, value) => { val i = ctx.freshName("i") val getValue = ctx.getValue(arr, right.dataType, i) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 1854dfaa7db35..74de4a776de89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{GenericArrayData, TypeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Returns an Array containing the evaluation of all children expressions. */ +@ExpressionDescription( + usage = "_FUNC_(n0, ...) - Returns an array with the given elements.") case class CreateArray(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) @@ -46,7 +48,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { new GenericArrayData(children.map(_.eval(input)).toArray) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val arrayClass = classOf[GenericArrayData].getName val values = ctx.freshName("values") s""" @@ -69,9 +71,94 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def prettyName: String = "array" } +/** + * Returns a catalyst Map containing the evaluation of all children expressions as keys and values. + * The children are a flatted sequence of kv pairs, e.g. (key1, value1, key2, value2, ...) + */ +@ExpressionDescription( + usage = "_FUNC_(key0, value0, key1, value1...) - Creates a map with the given key/value pairs.") +case class CreateMap(children: Seq[Expression]) extends Expression { + private[sql] lazy val keys = children.indices.filter(_ % 2 == 0).map(children) + private[sql] lazy val values = children.indices.filter(_ % 2 != 0).map(children) + + override def foldable: Boolean = children.forall(_.foldable) + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.size % 2 != 0) { + TypeCheckResult.TypeCheckFailure(s"$prettyName expects an positive even number of arguments.") + } else if (keys.map(_.dataType).distinct.length > 1) { + TypeCheckResult.TypeCheckFailure("The given keys of function map should all be the same " + + "type, but they are " + keys.map(_.dataType.simpleString).mkString("[", ", ", "]")) + } else if (values.map(_.dataType).distinct.length > 1) { + TypeCheckResult.TypeCheckFailure("The given values of function map should all be the same " + + "type, but they are " + values.map(_.dataType.simpleString).mkString("[", ", ", "]")) + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def dataType: DataType = { + MapType( + keyType = keys.headOption.map(_.dataType).getOrElse(NullType), + valueType = values.headOption.map(_.dataType).getOrElse(NullType), + valueContainsNull = values.exists(_.nullable)) + } + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = { + val keyArray = keys.map(_.eval(input)).toArray + if (keyArray.contains(null)) { + throw new RuntimeException("Cannot use null as map key!") + } + val valueArray = values.map(_.eval(input)).toArray + new ArrayBasedMapData(new GenericArrayData(keyArray), new GenericArrayData(valueArray)) + } + + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + val arrayClass = classOf[GenericArrayData].getName + val mapClass = classOf[ArrayBasedMapData].getName + val keyArray = ctx.freshName("keyArray") + val valueArray = ctx.freshName("valueArray") + val keyData = s"new $arrayClass($keyArray)" + val valueData = s"new $arrayClass($valueArray)" + s""" + final boolean ${ev.isNull} = false; + final Object[] $keyArray = new Object[${keys.size}]; + final Object[] $valueArray = new Object[${values.size}]; + """ + keys.zipWithIndex.map { + case (key, i) => + val eval = key.gen(ctx) + s""" + ${eval.code} + if (${eval.isNull}) { + throw new RuntimeException("Cannot use null as map key!"); + } else { + $keyArray[$i] = ${eval.value}; + } + """ + }.mkString("\n") + values.zipWithIndex.map { + case (value, i) => + val eval = value.gen(ctx) + s""" + ${eval.code} + if (${eval.isNull}) { + $valueArray[$i] = null; + } else { + $valueArray[$i] = ${eval.value}; + } + """ + }.mkString("\n") + s"final MapData ${ev.value} = new $mapClass($keyData, $valueData);" + } + + override def prettyName: String = "map" +} + /** * Returns a Row containing the evaluation of all children expressions. */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.") case class CreateStruct(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) @@ -94,7 +181,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { InternalRow(children.map(_.eval(input)): _*) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val rowClass = classOf[GenericInternalRow].getName val values = ctx.freshName("values") s""" @@ -123,10 +210,14 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { * * @param children Seq(name1, val1, name2, val2, ...) */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.") +// scalastyle:on line.size.limit case class CreateNamedStruct(children: Seq[Expression]) extends Expression { /** - * Returns Aliased [[Expressions]] that could be used to construct a flattened version of this + * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this * StructType. */ def flatten: Seq[NamedExpression] = valExprs.zip(names).map { @@ -159,7 +250,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { TypeCheckResult.TypeCheckFailure( s"Only foldable StringType expressions are allowed to appear at odd position , got :" + s" ${invalidNames.mkString(",")}") - } else if (names.forall(_ != null)){ + } else if (!names.contains(null)) { TypeCheckResult.TypeCheckSuccess } else { TypeCheckResult.TypeCheckFailure("Field name should not be null") @@ -171,7 +262,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { InternalRow(valExprs.map(_.eval(input)): _*) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val rowClass = classOf[GenericInternalRow].getName val values = ctx.freshName("values") s""" @@ -223,7 +314,7 @@ case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { InternalRow(children.map(_.eval(input)): _*) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval = GenerateUnsafeProjection.createCode(ctx, children) ev.isNull = eval.isNull ev.value = eval.value @@ -263,7 +354,7 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression InternalRow(valExprs.map(_.eval(input)): _*) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) ev.isNull = eval.isNull ev.value = eval.value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 41cd0a104a1f5..c06dcc98674fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} -import org.apache.spark.sql.catalyst.util.{MapData, GenericArrayData, ArrayData} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -51,7 +51,7 @@ object ExtractValue { case (StructType(fields), NonNullLiteral(v, StringType)) => val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) - GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal) + GetStructField(child, ordinal, Some(fieldName)) case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) => val fieldName = v.toString @@ -93,36 +93,57 @@ object ExtractValue { } } +trait ExtractValue extends Expression + /** * Returns the value of fields in the Struct `child`. * * No need to do type checking since it is handled by [[ExtractValue]]. + * + * Note that we can pass in the field name directly to keep case preserving in `toString`. + * For example, when get field `yEAr` from ``, we should pass in `yEAr`. */ -case class GetStructField(child: Expression, field: StructField, ordinal: Int) - extends UnaryExpression { +case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None) + extends UnaryExpression with ExtractValue { - override def dataType: DataType = field.dataType - override def nullable: Boolean = child.nullable || field.nullable - override def toString: String = s"$child.${field.name}" + private[sql] lazy val childSchema = child.dataType.asInstanceOf[StructType] + + override def dataType: DataType = childSchema(ordinal).dataType + override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable + + override def toString: String = { + val fieldName = if (resolved) childSchema(ordinal).name else s"_$ordinal" + s"$child.${name.getOrElse(fieldName)}" + } + + override def sql: String = + child.sql + s".${quoteIdentifier(name.getOrElse(childSchema(ordinal).name))}" protected override def nullSafeEval(input: Any): Any = - input.asInstanceOf[InternalRow].get(ordinal, field.dataType) + input.asInstanceOf[InternalRow].get(ordinal, childSchema(ordinal).dataType) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, eval => { - s""" - if ($eval.isNullAt($ordinal)) { - ${ev.isNull} = true; - } else { + if (nullable) { + s""" + if ($eval.isNullAt($ordinal)) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; + } + """ + } else { + s""" ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; - } - """ + """ + } }) } } /** - * Returns the array of value of fields in the Array of Struct `child`. + * For a child whose data type is an array of structs, extracts the `ordinal`-th fields of all array + * elements, and returns them as a new array. * * No need to do type checking since it is handled by [[ExtractValue]]. */ @@ -131,11 +152,11 @@ case class GetArrayStructFields( field: StructField, ordinal: Int, numFields: Int, - containsNull: Boolean) extends UnaryExpression { + containsNull: Boolean) extends UnaryExpression with ExtractValue { override def dataType: DataType = ArrayType(field.dataType, containsNull) - override def nullable: Boolean = child.nullable || containsNull || field.nullable override def toString: String = s"$child.${field.name}" + override def sql: String = s"${child.sql}.${quoteIdentifier(field.name)}" protected override def nullSafeEval(input: Any): Any = { val array = input.asInstanceOf[ArrayData] @@ -158,25 +179,29 @@ case class GetArrayStructFields( new GenericArrayData(result) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, eval => { + val n = ctx.freshName("n") + val values = ctx.freshName("values") + val j = ctx.freshName("j") + val row = ctx.freshName("row") s""" - final int n = $eval.numElements(); - final Object[] values = new Object[n]; - for (int j = 0; j < n; j++) { - if ($eval.isNullAt(j)) { - values[j] = null; + final int $n = $eval.numElements(); + final Object[] $values = new Object[$n]; + for (int $j = 0; $j < $n; $j++) { + if ($eval.isNullAt($j)) { + $values[$j] = null; } else { - final InternalRow row = $eval.getStruct(j, $numFields); - if (row.isNullAt($ordinal)) { - values[j] = null; + final InternalRow $row = $eval.getStruct($j, $numFields); + if ($row.isNullAt($ordinal)) { + $values[$j] = null; } else { - values[j] = ${ctx.getValue("row", field.dataType, ordinal.toString)}; + $values[$j] = ${ctx.getValue(row, field.dataType, ordinal.toString)}; } } } - ${ev.value} = new $arrayClass(values); + ${ev.value} = new $arrayClass($values); """ }) } @@ -188,12 +213,13 @@ case class GetArrayStructFields( * We need to do type checking here as `ordinal` expression maybe unresolved. */ case class GetArrayItem(child: Expression, ordinal: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with ExtractValue { // We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType) override def toString: String = s"$child[$ordinal]" + override def sql: String = s"${child.sql}[${ordinal.sql}]" override def left: Expression = child override def right: Expression = ordinal @@ -206,21 +232,22 @@ case class GetArrayItem(child: Expression, ordinal: Expression) protected override def nullSafeEval(value: Any, ordinal: Any): Any = { val baseValue = value.asInstanceOf[ArrayData] val index = ordinal.asInstanceOf[Number].intValue() - if (index >= baseValue.numElements() || index < 0) { + if (index >= baseValue.numElements() || index < 0 || baseValue.isNullAt(index)) { null } else { baseValue.get(index, dataType) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val index = ctx.freshName("index") s""" - final int index = (int) $eval2; - if (index >= $eval1.numElements() || index < 0) { + final int $index = (int) $eval2; + if ($index >= $eval1.numElements() || $index < 0 || $eval1.isNullAt($index)) { ${ev.isNull} = true; } else { - ${ev.value} = ${ctx.getValue(eval1, dataType, "index")}; + ${ev.value} = ${ctx.getValue(eval1, dataType, index)}; } """ }) @@ -233,7 +260,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) * We need to do type checking here as `key` expression maybe unresolved. */ case class GetMapValue(child: Expression, key: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with ExtractValue { private def keyType = child.dataType.asInstanceOf[MapType].keyType @@ -241,6 +268,7 @@ case class GetMapValue(child: Expression, key: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) override def toString: String = s"$child[$key]" + override def sql: String = s"${child.sql}[${key.sql}]" override def left: Expression = child override def right: Expression = key @@ -255,6 +283,7 @@ case class GetMapValue(child: Expression, key: Expression) val map = value.asInstanceOf[MapData] val length = map.numElements() val keys = map.keyArray() + val values = map.valueArray() var i = 0 var found = false @@ -266,23 +295,25 @@ case class GetMapValue(child: Expression, key: Expression) } } - if (!found) { + if (!found || values.isNullAt(i)) { null } else { - map.valueArray().get(i, dataType) + values.get(i, dataType) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val index = ctx.freshName("index") val length = ctx.freshName("length") val keys = ctx.freshName("keys") val found = ctx.freshName("found") val key = ctx.freshName("key") + val values = ctx.freshName("values") nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" final int $length = $eval1.numElements(); final ArrayData $keys = $eval1.keyArray(); + final ArrayData $values = $eval1.valueArray(); int $index = 0; boolean $found = false; @@ -295,10 +326,10 @@ case class GetMapValue(child: Expression, key: Expression) } } - if ($found) { - ${ev.value} = ${ctx.getValue(eval1 + ".valueArray()", dataType, index)}; - } else { + if (!$found || $values.isNullAt($index)) { ${ev.isNull} = true; + } else { + ${ev.value} = ${ctx.getValue(values, dataType, index)}; } """ }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index d532629984bec..ae6a94842f7d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -21,9 +21,12 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.types.{NullType, BooleanType, DataType} - +import org.apache.spark.sql.types._ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr1,expr2,expr3) - If expr1 is TRUE then IF() returns expr2; otherwise it returns expr3.") +// scalastyle:on line.size.limit case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) extends Expression { @@ -34,8 +37,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi if (predicate.dataType != BooleanType) { TypeCheckResult.TypeCheckFailure( s"type of predicate expression in If should be boolean, not ${predicate.dataType}") - } else if (trueValue.dataType != falseValue.dataType) { - TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + + } else if (trueValue.dataType.asNullable != falseValue.dataType.asNullable) { + TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") } else { TypeCheckResult.TypeCheckSuccess @@ -45,14 +48,14 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def dataType: DataType = trueValue.dataType override def eval(input: InternalRow): Any = { - if (true == predicate.eval(input)) { + if (java.lang.Boolean.TRUE.equals(predicate.eval(input))) { trueValue.eval(input) } else { falseValue.eval(input) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val condEval = predicate.gen(ctx) val trueEval = trueValue.gen(ctx) val falseEval = falseValue.gen(ctx) @@ -74,239 +77,185 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } override def toString: String = s"if ($predicate) $trueValue else $falseValue" -} -trait CaseWhenLike extends Expression { + override def sql: String = s"(IF(${predicate.sql}, ${trueValue.sql}, ${falseValue.sql}))" +} - // Note that `branches` are considered in consecutive pairs (cond, val), and the optional last - // element is the value for the default catch-all case (if provided). - // Hence, `branches` consists of at least two elements, and can have an odd or even length. - def branches: Seq[Expression] +/** + * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". + * When a = true, returns b; when c = true, returns d; else returns e. + * + * @param branches seq of (branch condition, branch value) + * @param elseValue optional value for the else branch + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END - When a = true, returns b; when c = true, return d; else return e.") +// scalastyle:on line.size.limit +case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None) + extends Expression with CodegenFallback { - @transient lazy val whenList = - branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq - @transient lazy val thenList = - branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq - val elseValue = if (branches.length % 2 == 0) None else Option(branches.last) + override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue // both then and else expressions should be considered. - def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) - def valueTypesEqual: Boolean = valueTypes.distinct.size == 1 + def valueTypes: Seq[DataType] = branches.map(_._2.dataType) ++ elseValue.map(_.dataType) - override def checkInputDataTypes(): TypeCheckResult = { - if (valueTypesEqual) { - checkTypesInternal() - } else { - TypeCheckResult.TypeCheckFailure( - "THEN and ELSE expressions should all be same type or coercible to a common type") - } + def valueTypesEqual: Boolean = valueTypes.size <= 1 || valueTypes.sliding(2, 1).forall { + case Seq(dt1, dt2) => dt1.sameType(dt2) } - protected def checkTypesInternal(): TypeCheckResult - - override def dataType: DataType = thenList.head.dataType + override def dataType: DataType = branches.head._2.dataType override def nullable: Boolean = { - // If no value is nullable and no elseValue is provided, the whole statement defaults to null. - thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) + // Result is nullable if any of the branch is nullable, or if the else value is nullable + branches.exists(_._2.nullable) || elseValue.map(_.nullable).getOrElse(true) } -} -// scalastyle:off -/** - * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". - * Refer to this link for the corresponding semantics: - * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions - */ -// scalastyle:on -case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { - - // Use private[this] Array to speed up evaluation. - @transient private[this] lazy val branchesArr = branches.toArray - - override def children: Seq[Expression] = branches - - override protected def checkTypesInternal(): TypeCheckResult = { - if (whenList.forall(_.dataType == BooleanType)) { - TypeCheckResult.TypeCheckSuccess + override def checkInputDataTypes(): TypeCheckResult = { + // Make sure all branch conditions are boolean types. + if (valueTypesEqual) { + if (branches.forall(_._1.dataType == BooleanType)) { + TypeCheckResult.TypeCheckSuccess + } else { + val index = branches.indexWhere(_._1.dataType != BooleanType) + TypeCheckResult.TypeCheckFailure( + s"WHEN expressions in CaseWhen should all be boolean type, " + + s"but the ${index + 1}th when expression's type is ${branches(index)._1}") + } } else { - val index = whenList.indexWhere(_.dataType != BooleanType) TypeCheckResult.TypeCheckFailure( - s"WHEN expressions in CaseWhen should all be boolean type, " + - s"but the ${index + 1}th when expression's type is ${whenList(index)}") + "THEN and ELSE expressions should all be same type or coercible to a common type") } } - /** Written in imperative fashion for performance considerations. */ override def eval(input: InternalRow): Any = { - val len = branchesArr.length var i = 0 - // If all branches fail and an elseVal is not provided, the whole statement - // defaults to null, according to Hive's semantics. - while (i < len - 1) { - if (branchesArr(i).eval(input) == true) { - return branchesArr(i + 1).eval(input) + while (i < branches.size) { + if (java.lang.Boolean.TRUE.equals(branches(i)._1.eval(input))) { + return branches(i)._2.eval(input) } - i += 2 + i += 1 } - var res: Any = null - if (i == len - 1) { - res = branchesArr(i).eval(input) + if (elseValue.isDefined) { + return elseValue.get.eval(input) + } else { + return null } - return res } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val len = branchesArr.length - val got = ctx.freshName("got") + def shouldCodegen: Boolean = { + branches.length < CaseWhen.MAX_NUM_CASES_FOR_CODEGEN + } - val cases = (0 until len/2).map { i => - val cond = branchesArr(i * 2).gen(ctx) - val res = branchesArr(i * 2 + 1).gen(ctx) + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + if (!shouldCodegen) { + // Fallback to interpreted mode if there are too many branches, as it may reach the + // 64K limit (limit on bytecode size for a single function). + return super[CodegenFallback].genCode(ctx, ev) + } + // Generate code that looks like: + // + // condA = ... + // if (condA) { + // valueA + // } else { + // condB = ... + // if (condB) { + // valueB + // } else { + // condC = ... + // if (condC) { + // valueC + // } else { + // elseValue + // } + // } + // } + val cases = branches.map { case (condExpr, valueExpr) => + val cond = condExpr.gen(ctx) + val res = valueExpr.gen(ctx) s""" - if (!$got) { - ${cond.code} - if (!${cond.isNull} && ${cond.value}) { - $got = true; - ${res.code} - ${ev.isNull} = ${res.isNull}; - ${ev.value} = ${res.value}; - } + ${cond.code} + if (!${cond.isNull} && ${cond.value}) { + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.value} = ${res.value}; } """ - }.mkString("\n") + } - val other = if (len % 2 == 1) { - val res = branchesArr(len - 1).gen(ctx) - s""" - if (!$got) { + var generatedCode = cases.mkString("", "\nelse {\n", "\nelse {\n") + + elseValue.foreach { elseExpr => + val res = elseExpr.gen(ctx) + generatedCode += + s""" ${res.code} ${ev.isNull} = ${res.isNull}; ${ev.value} = ${res.value}; - } - """ - } else { - "" + """ } + generatedCode += "}\n" * cases.size + s""" - boolean $got = false; boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $cases - $other + $generatedCode """ } override def toString: String = { - "CASE" + branches.sliding(2, 2).map { - case Seq(cond, value) => s" WHEN $cond THEN $value" - case Seq(elseValue) => s" ELSE $elseValue" - }.mkString + val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString + val elseCase = elseValue.map(" ELSE " + _).getOrElse("") + "CASE" + cases + elseCase + " END" } -} - -// scalastyle:off -/** - * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END". - * Refer to this link for the corresponding semantics: - * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions - */ -// scalastyle:on -case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseWhenLike { - // Use private[this] Array to speed up evaluation. - @transient private[this] lazy val branchesArr = branches.toArray + override def sql: String = { + val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString + val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("") + "CASE" + cases + elseCase + " END" + } +} - override def children: Seq[Expression] = key +: branches +/** Factory methods for CaseWhen. */ +object CaseWhen { - override protected def checkTypesInternal(): TypeCheckResult = { - if ((key +: whenList).map(_.dataType).distinct.size > 1) { - TypeCheckResult.TypeCheckFailure( - "key and WHEN expressions should all be same type or coercible to a common type") - } else { - TypeCheckResult.TypeCheckSuccess - } - } + // The maximum number of switches supported with codegen. + val MAX_NUM_CASES_FOR_CODEGEN = 20 - private def evalElse(input: InternalRow): Any = { - if (branchesArr.length % 2 == 0) { - null - } else { - branchesArr(branchesArr.length - 1).eval(input) - } + def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = { + CaseWhen(branches, Option(elseValue)) } - /** Written in imperative fashion for performance considerations. */ - override def eval(input: InternalRow): Any = { - val evaluatedKey = key.eval(input) - // If key is null, we can just return the else part or null if there is no else. - // If key is not null but doesn't match any when part, we need to return - // the else part or null if there is no else, according to Hive's semantics. - if (evaluatedKey != null) { - val len = branchesArr.length - var i = 0 - while (i < len - 1) { - if (evaluatedKey == branchesArr(i).eval(input)) { - return branchesArr(i + 1).eval(input) - } - i += 2 - } - } - evalElse(input) + /** + * A factory method to facilitate the creation of this expression when used in parsers. + * @param branches Expressions at even position are the branch conditions, and expressions at odd + * position are branch values. + */ + def createFromParser(branches: Seq[Expression]): CaseWhen = { + val cases = branches.grouped(2).flatMap { + case cond :: value :: Nil => Some((cond, value)) + case value :: Nil => None + }.toArray.toSeq // force materialization to make the seq serializable + val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None + CaseWhen(cases, elseValue) } +} - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val keyEval = key.gen(ctx) - val len = branchesArr.length - val got = ctx.freshName("got") - - val cases = (0 until len/2).map { i => - val cond = branchesArr(i * 2).gen(ctx) - val res = branchesArr(i * 2 + 1).gen(ctx) - s""" - if (!$got) { - ${cond.code} - if (!${cond.isNull} && ${ctx.genEqual(key.dataType, keyEval.value, cond.value)}) { - $got = true; - ${res.code} - ${ev.isNull} = ${res.isNull}; - ${ev.value} = ${res.value}; - } - } - """ - }.mkString("\n") - - val other = if (len % 2 == 1) { - val res = branchesArr(len - 1).gen(ctx) - s""" - if (!$got) { - ${res.code} - ${ev.isNull} = ${res.isNull}; - ${ev.value} = ${res.value}; - } - """ - } else { - "" - } - - s""" - boolean $got = false; - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - ${keyEval.code} - if (!${keyEval.isNull}) { - $cases - } - $other - """ - } - override def toString: String = { - s"CASE $key" + branches.sliding(2, 2).map { - case Seq(cond, value) => s" WHEN $cond THEN $value" - case Seq(elseValue) => s" ELSE $elseValue" - }.mkString +/** + * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END". + * When a = b, returns c; when a = d, returns e; else returns f. + */ +object CaseKeyWhen { + def apply(key: Expression, branches: Seq[Expression]): CaseWhen = { + val cases = branches.grouped(2).flatMap { + case cond :: value :: Nil => Some((EqualTo(key, cond), value)) + case value :: Nil => None + }.toArray.toSeq // force materialization to make the seq serializable + val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None + CaseWhen(cases, elseValue) } } @@ -314,6 +263,8 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW * A function that returns the least value of all parameters, skipping null values. * It takes at least 2 parameters, and returns null iff all parameters are null. */ +@ExpressionDescription( + usage = "_FUNC_(n1, ...) - Returns the least value of all parameters, skipping null values.") case class Least(children: Seq[Expression]) extends Expression { override def nullable: Boolean = children.forall(_.nullable) @@ -346,21 +297,25 @@ case class Least(children: Seq[Expression]) extends Expression { }) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val evalChildren = children.map(_.gen(ctx)) - def updateEval(i: Int): String = + val first = evalChildren(0) + val rest = evalChildren.drop(1) + def updateEval(eval: ExprCode): String = { s""" - if (!${evalChildren(i).isNull} && (${ev.isNull} || - ${ctx.genComp(dataType, evalChildren(i).value, ev.value)} < 0)) { + ${eval.code} + if (!${eval.isNull} && (${ev.isNull} || + ${ctx.genGreater(dataType, ev.value, eval.value)})) { ${ev.isNull} = false; - ${ev.value} = ${evalChildren(i).value}; + ${ev.value} = ${eval.value}; } """ + } s""" - ${evalChildren.map(_.code).mkString("\n")} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - ${children.indices.map(updateEval).mkString("\n")} + ${first.code} + boolean ${ev.isNull} = ${first.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${first.value}; + ${rest.map(updateEval).mkString("\n")} """ } } @@ -369,6 +324,8 @@ case class Least(children: Seq[Expression]) extends Expression { * A function that returns the greatest value of all parameters, skipping null values. * It takes at least 2 parameters, and returns null iff all parameters are null. */ +@ExpressionDescription( + usage = "_FUNC_(n1, ...) - Returns the greatest value of all parameters, skipping null values.") case class Greatest(children: Seq[Expression]) extends Expression { override def nullable: Boolean = children.forall(_.nullable) @@ -401,21 +358,26 @@ case class Greatest(children: Seq[Expression]) extends Expression { }) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val evalChildren = children.map(_.gen(ctx)) - def updateEval(i: Int): String = + val first = evalChildren(0) + val rest = evalChildren.drop(1) + def updateEval(eval: ExprCode): String = { s""" - if (!${evalChildren(i).isNull} && (${ev.isNull} || - ${ctx.genComp(dataType, evalChildren(i).value, ev.value)} > 0)) { + ${eval.code} + if (!${eval.isNull} && (${ev.isNull} || + ${ctx.genGreater(dataType, eval.value, ev.value)})) { ${ev.isNull} = false; - ${ev.value} = ${evalChildren(i).value}; + ${ev.value} = ${eval.value}; } """ + } s""" - ${evalChildren.map(_.code).mkString("\n")} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - ${children.indices.map(updateEval).mkString("\n")} + ${first.code} + boolean ${ev.isNull} = ${first.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${first.value}; + ${rest.map(updateEval).mkString("\n")} """ } } + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 13cc6bb6f27b8..9135753041f92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -20,21 +20,23 @@ package org.apache.spark.sql.catalyst.expressions import java.text.SimpleDateFormat import java.util.{Calendar, TimeZone} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import scala.util.Try + import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, + ExprCode} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} -import scala.util.Try - /** * Returns the current date at the start of query evaluation. * All calls of current_date within the same query return the same value. * * There is no code generation since this expression should get constant folded by the optimizer. */ +@ExpressionDescription( + usage = "_FUNC_() - Returns the current date at the start of query evaluation.") case class CurrentDate() extends LeafExpression with CodegenFallback { override def foldable: Boolean = true override def nullable: Boolean = false @@ -44,6 +46,8 @@ case class CurrentDate() extends LeafExpression with CodegenFallback { override def eval(input: InternalRow): Any = { DateTimeUtils.millisToDays(System.currentTimeMillis()) } + + override def prettyName: String = "current_date" } /** @@ -52,6 +56,8 @@ case class CurrentDate() extends LeafExpression with CodegenFallback { * * There is no code generation since this expression should get constant folded by the optimizer. */ +@ExpressionDescription( + usage = "_FUNC_() - Returns the current timestamp at the start of query evaluation.") case class CurrentTimestamp() extends LeafExpression with CodegenFallback { override def foldable: Boolean = true override def nullable: Boolean = false @@ -61,11 +67,16 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback { override def eval(input: InternalRow): Any = { System.currentTimeMillis() * 1000L } + + override def prettyName: String = "current_timestamp" } /** * Adds a number of days to startdate. */ +@ExpressionDescription( + usage = "_FUNC_(start_date, num_days) - Returns the date that is num_days after start_date.", + extended = "> SELECT _FUNC_('2016-07-30', 1);\n '2016-07-31'") case class DateAdd(startDate: Expression, days: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -80,16 +91,21 @@ case class DateAdd(startDate: Expression, days: Expression) start.asInstanceOf[Int] + d.asInstanceOf[Int] } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (sd, d) => { s"""${ev.value} = $sd + $d;""" }) } + + override def prettyName: String = "date_add" } /** * Subtracts a number of days to startdate. */ +@ExpressionDescription( + usage = "_FUNC_(start_date, num_days) - Returns the date that is num_days before start_date.", + extended = "> SELECT _FUNC_('2016-07-30', 1);\n '2016-07-29'") case class DateSub(startDate: Expression, days: Expression) extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = startDate @@ -103,13 +119,18 @@ case class DateSub(startDate: Expression, days: Expression) start.asInstanceOf[Int] - d.asInstanceOf[Int] } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (sd, d) => { s"""${ev.value} = $sd - $d;""" }) } + + override def prettyName: String = "date_sub" } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the hour component of the string/timestamp/interval.", + extended = "> SELECT _FUNC_('2009-07-30 12:58:59');\n 12") case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) @@ -120,12 +141,15 @@ case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInpu DateTimeUtils.getHours(timestamp.asInstanceOf[Long]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getHours($c)") } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the minute component of the string/timestamp/interval.", + extended = "> SELECT _FUNC_('2009-07-30 12:58:59');\n 58") case class Minute(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) @@ -136,12 +160,15 @@ case class Minute(child: Expression) extends UnaryExpression with ImplicitCastIn DateTimeUtils.getMinutes(timestamp.asInstanceOf[Long]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getMinutes($c)") } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the second component of the string/timestamp/interval.", + extended = "> SELECT _FUNC_('2009-07-30 12:58:59');\n 59") case class Second(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) @@ -152,12 +179,15 @@ case class Second(child: Expression) extends UnaryExpression with ImplicitCastIn DateTimeUtils.getSeconds(timestamp.asInstanceOf[Long]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getSeconds($c)") } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the day of year of date/timestamp.", + extended = "> SELECT _FUNC_('2016-04-09');\n 100") case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -168,13 +198,15 @@ case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCas DateTimeUtils.getDayInYear(date.asInstanceOf[Int]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getDayInYear($c)") } } - +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the year component of the date/timestamp/interval.", + extended = "> SELECT _FUNC_('2016-07-30');\n 2016") case class Year(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -185,12 +217,14 @@ case class Year(child: Expression) extends UnaryExpression with ImplicitCastInpu DateTimeUtils.getYear(date.asInstanceOf[Int]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getYear($c)") } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the quarter of the year for date, in the range 1 to 4.") case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -201,12 +235,15 @@ case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastI DateTimeUtils.getQuarter(date.asInstanceOf[Int]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getQuarter($c)") } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the month component of the date/timestamp/interval", + extended = "> SELECT _FUNC_('2016-07-30');\n 7") case class Month(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -217,12 +254,15 @@ case class Month(child: Expression) extends UnaryExpression with ImplicitCastInp DateTimeUtils.getMonth(date.asInstanceOf[Int]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getMonth($c)") } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the day of month of date/timestamp, or the day of interval.", + extended = "> SELECT _FUNC_('2009-07-30');\n 30") case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -233,12 +273,15 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa DateTimeUtils.getDayOfMonth(date.asInstanceOf[Int]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getDayOfMonth($c)") } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the week of the year of the given date.", + extended = "> SELECT _FUNC_('2008-02-20');\n 8") case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -257,7 +300,7 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa c.get(Calendar.WEEK_OF_YEAR) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName val c = ctx.freshName("cal") @@ -275,6 +318,11 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa } } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(date/timestamp/string, fmt) - Converts a date/timestamp/string to a value of string in the format specified by the date format fmt.", + extended = "> SELECT _FUNC_('2016-04-08', 'y')\n '2016'") +// scalastyle:on line.size.limit case class DateFormatClass(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -287,7 +335,7 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx UTF8String.fromString(sdf.format(new java.util.Date(timestamp.asInstanceOf[Long] / 1000))) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val sdf = classOf[SimpleDateFormat].getName defineCodeGen(ctx, ev, (timestamp, format) => { s"""UTF8String.fromString((new $sdf($format.toString())) @@ -299,7 +347,24 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx } /** - * Converts time string with given pattern + * Converts time string with given pattern. + * Deterministic version of [[UnixTimestamp]], must have at least one parameter. + */ +@ExpressionDescription( + usage = "_FUNC_(date[, pattern]) - Returns the UNIX timestamp of the give time.") +case class ToUnixTimestamp(timeExp: Expression, format: Expression) extends UnixTime { + override def left: Expression = timeExp + override def right: Expression = format + + def this(time: Expression) = { + this(time, Literal("yyyy-MM-dd HH:mm:ss")) + } + + override def prettyName: String = "to_unix_timestamp" +} + +/** + * Converts time string with given pattern. * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) * to Unix time stamp (in seconds), returns null if fail. * Note that hive Language Manual says it returns 0 if fail, but in fact it returns null. @@ -308,9 +373,9 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx * If the first parameter is a Date or Timestamp instead of String, we will ignore the * second parameter. */ -case class UnixTimestamp(timeExp: Expression, format: Expression) - extends BinaryExpression with ExpectsInputTypes { - +@ExpressionDescription( + usage = "_FUNC_([date[, pattern]]) - Returns the UNIX timestamp of current or specified time.") +case class UnixTimestamp(timeExp: Expression, format: Expression) extends UnixTime { override def left: Expression = timeExp override def right: Expression = format @@ -322,10 +387,16 @@ case class UnixTimestamp(timeExp: Expression, format: Expression) this(CurrentTimestamp()) } + override def prettyName: String = "unix_timestamp" +} + +abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, DateType, TimestampType), StringType) override def dataType: DataType = LongType + override def nullable: Boolean = true private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] @@ -347,7 +418,7 @@ case class UnixTimestamp(timeExp: Expression, format: Expression) null } case StringType => - val f = format.eval(input) + val f = right.eval(input) if (f == null) { null } else { @@ -359,7 +430,7 @@ case class UnixTimestamp(timeExp: Expression, format: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { left.dataType match { case StringType if right.foldable => val sdf = classOf[SimpleDateFormat].getName @@ -422,6 +493,8 @@ case class UnixTimestamp(timeExp: Expression, format: Expression) """ } } + + override def prettyName: String = "unix_time" } /** @@ -430,17 +503,23 @@ case class UnixTimestamp(timeExp: Expression, format: Expression) * format. If the format is missing, using format like "1970-01-01 00:00:00". * Note that hive Language Manual says it returns 0 if fail, but in fact it returns null. */ +@ExpressionDescription( + usage = "_FUNC_(unix_time, format) - Returns unix_time in the specified format", + extended = "> SELECT _FUNC_(0, 'yyyy-MM-dd HH:mm:ss');\n '1970-01-01 00:00:00'") case class FromUnixTime(sec: Expression, format: Expression) extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = sec override def right: Expression = format + override def prettyName: String = "from_unixtime" + def this(unix: Expression) = { this(unix, Literal("yyyy-MM-dd HH:mm:ss")) } override def dataType: DataType = StringType + override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringType) @@ -471,7 +550,7 @@ case class FromUnixTime(sec: Expression, format: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val sdf = classOf[SimpleDateFormat].getName if (format.foldable) { if (constFormat == null) { @@ -512,6 +591,9 @@ case class FromUnixTime(sec: Expression, format: Expression) /** * Returns the last day of the month which the date belongs to. */ +@ExpressionDescription( + usage = "_FUNC_(date) - Returns the last day of the month which the date belongs to.", + extended = "> SELECT _FUNC_('2009-01-12');\n '2009-01-31'") case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def child: Expression = startDate @@ -523,7 +605,7 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC DateTimeUtils.getLastDayOfMonth(date.asInstanceOf[Int]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, sd => s"$dtu.getLastDayOfMonth($sd)") } @@ -538,6 +620,11 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC * * Allowed "dayOfWeek" is defined in [[DateTimeUtils.getDayOfWeekFromString]]. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(start_date, day_of_week) - Returns the first date which is later than start_date and named as indicated.", + extended = "> SELECT _FUNC_('2015-01-14', 'TU');\n '2015-01-20'") +// scalastyle:on line.size.limit case class NextDay(startDate: Expression, dayOfWeek: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -547,6 +634,7 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) override def dataType: DataType = DateType + override def nullable: Boolean = true override def nullSafeEval(start: Any, dayOfW: Any): Any = { val dow = DateTimeUtils.getDayOfWeekFromString(dayOfW.asInstanceOf[UTF8String]) @@ -558,7 +646,7 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression) } } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (sd, dowS) => { val dateTimeUtilClass = DateTimeUtils.getClass.getName.stripSuffix("$") val dayOfWeekTerm = ctx.freshName("dayOfWeek") @@ -610,7 +698,7 @@ case class TimeAdd(start: Expression, interval: Expression) start.asInstanceOf[Long], itvl.months, itvl.microseconds) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (sd, i) => { s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds)""" @@ -621,6 +709,10 @@ case class TimeAdd(start: Expression, interval: Expression) /** * Assumes given timestamp is UTC and converts to given timezone. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(timestamp, string timezone) - Assumes given timestamp is UTC and converts to given timezone.") +// scalastyle:on line.size.limit case class FromUTCTimestamp(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -633,7 +725,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) timezone.asInstanceOf[UTF8String].toString) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (right.foldable) { val tz = right.eval() @@ -685,7 +777,7 @@ case class TimeSub(start: Expression, interval: Expression) start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (sd, i) => { s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds)""" @@ -696,6 +788,9 @@ case class TimeSub(start: Expression, interval: Expression) /** * Returns the date that is num_months after start_date. */ +@ExpressionDescription( + usage = "_FUNC_(start_date, num_months) - Returns the date that is num_months after start_date.", + extended = "> SELECT _FUNC_('2016-08-31', 1);\n '2016-09-30'") case class AddMonths(startDate: Expression, numMonths: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -710,17 +805,22 @@ case class AddMonths(startDate: Expression, numMonths: Expression) DateTimeUtils.dateAddMonths(start.asInstanceOf[Int], months.asInstanceOf[Int]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (sd, m) => { s"""$dtu.dateAddMonths($sd, $m)""" }) } + + override def prettyName: String = "add_months" } /** * Returns number of months between dates date1 and date2. */ +@ExpressionDescription( + usage = "_FUNC_(date1, date2) - returns number of months between dates date1 and date2.", + extended = "> SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30');\n 3.94959677") case class MonthsBetween(date1: Expression, date2: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -735,17 +835,23 @@ case class MonthsBetween(date1: Expression, date2: Expression) DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (l, r) => { s"""$dtu.monthsBetween($l, $r)""" }) } + + override def prettyName: String = "months_between" } /** * Assumes given timestamp is in given timezone and converts to UTC. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(timestamp, string timezone) - Assumes given timestamp is in given timezone and converts to UTC.") +// scalastyle:on line.size.limit case class ToUTCTimestamp(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -758,7 +864,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) timezone.asInstanceOf[UTF8String].toString) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (right.foldable) { val tz = right.eval() @@ -793,6 +899,9 @@ case class ToUTCTimestamp(left: Expression, right: Expression) /** * Returns the date part of a timestamp or string. */ +@ExpressionDescription( + usage = "_FUNC_(expr) - Extracts the date part of the date or datetime expression expr.", + extended = "> SELECT _FUNC_('2009-07-30 04:17:52');\n '2009-07-30'") case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { // Implicit casting of spark will accept string in both date and timestamp format, as @@ -803,14 +912,21 @@ case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastIn override def eval(input: InternalRow): Any = child.eval(input) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, d => d) } + + override def prettyName: String = "to_date" } /** * Returns date truncated to the unit specified by the format. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(date, fmt) - Returns returns date with the time portion of the day truncated to the unit specified by the format model fmt.", + extended = "> SELECT _FUNC_('2009-02-12', 'MM')\n '2009-02-01'\n> SELECT _FUNC_('2015-10-27', 'YEAR');\n '2015-01-01'") +// scalastyle:on line.size.limit case class TruncDate(date: Expression, format: Expression) extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = date @@ -818,6 +934,7 @@ case class TruncDate(date: Expression, format: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) override def dataType: DataType = DateType + override def nullable: Boolean = true override def prettyName: String = "trunc" private lazy val truncLevel: Int = @@ -842,7 +959,7 @@ case class TruncDate(date: Expression, format: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (format.foldable) { @@ -881,6 +998,9 @@ case class TruncDate(date: Expression, format: Expression) /** * Returns the number of days from startDate to endDate. */ +@ExpressionDescription( + usage = "_FUNC_(date1, date2) - Returns the number of days between date1 and date2.", + extended = "> SELECT _FUNC_('2009-07-30', '2009-07-31');\n 1") case class DateDiff(endDate: Expression, startDate: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -893,7 +1013,7 @@ case class DateDiff(endDate: Expression, startDate: Expression) end.asInstanceOf[Int] - start.asInstanceOf[Int] } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (end, start) => s"$end - $start") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 78f6631e46474..74e86f40c0364 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types._ /** @@ -34,7 +34,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { protected override def nullSafeEval(input: Any): Any = input.asInstanceOf[Decimal].toUnscaledLong - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()") } } @@ -47,12 +47,13 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression { override def dataType: DataType = DecimalType(precision, scale) + override def nullable: Boolean = true override def toString: String = s"MakeDecimal($child,$precision,$scale)" protected override def nullSafeEval(input: Any): Any = Decimal(input.asInstanceOf[Long], precision, scale) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, eval => { s""" ${ev.value} = (new Decimal()).setOrNull($eval, $precision, $scale); @@ -69,9 +70,10 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un case class PromotePrecision(child: Expression) extends UnaryExpression { override def dataType: DataType = child.dataType override def eval(input: InternalRow): Any = child.eval(input) - override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" + override def gen(ctx: CodegenContext): ExprCode = child.gen(ctx) + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = "" override def prettyName: String = "promote_precision" + override def sql: String = child.sql } /** @@ -91,7 +93,7 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary } } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, eval => { val tmp = ctx.freshName("tmp") s""" @@ -106,4 +108,6 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary } override def toString: String = s"CheckOverflow($child, $dataType)" + + override def sql: String = child.sql } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 894a0730d1c2a..65d7a1d5a0904 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ /** @@ -99,6 +99,10 @@ case class UserDefinedGenerator( /** * Given an input array produces a sequence of rows for each value in the array. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a) - Separates the elements of array a into multiple rows, or the elements of a map into multiple rows and columns.") +// scalastyle:on line.size.limit case class Explode(child: Expression) extends UnaryExpression with Generator with CodegenFallback { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala new file mode 100644 index 0000000000000..3be761c8676c9 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.types._ + +/** + * A placeholder expression for cube/rollup, which will be replaced by analyzer + */ +trait GroupingSet extends Expression with CodegenFallback { + + def groupByExprs: Seq[Expression] + override def children: Seq[Expression] = groupByExprs + + // this should be replaced first + override lazy val resolved: Boolean = false + + override def dataType: DataType = throw new UnsupportedOperationException + override def foldable: Boolean = false + override def nullable: Boolean = true + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException +} + +case class Cube(groupByExprs: Seq[Expression]) extends GroupingSet {} + +case class Rollup(groupByExprs: Seq[Expression]) extends GroupingSet {} + +/** + * Indicates whether a specified column expression in a GROUP BY list is aggregated or not. + * GROUPING returns 1 for aggregated or 0 for not aggregated in the result set. + */ +case class Grouping(child: Expression) extends Expression with Unevaluable { + override def references: AttributeSet = AttributeSet(VirtualColumn.groupingIdAttribute :: Nil) + override def children: Seq[Expression] = child :: Nil + override def dataType: DataType = ByteType + override def nullable: Boolean = false +} + +/** + * GroupingID is a function that computes the level of grouping. + * + * If groupByExprs is empty, it means all grouping expressions in GroupingSets. + */ +case class GroupingID(groupByExprs: Seq[Expression]) extends Expression with Unevaluable { + override def references: AttributeSet = AttributeSet(VirtualColumn.groupingIdAttribute :: Nil) + override def children: Seq[Expression] = groupByExprs + override def dataType: DataType = IntegerType + override def nullable: Boolean = false + override def prettyName: String = "grouping_id" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 8c9853e628d2c..ecd09b7083f2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -17,18 +17,19 @@ package org.apache.spark.sql.catalyst.expressions -import java.io.{StringWriter, ByteArrayOutputStream} +import java.io.{ByteArrayOutputStream, StringWriter} + +import scala.util.parsing.combinator.RegexParsers import com.fasterxml.jackson.core._ + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.types.{StructField, StructType, StringType, DataType} +import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -import scala.util.parsing.combinator.RegexParsers - private[this] sealed trait PathInstruction private[this] object PathInstruction { private[expressions] case object Subscript extends PathInstruction @@ -105,18 +106,22 @@ private[this] object SharedFactory { * Extracts json object from a json string based on json path specified, and returns json string * of the extracted json object. It will return null if the input json string is invalid. */ +@ExpressionDescription( + usage = "_FUNC_(json_txt, path) - Extract a json object from path") case class GetJsonObject(json: Expression, path: Expression) extends BinaryExpression with ExpectsInputTypes with CodegenFallback { - import SharedFactory._ + import com.fasterxml.jackson.core.JsonToken._ + import PathInstruction._ + import SharedFactory._ import WriteStyle._ - import com.fasterxml.jackson.core.JsonToken._ override def left: Expression = json override def right: Expression = path override def inputTypes: Seq[DataType] = Seq(StringType, StringType) override def dataType: DataType = StringType + override def nullable: Boolean = true override def prettyName: String = "get_json_object" @transient private lazy val parsedPath = parsePath(path.eval().asInstanceOf[UTF8String]) @@ -298,8 +303,11 @@ case class GetJsonObject(json: Expression, path: Expression) case (FIELD_NAME, Named(name) :: xs) if p.getCurrentName == name => // exact field match - p.nextToken() - evaluatePath(p, g, style, xs) + if (p.nextToken() != JsonToken.VALUE_NULL) { + evaluatePath(p, g, style, xs) + } else { + false + } case (FIELD_NAME, Wildcard :: xs) => // wildcard field match @@ -313,8 +321,12 @@ case class GetJsonObject(json: Expression, path: Expression) } } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(jsonStr, p1, p2, ..., pn) - like get_json_object, but it takes multiple names and return a tuple. All the input parameters and output column types are string.") +// scalastyle:on line.size.limit case class JsonTuple(children: Seq[Expression]) - extends Expression with CodegenFallback { + extends Generator with CodegenFallback { import SharedFactory._ @@ -324,8 +336,8 @@ case class JsonTuple(children: Seq[Expression]) } // if processing fails this shared value will be returned - @transient private lazy val nullRow: InternalRow = - new GenericInternalRow(Array.ofDim[Any](fieldExpressions.length)) + @transient private lazy val nullRow: Seq[InternalRow] = + new GenericInternalRow(Array.ofDim[Any](fieldExpressions.length)) :: Nil // the json body is the first child @transient private lazy val jsonExpr: Expression = children.head @@ -344,15 +356,8 @@ case class JsonTuple(children: Seq[Expression]) // and count the number of foldable fields, we'll use this later to optimize evaluation @transient private lazy val constantFields: Int = foldableFieldNames.count(_ != null) - override lazy val dataType: StructType = { - val fields = fieldExpressions.zipWithIndex.map { - case (_, idx) => StructField( - name = s"c$idx", // mirroring GenericUDTFJSONTuple.initialize - dataType = StringType, - nullable = true) - } - - StructType(fields) + override def elementTypes: Seq[(DataType, Boolean, String)] = fieldExpressions.zipWithIndex.map { + case (_, idx) => (StringType, true, s"c$idx") } override def prettyName: String = "json_tuple" @@ -367,7 +372,7 @@ case class JsonTuple(children: Seq[Expression]) } } - override def eval(input: InternalRow): InternalRow = { + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { val json = jsonExpr.eval(input).asInstanceOf[UTF8String] if (json == null) { return nullRow @@ -383,7 +388,7 @@ case class JsonTuple(children: Seq[Expression]) } } - private def parseRow(parser: JsonParser, input: InternalRow): InternalRow = { + private def parseRow(parser: JsonParser, input: InternalRow): Seq[InternalRow] = { // only objects are supported if (parser.nextToken() != JsonToken.START_OBJECT) { return nullRow @@ -433,7 +438,7 @@ case class JsonTuple(children: Seq[Expression]) parser.skipChildren() } - new GenericInternalRow(row) + new GenericInternalRow(row) :: Nil } private def copyCurrentStructure(generator: JsonGenerator, parser: JsonParser): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 455fa2427c26d..7fd4bc3066cbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} +import org.json4s.JsonAST._ + import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -26,6 +29,10 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types._ object Literal { + val TrueLiteral: Literal = Literal(true, BooleanType) + + val FalseLiteral: Literal = Literal(false, BooleanType) + def apply(v: Any): Literal = v match { case i: Int => Literal(i, IntegerType) case l: Long => Literal(l, LongType) @@ -44,10 +51,46 @@ object Literal { case a: Array[Byte] => Literal(a, BinaryType) case i: CalendarInterval => Literal(i, CalendarIntervalType) case null => Literal(null, NullType) + case v: Literal => v case _ => throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } + /** + * Constructs a [[Literal]] of [[ObjectType]], for example when you need to pass an object + * into code generation. + */ + def fromObject(obj: Any, objType: DataType): Literal = new Literal(obj, objType) + def fromObject(obj: Any): Literal = new Literal(obj, ObjectType(obj.getClass)) + + def fromJSON(json: JValue): Literal = { + val dataType = DataType.parseDataType(json \ "dataType") + json \ "value" match { + case JNull => Literal.create(null, dataType) + case JString(str) => + val value = dataType match { + case BooleanType => str.toBoolean + case ByteType => str.toByte + case ShortType => str.toShort + case IntegerType => str.toInt + case LongType => str.toLong + case FloatType => str.toFloat + case DoubleType => str.toDouble + case StringType => UTF8String.fromString(str) + case DateType => java.sql.Date.valueOf(str) + case TimestampType => java.sql.Timestamp.valueOf(str) + case CalendarIntervalType => CalendarInterval.fromString(str) + case t: DecimalType => + val d = Decimal(str) + assert(d.changePrecision(t.precision, t.scale)) + d + case _ => null + } + Literal.create(value, dataType) + case other => sys.error(s"$other is not a valid Literal json value") + } + } + def create(v: Any, dataType: DataType): Literal = { Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) } @@ -68,7 +111,7 @@ object Literal { case DateType => create(0, DateType) case TimestampType => create(0L, TimestampType) case StringType => Literal("") - case BinaryType => Literal("".getBytes) + case BinaryType => Literal("".getBytes(StandardCharsets.UTF_8)) case CalendarIntervalType => Literal(new CalendarInterval(0, 0)) case arr: ArrayType => create(Array(), arr) case map: MapType => create(Map(), map) @@ -98,6 +141,24 @@ object IntegerLiteral { } } +/** + * Extractor for and other utility methods for decimal literals. + */ +object DecimalLiteral { + def apply(v: Long): Literal = Literal(Decimal(v)) + + def apply(v: Double): Literal = Literal(Decimal(v)) + + def unapply(e: Expression): Option[Decimal] = e match { + case Literal(v, _: DecimalType) => Some(v.asInstanceOf[Decimal]) + case _ => None + } + + def largerThanLargestLong(v: Decimal): Boolean = v > Decimal(Long.MaxValue) + + def smallerThanSmallestLong(v: Decimal): Boolean = v < Decimal(Long.MinValue) +} + /** * In order to do type checking, use Literal.create() instead of constructor */ @@ -116,9 +177,21 @@ case class Literal protected (value: Any, dataType: DataType) case _ => false } + override protected def jsonFields: List[JField] = { + // Turns all kinds of literal values to string in json field, as the type info is hard to + // retain in json format, e.g. {"a": 123} can be a int, or double, or decimal, etc. + val jsonValue = (value, dataType) match { + case (null, _) => JNull + case (i: Int, DateType) => JString(DateTimeUtils.toJavaDate(i).toString) + case (l: Long, TimestampType) => JString(DateTimeUtils.toJavaTimestamp(l).toString) + case (other, _) => JString(other.toString) + } + ("value" -> jsonValue) :: ("dataType" -> dataType.jsonValue) :: Nil + } + override def eval(input: InternalRow): Any = value - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { // change the isNull and primitive to consts, to inline them if (value == null) { ev.isNull = "true" @@ -132,7 +205,7 @@ case class Literal protected (value: Any, dataType: DataType) case FloatType => val v = value.asInstanceOf[Float] if (v.isNaN || v.isInfinite) { - super.genCode(ctx, ev) + super[CodegenFallback].genCode(ctx, ev) } else { ev.isNull = "false" ev.value = s"${value}f" @@ -141,7 +214,7 @@ case class Literal protected (value: Any, dataType: DataType) case DoubleType => val v = value.asInstanceOf[Double] if (v.isNaN || v.isInfinite) { - super.genCode(ctx, ev) + super[CodegenFallback].genCode(ctx, ev) } else { ev.isNull = "false" ev.value = s"${value}D" @@ -161,19 +234,26 @@ case class Literal protected (value: Any, dataType: DataType) "" // eval() version may be faster for non-primitive types case other => - super.genCode(ctx, ev) + super[CodegenFallback].genCode(ctx, ev) } } } -} - -// TODO: Specialize -case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true) - extends LeafExpression with CodegenFallback { - def update(expression: Expression, input: InternalRow): Unit = { - value = expression.eval(input) + override def sql: String = (value, dataType) match { + case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null => "NULL" + case _ if value == null => s"CAST(NULL AS ${dataType.sql})" + case (v: UTF8String, StringType) => + // Escapes all backslashes and double quotes. + "\"" + v.toString.replace("\\", "\\\\").replace("\"", "\\\"") + "\"" + case (v: Byte, ByteType) => v + "Y" + case (v: Short, ShortType) => v + "S" + case (v: Long, LongType) => v + "L" + // Float type doesn't have a suffix + case (v: Float, FloatType) => s"CAST($v AS ${FloatType.sql})" + case (v: Double, DoubleType) => v + "D" + case (v: Decimal, t: DecimalType) => s"CAST($v AS ${t.sql})" + case (v: Int, DateType) => s"DATE '${DateTimeUtils.toJavaDate(v)}'" + case (v: Long, TimestampType) => s"TIMESTAMP('${DateTimeUtils.toJavaTimestamp(v)}')" + case _ => value.toString } - - override def eval(input: InternalRow): Any = value } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 28f616fbb9ca5..c8a28e847745c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.{lang => jl} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.NumberConverter @@ -42,6 +42,7 @@ abstract class LeafMathExpression(c: Double, name: String) override def foldable: Boolean = true override def nullable: Boolean = false override def toString: String = s"$name()" + override def prettyName: String = name override def eval(input: InternalRow): Any = c } @@ -49,6 +50,7 @@ abstract class LeafMathExpression(c: Double, name: String) /** * A unary expression specifically for math functions. Math Functions expect a specific type of * input format, therefore these functions extend `ExpectsInputTypes`. + * * @param f The math function. * @param name The short name of the function */ @@ -59,6 +61,7 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) override def dataType: DataType = DoubleType override def nullable: Boolean = true override def toString: String = s"$name($child)" + override def prettyName: String = name protected override def nullSafeEval(input: Any): Any = { f(input.asInstanceOf[Double]) @@ -67,7 +70,7 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) // name of function in java.lang.Math def funcName: String = name.toLowerCase - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)") } } @@ -75,6 +78,8 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) abstract class UnaryLogExpression(f: Double => Double, name: String) extends UnaryMathExpression(f, name) { + override def nullable: Boolean = true + // values less than or equal to yAsymptote eval to null in Hive, instead of NaN or -Infinity protected val yAsymptote: Double = 0.0 @@ -83,7 +88,7 @@ abstract class UnaryLogExpression(f: Double => Double, name: String) if (d <= yAsymptote) null else f(d) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, c => s""" if ($c <= $yAsymptote) { @@ -99,6 +104,7 @@ abstract class UnaryLogExpression(f: Double => Double, name: String) /** * A binary expression specifically for math functions that take two `Double`s as input and returns * a `Double`. + * * @param f The math function. * @param name The short name of the function */ @@ -109,13 +115,15 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) override def toString: String = s"$name($left, $right)" + override def prettyName: String = name + override def dataType: DataType = DoubleType protected override def nullSafeEval(input1: Any, input2: Any): Any = { f(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)") } } @@ -130,12 +138,18 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) * Euler's number. Note that there is no code generation because this is only * evaluated by the optimizer during constant folding. */ +@ExpressionDescription( + usage = "_FUNC_() - Returns Euler's number, E.", + extended = "> SELECT _FUNC_();\n 2.718281828459045") case class EulerNumber() extends LeafMathExpression(math.E, "E") /** * Pi. Note that there is no code generation because this is only * evaluated by the optimizer during constant folding. */ +@ExpressionDescription( + usage = "_FUNC_() - Returns PI.", + extended = "> SELECT _FUNC_();\n 3.141592653589793") case class Pi() extends LeafMathExpression(math.Pi, "PI") //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -144,14 +158,29 @@ case class Pi() extends LeafMathExpression(math.Pi, "PI") //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the arc cosine of x if -1<=x<=1 or NaN otherwise.", + extended = "> SELECT _FUNC_(1);\n 0.0\n> SELECT _FUNC_(2);\n NaN") case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the arc sin of x if -1<=x<=1 or NaN otherwise.", + extended = "> SELECT _FUNC_(0);\n 0.0\n> SELECT _FUNC_(2);\n NaN") case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the arc tangent.", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the cube root of a double value.", + extended = "> SELECT _FUNC_(27.0);\n 3.0") case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the smallest integer not smaller than x.", + extended = "> SELECT _FUNC_(-0.1);\n 0\n> SELECT _FUNC_(5);\n 5") case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") { override def dataType: DataType = child.dataType match { case dt @ DecimalType.Fixed(_, 0) => dt @@ -168,7 +197,7 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].ceil } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { child.dataType match { case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") case DecimalType.Fixed(precision, scale) => @@ -178,22 +207,33 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" } } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the cosine of x.", + extended = "> SELECT _FUNC_(0);\n 1.0") case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the hyperbolic cosine of x.", + extended = "> SELECT _FUNC_(0);\n 1.0") case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH") /** * Convert a num from one base to another + * * @param numExpr the number to be converted * @param fromBaseExpr from which base * @param toBaseExpr to which base */ +@ExpressionDescription( + usage = "_FUNC_(num, from_base, to_base) - Convert num from from_base to to_base.", + extended = "> SELECT _FUNC_('100', 2, 10);\n '4'\n> SELECT _FUNC_(-10, 16, -10);\n '16'") case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) extends TernaryExpression with ImplicitCastInputTypes { override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr) override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType) override def dataType: DataType = StringType + override def nullable: Boolean = true override def nullSafeEval(num: Any, fromBase: Any, toBase: Any): Any = { NumberConverter.convert( @@ -202,7 +242,7 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre toBase.asInstanceOf[Int]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val numconv = NumberConverter.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (num, from, to) => s""" @@ -215,10 +255,19 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre } } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns e to the power of x.", + extended = "> SELECT _FUNC_(0);\n 1.0") case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns exp(x) - 1.", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the largest integer not greater than x.", + extended = "> SELECT _FUNC_(-0.1);\n -1\n> SELECT _FUNC_(5);\n 5") case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") { override def dataType: DataType = child.dataType match { case dt @ DecimalType.Fixed(_, 0) => dt @@ -235,7 +284,7 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].floor } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { child.dataType match { case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") case DecimalType.Fixed(precision, scale) => @@ -276,6 +325,9 @@ object Factorial { ) } +@ExpressionDescription( + usage = "_FUNC_(n) - Returns n factorial for n is [0..20]. Otherwise, NULL.", + extended = "> SELECT _FUNC_(5);\n 120") case class Factorial(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[DataType] = Seq(IntegerType) @@ -294,7 +346,7 @@ case class Factorial(child: Expression) extends UnaryExpression with ImplicitCas } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, eval => { s""" if ($eval > 20 || $eval < 0) { @@ -308,11 +360,17 @@ case class Factorial(child: Expression) extends UnaryExpression with ImplicitCas } } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the natural logarithm of x with base e.", + extended = "> SELECT _FUNC_(1);\n 0.0") case class Log(child: Expression) extends UnaryLogExpression(math.log, "LOG") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the logarithm of x with base 2.", + extended = "> SELECT _FUNC_(2);\n 1.0") case class Log2(child: Expression) extends UnaryLogExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, c => s""" if ($c <= $yAsymptote) { @@ -325,36 +383,72 @@ case class Log2(child: Expression) } } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the logarithm of x with base 10.", + extended = "> SELECT _FUNC_(10);\n 1.0") case class Log10(child: Expression) extends UnaryLogExpression(math.log10, "LOG10") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns log(1 + x).", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Log1p(child: Expression) extends UnaryLogExpression(math.log1p, "LOG1P") { protected override val yAsymptote: Double = -1.0 } +@ExpressionDescription( + usage = "_FUNC_(x, d) - Return the rounded x at d decimal places.", + extended = "> SELECT _FUNC_(12.3456, 1);\n 12.3") case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") { override def funcName: String = "rint" } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the sign of x.", + extended = "> SELECT _FUNC_(40);\n 1.0") case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the sine of x.", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the hyperbolic sine of x.", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the square root of x.", + extended = "> SELECT _FUNC_(4);\n 2.0") case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the tangent of x.", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the hyperbolic tangent of x.", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH") +@ExpressionDescription( + usage = "_FUNC_(x) - Converts radians to degrees.", + extended = "> SELECT _FUNC_(3.141592653589793);\n 180.0") case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") { override def funcName: String = "toDegrees" } +@ExpressionDescription( + usage = "_FUNC_(x) - Converts degrees to radians.", + extended = "> SELECT _FUNC_(180);\n 3.141592653589793") case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") { override def funcName: String = "toRadians" } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns x in binary.", + extended = "> SELECT _FUNC_(13);\n '1101'") case class Bin(child: Expression) extends UnaryExpression with Serializable with ImplicitCastInputTypes { @@ -364,7 +458,7 @@ case class Bin(child: Expression) protected override def nullSafeEval(input: Any): Any = UTF8String.fromString(jl.Long.toBinaryString(input.asInstanceOf[Long])) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c) => s"UTF8String.fromString(java.lang.Long.toBinaryString($c))") } @@ -446,6 +540,9 @@ object Hex { * Otherwise if the number is a STRING, it converts each character into its hex representation * and returns the resulting STRING. Negative numbers would be treated as two's complement. */ +@ExpressionDescription( + usage = "_FUNC_(x) - Convert the argument to hexadecimal.", + extended = "> SELECT _FUNC_(17);\n '11'\n> SELECT _FUNC_('Spark SQL');\n '537061726B2053514C'") case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = @@ -459,7 +556,7 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInput case StringType => Hex.hex(num.asInstanceOf[UTF8String].getBytes) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (c) => { val hex = Hex.getClass.getName.stripSuffix("$") s"${ev.value} = " + (child.dataType match { @@ -474,6 +571,9 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInput * Performs the inverse operation of HEX. * Resulting characters are returned as a byte array. */ +@ExpressionDescription( + usage = "_FUNC_(x) - Converts hexadecimal argument to binary.", + extended = "> SELECT decode(_FUNC_('537061726B2053514C'),'UTF-8');\n 'Spark SQL'") case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(StringType) @@ -484,7 +584,7 @@ case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInp protected override def nullSafeEval(num: Any): Any = Hex.unhex(num.asInstanceOf[UTF8String].getBytes) - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (c) => { val hex = Hex.getClass.getName.stripSuffix("$") s""" @@ -502,7 +602,9 @@ case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInp //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// - +@ExpressionDescription( + usage = "_FUNC_(x,y) - Returns the arc tangent2.", + extended = "> SELECT _FUNC_(0, 0);\n 0.0") case class Atan2(left: Expression, right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { @@ -511,24 +613,31 @@ case class Atan2(left: Expression, right: Expression) math.atan2(input1.asInstanceOf[Double] + 0.0, input2.asInstanceOf[Double] + 0.0) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") } } +@ExpressionDescription( + usage = "_FUNC_(x1, x2) - Raise x1 to the power of x2.", + extended = "> SELECT _FUNC_(2, 3);\n 8.0") case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") { - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") } } /** - * Bitwise unsigned left shift. + * Bitwise left shift. + * * @param left the base number to shift. * @param right number of bits to left shift. */ +@ExpressionDescription( + usage = "_FUNC_(a, b) - Bitwise left shift.", + extended = "> SELECT _FUNC_(2, 1);\n 4") case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -544,17 +653,21 @@ case class ShiftLeft(left: Expression, right: Expression) } } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (left, right) => s"$left << $right") } } /** - * Bitwise unsigned left shift. + * Bitwise right shift. + * * @param left the base number to shift. - * @param right number of bits to left shift. + * @param right number of bits to right shift. */ +@ExpressionDescription( + usage = "_FUNC_(a, b) - Bitwise right shift.", + extended = "> SELECT _FUNC_(4, 1);\n 2") case class ShiftRight(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -570,7 +683,7 @@ case class ShiftRight(left: Expression, right: Expression) } } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (left, right) => s"$left >> $right") } } @@ -578,9 +691,13 @@ case class ShiftRight(left: Expression, right: Expression) /** * Bitwise unsigned right shift, for integer and long data type. + * * @param left the base number. * @param right the number of bits to right shift. */ +@ExpressionDescription( + usage = "_FUNC_(a, b) - Bitwise unsigned right shift.", + extended = "> SELECT _FUNC_(4, 1);\n 2") case class ShiftRightUnsigned(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -596,21 +713,27 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) } } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (left, right) => s"$left >>> $right") } } - +@ExpressionDescription( + usage = "_FUNC_(a, b) - Returns sqrt(a**2 + b**2).", + extended = "> SELECT _FUNC_(3, 4);\n 5.0") case class Hypot(left: Expression, right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") /** * Computes the logarithm of a number. + * * @param left the logarithm base, default to e. * @param right the number to compute the logarithm of. */ +@ExpressionDescription( + usage = "_FUNC_(b, x) - Returns the logarithm of x with base b.", + extended = "> SELECT _FUNC_(10, 100);\n 2.0") case class Logarithm(left: Expression, right: Expression) extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") { @@ -621,6 +744,8 @@ case class Logarithm(left: Expression, right: Expression) this(EulerNumber(), child) } + override def nullable: Boolean = true + protected override def nullSafeEval(input1: Any, input2: Any): Any = { val dLeft = input1.asInstanceOf[Double] val dRight = input2.asInstanceOf[Double] @@ -628,7 +753,7 @@ case class Logarithm(left: Expression, right: Expression) if (dLeft <= 0.0 || dRight <= 0.0) null else math.log(dRight) / math.log(dLeft) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { if (left.isInstanceOf[EulerNumber]) { nullSafeCodeGen(ctx, ev, (c1, c2) => s""" @@ -665,6 +790,9 @@ case class Logarithm(left: Expression, right: Expression) * @param child expr to be round, all [[NumericType]] is allowed as Input * @param scale new scale to be round to, this should be a constant int at runtime */ +@ExpressionDescription( + usage = "_FUNC_(x, d) - Round x to d decimal places.", + extended = "> SELECT _FUNC_(12.3456, 1);\n 12.3") case class Round(child: Expression, scale: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -739,7 +867,7 @@ case class Round(child: Expression, scale: Expression) if (f.isNaN || f.isInfinite) { f } else { - BigDecimal(f).setScale(_scale, HALF_UP).toFloat + BigDecimal(f.toDouble).setScale(_scale, HALF_UP).toFloat } case DoubleType => val d = input1.asInstanceOf[Double] @@ -751,7 +879,7 @@ case class Round(child: Expression, scale: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val ce = child.gen(ctx) val evaluationCode = child.dataType match { @@ -795,39 +923,21 @@ case class Round(child: Expression, scale: Expression) s"${ev.value} = ${ce.value};" } case FloatType => // if child eval to NaN or Infinity, just return it. - if (_scale == 0) { - s""" - if (Float.isNaN(${ce.value}) || Float.isInfinite(${ce.value})){ - ${ev.value} = ${ce.value}; - } else { - ${ev.value} = Math.round(${ce.value}); - }""" - } else { - s""" - if (Float.isNaN(${ce.value}) || Float.isInfinite(${ce.value})){ - ${ev.value} = ${ce.value}; - } else { - ${ev.value} = java.math.BigDecimal.valueOf(${ce.value}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue(); - }""" - } + s""" + if (Float.isNaN(${ce.value}) || Float.isInfinite(${ce.value})) { + ${ev.value} = ${ce.value}; + } else { + ${ev.value} = java.math.BigDecimal.valueOf(${ce.value}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue(); + }""" case DoubleType => // if child eval to NaN or Infinity, just return it. - if (_scale == 0) { - s""" - if (Double.isNaN(${ce.value}) || Double.isInfinite(${ce.value})){ - ${ev.value} = ${ce.value}; - } else { - ${ev.value} = Math.round(${ce.value}); - }""" - } else { - s""" - if (Double.isNaN(${ce.value}) || Double.isInfinite(${ce.value})){ - ${ev.value} = ${ce.value}; - } else { - ${ev.value} = java.math.BigDecimal.valueOf(${ce.value}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue(); - }""" - } + s""" + if (Double.isNaN(${ce.value}) || Double.isInfinite(${ce.value})) { + ${ev.value} = ${ce.value}; + } else { + ${ev.value} = java.math.BigDecimal.valueOf(${ce.value}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue(); + }""" } if (scaleV == null) { // if scale is null, no need to eval its child at all diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 0f6d02f2e00c2..4bd918ed01ae2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -20,16 +20,26 @@ package org.apache.spark.sql.catalyst.expressions import java.security.{MessageDigest, NoSuchAlgorithmException} import java.util.zip.CRC32 +import scala.annotation.tailrec + import org.apache.commons.codec.digest.DigestUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.hash.Murmur3_x86_32 +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.Platform /** * A function that calculates an MD5 128-bit checksum and returns it as a hex string * For input of type [[BinaryType]] */ +@ExpressionDescription( + usage = "_FUNC_(input) - Returns an MD5 128-bit checksum as a hex string of the input", + extended = "> SELECT _FUNC_('Spark');\n '8cde774d6f7333752ed72cacddb05126'") case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType @@ -39,7 +49,7 @@ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInput protected override def nullSafeEval(input: Any): Any = UTF8String.fromString(DigestUtils.md5Hex(input.asInstanceOf[Array[Byte]])) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") } @@ -53,10 +63,18 @@ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInput * asking for an unsupported SHA function, the return value is NULL. If either argument is NULL or * the hash length is not one of the permitted values, the return value is NULL. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """_FUNC_(input, bitLength) - Returns a checksum of SHA-2 family as a hex string of the input. + SHA-224, SHA-256, SHA-384, and SHA-512 are supported. Bit length of 0 is equivalent to 256.""", + extended = """> SELECT _FUNC_('Spark', 0); + '529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b'""") +// scalastyle:on line.size.limit case class Sha2(left: Expression, right: Expression) extends BinaryExpression with Serializable with ImplicitCastInputTypes { override def dataType: DataType = StringType + override def nullable: Boolean = true override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType) @@ -84,7 +102,7 @@ case class Sha2(left: Expression, right: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val digestUtils = "org.apache.commons.codec.digest.DigestUtils" nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" @@ -117,6 +135,9 @@ case class Sha2(left: Expression, right: Expression) * A function that calculates a sha1 hash value and returns it as a hex string * For input of type [[BinaryType]] or [[StringType]] */ +@ExpressionDescription( + usage = "_FUNC_(input) - Returns a sha1 hash value as a hex string of the input", + extended = "> SELECT _FUNC_('Spark');\n '85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c'") case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType @@ -124,11 +145,11 @@ case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInpu override def inputTypes: Seq[DataType] = Seq(BinaryType) protected override def nullSafeEval(input: Any): Any = - UTF8String.fromString(DigestUtils.shaHex(input.asInstanceOf[Array[Byte]])) + UTF8String.fromString(DigestUtils.sha1Hex(input.asInstanceOf[Array[Byte]])) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => - s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.shaHex($c))" + s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.sha1Hex($c))" ) } } @@ -137,6 +158,9 @@ case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInpu * A function that computes a cyclic redundancy check value and returns it as a bigint * For input of type [[BinaryType]] */ +@ExpressionDescription( + usage = "_FUNC_(input) - Returns a cyclic redundancy check value as a bigint of the input", + extended = "> SELECT _FUNC_('Spark');\n '1557323817'") case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = LongType @@ -149,7 +173,7 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp checksum.getValue } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val CRC32 = "java.util.zip.CRC32" nullSafeCodeGen(ctx, ev, value => { s""" @@ -160,3 +184,331 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp }) } } + + +/** + * A function that calculates hash value for a group of expressions. Note that the `seed` argument + * is not exposed to users and should only be set inside spark SQL. + * + * The hash value for an expression depends on its type and seed: + * - null: seed + * - boolean: turn boolean into int, 1 for true, 0 for false, and then use murmur3 to + * hash this int with seed. + * - byte, short, int: use murmur3 to hash the input as int with seed. + * - long: use murmur3 to hash the long input with seed. + * - float: turn it into int: java.lang.Float.floatToIntBits(input), and hash it. + * - double: turn it into long: java.lang.Double.doubleToLongBits(input), and hash it. + * - decimal: if it's a small decimal, i.e. precision <= 18, turn it into long and hash + * it. Else, turn it into bytes and hash it. + * - calendar interval: hash `microseconds` first, and use the result as seed to hash `months`. + * - binary: use murmur3 to hash the bytes with seed. + * - string: get the bytes of string and hash it. + * - array: The `result` starts with seed, then use `result` as seed, recursively + * calculate hash value for each element, and assign the element hash value + * to `result`. + * - map: The `result` starts with seed, then use `result` as seed, recursively + * calculate hash value for each key-value, and assign the key-value hash + * value to `result`. + * - struct: The `result` starts with seed, then use `result` as seed, recursively + * calculate hash value for each field, and assign the field hash value to + * `result`. + * + * Finally we aggregate the hash values for each expression by the same way of struct. + */ +abstract class HashExpression[E] extends Expression { + /** Seed of the HashExpression. */ + val seed: E + + override def foldable: Boolean = children.forall(_.foldable) + + override def nullable: Boolean = false + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.isEmpty) { + TypeCheckResult.TypeCheckFailure("function hash requires at least one argument") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def eval(input: InternalRow): Any = { + var hash = seed + var i = 0 + val len = children.length + while (i < len) { + hash = computeHash(children(i).eval(input), children(i).dataType, hash) + i += 1 + } + hash + } + + protected def computeHash(value: Any, dataType: DataType, seed: E): E + + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + ev.isNull = "false" + val childrenHash = children.map { child => + val childGen = child.gen(ctx) + childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { + computeHash(childGen.value, child.dataType, ev.value, ctx) + } + }.mkString("\n") + + s""" + ${ctx.javaType(dataType)} ${ev.value} = $seed; + $childrenHash + """ + } + + private def nullSafeElementHash( + input: String, + index: String, + nullable: Boolean, + elementType: DataType, + result: String, + ctx: CodegenContext): String = { + val element = ctx.freshName("element") + + ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") { + s""" + final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)}; + ${computeHash(element, elementType, result, ctx)} + """ + } + } + + @tailrec + private def computeHash( + input: String, + dataType: DataType, + result: String, + ctx: CodegenContext): String = { + val hasher = hasherClassName + + def hashInt(i: String): String = s"$result = $hasher.hashInt($i, $result);" + def hashLong(l: String): String = s"$result = $hasher.hashLong($l, $result);" + def hashBytes(b: String): String = + s"$result = $hasher.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length, $result);" + + dataType match { + case NullType => "" + case BooleanType => hashInt(s"$input ? 1 : 0") + case ByteType | ShortType | IntegerType | DateType => hashInt(input) + case LongType | TimestampType => hashLong(input) + case FloatType => hashInt(s"Float.floatToIntBits($input)") + case DoubleType => hashLong(s"Double.doubleToLongBits($input)") + case d: DecimalType => + if (d.precision <= Decimal.MAX_LONG_DIGITS) { + hashLong(s"$input.toUnscaledLong()") + } else { + val bytes = ctx.freshName("bytes") + s""" + final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray(); + ${hashBytes(bytes)} + """ + } + case CalendarIntervalType => + val microsecondsHash = s"$hasher.hashLong($input.microseconds, $result)" + s"$result = $hasher.hashInt($input.months, $microsecondsHash);" + case BinaryType => hashBytes(input) + case StringType => + val baseObject = s"$input.getBaseObject()" + val baseOffset = s"$input.getBaseOffset()" + val numBytes = s"$input.numBytes()" + s"$result = $hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" + + case ArrayType(et, containsNull) => + val index = ctx.freshName("index") + s""" + for (int $index = 0; $index < $input.numElements(); $index++) { + ${nullSafeElementHash(input, index, containsNull, et, result, ctx)} + } + """ + + case MapType(kt, vt, valueContainsNull) => + val index = ctx.freshName("index") + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + s""" + final ArrayData $keys = $input.keyArray(); + final ArrayData $values = $input.valueArray(); + for (int $index = 0; $index < $input.numElements(); $index++) { + ${nullSafeElementHash(keys, index, false, kt, result, ctx)} + ${nullSafeElementHash(values, index, valueContainsNull, vt, result, ctx)} + } + """ + + case StructType(fields) => + fields.zipWithIndex.map { case (field, index) => + nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) + }.mkString("\n") + + case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, result, ctx) + } + } + + protected def hasherClassName: String +} + +/** + * Base class for interpreted hash functions. + */ +abstract class InterpretedHashFunction { + protected def hashInt(i: Int, seed: Long): Long + + protected def hashLong(l: Long, seed: Long): Long + + protected def hashUnsafeBytes(base: AnyRef, offset: Long, length: Int, seed: Long): Long + + def hash(value: Any, dataType: DataType, seed: Long): Long = { + value match { + case null => seed + case b: Boolean => hashInt(if (b) 1 else 0, seed) + case b: Byte => hashInt(b, seed) + case s: Short => hashInt(s, seed) + case i: Int => hashInt(i, seed) + case l: Long => hashLong(l, seed) + case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed) + case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) + case d: Decimal => + val precision = dataType.asInstanceOf[DecimalType].precision + if (precision <= Decimal.MAX_LONG_DIGITS) { + hashLong(d.toUnscaledLong, seed) + } else { + val bytes = d.toJavaBigDecimal.unscaledValue().toByteArray + hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, seed) + } + case c: CalendarInterval => hashInt(c.months, hashLong(c.microseconds, seed)) + case a: Array[Byte] => + hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed) + case s: UTF8String => + hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed) + + case array: ArrayData => + val elementType = dataType match { + case udt: UserDefinedType[_] => udt.sqlType.asInstanceOf[ArrayType].elementType + case ArrayType(et, _) => et + } + var result = seed + var i = 0 + while (i < array.numElements()) { + result = hash(array.get(i, elementType), elementType, result) + i += 1 + } + result + + case map: MapData => + val (kt, vt) = dataType match { + case udt: UserDefinedType[_] => + val mapType = udt.sqlType.asInstanceOf[MapType] + mapType.keyType -> mapType.valueType + case MapType(kt, vt, _) => kt -> vt + } + val keys = map.keyArray() + val values = map.valueArray() + var result = seed + var i = 0 + while (i < map.numElements()) { + result = hash(keys.get(i, kt), kt, result) + result = hash(values.get(i, vt), vt, result) + i += 1 + } + result + + case struct: InternalRow => + val types: Array[DataType] = dataType match { + case udt: UserDefinedType[_] => + udt.sqlType.asInstanceOf[StructType].map(_.dataType).toArray + case StructType(fields) => fields.map(_.dataType) + } + var result = seed + var i = 0 + val len = struct.numFields + while (i < len) { + result = hash(struct.get(i, types(i)), types(i), result) + i += 1 + } + result + } + } +} + +/** + * A MurMur3 Hash expression. + * + * We should use this hash function for both shuffle and bucket, so that we can guarantee shuffle + * and bucketing have same data distribution. + */ +@ExpressionDescription( + usage = "_FUNC_(a1, a2, ...) - Returns a hash value of the arguments.") +case class Murmur3Hash(children: Seq[Expression], seed: Int) extends HashExpression[Int] { + def this(arguments: Seq[Expression]) = this(arguments, 42) + + override def dataType: DataType = IntegerType + + override def prettyName: String = "hash" + + override protected def hasherClassName: String = classOf[Murmur3_x86_32].getName + + override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = { + Murmur3HashFunction.hash(value, dataType, seed).toInt + } +} + +object Murmur3HashFunction extends InterpretedHashFunction { + override protected def hashInt(i: Int, seed: Long): Long = { + Murmur3_x86_32.hashInt(i, seed.toInt) + } + + override protected def hashLong(l: Long, seed: Long): Long = { + Murmur3_x86_32.hashLong(l, seed.toInt) + } + + override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + Murmur3_x86_32.hashUnsafeBytes(base, offset, len, seed.toInt) + } +} + +/** + * Print the result of an expression to stderr (used for debugging codegen). + */ +case class PrintToStderr(child: Expression) extends UnaryExpression { + + override def dataType: DataType = child.dataType + + protected override def nullSafeEval(input: Any): Any = input + + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + nullSafeCodeGen(ctx, ev, c => + s""" + | System.err.println("Result of ${child.simpleString} is " + $c); + | ${ev.value} = $c; + """.stripMargin) + } +} + +/** + * A xxHash64 64-bit hash expression. + */ +case class XxHash64(children: Seq[Expression], seed: Long) extends HashExpression[Long] { + def this(arguments: Seq[Expression]) = this(arguments, 42L) + + override def dataType: DataType = LongType + + override def prettyName: String = "xxHash" + + override protected def hasherClassName: String = classOf[XXH64].getName + + override protected def computeHash(value: Any, dataType: DataType, seed: Long): Long = { + XxHash64Function.hash(value, dataType, seed) + } +} + +object XxHash64Function extends InterpretedHashFunction { + override protected def hashInt(i: Int, seed: Long): Long = XXH64.hashInt(i, seed) + + override protected def hashLong(l: Long, seed: Long): Long = XXH64.hashLong(l, seed) + + override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + XXH64.hashUnsafeBytes(base, offset, len, seed) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 8957df0be6814..78310fb2f1539 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -22,6 +22,7 @@ import java.util.UUID import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.types._ object NamedExpression { @@ -60,10 +61,10 @@ trait NamedExpression extends Expression { * multiple qualifiers, it is possible that there are other possible way to refer to this * attribute. */ - def qualifiedName: String = (qualifiers.headOption.toSeq :+ name).mkString(".") + def qualifiedName: String = (qualifier.toSeq :+ name).mkString(".") /** - * All possible qualifiers for the expression. + * Optional qualifier for the expression. * * For now, since we do not allow using original table name to qualify a column name once the * table is aliased, this can only be: @@ -72,13 +73,19 @@ trait NamedExpression extends Expression { * e.g. top level attributes aliased in the SELECT clause, or column from a LocalRelation. * 2. Single element: either the table name or the alias name of the table. */ - def qualifiers: Seq[String] + def qualifier: Option[String] def toAttribute: Attribute /** Returns the metadata when an expression is a reference to another expression with metadata. */ def metadata: Metadata = Metadata.empty + /** Returns true if the expression is generated by Catalyst */ + def isGenerated: java.lang.Boolean = false + + /** Returns a copy of this expression with a new `exprId`. */ + def newInstance(): NamedExpression + protected def typeSuffix = if (resolved) { dataType match { @@ -90,12 +97,12 @@ trait NamedExpression extends Expression { } } -abstract class Attribute extends LeafExpression with NamedExpression { +abstract class Attribute extends LeafExpression with NamedExpression with NullIntolerant { override def references: AttributeSet = AttributeSet(this) def withNullability(newNullability: Boolean): Attribute - def withQualifiers(newQualifiers: Seq[String]): Attribute + def withQualifier(newQualifier: Option[String]): Attribute def withName(newName: String): Attribute override def toAttribute: Attribute = this @@ -111,16 +118,21 @@ abstract class Attribute extends LeafExpression with NamedExpression { * Note that exprId and qualifiers are in a separate parameter list because * we only pattern match on child and name. * - * @param child the computation being performed - * @param name the name to be associated with the result of computing [[child]]. + * @param child The computation being performed + * @param name The name to be associated with the result of computing [[child]]. * @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this * alias. Auto-assigned if left blank. + * @param qualifier An optional string that can be used to referred to this attribute in a fully + * qualified way. Consider the examples tableName.name, subQueryAlias.name. + * tableName and subQueryAlias are possible qualifiers. * @param explicitMetadata Explicit metadata associated with this alias that overwrites child's. + * @param isGenerated A flag to indicate if this alias is generated by Catalyst */ case class Alias(child: Expression, name: String)( val exprId: ExprId = NamedExpression.newExprId, - val qualifiers: Seq[String] = Nil, - val explicitMetadata: Option[Metadata] = None) + val qualifier: Option[String] = None, + val explicitMetadata: Option[Metadata] = None, + override val isGenerated: java.lang.Boolean = false) extends UnaryExpression with NamedExpression { // Alias(Generator, xx) need to be transformed into Generate(generator, ...) @@ -130,8 +142,8 @@ case class Alias(child: Expression, name: String)( override def eval(input: InternalRow): Any = child.eval(input) /** Just a simple passthrough for code generation. */ - override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" + override def gen(ctx: CodegenContext): ExprCode = child.gen(ctx) + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = "" override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable @@ -144,9 +156,14 @@ case class Alias(child: Expression, name: String)( } } + def newInstance(): NamedExpression = + Alias(child, name)( + qualifier = qualifier, explicitMetadata = explicitMetadata, isGenerated = isGenerated) + override def toAttribute: Attribute = { if (resolved) { - AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifiers) + AttributeReference(name, child.dataType, child.nullable, metadata)( + exprId, qualifier, isGenerated) } else { UnresolvedAttribute(name) } @@ -155,15 +172,20 @@ case class Alias(child: Expression, name: String)( override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix" override protected final def otherCopyArgs: Seq[AnyRef] = { - exprId :: qualifiers :: explicitMetadata :: Nil + exprId :: qualifier :: explicitMetadata :: isGenerated :: Nil } override def equals(other: Any): Boolean = other match { case a: Alias => - name == a.name && exprId == a.exprId && child == a.child && qualifiers == a.qualifiers && + name == a.name && exprId == a.exprId && child == a.child && qualifier == a.qualifier && explicitMetadata == a.explicitMetadata case _ => false } + + override def sql: String = { + val qualifierPrefix = qualifier.map(_ + ".").getOrElse("") + s"${child.sql} AS $qualifierPrefix${quoteIdentifier(name)}" + } } /** @@ -175,9 +197,10 @@ case class Alias(child: Expression, name: String)( * @param metadata The metadata of this attribute. * @param exprId A globally unique id used to check if different AttributeReferences refer to the * same attribute. - * @param qualifiers a list of strings that can be used to referred to this attribute in a fully - * qualified way. Consider the examples tableName.name, subQueryAlias.name. - * tableName and subQueryAlias are possible qualifiers. + * @param qualifier An optional string that can be used to referred to this attribute in a fully + * qualified way. Consider the examples tableName.name, subQueryAlias.name. + * tableName and subQueryAlias are possible qualifiers. + * @param isGenerated A flag to indicate if this reference is generated by Catalyst */ case class AttributeReference( name: String, @@ -185,7 +208,8 @@ case class AttributeReference( nullable: Boolean = true, override val metadata: Metadata = Metadata.empty)( val exprId: ExprId = NamedExpression.newExprId, - val qualifiers: Seq[String] = Nil) + val qualifier: Option[String] = None, + override val isGenerated: java.lang.Boolean = false) extends Attribute with Unevaluable { /** @@ -194,7 +218,9 @@ case class AttributeReference( def sameRef(other: AttributeReference): Boolean = this.exprId == other.exprId override def equals(other: Any): Boolean = other match { - case ar: AttributeReference => name == ar.name && exprId == ar.exprId && dataType == ar.dataType + case ar: AttributeReference => + name == ar.name && dataType == ar.dataType && nullable == ar.nullable && + metadata == ar.metadata && exprId == ar.exprId && qualifier == ar.qualifier case _ => false } @@ -203,17 +229,25 @@ case class AttributeReference( case _ => false } + override def semanticHash(): Int = { + this.exprId.hashCode() + } + override def hashCode: Int = { // See http://stackoverflow.com/questions/113511/hash-code-implementation var h = 17 - h = h * 37 + exprId.hashCode() + h = h * 37 + name.hashCode() h = h * 37 + dataType.hashCode() + h = h * 37 + nullable.hashCode() h = h * 37 + metadata.hashCode() + h = h * 37 + exprId.hashCode() + h = h * 37 + qualifier.hashCode() h } override def newInstance(): AttributeReference = - AttributeReference(name, dataType, nullable, metadata)(qualifiers = qualifiers) + AttributeReference(name, dataType, nullable, metadata)( + qualifier = qualifier, isGenerated = isGenerated) /** * Returns a copy of this [[AttributeReference]] with changed nullability. @@ -222,7 +256,7 @@ case class AttributeReference( if (nullable == newNullability) { this } else { - AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifiers) + AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifier, isGenerated) } } @@ -230,18 +264,18 @@ case class AttributeReference( if (name == newName) { this } else { - AttributeReference(newName, dataType, nullable)(exprId, qualifiers) + AttributeReference(newName, dataType, nullable, metadata)(exprId, qualifier, isGenerated) } } /** - * Returns a copy of this [[AttributeReference]] with new qualifiers. + * Returns a copy of this [[AttributeReference]] with new qualifier. */ - override def withQualifiers(newQualifiers: Seq[String]): AttributeReference = { - if (newQualifiers.toSet == qualifiers.toSet) { + override def withQualifier(newQualifier: Option[String]): AttributeReference = { + if (newQualifier == qualifier) { this } else { - AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifiers) + AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifier, isGenerated) } } @@ -249,34 +283,58 @@ case class AttributeReference( if (exprId == newExprId) { this } else { - AttributeReference(name, dataType, nullable, metadata)(newExprId, qualifiers) + AttributeReference(name, dataType, nullable, metadata)(newExprId, qualifier, isGenerated) } } + override protected final def otherCopyArgs: Seq[AnyRef] = { + exprId :: qualifier :: isGenerated :: Nil + } + override def toString: String = s"$name#${exprId.id}$typeSuffix" + + // Since the expression id is not in the first constructor it is missing from the default + // tree string. + override def simpleString: String = s"$name#${exprId.id}: ${dataType.simpleString}" + + override def sql: String = { + val qualifierPrefix = qualifier.map(_ + ".").getOrElse("") + s"$qualifierPrefix${quoteIdentifier(name)}" + } } /** * A place holder used when printing expressions without debugging information such as the * expression id or the unresolved indicator. */ -case class PrettyAttribute(name: String) extends Attribute with Unevaluable { +case class PrettyAttribute( + name: String, + dataType: DataType = NullType) + extends Attribute with Unevaluable { + + def this(attribute: Attribute) = this(attribute.name, attribute match { + case a: AttributeReference => a.dataType + case a: PrettyAttribute => a.dataType + case _ => NullType + }) override def toString: String = name + override def sql: String = toString override def withNullability(newNullability: Boolean): Attribute = throw new UnsupportedOperationException override def newInstance(): Attribute = throw new UnsupportedOperationException - override def withQualifiers(newQualifiers: Seq[String]): Attribute = + override def withQualifier(newQualifier: Option[String]): Attribute = throw new UnsupportedOperationException override def withName(newName: String): Attribute = throw new UnsupportedOperationException - override def qualifiers: Seq[String] = throw new UnsupportedOperationException + override def qualifier: Option[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException - override def nullable: Boolean = throw new UnsupportedOperationException - override def dataType: DataType = NullType + override def nullable: Boolean = true } object VirtualColumn { - val groupingIdName: String = "grouping__id" + // The attribute name used by Hive, which has different result than Spark, deprecated. + val hiveGroupingIdName: String = "grouping__id" + val groupingIdName: String = "spark_grouping_id" val groupingIdAttribute: UnresolvedAttribute = UnresolvedAttribute(groupingIdName) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 94deafb75b69c..6a452499430c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -34,6 +34,9 @@ import org.apache.spark.sql.types._ * coalesce(null, null, null) => null * }}} */ +@ExpressionDescription( + usage = "_FUNC_(a1, a2, ...) - Returns the first non-null argument if exists. Otherwise, NULL.", + extended = "> SELECT _FUNC_(NULL, 1, NULL);\n 1") case class Coalesce(children: Seq[Expression]) extends Expression { /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ @@ -61,12 +64,16 @@ case class Coalesce(children: Seq[Expression]) extends Expression { result } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + val first = children(0) + val rest = children.drop(1) + val firstEval = first.gen(ctx) s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${firstEval.code} + boolean ${ev.isNull} = ${firstEval.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${firstEval.value}; """ + - children.map { e => + rest.map { e => val eval = e.gen(ctx) s""" if (${ev.isNull}) { @@ -85,6 +92,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { /** * Evaluates to `true` iff it's NaN. */ +@ExpressionDescription( + usage = "_FUNC_(a) - Returns true if a is NaN and false otherwise.") case class IsNaN(child: Expression) extends UnaryExpression with Predicate with ImplicitCastInputTypes { @@ -104,7 +113,7 @@ case class IsNaN(child: Expression) extends UnaryExpression } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval = child.gen(ctx) child.dataType match { case DoubleType | FloatType => @@ -122,6 +131,8 @@ case class IsNaN(child: Expression) extends UnaryExpression * An Expression evaluates to `left` iff it's not NaN, or evaluates to `right` otherwise. * This Expression is useful for mapping NaN values to null. */ +@ExpressionDescription( + usage = "_FUNC_(a,b) - Returns a iff it's not NaN, or b otherwise.") case class NaNvl(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -144,7 +155,7 @@ case class NaNvl(left: Expression, right: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val leftGen = left.gen(ctx) val rightGen = right.gen(ctx) left.dataType match { @@ -176,6 +187,8 @@ case class NaNvl(left: Expression, right: Expression) /** * An expression that is evaluated to true if the input is null. */ +@ExpressionDescription( + usage = "_FUNC_(a) - Returns true if a is NULL and false otherwise.") case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def nullable: Boolean = false @@ -183,18 +196,22 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { child.eval(input) == null } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval = child.gen(ctx) ev.isNull = "false" ev.value = eval.isNull eval.code } + + override def sql: String = s"(${child.sql} IS NULL)" } /** * An expression that is evaluated to true if the input is not null. */ +@ExpressionDescription( + usage = "_FUNC_(a) - Returns true if a is not NULL and false otherwise.") case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def nullable: Boolean = false @@ -202,12 +219,14 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { child.eval(input) != null } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval = child.gen(ctx) ev.isNull = "false" ev.value = s"(!(${eval.isNull}))" eval.code } + + override def sql: String = s"(${child.sql} IS NOT NULL)" } @@ -240,7 +259,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate numNonNulls >= n } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val nonnull = ctx.freshName("nonnull") val code = children.map { e => val eval = e.gen(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 81855289762c6..26b1ff39b3e9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -17,15 +17,18 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer -import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} -import org.apache.spark.sql.catalyst.util.GenericArrayData +import java.lang.reflect.Modifier +import scala.annotation.tailrec import scala.language.existentials +import scala.reflect.ClassTag +import org.apache.spark.SparkConf +import org.apache.spark.serializer._ +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ /** @@ -42,23 +45,21 @@ import org.apache.spark.sql.types._ * of calling the function. */ case class StaticInvoke( - staticObject: Any, + staticObject: Class[_], dataType: DataType, functionName: String, arguments: Seq[Expression] = Nil, - propagateNull: Boolean = true) extends Expression { + propagateNull: Boolean = true) extends Expression with NonSQLExpression { + + val objectName = staticObject.getName.stripSuffix("$") - val objectName = staticObject match { - case c: Class[_] => c.getName - case other => other.getClass.getName.stripSuffix("$") - } override def nullable: Boolean = true override def children: Seq[Expression] = arguments override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val javaType = ctx.javaType(dataType) val argGen = arguments.map(_.gen(ctx)) val argString = argGen.map(_.value).mkString(", ") @@ -110,26 +111,26 @@ case class Invoke( targetObject: Expression, functionName: String, dataType: DataType, - arguments: Seq[Expression] = Nil) extends Expression { + arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression { override def nullable: Boolean = true - override def children: Seq[Expression] = targetObject :: Nil + override def children: Seq[Expression] = targetObject +: arguments override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - lazy val method = targetObject.dataType match { + @transient lazy val method = targetObject.dataType match { case ObjectType(cls) => - cls - .getMethods - .find(_.getName == functionName) - .getOrElse(sys.error(s"Couldn't find $functionName on $cls")) - .getReturnType - .getName - case _ => "" + val m = cls.getMethods.find(_.getName == functionName) + if (m.isEmpty) { + sys.error(s"Couldn't find $functionName on $cls") + } else { + m + } + case _ => None } - lazy val unboxer = (dataType, method) match { + lazy val unboxer = (dataType, method.map(_.getReturnType.getName).getOrElse("")) match { case (IntegerType, "java.lang.Object") => (s: String) => s"((java.lang.Integer)$s).intValue()" case (LongType, "java.lang.Object") => (s: String) => @@ -147,7 +148,7 @@ case class Invoke( case _ => identity[String] _ } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val javaType = ctx.javaType(dataType) val obj = targetObject.gen(ctx) val argGen = arguments.map(_.gen(ctx)) @@ -156,24 +157,45 @@ case class Invoke( // If the function can return null, we do an extra check to make sure our null bit is still set // correctly. val objNullCheck = if (ctx.defaultValue(dataType) == "null") { - s"${ev.isNull} = ${ev.value} == null;" + s"boolean ${ev.isNull} = ${ev.value} == null;" } else { + ev.isNull = obj.isNull "" } val value = unboxer(s"${obj.value}.$functionName($argString)") + val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) { + s"$javaType ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value;" + } else { + s""" + $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + try { + ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) $value; + } catch (Exception e) { + org.apache.spark.unsafe.Platform.throwException(e); + } + """ + } + s""" ${obj.code} ${argGen.map(_.code).mkString("\n")} - - boolean ${ev.isNull} = ${obj.value} == null; - $javaType ${ev.value} = - ${ev.isNull} ? - ${ctx.defaultValue(dataType)} : ($javaType) $value; + $evaluate $objNullCheck """ } + + override def toString: String = s"$targetObject.$functionName" +} + +object NewInstance { + def apply( + cls: Class[_], + arguments: Seq[Expression], + dataType: DataType, + propagateNull: Boolean = true): NewInstance = + new NewInstance(cls, arguments, propagateNull, dataType, None) } /** @@ -187,54 +209,79 @@ case class Invoke( * @param dataType The type of object being constructed, as a Spark SQL datatype. This allows you * to manually specify the type when the object in question is a valid internal * representation (i.e. ArrayData) instead of an object. + * @param outerPointer If the object being constructed is an inner class, the outerPointer for the + * containing class must be specified. This parameter is defined as an optional + * function, which allows us to get the outer pointer lazily,and it's useful if + * the inner class is defined in REPL. */ case class NewInstance( cls: Class[_], arguments: Seq[Expression], - propagateNull: Boolean = true, - dataType: DataType) extends Expression { + propagateNull: Boolean, + dataType: DataType, + outerPointer: Option[() => AnyRef]) extends Expression with NonSQLExpression { private val className = cls.getName override def nullable: Boolean = propagateNull override def children: Seq[Expression] = arguments + override lazy val resolved: Boolean = { + // If the class to construct is an inner class, we need to get its outer pointer, or this + // expression should be regarded as unresolved. + // Note that static inner classes (e.g., inner classes within Scala objects) don't need + // outer pointer registration. + val needOuterPointer = + outerPointer.isEmpty && cls.isMemberClass && !Modifier.isStatic(cls.getModifiers) + childrenResolved && !needOuterPointer + } + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val javaType = ctx.javaType(dataType) val argGen = arguments.map(_.gen(ctx)) val argString = argGen.map(_.value).mkString(", ") - if (propagateNull) { - val objNullCheck = if (ctx.defaultValue(dataType) == "null") { - s"${ev.isNull} = ${ev.value} == null;" - } else { - "" - } + val outer = outerPointer.map(func => Literal.fromObject(func()).gen(ctx)) + val setup = + s""" + ${argGen.map(_.code).mkString("\n")} + ${outer.map(_.code).getOrElse("")} + """.stripMargin + + val constructorCall = outer.map { gen => + s"""${gen.value}.new ${cls.getSimpleName}($argString)""" + }.getOrElse { + s"new $className($argString)" + } + + if (propagateNull && argGen.nonEmpty) { val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" + s""" - ${argGen.map(_.code).mkString("\n")} + $setup boolean ${ev.isNull} = true; $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; - if ($argsNonNull) { - ${ev.value} = new $className($argString); + ${ev.value} = $constructorCall; ${ev.isNull} = false; } """ } else { s""" - ${argGen.map(_.code).mkString("\n")} + $setup - $javaType ${ev.value} = new $className($argString); - final boolean ${ev.isNull} = ${ev.value} == null; + final $javaType ${ev.value} = $constructorCall; + final boolean ${ev.isNull} = false; """ } } + + override def toString: String = s"newInstance($cls)" } /** @@ -246,7 +293,7 @@ case class NewInstance( */ case class UnwrapOption( dataType: DataType, - child: Expression) extends UnaryExpression with ExpectsInputTypes { + child: Expression) extends UnaryExpression with NonSQLExpression with ExpectsInputTypes { override def nullable: Boolean = true @@ -255,7 +302,7 @@ case class UnwrapOption( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val javaType = ctx.javaType(dataType) val inputObject = child.gen(ctx) @@ -272,30 +319,29 @@ case class UnwrapOption( /** * Converts the result of evaluating `child` into an option, checking both the isNull bit and * (in the case of reference types) equality with null. - * @param optionType The datatype to be held inside of the Option. * @param child The expression to evaluate and wrap. + * @param optType The type of this option. */ -case class WrapOption(optionType: DataType, child: Expression) - extends UnaryExpression with ExpectsInputTypes { +case class WrapOption(child: Expression, optType: DataType) + extends UnaryExpression with NonSQLExpression with ExpectsInputTypes { override def dataType: DataType = ObjectType(classOf[Option[_]]) override def nullable: Boolean = true - override def inputTypes: Seq[AbstractDataType] = ObjectType :: Nil + override def inputTypes: Seq[AbstractDataType] = optType :: Nil override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val javaType = ctx.javaType(optionType) + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val inputObject = child.gen(ctx) s""" ${inputObject.code} boolean ${ev.isNull} = false; - scala.Option<$javaType> ${ev.value} = + scala.Option ${ev.value} = ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); """ @@ -306,19 +352,28 @@ case class WrapOption(optionType: DataType, child: Expression) * A place holder for the loop variable used in [[MapObjects]]. This should never be constructed * manually, but will instead be passed into the provided lambda function. */ -case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends Expression { +case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends LeafExpression + with Unevaluable with NonSQLExpression { - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = - throw new UnsupportedOperationException("Only calling gen() is supported.") - - override def children: Seq[Expression] = Nil - override def gen(ctx: CodeGenContext): GeneratedExpressionCode = - GeneratedExpressionCode(code = "", value = value, isNull = isNull) + override def nullable: Boolean = true - override def nullable: Boolean = false - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + override def gen(ctx: CodegenContext): ExprCode = { + ExprCode(code = "", value = value, isNull = isNull) + } +} +object MapObjects { + private val curId = new java.util.concurrent.atomic.AtomicInteger() + + def apply( + function: Expression => Expression, + inputData: Expression, + elementType: DataType): MapObjects = { + val loopValue = "MapObjects_loopValue" + curId.getAndIncrement() + val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement() + val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) + MapObjects(loopVar, function(loopVar), inputData) + } } /** @@ -326,85 +381,80 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext * as an ArrayType. This is similar to a typical map operation, but where the lambda function * is expressed using catalyst expressions. * - * The following collection ObjectTypes are currently supported: Seq, Array, ArrayData + * The following collection ObjectTypes are currently supported: + * Seq, Array, ArrayData, java.util.List * - * @param function A function that returns an expression, given an attribute that can be used - * to access the current value. This is does as a lambda function so that - * a unique attribute reference can be provided for each expression (thus allowing - * us to nest multiple MapObject calls). - * @param inputData An expression that when evaluted returns a collection object. - * @param elementType The type of element in the collection, expressed as a DataType. + * @param loopVar A place holder that used as the loop variable when iterate the collection, and + * used as input for the `lambdaFunction`. It also carries the element type info. + * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function + * to handle collection elements. + * @param inputData An expression that when evaluated returns a collection object. */ -case class MapObjects( - function: AttributeReference => Expression, - inputData: Expression, - elementType: DataType) extends Expression { - - private lazy val loopAttribute = AttributeReference("loopVar", elementType)() - private lazy val completeFunction = function(loopAttribute) +case class MapObjects private( + loopVar: LambdaVariable, + lambdaFunction: Expression, + inputData: Expression) extends Expression with NonSQLExpression { + + @tailrec + private def itemAccessorMethod(dataType: DataType): String => String = dataType match { + case NullType => + val nullTypeClassName = NullType.getClass.getName + ".MODULE$" + (i: String) => s".get($i, $nullTypeClassName)" + case IntegerType => (i: String) => s".getInt($i)" + case LongType => (i: String) => s".getLong($i)" + case FloatType => (i: String) => s".getFloat($i)" + case DoubleType => (i: String) => s".getDouble($i)" + case ByteType => (i: String) => s".getByte($i)" + case ShortType => (i: String) => s".getShort($i)" + case BooleanType => (i: String) => s".getBoolean($i)" + case StringType => (i: String) => s".getUTF8String($i)" + case s: StructType => (i: String) => s".getStruct($i, ${s.size})" + case a: ArrayType => (i: String) => s".getArray($i)" + case _: MapType => (i: String) => s".getMap($i)" + case udt: UserDefinedType[_] => itemAccessorMethod(udt.sqlType) + case DecimalType.Fixed(p, s) => (i: String) => s".getDecimal($i, $p, $s)" + case DateType => (i: String) => s".getInt($i)" + } private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match { case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => (".size()", (i: String) => s".apply($i)", false) case ObjectType(cls) if cls.isArray => (".length", (i: String) => s"[$i]", false) - case ArrayType(s: StructType, _) => - (".numElements()", (i: String) => s".getStruct($i, ${s.size})", false) - case ArrayType(a: ArrayType, _) => - (".numElements()", (i: String) => s".getArray($i)", true) - case ArrayType(IntegerType, _) => - (".numElements()", (i: String) => s".getInt($i)", true) - case ArrayType(LongType, _) => - (".numElements()", (i: String) => s".getLong($i)", true) - case ArrayType(FloatType, _) => - (".numElements()", (i: String) => s".getFloat($i)", true) - case ArrayType(DoubleType, _) => - (".numElements()", (i: String) => s".getDouble($i)", true) - case ArrayType(ByteType, _) => - (".numElements()", (i: String) => s".getByte($i)", true) - case ArrayType(ShortType, _) => - (".numElements()", (i: String) => s".getShort($i)", true) - case ArrayType(BooleanType, _) => - (".numElements()", (i: String) => s".getBoolean($i)", true) - case ArrayType(StringType, _) => - (".numElements()", (i: String) => s".getUTF8String($i)", false) - case ArrayType(_: MapType, _) => - (".numElements()", (i: String) => s".getMap($i)", false) + case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + (".size()", (i: String) => s".get($i)", false) + case ArrayType(t, _) => + val (sqlType, primitiveElement) = t match { + case m: MapType => (m, false) + case s: StructType => (s, false) + case s: StringType => (s, false) + case udt: UserDefinedType[_] => (udt.sqlType, false) + case o => (o, true) + } + (".numElements()", itemAccessorMethod(sqlType), primitiveElement) } override def nullable: Boolean = true - override def children: Seq[Expression] = completeFunction :: inputData :: Nil + override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override def dataType: DataType = ArrayType(completeFunction.dataType) + override def dataType: DataType = ArrayType(lambdaFunction.dataType) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val javaType = ctx.javaType(dataType) - val elementJavaType = ctx.javaType(elementType) + val elementJavaType = ctx.javaType(loopVar.dataType) + ctx.addMutableState("boolean", loopVar.isNull, "") + ctx.addMutableState(elementJavaType, loopVar.value, "") val genInputData = inputData.gen(ctx) - - // Variables to hold the element that is currently being processed. - val loopValue = ctx.freshName("loopValue") - val loopIsNull = ctx.freshName("loopIsNull") - - val loopVariable = LambdaVariable(loopValue, loopIsNull, elementType) - val substitutedFunction = completeFunction transform { - case a: AttributeReference if a == loopAttribute => loopVariable - } - // A hack to run this through the analyzer (to bind extractions). - val boundFunction = - SimpleAnalyzer.execute(Project(Alias(substitutedFunction, "")() :: Nil, LocalRelation(Nil))) - .expressions.head.children.head - - val genFunction = boundFunction.gen(ctx) + val genFunction = lambdaFunction.gen(ctx) val dataLength = ctx.freshName("dataLength") val convertedArray = ctx.freshName("convertedArray") val loopIndex = ctx.freshName("loopIndex") - val convertedType = ctx.boxedType(boundFunction.dataType) + val convertedType = ctx.boxedType(lambdaFunction.dataType) // Because of the way Java defines nested arrays, we have to handle the syntax specially. // Specifically, we have to insert the [$dataLength] in between the type and any extra nested @@ -418,9 +468,9 @@ case class MapObjects( } val loopNullCheck = if (primitiveElement) { - s"boolean $loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" + s"${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);" } else { - s"boolean $loopIsNull = ${genInputData.isNull} || $loopValue == null;" + s"${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;" } s""" @@ -436,14 +486,14 @@ case class MapObjects( int $loopIndex = 0; while ($loopIndex < $dataLength) { - $elementJavaType $loopValue = + ${loopVar.value} = ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)}; $loopNullCheck - if ($loopIsNull) { + ${genFunction.code} + if (${genFunction.isNull}) { $convertedArray[$loopIndex] = null; } else { - ${genFunction.code} $convertedArray[$loopIndex] = ${genFunction.value}; } @@ -463,7 +513,9 @@ case class MapObjects( * * @param children A list of expression to use as content of the external row. */ -case class CreateExternalRow(children: Seq[Expression]) extends Expression { +case class CreateExternalRow(children: Seq[Expression], schema: StructType) + extends Expression with NonSQLExpression { + override def dataType: DataType = ObjectType(classOf[Row]) override def nullable: Boolean = false @@ -471,23 +523,188 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression { override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val rowClass = classOf[GenericRow].getName + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + val rowClass = classOf[GenericRowWithSchema].getName val values = ctx.freshName("values") - s""" - boolean ${ev.isNull} = false; - final Object[] $values = new Object[${children.size}]; - """ + - children.zipWithIndex.map { case (e, i) => - val eval = e.gen(ctx) - eval.code + s""" + ctx.addMutableState("Object[]", values, "") + + val childrenCodes = children.zipWithIndex.map { case (e, i) => + val eval = e.gen(ctx) + eval.code + s""" if (${eval.isNull}) { $values[$i] = null; } else { $values[$i] = ${eval.value}; } """ - }.mkString("\n") + - s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values);" + } + val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes) + val schemaField = ctx.addReferenceObj("schema", schema) + s""" + boolean ${ev.isNull} = false; + $values = new Object[${children.size}]; + $childrenCode + final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField); + """ + } +} + +/** + * Serializes an input object using a generic serializer (Kryo or Java). + * @param kryo if true, use Kryo. Otherwise, use Java. + */ +case class EncodeUsingSerializer(child: Expression, kryo: Boolean) + extends UnaryExpression with NonSQLExpression { + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + // Code to initialize the serializer. + val serializer = ctx.freshName("serializer") + val (serializerClass, serializerInstanceClass) = { + if (kryo) { + (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName) + } else { + (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName) + } + } + val sparkConf = s"new ${classOf[SparkConf].getName}()" + ctx.addMutableState( + serializerInstanceClass, + serializer, + s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();") + + // Code to serialize. + val input = child.gen(ctx) + s""" + ${input.code} + final boolean ${ev.isNull} = ${input.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $serializer.serialize(${input.value}, null).array(); + } + """ + } + + override def dataType: DataType = BinaryType +} + +/** + * Serializes an input object using a generic serializer (Kryo or Java). Note that the ClassTag + * is not an implicit parameter because TreeNode cannot copy implicit parameters. + * @param kryo if true, use Kryo. Otherwise, use Java. + */ +case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean) + extends UnaryExpression with NonSQLExpression { + + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + // Code to initialize the serializer. + val serializer = ctx.freshName("serializer") + val (serializerClass, serializerInstanceClass) = { + if (kryo) { + (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName) + } else { + (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName) + } + } + val sparkConf = s"new ${classOf[SparkConf].getName}()" + ctx.addMutableState( + serializerInstanceClass, + serializer, + s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();") + + // Code to serialize. + val input = child.gen(ctx) + s""" + ${input.code} + final boolean ${ev.isNull} = ${input.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = (${ctx.javaType(dataType)}) + $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null); + } + """ + } + + override def dataType: DataType = ObjectType(tag.runtimeClass) +} + +/** + * Initialize a Java Bean instance by setting its field values via setters. + */ +case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Expression]) + extends Expression with NonSQLExpression { + + override def nullable: Boolean = beanInstance.nullable + override def children: Seq[Expression] = beanInstance +: setters.values.toSeq + override def dataType: DataType = beanInstance.dataType + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + val instanceGen = beanInstance.gen(ctx) + + val initialize = setters.map { + case (setterMethod, fieldValue) => + val fieldGen = fieldValue.gen(ctx) + s""" + ${fieldGen.code} + ${instanceGen.value}.$setterMethod(${fieldGen.value}); + """ + } + + ev.isNull = instanceGen.isNull + ev.value = instanceGen.value + + s""" + ${instanceGen.code} + if (!${instanceGen.isNull}) { + ${initialize.mkString("\n")} + } + """ + } +} + +/** + * Asserts that input values of a non-nullable child expression are not null. + * + * Note that there are cases where `child.nullable == true`, while we still needs to add this + * assertion. Consider a nullable column `s` whose data type is a struct containing a non-nullable + * `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all + * non-null `s`, `s.i` can't be null. + */ +case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) + extends UnaryExpression with NonSQLExpression { + + override def dataType: DataType = child.dataType + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + val childGen = child.gen(ctx) + + val errMsg = "Null value appeared in non-nullable field:" + + walkedTypePath.mkString("\n", "\n", "\n") + + "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + + "please try to use scala.Option[_] or other nullable types " + + "(e.g. java.lang.Integer instead of int/scala.Int)." + val idx = ctx.references.length + ctx.references += errMsg + + ev.isNull = "false" + ev.value = childGen.value + + s""" + ${childGen.code} + + if (${childGen.isNull}) { + throw new RuntimeException((String) references[$idx]); + } + """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala index 6407c73bc97d9..6112259fed619 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala @@ -48,6 +48,10 @@ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow dt.ordering.asInstanceOf[Ordering[Any]].compare(left, right) case dt: AtomicType if order.direction == Descending => dt.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) + case a: ArrayType if order.direction == Ascending => + a.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right) + case a: ArrayType if order.direction == Descending => + a.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) case s: StructType if order.direction == Ascending => s.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right) case s: StructType if order.direction == Descending => @@ -86,6 +90,8 @@ object RowOrdering { case NullType => true case dt: AtomicType => true case struct: StructType => struct.fields.forall(f => isOrderable(f.dataType)) + case array: ArrayType => isOrderable(array.elementType) + case udt: UserDefinedType[_] => isOrderable(udt.sqlType) case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index f1fa13daa77eb..23baa6f7837fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -92,4 +92,11 @@ package object expressions { StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable))) } } + + /** + * When an expression inherits this, meaning the expression is null intolerant (i.e. any null + * input will result in null output). We will use this information during constructing IsNotNull + * constraints. + */ + trait NullIntolerant } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 68557479a9591..38f1210a4edb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -65,6 +65,15 @@ trait PredicateHelper { } } + // Substitute any known alias from a map. + protected def replaceAlias( + condition: Expression, + aliases: AttributeMap[Expression]): Expression = { + condition.transform { + case a: Attribute => aliases.getOrElse(a, a) + } + } + /** * Returns true if `expr` can be evaluated using only the output of `plan`. This method * can be used to determine when it is acceptable to move expression evaluation within a query @@ -79,9 +88,10 @@ trait PredicateHelper { expr.references.subsetOf(plan.outputSet) } - +@ExpressionDescription( + usage = "_FUNC_ a - Logical not") case class Not(child: Expression) - extends UnaryExpression with Predicate with ImplicitCastInputTypes { + extends UnaryExpression with Predicate with ImplicitCastInputTypes with NullIntolerant { override def toString: String = s"NOT $child" @@ -89,15 +99,19 @@ case class Not(child: Expression) protected override def nullSafeEval(input: Any): Any = !input.asInstanceOf[Boolean] - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"!($c)") } + + override def sql: String = s"(NOT ${child.sql})" } /** * Evaluates to `true` if `list` contains `value`. */ +@ExpressionDescription( + usage = "expr _FUNC_(val1, val2, ...) - Returns true if expr equals to any valN.") case class In(value: Expression, list: Seq[Expression]) extends Predicate with ImplicitCastInputTypes { @@ -143,7 +157,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val valueGen = value.gen(ctx) val listGen = list.map(_.gen(ctx)) val listCode = listGen.map(x => @@ -167,6 +181,13 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate } """ } + + override def sql: String = { + val childrenSQL = children.map(_.sql) + val valueSQL = childrenSQL.head + val listSQL = childrenSQL.tail.mkString(", ") + s"($valueSQL IN ($listSQL))" + } } /** @@ -195,7 +216,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with def getHSet(): Set[Any] = hset - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val setName = classOf[Set[Any]].getName val InSetName = classOf[InSet].getName val childGen = child.gen(ctx) @@ -203,7 +224,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with val hsetTerm = ctx.freshName("hset") val hasNullTerm = ctx.freshName("hasNull") ctx.addMutableState(setName, hsetTerm, - s"$hsetTerm = (($InSetName)expressions[${ctx.references.size - 1}]).getHSet();") + s"$hsetTerm = (($InSetName)references[${ctx.references.size - 1}]).getHSet();") ctx.addMutableState("boolean", hasNullTerm, s"$hasNullTerm = $hsetTerm.contains(null);") s""" ${childGen.code} @@ -217,14 +238,24 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } """ } + + override def sql: String = { + val valueSQL = child.sql + val listSQL = hset.toSeq.map(Literal(_).sql).mkString(", ") + s"($valueSQL IN ($listSQL))" + } } +@ExpressionDescription( + usage = "a _FUNC_ b - Logical AND.") case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate { override def inputType: AbstractDataType = BooleanType override def symbol: String = "&&" + override def sqlOperator: String = "AND" + override def eval(input: InternalRow): Any = { val input1 = left.eval(input) if (input1 == false) { @@ -243,37 +274,53 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) // The result should be `false`, if any of them is `false` whenever the other is null or not. - s""" - ${eval1.code} - boolean ${ev.isNull} = false; - boolean ${ev.value} = false; + if (!left.nullable && !right.nullable) { + ev.isNull = "false" + s""" + ${eval1.code} + boolean ${ev.value} = false; - if (!${eval1.isNull} && !${eval1.value}) { - } else { - ${eval2.code} - if (!${eval2.isNull} && !${eval2.value}) { - } else if (!${eval1.isNull} && !${eval2.isNull}) { - ${ev.value} = true; + if (${eval1.value}) { + ${eval2.code} + ${ev.value} = ${eval2.value}; + } + """ + } else { + s""" + ${eval1.code} + boolean ${ev.isNull} = false; + boolean ${ev.value} = false; + + if (!${eval1.isNull} && !${eval1.value}) { } else { - ${ev.isNull} = true; + ${eval2.code} + if (!${eval2.isNull} && !${eval2.value}) { + } else if (!${eval1.isNull} && !${eval2.isNull}) { + ${ev.value} = true; + } else { + ${ev.isNull} = true; + } } - } - """ + """ + } } } - +@ExpressionDescription( + usage = "a _FUNC_ b - Logical OR.") case class Or(left: Expression, right: Expression) extends BinaryOperator with Predicate { override def inputType: AbstractDataType = BooleanType override def symbol: String = "||" + override def sqlOperator: String = "OR" + override def eval(input: InternalRow): Any = { val input1 = left.eval(input) if (input1 == true) { @@ -292,34 +339,47 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) // The result should be `true`, if any of them is `true` whenever the other is null or not. - s""" - ${eval1.code} - boolean ${ev.isNull} = false; - boolean ${ev.value} = true; + if (!left.nullable && !right.nullable) { + ev.isNull = "false" + s""" + ${eval1.code} + boolean ${ev.value} = true; - if (!${eval1.isNull} && ${eval1.value}) { - } else { - ${eval2.code} - if (!${eval2.isNull} && ${eval2.value}) { - } else if (!${eval1.isNull} && !${eval2.isNull}) { - ${ev.value} = false; + if (!${eval1.value}) { + ${eval2.code} + ${ev.value} = ${eval2.value}; + } + """ + } else { + s""" + ${eval1.code} + boolean ${ev.isNull} = false; + boolean ${ev.value} = true; + + if (!${eval1.isNull} && ${eval1.value}) { } else { - ${ev.isNull} = true; + ${eval2.code} + if (!${eval2.isNull} && ${eval2.value}) { + } else if (!${eval1.isNull} && !${eval2.isNull}) { + ${ev.value} = false; + } else { + ${ev.isNull} = true; + } } - } - """ + """ + } } } abstract class BinaryComparison extends BinaryOperator with Predicate { - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { if (ctx.isPrimitiveType(left.dataType) && left.dataType != BooleanType // java boolean doesn't support > or < operator && left.dataType != FloatType @@ -347,8 +407,10 @@ private[sql] object Equality { } } - -case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { +@ExpressionDescription( + usage = "a _FUNC_ b - Returns TRUE if a equals b and false otherwise.") +case class EqualTo(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = AnyDataType @@ -366,12 +428,14 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2)) } } - +@ExpressionDescription( + usage = """a _FUNC_ b - Returns same result with EQUAL(=) operator for non-null operands, + but returns TRUE if both are NULL, FALSE if one of the them is NULL.""") case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { override def inputType: AbstractDataType = AnyDataType @@ -400,7 +464,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value) @@ -412,8 +476,10 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } } - -case class LessThan(left: Expression, right: Expression) extends BinaryComparison { +@ExpressionDescription( + usage = "a _FUNC_ b - Returns TRUE if a is less than b.") +case class LessThan(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.Ordered @@ -424,8 +490,10 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2) } - -case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { +@ExpressionDescription( + usage = "a _FUNC_ b - Returns TRUE if a is not greater than b.") +case class LessThanOrEqual(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.Ordered @@ -436,8 +504,10 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2) } - -case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { +@ExpressionDescription( + usage = "a _FUNC_ b - Returns TRUE if a is greater than b.") +case class GreaterThan(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.Ordered @@ -448,8 +518,10 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2) } - -case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { +@ExpressionDescription( + usage = "a _FUNC_ b - Returns TRUE if a is not smaller than b.") +case class GreaterThanOrEqual(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.Ordered diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 8bde8cb9fe876..1ec092a5be965 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, DoubleType} import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -49,9 +49,14 @@ abstract class RDG extends LeafExpression with Nondeterministic { override def nullable: Boolean = false override def dataType: DataType = DoubleType + + // NOTE: Even if the user doesn't provide a seed, Spark SQL adds a default seed. + override def sql: String = s"$prettyName($seed)" } /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ +@ExpressionDescription( + usage = "_FUNC_(a) - Returns a random column with i.i.d. uniformly distributed values in [0, 1).") case class Rand(seed: Long) extends RDG { override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() @@ -62,7 +67,7 @@ case class Rand(seed: Long) extends RDG { case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") }) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName ctx.addMutableState(className, rngTerm, @@ -75,6 +80,8 @@ case class Rand(seed: Long) extends RDG { } /** Generate a random column with i.i.d. gaussian random distribution. */ +@ExpressionDescription( + usage = "_FUNC_(a) - Returns a random column with i.i.d. gaussian random distribution.") case class Randn(seed: Long) extends RDG { override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() @@ -82,10 +89,10 @@ case class Randn(seed: Long) extends RDG { def this(seed: Expression) = this(seed match { case IntegerLiteral(s) => s - case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") + case _ => throw new AnalysisException("Input argument to randn must be an integer literal.") }) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName ctx.addMutableState(className, rngTerm, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 9e484c5ed83bf..85a54292639d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -59,14 +59,18 @@ trait StringRegexExpression extends ImplicitCastInputTypes { matches(regex, input1.asInstanceOf[UTF8String].toString) } } + + override def sql: String = s"${left.sql} ${prettyName.toUpperCase} ${right.sql}" } /** * Simple RegEx pattern matching function */ +@ExpressionDescription( + usage = "str _FUNC_ pattern - Returns true if str matches pattern and false otherwise.") case class Like(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression with CodegenFallback { + extends BinaryExpression with StringRegexExpression { override def escape(v: String): String = StringUtils.escapeLikeRegex(v) @@ -74,7 +78,7 @@ case class Like(left: Expression, right: Expression) override def toString: String = s"$left LIKE $right" - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val patternClass = classOf[Pattern].getName val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" val pattern = ctx.freshName("pattern") @@ -115,15 +119,16 @@ case class Like(left: Expression, right: Expression) } } - +@ExpressionDescription( + usage = "str _FUNC_ regexp - Returns true if str matches regexp and false otherwise.") case class RLike(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression with CodegenFallback { + extends BinaryExpression with StringRegexExpression { override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) override def toString: String = s"$left RLIKE $right" - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val patternClass = classOf[Pattern].getName val pattern = ctx.freshName("pattern") @@ -167,6 +172,9 @@ case class RLike(left: Expression, right: Expression) /** * Splits str around pat (pattern is a regular expression). */ +@ExpressionDescription( + usage = "_FUNC_(str, regex) - Splits str around occurrences that match regex", + extended = "> SELECT _FUNC_('oneAtwoBthreeC', '[ABC]');\n ['one', 'two', 'three']") case class StringSplit(str: Expression, pattern: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -180,7 +188,7 @@ case class StringSplit(str: Expression, pattern: Expression) new GenericArrayData(strings.asInstanceOf[Array[Any]]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, (str, pattern) => // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. @@ -196,6 +204,9 @@ case class StringSplit(str: Expression, pattern: Expression) * * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. */ +@ExpressionDescription( + usage = "_FUNC_(str, regexp, rep) - replace all substrings of str that match regexp with rep.", + extended = "> SELECT _FUNC_('100-200', '(\\d+)', 'num');\n 'num-num'") case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -236,7 +247,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio override def children: Seq[Expression] = subject :: regexp :: rep :: Nil override def prettyName: String = "regexp_replace" - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val termLastRegex = ctx.freshName("lastRegex") val termPattern = ctx.freshName("pattern") @@ -287,6 +298,9 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio * * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. */ +@ExpressionDescription( + usage = "_FUNC_(str, regexp[, idx]) - extracts a group that matches regexp.", + extended = "> SELECT _FUNC_('100-200', '(\\d+)-(\\d+)', 1);\n '100'") case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) extends TernaryExpression with ImplicitCastInputTypes { def this(s: Expression, r: Expression) = this(s, r, Literal(1)) @@ -316,7 +330,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio override def children: Seq[Expression] = subject :: regexp :: idx :: Nil override def prettyName: String = "regexp_extract" - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val termLastRegex = ctx.freshName("lastRegex") val termPattern = ctx.freshName("pattern") val classNamePattern = classOf[Pattern].getCanonicalName diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index cfc68fc00bea8..93a8278528697 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -164,7 +164,7 @@ trait BaseGenericInternalRow extends InternalRow { abstract class MutableRow extends InternalRow { def setNullAt(i: Int): Unit - def update(i: Int, value: Any) + def update(i: Int, value: Any): Unit // default implementation (slow) def setBoolean(i: Int, value: Boolean): Unit = { update(i, value) } @@ -199,9 +199,9 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { override def get(i: Int): Any = values(i) - override def toSeq: Seq[Any] = values.toSeq + override def toSeq: Seq[Any] = values.clone() - override def copy(): Row = this + override def copy(): GenericRow = this } class GenericRowWithSchema(values: Array[Any], override val schema: StructType) @@ -226,23 +226,11 @@ class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGeneri override protected def genericGet(ordinal: Int) = values(ordinal) - override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values.clone() override def numFields: Int = values.length - override def copy(): InternalRow = new GenericInternalRow(values.clone()) -} - -/** - * This is used for serialization of Python DataFrame - */ -class GenericInternalRowWithSchema(values: Array[Any], val schema: StructType) - extends GenericInternalRow(values) { - - /** No-arg constructor for serialization. */ - protected def this() = this(null, null) - - def fieldIndex(name: String): Int = schema.fieldIndex(name) + override def copy(): GenericInternalRow = this } class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericInternalRow { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala deleted file mode 100644 index d124d29d534b8..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.OpenHashSet - -/** The data type for expressions returning an OpenHashSet as the result. */ -private[sql] class OpenHashSetUDT( - val elementType: DataType) extends UserDefinedType[OpenHashSet[Any]] { - - override def sqlType: DataType = ArrayType(elementType) - - /** Since we are using OpenHashSet internally, usually it will not be called. */ - override def serialize(obj: Any): Seq[Any] = { - obj.asInstanceOf[OpenHashSet[Any]].iterator.toSeq - } - - /** Since we are using OpenHashSet internally, usually it will not be called. */ - override def deserialize(datum: Any): OpenHashSet[Any] = { - val iterator = datum.asInstanceOf[Seq[Any]].iterator - val set = new OpenHashSet[Any] - while(iterator.hasNext) { - set.add(iterator.next()) - } - - set - } - - override def userClass: Class[OpenHashSet[Any]] = classOf[OpenHashSet[Any]] - - private[spark] override def asNullable: OpenHashSetUDT = this -} - -/** - * Creates a new set of the specified type - */ -case class NewSet(elementType: DataType) extends LeafExpression with CodegenFallback { - - override def nullable: Boolean = false - - override def dataType: OpenHashSetUDT = new OpenHashSetUDT(elementType) - - override def eval(input: InternalRow): Any = { - new OpenHashSet[Any]() - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - elementType match { - case IntegerType | LongType => - ev.isNull = "false" - s""" - ${ctx.javaType(dataType)} ${ev.value} = new ${ctx.javaType(dataType)}(); - """ - case _ => super.genCode(ctx, ev) - } - } - - override def toString: String = s"new Set($dataType)" -} - -/** - * Adds an item to a set. - * For performance, this expression mutates its input during evaluation. - * Note: this expression is internal and created only by the GeneratedAggregate, - * we don't need to do type check for it. - */ -case class AddItemToSet(item: Expression, set: Expression) - extends Expression with CodegenFallback { - - override def children: Seq[Expression] = item :: set :: Nil - - override def nullable: Boolean = set.nullable - - override def dataType: DataType = set.dataType - - override def eval(input: InternalRow): Any = { - val itemEval = item.eval(input) - val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]] - - if (itemEval != null) { - if (setEval != null) { - setEval.add(itemEval) - setEval - } else { - null - } - } else { - setEval - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType - elementType match { - case IntegerType | LongType => - val itemEval = item.gen(ctx) - val setEval = set.gen(ctx) - val htype = ctx.javaType(dataType) - - ev.isNull = "false" - ev.value = setEval.value - itemEval.code + setEval.code + s""" - if (!${itemEval.isNull} && !${setEval.isNull}) { - (($htype)${setEval.value}).add(${itemEval.value}); - } - """ - case _ => super.genCode(ctx, ev) - } - } - - override def toString: String = s"$set += $item" -} - -/** - * Combines the elements of two sets. - * For performance, this expression mutates its left input set during evaluation. - * Note: this expression is internal and created only by the GeneratedAggregate, - * we don't need to do type check for it. - */ -case class CombineSets(left: Expression, right: Expression) - extends BinaryExpression with CodegenFallback { - - override def nullable: Boolean = left.nullable - override def dataType: DataType = left.dataType - - override def eval(input: InternalRow): Any = { - val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]] - if(leftEval != null) { - val rightEval = right.eval(input).asInstanceOf[OpenHashSet[Any]] - if (rightEval != null) { - val iterator = rightEval.iterator - while(iterator.hasNext) { - val rightValue = iterator.next() - leftEval.add(rightValue) - } - } - leftEval - } else { - null - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType - elementType match { - case IntegerType | LongType => - val leftEval = left.gen(ctx) - val rightEval = right.gen(ctx) - val htype = ctx.javaType(dataType) - - ev.isNull = leftEval.isNull - ev.value = leftEval.value - leftEval.code + rightEval.code + s""" - if (!${leftEval.isNull} && !${rightEval.isNull}) { - ${leftEval.value}.union((${htype})${rightEval.value}); - } - """ - case _ => super.genCode(ctx, ev) - } - } -} - -/** - * Returns the number of elements in the input set. - * Note: this expression is internal and created only by the GeneratedAggregate, - * we don't need to do type check for it. - */ -case class CountSet(child: Expression) extends UnaryExpression with CodegenFallback { - - override def dataType: DataType = LongType - - protected override def nullSafeEval(input: Any): Any = - input.asInstanceOf[OpenHashSet[Any]].size.toLong - - override def toString: String = s"$child.count()" -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 8770c4b76c2e5..a17482697d906 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.text.DecimalFormat +import java.text.{DecimalFormat, DecimalFormatSymbols} import java.util.{HashMap, Locale, Map => JMap} import org.apache.spark.sql.catalyst.InternalRow @@ -35,6 +35,9 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} * An expression that concatenates multiple input strings into a single string. * If any input is null, concat returns null. */ +@ExpressionDescription( + usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of str1, str2, ..., strN", + extended = "> SELECT _FUNC_('Spark','SQL');\n 'SparkSQL'") case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) @@ -48,7 +51,7 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas UTF8String.concat(inputs : _*) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val evals = children.map(_.gen(ctx)) val inputs = evals.map { eval => s"${eval.isNull} ? null : ${eval.value}" @@ -70,6 +73,10 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas * * Returns null if the separator is null. Otherwise, concat_ws skips all null values. */ +@ExpressionDescription( + usage = + "_FUNC_(sep, [str | array(str)]+) - Returns the concatenation of the strings separated by sep.", + extended = "> SELECT _FUNC_(' ', Spark', 'SQL');\n 'Spark SQL'") case class ConcatWs(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { @@ -99,7 +106,7 @@ case class ConcatWs(children: Seq[Expression]) UTF8String.concatWs(flatInputs.head, flatInputs.tail : _*) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { if (children.forall(_.dataType == StringType)) { // All children are strings. In that case we can construct a fixed size array. val evals = children.map(_.gen(ctx)) @@ -178,7 +185,7 @@ case class Upper(child: Expression) override def convert(v: UTF8String): UTF8String = v.toUpperCase - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") } } @@ -188,12 +195,12 @@ case class Upper(child: Expression) */ @ExpressionDescription( usage = "_FUNC_(str) - Returns str with all characters changed to lowercase", - extended = "> SELECT _FUNC_('SparkSql');\n'sparksql'") + extended = "> SELECT _FUNC_('SparkSql');\n 'sparksql'") case class Lower(child: Expression) extends UnaryExpression with String2StringExpression { override def convert(v: UTF8String): UTF8String = v.toLowerCase - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") } } @@ -218,7 +225,7 @@ trait StringPredicate extends Predicate with ImplicitCastInputTypes { case class Contains(left: Expression, right: Expression) extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") } } @@ -229,7 +236,7 @@ case class Contains(left: Expression, right: Expression) case class StartsWith(left: Expression, right: Expression) extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") } } @@ -240,7 +247,7 @@ case class StartsWith(left: Expression, right: Expression) case class EndsWith(left: Expression, right: Expression) extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") } } @@ -270,6 +277,11 @@ object StringTranslate { * The translate will happen when any character in the string matching with the character * in the `matchingExpr`. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """_FUNC_(input, from, to) - Translates the input string by replacing the characters present in the from string with the corresponding characters in the to string""", + extended = "> SELECT _FUNC_('AaBbCc', 'abc', '123');\n 'A1B2C3'") +// scalastyle:on line.size.limit case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replaceExpr: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -286,30 +298,30 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac srcEval.asInstanceOf[UTF8String].translate(dict) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val termLastMatching = ctx.freshName("lastMatching") val termLastReplace = ctx.freshName("lastReplace") val termDict = ctx.freshName("dict") val classNameDict = classOf[JMap[Character, Character]].getCanonicalName - ctx.addMutableState("UTF8String", termLastMatching, s"${termLastMatching} = null;") - ctx.addMutableState("UTF8String", termLastReplace, s"${termLastReplace} = null;") - ctx.addMutableState(classNameDict, termDict, s"${termDict} = null;") + ctx.addMutableState("UTF8String", termLastMatching, s"$termLastMatching = null;") + ctx.addMutableState("UTF8String", termLastReplace, s"$termLastReplace = null;") + ctx.addMutableState(classNameDict, termDict, s"$termDict = null;") nullSafeCodeGen(ctx, ev, (src, matching, replace) => { val check = if (matchingExpr.foldable && replaceExpr.foldable) { - s"${termDict} == null" + s"$termDict == null" } else { - s"!${matching}.equals(${termLastMatching}) || !${replace}.equals(${termLastReplace})" + s"!$matching.equals($termLastMatching) || !$replace.equals($termLastReplace)" } s"""if ($check) { // Not all of them is literal or matching or replace value changed - ${termLastMatching} = ${matching}.clone(); - ${termLastReplace} = ${replace}.clone(); - ${termDict} = org.apache.spark.sql.catalyst.expressions.StringTranslate - .buildDict(${termLastMatching}, ${termLastReplace}); + $termLastMatching = $matching.clone(); + $termLastReplace = $replace.clone(); + $termDict = org.apache.spark.sql.catalyst.expressions.StringTranslate + .buildDict($termLastMatching, $termLastReplace); } - ${ev.value} = ${src}.translate(${termDict}); + ${ev.value} = $src.translate($termDict); """ }) } @@ -325,6 +337,12 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac * delimited list (right). Returns 0, if the string wasn't found or if the given * string (left) contains a comma. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """_FUNC_(str, str_array) - Returns the index (1-based) of the given string (left) in the comma-delimited list (right). + Returns 0, if the string wasn't found or if the given string (left) contains a comma.""", + extended = "> SELECT _FUNC_('ab','abc,b,ab,c,def');\n 3") +// scalastyle:on case class FindInSet(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -333,18 +351,23 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi override protected def nullSafeEval(word: Any, set: Any): Any = set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String]) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (word, set) => s"${ev.value} = $set.findInSet($word);" ) } override def dataType: DataType = IntegerType + + override def prettyName: String = "find_in_set" } /** * A function that trim the spaces from both ends for the specified string. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Removes the leading and trailing space characters from str.", + extended = "> SELECT _FUNC_(' SparkSQL ');\n 'SparkSQL'") case class StringTrim(child: Expression) extends UnaryExpression with String2StringExpression { @@ -352,7 +375,7 @@ case class StringTrim(child: Expression) override def prettyName: String = "trim" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"($c).trim()") } } @@ -360,6 +383,9 @@ case class StringTrim(child: Expression) /** * A function that trim the spaces from left end for given string. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Removes the leading space characters from str.", + extended = "> SELECT _FUNC_(' SparkSQL ');\n 'SparkSQL '") case class StringTrimLeft(child: Expression) extends UnaryExpression with String2StringExpression { @@ -367,7 +393,7 @@ case class StringTrimLeft(child: Expression) override def prettyName: String = "ltrim" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"($c).trimLeft()") } } @@ -375,6 +401,9 @@ case class StringTrimLeft(child: Expression) /** * A function that trim the spaces from right end for given string. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Removes the trailing space characters from str.", + extended = "> SELECT _FUNC_(' SparkSQL ');\n ' SparkSQL'") case class StringTrimRight(child: Expression) extends UnaryExpression with String2StringExpression { @@ -382,7 +411,7 @@ case class StringTrimRight(child: Expression) override def prettyName: String = "rtrim" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"($c).trimRight()") } } @@ -394,6 +423,9 @@ case class StringTrimRight(child: Expression) * * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1. */ +@ExpressionDescription( + usage = "_FUNC_(str, substr) - Returns the (1-based) index of the first occurrence of substr in str.", + extended = "> SELECT _FUNC_('SparkSQL', 'SQL');\n 6") case class StringInstr(str: Expression, substr: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -408,7 +440,7 @@ case class StringInstr(str: Expression, substr: Expression) override def prettyName: String = "instr" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (l, r) => s"($l).indexOf($r, 0) + 1") } @@ -420,6 +452,15 @@ case class StringInstr(str: Expression, substr: Expression) * returned. If count is negative, every to the right of the final delimiter (counting from the * right) is returned. substring_index performs a case-sensitive match when searching for delim. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """_FUNC_(str, delim, count) - Returns the substring from str before count occurrences of the delimiter delim. + If count is positive, everything to the left of the final delimiter (counting from the + left) is returned. If count is negative, everything to the right of the final delimiter + (counting from the right) is returned. Substring_index performs a case-sensitive match + when searching for delim.""", + extended = "> SELECT _FUNC_('www.apache.org', '.', 2);\n 'www.apache'") +// scalastyle:on line.size.limit case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -434,7 +475,7 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: count.asInstanceOf[Int]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (str, delim, count) => s"$str.subStringIndex($delim, $count)") } } @@ -443,6 +484,12 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: * A function that returns the position of the first occurrence of substr * in given string after position pos. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """_FUNC_(substr, str[, pos]) - Returns the position of the first occurrence of substr in str after position pos. + The given pos and return value are 1-based.""", + extended = "> SELECT _FUNC_('bar', 'foobarbar', 5);\n 7") +// scalastyle:on line.size.limit case class StringLocate(substr: Expression, str: Expression, start: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -477,7 +524,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) } } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val substrGen = substr.gen(ctx) val strGen = str.gen(ctx) val startGen = start.gen(ctx) @@ -508,6 +555,11 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) /** * Returns str, left-padded with pad to a length of len. */ +@ExpressionDescription( + usage = """_FUNC_(str, len, pad) - Returns str, left-padded with pad to a length of len. + If str is longer than len, the return value is shortened to len characters.""", + extended = "> SELECT _FUNC_('hi', 5, '??');\n '???hi'\n" + + "> SELECT _FUNC_('hi', 1, '??');\n 'h'") case class StringLPad(str: Expression, len: Expression, pad: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -519,7 +571,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) str.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (str, len, pad) => s"$str.lpad($len, $pad)") } @@ -529,6 +581,11 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) /** * Returns str, right-padded with pad to a length of len. */ +@ExpressionDescription( + usage = """_FUNC_(str, len, pad) - Returns str, right-padded with pad to a length of len. + If str is longer than len, the return value is shortened to len characters.""", + extended = "> SELECT _FUNC_('hi', 5, '??');\n 'hi???'\n" + + "> SELECT _FUNC_('hi', 1, '??');\n 'h'") case class StringRPad(str: Expression, len: Expression, pad: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -540,7 +597,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) str.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (str, len, pad) => s"$str.rpad($len, $pad)") } @@ -550,6 +607,11 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) /** * Returns the input formatted according do printf-style format strings */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(String format, Obj... args) - Returns a formatted string from printf-style format strings.", + extended = "> SELECT _FUNC_(\"Hello World %d %s\", 100, \"days\");\n 'Hello World 100 days'") +// scalastyle:on line.size.limit case class FormatString(children: Expression*) extends Expression with ImplicitCastInputTypes { require(children.nonEmpty, "format_string() should take at least 1 argument") @@ -576,7 +638,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val pattern = children.head.gen(ctx) val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx))) @@ -616,25 +678,33 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC } /** - * Returns string, with the first letter of each word in uppercase. + * Returns string, with the first letter of each word in uppercase, all other letters in lowercase. * Words are delimited by whitespace. */ +@ExpressionDescription( + usage = + """_FUNC_(str) - Returns str with the first letter of each word in uppercase. + All other letters are in lowercase. Words are delimited by white space.""", + extended = "> SELECT initcap('sPark sql');\n 'Spark Sql'") case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[DataType] = Seq(StringType) override def dataType: DataType = StringType override def nullSafeEval(string: Any): Any = { - string.asInstanceOf[UTF8String].toTitleCase + string.asInstanceOf[UTF8String].toLowerCase.toTitleCase } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, str => s"$str.toTitleCase()") + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + defineCodeGen(ctx, ev, str => s"$str.toLowerCase().toTitleCase()") } } /** * Returns the string which repeat the given string value n times. */ +@ExpressionDescription( + usage = "_FUNC_(str, n) - Returns the string which repeat the given string value n times.", + extended = "> SELECT _FUNC_('123', 2);\n '123123'") case class StringRepeat(str: Expression, times: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -649,7 +719,7 @@ case class StringRepeat(str: Expression, times: Expression) override def prettyName: String = "repeat" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (l, r) => s"($l).repeat($r)") } } @@ -657,12 +727,15 @@ case class StringRepeat(str: Expression, times: Expression) /** * Returns the reversed given string. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Returns the reversed given string.", + extended = "> SELECT _FUNC_('Spark SQL');\n 'LQS krapS'") case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression { override def convert(v: UTF8String): UTF8String = v.reverse() override def prettyName: String = "reverse" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"($c).reverse()") } } @@ -670,6 +743,9 @@ case class StringReverse(child: Expression) extends UnaryExpression with String2 /** * Returns a n spaces string. */ +@ExpressionDescription( + usage = "_FUNC_(n) - Returns a n spaces string.", + extended = "> SELECT _FUNC_(2);\n ' '") case class StringSpace(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { @@ -681,7 +757,7 @@ case class StringSpace(child: Expression) UTF8String.blankString(if (length < 0) 0 else length) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (length) => s"""${ev.value} = UTF8String.blankString(($length < 0) ? 0 : $length);""") } @@ -692,7 +768,14 @@ case class StringSpace(child: Expression) /** * A function that takes a substring of its first argument starting at a given position. * Defined for String and Binary types. + * + * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str, pos[, len]) - Returns the substring of str that starts at pos and is of length len or the slice of byte array that starts at pos and is of length len.", + extended = "> SELECT _FUNC_('Spark SQL', 5);\n 'k SQL'\n> SELECT _FUNC_('Spark SQL', -3);\n 'SQL'\n> SELECT _FUNC_('Spark SQL', 5, 1);\n 'k'") +// scalastyle:on line.size.limit case class Substring(str: Expression, pos: Expression, len: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -716,7 +799,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (string, pos, len) => { str.dataType match { @@ -730,6 +813,9 @@ case class Substring(str: Expression, pos: Expression, len: Expression) /** * A function that return the length of the given string or binary expression. */ +@ExpressionDescription( + usage = "_FUNC_(str | binary) - Returns the length of str or number of bytes in binary data.", + extended = "> SELECT _FUNC_('Spark SQL');\n 9") case class Length(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -739,7 +825,7 @@ case class Length(child: Expression) extends UnaryExpression with ExpectsInputTy case BinaryType => value.asInstanceOf[Array[Byte]].length } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { child.dataType match { case StringType => defineCodeGen(ctx, ev, c => s"($c).numChars()") case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") @@ -750,6 +836,9 @@ case class Length(child: Expression) extends UnaryExpression with ExpectsInputTy /** * A function that return the Levenshtein distance between the two given strings. */ +@ExpressionDescription( + usage = "_FUNC_(str1, str2) - Returns the Levenshtein distance between the two given strings.", + extended = "> SELECT _FUNC_('kitten', 'sitting');\n 3") case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -759,7 +848,7 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres protected override def nullSafeEval(leftValue: Any, rightValue: Any): Any = leftValue.asInstanceOf[UTF8String].levenshteinDistance(rightValue.asInstanceOf[UTF8String]) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (left, right) => s"${ev.value} = $left.levenshteinDistance($right);") } @@ -768,6 +857,9 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres /** * A function that return soundex code of the given string expression. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Returns soundex code of the string.", + extended = "> SELECT _FUNC_('Miller');\n 'M460'") case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = StringType @@ -776,7 +868,7 @@ case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputT override def nullSafeEval(input: Any): Any = input.asInstanceOf[UTF8String].soundex() - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"$c.soundex()") } } @@ -784,6 +876,10 @@ case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputT /** * Returns the numeric value of the first character of str. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Returns the numeric value of the first character of str.", + extended = "> SELECT _FUNC_('222');\n 50\n" + + "> SELECT _FUNC_(2);\n 50") case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = IntegerType @@ -798,7 +894,7 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (child) => { val bytes = ctx.freshName("bytes") s""" @@ -815,6 +911,8 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp /** * Converts the argument from binary to a base 64 string. */ +@ExpressionDescription( + usage = "_FUNC_(bin) - Convert the argument from binary to a base 64 string.") case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType @@ -826,18 +924,19 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn bytes.asInstanceOf[Array[Byte]])) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (child) => { s"""${ev.value} = UTF8String.fromBytes( org.apache.commons.codec.binary.Base64.encodeBase64($child)); """}) } - } /** * Converts the argument from a base 64 string to BINARY. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Convert the argument from a base 64 string to binary.") case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = BinaryType @@ -846,7 +945,7 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast protected override def nullSafeEval(string: Any): Any = org.apache.commons.codec.binary.Base64.decodeBase64(string.asInstanceOf[UTF8String].toString) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (child) => { s""" ${ev.value} = org.apache.commons.codec.binary.Base64.decodeBase64($child.toString()); @@ -859,6 +958,8 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). * If either argument is null, the result will also be null. */ +@ExpressionDescription( + usage = "_FUNC_(bin, str) - Decode the first argument using the second argument character set.") case class Decode(bin: Expression, charset: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -872,7 +973,7 @@ case class Decode(bin: Expression, charset: Expression) UTF8String.fromString(new String(input1.asInstanceOf[Array[Byte]], fromCharset)) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (bytes, charset) => s""" try { @@ -888,7 +989,9 @@ case class Decode(bin: Expression, charset: Expression) * Encodes the first argument into a BINARY using the provided character set * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). * If either argument is null, the result will also be null. -*/ + */ +@ExpressionDescription( + usage = "_FUNC_(str, str) - Encode the first argument using the second argument character set.") case class Encode(value: Expression, charset: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -902,7 +1005,7 @@ case class Encode(value: Expression, charset: Expression) input1.asInstanceOf[UTF8String].toString.getBytes(toCharset) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (string, charset) => s""" try { @@ -918,12 +1021,18 @@ case class Encode(value: Expression, charset: Expression) * and returns the result as a string. If D is 0, the result has no decimal point or * fractional part. */ +@ExpressionDescription( + usage = """_FUNC_(X, D) - Formats the number X like '#,###,###.##', rounded to D decimal places. + If D is 0, the result has no decimal point or fractional part. + This is supposed to function like MySQL's FORMAT.""", + extended = "> SELECT _FUNC_(12332.123456, 4);\n '12,332.1235'") case class FormatNumber(x: Expression, d: Expression) extends BinaryExpression with ExpectsInputTypes { override def left: Expression = x override def right: Expression = d override def dataType: DataType = StringType + override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) // Associated with the pattern, for the last d value, and we will update the @@ -936,8 +1045,10 @@ case class FormatNumber(x: Expression, d: Expression) @transient private val pattern: StringBuffer = new StringBuffer() + // SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.') + // as a decimal separator. @transient - private val numberFormat: DecimalFormat = new DecimalFormat("") + private val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US)) override protected def nullSafeEval(xObject: Any, dObject: Any): Any = { val dValue = dObject.asInstanceOf[Int] @@ -960,10 +1071,9 @@ case class FormatNumber(x: Expression, d: Expression) pattern.append("0") } } - val dFormat = new DecimalFormat(pattern.toString) lastDValue = dValue - numberFormat.applyPattern(dFormat.toPattern) + numberFormat.applyLocalizedPattern(pattern.toString) } x.dataType match { @@ -978,7 +1088,7 @@ case class FormatNumber(x: Expression, d: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (num, d) => { def typeHelper(p: String): String = { @@ -990,6 +1100,11 @@ case class FormatNumber(x: Expression, d: Expression) val sb = classOf[StringBuffer].getName val df = classOf[DecimalFormat].getName + val dfs = classOf[DecimalFormatSymbols].getName + val l = classOf[Locale].getName + // SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.') + // as a decimal separator. + val usLocale = "US" val lastDValue = ctx.freshName("lastDValue") val pattern = ctx.freshName("pattern") val numberFormat = ctx.freshName("numberFormat") @@ -997,7 +1112,8 @@ case class FormatNumber(x: Expression, d: Expression) val dFormat = ctx.freshName("dFormat") ctx.addMutableState("int", lastDValue, s"$lastDValue = -100;") ctx.addMutableState(sb, pattern, s"$pattern = new $sb();") - ctx.addMutableState(df, numberFormat, s"""$numberFormat = new $df("");""") + ctx.addMutableState(df, numberFormat, + s"""$numberFormat = new $df("", new $dfs($l.$usLocale));""") s""" if ($d >= 0) { @@ -1011,9 +1127,8 @@ case class FormatNumber(x: Expression, d: Expression) $pattern.append("0"); } } - $df $dFormat = new $df($pattern.toString()); $lastDValue = $d; - $numberFormat.applyPattern($dFormat.toPattern()); + $numberFormat.applyLocalizedPattern($pattern.toString()); } ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala new file mode 100644 index 0000000000000..968bbdb1a5f03 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} +import org.apache.spark.sql.types.DataType + +/** + * An interface for subquery that is used in expressions. + */ +abstract class SubqueryExpression extends LeafExpression { + + /** + * The logical plan of the query. + */ + def query: LogicalPlan + + /** + * Either a logical plan or a physical plan. The generated tree string (explain output) uses this + * field to explain the subquery. + */ + def plan: QueryPlan[_] + + /** + * Updates the query with new logical plan. + */ + def withNewPlan(plan: LogicalPlan): SubqueryExpression +} + +/** + * A subquery that will return only one row and one column. This will be converted into a physical + * scalar subquery during planning. + * + * Note: `exprId` is used to have unique name in explain string output. + */ +case class ScalarSubquery( + query: LogicalPlan, + exprId: ExprId = NamedExpression.newExprId) + extends SubqueryExpression with Unevaluable { + + override def plan: LogicalPlan = SubqueryAlias(toString, query) + + override lazy val resolved: Boolean = query.resolved + + override def dataType: DataType = query.schema.fields.head.dataType + + override def checkInputDataTypes(): TypeCheckResult = { + if (query.schema.length != 1) { + TypeCheckResult.TypeCheckFailure("Scalar subquery must return only one column, but got " + + query.schema.length.toString) + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def foldable: Boolean = false + override def nullable: Boolean = true + + override def withNewPlan(plan: LogicalPlan): ScalarSubquery = ScalarSubquery(plan, exprId) + + override def toString: String = s"subquery#${exprId.id}" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 09ec0e333aa44..c0b453dccf5e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.types.{DataType, NumericType} +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.aggregate.{DeclarativeAggregate, NoOp} +import org.apache.spark.sql.types._ /** * The trait of the Window Specification (specified in the OVER clause or WINDOW clause) for @@ -29,6 +31,7 @@ sealed trait WindowSpec /** * The specification for a window function. + * * @param partitionSpec It defines the way that input rows are partitioned. * @param orderSpec It defines the ordering of rows in a partition. * @param frameSpecification It defines the window frame in a partition. @@ -71,12 +74,25 @@ case class WindowSpecDefinition( childrenResolved && checkInputDataTypes().isSuccess && frameSpecification.isInstanceOf[SpecifiedWindowFrame] - - override def toString: String = simpleString - override def nullable: Boolean = true override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException + + override def sql: String = { + val partition = if (partitionSpec.isEmpty) { + "" + } else { + "PARTITION BY " + partitionSpec.map(_.sql).mkString(", ") + } + + val order = if (orderSpec.isEmpty) { + "" + } else { + "ORDER BY " + orderSpec.map(_.sql).mkString(", ") + } + + s"($partition $order ${frameSpecification.toString})" + } } /** @@ -120,6 +136,19 @@ sealed trait FrameBoundary { def notFollows(other: FrameBoundary): Boolean } +/** + * Extractor for making working with frame boundaries easier. + */ +object FrameBoundary { + def apply(boundary: FrameBoundary): Option[Int] = unapply(boundary) + def unapply(boundary: FrameBoundary): Option[Int] = boundary match { + case CurrentRow => Some(0) + case ValuePreceding(offset) => Some(-offset) + case ValueFollowing(offset) => Some(offset) + case _ => None + } +} + /** UNBOUNDED PRECEDING boundary. */ case object UnboundedPreceding extends FrameBoundary { def notFollows(other: FrameBoundary): Boolean = other match { @@ -246,85 +275,428 @@ object SpecifiedWindowFrame { } } +case class UnresolvedWindowExpression( + child: Expression, + windowSpec: WindowSpecReference) extends UnaryExpression with Unevaluable { + + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override lazy val resolved = false +} + +case class WindowExpression( + windowFunction: Expression, + windowSpec: WindowSpecDefinition) extends Expression with Unevaluable { + + override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil + + override def dataType: DataType = windowFunction.dataType + override def foldable: Boolean = windowFunction.foldable + override def nullable: Boolean = windowFunction.nullable + + override def toString: String = s"$windowFunction $windowSpec" + override def sql: String = windowFunction.sql + " OVER " + windowSpec.sql +} + /** - * Every window function needs to maintain a output buffer for its output. - * It should expect that for a n-row window frame, it will be called n times - * to retrieve value corresponding with these n rows. + * A window function is a function that can only be evaluated in the context of a window operator. */ trait WindowFunction extends Expression { - def init(): Unit + /** Frame in which the window operator must be executed. */ + def frame: WindowFrame = UnspecifiedFrame +} + +/** + * An offset window function is a window function that returns the value of the input column offset + * by a number of rows within the partition. For instance: an OffsetWindowfunction for value x with + * offset -2, will get the value of x 2 rows back in the partition. + */ +abstract class OffsetWindowFunction + extends Expression with WindowFunction with Unevaluable with ImplicitCastInputTypes { + /** + * Input expression to evaluate against a row which a number of rows below or above (depending on + * the value and sign of the offset) the current row. + */ + val input: Expression - def reset(): Unit + /** + * Default result value for the function when the input expression returns NULL. The default will + * evaluated against the current row instead of the offset row. + */ + val default: Expression - def prepareInputParameters(input: InternalRow): AnyRef + /** + * (Foldable) expression that contains the number of rows between the current row and the row + * where the input expression is evaluated. + */ + val offset: Expression - def update(input: AnyRef): Unit + /** + * Direction of the number of rows between the current row and the row where the input expression + * is evaluated. + */ + val direction: SortDirection - def batchUpdate(inputs: Array[AnyRef]): Unit + override def children: Seq[Expression] = Seq(input, offset, default) - def evaluate(): Unit + /* + * The result of an OffsetWindowFunction is dependent on the frame in which the + * OffsetWindowFunction is executed, the input expression and the default expression. Even when + * both the input and the default expression are foldable, the result is still not foldable due to + * the frame. + */ + override def foldable: Boolean = false - def get(index: Int): Any + override def nullable: Boolean = default == null || default.nullable - def newInstance(): WindowFunction + override lazy val frame = { + // This will be triggered by the Analyzer. + val offsetValue = offset.eval() match { + case o: Int => o + case x => throw new AnalysisException( + s"Offset expression must be a foldable integer expression: $x") + } + val boundary = direction match { + case Ascending => ValueFollowing(offsetValue) + case Descending => ValuePreceding(offsetValue) + } + SpecifiedWindowFrame(RowFrame, boundary, boundary) + } + + override def dataType: DataType = input.dataType + + override def inputTypes: Seq[AbstractDataType] = + Seq(AnyDataType, IntegerType, TypeCollection(input.dataType, NullType)) + + override def toString: String = s"$prettyName($input, $offset, $default)" } -case class UnresolvedWindowFunction( - name: String, - children: Seq[Expression]) - extends Expression with WindowFunction with Unevaluable { +/** + * The Lead function returns the value of 'x' at 'offset' rows after the current row in the window. + * Offsets start at 0, which is the current row. The offset must be constant integer value. The + * default offset is 1. When the value of 'x' is null at the offset, or when the offset is larger + * than the window, the default expression is evaluated. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + * + * @param input expression to evaluate 'offset' rows after the current row. + * @param offset rows to jump ahead in the partition. + * @param default to use when the input value is null or when the offset is larger than the window. + */ +@ExpressionDescription(usage = + """_FUNC_(input, offset, default) - LEAD returns the value of 'x' at 'offset' rows + after the current row in the window""") +case class Lead(input: Expression, offset: Expression, default: Expression) + extends OffsetWindowFunction { - override def dataType: DataType = throw new UnresolvedException(this, "dataType") - override def foldable: Boolean = throw new UnresolvedException(this, "foldable") - override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override lazy val resolved = false + def this(input: Expression, offset: Expression) = this(input, offset, Literal(null)) - override def init(): Unit = throw new UnresolvedException(this, "init") - override def reset(): Unit = throw new UnresolvedException(this, "reset") - override def prepareInputParameters(input: InternalRow): AnyRef = - throw new UnresolvedException(this, "prepareInputParameters") - override def update(input: AnyRef): Unit = throw new UnresolvedException(this, "update") - override def batchUpdate(inputs: Array[AnyRef]): Unit = - throw new UnresolvedException(this, "batchUpdate") - override def evaluate(): Unit = throw new UnresolvedException(this, "evaluate") - override def get(index: Int): Any = throw new UnresolvedException(this, "get") + def this(input: Expression) = this(input, Literal(1)) - override def toString: String = s"'$name(${children.mkString(",")})" + def this() = this(Literal(null)) - override def newInstance(): WindowFunction = throw new UnresolvedException(this, "newInstance") + override val direction = Ascending } -case class UnresolvedWindowExpression( - child: UnresolvedWindowFunction, - windowSpec: WindowSpecReference) extends UnaryExpression with Unevaluable { +/** + * The Lag function returns the value of 'x' at 'offset' rows before the current row in the window. + * Offsets start at 0, which is the current row. The offset must be constant integer value. The + * default offset is 1. When the value of 'x' is null at the offset, or when the offset is smaller + * than the window, the default expression is evaluated. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + * + * @param input expression to evaluate 'offset' rows before the current row. + * @param offset rows to jump back in the partition. + * @param default to use when the input value is null or when the offset is smaller than the window. + */ +@ExpressionDescription(usage = + """_FUNC_(input, offset, default) - LAG returns the value of 'x' at 'offset' rows + before the current row in the window""") +case class Lag(input: Expression, offset: Expression, default: Expression) + extends OffsetWindowFunction { - override def dataType: DataType = throw new UnresolvedException(this, "dataType") - override def foldable: Boolean = throw new UnresolvedException(this, "foldable") - override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override lazy val resolved = false + def this(input: Expression, offset: Expression) = this(input, offset, Literal(null)) + + def this(input: Expression) = this(input, Literal(1)) + + def this() = this(Literal(null)) + + override val direction = Descending } -case class WindowExpression( - windowFunction: WindowFunction, - windowSpec: WindowSpecDefinition) extends Expression with Unevaluable { +abstract class AggregateWindowFunction extends DeclarativeAggregate with WindowFunction { + self: Product => + override val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow) + override def dataType: DataType = IntegerType + override def nullable: Boolean = true + override def supportsPartial: Boolean = false + override lazy val mergeExpressions = + throw new UnsupportedOperationException("Window Functions do not support merging.") +} - override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil +abstract class RowNumberLike extends AggregateWindowFunction { + override def children: Seq[Expression] = Nil + override def inputTypes: Seq[AbstractDataType] = Nil + protected val zero = Literal(0) + protected val one = Literal(1) + protected val rowNumber = AttributeReference("rowNumber", IntegerType, nullable = false)() + override val aggBufferAttributes: Seq[AttributeReference] = rowNumber :: Nil + override val initialValues: Seq[Expression] = zero :: Nil + override val updateExpressions: Seq[Expression] = Add(rowNumber, one) :: Nil +} - override def dataType: DataType = windowFunction.dataType - override def foldable: Boolean = windowFunction.foldable - override def nullable: Boolean = windowFunction.nullable +/** + * A [[SizeBasedWindowFunction]] needs the size of the current window for its calculation. + */ +trait SizeBasedWindowFunction extends AggregateWindowFunction { + // It's made a val so that the attribute created on driver side is serialized to executor side. + // Otherwise, if it's defined as a function, when it's called on executor side, it actually + // returns the singleton value instantiated on executor side, which has different expression ID + // from the one created on driver side. + val n: AttributeReference = SizeBasedWindowFunction.n +} - override def toString: String = s"$windowFunction $windowSpec" +object SizeBasedWindowFunction { + val n = AttributeReference("window__partition__size", IntegerType, nullable = false)() } /** - * Extractor for making working with frame boundaries easier. + * The RowNumber function computes a unique, sequential number to each row, starting with one, + * according to the ordering of rows within the window partition. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. */ -object FrameBoundaryExtractor { - def unapply(boundary: FrameBoundary): Option[Int] = boundary match { - case CurrentRow => Some(0) - case ValuePreceding(offset) => Some(-offset) - case ValueFollowing(offset) => Some(offset) - case _ => None +@ExpressionDescription(usage = + """_FUNC_() - The ROW_NUMBER() function assigns a unique, sequential number to + each row, starting with one, according to the ordering of rows within + the window partition.""") +case class RowNumber() extends RowNumberLike { + override val evaluateExpression = rowNumber + override def sql: String = "ROW_NUMBER()" +} + +/** + * The CumeDist function computes the position of a value relative to a all values in the partition. + * The result is the number of rows preceding or equal to the current row in the ordering of the + * partition divided by the total number of rows in the window partition. Any tie values in the + * ordering will evaluate to the same position. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + */ +@ExpressionDescription(usage = + """_FUNC_() - The CUME_DIST() function computes the position of a value relative to + a all values in the partition.""") +case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction { + override def dataType: DataType = DoubleType + // The frame for CUME_DIST is Range based instead of Row based, because CUME_DIST must + // return the same value for equal values in the partition. + override val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) + override val evaluateExpression = Divide(Cast(rowNumber, DoubleType), Cast(n, DoubleType)) + override def sql: String = "CUME_DIST()" +} + +/** + * The NTile function divides the rows for each window partition into 'n' buckets ranging from 1 to + * at most 'n'. Bucket values will differ by at most 1. If the number of rows in the partition does + * not divide evenly into the number of buckets, then the remainder values are distributed one per + * bucket, starting with the first bucket. + * + * The NTile function is particularly useful for the calculation of tertiles, quartiles, deciles and + * other common summary statistics + * + * The function calculates two variables during initialization: The size of a regular bucket, and + * the number of buckets that will have one extra row added to it (when the rows do not evenly fit + * into the number of buckets); both variables are based on the size of the current partition. + * During the calculation process the function keeps track of the current row number, the current + * bucket number, and the row number at which the bucket will change (bucketThreshold). When the + * current row number reaches bucket threshold, the bucket value is increased by one and the the + * threshold is increased by the bucket size (plus one extra if the current bucket is padded). + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + * + * @param buckets number of buckets to divide the rows in. Default value is 1. + */ +@ExpressionDescription(usage = + """_FUNC_(x) - The NTILE(n) function divides the rows for each window partition + into 'n' buckets ranging from 1 to at most 'n'.""") +case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindowFunction { + def this() = this(Literal(1)) + + override def children: Seq[Expression] = Seq(buckets) + + // Validate buckets. Note that this could be relaxed, the bucket value only needs to constant + // for each partition. + override def checkInputDataTypes(): TypeCheckResult = { + if (!buckets.foldable) { + return TypeCheckFailure(s"Buckets expression must be foldable, but got $buckets") + } + + if (buckets.dataType != IntegerType) { + return TypeCheckFailure(s"Buckets expression must be integer type, but got $buckets") + } + + val i = buckets.eval().asInstanceOf[Int] + if (i > 0) { + TypeCheckSuccess + } else { + TypeCheckFailure(s"Buckets expression must be positive, but got: $i") + } + } + + private val bucket = AttributeReference("bucket", IntegerType, nullable = false)() + private val bucketThreshold = + AttributeReference("bucketThreshold", IntegerType, nullable = false)() + private val bucketSize = AttributeReference("bucketSize", IntegerType, nullable = false)() + private val bucketsWithPadding = + AttributeReference("bucketsWithPadding", IntegerType, nullable = false)() + private def bucketOverflow(e: Expression) = + If(GreaterThanOrEqual(rowNumber, bucketThreshold), e, zero) + + override val aggBufferAttributes = Seq( + rowNumber, + bucket, + bucketThreshold, + bucketSize, + bucketsWithPadding + ) + + override val initialValues = Seq( + zero, + zero, + zero, + Cast(Divide(n, buckets), IntegerType), + Cast(Remainder(n, buckets), IntegerType) + ) + + override val updateExpressions = Seq( + Add(rowNumber, one), + Add(bucket, bucketOverflow(one)), + Add(bucketThreshold, bucketOverflow( + Add(bucketSize, If(LessThan(bucket, bucketsWithPadding), one, zero)))), + NoOp, + NoOp + ) + + override val evaluateExpression = bucket +} + +/** + * A RankLike function is a WindowFunction that changes its value based on a change in the value of + * the order of the window in which is processed. For instance, when the value of 'x' changes in a + * window ordered by 'x' the rank function also changes. The size of the change of the rank function + * is (typically) not dependent on the size of the change in 'x'. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + */ +abstract class RankLike extends AggregateWindowFunction { + override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType) + + /** Store the values of the window 'order' expressions. */ + protected val orderAttrs = children.map { expr => + AttributeReference(expr.sql, expr.dataType)() } + + /** Predicate that detects if the order attributes have changed. */ + protected val orderEquals = children.zip(orderAttrs) + .map(EqualNullSafe.tupled) + .reduceOption(And) + .getOrElse(Literal(true)) + + protected val orderInit = children.map(e => Literal.create(null, e.dataType)) + protected val rank = AttributeReference("rank", IntegerType, nullable = false)() + protected val rowNumber = AttributeReference("rowNumber", IntegerType, nullable = false)() + protected val zero = Literal(0) + protected val one = Literal(1) + protected val increaseRowNumber = Add(rowNumber, one) + + /** + * Different RankLike implementations use different source expressions to update their rank value. + * Rank for instance uses the number of rows seen, whereas DenseRank uses the number of changes. + */ + protected def rankSource: Expression = rowNumber + + /** Increase the rank when the current rank == 0 or when the one of order attributes changes. */ + protected val increaseRank = If(And(orderEquals, Not(EqualTo(rank, zero))), rank, rankSource) + + override val aggBufferAttributes: Seq[AttributeReference] = rank +: rowNumber +: orderAttrs + override val initialValues = zero +: one +: orderInit + override val updateExpressions = increaseRank +: increaseRowNumber +: children + override val evaluateExpression: Expression = rank + + def withOrder(order: Seq[Expression]): RankLike +} + +/** + * The Rank function computes the rank of a value in a group of values. The result is one plus the + * number of rows preceding or equal to the current row in the ordering of the partition. Tie values + * will produce gaps in the sequence. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + * + * @param children to base the rank on; a change in the value of one the children will trigger a + * change in rank. This is an internal parameter and will be assigned by the + * Analyser. + */ +@ExpressionDescription(usage = + """_FUNC_() - RANK() computes the rank of a value in a group of values. The result + is one plus the number of rows preceding or equal to the current row in the + ordering of the partition. Tie values will produce gaps in the sequence.""") +case class Rank(children: Seq[Expression]) extends RankLike { + def this() = this(Nil) + override def withOrder(order: Seq[Expression]): Rank = Rank(order) + override def sql: String = "RANK()" +} + +/** + * The DenseRank function computes the rank of a value in a group of values. The result is one plus + * the previously assigned rank value. Unlike Rank, DenseRank will not produce gaps in the ranking + * sequence. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + * + * @param children to base the rank on; a change in the value of one the children will trigger a + * change in rank. This is an internal parameter and will be assigned by the + * Analyser. + */ +@ExpressionDescription(usage = + """_FUNC_() - The DENSE_RANK() function computes the rank of a value in a group of + values. The result is one plus the previously assigned rank value. Unlike Rank, + DenseRank will not produce gaps in the ranking sequence.""") +case class DenseRank(children: Seq[Expression]) extends RankLike { + def this() = this(Nil) + override def withOrder(order: Seq[Expression]): DenseRank = DenseRank(order) + override protected def rankSource = Add(rank, one) + override val updateExpressions = increaseRank +: children + override val aggBufferAttributes = rank +: orderAttrs + override val initialValues = zero +: orderInit + override def sql: String = "DENSE_RANK()" +} + +/** + * The PercentRank function computes the percentage ranking of a value in a group of values. The + * result the rank of the minus one divided by the total number of rows in the partition minus one: + * (r - 1) / (n - 1). If a partition only contains one row, the function will return 0. + * + * The PercentRank function is similar to the CumeDist function, but it uses rank values instead of + * row counts in the its numerator. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + * + * @param children to base the rank on; a change in the value of one the children will trigger a + * change in rank. This is an internal parameter and will be assigned by the + * Analyser. + */ +@ExpressionDescription(usage = + """_FUNC_() - PERCENT_RANK() The PercentRank function computes the percentage + ranking of a value in a group of values.""") +case class PercentRank(children: Seq[Expression]) extends RankLike with SizeBasedWindowFunction { + def this() = this(Nil) + override def withOrder(order: Seq[Expression]): PercentRank = PercentRank(order) + override def dataType: DataType = DoubleType + override val evaluateExpression = If(GreaterThan(n, one), + Divide(Cast(Subtract(rank, one), DoubleType), Cast(Subtract(n, one), DoubleType)), + Literal(0.0d)) + override def sql: String = "PERCENT_RANK()" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala new file mode 100644 index 0000000000000..aae75956ea61a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + + +/** + * An identifier that optionally specifies a database. + * + * Format (unquoted): "name" or "db.name" + * Format (quoted): "`name`" or "`db`.`name`" + */ +sealed trait IdentifierWithDatabase { + val identifier: String + def database: Option[String] + def quotedString: String = database.map(db => s"`$db`.`$identifier`").getOrElse(s"`$identifier`") + def unquotedString: String = database.map(db => s"$db.$identifier").getOrElse(identifier) + override def toString: String = quotedString +} + + +/** + * Identifies a table in a database. + * If `database` is not defined, the current database is used. + * When we register a permenent function in the FunctionRegistry, we use + * unquotedString as the function name. + */ +case class TableIdentifier(table: String, database: Option[String]) + extends IdentifierWithDatabase { + + override val identifier: String = table + + def this(table: String) = this(table, None) + +} + +object TableIdentifier { + def apply(tableName: String): TableIdentifier = new TableIdentifier(tableName) +} + + +/** + * Identifies a function in a database. + * If `database` is not defined, the current database is used. + */ +case class FunctionIdentifier(funcName: String, database: Option[String]) + extends IdentifierWithDatabase { + + override val identifier: String = funcName + + def this(funcName: String) = this(funcName, None) +} + +object FunctionIdentifier { + def apply(funcName: String): FunctionIdentifier = new FunctionIdentifier(funcName) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 338c5193cb7a2..f5172b213a74b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -17,57 +17,108 @@ package org.apache.spark.sql.catalyst.optimizer +import scala.annotation.tailrec import scala.collection.immutable.HashSet -import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries} + +import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, DistinctAggregationRewriter, EliminateSubqueryAliases} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.FullOuter -import org.apache.spark.sql.catalyst.plans.LeftOuter -import org.apache.spark.sql.catalyst.plans.RightOuter -import org.apache.spark.sql.catalyst.plans.LeftSemi +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ -abstract class Optimizer extends RuleExecutor[LogicalPlan] - -object DefaultOptimizer extends Optimizer { - val batches = - // SubQueries are only needed for analysis and can be removed before execution. - Batch("Remove SubQueries", FixedPoint(100), - EliminateSubQueries) :: +/** + * Abstract class all optimizers should inherit of, contains the standard batches (extending + * Optimizers can override this. + */ +abstract class Optimizer extends RuleExecutor[LogicalPlan] { + def batches: Seq[Batch] = { + // Technically some of the rules in Finish Analysis are not optimizer rules and belong more + // in the analyzer, because they are needed for correctness (e.g. ComputeCurrentTime). + // However, because we also use the analyzer to canonicalized queries (for view definition), + // we do not eliminate subqueries or compute current time in the analyzer. + Batch("Finish Analysis", Once, + EliminateSubqueryAliases, + ComputeCurrentTime, + DistinctAggregationRewriter) :: + ////////////////////////////////////////////////////////////////////////////////////////// + // Optimizer rules start here + ////////////////////////////////////////////////////////////////////////////////////////// + // - Do the first call of CombineUnions before starting the major Optimizer rules, + // since it can reduce the number of iteration and the other rules could add/move + // extra operators between two adjacent Union operators. + // - Call CombineUnions again in Batch("Operator Optimizations"), + // since the other rules might make two separate Unions operators adjacent. + Batch("Union", Once, + CombineUnions) :: + Batch("Replace Operators", FixedPoint(100), + ReplaceIntersectWithSemiJoin, + ReplaceDistinctWithAggregate) :: Batch("Aggregate", FixedPoint(100), - ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down SetOperationPushDown, SamplePushDown, + ReorderJoin, + OuterJoinElimination, PushPredicateThroughJoin, - PushPredicateThroughProject, - PushPredicateThroughGenerate, - PushPredicateThroughAggregate, + PushDownPredicate, + LimitPushDown, ColumnPruning, + InferFiltersFromConstraints, // Operator combine - ProjectCollapsing, + CollapseRepartition, + CollapseProject, CombineFilters, CombineLimits, - // Constant folding + CombineUnions, + // Constant folding and strength reduction NullPropagation, OptimizeIn, ConstantFolding, LikeSimplification, BooleanSimplification, - RemoveDispensable, - SimplifyFilters, + SimplifyConditionals, + RemoveDispensableExpressions, + BinaryComparisonSimplification, + PruneFilters, + EliminateSorts, SimplifyCasts, - SimplifyCaseConversionExpressions) :: + SimplifyCaseConversionExpressions, + EliminateSerialization) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: + Batch("Typed Filter Optimization", FixedPoint(100), + EmbedSerializerInFilter) :: Batch("LocalRelation", FixedPoint(100), - ConvertToLocalRelation) :: Nil + ConvertToLocalRelation) :: + Batch("Subquery", Once, + OptimizeSubqueries) :: Nil + } + + /** + * Optimize all the subqueries inside expression. + */ + object OptimizeSubqueries extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case subquery: SubqueryExpression => + subquery.withNewPlan(Optimizer.this.execute(subquery.query)) + } + } } +/** + * Non-abstract representation of the standard Spark optimizing strategies + * + * To ensure extendability, we leave the standard rules in the abstract optimizer rules, while + * specific rules go to the subclasses + */ +object DefaultOptimizer extends Optimizer + /** * Pushes operations down into a Sample. */ @@ -77,23 +128,112 @@ object SamplePushDown extends Rule[LogicalPlan] { // Push down projection into sample case Project(projectList, s @ Sample(lb, up, replace, seed, child)) => Sample(lb, up, replace, seed, - Project(projectList, child)) + Project(projectList, child))() + } +} + +/** + * Removes cases where we are unnecessarily going between the object and serialized (InternalRow) + * representation of data item. For example back to back map operations. + */ +object EliminateSerialization extends Rule[LogicalPlan] { + // TODO: find a more general way to do this optimization. + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case m @ MapPartitions(_, deserializer, _, child: ObjectOperator) + if !deserializer.isInstanceOf[Attribute] && + deserializer.dataType == child.outputObject.dataType => + val childWithoutSerialization = child.withObjectOutput + m.copy( + deserializer = childWithoutSerialization.output.head, + child = childWithoutSerialization) + + case m @ MapElements(_, deserializer, _, child: ObjectOperator) + if !deserializer.isInstanceOf[Attribute] && + deserializer.dataType == child.outputObject.dataType => + val childWithoutSerialization = child.withObjectOutput + m.copy( + deserializer = childWithoutSerialization.output.head, + child = childWithoutSerialization) + + case d @ DeserializeToObject(_, s: SerializeFromObject) + if d.outputObjectType == s.inputObjectType => + // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. + val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId) + Project(objAttr :: Nil, s.child) + } +} + +/** + * Pushes down [[LocalLimit]] beneath UNION ALL and beneath the streamed inputs of outer joins. + */ +object LimitPushDown extends Rule[LogicalPlan] { + + private def stripGlobalLimitIfPresent(plan: LogicalPlan): LogicalPlan = { + plan match { + case GlobalLimit(expr, child) => child + case _ => plan + } + } + + private def maybePushLimit(limitExp: Expression, plan: LogicalPlan): LogicalPlan = { + (limitExp, plan.maxRows) match { + case (IntegerLiteral(maxRow), Some(childMaxRows)) if maxRow < childMaxRows => + LocalLimit(limitExp, stripGlobalLimitIfPresent(plan)) + case (_, None) => + LocalLimit(limitExp, stripGlobalLimitIfPresent(plan)) + case _ => plan + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // Adding extra Limits below UNION ALL for children which are not Limit or do not have Limit + // descendants whose maxRow is larger. This heuristic is valid assuming there does not exist any + // Limit push-down rule that is unable to infer the value of maxRows. + // Note: right now Union means UNION ALL, which does not de-duplicate rows, so it is safe to + // pushdown Limit through it. Once we add UNION DISTINCT, however, we will not be able to + // pushdown Limit. + case LocalLimit(exp, Union(children)) => + LocalLimit(exp, Union(children.map(maybePushLimit(exp, _)))) + // Add extra limits below OUTER JOIN. For LEFT OUTER and FULL OUTER JOIN we push limits to the + // left and right sides, respectively. For FULL OUTER JOIN, we can only push limits to one side + // because we need to ensure that rows from the limited side still have an opportunity to match + // against all candidates from the non-limited side. We also need to ensure that this limit + // pushdown rule will not eventually introduce limits on both sides if it is applied multiple + // times. Therefore: + // - If one side is already limited, stack another limit on top if the new limit is smaller. + // The redundant limit will be collapsed by the CombineLimits rule. + // - If neither side is limited, limit the side that is estimated to be bigger. + case LocalLimit(exp, join @ Join(left, right, joinType, condition)) => + val newJoin = joinType match { + case RightOuter => join.copy(right = maybePushLimit(exp, right)) + case LeftOuter => join.copy(left = maybePushLimit(exp, left)) + case FullOuter => + (left.maxRows, right.maxRows) match { + case (None, None) => + if (left.statistics.sizeInBytes >= right.statistics.sizeInBytes) { + join.copy(left = maybePushLimit(exp, left)) + } else { + join.copy(right = maybePushLimit(exp, right)) + } + case (Some(_), Some(_)) => join + case (Some(_), None) => join.copy(left = maybePushLimit(exp, left)) + case (None, Some(_)) => join.copy(right = maybePushLimit(exp, right)) + + } + case _ => join + } + LocalLimit(exp, newJoin) } } /** - * Pushes certain operations to both sides of a Union, Intersect or Except operator. + * Pushes certain operations to both sides of a Union or Except operator. * Operations that are safe to pushdown are listed as follows. * Union: * Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is * safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT, * we will not be able to pushdown Projections. * - * Intersect: - * It is not safe to pushdown Projections through it because we need to get the - * intersect of rows by comparing the entire rows. It is fine to pushdown Filters - * with deterministic condition. - * * Except: * It is not safe to pushdown Projections through it because we need to get the * intersect of rows by comparing the entire rows. It is fine to pushdown Filters @@ -104,16 +244,14 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { /** * Maps Attributes from the left side to the corresponding Attribute on the right side. */ - private def buildRewrites(bn: BinaryNode): AttributeMap[Attribute] = { - assert(bn.isInstanceOf[Union] || bn.isInstanceOf[Intersect] || bn.isInstanceOf[Except]) - assert(bn.left.output.size == bn.right.output.size) - - AttributeMap(bn.left.output.zip(bn.right.output)) + private def buildRewrites(left: LogicalPlan, right: LogicalPlan): AttributeMap[Attribute] = { + assert(left.output.size == right.output.size) + AttributeMap(left.output.zip(right.output)) } /** * Rewrites an expression so that it can be pushed to the right side of a - * Union, Intersect or Except operator. This method relies on the fact that the output attributes + * Union or Except operator. This method relies on the fact that the output attributes * of a union/intersect/except are always equal to the left child's output. */ private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = { @@ -142,43 +280,38 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Push down filter into union - case Filter(condition, u @ Union(left, right)) => - val (deterministic, nondeterministic) = partitionByDeterministic(condition) - val rewrites = buildRewrites(u) - Filter(nondeterministic, - Union( - Filter(deterministic, left), - Filter(pushToRight(deterministic, rewrites), right) - ) - ) // Push down deterministic projection through UNION ALL - case p @ Project(projectList, u @ Union(left, right)) => + case p @ Project(projectList, Union(children)) => + assert(children.nonEmpty) if (projectList.forall(_.deterministic)) { - val rewrites = buildRewrites(u) - Union( - Project(projectList, left), - Project(projectList.map(pushToRight(_, rewrites)), right)) + val newFirstChild = Project(projectList, children.head) + val newOtherChildren = children.tail.map { child => + val rewrites = buildRewrites(children.head, child) + Project(projectList.map(pushToRight(_, rewrites)), child) + } + Union(newFirstChild +: newOtherChildren) } else { p } - // Push down filter through INTERSECT - case Filter(condition, i @ Intersect(left, right)) => + // Push down filter into union + case Filter(condition, Union(children)) => + assert(children.nonEmpty) val (deterministic, nondeterministic) = partitionByDeterministic(condition) - val rewrites = buildRewrites(i) - Filter(nondeterministic, - Intersect( - Filter(deterministic, left), - Filter(pushToRight(deterministic, rewrites), right) - ) - ) + val newFirstChild = Filter(deterministic, children.head) + val newOtherChildren = children.tail.map { + child => { + val rewrites = buildRewrites(children.head, child) + Filter(pushToRight(deterministic, rewrites), child) + } + } + Filter(nondeterministic, Union(newFirstChild +: newOtherChildren)) // Push down filter through EXCEPT - case Filter(condition, e @ Except(left, right)) => + case Filter(condition, Except(left, right)) => val (deterministic, nondeterministic) = partitionByDeterministic(condition) - val rewrites = buildRewrites(e) + val rewrites = buildRewrites(left, right) Filter(nondeterministic, Except( Filter(deterministic, left), @@ -189,140 +322,189 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { } /** - * Attempts to eliminate the reading of unneeded columns from the query plan using the following - * transformations: + * Attempts to eliminate the reading of unneeded columns from the query plan. + * + * Since adding Project before Filter conflicts with PushPredicatesThroughProject, this rule will + * remove the Project p2 in the following pattern: + * + * p1 @ Project(_, Filter(_, p2 @ Project(_, child))) if p2.outputSet.subsetOf(p2.inputSet) * - * - Inserting Projections beneath the following operators: - * - Aggregate - * - Generate - * - Project <- Join - * - LeftSemiJoin + * p2 is usually inserted by this rule and useless, p1 could prune the columns anyway. */ object ColumnPruning extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a @ Aggregate(_, _, e @ Expand(_, groupByExprs, _, child)) - if (child.outputSet -- AttributeSet(groupByExprs) -- a.references).nonEmpty => - a.copy(child = e.copy(child = prunedChild(child, AttributeSet(groupByExprs) ++ a.references))) + private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = + output1.size == output2.size && + output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) + + def apply(plan: LogicalPlan): LogicalPlan = removeProjectBeforeFilter(plan transform { + // Prunes the unused columns from project list of Project/Aggregate/Expand + case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty => + p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains))) + case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty => + p.copy( + child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains))) + case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty => + val newOutput = e.output.filter(a.references.contains(_)) + val newProjects = e.projections.map { proj => + proj.zip(e.output).filter { case (e, a) => + newOutput.contains(a) + }.unzip._1 + } + a.copy(child = Expand(newProjects, newOutput, grandChild)) - // Eliminate attributes that are not needed to calculate the specified aggregates. - case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => - a.copy(child = Project(a.references.toSeq, child)) + // Prunes the unused columns from child of MapPartitions + case mp @ MapPartitions(_, _, _, child) if (child.outputSet -- mp.references).nonEmpty => + mp.copy(child = prunedChild(child, mp.references)) - // Eliminate attributes that are not needed to calculate the Generate. + // Prunes the unused columns from child of Aggregate/Expand/Generate + case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => + a.copy(child = prunedChild(child, a.references)) + case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => + e.copy(child = prunedChild(child, e.references)) case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty => - g.copy(child = Project(g.references.toSeq, g.child)) + g.copy(child = prunedChild(g.child, g.references)) + // Turn off `join` for Generate if no column from it's child is used case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) => p.copy(child = g.copy(join = false)) - case p @ Project(projectList, g: Generate) if g.join => - val neededChildOutput = p.references -- g.generatorOutput ++ g.references - if (neededChildOutput == g.child.outputSet) { - p + // Eliminate unneeded attributes from right side of a Left Existence Join. + case j @ Join(left, right, LeftExistence(_), condition) => + j.copy(right = prunedChild(right, j.references)) + + // all the columns will be used to compare, so we can't prune them + case p @ Project(_, _: SetOperation) => p + case p @ Project(_, _: Distinct) => p + // Eliminate unneeded attributes from children of Union. + case p @ Project(_, u: Union) => + if ((u.outputSet -- p.references).nonEmpty) { + val firstChild = u.children.head + val newOutput = prunedChild(firstChild, p.references).output + // pruning the columns of all children based on the pruned first child. + val newChildren = u.children.map { p => + val selected = p.output.zipWithIndex.filter { case (a, i) => + newOutput.contains(firstChild.output(i)) + }.map(_._1) + Project(selected, p) + } + p.copy(child = u.withNewChildren(newChildren)) } else { - Project(projectList, g.copy(child = Project(neededChildOutput.toSeq, g.child))) + p } - case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child)) - if (a.outputSet -- p.references).nonEmpty => - Project( - projectList, - Aggregate( - groupingExpressions, - aggregateExpressions.filter(e => p.references.contains(e)), - child)) - - // Eliminate unneeded attributes from either side of a Join. - case Project(projectList, Join(left, right, joinType, condition)) => - // Collect the list of all references required either above or to evaluate the condition. - val allReferences: AttributeSet = - AttributeSet( - projectList.flatMap(_.references.iterator)) ++ - condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) - - /** Applies a projection only when the child is producing unnecessary attributes */ - def pruneJoinChild(c: LogicalPlan): LogicalPlan = prunedChild(c, allReferences) - - Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition)) - - // Eliminate unneeded attributes from right side of a LeftSemiJoin. - case Join(left, right, LeftSemi, condition) => - // Collect the list of all references required to evaluate the condition. - val allReferences: AttributeSet = - condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) - - Join(left, prunedChild(right, allReferences), LeftSemi, condition) - - // Push down project through limit, so that we may have chance to push it further. - case Project(projectList, Limit(exp, child)) => - Limit(exp, Project(projectList, child)) - - // Push down project if possible when the child is sort. - case p @ Project(projectList, s @ Sort(_, _, grandChild)) => - if (s.references.subsetOf(p.outputSet)) { - s.copy(child = Project(projectList, grandChild)) - } else { - val neededReferences = s.references ++ p.references - if (neededReferences == grandChild.outputSet) { - // No column we can prune, return the original plan. - p - } else { - // Do not use neededReferences.toSeq directly, should respect grandChild's output order. - val newProjectList = grandChild.output.filter(neededReferences.contains) - p.copy(child = s.copy(child = Project(newProjectList, grandChild))) - } - } + // Prune unnecessary window expressions + case p @ Project(_, w: Window) if (w.windowOutputSet -- p.references).nonEmpty => + p.copy(child = w.copy( + windowExpressions = w.windowExpressions.filter(p.references.contains))) + + // Eliminate no-op Window + case w: Window if w.windowExpressions.isEmpty => w.child // Eliminate no-op Projects - case Project(projectList, child) if child.output == projectList => child - } + case p @ Project(projectList, child) if sameOutput(child.output, p.output) => child + + // Can't prune the columns on LeafNode + case p @ Project(_, l: LeafNode) => p + + // for all other logical plans that inherits the output from it's children + case p @ Project(_, child) => + val required = child.references ++ p.references + if ((child.inputSet -- required).nonEmpty) { + val newChildren = child.children.map(c => prunedChild(c, required)) + p.copy(child = child.withNewChildren(newChildren)) + } else { + p + } + }) /** Applies a projection only when the child is producing unnecessary attributes */ private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) { - Project(allReferences.filter(c.outputSet.contains).toSeq, c) + Project(c.output.filter(allReferences.contains), c) } else { c } + + /** + * The Project before Filter is not necessary but conflict with PushPredicatesThroughProject, + * so remove it. + */ + private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transform { + case p1 @ Project(_, f @ Filter(_, p2 @ Project(_, child))) + if p2.outputSet.subsetOf(child.outputSet) => + p1.copy(child = f.copy(child = child)) + } } /** * Combines two adjacent [[Project]] operators into one and perform alias substitution, * merging the expressions into one single expression. */ -object ProjectCollapsing extends Rule[LogicalPlan] { +object CollapseProject extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case p @ Project(projectList1, Project(projectList2, child)) => - // Create a map of Aliases to their values from the child projection. - // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). - val aliasMap = AttributeMap(projectList2.collect { - case a: Alias => (a.toAttribute, a) - }) - - // We only collapse these two Projects if their overlapped expressions are all - // deterministic. - val hasNondeterministic = projectList1.exists(_.collect { - case a: Attribute if aliasMap.contains(a) => aliasMap(a).child - }.exists(!_.deterministic)) - - if (hasNondeterministic) { + case p1 @ Project(_, p2: Project) => + if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) { + p1 + } else { + p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList)) + } + case p @ Project(_, agg: Aggregate) => + if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)) { p } else { - // Substitute any attributes that are produced by the child projection, so that we safely - // eliminate it. - // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' - // TODO: Fix TransformBase to avoid the cast below. - val substitutedProjection = projectList1.map(_.transform { - case a: Attribute => aliasMap.getOrElse(a, a) - }).asInstanceOf[Seq[NamedExpression]] - // collapse 2 projects may introduce unnecessary Aliases, trim them here. - val cleanedProjection = substitutedProjection.map(p => - CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression] - ) - Project(cleanedProjection, child) + agg.copy(aggregateExpressions = buildCleanedProjectList( + p.projectList, agg.aggregateExpressions)) } } + + private def collectAliases(projectList: Seq[NamedExpression]): AttributeMap[Alias] = { + AttributeMap(projectList.collect { + case a: Alias => a.toAttribute -> a + }) + } + + private def haveCommonNonDeterministicOutput( + upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = { + // Create a map of Aliases to their values from the lower projection. + // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). + val aliases = collectAliases(lower) + + // Collapse upper and lower Projects if and only if their overlapped expressions are all + // deterministic. + upper.exists(_.collect { + case a: Attribute if aliases.contains(a) => aliases(a).child + }.exists(!_.deterministic)) + } + + private def buildCleanedProjectList( + upper: Seq[NamedExpression], + lower: Seq[NamedExpression]): Seq[NamedExpression] = { + // Create a map of Aliases to their values from the lower projection. + // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). + val aliases = collectAliases(lower) + + // Substitute any attributes that are produced by the lower projection, so that we safely + // eliminate it. + // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' + val rewrittenUpper = upper.map(_.transform { + case a: Attribute => aliases.getOrElse(a, a) + }) + // collapse upper and lower Projects may introduce unnecessary Aliases, trim them here. + rewrittenUpper.map { p => + CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression] + } + } +} + +/** + * Combines adjacent [[Repartition]] operators by keeping only the last one. + */ +object CollapseRepartition extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case r @ Repartition(numPartitions, shuffle, Repartition(_, _, child)) => + Repartition(numPartitions, shuffle, child) + } } /** @@ -335,22 +517,28 @@ object LikeSimplification extends Rule[LogicalPlan] { // Cases like "something\%" are not optimized, but this does not affect correctness. private val startsWith = "([^_%]+)%".r private val endsWith = "%([^_%]+)".r + private val startsAndEndsWith = "([^_%]+)%([^_%]+)".r private val contains = "%([^_%]+)%".r private val equalTo = "([^_%]*)".r def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case Like(l, Literal(utf, StringType)) => - utf.toString match { - case startsWith(pattern) if !pattern.endsWith("\\") => - StartsWith(l, Literal(pattern)) - case endsWith(pattern) => - EndsWith(l, Literal(pattern)) - case contains(pattern) if !pattern.endsWith("\\") => - Contains(l, Literal(pattern)) - case equalTo(pattern) => - EqualTo(l, Literal(pattern)) + case Like(input, Literal(pattern, StringType)) => + pattern.toString match { + case startsWith(prefix) if !prefix.endsWith("\\") => + StartsWith(input, Literal(prefix)) + case endsWith(postfix) => + EndsWith(input, Literal(postfix)) + // 'a%a' pattern is basically same with 'a%' && '%a'. + // However, the additional `Length` condition is required to prevent 'a' match 'a%a'. + case startsAndEndsWith(prefix, postfix) if !prefix.endsWith("\\") => + And(GreaterThanOrEqual(Length(input), Literal(prefix.size + postfix.size)), + And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix)))) + case contains(infix) if !infix.endsWith("\\") => + Contains(input, Literal(infix)) + case equalTo(str) => + EqualTo(input, Literal(str)) case _ => - Like(l, Literal.create(utf, StringType)) + Like(input, Literal.create(pattern, StringType)) } } } @@ -361,9 +549,15 @@ object LikeSimplification extends Rule[LogicalPlan] { * Null value propagation from bottom to top of the expression tree. */ object NullPropagation extends Rule[LogicalPlan] { + private def nonNullLiteral(e: Expression): Boolean = e match { + case Literal(null, _) => false + case _ => true + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType) + case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) => + Cast(Literal(0L), e.dataType) case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) case e @ GetArrayItem(Literal(null, _), _) => Literal.create(null, e.dataType) @@ -375,14 +569,13 @@ object NullPropagation extends Rule[LogicalPlan] { Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) - case e @ Count(expr) if !expr.nullable => Count(Literal(1)) + case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) => + // This rule should be only triggered when isDistinct field is false. + ae.copy(aggregateFunction = Count(Literal(1))) // For Coalesce, remove null literals. case e @ Coalesce(children) => - val newChildren = children.filter { - case Literal(null, _) => false - case _ => true - } + val newChildren = children.filter(nonNullLiteral) if (newChildren.length == 0) { Literal.create(null, e.dataType) } else if (newChildren.length == 1) { @@ -426,6 +619,44 @@ object NullPropagation extends Rule[LogicalPlan] { } } +/** + * Generate a list of additional filters from an operator's existing constraint but remove those + * that are either already part of the operator's condition or are part of the operator's child + * constraints. These filters are currently inserted to the existing conditions in the Filter + * operators and on either side of Join operators. + * + * Note: While this optimization is applicable to all types of join, it primarily benefits Inner and + * LeftSemi joins. + */ +object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case filter @ Filter(condition, child) => + val newFilters = filter.constraints -- + (child.constraints ++ splitConjunctivePredicates(condition)) + if (newFilters.nonEmpty) { + Filter(And(newFilters.reduce(And), condition), child) + } else { + filter + } + + case join @ Join(left, right, joinType, conditionOpt) => + // Only consider constraints that can be pushed down completely to either the left or the + // right child + val constraints = join.constraints.filter { c => + c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet)} + // Remove those constraints that are already enforced by either the left or the right child + val additionalConstraints = constraints -- (left.constraints ++ right.constraints) + val newConditionOpt = conditionOpt match { + case Some(condition) => + val newFilters = additionalConstraints -- splitConjunctivePredicates(condition) + if (newFilters.nonEmpty) Option(And(newFilters.reduce(And), condition)) else None + case None => + additionalConstraints.reduceOption(And) + } + if (newConditionOpt.isDefined) Join(left, right, joinType, newConditionOpt) else join + } +} + /** * Replaces [[Expression Expressions]] that can be statically evaluated with * equivalent [[Literal]] values. @@ -468,132 +699,202 @@ object OptimizeIn extends Rule[LogicalPlan] { object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case and @ And(left, right) => (left, right) match { - // true && r => r - case (Literal(true, BooleanType), r) => r - // l && true => l - case (l, Literal(true, BooleanType)) => l - // false && r => false - case (Literal(false, BooleanType), _) => Literal(false) - // l && false => false - case (_, Literal(false, BooleanType)) => Literal(false) - // a && a => a - case (l, r) if l fastEquals r => l - // a && (not(a) || b) => a && b - case (l, Or(l1, r)) if (Not(l) == l1) => And(l, r) - case (l, Or(r, l1)) if (Not(l) == l1) => And(l, r) - case (Or(l, l1), r) if (l1 == Not(r)) => And(l, r) - case (Or(l1, l), r) if (l1 == Not(r)) => And(l, r) - // (a || b) && (a || c) => a || (b && c) - case _ => - // 1. Split left and right to get the disjunctive predicates, - // i.e. lhs = (a, b), rhs = (a, c) - // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) - // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) - // 4. Apply the formula, get the optimized predicate: common || (ldiff && rdiff) - val lhs = splitDisjunctivePredicates(left) - val rhs = splitDisjunctivePredicates(right) - val common = lhs.filter(e => rhs.exists(e.semanticEquals(_))) - if (common.isEmpty) { - // No common factors, return the original predicate - and + case TrueLiteral And e => e + case e And TrueLiteral => e + case FalseLiteral Or e => e + case e Or FalseLiteral => e + + case FalseLiteral And _ => FalseLiteral + case _ And FalseLiteral => FalseLiteral + case TrueLiteral Or _ => TrueLiteral + case _ Or TrueLiteral => TrueLiteral + + case a And b if a.semanticEquals(b) => a + case a Or b if a.semanticEquals(b) => a + + case a And (b Or c) if Not(a).semanticEquals(b) => And(a, c) + case a And (b Or c) if Not(a).semanticEquals(c) => And(a, b) + case (a Or b) And c if a.semanticEquals(Not(c)) => And(b, c) + case (a Or b) And c if b.semanticEquals(Not(c)) => And(a, c) + + case a Or (b And c) if Not(a).semanticEquals(b) => Or(a, c) + case a Or (b And c) if Not(a).semanticEquals(c) => Or(a, b) + case (a And b) Or c if a.semanticEquals(Not(c)) => Or(b, c) + case (a And b) Or c if b.semanticEquals(Not(c)) => Or(a, c) + + // Common factor elimination for conjunction + case and @ (left And right) => + // 1. Split left and right to get the disjunctive predicates, + // i.e. lhs = (a, b), rhs = (a, c) + // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) + // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) + // 4. Apply the formula, get the optimized predicate: common || (ldiff && rdiff) + val lhs = splitDisjunctivePredicates(left) + val rhs = splitDisjunctivePredicates(right) + val common = lhs.filter(e => rhs.exists(e.semanticEquals)) + if (common.isEmpty) { + // No common factors, return the original predicate + and + } else { + val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals)) + val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals)) + if (ldiff.isEmpty || rdiff.isEmpty) { + // (a || b || c || ...) && (a || b) => (a || b) + common.reduce(Or) } else { - val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals(_))) - val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals(_))) - if (ldiff.isEmpty || rdiff.isEmpty) { - // (a || b || c || ...) && (a || b) => (a || b) - common.reduce(Or) - } else { - // (a || b || c || ...) && (a || b || d || ...) => - // ((c || ...) && (d || ...)) || a || b - (common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or) - } + // (a || b || c || ...) && (a || b || d || ...) => + // ((c || ...) && (d || ...)) || a || b + (common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or) } - } // end of And(left, right) - - case or @ Or(left, right) => (left, right) match { - // true || r => true - case (Literal(true, BooleanType), _) => Literal(true) - // r || true => true - case (_, Literal(true, BooleanType)) => Literal(true) - // false || r => r - case (Literal(false, BooleanType), r) => r - // l || false => l - case (l, Literal(false, BooleanType)) => l - // a || a => a - case (l, r) if l fastEquals r => l - // (a && b) || (a && c) => a && (b || c) - case _ => - // 1. Split left and right to get the conjunctive predicates, - // i.e. lhs = (a, b), rhs = (a, c) - // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) - // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) - // 4. Apply the formula, get the optimized predicate: common && (ldiff || rdiff) - val lhs = splitConjunctivePredicates(left) - val rhs = splitConjunctivePredicates(right) - val common = lhs.filter(e => rhs.exists(e.semanticEquals(_))) - if (common.isEmpty) { - // No common factors, return the original predicate - or + } + + // Common factor elimination for disjunction + case or @ (left Or right) => + // 1. Split left and right to get the conjunctive predicates, + // i.e. lhs = (a, b), rhs = (a, c) + // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) + // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) + // 4. Apply the formula, get the optimized predicate: common && (ldiff || rdiff) + val lhs = splitConjunctivePredicates(left) + val rhs = splitConjunctivePredicates(right) + val common = lhs.filter(e => rhs.exists(e.semanticEquals)) + if (common.isEmpty) { + // No common factors, return the original predicate + or + } else { + val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals)) + val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals)) + if (ldiff.isEmpty || rdiff.isEmpty) { + // (a && b) || (a && b && c && ...) => a && b + common.reduce(And) } else { - val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals(_))) - val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals(_))) - if (ldiff.isEmpty || rdiff.isEmpty) { - // (a && b) || (a && b && c && ...) => a && b - common.reduce(And) - } else { - // (a && b && c && ...) || (a && b && d && ...) => - // ((c && ...) || (d && ...)) && a && b - (common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And) - } + // (a && b && c && ...) || (a && b && d && ...) => + // ((c && ...) || (d && ...)) && a && b + (common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And) } - } // end of Or(left, right) - - case not @ Not(exp) => exp match { - // not(true) => false - case Literal(true, BooleanType) => Literal(false) - // not(false) => true - case Literal(false, BooleanType) => Literal(true) - // not(l > r) => l <= r - case GreaterThan(l, r) => LessThanOrEqual(l, r) - // not(l >= r) => l < r - case GreaterThanOrEqual(l, r) => LessThan(l, r) - // not(l < r) => l >= r - case LessThan(l, r) => GreaterThanOrEqual(l, r) - // not(l <= r) => l > r - case LessThanOrEqual(l, r) => GreaterThan(l, r) - // not(l || r) => not(l) && not(r) - case Or(l, r) => And(Not(l), Not(r)) - // not(l && r) => not(l) or not(r) - case And(l, r) => Or(Not(l), Not(r)) - // not(not(e)) => e - case Not(e) => e - case _ => not - } // end of Not(exp) - - // if (true) a else b => a - // if (false) a else b => b - case e @ If(Literal(v, _), trueValue, falseValue) => if (v == true) trueValue else falseValue + } + + case Not(TrueLiteral) => FalseLiteral + case Not(FalseLiteral) => TrueLiteral + + case Not(a GreaterThan b) => LessThanOrEqual(a, b) + case Not(a GreaterThanOrEqual b) => LessThan(a, b) + + case Not(a LessThan b) => GreaterThanOrEqual(a, b) + case Not(a LessThanOrEqual b) => GreaterThan(a, b) + + case Not(a Or b) => And(Not(a), Not(b)) + case Not(a And b) => Or(Not(a), Not(b)) + + case Not(Not(e)) => e } } } /** - * Combines two adjacent [[Filter]] operators into one, merging the - * conditions into one conjunctive predicate. + * Simplifies binary comparisons with semantically-equal expressions: + * 1) Replace '<=>' with 'true' literal. + * 2) Replace '=', '<=', and '>=' with 'true' literal if both operands are non-nullable. + * 3) Replace '<' and '>' with 'false' literal if both operands are non-nullable. */ -object CombineFilters extends Rule[LogicalPlan] { +object BinaryComparisonSimplification extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case ff @ Filter(fc, nf @ Filter(nc, grandChild)) => Filter(And(nc, fc), grandChild) + case q: LogicalPlan => q transformExpressionsUp { + // True with equality + case a EqualNullSafe b if a.semanticEquals(b) => TrueLiteral + case a EqualTo b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral + case a GreaterThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) => + TrueLiteral + case a LessThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral + + // False with inequality + case a GreaterThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral + case a LessThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral + } } } /** - * Removes filters that can be evaluated trivially. This is done either by eliding the filter for - * cases where it will always evaluate to `true`, or substituting a dummy empty relation when the - * filter will always evaluate to `false`. + * Simplifies conditional expressions (if / case). */ -object SimplifyFilters extends Rule[LogicalPlan] { +object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { + private def falseOrNullLiteral(e: Expression): Boolean = e match { + case FalseLiteral => true + case Literal(null, _) => true + case _ => false + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsUp { + case If(TrueLiteral, trueValue, _) => trueValue + case If(FalseLiteral, _, falseValue) => falseValue + case If(Literal(null, _), _, falseValue) => falseValue + + case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) => + // If there are branches that are always false, remove them. + // If there are no more branches left, just use the else value. + // Note that these two are handled together here in a single case statement because + // otherwise we cannot determine the data type for the elseValue if it is None (i.e. null). + val newBranches = branches.filter(x => !falseOrNullLiteral(x._1)) + if (newBranches.isEmpty) { + elseValue.getOrElse(Literal.create(null, e.dataType)) + } else { + e.copy(branches = newBranches) + } + + case e @ CaseWhen(branches, _) if branches.headOption.map(_._1) == Some(TrueLiteral) => + // If the first branch is a true literal, remove the entire CaseWhen and use the value + // from that. Note that CaseWhen.branches should never be empty, and as a result the + // headOption (rather than head) added above is just a extra (and unnecessary) safeguard. + branches.head._2 + } + } +} + +/** + * Combines all adjacent [[Union]] operators into a single [[Union]]. + */ +object CombineUnions extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Unions(children) => Union(children) + } +} + +/** + * Combines two adjacent [[Filter]] operators into one, merging the non-redundant conditions into + * one conjunctive predicate. + */ +object CombineFilters extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case ff @ Filter(fc, nf @ Filter(nc, grandChild)) => + (ExpressionSet(splitConjunctivePredicates(fc)) -- + ExpressionSet(splitConjunctivePredicates(nc))).reduceOption(And) match { + case Some(ac) => + Filter(And(ac, nc), grandChild) + case None => + nf + } + } +} + +/** + * Removes no-op SortOrder from Sort + */ +object EliminateSorts extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) => + val newOrders = orders.filterNot(_.child.foldable) + if (newOrders.isEmpty) child else s.copy(order = newOrders) + } +} + +/** + * Removes filters that can be evaluated trivially. This can be done through the following ways: + * 1) by eliding the filter for cases where it will always evaluate to `true`. + * 2) by substituting a dummy empty relation when the filter will always evaluate to `false`. + * 3) by eliminating the always-true conditions given the constraints on the child's output. + */ +object PruneFilters extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // If the filter condition always evaluate to true, remove the filter. case Filter(Literal(true, BooleanType), child) => child @@ -601,101 +902,244 @@ object SimplifyFilters extends Rule[LogicalPlan] { // replace the input with an empty relation. case Filter(Literal(null, _), child) => LocalRelation(child.output, data = Seq.empty) case Filter(Literal(false, BooleanType), child) => LocalRelation(child.output, data = Seq.empty) + // If any deterministic condition is guaranteed to be true given the constraints on the child's + // output, remove the condition + case f @ Filter(fc, p: LogicalPlan) => + val (prunedPredicates, remainingPredicates) = + splitConjunctivePredicates(fc).partition { cond => + cond.deterministic && p.constraints.contains(cond) + } + if (prunedPredicates.isEmpty) { + f + } else if (remainingPredicates.isEmpty) { + p + } else { + val newCond = remainingPredicates.reduce(And) + Filter(newCond, p) + } } } /** - * Pushes [[Filter]] operators through [[Project]] operators, in-lining any [[Alias Aliases]] - * that were defined in the projection. + * Pushes [[Filter]] operators through many operators iff: + * 1) the operator is deterministic + * 2) the predicate is deterministic and the operator will not change any of rows. * * This heuristic is valid assuming the expression evaluation cost is minimal. */ -object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelper { +object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case filter @ Filter(condition, project @ Project(fields, grandChild)) => + // SPARK-13473: We can't push the predicate down when the underlying projection output non- + // deterministic field(s). Non-deterministic expressions are essentially stateful. This + // implies that, for a given input row, the output are determined by the expression's initial + // state and all the input rows processed before. In another word, the order of input rows + // matters for non-deterministic expressions, while pushing down predicates changes the order. + case filter @ Filter(condition, project @ Project(fields, grandChild)) + if fields.forall(_.deterministic) => + // Create a map of Aliases to their values from the child projection. // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b). val aliasMap = AttributeMap(fields.collect { case a: Alias => (a.toAttribute, a.child) }) - // Split the condition into small conditions by `And`, so that we can push down part of this - // condition without nondeterministic expressions. - val andConditions = splitConjunctivePredicates(condition) + project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) + + case filter @ Filter(condition, aggregate: Aggregate) => + // Find all the aliased expressions in the aggregate list that don't include any actual + // AggregateExpression, and create a map from the alias to the expression + val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { + case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty => + (a.toAttribute, a.child) + }) - val (deterministic, nondeterministic) = andConditions.partition(_.collect { - case a: Attribute if aliasMap.contains(a) => aliasMap(a) - }.forall(_.deterministic)) + // For each filter, expand the alias and check if the filter can be evaluated using + // attributes produced by the aggregate operator's child operator. + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => + val replaced = replaceAlias(cond, aliasMap) + replaced.references.subsetOf(aggregate.child.outputSet) && replaced.deterministic + } - // If there is no nondeterministic conditions, push down the whole condition. - if (nondeterministic.isEmpty) { - project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) + if (pushDown.nonEmpty) { + val pushDownPredicate = pushDown.reduce(And) + val replaced = replaceAlias(pushDownPredicate, aliasMap) + val newAggregate = aggregate.copy(child = Filter(replaced, aggregate.child)) + // If there is no more filter to stay up, just eliminate the filter. + // Otherwise, create "Filter(stayUp) <- Aggregate <- Filter(pushDownPredicate)". + if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), newAggregate) } else { - // If they are all nondeterministic conditions, leave it un-changed. - if (deterministic.isEmpty) { - filter + filter + } + + case filter @ Filter(condition, child) + if child.isInstanceOf[Union] || child.isInstanceOf[Intersect] => + // Union/Intersect could change the rows, so non-deterministic predicate can't be pushed down + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => + cond.deterministic + } + if (pushDown.nonEmpty) { + val pushDownCond = pushDown.reduceLeft(And) + val output = child.output + val newGrandChildren = child.children.map { grandchild => + val newCond = pushDownCond transform { + case e if output.exists(_.semanticEquals(e)) => + grandchild.output(output.indexWhere(_.semanticEquals(e))) + } + assert(newCond.references.subsetOf(grandchild.outputSet)) + Filter(newCond, grandchild) + } + val newChild = child.withNewChildren(newGrandChildren) + if (stayUp.nonEmpty) { + Filter(stayUp.reduceLeft(And), newChild) } else { - // Push down the small conditions without nondeterministic expressions. - val pushedCondition = deterministic.map(replaceAlias(_, aliasMap)).reduce(And) - Filter(nondeterministic.reduce(And), - project.copy(child = Filter(pushedCondition, grandChild))) + newChild } + } else { + filter + } + + case filter @ Filter(condition, e @ Except(left, _)) => + pushDownPredicate(filter, e.left) { predicate => + e.copy(left = Filter(predicate, left)) + } + + // two filters should be combine together by other rules + case filter @ Filter(_, f: Filter) => filter + // should not push predicates through sample, or will generate different results. + case filter @ Filter(_, s: Sample) => filter + // TODO: push predicates through expand + case filter @ Filter(_, e: Expand) => filter + + case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic) => + pushDownPredicate(filter, u.child) { predicate => + u.withNewChildren(Seq(Filter(predicate, u.child))) } } - // Substitute any attributes that are produced by the child projection, so that we safely - // eliminate it. - private def replaceAlias(condition: Expression, sourceAliases: AttributeMap[Expression]) = { - condition.transform { - case a: Attribute => sourceAliases.getOrElse(a, a) + private def pushDownPredicate( + filter: Filter, + grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = { + // Only push down the predicates that is deterministic and all the referenced attributes + // come from grandchild. + // TODO: non-deterministic predicates could be pushed through some operators that do not change + // the rows. + val (pushDown, stayUp) = splitConjunctivePredicates(filter.condition).partition { cond => + cond.deterministic && cond.references.subsetOf(grandchild.outputSet) + } + if (pushDown.nonEmpty) { + val newChild = insertFilter(pushDown.reduceLeft(And)) + if (stayUp.nonEmpty) { + Filter(stayUp.reduceLeft(And), newChild) + } else { + newChild + } + } else { + filter } } } /** - * Push [[Filter]] operators through [[Generate]] operators. Parts of the predicate that reference - * attributes generated in [[Generate]] will remain above, and the rest should be pushed beneath. + * Reorder the joins and push all the conditions into join, so that the bottom ones have at least + * one condition. + * + * The order of joins will not be changed if all of them already have at least one condition. */ -object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper { +object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case filter @ Filter(condition, g: Generate) => - // Predicates that reference attributes produced by the `Generate` operator cannot - // be pushed below the operator. - val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { - conjunct => conjunct.references subsetOf g.child.outputSet - } - if (pushDown.nonEmpty) { - val pushDownPredicate = pushDown.reduce(And) - val withPushdown = Generate(g.generator, join = g.join, outer = g.outer, - g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child)) - stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown) - } else { - filter + /** + * Join a list of plans together and push down the conditions into them. + * + * The joined plan are picked from left to right, prefer those has at least one join condition. + * + * @param input a list of LogicalPlans to join. + * @param conditions a list of condition for join. + */ + @tailrec + def createOrderedJoin(input: Seq[LogicalPlan], conditions: Seq[Expression]): LogicalPlan = { + assert(input.size >= 2) + if (input.size == 2) { + Join(input(0), input(1), Inner, conditions.reduceLeftOption(And)) + } else { + val left :: rest = input.toList + // find out the first join that have at least one join condition + val conditionalJoin = rest.find { plan => + val refs = left.outputSet ++ plan.outputSet + conditions.filterNot(canEvaluate(_, left)).filterNot(canEvaluate(_, plan)) + .exists(_.references.subsetOf(refs)) } + // pick the next one if no condition left + val right = conditionalJoin.getOrElse(rest.head) + + val joinedRefs = left.outputSet ++ right.outputSet + val (joinConditions, others) = conditions.partition(_.references.subsetOf(joinedRefs)) + val joined = Join(left, right, Inner, joinConditions.reduceLeftOption(And)) + + // should not have reference to same logical plan + createOrderedJoin(Seq(joined) ++ rest.filterNot(_ eq right), others) + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case j @ ExtractFiltersAndInnerJoins(input, conditions) + if input.size > 2 && conditions.nonEmpty => + createOrderedJoin(input, conditions) } } /** - * Push [[Filter]] operators through [[Aggregate]] operators. Parts of the predicate that reference - * attributes which are subset of group by attribute set of [[Aggregate]] will be pushed beneath, - * and the rest should remain above. + * Elimination of outer joins, if the predicates can restrict the result sets so that + * all null-supplying rows are eliminated + * + * - full outer -> inner if both sides have such predicates + * - left outer -> inner if the right side has such predicates + * - right outer -> inner if the left side has such predicates + * - full outer -> left outer if only the left side has such predicates + * - full outer -> right outer if only the right side has such predicates + * + * This rule should be executed before pushing down the Filter */ -object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHelper { +object OuterJoinElimination extends Rule[LogicalPlan] with PredicateHelper { + + /** + * Returns whether the expression returns null or false when all inputs are nulls. + */ + private def canFilterOutNull(e: Expression): Boolean = { + if (!e.deterministic) return false + val attributes = e.references.toSeq + val emptyRow = new GenericInternalRow(attributes.length) + val v = BindReferences.bindReference(e, attributes).eval(emptyRow) + v == null || v == false + } + + private def buildNewJoinType(filter: Filter, join: Join): JoinType = { + val splitConjunctiveConditions: Seq[Expression] = splitConjunctivePredicates(filter.condition) + val leftConditions = splitConjunctiveConditions + .filter(_.references.subsetOf(join.left.outputSet)) + val rightConditions = splitConjunctiveConditions + .filter(_.references.subsetOf(join.right.outputSet)) + + val leftHasNonNullPredicate = leftConditions.exists(canFilterOutNull) || + filter.constraints.filter(_.isInstanceOf[IsNotNull]) + .exists(expr => join.left.outputSet.intersect(expr.references).nonEmpty) + val rightHasNonNullPredicate = rightConditions.exists(canFilterOutNull) || + filter.constraints.filter(_.isInstanceOf[IsNotNull]) + .exists(expr => join.right.outputSet.intersect(expr.references).nonEmpty) + + join.joinType match { + case RightOuter if leftHasNonNullPredicate => Inner + case LeftOuter if rightHasNonNullPredicate => Inner + case FullOuter if leftHasNonNullPredicate && rightHasNonNullPredicate => Inner + case FullOuter if leftHasNonNullPredicate => LeftOuter + case FullOuter if rightHasNonNullPredicate => RightOuter + case o => o + } + } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case filter @ Filter(condition, - aggregate @ Aggregate(groupingExpressions, aggregateExpressions, grandChild)) => - val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { - conjunct => conjunct.references subsetOf AttributeSet(groupingExpressions) - } - if (pushDown.nonEmpty) { - val pushDownPredicate = pushDown.reduce(And) - val withPushdown = aggregate.copy(child = Filter(pushDownPredicate, grandChild)) - stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown) - } else { - filter - } + case f @ Filter(condition, j @ Join(_, _, RightOuter | LeftOuter | FullOuter, _)) => + val newJoinType = buildNewJoinType(f, j) + if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType)) } } @@ -704,7 +1148,7 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel * evaluated using only the attributes of the left or right side of a join. Other * [[Filter]] conditions are moved into the `condition` of the [[Join]]. * - * And also Pushes down the join filter, where the `condition` can be evaluated using only the + * And also pushes down the join filter, where the `condition` can be evaluated using only the * attributes of the left or right side of sub query when applicable. * * Check https://cwiki.apache.org/confluence/display/Hive/OuterJoinBehavior for more details @@ -713,6 +1157,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { /** * Splits join condition expressions into three categories based on the attributes required * to evaluate them. + * * @return (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth) */ private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = { @@ -750,7 +1195,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { (leftFilterConditions ++ commonFilterCondition). reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) - case _ @ (LeftOuter | LeftSemi) => + case LeftOuter | LeftExistence(_) => // push down the left side only `where` condition val newLeft = leftFilterConditions. reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) @@ -761,6 +1206,8 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { (rightFilterConditions ++ commonFilterCondition). reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) case FullOuter => f // DO Nothing for Full Outer Join + case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") + case UsingJoin(_, _) => sys.error("Untransformed Using join node") } // push down the join filter into sub query scanning if applicable @@ -769,7 +1216,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right) joinType match { - case _ @ (Inner | LeftSemi) => + case Inner | LeftExistence(_) => // push down the single side only join filter for both sides sub queries val newLeft = leftJoinConditions. reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) @@ -795,6 +1242,8 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { Join(newLeft, newRight, LeftOuter, newJoinCond) case FullOuter => f + case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") + case UsingJoin(_, _) => sys.error("Untransformed Using join node") } } } @@ -811,7 +1260,7 @@ object SimplifyCasts extends Rule[LogicalPlan] { /** * Removes nodes that are not necessary. */ -object RemoveDispensable extends Rule[LogicalPlan] { +object RemoveDispensableExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case UnaryPositive(child) => child case PromotePrecision(child) => child @@ -824,8 +1273,12 @@ object RemoveDispensable extends Rule[LogicalPlan] { */ object CombineLimits extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case ll @ GlobalLimit(le, nl @ GlobalLimit(ne, grandChild)) => + GlobalLimit(Least(Seq(ne, le)), grandChild) + case ll @ LocalLimit(le, nl @ LocalLimit(ne, grandChild)) => + LocalLimit(Least(Seq(ne, le)), grandChild) case ll @ Limit(le, nl @ Limit(ne, grandChild)) => - Limit(If(LessThan(ne, le), ne, le), grandChild) + Limit(Least(Seq(ne, le)), grandChild) } } @@ -848,7 +1301,7 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { * Speeds up aggregates on fixed-precision decimals by executing them on unscaled Long values. * * This uses the same rules for increasing the precision and scale of the output as - * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion.DecimalPrecision]]. + * [[org.apache.spark.sql.catalyst.analysis.DecimalPrecision]]. */ object DecimalAggregates extends Rule[LogicalPlan] { import Decimal.MAX_LONG_DIGITS @@ -857,12 +1310,15 @@ object DecimalAggregates extends Rule[LogicalPlan] { private val MAX_DOUBLE_DIGITS = 15 def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => - MakeDecimal(Sum(UnscaledValue(e)), prec + 10, scale) + case ae @ AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), _, _, _) + if prec + 10 <= MAX_LONG_DIGITS => + MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale) - case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => + case ae @ AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), _, _, _) + if prec + 4 <= MAX_DOUBLE_DIGITS => + val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e))) Cast( - Divide(Average(UnscaledValue(e)), Literal.create(math.pow(10.0, scale), DoubleType)), + Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), DecimalType(prec + 4, scale + 4)) } } @@ -893,6 +1349,27 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] { } } +/** + * Replaces logical [[Intersect]] operator with a left-semi [[Join]] operator. + * {{{ + * SELECT a1, a2 FROM Tab1 INTERSECT SELECT b1, b2 FROM Tab2 + * ==> SELECT DISTINCT a1, a2 FROM Tab1 LEFT SEMI JOIN Tab2 ON a1<=>b1 AND a2<=>b2 + * }}} + * + * Note: + * 1. This rule is only applicable to INTERSECT DISTINCT. Do not use it for INTERSECT ALL. + * 2. This rule has to be done after de-duplicating the attributes; otherwise, the generated + * join conditions will be incorrect. + */ +object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Intersect(left, right) => + assert(left.output.size == right.output.size) + val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) } + Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And))) + } +} + /** * Removes literals from group expressions in [[Aggregate]], as they have no effect to the result * but only makes the grouping key bigger. @@ -904,3 +1381,47 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { a.copy(groupingExpressions = newGrouping) } } + +/** + * Computes the current date and time to make sure we return the same result in a single query. + */ +object ComputeCurrentTime extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + val dateExpr = CurrentDate() + val timeExpr = CurrentTimestamp() + val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType) + val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType) + + plan transformAllExpressions { + case CurrentDate() => currentDate + case CurrentTimestamp() => currentTime + } + } +} + +/** + * Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a + * [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed + * the deserializer in filter condition to save the extra serialization at last. + */ +object EmbedSerializerInFilter extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject)) => + val numObjects = condition.collect { + case a: Attribute if a == d.output.head => a + }.length + + if (numObjects > 1) { + // If the filter condition references the object more than one times, we should not embed + // deserializer in it as the deserialization will happen many times and slow down the + // execution. + // TODO: we can still embed it if we can make sure subexpression elimination works here. + s + } else { + val newCondition = condition transform { + case a: Attribute if a == d.output.head => d.deserializer.child + } + Filter(newCondition, d.child) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala new file mode 100644 index 0000000000000..aa59f3fb2a4a4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -0,0 +1,1455 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import java.sql.{Date, Timestamp} + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.antlr.v4.runtime.{ParserRuleContext, Token} +import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.util.random.RandomSampler + +/** + * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or + * TableIdentifier. + */ +class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { + import ParserUtils._ + + protected def typedVisit[T](ctx: ParseTree): T = { + ctx.accept(this).asInstanceOf[T] + } + + /** + * Override the default behavior for all visit methods. This will only return a non-null result + * when the context has only one child. This is done because there is no generic method to + * combine the results of the context children. In all other cases null is returned. + */ + override def visitChildren(node: RuleNode): AnyRef = { + if (node.getChildCount == 1) { + node.getChild(0).accept(this) + } else { + null + } + } + + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { + visit(ctx.statement).asInstanceOf[LogicalPlan] + } + + override def visitSingleExpression(ctx: SingleExpressionContext): Expression = withOrigin(ctx) { + visitNamedExpression(ctx.namedExpression) + } + + override def visitSingleTableIdentifier( + ctx: SingleTableIdentifierContext): TableIdentifier = withOrigin(ctx) { + visitTableIdentifier(ctx.tableIdentifier) + } + + override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { + visit(ctx.dataType).asInstanceOf[DataType] + } + + /* ******************************************************************************************** + * Plan parsing + * ******************************************************************************************** */ + protected def plan(tree: ParserRuleContext): LogicalPlan = typedVisit(tree) + + /** + * Make sure we do not try to create a plan for a native command. + */ + override def visitExecuteNativeCommand(ctx: ExecuteNativeCommandContext): LogicalPlan = null + + /** + * Create a plan for a SHOW FUNCTIONS command. + */ + override def visitShowFunctions(ctx: ShowFunctionsContext): LogicalPlan = withOrigin(ctx) { + import ctx._ + if (qualifiedName != null) { + val names = qualifiedName().identifier().asScala.map(_.getText).toList + names match { + case db :: name :: Nil => + ShowFunctions(Some(db), Some(name)) + case name :: Nil => + ShowFunctions(None, Some(name)) + case _ => + throw new ParseException("SHOW FUNCTIONS unsupported name", ctx) + } + } else if (pattern != null) { + ShowFunctions(None, Some(string(pattern))) + } else { + ShowFunctions(None, None) + } + } + + /** + * Create a plan for a DESCRIBE FUNCTION command. + */ + override def visitDescribeFunction(ctx: DescribeFunctionContext): LogicalPlan = withOrigin(ctx) { + val functionName = ctx.qualifiedName().identifier().asScala.map(_.getText).mkString(".") + DescribeFunction(functionName, ctx.EXTENDED != null) + } + + /** + * Create a top-level plan with Common Table Expressions. + */ + override def visitQuery(ctx: QueryContext): LogicalPlan = withOrigin(ctx) { + val query = plan(ctx.queryNoWith) + + // Apply CTEs + query.optional(ctx.ctes) { + val ctes = ctx.ctes.namedQuery.asScala.map { + case nCtx => + val namedQuery = visitNamedQuery(nCtx) + (namedQuery.alias, namedQuery) + } + + // Check for duplicate names. + ctes.groupBy(_._1).filter(_._2.size > 1).foreach { + case (name, _) => + throw new ParseException( + s"Name '$name' is used for multiple common table expressions", ctx) + } + + With(query, ctes.toMap) + } + } + + /** + * Create a named logical plan. + * + * This is only used for Common Table Expressions. + */ + override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) { + SubqueryAlias(ctx.name.getText, plan(ctx.queryNoWith)) + } + + /** + * Create a logical plan which allows for multiple inserts using one 'from' statement. These + * queries have the following SQL form: + * {{{ + * [WITH cte...]? + * FROM src + * [INSERT INTO tbl1 SELECT *]+ + * }}} + * For example: + * {{{ + * FROM db.tbl1 A + * INSERT INTO dbo.tbl1 SELECT * WHERE A.value = 10 LIMIT 5 + * INSERT INTO dbo.tbl2 SELECT * WHERE A.value = 12 + * }}} + * This (Hive) feature cannot be combined with set-operators. + */ + override def visitMultiInsertQuery(ctx: MultiInsertQueryContext): LogicalPlan = withOrigin(ctx) { + val from = visitFromClause(ctx.fromClause) + + // Build the insert clauses. + val inserts = ctx.multiInsertQueryBody.asScala.map { + body => + assert(body.querySpecification.fromClause == null, + "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements", + body) + + withQuerySpecification(body.querySpecification, from). + // Add organization statements. + optionalMap(body.queryOrganization)(withQueryResultClauses). + // Add insert. + optionalMap(body.insertInto())(withInsertInto) + } + + // If there are multiple INSERTS just UNION them together into one query. + inserts match { + case Seq(query) => query + case queries => Union(queries) + } + } + + /** + * Create a logical plan for a regular (single-insert) query. + */ + override def visitSingleInsertQuery( + ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryTerm). + // Add organization statements. + optionalMap(ctx.queryOrganization)(withQueryResultClauses). + // Add insert. + optionalMap(ctx.insertInto())(withInsertInto) + } + + /** + * Add an INSERT INTO [TABLE]/INSERT OVERWRITE TABLE operation to the logical plan. + */ + private def withInsertInto( + ctx: InsertIntoContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val tableIdent = visitTableIdentifier(ctx.tableIdentifier) + val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) + + InsertIntoTable( + UnresolvedRelation(tableIdent, None), + partitionKeys, + query, + ctx.OVERWRITE != null, + ctx.EXISTS != null) + } + + /** + * Create a partition specification map. + */ + override def visitPartitionSpec( + ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) { + ctx.partitionVal.asScala.map { pVal => + val name = pVal.identifier.getText.toLowerCase + val value = Option(pVal.constant).map(visitStringConstant) + name -> value + }.toMap + } + + /** + * Create a partition specification map without optional values. + */ + protected def visitNonOptionalPartitionSpec( + ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) { + visitPartitionSpec(ctx).mapValues(_.orNull).map(identity) + } + + /** + * Convert a constant of any type into a string. This is typically used in DDL commands, and its + * main purpose is to prevent slight differences due to back to back conversions i.e.: + * String -> Literal -> String. + */ + protected def visitStringConstant(ctx: ConstantContext): String = withOrigin(ctx) { + ctx match { + case s: StringLiteralContext => createString(s) + case o => o.getText + } + } + + /** + * Add ORDER BY/SORT BY/CLUSTER BY/DISTRIBUTE BY/LIMIT/WINDOWS clauses to the logical plan. These + * clauses determine the shape (ordering/partitioning/rows) of the query result. + */ + private def withQueryResultClauses( + ctx: QueryOrganizationContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + import ctx._ + + // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. + val withOrder = if ( + !order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // ORDER BY ... + Sort(order.asScala.map(visitSortItem), global = true, query) + } else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // SORT BY ... + Sort(sort.asScala.map(visitSortItem), global = false, query) + } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { + // DISTRIBUTE BY ... + RepartitionByExpression(expressionList(distributeBy), query) + } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { + // SORT BY ... DISTRIBUTE BY ... + Sort( + sort.asScala.map(visitSortItem), + global = false, + RepartitionByExpression(expressionList(distributeBy), query)) + } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) { + // CLUSTER BY ... + val expressions = expressionList(clusterBy) + Sort( + expressions.map(SortOrder(_, Ascending)), + global = false, + RepartitionByExpression(expressions, query)) + } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // [EMPTY] + query + } else { + throw new ParseException( + "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", ctx) + } + + // WINDOWS + val withWindow = withOrder.optionalMap(windows)(withWindows) + + // LIMIT + withWindow.optional(limit) { + Limit(typedVisit(limit), withWindow) + } + } + + /** + * Create a logical plan using a query specification. + */ + override def visitQuerySpecification( + ctx: QuerySpecificationContext): LogicalPlan = withOrigin(ctx) { + val from = OneRowRelation.optional(ctx.fromClause) { + visitFromClause(ctx.fromClause) + } + withQuerySpecification(ctx, from) + } + + /** + * Add a query specification to a logical plan. The query specification is the core of the logical + * plan, this is where sourcing (FROM clause), transforming (SELECT TRANSFORM/MAP/REDUCE), + * projection (SELECT), aggregation (GROUP BY ... HAVING ...) and filtering (WHERE) takes place. + * + * Note that query hints are ignored (both by the parser and the builder). + */ + private def withQuerySpecification( + ctx: QuerySpecificationContext, + relation: LogicalPlan): LogicalPlan = withOrigin(ctx) { + import ctx._ + + // WHERE + def filter(ctx: BooleanExpressionContext, plan: LogicalPlan): LogicalPlan = { + Filter(expression(ctx), plan) + } + + // Expressions. + val expressions = Option(namedExpressionSeq).toSeq + .flatMap(_.namedExpression.asScala) + .map(typedVisit[Expression]) + + // Create either a transform or a regular query. + val specType = Option(kind).map(_.getType).getOrElse(SqlBaseParser.SELECT) + specType match { + case SqlBaseParser.MAP | SqlBaseParser.REDUCE | SqlBaseParser.TRANSFORM => + // Transform + + // Add where. + val withFilter = relation.optionalMap(where)(filter) + + // Create the attributes. + val (attributes, schemaLess) = if (colTypeList != null) { + // Typed return columns. + (createStructType(colTypeList).toAttributes, false) + } else if (identifierSeq != null) { + // Untyped return columns. + val attrs = visitIdentifierSeq(identifierSeq).map { name => + AttributeReference(name, StringType, nullable = true)() + } + (attrs, false) + } else { + (Seq(AttributeReference("key", StringType)(), + AttributeReference("value", StringType)()), true) + } + + // Create the transform. + ScriptTransformation( + expressions, + string(script), + attributes, + withFilter, + withScriptIOSchema( + ctx, inRowFormat, recordWriter, outRowFormat, recordReader, schemaLess)) + + case SqlBaseParser.SELECT => + // Regular select + + // Add lateral views. + val withLateralView = ctx.lateralView.asScala.foldLeft(relation)(withGenerate) + + // Add where. + val withFilter = withLateralView.optionalMap(where)(filter) + + // Add aggregation or a project. + val namedExpressions = expressions.map { + case e: NamedExpression => e + case e: Expression => UnresolvedAlias(e) + } + val withProject = if (aggregation != null) { + withAggregation(aggregation, namedExpressions, withFilter) + } else if (namedExpressions.nonEmpty) { + Project(namedExpressions, withFilter) + } else { + withFilter + } + + // Having + val withHaving = withProject.optional(having) { + // Note that we added a cast to boolean. If the expression itself is already boolean, + // the optimizer will get rid of the unnecessary cast. + Filter(Cast(expression(having), BooleanType), withProject) + } + + // Distinct + val withDistinct = if (setQuantifier() != null && setQuantifier().DISTINCT() != null) { + Distinct(withHaving) + } else { + withHaving + } + + // Window + withDistinct.optionalMap(windows)(withWindows) + } + } + + /** + * Create a (Hive based) [[ScriptInputOutputSchema]]. + */ + protected def withScriptIOSchema( + ctx: QuerySpecificationContext, + inRowFormat: RowFormatContext, + recordWriter: Token, + outRowFormat: RowFormatContext, + recordReader: Token, + schemaLess: Boolean): ScriptInputOutputSchema = { + throw new ParseException("Script Transform is not supported", ctx) + } + + /** + * Create a logical plan for a given 'FROM' clause. Note that we support multiple (comma + * separated) relations here, these get converted into a single plan by condition-less inner join. + */ + override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) { + val from = ctx.relation.asScala.map(plan).reduceLeft(Join(_, _, Inner, None)) + ctx.lateralView.asScala.foldLeft(from)(withGenerate) + } + + /** + * Connect two queries by a Set operator. + * + * Supported Set operators are: + * - UNION [DISTINCT] + * - UNION ALL + * - EXCEPT [DISTINCT] + * - INTERSECT [DISTINCT] + */ + override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) { + val left = plan(ctx.left) + val right = plan(ctx.right) + val all = Option(ctx.setQuantifier()).exists(_.ALL != null) + ctx.operator.getType match { + case SqlBaseParser.UNION if all => + Union(left, right) + case SqlBaseParser.UNION => + Distinct(Union(left, right)) + case SqlBaseParser.INTERSECT if all => + throw new ParseException("INTERSECT ALL is not supported.", ctx) + case SqlBaseParser.INTERSECT => + Intersect(left, right) + case SqlBaseParser.EXCEPT if all => + throw new ParseException("EXCEPT ALL is not supported.", ctx) + case SqlBaseParser.EXCEPT => + Except(left, right) + } + } + + /** + * Add a [[WithWindowDefinition]] operator to a logical plan. + */ + private def withWindows( + ctx: WindowsContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // Collect all window specifications defined in the WINDOW clause. + val baseWindowMap = ctx.namedWindow.asScala.map { + wCtx => + (wCtx.identifier.getText, typedVisit[WindowSpec](wCtx.windowSpec)) + }.toMap + + // Handle cases like + // window w1 as (partition by p_mfgr order by p_name + // range between 2 preceding and 2 following), + // w2 as w1 + val windowMapView = baseWindowMap.mapValues { + case WindowSpecReference(name) => + baseWindowMap.get(name) match { + case Some(spec: WindowSpecDefinition) => + spec + case Some(ref) => + throw new ParseException(s"Window reference '$name' is not a window specification", ctx) + case None => + throw new ParseException(s"Cannot resolve window reference '$name'", ctx) + } + case spec: WindowSpecDefinition => spec + } + + // Note that mapValues creates a view instead of materialized map. We force materialization by + // mapping over identity. + WithWindowDefinition(windowMapView.map(identity), query) + } + + /** + * Add an [[Aggregate]] to a logical plan. + */ + private def withAggregation( + ctx: AggregationContext, + selectExpressions: Seq[NamedExpression], + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + import ctx._ + val groupByExpressions = expressionList(groupingExpressions) + + if (GROUPING != null) { + // GROUP BY .... GROUPING SETS (...) + val expressionMap = groupByExpressions.zipWithIndex.toMap + val numExpressions = expressionMap.size + val mask = (1 << numExpressions) - 1 + val masks = ctx.groupingSet.asScala.map { + _.expression.asScala.foldLeft(mask) { + case (bitmap, eCtx) => + // Find the index of the expression. + val e = typedVisit[Expression](eCtx) + val index = expressionMap.find(_._1.semanticEquals(e)).map(_._2).getOrElse( + throw new ParseException( + s"$e doesn't show up in the GROUP BY list", ctx)) + // 0 means that the column at the given index is a grouping column, 1 means it is not, + // so we unset the bit in bitmap. + bitmap & ~(1 << (numExpressions - 1 - index)) + } + } + GroupingSets(masks, groupByExpressions, query, selectExpressions) + } else { + // GROUP BY .... (WITH CUBE | WITH ROLLUP)? + val mappedGroupByExpressions = if (CUBE != null) { + Seq(Cube(groupByExpressions)) + } else if (ROLLUP != null) { + Seq(Rollup(groupByExpressions)) + } else { + groupByExpressions + } + Aggregate(mappedGroupByExpressions, selectExpressions, query) + } + } + + /** + * Add a [[Generate]] (Lateral View) to a logical plan. + */ + private def withGenerate( + query: LogicalPlan, + ctx: LateralViewContext): LogicalPlan = withOrigin(ctx) { + val expressions = expressionList(ctx.expression) + + // Create the generator. + val generator = ctx.qualifiedName.getText.toLowerCase match { + case "explode" if expressions.size == 1 => + Explode(expressions.head) + case "json_tuple" => + JsonTuple(expressions) + case name => + UnresolvedGenerator(name, expressions) + } + + Generate( + generator, + join = true, + outer = ctx.OUTER != null, + Some(ctx.tblName.getText.toLowerCase), + ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply), + query) + } + + /** + * Create a joins between two or more logical plans. + */ + override def visitJoinRelation(ctx: JoinRelationContext): LogicalPlan = withOrigin(ctx) { + /** Build a join between two plans. */ + def join(ctx: JoinRelationContext, left: LogicalPlan, right: LogicalPlan): Join = { + val baseJoinType = ctx.joinType match { + case null => Inner + case jt if jt.FULL != null => FullOuter + case jt if jt.SEMI != null => LeftSemi + case jt if jt.ANTI != null => LeftAnti + case jt if jt.LEFT != null => LeftOuter + case jt if jt.RIGHT != null => RightOuter + case _ => Inner + } + + // Resolve the join type and join condition + val (joinType, condition) = Option(ctx.joinCriteria) match { + case Some(c) if c.USING != null => + val columns = c.identifier.asScala.map { column => + UnresolvedAttribute.quoted(column.getText) + } + (UsingJoin(baseJoinType, columns), None) + case Some(c) if c.booleanExpression != null => + (baseJoinType, Option(expression(c.booleanExpression))) + case None if ctx.NATURAL != null => + (NaturalJoin(baseJoinType), None) + case None => + (baseJoinType, None) + } + Join(left, right, joinType, condition) + } + + // Handle all consecutive join clauses. ANTLR produces a right nested tree in which the the + // first join clause is at the top. However fields of previously referenced tables can be used + // in following join clauses. The tree needs to be reversed in order to make this work. + var result = plan(ctx.left) + var current = ctx + while (current != null) { + current.right match { + case right: JoinRelationContext => + result = join(current, result, plan(right.left)) + current = right + case right => + result = join(current, result, plan(right)) + current = null + } + } + result + } + + /** + * Add a [[Sample]] to a logical plan. + * + * This currently supports the following sampling methods: + * - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows. + * - TABLESAMPLE(x PERCENT): Sample the table down to the given percentage. Note that percentages + * are defined as a number between 0 and 100. + * - TABLESAMPLE(BUCKET x OUT OF y): Sample the table down to a 'x' divided by 'y' fraction. + */ + private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // Create a sampled plan if we need one. + def sample(fraction: Double): Sample = { + // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling + // function takes X PERCENT as the input and the range of X is [0, 100], we need to + // adjust the fraction. + val eps = RandomSampler.roundingEpsilon + assert(fraction >= 0.0 - eps && fraction <= 1.0 + eps, + s"Sampling fraction ($fraction) must be on interval [0, 1]", + ctx) + Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query)(true) + } + + ctx.sampleType.getType match { + case SqlBaseParser.ROWS => + Limit(expression(ctx.expression), query) + + case SqlBaseParser.PERCENTLIT => + val fraction = ctx.percentage.getText.toDouble + sample(fraction / 100.0d) + + case SqlBaseParser.BUCKET if ctx.ON != null => + throw new ParseException("TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported", ctx) + + case SqlBaseParser.BUCKET => + sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble) + } + } + + /** + * Create a logical plan for a sub-query. + */ + override def visitSubquery(ctx: SubqueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryNoWith) + } + + /** + * Create an un-aliased table reference. This is typically used for top-level table references, + * for example: + * {{{ + * INSERT INTO db.tbl2 + * TABLE db.tbl1 + * }}} + */ + override def visitTable(ctx: TableContext): LogicalPlan = withOrigin(ctx) { + UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier), None) + } + + /** + * Create an aliased table reference. This is typically used in FROM clauses. + */ + override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) { + val table = UnresolvedRelation( + visitTableIdentifier(ctx.tableIdentifier), + Option(ctx.identifier).map(_.getText)) + table.optionalMap(ctx.sample)(withSample) + } + + /** + * Create an inline table (a virtual table in Hive parlance). + */ + override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) { + // Get the backing expressions. + val expressions = ctx.expression.asScala.map { eCtx => + val e = expression(eCtx) + assert(e.foldable, "All expressions in an inline table must be constants.", eCtx) + e + } + + // Validate and evaluate the rows. + val (structType, structConstructor) = expressions.head.dataType match { + case st: StructType => + (st, (e: Expression) => e) + case dt => + val st = CreateStruct(Seq(expressions.head)).dataType + (st, (e: Expression) => CreateStruct(Seq(e))) + } + val rows = expressions.map { + case expression => + val safe = Cast(structConstructor(expression), structType) + safe.eval().asInstanceOf[InternalRow] + } + + // Construct attributes. + val baseAttributes = structType.toAttributes.map(_.withNullability(true)) + val attributes = if (ctx.identifierList != null) { + val aliases = visitIdentifierList(ctx.identifierList) + assert(aliases.size == baseAttributes.size, + "Number of aliases must match the number of fields in an inline table.", ctx) + baseAttributes.zip(aliases).map(p => p._1.withName(p._2)) + } else { + baseAttributes + } + + // Create plan and add an alias if a name has been defined. + LocalRelation(attributes, rows).optionalMap(ctx.identifier)(aliasPlan) + } + + /** + * Create an alias (SubqueryAlias) for a join relation. This is practically the same as + * visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different + * hooks. + */ + override def visitAliasedRelation(ctx: AliasedRelationContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.relation).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan) + } + + /** + * Create an alias (SubqueryAlias) for a sub-query. This is practically the same as + * visitAliasedRelation and visitNamedExpression, ANTLR4 however requires us to use 3 different + * hooks. + */ + override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryNoWith).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan) + } + + /** + * Create an alias (SubqueryAlias) for a LogicalPlan. + */ + private def aliasPlan(alias: IdentifierContext, plan: LogicalPlan): LogicalPlan = { + SubqueryAlias(alias.getText, plan) + } + + /** + * Create a Sequence of Strings for a parenthesis enclosed alias list. + */ + override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) { + visitIdentifierSeq(ctx.identifierSeq) + } + + /** + * Create a Sequence of Strings for an identifier list. + */ + override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) { + ctx.identifier.asScala.map(_.getText) + } + + /* ******************************************************************************************** + * Table Identifier parsing + * ******************************************************************************************** */ + /** + * Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern. + */ + override def visitTableIdentifier( + ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) { + TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText)) + } + + /* ******************************************************************************************** + * Expression parsing + * ******************************************************************************************** */ + /** + * Create an expression from the given context. This method just passes the context on to the + * vistor and only takes care of typing (We assume that the visitor returns an Expression here). + */ + protected def expression(ctx: ParserRuleContext): Expression = typedVisit(ctx) + + /** + * Create sequence of expressions from the given sequence of contexts. + */ + private def expressionList(trees: java.util.List[ExpressionContext]): Seq[Expression] = { + trees.asScala.map(expression) + } + + /** + * Create a star (i.e. all) expression; this selects all elements (in the specified object). + * Both un-targeted (global) and targeted aliases are supported. + */ + override def visitStar(ctx: StarContext): Expression = withOrigin(ctx) { + UnresolvedStar(Option(ctx.qualifiedName()).map(_.identifier.asScala.map(_.getText))) + } + + /** + * Create an aliased expression if an alias is specified. Both single and multi-aliases are + * supported. + */ + override def visitNamedExpression(ctx: NamedExpressionContext): Expression = withOrigin(ctx) { + val e = expression(ctx.expression) + if (ctx.identifier != null) { + Alias(e, ctx.identifier.getText)() + } else if (ctx.identifierList != null) { + MultiAlias(e, visitIdentifierList(ctx.identifierList)) + } else { + e + } + } + + /** + * Combine a number of boolean expressions into a balanced expression tree. These expressions are + * either combined by a logical [[And]] or a logical [[Or]]. + * + * A balanced binary tree is created because regular left recursive trees cause considerable + * performance degradations and can cause stack overflows. + */ + override def visitLogicalBinary(ctx: LogicalBinaryContext): Expression = withOrigin(ctx) { + val expressionType = ctx.operator.getType + val expressionCombiner = expressionType match { + case SqlBaseParser.AND => And.apply _ + case SqlBaseParser.OR => Or.apply _ + } + + // Collect all similar left hand contexts. + val contexts = ArrayBuffer(ctx.right) + var current = ctx.left + def collectContexts: Boolean = current match { + case lbc: LogicalBinaryContext if lbc.operator.getType == expressionType => + contexts += lbc.right + current = lbc.left + true + case _ => + contexts += current + false + } + while (collectContexts) { + // No body - all updates take place in the collectContexts. + } + + // Reverse the contexts to have them in the same sequence as in the SQL statement & turn them + // into expressions. + val expressions = contexts.reverse.map(expression) + + // Create a balanced tree. + def reduceToExpressionTree(low: Int, high: Int): Expression = high - low match { + case 0 => + expressions(low) + case 1 => + expressionCombiner(expressions(low), expressions(high)) + case x => + val mid = low + x / 2 + expressionCombiner( + reduceToExpressionTree(low, mid), + reduceToExpressionTree(mid + 1, high)) + } + reduceToExpressionTree(0, expressions.size - 1) + } + + /** + * Invert a boolean expression. + */ + override def visitLogicalNot(ctx: LogicalNotContext): Expression = withOrigin(ctx) { + Not(expression(ctx.booleanExpression())) + } + + /** + * Create a filtering correlated sub-query. This is not supported yet. + */ + override def visitExists(ctx: ExistsContext): Expression = { + throw new ParseException("EXISTS clauses are not supported.", ctx) + } + + /** + * Create a comparison expression. This compares two expressions. The following comparison + * operators are supported: + * - Equal: '=' or '==' + * - Null-safe Equal: '<=>' + * - Not Equal: '<>' or '!=' + * - Less than: '<' + * - Less then or Equal: '<=' + * - Greater than: '>' + * - Greater then or Equal: '>=' + */ + override def visitComparison(ctx: ComparisonContext): Expression = withOrigin(ctx) { + val left = expression(ctx.left) + val right = expression(ctx.right) + val operator = ctx.comparisonOperator().getChild(0).asInstanceOf[TerminalNode] + operator.getSymbol.getType match { + case SqlBaseParser.EQ => + EqualTo(left, right) + case SqlBaseParser.NSEQ => + EqualNullSafe(left, right) + case SqlBaseParser.NEQ | SqlBaseParser.NEQJ => + Not(EqualTo(left, right)) + case SqlBaseParser.LT => + LessThan(left, right) + case SqlBaseParser.LTE => + LessThanOrEqual(left, right) + case SqlBaseParser.GT => + GreaterThan(left, right) + case SqlBaseParser.GTE => + GreaterThanOrEqual(left, right) + } + } + + /** + * Create a predicated expression. A predicated expression is a normal expression with a + * predicate attached to it, for example: + * {{{ + * a + 1 IS NULL + * }}} + */ + override def visitPredicated(ctx: PredicatedContext): Expression = withOrigin(ctx) { + val e = expression(ctx.valueExpression) + if (ctx.predicate != null) { + withPredicate(e, ctx.predicate) + } else { + e + } + } + + /** + * Add a predicate to the given expression. Supported expressions are: + * - (NOT) BETWEEN + * - (NOT) IN + * - (NOT) LIKE + * - (NOT) RLIKE + * - IS (NOT) NULL. + */ + private def withPredicate(e: Expression, ctx: PredicateContext): Expression = withOrigin(ctx) { + // Invert a predicate if it has a valid NOT clause. + def invertIfNotDefined(e: Expression): Expression = ctx.NOT match { + case null => e + case not => Not(e) + } + + // Create the predicate. + ctx.kind.getType match { + case SqlBaseParser.BETWEEN => + // BETWEEN is translated to lower <= e && e <= upper + invertIfNotDefined(And( + GreaterThanOrEqual(e, expression(ctx.lower)), + LessThanOrEqual(e, expression(ctx.upper)))) + case SqlBaseParser.IN if ctx.query != null => + throw new ParseException("IN with a Sub-query is currently not supported.", ctx) + case SqlBaseParser.IN => + invertIfNotDefined(In(e, ctx.expression.asScala.map(expression))) + case SqlBaseParser.LIKE => + invertIfNotDefined(Like(e, expression(ctx.pattern))) + case SqlBaseParser.RLIKE => + invertIfNotDefined(RLike(e, expression(ctx.pattern))) + case SqlBaseParser.NULL if ctx.NOT != null => + IsNotNull(e) + case SqlBaseParser.NULL => + IsNull(e) + } + } + + /** + * Create a binary arithmetic expression. The following arithmetic operators are supported: + * - Multiplication: '*' + * - Division: '/' + * - Hive Long Division: 'DIV' + * - Modulo: '%' + * - Addition: '+' + * - Subtraction: '-' + * - Binary AND: '&' + * - Binary XOR + * - Binary OR: '|' + */ + override def visitArithmeticBinary(ctx: ArithmeticBinaryContext): Expression = withOrigin(ctx) { + val left = expression(ctx.left) + val right = expression(ctx.right) + ctx.operator.getType match { + case SqlBaseParser.ASTERISK => + Multiply(left, right) + case SqlBaseParser.SLASH => + Divide(left, right) + case SqlBaseParser.PERCENT => + Remainder(left, right) + case SqlBaseParser.DIV => + Cast(Divide(left, right), LongType) + case SqlBaseParser.PLUS => + Add(left, right) + case SqlBaseParser.MINUS => + Subtract(left, right) + case SqlBaseParser.AMPERSAND => + BitwiseAnd(left, right) + case SqlBaseParser.HAT => + BitwiseXor(left, right) + case SqlBaseParser.PIPE => + BitwiseOr(left, right) + } + } + + /** + * Create a unary arithmetic expression. The following arithmetic operators are supported: + * - Plus: '+' + * - Minus: '-' + * - Bitwise Not: '~' + */ + override def visitArithmeticUnary(ctx: ArithmeticUnaryContext): Expression = withOrigin(ctx) { + val value = expression(ctx.valueExpression) + ctx.operator.getType match { + case SqlBaseParser.PLUS => + value + case SqlBaseParser.MINUS => + UnaryMinus(value) + case SqlBaseParser.TILDE => + BitwiseNot(value) + } + } + + /** + * Create a [[Cast]] expression. + */ + override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) { + Cast(expression(ctx.expression), typedVisit(ctx.dataType)) + } + + /** + * Create a (windowed) Function expression. + */ + override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) { + // Create the function call. + val name = ctx.qualifiedName.getText + val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) + val arguments = ctx.expression().asScala.map(expression) match { + case Seq(UnresolvedStar(None)) if name.toLowerCase == "count" && !isDistinct => + // Transform COUNT(*) into COUNT(1). Move this to analysis? + Seq(Literal(1)) + case expressions => + expressions + } + val function = UnresolvedFunction(name, arguments, isDistinct) + + // Check if the function is evaluated in a windowed context. + ctx.windowSpec match { + case spec: WindowRefContext => + UnresolvedWindowExpression(function, visitWindowRef(spec)) + case spec: WindowDefContext => + WindowExpression(function, visitWindowDef(spec)) + case _ => function + } + } + + /** + * Create a reference to a window frame, i.e. [[WindowSpecReference]]. + */ + override def visitWindowRef(ctx: WindowRefContext): WindowSpecReference = withOrigin(ctx) { + WindowSpecReference(ctx.identifier.getText) + } + + /** + * Create a window definition, i.e. [[WindowSpecDefinition]]. + */ + override def visitWindowDef(ctx: WindowDefContext): WindowSpecDefinition = withOrigin(ctx) { + // CLUSTER BY ... | PARTITION BY ... ORDER BY ... + val partition = ctx.partition.asScala.map(expression) + val order = ctx.sortItem.asScala.map(visitSortItem) + + // RANGE/ROWS BETWEEN ... + val frameSpecOption = Option(ctx.windowFrame).map { frame => + val frameType = frame.frameType.getType match { + case SqlBaseParser.RANGE => RangeFrame + case SqlBaseParser.ROWS => RowFrame + } + + SpecifiedWindowFrame( + frameType, + visitFrameBound(frame.start), + Option(frame.end).map(visitFrameBound).getOrElse(CurrentRow)) + } + + WindowSpecDefinition( + partition, + order, + frameSpecOption.getOrElse(UnspecifiedFrame)) + } + + /** + * Create or resolve a [[FrameBoundary]]. Simple math expressions are allowed for Value + * Preceding/Following boundaries. These expressions must be constant (foldable) and return an + * integer value. + */ + override def visitFrameBound(ctx: FrameBoundContext): FrameBoundary = withOrigin(ctx) { + // We currently only allow foldable integers. + def value: Int = { + val e = expression(ctx.expression) + assert(e.resolved && e.foldable && e.dataType == IntegerType, + "Frame bound value must be a constant integer.", + ctx) + e.eval().asInstanceOf[Int] + } + + // Create the FrameBoundary + ctx.boundType.getType match { + case SqlBaseParser.PRECEDING if ctx.UNBOUNDED != null => + UnboundedPreceding + case SqlBaseParser.PRECEDING => + ValuePreceding(value) + case SqlBaseParser.CURRENT => + CurrentRow + case SqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null => + UnboundedFollowing + case SqlBaseParser.FOLLOWING => + ValueFollowing(value) + } + } + + /** + * Create a [[CreateStruct]] expression. + */ + override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) { + CreateStruct(ctx.expression.asScala.map(expression)) + } + + /** + * Create a [[ScalarSubquery]] expression. + */ + override def visitSubqueryExpression( + ctx: SubqueryExpressionContext): Expression = withOrigin(ctx) { + ScalarSubquery(plan(ctx.query)) + } + + /** + * Create a value based [[CaseWhen]] expression. This has the following SQL form: + * {{{ + * CASE [expression] + * WHEN [value] THEN [expression] + * ... + * ELSE [expression] + * END + * }}} + */ + override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) { + val e = expression(ctx.valueExpression) + val branches = ctx.whenClause.asScala.map { wCtx => + (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result)) + } + CaseWhen(branches, Option(ctx.elseExpression).map(expression)) + } + + /** + * Create a condition based [[CaseWhen]] expression. This has the following SQL syntax: + * {{{ + * CASE + * WHEN [predicate] THEN [expression] + * ... + * ELSE [expression] + * END + * }}} + * + * @param ctx the parse tree + * */ + override def visitSearchedCase(ctx: SearchedCaseContext): Expression = withOrigin(ctx) { + val branches = ctx.whenClause.asScala.map { wCtx => + (expression(wCtx.condition), expression(wCtx.result)) + } + CaseWhen(branches, Option(ctx.elseExpression).map(expression)) + } + + /** + * Create a dereference expression. The return type depends on the type of the parent, this can + * either be a [[UnresolvedAttribute]] (if the parent is an [[UnresolvedAttribute]]), or an + * [[UnresolvedExtractValue]] if the parent is some expression. + */ + override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) { + val attr = ctx.fieldName.getText + expression(ctx.base) match { + case UnresolvedAttribute(nameParts) => + UnresolvedAttribute(nameParts :+ attr) + case e => + UnresolvedExtractValue(e, Literal(attr)) + } + } + + /** + * Create an [[UnresolvedAttribute]] expression. + */ + override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) { + UnresolvedAttribute.quoted(ctx.getText) + } + + /** + * Create an [[UnresolvedExtractValue]] expression, this is used for subscript access to an array. + */ + override def visitSubscript(ctx: SubscriptContext): Expression = withOrigin(ctx) { + UnresolvedExtractValue(expression(ctx.value), expression(ctx.index)) + } + + /** + * Create an expression for an expression between parentheses. This is need because the ANTLR + * visitor cannot automatically convert the nested context into an expression. + */ + override def visitParenthesizedExpression( + ctx: ParenthesizedExpressionContext): Expression = withOrigin(ctx) { + expression(ctx.expression) + } + + /** + * Create a [[SortOrder]] expression. + */ + override def visitSortItem(ctx: SortItemContext): SortOrder = withOrigin(ctx) { + if (ctx.DESC != null) { + SortOrder(expression(ctx.expression), Descending) + } else { + SortOrder(expression(ctx.expression), Ascending) + } + } + + /** + * Create a typed Literal expression. A typed literal has the following SQL syntax: + * {{{ + * [TYPE] '[VALUE]' + * }}} + * Currently Date and Timestamp typed literals are supported. + * + * TODO what the added value of this over casting? + */ + override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) { + val value = string(ctx.STRING) + ctx.identifier.getText.toUpperCase match { + case "DATE" => + Literal(Date.valueOf(value)) + case "TIMESTAMP" => + Literal(Timestamp.valueOf(value)) + case other => + throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx) + } + } + + /** + * Create a NULL literal expression. + */ + override def visitNullLiteral(ctx: NullLiteralContext): Literal = withOrigin(ctx) { + Literal(null) + } + + /** + * Create a Boolean literal expression. + */ + override def visitBooleanLiteral(ctx: BooleanLiteralContext): Literal = withOrigin(ctx) { + if (ctx.getText.toBoolean) { + Literal.TrueLiteral + } else { + Literal.FalseLiteral + } + } + + /** + * Create an integral literal expression. The code selects the most narrow integral type + * possible, either a BigDecimal, a Long or an Integer is returned. + */ + override def visitIntegerLiteral(ctx: IntegerLiteralContext): Literal = withOrigin(ctx) { + BigDecimal(ctx.getText) match { + case v if v.isValidInt => + Literal(v.intValue()) + case v if v.isValidLong => + Literal(v.longValue()) + case v => Literal(v.underlying()) + } + } + + /** + * Create a double literal for a number denoted in scientific notation. + */ + override def visitScientificDecimalLiteral( + ctx: ScientificDecimalLiteralContext): Literal = withOrigin(ctx) { + Literal(ctx.getText.toDouble) + } + + /** + * Create a decimal literal for a regular decimal number. + */ + override def visitDecimalLiteral(ctx: DecimalLiteralContext): Literal = withOrigin(ctx) { + Literal(BigDecimal(ctx.getText).underlying()) + } + + /** Create a numeric literal expression. */ + private def numericLiteral(ctx: NumberContext)(f: String => Any): Literal = withOrigin(ctx) { + val raw = ctx.getText + try { + Literal(f(raw.substring(0, raw.length - 1))) + } catch { + case e: NumberFormatException => + throw new ParseException(e.getMessage, ctx) + } + } + + /** + * Create a Byte Literal expression. + */ + override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = numericLiteral(ctx) { + _.toByte + } + + /** + * Create a Short Literal expression. + */ + override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = numericLiteral(ctx) { + _.toShort + } + + /** + * Create a Long Literal expression. + */ + override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = numericLiteral(ctx) { + _.toLong + } + + /** + * Create a Double Literal expression. + */ + override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = numericLiteral(ctx) { + _.toDouble + } + + /** + * Create a String literal expression. + */ + override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) { + Literal(createString(ctx)) + } + + /** + * Create a String from a string literal context. This supports multiple consecutive string + * literals, these are concatenated, for example this expression "'hello' 'world'" will be + * converted into "helloworld". + * + * Special characters can be escaped by using Hive/C-style escaping. + */ + private def createString(ctx: StringLiteralContext): String = { + ctx.STRING().asScala.map(string).mkString + } + + /** + * Create a [[CalendarInterval]] literal expression. An interval expression can contain multiple + * unit value pairs, for instance: interval 2 months 2 days. + */ + override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) { + val intervals = ctx.intervalField.asScala.map(visitIntervalField) + assert(intervals.nonEmpty, "at least one time unit should be given for interval literal", ctx) + Literal(intervals.reduce(_.add(_))) + } + + /** + * Create a [[CalendarInterval]] for a unit value pair. Two unit configuration types are + * supported: + * - Single unit. + * - From-To unit (only 'YEAR TO MONTH' and 'DAY TO SECOND' are supported). + */ + override def visitIntervalField(ctx: IntervalFieldContext): CalendarInterval = withOrigin(ctx) { + import ctx._ + val s = value.getText + try { + val interval = (unit.getText.toLowerCase, Option(to).map(_.getText.toLowerCase)) match { + case (u, None) if u.endsWith("s") => + // Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/... + CalendarInterval.fromSingleUnitString(u.substring(0, u.length - 1), s) + case (u, None) => + CalendarInterval.fromSingleUnitString(u, s) + case ("year", Some("month")) => + CalendarInterval.fromYearMonthString(s) + case ("day", Some("second")) => + CalendarInterval.fromDayTimeString(s) + case (from, Some(t)) => + throw new ParseException(s"Intervals FROM $from TO $t are not supported.", ctx) + } + assert(interval != null, "No interval can be constructed", ctx) + interval + } catch { + // Handle Exceptions thrown by CalendarInterval + case e: IllegalArgumentException => + val pe = new ParseException(e.getMessage, ctx) + pe.setStackTrace(e.getStackTrace) + throw pe + } + } + + /* ******************************************************************************************** + * DataType parsing + * ******************************************************************************************** */ + /** + * Resolve/create a primitive type. + */ + override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) { + (ctx.identifier.getText.toLowerCase, ctx.INTEGER_VALUE().asScala.toList) match { + case ("boolean", Nil) => BooleanType + case ("tinyint" | "byte", Nil) => ByteType + case ("smallint" | "short", Nil) => ShortType + case ("int" | "integer", Nil) => IntegerType + case ("bigint" | "long", Nil) => LongType + case ("float", Nil) => FloatType + case ("double", Nil) => DoubleType + case ("date", Nil) => DateType + case ("timestamp", Nil) => TimestampType + case ("char" | "varchar" | "string", Nil) => StringType + case ("char" | "varchar", _ :: Nil) => StringType + case ("binary", Nil) => BinaryType + case ("decimal", Nil) => DecimalType.USER_DEFAULT + case ("decimal", precision :: Nil) => DecimalType(precision.getText.toInt, 0) + case ("decimal", precision :: scale :: Nil) => + DecimalType(precision.getText.toInt, scale.getText.toInt) + case (dt, params) => + throw new ParseException( + s"DataType $dt${params.mkString("(", ",", ")")} is not supported.", ctx) + } + } + + /** + * Create a complex DataType. Arrays, Maps and Structures are supported. + */ + override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) { + ctx.complex.getType match { + case SqlBaseParser.ARRAY => + ArrayType(typedVisit(ctx.dataType(0))) + case SqlBaseParser.MAP => + MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1))) + case SqlBaseParser.STRUCT => + createStructType(ctx.colTypeList()) + } + } + + /** + * Create a [[StructType]] from a sequence of [[StructField]]s. + */ + protected def createStructType(ctx: ColTypeListContext): StructType = { + StructType(Option(ctx).toSeq.flatMap(visitColTypeList)) + } + + /** + * Create a [[StructType]] from a number of column definitions. + */ + override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] = withOrigin(ctx) { + ctx.colType().asScala.map(visitColType) + } + + /** + * Create a [[StructField]] from a column definition. + */ + override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) { + import ctx._ + + // Add the comment to the metadata. + val builder = new MetadataBuilder + if (STRING != null) { + builder.putString("comment", string(STRING)) + } + + StructField(identifier.getText, typedVisit(dataType), nullable = true, builder.build()) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala new file mode 100644 index 0000000000000..0b570c9e4212d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.parser + +import scala.language.implicitConversions +import scala.util.matching.Regex +import scala.util.parsing.combinator.syntactical.StandardTokenParsers +import scala.util.parsing.input.CharArrayReader._ + +import org.apache.spark.sql.types._ + +/** + * This is a data type parser that can be used to parse string representations of data types + * provided in SQL queries. This parser is mixed in with DDLParser and SqlParser. + */ +private[sql] trait DataTypeParser extends StandardTokenParsers { + + // This is used to create a parser from a regex. We are using regexes for data type strings + // since these strings can be also used as column names or field names. + import lexical.Identifier + implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( + s"identifier matching regex ${regex}", + { case Identifier(str) if regex.unapplySeq(str).isDefined => str } + ) + + protected lazy val primitiveType: Parser[DataType] = + "(?i)string".r ^^^ StringType | + "(?i)float".r ^^^ FloatType | + "(?i)(?:int|integer)".r ^^^ IntegerType | + "(?i)tinyint".r ^^^ ByteType | + "(?i)smallint".r ^^^ ShortType | + "(?i)double".r ^^^ DoubleType | + "(?i)(?:bigint|long)".r ^^^ LongType | + "(?i)binary".r ^^^ BinaryType | + "(?i)boolean".r ^^^ BooleanType | + fixedDecimalType | + "(?i)decimal".r ^^^ DecimalType.USER_DEFAULT | + "(?i)date".r ^^^ DateType | + "(?i)timestamp".r ^^^ TimestampType | + varchar | + char + + protected lazy val fixedDecimalType: Parser[DataType] = + ("(?i)decimal".r ~> "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { + case precision ~ scale => + DecimalType(precision.toInt, scale.toInt) + } + + protected lazy val char: Parser[DataType] = + "(?i)char".r ~> "(" ~> (numericLit <~ ")") ^^^ StringType + + protected lazy val varchar: Parser[DataType] = + "(?i)varchar".r ~> "(" ~> (numericLit <~ ")") ^^^ StringType + + protected lazy val arrayType: Parser[DataType] = + "(?i)array".r ~> "<" ~> dataType <~ ">" ^^ { + case tpe => ArrayType(tpe) + } + + protected lazy val mapType: Parser[DataType] = + "(?i)map".r ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ { + case t1 ~ _ ~ t2 => MapType(t1, t2) + } + + protected lazy val structField: Parser[StructField] = + ident ~ ":" ~ dataType ^^ { + case name ~ _ ~ tpe => StructField(name, tpe, nullable = true) + } + + protected lazy val structType: Parser[DataType] = + ("(?i)struct".r ~> "<" ~> repsep(structField, ",") <~ ">" ^^ { + case fields => new StructType(fields.toArray) + }) | + ("(?i)struct".r ~ "<>" ^^^ StructType(Nil)) + + protected lazy val dataType: Parser[DataType] = + arrayType | + mapType | + structType | + primitiveType + + def toDataType(dataTypeString: String): DataType = synchronized { + phrase(dataType)(new lexical.Scanner(dataTypeString)) match { + case Success(result, _) => result + case failure: NoSuccess => throw new DataTypeException(failMessage(dataTypeString)) + } + } + + private def failMessage(dataTypeString: String): String = { + s"Unsupported dataType: $dataTypeString. If you have a struct and a field name of it has " + + "any special characters, please use backticks (`) to quote that field name, e.g. `x+y`. " + + "Please note that backtick itself is not supported in a field name." + } +} + +private[sql] object DataTypeParser { + lazy val dataTypeParser = new DataTypeParser { + override val lexical = new SqlLexical + } + + def parse(dataTypeString: String): DataType = dataTypeParser.toDataType(dataTypeString) +} + +/** The exception thrown from the [[DataTypeParser]]. */ +private[sql] class DataTypeException(message: String) extends Exception(message) + +class SqlLexical extends scala.util.parsing.combinator.lexical.StdLexical { + case class DecimalLit(chars: String) extends Token { + override def toString: String = chars + } + + /* This is a work around to support the lazy setting */ + def initialize(keywords: Seq[String]): Unit = { + reserved.clear() + reserved ++= keywords + } + + /* Normal the keyword string */ + def normalizeKeyword(str: String): String = str.toLowerCase + + delimiters += ( + "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", + ",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~", "<=>" + ) + + protected override def processIdent(name: String) = { + val token = normalizeKeyword(name) + if (reserved contains token) Keyword(token) else Identifier(name) + } + + override lazy val token: Parser[Token] = + ( rep1(digit) ~ scientificNotation ^^ { case i ~ s => DecimalLit(i.mkString + s) } + | '.' ~> (rep1(digit) ~ scientificNotation) ^^ + { case i ~ s => DecimalLit("0." + i.mkString + s) } + | rep1(digit) ~ ('.' ~> digit.*) ~ scientificNotation ^^ + { case i1 ~ i2 ~ s => DecimalLit(i1.mkString + "." + i2.mkString + s) } + | digit.* ~ identChar ~ (identChar | digit).* ^^ + { case first ~ middle ~ rest => processIdent((first ++ (middle :: rest)).mkString) } + | rep1(digit) ~ ('.' ~> digit.*).? ^^ { + case i ~ None => NumericLit(i.mkString) + case i ~ Some(d) => DecimalLit(i.mkString + "." + d.mkString) + } + | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^ + { case chars => StringLit(chars mkString "") } + | '"' ~> chrExcept('"', '\n', EofCh).* <~ '"' ^^ + { case chars => StringLit(chars mkString "") } + | '`' ~> chrExcept('`', '\n', EofCh).* <~ '`' ^^ + { case chars => Identifier(chars mkString "") } + | EofCh ^^^ EOF + | '\'' ~> failure("unclosed string literal") + | '"' ~> failure("unclosed string literal") + | delim + | failure("illegal character") + ) + + override def identChar: Parser[Elem] = letter | elem('_') + + private lazy val scientificNotation: Parser[String] = + (elem('e') | elem('E')) ~> (elem('+') | elem('-')).? ~ rep1(digit) ^^ { + case s ~ rest => "e" + s.mkString + rest.mkString + } + + override def whitespace: Parser[Any] = + ( whitespaceChar + | '/' ~ '*' ~ comment + | '/' ~ '/' ~ chrExcept(EofCh, '\n').* + | '#' ~ chrExcept(EofCh, '\n').* + | '-' ~ '-' ~ chrExcept(EofCh, '\n').* + | '/' ~ '*' ~ failure("unclosed comment") + ).* +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/LegacyTypeStringParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/LegacyTypeStringParser.scala new file mode 100644 index 0000000000000..60d7361242c69 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/LegacyTypeStringParser.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.parser + +import scala.util.parsing.combinator.RegexParsers + +import org.apache.spark.sql.types._ + +/** + * Parser that turns case class strings into datatypes. This is only here to maintain compatibility + * with Parquet files written by Spark 1.1 and below. + */ +object LegacyTypeStringParser extends RegexParsers { + + protected lazy val primitiveType: Parser[DataType] = + ( "StringType" ^^^ StringType + | "FloatType" ^^^ FloatType + | "IntegerType" ^^^ IntegerType + | "ByteType" ^^^ ByteType + | "ShortType" ^^^ ShortType + | "DoubleType" ^^^ DoubleType + | "LongType" ^^^ LongType + | "BinaryType" ^^^ BinaryType + | "BooleanType" ^^^ BooleanType + | "DateType" ^^^ DateType + | "DecimalType()" ^^^ DecimalType.USER_DEFAULT + | fixedDecimalType + | "TimestampType" ^^^ TimestampType + ) + + protected lazy val fixedDecimalType: Parser[DataType] = + ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ { + case precision ~ scale => DecimalType(precision.toInt, scale.toInt) + } + + protected lazy val arrayType: Parser[DataType] = + "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { + case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) + } + + protected lazy val mapType: Parser[DataType] = + "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { + case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) + } + + protected lazy val structField: Parser[StructField] = + ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { + case name ~ tpe ~ nullable => + StructField(name, tpe, nullable = nullable) + } + + protected lazy val boolVal: Parser[Boolean] = + ( "true" ^^^ true + | "false" ^^^ false + ) + + protected lazy val structType: Parser[DataType] = + "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { + case fields => StructType(fields) + } + + protected lazy val dataType: Parser[DataType] = + ( arrayType + | mapType + | structType + | primitiveType + ) + + /** + * Parses a string representation of a DataType. + */ + def parse(asString: String): DataType = parseAll(dataType, asString) match { + case Success(result, _) => result + case failure: NoSuccess => + throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure") + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala new file mode 100644 index 0000000000000..d0132529f18ea --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.atn.PredictionMode +import org.antlr.v4.runtime.misc.ParseCancellationException + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.types.DataType + +/** + * Base SQL parsing infrastructure. + */ +abstract class AbstractSqlParser extends ParserInterface with Logging { + + /** Creates/Resolves DataType for a given SQL string. */ + def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => + // TODO add this to the parser interface. + astBuilder.visitSingleDataType(parser.singleDataType()) + } + + /** Creates Expression for a given SQL string. */ + override def parseExpression(sqlText: String): Expression = parse(sqlText) { parser => + astBuilder.visitSingleExpression(parser.singleExpression()) + } + + /** Creates TableIdentifier for a given SQL string. */ + override def parseTableIdentifier(sqlText: String): TableIdentifier = parse(sqlText) { parser => + astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier()) + } + + /** Creates LogicalPlan for a given SQL string. */ + override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser => + astBuilder.visitSingleStatement(parser.singleStatement()) match { + case plan: LogicalPlan => plan + case _ => nativeCommand(sqlText) + } + } + + /** Get the builder (visitor) which converts a ParseTree into a AST. */ + protected def astBuilder: AstBuilder + + /** Create a native command, or fail when this is not supported. */ + protected def nativeCommand(sqlText: String): LogicalPlan = { + val position = Origin(None, None) + throw new ParseException(Option(sqlText), "Unsupported SQL statement", position, position) + } + + protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = { + logInfo(s"Parsing command: $command") + + val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command)) + lexer.removeErrorListeners() + lexer.addErrorListener(ParseErrorListener) + + val tokenStream = new CommonTokenStream(lexer) + val parser = new SqlBaseParser(tokenStream) + parser.addParseListener(PostProcessor) + parser.removeErrorListeners() + parser.addErrorListener(ParseErrorListener) + + try { + try { + // first, try parsing with potentially faster SLL mode + parser.getInterpreter.setPredictionMode(PredictionMode.SLL) + toResult(parser) + } + catch { + case e: ParseCancellationException => + // if we fail, parse with LL mode + tokenStream.reset() // rewind input stream + parser.reset() + + // Try Again. + parser.getInterpreter.setPredictionMode(PredictionMode.LL) + toResult(parser) + } + } + catch { + case e: ParseException if e.command.isDefined => + throw e + case e: ParseException => + throw e.withCommand(command) + case e: AnalysisException => + val position = Origin(e.line, e.startPosition) + throw new ParseException(Option(command), e.message, position, position) + } + } +} + +/** + * Concrete SQL parser for Catalyst-only SQL statements. + */ +object CatalystSqlParser extends AbstractSqlParser { + val astBuilder = new AstBuilder +} + +/** + * This string stream provides the lexer with upper case characters only. This greatly simplifies + * lexing the stream, while we can maintain the original command. + * + * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver.ANTLRNoCaseStringStream + * + * The comment below (taken from the original class) describes the rationale for doing this: + * + * This class provides and implementation for a case insensitive token checker for the lexical + * analysis part of antlr. By converting the token stream into upper case at the time when lexical + * rules are checked, this class ensures that the lexical rules need to just match the token with + * upper case letters as opposed to combination of upper case and lower case characters. This is + * purely used for matching lexical rules. The actual token text is stored in the same way as the + * user input without actually converting it into an upper case. The token values are generated by + * the consume() function of the super class ANTLRStringStream. The LA() function is the lookahead + * function and is purely used for matching lexical rules. This also means that the grammar will + * only accept capitalized tokens in case it is run from other tools like antlrworks which do not + * have the ANTLRNoCaseStringStream implementation. + */ + +private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRInputStream(input) { + override def LA(i: Int): Int = { + val la = super.LA(i) + if (la == 0 || la == IntStream.EOF) la + else Character.toUpperCase(la) + } +} + +/** + * The ParseErrorListener converts parse errors into AnalysisExceptions. + */ +case object ParseErrorListener extends BaseErrorListener { + override def syntaxError( + recognizer: Recognizer[_, _], + offendingSymbol: scala.Any, + line: Int, + charPositionInLine: Int, + msg: String, + e: RecognitionException): Unit = { + val position = Origin(Some(line), Some(charPositionInLine)) + throw new ParseException(None, msg, position, position) + } +} + +/** + * A [[ParseException]] is an [[AnalysisException]] that is thrown during the parse process. It + * contains fields and an extended error message that make reporting and diagnosing errors easier. + */ +class ParseException( + val command: Option[String], + message: String, + val start: Origin, + val stop: Origin) extends AnalysisException(message, start.line, start.startPosition) { + + def this(message: String, ctx: ParserRuleContext) = { + this(Option(ParserUtils.command(ctx)), + message, + ParserUtils.position(ctx.getStart), + ParserUtils.position(ctx.getStop)) + } + + override def getMessage: String = { + val builder = new StringBuilder + builder ++= "\n" ++= message + start match { + case Origin(Some(l), Some(p)) => + builder ++= s"(line $l, pos $p)\n" + command.foreach { cmd => + val (above, below) = cmd.split("\n").splitAt(l) + builder ++= "\n== SQL ==\n" + above.foreach(builder ++= _ += '\n') + builder ++= (0 until p).map(_ => "-").mkString("") ++= "^^^\n" + below.foreach(builder ++= _ += '\n') + } + case _ => + command.foreach { cmd => + builder ++= "\n== SQL ==\n" ++= cmd + } + } + builder.toString + } + + def withCommand(cmd: String): ParseException = { + new ParseException(Option(cmd), message, start, stop) + } +} + +/** + * The post-processor validates & cleans-up the parse tree during the parse process. + */ +case object PostProcessor extends SqlBaseBaseListener { + + /** Remove the back ticks from an Identifier. */ + override def exitQuotedIdentifier(ctx: SqlBaseParser.QuotedIdentifierContext): Unit = { + replaceTokenByIdentifier(ctx, 1) { token => + // Remove the double back ticks in the string. + token.setText(token.getText.replace("``", "`")) + token + } + } + + /** Treat non-reserved keywords as Identifiers. */ + override def exitNonReserved(ctx: SqlBaseParser.NonReservedContext): Unit = { + replaceTokenByIdentifier(ctx, 0)(identity) + } + + private def replaceTokenByIdentifier( + ctx: ParserRuleContext, + stripMargins: Int)( + f: CommonToken => CommonToken = identity): Unit = { + val parent = ctx.getParent + parent.removeLastChild() + val token = ctx.getChild(0).getPayload.asInstanceOf[Token] + parent.addChild(f(new CommonToken( + new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream), + SqlBaseParser.IDENTIFIER, + token.getChannel, + token.getStartIndex + stripMargins, + token.getStopIndex - stripMargins))) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala new file mode 100644 index 0000000000000..7f35d650b9571 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * Interface for a parser. + */ +trait ParserInterface { + /** Creates LogicalPlan for a given SQL string. */ + def parsePlan(sqlText: String): LogicalPlan + + /** Creates Expression for a given SQL string. */ + def parseExpression(sqlText: String): Expression + + /** Creates TableIdentifier for a given SQL string. */ + def parseTableIdentifier(sqlText: String): TableIdentifier +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala new file mode 100644 index 0000000000000..cb9fefec8f482 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import scala.collection.mutable.StringBuilder + +import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token} +import org.antlr.v4.runtime.misc.Interval +import org.antlr.v4.runtime.tree.TerminalNode + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} + +/** + * A collection of utility methods for use during the parsing process. + */ +object ParserUtils { + /** Get the command which created the token. */ + def command(ctx: ParserRuleContext): String = { + command(ctx.getStart.getInputStream) + } + + /** Get the command which created the token. */ + def command(stream: CharStream): String = { + stream.getText(Interval.of(0, stream.size())) + } + + /** Get the code that creates the given node. */ + def source(ctx: ParserRuleContext): String = { + val stream = ctx.getStart.getInputStream + stream.getText(Interval.of(ctx.getStart.getStartIndex, ctx.getStop.getStopIndex)) + } + + /** Get all the text which comes after the given rule. */ + def remainder(ctx: ParserRuleContext): String = remainder(ctx.getStop) + + /** Get all the text which comes after the given token. */ + def remainder(token: Token): String = { + val stream = token.getInputStream + val interval = Interval.of(token.getStopIndex + 1, stream.size()) + stream.getText(interval) + } + + /** Convert a string token into a string. */ + def string(token: Token): String = unescapeSQLString(token.getText) + + /** Convert a string node into a string. */ + def string(node: TerminalNode): String = unescapeSQLString(node.getText) + + /** Get the origin (line and position) of the token. */ + def position(token: Token): Origin = { + Origin(Option(token.getLine), Option(token.getCharPositionInLine)) + } + + /** Assert if a condition holds. If it doesn't throw a parse exception. */ + def assert(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = { + if (!f) { + throw new ParseException(message, ctx) + } + } + + /** + * Register the origin of the context. Any TreeNode created in the closure will be assigned the + * registered origin. This method restores the previously set origin after completion of the + * closure. + */ + def withOrigin[T](ctx: ParserRuleContext)(f: => T): T = { + val current = CurrentOrigin.get + CurrentOrigin.set(position(ctx.getStart)) + try { + f + } finally { + CurrentOrigin.set(current) + } + } + + /** Unescape baskslash-escaped string enclosed by quotes. */ + def unescapeSQLString(b: String): String = { + var enclosure: Character = null + val sb = new StringBuilder(b.length()) + + def appendEscapedChar(n: Char) { + n match { + case '0' => sb.append('\u0000') + case '\'' => sb.append('\'') + case '"' => sb.append('\"') + case 'b' => sb.append('\b') + case 'n' => sb.append('\n') + case 'r' => sb.append('\r') + case 't' => sb.append('\t') + case 'Z' => sb.append('\u001A') + case '\\' => sb.append('\\') + // The following 2 lines are exactly what MySQL does TODO: why do we do this? + case '%' => sb.append("\\%") + case '_' => sb.append("\\_") + case _ => sb.append(n) + } + } + + var i = 0 + val strLength = b.length + while (i < strLength) { + val currentChar = b.charAt(i) + if (enclosure == null) { + if (currentChar == '\'' || currentChar == '\"') { + enclosure = currentChar + } + } else if (enclosure == currentChar) { + enclosure = null + } else if (currentChar == '\\') { + + if ((i + 6 < strLength) && b.charAt(i + 1) == 'u') { + // \u0000 style character literals. + + val base = i + 2 + val code = (0 until 4).foldLeft(0) { (mid, j) => + val digit = Character.digit(b.charAt(j + base), 16) + (mid << 4) + digit + } + sb.append(code.asInstanceOf[Char]) + i += 5 + } else if (i + 4 < strLength) { + // \000 style character literals. + + val i1 = b.charAt(i + 1) + val i2 = b.charAt(i + 2) + val i3 = b.charAt(i + 3) + + if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') && (i3 >= '0' && i3 <= '7')) { + val tmp = ((i3 - '0') + ((i2 - '0') << 3) + ((i1 - '0') << 6)).asInstanceOf[Char] + sb.append(tmp) + i += 3 + } else { + appendEscapedChar(i1) + i += 1 + } + } else if (i + 2 < strLength) { + // escaped character literals. + val n = b.charAt(i + 1) + appendEscapedChar(n) + i += 1 + } + } else { + // non-escaped character literals. + sb.append(currentChar) + } + i += 1 + } + sb.toString() + } + + /** Some syntactic sugar which makes it easier to work with optional clauses for LogicalPlans. */ + implicit class EnhancedLogicalPlan(val plan: LogicalPlan) extends AnyVal { + /** + * Create a plan using the block of code when the given context exists. Otherwise return the + * original plan. + */ + def optional(ctx: AnyRef)(f: => LogicalPlan): LogicalPlan = { + if (ctx != null) { + f + } else { + plan + } + } + + /** + * Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the + * passed function. The original plan is returned when the context does not exist. + */ + def optionalMap[C <: ParserRuleContext]( + ctx: C)( + f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = { + if (ctx != null) { + f(ctx, plan) + } else { + plan + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 56a3dd02f9ba3..516b41cb138b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.planning -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 3b975b904a332..00656191354f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql.catalyst.planning -import org.apache.spark.Logging +import scala.annotation.tailrec +import scala.collection.mutable + +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.TreeNodeRef +import org.apache.spark.sql.types.IntegerType /** * A pattern that matches any number of project or filter operations on top of another relational @@ -65,6 +69,9 @@ object PhysicalOperation extends PredicateHelper { val substitutedCondition = substitute(aliases)(condition) (fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases) + case BroadcastHint(child) => + collectProjectsAndFilters(child) + case other => (None, Nil, other, Map.empty) } @@ -76,88 +83,17 @@ object PhysicalOperation extends PredicateHelper { private def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = { expr.transform { case a @ Alias(ref: AttributeReference, name) => - aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a) + aliases.get(ref) + .map(Alias(_, name)(a.exprId, a.qualifier, isGenerated = a.isGenerated)) + .getOrElse(a) case a: AttributeReference => - aliases.get(a).map(Alias(_, a.name)(a.exprId, a.qualifiers)).getOrElse(a) + aliases.get(a) + .map(Alias(_, a.name)(a.exprId, a.qualifier, isGenerated = a.isGenerated)).getOrElse(a) } } } -/** - * Matches a logical aggregation that can be performed on distributed data in two steps. The first - * operates on the data in each partition performing partial aggregation for each group. The second - * occurs after the shuffle and completes the aggregation. - * - * This pattern will only match if all aggregate expressions can be computed partially and will - * return the rewritten aggregation expressions for both phases. - * - * The returned values for this match are as follows: - * - Grouping attributes for the final aggregation. - * - Aggregates for the final aggregation. - * - Grouping expressions for the partial aggregation. - * - Partial aggregate expressions. - * - Input to the aggregation. - */ -object PartialAggregation { - type ReturnType = - (Seq[Attribute], Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan) - - def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { - case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => - // Collect all aggregate expressions. - val allAggregates = - aggregateExpressions.flatMap(_ collect { case a: AggregateExpression1 => a}) - // Collect all aggregate expressions that can be computed partially. - val partialAggregates = - aggregateExpressions.flatMap(_ collect { case p: PartialAggregate1 => p}) - - // Only do partial aggregation if supported by all aggregate expressions. - if (allAggregates.size == partialAggregates.size) { - // Create a map of expressions to their partial evaluations for all aggregate expressions. - val partialEvaluations: Map[TreeNodeRef, SplitEvaluation] = - partialAggregates.map(a => (new TreeNodeRef(a), a.asPartial)).toMap - - // We need to pass all grouping expressions though so the grouping can happen a second - // time. However some of them might be unnamed so we alias them allowing them to be - // referenced in the second aggregation. - val namedGroupingExpressions: Seq[(Expression, NamedExpression)] = - groupingExpressions.map { - case n: NamedExpression => (n, n) - case other => (other, Alias(other, "PartialGroup")()) - } - - // Replace aggregations with a new expression that computes the result from the already - // computed partial evaluations and grouping values. - val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformDown { - case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) => - partialEvaluations(new TreeNodeRef(e)).finalEvaluation - - case e: Expression => - namedGroupingExpressions.collectFirst { - case (expr, ne) if expr semanticEquals e => ne.toAttribute - }.getOrElse(e) - }).asInstanceOf[Seq[NamedExpression]] - - val partialComputation = namedGroupingExpressions.map(_._2) ++ - partialEvaluations.values.flatMap(_.partialEvaluations) - - val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) - - Some( - (namedGroupingAttributes, - rewrittenAggregateExpressions, - groupingExpressions, - partialComputation, - child)) - } else { - None - } - case _ => None - } -} - - /** * A pattern that finds joins with equality conditions that can be evaluated using equi-join. * @@ -206,17 +142,153 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { } } +/** + * A pattern that collects the filter and inner joins. + * + * Filter + * | + * inner Join + * / \ ----> (Seq(plan0, plan1, plan2), conditions) + * Filter plan2 + * | + * inner join + * / \ + * plan0 plan1 + * + * Note: This pattern currently only works for left-deep trees. + */ +object ExtractFiltersAndInnerJoins extends PredicateHelper { + + // flatten all inner joins, which are next to each other + def flattenJoin(plan: LogicalPlan): (Seq[LogicalPlan], Seq[Expression]) = plan match { + case Join(left, right, Inner, cond) => + val (plans, conditions) = flattenJoin(left) + (plans ++ Seq(right), conditions ++ cond.toSeq) + + case Filter(filterCondition, j @ Join(left, right, Inner, joinCondition)) => + val (plans, conditions) = flattenJoin(j) + (plans, conditions ++ splitConjunctivePredicates(filterCondition)) + + case _ => (Seq(plan), Seq()) + } + + def unapply(plan: LogicalPlan): Option[(Seq[LogicalPlan], Seq[Expression])] = plan match { + case f @ Filter(filterCondition, j @ Join(_, _, Inner, _)) => + Some(flattenJoin(f)) + case j @ Join(_, _, Inner, _) => + Some(flattenJoin(j)) + case _ => None + } +} + + /** * A pattern that collects all adjacent unions and returns their children as a Seq. */ object Unions { def unapply(plan: LogicalPlan): Option[Seq[LogicalPlan]] = plan match { - case u: Union => Some(collectUnionChildren(u)) + case u: Union => Some(collectUnionChildren(mutable.Stack(u), Seq.empty[LogicalPlan])) + case _ => None + } + + // Doing a depth-first tree traversal to combine all the union children. + @tailrec + private def collectUnionChildren( + plans: mutable.Stack[LogicalPlan], + children: Seq[LogicalPlan]): Seq[LogicalPlan] = { + if (plans.isEmpty) children + else { + plans.pop match { + case Union(grandchildren) => + grandchildren.reverseMap(plans.push(_)) + collectUnionChildren(plans, children) + case other => collectUnionChildren(plans, children :+ other) + } + } + } +} + +/** + * Extractor for retrieving Int value. + */ +object IntegerIndex { + def unapply(a: Any): Option[Int] = a match { + case Literal(a: Int, IntegerType) => Some(a) + // When resolving ordinal in Sort and Group By, negative values are extracted + // for issuing error messages. + case UnaryMinus(IntegerLiteral(v)) => Some(-v) case _ => None } +} + +/** + * An extractor used when planning the physical execution of an aggregation. Compared with a logical + * aggregation, the following transformations are performed: + * - Unnamed grouping expressions are named so that they can be referred to across phases of + * aggregation + * - Aggregations that appear multiple times are deduplicated. + * - The compution of the aggregations themselves is separated from the final result. For example, + * the `count` in `count + 1` will be split into an [[AggregateExpression]] and a final + * computation that computes `count.resultAttribute + 1`. + */ +object PhysicalAggregation { + // groupingExpressions, aggregateExpressions, resultExpressions, child + type ReturnType = + (Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan) + + def unapply(a: Any): Option[ReturnType] = a match { + case logical.Aggregate(groupingExpressions, resultExpressions, child) => + // A single aggregate expression might appear multiple times in resultExpressions. + // In order to avoid evaluating an individual aggregate function multiple times, we'll + // build a set of the distinct aggregate expressions and build a function which can + // be used to re-write expressions so that they reference the single copy of the + // aggregate function which actually gets computed. + val aggregateExpressions = resultExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression => agg + } + }.distinct - private def collectUnionChildren(plan: LogicalPlan): Seq[LogicalPlan] = plan match { - case Union(l, r) => collectUnionChildren(l) ++ collectUnionChildren(r) - case other => other :: Nil + val namedGroupingExpressions = groupingExpressions.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val groupExpressionMap = namedGroupingExpressions.toMap + + // The original `resultExpressions` are a set of expressions which may reference + // aggregate expressions, grouping column values, and constants. When aggregate operator + // emits output rows, we will use `resultExpressions` to generate an output projection + // which takes the grouping columns and final aggregate result buffer as input. + // Thus, we must re-write the result expressions so that their attributes match up with + // the attributes of the final result projection's input row: + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case ae: AggregateExpression => + // The final aggregation buffer's attributes will be `finalAggregationAttributes`, + // so replace each aggregate expression by its corresponding attribute in the set: + ae.resultAttribute + case expression => + // Since we're using `namedGroupingAttributes` to extract the grouping key + // columns, we need to replace grouping key expressions with their corresponding + // attributes. We do not rely on the equality check at here since attributes may + // differ cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + + Some(( + namedGroupingExpressions.map(_._2), + aggregateExpressions, + rewrittenResultExpressions, + child)) + + case _ => None } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 0ec9f08571082..d4447ca32d5a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,15 +17,93 @@ package org.apache.spark.sql.catalyst.plans -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, VirtualColumn} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{DataType, StructType} -abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanType] { +abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanType] { self: PlanType => def output: Seq[Attribute] + /** + * Extracts the relevant constraints from a given set of constraints based on the attributes that + * appear in the [[outputSet]]. + */ + protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = { + constraints + .union(inferAdditionalConstraints(constraints)) + .union(constructIsNotNullConstraints(constraints)) + .filter(constraint => + constraint.references.nonEmpty && constraint.references.subsetOf(outputSet)) + } + + /** + * Infers a set of `isNotNull` constraints from a given set of equality/comparison expressions as + * well as non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this + * returns a constraint of the form `isNotNull(a)` + */ + private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { + // First, we propagate constraints from the null intolerant expressions. + var isNotNullConstraints: Set[Expression] = + constraints.flatMap(scanNullIntolerantExpr).map(IsNotNull(_)) + + // Second, we infer additional constraints from non-nullable attributes that are part of the + // operator's output + val nonNullableAttributes = output.filterNot(_.nullable) + isNotNullConstraints ++= nonNullableAttributes.map(IsNotNull).toSet + + isNotNullConstraints -- constraints + } + + /** + * Recursively explores the expressions which are null intolerant and returns all attributes + * in these expressions. + */ + private def scanNullIntolerantExpr(expr: Expression): Seq[Attribute] = expr match { + case a: Attribute => Seq(a) + case _: NullIntolerant | IsNotNull(_: NullIntolerant) => + expr.children.flatMap(scanNullIntolerantExpr) + case _ => Seq.empty[Attribute] + } + + /** + * Infers an additional set of constraints from a given set of equality constraints. + * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an + * additional constraint of the form `b = 5` + */ + private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { + var inferredConstraints = Set.empty[Expression] + constraints.foreach { + case eq @ EqualTo(l: Attribute, r: Attribute) => + inferredConstraints ++= (constraints - eq).map(_ transform { + case a: Attribute if a.semanticEquals(l) => r + }) + inferredConstraints ++= (constraints - eq).map(_ transform { + case a: Attribute if a.semanticEquals(r) => l + }) + case _ => // No inference + } + inferredConstraints -- constraints + } + + /** + * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For + * example, if this set contains the expression `a = 2` then that expression is guaranteed to + * evaluate to `true` for all rows produced. + */ + lazy val constraints: ExpressionSet = ExpressionSet(getRelevantConstraints(validConstraints)) + + /** + * This method can be overridden by any child class of QueryPlan to specify a set of constraints + * based on the given operator's constraint propagation logic. These constraints are then + * canonicalized and filtered automatically to contain only those attributes that appear in the + * [[outputSet]]. + * + * See [[Canonicalize]] for more details. + */ + protected def validConstraints: Set[Expression] = Set.empty + /** * Returns the set of attributes that are output by this node. */ @@ -43,21 +121,23 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy def inputSet: AttributeSet = AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output)) + /** + * The set of all attributes that are produced by this node. + */ + def producedAttributes: AttributeSet = AttributeSet.empty + /** * Attributes that are referenced by expressions but not provided by this nodes children. * Subclasses should override this method if they produce attributes internally as it is used by * assertions designed to prevent the construction of invalid plans. - * - * Note that virtual columns should be excluded. Currently, we only support the grouping ID - * virtual column. */ - def missingInput: AttributeSet = - (references -- inputSet).filter(_.name != VirtualColumn.groupingIdName) + def missingInput: AttributeSet = references -- inputSet -- producedAttributes /** * Runs [[transform]] with `rule` on all expressions present in this query operator. * Users should not expect a specific directionality. If a specific directionality is needed, * transformExpressionsDown or transformExpressionsUp should be used. + * * @param rule the rule to be applied to every expression in this operator. */ def transformExpressions(rule: PartialFunction[Expression, Expression]): this.type = { @@ -66,6 +146,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy /** * Runs [[transformDown]] with `rule` on all expressions present in this query operator. + * * @param rule the rule to be applied to every expression in this operator. */ def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = { @@ -88,6 +169,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map(recursiveTransform) case other: AnyRef => other + case null => null } val newArgs = productIterator.map(recursiveTransform).toArray @@ -97,6 +179,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy /** * Runs [[transformUp]] with `rule` on all expressions present in this query operator. + * * @param rule the rule to be applied to every expression in this operator. * @return */ @@ -120,6 +203,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map(recursiveTransform) case other: AnyRef => other + case null => null } val newArgs = productIterator.map(recursiveTransform).toArray @@ -127,8 +211,10 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this } - /** Returns the result of running [[transformExpressions]] on this node - * and all its children. */ + /** + * Returns the result of running [[transformExpressions]] on this node + * and all its children. + */ def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = { transform { case q: QueryPlan[_] => q.transformExpressions(rule).asInstanceOf[PlanType] @@ -136,14 +222,18 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy } /** Returns all of the expressions present in this query plan operator. */ - def expressions: Seq[Expression] = { + final def expressions: Seq[Expression] = { + // Recursively find all expressions from a traversable. + def seqToExpressions(seq: Traversable[Any]): Traversable[Expression] = seq.flatMap { + case e: Expression => e :: Nil + case s: Traversable[_] => seqToExpressions(s) + case other => Nil + } + productIterator.flatMap { case e: Expression => e :: Nil case Some(e: Expression) => e :: Nil - case seq: Traversable[_] => seq.flatMap { - case e: Expression => e :: Nil - case other => Nil - } + case seq: Traversable[_] => seqToExpressions(seq) case other => Nil }.toSeq } @@ -166,4 +256,73 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else "" override def simpleString: String = statePrefix + super.simpleString + + /** + * All the subqueries of current plan. + */ + def subqueries: Seq[PlanType] = { + expressions.flatMap(_.collect {case e: SubqueryExpression => e.plan.asInstanceOf[PlanType]}) + } + + override def innerChildren: Seq[PlanType] = subqueries + + /** + * Canonicalized copy of this query plan. + */ + protected lazy val canonicalized: PlanType = this + + /** + * Returns true when the given query plan will return the same results as this query plan. + * + * Since its likely undecidable to generally determine if two given plans will produce the same + * results, it is okay for this function to return false, even if the results are actually + * the same. Such behavior will not affect correctness, only the application of performance + * enhancements like caching. However, it is not acceptable to return true if the results could + * possibly be different. + * + * By default this function performs a modified version of equality that is tolerant of cosmetic + * differences like attribute naming and or expression id differences. Operators that + * can do better should override this function. + */ + def sameResult(plan: PlanType): Boolean = { + val left = this.canonicalized + val right = plan.canonicalized + left.getClass == right.getClass && + left.children.size == right.children.size && + left.cleanArgs == right.cleanArgs && + (left.children, right.children).zipped.forall(_ sameResult _) + } + + /** + * All the attributes that are used for this plan. + */ + lazy val allAttributes: Seq[Attribute] = children.flatMap(_.output) + + private def cleanExpression(e: Expression): Expression = e match { + case a: Alias => + // As the root of the expression, Alias will always take an arbitrary exprId, we need + // to erase that for equality testing. + val cleanedExprId = + Alias(a.child, a.name)(ExprId(-1), a.qualifier, isGenerated = a.isGenerated) + BindReferences.bindReference(cleanedExprId, allAttributes, allowFailures = true) + case other => + BindReferences.bindReference(other, allAttributes, allowFailures = true) + } + + /** Args that have cleaned such that differences in expression id should not affect equality */ + protected lazy val cleanArgs: Seq[Any] = { + def cleanArg(arg: Any): Any = arg match { + // Children are checked using sameResult above. + case tn: TreeNode[_] if containsChild(tn) => null + case e: Expression => cleanExpression(e).canonicalized + case other => other + } + + productIterator.map { + case s: Option[_] => s.map(cleanArg) + case s: Seq[_] => s.map(cleanArg) + case m: Map[_, _] => m.mapValues(cleanArg) + case other => cleanArg(other) + }.toSeq + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 77dec7ca6e2b5..13f57c54a5623 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute + object JoinType { def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match { case "inner" => Inner @@ -24,27 +26,64 @@ object JoinType { case "leftouter" | "left" => LeftOuter case "rightouter" | "right" => RightOuter case "leftsemi" => LeftSemi + case "leftanti" => LeftAnti case _ => val supported = Seq( "inner", "outer", "full", "fullouter", "leftouter", "left", "rightouter", "right", - "leftsemi") + "leftsemi", + "leftanti") throw new IllegalArgumentException(s"Unsupported join type '$typ'. " + "Supported join types include: " + supported.mkString("'", "', '", "'") + ".") } } -sealed abstract class JoinType +sealed abstract class JoinType { + def sql: String +} + +case object Inner extends JoinType { + override def sql: String = "INNER" +} + +case object LeftOuter extends JoinType { + override def sql: String = "LEFT OUTER" +} + +case object RightOuter extends JoinType { + override def sql: String = "RIGHT OUTER" +} + +case object FullOuter extends JoinType { + override def sql: String = "FULL OUTER" +} -case object Inner extends JoinType +case object LeftSemi extends JoinType { + override def sql: String = "LEFT SEMI" +} -case object LeftOuter extends JoinType +case object LeftAnti extends JoinType { + override def sql: String = "LEFT ANTI" +} -case object RightOuter extends JoinType +case class NaturalJoin(tpe: JoinType) extends JoinType { + require(Seq(Inner, LeftOuter, RightOuter, FullOuter).contains(tpe), + "Unsupported natural join type " + tpe) + override def sql: String = "NATURAL " + tpe.sql +} -case object FullOuter extends JoinType +case class UsingJoin(tpe: JoinType, usingColumns: Seq[UnresolvedAttribute]) extends JoinType { + require(Seq(Inner, LeftOuter, LeftSemi, RightOuter, FullOuter, LeftAnti).contains(tpe), + "Unsupported using join type " + tpe) + override def sql: String = "USING " + tpe.sql +} -case object LeftSemi extends JoinType +object LeftExistence { + def unapply(joinType: JoinType): Option[JoinType] = joinType match { + case LeftSemi | LeftAnti => Some(joinType) + case _ => None + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index e3e7a11dba973..5813b74c770d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{analysis, CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, analysis} import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { @@ -45,6 +45,9 @@ object LocalRelation { case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) extends LeafNode with analysis.MultiInstanceRelation { + // A local relation must have resolved output. + require(output.forall(_.resolved), "Unresolved attributes found when constructing LocalRelation.") + /** * Returns an identical copy of this relation with new exprIds for all attributes. Different * attributes are required when a relation is going to be included multiple times in the same diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 8f8747e105932..aceeb8aadcf68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode} +import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.types.StructType abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { @@ -78,18 +79,26 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { /** * Computes [[Statistics]] for this plan. The default implementation assumes the output - * cardinality is the product of of all child plan's cardinality, i.e. applies in the case + * cardinality is the product of all child plan's cardinality, i.e. applies in the case * of cartesian joins. * * [[LeafNode]]s must override this. */ def statistics: Statistics = { - if (children.size == 0) { + if (children.isEmpty) { throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") } Statistics(sizeInBytes = children.map(_.statistics.sizeInBytes).product) } + /** + * Returns the maximum number of rows that this plan may compute. + * + * Any operator that a Limit can be pushed passed should override this function (e.g., Union). + * Any operator that can push through a Limit should override this function (e.g., Project). + */ + def maxRows: Option[Long] = None + /** * Returns true if this expression and all its children have been resolved to a specific schema * and false if it still contains any unresolved placeholders. Implementations of LogicalPlan @@ -106,58 +115,23 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ def childrenResolved: Boolean = children.forall(_.resolved) + override lazy val canonicalized: LogicalPlan = EliminateSubqueryAliases(this) + /** - * Returns true when the given logical plan will return the same results as this logical plan. - * - * Since its likely undecidable to generally determine if two given plans will produce the same - * results, it is okay for this function to return false, even if the results are actually - * the same. Such behavior will not affect correctness, only the application of performance - * enhancements like caching. However, it is not acceptable to return true if the results could - * possibly be different. - * - * By default this function performs a modified version of equality that is tolerant of cosmetic - * differences like attribute naming and or expression id differences. Logical operators that - * can do better should override this function. + * Resolves a given schema to concrete [[Attribute]] references in this query plan. This function + * should only be called on analyzed plans since it will throw [[AnalysisException]] for + * unresolved [[Attribute]]s. */ - def sameResult(plan: LogicalPlan): Boolean = { - val cleanLeft = EliminateSubQueries(this) - val cleanRight = EliminateSubQueries(plan) - - cleanLeft.getClass == cleanRight.getClass && - cleanLeft.children.size == cleanRight.children.size && { - logDebug( - s"[${cleanRight.cleanArgs.mkString(", ")}] == [${cleanLeft.cleanArgs.mkString(", ")}]") - cleanRight.cleanArgs == cleanLeft.cleanArgs - } && - (cleanLeft.children, cleanRight.children).zipped.forall(_ sameResult _) - } - - /** Args that have cleaned such that differences in expression id should not affect equality */ - protected lazy val cleanArgs: Seq[Any] = { - val input = children.flatMap(_.output) - def cleanExpression(e: Expression) = e match { - case a: Alias => - // As the root of the expression, Alias will always take an arbitrary exprId, we need - // to erase that for equality testing. - val cleanedExprId = Alias(a.child, a.name)(ExprId(-1), a.qualifiers) - BindReferences.bindReference(cleanedExprId, input, allowFailures = true) - case other => BindReferences.bindReference(other, input, allowFailures = true) - } - - productIterator.map { - // Children are checked using sameResult above. - case tn: TreeNode[_] if containsChild(tn) => null - case e: Expression => cleanExpression(e) - case s: Option[_] => s.map { - case e: Expression => cleanExpression(e) - case other => other - } - case s: Seq[_] => s.map { - case e: Expression => cleanExpression(e) - case other => other + def resolve(schema: StructType, resolver: Resolver): Seq[Attribute] = { + schema.map { field => + resolveQuoted(field.name, resolver).map { + case a: AttributeReference => a + case other => sys.error(s"can not handle nested schema yet... plan $this") + }.getOrElse { + throw new AnalysisException( + s"Unable to resolve ${field.name} given [${output.map(_.name).mkString(", ")}]") } - case other => other - }.toSeq + } } /** @@ -203,7 +177,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { resolver: Resolver, attribute: Attribute): Option[(Attribute, List[String])] = { assert(nameParts.length > 1) - if (attribute.qualifiers.exists(resolver(_, nameParts.head))) { + if (attribute.qualifier.exists(resolver(_, nameParts.head))) { // At least one qualifier matches. See if remaining parts match. val remainingParts = nameParts.tail resolveAsColumn(remainingParts, resolver, attribute) @@ -222,7 +196,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { nameParts: Seq[String], resolver: Resolver, attribute: Attribute): Option[(Attribute, List[String])] = { - if (resolver(attribute.name, nameParts.head)) { + if (!attribute.isGenerated && resolver(attribute.name, nameParts.head)) { Option((attribute.withName(nameParts.head), nameParts.tail.toList)) } else { None @@ -295,6 +269,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ abstract class LeafNode extends LogicalPlan { override def children: Seq[LogicalPlan] = Nil + override def producedAttributes: AttributeSet = outputSet } /** @@ -304,6 +279,39 @@ abstract class UnaryNode extends LogicalPlan { def child: LogicalPlan override def children: Seq[LogicalPlan] = child :: Nil + + /** + * Generates an additional set of aliased constraints by replacing the original constraint + * expressions with the corresponding alias + */ + protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = { + projectList.flatMap { + case a @ Alias(e, _) => + child.constraints.map(_ transform { + case expr: Expression if expr.semanticEquals(e) => + a.toAttribute + }).union(Set(EqualNullSafe(e, a.toAttribute))) + case _ => + Set.empty[Expression] + }.toSet + } + + override protected def validConstraints: Set[Expression] = child.constraints + + override def statistics: Statistics = { + // There should be some overhead in Row object, the size should not be zero when there is + // no columns, this help to prevent divide-by-zero error. + val childRowSize = child.output.map(_.dataType.defaultSize).sum + 8 + val outputRowSize = output.map(_.dataType.defaultSize).sum + 8 + // Assume there will be the same number of rows as child has. + var sizeInBytes = (child.statistics.sizeInBytes * outputRowSize) / childRowSize + if (sizeInBytes == 0) { + // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero + // (product of children). + sizeInBytes = 1 + } + Statistics(sizeInBytes = sizeInBytes) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala index ccf5291219add..578027da776e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression} /** * Transforms the input by forking and running the specified script. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 4cb67aacf33ee..d4fc9e4da944a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -17,15 +17,28 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.encoders._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.Utils +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.OpenHashSet + +/** + * When planning take() or collect() operations, this special node that is inserted at the top of + * the logical plan before invoking the query planner. + * + * Rules can pattern-match on this node in order to apply transformations that only take effect + * at the top of the logical query plan. + */ +case class ReturnAnswer(child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) + override def maxRows: Option[Long] = child.maxRows override lazy val resolved: Boolean = { val hasSpecialExpressions = projectList.exists ( _.collect { @@ -37,6 +50,9 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend !expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions } + + override def validConstraints: Set[Expression] = + child.constraints.union(getAliasedConstraints(projectList)) } /** @@ -44,6 +60,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional * programming with one important additional feature, which allows the input rows to be joined with * their output. + * * @param generator the generator expression * @param join when true, each output row is implicitly joined with the input tuple that produced * it. @@ -72,59 +89,184 @@ case class Generate( generatorOutput.forall(_.resolved) } - // we don't want the gOutput to be taken as part of the expressions - // as that will cause exceptions like unresolved attributes etc. - override def expressions: Seq[Expression] = generator :: Nil + override def producedAttributes: AttributeSet = AttributeSet(generatorOutput) def output: Seq[Attribute] = { val qualified = qualifier.map(q => // prepend the new qualifier to the existed one - generatorOutput.map(a => a.withQualifiers(q +: a.qualifiers)) + generatorOutput.map(a => a.withQualifier(Some(q))) ).getOrElse(generatorOutput) if (join) child.output ++ qualified else qualified } } -case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { +case class Filter(condition: Expression, child: LogicalPlan) + extends UnaryNode with PredicateHelper { override def output: Seq[Attribute] = child.output + + override def maxRows: Option[Long] = child.maxRows + + override protected def validConstraints: Set[Expression] = + child.constraints.union(splitConjunctivePredicates(condition).toSet) } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - // TODO: These aren't really the same attributes as nullability etc might change. - final override def output: Seq[Attribute] = left.output - final override lazy val resolved: Boolean = - childrenResolved && - left.output.length == right.output.length && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } + protected def leftConstraints: Set[Expression] = left.constraints + + protected def rightConstraints: Set[Expression] = { + require(left.output.size == right.output.size) + val attributeRewrites = AttributeMap(right.output.zip(left.output)) + right.constraints.map(_ transform { + case a: Attribute => attributeRewrites(a) + }) + } } private[sql] object SetOperation { def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right)) } -case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { +case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { + + def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + + override def output: Seq[Attribute] = + left.output.zip(right.output).map { case (leftAttr, rightAttr) => + leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) + } + + override protected def validConstraints: Set[Expression] = + leftConstraints.union(rightConstraints) + + // Intersect are only resolved if they don't introduce ambiguous expression ids, + // since the Optimizer will convert Intersect to Join. + override lazy val resolved: Boolean = + childrenResolved && + left.output.length == right.output.length && + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } && + duplicateResolved + + override def maxRows: Option[Long] = { + if (children.exists(_.maxRows.isEmpty)) { + None + } else { + Some(children.flatMap(_.maxRows).min) + } + } override def statistics: Statistics = { - val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes + val leftSize = left.statistics.sizeInBytes + val rightSize = right.statistics.sizeInBytes + val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize Statistics(sizeInBytes = sizeInBytes) } } -case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) +case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { + /** We don't use right.output because those rows get excluded from the set. */ + override def output: Seq[Attribute] = left.output -case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) + override protected def validConstraints: Set[Expression] = leftConstraints + + override lazy val resolved: Boolean = + childrenResolved && + left.output.length == right.output.length && + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } + + override def statistics: Statistics = { + Statistics(sizeInBytes = left.statistics.sizeInBytes) + } +} + +/** Factory for constructing new `Union` nodes. */ +object Union { + def apply(left: LogicalPlan, right: LogicalPlan): Union = { + Union (left :: right :: Nil) + } +} + +case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { + override def maxRows: Option[Long] = { + if (children.exists(_.maxRows.isEmpty)) { + None + } else { + Some(children.flatMap(_.maxRows).sum) + } + } + + // updating nullability to make all the children consistent + override def output: Seq[Attribute] = + children.map(_.output).transpose.map(attrs => + attrs.head.withNullability(attrs.exists(_.nullable))) + + override lazy val resolved: Boolean = { + // allChildrenCompatible needs to be evaluated after childrenResolved + def allChildrenCompatible: Boolean = + children.tail.forall( child => + // compare the attribute number with the first child + child.output.length == children.head.output.length && + // compare the data types with the first child + child.output.zip(children.head.output).forall { + case (l, r) => l.dataType == r.dataType } + ) + + children.length > 1 && childrenResolved && allChildrenCompatible + } + + override def statistics: Statistics = { + val sizeInBytes = children.map(_.statistics.sizeInBytes).sum + Statistics(sizeInBytes = sizeInBytes) + } + + /** + * Maps the constraints containing a given (original) sequence of attributes to those with a + * given (reference) sequence of attributes. Given the nature of union, we expect that the + * mapping between the original and reference sequences are symmetric. + */ + private def rewriteConstraints( + reference: Seq[Attribute], + original: Seq[Attribute], + constraints: Set[Expression]): Set[Expression] = { + require(reference.size == original.size) + val attributeRewrites = AttributeMap(original.zip(reference)) + constraints.map(_ transform { + case a: Attribute => attributeRewrites(a) + }) + } + + private def merge(a: Set[Expression], b: Set[Expression]): Set[Expression] = { + val common = a.intersect(b) + // The constraint with only one reference could be easily inferred as predicate + // Grouping the constraints by it's references so we can combine the constraints with same + // reference together + val othera = a.diff(common).filter(_.references.size == 1).groupBy(_.references.head) + val otherb = b.diff(common).filter(_.references.size == 1).groupBy(_.references.head) + // loose the constraints by: A1 && B1 || A2 && B2 -> (A1 || A2) && (B1 || B2) + val others = (othera.keySet intersect otherb.keySet).map { attr => + Or(othera(attr).reduceLeft(And), otherb(attr).reduceLeft(And)) + } + common ++ others + } + + override protected def validConstraints: Set[Expression] = { + children + .map(child => rewriteConstraints(children.head.output, child.output, child.constraints)) + .reduce(merge(_, _)) + } +} case class Join( - left: LogicalPlan, - right: LogicalPlan, - joinType: JoinType, - condition: Option[Expression]) extends BinaryNode { + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + condition: Option[Expression]) + extends BinaryNode with PredicateHelper { override def output: Seq[Attribute] = { joinType match { - case LeftSemi => + case LeftExistence(_) => left.output case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) @@ -137,15 +279,46 @@ case class Join( } } - def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + override protected def validConstraints: Set[Expression] = { + joinType match { + case Inner if condition.isDefined => + left.constraints + .union(right.constraints) + .union(splitConjunctivePredicates(condition.get).toSet) + case LeftSemi if condition.isDefined => + left.constraints + .union(splitConjunctivePredicates(condition.get).toSet) + case Inner => + left.constraints.union(right.constraints) + case LeftExistence(_) => + left.constraints + case LeftOuter => + left.constraints + case RightOuter => + right.constraints + case FullOuter => + Set.empty[Expression] + } + } + + def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty // Joins are only resolved if they don't introduce ambiguous expression ids. - override lazy val resolved: Boolean = { + // NaturalJoin should be ready for resolution only if everything else is resolved here + lazy val resolvedExceptNatural: Boolean = { childrenResolved && expressions.forall(_.resolved) && - selfJoinResolved && + duplicateResolved && condition.forall(_.dataType == BooleanType) } + + // if not a natural join, use `resolvedExceptNatural`. if it is a natural join or + // using join, we still need to eliminate natural or using before we mark it resolved. + override lazy val resolved: Boolean = joinType match { + case NaturalJoin(_) => false + case UsingJoin(_, _) => false + case _ => resolvedExceptNatural + } } /** @@ -153,6 +326,10 @@ case class Join( */ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + + // We manually set statistics of BroadcastHint to smallest value to make sure + // the plan wrapped by BroadcastHint will be considered to broadcast later. + override def statistics: Statistics = Statistics(sizeInBytes = 1) } case class InsertIntoTable( @@ -176,12 +353,13 @@ case class InsertIntoTable( /** * A container for holding named common table expressions (CTEs) and a query plan. * This operator will be removed during analysis and the relations will be substituted into child. + * * @param child The final query of this CTE. * @param cteRelations Queries that this CTE defined, * key is the alias of the CTE definition, * value is the CTE definition. */ -case class With(child: LogicalPlan, cteRelations: Map[String, Subquery]) extends UnaryNode { +case class With(child: LogicalPlan, cteRelations: Map[String, SubqueryAlias]) extends UnaryNode { override def output: Seq[Attribute] = child.output } @@ -202,6 +380,42 @@ case class Sort( global: Boolean, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def maxRows: Option[Long] = child.maxRows +} + +/** Factory for constructing new `Range` nodes. */ +object Range { + def apply(start: Long, end: Long, step: Long, numSlices: Int): Range = { + val output = StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes + new Range(start, end, step, numSlices, output) + } +} + +case class Range( + start: Long, + end: Long, + step: Long, + numSlices: Int, + output: Seq[Attribute]) extends LeafNode with MultiInstanceRelation { + require(step != 0, "step cannot be 0") + val numElements: BigInt = { + val safeStart = BigInt(start) + val safeEnd = BigInt(end) + if ((safeEnd - safeStart) % step == 0 || (safeEnd > safeStart) != (step > 0)) { + (safeEnd - safeStart) / step + } else { + // the remainder has the same sign with range, could add 1 more + (safeEnd - safeStart) / step + 1 + } + } + + override def newInstance(): Range = + Range(start, end, step, numSlices, output.map(_.newInstance())) + + override def statistics: Statistics = { + val sizeInBytes = LongType.defaultSize * numElements + Statistics( sizeInBytes = sizeInBytes ) + } } case class Aggregate( @@ -219,101 +433,117 @@ case class Aggregate( !expressions.exists(!_.resolved) && childrenResolved && !hasWindowExpressions } - lazy val newAggregation: Option[Aggregate] = Utils.tryConvert(this) - override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) + override def maxRows: Option[Long] = child.maxRows + + override def validConstraints: Set[Expression] = + child.constraints.union(getAliasedConstraints(aggregateExpressions)) + + override def statistics: Statistics = { + if (groupingExpressions.isEmpty) { + Statistics(sizeInBytes = 1) + } else { + super.statistics + } + } } case class Window( - projectList: Seq[Attribute], windowExpressions: Seq[NamedExpression], partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = - projectList ++ windowExpressions.map(_.toAttribute) -} + child.output ++ windowExpressions.map(_.toAttribute) -/** - * Apply the all of the GroupExpressions to every input row, hence we will get - * multiple output rows for a input row. - * @param bitmasks The bitmask set represents the grouping sets - * @param groupByExprs The grouping by expressions - * @param child Child operator - */ -case class Expand( - bitmasks: Seq[Int], - groupByExprs: Seq[Expression], - gid: Attribute, - child: LogicalPlan) extends UnaryNode { - override def statistics: Statistics = { - val sizeInBytes = child.statistics.sizeInBytes * projections.length - Statistics(sizeInBytes = sizeInBytes) - } - - val projections: Seq[Seq[Expression]] = expand() + def windowOutputSet: AttributeSet = AttributeSet(windowExpressions.map(_.toAttribute)) +} +private[sql] object Expand { /** - * Extract attribute set according to the grouping id + * Extract attribute set according to the grouping id. + * * @param bitmask bitmask to represent the selected of the attribute sequence - * @param exprs the attributes in sequence + * @param attrs the attributes in sequence * @return the attributes of non selected specified via bitmask (with the bit set to 1) */ - private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression]) - : OpenHashSet[Expression] = { - val set = new OpenHashSet[Expression](2) + private def buildNonSelectAttrSet( + bitmask: Int, + attrs: Seq[Attribute]): AttributeSet = { + val nonSelect = new ArrayBuffer[Attribute]() - var bit = exprs.length - 1 + var bit = attrs.length - 1 while (bit >= 0) { - if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit)) + if (((bitmask >> bit) & 1) == 1) nonSelect += attrs(attrs.length - bit - 1) bit -= 1 } - set + AttributeSet(nonSelect) } /** - * Create an array of Projections for the child projection, and replace the projections' - * expressions which equal GroupBy expressions with Literal(null), if those expressions - * are not set for this grouping set (according to the bit mask). + * Apply the all of the GroupExpressions to every input row, hence we will get + * multiple output rows for a input row. + * + * @param bitmasks The bitmask set represents the grouping sets + * @param groupByAliases The aliased original group by expressions + * @param groupByAttrs The attributes of aliased group by expressions + * @param gid Attribute of the grouping id + * @param child Child operator */ - private[this] def expand(): Seq[Seq[Expression]] = { - val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]] - - bitmasks.foreach { bitmask => + def apply( + bitmasks: Seq[Int], + groupByAliases: Seq[Alias], + groupByAttrs: Seq[Attribute], + gid: Attribute, + child: LogicalPlan): Expand = { + // Create an array of Projections for the child projection, and replace the projections' + // expressions which equal GroupBy expressions with Literal(null), if those expressions + // are not set for this grouping set (according to the bit mask). + val projections = bitmasks.map { bitmask => // get the non selected grouping attributes according to the bit mask - val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs) + val nonSelectedGroupAttrSet = buildNonSelectAttrSet(bitmask, groupByAttrs) - val substitution = (child.output :+ gid).map(expr => expr transformDown { - case x: Expression if nonSelectedGroupExprSet.contains(x) => + child.output ++ groupByAttrs.map { attr => + if (nonSelectedGroupAttrSet.contains(attr)) { // if the input attribute in the Invalid Grouping Expression set of for this group // replace it with constant null - Literal.create(null, expr.dataType) - case x if x == gid => - // replace the groupingId with concrete value (the bit mask) - Literal.create(bitmask, IntegerType) - }) - - result += substitution + Literal.create(null, attr.dataType) + } else { + attr + } + // groupingId is the last output, here we use the bit mask as the concrete value for it. + } :+ Literal.create(bitmask, IntegerType) } - - result.toSeq - } - - override def output: Seq[Attribute] = { - child.output :+ gid + val output = child.output ++ groupByAttrs :+ gid + Expand(projections, output, Project(child.output ++ groupByAliases, child)) } } -trait GroupingAnalytics extends UnaryNode { - - def groupByExprs: Seq[Expression] - def aggregations: Seq[NamedExpression] +/** + * Apply a number of projections to every input row, hence we will get multiple output rows for + * a input row. + * + * @param projections to apply + * @param output of all projections. + * @param child operator. + */ +case class Expand( + projections: Seq[Seq[Expression]], + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + override def references: AttributeSet = + AttributeSet(projections.flatten.flatMap(_.references)) - override def output: Seq[Attribute] = aggregations.map(_.toAttribute) + override def statistics: Statistics = { + val sizeInBytes = super.statistics.sizeInBytes * projections.length + Statistics(sizeInBytes = sizeInBytes) + } - def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics + // This operator can reuse attributes (for example making them null when doing a roll up) so + // the contraints of the child may no longer be valid. + override protected def validConstraints: Set[Expression] = Set.empty[Expression] } /** @@ -321,6 +551,7 @@ trait GroupingAnalytics extends UnaryNode { * to generated by a UNION ALL of multiple simple GROUP BY clauses. * * We will transform GROUPING SETS into logical plan Aggregate(.., Expand) in Analyzer + * * @param bitmasks A list of bitmasks, each of the bitmask indicates the selected * GroupBy expressions * @param groupByExprs The Group By expressions candidates, take effective only if the @@ -333,52 +564,65 @@ case class GroupingSets( bitmasks: Seq[Int], groupByExprs: Seq[Expression], child: LogicalPlan, - aggregations: Seq[NamedExpression]) extends GroupingAnalytics { + aggregations: Seq[NamedExpression]) extends UnaryNode { - def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = - this.copy(aggregations = aggs) -} + override def output: Seq[Attribute] = aggregations.map(_.toAttribute) -/** - * Cube is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets, - * and eventually will be transformed to Aggregate(.., Expand) in Analyzer - * - * @param groupByExprs The Group By expressions candidates. - * @param child Child operator - * @param aggregations The Aggregation expressions, those non selected group by expressions - * will be considered as constant null if it appears in the expressions - */ -case class Cube( - groupByExprs: Seq[Expression], - child: LogicalPlan, - aggregations: Seq[NamedExpression]) extends GroupingAnalytics { + // Needs to be unresolved before its translated to Aggregate + Expand because output attributes + // will change in analysis. + override lazy val resolved: Boolean = false +} - def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = - this.copy(aggregations = aggs) +case class Pivot( + groupByExprs: Seq[NamedExpression], + pivotColumn: Expression, + pivotValues: Seq[Literal], + aggregates: Seq[Expression], + child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match { + case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) + case _ => pivotValues.flatMap{ value => + aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)()) + } + } } -/** - * Rollup is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets, - * and eventually will be transformed to Aggregate(.., Expand) in Analyzer - * - * @param groupByExprs The Group By expressions candidates, take effective only if the - * associated bit in the bitmask set to 1. - * @param child Child operator - * @param aggregations The Aggregation expressions, those non selected group by expressions - * will be considered as constant null if it appears in the expressions - */ -case class Rollup( - groupByExprs: Seq[Expression], - child: LogicalPlan, - aggregations: Seq[NamedExpression]) extends GroupingAnalytics { +object Limit { + def apply(limitExpr: Expression, child: LogicalPlan): UnaryNode = { + GlobalLimit(limitExpr, LocalLimit(limitExpr, child)) + } - def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = - this.copy(aggregations = aggs) + def unapply(p: GlobalLimit): Option[(Expression, LogicalPlan)] = { + p match { + case GlobalLimit(le1, LocalLimit(le2, child)) if le1 == le2 => Some((le1, child)) + case _ => None + } + } } -case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { +case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def maxRows: Option[Long] = { + limitExpr match { + case IntegerLiteral(limit) => Some(limit) + case _ => None + } + } + override lazy val statistics: Statistics = { + val limit = limitExpr.eval().asInstanceOf[Int] + val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum + Statistics(sizeInBytes = sizeInBytes) + } +} +case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output + override def maxRows: Option[Long] = { + limitExpr match { + case IntegerLiteral(limit) => Some(limit) + case _ => None + } + } override lazy val statistics: Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum @@ -386,8 +630,9 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { } } -case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil)) +case class SubqueryAlias(alias: String, child: LogicalPlan) extends UnaryNode { + + override def output: Seq[Attribute] = child.output.map(_.withQualifier(Some(alias))) } /** @@ -399,21 +644,36 @@ case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { * @param withReplacement Whether to sample with replacement. * @param seed the random seed * @param child the LogicalPlan + * @param isTableSample Is created from TABLESAMPLE in the parser. */ case class Sample( lowerBound: Double, upperBound: Double, withReplacement: Boolean, seed: Long, - child: LogicalPlan) extends UnaryNode { + child: LogicalPlan)( + val isTableSample: java.lang.Boolean = false) extends UnaryNode { override def output: Seq[Attribute] = child.output + + override def statistics: Statistics = { + val ratio = upperBound - lowerBound + // BigInt can't multiply with Double + var sizeInBytes = child.statistics.sizeInBytes * (ratio * 100).toInt / 100 + if (sizeInBytes == 0) { + sizeInBytes = 1 + } + Statistics(sizeInBytes = sizeInBytes) + } + + override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil } /** * Returns a new logical plan that dedups input rows. */ case class Distinct(child: LogicalPlan) extends UnaryNode { + override def maxRows: Option[Long] = child.maxRows override def output: Seq[Attribute] = child.output } @@ -432,6 +692,7 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) * A relation with one row. This is used in "SELECT ..." without a from clause. */ case object OneRowRelation extends LeafNode { + override def maxRows: Option[Long] = Some(1) override def output: Seq[Attribute] = Nil /** @@ -443,112 +704,3 @@ case object OneRowRelation extends LeafNode { */ override def statistics: Statistics = Statistics(sizeInBytes = 1) } - -/** - * A relation produced by applying `func` to each partition of the `child`. tEncoder/uEncoder are - * used respectively to decode/encode from the JVM object representation expected by `func.` - */ -case class MapPartitions[T, U]( - func: Iterator[T] => Iterator[U], - tEncoder: ExpressionEncoder[T], - uEncoder: ExpressionEncoder[U], - output: Seq[Attribute], - child: LogicalPlan) extends UnaryNode { - override def missingInput: AttributeSet = AttributeSet.empty -} - -/** Factory for constructing new `AppendColumn` nodes. */ -object AppendColumn { - def apply[T : Encoder, U : Encoder](func: T => U, child: LogicalPlan): AppendColumn[T, U] = { - val attrs = encoderFor[U].schema.toAttributes - new AppendColumn[T, U](func, encoderFor[T], encoderFor[U], attrs, child) - } -} - -/** - * A relation produced by applying `func` to each partition of the `child`, concatenating the - * resulting columns at the end of the input row. tEncoder/uEncoder are used respectively to - * decode/encode from the JVM object representation expected by `func.` - */ -case class AppendColumn[T, U]( - func: T => U, - tEncoder: ExpressionEncoder[T], - uEncoder: ExpressionEncoder[U], - newColumns: Seq[Attribute], - child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output ++ newColumns - override def missingInput: AttributeSet = super.missingInput -- newColumns -} - -/** Factory for constructing new `MapGroups` nodes. */ -object MapGroups { - def apply[K : Encoder, T : Encoder, U : Encoder]( - func: (K, Iterator[T]) => Iterator[U], - groupingAttributes: Seq[Attribute], - child: LogicalPlan): MapGroups[K, T, U] = { - new MapGroups( - func, - encoderFor[K], - encoderFor[T], - encoderFor[U], - groupingAttributes, - encoderFor[U].schema.toAttributes, - child) - } -} - -/** - * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`. - * Func is invoked with an object representation of the grouping key an iterator containing the - * object representation of all the rows with that key. - */ -case class MapGroups[K, T, U]( - func: (K, Iterator[T]) => Iterator[U], - kEncoder: ExpressionEncoder[K], - tEncoder: ExpressionEncoder[T], - uEncoder: ExpressionEncoder[U], - groupingAttributes: Seq[Attribute], - output: Seq[Attribute], - child: LogicalPlan) extends UnaryNode { - override def missingInput: AttributeSet = AttributeSet.empty -} - -/** Factory for constructing new `CoGroup` nodes. */ -object CoGroup { - def apply[K : Encoder, Left : Encoder, Right : Encoder, R : Encoder]( - func: (K, Iterator[Left], Iterator[Right]) => Iterator[R], - leftGroup: Seq[Attribute], - rightGroup: Seq[Attribute], - left: LogicalPlan, - right: LogicalPlan): CoGroup[K, Left, Right, R] = { - CoGroup( - func, - encoderFor[K], - encoderFor[Left], - encoderFor[Right], - encoderFor[R], - encoderFor[R].schema.toAttributes, - leftGroup, - rightGroup, - left, - right) - } -} - -/** - * A relation produced by applying `func` to each grouping key and associated values from left and - * right children. - */ -case class CoGroup[K, Left, Right, R]( - func: (K, Iterator[Left], Iterator[Right]) => Iterator[R], - kEncoder: ExpressionEncoder[K], - leftEnc: ExpressionEncoder[Left], - rightEnc: ExpressionEncoder[Right], - rEncoder: ExpressionEncoder[R], - output: Seq[Attribute], - leftGroup: Seq[Attribute], - rightGroup: Seq[Attribute], - left: LogicalPlan, - right: LogicalPlan) extends BinaryNode { - override def missingInput: AttributeSet = AttributeSet.empty -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index e6621e0f50a9e..47b34d1fa2e49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.types.StringType /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala new file mode 100644 index 0000000000000..6df46189b627c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer +import org.apache.spark.sql.catalyst.encoders._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{DataType, ObjectType, StructType} + +object CatalystSerde { + def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = { + val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer) + DeserializeToObject(Alias(deserializer, "obj")(), child) + } + + def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = { + SerializeFromObject(encoderFor[T].namedExpressions, child) + } +} + +/** + * Takes the input row from child and turns it into object using the given deserializer expression. + * The output of this operator is a single-field safe row containing the deserialized object. + */ +case class DeserializeToObject( + deserializer: Alias, + child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = deserializer.toAttribute :: Nil + + def outputObjectType: DataType = deserializer.dataType +} + +/** + * Takes the input object from child and turns in into unsafe row using the given serializer + * expression. The output of its child must be a single-field row containing the input object. + */ +case class SerializeFromObject( + serializer: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + def inputObjectType: DataType = child.output.head.dataType +} + +/** + * A trait for logical operators that apply user defined functions to domain objects. + */ +trait ObjectOperator extends LogicalPlan { + + /** The serializer that is used to produce the output of this operator. */ + def serializer: Seq[NamedExpression] + + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + /** + * The object type that is produced by the user defined function. Note that the return type here + * is the same whether or not the operator is output serialized data. + */ + def outputObject: NamedExpression = + Alias(serializer.head.collect { case b: BoundReference => b }.head, "obj")() + + /** + * Returns a copy of this operator that will produce an object instead of an encoded row. + * Used in the optimizer when transforming plans to remove unneeded serialization. + */ + def withObjectOutput: LogicalPlan = if (output.head.dataType.isInstanceOf[ObjectType]) { + this + } else { + withNewSerializer(outputObject :: Nil) + } + + /** Returns a copy of this operator with a different serializer. */ + def withNewSerializer(newSerializer: Seq[NamedExpression]): LogicalPlan = makeCopy { + productIterator.map { + case c if c == serializer => newSerializer + case other: AnyRef => other + }.toArray + } +} + +object MapPartitions { + def apply[T : Encoder, U : Encoder]( + func: Iterator[T] => Iterator[U], + child: LogicalPlan): MapPartitions = { + MapPartitions( + func.asInstanceOf[Iterator[Any] => Iterator[Any]], + UnresolvedDeserializer(encoderFor[T].deserializer), + encoderFor[U].namedExpressions, + child) + } +} + +/** + * A relation produced by applying `func` to each partition of the `child`. + * + * @param deserializer used to extract the input to `func` from an input row. + * @param serializer use to serialize the output of `func`. + */ +case class MapPartitions( + func: Iterator[Any] => Iterator[Any], + deserializer: Expression, + serializer: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode with ObjectOperator + +object MapElements { + def apply[T : Encoder, U : Encoder]( + func: AnyRef, + child: LogicalPlan): MapElements = { + MapElements( + func, + UnresolvedDeserializer(encoderFor[T].deserializer), + encoderFor[U].namedExpressions, + child) + } +} + +/** + * A relation produced by applying `func` to each element of the `child`. + * + * @param deserializer used to extract the input to `func` from an input row. + * @param serializer use to serialize the output of `func`. + */ +case class MapElements( + func: AnyRef, + deserializer: Expression, + serializer: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode with ObjectOperator + +/** Factory for constructing new `AppendColumn` nodes. */ +object AppendColumns { + def apply[T : Encoder, U : Encoder]( + func: T => U, + child: LogicalPlan): AppendColumns = { + new AppendColumns( + func.asInstanceOf[Any => Any], + UnresolvedDeserializer(encoderFor[T].deserializer), + encoderFor[U].namedExpressions, + child) + } +} + +/** + * A relation produced by applying `func` to each partition of the `child`, concatenating the + * resulting columns at the end of the input row. + * + * @param deserializer used to extract the input to `func` from an input row. + * @param serializer use to serialize the output of `func`. + */ +case class AppendColumns( + func: Any => Any, + deserializer: Expression, + serializer: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode with ObjectOperator { + + override def output: Seq[Attribute] = child.output ++ newColumns + + def newColumns: Seq[Attribute] = serializer.map(_.toAttribute) +} + +/** Factory for constructing new `MapGroups` nodes. */ +object MapGroups { + def apply[K : Encoder, T : Encoder, U : Encoder]( + func: (K, Iterator[T]) => TraversableOnce[U], + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + child: LogicalPlan): MapGroups = { + new MapGroups( + func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]], + UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes), + UnresolvedDeserializer(encoderFor[T].deserializer, dataAttributes), + encoderFor[U].namedExpressions, + groupingAttributes, + dataAttributes, + child) + } +} + +/** + * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`. + * Func is invoked with an object representation of the grouping key an iterator containing the + * object representation of all the rows with that key. + * + * @param keyDeserializer used to extract the key object for each group. + * @param valueDeserializer used to extract the items in the iterator from an input row. + * @param serializer use to serialize the output of `func`. + */ +case class MapGroups( + func: (Any, Iterator[Any]) => TraversableOnce[Any], + keyDeserializer: Expression, + valueDeserializer: Expression, + serializer: Seq[NamedExpression], + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + child: LogicalPlan) extends UnaryNode with ObjectOperator + +/** Factory for constructing new `CoGroup` nodes. */ +object CoGroup { + def apply[Key : Encoder, Left : Encoder, Right : Encoder, Result : Encoder]( + func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], + leftGroup: Seq[Attribute], + rightGroup: Seq[Attribute], + leftAttr: Seq[Attribute], + rightAttr: Seq[Attribute], + left: LogicalPlan, + right: LogicalPlan): CoGroup = { + require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup)) + + CoGroup( + func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]], + // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to + // resolve the `keyDeserializer` based on either of them, here we pick the left one. + UnresolvedDeserializer(encoderFor[Key].deserializer, leftGroup), + UnresolvedDeserializer(encoderFor[Left].deserializer, leftAttr), + UnresolvedDeserializer(encoderFor[Right].deserializer, rightAttr), + encoderFor[Result].namedExpressions, + leftGroup, + rightGroup, + leftAttr, + rightAttr, + left, + right) + } +} + +/** + * A relation produced by applying `func` to each grouping key and associated values from left and + * right children. + */ +case class CoGroup( + func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any], + keyDeserializer: Expression, + leftDeserializer: Expression, + rightDeserializer: Expression, + serializer: Seq[NamedExpression], + leftGroup: Seq[Attribute], + rightGroup: Seq[Attribute], + leftAttr: Seq[Attribute], + rightAttr: Seq[Attribute], + left: LogicalPlan, + right: LogicalPlan) extends BinaryNode with ObjectOperator diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala new file mode 100644 index 0000000000000..9dfdf4da78ff6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.physical + +import org.apache.spark.sql.catalyst.InternalRow + +/** + * Marker trait to identify the shape in which tuples are broadcasted. Typical examples of this are + * identity (tuples remain unchanged) or hashed (tuples are converted into some hash index). + */ +trait BroadcastMode { + def transform(rows: Array[InternalRow]): Any + + /** + * Returns true iff this [[BroadcastMode]] generates the same result as `other`. + */ + def compatibleWith(other: BroadcastMode): Boolean +} + +/** + * IdentityBroadcastMode requires that rows are broadcasted in their original form. + */ +case object IdentityBroadcastMode extends BroadcastMode { + // TODO: pack the UnsafeRows into single bytes array. + override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows + + override def compatibleWith(other: BroadcastMode): Boolean = { + this eq other + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 9312c8123e92e..d449088498c8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.sql.catalyst.expressions.{Unevaluable, Expression, SortOrder} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -75,6 +75,12 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { def clustering: Set[Expression] = ordering.map(_.child).toSet } +/** + * Represents data where tuples are broadcasted to every node. It is quite common that the + * entire set of tuples is transformed into different data structure. + */ +case class BroadcastDistribution(mode: BroadcastMode) extends Distribution + /** * Describes how an operator's output is split across partitions. The `compatibleWith`, * `guarantees`, and `satisfies` methods describe relationships between child partitionings, @@ -165,11 +171,6 @@ sealed trait Partitioning { * produced by `A` could have also been produced by `B`. */ def guarantees(other: Partitioning): Boolean = this == other - - def withNumPartitions(newNumPartitions: Int): Partitioning = { - throw new IllegalStateException( - s"It is not allowed to call withNumPartitions method of a ${this.getClass.getSimpleName}") - } } object Partitioning { @@ -218,7 +219,10 @@ case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning { case object SinglePartition extends Partitioning { val numPartitions = 1 - override def satisfies(required: Distribution): Boolean = true + override def satisfies(required: Distribution): Boolean = required match { + case _: BroadcastDistribution => false + case _ => true + } override def compatibleWith(other: Partitioning): Boolean = other.numPartitions == 1 @@ -240,23 +244,25 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true case ClusteredDistribution(requiredClustering) => - expressions.toSet.subsetOf(requiredClustering.toSet) + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) case _ => false } override def compatibleWith(other: Partitioning): Boolean = other match { - case o: HashPartitioning => this == o + case o: HashPartitioning => this.semanticEquals(o) case _ => false } override def guarantees(other: Partitioning): Boolean = other match { - case o: HashPartitioning => this == o + case o: HashPartitioning => this.semanticEquals(o) case _ => false } - override def withNumPartitions(newNumPartitions: Int): HashPartitioning = { - HashPartitioning(expressions, newNumPartitions) - } + /** + * Returns an expression that will produce a valid partition ID(i.e. non-negative and is less + * than numPartitions) based on hashing expressions. + */ + def partitionIdExpression: Expression = Pmod(new Murmur3Hash(expressions), Literal(numPartitions)) } /** @@ -284,17 +290,17 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) case ClusteredDistribution(requiredClustering) => - ordering.map(_.child).toSet.subsetOf(requiredClustering.toSet) + ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) case _ => false } override def compatibleWith(other: Partitioning): Boolean = other match { - case o: RangePartitioning => this == o + case o: RangePartitioning => this.semanticEquals(o) case _ => false } override def guarantees(other: Partitioning): Boolean = other match { - case o: RangePartitioning => this == o + case o: RangePartitioning => this.semanticEquals(o) case _ => false } } @@ -354,3 +360,21 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) partitionings.map(_.toString).mkString("(", " or ", ")") } } + +/** + * Represents a partitioning where rows are collected, transformed and broadcasted to each + * node in the cluster. + */ +case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning { + override val numPartitions: Int = 1 + + override def satisfies(required: Distribution): Boolean = required match { + case BroadcastDistribution(m) if m == mode => true + case _ => false + } + + override def compatibleWith(other: Partitioning): Boolean = other match { + case BroadcastPartitioning(m) if m == mode => true + case _ => false + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala index 03414b2301e81..7eb72724d7663 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.rules -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.trees.TreeNode abstract class Rule[TreeType <: TreeNode[_]] extends Logging { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index f80d2a93241d1..6fc828f63f152 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -21,9 +21,11 @@ import scala.collection.JavaConverters._ import com.google.common.util.concurrent.AtomicLongMap -import org.apache.spark.Logging +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide +import org.apache.spark.util.Utils object RuleExecutor { protected val timeMap = AtomicLongMap.create[String]() @@ -37,7 +39,7 @@ object RuleExecutor { val maxSize = map.keys.map(_.toString.length).max map.toSeq.sortBy(_._2).reverseMap { case (k, v) => s"${k.padTo(maxSize, " ").mkString} $v" - }.mkString("\n") + }.mkString("\n", "\n", "") } } @@ -59,7 +61,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { protected case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*) /** Defines a sequence of rule batches, to be overridden by the implementation. */ - protected val batches: Seq[Batch] + protected def batches: Seq[Batch] /** @@ -98,7 +100,12 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { if (iteration > batch.strategy.maxIterations) { // Only log if this is a rule that is supposed to run more than once. if (iteration != 2) { - logInfo(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}") + val message = s"Max iterations (${iteration - 1}) reached for batch ${batch.name}" + if (Utils.isTesting) { + throw new TreeNodeException(curPlan, message, null) + } else { + logWarning(message) + } } continue = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 35f087baccdee..232ca4358865a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -17,8 +17,25 @@ package org.apache.spark.sql.catalyst.trees +import java.util.UUID + +import scala.collection.Map +import scala.collection.mutable.Stack + +import org.apache.commons.lang.ClassUtils +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext +import org.apache.spark.rdd.{EmptyRDD, RDD} +import org.apache.spark.sql.catalyst.ScalaReflection._ +import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.types.{StructType, DataType} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */ private class MutableInt(var i: Int) @@ -191,6 +208,19 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case nonChild: AnyRef => nonChild case null => null } + case m: Map[_, _] => m.mapValues { + case arg: TreeNode[_] if containsChild(arg) => + val newChild = remainingNewChildren.remove(0) + val oldChild = remainingOldChildren.remove(0) + if (newChild fastEquals oldChild) { + oldChild + } else { + changed = true + newChild + } + case nonChild: AnyRef => nonChild + case null => null + }.view.force // `mapValues` is lazy and we need to force it to materialize case arg: TreeNode[_] if containsChild(arg) => val newChild = remainingNewChildren.remove(0) val oldChild = remainingOldChildren.remove(0) @@ -212,6 +242,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * When `rule` does not apply to a given node it is left unchanged. * Users should not expect a specific directionality. If a specific directionality is needed, * transformDown or transformUp should be used. + * * @param rule the function use to transform this nodes children */ def transform(rule: PartialFunction[BaseType, BaseType]): BaseType = { @@ -221,6 +252,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** * Returns a copy of this node where `rule` has been recursively applied to it and all of its * children (pre-order). When `rule` does not apply to a given node it is left unchanged. + * * @param rule the function used to transform this nodes children */ def transformDown(rule: PartialFunction[BaseType, BaseType]): BaseType = { @@ -236,6 +268,26 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } } + /** + * Returns a copy of this node where `rule` has been recursively applied first to all of its + * children and then itself (post-order). When `rule` does not apply to a given node, it is left + * unchanged. + * + * @param rule the function use to transform this nodes children + */ + def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = { + val afterRuleOnChildren = transformChildren(rule, (t, r) => t.transformUp(r)) + if (this fastEquals afterRuleOnChildren) { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(this, identity[BaseType]) + } + } else { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(afterRuleOnChildren, identity[BaseType]) + } + } + } + /** * Returns a copy of this node where `rule` has been recursively applied to all the children of * this node. When `rule` does not apply to a given node it is left unchanged. @@ -262,7 +314,17 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } else { Some(arg) } - case m: Map[_, _] => m + case m: Map[_, _] => m.mapValues { + case arg: TreeNode[_] if containsChild(arg) => + val newChild = nextOperation(arg.asInstanceOf[BaseType], rule) + if (!(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case other => other + }.view.force // `mapValues` is lazy and we need to force it to materialize case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if containsChild(arg) => @@ -273,6 +335,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } else { arg } + case tuple @ (arg1: TreeNode[_], arg2: TreeNode[_]) => + val newChild1 = nextOperation(arg1.asInstanceOf[BaseType], rule) + val newChild2 = nextOperation(arg2.asInstanceOf[BaseType], rule) + if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { + changed = true + (newChild1, newChild2) + } else { + tuple + } case other => other } case nonChild: AnyRef => nonChild @@ -281,25 +352,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { if (changed) makeCopy(newArgs) else this } - /** - * Returns a copy of this node where `rule` has been recursively applied first to all of its - * children and then itself (post-order). When `rule` does not apply to a given node, it is left - * unchanged. - * @param rule the function use to transform this nodes children - */ - def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = { - val afterRuleOnChildren = transformChildren(rule, (t, r) => t.transformUp(r)) - if (this fastEquals afterRuleOnChildren) { - CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(this, identity[BaseType]) - } - } else { - CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(afterRuleOnChildren, identity[BaseType]) - } - } - } - /** * Args to the constructor that should be copied, but not transformed. * These are appended to the transformed args automatically by makeCopy @@ -314,20 +366,32 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * @param newArgs the new product arguments. */ def makeCopy(newArgs: Array[AnyRef]): BaseType = attachTree(this, "makeCopy") { + // Skip no-arg constructors that are just there for kryo. val ctors = getClass.getConstructors.filter(_.getParameterTypes.size != 0) if (ctors.isEmpty) { sys.error(s"No valid constructor for $nodeName") } - val defaultCtor = ctors.maxBy(_.getParameterTypes.size) + val allArgs: Array[AnyRef] = if (otherCopyArgs.isEmpty) { + newArgs + } else { + newArgs ++ otherCopyArgs + } + val defaultCtor = ctors.find { ctor => + if (ctor.getParameterTypes.length != allArgs.length) { + false + } else if (allArgs.contains(null)) { + // if there is a `null`, we can't figure out the class, therefore we should just fallback + // to older heuristic + false + } else { + val argsArray: Array[Class[_]] = allArgs.map(_.getClass) + ClassUtils.isAssignable(argsArray, ctor.getParameterTypes, true /* autoboxing */) + } + }.getOrElse(ctors.maxBy(_.getParameterTypes.length)) // fall back to older heuristic try { CurrentOrigin.withOrigin(origin) { - // Skip no-arg constructors that are just there for kryo. - if (otherCopyArgs.isEmpty) { - defaultCtor.newInstance(newArgs: _*).asInstanceOf[BaseType] - } else { - defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[BaseType] - } + defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType] } } catch { case e: java.lang.IllegalArgumentException => @@ -355,7 +419,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** Returns a string representing the arguments to this node, minus any children */ def argString: String = productIterator.flatMap { case tn: TreeNode[_] if containsChild(tn) => Nil - case tn: TreeNode[_] if tn.toString contains "\n" => s"(${tn.simpleString})" :: Nil + case tn: TreeNode[_] => s"${tn.simpleString}" :: Nil case seq: Seq[BaseType] if seq.toSet.subsetOf(children.toSet) => Nil case seq: Seq[_] => seq.mkString("[", ",", "]") :: Nil case set: Set[_] => set.mkString("{", ",", "}") :: Nil @@ -368,7 +432,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { override def toString: String = treeString /** Returns a string representation of the nodes in this tree */ - def treeString: String = generateTreeString(0, new StringBuilder).toString + def treeString: String = generateTreeString(0, Nil, new StringBuilder).toString /** * Returns a string representation of the nodes in this tree, where each operator is numbered. @@ -394,12 +458,87 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } } - /** Appends the string represent of this node and its children to the given StringBuilder. */ - protected def generateTreeString(depth: Int, builder: StringBuilder): StringBuilder = { - builder.append(" " * depth) + /** + * All the nodes that will be used to generate tree string. + * + * For example: + * + * WholeStageCodegen + * +-- SortMergeJoin + * |-- InputAdapter + * | +-- Sort + * +-- InputAdapter + * +-- Sort + * + * the treeChildren of WholeStageCodegen will be Seq(Sort, Sort), it will generate a tree string + * like this: + * + * WholeStageCodegen + * : +- SortMergeJoin + * : :- INPUT + * : :- INPUT + * :- Sort + * :- Sort + */ + protected def treeChildren: Seq[BaseType] = children + + /** + * All the nodes that are parts of this node. + * + * For example: + * + * WholeStageCodegen + * +- SortMergeJoin + * |-- InputAdapter + * | +-- Sort + * +-- InputAdapter + * +-- Sort + * + * the innerChildren of WholeStageCodegen will be Seq(SortMergeJoin), it will generate a tree + * string like this: + * + * WholeStageCodegen + * : +- SortMergeJoin + * : :- INPUT + * : :- INPUT + * :- Sort + * :- Sort + */ + protected def innerChildren: Seq[BaseType] = Nil + + /** + * Appends the string represent of this node and its children to the given StringBuilder. + * + * The `i`-th element in `lastChildren` indicates whether the ancestor of the current node at + * depth `i + 1` is the last child of its own parent node. The depth of the root node is 0, and + * `lastChildren` for the root node should be empty. + */ + def generateTreeString( + depth: Int, lastChildren: Seq[Boolean], builder: StringBuilder): StringBuilder = { + if (depth > 0) { + lastChildren.init.foreach { isLast => + val prefixFragment = if (isLast) " " else ": " + builder.append(prefixFragment) + } + + val branch = if (lastChildren.last) "+- " else ":- " + builder.append(branch) + } + builder.append(simpleString) builder.append("\n") - children.foreach(_.generateTreeString(depth + 1, builder)) + + if (innerChildren.nonEmpty) { + innerChildren.init.foreach(_.generateTreeString( + depth + 2, lastChildren :+ false :+ false, builder)) + innerChildren.last.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder) + } + + if (treeChildren.nonEmpty) { + treeChildren.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) + treeChildren.last.generateTreeString(depth + 1, lastChildren :+ true, builder) + } + builder } @@ -417,4 +556,246 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } s"$nodeName(${args.mkString(",")})" } + + def toJSON: String = compact(render(jsonValue)) + + def prettyJson: String = pretty(render(jsonValue)) + + private def jsonValue: JValue = { + val jsonValues = scala.collection.mutable.ArrayBuffer.empty[JValue] + + def collectJsonValue(tn: BaseType): Unit = { + val jsonFields = ("class" -> JString(tn.getClass.getName)) :: + ("num-children" -> JInt(tn.children.length)) :: tn.jsonFields + jsonValues += JObject(jsonFields) + tn.children.foreach(collectJsonValue) + } + + collectJsonValue(this) + jsonValues + } + + protected def jsonFields: List[JField] = { + val fieldNames = getConstructorParameterNames(getClass) + val fieldValues = productIterator.toSeq ++ otherCopyArgs + assert(fieldNames.length == fieldValues.length, s"${getClass.getSimpleName} fields: " + + fieldNames.mkString(", ") + s", values: " + fieldValues.map(_.toString).mkString(", ")) + + fieldNames.zip(fieldValues).map { + // If the field value is a child, then use an int to encode it, represents the index of + // this child in all children. + case (name, value: TreeNode[_]) if containsChild(value) => + name -> JInt(children.indexOf(value)) + case (name, value: Seq[BaseType]) if value.toSet.subsetOf(containsChild) => + name -> JArray( + value.map(v => JInt(children.indexOf(v.asInstanceOf[TreeNode[_]]))).toList + ) + case (name, value) => name -> parseToJson(value) + }.toList + } + + private def parseToJson(obj: Any): JValue = obj match { + case b: Boolean => JBool(b) + case b: Byte => JInt(b.toInt) + case s: Short => JInt(s.toInt) + case i: Int => JInt(i) + case l: Long => JInt(l) + case f: Float => JDouble(f) + case d: Double => JDouble(d) + case b: BigInt => JInt(b) + case null => JNull + case s: String => JString(s) + case u: UUID => JString(u.toString) + case dt: DataType => dt.jsonValue + case m: Metadata => m.jsonValue + case s: StorageLevel => + ("useDisk" -> s.useDisk) ~ ("useMemory" -> s.useMemory) ~ ("useOffHeap" -> s.useOffHeap) ~ + ("deserialized" -> s.deserialized) ~ ("replication" -> s.replication) + case n: TreeNode[_] => n.jsonValue + case o: Option[_] => o.map(parseToJson) + case t: Seq[_] => JArray(t.map(parseToJson).toList) + case m: Map[_, _] => + val fields = m.toList.map { case (k: String, v) => (k, parseToJson(v)) } + JObject(fields) + case r: RDD[_] => JNothing + // if it's a scala object, we can simply keep the full class path. + // TODO: currently if the class name ends with "$", we think it's a scala object, there is + // probably a better way to check it. + case obj if obj.getClass.getName.endsWith("$") => "object" -> obj.getClass.getName + // returns null if the product type doesn't have a primary constructor, e.g. HiveFunctionWrapper + case p: Product => try { + val fieldNames = getConstructorParameterNames(p.getClass) + val fieldValues = p.productIterator.toSeq + assert(fieldNames.length == fieldValues.length) + ("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map { + case (name, value) => name -> parseToJson(value) + }.toList + } catch { + case _: RuntimeException => null + } + case _ => JNull + } +} + +object TreeNode { + def fromJSON[BaseType <: TreeNode[BaseType]](json: String, sc: SparkContext): BaseType = { + val jsonAST = parse(json) + assert(jsonAST.isInstanceOf[JArray]) + reconstruct(jsonAST.asInstanceOf[JArray], sc).asInstanceOf[BaseType] + } + + private def reconstruct(treeNodeJson: JArray, sc: SparkContext): TreeNode[_] = { + assert(treeNodeJson.arr.forall(_.isInstanceOf[JObject])) + val jsonNodes = Stack(treeNodeJson.arr.map(_.asInstanceOf[JObject]): _*) + + def parseNextNode(): TreeNode[_] = { + val nextNode = jsonNodes.pop() + + val cls = Utils.classForName((nextNode \ "class").asInstanceOf[JString].s) + if (cls == classOf[Literal]) { + Literal.fromJSON(nextNode) + } else if (cls.getName.endsWith("$")) { + cls.getField("MODULE$").get(cls).asInstanceOf[TreeNode[_]] + } else { + val numChildren = (nextNode \ "num-children").asInstanceOf[JInt].num.toInt + + val children: Seq[TreeNode[_]] = (1 to numChildren).map(_ => parseNextNode()) + val fields = getConstructorParameters(cls) + + val parameters: Array[AnyRef] = fields.map { + case (fieldName, fieldType) => + parseFromJson(nextNode \ fieldName, fieldType, children, sc) + }.toArray + + val maybeCtor = cls.getConstructors.find { p => + val expectedTypes = p.getParameterTypes + expectedTypes.length == fields.length && expectedTypes.zip(fields.map(_._2)).forall { + case (cls, tpe) => cls == getClassFromType(tpe) + } + } + if (maybeCtor.isEmpty) { + sys.error(s"No valid constructor for ${cls.getName}") + } else { + try { + maybeCtor.get.newInstance(parameters: _*).asInstanceOf[TreeNode[_]] + } catch { + case e: java.lang.IllegalArgumentException => + throw new RuntimeException( + s""" + |Failed to construct tree node: ${cls.getName} + |ctor: ${maybeCtor.get} + |types: ${parameters.map(_.getClass).mkString(", ")} + |args: ${parameters.mkString(", ")} + """.stripMargin, e) + } + } + } + } + + parseNextNode() + } + + import universe._ + + private def parseFromJson( + value: JValue, + expectedType: Type, + children: Seq[TreeNode[_]], + sc: SparkContext): AnyRef = ScalaReflectionLock.synchronized { + if (value == JNull) return null + + expectedType match { + case t if t <:< definitions.BooleanTpe => + value.asInstanceOf[JBool].value: java.lang.Boolean + case t if t <:< definitions.ByteTpe => + value.asInstanceOf[JInt].num.toByte: java.lang.Byte + case t if t <:< definitions.ShortTpe => + value.asInstanceOf[JInt].num.toShort: java.lang.Short + case t if t <:< definitions.IntTpe => + value.asInstanceOf[JInt].num.toInt: java.lang.Integer + case t if t <:< definitions.LongTpe => + value.asInstanceOf[JInt].num.toLong: java.lang.Long + case t if t <:< definitions.FloatTpe => + value.asInstanceOf[JDouble].num.toFloat: java.lang.Float + case t if t <:< definitions.DoubleTpe => + value.asInstanceOf[JDouble].num: java.lang.Double + + case t if t <:< localTypeOf[java.lang.Boolean] => + value.asInstanceOf[JBool].value: java.lang.Boolean + case t if t <:< localTypeOf[BigInt] => value.asInstanceOf[JInt].num + case t if t <:< localTypeOf[java.lang.String] => value.asInstanceOf[JString].s + case t if t <:< localTypeOf[UUID] => UUID.fromString(value.asInstanceOf[JString].s) + case t if t <:< localTypeOf[DataType] => DataType.parseDataType(value) + case t if t <:< localTypeOf[Metadata] => Metadata.fromJObject(value.asInstanceOf[JObject]) + case t if t <:< localTypeOf[StorageLevel] => + val JBool(useDisk) = value \ "useDisk" + val JBool(useMemory) = value \ "useMemory" + val JBool(useOffHeap) = value \ "useOffHeap" + val JBool(deserialized) = value \ "deserialized" + val JInt(replication) = value \ "replication" + StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication.toInt) + case t if t <:< localTypeOf[TreeNode[_]] => value match { + case JInt(i) => children(i.toInt) + case arr: JArray => reconstruct(arr, sc) + case _ => throw new RuntimeException(s"$value is not a valid json value for tree node.") + } + case t if t <:< localTypeOf[Option[_]] => + if (value == JNothing) { + None + } else { + val TypeRef(_, _, Seq(optType)) = t + Option(parseFromJson(value, optType, children, sc)) + } + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val JArray(elements) = value + elements.map(parseFromJson(_, elementType, children, sc)).toSeq + case t if t <:< localTypeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + val JObject(fields) = value + fields.map { + case (name, value) => name -> parseFromJson(value, valueType, children, sc) + }.toMap + case t if t <:< localTypeOf[RDD[_]] => + new EmptyRDD[Any](sc) + case _ if isScalaObject(value) => + val JString(clsName) = value \ "object" + val cls = Utils.classForName(clsName) + cls.getField("MODULE$").get(cls) + case t if t <:< localTypeOf[Product] => + val fields = getConstructorParameters(t) + val clsName = getClassNameFromType(t) + parseToProduct(clsName, fields, value, children, sc) + // There maybe some cases that the parameter type signature is not Product but the value is, + // e.g. `SpecifiedWindowFrame` with type signature `WindowFrame`, handle it here. + case _ if isScalaProduct(value) => + val JString(clsName) = value \ "product-class" + val fields = getConstructorParameters(Utils.classForName(clsName)) + parseToProduct(clsName, fields, value, children, sc) + case _ => sys.error(s"Do not support type $expectedType with json $value.") + } + } + + private def parseToProduct( + clsName: String, + fields: Seq[(String, Type)], + value: JValue, + children: Seq[TreeNode[_]], + sc: SparkContext): AnyRef = { + val parameters: Array[AnyRef] = fields.map { + case (fieldName, fieldType) => parseFromJson(value \ fieldName, fieldType, children, sc) + }.toArray + val ctor = Utils.classForName(clsName).getConstructors.maxBy(_.getParameterTypes.size) + ctor.newInstance(parameters: _*).asInstanceOf[AnyRef] + } + + private def isScalaObject(jValue: JValue): Boolean = (jValue \ "object") match { + case JString(str) if str.endsWith("$") => true + case _ => false + } + + private def isScalaProduct(jValue: JValue): Boolean = (jValue \ "product-class") match { + case _: JString => true + case _ => false + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala index ea6aa1850db4c..3646c70ad2c41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.Logging +import org.apache.spark.internal.Logging /** * A library for easily manipulating trees of operators. Operators that extend TreeNode are diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala index 70b028d2b3f7c..d46f03ad8fbb3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala @@ -24,7 +24,6 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte override def copy(): MapData = new ArrayBasedMapData(keyArray.copy(), valueArray.copy()) - // We need to check equality of map type in tests. override def equals(o: Any): Boolean = { if (!o.isInstanceOf[ArrayBasedMapData]) { return false @@ -35,11 +34,11 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte return false } - ArrayBasedMapData.toScalaMap(this) == ArrayBasedMapData.toScalaMap(other) + this.keyArray == other.keyArray && this.valueArray == other.valueArray } override def hashCode: Int = { - ArrayBasedMapData.toScalaMap(this).hashCode() + keyArray.hashCode() * 37 + valueArray.hashCode() } override def toString: String = { @@ -70,4 +69,9 @@ object ArrayBasedMapData { def toScalaMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { keys.zip(values).toMap } + + def toJavaMap(keys: Array[Any], values: Array[Any]): java.util.Map[Any, Any] = { + import scala.collection.JavaConverters._ + keys.zip(values).toMap.asJava + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala deleted file mode 100644 index 2b83651f9086d..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.util - -import scala.language.implicitConversions -import scala.util.matching.Regex -import scala.util.parsing.combinator.syntactical.StandardTokenParsers - -import org.apache.spark.sql.catalyst.SqlLexical -import org.apache.spark.sql.types._ - -/** - * This is a data type parser that can be used to parse string representations of data types - * provided in SQL queries. This parser is mixed in with DDLParser and SqlParser. - */ -private[sql] trait DataTypeParser extends StandardTokenParsers { - - // This is used to create a parser from a regex. We are using regexes for data type strings - // since these strings can be also used as column names or field names. - import lexical.Identifier - implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( - s"identifier matching regex ${regex}", - { case Identifier(str) if regex.unapplySeq(str).isDefined => str } - ) - - protected lazy val primitiveType: Parser[DataType] = - "(?i)string".r ^^^ StringType | - "(?i)float".r ^^^ FloatType | - "(?i)(?:int|integer)".r ^^^ IntegerType | - "(?i)tinyint".r ^^^ ByteType | - "(?i)smallint".r ^^^ ShortType | - "(?i)double".r ^^^ DoubleType | - "(?i)(?:bigint|long)".r ^^^ LongType | - "(?i)binary".r ^^^ BinaryType | - "(?i)boolean".r ^^^ BooleanType | - fixedDecimalType | - "(?i)decimal".r ^^^ DecimalType.USER_DEFAULT | - "(?i)date".r ^^^ DateType | - "(?i)timestamp".r ^^^ TimestampType | - varchar - - protected lazy val fixedDecimalType: Parser[DataType] = - ("(?i)decimal".r ~> "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { - case precision ~ scale => - DecimalType(precision.toInt, scale.toInt) - } - - protected lazy val varchar: Parser[DataType] = - "(?i)varchar".r ~> "(" ~> (numericLit <~ ")") ^^^ StringType - - protected lazy val arrayType: Parser[DataType] = - "(?i)array".r ~> "<" ~> dataType <~ ">" ^^ { - case tpe => ArrayType(tpe) - } - - protected lazy val mapType: Parser[DataType] = - "(?i)map".r ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ { - case t1 ~ _ ~ t2 => MapType(t1, t2) - } - - protected lazy val structField: Parser[StructField] = - ident ~ ":" ~ dataType ^^ { - case name ~ _ ~ tpe => StructField(name, tpe, nullable = true) - } - - protected lazy val structType: Parser[DataType] = - ("(?i)struct".r ~> "<" ~> repsep(structField, ",") <~ ">" ^^ { - case fields => new StructType(fields.toArray) - }) | - ("(?i)struct".r ~ "<>" ^^^ StructType(Nil)) - - protected lazy val dataType: Parser[DataType] = - arrayType | - mapType | - structType | - primitiveType - - def toDataType(dataTypeString: String): DataType = synchronized { - phrase(dataType)(new lexical.Scanner(dataTypeString)) match { - case Success(result, _) => result - case failure: NoSuccess => throw new DataTypeException(failMessage(dataTypeString)) - } - } - - private def failMessage(dataTypeString: String): String = { - s"Unsupported dataType: $dataTypeString. If you have a struct and a field name of it has " + - "any special characters, please use backticks (`) to quote that field name, e.g. `x+y`. " + - "Please note that backtick itself is not supported in a field name." - } -} - -private[sql] object DataTypeParser { - lazy val dataTypeParser = new DataTypeParser { - override val lexical = new SqlLexical - } - - def parse(dataTypeString: String): DataType = dataTypeParser.toDataType(dataTypeString) -} - -/** The exception thrown from the [[DataTypeParser]]. */ -private[sql] class DataTypeException(message: String) extends Exception(message) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 781ed1688a327..5393cb8ab35e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -19,9 +19,11 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} -import java.util.{TimeZone, Calendar} +import java.util.{Calendar, TimeZone} import javax.xml.bind.DatatypeConverter +import scala.annotation.tailrec + import org.apache.spark.unsafe.types.UTF8String /** @@ -55,9 +57,17 @@ object DateTimeUtils { // this is year -17999, calculation: 50 * daysIn400Year final val YearZero = -17999 final val toYearZero = to2001 + 7304850 + final val TimeZoneGMT = TimeZone.getTimeZone("GMT") @transient lazy val defaultTimeZone = TimeZone.getDefault + // Reuse the Calendar object in each thread as it is expensive to create in each method call. + private val threadLocalGmtCalendar = new ThreadLocal[Calendar] { + override protected def initialValue: Calendar = { + Calendar.getInstance(TimeZoneGMT) + } + } + // Java TimeZone has no mention of thread safety. Use thread local instance to be safe. private val threadLocalLocalTimeZone = new ThreadLocal[TimeZone] { override protected def initialValue: TimeZone = { @@ -109,8 +119,9 @@ object DateTimeUtils { } } + @tailrec def stringToTime(s: String): java.util.Date = { - var indexOfGMT = s.indexOf("GMT"); + val indexOfGMT = s.indexOf("GMT") if (indexOfGMT != -1) { // ISO8601 with a weird time zone specifier (2000-01-01T00:00GMT+01:00) val s0 = s.substring(0, indexOfGMT) @@ -241,6 +252,10 @@ object DateTimeUtils { i += 3 } else if (i < 2) { if (b == '-') { + if (i == 0 && j != 4) { + // year should have exact four digits + return None + } segments(i) = currentSegmentValue currentSegmentValue = 0 i += 1 @@ -308,17 +323,26 @@ object DateTimeUtils { } segments(i) = currentSegmentValue + if (!justTime && i == 0 && j != 4) { + // year should have exact four digits + return None + } while (digitsMilli < 6) { segments(6) *= 10 digitsMilli += 1 } - if (!justTime && (segments(0) < 1000 || segments(0) > 9999 || segments(1) < 1 || + if (!justTime && (segments(0) < 0 || segments(0) > 9999 || segments(1) < 1 || segments(1) > 12 || segments(2) < 1 || segments(2) > 31)) { return None } + // Instead of return None, we truncate the fractional seconds to prevent inserting NULL + if (segments(6) > 999999) { + segments(6) = segments(6).toString.take(6).toInt + } + if (segments(3) < 0 || segments(3) > 23 || segments(4) < 0 || segments(4) > 59 || segments(5) < 0 || segments(5) > 59 || segments(6) < 0 || segments(6) > 999999 || segments(7) < 0 || segments(7) > 23 || segments(8) < 0 || segments(8) > 59) { @@ -368,6 +392,10 @@ object DateTimeUtils { while (j < bytes.length && (i < 3 && !(bytes(j) == ' ' || bytes(j) == 'T'))) { val b = bytes(j) if (i < 2 && b == '-') { + if (i == 0 && j != 4) { + // year should have exact four digits + return None + } segments(i) = currentSegmentValue currentSegmentValue = 0 i += 1 @@ -381,40 +409,54 @@ object DateTimeUtils { } j += 1 } + if (i == 0 && j != 4) { + // year should have exact four digits + return None + } segments(i) = currentSegmentValue - if (segments(0) < 1000 || segments(0) > 9999 || segments(1) < 1 || segments(1) > 12 || + if (segments(0) < 0 || segments(0) > 9999 || segments(1) < 1 || segments(1) > 12 || segments(2) < 1 || segments(2) > 31) { return None } - val c = Calendar.getInstance(TimeZone.getTimeZone("GMT")) + val c = threadLocalGmtCalendar.get() + c.clear() c.set(segments(0), segments(1) - 1, segments(2), 0, 0, 0) c.set(Calendar.MILLISECOND, 0) Some((c.getTimeInMillis / MILLIS_PER_DAY).toInt) } + /** + * Returns the microseconds since year zero (-17999) from microseconds since epoch. + */ + private def absoluteMicroSecond(microsec: SQLTimestamp): SQLTimestamp = { + microsec + toYearZero * MICROS_PER_DAY + } + + private def localTimestamp(microsec: SQLTimestamp): SQLTimestamp = { + absoluteMicroSecond(microsec) + defaultTimeZone.getOffset(microsec / 1000) * 1000L + } + /** * Returns the hour value of a given timestamp value. The timestamp is expressed in microseconds. */ - def getHours(timestamp: SQLTimestamp): Int = { - val localTs = (timestamp / 1000) + defaultTimeZone.getOffset(timestamp / 1000) - ((localTs / 1000 / 3600) % 24).toInt + def getHours(microsec: SQLTimestamp): Int = { + ((localTimestamp(microsec) / MICROS_PER_SECOND / 3600) % 24).toInt } /** * Returns the minute value of a given timestamp value. The timestamp is expressed in * microseconds. */ - def getMinutes(timestamp: SQLTimestamp): Int = { - val localTs = (timestamp / 1000) + defaultTimeZone.getOffset(timestamp / 1000) - ((localTs / 1000 / 60) % 60).toInt + def getMinutes(microsec: SQLTimestamp): Int = { + ((localTimestamp(microsec) / MICROS_PER_SECOND / 60) % 60).toInt } /** * Returns the second value of a given timestamp value. The timestamp is expressed in * microseconds. */ - def getSeconds(timestamp: SQLTimestamp): Int = { - ((timestamp / 1000 / 1000) % 60).toInt + def getSeconds(microsec: SQLTimestamp): Int = { + ((localTimestamp(microsec) / MICROS_PER_SECOND) % 60).toInt } private[this] def isLeapYear(year: Int): Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index e9bf7b33e35be..2b8cdc1e23ab3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -17,13 +17,16 @@ package org.apache.spark.sql.catalyst.util +import scala.collection.JavaConverters._ + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.{DataType, Decimal} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class GenericArrayData(val array: Array[Any]) extends ArrayData { - def this(seq: scala.collection.GenIterable[Any]) = this(seq.toArray) + def this(seq: Seq[Any]) = this(seq.toArray) + def this(list: java.util.List[Any]) = this(list.asScala) // TODO: This is boxing. We should specialize. def this(primitiveArray: Array[Int]) = this(primitiveArray.toSeq) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala index 9fefc5656aac0..da90ddbd63afb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala @@ -66,7 +66,7 @@ object NumberConverter { * negative digit is found, ignore the suffix starting there. * * @param radix must be between MIN_RADIX and MAX_RADIX - * @param fromPos is the first element that should be conisdered + * @param fromPos is the first element that should be considered * @return the result should be treated as an unsigned 64-bit integer. */ private def encode(radix: Int, fromPos: Int): Long = { @@ -122,7 +122,7 @@ object NumberConverter { * unsigned, otherwise it is signed. * NB: This logic is borrowed from org.apache.hadoop.hive.ql.ud.UDFConv */ - def convert(n: Array[Byte] , fromBase: Int, toBase: Int ): UTF8String = { + def convert(n: Array[Byte], fromBase: Int, toBase: Int ): UTF8String = { if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX || Math.abs(toBase) < Character.MIN_RADIX || Math.abs(toBase) > Character.MAX_RADIX) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala index 191d5e6399fc9..d5d151a5802f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala @@ -41,4 +41,6 @@ class StringKeyHashMap[T](normalizer: (String) => String) { def remove(key: String): Option[T] = base.remove(normalizer(key)) def iterator: Iterator[(String, T)] = base.toIterator + + def clear(): Unit = base.clear() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index c2eeb3c5650ab..cde8bd5b9614c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import java.util.regex.Pattern +import java.util.regex.{Pattern, PatternSyntaxException} import org.apache.spark.unsafe.types.UTF8String @@ -52,4 +52,25 @@ object StringUtils { def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase) def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase) + + /** + * This utility can be used for filtering pattern in the "Like" of "Show Tables / Functions" DDL + * @param names the names list to be filtered + * @param pattern the filter pattern, only '*' and '|' are allowed as wildcards, others will + * follow regular expression convention, case insensitive match and white spaces + * on both ends will be ignored + * @return the filtered names list in order + */ + def filterPattern(names: Seq[String], pattern: String): Seq[String] = { + val funcNames = scala.collection.mutable.SortedSet.empty[String] + pattern.trim().split("\\|").foreach { subPattern => + try { + val regex = ("(?i)" + subPattern.replaceAll("\\*", ".*")).r + funcNames ++= names.filter{ name => regex.pattern.matcher(name).matches() } + } catch { + case _: PatternSyntaxException => + } + } + funcNames.toSeq + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index bcf4d78fb9371..f603cbfb0cc21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -57,6 +57,7 @@ object TypeUtils { def getInterpretedOrdering(t: DataType): Ordering[Any] = { t match { case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] + case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 71293475ca0f9..f879b34358a9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -18,7 +18,11 @@ package org.apache.spark.sql.catalyst import java.io._ +import java.nio.charset.StandardCharsets +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{NumericType, StringType} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils package object util { @@ -101,12 +105,12 @@ package object util { } def sideBySide(left: Seq[String], right: Seq[String]): Seq[String] = { - val maxLeftSize = left.map(_.size).max + val maxLeftSize = left.map(_.length).max val leftPadded = left ++ Seq.fill(math.max(right.size - left.size, 0))("") val rightPadded = right ++ Seq.fill(math.max(left.size - right.size, 0))("") leftPadded.zip(rightPadded).map { - case (l, r) => (if (l == r) " " else "!") + l + (" " * ((maxLeftSize - l.size) + 3)) + r + case (l, r) => (if (l == r) " " else "!") + l + (" " * ((maxLeftSize - l.length) + 3)) + r } } @@ -115,7 +119,7 @@ package object util { val writer = new PrintWriter(out) t.printStackTrace(writer) writer.flush() - new String(out.toByteArray) + new String(out.toByteArray, StandardCharsets.UTF_8) } def stringOrNull(a: AnyRef): String = if (a == null) null else a.toString @@ -130,6 +134,35 @@ package object util { ret } + // Replaces attributes, string literals, complex type extractors with their pretty form so that + // generated column names don't contain back-ticks or double-quotes. + def usePrettyExpression(e: Expression): Expression = e transform { + case a: Attribute => new PrettyAttribute(a) + case Literal(s: UTF8String, StringType) => PrettyAttribute(s.toString, StringType) + case Literal(v, t: NumericType) if v != null => PrettyAttribute(v.toString, t) + case e: GetStructField => + val name = e.name.getOrElse(e.childSchema(e.ordinal).name) + PrettyAttribute(usePrettyExpression(e.child).sql + "." + name, e.dataType) + case e: GetArrayStructFields => + PrettyAttribute(usePrettyExpression(e.child) + "." + e.field.name, e.dataType) + } + + def quoteIdentifier(name: String): String = { + // Escapes back-ticks within the identifier name with double-back-ticks, and then quote the + // identifier with back-ticks. + "`" + name.replace("`", "``") + "`" + } + + /** + * Returns the string representation of this expression that is safe to be put in + * code comments of generated code. The length is capped at 128 characters. + */ + def toCommentSafeString(str: String): String = { + val len = math.min(str.length, 128) + val suffix = if (str.length > len) "..." else "" + str.substring(0, len).replace("*/", "\\*\\/").replace("\\u", "\\\\u") + suffix + } + /* FIX ME implicit class debugLogging(a: Any) { def debugLogging() { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 1d2d007c2b4d2..90af10f7a6b1e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.types import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{TypeTag, runtimeMirror} +import scala.reflect.runtime.universe.{runtimeMirror, TypeTag} import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.Expression @@ -84,6 +84,7 @@ private[sql] object TypeCollection { * Types that can be ordered/compared. In the long run we should probably make this a trait * that can be mixed into each data type, and perhaps create an [[AbstractDataType]]. */ + // TODO: Should we consolidate this with RowOrdering.isOrderable? val Ordered = TypeCollection( BooleanType, ByteType, ShortType, IntegerType, LongType, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 5770f59b53077..520e344361625 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.types +import scala.math.Ordering + import org.json4s.JsonDSL._ import org.apache.spark.annotation.DeveloperApi - +import org.apache.spark.sql.catalyst.util.ArrayData object ArrayType extends AbstractDataType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ @@ -75,10 +77,57 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT override def simpleString: String = s"array<${elementType.simpleString}>" + override def sql: String = s"ARRAY<${elementType.sql}>" + override private[spark] def asNullable: ArrayType = ArrayType(elementType.asNullable, containsNull = true) override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { f(this) || elementType.existsRecursively(f) } + + @transient + private[sql] lazy val interpretedOrdering: Ordering[ArrayData] = new Ordering[ArrayData] { + private[this] val elementOrdering: Ordering[Any] = elementType match { + case dt: AtomicType => dt.ordering.asInstanceOf[Ordering[Any]] + case a : ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] + case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] + case other => + throw new IllegalArgumentException(s"Type $other does not support ordered operations") + } + + def compare(x: ArrayData, y: ArrayData): Int = { + val leftArray = x + val rightArray = y + val minLength = scala.math.min(leftArray.numElements(), rightArray.numElements()) + var i = 0 + while (i < minLength) { + val isNullLeft = leftArray.isNullAt(i) + val isNullRight = rightArray.isNullAt(i) + if (isNullLeft && isNullRight) { + // Do nothing. + } else if (isNullLeft) { + return -1 + } else if (isNullRight) { + return 1 + } else { + val comp = + elementOrdering.compare( + leftArray.get(i, elementType), + rightArray.get(i, elementType)) + if (comp != 0) { + return comp + } + } + i += 1 + } + if (leftArray.numElements() < rightArray.numElements()) { + return -1 + } else if (leftArray.numElements() > rightArray.numElements()) { + return 1 + } else { + return 0 + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala index f2c6f34ea51c7..c40e140e8c5c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala @@ -47,9 +47,9 @@ class BinaryType private() extends AtomicType { } /** - * The default size of a value of the BinaryType is 4096 bytes. + * The default size of a value of the BinaryType is 100 bytes. */ - override def defaultSize: Int = 4096 + override def defaultSize: Int = 100 private[spark] override def asNullable: BinaryType = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala index 2ca427975a1cf..d37130e27ba5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.types -import scala.math.{Ordering, Integral, Numeric} +import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock - /** * :: DeveloperApi :: * The data type representing `Byte` values. Please use the singleton [[DataTypes.ByteType]]. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 4b54c31dcc27a..3d4a02b0ffebd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -17,19 +17,15 @@ package org.apache.spark.sql.types -import scala.util.Try -import scala.util.parsing.combinator.RegexParsers - +import org.json4s._ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ -import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.util.Utils - /** * :: DeveloperApi :: * The base type of all Spark SQL data types. @@ -66,6 +62,11 @@ abstract class DataType extends AbstractDataType { /** Readable string representation for the type. */ def simpleString: String = typeName + /** Readable string representation for the type with truncation */ + private[sql] def simpleString(maxNumberFields: Int): String = simpleString + + def sql: String = simpleString.toUpperCase + /** * Check if `this` and `other` are the same data type when ignoring nullability * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). @@ -91,27 +92,18 @@ abstract class DataType extends AbstractDataType { object DataType { - private[sql] def fromString(raw: String): DataType = { - Try(DataType.fromJson(raw)).getOrElse(DataType.fromCaseClassString(raw)) - } def fromJson(json: String): DataType = parseDataType(parse(json)) - /** - * @deprecated As of 1.2.0, replaced by `DataType.fromJson()` - */ - @deprecated("Use DataType.fromJson instead", "1.2.0") - def fromCaseClassString(string: String): DataType = CaseClassStringParser(string) - private val nonDecimalNameToType = { - Seq(NullType, DateType, TimestampType, BinaryType, - IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) + Seq(NullType, DateType, TimestampType, BinaryType, IntegerType, BooleanType, LongType, + DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType) .map(t => t.typeName -> t).toMap } /** Given the string representation of a type, return its DataType */ private def nameToType(name: String): DataType = { - val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r + val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\-?\d+)\s*\)""".r name match { case "decimal" => DecimalType.USER_DEFAULT case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) @@ -127,7 +119,7 @@ object DataType { } // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side. - private def parseDataType(json: JValue): DataType = json match { + private[sql] def parseDataType(json: JValue): DataType = json match { case JString(name) => nameToType(name) @@ -181,73 +173,6 @@ object DataType { StructField(name, parseDataType(dataType), nullable) } - private object CaseClassStringParser extends RegexParsers { - protected lazy val primitiveType: Parser[DataType] = - ( "StringType" ^^^ StringType - | "FloatType" ^^^ FloatType - | "IntegerType" ^^^ IntegerType - | "ByteType" ^^^ ByteType - | "ShortType" ^^^ ShortType - | "DoubleType" ^^^ DoubleType - | "LongType" ^^^ LongType - | "BinaryType" ^^^ BinaryType - | "BooleanType" ^^^ BooleanType - | "DateType" ^^^ DateType - | "DecimalType()" ^^^ DecimalType.USER_DEFAULT - | fixedDecimalType - | "TimestampType" ^^^ TimestampType - ) - - protected lazy val fixedDecimalType: Parser[DataType] = - ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ { - case precision ~ scale => DecimalType(precision.toInt, scale.toInt) - } - - protected lazy val arrayType: Parser[DataType] = - "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { - case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) - } - - protected lazy val mapType: Parser[DataType] = - "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { - case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) - } - - protected lazy val structField: Parser[StructField] = - ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { - case name ~ tpe ~ nullable => - StructField(name, tpe, nullable = nullable) - } - - protected lazy val boolVal: Parser[Boolean] = - ( "true" ^^^ true - | "false" ^^^ false - ) - - protected lazy val structType: Parser[DataType] = - "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { - case fields => StructType(fields) - } - - protected lazy val dataType: Parser[DataType] = - ( arrayType - | mapType - | structType - | primitiveType - ) - - /** - * Parses a string representation of a DataType. - * - * TODO: Generate parser as pickler... - */ - def apply(asString: String): DataType = parseAll(dataType, asString) match { - case Success(result, _) => result - case failure: NoSuccess => - throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure") - } - } - protected[types] def buildFormattedString( dataType: DataType, prefix: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index c7a1a2e7469ee..a30a3926bb86e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.types -import java.math.{RoundingMode, MathContext} +import java.math.{MathContext, RoundingMode} import org.apache.spark.annotation.DeveloperApi @@ -340,6 +340,9 @@ object Decimal { val ROUND_CEILING = BigDecimal.RoundingMode.CEILING val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR + /** Maximum number of decimal digits a Int can represent */ + val MAX_INT_DIGITS = 9 + /** Maximum number of decimal digits a Long can represent */ val MAX_LONG_DIGITS = 18 @@ -373,6 +376,25 @@ object Decimal { def apply(value: String): Decimal = new Decimal().set(BigDecimal(value)) + // This is used for RowEncoder to handle Decimal inside external row. + def fromDecimal(value: Any): Decimal = { + value match { + case j: java.math.BigDecimal => apply(j) + case d: Decimal => d + } + } + + /** + * Creates a decimal from unscaled, precision and scale without checking the bounds. + */ + def createUnsafe(unscaled: Long, precision: Int, scale: Int): Decimal = { + val dec = new Decimal() + dec.longVal = unscaled + dec._precision = precision + dec._scale = scale + dec + } + // Evidence parameters for Decimal considered either as Fractional or Integral. We provide two // parameters inheriting from a common trait since both traits define mkNumericOps. // See scala.math's Numeric.scala for examples for Scala's built-in types. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 0cd352d0fa928..9c1319c1c5e6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -25,20 +25,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.Expression -/** Precision parameters for a Decimal */ -@deprecated("Use DecimalType(precision, scale) directly", "1.5") -case class PrecisionInfo(precision: Int, scale: Int) { - if (scale > precision) { - throw new AnalysisException( - s"Decimal scale ($scale) cannot be greater than precision ($precision).") - } - if (precision > DecimalType.MAX_PRECISION) { - throw new AnalysisException( - s"DecimalType can only support precision up to 38" - ) - } -} - /** * :: DeveloperApi :: * The data type representing `java.math.BigDecimal` values. @@ -54,18 +40,18 @@ case class PrecisionInfo(precision: Int, scale: Int) { @DeveloperApi case class DecimalType(precision: Int, scale: Int) extends FractionalType { - // default constructor for Java - def this(precision: Int) = this(precision, 0) - def this() = this(10) + if (scale > precision) { + throw new AnalysisException( + s"Decimal scale ($scale) cannot be greater than precision ($precision).") + } - @deprecated("Use DecimalType(precision, scale) instead", "1.5") - def this(precisionInfo: Option[PrecisionInfo]) { - this(precisionInfo.getOrElse(PrecisionInfo(10, 0)).precision, - precisionInfo.getOrElse(PrecisionInfo(10, 0)).scale) + if (precision > DecimalType.MAX_PRECISION) { + throw new AnalysisException(s"DecimalType can only support precision up to 38") } - @deprecated("Use DecimalType.precision and DecimalType.scale instead", "1.5") - val precisionInfo = Some(PrecisionInfo(precision, scale)) + // default constructor for Java + def this(precision: Int) = this(precision, 0) + def this() = this(10) private[sql] type InternalType = Decimal @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } @@ -78,6 +64,8 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { override def toString: String = s"DecimalType($precision,$scale)" + override def sql: String = typeName.toUpperCase + /** * Returns whether this DecimalType is wider than `other`. If yes, it means `other` * can be casted into `this` safely without losing any precision or range. @@ -91,9 +79,21 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { } /** - * The default size of a value of the DecimalType is 4096 bytes. + * Returns whether this DecimalType is tighter than `other`. If yes, it means `this` + * can be casted into `other` safely without losing any precision or range. */ - override def defaultSize: Int = 4096 + private[sql] def isTighterThan(other: DataType): Boolean = other match { + case dt: DecimalType => + (precision - scale) <= (dt.precision - dt.scale) && scale <= dt.scale + case dt: IntegralType => + isTighterThan(DecimalType.forType(dt)) + case _ => false + } + + /** + * The default size of a value of the DecimalType is 8 bytes (precision <= 18) or 16 bytes. + */ + override def defaultSize: Int = if (precision <= Decimal.MAX_LONG_DIGITS) 8 else 16 override def simpleString: String = s"decimal($precision,$scale)" @@ -110,9 +110,6 @@ object DecimalType extends AbstractDataType { val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 18) val USER_DEFAULT: DecimalType = DecimalType(10, 0) - @deprecated("Does not support unlimited precision, please specify the precision and scale", "1.5") - val Unlimited: DecimalType = SYSTEM_DEFAULT - // The decimal types compatible with other numeric types private[sql] val ByteDecimal = DecimalType(3, 0) private[sql] val ShortDecimal = DecimalType(5, 0) @@ -130,15 +127,6 @@ object DecimalType extends AbstractDataType { case DoubleType => DoubleDecimal } - @deprecated("please specify precision and scale", "1.5") - def apply(): DecimalType = USER_DEFAULT - - @deprecated("Use DecimalType(precision, scale) instead", "1.5") - def apply(precisionInfo: Option[PrecisionInfo]) { - this(precisionInfo.getOrElse(PrecisionInfo(10, 0)).precision, - precisionInfo.getOrElse(PrecisionInfo(10, 0)).scale) - } - private[sql] def bounded(precision: Int, scale: Int): DecimalType = { DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE)) } @@ -162,6 +150,39 @@ object DecimalType extends AbstractDataType { } } + /** + * Returns if dt is a DecimalType that fits inside a int + */ + def is32BitDecimalType(dt: DataType): Boolean = { + dt match { + case t: DecimalType => + t.precision <= Decimal.MAX_INT_DIGITS + case _ => false + } + } + + /** + * Returns if dt is a DecimalType that fits inside a long + */ + def is64BitDecimalType(dt: DataType): Boolean = { + dt match { + case t: DecimalType => + t.precision <= Decimal.MAX_LONG_DIGITS + case _ => false + } + } + + /** + * Returns if dt is a DecimalType that doesn't fit inside a long + */ + def isByteArrayDecimalType(dt: DataType): Boolean = { + dt match { + case t: DecimalType => + t.precision > Decimal.MAX_LONG_DIGITS + case _ => false + } + } + def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType] def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala index 2a1bf0938e5a8..e553f65f3c99d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.types -import scala.math.{Ordering, Fractional, Numeric} +import scala.math.{Fractional, Numeric, Ordering} import scala.math.Numeric.DoubleAsIfIntegral import scala.reflect.runtime.universe.typeTag diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala index 08e22252aef82..ae9aa9eefaf2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.types +import scala.math.{Fractional, Numeric, Ordering} import scala.math.Numeric.FloatAsIfIntegral -import scala.math.{Ordering, Fractional, Numeric} import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala index a2c6e19b05b3c..38a7b8ee52651 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.types -import scala.math.{Ordering, Integral, Numeric} +import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala index 2b3adf6ade83b..88aff0c87755c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.types -import scala.math.{Ordering, Integral, Numeric} +import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 00461e529ca0a..5474954af70e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -62,6 +62,8 @@ case class MapType( override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>" + override def sql: String = s"MAP<${keyType.sql}, ${valueType.sql}>" + override private[spark] def asNullable: MapType = MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index 6ee24ee0c1913..66f123682e117 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -156,7 +156,9 @@ object Metadata { throw new RuntimeException(s"Do not support array of type ${other.getClass}.") } } - case other => + case (key, JNull) => + builder.putNull(key) + case (key, other) => throw new RuntimeException(s"Do not support type ${other.getClass}.") } builder.build() @@ -229,6 +231,9 @@ class MetadataBuilder { this } + /** Puts a null. */ + def putNull(key: String): this.type = put(key, null) + /** Puts a Long. */ def putLong(key: String, value: Long): this.type = put(key, value) @@ -268,4 +273,9 @@ class MetadataBuilder { map.put(key, value) this } + + def remove(key: String): this.type = { + map.remove(key) + this + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala index fca0b799eb809..b7b1acc58242e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala @@ -23,8 +23,10 @@ private[sql] object ObjectType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = throw new UnsupportedOperationException("null literals can't be casted to ObjectType") - // No casting or comparison is supported. - override private[sql] def acceptsType(other: DataType): Boolean = false + override private[sql] def acceptsType(other: DataType): Boolean = other match { + case ObjectType(_) => true + case _ => false + } override private[sql] def simpleString: String = "Object" } @@ -39,4 +41,6 @@ private[sql] case class ObjectType(cls: Class[_]) extends DataType { throw new UnsupportedOperationException("No size estimation available for objects.") def asNullable: DataType = this + + override def simpleString: String = cls.getName } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala index a13119e659064..486cf585284df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.types -import scala.math.{Ordering, Integral, Numeric} +import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala index a7627a2de1611..44a25361f31c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -38,9 +38,9 @@ class StringType private() extends AtomicType { private[sql] val ordering = implicitly[Ordering[InternalType]] /** - * The default size of a value of the StringType is 4096 bytes. + * The default size of a value of the StringType is 20 bytes. */ - override def defaultSize: Int = 4096 + override def defaultSize: Int = 20 private[spark] override def asNullable: StringType = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 11fce4beaf55f..1238eefcb6062 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -18,14 +18,15 @@ package org.apache.spark.sql.types import scala.collection.mutable.ArrayBuffer +import scala.util.Try import org.json4s.JsonDSL._ import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} -import org.apache.spark.sql.catalyst.util.DataTypeParser - +import org.apache.spark.sql.catalyst.parser.{DataTypeParser, LegacyTypeStringParser} +import org.apache.spark.sql.catalyst.util.quoteIdentifier /** * :: DeveloperApi :: @@ -40,6 +41,7 @@ import org.apache.spark.sql.catalyst.util.DataTypeParser * Example: * {{{ * import org.apache.spark.sql._ + * import org.apache.spark.sql.types._ * * val struct = * StructType( @@ -278,6 +280,28 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru s"struct<${fieldTypes.mkString(",")}>" } + override def sql: String = { + val fieldTypes = fields.map(f => s"${quoteIdentifier(f.name)}: ${f.dataType.sql}") + s"STRUCT<${fieldTypes.mkString(", ")}>" + } + + private[sql] override def simpleString(maxNumberFields: Int): String = { + val builder = new StringBuilder + val fieldTypes = fields.take(maxNumberFields).map { + case f => s"${f.name}: ${f.dataType.simpleString(maxNumberFields)}" + } + builder.append("struct<") + builder.append(fieldTypes.mkString(", ")) + if (fields.length > 2) { + if (fields.length - fieldTypes.length == 1) { + builder.append(" ... 1 more field") + } else { + builder.append(" ... " + (fields.length - 2) + " more fields") + } + } + builder.append(">").toString() + } + /** * Merges with another schema (`StructType`). For a struct field A from `this` and a struct field * B from `that`, @@ -312,6 +336,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru object StructType extends AbstractDataType { + private[sql] val metadataKeyForOptionalField = "_OPTIONAL_" + override private[sql] def defaultConcreteType: DataType = new StructType override private[sql] def acceptsType(other: DataType): Boolean = { @@ -320,20 +346,35 @@ object StructType extends AbstractDataType { override private[sql] def simpleString: String = "struct" - private[sql] def fromString(raw: String): StructType = DataType.fromString(raw) match { - case t: StructType => t - case _ => throw new RuntimeException(s"Failed parsing StructType: $raw") + private[sql] def fromString(raw: String): StructType = { + Try(DataType.fromJson(raw)).getOrElse(LegacyTypeStringParser.parse(raw)) match { + case t: StructType => t + case _ => throw new RuntimeException(s"Failed parsing StructType: $raw") + } } def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) def apply(fields: java.util.List[StructField]): StructType = { - StructType(fields.toArray.asInstanceOf[Array[StructField]]) + import scala.collection.JavaConverters._ + StructType(fields.asScala) } protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) + def removeMetadata(key: String, dt: DataType): DataType = + dt match { + case StructType(fields) => + val newFields = fields.map { f => + val mb = new MetadataBuilder() + f.copy(dataType = removeMetadata(key, f.dataType), + metadata = mb.withMetadata(f.metadata).remove(key).build()) + } + StructType(newFields) + case _ => dt + } + private[sql] def merge(left: DataType, right: DataType): DataType = (left, right) match { case (ArrayType(leftElementType, leftContainsNull), @@ -351,24 +392,32 @@ object StructType extends AbstractDataType { case (StructType(leftFields), StructType(rightFields)) => val newFields = ArrayBuffer.empty[StructField] + // This metadata will record the fields that only exist in one of two StructTypes + val optionalMeta = new MetadataBuilder() val rightMapped = fieldsMap(rightFields) leftFields.foreach { case leftField @ StructField(leftName, leftType, leftNullable, _) => rightMapped.get(leftName) .map { case rightField @ StructField(_, rightType, rightNullable, _) => - leftField.copy( - dataType = merge(leftType, rightType), - nullable = leftNullable || rightNullable) - } - .orElse(Some(leftField)) + leftField.copy( + dataType = merge(leftType, rightType), + nullable = leftNullable || rightNullable) + } + .orElse { + optionalMeta.putBoolean(metadataKeyForOptionalField, true) + Some(leftField.copy(metadata = optionalMeta.build())) + } .foreach(newFields += _) } val leftMapped = fieldsMap(leftFields) rightFields .filterNot(f => leftMapped.get(f.name).nonEmpty) - .foreach(newFields += _) + .foreach { f => + optionalMeta.putBoolean(metadataKeyForOptionalField, true) + newFields += f.copy(metadata = optionalMeta.build()) + } StructType(newFields) @@ -377,13 +426,13 @@ object StructType extends AbstractDataType { if ((leftPrecision == rightPrecision) && (leftScale == rightScale)) { DecimalType(leftPrecision, leftScale) } else if ((leftPrecision != rightPrecision) && (leftScale != rightScale)) { - throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + throw new SparkException("Failed to merge decimal types with incompatible " + s"precision $leftPrecision and $rightPrecision & scale $leftScale and $rightScale") } else if (leftPrecision != rightPrecision) { - throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + throw new SparkException("Failed to merge decimal types with incompatible " + s"precision $leftPrecision and $rightPrecision") } else { - throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + throw new SparkException("Failed to merge decimal types with incompatible " + s"scala $leftScale and $rightScale") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 4305903616bd9..fb7251d71b9b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -23,7 +23,6 @@ import org.json4s.JsonDSL._ import org.apache.spark.annotation.DeveloperApi /** - * ::DeveloperApi:: * The data type for User Defined Types (UDTs). * * This interface allows a user to make their own classes more interoperable with SparkSQL; @@ -35,9 +34,12 @@ import org.apache.spark.annotation.DeveloperApi * * The conversion via `serialize` occurs when instantiating a `DataFrame` from another RDD. * The conversion via `deserialize` occurs when reading from a `DataFrame`. + * + * Note: This was previously a developer API in Spark 1.x. We are making this private in Spark 2.0 + * because we will very likely create a new version of this that works better with Datasets. */ -@DeveloperApi -abstract class UserDefinedType[UserType] extends DataType with Serializable { +private[spark] +abstract class UserDefinedType[UserType >: Null] extends DataType with Serializable { /** Underlying storage type for this UDT */ def sqlType: DataType @@ -50,11 +52,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { /** * Convert the user type to a SQL datum - * - * TODO: Can we make this take obj: UserType? The issue is in - * CatalystTypeConverters.convertToCatalyst, where we need to convert Any to UserType. */ - def serialize(obj: Any): Any + def serialize(obj: UserType): Any /** Convert a SQL datum to the user type */ def deserialize(datum: Any): UserType @@ -71,10 +70,7 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { */ def userClass: java.lang.Class[UserType] - /** - * The default size of a value of the UserDefinedType is 4096 bytes. - */ - override def defaultSize: Int = 4096 + override def defaultSize: Int = sqlType.defaultSize /** * For UDT, asNullable will not change the nullability of its internal sqlType and just returns @@ -84,6 +80,13 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { override private[sql] def acceptsType(dataType: DataType) = this.getClass == dataType.getClass + + override def sql: String = sqlType.sql + + override def equals(other: Any): Boolean = other match { + case that: UserDefinedType[_] => this.acceptsType(that) + case _ => false + } } /** @@ -110,4 +113,9 @@ private[sql] class PythonUserDefinedType( ("serializedClass" -> serializedPyClass) ~ ("sqlType" -> sqlType.jsonValue) } + + override def equals(other: Any): Boolean = other match { + case that: PythonUserDefinedType => this.pyUDT.equals(that.pyUDT) + case _ => false + } } diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java new file mode 100644 index 0000000000000..711887f02832a --- /dev/null +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; +import java.util.HashSet; +import java.util.Random; +import java.util.Set; + +import org.apache.spark.unsafe.Platform; +import org.junit.Assert; +import org.junit.Test; + +/** + * Test the XXH64 function. + *

    + * Test constants were taken from the original implementation and the airlift/slice implementation. + */ +public class XXH64Suite { + + private static final XXH64 hasher = new XXH64(0); + + private static final int SIZE = 101; + private static final long PRIME = 2654435761L; + private static final byte[] BUFFER = new byte[SIZE]; + private static final int TEST_INT = 0x4B1FFF9E; // First 4 bytes in the buffer + private static final long TEST_LONG = 0xDD2F535E4B1FFF9EL; // First 8 bytes in the buffer + + /* Create the test data. */ + static { + long seed = PRIME; + for (int i = 0; i < SIZE; i++) { + BUFFER[i] = (byte) (seed >> 24); + seed *= seed; + } + } + + @Test + public void testKnownIntegerInputs() { + Assert.assertEquals(0x9256E58AA397AEF1L, hasher.hashInt(TEST_INT)); + Assert.assertEquals(0x9D5FFDFB928AB4BL, XXH64.hashInt(TEST_INT, PRIME)); + } + + @Test + public void testKnownLongInputs() { + Assert.assertEquals(0xF74CB1451B32B8CFL, hasher.hashLong(TEST_LONG)); + Assert.assertEquals(0x9C44B77FBCC302C5L, XXH64.hashLong(TEST_LONG, PRIME)); + } + + @Test + public void testKnownByteArrayInputs() { + Assert.assertEquals(0xEF46DB3751D8E999L, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 0)); + Assert.assertEquals(0xAC75FDA2929B17EFL, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 0, PRIME)); + Assert.assertEquals(0x4FCE394CC88952D8L, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 1)); + Assert.assertEquals(0x739840CB819FA723L, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 1, PRIME)); + + // These tests currently fail in a big endian environment because the test data and expected + // answers are generated with little endian the assumptions. We could revisit this when Platform + // becomes endian aware. + if (ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN) { + Assert.assertEquals(0x9256E58AA397AEF1L, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 4)); + Assert.assertEquals(0x9D5FFDFB928AB4BL, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 4, PRIME)); + Assert.assertEquals(0xF74CB1451B32B8CFL, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 8)); + Assert.assertEquals(0x9C44B77FBCC302C5L, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 8, PRIME)); + Assert.assertEquals(0xCFFA8DB881BC3A3DL, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 14)); + Assert.assertEquals(0x5B9611585EFCC9CBL, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 14, PRIME)); + Assert.assertEquals(0x0EAB543384F878ADL, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE)); + Assert.assertEquals(0xCAA65939306F1E21L, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE, PRIME)); + } + } + + @Test + public void randomizedStressTest() { + int size = 65536; + Random rand = new Random(); + + // A set used to track collision rate. + Set hashcodes = new HashSet<>(); + for (int i = 0; i < size; i++) { + int vint = rand.nextInt(); + long lint = rand.nextLong(); + Assert.assertEquals(hasher.hashInt(vint), hasher.hashInt(vint)); + Assert.assertEquals(hasher.hashLong(lint), hasher.hashLong(lint)); + + hashcodes.add(hasher.hashLong(lint)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95d); + } + + @Test + public void randomizedStressTestBytes() { + int size = 65536; + Random rand = new Random(); + + // A set used to track collision rate. + Set hashcodes = new HashSet<>(); + for (int i = 0; i < size; i++) { + int byteArrSize = rand.nextInt(100) * 8; + byte[] bytes = new byte[byteArrSize]; + rand.nextBytes(bytes); + + Assert.assertEquals( + hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + + hashcodes.add(hasher.hashUnsafeWords( + bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95d); + } + + @Test + public void randomizedStressTestPaddedStrings() { + int size = 64000; + // A set used to track collision rate. + Set hashcodes = new HashSet<>(); + for (int i = 0; i < size; i++) { + int byteArrSize = 8; + byte[] strBytes = String.valueOf(i).getBytes(StandardCharsets.UTF_8); + byte[] paddedBytes = new byte[byteArrSize]; + System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); + + Assert.assertEquals( + hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + + hashcodes.add(hasher.hashUnsafeWords( + paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95d); + } +} diff --git a/sql/catalyst/src/test/resources/log4j.properties b/sql/catalyst/src/test/resources/log4j.properties index eb3b1999eb996..3706a6e361307 100644 --- a/sql/catalyst/src/test/resources/log4j.properties +++ b/sql/catalyst/src/test/resources/log4j.properties @@ -24,5 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN -org.spark-project.jetty.LEVEL=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala new file mode 100644 index 0000000000000..c6a1a2be0d071 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection +import org.apache.spark.sql.types._ +import org.apache.spark.util.Benchmark + +/** + * Benchmark for the previous interpreted hash function(InternalRow.hashCode) vs codegened + * hash expressions (Murmur3Hash/xxHash64). + */ +object HashBenchmark { + + def test(name: String, schema: StructType, numRows: Int, iters: Int): Unit = { + val generator = RandomDataGenerator.forType(schema, nullable = false).get + val encoder = RowEncoder(schema) + val attrs = schema.toAttributes + val safeProjection = GenerateSafeProjection.generate(attrs, attrs) + + val rows = (1 to numRows).map(_ => + // The output of encoder is UnsafeRow, use safeProjection to turn in into safe format. + safeProjection(encoder.toRow(generator().asInstanceOf[Row])).copy() + ).toArray + + val benchmark = new Benchmark("Hash For " + name, iters * numRows) + benchmark.addCase("interpreted version") { _: Int => + for (_ <- 0L until iters) { + var sum = 0 + var i = 0 + while (i < numRows) { + sum += rows(i).hashCode() + i += 1 + } + } + } + + val getHashCode = UnsafeProjection.create(new Murmur3Hash(attrs) :: Nil, attrs) + benchmark.addCase("codegen version") { _: Int => + for (_ <- 0L until iters) { + var sum = 0 + var i = 0 + while (i < numRows) { + sum += getHashCode(rows(i)).getInt(0) + i += 1 + } + } + } + + val getHashCode64b = UnsafeProjection.create(new XxHash64(attrs) :: Nil, attrs) + benchmark.addCase("codegen version 64-bit") { _: Int => + for (_ <- 0L until iters) { + var sum = 0 + var i = 0 + while (i < numRows) { + sum += getHashCode64b(rows(i)).getInt(0) + i += 1 + } + } + } + + benchmark.run() + } + + def main(args: Array[String]): Unit = { + val singleInt = new StructType().add("i", IntegerType) + /* + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash For single ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + interpreted version 1006 / 1011 133.4 7.5 1.0X + codegen version 1835 / 1839 73.1 13.7 0.5X + codegen version 64-bit 1627 / 1628 82.5 12.1 0.6X + */ + test("single ints", singleInt, 1 << 15, 1 << 14) + + val singleLong = new StructType().add("i", LongType) + /* + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash For single longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + interpreted version 1196 / 1209 112.2 8.9 1.0X + codegen version 2178 / 2181 61.6 16.2 0.5X + codegen version 64-bit 1752 / 1753 76.6 13.1 0.7X + */ + test("single longs", singleLong, 1 << 15, 1 << 14) + + val normal = new StructType() + .add("null", NullType) + .add("boolean", BooleanType) + .add("byte", ByteType) + .add("short", ShortType) + .add("int", IntegerType) + .add("long", LongType) + .add("float", FloatType) + .add("double", DoubleType) + .add("bigDecimal", DecimalType.SYSTEM_DEFAULT) + .add("smallDecimal", DecimalType.USER_DEFAULT) + .add("string", StringType) + .add("binary", BinaryType) + .add("date", DateType) + .add("timestamp", TimestampType) + /* + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash For normal: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + interpreted version 2713 / 2715 0.8 1293.5 1.0X + codegen version 2015 / 2018 1.0 960.9 1.3X + codegen version 64-bit 735 / 738 2.9 350.7 3.7X + */ + test("normal", normal, 1 << 10, 1 << 11) + + val arrayOfInt = ArrayType(IntegerType) + val array = new StructType() + .add("array", arrayOfInt) + .add("arrayOfArray", ArrayType(arrayOfInt)) + /* + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash For array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + interpreted version 1498 / 1499 0.1 11432.1 1.0X + codegen version 2642 / 2643 0.0 20158.4 0.6X + codegen version 64-bit 2421 / 2424 0.1 18472.5 0.6X + */ + test("array", array, 1 << 8, 1 << 9) + + val mapOfInt = MapType(IntegerType, IntegerType) + val map = new StructType() + .add("map", mapOfInt) + .add("mapOfMap", MapType(IntegerType, mapOfInt)) + /* + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash For map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + interpreted version 1612 / 1618 0.0 393553.4 1.0X + codegen version 149 / 150 0.0 36381.2 10.8X + codegen version 64-bit 144 / 145 0.0 35122.1 11.2X + */ + test("map", map, 1 << 6, 1 << 6) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala new file mode 100644 index 0000000000000..53f21a8442429 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.Random + +import org.apache.spark.sql.catalyst.expressions.XXH64 +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.hash.Murmur3_x86_32 +import org.apache.spark.util.Benchmark + +/** + * Synthetic benchmark for MurMurHash 3 and xxHash64. + */ +object HashByteArrayBenchmark { + def test(length: Int, seed: Long, numArrays: Int, iters: Int): Unit = { + val random = new Random(seed) + val arrays = Array.fill[Array[Byte]](numArrays) { + val bytes = new Array[Byte](length) + random.nextBytes(bytes) + bytes + } + + val benchmark = new Benchmark("Hash byte arrays with length " + length, iters * numArrays) + benchmark.addCase("Murmur3_x86_32") { _: Int => + for (_ <- 0L until iters) { + var sum = 0 + var i = 0 + while (i < numArrays) { + sum += Murmur3_x86_32.hashUnsafeBytes(arrays(i), Platform.BYTE_ARRAY_OFFSET, length, 42) + i += 1 + } + } + } + + benchmark.addCase("xxHash 64-bit") { _: Int => + for (_ <- 0L until iters) { + var sum = 0L + var i = 0 + while (i < numArrays) { + sum += XXH64.hashUnsafeBytes(arrays(i), Platform.BYTE_ARRAY_OFFSET, length, 42) + i += 1 + } + } + } + + benchmark.run() + } + + def main(args: Array[String]): Unit = { + /* + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash byte arrays with length 8: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Murmur3_x86_32 11 / 12 185.1 5.4 1.0X + xxHash 64-bit 17 / 18 120.0 8.3 0.6X + */ + test(8, 42L, 1 << 10, 1 << 11) + + /* + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash byte arrays with length 16: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Murmur3_x86_32 18 / 18 118.6 8.4 1.0X + xxHash 64-bit 20 / 21 102.5 9.8 0.9X + */ + test(16, 42L, 1 << 10, 1 << 11) + + /* + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash byte arrays with length 24: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Murmur3_x86_32 24 / 24 86.6 11.5 1.0X + xxHash 64-bit 23 / 23 93.2 10.7 1.1X + */ + test(24, 42L, 1 << 10, 1 << 11) + + // Add 31 to all arrays to create worse case alignment for xxHash. + /* + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash byte arrays with length 31: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Murmur3_x86_32 38 / 39 54.7 18.3 1.0X + xxHash 64-bit 33 / 33 64.4 15.5 1.2X + */ + test(31, 42L, 1 << 10, 1 << 11) + + /* + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash byte arrays with length 95: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Murmur3_x86_32 91 / 94 22.9 43.6 1.0X + xxHash 64-bit 68 / 69 30.6 32.7 1.3X + */ + test(64 + 31, 42L, 1 << 10, 1 << 11) + + /* + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash byte arrays with length 287: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Murmur3_x86_32 268 / 268 7.8 127.6 1.0X + xxHash 64-bit 108 / 109 19.4 51.6 2.5X + */ + test(256 + 31, 42L, 1 << 10, 1 << 11) + + /* + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash byte arrays with length 1055: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Murmur3_x86_32 942 / 945 2.2 449.4 1.0X + xxHash 64-bit 276 / 276 7.6 131.4 3.4X + */ + test(1024 + 31, 42L, 1 << 10, 1 << 11) + + /* + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash byte arrays with length 2079: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Murmur3_x86_32 1839 / 1843 1.1 876.8 1.0X + xxHash 64-bit 445 / 448 4.7 212.1 4.1X + */ + test(2048 + 31, 42L, 1 << 10, 1 << 11) + + /* + Intel(R) Core(TM) i7-4750HQ CPU @ 2.00GHz + Hash byte arrays with length 8223: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Murmur3_x86_32 7307 / 7310 0.3 3484.4 1.0X + xxHash 64-bit 1487 / 1488 1.4 709.1 4.9X + */ + test(8192 + 31, 42L, 1 << 10, 1 << 11) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 7614f055e9c04..711e8707116cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -21,6 +21,7 @@ import java.lang.Double.longBitsToDouble import java.lang.Float.intBitsToFloat import java.math.MathContext +import scala.collection.mutable import scala.util.Random import org.apache.spark.sql.catalyst.CatalystTypeConverters @@ -46,9 +47,9 @@ object RandomDataGenerator { */ private val PROBABILITY_OF_NULL: Float = 0.1f - private val MAX_STR_LEN: Int = 1024 - private val MAX_ARR_SIZE: Int = 128 - private val MAX_MAP_SIZE: Int = 128 + final val MAX_STR_LEN: Int = 1024 + final val MAX_ARR_SIZE: Int = 128 + final val MAX_MAP_SIZE: Int = 128 /** * Helper function for constructing a biased random number generator which returns "interesting" @@ -74,13 +75,47 @@ object RandomDataGenerator { * @param numFields the number of fields in this schema * @param acceptedTypes types to draw from. */ - def randomSchema(numFields: Int, acceptedTypes: Seq[DataType]): StructType = { + def randomSchema(rand: Random, numFields: Int, acceptedTypes: Seq[DataType]): StructType = { StructType(Seq.tabulate(numFields) { i => - val dt = acceptedTypes(Random.nextInt(acceptedTypes.size)) - StructField("col_" + i, dt, nullable = true) + val dt = acceptedTypes(rand.nextInt(acceptedTypes.size)) + StructField("col_" + i, dt, nullable = rand.nextBoolean()) }) } + /** + * Returns a random nested schema. This will randomly generate structs and arrays drawn from + * acceptedTypes. + */ + def randomNestedSchema(rand: Random, totalFields: Int, acceptedTypes: Seq[DataType]): + StructType = { + val fields = mutable.ArrayBuffer.empty[StructField] + var i = 0 + var numFields = totalFields + while (numFields > 0) { + val v = rand.nextInt(3) + if (v == 0) { + // Simple type: + val dt = acceptedTypes(rand.nextInt(acceptedTypes.size)) + fields += new StructField("col_" + i, dt, rand.nextBoolean()) + numFields -= 1 + } else if (v == 1) { + // Array + val dt = acceptedTypes(rand.nextInt(acceptedTypes.size)) + fields += new StructField("col_" + i, ArrayType(dt), rand.nextBoolean()) + numFields -= 1 + } else { + // Struct + // TODO: do empty structs make sense? + val n = Math.max(rand.nextInt(numFields), 1) + val nested = randomNestedSchema(rand, n, acceptedTypes) + fields += new StructField("col_" + i, nested, rand.nextBoolean()) + numFields -= n + } + i += 1 + } + StructType(fields) + } + /** * Returns a function which generates random values for the given [[DataType]], or `None` if no * random data generator is defined for that data type. The generated values will use an external @@ -90,16 +125,13 @@ object RandomDataGenerator { * * @param dataType the type to generate values for * @param nullable whether null values should be generated - * @param seed an optional seed for the random number generator + * @param rand an optional random number generator * @return a function which can be called to generate random values. */ def forType( dataType: DataType, nullable: Boolean = true, - seed: Option[Long] = None): Option[() => Any] = { - val rand = new Random() - seed.foreach(rand.setSeed) - + rand: Random = new Random): Option[() => Any] = { val valueGenerator: Option[() => Any] = dataType match { case StringType => Some(() => rand.nextString(rand.nextInt(MAX_STR_LEN))) case BinaryType => Some(() => { @@ -116,7 +148,7 @@ object RandomDataGenerator { // for "0001-01-01 00:00:00.000000". We need to find a // number that is greater or equals to this number as a valid timestamp value. while (milliseconds < -62135740800000L) { - // 253402329599999L is the the number of milliseconds since + // 253402329599999L is the number of milliseconds since // January 1, 1970, 00:00:00 GMT for "9999-12-31 23:59:59.999999". milliseconds = rand.nextLong() % 253402329599999L } @@ -131,7 +163,7 @@ object RandomDataGenerator { // for "0001-01-01 00:00:00.000000". We need to find a // number that is greater or equals to this number as a valid timestamp value. while (milliseconds < -62135740800000L) { - // 253402329599999L is the the number of milliseconds since + // 253402329599999L is the number of milliseconds since // January 1, 1970, 00:00:00 GMT for "9999-12-31 23:59:59.999999". milliseconds = rand.nextLong() % 253402329599999L } @@ -164,25 +196,33 @@ object RandomDataGenerator { case ShortType => randomNumeric[Short]( rand, _.nextInt().toShort, Seq(Short.MinValue, Short.MaxValue, 0.toShort)) case NullType => Some(() => null) - case ArrayType(elementType, containsNull) => { - forType(elementType, nullable = containsNull, seed = Some(rand.nextLong())).map { + case ArrayType(elementType, containsNull) => + forType(elementType, nullable = containsNull, rand).map { elementGenerator => () => Seq.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator()) } - } - case MapType(keyType, valueType, valueContainsNull) => { + case MapType(keyType, valueType, valueContainsNull) => for ( - keyGenerator <- forType(keyType, nullable = false, seed = Some(rand.nextLong())); + keyGenerator <- forType(keyType, nullable = false, rand); valueGenerator <- - forType(valueType, nullable = valueContainsNull, seed = Some(rand.nextLong())) + forType(valueType, nullable = valueContainsNull, rand) ) yield { () => { - Seq.fill(rand.nextInt(MAX_MAP_SIZE))((keyGenerator(), valueGenerator())).toMap + val length = rand.nextInt(MAX_MAP_SIZE) + val keys = scala.collection.mutable.HashSet(Seq.fill(length)(keyGenerator()): _*) + // In case the number of different keys is not enough, set a max iteration to avoid + // infinite loop. + var count = 0 + while (keys.size < length && count < MAX_MAP_SIZE) { + keys += keyGenerator() + count += 1 + } + val values = Seq.fill(keys.size)(valueGenerator()) + keys.zip(values).toMap } } - } - case StructType(fields) => { + case StructType(fields) => val maybeFieldGenerators: Seq[Option[() => Any]] = fields.map { field => - forType(field.dataType, nullable = field.nullable, seed = Some(rand.nextLong())) + forType(field.dataType, nullable = field.nullable, rand) } if (maybeFieldGenerators.forall(_.isDefined)) { val fieldGenerators: Seq[() => Any] = maybeFieldGenerators.map(_.get) @@ -190,9 +230,8 @@ object RandomDataGenerator { } else { None } - } - case udt: UserDefinedType[_] => { - val maybeSqlTypeGenerator = forType(udt.sqlType, nullable, seed) + case udt: UserDefinedType[_] => + val maybeSqlTypeGenerator = forType(udt.sqlType, nullable, rand) // Because random data generator at here returns scala value, we need to // convert it to catalyst value to call udt's deserialize. val toCatalystType = CatalystTypeConverters.createToCatalystConverter(udt.sqlType) @@ -211,7 +250,6 @@ object RandomDataGenerator { } else { None } - } case unsupportedType => None } // Handle nullability by wrapping the non-null value generator: @@ -229,4 +267,38 @@ object RandomDataGenerator { } } } + + // Generates a random row for `schema`. + def randomRow(rand: Random, schema: StructType): Row = { + val fields = mutable.ArrayBuffer.empty[Any] + schema.fields.foreach { f => + f.dataType match { + case ArrayType(childType, nullable) => + val data = if (f.nullable && rand.nextFloat() <= PROBABILITY_OF_NULL) { + null + } else { + val arr = mutable.ArrayBuffer.empty[Any] + val n = 1// rand.nextInt(10) + var i = 0 + val generator = RandomDataGenerator.forType(childType, nullable, rand) + assert(generator.isDefined, "Unsupported type") + val gen = generator.get + while (i < n) { + arr += gen() + i += 1 + } + arr + } + fields += data + case StructType(children) => + fields += randomRow(rand, StructType(children)) + case _ => + val generator = RandomDataGenerator.forType(f.dataType, f.nullable, rand) + assert(generator.isDefined, "Unsupported type") + val gen = generator.get + fields += gen() + } + } + Row.fromSeq(fields) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala index cccac7efa09e9..3c2f8a28875f7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.types._ @@ -32,13 +34,13 @@ class RandomDataGeneratorSuite extends SparkFunSuite { */ def testRandomDataGeneration(dataType: DataType, nullable: Boolean = true): Unit = { val toCatalyst = CatalystTypeConverters.createToCatalystConverter(dataType) - val generator = RandomDataGenerator.forType(dataType, nullable, Some(33)).getOrElse { + val generator = RandomDataGenerator.forType(dataType, nullable, new Random(33)).getOrElse { fail(s"Random data generator was not defined for $dataType") } if (nullable) { assert(Iterator.fill(100)(generator()).contains(null)) } else { - assert(Iterator.fill(100)(generator()).forall(_ != null)) + assert(!Iterator.fill(100)(generator()).contains(null)) } for (_ <- 1 to 10) { val generatedValue = generator() @@ -93,4 +95,15 @@ class RandomDataGeneratorSuite extends SparkFunSuite { } } + test("check size of generated map") { + val mapType = MapType(IntegerType, IntegerType) + for (seed <- 1 to 1000) { + val generator = RandomDataGenerator.forType( + mapType, nullable = false, rand = new Random(seed)).get + val maps = Seq.fill(100)(generator().asInstanceOf[Map[Int, Int]]) + val expectedTotalElements = 100 / 2 * RandomDataGenerator.MAX_MAP_SIZE + val deviation = math.abs(maps.map(_.size).sum - expectedTotalElements) + assert(deviation.toDouble / expectedTotalElements < 2e-1) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala index 01ff84cb56054..c9c9599e7f463 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql +import org.scalatest.{FunSpec, Matchers} + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema} import org.apache.spark.sql.types._ -import org.scalatest.{Matchers, FunSpec} class RowTest extends FunSpec with Matchers { @@ -29,8 +30,10 @@ class RowTest extends FunSpec with Matchers { StructField("col2", StringType) :: StructField("col3", IntegerType) :: Nil) val values = Array("value1", "value2", 1) + val valuesWithoutCol3 = Array[Any](null, "value2", null) val sampleRow: Row = new GenericRowWithSchema(values, schema) + val sampleRowWithoutCol3: Row = new GenericRowWithSchema(valuesWithoutCol3, schema) val noSchemaRow: Row = new GenericRow(values) describe("Row (without schema)") { @@ -68,6 +71,24 @@ class RowTest extends FunSpec with Matchers { ) sampleRow.getValuesMap(List("col1", "col2")) shouldBe expected } + + it("getValuesMap() retrieves null value on non AnyVal Type") { + val expected = Map( + "col1" -> null, + "col2" -> "value2" + ) + sampleRowWithoutCol3.getValuesMap[String](List("col1", "col2")) shouldBe expected + } + + it("getAs() on type extending AnyVal throws an exception when accessing field that is null") { + intercept[NullPointerException] { + sampleRowWithoutCol3.getInt(sampleRowWithoutCol3.fieldIndex("col3")) + } + } + + it("getAs() on type extending AnyVal does not throw exception when value is null") { + sampleRowWithoutCol3.getAs[String](sampleRowWithoutCol3.fieldIndex("col1")) shouldBe null + } } describe("row equals") { @@ -84,4 +105,34 @@ class RowTest extends FunSpec with Matchers { internalRow shouldEqual internalRow2 } } + + describe("row immutability") { + val values = Seq(1, 2, "3", "IV", 6L) + val externalRow = Row.fromSeq(values) + val internalRow = InternalRow.fromSeq(values) + + def modifyValues(values: Seq[Any]): Seq[Any] = { + val array = values.toArray + array(2) = "42" + array + } + + it("copy should return same ref for external rows") { + externalRow should be theSameInstanceAs externalRow.copy() + } + + it("copy should return same ref for internal rows") { + internalRow should be theSameInstanceAs internalRow.copy() + } + + it("toSeq should not expose internal state for external rows") { + val modifiedValues = modifyValues(externalRow.toSeq) + externalRow.toSeq should not equal modifiedValues + } + + it("toSeq should not expose internal state for internal rows") { + val modifiedValues = modifyValues(internalRow.toSeq(Seq.empty)) + internalRow.toSeq(Seq.empty) should not equal modifiedValues + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala new file mode 100644 index 0000000000000..a6d90409382e5 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.types._ +import org.apache.spark.util.Benchmark + +/** + * Benchmark [[UnsafeProjection]] for fixed-length/primitive-type fields. + */ +object UnsafeProjectionBenchmark { + + def generateRows(schema: StructType, numRows: Int): Array[InternalRow] = { + val generator = RandomDataGenerator.forType(schema, nullable = false).get + val encoder = RowEncoder(schema) + (1 to numRows).map(_ => encoder.toRow(generator().asInstanceOf[Row]).copy()).toArray + } + + def main(args: Array[String]) { + val iters = 1024 * 16 + val numRows = 1024 * 16 + + val benchmark = new Benchmark("unsafe projection", iters * numRows) + + + val schema1 = new StructType().add("l", LongType, false) + val attrs1 = schema1.toAttributes + val rows1 = generateRows(schema1, numRows) + val projection1 = UnsafeProjection.create(attrs1, attrs1) + + benchmark.addCase("single long") { _ => + for (_ <- 1 to iters) { + var sum = 0L + var i = 0 + while (i < numRows) { + sum += projection1(rows1(i)).getLong(0) + i += 1 + } + } + } + + val schema2 = new StructType().add("l", LongType, true) + val attrs2 = schema2.toAttributes + val rows2 = generateRows(schema2, numRows) + val projection2 = UnsafeProjection.create(attrs2, attrs2) + + benchmark.addCase("single nullable long") { _ => + for (_ <- 1 to iters) { + var sum = 0L + var i = 0 + while (i < numRows) { + sum += projection2(rows2(i)).getLong(0) + i += 1 + } + } + } + + + val schema3 = new StructType() + .add("boolean", BooleanType, false) + .add("byte", ByteType, false) + .add("short", ShortType, false) + .add("int", IntegerType, false) + .add("long", LongType, false) + .add("float", FloatType, false) + .add("double", DoubleType, false) + val attrs3 = schema3.toAttributes + val rows3 = generateRows(schema3, numRows) + val projection3 = UnsafeProjection.create(attrs3, attrs3) + + benchmark.addCase("7 primitive types") { _ => + for (_ <- 1 to iters) { + var sum = 0L + var i = 0 + while (i < numRows) { + sum += projection3(rows3(i)).getLong(0) + i += 1 + } + } + } + + + val schema4 = new StructType() + .add("boolean", BooleanType, true) + .add("byte", ByteType, true) + .add("short", ShortType, true) + .add("int", IntegerType, true) + .add("long", LongType, true) + .add("float", FloatType, true) + .add("double", DoubleType, true) + val attrs4 = schema4.toAttributes + val rows4 = generateRows(schema4, numRows) + val projection4 = UnsafeProjection.create(attrs4, attrs4) + + benchmark.addCase("7 nullable primitive types") { _ => + for (_ <- 1 to iters) { + var sum = 0L + var i = 0 + while (i < numRows) { + sum += projection4(rows4(i)).getLong(0) + i += 1 + } + } + } + + + /* + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + unsafe projection: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + single long 1533.34 175.07 1.00 X + single nullable long 2306.73 116.37 0.66 X + primitive types 8403.93 31.94 0.18 X + nullable primitive types 12448.39 21.56 0.12 X + */ + benchmark.run() + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index 827f7ce692712..b47b8adfe5d55 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -18,10 +18,9 @@ package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.plans.physical._ - /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ class DistributionSuite extends SparkFunSuite { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 3b848cfdf737f..5ca5a72512a29 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql.catalyst -import java.math.BigInteger +import java.net.URLClassLoader import java.sql.{Date, Timestamp} +import scala.reflect.runtime.universe.typeOf + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.BoundReference import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils case class PrimitiveData( intField: Int, @@ -69,6 +73,10 @@ case class ComplexData( case class GenericData[A]( genericField: A) +object GenericData { + type IntData = GenericData[Int] +} + case class MultipleConstructorsData(a: Int, b: String, c: Double) { def this(b: String, a: Int) = this(a, b, c = 1.0) } @@ -186,72 +194,8 @@ class ScalaReflectionSuite extends SparkFunSuite { nullable = true)) } - test("get data type of a value") { - // BooleanType - assert(BooleanType === typeOfObject(true)) - assert(BooleanType === typeOfObject(false)) - - // BinaryType - assert(BinaryType === typeOfObject("string".getBytes)) - - // StringType - assert(StringType === typeOfObject("string")) - - // ByteType - assert(ByteType === typeOfObject(127.toByte)) - - // ShortType - assert(ShortType === typeOfObject(32767.toShort)) - - // IntegerType - assert(IntegerType === typeOfObject(2147483647)) - - // LongType - assert(LongType === typeOfObject(9223372036854775807L)) - - // FloatType - assert(FloatType === typeOfObject(3.4028235E38.toFloat)) - - // DoubleType - assert(DoubleType === typeOfObject(1.7976931348623157E308)) - - // DecimalType - assert(DecimalType.SYSTEM_DEFAULT === - typeOfObject(new java.math.BigDecimal("1.7976931348623157E318"))) - - // DateType - assert(DateType === typeOfObject(Date.valueOf("2014-07-25"))) - - // TimestampType - assert(TimestampType === typeOfObject(Timestamp.valueOf("2014-07-25 10:26:00"))) - - // NullType - assert(NullType === typeOfObject(null)) - - def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT - case value: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT - case _ => StringType - } - - assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1( - new BigInteger("92233720368547758070"))) - assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1( - new java.math.BigDecimal("1.7976931348623157E318"))) - assert(StringType === typeOfObject1(BigInt("92233720368547758070"))) - - def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT - } - - intercept[MatchError](typeOfObject2(BigInt("92233720368547758070"))) - - def typeOfObject3: PartialFunction[Any, DataType] = typeOfObject orElse { - case c: Seq[_] => ArrayType(typeOfObject3(c.head)) - } - - assert(ArrayType(IntegerType) === typeOfObject3(Seq(1, 2, 3))) - assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1, 2, 3)))) + test("type-aliased data") { + assert(schemaFor[GenericData[Int]] == schemaFor[GenericData.IntData]) } test("convert PrimitiveData to catalyst") { @@ -280,4 +224,57 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(s.fields.map(_.dataType) === Seq(IntegerType, StringType, DoubleType)) } } + + test("get parameter type from a function object") { + val primitiveFunc = (i: Int, j: Long) => "x" + val primitiveTypes = getParameterTypes(primitiveFunc) + assert(primitiveTypes.forall(_.isPrimitive)) + assert(primitiveTypes === Seq(classOf[Int], classOf[Long])) + + val boxedFunc = (i: java.lang.Integer, j: java.lang.Long) => "x" + val boxedTypes = getParameterTypes(boxedFunc) + assert(boxedTypes.forall(!_.isPrimitive)) + assert(boxedTypes === Seq(classOf[java.lang.Integer], classOf[java.lang.Long])) + + val anyFunc = (i: Any, j: AnyRef) => "x" + val anyTypes = getParameterTypes(anyFunc) + assert(anyTypes.forall(!_.isPrimitive)) + assert(anyTypes === Seq(classOf[java.lang.Object], classOf[java.lang.Object])) + } + + private val dataTypeForComplexData = dataTypeFor[ComplexData] + private val typeOfComplexData = typeOf[ComplexData] + + Seq( + ("mirror", () => mirror), + ("dataTypeFor", () => dataTypeFor[ComplexData]), + ("constructorFor", () => deserializerFor[ComplexData]), + ("extractorsFor", { + val inputObject = BoundReference(0, dataTypeForComplexData, nullable = false) + () => serializerFor[ComplexData](inputObject) + }), + ("getConstructorParameters(cls)", () => getConstructorParameters(classOf[ComplexData])), + ("getConstructorParameterNames", () => getConstructorParameterNames(classOf[ComplexData])), + ("getClassFromType", () => getClassFromType(typeOfComplexData)), + ("schemaFor", () => schemaFor[ComplexData]), + ("localTypeOf", () => localTypeOf[ComplexData]), + ("getClassNameFromType", () => getClassNameFromType(typeOfComplexData)), + ("getParameterTypes", () => getParameterTypes(() => ())), + ("getConstructorParameters(tpe)", () => getClassNameFromType(typeOfComplexData))).foreach { + case (name, exec) => + test(s"SPARK-13640: thread safety of ${name}") { + (0 until 100).foreach { _ => + val loader = new URLClassLoader(Array.empty, Utils.getContextOrSparkClassLoader) + (0 until 10).par.foreach { _ => + val cl = Thread.currentThread.getContextClassLoader + try { + Thread.currentThread.setContextClassLoader(loader) + exec() + } finally { + Thread.currentThread.setContextClassLoader(cl) + } + } + } + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala deleted file mode 100644 index ea28bfa021bed..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst - -import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias -import org.apache.spark.sql.catalyst.expressions.{Literal, GreaterThan, Not, Attribute} -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project, LogicalPlan, Command} -import org.apache.spark.unsafe.types.CalendarInterval - -private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command { - override def output: Seq[Attribute] = Seq.empty - override def children: Seq[LogicalPlan] = Seq.empty -} - -private[sql] class SuperLongKeywordTestParser extends AbstractSparkSQLParser { - protected val EXECUTE = Keyword("THISISASUPERLONGKEYWORDTEST") - - override protected lazy val start: Parser[LogicalPlan] = set - - private lazy val set: Parser[LogicalPlan] = - EXECUTE ~> ident ^^ { - case fileName => TestCommand(fileName) - } -} - -private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser { - protected val EXECUTE = Keyword("EXECUTE") - - override protected lazy val start: Parser[LogicalPlan] = set - - private lazy val set: Parser[LogicalPlan] = - EXECUTE ~> ident ^^ { - case fileName => TestCommand(fileName) - } -} - -class SqlParserSuite extends PlanTest { - - test("test long keyword") { - val parser = new SuperLongKeywordTestParser - assert(TestCommand("NotRealCommand") === - parser.parse("ThisIsASuperLongKeyWordTest NotRealCommand")) - } - - test("test case insensitive") { - val parser = new CaseInsensitiveTestParser - assert(TestCommand("NotRealCommand") === parser.parse("EXECUTE NotRealCommand")) - assert(TestCommand("NotRealCommand") === parser.parse("execute NotRealCommand")) - assert(TestCommand("NotRealCommand") === parser.parse("exEcute NotRealCommand")) - } - - test("test NOT operator with comparison operations") { - val parsed = SqlParser.parse("SELECT NOT TRUE > TRUE") - val expected = Project( - UnresolvedAlias( - Not( - GreaterThan(Literal(true), Literal(true))) - ) :: Nil, - OneRowRelation) - comparePlans(parsed, expected) - } - - test("support hive interval literal") { - def checkInterval(sql: String, result: CalendarInterval): Unit = { - val parsed = SqlParser.parse(sql) - val expected = Project( - UnresolvedAlias( - Literal(result) - ) :: Nil, - OneRowRelation) - comparePlans(parsed, expected) - } - - def checkYearMonth(lit: String): Unit = { - checkInterval( - s"SELECT INTERVAL '$lit' YEAR TO MONTH", - CalendarInterval.fromYearMonthString(lit)) - } - - def checkDayTime(lit: String): Unit = { - checkInterval( - s"SELECT INTERVAL '$lit' DAY TO SECOND", - CalendarInterval.fromDayTimeString(lit)) - } - - def checkSingleUnit(lit: String, unit: String): Unit = { - checkInterval( - s"SELECT INTERVAL '$lit' $unit", - CalendarInterval.fromSingleUnitString(unit, lit)) - } - - checkYearMonth("123-10") - checkYearMonth("496-0") - checkYearMonth("-2-3") - checkYearMonth("-123-0") - - checkDayTime("99 11:22:33.123456789") - checkDayTime("-99 11:22:33.123456789") - checkDayTime("10 9:8:7.123456789") - checkDayTime("1 0:0:0") - checkDayTime("-1 0:0:0") - checkDayTime("1 0:0:1") - - for (unit <- Seq("year", "month", "day", "hour", "minute", "second")) { - checkSingleUnit("7", unit) - checkSingleUnit("-7", unit) - checkSingleUnit("0", unit) - } - - checkSingleUnit("13.123456789", "second") - checkSingleUnit("-13.123456789", "second") - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index fbdd3a7776f50..ad101d1c406b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -17,14 +17,67 @@ package org.apache.spark.sql.catalyst.analysis +import scala.beans.{BeanInfo, BeanProperty} + import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} import org.apache.spark.sql.types._ +@BeanInfo +private[sql] case class GroupableData(@BeanProperty data: Int) + +private[sql] class GroupableUDT extends UserDefinedType[GroupableData] { + + override def sqlType: DataType = IntegerType + + override def serialize(groupableData: GroupableData): Int = groupableData.data + + override def deserialize(datum: Any): GroupableData = { + datum match { + case data: Int => GroupableData(data) + } + } + + override def userClass: Class[GroupableData] = classOf[GroupableData] + + private[spark] override def asNullable: GroupableUDT = this +} + +@BeanInfo +private[sql] case class UngroupableData(@BeanProperty data: Map[Int, Int]) + +private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] { + + override def sqlType: DataType = MapType(IntegerType, IntegerType) + + override def serialize(ungroupableData: UngroupableData): MapData = { + val keyArray = new GenericArrayData(ungroupableData.data.keys.toSeq) + val valueArray = new GenericArrayData(ungroupableData.data.values.toSeq) + new ArrayBasedMapData(keyArray, valueArray) + } + + override def deserialize(datum: Any): UngroupableData = { + datum match { + case data: MapData => + val keyArray = data.keyArray().array + val valueArray = data.valueArray().array + assert(keyArray.length == valueArray.length) + val mapData = keyArray.zip(valueArray).toMap.asInstanceOf[Map[Int, Int]] + UngroupableData(mapData) + } + } + + override def userClass: Class[UngroupableData] = classOf[UngroupableData] + + private[spark] override def asNullable: UngroupableUDT = this +} + case class TestFunction( children: Seq[Expression], inputTypes: Seq[AbstractDataType]) @@ -53,38 +106,71 @@ class AnalysisErrorSuite extends AnalysisTest { val dateLit = Literal.create(null, DateType) + errorTest( + "scalar subquery with 2 columns", + testRelation.select( + (ScalarSubquery(testRelation.select('a, dateLit.as('b))) + Literal(1)).as('a)), + "Scalar subquery must return only one column, but got 2" :: Nil) + + errorTest( + "scalar subquery with no column", + testRelation.select(ScalarSubquery(LocalRelation()).as('a)), + "Scalar subquery must return only one column, but got 0" :: Nil) + errorTest( "single invalid type, single arg", testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)), - "cannot resolve" :: "testfunction" :: "argument 1" :: "requires int type" :: - "'null' is of date type" :: Nil) + "cannot resolve" :: "testfunction(CAST(NULL AS DATE))" :: "argument 1" :: "requires int type" :: + "'CAST(NULL AS DATE)' is of date type" :: Nil) errorTest( "single invalid type, second arg", testRelation.select( TestFunction(dateLit :: dateLit :: Nil, DateType :: IntegerType :: Nil).as('a)), - "cannot resolve" :: "testfunction" :: "argument 2" :: "requires int type" :: - "'null' is of date type" :: Nil) + "cannot resolve" :: "testfunction(CAST(NULL AS DATE), CAST(NULL AS DATE))" :: + "argument 2" :: "requires int type" :: + "'CAST(NULL AS DATE)' is of date type" :: Nil) errorTest( "multiple invalid type", testRelation.select( TestFunction(dateLit :: dateLit :: Nil, IntegerType :: IntegerType :: Nil).as('a)), - "cannot resolve" :: "testfunction" :: "argument 1" :: "argument 2" :: - "requires int type" :: "'null' is of date type" :: Nil) + "cannot resolve" :: "testfunction(CAST(NULL AS DATE), CAST(NULL AS DATE))" :: + "argument 1" :: "argument 2" :: "requires int type" :: + "'CAST(NULL AS DATE)' is of date type" :: Nil) + + errorTest( + "invalid window function", + testRelation2.select( + WindowExpression( + Literal(0), + WindowSpecDefinition( + UnresolvedAttribute("a") :: Nil, + SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + UnspecifiedFrame)).as('window)), + "not supported within a window function" :: Nil) errorTest( - "unresolved window function", + "distinct window function", testRelation2.select( WindowExpression( - UnresolvedWindowFunction( - "lead", - UnresolvedAttribute("c") :: Nil), + AggregateExpression(Count(UnresolvedAttribute("b")), Complete, isDistinct = true), WindowSpecDefinition( UnresolvedAttribute("a") :: Nil, SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, UnspecifiedFrame)).as('window)), - "lead" :: "window functions currently requires a HiveContext" :: Nil) + "Distinct window functions are not supported" :: Nil) + + errorTest( + "offset window function", + testRelation2.select( + WindowExpression( + new Lead(UnresolvedAttribute("b")), + WindowSpecDefinition( + UnresolvedAttribute("a") :: Nil, + SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + SpecifiedWindowFrame(RangeFrame, ValueFollowing(1), ValueFollowing(2)))).as('window)), + "window frame" :: "must match the required frame" :: Nil) errorTest( "too many generators", @@ -96,6 +182,18 @@ class AnalysisErrorSuite extends AnalysisTest { testRelation.select('abcd), "cannot resolve" :: "abcd" :: Nil) + errorTest( + "unresolved attributes with a generated name", + testRelation2.groupBy('a)(max('b)) + .where(sum('b) > 0) + .orderBy('havingCondition.asc), + "cannot resolve" :: "havingCondition" :: Nil) + + errorTest( + "unresolved star expansion in max", + testRelation2.groupBy('a)(sum(UnresolvedStar(None))), + "Invalid usage of '*'" :: "in expression 'sum'" :: Nil) + errorTest( "bad casts", testRelation.select(Literal(1).cast(BinaryType).as('badCast)), @@ -103,8 +201,13 @@ class AnalysisErrorSuite extends AnalysisTest { errorTest( "sorting by unsupported column types", - listRelation.orderBy('list.asc), - "sort" :: "type" :: "array" :: Nil) + mapRelation.orderBy('map.asc), + "sort" :: "type" :: "map" :: Nil) + + errorTest( + "sorting by attributes are not from grouping expressions", + testRelation2.groupBy('a, 'c)('a, 'c, count('a).as("a3")).orderBy('b.asc), + "cannot resolve" :: "'`b`'" :: "given input columns" :: "[a, c, a3]" :: Nil) errorTest( "non-boolean filters", @@ -119,7 +222,7 @@ class AnalysisErrorSuite extends AnalysisTest { errorTest( "missing group by", testRelation2.groupBy('a)('b), - "'b'" :: "group by" :: Nil + "'`b`'" :: "group by" :: Nil ) errorTest( @@ -147,7 +250,7 @@ class AnalysisErrorSuite extends AnalysisTest { errorTest( "union with unequal number of columns", - testRelation.unionAll(testRelation2), + testRelation.union(testRelation2), "union" :: "number of columns" :: testRelation2.output.length.toString :: testRelation.output.length.toString :: Nil) @@ -167,20 +270,78 @@ class AnalysisErrorSuite extends AnalysisTest { "SPARK-9955: correct error message for aggregate", // When parse SQL string, we will wrap aggregate expressions with UnresolvedAlias. testRelation2.where('bad_column > 1).groupBy('a)(UnresolvedAlias(max('b))), - "cannot resolve 'bad_column'" :: Nil) + "cannot resolve '`bad_column`'" :: Nil) + + errorTest( + "slide duration greater than window in time window", + testRelation2.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "2 second", "0 second").as("window")), + s"The slide duration " :: " must be less than or equal to the windowDuration " :: Nil + ) + + errorTest( + "start time greater than slide duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "1 minute").as("window")), + "The start time " :: " must be less than the slideDuration " :: Nil + ) + + errorTest( + "start time equal to slide duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "1 second").as("window")), + "The start time " :: " must be less than the slideDuration " :: Nil + ) + + errorTest( + "negative window duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "-1 second", "1 second", "0 second").as("window")), + "The window duration " :: " must be greater than 0." :: Nil + ) + + errorTest( + "zero window duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "0 second", "1 second", "0 second").as("window")), + "The window duration " :: " must be greater than 0." :: Nil + ) + + errorTest( + "negative slide duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "-1 second", "0 second").as("window")), + "The slide duration " :: " must be greater than 0." :: Nil + ) + + errorTest( + "zero slide duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "0 second", "0 second").as("window")), + "The slide duration" :: " must be greater than 0." :: Nil + ) + + errorTest( + "negative start time in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "-5 second").as("window")), + "The start time" :: "must be greater than or equal to 0." :: Nil + ) test("SPARK-6452 regression test") { // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) + // Since we manually construct the logical plan at here and Sum only accept + // LongType, DoubleType, and DecimalType. We use LongType as the type of a. val plan = Aggregate( Nil, - Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil, + Alias(sum(AttributeReference("a", LongType)(exprId = ExprId(1))), "b")() :: Nil, LocalRelation( - AttributeReference("a", IntegerType)(exprId = ExprId(2)))) + AttributeReference("a", LongType)(exprId = ExprId(2)))) assert(plan.resolved) - assertAnalysisError(plan, "resolved attribute(s) a#1 missing from a#2" :: Nil) + assertAnalysisError(plan, "resolved attribute(s) a#1L missing from a#2L" :: Nil) } test("error test for self-join") { @@ -192,28 +353,66 @@ class AnalysisErrorSuite extends AnalysisTest { assert(error.message.contains("Conflicting attributes")) } - test("aggregation can't work on binary and map types") { - val plan = - Aggregate( - AttributeReference("a", BinaryType)(exprId = ExprId(2)) :: Nil, - Alias(Sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, - LocalRelation( - AttributeReference("a", BinaryType)(exprId = ExprId(2)), - AttributeReference("b", IntegerType)(exprId = ExprId(1)))) + test("check grouping expression data types") { + def checkDataType(dataType: DataType, shouldSuccess: Boolean): Unit = { + val plan = + Aggregate( + AttributeReference("a", dataType)(exprId = ExprId(2)) :: Nil, + Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, + LocalRelation( + AttributeReference("a", dataType)(exprId = ExprId(2)), + AttributeReference("b", IntegerType)(exprId = ExprId(1)))) + + shouldSuccess match { + case true => + assertAnalysisSuccess(plan, true) + case false => + assertAnalysisError(plan, "expression `a` cannot be used as a grouping expression" :: Nil) + } + } + + val supportedDataTypes = Seq( + StringType, BinaryType, + NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", StringType, nullable = true), + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true), + new GroupableUDT()) + supportedDataTypes.foreach { dataType => + checkDataType(dataType, shouldSuccess = true) + } - assertAnalysisError(plan, - "binary type expression a cannot be used in grouping expression" :: Nil) + val unsupportedDataTypes = Seq( + MapType(StringType, LongType), + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", MapType(StringType, LongType), nullable = true), + new UngroupableUDT()) + unsupportedDataTypes.foreach { dataType => + checkDataType(dataType, shouldSuccess = false) + } + } - val plan2 = + test("we should fail analysis when we find nested aggregate functions") { + val plan = Aggregate( - AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)) :: Nil, - Alias(Sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, + AttributeReference("a", IntegerType)(exprId = ExprId(2)) :: Nil, + Alias(sum(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1)))), "c")() :: Nil, LocalRelation( - AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)), + AttributeReference("a", IntegerType)(exprId = ExprId(2)), AttributeReference("b", IntegerType)(exprId = ExprId(1)))) - assertAnalysisError(plan2, - "map type expression a cannot be used in grouping expression" :: Nil) + assertAnalysisError( + plan, + "It is not allowed to use an aggregate function in the argument of " + + "another aggregate function." :: Nil) } test("Join can't work on binary and map types") { @@ -229,7 +428,7 @@ class AnalysisErrorSuite extends AnalysisTest { Some(EqualTo(AttributeReference("a", BinaryType)(exprId = ExprId(2)), AttributeReference("c", BinaryType)(exprId = ExprId(4))))) - assertAnalysisError(plan, "binary type expression a cannot be used in join conditions" :: Nil) + assertAnalysisError(plan, "binary type expression `a` cannot be used in join conditions" :: Nil) val plan2 = Join( @@ -243,6 +442,6 @@ class AnalysisErrorSuite extends AnalysisTest { Some(EqualTo(AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)), AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4))))) - assertAnalysisError(plan2, "map type expression a cannot be used in join conditions" :: Nil) + assertAnalysisError(plan2, "map type expression `a` cannot be used in join conditions" :: Nil) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 71d2939ecffe6..a63d1770f3255 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -28,10 +29,10 @@ class AnalysisSuite extends AnalysisTest { import org.apache.spark.sql.catalyst.analysis.TestRelations._ test("union project *") { - val plan = (1 to 100) + val plan = (1 to 120) .map(_ => testRelation) .fold[LogicalPlan](testRelation) { (a, b) => - a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None))) + a.select(UnresolvedStar(None)).select('a).union(b.select(UnresolvedStar(None))) } assertAnalysisSuccess(plan) @@ -45,7 +46,7 @@ class AnalysisSuite extends AnalysisTest { val explode = Explode(AttributeReference("a", IntegerType, nullable = true)()) assert(!Project(Seq(Alias(explode, "explode")()), testRelation).resolved) - assert(!Project(Seq(Alias(Count(Literal(1)), "count")()), testRelation).resolved) + assert(!Project(Seq(Alias(count(Literal(1)), "count")()), testRelation).resolved) } test("analyze project") { @@ -76,15 +77,94 @@ class AnalysisSuite extends AnalysisTest { caseSensitive = false) } - test("resolve relations") { - assertAnalysisError( - UnresolvedRelation(TableIdentifier("tAbLe"), None), Seq("Table not found: tAbLe")) + test("resolve sort references - filter/limit") { + val a = testRelation2.output(0) + val b = testRelation2.output(1) + val c = testRelation2.output(2) - checkAnalysis(UnresolvedRelation(TableIdentifier("TaBlE"), None), testRelation) + // Case 1: one missing attribute is in the leaf node and another is in the unary node + val plan1 = testRelation2 + .where('a > "str").select('a, 'b) + .where('b > "str").select('a) + .sortBy('b.asc, 'c.desc) + val expected1 = testRelation2 + .where(a > "str").select(a, b, c) + .where(b > "str").select(a, b, c) + .sortBy(b.asc, c.desc) + .select(a) + checkAnalysis(plan1, expected1) + + // Case 2: all the missing attributes are in the leaf node + val plan2 = testRelation2 + .where('a > "str").select('a) + .where('a > "str").select('a) + .sortBy('b.asc, 'c.desc) + val expected2 = testRelation2 + .where(a > "str").select(a, b, c) + .where(a > "str").select(a, b, c) + .sortBy(b.asc, c.desc) + .select(a) + checkAnalysis(plan2, expected2) + } + test("resolve sort references - join") { + val a = testRelation2.output(0) + val b = testRelation2.output(1) + val c = testRelation2.output(2) + val h = testRelation3.output(3) + + // Case: join itself can resolve all the missing attributes + val plan = testRelation2.join(testRelation3) + .where('a > "str").select('a, 'b) + .sortBy('c.desc, 'h.asc) + val expected = testRelation2.join(testRelation3) + .where(a > "str").select(a, b, c, h) + .sortBy(c.desc, h.asc) + .select(a, b) + checkAnalysis(plan, expected) + } + + test("resolve sort references - aggregate") { + val a = testRelation2.output(0) + val b = testRelation2.output(1) + val c = testRelation2.output(2) + val alias_a3 = count(a).as("a3") + val alias_b = b.as("aggOrder") + + // Case 1: when the child of Sort is not Aggregate, + // the sort reference is handled by the rule ResolveSortReferences + val plan1 = testRelation2 + .groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3")) + .select('a, 'c, 'a3) + .orderBy('b.asc) + + val expected1 = testRelation2 + .groupBy(a, c, b)(a, c, alias_a3, b) + .select(a, c, alias_a3.toAttribute, b) + .orderBy(b.asc) + .select(a, c, alias_a3.toAttribute) + + checkAnalysis(plan1, expected1) + + // Case 2: when the child of Sort is Aggregate, + // the sort reference is handled by the rule ResolveAggregateFunctions + val plan2 = testRelation2 + .groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3")) + .orderBy('b.asc) + + val expected2 = testRelation2 + .groupBy(a, c, b)(a, c, alias_a3, alias_b) + .orderBy(alias_b.toAttribute.asc) + .select(a, c, alias_a3.toAttribute) + + checkAnalysis(plan2, expected2) + } + + test("resolve relations") { + assertAnalysisError(UnresolvedRelation(TableIdentifier("tAbLe"), None), Seq()) + checkAnalysis(UnresolvedRelation(TableIdentifier("TaBlE"), None), testRelation) checkAnalysis( UnresolvedRelation(TableIdentifier("tAbLe"), None), testRelation, caseSensitive = false) - checkAnalysis( UnresolvedRelation(TableIdentifier("TaBlE"), None), testRelation, caseSensitive = false) } @@ -154,6 +234,11 @@ class AnalysisSuite extends AnalysisTest { checkAnalysis(plan, expected) } + test("self intersect should resolve duplicate expression IDs") { + val plan = testRelation.intersect(testRelation) + assertAnalysisSuccess(plan) + } + test("SPARK-8654: invalid CAST in NULL IN(...) expression") { val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(2))), "a")() :: Nil, LocalRelation() @@ -161,7 +246,7 @@ class AnalysisSuite extends AnalysisTest { assertAnalysisSuccess(plan) } - test("SPARK-8654: different types in inlist but can be converted to a commmon type") { + test("SPARK-8654: different types in inlist but can be converted to a common type") { val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(1.2345))), "a")() :: Nil, LocalRelation() ) @@ -174,4 +259,91 @@ class AnalysisSuite extends AnalysisTest { ) assertAnalysisError(plan, Seq("data type mismatch: Arguments must be same type")) } + + test("SPARK-11725: correctly handle null inputs for ScalaUDF") { + val string = testRelation2.output(0) + val double = testRelation2.output(2) + val short = testRelation2.output(4) + val nullResult = Literal.create(null, StringType) + + def checkUDF(udf: Expression, transformed: Expression): Unit = { + checkAnalysis( + Project(Alias(udf, "")() :: Nil, testRelation2), + Project(Alias(transformed, "")() :: Nil, testRelation2) + ) + } + + // non-primitive parameters do not need special null handling + val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil) + val expected1 = udf1 + checkUDF(udf1, expected1) + + // only primitive parameter needs special null handling + val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil) + val expected2 = If(IsNull(double), nullResult, udf2) + checkUDF(udf2, expected2) + + // special null handling should apply to all primitive parameters + val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil) + val expected3 = If( + IsNull(short) || IsNull(double), + nullResult, + udf3) + checkUDF(udf3, expected3) + + // we can skip special null handling for primitive parameters that are not nullable + // TODO: this is disabled for now as we can not completely trust `nullable`. + val udf4 = ScalaUDF( + (s: Short, d: Double) => "x", + StringType, + short :: double.withNullability(false) :: Nil) + val expected4 = If( + IsNull(short), + nullResult, + udf4) + // checkUDF(udf4, expected4) + } + + test("SPARK-11863 mixture of aliases and real columns in order by clause - tpcds 19,55,71") { + val a = testRelation2.output(0) + val c = testRelation2.output(2) + val alias1 = a.as("a1") + val alias2 = c.as("a2") + val alias3 = count(a).as("a3") + + val plan = testRelation2 + .groupBy('a, 'c)('a.as("a1"), 'c.as("a2"), count('a).as("a3")) + .orderBy('a1.asc, 'c.asc) + + val expected = testRelation2 + .groupBy(a, c)(alias1, alias2, alias3) + .orderBy(alias1.toAttribute.asc, alias2.toAttribute.asc) + .select(alias1.toAttribute, alias2.toAttribute, alias3.toAttribute) + checkAnalysis(plan, expected) + } + + test("Eliminate the unnecessary union") { + val plan = Union(testRelation :: Nil) + val expected = testRelation + checkAnalysis(plan, expected) + } + + test("SPARK-12102: Ignore nullablity when comparing two sides of case") { + val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false))) + val plan = relation.select(CaseWhen(Seq((Literal(true), 'a.attr)), 'b).as("val")) + assertAnalysisSuccess(plan) + } + + test("Keep attribute qualifiers after dedup") { + val input = LocalRelation('key.int, 'value.string) + + val query = + Project(Seq($"x.key", $"y.key"), + Join( + Project(Seq($"x.key"), SubqueryAlias("x", input)), + Project(Seq($"y.key"), SubqueryAlias("y", input)), + Inner, None)) + + assertAnalysisSuccess(query) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 23861ed15da61..b1fcf011f43e6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -18,27 +18,22 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.{TableIdentifier, SimpleCatalystConf} trait AnalysisTest extends PlanTest { - val (caseSensitiveAnalyzer, caseInsensitiveAnalyzer) = { - val caseSensitiveConf = new SimpleCatalystConf(true) - val caseInsensitiveConf = new SimpleCatalystConf(false) + protected val caseSensitiveAnalyzer = makeAnalyzer(caseSensitive = true) + protected val caseInsensitiveAnalyzer = makeAnalyzer(caseSensitive = false) - val caseSensitiveCatalog = new SimpleCatalog(caseSensitiveConf) - val caseInsensitiveCatalog = new SimpleCatalog(caseInsensitiveConf) - - caseSensitiveCatalog.registerTable(TableIdentifier("TaBlE"), TestRelations.testRelation) - caseInsensitiveCatalog.registerTable(TableIdentifier("TaBlE"), TestRelations.testRelation) - - new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitiveConf) { - override val extendedResolutionRules = EliminateSubQueries :: Nil - } -> - new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseInsensitiveConf) { - override val extendedResolutionRules = EliminateSubQueries :: Nil + private def makeAnalyzer(caseSensitive: Boolean): Analyzer = { + val conf = new SimpleCatalystConf(caseSensitive) + val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) + catalog.createTempTable("TaBlE", TestRelations.testRelation, overrideIfExists = true) + new Analyzer(catalog, conf) { + override val extendedResolutionRules = EliminateSubqueryAliases :: Nil } } @@ -60,7 +55,18 @@ trait AnalysisTest extends PlanTest { inputPlan: LogicalPlan, caseSensitive: Boolean = true): Unit = { val analyzer = getAnalyzer(caseSensitive) - analyzer.checkAnalysis(analyzer.execute(inputPlan)) + val analysisAttempt = analyzer.execute(inputPlan) + try analyzer.checkAnalysis(analysisAttempt) catch { + case a: AnalysisException => + fail( + s""" + |Failed to Analyze Plan + |$inputPlan + | + |Partial Analysis + |$analysisAttempt + """.stripMargin, a) + } } protected def assertAnalysisError( @@ -71,8 +77,17 @@ trait AnalysisTest extends PlanTest { val e = intercept[AnalysisException] { analyzer.checkAnalysis(analyzer.execute(inputPlan)) } - assert(expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains), - s"Expected to throw Exception contains: ${expectedErrors.mkString(", ")}, " + - s"actually we get ${e.getMessage}") + + if (!expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains)) { + fail( + s"""Exception message should contain the following substrings: + | + | ${expectedErrors.mkString("\n ")} + | + |Actual exception message: + | + | ${e.getMessage} + """.stripMargin) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 40c4ae7920918..b3b1f5b920a53 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -19,18 +19,23 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.BeforeAndAfter -import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation} +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project, Union} import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.{TableIdentifier, SimpleCatalystConf} -class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { - val conf = new SimpleCatalystConf(true) - val catalog = new SimpleCatalog(conf) - val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf) - val relation = LocalRelation( +class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter { + private val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true) + private val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) + private val analyzer = new Analyzer(catalog, conf) + + private val relation = LocalRelation( AttributeReference("i", IntegerType)(), AttributeReference("d1", DecimalType(2, 1))(), AttributeReference("d2", DecimalType(5, 2))(), @@ -39,15 +44,15 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { AttributeReference("b", DoubleType)() ) - val i: Expression = UnresolvedAttribute("i") - val d1: Expression = UnresolvedAttribute("d1") - val d2: Expression = UnresolvedAttribute("d2") - val u: Expression = UnresolvedAttribute("u") - val f: Expression = UnresolvedAttribute("f") - val b: Expression = UnresolvedAttribute("b") + private val i: Expression = UnresolvedAttribute("i") + private val d1: Expression = UnresolvedAttribute("d1") + private val d2: Expression = UnresolvedAttribute("d2") + private val u: Expression = UnresolvedAttribute("u") + private val f: Expression = UnresolvedAttribute("f") + private val b: Expression = UnresolvedAttribute("b") before { - catalog.registerTable(TableIdentifier("table"), relation) + catalog.createTempTable("table", relation, overrideIfExists = true) } private def checkType(expression: Expression, expectedType: DataType): Unit = { @@ -69,7 +74,7 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { Union(Project(Seq(Alias(left, "l")()), relation), Project(Seq(Alias(right, "r")()), relation)) val (l, r) = analyzer.execute(plan).collect { - case Union(left, right) => (left.output.head, right.output.head) + case Union(Seq(child1, child2)) => (child1.output.head, child2.output.head) }.head assert(l.dataType === expectedType) assert(r.dataType === expectedType) @@ -180,4 +185,94 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { assert(d4.isWiderThan(FloatType) === false) assert(d4.isWiderThan(DoubleType) === false) } + + test("strength reduction for integer/decimal comparisons - basic test") { + Seq(ByteType, ShortType, IntegerType, LongType).foreach { dt => + val int = AttributeReference("a", dt)() + + ruleTest(int > Literal(Decimal(4)), int > Literal(4L)) + ruleTest(int > Literal(Decimal(4.7)), int > Literal(4L)) + + ruleTest(int >= Literal(Decimal(4)), int >= Literal(4L)) + ruleTest(int >= Literal(Decimal(4.7)), int >= Literal(5L)) + + ruleTest(int < Literal(Decimal(4)), int < Literal(4L)) + ruleTest(int < Literal(Decimal(4.7)), int < Literal(5L)) + + ruleTest(int <= Literal(Decimal(4)), int <= Literal(4L)) + ruleTest(int <= Literal(Decimal(4.7)), int <= Literal(4L)) + + ruleTest(Literal(Decimal(4)) > int, Literal(4L) > int) + ruleTest(Literal(Decimal(4.7)) > int, Literal(5L) > int) + + ruleTest(Literal(Decimal(4)) >= int, Literal(4L) >= int) + ruleTest(Literal(Decimal(4.7)) >= int, Literal(4L) >= int) + + ruleTest(Literal(Decimal(4)) < int, Literal(4L) < int) + ruleTest(Literal(Decimal(4.7)) < int, Literal(4L) < int) + + ruleTest(Literal(Decimal(4)) <= int, Literal(4L) <= int) + ruleTest(Literal(Decimal(4.7)) <= int, Literal(5L) <= int) + + } + } + + test("strength reduction for integer/decimal comparisons - overflow test") { + val maxValue = Literal(Decimal(Long.MaxValue)) + val overflow = Literal(Decimal(Long.MaxValue) + Decimal(0.1)) + val minValue = Literal(Decimal(Long.MinValue)) + val underflow = Literal(Decimal(Long.MinValue) - Decimal(0.1)) + + Seq(ByteType, ShortType, IntegerType, LongType).foreach { dt => + val int = AttributeReference("a", dt)() + + ruleTest(int > maxValue, int > Literal(Long.MaxValue)) + ruleTest(int > overflow, FalseLiteral) + ruleTest(int > minValue, int > Literal(Long.MinValue)) + ruleTest(int > underflow, TrueLiteral) + + ruleTest(int >= maxValue, int >= Literal(Long.MaxValue)) + ruleTest(int >= overflow, FalseLiteral) + ruleTest(int >= minValue, int >= Literal(Long.MinValue)) + ruleTest(int >= underflow, TrueLiteral) + + ruleTest(int < maxValue, int < Literal(Long.MaxValue)) + ruleTest(int < overflow, TrueLiteral) + ruleTest(int < minValue, int < Literal(Long.MinValue)) + ruleTest(int < underflow, FalseLiteral) + + ruleTest(int <= maxValue, int <= Literal(Long.MaxValue)) + ruleTest(int <= overflow, TrueLiteral) + ruleTest(int <= minValue, int <= Literal(Long.MinValue)) + ruleTest(int <= underflow, FalseLiteral) + + ruleTest(maxValue > int, Literal(Long.MaxValue) > int) + ruleTest(overflow > int, TrueLiteral) + ruleTest(minValue > int, Literal(Long.MinValue) > int) + ruleTest(underflow > int, FalseLiteral) + + ruleTest(maxValue >= int, Literal(Long.MaxValue) >= int) + ruleTest(overflow >= int, TrueLiteral) + ruleTest(minValue >= int, Literal(Long.MinValue) >= int) + ruleTest(underflow >= int, FalseLiteral) + + ruleTest(maxValue < int, Literal(Long.MaxValue) < int) + ruleTest(overflow < int, FalseLiteral) + ruleTest(minValue < int, Literal(Long.MinValue) < int) + ruleTest(underflow < int, TrueLiteral) + + ruleTest(maxValue <= int, Literal(Long.MaxValue) <= int) + ruleTest(overflow <= int, FalseLiteral) + ruleTest(minValue <= int, Literal(Long.MinValue) <= int) + ruleTest(underflow <= int, TrueLiteral) + } + } + + /** strength reduction for integer/decimal comparisons */ + def ruleTest(initial: Expression, transformed: Expression): Unit = { + val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) + comparePlans( + DecimalPrecision(Project(Seq(Alias(initial, "a")()), testRelation)), + Project(Seq(Alias(transformed, "a")()), testRelation)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index c9bcc68f02030..ace6e10c6ec30 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -22,8 +22,9 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.types.{TypeCollection, StringType} +import org.apache.spark.sql.types.{LongType, StringType, TypeCollection} class ExpressionTypeCheckingSuite extends SparkFunSuite { @@ -31,14 +32,16 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { 'intField.int, 'stringField.string, 'booleanField.boolean, - 'complexField.array(StringType)) + 'decimalField.decimal(8, 0), + 'arrayField.array(StringType), + 'mapField.map(StringType, LongType)) def assertError(expr: Expression, errorMessage: String): Unit = { val e = intercept[AnalysisException] { assertSuccess(expr) } assert(e.getMessage.contains( - s"cannot resolve '${expr.prettyString}' due to data type mismatch:")) + s"cannot resolve '${expr.sql}' due to data type mismatch:")) assert(e.getMessage.contains(errorMessage)) } @@ -49,7 +52,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { def assertErrorForDifferingTypes(expr: Expression): Unit = { assertError(expr, - s"differing types in '${expr.prettyString}'") + s"differing types in '${expr.sql}'") } test("check types for unary arithmetic") { @@ -89,9 +92,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(BitwiseOr('booleanField, 'booleanField), "requires integral type") assertError(BitwiseXor('booleanField, 'booleanField), "requires integral type") - assertError(MaxOf('complexField, 'complexField), + assertError(MaxOf('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") - assertError(MinOf('complexField, 'complexField), + assertError(MinOf('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") } @@ -108,20 +111,20 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(EqualTo('intField, 'booleanField)) assertSuccess(EqualNullSafe('intField, 'booleanField)) - assertErrorForDifferingTypes(EqualTo('intField, 'complexField)) - assertErrorForDifferingTypes(EqualNullSafe('intField, 'complexField)) + assertErrorForDifferingTypes(EqualTo('intField, 'mapField)) + assertErrorForDifferingTypes(EqualNullSafe('intField, 'mapField)) assertErrorForDifferingTypes(LessThan('intField, 'booleanField)) assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) - assertError(LessThan('complexField, 'complexField), + assertError(LessThan('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") - assertError(LessThanOrEqual('complexField, 'complexField), + assertError(LessThanOrEqual('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") - assertError(GreaterThan('complexField, 'complexField), + assertError(GreaterThan('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") - assertError(GreaterThanOrEqual('complexField, 'complexField), + assertError(GreaterThanOrEqual('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") assertError(If('intField, 'stringField, 'stringField), @@ -129,26 +132,28 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField)) assertError( - CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'complexField)), + CaseWhen(Seq(('booleanField.attr, 'intField.attr), ('booleanField.attr, 'mapField.attr))), "THEN and ELSE expressions should all be same type or coercible to a common type") assertError( - CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'complexField)), + CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'mapField)), "THEN and ELSE expressions should all be same type or coercible to a common type") assertError( - CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)), + CaseWhen(Seq(('booleanField.attr, 'intField.attr), ('intField.attr, 'intField.attr))), "WHEN expressions in CaseWhen should all be boolean type") } test("check types for aggregates") { + // We use AggregateFunction directly at here because the error will be thrown from it + // instead of from AggregateExpression, which is the wrapper of an AggregateFunction. + // We will cast String to Double for sum and average assertSuccess(Sum('stringField)) - assertSuccess(SumDistinct('stringField)) assertSuccess(Average('stringField)) + assertSuccess(Min('arrayField)) - assertError(Min('complexField), "min does not support ordering on type") - assertError(Max('complexField), "max does not support ordering on type") + assertError(Min('mapField), "min does not support ordering on type") + assertError(Max('mapField), "max does not support ordering on type") assertError(Sum('booleanField), "function sum requires numeric type") - assertError(SumDistinct('booleanField), "function sumDistinct requires numeric type") assertError(Average('booleanField), "function average requires numeric type") } @@ -158,6 +163,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Coalesce(Seq('intField, 'booleanField)), "input to function coalesce should all be the same type") assertError(Coalesce(Nil), "input to function coalesce cannot be empty") + assertError(new Murmur3Hash(Nil), "function hash requires at least one argument") assertError(Explode('intField), "input to function explode should be array or map type") } @@ -167,13 +173,23 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments") assertError( CreateNamedStruct(Seq(1, "a", "b", 2.0)), - "Only foldable StringType expressions are allowed to appear at odd position") + "Only foldable StringType expressions are allowed to appear at odd position") assertError( CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), - "Only foldable StringType expressions are allowed to appear at odd position") + "Only foldable StringType expressions are allowed to appear at odd position") assertError( CreateNamedStruct(Seq(Literal.create(null, StringType), "a")), - "Field name should not be null") + "Field name should not be null") + } + + test("check types for CreateMap") { + assertError(CreateMap(Seq("a", "b", 2.0)), "even number of arguments") + assertError( + CreateMap(Seq('intField, 'stringField, 'booleanField, 'stringField)), + "keys of function map should all be the same type") + assertError( + CreateMap(Seq('stringField, 'intField, 'stringField, 'booleanField)), + "values of function map should all be the same type") } test("check types for ROUND") { @@ -182,7 +198,16 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Round('intField, 'intField), "Only foldable Expression is allowed") assertError(Round('intField, 'booleanField), "requires int type") - assertError(Round('intField, 'complexField), "requires int type") + assertError(Round('intField, 'mapField), "requires int type") assertError(Round('booleanField, 'intField), "requires numeric type") } + + test("check types for Greatest/Least") { + for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { + assertError(operator(Seq('booleanField)), "requires at least 2 arguments") + assertError(operator(Seq('intField, 'stringField)), "should all have the same type") + assertError(operator(Seq('intField, 'decimalField)), "should all have the same type") + assertError(operator(Seq('mapField, 'mapField)), "does not support ordering") + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index d3fafaae89938..883ef48984d79 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -19,9 +19,8 @@ package org.apache.spark.sql.catalyst.analysis import java.sql.Timestamp -import org.apache.spark.sql.catalyst.plans.PlanTest - import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ @@ -206,7 +205,7 @@ class HiveTypeCoercionSuite extends PlanTest { Project(Seq(Alias(transformed, "a")()), testRelation)) } - test("cast NullType for expresions that implement ExpectsInputTypes") { + test("cast NullType for expressions that implement ExpectsInputTypes") { import HiveTypeCoercionSuite._ ruleTest(HiveTypeCoercion.ImplicitTypeCasts, @@ -251,6 +250,90 @@ class HiveTypeCoercionSuite extends PlanTest { :: Nil)) } + test("CreateArray casts") { + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + CreateArray(Literal(1.0) + :: Literal(1) + :: Literal.create(1.0, FloatType) + :: Nil), + CreateArray(Cast(Literal(1.0), DoubleType) + :: Cast(Literal(1), DoubleType) + :: Cast(Literal.create(1.0, FloatType), DoubleType) + :: Nil)) + + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + CreateArray(Literal(1.0) + :: Literal(1) + :: Literal("a") + :: Nil), + CreateArray(Cast(Literal(1.0), StringType) + :: Cast(Literal(1), StringType) + :: Cast(Literal("a"), StringType) + :: Nil)) + } + + test("CreateMap casts") { + // type coercion for map keys + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + CreateMap(Literal(1) + :: Literal("a") + :: Literal.create(2.0, FloatType) + :: Literal("b") + :: Nil), + CreateMap(Cast(Literal(1), FloatType) + :: Literal("a") + :: Cast(Literal.create(2.0, FloatType), FloatType) + :: Literal("b") + :: Nil)) + // type coercion for map values + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + CreateMap(Literal(1) + :: Literal("a") + :: Literal(2) + :: Literal(3.0) + :: Nil), + CreateMap(Literal(1) + :: Cast(Literal("a"), StringType) + :: Literal(2) + :: Cast(Literal(3.0), StringType) + :: Nil)) + // type coercion for both map keys and values + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + CreateMap(Literal(1) + :: Literal("a") + :: Literal(2.0) + :: Literal(3.0) + :: Nil), + CreateMap(Cast(Literal(1), DoubleType) + :: Cast(Literal("a"), StringType) + :: Cast(Literal(2.0), DoubleType) + :: Cast(Literal(3.0), StringType) + :: Nil)) + } + + test("greatest/least cast") { + for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + operator(Literal(1.0) + :: Literal(1) + :: Literal.create(1.0, FloatType) + :: Nil), + operator(Cast(Literal(1.0), DoubleType) + :: Cast(Literal(1), DoubleType) + :: Cast(Literal.create(1.0, FloatType), DoubleType) + :: Nil)) + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + operator(Literal(1L) + :: Literal(1) + :: Literal(new java.math.BigDecimal("1000000000000000000000")) + :: Nil), + operator(Cast(Literal(1L), DecimalType(22, 0)) + :: Cast(Literal(1), DecimalType(22, 0)) + :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) + :: Nil)) + } + } + test("nanvl casts") { ruleTest(HiveTypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)), @@ -277,7 +360,7 @@ class HiveTypeCoercionSuite extends PlanTest { } test("type coercion for CaseKeyWhen") { - ruleTest(HiveTypeCoercion.CaseWhenCoercion, + ruleTest(HiveTypeCoercion.ImplicitTypeCasts, CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) @@ -286,19 +369,44 @@ class HiveTypeCoercionSuite extends PlanTest { CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) ) ruleTest(HiveTypeCoercion.CaseWhenCoercion, - CaseWhen(Seq(Literal(true), Literal(1.2), Literal.create(1, DecimalType(7, 2)))), - CaseWhen(Seq( - Literal(true), Literal(1.2), Cast(Literal.create(1, DecimalType(7, 2)), DoubleType))) + CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))), + CaseWhen(Seq((Literal(true), Literal(1.2))), + Cast(Literal.create(1, DecimalType(7, 2)), DoubleType)) ) ruleTest(HiveTypeCoercion.CaseWhenCoercion, - CaseWhen(Seq(Literal(true), Literal(100L), Literal.create(1, DecimalType(7, 2)))), - CaseWhen(Seq( - Literal(true), Cast(Literal(100L), DecimalType(22, 2)), - Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2)))) + CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))), + CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))), + Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2))) ) } - test("type coercion simplification for equal to") { + test("BooleanEquality type cast") { + val be = HiveTypeCoercion.BooleanEquality + // Use something more than a literal to avoid triggering the simplification rules. + val one = Add(Literal(Decimal(1)), Literal(Decimal(0))) + + ruleTest(be, + EqualTo(Literal(true), one), + EqualTo(Cast(Literal(true), one.dataType), one) + ) + + ruleTest(be, + EqualTo(one, Literal(true)), + EqualTo(one, Cast(Literal(true), one.dataType)) + ) + + ruleTest(be, + EqualNullSafe(Literal(true), one), + EqualNullSafe(Cast(Literal(true), one.dataType), one) + ) + + ruleTest(be, + EqualNullSafe(one, Literal(true)), + EqualNullSafe(one, Cast(Literal(true), one.dataType)) + ) + } + + test("BooleanEquality simplification") { val be = HiveTypeCoercion.BooleanEquality ruleTest(be, @@ -340,19 +448,19 @@ class HiveTypeCoercionSuite extends PlanTest { ) } - test("WidenSetOperationTypes for union except and intersect") { - def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = { - logical.output.zip(expectTypes).foreach { case (attr, dt) => - assert(attr.dataType === dt) - } + private def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = { + logical.output.zip(expectTypes).foreach { case (attr, dt) => + assert(attr.dataType === dt) } + } - val left = LocalRelation( + test("WidenSetOperationTypes for except and intersect") { + val firstTable = LocalRelation( AttributeReference("i", IntegerType)(), AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), AttributeReference("b", ByteType)(), AttributeReference("d", DoubleType)()) - val right = LocalRelation( + val secondTable = LocalRelation( AttributeReference("s", StringType)(), AttributeReference("d", DecimalType(2, 1))(), AttributeReference("f", FloatType)(), @@ -361,15 +469,65 @@ class HiveTypeCoercionSuite extends PlanTest { val wt = HiveTypeCoercion.WidenSetOperationTypes val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) - val r1 = wt(Union(left, right)).asInstanceOf[Union] - val r2 = wt(Except(left, right)).asInstanceOf[Except] - val r3 = wt(Intersect(left, right)).asInstanceOf[Intersect] + val r1 = wt(Except(firstTable, secondTable)).asInstanceOf[Except] + val r2 = wt(Intersect(firstTable, secondTable)).asInstanceOf[Intersect] checkOutput(r1.left, expectedTypes) checkOutput(r1.right, expectedTypes) checkOutput(r2.left, expectedTypes) checkOutput(r2.right, expectedTypes) - checkOutput(r3.left, expectedTypes) - checkOutput(r3.right, expectedTypes) + + // Check if a Project is added + assert(r1.left.isInstanceOf[Project]) + assert(r1.right.isInstanceOf[Project]) + assert(r2.left.isInstanceOf[Project]) + assert(r2.right.isInstanceOf[Project]) + + val r3 = wt(Except(firstTable, firstTable)).asInstanceOf[Except] + checkOutput(r3.left, Seq(IntegerType, DecimalType.SYSTEM_DEFAULT, ByteType, DoubleType)) + checkOutput(r3.right, Seq(IntegerType, DecimalType.SYSTEM_DEFAULT, ByteType, DoubleType)) + + // Check if no Project is added + assert(r3.left.isInstanceOf[LocalRelation]) + assert(r3.right.isInstanceOf[LocalRelation]) + } + + test("WidenSetOperationTypes for union") { + val firstTable = LocalRelation( + AttributeReference("i", IntegerType)(), + AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("b", ByteType)(), + AttributeReference("d", DoubleType)()) + val secondTable = LocalRelation( + AttributeReference("s", StringType)(), + AttributeReference("d", DecimalType(2, 1))(), + AttributeReference("f", FloatType)(), + AttributeReference("l", LongType)()) + val thirdTable = LocalRelation( + AttributeReference("m", StringType)(), + AttributeReference("n", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("p", FloatType)(), + AttributeReference("q", DoubleType)()) + val forthTable = LocalRelation( + AttributeReference("m", StringType)(), + AttributeReference("n", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("p", ByteType)(), + AttributeReference("q", DoubleType)()) + + val wt = HiveTypeCoercion.WidenSetOperationTypes + val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) + + val unionRelation = wt( + Union(firstTable :: secondTable :: thirdTable :: forthTable :: Nil)).asInstanceOf[Union] + assert(unionRelation.children.length == 4) + checkOutput(unionRelation.children.head, expectedTypes) + checkOutput(unionRelation.children(1), expectedTypes) + checkOutput(unionRelation.children(2), expectedTypes) + checkOutput(unionRelation.children(3), expectedTypes) + + assert(unionRelation.children.head.isInstanceOf[Project]) + assert(unionRelation.children(1).isInstanceOf[Project]) + assert(unionRelation.children(2).isInstanceOf[Project]) + assert(unionRelation.children(3).isInstanceOf[Project]) } test("Transform Decimal precision/scale for union except and intersect") { @@ -391,8 +549,8 @@ class HiveTypeCoercionSuite extends PlanTest { val r2 = dp(Except(left1, right1)).asInstanceOf[Except] val r3 = dp(Intersect(left1, right1)).asInstanceOf[Intersect] - checkOutput(r1.left, expectedType1) - checkOutput(r1.right, expectedType1) + checkOutput(r1.children.head, expectedType1) + checkOutput(r1.children.last, expectedType1) checkOutput(r2.left, expectedType1) checkOutput(r2.right, expectedType1) checkOutput(r3.left, expectedType1) @@ -404,7 +562,7 @@ class HiveTypeCoercionSuite extends PlanTest { val expectedTypes = Seq(DecimalType(10, 5), DecimalType(10, 5), DecimalType(15, 5), DecimalType(25, 5), DoubleType, DoubleType) - rightTypes.zip(expectedTypes).map { case (rType, expectedType) => + rightTypes.zip(expectedTypes).foreach { case (rType, expectedType) => val plan2 = LocalRelation( AttributeReference("r", rType)()) @@ -412,7 +570,7 @@ class HiveTypeCoercionSuite extends PlanTest { val r2 = dp(Except(plan1, plan2)).asInstanceOf[Except] val r3 = dp(Intersect(plan1, plan2)).asInstanceOf[Intersect] - checkOutput(r1.right, Seq(expectedType)) + checkOutput(r1.children.last, Seq(expectedType)) checkOutput(r2.right, Seq(expectedType)) checkOutput(r3.right, Seq(expectedType)) @@ -420,7 +578,7 @@ class HiveTypeCoercionSuite extends PlanTest { val r5 = dp(Except(plan2, plan1)).asInstanceOf[Except] val r6 = dp(Intersect(plan2, plan1)).asInstanceOf[Intersect] - checkOutput(r4.left, Seq(expectedType)) + checkOutput(r4.children.last, Seq(expectedType)) checkOutput(r5.left, Seq(expectedType)) checkOutput(r6.left, Seq(expectedType)) } @@ -452,7 +610,6 @@ class HiveTypeCoercionSuite extends PlanTest { ruleTest(dateTimeOperations, Subtract(interval, interval), Subtract(interval, interval)) } - /** * There are rules that need to not fire before child expressions get resolved. * We use this test to make sure those rules do not fire early. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala new file mode 100644 index 0000000000000..1423a8705af27 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation + +class ResolveNaturalJoinSuite extends AnalysisTest { + lazy val a = 'a.string + lazy val b = 'b.string + lazy val c = 'c.string + lazy val aNotNull = a.notNull + lazy val bNotNull = b.notNull + lazy val cNotNull = c.notNull + lazy val r1 = LocalRelation(b, a) + lazy val r2 = LocalRelation(c, a) + lazy val r3 = LocalRelation(aNotNull, bNotNull) + lazy val r4 = LocalRelation(cNotNull, bNotNull) + + test("natural/using inner join") { + val naturalPlan = r1.join(r2, NaturalJoin(Inner), None) + val usingPlan = r1.join(r2, UsingJoin(Inner, Seq(UnresolvedAttribute("a"))), None) + val expected = r1.join(r2, Inner, Some(EqualTo(a, a))).select(a, b, c) + checkAnalysis(naturalPlan, expected) + checkAnalysis(usingPlan, expected) + } + + test("natural/using left join") { + val naturalPlan = r1.join(r2, NaturalJoin(LeftOuter), None) + val usingPlan = r1.join(r2, UsingJoin(LeftOuter, Seq(UnresolvedAttribute("a"))), None) + val expected = r1.join(r2, LeftOuter, Some(EqualTo(a, a))).select(a, b, c) + checkAnalysis(naturalPlan, expected) + checkAnalysis(usingPlan, expected) + } + + test("natural/using right join") { + val naturalPlan = r1.join(r2, NaturalJoin(RightOuter), None) + val usingPlan = r1.join(r2, UsingJoin(RightOuter, Seq(UnresolvedAttribute("a"))), None) + val expected = r1.join(r2, RightOuter, Some(EqualTo(a, a))).select(a, b, c) + checkAnalysis(naturalPlan, expected) + checkAnalysis(usingPlan, expected) + } + + test("natural/using full outer join") { + val naturalPlan = r1.join(r2, NaturalJoin(FullOuter), None) + val usingPlan = r1.join(r2, UsingJoin(FullOuter, Seq(UnresolvedAttribute("a"))), None) + val expected = r1.join(r2, FullOuter, Some(EqualTo(a, a))).select( + Alias(Coalesce(Seq(a, a)), "a")(), b, c) + checkAnalysis(naturalPlan, expected) + checkAnalysis(usingPlan, expected) + } + + test("natural/using inner join with no nullability") { + val naturalPlan = r3.join(r4, NaturalJoin(Inner), None) + val usingPlan = r3.join(r4, UsingJoin(Inner, Seq(UnresolvedAttribute("b"))), None) + val expected = r3.join(r4, Inner, Some(EqualTo(bNotNull, bNotNull))).select( + bNotNull, aNotNull, cNotNull) + checkAnalysis(naturalPlan, expected) + checkAnalysis(usingPlan, expected) + } + + test("natural/using left join with no nullability") { + val naturalPlan = r3.join(r4, NaturalJoin(LeftOuter), None) + val usingPlan = r3.join(r4, UsingJoin(LeftOuter, Seq(UnresolvedAttribute("b"))), None) + val expected = r3.join(r4, LeftOuter, Some(EqualTo(bNotNull, bNotNull))).select( + bNotNull, aNotNull, c) + checkAnalysis(naturalPlan, expected) + checkAnalysis(usingPlan, expected) + } + + test("natural/using right join with no nullability") { + val naturalPlan = r3.join(r4, NaturalJoin(RightOuter), None) + val usingPlan = r3.join(r4, UsingJoin(RightOuter, Seq(UnresolvedAttribute("b"))), None) + val expected = r3.join(r4, RightOuter, Some(EqualTo(bNotNull, bNotNull))).select( + bNotNull, a, cNotNull) + checkAnalysis(naturalPlan, expected) + checkAnalysis(usingPlan, expected) + } + + test("natural/using full outer join with no nullability") { + val naturalPlan = r3.join(r4, NaturalJoin(FullOuter), None) + val usingPlan = r3.join(r4, UsingJoin(FullOuter, Seq(UnresolvedAttribute("b"))), None) + val expected = r3.join(r4, FullOuter, Some(EqualTo(bNotNull, bNotNull))).select( + Alias(Coalesce(Seq(bNotNull, bNotNull)), "b")(), a, c) + checkAnalysis(naturalPlan, expected) + checkAnalysis(usingPlan, expected) + } + + test("using unresolved attribute") { + val usingPlan = r1.join(r2, UsingJoin(Inner, Seq(UnresolvedAttribute("d"))), None) + val error = intercept[AnalysisException] { + SimpleAnalyzer.checkAnalysis(usingPlan) + } + assert(error.message.contains( + "using columns ['d] can not be resolved given input columns: [b, a, c]")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala index 05b870705e7ea..3741a6ba95a86 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala @@ -31,6 +31,12 @@ object TestRelations { AttributeReference("d", DecimalType(10, 2))(), AttributeReference("e", ShortType)()) + val testRelation3 = LocalRelation( + AttributeReference("e", ShortType)(), + AttributeReference("f", StringType)(), + AttributeReference("g", DoubleType)(), + AttributeReference("h", DecimalType(10, 2))()) + val nestedRelation = LocalRelation( AttributeReference("top", StructType( StructField("duplicateField", StringType) :: @@ -48,4 +54,7 @@ object TestRelations { val listRelation = LocalRelation( AttributeReference("list", ArrayType(IntegerType))()) + + val mapRelation = LocalRelation( + AttributeReference("map", MapType(IntegerType, IntegerType))()) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala new file mode 100644 index 0000000000000..f961fe3292be3 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala @@ -0,0 +1,580 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.util.Utils + + +/** + * A reasonable complete test suite (i.e. behaviors) for a [[ExternalCatalog]]. + * + * Implementations of the [[ExternalCatalog]] interface can create test suites by extending this. + */ +abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach { + protected val utils: CatalogTestUtils + import utils._ + + protected def resetState(): Unit = { } + + // Clear all state after each test + override def afterEach(): Unit = { + try { + resetState() + } finally { + super.afterEach() + } + } + + // -------------------------------------------------------------------------- + // Databases + // -------------------------------------------------------------------------- + + test("basic create and list databases") { + val catalog = newEmptyCatalog() + catalog.createDatabase(newDb("default"), ignoreIfExists = true) + assert(catalog.databaseExists("default")) + assert(!catalog.databaseExists("testing")) + assert(!catalog.databaseExists("testing2")) + catalog.createDatabase(newDb("testing"), ignoreIfExists = false) + assert(catalog.databaseExists("testing")) + assert(catalog.listDatabases().toSet == Set("default", "testing")) + catalog.createDatabase(newDb("testing2"), ignoreIfExists = false) + assert(catalog.listDatabases().toSet == Set("default", "testing", "testing2")) + assert(catalog.databaseExists("testing2")) + assert(!catalog.databaseExists("does_not_exist")) + } + + test("get database when a database exists") { + val db1 = newBasicCatalog().getDatabase("db1") + assert(db1.name == "db1") + assert(db1.description.contains("db1")) + } + + test("get database should throw exception when the database does not exist") { + intercept[AnalysisException] { newBasicCatalog().getDatabase("db_that_does_not_exist") } + } + + test("list databases without pattern") { + val catalog = newBasicCatalog() + assert(catalog.listDatabases().toSet == Set("default", "db1", "db2")) + } + + test("list databases with pattern") { + val catalog = newBasicCatalog() + assert(catalog.listDatabases("db").toSet == Set.empty) + assert(catalog.listDatabases("db*").toSet == Set("db1", "db2")) + assert(catalog.listDatabases("*1").toSet == Set("db1")) + assert(catalog.listDatabases("db2").toSet == Set("db2")) + } + + test("drop database") { + val catalog = newBasicCatalog() + catalog.dropDatabase("db1", ignoreIfNotExists = false, cascade = false) + assert(catalog.listDatabases().toSet == Set("default", "db2")) + } + + test("drop database when the database is not empty") { + // Throw exception if there are functions left + val catalog1 = newBasicCatalog() + catalog1.dropTable("db2", "tbl1", ignoreIfNotExists = false) + catalog1.dropTable("db2", "tbl2", ignoreIfNotExists = false) + intercept[AnalysisException] { + catalog1.dropDatabase("db2", ignoreIfNotExists = false, cascade = false) + } + resetState() + + // Throw exception if there are tables left + val catalog2 = newBasicCatalog() + catalog2.dropFunction("db2", "func1") + intercept[AnalysisException] { + catalog2.dropDatabase("db2", ignoreIfNotExists = false, cascade = false) + } + resetState() + + // When cascade is true, it should drop them + val catalog3 = newBasicCatalog() + catalog3.dropDatabase("db2", ignoreIfNotExists = false, cascade = true) + assert(catalog3.listDatabases().toSet == Set("default", "db1")) + } + + test("drop database when the database does not exist") { + val catalog = newBasicCatalog() + + intercept[AnalysisException] { + catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = false, cascade = false) + } + + catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = true, cascade = false) + } + + test("alter database") { + val catalog = newBasicCatalog() + val db1 = catalog.getDatabase("db1") + // Note: alter properties here because Hive does not support altering other fields + catalog.alterDatabase(db1.copy(properties = Map("k" -> "v3", "good" -> "true"))) + val newDb1 = catalog.getDatabase("db1") + assert(db1.properties.isEmpty) + assert(newDb1.properties.size == 2) + assert(newDb1.properties.get("k") == Some("v3")) + assert(newDb1.properties.get("good") == Some("true")) + } + + test("alter database should throw exception when the database does not exist") { + intercept[AnalysisException] { + newBasicCatalog().alterDatabase(newDb("does_not_exist")) + } + } + + // -------------------------------------------------------------------------- + // Tables + // -------------------------------------------------------------------------- + + test("the table type of an external table should be EXTERNAL_TABLE") { + val catalog = newBasicCatalog() + val table = + newTable("external_table1", "db2").copy(tableType = CatalogTableType.EXTERNAL_TABLE) + catalog.createTable("db2", table, ignoreIfExists = false) + val actual = catalog.getTable("db2", "external_table1") + assert(actual.tableType === CatalogTableType.EXTERNAL_TABLE) + } + + test("drop table") { + val catalog = newBasicCatalog() + assert(catalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + catalog.dropTable("db2", "tbl1", ignoreIfNotExists = false) + assert(catalog.listTables("db2").toSet == Set("tbl2")) + } + + test("drop table when database/table does not exist") { + val catalog = newBasicCatalog() + // Should always throw exception when the database does not exist + intercept[AnalysisException] { + catalog.dropTable("unknown_db", "unknown_table", ignoreIfNotExists = false) + } + intercept[AnalysisException] { + catalog.dropTable("unknown_db", "unknown_table", ignoreIfNotExists = true) + } + // Should throw exception when the table does not exist, if ignoreIfNotExists is false + intercept[AnalysisException] { + catalog.dropTable("db2", "unknown_table", ignoreIfNotExists = false) + } + catalog.dropTable("db2", "unknown_table", ignoreIfNotExists = true) + } + + test("rename table") { + val catalog = newBasicCatalog() + assert(catalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + catalog.renameTable("db2", "tbl1", "tblone") + assert(catalog.listTables("db2").toSet == Set("tblone", "tbl2")) + } + + test("rename table when database/table does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.renameTable("unknown_db", "unknown_table", "unknown_table") + } + intercept[AnalysisException] { + catalog.renameTable("db2", "unknown_table", "unknown_table") + } + } + + test("alter table") { + val catalog = newBasicCatalog() + val tbl1 = catalog.getTable("db2", "tbl1") + catalog.alterTable("db2", tbl1.copy(properties = Map("toh" -> "frem"))) + val newTbl1 = catalog.getTable("db2", "tbl1") + assert(!tbl1.properties.contains("toh")) + assert(newTbl1.properties.size == tbl1.properties.size + 1) + assert(newTbl1.properties.get("toh") == Some("frem")) + } + + test("alter table when database/table does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.alterTable("unknown_db", newTable("tbl1", "unknown_db")) + } + intercept[AnalysisException] { + catalog.alterTable("db2", newTable("unknown_table", "db2")) + } + } + + test("get table") { + assert(newBasicCatalog().getTable("db2", "tbl1").identifier.table == "tbl1") + } + + test("get table when database/table does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.getTable("unknown_db", "unknown_table") + } + intercept[AnalysisException] { + catalog.getTable("db2", "unknown_table") + } + } + + test("list tables without pattern") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { catalog.listTables("unknown_db") } + assert(catalog.listTables("db1").toSet == Set.empty) + assert(catalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + } + + test("list tables with pattern") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { catalog.listTables("unknown_db", "*") } + assert(catalog.listTables("db1", "*").toSet == Set.empty) + assert(catalog.listTables("db2", "*").toSet == Set("tbl1", "tbl2")) + assert(catalog.listTables("db2", "tbl*").toSet == Set("tbl1", "tbl2")) + assert(catalog.listTables("db2", "*1").toSet == Set("tbl1")) + } + + // -------------------------------------------------------------------------- + // Partitions + // -------------------------------------------------------------------------- + + test("basic create and list partitions") { + val catalog = newEmptyCatalog() + catalog.createDatabase(newDb("mydb"), ignoreIfExists = false) + catalog.createTable("mydb", newTable("tbl", "mydb"), ignoreIfExists = false) + catalog.createPartitions("mydb", "tbl", Seq(part1, part2), ignoreIfExists = false) + assert(catalogPartitionsEqual(catalog, "mydb", "tbl", Seq(part1, part2))) + } + + test("create partitions when database/table does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.createPartitions("does_not_exist", "tbl1", Seq(), ignoreIfExists = false) + } + intercept[AnalysisException] { + catalog.createPartitions("db2", "does_not_exist", Seq(), ignoreIfExists = false) + } + } + + test("create partitions that already exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.createPartitions("db2", "tbl2", Seq(part1), ignoreIfExists = false) + } + catalog.createPartitions("db2", "tbl2", Seq(part1), ignoreIfExists = true) + } + + test("drop partitions") { + val catalog = newBasicCatalog() + assert(catalogPartitionsEqual(catalog, "db2", "tbl2", Seq(part1, part2))) + catalog.dropPartitions( + "db2", "tbl2", Seq(part1.spec), ignoreIfNotExists = false) + assert(catalogPartitionsEqual(catalog, "db2", "tbl2", Seq(part2))) + resetState() + val catalog2 = newBasicCatalog() + assert(catalogPartitionsEqual(catalog2, "db2", "tbl2", Seq(part1, part2))) + catalog2.dropPartitions( + "db2", "tbl2", Seq(part1.spec, part2.spec), ignoreIfNotExists = false) + assert(catalog2.listPartitions("db2", "tbl2").isEmpty) + } + + test("drop partitions when database/table does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.dropPartitions( + "does_not_exist", "tbl1", Seq(), ignoreIfNotExists = false) + } + intercept[AnalysisException] { + catalog.dropPartitions( + "db2", "does_not_exist", Seq(), ignoreIfNotExists = false) + } + } + + test("drop partitions that do not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.dropPartitions( + "db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = false) + } + catalog.dropPartitions( + "db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = true) + } + + test("get partition") { + val catalog = newBasicCatalog() + assert(catalog.getPartition("db2", "tbl2", part1.spec).spec == part1.spec) + assert(catalog.getPartition("db2", "tbl2", part2.spec).spec == part2.spec) + intercept[AnalysisException] { + catalog.getPartition("db2", "tbl1", part3.spec) + } + } + + test("get partition when database/table does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.getPartition("does_not_exist", "tbl1", part1.spec) + } + intercept[AnalysisException] { + catalog.getPartition("db2", "does_not_exist", part1.spec) + } + } + + test("rename partitions") { + val catalog = newBasicCatalog() + val newPart1 = part1.copy(spec = Map("a" -> "100", "b" -> "101")) + val newPart2 = part2.copy(spec = Map("a" -> "200", "b" -> "201")) + val newSpecs = Seq(newPart1.spec, newPart2.spec) + catalog.renamePartitions("db2", "tbl2", Seq(part1.spec, part2.spec), newSpecs) + assert(catalog.getPartition("db2", "tbl2", newPart1.spec).spec === newPart1.spec) + assert(catalog.getPartition("db2", "tbl2", newPart2.spec).spec === newPart2.spec) + // The old partitions should no longer exist + intercept[AnalysisException] { catalog.getPartition("db2", "tbl2", part1.spec) } + intercept[AnalysisException] { catalog.getPartition("db2", "tbl2", part2.spec) } + } + + test("rename partitions when database/table does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.renamePartitions("does_not_exist", "tbl1", Seq(part1.spec), Seq(part2.spec)) + } + intercept[AnalysisException] { + catalog.renamePartitions("db2", "does_not_exist", Seq(part1.spec), Seq(part2.spec)) + } + } + + test("alter partitions") { + val catalog = newBasicCatalog() + try { + // Note: Before altering table partitions in Hive, you *must* set the current database + // to the one that contains the table of interest. Otherwise you will end up with the + // most helpful error message ever: "Unable to alter partition. alter is not possible." + // See HIVE-2742 for more detail. + catalog.setCurrentDatabase("db2") + val newLocation = newUriForDatabase() + // alter but keep spec the same + val oldPart1 = catalog.getPartition("db2", "tbl2", part1.spec) + val oldPart2 = catalog.getPartition("db2", "tbl2", part2.spec) + catalog.alterPartitions("db2", "tbl2", Seq( + oldPart1.copy(storage = storageFormat.copy(locationUri = Some(newLocation))), + oldPart2.copy(storage = storageFormat.copy(locationUri = Some(newLocation))))) + val newPart1 = catalog.getPartition("db2", "tbl2", part1.spec) + val newPart2 = catalog.getPartition("db2", "tbl2", part2.spec) + assert(newPart1.storage.locationUri == Some(newLocation)) + assert(newPart2.storage.locationUri == Some(newLocation)) + assert(oldPart1.storage.locationUri != Some(newLocation)) + assert(oldPart2.storage.locationUri != Some(newLocation)) + // alter but change spec, should fail because new partition specs do not exist yet + val badPart1 = part1.copy(spec = Map("a" -> "v1", "b" -> "v2")) + val badPart2 = part2.copy(spec = Map("a" -> "v3", "b" -> "v4")) + intercept[AnalysisException] { + catalog.alterPartitions("db2", "tbl2", Seq(badPart1, badPart2)) + } + } finally { + // Remember to restore the original current database, which we assume to be "default" + catalog.setCurrentDatabase("default") + } + } + + test("alter partitions when database/table does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.alterPartitions("does_not_exist", "tbl1", Seq(part1)) + } + intercept[AnalysisException] { + catalog.alterPartitions("db2", "does_not_exist", Seq(part1)) + } + } + + // -------------------------------------------------------------------------- + // Functions + // -------------------------------------------------------------------------- + + test("basic create and list functions") { + val catalog = newEmptyCatalog() + catalog.createDatabase(newDb("mydb"), ignoreIfExists = false) + catalog.createFunction("mydb", newFunc("myfunc")) + assert(catalog.listFunctions("mydb", "*").toSet == Set("myfunc")) + } + + test("create function when database does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.createFunction("does_not_exist", newFunc()) + } + } + + test("create function that already exists") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.createFunction("db2", newFunc("func1")) + } + } + + test("drop function") { + val catalog = newBasicCatalog() + assert(catalog.listFunctions("db2", "*").toSet == Set("func1")) + catalog.dropFunction("db2", "func1") + assert(catalog.listFunctions("db2", "*").isEmpty) + } + + test("drop function when database does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.dropFunction("does_not_exist", "something") + } + } + + test("drop function that does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.dropFunction("db2", "does_not_exist") + } + } + + test("get function") { + val catalog = newBasicCatalog() + assert(catalog.getFunction("db2", "func1") == + CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass, + Seq.empty[(String, String)])) + intercept[AnalysisException] { + catalog.getFunction("db2", "does_not_exist") + } + } + + test("get function when database does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.getFunction("does_not_exist", "func1") + } + } + + test("rename function") { + val catalog = newBasicCatalog() + val newName = "funcky" + assert(catalog.getFunction("db2", "func1").className == funcClass) + catalog.renameFunction("db2", "func1", newName) + intercept[AnalysisException] { catalog.getFunction("db2", "func1") } + assert(catalog.getFunction("db2", newName).identifier.funcName == newName) + assert(catalog.getFunction("db2", newName).className == funcClass) + intercept[AnalysisException] { catalog.renameFunction("db2", "does_not_exist", "me") } + } + + test("rename function when database does not exist") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.renameFunction("does_not_exist", "func1", "func5") + } + } + + test("list functions") { + val catalog = newBasicCatalog() + catalog.createFunction("db2", newFunc("func2")) + catalog.createFunction("db2", newFunc("not_me")) + assert(catalog.listFunctions("db2", "*").toSet == Set("func1", "func2", "not_me")) + assert(catalog.listFunctions("db2", "func*").toSet == Set("func1", "func2")) + } + +} + + +/** + * A collection of utility fields and methods for tests related to the [[ExternalCatalog]]. + */ +abstract class CatalogTestUtils { + + // Unimplemented methods + val tableInputFormat: String + val tableOutputFormat: String + def newEmptyCatalog(): ExternalCatalog + + // These fields must be lazy because they rely on fields that are not implemented yet + lazy val storageFormat = CatalogStorageFormat( + locationUri = None, + inputFormat = Some(tableInputFormat), + outputFormat = Some(tableOutputFormat), + serde = None, + serdeProperties = Map.empty) + lazy val part1 = CatalogTablePartition(Map("a" -> "1", "b" -> "2"), storageFormat) + lazy val part2 = CatalogTablePartition(Map("a" -> "3", "b" -> "4"), storageFormat) + lazy val part3 = CatalogTablePartition(Map("a" -> "5", "b" -> "6"), storageFormat) + lazy val funcClass = "org.apache.spark.myFunc" + + /** + * Creates a basic catalog, with the following structure: + * + * default + * db1 + * db2 + * - tbl1 + * - tbl2 + * - part1 + * - part2 + * - func1 + */ + def newBasicCatalog(): ExternalCatalog = { + val catalog = newEmptyCatalog() + // When testing against a real catalog, the default database may already exist + catalog.createDatabase(newDb("default"), ignoreIfExists = true) + catalog.createDatabase(newDb("db1"), ignoreIfExists = false) + catalog.createDatabase(newDb("db2"), ignoreIfExists = false) + catalog.createTable("db2", newTable("tbl1", "db2"), ignoreIfExists = false) + catalog.createTable("db2", newTable("tbl2", "db2"), ignoreIfExists = false) + catalog.createPartitions("db2", "tbl2", Seq(part1, part2), ignoreIfExists = false) + catalog.createFunction("db2", newFunc("func1", Some("db2"))) + catalog + } + + def newFunc(): CatalogFunction = newFunc("funcName") + + def newUriForDatabase(): String = Utils.createTempDir().getAbsolutePath + + def newDb(name: String): CatalogDatabase = { + CatalogDatabase(name, name + " description", newUriForDatabase(), Map.empty) + } + + def newTable(name: String, db: String): CatalogTable = newTable(name, Some(db)) + + def newTable(name: String, database: Option[String] = None): CatalogTable = { + CatalogTable( + identifier = TableIdentifier(name, database), + tableType = CatalogTableType.EXTERNAL_TABLE, + storage = storageFormat, + schema = Seq( + CatalogColumn("col1", "int"), + CatalogColumn("col2", "string"), + CatalogColumn("a", "int"), + CatalogColumn("b", "string")), + partitionColumnNames = Seq("a", "b")) + } + + def newFunc(name: String, database: Option[String] = None): CatalogFunction = { + CatalogFunction(FunctionIdentifier(name, database), funcClass, Seq.empty[(String, String)]) + } + + /** + * Whether the catalog's table partitions equal the ones given. + * Note: Hive sets some random serde things, so we just compare the specs here. + */ + def catalogPartitionsEqual( + catalog: ExternalCatalog, + db: String, + table: String, + parts: Seq[CatalogTablePartition]): Boolean = { + catalog.listPartitions(db, table).map(_.spec).toSet == parts.map(_.spec).toSet + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalogSuite.scala new file mode 100644 index 0000000000000..63a7b2c661ecb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalogSuite.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + + +/** Test suite for the [[InMemoryCatalog]]. */ +class InMemoryCatalogSuite extends CatalogTestCases { + + protected override val utils: CatalogTestUtils = new CatalogTestUtils { + override val tableInputFormat: String = "org.apache.park.SequenceFileInputFormat" + override val tableOutputFormat: String = "org.apache.park.SequenceFileOutputFormat" + override def newEmptyCatalog(): ExternalCatalog = new InMemoryCatalog + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala new file mode 100644 index 0000000000000..426273e1e3e6c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -0,0 +1,827 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{Range, SubqueryAlias} + + +/** + * Tests for [[SessionCatalog]] that assume that [[InMemoryCatalog]] is correctly implemented. + * + * Note: many of the methods here are very similar to the ones in [[CatalogTestCases]]. + * This is because [[SessionCatalog]] and [[ExternalCatalog]] share many similar method + * signatures but do not extend a common parent. This is largely by design but + * unfortunately leads to very similar test code in two places. + */ +class SessionCatalogSuite extends SparkFunSuite { + private val utils = new CatalogTestUtils { + override val tableInputFormat: String = "com.fruit.eyephone.CameraInputFormat" + override val tableOutputFormat: String = "com.fruit.eyephone.CameraOutputFormat" + override def newEmptyCatalog(): ExternalCatalog = new InMemoryCatalog + } + + import utils._ + + // -------------------------------------------------------------------------- + // Databases + // -------------------------------------------------------------------------- + + test("basic create and list databases") { + val catalog = new SessionCatalog(newEmptyCatalog()) + catalog.createDatabase(newDb("default"), ignoreIfExists = true) + assert(catalog.databaseExists("default")) + assert(!catalog.databaseExists("testing")) + assert(!catalog.databaseExists("testing2")) + catalog.createDatabase(newDb("testing"), ignoreIfExists = false) + assert(catalog.databaseExists("testing")) + assert(catalog.listDatabases().toSet == Set("default", "testing")) + catalog.createDatabase(newDb("testing2"), ignoreIfExists = false) + assert(catalog.listDatabases().toSet == Set("default", "testing", "testing2")) + assert(catalog.databaseExists("testing2")) + assert(!catalog.databaseExists("does_not_exist")) + } + + test("get database when a database exists") { + val catalog = new SessionCatalog(newBasicCatalog()) + val db1 = catalog.getDatabaseMetadata("db1") + assert(db1.name == "db1") + assert(db1.description.contains("db1")) + } + + test("get database should throw exception when the database does not exist") { + val catalog = new SessionCatalog(newBasicCatalog()) + intercept[AnalysisException] { + catalog.getDatabaseMetadata("db_that_does_not_exist") + } + } + + test("list databases without pattern") { + val catalog = new SessionCatalog(newBasicCatalog()) + assert(catalog.listDatabases().toSet == Set("default", "db1", "db2")) + } + + test("list databases with pattern") { + val catalog = new SessionCatalog(newBasicCatalog()) + assert(catalog.listDatabases("db").toSet == Set.empty) + assert(catalog.listDatabases("db*").toSet == Set("db1", "db2")) + assert(catalog.listDatabases("*1").toSet == Set("db1")) + assert(catalog.listDatabases("db2").toSet == Set("db2")) + } + + test("drop database") { + val catalog = new SessionCatalog(newBasicCatalog()) + catalog.dropDatabase("db1", ignoreIfNotExists = false, cascade = false) + assert(catalog.listDatabases().toSet == Set("default", "db2")) + } + + test("drop database when the database is not empty") { + // Throw exception if there are functions left + val externalCatalog1 = newBasicCatalog() + val sessionCatalog1 = new SessionCatalog(externalCatalog1) + externalCatalog1.dropTable("db2", "tbl1", ignoreIfNotExists = false) + externalCatalog1.dropTable("db2", "tbl2", ignoreIfNotExists = false) + intercept[AnalysisException] { + sessionCatalog1.dropDatabase("db2", ignoreIfNotExists = false, cascade = false) + } + + // Throw exception if there are tables left + val externalCatalog2 = newBasicCatalog() + val sessionCatalog2 = new SessionCatalog(externalCatalog2) + externalCatalog2.dropFunction("db2", "func1") + intercept[AnalysisException] { + sessionCatalog2.dropDatabase("db2", ignoreIfNotExists = false, cascade = false) + } + + // When cascade is true, it should drop them + val externalCatalog3 = newBasicCatalog() + val sessionCatalog3 = new SessionCatalog(externalCatalog3) + externalCatalog3.dropDatabase("db2", ignoreIfNotExists = false, cascade = true) + assert(sessionCatalog3.listDatabases().toSet == Set("default", "db1")) + } + + test("drop database when the database does not exist") { + val catalog = new SessionCatalog(newBasicCatalog()) + intercept[AnalysisException] { + catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = false, cascade = false) + } + catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = true, cascade = false) + } + + test("alter database") { + val catalog = new SessionCatalog(newBasicCatalog()) + val db1 = catalog.getDatabaseMetadata("db1") + // Note: alter properties here because Hive does not support altering other fields + catalog.alterDatabase(db1.copy(properties = Map("k" -> "v3", "good" -> "true"))) + val newDb1 = catalog.getDatabaseMetadata("db1") + assert(db1.properties.isEmpty) + assert(newDb1.properties.size == 2) + assert(newDb1.properties.get("k") == Some("v3")) + assert(newDb1.properties.get("good") == Some("true")) + } + + test("alter database should throw exception when the database does not exist") { + val catalog = new SessionCatalog(newBasicCatalog()) + intercept[AnalysisException] { + catalog.alterDatabase(newDb("does_not_exist")) + } + } + + test("get/set current database") { + val catalog = new SessionCatalog(newBasicCatalog()) + assert(catalog.getCurrentDatabase == "default") + catalog.setCurrentDatabase("db2") + assert(catalog.getCurrentDatabase == "db2") + intercept[AnalysisException] { + catalog.setCurrentDatabase("deebo") + } + catalog.createDatabase(newDb("deebo"), ignoreIfExists = false) + catalog.setCurrentDatabase("deebo") + assert(catalog.getCurrentDatabase == "deebo") + } + + // -------------------------------------------------------------------------- + // Tables + // -------------------------------------------------------------------------- + + test("create table") { + val externalCatalog = newBasicCatalog() + val sessionCatalog = new SessionCatalog(externalCatalog) + assert(externalCatalog.listTables("db1").isEmpty) + assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + sessionCatalog.createTable(newTable("tbl3", "db1"), ignoreIfExists = false) + sessionCatalog.createTable(newTable("tbl3", "db2"), ignoreIfExists = false) + assert(externalCatalog.listTables("db1").toSet == Set("tbl3")) + assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2", "tbl3")) + // Create table without explicitly specifying database + sessionCatalog.setCurrentDatabase("db1") + sessionCatalog.createTable(newTable("tbl4"), ignoreIfExists = false) + assert(externalCatalog.listTables("db1").toSet == Set("tbl3", "tbl4")) + assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2", "tbl3")) + } + + test("create table when database does not exist") { + val catalog = new SessionCatalog(newBasicCatalog()) + // Creating table in non-existent database should always fail + intercept[AnalysisException] { + catalog.createTable(newTable("tbl1", "does_not_exist"), ignoreIfExists = false) + } + intercept[AnalysisException] { + catalog.createTable(newTable("tbl1", "does_not_exist"), ignoreIfExists = true) + } + // Table already exists + intercept[AnalysisException] { + catalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = false) + } + catalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = true) + } + + test("create temp table") { + val catalog = new SessionCatalog(newBasicCatalog()) + val tempTable1 = Range(1, 10, 1, 10, Seq()) + val tempTable2 = Range(1, 20, 2, 10, Seq()) + catalog.createTempTable("tbl1", tempTable1, overrideIfExists = false) + catalog.createTempTable("tbl2", tempTable2, overrideIfExists = false) + assert(catalog.getTempTable("tbl1") == Some(tempTable1)) + assert(catalog.getTempTable("tbl2") == Some(tempTable2)) + assert(catalog.getTempTable("tbl3") == None) + // Temporary table already exists + intercept[AnalysisException] { + catalog.createTempTable("tbl1", tempTable1, overrideIfExists = false) + } + // Temporary table already exists but we override it + catalog.createTempTable("tbl1", tempTable2, overrideIfExists = true) + assert(catalog.getTempTable("tbl1") == Some(tempTable2)) + } + + test("drop table") { + val externalCatalog = newBasicCatalog() + val sessionCatalog = new SessionCatalog(externalCatalog) + assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + sessionCatalog.dropTable(TableIdentifier("tbl1", Some("db2")), ignoreIfNotExists = false) + assert(externalCatalog.listTables("db2").toSet == Set("tbl2")) + // Drop table without explicitly specifying database + sessionCatalog.setCurrentDatabase("db2") + sessionCatalog.dropTable(TableIdentifier("tbl2"), ignoreIfNotExists = false) + assert(externalCatalog.listTables("db2").isEmpty) + } + + test("drop table when database/table does not exist") { + val catalog = new SessionCatalog(newBasicCatalog()) + // Should always throw exception when the database does not exist + intercept[AnalysisException] { + catalog.dropTable(TableIdentifier("tbl1", Some("unknown_db")), ignoreIfNotExists = false) + } + intercept[AnalysisException] { + catalog.dropTable(TableIdentifier("tbl1", Some("unknown_db")), ignoreIfNotExists = true) + } + // If the table does not exist, we do not issue an exception. Instead, we output an error log + // message to console when ignoreIfNotExists is set to false. + catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = false) + catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = true) + } + + test("drop temp table") { + val externalCatalog = newBasicCatalog() + val sessionCatalog = new SessionCatalog(externalCatalog) + val tempTable = Range(1, 10, 2, 10, Seq()) + sessionCatalog.createTempTable("tbl1", tempTable, overrideIfExists = false) + sessionCatalog.setCurrentDatabase("db2") + assert(sessionCatalog.getTempTable("tbl1") == Some(tempTable)) + assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + // If database is not specified, temp table should be dropped first + sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false) + assert(sessionCatalog.getTempTable("tbl1") == None) + assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + // If temp table does not exist, the table in the current database should be dropped + sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false) + assert(externalCatalog.listTables("db2").toSet == Set("tbl2")) + // If database is specified, temp tables are never dropped + sessionCatalog.createTempTable("tbl1", tempTable, overrideIfExists = false) + sessionCatalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = false) + sessionCatalog.dropTable(TableIdentifier("tbl1", Some("db2")), ignoreIfNotExists = false) + assert(sessionCatalog.getTempTable("tbl1") == Some(tempTable)) + assert(externalCatalog.listTables("db2").toSet == Set("tbl2")) + } + + test("rename table") { + val externalCatalog = newBasicCatalog() + val sessionCatalog = new SessionCatalog(externalCatalog) + assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + sessionCatalog.renameTable( + TableIdentifier("tbl1", Some("db2")), TableIdentifier("tblone", Some("db2"))) + assert(externalCatalog.listTables("db2").toSet == Set("tblone", "tbl2")) + sessionCatalog.renameTable( + TableIdentifier("tbl2", Some("db2")), TableIdentifier("tbltwo", Some("db2"))) + assert(externalCatalog.listTables("db2").toSet == Set("tblone", "tbltwo")) + // Rename table without explicitly specifying database + sessionCatalog.setCurrentDatabase("db2") + sessionCatalog.renameTable(TableIdentifier("tbltwo"), TableIdentifier("table_two")) + assert(externalCatalog.listTables("db2").toSet == Set("tblone", "table_two")) + // Renaming "db2.tblone" to "db1.tblones" should fail because databases don't match + intercept[AnalysisException] { + sessionCatalog.renameTable( + TableIdentifier("tblone", Some("db2")), TableIdentifier("tblones", Some("db1"))) + } + } + + test("rename table when database/table does not exist") { + val catalog = new SessionCatalog(newBasicCatalog()) + intercept[AnalysisException] { + catalog.renameTable( + TableIdentifier("tbl1", Some("unknown_db")), TableIdentifier("tbl2", Some("unknown_db"))) + } + intercept[AnalysisException] { + catalog.renameTable( + TableIdentifier("unknown_table", Some("db2")), TableIdentifier("tbl2", Some("db2"))) + } + } + + test("rename temp table") { + val externalCatalog = newBasicCatalog() + val sessionCatalog = new SessionCatalog(externalCatalog) + val tempTable = Range(1, 10, 2, 10, Seq()) + sessionCatalog.createTempTable("tbl1", tempTable, overrideIfExists = false) + sessionCatalog.setCurrentDatabase("db2") + assert(sessionCatalog.getTempTable("tbl1") == Some(tempTable)) + assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + // If database is not specified, temp table should be renamed first + sessionCatalog.renameTable(TableIdentifier("tbl1"), TableIdentifier("tbl3")) + assert(sessionCatalog.getTempTable("tbl1") == None) + assert(sessionCatalog.getTempTable("tbl3") == Some(tempTable)) + assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + // If database is specified, temp tables are never renamed + sessionCatalog.renameTable( + TableIdentifier("tbl2", Some("db2")), TableIdentifier("tbl4", Some("db2"))) + assert(sessionCatalog.getTempTable("tbl3") == Some(tempTable)) + assert(sessionCatalog.getTempTable("tbl4") == None) + assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl4")) + } + + test("alter table") { + val externalCatalog = newBasicCatalog() + val sessionCatalog = new SessionCatalog(externalCatalog) + val tbl1 = externalCatalog.getTable("db2", "tbl1") + sessionCatalog.alterTable(tbl1.copy(properties = Map("toh" -> "frem"))) + val newTbl1 = externalCatalog.getTable("db2", "tbl1") + assert(!tbl1.properties.contains("toh")) + assert(newTbl1.properties.size == tbl1.properties.size + 1) + assert(newTbl1.properties.get("toh") == Some("frem")) + // Alter table without explicitly specifying database + sessionCatalog.setCurrentDatabase("db2") + sessionCatalog.alterTable(tbl1.copy(identifier = TableIdentifier("tbl1"))) + val newestTbl1 = externalCatalog.getTable("db2", "tbl1") + assert(newestTbl1 == tbl1) + } + + test("alter table when database/table does not exist") { + val catalog = new SessionCatalog(newBasicCatalog()) + intercept[AnalysisException] { + catalog.alterTable(newTable("tbl1", "unknown_db")) + } + intercept[AnalysisException] { + catalog.alterTable(newTable("unknown_table", "db2")) + } + } + + test("get table") { + val externalCatalog = newBasicCatalog() + val sessionCatalog = new SessionCatalog(externalCatalog) + assert(sessionCatalog.getTableMetadata(TableIdentifier("tbl1", Some("db2"))) + == externalCatalog.getTable("db2", "tbl1")) + // Get table without explicitly specifying database + sessionCatalog.setCurrentDatabase("db2") + assert(sessionCatalog.getTableMetadata(TableIdentifier("tbl1")) + == externalCatalog.getTable("db2", "tbl1")) + } + + test("get table when database/table does not exist") { + val catalog = new SessionCatalog(newBasicCatalog()) + intercept[AnalysisException] { + catalog.getTableMetadata(TableIdentifier("tbl1", Some("unknown_db"))) + } + intercept[AnalysisException] { + catalog.getTableMetadata(TableIdentifier("unknown_table", Some("db2"))) + } + } + + test("lookup table relation") { + val externalCatalog = newBasicCatalog() + val sessionCatalog = new SessionCatalog(externalCatalog) + val tempTable1 = Range(1, 10, 1, 10, Seq()) + val metastoreTable1 = externalCatalog.getTable("db2", "tbl1") + sessionCatalog.createTempTable("tbl1", tempTable1, overrideIfExists = false) + sessionCatalog.setCurrentDatabase("db2") + // If we explicitly specify the database, we'll look up the relation in that database + assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1", Some("db2"))) + == SubqueryAlias("tbl1", CatalogRelation("db2", metastoreTable1))) + // Otherwise, we'll first look up a temporary table with the same name + assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1")) + == SubqueryAlias("tbl1", tempTable1)) + // Then, if that does not exist, look up the relation in the current database + sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false) + assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1")) + == SubqueryAlias("tbl1", CatalogRelation("db2", metastoreTable1))) + } + + test("lookup table relation with alias") { + val catalog = new SessionCatalog(newBasicCatalog()) + val alias = "monster" + val tableMetadata = catalog.getTableMetadata(TableIdentifier("tbl1", Some("db2"))) + val relation = SubqueryAlias("tbl1", CatalogRelation("db2", tableMetadata)) + val relationWithAlias = + SubqueryAlias(alias, + SubqueryAlias("tbl1", + CatalogRelation("db2", tableMetadata, Some(alias)))) + assert(catalog.lookupRelation( + TableIdentifier("tbl1", Some("db2")), alias = None) == relation) + assert(catalog.lookupRelation( + TableIdentifier("tbl1", Some("db2")), alias = Some(alias)) == relationWithAlias) + } + + test("table exists") { + val catalog = new SessionCatalog(newBasicCatalog()) + assert(catalog.tableExists(TableIdentifier("tbl1", Some("db2")))) + assert(catalog.tableExists(TableIdentifier("tbl2", Some("db2")))) + assert(!catalog.tableExists(TableIdentifier("tbl3", Some("db2")))) + assert(!catalog.tableExists(TableIdentifier("tbl1", Some("db1")))) + assert(!catalog.tableExists(TableIdentifier("tbl2", Some("db1")))) + // If database is explicitly specified, do not check temporary tables + val tempTable = Range(1, 10, 1, 10, Seq()) + catalog.createTempTable("tbl3", tempTable, overrideIfExists = false) + assert(!catalog.tableExists(TableIdentifier("tbl3", Some("db2")))) + // If database is not explicitly specified, check the current database + catalog.setCurrentDatabase("db2") + assert(catalog.tableExists(TableIdentifier("tbl1"))) + assert(catalog.tableExists(TableIdentifier("tbl2"))) + assert(catalog.tableExists(TableIdentifier("tbl3"))) + } + + test("list tables without pattern") { + val catalog = new SessionCatalog(newBasicCatalog()) + val tempTable = Range(1, 10, 2, 10, Seq()) + catalog.createTempTable("tbl1", tempTable, overrideIfExists = false) + catalog.createTempTable("tbl4", tempTable, overrideIfExists = false) + assert(catalog.listTables("db1").toSet == + Set(TableIdentifier("tbl1"), TableIdentifier("tbl4"))) + assert(catalog.listTables("db2").toSet == + Set(TableIdentifier("tbl1"), + TableIdentifier("tbl4"), + TableIdentifier("tbl1", Some("db2")), + TableIdentifier("tbl2", Some("db2")))) + intercept[AnalysisException] { + catalog.listTables("unknown_db") + } + } + + test("list tables with pattern") { + val catalog = new SessionCatalog(newBasicCatalog()) + val tempTable = Range(1, 10, 2, 10, Seq()) + catalog.createTempTable("tbl1", tempTable, overrideIfExists = false) + catalog.createTempTable("tbl4", tempTable, overrideIfExists = false) + assert(catalog.listTables("db1", "*").toSet == catalog.listTables("db1").toSet) + assert(catalog.listTables("db2", "*").toSet == catalog.listTables("db2").toSet) + assert(catalog.listTables("db2", "tbl*").toSet == + Set(TableIdentifier("tbl1"), + TableIdentifier("tbl4"), + TableIdentifier("tbl1", Some("db2")), + TableIdentifier("tbl2", Some("db2")))) + assert(catalog.listTables("db2", "*1").toSet == + Set(TableIdentifier("tbl1"), TableIdentifier("tbl1", Some("db2")))) + intercept[AnalysisException] { + catalog.listTables("unknown_db", "*") + } + } + + // -------------------------------------------------------------------------- + // Partitions + // -------------------------------------------------------------------------- + + test("basic create and list partitions") { + val externalCatalog = newEmptyCatalog() + val sessionCatalog = new SessionCatalog(externalCatalog) + sessionCatalog.createDatabase(newDb("mydb"), ignoreIfExists = false) + sessionCatalog.createTable(newTable("tbl", "mydb"), ignoreIfExists = false) + sessionCatalog.createPartitions( + TableIdentifier("tbl", Some("mydb")), Seq(part1, part2), ignoreIfExists = false) + assert(catalogPartitionsEqual(externalCatalog, "mydb", "tbl", Seq(part1, part2))) + // Create partitions without explicitly specifying database + sessionCatalog.setCurrentDatabase("mydb") + sessionCatalog.createPartitions(TableIdentifier("tbl"), Seq(part3), ignoreIfExists = false) + assert(catalogPartitionsEqual(externalCatalog, "mydb", "tbl", Seq(part1, part2, part3))) + } + + test("create partitions when database/table does not exist") { + val catalog = new SessionCatalog(newBasicCatalog()) + intercept[AnalysisException] { + catalog.createPartitions( + TableIdentifier("tbl1", Some("does_not_exist")), Seq(), ignoreIfExists = false) + } + intercept[AnalysisException] { + catalog.createPartitions( + TableIdentifier("does_not_exist", Some("db2")), Seq(), ignoreIfExists = false) + } + } + + test("create partitions that already exist") { + val catalog = new SessionCatalog(newBasicCatalog()) + intercept[AnalysisException] { + catalog.createPartitions( + TableIdentifier("tbl2", Some("db2")), Seq(part1), ignoreIfExists = false) + } + catalog.createPartitions( + TableIdentifier("tbl2", Some("db2")), Seq(part1), ignoreIfExists = true) + } + + test("drop partitions") { + val externalCatalog = newBasicCatalog() + val sessionCatalog = new SessionCatalog(externalCatalog) + assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part1, part2))) + sessionCatalog.dropPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(part1.spec), + ignoreIfNotExists = false) + assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part2))) + // Drop partitions without explicitly specifying database + sessionCatalog.setCurrentDatabase("db2") + sessionCatalog.dropPartitions( + TableIdentifier("tbl2"), + Seq(part2.spec), + ignoreIfNotExists = false) + assert(externalCatalog.listPartitions("db2", "tbl2").isEmpty) + // Drop multiple partitions at once + sessionCatalog.createPartitions( + TableIdentifier("tbl2", Some("db2")), Seq(part1, part2), ignoreIfExists = false) + assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part1, part2))) + sessionCatalog.dropPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(part1.spec, part2.spec), + ignoreIfNotExists = false) + assert(externalCatalog.listPartitions("db2", "tbl2").isEmpty) + } + + test("drop partitions when database/table does not exist") { + val catalog = new SessionCatalog(newBasicCatalog()) + intercept[AnalysisException] { + catalog.dropPartitions( + TableIdentifier("tbl1", Some("does_not_exist")), + Seq(), + ignoreIfNotExists = false) + } + intercept[AnalysisException] { + catalog.dropPartitions( + TableIdentifier("does_not_exist", Some("db2")), + Seq(), + ignoreIfNotExists = false) + } + } + + test("drop partitions that do not exist") { + val catalog = new SessionCatalog(newBasicCatalog()) + intercept[AnalysisException] { + catalog.dropPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(part3.spec), + ignoreIfNotExists = false) + } + catalog.dropPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(part3.spec), + ignoreIfNotExists = true) + } + + test("get partition") { + val catalog = new SessionCatalog(newBasicCatalog()) + assert(catalog.getPartition( + TableIdentifier("tbl2", Some("db2")), part1.spec).spec == part1.spec) + assert(catalog.getPartition( + TableIdentifier("tbl2", Some("db2")), part2.spec).spec == part2.spec) + // Get partition without explicitly specifying database + catalog.setCurrentDatabase("db2") + assert(catalog.getPartition(TableIdentifier("tbl2"), part1.spec).spec == part1.spec) + assert(catalog.getPartition(TableIdentifier("tbl2"), part2.spec).spec == part2.spec) + // Get non-existent partition + intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl2"), part3.spec) + } + } + + test("get partition when database/table does not exist") { + val catalog = new SessionCatalog(newBasicCatalog()) + intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl1", Some("does_not_exist")), part1.spec) + } + intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("does_not_exist", Some("db2")), part1.spec) + } + } + + test("rename partitions") { + val catalog = new SessionCatalog(newBasicCatalog()) + val newPart1 = part1.copy(spec = Map("a" -> "100", "b" -> "101")) + val newPart2 = part2.copy(spec = Map("a" -> "200", "b" -> "201")) + val newSpecs = Seq(newPart1.spec, newPart2.spec) + catalog.renamePartitions( + TableIdentifier("tbl2", Some("db2")), Seq(part1.spec, part2.spec), newSpecs) + assert(catalog.getPartition( + TableIdentifier("tbl2", Some("db2")), newPart1.spec).spec === newPart1.spec) + assert(catalog.getPartition( + TableIdentifier("tbl2", Some("db2")), newPart2.spec).spec === newPart2.spec) + intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part1.spec) + } + intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part2.spec) + } + // Rename partitions without explicitly specifying database + catalog.setCurrentDatabase("db2") + catalog.renamePartitions(TableIdentifier("tbl2"), newSpecs, Seq(part1.spec, part2.spec)) + assert(catalog.getPartition(TableIdentifier("tbl2"), part1.spec).spec === part1.spec) + assert(catalog.getPartition(TableIdentifier("tbl2"), part2.spec).spec === part2.spec) + intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl2"), newPart1.spec) + } + intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl2"), newPart2.spec) + } + } + + test("rename partitions when database/table does not exist") { + val catalog = new SessionCatalog(newBasicCatalog()) + intercept[AnalysisException] { + catalog.renamePartitions( + TableIdentifier("tbl1", Some("does_not_exist")), Seq(part1.spec), Seq(part2.spec)) + } + intercept[AnalysisException] { + catalog.renamePartitions( + TableIdentifier("does_not_exist", Some("db2")), Seq(part1.spec), Seq(part2.spec)) + } + } + + test("alter partitions") { + val catalog = new SessionCatalog(newBasicCatalog()) + val newLocation = newUriForDatabase() + // Alter but keep spec the same + val oldPart1 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part1.spec) + val oldPart2 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part2.spec) + catalog.alterPartitions(TableIdentifier("tbl2", Some("db2")), Seq( + oldPart1.copy(storage = storageFormat.copy(locationUri = Some(newLocation))), + oldPart2.copy(storage = storageFormat.copy(locationUri = Some(newLocation))))) + val newPart1 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part1.spec) + val newPart2 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part2.spec) + assert(newPart1.storage.locationUri == Some(newLocation)) + assert(newPart2.storage.locationUri == Some(newLocation)) + assert(oldPart1.storage.locationUri != Some(newLocation)) + assert(oldPart2.storage.locationUri != Some(newLocation)) + // Alter partitions without explicitly specifying database + catalog.setCurrentDatabase("db2") + catalog.alterPartitions(TableIdentifier("tbl2"), Seq(oldPart1, oldPart2)) + val newerPart1 = catalog.getPartition(TableIdentifier("tbl2"), part1.spec) + val newerPart2 = catalog.getPartition(TableIdentifier("tbl2"), part2.spec) + assert(oldPart1.storage.locationUri == newerPart1.storage.locationUri) + assert(oldPart2.storage.locationUri == newerPart2.storage.locationUri) + // Alter but change spec, should fail because new partition specs do not exist yet + val badPart1 = part1.copy(spec = Map("a" -> "v1", "b" -> "v2")) + val badPart2 = part2.copy(spec = Map("a" -> "v3", "b" -> "v4")) + intercept[AnalysisException] { + catalog.alterPartitions(TableIdentifier("tbl2", Some("db2")), Seq(badPart1, badPart2)) + } + } + + test("alter partitions when database/table does not exist") { + val catalog = new SessionCatalog(newBasicCatalog()) + intercept[AnalysisException] { + catalog.alterPartitions(TableIdentifier("tbl1", Some("does_not_exist")), Seq(part1)) + } + intercept[AnalysisException] { + catalog.alterPartitions(TableIdentifier("does_not_exist", Some("db2")), Seq(part1)) + } + } + + test("list partitions") { + val catalog = new SessionCatalog(newBasicCatalog()) + assert(catalog.listPartitions(TableIdentifier("tbl2", Some("db2"))).toSet == Set(part1, part2)) + // List partitions without explicitly specifying database + catalog.setCurrentDatabase("db2") + assert(catalog.listPartitions(TableIdentifier("tbl2")).toSet == Set(part1, part2)) + } + + // -------------------------------------------------------------------------- + // Functions + // -------------------------------------------------------------------------- + + test("basic create and list functions") { + val externalCatalog = newEmptyCatalog() + val sessionCatalog = new SessionCatalog(externalCatalog) + sessionCatalog.createDatabase(newDb("mydb"), ignoreIfExists = false) + sessionCatalog.createFunction(newFunc("myfunc", Some("mydb")), ignoreIfExists = false) + assert(externalCatalog.listFunctions("mydb", "*").toSet == Set("myfunc")) + // Create function without explicitly specifying database + sessionCatalog.setCurrentDatabase("mydb") + sessionCatalog.createFunction(newFunc("myfunc2"), ignoreIfExists = false) + assert(externalCatalog.listFunctions("mydb", "*").toSet == Set("myfunc", "myfunc2")) + } + + test("create function when database does not exist") { + val catalog = new SessionCatalog(newBasicCatalog()) + intercept[AnalysisException] { + catalog.createFunction( + newFunc("func5", Some("does_not_exist")), ignoreIfExists = false) + } + } + + test("create function that already exists") { + val catalog = new SessionCatalog(newBasicCatalog()) + intercept[AnalysisException] { + catalog.createFunction(newFunc("func1", Some("db2")), ignoreIfExists = false) + } + catalog.createFunction(newFunc("func1", Some("db2")), ignoreIfExists = true) + } + + test("create temp function") { + val catalog = new SessionCatalog(newBasicCatalog()) + val tempFunc1 = (e: Seq[Expression]) => e.head + val tempFunc2 = (e: Seq[Expression]) => e.last + val info1 = new ExpressionInfo("tempFunc1", "temp1") + val info2 = new ExpressionInfo("tempFunc2", "temp2") + catalog.createTempFunction("temp1", info1, tempFunc1, ignoreIfExists = false) + catalog.createTempFunction("temp2", info2, tempFunc2, ignoreIfExists = false) + val arguments = Seq(Literal(1), Literal(2), Literal(3)) + assert(catalog.lookupFunction("temp1", arguments) === Literal(1)) + assert(catalog.lookupFunction("temp2", arguments) === Literal(3)) + // Temporary function does not exist. + intercept[AnalysisException] { + catalog.lookupFunction("temp3", arguments) + } + val tempFunc3 = (e: Seq[Expression]) => Literal(e.size) + val info3 = new ExpressionInfo("tempFunc3", "temp1") + // Temporary function already exists + intercept[AnalysisException] { + catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = false) + } + // Temporary function is overridden + catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = true) + assert(catalog.lookupFunction("temp1", arguments) === Literal(arguments.length)) + } + + test("drop function") { + val externalCatalog = newBasicCatalog() + val sessionCatalog = new SessionCatalog(externalCatalog) + assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) + sessionCatalog.dropFunction( + FunctionIdentifier("func1", Some("db2")), ignoreIfNotExists = false) + assert(externalCatalog.listFunctions("db2", "*").isEmpty) + // Drop function without explicitly specifying database + sessionCatalog.setCurrentDatabase("db2") + sessionCatalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false) + assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func2")) + sessionCatalog.dropFunction(FunctionIdentifier("func2"), ignoreIfNotExists = false) + assert(externalCatalog.listFunctions("db2", "*").isEmpty) + } + + test("drop function when database/function does not exist") { + val catalog = new SessionCatalog(newBasicCatalog()) + intercept[AnalysisException] { + catalog.dropFunction( + FunctionIdentifier("something", Some("does_not_exist")), ignoreIfNotExists = false) + } + intercept[AnalysisException] { + catalog.dropFunction(FunctionIdentifier("does_not_exist"), ignoreIfNotExists = false) + } + catalog.dropFunction(FunctionIdentifier("does_not_exist"), ignoreIfNotExists = true) + } + + test("drop temp function") { + val catalog = new SessionCatalog(newBasicCatalog()) + val info = new ExpressionInfo("tempFunc", "func1") + val tempFunc = (e: Seq[Expression]) => e.head + catalog.createTempFunction("func1", info, tempFunc, ignoreIfExists = false) + val arguments = Seq(Literal(1), Literal(2), Literal(3)) + assert(catalog.lookupFunction("func1", arguments) === Literal(1)) + catalog.dropTempFunction("func1", ignoreIfNotExists = false) + intercept[AnalysisException] { + catalog.lookupFunction("func1", arguments) + } + intercept[AnalysisException] { + catalog.dropTempFunction("func1", ignoreIfNotExists = false) + } + catalog.dropTempFunction("func1", ignoreIfNotExists = true) + } + + test("get function") { + val catalog = new SessionCatalog(newBasicCatalog()) + val expected = + CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass, + Seq.empty[(String, String)]) + assert(catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("db2"))) == expected) + // Get function without explicitly specifying database + catalog.setCurrentDatabase("db2") + assert(catalog.getFunctionMetadata(FunctionIdentifier("func1")) == expected) + } + + test("get function when database/function does not exist") { + val catalog = new SessionCatalog(newBasicCatalog()) + intercept[AnalysisException] { + catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("does_not_exist"))) + } + intercept[AnalysisException] { + catalog.getFunctionMetadata(FunctionIdentifier("does_not_exist", Some("db2"))) + } + } + + test("lookup temp function") { + val catalog = new SessionCatalog(newBasicCatalog()) + val info1 = new ExpressionInfo("tempFunc1", "func1") + val tempFunc1 = (e: Seq[Expression]) => e.head + catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) + assert(catalog.lookupFunction("func1", Seq(Literal(1), Literal(2), Literal(3))) == Literal(1)) + catalog.dropTempFunction("func1", ignoreIfNotExists = false) + intercept[AnalysisException] { + catalog.lookupFunction("func1", Seq(Literal(1), Literal(2), Literal(3))) + } + } + + test("list functions") { + val catalog = new SessionCatalog(newBasicCatalog()) + val info1 = new ExpressionInfo("tempFunc1", "func1") + val info2 = new ExpressionInfo("tempFunc2", "yes_me") + val tempFunc1 = (e: Seq[Expression]) => e.head + val tempFunc2 = (e: Seq[Expression]) => e.last + catalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false) + catalog.createFunction(newFunc("not_me", Some("db2")), ignoreIfExists = false) + catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) + catalog.createTempFunction("yes_me", info2, tempFunc2, ignoreIfExists = false) + assert(catalog.listFunctions("db1", "*").toSet == + Set(FunctionIdentifier("func1"), + FunctionIdentifier("yes_me"))) + assert(catalog.listFunctions("db2", "*").toSet == + Set(FunctionIdentifier("func1"), + FunctionIdentifier("yes_me"), + FunctionIdentifier("func1", Some("db2")), + FunctionIdentifier("func2", Some("db2")), + FunctionIdentifier("not_me", Some("db2")))) + assert(catalog.listFunctions("db2", "func*").toSet == + Set(FunctionIdentifier("func1"), + FunctionIdentifier("func1", Some("db2")), + FunctionIdentifier("func2", Some("db2")))) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala new file mode 100644 index 0000000000000..8c766ef829923 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import scala.reflect.ClassTag + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Encoders + +class NonEncodable(i: Int) + +case class ComplexNonEncodable1(name1: NonEncodable) + +case class ComplexNonEncodable2(name2: ComplexNonEncodable1) + +case class ComplexNonEncodable3(name3: Option[NonEncodable]) + +case class ComplexNonEncodable4(name4: Array[NonEncodable]) + +case class ComplexNonEncodable5(name5: Option[Array[NonEncodable]]) + +class EncoderErrorMessageSuite extends SparkFunSuite { + + // Note: we also test error messages for encoders for private classes in JavaDatasetSuite. + // That is done in Java because Scala cannot create truly private classes. + + test("primitive types in encoders using Kryo serialization") { + intercept[UnsupportedOperationException] { Encoders.kryo[Int] } + intercept[UnsupportedOperationException] { Encoders.kryo[Long] } + intercept[UnsupportedOperationException] { Encoders.kryo[Char] } + } + + test("primitive types in encoders using Java serialization") { + intercept[UnsupportedOperationException] { Encoders.javaSerialization[Int] } + intercept[UnsupportedOperationException] { Encoders.javaSerialization[Long] } + intercept[UnsupportedOperationException] { Encoders.javaSerialization[Char] } + } + + test("nice error message for missing encoder") { + val errorMsg1 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable1]).getMessage + assert(errorMsg1.contains( + s"""root class: "${clsName[ComplexNonEncodable1]}"""")) + assert(errorMsg1.contains( + s"""field (class: "${clsName[NonEncodable]}", name: "name1")""")) + + val errorMsg2 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable2]).getMessage + assert(errorMsg2.contains( + s"""root class: "${clsName[ComplexNonEncodable2]}"""")) + assert(errorMsg2.contains( + s"""field (class: "${clsName[ComplexNonEncodable1]}", name: "name2")""")) + assert(errorMsg1.contains( + s"""field (class: "${clsName[NonEncodable]}", name: "name1")""")) + + val errorMsg3 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable3]).getMessage + assert(errorMsg3.contains( + s"""root class: "${clsName[ComplexNonEncodable3]}"""")) + assert(errorMsg3.contains( + s"""field (class: "scala.Option", name: "name3")""")) + assert(errorMsg3.contains( + s"""option value class: "${clsName[NonEncodable]}"""")) + + val errorMsg4 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable4]).getMessage + assert(errorMsg4.contains( + s"""root class: "${clsName[ComplexNonEncodable4]}"""")) + assert(errorMsg4.contains( + s"""field (class: "scala.Array", name: "name4")""")) + assert(errorMsg4.contains( + s"""array element class: "${clsName[NonEncodable]}"""")) + + val errorMsg5 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable5]).getMessage + assert(errorMsg5.contains( + s"""root class: "${clsName[ComplexNonEncodable5]}"""")) + assert(errorMsg5.contains( + s"""field (class: "scala.Option", name: "name5")""")) + assert(errorMsg5.contains( + s"""option value class: "scala.Array"""")) + assert(errorMsg5.contains( + s"""array element class: "${clsName[NonEncodable]}"""")) + } + + private def clsName[T : ClassTag]: String = implicitly[ClassTag[T]].runtimeClass.getName +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala new file mode 100644 index 0000000000000..3ad0dae767be3 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +case class StringLongClass(a: String, b: Long) + +case class StringIntClass(a: String, b: Int) + +case class ComplexClass(a: Long, b: StringLongClass) + +class EncoderResolutionSuite extends PlanTest { + private val str = UTF8String.fromString("hello") + + test("real type doesn't match encoder schema but they are compatible: product") { + val encoder = ExpressionEncoder[StringLongClass] + + // int type can be up cast to long type + val attrs1 = Seq('a.string, 'b.int) + encoder.resolve(attrs1, null).bind(attrs1).fromRow(InternalRow(str, 1)) + + // int type can be up cast to string type + val attrs2 = Seq('a.int, 'b.long) + encoder.resolve(attrs2, null).bind(attrs2).fromRow(InternalRow(1, 2L)) + } + + test("real type doesn't match encoder schema but they are compatible: nested product") { + val encoder = ExpressionEncoder[ComplexClass] + val attrs = Seq('a.int, 'b.struct('a.int, 'b.long)) + encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L))) + } + + test("real type doesn't match encoder schema but they are compatible: tupled encoder") { + val encoder = ExpressionEncoder.tuple( + ExpressionEncoder[StringLongClass], + ExpressionEncoder[Long]) + val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int) + encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2)) + } + + test("nullability of array type element should not fail analysis") { + val encoder = ExpressionEncoder[Seq[Int]] + val attrs = 'a.array(IntegerType) :: Nil + + // It should pass analysis + val bound = encoder.resolve(attrs, null).bind(attrs) + + // If no null values appear, it should works fine + bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2)))) + + // If there is null value, it should throw runtime exception + val e = intercept[RuntimeException] { + bound.fromRow(InternalRow(new GenericArrayData(Array(1, null)))) + } + assert(e.getMessage.contains("Null value appeared in non-nullable field")) + } + + test("the real number of fields doesn't match encoder schema: tuple encoder") { + val encoder = ExpressionEncoder[(String, Long)] + + { + val attrs = Seq('a.string, 'b.long, 'c.int) + assert(intercept[AnalysisException](encoder.validate(attrs)).message == + "Try to map struct to Tuple2, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct\n" + + " - Target schema: struct<_1:string,_2:bigint>") + } + + { + val attrs = Seq('a.string) + assert(intercept[AnalysisException](encoder.validate(attrs)).message == + "Try to map struct to Tuple2, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct\n" + + " - Target schema: struct<_1:string,_2:bigint>") + } + } + + test("the real number of fields doesn't match encoder schema: nested tuple encoder") { + val encoder = ExpressionEncoder[(String, (Long, String))] + + { + val attrs = Seq('a.string, 'b.struct('x.long, 'y.string, 'z.int)) + assert(intercept[AnalysisException](encoder.validate(attrs)).message == + "Try to map struct to Tuple2, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct>\n" + + " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>") + } + + { + val attrs = Seq('a.string, 'b.struct('x.long)) + assert(intercept[AnalysisException](encoder.validate(attrs)).message == + "Try to map struct to Tuple2, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct>\n" + + " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>") + } + } + + test("throw exception if real type is not compatible with encoder schema") { + val msg1 = intercept[AnalysisException] { + ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null) + }.message + assert(msg1 == + s""" + |Cannot up cast `b` from bigint to int as it may truncate + |The type path of the target object is: + |- field (class: "scala.Int", name: "b") + |- root class: "org.apache.spark.sql.catalyst.encoders.StringIntClass" + |You can either add an explicit cast to the input data or choose a higher precision type + """.stripMargin.trim + " of the field in the target object") + + val msg2 = intercept[AnalysisException] { + val structType = new StructType().add("a", StringType).add("b", DecimalType.SYSTEM_DEFAULT) + ExpressionEncoder[ComplexClass].resolve(Seq('a.long, 'b.struct(structType)), null) + }.message + assert(msg2 == + s""" + |Cannot up cast `b`.`b` from decimal(38,18) to bigint as it may truncate + |The type path of the target object is: + |- field (class: "scala.Long", name: "b") + |- field (class: "org.apache.spark.sql.catalyst.encoders.StringLongClass", name: "b") + |- root class: "org.apache.spark.sql.catalyst.encoders.ComplexClass" + |You can either add an explicit cast to the input data or choose a higher precision type + """.stripMargin.trim + " of the field in the target object") + } + + // test for leaf types + castSuccess[Int, Long] + castSuccess[java.sql.Date, java.sql.Timestamp] + castSuccess[Long, String] + castSuccess[Int, java.math.BigDecimal] + castSuccess[Long, java.math.BigDecimal] + + castFail[Long, Int] + castFail[java.sql.Timestamp, java.sql.Date] + castFail[java.math.BigDecimal, Double] + castFail[Double, java.math.BigDecimal] + castFail[java.math.BigDecimal, Int] + castFail[String, Long] + + + private def castSuccess[T: TypeTag, U: TypeTag]: Unit = { + val from = ExpressionEncoder[T] + val to = ExpressionEncoder[U] + val catalystType = from.schema.head.dataType.simpleString + test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should success") { + to.resolve(from.schema.toAttributes, null) + } + } + + private def castFail[T: TypeTag, U: TypeTag]: Unit = { + val from = ExpressionEncoder[T] + val to = ExpressionEncoder[U] + val catalystType = from.schema.head.dataType.simpleString + test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should fail") { + intercept[AnalysisException](to.resolve(from.schema.toAttributes, null)) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index b0dacf7f555e0..18752014ea908 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -17,17 +17,32 @@ package org.apache.spark.sql.catalyst.encoders -import scala.collection.mutable.ArrayBuffer -import scala.reflect.runtime.universe._ +import java.sql.{Date, Timestamp} +import java.util.Arrays -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst._ +import scala.collection.mutable.ArrayBuffer +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} +import org.apache.spark.sql.catalyst.analysis.AnalysisTest +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.types.{StructField, ArrayType} +import org.apache.spark.sql.types.{ArrayType, Decimal, ObjectType, StructType} case class RepeatedStruct(s: Seq[PrimitiveData]) -case class NestedArray(a: Array[Array[Int]]) +case class NestedArray(a: Array[Array[Int]]) { + override def equals(other: Any): Boolean = other match { + case NestedArray(otherArray) => + java.util.Arrays.deepEquals( + a.asInstanceOf[Array[AnyRef]], + otherArray.asInstanceOf[Array[AnyRef]]) + case _ => false + } +} case class BoxedData( intField: java.lang.Integer, @@ -47,40 +62,120 @@ case class RepeatedData( case class SpecificCollection(l: List[Int]) -class ExpressionEncoderSuite extends SparkFunSuite { +/** For testing Kryo serialization based encoder. */ +class KryoSerializable(val value: Int) { + override def equals(other: Any): Boolean = { + this.value == other.asInstanceOf[KryoSerializable].value + } +} - encodeDecodeTest(1) - encodeDecodeTest(1L) - encodeDecodeTest(1.toDouble) - encodeDecodeTest(1.toFloat) - encodeDecodeTest(true) - encodeDecodeTest(false) - encodeDecodeTest(1.toShort) - encodeDecodeTest(1.toByte) - encodeDecodeTest("hello") +/** For testing Java serialization based encoder. */ +class JavaSerializable(val value: Int) extends Serializable { + override def equals(other: Any): Boolean = { + this.value == other.asInstanceOf[JavaSerializable].value + } +} - encodeDecodeTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) +class ExpressionEncoderSuite extends PlanTest with AnalysisTest { + OuterScopes.addOuterScope(this) + + implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder() + + // test flat encoders + encodeDecodeTest(false, "primitive boolean") + encodeDecodeTest(-3.toByte, "primitive byte") + encodeDecodeTest(-3.toShort, "primitive short") + encodeDecodeTest(-3, "primitive int") + encodeDecodeTest(-3L, "primitive long") + encodeDecodeTest(-3.7f, "primitive float") + encodeDecodeTest(-3.7, "primitive double") + + encodeDecodeTest(new java.lang.Boolean(false), "boxed boolean") + encodeDecodeTest(new java.lang.Byte(-3.toByte), "boxed byte") + encodeDecodeTest(new java.lang.Short(-3.toShort), "boxed short") + encodeDecodeTest(new java.lang.Integer(-3), "boxed int") + encodeDecodeTest(new java.lang.Long(-3L), "boxed long") + encodeDecodeTest(new java.lang.Float(-3.7f), "boxed float") + encodeDecodeTest(new java.lang.Double(-3.7), "boxed double") + + encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal") + // encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal") + + encodeDecodeTest(Decimal("32131413.211321313"), "catalyst decimal") + + encodeDecodeTest("hello", "string") + encodeDecodeTest(Date.valueOf("2012-12-23"), "date") + encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), "timestamp") + encodeDecodeTest(Array[Byte](13, 21, -23), "binary") + + encodeDecodeTest(Seq(31, -123, 4), "seq of int") + encodeDecodeTest(Seq("abc", "xyz"), "seq of string") + encodeDecodeTest(Seq("abc", null, "xyz"), "seq of string with null") + encodeDecodeTest(Seq.empty[Int], "empty seq of int") + encodeDecodeTest(Seq.empty[String], "empty seq of string") + + encodeDecodeTest(Seq(Seq(31, -123), null, Seq(4, 67)), "seq of seq of int") + encodeDecodeTest(Seq(Seq("abc", "xyz"), Seq[String](null), null, Seq("1", null, "2")), + "seq of seq of string") + + encodeDecodeTest(Array(31, -123, 4), "array of int") + encodeDecodeTest(Array("abc", "xyz"), "array of string") + encodeDecodeTest(Array("a", null, "x"), "array of string with null") + encodeDecodeTest(Array.empty[Int], "empty array of int") + encodeDecodeTest(Array.empty[String], "empty array of string") + + encodeDecodeTest(Array(Array(31, -123), null, Array(4, 67)), "array of array of int") + encodeDecodeTest(Array(Array("abc", "xyz"), Array[String](null), null, Array("1", null, "2")), + "array of array of string") + + encodeDecodeTest(Map(1 -> "a", 2 -> "b"), "map") + encodeDecodeTest(Map(1 -> "a", 2 -> null), "map with null") + encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), "map of map") + + encodeDecodeTest(Tuple1[Seq[Int]](null), "null seq in tuple") + encodeDecodeTest(Tuple1[Map[String, String]](null), "null map in tuple") + + // Kryo encoders + encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String])) + encodeDecodeTest(new KryoSerializable(15), "kryo object")( + encoderFor(Encoders.kryo[KryoSerializable])) + + // Java encoders + encodeDecodeTest("hello", "java string")(encoderFor(Encoders.javaSerialization[String])) + encodeDecodeTest(new JavaSerializable(15), "java object")( + encoderFor(Encoders.javaSerialization[JavaSerializable])) + + // test product encoders + private def productTest[T <: Product : ExpressionEncoder](input: T): Unit = { + encodeDecodeTest(input, input.getClass.getSimpleName) + } - // TODO: Support creating specific subclasses of Seq. - ignore("Specific collection types") { encodeDecodeTest(SpecificCollection(1 :: Nil)) } + case class InnerClass(i: Int) + productTest(InnerClass(1)) + encodeDecodeTest(Array(InnerClass(1)), "array of inner class") - encodeDecodeTest( - OptionalData( - Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), + encodeDecodeTest(Array(Option(InnerClass(1))), "array of optional inner class") + + productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) + + productTest( + OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), Some(PrimitiveData(1, 1, 1, 1, 1, 1, true)))) - encodeDecodeTest(OptionalData(None, None, None, None, None, None, None, None)) + productTest(OptionalData(None, None, None, None, None, None, None, None)) - encodeDecodeTest( - BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) + encodeDecodeTest(Seq(Some(1), None), "Option in array") + encodeDecodeTest(Map(1 -> Some(10L), 2 -> Some(20L), 3 -> None), "Option in map") - encodeDecodeTest( - BoxedData(null, null, null, null, null, null, null)) + productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) - encodeDecodeTest( - RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil)) + productTest(BoxedData(null, null, null, null, null, null, null)) - encodeDecodeTest( + productTest(RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil)) + + productTest((1, "test", PrimitiveData(1, 1, 1, 1, 1, 1, true))) + + productTest( RepeatedData( Seq(1, 2), Seq(new Integer(1), null, new Integer(2)), @@ -88,161 +183,125 @@ class ExpressionEncoderSuite extends SparkFunSuite { Map(1 -> null), PrimitiveData(1, 1, 1, 1, 1, 1, true))) - encodeDecodeTest(("nullable Seq[Integer]", Seq[Integer](1, null))) + productTest(NestedArray(Array(Array(1, -2, 3), null, Array(4, 5, -6)))) - encodeDecodeTest(("Seq[(String, String)]", + productTest(("Seq[(String, String)]", Seq(("a", "b")))) - encodeDecodeTest(("Seq[(Int, Int)]", + productTest(("Seq[(Int, Int)]", Seq((1, 2)))) - encodeDecodeTest(("Seq[(Long, Long)]", + productTest(("Seq[(Long, Long)]", Seq((1L, 2L)))) - encodeDecodeTest(("Seq[(Float, Float)]", + productTest(("Seq[(Float, Float)]", Seq((1.toFloat, 2.toFloat)))) - encodeDecodeTest(("Seq[(Double, Double)]", + productTest(("Seq[(Double, Double)]", Seq((1.toDouble, 2.toDouble)))) - encodeDecodeTest(("Seq[(Short, Short)]", + productTest(("Seq[(Short, Short)]", Seq((1.toShort, 2.toShort)))) - encodeDecodeTest(("Seq[(Byte, Byte)]", + productTest(("Seq[(Byte, Byte)]", Seq((1.toByte, 2.toByte)))) - encodeDecodeTest(("Seq[(Boolean, Boolean)]", + productTest(("Seq[(Boolean, Boolean)]", Seq((true, false)))) - // TODO: Decoding/encoding of complex maps. - ignore("complex maps") { - encodeDecodeTest(("Map[Int, (String, String)]", - Map(1 ->("a", "b")))) - } - - encodeDecodeTest(("ArrayBuffer[(String, String)]", + productTest(("ArrayBuffer[(String, String)]", ArrayBuffer(("a", "b")))) - encodeDecodeTest(("ArrayBuffer[(Int, Int)]", + productTest(("ArrayBuffer[(Int, Int)]", ArrayBuffer((1, 2)))) - encodeDecodeTest(("ArrayBuffer[(Long, Long)]", + productTest(("ArrayBuffer[(Long, Long)]", ArrayBuffer((1L, 2L)))) - encodeDecodeTest(("ArrayBuffer[(Float, Float)]", + productTest(("ArrayBuffer[(Float, Float)]", ArrayBuffer((1.toFloat, 2.toFloat)))) - encodeDecodeTest(("ArrayBuffer[(Double, Double)]", + productTest(("ArrayBuffer[(Double, Double)]", ArrayBuffer((1.toDouble, 2.toDouble)))) - encodeDecodeTest(("ArrayBuffer[(Short, Short)]", + productTest(("ArrayBuffer[(Short, Short)]", ArrayBuffer((1.toShort, 2.toShort)))) - encodeDecodeTest(("ArrayBuffer[(Byte, Byte)]", + productTest(("ArrayBuffer[(Byte, Byte)]", ArrayBuffer((1.toByte, 2.toByte)))) - encodeDecodeTest(("ArrayBuffer[(Boolean, Boolean)]", + productTest(("ArrayBuffer[(Boolean, Boolean)]", ArrayBuffer((true, false)))) - encodeDecodeTest(("Seq[Seq[(Int, Int)]]", + productTest(("Seq[Seq[(Int, Int)]]", Seq(Seq((1, 2))))) - encodeDecodeTestCustom(("Array[Array[(Int, Int)]]", - Array(Array((1, 2))))) - { (l, r) => l._2(0)(0) == r._2(0)(0) } - - encodeDecodeTestCustom(("Array[Array[(Int, Int)]]", - Array(Array(Array((1, 2)))))) - { (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) } - - encodeDecodeTestCustom(("Array[Array[Array[(Int, Int)]]]", - Array(Array(Array(Array((1, 2))))))) - { (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) } - - encodeDecodeTestCustom(("Array[Array[Array[Array[(Int, Int)]]]]", - Array(Array(Array(Array(Array((1, 2)))))))) - { (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) } - - - encodeDecodeTestCustom(("Array[Array[Integer]]", - Array(Array[Integer](1)))) - { (l, r) => l._2(0)(0) == r._2(0)(0) } - - encodeDecodeTestCustom(("Array[Array[Int]]", - Array(Array(1)))) - { (l, r) => l._2(0)(0) == r._2(0)(0) } - - encodeDecodeTestCustom(("Array[Array[Int]]", - Array(Array(Array(1))))) - { (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) } - - encodeDecodeTestCustom(("Array[Array[Array[Int]]]", - Array(Array(Array(Array(1)))))) - { (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) } - - encodeDecodeTestCustom(("Array[Array[Array[Array[Int]]]]", - Array(Array(Array(Array(Array(1))))))) - { (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) } - - encodeDecodeTest(("Array[Byte] null", - null: Array[Byte])) - encodeDecodeTestCustom(("Array[Byte]", - Array[Byte](1, 2, 3))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTest(("Array[Int] null", - null: Array[Int])) - encodeDecodeTestCustom(("Array[Int]", - Array[Int](1, 2, 3))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTest(("Array[Long] null", - null: Array[Long])) - encodeDecodeTestCustom(("Array[Long]", - Array[Long](1, 2, 3))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTest(("Array[Double] null", - null: Array[Double])) - encodeDecodeTestCustom(("Array[Double]", - Array[Double](1, 2, 3))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTest(("Array[Float] null", - null: Array[Float])) - encodeDecodeTestCustom(("Array[Float]", - Array[Float](1, 2, 3))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTest(("Array[Boolean] null", - null: Array[Boolean])) - encodeDecodeTestCustom(("Array[Boolean]", - Array[Boolean](true, false))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTest(("Array[Short] null", - null: Array[Short])) - encodeDecodeTestCustom(("Array[Short]", - Array[Short](1, 2, 3))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTestCustom(("java.sql.Timestamp", - new java.sql.Timestamp(1))) - { (l, r) => l._2.toString == r._2.toString } - - encodeDecodeTestCustom(("java.sql.Date", new java.sql.Date(1))) - { (l, r) => l._2.toString == r._2.toString } - - /** Simplified encodeDecodeTestCustom, where the comparison function can be `Object.equals`. */ - protected def encodeDecodeTest[T : TypeTag](inputData: T) = - encodeDecodeTestCustom[T](inputData)((l, r) => l == r) - - /** - * Constructs a test that round-trips `t` through an encoder, checking the results to ensure it - * matches the original. - */ - protected def encodeDecodeTestCustom[T : TypeTag]( - inputData: T)( - c: (T, T) => Boolean) = { - test(s"encode/decode: $inputData - ${inputData.getClass.getName}") { - val encoder = try ExpressionEncoder[T]() catch { - case e: Exception => - fail(s"Exception thrown generating encoder", e) - } - val convertedData = encoder.toRow(inputData) + // test for ExpressionEncoder.tuple + encodeDecodeTest( + 1 -> 10L, + "tuple with 2 flat encoders")( + ExpressionEncoder.tuple(ExpressionEncoder[Int], ExpressionEncoder[Long])) + + encodeDecodeTest( + (PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)), + "tuple with 2 product encoders")( + ExpressionEncoder.tuple(ExpressionEncoder[PrimitiveData], ExpressionEncoder[(Int, Long)])) + + encodeDecodeTest( + (PrimitiveData(1, 1, 1, 1, 1, 1, true), 3), + "tuple with flat encoder and product encoder")( + ExpressionEncoder.tuple(ExpressionEncoder[PrimitiveData], ExpressionEncoder[Int])) + + encodeDecodeTest( + (3, PrimitiveData(1, 1, 1, 1, 1, 1, true)), + "tuple with product encoder and flat encoder")( + ExpressionEncoder.tuple(ExpressionEncoder[Int], ExpressionEncoder[PrimitiveData])) + + encodeDecodeTest( + (1, (10, 100L)), + "nested tuple encoder") { + val intEnc = ExpressionEncoder[Int] + val longEnc = ExpressionEncoder[Long] + ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) + } + + productTest(("UDT", new ExamplePoint(0.1, 0.2))) + + test("nullable of encoder schema") { + def checkNullable[T: ExpressionEncoder](nullable: Boolean*): Unit = { + assert(implicitly[ExpressionEncoder[T]].schema.map(_.nullable) === nullable.toSeq) + } + + // test for flat encoders + checkNullable[Int](false) + checkNullable[Option[Int]](true) + checkNullable[java.lang.Integer](true) + checkNullable[String](true) + + // test for product encoders + checkNullable[(String, Int)](true, false) + checkNullable[(Int, java.lang.Long)](false, true) + + // test for nested product encoders + { + val schema = ExpressionEncoder[(Int, (String, Int))].schema + assert(schema(0).nullable === false) + assert(schema(1).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](0).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](1).nullable === false) + } + + // test for tupled encoders + { + val schema = ExpressionEncoder.tuple( + ExpressionEncoder[Int], + ExpressionEncoder[(String, Int)]).schema + assert(schema(0).nullable === false) + assert(schema(1).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](0).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](1).nullable === false) + } + } + + private def encodeDecodeTest[T : ExpressionEncoder]( + input: T, + testName: String): Unit = { + test(s"encode/decode for $testName: $input") { + val encoder = implicitly[ExpressionEncoder[T]] + val row = encoder.toRow(input) val schema = encoder.schema.toAttributes - val boundEncoder = encoder.resolve(schema).bind(schema) - val convertedBack = try boundEncoder.fromRow(convertedData) catch { + val boundEncoder = encoder.defaultBinding + val convertedBack = try boundEncoder.fromRow(row) catch { case e: Exception => fail( s"""Exception thrown while decoding - |Converted: $convertedData + |Converted: $row |Schema: ${schema.mkString(",")} |${encoder.schema.treeString} | @@ -252,18 +311,36 @@ class ExpressionEncoderSuite extends SparkFunSuite { """.stripMargin, e) } - if (!c(inputData, convertedBack)) { + // Test the correct resolution of serialization / deserialization. + val attr = AttributeReference("obj", ObjectType(encoder.clsTag.runtimeClass))() + val inputPlan = LocalRelation(attr) + val plan = + Project(Alias(encoder.deserializer, "obj")() :: Nil, + Project(encoder.namedExpressions, + inputPlan)) + assertAnalysisSuccess(plan) + + val isCorrect = (input, convertedBack) match { + case (b1: Array[Byte], b2: Array[Byte]) => Arrays.equals(b1, b2) + case (b1: Array[Int], b2: Array[Int]) => Arrays.equals(b1, b2) + case (b1: Array[Array[_]], b2: Array[Array[_]]) => + Arrays.deepEquals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]]) + case (b1: Array[_], b2: Array[_]) => + Arrays.equals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]]) + case _ => input == convertedBack + } + + if (!isCorrect) { val types = convertedBack match { case c: Product => c.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",") case other => other.getClass.getName } - val encodedData = try { - convertedData.toSeq(encoder.schema).zip(encoder.schema).map { - case (a: ArrayData, StructField(_, at: ArrayType, _, _)) => - a.toArray[Any](at.elementType).toSeq + row.toSeq(encoder.schema).zip(schema).map { + case (a: ArrayData, AttributeReference(_, ArrayType(et, _), _, _)) => + a.toArray[Any](et).toSeq case (other, _) => other }.mkString("[", ",", "]") @@ -274,7 +351,7 @@ class ExpressionEncoderSuite extends SparkFunSuite { fail( s"""Encoded/Decoded data does not match input data | - |in: $inputData + |in: $input |out: $convertedBack |types: $types | @@ -282,11 +359,10 @@ class ExpressionEncoderSuite extends SparkFunSuite { |Schema: ${schema.mkString(",")} |${encoder.schema.treeString} | - |Extract Expressions: - |$boundEncoder + |fromRow Expressions: + |${boundEncoder.deserializer.treeString} """.stripMargin) - } } - + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index e8301e8e06b52..a8fa372b1ee3d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -17,19 +17,72 @@ package org.apache.spark.sql.catalyst.encoders +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String + +@SQLUserDefinedType(udt = classOf[ExamplePointUDT]) +class ExamplePoint(val x: Double, val y: Double) extends Serializable { + override def hashCode: Int = 41 * (41 + x.toInt) + y.toInt + override def equals(that: Any): Boolean = { + if (that.isInstanceOf[ExamplePoint]) { + val e = that.asInstanceOf[ExamplePoint] + (this.x == e.x || (this.x.isNaN && e.x.isNaN) || (this.x.isInfinity && e.x.isInfinity)) && + (this.y == e.y || (this.y.isNaN && e.y.isNaN) || (this.y.isInfinity && e.y.isInfinity)) + } else { + false + } + } +} + +/** + * User-defined type for [[ExamplePoint]]. + */ +class ExamplePointUDT extends UserDefinedType[ExamplePoint] { + + override def sqlType: DataType = ArrayType(DoubleType, false) + + override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT" + + override def serialize(p: ExamplePoint): GenericArrayData = { + val output = new Array[Any](2) + output(0) = p.x + output(1) = p.y + new GenericArrayData(output) + } + + override def deserialize(datum: Any): ExamplePoint = { + datum match { + case values: ArrayData => + if (values.numElements() > 1) { + new ExamplePoint(values.getDouble(0), values.getDouble(1)) + } else { + val random = new Random() + new ExamplePoint(random.nextDouble(), random.nextDouble()) + } + } + } + + override def userClass: Class[ExamplePoint] = classOf[ExamplePoint] + + private[spark] override def asNullable: ExamplePointUDT = this +} class RowEncoderSuite extends SparkFunSuite { private val structOfString = new StructType().add("str", StringType) + private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false) private val arrayOfString = ArrayType(StringType) + private val arrayOfNull = ArrayType(NullType) private val mapOfString = MapType(StringType, StringType) + private val arrayOfUDT = ArrayType(new ExamplePointUDT, false) encodeDecodeTest( new StructType() + .add("null", NullType) .add("boolean", BooleanType) .add("byte", ByteType) .add("short", ShortType) @@ -41,15 +94,18 @@ class RowEncoderSuite extends SparkFunSuite { .add("string", StringType) .add("binary", BinaryType) .add("date", DateType) - .add("timestamp", TimestampType)) + .add("timestamp", TimestampType) + .add("udt", new ExamplePointUDT)) encodeDecodeTest( new StructType() + .add("arrayOfNull", arrayOfNull) .add("arrayOfString", arrayOfString) .add("arrayOfArrayOfString", ArrayType(arrayOfString)) .add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType))) .add("arrayOfMap", ArrayType(mapOfString)) - .add("arrayOfStruct", ArrayType(structOfString))) + .add("arrayOfStruct", ArrayType(structOfString)) + .add("arrayOfUDT", arrayOfUDT)) encodeDecodeTest( new StructType() @@ -68,7 +124,41 @@ class RowEncoderSuite extends SparkFunSuite { .add("structOfArray", new StructType().add("array", arrayOfString)) .add("structOfMap", new StructType().add("map", mapOfString)) .add("structOfArrayAndMap", - new StructType().add("array", arrayOfString).add("map", mapOfString))) + new StructType().add("array", arrayOfString).add("map", mapOfString)) + .add("structOfUDT", structOfUDT)) + + test(s"encode/decode: Product") { + val schema = new StructType() + .add("structAsProduct", + new StructType() + .add("int", IntegerType) + .add("string", StringType) + .add("double", DoubleType)) + + val encoder = RowEncoder(schema) + + val input: Row = Row((100, "test", 0.123)) + val row = encoder.toRow(input) + val convertedBack = encoder.fromRow(row) + assert(input.getStruct(0) == convertedBack.getStruct(0)) + } + + test("encode/decode Decimal") { + val schema = new StructType() + .add("int", IntegerType) + .add("string", StringType) + .add("double", DoubleType) + .add("decimal", DecimalType.SYSTEM_DEFAULT) + + val encoder = RowEncoder(schema) + + val input: Row = Row(100, "test", 0.123, Decimal(1234.5678)) + val row = encoder.toRow(input) + val convertedBack = encoder.fromRow(row) + // Decimal inside external row will be converted back to Java BigDecimal when decoding. + assert(input.get(3).asInstanceOf[Decimal].toJavaBigDecimal + .compareTo(convertedBack.getDecimal(3)) == 0) + } private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index f4db4da7646f8..43af3592070fe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.{Timestamp, Date} -import java.util.{TimeZone, Calendar} +import java.sql.{Date, Timestamp} +import java.util.{Calendar, TimeZone} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row @@ -258,8 +258,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from int 2") { checkEvaluation(cast(1, LongType), 1.toLong) - checkEvaluation(cast(cast(1000, TimestampType), LongType), 1.toLong) - checkEvaluation(cast(cast(-1200, TimestampType), LongType), -2.toLong) + checkEvaluation(cast(cast(1000, TimestampType), LongType), 1000.toLong) + checkEvaluation(cast(cast(-1200, TimestampType), LongType), -1200.toLong) checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) @@ -297,7 +297,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from string") { assert(cast("abcdef", StringType).nullable === false) assert(cast("abcdef", BinaryType).nullable === false) - assert(cast("abcdef", BooleanType).nullable === false) + assert(cast("abcdef", BooleanType).nullable === true) assert(cast("abcdef", TimestampType).nullable === true) assert(cast("abcdef", LongType).nullable === true) assert(cast("abcdef", IntegerType).nullable === true) @@ -348,14 +348,14 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType), DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), - 0.toShort) + 5.toShort) checkEvaluation( cast(cast(cast(cast(cast(cast("5", TimestampType), ByteType), DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), null) checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.SYSTEM_DEFAULT), ByteType), TimestampType), LongType), StringType), ShortType), - 0.toShort) + 5.toShort) checkEvaluation(cast("23", DoubleType), 23d) checkEvaluation(cast("23", IntegerType), 23) @@ -479,10 +479,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(ts, LongType), 15.toLong) checkEvaluation(cast(ts, FloatType), 15.003f) checkEvaluation(cast(ts, DoubleType), 15.003) - checkEvaluation(cast(cast(tss, ShortType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts)) + checkEvaluation(cast(cast(tss, ShortType), TimestampType), + DateTimeUtils.fromJavaTimestamp(ts) * 1000) checkEvaluation(cast(cast(tss, IntegerType), TimestampType), - DateTimeUtils.fromJavaTimestamp(ts)) - checkEvaluation(cast(cast(tss, LongType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts)) + DateTimeUtils.fromJavaTimestamp(ts) * 1000) + checkEvaluation(cast(cast(tss, LongType), TimestampType), + DateTimeUtils.fromJavaTimestamp(ts) * 1000) checkEvaluation( cast(cast(millis.toFloat / 1000, TimestampType), FloatType), millis.toFloat / 1000) @@ -545,7 +547,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } { val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false)) - assert(ret.resolved === true) + assert(ret.resolved === false) checkEvaluation(ret, Seq(null, true, false)) } @@ -604,7 +606,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } { val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false)) - assert(ret.resolved === true) + assert(ret.resolved === false) checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false)) } { @@ -711,7 +713,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("a", BooleanType, nullable = true), StructField("b", BooleanType, nullable = true), StructField("c", BooleanType, nullable = false)))) - assert(ret.resolved === true) + assert(ret.resolved === false) checkEvaluation(ret, InternalRow(null, true, false)) } @@ -732,7 +734,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val complex = Literal.create( Row( Seq("123", "true", "f"), - Map("a" ->"123", "b" -> "true", "c" -> "f"), + Map("a" -> "123", "b" -> "true", "c" -> "f"), Row(0)), StructType(Seq( StructField("a", @@ -752,7 +754,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructType(Seq( StructField("l", LongType, nullable = true))))))) - assert(ret.resolved === true) + assert(ret.resolved === false) checkEvaluation(ret, Row( Seq(123, null, null), Map("a" -> null, "b" -> true, "c" -> false), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index e323467af5f4a..260dfb3f42244 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.sql.catalyst.expressions -import scala.math._ - import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{Row, RandomDataGenerator} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ @@ -38,9 +36,8 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { import scala.concurrent.duration._ val futures = (1 to 20).map { _ => - future { + Future { GeneratePredicate.generate(EqualTo(Literal(1), Literal(1))) - GenerateProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil) GenerateMutableProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil) GenerateOrdering.generate(Add(Literal(1), Literal(1)).asc :: Nil) } @@ -49,40 +46,6 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { futures.foreach(Await.result(_, 10.seconds)) } - // Test GenerateOrdering for all common types. For each type, we construct random input rows that - // contain two columns of that type, then for pairs of randomly-generated rows we check that - // GenerateOrdering agrees with RowOrdering. - (DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType => - test(s"GenerateOrdering with $dataType") { - val rowOrdering = InterpretedOrdering.forSchema(Seq(dataType, dataType)) - val genOrdering = GenerateOrdering.generate( - BoundReference(0, dataType, nullable = true).asc :: - BoundReference(1, dataType, nullable = true).asc :: Nil) - val rowType = StructType( - StructField("a", dataType, nullable = true) :: - StructField("b", dataType, nullable = true) :: Nil) - val maybeDataGenerator = RandomDataGenerator.forType(rowType, nullable = false) - assume(maybeDataGenerator.isDefined) - val randGenerator = maybeDataGenerator.get - val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType) - for (_ <- 1 to 50) { - val a = toCatalyst(randGenerator()).asInstanceOf[InternalRow] - val b = toCatalyst(randGenerator()).asInstanceOf[InternalRow] - withClue(s"a = $a, b = $b") { - assert(genOrdering.compare(a, a) === 0) - assert(genOrdering.compare(b, b) === 0) - assert(rowOrdering.compare(a, a) === 0) - assert(rowOrdering.compare(b, b) === 0) - assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a))) - assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a))) - assert( - signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)), - "Generated and non-generated orderings should agree") - } - } - } - } - test("SPARK-8443: split wide projections into blocks due to JVM code size limit") { val length = 5000 val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1))) @@ -95,6 +58,27 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-13242: case-when expression with large number of branches (or cases)") { + val cases = 50 + val clauses = 20 + + // Generate an individual case + def generateCase(n: Int): (Expression, Expression) = { + val condition = (1 to clauses) + .map(c => EqualTo(BoundReference(0, StringType, false), Literal(s"$c:$n"))) + .reduceLeft[Expression]((l, r) => Or(l, r)) + (condition, Literal(n)) + } + + val expression = CaseWhen((1 to cases).map(generateCase(_))) + + val plan = GenerateMutableProjection.generate(Seq(expression))() + val input = new GenericMutableRow(Array[Any](UTF8String.fromString(s"${clauses}:${cases}"))) + val actual = plan(input).toSeq(Seq(expression.dataType)) + + assert(actual(0) == cases) + } + test("test generated safe and unsafe projection") { val schema = new StructType(Array( StructField("a", StringType, true), @@ -134,4 +118,22 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { unsafeRow.getStruct(3, 1).getStruct(0, 2).setInt(1, 4) assert(internalRow === internalRow2) } + + test("*/ in the data") { + // When */ appears in a comment block (i.e. in /**/), code gen will break. + // So, in Expression and CodegenFallback, we escape */ to \*\/. + checkEvaluation( + EqualTo(BoundReference(0, StringType, false), Literal.create("*/", StringType)), + true, + InternalRow(UTF8String.fromString("*/"))) + } + + test("\\u in the data") { + // When \ u appears in a comment block (i.e. in /**/), code gen will break. + // So, in Expression and CodegenFallback, we escape \ u to \\u. + checkEvaluation( + EqualTo(BoundReference(0, StringType, false), Literal.create("\\u", StringType)), + true, + InternalRow(UTF8String.fromString("\\u"))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index e60990aeb423f..7c009a7360b6f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ @@ -79,8 +78,8 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { def getStructField(expr: Expression, fieldName: String): GetStructField = { expr.dataType match { case StructType(fields) => - val field = fields.find(_.name == fieldName).get - GetStructField(expr, field, fields.indexOf(field)) + val index = fields.indexWhere(_.name == fieldName) + GetStructField(expr, index) } } @@ -135,6 +134,46 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil) } + test("CreateMap") { + def interlace(keys: Seq[Literal], values: Seq[Literal]): Seq[Literal] = { + keys.zip(values).flatMap { case (k, v) => Seq(k, v) } + } + + def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { + // catalyst map is order-sensitive, so we create ListMap here to preserve the elements order. + scala.collection.immutable.ListMap(keys.zip(values): _*) + } + + val intSeq = Seq(5, 10, 15, 20, 25) + val longSeq = intSeq.map(_.toLong) + val strSeq = intSeq.map(_.toString) + checkEvaluation(CreateMap(Nil), Map.empty) + checkEvaluation( + CreateMap(interlace(intSeq.map(Literal(_)), longSeq.map(Literal(_)))), + createMap(intSeq, longSeq)) + checkEvaluation( + CreateMap(interlace(strSeq.map(Literal(_)), longSeq.map(Literal(_)))), + createMap(strSeq, longSeq)) + checkEvaluation( + CreateMap(interlace(longSeq.map(Literal(_)), strSeq.map(Literal(_)))), + createMap(longSeq, strSeq)) + + val strWithNull = strSeq.drop(1).map(Literal(_)) :+ Literal.create(null, StringType) + checkEvaluation( + CreateMap(interlace(intSeq.map(Literal(_)), strWithNull)), + createMap(intSeq, strWithNull.map(_.value))) + intercept[RuntimeException] { + checkEvaluationWithoutCodegen( + CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), + null, null) + } + intercept[RuntimeException] { + checkEvalutionWithUnsafeProjection( + CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), + null, null) + } + } + test("CreateStruct") { val row = create_row(1, 2, 3) val c1 = 'a.int.at(0) @@ -165,7 +204,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { "b", create_row(Map("a" -> "b"))) checkEvaluation(quickResolve('c.array(StringType).at(0).getItem(1)), "b", create_row(Seq("a", "b"))) - checkEvaluation(quickResolve('c.struct(StructField("a", IntegerType)).at(0).getField("a")), + checkEvaluation(quickResolve('c.struct('a.int).at(0).getField("a")), 1, create_row(create_row(1))) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index 0df673bb9fa02..3c581ecdaf068 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.{Timestamp, Date} +import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ - class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { test("if") { @@ -81,38 +80,39 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper val c5 = 'a.string.at(4) val c6 = 'a.string.at(5) - checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row) - checkEvaluation(CaseWhen(Seq(Literal.create(null, BooleanType), c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(Literal.create(false, BooleanType), c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(Literal.create(true, BooleanType), c4, c6)), "a", row) + checkEvaluation(CaseWhen(Seq((c1, c4)), c6), "c", row) + checkEvaluation(CaseWhen(Seq((c2, c4)), c6), "c", row) + checkEvaluation(CaseWhen(Seq((c3, c4)), c6), "a", row) + checkEvaluation(CaseWhen(Seq((Literal.create(null, BooleanType), c4)), c6), "c", row) + checkEvaluation(CaseWhen(Seq((Literal.create(false, BooleanType), c4)), c6), "c", row) + checkEvaluation(CaseWhen(Seq((Literal.create(true, BooleanType), c4)), c6), "a", row) - checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row) - checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row) - checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5)), null, row) + checkEvaluation(CaseWhen(Seq((c3, c4), (c2, c5)), c6), "a", row) + checkEvaluation(CaseWhen(Seq((c2, c4), (c3, c5)), c6), "b", row) + checkEvaluation(CaseWhen(Seq((c1, c4), (c2, c5)), c6), "c", row) + checkEvaluation(CaseWhen(Seq((c1, c4), (c2, c5))), null, row) - assert(CaseWhen(Seq(c2, c4, c6)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5)).nullable === true) + assert(CaseWhen(Seq((c2, c4)), c6).nullable === true) + assert(CaseWhen(Seq((c2, c4), (c3, c5)), c6).nullable === true) + assert(CaseWhen(Seq((c2, c4), (c3, c5))).nullable === true) val c4_notNull = 'a.boolean.notNull.at(3) val c5_notNull = 'a.boolean.notNull.at(4) val c6_notNull = 'a.boolean.notNull.at(5) - assert(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable === false) - assert(CaseWhen(Seq(c2, c4, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c6)).nullable === true) + assert(CaseWhen(Seq((c2, c4_notNull)), c6_notNull).nullable === false) + assert(CaseWhen(Seq((c2, c4)), c6_notNull).nullable === true) + assert(CaseWhen(Seq((c2, c4_notNull))).nullable === true) + assert(CaseWhen(Seq((c2, c4_notNull)), c6).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable === false) - assert(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable === true) + assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5_notNull)), c6_notNull).nullable === false) + assert(CaseWhen(Seq((c2, c4), (c3, c5_notNull)), c6_notNull).nullable === true) + assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5)), c6_notNull).nullable === true) + assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5_notNull)), c6).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true) + assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5_notNull))).nullable === true) + assert(CaseWhen(Seq((c2, c4), (c3, c5_notNull))).nullable === true) + assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5))).nullable === true) } test("case key when") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 610d39e8493cd..53c66d8a754ed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -465,6 +465,42 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format")), null) } + test("to_unix_timestamp") { + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2) + val fmt3 = "yy-MM-dd" + val sdf3 = new SimpleDateFormat(fmt3) + val date1 = Date.valueOf("2015-07-24") + checkEvaluation( + ToUnixTimestamp(Literal(sdf1.format(new Timestamp(0))), Literal("yyyy-MM-dd HH:mm:ss")), 0L) + checkEvaluation(ToUnixTimestamp( + Literal(sdf1.format(new Timestamp(1000000))), Literal("yyyy-MM-dd HH:mm:ss")), 1000L) + checkEvaluation( + ToUnixTimestamp(Literal(new Timestamp(1000000)), Literal("yyyy-MM-dd HH:mm:ss")), 1000L) + checkEvaluation( + ToUnixTimestamp(Literal(date1), Literal("yyyy-MM-dd HH:mm:ss")), + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1)) / 1000L) + checkEvaluation( + ToUnixTimestamp(Literal(sdf2.format(new Timestamp(-1000000))), Literal(fmt2)), -1000L) + checkEvaluation(ToUnixTimestamp( + Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3)), + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24"))) / 1000L) + val t1 = ToUnixTimestamp( + CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] + val t2 = ToUnixTimestamp( + CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] + assert(t2 - t1 <= 1) + checkEvaluation( + ToUnixTimestamp(Literal.create(null, DateType), Literal.create(null, StringType)), null) + checkEvaluation( + ToUnixTimestamp(Literal.create(null, DateType), Literal("yyyy-MM-dd HH:mm:ss")), null) + checkEvaluation(ToUnixTimestamp( + Literal(date1), Literal.create(null, StringType)), date1.getTime / 1000L) + checkEvaluation( + ToUnixTimestamp(Literal("2015-07-24"), Literal("not a valid format")), null) + } + test("datediff") { checkEvaluation( DateDiff(Literal(Date.valueOf("2015-07-24")), Literal(Date.valueOf("2015-07-21"))), 3) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala index 511f0307901df..a8f758d625a02 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.{LongType, DecimalType, Decimal} - +import org.apache.spark.sql.types.{Decimal, DecimalType, LongType} class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 465f7d08aa142..cf26d4843d84f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types.DataType +import org.apache.spark.util.Utils /** * A few helper functions for expression evaluation testing. Mixin this trait to use them. @@ -43,7 +44,6 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expression, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow) - checkEvaluationWithGeneratedProjection(expression, catalystValue, inputRow) if (GenerateUnsafeProjection.canSupport(expression.dataType)) { checkEvalutionWithUnsafeProjection(expression, catalystValue, inputRow) } @@ -58,8 +58,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { (result, expected) match { case (result: Array[Byte], expected: Array[Byte]) => java.util.Arrays.equals(result, expected) - case (result: Double, expected: Spread[Double]) => - expected.isWithin(result) + case (result: Double, expected: Spread[Double @unchecked]) => + expected.asInstanceOf[Spread[Double]].isWithin(result) case _ => result == expected } } @@ -83,7 +83,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { s""" |Code generation of $expression failed: |$e - |${e.getStackTraceString} + |${Utils.exceptionString(e)} """.stripMargin) } } @@ -120,42 +120,6 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } - protected def checkEvaluationWithGeneratedProjection( - expression: Expression, - expected: Any, - inputRow: InternalRow = EmptyRow): Unit = { - - val plan = generateProject( - GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), - expression) - - val actual = plan(inputRow) - val expectedRow = InternalRow(expected) - - // We reimplement hashCode in generated `SpecificRow`, make sure it's consistent with our - // interpreted version. - if (actual.hashCode() != expectedRow.hashCode()) { - val ctx = new CodeGenContext - val evaluated = expression.gen(ctx) - fail( - s""" - |Mismatched hashCodes for values: $actual, $expectedRow - |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} - |Expressions: $expression - |Code: $evaluated - """.stripMargin) - } - - if (actual != expectedRow) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail("Incorrect Evaluation in codegen mode: " + - s"$expression, actual: $actual, expected: $expectedRow$input") - } - if (actual.copy() != expectedRow) { - fail(s"Copy of generated Row is wrong: actual: ${actual.copy()}, expected: $expectedRow") - } - } - protected def checkEvalutionWithUnsafeProjection( expression: Expression, expected: Any, @@ -202,7 +166,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { checkEvaluationWithOptimization(expression, expected) var plan = generateProject( - GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)(), expression) var actual = plan(inputRow).get(0, expression.dataType) assert(checkResult(actual, expected)) @@ -312,8 +276,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { (result, expected) match { case (result: Array[Byte], expected: Array[Byte]) => java.util.Arrays.equals(result, expected) - case (result: Double, expected: Spread[Double]) => - expected.isWithin(result) + case (result: Double, expected: Spread[Double @unchecked]) => + expected.asInstanceOf[Spread[Double]].isWithin(result) case (result: Double, expected: Double) if result.isNaN && expected.isNaN => true case (result: Float, expected: Float) if result.isNaN && expected.isNaN => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala new file mode 100644 index 0000000000000..60939ee0eda5d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.IntegerType + +class ExpressionSetSuite extends SparkFunSuite { + + val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1)) + val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1)) + val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3)) + + val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2)) + val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2)) + + val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil) + + def setTest(size: Int, exprs: Expression*): Unit = { + test(s"expect $size: ${exprs.mkString(", ")}") { + val set = ExpressionSet(exprs) + if (set.size != size) { + fail(set.toDebugString) + } + } + } + + def setTestIgnore(size: Int, exprs: Expression*): Unit = + ignore(s"expect $size: ${exprs.mkString(", ")}") {} + + // Commutative + setTest(1, aUpper + 1, aLower + 1) + setTest(2, aUpper + 1, aLower + 2) + setTest(2, aUpper + 1, fakeA + 1) + setTest(2, aUpper + 1, bUpper + 1) + + setTest(1, aUpper + aLower, aLower + aUpper) + setTest(1, aUpper + bUpper, bUpper + aUpper) + setTest(1, + aUpper + bUpper + 3, + bUpper + 3 + aUpper, + bUpper + aUpper + 3, + Literal(3) + aUpper + bUpper) + setTest(1, + aUpper * bUpper * 3, + bUpper * 3 * aUpper, + bUpper * aUpper * 3, + Literal(3) * aUpper * bUpper) + setTest(1, aUpper === bUpper, bUpper === aUpper) + + setTest(1, aUpper + 1 === bUpper, bUpper === Literal(1) + aUpper) + + + // Not commutative + setTest(2, aUpper - bUpper, bUpper - aUpper) + + // Reversible + setTest(1, aUpper > bUpper, bUpper < aUpper) + setTest(1, aUpper >= bUpper, bUpper <= aUpper) + + // `Not` canonicalization + setTest(1, Not(aUpper > 1), aUpper <= 1, Not(Literal(1) < aUpper), Literal(1) >= aUpper) + setTest(1, Not(aUpper < 1), aUpper >= 1, Not(Literal(1) > aUpper), Literal(1) <= aUpper) + setTest(1, Not(aUpper >= 1), aUpper < 1, Not(Literal(1) <= aUpper), Literal(1) > aUpper) + setTest(1, Not(aUpper <= 1), aUpper > 1, Not(Literal(1) >= aUpper), Literal(1) < aUpper) + + test("add to / remove from set") { + val initialSet = ExpressionSet(aUpper + 1 :: Nil) + + assert((initialSet + (aUpper + 1)).size == 1) + assert((initialSet + (aUpper + 2)).size == 2) + assert((initialSet - (aUpper + 1)).size == 0) + assert((initialSet - (aUpper + 2)).size == 1) + + assert((initialSet + (aLower + 1)).size == 1) + assert((initialSet - (aLower + 1)).size == 0) + + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index f33125f463e14..7b754091f4714 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -209,8 +209,12 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal("f5") :: Nil + private def checkJsonTuple(jt: JsonTuple, expected: InternalRow): Unit = { + assert(jt.eval(null).toSeq.head === expected) + } + test("json_tuple - hive key 1") { - checkEvaluation( + checkJsonTuple( JsonTuple( Literal("""{"f1": "value1", "f2": "value2", "f3": 3, "f5": 5.23}""") :: jsonTupleQuery), @@ -218,7 +222,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("json_tuple - hive key 2") { - checkEvaluation( + checkJsonTuple( JsonTuple( Literal("""{"f1": "value12", "f3": "value3", "f2": 2, "f4": 4.01}""") :: jsonTupleQuery), @@ -226,7 +230,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("json_tuple - hive key 2 (mix of foldable fields)") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("""{"f1": "value12", "f3": "value3", "f2": 2, "f4": 4.01}""") :: Literal("f1") :: NonFoldableLiteral("f2") :: @@ -238,7 +242,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("json_tuple - hive key 3") { - checkEvaluation( + checkJsonTuple( JsonTuple( Literal("""{"f1": "value13", "f4": "value44", "f3": "value33", "f2": 2, "f5": 5.01}""") :: jsonTupleQuery), @@ -247,7 +251,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("json_tuple - hive key 3 (nonfoldable json)") { - checkEvaluation( + checkJsonTuple( JsonTuple( NonFoldableLiteral( """{"f1": "value13", "f4": "value44", @@ -258,7 +262,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("json_tuple - hive key 3 (nonfoldable fields)") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal( """{"f1": "value13", "f4": "value44", | "f3": "value33", "f2": 2, "f5": 5.01}""".stripMargin) :: @@ -273,43 +277,43 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("json_tuple - hive key 4 - null json") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal(null) :: jsonTupleQuery), InternalRow.fromSeq(Seq(null, null, null, null, null))) } test("json_tuple - hive key 5 - null and empty fields") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("""{"f1": "", "f5": null}""") :: jsonTupleQuery), InternalRow.fromSeq(Seq(UTF8String.fromString(""), null, null, null, null))) } test("json_tuple - hive key 6 - invalid json (array)") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("[invalid JSON string]") :: jsonTupleQuery), InternalRow.fromSeq(Seq(null, null, null, null, null))) } test("json_tuple - invalid json (object start only)") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("{") :: jsonTupleQuery), InternalRow.fromSeq(Seq(null, null, null, null, null))) } test("json_tuple - invalid json (no object end)") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("""{"foo": "bar"""") :: jsonTupleQuery), InternalRow.fromSeq(Seq(null, null, null, null, null))) } test("json_tuple - invalid json (invalid json)") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("\\") :: jsonTupleQuery), InternalRow.fromSeq(Seq(null, null, null, null, null))) } test("json_tuple - preserve newlines") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("{\"a\":\"b\nc\"}") :: Literal("a") :: Nil), InternalRow.fromSeq(Seq(UTF8String.fromString("b\nc")))) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index 7b85286c4dc8c..450222d8cbba3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.nio.charset.StandardCharsets + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -54,7 +56,7 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.default(FloatType), 0.0f) checkEvaluation(Literal.default(DoubleType), 0.0) checkEvaluation(Literal.default(StringType), "") - checkEvaluation(Literal.default(BinaryType), "".getBytes) + checkEvaluation(Literal.default(BinaryType), "".getBytes(StandardCharsets.UTF_8)) checkEvaluation(Literal.default(DecimalType.USER_DEFAULT), Decimal(0)) checkEvaluation(Literal.default(DecimalType.SYSTEM_DEFAULT), Decimal(0)) checkEvaluation(Literal.default(DateType), DateTimeUtils.toJavaDate(0)) @@ -96,7 +98,7 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { test("string literals") { checkEvaluation(Literal(""), "") checkEvaluation(Literal("test"), "test") - checkEvaluation(Literal("\0"), "\0") + checkEvaluation(Literal("\u0000"), "\u0000") } test("sum two literals") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala index d9c91415e249d..032aec01782f0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import org.scalacheck.{Arbitrary, Gen} -import org.scalatest.Matchers -import org.scalatest.prop.GeneratorDrivenPropertyChecks import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 88ed9fdd6465f..27195d3458b8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.catalyst.expressions +import java.nio.charset.StandardCharsets + import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types._ @@ -120,7 +122,6 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { private def checkNaNWithoutCodegen( expression: Expression, - expected: Any, inputRow: InternalRow = EmptyRow): Unit = { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) @@ -440,10 +441,10 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Hex(Literal(100800200404L)), "177828FED4") checkEvaluation(Hex(Literal(-100800200404L)), "FFFFFFE887D7012C") checkEvaluation(Hex(Literal.create(null, BinaryType)), null) - checkEvaluation(Hex(Literal("helloHex".getBytes())), "68656C6C6F486578") + checkEvaluation(Hex(Literal("helloHex".getBytes(StandardCharsets.UTF_8))), "68656C6C6F486578") // scalastyle:off // Turn off scala style for non-ascii chars - checkEvaluation(Hex(Literal("三重的".getBytes("UTF8"))), "E4B889E9878DE79A84") + checkEvaluation(Hex(Literal("三重的".getBytes(StandardCharsets.UTF_8))), "E4B889E9878DE79A84") // scalastyle:on Seq(LongType, BinaryType, StringType).foreach { dt => checkConsistencyBetweenInterpretedAndCodegen(Hex.apply _, dt) @@ -452,14 +453,14 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("unhex") { checkEvaluation(Unhex(Literal.create(null, StringType)), null) - checkEvaluation(Unhex(Literal("737472696E67")), "string".getBytes) + checkEvaluation(Unhex(Literal("737472696E67")), "string".getBytes(StandardCharsets.UTF_8)) checkEvaluation(Unhex(Literal("")), new Array[Byte](0)) checkEvaluation(Unhex(Literal("F")), Array[Byte](15)) checkEvaluation(Unhex(Literal("ff")), Array[Byte](-1)) checkEvaluation(Unhex(Literal("GG")), null) // scalastyle:off // Turn off scala style for non-ascii chars - checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), "三重的".getBytes("UTF-8")) + checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), "三重的".getBytes(StandardCharsets.UTF_8)) checkEvaluation(Unhex(Literal("三重的")), null) // scalastyle:on checkConsistencyBetweenInterpretedAndCodegen(Unhex, StringType) @@ -552,5 +553,9 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Round(Literal.create(null, dataType), Literal.create(null, IntegerType)), null) } + + checkEvaluation(Round(-3.5, 0), -4.0) + checkEvaluation(Round(-0.35, 1), -0.4) + checkEvaluation(Round(-35, -1), -40) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 75d17417e5a02..f5bafcc6a783e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -17,15 +17,20 @@ package org.apache.spark.sql.catalyst.expressions +import java.nio.charset.StandardCharsets + import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType} +import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} +import org.apache.spark.sql.types._ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("md5") { - checkEvaluation(Md5(Literal("ABC".getBytes)), "902fbdd2b1df0c4f70b4a5d23525e932") + checkEvaluation(Md5(Literal("ABC".getBytes(StandardCharsets.UTF_8))), + "902fbdd2b1df0c4f70b4a5d23525e932") checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), "6ac1e56bc78f031059be7be854522c4c") checkEvaluation(Md5(Literal.create(null, BinaryType)), null) @@ -33,30 +38,106 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("sha1") { - checkEvaluation(Sha1(Literal("ABC".getBytes)), "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8") + checkEvaluation(Sha1(Literal("ABC".getBytes(StandardCharsets.UTF_8))), + "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8") checkEvaluation(Sha1(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), "5d211bad8f4ee70e16c7d343a838fc344a1ed961") checkEvaluation(Sha1(Literal.create(null, BinaryType)), null) - checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709") + checkEvaluation(Sha1(Literal("".getBytes(StandardCharsets.UTF_8))), + "da39a3ee5e6b4b0d3255bfef95601890afd80709") checkConsistencyBetweenInterpretedAndCodegen(Sha1, BinaryType) } test("sha2") { - checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC")) + checkEvaluation(Sha2(Literal("ABC".getBytes(StandardCharsets.UTF_8)), Literal(256)), + DigestUtils.sha256Hex("ABC")) checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)), DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6))) // unsupported bit length checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null) checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(512)), null) - checkEvaluation(Sha2(Literal("ABC".getBytes), Literal.create(null, IntegerType)), null) + checkEvaluation(Sha2(Literal("ABC".getBytes(StandardCharsets.UTF_8)), + Literal.create(null, IntegerType)), null) checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null) } test("crc32") { - checkEvaluation(Crc32(Literal("ABC".getBytes)), 2743272264L) + checkEvaluation(Crc32(Literal("ABC".getBytes(StandardCharsets.UTF_8))), 2743272264L) checkEvaluation(Crc32(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), 2180413220L) checkEvaluation(Crc32(Literal.create(null, BinaryType)), null) checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) } + + private val structOfString = new StructType().add("str", StringType) + private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false) + private val arrayOfString = ArrayType(StringType) + private val arrayOfNull = ArrayType(NullType) + private val mapOfString = MapType(StringType, StringType) + private val arrayOfUDT = ArrayType(new ExamplePointUDT, false) + + testHash( + new StructType() + .add("null", NullType) + .add("boolean", BooleanType) + .add("byte", ByteType) + .add("short", ShortType) + .add("int", IntegerType) + .add("long", LongType) + .add("float", FloatType) + .add("double", DoubleType) + .add("bigDecimal", DecimalType.SYSTEM_DEFAULT) + .add("smallDecimal", DecimalType.USER_DEFAULT) + .add("string", StringType) + .add("binary", BinaryType) + .add("date", DateType) + .add("timestamp", TimestampType) + .add("udt", new ExamplePointUDT)) + + testHash( + new StructType() + .add("arrayOfNull", arrayOfNull) + .add("arrayOfString", arrayOfString) + .add("arrayOfArrayOfString", ArrayType(arrayOfString)) + .add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType))) + .add("arrayOfMap", ArrayType(mapOfString)) + .add("arrayOfStruct", ArrayType(structOfString)) + .add("arrayOfUDT", arrayOfUDT)) + + testHash( + new StructType() + .add("mapOfIntAndString", MapType(IntegerType, StringType)) + .add("mapOfStringAndArray", MapType(StringType, arrayOfString)) + .add("mapOfArrayAndInt", MapType(arrayOfString, IntegerType)) + .add("mapOfArray", MapType(arrayOfString, arrayOfString)) + .add("mapOfStringAndStruct", MapType(StringType, structOfString)) + .add("mapOfStructAndString", MapType(structOfString, StringType)) + .add("mapOfStruct", MapType(structOfString, structOfString))) + + testHash( + new StructType() + .add("structOfString", structOfString) + .add("structOfStructOfString", new StructType().add("struct", structOfString)) + .add("structOfArray", new StructType().add("array", arrayOfString)) + .add("structOfMap", new StructType().add("map", mapOfString)) + .add("structOfArrayAndMap", + new StructType().add("array", arrayOfString).add("map", mapOfString)) + .add("structOfUDT", structOfUDT)) + + private def testHash(inputSchema: StructType): Unit = { + val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get + val encoder = RowEncoder(inputSchema) + val seed = scala.util.Random.nextInt() + test(s"murmur3/xxHash64 hash: ${inputSchema.simpleString}") { + for (_ <- 1 to 10) { + val input = encoder.toRow(inputGenerator.apply().asInstanceOf[Row]).asInstanceOf[UnsafeRow] + val literals = input.toSeq(inputSchema).zip(inputSchema.map(_.dataType)).map { + case (value, dt) => Literal.create(value, dt) + } + // Only test the interpreted version has same result with codegen version. + checkEvaluation(Murmur3Hash(literals, seed), Murmur3Hash(literals, seed).eval()) + checkEvaluation(XxHash64(literals, seed), XxHash64(literals, seed).eval()) + } + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala index 31ecf4a9e810a..ff34b1e37be93 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala @@ -26,8 +26,7 @@ import org.apache.spark.sql.types._ * A literal value that is not foldable. Used in expression codegen testing to test code path * that behave differently based on foldable values. */ -case class NonFoldableLiteral(value: Any, dataType: DataType) - extends LeafExpression with CodegenFallback { +case class NonFoldableLiteral(value: Any, dataType: DataType) extends LeafExpression { override def foldable: Boolean = false override def nullable: Boolean = true @@ -36,7 +35,7 @@ case class NonFoldableLiteral(value: Any, dataType: DataType) override def eval(input: InternalRow): Any = value - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { Literal.create(value, dataType).genCode(ctx, ev) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala new file mode 100644 index 0000000000000..b190d3a00dfb8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.math._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering +import org.apache.spark.sql.types._ + +class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper { + + def compareArrays(a: Seq[Any], b: Seq[Any], expected: Int): Unit = { + test(s"compare two arrays: a = $a, b = $b") { + val dataType = ArrayType(IntegerType) + val rowType = StructType(StructField("array", dataType, nullable = true) :: Nil) + val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType) + val rowA = toCatalyst(Row(a)).asInstanceOf[InternalRow] + val rowB = toCatalyst(Row(b)).asInstanceOf[InternalRow] + Seq(Ascending, Descending).foreach { direction => + val sortOrder = direction match { + case Ascending => BoundReference(0, dataType, nullable = true).asc + case Descending => BoundReference(0, dataType, nullable = true).desc + } + val expectedCompareResult = direction match { + case Ascending => signum(expected) + case Descending => -1 * signum(expected) + } + val intOrdering = new InterpretedOrdering(sortOrder :: Nil) + val genOrdering = GenerateOrdering.generate(sortOrder :: Nil) + Seq(intOrdering, genOrdering).foreach { ordering => + assert(ordering.compare(rowA, rowA) === 0) + assert(ordering.compare(rowB, rowB) === 0) + assert(signum(ordering.compare(rowA, rowB)) === expectedCompareResult) + assert(signum(ordering.compare(rowB, rowA)) === -1 * expectedCompareResult) + } + } + } + } + + // Two arrays have the same size. + compareArrays(Seq[Any](), Seq[Any](), 0) + compareArrays(Seq[Any](1), Seq[Any](1), 0) + compareArrays(Seq[Any](1, 2), Seq[Any](1, 2), 0) + compareArrays(Seq[Any](1, 2, 2), Seq[Any](1, 2, 3), -1) + + // Two arrays have different sizes. + compareArrays(Seq[Any](), Seq[Any](1), -1) + compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 3, 4), -1) + compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 3, 2), -1) + compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 2, 2), 1) + + // Arrays having nulls. + compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 3, null), -1) + compareArrays(Seq[Any](), Seq[Any](null), -1) + compareArrays(Seq[Any](null), Seq[Any](null), 0) + compareArrays(Seq[Any](null, null), Seq[Any](null, null), 0) + compareArrays(Seq[Any](null), Seq[Any](null, null), -1) + compareArrays(Seq[Any](null), Seq[Any](1), -1) + compareArrays(Seq[Any](null), Seq[Any](null, 1), -1) + compareArrays(Seq[Any](null, 1), Seq[Any](1, 1), -1) + compareArrays(Seq[Any](1, null, 1), Seq[Any](1, null, 1), 0) + compareArrays(Seq[Any](1, null, 1), Seq[Any](1, null, 2), -1) + + // Test GenerateOrdering for all common types. For each type, we construct random input rows that + // contain two columns of that type, then for pairs of randomly-generated rows we check that + // GenerateOrdering agrees with RowOrdering. + { + val structType = + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true) + val arrayOfStructType = ArrayType(structType) + val complexTypes = ArrayType(IntegerType) :: structType :: arrayOfStructType :: Nil + (DataTypeTestUtils.atomicTypes ++ complexTypes ++ Set(NullType)).foreach { dataType => + test(s"GenerateOrdering with $dataType") { + val rowOrdering = InterpretedOrdering.forSchema(Seq(dataType, dataType)) + val genOrdering = GenerateOrdering.generate( + BoundReference(0, dataType, nullable = true).asc :: + BoundReference(1, dataType, nullable = true).asc :: Nil) + val rowType = StructType( + StructField("a", dataType, nullable = true) :: + StructField("b", dataType, nullable = true) :: Nil) + val maybeDataGenerator = RandomDataGenerator.forType(rowType, nullable = false) + assume(maybeDataGenerator.isDefined) + val randGenerator = maybeDataGenerator.get + val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType) + for (_ <- 1 to 50) { + val a = toCatalyst(randGenerator()).asInstanceOf[InternalRow] + val b = toCatalyst(randGenerator()).asInstanceOf[InternalRow] + withClue(s"a = $a, b = $b") { + assert(genOrdering.compare(a, a) === 0) + assert(genOrdering.compare(b, b) === 0) + assert(rowOrdering.compare(a, a) === 0) + assert(rowOrdering.compare(b, b) === 0) + assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a))) + assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a))) + assert( + signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)), + "Generated and non-generated orderings should agree") + } + } + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala index 4a644d136f09c..b7a0d44fa7e57 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala @@ -24,12 +24,12 @@ import org.apache.spark.SparkFunSuite class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { test("random") { - checkDoubleEvaluation(Rand(30), 0.7363714192755834 +- 0.001) - checkDoubleEvaluation(Randn(30), 0.5181478766595276 +- 0.001) + checkDoubleEvaluation(Rand(30), 0.31429268272540556 +- 0.001) + checkDoubleEvaluation(Randn(30), -0.4798519469521663 +- 0.001) } test("SPARK-9127 codegen with long seed") { - checkDoubleEvaluation(Rand(5419823303878592871L), 0.4061913198963727 +- 0.001) - checkDoubleEvaluation(Randn(5419823303878592871L), -0.24417152005343168 +- 0.001) + checkDoubleEvaluation(Rand(5419823303878592871L), 0.2304755080444375 +- 0.001) + checkDoubleEvaluation(Randn(5419823303878592871L), -1.2824262718225607 +- 0.001) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 99e3b13ce8c97..2cf8ca7000edc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -382,6 +382,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(InitCap(Literal("a b")), "A B") checkEvaluation(InitCap(Literal(" a")), " A") checkEvaluation(InitCap(Literal("the test")), "The Test") + checkEvaluation(InitCap(Literal("sParK")), "Spark") // scalastyle:off // non ascii characters are not allowed in the code, so we disable the scalastyle here. checkEvaluation(InitCap(Literal("世界")), "世界") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala new file mode 100644 index 0000000000000..90e97d718a9fc --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.IntegerType + +class SubexpressionEliminationSuite extends SparkFunSuite { + test("Semantic equals and hash") { + val id = ExprId(1) + val a: AttributeReference = AttributeReference("name", IntegerType)() + val b1 = a.withName("name2").withExprId(id) + val b2 = a.withExprId(id) + val b3 = a.withQualifier(Some("qualifierName")) + + assert(b1 != b2) + assert(a != b1) + assert(b1.semanticEquals(b2)) + assert(!b1.semanticEquals(a)) + assert(a.hashCode != b1.hashCode) + assert(b1.hashCode != b2.hashCode) + assert(b1.semanticHash() == b2.semanticHash()) + assert(a != b3) + assert(a.hashCode != b3.hashCode) + assert(a.semanticEquals(b3)) + } + + test("Expression Equivalence - basic") { + val equivalence = new EquivalentExpressions + assert(equivalence.getAllEquivalentExprs.isEmpty) + + val oneA = Literal(1) + val oneB = Literal(1) + val twoA = Literal(2) + var twoB = Literal(2) + + assert(equivalence.getEquivalentExprs(oneA).isEmpty) + assert(equivalence.getEquivalentExprs(twoA).isEmpty) + + // Add oneA and test if it is returned. Since it is a group of one, it does not. + assert(!equivalence.addExpr(oneA)) + assert(equivalence.getEquivalentExprs(oneA).size == 1) + assert(equivalence.getEquivalentExprs(twoA).isEmpty) + assert(equivalence.addExpr((oneA))) + assert(equivalence.getEquivalentExprs(oneA).size == 2) + + // Add B and make sure they can see each other. + assert(equivalence.addExpr(oneB)) + // Use exists and reference equality because of how equals is defined. + assert(equivalence.getEquivalentExprs(oneA).exists(_ eq oneB)) + assert(equivalence.getEquivalentExprs(oneA).exists(_ eq oneA)) + assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneA)) + assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneB)) + assert(equivalence.getEquivalentExprs(twoA).isEmpty) + assert(equivalence.getAllEquivalentExprs.size == 1) + assert(equivalence.getAllEquivalentExprs.head.size == 3) + assert(equivalence.getAllEquivalentExprs.head.contains(oneA)) + assert(equivalence.getAllEquivalentExprs.head.contains(oneB)) + + val add1 = Add(oneA, oneB) + val add2 = Add(oneA, oneB) + + equivalence.addExpr(add1) + equivalence.addExpr(add2) + + assert(equivalence.getAllEquivalentExprs.size == 2) + assert(equivalence.getEquivalentExprs(add2).exists(_ eq add1)) + assert(equivalence.getEquivalentExprs(add2).size == 2) + assert(equivalence.getEquivalentExprs(add1).exists(_ eq add2)) + } + + test("Expression Equivalence - Trees") { + val one = Literal(1) + val two = Literal(2) + + val add = Add(one, two) + val abs = Abs(add) + val add2 = Add(add, add) + + var equivalence = new EquivalentExpressions + equivalence.addExprTree(add, true) + equivalence.addExprTree(abs, true) + equivalence.addExprTree(add2, true) + + // Should only have one equivalence for `one + two` + assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 1) + assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).head.size == 4) + + // Set up the expressions + // one * two, + // (one * two) * (one * two) + // sqrt( (one * two) * (one * two) ) + // (one * two) + sqrt( (one * two) * (one * two) ) + equivalence = new EquivalentExpressions + val mul = Multiply(one, two) + val mul2 = Multiply(mul, mul) + val sqrt = Sqrt(mul2) + val sum = Add(mul2, sqrt) + equivalence.addExprTree(mul, true) + equivalence.addExprTree(mul2, true) + equivalence.addExprTree(sqrt, true) + equivalence.addExprTree(sum, true) + + // (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found + assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 3) + assert(equivalence.getEquivalentExprs(mul).size == 3) + assert(equivalence.getEquivalentExprs(mul2).size == 3) + assert(equivalence.getEquivalentExprs(sqrt).size == 2) + assert(equivalence.getEquivalentExprs(sum).size == 1) + + // Some expressions inspired by TPCH-Q1 + // sum(l_quantity) as sum_qty, + // sum(l_extendedprice) as sum_base_price, + // sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + // sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + // avg(l_extendedprice) as avg_price, + // avg(l_discount) as avg_disc + equivalence = new EquivalentExpressions + val quantity = Literal(1) + val price = Literal(1.1) + val discount = Literal(.24) + val tax = Literal(0.1) + equivalence.addExprTree(quantity, false) + equivalence.addExprTree(price, false) + equivalence.addExprTree(Multiply(price, Subtract(Literal(1), discount)), false) + equivalence.addExprTree( + Multiply( + Multiply(price, Subtract(Literal(1), discount)), + Add(Literal(1), tax)), false) + equivalence.addExprTree(price, false) + equivalence.addExprTree(discount, false) + // quantity, price, discount and (price * (1 - discount)) + assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 4) + } + + test("Expression equivalence - non deterministic") { + val sum = Add(Rand(0), Rand(0)) + val equivalence = new EquivalentExpressions + equivalence.addExpr(sum) + equivalence.addExpr(sum) + assert(equivalence.getAllEquivalentExprs.isEmpty) + } + + test("Children of CodegenFallback") { + val one = Literal(1) + val two = Add(one, one) + val explode = Explode(two) + val add = Add(two, explode) + + var equivalence = new EquivalentExpressions + equivalence.addExprTree(add, true) + // the `two` inside `explode` should not be added + assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0) + assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala new file mode 100644 index 0000000000000..b82cf8d1693e2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.scalatest.PrivateMethodTester + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.types.LongType + +class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with PrivateMethodTester { + + test("time window is unevaluable") { + intercept[UnsupportedOperationException] { + evaluate(TimeWindow(Literal(10L), "1 second", "1 second", "0 second")) + } + } + + private def checkErrorMessage(msg: String, value: String): Unit = { + val validDuration = "10 second" + val validTime = "5 second" + val e1 = intercept[IllegalArgumentException] { + TimeWindow(Literal(10L), value, validDuration, validTime).windowDuration + } + val e2 = intercept[IllegalArgumentException] { + TimeWindow(Literal(10L), validDuration, value, validTime).slideDuration + } + val e3 = intercept[IllegalArgumentException] { + TimeWindow(Literal(10L), validDuration, validDuration, value).startTime + } + Seq(e1, e2, e3).foreach { e => + e.getMessage.contains(msg) + } + } + + test("blank intervals throw exception") { + for (blank <- Seq(null, " ", "\n", "\t")) { + checkErrorMessage( + "The window duration, slide duration and start time cannot be null or blank.", blank) + } + } + + test("invalid intervals throw exception") { + checkErrorMessage( + "did not correspond to a valid interval string.", "2 apples") + } + + test("intervals greater than a month throws exception") { + checkErrorMessage( + "Intervals greater than or equal to a month is not supported (1 month).", "1 month") + } + + test("interval strings work with and without 'interval' prefix and return microseconds") { + val validDuration = "10 second" + for ((text, seconds) <- Seq( + ("1 second", 1000000), // 1e6 + ("1 minute", 60000000), // 6e7 + ("2 hours", 7200000000L))) { // 72e9 + assert(TimeWindow(Literal(10L), text, validDuration, "0 seconds").windowDuration === seconds) + assert(TimeWindow(Literal(10L), "interval " + text, validDuration, "0 seconds").windowDuration + === seconds) + } + } + + private val parseExpression = PrivateMethod[Long]('parseExpression) + + test("parse sql expression for duration in microseconds - string") { + val dur = TimeWindow.invokePrivate(parseExpression(Literal("5 seconds"))) + assert(dur.isInstanceOf[Long]) + assert(dur === 5000000) + } + + test("parse sql expression for duration in microseconds - integer") { + val dur = TimeWindow.invokePrivate(parseExpression(Literal(100))) + assert(dur.isInstanceOf[Long]) + assert(dur === 100) + } + + test("parse sql expression for duration in microseconds - long") { + val dur = TimeWindow.invokePrivate(parseExpression(Literal.create(2 << 52, LongType))) + assert(dur.isInstanceOf[Long]) + assert(dur === (2 << 52)) + } + + test("parse sql expression for duration in microseconds - invalid interval") { + intercept[IllegalArgumentException] { + TimeWindow.invokePrivate(parseExpression(Literal("2 apples"))) + } + } + + test("parse sql expression for duration in microseconds - invalid expression") { + intercept[AnalysisException] { + TimeWindow.invokePrivate(parseExpression(Rand(123))) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 68545f33e5465..1265908182b3a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import org.scalatest.Matchers @@ -77,16 +78,16 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val row = new SpecificMutableRow(fieldTypes) row.setLong(0, 0) row.update(1, UTF8String.fromString("Hello")) - row.update(2, "World".getBytes) + row.update(2, "World".getBytes(StandardCharsets.UTF_8)) val unsafeRow: UnsafeRow = converter.apply(row) assert(unsafeRow.getSizeInBytes === 8 + (8 * 3) + - roundedSize("Hello".getBytes.length) + - roundedSize("World".getBytes.length)) + roundedSize("Hello".getBytes(StandardCharsets.UTF_8).length) + + roundedSize("World".getBytes(StandardCharsets.UTF_8).length)) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") - assert(unsafeRow.getBinary(2) === "World".getBytes) + assert(unsafeRow.getBinary(2) === "World".getBytes(StandardCharsets.UTF_8)) } test("basic conversion with primitive, string, date and timestamp types") { @@ -100,7 +101,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { row.update(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-05-08 08:10:25"))) val unsafeRow: UnsafeRow = converter.apply(row) - assert(unsafeRow.getSizeInBytes === 8 + (8 * 4) + roundedSize("Hello".getBytes.length)) + assert(unsafeRow.getSizeInBytes === + 8 + (8 * 4) + roundedSize("Hello".getBytes(StandardCharsets.UTF_8).length)) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") @@ -175,7 +177,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { r.setFloat(6, 600) r.setDouble(7, 700) r.update(8, UTF8String.fromString("hello")) - r.update(9, "world".getBytes) + r.update(9, "world".getBytes(StandardCharsets.UTF_8)) r.setDecimal(10, Decimal(10), 10) r.setDecimal(11, Decimal(10.00, 38, 18), 38) // r.update(11, Array(11)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala index 0d329497758c6..f5374229ca5cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import java.util.Random +import scala.collection.mutable + import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, MutableRow, BoundReference} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, MutableRow, SpecificMutableRow} import org.apache.spark.sql.types.{DataType, IntegerType} -import scala.collection.mutable -import org.scalatest.Assertions._ - class HyperLogLogPlusPlusSuite extends SparkFunSuite { /** Create a HLL++ instance and an input and output buffer. */ @@ -132,7 +131,7 @@ class HyperLogLogPlusPlusSuite extends SparkFunSuite { i += 1 } - // Merge the lower and upper halfs. + // Merge the lower and upper halves. hll.merge(buffer1a, buffer1b) // Create the other buffer in reverse diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala index 9da1068e9ca1d..f57b82bb96399 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala @@ -18,13 +18,20 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util._ class CodeFormatterSuite extends SparkFunSuite { def testCase(name: String)(input: String)(expected: String): Unit = { test(name) { - assert(CodeFormatter.format(input).trim === expected.trim) + if (CodeFormatter.format(input).trim !== expected.trim) { + fail( + s""" + |== FAIL: Formatted code doesn't match === + |${sideBySide(CodeFormatter.format(input).trim, expected.trim).mkString("\n")} + """.stripMargin) + } } } @@ -93,4 +100,50 @@ class CodeFormatterSuite extends SparkFunSuite { |/* 004 */ c) """.stripMargin } + + testCase("single line comments") { + """// This is a comment about class A { { { ( ( + |class A { + |class body; + |}""".stripMargin + }{ + """ + |/* 001 */ // This is a comment about class A { { { ( ( + |/* 002 */ class A { + |/* 003 */ class body; + |/* 004 */ } + """.stripMargin + } + + testCase("single line comments /* */ ") { + """/** This is a comment about class A { { { ( ( */ + |class A { + |class body; + |}""".stripMargin + }{ + """ + |/* 001 */ /** This is a comment about class A { { { ( ( */ + |/* 002 */ class A { + |/* 003 */ class body; + |/* 004 */ } + """.stripMargin + } + + testCase("multi-line comments") { + """ /* This is a comment about + |class A { + |class body; ...*/ + |class A { + |class body; + |}""".stripMargin + }{ + """ + |/* 001 */ /* This is a comment about + |/* 002 */ class A { + |/* 003 */ class body; ...*/ + |/* 004 */ class A { + |/* 005 */ class body; + |/* 006 */ } + """.stripMargin + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala index 2d3f98dbbd3d1..c9616cdb26c20 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala @@ -34,12 +34,6 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { assert(instance.apply(null).getBoolean(0) === false) } - test("GenerateProjection should initialize expressions") { - val expr = And(NondeterministicExpression(), NondeterministicExpression()) - val instance = GenerateProjection.generate(Seq(expr)) - assert(instance.apply(null).getBoolean(0) === false) - } - test("GenerateMutableProjection should initialize expressions") { val expr = And(NondeterministicExpression(), NondeterministicExpression()) val instance = GenerateMutableProjection.generate(Seq(expr))() @@ -64,18 +58,6 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { assert(instance2.apply(null).getBoolean(0) === true) } - test("GenerateProjection should not share expression instances") { - val expr1 = MutableExpression() - val instance1 = GenerateProjection.generate(Seq(expr1)) - assert(instance1.apply(null).getBoolean(0) === false) - - val expr2 = MutableExpression() - expr2.mutableState = true - val instance2 = GenerateProjection.generate(Seq(expr2)) - assert(instance1.apply(null).getBoolean(0) === false) - assert(instance2.apply(null).getBoolean(0) === true) - } - test("GenerateMutableProjection should not share expression instances") { val expr1 = MutableExpression() val instance1 = GenerateMutableProjection.generate(Seq(expr1))() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala index 796d60032e1a6..f8342214d9ae0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala @@ -90,13 +90,13 @@ class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite { } private def createUnsafeRow(numFields: Int): UnsafeRow = { - val row = new UnsafeRow + val row = new UnsafeRow(numFields) val sizeInBytes = numFields * 8 + ((numFields + 63) / 64) * 8 // Allocate a larger buffer than needed and point the UnsafeRow to somewhere in the middle. // This way we can test the joiner when the input UnsafeRows are not the entire arrays. val offset = numFields * 8 val buf = new Array[Byte](sizeInBytes + offset) - row.pointTo(buf, Platform.BYTE_ARRAY_OFFSET + offset, numFields, sizeInBytes) + row.pointTo(buf, Platform.BYTE_ARRAY_OFFSET + offset, sizeInBytes) row } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala index 59729e7646beb..9f19745cefd20 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala @@ -74,8 +74,9 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite { private def testConcatOnce(numFields1: Int, numFields2: Int, candidateTypes: Seq[DataType]) { info(s"schema size $numFields1, $numFields2") - val schema1 = RandomDataGenerator.randomSchema(numFields1, candidateTypes) - val schema2 = RandomDataGenerator.randomSchema(numFields2, candidateTypes) + val random = new Random() + val schema1 = RandomDataGenerator.randomSchema(random, numFields1, candidateTypes) + val schema2 = RandomDataGenerator.randomSchema(random, numFields2, candidateTypes) // Create the converters needed to convert from external row to internal row and to UnsafeRows. val internalConverter1 = CatalystTypeConverters.createToCatalystConverter(schema1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index 1522ee34e43a5..e2a8eb8ee1d34 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import java.nio.charset.StandardCharsets + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -107,7 +109,8 @@ class GeneratedProjectionSuite extends SparkFunSuite { val fields = Array[DataType](StringType, struct) val unsafeProj = UnsafeProjection.create(fields) - val innerRow = InternalRow(false, 1.toByte, 2.toShort, 3, 4.0f, "".getBytes, + val innerRow = InternalRow(false, 1.toByte, 2.toShort, 3, 4.0f, + "".getBytes(StandardCharsets.UTF_8), UTF8String.fromString("")) val row1 = InternalRow(UTF8String.fromString(""), innerRow) val unsafe1 = unsafeProj(row1).copy() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index 2d080b95b1292..e458eb8a1d362 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -17,32 +17,20 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor class AggregateOptimizeSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Aggregate", FixedPoint(100), - ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: Nil } - test("replace distinct with aggregate") { - val input = LocalRelation('a.int, 'b.int) - - val query = Distinct(input) - val optimized = Optimize.execute(query.analyze) - - val correctAnswer = Aggregate(input.output, input.output, input) - - comparePlans(optimized, correctAnswer) - } - test("remove literals in grouping expression") { val input = LocalRelation('a.int, 'b.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala new file mode 100644 index 0000000000000..7cd038570bbdf --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("AnalysisNodes", Once, + EliminateSubqueryAliases) :: + Batch("Constant Folding", FixedPoint(50), + NullPropagation, + ConstantFolding, + BooleanSimplification, + BinaryComparisonSimplification, + PruneFilters) :: Nil + } + + val nullableRelation = LocalRelation('a.int.withNullability(true)) + val nonNullableRelation = LocalRelation('a.int.withNullability(false)) + + test("Preserve nullable exprs in general") { + for (e <- Seq('a === 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a)) { + val plan = nullableRelation.where(e).analyze + val actual = Optimize.execute(plan) + val correctAnswer = plan + comparePlans(actual, correctAnswer) + } + } + + test("Preserve non-deterministic exprs") { + val plan = nonNullableRelation + .where(Rand(0) === Rand(0) && Rand(1) <=> Rand(1)).analyze + val actual = Optimize.execute(plan) + val correctAnswer = plan + comparePlans(actual, correctAnswer) + } + + test("Nullable Simplification Primitive: <=>") { + val plan = nullableRelation.select('a <=> 'a).analyze + val actual = Optimize.execute(plan) + val correctAnswer = nullableRelation.select(Alias(TrueLiteral, "(a <=> a)")()).analyze + comparePlans(actual, correctAnswer) + } + + test("Non-Nullable Simplification Primitive") { + val plan = nonNullableRelation + .select('a === 'a, 'a <=> 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a).analyze + val actual = Optimize.execute(plan) + val correctAnswer = nonNullableRelation + .select( + Alias(TrueLiteral, "(a = a)")(), + Alias(TrueLiteral, "(a <=> a)")(), + Alias(TrueLiteral, "(a <= a)")(), + Alias(TrueLiteral, "(a >= a)")(), + Alias(FalseLiteral, "(a < a)")(), + Alias(FalseLiteral, "(a > a)")()) + .analyze + comparePlans(actual, correctAnswer) + } + + test("Expression Normalization") { + val plan = nonNullableRelation.where( + 'a * Literal(100) + Pi() === Pi() + Literal(100) * 'a && + DateAdd(CurrentDate(), 'a + Literal(2)) <= DateAdd(CurrentDate(), Literal(2) + 'a)) + .analyze + val actual = Optimize.execute(plan) + val correctAnswer = nonNullableRelation.analyze + comparePlans(actual, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index cde346e99eb17..8147d06969bbe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -19,24 +19,25 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("AnalysisNodes", Once, - EliminateSubQueries) :: + EliminateSubqueryAliases) :: Batch("Constant Folding", FixedPoint(50), NullPropagation, ConstantFolding, BooleanSimplification, - SimplifyFilters) :: Nil + PruneFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string) @@ -80,33 +81,67 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { checkCondition(('a < 2 || 'a > 3 || 'b > 5) && 'a < 2, 'a < 2) - checkCondition('a < 2 && ('a < 2 || 'a > 3 || 'b > 5) , 'a < 2) + checkCondition('a < 2 && ('a < 2 || 'a > 3 || 'b > 5), 'a < 2) checkCondition(('a < 2 || 'b > 3) && ('a < 2 || 'c > 5), 'a < 2 || ('b > 3 && 'c > 5)) checkCondition( ('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a < 5), - ('a === 'b || 'b > 3 && 'a > 3 && 'a < 5)) + 'a === 'b || 'b > 3 && 'a > 3 && 'a < 5) } test("a && (!a || b)") { - checkCondition(('a && (!('a) || 'b )), ('a && 'b)) + checkCondition('a && (!'a || 'b ), 'a && 'b) + + checkCondition('a && ('b || !'a ), 'a && 'b) + + checkCondition((!'a || 'b ) && 'a, 'b && 'a) + + checkCondition(('b || !'a ) && 'a, 'b && 'a) + } + + test("a < 1 && (!(a < 1) || b)") { + checkCondition('a < 1 && (!('a < 1) || 'b), ('a < 1) && 'b) + checkCondition('a < 1 && ('b || !('a < 1)), ('a < 1) && 'b) - checkCondition(('a && ('b || !('a) )), ('a && 'b)) + checkCondition('a <= 1 && (!('a <= 1) || 'b), ('a <= 1) && 'b) + checkCondition('a <= 1 && ('b || !('a <= 1)), ('a <= 1) && 'b) - checkCondition(((!('a) || 'b ) && 'a), ('b && 'a)) + checkCondition('a > 1 && (!('a > 1) || 'b), ('a > 1) && 'b) + checkCondition('a > 1 && ('b || !('a > 1)), ('a > 1) && 'b) - checkCondition((('b || !('a) ) && 'a), ('b && 'a)) + checkCondition('a >= 1 && (!('a >= 1) || 'b), ('a >= 1) && 'b) + checkCondition('a >= 1 && ('b || !('a >= 1)), ('a >= 1) && 'b) } - test("!(a && b) , !(a || b)") { - checkCondition((!('a && 'b)), (!('a) || !('b))) + test("a < 1 && ((a >= 1) || b)") { + checkCondition('a < 1 && ('a >= 1 || 'b ), ('a < 1) && 'b) + checkCondition('a < 1 && ('b || 'a >= 1), ('a < 1) && 'b) + + checkCondition('a <= 1 && ('a > 1 || 'b ), ('a <= 1) && 'b) + checkCondition('a <= 1 && ('b || 'a > 1), ('a <= 1) && 'b) + + checkCondition('a > 1 && (('a <= 1) || 'b), ('a > 1) && 'b) + checkCondition('a > 1 && ('b || ('a <= 1)), ('a > 1) && 'b) + + checkCondition('a >= 1 && (('a < 1) || 'b), ('a >= 1) && 'b) + checkCondition('a >= 1 && ('b || ('a < 1)), ('a >= 1) && 'b) + } + + test("DeMorgan's law") { + checkCondition(!('a && 'b), !'a || !'b) + + checkCondition(!('a || 'b), !'a && !'b) + + checkCondition(!(('a && 'b) || ('c && 'd)), (!'a || !'b) && (!'c || !'d)) - checkCondition(!('a || 'b), (!('a) && !('b))) + checkCondition(!(('a || 'b) && ('c || 'd)), (!'a && !'b) || (!'c && !'d)) } - private val caseInsensitiveAnalyzer = - new Analyzer(EmptyCatalog, EmptyFunctionRegistry, new SimpleCatalystConf(false)) + private val caseInsensitiveConf = new SimpleCatalystConf(false) + private val caseInsensitiveAnalyzer = new Analyzer( + new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, caseInsensitiveConf), + caseInsensitiveConf) test("(a && b) || (a && c) => a && (b || c) when case insensitive") { val plan = caseInsensitiveAnalyzer.execute( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala new file mode 100644 index 0000000000000..587437e9aa81d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Rand +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class CollapseProjectSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", FixedPoint(10), EliminateSubqueryAliases) :: + Batch("CollapseProject", Once, CollapseProject) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int) + + test("collapse two deterministic, independent projects into one") { + val query = testRelation + .select(('a + 1).as('a_plus_1), 'b) + .select('a_plus_1, ('b + 1).as('b_plus_1)) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = testRelation.select(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("collapse two deterministic, dependent projects into one") { + val query = testRelation + .select(('a + 1).as('a_plus_1), 'b) + .select(('a_plus_1 + 1).as('a_plus_2), 'b) + + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = testRelation.select( + (('a + 1).as('a_plus_1) + 1).as('a_plus_2), + 'b).analyze + + comparePlans(optimized, correctAnswer) + } + + test("do not collapse nondeterministic projects") { + val query = testRelation + .select(Rand(10).as('rand)) + .select(('rand + 1).as('rand1), ('rand + 2).as('rand2)) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = query.analyze + + comparePlans(optimized, correctAnswer) + } + + test("collapse two nondeterministic, independent projects into one") { + val query = testRelation + .select(Rand(10).as('rand)) + .select(Rand(20).as('rand2)) + + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = testRelation + .select(Rand(20).as('rand2)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("collapse one nondeterministic, one deterministic, independent projects into one") { + val query = testRelation + .select(Rand(10).as('rand), 'a) + .select(('a + 1).as('a_plus_1)) + + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = testRelation + .select(('a + 1).as('a_plus_1)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("collapse project into aggregate") { + val query = testRelation + .groupBy('a, 'b)(('a + 1).as('a_plus_1), 'b) + .select('a_plus_1, ('b + 1).as('b_plus_1)) + + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = testRelation + .groupBy('a, 'b)(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("do not collapse common nondeterministic project and aggregate") { + val query = testRelation + .groupBy('a)('a, Rand(10).as('rand)) + .select(('rand + 1).as('rand1), ('rand + 2).as('rand2)) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = query.analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 4a1e7ceaf394b..52b574c0e63c9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -17,30 +17,36 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.expressions.Explode -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation, Generate, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.RuleExecutor +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count} +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types.StringType class ColumnPruningSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Column pruning", FixedPoint(100), - ColumnPruning) :: Nil + PushDownPredicate, + ColumnPruning, + CollapseProject) :: Nil } test("Column pruning for Generate when Generate.join = false") { val input = LocalRelation('a.int, 'b.array(StringType)) - val query = Generate(Explode('b), false, false, None, 's.string :: Nil, input).analyze + val query = input.generate(Explode('b), join = false).analyze + val optimized = Optimize.execute(query) - val correctAnswer = - Generate(Explode('b), false, false, None, 's.string :: Nil, - Project('b.attr :: Nil, input)).analyze + val correctAnswer = input.select('b).generate(Explode('b), join = false).analyze comparePlans(optimized, correctAnswer) } @@ -49,16 +55,19 @@ class ColumnPruningSuite extends PlanTest { val input = LocalRelation('a.int, 'b.int, 'c.array(StringType)) val query = - Project(Seq('a, 's), - Generate(Explode('c), true, false, None, 's.string :: Nil, - input)).analyze + input + .generate(Explode('c), join = true, outputNames = "explode" :: Nil) + .select('a, 'explode) + .analyze + val optimized = Optimize.execute(query) val correctAnswer = - Project(Seq('a, 's), - Generate(Explode('c), true, false, None, 's.string :: Nil, - Project(Seq('a, 'c), - input))).analyze + input + .select('a, 'c) + .generate(Explode('c), join = true, outputNames = "explode" :: Nil) + .select('a, 'explode) + .analyze comparePlans(optimized, correctAnswer) } @@ -67,15 +76,18 @@ class ColumnPruningSuite extends PlanTest { val input = LocalRelation('b.array(StringType)) val query = - Project(('s + 1).as("s+1") :: Nil, - Generate(Explode('b), true, false, None, 's.string :: Nil, - input)).analyze + input + .generate(Explode('b), join = true, outputNames = "explode" :: Nil) + .select(('explode + 1).as("result")) + .analyze + val optimized = Optimize.execute(query) val correctAnswer = - Project(('s + 1).as("s+1") :: Nil, - Generate(Explode('b), false, false, None, 's.string :: Nil, - input)).analyze + input + .generate(Explode('b), join = false, outputNames = "explode" :: Nil) + .select(('explode + 1).as("result")) + .analyze comparePlans(optimized, correctAnswer) } @@ -91,5 +103,271 @@ class ColumnPruningSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("Column pruning for Expand") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + val query = + Aggregate( + Seq('aa, 'gid), + Seq(sum('c).as("sum")), + Expand( + Seq( + Seq('a, 'b, 'c, Literal.create(null, StringType), 1), + Seq('a, 'b, 'c, 'a, 2)), + Seq('a, 'b, 'c, 'aa.int, 'gid.int), + input)).analyze + val optimized = Optimize.execute(query) + + val expected = + Aggregate( + Seq('aa, 'gid), + Seq(sum('c).as("sum")), + Expand( + Seq( + Seq('c, Literal.create(null, StringType), 1), + Seq('c, 'a, 2)), + Seq('c, 'aa.int, 'gid.int), + Project(Seq('a, 'c), + input))).analyze + + comparePlans(optimized, expected) + } + + test("Column pruning on Filter") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + val plan1 = Filter('a > 1, input).analyze + comparePlans(Optimize.execute(plan1), plan1) + val query = Project('a :: Nil, Filter('c > Literal(0.0), input)).analyze + comparePlans(Optimize.execute(query), query) + val plan2 = Filter('b > 1, Project(Seq('a, 'b), input)).analyze + val expected2 = Project(Seq('a, 'b), Filter('b > 1, input)).analyze + comparePlans(Optimize.execute(plan2), expected2) + val plan3 = Project(Seq('a), Filter('b > 1, Project(Seq('a, 'b), input))).analyze + val expected3 = Project(Seq('a), Filter('b > 1, input)).analyze + comparePlans(Optimize.execute(plan3), expected3) + } + + test("Column pruning on except/intersect/distinct") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + val query = Project('a :: Nil, Except(input, input)).analyze + comparePlans(Optimize.execute(query), query) + + val query2 = Project('a :: Nil, Intersect(input, input)).analyze + comparePlans(Optimize.execute(query2), query2) + val query3 = Project('a :: Nil, Distinct(input)).analyze + comparePlans(Optimize.execute(query3), query3) + } + + test("Column pruning on Project") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + val query = Project('a :: Nil, Project(Seq('a, 'b), input)).analyze + val expected = Project(Seq('a), input).analyze + comparePlans(Optimize.execute(query), expected) + } + + test("Eliminate the Project with an empty projectList") { + val input = OneRowRelation + val expected = Project(Literal(1).as("1") :: Nil, input).analyze + + val query1 = + Project(Literal(1).as("1") :: Nil, Project(Literal(1).as("1") :: Nil, input)).analyze + comparePlans(Optimize.execute(query1), expected) + + val query2 = + Project(Literal(1).as("1") :: Nil, Project(Nil, input)).analyze + comparePlans(Optimize.execute(query2), expected) + + // to make sure the top Project will not be removed. + comparePlans(Optimize.execute(expected), expected) + } + + test("column pruning for group") { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val originalQuery = + testRelation + .groupBy('a)('a, count('b)) + .select('a) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .groupBy('a)('a).analyze + + comparePlans(optimized, correctAnswer) + } + + test("column pruning for group with alias") { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + val originalQuery = + testRelation + .groupBy('a)('a as 'c, count('b)) + .select('c) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .groupBy('a)('a as 'c).analyze + + comparePlans(optimized, correctAnswer) + } + + test("column pruning for Project(ne, Limit)") { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + val originalQuery = + testRelation + .select('a, 'b) + .limit(2) + .select('a) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .limit(2).analyze + + comparePlans(optimized, correctAnswer) + } + + test("push down project past sort") { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val x = testRelation.subquery('x) + + // push down valid + val originalQuery = { + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('a) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + x.select('a) + .sortBy(SortOrder('a, Ascending)).analyze + + comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) + + // push down invalid + val originalQuery1 = { + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('b) + } + + val optimized1 = Optimize.execute(originalQuery1.analyze) + val correctAnswer1 = + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('b).analyze + + comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1)) + } + + test("Column pruning on Window with useless aggregate functions") { + val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int) + + val originalQuery = + input.groupBy('a, 'c, 'd)('a, 'c, 'd, + WindowExpression( + AggregateExpression(Count('b), Complete, isDistinct = false), + WindowSpecDefinition( 'a :: Nil, + SortOrder('b, Ascending) :: Nil, + UnspecifiedFrame)).as('window)).select('a, 'c) + + val correctAnswer = input.select('a, 'c, 'd).groupBy('a, 'c, 'd)('a, 'c).analyze + + val optimized = Optimize.execute(originalQuery.analyze) + + comparePlans(optimized, correctAnswer) + } + + test("Column pruning on Window with selected agg expressions") { + val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int) + + val originalQuery = + input.select('a, 'b, 'c, 'd, + WindowExpression( + AggregateExpression(Count('b), Complete, isDistinct = false), + WindowSpecDefinition( 'a :: Nil, + SortOrder('b, Ascending) :: Nil, + UnspecifiedFrame)).as('window)).where('window > 1).select('a, 'c) + + val correctAnswer = + input.select('a, 'b, 'c) + .window(WindowExpression( + AggregateExpression(Count('b), Complete, isDistinct = false), + WindowSpecDefinition( 'a :: Nil, + SortOrder('b, Ascending) :: Nil, + UnspecifiedFrame)).as('window) :: Nil, + 'a :: Nil, 'b.asc :: Nil) + .where('window > 1).select('a, 'c).analyze + + val optimized = Optimize.execute(originalQuery.analyze) + + comparePlans(optimized, correctAnswer) + } + + test("Column pruning on Window in select") { + val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int) + + val originalQuery = + input.select('a, 'b, 'c, 'd, + WindowExpression( + AggregateExpression(Count('b), Complete, isDistinct = false), + WindowSpecDefinition( 'a :: Nil, + SortOrder('b, Ascending) :: Nil, + UnspecifiedFrame)).as('window)).select('a, 'c) + + val correctAnswer = input.select('a, 'c).analyze + + val optimized = Optimize.execute(originalQuery.analyze) + + comparePlans(optimized, correctAnswer) + } + + test("Column pruning on Union") { + val input1 = LocalRelation('a.int, 'b.string, 'c.double) + val input2 = LocalRelation('c.int, 'd.string, 'e.double) + val query = Project('b :: Nil, + Union(input1 :: input2 :: Nil)).analyze + val expected = Project('b :: Nil, + Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil)).analyze + comparePlans(Optimize.execute(query), expected) + } + + test("Remove redundant projects in column pruning rule") { + val input = LocalRelation('key.int, 'value.string) + + val query = + Project(Seq($"x.key", $"y.key"), + Join( + SubqueryAlias("x", input), + BroadcastHint(SubqueryAlias("y", input)), Inner, None)).analyze + + val optimized = Optimize.execute(query) + + val expected = + Join( + Project(Seq($"x.key"), SubqueryAlias("x", input)), + BroadcastHint( + Project(Seq($"y.key"), SubqueryAlias("y", input))), + Inner, None).analyze + + comparePlans(optimized, expected) + } + + implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]() + private val func = identity[Iterator[OtherTuple]] _ + + test("Column pruning on MapPartitions") { + val input = LocalRelation('_1.int, '_2.int, 'c.int) + val plan1 = MapPartitions(func, input) + val correctAnswer1 = + MapPartitions(func, Project(Seq('_1, '_2), input)).analyze + comparePlans(Optimize.execute(plan1.analyze), correctAnswer1) + } + // todo: add more tests for column pruning } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index 06c592f4905a3..87ad81db11b64 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ class CombiningLimitsSuite extends PlanTest { @@ -34,7 +34,8 @@ class CombiningLimitsSuite extends PlanTest { Batch("Constant Folding", FixedPoint(10), NullPropagation, ConstantFolding, - BooleanSimplification) :: Nil + BooleanSimplification, + SimplifyConditionals) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala new file mode 100644 index 0000000000000..10ed4e46ddd1c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.util.DateTimeUtils + +class ComputeCurrentTimeSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime)) + } + + test("analyzer should replace current_timestamp with literals") { + val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), + LocalRelation()) + + val min = System.currentTimeMillis() * 1000 + val plan = Optimize.execute(in.analyze).asInstanceOf[Project] + val max = (System.currentTimeMillis() + 1) * 1000 + + val lits = new scala.collection.mutable.ArrayBuffer[Long] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[Long] + e + } + assert(lits.size == 2) + assert(lits(0) >= min && lits(0) <= max) + assert(lits(1) >= min && lits(1) <= max) + assert(lits(0) == lits(1)) + } + + test("analyzer should replace current_date with literals") { + val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) + + val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val plan = Optimize.execute(in.analyze).asInstanceOf[Project] + val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) + + val lits = new scala.collection.mutable.ArrayBuffer[Int] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[Int] + e + } + assert(lits.size == 2) + assert(lits(0) >= min && lits(0) <= max) + assert(lits(1) >= min && lits(1) <= max) + assert(lits(0) == lits(1)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index e67606288f514..641c89873dcc4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -17,23 +17,21 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, EliminateSubQueries} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types._ -// For implicit conversions -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ - class ConstantFoldingSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("AnalysisNodes", Once, - EliminateSubQueries) :: + EliminateSubqueryAliases) :: Batch("ConstantFolding", Once, OptimizeIn, ConstantFolding, @@ -162,7 +160,7 @@ class ConstantFoldingSuite extends PlanTest { testRelation .select( Rand(5L) + Literal(1) as Symbol("c1"), - Sum('a) as Symbol("c2")) + sum('a) as Symbol("c2")) val optimized = Optimize.execute(originalQuery.analyze) @@ -170,7 +168,7 @@ class ConstantFoldingSuite extends PlanTest { testRelation .select( Rand(5L) + Literal(1.0) as Symbol("c1"), - Sum('a) as Symbol("c2")) + sum('a) as Symbol("c2")) .analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala new file mode 100644 index 0000000000000..91777375608fd --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.NewInstance +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, MapPartitions} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +case class OtherTuple(_1: Int, _2: Int) + +class EliminateSerializationSuite extends PlanTest { + private object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Serialization", FixedPoint(100), + EliminateSerialization) :: Nil + } + + implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]() + private val func = identity[Iterator[(Int, Int)]] _ + private val func2 = identity[Iterator[OtherTuple]] _ + + def assertObjectCreations(count: Int, plan: LogicalPlan): Unit = { + val newInstances = plan.flatMap(_.expressions.collect { + case n: NewInstance => n + }) + + if (newInstances.size != count) { + fail( + s""" + |Wrong number of object creations in plan: ${newInstances.size} != $count + |$plan + """.stripMargin) + } + } + + test("back to back MapPartitions") { + val input = LocalRelation('_1.int, '_2.int) + val plan = + MapPartitions(func, + MapPartitions(func, input)) + + val optimized = Optimize.execute(plan.analyze) + assertObjectCreations(1, optimized) + } + + test("back to back with object change") { + val input = LocalRelation('_1.int, '_2.int) + val plan = + MapPartitions(func, + MapPartitions(func2, input)) + + val optimized = Optimize.execute(plan.analyze) + assertObjectCreations(2, optimized) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala new file mode 100644 index 0000000000000..8c92ad82ac5be --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +class EliminateSortsSuite extends PlanTest { + val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true, orderByOrdinal = false) + val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) + val analyzer = new Analyzer(catalog, conf) + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Eliminate Sorts", Once, + EliminateSorts) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + test("Empty order by clause") { + val x = testRelation + + val query = x.orderBy() + val optimized = Optimize.execute(query.analyze) + val correctAnswer = x.analyze + + comparePlans(optimized, correctAnswer) + } + + test("All the SortOrder are no-op") { + val x = testRelation + + val query = x.orderBy(SortOrder(3, Ascending), SortOrder(-1, Ascending)) + val optimized = Optimize.execute(analyzer.execute(query)) + val correctAnswer = analyzer.execute(x) + + comparePlans(optimized, correctAnswer) + } + + test("Partial order-by clauses contain no-op SortOrder") { + val x = testRelation + + val query = x.orderBy(SortOrder(3, Ascending), 'a.asc) + val optimized = Optimize.execute(analyzer.execute(query)) + val correctAnswer = analyzer.execute(x.orderBy('a.asc)) + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala new file mode 100644 index 0000000000000..9b6d68aee803a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + + +class EliminateSubqueryAliasesSuite extends PlanTest with PredicateHelper { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("EliminateSubqueryAliases", Once, EliminateSubqueryAliases) :: Nil + } + + private def assertEquivalent(e1: Expression, e2: Expression): Unit = { + val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze + val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze) + comparePlans(actual, correctAnswer) + } + + private def afterOptimization(plan: LogicalPlan): LogicalPlan = { + Optimize.execute(analysis.SimpleAnalyzer.execute(plan)) + } + + test("eliminate top level subquery") { + val input = LocalRelation('a.int, 'b.int) + val query = SubqueryAlias("a", input) + comparePlans(afterOptimization(query), input) + } + + test("eliminate mid-tree subquery") { + val input = LocalRelation('a.int, 'b.int) + val query = Filter(TrueLiteral, SubqueryAlias("a", input)) + comparePlans( + afterOptimization(query), + Filter(TrueLiteral, LocalRelation('a.int, 'b.int))) + } + + test("eliminate multiple subqueries") { + val input = LocalRelation('a.int, 'b.int) + val query = Filter(TrueLiteral, + SubqueryAlias("c", SubqueryAlias("b", SubqueryAlias("a", input)))) + comparePlans( + afterOptimization(query), + Filter(TrueLiteral, LocalRelation('a.int, 'b.int))) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index ed810a12808f0..df7529d83f7c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{LeftOuter, LeftSemi, PlanTest, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types.IntegerType class FilterPushdownSuite extends PlanTest { @@ -32,17 +32,14 @@ class FilterPushdownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, - EliminateSubQueries) :: - Batch("Filter Pushdown", Once, + EliminateSubqueryAliases) :: + Batch("Filter Pushdown", FixedPoint(10), SamplePushDown, CombineFilters, - PushPredicateThroughProject, + PushDownPredicate, BooleanSimplification, PushPredicateThroughJoin, - PushPredicateThroughGenerate, - PushPredicateThroughAggregate, - ColumnPruning, - ProjectCollapsing) :: Nil + CollapseProject) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -65,66 +62,33 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("column pruning for group") { - val originalQuery = - testRelation - .groupBy('a)('a, Count('b)) - .select('a) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .select('a) - .groupBy('a)('a) - .select('a).analyze - - comparePlans(optimized, correctAnswer) - } - - test("column pruning for group with alias") { - val originalQuery = - testRelation - .groupBy('a)('a as 'c, Count('b)) - .select('c) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .select('a) - .groupBy('a)('a as 'c) - .select('c).analyze - - comparePlans(optimized, correctAnswer) - } - - test("column pruning for Project(ne, Limit)") { + // After this line is unimplemented. + test("simple push down") { val originalQuery = testRelation - .select('a, 'b) - .limit(2) .select('a) + .where('a === 1) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation + .where('a === 1) .select('a) - .limit(2).analyze + .analyze comparePlans(optimized, correctAnswer) } - // After this line is unimplemented. - test("simple push down") { + test("combine redundant filters") { val originalQuery = testRelation - .select('a) - .where('a === 1) + .where('a === 1 && 'b === 1) + .where('a === 1 && 'c === 1) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where('a === 1) - .select('a) + .where('a === 1 && 'b === 1 && 'c === 1) .analyze comparePlans(optimized, correctAnswer) @@ -147,7 +111,7 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("nondeterministic: can't push down filter through project") { + test("nondeterministic: can't push down filter with nondeterministic condition through project") { val originalQuery = testRelation .select(Rand(10).as('rand), 'a) .where('rand > 5 || 'a > 5) @@ -158,36 +122,15 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery) } - test("nondeterministic: push down part of filter through project") { + test("nondeterministic: can't push down filter through project with nondeterministic field") { val originalQuery = testRelation .select(Rand(10).as('rand), 'a) - .where('rand > 5 && 'a > 5) - .analyze - - val optimized = Optimize.execute(originalQuery) - - val correctAnswer = testRelation .where('a > 5) - .select(Rand(10).as('rand), 'a) - .where('rand > 5) - .analyze - - comparePlans(optimized, correctAnswer) - } - - test("nondeterministic: push down filter through project") { - val originalQuery = testRelation - .select(Rand(10).as('rand), 'a) - .where('a > 5 && 'a < 10) .analyze val optimized = Optimize.execute(originalQuery) - val correctAnswer = testRelation - .where('a > 5 && 'a < 10) - .select(Rand(10).as('rand), 'a) - .analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized, originalQuery) } test("filters: combines filters") { @@ -483,7 +426,7 @@ class FilterPushdownSuite extends PlanTest { } val optimized = Optimize.execute(originalQuery.analyze) - comparePlans(analysis.EliminateSubQueries(originalQuery.analyze), optimized) + comparePlans(analysis.EliminateSubqueryAliases(originalQuery.analyze), optimized) } test("joins: conjunctive predicates") { @@ -502,7 +445,7 @@ class FilterPushdownSuite extends PlanTest { left.join(right, condition = Some("x.b".attr === "y.b".attr)) .analyze - comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer)) + comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) } test("joins: conjunctive predicates #2") { @@ -521,7 +464,7 @@ class FilterPushdownSuite extends PlanTest { left.join(right, condition = Some("x.b".attr === "y.b".attr)) .analyze - comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer)) + comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) } test("joins: conjunctive predicates #3") { @@ -545,7 +488,7 @@ class FilterPushdownSuite extends PlanTest { condition = Some("z.a".attr === "x.b".attr)) .analyze - comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer)) + comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) } val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) @@ -566,6 +509,24 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("generate: non-deterministic predicate referenced no generated column") { + val originalQuery = { + testRelationWithArrayType + .generate(Explode('c_arr), true, false, Some("arr")) + .where(('b >= 5) && ('a + Rand(10).as("rnd") > 6)) + } + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = { + testRelationWithArrayType + .where('b >= 5) + .generate(Explode('c_arr), true, false, Some("arr")) + .where('a + Rand(10).as("rnd") > 6) + .analyze + } + + comparePlans(optimized, correctAnswer) + } + test("generate: part of conjuncts referenced generated column") { val generator = Explode('c_arr) val originalQuery = { @@ -587,7 +548,7 @@ class FilterPushdownSuite extends PlanTest { // Filter("c" > 6) assertResult(classOf[Filter])(optimized.getClass) assertResult(1)(optimized.asInstanceOf[Filter].condition.references.size) - assertResult("c"){ + assertResult("c") { optimized.asInstanceOf[Filter].condition.references.toSeq(0).name } @@ -606,57 +567,25 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery) } - test("push down project past sort") { - val x = testRelation.subquery('x) - - // push down valid - val originalQuery = { - x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('a) - } - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - x.select('a) - .sortBy(SortOrder('a, Ascending)).analyze - - comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer)) - - // push down invalid - val originalQuery1 = { - x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('b) - } - - val optimized1 = Optimize.execute(originalQuery1.analyze) - val correctAnswer1 = - x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('b).analyze - - comparePlans(optimized1, analysis.EliminateSubQueries(correctAnswer1)) - } - test("push project and filter down into sample") { val x = testRelation.subquery('x) val originalQuery = - Sample(0.0, 0.6, false, 11L, x).select('a) + Sample(0.0, 0.6, false, 11L, x)().select('a) - val originalQueryAnalyzed = EliminateSubQueries(analysis.SimpleAnalyzer.execute(originalQuery)) + val originalQueryAnalyzed = + EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(originalQuery)) val optimized = Optimize.execute(originalQueryAnalyzed) val correctAnswer = - Sample(0.0, 0.6, false, 11L, x.select('a)) + Sample(0.0, 0.6, false, 11L, x.select('a))() comparePlans(optimized, correctAnswer.analyze) } test("aggregate: push down filter when filter on group by expression") { val originalQuery = testRelation - .groupBy('a)('a, Count('b) as 'c) + .groupBy('a)('a, count('b) as 'c) .select('a, 'c) .where('a === 2) @@ -664,7 +593,7 @@ class FilterPushdownSuite extends PlanTest { val correctAnswer = testRelation .where('a === 2) - .groupBy('a)('a, Count('b) as 'c) + .groupBy('a)('a, count('b) as 'c) .analyze comparePlans(optimized, correctAnswer) } @@ -672,7 +601,7 @@ class FilterPushdownSuite extends PlanTest { test("aggregate: don't push down filter when filter not on group by expression") { val originalQuery = testRelation .select('a, 'b) - .groupBy('a)('a, Count('b) as 'c) + .groupBy('a)('a, count('b) as 'c) .where('c === 2L) val optimized = Optimize.execute(originalQuery.analyze) @@ -683,18 +612,135 @@ class FilterPushdownSuite extends PlanTest { test("aggregate: push down filters partially which are subset of group by expressions") { val originalQuery = testRelation .select('a, 'b) - .groupBy('a)('a, Count('b) as 'c) + .groupBy('a)('a, count('b) as 'c) .where('c === 2L && 'a === 3) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a, 'b) .where('a === 3) - .groupBy('a)('a, Count('b) as 'c) + .select('a, 'b) + .groupBy('a)('a, count('b) as 'c) .where('c === 2L) .analyze comparePlans(optimized, correctAnswer) } + + test("aggregate: push down filters with alias") { + val originalQuery = testRelation + .select('a, 'b) + .groupBy('a)(('a + 1) as 'aa, count('b) as 'c) + .where(('c === 2L || 'aa > 4) && 'aa < 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .where('a + 1 < 3) + .select('a, 'b) + .groupBy('a)(('a + 1) as 'aa, count('b) as 'c) + .where('c === 2L || 'aa > 4) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("aggregate: push down filters with literal") { + val originalQuery = testRelation + .select('a, 'b) + .groupBy('a)('a, count('b) as 'c, "s" as 'd) + .where('c === 2L && 'd === "s") + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .where("s" === "s") + .select('a, 'b) + .groupBy('a)('a, count('b) as 'c, "s" as 'd) + .where('c === 2L) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("aggregate: don't push down filters that are nondeterministic") { + val originalQuery = testRelation + .select('a, 'b) + .groupBy('a)('a + Rand(10) as 'aa, count('b) as 'c, Rand(11).as("rnd")) + .where('c === 2L && 'aa + Rand(10).as("rnd") === 3 && 'rnd === 5) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .select('a, 'b) + .groupBy('a)('a + Rand(10) as 'aa, count('b) as 'c, Rand(11).as("rnd")) + .where('c === 2L && 'aa + Rand(10).as("rnd") === 3 && 'rnd === 5) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("broadcast hint") { + val originalQuery = BroadcastHint(testRelation) + .where('a === 2L && 'b + Rand(10).as("rnd") === 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = BroadcastHint(testRelation.where('a === 2L)) + .where('b + Rand(10).as("rnd") === 3) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("union") { + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + + val originalQuery = Union(Seq(testRelation, testRelation2)) + .where('a === 2L && 'b + Rand(10).as("rnd") === 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = Union(Seq( + testRelation.where('a === 2L), + testRelation2.where('d === 2L))) + .where('b + Rand(10).as("rnd") === 3) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("intersect") { + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + + val originalQuery = Intersect(testRelation, testRelation2) + .where('a === 2L && 'b + Rand(10).as("rnd") === 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = Intersect( + testRelation.where('a === 2L), + testRelation2.where('d === 2L)) + .where('b + Rand(10).as("rnd") === 3) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("except") { + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + + val originalQuery = Except(testRelation, testRelation2) + .where('a === 2L && 'b + Rand(10).as("rnd") === 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = Except( + testRelation.where('a === 2L), + testRelation2) + .where('b + Rand(10).as("rnd") === 3) + .analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala new file mode 100644 index 0000000000000..e7fdd5a6202b6 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +class InferFiltersFromConstraintsSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("InferFilters", FixedPoint(5), InferFiltersFromConstraints) :: + Batch("PredicatePushdown", FixedPoint(5), PushPredicateThroughJoin) :: + Batch("CombineFilters", FixedPoint(5), CombineFilters) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + test("filter: filter out constraints in condition") { + val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze + val correctAnswer = testRelation + .where(IsNotNull('a) && IsNotNull('b) && 'a === 'b && 'a === 1 && 'b === 1).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("single inner join: filter out values on either side on equi-join keys") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val originalQuery = x.join(y, + condition = Some(("x.a".attr === "y.a".attr) && ("x.a".attr === 1) && ("y.c".attr > 5))) + .analyze + val left = x.where(IsNotNull('a) && "x.a".attr === 1) + val right = y.where(IsNotNull('a) && IsNotNull('c) && "y.c".attr > 5 && "y.a".attr === 1) + val correctAnswer = left.join(right, condition = Some("x.a".attr === "y.a".attr)).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("single inner join: filter out nulls on either side on non equal keys") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val originalQuery = x.join(y, + condition = Some(("x.a".attr =!= "y.a".attr) && ("x.b".attr === 1) && ("y.c".attr > 5))) + .analyze + val left = x.where(IsNotNull('a) && IsNotNull('b) && "x.b".attr === 1) + val right = y.where(IsNotNull('a) && IsNotNull('c) && "y.c".attr > 5) + val correctAnswer = left.join(right, condition = Some("x.a".attr =!= "y.a".attr)).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("single inner join with pre-existing filters: filter out values on either side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val originalQuery = x.where('b > 5).join(y.where('a === 10), + condition = Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).analyze + val left = x.where(IsNotNull('a) && 'a === 10 && IsNotNull('b) && 'b > 5) + val right = y.where(IsNotNull('a) && IsNotNull('b) && 'a === 10 && 'b > 5) + val correctAnswer = left.join(right, + condition = Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("single outer join: no null filters are generated") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val originalQuery = x.join(y, FullOuter, + condition = Some("x.a".attr === "y.a".attr)).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, originalQuery) + } + + test("multiple inner joins: filter out values on all sides on equi-join keys") { + val t1 = testRelation.subquery('t1) + val t2 = testRelation.subquery('t2) + val t3 = testRelation.subquery('t3) + val t4 = testRelation.subquery('t4) + + val originalQuery = t1.where('b > 5) + .join(t2, condition = Some("t1.b".attr === "t2.b".attr)) + .join(t3, condition = Some("t2.b".attr === "t3.b".attr)) + .join(t4, condition = Some("t3.b".attr === "t4.b".attr)).analyze + val correctAnswer = t1.where(IsNotNull('b) && 'b > 5) + .join(t2.where(IsNotNull('b) && 'b > 5), condition = Some("t1.b".attr === "t2.b".attr)) + .join(t3.where(IsNotNull('b) && 'b > 5), condition = Some("t2.b".attr === "t3.b".attr)) + .join(t4.where(IsNotNull('b) && 'b > 5), condition = Some("t3.b".attr === "t4.b".attr)) + .analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("inner join with filter: filter out values on all sides on equi-join keys") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = + x.join(y, Inner, Some("x.a".attr === "y.a".attr)).where("x.a".attr > 5).analyze + val correctAnswer = x.where(IsNotNull('a) && 'a.attr > 5) + .join(y.where(IsNotNull('a) && 'a.attr > 5), Inner, Some("x.a".attr === "y.a".attr)).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala new file mode 100644 index 0000000000000..c1ebf8b09e08d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor + + +class JoinOptimizationSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubqueryAliases) :: + Batch("Filter Pushdown", FixedPoint(100), + CombineFilters, + PushDownPredicate, + BooleanSimplification, + ReorderJoin, + PushPredicateThroughJoin, + ColumnPruning, + CollapseProject) :: Nil + + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation1 = LocalRelation('d.int) + + test("extract filters and joins") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + val z = testRelation.subquery('z) + + def testExtract(plan: LogicalPlan, expected: Option[(Seq[LogicalPlan], Seq[Expression])]) { + assert(ExtractFiltersAndInnerJoins.unapply(plan) === expected) + } + + testExtract(x, None) + testExtract(x.where("x.b".attr === 1), None) + testExtract(x.join(y), Some(Seq(x, y), Seq())) + testExtract(x.join(y, condition = Some("x.b".attr === "y.d".attr)), + Some(Seq(x, y), Seq("x.b".attr === "y.d".attr))) + testExtract(x.join(y).where("x.b".attr === "y.d".attr), + Some(Seq(x, y), Seq("x.b".attr === "y.d".attr))) + testExtract(x.join(y).join(z), Some(Seq(x, y, z), Seq())) + testExtract(x.join(y).where("x.b".attr === "y.d".attr).join(z), + Some(Seq(x, y, z), Seq("x.b".attr === "y.d".attr))) + testExtract(x.join(y).join(x.join(z)), Some(Seq(x, y, x.join(z)), Seq())) + testExtract(x.join(y).join(x.join(z)).where("x.b".attr === "y.d".attr), + Some(Seq(x, y, x.join(z)), Seq("x.b".attr === "y.d".attr))) + } + + test("reorder inner joins") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + val z = testRelation.subquery('z) + + val originalQuery = { + x.join(y).join(z) + .where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr)) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + x.join(z, condition = Some("x.b".attr === "z.b".attr)) + .join(y, condition = Some("y.d".attr === "z.a".attr)) + .analyze + + comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) + } + + test("broadcasthint sets relation statistics to smallest value") { + val input = LocalRelation('key.int, 'value.string) + + val query = + Project(Seq($"x.key", $"y.key"), + Join( + SubqueryAlias("x", input), + BroadcastHint(SubqueryAlias("y", input)), Inner, None)).analyze + + val optimized = Optimize.execute(query) + + val expected = + Join( + Project(Seq($"x.key"), SubqueryAlias("x", input)), + BroadcastHint(Project(Seq($"y.key"), SubqueryAlias("y", input))), + Inner, None).analyze + + comparePlans(optimized, expected) + + val broadcastChildren = optimized.collect { + case Join(_, r, _, _) if r.statistics.sizeInBytes == 1 => r + } + assert(broadcastChildren.size == 1) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala index b3df487c84dc8..fdde89d079bc0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.catalyst.optimizer +/* Implicit conversions */ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules._ -/* Implicit conversions */ -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ - class LikeSimplificationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -62,6 +61,20 @@ class LikeSimplificationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("simplify Like into startsWith and EndsWith") { + val originalQuery = + testRelation + .where(('a like "abc\\%def") || ('a like "abc%def")) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .where(('a like "abc\\%def") || + (Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def")))) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("simplify Like into Contains") { val originalQuery = testRelation diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala new file mode 100644 index 0000000000000..dcbc79365c3aa --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Add +import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, PlanTest, RightOuter} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +class LimitPushdownSuite extends PlanTest { + + private object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubqueryAliases) :: + Batch("Limit pushdown", FixedPoint(100), + LimitPushDown, + CombineLimits, + ConstantFolding, + BooleanSimplification) :: Nil + } + + private val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + private val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + private val x = testRelation.subquery('x) + private val y = testRelation.subquery('y) + + // Union --------------------------------------------------------------------------------------- + + test("Union: limit to each side") { + val unionQuery = Union(testRelation, testRelation2).limit(1) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Limit(1, Union(LocalLimit(1, testRelation), LocalLimit(1, testRelation2))).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("Union: limit to each side with constant-foldable limit expressions") { + val unionQuery = Union(testRelation, testRelation2).limit(Add(1, 1)) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Limit(2, Union(LocalLimit(2, testRelation), LocalLimit(2, testRelation2))).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("Union: limit to each side with the new limit number") { + val unionQuery = Union(testRelation, testRelation2.limit(3)).limit(1) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Limit(1, Union(LocalLimit(1, testRelation), LocalLimit(1, testRelation2))).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("Union: no limit to both sides if children having smaller limit values") { + val unionQuery = Union(testRelation.limit(1), testRelation2.select('d).limit(1)).limit(2) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Limit(2, Union(testRelation.limit(1), testRelation2.select('d).limit(1))).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("Union: limit to each sides if children having larger limit values") { + val testLimitUnion = Union(testRelation.limit(3), testRelation2.select('d).limit(4)) + val unionQuery = testLimitUnion.limit(2) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Limit(2, Union(LocalLimit(2, testRelation), LocalLimit(2, testRelation2.select('d)))).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + // Outer join ---------------------------------------------------------------------------------- + + test("left outer join") { + val originalQuery = x.join(y, LeftOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, y).join(y, LeftOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("right outer join") { + val originalQuery = x.join(y, RightOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, x.join(LocalLimit(1, y), RightOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("larger limits are not pushed on top of smaller ones in right outer join") { + val originalQuery = x.join(y.limit(5), RightOuter).limit(10) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(10, x.join(Limit(5, y), RightOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("full outer join where neither side is limited and both sides have same statistics") { + assert(x.statistics.sizeInBytes === y.statistics.sizeInBytes) + val originalQuery = x.join(y, FullOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, x).join(y, FullOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("full outer join where neither side is limited and left side has larger statistics") { + val xBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('x) + assert(xBig.statistics.sizeInBytes > y.statistics.sizeInBytes) + val originalQuery = xBig.join(y, FullOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, xBig).join(y, FullOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("full outer join where neither side is limited and right side has larger statistics") { + val yBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('y) + assert(x.statistics.sizeInBytes < yBig.statistics.sizeInBytes) + val originalQuery = x.join(yBig, FullOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, x.join(LocalLimit(1, yBig), FullOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("full outer join where both sides are limited") { + val originalQuery = x.limit(2).join(y.limit(2), FullOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, Limit(2, x).join(Limit(2, y), FullOuter)).analyze + comparePlans(optimized, correctAnswer) + } +} + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 48cab01ac1004..0e43ce034fb48 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -17,24 +17,21 @@ package org.apache.spark.sql.catalyst.optimizer -import scala.collection.immutable.HashSet -import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types._ -// For implicit conversions -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ - class OptimizeInSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("AnalysisNodes", Once, - EliminateSubQueries) :: + EliminateSubqueryAliases) :: Batch("ConstantFolding", FixedPoint(10), NullPropagation, ConstantFolding, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala new file mode 100644 index 0000000000000..6e5672ddc36bd --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.optimizer.Optimizer +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * This is a test for SPARK-7727 if the Optimizer is kept being extendable + */ +class OptimizerExtendableSuite extends SparkFunSuite { + + /** + * Dummy rule for test batches + */ + object DummyRule extends Rule[LogicalPlan] { + def apply(p: LogicalPlan): LogicalPlan = p + } + + /** + * This class represents a dummy extended optimizer that takes the batches of the + * Optimizer and adds custom ones. + */ + class ExtendedOptimizer extends Optimizer { + + // rules set to DummyRule, would not be executed anyways + val myBatches: Seq[Batch] = { + Batch("once", Once, + DummyRule) :: + Batch("fixedPoint", FixedPoint(100), + DummyRule) :: Nil + } + + override def batches: Seq[Batch] = super.batches ++ myBatches + } + + test("Extending batches possible") { + // test simply instantiates the new extended optimizer + val extendedOptimizer = new ExtendedOptimizer() + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala new file mode 100644 index 0000000000000..5e6e54dc741f3 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +class OuterJoinEliminationSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubqueryAliases) :: + Batch("Outer Join Elimination", Once, + OuterJoinElimination, + PushPredicateThroughJoin) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation1 = LocalRelation('d.int, 'e.int, 'f.int) + + test("joins: full outer to inner") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)) + .where("x.b".attr >= 1 && "y.d".attr >= 2) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('b >= 1) + val right = testRelation1.where('d >= 2) + val correctAnswer = + left.join(right, Inner, Option("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: full outer to right") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)).where("y.d".attr > 2) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation + val right = testRelation1.where('d > 2) + val correctAnswer = + left.join(right, RightOuter, Option("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: full outer to left") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)).where("x.a".attr <=> 2) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a <=> 2) + val right = testRelation1 + val correctAnswer = + left.join(right, LeftOuter, Option("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: right to inner") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, RightOuter, Option("x.a".attr === "y.d".attr)).where("x.b".attr > 2) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('b > 2) + val right = testRelation1 + val correctAnswer = + left.join(right, Inner, Option("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: left to inner") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, LeftOuter, Option("x.a".attr === "y.d".attr)) + .where("y.e".attr.isNotNull) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation + val right = testRelation1.where('e.isNotNull) + val correctAnswer = + left.join(right, Inner, Option("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + // evaluating if mixed OR and NOT expressions can eliminate all null-supplying rows + test("joins: left to inner with complicated filter predicates #1") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, LeftOuter, Option("x.a".attr === "y.d".attr)) + .where(!'e.isNull || ('d.isNotNull && 'f.isNull)) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation + val right = testRelation1.where(!'e.isNull || ('d.isNotNull && 'f.isNull)) + val correctAnswer = + left.join(right, Inner, Option("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + // eval(emptyRow) of 'e.in(1, 2) will return null instead of false + test("joins: left to inner with complicated filter predicates #2") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, LeftOuter, Option("x.a".attr === "y.d".attr)) + .where('e.in(1, 2)) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation + val right = testRelation1.where('e.in(1, 2)) + val correctAnswer = + left.join(right, Inner, Option("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + // evaluating if mixed OR and AND expressions can eliminate all null-supplying rows + test("joins: left to inner with complicated filter predicates #3") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, LeftOuter, Option("x.a".attr === "y.d".attr)) + .where((!'e.isNull || ('d.isNotNull && 'f.isNull)) && 'e.isNull) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation + val right = testRelation1.where((!'e.isNull || ('d.isNotNull && 'f.isNull)) && 'e.isNull) + val correctAnswer = + left.join(right, Inner, Option("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + // evaluating if the expressions that have both left and right attributes + // can eliminate all null-supplying rows + // FULL OUTER => INNER + test("joins: left to inner with complicated filter predicates #4") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)) + .where("x.b".attr + 3 === "y.e".attr) + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation + val right = testRelation1 + val correctAnswer = + left.join(right, Inner, Option("b".attr + 3 === "e".attr && "a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala deleted file mode 100644 index 1aa89991cc698..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.Rand -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.RuleExecutor - - -class ProjectCollapsingSuite extends PlanTest { - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = - Batch("Subqueries", FixedPoint(10), EliminateSubQueries) :: - Batch("ProjectCollapsing", Once, ProjectCollapsing) :: Nil - } - - val testRelation = LocalRelation('a.int, 'b.int) - - test("collapse two deterministic, independent projects into one") { - val query = testRelation - .select(('a + 1).as('a_plus_1), 'b) - .select('a_plus_1, ('b + 1).as('b_plus_1)) - - val optimized = Optimize.execute(query.analyze) - val correctAnswer = testRelation.select(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze - - comparePlans(optimized, correctAnswer) - } - - test("collapse two deterministic, dependent projects into one") { - val query = testRelation - .select(('a + 1).as('a_plus_1), 'b) - .select(('a_plus_1 + 1).as('a_plus_2), 'b) - - val optimized = Optimize.execute(query.analyze) - - val correctAnswer = testRelation.select( - (('a + 1).as('a_plus_1) + 1).as('a_plus_2), - 'b).analyze - - comparePlans(optimized, correctAnswer) - } - - test("do not collapse nondeterministic projects") { - val query = testRelation - .select(Rand(10).as('rand)) - .select(('rand + 1).as('rand1), ('rand + 2).as('rand2)) - - val optimized = Optimize.execute(query.analyze) - val correctAnswer = query.analyze - - comparePlans(optimized, correctAnswer) - } - - test("collapse two nondeterministic, independent projects into one") { - val query = testRelation - .select(Rand(10).as('rand)) - .select(Rand(20).as('rand2)) - - val optimized = Optimize.execute(query.analyze) - - val correctAnswer = testRelation - .select(Rand(20).as('rand2)).analyze - - comparePlans(optimized, correctAnswer) - } - - test("collapse one nondeterministic, one deterministic, independent projects into one") { - val query = testRelation - .select(Rand(10).as('rand), 'a) - .select(('a + 1).as('a_plus_1)) - - val optimized = Optimize.execute(query.analyze) - - val correctAnswer = testRelation - .select(('a + 1).as('a_plus_1)).analyze - - comparePlans(optimized, correctAnswer) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala new file mode 100644 index 0000000000000..d8cfec5391497 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +class PruneFiltersSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubqueryAliases) :: + Batch("Filter Pushdown and Pruning", Once, + CombineFilters, + PruneFilters, + PushDownPredicate, + PushPredicateThroughJoin) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + test("Constraints of isNull + LeftOuter") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val query = x.where("x.b".attr.isNull).join(y, LeftOuter) + val queryWithUselessFilter = query.where("x.b".attr.isNull) + + val optimized = Optimize.execute(queryWithUselessFilter.analyze) + val correctAnswer = query.analyze + + comparePlans(optimized, correctAnswer) + } + + test("Constraints of unionall") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int) + val tr2 = LocalRelation('d.int, 'e.int, 'f.int) + val tr3 = LocalRelation('g.int, 'h.int, 'i.int) + + val query = + tr1.where('a.attr > 10) + .union(tr2.where('d.attr > 10) + .union(tr3.where('g.attr > 10))) + val queryWithUselessFilter = query.where('a.attr > 10) + + val optimized = Optimize.execute(queryWithUselessFilter.analyze) + val correctAnswer = query.analyze + + comparePlans(optimized, correctAnswer) + } + + test("Pruning multiple constraints in the same run") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + + val query = tr1 + .where("tr1.a".attr > 10 || "tr1.c".attr < 10) + .join(tr2.where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)) + // different order of "tr2.a" and "tr1.a" + val queryWithUselessFilter = + query.where( + ("tr1.a".attr > 10 || "tr1.c".attr < 10) && + 'd.attr < 100 && + "tr2.a".attr === "tr1.a".attr) + + val optimized = Optimize.execute(queryWithUselessFilter.analyze) + val correctAnswer = query.analyze + + comparePlans(optimized, correctAnswer) + } + + test("Partial pruning") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + + // One of the filter condition does not exist in the constraints of its child + // Thus, the filter is not removed + val query = tr1 + .where("tr1.a".attr > 10) + .join(tr2.where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.d".attr)) + val queryWithExtraFilters = + query.where("tr1.a".attr > 10 && 'd.attr < 100 && "tr1.a".attr === "tr2.a".attr) + + val optimized = Optimize.execute(queryWithExtraFilters.analyze) + val correctAnswer = tr1 + .where("tr1.a".attr > 10) + .join(tr2.where('d.attr < 100), + Inner, + Some("tr1.a".attr === "tr2.a".attr && "tr1.a".attr === "tr2.d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("No predicate is pruned") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val query = x.where("x.b".attr.isNull).join(y, LeftOuter) + val queryWithExtraFilters = query.where("x.b".attr.isNotNull) + + val optimized = Optimize.execute(queryWithExtraFilters.analyze) + val correctAnswer = + testRelation.where("b".attr.isNull).where("b".attr.isNotNull) + .join(testRelation, LeftOuter).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Nondeterministic predicate is not pruned") { + val originalQuery = testRelation.where(Rand(10) > 5).select('a).where(Rand(10) > 5).analyze + val optimized = Optimize.execute(originalQuery) + val correctAnswer = testRelation.where(Rand(10) > 5).where(Rand(10) > 5).select('a).analyze + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala new file mode 100644 index 0000000000000..f8ae5d9be2084 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class ReplaceOperatorSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Replace Operators", FixedPoint(100), + ReplaceDistinctWithAggregate, + ReplaceIntersectWithSemiJoin) :: Nil + } + + test("replace Intersect with Left-semi Join") { + val table1 = LocalRelation('a.int, 'b.int) + val table2 = LocalRelation('c.int, 'd.int) + + val query = Intersect(table1, table2) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(table1.output, table1.output, + Join(table1, table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd))).analyze + + comparePlans(optimized, correctAnswer) + } + + test("replace Distinct with Aggregate") { + val input = LocalRelation('a.int, 'b.int) + + val query = Distinct(input) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = Aggregate(input.output, input.output, input) + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala deleted file mode 100644 index 1595ad9327423..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ - -class SetOperationPushDownSuite extends PlanTest { - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = - Batch("Subqueries", Once, - EliminateSubQueries) :: - Batch("Union Pushdown", Once, - SetOperationPushDown, - SimplifyFilters) :: Nil - } - - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) - val testUnion = Union(testRelation, testRelation2) - val testIntersect = Intersect(testRelation, testRelation2) - val testExcept = Except(testRelation, testRelation2) - - test("union/intersect/except: filter to each side") { - val unionQuery = testUnion.where('a === 1) - val intersectQuery = testIntersect.where('b < 10) - val exceptQuery = testExcept.where('c >= 5) - - val unionOptimized = Optimize.execute(unionQuery.analyze) - val intersectOptimized = Optimize.execute(intersectQuery.analyze) - val exceptOptimized = Optimize.execute(exceptQuery.analyze) - - val unionCorrectAnswer = - Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze - val intersectCorrectAnswer = - Intersect(testRelation.where('b < 10), testRelation2.where('e < 10)).analyze - val exceptCorrectAnswer = - Except(testRelation.where('c >= 5), testRelation2.where('f >= 5)).analyze - - comparePlans(unionOptimized, unionCorrectAnswer) - comparePlans(intersectOptimized, intersectCorrectAnswer) - comparePlans(exceptOptimized, exceptCorrectAnswer) - } - - test("union: project to each side") { - val unionQuery = testUnion.select('a) - val unionOptimized = Optimize.execute(unionQuery.analyze) - val unionCorrectAnswer = - Union(testRelation.select('a), testRelation2.select('d)).analyze - comparePlans(unionOptimized, unionCorrectAnswer) - } - - test("SPARK-10539: Project should not be pushed down through Intersect or Except") { - val intersectQuery = testIntersect.select('b, 'c) - val exceptQuery = testExcept.select('a, 'b, 'c) - - val intersectOptimized = Optimize.execute(intersectQuery.analyze) - val exceptOptimized = Optimize.execute(exceptQuery.analyze) - - comparePlans(intersectOptimized, intersectQuery.analyze) - comparePlans(exceptOptimized, exceptQuery.analyze) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala new file mode 100644 index 0000000000000..b08cdc8a3658e --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +class SetOperationSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubqueryAliases) :: + Batch("Union Pushdown", Once, + CombineUnions, + SetOperationPushDown, + PruneFilters) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + val testRelation3 = LocalRelation('g.int, 'h.int, 'i.int) + val testUnion = Union(testRelation :: testRelation2 :: testRelation3 :: Nil) + val testExcept = Except(testRelation, testRelation2) + + test("union: combine unions into one unions") { + val unionQuery1 = Union(Union(testRelation, testRelation2), testRelation) + val unionQuery2 = Union(testRelation, Union(testRelation2, testRelation)) + val unionOptimized1 = Optimize.execute(unionQuery1.analyze) + val unionOptimized2 = Optimize.execute(unionQuery2.analyze) + + comparePlans(unionOptimized1, unionOptimized2) + + val combinedUnions = Union(unionOptimized1 :: unionOptimized2 :: Nil) + val combinedUnionsOptimized = Optimize.execute(combinedUnions.analyze) + val unionQuery3 = Union(unionQuery1, unionQuery2) + val unionOptimized3 = Optimize.execute(unionQuery3.analyze) + comparePlans(combinedUnionsOptimized, unionOptimized3) + } + + test("except: filter to each side") { + val exceptQuery = testExcept.where('c >= 5) + val exceptOptimized = Optimize.execute(exceptQuery.analyze) + val exceptCorrectAnswer = + Except(testRelation.where('c >= 5), testRelation2.where('f >= 5)).analyze + + comparePlans(exceptOptimized, exceptCorrectAnswer) + } + + test("union: filter to each side") { + val unionQuery = testUnion.where('a === 1) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Union(testRelation.where('a === 1) :: + testRelation2.where('d === 1) :: + testRelation3.where('g === 1) :: Nil).analyze + + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("union: project to each side") { + val unionQuery = testUnion.select('a) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Union(testRelation.select('a) :: + testRelation2.select('d) :: + testRelation3.select('g) :: Nil).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("SPARK-10539: Project should not be pushed down through Intersect or Except") { + val exceptQuery = testExcept.select('a, 'b, 'c) + val exceptOptimized = Optimize.execute(exceptQuery.analyze) + comparePlans(exceptOptimized, exceptQuery.analyze) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala deleted file mode 100644 index 6b1e53cd42b24..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.rules._ - -/* Implicit conversions */ -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ - -class SimplifyCaseConversionExpressionsSuite extends PlanTest { - - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = - Batch("Simplify CaseConversionExpressions", Once, - SimplifyCaseConversionExpressions) :: Nil - } - - val testRelation = LocalRelation('a.string) - - test("simplify UPPER(UPPER(str))") { - val originalQuery = - testRelation - .select(Upper(Upper('a)) as 'u) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .select(Upper('a) as 'u) - .analyze - - comparePlans(optimized, correctAnswer) - } - - test("simplify UPPER(LOWER(str))") { - val originalQuery = - testRelation - .select(Upper(Lower('a)) as 'u) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .select(Upper('a) as 'u) - .analyze - - comparePlans(optimized, correctAnswer) - } - - test("simplify LOWER(UPPER(str))") { - val originalQuery = - testRelation - .select(Lower(Upper('a)) as 'l) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = testRelation - .select(Lower('a) as 'l) - .analyze - - comparePlans(optimized, correctAnswer) - } - - test("simplify LOWER(LOWER(str))") { - val originalQuery = - testRelation - .select(Lower(Lower('a)) as 'l) - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = testRelation - .select(Lower('a) as 'l) - .analyze - - comparePlans(optimized, correctAnswer) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala new file mode 100644 index 0000000000000..c02fec30858e5 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.{IntegerType, NullType} + + +class SimplifyConditionalSuite extends PlanTest with PredicateHelper { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil + } + + protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { + val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze + val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze) + comparePlans(actual, correctAnswer) + } + + private val trueBranch = (TrueLiteral, Literal(5)) + private val normalBranch = (NonFoldableLiteral(true), Literal(10)) + private val unreachableBranch = (FalseLiteral, Literal(20)) + private val nullBranch = (Literal.create(null, NullType), Literal(30)) + + test("simplify if") { + assertEquivalent( + If(TrueLiteral, Literal(10), Literal(20)), + Literal(10)) + + assertEquivalent( + If(FalseLiteral, Literal(10), Literal(20)), + Literal(20)) + + assertEquivalent( + If(Literal.create(null, NullType), Literal(10), Literal(20)), + Literal(20)) + } + + test("remove unreachable branches") { + // i.e. removing branches whose conditions are always false + assertEquivalent( + CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None), + CaseWhen(normalBranch :: Nil, None)) + } + + test("remove entire CaseWhen if only the else branch is reachable") { + assertEquivalent( + CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: Nil, Some(Literal(30))), + Literal(30)) + + assertEquivalent( + CaseWhen(unreachableBranch :: unreachableBranch :: Nil, None), + Literal.create(null, IntegerType)) + } + + test("remove entire CaseWhen if the first branch is always true") { + assertEquivalent( + CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None), + Literal(5)) + + // Test branch elimination and simplification in combination + assertEquivalent( + CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch + :: Nil, None), + Literal(5)) + + // Make sure this doesn't trigger if there is a non-foldable branch before the true branch + assertEquivalent( + CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None), + CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyStringCaseConversionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyStringCaseConversionSuite.scala new file mode 100644 index 0000000000000..24413e7a2a3f0 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyStringCaseConversionSuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.rules._ + +class SimplifyStringCaseConversionSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Simplify CaseConversionExpressions", Once, + SimplifyCaseConversionExpressions) :: Nil + } + + val testRelation = LocalRelation('a.string) + + test("simplify UPPER(UPPER(str))") { + val originalQuery = + testRelation + .select(Upper(Upper('a)) as 'u) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select(Upper('a) as 'u) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("simplify UPPER(LOWER(str))") { + val originalQuery = + testRelation + .select(Upper(Lower('a)) as 'u) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select(Upper('a) as 'u) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("simplify LOWER(UPPER(str))") { + val originalQuery = + testRelation + .select(Lower(Upper('a)) as 'l) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select(Lower('a) as 'l) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("simplify LOWER(LOWER(str))") { + val originalQuery = + testRelation + .select(Lower(Lower('a)) as 'l) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select(Lower('a) as 'l) + .analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala new file mode 100644 index 0000000000000..1fae64e3bc6b1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.BooleanType + +class TypedFilterOptimizationSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("EliminateSerialization", FixedPoint(50), + EliminateSerialization) :: + Batch("EmbedSerializerInFilter", FixedPoint(50), + EmbedSerializerInFilter) :: Nil + } + + implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]() + + test("back to back filter") { + val input = LocalRelation('_1.int, '_2.int) + val f1 = (i: (Int, Int)) => i._1 > 0 + val f2 = (i: (Int, Int)) => i._2 > 0 + + val query = input.filter(f1).filter(f2).analyze + + val optimized = Optimize.execute(query) + + val expected = input.deserialize[(Int, Int)] + .where(callFunction(f1, BooleanType, 'obj)) + .select('obj.as("obj")) + .where(callFunction(f2, BooleanType, 'obj)) + .serialize[(Int, Int)].analyze + + comparePlans(optimized, expected) + } + + test("embed deserializer in filter condition if there is only one filter") { + val input = LocalRelation('_1.int, '_2.int) + val f = (i: (Int, Int)) => i._1 > 0 + + val query = input.filter(f).analyze + + val optimized = Optimize.execute(query) + + val deserializer = UnresolvedDeserializer(encoderFor[(Int, Int)].deserializer) + val condition = callFunction(f, BooleanType, deserializer) + val expected = input.where(condition).analyze + + comparePlans(optimized, expected) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala new file mode 100644 index 0000000000000..07b89cb61f2d1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala @@ -0,0 +1,157 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +abstract class AbstractDataTypeParserSuite extends SparkFunSuite { + + def parse(sql: String): DataType + + def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = { + test(s"parse ${dataTypeString.replace("\n", "")}") { + assert(parse(dataTypeString) === expectedDataType) + } + } + + def intercept(sql: String) + + def unsupported(dataTypeString: String): Unit = { + test(s"$dataTypeString is not supported") { + intercept(dataTypeString) + } + } + + checkDataType("int", IntegerType) + checkDataType("integer", IntegerType) + checkDataType("BooLean", BooleanType) + checkDataType("tinYint", ByteType) + checkDataType("smallINT", ShortType) + checkDataType("INT", IntegerType) + checkDataType("INTEGER", IntegerType) + checkDataType("bigint", LongType) + checkDataType("float", FloatType) + checkDataType("dOUBle", DoubleType) + checkDataType("decimal(10, 5)", DecimalType(10, 5)) + checkDataType("decimal", DecimalType.USER_DEFAULT) + checkDataType("DATE", DateType) + checkDataType("timestamp", TimestampType) + checkDataType("string", StringType) + checkDataType("ChaR(5)", StringType) + checkDataType("varchAr(20)", StringType) + checkDataType("cHaR(27)", StringType) + checkDataType("BINARY", BinaryType) + + checkDataType("array", ArrayType(DoubleType, true)) + checkDataType("Array>", ArrayType(MapType(IntegerType, ByteType, true), true)) + checkDataType( + "array>", + ArrayType(StructType(StructField("tinYint", ByteType, true) :: Nil), true) + ) + checkDataType("MAP", MapType(IntegerType, StringType, true)) + checkDataType("MAp>", MapType(IntegerType, ArrayType(DoubleType), true)) + checkDataType( + "MAP>", + MapType(IntegerType, StructType(StructField("varchar", StringType, true) :: Nil), true) + ) + + checkDataType( + "struct", + StructType( + StructField("intType", IntegerType, true) :: + StructField("ts", TimestampType, true) :: Nil) + ) + // It is fine to use the data type string as the column name. + checkDataType( + "Struct", + StructType( + StructField("int", IntegerType, true) :: + StructField("timestamp", TimestampType, true) :: Nil) + ) + checkDataType( + """ + |struct< + | struct:struct, + | MAP:Map, + | arrAy:Array, + | anotherArray:Array> + """.stripMargin, + StructType( + StructField("struct", + StructType( + StructField("deciMal", DecimalType.USER_DEFAULT, true) :: + StructField("anotherDecimal", DecimalType(5, 2), true) :: Nil), true) :: + StructField("MAP", MapType(TimestampType, StringType), true) :: + StructField("arrAy", ArrayType(DoubleType, true), true) :: + StructField("anotherArray", ArrayType(StringType, true), true) :: Nil) + ) + // Use backticks to quote column names having special characters. + checkDataType( + "struct<`x+y`:int, `!@#$%^&*()`:string, `1_2.345<>:\"`:varchar(20)>", + StructType( + StructField("x+y", IntegerType, true) :: + StructField("!@#$%^&*()", StringType, true) :: + StructField("1_2.345<>:\"", StringType, true) :: Nil) + ) + // Empty struct. + checkDataType("strUCt<>", StructType(Nil)) + + unsupported("it is not a data type") + unsupported("struct") + unsupported("struct", + StructType( + StructField("TABLE", StringType, true) :: + StructField("CASE", BooleanType, true) :: Nil) + ) + + unsupported("struct") + + unsupported("struct<`x``y` int>") +} + +class CatalystQlDataTypeParserSuite extends AbstractDataTypeParserSuite { + override def intercept(sql: String): Unit = + intercept[ParseException](CatalystSqlParser.parseDataType(sql)) + + override def parse(sql: String): DataType = + CatalystSqlParser.parseDataType(sql) + + // A column name can be a reserved word in our DDL parser and SqlParser. + unsupported("Struct") + + checkDataType( + "struct", + (new StructType).add("x", IntegerType).add("y", StringType)) + + checkDataType( + "struct<`x``y` int>", + (new StructType).add("x`y", IntegerType)) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala new file mode 100644 index 0000000000000..db96bfb652120 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.SparkFunSuite + +/** + * Test various parser errors. + */ +class ErrorParserSuite extends SparkFunSuite { + def intercept(sql: String, line: Int, startPosition: Int, messages: String*): Unit = { + val e = intercept[ParseException](CatalystSqlParser.parsePlan(sql)) + + // Check position. + assert(e.line.isDefined) + assert(e.line.get === line) + assert(e.startPosition.isDefined) + assert(e.startPosition.get === startPosition) + + // Check messages. + val error = e.getMessage + messages.foreach { message => + assert(error.contains(message)) + } + } + + test("no viable input") { + intercept("select from tbl", 1, 7, "no viable alternative at input", "-------^^^") + intercept("select\nfrom tbl", 2, 0, "no viable alternative at input", "^^^") + intercept("select ((r + 1) ", 1, 16, "no viable alternative at input", "----------------^^^") + } + + test("extraneous input") { + intercept("select 1 1", 1, 9, "extraneous input '1' expecting", "---------^^^") + intercept("select *\nfrom r as q t", 2, 12, "extraneous input", "------------^^^") + } + + test("mismatched input") { + intercept("select * from r order by q from t", 1, 27, + "mismatched input", + "---------------------------^^^") + intercept("select *\nfrom r\norder by q\nfrom t", 4, 0, "mismatched input", "^^^") + } + + test("semantic errors") { + intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0, + "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", + "^^^") + intercept("select * from r where a in (select * from t)", 1, 24, + "IN with a Sub-query is currently not supported", + "------------------------^^^") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala new file mode 100644 index 0000000000000..6f40ec67ec6e0 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -0,0 +1,497 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * Test basic expression parsing. If a type of expression is supported it should be tested here. + * + * Please note that some of the expressions test don't have to be sound expressions, only their + * structure needs to be valid. Unsound expressions should be caught by the Analyzer or + * CheckAnalysis classes. + */ +class ExpressionParserSuite extends PlanTest { + import CatalystSqlParser._ + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ + + def assertEqual(sqlCommand: String, e: Expression): Unit = { + compareExpressions(parseExpression(sqlCommand), e) + } + + def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parseExpression(sqlCommand)) + messages.foreach { message => + assert(e.message.contains(message)) + } + } + + test("star expressions") { + // Global Star + assertEqual("*", UnresolvedStar(None)) + + // Targeted Star + assertEqual("a.b.*", UnresolvedStar(Option(Seq("a", "b")))) + } + + // NamedExpression (Alias/Multialias) + test("named expressions") { + // No Alias + val r0 = 'a + assertEqual("a", r0) + + // Single Alias. + val r1 = 'a as "b" + assertEqual("a as b", r1) + assertEqual("a b", r1) + + // Multi-Alias + assertEqual("a as (b, c)", MultiAlias('a, Seq("b", "c"))) + assertEqual("a() (b, c)", MultiAlias('a.function(), Seq("b", "c"))) + + // Numeric literals without a space between the literal qualifier and the alias, should not be + // interpreted as such. An unresolved reference should be returned instead. + // TODO add the JIRA-ticket number. + assertEqual("1SL", Symbol("1SL")) + + // Aliased star is allowed. + assertEqual("a.* b", UnresolvedStar(Option(Seq("a"))) as 'b) + } + + test("binary logical expressions") { + // And + assertEqual("a and b", 'a && 'b) + + // Or + assertEqual("a or b", 'a || 'b) + + // Combination And/Or check precedence + assertEqual("a and b or c and d", ('a && 'b) || ('c && 'd)) + assertEqual("a or b or c and d", 'a || 'b || ('c && 'd)) + + // Multiple AND/OR get converted into a balanced tree + assertEqual("a or b or c or d or e or f", (('a || 'b) || 'c) || (('d || 'e) || 'f)) + assertEqual("a and b and c and d and e and f", (('a && 'b) && 'c) && (('d && 'e) && 'f)) + } + + test("long binary logical expressions") { + def testVeryBinaryExpression(op: String, clazz: Class[_]): Unit = { + val sql = (1 to 1000).map(x => s"$x == $x").mkString(op) + val e = parseExpression(sql) + assert(e.collect { case _: EqualTo => true }.size === 1000) + assert(e.collect { case x if clazz.isInstance(x) => true }.size === 999) + } + testVeryBinaryExpression(" AND ", classOf[And]) + testVeryBinaryExpression(" OR ", classOf[Or]) + } + + test("not expressions") { + assertEqual("not a", !'a) + assertEqual("!a", !'a) + assertEqual("not true > true", Not(GreaterThan(true, true))) + } + + test("exists expression") { + intercept("exists (select 1 from b where b.x = a.x)", "EXISTS clauses are not supported") + } + + test("comparison expressions") { + assertEqual("a = b", 'a === 'b) + assertEqual("a == b", 'a === 'b) + assertEqual("a <=> b", 'a <=> 'b) + assertEqual("a <> b", 'a =!= 'b) + assertEqual("a != b", 'a =!= 'b) + assertEqual("a < b", 'a < 'b) + assertEqual("a <= b", 'a <= 'b) + assertEqual("a > b", 'a > 'b) + assertEqual("a >= b", 'a >= 'b) + } + + test("between expressions") { + assertEqual("a between b and c", 'a >= 'b && 'a <= 'c) + assertEqual("a not between b and c", !('a >= 'b && 'a <= 'c)) + } + + test("in expressions") { + assertEqual("a in (b, c, d)", 'a in ('b, 'c, 'd)) + assertEqual("a not in (b, c, d)", !('a in ('b, 'c, 'd))) + } + + test("in sub-query") { + intercept("a in (select b from c)", "IN with a Sub-query is currently not supported") + } + + test("like expressions") { + assertEqual("a like 'pattern%'", 'a like "pattern%") + assertEqual("a not like 'pattern%'", !('a like "pattern%")) + assertEqual("a rlike 'pattern%'", 'a rlike "pattern%") + assertEqual("a not rlike 'pattern%'", !('a rlike "pattern%")) + assertEqual("a regexp 'pattern%'", 'a rlike "pattern%") + assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%")) + } + + test("is null expressions") { + assertEqual("a is null", 'a.isNull) + assertEqual("a is not null", 'a.isNotNull) + assertEqual("a = b is null", ('a === 'b).isNull) + assertEqual("a = b is not null", ('a === 'b).isNotNull) + } + + test("binary arithmetic expressions") { + // Simple operations + assertEqual("a * b", 'a * 'b) + assertEqual("a / b", 'a / 'b) + assertEqual("a DIV b", ('a / 'b).cast(LongType)) + assertEqual("a % b", 'a % 'b) + assertEqual("a + b", 'a + 'b) + assertEqual("a - b", 'a - 'b) + assertEqual("a & b", 'a & 'b) + assertEqual("a ^ b", 'a ^ 'b) + assertEqual("a | b", 'a | 'b) + + // Check precedences + assertEqual( + "a * t | b ^ c & d - e + f % g DIV h / i * k", + 'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g / 'h).cast(LongType) / 'i * 'k))))) + } + + test("unary arithmetic expressions") { + assertEqual("+a", 'a) + assertEqual("-a", -'a) + assertEqual("~a", ~'a) + assertEqual("-+~~a", -(~(~'a))) + } + + test("cast expressions") { + // Note that DataType parsing is tested elsewhere. + assertEqual("cast(a as int)", 'a.cast(IntegerType)) + assertEqual("cast(a as timestamp)", 'a.cast(TimestampType)) + assertEqual("cast(a as array)", 'a.cast(ArrayType(IntegerType))) + assertEqual("cast(cast(a as int) as long)", 'a.cast(IntegerType).cast(LongType)) + } + + test("function expressions") { + assertEqual("foo()", 'foo.function()) + assertEqual("foo.bar()", Symbol("foo.bar").function()) + assertEqual("foo(*)", 'foo.function(star())) + assertEqual("count(*)", 'count.function(1)) + assertEqual("foo(a, b)", 'foo.function('a, 'b)) + assertEqual("foo(all a, b)", 'foo.function('a, 'b)) + assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b)) + assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b)) + assertEqual("`select`(all a, b)", 'select.function('a, 'b)) + } + + test("window function expressions") { + val func = 'foo.function(star()) + def windowed( + partitioning: Seq[Expression] = Seq.empty, + ordering: Seq[SortOrder] = Seq.empty, + frame: WindowFrame = UnspecifiedFrame): Expression = { + WindowExpression(func, WindowSpecDefinition(partitioning, ordering, frame)) + } + + // Basic window testing. + assertEqual("foo(*) over w1", UnresolvedWindowExpression(func, WindowSpecReference("w1"))) + assertEqual("foo(*) over ()", windowed()) + assertEqual("foo(*) over (partition by a, b)", windowed(Seq('a, 'b))) + assertEqual("foo(*) over (distribute by a, b)", windowed(Seq('a, 'b))) + assertEqual("foo(*) over (cluster by a, b)", windowed(Seq('a, 'b))) + assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) + assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) + assertEqual("foo(*) over (partition by a, b order by c)", windowed(Seq('a, 'b), Seq('c.asc))) + assertEqual("foo(*) over (distribute by a, b sort by c)", windowed(Seq('a, 'b), Seq('c.asc))) + + // Test use of expressions in window functions. + assertEqual( + "sum(product + 1) over (partition by ((product) + (1)) order by 2)", + WindowExpression('sum.function('product + 1), + WindowSpecDefinition(Seq('product + 1), Seq(Literal(2).asc), UnspecifiedFrame))) + assertEqual( + "sum(product + 1) over (partition by ((product / 2) + 1) order by 2)", + WindowExpression('sum.function('product + 1), + WindowSpecDefinition(Seq('product / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame))) + + // Range/Row + val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame)) + val boundaries = Seq( + ("10 preceding", ValuePreceding(10), CurrentRow), + ("3 + 1 following", ValueFollowing(4), CurrentRow), // Will fail during analysis + ("unbounded preceding", UnboundedPreceding, CurrentRow), + ("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis + ("between unbounded preceding and current row", UnboundedPreceding, CurrentRow), + ("between unbounded preceding and unbounded following", + UnboundedPreceding, UnboundedFollowing), + ("between 10 preceding and current row", ValuePreceding(10), CurrentRow), + ("between current row and 5 following", CurrentRow, ValueFollowing(5)), + ("between 10 preceding and 5 following", ValuePreceding(10), ValueFollowing(5)) + ) + frameTypes.foreach { + case (frameTypeSql, frameType) => + boundaries.foreach { + case (boundarySql, begin, end) => + val query = s"foo(*) over (partition by a order by b $frameTypeSql $boundarySql)" + val expr = windowed(Seq('a), Seq('b.asc), SpecifiedWindowFrame(frameType, begin, end)) + assertEqual(query, expr) + } + } + + // We cannot use non integer constants. + intercept("foo(*) over (partition by a order by b rows 10.0 preceding)", + "Frame bound value must be a constant integer.") + + // We cannot use an arbitrary expression. + intercept("foo(*) over (partition by a order by b rows exp(b) preceding)", + "Frame bound value must be a constant integer.") + } + + test("row constructor") { + // Note that '(a)' will be interpreted as a nested expression. + assertEqual("(a, b)", CreateStruct(Seq('a, 'b))) + assertEqual("(a, b, c)", CreateStruct(Seq('a, 'b, 'c))) + } + + test("scalar sub-query") { + assertEqual( + "(select max(val) from tbl) > current", + ScalarSubquery(table("tbl").select('max.function('val))) > 'current) + assertEqual( + "a = (select b from s)", + 'a === ScalarSubquery(table("s").select('b))) + } + + test("case when") { + assertEqual("case a when 1 then b when 2 then c else d end", + CaseKeyWhen('a, Seq(1, 'b, 2, 'c, 'd))) + assertEqual("case when a = 1 then b when a = 2 then c else d end", + CaseWhen(Seq(('a === 1, 'b.expr), ('a === 2, 'c.expr)), 'd)) + } + + test("dereference") { + assertEqual("a.b", UnresolvedAttribute("a.b")) + assertEqual("`select`.b", UnresolvedAttribute("select.b")) + assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis. + assertEqual("struct(a, b).b", 'struct.function('a, 'b).getField("b")) + } + + test("reference") { + // Regular + assertEqual("a", 'a) + + // Starting with a digit. + assertEqual("1a", Symbol("1a")) + + // Quoted using a keyword. + assertEqual("`select`", 'select) + + // Unquoted using an unreserved keyword. + assertEqual("columns", 'columns) + } + + test("subscript") { + assertEqual("a[b]", 'a.getItem('b)) + assertEqual("a[1 + 1]", 'a.getItem(Literal(1) + 1)) + assertEqual("`c`.a[b]", UnresolvedAttribute("c.a").getItem('b)) + } + + test("parenthesis") { + assertEqual("(a)", 'a) + assertEqual("r * (a + b)", 'r * ('a + 'b)) + } + + test("type constructors") { + // Dates. + assertEqual("dAte '2016-03-11'", Literal(Date.valueOf("2016-03-11"))) + intercept[IllegalArgumentException] { + parseExpression("DAtE 'mar 11 2016'") + } + + // Timestamps. + assertEqual("tImEstAmp '2016-03-11 20:54:00.000'", + Literal(Timestamp.valueOf("2016-03-11 20:54:00.000"))) + intercept[IllegalArgumentException] { + parseExpression("timestamP '2016-33-11 20:54:00.000'") + } + + // Unsupported datatype. + intercept("GEO '(10,-6)'", "Literals of type 'GEO' are currently not supported.") + } + + test("literals") { + // NULL + assertEqual("null", Literal(null)) + + // Boolean + assertEqual("trUe", Literal(true)) + assertEqual("False", Literal(false)) + + // Integral should have the narrowest possible type + assertEqual("787324", Literal(787324)) + assertEqual("7873247234798249234", Literal(7873247234798249234L)) + assertEqual("78732472347982492793712334", + Literal(BigDecimal("78732472347982492793712334").underlying())) + + // Decimal + assertEqual("7873247234798249279371.2334", + Literal(BigDecimal("7873247234798249279371.2334").underlying())) + + // Scientific Decimal + assertEqual("9.0e1", 90d) + assertEqual(".9e+2", 90d) + assertEqual("0.9e+2", 90d) + assertEqual("900e-1", 90d) + assertEqual("900.0E-1", 90d) + assertEqual("9.e+1", 90d) + intercept(".e3") + + // Tiny Int Literal + assertEqual("10Y", Literal(10.toByte)) + intercept("-1000Y") + + // Small Int Literal + assertEqual("10S", Literal(10.toShort)) + intercept("40000S") + + // Long Int Literal + assertEqual("10L", Literal(10L)) + intercept("78732472347982492793712334L") + + // Double Literal + assertEqual("10.0D", Literal(10.0D)) + // TODO we need to figure out if we should throw an exception here! + assertEqual("1E309", Literal(Double.PositiveInfinity)) + } + + test("strings") { + // Single Strings. + assertEqual("\"hello\"", "hello") + assertEqual("'hello'", "hello") + + // Multi-Strings. + assertEqual("\"hello\" 'world'", "helloworld") + assertEqual("'hello' \" \" 'world'", "hello world") + + // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a + // regular '%'; to get the correct result you need to add another escaped '\'. + // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? + assertEqual("'pattern%'", "pattern%") + assertEqual("'no-pattern\\%'", "no-pattern\\%") + assertEqual("'pattern\\\\%'", "pattern\\%") + assertEqual("'pattern\\\\\\%'", "pattern\\\\%") + + // Escaped characters. + // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html + assertEqual("'\\0'", "\u0000") // ASCII NUL (X'00') + assertEqual("'\\''", "\'") // Single quote + assertEqual("'\\\"'", "\"") // Double quote + assertEqual("'\\b'", "\b") // Backspace + assertEqual("'\\n'", "\n") // Newline + assertEqual("'\\r'", "\r") // Carriage return + assertEqual("'\\t'", "\t") // Tab character + assertEqual("'\\Z'", "\u001A") // ASCII 26 - CTRL + Z (EOF on windows) + + // Octals + assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!") + + // Unicode + assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)") + } + + test("intervals") { + def intervalLiteral(u: String, s: String): Literal = { + Literal(CalendarInterval.fromSingleUnitString(u, s)) + } + + // Empty interval statement + intercept("interval", "at least one time unit should be given for interval literal") + + // Single Intervals. + val units = Seq( + "year", + "month", + "week", + "day", + "hour", + "minute", + "second", + "millisecond", + "microsecond") + val forms = Seq("", "s") + val values = Seq("0", "10", "-7", "21") + units.foreach { unit => + forms.foreach { form => + values.foreach { value => + val expected = intervalLiteral(unit, value) + assertEqual(s"interval $value $unit$form", expected) + assertEqual(s"interval '$value' $unit$form", expected) + } + } + } + + // Hive nanosecond notation. + assertEqual("interval 13.123456789 seconds", intervalLiteral("second", "13.123456789")) + assertEqual("interval -13.123456789 second", intervalLiteral("second", "-13.123456789")) + + // Non Existing unit + intercept("interval 10 nanoseconds", "No interval can be constructed") + + // Year-Month intervals. + val yearMonthValues = Seq("123-10", "496-0", "-2-3", "-123-0") + yearMonthValues.foreach { value => + val result = Literal(CalendarInterval.fromYearMonthString(value)) + assertEqual(s"interval '$value' year to month", result) + } + + // Day-Time intervals. + val datTimeValues = Seq( + "99 11:22:33.123456789", + "-99 11:22:33.123456789", + "10 9:8:7.123456789", + "1 0:0:0", + "-1 0:0:0", + "1 0:0:1") + datTimeValues.foreach { value => + val result = Literal(CalendarInterval.fromDayTimeString(value)) + assertEqual(s"interval '$value' day to second", result) + } + + // Unknown FROM TO intervals + intercept("interval 10 month to second", "Intervals FROM month TO second are not supported.") + + // Composed intervals. + assertEqual( + "interval 3 months 22 seconds 1 millisecond", + Literal(new CalendarInterval(3, 22001000L))) + assertEqual( + "interval 3 years '-1-10' year to month 3 weeks '1 0:0:2' day to second", + Literal(new CalendarInterval(14, + 22 * CalendarInterval.MICROS_PER_DAY + 2 * CalendarInterval.MICROS_PER_SECOND))) + } + + test("composed expressions") { + assertEqual("1 + r.r As q", (Literal(1) + UnresolvedAttribute("r.r")).as("q")) + assertEqual("1 - f('o', o(bar))", Literal(1) - 'f.function("o", 'o.function('bar))) + intercept("1 - f('o', o(bar)) hello * world", "mismatched input '*'") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala new file mode 100644 index 0000000000000..d090daf7b41eb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.SparkFunSuite + +class ParserUtilsSuite extends SparkFunSuite { + + import ParserUtils._ + + test("unescapeSQLString") { + // scalastyle:off nonascii + + // String not including escaped characters and enclosed by double quotes. + assert(unescapeSQLString(""""abcdefg"""") == "abcdefg") + + // String enclosed by single quotes. + assert(unescapeSQLString("""'C0FFEE'""") == "C0FFEE") + + // Strings including single escaped characters. + assert(unescapeSQLString("""'\0'""") == "\u0000") + assert(unescapeSQLString(""""\'"""") == "\'") + assert(unescapeSQLString("""'\"'""") == "\"") + assert(unescapeSQLString(""""\b"""") == "\b") + assert(unescapeSQLString("""'\n'""") == "\n") + assert(unescapeSQLString(""""\r"""") == "\r") + assert(unescapeSQLString("""'\t'""") == "\t") + assert(unescapeSQLString(""""\Z"""") == "\u001A") + assert(unescapeSQLString("""'\\'""") == "\\") + assert(unescapeSQLString(""""\%"""") == "\\%") + assert(unescapeSQLString("""'\_'""") == "\\_") + + // String including '\000' style literal characters. + assert(unescapeSQLString("""'3 + 5 = \070'""") == "3 + 5 = \u0038") + assert(unescapeSQLString(""""\000"""") == "\u0000") + + // String including invalid '\000' style literal characters. + assert(unescapeSQLString(""""\256"""") == "256") + + // String including a '\u0000' style literal characters (\u732B is a cat in Kanji). + assert(unescapeSQLString(""""How cute \u732B are"""") == "How cute \u732B are") + + // String including a surrogate pair character + // (\uD867\uDE3D is Okhotsk atka mackerel in Kanji). + assert(unescapeSQLString(""""\uD867\uDE3D is a fish"""") == "\uD867\uDE3D is a fish") + + // scalastyle:on nonascii + } + + // TODO: Add test cases for other methods in ParserUtils +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala new file mode 100644 index 0000000000000..411e2372f2e07 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -0,0 +1,431 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.{BooleanType, IntegerType} + +class PlanParserSuite extends PlanTest { + import CatalystSqlParser._ + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ + + def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = { + comparePlans(parsePlan(sqlCommand), plan) + } + + def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parsePlan(sqlCommand)) + messages.foreach { message => + assert(e.message.contains(message)) + } + } + + test("case insensitive") { + val plan = table("a").select(star()) + assertEqual("sELEct * FroM a", plan) + assertEqual("select * fRoM a", plan) + assertEqual("SELECT * FROM a", plan) + } + + test("show functions") { + assertEqual("show functions", ShowFunctions(None, None)) + assertEqual("show functions foo", ShowFunctions(None, Some("foo"))) + assertEqual("show functions foo.bar", ShowFunctions(Some("foo"), Some("bar"))) + assertEqual("show functions 'foo\\\\.*'", ShowFunctions(None, Some("foo\\.*"))) + intercept("show functions foo.bar.baz", "SHOW FUNCTIONS unsupported name") + } + + test("describe function") { + assertEqual("describe function bar", DescribeFunction("bar", isExtended = false)) + assertEqual("describe function extended bar", DescribeFunction("bar", isExtended = true)) + assertEqual("describe function foo.bar", DescribeFunction("foo.bar", isExtended = false)) + assertEqual("describe function extended f.bar", DescribeFunction("f.bar", isExtended = true)) + } + + test("set operations") { + val a = table("a").select(star()) + val b = table("b").select(star()) + + assertEqual("select * from a union select * from b", Distinct(a.union(b))) + assertEqual("select * from a union distinct select * from b", Distinct(a.union(b))) + assertEqual("select * from a union all select * from b", a.union(b)) + assertEqual("select * from a except select * from b", a.except(b)) + intercept("select * from a except all select * from b", "EXCEPT ALL is not supported.") + assertEqual("select * from a except distinct select * from b", a.except(b)) + assertEqual("select * from a intersect select * from b", a.intersect(b)) + intercept("select * from a intersect all select * from b", "INTERSECT ALL is not supported.") + assertEqual("select * from a intersect distinct select * from b", a.intersect(b)) + } + + test("common table expressions") { + def cte(plan: LogicalPlan, namedPlans: (String, LogicalPlan)*): With = { + val ctes = namedPlans.map { + case (name, cte) => + name -> SubqueryAlias(name, cte) + }.toMap + With(plan, ctes) + } + assertEqual( + "with cte1 as (select * from a) select * from cte1", + cte(table("cte1").select(star()), "cte1" -> table("a").select(star()))) + assertEqual( + "with cte1 (select 1) select * from cte1", + cte(table("cte1").select(star()), "cte1" -> OneRowRelation.select(1))) + assertEqual( + "with cte1 (select 1), cte2 as (select * from cte1) select * from cte2", + cte(table("cte2").select(star()), + "cte1" -> OneRowRelation.select(1), + "cte2" -> table("cte1").select(star()))) + intercept( + "with cte1 (select 1), cte1 as (select 1 from cte1) select * from cte1", + "Name 'cte1' is used for multiple common table expressions") + } + + test("simple select query") { + assertEqual("select 1", OneRowRelation.select(1)) + assertEqual("select a, b", OneRowRelation.select('a, 'b)) + assertEqual("select a, b from db.c", table("db", "c").select('a, 'b)) + assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b)) + assertEqual( + "select a, b from db.c having x < 1", + table("db", "c").select('a, 'b).where(('x < 1).cast(BooleanType))) + assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b))) + assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b)) + } + + test("reverse select query") { + assertEqual("from a", table("a")) + assertEqual("from a select b, c", table("a").select('b, 'c)) + assertEqual( + "from db.a select b, c where d < 1", table("db", "a").where('d < 1).select('b, 'c)) + assertEqual("from a select distinct b, c", Distinct(table("a").select('b, 'c))) + assertEqual( + "from (from a union all from b) c select *", + table("a").union(table("b")).as("c").select(star())) + } + + test("multi select query") { + assertEqual( + "from a select * select * where s < 10", + table("a").select(star()).union(table("a").where('s < 10).select(star()))) + intercept( + "from a select * select * from x where a.s < 10", + "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements") + assertEqual( + "from a insert into tbl1 select * insert into tbl2 select * where s < 10", + table("a").select(star()).insertInto("tbl1").union( + table("a").where('s < 10).select(star()).insertInto("tbl2"))) + } + + test("query organization") { + // Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows + val baseSql = "select * from t" + val basePlan = table("t").select(star()) + + val ws = Map("w1" -> WindowSpecDefinition(Seq.empty, Seq.empty, UnspecifiedFrame)) + val limitWindowClauses = Seq( + ("", (p: LogicalPlan) => p), + (" limit 10", (p: LogicalPlan) => p.limit(10)), + (" window w1 as ()", (p: LogicalPlan) => WithWindowDefinition(ws, p)), + (" window w1 as () limit 10", (p: LogicalPlan) => WithWindowDefinition(ws, p).limit(10)) + ) + + val orderSortDistrClusterClauses = Seq( + ("", basePlan), + (" order by a, b desc", basePlan.orderBy('a.asc, 'b.desc)), + (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)), + (" distribute by a, b", basePlan.distribute('a, 'b)), + (" distribute by a sort by b", basePlan.distribute('a).sortBy('b.asc)), + (" cluster by a, b", basePlan.distribute('a, 'b).sortBy('a.asc, 'b.asc)) + ) + + orderSortDistrClusterClauses.foreach { + case (s1, p1) => + limitWindowClauses.foreach { + case (s2, pf2) => + assertEqual(baseSql + s1 + s2, pf2(p1)) + } + } + + val msg = "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported" + intercept(s"$baseSql order by a sort by a", msg) + intercept(s"$baseSql cluster by a distribute by a", msg) + intercept(s"$baseSql order by a cluster by a", msg) + intercept(s"$baseSql order by a distribute by a", msg) + } + + test("insert into") { + val sql = "select * from t" + val plan = table("t").select(star()) + def insert( + partition: Map[String, Option[String]], + overwrite: Boolean = false, + ifNotExists: Boolean = false): LogicalPlan = + InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists) + + // Single inserts + assertEqual(s"insert overwrite table s $sql", + insert(Map.empty, overwrite = true)) + assertEqual(s"insert overwrite table s if not exists $sql", + insert(Map.empty, overwrite = true, ifNotExists = true)) + assertEqual(s"insert into s $sql", + insert(Map.empty)) + assertEqual(s"insert into table s partition (c = 'd', e = 1) $sql", + insert(Map("c" -> Option("d"), "e" -> Option("1")))) + assertEqual(s"insert overwrite table s partition (c = 'd', x) if not exists $sql", + insert(Map("c" -> Option("d"), "x" -> None), overwrite = true, ifNotExists = true)) + + // Multi insert + val plan2 = table("t").where('x > 5).select(star()) + assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5", + InsertIntoTable( + table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false).union( + InsertIntoTable( + table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false))) + } + + test("aggregation") { + val sql = "select a, b, sum(c) as c from d group by a, b" + + // Normal + assertEqual(sql, table("d").groupBy('a, 'b)('a, 'b, 'sum.function('c).as("c"))) + + // Cube + assertEqual(s"$sql with cube", + table("d").groupBy(Cube(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c"))) + + // Rollup + assertEqual(s"$sql with rollup", + table("d").groupBy(Rollup(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c"))) + + // Grouping Sets + assertEqual(s"$sql grouping sets((a, b), (a), ())", + GroupingSets(Seq(0, 1, 3), Seq('a, 'b), table("d"), Seq('a, 'b, 'sum.function('c).as("c")))) + intercept(s"$sql grouping sets((a, b), (c), ())", + "c doesn't show up in the GROUP BY list") + } + + test("limit") { + val sql = "select * from t" + val plan = table("t").select(star()) + assertEqual(s"$sql limit 10", plan.limit(10)) + assertEqual(s"$sql limit cast(9 / 4 as int)", plan.limit(Cast(Literal(9) / 4, IntegerType))) + } + + test("window spec") { + // Note that WindowSpecs are testing in the ExpressionParserSuite + val sql = "select * from t" + val plan = table("t").select(star()) + val spec = WindowSpecDefinition(Seq('a, 'b), Seq('c.asc), + SpecifiedWindowFrame(RowFrame, ValuePreceding(1), ValueFollowing(1))) + + // Test window resolution. + val ws1 = Map("w1" -> spec, "w2" -> spec, "w3" -> spec) + assertEqual( + s"""$sql + |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following), + | w2 as w1, + | w3 as w1""".stripMargin, + WithWindowDefinition(ws1, plan)) + + // Fail with no reference. + intercept(s"$sql window w2 as w1", "Cannot resolve window reference 'w1'") + + // Fail when resolved reference is not a window spec. + intercept( + s"""$sql + |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following), + | w2 as w1, + | w3 as w2""".stripMargin, + "Window reference 'w2' is not a window specification" + ) + } + + test("lateral view") { + // Single lateral view + assertEqual( + "select * from t lateral view explode(x) expl as x", + table("t") + .generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x")) + .select(star())) + + // Multiple lateral views + assertEqual( + """select * + |from t + |lateral view explode(x) expl + |lateral view outer json_tuple(x, y) jtup q, z""".stripMargin, + table("t") + .generate(Explode('x), join = true, outer = false, Some("expl"), Seq.empty) + .generate(JsonTuple(Seq('x, 'y)), join = true, outer = true, Some("jtup"), Seq("q", "z")) + .select(star())) + + // Multi-Insert lateral views. + val from = table("t1").generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x")) + assertEqual( + """from t1 + |lateral view explode(x) expl as x + |insert into t2 + |select * + |lateral view json_tuple(x, y) jtup q, z + |insert into t3 + |select * + |where s < 10 + """.stripMargin, + Union(from + .generate(JsonTuple(Seq('x, 'y)), join = true, outer = false, Some("jtup"), Seq("q", "z")) + .select(star()) + .insertInto("t2"), + from.where('s < 10).select(star()).insertInto("t3"))) + + // Unresolved generator. + val expected = table("t") + .generate( + UnresolvedGenerator("posexplode", Seq('x)), + join = true, + outer = false, + Some("posexpl"), + Seq("x", "y")) + .select(star()) + assertEqual( + "select * from t lateral view posexplode(x) posexpl as x, y", + expected) + } + + test("joins") { + // Test single joins. + val testUnconditionalJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t as tt $sql u", + table("t").as("tt").join(table("u"), jt, None).select(star())) + } + val testConditionalJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t $sql u as uu on a = b", + table("t").join(table("u").as("uu"), jt, Option('a === 'b)).select(star())) + } + val testNaturalJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t tt natural $sql u as uu", + table("t").as("tt").join(table("u").as("uu"), NaturalJoin(jt), None).select(star())) + } + val testUsingJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t $sql u using(a, b)", + table("t").join(table("u"), UsingJoin(jt, Seq('a.attr, 'b.attr)), None).select(star())) + } + val testAll = Seq(testUnconditionalJoin, testConditionalJoin, testNaturalJoin, testUsingJoin) + val testExistence = Seq(testUnconditionalJoin, testConditionalJoin, testUsingJoin) + def test(sql: String, jt: JoinType, tests: Seq[(String, JoinType) => Unit]): Unit = { + tests.foreach(_(sql, jt)) + } + test("cross join", Inner, Seq(testUnconditionalJoin)) + test(",", Inner, Seq(testUnconditionalJoin)) + test("join", Inner, testAll) + test("inner join", Inner, testAll) + test("left join", LeftOuter, testAll) + test("left outer join", LeftOuter, testAll) + test("right join", RightOuter, testAll) + test("right outer join", RightOuter, testAll) + test("full join", FullOuter, testAll) + test("full outer join", FullOuter, testAll) + test("left semi join", LeftSemi, testExistence) + test("left anti join", LeftAnti, testExistence) + test("anti join", LeftAnti, testExistence) + + // Test multiple consecutive joins + assertEqual( + "select * from a join b join c right join d", + table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star())) + } + + test("sampled relations") { + val sql = "select * from t" + assertEqual(s"$sql tablesample(100 rows)", + table("t").limit(100).select(star())) + assertEqual(s"$sql tablesample(43 percent) as x", + Sample(0, .43d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) + assertEqual(s"$sql tablesample(bucket 4 out of 10) as x", + Sample(0, .4d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) + intercept(s"$sql tablesample(bucket 4 out of 10 on x) as x", + "TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported") + intercept(s"$sql tablesample(bucket 11 out of 10) as x", + s"Sampling fraction (${11.0/10.0}) must be on interval [0, 1]") + } + + test("sub-query") { + val plan = table("t0").select('id) + assertEqual("select id from (t0)", plan) + assertEqual("select id from ((((((t0))))))", plan) + assertEqual( + "(select * from t1) union distinct (select * from t2)", + Distinct(table("t1").select(star()).union(table("t2").select(star())))) + assertEqual( + "select * from ((select * from t1) union (select * from t2)) t", + Distinct( + table("t1").select(star()).union(table("t2").select(star()))).as("t").select(star())) + assertEqual( + """select id + |from (((select id from t0) + | union all + | (select id from t0)) + | union all + | (select id from t0)) as u_1 + """.stripMargin, + plan.union(plan).union(plan).as("u_1").select('id)) + } + + test("scalar sub-query") { + assertEqual( + "select (select max(b) from s) ss from t", + table("t").select(ScalarSubquery(table("s").select('max.function('b))).as("ss"))) + assertEqual( + "select * from t where a = (select b from s)", + table("t").where('a === ScalarSubquery(table("s").select('b))).select(star())) + assertEqual( + "select g from t group by g having a > (select b from s)", + table("t") + .groupBy('g)('g) + .where(('a > ScalarSubquery(table("s").select('b))).cast(BooleanType))) + } + + test("table reference") { + assertEqual("table t", table("t")) + assertEqual("table d.t", table("d", "t")) + } + + test("inline table") { + assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows( + Seq('col1.int), + Seq(1, 2, 3, 4).map(x => Row(x)))) + assertEqual( + "values (1, 'a'), (2, 'b'), (3, 'c') as tbl(a, b)", + LocalRelation.fromExternalRows( + Seq('a.int, 'b.string), + Seq((1, "a"), (2, "b"), (3, "c")).map(x => Row(x._1, x._2))).as("tbl")) + intercept("values (a, 'a'), (b, 'b')", + "All expressions in an inline table must be constants.") + intercept("values (1, 'a'), (2, 'b') as tbl(a, b, c)", + "Number of aliases must match the number of fields in an inline table.") + intercept[ArrayIndexOutOfBoundsException](parsePlan("values (1, 'a'), (2, 'b', 5Y)")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala new file mode 100644 index 0000000000000..297b1931a9557 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.TableIdentifier + +class TableIdentifierParserSuite extends SparkFunSuite { + import CatalystSqlParser._ + + test("table identifier") { + // Regular names. + assert(TableIdentifier("q") === parseTableIdentifier("q")) + assert(TableIdentifier("q", Option("d")) === parseTableIdentifier("d.q")) + + // Illegal names. + intercept[ParseException](parseTableIdentifier("")) + intercept[ParseException](parseTableIdentifier("d.q.g")) + + // SQL Keywords. + val keywords = Seq("select", "from", "where", "left", "right") + keywords.foreach { keyword => + intercept[ParseException](parseTableIdentifier(keyword)) + assert(TableIdentifier(keyword) === parseTableIdentifier(s"`$keyword`")) + assert(TableIdentifier(keyword, Option("db")) === parseTableIdentifier(s"db.`$keyword`")) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala new file mode 100644 index 0000000000000..81cc6b123cdd4 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType} + +class ConstraintPropagationSuite extends SparkFunSuite { + + private def resolveColumn(tr: LocalRelation, columnName: String): Expression = + resolveColumn(tr.analyze, columnName) + + private def resolveColumn(plan: LogicalPlan, columnName: String): Expression = + plan.resolveQuoted(columnName, caseInsensitiveResolution).get + + private def verifyConstraints(found: ExpressionSet, expected: ExpressionSet): Unit = { + val missing = expected -- found + val extra = found -- expected + if (missing.nonEmpty || extra.nonEmpty) { + fail( + s""" + |== FAIL: Constraints do not match === + |Found: ${found.mkString(",")} + |Expected: ${expected.mkString(",")} + |== Result == + |Missing: ${if (missing.isEmpty) "N/A" else missing.mkString(",")} + |Found but not expected: ${if (extra.isEmpty) "N/A" else extra.mkString(",")} + """.stripMargin) + } + } + + test("propagating constraints in filters") { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + + assert(tr.analyze.constraints.isEmpty) + + assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty) + + verifyConstraints(tr + .where('a.attr > 10) + .analyze.constraints, + ExpressionSet(Seq(resolveColumn(tr, "a") > 10, + IsNotNull(resolveColumn(tr, "a"))))) + + verifyConstraints(tr + .where('a.attr > 10) + .select('c.attr, 'a.attr) + .where('c.attr =!= 100) + .analyze.constraints, + ExpressionSet(Seq(resolveColumn(tr, "a") > 10, + resolveColumn(tr, "c") =!= 100, + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "c"))))) + } + + test("propagating constraints in aggregate") { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + + assert(tr.analyze.constraints.isEmpty) + + val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) + .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a).analyze + + verifyConstraints(aliasedRelation.analyze.constraints, + ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "c1") > 10, + IsNotNull(resolveColumn(aliasedRelation.analyze, "c1")), + resolveColumn(aliasedRelation.analyze, "a") < 5, + IsNotNull(resolveColumn(aliasedRelation.analyze, "a"))))) + } + + test("propagating constraints in expand") { + val tr = LocalRelation('a.int, 'b.int, 'c.int) + + assert(tr.analyze.constraints.isEmpty) + + // We add IsNotNull constraints for 'a, 'b and 'c into LocalRelation + // by creating notNullRelation. + val notNullRelation = tr.where('c.attr > 10 && 'a.attr < 5 && 'b.attr > 2) + verifyConstraints(notNullRelation.analyze.constraints, + ExpressionSet(Seq(resolveColumn(notNullRelation.analyze, "c") > 10, + IsNotNull(resolveColumn(notNullRelation.analyze, "c")), + resolveColumn(notNullRelation.analyze, "a") < 5, + IsNotNull(resolveColumn(notNullRelation.analyze, "a")), + resolveColumn(notNullRelation.analyze, "b") > 2, + IsNotNull(resolveColumn(notNullRelation.analyze, "b"))))) + + val expand = Expand( + Seq( + Seq('c, Literal.create(null, StringType), 1), + Seq('c, 'a, 2)), + Seq('c, 'a, 'gid.int), + Project(Seq('a, 'c), + notNullRelation)) + verifyConstraints(expand.analyze.constraints, + ExpressionSet(Seq.empty[Expression])) + } + + test("propagating constraints in aliases") { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + + assert(tr.where('c.attr > 10).select('a.as('x), 'b.as('y)).analyze.constraints.isEmpty) + + val aliasedRelation = tr.where('a.attr > 10).select('a.as('x), 'b, 'b.as('y), 'a.as('z)) + + verifyConstraints(aliasedRelation.analyze.constraints, + ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10, + IsNotNull(resolveColumn(aliasedRelation.analyze, "x")), + resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"), + resolveColumn(aliasedRelation.analyze, "z") > 10, + IsNotNull(resolveColumn(aliasedRelation.analyze, "z"))))) + } + + test("propagating constraints in union") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int) + val tr2 = LocalRelation('d.int, 'e.int, 'f.int) + val tr3 = LocalRelation('g.int, 'h.int, 'i.int) + + assert(tr1 + .where('a.attr > 10) + .union(tr2.where('e.attr > 10) + .union(tr3.where('i.attr > 10))) + .analyze.constraints.isEmpty) + + verifyConstraints(tr1 + .where('a.attr > 10) + .union(tr2.where('d.attr > 10) + .union(tr3.where('g.attr > 10))) + .analyze.constraints, + ExpressionSet(Seq(resolveColumn(tr1, "a") > 10, + IsNotNull(resolveColumn(tr1, "a"))))) + + val a = resolveColumn(tr1, "a") + verifyConstraints(tr1 + .where('a.attr > 10) + .union(tr2.where('d.attr > 11)) + .analyze.constraints, + ExpressionSet(Seq(a > 10 || a > 11, IsNotNull(a)))) + + val b = resolveColumn(tr1, "b") + verifyConstraints(tr1 + .where('a.attr > 10 && 'b.attr < 10) + .union(tr2.where('d.attr > 11 && 'e.attr < 11)) + .analyze.constraints, + ExpressionSet(Seq(a > 10 || a > 11, b < 10 || b < 11, IsNotNull(a), IsNotNull(b)))) + } + + test("propagating constraints in intersect") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int) + val tr2 = LocalRelation('a.int, 'b.int, 'c.int) + + verifyConstraints(tr1 + .where('a.attr > 10) + .intersect(tr2.where('b.attr < 100)) + .analyze.constraints, + ExpressionSet(Seq(resolveColumn(tr1, "a") > 10, + resolveColumn(tr1, "b") < 100, + IsNotNull(resolveColumn(tr1, "a")), + IsNotNull(resolveColumn(tr1, "b"))))) + } + + test("propagating constraints in except") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int) + val tr2 = LocalRelation('a.int, 'b.int, 'c.int) + verifyConstraints(tr1 + .where('a.attr > 10) + .except(tr2.where('b.attr < 100)) + .analyze.constraints, + ExpressionSet(Seq(resolveColumn(tr1, "a") > 10, + IsNotNull(resolveColumn(tr1, "a"))))) + } + + test("propagating constraints in inner join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, + ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, + tr1.resolveQuoted("a", caseInsensitiveResolution).get === + tr2.resolveQuoted("a", caseInsensitiveResolution).get, + tr2.resolveQuoted("a", caseInsensitiveResolution).get > 10, + IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get), + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get), + IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))) + } + + test("propagating constraints in left-semi join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), LeftSemi, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, + ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))) + } + + test("propagating constraints in left-outer join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), LeftOuter, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, + ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))) + } + + test("propagating constraints in right-outer join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + verifyConstraints(tr1 + .where('a.attr > 10) + .join(tr2.where('d.attr < 100), RightOuter, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints, + ExpressionSet(Seq(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, + IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))) + } + + test("propagating constraints in full-outer join") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + assert(tr1.where('a.attr > 10) + .join(tr2.where('d.attr < 100), FullOuter, Some("tr1.a".attr === "tr2.a".attr)) + .analyze.constraints.isEmpty) + } + + test("infer additional constraints in filters") { + val tr = LocalRelation('a.int, 'b.int, 'c.int) + + verifyConstraints(tr + .where('a.attr > 10 && 'a.attr === 'b.attr) + .analyze.constraints, + ExpressionSet(Seq(resolveColumn(tr, "a") > 10, + resolveColumn(tr, "b") > 10, + resolveColumn(tr, "a") === resolveColumn(tr, "b"), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b"))))) + } + + test("infer constraints on cast") { + val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int) + verifyConstraints( + tr.where('a.attr === 'b.attr && + 'c.attr + 100 > 'd.attr && + IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))).analyze.constraints, + ExpressionSet(Seq(Cast(resolveColumn(tr, "a"), LongType) === resolveColumn(tr, "b"), + Cast(resolveColumn(tr, "c") + 100, LongType) > resolveColumn(tr, "d"), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "d")), + IsNotNull(resolveColumn(tr, "e")), + IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))))) + } + + test("infer isnotnull constraints from compound expressions") { + val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int) + verifyConstraints( + tr.where('a.attr + 'b.attr === 'c.attr && + IsNotNull( + Cast( + Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))).analyze.constraints, + ExpressionSet(Seq( + Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b") === + Cast(resolveColumn(tr, "c"), LongType), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "e")), + IsNotNull(Cast(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))))) + + verifyConstraints( + tr.where(('a.attr * 'b.attr + 100) === 'c.attr && 'd / 10 === 'e).analyze.constraints, + ExpressionSet(Seq( + Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + Cast(100, LongType) === + Cast(resolveColumn(tr, "c"), LongType), + Cast(resolveColumn(tr, "d"), DoubleType) / + Cast(Cast(10, LongType), DoubleType) === + Cast(resolveColumn(tr, "e"), DoubleType), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "d")), + IsNotNull(resolveColumn(tr, "e"))))) + + verifyConstraints( + tr.where(('a.attr * 'b.attr - 10) >= 'c.attr && 'd / 10 < 'e).analyze.constraints, + ExpressionSet(Seq( + Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - Cast(10, LongType) >= + Cast(resolveColumn(tr, "c"), LongType), + Cast(resolveColumn(tr, "d"), DoubleType) / + Cast(Cast(10, LongType), DoubleType) < + Cast(resolveColumn(tr, "e"), DoubleType), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "d")), + IsNotNull(resolveColumn(tr, "e"))))) + + verifyConstraints( + tr.where('a.attr + 'b.attr - 'c.attr * 'd.attr > 'e.attr * 1000).analyze.constraints, + ExpressionSet(Seq( + (Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b")) - + (Cast(resolveColumn(tr, "c"), LongType) * resolveColumn(tr, "d")) > + Cast(resolveColumn(tr, "e") * 1000, LongType), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "d")), + IsNotNull(resolveColumn(tr, "e"))))) + + // The constraint IsNotNull(IsNotNull(expr)) doesn't guarantee expr is not null. + verifyConstraints( + tr.where('a.attr === 'c.attr && + IsNotNull(IsNotNull(resolveColumn(tr, "b")))).analyze.constraints, + ExpressionSet(Seq( + resolveColumn(tr, "a") === resolveColumn(tr, "c"), + IsNotNull(IsNotNull(resolveColumn(tr, "b"))), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "c"))))) + } + + test("infer IsNotNull constraints from non-nullable attributes") { + val tr = LocalRelation('a.int, AttributeReference("b", IntegerType, nullable = false)(), + AttributeReference("c", StringType, nullable = false)()) + + verifyConstraints(tr.analyze.constraints, + ExpressionSet(Seq(IsNotNull(resolveColumn(tr, "b")), IsNotNull(resolveColumn(tr, "c"))))) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index 455a3810c719e..faef9ed274593 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -18,9 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util._ /** * This suite is used to test [[LogicalPlan]]'s `resolveOperators` and make sure it can correctly diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 2efee1fc54706..71919366999ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -19,30 +19,51 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Filter, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample} import org.apache.spark.sql.catalyst.util._ /** * Provides helper methods for comparing plans. */ -abstract class PlanTest extends SparkFunSuite { +abstract class PlanTest extends SparkFunSuite with PredicateHelper { /** * Since attribute references are given globally unique ids during analysis, * we must normalize them to check if two different queries are identical. */ protected def normalizeExprIds(plan: LogicalPlan) = { plan transformAllExpressions { + case s: ScalarSubquery => + ScalarSubquery(s.query, ExprId(0)) case a: AttributeReference => AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) case a: Alias => Alias(a.child, a.name)(exprId = ExprId(0)) + case ae: AggregateExpression => + ae.copy(resultId = ExprId(0)) + } + } + + /** + * Normalizes plans: + * - Filter the filter conditions that appear in a plan. For instance, + * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2) + * etc., will all now be equivalent. + * - Sample the seed will replaced by 0L. + */ + private def normalizePlan(plan: LogicalPlan): LogicalPlan = { + plan transform { + case filter @ Filter(condition: Expression, child: LogicalPlan) => + Filter(splitConjunctivePredicates(condition).sortBy(_.hashCode()).reduce(And), child) + case sample: Sample => + sample.copy(seed = 0L)(true) } } /** Fails the test if the two plans do not match */ protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { - val normalized1 = normalizeExprIds(plan1) - val normalized2 = normalizeExprIds(plan2) + val normalized1 = normalizePlan(normalizeExprIds(plan1)) + val normalized2 = normalizePlan(normalizeExprIds(plan2)) if (normalized1 != normalized2) { fail( s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala index 62d5f6ac74885..467f76193cfc5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala @@ -18,10 +18,9 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Union} import org.apache.spark.sql.catalyst.util._ /** @@ -62,4 +61,9 @@ class SameResultSuite extends SparkFunSuite { test("sorts") { assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc)) } + + test("union") { + assertSameResult(Union(Seq(testRelation, testRelation2)), + Union(Seq(testRelation2, testRelation))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala index a7de7b052bdc3..c9d36910b0998 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst.trees import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} class RuleExecutorSuite extends SparkFunSuite { @@ -49,6 +51,9 @@ class RuleExecutorSuite extends SparkFunSuite { val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil } - assert(ToFixedPoint.execute(Literal(100)) === Literal(90)) + val message = intercept[TreeNodeException[LogicalPlan]] { + ToFixedPoint.execute(Literal(100)) + }.getMessage + assert(message.contains("Max iterations (10) reached for batch fixedPoint")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 8fff39906b342..6a188e7e55126 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.types.{IntegerType, StringType, NullType} +import org.apache.spark.sql.types.{IntegerType, NullType, StringType} case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFallback { override def children: Seq[Expression] = optKey.toSeq @@ -38,6 +38,13 @@ case class ComplexPlan(exprs: Seq[Seq[Expression]]) override def output: Seq[Attribute] = Nil } +case class ExpressionInMap(map: Map[String, Expression]) extends Expression with Unevaluable { + override def children: Seq[Expression] = map.values.toSeq + override def nullable: Boolean = true + override def dataType: NullType = NullType + override lazy val resolved = true +} + class TreeNodeSuite extends SparkFunSuite { test("top node changed") { val after = Literal(1) transform { case Literal(1, _) => Literal(2) } @@ -236,4 +243,22 @@ class TreeNodeSuite extends SparkFunSuite { val expected = ComplexPlan(Seq(Seq(Literal("1")), Seq(Literal("2")))) assert(expected === actual) } + + test("expressions inside a map") { + val expression = ExpressionInMap(Map("1" -> Literal(1), "2" -> Literal(2))) + + { + val actual = expression.transform { + case Literal(i: Int, _) => Literal(i + 1) + } + val expected = ExpressionInMap(Map("1" -> Literal(2), "2" -> Literal(3))) + assert(actual === expected) + } + + { + val actual = expression.withNewChildren(Seq(Literal(2), Literal(3))) + val expected = ExpressionInMap(Map("1" -> Literal(2), "2" -> Literal(3))) + assert(actual === expected) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala deleted file mode 100644 index 1e3409a9db6eb..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala +++ /dev/null @@ -1,119 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.catalyst.util - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types._ - -class DataTypeParserSuite extends SparkFunSuite { - - def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = { - test(s"parse ${dataTypeString.replace("\n", "")}") { - assert(DataTypeParser.parse(dataTypeString) === expectedDataType) - } - } - - def unsupported(dataTypeString: String): Unit = { - test(s"$dataTypeString is not supported") { - intercept[DataTypeException](DataTypeParser.parse(dataTypeString)) - } - } - - checkDataType("int", IntegerType) - checkDataType("integer", IntegerType) - checkDataType("BooLean", BooleanType) - checkDataType("tinYint", ByteType) - checkDataType("smallINT", ShortType) - checkDataType("INT", IntegerType) - checkDataType("INTEGER", IntegerType) - checkDataType("bigint", LongType) - checkDataType("float", FloatType) - checkDataType("dOUBle", DoubleType) - checkDataType("decimal(10, 5)", DecimalType(10, 5)) - checkDataType("decimal", DecimalType.USER_DEFAULT) - checkDataType("DATE", DateType) - checkDataType("timestamp", TimestampType) - checkDataType("string", StringType) - checkDataType("varchAr(20)", StringType) - checkDataType("BINARY", BinaryType) - - checkDataType("array", ArrayType(DoubleType, true)) - checkDataType("Array>", ArrayType(MapType(IntegerType, ByteType, true), true)) - checkDataType( - "array>", - ArrayType(StructType(StructField("tinYint", ByteType, true) :: Nil), true) - ) - checkDataType("MAP", MapType(IntegerType, StringType, true)) - checkDataType("MAp>", MapType(IntegerType, ArrayType(DoubleType), true)) - checkDataType( - "MAP>", - MapType(IntegerType, StructType(StructField("varchar", StringType, true) :: Nil), true) - ) - - checkDataType( - "struct", - StructType( - StructField("intType", IntegerType, true) :: - StructField("ts", TimestampType, true) :: Nil) - ) - // It is fine to use the data type string as the column name. - checkDataType( - "Struct", - StructType( - StructField("int", IntegerType, true) :: - StructField("timestamp", TimestampType, true) :: Nil) - ) - checkDataType( - """ - |struct< - | struct:struct, - | MAP:Map, - | arrAy:Array> - """.stripMargin, - StructType( - StructField("struct", - StructType( - StructField("deciMal", DecimalType.USER_DEFAULT, true) :: - StructField("anotherDecimal", DecimalType(5, 2), true) :: Nil), true) :: - StructField("MAP", MapType(TimestampType, StringType), true) :: - StructField("arrAy", ArrayType(DoubleType, true), true) :: Nil) - ) - // A column name can be a reserved word in our DDL parser and SqlParser. - checkDataType( - "Struct", - StructType( - StructField("TABLE", StringType, true) :: - StructField("CASE", BooleanType, true) :: Nil) - ) - // Use backticks to quote column names having special characters. - checkDataType( - "struct<`x+y`:int, `!@#$%^&*()`:string, `1_2.345<>:\"`:varchar(20)>", - StructType( - StructField("x+y", IntegerType, true) :: - StructField("!@#$%^&*()", StringType, true) :: - StructField("1_2.345<>:\"", StringType, true) :: Nil) - ) - // Empty struct. - checkDataType("strUCt<>", StructType(Nil)) - - unsupported("it is not a data type") - unsupported("struct") - unsupported("struct") - unsupported("struct<`x``y` int>") -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 46335941b62d6..6745b4b6c3c67 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -22,8 +22,8 @@ import java.text.SimpleDateFormat import java.util.{Calendar, TimeZone} import org.apache.spark.SparkFunSuite -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.sql.catalyst.util.DateTimeUtils._ +import org.apache.spark.unsafe.types.UTF8String class DateTimeUtilsSuite extends SparkFunSuite { @@ -110,6 +110,10 @@ class DateTimeUtilsSuite extends SparkFunSuite { c.set(Calendar.MILLISECOND, 0) assert(stringToDate(UTF8String.fromString("2015")).get === millisToDays(c.getTimeInMillis)) + c.set(1, 0, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + assert(stringToDate(UTF8String.fromString("0001")).get === + millisToDays(c.getTimeInMillis)) c = Calendar.getInstance() c.set(2015, 2, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) @@ -134,11 +138,15 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(stringToDate(UTF8String.fromString("2015.03.18")).isEmpty) assert(stringToDate(UTF8String.fromString("20150318")).isEmpty) assert(stringToDate(UTF8String.fromString("2015-031-8")).isEmpty) + assert(stringToDate(UTF8String.fromString("02015-03-18")).isEmpty) + assert(stringToDate(UTF8String.fromString("015-03-18")).isEmpty) + assert(stringToDate(UTF8String.fromString("015")).isEmpty) + assert(stringToDate(UTF8String.fromString("02015")).isEmpty) } test("string to time") { // Tests with UTC. - var c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + val c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) c.set(Calendar.MILLISECOND, 0) c.set(1900, 0, 1, 0, 0, 0) @@ -174,9 +182,9 @@ class DateTimeUtilsSuite extends SparkFunSuite { c.set(Calendar.MILLISECOND, 0) assert(stringToTimestamp(UTF8String.fromString("1969-12-31 16:00:00")).get === c.getTimeInMillis * 1000) - c.set(2015, 0, 1, 0, 0, 0) + c.set(1, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(stringToTimestamp(UTF8String.fromString("2015")).get === + assert(stringToTimestamp(UTF8String.fromString("0001")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance() c.set(2015, 2, 1, 0, 0, 0) @@ -319,6 +327,7 @@ class DateTimeUtilsSuite extends SparkFunSuite { UTF8String.fromString("2011-05-06 07:08:09.1000")).get === c.getTimeInMillis * 1000) assert(stringToTimestamp(UTF8String.fromString("238")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("00238")).isEmpty) assert(stringToTimestamp(UTF8String.fromString("2015-03-18 123142")).isEmpty) assert(stringToTimestamp(UTF8String.fromString("2015-03-18T123123")).isEmpty) assert(stringToTimestamp(UTF8String.fromString("2015-03-18X")).isEmpty) @@ -326,12 +335,22 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(stringToTimestamp(UTF8String.fromString("2015.03.18")).isEmpty) assert(stringToTimestamp(UTF8String.fromString("20150318")).isEmpty) assert(stringToTimestamp(UTF8String.fromString("2015-031-8")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("02015-01-18")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("015-01-18")).isEmpty) assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03.17-20:0")).isEmpty) assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03.17-0:70")).isEmpty) assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03.17-1:0:0")).isEmpty) + + // Truncating the fractional seconds + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+00:00")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + assert(stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17.123456789+0:00")).get === + c.getTimeInMillis * 1000 + 123456) } test("hours") { @@ -358,6 +377,16 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(getSeconds(c.getTimeInMillis * 1000) === 9) } + test("hours / minutes / seconds") { + Seq(Timestamp.valueOf("2015-06-11 10:12:35.789"), + Timestamp.valueOf("2015-06-11 20:13:40.789"), + Timestamp.valueOf("1900-06-11 12:14:50.789"), + Timestamp.valueOf("1700-02-28 12:14:50.123456")).foreach { t => + val us = fromJavaTimestamp(t) + assert(toJavaTimestamp(us) === t) + } + } + test("get day in year") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala index 4030a1b1df358..a0c1d97bfc3a8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.util import org.json4s.jackson.JsonMethods.parse import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.{MetadataBuilder, Metadata} +import org.apache.spark.sql.types.{Metadata, MetadataBuilder} class MetadataSuite extends SparkFunSuite { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala index d6f273f9e568a..2ffc18a8d14fb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala @@ -31,4 +31,16 @@ class StringUtilsSuite extends SparkFunSuite { assert(escapeLikeRegex("**") === "(?s)\\Q*\\E\\Q*\\E") assert(escapeLikeRegex("a_b") === "(?s)\\Qa\\E.\\Qb\\E") } + + test("filter pattern") { + val names = Seq("a1", "a2", "b2", "c3") + assert(filterPattern(names, " * ") === Seq("a1", "a2", "b2", "c3")) + assert(filterPattern(names, "*a*") === Seq("a1", "a2")) + assert(filterPattern(names, " *a* ") === Seq("a1", "a2")) + assert(filterPattern(names, " a* ") === Seq("a1", "a2")) + assert(filterPattern(names, " a.* ") === Seq("a1", "a2")) + assert(filterPattern(names, " B.*|a* ") === Seq("a1", "a2", "b2")) + assert(filterPattern(names, " a. ") === Seq("a1", "a2")) + assert(filterPattern(names, " d* ") === Nil) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 706ecd29d1355..6b85f12521c2a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -122,7 +122,9 @@ class DataTypeSuite extends SparkFunSuite { val right = StructType(List()) val merged = left.merge(right) - assert(merged === left) + assert(DataType.equalsIgnoreCompatibleNullability(merged, left)) + assert(merged("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(merged("b").metadata.getBoolean(StructType.metadataKeyForOptionalField)) } test("merge where left is empty") { @@ -135,8 +137,9 @@ class DataTypeSuite extends SparkFunSuite { val merged = left.merge(right) - assert(right === merged) - + assert(DataType.equalsIgnoreCompatibleNullability(merged, right)) + assert(merged("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(merged("b").metadata.getBoolean(StructType.metadataKeyForOptionalField)) } test("merge where both are non-empty") { @@ -154,7 +157,10 @@ class DataTypeSuite extends SparkFunSuite { val merged = left.merge(right) - assert(merged === expected) + assert(DataType.equalsIgnoreCompatibleNullability(merged, expected)) + assert(merged("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(merged("b").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(merged("c").metadata.getBoolean(StructType.metadataKeyForOptionalField)) } test("merge where right contains type conflict") { @@ -242,15 +248,15 @@ class DataTypeSuite extends SparkFunSuite { checkDefaultSize(LongType, 8) checkDefaultSize(FloatType, 4) checkDefaultSize(DoubleType, 8) - checkDefaultSize(DecimalType(10, 5), 4096) - checkDefaultSize(DecimalType.SYSTEM_DEFAULT, 4096) + checkDefaultSize(DecimalType(10, 5), 8) + checkDefaultSize(DecimalType.SYSTEM_DEFAULT, 16) checkDefaultSize(DateType, 4) checkDefaultSize(TimestampType, 8) - checkDefaultSize(StringType, 4096) - checkDefaultSize(BinaryType, 4096) + checkDefaultSize(StringType, 20) + checkDefaultSize(BinaryType, 100) checkDefaultSize(ArrayType(DoubleType, true), 800) - checkDefaultSize(ArrayType(StringType, false), 409600) - checkDefaultSize(MapType(IntegerType, StringType, true), 410000) + checkDefaultSize(ArrayType(StringType, false), 2000) + checkDefaultSize(MapType(IntegerType, StringType, true), 2400) checkDefaultSize(MapType(IntegerType, ArrayType(DoubleType), false), 80400) checkDefaultSize(structType, 812) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 50683947da224..e1675c95907af 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.types -import org.apache.spark.SparkFunSuite +import scala.language.postfixOps + import org.scalatest.PrivateMethodTester -import scala.language.postfixOps +import org.apache.spark.SparkFunSuite class DecimalSuite extends SparkFunSuite with PrivateMethodTester { /** Check that a Decimal has the given string representation, precision and scale */ diff --git a/sql/core/pom.xml b/sql/core/pom.xml index c96855e261ee8..8b1017042cd93 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT + spark-parent_2.11 + 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-sql_2.10 + spark-sql_2.11 jar Spark Project SQL http://spark.apache.org/ @@ -36,6 +36,17 @@ + + com.univocity + univocity-parsers + 2.0.2 + jar + + + org.apache.spark + spark-sketch_2.11 + ${project.version} + org.apache.spark spark-core_${scala.binary.version} @@ -72,6 +83,10 @@ org.apache.parquet parquet-hadoop + + org.eclipse.jetty + jetty-servlet + com.fasterxml.jackson.core jackson-databind @@ -91,13 +106,11 @@ mysql mysql-connector-java - 5.1.34 test org.postgresql postgresql - 9.3-1102-jdbc41 test @@ -110,6 +123,11 @@ mockito-core test + + org.apache.xbean + xbean-asm5-shaded + test + target/scala-${scala.binary.version}/classes diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java new file mode 100644 index 0000000000000..086547c793e3b --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import java.io.IOException; +import java.util.LinkedList; + +import scala.collection.Iterator; + +import org.apache.spark.TaskContext; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; + +/** + * An iterator interface used to pull the output from generated function for multiple operators + * (whole stage codegen). + */ +public abstract class BufferedRowIterator { + protected LinkedList currentRows = new LinkedList<>(); + // used when there is no column in output + protected UnsafeRow unsafeRow = new UnsafeRow(0); + private long startTimeNs = System.nanoTime(); + + protected int partitionIndex = -1; + + public boolean hasNext() throws IOException { + if (currentRows.isEmpty()) { + processNext(); + } + return !currentRows.isEmpty(); + } + + public InternalRow next() { + return currentRows.remove(); + } + + /** + * Returns the elapsed time since this object is created. This object represents a pipeline so + * this is a measure of how long the pipeline has been running. + */ + public long durationMs() { + return (System.nanoTime() - startTimeNs) / (1000 * 1000); + } + + /** + * Initializes from array of iterators of InternalRow. + */ + public abstract void init(int index, Iterator[] iters); + + /** + * Append a row to currentRows. + */ + protected void append(InternalRow row) { + currentRows.add(row); + } + + /** + * Returns whether `processNext()` should stop processing next row from `input` or not. + * + * If it returns true, the caller should exit the loop (return from processNext()). + */ + protected boolean shouldStop() { + return !currentRows.isEmpty(); + } + + /** + * Increase the peak execution memory for current task. + */ + protected void incPeakExecutionMemory(long size) { + TaskContext.get().taskMetrics().incPeakExecutionMemory(size); + } + + /** + * Processes the input until have a row as output (currentRow). + * + * After it's called, if currentRow is still null, it means no more rows left. + */ + protected abstract void processNext() throws IOException; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index a2f99d566d471..1f1b5389aa7d4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -29,7 +29,6 @@ import org.apache.spark.unsafe.KVIterator; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.map.BytesToBytesMap; -import org.apache.spark.unsafe.memory.MemoryLocation; /** * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width. @@ -61,7 +60,7 @@ public final class UnsafeFixedWidthAggregationMap { /** * Re-used pointer to the current aggregation buffer */ - private final UnsafeRow currentAggregationBuffer = new UnsafeRow(); + private final UnsafeRow currentAggregationBuffer; private final boolean enablePerfMetrics; @@ -98,6 +97,7 @@ public UnsafeFixedWidthAggregationMap( long pageSizeBytes, boolean enablePerfMetrics) { this.aggregationBufferSchema = aggregationBufferSchema; + this.currentAggregationBuffer = new UnsafeRow(aggregationBufferSchema.length()); this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); this.groupingKeySchema = groupingKeySchema; this.map = @@ -120,19 +120,24 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { return getAggregationBufferFromUnsafeRow(unsafeGroupingKeyRow); } - public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRow) { + public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key) { + return getAggregationBufferFromUnsafeRow(key, key.hashCode()); + } + + public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key, int hash) { // Probe our map using the serialized key final BytesToBytesMap.Location loc = map.lookup( - unsafeGroupingKeyRow.getBaseObject(), - unsafeGroupingKeyRow.getBaseOffset(), - unsafeGroupingKeyRow.getSizeInBytes()); + key.getBaseObject(), + key.getBaseOffset(), + key.getSizeInBytes(), + hash); if (!loc.isDefined()) { // This is the first time that we've seen this grouping key, so we'll insert a copy of the // empty aggregation buffer into the map: - boolean putSucceeded = loc.putNewKey( - unsafeGroupingKeyRow.getBaseObject(), - unsafeGroupingKeyRow.getBaseOffset(), - unsafeGroupingKeyRow.getSizeInBytes(), + boolean putSucceeded = loc.append( + key.getBaseObject(), + key.getBaseOffset(), + key.getSizeInBytes(), emptyAggregationBuffer, Platform.BYTE_ARRAY_OFFSET, emptyAggregationBuffer.length @@ -143,11 +148,9 @@ public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRo } // Reset the pointer to point to the value that we just stored or looked up: - final MemoryLocation address = loc.getValueAddress(); currentAggregationBuffer.pointTo( - address.getBaseObject(), - address.getBaseOffset(), - aggregationBufferSchema.length(), + loc.getValueBase(), + loc.getValueOffset(), loc.getValueLength() ); return currentAggregationBuffer; @@ -165,25 +168,21 @@ public KVIterator iterator() { private final BytesToBytesMap.MapIterator mapLocationIterator = map.destructiveIterator(); - private final UnsafeRow key = new UnsafeRow(); - private final UnsafeRow value = new UnsafeRow(); + private final UnsafeRow key = new UnsafeRow(groupingKeySchema.length()); + private final UnsafeRow value = new UnsafeRow(aggregationBufferSchema.length()); @Override public boolean next() { if (mapLocationIterator.hasNext()) { final BytesToBytesMap.Location loc = mapLocationIterator.next(); - final MemoryLocation keyAddress = loc.getKeyAddress(); - final MemoryLocation valueAddress = loc.getValueAddress(); key.pointTo( - keyAddress.getBaseObject(), - keyAddress.getBaseOffset(), - groupingKeySchema.length(), + loc.getKeyBase(), + loc.getKeyOffset(), loc.getKeyLength() ); value.pointTo( - valueAddress.getBaseObject(), - valueAddress.getBaseOffset(), - aggregationBufferSchema.length(), + loc.getValueBase(), + loc.getValueOffset(), loc.getValueLength() ); return true; @@ -237,12 +236,16 @@ public void printPerfMetrics() { /** * Sorts the map's records in place, spill them to disk, and returns an [[UnsafeKVExternalSorter]] * - * Note that the map will be reset for inserting new records, and the returned sorter can NOT be used - * to insert records. + * Note that the map will be reset for inserting new records, and the returned sorter can NOT be + * used to insert records. */ public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOException { return new UnsafeKVExternalSorter( - groupingKeySchema, aggregationBufferSchema, - SparkEnv.get().blockManager(), map.getPageSizeBytes(), map); + groupingKeySchema, + aggregationBufferSchema, + SparkEnv.get().blockManager(), + SparkEnv.get().serializerManager(), + map.getPageSizeBytes(), + map); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index e2898ef2e2158..8132bba04caeb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -24,6 +24,7 @@ import org.apache.spark.TaskContext; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering; import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering; @@ -52,14 +53,16 @@ public UnsafeKVExternalSorter( StructType keySchema, StructType valueSchema, BlockManager blockManager, + SerializerManager serializerManager, long pageSizeBytes) throws IOException { - this(keySchema, valueSchema, blockManager, pageSizeBytes, null); + this(keySchema, valueSchema, blockManager, serializerManager, pageSizeBytes, null); } public UnsafeKVExternalSorter( StructType keySchema, StructType valueSchema, BlockManager blockManager, + SerializerManager serializerManager, long pageSizeBytes, @Nullable BytesToBytesMap map) throws IOException { this.keySchema = keySchema; @@ -77,6 +80,7 @@ public UnsafeKVExternalSorter( sorter = UnsafeExternalSorter.create( taskMemoryManager, blockManager, + serializerManager, taskContext, recordComparator, prefixComparator, @@ -85,19 +89,20 @@ public UnsafeKVExternalSorter( } else { // During spilling, the array in map will not be used, so we can borrow that and use it // as the underline array for in-memory sorter (it's always large enough). + // Since we will not grow the array, it's fine to pass `null` as consumer. final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( - taskMemoryManager, recordComparator, prefixComparator, map.getArray()); + null, taskMemoryManager, recordComparator, prefixComparator, map.getArray()); // We cannot use the destructive iterator here because we are reusing the existing memory // pages in BytesToBytesMap to hold records during sorting. // The only new memory we are allocating is the pointer/prefix array. BytesToBytesMap.MapIterator iter = map.iterator(); final int numKeyFields = keySchema.size(); - UnsafeRow row = new UnsafeRow(); + UnsafeRow row = new UnsafeRow(numKeyFields); while (iter.hasNext()) { final BytesToBytesMap.Location loc = iter.next(); - final Object baseObject = loc.getKeyAddress().getBaseObject(); - final long baseOffset = loc.getKeyAddress().getBaseOffset(); + final Object baseObject = loc.getKeyBase(); + final long baseOffset = loc.getKeyOffset(); // Get encoded memory address // baseObject + baseOffset point to the beginning of the key data in the map, but that @@ -106,7 +111,7 @@ public UnsafeKVExternalSorter( long address = taskMemoryManager.encodePageNumberAndOffset(page, baseOffset - 8); // Compute prefix - row.pointTo(baseObject, baseOffset, numKeyFields, loc.getKeyLength()); + row.pointTo(baseObject, baseOffset, loc.getKeyLength()); final long prefix = prefixComputer.computePrefix(row); inMemSorter.insertRecord(address, prefix); @@ -115,6 +120,7 @@ public UnsafeKVExternalSorter( sorter = UnsafeExternalSorter.createWithExistingInMemorySorter( taskMemoryManager, blockManager, + serializerManager, taskContext, new KVComparator(ordering, keySchema.length()), prefixComparator, @@ -193,12 +199,14 @@ public void cleanupResources() { private static final class KVComparator extends RecordComparator { private final BaseOrdering ordering; - private final UnsafeRow row1 = new UnsafeRow(); - private final UnsafeRow row2 = new UnsafeRow(); + private final UnsafeRow row1; + private final UnsafeRow row2; private final int numKeyFields; - public KVComparator(BaseOrdering ordering, int numKeyFields) { + KVComparator(BaseOrdering ordering, int numKeyFields) { this.numKeyFields = numKeyFields; + this.row1 = new UnsafeRow(numKeyFields); + this.row2 = new UnsafeRow(numKeyFields); this.ordering = ordering; } @@ -206,17 +214,15 @@ public KVComparator(BaseOrdering ordering, int numKeyFields) { public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { // Note that since ordering doesn't need the total length of the record, we just pass -1 // into the row. - row1.pointTo(baseObj1, baseOff1 + 4, numKeyFields, -1); - row2.pointTo(baseObj2, baseOff2 + 4, numKeyFields, -1); + row1.pointTo(baseObj1, baseOff1 + 4, -1); + row2.pointTo(baseObj2, baseOff2 + 4, -1); return ordering.compare(row1, row2); } } public class KVSorterIterator extends KVIterator { - private UnsafeRow key = new UnsafeRow(); - private UnsafeRow value = new UnsafeRow(); - private final int numKeyFields = keySchema.size(); - private final int numValueFields = valueSchema.size(); + private UnsafeRow key = new UnsafeRow(keySchema.size()); + private UnsafeRow value = new UnsafeRow(valueSchema.size()); private final UnsafeSorterIterator underlying; private KVSorterIterator(UnsafeSorterIterator underlying) { @@ -236,8 +242,8 @@ public boolean next() throws IOException { // Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself) int keyLen = Platform.getInt(baseObj, recordOffset); int valueLen = recordLen - keyLen - 4; - key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen); - value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen); + key.pointTo(baseObj, recordOffset + 4, keyLen); + value.pointTo(baseObj, recordOffset + 4 + keyLen, valueLen); return true; } else { @@ -266,5 +272,5 @@ public UnsafeRow getValue() { public void close() { cleanupResources(); } - }; + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java new file mode 100644 index 0000000000000..5c257bc260873 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -0,0 +1,305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.sql.execution.datasources.parquet; + +import java.io.ByteArrayInputStream; +import java.io.File; +import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.apache.parquet.filter2.compat.RowGroupFilter.filterRowGroups; +import static org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER; +import static org.apache.parquet.format.converter.ParquetMetadataConverter.range; +import static org.apache.parquet.hadoop.ParquetFileReader.readFooter; +import static org.apache.parquet.hadoop.ParquetInputFormat.getFilter; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.RecordReader; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.parquet.bytes.BytesInput; +import org.apache.parquet.bytes.BytesUtils; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.values.ValuesReader; +import org.apache.parquet.column.values.rle.RunLengthBitPackingHybridDecoder; +import org.apache.parquet.filter2.compat.FilterCompat; +import org.apache.parquet.hadoop.BadConfigurationException; +import org.apache.parquet.hadoop.ParquetFileReader; +import org.apache.parquet.hadoop.ParquetInputFormat; +import org.apache.parquet.hadoop.ParquetInputSplit; +import org.apache.parquet.hadoop.api.InitContext; +import org.apache.parquet.hadoop.api.ReadSupport; +import org.apache.parquet.hadoop.metadata.BlockMetaData; +import org.apache.parquet.hadoop.metadata.ParquetMetadata; +import org.apache.parquet.hadoop.util.ConfigurationUtil; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.Types; +import org.apache.spark.sql.types.StructType; + +/** + * Base class for custom RecordReaders for Parquet that directly materialize to `T`. + * This class handles computing row groups, filtering on them, setting up the column readers, + * etc. + * This is heavily based on parquet-mr's RecordReader. + * TODO: move this to the parquet-mr project. There are performance benefits of doing it + * this way, albeit at a higher cost to implement. This base class is reusable. + */ +public abstract class SpecificParquetRecordReaderBase extends RecordReader { + protected Path file; + protected MessageType fileSchema; + protected MessageType requestedSchema; + protected StructType sparkSchema; + + /** + * The total number of rows this RecordReader will eventually read. The sum of the + * rows of all the row groups. + */ + protected long totalRowCount; + + protected ParquetFileReader reader; + + @Override + public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) + throws IOException, InterruptedException { + Configuration configuration = taskAttemptContext.getConfiguration(); + ParquetInputSplit split = (ParquetInputSplit)inputSplit; + this.file = split.getPath(); + long[] rowGroupOffsets = split.getRowGroupOffsets(); + + ParquetMetadata footer; + List blocks; + + // if task.side.metadata is set, rowGroupOffsets is null + if (rowGroupOffsets == null) { + // then we need to apply the predicate push down filter + footer = readFooter(configuration, file, range(split.getStart(), split.getEnd())); + MessageType fileSchema = footer.getFileMetaData().getSchema(); + FilterCompat.Filter filter = getFilter(configuration); + blocks = filterRowGroups(filter, footer.getBlocks(), fileSchema); + } else { + // otherwise we find the row groups that were selected on the client + footer = readFooter(configuration, file, NO_FILTER); + Set offsets = new HashSet<>(); + for (long offset : rowGroupOffsets) { + offsets.add(offset); + } + blocks = new ArrayList<>(); + for (BlockMetaData block : footer.getBlocks()) { + if (offsets.contains(block.getStartingPos())) { + blocks.add(block); + } + } + // verify we found them all + if (blocks.size() != rowGroupOffsets.length) { + long[] foundRowGroupOffsets = new long[footer.getBlocks().size()]; + for (int i = 0; i < foundRowGroupOffsets.length; i++) { + foundRowGroupOffsets[i] = footer.getBlocks().get(i).getStartingPos(); + } + // this should never happen. + // provide a good error message in case there's a bug + throw new IllegalStateException( + "All the offsets listed in the split should be found in the file." + + " expected: " + Arrays.toString(rowGroupOffsets) + + " found: " + blocks + + " out of: " + Arrays.toString(foundRowGroupOffsets) + + " in range " + split.getStart() + ", " + split.getEnd()); + } + } + this.fileSchema = footer.getFileMetaData().getSchema(); + Map fileMetadata = footer.getFileMetaData().getKeyValueMetaData(); + ReadSupport readSupport = getReadSupportInstance(getReadSupportClass(configuration)); + ReadSupport.ReadContext readContext = readSupport.init(new InitContext( + taskAttemptContext.getConfiguration(), toSetMultiMap(fileMetadata), fileSchema)); + this.requestedSchema = readContext.getRequestedSchema(); + this.sparkSchema = new CatalystSchemaConverter(configuration).convert(requestedSchema); + this.reader = new ParquetFileReader(configuration, file, blocks, requestedSchema.getColumns()); + for (BlockMetaData block : blocks) { + this.totalRowCount += block.getRowCount(); + } + } + + /** + * Returns the list of files at 'path' recursively. This skips files that are ignored normally + * by MapReduce. + */ + public static List listDirectory(File path) throws IOException { + List result = new ArrayList<>(); + if (path.isDirectory()) { + for (File f: path.listFiles()) { + result.addAll(listDirectory(f)); + } + } else { + char c = path.getName().charAt(0); + if (c != '.' && c != '_') { + result.add(path.getAbsolutePath()); + } + } + return result; + } + + /** + * Initializes the reader to read the file at `path` with `columns` projected. If columns is + * null, all the columns are projected. + * + * This is exposed for testing to be able to create this reader without the rest of the Hadoop + * split machinery. It is not intended for general use and those not support all the + * configurations. + */ + protected void initialize(String path, List columns) throws IOException { + Configuration config = new Configuration(); + config.set("spark.sql.parquet.binaryAsString", "false"); + config.set("spark.sql.parquet.int96AsTimestamp", "false"); + config.set("spark.sql.parquet.writeLegacyFormat", "false"); + + this.file = new Path(path); + long length = FileSystem.get(config).getFileStatus(this.file).getLen(); + ParquetMetadata footer = readFooter(config, file, range(0, length)); + + List blocks = footer.getBlocks(); + this.fileSchema = footer.getFileMetaData().getSchema(); + + if (columns == null) { + this.requestedSchema = fileSchema; + } else { + Types.MessageTypeBuilder builder = Types.buildMessage(); + for (String s: columns) { + if (!fileSchema.containsField(s)) { + throw new IOException("Can only project existing columns. Unknown field: " + s + + " File schema:\n" + fileSchema); + } + builder.addFields(fileSchema.getType(s)); + } + this.requestedSchema = builder.named("spark_schema"); + } + this.sparkSchema = new CatalystSchemaConverter(config).convert(requestedSchema); + this.reader = new ParquetFileReader(config, file, blocks, requestedSchema.getColumns()); + for (BlockMetaData block : blocks) { + this.totalRowCount += block.getRowCount(); + } + } + + @Override + public Void getCurrentKey() throws IOException, InterruptedException { + return null; + } + + @Override + public void close() throws IOException { + if (reader != null) { + reader.close(); + reader = null; + } + } + + /** + * Utility classes to abstract over different way to read ints with different encodings. + * TODO: remove this layer of abstraction? + */ + abstract static class IntIterator { + abstract int nextInt() throws IOException; + } + + protected static final class ValuesReaderIntIterator extends IntIterator { + ValuesReader delegate; + + public ValuesReaderIntIterator(ValuesReader delegate) { + this.delegate = delegate; + } + + @Override + int nextInt() throws IOException { + return delegate.readInteger(); + } + } + + protected static final class RLEIntIterator extends IntIterator { + RunLengthBitPackingHybridDecoder delegate; + + public RLEIntIterator(RunLengthBitPackingHybridDecoder delegate) { + this.delegate = delegate; + } + + @Override + int nextInt() throws IOException { + return delegate.readInt(); + } + } + + protected static final class NullIntIterator extends IntIterator { + @Override + int nextInt() throws IOException { return 0; } + } + + /** + * Creates a reader for definition and repetition levels, returning an optimized one if + * the levels are not needed. + */ + protected static IntIterator createRLEIterator(int maxLevel, BytesInput bytes, + ColumnDescriptor descriptor) throws IOException { + try { + if (maxLevel == 0) return new NullIntIterator(); + return new RLEIntIterator( + new RunLengthBitPackingHybridDecoder( + BytesUtils.getWidthFromMaxInt(maxLevel), + new ByteArrayInputStream(bytes.toByteArray()))); + } catch (IOException e) { + throw new IOException("could not read levels in page for col " + descriptor, e); + } + } + + private static Map> toSetMultiMap(Map map) { + Map> setMultiMap = new HashMap<>(); + for (Map.Entry entry : map.entrySet()) { + Set set = new HashSet<>(); + set.add(entry.getValue()); + setMultiMap.put(entry.getKey(), Collections.unmodifiableSet(set)); + } + return Collections.unmodifiableMap(setMultiMap); + } + + @SuppressWarnings("unchecked") + private Class> getReadSupportClass(Configuration configuration) { + return (Class>) ConfigurationUtil.getClassFromConfig(configuration, + ParquetInputFormat.READ_SUPPORT_CLASS, ReadSupport.class); + } + + /** + * @param readSupportClass to instantiate + * @return the configured read support + */ + private static ReadSupport getReadSupportInstance( + Class> readSupportClass){ + try { + return readSupportClass.getConstructor().newInstance(); + } catch (InstantiationException | IllegalAccessException | + NoSuchMethodException | InvocationTargetException e) { + throw new BadConfigurationException("could not instantiate read support class", e); + } + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java new file mode 100644 index 0000000000000..ea37a08ab5f55 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -0,0 +1,523 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet; + +import java.io.IOException; + +import org.apache.commons.lang.NotImplementedException; +import org.apache.parquet.bytes.BytesUtils; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.Dictionary; +import org.apache.parquet.column.Encoding; +import org.apache.parquet.column.page.*; +import org.apache.parquet.column.values.ValuesReader; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.PrimitiveType; + +import org.apache.spark.sql.execution.vectorized.ColumnVector; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DecimalType; + +import static org.apache.parquet.column.ValuesType.REPETITION_LEVEL; +import static org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase.ValuesReaderIntIterator; +import static org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase.createRLEIterator; + +/** + * Decoder to return values from a single column. + */ +public class VectorizedColumnReader { + /** + * Total number of values read. + */ + private long valuesRead; + + /** + * value that indicates the end of the current page. That is, + * if valuesRead == endOfPageValueCount, we are at the end of the page. + */ + private long endOfPageValueCount; + + /** + * The dictionary, if this column has dictionary encoding. + */ + private final Dictionary dictionary; + + /** + * If true, the current page is dictionary encoded. + */ + private boolean useDictionary; + + /** + * Maximum definition level for this column. + */ + private final int maxDefLevel; + + /** + * Repetition/Definition/Value readers. + */ + private SpecificParquetRecordReaderBase.IntIterator repetitionLevelColumn; + private SpecificParquetRecordReaderBase.IntIterator definitionLevelColumn; + private ValuesReader dataColumn; + + // Only set if vectorized decoding is true. This is used instead of the row by row decoding + // with `definitionLevelColumn`. + private VectorizedRleValuesReader defColumn; + + /** + * Total number of values in this column (in this row group). + */ + private final long totalValueCount; + + /** + * Total values in the current page. + */ + private int pageValueCount; + + private final PageReader pageReader; + private final ColumnDescriptor descriptor; + + public VectorizedColumnReader(ColumnDescriptor descriptor, PageReader pageReader) + throws IOException { + this.descriptor = descriptor; + this.pageReader = pageReader; + this.maxDefLevel = descriptor.getMaxDefinitionLevel(); + + DictionaryPage dictionaryPage = pageReader.readDictionaryPage(); + if (dictionaryPage != null) { + try { + this.dictionary = dictionaryPage.getEncoding().initDictionary(descriptor, dictionaryPage); + this.useDictionary = true; + } catch (IOException e) { + throw new IOException("could not decode the dictionary for " + descriptor, e); + } + } else { + this.dictionary = null; + this.useDictionary = false; + } + this.totalValueCount = pageReader.getTotalValueCount(); + if (totalValueCount == 0) { + throw new IOException("totalValueCount == 0"); + } + } + + /** + * Advances to the next value. Returns true if the value is non-null. + */ + private boolean next() throws IOException { + if (valuesRead >= endOfPageValueCount) { + if (valuesRead >= totalValueCount) { + // How do we get here? Throw end of stream exception? + return false; + } + readPage(); + } + ++valuesRead; + // TODO: Don't read for flat schemas + //repetitionLevel = repetitionLevelColumn.nextInt(); + return definitionLevelColumn.nextInt() == maxDefLevel; + } + + /** + * Reads `total` values from this columnReader into column. + */ + void readBatch(int total, ColumnVector column) throws IOException { + int rowId = 0; + while (total > 0) { + // Compute the number of values we want to read in this page. + int leftInPage = (int) (endOfPageValueCount - valuesRead); + if (leftInPage == 0) { + readPage(); + leftInPage = (int) (endOfPageValueCount - valuesRead); + } + int num = Math.min(total, leftInPage); + if (useDictionary) { + // Read and decode dictionary ids. + ColumnVector dictionaryIds = column.reserveDictionaryIds(total); + defColumn.readIntegers( + num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + + if (column.hasDictionary() || (rowId == 0 && + (descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT32 || + descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT64 || + descriptor.getType() == PrimitiveType.PrimitiveTypeName.FLOAT || + descriptor.getType() == PrimitiveType.PrimitiveTypeName.DOUBLE || + descriptor.getType() == PrimitiveType.PrimitiveTypeName.BINARY))) { + // Column vector supports lazy decoding of dictionary values so just set the dictionary. + // We can't do this if rowId != 0 AND the column doesn't have a dictionary (i.e. some + // non-dictionary encoded values have already been added). + column.setDictionary(dictionary); + } else { + decodeDictionaryIds(rowId, num, column, dictionaryIds); + } + } else { + if (column.hasDictionary() && rowId != 0) { + // This batch already has dictionary encoded values but this new page is not. The batch + // does not support a mix of dictionary and not so we will decode the dictionary. + decodeDictionaryIds(0, rowId, column, column.getDictionaryIds()); + } + column.setDictionary(null); + switch (descriptor.getType()) { + case BOOLEAN: + readBooleanBatch(rowId, num, column); + break; + case INT32: + readIntBatch(rowId, num, column); + break; + case INT64: + readLongBatch(rowId, num, column); + break; + case INT96: + readBinaryBatch(rowId, num, column); + break; + case FLOAT: + readFloatBatch(rowId, num, column); + break; + case DOUBLE: + readDoubleBatch(rowId, num, column); + break; + case BINARY: + readBinaryBatch(rowId, num, column); + break; + case FIXED_LEN_BYTE_ARRAY: + readFixedLenByteArrayBatch(rowId, num, column, descriptor.getTypeLength()); + break; + default: + throw new IOException("Unsupported type: " + descriptor.getType()); + } + } + + valuesRead += num; + rowId += num; + total -= num; + } + } + + /** + * Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`. + */ + private void decodeDictionaryIds(int rowId, int num, ColumnVector column, + ColumnVector dictionaryIds) { + switch (descriptor.getType()) { + case INT32: + if (column.dataType() == DataTypes.IntegerType || + DecimalType.is32BitDecimalType(column.dataType())) { + for (int i = rowId; i < rowId + num; ++i) { + column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i))); + } + } else if (column.dataType() == DataTypes.ByteType) { + for (int i = rowId; i < rowId + num; ++i) { + column.putByte(i, (byte) dictionary.decodeToInt(dictionaryIds.getInt(i))); + } + } else if (column.dataType() == DataTypes.ShortType) { + for (int i = rowId; i < rowId + num; ++i) { + column.putShort(i, (short) dictionary.decodeToInt(dictionaryIds.getInt(i))); + } + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + break; + + case INT64: + if (column.dataType() == DataTypes.LongType || + DecimalType.is64BitDecimalType(column.dataType())) { + for (int i = rowId; i < rowId + num; ++i) { + column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i))); + } + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + break; + + case FLOAT: + for (int i = rowId; i < rowId + num; ++i) { + column.putFloat(i, dictionary.decodeToFloat(dictionaryIds.getInt(i))); + } + break; + + case DOUBLE: + for (int i = rowId; i < rowId + num; ++i) { + column.putDouble(i, dictionary.decodeToDouble(dictionaryIds.getInt(i))); + } + break; + case INT96: + if (column.dataType() == DataTypes.TimestampType) { + for (int i = rowId; i < rowId + num; ++i) { + // TODO: Convert dictionary of Binaries to dictionary of Longs + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putLong(i, CatalystRowConverter.binaryToSQLTimestamp(v)); + } + } else { + throw new NotImplementedException(); + } + break; + case BINARY: + // TODO: this is incredibly inefficient as it blows up the dictionary right here. We + // need to do this better. We should probably add the dictionary data to the ColumnVector + // and reuse it across batches. This should mean adding a ByteArray would just update + // the length and offset. + for (int i = rowId; i < rowId + num; ++i) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putByteArray(i, v.getBytes()); + } + break; + case FIXED_LEN_BYTE_ARRAY: + // DecimalType written in the legacy mode + if (DecimalType.is32BitDecimalType(column.dataType())) { + for (int i = rowId; i < rowId + num; ++i) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putInt(i, (int) CatalystRowConverter.binaryToUnscaledLong(v)); + } + } else if (DecimalType.is64BitDecimalType(column.dataType())) { + for (int i = rowId; i < rowId + num; ++i) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putLong(i, CatalystRowConverter.binaryToUnscaledLong(v)); + } + } else if (DecimalType.isByteArrayDecimalType(column.dataType())) { + for (int i = rowId; i < rowId + num; ++i) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putByteArray(i, v.getBytes()); + } + } else { + throw new NotImplementedException(); + } + break; + + default: + throw new NotImplementedException("Unsupported type: " + descriptor.getType()); + } + } + + /** + * For all the read*Batch functions, reads `num` values from this columnReader into column. It + * is guaranteed that num is smaller than the number of values left in the current page. + */ + + private void readBooleanBatch(int rowId, int num, ColumnVector column) throws IOException { + assert(column.dataType() == DataTypes.BooleanType); + defColumn.readBooleans( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } + + private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException { + // This is where we implement support for the valid type conversions. + // TODO: implement remaining type conversions + if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType || + DecimalType.is32BitDecimalType(column.dataType())) { + defColumn.readIntegers( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else if (column.dataType() == DataTypes.ByteType) { + defColumn.readBytes( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else if (column.dataType() == DataTypes.ShortType) { + defColumn.readShorts( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + } + + private void readLongBatch(int rowId, int num, ColumnVector column) throws IOException { + // This is where we implement support for the valid type conversions. + if (column.dataType() == DataTypes.LongType || + DecimalType.is64BitDecimalType(column.dataType())) { + defColumn.readLongs( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else { + throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType()); + } + } + + private void readFloatBatch(int rowId, int num, ColumnVector column) throws IOException { + // This is where we implement support for the valid type conversions. + // TODO: support implicit cast to double? + if (column.dataType() == DataTypes.FloatType) { + defColumn.readFloats( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else { + throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType()); + } + } + + private void readDoubleBatch(int rowId, int num, ColumnVector column) throws IOException { + // This is where we implement support for the valid type conversions. + // TODO: implement remaining type conversions + if (column.dataType() == DataTypes.DoubleType) { + defColumn.readDoubles( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + } + + private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOException { + // This is where we implement support for the valid type conversions. + // TODO: implement remaining type conversions + VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; + if (column.isArray()) { + defColumn.readBinarys(num, column, rowId, maxDefLevel, data); + } else if (column.dataType() == DataTypes.TimestampType) { + for (int i = 0; i < num; i++) { + if (defColumn.readInteger() == maxDefLevel) { + column.putLong(rowId + i, + // Read 12 bytes for INT96 + CatalystRowConverter.binaryToSQLTimestamp(data.readBinary(12))); + } else { + column.putNull(rowId + i); + } + } + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + } + + private void readFixedLenByteArrayBatch(int rowId, int num, + ColumnVector column, int arrayLen) throws IOException { + VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; + // This is where we implement support for the valid type conversions. + // TODO: implement remaining type conversions + if (DecimalType.is32BitDecimalType(column.dataType())) { + for (int i = 0; i < num; i++) { + if (defColumn.readInteger() == maxDefLevel) { + column.putInt(rowId + i, + (int) CatalystRowConverter.binaryToUnscaledLong(data.readBinary(arrayLen))); + } else { + column.putNull(rowId + i); + } + } + } else if (DecimalType.is64BitDecimalType(column.dataType())) { + for (int i = 0; i < num; i++) { + if (defColumn.readInteger() == maxDefLevel) { + column.putLong(rowId + i, + CatalystRowConverter.binaryToUnscaledLong(data.readBinary(arrayLen))); + } else { + column.putNull(rowId + i); + } + } + } else if (DecimalType.isByteArrayDecimalType(column.dataType())) { + for (int i = 0; i < num; i++) { + if (defColumn.readInteger() == maxDefLevel) { + column.putByteArray(rowId + i, data.readBinary(arrayLen).getBytes()); + } else { + column.putNull(rowId + i); + } + } + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + } + + private void readPage() throws IOException { + DataPage page = pageReader.readPage(); + // TODO: Why is this a visitor? + page.accept(new DataPage.Visitor() { + @Override + public Void visit(DataPageV1 dataPageV1) { + try { + readPageV1(dataPageV1); + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public Void visit(DataPageV2 dataPageV2) { + try { + readPageV2(dataPageV2); + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + }); + } + + private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) throws IOException { + this.endOfPageValueCount = valuesRead + pageValueCount; + if (dataEncoding.usesDictionary()) { + this.dataColumn = null; + if (dictionary == null) { + throw new IOException( + "could not read page in col " + descriptor + + " as the dictionary was missing for encoding " + dataEncoding); + } + @SuppressWarnings("deprecation") + Encoding plainDict = Encoding.PLAIN_DICTIONARY; // var to allow warning suppression + if (dataEncoding != plainDict && dataEncoding != Encoding.RLE_DICTIONARY) { + throw new NotImplementedException("Unsupported encoding: " + dataEncoding); + } + this.dataColumn = new VectorizedRleValuesReader(); + this.useDictionary = true; + } else { + if (dataEncoding != Encoding.PLAIN) { + throw new NotImplementedException("Unsupported encoding: " + dataEncoding); + } + this.dataColumn = new VectorizedPlainValuesReader(); + this.useDictionary = false; + } + + try { + dataColumn.initFromPage(pageValueCount, bytes, offset); + } catch (IOException e) { + throw new IOException("could not read page in col " + descriptor, e); + } + } + + private void readPageV1(DataPageV1 page) throws IOException { + this.pageValueCount = page.getValueCount(); + ValuesReader rlReader = page.getRlEncoding().getValuesReader(descriptor, REPETITION_LEVEL); + ValuesReader dlReader; + + // Initialize the decoders. + if (page.getDlEncoding() != Encoding.RLE && descriptor.getMaxDefinitionLevel() != 0) { + throw new NotImplementedException("Unsupported encoding: " + page.getDlEncoding()); + } + int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel()); + this.defColumn = new VectorizedRleValuesReader(bitWidth); + dlReader = this.defColumn; + this.repetitionLevelColumn = new ValuesReaderIntIterator(rlReader); + this.definitionLevelColumn = new ValuesReaderIntIterator(dlReader); + try { + byte[] bytes = page.getBytes().toByteArray(); + rlReader.initFromPage(pageValueCount, bytes, 0); + int next = rlReader.getNextOffset(); + dlReader.initFromPage(pageValueCount, bytes, next); + next = dlReader.getNextOffset(); + initDataReader(page.getValueEncoding(), bytes, next); + } catch (IOException e) { + throw new IOException("could not read page " + page + " in col " + descriptor, e); + } + } + + private void readPageV2(DataPageV2 page) throws IOException { + this.pageValueCount = page.getValueCount(); + this.repetitionLevelColumn = createRLEIterator(descriptor.getMaxRepetitionLevel(), + page.getRepetitionLevels(), descriptor); + + int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel()); + this.defColumn = new VectorizedRleValuesReader(bitWidth); + this.definitionLevelColumn = new ValuesReaderIntIterator(this.defColumn); + this.defColumn.initFromBuffer( + this.pageValueCount, page.getDefinitionLevels().toByteArray()); + try { + initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0); + } catch (IOException e) { + throw new IOException("could not read page " + page + " in col " + descriptor, e); + } + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java new file mode 100644 index 0000000000000..51bdf0f0f2291 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.page.PageReadStore; +import org.apache.parquet.schema.Type; + +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; +import org.apache.spark.sql.execution.vectorized.ColumnarBatch; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * A specialized RecordReader that reads into InternalRows or ColumnarBatches directly using the + * Parquet column APIs. This is somewhat based on parquet-mr's ColumnReader. + * + * TODO: handle complex types, decimal requiring more than 8 bytes, INT96. Schema mismatch. + * All of these can be handled efficiently and easily with codegen. + * + * This class can either return InternalRows or ColumnarBatches. With whole stage codegen + * enabled, this class returns ColumnarBatches which offers significant performance gains. + * TODO: make this always return ColumnarBatches. + */ +public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBase { + /** + * Batch of rows that we assemble and the current index we've returned. Every time this + * batch is used up (batchIdx == numBatched), we populated the batch. + */ + private int batchIdx = 0; + private int numBatched = 0; + + /** + * For each request column, the reader to read this column. This is NULL if this column + * is missing from the file, in which case we populate the attribute with NULL. + */ + private VectorizedColumnReader[] columnReaders; + + /** + * The number of rows that have been returned. + */ + private long rowsReturned; + + /** + * The number of rows that have been reading, including the current in flight row group. + */ + private long totalCountLoadedSoFar = 0; + + /** + * For each column, true if the column is missing in the file and we'll instead return NULLs. + */ + private boolean[] missingColumns; + + /** + * columnBatch object that is used for batch decoding. This is created on first use and triggers + * batched decoding. It is not valid to interleave calls to the batched interface with the row + * by row RecordReader APIs. + * This is only enabled with additional flags for development. This is still a work in progress + * and currently unsupported cases will fail with potentially difficult to diagnose errors. + * This should be only turned on for development to work on this feature. + * + * When this is set, the code will branch early on in the RecordReader APIs. There is no shared + * code between the path that uses the MR decoders and the vectorized ones. + * + * TODOs: + * - Implement v2 page formats (just make sure we create the correct decoders). + */ + private ColumnarBatch columnarBatch; + + /** + * If true, this class returns batches instead of rows. + */ + private boolean returnColumnarBatch; + + /** + * The default config on whether columnarBatch should be offheap. + */ + private static final MemoryMode DEFAULT_MEMORY_MODE = MemoryMode.ON_HEAP; + + /** + * Implementation of RecordReader API. + */ + @Override + public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) + throws IOException, InterruptedException, UnsupportedOperationException { + super.initialize(inputSplit, taskAttemptContext); + initializeInternal(); + } + + /** + * Utility API that will read all the data in path. This circumvents the need to create Hadoop + * objects to use this class. `columns` can contain the list of columns to project. + */ + @Override + public void initialize(String path, List columns) throws IOException, + UnsupportedOperationException { + super.initialize(path, columns); + initializeInternal(); + } + + @Override + public void close() throws IOException { + if (columnarBatch != null) { + columnarBatch.close(); + columnarBatch = null; + } + super.close(); + } + + @Override + public boolean nextKeyValue() throws IOException, InterruptedException { + resultBatch(); + + if (returnColumnarBatch) return nextBatch(); + + if (batchIdx >= numBatched) { + if (!nextBatch()) return false; + } + ++batchIdx; + return true; + } + + @Override + public Object getCurrentValue() throws IOException, InterruptedException { + if (returnColumnarBatch) return columnarBatch; + return columnarBatch.getRow(batchIdx - 1); + } + + @Override + public float getProgress() throws IOException, InterruptedException { + return (float) rowsReturned / totalRowCount; + } + + /** + * Returns the ColumnarBatch object that will be used for all rows returned by this reader. + * This object is reused. Calling this enables the vectorized reader. This should be called + * before any calls to nextKeyValue/nextBatch. + */ + + // Creates a columnar batch that includes the schema from the data files and the additional + // partition columns appended to the end of the batch. + // For example, if the data contains two columns, with 2 partition columns: + // Columns 0,1: data columns + // Column 2: partitionValues[0] + // Column 3: partitionValues[1] + public void initBatch(MemoryMode memMode, StructType partitionColumns, + InternalRow partitionValues) { + StructType batchSchema = new StructType(); + for (StructField f: sparkSchema.fields()) { + batchSchema = batchSchema.add(f); + } + if (partitionColumns != null) { + for (StructField f : partitionColumns.fields()) { + batchSchema = batchSchema.add(f); + } + } + + columnarBatch = ColumnarBatch.allocate(batchSchema, memMode); + if (partitionColumns != null) { + int partitionIdx = sparkSchema.fields().length; + for (int i = 0; i < partitionColumns.fields().length; i++) { + ColumnVectorUtils.populate(columnarBatch.column(i + partitionIdx), partitionValues, i); + columnarBatch.column(i + partitionIdx).setIsConstant(); + } + } + + // Initialize missing columns with nulls. + for (int i = 0; i < missingColumns.length; i++) { + if (missingColumns[i]) { + columnarBatch.column(i).putNulls(0, columnarBatch.capacity()); + columnarBatch.column(i).setIsConstant(); + } + } + } + + public void initBatch() { + initBatch(DEFAULT_MEMORY_MODE, null, null); + } + + public void initBatch(StructType partitionColumns, InternalRow partitionValues) { + initBatch(DEFAULT_MEMORY_MODE, partitionColumns, partitionValues); + } + + public ColumnarBatch resultBatch() { + if (columnarBatch == null) initBatch(); + return columnarBatch; + } + + /* + * Can be called before any rows are returned to enable returning columnar batches directly. + */ + public void enableReturningBatches() { + returnColumnarBatch = true; + } + + /** + * Advances to the next batch of rows. Returns false if there are no more. + */ + public boolean nextBatch() throws IOException { + columnarBatch.reset(); + if (rowsReturned >= totalRowCount) return false; + checkEndOfRowGroup(); + + int num = (int) Math.min((long) columnarBatch.capacity(), totalCountLoadedSoFar - rowsReturned); + for (int i = 0; i < columnReaders.length; ++i) { + if (columnReaders[i] == null) continue; + columnReaders[i].readBatch(num, columnarBatch.column(i)); + } + rowsReturned += num; + columnarBatch.setNumRows(num); + numBatched = num; + batchIdx = 0; + return true; + } + + private void initializeInternal() throws IOException, UnsupportedOperationException { + /** + * Check that the requested schema is supported. + */ + missingColumns = new boolean[requestedSchema.getFieldCount()]; + for (int i = 0; i < requestedSchema.getFieldCount(); ++i) { + Type t = requestedSchema.getFields().get(i); + if (!t.isPrimitive() || t.isRepetition(Type.Repetition.REPEATED)) { + throw new UnsupportedOperationException("Complex types not supported."); + } + + String[] colPath = requestedSchema.getPaths().get(i); + if (fileSchema.containsPath(colPath)) { + ColumnDescriptor fd = fileSchema.getColumnDescription(colPath); + if (!fd.equals(requestedSchema.getColumns().get(i))) { + throw new UnsupportedOperationException("Schema evolution not supported."); + } + missingColumns[i] = false; + } else { + if (requestedSchema.getColumns().get(i).getMaxDefinitionLevel() == 0) { + // Column is missing in data but the required data is non-nullable. This file is invalid. + throw new IOException("Required column is missing in data file. Col: " + + Arrays.toString(colPath)); + } + missingColumns[i] = true; + } + } + } + + private void checkEndOfRowGroup() throws IOException { + if (rowsReturned != totalCountLoadedSoFar) return; + PageReadStore pages = reader.readNextRowGroup(); + if (pages == null) { + throw new IOException("expecting more rows but reached last block. Read " + + rowsReturned + " out of " + totalRowCount); + } + List columns = requestedSchema.getColumns(); + columnReaders = new VectorizedColumnReader[columns.size()]; + for (int i = 0; i < columns.size(); ++i) { + if (missingColumns[i]) continue; + columnReaders[i] = new VectorizedColumnReader(columns.get(i), + pages.getPageReader(columns.get(i))); + } + totalCountLoadedSoFar += pages.getRowCount(); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java new file mode 100644 index 0000000000000..2672e0453b392 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet; + +import java.io.IOException; + +import org.apache.spark.sql.execution.vectorized.ColumnVector; +import org.apache.spark.unsafe.Platform; + +import org.apache.parquet.column.values.ValuesReader; +import org.apache.parquet.io.api.Binary; + +/** + * An implementation of the Parquet PLAIN decoder that supports the vectorized interface. + */ +public class VectorizedPlainValuesReader extends ValuesReader implements VectorizedValuesReader { + private byte[] buffer; + private int offset; + private int bitOffset; // Only used for booleans. + + public VectorizedPlainValuesReader() { + } + + @Override + public void initFromPage(int valueCount, byte[] bytes, int offset) throws IOException { + this.buffer = bytes; + this.offset = offset + Platform.BYTE_ARRAY_OFFSET; + } + + @Override + public void skip() { + throw new UnsupportedOperationException(); + } + + @Override + public final void readBooleans(int total, ColumnVector c, int rowId) { + // TODO: properly vectorize this + for (int i = 0; i < total; i++) { + c.putBoolean(rowId + i, readBoolean()); + } + } + + @Override + public final void readIntegers(int total, ColumnVector c, int rowId) { + c.putIntsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); + offset += 4 * total; + } + + @Override + public final void readLongs(int total, ColumnVector c, int rowId) { + c.putLongsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); + offset += 8 * total; + } + + @Override + public final void readFloats(int total, ColumnVector c, int rowId) { + c.putFloats(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); + offset += 4 * total; + } + + @Override + public final void readDoubles(int total, ColumnVector c, int rowId) { + c.putDoubles(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); + offset += 8 * total; + } + + @Override + public final void readBytes(int total, ColumnVector c, int rowId) { + for (int i = 0; i < total; i++) { + // Bytes are stored as a 4-byte little endian int. Just read the first byte. + // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride. + c.putByte(rowId + i, Platform.getByte(buffer, offset)); + offset += 4; + } + } + + @Override + public final boolean readBoolean() { + byte b = Platform.getByte(buffer, offset); + boolean v = (b & (1 << bitOffset)) != 0; + bitOffset += 1; + if (bitOffset == 8) { + bitOffset = 0; + offset++; + } + return v; + } + + @Override + public final int readInteger() { + int v = Platform.getInt(buffer, offset); + offset += 4; + return v; + } + + @Override + public final long readLong() { + long v = Platform.getLong(buffer, offset); + offset += 8; + return v; + } + + @Override + public final byte readByte() { + return (byte)readInteger(); + } + + @Override + public final float readFloat() { + float v = Platform.getFloat(buffer, offset); + offset += 4; + return v; + } + + @Override + public final double readDouble() { + double v = Platform.getDouble(buffer, offset); + offset += 8; + return v; + } + + @Override + public final void readBinary(int total, ColumnVector v, int rowId) { + for (int i = 0; i < total; i++) { + int len = readInteger(); + int start = offset; + offset += len; + v.putByteArray(rowId + i, buffer, start - Platform.BYTE_ARRAY_OFFSET, len); + } + } + + @Override + public final Binary readBinary(int len) { + Binary result = Binary.fromByteArray(buffer, offset - Platform.BYTE_ARRAY_OFFSET, len); + offset += len; + return result; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java new file mode 100644 index 0000000000000..62157389013bb --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -0,0 +1,613 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet; + +import org.apache.parquet.Preconditions; +import org.apache.parquet.bytes.BytesUtils; +import org.apache.parquet.column.values.ValuesReader; +import org.apache.parquet.column.values.bitpacking.BytePacker; +import org.apache.parquet.column.values.bitpacking.Packer; +import org.apache.parquet.io.ParquetDecodingException; +import org.apache.parquet.io.api.Binary; + +import org.apache.spark.sql.execution.vectorized.ColumnVector; + +/** + * A values reader for Parquet's run-length encoded data. This is based off of the version in + * parquet-mr with these changes: + * - Supports the vectorized interface. + * - Works on byte arrays(byte[]) instead of making byte streams. + * + * This encoding is used in multiple places: + * - Definition/Repetition levels + * - Dictionary ids. + */ +public final class VectorizedRleValuesReader extends ValuesReader + implements VectorizedValuesReader { + // Current decoding mode. The encoded data contains groups of either run length encoded data + // (RLE) or bit packed data. Each group contains a header that indicates which group it is and + // the number of values in the group. + // More details here: https://github.com/Parquet/parquet-format/blob/master/Encodings.md + private enum MODE { + RLE, + PACKED + } + + // Encoded data. + private byte[] in; + private int end; + private int offset; + + // bit/byte width of decoded data and utility to batch unpack them. + private int bitWidth; + private int bytesWidth; + private BytePacker packer; + + // Current decoding mode and values + private MODE mode; + private int currentCount; + private int currentValue; + + // Buffer of decoded values if the values are PACKED. + private int[] currentBuffer = new int[16]; + private int currentBufferIdx = 0; + + // If true, the bit width is fixed. This decoder is used in different places and this also + // controls if we need to read the bitwidth from the beginning of the data stream. + private final boolean fixedWidth; + + public VectorizedRleValuesReader() { + fixedWidth = false; + } + + public VectorizedRleValuesReader(int bitWidth) { + fixedWidth = true; + init(bitWidth); + } + + @Override + public void initFromPage(int valueCount, byte[] page, int start) { + this.offset = start; + this.in = page; + if (fixedWidth) { + if (bitWidth != 0) { + int length = readIntLittleEndian(); + this.end = this.offset + length; + } + } else { + this.end = page.length; + if (this.end != this.offset) init(page[this.offset++] & 255); + } + if (bitWidth == 0) { + // 0 bit width, treat this as an RLE run of valueCount number of 0's. + this.mode = MODE.RLE; + this.currentCount = valueCount; + this.currentValue = 0; + } else { + this.currentCount = 0; + } + } + + // Initialize the reader from a buffer. This is used for the V2 page encoding where the + // definition are in its own buffer. + public void initFromBuffer(int valueCount, byte[] data) { + this.offset = 0; + this.in = data; + this.end = data.length; + if (bitWidth == 0) { + // 0 bit width, treat this as an RLE run of valueCount number of 0's. + this.mode = MODE.RLE; + this.currentCount = valueCount; + this.currentValue = 0; + } else { + this.currentCount = 0; + } + } + + /** + * Initializes the internal state for decoding ints of `bitWidth`. + */ + private void init(int bitWidth) { + Preconditions.checkArgument(bitWidth >= 0 && bitWidth <= 32, "bitWidth must be >= 0 and <= 32"); + this.bitWidth = bitWidth; + this.bytesWidth = BytesUtils.paddedByteCountFromBits(bitWidth); + this.packer = Packer.LITTLE_ENDIAN.newBytePacker(bitWidth); + } + + @Override + public int getNextOffset() { + return this.end; + } + + @Override + public boolean readBoolean() { + return this.readInteger() != 0; + } + + @Override + public void skip() { + this.readInteger(); + } + + @Override + public int readValueDictionaryId() { + return readInteger(); + } + + @Override + public int readInteger() { + if (this.currentCount == 0) { this.readNextGroup(); } + + this.currentCount--; + switch (mode) { + case RLE: + return this.currentValue; + case PACKED: + return this.currentBuffer[currentBufferIdx++]; + } + throw new RuntimeException("Unreachable"); + } + + /** + * Reads `total` ints into `c` filling them in starting at `c[rowId]`. This reader + * reads the definition levels and then will read from `data` for the non-null values. + * If the value is null, c will be populated with `nullValue`. Note that `nullValue` is only + * necessary for readIntegers because we also use it to decode dictionaryIds and want to make + * sure it always has a value in range. + * + * This is a batched version of this logic: + * if (this.readInt() == level) { + * c[rowId] = data.readInteger(); + * } else { + * c[rowId] = null; + * } + */ + public void readIntegers(int total, ColumnVector c, int rowId, int level, + VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + data.readIntegers(n, c, rowId); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putInt(rowId + i, data.readInteger()); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + // TODO: can this code duplication be removed without a perf penalty? + public void readBooleans(int total, ColumnVector c, + int rowId, int level, VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + data.readBooleans(n, c, rowId); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putBoolean(rowId + i, data.readBoolean()); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + public void readBytes(int total, ColumnVector c, + int rowId, int level, VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + data.readBytes(n, c, rowId); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putByte(rowId + i, data.readByte()); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + public void readShorts(int total, ColumnVector c, + int rowId, int level, VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + for (int i = 0; i < n; i++) { + c.putShort(rowId + i, (short)data.readInteger()); + } + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putShort(rowId + i, (short)data.readInteger()); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + public void readLongs(int total, ColumnVector c, int rowId, int level, + VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + data.readLongs(n, c, rowId); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putLong(rowId + i, data.readLong()); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + public void readFloats(int total, ColumnVector c, int rowId, int level, + VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + data.readFloats(n, c, rowId); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putFloat(rowId + i, data.readFloat()); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + public void readDoubles(int total, ColumnVector c, int rowId, int level, + VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + data.readDoubles(n, c, rowId); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putDouble(rowId + i, data.readDouble()); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + public void readBinarys(int total, ColumnVector c, int rowId, int level, + VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + data.readBinary(n, c, rowId); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + data.readBinary(1, c, rowId + i); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + /** + * Decoding for dictionary ids. The IDs are populated into `values` and the nullability is + * populated into `nulls`. + */ + public void readIntegers(int total, ColumnVector values, ColumnVector nulls, int rowId, int level, + VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + data.readIntegers(n, values, rowId); + } else { + nulls.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + values.putInt(rowId + i, data.readInteger()); + } else { + nulls.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + + // The RLE reader implements the vectorized decoding interface when used to decode dictionary + // IDs. This is different than the above APIs that decodes definitions levels along with values. + // Since this is only used to decode dictionary IDs, only decoding integers is supported. + @Override + public void readIntegers(int total, ColumnVector c, int rowId) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + c.putInts(rowId, n, currentValue); + break; + case PACKED: + c.putInts(rowId, n, currentBuffer, currentBufferIdx); + currentBufferIdx += n; + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + @Override + public byte readByte() { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public void readBytes(int total, ColumnVector c, int rowId) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public void readLongs(int total, ColumnVector c, int rowId) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public void readBinary(int total, ColumnVector c, int rowId) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public void readBooleans(int total, ColumnVector c, int rowId) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public void readFloats(int total, ColumnVector c, int rowId) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public void readDoubles(int total, ColumnVector c, int rowId) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public Binary readBinary(int len) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + /** + * Reads the next varint encoded int. + */ + private int readUnsignedVarInt() { + int value = 0; + int shift = 0; + int b; + do { + b = in[offset++] & 255; + value |= (b & 0x7F) << shift; + shift += 7; + } while ((b & 0x80) != 0); + return value; + } + + /** + * Reads the next 4 byte little endian int. + */ + private int readIntLittleEndian() { + int ch4 = in[offset] & 255; + int ch3 = in[offset + 1] & 255; + int ch2 = in[offset + 2] & 255; + int ch1 = in[offset + 3] & 255; + offset += 4; + return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4 << 0)); + } + + /** + * Reads the next byteWidth little endian int. + */ + private int readIntLittleEndianPaddedOnBitWidth() { + switch (bytesWidth) { + case 0: + return 0; + case 1: + return in[offset++] & 255; + case 2: { + int ch2 = in[offset] & 255; + int ch1 = in[offset + 1] & 255; + offset += 2; + return (ch1 << 8) + ch2; + } + case 3: { + int ch3 = in[offset] & 255; + int ch2 = in[offset + 1] & 255; + int ch1 = in[offset + 2] & 255; + offset += 3; + return (ch1 << 16) + (ch2 << 8) + (ch3 << 0); + } + case 4: { + return readIntLittleEndian(); + } + } + throw new RuntimeException("Unreachable"); + } + + private int ceil8(int value) { + return (value + 7) / 8; + } + + /** + * Reads the next group. + */ + private void readNextGroup() { + int header = readUnsignedVarInt(); + this.mode = (header & 1) == 0 ? MODE.RLE : MODE.PACKED; + switch (mode) { + case RLE: + this.currentCount = header >>> 1; + this.currentValue = readIntLittleEndianPaddedOnBitWidth(); + return; + case PACKED: + int numGroups = header >>> 1; + this.currentCount = numGroups * 8; + int bytesToRead = ceil8(this.currentCount * this.bitWidth); + + if (this.currentBuffer.length < this.currentCount) { + this.currentBuffer = new int[this.currentCount]; + } + currentBufferIdx = 0; + int valueIndex = 0; + for (int byteIndex = offset; valueIndex < this.currentCount; byteIndex += this.bitWidth) { + this.packer.unpack8Values(in, byteIndex, this.currentBuffer, valueIndex); + valueIndex += 8; + } + offset += bytesToRead; + return; + default: + throw new ParquetDecodingException("not a valid mode " + this.mode); + } + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java new file mode 100644 index 0000000000000..88418ca53fe1e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet; + +import org.apache.spark.sql.execution.vectorized.ColumnVector; + +import org.apache.parquet.io.api.Binary; + +/** + * Interface for value decoding that supports vectorized (aka batched) decoding. + * TODO: merge this into parquet-mr. + */ +public interface VectorizedValuesReader { + boolean readBoolean(); + byte readByte(); + int readInteger(); + long readLong(); + float readFloat(); + double readDouble(); + Binary readBinary(int len); + + /* + * Reads `total` values into `c` start at `c[rowId]` + */ + void readBooleans(int total, ColumnVector c, int rowId); + void readBytes(int total, ColumnVector c, int rowId); + void readIntegers(int total, ColumnVector c, int rowId); + void readLongs(int total, ColumnVector c, int rowId); + void readFloats(int total, ColumnVector c, int rowId); + void readDoubles(int total, ColumnVector c, int rowId); + void readBinary(int total, ColumnVector c, int rowId); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java new file mode 100644 index 0000000000000..69ce54390fead --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized; + +import java.util.Arrays; + +import com.google.common.annotations.VisibleForTesting; + +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.types.StructType; + +import static org.apache.spark.sql.types.DataTypes.LongType; + +/** + * This is an illustrative implementation of an append-only single-key/single value aggregate hash + * map that can act as a 'cache' for extremely fast key-value lookups while evaluating aggregates + * (and fall back to the `BytesToBytesMap` if a given key isn't found). This can be potentially + * 'codegened' in TungstenAggregate to speed up aggregates w/ key. + * + * It is backed by a power-of-2-sized array for index lookups and a columnar batch that stores the + * key-value pairs. The index lookups in the array rely on linear probing (with a small number of + * maximum tries) and use an inexpensive hash function which makes it really efficient for a + * majority of lookups. However, using linear probing and an inexpensive hash function also makes it + * less robust as compared to the `BytesToBytesMap` (especially for a large number of keys or even + * for certain distribution of keys) and requires us to fall back on the latter for correctness. + */ +public class AggregateHashMap { + + private ColumnarBatch batch; + private int[] buckets; + private int numBuckets; + private int numRows = 0; + private int maxSteps = 3; + + private static int DEFAULT_CAPACITY = 1 << 16; + private static double DEFAULT_LOAD_FACTOR = 0.25; + private static int DEFAULT_MAX_STEPS = 3; + + public AggregateHashMap(StructType schema, int capacity, double loadFactor, int maxSteps) { + + // We currently only support single key-value pair that are both longs + assert (schema.size() == 2 && schema.fields()[0].dataType() == LongType && + schema.fields()[1].dataType() == LongType); + + // capacity should be a power of 2 + assert (capacity > 0 && ((capacity & (capacity - 1)) == 0)); + + this.maxSteps = maxSteps; + numBuckets = (int) (capacity / loadFactor); + batch = ColumnarBatch.allocate(schema, MemoryMode.ON_HEAP, capacity); + buckets = new int[numBuckets]; + Arrays.fill(buckets, -1); + } + + public AggregateHashMap(StructType schema) { + this(schema, DEFAULT_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_MAX_STEPS); + } + + public ColumnarBatch.Row findOrInsert(long key) { + int idx = find(key); + if (idx != -1 && buckets[idx] == -1) { + batch.column(0).putLong(numRows, key); + batch.column(1).putLong(numRows, 0); + buckets[idx] = numRows++; + } + return batch.getRow(buckets[idx]); + } + + @VisibleForTesting + public int find(long key) { + long h = hash(key); + int step = 0; + int idx = (int) h & (numBuckets - 1); + while (step < maxSteps) { + // Return bucket index if it's either an empty slot or already contains the key + if (buckets[idx] == -1) { + return idx; + } else if (equals(idx, key)) { + return idx; + } + idx = (idx + 1) & (numBuckets - 1); + step++; + } + // Didn't find it + return -1; + } + + private long hash(long key) { + return key; + } + + private boolean equals(int idx, long key1) { + return batch.column(0).getLong(buckets[idx]) == key1; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java new file mode 100644 index 0000000000000..ff1f6680a7181 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -0,0 +1,984 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import java.math.BigDecimal; +import java.math.BigInteger; + +import org.apache.commons.lang.NotImplementedException; +import org.apache.parquet.column.Dictionary; +import org.apache.parquet.io.api.Binary; + +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * This class represents a column of values and provides the main APIs to access the data + * values. It supports all the types and contains get/put APIs as well as their batched versions. + * The batched versions are preferable whenever possible. + * + * To handle nested schemas, ColumnVector has two types: Arrays and Structs. In both cases these + * columns have child columns. All of the data is stored in the child columns and the parent column + * contains nullability, and in the case of Arrays, the lengths and offsets into the child column. + * Lengths and offsets are encoded identically to INTs. + * Maps are just a special case of a two field struct. + * Strings are handled as an Array of ByteType. + * + * Capacity: The data stored is dense but the arrays are not fixed capacity. It is the + * responsibility of the caller to call reserve() to ensure there is enough room before adding + * elements. This means that the put() APIs do not check as in common cases (i.e. flat schemas), + * the lengths are known up front. + * + * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values + * in the current RowBatch. + * + * A ColumnVector should be considered immutable once originally created. In other words, it is not + * valid to call put APIs after reads until reset() is called. + * + * ColumnVectors are intended to be reused. + */ +public abstract class ColumnVector implements AutoCloseable { + /** + * Allocates a column to store elements of `type` on or off heap. + * Capacity is the initial capacity of the vector and it will grow as necessary. Capacity is + * in number of elements, not number of bytes. + */ + public static ColumnVector allocate(int capacity, DataType type, MemoryMode mode) { + if (mode == MemoryMode.OFF_HEAP) { + return new OffHeapColumnVector(capacity, type); + } else { + return new OnHeapColumnVector(capacity, type); + } + } + + /** + * Holder object to return an array. This object is intended to be reused. Callers should + * copy the data out if it needs to be stored. + */ + public static final class Array extends ArrayData { + // The data for this array. This array contains elements from + // data[offset] to data[offset + length). + public final ColumnVector data; + public int length; + public int offset; + + // Populate if binary data is required for the Array. This is stored here as an optimization + // for string data. + public byte[] byteArray; + public int byteArrayOffset; + + // Reused staging buffer, used for loading from offheap. + protected byte[] tmpByteArray = new byte[1]; + + protected Array(ColumnVector data) { + this.data = data; + } + + @Override + public int numElements() { return length; } + + @Override + public ArrayData copy() { + throw new NotImplementedException(); + } + + // TODO: this is extremely expensive. + @Override + public Object[] array() { + DataType dt = data.dataType(); + Object[] list = new Object[length]; + + if (dt instanceof BooleanType) { + for (int i = 0; i < length; i++) { + if (!data.isNullAt(offset + i)) { + list[i] = data.getBoolean(offset + i); + } + } + } else if (dt instanceof ByteType) { + for (int i = 0; i < length; i++) { + if (!data.isNullAt(offset + i)) { + list[i] = data.getByte(offset + i); + } + } + } else if (dt instanceof ShortType) { + for (int i = 0; i < length; i++) { + if (!data.isNullAt(offset + i)) { + list[i] = data.getShort(offset + i); + } + } + } else if (dt instanceof IntegerType) { + for (int i = 0; i < length; i++) { + if (!data.isNullAt(offset + i)) { + list[i] = data.getInt(offset + i); + } + } + } else if (dt instanceof FloatType) { + for (int i = 0; i < length; i++) { + if (!data.isNullAt(offset + i)) { + list[i] = data.getFloat(offset + i); + } + } + } else if (dt instanceof DoubleType) { + for (int i = 0; i < length; i++) { + if (!data.isNullAt(offset + i)) { + list[i] = data.getDouble(offset + i); + } + } + } else if (dt instanceof LongType) { + for (int i = 0; i < length; i++) { + if (!data.isNullAt(offset + i)) { + list[i] = data.getLong(offset + i); + } + } + } else if (dt instanceof DecimalType) { + DecimalType decType = (DecimalType)dt; + for (int i = 0; i < length; i++) { + if (!data.isNullAt(offset + i)) { + list[i] = getDecimal(i, decType.precision(), decType.scale()); + } + } + } else if (dt instanceof StringType) { + for (int i = 0; i < length; i++) { + if (!data.isNullAt(offset + i)) { + list[i] = getUTF8String(i).toString(); + } + } + } else if (dt instanceof CalendarIntervalType) { + for (int i = 0; i < length; i++) { + if (!data.isNullAt(offset + i)) { + list[i] = getInterval(i); + } + } + } else { + throw new NotImplementedException("Type " + dt); + } + return list; + } + + @Override + public boolean isNullAt(int ordinal) { return data.isNullAt(offset + ordinal); } + + @Override + public boolean getBoolean(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public byte getByte(int ordinal) { return data.getByte(offset + ordinal); } + + @Override + public short getShort(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public int getInt(int ordinal) { return data.getInt(offset + ordinal); } + + @Override + public long getLong(int ordinal) { return data.getLong(offset + ordinal); } + + @Override + public float getFloat(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public double getDouble(int ordinal) { return data.getDouble(offset + ordinal); } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + return data.getDecimal(offset + ordinal, precision, scale); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + return data.getUTF8String(offset + ordinal); + } + + @Override + public byte[] getBinary(int ordinal) { + return data.getBinary(offset + ordinal); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + int month = data.getChildColumn(0).getInt(offset + ordinal); + long microseconds = data.getChildColumn(1).getLong(offset + ordinal); + return new CalendarInterval(month, microseconds); + } + + @Override + public InternalRow getStruct(int ordinal, int numFields) { + return data.getStruct(offset + ordinal); + } + + @Override + public ArrayData getArray(int ordinal) { + return data.getArray(offset + ordinal); + } + + @Override + public MapData getMap(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public Object get(int ordinal, DataType dataType) { + throw new NotImplementedException(); + } + } + + /** + * Returns the data type of this column. + */ + public final DataType dataType() { return type; } + + /** + * Resets this column for writing. The currently stored values are no longer accessible. + */ + public void reset() { + if (isConstant) return; + + if (childColumns != null) { + for (ColumnVector c: childColumns) { + c.reset(); + } + } + numNulls = 0; + elementsAppended = 0; + if (anyNullsSet) { + putNotNulls(0, capacity); + anyNullsSet = false; + } + } + + /** + * Cleans up memory for this column. The column is not usable after this. + * TODO: this should probably have ref-counted semantics. + */ + public abstract void close(); + + /* + * Ensures that there is enough storage to store capcity elements. That is, the put() APIs + * must work for all rowIds < capcity. + */ + public abstract void reserve(int capacity); + + /** + * Returns the number of nulls in this column. + */ + public final int numNulls() { return numNulls; } + + /** + * Returns true if any of the nulls indicator are set for this column. This can be used + * as an optimization to prevent setting nulls. + */ + public final boolean anyNullsSet() { return anyNullsSet; } + + /** + * Returns the off heap ptr for the arrays backing the NULLs and values buffer. Only valid + * to call for off heap columns. + */ + public abstract long nullsNativeAddress(); + public abstract long valuesNativeAddress(); + + /** + * Sets the value at rowId to null/not null. + */ + public abstract void putNotNull(int rowId); + public abstract void putNull(int rowId); + + /** + * Sets the values from [rowId, rowId + count) to null/not null. + */ + public abstract void putNulls(int rowId, int count); + public abstract void putNotNulls(int rowId, int count); + + /** + * Returns whether the value at rowId is NULL. + */ + public abstract boolean isNullAt(int rowId); + + /** + * Sets the value at rowId to `value`. + */ + public abstract void putBoolean(int rowId, boolean value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putBooleans(int rowId, int count, boolean value); + + /** + * Returns the value for rowId. + */ + public abstract boolean getBoolean(int rowId); + + /** + * Sets the value at rowId to `value`. + */ + public abstract void putByte(int rowId, byte value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putBytes(int rowId, int count, byte value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putBytes(int rowId, int count, byte[] src, int srcIndex); + + /** + * Returns the value for rowId. + */ + public abstract byte getByte(int rowId); + + /** + * Sets the value at rowId to `value`. + */ + public abstract void putShort(int rowId, short value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putShorts(int rowId, int count, short value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putShorts(int rowId, int count, short[] src, int srcIndex); + + /** + * Returns the value for rowId. + */ + public abstract short getShort(int rowId); + + /** + * Sets the value at rowId to `value`. + */ + public abstract void putInt(int rowId, int value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putInts(int rowId, int count, int value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putInts(int rowId, int count, int[] src, int srcIndex); + + /** + * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * The data in src must be 4-byte little endian ints. + */ + public abstract void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex); + + /** + * Returns the value for rowId. + */ + public abstract int getInt(int rowId); + + /** + * Sets the value at rowId to `value`. + */ + public abstract void putLong(int rowId, long value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putLongs(int rowId, int count, long value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putLongs(int rowId, int count, long[] src, int srcIndex); + + /** + * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * The data in src must be 8-byte little endian longs. + */ + public abstract void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex); + + /** + * Returns the value for rowId. + */ + public abstract long getLong(int rowId); + + /** + * Sets the value at rowId to `value`. + */ + public abstract void putFloat(int rowId, float value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putFloats(int rowId, int count, float value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * src should contain `count` doubles written as ieee format. + */ + public abstract void putFloats(int rowId, int count, float[] src, int srcIndex); + + /** + * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * The data in src must be ieee formatted floats. + */ + public abstract void putFloats(int rowId, int count, byte[] src, int srcIndex); + + /** + * Returns the value for rowId. + */ + public abstract float getFloat(int rowId); + + /** + * Sets the value at rowId to `value`. + */ + public abstract void putDouble(int rowId, double value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putDoubles(int rowId, int count, double value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * src should contain `count` doubles written as ieee format. + */ + public abstract void putDoubles(int rowId, int count, double[] src, int srcIndex); + + /** + * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * The data in src must be ieee formatted doubles. + */ + public abstract void putDoubles(int rowId, int count, byte[] src, int srcIndex); + + /** + * Returns the value for rowId. + */ + public abstract double getDouble(int rowId); + + /** + * Puts a byte array that already exists in this column. + */ + public abstract void putArray(int rowId, int offset, int length); + + /** + * Returns the length of the array at rowid. + */ + public abstract int getArrayLength(int rowId); + + /** + * Returns the offset of the array at rowid. + */ + public abstract int getArrayOffset(int rowId); + + /** + * Returns a utility object to get structs. + */ + public ColumnarBatch.Row getStruct(int rowId) { + resultStruct.rowId = rowId; + return resultStruct; + } + + /** + * Returns a utility object to get structs. + * provided to keep API compabilitity with InternalRow for code generation + */ + public ColumnarBatch.Row getStruct(int rowId, int size) { + resultStruct.rowId = rowId; + return resultStruct; + } + + /** + * Returns the array at rowid. + */ + public final Array getArray(int rowId) { + resultArray.length = getArrayLength(rowId); + resultArray.offset = getArrayOffset(rowId); + return resultArray; + } + + /** + * Loads the data into array.byteArray. + */ + public abstract void loadBytes(Array array); + + /** + * Sets the value at rowId to `value`. + */ + public abstract int putByteArray(int rowId, byte[] value, int offset, int count); + public final int putByteArray(int rowId, byte[] value) { + return putByteArray(rowId, value, 0, value.length); + } + + /** + * Returns the value for rowId. + */ + private Array getByteArray(int rowId) { + Array array = getArray(rowId); + array.data.loadBytes(array); + return array; + } + + /** + * Returns the value for rowId. + */ + public MapData getMap(int ordinal) { + throw new NotImplementedException(); + } + + /** + * Returns the decimal for rowId. + */ + public final Decimal getDecimal(int rowId, int precision, int scale) { + if (precision <= Decimal.MAX_INT_DIGITS()) { + return Decimal.createUnsafe(getInt(rowId), precision, scale); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + return Decimal.createUnsafe(getLong(rowId), precision, scale); + } else { + // TODO: best perf? + byte[] bytes = getBinary(rowId); + BigInteger bigInteger = new BigInteger(bytes); + BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(javaDecimal, precision, scale); + } + } + + + public final void putDecimal(int rowId, Decimal value, int precision) { + if (precision <= Decimal.MAX_INT_DIGITS()) { + putInt(rowId, value.toInt()); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + putLong(rowId, value.toLong()); + } else { + BigInteger bigInteger = value.toJavaBigDecimal().unscaledValue(); + putByteArray(rowId, bigInteger.toByteArray()); + } + } + + /** + * Returns the UTF8String for rowId. + */ + public final UTF8String getUTF8String(int rowId) { + if (dictionary == null) { + ColumnVector.Array a = getByteArray(rowId); + return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); + } else { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(rowId)); + return UTF8String.fromBytes(v.getBytes()); + } + } + + /** + * Returns the byte array for rowId. + */ + public final byte[] getBinary(int rowId) { + if (dictionary == null) { + ColumnVector.Array array = getByteArray(rowId); + byte[] bytes = new byte[array.length]; + System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); + return bytes; + } else { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(rowId)); + return v.getBytes(); + } + } + + /** + * Append APIs. These APIs all behave similarly and will append data to the current vector. It + * is not valid to mix the put and append APIs. The append APIs are slower and should only be + * used if the sizes are not known up front. + * In all these cases, the return value is the rowId for the first appended element. + */ + public final int appendNull() { + assert (!(dataType() instanceof StructType)); // Use appendStruct() + reserve(elementsAppended + 1); + putNull(elementsAppended); + return elementsAppended++; + } + + public final int appendNotNull() { + reserve(elementsAppended + 1); + putNotNull(elementsAppended); + return elementsAppended++; + } + + public final int appendNulls(int count) { + assert (!(dataType() instanceof StructType)); + reserve(elementsAppended + count); + int result = elementsAppended; + putNulls(elementsAppended, count); + elementsAppended += count; + return result; + } + + public final int appendNotNulls(int count) { + assert (!(dataType() instanceof StructType)); + reserve(elementsAppended + count); + int result = elementsAppended; + putNotNulls(elementsAppended, count); + elementsAppended += count; + return result; + } + + public final int appendBoolean(boolean v) { + reserve(elementsAppended + 1); + putBoolean(elementsAppended, v); + return elementsAppended++; + } + + public final int appendBooleans(int count, boolean v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putBooleans(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendByte(byte v) { + reserve(elementsAppended + 1); + putByte(elementsAppended, v); + return elementsAppended++; + } + + public final int appendBytes(int count, byte v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putBytes(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendBytes(int length, byte[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putBytes(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendShort(short v) { + reserve(elementsAppended + 1); + putShort(elementsAppended, v); + return elementsAppended++; + } + + public final int appendShorts(int count, short v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putShorts(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendShorts(int length, short[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putShorts(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendInt(int v) { + reserve(elementsAppended + 1); + putInt(elementsAppended, v); + return elementsAppended++; + } + + public final int appendInts(int count, int v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putInts(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendInts(int length, int[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putInts(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendLong(long v) { + reserve(elementsAppended + 1); + putLong(elementsAppended, v); + return elementsAppended++; + } + + public final int appendLongs(int count, long v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putLongs(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendLongs(int length, long[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putLongs(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendFloat(float v) { + reserve(elementsAppended + 1); + putFloat(elementsAppended, v); + return elementsAppended++; + } + + public final int appendFloats(int count, float v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putFloats(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendDouble(double v) { + reserve(elementsAppended + 1); + putDouble(elementsAppended, v); + return elementsAppended++; + } + + public final int appendDoubles(int count, double v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putDoubles(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendDoubles(int length, double[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putDoubles(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendByteArray(byte[] value, int offset, int length) { + int copiedOffset = arrayData().appendBytes(length, value, offset); + reserve(elementsAppended + 1); + putArray(elementsAppended, copiedOffset, length); + return elementsAppended++; + } + + public final int appendArray(int length) { + reserve(elementsAppended + 1); + putArray(elementsAppended, arrayData().elementsAppended, length); + return elementsAppended++; + } + + /** + * Appends a NULL struct. This *has* to be used for structs instead of appendNull() as this + * recursively appends a NULL to its children. + * We don't have this logic as the general appendNull implementation to optimize the more + * common non-struct case. + */ + public final int appendStruct(boolean isNull) { + if (isNull) { + appendNull(); + for (ColumnVector c: childColumns) { + if (c.type instanceof StructType) { + c.appendStruct(true); + } else { + c.appendNull(); + } + } + } else { + appendNotNull(); + } + return elementsAppended; + } + + /** + * Returns the data for the underlying array. + */ + public final ColumnVector arrayData() { return childColumns[0]; } + + /** + * Returns the ordinal's child data column. + */ + public final ColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } + + /** + * Returns the elements appended. + */ + public final int getElementsAppended() { return elementsAppended; } + + /** + * Returns true if this column is an array. + */ + public final boolean isArray() { return resultArray != null; } + + /** + * Marks this column as being constant. + */ + public final void setIsConstant() { isConstant = true; } + + /** + * Maximum number of rows that can be stored in this column. + */ + protected int capacity; + + /** + * Data type for this column. + */ + protected final DataType type; + + /** + * Number of nulls in this column. This is an optimization for the reader, to skip NULL checks. + */ + protected int numNulls; + + /** + * True if there is at least one NULL byte set. This is an optimization for the writer, to skip + * having to clear NULL bits. + */ + protected boolean anyNullsSet; + + /** + * True if this column's values are fixed. This means the column values never change, even + * across resets. + */ + protected boolean isConstant; + + /** + * Default size of each array length value. This grows as necessary. + */ + protected static final int DEFAULT_ARRAY_LENGTH = 4; + + /** + * Current write cursor (row index) when appending data. + */ + protected int elementsAppended; + + /** + * If this is a nested type (array or struct), the column for the child data. + */ + protected final ColumnVector[] childColumns; + + /** + * Reusable Array holder for getArray(). + */ + protected final Array resultArray; + + /** + * Reusable Struct holder for getStruct(). + */ + protected final ColumnarBatch.Row resultStruct; + + /** + * The Dictionary for this column. + * + * If it's not null, will be used to decode the value in getXXX(). + */ + protected Dictionary dictionary; + + /** + * Reusable column for ids of dictionary. + */ + protected ColumnVector dictionaryIds; + + /** + * Update the dictionary. + */ + public void setDictionary(Dictionary dictionary) { + this.dictionary = dictionary; + } + + /** + * Returns true if this column has a dictionary. + */ + public boolean hasDictionary() { return this.dictionary != null; } + + /** + * Reserve a integer column for ids of dictionary. + */ + public ColumnVector reserveDictionaryIds(int capacity) { + if (dictionaryIds == null) { + dictionaryIds = allocate(capacity, DataTypes.IntegerType, + this instanceof OnHeapColumnVector ? MemoryMode.ON_HEAP : MemoryMode.OFF_HEAP); + } else { + dictionaryIds.reset(); + dictionaryIds.reserve(capacity); + } + return dictionaryIds; + } + + /** + * Returns the underlying integer column for ids of dictionary. + */ + public ColumnVector getDictionaryIds() { + return dictionaryIds; + } + + /** + * Sets up the common state and also handles creating the child columns if this is a nested + * type. + */ + protected ColumnVector(int capacity, DataType type, MemoryMode memMode) { + this.capacity = capacity; + this.type = type; + + if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType + || DecimalType.isByteArrayDecimalType(type)) { + DataType childType; + int childCapacity = capacity; + if (type instanceof ArrayType) { + childType = ((ArrayType)type).elementType(); + } else { + childType = DataTypes.ByteType; + childCapacity *= DEFAULT_ARRAY_LENGTH; + } + this.childColumns = new ColumnVector[1]; + this.childColumns[0] = ColumnVector.allocate(childCapacity, childType, memMode); + this.resultArray = new Array(this.childColumns[0]); + this.resultStruct = null; + } else if (type instanceof StructType) { + StructType st = (StructType)type; + this.childColumns = new ColumnVector[st.fields().length]; + for (int i = 0; i < childColumns.length; ++i) { + this.childColumns[i] = ColumnVector.allocate(capacity, st.fields()[i].dataType(), memMode); + } + this.resultArray = null; + this.resultStruct = new ColumnarBatch.Row(this.childColumns); + } else if (type instanceof CalendarIntervalType) { + // Two columns. Months as int. Microseconds as Long. + this.childColumns = new ColumnVector[2]; + this.childColumns[0] = ColumnVector.allocate(capacity, DataTypes.IntegerType, memMode); + this.childColumns[1] = ColumnVector.allocate(capacity, DataTypes.LongType, memMode); + this.resultArray = null; + this.resultStruct = new ColumnarBatch.Row(this.childColumns); + } else { + this.childColumns = null; + this.resultArray = null; + this.resultStruct = null; + } + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java new file mode 100644 index 0000000000000..2dc57dc50d691 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.sql.Date; +import java.util.Iterator; +import java.util.List; + +import org.apache.commons.lang.NotImplementedException; + +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * Utilities to help manipulate data associate with ColumnVectors. These should be used mostly + * for debugging or other non-performance critical paths. + * These utilities are mostly used to convert ColumnVectors into other formats. + */ +public class ColumnVectorUtils { + /** + * Populates the entire `col` with `row[fieldIdx]` + */ + public static void populate(ColumnVector col, InternalRow row, int fieldIdx) { + int capacity = col.capacity; + DataType t = col.dataType(); + + if (row.isNullAt(fieldIdx)) { + col.putNulls(0, capacity); + } else { + if (t == DataTypes.BooleanType) { + col.putBooleans(0, capacity, row.getBoolean(fieldIdx)); + } else if (t == DataTypes.ByteType) { + col.putBytes(0, capacity, row.getByte(fieldIdx)); + } else if (t == DataTypes.ShortType) { + col.putShorts(0, capacity, row.getShort(fieldIdx)); + } else if (t == DataTypes.IntegerType) { + col.putInts(0, capacity, row.getInt(fieldIdx)); + } else if (t == DataTypes.LongType) { + col.putLongs(0, capacity, row.getLong(fieldIdx)); + } else if (t == DataTypes.FloatType) { + col.putFloats(0, capacity, row.getFloat(fieldIdx)); + } else if (t == DataTypes.DoubleType) { + col.putDoubles(0, capacity, row.getDouble(fieldIdx)); + } else if (t == DataTypes.StringType) { + UTF8String v = row.getUTF8String(fieldIdx); + byte[] bytes = v.getBytes(); + for (int i = 0; i < capacity; i++) { + col.putByteArray(i, bytes); + } + } else if (t instanceof DecimalType) { + DecimalType dt = (DecimalType)t; + Decimal d = row.getDecimal(fieldIdx, dt.precision(), dt.scale()); + if (dt.precision() <= Decimal.MAX_INT_DIGITS()) { + col.putInts(0, capacity, (int)d.toUnscaledLong()); + } else if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) { + col.putLongs(0, capacity, d.toUnscaledLong()); + } else { + final BigInteger integer = d.toJavaBigDecimal().unscaledValue(); + byte[] bytes = integer.toByteArray(); + for (int i = 0; i < capacity; i++) { + col.putByteArray(i, bytes, 0, bytes.length); + } + } + } else if (t instanceof CalendarIntervalType) { + CalendarInterval c = (CalendarInterval)row.get(fieldIdx, t); + col.getChildColumn(0).putInts(0, capacity, c.months); + col.getChildColumn(1).putLongs(0, capacity, c.microseconds); + } else if (t instanceof DateType) { + Date date = (Date)row.get(fieldIdx, t); + col.putInts(0, capacity, DateTimeUtils.fromJavaDate(date)); + } + } + } + + /** + * Returns the array data as the java primitive array. + * For example, an array of IntegerType will return an int[]. + * Throws exceptions for unhandled schemas. + */ + public static Object toPrimitiveJavaArray(ColumnVector.Array array) { + DataType dt = array.data.dataType(); + if (dt instanceof IntegerType) { + int[] result = new int[array.length]; + ColumnVector data = array.data; + for (int i = 0; i < result.length; i++) { + if (data.isNullAt(array.offset + i)) { + throw new RuntimeException("Cannot handle NULL values."); + } + result[i] = data.getInt(array.offset + i); + } + return result; + } else { + throw new NotImplementedException(); + } + } + + private static void appendValue(ColumnVector dst, DataType t, Object o) { + if (o == null) { + if (t instanceof CalendarIntervalType) { + dst.appendStruct(true); + } else { + dst.appendNull(); + } + } else { + if (t == DataTypes.BooleanType) { + dst.appendBoolean(((Boolean)o).booleanValue()); + } else if (t == DataTypes.ByteType) { + dst.appendByte(((Byte) o).byteValue()); + } else if (t == DataTypes.ShortType) { + dst.appendShort(((Short)o).shortValue()); + } else if (t == DataTypes.IntegerType) { + dst.appendInt(((Integer)o).intValue()); + } else if (t == DataTypes.LongType) { + dst.appendLong(((Long)o).longValue()); + } else if (t == DataTypes.FloatType) { + dst.appendFloat(((Float)o).floatValue()); + } else if (t == DataTypes.DoubleType) { + dst.appendDouble(((Double)o).doubleValue()); + } else if (t == DataTypes.StringType) { + byte[] b =((String)o).getBytes(StandardCharsets.UTF_8); + dst.appendByteArray(b, 0, b.length); + } else if (t instanceof DecimalType) { + DecimalType dt = (DecimalType)t; + Decimal d = Decimal.apply((BigDecimal)o, dt.precision(), dt.scale()); + if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) { + dst.appendLong(d.toUnscaledLong()); + } else { + final BigInteger integer = d.toJavaBigDecimal().unscaledValue(); + byte[] bytes = integer.toByteArray(); + dst.appendByteArray(bytes, 0, bytes.length); + } + } else if (t instanceof CalendarIntervalType) { + CalendarInterval c = (CalendarInterval)o; + dst.appendStruct(false); + dst.getChildColumn(0).appendInt(c.months); + dst.getChildColumn(1).appendLong(c.microseconds); + } else if (t instanceof DateType) { + dst.appendInt(DateTimeUtils.fromJavaDate((Date)o)); + } else { + throw new NotImplementedException("Type " + t); + } + } + } + + private static void appendValue(ColumnVector dst, DataType t, Row src, int fieldIdx) { + if (t instanceof ArrayType) { + ArrayType at = (ArrayType)t; + if (src.isNullAt(fieldIdx)) { + dst.appendNull(); + } else { + List values = src.getList(fieldIdx); + dst.appendArray(values.size()); + for (Object o : values) { + appendValue(dst.arrayData(), at.elementType(), o); + } + } + } else if (t instanceof StructType) { + StructType st = (StructType)t; + if (src.isNullAt(fieldIdx)) { + dst.appendStruct(true); + } else { + dst.appendStruct(false); + Row c = src.getStruct(fieldIdx); + for (int i = 0; i < st.fields().length; i++) { + appendValue(dst.getChildColumn(i), st.fields()[i].dataType(), c, i); + } + } + } else { + appendValue(dst, t, src.get(fieldIdx)); + } + } + + /** + * Converts an iterator of rows into a single ColumnBatch. + */ + public static ColumnarBatch toBatch( + StructType schema, MemoryMode memMode, Iterator row) { + ColumnarBatch batch = ColumnarBatch.allocate(schema, memMode); + int n = 0; + while (row.hasNext()) { + Row r = row.next(); + for (int i = 0; i < schema.fields().length; i++) { + appendValue(batch.column(i), schema.fields()[i].dataType(), r, i); + } + n++; + } + batch.setNumRows(n); + return batch; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java new file mode 100644 index 0000000000000..8cece73faa4b9 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -0,0 +1,480 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import java.math.BigDecimal; +import java.util.*; + +import org.apache.commons.lang.NotImplementedException; + +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow; +import org.apache.spark.sql.catalyst.expressions.MutableRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * This class is the in memory representation of rows as they are streamed through operators. It + * is designed to maximize CPU efficiency and not storage footprint. Since it is expected that + * each operator allocates one of these objects, the storage footprint on the task is negligible. + * + * The layout is a columnar with values encoded in their native format. Each RowBatch contains + * a horizontal partitioning of the data, split into columns. + * + * The ColumnarBatch supports either on heap or offheap modes with (mostly) the identical API. + * + * TODO: + * - There are many TODOs for the existing APIs. They should throw a not implemented exception. + * - Compaction: The batch and columns should be able to compact based on a selection vector. + */ +public final class ColumnarBatch { + private static final int DEFAULT_BATCH_SIZE = 4 * 1024; + private static MemoryMode DEFAULT_MEMORY_MODE = MemoryMode.ON_HEAP; + + private final StructType schema; + private final int capacity; + private int numRows; + private final ColumnVector[] columns; + + // True if the row is filtered. + private final boolean[] filteredRows; + + // Column indices that cannot have null values. + private final Set nullFilteredColumns; + + // Total number of rows that have been filtered. + private int numRowsFiltered = 0; + + // Staging row returned from getRow. + final Row row; + + public static ColumnarBatch allocate(StructType schema, MemoryMode memMode) { + return new ColumnarBatch(schema, DEFAULT_BATCH_SIZE, memMode); + } + + public static ColumnarBatch allocate(StructType type) { + return new ColumnarBatch(type, DEFAULT_BATCH_SIZE, DEFAULT_MEMORY_MODE); + } + + public static ColumnarBatch allocate(StructType schema, MemoryMode memMode, int maxRows) { + return new ColumnarBatch(schema, maxRows, memMode); + } + + /** + * Called to close all the columns in this batch. It is not valid to access the data after + * calling this. This must be called at the end to clean up memory allocations. + */ + public void close() { + for (ColumnVector c: columns) { + c.close(); + } + } + + /** + * Adapter class to interop with existing components that expect internal row. A lot of + * performance is lost with this translation. + */ + public static final class Row extends MutableRow { + protected int rowId; + private final ColumnarBatch parent; + private final int fixedLenRowSize; + private final ColumnVector[] columns; + + // Ctor used if this is a top level row. + private Row(ColumnarBatch parent) { + this.parent = parent; + this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(parent.numCols()); + this.columns = parent.columns; + } + + // Ctor used if this is a struct. + protected Row(ColumnVector[] columns) { + this.parent = null; + this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(columns.length); + this.columns = columns; + } + + /** + * Marks this row as being filtered out. This means a subsequent iteration over the rows + * in this batch will not include this row. + */ + public void markFiltered() { + parent.markFiltered(rowId); + } + + public ColumnVector[] columns() { return columns; } + + @Override + public int numFields() { return columns.length; } + + @Override + /** + * Revisit this. This is expensive. This is currently only used in test paths. + */ + public InternalRow copy() { + GenericMutableRow row = new GenericMutableRow(columns.length); + for (int i = 0; i < numFields(); i++) { + if (isNullAt(i)) { + row.setNullAt(i); + } else { + DataType dt = columns[i].dataType(); + if (dt instanceof BooleanType) { + row.setBoolean(i, getBoolean(i)); + } else if (dt instanceof IntegerType) { + row.setInt(i, getInt(i)); + } else if (dt instanceof LongType) { + row.setLong(i, getLong(i)); + } else if (dt instanceof FloatType) { + row.setFloat(i, getFloat(i)); + } else if (dt instanceof DoubleType) { + row.setDouble(i, getDouble(i)); + } else if (dt instanceof StringType) { + row.update(i, getUTF8String(i)); + } else if (dt instanceof BinaryType) { + row.update(i, getBinary(i)); + } else if (dt instanceof DecimalType) { + DecimalType t = (DecimalType)dt; + row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision()); + } else if (dt instanceof DateType) { + row.setInt(i, getInt(i)); + } else { + throw new RuntimeException("Not implemented. " + dt); + } + } + } + return row; + } + + @Override + public boolean anyNull() { + throw new NotImplementedException(); + } + + @Override + public boolean isNullAt(int ordinal) { return columns[ordinal].isNullAt(rowId); } + + @Override + public boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); } + + @Override + public byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); } + + @Override + public short getShort(int ordinal) { return columns[ordinal].getShort(rowId); } + + @Override + public int getInt(int ordinal) { return columns[ordinal].getInt(rowId); } + + @Override + public long getLong(int ordinal) { return columns[ordinal].getLong(rowId); } + + @Override + public float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); } + + @Override + public double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + return columns[ordinal].getDecimal(rowId, precision, scale); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + return columns[ordinal].getUTF8String(rowId); + } + + @Override + public byte[] getBinary(int ordinal) { + return columns[ordinal].getBinary(rowId); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + final int months = columns[ordinal].getChildColumn(0).getInt(rowId); + final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId); + return new CalendarInterval(months, microseconds); + } + + @Override + public InternalRow getStruct(int ordinal, int numFields) { + return columns[ordinal].getStruct(rowId); + } + + @Override + public ArrayData getArray(int ordinal) { + return columns[ordinal].getArray(rowId); + } + + @Override + public MapData getMap(int ordinal) { + throw new NotImplementedException(); + } + + @Override + public Object get(int ordinal, DataType dataType) { + throw new NotImplementedException(); + } + + @Override + public void update(int ordinal, Object value) { + if (value == null) { + setNullAt(ordinal); + } else { + DataType dt = columns[ordinal].dataType(); + if (dt instanceof BooleanType) { + setBoolean(ordinal, (boolean) value); + } else if (dt instanceof IntegerType) { + setInt(ordinal, (int) value); + } else if (dt instanceof ShortType) { + setShort(ordinal, (short) value); + } else if (dt instanceof LongType) { + setLong(ordinal, (long) value); + } else if (dt instanceof FloatType) { + setFloat(ordinal, (float) value); + } else if (dt instanceof DoubleType) { + setDouble(ordinal, (double) value); + } else if (dt instanceof DecimalType) { + DecimalType t = (DecimalType) dt; + setDecimal(ordinal, Decimal.apply((BigDecimal) value, t.precision(), t.scale()), + t.precision()); + } else { + throw new NotImplementedException("Datatype not supported " + dt); + } + } + } + + @Override + public void setNullAt(int ordinal) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNull(rowId); + } + + @Override + public void setBoolean(int ordinal, boolean value) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putBoolean(rowId, value); + } + + @Override + public void setByte(int ordinal, byte value) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putByte(rowId, value); + } + + @Override + public void setShort(int ordinal, short value) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putShort(rowId, value); + } + + @Override + public void setInt(int ordinal, int value) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putInt(rowId, value); + } + + @Override + public void setLong(int ordinal, long value) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putLong(rowId, value); + } + + @Override + public void setFloat(int ordinal, float value) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putFloat(rowId, value); + } + + @Override + public void setDouble(int ordinal, double value) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putDouble(rowId, value); + } + + @Override + public void setDecimal(int ordinal, Decimal value, int precision) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putDecimal(rowId, value, precision); + } + } + + /** + * Returns an iterator over the rows in this batch. This skips rows that are filtered out. + */ + public Iterator rowIterator() { + final int maxRows = ColumnarBatch.this.numRows(); + final Row row = new Row(this); + return new Iterator() { + int rowId = 0; + + @Override + public boolean hasNext() { + while (rowId < maxRows && ColumnarBatch.this.filteredRows[rowId]) { + ++rowId; + } + return rowId < maxRows; + } + + @Override + public Row next() { + while (rowId < maxRows && ColumnarBatch.this.filteredRows[rowId]) { + ++rowId; + } + if (rowId >= maxRows) { + throw new NoSuchElementException(); + } + row.rowId = rowId++; + return row; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + + /** + * Resets the batch for writing. + */ + public void reset() { + for (int i = 0; i < numCols(); ++i) { + columns[i].reset(); + } + if (this.numRowsFiltered > 0) { + Arrays.fill(filteredRows, false); + } + this.numRows = 0; + this.numRowsFiltered = 0; + } + + /** + * Sets the number of rows that are valid. Additionally, marks all rows as "filtered" if one or + * more of their attributes are part of a non-nullable column. + */ + public void setNumRows(int numRows) { + assert(numRows <= this.capacity); + this.numRows = numRows; + + for (int ordinal : nullFilteredColumns) { + if (columns[ordinal].numNulls != 0) { + for (int rowId = 0; rowId < numRows; rowId++) { + if (!filteredRows[rowId] && columns[ordinal].isNullAt(rowId)) { + filteredRows[rowId] = true; + ++numRowsFiltered; + } + } + } + } + } + + /** + * Returns the number of columns that make up this batch. + */ + public int numCols() { return columns.length; } + + /** + * Returns the number of rows for read, including filtered rows. + */ + public int numRows() { return numRows; } + + /** + * Returns the number of valid rows. + */ + public int numValidRows() { + assert(numRowsFiltered <= numRows); + return numRows - numRowsFiltered; + } + + /** + * Returns the max capacity (in number of rows) for this batch. + */ + public int capacity() { return capacity; } + + /** + * Returns the column at `ordinal`. + */ + public ColumnVector column(int ordinal) { return columns[ordinal]; } + + /** + * Sets (replaces) the column at `ordinal` with column. This can be used to do very efficient + * projections. + */ + public void setColumn(int ordinal, ColumnVector column) { + if (column instanceof OffHeapColumnVector) { + throw new NotImplementedException("Need to ref count columns."); + } + columns[ordinal] = column; + } + + /** + * Returns the row in this batch at `rowId`. Returned row is reused across calls. + */ + public ColumnarBatch.Row getRow(int rowId) { + assert(rowId >= 0); + assert(rowId < numRows); + row.rowId = rowId; + return row; + } + + /** + * Marks this row as being filtered out. This means a subsequent iteration over the rows + * in this batch will not include this row. + */ + public void markFiltered(int rowId) { + assert(!filteredRows[rowId]); + filteredRows[rowId] = true; + ++numRowsFiltered; + } + + /** + * Marks a given column as non-nullable. Any row that has a NULL value for the corresponding + * attribute is filtered out. + */ + public void filterNullsInColumn(int ordinal) { + nullFilteredColumns.add(ordinal); + } + + private ColumnarBatch(StructType schema, int maxRows, MemoryMode memMode) { + this.schema = schema; + this.capacity = maxRows; + this.columns = new ColumnVector[schema.size()]; + this.nullFilteredColumns = new HashSet<>(); + this.filteredRows = new boolean[maxRows]; + + for (int i = 0; i < schema.fields().length; ++i) { + StructField field = schema.fields()[i]; + columns[i] = ColumnVector.allocate(maxRows, field.dataType(), memMode); + } + + this.row = new Row(this); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java new file mode 100644 index 0000000000000..b1901411351a2 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import java.nio.ByteOrder; + +import org.apache.commons.lang.NotImplementedException; + +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.Platform; + +/** + * Column data backed using offheap memory. + */ +public final class OffHeapColumnVector extends ColumnVector { + // The data stored in these two allocations need to maintain binary compatible. We can + // directly pass this buffer to external components. + private long nulls; + private long data; + + // Set iff the type is array. + private long lengthData; + private long offsetData; + + protected OffHeapColumnVector(int capacity, DataType type) { + super(capacity, type, MemoryMode.OFF_HEAP); + if (!ByteOrder.nativeOrder().equals(ByteOrder.LITTLE_ENDIAN)) { + throw new NotImplementedException("Only little endian is supported."); + } + nulls = 0; + data = 0; + lengthData = 0; + offsetData = 0; + + reserveInternal(capacity); + reset(); + } + + @Override + public long valuesNativeAddress() { + return data; + } + + @Override + public long nullsNativeAddress() { + return nulls; + } + + @Override + public void close() { + Platform.freeMemory(nulls); + Platform.freeMemory(data); + Platform.freeMemory(lengthData); + Platform.freeMemory(offsetData); + nulls = 0; + data = 0; + lengthData = 0; + offsetData = 0; + } + + // + // APIs dealing with nulls + // + + @Override + public void putNotNull(int rowId) { + Platform.putByte(null, nulls + rowId, (byte) 0); + } + + @Override + public void putNull(int rowId) { + Platform.putByte(null, nulls + rowId, (byte) 1); + ++numNulls; + anyNullsSet = true; + } + + @Override + public void putNulls(int rowId, int count) { + long offset = nulls + rowId; + for (int i = 0; i < count; ++i, ++offset) { + Platform.putByte(null, offset, (byte) 1); + } + anyNullsSet = true; + numNulls += count; + } + + @Override + public void putNotNulls(int rowId, int count) { + if (!anyNullsSet) return; + long offset = nulls + rowId; + for (int i = 0; i < count; ++i, ++offset) { + Platform.putByte(null, offset, (byte) 0); + } + } + + @Override + public boolean isNullAt(int rowId) { + return Platform.getByte(null, nulls + rowId) == 1; + } + + // + // APIs dealing with Booleans + // + + @Override + public void putBoolean(int rowId, boolean value) { + Platform.putByte(null, data + rowId, (byte)((value) ? 1 : 0)); + } + + @Override + public void putBooleans(int rowId, int count, boolean value) { + byte v = (byte)((value) ? 1 : 0); + for (int i = 0; i < count; ++i) { + Platform.putByte(null, data + rowId + i, v); + } + } + + @Override + public boolean getBoolean(int rowId) { return Platform.getByte(null, data + rowId) == 1; } + + // + // APIs dealing with Bytes + // + + @Override + public void putByte(int rowId, byte value) { + Platform.putByte(null, data + rowId, value); + + } + + @Override + public void putBytes(int rowId, int count, byte value) { + for (int i = 0; i < count; ++i) { + Platform.putByte(null, data + rowId + i, value); + } + } + + @Override + public void putBytes(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, null, data + rowId, count); + } + + @Override + public byte getByte(int rowId) { + if (dictionary == null) { + return Platform.getByte(null, data + rowId); + } else { + return (byte) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } + } + + // + // APIs dealing with shorts + // + + @Override + public void putShort(int rowId, short value) { + Platform.putShort(null, data + 2 * rowId, value); + } + + @Override + public void putShorts(int rowId, int count, short value) { + long offset = data + 2 * rowId; + for (int i = 0; i < count; ++i, offset += 4) { + Platform.putShort(null, offset, value); + } + } + + @Override + public void putShorts(int rowId, int count, short[] src, int srcIndex) { + Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2, + null, data + 2 * rowId, count * 2); + } + + @Override + public short getShort(int rowId) { + if (dictionary == null) { + return Platform.getShort(null, data + 2 * rowId); + } else { + return (short) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } + } + + // + // APIs dealing with ints + // + + @Override + public void putInt(int rowId, int value) { + Platform.putInt(null, data + 4 * rowId, value); + } + + @Override + public void putInts(int rowId, int count, int value) { + long offset = data + 4 * rowId; + for (int i = 0; i < count; ++i, offset += 4) { + Platform.putInt(null, offset, value); + } + } + + @Override + public void putInts(int rowId, int count, int[] src, int srcIndex) { + Platform.copyMemory(src, Platform.INT_ARRAY_OFFSET + srcIndex * 4, + null, data + 4 * rowId, count * 4); + } + + @Override + public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, + null, data + 4 * rowId, count * 4); + } + + @Override + public int getInt(int rowId) { + if (dictionary == null) { + return Platform.getInt(null, data + 4 * rowId); + } else { + return dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } + } + + // + // APIs dealing with Longs + // + + @Override + public void putLong(int rowId, long value) { + Platform.putLong(null, data + 8 * rowId, value); + } + + @Override + public void putLongs(int rowId, int count, long value) { + long offset = data + 8 * rowId; + for (int i = 0; i < count; ++i, offset += 8) { + Platform.putLong(null, offset, value); + } + } + + @Override + public void putLongs(int rowId, int count, long[] src, int srcIndex) { + Platform.copyMemory(src, Platform.LONG_ARRAY_OFFSET + srcIndex * 8, + null, data + 8 * rowId, count * 8); + } + + @Override + public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, + null, data + 8 * rowId, count * 8); + } + + @Override + public long getLong(int rowId) { + if (dictionary == null) { + return Platform.getLong(null, data + 8 * rowId); + } else { + return dictionary.decodeToLong(dictionaryIds.getInt(rowId)); + } + } + + // + // APIs dealing with floats + // + + @Override + public void putFloat(int rowId, float value) { + Platform.putFloat(null, data + rowId * 4, value); + } + + @Override + public void putFloats(int rowId, int count, float value) { + long offset = data + 4 * rowId; + for (int i = 0; i < count; ++i, offset += 4) { + Platform.putFloat(null, offset, value); + } + } + + @Override + public void putFloats(int rowId, int count, float[] src, int srcIndex) { + Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4, + null, data + 4 * rowId, count * 4); + } + + @Override + public void putFloats(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + null, data + rowId * 4, count * 4); + } + + @Override + public float getFloat(int rowId) { + if (dictionary == null) { + return Platform.getFloat(null, data + rowId * 4); + } else { + return dictionary.decodeToFloat(dictionaryIds.getInt(rowId)); + } + } + + + // + // APIs dealing with doubles + // + + @Override + public void putDouble(int rowId, double value) { + Platform.putDouble(null, data + rowId * 8, value); + } + + @Override + public void putDoubles(int rowId, int count, double value) { + long offset = data + 8 * rowId; + for (int i = 0; i < count; ++i, offset += 8) { + Platform.putDouble(null, offset, value); + } + } + + @Override + public void putDoubles(int rowId, int count, double[] src, int srcIndex) { + Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex * 8, + null, data + 8 * rowId, count * 8); + } + + @Override + public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + null, data + rowId * 8, count * 8); + } + + @Override + public double getDouble(int rowId) { + if (dictionary == null) { + return Platform.getDouble(null, data + rowId * 8); + } else { + return dictionary.decodeToDouble(dictionaryIds.getInt(rowId)); + } + } + + // + // APIs dealing with Arrays. + // + @Override + public void putArray(int rowId, int offset, int length) { + assert(offset >= 0 && offset + length <= childColumns[0].capacity); + Platform.putInt(null, lengthData + 4 * rowId, length); + Platform.putInt(null, offsetData + 4 * rowId, offset); + } + + @Override + public int getArrayLength(int rowId) { + return Platform.getInt(null, lengthData + 4 * rowId); + } + + @Override + public int getArrayOffset(int rowId) { + return Platform.getInt(null, offsetData + 4 * rowId); + } + + // APIs dealing with ByteArrays + @Override + public int putByteArray(int rowId, byte[] value, int offset, int length) { + int result = arrayData().appendBytes(length, value, offset); + Platform.putInt(null, lengthData + 4 * rowId, length); + Platform.putInt(null, offsetData + 4 * rowId, result); + return result; + } + + @Override + public void loadBytes(ColumnVector.Array array) { + if (array.tmpByteArray.length < array.length) array.tmpByteArray = new byte[array.length]; + Platform.copyMemory( + null, data + array.offset, array.tmpByteArray, Platform.BYTE_ARRAY_OFFSET, array.length); + array.byteArray = array.tmpByteArray; + array.byteArrayOffset = 0; + } + + @Override + public void reserve(int requiredCapacity) { + if (requiredCapacity > capacity) reserveInternal(requiredCapacity * 2); + } + + // Split out the slow path. + private void reserveInternal(int newCapacity) { + if (this.resultArray != null) { + this.lengthData = + Platform.reallocateMemory(lengthData, elementsAppended * 4, newCapacity * 4); + this.offsetData = + Platform.reallocateMemory(offsetData, elementsAppended * 4, newCapacity * 4); + } else if (type instanceof ByteType || type instanceof BooleanType) { + this.data = Platform.reallocateMemory(data, elementsAppended, newCapacity); + } else if (type instanceof ShortType) { + this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2); + } else if (type instanceof IntegerType || type instanceof FloatType || + type instanceof DateType || DecimalType.is32BitDecimalType(type)) { + this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4); + } else if (type instanceof LongType || type instanceof DoubleType || + DecimalType.is64BitDecimalType(type) || type instanceof TimestampType) { + this.data = Platform.reallocateMemory(data, elementsAppended * 8, newCapacity * 8); + } else if (resultStruct != null) { + // Nothing to store. + } else { + throw new RuntimeException("Unhandled " + type); + } + this.nulls = Platform.reallocateMemory(nulls, elementsAppended, newCapacity); + Platform.setMemory(nulls + elementsAppended, (byte)0, newCapacity - elementsAppended); + capacity = newCapacity; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java new file mode 100644 index 0000000000000..e97276800daa8 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -0,0 +1,445 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import java.util.Arrays; + +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.Platform; + +/** + * A column backed by an in memory JVM array. This stores the NULLs as a byte per value + * and a java array for the values. + */ +public final class OnHeapColumnVector extends ColumnVector { + // The data stored in these arrays need to maintain binary compatible. We can + // directly pass this buffer to external components. + + // This is faster than a boolean array and we optimize this over memory footprint. + private byte[] nulls; + + // Array for each type. Only 1 is populated for any type. + private byte[] byteData; + private short[] shortData; + private int[] intData; + private long[] longData; + private float[] floatData; + private double[] doubleData; + + // Only set if type is Array. + private int[] arrayLengths; + private int[] arrayOffsets; + + protected OnHeapColumnVector(int capacity, DataType type) { + super(capacity, type, MemoryMode.ON_HEAP); + reserveInternal(capacity); + reset(); + } + + @Override + public long valuesNativeAddress() { + throw new RuntimeException("Cannot get native address for on heap column"); + } + @Override + public long nullsNativeAddress() { + throw new RuntimeException("Cannot get native address for on heap column"); + } + + @Override + public void close() { + } + + // + // APIs dealing with nulls + // + + @Override + public void putNotNull(int rowId) { + nulls[rowId] = (byte)0; + } + + @Override + public void putNull(int rowId) { + nulls[rowId] = (byte)1; + ++numNulls; + anyNullsSet = true; + } + + @Override + public void putNulls(int rowId, int count) { + for (int i = 0; i < count; ++i) { + nulls[rowId + i] = (byte)1; + } + anyNullsSet = true; + numNulls += count; + } + + @Override + public void putNotNulls(int rowId, int count) { + if (!anyNullsSet) return; + for (int i = 0; i < count; ++i) { + nulls[rowId + i] = (byte)0; + } + } + + @Override + public boolean isNullAt(int rowId) { + return nulls[rowId] == 1; + } + + // + // APIs dealing with Booleans + // + + @Override + public void putBoolean(int rowId, boolean value) { + byteData[rowId] = (byte)((value) ? 1 : 0); + } + + @Override + public void putBooleans(int rowId, int count, boolean value) { + byte v = (byte)((value) ? 1 : 0); + for (int i = 0; i < count; ++i) { + byteData[i + rowId] = v; + } + } + + @Override + public boolean getBoolean(int rowId) { + return byteData[rowId] == 1; + } + + // + + // + // APIs dealing with Bytes + // + + @Override + public void putByte(int rowId, byte value) { + byteData[rowId] = value; + } + + @Override + public void putBytes(int rowId, int count, byte value) { + for (int i = 0; i < count; ++i) { + byteData[i + rowId] = value; + } + } + + @Override + public void putBytes(int rowId, int count, byte[] src, int srcIndex) { + System.arraycopy(src, srcIndex, byteData, rowId, count); + } + + @Override + public byte getByte(int rowId) { + if (dictionary == null) { + return byteData[rowId]; + } else { + return (byte) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } + } + + // + // APIs dealing with Shorts + // + + @Override + public void putShort(int rowId, short value) { + shortData[rowId] = value; + } + + @Override + public void putShorts(int rowId, int count, short value) { + for (int i = 0; i < count; ++i) { + shortData[i + rowId] = value; + } + } + + @Override + public void putShorts(int rowId, int count, short[] src, int srcIndex) { + System.arraycopy(src, srcIndex, shortData, rowId, count); + } + + @Override + public short getShort(int rowId) { + if (dictionary == null) { + return shortData[rowId]; + } else { + return (short) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } + } + + + // + // APIs dealing with Ints + // + + @Override + public void putInt(int rowId, int value) { + intData[rowId] = value; + } + + @Override + public void putInts(int rowId, int count, int value) { + for (int i = 0; i < count; ++i) { + intData[i + rowId] = value; + } + } + + @Override + public void putInts(int rowId, int count, int[] src, int srcIndex) { + System.arraycopy(src, srcIndex, intData, rowId, count); + } + + @Override + public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; + for (int i = 0; i < count; ++i) { + intData[i + rowId] = Platform.getInt(src, srcOffset); + srcIndex += 4; + srcOffset += 4; + } + } + + @Override + public int getInt(int rowId) { + if (dictionary == null) { + return intData[rowId]; + } else { + return dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } + } + + // + // APIs dealing with Longs + // + + @Override + public void putLong(int rowId, long value) { + longData[rowId] = value; + } + + @Override + public void putLongs(int rowId, int count, long value) { + for (int i = 0; i < count; ++i) { + longData[i + rowId] = value; + } + } + + @Override + public void putLongs(int rowId, int count, long[] src, int srcIndex) { + System.arraycopy(src, srcIndex, longData, rowId, count); + } + + @Override + public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; + for (int i = 0; i < count; ++i) { + longData[i + rowId] = Platform.getLong(src, srcOffset); + srcIndex += 8; + srcOffset += 8; + } + } + + @Override + public long getLong(int rowId) { + if (dictionary == null) { + return longData[rowId]; + } else { + return dictionary.decodeToLong(dictionaryIds.getInt(rowId)); + } + } + + // + // APIs dealing with floats + // + + @Override + public void putFloat(int rowId, float value) { floatData[rowId] = value; } + + @Override + public void putFloats(int rowId, int count, float value) { + Arrays.fill(floatData, rowId, rowId + count, value); + } + + @Override + public void putFloats(int rowId, int count, float[] src, int srcIndex) { + System.arraycopy(src, srcIndex, floatData, rowId, count); + } + + @Override + public void putFloats(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + floatData, Platform.DOUBLE_ARRAY_OFFSET + rowId * 4, count * 4); + } + + @Override + public float getFloat(int rowId) { + if (dictionary == null) { + return floatData[rowId]; + } else { + return dictionary.decodeToFloat(dictionaryIds.getInt(rowId)); + } + } + + // + // APIs dealing with doubles + // + + @Override + public void putDouble(int rowId, double value) { + doubleData[rowId] = value; + } + + @Override + public void putDoubles(int rowId, int count, double value) { + Arrays.fill(doubleData, rowId, rowId + count, value); + } + + @Override + public void putDoubles(int rowId, int count, double[] src, int srcIndex) { + System.arraycopy(src, srcIndex, doubleData, rowId, count); + } + + @Override + public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, doubleData, + Platform.DOUBLE_ARRAY_OFFSET + rowId * 8, count * 8); + } + + @Override + public double getDouble(int rowId) { + if (dictionary == null) { + return doubleData[rowId]; + } else { + return dictionary.decodeToDouble(dictionaryIds.getInt(rowId)); + } + } + + // + // APIs dealing with Arrays + // + + @Override + public int getArrayLength(int rowId) { + return arrayLengths[rowId]; + } + @Override + public int getArrayOffset(int rowId) { + return arrayOffsets[rowId]; + } + + @Override + public void putArray(int rowId, int offset, int length) { + arrayOffsets[rowId] = offset; + arrayLengths[rowId] = length; + } + + @Override + public void loadBytes(ColumnVector.Array array) { + array.byteArray = byteData; + array.byteArrayOffset = array.offset; + } + + // + // APIs dealing with Byte Arrays + // + + @Override + public int putByteArray(int rowId, byte[] value, int offset, int length) { + int result = arrayData().appendBytes(length, value, offset); + arrayOffsets[rowId] = result; + arrayLengths[rowId] = length; + return result; + } + + @Override + public void reserve(int requiredCapacity) { + if (requiredCapacity > capacity) reserveInternal(requiredCapacity * 2); + } + + // Spilt this function out since it is the slow path. + private void reserveInternal(int newCapacity) { + if (this.resultArray != null || DecimalType.isByteArrayDecimalType(type)) { + int[] newLengths = new int[newCapacity]; + int[] newOffsets = new int[newCapacity]; + if (this.arrayLengths != null) { + System.arraycopy(this.arrayLengths, 0, newLengths, 0, elementsAppended); + System.arraycopy(this.arrayOffsets, 0, newOffsets, 0, elementsAppended); + } + arrayLengths = newLengths; + arrayOffsets = newOffsets; + } else if (type instanceof BooleanType) { + if (byteData == null || byteData.length < newCapacity) { + byte[] newData = new byte[newCapacity]; + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + byteData = newData; + } + } else if (type instanceof ByteType) { + if (byteData == null || byteData.length < newCapacity) { + byte[] newData = new byte[newCapacity]; + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + byteData = newData; + } + } else if (type instanceof ShortType) { + if (shortData == null || shortData.length < newCapacity) { + short[] newData = new short[newCapacity]; + if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended); + shortData = newData; + } + } else if (type instanceof IntegerType || type instanceof DateType || + DecimalType.is32BitDecimalType(type)) { + if (intData == null || intData.length < newCapacity) { + int[] newData = new int[newCapacity]; + if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended); + intData = newData; + } + } else if (type instanceof LongType || type instanceof TimestampType || + DecimalType.is64BitDecimalType(type)) { + if (longData == null || longData.length < newCapacity) { + long[] newData = new long[newCapacity]; + if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended); + longData = newData; + } + } else if (type instanceof FloatType) { + if (floatData == null || floatData.length < newCapacity) { + float[] newData = new float[newCapacity]; + if (floatData != null) System.arraycopy(floatData, 0, newData, 0, elementsAppended); + floatData = newData; + } + } else if (type instanceof DoubleType) { + if (doubleData == null || doubleData.length < newCapacity) { + double[] newData = new double[newCapacity]; + if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, elementsAppended); + doubleData = newData; + } + } else if (resultStruct != null) { + // Nothing to store. + } else { + throw new RuntimeException("Unhandled " + type); + } + + byte[] newNulls = new byte[newCapacity]; + if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, elementsAppended); + nulls = newNulls; + + capacity = newCapacity; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/expressions/java/typed.java b/sql/core/src/main/java/org/apache/spark/sql/expressions/java/typed.java new file mode 100644 index 0000000000000..c7c6e3868f9bb --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/expressions/java/typed.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions.java; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.sql.TypedColumn; +import org.apache.spark.sql.execution.aggregate.TypedAverage; +import org.apache.spark.sql.execution.aggregate.TypedCount; +import org.apache.spark.sql.execution.aggregate.TypedSumDouble; +import org.apache.spark.sql.execution.aggregate.TypedSumLong; + +/** + * :: Experimental :: + * Type-safe functions available for {@link org.apache.spark.sql.Dataset} operations in Java. + * + * Scala users should use {@link org.apache.spark.sql.expressions.scala.typed}. + * + * @since 2.0.0 + */ +@Experimental +public class typed { + // Note: make sure to keep in sync with typed.scala + + /** + * Average aggregate function. + * + * @since 2.0.0 + */ + public static TypedColumn avg(MapFunction f) { + return new TypedAverage(f).toColumnJava(); + } + + /** + * Count aggregate function. + * + * @since 2.0.0 + */ + public static TypedColumn count(MapFunction f) { + return new TypedCount(f).toColumnJava(); + } + + /** + * Sum aggregate function for floating point (double) type. + * + * @since 2.0.0 + */ + public static TypedColumn sum(MapFunction f) { + return new TypedSumDouble(f).toColumnJava(); + } + + /** + * Sum aggregate function for integral (long, i.e. 64 bit integer) type. + * + * @since 2.0.0 + */ + public static TypedColumn sumLong(MapFunction f) { + return new TypedSumLong(f).toColumnJava(); + } +} diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory b/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory new file mode 100644 index 0000000000000..507100be90967 --- /dev/null +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory @@ -0,0 +1 @@ +org.apache.spark.sql.execution.ui.SQLHistoryListenerFactory diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 1ca2044057e56..226d59d0eae88 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,3 +1,4 @@ +org.apache.spark.sql.execution.datasources.csv.DefaultSource org.apache.spark.sql.execution.datasources.jdbc.DefaultSource org.apache.spark.sql.execution.datasources.json.DefaultSource org.apache.spark.sql.execution.datasources.parquet.DefaultSource diff --git a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css index ddd3a91dd8ef8..303f8ebb8814c 100644 --- a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css +++ b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css @@ -20,6 +20,12 @@ text-shadow: none; } +#plan-viz-graph svg g.cluster rect { + fill: #A0DFFF; + stroke: #3EC0FF; + stroke-width: 1px; +} + #plan-viz-graph svg g.node rect { fill: #C3EBFF; stroke: #3EC0FF; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index c73f696962de5..bd96941da798d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -20,15 +20,16 @@ package org.apache.spark.sql import scala.language.implicitConversions import org.apache.spark.annotation.Experimental -import org.apache.spark.Logging -import org.apache.spark.sql.functions.lit +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.encoders.Encoder +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DataTypeParser +import org.apache.spark.sql.catalyst.parser.DataTypeParser +import org.apache.spark.sql.catalyst.util.usePrettyExpression +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types._ - private[sql] object Column { def apply(colName: String): Column = new Column(colName) @@ -39,14 +40,58 @@ private[sql] object Column { } /** - * A [[Column]] where an [[Encoder]] has been given for the expected return type. + * A [[Column]] where an [[Encoder]] has been given for the expected input and return type. + * To create a [[TypedColumn]], use the `as` function on a [[Column]]. + * + * @tparam T The input type expected for this expression. Can be `Any` if the expression is type + * checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`). + * @tparam U The output type of this column. + * * @since 1.6.0 */ -class TypedColumn[T](expr: Expression)(implicit val encoder: Encoder[T]) extends Column(expr) +class TypedColumn[-T, U]( + expr: Expression, + private[sql] val encoder: ExpressionEncoder[U]) + extends Column(expr) { + + /** + * Inserts the specific input type and schema into any expressions that are expected to operate + * on a decoded object. + */ + private[sql] def withInputType( + inputDeserializer: Expression, + inputAttributes: Seq[Attribute]): TypedColumn[T, U] = { + val unresolvedDeserializer = UnresolvedDeserializer(inputDeserializer, inputAttributes) + val newExpr = expr transform { + case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty => + ta.copy(inputDeserializer = Some(unresolvedDeserializer)) + } + new TypedColumn[T, U](newExpr, encoder) + } +} /** * :: Experimental :: - * A column in a [[DataFrame]]. + * A column that will be computed based on the data in a [[DataFrame]]. + * + * A new column is constructed based on the input columns present in a dataframe: + * + * {{{ + * df("columnName") // On a specific DataFrame. + * col("columnName") // A generic column no yet associated with a DataFrame. + * col("columnName.field") // Extracting a struct field + * col("`a.column.with.dots`") // Escape `.` in column names. + * $"columnName" // Scala short hand for a named column. + * expr("a + 1") // A column that is constructed from a parsed SQL Expression. + * lit("abc") // A column that produces a literal (constant) value. + * }}} + * + * [[Column]] objects can be composed to form complex expressions: + * + * {{{ + * $"a" + 1 + * $"a" === $"b" + * }}} * * @groupname java_expr_ops Java-specific expression operators * @groupname expr_ops Expression operators @@ -60,17 +105,49 @@ class Column(protected[sql] val expr: Expression) extends Logging { def this(name: String) = this(name match { case "*" => UnresolvedStar(None) - case _ if name.endsWith(".*") => { + case _ if name.endsWith(".*") => val parts = UnresolvedAttribute.parseAttributeName(name.substring(0, name.length - 2)) UnresolvedStar(Some(parts)) - } case _ => UnresolvedAttribute.quotedString(name) }) /** Creates a column based on the given expression. */ - implicit private def exprToColumn(newExpr: Expression): Column = new Column(newExpr) + private def withExpr(newExpr: Expression): Column = new Column(newExpr) + + /** + * Returns the expression for this column either with an existing or auto assigned name. + */ + private[sql] def named: NamedExpression = expr match { + // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we + // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to + // make it a NamedExpression. + case u: UnresolvedAttribute => UnresolvedAlias(u) + + case u: UnresolvedExtractValue => UnresolvedAlias(u) + + case expr: NamedExpression => expr - override def toString: String = expr.prettyString + // Leave an unaliased generator with an empty list of names since the analyzer will generate + // the correct defaults after the nested expression's type has been resolved. + case explode: Explode => MultiAlias(explode, Nil) + + case jt: JsonTuple => MultiAlias(jt, Nil) + + case func: UnresolvedFunction => UnresolvedAlias(func, Some(usePrettyExpression(func).sql)) + + // If we have a top level Cast, there is a chance to give it a better alias, if there is a + // NamedExpression under this Cast. + case c: Cast => c.transformUp { + case Cast(ne: NamedExpression, to) => UnresolvedAlias(Cast(ne, to)) + } match { + case ne: NamedExpression => ne + case other => Alias(expr, usePrettyExpression(expr).sql)() + } + + case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() + } + + override def toString: String = usePrettyExpression(expr).sql override def equals(that: Any): Boolean = that match { case that: Column => that.expr.equals(this.expr) @@ -85,21 +162,24 @@ class Column(protected[sql] val expr: Expression) extends Logging { * results into the correct JVM types. * @since 1.6.0 */ - def as[T : Encoder]: TypedColumn[T] = new TypedColumn[T](expr) + def as[U : Encoder]: TypedColumn[Any, U] = new TypedColumn[Any, U](expr, encoderFor[U]) /** * Extracts a value or values from a complex type. * The following types of extraction are supported: - * - Given an Array, an integer ordinal can be used to retrieve a single value. - * - Given a Map, a key of the correct type can be used to retrieve an individual value. - * - Given a Struct, a string fieldName can be used to extract that field. - * - Given an Array of Structs, a string fieldName can be used to extract filed - * of every struct in that array, and return an Array of fields + * + * - Given an Array, an integer ordinal can be used to retrieve a single value. + * - Given a Map, a key of the correct type can be used to retrieve an individual value. + * - Given a Struct, a string fieldName can be used to extract that field. + * - Given an Array of Structs, a string fieldName can be used to extract filed + * of every struct in that array, and return an Array of fields * * @group expr_ops * @since 1.4.0 */ - def apply(extraction: Any): Column = UnresolvedExtractValue(expr, lit(extraction).expr) + def apply(extraction: Any): Column = withExpr { + UnresolvedExtractValue(expr, lit(extraction).expr) + } /** * Unary minus, i.e. negate the expression. @@ -115,7 +195,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def unary_- : Column = UnaryMinus(expr) + def unary_- : Column = withExpr { UnaryMinus(expr) } /** * Inversion of boolean expression, i.e. NOT. @@ -131,7 +211,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def unary_! : Column = Not(expr) + def unary_! : Column = withExpr { Not(expr) } /** * Equality test. @@ -147,7 +227,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def === (other: Any): Column = { + def === (other: Any): Column = withExpr { val right = lit(other).expr if (this.expr == right) { logWarning( @@ -173,6 +253,23 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def equalTo(other: Any): Column = this === other + /** + * Inequality test. + * {{{ + * // Scala: + * df.select( df("colA") =!= df("colB") ) + * df.select( !(df("colA") === df("colB")) ) + * + * // Java: + * import static org.apache.spark.sql.functions.*; + * df.filter( col("colA").notEqual(col("colB")) ); + * }}} + * + * @group expr_ops + * @since 2.0.0 + */ + def =!= (other: Any): Column = withExpr{ Not(EqualTo(expr, lit(other).expr)) } + /** * Inequality test. * {{{ @@ -187,8 +284,9 @@ class Column(protected[sql] val expr: Expression) extends Logging { * * @group expr_ops * @since 1.3.0 - */ - def !== (other: Any): Column = Not(EqualTo(expr, lit(other).expr)) + */ + @deprecated("!== does not have the same precedence as ===, use =!= instead", "2.0.0") + def !== (other: Any): Column = this =!= other /** * Inequality test. @@ -205,7 +303,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group java_expr_ops * @since 1.3.0 */ - def notEqual(other: Any): Column = Not(EqualTo(expr, lit(other).expr)) + def notEqual(other: Any): Column = withExpr { Not(EqualTo(expr, lit(other).expr)) } /** * Greater than. @@ -221,7 +319,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def > (other: Any): Column = GreaterThan(expr, lit(other).expr) + def > (other: Any): Column = withExpr { GreaterThan(expr, lit(other).expr) } /** * Greater than. @@ -252,7 +350,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def < (other: Any): Column = LessThan(expr, lit(other).expr) + def < (other: Any): Column = withExpr { LessThan(expr, lit(other).expr) } /** * Less than. @@ -282,7 +380,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def <= (other: Any): Column = LessThanOrEqual(expr, lit(other).expr) + def <= (other: Any): Column = withExpr { LessThanOrEqual(expr, lit(other).expr) } /** * Less than or equal to. @@ -312,7 +410,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def >= (other: Any): Column = GreaterThanOrEqual(expr, lit(other).expr) + def >= (other: Any): Column = withExpr { GreaterThanOrEqual(expr, lit(other).expr) } /** * Greater than or equal to an expression. @@ -335,7 +433,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def <=> (other: Any): Column = EqualNullSafe(expr, lit(other).expr) + def <=> (other: Any): Column = withExpr { EqualNullSafe(expr, lit(other).expr) } /** * Equality test that is safe for null values. @@ -367,8 +465,11 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @since 1.4.0 */ def when(condition: Column, value: Any): Column = this.expr match { - case CaseWhen(branches: Seq[Expression]) => - CaseWhen(branches ++ Seq(lit(condition).expr, lit(value).expr)) + case CaseWhen(branches, None) => + withExpr { CaseWhen(branches :+ (condition.expr, lit(value).expr)) } + case CaseWhen(branches, Some(_)) => + throw new IllegalArgumentException( + "when() cannot be applied once otherwise() is applied") case _ => throw new IllegalArgumentException( "when() can only be applied on a Column previously generated by when() function") @@ -396,13 +497,11 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @since 1.4.0 */ def otherwise(value: Any): Column = this.expr match { - case CaseWhen(branches: Seq[Expression]) => - if (branches.size % 2 == 0) { - CaseWhen(branches :+ lit(value).expr) - } else { - throw new IllegalArgumentException( - "otherwise() can only be applied once on a Column previously generated by when()") - } + case CaseWhen(branches, None) => + withExpr { CaseWhen(branches, Option(lit(value).expr)) } + case CaseWhen(branches, Some(_)) => + throw new IllegalArgumentException( + "otherwise() can only be applied once on a Column previously generated by when()") case _ => throw new IllegalArgumentException( "otherwise() can only be applied on a Column previously generated by when()") @@ -424,7 +523,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.5.0 */ - def isNaN: Column = IsNaN(expr) + def isNaN: Column = withExpr { IsNaN(expr) } /** * True if the current expression is null. @@ -432,7 +531,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def isNull: Column = IsNull(expr) + def isNull: Column = withExpr { IsNull(expr) } /** * True if the current expression is NOT null. @@ -440,7 +539,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def isNotNull: Column = IsNotNull(expr) + def isNotNull: Column = withExpr { IsNotNull(expr) } /** * Boolean OR. @@ -455,7 +554,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def || (other: Any): Column = Or(expr, lit(other).expr) + def || (other: Any): Column = withExpr { Or(expr, lit(other).expr) } /** * Boolean OR. @@ -485,7 +584,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def && (other: Any): Column = And(expr, lit(other).expr) + def && (other: Any): Column = withExpr { And(expr, lit(other).expr) } /** * Boolean AND. @@ -515,7 +614,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def + (other: Any): Column = Add(expr, lit(other).expr) + def + (other: Any): Column = withExpr { Add(expr, lit(other).expr) } /** * Sum of this expression and another expression. @@ -545,7 +644,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def - (other: Any): Column = Subtract(expr, lit(other).expr) + def - (other: Any): Column = withExpr { Subtract(expr, lit(other).expr) } /** * Subtraction. Subtract the other expression from this expression. @@ -575,7 +674,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def * (other: Any): Column = Multiply(expr, lit(other).expr) + def * (other: Any): Column = withExpr { Multiply(expr, lit(other).expr) } /** * Multiplication of this expression and another expression. @@ -605,7 +704,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def / (other: Any): Column = Divide(expr, lit(other).expr) + def / (other: Any): Column = withExpr { Divide(expr, lit(other).expr) } /** * Division this expression by another expression. @@ -628,7 +727,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def % (other: Any): Column = Remainder(expr, lit(other).expr) + def % (other: Any): Column = withExpr { Remainder(expr, lit(other).expr) } /** * Modulo (a.k.a. remainder) expression. @@ -638,17 +737,6 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def mod(other: Any): Column = this % other - /** - * A boolean expression that is evaluated to true if the value of this expression is contained - * by the evaluated values of the arguments. - * - * @group expr_ops - * @since 1.3.0 - */ - @deprecated("use isin", "1.5.0") - @scala.annotation.varargs - def in(list: Any*): Column = isin(list : _*) - /** * A boolean expression that is evaluated to true if the value of this expression is contained * by the evaluated values of the arguments. @@ -657,7 +745,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @since 1.5.0 */ @scala.annotation.varargs - def isin(list: Any*): Column = In(expr, list.map(lit(_).expr)) + def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) } /** * SQL like expression. @@ -665,7 +753,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def like(literal: String): Column = Like(expr, lit(literal).expr) + def like(literal: String): Column = withExpr { Like(expr, lit(literal).expr) } /** * SQL RLIKE expression (LIKE with Regex). @@ -673,7 +761,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def rlike(literal: String): Column = RLike(expr, lit(literal).expr) + def rlike(literal: String): Column = withExpr { RLike(expr, lit(literal).expr) } /** * An expression that gets an item at position `ordinal` out of an array, @@ -682,7 +770,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def getItem(key: Any): Column = UnresolvedExtractValue(expr, Literal(key)) + def getItem(key: Any): Column = withExpr { UnresolvedExtractValue(expr, Literal(key)) } /** * An expression that gets a field by name in a [[StructType]]. @@ -690,7 +778,9 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def getField(fieldName: String): Column = UnresolvedExtractValue(expr, Literal(fieldName)) + def getField(fieldName: String): Column = withExpr { + UnresolvedExtractValue(expr, Literal(fieldName)) + } /** * An expression that returns a substring. @@ -700,7 +790,9 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def substr(startPos: Column, len: Column): Column = Substring(expr, startPos.expr, len.expr) + def substr(startPos: Column, len: Column): Column = withExpr { + Substring(expr, startPos.expr, len.expr) + } /** * An expression that returns a substring. @@ -710,7 +802,9 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def substr(startPos: Int, len: Int): Column = Substring(expr, lit(startPos).expr, lit(len).expr) + def substr(startPos: Int, len: Int): Column = withExpr { + Substring(expr, lit(startPos).expr, lit(len).expr) + } /** * Contains the other element. @@ -718,7 +812,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def contains(other: Any): Column = Contains(expr, lit(other).expr) + def contains(other: Any): Column = withExpr { Contains(expr, lit(other).expr) } /** * String starts with. @@ -726,7 +820,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def startsWith(other: Column): Column = StartsWith(expr, lit(other).expr) + def startsWith(other: Column): Column = withExpr { StartsWith(expr, lit(other).expr) } /** * String starts with another string literal. @@ -742,7 +836,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def endsWith(other: Column): Column = EndsWith(expr, lit(other).expr) + def endsWith(other: Column): Column = withExpr { EndsWith(expr, lit(other).expr) } /** * String ends with another string literal. @@ -762,7 +856,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def alias(alias: String): Column = as(alias) + def alias(alias: String): Column = name(alias) /** * Gives the column an alias. @@ -777,10 +871,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def as(alias: String): Column = expr match { - case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata)) - case other => Alias(other, alias)() - } + def as(alias: String): Column = name(alias) /** * (Scala-specific) Assigns the given aliases to the results of a table generating function. @@ -792,7 +883,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def as(aliases: Seq[String]): Column = MultiAlias(expr, aliases) + def as(aliases: Seq[String]): Column = withExpr { MultiAlias(expr, aliases) } /** * Assigns the given aliases to the results of a table generating function. @@ -804,7 +895,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def as(aliases: Array[String]): Column = MultiAlias(expr, aliases) + def as(aliases: Array[String]): Column = withExpr { MultiAlias(expr, aliases) } /** * Gives the column an alias. @@ -819,9 +910,11 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def as(alias: Symbol): Column = expr match { - case ne: NamedExpression => Alias(expr, alias.name)(explicitMetadata = Some(ne.metadata)) - case other => Alias(other, alias.name)() + def as(alias: Symbol): Column = withExpr { + expr match { + case ne: NamedExpression => Alias(expr, alias.name)(explicitMetadata = Some(ne.metadata)) + case other => Alias(other, alias.name)() + } } /** @@ -834,10 +927,30 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def as(alias: String, metadata: Metadata): Column = { + def as(alias: String, metadata: Metadata): Column = withExpr { Alias(expr, alias)(explicitMetadata = Some(metadata)) } + /** + * Gives the column a name (alias). + * {{{ + * // Renames colA to colB in select output. + * df.select($"colA".name("colB")) + * }}} + * + * If the current column has metadata associated with it, this metadata will be propagated + * to the new column. If this not desired, use `as` with explicitly empty metadata. + * + * @group expr_ops + * @since 2.0.0 + */ + def name(alias: String): Column = withExpr { + expr match { + case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata)) + case other => Alias(other, alias)() + } + } + /** * Casts the column to a different data type. * {{{ @@ -852,11 +965,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def cast(to: DataType): Column = expr match { - // keeps the name of expression if possible when do cast. - case ne: NamedExpression => UnresolvedAlias(Cast(expr, to)) - case _ => Cast(expr, to) - } + def cast(to: DataType): Column = withExpr { Cast(expr, to) } /** * Casts the column to a different data type, using the canonical string representation @@ -885,7 +994,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def desc: Column = SortOrder(expr, Descending) + def desc: Column = withExpr { SortOrder(expr, Descending) } /** * Returns an ordering used in sorting. @@ -900,7 +1009,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def asc: Column = SortOrder(expr, Ascending) + def asc: Column = withExpr { SortOrder(expr, Ascending) } /** * Prints the expression to the console for debugging purpose. @@ -913,7 +1022,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { if (extended) { println(expr) } else { - println(expr.prettyString) + println(expr.sql) } // scalastyle:on println } @@ -927,7 +1036,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def bitwiseOR(other: Any): Column = BitwiseOr(expr, lit(other).expr) + def bitwiseOR(other: Any): Column = withExpr { BitwiseOr(expr, lit(other).expr) } /** * Compute bitwise AND of this expression with another expression. @@ -938,7 +1047,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def bitwiseAND(other: Any): Column = BitwiseAnd(expr, lit(other).expr) + def bitwiseAND(other: Any): Column = withExpr { BitwiseAnd(expr, lit(other).expr) } /** * Compute bitwise XOR of this expression with another expression. @@ -949,7 +1058,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def bitwiseXOR(other: Any): Column = BitwiseXor(expr, lit(other).expr) + def bitwiseXOR(other: Any): Column = withExpr { BitwiseXor(expr, lit(other).expr) } /** * Define a windowing column. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala new file mode 100644 index 0000000000000..d9973b092dc11 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * A handle to a query that is executing continuously in the background as new data arrives. + * All these methods are thread-safe. + * @since 2.0.0 + */ +@Experimental +trait ContinuousQuery { + + /** + * Returns the name of the query. + * @since 2.0.0 + */ + def name: String + + /** + * Returns the SQLContext associated with `this` query + * @since 2.0.0 + */ + def sqlContext: SQLContext + + /** + * Whether the query is currently active or not + * @since 2.0.0 + */ + def isActive: Boolean + + /** + * Returns the [[ContinuousQueryException]] if the query was terminated by an exception. + * @since 2.0.0 + */ + def exception: Option[ContinuousQueryException] + + /** + * Returns current status of all the sources. + * @since 2.0.0 + */ + def sourceStatuses: Array[SourceStatus] + + /** Returns current status of the sink. */ + def sinkStatus: SinkStatus + + /** + * Waits for the termination of `this` query, either by `query.stop()` or by an exception. + * If the query has terminated with an exception, then the exception will be thrown. + * + * If the query has terminated, then all subsequent calls to this method will either return + * immediately (if the query was terminated by `stop()`), or throw the exception + * immediately (if the query has terminated with exception). + * + * @throws ContinuousQueryException, if `this` query has terminated with an exception. + * + * @since 2.0.0 + */ + def awaitTermination(): Unit + + /** + * Waits for the termination of `this` query, either by `query.stop()` or by an exception. + * If the query has terminated with an exception, then the exception will be throw. + * Otherwise, it returns whether the query has terminated or not within the `timeoutMs` + * milliseconds. + * + * If the query has terminated, then all subsequent calls to this method will either return + * `true` immediately (if the query was terminated by `stop()`), or throw the exception + * immediately (if the query has terminated with exception). + * + * @throws ContinuousQueryException, if `this` query has terminated with an exception + * + * @since 2.0.0 + */ + def awaitTermination(timeoutMs: Long): Boolean + + /** + * Blocks until all available data in the source has been processed an committed to the sink. + * This method is intended for testing. Note that in the case of continually arriving data, this + * method may block forever. Additionally, this method is only guaranteed to block until data that + * has been synchronously appended data to a [[org.apache.spark.sql.execution.streaming.Source]] + * prior to invocation. (i.e. `getOffset` must immediately reflect the addition). + */ + def processAllAvailable(): Unit + + /** + * Stops the execution of this query if it is running. This method blocks until the threads + * performing execution has stopped. + * @since 2.0.0 + */ + def stop(): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala new file mode 100644 index 0000000000000..fec38629d914e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.execution.streaming.{Offset, StreamExecution} + +/** + * :: Experimental :: + * Exception that stopped a [[ContinuousQuery]]. + * @param query Query that caused the exception + * @param message Message of this exception + * @param cause Internal cause of this exception + * @param startOffset Starting offset (if known) of the range of data in which exception occurred + * @param endOffset Ending offset (if known) of the range of data in exception occurred + * @since 2.0.0 + */ +@Experimental +class ContinuousQueryException private[sql]( + @transient val query: ContinuousQuery, + val message: String, + val cause: Throwable, + val startOffset: Option[Offset] = None, + val endOffset: Option[Offset] = None) + extends Exception(message, cause) { + + /** Time when the exception occurred */ + val time: Long = System.currentTimeMillis + + override def toString(): String = { + val causeStr = + s"${cause.getMessage} ${cause.getStackTrace.take(10).mkString("", "\n|\t", "\n")}" + s""" + |$causeStr + | + |${query.asInstanceOf[StreamExecution].toDebugString} + """.stripMargin + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala new file mode 100644 index 0000000000000..1343e81569cbd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.mutable + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef +import org.apache.spark.sql.util.ContinuousQueryListener + +/** + * :: Experimental :: + * A class to manage all the [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]] active + * on a [[SQLContext]]. + * + * @since 2.0.0 + */ +@Experimental +class ContinuousQueryManager(sqlContext: SQLContext) { + + private[sql] val stateStoreCoordinator = + StateStoreCoordinatorRef.forDriver(sqlContext.sparkContext.env) + private val listenerBus = new ContinuousQueryListenerBus(sqlContext.sparkContext.listenerBus) + private val activeQueries = new mutable.HashMap[String, ContinuousQuery] + private val activeQueriesLock = new Object + private val awaitTerminationLock = new Object + + private var lastTerminatedQuery: ContinuousQuery = null + + /** + * Returns a list of active queries associated with this SQLContext + * + * @since 2.0.0 + */ + def active: Array[ContinuousQuery] = activeQueriesLock.synchronized { + activeQueries.values.toArray + } + + /** + * Returns an active query from this SQLContext or throws exception if bad name + * + * @since 2.0.0 + */ + def get(name: String): ContinuousQuery = activeQueriesLock.synchronized { + activeQueries.getOrElse(name, + throw new IllegalArgumentException(s"There is no active query with name $name")) + } + + /** + * Wait until any of the queries on the associated SQLContext has terminated since the + * creation of the context, or since `resetTerminated()` was called. If any query was terminated + * with an exception, then the exception will be thrown. + * + * If a query has terminated, then subsequent calls to `awaitAnyTermination()` will either + * return immediately (if the query was terminated by `query.stop()`), + * or throw the exception immediately (if the query was terminated with exception). Use + * `resetTerminated()` to clear past terminations and wait for new terminations. + * + * In the case where multiple queries have terminated since `resetTermination()` was called, + * if any query has terminated with exception, then `awaitAnyTermination()` will + * throw any of the exception. For correctly documenting exceptions across multiple queries, + * users need to stop all of them after any of them terminates with exception, and then check the + * `query.exception()` for each query. + * + * @throws ContinuousQueryException, if any query has terminated with an exception + * + * @since 2.0.0 + */ + def awaitAnyTermination(): Unit = { + awaitTerminationLock.synchronized { + while (lastTerminatedQuery == null) { + awaitTerminationLock.wait(10) + } + if (lastTerminatedQuery != null && lastTerminatedQuery.exception.nonEmpty) { + throw lastTerminatedQuery.exception.get + } + } + } + + /** + * Wait until any of the queries on the associated SQLContext has terminated since the + * creation of the context, or since `resetTerminated()` was called. Returns whether any query + * has terminated or not (multiple may have terminated). If any query has terminated with an + * exception, then the exception will be thrown. + * + * If a query has terminated, then subsequent calls to `awaitAnyTermination()` will either + * return `true` immediately (if the query was terminated by `query.stop()`), + * or throw the exception immediately (if the query was terminated with exception). Use + * `resetTerminated()` to clear past terminations and wait for new terminations. + * + * In the case where multiple queries have terminated since `resetTermination()` was called, + * if any query has terminated with exception, then `awaitAnyTermination()` will + * throw any of the exception. For correctly documenting exceptions across multiple queries, + * users need to stop all of them after any of them terminates with exception, and then check the + * `query.exception()` for each query. + * + * @throws ContinuousQueryException, if any query has terminated with an exception + * + * @since 2.0.0 + */ + def awaitAnyTermination(timeoutMs: Long): Boolean = { + + val startTime = System.currentTimeMillis + def isTimedout = System.currentTimeMillis - startTime >= timeoutMs + + awaitTerminationLock.synchronized { + while (!isTimedout && lastTerminatedQuery == null) { + awaitTerminationLock.wait(10) + } + if (lastTerminatedQuery != null && lastTerminatedQuery.exception.nonEmpty) { + throw lastTerminatedQuery.exception.get + } + lastTerminatedQuery != null + } + } + + /** + * Forget about past terminated queries so that `awaitAnyTermination()` can be used again to + * wait for new terminations. + * + * @since 2.0.0 + */ + def resetTerminated(): Unit = { + awaitTerminationLock.synchronized { + lastTerminatedQuery = null + } + } + + /** + * Register a [[ContinuousQueryListener]] to receive up-calls for life cycle events of + * [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]]. + * + * @since 2.0.0 + */ + def addListener(listener: ContinuousQueryListener): Unit = { + listenerBus.addListener(listener) + } + + /** + * Deregister a [[ContinuousQueryListener]]. + * + * @since 2.0.0 + */ + def removeListener(listener: ContinuousQueryListener): Unit = { + listenerBus.removeListener(listener) + } + + /** Post a listener event */ + private[sql] def postListenerEvent(event: ContinuousQueryListener.Event): Unit = { + listenerBus.post(event) + } + + /** Start a query */ + private[sql] def startQuery( + name: String, + checkpointLocation: String, + df: DataFrame, + sink: Sink, + trigger: Trigger = ProcessingTime(0)): ContinuousQuery = { + activeQueriesLock.synchronized { + if (activeQueries.contains(name)) { + throw new IllegalArgumentException( + s"Cannot start query with name $name as a query with that name is already active") + } + var nextSourceId = 0L + val logicalPlan = df.logicalPlan.transform { + case StreamingRelation(dataSource, _, output) => + // Materialize source to avoid creating it in every batch + val metadataPath = s"$checkpointLocation/sources/$nextSourceId" + val source = dataSource.createSource(metadataPath) + nextSourceId += 1 + // We still need to use the previous `output` instead of `source.schema` as attributes in + // "df.logicalPlan" has already used attributes of the previous `output`. + StreamingExecutionRelation(source, output) + } + val query = new StreamExecution( + sqlContext, + name, + checkpointLocation, + logicalPlan, + sink, + trigger) + query.start() + activeQueries.put(name, query) + query + } + } + + /** Notify (by the ContinuousQuery) that the query has been terminated */ + private[sql] def notifyQueryTermination(terminatedQuery: ContinuousQuery): Unit = { + activeQueriesLock.synchronized { + activeQueries -= terminatedQuery.name + } + awaitTerminationLock.synchronized { + if (lastTerminatedQuery == null || terminatedQuery.exception.nonEmpty) { + lastTerminatedQuery = terminatedQuery + } + awaitTerminationLock.notifyAll() + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala deleted file mode 100644 index 6336dee7be6a3..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ /dev/null @@ -1,2097 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import java.io.CharArrayWriter -import java.util.Properties - -import scala.language.implicitConversions -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.TypeTag -import scala.util.control.NonFatal - -import com.fasterxml.jackson.core.JsonFactory -import org.apache.commons.lang3.StringUtils - -import org.apache.spark.annotation.{DeveloperApi, Experimental} -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.encoders.Encoder -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, SQLExecution} -import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} -import org.apache.spark.sql.execution.datasources.json.JacksonGenerator -import org.apache.spark.sql.sources.HadoopFsRelation -import org.apache.spark.sql.types._ -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.Utils - - -private[sql] object DataFrame { - def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = { - new DataFrame(sqlContext, logicalPlan) - } -} - -/** - * :: Experimental :: - * A distributed collection of data organized into named columns. - * - * A [[DataFrame]] is equivalent to a relational table in Spark SQL. The following example creates - * a [[DataFrame]] by pointing Spark SQL to a Parquet data set. - * {{{ - * val people = sqlContext.read.parquet("...") // in Scala - * DataFrame people = sqlContext.read().parquet("...") // in Java - * }}} - * - * Once created, it can be manipulated using the various domain-specific-language (DSL) functions - * defined in: [[DataFrame]] (this class), [[Column]], and [[functions]]. - * - * To select a column from the data frame, use `apply` method in Scala and `col` in Java. - * {{{ - * val ageCol = people("age") // in Scala - * Column ageCol = people.col("age") // in Java - * }}} - * - * Note that the [[Column]] type can also be manipulated through its various functions. - * {{{ - * // The following creates a new column that increases everybody's age by 10. - * people("age") + 10 // in Scala - * people.col("age").plus(10); // in Java - * }}} - * - * A more concrete example in Scala: - * {{{ - * // To create DataFrame using SQLContext - * val people = sqlContext.read.parquet("...") - * val department = sqlContext.read.parquet("...") - * - * people.filter("age > 30") - * .join(department, people("deptId") === department("id")) - * .groupBy(department("name"), "gender") - * .agg(avg(people("salary")), max(people("age"))) - * }}} - * - * and in Java: - * {{{ - * // To create DataFrame using SQLContext - * DataFrame people = sqlContext.read().parquet("..."); - * DataFrame department = sqlContext.read().parquet("..."); - * - * people.filter("age".gt(30)) - * .join(department, people.col("deptId").equalTo(department("id"))) - * .groupBy(department.col("name"), "gender") - * .agg(avg(people.col("salary")), max(people.col("age"))); - * }}} - * - * @groupname basic Basic DataFrame functions - * @groupname dfops Language Integrated Queries - * @groupname rdd RDD Operations - * @groupname output Output Operations - * @groupname action Actions - * @since 1.3.0 - */ -// TODO: Improve documentation. -@Experimental -class DataFrame private[sql]( - @transient val sqlContext: SQLContext, - @DeveloperApi @transient val queryExecution: QueryExecution) extends Serializable { - - // Note for Spark contributors: if adding or updating any action in `DataFrame`, please make sure - // you wrap it with `withNewExecutionId` if this actions doesn't call other action. - - /** - * A constructor that automatically analyzes the logical plan. - * - * This reports error eagerly as the [[DataFrame]] is constructed, unless - * [[SQLConf.dataFrameEagerAnalysis]] is turned off. - */ - def this(sqlContext: SQLContext, logicalPlan: LogicalPlan) = { - this(sqlContext, { - val qe = sqlContext.executePlan(logicalPlan) - if (sqlContext.conf.dataFrameEagerAnalysis) { - qe.assertAnalyzed() // This should force analysis and throw errors if there are any - } - qe - }) - } - - @transient protected[sql] val logicalPlan: LogicalPlan = queryExecution.logical match { - // For various commands (like DDL) and queries with side effects, we force query optimization to - // happen right away to let these side effects take place eagerly. - case _: Command | - _: InsertIntoTable | - _: CreateTableUsingAsSelect => - LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) - case _ => - queryExecution.analyzed - } - - protected[sql] def resolve(colName: String): NamedExpression = { - queryExecution.analyzed.resolveQuoted(colName, sqlContext.analyzer.resolver).getOrElse { - throw new AnalysisException( - s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") - } - } - - protected[sql] def numericColumns: Seq[Expression] = { - schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => - queryExecution.analyzed.resolveQuoted(n.name, sqlContext.analyzer.resolver).get - } - } - - /** - * Compose the string representing rows for output - * @param _numRows Number of rows to show - * @param truncate Whether truncate long strings and align cells right - */ - private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { - val numRows = _numRows.max(0) - val sb = new StringBuilder - val takeResult = take(numRows + 1) - val hasMoreData = takeResult.length > numRows - val data = takeResult.take(numRows) - val numCols = schema.fieldNames.length - - // For array values, replace Seq and Array with square brackets - // For cells that are beyond 20 characters, replace it with the first 17 and "..." - val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row => - row.toSeq.map { cell => - val str = cell match { - case null => "null" - case array: Array[_] => array.mkString("[", ", ", "]") - case seq: Seq[_] => seq.mkString("[", ", ", "]") - case _ => cell.toString - } - if (truncate && str.length > 20) str.substring(0, 17) + "..." else str - }: Seq[String] - } - - // Initialise the width of each column to a minimum value of '3' - val colWidths = Array.fill(numCols)(3) - - // Compute the width of each column - for (row <- rows) { - for ((cell, i) <- row.zipWithIndex) { - colWidths(i) = math.max(colWidths(i), cell.length) - } - } - - // Create SeparateLine - val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() - - // column names - rows.head.zipWithIndex.map { case (cell, i) => - if (truncate) { - StringUtils.leftPad(cell, colWidths(i)) - } else { - StringUtils.rightPad(cell, colWidths(i)) - } - }.addString(sb, "|", "|", "|\n") - - sb.append(sep) - - // data - rows.tail.map { - _.zipWithIndex.map { case (cell, i) => - if (truncate) { - StringUtils.leftPad(cell.toString, colWidths(i)) - } else { - StringUtils.rightPad(cell.toString, colWidths(i)) - } - }.addString(sb, "|", "|", "|\n") - } - - sb.append(sep) - - // For Data that has more than "numRows" records - if (hasMoreData) { - val rowsString = if (numRows == 1) "row" else "rows" - sb.append(s"only showing top $numRows $rowsString\n") - } - - sb.toString() - } - - override def toString: String = { - try { - schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]") - } catch { - case NonFatal(e) => - s"Invalid tree; ${e.getMessage}:\n$queryExecution" - } - } - - /** - * Returns the object itself. - * @group basic - * @since 1.3.0 - */ - // This is declared with parentheses to prevent the Scala compiler from treating - // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDF(): DataFrame = this - - /** - * :: Experimental :: - * Converts this [[DataFrame]] to a strongly-typed [[Dataset]] containing objects of the - * specified type, `U`. - * @group basic - * @since 1.6.0 - */ - @Experimental - def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, logicalPlan) - - /** - * Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion - * from a RDD of tuples into a [[DataFrame]] with meaningful names. For example: - * {{{ - * val rdd: RDD[(Int, String)] = ... - * rdd.toDF() // this implicit conversion creates a DataFrame with column name _1 and _2 - * rdd.toDF("id", "name") // this creates a DataFrame with column name "id" and "name" - * }}} - * @group basic - * @since 1.3.0 - */ - @scala.annotation.varargs - def toDF(colNames: String*): DataFrame = { - require(schema.size == colNames.size, - "The number of columns doesn't match.\n" + - s"Old column names (${schema.size}): " + schema.fields.map(_.name).mkString(", ") + "\n" + - s"New column names (${colNames.size}): " + colNames.mkString(", ")) - - val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, newName) => - Column(oldAttribute).as(newName) - } - select(newCols : _*) - } - - /** - * Returns the schema of this [[DataFrame]]. - * @group basic - * @since 1.3.0 - */ - def schema: StructType = queryExecution.analyzed.schema - - /** - * Returns all column names and their data types as an array. - * @group basic - * @since 1.3.0 - */ - def dtypes: Array[(String, String)] = schema.fields.map { field => - (field.name, field.dataType.toString) - } - - /** - * Returns all column names as an array. - * @group basic - * @since 1.3.0 - */ - def columns: Array[String] = schema.fields.map(_.name) - - /** - * Prints the schema to the console in a nice tree format. - * @group basic - * @since 1.3.0 - */ - // scalastyle:off println - def printSchema(): Unit = println(schema.treeString) - // scalastyle:on println - - /** - * Prints the plans (logical and physical) to the console for debugging purposes. - * @group basic - * @since 1.3.0 - */ - def explain(extended: Boolean): Unit = { - val explain = ExplainCommand(queryExecution.logical, extended = extended) - withPlan(explain).queryExecution.executedPlan.executeCollect().foreach { - // scalastyle:off println - r => println(r.getString(0)) - // scalastyle:on println - } - } - - /** - * Only prints the physical plan to the console for debugging purposes. - * @group basic - * @since 1.3.0 - */ - def explain(): Unit = explain(extended = false) - - /** - * Returns true if the `collect` and `take` methods can be run locally - * (without any Spark executors). - * @group basic - * @since 1.3.0 - */ - def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation] - - /** - * Displays the [[DataFrame]] in a tabular form. Strings more than 20 characters will be - * truncated, and all cells will be aligned right. For example: - * {{{ - * year month AVG('Adj Close) MAX('Adj Close) - * 1980 12 0.503218 0.595103 - * 1981 01 0.523289 0.570307 - * 1982 02 0.436504 0.475256 - * 1983 03 0.410516 0.442194 - * 1984 04 0.450090 0.483521 - * }}} - * @param numRows Number of rows to show - * - * @group action - * @since 1.3.0 - */ - def show(numRows: Int): Unit = show(numRows, truncate = true) - - /** - * Displays the top 20 rows of [[DataFrame]] in a tabular form. Strings more than 20 characters - * will be truncated, and all cells will be aligned right. - * @group action - * @since 1.3.0 - */ - def show(): Unit = show(20) - - /** - * Displays the top 20 rows of [[DataFrame]] in a tabular form. - * - * @param truncate Whether truncate long strings. If true, strings more than 20 characters will - * be truncated and all cells will be aligned right - * - * @group action - * @since 1.5.0 - */ - def show(truncate: Boolean): Unit = show(20, truncate) - - /** - * Displays the [[DataFrame]] in a tabular form. For example: - * {{{ - * year month AVG('Adj Close) MAX('Adj Close) - * 1980 12 0.503218 0.595103 - * 1981 01 0.523289 0.570307 - * 1982 02 0.436504 0.475256 - * 1983 03 0.410516 0.442194 - * 1984 04 0.450090 0.483521 - * }}} - * @param numRows Number of rows to show - * @param truncate Whether truncate long strings. If true, strings more than 20 characters will - * be truncated and all cells will be aligned right - * - * @group action - * @since 1.5.0 - */ - // scalastyle:off println - def show(numRows: Int, truncate: Boolean): Unit = println(showString(numRows, truncate)) - // scalastyle:on println - - /** - * Returns a [[DataFrameNaFunctions]] for working with missing data. - * {{{ - * // Dropping rows containing any null values. - * df.na.drop() - * }}} - * - * @group dfops - * @since 1.3.1 - */ - def na: DataFrameNaFunctions = new DataFrameNaFunctions(this) - - /** - * Returns a [[DataFrameStatFunctions]] for working statistic functions support. - * {{{ - * // Finding frequent items in column with name 'a'. - * df.stat.freqItems(Seq("a")) - * }}} - * - * @group dfops - * @since 1.4.0 - */ - def stat: DataFrameStatFunctions = new DataFrameStatFunctions(this) - - /** - * Cartesian join with another [[DataFrame]]. - * - * Note that cartesian joins are very expensive without an extra filter that can be pushed down. - * - * @param right Right side of the join operation. - * @group dfops - * @since 1.3.0 - */ - def join(right: DataFrame): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Inner, None) - } - - /** - * Inner equi-join with another [[DataFrame]] using the given column. - * - * Different from other join functions, the join column will only appear once in the output, - * i.e. similar to SQL's `JOIN USING` syntax. - * - * {{{ - * // Joining df1 and df2 using the column "user_id" - * df1.join(df2, "user_id") - * }}} - * - * Note that if you perform a self-join using this function without aliasing the input - * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since - * there is no way to disambiguate which side of the join you would like to reference. - * - * @param right Right side of the join operation. - * @param usingColumn Name of the column to join on. This column must exist on both sides. - * @group dfops - * @since 1.4.0 - */ - def join(right: DataFrame, usingColumn: String): DataFrame = { - join(right, Seq(usingColumn)) - } - - /** - * Inner equi-join with another [[DataFrame]] using the given columns. - * - * Different from other join functions, the join columns will only appear once in the output, - * i.e. similar to SQL's `JOIN USING` syntax. - * - * {{{ - * // Joining df1 and df2 using the columns "user_id" and "user_name" - * df1.join(df2, Seq("user_id", "user_name")) - * }}} - * - * Note that if you perform a self-join using this function without aliasing the input - * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since - * there is no way to disambiguate which side of the join you would like to reference. - * - * @param right Right side of the join operation. - * @param usingColumns Names of the columns to join on. This columns must exist on both sides. - * @group dfops - * @since 1.4.0 - */ - def join(right: DataFrame, usingColumns: Seq[String]): DataFrame = { - join(right, usingColumns, "inner") - } - - /** - * Equi-join with another [[DataFrame]] using the given columns. - * - * Different from other join functions, the join columns will only appear once in the output, - * i.e. similar to SQL's `JOIN USING` syntax. - * - * Note that if you perform a self-join using this function without aliasing the input - * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since - * there is no way to disambiguate which side of the join you would like to reference. - * - * @param right Right side of the join operation. - * @param usingColumns Names of the columns to join on. This columns must exist on both sides. - * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. - * @group dfops - * @since 1.6.0 - */ - def join(right: DataFrame, usingColumns: Seq[String], joinType: String): DataFrame = { - // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right - // by creating a new instance for one of the branch. - val joined = sqlContext.executePlan( - Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join] - - // Project only one of the join columns. - val joinedCols = usingColumns.map(col => withPlan(joined.right).resolve(col)) - val condition = usingColumns.map { col => - catalyst.expressions.EqualTo( - withPlan(joined.left).resolve(col), - withPlan(joined.right).resolve(col)) - }.reduceLeftOption[catalyst.expressions.BinaryExpression] { (cond, eqTo) => - catalyst.expressions.And(cond, eqTo) - } - - withPlan { - Project( - joined.output.filterNot(joinedCols.contains(_)), - Join( - joined.left, - joined.right, - joinType = JoinType(joinType), - condition) - ) - } - } - - /** - * Inner join with another [[DataFrame]], using the given join expression. - * - * {{{ - * // The following two are equivalent: - * df1.join(df2, $"df1Key" === $"df2Key") - * df1.join(df2).where($"df1Key" === $"df2Key") - * }}} - * @group dfops - * @since 1.3.0 - */ - def join(right: DataFrame, joinExprs: Column): DataFrame = join(right, joinExprs, "inner") - - /** - * Join with another [[DataFrame]], using the given join expression. The following performs - * a full outer join between `df1` and `df2`. - * - * {{{ - * // Scala: - * import org.apache.spark.sql.functions._ - * df1.join(df2, $"df1Key" === $"df2Key", "outer") - * - * // Java: - * import static org.apache.spark.sql.functions.*; - * df1.join(df2, col("df1Key").equalTo(col("df2Key")), "outer"); - * }}} - * - * @param right Right side of the join. - * @param joinExprs Join expression. - * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. - * @group dfops - * @since 1.3.0 - */ - def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = { - // Note that in this function, we introduce a hack in the case of self-join to automatically - // resolve ambiguous join conditions into ones that might make sense [SPARK-6231]. - // Consider this case: df.join(df, df("key") === df("key")) - // Since df("key") === df("key") is a trivially true condition, this actually becomes a - // cartesian join. However, most likely users expect to perform a self join using "key". - // With that assumption, this hack turns the trivially true condition into equality on join - // keys that are resolved to both sides. - - // Trigger analysis so in the case of self-join, the analyzer will clone the plan. - // After the cloning, left and right side will have distinct expression ids. - val plan = withPlan( - Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))) - .queryExecution.analyzed.asInstanceOf[Join] - - // If auto self join alias is disabled, return the plan. - if (!sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity) { - return withPlan(plan) - } - - // If left/right have no output set intersection, return the plan. - val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed - val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed - if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { - return withPlan(plan) - } - - // Otherwise, find the trivially true predicates and automatically resolves them to both sides. - // By the time we get here, since we have already run analysis, all attributes should've been - // resolved and become AttributeReference. - val cond = plan.condition.map { _.transform { - case catalyst.expressions.EqualTo(a: AttributeReference, b: AttributeReference) - if a.sameRef(b) => - catalyst.expressions.EqualTo( - withPlan(plan.left).resolve(a.name), - withPlan(plan.right).resolve(b.name)) - }} - - withPlan { - plan.copy(condition = cond) - } - } - - /** - * Returns a new [[DataFrame]] with each partition sorted by the given expressions. - * - * This is the same operation as "SORT BY" in SQL (Hive QL). - * - * @group dfops - * @since 1.6.0 - */ - @scala.annotation.varargs - def sortWithinPartitions(sortCol: String, sortCols: String*): DataFrame = { - sortWithinPartitions(sortCol, sortCols : _*) - } - - /** - * Returns a new [[DataFrame]] with each partition sorted by the given expressions. - * - * This is the same operation as "SORT BY" in SQL (Hive QL). - * - * @group dfops - * @since 1.6.0 - */ - @scala.annotation.varargs - def sortWithinPartitions(sortExprs: Column*): DataFrame = { - sortInternal(global = false, sortExprs) - } - - /** - * Returns a new [[DataFrame]] sorted by the specified column, all in ascending order. - * {{{ - * // The following 3 are equivalent - * df.sort("sortcol") - * df.sort($"sortcol") - * df.sort($"sortcol".asc) - * }}} - * @group dfops - * @since 1.3.0 - */ - @scala.annotation.varargs - def sort(sortCol: String, sortCols: String*): DataFrame = { - sort((sortCol +: sortCols).map(apply) : _*) - } - - /** - * Returns a new [[DataFrame]] sorted by the given expressions. For example: - * {{{ - * df.sort($"col1", $"col2".desc) - * }}} - * @group dfops - * @since 1.3.0 - */ - @scala.annotation.varargs - def sort(sortExprs: Column*): DataFrame = { - sortInternal(global = true, sortExprs) - } - - /** - * Returns a new [[DataFrame]] sorted by the given expressions. - * This is an alias of the `sort` function. - * @group dfops - * @since 1.3.0 - */ - @scala.annotation.varargs - def orderBy(sortCol: String, sortCols: String*): DataFrame = sort(sortCol, sortCols : _*) - - /** - * Returns a new [[DataFrame]] sorted by the given expressions. - * This is an alias of the `sort` function. - * @group dfops - * @since 1.3.0 - */ - @scala.annotation.varargs - def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs : _*) - - /** - * Selects column based on the column name and return it as a [[Column]]. - * Note that the column name can also reference to a nested column like `a.b`. - * @group dfops - * @since 1.3.0 - */ - def apply(colName: String): Column = col(colName) - - /** - * Selects column based on the column name and return it as a [[Column]]. - * Note that the column name can also reference to a nested column like `a.b`. - * @group dfops - * @since 1.3.0 - */ - def col(colName: String): Column = colName match { - case "*" => - Column(ResolvedStar(schema.fieldNames.map(resolve))) - case _ => - val expr = resolve(colName) - Column(expr) - } - - /** - * Returns a new [[DataFrame]] with an alias set. - * @group dfops - * @since 1.3.0 - */ - def as(alias: String): DataFrame = withPlan { - Subquery(alias, logicalPlan) - } - - /** - * (Scala-specific) Returns a new [[DataFrame]] with an alias set. - * @group dfops - * @since 1.3.0 - */ - def as(alias: Symbol): DataFrame = as(alias.name) - - /** - * Returns a new [[DataFrame]] with an alias set. Same as `as`. - * @group dfops - * @since 1.6.0 - */ - def alias(alias: String): DataFrame = as(alias) - - /** - * (Scala-specific) Returns a new [[DataFrame]] with an alias set. Same as `as`. - * @group dfops - * @since 1.6.0 - */ - def alias(alias: Symbol): DataFrame = as(alias) - - /** - * Selects a set of column based expressions. - * {{{ - * df.select($"colA", $"colB" + 1) - * }}} - * @group dfops - * @since 1.3.0 - */ - @scala.annotation.varargs - def select(cols: Column*): DataFrame = withPlan { - val namedExpressions = cols.map { - // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we - // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to - // make it a NamedExpression. - case Column(u: UnresolvedAttribute) => UnresolvedAlias(u) - case Column(expr: NamedExpression) => expr - // Leave an unaliased explode with an empty list of names since the analyzer will generate the - // correct defaults after the nested expression's type has been resolved. - case Column(explode: Explode) => MultiAlias(explode, Nil) - case Column(expr: Expression) => Alias(expr, expr.prettyString)() - } - Project(namedExpressions.toSeq, logicalPlan) - } - - /** - * Selects a set of columns. This is a variant of `select` that can only select - * existing columns using column names (i.e. cannot construct expressions). - * - * {{{ - * // The following two are equivalent: - * df.select("colA", "colB") - * df.select($"colA", $"colB") - * }}} - * @group dfops - * @since 1.3.0 - */ - @scala.annotation.varargs - def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)) : _*) - - /** - * Selects a set of SQL expressions. This is a variant of `select` that accepts - * SQL expressions. - * - * {{{ - * // The following are equivalent: - * df.selectExpr("colA", "colB as newName", "abs(colC)") - * df.select(expr("colA"), expr("colB as newName"), expr("abs(colC)")) - * }}} - * @group dfops - * @since 1.3.0 - */ - @scala.annotation.varargs - def selectExpr(exprs: String*): DataFrame = { - select(exprs.map { expr => - Column(SqlParser.parseExpression(expr)) - }: _*) - } - - /** - * Filters rows using the given condition. - * {{{ - * // The following are equivalent: - * peopleDf.filter($"age" > 15) - * peopleDf.where($"age" > 15) - * }}} - * @group dfops - * @since 1.3.0 - */ - def filter(condition: Column): DataFrame = withPlan { - Filter(condition.expr, logicalPlan) - } - - /** - * Filters rows using the given SQL expression. - * {{{ - * peopleDf.filter("age > 15") - * }}} - * @group dfops - * @since 1.3.0 - */ - def filter(conditionExpr: String): DataFrame = { - filter(Column(SqlParser.parseExpression(conditionExpr))) - } - - /** - * Filters rows using the given condition. This is an alias for `filter`. - * {{{ - * // The following are equivalent: - * peopleDf.filter($"age" > 15) - * peopleDf.where($"age" > 15) - * }}} - * @group dfops - * @since 1.3.0 - */ - def where(condition: Column): DataFrame = filter(condition) - - /** - * Filters rows using the given SQL expression. - * {{{ - * peopleDf.where("age > 15") - * }}} - * @group dfops - * @since 1.5.0 - */ - def where(conditionExpr: String): DataFrame = { - filter(Column(SqlParser.parseExpression(conditionExpr))) - } - - /** - * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. - * See [[GroupedData]] for all the available aggregate functions. - * - * {{{ - * // Compute the average for all numeric columns grouped by department. - * df.groupBy($"department").avg() - * - * // Compute the max age and average salary, grouped by department and gender. - * df.groupBy($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * @group dfops - * @since 1.3.0 - */ - @scala.annotation.varargs - def groupBy(cols: Column*): GroupedData = { - GroupedData(this, cols.map(_.expr), GroupedData.GroupByType) - } - - /** - * Create a multi-dimensional rollup for the current [[DataFrame]] using the specified columns, - * so we can run aggregation on them. - * See [[GroupedData]] for all the available aggregate functions. - * - * {{{ - * // Compute the average for all numeric columns rolluped by department and group. - * df.rollup($"department", $"group").avg() - * - * // Compute the max age and average salary, rolluped by department and gender. - * df.rollup($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * @group dfops - * @since 1.4.0 - */ - @scala.annotation.varargs - def rollup(cols: Column*): GroupedData = { - GroupedData(this, cols.map(_.expr), GroupedData.RollupType) - } - - /** - * Create a multi-dimensional cube for the current [[DataFrame]] using the specified columns, - * so we can run aggregation on them. - * See [[GroupedData]] for all the available aggregate functions. - * - * {{{ - * // Compute the average for all numeric columns cubed by department and group. - * df.cube($"department", $"group").avg() - * - * // Compute the max age and average salary, cubed by department and gender. - * df.cube($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * @group dfops - * @since 1.4.0 - */ - @scala.annotation.varargs - def cube(cols: Column*): GroupedData = GroupedData(this, cols.map(_.expr), GroupedData.CubeType) - - /** - * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. - * See [[GroupedData]] for all the available aggregate functions. - * - * This is a variant of groupBy that can only group by existing columns using column names - * (i.e. cannot construct expressions). - * - * {{{ - * // Compute the average for all numeric columns grouped by department. - * df.groupBy("department").avg() - * - * // Compute the max age and average salary, grouped by department and gender. - * df.groupBy($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * @group dfops - * @since 1.3.0 - */ - @scala.annotation.varargs - def groupBy(col1: String, cols: String*): GroupedData = { - val colNames: Seq[String] = col1 +: cols - GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.GroupByType) - } - - /** - * Create a multi-dimensional rollup for the current [[DataFrame]] using the specified columns, - * so we can run aggregation on them. - * See [[GroupedData]] for all the available aggregate functions. - * - * This is a variant of rollup that can only group by existing columns using column names - * (i.e. cannot construct expressions). - * - * {{{ - * // Compute the average for all numeric columns rolluped by department and group. - * df.rollup("department", "group").avg() - * - * // Compute the max age and average salary, rolluped by department and gender. - * df.rollup($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * @group dfops - * @since 1.4.0 - */ - @scala.annotation.varargs - def rollup(col1: String, cols: String*): GroupedData = { - val colNames: Seq[String] = col1 +: cols - GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.RollupType) - } - - /** - * Create a multi-dimensional cube for the current [[DataFrame]] using the specified columns, - * so we can run aggregation on them. - * See [[GroupedData]] for all the available aggregate functions. - * - * This is a variant of cube that can only group by existing columns using column names - * (i.e. cannot construct expressions). - * - * {{{ - * // Compute the average for all numeric columns cubed by department and group. - * df.cube("department", "group").avg() - * - * // Compute the max age and average salary, cubed by department and gender. - * df.cube($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * @group dfops - * @since 1.4.0 - */ - @scala.annotation.varargs - def cube(col1: String, cols: String*): GroupedData = { - val colNames: Seq[String] = col1 +: cols - GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.CubeType) - } - - /** - * (Scala-specific) Aggregates on the entire [[DataFrame]] without groups. - * {{{ - * // df.agg(...) is a shorthand for df.groupBy().agg(...) - * df.agg("age" -> "max", "salary" -> "avg") - * df.groupBy().agg("age" -> "max", "salary" -> "avg") - * }}} - * @group dfops - * @since 1.3.0 - */ - def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { - groupBy().agg(aggExpr, aggExprs : _*) - } - - /** - * (Scala-specific) Aggregates on the entire [[DataFrame]] without groups. - * {{{ - * // df.agg(...) is a shorthand for df.groupBy().agg(...) - * df.agg(Map("age" -> "max", "salary" -> "avg")) - * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) - * }}} - * @group dfops - * @since 1.3.0 - */ - def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs) - - /** - * (Java-specific) Aggregates on the entire [[DataFrame]] without groups. - * {{{ - * // df.agg(...) is a shorthand for df.groupBy().agg(...) - * df.agg(Map("age" -> "max", "salary" -> "avg")) - * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) - * }}} - * @group dfops - * @since 1.3.0 - */ - def agg(exprs: java.util.Map[String, String]): DataFrame = groupBy().agg(exprs) - - /** - * Aggregates on the entire [[DataFrame]] without groups. - * {{{ - * // df.agg(...) is a shorthand for df.groupBy().agg(...) - * df.agg(max($"age"), avg($"salary")) - * df.groupBy().agg(max($"age"), avg($"salary")) - * }}} - * @group dfops - * @since 1.3.0 - */ - @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs : _*) - - /** - * Returns a new [[DataFrame]] by taking the first `n` rows. The difference between this function - * and `head` is that `head` returns an array while `limit` returns a new [[DataFrame]]. - * @group dfops - * @since 1.3.0 - */ - def limit(n: Int): DataFrame = withPlan { - Limit(Literal(n), logicalPlan) - } - - /** - * Returns a new [[DataFrame]] containing union of rows in this frame and another frame. - * This is equivalent to `UNION ALL` in SQL. - * @group dfops - * @since 1.3.0 - */ - def unionAll(other: DataFrame): DataFrame = withPlan { - Union(logicalPlan, other.logicalPlan) - } - - /** - * Returns a new [[DataFrame]] containing rows only in both this frame and another frame. - * This is equivalent to `INTERSECT` in SQL. - * @group dfops - * @since 1.3.0 - */ - def intersect(other: DataFrame): DataFrame = withPlan { - Intersect(logicalPlan, other.logicalPlan) - } - - /** - * Returns a new [[DataFrame]] containing rows in this frame but not in another frame. - * This is equivalent to `EXCEPT` in SQL. - * @group dfops - * @since 1.3.0 - */ - def except(other: DataFrame): DataFrame = withPlan { - Except(logicalPlan, other.logicalPlan) - } - - /** - * Returns a new [[DataFrame]] by sampling a fraction of rows. - * - * @param withReplacement Sample with replacement or not. - * @param fraction Fraction of rows to generate. - * @param seed Seed for sampling. - * @group dfops - * @since 1.3.0 - */ - def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = withPlan { - Sample(0.0, fraction, withReplacement, seed, logicalPlan) - } - - /** - * Returns a new [[DataFrame]] by sampling a fraction of rows, using a random seed. - * - * @param withReplacement Sample with replacement or not. - * @param fraction Fraction of rows to generate. - * @group dfops - * @since 1.3.0 - */ - def sample(withReplacement: Boolean, fraction: Double): DataFrame = { - sample(withReplacement, fraction, Utils.random.nextLong) - } - - /** - * Randomly splits this [[DataFrame]] with the provided weights. - * - * @param weights weights for splits, will be normalized if they don't sum to 1. - * @param seed Seed for sampling. - * @group dfops - * @since 1.4.0 - */ - def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = { - val sum = weights.sum - val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) - normalizedCumWeights.sliding(2).map { x => - new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, logicalPlan)) - }.toArray - } - - /** - * Randomly splits this [[DataFrame]] with the provided weights. - * - * @param weights weights for splits, will be normalized if they don't sum to 1. - * @group dfops - * @since 1.4.0 - */ - def randomSplit(weights: Array[Double]): Array[DataFrame] = { - randomSplit(weights, Utils.random.nextLong) - } - - /** - * Randomly splits this [[DataFrame]] with the provided weights. Provided for the Python Api. - * - * @param weights weights for splits, will be normalized if they don't sum to 1. - * @param seed Seed for sampling. - * @group dfops - */ - private[spark] def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = { - randomSplit(weights.toArray, seed) - } - - /** - * (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more - * rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of - * the input row are implicitly joined with each row that is output by the function. - * - * The following example uses this function to count the number of books which contain - * a given word: - * - * {{{ - * case class Book(title: String, words: String) - * val df: RDD[Book] - * - * case class Word(word: String) - * val allWords = df.explode('words) { - * case Row(words: String) => words.split(" ").map(Word(_)) - * } - * - * val bookCountPerWord = allWords.groupBy("word").agg(countDistinct("title")) - * }}} - * @group dfops - * @since 1.3.0 - */ - def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { - val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] - - val elementTypes = schema.toAttributes.map { - attr => (attr.dataType, attr.nullable, attr.name) } - val names = schema.toAttributes.map(_.name) - val convert = CatalystTypeConverters.createToCatalystConverter(schema) - - val rowFunction = - f.andThen(_.map(convert(_).asInstanceOf[InternalRow])) - val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr)) - - withPlan { - Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, logicalPlan) - } - } - - /** - * (Scala-specific) Returns a new [[DataFrame]] where a single column has been expanded to zero - * or more rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. All - * columns of the input row are implicitly joined with each value that is output by the function. - * - * {{{ - * df.explode("words", "word"){words: String => words.split(" ")} - * }}} - * @group dfops - * @since 1.3.0 - */ - def explode[A, B : TypeTag](inputColumn: String, outputColumn: String)(f: A => TraversableOnce[B]) - : DataFrame = { - val dataType = ScalaReflection.schemaFor[B].dataType - val attributes = AttributeReference(outputColumn, dataType)() :: Nil - // TODO handle the metadata? - val elementTypes = attributes.map { attr => (attr.dataType, attr.nullable, attr.name) } - - def rowFunction(row: Row): TraversableOnce[InternalRow] = { - val convert = CatalystTypeConverters.createToCatalystConverter(dataType) - f(row(0).asInstanceOf[A]).map(o => InternalRow(convert(o))) - } - val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil) - - withPlan { - Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, logicalPlan) - } - } - - ///////////////////////////////////////////////////////////////////////////// - - /** - * Returns a new [[DataFrame]] by adding a column or replacing the existing column that has - * the same name. - * @group dfops - * @since 1.3.0 - */ - def withColumn(colName: String, col: Column): DataFrame = { - val resolver = sqlContext.analyzer.resolver - val replaced = schema.exists(f => resolver(f.name, colName)) - if (replaced) { - val colNames = schema.map { field => - val name = field.name - if (resolver(name, colName)) col.as(colName) else Column(name) - } - select(colNames : _*) - } else { - select(Column("*"), col.as(colName)) - } - } - - /** - * Returns a new [[DataFrame]] by adding a column with metadata. - */ - private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = { - val resolver = sqlContext.analyzer.resolver - val replaced = schema.exists(f => resolver(f.name, colName)) - if (replaced) { - val colNames = schema.map { field => - val name = field.name - if (resolver(name, colName)) col.as(colName, metadata) else Column(name) - } - select(colNames : _*) - } else { - select(Column("*"), col.as(colName, metadata)) - } - } - - /** - * Returns a new [[DataFrame]] with a column renamed. - * This is a no-op if schema doesn't contain existingName. - * @group dfops - * @since 1.3.0 - */ - def withColumnRenamed(existingName: String, newName: String): DataFrame = { - val resolver = sqlContext.analyzer.resolver - val shouldRename = schema.exists(f => resolver(f.name, existingName)) - if (shouldRename) { - val colNames = schema.map { field => - val name = field.name - if (resolver(name, existingName)) Column(name).as(newName) else Column(name) - } - select(colNames : _*) - } else { - this - } - } - - /** - * Returns a new [[DataFrame]] with a column dropped. - * This is a no-op if schema doesn't contain column name. - * @group dfops - * @since 1.4.0 - */ - def drop(colName: String): DataFrame = { - val resolver = sqlContext.analyzer.resolver - val shouldDrop = schema.exists(f => resolver(f.name, colName)) - if (shouldDrop) { - val colsAfterDrop = schema.filter { field => - val name = field.name - !resolver(name, colName) - }.map(f => Column(f.name)) - select(colsAfterDrop : _*) - } else { - this - } - } - - /** - * Returns a new [[DataFrame]] with a column dropped. - * This version of drop accepts a Column rather than a name. - * This is a no-op if the DataFrame doesn't have a column - * with an equivalent expression. - * @group dfops - * @since 1.4.1 - */ - def drop(col: Column): DataFrame = { - val expression = col match { - case Column(u: UnresolvedAttribute) => - queryExecution.analyzed.resolveQuoted(u.name, sqlContext.analyzer.resolver).getOrElse(u) - case Column(expr: Expression) => expr - } - val attrs = this.logicalPlan.output - val colsAfterDrop = attrs.filter { attr => - attr != expression - }.map(attr => Column(attr)) - select(colsAfterDrop : _*) - } - - /** - * Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]]. - * This is an alias for `distinct`. - * @group dfops - * @since 1.4.0 - */ - def dropDuplicates(): DataFrame = dropDuplicates(this.columns) - - /** - * (Scala-specific) Returns a new [[DataFrame]] with duplicate rows removed, considering only - * the subset of columns. - * - * @group dfops - * @since 1.4.0 - */ - def dropDuplicates(colNames: Seq[String]): DataFrame = withPlan { - val groupCols = colNames.map(resolve) - val groupColExprIds = groupCols.map(_.exprId) - val aggCols = logicalPlan.output.map { attr => - if (groupColExprIds.contains(attr.exprId)) { - attr - } else { - Alias(First(attr), attr.name)() - } - } - Aggregate(groupCols, aggCols, logicalPlan) - } - - /** - * Returns a new [[DataFrame]] with duplicate rows removed, considering only - * the subset of columns. - * - * @group dfops - * @since 1.4.0 - */ - def dropDuplicates(colNames: Array[String]): DataFrame = dropDuplicates(colNames.toSeq) - - /** - * Computes statistics for numeric columns, including count, mean, stddev, min, and max. - * If no columns are given, this function computes statistics for all numerical columns. - * - * This function is meant for exploratory data analysis, as we make no guarantee about the - * backward compatibility of the schema of the resulting [[DataFrame]]. If you want to - * programmatically compute summary statistics, use the `agg` function instead. - * - * {{{ - * df.describe("age", "height").show() - * - * // output: - * // summary age height - * // count 10.0 10.0 - * // mean 53.3 178.05 - * // stddev 11.6 15.7 - * // min 18.0 163.0 - * // max 92.0 192.0 - * }}} - * - * @group action - * @since 1.3.1 - */ - @scala.annotation.varargs - def describe(cols: String*): DataFrame = withPlan { - - // The list of summary statistics to compute, in the form of expressions. - val statistics = List[(String, Expression => Expression)]( - "count" -> Count, - "mean" -> Average, - "stddev" -> StddevSamp, - "min" -> Min, - "max" -> Max) - - val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList - - val ret: Seq[Row] = if (outputCols.nonEmpty) { - val aggExprs = statistics.flatMap { case (_, colToAgg) => - outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) - } - - val row = agg(aggExprs.head, aggExprs.tail: _*).head().toSeq - - // Pivot the data so each summary is one row - row.grouped(outputCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) => - Row(statistic :: aggregation.toList: _*) - } - } else { - // If there are no output columns, just output a single column that contains the stats. - statistics.map { case (name, _) => Row(name) } - } - - // All columns are string type - val schema = StructType( - StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes - LocalRelation.fromExternalRows(schema, ret) - } - - /** - * Returns the first `n` rows. - * @group action - * @since 1.3.0 - */ - def head(n: Int): Array[Row] = withCallback("head", limit(n)) { df => - df.collect(needCallback = false) - } - - /** - * Returns the first row. - * @group action - * @since 1.3.0 - */ - def head(): Row = head(1).head - - /** - * Returns the first row. Alias for head(). - * @group action - * @since 1.3.0 - */ - def first(): Row = head() - - /** - * Returns a new RDD by applying a function to all rows of this DataFrame. - * @group rdd - * @since 1.3.0 - */ - def map[R: ClassTag](f: Row => R): RDD[R] = rdd.map(f) - - /** - * Returns a new RDD by first applying a function to all rows of this [[DataFrame]], - * and then flattening the results. - * @group rdd - * @since 1.3.0 - */ - def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f) - - /** - * Returns a new RDD by applying a function to each partition of this DataFrame. - * @group rdd - * @since 1.3.0 - */ - def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = { - rdd.mapPartitions(f) - } - - /** - * Applies a function `f` to all rows. - * @group rdd - * @since 1.3.0 - */ - def foreach(f: Row => Unit): Unit = withNewExecutionId { - rdd.foreach(f) - } - - /** - * Applies a function f to each partition of this [[DataFrame]]. - * @group rdd - * @since 1.3.0 - */ - def foreachPartition(f: Iterator[Row] => Unit): Unit = withNewExecutionId { - rdd.foreachPartition(f) - } - - /** - * Returns the first `n` rows in the [[DataFrame]]. - * @group action - * @since 1.3.0 - */ - def take(n: Int): Array[Row] = head(n) - - /** - * Returns an array that contains all of [[Row]]s in this [[DataFrame]]. - * @group action - * @since 1.3.0 - */ - def collect(): Array[Row] = collect(needCallback = true) - - private def collect(needCallback: Boolean): Array[Row] = { - def execute(): Array[Row] = withNewExecutionId { - queryExecution.executedPlan.executeCollectPublic() - } - - if (needCallback) { - withCallback("collect", this)(_ => execute()) - } else { - execute() - } - } - - /** - * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]]. - * @group action - * @since 1.3.0 - */ - def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ => - withNewExecutionId { - java.util.Arrays.asList(rdd.collect() : _*) - } - } - - /** - * Returns the number of rows in the [[DataFrame]]. - * @group action - * @since 1.3.0 - */ - def count(): Long = withCallback("count", groupBy().count()) { df => - df.collect(needCallback = false).head.getLong(0) - } - - /** - * Returns a new [[DataFrame]] that has exactly `numPartitions` partitions. - * @group dfops - * @since 1.3.0 - */ - def repartition(numPartitions: Int): DataFrame = withPlan { - Repartition(numPartitions, shuffle = true, logicalPlan) - } - - /** - * Returns a new [[DataFrame]] partitioned by the given partitioning expressions into - * `numPartitions`. The resulting DataFrame is hash partitioned. - * - * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL). - * - * @group dfops - * @since 1.6.0 - */ - @scala.annotation.varargs - def repartition(numPartitions: Int, partitionExprs: Column*): DataFrame = withPlan { - RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, Some(numPartitions)) - } - - /** - * Returns a new [[DataFrame]] partitioned by the given partitioning expressions preserving - * the existing number of partitions. The resulting DataFrame is hash partitioned. - * - * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL). - * - * @group dfops - * @since 1.6.0 - */ - @scala.annotation.varargs - def repartition(partitionExprs: Column*): DataFrame = withPlan { - RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions = None) - } - - /** - * Returns a new [[DataFrame]] that has exactly `numPartitions` partitions. - * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g. - * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of - * the 100 new partitions will claim 10 of the current partitions. - * @group rdd - * @since 1.4.0 - */ - def coalesce(numPartitions: Int): DataFrame = withPlan { - Repartition(numPartitions, shuffle = false, logicalPlan) - } - - /** - * Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]]. - * This is an alias for `dropDuplicates`. - * @group dfops - * @since 1.3.0 - */ - def distinct(): DataFrame = dropDuplicates() - - /** - * @group basic - * @since 1.3.0 - */ - def persist(): this.type = { - sqlContext.cacheManager.cacheQuery(this) - this - } - - /** - * @group basic - * @since 1.3.0 - */ - def cache(): this.type = persist() - - /** - * @group basic - * @since 1.3.0 - */ - def persist(newLevel: StorageLevel): this.type = { - sqlContext.cacheManager.cacheQuery(this, None, newLevel) - this - } - - /** - * @group basic - * @since 1.3.0 - */ - def unpersist(blocking: Boolean): this.type = { - sqlContext.cacheManager.tryUncacheQuery(this, blocking) - this - } - - /** - * @group basic - * @since 1.3.0 - */ - def unpersist(): this.type = unpersist(blocking = false) - - ///////////////////////////////////////////////////////////////////////////// - // I/O - ///////////////////////////////////////////////////////////////////////////// - - /** - * Represents the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s. Note that the RDD is - * memoized. Once called, it won't change even if you change any query planning related Spark SQL - * configurations (e.g. `spark.sql.shuffle.partitions`). - * @group rdd - * @since 1.3.0 - */ - lazy val rdd: RDD[Row] = { - // use a local variable to make sure the map closure doesn't capture the whole DataFrame - val schema = this.schema - queryExecution.toRdd.mapPartitions { rows => - val converter = CatalystTypeConverters.createToScalaConverter(schema) - rows.map(converter(_).asInstanceOf[Row]) - } - } - - /** - * Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s. - * @group rdd - * @since 1.3.0 - */ - def toJavaRDD: JavaRDD[Row] = rdd.toJavaRDD() - - /** - * Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s. - * @group rdd - * @since 1.3.0 - */ - def javaRDD: JavaRDD[Row] = toJavaRDD - - /** - * Registers this [[DataFrame]] as a temporary table using the given name. The lifetime of this - * temporary table is tied to the [[SQLContext]] that was used to create this DataFrame. - * - * @group basic - * @since 1.3.0 - */ - def registerTempTable(tableName: String): Unit = { - sqlContext.registerDataFrameAsTable(this, tableName) - } - - /** - * :: Experimental :: - * Interface for saving the content of the [[DataFrame]] out into external storage. - * - * @group output - * @since 1.4.0 - */ - @Experimental - def write: DataFrameWriter = new DataFrameWriter(this) - - /** - * Returns the content of the [[DataFrame]] as a RDD of JSON strings. - * @group rdd - * @since 1.3.0 - */ - def toJSON: RDD[String] = { - val rowSchema = this.schema - queryExecution.toRdd.mapPartitions { iter => - val writer = new CharArrayWriter() - // create the Generator without separator inserted between 2 records - val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) - - new Iterator[String] { - override def hasNext: Boolean = iter.hasNext - override def next(): String = { - JacksonGenerator(rowSchema, gen)(iter.next()) - gen.flush() - - val json = writer.toString - if (hasNext) { - writer.reset() - } else { - gen.close() - } - - json - } - } - } - } - - /** - * Returns a best-effort snapshot of the files that compose this DataFrame. This method simply - * asks each constituent BaseRelation for its respective files and takes the union of all results. - * Depending on the source relations, this may not find all input files. Duplicates are removed. - */ - def inputFiles: Array[String] = { - val files: Seq[String] = logicalPlan.collect { - case LogicalRelation(fsBasedRelation: FileRelation, _) => - fsBasedRelation.inputFiles - case fr: FileRelation => - fr.inputFiles - }.flatten - files.toSet.toArray - } - - //////////////////////////////////////////////////////////////////////////// - // for Python API - //////////////////////////////////////////////////////////////////////////// - - /** - * Converts a JavaRDD to a PythonRDD. - */ - protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { - val structType = schema // capture it for closure - val rdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType)) - EvaluatePython.javaToPython(rdd) - } - - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - // Deprecated methods - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - - /** - * @deprecated As of 1.3.0, replaced by `toDF()`. - */ - @deprecated("use toDF", "1.3.0") - def toSchemaRDD: DataFrame = this - - /** - * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. - * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements. - * If you pass `true` for `allowExisting`, it will drop any table with the - * given name; if you pass `false`, it will throw if the table already - * exists. - * @group output - * @deprecated As of 1.340, replaced by `write().jdbc()`. - */ - @deprecated("Use write.jdbc()", "1.4.0") - def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = { - val w = if (allowExisting) write.mode(SaveMode.Overwrite) else write - w.jdbc(url, table, new Properties) - } - - /** - * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. - * Assumes the table already exists and has a compatible schema. If you - * pass `true` for `overwrite`, it will `TRUNCATE` the table before - * performing the `INSERT`s. - * - * The table must already exist on the database. It must have a schema - * that is compatible with the schema of this RDD; inserting the rows of - * the RDD in order via the simple statement - * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. - * @group output - * @deprecated As of 1.4.0, replaced by `write().jdbc()`. - */ - @deprecated("Use write.jdbc()", "1.4.0") - def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = { - val w = if (overwrite) write.mode(SaveMode.Overwrite) else write.mode(SaveMode.Append) - w.jdbc(url, table, new Properties) - } - - /** - * Saves the contents of this [[DataFrame]] as a parquet file, preserving the schema. - * Files that are written out using this method can be read back in as a [[DataFrame]] - * using the `parquetFile` function in [[SQLContext]]. - * @group output - * @deprecated As of 1.4.0, replaced by `write().parquet()`. - */ - @deprecated("Use write.parquet(path)", "1.4.0") - def saveAsParquetFile(path: String): Unit = { - write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) - } - - /** - * Creates a table from the the contents of this DataFrame. - * It will use the default data source configured by spark.sql.sources.default. - * This will fail if the table already exists. - * - * Note that this currently only works with DataFrames that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * - * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input - * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC - * and Parquet), the table is persisted in a Hive compatible format, which means other systems - * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL - * specific format. - * - * @group output - * @deprecated As of 1.4.0, replaced by `write().saveAsTable(tableName)`. - */ - @deprecated("Use write.saveAsTable(tableName)", "1.4.0") - def saveAsTable(tableName: String): Unit = { - write.mode(SaveMode.ErrorIfExists).saveAsTable(tableName) - } - - /** - * Creates a table from the the contents of this DataFrame, using the default data source - * configured by spark.sql.sources.default and [[SaveMode.ErrorIfExists]] as the save mode. - * - * Note that this currently only works with DataFrames that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * - * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input - * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC - * and Parquet), the table is persisted in a Hive compatible format, which means other systems - * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL - * specific format. - * - * @group output - * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. - */ - @deprecated("Use write.mode(mode).saveAsTable(tableName)", "1.4.0") - def saveAsTable(tableName: String, mode: SaveMode): Unit = { - write.mode(mode).saveAsTable(tableName) - } - - /** - * Creates a table at the given path from the the contents of this DataFrame - * based on a given data source and a set of options, - * using [[SaveMode.ErrorIfExists]] as the save mode. - * - * Note that this currently only works with DataFrames that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * - * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input - * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC - * and Parquet), the table is persisted in a Hive compatible format, which means other systems - * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL - * specific format. - * - * @group output - * @deprecated As of 1.4.0, replaced by `write().format(source).saveAsTable(tableName)`. - */ - @deprecated("Use write.format(source).saveAsTable(tableName)", "1.4.0") - def saveAsTable(tableName: String, source: String): Unit = { - write.format(source).saveAsTable(tableName) - } - - /** - * :: Experimental :: - * Creates a table at the given path from the the contents of this DataFrame - * based on a given data source, [[SaveMode]] specified by mode, and a set of options. - * - * Note that this currently only works with DataFrames that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * - * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input - * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC - * and Parquet), the table is persisted in a Hive compatible format, which means other systems - * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL - * specific format. - * - * @group output - * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. - */ - @deprecated("Use write.format(source).mode(mode).saveAsTable(tableName)", "1.4.0") - def saveAsTable(tableName: String, source: String, mode: SaveMode): Unit = { - write.format(source).mode(mode).saveAsTable(tableName) - } - - /** - * Creates a table at the given path from the the contents of this DataFrame - * based on a given data source, [[SaveMode]] specified by mode, and a set of options. - * - * Note that this currently only works with DataFrames that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * - * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input - * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC - * and Parquet), the table is persisted in a Hive compatible format, which means other systems - * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL - * specific format. - * - * @group output - * @deprecated As of 1.4.0, replaced by - * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. - */ - @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", - "1.4.0") - def saveAsTable( - tableName: String, - source: String, - mode: SaveMode, - options: java.util.Map[String, String]): Unit = { - write.format(source).mode(mode).options(options).saveAsTable(tableName) - } - - /** - * (Scala-specific) - * Creates a table from the the contents of this DataFrame based on a given data source, - * [[SaveMode]] specified by mode, and a set of options. - * - * Note that this currently only works with DataFrames that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * - * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input - * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC - * and Parquet), the table is persisted in a Hive compatible format, which means other systems - * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL - * specific format. - * - * @group output - * @deprecated As of 1.4.0, replaced by - * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. - */ - @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", - "1.4.0") - def saveAsTable( - tableName: String, - source: String, - mode: SaveMode, - options: Map[String, String]): Unit = { - write.format(source).mode(mode).options(options).saveAsTable(tableName) - } - - /** - * Saves the contents of this DataFrame to the given path, - * using the default data source configured by spark.sql.sources.default and - * [[SaveMode.ErrorIfExists]] as the save mode. - * @group output - * @deprecated As of 1.4.0, replaced by `write().save(path)`. - */ - @deprecated("Use write.save(path)", "1.4.0") - def save(path: String): Unit = { - write.save(path) - } - - /** - * Saves the contents of this DataFrame to the given path and [[SaveMode]] specified by mode, - * using the default data source configured by spark.sql.sources.default. - * @group output - * @deprecated As of 1.4.0, replaced by `write().mode(mode).save(path)`. - */ - @deprecated("Use write.mode(mode).save(path)", "1.4.0") - def save(path: String, mode: SaveMode): Unit = { - write.mode(mode).save(path) - } - - /** - * Saves the contents of this DataFrame to the given path based on the given data source, - * using [[SaveMode.ErrorIfExists]] as the save mode. - * @group output - * @deprecated As of 1.4.0, replaced by `write().format(source).save(path)`. - */ - @deprecated("Use write.format(source).save(path)", "1.4.0") - def save(path: String, source: String): Unit = { - write.format(source).save(path) - } - - /** - * Saves the contents of this DataFrame to the given path based on the given data source and - * [[SaveMode]] specified by mode. - * @group output - * @deprecated As of 1.4.0, replaced by `write().format(source).mode(mode).save(path)`. - */ - @deprecated("Use write.format(source).mode(mode).save(path)", "1.4.0") - def save(path: String, source: String, mode: SaveMode): Unit = { - write.format(source).mode(mode).save(path) - } - - /** - * Saves the contents of this DataFrame based on the given data source, - * [[SaveMode]] specified by mode, and a set of options. - * @group output - * @deprecated As of 1.4.0, replaced by - * `write().format(source).mode(mode).options(options).save(path)`. - */ - @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") - def save( - source: String, - mode: SaveMode, - options: java.util.Map[String, String]): Unit = { - write.format(source).mode(mode).options(options).save() - } - - /** - * (Scala-specific) - * Saves the contents of this DataFrame based on the given data source, - * [[SaveMode]] specified by mode, and a set of options - * @group output - * @deprecated As of 1.4.0, replaced by - * `write().format(source).mode(mode).options(options).save(path)`. - */ - @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") - def save( - source: String, - mode: SaveMode, - options: Map[String, String]): Unit = { - write.format(source).mode(mode).options(options).save() - } - - - /** - * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. - * @group output - * @deprecated As of 1.4.0, replaced by - * `write().mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)`. - */ - @deprecated("Use write.mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)", "1.4.0") - def insertInto(tableName: String, overwrite: Boolean): Unit = { - write.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append).insertInto(tableName) - } - - /** - * Adds the rows from this RDD to the specified table. - * Throws an exception if the table already exists. - * @group output - * @deprecated As of 1.4.0, replaced by - * `write().mode(SaveMode.Append).saveAsTable(tableName)`. - */ - @deprecated("Use write.mode(SaveMode.Append).saveAsTable(tableName)", "1.4.0") - def insertInto(tableName: String): Unit = { - write.mode(SaveMode.Append).insertInto(tableName) - } - - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - // End of deprecated methods - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - - /** - * Wrap a DataFrame action to track all Spark jobs in the body so that we can connect them with - * an execution. - */ - private[sql] def withNewExecutionId[T](body: => T): T = { - SQLExecution.withNewExecutionId(sqlContext, queryExecution)(body) - } - - /** - * Wrap a DataFrame action to track the QueryExecution and time cost, then report to the - * user-registered callback functions. - */ - private def withCallback[T](name: String, df: DataFrame)(action: DataFrame => T) = { - try { - df.queryExecution.executedPlan.foreach { plan => - plan.metrics.valuesIterator.foreach(_.reset()) - } - val start = System.nanoTime() - val result = action(df) - val end = System.nanoTime() - sqlContext.listenerManager.onSuccess(name, df.queryExecution, end - start) - result - } catch { - case e: Exception => - sqlContext.listenerManager.onFailure(name, df.queryExecution, e) - throw e - } - } - - private def sortInternal(global: Boolean, sortExprs: Seq[Column]): DataFrame = { - val sortOrder: Seq[SortOrder] = sortExprs.map { col => - col.expr match { - case expr: SortOrder => - expr - case expr: Expression => - SortOrder(expr, Ascending) - } - } - withPlan { - Sort(sortOrder, global = global, logicalPlan) - } - } - - /** A convenient function to wrap a logical plan and produce a DataFrame. */ - @inline private def withPlan(logicalPlan: => LogicalPlan): DataFrame = { - new DataFrame(sqlContext, logicalPlan) - } - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala deleted file mode 100644 index 3b30337f1f877..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -/** - * A container for a [[DataFrame]], used for implicit conversions. - * - * To use this, import implicit conversions in SQL: - * {{{ - * import sqlContext.implicits._ - * }}} - * - * @since 1.3.0 - */ -case class DataFrameHolder private[sql](private val df: DataFrame) { - - // This is declared with parentheses to prevent the Scala compiler from treating - // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDF(): DataFrame = df - - def toDF(colNames: String*): DataFrame = df.toDF(colNames : _*) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index f7be5f6b370ab..f0e16eefc775b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -155,7 +155,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * @since 1.3.1 */ def fill(value: Double, cols: Seq[String]): DataFrame = { - val columnEquals = df.sqlContext.analyzer.resolver + val columnEquals = df.sqlContext.sessionState.analyzer.resolver val projections = df.schema.fields.map { f => // Only fill if the column is part of the cols list. if (f.dataType.isInstanceOf[NumericType] && cols.exists(col => columnEquals(f.name, col))) { @@ -182,7 +182,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * @since 1.3.1 */ def fill(value: String, cols: Seq[String]): DataFrame = { - val columnEquals = df.sqlContext.analyzer.resolver + val columnEquals = df.sqlContext.sessionState.analyzer.resolver val projections = df.schema.fields.map { f => // Only fill if the column is part of the cols list. if (f.dataType.isInstanceOf[StringType] && cols.exists(col => columnEquals(f.name, col))) { @@ -200,6 +200,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * The key of the map is the column name, and the value of the map is the replacement value. * The value must be of the following type: * `Integer`, `Long`, `Float`, `Double`, `String`, `Boolean`. + * Replacement values are cast to the column data type. * * For example, the following replaces null values in column "A" with string "unknown", and * null values in column "B" with numeric value 1.0. @@ -217,6 +218,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * The key of the map is the column name, and the value of the map is the replacement value. * The value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`, `Boolean`. + * Replacement values are cast to the column data type. * * For example, the following replaces null values in column "A" with string "unknown", and * null values in column "B" with numeric value 1.0. @@ -353,7 +355,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { case _: String => StringType } - val columnEquals = df.sqlContext.analyzer.resolver + val columnEquals = df.sqlContext.sessionState.analyzer.resolver val projections = df.schema.fields.map { f => val shouldReplace = cols.exists(colName => columnEquals(colName, f.name)) if (f.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType && shouldReplace) { @@ -382,14 +384,14 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } } - val columnEquals = df.sqlContext.analyzer.resolver + val columnEquals = df.sqlContext.sessionState.analyzer.resolver val projections = df.schema.fields.map { f => values.find { case (k, _) => columnEquals(k, f.name) }.map { case (_, v) => v match { - case v: jl.Float => fillCol[Double](f, v.toDouble) + case v: jl.Float => fillCol[Float](f, v) case v: jl.Double => fillCol[Double](f, v) - case v: jl.Long => fillCol[Double](f, v.toDouble) - case v: jl.Integer => fillCol[Double](f, v.toDouble) + case v: jl.Long => fillCol[Long](f, v) + case v: jl.Integer => fillCol[Integer](f, v) case v: jl.Boolean => fillCol[Boolean](f, v.booleanValue()) case v: String => fillCol[String](f, v) } @@ -402,13 +404,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * Returns a [[Column]] expression that replaces null value in `col` with `replacement`. */ private def fillCol[T](col: StructField, replacement: T): Column = { - col.dataType match { + val quotedColName = "`" + col.name + "`" + val colValue = col.dataType match { case DoubleType | FloatType => - coalesce(nanvl(df.col("`" + col.name + "`"), lit(null)), - lit(replacement).cast(col.dataType)).as(col.name) - case _ => - coalesce(df.col("`" + col.name + "`"), lit(replacement).cast(col.dataType)).as(col.name) + nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types + case _ => df.col(quotedColName) } + coalesce(colValue, lit(replacement)).cast(col.dataType).as(col.name) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 6a194a443ab17..15f2344df6ab2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -21,25 +21,22 @@ import java.util.Properties import scala.collection.JavaConverters._ -import org.apache.hadoop.fs.Path -import org.apache.hadoop.util.StringUtils - +import org.apache.spark.Partition import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution.LogicalRDD +import org.apache.spark.sql.execution.datasources.{DataSource, LogicalRelation} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} -import org.apache.spark.sql.execution.datasources.json.JSONRelation -import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation -import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} +import org.apache.spark.sql.execution.datasources.json.{InferSchema, JacksonParser, JSONOptions} +import org.apache.spark.sql.execution.streaming.StreamingRelation import org.apache.spark.sql.types.StructType -import org.apache.spark.{Logging, Partition} -import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} /** * :: Experimental :: * Interface used to load a [[DataFrame]] from external storage systems (e.g. file systems, - * key-value stores, etc). Use [[SQLContext.read]] to access this. + * key-value stores, etc) or data streams. Use [[SQLContext.read]] to access this. * * @since 1.4.0 */ @@ -78,6 +75,27 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { this } + /** + * Adds an input option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Boolean): DataFrameReader = option(key, value.toString) + + /** + * Adds an input option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Long): DataFrameReader = option(key, value.toString) + + /** + * Adds an input option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Double): DataFrameReader = option(key, value.toString) + /** * (Scala-specific) Adds input options for the underlying data source. * @@ -99,29 +117,29 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { } /** - * Loads input in as a [[DataFrame]], for data sources that require a path (e.g. data backed by - * a local or distributed file system). + * Loads input in as a [[DataFrame]], for data sources that don't require a path (e.g. external + * key-value stores). * * @since 1.4.0 */ - def load(path: String): DataFrame = { - option("path", path).load() + def load(): DataFrame = { + val dataSource = + DataSource( + sqlContext, + userSpecifiedSchema = userSpecifiedSchema, + className = source, + options = extraOptions.toMap) + Dataset.ofRows(sqlContext, LogicalRelation(dataSource.resolveRelation())) } /** - * Loads input in as a [[DataFrame]], for data sources that don't require a path (e.g. external - * key-value stores). + * Loads input in as a [[DataFrame]], for data sources that require a path (e.g. data backed by + * a local or distributed file system). * * @since 1.4.0 */ - def load(): DataFrame = { - val resolved = ResolvedDataSource( - sqlContext, - userSpecifiedSchema = userSpecifiedSchema, - partitionColumns = Array.empty[String], - provider = source, - options = extraOptions.toMap) - DataFrame(sqlContext, LogicalRelation(resolved.relation)) + def load(path: String): DataFrame = { + option("path", path).load() } /** @@ -130,8 +148,44 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * * @since 1.6.0 */ - def load(paths: Array[String]): DataFrame = { - option("paths", paths.map(StringUtils.escapeString(_, '\\', ',')).mkString(",")).load() + @scala.annotation.varargs + def load(paths: String*): DataFrame = { + if (paths.isEmpty) { + sqlContext.emptyDataFrame + } else { + sqlContext.baseRelationToDataFrame( + DataSource.apply( + sqlContext, + paths = paths, + userSpecifiedSchema = userSpecifiedSchema, + className = source, + options = extraOptions.toMap).resolveRelation()) + } + } + + /** + * Loads input data stream in as a [[DataFrame]], for data streams that don't require a path + * (e.g. external key-value stores). + * + * @since 2.0.0 + */ + def stream(): DataFrame = { + val dataSource = + DataSource( + sqlContext, + userSpecifiedSchema = userSpecifiedSchema, + className = source, + options = extraOptions.toMap) + Dataset.ofRows(sqlContext, StreamingRelation(dataSource)) + } + + /** + * Loads input in as a [[DataFrame]], for data streams that read from some path. + * + * @since 2.0.0 + */ + def stream(path: String): DataFrame = { + option("path", path).stream() } /** @@ -152,17 +206,17 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash * your external database systems. * - * @param url JDBC database url of the form `jdbc:subprotocol:subname` + * @param url JDBC database url of the form `jdbc:subprotocol:subname`. * @param table Name of the table in the external database. * @param columnName the name of a column of integral type that will be used for partitioning. - * @param lowerBound the minimum value of `columnName` used to decide partition stride - * @param upperBound the maximum value of `columnName` used to decide partition stride - * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split - * evenly into this many partitions + * @param lowerBound the minimum value of `columnName` used to decide partition stride. + * @param upperBound the maximum value of `columnName` used to decide partition stride. + * @param numPartitions the number of partitions. This, along with `lowerBound` (inclusive), + * `upperBound` (exclusive), form partition strides for generated WHERE + * clause expressions used to split the column `columnName` evenly. * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property * should be included. - * * @since 1.4.0 */ def jdbc( @@ -227,11 +281,67 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. * - * @param path input path + * You can set the following JSON-specific options to deal with non-standard JSON files: + *
  • `primitivesAsString` (default `false`): infers all primitive values as a string type
  • + *
  • `allowComments` (default `false`): ignores Java/C++ style comment in JSON records
  • + *
  • `allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names
  • + *
  • `allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes + *
  • + *
  • `allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers + * (e.g. 00012)
  • + *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records + * during parsing.
  • + *
      + *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts the + * malformed string into a new field configured by `columnNameOfCorruptRecord`. When + * a schema is set by user, it sets `null` for extra fields.
    • + *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • + *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • + *
    + *
  • `columnNameOfCorruptRecord` (default `_corrupt_record`): allows renaming the new field + * having malformed string created by `PERMISSIVE` mode. This overrides + * `spark.sql.columnNameOfCorruptRecord`.
  • + * * @since 1.4.0 */ + // TODO: Remove this one in Spark 2.0. def json(path: String): DataFrame = format("json").load(path) + /** + * Loads a JSON file (one object per line) and returns the result as a [[DataFrame]]. + * + * This function goes through the input once to determine the input schema. If you know the + * schema in advance, use the version that specifies the schema to avoid the extra scan. + * + * You can set the following JSON-specific options to deal with non-standard JSON files: + *
  • `primitivesAsString` (default `false`): infers all primitive values as a string type
  • + *
  • `prefersDecimal` (default `false`): infers all floating-point values as a decimal + * type. If the values do not fit in decimal, then it infers them as doubles.
  • + *
  • `allowComments` (default `false`): ignores Java/C++ style comment in JSON records
  • + *
  • `allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names
  • + *
  • `allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes + *
  • + *
  • `allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers + * (e.g. 00012)
  • + *
  • `allowBackslashEscapingAnyCharacter` (default `false`): allows accepting quoting of all + * character using backslash quoting mechanism
  • + *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records + * during parsing.
  • + *
      + *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts the + * malformed string into a new field configured by `columnNameOfCorruptRecord`. When + * a schema is set by user, it sets `null` for extra fields.
    • + *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • + *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • + *
    + *
  • `columnNameOfCorruptRecord` (default `_corrupt_record`): allows renaming the new field + * having malformed string created by `PERMISSIVE` mode. This overrides + * `spark.sql.columnNameOfCorruptRecord`.
  • + * + * @since 1.6.0 + */ + def json(paths: String*): DataFrame = format("json").load(paths : _*) + /** * Loads an `JavaRDD[String]` storing JSON objects (one object per record) and * returns the result as a [[DataFrame]]. @@ -255,19 +365,39 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * @since 1.4.0 */ def json(jsonRDD: RDD[String]): DataFrame = { - val samplingRatio = extraOptions.getOrElse("samplingRatio", "1.0").toDouble - val primitivesAsString = extraOptions.getOrElse("primitivesAsString", "false").toBoolean - sqlContext.baseRelationToDataFrame( - new JSONRelation( - Some(jsonRDD), - samplingRatio, - primitivesAsString, - userSpecifiedSchema, - None, - None)(sqlContext) - ) + val parsedOptions: JSONOptions = new JSONOptions(extraOptions.toMap) + val columnNameOfCorruptRecord = + parsedOptions.columnNameOfCorruptRecord + .getOrElse(sqlContext.conf.columnNameOfCorruptRecord) + val schema = userSpecifiedSchema.getOrElse { + InferSchema.infer( + jsonRDD, + columnNameOfCorruptRecord, + parsedOptions) + } + + Dataset.ofRows( + sqlContext, + LogicalRDD( + schema.toAttributes, + JacksonParser.parse( + jsonRDD, + schema, + columnNameOfCorruptRecord, + parsedOptions))(sqlContext)) } + /** + * Loads a CSV file and returns the result as a [[DataFrame]]. + * + * This function goes through the input once to determine the input schema. To avoid going + * through the entire data once, specify the schema explicitly using [[schema]]. + * + * @since 2.0.0 + */ + @scala.annotation.varargs + def csv(paths: String*): DataFrame = format("csv").load(paths : _*) + /** * Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty * [[DataFrame]] if no paths are passed in. @@ -276,20 +406,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { */ @scala.annotation.varargs def parquet(paths: String*): DataFrame = { - if (paths.isEmpty) { - sqlContext.emptyDataFrame - } else { - val globbedPaths = paths.flatMap { path => - val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.globPathIfNecessary(qualified) - }.toArray - - sqlContext.baseRelationToDataFrame( - new ParquetRelation( - globbedPaths.map(_.toString), userSpecifiedSchema, None, extraOptions.toMap)(sqlContext)) - } + format("parquet").load(paths: _*) } /** @@ -307,12 +424,16 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * @since 1.4.0 */ def table(tableName: String): DataFrame = { - DataFrame(sqlContext, sqlContext.catalog.lookupRelation(TableIdentifier(tableName))) + Dataset.ofRows(sqlContext, + sqlContext.sessionState.catalog.lookupRelation( + sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName))) } /** - * Loads a text file and returns a [[DataFrame]] with a single string column named "text". - * Each line in the text file is a new row in the resulting DataFrame. For example: + * Loads a text file and returns a [[Dataset]] of String. The underlying schema of the Dataset + * contains a single string column named "value". + * + * Each line in the text file is a new row in the resulting Dataset. For example: * {{{ * // Scala: * sqlContext.read.text("/path/to/spark/README.md") @@ -321,10 +442,13 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * sqlContext.read().text("/path/to/spark/README.md") * }}} * - * @param path input path - * @since 1.6.0 + * @param paths input path + * @since 2.0.0 */ - def text(path: String): DataFrame = format("text").load(path) + @scala.annotation.varargs + def text(paths: String*): Dataset[String] = { + format("text").load(paths : _*).as[String](sqlContext.implicits.newStringEncoder) + } /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 69c984717526d..3eb1f0f0d58ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql -import java.{util => ju, lang => jl} +import java.{lang => jl, util => ju} import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.stat._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} /** * :: Experimental :: @@ -33,6 +36,50 @@ import org.apache.spark.sql.execution.stat._ @Experimental final class DataFrameStatFunctions private[sql](df: DataFrame) { + /** + * Calculates the approximate quantiles of a numerical column of a DataFrame. + * + * The result of this algorithm has the following deterministic bound: + * If the DataFrame has N elements and if we request the quantile at probability `p` up to error + * `err`, then the algorithm will return a sample `x` from the DataFrame so that the *exact* rank + * of `x` is close to (p * N). + * More precisely, + * + * floor((p - err) * N) <= rank(x) <= ceil((p + err) * N). + * + * This method implements a variation of the Greenwald-Khanna algorithm (with some speed + * optimizations). + * The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 Space-efficient + * Online Computation of Quantile Summaries]] by Greenwald and Khanna. + * + * @param col the name of the numerical column + * @param probabilities a list of quantile probabilities + * Each number must belong to [0, 1]. + * For example 0 is the minimum, 0.5 is the median, 1 is the maximum. + * @param relativeError The relative target precision to achieve (>= 0). + * If set to zero, the exact quantiles are computed, which could be very expensive. + * Note that values greater than 1 are accepted but give the same result as 1. + * @return the approximate quantiles at the given probabilities + * + * @since 2.0.0 + */ + def approxQuantile( + col: String, + probabilities: Array[Double], + relativeError: Double): Array[Double] = { + StatFunctions.multipleApproxQuantiles(df, Seq(col), probabilities, relativeError).head.toArray + } + + /** + * Python-friendly version of [[approxQuantile()]] + */ + private[spark] def approxQuantile( + col: String, + probabilities: List[Double], + relativeError: Double): java.util.List[Double] = { + approxQuantile(col, probabilities.toArray, relativeError).toList.asJava + } + /** * Calculate the sample covariance of two numerical columns of a DataFrame. * @param col1 the name of the first column @@ -106,7 +153,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * Null elements will be replaced by "null", and back ticks will be dropped from elements if they * exist. * - * * @param col1 The name of the first column. Distinct items will make the first item of * each row. * @param col2 The name of the second column. Distinct items will make the column names @@ -309,4 +355,168 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) } + + /** + * Builds a Count-min Sketch over a specified column. + * + * @param colName name of the column over which the sketch is built + * @param depth depth of the sketch + * @param width width of the sketch + * @param seed random seed + * @return a [[CountMinSketch]] over column `colName` + * @since 2.0.0 + */ + def countMinSketch(colName: String, depth: Int, width: Int, seed: Int): CountMinSketch = { + countMinSketch(Column(colName), depth, width, seed) + } + + /** + * Builds a Count-min Sketch over a specified column. + * + * @param colName name of the column over which the sketch is built + * @param eps relative error of the sketch + * @param confidence confidence of the sketch + * @param seed random seed + * @return a [[CountMinSketch]] over column `colName` + * @since 2.0.0 + */ + def countMinSketch( + colName: String, eps: Double, confidence: Double, seed: Int): CountMinSketch = { + countMinSketch(Column(colName), eps, confidence, seed) + } + + /** + * Builds a Count-min Sketch over a specified column. + * + * @param col the column over which the sketch is built + * @param depth depth of the sketch + * @param width width of the sketch + * @param seed random seed + * @return a [[CountMinSketch]] over column `colName` + * @since 2.0.0 + */ + def countMinSketch(col: Column, depth: Int, width: Int, seed: Int): CountMinSketch = { + countMinSketch(col, CountMinSketch.create(depth, width, seed)) + } + + /** + * Builds a Count-min Sketch over a specified column. + * + * @param col the column over which the sketch is built + * @param eps relative error of the sketch + * @param confidence confidence of the sketch + * @param seed random seed + * @return a [[CountMinSketch]] over column `colName` + * @since 2.0.0 + */ + def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int): CountMinSketch = { + countMinSketch(col, CountMinSketch.create(eps, confidence, seed)) + } + + private def countMinSketch(col: Column, zero: CountMinSketch): CountMinSketch = { + val singleCol = df.select(col) + val colType = singleCol.schema.head.dataType + + val updater: (CountMinSketch, InternalRow) => Unit = colType match { + // For string type, we can get bytes of our `UTF8String` directly, and call the `addBinary` + // instead of `addString` to avoid unnecessary conversion. + case StringType => (sketch, row) => sketch.addBinary(row.getUTF8String(0).getBytes) + case ByteType => (sketch, row) => sketch.addLong(row.getByte(0)) + case ShortType => (sketch, row) => sketch.addLong(row.getShort(0)) + case IntegerType => (sketch, row) => sketch.addLong(row.getInt(0)) + case LongType => (sketch, row) => sketch.addLong(row.getLong(0)) + case _ => + throw new IllegalArgumentException( + s"Count-min Sketch only supports string type and integral types, " + + s"and does not support type $colType." + ) + } + + singleCol.queryExecution.toRdd.aggregate(zero)( + (sketch: CountMinSketch, row: InternalRow) => { + updater(sketch, row) + sketch + }, + (sketch1, sketch2) => sketch1.mergeInPlace(sketch2) + ) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param colName name of the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param fpp expected false positive probability of the filter. + * @since 2.0.0 + */ + def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = { + buildBloomFilter(Column(colName), BloomFilter.create(expectedNumItems, fpp)) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param col the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param fpp expected false positive probability of the filter. + * @since 2.0.0 + */ + def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = { + buildBloomFilter(col, BloomFilter.create(expectedNumItems, fpp)) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param colName name of the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param numBits expected number of bits of the filter. + * @since 2.0.0 + */ + def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = { + buildBloomFilter(Column(colName), BloomFilter.create(expectedNumItems, numBits)) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param col the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param numBits expected number of bits of the filter. + * @since 2.0.0 + */ + def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = { + buildBloomFilter(col, BloomFilter.create(expectedNumItems, numBits)) + } + + private def buildBloomFilter(col: Column, zero: BloomFilter): BloomFilter = { + val singleCol = df.select(col) + val colType = singleCol.schema.head.dataType + + require(colType == StringType || colType.isInstanceOf[IntegralType], + s"Bloom filter only supports string type and integral types, but got $colType.") + + val updater: (BloomFilter, InternalRow) => Unit = colType match { + // For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary` + // instead of `putString` to avoid unnecessary conversion. + case StringType => (filter, row) => filter.putBinary(row.getUTF8String(0).getBytes) + case ByteType => (filter, row) => filter.putLong(row.getByte(0)) + case ShortType => (filter, row) => filter.putLong(row.getShort(0)) + case IntegerType => (filter, row) => filter.putLong(row.getInt(0)) + case LongType => (filter, row) => filter.putLong(row.getLong(0)) + case _ => + throw new IllegalArgumentException( + s"Bloom filter only supports string type and integral types, " + + s"and does not support type $colType." + ) + } + + singleCol.queryExecution.toRdd.aggregate(zero)( + (filter: BloomFilter, row: InternalRow) => { + updater(filter, row) + filter + }, + (filter1, filter2) => filter1.mergeInPlace(filter2) + ) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 7887e559a3025..54d250867fbb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -1,19 +1,19 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql @@ -21,19 +21,23 @@ import java.util.Properties import scala.collection.JavaConverters._ +import org.apache.hadoop.fs.Path + import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project} +import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, DataSource} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils -import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource} +import org.apache.spark.sql.execution.streaming.{MemoryPlan, MemorySink, StreamExecution} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.HadoopFsRelation - +import org.apache.spark.util.Utils /** * :: Experimental :: * Interface used to write a [[DataFrame]] to external storage systems (e.g. file systems, - * key-value stores, etc). Use [[DataFrame.write]] to access this. + * key-value stores, etc) or data streams. Use [[DataFrame.write]] to access this. * * @since 1.4.0 */ @@ -75,6 +79,35 @@ final class DataFrameWriter private[sql](df: DataFrame) { this } + /** + * :: Experimental :: + * Set the trigger for the stream query. The default value is `ProcessingTime(0)` and it will run + * the query as fast as possible. + * + * Scala Example: + * {{{ + * def.writer.trigger(ProcessingTime("10 seconds")) + * + * import scala.concurrent.duration._ + * def.writer.trigger(ProcessingTime(10.seconds)) + * }}} + * + * Java Example: + * {{{ + * def.writer.trigger(ProcessingTime.create("10 seconds")) + * + * import java.util.concurrent.TimeUnit + * def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * }}} + * + * @since 2.0.0 + */ + @Experimental + def trigger(trigger: Trigger): DataFrameWriter = { + this.trigger = trigger + this + } + /** * Specifies the underlying output data source. Built-in options include "parquet", "json", etc. * @@ -95,6 +128,27 @@ final class DataFrameWriter private[sql](df: DataFrame) { this } + /** + * Adds an output option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Boolean): DataFrameWriter = option(key, value.toString) + + /** + * Adds an output option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Long): DataFrameWriter = option(key, value.toString) + + /** + * Adds an output option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Double): DataFrameWriter = option(key, value.toString) + /** * (Scala-specific) Adds output options for the underlying data source. * @@ -117,9 +171,18 @@ final class DataFrameWriter private[sql](df: DataFrame) { /** * Partitions the output by the given columns on the file system. If specified, the output is - * laid out on the file system similar to Hive's partitioning scheme. + * laid out on the file system similar to Hive's partitioning scheme. As an example, when we + * partition a dataset by year and then month, the directory layout would look like: + * + * - year=2016/month=01/ + * - year=2016/month=02/ + * + * Partitioning is one of the most widely used techniques to optimize physical data layout. + * It provides a coarse-grained index for skipping unnecessary data reads when queries have + * predicates on the partitioned columns. In order for partitioning to work well, the number + * of distinct values in each column should typically be less than tens of thousands. * - * This is only applicable for Parquet at the moment. + * This was initially applicable for Parquet but in 1.5+ covers JSON, text, ORC and avro as well. * * @since 1.4.0 */ @@ -129,6 +192,34 @@ final class DataFrameWriter private[sql](df: DataFrame) { this } + /** + * Buckets the output by the given columns. If specified, the output is laid out on the file + * system similar to Hive's bucketing scheme. + * + * This is applicable for Parquet, JSON and ORC. + * + * @since 2.0 + */ + @scala.annotation.varargs + def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter = { + this.numBuckets = Option(numBuckets) + this.bucketColumnNames = Option(colName +: colNames) + this + } + + /** + * Sorts the output in each bucket by the given columns. + * + * This is applicable for Parquet, JSON and ORC. + * + * @since 2.0 + */ + @scala.annotation.varargs + def sortBy(colName: String, colNames: String*): DataFrameWriter = { + this.sortColumnNames = Option(colName +: colNames) + this + } + /** * Saves the content of the [[DataFrame]] at the specified path. * @@ -145,13 +236,105 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def save(): Unit = { - ResolvedDataSource( + assertNotBucketed() + val dataSource = DataSource( df.sqlContext, - source, - partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), - mode, - extraOptions.toMap, - df) + className = source, + partitionColumns = partitioningColumns.getOrElse(Nil), + bucketSpec = getBucketSpec, + options = extraOptions.toMap) + + dataSource.write(mode, df) + } + + /** + * Specifies the name of the [[ContinuousQuery]] that can be started with `startStream()`. + * This name must be unique among all the currently active queries in the associated SQLContext. + * + * @since 2.0.0 + */ + def queryName(queryName: String): DataFrameWriter = { + this.extraOptions += ("queryName" -> queryName) + this + } + + /** + * Starts the execution of the streaming query, which will continually output results to the given + * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with + * the stream. + * + * @since 2.0.0 + */ + def startStream(path: String): ContinuousQuery = { + option("path", path).startStream() + } + + /** + * Starts the execution of the streaming query, which will continually output results to the given + * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with + * the stream. + * + * @since 2.0.0 + */ + def startStream(): ContinuousQuery = { + if (source == "memory") { + val queryName = + extraOptions.getOrElse( + "queryName", throw new AnalysisException("queryName must be specified for memory sink")) + val checkpointLocation = extraOptions.get("checkpointLocation").map { userSpecified => + new Path(userSpecified).toUri.toString + }.orElse { + val checkpointConfig: Option[String] = + df.sqlContext.conf.getConf( + SQLConf.CHECKPOINT_LOCATION, + None) + + checkpointConfig.map { location => + new Path(location, queryName).toUri.toString + } + }.getOrElse { + Utils.createTempDir(namePrefix = "memory.stream").getCanonicalPath + } + + // If offsets have already been created, we trying to resume a query. + val checkpointPath = new Path(checkpointLocation, "offsets") + val fs = checkpointPath.getFileSystem(df.sqlContext.sparkContext.hadoopConfiguration) + if (fs.exists(checkpointPath)) { + throw new AnalysisException( + s"Unable to resume query written to memory sink. Delete $checkpointPath to start over.") + } else { + checkpointPath.toUri.toString + } + + val sink = new MemorySink(df.schema) + val resultDf = Dataset.ofRows(df.sqlContext, new MemoryPlan(sink)) + resultDf.registerTempTable(queryName) + val continuousQuery = df.sqlContext.sessionState.continuousQueryManager.startQuery( + queryName, + checkpointLocation, + df, + sink, + trigger) + continuousQuery + } else { + val dataSource = + DataSource( + df.sqlContext, + className = source, + options = extraOptions.toMap, + partitionColumns = normalizedParCols.getOrElse(Nil)) + + val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName) + val checkpointLocation = extraOptions.getOrElse("checkpointLocation", { + new Path(df.sqlContext.conf.checkpointLocation, queryName).toUri.toString + }) + df.sqlContext.sessionState.continuousQueryManager.startQuery( + queryName, + checkpointLocation, + df, + dataSource.createSink(), + trigger) + } } /** @@ -163,21 +346,86 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def insertInto(tableName: String): Unit = { - insertInto(SqlParser.parseTableIdentifier(tableName)) + insertInto(df.sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName)) } private def insertInto(tableIdent: TableIdentifier): Unit = { - val partitions = partitioningColumns.map(_.map(col => col -> (None: Option[String])).toMap) + assertNotBucketed() + val partitions = normalizedParCols.map(_.map(col => col -> (None: Option[String])).toMap) val overwrite = mode == SaveMode.Overwrite + + // A partitioned relation's schema can be different from the input logicalPlan, since + // partition columns are all moved after data columns. We Project to adjust the ordering. + // TODO: this belongs to the analyzer. + val input = normalizedParCols.map { parCols => + val (inputPartCols, inputDataCols) = df.logicalPlan.output.partition { attr => + parCols.contains(attr.name) + } + Project(inputDataCols ++ inputPartCols, df.logicalPlan) + }.getOrElse(df.logicalPlan) + df.sqlContext.executePlan( InsertIntoTable( UnresolvedRelation(tableIdent), partitions.getOrElse(Map.empty[String, Option[String]]), - df.logicalPlan, + input, overwrite, ifNotExists = false)).toRdd } + private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => + cols.map(normalize(_, "Partition")) + } + + private def normalizedBucketColNames: Option[Seq[String]] = bucketColumnNames.map { cols => + cols.map(normalize(_, "Bucketing")) + } + + private def normalizedSortColNames: Option[Seq[String]] = sortColumnNames.map { cols => + cols.map(normalize(_, "Sorting")) + } + + private def getBucketSpec: Option[BucketSpec] = { + if (sortColumnNames.isDefined) { + require(numBuckets.isDefined, "sortBy must be used together with bucketBy") + } + + for { + n <- numBuckets + } yield { + require(n > 0 && n < 100000, "Bucket number must be greater than 0 and less than 100000.") + + // partitionBy columns cannot be used in bucketBy + if (normalizedParCols.nonEmpty && + normalizedBucketColNames.get.toSet.intersect(normalizedParCols.get.toSet).nonEmpty) { + throw new AnalysisException( + s"bucketBy columns '${bucketColumnNames.get.mkString(", ")}' should not be part of " + + s"partitionBy columns '${partitioningColumns.get.mkString(", ")}'") + } + + BucketSpec(n, normalizedBucketColNames.get, normalizedSortColNames.getOrElse(Nil)) + } + } + + /** + * The given column name may not be equal to any of the existing column names if we were in + * case-insensitive context. Normalize the given column name to the real one so that we don't + * need to care about case sensitivity afterwards. + */ + private def normalize(columnName: String, columnType: String): String = { + val validColumnNames = df.logicalPlan.output.map(_.name) + validColumnNames.find(df.sqlContext.sessionState.analyzer.resolver(_, columnName)) + .getOrElse(throw new AnalysisException(s"$columnType column $columnName not found in " + + s"existing columns (${validColumnNames.mkString(", ")})")) + } + + private def assertNotBucketed(): Unit = { + if (numBuckets.isDefined || sortColumnNames.isDefined) { + throw new IllegalArgumentException( + "Currently we don't support writing bucketed data to this data source.") + } + } + /** * Saves the content of the [[DataFrame]] as the specified table. * @@ -197,11 +445,11 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def saveAsTable(tableName: String): Unit = { - saveAsTable(SqlParser.parseTableIdentifier(tableName)) + saveAsTable(df.sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName)) } private def saveAsTable(tableIdent: TableIdentifier): Unit = { - val tableExists = df.sqlContext.catalog.tableExists(tableIdent) + val tableExists = df.sqlContext.sessionState.catalog.tableExists(tableIdent) (tableExists, mode) match { case (true, SaveMode.Ignore) => @@ -210,13 +458,6 @@ final class DataFrameWriter private[sql](df: DataFrame) { case (true, SaveMode.ErrorIfExists) => throw new AnalysisException(s"Table $tableIdent already exists.") - case (true, SaveMode.Append) => - // If it is Append, we just ask insertInto to handle it. We will not use insertInto - // to handle saveAsTable with Overwrite because saveAsTable can change the schema of - // the table. But, insertInto with Overwrite requires the schema of data be the same - // the schema of the table. - insertInto(tableIdent) - case _ => val cmd = CreateTableUsingAsSelect( @@ -224,6 +465,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { source, temporary = false, partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), + getBucketSpec, mode, extraOptions.toMap, df.logicalPlan) @@ -244,7 +486,6 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property * should be included. - * * @since 1.4.0 */ def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { @@ -254,7 +495,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { } // connectionProperties should override settings in extraOptions props.putAll(connectionProperties) - val conn = JdbcUtils.createConnection(url, props) + val conn = JdbcUtils.createConnectionFactory(url, props)() try { var tableExists = JdbcUtils.tableExists(conn, url, table) @@ -276,7 +517,12 @@ final class DataFrameWriter private[sql](df: DataFrame) { if (!tableExists) { val schema = JdbcUtils.schemaString(df, url) val sql = s"CREATE TABLE $table ($schema)" - conn.prepareStatement(sql).executeUpdate() + val statement = conn.createStatement + try { + statement.executeUpdate(sql) + } finally { + statement.close() + } } } finally { conn.close() @@ -292,6 +538,11 @@ final class DataFrameWriter private[sql](df: DataFrame) { * format("json").save(path) * }}} * + * You can set the following JSON-specific option(s) for writing JSON files: + *
  • `compression` (default `null`): compression codec to use when saving to file. This can be + * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, + * `snappy` and `deflate`).
  • + * * @since 1.4.0 */ def json(path: String): Unit = format("json").save(path) @@ -303,6 +554,11 @@ final class DataFrameWriter private[sql](df: DataFrame) { * format("parquet").save(path) * }}} * + * You can set the following Parquet-specific option(s) for writing Parquet files: + *
  • `compression` (default `null`): compression codec to use when saving to file. This can be + * one of the known case-insensitive shorten names(`none`, `snappy`, `gzip`, and `lzo`). + * This will overwrite `spark.sql.parquet.compression.codec`.
  • + * * @since 1.4.0 */ def parquet(path: String): Unit = format("parquet").save(path) @@ -314,6 +570,11 @@ final class DataFrameWriter private[sql](df: DataFrame) { * format("orc").save(path) * }}} * + * You can set the following ORC-specific option(s) for writing ORC files: + *
  • `compression` (default `null`): compression codec to use when saving to file. This can be + * one of the known case-insensitive shorten names(`none`, `snappy`, `zlib`, and `lzo`). + * This will overwrite `orc.compress`.
  • + * * @since 1.5.0 * @note Currently, this method can only be used together with `HiveContext`. */ @@ -331,10 +592,31 @@ final class DataFrameWriter private[sql](df: DataFrame) { * df.write().text("/path/to/output") * }}} * + * You can set the following option(s) for writing text files: + *
  • `compression` (default `null`): compression codec to use when saving to file. This can be + * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, + * `snappy` and `deflate`).
  • + * * @since 1.6.0 */ def text(path: String): Unit = format("text").save(path) + /** + * Saves the content of the [[DataFrame]] in CSV format at the specified path. + * This is equivalent to: + * {{{ + * format("csv").save(path) + * }}} + * + * You can set the following CSV-specific option(s) for writing CSV files: + *
  • `compression` (default `null`): compression codec to use when saving to file. This can be + * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, + * `snappy` and `deflate`).
  • + * + * @since 2.0.0 + */ + def csv(path: String): Unit = format("csv").save(path) + /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// @@ -343,8 +625,15 @@ final class DataFrameWriter private[sql](df: DataFrame) { private var mode: SaveMode = SaveMode.ErrorIfExists + private var trigger: Trigger = ProcessingTime(0L) + private var extraOptions = new scala.collection.mutable.HashMap[String, String] private var partitioningColumns: Option[Seq[String]] = None + private var bucketColumnNames: Option[Seq[String]] = None + + private var numBuckets: Option[Int] = None + + private var sortColumnNames: Option[Seq[String]] = None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 500227e93a472..4edc90d9c38d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1,463 +1,2450 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ package org.apache.spark.sql -import org.apache.spark.annotation.Experimental +import java.io.CharArrayWriter + +import scala.collection.JavaConverters._ +import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal + +import com.fasterxml.jackson.core.JsonFactory +import org.apache.commons.lang3.StringUtils + +import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.java.function._ +import org.apache.spark.api.python.PythonRDD import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias +import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.optimizer.CombineUnions +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.catalyst.util.usePrettyExpression +import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, QueryExecution, SQLExecution} +import org.apache.spark.sql.execution.command.ExplainCommand +import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} +import org.apache.spark.sql.execution.datasources.json.JacksonGenerator +import org.apache.spark.sql.execution.python.EvaluatePython +import org.apache.spark.sql.execution.streaming.{StreamingExecutionRelation, StreamingRelation} +import org.apache.spark.sql.types._ +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils + +private[sql] object Dataset { + def apply[T: Encoder](sqlContext: SQLContext, logicalPlan: LogicalPlan): Dataset[T] = { + new Dataset(sqlContext, logicalPlan, implicitly[Encoder[T]]) + } + + def ofRows(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = { + val qe = sqlContext.executePlan(logicalPlan) + qe.assertAnalyzed() + new Dataset[Row](sqlContext, logicalPlan, RowEncoder(qe.analyzed.schema)) + } +} /** - * A [[Dataset]] is a strongly typed collection of objects that can be transformed in parallel - * using functional or relational operations. + * A [[Dataset]] is a strongly typed collection of domain-specific objects that can be transformed + * in parallel using functional or relational operations. Each Dataset also has an untyped view + * called a [[DataFrame]], which is a Dataset of [[Row]]. + * + * Operations available on Datasets are divided into transformations and actions. Transformations + * are the ones that produce new Datasets, and actions are the ones that trigger computation and + * return results. Example transformations include map, filter, select, and aggregate (`groupBy`). + * Example actions count, show, or writing data out to file systems. + * + * Datasets are "lazy", i.e. computations are only triggered when an action is invoked. Internally, + * a Dataset represents a logical plan that describes the computation required to produce the data. + * When an action is invoked, Spark's query optimizer optimizes the logical plan and generates a + * physical plan for efficient execution in a parallel and distributed manner. To explore the + * logical plan as well as optimized physical plan, use the `explain` function. + * + * To efficiently support domain-specific objects, an [[Encoder]] is required. The encoder maps + * the domain specific type `T` to Spark's internal type system. For example, given a class `Person` + * with two fields, `name` (string) and `age` (int), an encoder is used to tell Spark to generate + * code at runtime to serialize the `Person` object into a binary structure. This binary structure + * often has much lower memory footprint as well as are optimized for efficiency in data processing + * (e.g. in a columnar format). To understand the internal binary representation for data, use the + * `schema` function. + * + * There are typically two ways to create a Dataset. The most common way is by pointing Spark + * to some files on storage systems, using the `read` function available on a `SparkSession`. + * {{{ + * val people = session.read.parquet("...").as[Person] // Scala + * Dataset people = session.read().parquet("...").as(Encoders.bean(Person.class) // Java + * }}} + * + * Datasets can also be created through transformations available on existing Datasets. For example, + * the following creates a new Dataset by applying a filter on the existing one: + * {{{ + * val names = people.map(_.name) // in Scala; names is a Dataset[String] + * Dataset names = people.map((Person p) -> p.name, Encoders.STRING) // in Java 8 + * }}} * - * A [[Dataset]] differs from an [[RDD]] in the following ways: - * - Internally, a [[Dataset]] is represented by a Catalyst logical plan and the data is stored - * in the encoded form. This representation allows for additional logical operations and - * enables many operations (sorting, shuffling, etc.) to be performed without deserializing to - * an object. - * - The creation of a [[Dataset]] requires the presence of an explicit [[Encoder]] that can be - * used to serialize the object into a binary format. Encoders are also capable of mapping the - * schema of a given object to the Spark SQL type system. In contrast, RDDs rely on runtime - * reflection based serialization. Operations that change the type of object stored in the - * dataset also need an encoder for the new type. + * Dataset operations can also be untyped, through various domain-specific-language (DSL) + * functions defined in: [[Dataset]] (this class), [[Column]], and [[functions]]. These operations + * are very similar to the operations available in the data frame abstraction in R or Python. * - * A [[Dataset]] can be thought of as a specialized DataFrame, where the elements map to a specific - * JVM object type, instead of to a generic [[Row]] container. A DataFrame can be transformed into - * specific Dataset by calling `df.as[ElementType]`. Similarly you can transform a strongly-typed - * [[Dataset]] to a generic DataFrame by calling `ds.toDF()`. + * To select a column from the Dataset, use `apply` method in Scala and `col` in Java. + * {{{ + * val ageCol = people("age") // in Scala + * Column ageCol = people.col("age") // in Java + * }}} * - * COMPATIBILITY NOTE: Long term we plan to make [[DataFrame]] extend `Dataset[Row]`. However, - * making this change to the class hierarchy would break the function signatures for the existing - * functional operations (map, flatMap, etc). As such, this class should be considered a preview - * of the final API. Changes will be made to the interface after Spark 1.6. + * Note that the [[Column]] type can also be manipulated through its various functions. + * {{{ + * // The following creates a new column that increases everybody's age by 10. + * people("age") + 10 // in Scala + * people.col("age").plus(10); // in Java + * }}} + * + * A more concrete example in Scala: + * {{{ + * // To create Dataset[Row] using SQLContext + * val people = session.read.parquet("...") + * val department = session.read.parquet("...") + * + * people.filter("age > 30") + * .join(department, people("deptId") === department("id")) + * .groupBy(department("name"), "gender") + * .agg(avg(people("salary")), max(people("age"))) + * }}} + * + * and in Java: + * {{{ + * // To create Dataset using SQLContext + * Dataset people = session.read().parquet("..."); + * Dataset department = session.read().parquet("..."); + * + * people.filter("age".gt(30)) + * .join(department, people.col("deptId").equalTo(department("id"))) + * .groupBy(department.col("name"), "gender") + * .agg(avg(people.col("salary")), max(people.col("age"))); + * }}} + * + * @groupname basic Basic Dataset functions + * @groupname action Actions + * @groupname untypedrel Untyped Language Integrated Relational Queries + * @groupname typedrel Typed Language Integrated Relational Queries + * @groupname func Functional Transformations + * @groupname rdd RDD Operations + * @groupname output Output Operations * * @since 1.6.0 */ -@Experimental -class Dataset[T] private( +class Dataset[T] private[sql]( @transient val sqlContext: SQLContext, - @transient val queryExecution: QueryExecution, - unresolvedEncoder: Encoder[T]) extends Serializable { + @DeveloperApi @transient val queryExecution: QueryExecution, + encoder: Encoder[T]) + extends Serializable { + + queryExecution.assertAnalyzed() + + // Note for Spark contributors: if adding or updating any action in `Dataset`, please make sure + // you wrap it with `withNewExecutionId` if this actions doesn't call other action. + + def this(sqlContext: SQLContext, logicalPlan: LogicalPlan, encoder: Encoder[T]) = { + this(sqlContext, sqlContext.executePlan(logicalPlan), encoder) + } + + @transient protected[sql] val logicalPlan: LogicalPlan = { + def hasSideEffects(plan: LogicalPlan): Boolean = plan match { + case _: Command | + _: InsertIntoTable | + _: CreateTableUsingAsSelect => true + case _ => false + } + + queryExecution.logical match { + // For various commands (like DDL) and queries with side effects, we force query execution + // to happen right away to let these side effects take place eagerly. + case p if hasSideEffects(p) => + LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) + case Union(children) if children.forall(hasSideEffects) => + LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) + case _ => + queryExecution.analyzed + } + } + + /** + * An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is + * marked implicit so that we can use it when constructing new [[Dataset]] objects that have the + * same object type (that will be possibly resolved to a different schema). + */ + private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(encoder) + unresolvedTEncoder.validate(logicalPlan.output) + + /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ + private[sql] val resolvedTEncoder: ExpressionEncoder[T] = + unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes) + + /** + * The encoder where the expressions used to construct an object from an input row have been + * bound to the ordinals of this [[Dataset]]'s output schema. + */ + private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output) + + private implicit def classTag = unresolvedTEncoder.clsTag + + protected[sql] def resolve(colName: String): NamedExpression = { + queryExecution.analyzed.resolveQuoted(colName, sqlContext.sessionState.analyzer.resolver) + .getOrElse { + throw new AnalysisException( + s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") + } + } + + protected[sql] def numericColumns: Seq[Expression] = { + schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => + queryExecution.analyzed.resolveQuoted(n.name, sqlContext.sessionState.analyzer.resolver).get + } + } + + /** + * Compose the string representing rows for output + * + * @param _numRows Number of rows to show + * @param truncate Whether truncate long strings and align cells right + */ + private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { + val numRows = _numRows.max(0) + val takeResult = take(numRows + 1) + val hasMoreData = takeResult.length > numRows + val data = takeResult.take(numRows) + + // For array values, replace Seq and Array with square brackets + // For cells that are beyond 20 characters, replace it with the first 17 and "..." + val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { + case r: Row => r + case tuple: Product => Row.fromTuple(tuple) + case o => Row(o) + }.map { row => + row.toSeq.map { cell => + val str = cell match { + case null => "null" + case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]") + case array: Array[_] => array.mkString("[", ", ", "]") + case seq: Seq[_] => seq.mkString("[", ", ", "]") + case _ => cell.toString + } + if (truncate && str.length > 20) str.substring(0, 17) + "..." else str + }: Seq[String] + } + + val sb = new StringBuilder + val numCols = schema.fieldNames.length + + // Initialise the width of each column to a minimum value of '3' + val colWidths = Array.fill(numCols)(3) + + // Compute the width of each column + for (row <- rows) { + for ((cell, i) <- row.zipWithIndex) { + colWidths(i) = math.max(colWidths(i), cell.length) + } + } + + // Create SeparateLine + val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() + + // column names + rows.head.zipWithIndex.map { case (cell, i) => + if (truncate) { + StringUtils.leftPad(cell, colWidths(i)) + } else { + StringUtils.rightPad(cell, colWidths(i)) + } + }.addString(sb, "|", "|", "|\n") + + sb.append(sep) + + // data + rows.tail.map { + _.zipWithIndex.map { case (cell, i) => + if (truncate) { + StringUtils.leftPad(cell.toString, colWidths(i)) + } else { + StringUtils.rightPad(cell.toString, colWidths(i)) + } + }.addString(sb, "|", "|", "|\n") + } + + sb.append(sep) + + // For Data that has more than "numRows" records + if (hasMoreData) { + val rowsString = if (numRows == 1) "row" else "rows" + sb.append(s"only showing top $numRows $rowsString\n") + } + + sb.toString() + } + + override def toString: String = { + try { + val builder = new StringBuilder + val fields = schema.take(2).map { + case f => s"${f.name}: ${f.dataType.simpleString(2)}" + } + builder.append("[") + builder.append(fields.mkString(", ")) + if (schema.length > 2) { + if (schema.length - fields.size == 1) { + builder.append(" ... 1 more field") + } else { + builder.append(" ... " + (schema.length - 2) + " more fields") + } + } + builder.append("]").toString() + } catch { + case NonFatal(e) => + s"Invalid tree; ${e.getMessage}:\n$queryExecution" + } + } + + /** + * Converts this strongly typed collection of data to generic Dataframe. In contrast to the + * strongly typed objects that Dataset operations work on, a Dataframe returns generic [[Row]] + * objects that allow fields to be accessed by ordinal or name. + * + * @group basic + * @since 1.6.0 + */ + // This is declared with parentheses to prevent the Scala compiler from treating + // `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. + def toDF(): DataFrame = new Dataset[Row](sqlContext, queryExecution, RowEncoder(schema)) + + /** + * :: Experimental :: + * Returns a new [[Dataset]] where each record has been mapped on to the specified type. The + * method used to map columns depend on the type of `U`: + * - When `U` is a class, fields for the class will be mapped to columns of the same name + * (case sensitivity is determined by `spark.sql.caseSensitive`) + * - When `U` is a tuple, the columns will be be mapped by ordinal (i.e. the first column will + * be assigned to `_1`). + * - When `U` is a primitive type (i.e. String, Int, etc). then the first column of the + * [[DataFrame]] will be used. + * + * If the schema of the [[Dataset]] does not match the desired `U` type, you can use `select` + * along with `alias` or `as` to rearrange or rename as required. + * + * @group basic + * @since 1.6.0 + */ + @Experimental + def as[U : Encoder]: Dataset[U] = Dataset[U](sqlContext, logicalPlan) + + /** + * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. + * This can be quite convenient in conversion from a RDD of tuples into a [[DataFrame]] with + * meaningful names. For example: + * {{{ + * val rdd: RDD[(Int, String)] = ... + * rdd.toDF() // this implicit conversion creates a DataFrame with column name `_1` and `_2` + * rdd.toDF("id", "name") // this creates a DataFrame with column name "id" and "name" + * }}} + * + * @group basic + * @since 2.0.0 + */ + @scala.annotation.varargs + def toDF(colNames: String*): DataFrame = { + require(schema.size == colNames.size, + "The number of columns doesn't match.\n" + + s"Old column names (${schema.size}): " + schema.fields.map(_.name).mkString(", ") + "\n" + + s"New column names (${colNames.size}): " + colNames.mkString(", ")) + + val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, newName) => + Column(oldAttribute).as(newName) + } + select(newCols : _*) + } + + /** + * Returns the schema of this [[Dataset]]. + * + * @group basic + * @since 1.6.0 + */ + def schema: StructType = queryExecution.analyzed.schema + + /** + * Prints the schema to the console in a nice tree format. + * + * @group basic + * @since 1.6.0 + */ + // scalastyle:off println + def printSchema(): Unit = println(schema.treeString) + // scalastyle:on println + + /** + * Prints the plans (logical and physical) to the console for debugging purposes. + * + * @group basic + * @since 1.6.0 + */ + def explain(extended: Boolean): Unit = { + val explain = ExplainCommand(queryExecution.logical, extended = extended) + sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + // scalastyle:off println + r => println(r.getString(0)) + // scalastyle:on println + } + } + + /** + * Prints the physical plan to the console for debugging purposes. + * + * @group basic + * @since 1.6.0 + */ + def explain(): Unit = explain(extended = false) + + /** + * Returns all column names and their data types as an array. + * + * @group basic + * @since 1.6.0 + */ + def dtypes: Array[(String, String)] = schema.fields.map { field => + (field.name, field.dataType.toString) + } + + /** + * Returns all column names as an array. + * + * @group basic + * @since 1.6.0 + */ + def columns: Array[String] = schema.fields.map(_.name) + + /** + * Returns true if the `collect` and `take` methods can be run locally + * (without any Spark executors). + * + * @group basic + * @since 1.6.0 + */ + def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation] + + /** + * Returns true if this [[Dataset]] contains one or more sources that continuously + * return data as it arrives. A [[Dataset]] that reads data from a streaming source + * must be executed as a [[ContinuousQuery]] using the `startStream()` method in + * [[DataFrameWriter]]. Methods that return a single answer, (e.g., `count()` or + * `collect()`) will throw an [[AnalysisException]] when there is a streaming + * source present. + * + * @group basic + * @since 2.0.0 + */ + @Experimental + def isStreaming: Boolean = logicalPlan.find { n => + n.isInstanceOf[StreamingRelation] || n.isInstanceOf[StreamingExecutionRelation] + }.isDefined + + /** + * Displays the [[Dataset]] in a tabular form. Strings more than 20 characters will be truncated, + * and all cells will be aligned right. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + * + * @param numRows Number of rows to show + * + * @group action + * @since 1.6.0 + */ + def show(numRows: Int): Unit = show(numRows, truncate = true) + + /** + * Displays the top 20 rows of [[Dataset]] in a tabular form. Strings more than 20 characters + * will be truncated, and all cells will be aligned right. + * + * @group action + * @since 1.6.0 + */ + def show(): Unit = show(20) + + /** + * Displays the top 20 rows of [[Dataset]] in a tabular form. + * + * @param truncate Whether truncate long strings. If true, strings more than 20 characters will + * be truncated and all cells will be aligned right + * + * @group action + * @since 1.6.0 + */ + def show(truncate: Boolean): Unit = show(20, truncate) + + /** + * Displays the [[Dataset]] in a tabular form. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + * @param numRows Number of rows to show + * @param truncate Whether truncate long strings. If true, strings more than 20 characters will + * be truncated and all cells will be aligned right + * + * @group action + * @since 1.6.0 + */ + // scalastyle:off println + def show(numRows: Int, truncate: Boolean): Unit = println(showString(numRows, truncate)) + // scalastyle:on println + + /** + * Returns a [[DataFrameNaFunctions]] for working with missing data. + * {{{ + * // Dropping rows containing any null values. + * ds.na.drop() + * }}} + * + * @group untypedrel + * @since 1.6.0 + */ + def na: DataFrameNaFunctions = new DataFrameNaFunctions(toDF()) + + /** + * Returns a [[DataFrameStatFunctions]] for working statistic functions support. + * {{{ + * // Finding frequent items in column with name 'a'. + * ds.stat.freqItems(Seq("a")) + * }}} + * + * @group untypedrel + * @since 1.6.0 + */ + def stat: DataFrameStatFunctions = new DataFrameStatFunctions(toDF()) + + /** + * Cartesian join with another [[DataFrame]]. + * + * Note that cartesian joins are very expensive without an extra filter that can be pushed down. + * + * @param right Right side of the join operation. + * + * @group untypedrel + * @since 2.0.0 + */ + def join(right: DataFrame): DataFrame = withPlan { + Join(logicalPlan, right.logicalPlan, joinType = Inner, None) + } + + /** + * Inner equi-join with another [[DataFrame]] using the given column. + * + * Different from other join functions, the join column will only appear once in the output, + * i.e. similar to SQL's `JOIN USING` syntax. + * + * {{{ + * // Joining df1 and df2 using the column "user_id" + * df1.join(df2, "user_id") + * }}} + * + * Note that if you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * + * @param right Right side of the join operation. + * @param usingColumn Name of the column to join on. This column must exist on both sides. + * + * @group untypedrel + * @since 2.0.0 + */ + def join(right: DataFrame, usingColumn: String): DataFrame = { + join(right, Seq(usingColumn)) + } + + /** + * Inner equi-join with another [[DataFrame]] using the given columns. + * + * Different from other join functions, the join columns will only appear once in the output, + * i.e. similar to SQL's `JOIN USING` syntax. + * + * {{{ + * // Joining df1 and df2 using the columns "user_id" and "user_name" + * df1.join(df2, Seq("user_id", "user_name")) + * }}} + * + * Note that if you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * + * @param right Right side of the join operation. + * @param usingColumns Names of the columns to join on. This columns must exist on both sides. + * + * @group untypedrel + * @since 2.0.0 + */ + def join(right: DataFrame, usingColumns: Seq[String]): DataFrame = { + join(right, usingColumns, "inner") + } + + /** + * Equi-join with another [[DataFrame]] using the given columns. + * + * Different from other join functions, the join columns will only appear once in the output, + * i.e. similar to SQL's `JOIN USING` syntax. + * + * Note that if you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * + * @param right Right side of the join operation. + * @param usingColumns Names of the columns to join on. This columns must exist on both sides. + * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. + * + * @group untypedrel + * @since 2.0.0 + */ + def join(right: DataFrame, usingColumns: Seq[String], joinType: String): DataFrame = { + // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right + // by creating a new instance for one of the branch. + val joined = sqlContext.executePlan( + Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None)) + .analyzed.asInstanceOf[Join] + + withPlan { + Join( + joined.left, + joined.right, + UsingJoin(JoinType(joinType), usingColumns.map(UnresolvedAttribute(_))), + None) + } + } + + /** + * Inner join with another [[DataFrame]], using the given join expression. + * + * {{{ + * // The following two are equivalent: + * df1.join(df2, $"df1Key" === $"df2Key") + * df1.join(df2).where($"df1Key" === $"df2Key") + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ + def join(right: DataFrame, joinExprs: Column): DataFrame = join(right, joinExprs, "inner") + + /** + * Join with another [[DataFrame]], using the given join expression. The following performs + * a full outer join between `df1` and `df2`. + * + * {{{ + * // Scala: + * import org.apache.spark.sql.functions._ + * df1.join(df2, $"df1Key" === $"df2Key", "outer") + * + * // Java: + * import static org.apache.spark.sql.functions.*; + * df1.join(df2, col("df1Key").equalTo(col("df2Key")), "outer"); + * }}} + * + * @param right Right side of the join. + * @param joinExprs Join expression. + * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. + * + * @group untypedrel + * @since 2.0.0 + */ + def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = { + // Note that in this function, we introduce a hack in the case of self-join to automatically + // resolve ambiguous join conditions into ones that might make sense [SPARK-6231]. + // Consider this case: df.join(df, df("key") === df("key")) + // Since df("key") === df("key") is a trivially true condition, this actually becomes a + // cartesian join. However, most likely users expect to perform a self join using "key". + // With that assumption, this hack turns the trivially true condition into equality on join + // keys that are resolved to both sides. + + // Trigger analysis so in the case of self-join, the analyzer will clone the plan. + // After the cloning, left and right side will have distinct expression ids. + val plan = withPlan( + Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))) + .queryExecution.analyzed.asInstanceOf[Join] + + // If auto self join alias is disabled, return the plan. + if (!sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity) { + return withPlan(plan) + } + + // If left/right have no output set intersection, return the plan. + val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed + val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed + if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { + return withPlan(plan) + } + + // Otherwise, find the trivially true predicates and automatically resolves them to both sides. + // By the time we get here, since we have already run analysis, all attributes should've been + // resolved and become AttributeReference. + val cond = plan.condition.map { _.transform { + case catalyst.expressions.EqualTo(a: AttributeReference, b: AttributeReference) + if a.sameRef(b) => + catalyst.expressions.EqualTo( + withPlan(plan.left).resolve(a.name), + withPlan(plan.right).resolve(b.name)) + }} + + withPlan { + plan.copy(condition = cond) + } + } + + /** + * :: Experimental :: + * Joins this [[Dataset]] returning a [[Tuple2]] for each pair where `condition` evaluates to + * true. + * + * This is similar to the relation `join` function with one important difference in the + * result schema. Since `joinWith` preserves objects present on either side of the join, the + * result schema is similarly nested into a tuple under the column names `_1` and `_2`. + * + * This type of join can be useful both for preserving type-safety with the original object + * types as well as working with relational data where either side of the join has column + * names in common. + * + * @param other Right side of the join. + * @param condition Join expression. + * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. + * + * @group typedrel + * @since 1.6.0 + */ + @Experimental + def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { + val left = this.logicalPlan + val right = other.logicalPlan + + val joined = sqlContext.executePlan(Join(left, right, joinType = + JoinType(joinType), Some(condition.expr))) + val leftOutput = joined.analyzed.output.take(left.output.length) + val rightOutput = joined.analyzed.output.takeRight(right.output.length) + + val leftData = this.unresolvedTEncoder match { + case e if e.flat => Alias(leftOutput.head, "_1")() + case _ => Alias(CreateStruct(leftOutput), "_1")() + } + val rightData = other.unresolvedTEncoder match { + case e if e.flat => Alias(rightOutput.head, "_2")() + case _ => Alias(CreateStruct(rightOutput), "_2")() + } + + implicit val tuple2Encoder: Encoder[(T, U)] = + ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder) + + withTypedPlan { + Project( + leftData :: rightData :: Nil, + joined.analyzed) + } + } + + /** + * :: Experimental :: + * Using inner equi-join to join this [[Dataset]] returning a [[Tuple2]] for each pair + * where `condition` evaluates to true. + * + * @param other Right side of the join. + * @param condition Join expression. + * + * @group typedrel + * @since 1.6.0 + */ + @Experimental + def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { + joinWith(other, condition, "inner") + } + + /** + * Returns a new [[Dataset]] with each partition sorted by the given expressions. + * + * This is the same operation as "SORT BY" in SQL (Hive QL). + * + * @group typedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def sortWithinPartitions(sortCol: String, sortCols: String*): Dataset[T] = { + sortWithinPartitions((sortCol +: sortCols).map(Column(_)) : _*) + } + + /** + * Returns a new [[Dataset]] with each partition sorted by the given expressions. + * + * This is the same operation as "SORT BY" in SQL (Hive QL). + * + * @group typedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def sortWithinPartitions(sortExprs: Column*): Dataset[T] = { + sortInternal(global = false, sortExprs) + } + + /** + * Returns a new [[Dataset]] sorted by the specified column, all in ascending order. + * {{{ + * // The following 3 are equivalent + * ds.sort("sortcol") + * ds.sort($"sortcol") + * ds.sort($"sortcol".asc) + * }}} + * + * @group typedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def sort(sortCol: String, sortCols: String*): Dataset[T] = { + sort((sortCol +: sortCols).map(apply) : _*) + } + + /** + * Returns a new [[Dataset]] sorted by the given expressions. For example: + * {{{ + * ds.sort($"col1", $"col2".desc) + * }}} + * + * @group typedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def sort(sortExprs: Column*): Dataset[T] = { + sortInternal(global = true, sortExprs) + } + + /** + * Returns a new [[Dataset]] sorted by the given expressions. + * This is an alias of the `sort` function. + * + * @group typedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def orderBy(sortCol: String, sortCols: String*): Dataset[T] = sort(sortCol, sortCols : _*) + + /** + * Returns a new [[Dataset]] sorted by the given expressions. + * This is an alias of the `sort` function. + * + * @group typedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def orderBy(sortExprs: Column*): Dataset[T] = sort(sortExprs : _*) + + /** + * Selects column based on the column name and return it as a [[Column]]. + * Note that the column name can also reference to a nested column like `a.b`. + * + * @group untypedrel + * @since 2.0.0 + */ + def apply(colName: String): Column = col(colName) + + /** + * Selects column based on the column name and return it as a [[Column]]. + * Note that the column name can also reference to a nested column like `a.b`. + * + * @group untypedrel + * @since 2.0.0 + */ + def col(colName: String): Column = colName match { + case "*" => + Column(ResolvedStar(queryExecution.analyzed.output)) + case _ => + val expr = resolve(colName) + Column(expr) + } + + /** + * Returns a new [[Dataset]] with an alias set. + * + * @group typedrel + * @since 1.6.0 + */ + def as(alias: String): Dataset[T] = withTypedPlan { + SubqueryAlias(alias, logicalPlan) + } + + /** + * (Scala-specific) Returns a new [[Dataset]] with an alias set. + * + * @group typedrel + * @since 2.0.0 + */ + def as(alias: Symbol): Dataset[T] = as(alias.name) + + /** + * Returns a new [[Dataset]] with an alias set. Same as `as`. + * + * @group typedrel + * @since 2.0.0 + */ + def alias(alias: String): Dataset[T] = as(alias) + + /** + * (Scala-specific) Returns a new [[Dataset]] with an alias set. Same as `as`. + * + * @group typedrel + * @since 2.0.0 + */ + def alias(alias: Symbol): Dataset[T] = as(alias) + + /** + * Selects a set of column based expressions. + * {{{ + * ds.select($"colA", $"colB" + 1) + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def select(cols: Column*): DataFrame = withPlan { + Project(cols.map(_.named), logicalPlan) + } + + /** + * Selects a set of columns. This is a variant of `select` that can only select + * existing columns using column names (i.e. cannot construct expressions). + * + * {{{ + * // The following two are equivalent: + * ds.select("colA", "colB") + * ds.select($"colA", $"colB") + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)) : _*) + + /** + * Selects a set of SQL expressions. This is a variant of `select` that accepts + * SQL expressions. + * + * {{{ + * // The following are equivalent: + * ds.selectExpr("colA", "colB as newName", "abs(colC)") + * ds.select(expr("colA"), expr("colB as newName"), expr("abs(colC)")) + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def selectExpr(exprs: String*): DataFrame = { + select(exprs.map { expr => + Column(sqlContext.sessionState.sqlParser.parseExpression(expr)) + }: _*) + } + + /** + * :: Experimental :: + * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element. + * + * {{{ + * val ds = Seq(1, 2, 3).toDS() + * val newDS = ds.select(expr("value + 1").as[Int]) + * }}} + * + * @group typedrel + * @since 1.6.0 + */ + @Experimental + def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = { + new Dataset[U1]( + sqlContext, + Project( + c1.withInputType( + unresolvedTEncoder.deserializer, + logicalPlan.output).named :: Nil, + logicalPlan), + implicitly[Encoder[U1]]) + } + + /** + * Internal helper function for building typed selects that return tuples. For simplicity and + * code reuse, we do this without the help of the type system and then use helper functions + * that cast appropriately for the user facing interface. + */ + protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { + val encoders = columns.map(_.encoder) + val namedColumns = + columns.map(_.withInputType(unresolvedTEncoder.deserializer, logicalPlan.output).named) + val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan)) + + new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) + } + + /** + * :: Experimental :: + * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * + * @group typedrel + * @since 1.6.0 + */ + @Experimental + def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] = + selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]] + + /** + * :: Experimental :: + * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * + * @group typedrel + * @since 1.6.0 + */ + @Experimental + def select[U1, U2, U3]( + c1: TypedColumn[T, U1], + c2: TypedColumn[T, U2], + c3: TypedColumn[T, U3]): Dataset[(U1, U2, U3)] = + selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]] + + /** + * :: Experimental :: + * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * + * @group typedrel + * @since 1.6.0 + */ + @Experimental + def select[U1, U2, U3, U4]( + c1: TypedColumn[T, U1], + c2: TypedColumn[T, U2], + c3: TypedColumn[T, U3], + c4: TypedColumn[T, U4]): Dataset[(U1, U2, U3, U4)] = + selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]] + + /** + * :: Experimental :: + * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * + * @group typedrel + * @since 1.6.0 + */ + @Experimental + def select[U1, U2, U3, U4, U5]( + c1: TypedColumn[T, U1], + c2: TypedColumn[T, U2], + c3: TypedColumn[T, U3], + c4: TypedColumn[T, U4], + c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] = + selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]] + + /** + * Filters rows using the given condition. + * {{{ + * // The following are equivalent: + * peopleDs.filter($"age" > 15) + * peopleDs.where($"age" > 15) + * }}} + * + * @group typedrel + * @since 1.6.0 + */ + def filter(condition: Column): Dataset[T] = withTypedPlan { + Filter(condition.expr, logicalPlan) + } + + /** + * Filters rows using the given SQL expression. + * {{{ + * peopleDs.filter("age > 15") + * }}} + * + * @group typedrel + * @since 1.6.0 + */ + def filter(conditionExpr: String): Dataset[T] = { + filter(Column(sqlContext.sessionState.sqlParser.parseExpression(conditionExpr))) + } + + /** + * Filters rows using the given condition. This is an alias for `filter`. + * {{{ + * // The following are equivalent: + * peopleDs.filter($"age" > 15) + * peopleDs.where($"age" > 15) + * }}} + * + * @group typedrel + * @since 1.6.0 + */ + def where(condition: Column): Dataset[T] = filter(condition) + + /** + * Filters rows using the given SQL expression. + * {{{ + * peopleDs.where("age > 15") + * }}} + * + * @group typedrel + * @since 1.6.0 + */ + def where(conditionExpr: String): Dataset[T] = { + filter(Column(sqlContext.sessionState.sqlParser.parseExpression(conditionExpr))) + } + + /** + * Groups the [[Dataset]] using the specified columns, so we can run aggregation on them. See + * [[RelationalGroupedDataset]] for all the available aggregate functions. + * + * {{{ + * // Compute the average for all numeric columns grouped by department. + * ds.groupBy($"department").avg() + * + * // Compute the max age and average salary, grouped by department and gender. + * ds.groupBy($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def groupBy(cols: Column*): RelationalGroupedDataset = { + RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.GroupByType) + } + + /** + * Create a multi-dimensional rollup for the current [[Dataset]] using the specified columns, + * so we can run aggregation on them. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. + * + * {{{ + * // Compute the average for all numeric columns rolluped by department and group. + * ds.rollup($"department", $"group").avg() + * + * // Compute the max age and average salary, rolluped by department and gender. + * ds.rollup($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def rollup(cols: Column*): RelationalGroupedDataset = { + RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.RollupType) + } + + /** + * Create a multi-dimensional cube for the current [[Dataset]] using the specified columns, + * so we can run aggregation on them. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. + * + * {{{ + * // Compute the average for all numeric columns cubed by department and group. + * ds.cube($"department", $"group").avg() + * + * // Compute the max age and average salary, cubed by department and gender. + * ds.cube($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def cube(cols: Column*): RelationalGroupedDataset = { + RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.CubeType) + } + + /** + * Groups the [[Dataset]] using the specified columns, so that we can run aggregation on them. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. + * + * This is a variant of groupBy that can only group by existing columns using column names + * (i.e. cannot construct expressions). + * + * {{{ + * // Compute the average for all numeric columns grouped by department. + * ds.groupBy("department").avg() + * + * // Compute the max age and average salary, grouped by department and gender. + * ds.groupBy($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * @group untypedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def groupBy(col1: String, cols: String*): RelationalGroupedDataset = { + val colNames: Seq[String] = col1 +: cols + RelationalGroupedDataset( + toDF(), colNames.map(colName => resolve(colName)), RelationalGroupedDataset.GroupByType) + } + + /** + * :: Experimental :: + * (Scala-specific) + * Reduces the elements of this [[Dataset]] using the specified binary function. The given `func` + * must be commutative and associative or the result may be non-deterministic. + * + * @group action + * @since 1.6.0 + */ + @Experimental + def reduce(func: (T, T) => T): T = rdd.reduce(func) + + /** + * :: Experimental :: + * (Java-specific) + * Reduces the elements of this Dataset using the specified binary function. The given `func` + * must be commutative and associative or the result may be non-deterministic. + * + * @group action + * @since 1.6.0 + */ + @Experimental + def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _)) + + /** + * :: Experimental :: + * (Scala-specific) + * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`. + * + * @group typedrel + * @since 2.0.0 + */ + @Experimental + def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { + val inputPlan = logicalPlan + val withGroupingKey = AppendColumns(func, inputPlan) + val executed = sqlContext.executePlan(withGroupingKey) + + new KeyValueGroupedDataset( + encoderFor[K], + encoderFor[T], + executed, + inputPlan.output, + withGroupingKey.newColumns) + } + + /** + * :: Experimental :: + * (Java-specific) + * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`. + * + * @group typedrel + * @since 2.0.0 + */ + @Experimental + def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = + groupByKey(func.call(_))(encoder) + + /** + * Create a multi-dimensional rollup for the current [[Dataset]] using the specified columns, + * so we can run aggregation on them. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. + * + * This is a variant of rollup that can only group by existing columns using column names + * (i.e. cannot construct expressions). + * + * {{{ + * // Compute the average for all numeric columns rolluped by department and group. + * ds.rollup("department", "group").avg() + * + * // Compute the max age and average salary, rolluped by department and gender. + * ds.rollup($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def rollup(col1: String, cols: String*): RelationalGroupedDataset = { + val colNames: Seq[String] = col1 +: cols + RelationalGroupedDataset( + toDF(), colNames.map(colName => resolve(colName)), RelationalGroupedDataset.RollupType) + } + + /** + * Create a multi-dimensional cube for the current [[Dataset]] using the specified columns, + * so we can run aggregation on them. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. + * + * This is a variant of cube that can only group by existing columns using column names + * (i.e. cannot construct expressions). + * + * {{{ + * // Compute the average for all numeric columns cubed by department and group. + * ds.cube("department", "group").avg() + * + * // Compute the max age and average salary, cubed by department and gender. + * ds.cube($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * @group untypedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def cube(col1: String, cols: String*): RelationalGroupedDataset = { + val colNames: Seq[String] = col1 +: cols + RelationalGroupedDataset( + toDF(), colNames.map(colName => resolve(colName)), RelationalGroupedDataset.CubeType) + } + + /** + * (Scala-specific) Aggregates on the entire [[Dataset]] without groups. + * {{{ + * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) + * ds.agg("age" -> "max", "salary" -> "avg") + * ds.groupBy().agg("age" -> "max", "salary" -> "avg") + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ + def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { + groupBy().agg(aggExpr, aggExprs : _*) + } + + /** + * (Scala-specific) Aggregates on the entire [[Dataset]] without groups. + * {{{ + * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) + * ds.agg(Map("age" -> "max", "salary" -> "avg")) + * ds.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ + def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs) + + /** + * (Java-specific) Aggregates on the entire [[Dataset]] without groups. + * {{{ + * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) + * ds.agg(Map("age" -> "max", "salary" -> "avg")) + * ds.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ + def agg(exprs: java.util.Map[String, String]): DataFrame = groupBy().agg(exprs) + + /** + * Aggregates on the entire [[Dataset]] without groups. + * {{{ + * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) + * ds.agg(max($"age"), avg($"salary")) + * ds.groupBy().agg(max($"age"), avg($"salary")) + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs : _*) + + /** + * Returns a new [[Dataset]] by taking the first `n` rows. The difference between this function + * and `head` is that `head` is an action and returns an array (by triggering query execution) + * while `limit` returns a new [[Dataset]]. + * + * @group typedrel + * @since 2.0.0 + */ + def limit(n: Int): Dataset[T] = withTypedPlan { + Limit(Literal(n), logicalPlan) + } + + /** + * Returns a new [[Dataset]] containing union of rows in this Dataset and another Dataset. + * This is equivalent to `UNION ALL` in SQL. + * + * To do a SQL-style set union (that does deduplication of elements), use this function followed + * by a [[distinct]]. + * + * @group typedrel + * @since 2.0.0 + */ + @deprecated("use union()", "2.0.0") + def unionAll(other: Dataset[T]): Dataset[T] = union(other) + + /** + * Returns a new [[Dataset]] containing union of rows in this Dataset and another Dataset. + * This is equivalent to `UNION ALL` in SQL. + * + * To do a SQL-style set union (that does deduplication of elements), use this function followed + * by a [[distinct]]. + * + * @group typedrel + * @since 2.0.0 + */ + def union(other: Dataset[T]): Dataset[T] = withTypedPlan { + // This breaks caching, but it's usually ok because it addresses a very specific use case: + // using union to union many files or partitions. + CombineUnions(Union(logicalPlan, other.logicalPlan)) + } + + /** + * Returns a new [[Dataset]] containing rows only in both this Dataset and another Dataset. + * This is equivalent to `INTERSECT` in SQL. + * + * Note that, equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. + * + * @group typedrel + * @since 1.6.0 + */ + def intersect(other: Dataset[T]): Dataset[T] = withTypedPlan { + Intersect(logicalPlan, other.logicalPlan) + } + + /** + * Returns a new [[Dataset]] containing rows in this Dataset but not in another Dataset. + * This is equivalent to `EXCEPT` in SQL. + * + * Note that, equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. + * + * @group typedrel + * @since 2.0.0 + */ + def except(other: Dataset[T]): Dataset[T] = withTypedPlan { + Except(logicalPlan, other.logicalPlan) + } + + /** + * Returns a new [[Dataset]] by sampling a fraction of rows. + * + * @param withReplacement Sample with replacement or not. + * @param fraction Fraction of rows to generate. + * @param seed Seed for sampling. + * + * @group typedrel + * @since 1.6.0 + */ + def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = withTypedPlan { + Sample(0.0, fraction, withReplacement, seed, logicalPlan)() + } + + /** + * Returns a new [[Dataset]] by sampling a fraction of rows, using a random seed. + * + * @param withReplacement Sample with replacement or not. + * @param fraction Fraction of rows to generate. + * + * @group typedrel + * @since 1.6.0 + */ + def sample(withReplacement: Boolean, fraction: Double): Dataset[T] = { + sample(withReplacement, fraction, Utils.random.nextLong) + } + + /** + * Randomly splits this [[Dataset]] with the provided weights. + * + * @param weights weights for splits, will be normalized if they don't sum to 1. + * @param seed Seed for sampling. + * + * For Java API, use [[randomSplitAsList]]. + * + * @group typedrel + * @since 2.0.0 + */ + def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] = { + // It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its + // constituent partitions each time a split is materialized which could result in + // overlapping splits. To prevent this, we explicitly sort each input partition to make the + // ordering deterministic. + val sorted = Sort(logicalPlan.output.map(SortOrder(_, Ascending)), global = false, logicalPlan) + val sum = weights.sum + val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) + normalizedCumWeights.sliding(2).map { x => + new Dataset[T]( + sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)(), encoder) + }.toArray + } + + /** + * Returns a Java list that contains randomly split [[Dataset]] with the provided weights. + * + * @param weights weights for splits, will be normalized if they don't sum to 1. + * @param seed Seed for sampling. + * + * @group typedrel + * @since 2.0.0 + */ + def randomSplitAsList(weights: Array[Double], seed: Long): java.util.List[Dataset[T]] = { + val values = randomSplit(weights, seed) + java.util.Arrays.asList(values : _*) + } + + /** + * Randomly splits this [[Dataset]] with the provided weights. + * + * @param weights weights for splits, will be normalized if they don't sum to 1. + * @group typedrel + * @since 2.0.0 + */ + def randomSplit(weights: Array[Double]): Array[Dataset[T]] = { + randomSplit(weights, Utils.random.nextLong) + } + + /** + * Randomly splits this [[Dataset]] with the provided weights. Provided for the Python Api. + * + * @param weights weights for splits, will be normalized if they don't sum to 1. + * @param seed Seed for sampling. + */ + private[spark] def randomSplit(weights: List[Double], seed: Long): Array[Dataset[T]] = { + randomSplit(weights.toArray, seed) + } + + /** + * :: Experimental :: + * (Scala-specific) Returns a new [[Dataset]] where each row has been expanded to zero or more + * rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of + * the input row are implicitly joined with each row that is output by the function. + * + * The following example uses this function to count the number of books which contain + * a given word: + * + * {{{ + * case class Book(title: String, words: String) + * val ds: Dataset[Book] + * + * case class Word(word: String) + * val allWords = ds.explode('words) { + * case Row(words: String) => words.split(" ").map(Word(_)) + * } + * + * val bookCountPerWord = allWords.groupBy("word").agg(countDistinct("title")) + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ + @Experimental + def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { + val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] + + val elementTypes = schema.toAttributes.map { + attr => (attr.dataType, attr.nullable, attr.name) } + val names = schema.toAttributes.map(_.name) + val convert = CatalystTypeConverters.createToCatalystConverter(schema) + + val rowFunction = + f.andThen(_.map(convert(_).asInstanceOf[InternalRow])) + val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr)) + + withPlan { + Generate(generator, join = true, outer = false, + qualifier = None, generatorOutput = Nil, logicalPlan) + } + } + + /** + * :: Experimental :: + * (Scala-specific) Returns a new [[Dataset]] where a single column has been expanded to zero + * or more rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. All + * columns of the input row are implicitly joined with each value that is output by the function. + * + * {{{ + * ds.explode("words", "word") {words: String => words.split(" ")} + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ + @Experimental + def explode[A, B : TypeTag](inputColumn: String, outputColumn: String)(f: A => TraversableOnce[B]) + : DataFrame = { + val dataType = ScalaReflection.schemaFor[B].dataType + val attributes = AttributeReference(outputColumn, dataType)() :: Nil + // TODO handle the metadata? + val elementTypes = attributes.map { attr => (attr.dataType, attr.nullable, attr.name) } + + def rowFunction(row: Row): TraversableOnce[InternalRow] = { + val convert = CatalystTypeConverters.createToCatalystConverter(dataType) + f(row(0).asInstanceOf[A]).map(o => InternalRow(convert(o))) + } + val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil) + + withPlan { + Generate(generator, join = true, outer = false, + qualifier = None, generatorOutput = Nil, logicalPlan) + } + } + + /** + * Returns a new [[Dataset]] by adding a column or replacing the existing column that has + * the same name. + * + * @group untypedrel + * @since 2.0.0 + */ + def withColumn(colName: String, col: Column): DataFrame = { + val resolver = sqlContext.sessionState.analyzer.resolver + val output = queryExecution.analyzed.output + val shouldReplace = output.exists(f => resolver(f.name, colName)) + if (shouldReplace) { + val columns = output.map { field => + if (resolver(field.name, colName)) { + col.as(colName) + } else { + Column(field) + } + } + select(columns : _*) + } else { + select(Column("*"), col.as(colName)) + } + } + + /** + * Returns a new [[Dataset]] by adding a column with metadata. + */ + private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = { + val resolver = sqlContext.sessionState.analyzer.resolver + val output = queryExecution.analyzed.output + val shouldReplace = output.exists(f => resolver(f.name, colName)) + if (shouldReplace) { + val columns = output.map { field => + if (resolver(field.name, colName)) { + col.as(colName, metadata) + } else { + Column(field) + } + } + select(columns : _*) + } else { + select(Column("*"), col.as(colName, metadata)) + } + } - /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ - private[sql] implicit val encoder: ExpressionEncoder[T] = unresolvedEncoder match { - case e: ExpressionEncoder[T] => e.resolve(queryExecution.analyzed.output) - case _ => throw new IllegalArgumentException("Only expression encoders are currently supported") + /** + * Returns a new [[Dataset]] with a column renamed. + * This is a no-op if schema doesn't contain existingName. + * + * @group untypedrel + * @since 2.0.0 + */ + def withColumnRenamed(existingName: String, newName: String): DataFrame = { + val resolver = sqlContext.sessionState.analyzer.resolver + val output = queryExecution.analyzed.output + val shouldRename = output.exists(f => resolver(f.name, existingName)) + if (shouldRename) { + val columns = output.map { col => + if (resolver(col.name, existingName)) { + Column(col).as(newName) + } else { + Column(col) + } + } + select(columns : _*) + } else { + toDF() + } } - private implicit def classTag = encoder.clsTag + /** + * Returns a new [[Dataset]] with a column dropped. + * This is a no-op if schema doesn't contain column name. + * + * @group untypedrel + * @since 2.0.0 + */ + def drop(colName: String): DataFrame = { + drop(Seq(colName) : _*) + } - private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) = - this(sqlContext, new QueryExecution(sqlContext, plan), encoder) + /** + * Returns a new [[Dataset]] with columns dropped. + * This is a no-op if schema doesn't contain column name(s). + * + * @group untypedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def drop(colNames: String*): DataFrame = { + val resolver = sqlContext.sessionState.analyzer.resolver + val remainingCols = + schema.filter(f => colNames.forall(n => !resolver(f.name, n))).map(f => Column(f.name)) + if (remainingCols.size == this.schema.size) { + toDF() + } else { + this.select(remainingCols: _*) + } + } - /** Returns the schema of the encoded form of the objects in this [[Dataset]]. */ - def schema: StructType = encoder.schema + /** + * Returns a new [[Dataset]] with a column dropped. + * This version of drop accepts a Column rather than a name. + * This is a no-op if the Datasetdoesn't have a column + * with an equivalent expression. + * + * @group untypedrel + * @since 2.0.0 + */ + def drop(col: Column): DataFrame = { + val expression = col match { + case Column(u: UnresolvedAttribute) => + queryExecution.analyzed.resolveQuoted( + u.name, sqlContext.sessionState.analyzer.resolver).getOrElse(u) + case Column(expr: Expression) => expr + } + val attrs = this.logicalPlan.output + val colsAfterDrop = attrs.filter { attr => + attr != expression + }.map(attr => Column(attr)) + select(colsAfterDrop : _*) + } - /* ************* * - * Conversions * - * ************* */ + /** + * Returns a new [[Dataset]] that contains only the unique rows from this [[Dataset]]. + * This is an alias for `distinct`. + * + * @group typedrel + * @since 2.0.0 + */ + def dropDuplicates(): Dataset[T] = dropDuplicates(this.columns) /** - * Returns a new `Dataset` where each record has been mapped on to the specified type. The - * method used to map columns depend on the type of `U`: - * - When `U` is a class, fields for the class will be mapped to columns of the same name - * (case sensitivity is determined by `spark.sql.caseSensitive`) - * - When `U` is a tuple, the columns will be be mapped by ordinal (i.e. the first column will - * be assigned to `_1`). - * - When `U` is a primitive type (i.e. String, Int, etc). then the first column of the - * [[DataFrame]] will be used. + * (Scala-specific) Returns a new [[Dataset]] with duplicate rows removed, considering only + * the subset of columns. * - * If the schema of the [[DataFrame]] does not match the desired `U` type, you can use `select` - * along with `alias` or `as` to rearrange or rename as required. - * @since 1.6.0 + * @group typedrel + * @since 2.0.0 */ - def as[U : Encoder]: Dataset[U] = { - new Dataset(sqlContext, queryExecution, encoderFor[U]) + def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan { + val groupCols = colNames.map(resolve) + val groupColExprIds = groupCols.map(_.exprId) + val aggCols = logicalPlan.output.map { attr => + if (groupColExprIds.contains(attr.exprId)) { + attr + } else { + Alias(new First(attr).toAggregateExpression(), attr.name)() + } + } + Aggregate(groupCols, aggCols, logicalPlan) } /** - * Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have - * the same name after two Datasets have been joined. + * Returns a new [[Dataset]] with duplicate rows removed, considering only + * the subset of columns. + * + * @group typedrel + * @since 2.0.0 */ - def as(alias: String): Dataset[T] = withPlan(Subquery(alias, _)) + def dropDuplicates(colNames: Array[String]): Dataset[T] = dropDuplicates(colNames.toSeq) /** - * Converts this strongly typed collection of data to generic Dataframe. In contrast to the - * strongly typed objects that Dataset operations work on, a Dataframe returns generic [[Row]] - * objects that allow fields to be accessed by ordinal or name. + * Computes statistics for numeric columns, including count, mean, stddev, min, and max. + * If no columns are given, this function computes statistics for all numerical columns. + * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[Dataset]]. If you want to + * programmatically compute summary statistics, use the `agg` function instead. + * + * {{{ + * ds.describe("age", "height").show() + * + * // output: + * // summary age height + * // count 10.0 10.0 + * // mean 53.3 178.05 + * // stddev 11.6 15.7 + * // min 18.0 163.0 + * // max 92.0 192.0 + * }}} + * + * @group action + * @since 1.6.0 */ - // This is declared with parentheses to prevent the Scala compiler from treating - // `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDF(): DataFrame = DataFrame(sqlContext, logicalPlan) + @scala.annotation.varargs + def describe(cols: String*): DataFrame = withPlan { + + // The list of summary statistics to compute, in the form of expressions. + val statistics = List[(String, Expression => Expression)]( + "count" -> ((child: Expression) => Count(child).toAggregateExpression()), + "mean" -> ((child: Expression) => Average(child).toAggregateExpression()), + "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()), + "min" -> ((child: Expression) => Min(child).toAggregateExpression()), + "max" -> ((child: Expression) => Max(child).toAggregateExpression())) + + val outputCols = + (if (cols.isEmpty) numericColumns.map(usePrettyExpression(_).sql) else cols).toList + + val ret: Seq[Row] = if (outputCols.nonEmpty) { + val aggExprs = statistics.flatMap { case (_, colToAgg) => + outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) + } + + val row = groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq + + // Pivot the data so each summary is one row + row.grouped(outputCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) => + Row(statistic :: aggregation.toList: _*) + } + } else { + // If there are no output columns, just output a single column that contains the stats. + statistics.map { case (name, _) => Row(name) } + } + + // All columns are string type + val schema = StructType( + StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes + LocalRelation.fromExternalRows(schema, ret) + } /** - * Returns this Dataset. + * Returns the first `n` rows. + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. + * + * @group action * @since 1.6.0 */ - // This is declared with parentheses to prevent the Scala compiler from treating - // `ds.toDS("1")` as invoking this toDF and then apply on the returned Dataset. - def toDS(): Dataset[T] = this + def head(n: Int): Array[T] = withTypedCallback("head", limit(n)) { df => + df.collect(needCallback = false) + } /** - * Converts this Dataset to an RDD. + * Returns the first row. + * @group action * @since 1.6.0 */ - def rdd: RDD[T] = { - val tEnc = encoderFor[T] - val input = queryExecution.analyzed.output - queryExecution.toRdd.mapPartitions { iter => - val bound = tEnc.bind(input) - iter.map(bound.fromRow) - } - } + def head(): T = head(1).head - /* *********************** * - * Functional Operations * - * *********************** */ + /** + * Returns the first row. Alias for head(). + * @group action + * @since 1.6.0 + */ + def first(): T = head() /** * Concise syntax for chaining custom transformations. * {{{ - * def featurize(ds: Dataset[T]) = ... + * def featurize(ds: Dataset[T]): Dataset[U] = ... * - * dataset + * ds * .transform(featurize) * .transform(...) * }}} * + * @group func * @since 1.6.0 */ def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this) /** + * :: Experimental :: + * (Scala-specific) * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. + * + * @group func + * @since 1.6.0 + */ + @Experimental + def filter(func: T => Boolean): Dataset[T] = { + val deserialized = CatalystSerde.deserialize[T](logicalPlan) + val function = Literal.create(func, ObjectType(classOf[T => Boolean])) + val condition = Invoke(function, "apply", BooleanType, deserialized.output) + val filter = Filter(condition, deserialized) + withTypedPlan(CatalystSerde.serialize[T](filter)) + } + + /** + * :: Experimental :: + * (Java-specific) + * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. + * + * @group func * @since 1.6.0 */ - def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func)) + @Experimental + def filter(func: FilterFunction[T]): Dataset[T] = { + val deserialized = CatalystSerde.deserialize[T](logicalPlan) + val function = Literal.create(func, ObjectType(classOf[FilterFunction[T]])) + val condition = Invoke(function, "call", BooleanType, deserialized.output) + val filter = Filter(condition, deserialized) + withTypedPlan(CatalystSerde.serialize[T](filter)) + } /** + * :: Experimental :: + * (Scala-specific) * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * + * @group func * @since 1.6.0 */ - def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func)) + @Experimental + def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { + MapElements[T, U](func, logicalPlan) + } /** + * :: Experimental :: + * (Java-specific) * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * + * @group func + * @since 1.6.0 + */ + @Experimental + def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { + implicit val uEnc = encoder + withTypedPlan(MapElements[T, U](func, logicalPlan)) + } + + /** + * :: Experimental :: + * (Scala-specific) + * Returns a new [[Dataset]] that contains the result of applying `func` to each partition. + * + * @group func * @since 1.6.0 */ + @Experimental def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { - new Dataset( + new Dataset[U]( sqlContext, - MapPartitions[T, U]( - func, - encoderFor[T], - encoderFor[U], - encoderFor[U].schema.toAttributes, - logicalPlan)) + MapPartitions[T, U](func, logicalPlan), + implicitly[Encoder[U]]) + } + + /** + * :: Experimental :: + * (Java-specific) + * Returns a new [[Dataset]] that contains the result of applying `func` to each partition. + * + * @group func + * @since 1.6.0 + */ + @Experimental + def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = { + val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).asScala + mapPartitions(func)(encoder) } + /** + * :: Experimental :: + * (Scala-specific) + * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]], + * and then flattening the results. + * + * @group func + * @since 1.6.0 + */ + @Experimental def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] = mapPartitions(_.flatMap(func)) - /* ************** * - * Side effects * - * ************** */ + /** + * :: Experimental :: + * (Java-specific) + * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]], + * and then flattening the results. + * + * @group func + * @since 1.6.0 + */ + @Experimental + def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { + val func: (T) => Iterator[U] = x => f.call(x).asScala + flatMap(func)(encoder) + } /** - * Runs `func` on each element of this Dataset. + * Applies a function `f` to all rows. + * + * @group action * @since 1.6.0 */ - def foreach(func: T => Unit): Unit = rdd.foreach(func) + def foreach(f: T => Unit): Unit = withNewExecutionId { + rdd.foreach(f) + } /** - * Runs `func` on each partition of this Dataset. + * (Java-specific) + * Runs `func` on each element of this [[Dataset]]. + * + * @group action * @since 1.6.0 */ - def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func) + def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_)) - /* ************* * - * Aggregation * - * ************* */ + /** + * Applies a function f to each partition of this [[Dataset]]. + * + * @group action + * @since 1.6.0 + */ + def foreachPartition(f: Iterator[T] => Unit): Unit = withNewExecutionId { + rdd.foreachPartition(f) + } /** - * Reduces the elements of this Dataset using the specified binary function. The given function - * must be commutative and associative or the result may be non-deterministic. + * (Java-specific) + * Runs `func` on each partition of this [[Dataset]]. + * + * @group action * @since 1.6.0 */ - def reduce(func: (T, T) => T): T = rdd.reduce(func) + def foreachPartition(func: ForeachPartitionFunction[T]): Unit = + foreachPartition(it => func.call(it.asJava)) /** - * Aggregates the elements of each partition, and then the results for all the partitions, using a - * given associative and commutative function and a neutral "zero value". + * Returns the first `n` rows in the [[Dataset]]. * - * This behaves somewhat differently than the fold operations implemented for non-distributed - * collections in functional languages like Scala. This fold operation may be applied to - * partitions individually, and then those results will be folded into the final result. - * If op is not commutative, then the result may differ from that of a fold applied to a - * non-distributed collection. + * Running take requires moving data into the application's driver process, and doing so with + * a very large `n` can crash the driver process with OutOfMemoryError. + * + * @group action * @since 1.6.0 */ - def fold(zeroValue: T)(op: (T, T) => T): T = rdd.fold(zeroValue)(op) + def take(n: Int): Array[T] = head(n) /** - * Returns a [[GroupedDataset]] where the data is grouped by the given key function. + * Returns the first `n` rows in the [[Dataset]] as a list. + * + * Running take requires moving data into the application's driver process, and doing so with + * a very large `n` can crash the driver process with OutOfMemoryError. + * + * @group action * @since 1.6.0 */ - def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = { - val inputPlan = queryExecution.analyzed - val withGroupingKey = AppendColumn(func, inputPlan) - val executed = sqlContext.executePlan(withGroupingKey) + def takeAsList(n: Int): java.util.List[T] = java.util.Arrays.asList(take(n) : _*) - new GroupedDataset( - encoderFor[K].resolve(withGroupingKey.newColumns), - encoderFor[T].bind(inputPlan.output), - executed, - inputPlan.output, - withGroupingKey.newColumns) - } + /** + * Returns an array that contains all of [[Row]]s in this [[Dataset]]. + * + * Running collect requires moving all the data into the application's driver process, and + * doing so on a very large dataset can crash the driver process with OutOfMemoryError. + * + * For Java API, use [[collectAsList]]. + * + * @group action + * @since 1.6.0 + */ + def collect(): Array[T] = collect(needCallback = true) /** - * Returns a [[GroupedDataset]] where the data is grouped by the given [[Column]] expressions. + * Returns a Java list that contains all of [[Row]]s in this [[Dataset]]. + * + * Running collect requires moving all the data into the application's driver process, and + * doing so on a very large dataset can crash the driver process with OutOfMemoryError. + * + * @group action * @since 1.6.0 */ - @scala.annotation.varargs - def groupBy(cols: Column*): GroupedDataset[Row, T] = { - val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias) - val withKey = Project(withKeyColumns, logicalPlan) - val executed = sqlContext.executePlan(withKey) + def collectAsList(): java.util.List[T] = withCallback("collectAsList", toDF()) { _ => + withNewExecutionId { + val values = queryExecution.executedPlan.executeCollect().map(boundTEncoder.fromRow) + java.util.Arrays.asList(values : _*) + } + } - val dataAttributes = executed.analyzed.output.dropRight(cols.size) - val keyAttributes = executed.analyzed.output.takeRight(cols.size) + private def collect(needCallback: Boolean): Array[T] = { + def execute(): Array[T] = withNewExecutionId { + queryExecution.executedPlan.executeCollect().map(boundTEncoder.fromRow) + } - new GroupedDataset( - RowEncoder(keyAttributes.toStructType), - encoderFor[T], - executed, - dataAttributes, - keyAttributes) + if (needCallback) { + withCallback("collect", toDF())(_ => execute()) + } else { + execute() + } } - /* ****************** * - * Typed Relational * - * ****************** */ + /** + * Return an iterator that contains all of [[Row]]s in this [[Dataset]]. + * + * The iterator will consume as much memory as the largest partition in this [[Dataset]]. + * + * Note: this results in multiple Spark jobs, and if the input Dataset is the result + * of a wide transformation (e.g. join with different partitioners), to avoid + * recomputing the input Dataset should be cached first. + * + * @group action + * @since 2.0.0 + */ + def toLocalIterator(): java.util.Iterator[T] = withCallback("toLocalIterator", toDF()) { _ => + withNewExecutionId { + queryExecution.executedPlan.executeToIterator().map(boundTEncoder.fromRow).asJava + } + } /** - * Selects a set of column based expressions. - * {{{ - * df.select($"colA", $"colB" + 1) - * }}} - * @group dfops - * @since 1.3.0 + * Returns the number of rows in the [[Dataset]]. + * @group action + * @since 1.6.0 */ - // Copied from Dataframe to make sure we don't have invalid overloads. - @scala.annotation.varargs - def select(cols: Column*): DataFrame = toDF().select(cols: _*) + def count(): Long = withCallback("count", groupBy().count()) { df => + df.collect(needCallback = false).head.getLong(0) + } /** - * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element. + * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. * - * {{{ - * val ds = Seq(1, 2, 3).toDS() - * val newDS = ds.select(e[Int]("value + 1")) - * }}} + * @group typedrel * @since 1.6.0 */ - def select[U1: Encoder](c1: TypedColumn[U1]): Dataset[U1] = { - new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan)) + def repartition(numPartitions: Int): Dataset[T] = withTypedPlan { + Repartition(numPartitions, shuffle = true, logicalPlan) } /** - * Internal helper function for building typed selects that return tuples. For simplicity and - * code reuse, we do this without the help of the type system and then use helper functions - * that cast appropriately for the user facing interface. + * Returns a new [[Dataset]] partitioned by the given partitioning expressions into + * `numPartitions`. The resulting Dataset is hash partitioned. + * + * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL). + * + * @group typedrel + * @since 2.0.0 */ - protected def selectUntyped(columns: TypedColumn[_]*): Dataset[_] = { - val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() } - val unresolvedPlan = Project(aliases, logicalPlan) - val execution = new QueryExecution(sqlContext, unresolvedPlan) - // Rebind the encoders to the nested schema that will be produced by the select. - val encoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map { - case (e: ExpressionEncoder[_], a) if !e.flat => - e.nested(a.toAttribute).resolve(execution.analyzed.output) - case (e, a) => - e.unbind(a.toAttribute :: Nil).resolve(execution.analyzed.output) - } - new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) + @scala.annotation.varargs + def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan { + RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, Some(numPartitions)) } /** - * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. - * @since 1.6.0 + * Returns a new [[Dataset]] partitioned by the given partitioning expressions preserving + * the existing number of partitions. The resulting Datasetis hash partitioned. + * + * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL). + * + * @group typedrel + * @since 2.0.0 */ - def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] = - selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]] + @scala.annotation.varargs + def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan { + RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions = None) + } /** - * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. + * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g. + * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of + * the 100 new partitions will claim 10 of the current partitions. + * + * @group rdd * @since 1.6.0 */ - def select[U1, U2, U3]( - c1: TypedColumn[U1], - c2: TypedColumn[U2], - c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] = - selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]] + def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan { + Repartition(numPartitions, shuffle = false, logicalPlan) + } /** - * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. - * @since 1.6.0 + * Returns a new [[Dataset]] that contains only the unique rows from this [[Dataset]]. + * This is an alias for `dropDuplicates`. + * + * Note that, equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. + * + * @group typedrel + * @since 2.0.0 */ - def select[U1, U2, U3, U4]( - c1: TypedColumn[U1], - c2: TypedColumn[U2], - c3: TypedColumn[U3], - c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] = - selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]] + def distinct(): Dataset[T] = dropDuplicates() /** - * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`). + * + * @group basic * @since 1.6.0 */ - def select[U1, U2, U3, U4, U5]( - c1: TypedColumn[U1], - c2: TypedColumn[U2], - c3: TypedColumn[U3], - c4: TypedColumn[U4], - c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] = - selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]] + def persist(): this.type = { + sqlContext.cacheManager.cacheQuery(this) + this + } - /* **************** * - * Set operations * - * **************** */ + /** + * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`). + * + * @group basic + * @since 1.6.0 + */ + def cache(): this.type = persist() /** - * Returns a new [[Dataset]] that contains only the unique elements of this [[Dataset]]. + * Persist this [[Dataset]] with the given storage level. + * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, + * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, + * `MEMORY_AND_DISK_2`, etc. * - * Note that, equality checking is performed directly on the encoded representation of the data - * and thus is not affected by a custom `equals` function defined on `T`. + * @group basic * @since 1.6.0 */ - def distinct: Dataset[T] = withPlan(Distinct) + def persist(newLevel: StorageLevel): this.type = { + sqlContext.cacheManager.cacheQuery(this, None, newLevel) + this + } /** - * Returns a new [[Dataset]] that contains only the elements of this [[Dataset]] that are also - * present in `other`. + * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk. * - * Note that, equality checking is performed directly on the encoded representation of the data - * and thus is not affected by a custom `equals` function defined on `T`. + * @param blocking Whether to block until all blocks are deleted. + * + * @group basic * @since 1.6.0 */ - def intersect(other: Dataset[T]): Dataset[T] = - withPlan[T](other)(Intersect) + def unpersist(blocking: Boolean): this.type = { + sqlContext.cacheManager.tryUncacheQuery(this, blocking) + this + } /** - * Returns a new [[Dataset]] that contains the elements of both this and the `other` [[Dataset]] - * combined. + * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk. * - * Note that, this function is not a typical set union operation, in that it does not eliminate - * duplicate items. As such, it is analagous to `UNION ALL` in SQL. + * @group basic * @since 1.6.0 */ - def union(other: Dataset[T]): Dataset[T] = - withPlan[T](other)(Union) + def unpersist(): this.type = unpersist(blocking = false) /** - * Returns a new [[Dataset]] where any elements present in `other` have been removed. + * Represents the content of the [[Dataset]] as an [[RDD]] of [[Row]]s. Note that the RDD is + * memoized. Once called, it won't change even if you change any query planning related Spark SQL + * configurations (e.g. `spark.sql.shuffle.partitions`). * - * Note that, equality checking is performed directly on the encoded representation of the data - * and thus is not affected by a custom `equals` function defined on `T`. + * @group rdd * @since 1.6.0 */ - def subtract(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Except) + lazy val rdd: RDD[T] = { + queryExecution.toRdd.mapPartitions { rows => + rows.map(boundTEncoder.fromRow) + } + } - /* ****** * - * Joins * - * ****** */ + /** + * Returns the content of the [[Dataset]] as a [[JavaRDD]] of [[Row]]s. + * @group rdd + * @since 1.6.0 + */ + def toJavaRDD: JavaRDD[T] = rdd.toJavaRDD() /** - * Joins this [[Dataset]] returning a [[Tuple2]] for each pair where `condition` evaluates to - * true. + * Returns the content of the [[Dataset]] as a [[JavaRDD]] of [[Row]]s. + * @group rdd + * @since 1.6.0 + */ + def javaRDD: JavaRDD[T] = toJavaRDD + + /** + * Registers this [[Dataset]] as a temporary table using the given name. The lifetime of this + * temporary table is tied to the [[SQLContext]] that was used to create this Dataset. * - * This is similar to the relation `join` function with one important difference in the - * result schema. Since `joinWith` preserves objects present on either side of the join, the - * result schema is similarly nested into a tuple under the column names `_1` and `_2`. + * @group basic + * @since 1.6.0 + */ + def registerTempTable(tableName: String): Unit = { + sqlContext.registerDataFrameAsTable(toDF(), tableName) + } + + /** + * :: Experimental :: + * Interface for saving the content of the [[Dataset]] out into external storage or streams. * - * This type of join can be useful both for preserving type-safety with the original object - * types as well as working with relational data where either side of the join has column - * names in common. + * @group output + * @since 1.6.0 */ - def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { - val left = this.logicalPlan - val right = other.logicalPlan + @Experimental + def write: DataFrameWriter = new DataFrameWriter(toDF()) - val leftData = this.encoder match { - case e if e.flat => Alias(left.output.head, "_1")() - case _ => Alias(CreateStruct(left.output), "_1")() - } - val rightData = other.encoder match { - case e if e.flat => Alias(right.output.head, "_2")() - case _ => Alias(CreateStruct(right.output), "_2")() + /** + * Returns the content of the [[Dataset]] as a Dataset of JSON strings. + * @since 2.0.0 + */ + def toJSON: Dataset[String] = { + val rowSchema = this.schema + val rdd: RDD[String] = queryExecution.toRdd.mapPartitions { iter => + val writer = new CharArrayWriter() + // create the Generator without separator inserted between 2 records + val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) + + new Iterator[String] { + override def hasNext: Boolean = iter.hasNext + override def next(): String = { + JacksonGenerator(rowSchema, gen)(iter.next()) + gen.flush() + + val json = writer.toString + if (hasNext) { + writer.reset() + } else { + gen.close() + } + + json + } + } } - val leftEncoder = - if (encoder.flat) encoder else encoder.nested(leftData.toAttribute) - val rightEncoder = - if (other.encoder.flat) other.encoder else other.encoder.nested(rightData.toAttribute) - implicit val tuple2Encoder: Encoder[(T, U)] = - ExpressionEncoder.tuple( - leftEncoder, - rightEncoder.rebind(right.output, left.output ++ right.output)) + import sqlContext.implicits.newStringEncoder + sqlContext.createDataset(rdd) + } - withPlan[(T, U)](other) { (left, right) => - Project( - leftData :: rightData :: Nil, - Join(left, right, Inner, Some(condition.expr))) + /** + * Returns a best-effort snapshot of the files that compose this Dataset. This method simply + * asks each constituent BaseRelation for its respective files and takes the union of all results. + * Depending on the source relations, this may not find all input files. Duplicates are removed. + * + * @group basic + * @since 2.0.0 + */ + def inputFiles: Array[String] = { + val files: Seq[String] = logicalPlan.collect { + case LogicalRelation(fsBasedRelation: FileRelation, _, _) => + fsBasedRelation.inputFiles + case fr: FileRelation => + fr.inputFiles + }.flatten + files.toSet.toArray + } + + //////////////////////////////////////////////////////////////////////////// + // For Python API + //////////////////////////////////////////////////////////////////////////// + + /** + * Converts a JavaRDD to a PythonRDD. + */ + protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { + val structType = schema // capture it for closure + val rdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType)) + EvaluatePython.javaToPython(rdd) + } + + protected[sql] def collectToPython(): Int = { + withNewExecutionId { + PythonRDD.collectAndServe(javaToPython.rdd) } } - /* ************************** * - * Gather to Driver Actions * - * ************************** */ + protected[sql] def toPythonIterator(): Int = { + withNewExecutionId { + PythonRDD.toLocalIteratorAndServe(javaToPython.rdd) + } + } - /** Returns the first element in this [[Dataset]]. */ - def first(): T = rdd.first() + //////////////////////////////////////////////////////////////////////////// + // Private Helpers + //////////////////////////////////////////////////////////////////////////// - /** Collects the elements to an Array. */ - def collect(): Array[T] = rdd.collect() + /** + * Wrap a Dataset action to track all Spark jobs in the body so that we can connect them with + * an execution. + */ + private[sql] def withNewExecutionId[U](body: => U): U = { + SQLExecution.withNewExecutionId(sqlContext, queryExecution)(body) + } - /** Returns the first `num` elements of this [[Dataset]] as an Array. */ - def take(num: Int): Array[T] = rdd.take(num) + /** + * Wrap a Dataset action to track the QueryExecution and time cost, then report to the + * user-registered callback functions. + */ + private def withCallback[U](name: String, df: DataFrame)(action: DataFrame => U) = { + try { + df.queryExecution.executedPlan.foreach { plan => + plan.resetMetrics() + } + val start = System.nanoTime() + val result = action(df) + val end = System.nanoTime() + sqlContext.listenerManager.onSuccess(name, df.queryExecution, end - start) + result + } catch { + case e: Exception => + sqlContext.listenerManager.onFailure(name, df.queryExecution, e) + throw e + } + } - /* ******************** * - * Internal Functions * - * ******************** */ + private def withTypedCallback[A, B](name: String, ds: Dataset[A])(action: Dataset[A] => B) = { + try { + ds.queryExecution.executedPlan.foreach { plan => + plan.resetMetrics() + } + val start = System.nanoTime() + val result = action(ds) + val end = System.nanoTime() + sqlContext.listenerManager.onSuccess(name, ds.queryExecution, end - start) + result + } catch { + case e: Exception => + sqlContext.listenerManager.onFailure(name, ds.queryExecution, e) + throw e + } + } - private[sql] def logicalPlan = queryExecution.analyzed + private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = { + val sortOrder: Seq[SortOrder] = sortExprs.map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + withTypedPlan { + Sort(sortOrder, global = global, logicalPlan) + } + } - private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] = - new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), encoder) + /** A convenient function to wrap a logical plan and produce a DataFrame. */ + @inline private def withPlan(logicalPlan: => LogicalPlan): DataFrame = { + Dataset.ofRows(sqlContext, logicalPlan) + } - private[sql] def withPlan[R : Encoder]( - other: Dataset[_])( - f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] = - new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan)) + /** A convenient function to wrap a logical plan and produce a Dataset. */ + @inline private def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = { + Dataset(sqlContext, logicalPlan) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala index 45f0098b92887..47b81c17a31dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql /** - * A container for a [[Dataset]], used for implicit conversions. + * A container for a [[Dataset]], used for implicit conversions in Scala. * * To use this, import implicit conversions in SQL: * {{{ @@ -27,9 +27,15 @@ package org.apache.spark.sql * * @since 1.6.0 */ -case class DatasetHolder[T] private[sql](private val df: Dataset[T]) { +case class DatasetHolder[T] private[sql](private val ds: Dataset[T]) { + + // This is declared with parentheses to prevent the Scala compiler from treating + // `rdd.toDS("1")` as invoking this toDS and then apply on the returned Dataset. + def toDS(): Dataset[T] = ds // This is declared with parentheses to prevent the Scala compiler from treating // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDS(): Dataset[T] = df + def toDF(): DataFrame = ds.toDF() + + def toDF(colNames: String*): DataFrame = ds.toDF(colNames : _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala index 717709e4f9312..c5df028485373 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule /** * :: Experimental :: @@ -31,15 +33,18 @@ import org.apache.spark.annotation.Experimental * @since 1.3.0 */ @Experimental -class ExperimentalMethods protected[sql](sqlContext: SQLContext) { +class ExperimentalMethods private[sql]() { /** * Allows extra strategies to be injected into the query planner at runtime. Note this API - * should be consider experimental and is not intended to be stable across releases. + * should be considered experimental and is not intended to be stable across releases. * * @since 1.3.0 */ @Experimental var extraStrategies: Seq[Strategy] = Nil + @Experimental + var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala deleted file mode 100644 index 7cf66b65c8722..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ /dev/null @@ -1,311 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import scala.collection.JavaConverters._ -import scala.language.implicitConversions - -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate} -import org.apache.spark.sql.types.NumericType - - -/** - * :: Experimental :: - * A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]]. - * - * The main method is the agg function, which has multiple variants. This class also contains - * convenience some first order statistics such as mean, sum for convenience. - * - * @since 1.3.0 - */ -@Experimental -class GroupedData protected[sql]( - df: DataFrame, - groupingExprs: Seq[Expression], - private val groupType: GroupedData.GroupType) { - - private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { - val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { - groupingExprs ++ aggExprs - } else { - aggExprs - } - - val aliasedAgg = aggregates.map { - // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we - // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to - // make it a NamedExpression. - case u: UnresolvedAttribute => UnresolvedAlias(u) - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - } - groupType match { - case GroupedData.GroupByType => - DataFrame( - df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) - case GroupedData.RollupType => - DataFrame( - df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aliasedAgg)) - case GroupedData.CubeType => - DataFrame( - df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg)) - } - } - - private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => Expression) - : DataFrame = { - - val columnExprs = if (colNames.isEmpty) { - // No columns specified. Use all numeric columns. - df.numericColumns - } else { - // Make sure all specified columns are numeric. - colNames.map { colName => - val namedExpr = df.resolve(colName) - if (!namedExpr.dataType.isInstanceOf[NumericType]) { - throw new AnalysisException( - s""""$colName" is not a numeric column. """ + - "Aggregation function can only be applied on a numeric column.") - } - namedExpr - } - } - toDF(columnExprs.map(f)) - } - - private[this] def strToExpr(expr: String): (Expression => Expression) = { - expr.toLowerCase match { - case "avg" | "average" | "mean" => Average - case "max" => Max - case "min" => Min - case "stddev" | "std" => StddevSamp - case "stddev_pop" => StddevPop - case "stddev_samp" => StddevSamp - case "variance" => VarianceSamp - case "var_pop" => VariancePop - case "var_samp" => VarianceSamp - case "sum" => Sum - case "skewness" => Skewness - case "kurtosis" => Kurtosis - case "count" | "size" => - // Turn count(*) into count(1) - (inputExpr: Expression) => inputExpr match { - case s: Star => Count(Literal(1)) - case _ => Count(inputExpr) - } - } - } - - /** - * (Scala-specific) Compute aggregates by specifying a map from column name to - * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. - * - * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * df.groupBy("department").agg( - * "age" -> "max", - * "expense" -> "sum" - * ) - * }}} - * - * @since 1.3.0 - */ - def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { - agg((aggExpr +: aggExprs).toMap) - } - - /** - * (Scala-specific) Compute aggregates by specifying a map from column name to - * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. - * - * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * df.groupBy("department").agg(Map( - * "age" -> "max", - * "expense" -> "sum" - * )) - * }}} - * - * @since 1.3.0 - */ - def agg(exprs: Map[String, String]): DataFrame = { - toDF(exprs.map { case (colName, expr) => - strToExpr(expr)(df(colName).expr) - }.toSeq) - } - - /** - * (Java-specific) Compute aggregates by specifying a map from column name to - * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. - * - * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * import com.google.common.collect.ImmutableMap; - * df.groupBy("department").agg(ImmutableMap.of("age", "max", "expense", "sum")); - * }}} - * - * @since 1.3.0 - */ - def agg(exprs: java.util.Map[String, String]): DataFrame = { - agg(exprs.asScala.toMap) - } - - /** - * Compute aggregates by specifying a series of aggregate columns. Note that this function by - * default retains the grouping columns in its output. To not retain grouping columns, set - * `spark.sql.retainGroupColumns` to false. - * - * The available aggregate methods are defined in [[org.apache.spark.sql.functions]]. - * - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * - * // Scala: - * import org.apache.spark.sql.functions._ - * df.groupBy("department").agg(max("age"), sum("expense")) - * - * // Java: - * import static org.apache.spark.sql.functions.*; - * df.groupBy("department").agg(max("age"), sum("expense")); - * }}} - * - * Note that before Spark 1.4, the default behavior is to NOT retain grouping columns. To change - * to that behavior, set config variable `spark.sql.retainGroupColumns` to `false`. - * {{{ - * // Scala, 1.3.x: - * df.groupBy("department").agg($"department", max("age"), sum("expense")) - * - * // Java, 1.3.x: - * df.groupBy("department").agg(col("department"), max("age"), sum("expense")); - * }}} - * - * @since 1.3.0 - */ - @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame = { - toDF((expr +: exprs).map(_.expr)) - } - - /** - * Count the number of rows for each group. - * The resulting [[DataFrame]] will also contain the grouping columns. - * - * @since 1.3.0 - */ - def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)), "count")())) - - /** - * Compute the average value for each numeric columns for each group. This is an alias for `avg`. - * The resulting [[DataFrame]] will also contain the grouping columns. - * When specified columns are given, only compute the average values for them. - * - * @since 1.3.0 - */ - @scala.annotation.varargs - def mean(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Average) - } - - /** - * Compute the max value for each numeric columns for each group. - * The resulting [[DataFrame]] will also contain the grouping columns. - * When specified columns are given, only compute the max values for them. - * - * @since 1.3.0 - */ - @scala.annotation.varargs - def max(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Max) - } - - /** - * Compute the mean value for each numeric columns for each group. - * The resulting [[DataFrame]] will also contain the grouping columns. - * When specified columns are given, only compute the mean values for them. - * - * @since 1.3.0 - */ - @scala.annotation.varargs - def avg(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Average) - } - - /** - * Compute the min value for each numeric column for each group. - * The resulting [[DataFrame]] will also contain the grouping columns. - * When specified columns are given, only compute the min values for them. - * - * @since 1.3.0 - */ - @scala.annotation.varargs - def min(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Min) - } - - /** - * Compute the sum for each numeric columns for each group. - * The resulting [[DataFrame]] will also contain the grouping columns. - * When specified columns are given, only compute the sum for them. - * - * @since 1.3.0 - */ - @scala.annotation.varargs - def sum(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Sum) - } -} - - -/** - * Companion object for GroupedData. - */ -private[sql] object GroupedData { - - def apply( - df: DataFrame, - groupingExprs: Seq[Expression], - groupType: GroupType): GroupedData = { - new GroupedData(df, groupingExprs, groupType: GroupType) - } - - /** - * The Grouping Type - */ - private[sql] trait GroupType - - /** - * To indicate it's the GroupBy - */ - private[sql] object GroupByType extends GroupType - - /** - * To indicate it's the CUBE - */ - private[sql] object CubeType extends GroupType - - /** - * To indicate it's the ROLLUP - */ - private[sql] object RollupType extends GroupType -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala deleted file mode 100644 index 96d6e9dd548e5..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder} -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.QueryExecution - -/** - * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not - * construct a [[GroupedDataset]] directly, but should instead call `groupBy` on an existing - * [[Dataset]]. - */ -class GroupedDataset[K, T] private[sql]( - private val kEncoder: Encoder[K], - private val tEncoder: Encoder[T], - queryExecution: QueryExecution, - private val dataAttributes: Seq[Attribute], - private val groupingAttributes: Seq[Attribute]) extends Serializable { - - private implicit val kEnc = kEncoder match { - case e: ExpressionEncoder[K] => e.resolve(groupingAttributes) - case other => - throw new UnsupportedOperationException("Only expression encoders are currently supported") - } - - private implicit val tEnc = tEncoder match { - case e: ExpressionEncoder[T] => e.resolve(dataAttributes) - case other => - throw new UnsupportedOperationException("Only expression encoders are currently supported") - } - - private def logicalPlan = queryExecution.analyzed - private def sqlContext = queryExecution.sqlContext - - /** - * Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified - * type. The mapping of key columns to the type follows the same rules as `as` on [[Dataset]]. - */ - def asKey[L : Encoder]: GroupedDataset[L, T] = - new GroupedDataset( - encoderFor[L], - tEncoder, - queryExecution, - dataAttributes, - groupingAttributes) - - /** - * Returns a [[Dataset]] that contains each unique key. - */ - def keys: Dataset[K] = { - new Dataset[K]( - sqlContext, - Distinct( - Project(groupingAttributes, logicalPlan))) - } - - /** - * Applies the given function to each group of data. For each unique group, the function will - * be passed the group key and an iterator that contains all of the elements in the group. The - * function can return an iterator containing elements of an arbitrary type which will be returned - * as a new [[Dataset]]. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the memory - * constraints of their cluster. - */ - def mapGroups[U : Encoder](f: (K, Iterator[T]) => Iterator[U]): Dataset[U] = { - new Dataset[U]( - sqlContext, - MapGroups(f, groupingAttributes, logicalPlan)) - } - - /** - * Applies the given function to each cogrouped data. For each unique group, the function will - * be passed the grouping key and 2 iterators containing all elements in the group from - * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an - * arbitrary type which will be returned as a new [[Dataset]]. - */ - def cogroup[U, R : Encoder]( - other: GroupedDataset[K, U])( - f: (K, Iterator[T], Iterator[U]) => Iterator[R]): Dataset[R] = { - implicit def uEnc: Encoder[U] = other.tEncoder - new Dataset[R]( - sqlContext, - CoGroup( - f, - this.groupingAttributes, - other.groupingAttributes, - this.logicalPlan, - other.logicalPlan)) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala new file mode 100644 index 0000000000000..05e13e66d137c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.function._ +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, OuterScopes} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.QueryExecution + +/** + * :: Experimental :: + * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not + * construct a [[KeyValueGroupedDataset]] directly, but should instead call `groupBy` on an existing + * [[Dataset]]. + * + * @since 2.0.0 + */ +@Experimental +class KeyValueGroupedDataset[K, V] private[sql]( + kEncoder: Encoder[K], + vEncoder: Encoder[V], + val queryExecution: QueryExecution, + private val dataAttributes: Seq[Attribute], + private val groupingAttributes: Seq[Attribute]) extends Serializable { + + // Similar to [[Dataset]], we use unresolved encoders for later composition and resolved encoders + // when constructing new logical plans that will operate on the output of the current + // queryexecution. + + private implicit val unresolvedKEncoder = encoderFor(kEncoder) + private implicit val unresolvedVEncoder = encoderFor(vEncoder) + + private val resolvedKEncoder = + unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes) + private val resolvedVEncoder = + unresolvedVEncoder.resolve(dataAttributes, OuterScopes.outerScopes) + + private def logicalPlan = queryExecution.analyzed + private def sqlContext = queryExecution.sqlContext + + /** + * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the + * specified type. The mapping of key columns to the type follows the same rules as `as` on + * [[Dataset]]. + * + * @since 1.6.0 + */ + def keyAs[L : Encoder]: KeyValueGroupedDataset[L, V] = + new KeyValueGroupedDataset( + encoderFor[L], + unresolvedVEncoder, + queryExecution, + dataAttributes, + groupingAttributes) + + /** + * Returns a [[Dataset]] that contains each unique key. + * + * @since 1.6.0 + */ + def keys: Dataset[K] = { + Dataset[K]( + sqlContext, + Distinct( + Project(groupingAttributes, logicalPlan))) + } + + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an iterator containing elements of an arbitrary type which will be returned + * as a new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an + * [[org.apache.spark.sql.expressions#Aggregator Aggregator]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @since 1.6.0 + */ + def flatMapGroups[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = { + Dataset[U]( + sqlContext, + MapGroups( + f, + groupingAttributes, + dataAttributes, + logicalPlan)) + } + + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an iterator containing elements of an arbitrary type which will be returned + * as a new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an + * [[org.apache.spark.sql.expressions#Aggregator Aggregator]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @since 1.6.0 + */ + def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { + flatMapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder) + } + + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an element of arbitrary type which will be returned as a new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an + * [[org.apache.spark.sql.expressions#Aggregator Aggregator]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @since 1.6.0 + */ + def mapGroups[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { + val func = (key: K, it: Iterator[V]) => Iterator(f(key, it)) + flatMapGroups(func) + } + + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an element of arbitrary type which will be returned as a new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an + * [[org.apache.spark.sql.expressions#Aggregator Aggregator]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @since 1.6.0 + */ + def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { + mapGroups((key, data) => f.call(key, data.asJava))(encoder) + } + + /** + * Reduces the elements of each group of data using the specified binary function. + * The given function must be commutative and associative or the result may be non-deterministic. + * + * @since 1.6.0 + */ + def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = { + val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f))) + + implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedVEncoder) + flatMapGroups(func) + } + + /** + * Reduces the elements of each group of data using the specified binary function. + * The given function must be commutative and associative or the result may be non-deterministic. + * + * @since 1.6.0 + */ + def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = { + reduceGroups(f.call _) + } + + /** + * Internal helper function for building typed aggregations that return tuples. For simplicity + * and code reuse, we do this without the help of the type system and then use helper functions + * that cast appropriately for the user facing interface. + * TODO: does not handle aggregations that return nonflat results, + */ + protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { + val encoders = columns.map(_.encoder) + val namedColumns = + columns.map(_.withInputType(unresolvedVEncoder.deserializer, dataAttributes).named) + val keyColumn = if (resolvedKEncoder.flat) { + assert(groupingAttributes.length == 1) + groupingAttributes.head + } else { + Alias(CreateStruct(groupingAttributes), "key")() + } + val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) + val execution = new QueryExecution(sqlContext, aggregate) + + new Dataset( + sqlContext, + execution, + ExpressionEncoder.tuple(unresolvedKEncoder +: encoders)) + } + + /** + * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key + * and the result of computing this aggregation over all elements in the group. + * + * @since 1.6.0 + */ + def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] = + aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + * + * @since 1.6.0 + */ + def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] = + aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + * + * @since 1.6.0 + */ + def agg[U1, U2, U3]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] = + aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + * + * @since 1.6.0 + */ + def agg[U1, U2, U3, U4]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] = + aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]] + + /** + * Returns a [[Dataset]] that contains a tuple with each key and the number of items present + * for that key. + * + * @since 1.6.0 + */ + def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long]())) + + /** + * Applies the given function to each cogrouped data. For each unique group, the function will + * be passed the grouping key and 2 iterators containing all elements in the group from + * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an + * arbitrary type which will be returned as a new [[Dataset]]. + * + * @since 1.6.0 + */ + def cogroup[U, R : Encoder]( + other: KeyValueGroupedDataset[K, U])( + f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { + implicit val uEncoder = other.unresolvedVEncoder + Dataset[R]( + sqlContext, + CoGroup( + f, + this.groupingAttributes, + other.groupingAttributes, + this.dataAttributes, + other.dataAttributes, + this.logicalPlan, + other.logicalPlan)) + } + + /** + * Applies the given function to each cogrouped data. For each unique group, the function will + * be passed the grouping key and 2 iterators containing all elements in the group from + * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an + * arbitrary type which will be returned as a new [[Dataset]]. + * + * @since 1.6.0 + */ + def cogroup[U, R]( + other: KeyValueGroupedDataset[K, U], + f: CoGroupFunction[K, V, U, R], + encoder: Encoder[R]): Dataset[R] = { + cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala new file mode 100644 index 0000000000000..7dbf2e6c7c798 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -0,0 +1,414 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ +import scala.language.implicitConversions + +import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Pivot} +import org.apache.spark.sql.catalyst.util.usePrettyExpression +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.NumericType + +/** + * A set of methods for aggregations on a [[DataFrame]], created by [[Dataset.groupBy]]. + * + * The main method is the agg function, which has multiple variants. This class also contains + * convenience some first order statistics such as mean, sum for convenience. + * + * @since 2.0.0 + */ +class RelationalGroupedDataset protected[sql]( + df: DataFrame, + groupingExprs: Seq[Expression], + groupType: RelationalGroupedDataset.GroupType) { + + private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { + val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { + groupingExprs ++ aggExprs + } else { + aggExprs + } + + val aliasedAgg = aggregates.map(alias) + + groupType match { + case RelationalGroupedDataset.GroupByType => + Dataset.ofRows( + df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) + case RelationalGroupedDataset.RollupType => + Dataset.ofRows( + df.sqlContext, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) + case RelationalGroupedDataset.CubeType => + Dataset.ofRows( + df.sqlContext, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan)) + case RelationalGroupedDataset.PivotType(pivotCol, values) => + val aliasedGrps = groupingExprs.map(alias) + Dataset.ofRows( + df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) + } + } + + // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we + // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to + // make it a NamedExpression. + private[this] def alias(expr: Expression): NamedExpression = expr match { + case u: UnresolvedAttribute => UnresolvedAlias(u) + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() + } + + private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction) + : DataFrame = { + + val columnExprs = if (colNames.isEmpty) { + // No columns specified. Use all numeric columns. + df.numericColumns + } else { + // Make sure all specified columns are numeric. + colNames.map { colName => + val namedExpr = df.resolve(colName) + if (!namedExpr.dataType.isInstanceOf[NumericType]) { + throw new AnalysisException( + s""""$colName" is not a numeric column. """ + + "Aggregation function can only be applied on a numeric column.") + } + namedExpr + } + } + toDF(columnExprs.map(expr => f(expr).toAggregateExpression())) + } + + private[this] def strToExpr(expr: String): (Expression => Expression) = { + val exprToFunc: (Expression => Expression) = { + (inputExpr: Expression) => expr.toLowerCase match { + // We special handle a few cases that have alias that are not in function registry. + case "avg" | "average" | "mean" => + UnresolvedFunction("avg", inputExpr :: Nil, isDistinct = false) + case "stddev" | "std" => + UnresolvedFunction("stddev", inputExpr :: Nil, isDistinct = false) + // Also special handle count because we need to take care count(*). + case "count" | "size" => + // Turn count(*) into count(1) + inputExpr match { + case s: Star => Count(Literal(1)).toAggregateExpression() + case _ => Count(inputExpr).toAggregateExpression() + } + case name => UnresolvedFunction(name, inputExpr :: Nil, isDistinct = false) + } + } + (inputExpr: Expression) => exprToFunc(inputExpr) + } + + /** + * (Scala-specific) Compute aggregates by specifying a map from column name to + * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. + * + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * df.groupBy("department").agg( + * "age" -> "max", + * "expense" -> "sum" + * ) + * }}} + * + * @since 1.3.0 + */ + def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { + agg((aggExpr +: aggExprs).toMap) + } + + /** + * (Scala-specific) Compute aggregates by specifying a map from column name to + * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. + * + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * df.groupBy("department").agg(Map( + * "age" -> "max", + * "expense" -> "sum" + * )) + * }}} + * + * @since 1.3.0 + */ + def agg(exprs: Map[String, String]): DataFrame = { + toDF(exprs.map { case (colName, expr) => + strToExpr(expr)(df(colName).expr) + }.toSeq) + } + + /** + * (Java-specific) Compute aggregates by specifying a map from column name to + * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. + * + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * import com.google.common.collect.ImmutableMap; + * df.groupBy("department").agg(ImmutableMap.of("age", "max", "expense", "sum")); + * }}} + * + * @since 1.3.0 + */ + def agg(exprs: java.util.Map[String, String]): DataFrame = { + agg(exprs.asScala.toMap) + } + + /** + * Compute aggregates by specifying a series of aggregate columns. Note that this function by + * default retains the grouping columns in its output. To not retain grouping columns, set + * `spark.sql.retainGroupColumns` to false. + * + * The available aggregate methods are defined in [[org.apache.spark.sql.functions]]. + * + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * + * // Scala: + * import org.apache.spark.sql.functions._ + * df.groupBy("department").agg(max("age"), sum("expense")) + * + * // Java: + * import static org.apache.spark.sql.functions.*; + * df.groupBy("department").agg(max("age"), sum("expense")); + * }}} + * + * Note that before Spark 1.4, the default behavior is to NOT retain grouping columns. To change + * to that behavior, set config variable `spark.sql.retainGroupColumns` to `false`. + * {{{ + * // Scala, 1.3.x: + * df.groupBy("department").agg($"department", max("age"), sum("expense")) + * + * // Java, 1.3.x: + * df.groupBy("department").agg(col("department"), max("age"), sum("expense")); + * }}} + * + * @since 1.3.0 + */ + @scala.annotation.varargs + def agg(expr: Column, exprs: Column*): DataFrame = { + toDF((expr +: exprs).map(_.expr)) + } + + /** + * Count the number of rows for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * + * @since 1.3.0 + */ + def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)).toAggregateExpression(), "count")())) + + /** + * Compute the average value for each numeric columns for each group. This is an alias for `avg`. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the average values for them. + * + * @since 1.3.0 + */ + @scala.annotation.varargs + def mean(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Average) + } + + /** + * Compute the max value for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the max values for them. + * + * @since 1.3.0 + */ + @scala.annotation.varargs + def max(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Max) + } + + /** + * Compute the mean value for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the mean values for them. + * + * @since 1.3.0 + */ + @scala.annotation.varargs + def avg(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Average) + } + + /** + * Compute the min value for each numeric column for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the min values for them. + * + * @since 1.3.0 + */ + @scala.annotation.varargs + def min(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Min) + } + + /** + * Compute the sum for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the sum for them. + * + * @since 1.3.0 + */ + @scala.annotation.varargs + def sum(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Sum) + } + + /** + * Pivots a column of the current [[DataFrame]] and perform the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings") + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @since 1.6.0 + */ + def pivot(pivotColumn: String): RelationalGroupedDataset = { + // This is to prevent unintended OOM errors when the number of distinct values is large + val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES) + // Get the distinct values of the column and sort them so its consistent + val values = df.select(pivotColumn) + .distinct() + .sort(pivotColumn) // ensure that the output columns are in a consistent logical order + .rdd + .map(_.get(0)) + .take(maxValues + 1) + .toSeq + + if (values.length > maxValues) { + throw new AnalysisException( + s"The pivot column $pivotColumn has more than $maxValues distinct values, " + + "this could indicate an error. " + + s"If this was intended, set ${SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key} " + + "to at least the number of distinct values of the pivot column.") + } + + pivot(pivotColumn, values) + } + + /** + * Pivots a column of the current [[DataFrame]] and perform the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings") + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = { + groupType match { + case RelationalGroupedDataset.GroupByType => + new RelationalGroupedDataset( + df, + groupingExprs, + RelationalGroupedDataset.PivotType(df.resolve(pivotColumn), values.map(Literal.apply))) + case _: RelationalGroupedDataset.PivotType => + throw new UnsupportedOperationException("repeated pivots are not supported") + case _ => + throw new UnsupportedOperationException("pivot is only supported after a groupBy") + } + } + + /** + * Pivots a column of the current [[DataFrame]] and perform the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Arrays.asList("dotNET", "Java")).sum("earnings"); + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings"); + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = { + pivot(pivotColumn, values.asScala) + } +} + + +/** + * Companion object for GroupedData. + */ +private[sql] object RelationalGroupedDataset { + + def apply( + df: DataFrame, + groupingExprs: Seq[Expression], + groupType: GroupType): RelationalGroupedDataset = { + new RelationalGroupedDataset(df, groupingExprs, groupType: GroupType) + } + + /** + * The Grouping Type + */ + private[sql] trait GroupType + + /** + * To indicate it's the GroupBy + */ + private[sql] object GroupByType extends GroupType + + /** + * To indicate it's the CUBE + */ + private[sql] object CubeType extends GroupType + + /** + * To indicate it's the ROLLUP + */ + private[sql] object RollupType extends GroupType + + /** + * To indicate it's the PIVOT + */ + private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala new file mode 100644 index 0000000000000..e90a04243164b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +/** + * Runtime configuration interface for Spark. To access this, use `SparkSession.conf`. + * + * @since 2.0.0 + */ +abstract class RuntimeConfig { + + /** + * Sets the given Spark runtime configuration property. + * + * @since 2.0.0 + */ + def set(key: String, value: String): RuntimeConfig + + /** + * Sets the given Spark runtime configuration property. + * + * @since 2.0.0 + */ + def set(key: String, value: Boolean): RuntimeConfig + + /** + * Sets the given Spark runtime configuration property. + * + * @since 2.0.0 + */ + def set(key: String, value: Long): RuntimeConfig + + /** + * Returns the value of Spark runtime configuration property for the given key. + * + * @throws NoSuchElementException if the key is not set and does not have a default value + * @since 2.0.0 + */ + @throws[NoSuchElementException]("if the key is not set") + def get(key: String): String + + /** + * Returns the value of Spark runtime configuration property for the given key. + * + * @since 2.0.0 + */ + def getOption(key: String): Option[String] + + /** + * Resets the configuration property for the given key. + * + * @since 2.0.0 + */ + def unset(key: String): Unit + + /** + * Sets the given Hadoop configuration property. This is passed directly to Hadoop during I/O. + * + * @since 2.0.0 + */ + def setHadoop(key: String, value: String): RuntimeConfig + + /** + * Returns the value of the Hadoop configuration property. + * + * @throws NoSuchElementException if the key is not set + * @since 2.0.0 + */ + @throws[NoSuchElementException]("if the key is not set") + def getHadoop(key: String): String + + /** + * Returns the value of the Hadoop configuration property. + * + * @since 2.0.0 + */ + def getHadoopOption(key: String): Option[String] + + /** + * Resets the Hadoop configuration property for the given key. + * + * @since 2.0.0 + */ + def unsetHadoop(key: String): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala deleted file mode 100644 index ed8b634ad5630..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ /dev/null @@ -1,677 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.util.Properties - -import scala.collection.immutable -import scala.collection.JavaConverters._ - -import org.apache.parquet.hadoop.ParquetOutputCommitter - -import org.apache.spark.sql.catalyst.CatalystConf - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// This file defines the configuration options for Spark SQL. -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -private[spark] object SQLConf { - - private val sqlConfEntries = java.util.Collections.synchronizedMap( - new java.util.HashMap[String, SQLConfEntry[_]]()) - - /** - * An entry contains all meta information for a configuration. - * - * @param key the key for the configuration - * @param defaultValue the default value for the configuration - * @param valueConverter how to convert a string to the value. It should throw an exception if the - * string does not have the required format. - * @param stringConverter how to convert a value to a string that the user can use it as a valid - * string value. It's usually `toString`. But sometimes, a custom converter - * is necessary. E.g., if T is List[String], `a, b, c` is better than - * `List(a, b, c)`. - * @param doc the document for the configuration - * @param isPublic if this configuration is public to the user. If it's `false`, this - * configuration is only used internally and we should not expose it to the user. - * @tparam T the value type - */ - private[sql] class SQLConfEntry[T] private( - val key: String, - val defaultValue: Option[T], - val valueConverter: String => T, - val stringConverter: T => String, - val doc: String, - val isPublic: Boolean) { - - def defaultValueString: String = defaultValue.map(stringConverter).getOrElse("") - - override def toString: String = { - s"SQLConfEntry(key = $key, defaultValue=$defaultValueString, doc=$doc, isPublic = $isPublic)" - } - } - - private[sql] object SQLConfEntry { - - private def apply[T]( - key: String, - defaultValue: Option[T], - valueConverter: String => T, - stringConverter: T => String, - doc: String, - isPublic: Boolean): SQLConfEntry[T] = - sqlConfEntries.synchronized { - if (sqlConfEntries.containsKey(key)) { - throw new IllegalArgumentException(s"Duplicate SQLConfEntry. $key has been registered") - } - val entry = - new SQLConfEntry[T](key, defaultValue, valueConverter, stringConverter, doc, isPublic) - sqlConfEntries.put(key, entry) - entry - } - - def intConf( - key: String, - defaultValue: Option[Int] = None, - doc: String = "", - isPublic: Boolean = true): SQLConfEntry[Int] = - SQLConfEntry(key, defaultValue, { v => - try { - v.toInt - } catch { - case _: NumberFormatException => - throw new IllegalArgumentException(s"$key should be int, but was $v") - } - }, _.toString, doc, isPublic) - - def longConf( - key: String, - defaultValue: Option[Long] = None, - doc: String = "", - isPublic: Boolean = true): SQLConfEntry[Long] = - SQLConfEntry(key, defaultValue, { v => - try { - v.toLong - } catch { - case _: NumberFormatException => - throw new IllegalArgumentException(s"$key should be long, but was $v") - } - }, _.toString, doc, isPublic) - - def doubleConf( - key: String, - defaultValue: Option[Double] = None, - doc: String = "", - isPublic: Boolean = true): SQLConfEntry[Double] = - SQLConfEntry(key, defaultValue, { v => - try { - v.toDouble - } catch { - case _: NumberFormatException => - throw new IllegalArgumentException(s"$key should be double, but was $v") - } - }, _.toString, doc, isPublic) - - def booleanConf( - key: String, - defaultValue: Option[Boolean] = None, - doc: String = "", - isPublic: Boolean = true): SQLConfEntry[Boolean] = - SQLConfEntry(key, defaultValue, { v => - try { - v.toBoolean - } catch { - case _: IllegalArgumentException => - throw new IllegalArgumentException(s"$key should be boolean, but was $v") - } - }, _.toString, doc, isPublic) - - def stringConf( - key: String, - defaultValue: Option[String] = None, - doc: String = "", - isPublic: Boolean = true): SQLConfEntry[String] = - SQLConfEntry(key, defaultValue, v => v, v => v, doc, isPublic) - - def enumConf[T]( - key: String, - valueConverter: String => T, - validValues: Set[T], - defaultValue: Option[T] = None, - doc: String = "", - isPublic: Boolean = true): SQLConfEntry[T] = - SQLConfEntry(key, defaultValue, v => { - val _v = valueConverter(v) - if (!validValues.contains(_v)) { - throw new IllegalArgumentException( - s"The value of $key should be one of ${validValues.mkString(", ")}, but was $v") - } - _v - }, _.toString, doc, isPublic) - - def seqConf[T]( - key: String, - valueConverter: String => T, - defaultValue: Option[Seq[T]] = None, - doc: String = "", - isPublic: Boolean = true): SQLConfEntry[Seq[T]] = { - SQLConfEntry( - key, defaultValue, _.split(",").map(valueConverter), _.mkString(","), doc, isPublic) - } - - def stringSeqConf( - key: String, - defaultValue: Option[Seq[String]] = None, - doc: String = "", - isPublic: Boolean = true): SQLConfEntry[Seq[String]] = { - seqConf(key, s => s, defaultValue, doc, isPublic) - } - } - - import SQLConfEntry._ - - val ALLOW_MULTIPLE_CONTEXTS = booleanConf("spark.sql.allowMultipleContexts", - defaultValue = Some(true), - doc = "When set to true, creating multiple SQLContexts/HiveContexts is allowed." + - "When set to false, only one SQLContext/HiveContext is allowed to be created " + - "through the constructor (new SQLContexts/HiveContexts created through newSession " + - "method is allowed). Please note that this conf needs to be set in Spark Conf. Once" + - "a SQLContext/HiveContext has been created, changing the value of this conf will not" + - "have effect.", - isPublic = true) - - val COMPRESS_CACHED = booleanConf("spark.sql.inMemoryColumnarStorage.compressed", - defaultValue = Some(true), - doc = "When set to true Spark SQL will automatically select a compression codec for each " + - "column based on statistics of the data.", - isPublic = false) - - val COLUMN_BATCH_SIZE = intConf("spark.sql.inMemoryColumnarStorage.batchSize", - defaultValue = Some(10000), - doc = "Controls the size of batches for columnar caching. Larger batch sizes can improve " + - "memory utilization and compression, but risk OOMs when caching data.", - isPublic = false) - - val IN_MEMORY_PARTITION_PRUNING = - booleanConf("spark.sql.inMemoryColumnarStorage.partitionPruning", - defaultValue = Some(true), - doc = "When true, enable partition pruning for in-memory columnar tables.", - isPublic = false) - - val AUTO_BROADCASTJOIN_THRESHOLD = intConf("spark.sql.autoBroadcastJoinThreshold", - defaultValue = Some(10 * 1024 * 1024), - doc = "Configures the maximum size in bytes for a table that will be broadcast to all worker " + - "nodes when performing a join. By setting this value to -1 broadcasting can be disabled. " + - "Note that currently statistics are only supported for Hive Metastore tables where the " + - "commandANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run.") - - val DEFAULT_SIZE_IN_BYTES = longConf( - "spark.sql.defaultSizeInBytes", - doc = "The default table size used in query planning. By default, it is set to a larger " + - "value than `spark.sql.autoBroadcastJoinThreshold` to be more conservative. That is to say " + - "by default the optimizer will not choose to broadcast a table unless it knows for sure its" + - "size is small enough.", - isPublic = false) - - val SHUFFLE_PARTITIONS = intConf("spark.sql.shuffle.partitions", - defaultValue = Some(200), - doc = "The default number of partitions to use when shuffling data for joins or aggregations.") - - val SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE = - longConf("spark.sql.adaptive.shuffle.targetPostShuffleInputSize", - defaultValue = Some(64 * 1024 * 1024), - doc = "The target post-shuffle input size in bytes of a task.") - - val ADAPTIVE_EXECUTION_ENABLED = booleanConf("spark.sql.adaptive.enabled", - defaultValue = Some(false), - doc = "When true, enable adaptive query execution.") - - val SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS = - intConf("spark.sql.adaptive.minNumPostShufflePartitions", - defaultValue = Some(-1), - doc = "The advisory minimal number of post-shuffle partitions provided to " + - "ExchangeCoordinator. This setting is used in our test to make sure we " + - "have enough parallelism to expose issues that will not be exposed with a " + - "single partition. When the value is a non-positive value, this setting will" + - "not be provided to ExchangeCoordinator.", - isPublic = false) - - val TUNGSTEN_ENABLED = booleanConf("spark.sql.tungsten.enabled", - defaultValue = Some(true), - doc = "When true, use the optimized Tungsten physical execution backend which explicitly " + - "manages memory and dynamically generates bytecode for expression evaluation.") - - val CODEGEN_ENABLED = booleanConf("spark.sql.codegen", - defaultValue = Some(true), // use TUNGSTEN_ENABLED as default - doc = "When true, code will be dynamically generated at runtime for expression evaluation in" + - " a specific query.", - isPublic = false) - - val UNSAFE_ENABLED = booleanConf("spark.sql.unsafe.enabled", - defaultValue = Some(true), // use TUNGSTEN_ENABLED as default - doc = "When true, use the new optimized Tungsten physical execution backend.", - isPublic = false) - - val DIALECT = stringConf( - "spark.sql.dialect", - defaultValue = Some("sql"), - doc = "The default SQL dialect to use.") - - val CASE_SENSITIVE = booleanConf("spark.sql.caseSensitive", - defaultValue = Some(true), - doc = "Whether the query analyzer should be case sensitive or not.") - - val PARQUET_SCHEMA_MERGING_ENABLED = booleanConf("spark.sql.parquet.mergeSchema", - defaultValue = Some(false), - doc = "When true, the Parquet data source merges schemas collected from all data files, " + - "otherwise the schema is picked from the summary file or a random data file " + - "if no summary file is available.") - - val PARQUET_SCHEMA_RESPECT_SUMMARIES = booleanConf("spark.sql.parquet.respectSummaryFiles", - defaultValue = Some(false), - doc = "When true, we make assumption that all part-files of Parquet are consistent with " + - "summary files and we will ignore them when merging schema. Otherwise, if this is " + - "false, which is the default, we will merge all part-files. This should be considered " + - "as expert-only option, and shouldn't be enabled before knowing what it means exactly.") - - val PARQUET_BINARY_AS_STRING = booleanConf("spark.sql.parquet.binaryAsString", - defaultValue = Some(false), - doc = "Some other Parquet-producing systems, in particular Impala and older versions of " + - "Spark SQL, do not differentiate between binary data and strings when writing out the " + - "Parquet schema. This flag tells Spark SQL to interpret binary data as a string to provide " + - "compatibility with these systems.") - - val PARQUET_INT96_AS_TIMESTAMP = booleanConf("spark.sql.parquet.int96AsTimestamp", - defaultValue = Some(true), - doc = "Some Parquet-producing systems, in particular Impala, store Timestamp into INT96. " + - "Spark would also store Timestamp as INT96 because we need to avoid precision lost of the " + - "nanoseconds field. This flag tells Spark SQL to interpret INT96 data as a timestamp to " + - "provide compatibility with these systems.") - - val PARQUET_CACHE_METADATA = booleanConf("spark.sql.parquet.cacheMetadata", - defaultValue = Some(true), - doc = "Turns on caching of Parquet schema metadata. Can speed up querying of static data.") - - val PARQUET_COMPRESSION = enumConf("spark.sql.parquet.compression.codec", - valueConverter = v => v.toLowerCase, - validValues = Set("uncompressed", "snappy", "gzip", "lzo"), - defaultValue = Some("gzip"), - doc = "Sets the compression codec use when writing Parquet files. Acceptable values include: " + - "uncompressed, snappy, gzip, lzo.") - - val PARQUET_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.parquet.filterPushdown", - defaultValue = Some(true), - doc = "Enables Parquet filter push-down optimization when set to true.") - - val PARQUET_WRITE_LEGACY_FORMAT = booleanConf( - key = "spark.sql.parquet.writeLegacyFormat", - defaultValue = Some(false), - doc = "Whether to follow Parquet's format specification when converting Parquet schema to " + - "Spark SQL schema and vice versa.") - - val PARQUET_OUTPUT_COMMITTER_CLASS = stringConf( - key = "spark.sql.parquet.output.committer.class", - defaultValue = Some(classOf[ParquetOutputCommitter].getName), - doc = "The output committer class used by Parquet. The specified class needs to be a " + - "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + - "of org.apache.parquet.hadoop.ParquetOutputCommitter. NOTE: 1. Instead of SQLConf, this " + - "option must be set in Hadoop Configuration. 2. This option overrides " + - "\"spark.sql.sources.outputCommitterClass\".") - - val ORC_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.orc.filterPushdown", - defaultValue = Some(false), - doc = "When true, enable filter pushdown for ORC files.") - - val HIVE_VERIFY_PARTITION_PATH = booleanConf("spark.sql.hive.verifyPartitionPath", - defaultValue = Some(false), - doc = "") - - val HIVE_METASTORE_PARTITION_PRUNING = booleanConf("spark.sql.hive.metastorePartitionPruning", - defaultValue = Some(false), - doc = "When true, some predicates will be pushed down into the Hive metastore so that " + - "unmatching partitions can be eliminated earlier.") - - val NATIVE_VIEW = booleanConf("spark.sql.nativeView", - defaultValue = Some(false), - doc = "When true, CREATE VIEW will be handled by Spark SQL instead of Hive native commands. " + - "Note that this function is experimental and should ony be used when you are using " + - "non-hive-compatible tables written by Spark SQL. The SQL string used to create " + - "view should be fully qualified, i.e. use `tbl1`.`col1` instead of `*` whenever " + - "possible, or you may get wrong result.", - isPublic = false) - - val COLUMN_NAME_OF_CORRUPT_RECORD = stringConf("spark.sql.columnNameOfCorruptRecord", - defaultValue = Some("_corrupt_record"), - doc = "") - - val BROADCAST_TIMEOUT = intConf("spark.sql.broadcastTimeout", - defaultValue = Some(5 * 60), - doc = "Timeout in seconds for the broadcast wait time in broadcast joins.") - - // Options that control which operators can be chosen by the query planner. These should be - // considered hints and may be ignored by future versions of Spark SQL. - val SORTMERGE_JOIN = booleanConf("spark.sql.planner.sortMergeJoin", - defaultValue = Some(true), - doc = "When true, use sort merge join (as opposed to hash join) by default for large joins.") - - // This is only used for the thriftserver - val THRIFTSERVER_POOL = stringConf("spark.sql.thriftserver.scheduler.pool", - doc = "Set a Fair Scheduler pool for a JDBC client session") - - val THRIFTSERVER_UI_STATEMENT_LIMIT = intConf("spark.sql.thriftserver.ui.retainedStatements", - defaultValue = Some(200), - doc = "The number of SQL statements kept in the JDBC/ODBC web UI history.") - - val THRIFTSERVER_UI_SESSION_LIMIT = intConf("spark.sql.thriftserver.ui.retainedSessions", - defaultValue = Some(200), - doc = "The number of SQL client sessions kept in the JDBC/ODBC web UI history.") - - // This is used to set the default data source - val DEFAULT_DATA_SOURCE_NAME = stringConf("spark.sql.sources.default", - defaultValue = Some("org.apache.spark.sql.parquet"), - doc = "The default data source to use in input/output.") - - // This is used to control the when we will split a schema's JSON string to multiple pieces - // in order to fit the JSON string in metastore's table property (by default, the value has - // a length restriction of 4000 characters). We will split the JSON string of a schema - // to its length exceeds the threshold. - val SCHEMA_STRING_LENGTH_THRESHOLD = intConf("spark.sql.sources.schemaStringLengthThreshold", - defaultValue = Some(4000), - doc = "The maximum length allowed in a single cell when " + - "storing additional schema information in Hive's metastore.", - isPublic = false) - - val PARTITION_DISCOVERY_ENABLED = booleanConf("spark.sql.sources.partitionDiscovery.enabled", - defaultValue = Some(true), - doc = "When true, automatically discover data partitions.") - - val PARTITION_COLUMN_TYPE_INFERENCE = - booleanConf("spark.sql.sources.partitionColumnTypeInference.enabled", - defaultValue = Some(true), - doc = "When true, automatically infer the data types for partitioned columns.") - - val PARTITION_MAX_FILES = - intConf("spark.sql.sources.maxConcurrentWrites", - defaultValue = Some(5), - doc = "The maximum number of concurrent files to open before falling back on sorting when " + - "writing out files using dynamic partitioning.") - - // The output committer class used by HadoopFsRelation. The specified class needs to be a - // subclass of org.apache.hadoop.mapreduce.OutputCommitter. - // - // NOTE: - // - // 1. Instead of SQLConf, this option *must be set in Hadoop Configuration*. - // 2. This option can be overriden by "spark.sql.parquet.output.committer.class". - val OUTPUT_COMMITTER_CLASS = - stringConf("spark.sql.sources.outputCommitterClass", isPublic = false) - - val PARALLEL_PARTITION_DISCOVERY_THRESHOLD = intConf( - key = "spark.sql.sources.parallelPartitionDiscovery.threshold", - defaultValue = Some(32), - doc = "") - - // Whether to perform eager analysis when constructing a dataframe. - // Set to false when debugging requires the ability to look at invalid query plans. - val DATAFRAME_EAGER_ANALYSIS = booleanConf( - "spark.sql.eagerAnalysis", - defaultValue = Some(true), - doc = "When true, eagerly applies query analysis on DataFrame operations.", - isPublic = false) - - // Whether to automatically resolve ambiguity in join conditions for self-joins. - // See SPARK-6231. - val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = booleanConf( - "spark.sql.selfJoinAutoResolveAmbiguity", - defaultValue = Some(true), - isPublic = false) - - // Whether to retain group by columns or not in GroupedData.agg. - val DATAFRAME_RETAIN_GROUP_COLUMNS = booleanConf( - "spark.sql.retainGroupColumns", - defaultValue = Some(true), - isPublic = false) - - val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2", - defaultValue = Some(true), doc = "") - - val RUN_SQL_ON_FILES = booleanConf("spark.sql.runSQLOnFiles", - defaultValue = Some(true), - isPublic = false, - doc = "When true, we could use `datasource`.`path` as table in SQL query" - ) - - object Deprecated { - val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" - val EXTERNAL_SORT = "spark.sql.planner.externalSort" - } -} - -/** - * A class that enables the setting and getting of mutable config parameters/hints. - * - * In the presence of a SQLContext, these can be set and queried by passing SET commands - * into Spark SQL's query functions (i.e. sql()). Otherwise, users of this class can - * modify the hints by programmatically calling the setters and getters of this class. - * - * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). - */ -private[sql] class SQLConf extends Serializable with CatalystConf { - import SQLConf._ - - /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ - @transient protected[spark] val settings = java.util.Collections.synchronizedMap( - new java.util.HashMap[String, String]()) - - /** ************************ Spark SQL Params/Hints ******************* */ - // TODO: refactor so that these hints accessors don't pollute the name space of SQLContext? - - /** - * The SQL dialect that is used when parsing queries. This defaults to 'sql' which uses - * a simple SQL parser provided by Spark SQL. This is currently the only option for users of - * SQLContext. - * - * When using a HiveContext, this value defaults to 'hiveql', which uses the Hive 0.12.0 HiveQL - * parser. Users can change this to 'sql' if they want to run queries that aren't supported by - * HiveQL (e.g., SELECT 1). - * - * Note that the choice of dialect does not affect things like what tables are available or - * how query execution is performed. - */ - private[spark] def dialect: String = getConf(DIALECT) - - private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED) - - private[spark] def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) - - private[spark] def parquetCacheMetadata: Boolean = getConf(PARQUET_CACHE_METADATA) - - private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) - - private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) - - private[spark] def targetPostShuffleInputSize: Long = - getConf(SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) - - private[spark] def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED) - - private[spark] def minNumPostShufflePartitions: Int = - getConf(SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS) - - private[spark] def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) - - private[spark] def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) - - private[spark] def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) - - private[spark] def metastorePartitionPruning: Boolean = getConf(HIVE_METASTORE_PARTITION_PRUNING) - - private[spark] def nativeView: Boolean = getConf(NATIVE_VIEW) - - private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) - - private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, getConf(TUNGSTEN_ENABLED)) - - def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) - - private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, getConf(TUNGSTEN_ENABLED)) - - private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2) - - private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) - - private[spark] def defaultSizeInBytes: Long = - getConf(DEFAULT_SIZE_IN_BYTES, autoBroadcastJoinThreshold + 1L) - - private[spark] def isParquetBinaryAsString: Boolean = getConf(PARQUET_BINARY_AS_STRING) - - private[spark] def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP) - - private[spark] def writeLegacyParquetFormat: Boolean = getConf(PARQUET_WRITE_LEGACY_FORMAT) - - private[spark] def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) - - private[spark] def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD) - - private[spark] def broadcastTimeout: Int = getConf(BROADCAST_TIMEOUT) - - private[spark] def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME) - - private[spark] def partitionDiscoveryEnabled(): Boolean = - getConf(SQLConf.PARTITION_DISCOVERY_ENABLED) - - private[spark] def partitionColumnTypeInferenceEnabled(): Boolean = - getConf(SQLConf.PARTITION_COLUMN_TYPE_INFERENCE) - - private[spark] def parallelPartitionDiscoveryThreshold: Int = - getConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD) - - // Do not use a value larger than 4000 as the default value of this property. - // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information. - private[spark] def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD) - - private[spark] def dataFrameEagerAnalysis: Boolean = getConf(DATAFRAME_EAGER_ANALYSIS) - - private[spark] def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = - getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) - - private[spark] def dataFrameRetainGroupColumns: Boolean = getConf(DATAFRAME_RETAIN_GROUP_COLUMNS) - - private[spark] def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES) - - /** ********************** SQLConf functionality methods ************ */ - - /** Set Spark SQL configuration properties. */ - def setConf(props: Properties): Unit = settings.synchronized { - props.asScala.foreach { case (k, v) => setConfString(k, v) } - } - - /** Set the given Spark SQL configuration property using a `string` value. */ - def setConfString(key: String, value: String): Unit = { - require(key != null, "key cannot be null") - require(value != null, s"value cannot be null for key: $key") - val entry = sqlConfEntries.get(key) - if (entry != null) { - // Only verify configs in the SQLConf object - entry.valueConverter(value) - } - settings.put(key, value) - } - - /** Set the given Spark SQL configuration property. */ - def setConf[T](entry: SQLConfEntry[T], value: T): Unit = { - require(entry != null, "entry cannot be null") - require(value != null, s"value cannot be null for key: ${entry.key}") - require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") - settings.put(entry.key, entry.stringConverter(value)) - } - - /** Return the value of Spark SQL configuration property for the given key. */ - def getConfString(key: String): String = { - Option(settings.get(key)). - orElse { - // Try to use the default value - Option(sqlConfEntries.get(key)).map(_.defaultValueString) - }. - getOrElse(throw new NoSuchElementException(key)) - } - - /** - * Return the value of Spark SQL configuration property for the given key. If the key is not set - * yet, return `defaultValue`. This is useful when `defaultValue` in SQLConfEntry is not the - * desired one. - */ - def getConf[T](entry: SQLConfEntry[T], defaultValue: T): T = { - require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") - Option(settings.get(entry.key)).map(entry.valueConverter).getOrElse(defaultValue) - } - - /** - * Return the value of Spark SQL configuration property for the given key. If the key is not set - * yet, return `defaultValue` in [[SQLConfEntry]]. - */ - def getConf[T](entry: SQLConfEntry[T]): T = { - require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") - Option(settings.get(entry.key)).map(entry.valueConverter).orElse(entry.defaultValue). - getOrElse(throw new NoSuchElementException(entry.key)) - } - - /** - * Return the `string` value of Spark SQL configuration property for the given key. If the key is - * not set yet, return `defaultValue`. - */ - def getConfString(key: String, defaultValue: String): String = { - val entry = sqlConfEntries.get(key) - if (entry != null && defaultValue != "") { - // Only verify configs in the SQLConf object - entry.valueConverter(defaultValue) - } - Option(settings.get(key)).getOrElse(defaultValue) - } - - /** - * Return all the configuration properties that have been set (i.e. not the default). - * This creates a new copy of the config properties in the form of a Map. - */ - def getAllConfs: immutable.Map[String, String] = - settings.synchronized { settings.asScala.toMap } - - /** - * Return all the configuration definitions that have been defined in [[SQLConf]]. Each - * definition contains key, defaultValue and doc. - */ - def getAllDefinedConfs: Seq[(String, String, String)] = sqlConfEntries.synchronized { - sqlConfEntries.values.asScala.filter(_.isPublic).map { entry => - (entry.key, entry.defaultValueString, entry.doc) - }.toSeq - } - - private[spark] def unsetConf(key: String): Unit = { - settings.remove(key) - } - - private[spark] def unsetConf(entry: SQLConfEntry[_]): Unit = { - settings.remove(entry.key) - } - - private[spark] def clear(): Unit = { - settings.clear() - } -} - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 5ad3871093fc8..9259ff40625c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -21,32 +21,29 @@ import java.beans.{BeanInfo, Introspector} import java.util.Properties import java.util.concurrent.atomic.AtomicReference - import scala.collection.JavaConverters._ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag -import scala.util.control.NonFatal -import org.apache.spark.{SparkException, SparkContext} +import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.ConfigEntry import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} -import org.apache.spark.sql.SQLConf.SQLConfEntry -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} -import org.apache.spark.sql.catalyst.errors.DialectException +import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.command.ShowTablesCommand import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} +import org.apache.spark.sql.internal.{SessionState, SQLConf} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.{execution => sparkexecution} import org.apache.spark.sql.util.ExecutionListenerManager import org.apache.spark.util.Utils @@ -61,26 +58,28 @@ import org.apache.spark.util.Utils * @groupname specificdata Specific Data Sources * @groupname config Configuration * @groupname dataframes Custom DataFrame Creation + * @groupname dataset Custom DataFrame Creation * @groupname Ungrouped Support functions for language integrated queries - * * @since 1.0.0 */ class SQLContext private[sql]( @transient val sparkContext: SparkContext, @transient protected[sql] val cacheManager: CacheManager, @transient private[sql] val listener: SQLListener, - val isRootContext: Boolean) - extends org.apache.spark.Logging with Serializable { + val isRootContext: Boolean, + @transient private[sql] val externalCatalog: ExternalCatalog) + extends Logging with Serializable { self => - def this(sparkContext: SparkContext) = { - this(sparkContext, new CacheManager, SQLContext.createListenerAndUI(sparkContext), true) + def this(sc: SparkContext) = { + this(sc, new CacheManager, SQLContext.createListenerAndUI(sc), true, new InMemoryCatalog) } + def this(sparkContext: JavaSparkContext) = this(sparkContext.sc) // If spark.sql.allowMultipleContexts is true, we will throw an exception if a user - // wants to create a new root SQLContext (a SLQContext that is not created by newSession). + // wants to create a new root SQLContext (a SQLContext that is not created by newSession). private val allowMultipleContexts = sparkContext.conf.getBoolean( SQLConf.ALLOW_MULTIPLE_CONTEXTS.key, @@ -112,13 +111,23 @@ class SQLContext private[sql]( sparkContext = sparkContext, cacheManager = cacheManager, listener = listener, - isRootContext = false) + isRootContext = false, + externalCatalog = externalCatalog) } /** - * @return Spark SQL configuration + * Per-session state, e.g. configuration, functions, temporary tables etc. */ - protected[sql] lazy val conf = new SQLConf + @transient + protected[sql] lazy val sessionState: SessionState = new SessionState(self) + protected[spark] def conf: SQLConf = sessionState.conf + + /** + * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s + * that listen for execution metrics. + */ + @Experimental + def listenerManager: ExecutionListenerManager = sessionState.listenerManager /** * Set Spark SQL configuration properties. @@ -129,7 +138,7 @@ class SQLContext private[sql]( def setConf(props: Properties): Unit = conf.setConf(props) /** Set the given Spark SQL configuration property. */ - private[sql] def setConf[T](entry: SQLConfEntry[T], value: T): Unit = conf.setConf(entry, value) + private[sql] def setConf[T](entry: ConfigEntry[T], value: T): Unit = conf.setConf(entry, value) /** * Set the given Spark SQL configuration property. @@ -149,16 +158,16 @@ class SQLContext private[sql]( /** * Return the value of Spark SQL configuration property for the given key. If the key is not set - * yet, return `defaultValue` in [[SQLConfEntry]]. + * yet, return `defaultValue` in [[ConfigEntry]]. */ - private[sql] def getConf[T](entry: SQLConfEntry[T]): T = conf.getConf(entry) + private[sql] def getConf[T](entry: ConfigEntry[T]): T = conf.getConf(entry) /** * Return the value of Spark SQL configuration property for the given key. If the key is not set - * yet, return `defaultValue`. This is useful when `defaultValue` in SQLConfEntry is not the + * yet, return `defaultValue`. This is useful when `defaultValue` in ConfigEntry is not the * desired one. */ - private[sql] def getConf[T](entry: SQLConfEntry[T], defaultValue: T): T = { + private[sql] def getConf[T](entry: ConfigEntry[T], defaultValue: T): T = { conf.getConf(entry, defaultValue) } @@ -180,68 +189,17 @@ class SQLContext private[sql]( */ def getAllConfs: immutable.Map[String, String] = conf.getAllConfs - @transient - lazy val listenerManager: ExecutionListenerManager = new ExecutionListenerManager - - @transient - protected[sql] lazy val catalog: Catalog = new SimpleCatalog(conf) - - @transient - protected[sql] lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy() - - @transient - protected[sql] lazy val analyzer: Analyzer = - new Analyzer(catalog, functionRegistry, conf) { - override val extendedResolutionRules = - ExtractPythonUDFs :: - PreInsertCastAndRename :: - (if (conf.runSQLOnFile) new ResolveDataSource(self) :: Nil else Nil) - - override val extendedCheckRules = Seq( - datasources.PreWriteCheck(catalog) - ) - } - - @transient - protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer - - @transient - protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_)) - - @transient - protected[sql] val sqlParser = new SparkSQLParser(getSQLDialect().parse(_)) - - protected[sql] def getSQLDialect(): ParserDialect = { - try { - val clazz = Utils.classForName(dialectClassName) - clazz.newInstance().asInstanceOf[ParserDialect] - } catch { - case NonFatal(e) => - // Since we didn't find the available SQL Dialect, it will fail even for SET command: - // SET spark.sql.dialect=sql; Let's reset as default dialect automatically. - val dialect = conf.dialect - // reset the sql dialect - conf.unsetConf(SQLConf.DIALECT) - // throw out the exception, and the default sql dialect will take effect for next query. - throw new DialectException( - s"""Instantiating dialect '$dialect' failed. - |Reverting to default dialect '${conf.dialect}'""".stripMargin, e) - } + // Extract `spark.sql.*` entries and put it in our SQLConf. + // Subclasses may additionally set these entries in other confs. + SQLContext.getSQLProperties(sparkContext.getConf).asScala.foreach { case (k, v) => + setConf(k, v) } - protected[sql] def parseSql(sql: String): LogicalPlan = ddlParser.parse(sql, false) + protected[sql] def parseSql(sql: String): LogicalPlan = sessionState.sqlParser.parsePlan(sql) - protected[sql] def executeSql(sql: String): - org.apache.spark.sql.execution.QueryExecution = executePlan(parseSql(sql)) + protected[sql] def executeSql(sql: String): QueryExecution = executePlan(parseSql(sql)) - protected[sql] def executePlan(plan: LogicalPlan) = - new sparkexecution.QueryExecution(this, plan) - - protected[sql] def dialectClassName = if (conf.dialect == "sql") { - classOf[DefaultParserDialect].getCanonicalName - } else { - conf.dialect - } + protected[sql] def executePlan(plan: LogicalPlan) = new QueryExecution(this, plan) /** * Add a jar to SQLContext @@ -250,27 +208,19 @@ class SQLContext private[sql]( sparkContext.addJar(path) } - { - // We extract spark sql settings from SparkContext's conf and put them to - // Spark SQL's conf. - // First, we populate the SQLConf (conf). So, we can make sure that other values using - // those settings in their construction can get the correct settings. - // For example, metadataHive in HiveContext may need both spark.sql.hive.metastore.version - // and spark.sql.hive.metastore.jars to get correctly constructed. - val properties = new Properties - sparkContext.getConf.getAll.foreach { - case (key, value) if key.startsWith("spark.sql") => properties.setProperty(key, value) - case _ => - } - // We directly put those settings to conf to avoid of calling setConf, which may have - // side-effects. For example, in HiveContext, setConf may cause executionHive and metadataHive - // get constructed. If we call setConf directly, the constructed metadataHive may have - // wrong settings, or the construction may fail. - conf.setConf(properties) - // After we have populated SQLConf, we call setConf to populate other confs in the subclass - // (e.g. hiveconf in HiveContext). - properties.asScala.foreach { - case (key, value) => setConf(key, value) + /** A [[FunctionResourceLoader]] that can be used in SessionCatalog. */ + @transient protected[sql] lazy val functionResourceLoader: FunctionResourceLoader = { + new FunctionResourceLoader { + override def loadResource(resource: FunctionResource): Unit = { + resource.resourceType match { + case JarResource => addJar(resource.uri) + case FileResource => sparkContext.addFile(resource.uri) + case ArchiveResource => + throw new AnalysisException( + "Archive is not allowed to be loaded. If YARN mode is used, " + + "please use --archives options while calling spark-submit.") + } + } } } @@ -284,7 +234,7 @@ class SQLContext private[sql]( */ @Experimental @transient - val experimental: ExperimentalMethods = new ExperimentalMethods(this) + def experimental: ExperimentalMethods = sessionState.experimentalMethods /** * :: Experimental :: @@ -325,10 +275,8 @@ class SQLContext private[sql]( * * @group basic * @since 1.3.0 - * TODO move to SQLSession? */ - @transient - val udf: UDFRegistration = new UDFRegistration(this) + def udf: UDFRegistration = sessionState.udf /** * Returns true if the table is currently cached in-memory. @@ -339,6 +287,15 @@ class SQLContext private[sql]( cacheManager.lookupCachedData(table(tableName)).nonEmpty } + /** + * Returns true if the [[Dataset]] is currently cached in-memory. + * @group cachemgmt + * @since 1.3.0 + */ + private[sql] def isCached(qName: Dataset[_]): Boolean = { + cacheManager.lookupCachedData(qName).nonEmpty + } + /** * Caches the specified table in-memory. * @group cachemgmt @@ -379,17 +336,6 @@ class SQLContext private[sql]( @Experimental object implicits extends SQLImplicits with Serializable { protected override def _sqlContext: SQLContext = self - - /** - * Converts $"col name" into an [[Column]]. - * @since 1.3.0 - */ - // This must live here to preserve binary compatibility with Spark < 1.5. - implicit class StringToColumn(val sc: StringContext) { - def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args: _*)) - } - } } // scalastyle:on @@ -402,11 +348,11 @@ class SQLContext private[sql]( */ @Experimental def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { - SparkPlan.currentContext.set(self) + SQLContext.setActive(self) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType)) - DataFrame(self, LogicalRDD(attributeSeq, rowRDD)(self)) + Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRDD)(self)) } /** @@ -418,10 +364,10 @@ class SQLContext private[sql]( */ @Experimental def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { - SparkPlan.currentContext.set(self) + SQLContext.setActive(self) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes - DataFrame(self, LocalRelation.fromProduct(attributeSeq, data)) + Dataset.ofRows(self, LocalRelation.fromProduct(attributeSeq, data)) } /** @@ -431,7 +377,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = { - DataFrame(this, LogicalRelation(baseRelation)) + Dataset.ofRows(this, LogicalRelation(baseRelation)) } /** @@ -486,7 +432,7 @@ class SQLContext private[sql]( rowRDD.map{r: Row => InternalRow.fromSeq(r.toSeq)} } val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self) - DataFrame(this, logicalPlan) + Dataset.ofRows(this, logicalPlan) } @@ -496,7 +442,7 @@ class SQLContext private[sql]( val encoded = data.map(d => enc.toRow(d).copy()) val plan = new LocalRelation(attributes, encoded) - new Dataset[T](this, plan) + Dataset[T](this, plan) } def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = { @@ -505,7 +451,11 @@ class SQLContext private[sql]( val encoded = data.map(d => enc.toRow(d)) val plan = LogicalRDD(attributes, encoded)(self) - new Dataset[T](this, plan) + Dataset[T](this, plan) + } + + def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { + createDataset(data.asScala) } /** @@ -517,7 +467,7 @@ class SQLContext private[sql]( // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self) - DataFrame(this, logicalPlan) + Dataset.ofRows(this, logicalPlan) } /** @@ -545,7 +495,7 @@ class SQLContext private[sql]( */ @DeveloperApi def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = { - DataFrame(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala)) + Dataset.ofRows(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala)) } /** @@ -564,7 +514,7 @@ class SQLContext private[sql]( val localBeanInfo = Introspector.getBeanInfo(Utils.classForName(className)) SQLContext.beansToRows(iter, localBeanInfo, attributeSeq) } - DataFrame(this, LogicalRDD(attributeSeq, rowRdd)(this)) + Dataset.ofRows(this, LogicalRDD(attributeSeq, rowRdd)(this)) } /** @@ -592,13 +542,12 @@ class SQLContext private[sql]( val className = beanClass.getName val beanInfo = Introspector.getBeanInfo(beanClass) val rows = SQLContext.beansToRows(data.asScala.iterator, beanInfo, attrSeq) - DataFrame(self, LocalRelation(attrSeq, rows.toSeq)) + Dataset.ofRows(self, LocalRelation(attrSeq, rows.toSeq)) } - /** * :: Experimental :: - * Returns a [[DataFrameReader]] that can be used to read data in as a [[DataFrame]]. + * Returns a [[DataFrameReader]] that can be used to read data and streams in as a [[DataFrame]]. * {{{ * sqlContext.read.parquet("/path/to/file.parquet") * sqlContext.read.schema(schema).json("/path/to/file.json") @@ -670,7 +619,7 @@ class SQLContext private[sql]( tableName: String, source: String, options: Map[String, String]): DataFrame = { - val tableIdent = SqlParser.parseTableIdentifier(tableName) + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( tableIdent, @@ -716,7 +665,7 @@ class SQLContext private[sql]( source: String, schema: StructType, options: Map[String, String]): DataFrame = { - val tableIdent = SqlParser.parseTableIdentifier(tableName) + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( tableIdent, @@ -735,7 +684,10 @@ class SQLContext private[sql]( * only during the lifetime of this instance of SQLContext. */ private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = { - catalog.registerTable(TableIdentifier(tableName), df.logicalPlan) + sessionState.catalog.createTempTable( + sessionState.sqlParser.parseTableIdentifier(tableName).table, + df.logicalPlan, + overrideIfExists = true) } /** @@ -743,55 +695,63 @@ class SQLContext private[sql]( * cached/persisted before, it's also unpersisted. * * @param tableName the name of the table to be unregistered. - * * @group basic * @since 1.3.0 */ def dropTempTable(tableName: String): Unit = { cacheManager.tryUncacheQuery(table(tableName)) - catalog.unregisterTable(TableIdentifier(tableName)) + sessionState.catalog.dropTable(TableIdentifier(tableName), ignoreIfNotExists = true) } /** * :: Experimental :: - * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements + * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements * in an range from 0 to `end` (exclusive) with step value 1. * - * @since 1.4.1 - * @group dataframe + * @since 2.0.0 + * @group dataset */ @Experimental - def range(end: Long): DataFrame = range(0, end) + def range(end: Long): Dataset[java.lang.Long] = range(0, end) /** * :: Experimental :: - * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements + * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements * in an range from `start` to `end` (exclusive) with step value 1. * - * @since 1.4.0 - * @group dataframe + * @since 2.0.0 + * @group dataset + */ + @Experimental + def range(start: Long, end: Long): Dataset[java.lang.Long] = { + range(start, end, step = 1, numPartitions = sparkContext.defaultParallelism) + } + + /** + * :: Experimental :: + * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements + * in an range from `start` to `end` (exclusive) with an step value. + * + * @since 2.0.0 + * @group dataset */ @Experimental - def range(start: Long, end: Long): DataFrame = { - createDataFrame( - sparkContext.range(start, end).map(Row(_)), - StructType(StructField("id", LongType, nullable = false) :: Nil)) + def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = { + range(start, end, step, numPartitions = sparkContext.defaultParallelism) } /** * :: Experimental :: - * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements + * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements * in an range from `start` to `end` (exclusive) with an step value, with partition number * specified. * - * @since 1.4.0 - * @group dataframe + * @since 2.0.0 + * @group dataset */ @Experimental - def range(start: Long, end: Long, step: Long, numPartitions: Int): DataFrame = { - createDataFrame( - sparkContext.range(start, end, step, numPartitions).map(Row(_)), - StructType(StructField("id", LongType, nullable = false) :: Nil)) + def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = { + new Dataset(this, Range(start, end, step, numPartitions), Encoders.LONG) } /** @@ -802,7 +762,16 @@ class SQLContext private[sql]( * @since 1.3.0 */ def sql(sqlText: String): DataFrame = { - DataFrame(this, parseSql(sqlText)) + Dataset.ofRows(this, parseSql(sqlText)) + } + + /** + * Executes a SQL query without parsing it, but instead passing it directly to an underlying + * system to process. This is currently only used for Hive DDLs and will be removed as soon + * as Spark can parse all supported Hive DDLs itself. + */ + private[sql] def runNativeSql(sqlText: String): Seq[Row] = { + throw new UnsupportedOperationException } /** @@ -812,11 +781,11 @@ class SQLContext private[sql]( * @since 1.3.0 */ def table(tableName: String): DataFrame = { - table(SqlParser.parseTableIdentifier(tableName)) + table(sessionState.sqlParser.parseTableIdentifier(tableName)) } private def table(tableIdent: TableIdentifier): DataFrame = { - DataFrame(this, catalog.lookupRelation(tableIdent)) + Dataset.ofRows(this, sessionState.catalog.lookupRelation(tableIdent)) } /** @@ -828,7 +797,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def tables(): DataFrame = { - DataFrame(this, ShowTablesCommand(None)) + Dataset.ofRows(this, ShowTablesCommand(None, None)) } /** @@ -840,9 +809,17 @@ class SQLContext private[sql]( * @since 1.3.0 */ def tables(databaseName: String): DataFrame = { - DataFrame(this, ShowTablesCommand(Some(databaseName))) + Dataset.ofRows(this, ShowTablesCommand(Some(databaseName), None)) } + /** + * Returns a [[ContinuousQueryManager]] that allows managing all the + * [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]] active on `this` context. + * + * @since 2.0.0 + */ + def streams: ContinuousQueryManager = sessionState.continuousQueryManager + /** * Returns the names of tables in the current database as an array. * @@ -850,9 +827,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def tableNames(): Array[String] = { - catalog.getTables(None).map { - case (tableName, _) => tableName - }.toArray + tableNames(sessionState.catalog.getCurrentDatabase) } /** @@ -862,36 +837,12 @@ class SQLContext private[sql]( * @since 1.3.0 */ def tableNames(databaseName: String): Array[String] = { - catalog.getTables(Some(databaseName)).map { - case (tableName, _) => tableName - }.toArray + sessionState.catalog.listTables(databaseName).map(_.table).toArray } - @deprecated("use org.apache.spark.sql.SparkPlanner", "1.6.0") - protected[sql] class SparkPlanner extends sparkexecution.SparkPlanner(this) - - @transient - protected[sql] val planner: sparkexecution.SparkPlanner = new sparkexecution.SparkPlanner(this) - @transient protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[InternalRow], 1) - /** - * Prepares a planned SparkPlan for execution by inserting shuffle operations and internal - * row format conversions as needed. - */ - @transient - protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { - val batches = Seq( - Batch("Add exchange", Once, EnsureRequirements(self)), - Batch("Add row converters", Once, EnsureRowFormats) - ) - } - - @deprecated("use org.apache.spark.sql.QueryExecution", "1.6.0") - protected[sql] class QueryExecution(logical: LogicalPlan) - extends sparkexecution.QueryExecution(this, logical) - /** * Parses the data type in our internal string representation. The data type string should * have the same format as the one generated by `toString` in scala. @@ -918,8 +869,8 @@ class SQLContext private[sql]( rdd: RDD[Array[Any]], schema: StructType): DataFrame = { - val rowRdd = rdd.map(r => EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow]) - DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) + val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow]) + Dataset.ofRows(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) } /** @@ -932,301 +883,13 @@ class SQLContext private[sql]( } } - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - // Deprecated methods - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - - /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. - */ - @deprecated("use createDataFrame", "1.3.0") - def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { - createDataFrame(rowRDD, schema) - } - - /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. - */ - @deprecated("use createDataFrame", "1.3.0") - def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { - createDataFrame(rowRDD, schema) - } - - /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. - */ - @deprecated("use createDataFrame", "1.3.0") - def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { - createDataFrame(rdd, beanClass) - } - - /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. - */ - @deprecated("use createDataFrame", "1.3.0") - def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { - createDataFrame(rdd, beanClass) - } - - /** - * Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty - * [[DataFrame]] if no paths are passed in. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().parquet()`. - */ - @deprecated("Use read.parquet()", "1.4.0") - @scala.annotation.varargs - def parquetFile(paths: String*): DataFrame = { - if (paths.isEmpty) { - emptyDataFrame - } else { - read.parquet(paths : _*) - } - } - - /** - * Loads a JSON file (one object per line), returning the result as a [[DataFrame]]. - * It goes through the entire dataset once to determine the schema. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. - */ - @deprecated("Use read.json()", "1.4.0") - def jsonFile(path: String): DataFrame = { - read.json(path) - } - - /** - * Loads a JSON file (one object per line) and applies the given schema, - * returning the result as a [[DataFrame]]. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. - */ - @deprecated("Use read.json()", "1.4.0") - def jsonFile(path: String, schema: StructType): DataFrame = { - read.schema(schema).json(path) - } - - /** - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. - */ - @deprecated("Use read.json()", "1.4.0") - def jsonFile(path: String, samplingRatio: Double): DataFrame = { - read.option("samplingRatio", samplingRatio.toString).json(path) - } - - /** - * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a - * [[DataFrame]]. - * It goes through the entire dataset once to determine the schema. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. - */ - @deprecated("Use read.json()", "1.4.0") - def jsonRDD(json: RDD[String]): DataFrame = read.json(json) - - /** - * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a - * [[DataFrame]]. - * It goes through the entire dataset once to determine the schema. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. - */ - @deprecated("Use read.json()", "1.4.0") - def jsonRDD(json: JavaRDD[String]): DataFrame = read.json(json) - - /** - * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, - * returning the result as a [[DataFrame]]. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. - */ - @deprecated("Use read.json()", "1.4.0") - def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { - read.schema(schema).json(json) - } - - /** - * Loads an JavaRDD storing JSON objects (one object per record) and applies the given - * schema, returning the result as a [[DataFrame]]. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. - */ - @deprecated("Use read.json()", "1.4.0") - def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = { - read.schema(schema).json(json) - } - - /** - * Loads an RDD[String] storing JSON objects (one object per record) inferring the - * schema, returning the result as a [[DataFrame]]. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. - */ - @deprecated("Use read.json()", "1.4.0") - def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { - read.option("samplingRatio", samplingRatio.toString).json(json) - } - - /** - * Loads a JavaRDD[String] storing JSON objects (one object per record) inferring the - * schema, returning the result as a [[DataFrame]]. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. - */ - @deprecated("Use read.json()", "1.4.0") - def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = { - read.option("samplingRatio", samplingRatio.toString).json(json) - } - - /** - * Returns the dataset stored at path as a DataFrame, - * using the default data source configured by spark.sql.sources.default. - * - * @group genericdata - * @deprecated As of 1.4.0, replaced by `read().load(path)`. - */ - @deprecated("Use read.load(path)", "1.4.0") - def load(path: String): DataFrame = { - read.load(path) - } - - /** - * Returns the dataset stored at path as a DataFrame, using the given data source. - * - * @group genericdata - * @deprecated As of 1.4.0, replaced by `read().format(source).load(path)`. - */ - @deprecated("Use read.format(source).load(path)", "1.4.0") - def load(path: String, source: String): DataFrame = { - read.format(source).load(path) - } - - /** - * (Java-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame. - * - * @group genericdata - * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. - */ - @deprecated("Use read.format(source).options(options).load()", "1.4.0") - def load(source: String, options: java.util.Map[String, String]): DataFrame = { - read.options(options).format(source).load() - } - - /** - * (Scala-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame. - * - * @group genericdata - * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. - */ - @deprecated("Use read.format(source).options(options).load()", "1.4.0") - def load(source: String, options: Map[String, String]): DataFrame = { - read.options(options).format(source).load() - } - - /** - * (Java-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. - * - * @group genericdata - * @deprecated As of 1.4.0, replaced by - * `read().format(source).schema(schema).options(options).load()`. - */ - @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") - def load(source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame = - { - read.format(source).schema(schema).options(options).load() - } - - /** - * (Scala-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. - * - * @group genericdata - * @deprecated As of 1.4.0, replaced by - * `read().format(source).schema(schema).options(options).load()`. - */ - @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") - def load(source: String, schema: StructType, options: Map[String, String]): DataFrame = { - read.format(source).schema(schema).options(options).load() - } - - /** - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. - */ - @deprecated("use read.jdbc()", "1.4.0") - def jdbc(url: String, table: String): DataFrame = { - read.jdbc(url, table, new Properties) - } - - /** - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table. Partitions of the table will be retrieved in parallel based on the parameters - * passed to this function. - * - * @param columnName the name of a column of integral type that will be used for partitioning. - * @param lowerBound the minimum value of `columnName` used to decide partition stride - * @param upperBound the maximum value of `columnName` used to decide partition stride - * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split - * evenly into this many partitions - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. - */ - @deprecated("use read.jdbc()", "1.4.0") - def jdbc( - url: String, - table: String, - columnName: String, - lowerBound: Long, - upperBound: Long, - numPartitions: Int): DataFrame = { - read.jdbc(url, table, columnName, lowerBound, upperBound, numPartitions, new Properties) - } - - /** - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table. The theParts parameter gives a list expressions - * suitable for inclusion in WHERE clauses; each one defines one partition - * of the [[DataFrame]]. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. - */ - @deprecated("use read.jdbc()", "1.4.0") - def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = { - read.jdbc(url, table, theParts, new Properties) - } - - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - // End of deprecated methods - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - - - // Register a succesfully instantiatd context to the singleton. This should be at the end of + // Register a successfully instantiated context to the singleton. This should be at the end of // the class definition so that the singleton is updated only if there is no exception in the // construction of the instance. sparkContext.addSparkListener(new SparkListener { override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { - SQLContext.clearInstantiatedContext(self) + SQLContext.clearInstantiatedContext() + SQLContext.clearSqlListener() } }) @@ -1254,6 +917,8 @@ object SQLContext { */ @transient private val instantiatedContext = new AtomicReference[SQLContext]() + @transient private val sqlListener = new AtomicReference[SQLListener]() + /** * Get the singleton SQLContext if it exists or create a new one using the given SparkContext. * @@ -1267,13 +932,13 @@ object SQLContext { */ def getOrCreate(sparkContext: SparkContext): SQLContext = { val ctx = activeContext.get() - if (ctx != null) { + if (ctx != null && !ctx.sparkContext.isStopped) { return ctx } synchronized { val ctx = instantiatedContext.get() - if (ctx == null) { + if (ctx == null || ctx.sparkContext.isStopped) { new SQLContext(sparkContext) } else { ctx @@ -1281,18 +946,27 @@ object SQLContext { } } - private[sql] def clearInstantiatedContext(sqlContext: SQLContext): Unit = { - instantiatedContext.compareAndSet(sqlContext, null) + private[sql] def clearInstantiatedContext(): Unit = { + instantiatedContext.set(null) } private[sql] def setInstantiatedContext(sqlContext: SQLContext): Unit = { - instantiatedContext.compareAndSet(null, sqlContext) + synchronized { + val ctx = instantiatedContext.get() + if (ctx == null || ctx.sparkContext.isStopped) { + instantiatedContext.set(sqlContext) + } + } } private[sql] def getInstantiatedContextOption(): Option[SQLContext] = { Option(instantiatedContext.get()) } + private[sql] def clearSqlListener(): Unit = { + sqlListener.set(null) + } + /** * Changes the SQLContext that will be returned in this thread and its children when * SQLContext.getOrCreate() is called. This can be used to ensure that a given thread receives @@ -1314,7 +988,7 @@ object SQLContext { activeContext.remove() } - private[sql] def getActiveContextOption(): Option[SQLContext] = { + private[sql] def getActive(): Option[SQLContext] = { Option(activeContext.get()) } @@ -1341,9 +1015,27 @@ object SQLContext { * Create a SQLListener then add it into SparkContext, and create an SQLTab if there is SparkUI. */ private[sql] def createListenerAndUI(sc: SparkContext): SQLListener = { - val listener = new SQLListener(sc.conf) - sc.addSparkListener(listener) - sc.ui.foreach(new SQLTab(listener, _)) - listener + if (sqlListener.get() == null) { + val listener = new SQLListener(sc.conf) + if (sqlListener.compareAndSet(null, listener)) { + sc.addSparkListener(listener) + sc.ui.foreach(new SQLTab(listener, _)) + } + } + sqlListener.get() + } + + /** + * Extract `spark.sql.*` properties from the conf and return them as a [[Properties]]. + */ + private[sql] def getSQLProperties(sparkConf: SparkConf): Properties = { + val properties = new Properties + sparkConf.getAll.foreach { case (key, value) => + if (key.startsWith("spark.sql")) { + properties.setProperty(key, value) + } + } + properties } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 6da46a5f7ef9a..ad69e23540a91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -17,37 +17,125 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.encoders._ -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.execution.datasources.LogicalRelation - import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow -import org.apache.spark.sql.types.StructField -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder /** * A collection of implicit methods for converting common Scala objects into [[DataFrame]]s. + * + * @since 1.6.0 */ abstract class SQLImplicits { + protected def _sqlContext: SQLContext - implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder[T]() + /** + * Converts $"col name" into an [[Column]]. + * + * @since 2.0.0 + */ + implicit class StringToColumn(val sc: StringContext) { + def $(args: Any*): ColumnName = { + new ColumnName(sc.s(args: _*)) + } + } + + /** @since 1.6.0 */ + implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T] + + // Primitives + + /** @since 1.6.0 */ + implicit def newIntEncoder: Encoder[Int] = Encoders.scalaInt + + /** @since 1.6.0 */ + implicit def newLongEncoder: Encoder[Long] = Encoders.scalaLong + + /** @since 1.6.0 */ + implicit def newDoubleEncoder: Encoder[Double] = Encoders.scalaDouble + + /** @since 1.6.0 */ + implicit def newFloatEncoder: Encoder[Float] = Encoders.scalaFloat + + /** @since 1.6.0 */ + implicit def newByteEncoder: Encoder[Byte] = Encoders.scalaByte + + /** @since 1.6.0 */ + implicit def newShortEncoder: Encoder[Short] = Encoders.scalaShort + + /** @since 1.6.0 */ + implicit def newBooleanEncoder: Encoder[Boolean] = Encoders.scalaBoolean + + /** @since 1.6.0 */ + implicit def newStringEncoder: Encoder[String] = Encoders.STRING + + // Seqs + + /** @since 1.6.1 */ + implicit def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder() + + // Arrays + + /** @since 1.6.1 */ + implicit def newIntArrayEncoder: Encoder[Array[Int]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newLongArrayEncoder: Encoder[Array[Long]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newDoubleArrayEncoder: Encoder[Array[Double]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newFloatArrayEncoder: Encoder[Array[Float]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newByteArrayEncoder: Encoder[Array[Byte]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newShortArrayEncoder: Encoder[Array[Short]] = ExpressionEncoder() - implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder[Int](flat = true) - implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true) - implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder[Double](flat = true) - implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder[Float](flat = true) - implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder[Byte](flat = true) - implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder[Short](flat = true) - implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder[Boolean](flat = true) - implicit def newStringEncoder: Encoder[String] = ExpressionEncoder[String](flat = true) + /** @since 1.6.1 */ + implicit def newBooleanArrayEncoder: Encoder[Array[Boolean]] = ExpressionEncoder() + /** @since 1.6.1 */ + implicit def newStringArrayEncoder: Encoder[Array[String]] = ExpressionEncoder() + + /** @since 1.6.1 */ + implicit def newProductArrayEncoder[A <: Product : TypeTag]: Encoder[Array[A]] = + ExpressionEncoder() + + /** + * Creates a [[Dataset]] from an RDD. + * + * @since 1.6.0 + */ implicit def rddToDatasetHolder[T : Encoder](rdd: RDD[T]): DatasetHolder[T] = { DatasetHolder(_sqlContext.createDataset(rdd)) } @@ -66,75 +154,4 @@ abstract class SQLImplicits { */ implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) - /** - * Creates a DataFrame from an RDD of Product (e.g. case classes, tuples). - * @since 1.3.0 - */ - implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = { - DataFrameHolder(_sqlContext.createDataFrame(rdd)) - } - - /** - * Creates a DataFrame from a local Seq of Product. - * @since 1.3.0 - */ - implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = - { - DataFrameHolder(_sqlContext.createDataFrame(data)) - } - - // Do NOT add more implicit conversions. They are likely to break source compatibility by - // making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous - // because of [[DoubleRDDFunctions]]. - - /** - * Creates a single column DataFrame from an RDD[Int]. - * @since 1.3.0 - */ - implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = { - val dataType = IntegerType - val rows = data.mapPartitions { iter => - val row = new SpecificMutableRow(dataType :: Nil) - iter.map { v => - row.setInt(0, v) - row: InternalRow - } - } - DataFrameHolder( - _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) - } - - /** - * Creates a single column DataFrame from an RDD[Long]. - * @since 1.3.0 - */ - implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = { - val dataType = LongType - val rows = data.mapPartitions { iter => - val row = new SpecificMutableRow(dataType :: Nil) - iter.map { v => - row.setLong(0, v) - row: InternalRow - } - } - DataFrameHolder( - _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) - } - - /** - * Creates a single column DataFrame from an RDD[String]. - * @since 1.3.0 - */ - implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = { - val dataType = StringType - val rows = data.mapPartitions { iter => - val row = new SpecificMutableRow(dataType :: Nil) - iter.map { v => - row.update(0, UTF8String.fromString(v)) - row: InternalRow - } - } - DataFrameHolder( - _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala new file mode 100644 index 0000000000000..5a9852809c0eb --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.execution.streaming.{Offset, Sink} + +/** + * :: Experimental :: + * Status and metrics of a streaming [[Sink]]. + * + * @param description Description of the source corresponding to this status + * @param offset Current offset up to which data has been written by the sink + * @since 2.0.0 + */ +@Experimental +class SinkStatus private[sql]( + val description: String, + val offset: Offset) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SourceStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/SourceStatus.scala new file mode 100644 index 0000000000000..2479e67e369ec --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SourceStatus.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.execution.streaming.{Offset, Source} + +/** + * :: Experimental :: + * Status and metrics of a streaming [[Source]]. + * + * @param description Description of the source corresponding to this status + * @param offset Current offset of the source, if known + * @since 2.0.0 + */ +@Experimental +class SourceStatus private[sql] ( + val description: String, + val offset: Option[Offset]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala deleted file mode 100644 index ea8fce6ca9cf2..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import scala.util.parsing.combinator.RegexParsers - -import org.apache.spark.sql.catalyst.AbstractSparkSQLParser -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.{DescribeFunction, LogicalPlan, ShowFunctions} -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.types.StringType - - -/** - * The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL - * dialects supported by Spark SQL, and delegates all the other syntaxes to the `fallback` parser. - * - * @param fallback A function that parses an input string to a logical plan - */ -private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser { - - // A parser for the key-value part of the "SET [key = [value ]]" syntax - private object SetCommandParser extends RegexParsers { - private val key: Parser[String] = "(?m)[^=]+".r - - private val value: Parser[String] = "(?m).*$".r - - private val output: Seq[Attribute] = Seq(AttributeReference("", StringType, nullable = false)()) - - private val pair: Parser[LogicalPlan] = - (key ~ ("=".r ~> value).?).? ^^ { - case None => SetCommand(None) - case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim))) - } - - def apply(input: String): LogicalPlan = parseAll(pair, input) match { - case Success(plan, _) => plan - case x => sys.error(x.toString) - } - } - - protected val AS = Keyword("AS") - protected val CACHE = Keyword("CACHE") - protected val CLEAR = Keyword("CLEAR") - protected val DESCRIBE = Keyword("DESCRIBE") - protected val EXTENDED = Keyword("EXTENDED") - protected val FUNCTION = Keyword("FUNCTION") - protected val FUNCTIONS = Keyword("FUNCTIONS") - protected val IN = Keyword("IN") - protected val LAZY = Keyword("LAZY") - protected val SET = Keyword("SET") - protected val SHOW = Keyword("SHOW") - protected val TABLE = Keyword("TABLE") - protected val TABLES = Keyword("TABLES") - protected val UNCACHE = Keyword("UNCACHE") - - override protected lazy val start: Parser[LogicalPlan] = - cache | uncache | set | show | desc | others - - private lazy val cache: Parser[LogicalPlan] = - CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ { - case isLazy ~ tableName ~ plan => - CacheTableCommand(tableName, plan.map(fallback), isLazy.isDefined) - } - - private lazy val uncache: Parser[LogicalPlan] = - ( UNCACHE ~ TABLE ~> ident ^^ { - case tableName => UncacheTableCommand(tableName) - } - | CLEAR ~ CACHE ^^^ ClearCacheCommand - ) - - private lazy val set: Parser[LogicalPlan] = - SET ~> restInput ^^ { - case input => SetCommandParser(input) - } - - // It can be the following patterns: - // SHOW FUNCTIONS; - // SHOW FUNCTIONS mydb.func1; - // SHOW FUNCTIONS func1; - // SHOW FUNCTIONS `mydb.a`.`func1.aa`; - private lazy val show: Parser[LogicalPlan] = - ( SHOW ~> TABLES ~ (IN ~> ident).? ^^ { - case _ ~ dbName => ShowTablesCommand(dbName) - } - | SHOW ~ FUNCTIONS ~> ((ident <~ ".").? ~ (ident | stringLit)).? ^^ { - case Some(f) => ShowFunctions(f._1, Some(f._2)) - case None => ShowFunctions(None, None) - } - ) - - private lazy val desc: Parser[LogicalPlan] = - DESCRIBE ~ FUNCTION ~> EXTENDED.? ~ (ident | stringLit) ^^ { - case isExtended ~ functionName => DescribeFunction(functionName, isExtended.isDefined) - } - - private lazy val others: Parser[LogicalPlan] = - wholeInput ^^ { - case input => fallback(input) - } - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala b/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala new file mode 100644 index 0000000000000..c4e54b3f90ac5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration.Duration + +import org.apache.commons.lang3.StringUtils + +import org.apache.spark.annotation.Experimental +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * :: Experimental :: + * Used to indicate how often results should be produced by a [[ContinuousQuery]]. + */ +@Experimental +sealed trait Trigger {} + +/** + * :: Experimental :: + * A trigger that runs a query periodically based on the processing time. If `intervalMs` is 0, + * the query will run as fast as possible. + * + * Scala Example: + * {{{ + * def.writer.trigger(ProcessingTime("10 seconds")) + * + * import scala.concurrent.duration._ + * def.writer.trigger(ProcessingTime(10.seconds)) + * }}} + * + * Java Example: + * {{{ + * def.writer.trigger(ProcessingTime.create("10 seconds")) + * + * import java.util.concurrent.TimeUnit + * def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * }}} + */ +@Experimental +case class ProcessingTime(intervalMs: Long) extends Trigger { + require(intervalMs >= 0, "the interval of trigger should not be negative") +} + +/** + * :: Experimental :: + * Used to create [[ProcessingTime]] triggers for [[ContinuousQuery]]s. + */ +@Experimental +object ProcessingTime { + + /** + * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible. + * + * Example: + * {{{ + * def.writer.trigger(ProcessingTime("10 seconds")) + * }}} + */ + def apply(interval: String): ProcessingTime = { + if (StringUtils.isBlank(interval)) { + throw new IllegalArgumentException( + "interval cannot be null or blank.") + } + val cal = if (interval.startsWith("interval")) { + CalendarInterval.fromString(interval) + } else { + CalendarInterval.fromString("interval " + interval) + } + if (cal == null) { + throw new IllegalArgumentException(s"Invalid interval: $interval") + } + if (cal.months > 0) { + throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval") + } + new ProcessingTime(cal.microseconds / 1000) + } + + /** + * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible. + * + * Example: + * {{{ + * import scala.concurrent.duration._ + * def.writer.trigger(ProcessingTime(10.seconds)) + * }}} + */ + def apply(interval: Duration): ProcessingTime = { + new ProcessingTime(interval.toMillis) + } + + /** + * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible. + * + * Example: + * {{{ + * def.writer.trigger(ProcessingTime.create("10 seconds")) + * }}} + */ + def create(interval: String): ProcessingTime = { + apply(interval) + } + + /** + * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible. + * + * Example: + * {{{ + * import java.util.concurrent.TimeUnit + * def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * }}} + */ + def create(interval: Long, unit: TimeUnit): ProcessingTime = { + new ProcessingTime(unit.toMillis(interval)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index f5b95e13e47bc..3a043dcc6af22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -17,17 +17,17 @@ package org.apache.spark.sql -import java.util.{List => JList, Map => JMap} - import scala.reflect.runtime.universe.TypeTag import scala.util.Try -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.sql.api.java._ +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.aggregate.ScalaUDAF -import org.apache.spark.sql.expressions.UserDefinedAggregateFunction +import org.apache.spark.sql.execution.python.UserDefinedPythonFunction +import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction} import org.apache.spark.sql.types.DataType /** @@ -35,19 +35,17 @@ import org.apache.spark.sql.types.DataType * * @since 1.3.0 */ -class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { - - private val functionRegistry = sqlContext.functionRegistry +class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends Logging { protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = { log.debug( s""" | Registering new PythonUDF: | name: $name - | command: ${udf.command.toSeq} - | envVars: ${udf.envVars} - | pythonIncludes: ${udf.pythonIncludes} - | pythonExec: ${udf.pythonExec} + | command: ${udf.func.command.toSeq} + | envVars: ${udf.func.envVars} + | pythonIncludes: ${udf.func.pythonIncludes} + | pythonExec: ${udf.func.pythonExec} | dataType: ${udf.dataType} """.stripMargin) @@ -58,10 +56,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined aggregate function (UDAF). * * @param name the name of the UDAF. - * @param udaf the UDAF that needs to be registered. + * @param udaf the UDAF needs to be registered. * @return the registered UDAF. - * - * @since 1.5.0 */ def register( name: String, @@ -71,23 +67,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { udaf } - /** - * Register a user-defined function (UDF). - * - * @param name the name of the UDF. - * @param udf the UDF that needs to be registered. - * @return the registered UDF. - * - * @since 1.6.0 - */ - def register( - name: String, - udf: UserDefinedFunction): UserDefinedFunction = { - functionRegistry.registerFunction(name, udf.builder) - udf - } - - // scalastyle:off + // scalastyle:off line.size.limit /* register 0-22 were generated by this script @@ -103,10 +83,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try($inputTypes).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + val inputTypes = Try($inputTypes).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType, inputTypes) }""") } @@ -120,7 +100,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { | * Register a user-defined function with ${i} arguments. | * @since 1.3.0 | */ - |def register(name: String, f: UDF$i[$extTypeArgs, _], returnType: DataType) = { + |def register(name: String, f: UDF$i[$extTypeArgs, _], returnType: DataType): Unit = { | functionRegistry.registerFunction( | name, | (e: Seq[Expression]) => ScalaUDF(f$anyCast.call($anyParams), returnType, e)) @@ -135,10 +115,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + val inputTypes = Try(Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -148,10 +128,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -161,10 +141,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -174,10 +154,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -187,10 +167,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -200,10 +180,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -213,10 +193,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -226,10 +206,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -239,10 +219,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -252,10 +232,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -265,10 +245,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -278,10 +258,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -291,10 +271,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -304,10 +284,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -317,10 +297,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -330,10 +310,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -343,10 +323,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -356,10 +336,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -369,10 +349,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -382,10 +362,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -395,10 +375,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -408,10 +388,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -421,10 +401,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType, inputTypes) } ////////////////////////////////////////////////////////////////////////////////////////////// @@ -434,7 +414,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 1 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF1[_, _], returnType: DataType) = { + def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF1[Any, Any]].call(_: Any), returnType, e)) @@ -444,7 +424,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 2 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF2[_, _, _], returnType: DataType) = { + def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any), returnType, e)) @@ -454,7 +434,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 3 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF3[_, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any), returnType, e)) @@ -464,7 +444,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 4 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -474,7 +454,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 5 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -484,7 +464,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 6 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -494,7 +474,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 7 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -504,7 +484,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 8 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -514,7 +494,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 9 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -524,7 +504,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 10 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -534,7 +514,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 11 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -544,7 +524,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 12 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -554,7 +534,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 13 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -564,7 +544,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 14 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -574,7 +554,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 15 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -584,7 +564,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 16 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -594,7 +574,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 17 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -604,7 +584,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 18 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -614,7 +594,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 19 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -624,7 +604,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 20 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -634,7 +614,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 21 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -644,11 +624,12 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 22 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } - // scalastyle:on + // scalastyle:on line.size.limit + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala deleted file mode 100644 index 1319391db5375..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import java.util.{List => JList, Map => JMap} - -import org.apache.spark.Accumulator -import org.apache.spark.annotation.Experimental -import org.apache.spark.api.python.PythonBroadcast -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} -import org.apache.spark.sql.execution.PythonUDF -import org.apache.spark.sql.types.DataType - -/** - * A user-defined function. To create one, use the `udf` functions in [[functions]]. - * As an example: - * {{{ - * // Defined a UDF that returns true or false based on some numeric score. - * val predict = udf((score: Double) => if (score > 0.5) true else false) - * - * // Projects a column that adds a prediction column based on the score column. - * df.select( predict(df("score")) ) - * }}} - * - * @since 1.3.0 - */ -@Experimental -case class UserDefinedFunction protected[sql] ( - f: AnyRef, - dataType: DataType, - inputTypes: Seq[DataType] = Nil, - deterministic: Boolean = true) { - - def apply(exprs: Column*): Column = { - Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes, deterministic)) - } - - protected[sql] def builder: Seq[Expression] => ScalaUDF = { - (exprs: Seq[Expression]) => - ScalaUDF(f, dataType, exprs, inputTypes, deterministic) - } - - def nondeterministic: UserDefinedFunction = - UserDefinedFunction(f, dataType, inputTypes, deterministic = false) -} - -/** - * A user-defined Python function. To create one, use the `pythonUDF` functions in [[functions]]. - * This is used by Python API. - */ -private[sql] case class UserDefinedPythonFunction( - name: String, - command: Array[Byte], - envVars: JMap[String, String], - pythonIncludes: JList[String], - pythonExec: String, - pythonVer: String, - broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: Accumulator[JList[Array[Byte]]], - dataType: DataType) { - - def builder(e: Seq[Expression]): PythonUDF = { - PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, - accumulator, dataType, e) - } - - /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ - def apply(exprs: Column*): Column = { - val udf = builder(exprs.map(_.expr)) - Column(udf) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index b3f134614c6bb..22ded7a4bf5b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -19,20 +19,20 @@ package org.apache.spark.sql.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import scala.util.matching.Regex + import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.r.SerDe import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression, GenericRowWithSchema} +import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext} +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode} - -import scala.util.matching.Regex private[r] object SQLUtils { SerDe.registerSqlSerDe((readSqlObject, writeSqlObject)) def createSQLContext(jsc: JavaSparkContext): SQLContext = { - new SQLContext(jsc) + SQLContext.getOrCreate(jsc.sc) } def getJavaSparkContext(sqlCtx: SQLContext): JavaSparkContext = { @@ -94,13 +94,13 @@ private[r] object SQLUtils { } def createDF(rdd: RDD[Array[Byte]], schema: StructType, sqlContext: SQLContext): DataFrame = { - val num = schema.fields.size + val num = schema.fields.length val rowRDD = rdd.map(bytesToRow(_, schema)) sqlContext.createDataFrame(rowRDD, schema) } def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = { - df.map(r => rowToRBytes(r)) + df.rdd.map(r => rowToRBytes(r)) } private[this] def doConversion(data: Object, dataType: DataType): Object = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala deleted file mode 100644 index 42ec4d3433f16..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.columnar - -import java.nio.{ByteBuffer, ByteOrder} - -import org.apache.spark.sql.catalyst.expressions.{MutableRow, UnsafeArrayData, UnsafeMapData, UnsafeRow} -import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor -import org.apache.spark.sql.types._ - -/** - * An `Iterator` like trait used to extract values from columnar byte buffer. When a value is - * extracted from the buffer, instead of directly returning it, the value is set into some field of - * a [[MutableRow]]. In this way, boxing cost can be avoided by leveraging the setter methods - * for primitive values provided by [[MutableRow]]. - */ -private[sql] trait ColumnAccessor { - initialize() - - protected def initialize() - - def hasNext: Boolean - - def extractTo(row: MutableRow, ordinal: Int) - - protected def underlyingBuffer: ByteBuffer -} - -private[sql] abstract class BasicColumnAccessor[JvmType]( - protected val buffer: ByteBuffer, - protected val columnType: ColumnType[JvmType]) - extends ColumnAccessor { - - protected def initialize() {} - - override def hasNext: Boolean = buffer.hasRemaining - - override def extractTo(row: MutableRow, ordinal: Int): Unit = { - extractSingle(row, ordinal) - } - - def extractSingle(row: MutableRow, ordinal: Int): Unit = { - columnType.extract(buffer, row, ordinal) - } - - protected def underlyingBuffer = buffer -} - -private[sql] class NullColumnAccessor(buffer: ByteBuffer) - extends BasicColumnAccessor[Any](buffer, NULL) - with NullableColumnAccessor - -private[sql] abstract class NativeColumnAccessor[T <: AtomicType]( - override protected val buffer: ByteBuffer, - override protected val columnType: NativeColumnType[T]) - extends BasicColumnAccessor(buffer, columnType) - with NullableColumnAccessor - with CompressibleColumnAccessor[T] - -private[sql] class BooleanColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, BOOLEAN) - -private[sql] class ByteColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, BYTE) - -private[sql] class ShortColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, SHORT) - -private[sql] class IntColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, INT) - -private[sql] class LongColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, LONG) - -private[sql] class FloatColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, FLOAT) - -private[sql] class DoubleColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, DOUBLE) - -private[sql] class StringColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, STRING) - -private[sql] class BinaryColumnAccessor(buffer: ByteBuffer) - extends BasicColumnAccessor[Array[Byte]](buffer, BINARY) - with NullableColumnAccessor - -private[sql] class CompactDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) - extends NativeColumnAccessor(buffer, COMPACT_DECIMAL(dataType)) - -private[sql] class DecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) - extends BasicColumnAccessor[Decimal](buffer, LARGE_DECIMAL(dataType)) - with NullableColumnAccessor - -private[sql] class StructColumnAccessor(buffer: ByteBuffer, dataType: StructType) - extends BasicColumnAccessor[UnsafeRow](buffer, STRUCT(dataType)) - with NullableColumnAccessor - -private[sql] class ArrayColumnAccessor(buffer: ByteBuffer, dataType: ArrayType) - extends BasicColumnAccessor[UnsafeArrayData](buffer, ARRAY(dataType)) - with NullableColumnAccessor - -private[sql] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType) - extends BasicColumnAccessor[UnsafeMapData](buffer, MAP(dataType)) - with NullableColumnAccessor - -private[sql] object ColumnAccessor { - def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = { - val buf = buffer.order(ByteOrder.nativeOrder) - - dataType match { - case NullType => new NullColumnAccessor(buf) - case BooleanType => new BooleanColumnAccessor(buf) - case ByteType => new ByteColumnAccessor(buf) - case ShortType => new ShortColumnAccessor(buf) - case IntegerType | DateType => new IntColumnAccessor(buf) - case LongType | TimestampType => new LongColumnAccessor(buf) - case FloatType => new FloatColumnAccessor(buf) - case DoubleType => new DoubleColumnAccessor(buf) - case StringType => new StringColumnAccessor(buf) - case BinaryType => new BinaryColumnAccessor(buf) - case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => - new CompactDecimalColumnAccessor(buf, dt) - case dt: DecimalType => new DecimalColumnAccessor(buf, dt) - case struct: StructType => new StructColumnAccessor(buf, struct) - case array: ArrayType => new ArrayColumnAccessor(buf, array) - case map: MapType => new MapColumnAccessor(buf, map) - case udt: UserDefinedType[_] => ColumnAccessor(udt.sqlType, buffer) - case other => - throw new Exception(s"not support type: $other") - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala deleted file mode 100644 index 7a7345a7e004b..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ /dev/null @@ -1,181 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.columnar - -import java.nio.{ByteBuffer, ByteOrder} - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.columnar.ColumnBuilder._ -import org.apache.spark.sql.columnar.compression.{AllCompressionSchemes, CompressibleColumnBuilder} -import org.apache.spark.sql.types._ - -private[sql] trait ColumnBuilder { - /** - * Initializes with an approximate lower bound on the expected number of elements in this column. - */ - def initialize(initialSize: Int, columnName: String = "", useCompression: Boolean = false) - - /** - * Appends `row(ordinal)` to the column builder. - */ - def appendFrom(row: InternalRow, ordinal: Int) - - /** - * Column statistics information - */ - def columnStats: ColumnStats - - /** - * Returns the final columnar byte buffer. - */ - def build(): ByteBuffer -} - -private[sql] class BasicColumnBuilder[JvmType]( - val columnStats: ColumnStats, - val columnType: ColumnType[JvmType]) - extends ColumnBuilder { - - protected var columnName: String = _ - - protected var buffer: ByteBuffer = _ - - override def initialize( - initialSize: Int, - columnName: String = "", - useCompression: Boolean = false): Unit = { - - val size = if (initialSize == 0) DEFAULT_INITIAL_BUFFER_SIZE else initialSize - this.columnName = columnName - - buffer = ByteBuffer.allocate(size * columnType.defaultSize) - buffer.order(ByteOrder.nativeOrder()) - } - - override def appendFrom(row: InternalRow, ordinal: Int): Unit = { - buffer = ensureFreeSpace(buffer, columnType.actualSize(row, ordinal)) - columnType.append(row, ordinal, buffer) - } - - override def build(): ByteBuffer = { - buffer.flip().asInstanceOf[ByteBuffer] - } -} - -private[sql] class NullColumnBuilder - extends BasicColumnBuilder[Any](new ObjectColumnStats(NullType), NULL) - with NullableColumnBuilder - -private[sql] abstract class ComplexColumnBuilder[JvmType]( - columnStats: ColumnStats, - columnType: ColumnType[JvmType]) - extends BasicColumnBuilder[JvmType](columnStats, columnType) - with NullableColumnBuilder - -private[sql] abstract class NativeColumnBuilder[T <: AtomicType]( - override val columnStats: ColumnStats, - override val columnType: NativeColumnType[T]) - extends BasicColumnBuilder[T#InternalType](columnStats, columnType) - with NullableColumnBuilder - with AllCompressionSchemes - with CompressibleColumnBuilder[T] - -private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) - -private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) - -private[sql] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT) - -private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) - -private[sql] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG) - -private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) - -private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE) - -private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) - -private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) - -private[sql] class CompactDecimalColumnBuilder(dataType: DecimalType) - extends NativeColumnBuilder(new DecimalColumnStats(dataType), COMPACT_DECIMAL(dataType)) - -private[sql] class DecimalColumnBuilder(dataType: DecimalType) - extends ComplexColumnBuilder(new DecimalColumnStats(dataType), LARGE_DECIMAL(dataType)) - -private[sql] class StructColumnBuilder(dataType: StructType) - extends ComplexColumnBuilder(new ObjectColumnStats(dataType), STRUCT(dataType)) - -private[sql] class ArrayColumnBuilder(dataType: ArrayType) - extends ComplexColumnBuilder(new ObjectColumnStats(dataType), ARRAY(dataType)) - -private[sql] class MapColumnBuilder(dataType: MapType) - extends ComplexColumnBuilder(new ObjectColumnStats(dataType), MAP(dataType)) - -private[sql] object ColumnBuilder { - val DEFAULT_INITIAL_BUFFER_SIZE = 1024 * 1024 - - private[columnar] def ensureFreeSpace(orig: ByteBuffer, size: Int) = { - if (orig.remaining >= size) { - orig - } else { - // grow in steps of initial size - val capacity = orig.capacity() - val newSize = capacity + size.max(capacity / 8 + 1) - val pos = orig.position() - - ByteBuffer - .allocate(newSize) - .order(ByteOrder.nativeOrder()) - .put(orig.array(), 0, pos) - } - } - - def apply( - dataType: DataType, - initialSize: Int = 0, - columnName: String = "", - useCompression: Boolean = false): ColumnBuilder = { - val builder: ColumnBuilder = dataType match { - case NullType => new NullColumnBuilder - case BooleanType => new BooleanColumnBuilder - case ByteType => new ByteColumnBuilder - case ShortType => new ShortColumnBuilder - case IntegerType | DateType => new IntColumnBuilder - case LongType | TimestampType => new LongColumnBuilder - case FloatType => new FloatColumnBuilder - case DoubleType => new DoubleColumnBuilder - case StringType => new StringColumnBuilder - case BinaryType => new BinaryColumnBuilder - case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => - new CompactDecimalColumnBuilder(dt) - case dt: DecimalType => new DecimalColumnBuilder(dt) - case struct: StructType => new StructColumnBuilder(struct) - case array: ArrayType => new ArrayColumnBuilder(array) - case map: MapType => new MapColumnBuilder(map) - case udt: UserDefinedType[_] => - return apply(udt.sqlType, initialSize, columnName, useCompression) - case other => - throw new Exception(s"not suppported type: $other") - } - - builder.initialize(initialSize, columnName, useCompression) - builder - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala deleted file mode 100644 index 6f3f1bd97ad52..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ /dev/null @@ -1,205 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.util.HashMap - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.metric.SQLMetrics - -/** - * Groups input data by `groupingExpressions` and computes the `aggregateExpressions` for each - * group. - * - * @param partial if true then aggregation is done partially on local data without shuffling to - * ensure all values where `groupingExpressions` are equal are present. - * @param groupingExpressions expressions that are evaluated to determine grouping. - * @param aggregateExpressions expressions that are computed for each group. - * @param child the input data source. - */ -case class Aggregate( - partial: Boolean, - groupingExpressions: Seq[Expression], - aggregateExpressions: Seq[NamedExpression], - child: SparkPlan) - extends UnaryNode { - - override private[sql] lazy val metrics = Map( - "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - override def requiredChildDistribution: List[Distribution] = { - if (partial) { - UnspecifiedDistribution :: Nil - } else { - if (groupingExpressions == Nil) { - AllTuples :: Nil - } else { - ClusteredDistribution(groupingExpressions) :: Nil - } - } - } - - override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) - - /** - * An aggregate that needs to be computed for each row in a group. - * - * @param unbound Unbound version of this aggregate, used for result substitution. - * @param aggregate A bound copy of this aggregate used to create a new aggregation buffer. - * @param resultAttribute An attribute used to refer to the result of this aggregate in the final - * output. - */ - case class ComputedAggregate( - unbound: AggregateExpression1, - aggregate: AggregateExpression1, - resultAttribute: AttributeReference) - - /** A list of aggregates that need to be computed for each group. */ - private[this] val computedAggregates = aggregateExpressions.flatMap { agg => - agg.collect { - case a: AggregateExpression1 => - ComputedAggregate( - a, - BindReferences.bindReference(a, child.output), - AttributeReference(s"aggResult:$a", a.dataType, a.nullable)()) - } - }.toArray - - /** The schema of the result of all aggregate evaluations */ - private[this] val computedSchema = computedAggregates.map(_.resultAttribute) - - /** Creates a new aggregate buffer for a group. */ - private[this] def newAggregateBuffer(): Array[AggregateFunction1] = { - val buffer = new Array[AggregateFunction1](computedAggregates.length) - var i = 0 - while (i < computedAggregates.length) { - buffer(i) = computedAggregates(i).aggregate.newInstance() - i += 1 - } - buffer - } - - /** Named attributes used to substitute grouping attributes into the final result. */ - private[this] val namedGroups = groupingExpressions.map { - case ne: NamedExpression => ne -> ne.toAttribute - case e => e -> Alias(e, s"groupingExpr:$e")().toAttribute - } - - /** - * A map of substitutions that are used to insert the aggregate expressions and grouping - * expression into the final result expression. - */ - private[this] val resultMap = - (computedAggregates.map { agg => agg.unbound -> agg.resultAttribute } ++ namedGroups).toMap - - /** - * Substituted version of aggregateExpressions expressions which are used to compute final - * output rows given a group and the result of all aggregate computations. - */ - private[this] val resultExpressions = aggregateExpressions.map { agg => - agg.transform { - case e: Expression if resultMap.contains(e) => resultMap(e) - } - } - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - val numInputRows = longMetric("numInputRows") - val numOutputRows = longMetric("numOutputRows") - if (groupingExpressions.isEmpty) { - child.execute().mapPartitions { iter => - val buffer = newAggregateBuffer() - var currentRow: InternalRow = null - while (iter.hasNext) { - currentRow = iter.next() - numInputRows += 1 - var i = 0 - while (i < buffer.length) { - buffer(i).update(currentRow) - i += 1 - } - } - val resultProjection = new InterpretedProjection(resultExpressions, computedSchema) - val aggregateResults = new GenericMutableRow(computedAggregates.length) - - var i = 0 - while (i < buffer.length) { - aggregateResults(i) = buffer(i).eval(EmptyRow) - i += 1 - } - - numOutputRows += 1 - Iterator(resultProjection(aggregateResults)) - } - } else { - child.execute().mapPartitions { iter => - val hashTable = new HashMap[InternalRow, Array[AggregateFunction1]] - val groupingProjection = new InterpretedMutableProjection(groupingExpressions, child.output) - - var currentRow: InternalRow = null - while (iter.hasNext) { - currentRow = iter.next() - numInputRows += 1 - val currentGroup = groupingProjection(currentRow) - var currentBuffer = hashTable.get(currentGroup) - if (currentBuffer == null) { - currentBuffer = newAggregateBuffer() - hashTable.put(currentGroup.copy(), currentBuffer) - } - - var i = 0 - while (i < currentBuffer.length) { - currentBuffer(i).update(currentRow) - i += 1 - } - } - - new Iterator[InternalRow] { - private[this] val hashTableIter = hashTable.entrySet().iterator() - private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length) - private[this] val resultProjection = - new InterpretedMutableProjection( - resultExpressions, computedSchema ++ namedGroups.map(_._2)) - private[this] val joinedRow = new JoinedRow - - override final def hasNext: Boolean = hashTableIter.hasNext - - override final def next(): InternalRow = { - val currentEntry = hashTableIter.next() - val currentGroup = currentEntry.getKey - val currentBuffer = currentEntry.getValue - numOutputRows += 1 - - var i = 0 - while (i < currentBuffer.length) { - // Evaluating an aggregate buffer returns the result. No row is required since we - // already added all rows in the group using update. - aggregateResults(i) = currentBuffer(i).eval(EmptyRow) - i += 1 - } - resultProjection(joinedRow(aggregateResults, currentGroup)) - } - } - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index f85aeb1b02694..124ec09efd196 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.execution import java.util.concurrent.locks.ReentrantReadWriteLock -import org.apache.spark.Logging -import org.apache.spark.sql.DataFrame +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.Dataset import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK @@ -75,12 +75,12 @@ private[sql] class CacheManager extends Logging { } /** - * Caches the data produced by the logical representation of the given [[DataFrame]]. Unlike - * `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because recomputing - * the in-memory columnar representation of the underlying table is expensive. + * Caches the data produced by the logical representation of the given [[Dataset]]. + * Unlike `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because + * recomputing the in-memory columnar representation of the underlying table is expensive. */ private[sql] def cacheQuery( - query: DataFrame, + query: Dataset[_], tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { val planToCache = query.queryExecution.analyzed @@ -95,13 +95,13 @@ private[sql] class CacheManager extends Logging { sqlContext.conf.useCompression, sqlContext.conf.columnBatchSize, storageLevel, - sqlContext.executePlan(query.logicalPlan).executedPlan, + sqlContext.executePlan(planToCache).executedPlan, tableName)) } } - /** Removes the data for the given [[DataFrame]] from the cache */ - private[sql] def uncacheQuery(query: DataFrame, blocking: Boolean = true): Unit = writeLock { + /** Removes the data for the given [[Dataset]] from the cache */ + private[sql] def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) require(dataIndex >= 0, s"Table $query is not cached.") @@ -109,9 +109,12 @@ private[sql] class CacheManager extends Logging { cachedData.remove(dataIndex) } - /** Tries to remove the data for the given [[DataFrame]] from the cache if it's cached */ + /** + * Tries to remove the data for the given [[Dataset]] from the cache + * if it's cached + */ private[sql] def tryUncacheQuery( - query: DataFrame, + query: Dataset[_], blocking: Boolean = true): Boolean = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) @@ -123,12 +126,12 @@ private[sql] class CacheManager extends Logging { found } - /** Optionally returns cached data for the given [[DataFrame]] */ - private[sql] def lookupCachedData(query: DataFrame): Option[CachedData] = readLock { + /** Optionally returns cached data for the given [[Dataset]] */ + private[sql] def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock { lookupCachedData(query.queryExecution.analyzed) } - /** Optionally returns cached data for the given LogicalPlan. */ + /** Optionally returns cached data for the given [[LogicalPlan]]. */ private[sql] def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock { cachedData.find(cd => plan.sameResult(cd.plan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala index 663bc904f39c8..33475bea9af43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, SortOrder} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala deleted file mode 100644 index 0f72ec6cc107a..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ /dev/null @@ -1,479 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.util.Random - -import org.apache.spark._ -import org.apache.spark.rdd.RDD -import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.hash.HashShuffleManager -import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.util.MutablePair - -/** - * Performs a shuffle that will result in the desired `newPartitioning`. - */ -case class Exchange( - var newPartitioning: Partitioning, - child: SparkPlan, - @transient coordinator: Option[ExchangeCoordinator]) extends UnaryNode { - - override def nodeName: String = { - val extraInfo = coordinator match { - case Some(exchangeCoordinator) if exchangeCoordinator.isEstimated => - "Shuffle" - case Some(exchangeCoordinator) if !exchangeCoordinator.isEstimated => - "May shuffle" - case None => "Shuffle without coordinator" - } - - val simpleNodeName = if (tungstenMode) "TungstenExchange" else "Exchange" - s"$simpleNodeName($extraInfo)" - } - - /** - * Returns true iff we can support the data type, and we are not doing range partitioning. - */ - private lazy val tungstenMode: Boolean = { - unsafeEnabled && codegenEnabled && GenerateUnsafeProjection.canSupport(child.schema) && - !newPartitioning.isInstanceOf[RangePartitioning] - } - - override def outputPartitioning: Partitioning = newPartitioning - - override def output: Seq[Attribute] = child.output - - // This setting is somewhat counterintuitive: - // If the schema works with UnsafeRow, then we tell the planner that we don't support safe row, - // so the planner inserts a converter to convert data into UnsafeRow if needed. - override def outputsUnsafeRows: Boolean = tungstenMode - override def canProcessSafeRows: Boolean = !tungstenMode - override def canProcessUnsafeRows: Boolean = tungstenMode - - /** - * Determines whether records must be defensively copied before being sent to the shuffle. - * Several of Spark's shuffle components will buffer deserialized Java objects in memory. The - * shuffle code assumes that objects are immutable and hence does not perform its own defensive - * copying. In Spark SQL, however, operators' iterators return the same mutable `Row` object. In - * order to properly shuffle the output of these operators, we need to perform our own copying - * prior to sending records to the shuffle. This copying is expensive, so we try to avoid it - * whenever possible. This method encapsulates the logic for choosing when to copy. - * - * In the long run, we might want to push this logic into core's shuffle APIs so that we don't - * have to rely on knowledge of core internals here in SQL. - * - * See SPARK-2967, SPARK-4479, and SPARK-7375 for more discussion of this issue. - * - * @param partitioner the partitioner for the shuffle - * @param serializer the serializer that will be used to write rows - * @return true if rows should be copied before being shuffled, false otherwise - */ - private def needToCopyObjectsBeforeShuffle( - partitioner: Partitioner, - serializer: Serializer): Boolean = { - // Note: even though we only use the partitioner's `numPartitions` field, we require it to be - // passed instead of directly passing the number of partitions in order to guard against - // corner-cases where a partitioner constructed with `numPartitions` partitions may output - // fewer partitions (like RangePartitioner, for example). - val conf = child.sqlContext.sparkContext.conf - val shuffleManager = SparkEnv.get.shuffleManager - val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] - val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - if (sortBasedShuffleOn) { - val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] - if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) { - // If we're using the original SortShuffleManager and the number of output partitions is - // sufficiently small, then Spark will fall back to the hash-based shuffle write path, which - // doesn't buffer deserialized records. - // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass. - false - } else if (serializer.supportsRelocationOfSerializedObjects) { - // SPARK-4550 and SPARK-7081 extended sort-based shuffle to serialize individual records - // prior to sorting them. This optimization is only applied in cases where shuffle - // dependency does not specify an aggregator or ordering and the record serializer has - // certain properties. If this optimization is enabled, we can safely avoid the copy. - // - // Exchange never configures its ShuffledRDDs with aggregators or key orderings, so we only - // need to check whether the optimization is enabled and supported by our serializer. - false - } else { - // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory, so we must - // copy. - true - } - } else if (shuffleManager.isInstanceOf[HashShuffleManager]) { - // We're using hash-based shuffle, so we don't need to copy. - false - } else { - // Catch-all case to safely handle any future ShuffleManager implementations. - true - } - } - - @transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf - - private val serializer: Serializer = { - if (tungstenMode) { - new UnsafeRowSerializer(child.output.size) - } else { - new SparkSqlSerializer(sparkConf) - } - } - - override protected def doPrepare(): Unit = { - // If an ExchangeCoordinator is needed, we register this Exchange operator - // to the coordinator when we do prepare. It is important to make sure - // we register this operator right before the execution instead of register it - // in the constructor because it is possible that we create new instances of - // Exchange operators when we transform the physical plan - // (then the ExchangeCoordinator will hold references of unneeded Exchanges). - // So, we should only call registerExchange just before we start to execute - // the plan. - coordinator match { - case Some(exchangeCoordinator) => exchangeCoordinator.registerExchange(this) - case None => - } - } - - /** - * Returns a [[ShuffleDependency]] that will partition rows of its child based on - * the partitioning scheme defined in `newPartitioning`. Those partitions of - * the returned ShuffleDependency will be the input of shuffle. - */ - private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = { - val rdd = child.execute() - val part: Partitioner = newPartitioning match { - case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions) - case HashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions) - case RangePartitioning(sortingExpressions, numPartitions) => - // Internally, RangePartitioner runs a job on the RDD that samples keys to compute - // partition bounds. To get accurate samples, we need to copy the mutable keys. - val rddForSampling = rdd.mapPartitions { iter => - val mutablePair = new MutablePair[InternalRow, Null]() - iter.map(row => mutablePair.update(row.copy(), null)) - } - // We need to use an interpreted ordering here because generated orderings cannot be - // serialized and this ordering needs to be created on the driver in order to be passed into - // Spark core code. - implicit val ordering = new InterpretedOrdering(sortingExpressions, child.output) - new RangePartitioner(numPartitions, rddForSampling, ascending = true) - case SinglePartition => - new Partitioner { - override def numPartitions: Int = 1 - override def getPartition(key: Any): Int = 0 - } - case _ => sys.error(s"Exchange not implemented for $newPartitioning") - // TODO: Handle BroadcastPartitioning. - } - def getPartitionKeyExtractor(): InternalRow => Any = newPartitioning match { - case RoundRobinPartitioning(numPartitions) => - // Distributes elements evenly across output partitions, starting from a random partition. - var position = new Random(TaskContext.get().partitionId()).nextInt(numPartitions) - (row: InternalRow) => { - // The HashPartitioner will handle the `mod` by the number of partitions - position += 1 - position - } - case HashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)() - case RangePartitioning(_, _) | SinglePartition => identity - case _ => sys.error(s"Exchange not implemented for $newPartitioning") - } - val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = { - if (needToCopyObjectsBeforeShuffle(part, serializer)) { - rdd.mapPartitions { iter => - val getPartitionKey = getPartitionKeyExtractor() - iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } - } - } else { - rdd.mapPartitions { iter => - val getPartitionKey = getPartitionKeyExtractor() - val mutablePair = new MutablePair[Int, InternalRow]() - iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) } - } - } - } - - // Now, we manually create a ShuffleDependency. Because pairs in rddWithPartitionIds - // are in the form of (partitionId, row) and every partitionId is in the expected range - // [0, part.numPartitions - 1]. The partitioner of this is a PartitionIdPassthrough. - val dependency = - new ShuffleDependency[Int, InternalRow, InternalRow]( - rddWithPartitionIds, - new PartitionIdPassthrough(part.numPartitions), - Some(serializer)) - - dependency - } - - /** - * Returns a [[ShuffledRowRDD]] that represents the post-shuffle dataset. - * This [[ShuffledRowRDD]] is created based on a given [[ShuffleDependency]] and an optional - * partition start indices array. If this optional array is defined, the returned - * [[ShuffledRowRDD]] will fetch pre-shuffle partitions based on indices of this array. - */ - private[sql] def preparePostShuffleRDD( - shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow], - specifiedPartitionStartIndices: Option[Array[Int]] = None): ShuffledRowRDD = { - // If an array of partition start indices is provided, we need to use this array - // to create the ShuffledRowRDD. Also, we need to update newPartitioning to - // update the number of post-shuffle partitions. - specifiedPartitionStartIndices.foreach { indices => - assert(newPartitioning.isInstanceOf[HashPartitioning]) - newPartitioning = newPartitioning.withNumPartitions(indices.length) - } - new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices) - } - - protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") { - coordinator match { - case Some(exchangeCoordinator) => - val shuffleRDD = exchangeCoordinator.postShuffleRDD(this) - assert(shuffleRDD.partitions.length == newPartitioning.numPartitions) - shuffleRDD - case None => - val shuffleDependency = prepareShuffleDependency() - preparePostShuffleRDD(shuffleDependency) - } - } -} - -object Exchange { - def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = { - Exchange(newPartitioning, child, None: Option[ExchangeCoordinator]) - } -} - -/** - * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]] - * of input data meets the - * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for - * each operator by inserting [[Exchange]] Operators where required. Also ensure that the - * input partition ordering requirements are met. - */ -private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] { - private def defaultNumPreShufflePartitions: Int = sqlContext.conf.numShufflePartitions - - private def targetPostShuffleInputSize: Long = sqlContext.conf.targetPostShuffleInputSize - - private def adaptiveExecutionEnabled: Boolean = sqlContext.conf.adaptiveExecutionEnabled - - private def minNumPostShufflePartitions: Option[Int] = { - val minNumPostShufflePartitions = sqlContext.conf.minNumPostShufflePartitions - if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None - } - - /** - * Given a required distribution, returns a partitioning that satisfies that distribution. - */ - private def createPartitioning( - requiredDistribution: Distribution, - numPartitions: Int): Partitioning = { - requiredDistribution match { - case AllTuples => SinglePartition - case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions) - case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions) - case dist => sys.error(s"Do not know how to satisfy distribution $dist") - } - } - - /** - * Adds [[ExchangeCoordinator]] to [[Exchange]]s if adaptive query execution is enabled - * and partitioning schemes of these [[Exchange]]s support [[ExchangeCoordinator]]. - */ - private def withExchangeCoordinator( - children: Seq[SparkPlan], - requiredChildDistributions: Seq[Distribution]): Seq[SparkPlan] = { - val supportsCoordinator = - if (children.exists(_.isInstanceOf[Exchange])) { - // Right now, ExchangeCoordinator only support HashPartitionings. - children.forall { - case e @ Exchange(hash: HashPartitioning, _, _) => true - case child => - child.outputPartitioning match { - case hash: HashPartitioning => true - case collection: PartitioningCollection => - collection.partitionings.exists(_.isInstanceOf[HashPartitioning]) - case _ => false - } - } - } else { - // In this case, although we do not have Exchange operators, we may still need to - // shuffle data when we have more than one children because data generated by - // these children may not be partitioned in the same way. - // Please see the comment in withCoordinator for more details. - val supportsDistribution = - requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution]) - children.length > 1 && supportsDistribution - } - - val withCoordinator = - if (adaptiveExecutionEnabled && supportsCoordinator) { - val coordinator = - new ExchangeCoordinator( - children.length, - targetPostShuffleInputSize, - minNumPostShufflePartitions) - children.zip(requiredChildDistributions).map { - case (e: Exchange, _) => - // This child is an Exchange, we need to add the coordinator. - e.copy(coordinator = Some(coordinator)) - case (child, distribution) => - // If this child is not an Exchange, we need to add an Exchange for now. - // Ideally, we can try to avoid this Exchange. However, when we reach here, - // there are at least two children operators (because if there is a single child - // and we can avoid Exchange, supportsCoordinator will be false and we - // will not reach here.). Although we can make two children have the same number of - // post-shuffle partitions. Their numbers of pre-shuffle partitions may be different. - // For example, let's say we have the following plan - // Join - // / \ - // Agg Exchange - // / \ - // Exchange t2 - // / - // t1 - // In this case, because a post-shuffle partition can include multiple pre-shuffle - // partitions, a HashPartitioning will not be strictly partitioned by the hashcodes - // after shuffle. So, even we can use the child Exchange operator of the Join to - // have a number of post-shuffle partitions that matches the number of partitions of - // Agg, we cannot say these two children are partitioned in the same way. - // Here is another case - // Join - // / \ - // Agg1 Agg2 - // / \ - // Exchange1 Exchange2 - // / \ - // t1 t2 - // In this case, two Aggs shuffle data with the same column of the join condition. - // After we use ExchangeCoordinator, these two Aggs may not be partitioned in the same - // way. Let's say that Agg1 and Agg2 both have 5 pre-shuffle partitions and 2 - // post-shuffle partitions. It is possible that Agg1 fetches those pre-shuffle - // partitions by using a partitionStartIndices [0, 3]. However, Agg2 may fetch its - // pre-shuffle partitions by using another partitionStartIndices [0, 4]. - // So, Agg1 and Agg2 are actually not co-partitioned. - // - // It will be great to introduce a new Partitioning to represent the post-shuffle - // partitions when one post-shuffle partition includes multiple pre-shuffle partitions. - val targetPartitioning = - createPartitioning(distribution, defaultNumPreShufflePartitions) - assert(targetPartitioning.isInstanceOf[HashPartitioning]) - Exchange(targetPartitioning, child, Some(coordinator)) - } - } else { - // If we do not need ExchangeCoordinator, the original children are returned. - children - } - - withCoordinator - } - - private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = { - val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution - val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering - var children: Seq[SparkPlan] = operator.children - assert(requiredChildDistributions.length == children.length) - assert(requiredChildOrderings.length == children.length) - - // Ensure that the operator's children satisfy their output distribution requirements: - children = children.zip(requiredChildDistributions).map { case (child, distribution) => - if (child.outputPartitioning.satisfies(distribution)) { - child - } else { - Exchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child) - } - } - - // If the operator has multiple children and specifies child output distributions (e.g. join), - // then the children's output partitionings must be compatible: - if (children.length > 1 - && requiredChildDistributions.toSet != Set(UnspecifiedDistribution) - && !Partitioning.allCompatible(children.map(_.outputPartitioning))) { - - // First check if the existing partitions of the children all match. This means they are - // partitioned by the same partitioning into the same number of partitions. In that case, - // don't try to make them match `defaultPartitions`, just use the existing partitioning. - // TODO: this should be a cost based decision. For example, a big relation should probably - // maintain its existing number of partitions and smaller partitions should be shuffled. - // defaultPartitions is arbitrary. - val numPartitions = children.head.outputPartitioning.numPartitions - val useExistingPartitioning = children.zip(requiredChildDistributions).forall { - case (child, distribution) => { - child.outputPartitioning.guarantees( - createPartitioning(distribution, numPartitions)) - } - } - - children = if (useExistingPartitioning) { - children - } else { - children.zip(requiredChildDistributions).map { - case (child, distribution) => { - val targetPartitioning = - createPartitioning(distribution, defaultNumPreShufflePartitions) - if (child.outputPartitioning.guarantees(targetPartitioning)) { - child - } else { - Exchange(targetPartitioning, child) - } - } - } - } - } - - // Now, we need to add ExchangeCoordinator if necessary. - // Actually, it is not a good idea to add ExchangeCoordinators while we are adding Exchanges. - // However, with the way that we plan the query, we do not have a place where we have a - // global picture of all shuffle dependencies of a post-shuffle stage. So, we add coordinator - // at here for now. - // Once we finish https://issues.apache.org/jira/browse/SPARK-10665, - // we can first add Exchanges and then add coordinator once we have a DAG of query fragments. - children = withExchangeCoordinator(children, requiredChildDistributions) - - // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings: - children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) => - if (requiredOrdering.nonEmpty) { - // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort. - if (requiredOrdering != child.outputOrdering.take(requiredOrdering.length)) { - sqlContext.planner.BasicOperators.getSortOperator( - requiredOrdering, - global = false, - child) - } else { - child - } - } else { - child - } - } - - operator.withNewChildren(children) - } - - def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case operator: SparkPlan => ensureDistributionAndOrdering(operator) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 7a466cf6a0a94..392c48fb7b93b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -18,14 +18,19 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} -import org.apache.spark.sql.sources.{HadoopFsRelation, BaseRelation} -import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.{Row, SQLContext} - +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.util.toCommentSafeString +import org.apache.spark.sql.execution.datasources.parquet.{DefaultSource => ParquetSource} +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation} +import org.apache.spark.sql.types.{AtomicType, DataType} object RDDConversions { def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = { @@ -74,6 +79,8 @@ private[sql] case class LogicalRDD( override def children: Seq[LogicalPlan] = Nil + override protected final def otherCopyArgs: Seq[AnyRef] = sqlContext :: Nil + override def newInstance(): LogicalRDD.this.type = LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type] @@ -82,6 +89,8 @@ private[sql] case class LogicalRDD( case _ => false } + override def producedAttributes: AttributeSet = outputSet + @transient override lazy val statistics: Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. @@ -93,42 +102,255 @@ private[sql] case class LogicalRDD( private[sql] case class PhysicalRDD( output: Seq[Attribute], rdd: RDD[InternalRow], - extraInformation: String, - override val outputsUnsafeRows: Boolean = false) - extends LeafNode { + override val nodeName: String) extends LeafNode { - protected override def doExecute(): RDD[InternalRow] = rdd + private[sql] override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - override def simpleString: String = "Scan " + extraInformation + output.mkString("[", ",", "]") -} + protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + rdd.mapPartitionsInternal { iter => + val proj = UnsafeProjection.create(schema) + iter.map { r => + numOutputRows += 1 + proj(r) + } + } + } -private[sql] object PhysicalRDD { - def createFromDataSource( - output: Seq[Attribute], - rdd: RDD[InternalRow], - relation: BaseRelation): PhysicalRDD = { - PhysicalRDD(output, rdd, relation.toString, relation.isInstanceOf[HadoopFsRelation]) + override def simpleString: String = { + s"Scan $nodeName${output.mkString("[", ",", "]")}" } } -/** Logical plan node for scanning data from a local collection. */ -private[sql] -case class LogicalLocalTable(output: Seq[Attribute], rows: Seq[InternalRow])(sqlContext: SQLContext) - extends LogicalPlan with MultiInstanceRelation { +private[sql] trait DataSourceScan extends LeafNode { + val rdd: RDD[InternalRow] + val relation: BaseRelation - override def children: Seq[LogicalPlan] = Nil + override val nodeName: String = relation.toString + + // Ignore rdd when checking results + override def sameResult(plan: SparkPlan): Boolean = plan match { + case other: DataSourceScan => relation == other.relation && metadata == other.metadata + case _ => false + } +} - override def newInstance(): this.type = - LogicalLocalTable(output.map(_.newInstance()), rows)(sqlContext).asInstanceOf[this.type] +/** Physical plan node for scanning data from a relation. */ +private[sql] case class RowDataSourceScan( + output: Seq[Attribute], + rdd: RDD[InternalRow], + @transient relation: BaseRelation, + override val outputPartitioning: Partitioning, + override val metadata: Map[String, String] = Map.empty) + extends DataSourceScan with CodegenSupport { - override def sameResult(plan: LogicalPlan): Boolean = plan match { - case LogicalRDD(_, otherRDD) => rows == rows + private[sql] override lazy val metrics = + Map("numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + + val outputUnsafeRows = relation match { + case r: HadoopFsRelation if r.fileFormat.isInstanceOf[ParquetSource] => + !SQLContext.getActive().get.conf.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) + case _: HadoopFsRelation => true case _ => false } - @transient override lazy val statistics: Statistics = Statistics( - // TODO: Improve the statistics estimation. - // This is made small enough so it can be broadcasted. - sizeInBytes = sqlContext.conf.autoBroadcastJoinThreshold - 1 - ) + protected override def doExecute(): RDD[InternalRow] = { + val unsafeRow = if (outputUnsafeRows) { + rdd + } else { + rdd.mapPartitionsInternal { iter => + val proj = UnsafeProjection.create(schema) + iter.map(proj) + } + } + + val numOutputRows = longMetric("numOutputRows") + unsafeRow.map { r => + numOutputRows += 1 + r + } + } + + override def simpleString: String = { + val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield s"$key: $value" + s"Scan $nodeName${output.mkString("[", ",", "]")}${metadataEntries.mkString(" ", ", ", "")}" + } + + override def upstreams(): Seq[RDD[InternalRow]] = { + rdd :: Nil + } + + override protected def doProduce(ctx: CodegenContext): String = { + val numOutputRows = metricTerm(ctx, "numOutputRows") + // PhysicalRDD always just has one input + val input = ctx.freshName("input") + ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val exprRows = output.zipWithIndex.map{ case (a, i) => + new BoundReference(i, a.dataType, a.nullable) + } + val row = ctx.freshName("row") + ctx.INPUT_ROW = row + ctx.currentVars = null + val columnsRowInput = exprRows.map(_.gen(ctx)) + val inputRow = if (outputUnsafeRows) row else null + s""" + |while ($input.hasNext()) { + | InternalRow $row = (InternalRow) $input.next(); + | $numOutputRows.add(1); + | ${consume(ctx, columnsRowInput, inputRow).trim} + | if (shouldStop()) return; + |} + """.stripMargin + } +} + +/** Physical plan node for scanning data from a batched relation. */ +private[sql] case class BatchedDataSourceScan( + output: Seq[Attribute], + rdd: RDD[InternalRow], + @transient relation: BaseRelation, + override val outputPartitioning: Partitioning, + override val metadata: Map[String, String] = Map.empty) + extends DataSourceScan with CodegenSupport { + + private[sql] override lazy val metrics = + Map("numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"), + "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) + + protected override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException + } + + override def simpleString: String = { + val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield s"$key: $value" + val metadataStr = metadataEntries.mkString(" ", ", ", "") + s"BatchedScan $nodeName${output.mkString("[", ",", "]")}$metadataStr" + } + + override def upstreams(): Seq[RDD[InternalRow]] = { + rdd :: Nil + } + + private def genCodeColumnVector(ctx: CodegenContext, columnVar: String, ordinal: String, + dataType: DataType, nullable: Boolean): ExprCode = { + val javaType = ctx.javaType(dataType) + val value = ctx.getValue(columnVar, dataType, ordinal) + val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" } + val valueVar = ctx.freshName("value") + val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" + val code = s"/* ${toCommentSafeString(str)} */\n" + (if (nullable) { + s""" + boolean ${isNullVar} = ${columnVar}.isNullAt($ordinal); + $javaType ${valueVar} = ${isNullVar} ? ${ctx.defaultValue(dataType)} : ($value); + """ + } else { + s"$javaType ${valueVar} = $value;" + }).trim + ExprCode(code, isNullVar, valueVar) + } + + // Support codegen so that we can avoid the UnsafeRow conversion in all cases. Codegen + // never requires UnsafeRow as input. + override protected def doProduce(ctx: CodegenContext): String = { + val input = ctx.freshName("input") + // PhysicalRDD always just has one input + ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + + // metrics + val numOutputRows = metricTerm(ctx, "numOutputRows") + val scanTimeMetric = metricTerm(ctx, "scanTime") + val scanTimeTotalNs = ctx.freshName("scanTime") + ctx.addMutableState("long", scanTimeTotalNs, s"$scanTimeTotalNs = 0;") + + val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch" + val batch = ctx.freshName("batch") + ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;") + + val columnVectorClz = "org.apache.spark.sql.execution.vectorized.ColumnVector" + val idx = ctx.freshName("batchIdx") + ctx.addMutableState("int", idx, s"$idx = 0;") + val colVars = output.indices.map(i => ctx.freshName("colInstance" + i)) + val columnAssigns = colVars.zipWithIndex.map { case (name, i) => + ctx.addMutableState(columnVectorClz, name, s"$name = null;") + s"$name = $batch.column($i);" + } + + val nextBatch = ctx.freshName("nextBatch") + ctx.addNewFunction(nextBatch, + s""" + |private void $nextBatch() throws java.io.IOException { + | long getBatchStart = System.nanoTime(); + | if ($input.hasNext()) { + | $batch = ($columnarBatchClz)$input.next(); + | $numOutputRows.add($batch.numRows()); + | $idx = 0; + | ${columnAssigns.mkString("", "\n", "\n")} + | } + | $scanTimeTotalNs += System.nanoTime() - getBatchStart; + |}""".stripMargin) + + ctx.currentVars = null + val rowidx = ctx.freshName("rowIdx") + val columnsBatchInput = (output zip colVars).map { case (attr, colVar) => + genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable) + } + s""" + |if ($batch == null) { + | $nextBatch(); + |} + |while ($batch != null) { + | int numRows = $batch.numRows(); + | while ($idx < numRows) { + | int $rowidx = $idx++; + | ${consume(ctx, columnsBatchInput).trim} + | if (shouldStop()) return; + | } + | $batch = null; + | $nextBatch(); + |} + |$scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000)); + |$scanTimeTotalNs = 0; + """.stripMargin + } +} + +private[sql] object DataSourceScan { + // Metadata keys + val INPUT_PATHS = "InputPaths" + val PUSHED_FILTERS = "PushedFilters" + + def create( + output: Seq[Attribute], + rdd: RDD[InternalRow], + relation: BaseRelation, + metadata: Map[String, String] = Map.empty): DataSourceScan = { + val outputPartitioning = { + val bucketSpec = relation match { + // TODO: this should be closer to bucket planning. + case r: HadoopFsRelation if r.sqlContext.conf.bucketingEnabled => r.bucketSpec + case _ => None + } + + def toAttribute(colName: String): Attribute = output.find(_.name == colName).getOrElse { + throw new AnalysisException(s"bucket column $colName not found in existing columns " + + s"(${output.map(_.name).mkString(", ")})") + } + + bucketSpec.map { spec => + val numBuckets = spec.numBuckets + val bucketColumns = spec.bucketColumnNames.map(toAttribute) + HashPartitioning(bucketColumns, numBuckets) + }.getOrElse { + UnknownPartitioning(0) + } + } + + relation match { + case r: HadoopFsRelation if r.fileFormat.supportBatch(r.sqlContext, relation.schema) => + BatchedDataSourceScan(output, rdd, relation, outputPartitioning, metadata) + case _ => + RowDataSourceScan(output, rdd, relation, outputPartitioning, metadata) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index a458881f40948..bd23b7e3ad683 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -21,7 +21,9 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * Apply the all of the GroupExpressions to every input row, hence we will get @@ -35,20 +37,26 @@ case class Expand( projections: Seq[Seq[Expression]], output: Seq[Attribute], child: SparkPlan) - extends UnaryNode { + extends UnaryNode with CodegenSupport { + + private[sql] override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) // The GroupExpressions can output data with arbitrary partitioning, so set it // as UNKNOWN partitioning override def outputPartitioning: Partitioning = UnknownPartitioning(0) + override def references: AttributeSet = + AttributeSet(projections.flatten.flatMap(_.references)) + + private[this] val projection = + (exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output) + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - child.execute().mapPartitions { iter => - // TODO Move out projection objects creation and transfer to - // workers via closure. However we can't assume the Projection - // is serializable because of the code gen, so we have to - // create the projections within each of the partition processing. - val groups = projections.map(ee => newProjection(ee, child.output)).toArray + val numOutputRows = longMetric("numOutputRows") + child.execute().mapPartitions { iter => + val groups = projections.map(projection).toArray new Iterator[InternalRow] { private[this] var result: InternalRow = _ private[this] var idx = -1 // -1 means the initial state @@ -70,9 +78,125 @@ case class Expand( idx = 0 } + numOutputRows += 1 result } } } } + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + /* + * When the projections list looks like: + * expr1A, exprB, expr1C + * expr2A, exprB, expr2C + * ... + * expr(N-1)A, exprB, expr(N-1)C + * + * i.e. column A and C have different values for each output row, but column B stays constant. + * + * The generated code looks something like (note that B is only computed once in declaration): + * + * // part 1: declare all the columns + * colA = ... + * colB = ... + * colC = ... + * + * // part 2: code that computes the columns + * for (row = 0; row < N; row++) { + * switch (row) { + * case 0: + * colA = ... + * colC = ... + * case 1: + * colA = ... + * colC = ... + * ... + * case N - 1: + * colA = ... + * colC = ... + * } + * // increment metrics and consume output values + * } + * + * We use a for loop here so we only includes one copy of the consume code and avoid code + * size explosion. + */ + + // Set input variables + ctx.currentVars = input + + // Tracks whether a column has the same output for all rows. + // Size of sameOutput array should equal N. + // If sameOutput(i) is true, then the i-th column has the same value for all output rows given + // an input row. + val sameOutput: Array[Boolean] = output.indices.map { colIndex => + projections.map(p => p(colIndex)).toSet.size == 1 + }.toArray + + // Part 1: declare variables for each column + // If a column has the same value for all output rows, then we also generate its computation + // right after declaration. Otherwise its value is computed in the part 2. + val outputColumns = output.indices.map { col => + val firstExpr = projections.head(col) + if (sameOutput(col)) { + // This column is the same across all output rows. Just generate code for it here. + BindReferences.bindReference(firstExpr, child.output).gen(ctx) + } else { + val isNull = ctx.freshName("isNull") + val value = ctx.freshName("value") + val code = s""" + |boolean $isNull = true; + |${ctx.javaType(firstExpr.dataType)} $value = ${ctx.defaultValue(firstExpr.dataType)}; + """.stripMargin + ExprCode(code, isNull, value) + } + } + + // Part 2: switch/case statements + val cases = projections.zipWithIndex.map { case (exprs, row) => + var updateCode = "" + for (col <- exprs.indices) { + if (!sameOutput(col)) { + val ev = BindReferences.bindReference(exprs(col), child.output).gen(ctx) + updateCode += + s""" + |${ev.code} + |${outputColumns(col).isNull} = ${ev.isNull}; + |${outputColumns(col).value} = ${ev.value}; + """.stripMargin + } + } + + s""" + |case $row: + | ${updateCode.trim} + | break; + """.stripMargin + } + + val numOutput = metricTerm(ctx, "numOutputRows") + val i = ctx.freshName("i") + // these column have to declared before the loop. + val evaluate = evaluateVariables(outputColumns) + ctx.copyResult = true + s""" + |$evaluate + |for (int $i = 0; $i < ${projections.length}; $i ++) { + | switch ($i) { + | ${cases.mkString("\n").trim} + | } + | $numOutput.add(1); + | ${consume(ctx, outputColumns)} + |} + """.stripMargin + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 78e33d9f233a6..9938d2169f1c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.metric.SQLMetrics /** * For lazy computing, be sure the generator.terminate() called in the very last @@ -54,13 +55,18 @@ case class Generate( child: SparkPlan) extends UnaryNode { + private[sql] override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + + override def producedAttributes: AttributeSet = AttributeSet(output) + val boundGenerator = BindReferences.bindReference(generator, child.output) protected override def doExecute(): RDD[InternalRow] = { // boundGenerator.terminate() should be triggered after all of the rows in the partition - if (join) { - child.execute().mapPartitions { iter => - val generatorNullRow = InternalRow.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null)) + val rows = if (join) { + child.execute().mapPartitionsInternal { iter => + val generatorNullRow = new GenericInternalRow(generator.elementTypes.size) val joinedRow = new JoinedRow iter.flatMap { row => @@ -70,18 +76,26 @@ case class Generate( if (outer && outputRows.isEmpty) { joinedRow.withRight(generatorNullRow) :: Nil } else { - outputRows.map(or => joinedRow.withRight(or)) + outputRows.map(joinedRow.withRight) } - } ++ LazyIterator(() => boundGenerator.terminate()).map { row => + } ++ LazyIterator(boundGenerator.terminate).map { row => // we leave the left side as the last element of its child output // keep it the same as Hive does joinedRow.withRight(row) } } } else { - child.execute().mapPartitions { iter => - iter.flatMap(row => boundGenerator.eval(row)) ++ - LazyIterator(() => boundGenerator.terminate()) + child.execute().mapPartitionsInternal { iter => + iter.flatMap(boundGenerator.eval) ++ LazyIterator(boundGenerator.terminate) + } + } + + val numOutputRows = longMetric("numOutputRows") + rows.mapPartitionsInternal { iter => + val proj = UnsafeProjection.create(output, output) + iter.map { r => + numOutputRows += 1 + proj(r) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala index 6a8850129f1ac..431f02102e8e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateOrdering} -import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder, Ascending, Expression} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateOrdering, GenerateUnsafeProjection} object GroupedIterator { def apply( @@ -115,7 +115,8 @@ class GroupedIterator private( false } else { // Skip to next group. - while (input.hasNext && keyOrdering.compare(currentGroup, currentRow) == 0) { + // currentRow may be overwritten by `hasNext`, so we should compare them first. + while (keyOrdering.compare(currentGroup, currentRow) == 0 && input.hasNext) { currentRow = input.next() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index ba7f6287ac6c3..f8aec9e7a1d1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} +import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -29,15 +30,29 @@ private[sql] case class LocalTableScan( output: Seq[Attribute], rows: Seq[InternalRow]) extends LeafNode { - private lazy val rdd = sqlContext.sparkContext.parallelize(rows) + private[sql] override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - protected override def doExecute(): RDD[InternalRow] = rdd + private val unsafeRows: Array[InternalRow] = { + val proj = UnsafeProjection.create(output, output) + rows.map(r => proj(r).copy()).toArray + } + + private lazy val rdd = sqlContext.sparkContext.parallelize(unsafeRows) + + protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + rdd.map { r => + numOutputRows += 1 + r + } + } override def executeCollect(): Array[InternalRow] = { - rows.toArray + unsafeRows } override def executeTake(limit: Int): Array[InternalRow] = { - rows.take(limit).toArray + unsafeRows.take(limit) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index fc9174549e642..f5e1e77263b5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -18,50 +18,76 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{AnalysisException, SQLContext} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} /** * The primary workflow for executing relational queries using Spark. Designed to allow easy * access to the intermediate phases of query execution for developers. + * + * While this is not a public class, we should avoid changing the function names for the sake of + * changing them, because a lot of developers use the feature for debugging. */ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { - val analyzer = sqlContext.analyzer - val optimizer = sqlContext.optimizer - val planner = sqlContext.planner - val cacheManager = sqlContext.cacheManager - val prepareForExecution = sqlContext.prepareForExecution - def assertAnalyzed(): Unit = analyzer.checkAnalysis(analyzed) + // TODO: Move the planner an optimizer into here from SessionState. + protected def planner = sqlContext.sessionState.planner + + def assertAnalyzed(): Unit = try sqlContext.sessionState.analyzer.checkAnalysis(analyzed) catch { + case e: AnalysisException => + val ae = new AnalysisException(e.message, e.line, e.startPosition, Some(analyzed)) + ae.setStackTrace(e.getStackTrace) + throw ae + } + + lazy val analyzed: LogicalPlan = sqlContext.sessionState.analyzer.execute(logical) - lazy val analyzed: LogicalPlan = analyzer.execute(logical) lazy val withCachedData: LogicalPlan = { assertAnalyzed() - cacheManager.useCachedData(analyzed) + sqlContext.cacheManager.useCachedData(analyzed) } - lazy val optimizedPlan: LogicalPlan = optimizer.execute(withCachedData) - // TODO: Don't just pick the first one... + lazy val optimizedPlan: LogicalPlan = sqlContext.sessionState.optimizer.execute(withCachedData) + lazy val sparkPlan: SparkPlan = { - SparkPlan.currentContext.set(sqlContext) - planner.plan(optimizedPlan).next() + SQLContext.setActive(sqlContext) + planner.plan(ReturnAnswer(optimizedPlan)).next() } + // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - lazy val executedPlan: SparkPlan = prepareForExecution.execute(sparkPlan) + lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ lazy val toRdd: RDD[InternalRow] = executedPlan.execute() + /** + * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal + * row format conversions as needed. + */ + protected def prepareForExecution(plan: SparkPlan): SparkPlan = { + preparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp) } + } + + /** A sequence of rules that will be applied in order to the physical plan before execution. */ + protected def preparations: Seq[Rule[SparkPlan]] = Seq( + python.ExtractPythonUDFs, + PlanSubqueries(sqlContext), + EnsureRequirements(sqlContext.conf), + CollapseCodegenStages(sqlContext.conf), + ReuseExchange(sqlContext.conf)) + protected def stringOrError[A](f: => A): String = try f.toString catch { case e: Throwable => e.toString } - def simpleString: String = + def simpleString: String = { s"""== Physical Plan == |${stringOrError(executedPlan)} """.stripMargin.trim - + } override def toString: String = { def output = @@ -76,7 +102,22 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { |${stringOrError(optimizedPlan)} |== Physical Plan == |${stringOrError(executedPlan)} - |Code Generation: ${stringOrError(executedPlan.codegenEnabled)} """.stripMargin.trim } + + /** A special namespace for commands that can be used to debug query execution. */ + // scalastyle:off + object debug { + // scalastyle:on + + /** + * Prints to stdout all the generated code found in this plan (i.e. the output of each + * WholeStageCodegen subtree). + */ + def codegen(): Unit = { + // scalastyle:off println + println(org.apache.spark.sql.execution.debug.codegenString(executedPlan)) + // scalastyle:on println + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 1422e15549c94..0a11b16d0ed35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -21,7 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.ui.SparkPlanGraph +import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, + SparkListenerSQLExecutionStart} import org.apache.spark.util.Utils private[sql] object SQLExecution { @@ -45,25 +46,14 @@ private[sql] object SQLExecution { sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) val r = try { val callSite = Utils.getCallSite() - sqlContext.listener.onExecutionStart( - executionId, - callSite.shortForm, - callSite.longForm, - queryExecution.toString, - SparkPlanGraph(queryExecution.executedPlan), - System.currentTimeMillis()) + sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) try { body } finally { - // Ideally, we need to make sure onExecutionEnd happens after onJobStart and onJobEnd. - // However, onJobStart and onJobEnd run in the listener thread. Because we cannot add new - // SQL event types to SparkListener since it's a public API, we cannot guarantee that. - // - // SQLListener should handle the case that onExecutionEnd happens before onJobEnd. - // - // The worst case is onExecutionEnd may happen before onJobStart when the listener thread - // is very busy. If so, we cannot track the jobs for the execution. It seems acceptable. - sqlContext.listener.onExecutionEnd(executionId, System.currentTimeMillis()) + sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) } } finally { sc.setLocalProperty(EXECUTION_ID_KEY, null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala new file mode 100644 index 0000000000000..efd8760cd2474 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution} +import org.apache.spark.sql.execution.metric.SQLMetrics + +/** + * Performs (external) sorting. + * + * @param global when true performs a global sort of all partitions by shuffling the data first + * if necessary. + * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will + * spill every `frequency` records. + */ +case class Sort( + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan, + testSpillFrequency: Int = 0) + extends UnaryNode with CodegenSupport { + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder + + override def requiredChildDistribution: Seq[Distribution] = + if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil + + override private[sql] lazy val metrics = Map( + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) + + def createSorter(): UnsafeExternalRowSorter = { + val ordering = newOrdering(sortOrder, output) + + // The comparator for comparing prefix + val boundSortExpression = BindReferences.bindReference(sortOrder.head, output) + val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) + + // The generator for prefix + val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) + val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { + override def computePrefix(row: InternalRow): Long = { + prefixProjection.apply(row).getLong(0) + } + } + + val pageSize = SparkEnv.get.memoryManager.pageSizeBytes + val sorter = new UnsafeExternalRowSorter( + schema, ordering, prefixComparator, prefixComputer, pageSize) + if (testSpillFrequency > 0) { + sorter.setTestSpillFrequency(testSpillFrequency) + } + sorter + } + + protected override def doExecute(): RDD[InternalRow] = { + val dataSize = longMetric("dataSize") + val spillSize = longMetric("spillSize") + + child.execute().mapPartitionsInternal { iter => + val sorter = createSorter() + + val metrics = TaskContext.get().taskMetrics() + // Remember spill data size of this task before execute this operator so that we can + // figure out how many bytes we spilled for this operator. + val spillSizeBefore = metrics.memoryBytesSpilled + + val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + + dataSize += sorter.getPeakMemoryUsage + spillSize += metrics.memoryBytesSpilled - spillSizeBefore + metrics.incPeakExecutionMemory(sorter.getPeakMemoryUsage) + + sortedIterator + } + } + + override def usedInputs: AttributeSet = AttributeSet(Seq.empty) + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + // Name of sorter variable used in codegen. + private var sorterVariable: String = _ + + override protected def doProduce(ctx: CodegenContext): String = { + val needToSort = ctx.freshName("needToSort") + ctx.addMutableState("boolean", needToSort, s"$needToSort = true;") + + // Initialize the class member variables. This includes the instance of the Sorter and + // the iterator to return sorted rows. + val thisPlan = ctx.addReferenceObj("plan", this) + sorterVariable = ctx.freshName("sorter") + ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, sorterVariable, + s"$sorterVariable = $thisPlan.createSorter();") + val metrics = ctx.freshName("metrics") + ctx.addMutableState(classOf[TaskMetrics].getName, metrics, + s"$metrics = org.apache.spark.TaskContext.get().taskMetrics();") + val sortedIterator = ctx.freshName("sortedIter") + ctx.addMutableState("scala.collection.Iterator", sortedIterator, "") + + val addToSorter = ctx.freshName("addToSorter") + ctx.addNewFunction(addToSorter, + s""" + | private void $addToSorter() throws java.io.IOException { + | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + | } + """.stripMargin.trim) + + // The child could change `copyResult` to true, but we had already consumed all the rows, + // so `copyResult` should be reset to `false`. + ctx.copyResult = false + + val outputRow = ctx.freshName("outputRow") + val dataSize = metricTerm(ctx, "dataSize") + val spillSize = metricTerm(ctx, "spillSize") + val spillSizeBefore = ctx.freshName("spillSizeBefore") + s""" + | if ($needToSort) { + | $addToSorter(); + | Long $spillSizeBefore = $metrics.memoryBytesSpilled(); + | $sortedIterator = $sorterVariable.sort(); + | $dataSize.add($sorterVariable.getPeakMemoryUsage()); + | $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore); + | $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage()); + | $needToSort = false; + | } + | + | while ($sortedIterator.hasNext()) { + | UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next(); + | ${consume(ctx, null, outputRow)} + | if (shouldStop()) return; + | } + """.stripMargin.trim + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + s""" + |${row.code} + |$sorterVariable.insertRow((UnsafeRow)${row.value}); + """.stripMargin + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index e17b50edc62dd..909f124d2c9cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -21,8 +21,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator} - +import org.apache.spark.util.collection.unsafe.sort.{PrefixComparator, PrefixComparators} object SortPrefixUtils { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala new file mode 100644 index 0000000000000..cbde777d98415 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.ExperimentalMethods +import org.apache.spark.sql.catalyst.optimizer.Optimizer + +class SparkOptimizer(experimentalMethods: ExperimentalMethods) extends Optimizer { + override def batches: Seq[Batch] = super.batches :+ Batch( + "User Provided Optimizers", FixedPoint(100), experimentalMethods.extraOptimizations: _*) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 8bb293ae87e64..4091f65aecb50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -17,26 +17,26 @@ package org.apache.spark.sql.execution +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable.ArrayBuffer +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration._ -import org.apache.spark.Logging +import org.apache.spark.{broadcast, SparkEnv} +import org.apache.spark.internal.Logging +import org.apache.spark.io.CompressionCodec import org.apache.spark.rdd.{RDD, RDDOperationScope} -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetric} import org.apache.spark.sql.types.DataType - -object SparkPlan { - protected[sql] val currentContext = new ThreadLocal[SQLContext]() -} +import org.apache.spark.util.ThreadUtils /** * The base class for physical operators. @@ -49,20 +49,15 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * populated by the query planning infrastructure. */ @transient - protected[spark] final val sqlContext = SparkPlan.currentContext.get() + protected[spark] final val sqlContext = SQLContext.getActive().orNull protected def sparkContext = sqlContext.sparkContext // sqlContext will be null when we are being deserialized on the slaves. In this instance - // the value of codegenEnabled/unsafeEnabled will be set by the desserializer after the + // the value of subexpressionEliminationEnabled will be set by the deserializer after the // constructor has run. - val codegenEnabled: Boolean = if (sqlContext != null) { - sqlContext.conf.codegenEnabled - } else { - false - } - val unsafeEnabled: Boolean = if (sqlContext != null) { - sqlContext.conf.unsafeEnabled + val subexpressionEliminationEnabled: Boolean = if (sqlContext != null) { + sqlContext.conf.subexpressionEliminationEnabled } else { false } @@ -72,17 +67,29 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ */ private val prepareCalled = new AtomicBoolean(false) - /** Overridden make copy also propogates sqlContext to copied plan. */ + /** Overridden make copy also propagates sqlContext to copied plan. */ override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = { - SparkPlan.currentContext.set(sqlContext) + SQLContext.setActive(sqlContext) super.makeCopy(newArgs) } + /** + * Return all metadata that describes more details of this SparkPlan. + */ + private[sql] def metadata: Map[String, String] = Map.empty + /** * Return all metrics containing metrics of this SparkPlan. */ private[sql] def metrics: Map[String, SQLMetric[_, _]] = Map.empty + /** + * Reset all the metrics. + */ + private[sql] def resetMetrics(): Unit = { + metrics.valuesIterator.foreach(_.reset()) + } + /** * Return a LongSQLMetric according to the name. */ @@ -103,48 +110,86 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Specifies sort order for each partition requirements on the input data for this operator. */ def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) - /** Specifies whether this operator outputs UnsafeRows */ - def outputsUnsafeRows: Boolean = false - - /** Specifies whether this operator is capable of processing UnsafeRows */ - def canProcessUnsafeRows: Boolean = false + /** + * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute after + * preparations. Concrete implementations of SparkPlan should override doExecute. + */ + final def execute(): RDD[InternalRow] = executeQuery { + doExecute() + } /** - * Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows - * that are not UnsafeRows). + * Returns the result of this query as a broadcast variable by delegating to doBroadcast after + * preparations. Concrete implementations of SparkPlan should override doBroadcast. */ - def canProcessSafeRows: Boolean = true + final def executeBroadcast[T](): broadcast.Broadcast[T] = executeQuery { + doExecuteBroadcast() + } /** - * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute - * after adding query plan information to created RDDs for visualization. - * Concrete implementations of SparkPlan should override doExecute instead. + * Execute a query after preparing the query and adding query plan information to created RDDs + * for visualization. */ - final def execute(): RDD[InternalRow] = { - if (children.nonEmpty) { - val hasUnsafeInputs = children.exists(_.outputsUnsafeRows) - val hasSafeInputs = children.exists(!_.outputsUnsafeRows) - assert(!(hasSafeInputs && hasUnsafeInputs), - "Child operators should output rows in the same format") - assert(canProcessSafeRows || canProcessUnsafeRows, - "Operator must be able to process at least one row format") - assert(!hasSafeInputs || canProcessSafeRows, - "Operator will receive safe rows as input but cannot process safe rows") - assert(!hasUnsafeInputs || canProcessUnsafeRows, - "Operator will receive unsafe rows as input but cannot process unsafe rows") - } + private final def executeQuery[T](query: => T): T = { RDDOperationScope.withScope(sparkContext, nodeName, false, true) { prepare() - doExecute() + waitForSubqueries() + query } } + /** + * List of (uncorrelated scalar subquery, future holding the subquery result) for this plan node. + * This list is populated by [[prepareSubqueries]], which is called in [[prepare]]. + */ + @transient + private val subqueryResults = new ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])] + + /** + * Finds scalar subquery expressions in this plan node and starts evaluating them. + * The list of subqueries are added to [[subqueryResults]]. + */ + protected def prepareSubqueries(): Unit = { + val allSubqueries = expressions.flatMap(_.collect {case e: ScalarSubquery => e}) + allSubqueries.asInstanceOf[Seq[ScalarSubquery]].foreach { e => + val futureResult = Future { + // Each subquery should return only one row (and one column). We take two here and throws + // an exception later if the number of rows is greater than one. + e.executedPlan.executeTake(2) + }(SparkPlan.subqueryExecutionContext) + subqueryResults += e -> futureResult + } + } + + /** + * Blocks the thread until all subqueries finish evaluation and update the results. + */ + protected def waitForSubqueries(): Unit = { + // fill in the result of subqueries + subqueryResults.foreach { case (e, futureResult) => + val rows = Await.result(futureResult, Duration.Inf) + if (rows.length > 1) { + sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}") + } + if (rows.length == 1) { + assert(rows(0).numFields == 1, + s"Expects 1 field, but got ${rows(0).numFields}; something went wrong in analysis") + e.updateResult(rows(0).get(0, e.dataType)) + } else { + // If there is no rows returned, the result should be null. + e.updateResult(null) + } + } + subqueryResults.clear() + } + /** * Prepare a SparkPlan for execution. It's idempotent. */ final def prepare(): Unit = { if (prepareCalled.compareAndSet(false, true)) { doPrepare() + prepareSubqueries() children.foreach(_.prepare()) } } @@ -165,11 +210,86 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ */ protected def doExecute(): RDD[InternalRow] + /** + * Overridden by concrete implementations of SparkPlan. + * Produces the result of the query as a broadcast variable. + */ + protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { + throw new UnsupportedOperationException(s"$nodeName does not implement doExecuteBroadcast") + } + + /** + * Packing the UnsafeRows into byte array for faster serialization. + * The byte arrays are in the following format: + * [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1] + * + * UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also + * compressed. + */ + private def getByteArrayRdd(n: Int = -1): RDD[Array[Byte]] = { + execute().mapPartitionsInternal { iter => + var count = 0 + val buffer = new Array[Byte](4 << 10) // 4K + val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + val bos = new ByteArrayOutputStream() + val out = new DataOutputStream(codec.compressedOutputStream(bos)) + while (iter.hasNext && (n < 0 || count < n)) { + val row = iter.next().asInstanceOf[UnsafeRow] + out.writeInt(row.getSizeInBytes) + row.writeToStream(out, buffer) + count += 1 + } + out.writeInt(-1) + out.flush() + out.close() + Iterator(bos.toByteArray) + } + } + + /** + * Decode the byte arrays back to UnsafeRows and put them into buffer. + */ + private def decodeUnsafeRows(bytes: Array[Byte]): Iterator[InternalRow] = { + val nFields = schema.length + + val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + val bis = new ByteArrayInputStream(bytes) + val ins = new DataInputStream(codec.compressedInputStream(bis)) + + new Iterator[InternalRow] { + private var sizeOfNextRow = ins.readInt() + override def hasNext: Boolean = sizeOfNextRow >= 0 + override def next(): InternalRow = { + val bs = new Array[Byte](sizeOfNextRow) + ins.readFully(bs) + val row = new UnsafeRow(nFields) + row.pointTo(bs, sizeOfNextRow) + sizeOfNextRow = ins.readInt() + row + } + } + } + /** * Runs this query returning the result as an array. */ def executeCollect(): Array[InternalRow] = { - execute().map(_.copy()).collect() + val byteArrayRdd = getByteArrayRdd() + + val results = ArrayBuffer[InternalRow]() + byteArrayRdd.collect().foreach { bytes => + decodeUnsafeRows(bytes).foreach(results.+=) + } + results.toArray + } + + /** + * Runs this query returning the result as an iterator of InternalRow. + * + * Note: this will trigger multiple jobs (one for each partition). + */ + def executeToIterator(): Iterator[InternalRow] = { + getByteArrayRdd().toLocalIterator.flatMap(decodeUnsafeRows) } /** @@ -190,7 +310,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ return new Array[InternalRow](0) } - val childRDD = execute().map(_.copy()) + val childRDD = getByteArrayRdd(n) val buf = new ArrayBuffer[InternalRow] val totalParts = childRDD.partitions.length @@ -198,7 +318,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ while (buf.size < n && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1 + var numPartsToTry = 1L if (partsScanned > 0) { // If we didn't find any rows after the first iteration, just try all partitions next. // Otherwise, interpolate the number of partitions we need to try, but overestimate it @@ -212,101 +332,45 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions val left = n - buf.size - val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val sc = sqlContext.sparkContext - val res = - sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p) - - res.foreach(buf ++= _.take(n - buf.size)) - partsScanned += numPartsToTry - } + val res = sc.runJob(childRDD, + (it: Iterator[Array[Byte]]) => if (it.hasNext) it.next() else Array.empty, p) - buf.toArray - } + res.foreach { r => + decodeUnsafeRows(r.asInstanceOf[Array[Byte]]).foreach(buf.+=) + } - private[this] def isTesting: Boolean = sys.props.contains("spark.testing") + partsScanned += p.size + } - protected def newProjection( - expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { - log.debug( - s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if (codegenEnabled) { - try { - GenerateProjection.generate(expressions, inputSchema) - } catch { - case e: Exception => - if (isTesting) { - throw e - } else { - log.error("Failed to generate projection, fallback to interpret", e) - new InterpretedProjection(expressions, inputSchema) - } - } + if (buf.size > n) { + buf.take(n).toArray } else { - new InterpretedProjection(expressions, inputSchema) + buf.toArray } } + private[this] def isTesting: Boolean = sys.props.contains("spark.testing") + protected def newMutableProjection( expressions: Seq[Expression], - inputSchema: Seq[Attribute]): () => MutableProjection = { - log.debug( - s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if(codegenEnabled) { - try { - GenerateMutableProjection.generate(expressions, inputSchema) - } catch { - case e: Exception => - if (isTesting) { - throw e - } else { - log.error("Failed to generate mutable projection, fallback to interpreted", e) - () => new InterpretedMutableProjection(expressions, inputSchema) - } - } - } else { - () => new InterpretedMutableProjection(expressions, inputSchema) - } + inputSchema: Seq[Attribute], + useSubexprElimination: Boolean = false): () => MutableProjection = { + log.debug(s"Creating MutableProj: $expressions, inputSchema: $inputSchema") + GenerateMutableProjection.generate(expressions, inputSchema, useSubexprElimination) } protected def newPredicate( expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { - if (codegenEnabled) { - try { - GeneratePredicate.generate(expression, inputSchema) - } catch { - case e: Exception => - if (isTesting) { - throw e - } else { - log.error("Failed to generate predicate, fallback to interpreted", e) - InterpretedPredicate.create(expression, inputSchema) - } - } - } else { - InterpretedPredicate.create(expression, inputSchema) - } + GeneratePredicate.generate(expression, inputSchema) } protected def newOrdering( - order: Seq[SortOrder], - inputSchema: Seq[Attribute]): Ordering[InternalRow] = { - if (codegenEnabled) { - try { - GenerateOrdering.generate(order, inputSchema) - } catch { - case e: Exception => - if (isTesting) { - throw e - } else { - log.error("Failed to generate ordering, fallback to interpreted", e) - new InterpretedOrdering(order, inputSchema) - } - } - } else { - new InterpretedOrdering(order, inputSchema) - } + order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[InternalRow] = { + GenerateOrdering.generate(order, inputSchema) } + /** * Creates a row ordering for the given schema, in natural ascending order. */ @@ -318,8 +382,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } } +object SparkPlan { + private[execution] val subqueryExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("subquery", 16)) +} + private[sql] trait LeafNode extends SparkPlan { override def children: Seq[SparkPlan] = Nil + override def producedAttributes: AttributeSet = outputSet +} + +object UnaryNode { + def unapply(a: Any): Option[(SparkPlan, SparkPlan)] = a match { + case s: SparkPlan if s.children.size == 1 => Some((s, s.children.head)) + case _ => None + } } private[sql] trait UnaryNode extends SparkPlan { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala new file mode 100644 index 0000000000000..247f55da1d2a0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.execution.exchange.ReusedExchange +import org.apache.spark.sql.execution.metric.SQLMetricInfo +import org.apache.spark.util.Utils + +/** + * :: DeveloperApi :: + * Stores information about a SQL SparkPlan. + */ +@DeveloperApi +class SparkPlanInfo( + val nodeName: String, + val simpleString: String, + val children: Seq[SparkPlanInfo], + val metadata: Map[String, String], + val metrics: Seq[SQLMetricInfo]) { + + override def hashCode(): Int = { + // hashCode of simpleString should be good enough to distinguish the plans from each other + // within a plan + simpleString.hashCode + } + + override def equals(other: Any): Boolean = other match { + case o: SparkPlanInfo => + nodeName == o.nodeName && simpleString == o.simpleString && children == o.children + case _ => false + } +} + +private[sql] object SparkPlanInfo { + + def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = { + val children = plan match { + case ReusedExchange(_, child) => child :: Nil + case _ => plan.children ++ plan.subqueries + } + val metrics = plan.metrics.toSeq.map { case (key, metric) => + new SQLMetricInfo(metric.name.getOrElse(key), metric.id, + Utils.getFormattedClassName(metric.param)) + } + + new SparkPlanInfo(plan.nodeName, plan.simpleString, children.map(fromSparkPlan), + plan.metadata, metrics) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 0f98fe88b2101..8d05ae470dec1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -18,29 +18,27 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy} +import org.apache.spark.sql.internal.SQLConf -@Experimental -class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { - val sparkContext: SparkContext = sqlContext.sparkContext +class SparkPlanner( + val sparkContext: SparkContext, + val conf: SQLConf, + val extraStrategies: Seq[Strategy]) + extends SparkStrategies { - def codegenEnabled: Boolean = sqlContext.conf.codegenEnabled - - def unsafeEnabled: Boolean = sqlContext.conf.unsafeEnabled - - def numPartitions: Int = sqlContext.conf.numShufflePartitions + def numPartitions: Int = conf.numShufflePartitions def strategies: Seq[Strategy] = - sqlContext.experimental.extraStrategies ++ ( + extraStrategies ++ ( + FileSourceStrategy :: DataSourceStrategy :: DDLStrategy :: - TakeOrderedAndProject :: - HashAggregation :: + SpecialLimits :: Aggregation :: - LeftSemiJoin :: + ExistenceJoin :: EquiJoinSelection :: InMemoryScans :: BasicOperators :: @@ -69,7 +67,7 @@ class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { val projectSet = AttributeSet(projectList.flatMap(_.references)) val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) - val filterCondition = + val filterCondition: Option[Expression] = prunePushedDownFilters(filterPredicates).reduceLeftOption(catalyst.expressions.And) // Right now we still use a projection even if the only evaluation is applying an alias diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala new file mode 100644 index 0000000000000..8ed6ed21d0170 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -0,0 +1,792 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.parser._ +import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} +import org.apache.spark.sql.execution.command.{DescribeCommand => _, _} +import org.apache.spark.sql.execution.datasources._ + +/** + * Concrete parser for Spark SQL statements. + */ +object SparkSqlParser extends AbstractSqlParser{ + val astBuilder = new SparkSqlAstBuilder +} + +/** + * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier. + */ +class SparkSqlAstBuilder extends AstBuilder { + import org.apache.spark.sql.catalyst.parser.ParserUtils._ + + /** + * Create a [[SetCommand]] logical plan. + * + * Note that we assume that everything after the SET keyword is assumed to be a part of the + * key-value pair. The split between key and value is made by searching for the first `=` + * character in the raw string. + */ + override def visitSetConfiguration(ctx: SetConfigurationContext): LogicalPlan = withOrigin(ctx) { + // Construct the command. + val raw = remainder(ctx.SET.getSymbol) + val keyValueSeparatorIndex = raw.indexOf('=') + if (keyValueSeparatorIndex >= 0) { + val key = raw.substring(0, keyValueSeparatorIndex).trim + val value = raw.substring(keyValueSeparatorIndex + 1).trim + SetCommand(Some(key -> Option(value))) + } else if (raw.nonEmpty) { + SetCommand(Some(raw.trim -> None)) + } else { + SetCommand(None) + } + } + + /** + * Create a [[SetDatabaseCommand]] logical plan. + */ + override def visitUse(ctx: UseContext): LogicalPlan = withOrigin(ctx) { + SetDatabaseCommand(ctx.db.getText) + } + + /** + * Create a [[ShowTablesCommand]] logical plan. + * Example SQL : + * {{{ + * SHOW TABLES [(IN|FROM) database_name] [[LIKE] 'identifier_with_wildcards']; + * }}} + */ + override def visitShowTables(ctx: ShowTablesContext): LogicalPlan = withOrigin(ctx) { + ShowTablesCommand( + Option(ctx.db).map(_.getText), + Option(ctx.pattern).map(string)) + } + + /** + * Create a [[ShowDatabasesCommand]] logical plan. + * Example SQL: + * {{{ + * SHOW (DATABASES|SCHEMAS) [LIKE 'identifier_with_wildcards']; + * }}} + */ + override def visitShowDatabases(ctx: ShowDatabasesContext): LogicalPlan = withOrigin(ctx) { + ShowDatabasesCommand(Option(ctx.pattern).map(string)) + } + + /** + * A command for users to list the properties for a table. If propertyKey is specified, the value + * for the propertyKey is returned. If propertyKey is not specified, all the keys and their + * corresponding values are returned. + * The syntax of using this command in SQL is: + * {{{ + * SHOW TBLPROPERTIES table_name[('propertyKey')]; + * }}} + */ + override def visitShowTblProperties( + ctx: ShowTblPropertiesContext): LogicalPlan = withOrigin(ctx) { + ShowTablePropertiesCommand( + visitTableIdentifier(ctx.tableIdentifier), + Option(ctx.key).map(visitTablePropertyKey)) + } + + /** + * Create a [[RefreshTable]] logical plan. + */ + override def visitRefreshTable(ctx: RefreshTableContext): LogicalPlan = withOrigin(ctx) { + RefreshTable(visitTableIdentifier(ctx.tableIdentifier)) + } + + /** + * Create a [[CacheTableCommand]] logical plan. + */ + override def visitCacheTable(ctx: CacheTableContext): LogicalPlan = withOrigin(ctx) { + val query = Option(ctx.query).map(plan) + CacheTableCommand(ctx.identifier.getText, query, ctx.LAZY != null) + } + + /** + * Create an [[UncacheTableCommand]] logical plan. + */ + override def visitUncacheTable(ctx: UncacheTableContext): LogicalPlan = withOrigin(ctx) { + UncacheTableCommand(ctx.identifier.getText) + } + + /** + * Create a [[ClearCacheCommand]] logical plan. + */ + override def visitClearCache(ctx: ClearCacheContext): LogicalPlan = withOrigin(ctx) { + ClearCacheCommand + } + + /** + * Create an [[ExplainCommand]] logical plan. + */ + override def visitExplain(ctx: ExplainContext): LogicalPlan = withOrigin(ctx) { + val options = ctx.explainOption.asScala + if (options.exists(_.FORMATTED != null)) { + logWarning("Unsupported operation: EXPLAIN FORMATTED option") + } + + // Create the explain comment. + val statement = plan(ctx.statement) + if (isExplainableStatement(statement)) { + ExplainCommand(statement, extended = options.exists(_.EXTENDED != null), + codegen = options.exists(_.CODEGEN != null)) + } else { + ExplainCommand(OneRowRelation) + } + } + + /** + * Determine if a plan should be explained at all. + */ + protected def isExplainableStatement(plan: LogicalPlan): Boolean = plan match { + case _: datasources.DescribeCommand => false + case _ => true + } + + /** + * Create a [[DescribeCommand]] logical plan. + */ + override def visitDescribeTable(ctx: DescribeTableContext): LogicalPlan = withOrigin(ctx) { + // FORMATTED and columns are not supported. Return null and let the parser decide what to do + // with this (create an exception or pass it on to a different system). + if (ctx.describeColName != null || ctx.FORMATTED != null || ctx.partitionSpec != null) { + null + } else { + datasources.DescribeCommand( + visitTableIdentifier(ctx.tableIdentifier), + ctx.EXTENDED != null) + } + } + + /** + * Type to keep track of a table header: (identifier, isTemporary, ifNotExists, isExternal). + */ + type TableHeader = (TableIdentifier, Boolean, Boolean, Boolean) + + /** + * Validate a create table statement and return the [[TableIdentifier]]. + */ + override def visitCreateTableHeader( + ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) { + val temporary = ctx.TEMPORARY != null + val ifNotExists = ctx.EXISTS != null + assert(!temporary || !ifNotExists, + "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.", + ctx) + (visitTableIdentifier(ctx.tableIdentifier), temporary, ifNotExists, ctx.EXTERNAL != null) + } + + /** + * Create a [[CreateTableUsing]] or a [[CreateTableUsingAsSelect]] logical plan. + * + * TODO add bucketing and partitioning. + */ + override def visitCreateTableUsing(ctx: CreateTableUsingContext): LogicalPlan = withOrigin(ctx) { + val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) + if (external) { + throw new ParseException("Unsupported operation: EXTERNAL option", ctx) + } + val options = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty) + val provider = ctx.tableProvider.qualifiedName.getText + + if (ctx.query != null) { + // Get the backing query. + val query = plan(ctx.query) + + // Determine the storage mode. + val mode = if (ifNotExists) { + SaveMode.Ignore + } else if (temp) { + SaveMode.Overwrite + } else { + SaveMode.ErrorIfExists + } + CreateTableUsingAsSelect(table, provider, temp, Array.empty, None, mode, options, query) + } else { + val struct = Option(ctx.colTypeList).map(createStructType) + CreateTableUsing(table, struct, provider, temp, options, ifNotExists, managedIfNoPath = false) + } + } + + /** + * Convert a table property list into a key-value map. + */ + override def visitTablePropertyList( + ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) { + ctx.tableProperty.asScala.map { property => + val key = visitTablePropertyKey(property.key) + val value = Option(property.value).map(string).orNull + key -> value + }.toMap + } + + /** + * A table property key can either be String or a collection of dot separated elements. This + * function extracts the property key based on whether its a string literal or a table property + * identifier. + */ + override def visitTablePropertyKey(key: TablePropertyKeyContext): String = { + if (key.STRING != null) { + string(key.STRING) + } else { + key.getText + } + } + + /** + * Create a [[CreateDatabase]] command. + * + * For example: + * {{{ + * CREATE DATABASE [IF NOT EXISTS] database_name [COMMENT database_comment] + * [LOCATION path] [WITH DBPROPERTIES (key1=val1, key2=val2, ...)] + * }}} + */ + override def visitCreateDatabase(ctx: CreateDatabaseContext): LogicalPlan = withOrigin(ctx) { + CreateDatabase( + ctx.identifier.getText, + ctx.EXISTS != null, + Option(ctx.locationSpec).map(visitLocationSpec), + Option(ctx.comment).map(string), + Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty)) + } + + /** + * Create an [[AlterDatabaseProperties]] command. + * + * For example: + * {{{ + * ALTER (DATABASE|SCHEMA) database SET DBPROPERTIES (property_name=property_value, ...); + * }}} + */ + override def visitSetDatabaseProperties( + ctx: SetDatabasePropertiesContext): LogicalPlan = withOrigin(ctx) { + AlterDatabaseProperties( + ctx.identifier.getText, + visitTablePropertyList(ctx.tablePropertyList)) + } + + /** + * Create a [[DropDatabase]] command. + * + * For example: + * {{{ + * DROP (DATABASE|SCHEMA) [IF EXISTS] database [RESTRICT|CASCADE]; + * }}} + */ + override def visitDropDatabase(ctx: DropDatabaseContext): LogicalPlan = withOrigin(ctx) { + DropDatabase(ctx.identifier.getText, ctx.EXISTS != null, ctx.CASCADE != null) + } + + /** + * Create a [[DescribeDatabase]] command. + * + * For example: + * {{{ + * DESCRIBE DATABASE [EXTENDED] database; + * }}} + */ + override def visitDescribeDatabase(ctx: DescribeDatabaseContext): LogicalPlan = withOrigin(ctx) { + DescribeDatabase(ctx.identifier.getText, ctx.EXTENDED != null) + } + + /** + * Create a [[CreateFunction]] command. + * + * For example: + * {{{ + * CREATE [TEMPORARY] FUNCTION [db_name.]function_name AS class_name + * [USING JAR|FILE|ARCHIVE 'file_uri' [, JAR|FILE|ARCHIVE 'file_uri']]; + * }}} + */ + override def visitCreateFunction(ctx: CreateFunctionContext): LogicalPlan = withOrigin(ctx) { + val resources = ctx.resource.asScala.map { resource => + val resourceType = resource.identifier.getText.toLowerCase + resourceType match { + case "jar" | "file" | "archive" => + resourceType -> string(resource.STRING) + case other => + throw new ParseException(s"Resource Type '$resourceType' is not supported.", ctx) + } + } + + // Extract database, name & alias. + val (database, function) = visitFunctionName(ctx.qualifiedName) + CreateFunction( + database, + function, + string(ctx.className), + resources, + ctx.TEMPORARY != null) + } + + /** + * Create a [[DropFunction]] command. + * + * For example: + * {{{ + * DROP [TEMPORARY] FUNCTION [IF EXISTS] function; + * }}} + */ + override def visitDropFunction(ctx: DropFunctionContext): LogicalPlan = withOrigin(ctx) { + val (database, function) = visitFunctionName(ctx.qualifiedName) + DropFunction(database, function, ctx.EXISTS != null, ctx.TEMPORARY != null) + } + + /** + * Create a function database (optional) and name pair. + */ + private def visitFunctionName(ctx: QualifiedNameContext): (Option[String], String) = { + ctx.identifier().asScala.map(_.getText) match { + case Seq(db, fn) => (Option(db), fn) + case Seq(fn) => (None, fn) + case other => throw new ParseException(s"Unsupported function name '${ctx.getText}'", ctx) + } + } + + /** + * Create a [[DropTable]] command. + */ + override def visitDropTable(ctx: DropTableContext): LogicalPlan = withOrigin(ctx) { + if (ctx.PURGE != null) { + throw new ParseException("Unsupported operation: PURGE option", ctx) + } + if (ctx.REPLICATION != null) { + throw new ParseException("Unsupported operation: REPLICATION clause", ctx) + } + DropTable( + visitTableIdentifier(ctx.tableIdentifier), + ctx.EXISTS != null, + ctx.VIEW != null) + } + + /** + * Create a [[AlterTableRename]] command. + * + * For example: + * {{{ + * ALTER TABLE table1 RENAME TO table2; + * ALTER VIEW view1 RENAME TO view2; + * }}} + */ + override def visitRenameTable(ctx: RenameTableContext): LogicalPlan = withOrigin(ctx) { + AlterTableRename( + visitTableIdentifier(ctx.from), + visitTableIdentifier(ctx.to), + ctx.VIEW != null) + } + + /** + * Create an [[AlterTableSetProperties]] command. + * + * For example: + * {{{ + * ALTER TABLE table SET TBLPROPERTIES ('comment' = new_comment); + * ALTER VIEW view SET TBLPROPERTIES ('comment' = new_comment); + * }}} + */ + override def visitSetTableProperties( + ctx: SetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { + AlterTableSetProperties( + visitTableIdentifier(ctx.tableIdentifier), + visitTablePropertyList(ctx.tablePropertyList), + ctx.VIEW != null) + } + + /** + * Create an [[AlterTableUnsetProperties]] command. + * + * For example: + * {{{ + * ALTER TABLE table UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); + * ALTER VIEW view UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); + * }}} + */ + override def visitUnsetTableProperties( + ctx: UnsetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { + AlterTableUnsetProperties( + visitTableIdentifier(ctx.tableIdentifier), + visitTablePropertyList(ctx.tablePropertyList).keys.toSeq, + ctx.EXISTS != null, + ctx.VIEW != null) + } + + /** + * Create an [[AlterTableSerDeProperties]] command. + * + * For example: + * {{{ + * ALTER TABLE table [PARTITION spec] SET SERDE serde_name [WITH SERDEPROPERTIES props]; + * ALTER TABLE table [PARTITION spec] SET SERDEPROPERTIES serde_properties; + * }}} + */ + override def visitSetTableSerDe(ctx: SetTableSerDeContext): LogicalPlan = withOrigin(ctx) { + AlterTableSerDeProperties( + visitTableIdentifier(ctx.tableIdentifier), + Option(ctx.STRING).map(string), + Option(ctx.tablePropertyList).map(visitTablePropertyList), + // TODO a partition spec is allowed to have optional values. This is currently violated. + Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec)) + } + + // TODO: don't even bother parsing alter table commands related to bucketing and skewing + + override def visitBucketTable(ctx: BucketTableContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... CLUSTERED BY ... INTO N BUCKETS") + } + + override def visitUnclusterTable(ctx: UnclusterTableContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException("Operation not allowed: ALTER TABLE ... NOT CLUSTERED") + } + + override def visitUnsortTable(ctx: UnsortTableContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException("Operation not allowed: ALTER TABLE ... NOT SORTED") + } + + override def visitSkewTable(ctx: SkewTableContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException("Operation not allowed: ALTER TABLE ... SKEWED BY ...") + } + + override def visitUnskewTable(ctx: UnskewTableContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException("Operation not allowed: ALTER TABLE ... NOT SKEWED") + } + + override def visitUnstoreTable(ctx: UnstoreTableContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... NOT STORED AS DIRECTORIES") + } + + override def visitSetTableSkewLocations( + ctx: SetTableSkewLocationsContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... SET SKEWED LOCATION ...") + } + + /** + * Create an [[AlterTableAddPartition]] command. + * + * For example: + * {{{ + * ALTER TABLE table ADD [IF NOT EXISTS] PARTITION spec [LOCATION 'loc1'] + * ALTER VIEW view ADD [IF NOT EXISTS] PARTITION spec + * }}} + * + * ALTER VIEW ... ADD PARTITION ... is not supported because the concept of partitioning + * is associated with physical tables + */ + override def visitAddTablePartition( + ctx: AddTablePartitionContext): LogicalPlan = withOrigin(ctx) { + if (ctx.VIEW != null) { + throw new AnalysisException(s"Operation not allowed: partitioned views") + } + // Create partition spec to location mapping. + val specsAndLocs = if (ctx.partitionSpec.isEmpty) { + ctx.partitionSpecLocation.asScala.map { + splCtx => + val spec = visitNonOptionalPartitionSpec(splCtx.partitionSpec) + val location = Option(splCtx.locationSpec).map(visitLocationSpec) + spec -> location + } + } else { + // Alter View: the location clauses are not allowed. + ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec(_) -> None) + } + AlterTableAddPartition( + visitTableIdentifier(ctx.tableIdentifier), + specsAndLocs, + ctx.EXISTS != null) + } + + /** + * Create an [[AlterTableExchangePartition]] command. + * + * For example: + * {{{ + * ALTER TABLE table1 EXCHANGE PARTITION spec WITH TABLE table2; + * }}} + */ + override def visitExchangeTablePartition( + ctx: ExchangeTablePartitionContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... EXCHANGE PARTITION ...") + } + + /** + * Create an [[AlterTableRenamePartition]] command + * + * For example: + * {{{ + * ALTER TABLE table PARTITION spec1 RENAME TO PARTITION spec2; + * }}} + */ + override def visitRenameTablePartition( + ctx: RenameTablePartitionContext): LogicalPlan = withOrigin(ctx) { + AlterTableRenamePartition( + visitTableIdentifier(ctx.tableIdentifier), + visitNonOptionalPartitionSpec(ctx.from), + visitNonOptionalPartitionSpec(ctx.to)) + } + + /** + * Create an [[AlterTableDropPartition]] command + * + * For example: + * {{{ + * ALTER TABLE table DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE]; + * ALTER VIEW view DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...]; + * }}} + * + * ALTER VIEW ... DROP PARTITION ... is not supported because the concept of partitioning + * is associated with physical tables + */ + override def visitDropTablePartitions( + ctx: DropTablePartitionsContext): LogicalPlan = withOrigin(ctx) { + if (ctx.VIEW != null) { + throw new AnalysisException(s"Operation not allowed: partitioned views") + } + if (ctx.PURGE != null) { + throw new AnalysisException(s"Operation not allowed: PURGE") + } + AlterTableDropPartition( + visitTableIdentifier(ctx.tableIdentifier), + ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec), + ctx.EXISTS != null) + } + + /** + * Create an [[AlterTableArchivePartition]] command + * + * For example: + * {{{ + * ALTER TABLE table ARCHIVE PARTITION spec; + * }}} + */ + override def visitArchiveTablePartition( + ctx: ArchiveTablePartitionContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... ARCHIVE PARTITION ...") + } + + /** + * Create an [[AlterTableUnarchivePartition]] command + * + * For example: + * {{{ + * ALTER TABLE table UNARCHIVE PARTITION spec; + * }}} + */ + override def visitUnarchiveTablePartition( + ctx: UnarchiveTablePartitionContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... UNARCHIVE PARTITION ...") + } + + /** + * Create an [[AlterTableSetFileFormat]] command + * + * For example: + * {{{ + * ALTER TABLE table [PARTITION spec] SET FILEFORMAT file_format; + * }}} + */ + override def visitSetTableFileFormat( + ctx: SetTableFileFormatContext): LogicalPlan = withOrigin(ctx) { + // AlterTableSetFileFormat currently takes both a GenericFileFormat and a + // TableFileFormatContext. This is a bit weird because it should only take one. It also should + // use a CatalogFileFormat instead of either a String or a Sequence of Strings. We will address + // this in a follow-up PR. + val (fileFormat, genericFormat) = ctx.fileFormat match { + case s: GenericFileFormatContext => + (Seq.empty[String], Option(s.identifier.getText)) + case s: TableFileFormatContext => + val elements = Seq(s.inFmt, s.outFmt) ++ Option(s.serdeCls).toSeq + (elements.map(string), None) + } + AlterTableSetFileFormat( + visitTableIdentifier(ctx.tableIdentifier), + Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), + fileFormat, + genericFormat)( + command(ctx)) + } + + /** + * Create an [[AlterTableSetLocation]] command + * + * For example: + * {{{ + * ALTER TABLE table [PARTITION spec] SET LOCATION "loc"; + * }}} + */ + override def visitSetTableLocation(ctx: SetTableLocationContext): LogicalPlan = withOrigin(ctx) { + AlterTableSetLocation( + visitTableIdentifier(ctx.tableIdentifier), + Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), + visitLocationSpec(ctx.locationSpec)) + } + + /** + * Create an [[AlterTableTouch]] command + * + * For example: + * {{{ + * ALTER TABLE table TOUCH [PARTITION spec]; + * }}} + */ + override def visitTouchTable(ctx: TouchTableContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException("Operation not allowed: ALTER TABLE ... TOUCH ...") + } + + /** + * Create an [[AlterTableCompact]] command + * + * For example: + * {{{ + * ALTER TABLE table [PARTITION spec] COMPACT 'compaction_type'; + * }}} + */ + override def visitCompactTable(ctx: CompactTableContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException("Operation not allowed: ALTER TABLE ... COMPACT ...") + } + + /** + * Create an [[AlterTableMerge]] command + * + * For example: + * {{{ + * ALTER TABLE table [PARTITION spec] CONCATENATE; + * }}} + */ + override def visitConcatenateTable(ctx: ConcatenateTableContext): LogicalPlan = withOrigin(ctx) { + throw new AnalysisException("Operation not allowed: ALTER TABLE ... CONCATENATE") + } + + /** + * Create an [[AlterTableChangeCol]] command + * + * For example: + * {{{ + * ALTER TABLE tableIdentifier [PARTITION spec] + * CHANGE [COLUMN] col_old_name col_new_name column_type [COMMENT col_comment] + * [FIRST|AFTER column_name] [CASCADE|RESTRICT]; + * }}} + */ + override def visitChangeColumn(ctx: ChangeColumnContext): LogicalPlan = withOrigin(ctx) { + val col = visitColType(ctx.colType()) + val comment = if (col.metadata.contains("comment")) { + Option(col.metadata.getString("comment")) + } else { + None + } + + AlterTableChangeCol( + visitTableIdentifier(ctx.tableIdentifier), + Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), + ctx.oldName.getText, + // We could also pass in a struct field - seems easier. + col.name, + col.dataType, + comment, + Option(ctx.after).map(_.getText), + // Note that Restrict and Cascade are mutually exclusive. + ctx.RESTRICT != null, + ctx.CASCADE != null)( + command(ctx)) + } + + /** + * Create an [[AlterTableAddCol]] command + * + * For example: + * {{{ + * ALTER TABLE tableIdentifier [PARTITION spec] + * ADD COLUMNS (name type [COMMENT comment], ...) [CASCADE|RESTRICT] + * }}} + */ + override def visitAddColumns(ctx: AddColumnsContext): LogicalPlan = withOrigin(ctx) { + AlterTableAddCol( + visitTableIdentifier(ctx.tableIdentifier), + Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), + createStructType(ctx.colTypeList), + // Note that Restrict and Cascade are mutually exclusive. + ctx.RESTRICT != null, + ctx.CASCADE != null)( + command(ctx)) + } + + /** + * Create an [[AlterTableReplaceCol]] command + * + * For example: + * {{{ + * ALTER TABLE tableIdentifier [PARTITION spec] + * REPLACE COLUMNS (name type [COMMENT comment], ...) [CASCADE|RESTRICT] + * }}} + */ + override def visitReplaceColumns(ctx: ReplaceColumnsContext): LogicalPlan = withOrigin(ctx) { + AlterTableReplaceCol( + visitTableIdentifier(ctx.tableIdentifier), + Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), + createStructType(ctx.colTypeList), + // Note that Restrict and Cascade are mutually exclusive. + ctx.RESTRICT != null, + ctx.CASCADE != null)( + command(ctx)) + } + + /** + * Create location string. + */ + override def visitLocationSpec(ctx: LocationSpecContext): String = withOrigin(ctx) { + string(ctx.STRING) + } + + /** + * Create a [[BucketSpec]]. + */ + override def visitBucketSpec(ctx: BucketSpecContext): BucketSpec = withOrigin(ctx) { + BucketSpec( + ctx.INTEGER_VALUE.getText.toInt, + visitIdentifierList(ctx.identifierList), + Option(ctx.orderedIdentifierList).toSeq + .flatMap(_.orderedIdentifier.asScala) + .map(_.identifier.getText)) + } + + /** + * Convert a nested constants list into a sequence of string sequences. + */ + override def visitNestedConstantList( + ctx: NestedConstantListContext): Seq[Seq[String]] = withOrigin(ctx) { + ctx.constantList.asScala.map(visitConstantList) + } + + /** + * Convert a constants list into a String sequence. + */ + override def visitConstantList(ctx: ConstantListContext): Seq[String] = withOrigin(ctx) { + ctx.constant.asScala.map(visitStringConstant) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index b19ad4f1c563e..c590f7c6c3e8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -22,18 +22,15 @@ import java.util.{HashMap => JavaHashMap} import scala.reflect.ClassTag -import com.clearspring.analytics.stream.cardinality.HyperLogLog -import com.esotericsoftware.kryo.io.{Input, Output} import com.esotericsoftware.kryo.{Kryo, Serializer} +import com.esotericsoftware.kryo.io.{Input, Output} import com.twitter.chill.ResourcePool +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{KryoSerializer, SerializerInstance} -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHashSet} import org.apache.spark.sql.types.Decimal import org.apache.spark.util.MutablePair -import org.apache.spark.util.collection.OpenHashSet -import org.apache.spark.{SparkConf, SparkEnv} private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) { override def newKryo(): Kryo = { @@ -43,16 +40,9 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow]) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericInternalRow]) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow]) - kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog], - new HyperLogLogSerializer) kryo.register(classOf[java.math.BigDecimal], new JavaBigDecimalSerializer) kryo.register(classOf[BigDecimal], new ScalaBigDecimalSerializer) - // Specific hashsets must come first TODO: Move to core. - kryo.register(classOf[IntegerHashSet], new IntegerHashSetSerializer) - kryo.register(classOf[LongHashSet], new LongHashSetSerializer) - kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]], - new OpenHashSetSerializer) kryo.register(classOf[Decimal]) kryo.register(classOf[JavaHashMap[_, _]]) @@ -62,7 +52,7 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co } private[execution] class KryoResourcePool(size: Int) - extends ResourcePool[SerializerInstance](size) { + extends ResourcePool[SerializerInstance](size) { val ser: SparkSqlSerializer = { val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) @@ -86,7 +76,7 @@ private[sql] object SparkSqlSerializer { def serialize[T: ClassTag](o: T): Array[Byte] = acquireRelease { k => - k.serialize(o).array() + JavaUtils.bufferToArray(k.serialize(o)) } def deserialize[T: ClassTag](bytes: Array[Byte]): T = @@ -116,92 +106,3 @@ private[sql] class ScalaBigDecimalSerializer extends Serializer[BigDecimal] { new java.math.BigDecimal(input.readString()) } } - -private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] { - def write(kryo: Kryo, output: Output, hyperLogLog: HyperLogLog) { - val bytes = hyperLogLog.getBytes() - output.writeInt(bytes.length) - output.writeBytes(bytes) - } - - def read(kryo: Kryo, input: Input, tpe: Class[HyperLogLog]): HyperLogLog = { - val length = input.readInt() - val bytes = input.readBytes(length) - HyperLogLog.Builder.build(bytes) - } -} - -private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] { - def write(kryo: Kryo, output: Output, hs: OpenHashSet[_]) { - val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]] - output.writeInt(hs.size) - val iterator = hs.iterator - while(iterator.hasNext) { - val row = iterator.next() - rowSerializer.write(kryo, output, row.asInstanceOf[GenericInternalRow].values) - } - } - - def read(kryo: Kryo, input: Input, tpe: Class[OpenHashSet[_]]): OpenHashSet[_] = { - val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]] - val numItems = input.readInt() - val set = new OpenHashSet[Any](numItems + 1) - var i = 0 - while (i < numItems) { - val row = - new GenericInternalRow(rowSerializer.read( - kryo, - input, - classOf[Array[Any]].asInstanceOf[Class[Any]]).asInstanceOf[Array[Any]]) - set.add(row) - i += 1 - } - set - } -} - -private[sql] class IntegerHashSetSerializer extends Serializer[IntegerHashSet] { - def write(kryo: Kryo, output: Output, hs: IntegerHashSet) { - output.writeInt(hs.size) - val iterator = hs.iterator - while(iterator.hasNext) { - val value: Int = iterator.next() - output.writeInt(value) - } - } - - def read(kryo: Kryo, input: Input, tpe: Class[IntegerHashSet]): IntegerHashSet = { - val numItems = input.readInt() - val set = new IntegerHashSet - var i = 0 - while (i < numItems) { - val value = input.readInt() - set.add(value) - i += 1 - } - set - } -} - -private[sql] class LongHashSetSerializer extends Serializer[LongHashSet] { - def write(kryo: Kryo, output: Output, hs: LongHashSet) { - output.writeInt(hs.size) - val iterator = hs.iterator - while(iterator.hasNext) { - val value = iterator.next() - output.writeLong(value) - } - } - - def read(kryo: Kryo, input: Input, tpe: Class[LongHashSet]): LongHashSet = { - val numItems = input.readInt() - val set = new LongHashSet - var i = 0 - while (i < numItems) { - val value = input.readLong() - set.add(value) - i += 1 - } - set - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f4464e0b916f8..c15aaed3654ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,34 +17,62 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, Utils} import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} -import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} -import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} -import org.apache.spark.sql.{Strategy, execution} +import org.apache.spark.sql.execution +import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation} +import org.apache.spark.sql.execution.command.{DescribeCommand => RunnableDescribeCommand, _} +import org.apache.spark.sql.execution.datasources.{DescribeCommand => LogicalDescribeCommand, _} +import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} +import org.apache.spark.sql.execution.streaming.MemoryPlan +import org.apache.spark.sql.internal.SQLConf private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SparkPlanner => - object LeftSemiJoin extends Strategy with PredicateHelper { + /** + * Plans special cases of limit operators. + */ + object SpecialLimits extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.ReturnAnswer(rootPlan) => rootPlan match { + case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => + execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil + case logical.Limit( + IntegerLiteral(limit), + logical.Project(projectList, logical.Sort(order, true, child))) => + execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil + case logical.Limit(IntegerLiteral(limit), child) => + execution.CollectLimit(limit, planLater(child)) :: Nil + case other => planLater(other) :: Nil + } + case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => + execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil + case logical.Limit( + IntegerLiteral(limit), logical.Project(projectList, logical.Sort(order, true, child))) => + execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil + case _ => Nil + } + } + + object ExistenceJoin extends Strategy with PredicateHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ExtractEquiJoinKeys( - LeftSemi, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => - joins.BroadcastLeftSemiJoinHash( - leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil + LeftExistence(jt), leftKeys, rightKeys, condition, left, CanBroadcast(right)) => + Seq(joins.BroadcastHashJoin( + leftKeys, rightKeys, jt, BuildRight, condition, planLater(left), planLater(right))) // Find left semi joins where at least some predicates can be evaluated by matching join keys - case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => - joins.LeftSemiJoinHash( - leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil - // no predicate can be evaluated by matching hash keys - case logical.Join(left, right, LeftSemi, condition) => - joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil + case ExtractEquiJoinKeys( + LeftExistence(jt), leftKeys, rightKeys, condition, left, right) => + Seq(joins.ShuffledHashJoin( + leftKeys, rightKeys, jt, BuildRight, condition, planLater(left), planLater(right))) case _ => Nil } } @@ -53,11 +81,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Matches a plan whose output should be small enough to be used in broadcast join. */ object CanBroadcast { - def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match { - case BroadcastHint(p) => Some(p) - case p if sqlContext.conf.autoBroadcastJoinThreshold > 0 && - p.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => Some(p) - case _ => None + def unapply(plan: LogicalPlan): Option[LogicalPlan] = { + if (plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold) { + Some(plan) + } else { + None + } } } @@ -68,28 +97,48 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Join implementations are chosen with the following precedence: * * - Broadcast: if one side of the join has an estimated physical size that is smaller than the - * user-configurable [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold + * user-configurable [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold * or if that side has an explicit broadcast hint (e.g. the user applied the * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side * of the join will be broadcasted and the other side will be streamed, with no shuffling * performed. If both sides of the join are eligible to be broadcasted then the - * - Sort merge: if the matching join keys are sortable and - * [[org.apache.spark.sql.SQLConf.SORTMERGE_JOIN]] is enabled (default), then sort merge join - * will be used. - * - Hash: will be chosen if neither of the above optimizations apply to this join. + * - Shuffle hash join: if the average size of a single partition is small enough to build a hash + * table. + * - Sort merge: if the matching join keys are sortable. */ object EquiJoinSelection extends Strategy with PredicateHelper { - private[this] def makeBroadcastHashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - left: LogicalPlan, - right: LogicalPlan, - condition: Option[Expression], - side: joins.BuildSide): Seq[SparkPlan] = { - val broadcastHashJoin = execution.joins.BroadcastHashJoin( - leftKeys, rightKeys, side, planLater(left), planLater(right)) - condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil + /** + * Matches a plan whose single partition should be small enough to build a hash table. + * + * Note: this assume that the number of partition is fixed, requires additional work if it's + * dynamic. + */ + def canBuildHashMap(plan: LogicalPlan): Boolean = { + plan.statistics.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions + } + + /** + * Returns whether plan a is much smaller (3X) than plan b. + * + * The cost to build hash map is higher than sorting, we should only build hash map on a table + * that is much smaller than other one. Since we does not have the statistic for number of rows, + * use the size of bytes here as estimation. + */ + private def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = { + a.statistics.sizeInBytes * 3 <= b.statistics.sizeInBytes + } + + /** + * Returns whether we should use shuffle hash join or not. + * + * We should only use shuffle hash join when: + * 1) any single partition of a small table could fit in memory. + * 2) the smaller table is much smaller (3X) than the other one. + */ + private def shouldShuffleHashJoin(left: LogicalPlan, right: LogicalPlan): Boolean = { + canBuildHashMap(left) && muchSmaller(left, right) || + canBuildHashMap(right) && muchSmaller(right, left) } def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { @@ -97,47 +146,57 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- Inner joins -------------------------------------------------------------------------- case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => - makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight) + Seq(joins.BroadcastHashJoin( + leftKeys, rightKeys, Inner, BuildRight, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) => - makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) + Seq(joins.BroadcastHashJoin( + leftKeys, rightKeys, Inner, BuildLeft, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => - val mergeJoin = - joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) - condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil - - case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => + if !conf.preferSortMergeJoin && shouldShuffleHashJoin(left, right) || + !RowOrdering.isOrderable(leftKeys) => val buildSide = if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { - joins.BuildRight + BuildRight } else { - joins.BuildLeft + BuildLeft } - val hashJoin = joins.ShuffledHashJoin( - leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) - condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil + Seq(joins.ShuffledHashJoin( + leftKeys, rightKeys, Inner, buildSide, condition, planLater(left), planLater(right))) + + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) + if RowOrdering.isOrderable(leftKeys) => + joins.SortMergeJoin( + leftKeys, rightKeys, Inner, condition, planLater(left), planLater(right)) :: Nil // --- Outer joins -------------------------------------------------------------------------- case ExtractEquiJoinKeys( LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => - joins.BroadcastHashOuterJoin( - leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil + Seq(joins.BroadcastHashJoin( + leftKeys, rightKeys, LeftOuter, BuildRight, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys( RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) => - joins.BroadcastHashOuterJoin( - leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil + Seq(joins.BroadcastHashJoin( + leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right))) - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) - if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => - joins.SortMergeOuterJoin( - leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) + if !conf.preferSortMergeJoin && canBuildHashMap(right) && muchSmaller(right, left) || + !RowOrdering.isOrderable(leftKeys) => + Seq(joins.ShuffledHashJoin( + leftKeys, rightKeys, LeftOuter, BuildRight, condition, planLater(left), planLater(right))) + + case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) + if !conf.preferSortMergeJoin && canBuildHashMap(left) && muchSmaller(left, right) || + !RowOrdering.isOrderable(leftKeys) => + Seq(joins.ShuffledHashJoin( + leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right))) - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => - joins.ShuffledHashOuterJoin( + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + if RowOrdering.isOrderable(leftKeys) => + joins.SortMergeJoin( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil // --- Cases where this strategy does not apply --------------------------------------------- @@ -146,41 +205,24 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - object HashAggregation extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - // Aggregations that can be performed in two phases, before and after the shuffle. - case PartialAggregation( - namedGroupingAttributes, - rewrittenAggregateExpressions, - groupingExpressions, - partialComputation, - child) if !canBeConvertedToNewAggregation(plan) => - execution.Aggregate( - partial = false, - namedGroupingAttributes, - rewrittenAggregateExpressions, - execution.Aggregate( - partial = true, - groupingExpressions, - partialComputation, - planLater(child))) :: Nil + /** + * Used to plan aggregation queries that are computed incrementally as part of a + * [[org.apache.spark.sql.ContinuousQuery]]. Currently this rule is injected into the planner + * on-demand, only when planning in a [[org.apache.spark.sql.execution.streaming.StreamExecution]] + */ + object StatefulAggregationStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalAggregation( + namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => - case _ => Nil - } + aggregate.Utils.planStreamingAggregation( + namedGroupingExpressions, + aggregateExpressions, + rewrittenResultExpressions, + planLater(child)) - def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = plan match { - case a: logical.Aggregate => - if (sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled) { - a.newAggregation.isDefined - } else { - Utils.checkInvalidAggregateFunction2(a) - false - } - case _ => false + case _ => Nil } - - def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] = - exprs.flatMap(_.collect { case a: AggregateExpression1 => a }) } /** @@ -188,106 +230,46 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object Aggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case p: logical.Aggregate if sqlContext.conf.useSqlAggregate2 && - sqlContext.conf.codegenEnabled => - val converted = p.newAggregation - converted match { - case None => Nil // Cannot convert to new aggregation code path. - case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) => - // A single aggregate expression might appear multiple times in resultExpressions. - // In order to avoid evaluating an individual aggregate function multiple times, we'll - // build a set of the distinct aggregate expressions and build a function which can - // be used to re-write expressions so that they reference the single copy of the - // aggregate function which actually gets computed. - val aggregateExpressions = resultExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression2 => agg - } - }.distinct - // For those distinct aggregate expressions, we create a map from the - // aggregate function to the corresponding attribute of the function. - val aggregateFunctionToAttribute = aggregateExpressions.map { agg => - val aggregateFunction = agg.aggregateFunction - val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute - (aggregateFunction, agg.isDistinct) -> attribute - }.toMap - - val (functionsWithDistinct, functionsWithoutDistinct) = - aggregateExpressions.partition(_.isDistinct) - if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { - // This is a sanity check. We should not reach here when we have multiple distinct - // column sets (aggregate.NewAggregation will not match). - sys.error( - "Multiple distinct column sets are not supported by the new aggregation" + - "code path.") - } + case PhysicalAggregation( + groupingExpressions, aggregateExpressions, resultExpressions, child) => + + val (functionsWithDistinct, functionsWithoutDistinct) = + aggregateExpressions.partition(_.isDistinct) + if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + // This is a sanity check. We should not reach here when we have multiple distinct + // column sets. Our MultipleDistinctRewriter should take care this case. + sys.error("You hit a query analyzer bug. Please report your query to " + + "Spark user mailing list.") + } - val namedGroupingExpressions = groupingExpressions.map { - case ne: NamedExpression => ne -> ne - // If the expression is not a NamedExpressions, we add an alias. - // So, when we generate the result of the operator, the Aggregate Operator - // can directly get the Seq of attributes representing the grouping expressions. - case other => - val withAlias = Alias(other, other.toString)() - other -> withAlias - } - val groupExpressionMap = namedGroupingExpressions.toMap - - // The original `resultExpressions` are a set of expressions which may reference - // aggregate expressions, grouping column values, and constants. When aggregate operator - // emits output rows, we will use `resultExpressions` to generate an output projection - // which takes the grouping columns and final aggregate result buffer as input. - // Thus, we must re-write the result expressions so that their attributes match up with - // the attributes of the final result projection's input row: - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transformDown { - case AggregateExpression2(aggregateFunction, _, isDistinct) => - // The final aggregation buffer's attributes will be `finalAggregationAttributes`, - // so replace each aggregate expression by its corresponding attribute in the set: - aggregateFunctionToAttribute(aggregateFunction, isDistinct) - case expression => - // Since we're using `namedGroupingAttributes` to extract the grouping key - // columns, we need to replace grouping key expressions with their corresponding - // attributes. We do not rely on the equality check at here since attributes may - // differ cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] + val aggregateOperator = + if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) { + if (functionsWithDistinct.nonEmpty) { + sys.error("Distinct columns cannot exist in Aggregate operator containing " + + "aggregate functions which don't support partial aggregation.") + } else { + aggregate.Utils.planAggregateWithoutPartial( + groupingExpressions, + aggregateExpressions, + resultExpressions, + planLater(child)) } + } else if (functionsWithDistinct.isEmpty) { + aggregate.Utils.planAggregateWithoutDistinct( + groupingExpressions, + aggregateExpressions, + resultExpressions, + planLater(child)) + } else { + aggregate.Utils.planAggregateWithOneDistinct( + groupingExpressions, + functionsWithDistinct, + functionsWithoutDistinct, + resultExpressions, + planLater(child)) + } - val aggregateOperator = - if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) { - if (functionsWithDistinct.nonEmpty) { - sys.error("Distinct columns cannot exist in Aggregate operator containing " + - "aggregate functions which don't support partial aggregation.") - } else { - aggregate.Utils.planAggregateWithoutPartial( - namedGroupingExpressions.map(_._2), - aggregateExpressions, - aggregateFunctionToAttribute, - rewrittenResultExpressions, - planLater(child)) - } - } else if (functionsWithDistinct.isEmpty) { - aggregate.Utils.planAggregateWithoutDistinct( - namedGroupingExpressions.map(_._2), - aggregateExpressions, - aggregateFunctionToAttribute, - rewrittenResultExpressions, - planLater(child)) - } else { - aggregate.Utils.planAggregateWithOneDistinct( - namedGroupingExpressions.map(_._2), - functionsWithDistinct, - functionsWithoutDistinct, - aggregateFunctionToAttribute, - rewrittenResultExpressions, - planLater(child)) - } - - aggregateOperator - } + aggregateOperator case _ => Nil } @@ -295,22 +277,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object BroadcastNestedLoop extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Join( - CanBroadcast(left), right, joinType, condition) if joinType != LeftSemi => + case j @ logical.Join(CanBroadcast(left), right, Inner | RightOuter, condition) => execution.joins.BroadcastNestedLoopJoin( - planLater(left), planLater(right), joins.BuildLeft, joinType, condition) :: Nil - case logical.Join( - left, CanBroadcast(right), joinType, condition) if joinType != LeftSemi => + planLater(left), planLater(right), joins.BuildLeft, j.joinType, condition) :: Nil + case j @ logical.Join(left, CanBroadcast(right), Inner | LeftOuter | LeftSemi, condition) => execution.joins.BroadcastNestedLoopJoin( - planLater(left), planLater(right), joins.BuildRight, joinType, condition) :: Nil + planLater(left), planLater(right), joins.BuildRight, j.joinType, condition) :: Nil case _ => Nil } } object CartesianProduct extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - // TODO CartesianProduct doesn't support the Left Semi Join - case logical.Join(left, right, joinType, None) if joinType != LeftSemi => + case logical.Join(left, right, Inner, None) => execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil case logical.Join(left, right, Inner, Some(condition)) => execution.Filter(condition, @@ -328,6 +307,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } else { joins.BuildLeft } + // This join could be very slow or even hang forever joins.BroadcastNestedLoopJoin( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil case _ => Nil @@ -336,18 +316,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1) - object TakeOrderedAndProject extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => - execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil - case logical.Limit( - IntegerLiteral(limit), - logical.Project(projectList, logical.Sort(order, true, child))) => - execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil - case _ => Nil - } - } - object InMemoryScans extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projectList, filters, mem: InMemoryRelation) => @@ -364,103 +332,81 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object BasicOperators extends Strategy { def numPartitions: Int = self.numPartitions - /** - * Picks an appropriate sort operator. - * - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - */ - def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = { - if (sqlContext.conf.unsafeEnabled && sqlContext.conf.codegenEnabled && - TungstenSort.supportsSchema(child.schema)) { - execution.TungstenSort(sortExprs, global, child) - } else { - execution.Sort(sortExprs, global, child) - } - } - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case r: RunnableCommand => ExecutedCommand(r) :: Nil + case MemoryPlan(sink, output) => + val encoder = RowEncoder(sink.schema) + LocalTableScan(output, sink.allData.map(r => encoder.toRow(r).copy())) :: Nil + case logical.Distinct(child) => throw new IllegalStateException( "logical distinct operator should have been replaced by aggregate in the optimizer") - - case logical.MapPartitions(f, tEnc, uEnc, output, child) => - execution.MapPartitions(f, tEnc, uEnc, output, planLater(child)) :: Nil - case logical.AppendColumn(f, tEnc, uEnc, newCol, child) => - execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil - case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) => - execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil - case logical.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output, - leftGroup, rightGroup, left, right) => - execution.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output, leftGroup, rightGroup, + case logical.Intersect(left, right) => + throw new IllegalStateException( + "logical intersect operator should have been replaced by semi-join in the optimizer") + + case logical.DeserializeToObject(deserializer, child) => + execution.DeserializeToObject(deserializer, planLater(child)) :: Nil + case logical.SerializeFromObject(serializer, child) => + execution.SerializeFromObject(serializer, planLater(child)) :: Nil + case logical.MapPartitions(f, in, out, child) => + execution.MapPartitions(f, in, out, planLater(child)) :: Nil + case logical.MapElements(f, in, out, child) => + execution.MapElements(f, in, out, planLater(child)) :: Nil + case logical.AppendColumns(f, in, out, child) => + execution.AppendColumns(f, in, out, planLater(child)) :: Nil + case logical.MapGroups(f, key, in, out, grouping, data, child) => + execution.MapGroups(f, key, in, out, grouping, data, planLater(child)) :: Nil + case logical.CoGroup(f, keyObj, lObj, rObj, out, lGroup, rGroup, lAttr, rAttr, left, right) => + execution.CoGroup( + f, keyObj, lObj, rObj, out, lGroup, rGroup, lAttr, rAttr, planLater(left), planLater(right)) :: Nil case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { - execution.Exchange(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil + ShuffleExchange(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil } else { execution.Coalesce(numPartitions, planLater(child)) :: Nil } case logical.SortPartitions(sortExprs, child) => // This sort only sorts tuples within a partition. Its requiredDistribution will be // an UnspecifiedDistribution. - getSortOperator(sortExprs, global = false, planLater(child)) :: Nil + execution.Sort(sortExprs, global = false, child = planLater(child)) :: Nil case logical.Sort(sortExprs, global, child) => - getSortOperator(sortExprs, global, planLater(child)):: Nil + execution.Sort(sortExprs, global, planLater(child)) :: Nil case logical.Project(projectList, child) => - // If unsafe mode is enabled and we support these data types in Unsafe, use the - // Tungsten project. Otherwise, use the normal project. - if (sqlContext.conf.unsafeEnabled && - UnsafeProjection.canSupport(projectList) && UnsafeProjection.canSupport(child.schema)) { - execution.TungstenProject(projectList, planLater(child)) :: Nil - } else { - execution.Project(projectList, planLater(child)) :: Nil - } + execution.Project(projectList, planLater(child)) :: Nil case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil - case e @ logical.Expand(_, _, _, child) => + case e @ logical.Expand(_, _, child) => execution.Expand(e.projections, e.output, planLater(child)) :: Nil - case a @ logical.Aggregate(group, agg, child) => { - val useNewAggregation = sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled - if (useNewAggregation && a.newAggregation.isDefined) { - // If this logical.Aggregate can be planned to use new aggregation code path - // (i.e. it can be planned by the Strategy Aggregation), we will not use the old - // aggregation code path. - Nil - } else { - Utils.checkInvalidAggregateFunction2(a) - execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil - } - } - case logical.Window(projectList, windowExprs, partitionSpec, orderSpec, child) => - execution.Window( - projectList, windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil + case logical.Window(windowExprs, partitionSpec, orderSpec, child) => + execution.Window(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child) => execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => LocalTableScan(output, data) :: Nil - case logical.Limit(IntegerLiteral(limit), child) => - execution.Limit(limit, planLater(child)) :: Nil - case Unions(unionChildren) => + case logical.LocalLimit(IntegerLiteral(limit), child) => + execution.LocalLimit(limit, planLater(child)) :: Nil + case logical.GlobalLimit(IntegerLiteral(limit), child) => + execution.GlobalLimit(limit, planLater(child)) :: Nil + case logical.Union(unionChildren) => execution.Union(unionChildren.map(planLater)) :: Nil case logical.Except(left, right) => execution.Except(planLater(left), planLater(right)) :: Nil - case logical.Intersect(left, right) => - execution.Intersect(planLater(left), planLater(right)) :: Nil case g @ logical.Generate(generator, join, outer, _, _, child) => execution.Generate( generator, join = join, outer = outer, g.output, planLater(child)) :: Nil case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd, "OneRowRelation") :: Nil + case r @ logical.Range(start, end, step, numSlices, output) => + execution.Range(start, step, numSlices, r.numElements, output) :: Nil case logical.RepartitionByExpression(expressions, child, nPartitions) => - execution.Exchange(HashPartitioning( + exchange.ShuffleExchange(HashPartitioning( expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil - case e @ EvaluatePython(udf, child, _) => - BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil - case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "PhysicalRDD") :: Nil - case BroadcastHint(child) => apply(child) + case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil + case BroadcastHint(child) => planLater(child) :: Nil case _ => Nil } } @@ -476,23 +422,21 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case c: CreateTableUsing if c.temporary && c.allowExisting => sys.error("allowExisting should be set to false when creating a temporary table.") - case CreateTableUsingAsSelect(tableIdent, provider, true, partitionsCols, mode, opts, query) - if partitionsCols.nonEmpty => + case c: CreateTableUsingAsSelect if c.temporary && c.partitionColumns.nonEmpty => sys.error("Cannot create temporary partitioned table.") - case CreateTableUsingAsSelect(tableIdent, provider, true, _, mode, opts, query) => + case c: CreateTableUsingAsSelect if c.temporary => val cmd = CreateTempTableUsingAsSelect( - tableIdent, provider, Array.empty[String], mode, opts, query) + c.tableIdent, c.provider, Array.empty[String], c.mode, c.options, c.child) ExecutedCommand(cmd) :: Nil case c: CreateTableUsingAsSelect if !c.temporary => sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") case describe @ LogicalDescribeCommand(table, isExtended) => - val resultPlan = self.sqlContext.executePlan(table).executedPlan - ExecutedCommand( - RunnableDescribeCommand(resultPlan, describe.output, isExtended)) :: Nil + ExecutedCommand(RunnableDescribeCommand(table, describe.output, isExtended)) :: Nil - case logical.ShowFunctions(db, pattern) => ExecutedCommand(ShowFunctions(db, pattern)) :: Nil + case logical.ShowFunctions(db, pattern) => + ExecutedCommand(ShowFunctions(db, pattern)) :: Nil case logical.DescribeFunction(function, extended) => ExecutedCommand(DescribeFunction(function, extended)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 7e981268de392..a23ebec95333b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -24,7 +24,7 @@ import scala.reflect.ClassTag import com.google.common.io.ByteStreams -import org.apache.spark.serializer.{SerializationStream, DeserializationStream, SerializerInstance, Serializer} +import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer, SerializerInstance} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.unsafe.Platform @@ -94,7 +94,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst private[this] val dIn: DataInputStream = new DataInputStream(new BufferedInputStream(in)) // 1024 is a default buffer size; this buffer will grow to accommodate larger rows private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024) - private[this] var row: UnsafeRow = new UnsafeRow() + private[this] var row: UnsafeRow = new UnsafeRow(numFields) private[this] var rowTuple: (Int, UnsafeRow) = (0, row) private[this] val EOF: Int = -1 @@ -117,7 +117,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst rowBuffer = new Array[Byte](rowSize) } ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) - row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize) + row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, rowSize) rowSize = readSize() if (rowSize == EOF) { // We are returning the last row in this stream dIn.close() @@ -152,7 +152,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst rowBuffer = new Array[Byte](rowSize) } ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) - row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize) + row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, rowSize) row.asInstanceOf[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala new file mode 100644 index 0000000000000..29acc38ab3584 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -0,0 +1,490 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.{broadcast, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.toCommentSafeString +import org.apache.spark.sql.execution.aggregate.TungstenAggregate +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} +import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +/** + * An interface for those physical operators that support codegen. + */ +trait CodegenSupport extends SparkPlan { + + /** Prefix used in the current operator's variable names. */ + private def variablePrefix: String = this match { + case _: TungstenAggregate => "agg" + case _: BroadcastHashJoin => "bhj" + case _: SortMergeJoin => "smj" + case _: PhysicalRDD => "rdd" + case _: DataSourceScan => "scan" + case _ => nodeName.toLowerCase + } + + /** + * Creates a metric using the specified name. + * + * @return name of the variable representing the metric + */ + def metricTerm(ctx: CodegenContext, name: String): String = { + val metric = ctx.addReferenceObj(name, longMetric(name)) + val value = ctx.freshName("metricValue") + val cls = classOf[LongSQLMetricValue].getName + ctx.addMutableState(cls, value, s"$value = ($cls) $metric.localValue();") + value + } + + /** + * Whether this SparkPlan support whole stage codegen or not. + */ + def supportCodegen: Boolean = true + + /** + * Which SparkPlan is calling produce() of this one. It's itself for the first SparkPlan. + */ + protected var parent: CodegenSupport = null + + /** + * Returns all the RDDs of InternalRow which generates the input rows. + * + * Note: right now we support up to two RDDs. + */ + def upstreams(): Seq[RDD[InternalRow]] + + /** + * Returns Java source code to process the rows from upstream. + */ + final def produce(ctx: CodegenContext, parent: CodegenSupport): String = { + this.parent = parent + ctx.freshNamePrefix = variablePrefix + waitForSubqueries() + s""" + |/*** PRODUCE: ${toCommentSafeString(this.simpleString)} */ + |${doProduce(ctx)} + """.stripMargin + } + + /** + * Generate the Java source code to process, should be overridden by subclass to support codegen. + * + * doProduce() usually generate the framework, for example, aggregation could generate this: + * + * if (!initialized) { + * # create a hash map, then build the aggregation hash map + * # call child.produce() + * initialized = true; + * } + * while (hashmap.hasNext()) { + * row = hashmap.next(); + * # build the aggregation results + * # create variables for results + * # call consume(), which will call parent.doConsume() + * if (shouldStop()) return; + * } + */ + protected def doProduce(ctx: CodegenContext): String + + /** + * Consume the generated columns or row from current SparkPlan, call it's parent's doConsume(). + */ + final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String = { + val inputVars = + if (row != null) { + ctx.currentVars = null + ctx.INPUT_ROW = row + output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable).gen(ctx) + } + } else { + assert(outputVars != null) + assert(outputVars.length == output.length) + // outputVars will be used to generate the code for UnsafeRow, so we should copy them + outputVars.map(_.copy()) + } + + val rowVar = if (row != null) { + ExprCode("", "false", row) + } else { + if (outputVars.nonEmpty) { + val colExprs = output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } + val evaluateInputs = evaluateVariables(outputVars) + // generate the code to create a UnsafeRow + ctx.currentVars = outputVars + val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) + val code = s""" + |$evaluateInputs + |${ev.code.trim} + """.stripMargin.trim + ExprCode(code, "false", ev.value) + } else { + // There is no columns + ExprCode("", "false", "unsafeRow") + } + } + + ctx.freshNamePrefix = parent.variablePrefix + val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs) + s""" + | + |/*** CONSUME: ${toCommentSafeString(parent.simpleString)} */ + |$evaluated + |${parent.doConsume(ctx, inputVars, rowVar)} + """.stripMargin + } + + /** + * Returns source code to evaluate all the variables, and clear the code of them, to prevent + * them to be evaluated twice. + */ + protected def evaluateVariables(variables: Seq[ExprCode]): String = { + val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n") + variables.foreach(_.code = "") + evaluate + } + + /** + * Returns source code to evaluate the variables for required attributes, and clear the code + * of evaluated variables, to prevent them to be evaluated twice. + */ + protected def evaluateRequiredVariables( + attributes: Seq[Attribute], + variables: Seq[ExprCode], + required: AttributeSet): String = { + val evaluateVars = new StringBuilder + variables.zipWithIndex.foreach { case (ev, i) => + if (ev.code != "" && required.contains(attributes(i))) { + evaluateVars.append(ev.code.trim + "\n") + ev.code = "" + } + } + evaluateVars.toString() + } + + /** + * The subset of inputSet those should be evaluated before this plan. + * + * We will use this to insert some code to access those columns that are actually used by current + * plan before calling doConsume(). + */ + def usedInputs: AttributeSet = references + + /** + * Generate the Java source code to process the rows from child SparkPlan. + * + * This should be override by subclass to support codegen. + * + * For example, Filter will generate the code like this: + * + * # code to evaluate the predicate expression, result is isNull1 and value2 + * if (isNull1 || !value2) continue; + * # call consume(), which will call parent.doConsume() + * + * Note: A plan can either consume the rows as UnsafeRow (row), or a list of variables (input). + */ + def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + throw new UnsupportedOperationException + } +} + + +/** + * InputAdapter is used to hide a SparkPlan from a subtree that support codegen. + * + * This is the leaf node of a tree with WholeStageCodegen, is used to generate code that consumes + * an RDD iterator of InternalRow. + */ +case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport { + + override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def doExecute(): RDD[InternalRow] = { + child.execute() + } + + override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { + child.doExecuteBroadcast() + } + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.execute() :: Nil + } + + override def doProduce(ctx: CodegenContext): String = { + val input = ctx.freshName("input") + // Right now, InputAdapter is only used when there is one upstream. + ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val row = ctx.freshName("row") + s""" + | while ($input.hasNext()) { + | InternalRow $row = (InternalRow) $input.next(); + | ${consume(ctx, null, row).trim} + | if (shouldStop()) return; + | } + """.stripMargin + } + + override def simpleString: String = "INPUT" + + override def treeChildren: Seq[SparkPlan] = Nil +} + +object WholeStageCodegen { + val PIPELINE_DURATION_METRIC = "duration" +} + +/** + * WholeStageCodegen compile a subtree of plans that support codegen together into single Java + * function. + * + * Here is the call graph of to generate Java source (plan A support codegen, but plan B does not): + * + * WholeStageCodegen Plan A FakeInput Plan B + * ========================================================================= + * + * -> execute() + * | + * doExecute() ---------> upstreams() -------> upstreams() ------> execute() + * | + * +-----------------> produce() + * | + * doProduce() -------> produce() + * | + * doProduce() + * | + * doConsume() <--------- consume() + * | + * doConsume() <-------- consume() + * + * SparkPlan A should override doProduce() and doConsume(). + * + * doCodeGen() will create a CodeGenContext, which will hold a list of variables for input, + * used to generated code for BoundReference. + */ +case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSupport { + + override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override private[sql] lazy val metrics = Map( + "pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext, + WholeStageCodegen.PIPELINE_DURATION_METRIC)) + + /** + * Generates code for this subtree. + * + * @return the tuple of the codegen context and the actual generated source. + */ + def doCodeGen(): (CodegenContext, String) = { + val ctx = new CodegenContext + val code = child.asInstanceOf[CodegenSupport].produce(ctx, this) + val source = s""" + public Object generate(Object[] references) { + return new GeneratedIterator(references); + } + + /** Codegened pipeline for: + * ${toCommentSafeString(child.treeString.trim)} + */ + final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { + + private Object[] references; + ${ctx.declareMutableStates()} + + public GeneratedIterator(Object[] references) { + this.references = references; + } + + public void init(int index, scala.collection.Iterator inputs[]) { + partitionIndex = index; + ${ctx.initMutableStates()} + } + + ${ctx.declareAddedFunctions()} + + protected void processNext() throws java.io.IOException { + ${code.trim} + } + } + """.trim + + // try to compile, helpful for debug + val cleanedSource = CodeFormatter.stripExtraNewLines(source) + logDebug(s"\n${CodeFormatter.format(cleanedSource)}") + CodeGenerator.compile(cleanedSource) + (ctx, cleanedSource) + } + + override def doExecute(): RDD[InternalRow] = { + val (ctx, cleanedSource) = doCodeGen() + val references = ctx.references.toArray + + val durationMs = longMetric("pipelineTime") + + val rdds = child.asInstanceOf[CodegenSupport].upstreams() + assert(rdds.size <= 2, "Up to two upstream RDDs can be supported") + if (rdds.length == 1) { + rdds.head.mapPartitionsWithIndex { (index, iter) => + val clazz = CodeGenerator.compile(cleanedSource) + val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] + buffer.init(index, Array(iter)) + new Iterator[InternalRow] { + override def hasNext: Boolean = { + val v = buffer.hasNext + if (!v) durationMs += buffer.durationMs() + v + } + override def next: InternalRow = buffer.next() + } + } + } else { + // Right now, we support up to two upstreams. + rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) => + val partitionIndex = TaskContext.getPartitionId() + val clazz = CodeGenerator.compile(cleanedSource) + val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] + buffer.init(partitionIndex, Array(leftIter, rightIter)) + new Iterator[InternalRow] { + override def hasNext: Boolean = { + val v = buffer.hasNext + if (!v) durationMs += buffer.durationMs() + v + } + override def next: InternalRow = buffer.next() + } + } + } + } + + override def upstreams(): Seq[RDD[InternalRow]] = { + throw new UnsupportedOperationException + } + + override def doProduce(ctx: CodegenContext): String = { + throw new UnsupportedOperationException + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + val doCopy = if (ctx.copyResult) { + ".copy()" + } else { + "" + } + s""" + |${row.code} + |append(${row.value}$doCopy); + """.stripMargin.trim + } + + override def innerChildren: Seq[SparkPlan] = { + child :: Nil + } + + private def collectInputs(plan: SparkPlan): Seq[SparkPlan] = plan match { + case InputAdapter(c) => c :: Nil + case other => other.children.flatMap(collectInputs) + } + + override def treeChildren: Seq[SparkPlan] = { + collectInputs(child) + } + + override def simpleString: String = "WholeStageCodegen" +} + + +/** + * Find the chained plans that support codegen, collapse them together as WholeStageCodegen. + */ +case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { + + private def supportCodegen(e: Expression): Boolean = e match { + case e: LeafExpression => true + case e: CaseWhen => e.shouldCodegen + // CodegenFallback requires the input to be an InternalRow + case e: CodegenFallback => false + case _ => true + } + + private def numOfNestedFields(dataType: DataType): Int = dataType match { + case dt: StructType => dt.fields.map(f => numOfNestedFields(f.dataType)).sum + case m: MapType => numOfNestedFields(m.keyType) + numOfNestedFields(m.valueType) + case a: ArrayType => numOfNestedFields(a.elementType) + case u: UserDefinedType[_] => numOfNestedFields(u.sqlType) + case _ => 1 + } + + private def supportCodegen(plan: SparkPlan): Boolean = plan match { + case plan: CodegenSupport if plan.supportCodegen => + val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined) + // the generated code will be huge if there are too many columns + val hasTooManyOutputFields = + numOfNestedFields(plan.schema) > conf.wholeStageMaxNumFields + val hasTooManyInputFields = + plan.children.map(p => numOfNestedFields(p.schema)).exists(_ > conf.wholeStageMaxNumFields) + !willFallback && !hasTooManyOutputFields && !hasTooManyInputFields + case _ => false + } + + /** + * Inserts a InputAdapter on top of those that do not support codegen. + */ + private def insertInputAdapter(plan: SparkPlan): SparkPlan = plan match { + case j @ SortMergeJoin(_, _, _, _, left, right) if j.supportCodegen => + // The children of SortMergeJoin should do codegen separately. + j.copy(left = InputAdapter(insertWholeStageCodegen(left)), + right = InputAdapter(insertWholeStageCodegen(right))) + case p if !supportCodegen(p) => + // collapse them recursively + InputAdapter(insertWholeStageCodegen(p)) + case p => + p.withNewChildren(p.children.map(insertInputAdapter)) + } + + /** + * Inserts a WholeStageCodegen on top of those that support codegen. + */ + private def insertWholeStageCodegen(plan: SparkPlan): SparkPlan = plan match { + case plan: CodegenSupport if supportCodegen(plan) => + WholeStageCodegen(insertInputAdapter(plan)) + case other => + other.withNewChildren(other.children.map(insertWholeStageCodegen)) + } + + def apply(plan: SparkPlan): SparkPlan = { + if (conf.wholeStageEnabled) { + insertWholeStageCodegen(plan) + } else { + plan + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 53c5ccf8fa37e..8e9214fa258b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -17,12 +17,19 @@ package org.apache.spark.sql.execution +import java.util + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types.IntegerType -import org.apache.spark.rdd.RDD -import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator} /** * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) @@ -42,6 +49,8 @@ import org.apache.spark.util.collection.CompactBuffer * - Moving frame: Every time we move to a new row to process, we remove some rows from the frame * and we add some rows to the frame. Examples are: * 1 PRECEDING AND CURRENT ROW and 1 FOLLOWING AND 2 FOLLOWING. + * - Offset frame: The frame consist of one row, which is an offset number of rows away from the + * current row. Only [[OffsetWindowFunction]]s can be processed in an offset frame. * * Different frame boundaries can be used in Growing, Shrinking and Moving frames. A frame * boundary can be either Row or Range based: @@ -72,14 +81,14 @@ import org.apache.spark.util.collection.CompactBuffer * of specialized classes: [[RowBoundOrdering]] & [[RangeBoundOrdering]]. */ case class Window( - projectList: Seq[Attribute], windowExpression: Seq[NamedExpression], partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], child: SparkPlan) extends UnaryNode { - override def output: Seq[Attribute] = projectList ++ windowExpression.map(_.toAttribute) + override def output: Seq[Attribute] = + child.output ++ windowExpression.map(_.toAttribute) override def requiredChildDistribution: Seq[Distribution] = { if (partitionSpec.isEmpty) { @@ -95,8 +104,6 @@ case class Window( override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def canProcessUnsafeRows: Boolean = true - /** * Create a bound ordering object for a given frame type and offset. A bound ordering object is * used to determine which input row lies within the frame boundaries of an output row. @@ -122,12 +129,10 @@ case class Window( // Create the projection which returns the current 'value'. val current = newMutableProjection(expr :: Nil, child.output)() // Flip the sign of the offset when processing the order is descending - val boundOffset = - if (sortExpr.direction == Descending) { - -offset - } else { - offset - } + val boundOffset = sortExpr.direction match { + case Descending => -offset + case Ascending => offset + } // Create the projection which returns the current 'value' modified by adding the offset. val boundExpr = Add(expr, Cast(Literal.create(boundOffset, IntegerType), expr.dataType)) val bound = newMutableProjection(boundExpr :: Nil, child.output)() @@ -149,43 +154,108 @@ case class Window( } /** - * Create a frame processor. - * - * This method uses Code Generation. It can only be used on the executor side. - * - * @param frame boundaries. - * @param functions to process in the frame. - * @param ordinal at which the processor starts writing to the output. - * @return a frame processor. + * Collection containing an entry for each window frame to process. Each entry contains a frames' + * WindowExpressions and factory function for the WindowFrameFunction. */ - private[this] def createFrameProcessor( - frame: WindowFrame, - functions: Array[WindowFunction], - ordinal: Int): WindowFunctionFrame = frame match { - // Growing Frame. - case SpecifiedWindowFrame(frameType, UnboundedPreceding, FrameBoundaryExtractor(high)) => - val uBoundOrdering = createBoundOrdering(frameType, high) - new UnboundedPrecedingWindowFunctionFrame(ordinal, functions, uBoundOrdering) - - // Shrinking Frame. - case SpecifiedWindowFrame(frameType, FrameBoundaryExtractor(low), UnboundedFollowing) => - val lBoundOrdering = createBoundOrdering(frameType, low) - new UnboundedFollowingWindowFunctionFrame(ordinal, functions, lBoundOrdering) - - // Moving Frame. - case SpecifiedWindowFrame(frameType, - FrameBoundaryExtractor(low), FrameBoundaryExtractor(high)) => - val lBoundOrdering = createBoundOrdering(frameType, low) - val uBoundOrdering = createBoundOrdering(frameType, high) - new SlidingWindowFunctionFrame(ordinal, functions, lBoundOrdering, uBoundOrdering) - - // Entire Partition Frame. - case SpecifiedWindowFrame(_, UnboundedPreceding, UnboundedFollowing) => - new UnboundedWindowFunctionFrame(ordinal, functions) - - // Error - case fr => - sys.error(s"Unsupported Frame $fr for functions: $functions") + private[this] lazy val windowFrameExpressionFactoryPairs = { + type FrameKey = (String, FrameType, Option[Int], Option[Int]) + type ExpressionBuffer = mutable.Buffer[Expression] + val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)] + + // Add a function and its function to the map for a given frame. + def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = { + val key = (tpe, fr.frameType, FrameBoundary(fr.frameStart), FrameBoundary(fr.frameEnd)) + val (es, fns) = framedFunctions.getOrElseUpdate( + key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression])) + es.append(e) + fns.append(fn) + } + + // Collect all valid window functions and group them by their frame. + windowExpression.foreach { x => + x.foreach { + case e @ WindowExpression(function, spec) => + val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] + function match { + case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f) + case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f) + case f: OffsetWindowFunction => collect("OFFSET", frame, e, f) + case f => sys.error(s"Unsupported window function: $f") + } + case _ => + } + } + + // Map the groups to a (unbound) expression and frame factory pair. + var numExpressions = 0 + framedFunctions.toSeq.map { + case (key, (expressions, functionSeq)) => + val ordinal = numExpressions + val functions = functionSeq.toArray + + // Construct an aggregate processor if we need one. + def processor = AggregateProcessor( + functions, + ordinal, + child.output, + (expressions, schema) => + newMutableProjection(expressions, schema, subexpressionEliminationEnabled)) + + // Create the factory + val factory = key match { + // Offset Frame + case ("OFFSET", RowFrame, Some(offset), Some(h)) if offset == h => + target: MutableRow => + new OffsetWindowFunctionFrame( + target, + ordinal, + functions, + child.output, + (expressions, schema) => + newMutableProjection(expressions, schema, subexpressionEliminationEnabled), + offset) + + // Growing Frame. + case ("AGGREGATE", frameType, None, Some(high)) => + target: MutableRow => { + new UnboundedPrecedingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, high)) + } + + // Shrinking Frame. + case ("AGGREGATE", frameType, Some(low), None) => + target: MutableRow => { + new UnboundedFollowingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, low)) + } + + // Moving Frame. + case ("AGGREGATE", frameType, Some(low), Some(high)) => + target: MutableRow => { + new SlidingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, low), + createBoundOrdering(frameType, high)) + } + + // Entire Partition Frame. + case ("AGGREGATE", frameType, None, None) => + target: MutableRow => { + new UnboundedWindowFunctionFrame(target, processor) + } + } + + // Keep track of the number of expressions. This is a side-effect in a map... + numExpressions += expressions.size + + // Create the Frame Expression - Factory pair. + (expressions, factory) + } } /** @@ -197,111 +267,120 @@ case class Window( * @return the final resulting projection. */ private[this] def createResultProjection( - expressions: Seq[Expression]): MutableProjection = { + expressions: Seq[Expression]): UnsafeProjection = { val references = expressions.zipWithIndex.map{ case (e, i) => // Results of window expressions will be on the right side of child's output BoundReference(child.output.size + i, e.dataType, e.nullable) } val unboundToRefMap = expressions.zip(references).toMap val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) - newMutableProjection( - projectList ++ patchedWindowExpression, - child.output)() + UnsafeProjection.create( + child.output ++ patchedWindowExpression, + child.output) } protected override def doExecute(): RDD[InternalRow] = { - // Prepare processing. - // Group the window expression by their processing frame. - val windowExprs = windowExpression.flatMap { - _.collect { - case e: WindowExpression => e - } - } - - // Create Frame processor factories and order the unbound window expressions by the frame they - // are processed in; this is the order in which their results will be written to window - // function result buffer. - val framedWindowExprs = windowExprs.groupBy(_.windowSpec.frameSpecification) - val factories = Array.ofDim[() => WindowFunctionFrame](framedWindowExprs.size) - val unboundExpressions = scala.collection.mutable.Buffer.empty[Expression] - framedWindowExprs.zipWithIndex.foreach { - case ((frame, unboundFrameExpressions), index) => - // Track the ordinal. - val ordinal = unboundExpressions.size - - // Track the unbound expressions - unboundExpressions ++= unboundFrameExpressions - - // Bind the expressions. - val functions = unboundFrameExpressions.map { e => - BindReferences.bindReference(e.windowFunction, child.output) - }.toArray - - // Create the frame processor factory. - factories(index) = () => createFrameProcessor(frame, functions, ordinal) - } + // Unwrap the expressions and factories from the map. + val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) + val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray // Start processing. child.execute().mapPartitions { stream => new Iterator[InternalRow] { // Get all relevant projections. - val result = createResultProjection(unboundExpressions) - val grouping = if (child.outputsUnsafeRows) { - UnsafeProjection.create(partitionSpec, child.output) - } else { - newProjection(partitionSpec, child.output) - } + val result = createResultProjection(expressions) + val grouping = UnsafeProjection.create(partitionSpec, child.output) // Manage the stream and the grouping. - var nextRow: InternalRow = EmptyRow - var nextGroup: InternalRow = EmptyRow + var nextRow: UnsafeRow = null + var nextGroup: UnsafeRow = null var nextRowAvailable: Boolean = false private[this] def fetchNextRow() { nextRowAvailable = stream.hasNext if (nextRowAvailable) { - nextRow = stream.next() + nextRow = stream.next().asInstanceOf[UnsafeRow] nextGroup = grouping(nextRow) } else { - nextRow = EmptyRow - nextGroup = EmptyRow + nextRow = null + nextGroup = null } } fetchNextRow() // Manage the current partition. - var rows: CompactBuffer[InternalRow] = _ - val frames: Array[WindowFunctionFrame] = factories.map(_()) + val rows = ArrayBuffer.empty[UnsafeRow] + val inputFields = child.output.length + var sorter: UnsafeExternalSorter = null + var rowBuffer: RowBuffer = null + val windowFunctionResult = new SpecificMutableRow(expressions.map(_.dataType)) + val frames = factories.map(_(windowFunctionResult)) val numFrames = frames.length private[this] def fetchNextPartition() { // Collect all the rows in the current partition. // Before we start to fetch new input rows, make a copy of nextGroup. val currentGroup = nextGroup.copy() - rows = new CompactBuffer + + // clear last partition + if (sorter != null) { + // the last sorter of this task will be cleaned up via task completion listener + sorter.cleanupResources() + sorter = null + } else { + rows.clear() + } + while (nextRowAvailable && nextGroup == currentGroup) { - rows += nextRow.copy() + if (sorter == null) { + rows += nextRow.copy() + + if (rows.length >= 4096) { + // We will not sort the rows, so prefixComparator and recordComparator are null. + sorter = UnsafeExternalSorter.create( + TaskContext.get().taskMemoryManager(), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + TaskContext.get(), + null, + null, + 1024, + SparkEnv.get.memoryManager.pageSizeBytes) + rows.foreach { r => + sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0) + } + rows.clear() + } + } else { + sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset, + nextRow.getSizeInBytes, 0) + } fetchNextRow() } + if (sorter != null) { + rowBuffer = new ExternalRowBuffer(sorter, inputFields) + } else { + rowBuffer = new ArrayRowBuffer(rows) + } // Setup the frames. var i = 0 while (i < numFrames) { - frames(i).prepare(rows) + frames(i).prepare(rowBuffer.copy()) i += 1 } // Setup iteration rowIndex = 0 - rowsSize = rows.size + rowsSize = rowBuffer.size() } // Iteration var rowIndex = 0 - var rowsSize = 0 + var rowsSize = 0L + override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable val join = new JoinedRow - val windowFunctionResult = new GenericMutableRow(unboundExpressions.size) override final def next(): InternalRow = { // Load the next partition if we need to. if (rowIndex >= rowsSize && nextRowAvailable) { @@ -311,13 +390,14 @@ case class Window( if (rowIndex < rowsSize) { // Get the results for the window frames. var i = 0 + val current = rowBuffer.next() while (i < numFrames) { - frames(i).write(windowFunctionResult) + frames(i).write(rowIndex, current) i += 1 } // 'Merge' the input row with the window function result - join(rows(rowIndex), windowFunctionResult) + join(current, windowFunctionResult) rowIndex += 1 // Return the projection. @@ -333,14 +413,18 @@ case class Window( * Function for comparing boundary values. */ private[execution] abstract class BoundOrdering { - def compare(input: Seq[InternalRow], inputIndex: Int, outputIndex: Int): Int + def compare(inputRow: InternalRow, inputIndex: Int, outputRow: InternalRow, outputIndex: Int): Int } /** * Compare the input index to the bound of the output index. */ private[execution] final case class RowBoundOrdering(offset: Int) extends BoundOrdering { - override def compare(input: Seq[InternalRow], inputIndex: Int, outputIndex: Int): Int = + override def compare( + inputRow: InternalRow, + inputIndex: Int, + outputRow: InternalRow, + outputIndex: Int): Int = inputIndex - (outputIndex + offset) } @@ -351,148 +435,198 @@ private[execution] final case class RangeBoundOrdering( ordering: Ordering[InternalRow], current: Projection, bound: Projection) extends BoundOrdering { - override def compare(input: Seq[InternalRow], inputIndex: Int, outputIndex: Int): Int = - ordering.compare(current(input(inputIndex)), bound(input(outputIndex))) + override def compare( + inputRow: InternalRow, + inputIndex: Int, + outputRow: InternalRow, + outputIndex: Int): Int = + ordering.compare(current(inputRow), bound(outputRow)) } /** - * A window function calculates the results of a number of window functions for a window frame. - * Before use a frame must be prepared by passing it all the rows in the current partition. After - * preparation the update method can be called to fill the output rows. - * - * TODO How to improve performance? A few thoughts: - * - Window functions are expensive due to its distribution and ordering requirements. - * Unfortunately it is up to the Spark engine to solve this. Improvements in the form of project - * Tungsten are on the way. - * - The window frame processing bit can be improved though. But before we start doing that we - * need to see how much of the time and resources are spent on partitioning and ordering, and - * how much time and resources are spent processing the partitions. There are a couple ways to - * improve on the current situation: - * - Reduce memory footprint by performing streaming calculations. This can only be done when - * there are no Unbound/Unbounded Following calculations present. - * - Use Tungsten style memory usage. - * - Use code generation in general, and use the approach to aggregation taken in the - * GeneratedAggregate class in specific. - * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * The interface of row buffer for a partition */ -private[execution] abstract class WindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction]) { +private[execution] abstract class RowBuffer { - // Make sure functions are initialized. - functions.foreach(_.init()) + /** Number of rows. */ + def size(): Int - /** Number of columns the window function frame is managing */ - val numColumns = functions.length + /** Return next row in the buffer, null if no more left. */ + def next(): InternalRow - /** - * Create a fresh thread safe copy of the frame. - * - * @return the copied frame. - */ - def copy: WindowFunctionFrame + /** Skip the next `n` rows. */ + def skip(n: Int): Unit - /** - * Create new instances of the functions. - * - * @return an array containing copies of the current window functions. - */ - protected final def copyFunctions: Array[WindowFunction] = functions.map(_.newInstance()) + /** Return a new RowBuffer that has the same rows. */ + def copy(): RowBuffer +} - /** - * Prepare the frame for calculating the results for a partition. - * - * @param rows to calculate the frame results for. - */ - def prepare(rows: CompactBuffer[InternalRow]): Unit +/** + * A row buffer based on ArrayBuffer (the number of rows is limited) + */ +private[execution] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends RowBuffer { - /** - * Write the result for the current row to the given target row. - * - * @param target row to write the result for the current row to. - */ - def write(target: GenericMutableRow): Unit + private[this] var cursor: Int = -1 - /** Reset the current window functions. */ - protected final def reset(): Unit = { - var i = 0 - while (i < numColumns) { - functions(i).reset() - i += 1 + /** Number of rows. */ + def size(): Int = buffer.length + + /** Return next row in the buffer, null if no more left. */ + def next(): InternalRow = { + cursor += 1 + if (cursor < buffer.length) { + buffer(cursor) + } else { + null } } - /** Prepare an input row for processing. */ - protected final def prepare(input: InternalRow): Array[AnyRef] = { - val prepared = new Array[AnyRef](numColumns) - var i = 0 - while (i < numColumns) { - prepared(i) = functions(i).prepareInputParameters(input) - i += 1 - } - prepared + /** Skip the next `n` rows. */ + def skip(n: Int): Unit = { + cursor += n } - /** Evaluate a prepared buffer (iterator). */ - protected final def evaluatePrepared(iterator: java.util.Iterator[Array[AnyRef]]): Unit = { - reset() - while (iterator.hasNext) { - val prepared = iterator.next() - var i = 0 - while (i < numColumns) { - functions(i).update(prepared(i)) - i += 1 - } + /** Return a new RowBuffer that has the same rows. */ + def copy(): RowBuffer = { + new ArrayRowBuffer(buffer) + } +} + +/** + * An external buffer of rows based on UnsafeExternalSorter + */ +private[execution] class ExternalRowBuffer(sorter: UnsafeExternalSorter, numFields: Int) + extends RowBuffer { + + private[this] val iter: UnsafeSorterIterator = sorter.getIterator + + private[this] val currentRow = new UnsafeRow(numFields) + + /** Number of rows. */ + def size(): Int = iter.getNumRecords() + + /** Return next row in the buffer, null if no more left. */ + def next(): InternalRow = { + if (iter.hasNext) { + iter.loadNext() + currentRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) + currentRow + } else { + null } - evaluate() } - /** Evaluate a prepared buffer (array). */ - protected final def evaluatePrepared(prepared: Array[Array[AnyRef]], - fromIndex: Int, toIndex: Int): Unit = { + /** Skip the next `n` rows. */ + def skip(n: Int): Unit = { var i = 0 - while (i < numColumns) { - val function = functions(i) - function.reset() - var j = fromIndex - while (j < toIndex) { - function.update(prepared(j)(i)) - j += 1 - } - function.evaluate() + while (i < n && iter.hasNext) { + iter.loadNext() i += 1 } } - /** Update an array of window functions. */ - protected final def update(input: InternalRow): Unit = { - var i = 0 - while (i < numColumns) { - val aggregate = functions(i) - val preparedInput = aggregate.prepareInputParameters(input) - aggregate.update(preparedInput) - i += 1 + /** Return a new RowBuffer that has the same rows. */ + def copy(): RowBuffer = { + new ExternalRowBuffer(sorter, numFields) + } +} + +/** + * A window function calculates the results of a number of window functions for a window frame. + * Before use a frame must be prepared by passing it all the rows in the current partition. After + * preparation the update method can be called to fill the output rows. + */ +private[execution] abstract class WindowFunctionFrame { + /** + * Prepare the frame for calculating the results for a partition. + * + * @param rows to calculate the frame results for. + */ + def prepare(rows: RowBuffer): Unit + + /** + * Write the current results to the target row. + */ + def write(index: Int, current: InternalRow): Unit +} + +/** + * The offset window frame calculates frames containing LEAD/LAG statements. + * + * @param target to write results to. + * @param expressions to shift a number of rows. + * @param inputSchema required for creating a projection. + * @param newMutableProjection function used to create the projection. + * @param offset by which rows get moved within a partition. + */ +private[execution] final class OffsetWindowFunctionFrame( + target: MutableRow, + ordinal: Int, + expressions: Array[Expression], + inputSchema: Seq[Attribute], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => () => MutableProjection, + offset: Int) extends WindowFunctionFrame { + + /** Rows of the partition currently being processed. */ + private[this] var input: RowBuffer = null + + /** Index of the input row currently used for output. */ + private[this] var inputIndex = 0 + + /** Row used when there is no valid input. */ + private[this] val emptyRow = new GenericInternalRow(inputSchema.size) + + /** Row used to combine the offset and the current row. */ + private[this] val join = new JoinedRow + + /** Create the projection. */ + private[this] val projection = { + // Collect the expressions and bind them. + val inputAttrs = inputSchema.map(_.withNullability(true)) + val numInputAttributes = inputAttrs.size + val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { + case e: OffsetWindowFunction => + val input = BindReferences.bindReference(e.input, inputAttrs) + if (e.default == null || e.default.foldable && e.default.eval() == null) { + // Without default value. + input + } else { + // With default value. + val default = BindReferences.bindReference(e.default, inputAttrs).transform { + // Shift the input reference to its default version. + case BoundReference(o, dataType, nullable) => + BoundReference(o + numInputAttributes, dataType, nullable) + } + org.apache.spark.sql.catalyst.expressions.Coalesce(input :: default :: Nil) + } + case e => + BindReferences.bindReference(e, inputAttrs) } + + // Create the projection. + newMutableProjection(boundExpressions, Nil)().target(target) } - /** Evaluate the window functions. */ - protected final def evaluate(): Unit = { - var i = 0 - while (i < numColumns) { - functions(i).evaluate() - i += 1 + override def prepare(rows: RowBuffer): Unit = { + input = rows + // drain the first few rows if offset is larger than zero + inputIndex = 0 + while (inputIndex < offset) { + input.next() + inputIndex += 1 } + inputIndex = offset } - /** Fill a target row with the current window function results. */ - protected final def fill(target: GenericMutableRow, rowIndex: Int): Unit = { - var i = 0 - while (i < numColumns) { - target.update(ordinal + i, functions(i).get(rowIndex)) - i += 1 + override def write(index: Int, current: InternalRow): Unit = { + if (inputIndex >= 0 && inputIndex < input.size) { + val r = input.next() + join(r, current) + } else { + join(emptyRow, current) } + projection(join) + inputIndex += 1 } } @@ -500,78 +634,78 @@ private[execution] abstract class WindowFunctionFrame( * The sliding window frame calculates frames with the following SQL form: * ... BETWEEN 1 PRECEDING AND 1 FOLLOWING * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * @param target to write results to. + * @param processor to calculate the row values with. * @param lbound comparator used to identify the lower bound of an output row. * @param ubound comparator used to identify the upper bound of an output row. */ private[execution] final class SlidingWindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction], + target: MutableRow, + processor: AggregateProcessor, lbound: BoundOrdering, - ubound: BoundOrdering) extends WindowFunctionFrame(ordinal, functions) { + ubound: BoundOrdering) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: CompactBuffer[InternalRow] = null + private[this] var input: RowBuffer = null - /** Index of the first input row with a value greater than the upper bound of the current - * output row. */ - private[this] var inputHighIndex = 0 + /** The next row from `input`. */ + private[this] var nextRow: InternalRow = null - /** Index of the first input row with a value equal to or greater than the lower bound of the - * current output row. */ - private[this] var inputLowIndex = 0 + /** The rows within current sliding window. */ + private[this] val buffer = new util.ArrayDeque[InternalRow]() - /** Buffer used for storing prepared input for the window functions. */ - private[this] val buffer = new java.util.ArrayDeque[Array[AnyRef]] + /** + * Index of the first input row with a value greater than the upper bound of the current + * output row. + */ + private[this] var inputHighIndex = 0 - /** Index of the row we are currently writing. */ - private[this] var outputIndex = 0 + /** + * Index of the first input row with a value equal to or greater than the lower bound of the + * current output row. + */ + private[this] var inputLowIndex = 0 /** Prepare the frame for calculating a new partition. Reset all variables. */ - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { + override def prepare(rows: RowBuffer): Unit = { input = rows + nextRow = rows.next() inputHighIndex = 0 inputLowIndex = 0 - outputIndex = 0 buffer.clear() } /** Write the frame columns for the current row to the given target row. */ - override def write(target: GenericMutableRow): Unit = { - var bufferUpdated = outputIndex == 0 + override def write(index: Int, current: InternalRow): Unit = { + var bufferUpdated = index == 0 // Add all rows to the buffer for which the input row value is equal to or less than // the output row upper bound. - while (inputHighIndex < input.size && - ubound.compare(input, inputHighIndex, outputIndex) <= 0) { - buffer.offer(prepare(input(inputHighIndex))) + while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) { + buffer.add(nextRow.copy()) + nextRow = input.next() inputHighIndex += 1 bufferUpdated = true } // Drop all rows from the buffer for which the input row value is smaller than // the output row lower bound. - while (inputLowIndex < inputHighIndex && - lbound.compare(input, inputLowIndex, outputIndex) < 0) { - buffer.pop() + while (!buffer.isEmpty && lbound.compare(buffer.peek(), inputLowIndex, current, index) < 0) { + buffer.remove() inputLowIndex += 1 bufferUpdated = true } // Only recalculate and update when the buffer changes. if (bufferUpdated) { - evaluatePrepared(buffer.iterator()) - fill(target, outputIndex) + processor.initialize(input.size) + val iter = buffer.iterator() + while (iter.hasNext) { + processor.update(iter.next()) + } + processor.evaluate(target) } - - // Move to the next row. - outputIndex += 1 } - - /** Copy the frame. */ - override def copy: SlidingWindowFunctionFrame = - new SlidingWindowFunctionFrame(ordinal, copyFunctions, lbound, ubound) } /** @@ -582,36 +716,30 @@ private[execution] final class SlidingWindowFunctionFrame( * Its results are the same for each and every row in the partition. This class can be seen as a * special case of a sliding window, but is optimized for the unbound case. * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * @param target to write results to. + * @param processor to calculate the row values with. */ private[execution] final class UnboundedWindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction]) extends WindowFunctionFrame(ordinal, functions) { - - /** Index of the row we are currently writing. */ - private[this] var outputIndex = 0 + target: MutableRow, + processor: AggregateProcessor) extends WindowFunctionFrame { /** Prepare the frame for calculating a new partition. Process all rows eagerly. */ - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { - reset() - outputIndex = 0 - val iterator = rows.iterator - while (iterator.hasNext) { - update(iterator.next()) + override def prepare(rows: RowBuffer): Unit = { + val size = rows.size() + processor.initialize(size) + var i = 0 + while (i < size) { + processor.update(rows.next()) + i += 1 } - evaluate() } /** Write the frame columns for the current row to the given target row. */ - override def write(target: GenericMutableRow): Unit = { - fill(target, outputIndex) - outputIndex += 1 + override def write(index: Int, current: InternalRow): Unit = { + // Unfortunately we cannot assume that evaluation is deterministic. So we need to re-evaluate + // for each row. + processor.evaluate(target) } - - /** Copy the frame. */ - override def copy: UnboundedWindowFunctionFrame = - new UnboundedWindowFunctionFrame(ordinal, copyFunctions) } /** @@ -624,58 +752,53 @@ private[execution] final class UnboundedWindowFunctionFrame( * is not the case when there is no lower bound, given the additive nature of most aggregates * streaming updates and partial evaluation suffice and no buffering is needed. * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * @param target to write results to. + * @param processor to calculate the row values with. * @param ubound comparator used to identify the upper bound of an output row. */ private[execution] final class UnboundedPrecedingWindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction], - ubound: BoundOrdering) extends WindowFunctionFrame(ordinal, functions) { + target: MutableRow, + processor: AggregateProcessor, + ubound: BoundOrdering) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: CompactBuffer[InternalRow] = null + private[this] var input: RowBuffer = null - /** Index of the first input row with a value greater than the upper bound of the current - * output row. */ - private[this] var inputIndex = 0 + /** The next row from `input`. */ + private[this] var nextRow: InternalRow = null - /** Index of the row we are currently writing. */ - private[this] var outputIndex = 0 + /** + * Index of the first input row with a value greater than the upper bound of the current + * output row. + */ + private[this] var inputIndex = 0 /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { - reset() + override def prepare(rows: RowBuffer): Unit = { input = rows + nextRow = rows.next() inputIndex = 0 - outputIndex = 0 + processor.initialize(input.size) } /** Write the frame columns for the current row to the given target row. */ - override def write(target: GenericMutableRow): Unit = { - var bufferUpdated = outputIndex == 0 + override def write(index: Int, current: InternalRow): Unit = { + var bufferUpdated = index == 0 // Add all rows to the aggregates for which the input row value is equal to or less than // the output row upper bound. - while (inputIndex < input.size && ubound.compare(input, inputIndex, outputIndex) <= 0) { - update(input(inputIndex)) + while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index) <= 0) { + processor.update(nextRow) + nextRow = input.next() inputIndex += 1 bufferUpdated = true } // Only recalculate and update when the buffer changes. if (bufferUpdated) { - evaluate() - fill(target, outputIndex) + processor.evaluate(target) } - - // Move to the next row. - outputIndex += 1 } - - /** Copy the frame. */ - override def copy: UnboundedPrecedingWindowFunctionFrame = - new UnboundedPrecedingWindowFunctionFrame(ordinal, copyFunctions, ubound) } /** @@ -688,67 +811,197 @@ private[execution] final class UnboundedPrecedingWindowFunctionFrame( * * This is a very expensive operator to use, O(n * (n - 1) /2), because we need to maintain a * buffer and must do full recalculation after each row. Reverse iteration would be possible, if - * the communitativity of the used window functions can be guaranteed. + * the commutativity of the used window functions can be guaranteed. * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * @param target to write results to. + * @param processor to calculate the row values with. * @param lbound comparator used to identify the lower bound of an output row. */ private[execution] final class UnboundedFollowingWindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction], - lbound: BoundOrdering) extends WindowFunctionFrame(ordinal, functions) { - - /** Buffer used for storing prepared input for the window functions. */ - private[this] var buffer: Array[Array[AnyRef]] = _ + target: MutableRow, + processor: AggregateProcessor, + lbound: BoundOrdering) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: CompactBuffer[InternalRow] = null + private[this] var input: RowBuffer = null - /** Index of the first input row with a value equal to or greater than the lower bound of the - * current output row. */ + /** + * Index of the first input row with a value equal to or greater than the lower bound of the + * current output row. + */ private[this] var inputIndex = 0 - /** Index of the row we are currently writing. */ - private[this] var outputIndex = 0 - /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { + override def prepare(rows: RowBuffer): Unit = { input = rows inputIndex = 0 - outputIndex = 0 - val size = input.size - buffer = Array.ofDim(size) - var i = 0 - while (i < size) { - buffer(i) = prepare(input(i)) - i += 1 - } - evaluatePrepared(buffer, 0, buffer.length) } /** Write the frame columns for the current row to the given target row. */ - override def write(target: GenericMutableRow): Unit = { - var bufferUpdated = outputIndex == 0 + override def write(index: Int, current: InternalRow): Unit = { + var bufferUpdated = index == 0 + + // Duplicate the input to have a new iterator + val tmp = input.copy() // Drop all rows from the buffer for which the input row value is smaller than // the output row lower bound. - while (inputIndex < input.size && lbound.compare(input, inputIndex, outputIndex) < 0) { + tmp.skip(inputIndex) + var nextRow = tmp.next() + while (nextRow != null && lbound.compare(nextRow, inputIndex, current, index) < 0) { + nextRow = tmp.next() inputIndex += 1 bufferUpdated = true } // Only recalculate and update when the buffer changes. if (bufferUpdated) { - evaluatePrepared(buffer, inputIndex, buffer.length) - fill(target, outputIndex) + processor.initialize(input.size) + while (nextRow != null) { + processor.update(nextRow) + nextRow = tmp.next() + } + processor.evaluate(target) + } + } +} + +/** + * This class prepares and manages the processing of a number of [[AggregateFunction]]s within a + * single frame. The [[WindowFunctionFrame]] takes care of processing the frame in the correct way, + * this reduces the processing of a [[AggregateWindowFunction]] to processing the underlying + * [[AggregateFunction]]. All [[AggregateFunction]]s are processed in [[Complete]] mode. + * + * [[SizeBasedWindowFunction]]s are initialized in a slightly different way. These functions + * require the size of the partition processed, this value is exposed to them when the processor is + * constructed. + * + * Processing of distinct aggregates is currently not supported. + * + * The implementation is split into an object which takes care of construction, and a the actual + * processor class. + */ +private[execution] object AggregateProcessor { + def apply( + functions: Array[Expression], + ordinal: Int, + inputAttributes: Seq[Attribute], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => () => MutableProjection): + AggregateProcessor = { + val aggBufferAttributes = mutable.Buffer.empty[AttributeReference] + val initialValues = mutable.Buffer.empty[Expression] + val updateExpressions = mutable.Buffer.empty[Expression] + val evaluateExpressions = mutable.Buffer.fill[Expression](ordinal)(NoOp) + val imperatives = mutable.Buffer.empty[ImperativeAggregate] + + // SPARK-14244: `SizeBasedWindowFunction`s are firstly created on driver side and then + // serialized to executor side. These functions all reference a global singleton window + // partition size attribute reference, i.e., `SizeBasedWindowFunction.n`. Here we must collect + // the singleton instance created on driver side instead of using executor side + // `SizeBasedWindowFunction.n` to avoid binding failure caused by mismatching expression ID. + val partitionSize: Option[AttributeReference] = { + val aggs = functions.flatMap(_.collectFirst { case f: SizeBasedWindowFunction => f }) + aggs.headOption.map(_.n) + } + + // Check if there are any SizeBasedWindowFunctions. If there are, we add the partition size to + // the aggregation buffer. Note that the ordinal of the partition size value will always be 0. + partitionSize.foreach { n => + aggBufferAttributes += n + initialValues += NoOp + updateExpressions += NoOp + } + + // Add an AggregateFunction to the AggregateProcessor. + functions.foreach { + case agg: DeclarativeAggregate => + aggBufferAttributes ++= agg.aggBufferAttributes + initialValues ++= agg.initialValues + updateExpressions ++= agg.updateExpressions + evaluateExpressions += agg.evaluateExpression + case agg: ImperativeAggregate => + val offset = aggBufferAttributes.size + val imperative = BindReferences.bindReference(agg + .withNewInputAggBufferOffset(offset) + .withNewMutableAggBufferOffset(offset), + inputAttributes) + imperatives += imperative + aggBufferAttributes ++= imperative.aggBufferAttributes + val noOps = Seq.fill(imperative.aggBufferAttributes.size)(NoOp) + initialValues ++= noOps + updateExpressions ++= noOps + evaluateExpressions += imperative + case other => + sys.error(s"Unsupported Aggregate Function: $other") } - // Move to the next row. - outputIndex += 1 + // Create the projections. + val initialProjection = newMutableProjection( + initialValues, + partitionSize.toSeq)() + val updateProjection = newMutableProjection( + updateExpressions, + aggBufferAttributes ++ inputAttributes)() + val evaluateProjection = newMutableProjection( + evaluateExpressions, + aggBufferAttributes)() + + // Create the processor + new AggregateProcessor( + aggBufferAttributes.toArray, + initialProjection, + updateProjection, + evaluateProjection, + imperatives.toArray, + partitionSize.isDefined) + } +} + +/** + * This class manages the processing of a number of aggregate functions. See the documentation of + * the object for more information. + */ +private[execution] final class AggregateProcessor( + private[this] val bufferSchema: Array[AttributeReference], + private[this] val initialProjection: MutableProjection, + private[this] val updateProjection: MutableProjection, + private[this] val evaluateProjection: MutableProjection, + private[this] val imperatives: Array[ImperativeAggregate], + private[this] val trackPartitionSize: Boolean) { + + private[this] val join = new JoinedRow + private[this] val numImperatives = imperatives.length + private[this] val buffer = new SpecificMutableRow(bufferSchema.toSeq.map(_.dataType)) + initialProjection.target(buffer) + updateProjection.target(buffer) + + /** Create the initial state. */ + def initialize(size: Int): Unit = { + // Some initialization expressions are dependent on the partition size so we have to + // initialize the size before initializing all other fields, and we have to pass the buffer to + // the initialization projection. + if (trackPartitionSize) { + buffer.setInt(0, size) + } + initialProjection(buffer) + var i = 0 + while (i < numImperatives) { + imperatives(i).initialize(buffer) + i += 1 + } + } + + /** Update the buffer. */ + def update(input: InternalRow): Unit = { + updateProjection(join(buffer, input)) + var i = 0 + while (i < numImperatives) { + imperatives(i).update(buffer, input) + i += 1 + } } - /** Copy the frame. */ - override def copy: UnboundedFollowingWindowFunctionFrame = - new UnboundedFollowingWindowFunctionFrame(ordinal, copyFunctions, lbound) + /** Evaluate buffer. */ + def evaluate(target: MutableRow): Unit = + evaluateProjection.target(target)(buffer) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 99fb7a40b72e1..042c7319018be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -17,15 +17,15 @@ package org.apache.spark.sql.execution.aggregate -import org.apache.spark.Logging +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import scala.collection.mutable.ArrayBuffer - /** - * The base class of [[SortBasedAggregationIterator]]. + * The base class of [[SortBasedAggregationIterator]] and [[TungstenAggregationIterator]]. * It mainly contains two parts: * 1. It initializes aggregate functions. * 2. It creates two functions, `processRow` and `generateOutput` based on [[AggregateMode]] of @@ -33,64 +33,58 @@ import scala.collection.mutable.ArrayBuffer * is used to generate result. */ abstract class AggregationIterator( - groupingKeyAttributes: Seq[Attribute], - valueAttributes: Seq[Attribute], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], + groupingExpressions: Seq[NamedExpression], + inputAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean) - extends Iterator[InternalRow] with Logging { + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection)) + extends Iterator[UnsafeRow] with Logging { /////////////////////////////////////////////////////////////////////////// // Initializing functions. /////////////////////////////////////////////////////////////////////////// - // An Seq of all AggregateExpressions. - // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final - // are at the beginning of the allAggregateExpressions. - protected val allAggregateExpressions = - nonCompleteAggregateExpressions ++ completeAggregateExpressions - - require( - allAggregateExpressions.map(_.mode).distinct.length <= 2, - s"$allAggregateExpressions are not supported becuase they have more than 2 distinct modes.") - /** - * The distinct modes of AggregateExpressions. Right now, we can handle the following mode: - * - Partial-only: all AggregateExpressions have the mode of Partial; - * - PartialMerge-only: all AggregateExpressions have the mode of PartialMerge); - * - Final-only: all AggregateExpressions have the mode of Final; - * - Final-Complete: some AggregateExpressions have the mode of Final and - * others have the mode of Complete; - * - Complete-only: nonCompleteAggregateExpressions is empty and we have AggregateExpressions - * with mode Complete in completeAggregateExpressions; and - * - Grouping-only: there is no AggregateExpression. + * The following combinations of AggregationMode are supported: + * - Partial + * - PartialMerge (for single distinct) + * - Partial and PartialMerge (for single distinct) + * - Final + * - Complete (for SortBasedAggregate with functions that does not support Partial) + * - Final and Complete (currently not used) + * + * TODO: AggregateMode should have only two modes: Update and Merge, AggregateExpression + * could have a flag to tell it's final or not. */ - protected val aggregationMode: (Option[AggregateMode], Option[AggregateMode]) = - nonCompleteAggregateExpressions.map(_.mode).distinct.headOption -> - completeAggregateExpressions.map(_.mode).distinct.headOption + { + val modes = aggregateExpressions.map(_.mode).distinct.toSet + require(modes.size <= 2, + s"$aggregateExpressions are not supported because they have more than 2 distinct modes.") + require(modes.subsetOf(Set(Partial, PartialMerge)) || modes.subsetOf(Set(Final, Complete)), + s"$aggregateExpressions can't have Partial/PartialMerge and Final/Complete in the same time.") + } // Initialize all AggregateFunctions by binding references if necessary, // and set inputBufferOffset and mutableBufferOffset. - protected val allAggregateFunctions: Array[AggregateFunction2] = { + protected def initializeAggregateFunctions( + expressions: Seq[AggregateExpression], + startingInputBufferOffset: Int): Array[AggregateFunction] = { var mutableBufferOffset = 0 - var inputBufferOffset: Int = initialInputBufferOffset - val functions = new Array[AggregateFunction2](allAggregateExpressions.length) + var inputBufferOffset: Int = startingInputBufferOffset + val functions = new Array[AggregateFunction](expressions.length) var i = 0 - while (i < allAggregateExpressions.length) { - val func = allAggregateExpressions(i).aggregateFunction - val funcWithBoundReferences: AggregateFunction2 = allAggregateExpressions(i).mode match { + while (i < expressions.length) { + val func = expressions(i).aggregateFunction + val funcWithBoundReferences: AggregateFunction = expressions(i).mode match { case Partial | Complete if func.isInstanceOf[ImperativeAggregate] => // We need to create BoundReferences if the function is not an // expression-based aggregate function (it does not support code-gen) and the mode of // this function is Partial or Complete because we will call eval of this // function's children in the update method of this aggregate function. // Those eval calls require BoundReferences to work. - BindReferences.bindReference(func, valueAttributes) + BindReferences.bindReference(func, inputAttributes) case _ => // We only need to set inputBufferOffset for aggregate functions with mode // PartialMerge and Final. @@ -117,15 +111,18 @@ abstract class AggregationIterator( functions } + protected val aggregateFunctions: Array[AggregateFunction] = + initializeAggregateFunctions(aggregateExpressions, initialInputBufferOffset) + // Positions of those imperative aggregate functions in allAggregateFunctions. // For example, we have func1, func2, func3, func4 in aggregateFunctions, and // func2 and func3 are imperative aggregate functions. // ImperativeAggregateFunctionPositions will be [1, 2]. - private[this] val allImperativeAggregateFunctionPositions: Array[Int] = { + protected[this] val allImperativeAggregateFunctionPositions: Array[Int] = { val positions = new ArrayBuffer[Int]() var i = 0 - while (i < allAggregateFunctions.length) { - allAggregateFunctions(i) match { + while (i < aggregateFunctions.length) { + aggregateFunctions(i) match { case agg: DeclarativeAggregate => case _ => positions += i } @@ -134,17 +131,9 @@ abstract class AggregationIterator( positions.toArray } - // All AggregateFunctions functions with mode Partial, PartialMerge, or Final. - private[this] val nonCompleteAggregateFunctions: Array[AggregateFunction2] = - allAggregateFunctions.take(nonCompleteAggregateExpressions.length) - - // All imperative aggregate functions with mode Partial, PartialMerge, or Final. - private[this] val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] = - nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func } - // The projection used to initialize buffer values for all expression-based aggregates. - private[this] val expressionAggInitialProjection = { - val initExpressions = allAggregateFunctions.flatMap { + protected[this] val expressionAggInitialProjection = { + val initExpressions = aggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.initialValues // For the positions corresponding to imperative aggregate functions, we'll use special // no-op expressions which are ignored during projection code-generation. @@ -154,248 +143,112 @@ abstract class AggregationIterator( } // All imperative AggregateFunctions. - private[this] val allImperativeAggregateFunctions: Array[ImperativeAggregate] = + protected[this] val allImperativeAggregateFunctions: Array[ImperativeAggregate] = allImperativeAggregateFunctionPositions - .map(allAggregateFunctions) + .map(aggregateFunctions) .map(_.asInstanceOf[ImperativeAggregate]) - /////////////////////////////////////////////////////////////////////////// - // Methods and fields used by sub-classes. - /////////////////////////////////////////////////////////////////////////// - // Initializing functions used to process a row. - protected val processRow: (MutableRow, InternalRow) => Unit = { - val rowToBeProcessed = new JoinedRow - val aggregationBufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes) - aggregationMode match { - // Partial-only - case (Some(Partial), None) => - val updateExpressions = nonCompleteAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val expressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() - - (currentBuffer: MutableRow, row: InternalRow) => { - expressionAggUpdateProjection.target(currentBuffer) - // Process all expression-based aggregate functions. - expressionAggUpdateProjection(rowToBeProcessed(currentBuffer, row)) - // Process all imperative aggregate functions. - var i = 0 - while (i < nonCompleteImperativeAggregateFunctions.length) { - nonCompleteImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 + protected def generateProcessRow( + expressions: Seq[AggregateExpression], + functions: Seq[AggregateFunction], + inputAttributes: Seq[Attribute]): (MutableRow, InternalRow) => Unit = { + val joinedRow = new JoinedRow + if (expressions.nonEmpty) { + val mergeExpressions = functions.zipWithIndex.flatMap { + case (ae: DeclarativeAggregate, i) => + expressions(i).mode match { + case Partial | Complete => ae.updateExpressions + case PartialMerge | Final => ae.mergeExpressions } - } - - // PartialMerge-only or Final-only - case (Some(PartialMerge), None) | (Some(Final), None) => - val inputAggregationBufferSchema = if (initialInputBufferOffset == 0) { - // If initialInputBufferOffset, the input value does not contain - // grouping keys. - // This part is pretty hacky. - allAggregateFunctions.flatMap(_.inputAggBufferAttributes).toSeq - } else { - groupingKeyAttributes ++ allAggregateFunctions.flatMap(_.inputAggBufferAttributes) - } - // val inputAggregationBufferSchema = - // groupingKeyAttributes ++ - // allAggregateFunctions.flatMap(_.cloneBufferAttributes) - val mergeExpressions = nonCompleteAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - // This projection is used to merge buffer values for all expression-based aggregates. - val expressionAggMergeProjection = - newMutableProjection( - mergeExpressions, - aggregationBufferSchema ++ inputAggregationBufferSchema)() - - (currentBuffer: MutableRow, row: InternalRow) => { - // Process all expression-based aggregate functions. - expressionAggMergeProjection.target(currentBuffer)(rowToBeProcessed(currentBuffer, row)) - // Process all imperative aggregate functions. - var i = 0 - while (i < nonCompleteImperativeAggregateFunctions.length) { - nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row) - i += 1 - } - } - - // Final-Complete - case (Some(Final), Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction2] = - allAggregateFunctions.takeRight(completeAggregateExpressions.length) - // All imperative aggregate functions with mode Complete. - val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = - completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - - // The first initialInputBufferOffset values of the input aggregation buffer is - // for grouping expressions and distinct columns. - val groupingAttributesAndDistinctColumns = valueAttributes.take(initialInputBufferOffset) - - val completeOffsetExpressions = - Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) - // We do not touch buffer values of aggregate functions with the Final mode. - val finalOffsetExpressions = - Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) - - val mergeInputSchema = - aggregationBufferSchema ++ - groupingAttributesAndDistinctColumns ++ - nonCompleteAggregateFunctions.flatMap(_.inputAggBufferAttributes) - val mergeExpressions = - nonCompleteAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } ++ completeOffsetExpressions - val finalExpressionAggMergeProjection = - newMutableProjection(mergeExpressions, mergeInputSchema)() - - val updateExpressions = - finalOffsetExpressions ++ completeAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val completeExpressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() - - (currentBuffer: MutableRow, row: InternalRow) => { - val input = rowToBeProcessed(currentBuffer, row) - // For all aggregate functions with mode Complete, update buffers. - completeExpressionAggUpdateProjection.target(currentBuffer)(input) - var i = 0 - while (i < completeImperativeAggregateFunctions.length) { - completeImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - - // For all aggregate functions with mode Final, merge buffers. - finalExpressionAggMergeProjection.target(currentBuffer)(input) - i = 0 - while (i < nonCompleteImperativeAggregateFunctions.length) { - nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row) - i += 1 - } - } - - // Complete-only - case (None, Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction2] = - allAggregateFunctions.takeRight(completeAggregateExpressions.length) - // All imperative aggregate functions with mode Complete. - val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = - completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - - val updateExpressions = - completeAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val completeExpressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() - - (currentBuffer: MutableRow, row: InternalRow) => { - val input = rowToBeProcessed(currentBuffer, row) - // For all aggregate functions with mode Complete, update buffers. - completeExpressionAggUpdateProjection.target(currentBuffer)(input) - var i = 0 - while (i < completeImperativeAggregateFunctions.length) { - completeImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 + case (agg: AggregateFunction, _) => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + } + val updateFunctions = functions.zipWithIndex.collect { + case (ae: ImperativeAggregate, i) => + expressions(i).mode match { + case Partial | Complete => + (buffer: MutableRow, row: InternalRow) => ae.update(buffer, row) + case PartialMerge | Final => + (buffer: MutableRow, row: InternalRow) => ae.merge(buffer, row) } + } + // This projection is used to merge buffer values for all expression-based aggregates. + val aggregationBufferSchema = functions.flatMap(_.aggBufferAttributes) + val updateProjection = + newMutableProjection(mergeExpressions, aggregationBufferSchema ++ inputAttributes)() + + (currentBuffer: MutableRow, row: InternalRow) => { + // Process all expression-based aggregate functions. + updateProjection.target(currentBuffer)(joinedRow(currentBuffer, row)) + // Process all imperative aggregate functions. + var i = 0 + while (i < updateFunctions.length) { + updateFunctions(i)(currentBuffer, row) + i += 1 } - + } + } else { // Grouping only. - case (None, None) => (currentBuffer: MutableRow, row: InternalRow) => {} - - case other => - sys.error( - s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " + - s"support evaluate modes $other in this iterator.") + (currentBuffer: MutableRow, row: InternalRow) => {} } } - // Initializing the function used to generate the output row. - protected val generateOutput: (InternalRow, MutableRow) => InternalRow = { - val rowToBeEvaluated = new JoinedRow - val safeOutputRow = new SpecificMutableRow(resultExpressions.map(_.dataType)) - val mutableOutput = if (outputsUnsafeRows) { - UnsafeProjection.create(resultExpressions.map(_.dataType).toArray).apply(safeOutputRow) - } else { - safeOutputRow - } - - aggregationMode match { - // Partial-only or PartialMerge-only: every output row is basically the values of - // the grouping expressions and the corresponding aggregation buffer. - case (Some(Partial), None) | (Some(PartialMerge), None) => - // Because we cannot copy a joinedRow containing a UnsafeRow (UnsafeRow does not - // support generic getter), we create a mutable projection to output the - // JoinedRow(currentGroupingKey, currentBuffer) - val bufferSchema = nonCompleteAggregateFunctions.flatMap(_.aggBufferAttributes) - val resultProjection = - newMutableProjection( - groupingKeyAttributes ++ bufferSchema, - groupingKeyAttributes ++ bufferSchema)() - resultProjection.target(mutableOutput) - - (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { - resultProjection(rowToBeEvaluated(currentGroupingKey, currentBuffer)) - // rowToBeEvaluated(currentGroupingKey, currentBuffer) - } - - // Final-only, Complete-only and Final-Complete: every output row contains values representing - // resultExpressions. - case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => - val bufferSchemata = - allAggregateFunctions.flatMap(_.aggBufferAttributes) - val evalExpressions = allAggregateFunctions.map { - case ae: DeclarativeAggregate => ae.evaluateExpression - case agg: AggregateFunction2 => NoOp - } - val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)() - val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes - // TODO: Use unsafe row. - val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType)) - expressionAggEvalProjection.target(aggregateResult) - val resultProjection = - newMutableProjection( - resultExpressions, groupingKeyAttributes ++ aggregateResultSchema)() - resultProjection.target(mutableOutput) - - (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { - // Generate results for all expression-based aggregate functions. - expressionAggEvalProjection(currentBuffer) - // Generate results for all imperative aggregate functions. - var i = 0 - while (i < allImperativeAggregateFunctions.length) { - aggregateResult.update( - allImperativeAggregateFunctionPositions(i), - allImperativeAggregateFunctions(i).eval(currentBuffer)) - i += 1 - } - resultProjection(rowToBeEvaluated(currentGroupingKey, aggregateResult)) - } + protected val processRow: (MutableRow, InternalRow) => Unit = + generateProcessRow(aggregateExpressions, aggregateFunctions, inputAttributes) - // Grouping-only: we only output values of grouping expressions. - case (None, None) => - val resultProjection = - newMutableProjection(resultExpressions, groupingKeyAttributes)() - resultProjection.target(mutableOutput) + protected val groupingProjection: UnsafeProjection = + UnsafeProjection.create(groupingExpressions, inputAttributes) + protected val groupingAttributes = groupingExpressions.map(_.toAttribute) - (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { - resultProjection(currentGroupingKey) + // Initializing the function used to generate the output row. + protected def generateResultProjection(): (UnsafeRow, MutableRow) => UnsafeRow = { + val joinedRow = new JoinedRow + val modes = aggregateExpressions.map(_.mode).distinct + val bufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes) + if (modes.contains(Final) || modes.contains(Complete)) { + val evalExpressions = aggregateFunctions.map { + case ae: DeclarativeAggregate => ae.evaluateExpression + case agg: AggregateFunction => NoOp + } + val aggregateResult = new SpecificMutableRow(aggregateAttributes.map(_.dataType)) + val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)() + expressionAggEvalProjection.target(aggregateResult) + + val resultProjection = + UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateAttributes) + + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + // Generate results for all expression-based aggregate functions. + expressionAggEvalProjection(currentBuffer) + // Generate results for all imperative aggregate functions. + var i = 0 + while (i < allImperativeAggregateFunctions.length) { + aggregateResult.update( + allImperativeAggregateFunctionPositions(i), + allImperativeAggregateFunctions(i).eval(currentBuffer)) + i += 1 } - - case other => - sys.error( - s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " + - s"support evaluate modes $other in this iterator.") + resultProjection(joinedRow(currentGroupingKey, aggregateResult)) + } + } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { + val resultProjection = UnsafeProjection.create( + groupingAttributes ++ bufferAttributes, + groupingAttributes ++ bufferAttributes) + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + resultProjection(joinedRow(currentGroupingKey, currentBuffer)) + } + } else { + // Grouping-only: we only output values based on grouping expressions. + val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes) + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + resultProjection(currentGroupingKey) + } } } + protected val generateOutput: (UnsafeRow, MutableRow) => UnsafeRow = + generateResultProjection() + /** Initializes buffer values for all aggregate functions. */ protected def initializeBuffer(buffer: MutableRow): Unit = { expressionAggInitialProjection.target(buffer)(EmptyRow) @@ -405,10 +258,4 @@ abstract class AggregationIterator( i += 1 } } - - /** - * Creates a new aggregation buffer and initializes buffer values - * for all aggregate functions. - */ - protected def newBuffer: MutableRow } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index 4d37106e007f5..9fcfea8381ac4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -22,31 +22,31 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} import org.apache.spark.sql.execution.metric.SQLMetrics case class SortBasedAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { - override private[sql] lazy val metrics = Map( - "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - override def outputsUnsafeRows: Boolean = false + private[this] val aggregateBufferAttributes = { + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + } - override def canProcessUnsafeRows: Boolean = false + override def producedAttributes: AttributeSet = + AttributeSet(aggregateAttributes) ++ + AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ + AttributeSet(aggregateBufferAttributes) - override def canProcessSafeRows: Boolean = true + override private[sql] lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -67,42 +67,32 @@ case class SortBasedAggregate( } protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - val numInputRows = longMetric("numInputRows") val numOutputRows = longMetric("numOutputRows") - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => // Because the constructor of an aggregation iterator will read at least the first row, // we need to get the value of iter.hasNext first. val hasInput = iter.hasNext if (!hasInput && groupingExpressions.nonEmpty) { // This is a grouped aggregate and the input iterator is empty, // so return an empty iterator. - Iterator[InternalRow]() + Iterator[UnsafeRow]() } else { - val groupingKeyProjection = if (UnsafeProjection.canSupport(groupingExpressions)) { - UnsafeProjection.create(groupingExpressions, child.output) - } else { - newMutableProjection(groupingExpressions, child.output)() - } val outputIter = new SortBasedAggregationIterator( - groupingKeyProjection, - groupingExpressions.map(_.toAttribute), + groupingExpressions, child.output, iter, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, + aggregateExpressions, + aggregateAttributes, initialInputBufferOffset, resultExpressions, - newMutableProjection, - outputsUnsafeRows, - numInputRows, + (expressions, inputSchema) => + newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), numOutputRows) if (!hasInput && groupingExpressions.isEmpty) { // There is no input and there is no grouping expressions. // We need to output a single row as the output. numOutputRows += 1 - Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) + Iterator[UnsafeRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) } else { outputIter } @@ -111,7 +101,7 @@ case class SortBasedAggregate( } override def simpleString: String = { - val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions + val allAggregateExpressions = aggregateExpressions val keyString = groupingExpressions.mkString("[", ",", "]") val functionString = allAggregateExpressions.mkString("[", ",", "]") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 64c673064f576..de1491d357405 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -19,42 +19,38 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} import org.apache.spark.sql.execution.metric.LongSQLMetric /** - * An iterator used to evaluate [[AggregateFunction2]]. It assumes the input rows have been - * sorted by values of [[groupingKeyAttributes]]. + * An iterator used to evaluate [[AggregateFunction]]. It assumes the input rows have been + * sorted by values of [[groupingExpressions]]. */ class SortBasedAggregationIterator( - groupingKeyProjection: InternalRow => InternalRow, - groupingKeyAttributes: Seq[Attribute], + groupingExpressions: Seq[NamedExpression], valueAttributes: Seq[Attribute], inputIterator: Iterator[InternalRow], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean, - numInputRows: LongSQLMetric, numOutputRows: LongSQLMetric) extends AggregationIterator( - groupingKeyAttributes, + groupingExpressions, valueAttributes, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, + aggregateExpressions, + aggregateAttributes, initialInputBufferOffset, resultExpressions, - newMutableProjection, - outputsUnsafeRows) { - - override protected def newBuffer: MutableRow = { - val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes) + newMutableProjection) { + + /** + * Creates a new aggregation buffer and initializes buffer values + * for all aggregate functions. + */ + private def newBuffer: MutableRow = { + val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes) val bufferRowSize: Int = bufferSchema.length val genericMutableBuffer = new GenericMutableRow(bufferRowSize) @@ -76,10 +72,10 @@ class SortBasedAggregationIterator( /////////////////////////////////////////////////////////////////////////// // The partition key of the current partition. - private[this] var currentGroupingKey: InternalRow = _ + private[this] var currentGroupingKey: UnsafeRow = _ // The partition key of next partition. - private[this] var nextGroupingKey: InternalRow = _ + private[this] var nextGroupingKey: UnsafeRow = _ // The first row of next partition. private[this] var firstRowInNextGroup: InternalRow = _ @@ -90,13 +86,16 @@ class SortBasedAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer + // An SafeProjection to turn UnsafeRow into GenericInternalRow, because UnsafeRow can't be + // compared to MutableRow (aggregation buffer) directly. + private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType)) + protected def initialize(): Unit = { if (inputIterator.hasNext) { initializeBuffer(sortBasedAggregationBuffer) val inputRow = inputIterator.next() - nextGroupingKey = groupingKeyProjection(inputRow).copy() + nextGroupingKey = groupingProjection(inputRow).copy() firstRowInNextGroup = inputRow.copy() - numInputRows += 1 sortedInputHasNewGroup = true } else { // This inputIter is empty. @@ -113,19 +112,18 @@ class SortBasedAggregationIterator( // We create a variable to track if we see the next group. var findNextPartition = false // firstRowInNextGroup is the first row of this group. We first process it. - processRow(sortBasedAggregationBuffer, firstRowInNextGroup) + processRow(sortBasedAggregationBuffer, safeProj(firstRowInNextGroup)) // The search will stop when we see the next group or there is no // input row left in the iter. while (!findNextPartition && inputIterator.hasNext) { // Get the grouping key. val currentRow = inputIterator.next() - val groupingKey = groupingKeyProjection(currentRow) - numInputRows += 1 + val groupingKey = groupingProjection(currentRow) // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { - processRow(sortBasedAggregationBuffer, currentRow) + processRow(sortBasedAggregationBuffer, safeProj(currentRow)) } else { // We find a new group. findNextPartition = true @@ -146,7 +144,7 @@ class SortBasedAggregationIterator( override final def hasNext: Boolean = sortedInputHasNewGroup - override final def next(): InternalRow = { + override final def next(): UnsafeRow = { if (hasNext) { // Process the current group. processCurrentSortedGroup() @@ -162,8 +160,8 @@ class SortBasedAggregationIterator( } } - def outputForEmptyGroupingKeyWithoutInput(): InternalRow = { + def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { initializeBuffer(sortBasedAggregationBuffer) - generateOutput(new GenericInternalRow(0), sortBasedAggregationBuffer) + generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 15616915f7364..f585759e583c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -17,68 +17,71 @@ package org.apache.spark.sql.execution.aggregate +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{LongType, StructType} +import org.apache.spark.unsafe.KVIterator case class TungstenAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryNode { + extends UnaryNode with CodegenSupport { private[this] val aggregateBufferAttributes = { - (nonCompleteAggregateExpressions ++ completeAggregateExpressions) - .flatMap(_.aggregateFunction.aggBufferAttributes) + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) } - require(TungstenAggregate.supportsAggregate(groupingExpressions, aggregateBufferAttributes)) + require(TungstenAggregate.supportsAggregate(aggregateBufferAttributes)) + + override lazy val allAttributes: Seq[Attribute] = + child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) override private[sql] lazy val metrics = Map( - "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"), "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) - override def outputsUnsafeRows: Boolean = true - - override def canProcessUnsafeRows: Boolean = true - - override def canProcessSafeRows: Boolean = true - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + override def producedAttributes: AttributeSet = + AttributeSet(aggregateAttributes) ++ + AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ + AttributeSet(aggregateBufferAttributes) + override def requiredChildDistribution: List[Distribution] = { requiredChildDistributionExpressions match { - case Some(exprs) if exprs.length == 0 => AllTuples :: Nil - case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case Some(exprs) if exprs.isEmpty => AllTuples :: Nil + case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil case None => UnspecifiedDistribution :: Nil } } - // This is for testing. We force TungstenAggregationIterator to fall back to sort-based - // aggregation once it has processed a given number of input rows. - private val testFallbackStartsAt: Option[Int] = { + // This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash + // map and/or the sort-based aggregation once it has processed a given number of input rows. + private val testFallbackStartsAt: Option[(Int, Int)] = { sqlContext.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null) match { case null | "" => None - case fallbackStartsAt => Some(fallbackStartsAt.toInt) + case fallbackStartsAt => + val splits = fallbackStartsAt.split(",").map(_.trim) + Some((splits.head.toInt, splits.last.toInt)) } } protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - val numInputRows = longMetric("numInputRows") val numOutputRows = longMetric("numOutputRows") val dataSize = longMetric("dataSize") val spillSize = longMetric("spillSize") @@ -94,17 +97,15 @@ case class TungstenAggregate( val aggregationIterator = new TungstenAggregationIterator( groupingExpressions, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, + aggregateExpressions, + aggregateAttributes, initialInputBufferOffset, resultExpressions, - newMutableProjection, + (expressions, inputSchema) => + newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), child.output, iter, testFallbackStartsAt, - numInputRows, numOutputRows, dataSize, spillSize) @@ -118,8 +119,590 @@ case class TungstenAggregate( } } + // all the mode of aggregate expressions + private val modes = aggregateExpressions.map(_.mode).distinct + + override def usedInputs: AttributeSet = inputSet + + override def supportCodegen: Boolean = { + // ImperativeAggregate is not supported right now + !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) + } + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + if (groupingExpressions.isEmpty) { + doProduceWithoutKeys(ctx) + } else { + doProduceWithKeys(ctx) + } + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + if (groupingExpressions.isEmpty) { + doConsumeWithoutKeys(ctx, input) + } else { + doConsumeWithKeys(ctx, input) + } + } + + // The variables used as aggregation buffer + private var bufVars: Seq[ExprCode] = _ + + private def doProduceWithoutKeys(ctx: CodegenContext): String = { + val initAgg = ctx.freshName("initAgg") + ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + + // generate variables for aggregation buffer + val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) + val initExpr = functions.flatMap(f => f.initialValues) + bufVars = initExpr.map { e => + val isNull = ctx.freshName("bufIsNull") + val value = ctx.freshName("bufValue") + ctx.addMutableState("boolean", isNull, "") + ctx.addMutableState(ctx.javaType(e.dataType), value, "") + // The initial expression should not access any column + val ev = e.gen(ctx) + val initVars = s""" + | $isNull = ${ev.isNull}; + | $value = ${ev.value}; + """.stripMargin + ExprCode(ev.code + initVars, isNull, value) + } + val initBufVar = evaluateVariables(bufVars) + + // generate variables for output + val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) { + // evaluate aggregate results + ctx.currentVars = bufVars + val aggResults = functions.map(_.evaluateExpression).map { e => + BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx) + } + val evaluateAggResults = evaluateVariables(aggResults) + // evaluate result expressions + ctx.currentVars = aggResults + val resultVars = resultExpressions.map { e => + BindReferences.bindReference(e, aggregateAttributes).gen(ctx) + } + (resultVars, s""" + |$evaluateAggResults + |${evaluateVariables(resultVars)} + """.stripMargin) + } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { + // output the aggregate buffer directly + (bufVars, "") + } else { + // no aggregate function, the result should be literals + val resultVars = resultExpressions.map(_.gen(ctx)) + (resultVars, evaluateVariables(resultVars)) + } + + val doAgg = ctx.freshName("doAggregateWithoutKey") + ctx.addNewFunction(doAgg, + s""" + | private void $doAgg() throws java.io.IOException { + | // initialize aggregation buffer + | $initBufVar + | + | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + | } + """.stripMargin) + + val numOutput = metricTerm(ctx, "numOutputRows") + s""" + | while (!$initAgg) { + | $initAgg = true; + | $doAgg(); + | + | // output the result + | ${genResult.trim} + | + | $numOutput.add(1); + | ${consume(ctx, resultVars).trim} + | } + """.stripMargin + } + + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { + // only have DeclarativeAggregate + val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) + val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output + val updateExpr = aggregateExpressions.flatMap { e => + e.mode match { + case Partial | Complete => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions + case PartialMerge | Final => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions + } + } + ctx.currentVars = bufVars ++ input + // TODO: support subexpression elimination + val aggVals = updateExpr.map(BindReferences.bindReference(_, inputAttrs).gen(ctx)) + // aggregate buffer should be updated atomic + val updates = aggVals.zipWithIndex.map { case (ev, i) => + s""" + | ${bufVars(i).isNull} = ${ev.isNull}; + | ${bufVars(i).value} = ${ev.value}; + """.stripMargin + } + s""" + | // do aggregate + | ${evaluateVariables(aggVals)} + | // update aggregation buffer + | ${updates.mkString("\n").trim} + """.stripMargin + } + + private val groupingAttributes = groupingExpressions.map(_.toAttribute) + private val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + private val declFunctions = aggregateExpressions.map(_.aggregateFunction) + .filter(_.isInstanceOf[DeclarativeAggregate]) + .map(_.asInstanceOf[DeclarativeAggregate]) + private val bufferSchema = StructType.fromAttributes(aggregateBufferAttributes) + + // The name for Vectorized HashMap + private var vectorizedHashMapTerm: String = _ + + // We currently only enable vectorized hashmap for long key/value types and partial aggregates + private val isVectorizedHashMapEnabled: Boolean = sqlContext.conf.columnarAggregateMapEnabled && + (groupingKeySchema ++ bufferSchema).forall(_.dataType == LongType) && + modes.forall(mode => mode == Partial || mode == PartialMerge) + + // The name for UnsafeRow HashMap + private var hashMapTerm: String = _ + private var sorterTerm: String = _ + + /** + * This is called by generated Java class, should be public. + */ + def createHashMap(): UnsafeFixedWidthAggregationMap = { + // create initialized aggregate buffer + val initExpr = declFunctions.flatMap(f => f.initialValues) + val initialBuffer = UnsafeProjection.create(initExpr)(EmptyRow) + + // create hashMap + new UnsafeFixedWidthAggregationMap( + initialBuffer, + bufferSchema, + groupingKeySchema, + TaskContext.get().taskMemoryManager(), + 1024 * 16, // initial capacity + TaskContext.get().taskMemoryManager().pageSizeBytes, + false // disable tracking of performance metrics + ) + } + + /** + * This is called by generated Java class, should be public. + */ + def createUnsafeJoiner(): UnsafeRowJoiner = { + GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + } + + /** + * Called by generated Java class to finish the aggregate and return a KVIterator. + */ + def finishAggregate( + hashMap: UnsafeFixedWidthAggregationMap, + sorter: UnsafeKVExternalSorter): KVIterator[UnsafeRow, UnsafeRow] = { + + // update peak execution memory + val mapMemory = hashMap.getPeakMemoryUsedBytes + val sorterMemory = Option(sorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L) + val peakMemory = Math.max(mapMemory, sorterMemory) + val metrics = TaskContext.get().taskMetrics() + metrics.incPeakExecutionMemory(peakMemory) + // TODO: update data size and spill size + + if (sorter == null) { + // not spilled + return hashMap.iterator() + } + + // merge the final hashMap into sorter + sorter.merge(hashMap.destructAndCreateExternalSorter()) + hashMap.free() + val sortedIter = sorter.sortedIterator() + + // Create a KVIterator based on the sorted iterator. + new KVIterator[UnsafeRow, UnsafeRow] { + + // Create a MutableProjection to merge the rows of same key together + val mergeExpr = declFunctions.flatMap(_.mergeExpressions) + val mergeProjection = newMutableProjection( + mergeExpr, + aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes), + subexpressionEliminationEnabled)() + val joinedRow = new JoinedRow() + + var currentKey: UnsafeRow = null + var currentRow: UnsafeRow = null + var nextKey: UnsafeRow = if (sortedIter.next()) { + sortedIter.getKey + } else { + null + } + + override def next(): Boolean = { + if (nextKey != null) { + currentKey = nextKey.copy() + currentRow = sortedIter.getValue.copy() + nextKey = null + // use the first row as aggregate buffer + mergeProjection.target(currentRow) + + // merge the following rows with same key together + var findNextGroup = false + while (!findNextGroup && sortedIter.next()) { + val key = sortedIter.getKey + if (currentKey.equals(key)) { + mergeProjection(joinedRow(currentRow, sortedIter.getValue)) + } else { + // We find a new group. + findNextGroup = true + nextKey = key + } + } + + true + } else { + false + } + } + + override def getKey: UnsafeRow = currentKey + override def getValue: UnsafeRow = currentRow + override def close(): Unit = { + sortedIter.close() + } + } + } + + /** + * Generate the code for output. + */ + private def generateResultCode( + ctx: CodegenContext, + keyTerm: String, + bufferTerm: String, + plan: String): String = { + if (modes.contains(Final) || modes.contains(Complete)) { + // generate output using resultExpressions + ctx.currentVars = null + ctx.INPUT_ROW = keyTerm + val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) => + BoundReference(i, e.dataType, e.nullable).gen(ctx) + } + val evaluateKeyVars = evaluateVariables(keyVars) + ctx.INPUT_ROW = bufferTerm + val bufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) => + BoundReference(i, e.dataType, e.nullable).gen(ctx) + } + val evaluateBufferVars = evaluateVariables(bufferVars) + // evaluate the aggregation result + ctx.currentVars = bufferVars + val aggResults = declFunctions.map(_.evaluateExpression).map { e => + BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx) + } + val evaluateAggResults = evaluateVariables(aggResults) + // generate the final result + ctx.currentVars = keyVars ++ aggResults + val inputAttrs = groupingAttributes ++ aggregateAttributes + val resultVars = resultExpressions.map { e => + BindReferences.bindReference(e, inputAttrs).gen(ctx) + } + s""" + $evaluateKeyVars + $evaluateBufferVars + $evaluateAggResults + ${consume(ctx, resultVars)} + """ + + } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { + // This should be the last operator in a stage, we should output UnsafeRow directly + val joinerTerm = ctx.freshName("unsafeRowJoiner") + ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm, + s"$joinerTerm = $plan.createUnsafeJoiner();") + val resultRow = ctx.freshName("resultRow") + s""" + UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm); + ${consume(ctx, null, resultRow)} + """ + + } else { + // generate result based on grouping key + ctx.INPUT_ROW = keyTerm + ctx.currentVars = null + val eval = resultExpressions.map{ e => + BindReferences.bindReference(e, groupingAttributes).gen(ctx) + } + consume(ctx, eval) + } + } + + private def doProduceWithKeys(ctx: CodegenContext): String = { + val initAgg = ctx.freshName("initAgg") + ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + + vectorizedHashMapTerm = ctx.freshName("vectorizedHashMap") + val vectorizedHashMapClassName = ctx.freshName("VectorizedHashMap") + val vectorizedHashMapGenerator = new VectorizedHashMapGenerator(ctx, vectorizedHashMapClassName, + groupingKeySchema, bufferSchema) + // Create a name for iterator from vectorized HashMap + val iterTermForVectorizedHashMap = ctx.freshName("vectorizedHashMapIter") + if (isVectorizedHashMapEnabled) { + ctx.addMutableState(vectorizedHashMapClassName, vectorizedHashMapTerm, + s"$vectorizedHashMapTerm = new $vectorizedHashMapClassName();") + ctx.addMutableState( + "java.util.Iterator", + iterTermForVectorizedHashMap, "") + } + + // create hashMap + val thisPlan = ctx.addReferenceObj("plan", this) + hashMapTerm = ctx.freshName("hashMap") + val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName + ctx.addMutableState(hashMapClassName, hashMapTerm, "") + sorterTerm = ctx.freshName("sorter") + ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "") + + // Create a name for iterator from HashMap + val iterTerm = ctx.freshName("mapIter") + ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") + + val doAgg = ctx.freshName("doAggregateWithKeys") + ctx.addNewFunction(doAgg, + s""" + ${if (isVectorizedHashMapEnabled) vectorizedHashMapGenerator.generate() else ""} + private void $doAgg() throws java.io.IOException { + $hashMapTerm = $thisPlan.createHashMap(); + ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + + ${if (isVectorizedHashMapEnabled) { + s"$iterTermForVectorizedHashMap = $vectorizedHashMapTerm.rowIterator();"} else ""} + + $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm); + } + """) + + // generate code for output + val keyTerm = ctx.freshName("aggKey") + val bufferTerm = ctx.freshName("aggBuffer") + val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan) + val numOutput = metricTerm(ctx, "numOutputRows") + + // The child could change `copyResult` to true, but we had already consumed all the rows, + // so `copyResult` should be reset to `false`. + ctx.copyResult = false + + // Iterate over the aggregate rows and convert them from ColumnarBatch.Row to UnsafeRow + def outputFromGeneratedMap: Option[String] = { + if (isVectorizedHashMapEnabled) { + val row = ctx.freshName("vectorizedHashMapRow") + ctx.currentVars = null + ctx.INPUT_ROW = row + var schema: StructType = groupingKeySchema + bufferSchema.foreach(i => schema = schema.add(i)) + val generateRow = GenerateUnsafeProjection.createCode(ctx, schema.toAttributes.zipWithIndex + .map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) }) + Option( + s""" + | while ($iterTermForVectorizedHashMap.hasNext()) { + | $numOutput.add(1); + | org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $row = + | (org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row) + | $iterTermForVectorizedHashMap.next(); + | ${generateRow.code} + | ${consume(ctx, Seq.empty, {generateRow.value})} + | + | if (shouldStop()) return; + | } + | + | $vectorizedHashMapTerm.close(); + """.stripMargin) + } else None + } + + s""" + if (!$initAgg) { + $initAgg = true; + $doAgg(); + } + + // output the result + ${outputFromGeneratedMap.getOrElse("")} + + while ($iterTerm.next()) { + $numOutput.add(1); + UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); + UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); + $outputCode + + if (shouldStop()) return; + } + + $iterTerm.close(); + if ($sorterTerm == null) { + $hashMapTerm.free(); + } + """ + } + + private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { + + // create grouping key + ctx.currentVars = input + val unsafeRowKeyCode = GenerateUnsafeProjection.createCode( + ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) + val vectorizedRowKeys = ctx.generateExpressions( + groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) + val unsafeRowKeys = unsafeRowKeyCode.value + val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer") + val vectorizedRowBuffer = ctx.freshName("vectorizedAggBuffer") + + // only have DeclarativeAggregate + val updateExpr = aggregateExpressions.flatMap { e => + e.mode match { + case Partial | Complete => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions + case PartialMerge | Final => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions + } + } + + // generate hash code for key + val hashExpr = Murmur3Hash(groupingExpressions, 42) + ctx.currentVars = input + val hashEval = BindReferences.bindReference(hashExpr, child.output).gen(ctx) + + val inputAttr = aggregateBufferAttributes ++ child.output + ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input + + val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, + incCounter) = if (testFallbackStartsAt.isDefined) { + val countTerm = ctx.freshName("fallbackCounter") + ctx.addMutableState("int", countTerm, s"$countTerm = 0;") + (s"$countTerm < ${testFallbackStartsAt.get._1}", + s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;") + } else { + ("true", "true", "", "") + } + + // We first generate code to probe and update the vectorized hash map. If the probe is + // successful the corresponding vectorized row buffer will hold the mutable row + val findOrInsertInVectorizedHashMap: Option[String] = { + if (isVectorizedHashMapEnabled) { + Option( + s""" + |if ($checkFallbackForGeneratedHashMap) { + | ${vectorizedRowKeys.map(_.code).mkString("\n")} + | if (${vectorizedRowKeys.map("!" + _.isNull).mkString(" && ")}) { + | $vectorizedRowBuffer = $vectorizedHashMapTerm.findOrInsert( + | ${vectorizedRowKeys.map(_.value).mkString(", ")}); + | } + |} + """.stripMargin) + } else { + None + } + } + + val updateRowInVectorizedHashMap: Option[String] = { + if (isVectorizedHashMapEnabled) { + ctx.INPUT_ROW = vectorizedRowBuffer + val vectorizedRowEvals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx)) + val updateVectorizedRow = vectorizedRowEvals.zipWithIndex.map { case (ev, i) => + val dt = updateExpr(i).dataType + ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable) + } + Option( + s""" + |// evaluate aggregate function + |${evaluateVariables(vectorizedRowEvals)} + |// update vectorized row + |${updateVectorizedRow.mkString("\n").trim} + """.stripMargin) + } else None + } + + // Next, we generate code to probe and update the unsafe row hash map. + val findOrInsertInUnsafeRowMap: String = { + s""" + | if ($vectorizedRowBuffer == null) { + | // generate grouping key + | ${unsafeRowKeyCode.code.trim} + | ${hashEval.code.trim} + | if ($checkFallbackForBytesToBytesMap) { + | // try to get the buffer from hash map + | $unsafeRowBuffer = + | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, ${hashEval.value}); + | } + | if ($unsafeRowBuffer == null) { + | if ($sorterTerm == null) { + | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); + | } else { + | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); + | } + | $resetCounter + | // the hash map had be spilled, it should have enough memory now, + | // try to allocate buffer again. + | $unsafeRowBuffer = + | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, ${hashEval.value}); + | if ($unsafeRowBuffer == null) { + | // failed to allocate the first page + | throw new OutOfMemoryError("No enough memory for aggregation"); + | } + | } + | } + """.stripMargin + } + + val updateRowInUnsafeRowMap: String = { + ctx.INPUT_ROW = unsafeRowBuffer + val unsafeRowBufferEvals = + updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx)) + val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => + val dt = updateExpr(i).dataType + ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) + } + s""" + |// evaluate aggregate function + |${evaluateVariables(unsafeRowBufferEvals)} + |// update unsafe row buffer + |${updateUnsafeRowBuffer.mkString("\n").trim} + """.stripMargin + } + + + // We try to do hash map based in-memory aggregation first. If there is not enough memory (the + // hash map will return null for new key), we spill the hash map to disk to free memory, then + // continue to do in-memory aggregation and spilling until all the rows had been processed. + // Finally, sort the spilled aggregate buffers by key, and merge them together for same key. + s""" + UnsafeRow $unsafeRowBuffer = null; + org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $vectorizedRowBuffer = null; + + ${findOrInsertInVectorizedHashMap.getOrElse("")} + + $findOrInsertInUnsafeRowMap + + $incCounter + + if ($vectorizedRowBuffer != null) { + // update vectorized row + ${updateRowInVectorizedHashMap.getOrElse("")} + } else { + // update unsafe row + $updateRowInUnsafeRowMap + } + """ + } + override def simpleString: String = { - val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions + val allAggregateExpressions = aggregateExpressions testFallbackStartsAt match { case None => @@ -135,11 +718,8 @@ case class TungstenAggregate( } object TungstenAggregate { - def supportsAggregate( - groupingExpressions: Seq[Expression], - aggregateBufferAttributes: Seq[Attribute]): Boolean = { + def supportsAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = { val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes) - UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && - UnsafeProjection.canSupport(groupingExpressions) + UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index ce8d592c368ee..09384a482d9fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -17,17 +17,16 @@ package org.apache.spark.sql.execution.aggregate -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.unsafe.KVIterator -import org.apache.spark.{InternalAccumulator, Logging, TaskContext} +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap} +import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnsafeKVExternalSorter} import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.KVIterator /** * An iterator used to evaluate aggregate functions. It operates on [[UnsafeRow]]s. @@ -63,15 +62,11 @@ import org.apache.spark.sql.types.StructType * * @param groupingExpressions * expressions for grouping keys - * @param nonCompleteAggregateExpressions - * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Partial]], - * [[PartialMerge]], or [[Final]]. - * @param nonCompleteAggregateAttributes the attributes of the nonCompleteAggregateExpressions' + * @param aggregateExpressions + * [[AggregateExpression]] containing [[AggregateFunction]]s with mode [[Partial]], + * [[PartialMerge]], or [[Final]]. + * @param aggregateAttributes the attributes of the aggregateExpressions' * outputs when they are stored in the final aggregation buffer. - * @param completeAggregateExpressions - * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Complete]]. - * @param completeAggregateAttributes the attributes of completeAggregateExpressions' outputs - * when they are stored in the final aggregation buffer. * @param resultExpressions * expressions for generating output rows. * @param newMutableProjection @@ -83,392 +78,73 @@ import org.apache.spark.sql.types.StructType */ class TungstenAggregationIterator( groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), originalInputAttributes: Seq[Attribute], inputIter: Iterator[InternalRow], - testFallbackStartsAt: Option[Int], - numInputRows: LongSQLMetric, + testFallbackStartsAt: Option[(Int, Int)], numOutputRows: LongSQLMetric, dataSize: LongSQLMetric, spillSize: LongSQLMetric) - extends Iterator[UnsafeRow] with Logging { + extends AggregationIterator( + groupingExpressions, + originalInputAttributes, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection) with Logging { /////////////////////////////////////////////////////////////////////////// // Part 1: Initializing aggregate functions. /////////////////////////////////////////////////////////////////////////// - // A Seq containing all AggregateExpressions. - // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final - // are at the beginning of the allAggregateExpressions. - private[this] val allAggregateExpressions: Seq[AggregateExpression2] = - nonCompleteAggregateExpressions ++ completeAggregateExpressions - - // Check to make sure we do not have more than three modes in our AggregateExpressions. - // If we have, users are hitting a bug and we throw an IllegalStateException. - if (allAggregateExpressions.map(_.mode).distinct.length > 2) { - throw new IllegalStateException( - s"$allAggregateExpressions should have no more than 2 kinds of modes.") - } - // Remember spill data size of this task before execute this operator so that we can // figure out how many bytes we spilled for this operator. private val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled - // - // The modes of AggregateExpressions. Right now, we can handle the following mode: - // - Partial-only: - // All AggregateExpressions have the mode of Partial. - // For this case, aggregationMode is (Some(Partial), None). - // - PartialMerge-only: - // All AggregateExpressions have the mode of PartialMerge). - // For this case, aggregationMode is (Some(PartialMerge), None). - // - Final-only: - // All AggregateExpressions have the mode of Final. - // For this case, aggregationMode is (Some(Final), None). - // - Final-Complete: - // Some AggregateExpressions have the mode of Final and - // others have the mode of Complete. For this case, - // aggregationMode is (Some(Final), Some(Complete)). - // - Complete-only: - // nonCompleteAggregateExpressions is empty and we have AggregateExpressions - // with mode Complete in completeAggregateExpressions. For this case, - // aggregationMode is (None, Some(Complete)). - // - Grouping-only: - // There is no AggregateExpression. For this case, AggregationMode is (None,None). - // - private[this] var aggregationMode: (Option[AggregateMode], Option[AggregateMode]) = { - nonCompleteAggregateExpressions.map(_.mode).distinct.headOption -> - completeAggregateExpressions.map(_.mode).distinct.headOption - } - - // Initialize all AggregateFunctions by binding references, if necessary, - // and setting inputBufferOffset and mutableBufferOffset. - private def initializeAllAggregateFunctions( - startingInputBufferOffset: Int): Array[AggregateFunction2] = { - var mutableBufferOffset = 0 - var inputBufferOffset: Int = startingInputBufferOffset - val functions = new Array[AggregateFunction2](allAggregateExpressions.length) - var i = 0 - while (i < allAggregateExpressions.length) { - val func = allAggregateExpressions(i).aggregateFunction - val aggregateExpressionIsNonComplete = i < nonCompleteAggregateExpressions.length - // We need to use this mode instead of func.mode in order to handle aggregation mode switching - // when switching to sort-based aggregation: - val mode = if (aggregateExpressionIsNonComplete) aggregationMode._1 else aggregationMode._2 - val funcWithBoundReferences = mode match { - case Some(Partial) | Some(Complete) if func.isInstanceOf[ImperativeAggregate] => - // We need to create BoundReferences if the function is not an - // expression-based aggregate function (it does not support code-gen) and the mode of - // this function is Partial or Complete because we will call eval of this - // function's children in the update method of this aggregate function. - // Those eval calls require BoundReferences to work. - BindReferences.bindReference(func, originalInputAttributes) - case _ => - // We only need to set inputBufferOffset for aggregate functions with mode - // PartialMerge and Final. - val updatedFunc = func match { - case function: ImperativeAggregate => - function.withNewInputAggBufferOffset(inputBufferOffset) - case function => function - } - inputBufferOffset += func.aggBufferSchema.length - updatedFunc - } - val funcWithUpdatedAggBufferOffset = funcWithBoundReferences match { - case function: ImperativeAggregate => - // Set mutableBufferOffset for this function. It is important that setting - // mutableBufferOffset happens after all potential bindReference operations - // because bindReference will create a new instance of the function. - function.withNewMutableAggBufferOffset(mutableBufferOffset) - case function => function - } - mutableBufferOffset += funcWithUpdatedAggBufferOffset.aggBufferSchema.length - functions(i) = funcWithUpdatedAggBufferOffset - i += 1 - } - functions - } - - private[this] var allAggregateFunctions: Array[AggregateFunction2] = - initializeAllAggregateFunctions(initialInputBufferOffset) - - // Positions of those imperative aggregate functions in allAggregateFunctions. - // For example, say that we have func1, func2, func3, func4 in aggregateFunctions, and - // func2 and func3 are imperative aggregate functions. Then - // allImperativeAggregateFunctionPositions will be [1, 2]. Note that this does not need to be - // updated when falling back to sort-based aggregation because the positions of the aggregate - // functions do not change in that case. - private[this] val allImperativeAggregateFunctionPositions: Array[Int] = { - val positions = new ArrayBuffer[Int]() - var i = 0 - while (i < allAggregateFunctions.length) { - allAggregateFunctions(i) match { - case agg: DeclarativeAggregate => - case _ => positions += i - } - i += 1 - } - positions.toArray - } - /////////////////////////////////////////////////////////////////////////// // Part 2: Methods and fields used by setting aggregation buffer values, // processing input rows from inputIter, and generating output // rows. /////////////////////////////////////////////////////////////////////////// - // The projection used to initialize buffer values for all expression-based aggregates. - // Note that this projection does not need to be updated when switching to sort-based aggregation - // because the schema of empty aggregation buffers does not change in that case. - private[this] val expressionAggInitialProjection: MutableProjection = { - val initExpressions = allAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.initialValues - // For the positions corresponding to imperative aggregate functions, we'll use special - // no-op expressions which are ignored during projection code-generation. - case i: ImperativeAggregate => Seq.fill(i.aggBufferAttributes.length)(NoOp) - } - newMutableProjection(initExpressions, Nil)() - } - // Creates a new aggregation buffer and initializes buffer values. - // This function should be only called at most three times (when we create the hash map, - // when we switch to sort-based aggregation, and when we create the re-used buffer for - // sort-based aggregation). + // This function should be only called at most two times (when we create the hash map, + // and when we create the re-used buffer for sort-based aggregation). private def createNewAggregationBuffer(): UnsafeRow = { - val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes) + val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes) val buffer: UnsafeRow = UnsafeProjection.create(bufferSchema.map(_.dataType)) .apply(new GenericMutableRow(bufferSchema.length)) // Initialize declarative aggregates' buffer values expressionAggInitialProjection.target(buffer)(EmptyRow) // Initialize imperative aggregates' buffer values - allAggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer)) + aggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer)) buffer } - // Creates a function used to process a row based on the given inputAttributes. - private def generateProcessRow( - inputAttributes: Seq[Attribute]): (UnsafeRow, InternalRow) => Unit = { - - val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.aggBufferAttributes) - val joinedRow = new JoinedRow() - - aggregationMode match { - // Partial-only - case (Some(Partial), None) => - val updateExpressions = allAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val imperativeAggregateFunctions: Array[ImperativeAggregate] = - allAggregateFunctions.collect { case func: ImperativeAggregate => func} - val expressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() - - (currentBuffer: UnsafeRow, row: InternalRow) => { - expressionAggUpdateProjection.target(currentBuffer) - // Process all expression-based aggregate functions. - expressionAggUpdateProjection(joinedRow(currentBuffer, row)) - // Process all imperative aggregate functions - var i = 0 - while (i < imperativeAggregateFunctions.length) { - imperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - } - - // PartialMerge-only or Final-only - case (Some(PartialMerge), None) | (Some(Final), None) => - val mergeExpressions = allAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val imperativeAggregateFunctions: Array[ImperativeAggregate] = - allAggregateFunctions.collect { case func: ImperativeAggregate => func} - // This projection is used to merge buffer values for all expression-based aggregates. - val expressionAggMergeProjection = - newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)() - - (currentBuffer: UnsafeRow, row: InternalRow) => { - // Process all expression-based aggregate functions. - expressionAggMergeProjection.target(currentBuffer)(joinedRow(currentBuffer, row)) - // Process all imperative aggregate functions. - var i = 0 - while (i < imperativeAggregateFunctions.length) { - imperativeAggregateFunctions(i).merge(currentBuffer, row) - i += 1 - } - } - - // Final-Complete - case (Some(Final), Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction2] = - allAggregateFunctions.takeRight(completeAggregateExpressions.length) - val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = - completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - val nonCompleteAggregateFunctions: Array[AggregateFunction2] = - allAggregateFunctions.take(nonCompleteAggregateExpressions.length) - val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] = - nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func } - - val completeOffsetExpressions = - Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) - val mergeExpressions = - nonCompleteAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } ++ completeOffsetExpressions - val finalMergeProjection = - newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)() - - // We do not touch buffer values of aggregate functions with the Final mode. - val finalOffsetExpressions = - Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) - val updateExpressions = finalOffsetExpressions ++ completeAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val completeUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() - - (currentBuffer: UnsafeRow, row: InternalRow) => { - val input = joinedRow(currentBuffer, row) - // For all aggregate functions with mode Complete, update buffers. - completeUpdateProjection.target(currentBuffer)(input) - var i = 0 - while (i < completeImperativeAggregateFunctions.length) { - completeImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - - // For all aggregate functions with mode Final, merge buffer values in row to - // currentBuffer. - finalMergeProjection.target(currentBuffer)(input) - i = 0 - while (i < nonCompleteImperativeAggregateFunctions.length) { - nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row) - i += 1 - } - } - - // Complete-only - case (None, Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction2] = - allAggregateFunctions.takeRight(completeAggregateExpressions.length) - // All imperative aggregate functions with mode Complete. - val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = - completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - - val updateExpressions = completeAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val completeExpressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() - - (currentBuffer: UnsafeRow, row: InternalRow) => { - // For all aggregate functions with mode Complete, update buffers. - completeExpressionAggUpdateProjection.target(currentBuffer)(joinedRow(currentBuffer, row)) - var i = 0 - while (i < completeImperativeAggregateFunctions.length) { - completeImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - } - - // Grouping only. - case (None, None) => (currentBuffer: UnsafeRow, row: InternalRow) => {} - - case other => - throw new IllegalStateException( - s"${aggregationMode} should not be passed into TungstenAggregationIterator.") - } - } - // Creates a function used to generate output rows. - private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow = { - - val groupingAttributes = groupingExpressions.map(_.toAttribute) - val bufferAttributes = allAggregateFunctions.flatMap(_.aggBufferAttributes) - - aggregationMode match { - // Partial-only or PartialMerge-only: every output row is basically the values of - // the grouping expressions and the corresponding aggregation buffer. - case (Some(Partial), None) | (Some(PartialMerge), None) => - val groupingKeySchema = StructType.fromAttributes(groupingAttributes) - val bufferSchema = StructType.fromAttributes(bufferAttributes) - val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) - - (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { - unsafeRowJoiner.join(currentGroupingKey, currentBuffer) - } - - // Final-only, Complete-only and Final-Complete: a output row is generated based on - // resultExpressions. - case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => - val joinedRow = new JoinedRow() - val evalExpressions = allAggregateFunctions.map { - case ae: DeclarativeAggregate => ae.evaluateExpression - case agg: AggregateFunction2 => NoOp - } - val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)() - // These are the attributes of the row produced by `expressionAggEvalProjection` - val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes - val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType)) - expressionAggEvalProjection.target(aggregateResult) - val resultProjection = - UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateResultSchema) - - val allImperativeAggregateFunctions: Array[ImperativeAggregate] = - allAggregateFunctions.collect { case func: ImperativeAggregate => func} - - (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { - // Generate results for all expression-based aggregate functions. - expressionAggEvalProjection(currentBuffer) - // Generate results for all imperative aggregate functions. - var i = 0 - while (i < allImperativeAggregateFunctions.length) { - aggregateResult.update( - allImperativeAggregateFunctionPositions(i), - allImperativeAggregateFunctions(i).eval(currentBuffer)) - i += 1 - } - resultProjection(joinedRow(currentGroupingKey, aggregateResult)) - } - - // Grouping-only: a output row is generated from values of grouping expressions. - case (None, None) => - val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes) - - (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { - resultProjection(currentGroupingKey) - } - - case other => - throw new IllegalStateException( - s"${aggregationMode} should not be passed into TungstenAggregationIterator.") + override protected def generateResultProjection(): (UnsafeRow, MutableRow) => UnsafeRow = { + val modes = aggregateExpressions.map(_.mode).distinct + if (modes.nonEmpty && !modes.contains(Final) && !modes.contains(Complete)) { + // Fast path for partial aggregation, UnsafeRowJoiner is usually faster than projection + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val bufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes) + val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + val bufferSchema = StructType.fromAttributes(bufferAttributes) + val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + unsafeRowJoiner.join(currentGroupingKey, currentBuffer.asInstanceOf[UnsafeRow]) + } + } else { + super.generateResultProjection() } } - // An UnsafeProjection used to extract grouping keys from the input rows. - private[this] val groupProjection = - UnsafeProjection.create(groupingExpressions, originalInputAttributes) - - // A function used to process a input row. Its first argument is the aggregation buffer - // and the second argument is the input row. - private[this] var processRow: (UnsafeRow, InternalRow) => Unit = - generateProcessRow(originalInputAttributes) - - // A function used to generate output rows based on the grouping keys (first argument) - // and the corresponding aggregation buffer (second argument). - private[this] var generateOutput: (UnsafeRow, UnsafeRow) => UnsafeRow = - generateResultProjection() - // An aggregation buffer containing initial buffer values. It is used to // initialize other aggregation buffers. private[this] val initialAggregationBuffer: UnsafeRow = createNewAggregationBuffer() @@ -482,7 +158,7 @@ class TungstenAggregationIterator( // all groups and their corresponding aggregation buffers for hash-based aggregation. private[this] val hashMap = new UnsafeFixedWidthAggregationMap( initialAggregationBuffer, - StructType.fromAttributes(allAggregateFunctions.flatMap(_.aggBufferAttributes)), + StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)), StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), TaskContext.get().taskMemoryManager(), 1024 * 16, // initial capacity @@ -495,25 +171,23 @@ class TungstenAggregationIterator( // hashMap. If there is not enough memory, it will multiple hash-maps, spilling // after each becomes full then using sort to merge these spills, finally do sort // based aggregation. - private def processInputs(fallbackStartsAt: Int): Unit = { + private def processInputs(fallbackStartsAt: (Int, Int)): Unit = { if (groupingExpressions.isEmpty) { // If there is no grouping expressions, we can just reuse the same buffer over and over again. // Note that it would be better to eliminate the hash map entirely in the future. - val groupingKey = groupProjection.apply(null) + val groupingKey = groupingProjection.apply(null) val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) while (inputIter.hasNext) { val newInput = inputIter.next() - numInputRows += 1 processRow(buffer, newInput) } } else { var i = 0 while (inputIter.hasNext) { val newInput = inputIter.next() - numInputRows += 1 - val groupingKey = groupProjection.apply(newInput) + val groupingKey = groupingProjection.apply(newInput) var buffer: UnsafeRow = null - if (i < fallbackStartsAt) { + if (i < fallbackStartsAt._2) { buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) } if (buffer == null) { @@ -565,25 +239,18 @@ class TungstenAggregationIterator( private def switchToSortBasedAggregation(): Unit = { logInfo("falling back to sort based aggregation.") - // Set aggregationMode, processRow, and generateOutput for sort-based aggregation. - val newAggregationMode = aggregationMode match { - case (Some(Partial), None) => (Some(PartialMerge), None) - case (None, Some(Complete)) => (Some(Final), None) - case (Some(Final), Some(Complete)) => (Some(Final), None) + // Basically the value of the KVIterator returned by externalSorter + // will be just aggregation buffer, so we rewrite the aggregateExpressions to reflect it. + val newExpressions = aggregateExpressions.map { + case agg @ AggregateExpression(_, Partial, _, _) => + agg.copy(mode = PartialMerge) + case agg @ AggregateExpression(_, Complete, _, _) => + agg.copy(mode = Final) case other => other } - aggregationMode = newAggregationMode - - allAggregateFunctions = initializeAllAggregateFunctions(startingInputBufferOffset = 0) - - // Basically the value of the KVIterator returned by externalSorter - // will just aggregation buffer. At here, we use inputAggBufferAttributes. - val newInputAttributes: Seq[Attribute] = - allAggregateFunctions.flatMap(_.inputAggBufferAttributes) - - // Set up new processRow and generateOutput. - processRow = generateProcessRow(newInputAttributes) - generateOutput = generateResultProjection() + val newFunctions = initializeAggregateFunctions(newExpressions, 0) + val newInputAttributes = newFunctions.flatMap(_.inputAggBufferAttributes) + sortBasedProcessRow = generateProcessRow(newExpressions, newFunctions, newInputAttributes) // Step 5: Get the sorted iterator from the externalSorter. sortedKVIterator = externalSorter.sortedIterator() @@ -632,6 +299,9 @@ class TungstenAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: UnsafeRow = createNewAggregationBuffer() + // The function used to process rows in a group + private[this] var sortBasedProcessRow: (MutableRow, InternalRow) => Unit = null + // Processes rows in the current group. It will stop when it find a new group. private def processCurrentSortedGroup(): Unit = { // First, we need to copy nextGroupingKey to currentGroupingKey. @@ -640,7 +310,7 @@ class TungstenAggregationIterator( // We create a variable to track if we see the next group. var findNextPartition = false // firstRowInNextGroup is the first row of this group. We first process it. - processRow(sortBasedAggregationBuffer, firstRowInNextGroup) + sortBasedProcessRow(sortBasedAggregationBuffer, firstRowInNextGroup) // The search will stop when we see the next group or there is no // input row left in the iter. @@ -655,16 +325,15 @@ class TungstenAggregationIterator( // Check if the current row belongs the current input row. if (currentGroupingKey.equals(groupingKey)) { - processRow(sortBasedAggregationBuffer, inputAggregationBuffer) + sortBasedProcessRow(sortBasedAggregationBuffer, inputAggregationBuffer) hasNext = sortedKVIterator.next() } else { // We find a new group. findNextPartition = true // copyFrom will fail when - nextGroupingKey.copyFrom(groupingKey) // = groupingKey.copy() - firstRowInNextGroup.copyFrom(inputAggregationBuffer) // = inputAggregationBuffer.copy() - + nextGroupingKey.copyFrom(groupingKey) + firstRowInNextGroup.copyFrom(inputAggregationBuffer) } } // We have not seen a new group. It means that there is no new row in the input @@ -683,7 +352,7 @@ class TungstenAggregationIterator( /** * Start processing input rows. */ - processInputs(testFallbackStartsAt.getOrElse(Int.MaxValue)) + processInputs(testFallbackStartsAt.getOrElse((Int.MaxValue, Int.MaxValue))) // If we did not switch to sort-based aggregation in processInputs, // we pre-load the first key-value pair from the map (to make hasNext idempotent). @@ -747,10 +416,10 @@ class TungstenAggregationIterator( val mapMemory = hashMap.getPeakMemoryUsedBytes val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L) val peakMemory = Math.max(mapMemory, sorterMemory) + val metrics = TaskContext.get().taskMetrics() dataSize += peakMemory - spillSize += TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore - TaskContext.get().internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemory) + spillSize += metrics.memoryBytesSpilled - spillSizeBefore + metrics.incPeakExecutionMemory(peakMemory) } numOutputRows += 1 res diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala new file mode 100644 index 0000000000000..535e64cb34442 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import scala.language.existentials + +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.types._ + +object TypedAggregateExpression { + def apply[BUF : Encoder, OUT : Encoder]( + aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = { + val bufferEncoder = encoderFor[BUF] + // We will insert the deserializer and function call expression at the bottom of each serializer + // expression while executing `TypedAggregateExpression`, which means multiply serializer + // expressions will all evaluate the same sub-expression at bottom. To avoid the re-evaluating, + // here we always use one single serializer expression to serialize the buffer object into a + // single-field row, no matter whether the encoder is flat or not. We also need to update the + // deserializer to read in all fields from that single-field row. + // TODO: remove this trick after we have better integration of subexpression elimination and + // whole stage codegen. + val bufferSerializer = if (bufferEncoder.flat) { + bufferEncoder.namedExpressions.head + } else { + Alias(CreateStruct(bufferEncoder.serializer), "buffer")() + } + + val bufferDeserializer = if (bufferEncoder.flat) { + bufferEncoder.deserializer transformUp { + case b: BoundReference => bufferSerializer.toAttribute + } + } else { + bufferEncoder.deserializer transformUp { + case UnresolvedAttribute(nameParts) => + assert(nameParts.length == 1) + UnresolvedExtractValue(bufferSerializer.toAttribute, Literal(nameParts.head)) + case BoundReference(ordinal, dt, _) => GetStructField(bufferSerializer.toAttribute, ordinal) + } + } + + val outputEncoder = encoderFor[OUT] + val outputType = if (outputEncoder.flat) { + outputEncoder.schema.head.dataType + } else { + outputEncoder.schema + } + + new TypedAggregateExpression( + aggregator.asInstanceOf[Aggregator[Any, Any, Any]], + None, + bufferSerializer, + bufferDeserializer, + outputEncoder.serializer, + outputEncoder.deserializer.dataType, + outputType) + } +} + +/** + * A helper class to hook [[Aggregator]] into the aggregation system. + */ +case class TypedAggregateExpression( + aggregator: Aggregator[Any, Any, Any], + inputDeserializer: Option[Expression], + bufferSerializer: NamedExpression, + bufferDeserializer: Expression, + outputSerializer: Seq[Expression], + outputExternalType: DataType, + dataType: DataType) extends DeclarativeAggregate with NonSQLExpression { + + override def nullable: Boolean = true + + override def deterministic: Boolean = true + + override def children: Seq[Expression] = inputDeserializer.toSeq :+ bufferDeserializer + + override lazy val resolved: Boolean = inputDeserializer.isDefined && childrenResolved + + override def references: AttributeSet = AttributeSet(inputDeserializer.toSeq) + + override def inputTypes: Seq[AbstractDataType] = Nil + + private def aggregatorLiteral = + Literal.create(aggregator, ObjectType(classOf[Aggregator[Any, Any, Any]])) + + private def bufferExternalType = bufferDeserializer.dataType + + override lazy val aggBufferAttributes: Seq[AttributeReference] = + bufferSerializer.toAttribute.asInstanceOf[AttributeReference] :: Nil + + override lazy val initialValues: Seq[Expression] = { + val zero = Literal.fromObject(aggregator.zero, bufferExternalType) + ReferenceToExpressions(bufferSerializer, zero :: Nil) :: Nil + } + + override lazy val updateExpressions: Seq[Expression] = { + val reduced = Invoke( + aggregatorLiteral, + "reduce", + bufferExternalType, + bufferDeserializer :: inputDeserializer.get :: Nil) + + ReferenceToExpressions(bufferSerializer, reduced :: Nil) :: Nil + } + + override lazy val mergeExpressions: Seq[Expression] = { + val leftBuffer = bufferDeserializer transform { + case a: AttributeReference => a.left + } + val rightBuffer = bufferDeserializer transform { + case a: AttributeReference => a.right + } + val merged = Invoke( + aggregatorLiteral, + "merge", + bufferExternalType, + leftBuffer :: rightBuffer :: Nil) + + ReferenceToExpressions(bufferSerializer, merged :: Nil) :: Nil + } + + override lazy val evaluateExpression: Expression = { + val resultObj = Invoke( + aggregatorLiteral, + "finish", + outputExternalType, + bufferDeserializer :: Nil) + + dataType match { + case s: StructType => + ReferenceToExpressions(CreateStruct(outputSerializer), resultObj :: Nil) + case _ => + assert(outputSerializer.length == 1) + outputSerializer.head transform { + case b: BoundReference => resultObj + } + } + } + + override def toString: String = { + val input = inputDeserializer match { + case Some(UnresolvedDeserializer(deserializer, _)) => deserializer.dataType.simpleString + case Some(deserializer) => deserializer.dataType.simpleString + case _ => "unknown" + } + + s"$nodeName($input)" + } + + override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$") +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala new file mode 100644 index 0000000000000..395cc7ab91709 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.types.StructType + +/** + * This is a helper class to generate an append-only vectorized hash map that can act as a 'cache' + * for extremely fast key-value lookups while evaluating aggregates (and fall back to the + * `BytesToBytesMap` if a given key isn't found). This is 'codegened' in TungstenAggregate to speed + * up aggregates w/ key. + * + * It is backed by a power-of-2-sized array for index lookups and a columnar batch that stores the + * key-value pairs. The index lookups in the array rely on linear probing (with a small number of + * maximum tries) and use an inexpensive hash function which makes it really efficient for a + * majority of lookups. However, using linear probing and an inexpensive hash function also makes it + * less robust as compared to the `BytesToBytesMap` (especially for a large number of keys or even + * for certain distribution of keys) and requires us to fall back on the latter for correctness. We + * also use a secondary columnar batch that logically projects over the original columnar batch and + * is equivalent to the `BytesToBytesMap` aggregate buffer. + * + * NOTE: This vectorized hash map currently doesn't support nullable keys and falls back to the + * `BytesToBytesMap` to store them. + */ +class VectorizedHashMapGenerator( + ctx: CodegenContext, + generatedClassName: String, + groupingKeySchema: StructType, + bufferSchema: StructType) { + val groupingKeys = groupingKeySchema.map(k => (k.dataType.typeName, ctx.freshName("key"))) + val bufferValues = bufferSchema.map(k => (k.dataType.typeName, ctx.freshName("value"))) + val groupingKeySignature = groupingKeys.map(_.productIterator.toList.mkString(" ")).mkString(", ") + + def generate(): String = { + s""" + |public class $generatedClassName { + |${initializeAggregateHashMap()} + | + |${generateFindOrInsert()} + | + |${generateEquals()} + | + |${generateHashFunction()} + | + |${generateRowIterator()} + | + |${generateClose()} + |} + """.stripMargin + } + + private def initializeAggregateHashMap(): String = { + val generatedSchema: String = + s""" + |new org.apache.spark.sql.types.StructType() + |${(groupingKeySchema ++ bufferSchema).map(key => + s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""") + .mkString("\n")}; + """.stripMargin + + val generatedAggBufferSchema: String = + s""" + |new org.apache.spark.sql.types.StructType() + |${bufferSchema.map(key => + s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""") + .mkString("\n")}; + """.stripMargin + + s""" + | private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch; + | private org.apache.spark.sql.execution.vectorized.ColumnarBatch aggregateBufferBatch; + | private int[] buckets; + | private int numBuckets; + | private int maxSteps; + | private int numRows = 0; + | private org.apache.spark.sql.types.StructType schema = $generatedSchema + | private org.apache.spark.sql.types.StructType aggregateBufferSchema = + | $generatedAggBufferSchema + | + | public $generatedClassName() { + | // TODO: These should be generated based on the schema + | int DEFAULT_CAPACITY = 1 << 16; + | double DEFAULT_LOAD_FACTOR = 0.25; + | int DEFAULT_MAX_STEPS = 2; + | assert (DEFAULT_CAPACITY > 0 && ((DEFAULT_CAPACITY & (DEFAULT_CAPACITY - 1)) == 0)); + | this.maxSteps = DEFAULT_MAX_STEPS; + | numBuckets = (int) (DEFAULT_CAPACITY / DEFAULT_LOAD_FACTOR); + | + | batch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate(schema, + | org.apache.spark.memory.MemoryMode.ON_HEAP, DEFAULT_CAPACITY); + | + | // TODO: Possibly generate this projection in TungstenAggregate directly + | aggregateBufferBatch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate( + | aggregateBufferSchema, org.apache.spark.memory.MemoryMode.ON_HEAP, DEFAULT_CAPACITY); + | for (int i = 0 ; i < aggregateBufferBatch.numCols(); i++) { + | aggregateBufferBatch.setColumn(i, batch.column(i+${groupingKeys.length})); + | } + | + | buckets = new int[numBuckets]; + | java.util.Arrays.fill(buckets, -1); + | } + """.stripMargin + } + + /** + * Generates a method that computes a hash by currently xor-ing all individual group-by keys. For + * instance, if we have 2 long group-by keys, the generated function would be of the form: + * + * {{{ + * private long hash(long agg_key, long agg_key1) { + * return agg_key ^ agg_key1; + * } + * }}} + */ + private def generateHashFunction(): String = { + s""" + |// TODO: Improve this hash function + |private long hash($groupingKeySignature) { + | return ${groupingKeys.map(_._2).mkString(" | ")}; + |} + """.stripMargin + } + + /** + * Generates a method that returns true if the group-by keys exist at a given index in the + * associated [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we + * have 2 long group-by keys, the generated function would be of the form: + * + * {{{ + * private boolean equals(int idx, long agg_key, long agg_key1) { + * return batch.column(0).getLong(buckets[idx]) == agg_key && + * batch.column(1).getLong(buckets[idx]) == agg_key1; + * } + * }}} + */ + private def generateEquals(): String = { + s""" + |private boolean equals(int idx, $groupingKeySignature) { + | return ${groupingKeys.zipWithIndex.map(k => + s"batch.column(${k._2}).getLong(buckets[idx]) == ${k._1._2}").mkString(" && ")}; + |} + """.stripMargin + } + + /** + * Generates a method that returns a mutable + * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row]] which keeps track of the + * aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the + * generated method adds the corresponding row in the associated + * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we + * have 2 long group-by keys, the generated function would be of the form: + * + * {{{ + * public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert( + * long agg_key, long agg_key1) { + * long h = hash(agg_key, agg_key1); + * int step = 0; + * int idx = (int) h & (numBuckets - 1); + * while (step < maxSteps) { + * // Return bucket index if it's either an empty slot or already contains the key + * if (buckets[idx] == -1) { + * batch.column(0).putLong(numRows, agg_key); + * batch.column(1).putLong(numRows, agg_key1); + * batch.column(2).putLong(numRows, 0); + * buckets[idx] = numRows++; + * return batch.getRow(buckets[idx]); + * } else if (equals(idx, agg_key, agg_key1)) { + * return batch.getRow(buckets[idx]); + * } + * idx = (idx + 1) & (numBuckets - 1); + * step++; + * } + * // Didn't find it + * return null; + * } + * }}} + */ + private def generateFindOrInsert(): String = { + s""" + |public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(${ + groupingKeySignature}) { + | long h = hash(${groupingKeys.map(_._2).mkString(", ")}); + | int step = 0; + | int idx = (int) h & (numBuckets - 1); + | while (step < maxSteps) { + | // Return bucket index if it's either an empty slot or already contains the key + | if (buckets[idx] == -1) { + | ${groupingKeys.zipWithIndex.map(k => + s"batch.column(${k._2}).putLong(numRows, ${k._1._2});").mkString("\n")} + | ${bufferValues.zipWithIndex.map(k => + s"batch.column(${groupingKeys.length + k._2}).putNull(numRows);") + .mkString("\n")} + | buckets[idx] = numRows++; + | batch.setNumRows(numRows); + | aggregateBufferBatch.setNumRows(numRows); + | return aggregateBufferBatch.getRow(buckets[idx]); + | } else if (equals(idx, ${groupingKeys.map(_._2).mkString(", ")})) { + | return aggregateBufferBatch.getRow(buckets[idx]); + | } + | idx = (idx + 1) & (numBuckets - 1); + | step++; + | } + | // Didn't find it + | return null; + |} + """.stripMargin + } + + private def generateRowIterator(): String = { + s""" + |public java.util.Iterator + | rowIterator() { + | return batch.rowIterator(); + |} + """.stripMargin + } + + private def generateClose(): String = { + s""" + |public void close() { + | batch.close(); + |} + """.stripMargin + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala new file mode 100644 index 0000000000000..c39a78da6f9be --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.api.java.function.MapFunction +import org.apache.spark.sql.{Encoder, TypedColumn} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.expressions.Aggregator + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines internal implementations for aggregators. +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double] { + override def zero: Double = 0.0 + override def reduce(b: Double, a: IN): Double = b + f(a) + override def merge(b1: Double, b2: Double): Double = b1 + b2 + override def finish(reduction: Double): Double = reduction + + override def bufferEncoder: Encoder[Double] = ExpressionEncoder[Double]() + override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]() + + // Java api support + def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double]) + + def toColumnJava: TypedColumn[IN, java.lang.Double] = { + toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]] + } +} + + +class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] { + override def zero: Long = 0L + override def reduce(b: Long, a: IN): Long = b + f(a) + override def merge(b1: Long, b2: Long): Long = b1 + b2 + override def finish(reduction: Long): Long = reduction + + override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]() + override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]() + + // Java api support + def this(f: MapFunction[IN, java.lang.Long]) = this(x => f.call(x).asInstanceOf[Long]) + + def toColumnJava: TypedColumn[IN, java.lang.Long] = { + toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]] + } +} + + +class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] { + override def zero: Long = 0 + override def reduce(b: Long, a: IN): Long = { + if (f(a) == null) b else b + 1 + } + override def merge(b1: Long, b2: Long): Long = b1 + b2 + override def finish(reduction: Long): Long = reduction + + override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]() + override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]() + + // Java api support + def this(f: MapFunction[IN, Object]) = this(x => f.call(x)) + def toColumnJava: TypedColumn[IN, java.lang.Long] = { + toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]] + } +} + + +class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), Double] { + override def zero: (Double, Long) = (0.0, 0L) + override def reduce(b: (Double, Long), a: IN): (Double, Long) = (f(a) + b._1, 1 + b._2) + override def finish(reduction: (Double, Long)): Double = reduction._1 / reduction._2 + override def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = { + (b1._1 + b2._1, b1._2 + b2._2) + } + + override def bufferEncoder: Encoder[(Double, Long)] = ExpressionEncoder[(Double, Long)]() + override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]() + + // Java api support + def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double]) + def toColumnJava: TypedColumn[IN, java.lang.Double] = { + toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index d2f56e0fc14a4..f5776e7b8d49a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.execution.aggregate -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, MutableRow, _} +import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.catalyst.expressions.{MutableRow, InterpretedMutableProjection, AttributeReference, Expression} -import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, AggregateFunction2} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ @@ -324,7 +324,7 @@ private[sql] case class ScalaUDAF( udaf: UserDefinedAggregateFunction, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends ImperativeAggregate with Logging { + extends ImperativeAggregate with NonSQLExpression with Logging { override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -332,11 +332,6 @@ private[sql] case class ScalaUDAF( override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = copy(inputAggBufferOffset = newInputAggBufferOffset) - require( - children.length == udaf.inputSchema.length, - s"$udaf only accepts ${udaf.inputSchema.length} arguments, " + - s"but ${children.length} are provided.") - override def nullable: Boolean = true override def dataType: DataType = udaf.dataType @@ -366,13 +361,7 @@ private[sql] case class ScalaUDAF( val inputAttributes = childrenSchema.toAttributes log.debug( s"Creating MutableProj: $children, inputSchema: $inputAttributes.") - try { - GenerateMutableProjection.generate(children, inputAttributes)() - } catch { - case e: Exception => - log.error("Failed to generate mutable projection, fallback to interpreted", e) - new InterpretedMutableProjection(children, inputAttributes) - } + GenerateMutableProjection.generate(children, inputAttributes)() } private[this] lazy val inputToScalaConverters: Any => Any = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index eaafd83158a15..4682949fa1c7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.{StateStoreRestore, StateStoreSave} /** * Utility functions used by the query planner to convert our plan to new aggregation code path. @@ -28,43 +29,60 @@ object Utils { def planAggregateWithoutPartial( groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression2], - aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute], + aggregateExpressions: Seq[AggregateExpression], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { - val groupingAttributes = groupingExpressions.map(_.toAttribute) val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) - val completeAggregateAttributes = completeAggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } - + val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute) SortBasedAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = Nil, - nonCompleteAggregateAttributes = Nil, - completeAggregateExpressions = completeAggregateExpressions, - completeAggregateAttributes = completeAggregateAttributes, + requiredChildDistributionExpressions = Some(groupingExpressions), + groupingExpressions = groupingExpressions, + aggregateExpressions = completeAggregateExpressions, + aggregateAttributes = completeAggregateAttributes, initialInputBufferOffset = 0, resultExpressions = resultExpressions, child = child ) :: Nil } + private def createAggregate( + requiredChildDistributionExpressions: Option[Seq[Expression]] = None, + groupingExpressions: Seq[NamedExpression] = Nil, + aggregateExpressions: Seq[AggregateExpression] = Nil, + aggregateAttributes: Seq[Attribute] = Nil, + initialInputBufferOffset: Int = 0, + resultExpressions: Seq[NamedExpression] = Nil, + child: SparkPlan): SparkPlan = { + val usesTungstenAggregate = TungstenAggregate.supportsAggregate( + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) + if (usesTungstenAggregate) { + TungstenAggregate( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } else { + SortBasedAggregate( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } + } + def planAggregateWithoutDistinct( groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression2], - aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute], + aggregateExpressions: Seq[AggregateExpression], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { // Check if we can use TungstenAggregate. - val usesTungstenAggregate = - child.sqlContext.conf.unsafeEnabled && - TungstenAggregate.supportsAggregate( - groupingExpressions, - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) - // 1. Create an Aggregate Operator for partial aggregations. @@ -76,220 +94,240 @@ object Utils { groupingAttributes ++ partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - val partialAggregate = if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = None: Option[Seq[Expression]], - groupingExpressions = groupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = 0, - resultExpressions = partialResultExpressions, - child = child) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = None: Option[Seq[Expression]], + val partialAggregate = createAggregate( + requiredChildDistributionExpressions = None, groupingExpressions = groupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, + aggregateExpressions = partialAggregateExpressions, + aggregateAttributes = partialAggregateAttributes, initialInputBufferOffset = 0, resultExpressions = partialResultExpressions, child = child) - } // 2. Create an Aggregate Operator for final aggregations. val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) // The attributes of the final aggregation buffer, which is presented as input to the result // projection: - val finalAggregateAttributes = finalAggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) - val finalAggregate = if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = groupingExpressions.length, - resultExpressions = resultExpressions, - child = partialAggregate) - } else { - SortBasedAggregate( + val finalAggregate = createAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, + aggregateExpressions = finalAggregateExpressions, + aggregateAttributes = finalAggregateAttributes, initialInputBufferOffset = groupingExpressions.length, resultExpressions = resultExpressions, child = partialAggregate) - } finalAggregate :: Nil } def planAggregateWithOneDistinct( groupingExpressions: Seq[NamedExpression], - functionsWithDistinct: Seq[AggregateExpression2], - functionsWithoutDistinct: Seq[AggregateExpression2], - aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute], + functionsWithDistinct: Seq[AggregateExpression], + functionsWithoutDistinct: Seq[AggregateExpression], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { - val aggregateExpressions = functionsWithDistinct ++ functionsWithoutDistinct - val usesTungstenAggregate = - child.sqlContext.conf.unsafeEnabled && - TungstenAggregate.supportsAggregate( - groupingExpressions, - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) - // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one - // DISTINCT aggregate function, all of those functions will have the same column expression. + // DISTINCT aggregate function, all of those functions will have the same column expressions. // For example, it would be valid for functionsWithDistinct to be // [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is // disallowed because those two distinct aggregates have different column expressions. - val distinctColumnExpression: Expression = { - val allDistinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children - assert(allDistinctColumnExpressions.length == 1) - allDistinctColumnExpressions.head - } - val namedDistinctColumnExpression: NamedExpression = distinctColumnExpression match { + val distinctExpressions = functionsWithDistinct.head.aggregateFunction.children + val namedDistinctExpressions = distinctExpressions.map { case ne: NamedExpression => ne case other => Alias(other, other.toString)() } - val distinctColumnAttribute: Attribute = namedDistinctColumnExpression.toAttribute + val distinctAttributes = namedDistinctExpressions.map(_.toAttribute) val groupingAttributes = groupingExpressions.map(_.toAttribute) // 1. Create an Aggregate Operator for partial aggregations. val partialAggregate: SparkPlan = { - val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) - val partialAggregateAttributes = - partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) // We will group by the original grouping expression, plus an additional expression for the // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping // expressions will be [key, value]. - val partialAggregateGroupingExpressions = groupingExpressions :+ namedDistinctColumnExpression - val partialAggregateResult = - groupingAttributes ++ - Seq(distinctColumnAttribute) ++ - partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = None, - groupingExpressions = partialAggregateGroupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = 0, - resultExpressions = partialAggregateResult, - child = child) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = None, - groupingExpressions = partialAggregateGroupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = 0, - resultExpressions = partialAggregateResult, - child = child) - } + createAggregate( + groupingExpressions = groupingExpressions ++ namedDistinctExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + resultExpressions = groupingAttributes ++ distinctAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = child) } // 2. Create an Aggregate Operator for partial merge aggregations. val partialMergeAggregate: SparkPlan = { - val partialMergeAggregateExpressions = - functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) - val partialMergeAggregateAttributes = - partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - val partialMergeAggregateResult = - groupingAttributes ++ - Seq(distinctColumnAttribute) ++ - partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes :+ distinctColumnAttribute, - nonCompleteAggregateExpressions = partialMergeAggregateExpressions, - nonCompleteAggregateAttributes = partialMergeAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length, - resultExpressions = partialMergeAggregateResult, - child = partialAggregate) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes :+ distinctColumnAttribute, - nonCompleteAggregateExpressions = partialMergeAggregateExpressions, - nonCompleteAggregateAttributes = partialMergeAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length, - resultExpressions = partialMergeAggregateResult, - child = partialAggregate) - } + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + createAggregate( + requiredChildDistributionExpressions = + Some(groupingAttributes ++ distinctAttributes), + groupingExpressions = groupingAttributes ++ distinctAttributes, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, + resultExpressions = groupingAttributes ++ distinctAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = partialAggregate) + } + + // 3. Create an Aggregate operator for partial aggregation (for distinct) + val distinctColumnAttributeLookup = distinctExpressions.zip(distinctAttributes).toMap + val rewrittenDistinctFunctions = functionsWithDistinct.map { + // Children of an AggregateFunction with DISTINCT keyword has already + // been evaluated. At here, we need to replace original children + // to AttributeReferences. + case agg @ AggregateExpression(aggregateFunction, mode, true, _) => + aggregateFunction.transformDown(distinctColumnAttributeLookup) + .asInstanceOf[AggregateFunction] } - // 3. Create an Aggregate Operator for the final aggregation. + val partialDistinctAggregate: SparkPlan = { + val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val mergeAggregateAttributes = mergeAggregateExpressions.map(_.resultAttribute) + val (distinctAggregateExpressions, distinctAggregateAttributes) = + rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => + // We rewrite the aggregate function to a non-distinct aggregation because + // its input will have distinct arguments. + // We just keep the isDistinct setting to true, so when users look at the query plan, + // they still can see distinct aggregations. + val expr = AggregateExpression(func, Partial, isDistinct = true) + // Use original AggregationFunction to lookup attributes, which is used to build + // aggregateFunctionToAttribute + val attr = functionsWithDistinct(i).resultAttribute + (expr, attr) + }.unzip + + val partialAggregateResult = groupingAttributes ++ + mergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) ++ + distinctAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + createAggregate( + groupingExpressions = groupingAttributes, + aggregateExpressions = mergeAggregateExpressions ++ distinctAggregateExpressions, + aggregateAttributes = mergeAggregateAttributes ++ distinctAggregateAttributes, + initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, + resultExpressions = partialAggregateResult, + child = partialMergeAggregate) + } + + // 4. Create an Aggregate Operator for the final aggregation. val finalAndCompleteAggregate: SparkPlan = { val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) // The attributes of the final aggregation buffer, which is presented as input to the result // projection: - val finalAggregateAttributes = finalAggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } - - val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map { - // Children of an AggregateFunction with DISTINCT keyword has already - // been evaluated. At here, we need to replace original children - // to AttributeReferences. - case agg @ AggregateExpression2(aggregateFunction, mode, true) => - val rewrittenAggregateFunction = aggregateFunction.transformDown { - case expr if expr == distinctColumnExpression => distinctColumnAttribute - }.asInstanceOf[AggregateFunction2] + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) + + val (distinctAggregateExpressions, distinctAggregateAttributes) = + rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => // We rewrite the aggregate function to a non-distinct aggregation because // its input will have distinct arguments. // We just keep the isDistinct setting to true, so when users look at the query plan, // they still can see distinct aggregations. - val rewrittenAggregateExpression = - AggregateExpression2(rewrittenAggregateFunction, Complete, isDistinct = true) - - val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true) - (rewrittenAggregateExpression, aggregateFunctionAttribute) + val expr = AggregateExpression(func, Final, isDistinct = true) + // Use original AggregationFunction to lookup attributes, which is used to build + // aggregateFunctionToAttribute + val attr = functionsWithDistinct(i).resultAttribute + (expr, attr) }.unzip - if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = completeAggregateExpressions, - completeAggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length, - resultExpressions = resultExpressions, - child = partialMergeAggregate) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = completeAggregateExpressions, - completeAggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length, - resultExpressions = resultExpressions, - child = partialMergeAggregate) - } + + createAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = finalAggregateExpressions ++ distinctAggregateExpressions, + aggregateAttributes = finalAggregateAttributes ++ distinctAggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = resultExpressions, + child = partialDistinctAggregate) + } + + finalAndCompleteAggregate :: Nil + } + + /** + * Plans a streaming aggregation using the following progression: + * - Partial Aggregation + * - Shuffle + * - Partial Merge (now there is at most 1 tuple per group) + * - StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous) + * - PartialMerge (now there is at most 1 tuple per group) + * - StateStoreSave (saves the tuple for the next batch) + * - Complete (output the current result of the aggregation) + */ + def planStreamingAggregation( + groupingExpressions: Seq[NamedExpression], + functionsWithoutDistinct: Seq[AggregateExpression], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): Seq[SparkPlan] = { + + val groupingAttributes = groupingExpressions.map(_.toAttribute) + + val partialAggregate: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + // We will group by the original grouping expression, plus an additional expression for the + // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping + // expressions will be [key, value]. + createAggregate( + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = child) + } + + val partialMerged1: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + createAggregate( + requiredChildDistributionExpressions = + Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = partialAggregate) + } + + val restored = StateStoreRestore(groupingAttributes, None, partialMerged1) + + val partialMerged2: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + createAggregate( + requiredChildDistributionExpressions = + Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = restored) + } + + val saved = StateStoreSave(groupingAttributes, None, partialMerged2) + + val finalAndCompleteAggregate: SparkPlan = { + val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) + + createAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = finalAggregateExpressions, + aggregateAttributes = finalAggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = resultExpressions, + child = saved) } finalAndCompleteAggregate :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index d5a803f8c4b24..344aaff348e77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -17,35 +17,56 @@ package org.apache.spark.sql.execution -import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD} -import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.util.MutablePair -import org.apache.spark.util.random.PoissonSampler -import org.apache.spark.{HashPartitioner, SparkEnv} +import org.apache.spark.sql.types.LongType +import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} +case class Project(projectList: Seq[NamedExpression], child: SparkPlan) + extends UnaryNode with CodegenSupport { -case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) - override private[sql] lazy val metrics = Map( - "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } - @transient lazy val buildProjection = newMutableProjection(projectList, child.output) + override def usedInputs: AttributeSet = { + // only the attributes those are used at least twice should be evaluated before this plan, + // otherwise we could defer the evaluation until output attribute is actually used. + val usedExprIds = projectList.flatMap(_.collect { + case a: Attribute => a.exprId + }) + val usedMoreThanOnce = usedExprIds.groupBy(id => id).filter(_._2.size > 1).keySet + references.filter(a => usedMoreThanOnce.contains(a.exprId)) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + val exprs = projectList.map(x => + ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) + ctx.currentVars = input + val resultVars = exprs.map(_.gen(ctx)) + // Evaluation of non-deterministic expressions can't be deferred. + val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute) + s""" + |${evaluateRequiredVariables(output, resultVars, AttributeSet(nonDeterministicAttrs))} + |${consume(ctx, resultVars)} + """.stripMargin + } protected override def doExecute(): RDD[InternalRow] = { - val numRows = longMetric("numRows") - child.execute().mapPartitions { iter => - val reusableProjection = buildProjection() - iter.map { row => - numRows += 1 - reusableProjection(row) - } + child.execute().mapPartitionsInternal { iter => + val project = UnsafeProjection.create(projectList, child.output, + subexpressionEliminationEnabled) + iter.map(project) } } @@ -53,55 +74,130 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends } -/** - * A variant of [[Project]] that returns [[UnsafeRow]]s. - */ -case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { +case class Filter(condition: Expression, child: SparkPlan) + extends UnaryNode with CodegenSupport with PredicateHelper { - override private[sql] lazy val metrics = Map( - "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) + // Split out all the IsNotNulls from condition. + private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition { + case IsNotNull(a: NullIntolerant) if a.references.subsetOf(child.outputSet) => true + case _ => false + } - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true + // The columns that will filtered out by `IsNotNull` could be considered as not nullable. + private val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId) - override def output: Seq[Attribute] = projectList.map(_.toAttribute) + // Mark this as empty. We'll evaluate the input during doConsume(). We don't want to evaluate + // all the variables at the beginning to take advantage of short circuiting. + override def usedInputs: AttributeSet = AttributeSet.empty - /** Rewrite the project list to use unsafe expressions as needed. */ - protected val unsafeProjectList = projectList.map(_ transform { - case CreateStruct(children) => CreateStructUnsafe(children) - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) - }) - - protected override def doExecute(): RDD[InternalRow] = { - val numRows = longMetric("numRows") - child.execute().mapPartitions { iter => - val project = UnsafeProjection.create(unsafeProjectList, child.output) - iter.map { row => - numRows += 1 - project(row) + override def output: Seq[Attribute] = { + child.output.map { a => + if (a.nullable && notNullAttributes.contains(a.exprId)) { + a.withNullability(false) + } else { + a } } } - override def outputOrdering: Seq[SortOrder] = child.outputOrdering -} + private[sql] override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } -case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } - private[sql] override lazy val metrics = Map( - "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + val numOutput = metricTerm(ctx, "numOutputRows") + + /** + * Generates code for `c`, using `in` for input attributes and `attrs` for nullability. + */ + def genPredicate(c: Expression, in: Seq[ExprCode], attrs: Seq[Attribute]): String = { + val bound = BindReferences.bindReference(c, attrs) + val evaluated = evaluateRequiredVariables(child.output, in, c.references) + + // Generate the code for the predicate. + val ev = ExpressionCanonicalizer.execute(bound).gen(ctx) + val nullCheck = if (bound.nullable) { + s"${ev.isNull} || " + } else { + s"" + } + + s""" + |$evaluated + |${ev.code} + |if (${nullCheck}!${ev.value}) continue; + """.stripMargin + } + + ctx.currentVars = input + + // To generate the predicates we will follow this algorithm. + // For each predicate that is not IsNotNull, we will generate them one by one loading attributes + // as necessary. For each of both attributes, if there is a IsNotNull predicate we will generate + // that check *before* the predicate. After all of these predicates, we will generate the + // remaining IsNotNull checks that were not part of other predicates. + // This has the property of not doing redundant IsNotNull checks and taking better advantage of + // short-circuiting, not loading attributes until they are needed. + // This is very perf sensitive. + // TODO: revisit this. We can consider reordering predicates as well. + val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length) + val generated = otherPreds.map { c => + val nullChecks = c.references.map { r => + val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)} + if (idx != -1 && !generatedIsNotNullChecks(idx)) { + generatedIsNotNullChecks(idx) = true + // Use the child's output. The nullability is what the child produced. + genPredicate(notNullPreds(idx), input, child.output) + } else { + "" + } + }.mkString("\n").trim + + // Here we use *this* operator's output with this output's nullability since we already + // enforced them with the IsNotNull checks above. + s""" + |$nullChecks + |${genPredicate(c, input, output)} + """.stripMargin.trim + }.mkString("\n") + + val nullChecks = notNullPreds.zipWithIndex.map { case (c, idx) => + if (!generatedIsNotNullChecks(idx)) { + genPredicate(c, input, child.output) + } else { + "" + } + }.mkString("\n") + + // Reset the isNull to false for the not-null columns, then the followed operators could + // generate better code (remove dead branches). + val resultVars = input.zipWithIndex.map { case (ev, i) => + if (notNullAttributes.contains(child.output(i).exprId)) { + ev.isNull = "false" + } + ev + } + + s""" + |$generated + |$nullChecks + |$numOutput.add(1); + |${consume(ctx, resultVars)} + """.stripMargin + } protected override def doExecute(): RDD[InternalRow] = { - val numInputRows = longMetric("numInputRows") val numOutputRows = longMetric("numOutputRows") - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => val predicate = newPredicate(condition, child.output) iter.filter { row => - numInputRows += 1 val r = predicate(row) if (r) numOutputRows += 1 r @@ -110,12 +206,6 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { } override def outputOrdering: Seq[SortOrder] = child.outputOrdering - - override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows - - override def canProcessUnsafeRows: Boolean = true - - override def canProcessSafeRows: Boolean = true } /** @@ -133,14 +223,11 @@ case class Sample( upperBound: Double, withReplacement: Boolean, seed: Long, - child: SparkPlan) - extends UnaryNode -{ + child: SparkPlan) extends UnaryNode with CodegenSupport { override def output: Seq[Attribute] = child.output - override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true + private[sql] override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) protected override def doExecute(): RDD[InternalRow] = { if (withReplacement) { @@ -155,107 +242,222 @@ case class Sample( child.execute().randomSampleWithRange(lowerBound, upperBound, seed) } } -} - -/** - * Union two plans, without a distinct. This is UNION ALL in SQL. - */ -case class Union(children: Seq[SparkPlan]) extends SparkPlan { - // TODO: attributes output by union should be distinct for nullability purposes - override def output: Seq[Attribute] = children.head.output - override def outputsUnsafeRows: Boolean = children.forall(_.outputsUnsafeRows) - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true - protected override def doExecute(): RDD[InternalRow] = - sparkContext.union(children.map(_.execute())) -} - -/** - * Take the first limit elements. Note that the implementation is different depending on whether - * this is a terminal operator or not. If it is terminal and is invoked using executeCollect, - * this operator uses something similar to Spark's take method on the Spark driver. If it is not - * terminal or is invoked using execute, we first take the limit on each partition, and then - * repartition all the data to a single partition to compute the global limit. - */ -case class Limit(limit: Int, child: SparkPlan) - extends UnaryNode { - // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan: - // partition local limit -> exchange into one partition -> partition local limit again - /** We must copy rows when sort based shuffle is on */ - private def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } - override def output: Seq[Attribute] = child.output - override def outputPartitioning: Partitioning = SinglePartition + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } - override def executeCollect(): Array[InternalRow] = child.executeTake(limit) + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + val numOutput = metricTerm(ctx, "numOutputRows") + val sampler = ctx.freshName("sampler") - protected override def doExecute(): RDD[InternalRow] = { - val rdd: RDD[_ <: Product2[Boolean, InternalRow]] = if (sortBasedShuffleOn) { - child.execute().mapPartitions { iter => - iter.take(limit).map(row => (false, row.copy())) - } + if (withReplacement) { + val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName + val initSampler = ctx.freshName("initSampler") + ctx.addMutableState(s"$samplerClass", sampler, + s"$initSampler();") + + ctx.addNewFunction(initSampler, + s""" + | private void $initSampler() { + | $sampler = new $samplerClass($upperBound - $lowerBound, false); + | java.util.Random random = new java.util.Random(${seed}L); + | long randomSeed = random.nextLong(); + | int loopCount = 0; + | while (loopCount < partitionIndex) { + | randomSeed = random.nextLong(); + | loopCount += 1; + | } + | $sampler.setSeed(randomSeed); + | } + """.stripMargin.trim) + + val samplingCount = ctx.freshName("samplingCount") + s""" + | int $samplingCount = $sampler.sample(); + | while ($samplingCount-- > 0) { + | $numOutput.add(1); + | ${consume(ctx, input)} + | } + """.stripMargin.trim } else { - child.execute().mapPartitions { iter => - val mutablePair = new MutablePair[Boolean, InternalRow]() - iter.take(limit).map(row => mutablePair.update(false, row)) - } + val samplerClass = classOf[BernoulliCellSampler[UnsafeRow]].getName + ctx.addMutableState(s"$samplerClass", sampler, + s""" + | $sampler = new $samplerClass($lowerBound, $upperBound, false); + | $sampler.setSeed(${seed}L + partitionIndex); + """.stripMargin.trim) + + s""" + | if ($sampler.sample() == 0) continue; + | $numOutput.add(1); + | ${consume(ctx, input)} + """.stripMargin.trim } - val part = new HashPartitioner(1) - val shuffled = new ShuffledRDD[Boolean, InternalRow, InternalRow](rdd, part) - shuffled.setSerializer(new SparkSqlSerializer(child.sqlContext.sparkContext.getConf)) - shuffled.mapPartitions(_.take(limit).map(_._2)) } } -/** - * Take the first limit elements as defined by the sortOrder, and do projection if needed. - * This is logically equivalent to having a [[Limit]] operator after a [[Sort]] operator, - * or having a [[Project]] operator between them. - * This could have been named TopK, but Spark's top operator does the opposite in ordering - * so we name it TakeOrdered to avoid confusion. - */ -case class TakeOrderedAndProject( - limit: Int, - sortOrder: Seq[SortOrder], - projectList: Option[Seq[NamedExpression]], - child: SparkPlan) extends UnaryNode { +case class Range( + start: Long, + step: Long, + numSlices: Int, + numElements: BigInt, + output: Seq[Attribute]) + extends LeafNode with CodegenSupport { - override def output: Seq[Attribute] = { - val projectOutput = projectList.map(_.map(_.toAttribute)) - projectOutput.getOrElse(child.output) - } + private[sql] override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - override def outputPartitioning: Partitioning = SinglePartition + // output attributes should not affect the results + override lazy val cleanArgs: Seq[Any] = Seq(start, step, numSlices, numElements) - // We need to use an interpreted ordering here because generated orderings cannot be serialized - // and this ordering needs to be created on the driver in order to be passed into Spark core code. - private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output) + override def upstreams(): Seq[RDD[InternalRow]] = { + sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) + .map(i => InternalRow(i)) :: Nil + } - // TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable. - @transient private val projection = projectList.map(new InterpretedProjection(_, child.output)) + protected override def doProduce(ctx: CodegenContext): String = { + val numOutput = metricTerm(ctx, "numOutputRows") + + val initTerm = ctx.freshName("initRange") + ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") + val partitionEnd = ctx.freshName("partitionEnd") + ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;") + val number = ctx.freshName("number") + ctx.addMutableState("long", number, s"$number = 0L;") + val overflow = ctx.freshName("overflow") + ctx.addMutableState("boolean", overflow, s"$overflow = false;") + + val value = ctx.freshName("value") + val ev = ExprCode("", "false", value) + val BigInt = classOf[java.math.BigInteger].getName + val checkEnd = if (step > 0) { + s"$number < $partitionEnd" + } else { + s"$number > $partitionEnd" + } - private def collectData(): Array[InternalRow] = { - val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) - projection.map(data.map(_)).getOrElse(data) + ctx.addNewFunction("initRange", + s""" + | private void initRange(int idx) { + | $BigInt index = $BigInt.valueOf(idx); + | $BigInt numSlice = $BigInt.valueOf(${numSlices}L); + | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); + | $BigInt step = $BigInt.valueOf(${step}L); + | $BigInt start = $BigInt.valueOf(${start}L); + | + | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); + | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $number = Long.MAX_VALUE; + | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $number = Long.MIN_VALUE; + | } else { + | $number = st.longValue(); + | } + | + | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) + | .multiply(step).add(start); + | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { + | $partitionEnd = Long.MAX_VALUE; + | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { + | $partitionEnd = Long.MIN_VALUE; + | } else { + | $partitionEnd = end.longValue(); + | } + | + | $numOutput.add(($partitionEnd - $number) / ${step}L); + | } + """.stripMargin) + + val input = ctx.freshName("input") + // Right now, Range is only used when there is one upstream. + ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + s""" + | // initialize Range + | if (!$initTerm) { + | $initTerm = true; + | initRange(partitionIndex); + | } + | + | while (!$overflow && $checkEnd) { + | long $value = $number; + | $number += ${step}L; + | if ($number < $value ^ ${step}L < 0) { + | $overflow = true; + | } + | ${consume(ctx, Seq(ev))} + | if (shouldStop()) return; + | } + """.stripMargin } - override def executeCollect(): Array[InternalRow] = { - collectData() + protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + sqlContext + .sparkContext + .parallelize(0 until numSlices, numSlices) + .mapPartitionsWithIndex { (i, _) => + val partitionStart = (i * numElements) / numSlices * step + start + val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start + def getSafeMargin(bi: BigInt): Long = + if (bi.isValidLong) { + bi.toLong + } else if (bi > 0) { + Long.MaxValue + } else { + Long.MinValue + } + val safePartitionStart = getSafeMargin(partitionStart) + val safePartitionEnd = getSafeMargin(partitionEnd) + val rowSize = UnsafeRow.calculateBitSetWidthInBytes(1) + LongType.defaultSize + val unsafeRow = UnsafeRow.createFromByteArray(rowSize, 1) + + new Iterator[InternalRow] { + private[this] var number: Long = safePartitionStart + private[this] var overflow: Boolean = false + + override def hasNext = + if (!overflow) { + if (step > 0) { + number < safePartitionEnd + } else { + number > safePartitionEnd + } + } else false + + override def next() = { + val ret = number + number += step + if (number < ret ^ step < 0) { + // we have Long.MaxValue + Long.MaxValue < Long.MaxValue + // and Long.MinValue + Long.MinValue > Long.MinValue, so iff the step causes a step + // back, we are pretty sure that we have an overflow. + overflow = true + } + + numOutputRows += 1 + unsafeRow.setLong(0, ret) + unsafeRow + } + } + } } +} - // TODO: Terminal split should be implemented differently from non-terminal split. - // TODO: Pick num splits based on |limit|. - protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1) - - override def outputOrdering: Seq[SortOrder] = sortOrder - - override def simpleString: String = { - val orderByString = sortOrder.mkString("[", ",", "]") - val outputString = output.mkString("[", ",", "]") +/** + * Union two plans, without a distinct. This is UNION ALL in SQL. + */ +case class Union(children: Seq[SparkPlan]) extends SparkPlan { + override def output: Seq[Attribute] = + children.map(_.output).transpose.map(attrs => + attrs.head.withNullability(attrs.exists(_.nullable))) - s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)" - } + protected override def doExecute(): RDD[InternalRow] = + sparkContext.union(children.map(_.execute())) } /** @@ -275,8 +477,6 @@ case class Coalesce(numPartitions: Int, child: SparkPlan) extends UnaryNode { protected override def doExecute(): RDD[InternalRow] = { child.execute().coalesce(numPartitions, shuffle = false) } - - override def canProcessUnsafeRows: Boolean = true } /** @@ -291,18 +491,6 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { } } -/** - * Returns the rows in left that also appear in right using the built in spark - * intersection function. - */ -case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { - override def output: Seq[Attribute] = children.head.output - - protected override def doExecute(): RDD[InternalRow] = { - left.execute().map(_.copy()).intersection(right.execute().map(_.copy())) - } -} - /** * A plan node that does nothing but lie about the output of its child. Used to spice a * (hopefully structurally equivalent) tree from a different optimization sequence into an already @@ -315,119 +503,16 @@ case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPl } /** - * Applies the given function to each input row and encodes the result. - */ -case class MapPartitions[T, U]( - func: Iterator[T] => Iterator[U], - tEncoder: ExpressionEncoder[T], - uEncoder: ExpressionEncoder[U], - output: Seq[Attribute], - child: SparkPlan) extends UnaryNode { - - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => - val tBoundEncoder = tEncoder.bind(child.output) - func(iter.map(tBoundEncoder.fromRow)).map(uEncoder.toRow) - } - } -} - -/** - * Applies the given function to each input row, appending the encoded result at the end of the row. - */ -case class AppendColumns[T, U]( - func: T => U, - tEncoder: ExpressionEncoder[T], - uEncoder: ExpressionEncoder[U], - newColumns: Seq[Attribute], - child: SparkPlan) extends UnaryNode { - - override def output: Seq[Attribute] = child.output ++ newColumns - - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => - val tBoundEncoder = tEncoder.bind(child.output) - val combiner = GenerateUnsafeRowJoiner.create(tEncoder.schema, uEncoder.schema) - iter.map { row => - val newColumns = uEncoder.toRow(func(tBoundEncoder.fromRow(row))) - combiner.join(row.asInstanceOf[UnsafeRow], newColumns.asInstanceOf[UnsafeRow]): InternalRow - } - } - } -} - -/** - * Groups the input rows together and calls the function with each group and an iterator containing - * all elements in the group. The result of this function is encoded and flattened before - * being output. + * A plan as subquery. + * + * This is used to generate tree string for SparkScalarSubquery. */ -case class MapGroups[K, T, U]( - func: (K, Iterator[T]) => Iterator[U], - kEncoder: ExpressionEncoder[K], - tEncoder: ExpressionEncoder[T], - uEncoder: ExpressionEncoder[U], - groupingAttributes: Seq[Attribute], - output: Seq[Attribute], - child: SparkPlan) extends UnaryNode { - - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(groupingAttributes) :: Nil - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(groupingAttributes.map(SortOrder(_, Ascending))) - - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => - val grouped = GroupedIterator(iter, groupingAttributes, child.output) - val groupKeyEncoder = kEncoder.bind(groupingAttributes) - - grouped.flatMap { case (key, rowIter) => - val result = func( - groupKeyEncoder.fromRow(key), - rowIter.map(tEncoder.fromRow)) - result.map(uEncoder.toRow) - } - } - } -} +case class Subquery(name: String, child: SparkPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering -/** - * Co-groups the data from left and right children, and calls the function with each group and 2 - * iterators containing all elements in the group from left and right side. - * The result of this function is encoded and flattened before being output. - */ -case class CoGroup[K, Left, Right, R]( - func: (K, Iterator[Left], Iterator[Right]) => Iterator[R], - kEncoder: ExpressionEncoder[K], - leftEnc: ExpressionEncoder[Left], - rightEnc: ExpressionEncoder[Right], - rEncoder: ExpressionEncoder[R], - output: Seq[Attribute], - leftGroup: Seq[Attribute], - rightGroup: Seq[Attribute], - left: SparkPlan, - right: SparkPlan) extends BinaryNode { - - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil - - override protected def doExecute(): RDD[InternalRow] = { - left.execute().zipPartitions(right.execute()) { (leftData, rightData) => - val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) - val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) - val groupKeyEncoder = kEncoder.bind(leftGroup) - - new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap { - case (key, leftResult, rightResult) => - val result = func( - groupKeyEncoder.fromRow(key), - leftResult.map(leftEnc.fromRow), - rightResult.map(rightEnc.fromRow)) - result.map(rEncoder.toRow) - } - } + protected override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala new file mode 100644 index 0000000000000..7cde04b62619e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.nio.{ByteBuffer, ByteOrder} + +import scala.annotation.tailrec + +import org.apache.spark.sql.catalyst.expressions.{MutableRow, UnsafeArrayData, UnsafeMapData, UnsafeRow} +import org.apache.spark.sql.execution.columnar.compression.CompressibleColumnAccessor +import org.apache.spark.sql.types._ + +/** + * An `Iterator` like trait used to extract values from columnar byte buffer. When a value is + * extracted from the buffer, instead of directly returning it, the value is set into some field of + * a [[MutableRow]]. In this way, boxing cost can be avoided by leveraging the setter methods + * for primitive values provided by [[MutableRow]]. + */ +private[columnar] trait ColumnAccessor { + initialize() + + protected def initialize() + + def hasNext: Boolean + + def extractTo(row: MutableRow, ordinal: Int): Unit + + protected def underlyingBuffer: ByteBuffer +} + +private[columnar] abstract class BasicColumnAccessor[JvmType]( + protected val buffer: ByteBuffer, + protected val columnType: ColumnType[JvmType]) + extends ColumnAccessor { + + protected def initialize() {} + + override def hasNext: Boolean = buffer.hasRemaining + + override def extractTo(row: MutableRow, ordinal: Int): Unit = { + extractSingle(row, ordinal) + } + + def extractSingle(row: MutableRow, ordinal: Int): Unit = { + columnType.extract(buffer, row, ordinal) + } + + protected def underlyingBuffer = buffer +} + +private[columnar] class NullColumnAccessor(buffer: ByteBuffer) + extends BasicColumnAccessor[Any](buffer, NULL) + with NullableColumnAccessor + +private[columnar] abstract class NativeColumnAccessor[T <: AtomicType]( + override protected val buffer: ByteBuffer, + override protected val columnType: NativeColumnType[T]) + extends BasicColumnAccessor(buffer, columnType) + with NullableColumnAccessor + with CompressibleColumnAccessor[T] + +private[columnar] class BooleanColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, BOOLEAN) + +private[columnar] class ByteColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, BYTE) + +private[columnar] class ShortColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, SHORT) + +private[columnar] class IntColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, INT) + +private[columnar] class LongColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, LONG) + +private[columnar] class FloatColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, FLOAT) + +private[columnar] class DoubleColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, DOUBLE) + +private[columnar] class StringColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, STRING) + +private[columnar] class BinaryColumnAccessor(buffer: ByteBuffer) + extends BasicColumnAccessor[Array[Byte]](buffer, BINARY) + with NullableColumnAccessor + +private[columnar] class CompactDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) + extends NativeColumnAccessor(buffer, COMPACT_DECIMAL(dataType)) + +private[columnar] class DecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) + extends BasicColumnAccessor[Decimal](buffer, LARGE_DECIMAL(dataType)) + with NullableColumnAccessor + +private[columnar] class StructColumnAccessor(buffer: ByteBuffer, dataType: StructType) + extends BasicColumnAccessor[UnsafeRow](buffer, STRUCT(dataType)) + with NullableColumnAccessor + +private[columnar] class ArrayColumnAccessor(buffer: ByteBuffer, dataType: ArrayType) + extends BasicColumnAccessor[UnsafeArrayData](buffer, ARRAY(dataType)) + with NullableColumnAccessor + +private[columnar] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType) + extends BasicColumnAccessor[UnsafeMapData](buffer, MAP(dataType)) + with NullableColumnAccessor + +private[columnar] object ColumnAccessor { + @tailrec + def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = { + val buf = buffer.order(ByteOrder.nativeOrder) + + dataType match { + case NullType => new NullColumnAccessor(buf) + case BooleanType => new BooleanColumnAccessor(buf) + case ByteType => new ByteColumnAccessor(buf) + case ShortType => new ShortColumnAccessor(buf) + case IntegerType | DateType => new IntColumnAccessor(buf) + case LongType | TimestampType => new LongColumnAccessor(buf) + case FloatType => new FloatColumnAccessor(buf) + case DoubleType => new DoubleColumnAccessor(buf) + case StringType => new StringColumnAccessor(buf) + case BinaryType => new BinaryColumnAccessor(buf) + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => + new CompactDecimalColumnAccessor(buf, dt) + case dt: DecimalType => new DecimalColumnAccessor(buf, dt) + case struct: StructType => new StructColumnAccessor(buf, struct) + case array: ArrayType => new ArrayColumnAccessor(buf, array) + case map: MapType => new MapColumnAccessor(buf, map) + case udt: UserDefinedType[_] => ColumnAccessor(udt.sqlType, buffer) + case other => + throw new Exception(s"not support type: $other") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala new file mode 100644 index 0000000000000..d30655e0c4a20 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.nio.{ByteBuffer, ByteOrder} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.columnar.ColumnBuilder._ +import org.apache.spark.sql.execution.columnar.compression.{AllCompressionSchemes, CompressibleColumnBuilder} +import org.apache.spark.sql.types._ + +private[columnar] trait ColumnBuilder { + /** + * Initializes with an approximate lower bound on the expected number of elements in this column. + */ + def initialize(initialSize: Int, columnName: String = "", useCompression: Boolean = false): Unit + + /** + * Appends `row(ordinal)` to the column builder. + */ + def appendFrom(row: InternalRow, ordinal: Int): Unit + + /** + * Column statistics information + */ + def columnStats: ColumnStats + + /** + * Returns the final columnar byte buffer. + */ + def build(): ByteBuffer +} + +private[columnar] class BasicColumnBuilder[JvmType]( + val columnStats: ColumnStats, + val columnType: ColumnType[JvmType]) + extends ColumnBuilder { + + protected var columnName: String = _ + + protected var buffer: ByteBuffer = _ + + override def initialize( + initialSize: Int, + columnName: String = "", + useCompression: Boolean = false): Unit = { + + val size = if (initialSize == 0) DEFAULT_INITIAL_BUFFER_SIZE else initialSize + this.columnName = columnName + + buffer = ByteBuffer.allocate(size * columnType.defaultSize) + buffer.order(ByteOrder.nativeOrder()) + } + + override def appendFrom(row: InternalRow, ordinal: Int): Unit = { + buffer = ensureFreeSpace(buffer, columnType.actualSize(row, ordinal)) + columnType.append(row, ordinal, buffer) + } + + override def build(): ByteBuffer = { + if (buffer.capacity() > buffer.position() * 1.1) { + // trim the buffer + buffer = ByteBuffer + .allocate(buffer.position()) + .order(ByteOrder.nativeOrder()) + .put(buffer.array(), 0, buffer.position()) + } + buffer.flip().asInstanceOf[ByteBuffer] + } +} + +private[columnar] class NullColumnBuilder + extends BasicColumnBuilder[Any](new ObjectColumnStats(NullType), NULL) + with NullableColumnBuilder + +private[columnar] abstract class ComplexColumnBuilder[JvmType]( + columnStats: ColumnStats, + columnType: ColumnType[JvmType]) + extends BasicColumnBuilder[JvmType](columnStats, columnType) + with NullableColumnBuilder + +private[columnar] abstract class NativeColumnBuilder[T <: AtomicType]( + override val columnStats: ColumnStats, + override val columnType: NativeColumnType[T]) + extends BasicColumnBuilder[T#InternalType](columnStats, columnType) + with NullableColumnBuilder + with AllCompressionSchemes + with CompressibleColumnBuilder[T] + +private[columnar] +class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) + +private[columnar] +class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) + +private[columnar] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT) + +private[columnar] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) + +private[columnar] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG) + +private[columnar] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) + +private[columnar] +class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE) + +private[columnar] +class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) + +private[columnar] +class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) + +private[columnar] class CompactDecimalColumnBuilder(dataType: DecimalType) + extends NativeColumnBuilder(new DecimalColumnStats(dataType), COMPACT_DECIMAL(dataType)) + +private[columnar] class DecimalColumnBuilder(dataType: DecimalType) + extends ComplexColumnBuilder(new DecimalColumnStats(dataType), LARGE_DECIMAL(dataType)) + +private[columnar] class StructColumnBuilder(dataType: StructType) + extends ComplexColumnBuilder(new ObjectColumnStats(dataType), STRUCT(dataType)) + +private[columnar] class ArrayColumnBuilder(dataType: ArrayType) + extends ComplexColumnBuilder(new ObjectColumnStats(dataType), ARRAY(dataType)) + +private[columnar] class MapColumnBuilder(dataType: MapType) + extends ComplexColumnBuilder(new ObjectColumnStats(dataType), MAP(dataType)) + +private[columnar] object ColumnBuilder { + val DEFAULT_INITIAL_BUFFER_SIZE = 128 * 1024 + val MAX_BATCH_SIZE_IN_BYTE = 4 * 1024 * 1024L + + private[columnar] def ensureFreeSpace(orig: ByteBuffer, size: Int) = { + if (orig.remaining >= size) { + orig + } else { + // grow in steps of initial size + val capacity = orig.capacity() + val newSize = capacity + size.max(capacity) + val pos = orig.position() + + ByteBuffer + .allocate(newSize) + .order(ByteOrder.nativeOrder()) + .put(orig.array(), 0, pos) + } + } + + def apply( + dataType: DataType, + initialSize: Int = 0, + columnName: String = "", + useCompression: Boolean = false): ColumnBuilder = { + val builder: ColumnBuilder = dataType match { + case NullType => new NullColumnBuilder + case BooleanType => new BooleanColumnBuilder + case ByteType => new ByteColumnBuilder + case ShortType => new ShortColumnBuilder + case IntegerType | DateType => new IntColumnBuilder + case LongType | TimestampType => new LongColumnBuilder + case FloatType => new FloatColumnBuilder + case DoubleType => new DoubleColumnBuilder + case StringType => new StringColumnBuilder + case BinaryType => new BinaryColumnBuilder + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => + new CompactDecimalColumnBuilder(dt) + case dt: DecimalType => new DecimalColumnBuilder(dt) + case struct: StructType => new StructColumnBuilder(struct) + case array: ArrayType => new ArrayColumnBuilder(array) + case map: MapType => new MapColumnBuilder(map) + case udt: UserDefinedType[_] => + return apply(udt.sqlType, initialSize, columnName, useCompression) + case other => + throw new Exception(s"not supported type: $other") + } + + builder.initialize(initialSize, columnName, useCompression) + builder + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala similarity index 87% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala index ba61003ba41c6..5d4476989a369 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Attribute, AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, GenericInternalRow} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable { +private[columnar] class ColumnStatisticsSchema(a: Attribute) extends Serializable { val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = true)() val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, nullable = true)() val nullCount = AttributeReference(a.name + ".nullCount", IntegerType, nullable = false)() @@ -32,7 +32,7 @@ private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable { val schema = Seq(lowerBound, upperBound, nullCount, count, sizeInBytes) } -private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable { +private[columnar] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable { val (forAttribute, schema) = { val allStats = tableSchema.map(a => a -> new ColumnStatisticsSchema(a)) (AttributeMap(allStats), allStats.map(_._2.schema).foldLeft(Seq.empty[Attribute])(_ ++ _)) @@ -45,10 +45,10 @@ private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Seri * NOTE: we intentionally avoid using `Ordering[T]` to compare values here because `Ordering[T]` * brings significant performance penalty. */ -private[sql] sealed trait ColumnStats extends Serializable { +private[columnar] sealed trait ColumnStats extends Serializable { protected var count = 0 protected var nullCount = 0 - protected var sizeInBytes = 0L + private[columnar] var sizeInBytes = 0L /** * Gathers statistics information from `row(ordinal)`. @@ -72,14 +72,14 @@ private[sql] sealed trait ColumnStats extends Serializable { /** * A no-op ColumnStats only used for testing purposes. */ -private[sql] class NoopColumnStats extends ColumnStats { +private[columnar] class NoopColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal) override def collectedStatistics: GenericInternalRow = new GenericInternalRow(Array[Any](null, null, nullCount, count, 0L)) } -private[sql] class BooleanColumnStats extends ColumnStats { +private[columnar] class BooleanColumnStats extends ColumnStats { protected var upper = false protected var lower = true @@ -97,7 +97,7 @@ private[sql] class BooleanColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class ByteColumnStats extends ColumnStats { +private[columnar] class ByteColumnStats extends ColumnStats { protected var upper = Byte.MinValue protected var lower = Byte.MaxValue @@ -115,7 +115,7 @@ private[sql] class ByteColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class ShortColumnStats extends ColumnStats { +private[columnar] class ShortColumnStats extends ColumnStats { protected var upper = Short.MinValue protected var lower = Short.MaxValue @@ -133,7 +133,7 @@ private[sql] class ShortColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class IntColumnStats extends ColumnStats { +private[columnar] class IntColumnStats extends ColumnStats { protected var upper = Int.MinValue protected var lower = Int.MaxValue @@ -151,7 +151,7 @@ private[sql] class IntColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class LongColumnStats extends ColumnStats { +private[columnar] class LongColumnStats extends ColumnStats { protected var upper = Long.MinValue protected var lower = Long.MaxValue @@ -169,7 +169,7 @@ private[sql] class LongColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class FloatColumnStats extends ColumnStats { +private[columnar] class FloatColumnStats extends ColumnStats { protected var upper = Float.MinValue protected var lower = Float.MaxValue @@ -187,7 +187,7 @@ private[sql] class FloatColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class DoubleColumnStats extends ColumnStats { +private[columnar] class DoubleColumnStats extends ColumnStats { protected var upper = Double.MinValue protected var lower = Double.MaxValue @@ -205,7 +205,7 @@ private[sql] class DoubleColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class StringColumnStats extends ColumnStats { +private[columnar] class StringColumnStats extends ColumnStats { protected var upper: UTF8String = null protected var lower: UTF8String = null @@ -223,7 +223,7 @@ private[sql] class StringColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class BinaryColumnStats extends ColumnStats { +private[columnar] class BinaryColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { @@ -235,7 +235,7 @@ private[sql] class BinaryColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) } -private[sql] class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { +private[columnar] class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { def this(dt: DecimalType) = this(dt.precision, dt.scale) protected var upper: Decimal = null @@ -256,7 +256,7 @@ private[sql] class DecimalColumnStats(precision: Int, scale: Int) extends Column new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class ObjectColumnStats(dataType: DataType) extends ColumnStats { +private[columnar] class ObjectColumnStats(dataType: DataType) extends ColumnStats { val columnType = ColumnType(dataType) override def gatherStats(row: InternalRow, ordinal: Int): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala similarity index 92% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index 68e509eb5047d..f9d606e37ea89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -15,11 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.math.{BigDecimal, BigInteger} import java.nio.ByteBuffer +import scala.annotation.tailrec import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.catalyst.InternalRow @@ -39,9 +40,9 @@ import org.apache.spark.unsafe.types.UTF8String * so we do not have helper methods for them. * * - * WARNNING: This only works with HeapByteBuffer + * WARNING: This only works with HeapByteBuffer */ -object ByteBufferHelper { +private[columnar] object ByteBufferHelper { def getInt(buffer: ByteBuffer): Int = { val pos = buffer.position() buffer.position(pos + 4) @@ -73,7 +74,7 @@ object ByteBufferHelper { * * @tparam JvmType Underlying Java type to represent the elements. */ -private[sql] sealed abstract class ColumnType[JvmType] { +private[columnar] sealed abstract class ColumnType[JvmType] { // The catalyst data type of this column. def dataType: DataType @@ -142,7 +143,7 @@ private[sql] sealed abstract class ColumnType[JvmType] { override def toString: String = getClass.getSimpleName.stripSuffix("$") } -private[sql] object NULL extends ColumnType[Any] { +private[columnar] object NULL extends ColumnType[Any] { override def dataType: DataType = NullType override def defaultSize: Int = 0 @@ -152,7 +153,7 @@ private[sql] object NULL extends ColumnType[Any] { override def getField(row: InternalRow, ordinal: Int): Any = null } -private[sql] abstract class NativeColumnType[T <: AtomicType]( +private[columnar] abstract class NativeColumnType[T <: AtomicType]( val dataType: T, val defaultSize: Int) extends ColumnType[T#InternalType] { @@ -163,7 +164,7 @@ private[sql] abstract class NativeColumnType[T <: AtomicType]( def scalaTag: TypeTag[dataType.InternalType] = dataType.tag } -private[sql] object INT extends NativeColumnType(IntegerType, 4) { +private[columnar] object INT extends NativeColumnType(IntegerType, 4) { override def append(v: Int, buffer: ByteBuffer): Unit = { buffer.putInt(v) } @@ -192,7 +193,7 @@ private[sql] object INT extends NativeColumnType(IntegerType, 4) { } } -private[sql] object LONG extends NativeColumnType(LongType, 8) { +private[columnar] object LONG extends NativeColumnType(LongType, 8) { override def append(v: Long, buffer: ByteBuffer): Unit = { buffer.putLong(v) } @@ -220,7 +221,7 @@ private[sql] object LONG extends NativeColumnType(LongType, 8) { } } -private[sql] object FLOAT extends NativeColumnType(FloatType, 4) { +private[columnar] object FLOAT extends NativeColumnType(FloatType, 4) { override def append(v: Float, buffer: ByteBuffer): Unit = { buffer.putFloat(v) } @@ -248,7 +249,7 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 4) { } } -private[sql] object DOUBLE extends NativeColumnType(DoubleType, 8) { +private[columnar] object DOUBLE extends NativeColumnType(DoubleType, 8) { override def append(v: Double, buffer: ByteBuffer): Unit = { buffer.putDouble(v) } @@ -276,7 +277,7 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 8) { } } -private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 1) { +private[columnar] object BOOLEAN extends NativeColumnType(BooleanType, 1) { override def append(v: Boolean, buffer: ByteBuffer): Unit = { buffer.put(if (v) 1: Byte else 0: Byte) } @@ -302,7 +303,7 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 1) { } } -private[sql] object BYTE extends NativeColumnType(ByteType, 1) { +private[columnar] object BYTE extends NativeColumnType(ByteType, 1) { override def append(v: Byte, buffer: ByteBuffer): Unit = { buffer.put(v) } @@ -330,7 +331,7 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 1) { } } -private[sql] object SHORT extends NativeColumnType(ShortType, 2) { +private[columnar] object SHORT extends NativeColumnType(ShortType, 2) { override def append(v: Short, buffer: ByteBuffer): Unit = { buffer.putShort(v) } @@ -362,7 +363,7 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 2) { * A fast path to copy var-length bytes between ByteBuffer and UnsafeRow without creating wrapper * objects. */ -private[sql] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType] { +private[columnar] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType] { // copy the bytes from ByteBuffer to UnsafeRow override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { @@ -387,7 +388,7 @@ private[sql] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType] { } } -private[sql] object STRING +private[columnar] object STRING extends NativeColumnType(StringType, 8) with DirectCopyColumnType[UTF8String] { override def actualSize(row: InternalRow, ordinal: Int): Int = { @@ -425,7 +426,7 @@ private[sql] object STRING override def clone(v: UTF8String): UTF8String = v.clone() } -private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int) +private[columnar] case class COMPACT_DECIMAL(precision: Int, scale: Int) extends NativeColumnType(DecimalType(precision, scale), 8) { override def extract(buffer: ByteBuffer): Decimal = { @@ -467,13 +468,13 @@ private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int) } } -private[sql] object COMPACT_DECIMAL { +private[columnar] object COMPACT_DECIMAL { def apply(dt: DecimalType): COMPACT_DECIMAL = { COMPACT_DECIMAL(dt.precision, dt.scale) } } -private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: Int) +private[columnar] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: Int) extends ColumnType[JvmType] with DirectCopyColumnType[JvmType] { def serialize(value: JvmType): Array[Byte] @@ -492,7 +493,7 @@ private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: } } -private[sql] object BINARY extends ByteArrayColumnType[Array[Byte]](16) { +private[columnar] object BINARY extends ByteArrayColumnType[Array[Byte]](16) { def dataType: DataType = BinaryType @@ -512,7 +513,7 @@ private[sql] object BINARY extends ByteArrayColumnType[Array[Byte]](16) { def deserialize(bytes: Array[Byte]): Array[Byte] = bytes } -private[sql] case class LARGE_DECIMAL(precision: Int, scale: Int) +private[columnar] case class LARGE_DECIMAL(precision: Int, scale: Int) extends ByteArrayColumnType[Decimal](12) { override val dataType: DataType = DecimalType(precision, scale) @@ -539,16 +540,16 @@ private[sql] case class LARGE_DECIMAL(precision: Int, scale: Int) } } -private[sql] object LARGE_DECIMAL { +private[columnar] object LARGE_DECIMAL { def apply(dt: DecimalType): LARGE_DECIMAL = { LARGE_DECIMAL(dt.precision, dt.scale) } } -private[sql] case class STRUCT(dataType: StructType) +private[columnar] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRow] with DirectCopyColumnType[UnsafeRow] { - private val numOfFields: Int = dataType.fields.size + private val numOfFields: Int = dataType.fields.length override def defaultSize: Int = 20 @@ -574,11 +575,10 @@ private[sql] case class STRUCT(dataType: StructType) assert(buffer.hasArray) val cursor = buffer.position() buffer.position(cursor + sizeInBytes) - val unsafeRow = new UnsafeRow + val unsafeRow = new UnsafeRow(numOfFields) unsafeRow.pointTo( buffer.array(), Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, - numOfFields, sizeInBytes) unsafeRow } @@ -586,7 +586,7 @@ private[sql] case class STRUCT(dataType: StructType) override def clone(v: UnsafeRow): UnsafeRow = v.copy() } -private[sql] case class ARRAY(dataType: ArrayType) +private[columnar] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArrayData] with DirectCopyColumnType[UnsafeArrayData] { override def defaultSize: Int = 16 @@ -625,7 +625,7 @@ private[sql] case class ARRAY(dataType: ArrayType) override def clone(v: UnsafeArrayData): UnsafeArrayData = v.copy() } -private[sql] case class MAP(dataType: MapType) +private[columnar] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] with DirectCopyColumnType[UnsafeMapData] { override def defaultSize: Int = 32 @@ -663,7 +663,8 @@ private[sql] case class MAP(dataType: MapType) override def clone(v: UnsafeMapData): UnsafeMapData = v.copy() } -private[sql] object ColumnType { +private[columnar] object ColumnType { + @tailrec def apply(dataType: DataType): ColumnType[_] = { dataType match { case NullType => NULL diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala similarity index 75% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index ff9393b465b7a..e2e33e32463fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -15,12 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeRowWriter, CodeFormatter, CodeGenerator} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodeGenerator, UnsafeRowWriter} import org.apache.spark.sql.types._ /** @@ -88,7 +88,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera case array: ArrayType => classOf[ArrayColumnAccessor].getName case t: MapType => classOf[MapColumnAccessor].getName } - ctx.addMutableState(accessorCls, accessorName, s"$accessorName = null;") + ctx.addMutableState(accessorCls, accessorName, "") val createCode = dt match { case t if ctx.isPrimitiveType(dt) => @@ -114,6 +114,42 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera (createCode, extract + patch) }.unzip + /* + * 200 = 6000 bytes / 30 (up to 30 bytes per one call)) + * the maximum byte code size to be compiled for HotSpot is 8000. + * We should keep less than 8000 + */ + val numberOfStatementsThreshold = 200 + val (initializerAccessorCalls, extractorCalls) = + if (initializeAccessors.length <= numberOfStatementsThreshold) { + (initializeAccessors.mkString("\n"), extractors.mkString("\n")) + } else { + val groupedAccessorsItr = initializeAccessors.grouped(numberOfStatementsThreshold) + val groupedExtractorsItr = extractors.grouped(numberOfStatementsThreshold) + var groupedAccessorsLength = 0 + groupedAccessorsItr.zipWithIndex.map { case (body, i) => + groupedAccessorsLength += 1 + val funcName = s"accessors$i" + val funcCode = s""" + |private void $funcName() { + | ${body.mkString("\n")} + |} + """.stripMargin + ctx.addNewFunction(funcName, funcCode) + } + groupedExtractorsItr.zipWithIndex.map { case (body, i) => + val funcName = s"extractors$i" + val funcCode = s""" + |private void $funcName() { + | ${body.mkString("\n")} + |} + """.stripMargin + ctx.addNewFunction(funcName, funcCode) + } + ((0 to groupedAccessorsLength - 1).map { i => s"accessors$i();" }.mkString("\n"), + (0 to groupedAccessorsLength - 1).map { i => s"extractors$i();" }.mkString("\n")) + } + val code = s""" import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -121,9 +157,9 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; - import org.apache.spark.sql.columnar.MutableUnsafeRow; + import org.apache.spark.sql.execution.columnar.MutableUnsafeRow; - public SpecificColumnarIterator generate($exprType[] expr) { + public SpecificColumnarIterator generate(Object[] references) { return new SpecificColumnarIterator(); } @@ -131,9 +167,9 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private ByteOrder nativeOrder = null; private byte[][] buffers = null; - private UnsafeRow unsafeRow = new UnsafeRow(); - private BufferHolder bufferHolder = new BufferHolder(); - private UnsafeRowWriter rowWriter = new UnsafeRowWriter(); + private UnsafeRow unsafeRow = new UnsafeRow($numFields); + private BufferHolder bufferHolder = new BufferHolder(unsafeRow); + private UnsafeRowWriter rowWriter = new UnsafeRowWriter(bufferHolder, $numFields); private MutableUnsafeRow mutableRow = null; private int currentRow = 0; @@ -143,14 +179,12 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private DataType[] columnTypes = null; private int[] columnIndexes = null; - ${declareMutableStates(ctx)} + ${ctx.declareMutableStates()} public SpecificColumnarIterator() { this.nativeOrder = ByteOrder.nativeOrder(); this.buffers = new byte[${columnTypes.length}][]; this.mutableRow = new MutableUnsafeRow(rowWriter); - - ${initMutableStates(ctx)} } public void initialize(Iterator input, DataType[] columnTypes, int[] columnIndexes) { @@ -159,6 +193,8 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera this.columnIndexes = columnIndexes; } + ${ctx.declareAddedFunctions()} + public boolean hasNext() { if (currentRow < numRowsInBatch) { return true; @@ -173,7 +209,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera for (int i = 0; i < columnIndexes.length; i ++) { buffers[i] = batch.buffers()[columnIndexes[i]]; } - ${initializeAccessors.mkString("\n")} + ${initializerAccessorCalls} return hasNext(); } @@ -181,15 +217,15 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera public InternalRow next() { currentRow += 1; bufferHolder.reset(); - rowWriter.initialize(bufferHolder, $numFields); - ${extractors.mkString("\n")} - unsafeRow.pointTo(bufferHolder.buffer, $numFields, bufferHolder.totalSize()); + rowWriter.zeroOutNullBytes(); + ${extractorCalls} + unsafeRow.setTotalSize(bufferHolder.totalSize()); return unsafeRow; } }""" logDebug(s"Generated ColumnarIterator: ${CodeFormatter.format(code)}") - compile(code).generate(ctx.references.toArray).asInstanceOf[ColumnarIterator] + CodeGenerator.compile(code).generate(Array.empty).asInstanceOf[ColumnarIterator] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala similarity index 86% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index 7eb1ad7cd8198..1f964b1fc1dce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -15,21 +15,24 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import scala.collection.mutable.ArrayBuffer +import org.apache.spark.{Accumulable, Accumulator, Accumulators} +import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.{ConvertToUnsafe, LeafNode, SparkPlan} +import org.apache.spark.sql.execution.{LeafNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.UserDefinedType import org.apache.spark.storage.StorageLevel -import org.apache.spark.{Accumulable, Accumulator, Accumulators} private[sql] object InMemoryRelation { def apply( @@ -38,9 +41,7 @@ private[sql] object InMemoryRelation { storageLevel: StorageLevel, child: SparkPlan, tableName: Option[String]): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, - if (child.outputsUnsafeRows) child else ConvertToUnsafe(child), - tableName)() + new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)() } /** @@ -50,7 +51,8 @@ private[sql] object InMemoryRelation { * @param buffers The buffers for serialized columns * @param stats The stat of columns */ -private[sql] case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) +private[columnar] +case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) private[sql] case class InMemoryRelation( output: Seq[Attribute], @@ -59,10 +61,12 @@ private[sql] case class InMemoryRelation( storageLevel: StorageLevel, @transient child: SparkPlan, tableName: Option[String])( - @transient private var _cachedColumnBuffers: RDD[CachedBatch] = null, - @transient private var _statistics: Statistics = null, - private var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null) - extends LogicalPlan with MultiInstanceRelation { + @transient private[sql] var _cachedColumnBuffers: RDD[CachedBatch] = null, + @transient private[sql] var _statistics: Statistics = null, + private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null) + extends logical.LeafNode with MultiInstanceRelation { + + override def producedAttributes: AttributeSet = outputSet private val batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = if (_batchStats == null) { @@ -125,7 +129,7 @@ private[sql] case class InMemoryRelation( private def buildBuffers(): Unit = { val output = child.output - val cached = child.execute().mapPartitions { rowIterator => + val cached = child.execute().mapPartitionsInternal { rowIterator => new Iterator[CachedBatch] { def next(): CachedBatch = { val columnBuilders = output.map { attribute => @@ -133,7 +137,9 @@ private[sql] case class InMemoryRelation( }.toArray var rowCount = 0 - while (rowIterator.hasNext && rowCount < batchSize) { + var totalSize = 0L + while (rowIterator.hasNext && rowCount < batchSize + && totalSize < ColumnBuilder.MAX_BATCH_SIZE_IN_BYTE) { val row = rowIterator.next() // Added for SPARK-6082. This assertion can be useful for scenarios when something @@ -141,14 +147,16 @@ private[sql] case class InMemoryRelation( // may result malformed rows, causing ArrayIndexOutOfBoundsException, which is somewhat // hard to decipher. assert( - row.numFields == columnBuilders.size, + row.numFields == columnBuilders.length, s"Row column number mismatch, expected ${output.size} columns, " + s"but got ${row.numFields}." + s"\nRow content: $row") var i = 0 + totalSize = 0 while (i < row.numFields) { columnBuilders(i).appendFrom(row, i) + totalSize += columnBuilders(i).columnStats.sizeInBytes i += 1 } rowCount += 1 @@ -158,7 +166,9 @@ private[sql] case class InMemoryRelation( .flatMap(_.values)) batchStats += stats - CachedBatch(rowCount, columnBuilders.map(_.build().array()), stats) + CachedBatch(rowCount, columnBuilders.map { builder => + JavaUtils.bufferToArray(builder.build()) + }, stats) } def hasNext: Boolean = rowIterator.hasNext @@ -175,8 +185,6 @@ private[sql] case class InMemoryRelation( _cachedColumnBuffers, statisticsToBePropagated, batchStats) } - override def children: Seq[LogicalPlan] = Seq.empty - override def newInstance(): this.type = { new InMemoryRelation( output.map(_.newInstance()), @@ -208,6 +216,9 @@ private[sql] case class InMemoryColumnarTableScan( @transient relation: InMemoryRelation) extends LeafNode { + private[sql] override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def output: Seq[Attribute] = attributes // The cached version does not change the outputPartitioning of the original SparkPlan. @@ -216,8 +227,6 @@ private[sql] case class InMemoryColumnarTableScan( // The cached version does not change the outputOrdering of the original SparkPlan. override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering - override def outputsUnsafeRows: Boolean = true - private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) // Returned filter predicate should return false iff it is impossible for the input expression @@ -280,6 +289,8 @@ private[sql] case class InMemoryColumnarTableScan( private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + if (enableAccumulators) { readPartitions.setValue(0) readBatches.setValue(0) @@ -292,7 +303,7 @@ private[sql] case class InMemoryColumnarTableScan( val relOutput = relation.output val buffers = relation.cachedColumnBuffers - buffers.mapPartitions { cachedBatchIterator => + buffers.mapPartitionsInternal { cachedBatchIterator => val partitionFilter = newPredicate( partitionFilters.reduceOption(And).getOrElse(Literal(true)), schema) @@ -326,12 +337,18 @@ private[sql] case class InMemoryColumnarTableScan( cachedBatchIterator } + // update SQL metrics + val withMetrics = cachedBatchesToScan.map { batch => + numOutputRows += batch.numRows + batch + } + val columnTypes = requestedColumnDataTypes.map { case udt: UserDefinedType[_] => udt.sqlType case other => other }.toArray val columnarIterator = GenerateColumnAccessor.generate(columnTypes) - columnarIterator.initialize(cachedBatchesToScan, columnTypes, requestedColumnIndices.toArray) + columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray) if (enableAccumulators && columnarIterator.hasNext) { readPartitions += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala similarity index 92% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala index 7eaecfe047c3f..2465633162c4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar -import java.nio.{ByteOrder, ByteBuffer} +import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.expressions.MutableRow -private[sql] trait NullableColumnAccessor extends ColumnAccessor { +private[columnar] trait NullableColumnAccessor extends ColumnAccessor { private var nullsBuffer: ByteBuffer = _ private var nullCount: Int = _ private var seenNulls: Int = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilder.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilder.scala index 76cfddf1cd01a..3a1931bfb5c84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilder.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.{ByteBuffer, ByteOrder} @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.InternalRow * +---+-----+---------+ * }}} */ -private[sql] trait NullableColumnBuilder extends ColumnBuilder { +private[columnar] trait NullableColumnBuilder extends ColumnBuilder { protected var nulls: ByteBuffer = _ protected var nullCount: Int = _ private var pos: Int = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala similarity index 84% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala index cb205defbb1ad..6579b5068e65a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.columnar.{ColumnAccessor, NativeColumnAccessor} +import org.apache.spark.sql.execution.columnar.{ColumnAccessor, NativeColumnAccessor} import org.apache.spark.sql.types.AtomicType -private[sql] trait CompressibleColumnAccessor[T <: AtomicType] extends ColumnAccessor { +private[columnar] trait CompressibleColumnAccessor[T <: AtomicType] extends ColumnAccessor { this: NativeColumnAccessor[T] => private var decoder: Decoder[T] = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala similarity index 93% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala index 161021ff96154..63eae1b8685ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import java.nio.{ByteBuffer, ByteOrder} -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder} +import org.apache.spark.sql.execution.columnar.{ColumnBuilder, NativeColumnBuilder} import org.apache.spark.sql.types.AtomicType /** @@ -40,7 +40,7 @@ import org.apache.spark.sql.types.AtomicType * header body * }}} */ -private[sql] trait CompressibleColumnBuilder[T <: AtomicType] +private[columnar] trait CompressibleColumnBuilder[T <: AtomicType] extends ColumnBuilder with Logging { this: NativeColumnBuilder[T] with WithCompressionSchemes => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala similarity index 83% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala index 9322b772fd898..b90d00b15b180 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala @@ -15,15 +15,16 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import java.nio.{ByteBuffer, ByteOrder} + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType} +import org.apache.spark.sql.execution.columnar.{ColumnType, NativeColumnType} import org.apache.spark.sql.types.AtomicType -private[sql] trait Encoder[T <: AtomicType] { +private[columnar] trait Encoder[T <: AtomicType] { def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = {} def compressedSize: Int @@ -37,13 +38,13 @@ private[sql] trait Encoder[T <: AtomicType] { def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer } -private[sql] trait Decoder[T <: AtomicType] { +private[columnar] trait Decoder[T <: AtomicType] { def next(row: MutableRow, ordinal: Int): Unit def hasNext: Boolean } -private[sql] trait CompressionScheme { +private[columnar] trait CompressionScheme { def typeId: Int def supports(columnType: ColumnType[_]): Boolean @@ -53,15 +54,15 @@ private[sql] trait CompressionScheme { def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] } -private[sql] trait WithCompressionSchemes { +private[columnar] trait WithCompressionSchemes { def schemes: Seq[CompressionScheme] } -private[sql] trait AllCompressionSchemes extends WithCompressionSchemes { +private[columnar] trait AllCompressionSchemes extends WithCompressionSchemes { override val schemes: Seq[CompressionScheme] = CompressionScheme.all } -private[sql] object CompressionScheme { +private[columnar] object CompressionScheme { val all: Seq[CompressionScheme] = Seq(PassThrough, RunLengthEncoding, DictionaryEncoding, BooleanBitSet, IntDelta, LongDelta) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala index 41c9a284e3e4a..941f03b745a07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import java.nio.ByteBuffer @@ -23,11 +23,11 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} -import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.types._ -private[sql] case object PassThrough extends CompressionScheme { +private[columnar] case object PassThrough extends CompressionScheme { override val typeId = 0 override def supports(columnType: ColumnType[_]): Boolean = true @@ -64,7 +64,7 @@ private[sql] case object PassThrough extends CompressionScheme { } } -private[sql] case object RunLengthEncoding extends CompressionScheme { +private[columnar] case object RunLengthEncoding extends CompressionScheme { override val typeId = 1 override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = { @@ -172,7 +172,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { } } -private[sql] case object DictionaryEncoding extends CompressionScheme { +private[columnar] case object DictionaryEncoding extends CompressionScheme { override val typeId = 2 // 32K unique values allowed @@ -281,7 +281,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { } } -private[sql] case object BooleanBitSet extends CompressionScheme { +private[columnar] case object BooleanBitSet extends CompressionScheme { override val typeId = 3 val BITS_PER_LONG = 64 @@ -371,7 +371,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme { } } -private[sql] case object IntDelta extends CompressionScheme { +private[columnar] case object IntDelta extends CompressionScheme { override def typeId: Int = 4 override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) @@ -451,7 +451,7 @@ private[sql] case object IntDelta extends CompressionScheme { } } -private[sql] case object LongDelta extends CompressionScheme { +private[columnar] case object LongDelta extends CompressionScheme { override def typeId: Int = 5 override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala new file mode 100644 index 0000000000000..5d00c805a6afe --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -0,0 +1,530 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import java.util.NoSuchElementException + +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{AnalysisException, Dataset, Row, SQLContext} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.debug._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +/** + * A logical command that is executed for its side-effects. `RunnableCommand`s are + * wrapped in `ExecutedCommand` during execution. + */ +private[sql] trait RunnableCommand extends LogicalPlan with logical.Command { + override def output: Seq[Attribute] = Seq.empty + override def children: Seq[LogicalPlan] = Seq.empty + def run(sqlContext: SQLContext): Seq[Row] +} + +/** + * A physical operator that executes the run method of a `RunnableCommand` and + * saves the result to prevent multiple executions. + */ +private[sql] case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { + /** + * A concrete command should override this lazy field to wrap up any side effects caused by the + * command or any other computation that should be evaluated exactly once. The value of this field + * can be used as the contents of the corresponding RDD generated from the physical plan of this + * command. + * + * The `execute()` method of all the physical command classes should reference `sideEffectResult` + * so that the command can be executed eagerly right after the command query is created. + */ + protected[sql] lazy val sideEffectResult: Seq[InternalRow] = { + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + cmd.run(sqlContext).map(converter(_).asInstanceOf[InternalRow]) + } + + override def output: Seq[Attribute] = cmd.output + + override def children: Seq[SparkPlan] = Nil + + override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray + + override def executeTake(limit: Int): Array[InternalRow] = sideEffectResult.take(limit).toArray + + protected override def doExecute(): RDD[InternalRow] = { + sqlContext.sparkContext.parallelize(sideEffectResult, 1) + } + + override def argString: String = cmd.toString +} + + +case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableCommand with Logging { + + private def keyValueOutput: Seq[Attribute] = { + val schema = StructType( + StructField("key", StringType, false) :: + StructField("value", StringType, false) :: Nil) + schema.toAttributes + } + + private val (_output, runFunc): (Seq[Attribute], SQLContext => Seq[Row]) = kv match { + // Configures the deprecated "mapred.reduce.tasks" property. + case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS.key} instead.") + if (value.toInt < 1) { + val msg = + s"Setting negative ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} for automatically " + + "determining the number of reducers is not supported." + throw new IllegalArgumentException(msg) + } else { + sqlContext.setConf(SQLConf.SHUFFLE_PARTITIONS.key, value) + Seq(Row(SQLConf.SHUFFLE_PARTITIONS.key, value)) + } + } + (keyValueOutput, runFunc) + + case Some((SQLConf.Deprecated.EXTERNAL_SORT, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.EXTERNAL_SORT} is deprecated and will be ignored. " + + s"External sort will continue to be used.") + Seq(Row(SQLConf.Deprecated.EXTERNAL_SORT, "true")) + } + (keyValueOutput, runFunc) + + case Some((SQLConf.Deprecated.USE_SQL_AGGREGATE2, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.USE_SQL_AGGREGATE2} is deprecated and " + + s"will be ignored. ${SQLConf.Deprecated.USE_SQL_AGGREGATE2} will " + + s"continue to be true.") + Seq(Row(SQLConf.Deprecated.USE_SQL_AGGREGATE2, "true")) + } + (keyValueOutput, runFunc) + + case Some((SQLConf.Deprecated.TUNGSTEN_ENABLED, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.TUNGSTEN_ENABLED} is deprecated and " + + s"will be ignored. Tungsten will continue to be used.") + Seq(Row(SQLConf.Deprecated.TUNGSTEN_ENABLED, "true")) + } + (keyValueOutput, runFunc) + + case Some((SQLConf.Deprecated.CODEGEN_ENABLED, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.CODEGEN_ENABLED} is deprecated and " + + s"will be ignored. Codegen will continue to be used.") + Seq(Row(SQLConf.Deprecated.CODEGEN_ENABLED, "true")) + } + (keyValueOutput, runFunc) + + case Some((SQLConf.Deprecated.UNSAFE_ENABLED, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.UNSAFE_ENABLED} is deprecated and " + + s"will be ignored. Unsafe mode will continue to be used.") + Seq(Row(SQLConf.Deprecated.UNSAFE_ENABLED, "true")) + } + (keyValueOutput, runFunc) + + case Some((SQLConf.Deprecated.SORTMERGE_JOIN, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.SORTMERGE_JOIN} is deprecated and " + + s"will be ignored. Sort merge join will continue to be used.") + Seq(Row(SQLConf.Deprecated.SORTMERGE_JOIN, "true")) + } + (keyValueOutput, runFunc) + + case Some((SQLConf.Deprecated.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED} is " + + s"deprecated and will be ignored. Vectorized parquet reader will be used instead.") + Seq(Row(SQLConf.PARQUET_VECTORIZED_READER_ENABLED, "true")) + } + (keyValueOutput, runFunc) + + // Configures a single property. + case Some((key, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + sqlContext.setConf(key, value) + Seq(Row(key, value)) + } + (keyValueOutput, runFunc) + + // (In Hive, "SET" returns all changed properties while "SET -v" returns all properties.) + // Queries all key-value pairs that are set in the SQLConf of the sqlContext. + case None => + val runFunc = (sqlContext: SQLContext) => { + sqlContext.getAllConfs.map { case (k, v) => Row(k, v) }.toSeq + } + (keyValueOutput, runFunc) + + // Queries all properties along with their default values and docs that are defined in the + // SQLConf of the sqlContext. + case Some(("-v", None)) => + val runFunc = (sqlContext: SQLContext) => { + sqlContext.conf.getAllDefinedConfs.map { case (key, defaultValue, doc) => + Row(key, defaultValue, doc) + } + } + val schema = StructType( + StructField("key", StringType, false) :: + StructField("default", StringType, false) :: + StructField("meaning", StringType, false) :: Nil) + (schema.toAttributes, runFunc) + + // Queries the deprecated "mapred.reduce.tasks" property. + case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, None)) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + + s"showing ${SQLConf.SHUFFLE_PARTITIONS.key} instead.") + Seq(Row(SQLConf.SHUFFLE_PARTITIONS.key, sqlContext.conf.numShufflePartitions.toString)) + } + (keyValueOutput, runFunc) + + // Queries a single property. + case Some((key, None)) => + val runFunc = (sqlContext: SQLContext) => { + val value = + try sqlContext.getConf(key) catch { + case _: NoSuchElementException => "" + } + Seq(Row(key, value)) + } + (keyValueOutput, runFunc) + } + + override val output: Seq[Attribute] = _output + + override def run(sqlContext: SQLContext): Seq[Row] = runFunc(sqlContext) + +} + +/** + * An explain command for users to see how a command will be executed. + * + * Note that this command takes in a logical plan, runs the optimizer on the logical plan + * (but do NOT actually execute it). + */ +case class ExplainCommand( + logicalPlan: LogicalPlan, + override val output: Seq[Attribute] = + Seq(AttributeReference("plan", StringType, nullable = true)()), + extended: Boolean = false, + codegen: Boolean = false) + extends RunnableCommand { + + // Run through the optimizer to generate the physical plan. + override def run(sqlContext: SQLContext): Seq[Row] = try { + // TODO in Hive, the "extended" ExplainCommand prints the AST as well, and detailed properties. + val queryExecution = sqlContext.executePlan(logicalPlan) + val outputString = + if (codegen) { + codegenString(queryExecution.executedPlan) + } else if (extended) { + queryExecution.toString + } else { + queryExecution.simpleString + } + Seq(Row(outputString)) + } catch { case cause: TreeNodeException[_] => + ("Error occurred during query planning: \n" + cause.getMessage).split("\n").map(Row(_)) + } +} + + +case class CacheTableCommand( + tableName: String, + plan: Option[LogicalPlan], + isLazy: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + plan.foreach { logicalPlan => + sqlContext.registerDataFrameAsTable(Dataset.ofRows(sqlContext, logicalPlan), tableName) + } + sqlContext.cacheTable(tableName) + + if (!isLazy) { + // Performs eager caching + sqlContext.table(tableName).count() + } + + Seq.empty[Row] + } + + override def output: Seq[Attribute] = Seq.empty +} + + +case class UncacheTableCommand(tableName: String) extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.table(tableName).unpersist(blocking = false) + Seq.empty[Row] + } + + override def output: Seq[Attribute] = Seq.empty +} + +/** + * Clear all cached data from the in-memory cache. + */ +case object ClearCacheCommand extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.clearCache() + Seq.empty[Row] + } + + override def output: Seq[Attribute] = Seq.empty +} + + +case class DescribeCommand( + table: TableIdentifier, + override val output: Seq[Attribute], + isExtended: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val relation = sqlContext.sessionState.catalog.lookupRelation(table) + relation.schema.fields.map { field => + val cmtKey = "comment" + val comment = if (field.metadata.contains(cmtKey)) field.metadata.getString(cmtKey) else "" + Row(field.name, field.dataType.simpleString, comment) + } + } +} + +/** + * A command for users to get tables in the given database. + * If a databaseName is not given, the current database will be used. + * The syntax of using this command in SQL is: + * {{{ + * SHOW TABLES [(IN|FROM) database_name] [[LIKE] 'identifier_with_wildcards']; + * }}} + */ +case class ShowTablesCommand( + databaseName: Option[String], + tableIdentifierPattern: Option[String]) extends RunnableCommand { + + // The result of SHOW TABLES has two columns, tableName and isTemporary. + override val output: Seq[Attribute] = { + AttributeReference("tableName", StringType, nullable = false)() :: + AttributeReference("isTemporary", BooleanType, nullable = false)() :: Nil + } + + override def run(sqlContext: SQLContext): Seq[Row] = { + // Since we need to return a Seq of rows, we will call getTables directly + // instead of calling tables in sqlContext. + val catalog = sqlContext.sessionState.catalog + val db = databaseName.getOrElse(catalog.getCurrentDatabase) + val tables = + tableIdentifierPattern.map(catalog.listTables(db, _)).getOrElse(catalog.listTables(db)) + tables.map { t => + val isTemp = t.database.isEmpty + Row(t.table, isTemp) + } + } +} + +/** + * A command for users to list the databases/schemas. + * If a databasePattern is supplied then the databases that only matches the + * pattern would be listed. + * The syntax of using this command in SQL is: + * {{{ + * SHOW (DATABASES|SCHEMAS) [LIKE 'identifier_with_wildcards']; + * }}} + */ +case class ShowDatabasesCommand(databasePattern: Option[String]) extends RunnableCommand { + + // The result of SHOW DATABASES has one column called 'result' + override val output: Seq[Attribute] = { + AttributeReference("result", StringType, nullable = false)() :: Nil + } + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val databases = + databasePattern.map(catalog.listDatabases(_)).getOrElse(catalog.listDatabases()) + databases.map { d => Row(d) } + } +} + +/** + * A command for users to list the properties for a table If propertyKey is specified, the value + * for the propertyKey is returned. If propertyKey is not specified, all the keys and their + * corresponding values are returned. + * The syntax of using this command in SQL is: + * {{{ + * SHOW TBLPROPERTIES table_name[('propertyKey')]; + * }}} + */ +case class ShowTablePropertiesCommand( + table: TableIdentifier, + propertyKey: Option[String]) extends RunnableCommand { + + override val output: Seq[Attribute] = { + val schema = AttributeReference("value", StringType, nullable = false)() :: Nil + propertyKey match { + case None => AttributeReference("key", StringType, nullable = false)() :: schema + case _ => schema + } + } + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + + if (catalog.isTemporaryTable(table)) { + Seq.empty[Row] + } else { + val catalogTable = sqlContext.sessionState.catalog.getTableMetadata(table) + + propertyKey match { + case Some(p) => + val propValue = catalogTable + .properties + .getOrElse(p, s"Table ${catalogTable.qualifiedName} does not have property: $p") + Seq(Row(propValue)) + case None => + catalogTable.properties.map(p => Row(p._1, p._2)).toSeq + } + } + } +} + +/** + * A command for users to list all of the registered functions. + * The syntax of using this command in SQL is: + * {{{ + * SHOW FUNCTIONS [LIKE pattern] + * }}} + * For the pattern, '*' matches any sequence of characters (including no characters) and + * '|' is for alternation. + * For example, "show functions like 'yea*|windo*'" will return "window" and "year". + * + * TODO currently we are simply ignore the db + */ +case class ShowFunctions(db: Option[String], pattern: Option[String]) extends RunnableCommand { + override val output: Seq[Attribute] = { + val schema = StructType( + StructField("function", StringType, nullable = false) :: Nil) + + schema.toAttributes + } + + override def run(sqlContext: SQLContext): Seq[Row] = { + val dbName = db.getOrElse(sqlContext.sessionState.catalog.getCurrentDatabase) + // If pattern is not specified, we use '*', which is used to + // match any sequence of characters (including no characters). + val functionNames = + sqlContext.sessionState.catalog + .listFunctions(dbName, pattern.getOrElse("*")) + .map(_.unquotedString) + // The session catalog caches some persistent functions in the FunctionRegistry + // so there can be duplicates. + functionNames.distinct.sorted.map(Row(_)) + } +} + +/** + * A command for users to get the usage of a registered function. + * The syntax of using this command in SQL is + * {{{ + * DESCRIBE FUNCTION [EXTENDED] upper; + * }}} + */ +case class DescribeFunction( + functionName: String, + isExtended: Boolean) extends RunnableCommand { + + override val output: Seq[Attribute] = { + val schema = StructType( + StructField("function_desc", StringType, nullable = false) :: Nil) + + schema.toAttributes + } + + private def replaceFunctionName(usage: String, functionName: String): String = { + if (usage == null) { + "To be added." + } else { + usage.replaceAll("_FUNC_", functionName) + } + } + + override def run(sqlContext: SQLContext): Seq[Row] = { + // Hard code "<>", "!=", "between", and "case" for now as there is no corresponding functions. + functionName.toLowerCase match { + case "<>" => + Row(s"Function: $functionName") :: + Row(s"Usage: a <> b - Returns TRUE if a is not equal to b") :: Nil + case "!=" => + Row(s"Function: $functionName") :: + Row(s"Usage: a != b - Returns TRUE if a is not equal to b") :: Nil + case "between" => + Row(s"Function: between") :: + Row(s"Usage: a [NOT] BETWEEN b AND c - " + + s"evaluate if a is [not] in between b and c") :: Nil + case "case" => + Row(s"Function: case") :: + Row(s"Usage: CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END - " + + s"When a = b, returns c; when a = d, return e; else return f") :: Nil + case _ => sqlContext.sessionState.functionRegistry.lookupFunction(functionName) match { + case Some(info) => + val result = + Row(s"Function: ${info.getName}") :: + Row(s"Class: ${info.getClassName}") :: + Row(s"Usage: ${replaceFunctionName(info.getUsage(), info.getName)}") :: Nil + + if (isExtended) { + result :+ + Row(s"Extended Usage:\n${replaceFunctionName(info.getExtended, info.getName)}") + } else { + result + } + + case None => Seq(Row(s"Function: $functionName not found.")) + } + } + } +} + +case class SetDatabaseCommand(databaseName: String) extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.sessionState.catalog.setCurrentDatabase(databaseName) + Seq.empty[Row] + } + + override val output: Seq[Attribute] = Seq.empty +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala new file mode 100644 index 0000000000000..fc37a142cda1c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -0,0 +1,540 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import scala.util.control.NonFatal + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable} +import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, CatalogTableType, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.types._ + + + +// Note: The definition of these commands are based on the ones described in +// https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL + +/** + * A DDL command expected to be parsed and run in an underlying system instead of in Spark. + */ +abstract class NativeDDLCommand(val sql: String) extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.runNativeSql(sql) + } + + override val output: Seq[Attribute] = { + Seq(AttributeReference("result", StringType, nullable = false)()) + } + +} + +/** + * A command for users to create a new database. + * + * It will issue an error message when the database with the same name already exists, + * unless 'ifNotExists' is true. + * The syntax of using this command in SQL is: + * {{{ + * CREATE DATABASE|SCHEMA [IF NOT EXISTS] database_name + * }}} + */ +case class CreateDatabase( + databaseName: String, + ifNotExists: Boolean, + path: Option[String], + comment: Option[String], + props: Map[String, String]) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + catalog.createDatabase( + CatalogDatabase( + databaseName, + comment.getOrElse(""), + path.getOrElse(catalog.getDefaultDBPath(databaseName)), + props), + ifNotExists) + Seq.empty[Row] + } + + override val output: Seq[Attribute] = Seq.empty +} + + +/** + * A command for users to remove a database from the system. + * + * 'ifExists': + * - true, if database_name does't exist, no action + * - false (default), if database_name does't exist, a warning message will be issued + * 'cascade': + * - true, the dependent objects are automatically dropped before dropping database. + * - false (default), it is in the Restrict mode. The database cannot be dropped if + * it is not empty. The inclusive tables must be dropped at first. + * + * The syntax of using this command in SQL is: + * {{{ + * DROP DATABASE [IF EXISTS] database_name [RESTRICT|CASCADE]; + * }}} + */ +case class DropDatabase( + databaseName: String, + ifExists: Boolean, + cascade: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.sessionState.catalog.dropDatabase(databaseName, ifExists, cascade) + Seq.empty[Row] + } + + override val output: Seq[Attribute] = Seq.empty +} + +/** + * A command for users to add new (key, value) pairs into DBPROPERTIES + * If the database does not exist, an error message will be issued to indicate the database + * does not exist. + * The syntax of using this command in SQL is: + * {{{ + * ALTER (DATABASE|SCHEMA) database_name SET DBPROPERTIES (property_name=property_value, ...) + * }}} + */ +case class AlterDatabaseProperties( + databaseName: String, + props: Map[String, String]) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val db: CatalogDatabase = catalog.getDatabaseMetadata(databaseName) + catalog.alterDatabase(db.copy(properties = db.properties ++ props)) + + Seq.empty[Row] + } + + override val output: Seq[Attribute] = Seq.empty +} + +/** + * A command for users to show the name of the database, its comment (if one has been set), and its + * root location on the filesystem. When extended is true, it also shows the database's properties + * If the database does not exist, an error message will be issued to indicate the database + * does not exist. + * The syntax of using this command in SQL is + * {{{ + * DESCRIBE DATABASE [EXTENDED] db_name + * }}} + */ +case class DescribeDatabase( + databaseName: String, + extended: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val dbMetadata: CatalogDatabase = + sqlContext.sessionState.catalog.getDatabaseMetadata(databaseName) + val result = + Row("Database Name", dbMetadata.name) :: + Row("Description", dbMetadata.description) :: + Row("Location", dbMetadata.locationUri) :: Nil + + if (extended) { + val properties = + if (dbMetadata.properties.isEmpty) { + "" + } else { + dbMetadata.properties.toSeq.mkString("(", ", ", ")") + } + result :+ Row("Properties", properties) + } else { + result + } + } + + override val output: Seq[Attribute] = { + AttributeReference("database_description_item", StringType, nullable = false)() :: + AttributeReference("database_description_value", StringType, nullable = false)() :: Nil + } +} + +/** + * Drops a table/view from the metastore and removes it if it is cached. + * + * The syntax of this command is: + * {{{ + * DROP TABLE [IF EXISTS] table_name; + * DROP VIEW [IF EXISTS] [db_name.]view_name; + * }}} + */ +case class DropTable( + tableName: TableIdentifier, + ifExists: Boolean, + isView: Boolean) extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + if (!catalog.tableExists(tableName)) { + if (!ifExists) { + val objectName = if (isView) "View" else "Table" + logError(s"$objectName '${tableName.quotedString}' does not exist") + } + } else { + // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view + // issue an exception. + catalog.getTableMetadataOption(tableName).map(_.tableType match { + case CatalogTableType.VIRTUAL_VIEW if !isView => + throw new AnalysisException( + "Cannot drop a view with DROP TABLE. Please use DROP VIEW instead") + case o if o != CatalogTableType.VIRTUAL_VIEW && isView => + throw new AnalysisException( + s"Cannot drop a table with DROP VIEW. Please use DROP TABLE instead") + case _ => + }) + try { + sqlContext.cacheManager.tryUncacheQuery(sqlContext.table(tableName.quotedString)) + } catch { + case NonFatal(e) => log.warn(s"${e.getMessage}", e) + } + catalog.invalidateTable(tableName) + catalog.dropTable(tableName, ifExists) + } + Seq.empty[Row] + } +} + +/** + * A command that sets table/view properties. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table1 SET TBLPROPERTIES ('key1' = 'val1', 'key2' = 'val2', ...); + * ALTER VIEW view1 SET TBLPROPERTIES ('key1' = 'val1', 'key2' = 'val2', ...); + * }}} + */ +case class AlterTableSetProperties( + tableName: TableIdentifier, + properties: Map[String, String], + isView: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + DDLUtils.verifyAlterTableType(catalog, tableName, isView) + val table = catalog.getTableMetadata(tableName) + val newProperties = table.properties ++ properties + if (DDLUtils.isDatasourceTable(newProperties)) { + throw new AnalysisException( + "alter table properties is not supported for tables defined using the datasource API") + } + val newTable = table.copy(properties = newProperties) + catalog.alterTable(newTable) + Seq.empty[Row] + } + +} + +/** + * A command that unsets table/view properties. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table1 UNSET TBLPROPERTIES [IF EXISTS] ('key1', 'key2', ...); + * ALTER VIEW view1 UNSET TBLPROPERTIES [IF EXISTS] ('key1', 'key2', ...); + * }}} + */ +case class AlterTableUnsetProperties( + tableName: TableIdentifier, + propKeys: Seq[String], + ifExists: Boolean, + isView: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + DDLUtils.verifyAlterTableType(catalog, tableName, isView) + val table = catalog.getTableMetadata(tableName) + if (DDLUtils.isDatasourceTable(table)) { + throw new AnalysisException( + "alter table properties is not supported for datasource tables") + } + if (!ifExists) { + propKeys.foreach { k => + if (!table.properties.contains(k)) { + throw new AnalysisException( + s"attempted to unset non-existent property '$k' in table '$tableName'") + } + } + } + val newProperties = table.properties.filter { case (k, _) => !propKeys.contains(k) } + val newTable = table.copy(properties = newProperties) + catalog.alterTable(newTable) + Seq.empty[Row] + } + +} + +/** + * A command that sets the serde class and/or serde properties of a table/view. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table [PARTITION spec] SET SERDE serde_name [WITH SERDEPROPERTIES props]; + * ALTER TABLE table [PARTITION spec] SET SERDEPROPERTIES serde_properties; + * }}} + */ +case class AlterTableSerDeProperties( + tableName: TableIdentifier, + serdeClassName: Option[String], + serdeProperties: Option[Map[String, String]], + partition: Option[Map[String, String]]) + extends RunnableCommand { + + // should never happen if we parsed things correctly + require(serdeClassName.isDefined || serdeProperties.isDefined, + "alter table attempted to set neither serde class name nor serde properties") + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val table = catalog.getTableMetadata(tableName) + // Do not support setting serde for datasource tables + if (serdeClassName.isDefined && DDLUtils.isDatasourceTable(table)) { + throw new AnalysisException( + "alter table serde is not supported for datasource tables") + } + val newTable = table.withNewStorage( + serde = serdeClassName.orElse(table.storage.serde), + serdeProperties = table.storage.serdeProperties ++ serdeProperties.getOrElse(Map())) + catalog.alterTable(newTable) + Seq.empty[Row] + } + +} + +/** + * Add Partition in ALTER TABLE: add the table partitions. + * + * 'partitionSpecsAndLocs': the syntax of ALTER VIEW is identical to ALTER TABLE, + * EXCEPT that it is ILLEGAL to specify a LOCATION clause. + * An error message will be issued if the partition exists, unless 'ifNotExists' is true. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table ADD [IF NOT EXISTS] PARTITION spec [LOCATION 'loc1'] + * }}} + */ +case class AlterTableAddPartition( + tableName: TableIdentifier, + partitionSpecsAndLocs: Seq[(TablePartitionSpec, Option[String])], + ifNotExists: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val table = catalog.getTableMetadata(tableName) + if (DDLUtils.isDatasourceTable(table)) { + throw new AnalysisException( + "alter table add partition is not allowed for tables defined using the datasource API") + } + val parts = partitionSpecsAndLocs.map { case (spec, location) => + // inherit table storage format (possibly except for location) + CatalogTablePartition(spec, table.storage.copy(locationUri = location)) + } + catalog.createPartitions(tableName, parts, ignoreIfExists = ifNotExists) + Seq.empty[Row] + } + +} + +/** + * Alter a table partition's spec. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table PARTITION spec1 RENAME TO PARTITION spec2; + * }}} + */ +case class AlterTableRenamePartition( + tableName: TableIdentifier, + oldPartition: TablePartitionSpec, + newPartition: TablePartitionSpec) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.sessionState.catalog.renamePartitions( + tableName, Seq(oldPartition), Seq(newPartition)) + Seq.empty[Row] + } + +} + +/** + * Drop Partition in ALTER TABLE: to drop a particular partition for a table. + * + * This removes the data and metadata for this partition. + * The data is actually moved to the .Trash/Current directory if Trash is configured, + * unless 'purge' is true, but the metadata is completely lost. + * An error message will be issued if the partition does not exist, unless 'ifExists' is true. + * Note: purge is always false when the target is a view. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE]; + * }}} + */ +case class AlterTableDropPartition( + tableName: TableIdentifier, + specs: Seq[TablePartitionSpec], + ifExists: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val table = catalog.getTableMetadata(tableName) + if (DDLUtils.isDatasourceTable(table)) { + throw new AnalysisException( + "alter table drop partition is not allowed for tables defined using the datasource API") + } + catalog.dropPartitions(tableName, specs, ignoreIfNotExists = ifExists) + Seq.empty[Row] + } + +} + +case class AlterTableSetFileFormat( + tableName: TableIdentifier, + partitionSpec: Option[TablePartitionSpec], + fileFormat: Seq[String], + genericFormat: Option[String])(sql: String) + extends NativeDDLCommand(sql) with Logging + +/** + * A command that sets the location of a table or a partition. + * + * For normal tables, this just sets the location URI in the table/partition's storage format. + * For datasource tables, this sets a "path" parameter in the table/partition's serde properties. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table_name [PARTITION partition_spec] SET LOCATION "loc"; + * }}} + */ +case class AlterTableSetLocation( + tableName: TableIdentifier, + partitionSpec: Option[TablePartitionSpec], + location: String) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val table = catalog.getTableMetadata(tableName) + partitionSpec match { + case Some(spec) => + // Partition spec is specified, so we set the location only for this partition + val part = catalog.getPartition(tableName, spec) + val newPart = + if (DDLUtils.isDatasourceTable(table)) { + throw new AnalysisException( + "alter table set location for partition is not allowed for tables defined " + + "using the datasource API") + } else { + part.copy(storage = part.storage.copy(locationUri = Some(location))) + } + catalog.alterPartitions(tableName, Seq(newPart)) + case None => + // No partition spec is specified, so we set the location for the table itself + val newTable = + if (DDLUtils.isDatasourceTable(table)) { + table.withNewStorage( + locationUri = Some(location), + serdeProperties = table.storage.serdeProperties ++ Map("path" -> location)) + } else { + table.withNewStorage(locationUri = Some(location)) + } + catalog.alterTable(newTable) + } + Seq.empty[Row] + } + +} + +case class AlterTableChangeCol( + tableName: TableIdentifier, + partitionSpec: Option[TablePartitionSpec], + oldColName: String, + newColName: String, + dataType: DataType, + comment: Option[String], + afterColName: Option[String], + restrict: Boolean, + cascade: Boolean)(sql: String) + extends NativeDDLCommand(sql) with Logging + +case class AlterTableAddCol( + tableName: TableIdentifier, + partitionSpec: Option[TablePartitionSpec], + columns: StructType, + restrict: Boolean, + cascade: Boolean)(sql: String) + extends NativeDDLCommand(sql) with Logging + +case class AlterTableReplaceCol( + tableName: TableIdentifier, + partitionSpec: Option[TablePartitionSpec], + columns: StructType, + restrict: Boolean, + cascade: Boolean)(sql: String) + extends NativeDDLCommand(sql) with Logging + + +private object DDLUtils { + + def isDatasourceTable(props: Map[String, String]): Boolean = { + props.contains("spark.sql.sources.provider") + } + + def isDatasourceTable(table: CatalogTable): Boolean = { + isDatasourceTable(table.properties) + } + + /** + * If the command ALTER VIEW is to alter a table or ALTER TABLE is to alter a view, + * issue an exception [[AnalysisException]]. + */ + def verifyAlterTableType( + catalog: SessionCatalog, + tableIdentifier: TableIdentifier, + isView: Boolean): Unit = { + catalog.getTableMetadataOption(tableIdentifier).map(_.tableType match { + case CatalogTableType.VIRTUAL_VIEW if !isView => + throw new AnalysisException( + "Cannot alter a view with ALTER TABLE. Please use ALTER VIEW instead") + case o if o != CatalogTableType.VIRTUAL_VIEW && isView => + throw new AnalysisException( + s"Cannot alter a table with ALTER VIEW. Please use ALTER TABLE instead") + case _ => + }) + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala new file mode 100644 index 0000000000000..c6e601799f527 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogFunction +import org.apache.spark.sql.catalyst.expressions.ExpressionInfo + + +/** + * The DDL command that creates a function. + * To create a temporary function, the syntax of using this command in SQL is: + * {{{ + * CREATE TEMPORARY FUNCTION functionName + * AS className [USING JAR\FILE 'uri' [, JAR|FILE 'uri']] + * }}} + * + * To create a permanent function, the syntax in SQL is: + * {{{ + * CREATE FUNCTION [databaseName.]functionName + * AS className [USING JAR\FILE 'uri' [, JAR|FILE 'uri']] + * }}} + */ +// TODO: Use Seq[FunctionResource] instead of Seq[(String, String)] for resources. +case class CreateFunction( + databaseName: Option[String], + functionName: String, + className: String, + resources: Seq[(String, String)], + isTemp: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + if (isTemp) { + if (databaseName.isDefined) { + throw new AnalysisException( + s"It is not allowed to provide database name when defining a temporary function. " + + s"However, database name ${databaseName.get} is provided.") + } + // We first load resources and then put the builder in the function registry. + // Please note that it is allowed to overwrite an existing temp function. + catalog.loadFunctionResources(resources) + val info = new ExpressionInfo(className, functionName) + val builder = catalog.makeFunctionBuilder(functionName, className) + catalog.createTempFunction(functionName, info, builder, ignoreIfExists = false) + } else { + // For a permanent, we will store the metadata into underlying external catalog. + // This function will be loaded into the FunctionRegistry when a query uses it. + // We do not load it into FunctionRegistry right now. + // TODO: should we also parse "IF NOT EXISTS"? + catalog.createFunction( + CatalogFunction(FunctionIdentifier(functionName, databaseName), className, resources), + ignoreIfExists = false) + } + Seq.empty[Row] + } +} + +/** + * The DDL command that drops a function. + * ifExists: returns an error if the function doesn't exist, unless this is true. + * isTemp: indicates if it is a temporary function. + */ +case class DropFunction( + databaseName: Option[String], + functionName: String, + ifExists: Boolean, + isTemp: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + if (isTemp) { + if (databaseName.isDefined) { + throw new AnalysisException( + s"It is not allowed to provide database name when dropping a temporary function. " + + s"However, database name ${databaseName.get} is provided.") + } + catalog.dropTempFunction(functionName, ifExists) + } else { + // We are dropping a permanent function. + catalog.dropFunction( + FunctionIdentifier(functionName, databaseName), + ignoreIfNotExists = ifExists) + } + Seq.empty[Row] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala new file mode 100644 index 0000000000000..0b419851746ad --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} + +/** + * A command to create a table with the same definition of the given existing table. + * + * The syntax of using this command in SQL is: + * {{{ + * CREATE TABLE [IF NOT EXISTS] [db_name.]table_name + * LIKE [other_db_name.]existing_table_name + * }}} + */ +case class CreateTableLike( + targetTable: TableIdentifier, + sourceTable: TableIdentifier, + ifNotExists: Boolean) extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + if (!catalog.tableExists(sourceTable)) { + throw new AnalysisException( + s"Source table in CREATE TABLE LIKE does not exist: '$sourceTable'") + } + if (catalog.isTemporaryTable(sourceTable)) { + throw new AnalysisException( + s"Source table in CREATE TABLE LIKE cannot be temporary: '$sourceTable'") + } + + val tableToCreate = catalog.getTableMetadata(sourceTable).copy( + identifier = targetTable, + tableType = CatalogTableType.MANAGED_TABLE, + createTime = System.currentTimeMillis, + lastAccessTime = -1).withNewStorage(locationUri = None) + + catalog.createTable(tableToCreate, ifNotExists) + Seq.empty[Row] + } +} + + +// TODO: move the rest of the table commands from ddl.scala to this file + +/** + * A command to create a table. + * + * Note: This is currently used only for creating Hive tables. + * This is not intended for temporary tables. + * + * The syntax of using this command in SQL is: + * {{{ + * CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name + * [(col1 data_type [COMMENT col_comment], ...)] + * [COMMENT table_comment] + * [PARTITIONED BY (col3 data_type [COMMENT col_comment], ...)] + * [CLUSTERED BY (col1, ...) [SORTED BY (col1 [ASC|DESC], ...)] INTO num_buckets BUCKETS] + * [SKEWED BY (col1, col2, ...) ON ((col_value, col_value, ...), ...) + * [STORED AS DIRECTORIES] + * [ROW FORMAT row_format] + * [STORED AS file_format | STORED BY storage_handler_class [WITH SERDEPROPERTIES (...)]] + * [LOCATION path] + * [TBLPROPERTIES (property_name=property_value, ...)] + * [AS select_statement]; + * }}} + */ +case class CreateTable(table: CatalogTable, ifNotExists: Boolean) extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.sessionState.catalog.createTable(table, ifNotExists) + Seq.empty[Row] + } + +} + + +/** + * A command that renames a table/view. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table1 RENAME TO table2; + * ALTER VIEW view1 RENAME TO view2; + * }}} + */ +case class AlterTableRename( + oldName: TableIdentifier, + newName: TableIdentifier, + isView: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + DDLUtils.verifyAlterTableType(catalog, oldName, isView) + catalog.invalidateTable(oldName) + catalog.renameTable(oldName, newName) + Seq.empty[Row] + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala deleted file mode 100644 index e5f60b15e7359..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ /dev/null @@ -1,370 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.util.NoSuchElementException - -import org.apache.spark.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} -import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext} - -/** - * A logical command that is executed for its side-effects. `RunnableCommand`s are - * wrapped in `ExecutedCommand` during execution. - */ -private[sql] trait RunnableCommand extends LogicalPlan with logical.Command { - override def output: Seq[Attribute] = Seq.empty - override def children: Seq[LogicalPlan] = Seq.empty - def run(sqlContext: SQLContext): Seq[Row] -} - -/** - * A physical operator that executes the run method of a `RunnableCommand` and - * saves the result to prevent multiple executions. - */ -private[sql] case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { - /** - * A concrete command should override this lazy field to wrap up any side effects caused by the - * command or any other computation that should be evaluated exactly once. The value of this field - * can be used as the contents of the corresponding RDD generated from the physical plan of this - * command. - * - * The `execute()` method of all the physical command classes should reference `sideEffectResult` - * so that the command can be executed eagerly right after the command query is created. - */ - protected[sql] lazy val sideEffectResult: Seq[InternalRow] = { - val converter = CatalystTypeConverters.createToCatalystConverter(schema) - cmd.run(sqlContext).map(converter(_).asInstanceOf[InternalRow]) - } - - override def output: Seq[Attribute] = cmd.output - - override def children: Seq[SparkPlan] = Nil - - override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray - - override def executeTake(limit: Int): Array[InternalRow] = sideEffectResult.take(limit).toArray - - protected override def doExecute(): RDD[InternalRow] = { - sqlContext.sparkContext.parallelize(sideEffectResult, 1) - } - - override def argString: String = cmd.toString -} - - -case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableCommand with Logging { - - private def keyValueOutput: Seq[Attribute] = { - val schema = StructType( - StructField("key", StringType, false) :: - StructField("value", StringType, false) :: Nil) - schema.toAttributes - } - - private val (_output, runFunc): (Seq[Attribute], SQLContext => Seq[Row]) = kv match { - // Configures the deprecated "mapred.reduce.tasks" property. - case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, Some(value))) => - val runFunc = (sqlContext: SQLContext) => { - logWarning( - s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + - s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS.key} instead.") - if (value.toInt < 1) { - val msg = - s"Setting negative ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} for automatically " + - "determining the number of reducers is not supported." - throw new IllegalArgumentException(msg) - } else { - sqlContext.setConf(SQLConf.SHUFFLE_PARTITIONS.key, value) - Seq(Row(SQLConf.SHUFFLE_PARTITIONS.key, value)) - } - } - (keyValueOutput, runFunc) - - case Some((SQLConf.Deprecated.EXTERNAL_SORT, Some(value))) => - val runFunc = (sqlContext: SQLContext) => { - logWarning( - s"Property ${SQLConf.Deprecated.EXTERNAL_SORT} is deprecated and will be ignored. " + - s"External sort will continue to be used.") - Seq(Row(SQLConf.Deprecated.EXTERNAL_SORT, "true")) - } - (keyValueOutput, runFunc) - - // Configures a single property. - case Some((key, Some(value))) => - val runFunc = (sqlContext: SQLContext) => { - sqlContext.setConf(key, value) - Seq(Row(key, value)) - } - (keyValueOutput, runFunc) - - // (In Hive, "SET" returns all changed properties while "SET -v" returns all properties.) - // Queries all key-value pairs that are set in the SQLConf of the sqlContext. - case None => - val runFunc = (sqlContext: SQLContext) => { - sqlContext.getAllConfs.map { case (k, v) => Row(k, v) }.toSeq - } - (keyValueOutput, runFunc) - - // Queries all properties along with their default values and docs that are defined in the - // SQLConf of the sqlContext. - case Some(("-v", None)) => - val runFunc = (sqlContext: SQLContext) => { - sqlContext.conf.getAllDefinedConfs.map { case (key, defaultValue, doc) => - Row(key, defaultValue, doc) - } - } - val schema = StructType( - StructField("key", StringType, false) :: - StructField("default", StringType, false) :: - StructField("meaning", StringType, false) :: Nil) - (schema.toAttributes, runFunc) - - // Queries the deprecated "mapred.reduce.tasks" property. - case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, None)) => - val runFunc = (sqlContext: SQLContext) => { - logWarning( - s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + - s"showing ${SQLConf.SHUFFLE_PARTITIONS.key} instead.") - Seq(Row(SQLConf.SHUFFLE_PARTITIONS.key, sqlContext.conf.numShufflePartitions.toString)) - } - (keyValueOutput, runFunc) - - // Queries a single property. - case Some((key, None)) => - val runFunc = (sqlContext: SQLContext) => { - val value = - try { - if (key == SQLConf.DIALECT.key) { - sqlContext.conf.dialect - } else { - sqlContext.getConf(key) - } - } catch { - case _: NoSuchElementException => "" - } - Seq(Row(key, value)) - } - (keyValueOutput, runFunc) - } - - override val output: Seq[Attribute] = _output - - override def run(sqlContext: SQLContext): Seq[Row] = runFunc(sqlContext) - -} - -/** - * An explain command for users to see how a command will be executed. - * - * Note that this command takes in a logical plan, runs the optimizer on the logical plan - * (but do NOT actually execute it). - */ -case class ExplainCommand( - logicalPlan: LogicalPlan, - override val output: Seq[Attribute] = - Seq(AttributeReference("plan", StringType, nullable = false)()), - extended: Boolean = false) - extends RunnableCommand { - - // Run through the optimizer to generate the physical plan. - override def run(sqlContext: SQLContext): Seq[Row] = try { - // TODO in Hive, the "extended" ExplainCommand prints the AST as well, and detailed properties. - val queryExecution = sqlContext.executePlan(logicalPlan) - val outputString = if (extended) queryExecution.toString else queryExecution.simpleString - - outputString.split("\n").map(Row(_)) - } catch { case cause: TreeNodeException[_] => - ("Error occurred during query planning: \n" + cause.getMessage).split("\n").map(Row(_)) - } -} - - -case class CacheTableCommand( - tableName: String, - plan: Option[LogicalPlan], - isLazy: Boolean) - extends RunnableCommand { - - override def run(sqlContext: SQLContext): Seq[Row] = { - plan.foreach { logicalPlan => - sqlContext.registerDataFrameAsTable(DataFrame(sqlContext, logicalPlan), tableName) - } - sqlContext.cacheTable(tableName) - - if (!isLazy) { - // Performs eager caching - sqlContext.table(tableName).count() - } - - Seq.empty[Row] - } - - override def output: Seq[Attribute] = Seq.empty -} - - -case class UncacheTableCommand(tableName: String) extends RunnableCommand { - - override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.table(tableName).unpersist(blocking = false) - Seq.empty[Row] - } - - override def output: Seq[Attribute] = Seq.empty -} - -/** - * Clear all cached data from the in-memory cache. - */ -case object ClearCacheCommand extends RunnableCommand { - - override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.clearCache() - Seq.empty[Row] - } - - override def output: Seq[Attribute] = Seq.empty -} - - -case class DescribeCommand( - child: SparkPlan, - override val output: Seq[Attribute], - isExtended: Boolean) - extends RunnableCommand { - - override def run(sqlContext: SQLContext): Seq[Row] = { - child.schema.fields.map { field => - val cmtKey = "comment" - val comment = if (field.metadata.contains(cmtKey)) field.metadata.getString(cmtKey) else "" - Row(field.name, field.dataType.simpleString, comment) - } - } -} - -/** - * A command for users to get tables in the given database. - * If a databaseName is not given, the current database will be used. - * The syntax of using this command in SQL is: - * {{{ - * SHOW TABLES [IN databaseName] - * }}} - */ -case class ShowTablesCommand(databaseName: Option[String]) extends RunnableCommand { - - // The result of SHOW TABLES has two columns, tableName and isTemporary. - override val output: Seq[Attribute] = { - val schema = StructType( - StructField("tableName", StringType, false) :: - StructField("isTemporary", BooleanType, false) :: Nil) - - schema.toAttributes - } - - override def run(sqlContext: SQLContext): Seq[Row] = { - // Since we need to return a Seq of rows, we will call getTables directly - // instead of calling tables in sqlContext. - val rows = sqlContext.catalog.getTables(databaseName).map { - case (tableName, isTemporary) => Row(tableName, isTemporary) - } - - rows - } -} - -/** - * A command for users to list all of the registered functions. - * The syntax of using this command in SQL is: - * {{{ - * SHOW FUNCTIONS - * }}} - * TODO currently we are simply ignore the db - */ -case class ShowFunctions(db: Option[String], pattern: Option[String]) extends RunnableCommand { - override val output: Seq[Attribute] = { - val schema = StructType( - StructField("function", StringType, nullable = false) :: Nil) - - schema.toAttributes - } - - override def run(sqlContext: SQLContext): Seq[Row] = pattern match { - case Some(p) => - try { - val regex = java.util.regex.Pattern.compile(p) - sqlContext.functionRegistry.listFunction().filter(regex.matcher(_).matches()).map(Row(_)) - } catch { - // probably will failed in the regex that user provided, then returns empty row. - case _: Throwable => Seq.empty[Row] - } - case None => - sqlContext.functionRegistry.listFunction().map(Row(_)) - } -} - -/** - * A command for users to get the usage of a registered function. - * The syntax of using this command in SQL is - * {{{ - * DESCRIBE FUNCTION [EXTENDED] upper; - * }}} - */ -case class DescribeFunction( - functionName: String, - isExtended: Boolean) extends RunnableCommand { - - override val output: Seq[Attribute] = { - val schema = StructType( - StructField("function_desc", StringType, nullable = false) :: Nil) - - schema.toAttributes - } - - private def replaceFunctionName(usage: String, functionName: String): String = { - if (usage == null) { - "To be added." - } else { - usage.replaceAll("_FUNC_", functionName) - } - } - - override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.functionRegistry.lookupFunction(functionName) match { - case Some(info) => - val result = - Row(s"Function: ${info.getName}") :: - Row(s"Class: ${info.getClassName}") :: - Row(s"Usage: ${replaceFunctionName(info.getUsage(), info.getName)}") :: Nil - - if (isExtended) { - result :+ Row(s"Extended Usage:\n${replaceFunctionName(info.getExtended, info.getName)}") - } else { - result - } - - case None => Seq(Row(s"Function: $functionName is not found.")) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala new file mode 100644 index 0000000000000..41cff07472d1e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.io.SequenceFile.CompressionType +import org.apache.hadoop.io.compress.{BZip2Codec, DeflateCodec, GzipCodec, Lz4Codec, SnappyCodec} + +import org.apache.spark.util.Utils + +private[datasources] object CompressionCodecs { + private val shortCompressionCodecNames = Map( + "none" -> null, + "uncompressed" -> null, + "bzip2" -> classOf[BZip2Codec].getName, + "deflate" -> classOf[DeflateCodec].getName, + "gzip" -> classOf[GzipCodec].getName, + "lz4" -> classOf[Lz4Codec].getName, + "snappy" -> classOf[SnappyCodec].getName) + + /** + * Return the full version of the given codec class. + * If it is already a class name, just return it. + */ + def getCodecClassName(name: String): String = { + val codecName = shortCompressionCodecNames.getOrElse(name.toLowerCase, name) + try { + // Validate the codec name + if (codecName != null) { + Utils.classForName(codecName) + } + codecName + } catch { + case e: ClassNotFoundException => + throw new IllegalArgumentException(s"Codec [$codecName] " + + s"is not available. Known codecs are ${shortCompressionCodecNames.keys.mkString(", ")}.") + } + } + + /** + * Set compression configurations to Hadoop `Configuration`. + * `codec` should be a full class path + */ + def setCodecConfiguration(conf: Configuration, codec: String): Unit = { + if (codec != null) { + conf.set("mapreduce.output.fileoutputformat.compress", "true") + conf.set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString) + conf.set("mapreduce.output.fileoutputformat.compress.codec", codec) + conf.set("mapreduce.map.output.compress", "true") + conf.set("mapreduce.map.output.compress.codec", codec) + } else { + // This infers the option `compression` is set to `uncompressed` or `none`. + conf.set("mapreduce.output.fileoutputformat.compress", "false") + conf.set("mapreduce.map.output.compress", "false") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala deleted file mode 100644 index 6969b423d01b9..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala +++ /dev/null @@ -1,186 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.datasources - -import scala.language.implicitConversions -import scala.util.matching.Regex - -import org.apache.spark.Logging -import org.apache.spark.sql.SaveMode -import org.apache.spark.sql.catalyst.{TableIdentifier, AbstractSparkSQLParser} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.util.DataTypeParser -import org.apache.spark.sql.types._ - - -/** - * A parser for foreign DDL commands. - */ -class DDLParser(parseQuery: String => LogicalPlan) - extends AbstractSparkSQLParser with DataTypeParser with Logging { - - def parse(input: String, exceptionOnError: Boolean): LogicalPlan = { - try { - parse(input) - } catch { - case ddlException: DDLException => throw ddlException - case _ if !exceptionOnError => parseQuery(input) - case x: Throwable => throw x - } - } - - // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` - // properties via reflection the class in runtime for constructing the SqlLexical object - protected val CREATE = Keyword("CREATE") - protected val TEMPORARY = Keyword("TEMPORARY") - protected val TABLE = Keyword("TABLE") - protected val IF = Keyword("IF") - protected val NOT = Keyword("NOT") - protected val EXISTS = Keyword("EXISTS") - protected val USING = Keyword("USING") - protected val OPTIONS = Keyword("OPTIONS") - protected val DESCRIBE = Keyword("DESCRIBE") - protected val EXTENDED = Keyword("EXTENDED") - protected val AS = Keyword("AS") - protected val COMMENT = Keyword("COMMENT") - protected val REFRESH = Keyword("REFRESH") - - protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable | refreshTable - - protected def start: Parser[LogicalPlan] = ddl - - /** - * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * or - * `CREATE [TEMPORARY] TABLE avroTable(intField int, stringField string...) [IF NOT EXISTS] - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * or - * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * AS SELECT ... - */ - protected lazy val createTable: Parser[LogicalPlan] = { - // TODO: Support database.table. - (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ tableIdentifier ~ - tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ { - case temp ~ allowExisting ~ tableIdent ~ columns ~ provider ~ opts ~ query => - if (temp.isDefined && allowExisting.isDefined) { - throw new DDLException( - "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") - } - - val options = opts.getOrElse(Map.empty[String, String]) - if (query.isDefined) { - if (columns.isDefined) { - throw new DDLException( - "a CREATE TABLE AS SELECT statement does not allow column definitions.") - } - // When IF NOT EXISTS clause appears in the query, the save mode will be ignore. - val mode = if (allowExisting.isDefined) { - SaveMode.Ignore - } else if (temp.isDefined) { - SaveMode.Overwrite - } else { - SaveMode.ErrorIfExists - } - - val queryPlan = parseQuery(query.get) - CreateTableUsingAsSelect(tableIdent, - provider, - temp.isDefined, - Array.empty[String], - mode, - options, - queryPlan) - } else { - val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) - CreateTableUsing( - tableIdent, - userSpecifiedSchema, - provider, - temp.isDefined, - options, - allowExisting.isDefined, - managedIfNoPath = false) - } - } - } - - // This is the same as tableIdentifier in SqlParser. - protected lazy val tableIdentifier: Parser[TableIdentifier] = - (ident <~ ".").? ~ ident ^^ { - case maybeDbName ~ tableName => TableIdentifier(tableName, maybeDbName) - } - - protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" - - /* - * describe [extended] table avroTable - * This will display all columns of table `avroTable` includes column_name,column_type,comment - */ - protected lazy val describeTable: Parser[LogicalPlan] = - (DESCRIBE ~> opt(EXTENDED)) ~ tableIdentifier ^^ { - case e ~ tableIdent => - DescribeCommand(UnresolvedRelation(tableIdent, None), e.isDefined) - } - - protected lazy val refreshTable: Parser[LogicalPlan] = - REFRESH ~> TABLE ~> tableIdentifier ^^ { - case tableIndet => - RefreshTable(tableIndet) - } - - protected lazy val options: Parser[Map[String, String]] = - "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } - - protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} - - override implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( - s"identifier matching regex $regex", { - case lexical.Identifier(str) if regex.unapplySeq(str).isDefined => str - case lexical.Keyword(str) if regex.unapplySeq(str).isDefined => str - } - ) - - protected lazy val optionPart: Parser[String] = "[_a-zA-Z][_a-zA-Z0-9]*".r ^^ { - case name => name - } - - protected lazy val optionName: Parser[String] = repsep(optionPart, ".") ^^ { - case parts => parts.mkString(".") - } - - protected lazy val pair: Parser[(String, String)] = - optionName ~ stringLit ^^ { case k ~ v => (k, v) } - - protected lazy val column: Parser[StructField] = - ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => - val meta = cm match { - case Some(comment) => - new MetadataBuilder().putString(COMMENT.str.toLowerCase, comment).build() - case None => Metadata.empty - } - - StructField(columnName, typ, nullable = true, meta) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala new file mode 100644 index 0000000000000..10fde152ab2a9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -0,0 +1,421 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.datasources + +import java.util.ServiceLoader + +import scala.collection.JavaConverters._ +import scala.language.{existentials, implicitConversions} +import scala.util.{Failure, Success, Try} +import scala.util.control.NonFatal + +import org.apache.hadoop.fs.Path + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{CalendarIntervalType, StructType} +import org.apache.spark.util.Utils + +/** + * The main class responsible for representing a pluggable Data Source in Spark SQL. In addition to + * acting as the canonical set of parameters that can describe a Data Source, this class is used to + * resolve a description to a concrete implementation that can be used in a query plan + * (either batch or streaming) or to write out data using an external library. + * + * From an end user's perspective a DataSource description can be created explicitly using + * [[org.apache.spark.sql.DataFrameReader]] or CREATE TABLE USING DDL. Additionally, this class is + * used when resolving a description from a metastore to a concrete implementation. + * + * Many of the arguments to this class are optional, though depending on the specific API being used + * these optional arguments might be filled in during resolution using either inference or external + * metadata. For example, when reading a partitioned table from a file system, partition columns + * will be inferred from the directory layout even if they are not specified. + * + * @param paths A list of file system paths that hold data. These will be globbed before and + * qualified. This option only works when reading from a [[FileFormat]]. + * @param userSpecifiedSchema An optional specification of the schema of the data. When present + * we skip attempting to infer the schema. + * @param partitionColumns A list of column names that the relation is partitioned by. When this + * list is empty, the relation is unpartitioned. + * @param bucketSpec An optional specification for bucketing (hash-partitioning) of the data. + */ +case class DataSource( + sqlContext: SQLContext, + className: String, + paths: Seq[String] = Nil, + userSpecifiedSchema: Option[StructType] = None, + partitionColumns: Seq[String] = Seq.empty, + bucketSpec: Option[BucketSpec] = None, + options: Map[String, String] = Map.empty) extends Logging { + + lazy val providingClass: Class[_] = lookupDataSource(className) + + /** A map to maintain backward compatibility in case we move data sources around. */ + private val backwardCompatibilityMap = Map( + "org.apache.spark.sql.jdbc" -> classOf[jdbc.DefaultSource].getCanonicalName, + "org.apache.spark.sql.jdbc.DefaultSource" -> classOf[jdbc.DefaultSource].getCanonicalName, + "org.apache.spark.sql.json" -> classOf[json.DefaultSource].getCanonicalName, + "org.apache.spark.sql.json.DefaultSource" -> classOf[json.DefaultSource].getCanonicalName, + "org.apache.spark.sql.parquet" -> classOf[parquet.DefaultSource].getCanonicalName, + "org.apache.spark.sql.parquet.DefaultSource" -> classOf[parquet.DefaultSource].getCanonicalName, + "com.databricks.spark.csv" -> classOf[csv.DefaultSource].getCanonicalName + ) + + /** Given a provider name, look up the data source class definition. */ + private def lookupDataSource(provider0: String): Class[_] = { + val provider = backwardCompatibilityMap.getOrElse(provider0, provider0) + val provider2 = s"$provider.DefaultSource" + val loader = Utils.getContextOrSparkClassLoader + val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) + + serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(provider)).toList match { + // the provider format did not match any given registered aliases + case Nil => + Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { + case Success(dataSource) => + // Found the data source using fully qualified path + dataSource + case Failure(error) => + if (provider.startsWith("org.apache.spark.sql.hive.orc")) { + throw new ClassNotFoundException( + "The ORC data source must be used with Hive support enabled.", error) + } else { + if (provider == "avro" || provider == "com.databricks.spark.avro") { + throw new ClassNotFoundException( + s"Failed to find data source: $provider. Please use Spark package " + + "http://spark-packages.org/package/databricks/spark-avro", + error) + } else { + throw new ClassNotFoundException( + s"Failed to find data source: $provider. Please find packages at " + + "http://spark-packages.org", + error) + } + } + } + case head :: Nil => + // there is exactly one registered alias + head.getClass + case sources => + // There are multiple registered aliases for the input + sys.error(s"Multiple sources found for $provider " + + s"(${sources.map(_.getClass.getName).mkString(", ")}), " + + "please specify the fully qualified class name.") + } + } + + private def inferFileFormatSchema(format: FileFormat): StructType = { + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val allPaths = caseInsensitiveOptions.get("path") + val globbedPaths = allPaths.toSeq.flatMap { path => + val hdfsPath = new Path(path) + val fs = hdfsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + SparkHadoopUtil.get.globPathIfNecessary(qualified) + }.toArray + + val fileCatalog: FileCatalog = new HDFSFileCatalog(sqlContext, options, globbedPaths, None) + userSpecifiedSchema.orElse { + format.inferSchema( + sqlContext, + caseInsensitiveOptions, + fileCatalog.allFiles()) + }.getOrElse { + throw new AnalysisException("Unable to infer schema. It must be specified manually.") + } + } + + /** Returns the name and schema of the source that can be used to continually read data. */ + def sourceSchema(): (String, StructType) = { + providingClass.newInstance() match { + case s: StreamSourceProvider => + s.sourceSchema(sqlContext, userSpecifiedSchema, className, options) + + case format: FileFormat => + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val path = caseInsensitiveOptions.getOrElse("path", { + throw new IllegalArgumentException("'path' is not specified") + }) + (s"FileSource[$path]", inferFileFormatSchema(format)) + case _ => + throw new UnsupportedOperationException( + s"Data source $className does not support streamed reading") + } + } + + /** Returns a source that can be used to continually read data. */ + def createSource(metadataPath: String): Source = { + providingClass.newInstance() match { + case s: StreamSourceProvider => + s.createSource(sqlContext, metadataPath, userSpecifiedSchema, className, options) + + case format: FileFormat => + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val path = caseInsensitiveOptions.getOrElse("path", { + throw new IllegalArgumentException("'path' is not specified") + }) + + val dataSchema = inferFileFormatSchema(format) + + def dataFrameBuilder(files: Array[String]): DataFrame = { + Dataset.ofRows( + sqlContext, + LogicalRelation( + DataSource( + sqlContext, + paths = files, + userSpecifiedSchema = Some(dataSchema), + className = className, + options = + new CaseInsensitiveMap(options.filterKeys(_ != "path"))).resolveRelation())) + } + + new FileStreamSource( + sqlContext, metadataPath, path, Some(dataSchema), className, dataFrameBuilder) + case _ => + throw new UnsupportedOperationException( + s"Data source $className does not support streamed reading") + } + } + + /** Returns a sink that can be used to continually write data. */ + def createSink(): Sink = { + providingClass.newInstance() match { + case s: StreamSinkProvider => s.createSink(sqlContext, options, partitionColumns) + case format: FileFormat => + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val path = caseInsensitiveOptions.getOrElse("path", { + throw new IllegalArgumentException("'path' is not specified") + }) + + new FileStreamSink(sqlContext, path, format) + case _ => + throw new UnsupportedOperationException( + s"Data source $className does not support streamed writing") + } + } + + /** + * Returns true if there is a single path that has a metadata log indicating which files should + * be read. + */ + def hasMetadata(path: Seq[String]): Boolean = { + path match { + case Seq(singlePath) => + try { + val hdfsPath = new Path(singlePath) + val fs = hdfsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val metadataPath = new Path(hdfsPath, FileStreamSink.metadataDir) + val res = fs.exists(metadataPath) + res + } catch { + case NonFatal(e) => + logWarning(s"Error while looking for metadata directory.") + false + } + case _ => false + } + } + + /** Create a resolved [[BaseRelation]] that can be used to read data from this [[DataSource]] */ + def resolveRelation(): BaseRelation = { + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val relation = (providingClass.newInstance(), userSpecifiedSchema) match { + // TODO: Throw when too much is given. + case (dataSource: SchemaRelationProvider, Some(schema)) => + dataSource.createRelation(sqlContext, caseInsensitiveOptions, schema) + case (dataSource: RelationProvider, None) => + dataSource.createRelation(sqlContext, caseInsensitiveOptions) + case (_: SchemaRelationProvider, None) => + throw new AnalysisException(s"A schema needs to be specified when using $className.") + case (_: RelationProvider, Some(_)) => + throw new AnalysisException(s"$className does not allow user-specified schemas.") + + // We are reading from the results of a streaming query. Load files from the metadata log + // instead of listing them using HDFS APIs. + case (format: FileFormat, _) + if hasMetadata(caseInsensitiveOptions.get("path").toSeq ++ paths) => + val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head) + val fileCatalog = + new StreamFileCatalog(sqlContext, basePath) + val dataSchema = userSpecifiedSchema.orElse { + format.inferSchema( + sqlContext, + caseInsensitiveOptions, + fileCatalog.allFiles()) + }.getOrElse { + throw new AnalysisException( + s"Unable to infer schema for $format at ${fileCatalog.allFiles().mkString(",")}. " + + "It must be specified manually") + } + + HadoopFsRelation( + sqlContext, + fileCatalog, + partitionSchema = fileCatalog.partitionSpec().partitionColumns, + dataSchema = dataSchema, + bucketSpec = None, + format, + options) + + // This is a non-streaming file based datasource. + case (format: FileFormat, _) => + val allPaths = caseInsensitiveOptions.get("path") ++ paths + val globbedPaths = allPaths.flatMap { path => + val hdfsPath = new Path(path) + val fs = hdfsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + val globPath = SparkHadoopUtil.get.globPathIfNecessary(qualified) + + if (globPath.isEmpty) { + throw new AnalysisException(s"Path does not exist: $qualified") + } + // Sufficient to check head of the globPath seq for non-glob scenario + if (!fs.exists(globPath.head)) { + throw new AnalysisException(s"Path does not exist: ${globPath.head}") + } + globPath + }.toArray + + // If they gave a schema, then we try and figure out the types of the partition columns + // from that schema. + val partitionSchema = userSpecifiedSchema.map { schema => + StructType( + partitionColumns.map { c => + // TODO: Case sensitivity. + schema + .find(_.name.toLowerCase() == c.toLowerCase()) + .getOrElse(throw new AnalysisException(s"Invalid partition column '$c'")) + }) + } + + val fileCatalog: FileCatalog = + new HDFSFileCatalog(sqlContext, options, globbedPaths, partitionSchema) + val dataSchema = userSpecifiedSchema.orElse { + format.inferSchema( + sqlContext, + caseInsensitiveOptions, + fileCatalog.allFiles()) + }.getOrElse { + throw new AnalysisException( + s"Unable to infer schema for $format at ${allPaths.take(2).mkString(",")}. " + + "It must be specified manually") + } + + val enrichedOptions = + format.prepareRead(sqlContext, caseInsensitiveOptions, fileCatalog.allFiles()) + + HadoopFsRelation( + sqlContext, + fileCatalog, + partitionSchema = fileCatalog.partitionSpec().partitionColumns, + dataSchema = dataSchema.asNullable, + bucketSpec = bucketSpec, + format, + enrichedOptions) + + case _ => + throw new AnalysisException( + s"$className is not a valid Spark SQL Data Source.") + } + + relation + } + + /** Writes the give [[DataFrame]] out to this [[DataSource]]. */ + def write( + mode: SaveMode, + data: DataFrame): BaseRelation = { + if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { + throw new AnalysisException("Cannot save interval data type into external storage.") + } + + providingClass.newInstance() match { + case dataSource: CreatableRelationProvider => + dataSource.createRelation(sqlContext, mode, options, data) + case format: FileFormat => + // Don't glob path for the write path. The contracts here are: + // 1. Only one output path can be specified on the write path; + // 2. Output path must be a legal HDFS style file system path; + // 3. It's OK that the output path doesn't exist yet; + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val outputPath = { + val path = new Path(caseInsensitiveOptions.getOrElse("path", { + throw new IllegalArgumentException("'path' is not specified") + })) + val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + path.makeQualified(fs.getUri, fs.getWorkingDirectory) + } + + val caseSensitive = sqlContext.conf.caseSensitiveAnalysis + PartitioningUtils.validatePartitionColumnDataTypes( + data.schema, partitionColumns, caseSensitive) + + // If we are appending to a table that already exists, make sure the partitioning matches + // up. If we fail to load the table for whatever reason, ignore the check. + if (mode == SaveMode.Append) { + val existingPartitionColumnSet = try { + Some( + resolveRelation() + .asInstanceOf[HadoopFsRelation] + .location + .partitionSpec() + .partitionColumns + .fieldNames + .toSet) + } catch { + case e: Exception => + None + } + + existingPartitionColumnSet.foreach { ex => + if (ex.map(_.toLowerCase) != partitionColumns.map(_.toLowerCase()).toSet) { + throw new AnalysisException( + s"Requested partitioning does not equal existing partitioning: " + + s"$ex != ${partitionColumns.toSet}.") + } + } + } + + // For partitioned relation r, r.schema's column ordering can be different from the column + // ordering of data.logicalPlan (partition columns are all moved after data column). This + // will be adjusted within InsertIntoHadoopFsRelation. + val plan = + InsertIntoHadoopFsRelation( + outputPath, + partitionColumns.map(UnresolvedAttribute.quoted), + bucketSpec, + format, + () => Unit, // No existing table needs to be refreshed. + options, + data.logicalPlan, + mode) + sqlContext.executePlan(plan).toRdd + + case _ => + sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.") + } + + // We replace the schema with that of the DataFrame we just wrote out to avoid re-inferring it. + copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 7265d6a4de2e6..ac3c52e901795 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -17,28 +17,72 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions} -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.DataSourceScan.{INPUT_PATHS, PUSHED_FILTERS} +import org.apache.spark.sql.execution.command.ExecutedCommand import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{StringType, StructType} -import org.apache.spark.sql.{SaveMode, Strategy, execution, sources, _} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.{Logging, TaskContext} + +/** + * Replaces generic operations with specific variants that are designed to work with Spark + * SQL Data Sources. + */ +private[sql] object DataSourceAnalysis extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case i @ logical.InsertIntoTable( + l @ LogicalRelation(t: HadoopFsRelation, _, _), part, query, overwrite, false) + if query.resolved && t.schema.asNullable == query.schema.asNullable => + + // Sanity checks + if (t.location.paths.size != 1) { + throw new AnalysisException( + "Can only write data to relations with a single path.") + } + + val outputPath = t.location.paths.head + val inputPaths = query.collect { + case LogicalRelation(r: HadoopFsRelation, _, _) => r.location.paths + }.flatten + + val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append + if (overwrite && inputPaths.contains(outputPath)) { + throw new AnalysisException( + "Cannot overwrite a path that is also being read from.") + } + + InsertIntoHadoopFsRelation( + outputPath, + t.partitionSchema.fields.map(_.name).map(UnresolvedAttribute(_)), + t.bucketSpec, + t.fileFormat, + () => t.refresh(), + t.options, + query, + mode) + } +} /** * A Strategy for planning scans over data sources defined using the sources API. */ private[sql] object DataSourceStrategy extends Strategy with Logging { def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { - case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _)) => + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _)) => pruneFilterProjectRaw( l, projects, @@ -46,215 +90,41 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { (requestedColumns, allPredicates, _) => toCatalystRDD(l, requestedColumns, t.buildScan(requestedColumns, allPredicates))) :: Nil - case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedFilteredScan, _)) => + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedFilteredScan, _, _)) => pruneFilterProject( l, projects, filters, (a, f) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f))) :: Nil - case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedScan, _)) => + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedScan, _, _)) => pruneFilterProject( l, projects, filters, (a, _) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray))) :: Nil - // Scanning partitioned HadoopFsRelation - case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation, _)) - if t.partitionSpec.partitionColumns.nonEmpty => - // We divide the filter expressions into 3 parts - val partitionColumns = AttributeSet( - t.partitionColumns.map(c => l.output.find(_.name == c.name).get)) - - // Only pruning the partition keys - val partitionFilters = filters.filter(_.references.subsetOf(partitionColumns)) - - // Only pushes down predicates that do not reference partition keys. - val pushedFilters = filters.filter(_.references.intersect(partitionColumns).isEmpty) - - // Predicates with both partition keys and attributes - val combineFilters = filters.toSet -- partitionFilters.toSet -- pushedFilters.toSet - - val selectedPartitions = prunePartitions(partitionFilters, t.partitionSpec).toArray - - logInfo { - val total = t.partitionSpec.partitions.length - val selected = selectedPartitions.length - val percentPruned = (1 - selected.toDouble / total.toDouble) * 100 - s"Selected $selected partitions out of $total, pruned $percentPruned% partitions." - } - - val scan = buildPartitionedTableScan( - l, - projects, - pushedFilters, - t.partitionSpec.partitionColumns, - selectedPartitions) - - combineFilters - .reduceLeftOption(expressions.And) - .map(execution.Filter(_, scan)).getOrElse(scan) :: Nil - - // Scanning non-partitioned HadoopFsRelation - case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation, _)) => - // See buildPartitionedTableScan for the reason that we need to create a shard - // broadcast HadoopConf. - val sharedHadoopConf = SparkHadoopUtil.get.conf - val confBroadcast = - t.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf)) - pruneFilterProject( - l, - projects, - filters, - (a, f) => t.buildInternalScan(a.map(_.name).toArray, f, t.paths, confBroadcast)) :: Nil - - case l @ LogicalRelation(baseRelation: TableScan, _) => - execution.PhysicalRDD.createFromDataSource( + case l @ LogicalRelation(baseRelation: TableScan, _, _) => + execution.DataSourceScan.create( l.output, toCatalystRDD(l, baseRelation.buildScan()), baseRelation) :: Nil - case i @ logical.InsertIntoTable(l @ LogicalRelation(t: InsertableRelation, _), + case i @ logical.InsertIntoTable(l @ LogicalRelation(t: InsertableRelation, _, _), part, query, overwrite, false) if part.isEmpty => - execution.ExecutedCommand(InsertIntoDataSource(l, query, overwrite)) :: Nil - - case i @ logical.InsertIntoTable( - l @ LogicalRelation(t: HadoopFsRelation, _), part, query, overwrite, false) => - val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append - execution.ExecutedCommand(InsertIntoHadoopFsRelation(t, query, mode)) :: Nil + ExecutedCommand(InsertIntoDataSource(l, query, overwrite)) :: Nil case _ => Nil } - private def buildPartitionedTableScan( - logicalRelation: LogicalRelation, - projections: Seq[NamedExpression], - filters: Seq[Expression], - partitionColumns: StructType, - partitions: Array[Partition]): SparkPlan = { - val relation = logicalRelation.relation.asInstanceOf[HadoopFsRelation] - - // Because we are creating one RDD per partition, we need to have a shared HadoopConf. - // Otherwise, the cost of broadcasting HadoopConf in every RDD will be high. - val sharedHadoopConf = SparkHadoopUtil.get.conf - val confBroadcast = - relation.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf)) - val partitionColumnNames = partitionColumns.fieldNames.toSet - - // Now, we create a scan builder, which will be used by pruneFilterProject. This scan builder - // will union all partitions and attach partition values if needed. - val scanBuilder = { - (requiredColumns: Seq[Attribute], filters: Array[Filter]) => { - val requiredDataColumns = - requiredColumns.filterNot(c => partitionColumnNames.contains(c.name)) - - // Builds RDD[Row]s for each selected partition. - val perPartitionRows = partitions.map { case Partition(partitionValues, dir) => - // Don't scan any partition columns to save I/O. Here we are being optimistic and - // assuming partition columns data stored in data files are always consistent with those - // partition values encoded in partition directory paths. - val dataRows = relation.buildInternalScan( - requiredDataColumns.map(_.name).toArray, filters, Array(dir), confBroadcast) - - // Merges data values with partition values. - mergeWithPartitionValues( - requiredColumns, - requiredDataColumns, - partitionColumns, - partitionValues, - dataRows) - } - - val unionedRows = - if (perPartitionRows.length == 0) { - relation.sqlContext.emptyResult - } else { - new UnionRDD(relation.sqlContext.sparkContext, perPartitionRows) - } - - unionedRows - } - } - - // Create the scan operator. If needed, add Filter and/or Project on top of the scan. - // The added Filter/Project is on top of the unioned RDD. We do not want to create - // one Filter/Project for every partition. - val sparkPlan = pruneFilterProject( - logicalRelation, - projections, - filters, - scanBuilder) - - sparkPlan - } - - private def mergeWithPartitionValues( - requiredColumns: Seq[Attribute], - dataColumns: Seq[Attribute], - partitionColumnSchema: StructType, - partitionValues: InternalRow, - dataRows: RDD[InternalRow]): RDD[InternalRow] = { - // If output columns contain any partition column(s), we need to merge scanned data - // columns and requested partition columns to form the final result. - if (requiredColumns != dataColumns) { - // Builds `AttributeReference`s for all partition columns so that we can use them to project - // required partition columns. Note that if a partition column appears in `requiredColumns`, - // we should use the `AttributeReference` in `requiredColumns`. - val partitionColumns = { - val requiredColumnMap = requiredColumns.map(a => a.name -> a).toMap - partitionColumnSchema.toAttributes.map { a => - requiredColumnMap.getOrElse(a.name, a) - } - } - - val mapPartitionsFunc = (_: TaskContext, _: Int, iterator: Iterator[InternalRow]) => { - // Note that we can't use an `UnsafeRowJoiner` to replace the following `JoinedRow` and - // `UnsafeProjection`. Because the projection may also adjust column order. - val mutableJoinedRow = new JoinedRow() - val unsafePartitionValues = UnsafeProjection.create(partitionColumnSchema)(partitionValues) - val unsafeProjection = - UnsafeProjection.create(requiredColumns, dataColumns ++ partitionColumns) - - iterator.map { unsafeDataRow => - unsafeProjection(mutableJoinedRow(unsafeDataRow, unsafePartitionValues)) - } - } - - // This is an internal RDD whose call site the user should not be concerned with - // Since we create many of these (one per partition), the time spent on computing - // the call site may add up. - Utils.withDummyCallSite(dataRows.sparkContext) { - new MapPartitionsRDD(dataRows, mapPartitionsFunc, preservesPartitioning = false) - } - } else { - dataRows - } - } - - protected def prunePartitions( - predicates: Seq[Expression], - partitionSpec: PartitionSpec): Seq[Partition] = { - val PartitionSpec(partitionColumns, partitions) = partitionSpec - val partitionColumnNames = partitionColumns.map(_.name).toSet - val partitionPruningPredicates = predicates.filter { - _.references.map(_.name).toSet.subsetOf(partitionColumnNames) - } - - if (partitionPruningPredicates.nonEmpty) { - val predicate = - partitionPruningPredicates - .reduceOption(expressions.And) - .getOrElse(Literal(true)) - - val boundPredicate = InterpretedPredicate.create(predicate.transform { - case a: AttributeReference => - val index = partitionColumns.indexWhere(a.name == _.name) - BoundReference(index, partitionColumns(index).dataType, nullable = true) - }) + // Get the bucket ID based on the bucketing values. + // Restriction: Bucket pruning works iff the bucketing column has one and only one column. + def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { + val mutableRow = new SpecificMutableRow(Seq(bucketColumn.dataType)) + mutableRow(0) = Cast(Literal(value), bucketColumn.dataType).eval(null) + val bucketIdGeneration = UnsafeProjection.create( + HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil, + bucketColumn :: Nil) - partitions.filter { case Partition(values, _) => boundPredicate(values) } - } else { - partitions - } + bucketIdGeneration(mutableRow).getInt(0) } // Based on Public API. @@ -315,6 +185,21 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // `Filter`s or cannot be handled by `relation`. val filterCondition = unhandledPredicates.reduceLeftOption(expressions.And) + val metadata: Map[String, String] = { + val pairs = ArrayBuffer.empty[(String, String)] + + if (pushedFilters.nonEmpty) { + pairs += (PUSHED_FILTERS -> pushedFilters.mkString("[", ", ", "]")) + } + + relation.relation match { + case r: HadoopFsRelation => pairs += INPUT_PATHS -> r.location.paths.mkString(", ") + case _ => + } + + pairs.toMap + } + if (projects.map(_.toAttribute) == projects && projectSet.size == projects.size && filterSet.subsetOf(projectSet)) { @@ -329,21 +214,22 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // Don't request columns that are only referenced by pushed filters. .filterNot(handledSet.contains) - val scan = execution.PhysicalRDD.createFromDataSource( + val scan = execution.DataSourceScan.create( projects.map(_.toAttribute), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), - relation.relation) + relation.relation, metadata) filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) } else { // Don't request columns that are only referenced by pushed filters. val requestedColumns = (projectSet ++ filterSet -- handledSet).map(relation.attributeMap).toSeq - val scan = execution.PhysicalRDD.createFromDataSource( + val scan = execution.DataSourceScan.create( requestedColumns, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), - relation.relation) - execution.Project(projects, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) + relation.relation, metadata) + execution.Project( + projects, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) } } @@ -453,8 +339,8 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { * * @return A pair of `Seq[Expression]` and `Seq[Filter]`. The first element contains all Catalyst * predicate [[Expression]]s that are either not convertible or cannot be handled by - * `relation`. The second element contains all converted data source [[Filter]]s that can - * be handled by `relation`. + * `relation`. The second element contains all converted data source [[Filter]]s that + * will be pushed down to the data source. */ protected[sql] def selectFilters( relation: BaseRelation, @@ -476,7 +362,9 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // Catalyst predicate expressions that cannot be translated to data source filters. val unrecognizedPredicates = predicates.filterNot(translatedMap.contains) - // Data source filters that cannot be handled by `relation` + // Data source filters that cannot be handled by `relation`. The semantic of a unhandled filter + // at here is that a data source may not be able to apply this filter to every row + // of the underlying dataset. val unhandledFilters = relation.unhandledFilters(translatedMap.values.toArray).toSet val (unhandled, handled) = translated.partition { @@ -491,6 +379,11 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // Translated data source filters that can be handled by `relation` val (_, handledFilters) = handled.unzip - (unrecognizedPredicates ++ unhandledPredicates, handledFilters) + // translated contains all filters that have been converted to the public Filter interface. + // We should always push them to the data source no matter whether the data source can apply + // a filter to every row or not. + val (_, translatedFilters) = translated.unzip + + (unrecognizedPredicates ++ unhandledPredicates, translatedFilters) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala new file mode 100644 index 0000000000000..468e101fedb8b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.{Partition, TaskContext} +import org.apache.spark.rdd.{InputFileNameHolder, RDD} +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.InternalRow + +/** + * A single file that should be read, along with partition column values that + * need to be prepended to each row. The reading should start at the first + * valid record found after `offset`. + */ +case class PartitionedFile( + partitionValues: InternalRow, + filePath: String, + start: Long, + length: Long) { + override def toString: String = { + s"path: $filePath, range: $start-${start + length}, partition values: $partitionValues" + } +} + +/** + * A collection of files that should be read as a single task possibly from multiple partitioned + * directories. + * + * TODO: This currently does not take locality information about the files into account. + */ +case class FilePartition(index: Int, files: Seq[PartitionedFile]) extends Partition + +class FileScanRDD( + @transient val sqlContext: SQLContext, + readFunction: (PartitionedFile) => Iterator[InternalRow], + @transient val filePartitions: Seq[FilePartition]) + extends RDD[InternalRow](sqlContext.sparkContext, Nil) { + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val iterator = new Iterator[Object] with AutoCloseable { + private[this] val files = split.asInstanceOf[FilePartition].files.toIterator + private[this] var currentIterator: Iterator[Object] = null + + def hasNext = (currentIterator != null && currentIterator.hasNext) || nextIterator() + def next() = currentIterator.next() + + /** Advances to the next file. Returns true if a new non-empty iterator is available. */ + private def nextIterator(): Boolean = { + if (files.hasNext) { + val nextFile = files.next() + logInfo(s"Reading File $nextFile") + InputFileNameHolder.setInputFileName(nextFile.filePath) + currentIterator = readFunction(nextFile) + hasNext + } else { + InputFileNameHolder.unsetInputFileName() + false + } + } + + override def close() = { + InputFileNameHolder.unsetInputFileName() + } + } + + // Register an on-task-completion callback to close the input stream. + context.addTaskCompletionListener(context => iterator.close()) + + iterator.asInstanceOf[Iterator[InternalRow]] // This is an erasure hack. + } + + override protected def getPartitions: Array[Partition] = filePartitions.toArray +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala new file mode 100644 index 0000000000000..80a9156ddcdca --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.fs.Path + +import org.apache.spark.internal.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.{DataSourceScan, SparkPlan} +import org.apache.spark.sql.sources._ + +/** + * A strategy for planning scans over collections of files that might be partitioned or bucketed + * by user specified columns. + * + * At a high level planning occurs in several phases: + * - Split filters by when they need to be evaluated. + * - Prune the schema of the data requested based on any projections present. Today this pruning + * is only done on top level columns, but formats should support pruning of nested columns as + * well. + * - Construct a reader function by passing filters and the schema into the FileFormat. + * - Using an partition pruning predicates, enumerate the list of files that should be read. + * - Split the files into tasks and construct a FileScanRDD. + * - Add any projection or filters that must be evaluated after the scan. + * + * Files are assigned into tasks using the following algorithm: + * - If the table is bucketed, group files by bucket id into the correct number of partitions. + * - If the table is not bucketed or bucketing is turned off: + * - If any file is larger than the threshold, split it into pieces based on that threshold + * - Sort the files by decreasing file size. + * - Assign the ordered files to buckets using the following algorithm. If the current partition + * is under the threshold with the addition of the next file, add it. If not, open a new bucket + * and add it. Proceed to the next file. + */ +private[sql] object FileSourceStrategy extends Strategy with Logging { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalOperation(projects, filters, l @ LogicalRelation(files: HadoopFsRelation, _, _)) => + // Filters on this relation fall into four categories based on where we can use them to avoid + // reading unneeded data: + // - partition keys only - used to prune directories to read + // - bucket keys only - optionally used to prune files to read + // - keys stored in the data only - optionally used to skip groups of data in files + // - filters that need to be evaluated again after the scan + val filterSet = ExpressionSet(filters) + + // The attribute name of predicate could be different than the one in schema in case of + // case insensitive, we should change them to match the one in schema, so we donot need to + // worry about case sensitivity anymore. + val normalizedFilters = filters.map { e => + e transform { + case a: AttributeReference => + a.withName(l.output.find(_.semanticEquals(a)).get.name) + } + } + + val partitionColumns = + l.resolve(files.partitionSchema, files.sqlContext.sessionState.analyzer.resolver) + val partitionSet = AttributeSet(partitionColumns) + val partitionKeyFilters = + ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) + logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}") + + val dataColumns = + l.resolve(files.dataSchema, files.sqlContext.sessionState.analyzer.resolver) + + // Partition keys are not available in the statistics of the files. + val dataFilters = normalizedFilters.filter(_.references.intersect(partitionSet).isEmpty) + + // Predicates with both partition keys and attributes need to be evaluated after the scan. + val afterScanFilters = filterSet -- partitionKeyFilters + logInfo(s"Post-Scan Filters: ${afterScanFilters.mkString(",")}") + + val selectedPartitions = files.location.listFiles(partitionKeyFilters.toSeq) + + val filterAttributes = AttributeSet(afterScanFilters) + val requiredExpressions: Seq[NamedExpression] = filterAttributes.toSeq ++ projects + val requiredAttributes = AttributeSet(requiredExpressions) + + val readDataColumns = + dataColumns + .filter(requiredAttributes.contains) + .filterNot(partitionColumns.contains) + val prunedDataSchema = readDataColumns.toStructType + logInfo(s"Pruned Data Schema: ${prunedDataSchema.simpleString(5)}") + + val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter) + logInfo(s"Pushed Filters: ${pushedDownFilters.mkString(",")}") + + val readFile = files.fileFormat.buildReader( + sqlContext = files.sqlContext, + dataSchema = files.dataSchema, + partitionSchema = files.partitionSchema, + requiredSchema = prunedDataSchema, + filters = pushedDownFilters, + options = files.options) + + val plannedPartitions = files.bucketSpec match { + case Some(bucketing) if files.sqlContext.conf.bucketingEnabled => + logInfo(s"Planning with ${bucketing.numBuckets} buckets") + val bucketed = + selectedPartitions.flatMap { p => + p.files.map(f => PartitionedFile(p.values, f.getPath.toUri.toString, 0, f.getLen)) + }.groupBy { f => + BucketingUtils + .getBucketId(new Path(f.filePath).getName) + .getOrElse(sys.error(s"Invalid bucket file ${f.filePath}")) + } + + (0 until bucketing.numBuckets).map { bucketId => + FilePartition(bucketId, bucketed.getOrElse(bucketId, Nil)) + } + + case _ => + val maxSplitBytes = files.sqlContext.conf.filesMaxPartitionBytes + val openCostInBytes = files.sqlContext.conf.filesOpenCostInBytes + logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + + s"open cost is considered as scanning $openCostInBytes bytes.") + + val splitFiles = selectedPartitions.flatMap { partition => + partition.files.flatMap { file => + (0L to file.getLen by maxSplitBytes).map { offset => + val remaining = file.getLen - offset + val size = if (remaining > maxSplitBytes) maxSplitBytes else remaining + PartitionedFile(partition.values, file.getPath.toUri.toString, offset, size) + } + } + }.toArray.sortBy(_.length)(implicitly[Ordering[Long]].reverse) + + val partitions = new ArrayBuffer[FilePartition] + val currentFiles = new ArrayBuffer[PartitionedFile] + var currentSize = 0L + + /** Add the given file to the current partition. */ + def addFile(file: PartitionedFile): Unit = { + currentSize += file.length + openCostInBytes + currentFiles.append(file) + } + + /** Close the current partition and move to the next. */ + def closePartition(): Unit = { + if (currentFiles.nonEmpty) { + val newPartition = + FilePartition( + partitions.size, + currentFiles.toArray.toSeq) // Copy to a new Array. + partitions.append(newPartition) + } + currentFiles.clear() + currentSize = 0 + } + + // Assign files to partitions using "First Fit Decreasing" (FFD) + // TODO: consider adding a slop factor here? + splitFiles.foreach { file => + if (currentSize + file.length > maxSplitBytes) { + closePartition() + } + addFile(file) + } + closePartition() + partitions + } + + val scan = + DataSourceScan.create( + readDataColumns ++ partitionColumns, + new FileScanRDD( + files.sqlContext, + readFile, + plannedPartitions), + files, + Map( + "Format" -> files.fileFormat.toString, + "PushedFilters" -> pushedDownFilters.mkString("[", ", ", "]"), + "ReadSchema" -> prunedDataSchema.simpleString)) + + val afterScanFilter = afterScanFilters.toSeq.reduceOption(expressions.And) + val withFilter = afterScanFilter.map(execution.Filter(_, scan)).getOrElse(scan) + val withProjections = if (projects == withFilter.output) { + withFilter + } else { + execution.Project(projects, withFilter) + } + + withProjections :: Nil + + case _ => Nil + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala new file mode 100644 index 0000000000000..18f9b55895a64 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.net.URI + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.Text +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.{FileSplit, LineRecordReader} +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl + +/** + * An adaptor from a [[PartitionedFile]] to an [[Iterator]] of [[Text]], which are all of the lines + * in that file. + */ +class HadoopFileLinesReader(file: PartitionedFile, conf: Configuration) extends Iterator[Text] { + private val iterator = { + val fileSplit = new FileSplit( + new Path(new URI(file.filePath)), + file.start, + file.length, + // TODO: Implement Locality + Array.empty) + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + val reader = new LineRecordReader() + reader.initialize(fileSplit, hadoopAttemptContext) + new RecordReaderIterator(reader) + } + + override def hasNext: Boolean = iterator.hasNext + + override def next(): Text = iterator.next() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala index 3b7dc2e8d0210..37c2c4517ccf5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.sources.InsertableRelation @@ -34,7 +34,7 @@ private[sql] case class InsertIntoDataSource( override def run(sqlContext: SQLContext): Seq[Row] = { val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] - val data = DataFrame(sqlContext, query) + val data = Dataset.ofRows(sqlContext, query) // Apply the schema of the existing table to the new data. val df = sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) relation.insert(df, overwrite) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala index 735d52f808868..889c0204f8e6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -22,15 +22,16 @@ import java.io.IOException import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat + import org.apache.spark._ import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.{RunnableCommand, SQLExecution} +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ -import org.apache.spark.util.Utils - /** * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. @@ -55,18 +56,29 @@ import org.apache.spark.util.Utils * thrown during job commitment, also aborts the job. */ private[sql] case class InsertIntoHadoopFsRelation( - @transient relation: HadoopFsRelation, + outputPath: Path, + partitionColumns: Seq[Attribute], + bucketSpec: Option[BucketSpec], + fileFormat: FileFormat, + refreshFunction: () => Unit, + options: Map[String, String], @transient query: LogicalPlan, mode: SaveMode) extends RunnableCommand { + override def children: Seq[LogicalPlan] = query :: Nil + override def run(sqlContext: SQLContext): Seq[Row] = { - require( - relation.paths.length == 1, - s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}") + // Most formats don't do well with duplicate columns, so lets not allow that + if (query.schema.fieldNames.length != query.schema.fieldNames.distinct.length) { + val duplicateColumns = query.schema.fieldNames.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => "\"" + x + "\"" + }.mkString(", ") + throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + + s"cannot save to file.") + } val hadoopConf = sqlContext.sparkContext.hadoopConfiguration - val outputPath = new Path(relation.paths.head) val fs = outputPath.getFileSystem(hadoopConf) val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) @@ -75,11 +87,9 @@ private[sql] case class InsertIntoHadoopFsRelation( case (SaveMode.ErrorIfExists, true) => throw new AnalysisException(s"path $qualifiedOutputPath already exists.") case (SaveMode.Overwrite, true) => - Utils.tryOrIOException { - if (!fs.delete(qualifiedOutputPath, true /* recursively */)) { - throw new IOException(s"Unable to clear output " + - s"directory $qualifiedOutputPath prior to writing to it") - } + if (!fs.delete(qualifiedOutputPath, true /* recursively */)) { + throw new IOException(s"Unable to clear output " + + s"directory $qualifiedOutputPath prior to writing to it") } true case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => @@ -93,50 +103,33 @@ private[sql] case class InsertIntoHadoopFsRelation( val isAppend = pathExists && (mode == SaveMode.Append) if (doInsertion) { - val job = new Job(hadoopConf) + val job = Job.getInstance(hadoopConf) job.setOutputKeyClass(classOf[Void]) job.setOutputValueClass(classOf[InternalRow]) FileOutputFormat.setOutputPath(job, qualifiedOutputPath) - // A partitioned relation schema's can be different from the input logicalPlan, since - // partition columns are all moved after data column. We Project to adjust the ordering. - // TODO: this belongs in the analyzer. - val project = Project( - relation.schema.map(field => UnresolvedAttribute.quoted(field.name)), query) - val queryExecution = DataFrame(sqlContext, project).queryExecution + val partitionSet = AttributeSet(partitionColumns) + val dataColumns = query.output.filterNot(partitionSet.contains) + val queryExecution = Dataset.ofRows(sqlContext, query).queryExecution SQLExecution.withNewExecutionId(sqlContext, queryExecution) { - val df = sqlContext.internalCreateDataFrame(queryExecution.toRdd, relation.schema) - val partitionColumns = relation.partitionColumns.fieldNames - - // Some pre-flight checks. - require( - df.schema == relation.schema, - s"""DataFrame must have the same schema as the relation to which is inserted. - |DataFrame schema: ${df.schema} - |Relation schema: ${relation.schema} - """.stripMargin) - val partitionColumnsInSpec = relation.partitionColumns.fieldNames - require( - partitionColumnsInSpec.sameElements(partitionColumns), - s"""Partition columns mismatch. - |Expected: ${partitionColumnsInSpec.mkString(", ")} - |Actual: ${partitionColumns.mkString(", ")} - """.stripMargin) - - val writerContainer = if (partitionColumns.isEmpty) { + val relation = + WriteRelation( + sqlContext, + dataColumns.toStructType, + qualifiedOutputPath.toString, + fileFormat.prepareWrite(sqlContext, _, options, dataColumns.toStructType), + bucketSpec) + + val writerContainer = if (partitionColumns.isEmpty && bucketSpec.isEmpty) { new DefaultWriterContainer(relation, job, isAppend) } else { - val output = df.queryExecution.executedPlan.output - val (partitionOutput, dataOutput) = - output.partition(a => partitionColumns.contains(a.name)) - new DynamicPartitionWriterContainer( relation, job, - partitionOutput, - dataOutput, - output, + partitionColumns = partitionColumns, + dataColumns = dataColumns, + inputSchema = query.output, PartitioningUtils.DEFAULT_PARTITION_NAME, sqlContext.conf.getConf(SQLConf.PARTITION_MAX_FILES), isAppend) @@ -147,9 +140,9 @@ private[sql] case class InsertIntoHadoopFsRelation( writerContainer.driverSideSetup() try { - sqlContext.sparkContext.runJob(df.queryExecution.toRdd, writerContainer.writeRows _) + sqlContext.sparkContext.runJob(queryExecution.toRdd, writerContainer.writeRows _) writerContainer.commitJob() - relation.refresh() + refreshFunction() } catch { case cause: Throwable => logError("Aborting job.", cause) writerContainer.abortJob() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 219dae88e515d..0e0748ff32df3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} @@ -30,7 +31,8 @@ import org.apache.spark.sql.sources.BaseRelation */ case class LogicalRelation( relation: BaseRelation, - expectedOutputAttributes: Option[Seq[Attribute]] = None) + expectedOutputAttributes: Option[Seq[Attribute]] = None, + metastoreTableIdentifier: Option[TableIdentifier] = None) extends LeafNode with MultiInstanceRelation { override val output: Seq[AttributeReference] = { @@ -49,7 +51,7 @@ case class LogicalRelation( // Logical Relations are distinct if they have different output for the sake of transformations. override def equals(other: Any): Boolean = other match { - case l @ LogicalRelation(otherRelation, _) => relation == otherRelation && output == l.output + case l @ LogicalRelation(otherRelation, _, _) => relation == otherRelation && output == l.output case _ => false } @@ -58,7 +60,7 @@ case class LogicalRelation( } override def sameResult(otherPlan: LogicalPlan): Boolean = otherPlan match { - case LogicalRelation(otherRelation, _) => relation == otherRelation + case LogicalRelation(otherRelation, _, _) => relation == otherRelation case _ => false } @@ -74,7 +76,11 @@ case class LogicalRelation( /** Used to lookup original attribute capitalization */ val attributeMap: AttributeMap[AttributeReference] = AttributeMap(output.map(o => (o, o))) - def newInstance(): this.type = LogicalRelation(relation).asInstanceOf[this.type] + def newInstance(): this.type = + LogicalRelation( + relation, + expectedOutputAttributes, + metastoreTableIdentifier).asInstanceOf[this.type] override def simpleString: String = s"Relation[${output.mkString(",")}] $relation" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ParseModes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ParseModes.scala new file mode 100644 index 0000000000000..468228053c964 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ParseModes.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +private[datasources] object ParseModes { + val PERMISSIVE_MODE = "PERMISSIVE" + val DROP_MALFORMED_MODE = "DROPMALFORMED" + val FAIL_FAST_MODE = "FAILFAST" + + val DEFAULT = PERMISSIVE_MODE + + def isValidMode(mode: String): Boolean = { + mode.toUpperCase match { + case PERMISSIVE_MODE | DROP_MALFORMED_MODE | FAIL_FAST_MODE => true + case _ => false + } + } + + def isDropMalformedMode(mode: String): Boolean = mode.toUpperCase == DROP_MALFORMED_MODE + def isFailFastMode(mode: String): Boolean = mode.toUpperCase == FAIL_FAST_MODE + def isPermissiveMode(mode: String): Boolean = if (isValidMode(mode)) { + mode.toUpperCase == PERMISSIVE_MODE + } else { + true // We default to permissive is the mode string is not valid + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 86bc3a1b6dab2..3ac2ff494fa81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -32,12 +32,23 @@ import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types._ -private[sql] case class Partition(values: InternalRow, path: String) +object PartitionDirectory { + def apply(values: InternalRow, path: String): PartitionDirectory = + apply(values, new Path(path)) +} + +/** + * Holds a directory in a partitioned collection of files as well as as the partition values + * in the form of a Row. Before scanning, the files at `path` need to be enumerated. + */ +private[sql] case class PartitionDirectory(values: InternalRow, path: Path) -private[sql] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition]) +private[sql] case class PartitionSpec( + partitionColumns: StructType, + partitions: Seq[PartitionDirectory]) private[sql] object PartitionSpec { - val emptySpec = PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[Partition]) + val emptySpec = PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[PartitionDirectory]) } private[sql] object PartitioningUtils { @@ -75,14 +86,15 @@ private[sql] object PartitioningUtils { private[sql] def parsePartitions( paths: Seq[Path], defaultPartitionName: String, - typeInference: Boolean): PartitionSpec = { + typeInference: Boolean, + basePaths: Set[Path]): PartitionSpec = { // First, we need to parse every partition's path and see if we can find partition values. - val (partitionValues, optBasePaths) = paths.map { path => - parsePartition(path, defaultPartitionName, typeInference) + val (partitionValues, optDiscoveredBasePaths) = paths.map { path => + parsePartition(path, defaultPartitionName, typeInference, basePaths) }.unzip // We create pairs of (path -> path's partition value) here - // If the corresponding partition value is None, the pair will be skiped + // If the corresponding partition value is None, the pair will be skipped val pathsWithPartitionValues = paths.zip(partitionValues).flatMap(x => x._2.map(x._1 -> _)) if (pathsWithPartitionValues.isEmpty) { @@ -101,11 +113,16 @@ private[sql] object PartitioningUtils { // It will be recognised as conflicting directory structure: // "hdfs://host:9000/invalidPath" // "hdfs://host:9000/path" - val basePaths = optBasePaths.flatMap(x => x) + // TODO: Selective case sensitivity. + val discoveredBasePaths = optDiscoveredBasePaths.flatMap(x => x).map(_.toString.toLowerCase()) assert( - basePaths.distinct.size == 1, + discoveredBasePaths.distinct.size == 1, "Conflicting directory structures detected. Suspicious paths:\b" + - basePaths.distinct.mkString("\n\t", "\n\t", "\n\n")) + discoveredBasePaths.distinct.mkString("\n\t", "\n\t", "\n\n") + + "If provided paths are partition directories, please set " + + "\"basePath\" in the options of the data source to specify the " + + "root directory of the table. If there are multiple root directories, " + + "please load them separately and then union them.") val resolvedPartitionValues = resolvePartitions(pathsWithPartitionValues) @@ -122,7 +139,7 @@ private[sql] object PartitioningUtils { // Finally, we create `Partition`s based on paths and resolved partition values. val partitions = resolvedPartitionValues.zip(pathsWithPartitionValues).map { case (PartitionValues(_, literals), (path, _)) => - Partition(InternalRow.fromSeq(literals.map(_.value)), path.toString) + PartitionDirectory(InternalRow.fromSeq(literals.map(_.value)), path) } PartitionSpec(StructType(fields), partitions) @@ -131,7 +148,7 @@ private[sql] object PartitioningUtils { /** * Parses a single partition, returns column names and values of each partition column, also - * the base path. For example, given: + * the path when we stop partition discovery. For example, given: * {{{ * path = hdfs://:/path/to/partition/a=42/b=hello/c=3.14 * }}} @@ -144,40 +161,63 @@ private[sql] object PartitioningUtils { * Literal.create("hello", StringType), * Literal.create(3.14, FloatType))) * }}} - * and the base path: + * and the path when we stop the discovery is: * {{{ - * /path/to/partition + * hdfs://:/path/to/partition * }}} */ private[sql] def parsePartition( path: Path, defaultPartitionName: String, - typeInference: Boolean): (Option[PartitionValues], Option[Path]) = { + typeInference: Boolean, + basePaths: Set[Path]): (Option[PartitionValues], Option[Path]) = { val columns = ArrayBuffer.empty[(String, Literal)] // Old Hadoop versions don't have `Path.isRoot` var finished = path.getParent == null - var chopped = path - var basePath = path + // currentPath is the current path that we will use to parse partition column value. + var currentPath: Path = path while (!finished) { // Sometimes (e.g., when speculative task is enabled), temporary directories may be left - // uncleaned. Here we simply ignore them. - if (chopped.getName.toLowerCase == "_temporary") { + // uncleaned. Here we simply ignore them. + if (currentPath.getName.toLowerCase == "_temporary") { return (None, None) } - val maybeColumn = parsePartitionColumn(chopped.getName, defaultPartitionName, typeInference) - maybeColumn.foreach(columns += _) - basePath = chopped - chopped = chopped.getParent - finished = (maybeColumn.isEmpty && !columns.isEmpty) || chopped.getParent == null + if (basePaths.contains(currentPath)) { + // If the currentPath is one of base paths. We should stop. + finished = true + } else { + // Let's say currentPath is a path of "/table/a=1/", currentPath.getName will give us a=1. + // Once we get the string, we try to parse it and find the partition column and value. + val maybeColumn = + parsePartitionColumn(currentPath.getName, defaultPartitionName, typeInference) + maybeColumn.foreach(columns += _) + + // Now, we determine if we should stop. + // When we hit any of the following cases, we will stop: + // - In this iteration, we could not parse the value of partition column and value, + // i.e. maybeColumn is None, and columns is not empty. At here we check if columns is + // empty to handle cases like /table/a=1/_temporary/something (we need to find a=1 in + // this case). + // - After we get the new currentPath, this new currentPath represent the top level dir + // i.e. currentPath.getParent == null. For the example of "/table/a=1/", + // the top level dir is "/table". + finished = + (maybeColumn.isEmpty && !columns.isEmpty) || currentPath.getParent == null + + if (!finished) { + // For the above example, currentPath will be "/table/". + currentPath = currentPath.getParent + } + } } if (columns.isEmpty) { (None, Some(path)) } else { val (columnNames, values) = columns.reverse.unzip - (Some(PartitionValues(columnNames, values)), Some(basePath)) + (Some(PartitionValues(columnNames, values)), Some(currentPath)) } } @@ -214,7 +254,9 @@ private[sql] object PartitioningUtils { if (pathsWithPartitionValues.isEmpty) { Seq.empty } else { - val distinctPartColNames = pathsWithPartitionValues.map(_._2.columnNames).distinct + // TODO: Selective case sensitivity. + val distinctPartColNames = + pathsWithPartitionValues.map(_._2.columnNames.map(_.toLowerCase())).distinct assert( distinctPartColNames.size == 1, listConflictingPartitionColumns(pathsWithPartitionValues)) @@ -299,10 +341,10 @@ private[sql] object PartitioningUtils { def validatePartitionColumnDataTypes( schema: StructType, - partitionColumns: Array[String], + partitionColumns: Seq[String], caseSensitive: Boolean): Unit = { - ResolvedDataSource.partitionColumnsSchema(schema, partitionColumns, caseSensitive).foreach { + partitionColumnsSchema(schema, partitionColumns, caseSensitive).foreach { field => field.dataType match { case _: AtomicType => // OK case _ => throw new AnalysisException(s"Cannot use ${field.dataType} for partition column") @@ -310,6 +352,26 @@ private[sql] object PartitioningUtils { } } + def partitionColumnsSchema( + schema: StructType, + partitionColumns: Seq[String], + caseSensitive: Boolean): StructType = { + val equality = columnNameEquality(caseSensitive) + StructType(partitionColumns.map { col => + schema.find(f => equality(f.name, col)).getOrElse { + throw new RuntimeException(s"Partition column $col not found in schema $schema") + } + }).asNullable + } + + private def columnNameEquality(caseSensitive: Boolean): (String, String) => Boolean = { + if (caseSensitive) { + org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution + } else { + org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution + } + } + /** * Given a collection of [[Literal]]s, resolves possible type conflicts by up-casting "lower" * types. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RecordReaderIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RecordReaderIterator.scala new file mode 100644 index 0000000000000..f03ae94d55838 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RecordReaderIterator.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.hadoop.mapreduce.RecordReader + +import org.apache.spark.sql.catalyst.InternalRow + +/** + * An adaptor from a Hadoop [[RecordReader]] to an [[Iterator]] over the values returned. + * + * Note that this returns [[Object]]s instead of [[InternalRow]] because we rely on erasure to pass + * column batches by pretending they are rows. + */ +class RecordReaderIterator[T](rowReader: RecordReader[_, T]) extends Iterator[T] { + private[this] var havePair = false + private[this] var finished = false + + override def hasNext: Boolean = { + if (!finished && !havePair) { + finished = !rowReader.nextKeyValue + if (finished) { + // Close and release the reader here; close() will also be called when the task + // completes, but for tasks that read from many files, it helps to release the + // resources early. + rowReader.close() + } + havePair = !finished + } + !finished + } + + override def next(): T = { + if (!hasNext) { + throw new java.util.NoSuchElementException("End of stream") + } + havePair = false + rowReader.getCurrentValue + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala deleted file mode 100644 index 86a306b8f941d..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ /dev/null @@ -1,249 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.datasources - -import java.util.ServiceLoader - -import scala.collection.JavaConverters._ -import scala.language.{existentials, implicitConversions} -import scala.util.{Success, Failure, Try} - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.util.StringUtils - -import org.apache.spark.Logging -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.{DataFrame, SaveMode, AnalysisException, SQLContext} -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{CalendarIntervalType, StructType} -import org.apache.spark.util.Utils - - -case class ResolvedDataSource(provider: Class[_], relation: BaseRelation) - - -object ResolvedDataSource extends Logging { - - /** A map to maintain backward compatibility in case we move data sources around. */ - private val backwardCompatibilityMap = Map( - "org.apache.spark.sql.jdbc" -> classOf[jdbc.DefaultSource].getCanonicalName, - "org.apache.spark.sql.jdbc.DefaultSource" -> classOf[jdbc.DefaultSource].getCanonicalName, - "org.apache.spark.sql.json" -> classOf[json.DefaultSource].getCanonicalName, - "org.apache.spark.sql.json.DefaultSource" -> classOf[json.DefaultSource].getCanonicalName, - "org.apache.spark.sql.parquet" -> classOf[parquet.DefaultSource].getCanonicalName, - "org.apache.spark.sql.parquet.DefaultSource" -> classOf[parquet.DefaultSource].getCanonicalName - ) - - /** Given a provider name, look up the data source class definition. */ - def lookupDataSource(provider0: String): Class[_] = { - val provider = backwardCompatibilityMap.getOrElse(provider0, provider0) - val provider2 = s"$provider.DefaultSource" - val loader = Utils.getContextOrSparkClassLoader - val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) - - serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(provider)).toList match { - /** the provider format did not match any given registered aliases */ - case Nil => Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { - case Success(dataSource) => dataSource - case Failure(error) => - if (provider.startsWith("org.apache.spark.sql.hive.orc")) { - throw new ClassNotFoundException( - "The ORC data source must be used with Hive support enabled.", error) - } else { - throw new ClassNotFoundException( - s"Failed to load class for data source: $provider.", error) - } - } - /** there is exactly one registered alias */ - case head :: Nil => head.getClass - /** There are multiple registered aliases for the input */ - case sources => sys.error(s"Multiple sources found for $provider, " + - s"(${sources.map(_.getClass.getName).mkString(", ")}), " + - "please specify the fully qualified class name.") - } - } - - /** Create a [[ResolvedDataSource]] for reading data in. */ - def apply( - sqlContext: SQLContext, - userSpecifiedSchema: Option[StructType], - partitionColumns: Array[String], - provider: String, - options: Map[String, String]): ResolvedDataSource = { - val clazz: Class[_] = lookupDataSource(provider) - def className: String = clazz.getCanonicalName - val relation = userSpecifiedSchema match { - case Some(schema: StructType) => clazz.newInstance() match { - case dataSource: SchemaRelationProvider => - val caseInsensitiveOptions = new CaseInsensitiveMap(options) - if (caseInsensitiveOptions.contains("paths")) { - throw new AnalysisException(s"$className does not support paths option.") - } - dataSource.createRelation(sqlContext, caseInsensitiveOptions, schema) - case dataSource: HadoopFsRelationProvider => - val maybePartitionsSchema = if (partitionColumns.isEmpty) { - None - } else { - Some(partitionColumnsSchema( - schema, partitionColumns, sqlContext.conf.caseSensitiveAnalysis)) - } - - val caseInsensitiveOptions = new CaseInsensitiveMap(options) - val paths = { - if (caseInsensitiveOptions.contains("paths") && - caseInsensitiveOptions.contains("path")) { - throw new AnalysisException(s"Both path and paths options are present.") - } - caseInsensitiveOptions.get("paths") - .map(_.split("(? - val hdfsPath = new Path(pathString) - val fs = hdfsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.globPathIfNecessary(qualified).map(_.toString) - } - } - - val dataSchema = - StructType(schema.filterNot(f => partitionColumns.contains(f.name))).asNullable - - dataSource.createRelation( - sqlContext, - paths, - Some(dataSchema), - maybePartitionsSchema, - caseInsensitiveOptions) - case dataSource: org.apache.spark.sql.sources.RelationProvider => - throw new AnalysisException(s"$className does not allow user-specified schemas.") - case _ => - throw new AnalysisException(s"$className is not a RelationProvider.") - } - - case None => clazz.newInstance() match { - case dataSource: RelationProvider => - val caseInsensitiveOptions = new CaseInsensitiveMap(options) - if (caseInsensitiveOptions.contains("paths")) { - throw new AnalysisException(s"$className does not support paths option.") - } - dataSource.createRelation(sqlContext, caseInsensitiveOptions) - case dataSource: HadoopFsRelationProvider => - val caseInsensitiveOptions = new CaseInsensitiveMap(options) - val paths = { - if (caseInsensitiveOptions.contains("paths") && - caseInsensitiveOptions.contains("path")) { - throw new AnalysisException(s"Both path and paths options are present.") - } - caseInsensitiveOptions.get("paths") - .map(_.split("(? - val hdfsPath = new Path(pathString) - val fs = hdfsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.globPathIfNecessary(qualified).map(_.toString) - } - } - dataSource.createRelation(sqlContext, paths, None, None, caseInsensitiveOptions) - case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => - throw new AnalysisException( - s"A schema needs to be specified when using $className.") - case _ => - throw new AnalysisException( - s"$className is neither a RelationProvider nor a FSBasedRelationProvider.") - } - } - new ResolvedDataSource(clazz, relation) - } - - def partitionColumnsSchema( - schema: StructType, - partitionColumns: Array[String], - caseSensitive: Boolean): StructType = { - val equality = columnNameEquality(caseSensitive) - StructType(partitionColumns.map { col => - schema.find(f => equality(f.name, col)).getOrElse { - throw new RuntimeException(s"Partition column $col not found in schema $schema") - } - }).asNullable - } - - private def columnNameEquality(caseSensitive: Boolean): (String, String) => Boolean = { - if (caseSensitive) { - org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution - } else { - org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution - } - } - - /** Create a [[ResolvedDataSource]] for saving the content of the given DataFrame. */ - def apply( - sqlContext: SQLContext, - provider: String, - partitionColumns: Array[String], - mode: SaveMode, - options: Map[String, String], - data: DataFrame): ResolvedDataSource = { - if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { - throw new AnalysisException("Cannot save interval data type into external storage.") - } - val clazz: Class[_] = lookupDataSource(provider) - val relation = clazz.newInstance() match { - case dataSource: CreatableRelationProvider => - dataSource.createRelation(sqlContext, mode, options, data) - case dataSource: HadoopFsRelationProvider => - // Don't glob path for the write path. The contracts here are: - // 1. Only one output path can be specified on the write path; - // 2. Output path must be a legal HDFS style file system path; - // 3. It's OK that the output path doesn't exist yet; - val caseInsensitiveOptions = new CaseInsensitiveMap(options) - val outputPath = { - val path = new Path(caseInsensitiveOptions("path")) - val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - path.makeQualified(fs.getUri, fs.getWorkingDirectory) - } - - val caseSensitive = sqlContext.conf.caseSensitiveAnalysis - PartitioningUtils.validatePartitionColumnDataTypes( - data.schema, partitionColumns, caseSensitive) - - val equality = columnNameEquality(caseSensitive) - val dataSchema = StructType( - data.schema.filterNot(f => partitionColumns.exists(equality(_, f.name)))) - val r = dataSource.createRelation( - sqlContext, - Array(outputPath.toString), - Some(dataSchema.asNullable), - Some(partitionColumnsSchema(data.schema, partitionColumns, caseSensitive)), - caseInsensitiveOptions) - - // For partitioned relation r, r.schema's column ordering can be different from the column - // ordering of data.logicalPlan (partition columns are all moved after data column). This - // will be adjusted within InsertIntoHadoopFsRelation. - sqlContext.executePlan( - InsertIntoHadoopFsRelation( - r, - data.logicalPlan, - mode)).toRdd - r - case _ => - sys.error(s"${clazz.getCanonicalName} does not allow create table as select.") - } - ResolvedDataSource(clazz, relation) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 1b59b19d9420d..815d1d01ef343 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -19,36 +19,42 @@ package org.apache.spark.sql.execution.datasources import java.util.{Date, UUID} -import scala.collection.JavaConverters._ - import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter} +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl + import org.apache.spark._ -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil -import org.apache.spark.sql._ +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.UnsafeKVExternalSorter -import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory} -import org.apache.spark.sql.types.{StructType, StringType} -import org.apache.spark.util.SerializableConfiguration - +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.{OutputWriter, OutputWriterFactory} +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.util.{SerializableConfiguration, Utils} + +/** A container for all the details required when writing to a table. */ +case class WriteRelation( + sqlContext: SQLContext, + dataSchema: StructType, + path: String, + prepareJobForWrite: Job => OutputWriterFactory, + bucketSpec: Option[BucketSpec]) private[sql] abstract class BaseWriterContainer( - @transient val relation: HadoopFsRelation, + @transient val relation: WriteRelation, @transient private val job: Job, isAppend: Boolean) - extends SparkHadoopMapReduceUtil - with Logging - with Serializable { + extends Logging with Serializable { protected val dataSchema = relation.dataSchema protected val serializableConf = - new SerializableConfiguration(SparkHadoopUtil.get.getConfigurationFromJobContext(job)) + new SerializableConfiguration(job.getConfiguration) // This UUID is used to avoid output file name collision between different appending write jobs. // These jobs may belong to different SparkContext instances. Concrete data source implementations @@ -70,12 +76,7 @@ private[sql] abstract class BaseWriterContainer( @transient private var taskAttemptId: TaskAttemptID = _ @transient protected var taskAttemptContext: TaskAttemptContext = _ - protected val outputPath: String = { - assert( - relation.paths.length == 1, - s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}") - relation.paths.head - } + protected val outputPath: String = relation.path protected var outputWriterFactory: OutputWriterFactory = _ @@ -90,8 +91,7 @@ private[sql] abstract class BaseWriterContainer( // This UUID is sent to executor side together with the serialized `Configuration` object within // the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate // unique task output files. - SparkHadoopUtil.get.getConfigurationFromJobContext(job). - set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) + job.getConfiguration.set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor // clones the Configuration object passed in. If we initialize the TaskAttemptContext first, @@ -101,7 +101,7 @@ private[sql] abstract class BaseWriterContainer( // committer, since their initialization involve the job configuration, which can be potentially // decorated in `prepareJobForWrite`. outputWriterFactory = relation.prepareJobForWrite(job) - taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) + taskAttemptContext = new TaskAttemptContextImpl(serializableConf.value, taskAttemptId) outputFormatClass = job.getOutputFormatClass outputCommitter = newOutputCommitter(taskAttemptContext) @@ -111,7 +111,7 @@ private[sql] abstract class BaseWriterContainer( def executorSideSetup(taskContext: TaskContext): Unit = { setupIDs(taskContext.stageId(), taskContext.partitionId(), taskContext.attemptNumber()) setupConf() - taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) + taskAttemptContext = new TaskAttemptContextImpl(serializableConf.value, taskAttemptId) outputCommitter = newOutputCommitter(taskAttemptContext) outputCommitter.setupTask(taskAttemptContext) } @@ -124,6 +124,25 @@ private[sql] abstract class BaseWriterContainer( } } + protected def newOutputWriter(path: String, bucketId: Option[Int] = None): OutputWriter = { + try { + outputWriterFactory.newInstance(path, bucketId, dataSchema, taskAttemptContext) + } catch { + case e: org.apache.hadoop.fs.FileAlreadyExistsException => + if (outputCommitter.getClass.getName.contains("Direct")) { + // SPARK-11382: DirectParquetOutputCommitter is not idempotent, meaning on retry + // attempts, the task will fail because the output file is created from a prior attempt. + // This often means the most visible error to the user is misleading. Augment the error + // to tell the user to look for the actual error. + throw new SparkException("The output file already exists but this could be due to a " + + "failure from an earlier attempt. Look through the earlier logs or stage page for " + + "the first error.\n File exists error: " + e, e) + } else { + throw e + } + } + } + private def newOutputCommitter(context: TaskAttemptContext): OutputCommitter = { val defaultOutputCommitter = outputFormatClass.newInstance().getOutputCommitter(context) @@ -131,24 +150,15 @@ private[sql] abstract class BaseWriterContainer( // If we are appending data to an existing dir, we will only use the output committer // associated with the file output format since it is not safe to use a custom // committer for appending. For example, in S3, direct parquet output committer may - // leave partial data in the destination dir when the the appending job fails. + // leave partial data in the destination dir when the appending job fails. // // See SPARK-8578 for more details logInfo( s"Using default output committer ${defaultOutputCommitter.getClass.getCanonicalName} " + "for appending.") defaultOutputCommitter - } else if (speculationEnabled) { - // When speculation is enabled, it's not safe to use customized output committer classes, - // especially direct output committers (e.g. `DirectParquetOutputCommitter`). - // - // See SPARK-9899 for more details. - logInfo( - s"Using default output committer ${defaultOutputCommitter.getClass.getCanonicalName} " + - "because spark.speculation is configured to be true.") - defaultOutputCommitter } else { - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val configuration = context.getConfiguration val committerClass = configuration.getClass( SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) @@ -183,10 +193,8 @@ private[sql] abstract class BaseWriterContainer( private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = { this.jobId = SparkHadoopWriter.createJobID(new Date, jobId) - this.taskId = new TaskID(this.jobId, true, splitId) - // scalastyle:off jobcontext + this.taskId = new TaskID(this.jobId, TaskType.MAP, splitId) this.taskAttemptId = new TaskAttemptID(taskId, attemptId) - // scalastyle:on jobcontext } private def setupConf(): Unit = { @@ -225,41 +233,37 @@ private[sql] abstract class BaseWriterContainer( * A writer that writes all of the rows in a partition to a single file. */ private[sql] class DefaultWriterContainer( - relation: HadoopFsRelation, + relation: WriteRelation, job: Job, isAppend: Boolean) extends BaseWriterContainer(relation, job, isAppend) { def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { executorSideSetup(taskContext) - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) + val configuration = taskAttemptContext.getConfiguration configuration.set("spark.sql.sources.output.path", outputPath) - val writer = outputWriterFactory.newInstance(getWorkPath, dataSchema, taskAttemptContext) + var writer = newOutputWriter(getWorkPath) writer.initConverter(dataSchema) - var writerClosed = false - // If anything below fails, we should abort the task. try { - while (iterator.hasNext) { - val internalRow = iterator.next() - writer.writeInternal(internalRow) - } - - commitTask() + Utils.tryWithSafeFinallyAndFailureCallbacks { + while (iterator.hasNext) { + val internalRow = iterator.next() + writer.writeInternal(internalRow) + } + commitTask() + }(catchBlock = abortTask()) } catch { - case cause: Throwable => - logError("Aborting task.", cause) - abortTask() - throw new SparkException("Task failed while writing rows.", cause) + case t: Throwable => + throw new SparkException("Task failed while writing rows", t) } def commitTask(): Unit = { try { - assert(writer != null, "OutputWriter instance should have been initialized") - if (!writerClosed) { + if (writer != null) { writer.close() - writerClosed = true + writer = null } super.commitTask() } catch { @@ -272,9 +276,8 @@ private[sql] class DefaultWriterContainer( def abortTask(): Unit = { try { - if (!writerClosed) { + if (writer != null) { writer.close() - writerClosed = true } } finally { super.abortTask() @@ -289,7 +292,7 @@ private[sql] class DefaultWriterContainer( * writer externally sorts the remaining rows and then writes out them out one file at a time. */ private[sql] class DynamicPartitionWriterContainer( - relation: HadoopFsRelation, + relation: WriteRelation, job: Job, partitionColumns: Seq[Attribute], dataColumns: Seq[Attribute], @@ -299,139 +302,145 @@ private[sql] class DynamicPartitionWriterContainer( isAppend: Boolean) extends BaseWriterContainer(relation, job, isAppend) { - def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { - val outputWriters = new java.util.HashMap[InternalRow, OutputWriter] - executorSideSetup(taskContext) + private val bucketSpec = relation.bucketSpec - var outputWritersCleared = false + private val bucketColumns: Seq[Attribute] = bucketSpec.toSeq.flatMap { + spec => spec.bucketColumnNames.map(c => inputSchema.find(_.name == c).get) + } - // Returns the partition key given an input row - val getPartitionKey = UnsafeProjection.create(partitionColumns, inputSchema) - // Returns the data columns to be written given an input row - val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) + private val sortColumns: Seq[Attribute] = bucketSpec.toSeq.flatMap { + spec => spec.sortColumnNames.map(c => inputSchema.find(_.name == c).get) + } + + private def bucketIdExpression: Option[Expression] = bucketSpec.map { spec => + // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can + // guarantee the data distribution is same between shuffle and bucketed data source, which + // enables us to only shuffle one side when join a bucketed table and a normal one. + HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression + } - // Expressions that given a partition key build a string like: col1=val/col2=val/... - val partitionStringExpression = partitionColumns.zipWithIndex.flatMap { case (c, i) => + // Expressions that given a partition key build a string like: col1=val/col2=val/... + private def partitionStringExpression: Seq[Expression] = { + partitionColumns.zipWithIndex.flatMap { case (c, i) => val escaped = ScalaUDF( - PartitioningUtils.escapePathName _, StringType, Seq(Cast(c, StringType)), Seq(StringType)) + PartitioningUtils.escapePathName _, + StringType, + Seq(Cast(c, StringType)), + Seq(StringType)) val str = If(IsNull(c), Literal(defaultPartitionName), escaped) val partitionName = Literal(c.name + "=") :: str :: Nil if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName } + } + + private def getBucketIdFromKey(key: InternalRow): Option[Int] = bucketSpec.map { _ => + key.getInt(partitionColumns.length) + } + + /** + * Open and returns a new OutputWriter given a partition key and optional bucket id. + * If bucket id is specified, we will append it to the end of the file name, but before the + * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet + */ + private def newOutputWriter( + key: InternalRow, + getPartitionString: UnsafeProjection): OutputWriter = { + val configuration = taskAttemptContext.getConfiguration + val path = if (partitionColumns.nonEmpty) { + val partitionPath = getPartitionString(key).getString(0) + configuration.set( + "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) + new Path(getWorkPath, partitionPath).toString + } else { + configuration.set("spark.sql.sources.output.path", outputPath) + getWorkPath + } + val bucketId = getBucketIdFromKey(key) + val newWriter = super.newOutputWriter(path, bucketId) + newWriter.initConverter(dataSchema) + newWriter + } + + def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { + executorSideSetup(taskContext) + + // We should first sort by partition columns, then bucket id, and finally sorting columns. + val sortingExpressions: Seq[Expression] = partitionColumns ++ bucketIdExpression ++ sortColumns + val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema) + + val sortingKeySchema = StructType(sortingExpressions.map { + case a: Attribute => StructField(a.name, a.dataType, a.nullable) + // The sorting expressions are all `Attribute` except bucket id. + case _ => StructField("bucketId", IntegerType, nullable = false) + }) + + // Returns the data columns to be written given an input row + val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) // Returns the partition path given a partition key. val getPartitionString = UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns) - // If anything below fails, we should abort the task. - try { - // This will be filled in if we have to fall back on sorting. - var sorter: UnsafeKVExternalSorter = null - while (iterator.hasNext && sorter == null) { - val inputRow = iterator.next() - val currentKey = getPartitionKey(inputRow) - var currentWriter = outputWriters.get(currentKey) - - if (currentWriter == null) { - if (outputWriters.size < maxOpenFiles) { - currentWriter = newOutputWriter(currentKey) - outputWriters.put(currentKey.copy(), currentWriter) - currentWriter.writeInternal(getOutputRow(inputRow)) - } else { - logInfo(s"Maximum partitions reached, falling back on sorting.") - sorter = new UnsafeKVExternalSorter( - StructType.fromAttributes(partitionColumns), - StructType.fromAttributes(dataColumns), - SparkEnv.get.blockManager, - TaskContext.get().taskMemoryManager().pageSizeBytes) - sorter.insertKV(currentKey, getOutputRow(inputRow)) - } - } else { - currentWriter.writeInternal(getOutputRow(inputRow)) - } - } + // Sorts the data before write, so that we only need one writer at the same time. + // TODO: inject a local sort operator in planning. + val sorter = new UnsafeKVExternalSorter( + sortingKeySchema, + StructType.fromAttributes(dataColumns), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + TaskContext.get().taskMemoryManager().pageSizeBytes) + + while (iterator.hasNext) { + val currentRow = iterator.next() + sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) + } + logInfo(s"Sorting complete. Writing out partition files one at a time.") - // If the sorter is not null that means that we reached the maxFiles above and need to finish - // using external sort. - if (sorter != null) { - while (iterator.hasNext) { - val currentRow = iterator.next() - sorter.insertKV(getPartitionKey(currentRow), getOutputRow(currentRow)) - } + val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { + identity + } else { + UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { + case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) + }) + } + + val sortedIterator = sorter.sortedIterator() - logInfo(s"Sorting complete. Writing out partition files one at a time.") - - val sortedIterator = sorter.sortedIterator() - var currentKey: InternalRow = null - var currentWriter: OutputWriter = null - try { - while (sortedIterator.next()) { - if (currentKey != sortedIterator.getKey) { - if (currentWriter != null) { - currentWriter.close() - } - currentKey = sortedIterator.getKey.copy() - logDebug(s"Writing partition: $currentKey") - - // Either use an existing file from before, or open a new one. - currentWriter = outputWriters.remove(currentKey) - if (currentWriter == null) { - currentWriter = newOutputWriter(currentKey) - } + // If anything below fails, we should abort the task. + var currentWriter: OutputWriter = null + try { + Utils.tryWithSafeFinallyAndFailureCallbacks { + var currentKey: UnsafeRow = null + while (sortedIterator.next()) { + val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] + if (currentKey != nextKey) { + if (currentWriter != null) { + currentWriter.close() + currentWriter = null } + currentKey = nextKey.copy() + logDebug(s"Writing partition: $currentKey") - currentWriter.writeInternal(sortedIterator.getValue) + currentWriter = newOutputWriter(currentKey, getPartitionString) } - } finally { - if (currentWriter != null) { currentWriter.close() } + currentWriter.writeInternal(sortedIterator.getValue) + } + if (currentWriter != null) { + currentWriter.close() + currentWriter = null } - } - commitTask() - } catch { - case cause: Throwable => - logError("Aborting task.", cause) + commitTask() + }(catchBlock = { + if (currentWriter != null) { + currentWriter.close() + } abortTask() - throw new SparkException("Task failed while writing rows.", cause) - } - - /** Open and returns a new OutputWriter given a partition key. */ - def newOutputWriter(key: InternalRow): OutputWriter = { - val partitionPath = getPartitionString(key).getString(0) - val path = new Path(getWorkPath, partitionPath) - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) - configuration.set( - "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) - val newWriter = outputWriterFactory.newInstance(path.toString, dataSchema, taskAttemptContext) - newWriter.initConverter(dataSchema) - newWriter - } - - def clearOutputWriters(): Unit = { - if (!outputWritersCleared) { - outputWriters.asScala.values.foreach(_.close()) - outputWriters.clear() - outputWritersCleared = true - } - } - - def commitTask(): Unit = { - try { - clearOutputWriters() - super.commitTask() - } catch { - case cause: Throwable => - throw new RuntimeException("Failed to commit task", cause) - } - } - - def abortTask(): Unit = { - try { - clearOutputWriters() - } finally { - super.abortTask() - } + }) + } catch { + case t: Throwable => + throw new SparkException("Task failed while writing rows", t) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala new file mode 100644 index 0000000000000..6008d73717f77 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +/** + * A container for bucketing information. + * Bucketing is a technology for decomposing data sets into more manageable parts, and the number + * of buckets is fixed so it does not fluctuate with data. + * + * @param numBuckets number of buckets. + * @param bucketColumnNames the names of the columns that used to generate the bucket id. + * @param sortColumnNames the names of the columns that used to sort data in each bucket. + */ +private[sql] case class BucketSpec( + numBuckets: Int, + bucketColumnNames: Seq[String], + sortColumnNames: Seq[String]) + +private[sql] object BucketingUtils { + // The file name of bucketed data should have 3 parts: + // 1. some other information in the head of file name + // 2. bucket id part, some numbers, starts with "_" + // * The other-information part may use `-` as separator and may have numbers at the end, + // e.g. a normal parquet file without bucketing may have name: + // part-r-00000-2dd664f9-d2c4-4ffe-878f-431234567891.gz.parquet, and we will mistakenly + // treat `431234567891` as bucket id. So here we pick `_` as separator. + // 3. optional file extension part, in the tail of file name, starts with `.` + // An example of bucketed parquet file name with bucket id 3: + // part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet + private val bucketedFileName = """.*_(\d+)(?:\..*)?$""".r + + def getBucketId(fileName: String): Option[Int] = fileName match { + case bucketedFileName(bucketId) => Some(bucketId.toInt) + case other => None + } + + def bucketIdToString(id: Int): String = f"_$id%05d" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala new file mode 100644 index 0000000000000..ea843a10137f2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.math.BigDecimal +import java.text.NumberFormat +import java.util.Locale + +import scala.util.control.Exception._ +import scala.util.Try + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +private[csv] object CSVInferSchema { + + /** + * Similar to the JSON schema inference + * 1. Infer type of each row + * 2. Merge row types to find common type + * 3. Replace any null types with string type + */ + def infer( + tokenRdd: RDD[Array[String]], + header: Array[String], + nullValue: String = ""): StructType = { + + val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) + val rootTypes: Array[DataType] = + tokenRdd.aggregate(startType)(inferRowType(nullValue), mergeRowTypes) + + val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) => + val dType = rootType match { + case _: NullType => StringType + case other => other + } + StructField(thisHeader, dType, nullable = true) + } + + StructType(structFields) + } + + private def inferRowType(nullValue: String) + (rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = { + var i = 0 + while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing. + rowSoFar(i) = inferField(rowSoFar(i), next(i), nullValue) + i+=1 + } + rowSoFar + } + + def mergeRowTypes(first: Array[DataType], second: Array[DataType]): Array[DataType] = { + first.zipAll(second, NullType, NullType).map { case (a, b) => + findTightestCommonType(a, b).getOrElse(NullType) + } + } + + /** + * Infer type of string field. Given known type Double, and a string "1", there is no + * point checking if it is an Int, as the final type must be Double or higher. + */ + def inferField(typeSoFar: DataType, field: String, nullValue: String = ""): DataType = { + if (field == null || field.isEmpty || field == nullValue) { + typeSoFar + } else { + typeSoFar match { + case NullType => tryParseInteger(field) + case IntegerType => tryParseInteger(field) + case LongType => tryParseLong(field) + case DoubleType => tryParseDouble(field) + case TimestampType => tryParseTimestamp(field) + case BooleanType => tryParseBoolean(field) + case StringType => StringType + case other: DataType => + throw new UnsupportedOperationException(s"Unexpected data type $other") + } + } + } + + private def tryParseInteger(field: String): DataType = if ((allCatch opt field.toInt).isDefined) { + IntegerType + } else { + tryParseLong(field) + } + + private def tryParseLong(field: String): DataType = if ((allCatch opt field.toLong).isDefined) { + LongType + } else { + tryParseDouble(field) + } + + private def tryParseDouble(field: String): DataType = { + if ((allCatch opt field.toDouble).isDefined) { + DoubleType + } else { + tryParseTimestamp(field) + } + } + + def tryParseTimestamp(field: String): DataType = { + if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) { + TimestampType + } else { + tryParseBoolean(field) + } + } + + def tryParseBoolean(field: String): DataType = { + if ((allCatch opt field.toBoolean).isDefined) { + BooleanType + } else { + stringType() + } + } + + // Defining a function to return the StringType constant is necessary in order to work around + // a Scala compiler issue which leads to runtime incompatibilities with certain Spark versions; + // see issue #128 for more details. + private def stringType(): DataType = { + StringType + } + + private val numericPrecedence: IndexedSeq[DataType] = HiveTypeCoercion.numericPrecedence + + /** + * Copied from internal Spark api + * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]] + */ + val findTightestCommonType: (DataType, DataType) => Option[DataType] = { + case (t1, t2) if t1 == t2 => Some(t1) + case (NullType, t1) => Some(t1) + case (t1, NullType) => Some(t1) + case (StringType, t2) => Some(StringType) + case (t1, StringType) => Some(StringType) + + // Promote numeric types to the highest of the two and all numeric types to unlimited decimal + case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) => + val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2) + Some(numericPrecedence(index)) + + case _ => None + } +} + +private[csv] object CSVTypeCast { + + /** + * Casts given string datum to specified type. + * Currently we do not support complex types (ArrayType, MapType, StructType). + * + * For string types, this is simply the datum. For other types. + * For other nullable types, this is null if the string datum is empty. + * + * @param datum string value + * @param castType SparkSQL type + */ + def castTo( + datum: String, + castType: DataType, + nullable: Boolean = true, + nullValue: String = ""): Any = { + + if (datum == nullValue && nullable && (!castType.isInstanceOf[StringType])) { + null + } else { + castType match { + case _: ByteType => datum.toByte + case _: ShortType => datum.toShort + case _: IntegerType => datum.toInt + case _: LongType => datum.toLong + case _: FloatType => Try(datum.toFloat) + .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue()) + case _: DoubleType => Try(datum.toDouble) + .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue()) + case _: BooleanType => datum.toBoolean + case dt: DecimalType => + val value = new BigDecimal(datum.replaceAll(",", "")) + Decimal(value, dt.precision, dt.scale) + // TODO(hossein): would be good to support other common timestamp formats + case _: TimestampType => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681. + DateTimeUtils.stringToTime(datum).getTime * 1000L + // TODO(hossein): would be good to support other common date formats + case _: DateType => + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime) + case _: StringType => UTF8String.fromString(datum) + case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") + } + } + } + + /** + * Helper method that converts string representation of a character to actual character. + * It handles some Java escaped strings and throws exception if given string is longer than one + * character. + */ + @throws[IllegalArgumentException] + def toChar(str: String): Char = { + if (str.charAt(0) == '\\') { + str.charAt(1) + match { + case 't' => '\t' + case 'r' => '\r' + case 'b' => '\b' + case 'f' => '\f' + case '\"' => '\"' // In case user changes quote char and uses \" as delimiter in options + case '\'' => '\'' + case 'u' if str == """\u0000""" => '\u0000' + case _ => + throw new IllegalArgumentException(s"Unsupported special character for delimiter: $str") + } + } else if (str.length == 1) { + str.charAt(0) + } else { + throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala new file mode 100644 index 0000000000000..7b9d3b605a891 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.nio.charset.StandardCharsets + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.datasources.{CompressionCodecs, ParseModes} + +private[sql] class CSVOptions(@transient private val parameters: Map[String, String]) + extends Logging with Serializable { + + private def getChar(paramName: String, default: Char): Char = { + val paramValue = parameters.get(paramName) + paramValue match { + case None => default + case Some(value) if value.length == 0 => '\u0000' + case Some(value) if value.length == 1 => value.charAt(0) + case _ => throw new RuntimeException(s"$paramName cannot be more than one character") + } + } + + private def getInt(paramName: String, default: Int): Int = { + val paramValue = parameters.get(paramName) + paramValue match { + case None => default + case Some(value) => try { + value.toInt + } catch { + case e: NumberFormatException => + throw new RuntimeException(s"$paramName should be an integer. Found $value") + } + } + } + + private def getBool(paramName: String, default: Boolean = false): Boolean = { + val param = parameters.getOrElse(paramName, default.toString) + if (param.toLowerCase == "true") { + true + } else if (param.toLowerCase == "false") { + false + } else { + throw new Exception(s"$paramName flag can be true or false") + } + } + + val delimiter = CSVTypeCast.toChar( + parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) + private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + val charset = parameters.getOrElse("encoding", + parameters.getOrElse("charset", StandardCharsets.UTF_8.name())) + + val quote = getChar("quote", '\"') + val escape = getChar("escape", '\\') + val comment = getChar("comment", '\u0000') + + val headerFlag = getBool("header") + val inferSchemaFlag = getBool("inferSchema") + val ignoreLeadingWhiteSpaceFlag = getBool("ignoreLeadingWhiteSpace") + val ignoreTrailingWhiteSpaceFlag = getBool("ignoreTrailingWhiteSpace") + + // Parse mode flags + if (!ParseModes.isValidMode(parseMode)) { + logWarning(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.") + } + + val failFast = ParseModes.isFailFastMode(parseMode) + val dropMalformed = ParseModes.isDropMalformedMode(parseMode) + val permissive = ParseModes.isPermissiveMode(parseMode) + + val nullValue = parameters.getOrElse("nullValue", "") + + val compressionCodec: Option[String] = { + val name = parameters.get("compression").orElse(parameters.get("codec")) + name.map(CompressionCodecs.getCodecClassName) + } + + val maxColumns = getInt("maxColumns", 20480) + + val maxCharsPerColumn = getInt("maxCharsPerColumn", 1000000) + + val inputBufferSize = 128 + + val isCommentSet = this.comment != '\u0000' + + val rowSeparator = "\n" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala new file mode 100644 index 0000000000000..c3d863f547dab --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.io.{ByteArrayOutputStream, OutputStreamWriter, StringReader} +import java.nio.charset.StandardCharsets + +import com.univocity.parsers.csv.{CsvParser, CsvParserSettings, CsvWriter, CsvWriterSettings} + +import org.apache.spark.internal.Logging + +/** + * Read and parse CSV-like input + * + * @param params Parameters object + * @param headers headers for the columns + */ +private[sql] abstract class CsvReader(params: CSVOptions, headers: Seq[String]) { + + protected lazy val parser: CsvParser = { + val settings = new CsvParserSettings() + val format = settings.getFormat + format.setDelimiter(params.delimiter) + format.setLineSeparator(params.rowSeparator) + format.setQuote(params.quote) + format.setQuoteEscape(params.escape) + format.setComment(params.comment) + settings.setIgnoreLeadingWhitespaces(params.ignoreLeadingWhiteSpaceFlag) + settings.setIgnoreTrailingWhitespaces(params.ignoreTrailingWhiteSpaceFlag) + settings.setReadInputOnSeparateThread(false) + settings.setInputBufferSize(params.inputBufferSize) + settings.setMaxColumns(params.maxColumns) + settings.setNullValue(params.nullValue) + settings.setMaxCharsPerColumn(params.maxCharsPerColumn) + settings.setParseUnescapedQuotesUntilDelimiter(true) + if (headers != null) settings.setHeaders(headers: _*) + + new CsvParser(settings) + } +} + +/** + * Converts a sequence of string to CSV string + * + * @param params Parameters object for configuration + * @param headers headers for columns + */ +private[sql] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) extends Logging { + private val writerSettings = new CsvWriterSettings + private val format = writerSettings.getFormat + + format.setDelimiter(params.delimiter) + format.setLineSeparator(params.rowSeparator) + format.setQuote(params.quote) + format.setQuoteEscape(params.escape) + format.setComment(params.comment) + + writerSettings.setNullValue(params.nullValue) + writerSettings.setEmptyValue(params.nullValue) + writerSettings.setSkipEmptyLines(true) + writerSettings.setQuoteAllFields(false) + writerSettings.setHeaders(headers: _*) + + def writeRow(row: Seq[String], includeHeader: Boolean): String = { + val buffer = new ByteArrayOutputStream() + val outputWriter = new OutputStreamWriter(buffer, StandardCharsets.UTF_8) + val writer = new CsvWriter(outputWriter, writerSettings) + + if (includeHeader) { + writer.writeHeaders() + } + writer.writeRow(row.toArray: _*) + writer.close() + buffer.toString.stripLineEnd + } +} + +/** + * Parser for parsing a line at a time. Not efficient for bulk data. + * + * @param params Parameters object + */ +private[sql] class LineCsvReader(params: CSVOptions) + extends CsvReader(params, null) { + /** + * parse a line + * + * @param line a String with no newline at the end + * @return array of strings where each string is a field in the CSV record + */ + def parseLine(line: String): Array[String] = { + parser.beginParsing(new StringReader(line)) + val parsed = parser.parseNext() + parser.stopParsing() + parsed + } +} + +/** + * Parser for parsing lines in bulk. Use this when efficiency is desired. + * + * @param iter iterator over lines in the file + * @param params Parameters object + * @param headers headers for the columns + */ +private[sql] class BulkCsvReader( + iter: Iterator[String], + params: CSVOptions, + headers: Seq[String]) + extends CsvReader(params, headers) with Iterator[Array[String]] { + + private val reader = new StringIteratorReader(iter) + parser.beginParsing(reader) + private var nextRecord = parser.parseNext() + + /** + * get the next parsed line. + * @return array of strings where each string is a field in the CSV record + */ + override def next(): Array[String] = { + val curRecord = nextRecord + if(curRecord != null) { + nextRecord = parser.parseNext() + } else { + throw new NoSuchElementException("next record is null") + } + curRecord + } + + override def hasNext: Boolean = nextRecord != null + +} + +/** + * A Reader that "reads" from a sequence of lines. Spark's textFile method removes newlines at + * end of each line Univocity parser requires a Reader that provides access to the data to be + * parsed and needs the newlines to be present + * @param iter iterator over RDD[String] + */ +private class StringIteratorReader(val iter: Iterator[String]) extends java.io.Reader { + + private var next: Long = 0 + private var length: Long = 0 // length of input so far + private var start: Long = 0 + private var str: String = null // current string from iter + + /** + * fetch next string from iter, if done with current one + * pretend there is a new line at the end of every string we get from from iter + */ + private def refill(): Unit = { + if (length == next) { + if (iter.hasNext) { + str = iter.next() + start = length + length += (str.length + 1) // allowance for newline removed by SparkContext.textFile() + } else { + str = null + } + } + } + + /** + * read the next character, if at end of string pretend there is a new line + */ + override def read(): Int = { + refill() + if (next >= length) { + -1 + } else { + val cur = next - start + next += 1 + if (cur == str.length) '\n' else str.charAt(cur.toInt) + } + } + + /** + * read from str into cbuf + */ + override def read(cbuf: Array[Char], off: Int, len: Int): Int = { + refill() + var n = 0 + if ((off < 0) || (off > cbuf.length) || (len < 0) || + ((off + len) > cbuf.length) || ((off + len) < 0)) { + throw new IndexOutOfBoundsException() + } else if (len == 0) { + n = 0 + } else { + if (next >= length) { // end of input + n = -1 + } else { + n = Math.min(length - next, len).toInt // lesser of amount of input available or buf size + if (n == length - next) { + str.getChars((next - start).toInt, (next - start + n - 1).toInt, cbuf, off) + cbuf(off + n - 1) = '\n' + } else { + str.getChars((next - start).toInt, (next - start + n).toInt, cbuf, off) + } + next += n + if (n < len) { + val m = read(cbuf, off + n, len - n) // have more space, fetch more input from iter + if(m != -1) n += m + } + } + } + + n + } + + override def skip(ns: Long): Long = { + throw new IllegalArgumentException("Skip not implemented") + } + + override def ready: Boolean = { + refill() + true + } + + override def markSupported: Boolean = false + + override def mark(readAheadLimit: Int): Unit = { + throw new IllegalArgumentException("Mark not implemented") + } + + override def reset(): Unit = { + throw new IllegalArgumentException("Mark and hence reset not implemented") + } + + override def close(): Unit = { } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala new file mode 100644 index 0000000000000..54fb03b6d3bf7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import scala.util.control.NonFatal + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.{NullWritable, Text} +import org.apache.hadoop.mapreduce.RecordWriter +import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat + +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ + +object CSVRelation extends Logging { + + def univocityTokenizer( + file: RDD[String], + header: Seq[String], + firstLine: String, + params: CSVOptions): RDD[Array[String]] = { + // If header is set, make sure firstLine is materialized before sending to executors. + file.mapPartitions { iter => + new BulkCsvReader( + if (params.headerFlag) iter.filterNot(_ == firstLine) else iter, + params, + headers = header) + } + } + + def csvParser( + schema: StructType, + requiredColumns: Array[String], + params: CSVOptions): Array[String] => Option[InternalRow] = { + val schemaFields = schema.fields + val requiredFields = StructType(requiredColumns.map(schema(_))).fields + val safeRequiredFields = if (params.dropMalformed) { + // If `dropMalformed` is enabled, then it needs to parse all the values + // so that we can decide which row is malformed. + requiredFields ++ schemaFields.filterNot(requiredFields.contains(_)) + } else { + requiredFields + } + val safeRequiredIndices = new Array[Int](safeRequiredFields.length) + schemaFields.zipWithIndex.filter { + case (field, _) => safeRequiredFields.contains(field) + }.foreach { + case (field, index) => safeRequiredIndices(safeRequiredFields.indexOf(field)) = index + } + val requiredSize = requiredFields.length + val row = new GenericMutableRow(requiredSize) + + (tokens: Array[String]) => { + if (params.dropMalformed && schemaFields.length != tokens.length) { + logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") + None + } else if (params.failFast && schemaFields.length != tokens.length) { + throw new RuntimeException(s"Malformed line in FAILFAST mode: " + + s"${tokens.mkString(params.delimiter.toString)}") + } else { + val indexSafeTokens = if (params.permissive && schemaFields.length > tokens.length) { + tokens ++ new Array[String](schemaFields.length - tokens.length) + } else if (params.permissive && schemaFields.length < tokens.length) { + tokens.take(schemaFields.length) + } else { + tokens + } + try { + var index: Int = 0 + var subIndex: Int = 0 + while (subIndex < safeRequiredIndices.length) { + index = safeRequiredIndices(subIndex) + val field = schemaFields(index) + // It anyway needs to try to parse since it decides if this row is malformed + // or not after trying to cast in `DROPMALFORMED` mode even if the casted + // value is not stored in the row. + val value = CSVTypeCast.castTo( + indexSafeTokens(index), + field.dataType, + field.nullable, + params.nullValue) + if (subIndex < requiredSize) { + row(subIndex) = value + } + subIndex = subIndex + 1 + } + Some(row) + } catch { + case NonFatal(e) if params.dropMalformed => + logWarning("Parse exception. " + + s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") + None + } + } + } + } + + def parseCsv( + tokenizedRDD: RDD[Array[String]], + schema: StructType, + requiredColumns: Array[String], + options: CSVOptions): RDD[InternalRow] = { + val parser = csvParser(schema, requiredColumns, options) + tokenizedRDD.flatMap(parser(_).toSeq) + } + + // Skips the header line of each file if the `header` option is set to true. + def dropHeaderLine( + file: PartitionedFile, lines: Iterator[String], csvOptions: CSVOptions): Unit = { + // TODO What if the first partitioned file consists of only comments and empty lines? + if (csvOptions.headerFlag && file.start == 0) { + val nonEmptyLines = if (csvOptions.isCommentSet) { + val commentPrefix = csvOptions.comment.toString + lines.dropWhile { line => + line.trim.isEmpty || line.trim.startsWith(commentPrefix) + } + } else { + lines.dropWhile(_.trim.isEmpty) + } + + if (nonEmptyLines.hasNext) nonEmptyLines.drop(1) + } + } +} + +private[sql] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory { + override def newInstance( + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + if (bucketId.isDefined) sys.error("csv doesn't support bucketing") + new CsvOutputWriter(path, dataSchema, context, params) + } +} + +private[sql] class CsvOutputWriter( + path: String, + dataSchema: StructType, + context: TaskAttemptContext, + params: CSVOptions) extends OutputWriter with Logging { + + // create the Generator without separator inserted between 2 records + private[this] val text = new Text() + + private val recordWriter: RecordWriter[NullWritable, Text] = { + new TextOutputFormat[NullWritable, Text]() { + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val configuration = context.getConfiguration + val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId = context.getTaskAttemptID + val split = taskAttemptId.getTaskID.getId + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId.csv$extension") + } + }.getRecordWriter(context) + } + + private var firstRow: Boolean = params.headerFlag + + private val csvWriter = new LineCsvWriter(params, dataSchema.fieldNames.toSeq) + + private def rowToString(row: Seq[Any]): Seq[String] = row.map { field => + if (field != null) { + field.toString + } else { + params.nullValue + } + } + + override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") + + override protected[sql] def writeInternal(row: InternalRow): Unit = { + // TODO: Instead of converting and writing every row, we should use the univocity buffer + val resultString = csvWriter.writeRow(rowToString(row.toSeq(dataSchema)), firstRow) + if (firstRow) { + firstRow = false + } + text.set(resultString) + recordWriter.write(NullWritable.get(), text) + } + + override def close(): Unit = { + recordWriter.close(context) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala new file mode 100644 index 0000000000000..06a371b88bc02 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.nio.charset.{Charset, StandardCharsets} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.mapred.TextInputFormat +import org.apache.hadoop.mapreduce._ + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{JoinedRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.datasources.{CompressionCodecs, HadoopFileLinesReader, PartitionedFile} +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{StringType, StructField, StructType} +import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.collection.BitSet + +/** + * Provides access to CSV data from pure SQL statements. + */ +class DefaultSource extends FileFormat with DataSourceRegister { + + override def shortName(): String = "csv" + + override def toString: String = "CSV" + + override def equals(other: Any): Boolean = other.isInstanceOf[DefaultSource] + + override def inferSchema( + sqlContext: SQLContext, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + val csvOptions = new CSVOptions(options) + + // TODO: Move filtering. + val paths = files.filterNot(_.getPath.getName startsWith "_").map(_.getPath.toString) + val rdd = baseRdd(sqlContext, csvOptions, paths) + val firstLine = findFirstLine(csvOptions, rdd) + val firstRow = new LineCsvReader(csvOptions).parseLine(firstLine) + + val header = if (csvOptions.headerFlag) { + firstRow + } else { + firstRow.zipWithIndex.map { case (value, index) => s"C$index" } + } + + val parsedRdd = tokenRdd(sqlContext, csvOptions, header, paths) + val schema = if (csvOptions.inferSchemaFlag) { + CSVInferSchema.infer(parsedRdd, header, csvOptions.nullValue) + } else { + // By default fields are assumed to be StringType + val schemaFields = header.map { fieldName => + StructField(fieldName.toString, StringType, nullable = true) + } + StructType(schemaFields) + } + Some(schema) + } + + override def prepareWrite( + sqlContext: SQLContext, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + val conf = job.getConfiguration + val csvOptions = new CSVOptions(options) + csvOptions.compressionCodec.foreach { codec => + CompressionCodecs.setCodecConfiguration(conf, codec) + } + + new CSVOutputWriterFactory(csvOptions) + } + + override def buildReader( + sqlContext: SQLContext, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String]): (PartitionedFile) => Iterator[InternalRow] = { + val csvOptions = new CSVOptions(options) + val headers = requiredSchema.fields.map(_.name) + + val conf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) + val broadcastedConf = sqlContext.sparkContext.broadcast(new SerializableConfiguration(conf)) + + (file: PartitionedFile) => { + val lineIterator = { + val conf = broadcastedConf.value.value + new HadoopFileLinesReader(file, conf).map { line => + new String(line.getBytes, 0, line.getLength, csvOptions.charset) + } + } + + CSVRelation.dropHeaderLine(file, lineIterator, csvOptions) + + val unsafeRowIterator = { + val tokenizedIterator = new BulkCsvReader(lineIterator, csvOptions, headers) + val parser = CSVRelation.csvParser(dataSchema, requiredSchema.fieldNames, csvOptions) + tokenizedIterator.flatMap(parser(_).toSeq) + } + + // Appends partition values + val fullOutput = requiredSchema.toAttributes ++ partitionSchema.toAttributes + val joinedRow = new JoinedRow() + val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput) + + unsafeRowIterator.map { dataRow => + appendPartitionColumns(joinedRow(dataRow, file.partitionValues)) + } + } + } + + private def baseRdd( + sqlContext: SQLContext, + options: CSVOptions, + inputPaths: Seq[String]): RDD[String] = { + readText(sqlContext, options, inputPaths.mkString(",")) + } + + private def tokenRdd( + sqlContext: SQLContext, + options: CSVOptions, + header: Array[String], + inputPaths: Seq[String]): RDD[Array[String]] = { + val rdd = baseRdd(sqlContext, options, inputPaths) + // Make sure firstLine is materialized before sending to executors + val firstLine = if (options.headerFlag) findFirstLine(options, rdd) else null + CSVRelation.univocityTokenizer(rdd, header, firstLine, options) + } + + /** + * Returns the first line of the first non-empty file in path + */ + private def findFirstLine(options: CSVOptions, rdd: RDD[String]): String = { + if (options.isCommentSet) { + val comment = options.comment.toString + rdd.filter { line => + line.trim.nonEmpty && !line.startsWith(comment) + }.first() + } else { + rdd.filter { line => + line.trim.nonEmpty + }.first() + } + } + + private def readText( + sqlContext: SQLContext, + options: CSVOptions, + location: String): RDD[String] = { + if (Charset.forName(options.charset) == StandardCharsets.UTF_8) { + sqlContext.sparkContext.textFile(location) + } else { + val charset = options.charset + sqlContext.sparkContext + .hadoopFile[LongWritable, Text, TextInputFormat](location) + .mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset))) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index e7deeff13dc4d..2e88d588bee66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -17,22 +17,25 @@ package org.apache.spark.sql.execution.datasources +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} /** * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command. + * * @param table The table to be described. * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false. * It is effective only when the table is a Hive table. */ case class DescribeCommand( - table: LogicalPlan, - isExtended: Boolean) extends LogicalPlan with Command { + table: TableIdentifier, + isExtended: Boolean) + extends LogicalPlan with logical.Command { override def children: Seq[LogicalPlan] = Seq.empty @@ -42,16 +45,17 @@ case class DescribeCommand( new MetadataBuilder().putString("comment", "name of the column").build())(), AttributeReference("data_type", StringType, nullable = false, new MetadataBuilder().putString("comment", "data type of the column").build())(), - AttributeReference("comment", StringType, nullable = false, + AttributeReference("comment", StringType, nullable = true, new MetadataBuilder().putString("comment", "comment of the column").build())() ) } /** - * Used to represent the operation of create table using a data source. - * @param allowExisting If it is true, we will do nothing when the table already exists. - * If it is false, an exception will be thrown - */ + * Used to represent the operation of create table using a data source. + * + * @param allowExisting If it is true, we will do nothing when the table already exists. + * If it is false, an exception will be thrown + */ case class CreateTableUsing( tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], @@ -59,7 +63,7 @@ case class CreateTableUsing( temporary: Boolean, options: Map[String, String], allowExisting: Boolean, - managedIfNoPath: Boolean) extends LogicalPlan with Command { + managedIfNoPath: Boolean) extends LogicalPlan with logical.Command { override def output: Seq[Attribute] = Seq.empty override def children: Seq[LogicalPlan] = Seq.empty @@ -67,8 +71,8 @@ case class CreateTableUsing( /** * A node used to support CTAS statements and saveAsTable for the data source API. - * This node is a [[UnaryNode]] instead of a [[Command]] because we want the analyzer - * can analyze the logical plan that will be used to populate the table. + * This node is a [[logical.UnaryNode]] instead of a [[logical.Command]] because we want the + * analyzer can analyze the logical plan that will be used to populate the table. * So, [[PreWriteCheck]] can detect cases that are not allowed. */ case class CreateTableUsingAsSelect( @@ -76,9 +80,10 @@ case class CreateTableUsingAsSelect( provider: String, temporary: Boolean, partitionColumns: Array[String], + bucketSpec: Option[BucketSpec], mode: SaveMode, options: Map[String, String], - child: LogicalPlan) extends UnaryNode { + child: LogicalPlan) extends logical.UnaryNode { override def output: Seq[Attribute] = Seq.empty[Attribute] } @@ -88,12 +93,21 @@ case class CreateTempTableUsing( provider: String, options: Map[String, String]) extends RunnableCommand { + if (tableIdent.database.isDefined) { + throw new AnalysisException( + s"Temporary table '$tableIdent' should not have specified a database") + } + def run(sqlContext: SQLContext): Seq[Row] = { - val resolved = ResolvedDataSource( - sqlContext, userSpecifiedSchema, Array.empty[String], provider, options) - sqlContext.catalog.registerTable( - tableIdent, - DataFrame(sqlContext, LogicalRelation(resolved.relation)).logicalPlan) + val dataSource = DataSource( + sqlContext, + userSpecifiedSchema = userSpecifiedSchema, + className = provider, + options = options) + sqlContext.sessionState.catalog.createTempTable( + tableIdent.table, + Dataset.ofRows(sqlContext, LogicalRelation(dataSource.resolveRelation())).logicalPlan, + overrideIfExists = true) Seq.empty[Row] } @@ -107,12 +121,24 @@ case class CreateTempTableUsingAsSelect( options: Map[String, String], query: LogicalPlan) extends RunnableCommand { + if (tableIdent.database.isDefined) { + throw new AnalysisException( + s"Temporary table '$tableIdent' should not have specified a database") + } + override def run(sqlContext: SQLContext): Seq[Row] = { - val df = DataFrame(sqlContext, query) - val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df) - sqlContext.catalog.registerTable( - tableIdent, - DataFrame(sqlContext, LogicalRelation(resolved.relation)).logicalPlan) + val df = Dataset.ofRows(sqlContext, query) + val dataSource = DataSource( + sqlContext, + className = provider, + partitionColumns = partitionColumns, + bucketSpec = None, + options = options) + val result = dataSource.write(mode, df) + sqlContext.sessionState.catalog.createTempTable( + tableIdent.table, + Dataset.ofRows(sqlContext, LogicalRelation(result)).logicalPlan, + overrideIfExists = true) Seq.empty[Row] } @@ -123,17 +149,17 @@ case class RefreshTable(tableIdent: TableIdentifier) override def run(sqlContext: SQLContext): Seq[Row] = { // Refresh the given table's metadata first. - sqlContext.catalog.refreshTable(tableIdent) + sqlContext.sessionState.catalog.refreshTable(tableIdent) // If this table is cached as a InMemoryColumnarRelation, drop the original // cached version and make the new version cached lazily. - val logicalPlan = sqlContext.catalog.lookupRelation(tableIdent) + val logicalPlan = sqlContext.sessionState.catalog.lookupRelation(tableIdent) // Use lookupCachedData directly since RefreshTable also takes databaseName. val isCached = sqlContext.cacheManager.lookupCachedData(logicalPlan).nonEmpty if (isCached) { // Create a data frame to represent the table. // TODO: Use uncacheTable once it supports database name. - val df = DataFrame(sqlContext, logicalPlan) + val df = Dataset.ofRows(sqlContext, logicalPlan) // Uncache the logicalPlan. sqlContext.cacheManager.tryUncacheQuery(df, blocking = true) // Cache it again. @@ -161,8 +187,3 @@ class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] override def -(key: String): Map[String, String] = baseMap - key.toLowerCase } - -/** - * The exception thrown from the DDL parser. - */ -class DDLException(message: String) extends RuntimeException(message) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala index 6773afc794f9c..4dcd261f5cbe9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala @@ -1,26 +1,26 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql.execution.datasources.jdbc import java.util.Properties import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.sources.{BaseRelation, RelationProvider, DataSourceRegister} +import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider} class DefaultSource extends RelationProvider with DataSourceRegister { @@ -31,15 +31,12 @@ class DefaultSource extends RelationProvider with DataSourceRegister { sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) - val driver = parameters.getOrElse("driver", null) val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) val partitionColumn = parameters.getOrElse("partitionColumn", null) val lowerBound = parameters.getOrElse("lowerBound", null) val upperBound = parameters.getOrElse("upperBound", null) val numPartitions = parameters.getOrElse("numPartitions", null) - if (driver != null) DriverRegistry.register(driver) - if (partitionColumn != null && (lowerBound == null || upperBound == null || numPartitions == null)) { sys.error("Partitioning incompletely specified") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala index 7ccd61ed469e9..7a6c0f9fed2f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala @@ -21,7 +21,7 @@ import java.sql.{Driver, DriverManager} import scala.collection.mutable -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.util.Utils /** @@ -51,10 +51,5 @@ object DriverRegistry extends Logging { } } } - - def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { - case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName - case driver => driver.getClass.getCanonicalName - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 018a009fbda6d..6a5564addf48c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -17,20 +17,23 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException} +import java.sql.{Connection, Date, ResultSet, ResultSetMetaData, SQLException, Timestamp} import java.util.Properties +import scala.util.control.NonFatal + import org.apache.commons.lang3.StringUtils +import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} /** * Data corresponding to one partition of a JDBCRDD. @@ -39,7 +42,6 @@ private[sql] case class JDBCPartition(whereClause: String, idx: Int) extends Par override def index: Int = idx } - private[sql] object JDBCRDD extends Logging { /** @@ -118,32 +120,39 @@ private[sql] object JDBCRDD extends Logging { */ def resolveTable(url: String, table: String, properties: Properties): StructType = { val dialect = JdbcDialects.get(url) - val conn: Connection = getConnector(properties.getProperty("driver"), url, properties)() + val conn: Connection = JdbcUtils.createConnectionFactory(url, properties)() try { - val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery() + val statement = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0") try { - val rsmd = rs.getMetaData - val ncols = rsmd.getColumnCount - val fields = new Array[StructField](ncols) - var i = 0 - while (i < ncols) { - val columnName = rsmd.getColumnLabel(i + 1) - val dataType = rsmd.getColumnType(i + 1) - val typeName = rsmd.getColumnTypeName(i + 1) - val fieldSize = rsmd.getPrecision(i + 1) - val fieldScale = rsmd.getScale(i + 1) - val isSigned = rsmd.isSigned(i + 1) - val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls - val metadata = new MetadataBuilder().putString("name", columnName) - val columnType = - dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( - getCatalystType(dataType, fieldSize, fieldScale, isSigned)) - fields(i) = StructField(columnName, columnType, nullable, metadata.build()) - i = i + 1 + val rs = statement.executeQuery() + try { + val rsmd = rs.getMetaData + val ncols = rsmd.getColumnCount + val fields = new Array[StructField](ncols) + var i = 0 + while (i < ncols) { + val columnName = rsmd.getColumnLabel(i + 1) + val dataType = rsmd.getColumnType(i + 1) + val typeName = rsmd.getColumnTypeName(i + 1) + val fieldSize = rsmd.getPrecision(i + 1) + val fieldScale = rsmd.getScale(i + 1) + val isSigned = rsmd.isSigned(i + 1) + val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls + val metadata = new MetadataBuilder() + .putString("name", columnName) + .putLong("scale", fieldScale) + val columnType = + dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( + getCatalystType(dataType, fieldSize, fieldScale, isSigned)) + fields(i) = StructField(columnName, columnType, nullable, metadata.build()) + i = i + 1 + } + return new StructType(fields) + } finally { + rs.close() } - return new StructType(fields) } finally { - rs.close() + statement.close() } } finally { conn.close() @@ -161,40 +170,73 @@ private[sql] object JDBCRDD extends Logging { * @return A Catalyst schema corresponding to columns in the given order. */ private def pruneSchema(schema: StructType, columns: Array[String]): StructType = { - val fieldMap = Map(schema.fields map { x => x.metadata.getString("name") -> x }: _*) - new StructType(columns map { name => fieldMap(name) }) + val fieldMap = Map(schema.fields.map(x => x.metadata.getString("name") -> x): _*) + new StructType(columns.map(name => fieldMap(name))) } /** - * Given a driver string and an url, return a function that loads the - * specified driver string then returns a connection to the JDBC url. - * getConnector is run on the driver code, while the function it returns - * is run on the executor. - * - * @param driver - The class name of the JDBC driver for the given url, or null if the class name - * is not necessary. - * @param url - The JDBC url to connect to. - * - * @return A function that loads the driver and connects to the url. + * Converts value to SQL expression. */ - def getConnector(driver: String, url: String, properties: Properties): () => Connection = { - () => { - try { - if (driver != null) DriverRegistry.register(driver) - } catch { - case e: ClassNotFoundException => - logWarning(s"Couldn't find class $driver", e) - } - DriverManager.getConnection(url, properties) - } + private def compileValue(value: Any): Any = value match { + case stringValue: String => s"'${escapeSql(stringValue)}'" + case timestampValue: Timestamp => "'" + timestampValue + "'" + case dateValue: Date => "'" + dateValue + "'" + case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ") + case _ => value + } + + private def escapeSql(value: String): String = + if (value == null) null else StringUtils.replace(value, "'", "''") + + /** + * Turns a single Filter into a String representing a SQL expression. + * Returns None for an unhandled filter. + */ + private[jdbc] def compileFilter(f: Filter): Option[String] = { + Option(f match { + case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" + case EqualNullSafe(attr, value) => + s"(NOT ($attr != ${compileValue(value)} OR $attr IS NULL OR " + + s"${compileValue(value)} IS NULL) OR ($attr IS NULL AND ${compileValue(value)} IS NULL))" + case LessThan(attr, value) => s"$attr < ${compileValue(value)}" + case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" + case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" + case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" + case IsNull(attr) => s"$attr IS NULL" + case IsNotNull(attr) => s"$attr IS NOT NULL" + case StringStartsWith(attr, value) => s"${attr} LIKE '${value}%'" + case StringEndsWith(attr, value) => s"${attr} LIKE '%${value}'" + case StringContains(attr, value) => s"${attr} LIKE '%${value}%'" + case In(attr, value) => s"$attr IN (${compileValue(value)})" + case Not(f) => compileFilter(f).map(p => s"(NOT ($p))").getOrElse(null) + case Or(f1, f2) => + // We can't compile Or filter unless both sub-filters are compiled successfully. + // It applies too for the following And filter. + // If we can make sure compileFilter supports all filters, we can remove this check. + val or = Seq(f1, f2).flatMap(compileFilter(_)) + if (or.size == 2) { + or.map(p => s"($p)").mkString(" OR ") + } else { + null + } + case And(f1, f2) => + val and = Seq(f1, f2).flatMap(compileFilter(_)) + if (and.size == 2) { + and.map(p => s"($p)").mkString(" AND ") + } else { + null + } + case _ => null + }) } + + /** * Build and return JDBCRDD from the given information. * * @param sc - Your SparkContext. * @param schema - The Catalyst schema of the underlying database table. - * @param driver - The class name of the JDBC driver for the given url. * @param url - The JDBC url to connect to. * @param fqTable - The fully-qualified table name (or paren'd SQL query) to use. * @param requiredColumns - The names of the columns to SELECT. @@ -207,7 +249,6 @@ private[sql] object JDBCRDD extends Logging { def scanTable( sc: SparkContext, schema: StructType, - driver: String, url: String, properties: Properties, fqTable: String, @@ -218,12 +259,13 @@ private[sql] object JDBCRDD extends Logging { val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName)) new JDBCRDD( sc, - getConnector(driver, url, properties), + JdbcUtils.createConnectionFactory(url, properties), pruneSchema(schema, requiredColumns), fqTable, quotedColumns, filters, parts, + url, properties) } } @@ -241,6 +283,7 @@ private[sql] class JDBCRDD( columns: Array[String], filters: Array[Filter], partitions: Array[Partition], + url: String, properties: Properties) extends RDD[InternalRow](sc, Nil) { @@ -258,52 +301,24 @@ private[sql] class JDBCRDD( if (sb.length == 0) "1" else sb.substring(1) } - /** - * Converts value to SQL expression. - */ - private def compileValue(value: Any): Any = value match { - case stringValue: String => s"'${escapeSql(stringValue)}'" - case _ => value - } - - private def escapeSql(value: String): String = - if (value == null) null else StringUtils.replace(value, "'", "''") - - /** - * Turns a single Filter into a String representing a SQL expression. - * Returns null for an unhandled filter. - */ - private def compileFilter(f: Filter): String = f match { - case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" - case LessThan(attr, value) => s"$attr < ${compileValue(value)}" - case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" - case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" - case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" - case _ => null - } - /** * `filters`, but as a WHERE clause suitable for injection into a SQL query. */ - private val filterWhereClause: String = { - val filterStrings = filters map compileFilter filter (_ != null) - if (filterStrings.size > 0) { - val sb = new StringBuilder("WHERE ") - filterStrings.foreach(x => sb.append(x).append(" AND ")) - sb.substring(0, sb.length - 5) - } else "" - } + private val filterWhereClause: String = + filters.flatMap(JDBCRDD.compileFilter).mkString(" AND ") /** * A WHERE clause representing both `filters`, if any, and the current partition. */ private def getWhereClause(part: JDBCPartition): String = { if (part.whereClause != null && filterWhereClause.length > 0) { - filterWhereClause + " AND " + part.whereClause + "WHERE " + filterWhereClause + " AND " + part.whereClause } else if (part.whereClause != null) { "WHERE " + part.whereClause + } else if (filterWhereClause.length > 0) { + "WHERE " + filterWhereClause } else { - filterWhereClause + "" } } @@ -324,25 +339,27 @@ private[sql] class JDBCRDD( case object StringConversion extends JDBCConversion case object TimestampConversion extends JDBCConversion case object BinaryConversion extends JDBCConversion + case class ArrayConversion(elementConversion: JDBCConversion) extends JDBCConversion /** * Maps a StructType to a type tag list. */ - def getConversions(schema: StructType): Array[JDBCConversion] = { - schema.fields.map(sf => sf.dataType match { - case BooleanType => BooleanConversion - case DateType => DateConversion - case DecimalType.Fixed(p, s) => DecimalConversion(p, s) - case DoubleType => DoubleConversion - case FloatType => FloatConversion - case IntegerType => IntegerConversion - case LongType => - if (sf.metadata.contains("binarylong")) BinaryLongConversion else LongConversion - case StringType => StringConversion - case TimestampType => TimestampConversion - case BinaryType => BinaryConversion - case _ => throw new IllegalArgumentException(s"Unsupported field $sf") - }).toArray + def getConversions(schema: StructType): Array[JDBCConversion] = + schema.fields.map(sf => getConversions(sf.dataType, sf.metadata)) + + private def getConversions(dt: DataType, metadata: Metadata): JDBCConversion = dt match { + case BooleanType => BooleanConversion + case DateType => DateConversion + case DecimalType.Fixed(p, s) => DecimalConversion(p, s) + case DoubleType => DoubleConversion + case FloatType => FloatConversion + case IntegerType => IntegerConversion + case LongType => if (metadata.contains("binarylong")) BinaryLongConversion else LongConversion + case StringType => StringConversion + case TimestampType => TimestampConversion + case BinaryType => BinaryConversion + case ArrayType(et, _) => ArrayConversion(getConversions(et, metadata)) + case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}") } /** @@ -359,6 +376,9 @@ private[sql] class JDBCRDD( context.addTaskCompletionListener{ context => close() } val part = thePart.asInstanceOf[JDBCPartition] val conn = getConnection() + val dialect = JdbcDialects.get(url) + import scala.collection.JavaConverters._ + dialect.beforeFetch(conn, properties.asScala.toMap) // H2's JDBC driver does not support the setSchema() method. We pass a // fully-qualified table name in the SELECT statement. I don't know how to @@ -420,16 +440,44 @@ private[sql] class JDBCRDD( mutableRow.update(i, null) } case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) - case BinaryLongConversion => { + case BinaryLongConversion => val bytes = rs.getBytes(pos) var ans = 0L var j = 0 while (j < bytes.size) { ans = 256 * ans + (255 & bytes(j)) - j = j + 1; + j = j + 1 } mutableRow.setLong(i, ans) - } + case ArrayConversion(elementConversion) => + val array = rs.getArray(pos).getArray + if (array != null) { + val data = elementConversion match { + case TimestampConversion => + array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp => + nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp) + } + case StringConversion => + array.asInstanceOf[Array[java.lang.String]] + .map(UTF8String.fromString) + case DateConversion => + array.asInstanceOf[Array[java.sql.Date]].map { date => + nullSafeConvert(date, DateTimeUtils.fromJavaDate) + } + case DecimalConversion(p, s) => + array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal => + nullSafeConvert[java.math.BigDecimal](decimal, d => Decimal(d, p, s)) + } + case BinaryLongConversion => + throw new IllegalArgumentException(s"Unsupported array element conversion $i") + case _: ArrayConversion => + throw new IllegalArgumentException("Nested arrays unsupported") + case _ => array.asInstanceOf[Array[Any]] + } + mutableRow.update(i, new GenericArrayData(data)) + } else { + mutableRow.update(i, null) + } } if (rs.wasNull) mutableRow.setNullAt(i) i = i + 1 @@ -459,12 +507,20 @@ private[sql] class JDBCRDD( } try { if (null != conn) { + if (!conn.isClosed && !conn.getAutoCommit) { + try { + conn.commit() + } catch { + case NonFatal(e) => logWarning("Exception committing transaction", e) + } + } conn.close() } logInfo("closed connection") } catch { case e: Exception => logWarning("Exception closing connection", e) } + closed = true } override def hasNext: Boolean = { @@ -488,4 +544,12 @@ private[sql] class JDBCRDD( nextValue } } + + private def nullSafeConvert[T](input: T, f: T => Any): Any = { + if (input == null) { + null + } else { + f(input) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index f9300dc2cb529..9e336422d1f8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -23,9 +23,9 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.Partition import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} /** * Instructions on how to partition the table among workers. @@ -44,6 +44,12 @@ private[sql] object JDBCRelation { * exactly once. The parameters minValue and maxValue are advisory in that * incorrect values may cause the partitioning to be poor, but no data * will fail to be represented. + * + * Null value predicate is added to the first partition where clause to include + * the rows with null value for the partitions column. + * + * @param partitioning partition information to generate the where clause for each partition + * @return an array of partitions with where clause for each partition */ def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = { if (partitioning == null) return Array[Partition](JDBCPartition(null, 0)) @@ -66,7 +72,7 @@ private[sql] object JDBCRelation { if (upperBound == null) { lowerBound } else if (lowerBound == null) { - upperBound + s"$upperBound or $column is null" } else { s"$lowerBound AND $upperBound" } @@ -90,13 +96,16 @@ private[sql] case class JDBCRelation( override val schema: StructType = JDBCRDD.resolveTable(url, table, properties) + // Check if JDBCRDD.compileFilter can accept input filters + override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { + filters.filter(JDBCRDD.compileFilter(_).isEmpty) + } + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { - val driver: String = DriverRegistry.getDriverClassName(url) // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( sqlContext.sparkContext, schema, - driver, url, properties, table, @@ -110,4 +119,9 @@ private[sql] case class JDBCRelation( .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) .jdbc(url, table, properties) } + + override def toString: String = { + // credentials should not be included in the plan output, table information is sufficient. + s"JDBCRelation(${table})" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index f89d55b20e212..065c8572b06a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -17,15 +17,17 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, PreparedStatement} +import java.sql.{Connection, Driver, DriverManager, PreparedStatement} import java.util.Properties +import scala.collection.JavaConverters._ import scala.util.Try +import scala.util.control.NonFatal -import org.apache.spark.Logging -import org.apache.spark.sql.jdbc.JdbcDialects -import org.apache.spark.sql.types._ +import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} +import org.apache.spark.sql.types._ /** * Util functions for JDBC tables. @@ -33,10 +35,31 @@ import org.apache.spark.sql.{DataFrame, Row} object JdbcUtils extends Logging { /** - * Establishes a JDBC connection. + * Returns a factory for creating connections to the given JDBC URL. + * + * @param url the JDBC url to connect to. + * @param properties JDBC connection properties. */ - def createConnection(url: String, connectionProperties: Properties): Connection = { - JDBCRDD.getConnector(connectionProperties.getProperty("driver"), url, connectionProperties)() + def createConnectionFactory(url: String, properties: Properties): () => Connection = { + val userSpecifiedDriverClass = Option(properties.getProperty("driver")) + userSpecifiedDriverClass.foreach(DriverRegistry.register) + // Performing this part of the logic on the driver guards against the corner-case where the + // driver returned for a URL is different on the driver and executors due to classpath + // differences. + val driverClass: String = userSpecifiedDriverClass.getOrElse { + DriverManager.getDriver(url).getClass.getCanonicalName + } + () => { + userSpecifiedDriverClass.foreach(DriverRegistry.register) + val driver: Driver = DriverManager.getDrivers.asScala.collectFirst { + case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d + case d if d.getClass.getCanonicalName == driverClass => d + }.getOrElse { + throw new IllegalStateException( + s"Did not find registered driver with class $driverClass") + } + driver.connect(url, properties) + } } /** @@ -47,29 +70,66 @@ object JdbcUtils extends Logging { // Somewhat hacky, but there isn't a good way to identify whether a table exists for all // SQL database systems using JDBC meta data calls, considering "table" could also include - // the database name. Query used to find table exists can be overriden by the dialects. - Try(conn.prepareStatement(dialect.getTableExistsQuery(table)).executeQuery()).isSuccess + // the database name. Query used to find table exists can be overridden by the dialects. + Try { + val statement = conn.prepareStatement(dialect.getTableExistsQuery(table)) + try { + statement.executeQuery() + } finally { + statement.close() + } + }.isSuccess } /** * Drops a table from the JDBC database. */ def dropTable(conn: Connection, table: String): Unit = { - conn.prepareStatement(s"DROP TABLE $table").executeUpdate() + val statement = conn.createStatement + try { + statement.executeUpdate(s"DROP TABLE $table") + } finally { + statement.close() + } } /** * Returns a PreparedStatement that inserts a row into table via conn. */ def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = { - val sql = new StringBuilder(s"INSERT INTO $table VALUES (") - var fieldsLeft = rddSchema.fields.length - while (fieldsLeft > 0) { - sql.append("?") - if (fieldsLeft > 1) sql.append(", ") else sql.append(")") - fieldsLeft = fieldsLeft - 1 + val columns = rddSchema.fields.map(_.name).mkString(",") + val placeholders = rddSchema.fields.map(_ => "?").mkString(",") + val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders)" + conn.prepareStatement(sql) + } + + /** + * Retrieve standard jdbc types. + * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]]) + * @return The default JdbcType for this DataType + */ + def getCommonJDBCType(dt: DataType): Option[JdbcType] = { + dt match { + case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER)) + case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT)) + case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE)) + case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT)) + case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT)) + case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT)) + case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT)) + case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB)) + case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB)) + case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP)) + case DateType => Option(JdbcType("DATE", java.sql.Types.DATE)) + case t: DecimalType => Option( + JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL)) + case _ => None } - conn.prepareStatement(sql.toString()) + } + + private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = { + dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse( + throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}")) } /** @@ -92,11 +152,23 @@ object JdbcUtils extends Logging { iterator: Iterator[Row], rddSchema: StructType, nullTypes: Array[Int], - batchSize: Int): Iterator[Byte] = { + batchSize: Int, + dialect: JdbcDialect): Iterator[Byte] = { val conn = getConnection() var committed = false + val supportsTransactions = try { + conn.getMetaData().supportsDataManipulationTransactionsOnly() || + conn.getMetaData().supportsDataDefinitionAndDataManipulationTransactions() + } catch { + case NonFatal(e) => + logWarning("Exception while detecting transaction support", e) + true + } + try { - conn.setAutoCommit(false) // Everything in the same db transaction. + if (supportsTransactions) { + conn.setAutoCommit(false) // Everything in the same db transaction. + } val stmt = insertStatement(conn, table, rddSchema) try { var rowCount = 0 @@ -121,6 +193,14 @@ object JdbcUtils extends Logging { case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) + case ArrayType(et, _) => + // remove type length parameters from end of type name + val typeName = getJdbcType(et, dialect).databaseTypeDefinition + .toLowerCase.split("\\(")(0) + val array = conn.createArrayOf( + typeName, + row.getSeq[AnyRef](i).toArray) + stmt.setArray(i + 1, array) case _ => throw new IllegalArgumentException( s"Can't translate non-null value for field $i") } @@ -140,14 +220,18 @@ object JdbcUtils extends Logging { } finally { stmt.close() } - conn.commit() + if (supportsTransactions) { + conn.commit() + } committed = true } finally { if (!committed) { // The stage must fail. We got here through an exception path, so // let the exception through unless rollback() or close() want to // tell the user about another problem. - conn.rollback() + if (supportsTransactions) { + conn.rollback() + } conn.close() } else { // The stage must succeed. We cannot propagate any exception close() might throw. @@ -167,28 +251,12 @@ object JdbcUtils extends Logging { def schemaString(df: DataFrame, url: String): String = { val sb = new StringBuilder() val dialect = JdbcDialects.get(url) - df.schema.fields foreach { field => { + df.schema.fields foreach { field => val name = field.name - val typ: String = - dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse( - field.dataType match { - case IntegerType => "INTEGER" - case LongType => "BIGINT" - case DoubleType => "DOUBLE PRECISION" - case FloatType => "REAL" - case ShortType => "INTEGER" - case ByteType => "BYTE" - case BooleanType => "BIT(1)" - case StringType => "TEXT" - case BinaryType => "BLOB" - case TimestampType => "TIMESTAMP" - case DateType => "DATE" - case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})" - case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") - }) + val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition val nullable = if (field.nullable) "" else "NOT NULL" sb.append(s", $name $typ $nullable") - }} + } if (sb.length < 2) "" else sb.substring(2) } @@ -199,34 +267,17 @@ object JdbcUtils extends Logging { df: DataFrame, url: String, table: String, - properties: Properties = new Properties()) { + properties: Properties) { val dialect = JdbcDialects.get(url) val nullTypes: Array[Int] = df.schema.fields.map { field => - dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse( - field.dataType match { - case IntegerType => java.sql.Types.INTEGER - case LongType => java.sql.Types.BIGINT - case DoubleType => java.sql.Types.DOUBLE - case FloatType => java.sql.Types.REAL - case ShortType => java.sql.Types.INTEGER - case ByteType => java.sql.Types.INTEGER - case BooleanType => java.sql.Types.BIT - case StringType => java.sql.Types.CLOB - case BinaryType => java.sql.Types.BLOB - case TimestampType => java.sql.Types.TIMESTAMP - case DateType => java.sql.Types.DATE - case t: DecimalType => java.sql.Types.DECIMAL - case _ => throw new IllegalArgumentException( - s"Can't translate null value for field $field") - }) + getJdbcType(field.dataType, dialect).jdbcNullType } val rddSchema = df.schema - val driver: String = DriverRegistry.getDriverClassName(url) - val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) + val getConnection: () => Connection = createConnectionFactory(url, properties) val batchSize = properties.getProperty("batchsize", "1000").toInt df.foreachPartition { iterator => - savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize) + savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index b9914c581a657..8e8238a594a03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -26,39 +26,47 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils private[sql] object InferSchema { + /** * Infer the type of a collection of json records in three stages: * 1. Infer the type of each record * 2. Merge types by choosing the lowest type necessary to cover equal keys * 3. Replace any remaining null fields with string, the top type */ - def apply( + def infer( json: RDD[String], - samplingRatio: Double = 1.0, columnNameOfCorruptRecords: String, - primitivesAsString: Boolean = false): StructType = { - require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") - val schemaData = if (samplingRatio > 0.99) { + configOptions: JSONOptions): StructType = { + require(configOptions.samplingRatio > 0, + s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0") + val shouldHandleCorruptRecord = configOptions.permissive + val schemaData = if (configOptions.samplingRatio > 0.99) { json } else { - json.sample(withReplacement = false, samplingRatio, 1) + json.sample(withReplacement = false, configOptions.samplingRatio, 1) } // perform schema inference on each row and merge afterwards val rootType = schemaData.mapPartitions { iter => val factory = new JsonFactory() - iter.map { row => + configOptions.setJacksonOptions(factory) + iter.flatMap { row => try { Utils.tryWithResource(factory.createParser(row)) { parser => parser.nextToken() - inferField(parser, primitivesAsString) + Some(inferField(parser, configOptions)) } } catch { + case _: JsonParseException if shouldHandleCorruptRecord => + Some(StructType(Seq(StructField(columnNameOfCorruptRecords, StringType)))) case _: JsonParseException => - StructType(Seq(StructField(columnNameOfCorruptRecords, StringType))) + None } } - }.treeAggregate[DataType](StructType(Seq()))(compatibleRootType, compatibleRootType) + }.treeAggregate[DataType]( + StructType(Seq()))( + compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord), + compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)) canonicalizeType(rootType) match { case Some(st: StructType) => st @@ -71,14 +79,14 @@ private[sql] object InferSchema { /** * Infer the type of a json document from the parser's token stream */ - private def inferField(parser: JsonParser, primitivesAsString: Boolean): DataType = { + private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { import com.fasterxml.jackson.core.JsonToken._ parser.getCurrentToken match { case null | VALUE_NULL => NullType case FIELD_NAME => parser.nextToken() - inferField(parser, primitivesAsString) + inferField(parser, configOptions) case VALUE_STRING if parser.getTextLength < 1 => // Zero length strings and nulls have special handling to deal @@ -95,7 +103,7 @@ private[sql] object InferSchema { while (nextUntil(parser, END_OBJECT)) { builder += StructField( parser.getCurrentName, - inferField(parser, primitivesAsString), + inferField(parser, configOptions), nullable = true) } @@ -107,14 +115,15 @@ private[sql] object InferSchema { // the type as we pass through all JSON objects. var elementType: DataType = NullType while (nextUntil(parser, END_ARRAY)) { - elementType = compatibleType(elementType, inferField(parser, primitivesAsString)) + elementType = compatibleType( + elementType, inferField(parser, configOptions)) } ArrayType(elementType) - case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if primitivesAsString => StringType + case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if configOptions.primitivesAsString => StringType - case (VALUE_TRUE | VALUE_FALSE) if primitivesAsString => StringType + case (VALUE_TRUE | VALUE_FALSE) if configOptions.primitivesAsString => StringType case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT => import JsonParser.NumberType._ @@ -125,9 +134,19 @@ private[sql] object InferSchema { // when we see a Java BigInteger, we use DecimalType. case BIG_INTEGER | BIG_DECIMAL => val v = parser.getDecimalValue - DecimalType(v.precision(), v.scale()) + if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) { + DecimalType(Math.max(v.precision(), v.scale()), v.scale()) + } else { + DoubleType + } + case FLOAT | DOUBLE if configOptions.prefersDecimal => + val v = parser.getDecimalValue + if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) { + DecimalType(Math.max(v.precision(), v.scale()), v.scale()) + } else { + DoubleType + } case FLOAT | DOUBLE => - // TODO(davies): Should we use decimal if possible? DoubleType } @@ -138,7 +157,7 @@ private[sql] object InferSchema { /** * Convert NullType to StringType and remove StructTypes with no fields */ - private def canonicalizeType: DataType => Option[DataType] = { + private def canonicalizeType(tpe: DataType): Option[DataType] = tpe match { case at @ ArrayType(elementType, _) => for { canonicalType <- canonicalizeType(elementType) @@ -147,15 +166,15 @@ private[sql] object InferSchema { } case StructType(fields) => - val canonicalFields = for { + val canonicalFields: Array[StructField] = for { field <- fields - if field.name.nonEmpty + if field.name.length > 0 canonicalType <- canonicalizeType(field.dataType) } yield { field.copy(dataType = canonicalType) } - if (canonicalFields.nonEmpty) { + if (canonicalFields.length > 0) { Some(StructType(canonicalFields)) } else { // per SPARK-8093: empty structs should be deleted @@ -166,28 +185,56 @@ private[sql] object InferSchema { case other => Some(other) } + private def withCorruptField( + struct: StructType, + columnNameOfCorruptRecords: String): StructType = { + if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) { + // If this given struct does not have a column used for corrupt records, + // add this field. + struct.add(columnNameOfCorruptRecords, StringType, nullable = true) + } else { + // Otherwise, just return this struct. + struct + } + } + /** * Remove top-level ArrayType wrappers and merge the remaining schemas */ - private def compatibleRootType: (DataType, DataType) => DataType = { - case (ArrayType(ty1, _), ty2) => compatibleRootType(ty1, ty2) - case (ty1, ArrayType(ty2, _)) => compatibleRootType(ty1, ty2) + private def compatibleRootType( + columnNameOfCorruptRecords: String, + shouldHandleCorruptRecord: Boolean): (DataType, DataType) => DataType = { + // Since we support array of json objects at the top level, + // we need to check the element type and find the root level data type. + case (ArrayType(ty1, _), ty2) => + compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2) + case (ty1, ArrayType(ty2, _)) => + compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2) + // If we see any other data type at the root level, we get records that cannot be + // parsed. So, we use the struct as the data type and add the corrupt field to the schema. + case (struct: StructType, NullType) => struct + case (NullType, struct: StructType) => struct + case (struct: StructType, o) if !o.isInstanceOf[StructType] && shouldHandleCorruptRecord => + withCorruptField(struct, columnNameOfCorruptRecords) + case (o, struct: StructType) if !o.isInstanceOf[StructType] && shouldHandleCorruptRecord => + withCorruptField(struct, columnNameOfCorruptRecords) + // If we get anything else, we call compatibleType. + // Usually, when we reach here, ty1 and ty2 are two StructTypes. case (ty1, ty2) => compatibleType(ty1, ty2) } /** * Returns the most general data type for two given data types. */ - private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { + def compatibleType(t1: DataType, t2: DataType): DataType = { HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { // Double support larger range than fixed decimal, DecimalType.Maximum should be enough // in most case, also have better precision. - case (DoubleType, t: DecimalType) => - DoubleType - case (t: DecimalType, DoubleType) => + case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => DoubleType + case (t1: DecimalType, t2: DecimalType) => val scale = math.max(t1.scale, t2.scale) val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale) @@ -209,6 +256,14 @@ private[sql] object InferSchema { case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) + // The case that given `DecimalType` is capable of given `IntegralType` is handled in + // `findTightestCommonTypeOfTwo`. Both cases below will be executed only when + // the given `DecimalType` is not capable of the given `IntegralType`. + case (t1: IntegralType, t2: DecimalType) => + compatibleType(DecimalType.forType(t1), t2) + case (t1: DecimalType, t2: IntegralType) => + compatibleType(t1, DecimalType.forType(t2)) + // strings and every string is a Json object. case (_, _) => StringType } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala new file mode 100644 index 0000000000000..66f1126fb9ae6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.json + +import com.fasterxml.jackson.core.{JsonFactory, JsonParser} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.datasources.{CompressionCodecs, ParseModes} + +/** + * Options for the JSON data source. + * + * Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]]. + */ +private[sql] class JSONOptions( + @transient private val parameters: Map[String, String]) + extends Logging with Serializable { + + val samplingRatio = + parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + val primitivesAsString = + parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false) + val prefersDecimal = + parameters.get("prefersDecimal").map(_.toBoolean).getOrElse(false) + val allowComments = + parameters.get("allowComments").map(_.toBoolean).getOrElse(false) + val allowUnquotedFieldNames = + parameters.get("allowUnquotedFieldNames").map(_.toBoolean).getOrElse(false) + val allowSingleQuotes = + parameters.get("allowSingleQuotes").map(_.toBoolean).getOrElse(true) + val allowNumericLeadingZeros = + parameters.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false) + val allowNonNumericNumbers = + parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true) + val allowBackslashEscapingAnyCharacter = + parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) + val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) + private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + val columnNameOfCorruptRecord = parameters.get("columnNameOfCorruptRecord") + + // Parse mode flags + if (!ParseModes.isValidMode(parseMode)) { + logWarning(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.") + } + + val failFast = ParseModes.isFailFastMode(parseMode) + val dropMalformed = ParseModes.isDropMalformedMode(parseMode) + val permissive = ParseModes.isPermissiveMode(parseMode) + + /** Sets config options on a Jackson [[JsonFactory]]. */ + def setJacksonOptions(factory: JsonFactory): Unit = { + factory.configure(JsonParser.Feature.ALLOW_COMMENTS, allowComments) + factory.configure(JsonParser.Feature.ALLOW_UNQUOTED_FIELD_NAMES, allowUnquotedFieldNames) + factory.configure(JsonParser.Feature.ALLOW_SINGLE_QUOTES, allowSingleQuotes) + factory.configure(JsonParser.Feature.ALLOW_NUMERIC_LEADING_ZEROS, allowNumericLeadingZeros) + factory.configure(JsonParser.Feature.ALLOW_NON_NUMERIC_NUMBERS, allowNonNumericNumbers) + factory.configure(JsonParser.Feature.ALLOW_BACKSLASH_ESCAPING_ANY_CHARACTER, + allowBackslashEscapingAnyCharacter) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 85b52f04c8d01..7364a1dc0658a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -20,78 +20,114 @@ package org.apache.spark.sql.execution.datasources.json import java.io.CharArrayWriter import com.fasterxml.jackson.core.JsonFactory -import com.google.common.base.Objects +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{LongWritable, NullWritable, Text} import org.apache.hadoop.mapred.{JobConf, TextInputFormat} +import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat -import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} -import org.apache.spark.Logging -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.mapred.SparkHadoopMapRedUtil +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection -import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.catalyst.expressions.JoinedRow +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.util.SerializableConfiguration - -class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { +class DefaultSource extends FileFormat with DataSourceRegister { override def shortName(): String = "json" - override def createRelation( + override def inferSchema( sqlContext: SQLContext, - paths: Array[String], - dataSchema: Option[StructType], - partitionColumns: Option[StructType], - parameters: Map[String, String]): HadoopFsRelation = { - val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) - val primitivesAsString = parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false) - - new JSONRelation( - None, - samplingRatio, - primitivesAsString, - dataSchema, - None, - partitionColumns, - paths)(sqlContext) + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + if (files.isEmpty) { + None + } else { + val parsedOptions: JSONOptions = new JSONOptions(options) + val columnNameOfCorruptRecord = + parsedOptions.columnNameOfCorruptRecord + .getOrElse(sqlContext.conf.columnNameOfCorruptRecord) + val jsonFiles = files.filterNot { status => + val name = status.getPath.getName + name.startsWith("_") || name.startsWith(".") + }.toArray + + val jsonSchema = InferSchema.infer( + createBaseRdd(sqlContext, jsonFiles), + columnNameOfCorruptRecord, + parsedOptions) + checkConstraints(jsonSchema) + + Some(jsonSchema) + } } -} -private[sql] class JSONRelation( - val inputRDD: Option[RDD[String]], - val samplingRatio: Double, - val primitivesAsString: Boolean, - val maybeDataSchema: Option[StructType], - val maybePartitionSpec: Option[PartitionSpec], - override val userDefinedPartitionColumns: Option[StructType], - override val paths: Array[String] = Array.empty[String])(@transient val sqlContext: SQLContext) - extends HadoopFsRelation(maybePartitionSpec) { + override def prepareWrite( + sqlContext: SQLContext, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + val conf = job.getConfiguration + val parsedOptions: JSONOptions = new JSONOptions(options) + parsedOptions.compressionCodec.foreach { codec => + CompressionCodecs.setCodecConfiguration(conf, codec) + } - /** Constraints to be imposed on schema to be stored. */ - private def checkConstraints(schema: StructType): Unit = { - if (schema.fieldNames.length != schema.fieldNames.distinct.length) { - val duplicateColumns = schema.fieldNames.groupBy(identity).collect { - case (x, ys) if ys.length > 1 => "\"" + x + "\"" - }.mkString(", ") - throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + - s"cannot save to JSON format") + new OutputWriterFactory { + override def newInstance( + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new JsonOutputWriter(path, bucketId, dataSchema, context) + } } } - override val needConversion: Boolean = false + override def buildReader( + sqlContext: SQLContext, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = { + val conf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) + val broadcastedConf = + sqlContext.sparkContext.broadcast(new SerializableConfiguration(conf)) + + val parsedOptions: JSONOptions = new JSONOptions(options) + val columnNameOfCorruptRecord = parsedOptions.columnNameOfCorruptRecord + .getOrElse(sqlContext.conf.columnNameOfCorruptRecord) + + val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes + val joinedRow = new JoinedRow() + + file => { + val lines = new HadoopFileLinesReader(file, broadcastedConf.value.value).map(_.toString) + + val rows = JacksonParser.parseJson( + lines, + requiredSchema, + columnNameOfCorruptRecord, + parsedOptions) + + val appendPartitionColumns = GenerateUnsafeProjection.generate(fullSchema, fullSchema) + rows.map { row => + appendPartitionColumns(joinedRow(row, file.partitionValues)) + } + } + } - private def createBaseRdd(inputPaths: Array[FileStatus]): RDD[String] = { - val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + private def createBaseRdd(sqlContext: SQLContext, inputPaths: Seq[FileStatus]): RDD[String] = { + val job = Job.getInstance(sqlContext.sparkContext.hadoopConfiguration) + val conf = job.getConfiguration val paths = inputPaths.map(_.getPath) @@ -106,78 +142,27 @@ private[sql] class JSONRelation( classOf[Text]).map(_._2.toString) // get the text line } - override lazy val dataSchema = { - val jsonSchema = maybeDataSchema.getOrElse { - val files = cachedLeafStatuses().filterNot { status => - val name = status.getPath.getName - name.startsWith("_") || name.startsWith(".") - }.toArray - InferSchema( - inputRDD.getOrElse(createBaseRdd(files)), - samplingRatio, - sqlContext.conf.columnNameOfCorruptRecord, - primitivesAsString) - } - checkConstraints(jsonSchema) - - jsonSchema - } - - override private[sql] def buildInternalScan( - requiredColumns: Array[String], - filters: Array[Filter], - inputPaths: Array[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { - val requiredDataSchema = StructType(requiredColumns.map(dataSchema(_))) - val rows = JacksonParser( - inputRDD.getOrElse(createBaseRdd(inputPaths)), - requiredDataSchema, - sqlContext.conf.columnNameOfCorruptRecord) - - rows.mapPartitions { iterator => - val unsafeProjection = UnsafeProjection.create(requiredDataSchema) - iterator.map(unsafeProjection) + /** Constraints to be imposed on schema to be stored. */ + private def checkConstraints(schema: StructType): Unit = { + if (schema.fieldNames.length != schema.fieldNames.distinct.length) { + val duplicateColumns = schema.fieldNames.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => "\"" + x + "\"" + }.mkString(", ") + throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + + s"cannot save to JSON format") } } - override def equals(other: Any): Boolean = other match { - case that: JSONRelation => - ((inputRDD, that.inputRDD) match { - case (Some(thizRdd), Some(thatRdd)) => thizRdd eq thatRdd - case (None, None) => true - case _ => false - }) && paths.toSet == that.paths.toSet && - dataSchema == that.dataSchema && - schema == that.schema - case _ => false - } - - override def hashCode(): Int = { - Objects.hashCode( - inputRDD, - paths.toSet, - dataSchema, - schema, - partitionColumns) - } - - override def prepareJobForWrite(job: Job): OutputWriterFactory = { - new OutputWriterFactory { - override def newInstance( - path: String, - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter = { - new JsonOutputWriter(path, dataSchema, context) - } - } - } + override def toString: String = "JSON" + override def equals(other: Any): Boolean = other.isInstanceOf[DefaultSource] } private[json] class JsonOutputWriter( path: String, + bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriter with SparkHadoopMapRedUtil with Logging { + extends OutputWriter with Logging { private[this] val writer = new CharArrayWriter() // create the Generator without separator inserted between 2 records @@ -187,11 +172,12 @@ private[json] class JsonOutputWriter( private val recordWriter: RecordWriter[NullWritable, Text] = { new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val configuration = context.getConfiguration val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") - val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString.json$extension") } }.getRecordWriter(context) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index 3f34520afe6b6..8b920ecafaeed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -17,14 +17,10 @@ package org.apache.spark.sql.execution.datasources.json -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{MapData, ArrayData, DateTimeUtils} - -import scala.collection.Map - import com.fasterxml.jackson.core._ -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, MapData} import org.apache.spark.sql.types._ private[sql] object JacksonGenerator { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index 4f53eeb081b93..aeee2600a19ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.execution.datasources.json import java.io.ByteArrayOutputStream -import com.fasterxml.jackson.core._ - import scala.collection.mutable.ArrayBuffer +import com.fasterxml.jackson.core._ + +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -32,18 +33,48 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -private[sql] object JacksonParser { - def apply( - json: RDD[String], +private[json] class SparkSQLJsonProcessingException(msg: String) extends RuntimeException(msg) + +object JacksonParser extends Logging { + + def parse( + input: RDD[String], schema: StructType, - columnNameOfCorruptRecords: String): RDD[InternalRow] = { - parseJson(json, schema, columnNameOfCorruptRecords) + columnNameOfCorruptRecords: String, + configOptions: JSONOptions): RDD[InternalRow] = { + + input.mapPartitions { iter => + parseJson(iter, schema, columnNameOfCorruptRecords, configOptions) + } } /** * Parse the current token (and related children) according to a desired schema + * This is an wrapper for the method `convertField()` to handle a row wrapped + * with an array. */ - private[sql] def convertField( + def convertRootField( + factory: JsonFactory, + parser: JsonParser, + schema: DataType): Any = { + import com.fasterxml.jackson.core.JsonToken._ + (parser.getCurrentToken, schema) match { + case (START_ARRAY, st: StructType) => + // SPARK-3308: support reading top level JSON arrays and take every element + // in such an array as a row + convertArray(factory, parser, st) + + case (START_OBJECT, ArrayType(st, _)) => + // the business end of SPARK-3308: + // when an object is found but an array is requested just wrap it in a list + convertField(factory, parser, st) :: Nil + + case _ => + convertField(factory, parser, schema) + } + } + + private def convertField( factory: JsonFactory, parser: JsonParser, schema: DataType): Any = { @@ -83,7 +114,7 @@ private[sql] object JacksonParser { DateTimeUtils.stringToTime(parser.getText).getTime * 1000L case (VALUE_NUMBER_INT, TimestampType) => - parser.getLongValue * 1000L + parser.getLongValue * 1000000L case (_, StringType) => val writer = new ByteArrayOutputStream() @@ -106,7 +137,7 @@ private[sql] object JacksonParser { lowerCaseValue.equals("-inf")) { value.toFloat } else { - sys.error(s"Cannot parse $value as FloatType.") + throw new SparkSQLJsonProcessingException(s"Cannot parse $value as FloatType.") } case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, DoubleType) => @@ -123,7 +154,7 @@ private[sql] object JacksonParser { lowerCaseValue.equals("-inf")) { value.toDouble } else { - sys.error(s"Cannot parse $value as DoubleType.") + throw new SparkSQLJsonProcessingException(s"Cannot parse $value as DoubleType.") } case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, dt: DecimalType) => @@ -150,19 +181,9 @@ private[sql] object JacksonParser { case (START_OBJECT, st: StructType) => convertObject(factory, parser, st) - case (START_ARRAY, st: StructType) => - // SPARK-3308: support reading top level JSON arrays and take every element - // in such an array as a row - convertArray(factory, parser, st) - case (START_ARRAY, ArrayType(st, _)) => convertArray(factory, parser, st) - case (START_OBJECT, ArrayType(st, _)) => - // the business end of SPARK-3308: - // when an object is found but an array is requested just wrap it in a list - convertField(factory, parser, st) :: Nil - case (START_OBJECT, MapType(StringType, kt, _)) => convertMap(factory, parser, kt) @@ -170,7 +191,11 @@ private[sql] object JacksonParser { convertField(factory, parser, udt.sqlType) case (token, dataType) => - sys.error(s"Failed to parse a value for data type $dataType (current token: $token).") + // We cannot parse this token based on the given data type. So, we throw a + // SparkSQLJsonProcessingException and this exception will be caught by + // parseJson method. + throw new SparkSQLJsonProcessingException( + s"Failed to parse a value for data type $dataType (current token: $token).") } } @@ -225,53 +250,59 @@ private[sql] object JacksonParser { new GenericArrayData(values.toArray) } - private def parseJson( - json: RDD[String], + def parseJson( + input: Iterator[String], schema: StructType, - columnNameOfCorruptRecords: String): RDD[InternalRow] = { + columnNameOfCorruptRecords: String, + configOptions: JSONOptions): Iterator[InternalRow] = { def failedRecord(record: String): Seq[InternalRow] = { // create a row even if no corrupt record column is present - val row = new GenericMutableRow(schema.length) - for (corruptIndex <- schema.getFieldIndex(columnNameOfCorruptRecords)) { - require(schema(corruptIndex).dataType == StringType) - row.update(corruptIndex, UTF8String.fromString(record)) + if (configOptions.failFast) { + throw new RuntimeException(s"Malformed line in FAILFAST mode: $record") + } + if (configOptions.dropMalformed) { + logWarning(s"Dropping malformed line: $record") + Nil + } else { + val row = new GenericMutableRow(schema.length) + for (corruptIndex <- schema.getFieldIndex(columnNameOfCorruptRecords)) { + require(schema(corruptIndex).dataType == StringType) + row.update(corruptIndex, UTF8String.fromString(record)) + } + Seq(row) } - - Seq(row) } - json.mapPartitions { iter => - val factory = new JsonFactory() - - iter.flatMap { record => - if (record.trim.isEmpty) { - Nil - } else { - try { - Utils.tryWithResource(factory.createParser(record)) { parser => - parser.nextToken() - - convertField(factory, parser, schema) match { - case null => failedRecord(record) - case row: InternalRow => row :: Nil - case array: ArrayData => - if (array.numElements() == 0) { - Nil - } else { - array.toArray[InternalRow](schema) - } - case _ => - sys.error( - s"Failed to parse record $record. Please make sure that each line of " + - "the file (or each string in the RDD) is a valid JSON object or " + - "an array of JSON objects.") - } + val factory = new JsonFactory() + configOptions.setJacksonOptions(factory) + + input.flatMap { record => + if (record.trim.isEmpty) { + Nil + } else { + try { + Utils.tryWithResource(factory.createParser(record)) { parser => + parser.nextToken() + + convertRootField(factory, parser, schema) match { + case null => failedRecord(record) + case row: InternalRow => row :: Nil + case array: ArrayData => + if (array.numElements() == 0) { + Nil + } else { + array.toArray[InternalRow](schema) + } + case _ => + failedRecord(record) } - } catch { - case _: JsonProcessingException => - failedRecord(record) } + } catch { + case _: JsonProcessingException => + failedRecord(record) + case _: SparkSQLJsonProcessingException => + failedRecord(record) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala index a958373eb769d..850e807b8677e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala @@ -22,14 +22,13 @@ import java.util.{Map => JMap} import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration -import org.apache.parquet.hadoop.api.ReadSupport.ReadContext import org.apache.parquet.hadoop.api.{InitContext, ReadSupport} +import org.apache.parquet.hadoop.api.ReadSupport.ReadContext import org.apache.parquet.io.api.RecordMaterializer -import org.apache.parquet.schema.Type.Repetition import org.apache.parquet.schema._ +import org.apache.parquet.schema.Type.Repetition -import org.apache.spark.Logging -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ @@ -58,9 +57,7 @@ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with */ override def init(context: InitContext): ReadContext = { catalystRequestedSchema = { - // scalastyle:off jobcontext val conf = context.getConfiguration - // scalastyle:on jobcontext val schemaString = conf.get(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA) assert(schemaString != null, "Parquet requested schema not set.") StructType.fromString(schemaString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 1f653cd3d3cb1..6bf82bee67881 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -25,14 +25,15 @@ import scala.collection.mutable.ArrayBuffer import org.apache.parquet.column.Dictionary import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} -import org.apache.parquet.schema.OriginalType.{INT_32, LIST, UTF8} -import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{DOUBLE, INT32, INT64, BINARY, FIXED_LEN_BYTE_ARRAY} import org.apache.parquet.schema.{GroupType, MessageType, PrimitiveType, Type} +import org.apache.parquet.schema.OriginalType.{INT_32, LIST, UTF8} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{BINARY, DOUBLE, FIXED_LEN_BYTE_ARRAY, INT32, INT64} -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLTimestamp import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -327,8 +328,8 @@ private[parquet] class CatalystRowConverter( // are using `Binary.toByteBuffer.array()` to steal the underlying byte array without copying // it. val buffer = value.toByteBuffer - val offset = buffer.position() - val numBytes = buffer.limit() - buffer.position() + val offset = buffer.arrayOffset() + buffer.position() + val numBytes = buffer.remaining() updater.set(UTF8String.fromBytes(buffer.array(), offset, numBytes)) } } @@ -368,37 +369,15 @@ private[parquet] class CatalystRowConverter( } protected def decimalFromBinary(value: Binary): Decimal = { - if (precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64) { + if (precision <= Decimal.MAX_LONG_DIGITS) { // Constructs a `Decimal` with an unscaled `Long` value if possible. - val unscaled = binaryToUnscaledLong(value) + val unscaled = CatalystRowConverter.binaryToUnscaledLong(value) Decimal(unscaled, precision, scale) } else { // Otherwise, resorts to an unscaled `BigInteger` instead. Decimal(new BigDecimal(new BigInteger(value.getBytes), scale), precision, scale) } } - - private def binaryToUnscaledLong(binary: Binary): Long = { - // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here - // we are using `Binary.toByteBuffer.array()` to steal the underlying byte array without - // copying it. - val buffer = binary.toByteBuffer - val bytes = buffer.array() - val start = buffer.position() - val end = buffer.limit() - - var unscaled = 0L - var i = start - - while (i < end) { - unscaled = (unscaled << 8) | (bytes(i) & 0xff) - i += 1 - } - - val bits = 8 * (end - start) - unscaled = (unscaled << (64 - bits)) >> (64 - bits) - unscaled - } } private class CatalystIntDictionaryAwareDecimalConverter( @@ -658,3 +637,36 @@ private[parquet] class CatalystRowConverter( override def start(): Unit = elementConverter.start() } } + +private[parquet] object CatalystRowConverter { + def binaryToUnscaledLong(binary: Binary): Long = { + // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here + // we are using `Binary.toByteBuffer.array()` to steal the underlying byte array without + // copying it. + val buffer = binary.toByteBuffer + val bytes = buffer.array() + val start = buffer.arrayOffset() + buffer.position() + val end = buffer.arrayOffset() + buffer.limit() + + var unscaled = 0L + var i = start + + while (i < end) { + unscaled = (unscaled << 8) | (bytes(i) & 0xff) + i += 1 + } + + val bits = 8 * (end - start) + unscaled = (unscaled << (64 - bits)) >> (64 - bits) + unscaled + } + + def binaryToSQLTimestamp(binary: Binary): SQLTimestamp = { + assert(binary.length() == 12, s"Timestamps (with nanoseconds) are expected to be stored in" + + s" 12-byte long binaries. Found a ${binary.length()}-byte binary instead.") + val buffer = binary.toByteBuffer.order(ByteOrder.LITTLE_ENDIAN) + val timeOfDayNanos = buffer.getLong + val julianDay = buffer.getInt + DateTimeUtils.fromJulianDay(julianDay, timeOfDayNanos) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index 7f3394c20ed3d..6f6340f541ada 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -20,14 +20,15 @@ package org.apache.spark.sql.execution.datasources.parquet import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration +import org.apache.parquet.schema._ import org.apache.parquet.schema.OriginalType._ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.parquet.schema.Type.Repetition._ -import org.apache.parquet.schema._ -import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64, maxPrecisionForBytes} +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.maxPrecisionForBytes +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, SQLConf} /** * This converter class is used to convert Parquet [[MessageType]] to Spark SQL [[StructType]] and @@ -37,7 +38,6 @@ import org.apache.spark.sql.{AnalysisException, SQLConf} * [[MessageType]] schemas. * * @see https://github.com/apache/parquet-format/blob/master/LogicalTypes.md - * * @constructor * @param assumeBinaryIsString Whether unannotated BINARY fields should be assumed to be Spark SQL * [[StringType]] fields when converting Parquet a [[MessageType]] to Spark SQL @@ -65,7 +65,8 @@ private[parquet] class CatalystSchemaConverter( def this(conf: Configuration) = this( assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean, assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean, - writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean) + writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, + SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get.toString).toBoolean) /** * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]]. @@ -108,6 +109,9 @@ private[parquet] class CatalystSchemaConverter( def typeString = if (originalType == null) s"$typeName" else s"$typeName ($originalType)" + def typeNotSupported() = + throw new AnalysisException(s"Parquet type not supported: $typeString") + def typeNotImplemented() = throw new AnalysisException(s"Parquet type not yet supported: $typeString") @@ -141,7 +145,10 @@ private[parquet] class CatalystSchemaConverter( case INT_16 => ShortType case INT_32 | null => IntegerType case DATE => DateType - case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT32) + case DECIMAL => makeDecimalType(Decimal.MAX_INT_DIGITS) + case UINT_8 => typeNotSupported() + case UINT_16 => typeNotSupported() + case UINT_32 => typeNotSupported() case TIME_MILLIS => typeNotImplemented() case _ => illegalType() } @@ -149,7 +156,8 @@ private[parquet] class CatalystSchemaConverter( case INT64 => originalType match { case INT_64 | null => LongType - case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT64) + case DECIMAL => makeDecimalType(Decimal.MAX_LONG_DIGITS) + case UINT_64 => typeNotSupported() case TIMESTAMP_MILLIS => typeNotImplemented() case _ => illegalType() } @@ -163,9 +171,10 @@ private[parquet] class CatalystSchemaConverter( case BINARY => originalType match { - case UTF8 | ENUM => StringType + case UTF8 | ENUM | JSON => StringType case null if assumeBinaryIsString => StringType case null => BinaryType + case BSON => BinaryType case DECIMAL => makeDecimalType() case _ => illegalType() } @@ -394,7 +403,7 @@ private[parquet] class CatalystSchemaConverter( // Uses INT32 for 1 <= precision <= 9 case DecimalType.Fixed(precision, scale) - if precision <= MAX_PRECISION_FOR_INT32 && !writeLegacyParquetFormat => + if precision <= Decimal.MAX_INT_DIGITS && !writeLegacyParquetFormat => Types .primitive(INT32, repetition) .as(DECIMAL) @@ -404,7 +413,7 @@ private[parquet] class CatalystSchemaConverter( // Uses INT64 for 1 <= precision <= 18 case DecimalType.Fixed(precision, scale) - if precision <= MAX_PRECISION_FOR_INT64 && !writeLegacyParquetFormat => + if precision <= Decimal.MAX_LONG_DIGITS && !writeLegacyParquetFormat => Types .primitive(INT64, repetition) .as(DECIMAL) @@ -535,7 +544,7 @@ private[parquet] object CatalystSchemaConverter { !name.matches(".*[ ,;{}()\n\t=].*"), s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=". |Please use alias to rename it. - """.stripMargin.split("\n").mkString(" ")) + """.stripMargin.split("\n").mkString(" ").trim) } def checkFieldNames(schema: StructType): StructType = { @@ -560,10 +569,6 @@ private[parquet] object CatalystSchemaConverter { // Returns the minimum number of bytes needed to store a decimal with a given `precision`. val minBytesForPrecision = Array.tabulate[Int](39)(computeMinBytesForPrecision) - val MAX_PRECISION_FOR_INT32 = maxPrecisionForBytes(4) /* 9 */ - - val MAX_PRECISION_FOR_INT64 = maxPrecisionForBytes(8) /* 18 */ - // Max precision of a decimal value stored in `numBytes` bytes def maxPrecisionForBytes(numBytes: Int): Int = { Math.round( // convert double to long diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala index 483363d2c1a21..67bfd39697ed7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala @@ -29,12 +29,12 @@ import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.hadoop.api.WriteSupport.WriteContext import org.apache.parquet.io.api.{Binary, RecordConsumer} -import org.apache.spark.Logging -import org.apache.spark.sql.SQLConf +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64, minBytesForPrecision} +import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.minBytesForPrecision +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -253,13 +253,13 @@ private[parquet] class CatalystWriteSupport extends WriteSupport[InternalRow] wi writeLegacyParquetFormat match { // Standard mode, 1 <= precision <= 9, writes as INT32 - case false if precision <= MAX_PRECISION_FOR_INT32 => int32Writer + case false if precision <= Decimal.MAX_INT_DIGITS => int32Writer // Standard mode, 10 <= precision <= 18, writes as INT64 - case false if precision <= MAX_PRECISION_FOR_INT64 => int64Writer + case false if precision <= Decimal.MAX_LONG_DIGITS => int64Writer // Legacy mode, 1 <= precision <= 18, writes as FIXED_LEN_BYTE_ARRAY - case true if precision <= MAX_PRECISION_FOR_INT64 => binaryWriterUsingUnscaledLong + case true if precision <= Decimal.MAX_LONG_DIGITS => binaryWriterUsingUnscaledLong // Either standard or legacy mode, 19 <= precision <= 38, writes as FIXED_LEN_BYTE_ARRAY case _ => binaryWriterUsingUnscaledBytes @@ -429,7 +429,7 @@ private[parquet] object CatalystWriteSupport { def setSchema(schema: StructType, configuration: Configuration): Unit = { schema.map(_.name).foreach(CatalystSchemaConverter.checkFieldName) configuration.set(SPARK_ROW_SCHEMA, schema.json) - configuration.set( + configuration.setIfUnset( ParquetOutputFormat.WRITER_VERSION, ParquetProperties.WriterVersion.PARQUET_1_0.toString) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala deleted file mode 100644 index 300e8677b312f..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.parquet - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter -import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} -import org.apache.parquet.Log -import org.apache.parquet.hadoop.util.ContextUtil -import org.apache.parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat} - -/** - * An output committer for writing Parquet files. In stead of writing to the `_temporary` folder - * like what [[ParquetOutputCommitter]] does, this output committer writes data directly to the - * destination folder. This can be useful for data stored in S3, where directory operations are - * relatively expensive. - * - * To enable this output committer, users may set the "spark.sql.parquet.output.committer.class" - * property via Hadoop [[Configuration]]. Not that this property overrides - * "spark.sql.sources.outputCommitterClass". - * - * *NOTE* - * - * NEVER use [[DirectParquetOutputCommitter]] when appending data, because currently there's - * no safe way undo a failed appending job (that's why both `abortTask()` and `abortJob()` are - * left empty). - */ -private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) - extends ParquetOutputCommitter(outputPath, context) { - val LOG = Log.getLog(classOf[ParquetOutputCommitter]) - - override def getWorkPath: Path = outputPath - override def abortTask(taskContext: TaskAttemptContext): Unit = {} - override def commitTask(taskContext: TaskAttemptContext): Unit = {} - override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = true - override def setupJob(jobContext: JobContext): Unit = {} - override def setupTask(taskContext: TaskAttemptContext): Unit = {} - - override def commitJob(jobContext: JobContext) { - val configuration = { - // scalastyle:off jobcontext - ContextUtil.getConfiguration(jobContext) - // scalastyle:on jobcontext - } - val fileSystem = outputPath.getFileSystem(configuration) - - if (configuration.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, true)) { - try { - val outputStatus = fileSystem.getFileStatus(outputPath) - val footers = ParquetFileReader.readAllFootersInParallel(configuration, outputStatus) - try { - ParquetFileWriter.writeMetadataFile(configuration, outputPath, footers) - } catch { case e: Exception => - LOG.warn("could not write summary file for " + outputPath, e) - val metadataPath = new Path(outputPath, ParquetFileWriter.PARQUET_METADATA_FILE) - if (fileSystem.exists(metadataPath)) { - fileSystem.delete(metadataPath, true) - } - } - } catch { - case e: Exception => LOG.warn("could not write summary file for " + outputPath, e) - } - } - - if (configuration.getBoolean("mapreduce.fileoutputcommitter.marksuccessfuljobs", true)) { - try { - val successPath = new Path(outputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME) - fileSystem.create(successPath).close() - } catch { - case e: Exception => LOG.warn("could not write success file for " + outputPath, e) - } - } - } -} - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 07714329370a5..95afdc789f322 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.Serializable -import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.filter2.predicate._ +import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.io.api.Binary import org.apache.parquet.schema.OriginalType import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName @@ -207,11 +207,26 @@ private[sql] object ParquetFilters { */ } + /** + * SPARK-11955: The optional fields will have metadata StructType.metadataKeyForOptionalField. + * These fields only exist in one side of merged schemas. Due to that, we can't push down filters + * using such fields, otherwise Parquet library will throw exception. Here we filter out such + * fields. + */ + private def getFieldMap(dataType: DataType): Array[(String, DataType)] = dataType match { + case StructType(fields) => + fields.filter { f => + !f.metadata.contains(StructType.metadataKeyForOptionalField) || + !f.metadata.getBoolean(StructType.metadataKeyForOptionalField) + }.map(f => f.name -> f.dataType) ++ fields.flatMap { f => getFieldMap(f.dataType) } + case _ => Array.empty[(String, DataType)] + } + /** * Converts data sources filters to Parquet filter predicates. */ def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = { - val dataTypeOf = schema.map(f => f.name -> f.dataType).toMap + val dataTypeOf = getFieldMap(schema).toMap relaxParquetValidTypeMap @@ -219,7 +234,7 @@ private[sql] object ParquetFilters { // // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, // which can be casted to `false` implicitly. Please refer to the `eval` method of these - // operators and the `SimplifyFilters` rule for details. + // operators and the `PruneFilters` rule for details. // Hyukjin: // I added [[EqualNullSafe]] with [[org.apache.parquet.filter2.predicate.Operators.Eq]]. @@ -231,33 +246,46 @@ private[sql] object ParquetFilters { // Probably I missed something and obviously this should be changed. predicate match { - case sources.IsNull(name) => + case sources.IsNull(name) if dataTypeOf.contains(name) => makeEq.lift(dataTypeOf(name)).map(_(name, null)) - case sources.IsNotNull(name) => + case sources.IsNotNull(name) if dataTypeOf.contains(name) => makeNotEq.lift(dataTypeOf(name)).map(_(name, null)) - case sources.EqualTo(name, value) => + case sources.EqualTo(name, value) if dataTypeOf.contains(name) => makeEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.Not(sources.EqualTo(name, value)) => + case sources.Not(sources.EqualTo(name, value)) if dataTypeOf.contains(name) => makeNotEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.EqualNullSafe(name, value) => + case sources.EqualNullSafe(name, value) if dataTypeOf.contains(name) => makeEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.Not(sources.EqualNullSafe(name, value)) => + case sources.Not(sources.EqualNullSafe(name, value)) if dataTypeOf.contains(name) => makeNotEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.LessThan(name, value) => + case sources.LessThan(name, value) if dataTypeOf.contains(name) => makeLt.lift(dataTypeOf(name)).map(_(name, value)) - case sources.LessThanOrEqual(name, value) => + case sources.LessThanOrEqual(name, value) if dataTypeOf.contains(name) => makeLtEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.GreaterThan(name, value) => + case sources.GreaterThan(name, value) if dataTypeOf.contains(name) => makeGt.lift(dataTypeOf(name)).map(_(name, value)) - case sources.GreaterThanOrEqual(name, value) => + case sources.GreaterThanOrEqual(name, value) if dataTypeOf.contains(name) => makeGtEq.lift(dataTypeOf(name)).map(_(name, value)) + case sources.In(name, valueSet) => + makeInSet.lift(dataTypeOf(name)).map(_(name, valueSet.toSet)) + case sources.And(lhs, rhs) => - (createFilter(schema, lhs) ++ createFilter(schema, rhs)).reduceOption(FilterApi.and) + // At here, it is not safe to just convert one side if we do not understand the + // other side. Here is an example used to explain the reason. + // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to + // convert b in ('1'). If we only convert a = 2, we will end up with a filter + // NOT(a = 2), which will generate wrong results. + // Pushing one side of AND down is only safe to do at the top level. + // You can see ParquetRelation's initializeLocalJobFunc method as an example. + for { + lhsFilter <- createFilter(schema, lhs) + rhsFilter <- createFilter(schema, rhs) + } yield FilterApi.and(lhsFilter, rhsFilter) case sources.Or(lhs, rhs) => for { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala new file mode 100644 index 0000000000000..00352f23ae660 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.parquet.hadoop.metadata.CompressionCodecName + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.internal.SQLConf + +/** + * Options for the Parquet data source. + */ +class ParquetOptions( + @transient private val parameters: Map[String, String], + @transient private val sqlConf: SQLConf) + extends Logging with Serializable { + + import ParquetOptions._ + + /** + * Compression codec to use. By default use the value specified in SQLConf. + * Acceptable values are defined in [[shortParquetCompressionCodecNames]]. + */ + val compressionCodec: String = { + val codecName = parameters.getOrElse("compression", sqlConf.parquetCompressionCodec).toLowerCase + if (!shortParquetCompressionCodecNames.contains(codecName)) { + val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase) + throw new IllegalArgumentException(s"Codec [$codecName] " + + s"is not available. Available codecs are ${availableCodecs.mkString(", ")}.") + } + shortParquetCompressionCodecNames(codecName).name() + } +} + + +object ParquetOptions { + // The parquet compression short names + private val shortParquetCompressionCodecNames = Map( + "none" -> CompressionCodecName.UNCOMPRESSED, + "uncompressed" -> CompressionCodecName.UNCOMPRESSED, + "snappy" -> CompressionCodecName.SNAPPY, + "gzip" -> CompressionCodecName.GZIP, + "lzo" -> CompressionCodecName.LZO) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 5a7c6b95b565f..bfe7aefe4100c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -19,210 +19,58 @@ package org.apache.spark.sql.execution.datasources.parquet import java.net.URI import java.util.logging.{Logger => JLogger} -import java.util.{List => JList} import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.{Failure, Try} -import com.google.common.base.Objects import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.parquet.{Log => ApacheParquetLog} +import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.hadoop._ -import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType -import org.apache.parquet.{Log => ApacheParquetLog} import org.slf4j.bridge.SLF4JBridgeHandler -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD} +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.catalyst.expressions.JoinedRow +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} - +import org.apache.spark.sql.types.{AtomicType, DataType, StructType} +import org.apache.spark.util.SerializableConfiguration -private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { +private[sql] class DefaultSource + extends FileFormat + with DataSourceRegister + with Logging + with Serializable { override def shortName(): String = "parquet" - override def createRelation( - sqlContext: SQLContext, - paths: Array[String], - schema: Option[StructType], - partitionColumns: Option[StructType], - parameters: Map[String, String]): HadoopFsRelation = { - new ParquetRelation(paths, schema, None, partitionColumns, parameters)(sqlContext) - } -} + override def toString: String = "ParquetFormat" -// NOTE: This class is instantiated and used on executor side only, no need to be serializable. -private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext) - extends OutputWriter { + override def equals(other: Any): Boolean = other.isInstanceOf[DefaultSource] - private val recordWriter: RecordWriter[Void, InternalRow] = { - val outputFormat = { - new ParquetOutputFormat[InternalRow]() { - // Here we override `getDefaultWorkFile` for two reasons: - // - // 1. To allow appending. We need to generate unique output file names to avoid - // overwriting existing files (either exist before the write job, or are just written - // by other tasks within the same write job). - // - // 2. To allow dynamic partitioning. Default `getDefaultWorkFile` uses - // `FileOutputCommitter.getWorkPath()`, which points to the base directory of all - // partitions in the case of dynamic partitioning. - override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) - val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") - val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) - val split = taskAttemptId.getTaskID.getId - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") - } - } - } - - outputFormat.getRecordWriter(context) - } - - override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") - - override protected[sql] def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) - - override def close(): Unit = recordWriter.close(context) -} - -private[sql] class ParquetRelation( - override val paths: Array[String], - private val maybeDataSchema: Option[StructType], - // This is for metastore conversion. - private val maybePartitionSpec: Option[PartitionSpec], - override val userDefinedPartitionColumns: Option[StructType], - parameters: Map[String, String])( - val sqlContext: SQLContext) - extends HadoopFsRelation(maybePartitionSpec) - with Logging { - - private[sql] def this( - paths: Array[String], - maybeDataSchema: Option[StructType], - maybePartitionSpec: Option[PartitionSpec], - parameters: Map[String, String])( - sqlContext: SQLContext) = { - this( - paths, - maybeDataSchema, - maybePartitionSpec, - maybePartitionSpec.map(_.partitionColumns), - parameters)(sqlContext) - } - - // Should we merge schemas from all Parquet part-files? - private val shouldMergeSchemas = - parameters - .get(ParquetRelation.MERGE_SCHEMA) - .map(_.toBoolean) - .getOrElse(sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED)) - - private val mergeRespectSummaries = - sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES) - - private val maybeMetastoreSchema = parameters - .get(ParquetRelation.METASTORE_SCHEMA) - .map(DataType.fromJson(_).asInstanceOf[StructType]) - - private lazy val metadataCache: MetadataCache = { - val meta = new MetadataCache - meta.refresh() - meta - } - - override def equals(other: Any): Boolean = other match { - case that: ParquetRelation => - val schemaEquality = if (shouldMergeSchemas) { - this.shouldMergeSchemas == that.shouldMergeSchemas - } else { - this.dataSchema == that.dataSchema && - this.schema == that.schema - } - - this.paths.toSet == that.paths.toSet && - schemaEquality && - this.maybeDataSchema == that.maybeDataSchema && - this.partitionColumns == that.partitionColumns - - case _ => false - } - - override def hashCode(): Int = { - if (shouldMergeSchemas) { - Objects.hashCode( - Boolean.box(shouldMergeSchemas), - paths.toSet, - maybeDataSchema, - partitionColumns) - } else { - Objects.hashCode( - Boolean.box(shouldMergeSchemas), - paths.toSet, - dataSchema, - schema, - maybeDataSchema, - partitionColumns) - } - } - - /** Constraints on schema of dataframe to be stored. */ - private def checkConstraints(schema: StructType): Unit = { - if (schema.fieldNames.length != schema.fieldNames.distinct.length) { - val duplicateColumns = schema.fieldNames.groupBy(identity).collect { - case (x, ys) if ys.length > 1 => "\"" + x + "\"" - }.mkString(", ") - throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + - s"cannot save to parquet format") - } - } - - override def dataSchema: StructType = { - val schema = maybeDataSchema.getOrElse(metadataCache.dataSchema) - // check if schema satisfies the constraints - // before moving forward - checkConstraints(schema) - schema - } - - override private[sql] def refresh(): Unit = { - super.refresh() - metadataCache.refresh() - } - - // Parquet data source always uses Catalyst internal representations. - override val needConversion: Boolean = false - - override def sizeInBytes: Long = metadataCache.dataStatuses.map(_.getLen).sum + override def prepareWrite( + sqlContext: SQLContext, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { - override def prepareJobForWrite(job: Job): OutputWriterFactory = { - val conf = { - // scalastyle:off jobcontext - ContextUtil.getConfiguration(job) - // scalastyle:on jobcontext - } + val parquetOptions = new ParquetOptions(options, sqlContext.sessionState.conf) - // SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible - val committerClassName = conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) - if (committerClassName == "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") { - conf.set(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, - classOf[DirectParquetOutputCommitter].getCanonicalName) - } + val conf = ContextUtil.getConfiguration(job) val committerClass = conf.getClass( @@ -249,7 +97,12 @@ private[sql] class ParquetRelation( job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) ParquetOutputFormat.setWriteSupportClass(job, classOf[CatalystWriteSupport]) - CatalystWriteSupport.setSchema(dataSchema, conf) + + // We want to clear this temporary metadata from saving into Parquet file. + // This metadata is only useful for detecting optional columns when pushdowning filters. + val dataSchemaToWrite = StructType.removeMetadata(StructType.metadataKeyForOptionalField, + dataSchema).asInstanceOf[StructType] + CatalystWriteSupport.setSchema(dataSchemaToWrite, conf) // Sets flags for `CatalystSchemaConverter` (which converts Catalyst schema to Parquet schema) // and `CatalystWriteSupport` (writing actual rows to Parquet files). @@ -266,234 +119,299 @@ private[sql] class ParquetRelation( sqlContext.conf.writeLegacyParquetFormat.toString) // Sets compression scheme - conf.set( - ParquetOutputFormat.COMPRESSION, - ParquetRelation - .shortParquetCompressionCodecNames - .getOrElse( - sqlContext.conf.parquetCompressionCodec.toUpperCase, - CompressionCodecName.UNCOMPRESSED).name()) + conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodec) new OutputWriterFactory { override def newInstance( - path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new ParquetOutputWriter(path, context) + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new ParquetOutputWriter(path, bucketId, context) } } } - override def buildInternalScan( - requiredColumns: Array[String], - filters: Array[Filter], - inputFiles: Array[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { - val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA) - val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown - val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString - val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp + def inferSchema( + sqlContext: SQLContext, + parameters: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + // Should we merge schemas from all Parquet part-files? + val shouldMergeSchemas = + parameters + .get(ParquetRelation.MERGE_SCHEMA) + .map(_.toBoolean) + .getOrElse(sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED)) - // When merging schemas is enabled and the column of the given filter does not exist, - // Parquet emits an exception which is an issue of Parquet (PARQUET-389). - val safeParquetFilterPushDown = !shouldMergeSchemas && parquetFilterPushDown - - // Parquet row group size. We will use this value as the value for - // mapreduce.input.fileinputformat.split.minsize and mapred.min.split.size if the value - // of these flags are smaller than the parquet row group size. - val parquetBlockSize = ParquetOutputFormat.getLongBlockSize(broadcastedConf.value.value) - - // Create the function to set variable Parquet confs at both driver and executor side. - val initLocalJobFuncOpt = - ParquetRelation.initializeLocalJobFunc( - requiredColumns, - filters, - dataSchema, - parquetBlockSize, - useMetadataCache, - safeParquetFilterPushDown, - assumeBinaryIsString, - assumeInt96IsTimestamp) _ - - // Create the function to set input paths at the driver side. - val setInputPaths = - ParquetRelation.initializeDriverSideJobFunc(inputFiles, parquetBlockSize) _ - - Utils.withDummyCallSite(sqlContext.sparkContext) { - new SqlNewHadoopRDD( - sc = sqlContext.sparkContext, - broadcastedConf = broadcastedConf, - initDriverSideJobFuncOpt = Some(setInputPaths), - initLocalJobFuncOpt = Some(initLocalJobFuncOpt), - inputFormatClass = classOf[ParquetInputFormat[InternalRow]], - valueClass = classOf[InternalRow]) { - - val cacheMetadata = useMetadataCache - - @transient val cachedStatuses = inputFiles.map { f => - // In order to encode the authority of a Path containing special characters such as '/' - // (which does happen in some S3N credentials), we need to use the string returned by the - // URI of the path to create a new Path. - val pathWithEscapedAuthority = escapePathUserInfo(f.getPath) - new FileStatus( - f.getLen, f.isDir, f.getReplication, f.getBlockSize, f.getModificationTime, - f.getAccessTime, f.getPermission, f.getOwner, f.getGroup, pathWithEscapedAuthority) - }.toSeq - - private def escapePathUserInfo(path: Path): Path = { - val uri = path.toUri - new Path(new URI( - uri.getScheme, uri.getRawUserInfo, uri.getHost, uri.getPort, uri.getPath, - uri.getQuery, uri.getFragment)) - } + val mergeRespectSummaries = + sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES) - // Overridden so we can inject our own cached files statuses. - override def getPartitions: Array[SparkPartition] = { - val inputFormat = new ParquetInputFormat[InternalRow] { - override def listStatus(jobContext: JobContext): JList[FileStatus] = { - if (cacheMetadata) cachedStatuses.asJava else super.listStatus(jobContext) - } + val filesByType = splitFiles(files) + + // Sees which file(s) we need to touch in order to figure out the schema. + // + // Always tries the summary files first if users don't require a merged schema. In this case, + // "_common_metadata" is more preferable than "_metadata" because it doesn't contain row + // groups information, and could be much smaller for large Parquet files with lots of row + // groups. If no summary file is available, falls back to some random part-file. + // + // NOTE: Metadata stored in the summary files are merged from all part-files. However, for + // user defined key-value metadata (in which we store Spark SQL schema), Parquet doesn't know + // how to merge them correctly if some key is associated with different values in different + // part-files. When this happens, Parquet simply gives up generating the summary file. This + // implies that if a summary file presents, then: + // + // 1. Either all part-files have exactly the same Spark SQL schema, or + // 2. Some part-files don't contain Spark SQL schema in the key-value metadata at all (thus + // their schemas may differ from each other). + // + // Here we tend to be pessimistic and take the second case into account. Basically this means + // we can't trust the summary files if users require a merged schema, and must touch all part- + // files to do the merge. + val filesToTouch = + if (shouldMergeSchemas) { + // Also includes summary files, 'cause there might be empty partition directories. + + // If mergeRespectSummaries config is true, we assume that all part-files are the same for + // their schema with summary files, so we ignore them when merging schema. + // If the config is disabled, which is the default setting, we merge all part-files. + // In this mode, we only need to merge schemas contained in all those summary files. + // You should enable this configuration only if you are very sure that for the parquet + // part-files to read there are corresponding summary files containing correct schema. + + // As filed in SPARK-11500, the order of files to touch is a matter, which might affect + // the ordering of the output columns. There are several things to mention here. + // + // 1. If mergeRespectSummaries config is false, then it merges schemas by reducing from + // the first part-file so that the columns of the lexicographically first file show + // first. + // + // 2. If mergeRespectSummaries config is true, then there should be, at least, + // "_metadata"s for all given files, so that we can ensure the columns of + // the lexicographically first file show first. + // + // 3. If shouldMergeSchemas is false, but when multiple files are given, there is + // no guarantee of the output order, since there might not be a summary file for the + // lexicographically first file, which ends up putting ahead the columns of + // the other files. However, this should be okay since not enabling + // shouldMergeSchemas means (assumes) all the files have the same schemas. + + val needMerged: Seq[FileStatus] = + if (mergeRespectSummaries) { + Seq() + } else { + filesByType.data } + needMerged ++ filesByType.metadata ++ filesByType.commonMetadata + } else { + // Tries any "_common_metadata" first. Parquet files written by old versions or Parquet + // don't have this. + filesByType.commonMetadata.headOption + // Falls back to "_metadata" + .orElse(filesByType.metadata.headOption) + // Summary file(s) not found, the Parquet file is either corrupted, or different part- + // files contain conflicting user defined metadata (two or more values are associated + // with a same key in different files). In either case, we fall back to any of the + // first part-file, and just assume all schemas are consistent. + .orElse(filesByType.data.headOption) + .toSeq + } + ParquetRelation.mergeSchemasInParallel(filesToTouch, sqlContext) + } - val jobContext = newJobContext(getConf(isDriverSide = true), jobId) - val rawSplits = inputFormat.getSplits(jobContext) + case class FileTypes( + data: Seq[FileStatus], + metadata: Seq[FileStatus], + commonMetadata: Seq[FileStatus]) + + private def splitFiles(allFiles: Seq[FileStatus]): FileTypes = { + // Lists `FileStatus`es of all leaf nodes (files) under all base directories. + val leaves = allFiles.filter { f => + isSummaryFile(f.getPath) || + !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) + }.toArray.sortBy(_.getPath.toString) + + FileTypes( + data = leaves.filterNot(f => isSummaryFile(f.getPath)), + metadata = + leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE), + commonMetadata = + leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE)) + } - Array.tabulate[SparkPartition](rawSplits.size) { i => - new SqlNewHadoopPartition( - id, i, rawSplits.get(i).asInstanceOf[InputSplit with Writable]) - } + private def isSummaryFile(file: Path): Boolean = { + file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || + file.getName == ParquetFileWriter.PARQUET_METADATA_FILE + } + + /** + * Returns whether the reader will return the rows as batch or not. + */ + override def supportBatch(sqlContext: SQLContext, schema: StructType): Boolean = { + val conf = SQLContext.getActive().get.conf + conf.parquetVectorizedReaderEnabled && conf.wholeStageEnabled && + schema.length <= conf.wholeStageMaxNumFields && + schema.forall(_.dataType.isInstanceOf[AtomicType]) + } + + override def buildReader( + sqlContext: SQLContext, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = { + val parquetConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) + parquetConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[CatalystReadSupport].getName) + parquetConf.set( + CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA, + CatalystSchemaConverter.checkFieldNames(requiredSchema).json) + parquetConf.set( + CatalystWriteSupport.SPARK_ROW_SCHEMA, + CatalystSchemaConverter.checkFieldNames(requiredSchema).json) + + // We want to clear this temporary metadata from saving into Parquet file. + // This metadata is only useful for detecting optional columns when pushdowning filters. + val dataSchemaToWrite = StructType.removeMetadata(StructType.metadataKeyForOptionalField, + requiredSchema).asInstanceOf[StructType] + CatalystWriteSupport.setSchema(dataSchemaToWrite, parquetConf) + + // Sets flags for `CatalystSchemaConverter` + parquetConf.setBoolean( + SQLConf.PARQUET_BINARY_AS_STRING.key, + sqlContext.conf.getConf(SQLConf.PARQUET_BINARY_AS_STRING)) + parquetConf.setBoolean( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, + sqlContext.conf.getConf(SQLConf.PARQUET_INT96_AS_TIMESTAMP)) + + // Whole stage codegen (PhysicalRDD) is able to deal with batches directly + val returningBatch = + supportBatch(sqlContext, StructType(partitionSchema.fields ++ dataSchema.fields)) + + // Try to push down filters when filter push-down is enabled. + val pushed = if (sqlContext.getConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key).toBoolean) { + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(ParquetFilters.createFilter(requiredSchema, _)) + .reduceOption(FilterApi.and) + } else { + None + } + + val broadcastedConf = + sqlContext.sparkContext.broadcast(new SerializableConfiguration(parquetConf)) + + // TODO: if you move this into the closure it reverts to the default values. + // If true, enable using the custom RecordReader for parquet. This only works for + // a subset of the types (no complex types). + val enableVectorizedParquetReader: Boolean = sqlContext.conf.parquetVectorizedReaderEnabled && + dataSchema.forall(_.dataType.isInstanceOf[AtomicType]) + + (file: PartitionedFile) => { + assert(file.partitionValues.numFields == partitionSchema.size) + + val fileSplit = + new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty) + + val split = + new org.apache.parquet.hadoop.ParquetInputSplit( + fileSplit.getPath, + fileSplit.getStart, + fileSplit.getStart + fileSplit.getLength, + fileSplit.getLength, + fileSplit.getLocations, + null) + + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + val hadoopAttemptContext = new TaskAttemptContextImpl(broadcastedConf.value.value, attemptId) + + val parquetReader = if (enableVectorizedParquetReader) { + val vectorizedReader = new VectorizedParquetRecordReader() + vectorizedReader.initialize(split, hadoopAttemptContext) + logDebug(s"Appending $partitionSchema ${file.partitionValues}") + vectorizedReader.initBatch(partitionSchema, file.partitionValues) + if (returningBatch) { + vectorizedReader.enableReturningBatches() } + vectorizedReader + } else { + logDebug(s"Falling back to parquet-mr") + val reader = pushed match { + case Some(filter) => + new ParquetRecordReader[InternalRow]( + new CatalystReadSupport, + FilterCompat.get(filter, null)) + case _ => + new ParquetRecordReader[InternalRow](new CatalystReadSupport) + } + reader.initialize(split, hadoopAttemptContext) + reader + } + + val iter = new RecordReaderIterator(parquetReader) + + // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy. + if (parquetReader.isInstanceOf[VectorizedParquetRecordReader] && + enableVectorizedParquetReader) { + iter.asInstanceOf[Iterator[InternalRow]] + } else { + val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes + val joinedRow = new JoinedRow() + val appendPartitionColumns = GenerateUnsafeProjection.generate(fullSchema, fullSchema) + + // This is a horrible erasure hack... if we type the iterator above, then it actually check + // the type in next() and we get a class cast exception. If we make that function return + // Object, then we can defer the cast until later! + iter.asInstanceOf[Iterator[InternalRow]] + .map(d => appendPartitionColumns(joinedRow(d, file.partitionValues))) } } } +} - private class MetadataCache { - // `FileStatus` objects of all "_metadata" files. - private var metadataStatuses: Array[FileStatus] = _ - - // `FileStatus` objects of all "_common_metadata" files. - private var commonMetadataStatuses: Array[FileStatus] = _ - - // `FileStatus` objects of all data files (Parquet part-files). - var dataStatuses: Array[FileStatus] = _ - - // Schema of the actual Parquet files, without partition columns discovered from partition - // directory paths. - var dataSchema: StructType = null - - // Schema of the whole table, including partition columns. - var schema: StructType = _ - - // Cached leaves - var cachedLeaves: Set[FileStatus] = null - - /** - * Refreshes `FileStatus`es, footers, partition spec, and table schema. - */ - def refresh(): Unit = { - val currentLeafStatuses = cachedLeafStatuses() - - // Check if cachedLeafStatuses is changed or not - val leafStatusesChanged = (cachedLeaves == null) || - !cachedLeaves.equals(currentLeafStatuses) - - if (leafStatusesChanged) { - cachedLeaves = currentLeafStatuses.toIterator.toSet - - // Lists `FileStatus`es of all leaf nodes (files) under all base directories. - val leaves = currentLeafStatuses.filter { f => - isSummaryFile(f.getPath) || - !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) - }.toArray - - dataStatuses = leaves.filterNot(f => isSummaryFile(f.getPath)) - metadataStatuses = - leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE) - commonMetadataStatuses = - leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE) - - dataSchema = { - val dataSchema0 = maybeDataSchema - .orElse(readSchema()) - .orElse(maybeMetastoreSchema) - .getOrElse(throw new AnalysisException( - s"Failed to discover schema of Parquet file(s) in the following location(s):\n" + - paths.mkString("\n\t"))) - - // If this Parquet relation is converted from a Hive Metastore table, must reconcile case - // case insensitivity issue and possible schema mismatch (probably caused by schema - // evolution). - maybeMetastoreSchema - .map(ParquetRelation.mergeMetastoreParquetSchema(_, dataSchema0)) - .getOrElse(dataSchema0) +// NOTE: This class is instantiated and used on executor side only, no need to be serializable. +private[sql] class ParquetOutputWriter( + path: String, + bucketId: Option[Int], + context: TaskAttemptContext) + extends OutputWriter { + + private val recordWriter: RecordWriter[Void, InternalRow] = { + val outputFormat = { + new ParquetOutputFormat[InternalRow]() { + // Here we override `getDefaultWorkFile` for two reasons: + // + // 1. To allow appending. We need to generate unique output file names to avoid + // overwriting existing files (either exist before the write job, or are just written + // by other tasks within the same write job). + // + // 2. To allow dynamic partitioning. Default `getDefaultWorkFile` uses + // `FileOutputCommitter.getWorkPath()`, which points to the base directory of all + // partitions in the case of dynamic partitioning. + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val configuration = context.getConfiguration + val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId = context.getTaskAttemptID + val split = taskAttemptId.getTaskID.getId + val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") + // It has the `.parquet` extension at the end because (de)compression tools + // such as gunzip would not be able to decompress this as the compression + // is not applied on this whole file but on each "page" in Parquet format. + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString$extension") } } } - private def isSummaryFile(file: Path): Boolean = { - file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || - file.getName == ParquetFileWriter.PARQUET_METADATA_FILE - } + outputFormat.getRecordWriter(context) + } - private def readSchema(): Option[StructType] = { - // Sees which file(s) we need to touch in order to figure out the schema. - // - // Always tries the summary files first if users don't require a merged schema. In this case, - // "_common_metadata" is more preferable than "_metadata" because it doesn't contain row - // groups information, and could be much smaller for large Parquet files with lots of row - // groups. If no summary file is available, falls back to some random part-file. - // - // NOTE: Metadata stored in the summary files are merged from all part-files. However, for - // user defined key-value metadata (in which we store Spark SQL schema), Parquet doesn't know - // how to merge them correctly if some key is associated with different values in different - // part-files. When this happens, Parquet simply gives up generating the summary file. This - // implies that if a summary file presents, then: - // - // 1. Either all part-files have exactly the same Spark SQL schema, or - // 2. Some part-files don't contain Spark SQL schema in the key-value metadata at all (thus - // their schemas may differ from each other). - // - // Here we tend to be pessimistic and take the second case into account. Basically this means - // we can't trust the summary files if users require a merged schema, and must touch all part- - // files to do the merge. - val filesToTouch = - if (shouldMergeSchemas) { - // Also includes summary files, 'cause there might be empty partition directories. - - // If mergeRespectSummaries config is true, we assume that all part-files are the same for - // their schema with summary files, so we ignore them when merging schema. - // If the config is disabled, which is the default setting, we merge all part-files. - // In this mode, we only need to merge schemas contained in all those summary files. - // You should enable this configuration only if you are very sure that for the parquet - // part-files to read there are corresponding summary files containing correct schema. - - val needMerged: Seq[FileStatus] = - if (mergeRespectSummaries) { - Seq() - } else { - dataStatuses - } - (metadataStatuses ++ commonMetadataStatuses ++ needMerged).toSeq - } else { - // Tries any "_common_metadata" first. Parquet files written by old versions or Parquet - // don't have this. - commonMetadataStatuses.headOption - // Falls back to "_metadata" - .orElse(metadataStatuses.headOption) - // Summary file(s) not found, the Parquet file is either corrupted, or different part- - // files contain conflicting user defined metadata (two or more values are associated - // with a same key in different files). In either case, we fall back to any of the - // first part-file, and just assume all schemas are consistent. - .orElse(dataStatuses.headOption) - .toSeq - } + override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") - assert( - filesToTouch.nonEmpty || maybeDataSchema.isDefined || maybeMetastoreSchema.isDefined, - "No predefined schema found, " + - s"and no Parquet data files or summary files found under ${paths.mkString(", ")}.") + override protected[sql] def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) - ParquetRelation.mergeSchemasInParallel(filesToTouch, sqlContext) - } - } + override def close(): Unit = recordWriter.close(context) } private[sql] object ParquetRelation extends Logging { @@ -504,6 +422,10 @@ private[sql] object ParquetRelation extends Logging { // internally. private[sql] val METASTORE_SCHEMA = "metastoreSchema" + // If a ParquetRelation is converted from a Hive metastore table, this option is set to the + // original Hive table name. + private[sql] val METASTORE_TABLE_NAME = "metastoreTableName" + /** * If parquet's block size (row group size) setting is larger than the min split size, * we use parquet's block size setting as the min split size. Otherwise, we will create @@ -537,7 +459,7 @@ private[sql] object ParquetRelation extends Logging { parquetFilterPushDown: Boolean, assumeBinaryIsString: Boolean, assumeInt96IsTimestamp: Boolean)(job: Job): Unit = { - val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val conf = job.getConfiguration conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[CatalystReadSupport].getName) // Try to push down filters when filter push-down is enabled. @@ -580,7 +502,7 @@ private[sql] object ParquetRelation extends Logging { FileInputFormat.setInputPaths(job, inputFiles.map(_.getPath): _*) } - overrideMinSplitSize(parquetBlockSize, SparkHadoopUtil.get.getConfigurationFromJobContext(job)) + overrideMinSplitSize(parquetBlockSize, job.getConfiguration) } private[parquet] def readSchema( @@ -615,7 +537,7 @@ private[sql] object ParquetRelation extends Logging { logInfo( s"Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + "falling back to the deprecated DataType.fromCaseClassString parser.") - DataType.fromCaseClassString(serializedSchema.get) + LegacyTypeStringParser.parse(serializedSchema.get) } .recover { case cause: Throwable => logWarning( @@ -650,7 +572,7 @@ private[sql] object ParquetRelation extends Logging { * distinguish binary and string). This method generates a correct schema by merging Metastore * schema data types and Parquet schema field names. */ - private[parquet] def mergeMetastoreParquetSchema( + private[sql] def mergeMetastoreParquetSchema( metastoreSchema: StructType, parquetSchema: StructType): StructType = { def schemaConflictMessage: String = @@ -766,12 +688,37 @@ private[sql] object ParquetRelation extends Logging { assumeInt96IsTimestamp = assumeInt96IsTimestamp, writeLegacyParquetFormat = writeLegacyParquetFormat) - footers.map { footer => - ParquetRelation.readSchemaFromFooter(footer, converter) - }.reduceOption(_ merge _).iterator + if (footers.isEmpty) { + Iterator.empty + } else { + var mergedSchema = ParquetRelation.readSchemaFromFooter(footers.head, converter) + footers.tail.foreach { footer => + val schema = ParquetRelation.readSchemaFromFooter(footer, converter) + try { + mergedSchema = mergedSchema.merge(schema) + } catch { case cause: SparkException => + throw new SparkException( + s"Failed merging schema of file ${footer.getFile}:\n${schema.treeString}", cause) + } + } + Iterator.single(mergedSchema) + } }.collect() - partiallyMergedSchemas.reduceOption(_ merge _) + if (partiallyMergedSchemas.isEmpty) { + None + } else { + var finalSchema = partiallyMergedSchemas.head + partiallyMergedSchemas.tail.foreach { schema => + try { + finalSchema = finalSchema.merge(schema) + } catch { case cause: SparkException => + throw new SparkException( + s"Failed merging schema:\n${schema.treeString}", cause) + } + } + Some(finalSchema) + } } /** @@ -798,7 +745,7 @@ private[sql] object ParquetRelation extends Logging { logInfo( s"Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + "falling back to the deprecated DataType.fromCaseClassString parser.") - DataType.fromCaseClassString(schemaString).asInstanceOf[StructType] + LegacyTypeStringParser.parse(schemaString).asInstanceOf[StructType] }.recoverWith { case cause: Throwable => logWarning( @@ -837,17 +784,9 @@ private[sql] object ParquetRelation extends Logging { // scalastyle:on classforname redirect(JLogger.getLogger("parquet")) } catch { case _: Throwable => - // SPARK-9974: com.twitter:parquet-hadoop-bundle:1.6.0 is not packaged into the assembly jar + // SPARK-9974: com.twitter:parquet-hadoop-bundle:1.6.0 is not packaged into the assembly // when Spark is built with SBT. So `parquet.Log` may not be found. This try/catch block // should be removed after this issue is fixed. } } - - // The parquet compression short names - val shortParquetCompressionCodecNames = Map( - "NONE" -> CompressionCodecName.UNCOMPRESSED, - "UNCOMPRESSED" -> CompressionCodecName.UNCOMPRESSED, - "SNAPPY" -> CompressionCodecName.SNAPPY, - "GZIP" -> CompressionCodecName.GZIP, - "LZO" -> CompressionCodecName.LZO) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 1a8e7ab202dc2..28ac4583e9b25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.{AnalysisException, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast} +import org.apache.spark.sql.catalyst.catalog.SessionCatalog +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation, InsertableRelation} -import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} /** * Try to replaces [[UnresolvedRelation]]s with [[ResolvedDataSource]]. @@ -32,14 +34,12 @@ private[sql] class ResolveDataSource(sqlContext: SQLContext) extends Rule[Logica def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedRelation if u.tableIdentifier.database.isDefined => try { - val resolved = ResolvedDataSource( + val dataSource = DataSource( sqlContext, - userSpecifiedSchema = None, - partitionColumns = Array(), - provider = u.tableIdentifier.database.get, - options = Map("path" -> u.tableIdentifier.table)) - val plan = LogicalRelation(resolved.relation) - u.alias.map(a => Subquery(u.alias.get, plan)).getOrElse(plan) + paths = u.tableIdentifier.table :: Nil, + className = u.tableIdentifier.database.get) + val plan = LogicalRelation(dataSource.resolveRelation()) + u.alias.map(a => SubqueryAlias(u.alias.get, plan)).getOrElse(plan) } catch { case e: ClassNotFoundException => u case e: Exception => @@ -61,7 +61,7 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] { // We are inserting into an InsertableRelation or HadoopFsRelation. case i @ InsertIntoTable( - l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _), _, child, _, _) => { + l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _), _, child, _, _) => // First, make sure the data to be inserted have the same number of fields with the // schema of the relation. if (l.output.size != child.output.size) { @@ -70,7 +70,6 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] { s"statement generates the same number of columns as its schema.") } castAndRenameChildOutput(i, l.output, child) - } } /** If necessary, cast data types and rename fields to the expected types and names. */ @@ -102,20 +101,23 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] { /** * A rule to do various checks before inserting into or writing to a data source table. */ -private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => Unit) { +private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) + extends (LogicalPlan => Unit) { + def failAnalysis(msg: String): Unit = { throw new AnalysisException(msg) } def apply(plan: LogicalPlan): Unit = { plan.foreach { case i @ logical.InsertIntoTable( - l @ LogicalRelation(t: InsertableRelation, _), partition, query, overwrite, ifNotExists) => + l @ LogicalRelation(t: InsertableRelation, _, _), + partition, query, overwrite, ifNotExists) => // Right now, we do not support insert into a data source table with partition specs. if (partition.nonEmpty) { failAnalysis(s"Insert into a partition is not allowed because $l is not partitioned.") } else { // Get all input data source relations of the query. val srcRelations = query.collect { - case LogicalRelation(src: BaseRelation, _) => src + case LogicalRelation(src: BaseRelation, _, _) => src } if (srcRelations.contains(t)) { failAnalysis( @@ -126,10 +128,10 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => } case logical.InsertIntoTable( - LogicalRelation(r: HadoopFsRelation, _), part, query, overwrite, _) => + LogicalRelation(r: HadoopFsRelation, _, _), part, query, overwrite, _) => // We need to make sure the partition columns specified by users do match partition // columns of the relation. - val existingPartitionColumns = r.partitionColumns.fieldNames.toSet + val existingPartitionColumns = r.partitionSchema.fieldNames.toSet val specifiedPartitionColumns = part.keySet if (existingPartitionColumns != specifiedPartitionColumns) { failAnalysis(s"Specified partition columns " + @@ -141,11 +143,11 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => } PartitioningUtils.validatePartitionColumnDataTypes( - r.schema, part.keySet.toArray, catalog.conf.caseSensitiveAnalysis) + r.schema, part.keySet.toSeq, conf.caseSensitiveAnalysis) // Get all input data source relations of the query. val srcRelations = query.collect { - case LogicalRelation(src: BaseRelation, _) => src + case LogicalRelation(src: BaseRelation, _, _) => src } if (srcRelations.contains(r)) { failAnalysis( @@ -165,22 +167,22 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // OK } - case CreateTableUsingAsSelect(tableIdent, _, _, partitionColumns, mode, _, query) => + case c: CreateTableUsingAsSelect => // When the SaveMode is Overwrite, we need to check if the table is an input table of // the query. If so, we will throw an AnalysisException to let users know it is not allowed. - if (mode == SaveMode.Overwrite && catalog.tableExists(tableIdent)) { + if (c.mode == SaveMode.Overwrite && catalog.tableExists(c.tableIdent)) { // Need to remove SubQuery operator. - EliminateSubQueries(catalog.lookupRelation(tableIdent)) match { + EliminateSubqueryAliases(catalog.lookupRelation(c.tableIdent)) match { // Only do the check if the table is a data source table // (the relation is a BaseRelation). - case l @ LogicalRelation(dest: BaseRelation, _) => + case l @ LogicalRelation(dest: BaseRelation, _, _) => // Get all input data source relations of the query. - val srcRelations = query.collect { - case LogicalRelation(src: BaseRelation, _) => src + val srcRelations = c.child.collect { + case LogicalRelation(src: BaseRelation, _, _) => src } if (srcRelations.contains(dest)) { failAnalysis( - s"Cannot overwrite table $tableIdent that is also being read from.") + s"Cannot overwrite table ${c.tableIdent} that is also being read from.") } else { // OK } @@ -192,7 +194,17 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => } PartitioningUtils.validatePartitionColumnDataTypes( - query.schema, partitionColumns, catalog.conf.caseSensitiveAnalysis) + c.child.schema, c.partitionColumns, conf.caseSensitiveAnalysis) + + for { + spec <- c.bucketSpec + sortColumnName <- spec.sortColumnNames + sortColumn <- c.child.schema.find(_.name == sortColumnName) + } { + if (!RowOrdering.isOrderable(sortColumn.dataType)) { + failAnalysis(s"Cannot use ${sortColumn.dataType.simpleString} for sorting column.") + } + } case _ => // OK } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 52c4421d7e87e..94ecb7a28663c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -17,43 +17,25 @@ package org.apache.spark.sql.execution.datasources.text -import com.google.common.base.Objects -import org.apache.hadoop.fs.{Path, FileStatus} -import org.apache.hadoop.io.{NullWritable, Text, LongWritable} -import org.apache.hadoop.mapred.{TextInputFormat, JobConf} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.{NullWritable, Text} +import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat -import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext, Job} -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, GenericMutableRow} -import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeRowWriter, BufferHolder} -import org.apache.spark.sql.columnar.MutableUnsafeRow import org.apache.spark.sql.{AnalysisException, Row, SQLContext} -import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.execution.datasources.{CompressionCodecs, HadoopFileLinesReader, PartitionedFile} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.SerializableConfiguration /** * A data source for reading text files. */ -class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { - - override def createRelation( - sqlContext: SQLContext, - paths: Array[String], - dataSchema: Option[StructType], - partitionColumns: Option[StructType], - parameters: Map[String, String]): HadoopFsRelation = { - dataSchema.foreach(verifySchema) - new TextRelation(None, partitionColumns, paths)(sqlContext) - } +class DefaultSource extends FileFormat with DataSourceRegister { override def shortName(): String = "text" @@ -68,90 +50,79 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { s"Text data source supports only a string column, but you have ${tpe.simpleString}.") } } -} -private[sql] class TextRelation( - val maybePartitionSpec: Option[PartitionSpec], - override val userDefinedPartitionColumns: Option[StructType], - override val paths: Array[String] = Array.empty[String]) - (@transient val sqlContext: SQLContext) - extends HadoopFsRelation(maybePartitionSpec) { - - /** Data schema is always a single column, named "text". */ - override def dataSchema: StructType = new StructType().add("text", StringType) - - /** This is an internal data source that outputs internal row format. */ - override val needConversion: Boolean = false - - - override private[sql] def buildInternalScan( - requiredColumns: Array[String], - filters: Array[Filter], - inputPaths: Array[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { - val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) - val paths = inputPaths.map(_.getPath).sortBy(_.toUri) + override def inferSchema( + sqlContext: SQLContext, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = Some(new StructType().add("value", StringType)) - if (paths.nonEmpty) { - FileInputFormat.setInputPaths(job, paths: _*) + override def prepareWrite( + sqlContext: SQLContext, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + verifySchema(dataSchema) + + val conf = job.getConfiguration + val compressionCodec = options.get("compression").map(CompressionCodecs.getCodecClassName) + compressionCodec.foreach { codec => + CompressionCodecs.setCodecConfiguration(conf, codec) } - sqlContext.sparkContext.hadoopRDD( - conf.asInstanceOf[JobConf], classOf[TextInputFormat], classOf[LongWritable], classOf[Text]) - .mapPartitions { iter => - val bufferHolder = new BufferHolder - val unsafeRowWriter = new UnsafeRowWriter - val unsafeRow = new UnsafeRow - - iter.map { case (_, line) => - // Writes to an UnsafeRow directly - bufferHolder.reset() - unsafeRowWriter.initialize(bufferHolder, 1) - unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) - unsafeRow.pointTo(bufferHolder.buffer, 1, bufferHolder.totalSize()) - unsafeRow - } - } - } - - /** Write path. */ - override def prepareJobForWrite(job: Job): OutputWriterFactory = { new OutputWriterFactory { override def newInstance( path: String, + bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { + if (bucketId.isDefined) { + throw new AnalysisException("Text doesn't support bucketing") + } new TextOutputWriter(path, dataSchema, context) } } } - override def equals(other: Any): Boolean = other match { - case that: TextRelation => - paths.toSet == that.paths.toSet && partitionColumns == that.partitionColumns - case _ => false - } - - override def hashCode(): Int = { - Objects.hashCode(paths.toSet, partitionColumns) + override def buildReader( + sqlContext: SQLContext, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = { + val conf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) + val broadcastedConf = + sqlContext.sparkContext.broadcast(new SerializableConfiguration(conf)) + + file => { + val unsafeRow = new UnsafeRow(1) + val bufferHolder = new BufferHolder(unsafeRow) + val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) + + new HadoopFileLinesReader(file, broadcastedConf.value.value).map { line => + // Writes to an UnsafeRow directly + bufferHolder.reset() + unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) + unsafeRow.setTotalSize(bufferHolder.totalSize()) + unsafeRow + } + } } } class TextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriter - with SparkHadoopMapRedUtil { + extends OutputWriter { private[this] val buffer = new Text() private val recordWriter: RecordWriter[NullWritable, Text] = { new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val configuration = context.getConfiguration val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") - val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId.txt$extension") } }.getRecordWriter(context) } @@ -168,3 +139,4 @@ class TextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemp recordWriter.close(context) } } + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 74892e4e13fa4..17eae88b49dec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql.execution import scala.collection.mutable.HashSet +import org.apache.spark.{Accumulator, AccumulatorParam} +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.{Accumulator, AccumulatorParam, Logging} +import org.apache.spark.sql.internal.SQLConf /** * Contains methods for debugging query execution. @@ -32,12 +35,38 @@ import org.apache.spark.{Accumulator, AccumulatorParam, Logging} * Usage: * {{{ * import org.apache.spark.sql.execution.debug._ - * sql("SELECT key FROM src").debug() - * dataFrame.typeCheck() + * sql("SELECT 1").debug() + * sql("SELECT 1").debugCodegen() * }}} */ package object debug { + /** Helper function to evade the println() linter. */ + private def debugPrint(msg: String): Unit = { + // scalastyle:off println + println(msg) + // scalastyle:on println + } + + def codegenString(plan: SparkPlan): String = { + val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegen]() + plan transform { + case s: WholeStageCodegen => + codegenSubtrees += s + s + case s => s + } + var output = s"Found ${codegenSubtrees.size} WholeStageCodegen subtrees.\n" + for ((s, i) <- codegenSubtrees.toSeq.zipWithIndex) { + output += s"== Subtree ${i + 1} / ${codegenSubtrees.size} ==\n" + output += s + output += "\nGenerated code:\n" + val (_, source) = s.doCodeGen() + output += s"${CodeFormatter.format(source)}\n" + } + output + } + /** * Augments [[SQLContext]] with debug methods. */ @@ -48,9 +77,9 @@ package object debug { } /** - * Augments [[DataFrame]]s with debug methods. + * Augments [[Dataset]]s with debug methods. */ - implicit class DebugQuery(query: DataFrame) extends Logging { + implicit class DebugQuery(query: Dataset[_]) extends Logging { def debug(): Unit = { val plan = query.queryExecution.executedPlan val visited = new collection.mutable.HashSet[TreeNodeRef]() @@ -59,15 +88,23 @@ package object debug { visited += new TreeNodeRef(s) DebugNode(s) } - logDebug(s"Results returned: ${debugPlan.execute().count()}") + debugPrint(s"Results returned: ${debugPlan.execute().count()}") debugPlan.foreach { case d: DebugNode => d.dumpStats() case _ => } } + + /** + * Prints to stdout all the generated code found in this plan (i.e. the output of each + * WholeStageCodegen subtree). + */ + def debugCodegen(): Unit = { + debugPrint(codegenString(query.queryExecution.executedPlan)) + } } - private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode { + private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode with CodegenSupport { def output: Seq[Attribute] = child.output implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] { @@ -84,22 +121,24 @@ package object debug { /** * A collection of metrics for each column of output. + * * @param elementTypes the actual runtime types for the output. Useful when there are bugs - * causing the wrong data to be projected. + * causing the wrong data to be projected. */ case class ColumnMetrics( - elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty)) + elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty)) + val tupleCount: Accumulator[Int] = sparkContext.accumulator[Int](0) val numColumns: Int = child.output.size val columnStats: Array[ColumnMetrics] = Array.fill(child.output.size)(new ColumnMetrics()) def dumpStats(): Unit = { - logDebug(s"== ${child.simpleString} ==") - logDebug(s"Tuples output: ${tupleCount.value}") - child.output.zip(columnStats).foreach { case(attr, metric) => + debugPrint(s"== ${child.simpleString} ==") + debugPrint(s"Tuples output: ${tupleCount.value}") + child.output.zip(columnStats).foreach { case (attr, metric) => val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}") - logDebug(s" ${attr.name} ${attr.dataType}: $actualDataTypes") + debugPrint(s" ${attr.name} ${attr.dataType}: $actualDataTypes") } } @@ -107,6 +146,7 @@ package object debug { child.execute().mapPartitions { iter => new Iterator[InternalRow] { def hasNext: Boolean = iter.hasNext + def next(): InternalRow = { val currentRow = iter.next() tupleCount += 1 @@ -123,5 +163,17 @@ package object debug { } } } + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + consume(ctx, input) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala new file mode 100644 index 0000000000000..102a9356df311 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.exchange + +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration._ + +import org.apache.spark.broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning} +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} +import org.apache.spark.util.ThreadUtils + +/** + * A [[BroadcastExchange]] collects, transforms and finally broadcasts the result of a transformed + * SparkPlan. + */ +case class BroadcastExchange( + mode: BroadcastMode, + child: SparkPlan) extends Exchange { + + override def outputPartitioning: Partitioning = BroadcastPartitioning(mode) + + override def sameResult(plan: SparkPlan): Boolean = plan match { + case p: BroadcastExchange => + mode.compatibleWith(p.mode) && child.sameResult(p.child) + case _ => false + } + + @transient + private val timeout: Duration = { + val timeoutValue = sqlContext.conf.broadcastTimeout + if (timeoutValue < 0) { + Duration.Inf + } else { + timeoutValue.seconds + } + } + + @transient + private lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { + // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + Future { + // This will run in another thread. Set the execution id so that we can connect these jobs + // with the correct execution. + SQLExecution.withExecutionId(sparkContext, executionId) { + // Note that we use .executeCollect() because we don't want to convert data to Scala types + val input: Array[InternalRow] = child.executeCollect() + + // Construct and broadcast the relation. + sparkContext.broadcast(mode.transform(input)) + } + }(BroadcastExchange.executionContext) + } + + override protected def doPrepare(): Unit = { + // Materialize the future. + relationFuture + } + + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException( + "BroadcastExchange does not support the execute() code path.") + } + + override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { + val result = Await.result(relationFuture, timeout) + result.asInstanceOf[broadcast.Broadcast[T]] + } +} + +object BroadcastExchange { + private[execution] val executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("broadcast-exchange", 128)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala new file mode 100644 index 0000000000000..4864db7f2ac9b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.exchange + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.internal.SQLConf + +/** + * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]] + * of input data meets the + * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for + * each operator by inserting [[ShuffleExchange]] Operators where required. Also ensure that the + * input partition ordering requirements are met. + */ +case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { + private def defaultNumPreShufflePartitions: Int = conf.numShufflePartitions + + private def targetPostShuffleInputSize: Long = conf.targetPostShuffleInputSize + + private def adaptiveExecutionEnabled: Boolean = conf.adaptiveExecutionEnabled + + private def minNumPostShufflePartitions: Option[Int] = { + val minNumPostShufflePartitions = conf.minNumPostShufflePartitions + if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None + } + + /** + * Given a required distribution, returns a partitioning that satisfies that distribution. + */ + private def createPartitioning( + requiredDistribution: Distribution, + numPartitions: Int): Partitioning = { + requiredDistribution match { + case AllTuples => SinglePartition + case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions) + case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions) + case dist => sys.error(s"Do not know how to satisfy distribution $dist") + } + } + + /** + * Adds [[ExchangeCoordinator]] to [[ShuffleExchange]]s if adaptive query execution is enabled + * and partitioning schemes of these [[ShuffleExchange]]s support [[ExchangeCoordinator]]. + */ + private def withExchangeCoordinator( + children: Seq[SparkPlan], + requiredChildDistributions: Seq[Distribution]): Seq[SparkPlan] = { + val supportsCoordinator = + if (children.exists(_.isInstanceOf[ShuffleExchange])) { + // Right now, ExchangeCoordinator only support HashPartitionings. + children.forall { + case e @ ShuffleExchange(hash: HashPartitioning, _, _) => true + case child => + child.outputPartitioning match { + case hash: HashPartitioning => true + case collection: PartitioningCollection => + collection.partitionings.forall(_.isInstanceOf[HashPartitioning]) + case _ => false + } + } + } else { + // In this case, although we do not have Exchange operators, we may still need to + // shuffle data when we have more than one children because data generated by + // these children may not be partitioned in the same way. + // Please see the comment in withCoordinator for more details. + val supportsDistribution = + requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution]) + children.length > 1 && supportsDistribution + } + + val withCoordinator = + if (adaptiveExecutionEnabled && supportsCoordinator) { + val coordinator = + new ExchangeCoordinator( + children.length, + targetPostShuffleInputSize, + minNumPostShufflePartitions) + children.zip(requiredChildDistributions).map { + case (e: ShuffleExchange, _) => + // This child is an Exchange, we need to add the coordinator. + e.copy(coordinator = Some(coordinator)) + case (child, distribution) => + // If this child is not an Exchange, we need to add an Exchange for now. + // Ideally, we can try to avoid this Exchange. However, when we reach here, + // there are at least two children operators (because if there is a single child + // and we can avoid Exchange, supportsCoordinator will be false and we + // will not reach here.). Although we can make two children have the same number of + // post-shuffle partitions. Their numbers of pre-shuffle partitions may be different. + // For example, let's say we have the following plan + // Join + // / \ + // Agg Exchange + // / \ + // Exchange t2 + // / + // t1 + // In this case, because a post-shuffle partition can include multiple pre-shuffle + // partitions, a HashPartitioning will not be strictly partitioned by the hashcodes + // after shuffle. So, even we can use the child Exchange operator of the Join to + // have a number of post-shuffle partitions that matches the number of partitions of + // Agg, we cannot say these two children are partitioned in the same way. + // Here is another case + // Join + // / \ + // Agg1 Agg2 + // / \ + // Exchange1 Exchange2 + // / \ + // t1 t2 + // In this case, two Aggs shuffle data with the same column of the join condition. + // After we use ExchangeCoordinator, these two Aggs may not be partitioned in the same + // way. Let's say that Agg1 and Agg2 both have 5 pre-shuffle partitions and 2 + // post-shuffle partitions. It is possible that Agg1 fetches those pre-shuffle + // partitions by using a partitionStartIndices [0, 3]. However, Agg2 may fetch its + // pre-shuffle partitions by using another partitionStartIndices [0, 4]. + // So, Agg1 and Agg2 are actually not co-partitioned. + // + // It will be great to introduce a new Partitioning to represent the post-shuffle + // partitions when one post-shuffle partition includes multiple pre-shuffle partitions. + val targetPartitioning = + createPartitioning(distribution, defaultNumPreShufflePartitions) + assert(targetPartitioning.isInstanceOf[HashPartitioning]) + ShuffleExchange(targetPartitioning, child, Some(coordinator)) + } + } else { + // If we do not need ExchangeCoordinator, the original children are returned. + children + } + + withCoordinator + } + + private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = { + val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution + val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering + var children: Seq[SparkPlan] = operator.children + assert(requiredChildDistributions.length == children.length) + assert(requiredChildOrderings.length == children.length) + + // Ensure that the operator's children satisfy their output distribution requirements: + children = children.zip(requiredChildDistributions).map { + case (child, distribution) if child.outputPartitioning.satisfies(distribution) => + child + case (child, BroadcastDistribution(mode)) => + BroadcastExchange(mode, child) + case (child, distribution) => + ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child) + } + + // If the operator has multiple children and specifies child output distributions (e.g. join), + // then the children's output partitionings must be compatible: + def requireCompatiblePartitioning(distribution: Distribution): Boolean = distribution match { + case UnspecifiedDistribution => false + case BroadcastDistribution(_) => false + case _ => true + } + if (children.length > 1 + && requiredChildDistributions.exists(requireCompatiblePartitioning) + && !Partitioning.allCompatible(children.map(_.outputPartitioning))) { + + // First check if the existing partitions of the children all match. This means they are + // partitioned by the same partitioning into the same number of partitions. In that case, + // don't try to make them match `defaultPartitions`, just use the existing partitioning. + val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max + val useExistingPartitioning = children.zip(requiredChildDistributions).forall { + case (child, distribution) => + child.outputPartitioning.guarantees( + createPartitioning(distribution, maxChildrenNumPartitions)) + } + + children = if (useExistingPartitioning) { + // We do not need to shuffle any child's output. + children + } else { + // We need to shuffle at least one child's output. + // Now, we will determine the number of partitions that will be used by created + // partitioning schemes. + val numPartitions = { + // Let's see if we need to shuffle all child's outputs when we use + // maxChildrenNumPartitions. + val shufflesAllChildren = children.zip(requiredChildDistributions).forall { + case (child, distribution) => + !child.outputPartitioning.guarantees( + createPartitioning(distribution, maxChildrenNumPartitions)) + } + // If we need to shuffle all children, we use defaultNumPreShufflePartitions as the + // number of partitions. Otherwise, we use maxChildrenNumPartitions. + if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions + } + + children.zip(requiredChildDistributions).map { + case (child, distribution) => + val targetPartitioning = createPartitioning(distribution, numPartitions) + if (child.outputPartitioning.guarantees(targetPartitioning)) { + child + } else { + child match { + // If child is an exchange, we replace it with + // a new one having targetPartitioning. + case ShuffleExchange(_, c, _) => ShuffleExchange(targetPartitioning, c) + case _ => ShuffleExchange(targetPartitioning, child) + } + } + } + } + } + + // Now, we need to add ExchangeCoordinator if necessary. + // Actually, it is not a good idea to add ExchangeCoordinators while we are adding Exchanges. + // However, with the way that we plan the query, we do not have a place where we have a + // global picture of all shuffle dependencies of a post-shuffle stage. So, we add coordinator + // at here for now. + // Once we finish https://issues.apache.org/jira/browse/SPARK-10665, + // we can first add Exchanges and then add coordinator once we have a DAG of query fragments. + children = withExchangeCoordinator(children, requiredChildDistributions) + + // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings: + children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) => + if (requiredOrdering.nonEmpty) { + // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort. + if (requiredOrdering != child.outputOrdering.take(requiredOrdering.length)) { + Sort(requiredOrdering, global = false, child = child) + } else { + child + } + } else { + child + } + } + + operator.withNewChildren(children) + } + + def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case operator @ ShuffleExchange(partitioning, child, _) => + child.children match { + case ShuffleExchange(childPartitioning, baseChild, _)::Nil => + if (childPartitioning.guarantees(partitioning)) child else operator + case _ => operator + } + case operator: SparkPlan => ensureDistributionAndOrdering(operator) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala new file mode 100644 index 0000000000000..df7ad48812051 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.exchange + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType + +/** + * Base class for operators that exchange data among multiple threads or processes. + * + * Exchanges are the key class of operators that enable parallelism. Although the implementation + * differs significantly, the concept is similar to the exchange operator described in + * "Volcano -- An Extensible and Parallel Query Evaluation System" by Goetz Graefe. + */ +abstract class Exchange extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + +/** + * A wrapper for reused exchange to have different output, because two exchanges which produce + * logically identical output will have distinct sets of output attribute ids, so we need to + * preserve the original ids because they're what downstream operators are expecting. + */ +case class ReusedExchange(override val output: Seq[Attribute], child: Exchange) extends LeafNode { + + override def sameResult(plan: SparkPlan): Boolean = { + // Ignore this wrapper. `plan` could also be a ReusedExchange, so we reverse the order here. + plan.sameResult(child) + } + + def doExecute(): RDD[InternalRow] = { + child.execute() + } + + override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { + child.executeBroadcast() + } + + // Do not repeat the same tree in explain. + override def treeChildren: Seq[SparkPlan] = Nil +} + +/** + * Find out duplicated exchanges in the spark plan, then use the same exchange for all the + * references. + */ +case class ReuseExchange(conf: SQLConf) extends Rule[SparkPlan] { + + def apply(plan: SparkPlan): SparkPlan = { + if (!conf.exchangeReuseEnabled) { + return plan + } + // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls. + val exchanges = mutable.HashMap[StructType, ArrayBuffer[Exchange]]() + plan.transformUp { + case exchange: Exchange => + // the exchanges that have same results usually also have same schemas (same column names). + val sameSchema = exchanges.getOrElseUpdate(exchange.schema, ArrayBuffer[Exchange]()) + val samePlan = sameSchema.find { e => + exchange.sameResult(e) + } + if (samePlan.isDefined) { + // Keep the output of this exchange, the following plans require that to resolve + // attributes. + ReusedExchange(exchange.output, samePlan.get) + } else { + sameSchema += exchange + exchange + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala similarity index 76% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala index 8dbd69e1f44b8..fb60d68f986d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala @@ -15,15 +15,18 @@ * limitations under the License. */ -package org.apache.spark.sql.execution +package org.apache.spark.sql.execution.exchange -import java.util.{Map => JMap, HashMap => JHashMap} +import java.util.{HashMap => JHashMap, Map => JMap} +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{Logging, SimpleFutureAction, ShuffleDependency, MapOutputStatistics} +import org.apache.spark.{MapOutputStatistics, ShuffleDependency, SimpleFutureAction} +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan} /** * A coordinator used to determines how we shuffle data between stages generated by Spark SQL. @@ -32,9 +35,9 @@ import org.apache.spark.sql.catalyst.InternalRow * * A coordinator is constructed with three parameters, `numExchanges`, * `targetPostShuffleInputSize`, and `minNumPostShufflePartitions`. - * - `numExchanges` is used to indicated that how many [[Exchange]]s that will be registered to - * this coordinator. So, when we start to do any actual work, we have a way to make sure that - * we have got expected number of [[Exchange]]s. + * - `numExchanges` is used to indicated that how many [[ShuffleExchange]]s that will be registered + * to this coordinator. So, when we start to do any actual work, we have a way to make sure that + * we have got expected number of [[ShuffleExchange]]s. * - `targetPostShuffleInputSize` is the targeted size of a post-shuffle partition's * input data size. With this parameter, we can estimate the number of post-shuffle partitions. * This parameter is configured through @@ -44,26 +47,27 @@ import org.apache.spark.sql.catalyst.InternalRow * partitions. * * The workflow of this coordinator is described as follows: - * - Before the execution of a [[SparkPlan]], for an [[Exchange]] operator, + * - Before the execution of a [[SparkPlan]], for an [[ShuffleExchange]] operator, * if an [[ExchangeCoordinator]] is assigned to it, it registers itself to this coordinator. * This happens in the `doPrepare` method. - * - Once we start to execute a physical plan, an [[Exchange]] registered to this coordinator will - * call `postShuffleRDD` to get its corresponding post-shuffle [[ShuffledRowRDD]]. - * If this coordinator has made the decision on how to shuffle data, this [[Exchange]] will - * immediately get its corresponding post-shuffle [[ShuffledRowRDD]]. + * - Once we start to execute a physical plan, an [[ShuffleExchange]] registered to this + * coordinator will call `postShuffleRDD` to get its corresponding post-shuffle + * [[ShuffledRowRDD]]. + * If this coordinator has made the decision on how to shuffle data, this [[ShuffleExchange]] + * will immediately get its corresponding post-shuffle [[ShuffledRowRDD]]. * - If this coordinator has not made the decision on how to shuffle data, it will ask those - * registered [[Exchange]]s to submit their pre-shuffle stages. Then, based on the the size - * statistics of pre-shuffle partitions, this coordinator will determine the number of + * registered [[ShuffleExchange]]s to submit their pre-shuffle stages. Then, based on the + * size statistics of pre-shuffle partitions, this coordinator will determine the number of * post-shuffle partitions and pack multiple pre-shuffle partitions with continuous indices * to a single post-shuffle partition whenever necessary. * - Finally, this coordinator will create post-shuffle [[ShuffledRowRDD]]s for all registered - * [[Exchange]]s. So, when an [[Exchange]] calls `postShuffleRDD`, this coordinator can - * lookup the corresponding [[RDD]]. + * [[ShuffleExchange]]s. So, when an [[ShuffleExchange]] calls `postShuffleRDD`, this coordinator + * can lookup the corresponding [[RDD]]. * * The strategy used to determine the number of post-shuffle partitions is described as follows. * To determine the number of post-shuffle partitions, we have a target input size for a * post-shuffle partition. Once we have size statistics of pre-shuffle partitions from stages - * corresponding to the registered [[Exchange]]s, we will do a pass of those statistics and + * corresponding to the registered [[ShuffleExchange]]s, we will do a pass of those statistics and * pack pre-shuffle partitions with continuous indices to a single post-shuffle partition until * the size of a post-shuffle partition is equal or greater than the target size. * For example, we have two stages with the following pre-shuffle partition size statistics: @@ -82,11 +86,11 @@ private[sql] class ExchangeCoordinator( extends Logging { // The registered Exchange operators. - private[this] val exchanges = ArrayBuffer[Exchange]() + private[this] val exchanges = ArrayBuffer[ShuffleExchange]() // This map is used to lookup the post-shuffle ShuffledRowRDD for an Exchange operator. - private[this] val postShuffleRDDs: JMap[Exchange, ShuffledRowRDD] = - new JHashMap[Exchange, ShuffledRowRDD](numExchanges) + private[this] val postShuffleRDDs: JMap[ShuffleExchange, ShuffledRowRDD] = + new JHashMap[ShuffleExchange, ShuffledRowRDD](numExchanges) // A boolean that indicates if this coordinator has made decision on how to shuffle data. // This variable will only be updated by doEstimationIfNecessary, which is protected by @@ -94,10 +98,11 @@ private[sql] class ExchangeCoordinator( @volatile private[this] var estimated: Boolean = false /** - * Registers an [[Exchange]] operator to this coordinator. This method is only allowed to be - * called in the `doPrepare` method of an [[Exchange]] operator. + * Registers an [[ShuffleExchange]] operator to this coordinator. This method is only allowed to + * be called in the `doPrepare` method of an [[ShuffleExchange]] operator. */ - def registerExchange(exchange: Exchange): Unit = synchronized { + @GuardedBy("this") + def registerExchange(exchange: ShuffleExchange): Unit = synchronized { exchanges += exchange } @@ -109,7 +114,7 @@ private[sql] class ExchangeCoordinator( */ private[sql] def estimatePartitionStartIndices( mapOutputStatistics: Array[MapOutputStatistics]): Array[Int] = { - // If we have mapOutputStatistics.length <= numExchange, it is because we do not submit + // If we have mapOutputStatistics.length < numExchange, it is because we do not submit // a stage when the number of partitions of this dependency is 0. assert(mapOutputStatistics.length <= numExchanges) @@ -121,6 +126,8 @@ private[sql] class ExchangeCoordinator( val totalPostShuffleInputSize = mapOutputStatistics.map(_.bytesByPartitionId.sum).sum // The max at here is to make sure that when we have an empty table, we // only have a single post-shuffle partition. + // There is no particular reason that we pick 16. We just need a number to + // prevent maxPostShuffleInputSize from being set to 0. val maxPostShuffleInputSize = math.max(math.ceil(totalPostShuffleInputSize / numPartitions.toDouble).toLong, 16) math.min(maxPostShuffleInputSize, advisoryTargetPostShuffleInputSize) @@ -135,6 +142,12 @@ private[sql] class ExchangeCoordinator( // Make sure we do get the same number of pre-shuffle partitions for those stages. val distinctNumPreShufflePartitions = mapOutputStatistics.map(stats => stats.bytesByPartitionId.length).distinct + // The reason that we are expecting a single value of the number of pre-shuffle partitions + // is that when we add Exchanges, we set the number of pre-shuffle partitions + // (i.e. map output partitions) using a static setting, which is the value of + // spark.sql.shuffle.partitions. Even if two input RDDs are having different + // number of partitions, they will have the same number of pre-shuffle partitions + // (i.e. map output partitions). assert( distinctNumPreShufflePartitions.length == 1, "There should be only one distinct value of the number pre-shuffle partitions " + @@ -177,6 +190,7 @@ private[sql] class ExchangeCoordinator( partitionStartIndices.toArray } + @GuardedBy("this") private def doEstimationIfNecessary(): Unit = synchronized { // It is unlikely that this method will be called from multiple threads // (when multiple threads trigger the execution of THIS physical) @@ -188,7 +202,7 @@ private[sql] class ExchangeCoordinator( // Make sure we have the expected number of registered Exchange operators. assert(exchanges.length == numExchanges) - val newPostShuffleRDDs = new JHashMap[Exchange, ShuffledRowRDD](numExchanges) + val newPostShuffleRDDs = new JHashMap[ShuffleExchange, ShuffledRowRDD](numExchanges) // Submit all map stages val shuffleDependencies = ArrayBuffer[ShuffleDependency[Int, InternalRow, InternalRow]]() @@ -209,11 +223,11 @@ private[sql] class ExchangeCoordinator( // Wait for the finishes of those submitted map stages. val mapOutputStatistics = new Array[MapOutputStatistics](submittedStageFutures.length) - i = 0 - while (i < submittedStageFutures.length) { + var j = 0 + while (j < submittedStageFutures.length) { // This call is a blocking call. If the stage has not finished, we will wait at here. - mapOutputStatistics(i) = submittedStageFutures(i).get() - i += 1 + mapOutputStatistics(j) = submittedStageFutures(j).get() + j += 1 } // Now, we estimate partitionStartIndices. partitionStartIndices.length will be the @@ -225,14 +239,14 @@ private[sql] class ExchangeCoordinator( Some(estimatePartitionStartIndices(mapOutputStatistics)) } - i = 0 - while (i < numExchanges) { - val exchange = exchanges(i) + var k = 0 + while (k < numExchanges) { + val exchange = exchanges(k) val rdd = - exchange.preparePostShuffleRDD(shuffleDependencies(i), partitionStartIndices) + exchange.preparePostShuffleRDD(shuffleDependencies(k), partitionStartIndices) newPostShuffleRDDs.put(exchange, rdd) - i += 1 + k += 1 } // Finally, we set postShuffleRDDs and estimated. @@ -243,7 +257,7 @@ private[sql] class ExchangeCoordinator( } } - def postShuffleRDD(exchange: Exchange): ShuffledRowRDD = { + def postShuffleRDD(exchange: ShuffleExchange): ShuffledRowRDD = { doEstimationIfNecessary() if (!postShuffleRDDs.containsKey(exchange)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala new file mode 100644 index 0000000000000..7e35db7dd8a79 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.exchange + +import java.util.Random + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.hash.HashShuffleManager +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution._ +import org.apache.spark.util.MutablePair + +/** + * Performs a shuffle that will result in the desired `newPartitioning`. + */ +case class ShuffleExchange( + var newPartitioning: Partitioning, + child: SparkPlan, + @transient coordinator: Option[ExchangeCoordinator]) extends Exchange { + + override def nodeName: String = { + val extraInfo = coordinator match { + case Some(exchangeCoordinator) if exchangeCoordinator.isEstimated => + s"(coordinator id: ${System.identityHashCode(coordinator)})" + case Some(exchangeCoordinator) if !exchangeCoordinator.isEstimated => + s"(coordinator id: ${System.identityHashCode(coordinator)})" + case None => "" + } + + val simpleNodeName = "Exchange" + s"$simpleNodeName$extraInfo" + } + + override def outputPartitioning: Partitioning = newPartitioning + + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) + + override protected def doPrepare(): Unit = { + // If an ExchangeCoordinator is needed, we register this Exchange operator + // to the coordinator when we do prepare. It is important to make sure + // we register this operator right before the execution instead of register it + // in the constructor because it is possible that we create new instances of + // Exchange operators when we transform the physical plan + // (then the ExchangeCoordinator will hold references of unneeded Exchanges). + // So, we should only call registerExchange just before we start to execute + // the plan. + coordinator match { + case Some(exchangeCoordinator) => exchangeCoordinator.registerExchange(this) + case None => + } + } + + /** + * Returns a [[ShuffleDependency]] that will partition rows of its child based on + * the partitioning scheme defined in `newPartitioning`. Those partitions of + * the returned ShuffleDependency will be the input of shuffle. + */ + private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = { + ShuffleExchange.prepareShuffleDependency( + child.execute(), child.output, newPartitioning, serializer) + } + + /** + * Returns a [[ShuffledRowRDD]] that represents the post-shuffle dataset. + * This [[ShuffledRowRDD]] is created based on a given [[ShuffleDependency]] and an optional + * partition start indices array. If this optional array is defined, the returned + * [[ShuffledRowRDD]] will fetch pre-shuffle partitions based on indices of this array. + */ + private[sql] def preparePostShuffleRDD( + shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow], + specifiedPartitionStartIndices: Option[Array[Int]] = None): ShuffledRowRDD = { + // If an array of partition start indices is provided, we need to use this array + // to create the ShuffledRowRDD. Also, we need to update newPartitioning to + // update the number of post-shuffle partitions. + specifiedPartitionStartIndices.foreach { indices => + assert(newPartitioning.isInstanceOf[HashPartitioning]) + newPartitioning = UnknownPartitioning(indices.length) + } + new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices) + } + + /** + * Caches the created ShuffleRowRDD so we can reuse that. + */ + private var cachedShuffleRDD: ShuffledRowRDD = null + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + // Returns the same ShuffleRowRDD if this plan is used by multiple plans. + if (cachedShuffleRDD == null) { + cachedShuffleRDD = coordinator match { + case Some(exchangeCoordinator) => + val shuffleRDD = exchangeCoordinator.postShuffleRDD(this) + assert(shuffleRDD.partitions.length == newPartitioning.numPartitions) + shuffleRDD + case None => + val shuffleDependency = prepareShuffleDependency() + preparePostShuffleRDD(shuffleDependency) + } + } + cachedShuffleRDD + } +} + +object ShuffleExchange { + def apply(newPartitioning: Partitioning, child: SparkPlan): ShuffleExchange = { + ShuffleExchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator]) + } + + /** + * Determines whether records must be defensively copied before being sent to the shuffle. + * Several of Spark's shuffle components will buffer deserialized Java objects in memory. The + * shuffle code assumes that objects are immutable and hence does not perform its own defensive + * copying. In Spark SQL, however, operators' iterators return the same mutable `Row` object. In + * order to properly shuffle the output of these operators, we need to perform our own copying + * prior to sending records to the shuffle. This copying is expensive, so we try to avoid it + * whenever possible. This method encapsulates the logic for choosing when to copy. + * + * In the long run, we might want to push this logic into core's shuffle APIs so that we don't + * have to rely on knowledge of core internals here in SQL. + * + * See SPARK-2967, SPARK-4479, and SPARK-7375 for more discussion of this issue. + * + * @param partitioner the partitioner for the shuffle + * @param serializer the serializer that will be used to write rows + * @return true if rows should be copied before being shuffled, false otherwise + */ + private def needToCopyObjectsBeforeShuffle( + partitioner: Partitioner, + serializer: Serializer): Boolean = { + // Note: even though we only use the partitioner's `numPartitions` field, we require it to be + // passed instead of directly passing the number of partitions in order to guard against + // corner-cases where a partitioner constructed with `numPartitions` partitions may output + // fewer partitions (like RangePartitioner, for example). + val conf = SparkEnv.get.conf + val shuffleManager = SparkEnv.get.shuffleManager + val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] + val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + if (sortBasedShuffleOn) { + val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] + if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) { + // If we're using the original SortShuffleManager and the number of output partitions is + // sufficiently small, then Spark will fall back to the hash-based shuffle write path, which + // doesn't buffer deserialized records. + // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass. + false + } else if (serializer.supportsRelocationOfSerializedObjects) { + // SPARK-4550 and SPARK-7081 extended sort-based shuffle to serialize individual records + // prior to sorting them. This optimization is only applied in cases where shuffle + // dependency does not specify an aggregator or ordering and the record serializer has + // certain properties. If this optimization is enabled, we can safely avoid the copy. + // + // Exchange never configures its ShuffledRDDs with aggregators or key orderings, so we only + // need to check whether the optimization is enabled and supported by our serializer. + false + } else { + // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory, so we must + // copy. + true + } + } else if (shuffleManager.isInstanceOf[HashShuffleManager]) { + // We're using hash-based shuffle, so we don't need to copy. + false + } else { + // Catch-all case to safely handle any future ShuffleManager implementations. + true + } + } + + /** + * Returns a [[ShuffleDependency]] that will partition rows of its child based on + * the partitioning scheme defined in `newPartitioning`. Those partitions of + * the returned ShuffleDependency will be the input of shuffle. + */ + private[sql] def prepareShuffleDependency( + rdd: RDD[InternalRow], + outputAttributes: Seq[Attribute], + newPartitioning: Partitioning, + serializer: Serializer): ShuffleDependency[Int, InternalRow, InternalRow] = { + val part: Partitioner = newPartitioning match { + case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions) + case HashPartitioning(_, n) => + new Partitioner { + override def numPartitions: Int = n + // For HashPartitioning, the partitioning key is already a valid partition ID, as we use + // `HashPartitioning.partitionIdExpression` to produce partitioning key. + override def getPartition(key: Any): Int = key.asInstanceOf[Int] + } + case RangePartitioning(sortingExpressions, numPartitions) => + // Internally, RangePartitioner runs a job on the RDD that samples keys to compute + // partition bounds. To get accurate samples, we need to copy the mutable keys. + val rddForSampling = rdd.mapPartitionsInternal { iter => + val mutablePair = new MutablePair[InternalRow, Null]() + iter.map(row => mutablePair.update(row.copy(), null)) + } + implicit val ordering = new LazilyGeneratedOrdering(sortingExpressions, outputAttributes) + new RangePartitioner(numPartitions, rddForSampling, ascending = true) + case SinglePartition => + new Partitioner { + override def numPartitions: Int = 1 + override def getPartition(key: Any): Int = 0 + } + case _ => sys.error(s"Exchange not implemented for $newPartitioning") + // TODO: Handle BroadcastPartitioning. + } + def getPartitionKeyExtractor(): InternalRow => Any = newPartitioning match { + case RoundRobinPartitioning(numPartitions) => + // Distributes elements evenly across output partitions, starting from a random partition. + var position = new Random(TaskContext.get().partitionId()).nextInt(numPartitions) + (row: InternalRow) => { + // The HashPartitioner will handle the `mod` by the number of partitions + position += 1 + position + } + case h: HashPartitioning => + val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) + row => projection(row).getInt(0) + case RangePartitioning(_, _) | SinglePartition => identity + case _ => sys.error(s"Exchange not implemented for $newPartitioning") + } + val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = { + if (needToCopyObjectsBeforeShuffle(part, serializer)) { + rdd.mapPartitionsInternal { iter => + val getPartitionKey = getPartitionKeyExtractor() + iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } + } + } else { + rdd.mapPartitionsInternal { iter => + val getPartitionKey = getPartitionKeyExtractor() + val mutablePair = new MutablePair[Int, InternalRow]() + iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) } + } + } + } + + // Now, we manually create a ShuffleDependency. Because pairs in rddWithPartitionIds + // are in the form of (partitionId, row) and every partitionId is in the expected range + // [0, part.numPartitions - 1]. The partitioner of this is a PartitionIdPassthrough. + val dependency = + new ShuffleDependency[Int, InternalRow, InternalRow]( + rddWithPartitionIds, + new PartitionIdPassthrough(part.numPartitions), + serializer) + + dependency + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 1d381e2eaef38..a8f854136c1f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -17,110 +17,383 @@ package org.apache.spark.sql.execution.joins -import scala.concurrent._ -import scala.concurrent.duration._ - +import org.apache.spark.TaskContext +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution} +import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.util.ThreadUtils -import org.apache.spark.{InternalAccumulator, TaskContext} +import org.apache.spark.sql.types.LongType /** * Performs an inner hash join of two child relations. When the output RDD of this operator is * being constructed, a Spark job is asynchronously started to calculate the values for the - * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed + * broadcast relation. This data is then placed in a Spark broadcast variable. The streamed * relation is not shuffled. */ case class BroadcastHashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], + joinType: JoinType, buildSide: BuildSide, + condition: Option[Expression], left: SparkPlan, right: SparkPlan) - extends BinaryNode with HashJoin { + extends BinaryNode with HashJoin with CodegenSupport { override private[sql] lazy val metrics = Map( - "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), - "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - val timeout: Duration = { - val timeoutValue = sqlContext.conf.broadcastTimeout - if (timeoutValue < 0) { - Duration.Inf - } else { - timeoutValue.seconds + override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = { + val mode = HashedRelationBroadcastMode(buildKeys) + buildSide match { + case BuildLeft => + BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil + case BuildRight => + UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil } } - override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning + protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + + val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() + streamedPlan.execute().mapPartitions { streamedIter => + val hashed = broadcastRelation.value.asReadOnlyCopy() + TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize) + join(streamedIter, hashed, numOutputRows) + } + } + + override def upstreams(): Seq[RDD[InternalRow]] = { + streamedPlan.asInstanceOf[CodegenSupport].upstreams() + } + + override def doProduce(ctx: CodegenContext): String = { + streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + joinType match { + case Inner => codegenInner(ctx, input) + case LeftOuter | RightOuter => codegenOuter(ctx, input) + case LeftSemi => codegenSemi(ctx, input) + case LeftAnti => codegenAnti(ctx, input) + case x => + throw new IllegalArgumentException( + s"BroadcastHashJoin should not take $x as the JoinType") + } + } - override def requiredChildDistribution: Seq[Distribution] = - UnspecifiedDistribution :: UnspecifiedDistribution :: Nil + /** + * Returns a tuple of Broadcast of HashedRelation and the variable name for it. + */ + private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = { + // create a name for HashedRelation + val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() + val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) + val relationTerm = ctx.freshName("relation") + val clsName = broadcastRelation.value.getClass.getName + ctx.addMutableState(clsName, relationTerm, + s""" + | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy(); + | incPeakExecutionMemory($relationTerm.estimatedSize()); + """.stripMargin) + (broadcastRelation, relationTerm) + } - // Use lazy so that we won't do broadcast when calling explain but still cache the broadcast value - // for the same query. - @transient - private lazy val broadcastFuture = { - val numBuildRows = buildSide match { - case BuildLeft => longMetric("numLeftRows") - case BuildRight => longMetric("numRightRows") + /** + * Returns the code for generating join key for stream side, and expression of whether the key + * has any null in it or not. + */ + private def genStreamSideJoinKey( + ctx: CodegenContext, + input: Seq[ExprCode]): (ExprCode, String) = { + ctx.currentVars = input + if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) { + // generate the join key as Long + val ev = streamedKeys.head.gen(ctx) + (ev, ev.isNull) + } else { + // generate the join key as UnsafeRow + val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys) + (ev, s"${ev.value}.anyNull()") } + } - // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - future { - // This will run in another thread. Set the execution id so that we can connect these jobs - // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { - // Note that we use .execute().collect() because we don't want to convert data to Scala - // types - val input: Array[InternalRow] = buildPlan.execute().map { row => - numBuildRows += 1 - row.copy() - }.collect() - // The following line doesn't run in a job so we cannot track the metric value. However, we - // have already tracked it in the above lines. So here we can use - // `SQLMetrics.nullLongMetric` to ignore it. - val hashed = HashedRelation( - input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size) - sparkContext.broadcast(hashed) + /** + * Generates the code for variable of build side. + */ + private def genBuildSideVars(ctx: CodegenContext, matched: String): Seq[ExprCode] = { + ctx.currentVars = null + ctx.INPUT_ROW = matched + buildPlan.output.zipWithIndex.map { case (a, i) => + val ev = BoundReference(i, a.dataType, a.nullable).gen(ctx) + if (joinType == Inner) { + ev + } else { + // the variables are needed even there is no matched rows + val isNull = ctx.freshName("isNull") + val value = ctx.freshName("value") + val code = s""" + |boolean $isNull = true; + |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)}; + |if ($matched != null) { + | ${ev.code} + | $isNull = ${ev.isNull}; + | $value = ${ev.value}; + |} + """.stripMargin + ExprCode(code, isNull, value) } - }(BroadcastHashJoin.broadcastHashJoinExecutionContext) + } } - protected override def doPrepare(): Unit = { - broadcastFuture + /** + * Generate the (non-equi) condition used to filter joined rows. This is used in Inner, Left Semi + * and Left Anti joins. + */ + private def getJoinCondition( + ctx: CodegenContext, + input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = { + val matched = ctx.freshName("matched") + val buildVars = genBuildSideVars(ctx, matched) + val checkCondition = if (condition.isDefined) { + val expr = condition.get + // evaluate the variables from build side that used by condition + val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) + // filter the output via condition + ctx.currentVars = input ++ buildVars + val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx) + s""" + |$eval + |${ev.code} + |if (${ev.isNull} || !${ev.value}) continue; + """.stripMargin + } else { + "" + } + (matched, checkCondition, buildVars) } - protected override def doExecute(): RDD[InternalRow] = { - val numStreamedRows = buildSide match { - case BuildLeft => longMetric("numRightRows") - case BuildRight => longMetric("numLeftRows") + /** + * Generates the code for Inner join. + */ + private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input) + val numOutput = metricTerm(ctx, "numOutputRows") + + val resultVars = buildSide match { + case BuildLeft => buildVars ++ input + case BuildRight => input ++ buildVars } - val numOutputRows = longMetric("numOutputRows") + if (broadcastRelation.value.keyIsUnique) { + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashedRelation + |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + |if ($matched == null) continue; + |$checkCondition + |$numOutput.add(1); + |${consume(ctx, resultVars)} + """.stripMargin + + } else { + ctx.copyResult = true + val matches = ctx.freshName("matches") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashRelation + |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); + |if ($matches == null) continue; + |while ($matches.hasNext()) { + | UnsafeRow $matched = (UnsafeRow) $matches.next(); + | $checkCondition + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + |} + """.stripMargin + } + } - val broadcastRelation = Await.result(broadcastFuture, timeout) + /** + * Generates the code for left or right outer join. + */ + private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val matched = ctx.freshName("matched") + val buildVars = genBuildSideVars(ctx, matched) + val numOutput = metricTerm(ctx, "numOutputRows") - streamedPlan.execute().mapPartitions { streamedIter => - val hashedRelation = broadcastRelation.value - hashedRelation match { - case unsafe: UnsafeHashedRelation => - TaskContext.get().internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) - case _ => - } - hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows) + // filter the output via condition + val conditionPassed = ctx.freshName("conditionPassed") + val checkCondition = if (condition.isDefined) { + val expr = condition.get + // evaluate the variables from build side that used by condition + val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) + ctx.currentVars = input ++ buildVars + val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx) + s""" + |boolean $conditionPassed = true; + |${eval.trim} + |${ev.code} + |if ($matched != null) { + | $conditionPassed = !${ev.isNull} && ${ev.value}; + |} + """.stripMargin + } else { + s"final boolean $conditionPassed = true;" + } + + val resultVars = buildSide match { + case BuildLeft => buildVars ++ input + case BuildRight => input ++ buildVars + } + if (broadcastRelation.value.keyIsUnique) { + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashedRelation + |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + |${checkCondition.trim} + |if (!$conditionPassed) { + | $matched = null; + | // reset the variables those are already evaluated. + | ${buildVars.filter(_.code == "").map(v => s"${v.isNull} = true;").mkString("\n")} + |} + |$numOutput.add(1); + |${consume(ctx, resultVars)} + """.stripMargin + + } else { + ctx.copyResult = true + val matches = ctx.freshName("matches") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName + val found = ctx.freshName("found") + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashRelation + |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); + |boolean $found = false; + |// the last iteration of this loop is to emit an empty row if there is no matched rows. + |while ($matches != null && $matches.hasNext() || !$found) { + | UnsafeRow $matched = $matches != null && $matches.hasNext() ? + | (UnsafeRow) $matches.next() : null; + | ${checkCondition.trim} + | if (!$conditionPassed) continue; + | $found = true; + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + |} + """.stripMargin } } -} -object BroadcastHashJoin { + /** + * Generates the code for left semi join. + */ + private def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val (matched, checkCondition, _) = getJoinCondition(ctx, input) + val numOutput = metricTerm(ctx, "numOutputRows") + if (broadcastRelation.value.keyIsUnique) { + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashedRelation + |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + |if ($matched == null) continue; + |$checkCondition + |$numOutput.add(1); + |${consume(ctx, input)} + """.stripMargin + } else { + val matches = ctx.freshName("matches") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName + val found = ctx.freshName("found") + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashRelation + |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); + |if ($matches == null) continue; + |boolean $found = false; + |while (!$found && $matches.hasNext()) { + | UnsafeRow $matched = (UnsafeRow) $matches.next(); + | $checkCondition + | $found = true; + |} + |if (!$found) continue; + |$numOutput.add(1); + |${consume(ctx, input)} + """.stripMargin + } + } + + /** + * Generates the code for anti join. + */ + private def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val (matched, checkCondition, _) = getJoinCondition(ctx, input) + val numOutput = metricTerm(ctx, "numOutputRows") - private[joins] val broadcastHashJoinExecutionContext = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-join", 128)) + if (broadcastRelation.value.keyIsUnique) { + s""" + |// generate join key for stream side + |${keyEv.code} + |// Check if the key has nulls. + |if (!($anyNull)) { + | // Check if the HashedRelation exists. + | UnsafeRow $matched = (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + | if ($matched != null) { + | // Evaluate the condition. + | $checkCondition + | } + |} + |$numOutput.add(1); + |${consume(ctx, input)} + """.stripMargin + } else { + val matches = ctx.freshName("matches") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName + val found = ctx.freshName("found") + s""" + |// generate join key for stream side + |${keyEv.code} + |// Check if the key has nulls. + |if (!($anyNull)) { + | // Check if the HashedRelation exists. + | $iteratorCls $matches = ($iteratorCls)$relationTerm.get(${keyEv.value}); + | if ($matches != null) { + | // Evaluate the condition. + | boolean $found = false; + | while (!$found && $matches.hasNext()) { + | UnsafeRow $matched = (UnsafeRow) $matches.next(); + | $checkCondition + | $found = true; + | } + | if ($found) continue; + | } + |} + |$numOutput.add(1); + |${consume(ctx, input)} + """.stripMargin + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala deleted file mode 100644 index ab81bd7b3fc04..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ /dev/null @@ -1,151 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import scala.concurrent._ -import scala.concurrent.duration._ - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} -import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan} -import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.{InternalAccumulator, TaskContext} - -/** - * Performs a outer hash join for two child relations. When the output RDD of this operator is - * being constructed, a Spark job is asynchronously started to calculate the values for the - * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed - * relation is not shuffled. - */ -case class BroadcastHashOuterJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashOuterJoin { - - override private[sql] lazy val metrics = Map( - "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), - "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - val timeout = { - val timeoutValue = sqlContext.conf.broadcastTimeout - if (timeoutValue < 0) { - Duration.Inf - } else { - timeoutValue.seconds - } - } - - override def requiredChildDistribution: Seq[Distribution] = - UnspecifiedDistribution :: UnspecifiedDistribution :: Nil - - override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning - - // Use lazy so that we won't do broadcast when calling explain but still cache the broadcast value - // for the same query. - @transient - private lazy val broadcastFuture = { - val numBuildRows = joinType match { - case RightOuter => longMetric("numLeftRows") - case LeftOuter => longMetric("numRightRows") - case x => - throw new IllegalArgumentException( - s"HashOuterJoin should not take $x as the JoinType") - } - - // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - future { - // This will run in another thread. Set the execution id so that we can connect these jobs - // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { - // Note that we use .execute().collect() because we don't want to convert data to Scala - // types - val input: Array[InternalRow] = buildPlan.execute().map { row => - numBuildRows += 1 - row.copy() - }.collect() - // The following line doesn't run in a job so we cannot track the metric value. However, we - // have already tracked it in the above lines. So here we can use - // `SQLMetrics.nullLongMetric` to ignore it. - val hashed = HashedRelation( - input.iterator, SQLMetrics.nullLongMetric, buildKeyGenerator, input.size) - sparkContext.broadcast(hashed) - } - }(BroadcastHashJoin.broadcastHashJoinExecutionContext) - } - - protected override def doPrepare(): Unit = { - broadcastFuture - } - - override def doExecute(): RDD[InternalRow] = { - val numStreamedRows = joinType match { - case RightOuter => longMetric("numRightRows") - case LeftOuter => longMetric("numLeftRows") - case x => - throw new IllegalArgumentException( - s"HashOuterJoin should not take $x as the JoinType") - } - val numOutputRows = longMetric("numOutputRows") - - val broadcastRelation = Await.result(broadcastFuture, timeout) - - streamedPlan.execute().mapPartitions { streamedIter => - val joinedRow = new JoinedRow() - val hashTable = broadcastRelation.value - val keyGenerator = streamedKeyGenerator - - hashTable match { - case unsafe: UnsafeHashedRelation => - TaskContext.get().internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) - case _ => - } - - val resultProj = resultProjection - joinType match { - case LeftOuter => - streamedIter.flatMap(currentRow => { - numStreamedRows += 1 - val rowKey = keyGenerator(currentRow) - joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows) - }) - - case RightOuter => - streamedIter.flatMap(currentRow => { - numStreamedRows += 1 - val rowKey = keyGenerator(currentRow) - joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows) - }) - - case x => - throw new IllegalArgumentException( - s"BroadcastHashOuterJoin should not take $x as the JoinType") - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala deleted file mode 100644 index c5cd6a2fd6372..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.{InternalAccumulator, TaskContext} -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} -import org.apache.spark.sql.execution.metric.SQLMetrics - -/** - * Build the right table's join keys into a HashSet, and iteratively go through the left - * table, to find the if join keys are in the Hash set. - */ -case class BroadcastLeftSemiJoinHash( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - left: SparkPlan, - right: SparkPlan, - condition: Option[Expression]) extends BinaryNode with HashSemiJoin { - - override private[sql] lazy val metrics = Map( - "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), - "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - protected override def doExecute(): RDD[InternalRow] = { - val numLeftRows = longMetric("numLeftRows") - val numRightRows = longMetric("numRightRows") - val numOutputRows = longMetric("numOutputRows") - - val input = right.execute().map { row => - numRightRows += 1 - row.copy() - }.collect() - - if (condition.isEmpty) { - val hashSet = buildKeyHashSet(input.toIterator, SQLMetrics.nullLongMetric) - val broadcastedRelation = sparkContext.broadcast(hashSet) - - left.execute().mapPartitions { streamIter => - hashSemiJoin(streamIter, numLeftRows, broadcastedRelation.value, numOutputRows) - } - } else { - val hashRelation = - HashedRelation(input.toIterator, SQLMetrics.nullLongMetric, rightKeyGenerator, input.size) - val broadcastedRelation = sparkContext.broadcast(hashRelation) - - left.execute().mapPartitions { streamIter => - val hashedRelation = broadcastedRelation.value - hashedRelation match { - case unsafe: UnsafeHashedRelation => - TaskContext.get().internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) - case _ => - } - hashSemiJoin(streamIter, numLeftRows, hashedRelation, numOutputRows) - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 05d20f511aef8..4ba710c10a41a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -17,15 +17,15 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.util.collection.CompactBuffer - +import org.apache.spark.util.collection.{BitSet, CompactBuffer} case class BroadcastNestedLoopJoin( left: SparkPlan, @@ -33,11 +33,8 @@ case class BroadcastNestedLoopJoin( buildSide: BuildSide, joinType: JoinType, condition: Option[Expression]) extends BinaryNode { - // TODO: Override requiredChildDistribution. override private[sql] lazy val metrics = Map( - "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), - "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) /** BuildRight means the right relation <=> the broadcast relation. */ @@ -46,14 +43,21 @@ case class BroadcastNestedLoopJoin( case BuildLeft => (right, left) } - override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows - override def canProcessUnsafeRows: Boolean = true + override def requiredChildDistribution: Seq[Distribution] = buildSide match { + case BuildLeft => + BroadcastDistribution(IdentityBroadcastMode) :: UnspecifiedDistribution :: Nil + case BuildRight => + UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil + } private[this] def genResultProjection: InternalRow => InternalRow = { - if (outputsUnsafeRows) { - UnsafeProjection.create(schema) + if (joinType == LeftSemi) { + UnsafeProjection.create(output, output) } else { - identity[InternalRow] + // Always put the stream side on left to simplify implementation + // both of left and right side could be null + UnsafeProjection.create( + output, (streamed.output ++ broadcast.output).map(_.withNullability(true))) } } @@ -61,127 +65,267 @@ case class BroadcastNestedLoopJoin( override def output: Seq[Attribute] = { joinType match { + case Inner => + left.output ++ right.output case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case Inner => - // TODO we can avoid breaking the lineage, since we union an empty RDD for Inner Join case - left.output ++ right.output - case x => // TODO support the Left Semi Join + case LeftExistence(_) => + left.output + case x => throw new IllegalArgumentException( s"BroadcastNestedLoopJoin should not take $x as the JoinType") } } - @transient private lazy val boundCondition = - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) + @transient private lazy val boundCondition = { + if (condition.isDefined) { + newPredicate(condition.get, streamed.output ++ broadcast.output) + } else { + (r: InternalRow) => true + } + } - protected override def doExecute(): RDD[InternalRow] = { - val (numStreamedRows, numBuildRows) = buildSide match { - case BuildRight => (longMetric("numLeftRows"), longMetric("numRightRows")) - case BuildLeft => (longMetric("numRightRows"), longMetric("numLeftRows")) + /** + * The implementation for InnerJoin. + */ + private def innerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + streamed.execute().mapPartitionsInternal { streamedIter => + val buildRows = relation.value + val joinedRow = new JoinedRow + + streamedIter.flatMap { streamedRow => + val joinedRows = buildRows.iterator.map(r => joinedRow(streamedRow, r)) + if (condition.isDefined) { + joinedRows.filter(boundCondition) + } else { + joinedRows + } + } } - val numOutputRows = longMetric("numOutputRows") + } + + /** + * The implementation for these joins: + * + * LeftOuter with BuildRight + * RightOuter with BuildLeft + */ + private def outerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + streamed.execute().mapPartitionsInternal { streamedIter => + val buildRows = relation.value + val joinedRow = new JoinedRow + val nulls = new GenericMutableRow(broadcast.output.size) - val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map { row => - numBuildRows += 1 - row.copy() - }.collect().toIndexedSeq) + // Returns an iterator to avoid copy the rows. + new Iterator[InternalRow] { + // current row from stream side + private var streamRow: InternalRow = null + // have found a match for current row or not + private var foundMatch: Boolean = false + // the matched result row + private var resultRow: InternalRow = null + // the next index of buildRows to try + private var nextIndex: Int = 0 - /** All rows that either match both-way, or rows from streamed joined with nulls. */ - val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => - val matchedRows = new CompactBuffer[InternalRow] - // TODO: Use Spark's BitSet. - val includedBroadcastTuples = - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + private def findNextMatch(): Boolean = { + if (streamRow == null) { + if (!streamedIter.hasNext) { + return false + } + streamRow = streamedIter.next() + nextIndex = 0 + foundMatch = false + } + while (nextIndex < buildRows.length) { + resultRow = joinedRow(streamRow, buildRows(nextIndex)) + nextIndex += 1 + if (boundCondition(resultRow)) { + foundMatch = true + return true + } + } + if (!foundMatch) { + resultRow = joinedRow(streamRow, nulls) + streamRow = null + true + } else { + resultRow = null + streamRow = null + findNextMatch() + } + } + + override def hasNext(): Boolean = { + resultRow != null || findNextMatch() + } + override def next(): InternalRow = { + val r = resultRow + resultRow = null + r + } + } + } + } + + /** + * The implementation for these joins: + * + * LeftSemi with BuildRight + * Anti with BuildRight + */ + private def leftExistenceJoin( + relation: Broadcast[Array[InternalRow]], + exists: Boolean): RDD[InternalRow] = { + assert(buildSide == BuildRight) + streamed.execute().mapPartitionsInternal { streamedIter => + val buildRows = relation.value val joinedRow = new JoinedRow - val leftNulls = new GenericMutableRow(left.output.size) - val rightNulls = new GenericMutableRow(right.output.size) - val resultProj = genResultProjection + if (condition.isDefined) { + streamedIter.filter(l => + buildRows.exists(r => boundCondition(joinedRow(l, r))) == exists + ) + } else if (buildRows.nonEmpty == exists) { + streamedIter + } else { + Iterator.empty + } + } + } + + /** + * The implementation for these joins: + * + * LeftOuter with BuildLeft + * RightOuter with BuildRight + * FullOuter + * LeftSemi with BuildLeft + * Anti with BuildLeft + */ + private def defaultJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + /** All rows that either match both-way, or rows from streamed joined with nulls. */ + val streamRdd = streamed.execute() + + val matchedBuildRows = streamRdd.mapPartitionsInternal { streamedIter => + val buildRows = relation.value + val matched = new BitSet(buildRows.length) + val joinedRow = new JoinedRow streamedIter.foreach { streamedRow => var i = 0 - var streamRowMatched = false - numStreamedRows += 1 - - while (i < broadcastedRelation.value.size) { - val broadcastedRow = broadcastedRelation.value(i) - buildSide match { - case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => - matchedRows += resultProj(joinedRow(streamedRow, broadcastedRow)).copy() - streamRowMatched = true - includedBroadcastTuples += i - case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => - matchedRows += resultProj(joinedRow(broadcastedRow, streamedRow)).copy() - streamRowMatched = true - includedBroadcastTuples += i - case _ => + while (i < buildRows.length) { + if (boundCondition(joinedRow(streamedRow, buildRows(i)))) { + matched.set(i) } i += 1 } - - (streamRowMatched, joinType, buildSide) match { - case (false, LeftOuter | FullOuter, BuildRight) => - matchedRows += resultProj(joinedRow(streamedRow, rightNulls)).copy() - case (false, RightOuter | FullOuter, BuildLeft) => - matchedRows += resultProj(joinedRow(leftNulls, streamedRow)).copy() - case _ => - } } - Iterator((matchedRows, includedBroadcastTuples)) + Seq(matched).toIterator } - val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) - val allIncludedBroadcastTuples = includedBroadcastTuples.fold( - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) - )(_ ++ _) + val matchedBroadcastRows = matchedBuildRows.fold( + new BitSet(relation.value.length) + )(_ | _) - val leftNulls = new GenericMutableRow(left.output.size) - val rightNulls = new GenericMutableRow(right.output.size) - val resultProj = genResultProjection + if (joinType == LeftSemi) { + assert(buildSide == BuildLeft) + val buf: CompactBuffer[InternalRow] = new CompactBuffer() + var i = 0 + val rel = relation.value + while (i < rel.length) { + if (matchedBroadcastRows.get(i)) { + buf += rel(i).copy() + } + i += 1 + } + return sparkContext.makeRDD(buf) + } - /** Rows from broadcasted joined with nulls. */ - val broadcastRowsWithNulls: Seq[InternalRow] = { + val notMatchedBroadcastRows: Seq[InternalRow] = { + val nulls = new GenericMutableRow(streamed.output.size) val buf: CompactBuffer[InternalRow] = new CompactBuffer() var i = 0 - val rel = broadcastedRelation.value - (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => - val joinedRow = new JoinedRow - joinedRow.withLeft(leftNulls) - while (i < rel.length) { - if (!allIncludedBroadcastTuples.contains(i)) { - buf += resultProj(joinedRow.withRight(rel(i))).copy() - } - i += 1 - } - case (LeftOuter | FullOuter, BuildLeft) => - val joinedRow = new JoinedRow - joinedRow.withRight(rightNulls) - while (i < rel.length) { - if (!allIncludedBroadcastTuples.contains(i)) { - buf += resultProj(joinedRow.withLeft(rel(i))).copy() - } - i += 1 + val buildRows = relation.value + val joinedRow = new JoinedRow + joinedRow.withLeft(nulls) + while (i < buildRows.length) { + if (!matchedBroadcastRows.get(i)) { + buf += joinedRow.withRight(buildRows(i)).copy() + } + i += 1 + } + buf + } + + if (joinType == LeftAnti) { + return sparkContext.makeRDD(notMatchedBroadcastRows) + } + + val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter => + val buildRows = relation.value + val joinedRow = new JoinedRow + val nulls = new GenericMutableRow(broadcast.output.size) + + streamedIter.flatMap { streamedRow => + var i = 0 + var foundMatch = false + val matchedRows = new CompactBuffer[InternalRow] + + while (i < buildRows.length) { + if (boundCondition(joinedRow(streamedRow, buildRows(i)))) { + matchedRows += joinedRow.copy() + foundMatch = true } - case _ => + i += 1 + } + + if (!foundMatch && joinType == FullOuter) { + matchedRows += joinedRow(streamedRow, nulls).copy() + } + matchedRows.iterator } - buf.toSeq } - // TODO: Breaks lineage. sparkContext.union( - matchesOrStreamedRowsWithNulls.flatMap(_._1), - sparkContext.makeRDD(broadcastRowsWithNulls) - ).map { row => - // `broadcastRowsWithNulls` doesn't run in a job so that we have to track numOutputRows here. - numOutputRows += 1 - row + matchedStreamRows, + sparkContext.makeRDD(notMatchedBroadcastRows) + ) + } + + protected override def doExecute(): RDD[InternalRow] = { + val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]() + + val resultRdd = (joinType, buildSide) match { + case (Inner, _) => + innerJoin(broadcastedRelation) + case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) => + outerJoin(broadcastedRelation) + case (LeftSemi, BuildRight) => + leftExistenceJoin(broadcastedRelation, exists = true) + case (LeftAnti, BuildRight) => + leftExistenceJoin(broadcastedRelation, exists = false) + case _ => + /** + * LeftOuter with BuildLeft + * RightOuter with BuildRight + * FullOuter + * LeftSemi with BuildLeft + * Anti with BuildLeft + */ + defaultJoin(broadcastedRelation) + } + + val numOutputRows = longMetric("numOutputRows") + resultRdd.mapPartitionsInternal { iter => + val resultProj = genResultProjection + iter.map { r => + numOutputRows += 1 + resultProj(r) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index 0243e196dbc37..edb4c5a16fb0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -17,40 +17,85 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.rdd.RDD +import org.apache.spark._ +import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter + +/** + * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD, + * will be much faster than building the right partition for every row in left RDD, it also + * materialize the right RDD (in case of the right RDD is nondeterministic). + */ +private[spark] +class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int) + extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { + + override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = { + // We will not sort the rows, so prefixComparator and recordComparator are null. + val sorter = UnsafeExternalSorter.create( + context.taskMemoryManager(), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + context, + null, + null, + 1024, + SparkEnv.get.memoryManager.pageSizeBytes) + + val partition = split.asInstanceOf[CartesianPartition] + for (y <- rdd2.iterator(partition.s2, context)) { + sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0) + } + + // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow] + def createIter(): Iterator[UnsafeRow] = { + val iter = sorter.getIterator + val unsafeRow = new UnsafeRow(numFieldsOfRight) + new Iterator[UnsafeRow] { + override def hasNext: Boolean = { + iter.hasNext + } + override def next(): UnsafeRow = { + iter.loadNext() + unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) + unsafeRow + } + } + } + + val resultIter = + for (x <- rdd1.iterator(partition.s1, context); + y <- createIter()) yield (x, y) + CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]]( + resultIter, sorter.cleanupResources) + } +} case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output override private[sql] lazy val metrics = Map( - "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), - "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) protected override def doExecute(): RDD[InternalRow] = { - val numLeftRows = longMetric("numLeftRows") - val numRightRows = longMetric("numRightRows") val numOutputRows = longMetric("numOutputRows") - val leftResults = left.execute().map { row => - numLeftRows += 1 - row.copy() - } - val rightResults = right.execute().map { row => - numRightRows += 1 - row.copy() - } + val leftResults = left.execute().asInstanceOf[RDD[UnsafeRow]] + val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]] - leftResults.cartesian(rightResults).mapPartitions { iter => - val joinedRow = new JoinedRow + val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size) + pair.mapPartitionsInternal { iter => + val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) iter.map { r => numOutputRows += 1 - joinedRow(r._1, r._2) + joiner.join(r._1, r._2) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 7ce4a517838cb..d6feedc27244b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -19,118 +19,208 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.execution.{RowIterator, SparkPlan} import org.apache.spark.sql.execution.metric.LongSQLMetric - +import org.apache.spark.sql.types.{IntegralType, LongType} trait HashJoin { self: SparkPlan => val leftKeys: Seq[Expression] val rightKeys: Seq[Expression] + val joinType: JoinType val buildSide: BuildSide + val condition: Option[Expression] val left: SparkPlan val right: SparkPlan + override def output: Seq[Attribute] = { + joinType match { + case Inner => + left.output ++ right.output + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case LeftExistence(_) => + left.output + case x => + throw new IllegalArgumentException(s"HashJoin should not take $x as the JoinType") + } + } + protected lazy val (buildPlan, streamedPlan) = buildSide match { case BuildLeft => (left, right) case BuildRight => (right, left) } - protected lazy val (buildKeys, streamedKeys) = buildSide match { - case BuildLeft => (leftKeys, rightKeys) - case BuildRight => (rightKeys, leftKeys) + protected lazy val (buildKeys, streamedKeys) = { + require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType), + "Join keys from two sides should have same types") + val lkeys = rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output)) + val rkeys = rewriteKeyExpr(rightKeys).map(BindReferences.bindReference(_, right.output)) + buildSide match { + case BuildLeft => (lkeys, rkeys) + case BuildRight => (rkeys, lkeys) + } } - override def output: Seq[Attribute] = left.output ++ right.output - - protected[this] def isUnsafeMode: Boolean = { - (self.codegenEnabled && self.unsafeEnabled - && UnsafeProjection.canSupport(buildKeys) - && UnsafeProjection.canSupport(self.schema)) + /** + * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. + * + * If not, returns the original expressions. + */ + private def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { + var keyExpr: Expression = null + var width = 0 + keys.foreach { e => + e.dataType match { + case dt: IntegralType if dt.defaultSize <= 8 - width => + if (width == 0) { + if (e.dataType != LongType) { + keyExpr = Cast(e, LongType) + } else { + keyExpr = e + } + width = dt.defaultSize + } else { + val bits = dt.defaultSize * 8 + keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), + BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) + width -= bits + } + // TODO: support BooleanType, DateType and TimestampType + case other => + return keys + } + } + keyExpr :: Nil } - override def outputsUnsafeRows: Boolean = isUnsafeMode - override def canProcessUnsafeRows: Boolean = isUnsafeMode - override def canProcessSafeRows: Boolean = !isUnsafeMode + protected def buildSideKeyGenerator(): Projection = + UnsafeProjection.create(buildKeys) - protected def buildSideKeyGenerator: Projection = - if (isUnsafeMode) { - UnsafeProjection.create(buildKeys, buildPlan.output) - } else { - newMutableProjection(buildKeys, buildPlan.output)() - } + protected def streamSideKeyGenerator(): UnsafeProjection = + UnsafeProjection.create(streamedKeys) + + @transient private[this] lazy val boundCondition = if (condition.isDefined) { + newPredicate(condition.get, streamedPlan.output ++ buildPlan.output) + } else { + (r: InternalRow) => true + } - protected def streamSideKeyGenerator: Projection = - if (isUnsafeMode) { - UnsafeProjection.create(streamedKeys, streamedPlan.output) + protected def createResultProjection(): (InternalRow) => InternalRow = { + if (joinType == LeftSemi) { + UnsafeProjection.create(output, output) } else { - newMutableProjection(streamedKeys, streamedPlan.output)() + // Always put the stream side on left to simplify implementation + // both of left and right side could be null + UnsafeProjection.create( + output, (streamedPlan.output ++ buildPlan.output).map(_.withNullability(true))) } + } - protected def hashJoin( + private def innerJoin( streamIter: Iterator[InternalRow], - numStreamRows: LongSQLMetric, - hashedRelation: HashedRelation, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = - { - new Iterator[InternalRow] { - private[this] var currentStreamedRow: InternalRow = _ - private[this] var currentHashMatches: Seq[InternalRow] = _ - private[this] var currentMatchPosition: Int = -1 - - // Mutable per row objects. - private[this] val joinRow = new JoinedRow - private[this] val resultProjection: (InternalRow) => InternalRow = { - if (isUnsafeMode) { - UnsafeProjection.create(self.schema) - } else { - identity[InternalRow] - } + hashedRelation: HashedRelation): Iterator[InternalRow] = { + val joinRow = new JoinedRow + val joinKeys = streamSideKeyGenerator() + streamIter.flatMap { srow => + joinRow.withLeft(srow) + val matches = hashedRelation.get(joinKeys(srow)) + if (matches != null) { + matches.map(joinRow.withRight(_)).filter(boundCondition) + } else { + Seq.empty } + } + } - private[this] val joinKeys = streamSideKeyGenerator + private def outerJoin( + streamedIter: Iterator[InternalRow], + hashedRelation: HashedRelation): Iterator[InternalRow] = { + val joinedRow = new JoinedRow() + val keyGenerator = streamSideKeyGenerator() + val nullRow = new GenericInternalRow(buildPlan.output.length) + + streamedIter.flatMap { currentRow => + val rowKey = keyGenerator(currentRow) + joinedRow.withLeft(currentRow) + val buildIter = hashedRelation.get(rowKey) + new RowIterator { + private var found = false + override def advanceNext(): Boolean = { + while (buildIter != null && buildIter.hasNext) { + val nextBuildRow = buildIter.next() + if (boundCondition(joinedRow.withRight(nextBuildRow))) { + found = true + return true + } + } + if (!found) { + joinedRow.withRight(nullRow) + found = true + return true + } + false + } + override def getRow: InternalRow = joinedRow + }.toScala + } + } - override final def hasNext: Boolean = - (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || - (streamIter.hasNext && fetchNext()) + private def semiJoin( + streamIter: Iterator[InternalRow], + hashedRelation: HashedRelation): Iterator[InternalRow] = { + val joinKeys = streamSideKeyGenerator() + val joinedRow = new JoinedRow + streamIter.filter { current => + val key = joinKeys(current) + lazy val buildIter = hashedRelation.get(key) + !key.anyNull && buildIter != null && (condition.isEmpty || buildIter.exists { + (row: InternalRow) => boundCondition(joinedRow(current, row)) + }) + } + } - override final def next(): InternalRow = { - val ret = buildSide match { - case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) - case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) - } - currentMatchPosition += 1 - numOutputRows += 1 - resultProjection(ret) - } + private def antiJoin( + streamIter: Iterator[InternalRow], + hashedRelation: HashedRelation): Iterator[InternalRow] = { + val joinKeys = streamSideKeyGenerator() + val joinedRow = new JoinedRow + streamIter.filter { current => + val key = joinKeys(current) + lazy val buildIter = hashedRelation.get(key) + key.anyNull || buildIter == null || (condition.isDefined && !buildIter.exists { + row => boundCondition(joinedRow(current, row)) + }) + } + } - /** - * Searches the streamed iterator for the next row that has at least one match in hashtable. - * - * @return true if the search is successful, and false if the streamed iterator runs out of - * tuples. - */ - private final def fetchNext(): Boolean = { - currentHashMatches = null - currentMatchPosition = -1 - - while (currentHashMatches == null && streamIter.hasNext) { - currentStreamedRow = streamIter.next() - numStreamRows += 1 - val key = joinKeys(currentStreamedRow) - if (!key.anyNull) { - currentHashMatches = hashedRelation.get(key) - } - } + protected def join( + streamedIter: Iterator[InternalRow], + hashed: HashedRelation, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { + + val joinedIter = joinType match { + case Inner => + innerJoin(streamedIter, hashed) + case LeftOuter | RightOuter => + outerJoin(streamedIter, hashed) + case LeftSemi => + semiJoin(streamedIter, hashed) + case LeftAnti => + antiJoin(streamedIter, hashed) + case x => + throw new IllegalArgumentException( + s"BroadcastHashJoin should not take $x as the JoinType") + } - if (currentHashMatches == null) { - false - } else { - currentMatchPosition = 0 - true - } - } + val resultProj = createResultProjection + joinedIter.map { r => + numOutputRows += 1 + resultProj(r) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala deleted file mode 100644 index 15b06b1537f8c..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ /dev/null @@ -1,248 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.metric.LongSQLMetric -import org.apache.spark.util.collection.CompactBuffer - - -trait HashOuterJoin { - self: SparkPlan => - - val leftKeys: Seq[Expression] - val rightKeys: Seq[Expression] - val joinType: JoinType - val condition: Option[Expression] - val left: SparkPlan - val right: SparkPlan - - override def output: Seq[Attribute] = { - joinType match { - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case x => - throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") - } - } - - protected[this] lazy val (buildPlan, streamedPlan) = joinType match { - case RightOuter => (left, right) - case LeftOuter => (right, left) - case x => - throw new IllegalArgumentException( - s"HashOuterJoin should not take $x as the JoinType") - } - - protected[this] lazy val (buildKeys, streamedKeys) = joinType match { - case RightOuter => (leftKeys, rightKeys) - case LeftOuter => (rightKeys, leftKeys) - case x => - throw new IllegalArgumentException( - s"HashOuterJoin should not take $x as the JoinType") - } - - protected[this] def isUnsafeMode: Boolean = { - (self.codegenEnabled && self.unsafeEnabled && joinType != FullOuter - && UnsafeProjection.canSupport(buildKeys) - && UnsafeProjection.canSupport(self.schema)) - } - - override def outputsUnsafeRows: Boolean = isUnsafeMode - override def canProcessUnsafeRows: Boolean = isUnsafeMode - override def canProcessSafeRows: Boolean = !isUnsafeMode - - protected def buildKeyGenerator: Projection = - if (isUnsafeMode) { - UnsafeProjection.create(buildKeys, buildPlan.output) - } else { - newMutableProjection(buildKeys, buildPlan.output)() - } - - protected[this] def streamedKeyGenerator: Projection = { - if (isUnsafeMode) { - UnsafeProjection.create(streamedKeys, streamedPlan.output) - } else { - newProjection(streamedKeys, streamedPlan.output) - } - } - - protected[this] def resultProjection: InternalRow => InternalRow = { - if (isUnsafeMode) { - UnsafeProjection.create(self.schema) - } else { - identity[InternalRow] - } - } - - @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) - @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() - - @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) - @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) - @transient private[this] lazy val boundCondition = - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - - // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala - // iterator for performance purpose. - - protected[this] def leftOuterIterator( - key: InternalRow, - joinedRow: JoinedRow, - rightIter: Iterable[InternalRow], - resultProjection: InternalRow => InternalRow, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - val ret: Iterable[InternalRow] = { - if (!key.anyNull) { - val temp = if (rightIter != null) { - rightIter.collect { - case r if boundCondition(joinedRow.withRight(r)) => { - numOutputRows += 1 - resultProjection(joinedRow).copy() - } - } - } else { - List.empty - } - if (temp.isEmpty) { - numOutputRows += 1 - resultProjection(joinedRow.withRight(rightNullRow)) :: Nil - } else { - temp - } - } else { - numOutputRows += 1 - resultProjection(joinedRow.withRight(rightNullRow)) :: Nil - } - } - ret.iterator - } - - protected[this] def rightOuterIterator( - key: InternalRow, - leftIter: Iterable[InternalRow], - joinedRow: JoinedRow, - resultProjection: InternalRow => InternalRow, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - val ret: Iterable[InternalRow] = { - if (!key.anyNull) { - val temp = if (leftIter != null) { - leftIter.collect { - case l if boundCondition(joinedRow.withLeft(l)) => { - numOutputRows += 1 - resultProjection(joinedRow).copy() - } - } - } else { - List.empty - } - if (temp.isEmpty) { - numOutputRows += 1 - resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil - } else { - temp - } - } else { - numOutputRows += 1 - resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil - } - } - ret.iterator - } - - protected[this] def fullOuterIterator( - key: InternalRow, leftIter: Iterable[InternalRow], rightIter: Iterable[InternalRow], - joinedRow: JoinedRow, numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - if (!key.anyNull) { - // Store the positions of records in right, if one of its associated row satisfy - // the join condition. - val rightMatchedSet = scala.collection.mutable.Set[Int]() - leftIter.iterator.flatMap[InternalRow] { l => - joinedRow.withLeft(l) - var matched = false - rightIter.zipWithIndex.collect { - // 1. For those matched (satisfy the join condition) records with both sides filled, - // append them directly - - case (r, idx) if boundCondition(joinedRow.withRight(r)) => - numOutputRows += 1 - matched = true - // if the row satisfy the join condition, add its index into the matched set - rightMatchedSet.add(idx) - joinedRow.copy() - - } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { - // 2. For those unmatched records in left, append additional records with empty right. - - // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all - // of the records in right side. - // If we didn't get any proper row, then append a single row with empty right. - numOutputRows += 1 - joinedRow.withRight(rightNullRow).copy() - }) - } ++ rightIter.zipWithIndex.collect { - // 3. For those unmatched records in right, append additional records with empty left. - - // Re-visiting the records in right, and append additional row with empty left, if its not - // in the matched set. - case (r, idx) if !rightMatchedSet.contains(idx) => - numOutputRows += 1 - joinedRow(leftNullRow, r).copy() - } - } else { - leftIter.iterator.map[InternalRow] { l => - numOutputRows += 1 - joinedRow(l, rightNullRow).copy() - } ++ rightIter.iterator.map[InternalRow] { r => - numOutputRows += 1 - joinedRow(leftNullRow, r).copy() - } - } - } - - // This is only used by FullOuter - protected[this] def buildHashTable( - iter: Iterator[InternalRow], - numIterRows: LongSQLMetric, - keyGenerator: Projection): java.util.HashMap[InternalRow, CompactBuffer[InternalRow]] = { - val hashTable = new java.util.HashMap[InternalRow, CompactBuffer[InternalRow]]() - while (iter.hasNext) { - val currentRow = iter.next() - numIterRows += 1 - val rowKey = keyGenerator(currentRow) - - var existingMatchList = hashTable.get(rowKey) - if (existingMatchList == null) { - existingMatchList = new CompactBuffer[InternalRow]() - hashTable.put(rowKey.copy(), existingMatchList) - } - - existingMatchList += currentRow.copy() - } - - hashTable - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala deleted file mode 100644 index beb141ade616d..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.metric.LongSQLMetric - - -trait HashSemiJoin { - self: SparkPlan => - val leftKeys: Seq[Expression] - val rightKeys: Seq[Expression] - val left: SparkPlan - val right: SparkPlan - val condition: Option[Expression] - - override def output: Seq[Attribute] = left.output - - protected[this] def supportUnsafe: Boolean = { - (self.codegenEnabled && self.unsafeEnabled - && UnsafeProjection.canSupport(leftKeys) - && UnsafeProjection.canSupport(rightKeys) - && UnsafeProjection.canSupport(left.schema) - && UnsafeProjection.canSupport(right.schema)) - } - - override def outputsUnsafeRows: Boolean = supportUnsafe - override def canProcessUnsafeRows: Boolean = supportUnsafe - override def canProcessSafeRows: Boolean = !supportUnsafe - - protected def leftKeyGenerator: Projection = - if (supportUnsafe) { - UnsafeProjection.create(leftKeys, left.output) - } else { - newMutableProjection(leftKeys, left.output)() - } - - protected def rightKeyGenerator: Projection = - if (supportUnsafe) { - UnsafeProjection.create(rightKeys, right.output) - } else { - newMutableProjection(rightKeys, right.output)() - } - - @transient private lazy val boundCondition = - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - - protected def buildKeyHashSet( - buildIter: Iterator[InternalRow], numBuildRows: LongSQLMetric): java.util.Set[InternalRow] = { - val hashSet = new java.util.HashSet[InternalRow]() - - // Create a Hash set of buildKeys - val rightKey = rightKeyGenerator - while (buildIter.hasNext) { - val currentRow = buildIter.next() - numBuildRows += 1 - val rowKey = rightKey(currentRow) - if (!rowKey.anyNull) { - val keyExists = hashSet.contains(rowKey) - if (!keyExists) { - hashSet.add(rowKey.copy()) - } - } - } - - hashSet - } - - protected def hashSemiJoin( - streamIter: Iterator[InternalRow], - numStreamRows: LongSQLMetric, - hashSet: java.util.Set[InternalRow], - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - val joinKeys = leftKeyGenerator - streamIter.filter(current => { - numStreamRows += 1 - val key = joinKeys(current) - val r = !key.anyNull && hashSet.contains(key) - if (r) numOutputRows += 1 - r - }) - } - - protected def hashSemiJoin( - streamIter: Iterator[InternalRow], - numStreamRows: LongSQLMetric, - hashedRelation: HashedRelation, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - val joinKeys = leftKeyGenerator - val joinedRow = new JoinedRow - streamIter.filter { current => - numStreamRows += 1 - val key = joinKeys(current) - lazy val rowBuffer = hashedRelation.get(key) - val r = !key.anyNull && rowBuffer != null && rowBuffer.exists { - (row: InternalRow) => boundCondition(joinedRow(current, row)) - } - if (r) numOutputRows += 1 - r - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index cc8abb1ba463c..0427db4e3bf25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -18,313 +18,199 @@ package org.apache.spark.sql.execution.joins import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} -import java.nio.ByteOrder -import java.util.{HashMap => JavaHashMap} -import org.apache.spark.memory.{TaskMemoryManager, StaticMemoryManager} +import org.apache.spark.{SparkConf, SparkEnv, SparkException} +import org.apache.spark.memory.{MemoryConsumer, MemoryMode, StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.SparkSqlSerializer -import org.apache.spark.sql.execution.local.LocalNode -import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} +import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode +import org.apache.spark.sql.types.LongType import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap -import org.apache.spark.unsafe.memory.MemoryLocation -import org.apache.spark.util.Utils -import org.apache.spark.util.collection.CompactBuffer -import org.apache.spark.{SparkConf, SparkEnv} - +import org.apache.spark.util.{KnownSizeEstimation, Utils} /** * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete * object. */ -private[execution] sealed trait HashedRelation { - def get(key: InternalRow): Seq[InternalRow] - - // This is a helper method to implement Externalizable, and is used by - // GeneralHashedRelation and UniqueKeyHashedRelation - protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = { - out.writeInt(serialized.length) // Write the length of serialized bytes first - out.write(serialized) - } - - // This is a helper method to implement Externalizable, and is used by - // GeneralHashedRelation and UniqueKeyHashedRelation - protected def readBytes(in: ObjectInput): Array[Byte] = { - val serializedSize = in.readInt() // Read the length of serialized bytes first - val bytes = new Array[Byte](serializedSize) - in.readFully(bytes) - bytes - } -} - - -/** - * A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values. - */ -private[joins] final class GeneralHashedRelation( - private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]]) - extends HashedRelation with Externalizable { - - // Needed for serialization (it is public to make Java serialization work) - def this() = this(null) - - override def get(key: InternalRow): Seq[InternalRow] = hashTable.get(key) - - override def writeExternal(out: ObjectOutput): Unit = { - writeBytes(out, SparkSqlSerializer.serialize(hashTable)) - } +private[execution] sealed trait HashedRelation extends KnownSizeEstimation { + /** + * Returns matched rows. + * + * Returns null if there is no matched rows. + */ + def get(key: InternalRow): Iterator[InternalRow] - override def readExternal(in: ObjectInput): Unit = { - hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + /** + * Returns matched rows for a key that has only one column with LongType. + * + * Returns null if there is no matched rows. + */ + def get(key: Long): Iterator[InternalRow] = { + throw new UnsupportedOperationException } -} - - -/** - * A specialized [[HashedRelation]] that maps key into a single value. This implementation - * assumes the key is unique. - */ -private[joins] -final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow]) - extends HashedRelation with Externalizable { - // Needed for serialization (it is public to make Java serialization work) - def this() = this(null) + /** + * Returns the matched single row. + */ + def getValue(key: InternalRow): InternalRow - override def get(key: InternalRow): Seq[InternalRow] = { - val v = hashTable.get(key) - if (v eq null) null else CompactBuffer(v) + /** + * Returns the matched single row with key that have only one column of LongType. + */ + def getValue(key: Long): InternalRow = { + throw new UnsupportedOperationException } - def getValue(key: InternalRow): InternalRow = hashTable.get(key) + /** + * Returns true iff all the keys are unique. + */ + def keyIsUnique: Boolean - override def writeExternal(out: ObjectOutput): Unit = { - writeBytes(out, SparkSqlSerializer.serialize(hashTable)) - } + /** + * Returns a read-only copy of this, to be safely used in current thread. + */ + def asReadOnlyCopy(): HashedRelation - override def readExternal(in: ObjectInput): Unit = { - hashTable = SparkSqlSerializer.deserialize(readBytes(in)) - } + /** + * Release any used resources. + */ + def close(): Unit } -// TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys. - - private[execution] object HashedRelation { - def apply(localNode: LocalNode, keyGenerator: Projection): HashedRelation = { - apply(localNode.asIterator, SQLMetrics.nullLongMetric, keyGenerator) - } - + /** + * Create a HashedRelation from an Iterator of InternalRow. + */ def apply( input: Iterator[InternalRow], - numInputRows: LongSQLMetric, - keyGenerator: Projection, - sizeEstimate: Int = 64): HashedRelation = { - - if (keyGenerator.isInstanceOf[UnsafeProjection]) { - return UnsafeHashedRelation( - input, numInputRows, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) + key: Seq[Expression], + sizeEstimate: Int = 64, + taskMemoryManager: TaskMemoryManager = null): HashedRelation = { + val mm = Option(taskMemoryManager).getOrElse { + new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) } - // TODO: Use Spark's HashMap implementation. - val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]](sizeEstimate) - var currentRow: InternalRow = null - - // Whether the join key is unique. If the key is unique, we can convert the underlying - // hash map into one specialized for this. - var keyIsUnique = true - - // Create a mapping of buildKeys -> rows - while (input.hasNext) { - currentRow = input.next() - numInputRows += 1 - val rowKey = keyGenerator(currentRow) - if (!rowKey.anyNull) { - val existingMatchList = hashTable.get(rowKey) - val matchList = if (existingMatchList == null) { - val newMatchList = new CompactBuffer[InternalRow]() - hashTable.put(rowKey.copy(), newMatchList) - newMatchList - } else { - keyIsUnique = false - existingMatchList - } - matchList += currentRow.copy() - } - } - - if (keyIsUnique) { - val uniqHashTable = new JavaHashMap[InternalRow, InternalRow](hashTable.size) - val iter = hashTable.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() - uniqHashTable.put(entry.getKey, entry.getValue()(0)) - } - new UniqueKeyHashedRelation(uniqHashTable) + if (key.length == 1 && key.head.dataType == LongType) { + LongHashedRelation(input, key, sizeEstimate, mm) } else { - new GeneralHashedRelation(hashTable) + UnsafeHashedRelation(input, key, sizeEstimate, mm) } } } /** - * A HashedRelation for UnsafeRow, which is backed by HashMap or BytesToBytesMap that maps the key - * into a sequence of values. - * - * When it's created, it uses HashMap. After it's serialized and deserialized, it switch to use - * BytesToBytesMap for better memory performance (multiple values for the same are stored as a - * continuous byte array. + * A HashedRelation for UnsafeRow, which is backed BytesToBytesMap. * * It's serialized in the following format: * [number of keys] - * [size of key] [size of all values in bytes] [key bytes] [bytes for all values] - * ... - * - * All the values are serialized as following: - * [number of fields] [number of bytes] [underlying bytes of UnsafeRow] - * ... + * [size of key] [size of value] [key bytes] [bytes for value] */ -private[joins] final class UnsafeHashedRelation( - private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) +private[joins] class UnsafeHashedRelation( + private var numFields: Int, + private var binaryMap: BytesToBytesMap) extends HashedRelation with Externalizable { - private[joins] def this() = this(null) // Needed for serialization + private[joins] def this() = this(0, null) // Needed for serialization - // Use BytesToBytesMap in executor for better performance (it's created when deserialization) - // This is used in broadcast joins and distributed mode only - @transient private[this] var binaryMap: BytesToBytesMap = _ + override def keyIsUnique: Boolean = binaryMap.numKeys() == binaryMap.numValues() - /** - * Return the size of the unsafe map on the executors. - * - * For broadcast joins, this hashed relation is bigger on the driver because it is - * represented as a Java hash map there. While serializing the map to the executors, - * however, we rehash the contents in a binary map to reduce the memory footprint on - * the executors. - * - * For non-broadcast joins or in local mode, return 0. - */ - def getUnsafeSize: Long = { - if (binaryMap != null) { - binaryMap.getTotalMemoryConsumption - } else { - 0 - } + override def asReadOnlyCopy(): UnsafeHashedRelation = { + new UnsafeHashedRelation(numFields, binaryMap) } - override def get(key: InternalRow): Seq[InternalRow] = { - val unsafeKey = key.asInstanceOf[UnsafeRow] + override def estimatedSize: Long = binaryMap.getTotalMemoryConsumption + + // re-used in get()/getValue() + var resultRow = new UnsafeRow(numFields) - if (binaryMap != null) { - // Used in Broadcast join - val map = binaryMap // avoid the compiler error - val loc = new map.Location // this could be allocated in stack - binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, - unsafeKey.getSizeInBytes, loc) - if (loc.isDefined) { - val buffer = CompactBuffer[UnsafeRow]() - - val base = loc.getValueAddress.getBaseObject - var offset = loc.getValueAddress.getBaseOffset - val last = loc.getValueAddress.getBaseOffset + loc.getValueLength - while (offset < last) { - val numFields = Platform.getInt(base, offset) - val sizeInBytes = Platform.getInt(base, offset + 4) - offset += 8 - - val row = new UnsafeRow - row.pointTo(base, offset, numFields, sizeInBytes) - buffer += row - offset += sizeInBytes + override def get(key: InternalRow): Iterator[InternalRow] = { + val unsafeKey = key.asInstanceOf[UnsafeRow] + val map = binaryMap // avoid the compiler error + val loc = new map.Location // this could be allocated in stack + binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, + unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode()) + if (loc.isDefined) { + new Iterator[UnsafeRow] { + private var _hasNext = true + override def hasNext: Boolean = _hasNext + override def next(): UnsafeRow = { + resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) + _hasNext = loc.nextValue() + resultRow } - buffer - } else { - null } + } else { + null + } + } + def getValue(key: InternalRow): InternalRow = { + val unsafeKey = key.asInstanceOf[UnsafeRow] + val map = binaryMap // avoid the compiler error + val loc = new map.Location // this could be allocated in stack + binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, + unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode()) + if (loc.isDefined) { + resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) + resultRow } else { - // Use the Java HashMap in local mode or for non-broadcast joins (e.g. ShuffleHashJoin) - hashTable.get(unsafeKey) + null } } - override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - if (binaryMap != null) { - // This could happen when a cached broadcast object need to be dumped into disk to free memory - out.writeInt(binaryMap.numElements()) - - var buffer = new Array[Byte](64) - def write(addr: MemoryLocation, length: Int): Unit = { - if (buffer.length < length) { - buffer = new Array[Byte](length) - } - Platform.copyMemory(addr.getBaseObject, addr.getBaseOffset, - buffer, Platform.BYTE_ARRAY_OFFSET, length) - out.write(buffer, 0, length) - } + override def close(): Unit = { + binaryMap.free() + } - val iter = binaryMap.iterator() - while (iter.hasNext) { - val loc = iter.next() - // [key size] [values size] [key bytes] [values bytes] - out.writeInt(loc.getKeyLength) - out.writeInt(loc.getValueLength) - write(loc.getKeyAddress, loc.getKeyLength) - write(loc.getValueAddress, loc.getValueLength) + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { + out.writeInt(numFields) + // TODO: move these into BytesToBytesMap + out.writeInt(binaryMap.numKeys()) + out.writeInt(binaryMap.numValues()) + + var buffer = new Array[Byte](64) + def write(base: Object, offset: Long, length: Int): Unit = { + if (buffer.length < length) { + buffer = new Array[Byte](length) } + Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length) + out.write(buffer, 0, length) + } - } else { - assert(hashTable != null) - out.writeInt(hashTable.size()) - - val iter = hashTable.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() - val key = entry.getKey - val values = entry.getValue - - // write all the values as single byte array - var totalSize = 0L - var i = 0 - while (i < values.length) { - totalSize += values(i).getSizeInBytes + 4 + 4 - i += 1 - } - assert(totalSize < Integer.MAX_VALUE, "values are too big") - - // [key size] [values size] [key bytes] [values bytes] - out.writeInt(key.getSizeInBytes) - out.writeInt(totalSize.toInt) - out.write(key.getBytes) - i = 0 - while (i < values.length) { - // [num of fields] [num of bytes] [row bytes] - // write the integer in native order, so they can be read by UNSAFE.getInt() - if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) { - out.writeInt(values(i).numFields()) - out.writeInt(values(i).getSizeInBytes) - } else { - out.writeInt(Integer.reverseBytes(values(i).numFields())) - out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes)) - } - out.write(values(i).getBytes) - i += 1 - } - } + val iter = binaryMap.iterator() + while (iter.hasNext) { + val loc = iter.next() + // [key size] [values size] [key bytes] [value bytes] + out.writeInt(loc.getKeyLength) + out.writeInt(loc.getValueLength) + write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength) + write(loc.getValueBase, loc.getValueOffset, loc.getValueLength) } } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { + numFields = in.readInt() + resultRow = new UnsafeRow(numFields) val nKeys = in.readInt() + val nValues = in.readInt() // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory // TODO(josh): This needs to be revisited before we merge this patch; making this change now // so that tests compile: val taskMemoryManager = new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.unsafe.offHeap", "false"), Long.MaxValue, Long.MaxValue, 1), 0) + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m")) @@ -340,7 +226,7 @@ private[joins] final class UnsafeHashedRelation( var i = 0 var keyBuffer = new Array[Byte](1024) var valuesBuffer = new Array[Byte](1024) - while (i < nKeys) { + while (i < nValues) { val keySize = in.readInt() val valuesSize = in.readInt() if (keySize > keyBuffer.length) { @@ -352,13 +238,11 @@ private[joins] final class UnsafeHashedRelation( } in.readFully(valuesBuffer, 0, valuesSize) - // put it into binary map val loc = binaryMap.lookup(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize) - assert(!loc.isDefined, "Duplicated key found!") - val putSuceeded = loc.putNewKey( - keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize, + val putSuceeded = loc.append(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize, valuesBuffer, Platform.BYTE_ARRAY_OFFSET, valuesSize) if (!putSuceeded) { + binaryMap.free() throw new IOException("Could not allocate memory to grow BytesToBytesMap") } i += 1 @@ -370,31 +254,503 @@ private[joins] object UnsafeHashedRelation { def apply( input: Iterator[InternalRow], - numInputRows: LongSQLMetric, - keyGenerator: UnsafeProjection, - sizeEstimate: Int): HashedRelation = { + key: Seq[Expression], + sizeEstimate: Int, + taskMemoryManager: TaskMemoryManager): HashedRelation = { + + val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) + .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m")) - // Use a Java hash table here because unsafe maps expect fixed size records - val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate) + val binaryMap = new BytesToBytesMap( + taskMemoryManager, + // Only 70% of the slots can be used before growing, more capacity help to reduce collision + (sizeEstimate * 1.5 + 1).toInt, + pageSizeBytes) // Create a mapping of buildKeys -> rows + val keyGenerator = UnsafeProjection.create(key) + var numFields = 0 + while (input.hasNext) { + val row = input.next().asInstanceOf[UnsafeRow] + numFields = row.numFields() + val key = keyGenerator(row) + if (!key.anyNull) { + val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) + val success = loc.append( + key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + row.getBaseObject, row.getBaseOffset, row.getSizeInBytes) + if (!success) { + binaryMap.free() + throw new SparkException("There is no enough memory to build hash map") + } + } + } + + new UnsafeHashedRelation(numFields, binaryMap) + } +} + +/** + * An append-only hash map mapping from key of Long to UnsafeRow. + * + * The underlying bytes of all values (UnsafeRows) are packed together as a single byte array + * (`page`) in this format: + * + * [bytes of row1][address1][bytes of row2][address1] ... + * + * address1 (8 bytes) is the offset and size of next value for the same key as row1, any key + * could have multiple values. the address at the end of last value for every key is 0. + * + * The keys and addresses of their values could be stored in two modes: + * + * 1) sparse mode: the keys and addresses are stored in `array` as: + * + * [key1][address1][key2][address2]...[] + * + * address1 (Long) is the offset (in `page`) and size of the value for key1. The position of key1 + * is determined by `key1 % cap`. Quadratic probing with triangular numbers is used to address + * hash collision. + * + * 2) dense mode: all the addresses are packed into a single array of long, as: + * + * [address1] [address2] ... + * + * address1 (Long) is the offset (in `page`) and size of the value for key1, the position is + * determined by `key1 - minKey`. + * + * The map is created as sparse mode, then key-value could be appended into it. Once finish + * appending, caller could all optimize() to try to turn the map into dense mode, which is faster + * to probe. + * + * see http://java-performance.info/implementing-world-fastest-java-int-to-int-hash-map/ + */ +private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) + extends MemoryConsumer(mm) with Externalizable { + + // Whether the keys are stored in dense mode or not. + private var isDense = false + + // The minimum key + private var minKey = Long.MaxValue + + // The maxinum key + private var maxKey = Long.MinValue + + // The array to store the key and offset of UnsafeRow in the page. + // + // Sparse mode: [key1] [offset1 | size1] [key2] [offset | size2] ... + // Dense mode: [offset1 | size1] [offset2 | size2] + private var array: Array[Long] = null + private var mask: Int = 0 + + // The page to store all bytes of UnsafeRow and the pointer to next rows. + // [row1][pointer1] [row2][pointer2] + private var page: Array[Byte] = null + + // Current write cursor in the page. + private var cursor = Platform.BYTE_ARRAY_OFFSET + + // The total number of values of all keys. + private var numValues = 0 + + // The number of unique keys. + private var numKeys = 0 + + // needed by serializer + def this() = { + this( + new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0), + 0) + } + + private def acquireMemory(size: Long): Unit = { + // do not support spilling + val got = mm.acquireExecutionMemory(size, MemoryMode.ON_HEAP, this) + if (got < size) { + freeMemory(got) + throw new SparkException(s"Can't acquire $size bytes memory to build hash relation, " + + s"got $got bytes") + } + } + + private def freeMemory(size: Long): Unit = { + mm.releaseExecutionMemory(size, MemoryMode.ON_HEAP, this) + } + + private def init(): Unit = { + if (mm != null) { + var n = 1 + while (n < capacity) n *= 2 + acquireMemory(n * 2 * 8 + (1 << 20)) + array = new Array[Long](n * 2) + mask = n * 2 - 2 + page = new Array[Byte](1 << 20) // 1M bytes + } + } + + init() + + def spill(size: Long, trigger: MemoryConsumer): Long = 0L + + /** + * Returns whether all the keys are unique. + */ + def keyIsUnique: Boolean = numKeys == numValues + + /** + * Returns total memory consumption. + */ + def getTotalMemoryConsumption: Long = array.length * 8 + page.length + + /** + * Returns the first slot of array that store the keys (sparse mode). + */ + private def firstSlot(key: Long): Int = { + val h = key * 0x9E3779B9L + (h ^ (h >> 32)).toInt & mask + } + + /** + * Returns the next probe in the array. + */ + private def nextSlot(pos: Int): Int = (pos + 2) & mask + + private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = { + val offset = address >>> 32 + val size = address & 0xffffffffL + resultRow.pointTo(page, offset, size.toInt) + resultRow + } + + /** + * Returns the single UnsafeRow for given key, or null if not found. + */ + def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = { + if (isDense) { + val idx = (key - minKey).toInt + if (idx >= 0 && key <= maxKey && array(idx) > 0) { + return getRow(array(idx), resultRow) + } + } else { + var pos = firstSlot(key) + while (array(pos + 1) != 0) { + if (array(pos) == key) { + return getRow(array(pos + 1), resultRow) + } + pos = nextSlot(pos) + } + } + null + } + + /** + * Returns an interator of UnsafeRow for multiple linked values. + */ + private def valueIter(address: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { + new Iterator[UnsafeRow] { + var addr = address + override def hasNext: Boolean = addr != 0 + override def next(): UnsafeRow = { + val offset = addr >>> 32 + val size = addr & 0xffffffffL + resultRow.pointTo(page, offset, size.toInt) + addr = Platform.getLong(page, offset + size) + resultRow + } + } + } + + /** + * Returns an iterator for all the values for the given key, or null if no value found. + */ + def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { + if (isDense) { + val idx = (key - minKey).toInt + if (idx >=0 && key <= maxKey && array(idx) > 0) { + return valueIter(array(idx), resultRow) + } + } else { + var pos = firstSlot(key) + while (array(pos + 1) != 0) { + if (array(pos) == key) { + return valueIter(array(pos + 1), resultRow) + } + pos = nextSlot(pos) + } + } + null + } + + /** + * Appends the key and row into this map. + */ + def append(key: Long, row: UnsafeRow): Unit = { + if (key < minKey) { + minKey = key + } + if (key > maxKey) { + maxKey = key + } + + // There is 8 bytes for the pointer to next value + if (cursor + 8 + row.getSizeInBytes > page.length + Platform.BYTE_ARRAY_OFFSET) { + val used = page.length + if (used * 2L > (1L << 31)) { + sys.error("Can't allocate a page that is larger than 2G") + } + acquireMemory(used * 2) + val newPage = new Array[Byte](used * 2) + System.arraycopy(page, 0, newPage, 0, cursor - Platform.BYTE_ARRAY_OFFSET) + page = newPage + freeMemory(used) + } + + // copy the bytes of UnsafeRow + val offset = cursor + Platform.copyMemory(row.getBaseObject, row.getBaseOffset, page, cursor, row.getSizeInBytes) + cursor += row.getSizeInBytes + Platform.putLong(page, cursor, 0) + cursor += 8 + numValues += 1 + updateIndex(key, (offset.toLong << 32) | row.getSizeInBytes) + } + + /** + * Update the address in array for given key. + */ + private def updateIndex(key: Long, address: Long): Unit = { + var pos = firstSlot(key) + while (array(pos) != key && array(pos + 1) != 0) { + pos = nextSlot(pos) + } + if (array(pos + 1) == 0) { + // this is the first value for this key, put the address in array. + array(pos) = key + array(pos + 1) = address + numKeys += 1 + if (numKeys * 4 > array.length) { + // reach half of the capacity + growArray() + } + } else { + // there are some values for this key, put the address in the front of them. + val pointer = (address >>> 32) + (address & 0xffffffffL) + Platform.putLong(page, pointer, array(pos + 1)) + array(pos + 1) = address + } + } + + private def growArray(): Unit = { + var old_array = array + val n = array.length + numKeys = 0 + acquireMemory(n * 2 * 8) + array = new Array[Long](n * 2) + mask = n * 2 - 2 + var i = 0 + while (i < old_array.length) { + if (old_array(i + 1) > 0) { + updateIndex(old_array(i), old_array(i + 1)) + } + i += 2 + } + old_array = null // release the reference to old array + freeMemory(n * 8) + } + + /** + * Try to turn the map into dense mode, which is faster to probe. + */ + def optimize(): Unit = { + val range = maxKey - minKey + // Convert to dense mode if it does not require more memory or could fit within L1 cache + if (range < array.length || range < 1024) { + try { + acquireMemory((range + 1) * 8) + } catch { + case e: SparkException => + // there is no enough memory to convert + return + } + val denseArray = new Array[Long]((range + 1).toInt) + var i = 0 + while (i < array.length) { + if (array(i + 1) > 0) { + val idx = (array(i) - minKey).toInt + denseArray(idx) = array(i + 1) + } + i += 2 + } + val old_length = array.length + array = denseArray + isDense = true + freeMemory(old_length * 8) + } + } + + /** + * Free all the memory acquired by this map. + */ + def free(): Unit = { + if (page != null) { + freeMemory(page.length) + page = null + } + if (array != null) { + freeMemory(array.length * 8) + array = null + } + } + + override def writeExternal(out: ObjectOutput): Unit = { + out.writeBoolean(isDense) + out.writeLong(minKey) + out.writeLong(maxKey) + out.writeInt(numKeys) + out.writeInt(numValues) + + out.writeInt(array.length) + val buffer = new Array[Byte](4 << 10) + var offset = Platform.LONG_ARRAY_OFFSET + val end = array.length * 8 + Platform.LONG_ARRAY_OFFSET + while (offset < end) { + val size = Math.min(buffer.length, end - offset) + Platform.copyMemory(array, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size) + out.write(buffer, 0, size) + offset += size + } + + val used = cursor - Platform.BYTE_ARRAY_OFFSET + out.writeInt(used) + out.write(page, 0, used) + } + + override def readExternal(in: ObjectInput): Unit = { + isDense = in.readBoolean() + minKey = in.readLong() + maxKey = in.readLong() + numKeys = in.readInt() + numValues = in.readInt() + + val length = in.readInt() + array = new Array[Long](length) + mask = length - 2 + val buffer = new Array[Byte](4 << 10) + var offset = Platform.LONG_ARRAY_OFFSET + val end = length * 8 + Platform.LONG_ARRAY_OFFSET + while (offset < end) { + val size = Math.min(buffer.length, end - offset) + in.readFully(buffer, 0, size) + Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size) + offset += size + } + + val numBytes = in.readInt() + page = new Array[Byte](numBytes) + in.readFully(page) + } +} + +private[joins] class LongHashedRelation( + private var nFields: Int, + private var map: LongToUnsafeRowMap) extends HashedRelation with Externalizable { + + private var resultRow: UnsafeRow = new UnsafeRow(nFields) + + // Needed for serialization (it is public to make Java serialization work) + def this() = this(0, null) + + override def asReadOnlyCopy(): LongHashedRelation = new LongHashedRelation(nFields, map) + + override def estimatedSize: Long = map.getTotalMemoryConsumption + + override def get(key: InternalRow): Iterator[InternalRow] = { + if (key.isNullAt(0)) { + null + } else { + get(key.getLong(0)) + } + } + + override def getValue(key: InternalRow): InternalRow = { + if (key.isNullAt(0)) { + null + } else { + getValue(key.getLong(0)) + } + } + + override def get(key: Long): Iterator[InternalRow] = map.get(key, resultRow) + + override def getValue(key: Long): InternalRow = map.getValue(key, resultRow) + + override def keyIsUnique: Boolean = map.keyIsUnique + + override def close(): Unit = { + map.free() + } + + override def writeExternal(out: ObjectOutput): Unit = { + out.writeInt(nFields) + out.writeObject(map) + } + + override def readExternal(in: ObjectInput): Unit = { + nFields = in.readInt() + resultRow = new UnsafeRow(nFields) + map = in.readObject().asInstanceOf[LongToUnsafeRowMap] + } +} + +/** + * Create hashed relation with key that is long. + */ +private[joins] object LongHashedRelation { + def apply( + input: Iterator[InternalRow], + key: Seq[Expression], + sizeEstimate: Int, + taskMemoryManager: TaskMemoryManager): LongHashedRelation = { + + val map: LongToUnsafeRowMap = new LongToUnsafeRowMap(taskMemoryManager, sizeEstimate) + val keyGenerator = UnsafeProjection.create(key) + + // Create a mapping of key -> rows + var numFields = 0 while (input.hasNext) { val unsafeRow = input.next().asInstanceOf[UnsafeRow] - numInputRows += 1 + numFields = unsafeRow.numFields() val rowKey = keyGenerator(unsafeRow) - if (!rowKey.anyNull) { - val existingMatchList = hashTable.get(rowKey) - val matchList = if (existingMatchList == null) { - val newMatchList = new CompactBuffer[UnsafeRow]() - hashTable.put(rowKey.copy(), newMatchList) - newMatchList - } else { - existingMatchList - } - matchList += unsafeRow.copy() + if (!rowKey.isNullAt(0)) { + val key = rowKey.getLong(0) + map.append(key, unsafeRow) } } + map.optimize() + new LongHashedRelation(numFields, map) + } +} + +/** The HashedRelationBroadcastMode requires that rows are broadcasted as a HashedRelation. */ +private[execution] case class HashedRelationBroadcastMode(key: Seq[Expression]) + extends BroadcastMode { + + override def transform(rows: Array[InternalRow]): HashedRelation = { + HashedRelation(rows.iterator, canonicalizedKey, rows.length) + } + + private lazy val canonicalizedKey: Seq[Expression] = { + key.map { e => e.canonicalized } + } - new UnsafeHashedRelation(hashTable) + override def compatibleWith(other: BroadcastMode): Boolean = other match { + case m: HashedRelationBroadcastMode => canonicalizedKey == m.canonicalizedKey + case _ => false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala deleted file mode 100644 index efa7b49410edc..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} -import org.apache.spark.sql.execution.metric.SQLMetrics - -/** - * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys - * for hash join. - */ -case class LeftSemiJoinBNL( - streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) - extends BinaryNode { - // TODO: Override requiredChildDistribution. - - override private[sql] lazy val metrics = Map( - "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), - "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - override def outputPartitioning: Partitioning = streamed.outputPartitioning - - override def output: Seq[Attribute] = left.output - - override def outputsUnsafeRows: Boolean = streamed.outputsUnsafeRows - override def canProcessUnsafeRows: Boolean = true - - /** The Streamed Relation */ - override def left: SparkPlan = streamed - - /** The Broadcast relation */ - override def right: SparkPlan = broadcast - - @transient private lazy val boundCondition = - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - - protected override def doExecute(): RDD[InternalRow] = { - val numLeftRows = longMetric("numLeftRows") - val numRightRows = longMetric("numRightRows") - val numOutputRows = longMetric("numOutputRows") - - val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map { row => - numRightRows += 1 - row.copy() - }.collect().toIndexedSeq) - - streamed.execute().mapPartitions { streamedIter => - val joinedRow = new JoinedRow - - streamedIter.filter(streamedRow => { - numLeftRows += 1 - var i = 0 - var matched = false - - while (i < broadcastedRelation.value.size && !matched) { - val broadcastedRow = broadcastedRelation.value(i) - if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { - matched = true - } - i += 1 - } - if (matched) { - numOutputRows += 1 - } - matched - }) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala deleted file mode 100644 index bf3b05be981fb..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, ClusteredDistribution} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} -import org.apache.spark.sql.execution.metric.SQLMetrics - -/** - * Build the right table's join keys into a HashSet, and iteratively go through the left - * table, to find the if join keys are in the Hash set. - */ -case class LeftSemiJoinHash( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - left: SparkPlan, - right: SparkPlan, - condition: Option[Expression]) extends BinaryNode with HashSemiJoin { - - override private[sql] lazy val metrics = Map( - "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), - "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - override def outputPartitioning: Partitioning = left.outputPartitioning - - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - protected override def doExecute(): RDD[InternalRow] = { - val numLeftRows = longMetric("numLeftRows") - val numRightRows = longMetric("numRightRows") - val numOutputRows = longMetric("numOutputRows") - - right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) => - if (condition.isEmpty) { - val hashSet = buildKeyHashSet(buildIter, numRightRows) - hashSemiJoin(streamIter, numLeftRows, hashSet, numOutputRows) - } else { - val hashRelation = HashedRelation(buildIter, numRightRows, rightKeyGenerator) - hashSemiJoin(streamIter, numLeftRows, hashRelation, numOutputRows) - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 755986af8b95e..0c3e3c3fc18a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -17,46 +17,60 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{BindReferences, Expression, UnsafeRow} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics /** - * Performs an inner hash join of two child relations by first shuffling the data using the join - * keys. + * Performs a hash join of two child relations by first shuffling the data using the join keys. */ case class ShuffledHashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], + joinType: JoinType, buildSide: BuildSide, + condition: Option[Expression], left: SparkPlan, right: SparkPlan) extends BinaryNode with HashJoin { override private[sql] lazy val metrics = Map( - "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), - "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - override def outputPartitioning: Partitioning = - PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + override def outputPartitioning: Partitioning = joinType match { + case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + case LeftAnti => left.outputPartitioning + case LeftSemi => left.outputPartitioning + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case x => + throw new IllegalArgumentException(s"ShuffledHashJoin should not take $x as the JoinType") + } override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { + val context = TaskContext.get() + val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager()) + // This relation is usually used until the end of task. + context.addTaskCompletionListener((t: TaskContext) => + relation.close() + ) + relation + } + protected override def doExecute(): RDD[InternalRow] = { - val (numBuildRows, numStreamedRows) = buildSide match { - case BuildLeft => (longMetric("numLeftRows"), longMetric("numRightRows")) - case BuildRight => (longMetric("numRightRows"), longMetric("numLeftRows")) - } val numOutputRows = longMetric("numOutputRows") - - buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashed = HashedRelation(buildIter, numBuildRows, buildSideKeyGenerator) - hashJoin(streamIter, numStreamedRows, hashed, numOutputRows) + streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => + val hashed = buildHashedRelation(buildIter) + join(streamIter, hashed, numOutputRows) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala deleted file mode 100644 index 6b2cb9d8f6893..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import scala.collection.JavaConverters._ - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} -import org.apache.spark.sql.execution.metric.SQLMetrics - -/** - * Performs a hash based outer join for two child relations by shuffling the data using - * the join keys. This operator requires loading the associated partition in both side into memory. - */ -case class ShuffledHashOuterJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashOuterJoin { - - override private[sql] lazy val metrics = Map( - "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), - "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - override def outputPartitioning: Partitioning = joinType match { - case LeftOuter => left.outputPartitioning - case RightOuter => right.outputPartitioning - case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) - case x => - throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") - } - - protected override def doExecute(): RDD[InternalRow] = { - val numLeftRows = longMetric("numLeftRows") - val numRightRows = longMetric("numRightRows") - val numOutputRows = longMetric("numOutputRows") - - val joinedRow = new JoinedRow() - left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => - // TODO this probably can be replaced by external sort (sort merged join?) - joinType match { - case LeftOuter => - val hashed = HashedRelation(rightIter, numRightRows, buildKeyGenerator) - val keyGenerator = streamedKeyGenerator - val resultProj = resultProjection - leftIter.flatMap( currentRow => { - numLeftRows += 1 - val rowKey = keyGenerator(currentRow) - joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj, numOutputRows) - }) - - case RightOuter => - val hashed = HashedRelation(leftIter, numLeftRows, buildKeyGenerator) - val keyGenerator = streamedKeyGenerator - val resultProj = resultProjection - rightIter.flatMap ( currentRow => { - numRightRows += 1 - val rowKey = keyGenerator(currentRow) - joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj, numOutputRows) - }) - - case FullOuter => - // TODO(davies): use UnsafeRow - val leftHashTable = - buildHashTable(leftIter, numLeftRows, newProjection(leftKeys, left.output)).asScala - val rightHashTable = - buildHashTable(rightIter, numRightRows, newProjection(rightKeys, right.output)).asScala - (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => - fullOuterIterator(key, - leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST), - joinedRow, - numOutputRows) - } - - case x => - throw new IllegalArgumentException( - s"ShuffledHashOuterJoin should not take $x as the JoinType") - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 17030947b7bbc..0e7b2f2f3187f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -22,9 +22,12 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} +import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, RowIterator, SparkPlan} import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} +import org.apache.spark.util.collection.BitSet /** * Performs an sort merge join of two child relations. @@ -32,18 +35,40 @@ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} case class SortMergeJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], left: SparkPlan, - right: SparkPlan) extends BinaryNode { + right: SparkPlan) extends BinaryNode with CodegenSupport { override private[sql] lazy val metrics = Map( - "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), - "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - override def output: Seq[Attribute] = left.output ++ right.output + override def output: Seq[Attribute] = { + joinType match { + case Inner => + left.output ++ right.output + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + (left.output ++ right.output).map(_.withNullability(true)) + case x => + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $x as the JoinType") + } + } - override def outputPartitioning: Partitioning = - PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + override def outputPartitioning: Partitioning = joinType match { + case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + // For left and right outer joins, the output is partitioned by the streamed input's join keys. + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case x => + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $x as the JoinType") + } override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil @@ -53,104 +78,395 @@ case class SortMergeJoin( override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil - protected[this] def isUnsafeMode: Boolean = { - (codegenEnabled && unsafeEnabled - && UnsafeProjection.canSupport(leftKeys) - && UnsafeProjection.canSupport(rightKeys) - && UnsafeProjection.canSupport(schema)) - } - - override def outputsUnsafeRows: Boolean = isUnsafeMode - override def canProcessUnsafeRows: Boolean = isUnsafeMode - override def canProcessSafeRows: Boolean = !isUnsafeMode - private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. keys.map(SortOrder(_, Ascending)) } + private def createLeftKeyGenerator(): Projection = + UnsafeProjection.create(leftKeys, left.output) + + private def createRightKeyGenerator(): Projection = + UnsafeProjection.create(rightKeys, right.output) + protected override def doExecute(): RDD[InternalRow] = { - val numLeftRows = longMetric("numLeftRows") - val numRightRows = longMetric("numRightRows") val numOutputRows = longMetric("numOutputRows") left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => - new RowIterator { - // The projection used to extract keys from input rows of the left child. - private[this] val leftKeyGenerator = { - if (isUnsafeMode) { - // It is very important to use UnsafeProjection if input rows are UnsafeRows. - // Otherwise, GenerateProjection will cause wrong results. - UnsafeProjection.create(leftKeys, left.output) - } else { - newProjection(leftKeys, left.output) - } + val boundCondition: (InternalRow) => Boolean = { + condition.map { cond => + newPredicate(cond, left.output ++ right.output) + }.getOrElse { + (r: InternalRow) => true } + } + // An ordering that can be used to compare keys from both sides. + val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) + val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output) - // The projection used to extract keys from input rows of the right child. - private[this] val rightKeyGenerator = { - if (isUnsafeMode) { - // It is very important to use UnsafeProjection if input rows are UnsafeRows. - // Otherwise, GenerateProjection will cause wrong results. - UnsafeProjection.create(rightKeys, right.output) - } else { - newProjection(rightKeys, right.output) - } - } + joinType match { + case Inner => + new RowIterator { + // The projection used to extract keys from input rows of the left child. + private[this] val leftKeyGenerator = UnsafeProjection.create(leftKeys, left.output) - // An ordering that can be used to compare keys from both sides. - private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) - private[this] var currentLeftRow: InternalRow = _ - private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _ - private[this] var currentMatchIdx: Int = -1 - private[this] val smjScanner = new SortMergeJoinScanner( - leftKeyGenerator, - rightKeyGenerator, - keyOrdering, - RowIterator.fromScala(leftIter), - numLeftRows, - RowIterator.fromScala(rightIter), - numRightRows - ) - private[this] val joinRow = new JoinedRow - private[this] val resultProjection: (InternalRow) => InternalRow = { - if (isUnsafeMode) { - UnsafeProjection.create(schema) - } else { - identity[InternalRow] - } - } + // The projection used to extract keys from input rows of the right child. + private[this] val rightKeyGenerator = UnsafeProjection.create(rightKeys, right.output) + + // An ordering that can be used to compare keys from both sides. + private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) + private[this] var currentLeftRow: InternalRow = _ + private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _ + private[this] var currentMatchIdx: Int = -1 + private[this] val smjScanner = new SortMergeJoinScanner( + leftKeyGenerator, + rightKeyGenerator, + keyOrdering, + RowIterator.fromScala(leftIter), + RowIterator.fromScala(rightIter) + ) + private[this] val joinRow = new JoinedRow + private[this] val resultProjection: (InternalRow) => InternalRow = + UnsafeProjection.create(schema) - override def advanceNext(): Boolean = { - if (currentMatchIdx == -1 || currentMatchIdx == currentRightMatches.length) { if (smjScanner.findNextInnerJoinRows()) { currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow currentMatchIdx = 0 - } else { - currentRightMatches = null - currentLeftRow = null - currentMatchIdx = -1 } - } - if (currentLeftRow != null) { - joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) - currentMatchIdx += 1 - numOutputRows += 1 - true - } else { - false - } - } - override def getRow: InternalRow = resultProjection(joinRow) - }.toScala + override def advanceNext(): Boolean = { + while (currentMatchIdx >= 0) { + if (currentMatchIdx == currentRightMatches.length) { + if (smjScanner.findNextInnerJoinRows()) { + currentRightMatches = smjScanner.getBufferedMatches + currentLeftRow = smjScanner.getStreamedRow + currentMatchIdx = 0 + } else { + currentRightMatches = null + currentLeftRow = null + currentMatchIdx = -1 + return false + } + } + joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) + currentMatchIdx += 1 + if (boundCondition(joinRow)) { + numOutputRows += 1 + return true + } + } + false + } + + override def getRow: InternalRow = resultProjection(joinRow) + }.toScala + + case LeftOuter => + val smjScanner = new SortMergeJoinScanner( + streamedKeyGenerator = createLeftKeyGenerator(), + bufferedKeyGenerator = createRightKeyGenerator(), + keyOrdering, + streamedIter = RowIterator.fromScala(leftIter), + bufferedIter = RowIterator.fromScala(rightIter) + ) + val rightNullRow = new GenericInternalRow(right.output.length) + new LeftOuterIterator( + smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows).toScala + + case RightOuter => + val smjScanner = new SortMergeJoinScanner( + streamedKeyGenerator = createRightKeyGenerator(), + bufferedKeyGenerator = createLeftKeyGenerator(), + keyOrdering, + streamedIter = RowIterator.fromScala(rightIter), + bufferedIter = RowIterator.fromScala(leftIter) + ) + val leftNullRow = new GenericInternalRow(left.output.length) + new RightOuterIterator( + smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala + + case FullOuter => + val leftNullRow = new GenericInternalRow(left.output.length) + val rightNullRow = new GenericInternalRow(right.output.length) + val smjScanner = new SortMergeFullOuterJoinScanner( + leftKeyGenerator = createLeftKeyGenerator(), + rightKeyGenerator = createRightKeyGenerator(), + keyOrdering, + leftIter = RowIterator.fromScala(leftIter), + rightIter = RowIterator.fromScala(rightIter), + boundCondition, + leftNullRow, + rightNullRow) + + new FullOuterIterator( + smjScanner, + resultProj, + numOutputRows).toScala + + case x => + throw new IllegalArgumentException( + s"SortMergeJoin should not take $x as the JoinType") + } + } } + + override def supportCodegen: Boolean = { + joinType == Inner + } + + override def upstreams(): Seq[RDD[InternalRow]] = { + left.execute() :: right.execute() :: Nil + } + + private def createJoinKey( + ctx: CodegenContext, + row: String, + keys: Seq[Expression], + input: Seq[Attribute]): Seq[ExprCode] = { + ctx.INPUT_ROW = row + keys.map(BindReferences.bindReference(_, input).gen(ctx)) + } + + private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = { + vars.zipWithIndex.map { case (ev, i) => + val value = ctx.freshName("value") + ctx.addMutableState(ctx.javaType(leftKeys(i).dataType), value, "") + val code = + s""" + |$value = ${ev.value}; + """.stripMargin + ExprCode(code, "false", value) + } + } + + private def genComparision(ctx: CodegenContext, a: Seq[ExprCode], b: Seq[ExprCode]): String = { + val comparisons = a.zip(b).zipWithIndex.map { case ((l, r), i) => + s""" + |if (comp == 0) { + | comp = ${ctx.genComp(leftKeys(i).dataType, l.value, r.value)}; + |} + """.stripMargin.trim + } + s""" + |comp = 0; + |${comparisons.mkString("\n")} + """.stripMargin + } + + /** + * Generate a function to scan both left and right to find a match, returns the term for + * matched one row from left side and buffered rows from right side. + */ + private def genScanner(ctx: CodegenContext): (String, String) = { + // Create class member for next row from both sides. + val leftRow = ctx.freshName("leftRow") + ctx.addMutableState("InternalRow", leftRow, "") + val rightRow = ctx.freshName("rightRow") + ctx.addMutableState("InternalRow", rightRow, s"$rightRow = null;") + + // Create variables for join keys from both sides. + val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output) + val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ") + val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output) + val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ") + // Copy the right key as class members so they could be used in next function call. + val rightKeyVars = copyKeys(ctx, rightKeyTmpVars) + + // A list to hold all matched rows from right side. + val matches = ctx.freshName("matches") + val clsName = classOf[java.util.ArrayList[InternalRow]].getName + ctx.addMutableState(clsName, matches, s"$matches = new $clsName();") + // Copy the left keys as class members so they could be used in next function call. + val matchedKeyVars = copyKeys(ctx, leftKeyVars) + + ctx.addNewFunction("findNextInnerJoinRows", + s""" + |private boolean findNextInnerJoinRows( + | scala.collection.Iterator leftIter, + | scala.collection.Iterator rightIter) { + | $leftRow = null; + | int comp = 0; + | while ($leftRow == null) { + | if (!leftIter.hasNext()) return false; + | $leftRow = (InternalRow) leftIter.next(); + | ${leftKeyVars.map(_.code).mkString("\n")} + | if ($leftAnyNull) { + | $leftRow = null; + | continue; + | } + | if (!$matches.isEmpty()) { + | ${genComparision(ctx, leftKeyVars, matchedKeyVars)} + | if (comp == 0) { + | return true; + | } + | $matches.clear(); + | } + | + | do { + | if ($rightRow == null) { + | if (!rightIter.hasNext()) { + | ${matchedKeyVars.map(_.code).mkString("\n")} + | return !$matches.isEmpty(); + | } + | $rightRow = (InternalRow) rightIter.next(); + | ${rightKeyTmpVars.map(_.code).mkString("\n")} + | if ($rightAnyNull) { + | $rightRow = null; + | continue; + | } + | ${rightKeyVars.map(_.code).mkString("\n")} + | } + | ${genComparision(ctx, leftKeyVars, rightKeyVars)} + | if (comp > 0) { + | $rightRow = null; + | } else if (comp < 0) { + | if (!$matches.isEmpty()) { + | ${matchedKeyVars.map(_.code).mkString("\n")} + | return true; + | } + | $leftRow = null; + | } else { + | $matches.add($rightRow.copy()); + | $rightRow = null;; + | } + | } while ($leftRow != null); + | } + | return false; // unreachable + |} + """.stripMargin) + + (leftRow, matches) + } + + /** + * Creates variables for left part of result row. + * + * In order to defer the access after condition and also only access once in the loop, + * the variables should be declared separately from accessing the columns, we can't use the + * codegen of BoundReference here. + */ + private def createLeftVars(ctx: CodegenContext, leftRow: String): Seq[ExprCode] = { + ctx.INPUT_ROW = leftRow + left.output.zipWithIndex.map { case (a, i) => + val value = ctx.freshName("value") + val valueCode = ctx.getValue(leftRow, a.dataType, i.toString) + // declare it as class member, so we can access the column before or in the loop. + ctx.addMutableState(ctx.javaType(a.dataType), value, "") + if (a.nullable) { + val isNull = ctx.freshName("isNull") + ctx.addMutableState("boolean", isNull, "") + val code = + s""" + |$isNull = $leftRow.isNullAt($i); + |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode); + """.stripMargin + ExprCode(code, isNull, value) + } else { + ExprCode(s"$value = $valueCode;", "false", value) + } + } + } + + /** + * Creates the variables for right part of result row, using BoundReference, since the right + * part are accessed inside the loop. + */ + private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = { + ctx.INPUT_ROW = rightRow + right.output.zipWithIndex.map { case (a, i) => + BoundReference(i, a.dataType, a.nullable).gen(ctx) + } + } + + /** + * Splits variables based on whether it's used by condition or not, returns the code to create + * these variables before the condition and after the condition. + * + * Only a few columns are used by condition, then we can skip the accessing of those columns + * that are not used by condition also filtered out by condition. + */ + private def splitVarsByCondition( + attributes: Seq[Attribute], + variables: Seq[ExprCode]): (String, String) = { + if (condition.isDefined) { + val condRefs = condition.get.references + val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) => + condRefs.contains(a) + } + val beforeCond = evaluateVariables(used.map(_._2)) + val afterCond = evaluateVariables(notUsed.map(_._2)) + (beforeCond, afterCond) + } else { + (evaluateVariables(variables), "") + } + } + + override def doProduce(ctx: CodegenContext): String = { + ctx.copyResult = true + val leftInput = ctx.freshName("leftInput") + ctx.addMutableState("scala.collection.Iterator", leftInput, s"$leftInput = inputs[0];") + val rightInput = ctx.freshName("rightInput") + ctx.addMutableState("scala.collection.Iterator", rightInput, s"$rightInput = inputs[1];") + + val (leftRow, matches) = genScanner(ctx) + + // Create variables for row from both sides. + val leftVars = createLeftVars(ctx, leftRow) + val rightRow = ctx.freshName("rightRow") + val rightVars = createRightVar(ctx, rightRow) + + val size = ctx.freshName("size") + val i = ctx.freshName("i") + val numOutput = metricTerm(ctx, "numOutputRows") + val (beforeLoop, condCheck) = if (condition.isDefined) { + // Split the code of creating variables based on whether it's used by condition or not. + val loaded = ctx.freshName("loaded") + val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars) + val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars) + // Generate code for condition + ctx.currentVars = leftVars ++ rightVars + val cond = BindReferences.bindReference(condition.get, output).gen(ctx) + // evaluate the columns those used by condition before loop + val before = s""" + |boolean $loaded = false; + |$leftBefore + """.stripMargin + + val checking = s""" + |$rightBefore + |${cond.code} + |if (${cond.isNull} || !${cond.value}) continue; + |if (!$loaded) { + | $loaded = true; + | $leftAfter + |} + |$rightAfter + """.stripMargin + (before, checking) + } else { + (evaluateVariables(leftVars), "") + } + + s""" + |while (findNextInnerJoinRows($leftInput, $rightInput)) { + | int $size = $matches.size(); + | ${beforeLoop.trim} + | for (int $i = 0; $i < $size; $i ++) { + | InternalRow $rightRow = (InternalRow) $matches.get($i); + | ${condCheck.trim} + | $numOutput.add(1); + | ${consume(ctx, leftVars ++ rightVars)} + | } + | if (shouldStop()) return; + |} + """.stripMargin + } } /** - * Helper class that is used to implement [[SortMergeJoin]] and [[SortMergeOuterJoin]]. + * Helper class that is used to implement [[SortMergeJoin]]. * * To perform an inner (outer) join, users of this class call [[findNextInnerJoinRows()]] * ([[findNextOuterJoinRows()]]), which returns `true` if a result has been produced and `false` @@ -173,9 +489,7 @@ private[joins] class SortMergeJoinScanner( bufferedKeyGenerator: Projection, keyOrdering: Ordering[InternalRow], streamedIter: RowIterator, - numStreamedRows: LongSQLMetric, - bufferedIter: RowIterator, - numBufferedRows: LongSQLMetric) { + bufferedIter: RowIterator) { private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ private[this] var bufferedRow: InternalRow = _ @@ -300,7 +614,6 @@ private[joins] class SortMergeJoinScanner( if (streamedIter.advanceNext()) { streamedRow = streamedIter.getRow streamedRowKey = streamedKeyGenerator(streamedRow) - numStreamedRows += 1 true } else { streamedRow = null @@ -318,7 +631,6 @@ private[joins] class SortMergeJoinScanner( while (!foundRow && bufferedIter.advanceNext()) { bufferedRow = bufferedIter.getRow bufferedRowKey = bufferedKeyGenerator(bufferedRow) - numBufferedRows += 1 foundRow = !bufferedRowKey.anyNull } if (!foundRow) { @@ -348,3 +660,305 @@ private[joins] class SortMergeJoinScanner( } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) } } + +/** + * An iterator for outputting rows in left outer join. + */ +private class LeftOuterIterator( + smjScanner: SortMergeJoinScanner, + rightNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow, + numOutputRows: LongSQLMetric) + extends OneSideOuterIterator( + smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) { + + protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row) + protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withRight(row) +} + +/** + * An iterator for outputting rows in right outer join. + */ +private class RightOuterIterator( + smjScanner: SortMergeJoinScanner, + leftNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow, + numOutputRows: LongSQLMetric) + extends OneSideOuterIterator(smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) { + + protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row) + protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row) +} + +/** + * An abstract iterator for sharing code between [[LeftOuterIterator]] and [[RightOuterIterator]]. + * + * Each [[OneSideOuterIterator]] has a streamed side and a buffered side. Each row on the + * streamed side will output 0 or many rows, one for each matching row on the buffered side. + * If there are no matches, then the buffered side of the joined output will be a null row. + * + * In left outer join, the left is the streamed side and the right is the buffered side. + * In right outer join, the right is the streamed side and the left is the buffered side. + * + * @param smjScanner a scanner that streams rows and buffers any matching rows + * @param bufferedSideNullRow the default row to return when a streamed row has no matches + * @param boundCondition an additional filter condition for buffered rows + * @param resultProj how the output should be projected + * @param numOutputRows an accumulator metric for the number of rows output + */ +private abstract class OneSideOuterIterator( + smjScanner: SortMergeJoinScanner, + bufferedSideNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow, + numOutputRows: LongSQLMetric) extends RowIterator { + + // A row to store the joined result, reused many times + protected[this] val joinedRow: JoinedRow = new JoinedRow() + + // Index of the buffered rows, reset to 0 whenever we advance to a new streamed row + private[this] var bufferIndex: Int = 0 + + // This iterator is initialized lazily so there should be no matches initially + assert(smjScanner.getBufferedMatches.length == 0) + + // Set output methods to be overridden by subclasses + protected def setStreamSideOutput(row: InternalRow): Unit + protected def setBufferedSideOutput(row: InternalRow): Unit + + /** + * Advance to the next row on the stream side and populate the buffer with matches. + * @return whether there are more rows in the stream to consume. + */ + private def advanceStream(): Boolean = { + bufferIndex = 0 + if (smjScanner.findNextOuterJoinRows()) { + setStreamSideOutput(smjScanner.getStreamedRow) + if (smjScanner.getBufferedMatches.isEmpty) { + // There are no matching rows in the buffer, so return the null row + setBufferedSideOutput(bufferedSideNullRow) + } else { + // Find the next row in the buffer that satisfied the bound condition + if (!advanceBufferUntilBoundConditionSatisfied()) { + setBufferedSideOutput(bufferedSideNullRow) + } + } + true + } else { + // Stream has been exhausted + false + } + } + + /** + * Advance to the next row in the buffer that satisfies the bound condition. + * @return whether there is such a row in the current buffer. + */ + private def advanceBufferUntilBoundConditionSatisfied(): Boolean = { + var foundMatch: Boolean = false + while (!foundMatch && bufferIndex < smjScanner.getBufferedMatches.length) { + setBufferedSideOutput(smjScanner.getBufferedMatches(bufferIndex)) + foundMatch = boundCondition(joinedRow) + bufferIndex += 1 + } + foundMatch + } + + override def advanceNext(): Boolean = { + val r = advanceBufferUntilBoundConditionSatisfied() || advanceStream() + if (r) numOutputRows += 1 + r + } + + override def getRow: InternalRow = resultProj(joinedRow) +} + +private class SortMergeFullOuterJoinScanner( + leftKeyGenerator: Projection, + rightKeyGenerator: Projection, + keyOrdering: Ordering[InternalRow], + leftIter: RowIterator, + rightIter: RowIterator, + boundCondition: InternalRow => Boolean, + leftNullRow: InternalRow, + rightNullRow: InternalRow) { + private[this] val joinedRow: JoinedRow = new JoinedRow() + private[this] var leftRow: InternalRow = _ + private[this] var leftRowKey: InternalRow = _ + private[this] var rightRow: InternalRow = _ + private[this] var rightRowKey: InternalRow = _ + + private[this] var leftIndex: Int = 0 + private[this] var rightIndex: Int = 0 + private[this] val leftMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] + private[this] val rightMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] + private[this] var leftMatched: BitSet = new BitSet(1) + private[this] var rightMatched: BitSet = new BitSet(1) + + advancedLeft() + advancedRight() + + // --- Private methods -------------------------------------------------------------------------- + + /** + * Advance the left iterator and compute the new row's join key. + * @return true if the left iterator returned a row and false otherwise. + */ + private def advancedLeft(): Boolean = { + if (leftIter.advanceNext()) { + leftRow = leftIter.getRow + leftRowKey = leftKeyGenerator(leftRow) + true + } else { + leftRow = null + leftRowKey = null + false + } + } + + /** + * Advance the right iterator and compute the new row's join key. + * @return true if the right iterator returned a row and false otherwise. + */ + private def advancedRight(): Boolean = { + if (rightIter.advanceNext()) { + rightRow = rightIter.getRow + rightRowKey = rightKeyGenerator(rightRow) + true + } else { + rightRow = null + rightRowKey = null + false + } + } + + /** + * Populate the left and right buffers with rows matching the provided key. + * This consumes rows from both iterators until their keys are different from the matching key. + */ + private def findMatchingRows(matchingKey: InternalRow): Unit = { + leftMatches.clear() + rightMatches.clear() + leftIndex = 0 + rightIndex = 0 + + while (leftRowKey != null && keyOrdering.compare(leftRowKey, matchingKey) == 0) { + leftMatches += leftRow.copy() + advancedLeft() + } + while (rightRowKey != null && keyOrdering.compare(rightRowKey, matchingKey) == 0) { + rightMatches += rightRow.copy() + advancedRight() + } + + if (leftMatches.size <= leftMatched.capacity) { + leftMatched.clear() + } else { + leftMatched = new BitSet(leftMatches.size) + } + if (rightMatches.size <= rightMatched.capacity) { + rightMatched.clear() + } else { + rightMatched = new BitSet(rightMatches.size) + } + } + + /** + * Scan the left and right buffers for the next valid match. + * + * Note: this method mutates `joinedRow` to point to the latest matching rows in the buffers. + * If a left row has no valid matches on the right, or a right row has no valid matches on the + * left, then the row is joined with the null row and the result is considered a valid match. + * + * @return true if a valid match is found, false otherwise. + */ + private def scanNextInBuffered(): Boolean = { + while (leftIndex < leftMatches.size) { + while (rightIndex < rightMatches.size) { + joinedRow(leftMatches(leftIndex), rightMatches(rightIndex)) + if (boundCondition(joinedRow)) { + leftMatched.set(leftIndex) + rightMatched.set(rightIndex) + rightIndex += 1 + return true + } + rightIndex += 1 + } + rightIndex = 0 + if (!leftMatched.get(leftIndex)) { + // the left row has never matched any right row, join it with null row + joinedRow(leftMatches(leftIndex), rightNullRow) + leftIndex += 1 + return true + } + leftIndex += 1 + } + + while (rightIndex < rightMatches.size) { + if (!rightMatched.get(rightIndex)) { + // the right row has never matched any left row, join it with null row + joinedRow(leftNullRow, rightMatches(rightIndex)) + rightIndex += 1 + return true + } + rightIndex += 1 + } + + // There are no more valid matches in the left and right buffers + false + } + + // --- Public methods -------------------------------------------------------------------------- + + def getJoinedRow(): JoinedRow = joinedRow + + def advanceNext(): Boolean = { + // If we already buffered some matching rows, use them directly + if (leftIndex <= leftMatches.size || rightIndex <= rightMatches.size) { + if (scanNextInBuffered()) { + return true + } + } + + if (leftRow != null && (leftRowKey.anyNull || rightRow == null)) { + joinedRow(leftRow.copy(), rightNullRow) + advancedLeft() + true + } else if (rightRow != null && (rightRowKey.anyNull || leftRow == null)) { + joinedRow(leftNullRow, rightRow.copy()) + advancedRight() + true + } else if (leftRow != null && rightRow != null) { + // Both rows are present and neither have null values, + // so we populate the buffers with rows matching the next key + val comp = keyOrdering.compare(leftRowKey, rightRowKey) + if (comp <= 0) { + findMatchingRows(leftRowKey.copy()) + } else { + findMatchingRows(rightRowKey.copy()) + } + scanNextInBuffered() + true + } else { + // Both iterators have been consumed + false + } + } +} + +private class FullOuterIterator( + smjScanner: SortMergeFullOuterJoinScanner, + resultProj: InternalRow => InternalRow, + numRows: LongSQLMetric) extends RowIterator { + private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow() + + override def advanceNext(): Boolean = { + val r = smjScanner.advanceNext() + if (r) numRows += 1 + r + } + + override def getRow: InternalRow = resultProj(joinedRow) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala deleted file mode 100644 index 7e854e6702f77..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ /dev/null @@ -1,505 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} -import org.apache.spark.util.collection.BitSet - -/** - * Performs an sort merge outer join of two child relations. - */ -case class SortMergeOuterJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode { - - override private[sql] lazy val metrics = Map( - "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), - "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - override def output: Seq[Attribute] = { - joinType match { - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - (left.output ++ right.output).map(_.withNullability(true)) - case x => - throw new IllegalArgumentException( - s"${getClass.getSimpleName} should not take $x as the JoinType") - } - } - - override def outputPartitioning: Partitioning = joinType match { - // For left and right outer joins, the output is partitioned by the streamed input's join keys. - case LeftOuter => left.outputPartitioning - case RightOuter => right.outputPartitioning - case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) - case x => - throw new IllegalArgumentException( - s"${getClass.getSimpleName} should not take $x as the JoinType") - } - - override def outputOrdering: Seq[SortOrder] = joinType match { - // For left and right outer joins, the output is ordered by the streamed input's join keys. - case LeftOuter => requiredOrders(leftKeys) - case RightOuter => requiredOrders(rightKeys) - // there are null rows in both streams, so there is no order - case FullOuter => Nil - case x => throw new IllegalArgumentException( - s"SortMergeOuterJoin should not take $x as the JoinType") - } - - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil - - private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { - // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. - keys.map(SortOrder(_, Ascending)) - } - - private def isUnsafeMode: Boolean = { - (codegenEnabled && unsafeEnabled - && UnsafeProjection.canSupport(leftKeys) - && UnsafeProjection.canSupport(rightKeys) - && UnsafeProjection.canSupport(schema)) - } - - override def outputsUnsafeRows: Boolean = isUnsafeMode - override def canProcessUnsafeRows: Boolean = isUnsafeMode - override def canProcessSafeRows: Boolean = !isUnsafeMode - - private def createLeftKeyGenerator(): Projection = { - if (isUnsafeMode) { - UnsafeProjection.create(leftKeys, left.output) - } else { - newProjection(leftKeys, left.output) - } - } - - private def createRightKeyGenerator(): Projection = { - if (isUnsafeMode) { - UnsafeProjection.create(rightKeys, right.output) - } else { - newProjection(rightKeys, right.output) - } - } - - override def doExecute(): RDD[InternalRow] = { - val numLeftRows = longMetric("numLeftRows") - val numRightRows = longMetric("numRightRows") - val numOutputRows = longMetric("numOutputRows") - - left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => - // An ordering that can be used to compare keys from both sides. - val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) - val boundCondition: (InternalRow) => Boolean = { - condition.map { cond => - newPredicate(cond, left.output ++ right.output) - }.getOrElse { - (r: InternalRow) => true - } - } - val resultProj: InternalRow => InternalRow = { - if (isUnsafeMode) { - UnsafeProjection.create(schema) - } else { - identity[InternalRow] - } - } - - joinType match { - case LeftOuter => - val smjScanner = new SortMergeJoinScanner( - streamedKeyGenerator = createLeftKeyGenerator(), - bufferedKeyGenerator = createRightKeyGenerator(), - keyOrdering, - streamedIter = RowIterator.fromScala(leftIter), - numLeftRows, - bufferedIter = RowIterator.fromScala(rightIter), - numRightRows - ) - val rightNullRow = new GenericInternalRow(right.output.length) - new LeftOuterIterator( - smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows).toScala - - case RightOuter => - val smjScanner = new SortMergeJoinScanner( - streamedKeyGenerator = createRightKeyGenerator(), - bufferedKeyGenerator = createLeftKeyGenerator(), - keyOrdering, - streamedIter = RowIterator.fromScala(rightIter), - numRightRows, - bufferedIter = RowIterator.fromScala(leftIter), - numLeftRows - ) - val leftNullRow = new GenericInternalRow(left.output.length) - new RightOuterIterator( - smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala - - case FullOuter => - val leftNullRow = new GenericInternalRow(left.output.length) - val rightNullRow = new GenericInternalRow(right.output.length) - val smjScanner = new SortMergeFullOuterJoinScanner( - leftKeyGenerator = createLeftKeyGenerator(), - rightKeyGenerator = createRightKeyGenerator(), - keyOrdering, - leftIter = RowIterator.fromScala(leftIter), - numLeftRows, - rightIter = RowIterator.fromScala(rightIter), - numRightRows, - boundCondition, - leftNullRow, - rightNullRow) - - new FullOuterIterator( - smjScanner, - resultProj, - numOutputRows).toScala - - case x => - throw new IllegalArgumentException( - s"SortMergeOuterJoin should not take $x as the JoinType") - } - } - } -} - -/** - * An iterator for outputting rows in left outer join. - */ -private class LeftOuterIterator( - smjScanner: SortMergeJoinScanner, - rightNullRow: InternalRow, - boundCondition: InternalRow => Boolean, - resultProj: InternalRow => InternalRow, - numOutputRows: LongSQLMetric) - extends OneSideOuterIterator( - smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) { - - protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row) - protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withRight(row) -} - -/** - * An iterator for outputting rows in right outer join. - */ -private class RightOuterIterator( - smjScanner: SortMergeJoinScanner, - leftNullRow: InternalRow, - boundCondition: InternalRow => Boolean, - resultProj: InternalRow => InternalRow, - numOutputRows: LongSQLMetric) - extends OneSideOuterIterator( - smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) { - - protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row) - protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row) -} - -/** - * An abstract iterator for sharing code between [[LeftOuterIterator]] and [[RightOuterIterator]]. - * - * Each [[OneSideOuterIterator]] has a streamed side and a buffered side. Each row on the - * streamed side will output 0 or many rows, one for each matching row on the buffered side. - * If there are no matches, then the buffered side of the joined output will be a null row. - * - * In left outer join, the left is the streamed side and the right is the buffered side. - * In right outer join, the right is the streamed side and the left is the buffered side. - * - * @param smjScanner a scanner that streams rows and buffers any matching rows - * @param bufferedSideNullRow the default row to return when a streamed row has no matches - * @param boundCondition an additional filter condition for buffered rows - * @param resultProj how the output should be projected - * @param numOutputRows an accumulator metric for the number of rows output - */ -private abstract class OneSideOuterIterator( - smjScanner: SortMergeJoinScanner, - bufferedSideNullRow: InternalRow, - boundCondition: InternalRow => Boolean, - resultProj: InternalRow => InternalRow, - numOutputRows: LongSQLMetric) extends RowIterator { - - // A row to store the joined result, reused many times - protected[this] val joinedRow: JoinedRow = new JoinedRow() - - // Index of the buffered rows, reset to 0 whenever we advance to a new streamed row - private[this] var bufferIndex: Int = 0 - - // This iterator is initialized lazily so there should be no matches initially - assert(smjScanner.getBufferedMatches.length == 0) - - // Set output methods to be overridden by subclasses - protected def setStreamSideOutput(row: InternalRow): Unit - protected def setBufferedSideOutput(row: InternalRow): Unit - - /** - * Advance to the next row on the stream side and populate the buffer with matches. - * @return whether there are more rows in the stream to consume. - */ - private def advanceStream(): Boolean = { - bufferIndex = 0 - if (smjScanner.findNextOuterJoinRows()) { - setStreamSideOutput(smjScanner.getStreamedRow) - if (smjScanner.getBufferedMatches.isEmpty) { - // There are no matching rows in the buffer, so return the null row - setBufferedSideOutput(bufferedSideNullRow) - } else { - // Find the next row in the buffer that satisfied the bound condition - if (!advanceBufferUntilBoundConditionSatisfied()) { - setBufferedSideOutput(bufferedSideNullRow) - } - } - true - } else { - // Stream has been exhausted - false - } - } - - /** - * Advance to the next row in the buffer that satisfies the bound condition. - * @return whether there is such a row in the current buffer. - */ - private def advanceBufferUntilBoundConditionSatisfied(): Boolean = { - var foundMatch: Boolean = false - while (!foundMatch && bufferIndex < smjScanner.getBufferedMatches.length) { - setBufferedSideOutput(smjScanner.getBufferedMatches(bufferIndex)) - foundMatch = boundCondition(joinedRow) - bufferIndex += 1 - } - foundMatch - } - - override def advanceNext(): Boolean = { - val r = advanceBufferUntilBoundConditionSatisfied() || advanceStream() - if (r) numOutputRows += 1 - r - } - - override def getRow: InternalRow = resultProj(joinedRow) -} - -private class SortMergeFullOuterJoinScanner( - leftKeyGenerator: Projection, - rightKeyGenerator: Projection, - keyOrdering: Ordering[InternalRow], - leftIter: RowIterator, - numLeftRows: LongSQLMetric, - rightIter: RowIterator, - numRightRows: LongSQLMetric, - boundCondition: InternalRow => Boolean, - leftNullRow: InternalRow, - rightNullRow: InternalRow) { - private[this] val joinedRow: JoinedRow = new JoinedRow() - private[this] var leftRow: InternalRow = _ - private[this] var leftRowKey: InternalRow = _ - private[this] var rightRow: InternalRow = _ - private[this] var rightRowKey: InternalRow = _ - - private[this] var leftIndex: Int = 0 - private[this] var rightIndex: Int = 0 - private[this] val leftMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] - private[this] val rightMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] - private[this] var leftMatched: BitSet = new BitSet(1) - private[this] var rightMatched: BitSet = new BitSet(1) - - advancedLeft() - advancedRight() - - // --- Private methods -------------------------------------------------------------------------- - - /** - * Advance the left iterator and compute the new row's join key. - * @return true if the left iterator returned a row and false otherwise. - */ - private def advancedLeft(): Boolean = { - if (leftIter.advanceNext()) { - leftRow = leftIter.getRow - leftRowKey = leftKeyGenerator(leftRow) - numLeftRows += 1 - true - } else { - leftRow = null - leftRowKey = null - false - } - } - - /** - * Advance the right iterator and compute the new row's join key. - * @return true if the right iterator returned a row and false otherwise. - */ - private def advancedRight(): Boolean = { - if (rightIter.advanceNext()) { - rightRow = rightIter.getRow - rightRowKey = rightKeyGenerator(rightRow) - numRightRows += 1 - true - } else { - rightRow = null - rightRowKey = null - false - } - } - - /** - * Populate the left and right buffers with rows matching the provided key. - * This consumes rows from both iterators until their keys are different from the matching key. - */ - private def findMatchingRows(matchingKey: InternalRow): Unit = { - leftMatches.clear() - rightMatches.clear() - leftIndex = 0 - rightIndex = 0 - - while (leftRowKey != null && keyOrdering.compare(leftRowKey, matchingKey) == 0) { - leftMatches += leftRow.copy() - advancedLeft() - } - while (rightRowKey != null && keyOrdering.compare(rightRowKey, matchingKey) == 0) { - rightMatches += rightRow.copy() - advancedRight() - } - - if (leftMatches.size <= leftMatched.capacity) { - leftMatched.clear() - } else { - leftMatched = new BitSet(leftMatches.size) - } - if (rightMatches.size <= rightMatched.capacity) { - rightMatched.clear() - } else { - rightMatched = new BitSet(rightMatches.size) - } - } - - /** - * Scan the left and right buffers for the next valid match. - * - * Note: this method mutates `joinedRow` to point to the latest matching rows in the buffers. - * If a left row has no valid matches on the right, or a right row has no valid matches on the - * left, then the row is joined with the null row and the result is considered a valid match. - * - * @return true if a valid match is found, false otherwise. - */ - private def scanNextInBuffered(): Boolean = { - while (leftIndex < leftMatches.size) { - while (rightIndex < rightMatches.size) { - joinedRow(leftMatches(leftIndex), rightMatches(rightIndex)) - if (boundCondition(joinedRow)) { - leftMatched.set(leftIndex) - rightMatched.set(rightIndex) - rightIndex += 1 - return true - } - rightIndex += 1 - } - rightIndex = 0 - if (!leftMatched.get(leftIndex)) { - // the left row has never matched any right row, join it with null row - joinedRow(leftMatches(leftIndex), rightNullRow) - leftIndex += 1 - return true - } - leftIndex += 1 - } - - while (rightIndex < rightMatches.size) { - if (!rightMatched.get(rightIndex)) { - // the right row has never matched any left row, join it with null row - joinedRow(leftNullRow, rightMatches(rightIndex)) - rightIndex += 1 - return true - } - rightIndex += 1 - } - - // There are no more valid matches in the left and right buffers - false - } - - // --- Public methods -------------------------------------------------------------------------- - - def getJoinedRow(): JoinedRow = joinedRow - - def advanceNext(): Boolean = { - // If we already buffered some matching rows, use them directly - if (leftIndex <= leftMatches.size || rightIndex <= rightMatches.size) { - if (scanNextInBuffered()) { - return true - } - } - - if (leftRow != null && (leftRowKey.anyNull || rightRow == null)) { - joinedRow(leftRow.copy(), rightNullRow) - advancedLeft() - true - } else if (rightRow != null && (rightRowKey.anyNull || leftRow == null)) { - joinedRow(leftNullRow, rightRow.copy()) - advancedRight() - true - } else if (leftRow != null && rightRow != null) { - // Both rows are present and neither have null values, - // so we populate the buffers with rows matching the next key - val comp = keyOrdering.compare(leftRowKey, rightRowKey) - if (comp <= 0) { - findMatchingRows(leftRowKey.copy()) - } else { - findMatchingRows(rightRowKey.copy()) - } - scanNextInBuffered() - true - } else { - // Both iterators have been consumed - false - } - } -} - -private class FullOuterIterator( - smjScanner: SortMergeFullOuterJoinScanner, - resultProj: InternalRow => InternalRow, - numRows: LongSQLMetric - ) extends RowIterator { - private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow() - - override def advanceNext(): Boolean = { - val r = smjScanner.advanceNext() - if (r) numRows += 1 - r - } - - override def getRow: InternalRow = resultProj(joinedRow) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala new file mode 100644 index 0000000000000..9643b52f96544 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.exchange.ShuffleExchange + + +/** + * Take the first `limit` elements and collect them to a single partition. + * + * This operator will be used when a logical `Limit` operation is the final operator in an + * logical plan, which happens when the user is collecting results back to the driver. + */ +case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = SinglePartition + override def executeCollect(): Array[InternalRow] = child.executeTake(limit) + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) + protected override def doExecute(): RDD[InternalRow] = { + val shuffled = new ShuffledRowRDD( + ShuffleExchange.prepareShuffleDependency( + child.execute(), child.output, SinglePartition, serializer)) + shuffled.mapPartitionsInternal(_.take(limit)) + } +} + +/** + * Helper trait which defines methods that are shared by both [[LocalLimit]] and [[GlobalLimit]]. + */ +trait BaseLimit extends UnaryNode with CodegenSupport { + val limit: Int + override def output: Seq[Attribute] = child.output + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override def outputPartitioning: Partitioning = child.outputPartitioning + protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => + iter.take(limit) + } + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + val stopEarly = ctx.freshName("stopEarly") + ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;") + + ctx.addNewFunction("shouldStop", s""" + @Override + protected boolean shouldStop() { + return !currentRows.isEmpty() || $stopEarly; + } + """) + val countTerm = ctx.freshName("count") + ctx.addMutableState("int", countTerm, s"$countTerm = 0;") + s""" + | if ($countTerm < $limit) { + | $countTerm += 1; + | ${consume(ctx, input)} + | } else { + | $stopEarly = true; + | } + """.stripMargin + } +} + +/** + * Take the first `limit` elements of each child partition, but do not collect or shuffle them. + */ +case class LocalLimit(limit: Int, child: SparkPlan) extends BaseLimit { + override def outputOrdering: Seq[SortOrder] = child.outputOrdering +} + +/** + * Take the first `limit` elements of the child's single output partition. + */ +case class GlobalLimit(limit: Int, child: SparkPlan) extends BaseLimit { + override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil +} + +/** + * Take the first limit elements as defined by the sortOrder, and do projection if needed. + * This is logically equivalent to having a Limit operator after a [[Sort]] operator, + * or having a [[Project]] operator between them. + * This could have been named TopK, but Spark's top operator does the opposite in ordering + * so we name it TakeOrdered to avoid confusion. + */ +case class TakeOrderedAndProject( + limit: Int, + sortOrder: Seq[SortOrder], + projectList: Option[Seq[NamedExpression]], + child: SparkPlan) extends UnaryNode { + + override def output: Seq[Attribute] = { + projectList.map(_.map(_.toAttribute)).getOrElse(child.output) + } + + override def outputPartitioning: Partitioning = SinglePartition + + override def executeCollect(): Array[InternalRow] = { + val ord = new LazilyGeneratedOrdering(sortOrder, child.output) + val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) + if (projectList.isDefined) { + val proj = UnsafeProjection.create(projectList.get, child.output) + data.map(r => proj(r).copy()) + } else { + data + } + } + + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) + + protected override def doExecute(): RDD[InternalRow] = { + val ord = new LazilyGeneratedOrdering(sortOrder, child.output) + val localTopK: RDD[InternalRow] = { + child.execute().map(_.copy()).mapPartitions { iter => + org.apache.spark.util.collection.Utils.takeOrdered(iter, limit)(ord) + } + } + val shuffled = new ShuffledRowRDD( + ShuffleExchange.prepareShuffleDependency( + localTopK, child.output, SinglePartition, serializer)) + shuffled.mapPartitions { iter => + val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord) + if (projectList.isDefined) { + val proj = UnsafeProjection.create(projectList.get, child.output) + topK.map(r => proj(r)) + } else { + topK + } + } + } + + override def outputOrdering: Seq[SortOrder] = sortOrder + + override def simpleString: String = { + val orderByString = sortOrder.mkString("[", ",", "]") + val outputString = output.mkString("[", ",", "]") + + s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala deleted file mode 100644 index 52dcb9e43c4e8..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala +++ /dev/null @@ -1,76 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.joins.{HashedRelation, BuildLeft, BuildRight, BuildSide} - -/** - * A [[HashJoinNode]] that builds the [[HashedRelation]] according to the value of - * `buildSide`. The actual work of this node is defined in [[HashJoinNode]]. - */ -case class BinaryHashJoinNode( - conf: SQLConf, - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - buildSide: BuildSide, - left: LocalNode, - right: LocalNode) - extends BinaryLocalNode(conf) with HashJoinNode { - - protected override val (streamedNode, streamedKeys) = buildSide match { - case BuildLeft => (right, rightKeys) - case BuildRight => (left, leftKeys) - } - - private val (buildNode, buildKeys) = buildSide match { - case BuildLeft => (left, leftKeys) - case BuildRight => (right, rightKeys) - } - - override def output: Seq[Attribute] = left.output ++ right.output - - private def buildSideKeyGenerator: Projection = { - // We are expecting the data types of buildKeys and streamedKeys are the same. - assert(buildKeys.map(_.dataType) == streamedKeys.map(_.dataType)) - if (isUnsafeMode) { - UnsafeProjection.create(buildKeys, buildNode.output) - } else { - newMutableProjection(buildKeys, buildNode.output)() - } - } - - protected override def doOpen(): Unit = { - buildNode.open() - val hashedRelation = HashedRelation(buildNode, buildSideKeyGenerator) - // We have built the HashedRelation. So, close buildNode. - buildNode.close() - - streamedNode.open() - // Set the HashedRelation used by the HashJoinNode. - withHashedRelation(hashedRelation) - } - - override def close(): Unit = { - // Please note that we do not need to call the close method of our buildNode because - // it has been called in this.open. - streamedNode.close() - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala deleted file mode 100644 index cd1c86516ec5f..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala +++ /dev/null @@ -1,59 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, HashedRelation} - -/** - * A [[HashJoinNode]] for broadcast join. It takes a streamedNode and a broadcast - * [[HashedRelation]]. The actual work of this node is defined in [[HashJoinNode]]. - */ -case class BroadcastHashJoinNode( - conf: SQLConf, - streamedKeys: Seq[Expression], - streamedNode: LocalNode, - buildSide: BuildSide, - buildOutput: Seq[Attribute], - hashedRelation: Broadcast[HashedRelation]) - extends UnaryLocalNode(conf) with HashJoinNode { - - override val child = streamedNode - - // Because we do not pass in the buildNode, we take the output of buildNode to - // create the inputSet properly. - override def inputSet: AttributeSet = AttributeSet(child.output ++ buildOutput) - - override def output: Seq[Attribute] = buildSide match { - case BuildRight => streamedNode.output ++ buildOutput - case BuildLeft => buildOutput ++ streamedNode.output - } - - protected override def doOpen(): Unit = { - streamedNode.open() - // Set the HashedRelation used by the HashJoinNode. - withHashedRelation(hashedRelation.value) - } - - override def close(): Unit = { - streamedNode.close() - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala deleted file mode 100644 index b31c5a863832e..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, FromUnsafeProjection, Projection} - -case class ConvertToSafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf) { - - override def output: Seq[Attribute] = child.output - - private[this] var convertToSafe: Projection = _ - - override def open(): Unit = { - child.open() - convertToSafe = FromUnsafeProjection(child.schema) - } - - override def next(): Boolean = child.next() - - override def fetch(): InternalRow = convertToSafe(child.fetch()) - - override def close(): Unit = child.close() -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala deleted file mode 100644 index de2f4e661ab44..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Projection, UnsafeProjection} - -case class ConvertToUnsafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf) { - - override def output: Seq[Attribute] = child.output - - private[this] var convertToUnsafe: Projection = _ - - override def open(): Unit = { - child.open() - convertToUnsafe = UnsafeProjection.create(child.schema) - } - - override def next(): Boolean = child.next() - - override def fetch(): InternalRow = convertToUnsafe(child.fetch()) - - override def close(): Unit = child.close() -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala deleted file mode 100644 index 2aff156d18b54..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala +++ /dev/null @@ -1,60 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Projection} - -case class ExpandNode( - conf: SQLConf, - projections: Seq[Seq[Expression]], - output: Seq[Attribute], - child: LocalNode) extends UnaryLocalNode(conf) { - - assert(projections.size > 0) - - private[this] var result: InternalRow = _ - private[this] var idx: Int = _ - private[this] var input: InternalRow = _ - private[this] var groups: Array[Projection] = _ - - override def open(): Unit = { - child.open() - groups = projections.map(ee => newProjection(ee, child.output)).toArray - idx = groups.length - } - - override def next(): Boolean = { - if (idx >= groups.length) { - if (child.next()) { - input = child.fetch() - idx = 0 - } else { - return false - } - } - result = groups(idx)(input) - idx += 1 - true - } - - override def fetch(): InternalRow = result - - override def close(): Unit = child.close() -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala deleted file mode 100644 index dd1113b6726cf..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate - - -case class FilterNode(conf: SQLConf, condition: Expression, child: LocalNode) - extends UnaryLocalNode(conf) { - - private[this] var predicate: (InternalRow) => Boolean = _ - - override def output: Seq[Attribute] = child.output - - override def open(): Unit = { - child.open() - predicate = GeneratePredicate.generate(condition, child.output) - } - - override def next(): Boolean = { - var found = false - while (!found && child.next()) { - found = predicate.apply(child.fetch()) - } - found - } - - override def fetch(): InternalRow = child.fetch() - - override def close(): Unit = child.close() -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala deleted file mode 100644 index b1dc719ca8508..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala +++ /dev/null @@ -1,129 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.joins._ - -/** - * An abstract node for sharing common functionality among different implementations of - * inner hash equi-join, notably [[BinaryHashJoinNode]] and [[BroadcastHashJoinNode]]. - * - * Much of this code is similar to [[org.apache.spark.sql.execution.joins.HashJoin]]. - */ -trait HashJoinNode { - - self: LocalNode => - - protected def streamedKeys: Seq[Expression] - protected def streamedNode: LocalNode - protected def buildSide: BuildSide - - private[this] var currentStreamedRow: InternalRow = _ - private[this] var currentHashMatches: Seq[InternalRow] = _ - private[this] var currentMatchPosition: Int = -1 - - private[this] var joinRow: JoinedRow = _ - private[this] var resultProjection: (InternalRow) => InternalRow = _ - - private[this] var hashed: HashedRelation = _ - private[this] var joinKeys: Projection = _ - - protected def isUnsafeMode: Boolean = { - (codegenEnabled && - unsafeEnabled && - UnsafeProjection.canSupport(schema) && - UnsafeProjection.canSupport(streamedKeys)) - } - - private def streamSideKeyGenerator: Projection = { - if (isUnsafeMode) { - UnsafeProjection.create(streamedKeys, streamedNode.output) - } else { - newMutableProjection(streamedKeys, streamedNode.output)() - } - } - - /** - * Sets the HashedRelation used by this node. This method needs to be called after - * before the first `next` gets called. - */ - protected def withHashedRelation(hashedRelation: HashedRelation): Unit = { - hashed = hashedRelation - } - - /** - * Custom open implementation to be overridden by subclasses. - */ - protected def doOpen(): Unit - - override def open(): Unit = { - doOpen() - joinRow = new JoinedRow - resultProjection = { - if (isUnsafeMode) { - UnsafeProjection.create(schema) - } else { - identity[InternalRow] - } - } - joinKeys = streamSideKeyGenerator - } - - override def next(): Boolean = { - currentMatchPosition += 1 - if (currentHashMatches == null || currentMatchPosition >= currentHashMatches.size) { - fetchNextMatch() - } else { - true - } - } - - /** - * Populate `currentHashMatches` with build-side rows matching the next streamed row. - * @return whether matches are found such that subsequent calls to `fetch` are valid. - */ - private def fetchNextMatch(): Boolean = { - currentHashMatches = null - currentMatchPosition = -1 - - while (currentHashMatches == null && streamedNode.next()) { - currentStreamedRow = streamedNode.fetch() - val key = joinKeys(currentStreamedRow) - if (!key.anyNull) { - currentHashMatches = hashed.get(key) - } - } - - if (currentHashMatches == null) { - false - } else { - currentMatchPosition = 0 - true - } - } - - override def fetch(): InternalRow = { - val ret = buildSide match { - case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) - case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) - } - resultProjection(ret) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala deleted file mode 100644 index 740d485f8d9e6..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala +++ /dev/null @@ -1,63 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import scala.collection.mutable - -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute - -case class IntersectNode(conf: SQLConf, left: LocalNode, right: LocalNode) - extends BinaryLocalNode(conf) { - - override def output: Seq[Attribute] = left.output - - private[this] var leftRows: mutable.HashSet[InternalRow] = _ - - private[this] var currentRow: InternalRow = _ - - override def open(): Unit = { - left.open() - leftRows = mutable.HashSet[InternalRow]() - while (left.next()) { - leftRows += left.fetch().copy() - } - left.close() - right.open() - } - - override def next(): Boolean = { - currentRow = null - while (currentRow == null && right.next()) { - currentRow = right.fetch() - if (!leftRows.contains(currentRow)) { - currentRow = null - } - } - currentRow != null - } - - override def fetch(): InternalRow = currentRow - - override def close(): Unit = { - left.close() - right.close() - } - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala deleted file mode 100644 index 401b10a5ed307..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute - - -case class LimitNode(conf: SQLConf, limit: Int, child: LocalNode) extends UnaryLocalNode(conf) { - - private[this] var count = 0 - - override def output: Seq[Attribute] = child.output - - override def open(): Unit = child.open() - - override def close(): Unit = child.close() - - override def fetch(): InternalRow = child.fetch() - - override def next(): Boolean = { - if (count < limit) { - count += 1 - child.next() - } else { - false - } - } - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala deleted file mode 100644 index f96b62a67a254..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ /dev/null @@ -1,224 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import scala.util.control.NonFatal - -import org.apache.spark.Logging -import org.apache.spark.sql.{SQLConf, Row} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.types.StructType - -/** - * A local physical operator, in the form of an iterator. - * - * Before consuming the iterator, open function must be called. - * After consuming the iterator, close function must be called. - */ -abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Logging { - - protected val codegenEnabled: Boolean = conf.codegenEnabled - - protected val unsafeEnabled: Boolean = conf.unsafeEnabled - - private[this] lazy val isTesting: Boolean = sys.props.contains("spark.testing") - - /** - * Called before open(). Prepare can be used to reserve memory needed. It must NOT consume - * any input data. - * - * Implementations of this must also call the `prepare()` function of its children. - */ - def prepare(): Unit = children.foreach(_.prepare()) - - /** - * Initializes the iterator state. Must be called before calling `next()`. - * - * Implementations of this must also call the `open()` function of its children. - */ - def open(): Unit - - /** - * Advances the iterator to the next tuple. Returns true if there is at least one more tuple. - */ - def next(): Boolean - - /** - * Returns the current tuple. - */ - def fetch(): InternalRow - - /** - * Closes the iterator and releases all resources. It should be idempotent. - * - * Implementations of this must also call the `close()` function of its children. - */ - def close(): Unit - - /** Specifies whether this operator outputs UnsafeRows */ - def outputsUnsafeRows: Boolean = false - - /** Specifies whether this operator is capable of processing UnsafeRows */ - def canProcessUnsafeRows: Boolean = false - - /** - * Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows - * that are not UnsafeRows). - */ - def canProcessSafeRows: Boolean = true - - /** - * Returns the content through the [[Iterator]] interface. - */ - final def asIterator: Iterator[InternalRow] = new LocalNodeIterator(this) - - /** - * Returns the content of the iterator from the beginning to the end in the form of a Scala Seq. - */ - final def collect(): Seq[Row] = { - val converter = CatalystTypeConverters.createToScalaConverter(StructType.fromAttributes(output)) - val result = new scala.collection.mutable.ArrayBuffer[Row] - open() - try { - while (next()) { - result += converter.apply(fetch()).asInstanceOf[Row] - } - } finally { - close() - } - result - } - - protected def newProjection( - expressions: Seq[Expression], - inputSchema: Seq[Attribute]): Projection = { - log.debug( - s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if (codegenEnabled) { - try { - GenerateProjection.generate(expressions, inputSchema) - } catch { - case NonFatal(e) => - if (isTesting) { - throw e - } else { - log.error("Failed to generate projection, fallback to interpret", e) - new InterpretedProjection(expressions, inputSchema) - } - } - } else { - new InterpretedProjection(expressions, inputSchema) - } - } - - protected def newMutableProjection( - expressions: Seq[Expression], - inputSchema: Seq[Attribute]): () => MutableProjection = { - log.debug( - s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if (codegenEnabled) { - try { - GenerateMutableProjection.generate(expressions, inputSchema) - } catch { - case NonFatal(e) => - if (isTesting) { - throw e - } else { - log.error("Failed to generate mutable projection, fallback to interpreted", e) - () => new InterpretedMutableProjection(expressions, inputSchema) - } - } - } else { - () => new InterpretedMutableProjection(expressions, inputSchema) - } - } - - protected def newPredicate( - expression: Expression, - inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { - if (codegenEnabled) { - try { - GeneratePredicate.generate(expression, inputSchema) - } catch { - case NonFatal(e) => - if (isTesting) { - throw e - } else { - log.error("Failed to generate predicate, fallback to interpreted", e) - InterpretedPredicate.create(expression, inputSchema) - } - } - } else { - InterpretedPredicate.create(expression, inputSchema) - } - } -} - - -abstract class LeafLocalNode(conf: SQLConf) extends LocalNode(conf) { - override def children: Seq[LocalNode] = Seq.empty -} - - -abstract class UnaryLocalNode(conf: SQLConf) extends LocalNode(conf) { - - def child: LocalNode - - override def children: Seq[LocalNode] = Seq(child) -} - -abstract class BinaryLocalNode(conf: SQLConf) extends LocalNode(conf) { - - def left: LocalNode - - def right: LocalNode - - override def children: Seq[LocalNode] = Seq(left, right) -} - -/** - * An thin wrapper around a [[LocalNode]] that provides an `Iterator` interface. - */ -private class LocalNodeIterator(localNode: LocalNode) extends Iterator[InternalRow] { - private var nextRow: InternalRow = _ - - override def hasNext: Boolean = { - if (nextRow == null) { - val res = localNode.next() - if (res) { - nextRow = localNode.fetch() - } - res - } else { - true - } - } - - override def next(): InternalRow = { - if (hasNext) { - val res = nextRow - nextRow = null - res - } else { - throw new NoSuchElementException - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala deleted file mode 100644 index 7321fc66b4dde..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, RightOuter, LeftOuter, JoinType} -import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} -import org.apache.spark.util.collection.{BitSet, CompactBuffer} - -case class NestedLoopJoinNode( - conf: SQLConf, - left: LocalNode, - right: LocalNode, - buildSide: BuildSide, - joinType: JoinType, - condition: Option[Expression]) extends BinaryLocalNode(conf) { - - override def output: Seq[Attribute] = { - joinType match { - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case x => - throw new IllegalArgumentException( - s"NestedLoopJoin should not take $x as the JoinType") - } - } - - private[this] def genResultProjection: InternalRow => InternalRow = { - if (outputsUnsafeRows) { - UnsafeProjection.create(schema) - } else { - identity[InternalRow] - } - } - - private[this] var currentRow: InternalRow = _ - - private[this] var iterator: Iterator[InternalRow] = _ - - override def open(): Unit = { - val (streamed, build) = buildSide match { - case BuildRight => (left, right) - case BuildLeft => (right, left) - } - build.open() - val buildRelation = new CompactBuffer[InternalRow] - while (build.next()) { - buildRelation += build.fetch().copy() - } - build.close() - - val boundCondition = - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - - val leftNulls = new GenericMutableRow(left.output.size) - val rightNulls = new GenericMutableRow(right.output.size) - val joinedRow = new JoinedRow - val matchedBuildTuples = new BitSet(buildRelation.size) - val resultProj = genResultProjection - streamed.open() - - // streamedRowMatches also contains null rows if using outer join - val streamedRowMatches: Iterator[InternalRow] = streamed.asIterator.flatMap { streamedRow => - val matchedRows = new CompactBuffer[InternalRow] - - var i = 0 - var streamRowMatched = false - - // Scan the build relation to look for matches for each streamed row - while (i < buildRelation.size) { - val buildRow = buildRelation(i) - buildSide match { - case BuildRight => joinedRow(streamedRow, buildRow) - case BuildLeft => joinedRow(buildRow, streamedRow) - } - if (boundCondition(joinedRow)) { - matchedRows += resultProj(joinedRow).copy() - streamRowMatched = true - matchedBuildTuples.set(i) - } - i += 1 - } - - // If this row had no matches and we're using outer join, join it with the null rows - if (!streamRowMatched) { - (joinType, buildSide) match { - case (LeftOuter | FullOuter, BuildRight) => - matchedRows += resultProj(joinedRow(streamedRow, rightNulls)).copy() - case (RightOuter | FullOuter, BuildLeft) => - matchedRows += resultProj(joinedRow(leftNulls, streamedRow)).copy() - case _ => - } - } - - matchedRows.iterator - } - - // If we're using outer join, find rows on the build side that didn't match anything - // and join them with the null row - lazy val unmatchedBuildRows: Iterator[InternalRow] = { - var i = 0 - buildRelation.filter { row => - val r = !matchedBuildTuples.get(i) - i += 1 - r - }.iterator - } - iterator = (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => - streamedRowMatches ++ - unmatchedBuildRows.map { buildRow => resultProj(joinedRow(leftNulls, buildRow)) } - case (LeftOuter | FullOuter, BuildLeft) => - streamedRowMatches ++ - unmatchedBuildRows.map { buildRow => resultProj(joinedRow(buildRow, rightNulls)) } - case _ => streamedRowMatches - } - } - - override def next(): Boolean = { - if (iterator.hasNext) { - currentRow = iterator.next() - true - } else { - false - } - } - - override def fetch(): InternalRow = currentRow - - override def close(): Unit = { - left.close() - right.close() - } - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala deleted file mode 100644 index 11529d6dd9b83..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Attribute, NamedExpression} - - -case class ProjectNode(conf: SQLConf, projectList: Seq[NamedExpression], child: LocalNode) - extends UnaryLocalNode(conf) { - - private[this] var project: UnsafeProjection = _ - - override def output: Seq[Attribute] = projectList.map(_.toAttribute) - - override def open(): Unit = { - project = UnsafeProjection.create(projectList, child.output) - child.open() - } - - override def next(): Boolean = child.next() - - override def fetch(): InternalRow = { - project.apply(child.fetch()) - } - - override def close(): Unit = child.close() -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala deleted file mode 100644 index 793700803f216..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} - - -/** - * Sample the dataset. - * - * @param conf the SQLConf - * @param lowerBound Lower-bound of the sampling probability (usually 0.0) - * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled - * will be ub - lb. - * @param withReplacement Whether to sample with replacement. - * @param seed the random seed - * @param child the LocalNode - */ -case class SampleNode( - conf: SQLConf, - lowerBound: Double, - upperBound: Double, - withReplacement: Boolean, - seed: Long, - child: LocalNode) extends UnaryLocalNode(conf) { - - override def output: Seq[Attribute] = child.output - - private[this] var iterator: Iterator[InternalRow] = _ - - private[this] var currentRow: InternalRow = _ - - override def open(): Unit = { - child.open() - val sampler = - if (withReplacement) { - // Disable gap sampling since the gap sampling method buffers two rows internally, - // requiring us to copy the row, which is more expensive than the random number generator. - new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false) - } else { - new BernoulliCellSampler[InternalRow](lowerBound, upperBound) - } - sampler.setSeed(seed) - iterator = sampler.sample(child.asIterator) - } - - override def next(): Boolean = { - if (iterator.hasNext) { - currentRow = iterator.next() - true - } else { - false - } - } - - override def fetch(): InternalRow = currentRow - - override def close(): Unit = child.close() - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala deleted file mode 100644 index b8467f6ae58e0..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute - -/** - * An operator that scans some local data collection in the form of Scala Seq. - */ -case class SeqScanNode(conf: SQLConf, output: Seq[Attribute], data: Seq[InternalRow]) - extends LeafLocalNode(conf) { - - private[this] var iterator: Iterator[InternalRow] = _ - private[this] var currentRow: InternalRow = _ - - override def open(): Unit = { - iterator = data.iterator - } - - override def next(): Boolean = { - if (iterator.hasNext) { - currentRow = iterator.next() - true - } else { - false - } - } - - override def fetch(): InternalRow = currentRow - - override def close(): Unit = { - // Do nothing - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala deleted file mode 100644 index ae672fbca8d83..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.util.BoundedPriorityQueue - -case class TakeOrderedAndProjectNode( - conf: SQLConf, - limit: Int, - sortOrder: Seq[SortOrder], - projectList: Option[Seq[NamedExpression]], - child: LocalNode) extends UnaryLocalNode(conf) { - - private[this] var projection: Option[Projection] = _ - private[this] var ord: InterpretedOrdering = _ - private[this] var iterator: Iterator[InternalRow] = _ - private[this] var currentRow: InternalRow = _ - - override def output: Seq[Attribute] = { - val projectOutput = projectList.map(_.map(_.toAttribute)) - projectOutput.getOrElse(child.output) - } - - override def open(): Unit = { - child.open() - projection = projectList.map(new InterpretedProjection(_, child.output)) - ord = new InterpretedOrdering(sortOrder, child.output) - // Priority keeps the largest elements, so let's reverse the ordering. - val queue = new BoundedPriorityQueue[InternalRow](limit)(ord.reverse) - while (child.next()) { - queue += child.fetch() - } - // Close it eagerly since we don't need it. - child.close() - iterator = queue.toArray.sorted(ord).iterator - } - - override def next(): Boolean = { - if (iterator.hasNext) { - val _currentRow = iterator.next() - currentRow = projection match { - case Some(p) => p(_currentRow) - case None => _currentRow - } - true - } else { - false - } - } - - override def fetch(): InternalRow = currentRow - - override def close(): Unit = child.close() - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala deleted file mode 100644 index 0f2b8303e7372..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala +++ /dev/null @@ -1,73 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute - -case class UnionNode(conf: SQLConf, children: Seq[LocalNode]) extends LocalNode(conf) { - - override def output: Seq[Attribute] = children.head.output - - private[this] var currentChild: LocalNode = _ - - private[this] var nextChildIndex: Int = _ - - override def open(): Unit = { - currentChild = children.head - currentChild.open() - nextChildIndex = 1 - } - - private def advanceToNextChild(): Boolean = { - var found = false - var exit = false - while (!exit && !found) { - if (currentChild != null) { - currentChild.close() - } - if (nextChildIndex >= children.size) { - found = false - exit = true - } else { - currentChild = children(nextChildIndex) - nextChildIndex += 1 - currentChild.open() - found = currentChild.next() - } - } - found - } - - override def close(): Unit = { - if (currentChild != null) { - currentChild.close() - } - } - - override def fetch(): InternalRow = currentChild.fetch() - - override def next(): Boolean = { - if (currentChild.next()) { - true - } else { - advanceToNextChild() - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala new file mode 100644 index 0000000000000..2708219ad3485 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.metric + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Stores information about a SQL Metric. + */ +@DeveloperApi +class SQLMetricInfo( + val name: String, + val accumulatorId: Long, + val metricParam: String) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 1c253e3942e95..7fa13907295b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.execution.metric +import org.apache.spark.{Accumulable, AccumulableParam, Accumulators, SparkContext} +import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.util.Utils -import org.apache.spark.{Accumulable, AccumulableParam, SparkContext} /** * Create a layer for specialized metric. We cannot add `@specialized` to @@ -27,8 +28,15 @@ import org.apache.spark.{Accumulable, AccumulableParam, SparkContext} * An implementation of SQLMetric should override `+=` and `add` to avoid boxing. */ private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T]( - name: String, val param: SQLMetricParam[R, T]) - extends Accumulable[R, T](param.zero, param, Some(name), true) { + name: String, + val param: SQLMetricParam[R, T]) + extends Accumulable[R, T](param.zero, param, Some(name), internal = true) { + + // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later + override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { + new AccumulableInfo(id, Some(name), update, value, isInternal, countFailedValues, + Some(SQLMetrics.ACCUM_IDENTIFIER)) + } def reset(): Unit = { this.value = param.zero @@ -73,6 +81,14 @@ private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetr // Although there is a boxing here, it's fine because it's only called in SQLListener override def value: Long = _value + + // Needed for SQLListenerSuite + override def equals(other: Any): Boolean = { + other match { + case o: LongSQLMetricValue => value == o.value + case _ => false + } + } } /** @@ -104,21 +120,62 @@ private class LongSQLMetricParam(val stringValue: Seq[Long] => String, initialVa override def zero: LongSQLMetricValue = new LongSQLMetricValue(initialValue) } +private object LongSQLMetricParam extends LongSQLMetricParam(_.sum.toString, 0L) + +private object StatisticsBytesSQLMetricParam extends LongSQLMetricParam( + (values: Seq[Long]) => { + // This is a workaround for SPARK-11013. + // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update + // it at the end of task and the value will be at least 0. + val validValues = values.filter(_ >= 0) + val Seq(sum, min, med, max) = { + val metric = if (validValues.length == 0) { + Seq.fill(4)(0L) + } else { + val sorted = validValues.sorted + Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + } + metric.map(Utils.bytesToString) + } + s"\n$sum ($min, $med, $max)" + }, -1L) + +private object StatisticsTimingSQLMetricParam extends LongSQLMetricParam( + (values: Seq[Long]) => { + // This is a workaround for SPARK-11013. + // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update + // it at the end of task and the value will be at least 0. + val validValues = values.filter(_ >= 0) + val Seq(sum, min, med, max) = { + val metric = if (validValues.length == 0) { + Seq.fill(4)(0L) + } else { + val sorted = validValues.sorted + Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + } + metric.map(Utils.msDurationToString) + } + s"\n$sum ($min, $med, $max)" + }, -1L) + private[sql] object SQLMetrics { + // Identifier for distinguishing SQL metrics from other accumulators + private[sql] val ACCUM_IDENTIFIER = "sql" + private def createLongMetric( sc: SparkContext, name: String, - stringValue: Seq[Long] => String, - initialValue: Long): LongSQLMetric = { - val param = new LongSQLMetricParam(stringValue, initialValue) + param: LongSQLMetricParam): LongSQLMetric = { val acc = new LongSQLMetric(name, param) + // This is an internal accumulator so we need to register it explicitly. + Accumulators.register(acc) sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc } def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = { - createLongMetric(sc, name, _.sum.toString, 0L) + createLongMetric(sc, name, LongSQLMetricParam) } /** @@ -126,31 +183,34 @@ private[sql] object SQLMetrics { * spill size, etc. */ def createSizeMetric(sc: SparkContext, name: String): LongSQLMetric = { - val stringValue = (values: Seq[Long]) => { - // This is a workaround for SPARK-11013. - // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update - // it at the end of task and the value will be at least 0. - val validValues = values.filter(_ >= 0) - val Seq(sum, min, med, max) = { - val metric = if (validValues.length == 0) { - Seq.fill(4)(0L) - } else { - val sorted = validValues.sorted - Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) - } - metric.map(Utils.bytesToString) - } - s"\n$sum ($min, $med, $max)" - } // The final result of this metric in physical operator UI may looks like: // data size total (min, med, max): // 100GB (100MB, 1GB, 10GB) - createLongMetric(sc, s"$name total (min, med, max)", stringValue, -1L) + createLongMetric(sc, s"$name total (min, med, max)", StatisticsBytesSQLMetricParam) + } + + def createTimingMetric(sc: SparkContext, name: String): LongSQLMetric = { + // The final result of this metric in physical operator UI may looks like: + // duration(min, med, max): + // 5s (800ms, 1s, 2s) + createLongMetric(sc, s"$name total (min, med, max)", StatisticsTimingSQLMetricParam) + } + + def getMetricParam(metricParamName: String): SQLMetricParam[SQLMetricValue[Any], Any] = { + val longSQLMetricParam = Utils.getFormattedClassName(LongSQLMetricParam) + val bytesSQLMetricParam = Utils.getFormattedClassName(StatisticsBytesSQLMetricParam) + val timingsSQLMetricParam = Utils.getFormattedClassName(StatisticsTimingSQLMetricParam) + val metricParam = metricParamName match { + case `longSQLMetricParam` => LongSQLMetricParam + case `bytesSQLMetricParam` => StatisticsBytesSQLMetricParam + case `timingsSQLMetricParam` => StatisticsTimingSQLMetricParam + } + metricParam.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]] } /** * A metric that its value will be ignored. Use this one when we need a metric parameter but don't * care about the value. */ - val nullLongMetric = new LongSQLMetric("null", new LongSQLMetricParam(_.sum.toString, 0L)) + val nullLongMetric = new LongSQLMetric("null", LongSQLMetricParam) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala new file mode 100644 index 0000000000000..d2ab18ef0e189 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -0,0 +1,319 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.language.existentials + +import org.apache.spark.api.java.function.MapFunction +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.types.ObjectType + +/** + * Takes the input row from child and turns it into object using the given deserializer expression. + * The output of this operator is a single-field safe row containing the deserialized object. + */ +case class DeserializeToObject( + deserializer: Alias, + child: SparkPlan) extends UnaryNode with CodegenSupport { + override def output: Seq[Attribute] = deserializer.toAttribute :: Nil + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + val bound = ExpressionCanonicalizer.execute( + BindReferences.bindReference(deserializer, child.output)) + ctx.currentVars = input + val resultVars = bound.gen(ctx) :: Nil + consume(ctx, resultVars) + } + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal { iter => + val projection = GenerateSafeProjection.generate(deserializer :: Nil, child.output) + iter.map(projection) + } + } +} + +/** + * Takes the input object from child and turns in into unsafe row using the given serializer + * expression. The output of its child must be a single-field row containing the input object. + */ +case class SerializeFromObject( + serializer: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with CodegenSupport { + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + val bound = serializer.map { expr => + ExpressionCanonicalizer.execute(BindReferences.bindReference(expr, child.output)) + } + ctx.currentVars = input + val resultVars = bound.map(_.gen(ctx)) + consume(ctx, resultVars) + } + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal { iter => + val projection = UnsafeProjection.create(serializer) + iter.map(projection) + } + } +} + +/** + * Helper functions for physical operators that work with user defined objects. + */ +trait ObjectOperator extends SparkPlan { + def generateToObject(objExpr: Expression, inputSchema: Seq[Attribute]): InternalRow => Any = { + val objectProjection = GenerateSafeProjection.generate(objExpr :: Nil, inputSchema) + (i: InternalRow) => objectProjection(i).get(0, objExpr.dataType) + } + + def generateToRow(serializer: Seq[Expression]): Any => InternalRow = { + val outputProjection = if (serializer.head.dataType.isInstanceOf[ObjectType]) { + GenerateSafeProjection.generate(serializer) + } else { + GenerateUnsafeProjection.generate(serializer) + } + val inputType = serializer.head.collect { case b: BoundReference => b.dataType }.head + val outputRow = new SpecificMutableRow(inputType :: Nil) + (o: Any) => { + outputRow(0) = o + outputProjection(outputRow) + } + } +} + +/** + * Applies the given function to each input row and encodes the result. + */ +case class MapPartitions( + func: Iterator[Any] => Iterator[Any], + deserializer: Expression, + serializer: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with ObjectOperator { + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal { iter => + val getObject = generateToObject(deserializer, child.output) + val outputObject = generateToRow(serializer) + func(iter.map(getObject)).map(outputObject) + } + } +} + +/** + * Applies the given function to each input row and encodes the result. + * + * Note that, each serializer expression needs the result object which is returned by the given + * function, as input. This operator uses some tricks to make sure we only calculate the result + * object once. We don't use [[Project]] directly as subexpression elimination doesn't work with + * whole stage codegen and it's confusing to show the un-common-subexpression-eliminated version of + * a project while explain. + */ +case class MapElements( + func: AnyRef, + deserializer: Expression, + serializer: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with ObjectOperator with CodegenSupport { + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + val (funcClass, methodName) = func match { + case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call" + case _ => classOf[Any => Any] -> "apply" + } + val funcObj = Literal.create(func, ObjectType(funcClass)) + val resultObjType = serializer.head.collect { case b: BoundReference => b }.head.dataType + val callFunc = Invoke(funcObj, methodName, resultObjType, Seq(deserializer)) + + val bound = ExpressionCanonicalizer.execute( + BindReferences.bindReference(callFunc, child.output)) + ctx.currentVars = input + val evaluated = bound.gen(ctx) + + val resultObj = LambdaVariable(evaluated.value, evaluated.isNull, resultObjType) + val outputFields = serializer.map(_ transform { + case _: BoundReference => resultObj + }) + val resultVars = outputFields.map(_.gen(ctx)) + s""" + ${evaluated.code} + ${consume(ctx, resultVars)} + """ + } + + override protected def doExecute(): RDD[InternalRow] = { + val callFunc: Any => Any = func match { + case m: MapFunction[_, _] => i => m.asInstanceOf[MapFunction[Any, Any]].call(i) + case _ => func.asInstanceOf[Any => Any] + } + child.execute().mapPartitionsInternal { iter => + val getObject = generateToObject(deserializer, child.output) + val outputObject = generateToRow(serializer) + iter.map(row => outputObject(callFunc(getObject(row)))) + } + } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering +} + +/** + * Applies the given function to each input row, appending the encoded result at the end of the row. + */ +case class AppendColumns( + func: Any => Any, + deserializer: Expression, + serializer: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with ObjectOperator { + + override def output: Seq[Attribute] = child.output ++ serializer.map(_.toAttribute) + + private def newColumnSchema = serializer.map(_.toAttribute).toStructType + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal { iter => + val getObject = generateToObject(deserializer, child.output) + val combiner = GenerateUnsafeRowJoiner.create(child.schema, newColumnSchema) + val outputObject = generateToRow(serializer) + + iter.map { row => + val newColumns = outputObject(func(getObject(row))) + + // This operates on the assumption that we always serialize the result... + combiner.join(row.asInstanceOf[UnsafeRow], newColumns.asInstanceOf[UnsafeRow]): InternalRow + } + } + } +} + +/** + * Groups the input rows together and calls the function with each group and an iterator containing + * all elements in the group. The result of this function is encoded and flattened before + * being output. + */ +case class MapGroups( + func: (Any, Iterator[Any]) => TraversableOnce[Any], + keyDeserializer: Expression, + valueDeserializer: Expression, + serializer: Seq[NamedExpression], + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + child: SparkPlan) extends UnaryNode with ObjectOperator { + + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(groupingAttributes) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingAttributes.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal { iter => + val grouped = GroupedIterator(iter, groupingAttributes, child.output) + + val getKey = generateToObject(keyDeserializer, groupingAttributes) + val getValue = generateToObject(valueDeserializer, dataAttributes) + val outputObject = generateToRow(serializer) + + grouped.flatMap { case (key, rowIter) => + val result = func( + getKey(key), + rowIter.map(getValue)) + result.map(outputObject) + } + } + } +} + +/** + * Co-groups the data from left and right children, and calls the function with each group and 2 + * iterators containing all elements in the group from left and right side. + * The result of this function is encoded and flattened before being output. + */ +case class CoGroup( + func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any], + keyDeserializer: Expression, + leftDeserializer: Expression, + rightDeserializer: Expression, + serializer: Seq[NamedExpression], + leftGroup: Seq[Attribute], + rightGroup: Seq[Attribute], + leftAttr: Seq[Attribute], + rightAttr: Seq[Attribute], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with ObjectOperator { + + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil + + override protected def doExecute(): RDD[InternalRow] = { + left.execute().zipPartitions(right.execute()) { (leftData, rightData) => + val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) + val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) + + val getKey = generateToObject(keyDeserializer, leftGroup) + val getLeft = generateToObject(leftDeserializer, leftAttr) + val getRight = generateToObject(rightDeserializer, rightAttr) + val outputObject = generateToRow(serializer) + + new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap { + case (key, leftResult, rightResult) => + val result = func( + getKey(key), + leftResult.map(getLeft), + rightResult.map(getRight)) + result.map(outputObject) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala index 28fa231e722d0..c912734bba9e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala @@ -19,5 +19,7 @@ package org.apache.spark.sql /** * The physical execution component of Spark SQL. Note that this is a private package. + * All classes in catalyst are considered an internal API to Spark SQL and are subject + * to change between minor releases. */ package object execution diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala deleted file mode 100644 index d611b0011da16..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala +++ /dev/null @@ -1,411 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution - -import java.io.OutputStream -import java.util.{List => JList, Map => JMap} - -import scala.collection.JavaConverters._ - -import net.razorvine.pickle._ - -import org.apache.spark.{Logging => SparkLogging, TaskContext, Accumulator} -import org.apache.spark.api.python.{PythonRunner, PythonBroadcast, PythonRDD, SerDeUtil} -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.{MapData, GenericArrayData, ArrayBasedMapData, ArrayData} -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -/** - * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. - */ -private[spark] case class PythonUDF( - name: String, - command: Array[Byte], - envVars: JMap[String, String], - pythonIncludes: JList[String], - pythonExec: String, - pythonVer: String, - broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: Accumulator[JList[Array[Byte]]], - dataType: DataType, - children: Seq[Expression]) extends Expression with Unevaluable with SparkLogging { - - override def toString: String = s"PythonUDF#$name(${children.mkString(",")})" - - override def nullable: Boolean = true -} - -/** - * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated - * alone in a batch. - * - * This has the limitation that the input to the Python UDF is not allowed include attributes from - * multiple child operators. - */ -private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - // Skip EvaluatePython nodes. - case plan: EvaluatePython => plan - - case plan: LogicalPlan if plan.resolved => - // Extract any PythonUDFs from the current operator. - val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf }) - if (udfs.isEmpty) { - // If there aren't any, we are done. - plan - } else { - // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time) - // If there is more than one, we will add another evaluation operator in a subsequent pass. - udfs.find(_.resolved) match { - case Some(udf) => - var evaluation: EvaluatePython = null - - // Rewrite the child that has the input required for the UDF - val newChildren = plan.children.map { child => - // Check to make sure that the UDF can be evaluated with only the input of this child. - // Other cases are disallowed as they are ambiguous or would require a cartesian - // product. - if (udf.references.subsetOf(child.outputSet)) { - evaluation = EvaluatePython(udf, child) - evaluation - } else if (udf.references.intersect(child.outputSet).nonEmpty) { - sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") - } else { - child - } - } - - assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.") - - // Trim away the new UDF value if it was only used for filtering or something. - logical.Project( - plan.output, - plan.transformExpressions { - case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute - }.withNewChildren(newChildren)) - - case None => - // If there is no Python UDF that is resolved, skip this round. - plan - } - } - } -} - -object EvaluatePython { - def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython = - new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) - - def takeAndServe(df: DataFrame, n: Int): Int = { - registerPicklers() - val iter = new SerDeUtil.AutoBatchedPickler( - df.queryExecution.executedPlan.executeTake(n).iterator.map { row => - EvaluatePython.toJava(row, df.schema) - }) - PythonRDD.serveIterator(iter, s"serve-DataFrame") - } - - /** - * Helper for converting from Catalyst type to java type suitable for Pyrolite. - */ - def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { - case (null, _) => null - - case (row: InternalRow, struct: StructType) => - val values = new Array[Any](row.numFields) - var i = 0 - while (i < row.numFields) { - values(i) = toJava(row.get(i, struct.fields(i).dataType), struct.fields(i).dataType) - i += 1 - } - new GenericInternalRowWithSchema(values, struct) - - case (a: ArrayData, array: ArrayType) => - val values = new java.util.ArrayList[Any](a.numElements()) - a.foreach(array.elementType, (_, e) => { - values.add(toJava(e, array.elementType)) - }) - values - - case (map: MapData, mt: MapType) => - val jmap = new java.util.HashMap[Any, Any](map.numElements()) - map.foreach(mt.keyType, mt.valueType, (k, v) => { - jmap.put(toJava(k, mt.keyType), toJava(v, mt.valueType)) - }) - jmap - - case (ud, udt: UserDefinedType[_]) => toJava(ud, udt.sqlType) - - case (d: Decimal, _) => d.toJavaBigDecimal - - case (s: UTF8String, StringType) => s.toString - - case (other, _) => other - } - - /** - * Converts `obj` to the type specified by the data type, or returns null if the type of obj is - * unexpected. Because Python doesn't enforce the type. - */ - def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { - case (null, _) => null - - case (c: Boolean, BooleanType) => c - - case (c: Int, ByteType) => c.toByte - case (c: Long, ByteType) => c.toByte - - case (c: Int, ShortType) => c.toShort - case (c: Long, ShortType) => c.toShort - - case (c: Int, IntegerType) => c - case (c: Long, IntegerType) => c.toInt - - case (c: Int, LongType) => c.toLong - case (c: Long, LongType) => c - - case (c: Double, FloatType) => c.toFloat - - case (c: Double, DoubleType) => c - - case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c, dt.precision, dt.scale) - - case (c: Int, DateType) => c - - case (c: Long, TimestampType) => c - - case (c: String, StringType) => UTF8String.fromString(c) - case (c, StringType) => - // If we get here, c is not a string. Call toString on it. - UTF8String.fromString(c.toString) - - case (c: String, BinaryType) => c.getBytes("utf-8") - case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c - - case (c: java.util.List[_], ArrayType(elementType, _)) => - new GenericArrayData(c.asScala.map { e => fromJava(e, elementType)}.toArray) - - case (c, ArrayType(elementType, _)) if c.getClass.isArray => - new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType))) - - case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => - val keyValues = c.asScala.toSeq - val keys = keyValues.map(kv => fromJava(kv._1, keyType)).toArray - val values = keyValues.map(kv => fromJava(kv._2, valueType)).toArray - ArrayBasedMapData(keys, values) - - case (c, StructType(fields)) if c.getClass.isArray => - new GenericInternalRow(c.asInstanceOf[Array[_]].zip(fields).map { - case (e, f) => fromJava(e, f.dataType) - }) - - case (_, udt: UserDefinedType[_]) => fromJava(obj, udt.sqlType) - - // all other unexpected type should be null, or we will have runtime exception - // TODO(davies): we could improve this by try to cast the object to expected type - case (c, _) => null - } - - - private val module = "pyspark.sql.types" - - /** - * Pickler for StructType - */ - private class StructTypePickler extends IObjectPickler { - - private val cls = classOf[StructType] - - def register(): Unit = { - Pickler.registerCustomPickler(cls, this) - } - - def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { - out.write(Opcodes.GLOBAL) - out.write((module + "\n" + "_parse_datatype_json_string" + "\n").getBytes("utf-8")) - val schema = obj.asInstanceOf[StructType] - pickler.save(schema.json) - out.write(Opcodes.TUPLE1) - out.write(Opcodes.REDUCE) - } - } - - /** - * Pickler for InternalRow - */ - private class RowPickler extends IObjectPickler { - - private val cls = classOf[GenericInternalRowWithSchema] - - // register this to Pickler and Unpickler - def register(): Unit = { - Pickler.registerCustomPickler(this.getClass, this) - Pickler.registerCustomPickler(cls, this) - } - - def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { - if (obj == this) { - out.write(Opcodes.GLOBAL) - out.write((module + "\n" + "_create_row_inbound_converter" + "\n").getBytes("utf-8")) - } else { - // it will be memorized by Pickler to save some bytes - pickler.save(this) - val row = obj.asInstanceOf[GenericInternalRowWithSchema] - // schema should always be same object for memoization - pickler.save(row.schema) - out.write(Opcodes.TUPLE1) - out.write(Opcodes.REDUCE) - - out.write(Opcodes.MARK) - var i = 0 - while (i < row.values.size) { - pickler.save(row.values(i)) - i += 1 - } - out.write(Opcodes.TUPLE) - out.write(Opcodes.REDUCE) - } - } - } - - private[this] var registered = false - /** - * This should be called before trying to serialize any above classes un cluster mode, - * this should be put in the closure - */ - def registerPicklers(): Unit = { - synchronized { - if (!registered) { - SerDeUtil.initialize() - new StructTypePickler().register() - new RowPickler().register() - registered = true - } - } - } - - /** - * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by - * PySpark. - */ - def javaToPython(rdd: RDD[Any]): RDD[Array[Byte]] = { - rdd.mapPartitions { iter => - registerPicklers() // let it called in executor - new SerDeUtil.AutoBatchedPickler(iter) - } - } -} - -/** - * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple. - */ -case class EvaluatePython( - udf: PythonUDF, - child: LogicalPlan, - resultAttribute: AttributeReference) - extends logical.UnaryNode { - - def output: Seq[Attribute] = child.output :+ resultAttribute - - // References should not include the produced attribute. - override def references: AttributeSet = udf.references -} - -/** - * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a time. - * - * Python evaluation works by sending the necessary (projected) input data via a socket to an - * external Python process, and combine the result from the Python process with the original row. - * - * For each row we send to Python, we also put it in a queue. For each output row from Python, - * we drain the queue to find the original input row. Note that if the Python process is way too - * slow, this could lead to the queue growing unbounded and eventually run out of memory. - */ -case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) - extends SparkPlan { - - def children: Seq[SparkPlan] = child :: Nil - - override def outputsUnsafeRows: Boolean = false - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true - - protected override def doExecute(): RDD[InternalRow] = { - val inputRDD = child.execute().map(_.copy()) - val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) - val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) - - inputRDD.mapPartitions { iter => - EvaluatePython.registerPicklers() // register pickler for Row - - // The queue used to buffer input rows so we can drain it to - // combine input with output from Python. - val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() - - val pickle = new Pickler - val currentRow = newMutableProjection(udf.children, child.output)() - val fields = udf.children.map(_.dataType) - val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) - - // Input iterator to Python: input rows are grouped so we send them in batches to Python. - // For each row, add it to the queue. - val inputIterator = iter.grouped(100).map { inputRows => - val toBePickled = inputRows.map { row => - queue.add(row) - EvaluatePython.toJava(currentRow(row), schema) - }.toArray - pickle.dumps(toBePickled) - } - - val context = TaskContext.get() - - // Output iterator for results from Python. - val outputIterator = new PythonRunner( - udf.command, - udf.envVars, - udf.pythonIncludes, - udf.pythonExec, - udf.pythonVer, - udf.broadcastVars, - udf.accumulator, - bufferSize, - reuseWorker - ).compute(inputIterator, context.partitionId(), context) - - val unpickle = new Unpickler - val row = new GenericMutableRow(1) - val joined = new JoinedRow - - outputIterator.flatMap { pickedResult => - val unpickledBatch = unpickle.loads(pickedResult) - unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala - }.map { result => - row(0) = EvaluatePython.fromJava(result, udf.dataType) - joined(queue.poll(), row) - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala new file mode 100644 index 0000000000000..c9ab40a0a9abf --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala @@ -0,0 +1,149 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.python + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import net.razorvine.pickle.{Pickler, Unpickler} + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRunner} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.{DataType, StructField, StructType} + + +/** + * A physical plan that evaluates a [[PythonUDF]], one partition of tuples at a time. + * + * Python evaluation works by sending the necessary (projected) input data via a socket to an + * external Python process, and combine the result from the Python process with the original row. + * + * For each row we send to Python, we also put it in a queue. For each output row from Python, + * we drain the queue to find the original input row. Note that if the Python process is way too + * slow, this could lead to the queue growing unbounded and eventually run out of memory. + */ +case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) + extends SparkPlan { + + def children: Seq[SparkPlan] = child :: Nil + + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { + udf.children match { + case Seq(u: PythonUDF) => + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) + case children => + // There should not be any other UDFs, or the children can't be evaluated directly. + assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) + (ChainedPythonFunctions(Seq(udf.func)), udf.children) + } + } + + protected override def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute().map(_.copy()) + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + + inputRDD.mapPartitions { iter => + EvaluatePython.registerPicklers() // register pickler for Row + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() + + val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip + + // flatten all the arguments + val allInputs = new ArrayBuffer[Expression] + val dataTypes = new ArrayBuffer[DataType] + val argOffsets = inputs.map { input => + input.map { e => + if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) + } else { + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 + } + }.toArray + }.toArray + val projection = newMutableProjection(allInputs, child.output)() + val schema = StructType(dataTypes.map(dt => StructField("", dt))) + val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython) + + // enable memo iff we serialize the row with schema (schema and class should be memorized) + val pickle = new Pickler(needConversion) + // Input iterator to Python: input rows are grouped so we send them in batches to Python. + // For each row, add it to the queue. + val inputIterator = iter.grouped(100).map { inputRows => + val toBePickled = inputRows.map { inputRow => + queue.add(inputRow) + val row = projection(inputRow) + if (needConversion) { + EvaluatePython.toJava(row, schema) + } else { + // fast path for these types that does not need conversion in Python + val fields = new Array[Any](row.numFields) + var i = 0 + while (i < row.numFields) { + val dt = dataTypes(i) + fields(i) = EvaluatePython.toJava(row.get(i, dt), dt) + i += 1 + } + fields + } + }.toArray + pickle.dumps(toBePickled) + } + + val context = TaskContext.get() + + // Output iterator for results from Python. + val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, argOffsets) + .compute(inputIterator, context.partitionId(), context) + + val unpickle = new Unpickler + val mutableRow = new GenericMutableRow(1) + val joined = new JoinedRow + val resultType = if (udfs.length == 1) { + udfs.head.dataType + } else { + StructType(udfs.map(u => StructField("", u.dataType, u.nullable))) + } + val resultProj = UnsafeProjection.create(output, output) + + outputIterator.flatMap { pickedResult => + val unpickledBatch = unpickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala + }.map { result => + val row = if (udfs.length == 1) { + // fast path for single UDF + mutableRow(0) = EvaluatePython.fromJava(result, resultType) + mutableRow + } else { + EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow] + } + resultProj(joined(queue.poll(), row)) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala new file mode 100644 index 0000000000000..3b05e29e52bd3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -0,0 +1,255 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.python + +import java.io.OutputStream +import java.nio.charset.StandardCharsets + +import scala.collection.JavaConverters._ + +import net.razorvine.pickle.{IObjectPickler, Opcodes, Pickler} + +import org.apache.spark.api.python.{PythonRDD, SerDeUtil} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +object EvaluatePython { + def takeAndServe(df: DataFrame, n: Int): Int = { + registerPicklers() + df.withNewExecutionId { + val iter = new SerDeUtil.AutoBatchedPickler( + df.queryExecution.executedPlan.executeTake(n).iterator.map { row => + EvaluatePython.toJava(row, df.schema) + }) + PythonRDD.serveIterator(iter, s"serve-DataFrame") + } + } + + def needConversionInPython(dt: DataType): Boolean = dt match { + case DateType | TimestampType => true + case _: StructType => true + case _: UserDefinedType[_] => true + case ArrayType(elementType, _) => needConversionInPython(elementType) + case MapType(keyType, valueType, _) => + needConversionInPython(keyType) || needConversionInPython(valueType) + case _ => false + } + + /** + * Helper for converting from Catalyst type to java type suitable for Pyrolite. + */ + def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { + case (null, _) => null + + case (row: InternalRow, struct: StructType) => + val values = new Array[Any](row.numFields) + var i = 0 + while (i < row.numFields) { + values(i) = toJava(row.get(i, struct.fields(i).dataType), struct.fields(i).dataType) + i += 1 + } + new GenericRowWithSchema(values, struct) + + case (a: ArrayData, array: ArrayType) => + val values = new java.util.ArrayList[Any](a.numElements()) + a.foreach(array.elementType, (_, e) => { + values.add(toJava(e, array.elementType)) + }) + values + + case (map: MapData, mt: MapType) => + val jmap = new java.util.HashMap[Any, Any](map.numElements()) + map.foreach(mt.keyType, mt.valueType, (k, v) => { + jmap.put(toJava(k, mt.keyType), toJava(v, mt.valueType)) + }) + jmap + + case (ud, udt: UserDefinedType[_]) => toJava(ud, udt.sqlType) + + case (d: Decimal, _) => d.toJavaBigDecimal + + case (s: UTF8String, StringType) => s.toString + + case (other, _) => other + } + + /** + * Converts `obj` to the type specified by the data type, or returns null if the type of obj is + * unexpected. Because Python doesn't enforce the type. + */ + def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { + case (null, _) => null + + case (c: Boolean, BooleanType) => c + + case (c: Int, ByteType) => c.toByte + case (c: Long, ByteType) => c.toByte + + case (c: Int, ShortType) => c.toShort + case (c: Long, ShortType) => c.toShort + + case (c: Int, IntegerType) => c + case (c: Long, IntegerType) => c.toInt + + case (c: Int, LongType) => c.toLong + case (c: Long, LongType) => c + + case (c: Double, FloatType) => c.toFloat + + case (c: Double, DoubleType) => c + + case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c, dt.precision, dt.scale) + + case (c: Int, DateType) => c + + case (c: Long, TimestampType) => c + + case (c, StringType) => UTF8String.fromString(c.toString) + + case (c: String, BinaryType) => c.getBytes(StandardCharsets.UTF_8) + case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c + + case (c: java.util.List[_], ArrayType(elementType, _)) => + new GenericArrayData(c.asScala.map { e => fromJava(e, elementType)}.toArray) + + case (c, ArrayType(elementType, _)) if c.getClass.isArray => + new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType))) + + case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => + val keyValues = c.asScala.toSeq + val keys = keyValues.map(kv => fromJava(kv._1, keyType)).toArray + val values = keyValues.map(kv => fromJava(kv._2, valueType)).toArray + ArrayBasedMapData(keys, values) + + case (c, StructType(fields)) if c.getClass.isArray => + val array = c.asInstanceOf[Array[_]] + if (array.length != fields.length) { + throw new IllegalStateException( + s"Input row doesn't have expected number of values required by the schema. " + + s"${fields.length} fields are required while ${array.length} values are provided." + ) + } + new GenericInternalRow(array.zip(fields).map { + case (e, f) => fromJava(e, f.dataType) + }) + + case (_, udt: UserDefinedType[_]) => fromJava(obj, udt.sqlType) + + // all other unexpected type should be null, or we will have runtime exception + // TODO(davies): we could improve this by try to cast the object to expected type + case (c, _) => null + } + + private val module = "pyspark.sql.types" + + /** + * Pickler for StructType + */ + private class StructTypePickler extends IObjectPickler { + + private val cls = classOf[StructType] + + def register(): Unit = { + Pickler.registerCustomPickler(cls, this) + } + + def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { + out.write(Opcodes.GLOBAL) + out.write( + (module + "\n" + "_parse_datatype_json_string" + "\n").getBytes(StandardCharsets.UTF_8)) + val schema = obj.asInstanceOf[StructType] + pickler.save(schema.json) + out.write(Opcodes.TUPLE1) + out.write(Opcodes.REDUCE) + } + } + + /** + * Pickler for external row. + */ + private class RowPickler extends IObjectPickler { + + private val cls = classOf[GenericRowWithSchema] + + // register this to Pickler and Unpickler + def register(): Unit = { + Pickler.registerCustomPickler(this.getClass, this) + Pickler.registerCustomPickler(cls, this) + } + + def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { + if (obj == this) { + out.write(Opcodes.GLOBAL) + out.write( + (module + "\n" + "_create_row_inbound_converter" + "\n").getBytes(StandardCharsets.UTF_8)) + } else { + // it will be memorized by Pickler to save some bytes + pickler.save(this) + val row = obj.asInstanceOf[GenericRowWithSchema] + // schema should always be same object for memoization + pickler.save(row.schema) + out.write(Opcodes.TUPLE1) + out.write(Opcodes.REDUCE) + + out.write(Opcodes.MARK) + var i = 0 + while (i < row.values.length) { + pickler.save(row.values(i)) + i += 1 + } + out.write(Opcodes.TUPLE) + out.write(Opcodes.REDUCE) + } + } + } + + private[this] var registered = false + + /** + * This should be called before trying to serialize any above classes un cluster mode, + * this should be put in the closure + */ + def registerPicklers(): Unit = { + synchronized { + if (!registered) { + SerDeUtil.initialize() + new StructTypePickler().register() + new RowPickler().register() + registered = true + } + } + } + + /** + * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by + * PySpark. + */ + def javaToPython(rdd: RDD[Any]): RDD[Array[Byte]] = { + rdd.mapPartitions { iter => + registerPicklers() // let it called in executor + new SerDeUtil.AutoBatchedPickler(iter) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala new file mode 100644 index 0000000000000..d72b3d347d0f6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -0,0 +1,114 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.python + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution +import org.apache.spark.sql.execution.SparkPlan + +/** + * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated + * alone in a batch. + * + * Only extracts the PythonUDFs that could be evaluated in Python (the single child is PythonUDFs + * or all the children could be evaluated in JVM). + * + * This has the limitation that the input to the Python UDF is not allowed include attributes from + * multiple child operators. + */ +private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] { + + private def hasPythonUDF(e: Expression): Boolean = { + e.find(_.isInstanceOf[PythonUDF]).isDefined + } + + private def canEvaluateInPython(e: PythonUDF): Boolean = { + e.children match { + // single PythonUDF child could be chained and evaluated in Python + case Seq(u: PythonUDF) => canEvaluateInPython(u) + // Python UDF can't be evaluated directly in JVM + case children => !children.exists(hasPythonUDF) + } + } + + private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match { + case udf: PythonUDF if canEvaluateInPython(udf) => Seq(udf) + case e => e.children.flatMap(collectEvaluatableUDF) + } + + def apply(plan: SparkPlan): SparkPlan = plan transformUp { + case plan: SparkPlan => extract(plan) + } + + /** + * Extract all the PythonUDFs from the current operator. + */ + def extract(plan: SparkPlan): SparkPlan = { + val udfs = plan.expressions.flatMap(collectEvaluatableUDF) + if (udfs.isEmpty) { + // If there aren't any, we are done. + plan + } else { + val attributeMap = mutable.HashMap[PythonUDF, Expression]() + // Rewrite the child that has the input required for the UDF + val newChildren = plan.children.map { child => + // Pick the UDF we are going to evaluate + val validUdfs = udfs.filter { case udf => + // Check to make sure that the UDF can be evaluated with only the input of this child. + udf.references.subsetOf(child.outputSet) + } + if (validUdfs.nonEmpty) { + val resultAttrs = udfs.zipWithIndex.map { case (u, i) => + AttributeReference(s"pythonUDF$i", u.dataType)() + } + val evaluation = BatchPythonEvaluation(validUdfs, child.output ++ resultAttrs, child) + attributeMap ++= validUdfs.zip(resultAttrs) + evaluation + } else { + child + } + } + // Other cases are disallowed as they are ambiguous or would require a cartesian + // product. + udfs.filterNot(attributeMap.contains).foreach { udf => + if (udf.references.subsetOf(plan.inputSet)) { + sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") + } else { + sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.") + } + } + + val rewritten = plan.transformExpressions { + case p: PythonUDF if attributeMap.contains(p) => + attributeMap(p) + }.withNewChildren(newChildren) + + // extract remaining python UDFs recursively + val newPlan = extract(rewritten) + if (newPlan.output != plan.output) { + // Trim away the new UDF value if it was only used for filtering or something. + execution.Project(plan.output, newPlan) + } else { + newPlan + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala new file mode 100644 index 0000000000000..59d7e8dd6dffb --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.api.python.PythonFunction +import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable} +import org.apache.spark.sql.types.DataType + +/** + * A serialized version of a Python lambda function. + */ +case class PythonUDF( + name: String, + func: PythonFunction, + dataType: DataType, + children: Seq[Expression]) + extends Expression with Unevaluable with NonSQLExpression { + + override def toString: String = s"$name(${children.mkString(", ")})" + + override def nullable: Boolean = true +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala new file mode 100644 index 0000000000000..d301874c223d1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.api.python.PythonFunction +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.Column +import org.apache.spark.sql.types.DataType + +/** + * A user-defined Python function. This is used by the Python API. + */ +case class UserDefinedPythonFunction( + name: String, + func: PythonFunction, + dataType: DataType) { + + def builder(e: Seq[Expression]): PythonUDF = { + PythonUDF(name, func, dataType, e) + } + + /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ + def apply(exprs: Column*): Column = { + val udf = builder(exprs.map(_.expr)) + Column(udf) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala deleted file mode 100644 index 0e601cd2cab5d..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.catalyst.rules.Rule - -/** - * Converts Java-object-based rows into [[UnsafeRow]]s. - */ -case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode { - - require(UnsafeProjection.canSupport(child.schema), s"Cannot convert ${child.schema} to Unsafe") - - override def output: Seq[Attribute] = child.output - override def outputPartitioning: Partitioning = child.outputPartitioning - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = false - override def canProcessSafeRows: Boolean = true - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => - val convertToUnsafe = UnsafeProjection.create(child.schema) - iter.map(convertToUnsafe) - } - } -} - -/** - * Converts [[UnsafeRow]]s back into Java-object-based rows. - */ -case class ConvertToSafe(child: SparkPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output - override def outputPartitioning: Partitioning = child.outputPartitioning - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def outputsUnsafeRows: Boolean = false - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => - val convertToSafe = FromUnsafeProjection(child.output.map(_.dataType)) - iter.map(convertToSafe) - } - } -} - -private[sql] object EnsureRowFormats extends Rule[SparkPlan] { - - private def onlyHandlesSafeRows(operator: SparkPlan): Boolean = - operator.canProcessSafeRows && !operator.canProcessUnsafeRows - - private def onlyHandlesUnsafeRows(operator: SparkPlan): Boolean = - operator.canProcessUnsafeRows && !operator.canProcessSafeRows - - private def handlesBothSafeAndUnsafeRows(operator: SparkPlan): Boolean = - operator.canProcessSafeRows && operator.canProcessUnsafeRows - - override def apply(operator: SparkPlan): SparkPlan = operator.transformUp { - case operator: SparkPlan if onlyHandlesSafeRows(operator) => - if (operator.children.exists(_.outputsUnsafeRows)) { - operator.withNewChildren { - operator.children.map { - c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c - } - } - } else { - operator - } - case operator: SparkPlan if onlyHandlesUnsafeRows(operator) => - if (operator.children.exists(!_.outputsUnsafeRows)) { - operator.withNewChildren { - operator.children.map { - c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c - } - } - } else { - operator - } - case operator: SparkPlan if handlesBothSafeAndUnsafeRows(operator) => - if (operator.children.map(_.outputsUnsafeRows).toSet.size != 1) { - // If this operator's children produce both unsafe and safe rows, - // convert everything unsafe rows if all the schema of them are support by UnsafeRow - if (operator.children.forall(c => UnsafeProjection.canSupport(c.schema))) { - operator.withNewChildren { - operator.children.map { - c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c - } - } - } else { - operator.withNewChildren { - operator.children.map { - c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c - } - } - } - } else { - operator - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala deleted file mode 100644 index 1a3832a698b61..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution} -import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.StructType -import org.apache.spark.util.CompletionIterator -import org.apache.spark.util.collection.ExternalSorter -import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// This file defines various sort operators. -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/** - * Performs a sort, spilling to disk as needed. - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - */ -case class Sort( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan) - extends UnaryNode { - - override def requiredChildDistribution: Seq[Distribution] = - if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - child.execute().mapPartitions( { iterator => - val ordering = newOrdering(sortOrder, child.output) - val sorter = new ExternalSorter[InternalRow, Null, InternalRow]( - TaskContext.get(), ordering = Some(ordering)) - sorter.insertAll(iterator.map(r => (r.copy(), null))) - val baseIterator = sorter.iterator.map(_._1) - val context = TaskContext.get() - context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) - context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) - context.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) - // TODO(marmbrus): The complex type signature below thwarts inference for no reason. - CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop()) - }, preservesPartitioning = true) - } - - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder -} - -/** - * Optimized version of [[Sort]] that operates on binary data (implemented as part of - * Project Tungsten). - * - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will - * spill every `frequency` records. - */ - -case class TungstenSort( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan, - testSpillFrequency: Int = 0) - extends UnaryNode { - - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder - - override def requiredChildDistribution: Seq[Distribution] = - if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - - override private[sql] lazy val metrics = Map( - "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), - "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) - - protected override def doExecute(): RDD[InternalRow] = { - val schema = child.schema - val childOutput = child.output - - val dataSize = longMetric("dataSize") - val spillSize = longMetric("spillSize") - - child.execute().mapPartitions { iter => - val ordering = newOrdering(sortOrder, childOutput) - - // The comparator for comparing prefix - val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput) - val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) - - // The generator for prefix - val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) - val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { - override def computePrefix(row: InternalRow): Long = { - prefixProjection.apply(row).getLong(0) - } - } - - val pageSize = SparkEnv.get.memoryManager.pageSizeBytes - val sorter = new UnsafeExternalRowSorter( - schema, ordering, prefixComparator, prefixComputer, pageSize) - if (testSpillFrequency > 0) { - sorter.setTestSpillFrequency(testSpillFrequency) - } - - // Remember spill data size of this task before execute this operator so that we can - // figure out how many bytes we spilled for this operator. - val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled - - val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) - - dataSize += sorter.getPeakMemoryUsage - spillSize += TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore - - TaskContext.get().internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage) - sortedIterator - } - } - -} - -object TungstenSort { - /** - * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise. - */ - def supportsSchema(schema: StructType): Boolean = { - UnsafeExternalRowSorter.supportsSchema(schema) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index db463029aedf7..8c2231335c789 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.execution.stat import scala.collection.mutable.{Map => MutableMap} -import org.apache.spark.Logging +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Row, Column, DataFrame} private[sql] object FrequentItems extends Logging { @@ -121,6 +121,6 @@ private[sql] object FrequentItems extends Logging { StructField(v._1 + "_freqItems", ArrayType(v._2, false)) } val schema = StructType(outputCols).toAttributes - new DataFrame(df.sqlContext, LocalRelation.fromExternalRows(schema, Seq(resultRow))) + Dataset.ofRows(df.sqlContext, LocalRelation.fromExternalRows(schema, Seq(resultRow))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 00231d65a7d54..d603f63a08501 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.execution.stat -import org.apache.spark.Logging -import org.apache.spark.sql.{Row, Column, DataFrame} -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast} +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} +import org.apache.spark.sql.catalyst.expressions.{Cast, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -27,9 +29,316 @@ import org.apache.spark.unsafe.types.UTF8String private[sql] object StatFunctions extends Logging { + import QuantileSummaries.Stats + + /** + * Calculates the approximate quantiles of multiple numerical columns of a DataFrame in one pass. + * + * The result of this algorithm has the following deterministic bound: + * If the DataFrame has N elements and if we request the quantile at probability `p` up to error + * `err`, then the algorithm will return a sample `x` from the DataFrame so that the *exact* rank + * of `x` is close to (p * N). + * More precisely, + * + * floor((p - err) * N) <= rank(x) <= ceil((p + err) * N). + * + * This method implements a variation of the Greenwald-Khanna algorithm (with some speed + * optimizations). + * The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 Space-efficient + * Online Computation of Quantile Summaries]] by Greenwald and Khanna. + * + * @param df the dataframe + * @param cols numerical columns of the dataframe + * @param probabilities a list of quantile probabilities + * Each number must belong to [0, 1]. + * For example 0 is the minimum, 0.5 is the median, 1 is the maximum. + * @param relativeError The relative target precision to achieve (>= 0). + * If set to zero, the exact quantiles are computed, which could be very expensive. + * Note that values greater than 1 are accepted but give the same result as 1. + * + * @return for each column, returns the requested approximations + */ + def multipleApproxQuantiles( + df: DataFrame, + cols: Seq[String], + probabilities: Seq[Double], + relativeError: Double): Seq[Seq[Double]] = { + val columns: Seq[Column] = cols.map { colName => + val field = df.schema(colName) + require(field.dataType.isInstanceOf[NumericType], + s"Quantile calculation for column $colName with data type ${field.dataType}" + + " is not supported.") + Column(Cast(Column(colName).expr, DoubleType)) + } + val emptySummaries = Array.fill(cols.size)( + new QuantileSummaries(QuantileSummaries.defaultCompressThreshold, relativeError)) + + // Note that it works more or less by accident as `rdd.aggregate` is not a pure function: + // this function returns the same array as given in the input (because `aggregate` reuses + // the same argument). + def apply(summaries: Array[QuantileSummaries], row: Row): Array[QuantileSummaries] = { + var i = 0 + while (i < summaries.length) { + summaries(i) = summaries(i).insert(row.getDouble(i)) + i += 1 + } + summaries + } + + def merge( + sum1: Array[QuantileSummaries], + sum2: Array[QuantileSummaries]): Array[QuantileSummaries] = { + sum1.zip(sum2).map { case (s1, s2) => s1.compress().merge(s2.compress()) } + } + val summaries = df.select(columns: _*).rdd.aggregate(emptySummaries)(apply, merge) + + summaries.map { summary => probabilities.map(summary.query) } + } + + /** + * Helper class to compute approximate quantile summary. + * This implementation is based on the algorithm proposed in the paper: + * "Space-efficient Online Computation of Quantile Summaries" by Greenwald, Michael + * and Khanna, Sanjeev. (http://dx.doi.org/10.1145/375663.375670) + * + * In order to optimize for speed, it maintains an internal buffer of the last seen samples, + * and only inserts them after crossing a certain size threshold. This guarantees a near-constant + * runtime complexity compared to the original algorithm. + * + * @param compressThreshold the compression threshold. + * After the internal buffer of statistics crosses this size, it attempts to compress the + * statistics together. + * @param relativeError the target relative error. + * It is uniform across the complete range of values. + * @param sampled a buffer of quantile statistics. + * See the G-K article for more details. + * @param count the count of all the elements *inserted in the sampled buffer* + * (excluding the head buffer) + * @param headSampled a buffer of latest samples seen so far + */ + class QuantileSummaries( + val compressThreshold: Int, + val relativeError: Double, + val sampled: ArrayBuffer[Stats] = ArrayBuffer.empty, + private[stat] var count: Long = 0L, + val headSampled: ArrayBuffer[Double] = ArrayBuffer.empty) extends Serializable { + + import QuantileSummaries._ + + /** + * Returns a summary with the given observation inserted into the summary. + * This method may either modify in place the current summary (and return the same summary, + * modified in place), or it may create a new summary from scratch it necessary. + * @param x the new observation to insert into the summary + */ + def insert(x: Double): QuantileSummaries = { + headSampled.append(x) + if (headSampled.size >= defaultHeadSize) { + this.withHeadBufferInserted + } else { + this + } + } + + /** + * Inserts an array of (unsorted samples) in a batch, sorting the array first to traverse + * the summary statistics in a single batch. + * + * This method does not modify the current object and returns if necessary a new copy. + * + * @return a new quantile summary object. + */ + private def withHeadBufferInserted: QuantileSummaries = { + if (headSampled.isEmpty) { + return this + } + var currentCount = count + val sorted = headSampled.toArray.sorted + val newSamples: ArrayBuffer[Stats] = new ArrayBuffer[Stats]() + // The index of the next element to insert + var sampleIdx = 0 + // The index of the sample currently being inserted. + var opsIdx: Int = 0 + while(opsIdx < sorted.length) { + val currentSample = sorted(opsIdx) + // Add all the samples before the next observation. + while(sampleIdx < sampled.size && sampled(sampleIdx).value <= currentSample) { + newSamples.append(sampled(sampleIdx)) + sampleIdx += 1 + } + + // If it is the first one to insert, of if it is the last one + currentCount += 1 + val delta = + if (newSamples.isEmpty || (sampleIdx == sampled.size && opsIdx == sorted.length - 1)) { + 0 + } else { + math.floor(2 * relativeError * currentCount).toInt + } + + val tuple = Stats(currentSample, 1, delta) + newSamples.append(tuple) + opsIdx += 1 + } + + // Add all the remaining existing samples + while(sampleIdx < sampled.size) { + newSamples.append(sampled(sampleIdx)) + sampleIdx += 1 + } + new QuantileSummaries(compressThreshold, relativeError, newSamples, currentCount) + } + + /** + * Returns a new summary that compresses the summary statistics and the head buffer. + * + * This implements the COMPRESS function of the GK algorithm. It does not modify the object. + * + * @return a new summary object with compressed statistics + */ + def compress(): QuantileSummaries = { + // Inserts all the elements first + val inserted = this.withHeadBufferInserted + assert(inserted.headSampled.isEmpty) + assert(inserted.count == count + headSampled.size) + val compressed = + compressImmut(inserted.sampled, mergeThreshold = 2 * relativeError * inserted.count) + new QuantileSummaries(compressThreshold, relativeError, compressed, inserted.count) + } + + private def shallowCopy: QuantileSummaries = { + new QuantileSummaries(compressThreshold, relativeError, sampled, count, headSampled) + } + + /** + * Merges two (compressed) summaries together. + * + * Returns a new summary. + */ + def merge(other: QuantileSummaries): QuantileSummaries = { + require(headSampled.isEmpty, "Current buffer needs to be compressed before merge") + require(other.headSampled.isEmpty, "Other buffer needs to be compressed before merge") + if (other.count == 0) { + this.shallowCopy + } else if (count == 0) { + other.shallowCopy + } else { + // Merge the two buffers. + // The GK algorithm is a bit unclear about it, but it seems there is no need to adjust the + // statistics during the merging: the invariants are still respected after the merge. + // TODO: could replace full sort by ordered merge, the two lists are known to be sorted + // already. + val res = (sampled ++ other.sampled).sortBy(_.value) + val comp = compressImmut(res, mergeThreshold = 2 * relativeError * count) + new QuantileSummaries( + other.compressThreshold, other.relativeError, comp, other.count + count) + } + } + + /** + * Runs a query for a given quantile. + * The result follows the approximation guarantees detailed above. + * The query can only be run on a compressed summary: you need to call compress() before using + * it. + * + * @param quantile the target quantile + * @return + */ + def query(quantile: Double): Double = { + require(quantile >= 0 && quantile <= 1.0, "quantile should be in the range [0.0, 1.0]") + require(headSampled.isEmpty, + "Cannot operate on an uncompressed summary, call compress() first") + + if (quantile <= relativeError) { + return sampled.head.value + } + + if (quantile >= 1 - relativeError) { + return sampled.last.value + } + + // Target rank + val rank = math.ceil(quantile * count).toInt + val targetError = math.ceil(relativeError * count) + // Minimum rank at current sample + var minRank = 0 + var i = 1 + while (i < sampled.size - 1) { + val curSample = sampled(i) + minRank += curSample.g + val maxRank = minRank + curSample.delta + if (maxRank - targetError <= rank && rank <= minRank + targetError) { + return curSample.value + } + i += 1 + } + sampled.last.value + } + } + + object QuantileSummaries { + // TODO(tjhunter) more tuning could be done one the constants here, but for now + // the main cost of the algorithm is accessing the data in SQL. + /** + * The default value for the compression threshold. + */ + val defaultCompressThreshold: Int = 10000 + + /** + * The size of the head buffer. + */ + val defaultHeadSize: Int = 50000 + + /** + * The default value for the relative error (1%). + * With this value, the best extreme percentiles that can be approximated are 1% and 99%. + */ + val defaultRelativeError: Double = 0.01 + + /** + * Statistics from the Greenwald-Khanna paper. + * @param value the sampled value + * @param g the minimum rank jump from the previous value's minimum rank + * @param delta the maximum span of the rank. + */ + case class Stats(value: Double, g: Int, delta: Int) + + private def compressImmut( + currentSamples: IndexedSeq[Stats], + mergeThreshold: Double): ArrayBuffer[Stats] = { + val res: ArrayBuffer[Stats] = ArrayBuffer.empty + if (currentSamples.isEmpty) { + return res + } + // Start for the last element, which is always part of the set. + // The head contains the current new head, that may be merged with the current element. + var head = currentSamples.last + var i = currentSamples.size - 2 + // Do not compress the last element + while (i >= 1) { + // The current sample: + val sample1 = currentSamples(i) + // Do we need to compress? + if (sample1.g + head.g + head.delta < mergeThreshold) { + // Do not insert yet, just merge the current element into the head. + head = head.copy(g = head.g + sample1.g) + } else { + // Prepend the current head, and keep the current sample as target for merging. + res.prepend(head) + head = sample1 + } + i -= 1 + } + res.prepend(head) + // If necessary, add the minimum element: + res.prepend(currentSamples.head) + res + } + } + /** Calculate the Pearson Correlation Coefficient for the given columns */ private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = { - val counts = collectStatisticalData(df, cols) + val counts = collectStatisticalData(df, cols, "correlation") counts.Ck / math.sqrt(counts.MkX * counts.MkY) } @@ -73,13 +382,14 @@ private[sql] object StatFunctions extends Logging { def cov: Double = Ck / (count - 1) } - private def collectStatisticalData(df: DataFrame, cols: Seq[String]): CovarianceCounter = { - require(cols.length == 2, "Currently cov supports calculating the covariance " + + private def collectStatisticalData(df: DataFrame, cols: Seq[String], + functionName: String): CovarianceCounter = { + require(cols.length == 2, s"Currently $functionName calculation is supported " + "between two columns.") cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) => require(data.nonEmpty, s"Couldn't find column with name $name") - require(data.get.dataType.isInstanceOf[NumericType], "Covariance calculation for columns " + - s"with dataType ${data.get.dataType} not supported.") + require(data.get.dataType.isInstanceOf[NumericType], s"Currently $functionName calculation " + + s"for columns with dataType ${data.get.dataType} not supported.") } val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) df.select(columns: _*).queryExecution.toRdd.aggregate(new CovarianceCounter)( @@ -98,7 +408,7 @@ private[sql] object StatFunctions extends Logging { * @return the covariance of the two columns. */ private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = { - val counts = collectStatisticalData(df, cols) + val counts = collectStatisticalData(df, cols, "covariance") counts.cov } @@ -144,6 +454,6 @@ private[sql] object StatFunctions extends Logging { } val schema = StructType(StructField(tableName, StringType) +: headerNames) - new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0) + Dataset.ofRows(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Batch.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Batch.scala new file mode 100644 index 0000000000000..1f25eb8fc5223 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Batch.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.DataFrame + +/** + * Used to pass a batch of data through a streaming query execution along with an indication + * of progress in the stream. + */ +class Batch(val end: Offset, val data: DataFrame) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala new file mode 100644 index 0000000000000..729c8462fed65 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +/** + * An ordered collection of offsets, used to track the progress of processing data from one or more + * [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance + * vector clock that must progress linearly forward. + */ +case class CompositeOffset(offsets: Seq[Option[Offset]]) extends Offset { + /** + * Returns a negative integer, zero, or a positive integer as this object is less than, equal to, + * or greater than the specified object. + */ + override def compareTo(other: Offset): Int = other match { + case otherComposite: CompositeOffset if otherComposite.offsets.size == offsets.size => + val comparisons = offsets.zip(otherComposite.offsets).map { + case (Some(a), Some(b)) => a compareTo b + case (None, None) => 0 + case (None, _) => -1 + case (_, None) => 1 + } + val nonZeroSigns = comparisons.map(sign).filter(_ != 0).toSet + nonZeroSigns.size match { + case 0 => 0 // if both empty or only 0s + case 1 => nonZeroSigns.head // if there are only (0s and 1s) or (0s and -1s) + case _ => // there are both 1s and -1s + throw new IllegalArgumentException( + s"Invalid comparison between non-linear histories: $this <=> $other") + } + case _ => + throw new IllegalArgumentException(s"Cannot compare $this <=> $other") + } + + private def sign(num: Int): Int = num match { + case i if i < 0 => -1 + case i if i == 0 => 0 + case i if i > 0 => 1 + } + + /** + * Unpacks an offset into [[StreamProgress]] by associating each offset with the order list of + * sources. + * + * This method is typically used to associate a serialized offset with actual sources (which + * cannot be serialized). + */ + def toStreamProgress(sources: Seq[Source]): StreamProgress = { + assert(sources.size == offsets.size) + new StreamProgress ++ sources.zip(offsets).collect { case (s, Some(o)) => (s, o) } + } + + override def toString: String = + offsets.map(_.map(_.toString).getOrElse("-")).mkString("[", ", ", "]") +} + +object CompositeOffset { + /** + * Returns a [[CompositeOffset]] with a variable sequence of offsets. + * `nulls` in the sequence are converted to `None`s. + */ + def fill(offsets: Offset*): CompositeOffset = { + CompositeOffset(offsets.map(Option(_))) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousQueryListenerBus.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousQueryListenerBus.scala new file mode 100644 index 0000000000000..b1d24b6cfc0bd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousQueryListenerBus.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerEvent} +import org.apache.spark.sql.util.ContinuousQueryListener +import org.apache.spark.sql.util.ContinuousQueryListener._ +import org.apache.spark.util.ListenerBus + +/** + * A bus to forward events to [[ContinuousQueryListener]]s. This one will wrap received + * [[ContinuousQueryListener.Event]]s as WrappedContinuousQueryListenerEvents and send them to the + * Spark listener bus. It also registers itself with Spark listener bus, so that it can receive + * WrappedContinuousQueryListenerEvents, unwrap them as ContinuousQueryListener.Events and + * dispatch them to ContinuousQueryListener. + */ +class ContinuousQueryListenerBus(sparkListenerBus: LiveListenerBus) + extends SparkListener with ListenerBus[ContinuousQueryListener, ContinuousQueryListener.Event] { + + sparkListenerBus.addListener(this) + + /** + * Post a ContinuousQueryListener event to the Spark listener bus asynchronously. This event will + * be dispatched to all ContinuousQueryListener in the thread of the Spark listener bus. + */ + def post(event: ContinuousQueryListener.Event) { + event match { + case s: QueryStarted => + postToAll(s) + case _ => + sparkListenerBus.post(new WrappedContinuousQueryListenerEvent(event)) + } + } + + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case WrappedContinuousQueryListenerEvent(e) => + postToAll(e) + case _ => + } + } + + override protected def doPostEvent( + listener: ContinuousQueryListener, + event: ContinuousQueryListener.Event): Unit = { + event match { + case queryStarted: QueryStarted => + listener.onQueryStarted(queryStarted) + case queryProgress: QueryProgress => + listener.onQueryProgress(queryProgress) + case queryTerminated: QueryTerminated => + listener.onQueryTerminated(queryTerminated) + case _ => + } + } + + /** + * Wrapper for StreamingListenerEvent as SparkListenerEvent so that it can be posted to Spark + * listener bus. + */ + private case class WrappedContinuousQueryListenerEvent( + streamingListenerEvent: ContinuousQueryListener.Event) extends SparkListenerEvent { + + // Do not log streaming events in event log as history server does not support these events. + protected[spark] override def logEvent: Boolean = false + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala new file mode 100644 index 0000000000000..6921ae584dd84 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.UUID + +import org.apache.hadoop.fs.Path + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.sources.FileFormat + +object FileStreamSink { + // The name of the subdirectory that is used to store metadata about which files are valid. + val metadataDir = "_spark_metadata" +} + +/** + * A sink that writes out results to parquet files. Each batch is written out to a unique + * directory. After all of the files in a batch have been successfully written, the list of + * file paths is appended to the log atomically. In the case of partial failures, some duplicate + * data may be present in the target directory, but only one copy of each file will be present + * in the log. + */ +class FileStreamSink( + sqlContext: SQLContext, + path: String, + fileFormat: FileFormat) extends Sink with Logging { + + private val basePath = new Path(path) + private val logPath = new Path(basePath, FileStreamSink.metadataDir) + private val fileLog = new HDFSMetadataLog[Seq[String]](sqlContext, logPath.toUri.toString) + + override def addBatch(batchId: Long, data: DataFrame): Unit = { + if (fileLog.get(batchId).isDefined) { + logInfo(s"Skipping already committed batch $batchId") + } else { + val files = writeFiles(data) + if (fileLog.add(batchId, files)) { + logInfo(s"Committed batch $batchId") + } else { + logWarning(s"Race while writing batch $batchId") + } + } + } + + /** Writes the [[DataFrame]] to a UUID-named dir, returning the list of files paths. */ + private def writeFiles(data: DataFrame): Seq[String] = { + val ctx = sqlContext + val outputDir = path + val format = fileFormat + val schema = data.schema + + val file = new Path(basePath, UUID.randomUUID().toString).toUri.toString + data.write.parquet(file) + sqlContext.read + .schema(data.schema) + .parquet(file) + .inputFiles + .map(new Path(_)) + .filterNot(_.getName.startsWith("_")) + .map(_.toUri.toString) + } + + override def toString: String = s"FileSink[$path]" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala new file mode 100644 index 0000000000000..1b70055f346b3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.util.collection.OpenHashSet + +/** + * A very simple source that reads text files from the given directory as they appear. + * + * TODO Clean up the metadata files periodically + */ +class FileStreamSource( + sqlContext: SQLContext, + metadataPath: String, + path: String, + dataSchema: Option[StructType], + providerName: String, + dataFrameBuilder: Array[String] => DataFrame) extends Source with Logging { + + private val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration) + private val metadataLog = new HDFSMetadataLog[Seq[String]](sqlContext, metadataPath) + private var maxBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L) + + private val seenFiles = new OpenHashSet[String] + metadataLog.get(None, Some(maxBatchId)).foreach { case (batchId, files) => + files.foreach(seenFiles.add) + } + + /** Returns the schema of the data from this source */ + override lazy val schema: StructType = { + dataSchema.getOrElse { + val filesPresent = fetchAllFiles() + if (filesPresent.isEmpty) { + if (providerName == "text") { + // Add a default schema for "text" + new StructType().add("value", StringType) + } else { + throw new IllegalArgumentException("No schema specified") + } + } else { + // There are some existing files. Use them to infer the schema. + dataFrameBuilder(filesPresent.toArray).schema + } + } + } + + /** + * Returns the maximum offset that can be retrieved from the source. + * + * `synchronized` on this method is for solving race conditions in tests. In the normal usage, + * there is no race here, so the cost of `synchronized` should be rare. + */ + private def fetchMaxOffset(): LongOffset = synchronized { + val filesPresent = fetchAllFiles() + val newFiles = new ArrayBuffer[String]() + filesPresent.foreach { file => + if (!seenFiles.contains(file)) { + logDebug(s"new file: $file") + newFiles.append(file) + seenFiles.add(file) + } else { + logDebug(s"old file: $file") + } + } + + if (newFiles.nonEmpty) { + maxBatchId += 1 + metadataLog.add(maxBatchId, newFiles) + } + + new LongOffset(maxBatchId) + } + + /** + * For test only. Run `func` with the internal lock to make sure when `func` is running, + * the current offset won't be changed and no new batch will be emitted. + */ + def withBatchingLocked[T](func: => T): T = synchronized { + func + } + + /** Return the latest offset in the source */ + def currentOffset: LongOffset = synchronized { + new LongOffset(maxBatchId) + } + + /** + * Returns the next batch of data that is available after `start`, if any is available. + */ + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + val startId = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) + val endId = end.asInstanceOf[LongOffset].offset + + assert(startId <= endId) + val files = metadataLog.get(Some(startId + 1), Some(endId)).map(_._2).flatten + logInfo(s"Processing ${files.length} files from ${startId + 1}:$endId") + logDebug(s"Streaming ${files.mkString(", ")}") + dataFrameBuilder(files) + + } + + private def fetchAllFiles(): Seq[String] = { + val startTime = System.nanoTime() + val files = fs.listStatus(new Path(path)) + .filterNot(_.getPath.getName.startsWith("_")) + .map(_.getPath.toUri.toString) + val endTime = System.nanoTime() + logDebug(s"Listed ${files.size} in ${(endTime.toDouble - startTime) / 1000000}ms") + files + } + + override def getOffset: Option[Offset] = Some(fetchMaxOffset()).filterNot(_.offset == -1) + + override def toString: String = s"FileSource[$path]" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala new file mode 100644 index 0000000000000..9663fee18d364 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -0,0 +1,342 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.streaming + +import java.io.{FileNotFoundException, IOException} +import java.nio.ByteBuffer +import java.util.{ConcurrentModificationException, EnumSet, UUID} + +import scala.reflect.ClassTag + +import org.apache.commons.io.IOUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs._ +import org.apache.hadoop.fs.permission.FsPermission + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.sql.SQLContext + + +/** + * A [[MetadataLog]] implementation based on HDFS. [[HDFSMetadataLog]] uses the specified `path` + * as the metadata storage. + * + * When writing a new batch, [[HDFSMetadataLog]] will firstly write to a temp file and then rename + * it to the final batch file. If the rename step fails, there must be multiple writers and only + * one of them will succeed and the others will fail. + * + * Note: [[HDFSMetadataLog]] doesn't support S3-like file systems as they don't guarantee listing + * files in a directory always shows the latest files. + */ +class HDFSMetadataLog[T: ClassTag](sqlContext: SQLContext, path: String) + extends MetadataLog[T] + with Logging { + + import HDFSMetadataLog._ + + private val metadataPath = new Path(path) + private val fileManager = createFileManager() + + if (!fileManager.exists(metadataPath)) { + fileManager.mkdirs(metadataPath) + } + + /** + * A `PathFilter` to filter only batch files + */ + private val batchFilesFilter = new PathFilter { + override def accept(path: Path): Boolean = try { + path.getName.toLong + true + } catch { + case _: NumberFormatException => false + } + } + + private val serializer = new JavaSerializer(sqlContext.sparkContext.conf).newInstance() + + private def batchFile(batchId: Long): Path = { + new Path(metadataPath, batchId.toString) + } + + override def add(batchId: Long, metadata: T): Boolean = { + get(batchId).map(_ => false).getOrElse { + // Only write metadata when the batch has not yet been written. + val buffer = serializer.serialize(metadata) + try { + writeBatch(batchId, JavaUtils.bufferToArray(buffer)) + true + } catch { + case e: IOException if "java.lang.InterruptedException" == e.getMessage => + // create may convert InterruptedException to IOException. Let's convert it back to + // InterruptedException so that this failure won't crash StreamExecution + throw new InterruptedException("Creating file is interrupted") + } + } + } + + /** + * Write a batch to a temp file then rename it to the batch file. + * + * There may be multiple [[HDFSMetadataLog]] using the same metadata path. Although it is not a + * valid behavior, we still need to prevent it from destroying the files. + */ + private def writeBatch(batchId: Long, bytes: Array[Byte]): Unit = { + // Use nextId to create a temp file + var nextId = 0 + while (true) { + val tempPath = new Path(metadataPath, s".${UUID.randomUUID.toString}.tmp") + try { + val output = fileManager.create(tempPath) + try { + output.write(bytes) + } finally { + output.close() + } + try { + // Try to commit the batch + // It will fail if there is an existing file (someone has committed the batch) + logDebug(s"Attempting to write log #${batchFile(batchId)}") + fileManager.rename(tempPath, batchFile(batchId)) + return + } catch { + case e: IOException if isFileAlreadyExistsException(e) => + // If "rename" fails, it means some other "HDFSMetadataLog" has committed the batch. + // So throw an exception to tell the user this is not a valid behavior. + throw new ConcurrentModificationException( + s"Multiple HDFSMetadataLog are using $path", e) + case e: FileNotFoundException => + // Sometimes, "create" will succeed when multiple writers are calling it at the same + // time. However, only one writer can call "rename" successfully, others will get + // FileNotFoundException because the first writer has removed it. + throw new ConcurrentModificationException( + s"Multiple HDFSMetadataLog are using $path", e) + } + } catch { + case e: IOException if isFileAlreadyExistsException(e) => + // Failed to create "tempPath". There are two cases: + // 1. Someone is creating "tempPath" too. + // 2. This is a restart. "tempPath" has already been created but not moved to the final + // batch file (not committed). + // + // For both cases, the batch has not yet been committed. So we can retry it. + // + // Note: there is a potential risk here: if HDFSMetadataLog A is running, people can use + // the same metadata path to create "HDFSMetadataLog" and fail A. However, this is not a + // big problem because it requires the attacker must have the permission to write the + // metadata path. In addition, the old Streaming also have this issue, people can create + // malicious checkpoint files to crash a Streaming application too. + nextId += 1 + } finally { + fileManager.delete(tempPath) + } + } + } + + private def isFileAlreadyExistsException(e: IOException): Boolean = { + e.isInstanceOf[FileAlreadyExistsException] || + // Old Hadoop versions don't throw FileAlreadyExistsException. Although it's fixed in + // HADOOP-9361, we still need to support old Hadoop versions. + (e.getMessage != null && e.getMessage.startsWith("File already exists: ")) + } + + override def get(batchId: Long): Option[T] = { + val batchMetadataFile = batchFile(batchId) + if (fileManager.exists(batchMetadataFile)) { + val input = fileManager.open(batchMetadataFile) + val bytes = IOUtils.toByteArray(input) + Some(serializer.deserialize[T](ByteBuffer.wrap(bytes))) + } else { + logDebug(s"Unable to find batch $batchMetadataFile") + None + } + } + + override def get(startId: Option[Long], endId: Option[Long]): Array[(Long, T)] = { + val files = fileManager.list(metadataPath, batchFilesFilter) + val batchIds = files + .map(_.getPath.getName.toLong) + .filter { batchId => + (endId.isEmpty || batchId <= endId.get) && (startId.isEmpty || batchId >= startId.get) + } + batchIds.sorted.map(batchId => (batchId, get(batchId))).filter(_._2.isDefined).map { + case (batchId, metadataOption) => + (batchId, metadataOption.get) + } + } + + override def getLatest(): Option[(Long, T)] = { + val batchIds = fileManager.list(metadataPath, batchFilesFilter) + .map(_.getPath.getName.toLong) + .sorted + .reverse + for (batchId <- batchIds) { + val batch = get(batchId) + if (batch.isDefined) { + return Some((batchId, batch.get)) + } + } + None + } + + private def createFileManager(): FileManager = { + val hadoopConf = sqlContext.sparkContext.hadoopConfiguration + try { + new FileContextManager(metadataPath, hadoopConf) + } catch { + case e: UnsupportedFileSystemException => + logWarning("Could not use FileContext API for managing metadata log file. The log may be" + + "inconsistent under failures.", e) + new FileSystemManager(metadataPath, hadoopConf) + } + } +} + +object HDFSMetadataLog { + + /** A simple trait to abstract out the file management operations needed by HDFSMetadataLog. */ + trait FileManager { + + /** List the files in a path that matches a filter. */ + def list(path: Path, filter: PathFilter): Array[FileStatus] + + /** Make directory at the give path and all its parent directories as needed. */ + def mkdirs(path: Path): Unit + + /** Whether path exists */ + def exists(path: Path): Boolean + + /** Open a file for reading, or throw exception if it does not exist. */ + def open(path: Path): FSDataInputStream + + /** Create path, or throw exception if it already exists */ + def create(path: Path): FSDataOutputStream + + /** + * Atomically rename path, or throw exception if it cannot be done. + * Should throw FileNotFoundException if srcPath does not exist. + * Should throw FileAlreadyExistsException if destPath already exists. + */ + def rename(srcPath: Path, destPath: Path): Unit + + /** Recursively delete a path if it exists. Should not throw exception if file doesn't exist. */ + def delete(path: Path): Unit + } + + /** + * Default implementation of FileManager using newer FileContext API. + */ + class FileContextManager(path: Path, hadoopConf: Configuration) extends FileManager { + private val fc = if (path.toUri.getScheme == null) { + FileContext.getFileContext(hadoopConf) + } else { + FileContext.getFileContext(path.toUri, hadoopConf) + } + + override def list(path: Path, filter: PathFilter): Array[FileStatus] = { + fc.util.listStatus(path, filter) + } + + override def rename(srcPath: Path, destPath: Path): Unit = { + fc.rename(srcPath, destPath) + } + + override def mkdirs(path: Path): Unit = { + fc.mkdir(path, FsPermission.getDirDefault, true) + } + + override def open(path: Path): FSDataInputStream = { + fc.open(path) + } + + override def create(path: Path): FSDataOutputStream = { + fc.create(path, EnumSet.of(CreateFlag.CREATE)) + } + + override def exists(path: Path): Boolean = { + fc.util().exists(path) + } + + override def delete(path: Path): Unit = { + try { + fc.delete(path, true) + } catch { + case e: FileNotFoundException => + // ignore if file has already been deleted + } + } + } + + /** + * Implementation of FileManager using older FileSystem API. Note that this implementation + * cannot provide atomic renaming of paths, hence can lead to consistency issues. This + * should be used only as a backup option, when FileContextManager cannot be used. + */ + class FileSystemManager(path: Path, hadoopConf: Configuration) extends FileManager { + private val fs = path.getFileSystem(hadoopConf) + + override def list(path: Path, filter: PathFilter): Array[FileStatus] = { + fs.listStatus(path, filter) + } + + /** + * Rename a path. Note that this implementation is not atomic. + * @throws FileNotFoundException if source path does not exist. + * @throws FileAlreadyExistsException if destination path already exists. + * @throws IOException if renaming fails for some unknown reason. + */ + override def rename(srcPath: Path, destPath: Path): Unit = { + if (!fs.exists(srcPath)) { + throw new FileNotFoundException(s"Source path does not exist: $srcPath") + } + if (fs.exists(destPath)) { + throw new FileAlreadyExistsException(s"Destination path already exists: $destPath") + } + if (!fs.rename(srcPath, destPath)) { + throw new IOException(s"Failed to rename $srcPath to $destPath") + } + } + + override def mkdirs(path: Path): Unit = { + fs.mkdirs(path, FsPermission.getDirDefault) + } + + override def open(path: Path): FSDataInputStream = { + fs.open(path) + } + + override def create(path: Path): FSDataOutputStream = { + fs.create(path, false) + } + + override def exists(path: Path): Boolean = { + fs.exists(path) + } + + override def delete(path: Path): Unit = { + try { + fs.delete(path, true) + } catch { + case e: FileNotFoundException => + // ignore if file has already been deleted + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala new file mode 100644 index 0000000000000..aaced49dd16ce --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -0,0 +1,72 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryNode} + +/** + * A variant of [[QueryExecution]] that allows the execution of the given [[LogicalPlan]] + * plan incrementally. Possibly preserving state in between each execution. + */ +class IncrementalExecution( + ctx: SQLContext, + logicalPlan: LogicalPlan, + checkpointLocation: String, + currentBatchId: Long) extends QueryExecution(ctx, logicalPlan) { + + // TODO: make this always part of planning. + val stateStrategy = sqlContext.sessionState.planner.StatefulAggregationStrategy :: Nil + + // Modified planner with stateful operations. + override def planner: SparkPlanner = + new SparkPlanner( + sqlContext.sparkContext, + sqlContext.conf, + stateStrategy) + + /** + * Records the current id for a given stateful operator in the query plan as the `state` + * preperation walks the query plan. + */ + private var operatorId = 0 + + /** Locates save/restore pairs surrounding aggregation. */ + val state = new Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = plan transform { + case StateStoreSave(keys, None, + UnaryNode(agg, + StateStoreRestore(keys2, None, child))) => + val stateId = OperatorStateId(checkpointLocation, operatorId, currentBatchId - 1) + operatorId += 1 + + StateStoreSave( + keys, + Some(stateId), + agg.withNewChildren( + StateStoreRestore( + keys, + Some(stateId), + child) :: Nil)) + } + } + + override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala new file mode 100644 index 0000000000000..bb176408d8f59 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +/** + * A simple offset for sources that produce a single linear stream of data. + */ +case class LongOffset(offset: Long) extends Offset { + + override def compareTo(other: Offset): Int = other match { + case l: LongOffset => offset.compareTo(l.offset) + case _ => + throw new IllegalArgumentException(s"Invalid comparison of $getClass with ${other.getClass}") + } + + def +(increment: Long): LongOffset = new LongOffset(offset + increment) + def -(decrement: Long): LongOffset = new LongOffset(offset - decrement) + + override def toString: String = s"#$offset" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLog.scala new file mode 100644 index 0000000000000..cc70e1d314d1d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLog.scala @@ -0,0 +1,51 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.streaming + +/** + * A general MetadataLog that supports the following features: + * + * - Allow the user to store a metadata object for each batch. + * - Allow the user to query the latest batch id. + * - Allow the user to query the metadata object of a specified batch id. + * - Allow the user to query metadata objects in a range of batch ids. + */ +trait MetadataLog[T] { + + /** + * Store the metadata for the specified batchId and return `true` if successful. If the batchId's + * metadata has already been stored, this method will return `false`. + */ + def add(batchId: Long, metadata: T): Boolean + + /** + * Return the metadata for the specified batchId if it's stored. Otherwise, return None. + */ + def get(batchId: Long): Option[T] + + /** + * Return metadata for batches between startId (inclusive) and endId (inclusive). If `startId` is + * `None`, just return all batches before endId (inclusive). + */ + def get(startId: Option[Long], endId: Option[Long]): Array[(Long, T)] + + /** + * Return the latest batch Id and its metadata if exist. + */ + def getLatest(): Option[(Long, T)] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala new file mode 100644 index 0000000000000..0f5d6445b1e2b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +/** + * A offset is a monotonically increasing metric used to track progress in the computation of a + * stream. An [[Offset]] must be comparable, and the result of `compareTo` must be consistent + * with `equals` and `hashcode`. + */ +trait Offset extends Serializable { + + /** + * Returns a negative integer, zero, or a positive integer as this object is less than, equal to, + * or greater than the specified object. + */ + def compareTo(other: Offset): Int + + def >(other: Offset): Boolean = compareTo(other) > 0 + def <(other: Offset): Boolean = compareTo(other) < 0 + def <=(other: Offset): Boolean = compareTo(other) <= 0 + def >=(other: Offset): Boolean = compareTo(other) >= 0 +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala new file mode 100644 index 0000000000000..25015d58f75ab --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.DataFrame + +/** + * An interface for systems that can collect the results of a streaming query. In order to preserve + * exactly once semantics a sink must be idempotent in the face of multiple attempts to add the same + * batch. + */ +trait Sink { + + /** + * Adds a batch of data to this sink. The data for a given `batchId` is deterministic and if + * this method is called more than once with the same batchId (which will happen in the case of + * failures), then `data` should only be added once. + */ + def addBatch(batchId: Long, data: DataFrame): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala new file mode 100644 index 0000000000000..6457f928ed887 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.StructType + +/** + * A source of continually arriving data for a streaming query. A [[Source]] must have a + * monotonically increasing notion of progress that can be represented as an [[Offset]]. Spark + * will regularly query each [[Source]] to see if any more data is available. + */ +trait Source { + + /** Returns the schema of the data from this source */ + def schema: StructType + + /** Returns the maximum available offset for this source. */ + def getOffset: Option[Offset] + + /** + * Returns the data that is is between the offsets (`start`, `end`]. When `start` is `None` then + * the batch should begin with the first available record. This method must always return the + * same data for a particular `start` and `end` pair. + */ + def getBatch(start: Option[Offset], end: Offset): DataFrame +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala new file mode 100644 index 0000000000000..595774761cffe --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution +import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.execution.SparkPlan + +/** Used to identify the state store for a given operator. */ +case class OperatorStateId( + checkpointLocation: String, + operatorId: Long, + batchId: Long) + +/** + * An operator that saves or restores state from the [[StateStore]]. The [[OperatorStateId]] should + * be filled in by `prepareForExecution` in [[IncrementalExecution]]. + */ +trait StatefulOperator extends SparkPlan { + def stateId: Option[OperatorStateId] + + protected def getStateId: OperatorStateId = attachTree(this) { + stateId.getOrElse { + throw new IllegalStateException("State location not present for execution") + } + } +} + +/** + * For each input tuple, the key is calculated and the value from the [[StateStore]] is added + * to the stream (in addition to the input tuple) if present. + */ +case class StateStoreRestore( + keyExpressions: Seq[Attribute], + stateId: Option[OperatorStateId], + child: SparkPlan) extends execution.UnaryNode with StatefulOperator { + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsWithStateStore( + getStateId.checkpointLocation, + operatorId = getStateId.operatorId, + storeVersion = getStateId.batchId, + keyExpressions.toStructType, + child.output.toStructType, + new StateStoreConf(sqlContext.conf), + Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => + val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + iter.flatMap { row => + val key = getKey(row) + val savedState = store.get(key) + row +: savedState.toSeq + } + } + } + override def output: Seq[Attribute] = child.output +} + +/** + * For each input tuple, the key is calculated and the tuple is `put` into the [[StateStore]]. + */ +case class StateStoreSave( + keyExpressions: Seq[Attribute], + stateId: Option[OperatorStateId], + child: SparkPlan) extends execution.UnaryNode with StatefulOperator { + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsWithStateStore( + getStateId.checkpointLocation, + operatorId = getStateId.operatorId, + storeVersion = getStateId.batchId, + keyExpressions.toStructType, + child.output.toStructType, + new StateStoreConf(sqlContext.conf), + Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => + new Iterator[InternalRow] { + private[this] val baseIterator = iter + private[this] val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + + override def hasNext: Boolean = { + if (!baseIterator.hasNext) { + store.commit() + false + } else { + true + } + } + + override def next(): InternalRow = { + val row = baseIterator.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key.copy(), row.copy()) + row + } + } + } + } + + override def output: Seq[Attribute] = child.output +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala new file mode 100644 index 0000000000000..87dd27a2b1aed --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -0,0 +1,447 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable.ArrayBuffer +import scala.util.control.NonFatal + +import org.apache.hadoop.fs.Path + +import org.apache.spark.internal.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.util.ContinuousQueryListener +import org.apache.spark.sql.util.ContinuousQueryListener._ +import org.apache.spark.util.UninterruptibleThread + +/** + * Manages the execution of a streaming Spark SQL query that is occurring in a separate thread. + * Unlike a standard query, a streaming query executes repeatedly each time new data arrives at any + * [[Source]] present in the query plan. Whenever new data arrives, a [[QueryExecution]] is created + * and the results are committed transactionally to the given [[Sink]]. + */ +class StreamExecution( + override val sqlContext: SQLContext, + override val name: String, + checkpointRoot: String, + private[sql] val logicalPlan: LogicalPlan, + val sink: Sink, + val trigger: Trigger) extends ContinuousQuery with Logging { + + /** An monitor used to wait/notify when batches complete. */ + private val awaitBatchLock = new Object + private val startLatch = new CountDownLatch(1) + private val terminationLatch = new CountDownLatch(1) + + /** + * Tracks how much data we have processed and committed to the sink or state store from each + * input source. + */ + @volatile + private[sql] var committedOffsets = new StreamProgress + + /** + * Tracks the offsets that are available to be processed, but have not yet be committed to the + * sink. + */ + @volatile + private var availableOffsets = new StreamProgress + + /** The current batchId or -1 if execution has not yet been initialized. */ + private var currentBatchId: Long = -1 + + /** All stream sources present the query plan. */ + private val sources = + logicalPlan.collect { case s: StreamingExecutionRelation => s.source } + + /** A list of unique sources in the query plan. */ + private val uniqueSources = sources.distinct + + private val triggerExecutor = trigger match { + case t: ProcessingTime => ProcessingTimeExecutor(t) + } + + /** Defines the internal state of execution */ + @volatile + private var state: State = INITIALIZED + + @volatile + private[sql] var lastExecution: QueryExecution = null + + @volatile + private[sql] var streamDeathCause: ContinuousQueryException = null + + /** The thread that runs the micro-batches of this stream. */ + private[sql] val microBatchThread = + new UninterruptibleThread(s"stream execution thread for $name") { + override def run(): Unit = { runBatches() } + } + + /** + * A write-ahead-log that records the offsets that are present in each batch. In order to ensure + * that a given batch will always consist of the same data, we write to this log *before* any + * processing is done. Thus, the Nth record in this log indicated data that is currently being + * processed and the N-1th entry indicates which offsets have been durably committed to the sink. + */ + private val offsetLog = + new HDFSMetadataLog[CompositeOffset](sqlContext, checkpointFile("offsets")) + + /** Whether the query is currently active or not */ + override def isActive: Boolean = state == ACTIVE + + /** Returns current status of all the sources. */ + override def sourceStatuses: Array[SourceStatus] = { + val localAvailableOffsets = availableOffsets + sources.map(s => new SourceStatus(s.toString, localAvailableOffsets.get(s))).toArray + } + + /** Returns current status of the sink. */ + override def sinkStatus: SinkStatus = + new SinkStatus(sink.toString, committedOffsets.toCompositeOffset(sources)) + + /** Returns the [[ContinuousQueryException]] if the query was terminated by an exception. */ + override def exception: Option[ContinuousQueryException] = Option(streamDeathCause) + + /** Returns the path of a file with `name` in the checkpoint directory. */ + private def checkpointFile(name: String): String = + new Path(new Path(checkpointRoot), name).toUri.toString + + /** + * Starts the execution. This returns only after the thread has started and [[QueryStarted]] event + * has been posted to all the listeners. + */ + private[sql] def start(): Unit = { + microBatchThread.setDaemon(true) + microBatchThread.start() + startLatch.await() // Wait until thread started and QueryStart event has been posted + } + + /** + * Repeatedly attempts to run batches as data arrives. + * + * Note that this method ensures that [[QueryStarted]] and [[QueryTerminated]] events are posted + * such that listeners are guaranteed to get a start event before a termination. Furthermore, this + * method also ensures that [[QueryStarted]] event is posted before the `start()` method returns. + */ + private def runBatches(): Unit = { + try { + // Mark ACTIVE and then post the event. QueryStarted event is synchronously sent to listeners, + // so must mark this as ACTIVE first. + state = ACTIVE + postEvent(new QueryStarted(this)) // Assumption: Does not throw exception. + + // Unblock starting thread + startLatch.countDown() + + // While active, repeatedly attempt to run batches. + SQLContext.setActive(sqlContext) + populateStartOffsets() + logDebug(s"Stream running from $committedOffsets to $availableOffsets") + triggerExecutor.execute(() => { + if (isActive) { + if (dataAvailable) runBatch() + constructNextBatch() + true + } else { + false + } + }) + } catch { + case _: InterruptedException if state == TERMINATED => // interrupted by stop() + case NonFatal(e) => + streamDeathCause = new ContinuousQueryException( + this, + s"Query $name terminated with exception: ${e.getMessage}", + e, + Some(committedOffsets.toCompositeOffset(sources))) + logError(s"Query $name terminated with error", e) + } finally { + state = TERMINATED + sqlContext.streams.notifyQueryTermination(StreamExecution.this) + postEvent(new QueryTerminated(this)) + terminationLatch.countDown() + } + } + + /** + * Populate the start offsets to start the execution at the current offsets stored in the sink + * (i.e. avoid reprocessing data that we have already processed). This function must be called + * before any processing occurs and will populate the following fields: + * - currentBatchId + * - committedOffsets + * - availableOffsets + */ + private def populateStartOffsets(): Unit = { + offsetLog.getLatest() match { + case Some((batchId, nextOffsets)) => + logInfo(s"Resuming continuous query, starting with batch $batchId") + currentBatchId = batchId + 1 + availableOffsets = nextOffsets.toStreamProgress(sources) + logDebug(s"Found possibly uncommitted offsets $availableOffsets") + + offsetLog.get(batchId - 1).foreach { + case lastOffsets => + committedOffsets = lastOffsets.toStreamProgress(sources) + logDebug(s"Resuming with committed offsets: $committedOffsets") + } + + case None => // We are starting this stream for the first time. + logInfo(s"Starting new continuous query.") + currentBatchId = 0 + constructNextBatch() + } + } + + /** + * Returns true if there is any new data available to be processed. + */ + private def dataAvailable: Boolean = { + availableOffsets.exists { + case (source, available) => + committedOffsets + .get(source) + .map(committed => committed < available) + .getOrElse(true) + } + } + + /** + * Queries all of the sources to see if any new data is available. When there is new data the + * batchId counter is incremented and a new log entry is written with the newest offsets. + */ + private def constructNextBatch(): Unit = { + // There is a potential dead-lock in Hadoop "Shell.runCommand" before 2.5.0 (HADOOP-10622). + // If we interrupt some thread running Shell.runCommand, we may hit this issue. + // As "FileStreamSource.getOffset" will create a file using HDFS API and call "Shell.runCommand" + // to set the file permission, we should not interrupt "microBatchThread" when running this + // method. See SPARK-14131. + // + // Check to see what new data is available. + val newData = microBatchThread.runUninterruptibly { + uniqueSources.flatMap(s => s.getOffset.map(o => s -> o)) + } + availableOffsets ++= newData + + val hasNewData = awaitBatchLock.synchronized { + if (dataAvailable) { + true + } else { + noNewData = true + false + } + } + if (hasNewData) { + // There is a potential dead-lock in Hadoop "Shell.runCommand" before 2.5.0 (HADOOP-10622). + // If we interrupt some thread running Shell.runCommand, we may hit this issue. + // As "offsetLog.add" will create a file using HDFS API and call "Shell.runCommand" to set + // the file permission, we should not interrupt "microBatchThread" when running this method. + // See SPARK-14131. + microBatchThread.runUninterruptibly { + assert( + offsetLog.add(currentBatchId, availableOffsets.toCompositeOffset(sources)), + s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId") + } + currentBatchId += 1 + logInfo(s"Committed offsets for batch $currentBatchId.") + } else { + awaitBatchLock.synchronized { + // Wake up any threads that are waiting for the stream to progress. + awaitBatchLock.notifyAll() + } + } + } + + /** + * Processes any data available between `availableOffsets` and `committedOffsets`. + */ + private def runBatch(): Unit = { + val startTime = System.nanoTime() + + // TODO: Move this to IncrementalExecution. + + // Request unprocessed data from all sources. + val newData = availableOffsets.flatMap { + case (source, available) if committedOffsets.get(source).map(_ < available).getOrElse(true) => + val current = committedOffsets.get(source) + val batch = source.getBatch(current, available) + logDebug(s"Retrieving data from $source: $current -> $available") + Some(source -> batch) + case _ => None + }.toMap + + // A list of attributes that will need to be updated. + var replacements = new ArrayBuffer[(Attribute, Attribute)] + // Replace sources in the logical plan with data that has arrived since the last batch. + val withNewSources = logicalPlan transform { + case StreamingExecutionRelation(source, output) => + newData.get(source).map { data => + val newPlan = data.logicalPlan + assert(output.size == newPlan.output.size, + s"Invalid batch: ${output.mkString(",")} != ${newPlan.output.mkString(",")}") + replacements ++= output.zip(newPlan.output) + newPlan + }.getOrElse { + LocalRelation(output) + } + } + + // Rewire the plan to use the new attributes that were returned by the source. + val replacementMap = AttributeMap(replacements) + val newPlan = withNewSources transformAllExpressions { + case a: Attribute if replacementMap.contains(a) => replacementMap(a) + } + + val optimizerStart = System.nanoTime() + lastExecution = + new IncrementalExecution(sqlContext, newPlan, checkpointFile("state"), currentBatchId) + lastExecution.executedPlan + val optimizerTime = (System.nanoTime() - optimizerStart).toDouble / 1000000 + logDebug(s"Optimized batch in ${optimizerTime}ms") + + val nextBatch = + new Dataset(sqlContext, lastExecution, RowEncoder(lastExecution.analyzed.schema)) + sink.addBatch(currentBatchId - 1, nextBatch) + + awaitBatchLock.synchronized { + // Wake up any threads that are waiting for the stream to progress. + awaitBatchLock.notifyAll() + } + + val batchTime = (System.nanoTime() - startTime).toDouble / 1000000 + logInfo(s"Completed up to $availableOffsets in ${batchTime}ms") + // Update committed offsets. + committedOffsets ++= availableOffsets + postEvent(new QueryProgress(this)) + } + + private def postEvent(event: ContinuousQueryListener.Event) { + sqlContext.streams.postListenerEvent(event) + } + + /** + * Signals to the thread executing micro-batches that it should stop running after the next + * batch. This method blocks until the thread stops running. + */ + override def stop(): Unit = { + // Set the state to TERMINATED so that the batching thread knows that it was interrupted + // intentionally + state = TERMINATED + if (microBatchThread.isAlive) { + microBatchThread.interrupt() + microBatchThread.join() + } + logInfo(s"Query $name was stopped") + } + + /** + * Blocks the current thread until processing for data from the given `source` has reached at + * least the given `Offset`. This method is indented for use primarily when writing tests. + */ + def awaitOffset(source: Source, newOffset: Offset): Unit = { + def notDone = { + val localCommittedOffsets = committedOffsets + !localCommittedOffsets.contains(source) || localCommittedOffsets(source) < newOffset + } + + while (notDone) { + logInfo(s"Waiting until $newOffset at $source") + awaitBatchLock.synchronized { awaitBatchLock.wait(100) } + } + logDebug(s"Unblocked at $newOffset for $source") + } + + /** A flag to indicate that a batch has completed with no new data available. */ + @volatile private var noNewData = false + + override def processAllAvailable(): Unit = awaitBatchLock.synchronized { + noNewData = false + while (true) { + awaitBatchLock.wait(10000) + if (streamDeathCause != null) { + throw streamDeathCause + } + if (noNewData) { + return + } + } + } + + override def awaitTermination(): Unit = { + if (state == INITIALIZED) { + throw new IllegalStateException("Cannot wait for termination on a query that has not started") + } + terminationLatch.await() + if (streamDeathCause != null) { + throw streamDeathCause + } + } + + override def awaitTermination(timeoutMs: Long): Boolean = { + if (state == INITIALIZED) { + throw new IllegalStateException("Cannot wait for termination on a query that has not started") + } + require(timeoutMs > 0, "Timeout has to be positive") + terminationLatch.await(timeoutMs, TimeUnit.MILLISECONDS) + if (streamDeathCause != null) { + throw streamDeathCause + } else { + !isActive + } + } + + override def toString: String = { + s"Continuous Query - $name [state = $state]" + } + + def toDebugString: String = { + val deathCauseStr = if (streamDeathCause != null) { + "Error:\n" + stackTraceToString(streamDeathCause.cause) + } else "" + s""" + |=== Continuous Query === + |Name: $name + |Current Offsets: $committedOffsets + | + |Current State: $state + |Thread State: ${microBatchThread.getState} + | + |Logical Plan: + |$logicalPlan + | + |$deathCauseStr + """.stripMargin + } + + trait State + case object INITIALIZED extends State + case object ACTIVE extends State + case object TERMINATED extends State +} + +private[sql] object StreamExecution { + private val nextId = new AtomicInteger() + + def nextName: String = s"query-${nextId.getAndIncrement}" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamFileCatalog.scala new file mode 100644 index 0000000000000..b8d69b18450cf --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamFileCatalog.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.hadoop.fs.{FileStatus, Path} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.sources.{FileCatalog, Partition} +import org.apache.spark.sql.types.StructType + +class StreamFileCatalog(sqlContext: SQLContext, path: Path) extends FileCatalog with Logging { + val metadataDirectory = new Path(path, FileStreamSink.metadataDir) + logInfo(s"Reading streaming file log from $metadataDirectory") + val metadataLog = new HDFSMetadataLog[Seq[String]](sqlContext, metadataDirectory.toUri.toString) + val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + + override def paths: Seq[Path] = path :: Nil + + override def partitionSpec(): PartitionSpec = PartitionSpec(StructType(Nil), Nil) + + /** + * Returns all valid files grouped into partitions when the data is partitioned. If the data is + * unpartitioned, this will return a single partition with not partition values. + * + * @param filters the filters used to prune which partitions are returned. These filters must + * only refer to partition columns and this method will only return files + * where these predicates are guaranteed to evaluate to `true`. Thus, these + * filters will not need to be evaluated again on the returned data. + */ + override def listFiles(filters: Seq[Expression]): Seq[Partition] = + Partition(InternalRow.empty, allFiles()) :: Nil + + override def getStatus(path: Path): Array[FileStatus] = fs.listStatus(path) + + override def refresh(): Unit = {} + + override def allFiles(): Seq[FileStatus] = { + fs.listStatus(metadataLog.get(None, None).flatMap(_._2).map(new Path(_))) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala new file mode 100644 index 0000000000000..405a5f0387a7e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.collection.{immutable, GenTraversableOnce} + +/** + * A helper class that looks like a Map[Source, Offset]. + */ +class StreamProgress( + val baseMap: immutable.Map[Source, Offset] = new immutable.HashMap[Source, Offset]) + extends scala.collection.immutable.Map[Source, Offset] { + + private[sql] def toCompositeOffset(source: Seq[Source]): CompositeOffset = { + CompositeOffset(source.map(get)) + } + + override def toString: String = + baseMap.map { case (k, v) => s"$k: $v"}.mkString("{", ",", "}") + + override def +[B1 >: Offset](kv: (Source, B1)): Map[Source, B1] = baseMap + kv + + override def get(key: Source): Option[Offset] = baseMap.get(key) + + override def iterator: Iterator[(Source, Offset)] = baseMap.iterator + + override def -(key: Source): Map[Source, Offset] = baseMap - key + + def ++(updates: GenTraversableOnce[(Source, Offset)]): StreamProgress = { + new StreamProgress(baseMap ++ updates) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala new file mode 100644 index 0000000000000..d2872e49ce28a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LeafNode +import org.apache.spark.sql.execution.datasources.DataSource + +object StreamingRelation { + def apply(dataSource: DataSource): StreamingRelation = { + val (name, schema) = dataSource.sourceSchema() + StreamingRelation(dataSource, name, schema.toAttributes) + } +} + +/** + * Used to link a streaming [[DataSource]] into a + * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. This is only used for creating + * a streaming [[org.apache.spark.sql.DataFrame]] from [[org.apache.spark.sql.DataFrameReader]]. + * It should be used to create [[Source]] and converted to [[StreamingExecutionRelation]] when + * passing to [StreamExecution]] to run a query. + */ +case class StreamingRelation(dataSource: DataSource, sourceName: String, output: Seq[Attribute]) + extends LeafNode { + override def toString: String = sourceName +} + +/** + * Used to link a streaming [[Source]] of data into a + * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. + */ +case class StreamingExecutionRelation(source: Source, output: Seq[Attribute]) extends LeafNode { + override def toString: String = source.toString +} + +object StreamingExecutionRelation { + def apply(source: Source): StreamingExecutionRelation = { + StreamingExecutionRelation(source, source.schema.toAttributes) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala new file mode 100644 index 0000000000000..a1132d510685c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.ProcessingTime +import org.apache.spark.util.{Clock, SystemClock} + +trait TriggerExecutor { + + /** + * Execute batches using `batchRunner`. If `batchRunner` runs `false`, terminate the execution. + */ + def execute(batchRunner: () => Boolean): Unit +} + +/** + * A trigger executor that runs a batch every `intervalMs` milliseconds. + */ +case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock = new SystemClock()) + extends TriggerExecutor with Logging { + + private val intervalMs = processingTime.intervalMs + + override def execute(batchRunner: () => Boolean): Unit = { + while (true) { + val batchStartTimeMs = clock.getTimeMillis() + val terminated = !batchRunner() + if (intervalMs > 0) { + val batchEndTimeMs = clock.getTimeMillis() + val batchElapsedTimeMs = batchEndTimeMs - batchStartTimeMs + if (batchElapsedTimeMs > intervalMs) { + notifyBatchFallingBehind(batchElapsedTimeMs) + } + if (terminated) { + return + } + clock.waitTillTime(nextBatchTime(batchEndTimeMs)) + } else { + if (terminated) { + return + } + } + } + } + + /** Called when a batch falls behind. Expose for test only */ + def notifyBatchFallingBehind(realElapsedTimeMs: Long): Unit = { + logWarning("Current batch is falling behind. The trigger interval is " + + s"${intervalMs} milliseconds, but spent ${realElapsedTimeMs} milliseconds") + } + + /** Return the next multiple of intervalMs */ + def nextBatchTime(now: Long): Long = { + (now - 1) / intervalMs * intervalMs + intervalMs + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala new file mode 100644 index 0000000000000..3820968324bfe --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.atomic.AtomicInteger +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable.ArrayBuffer +import scala.util.control.NonFatal + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext} +import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LeafNode +import org.apache.spark.sql.types.StructType + +object MemoryStream { + protected val currentBlockId = new AtomicInteger(0) + protected val memoryStreamId = new AtomicInteger(0) + + def apply[A : Encoder](implicit sqlContext: SQLContext): MemoryStream[A] = + new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) +} + +/** + * A [[Source]] that produces value stored in memory as they are added by the user. This [[Source]] + * is primarily intended for use in unit tests as it can only replay data when the object is still + * available. + */ +case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) + extends Source with Logging { + protected val encoder = encoderFor[A] + protected val logicalPlan = StreamingExecutionRelation(this) + protected val output = logicalPlan.output + + @GuardedBy("this") + protected val batches = new ArrayBuffer[Dataset[A]] + + @GuardedBy("this") + protected var currentOffset: LongOffset = new LongOffset(-1) + + def schema: StructType = encoder.schema + + def toDS()(implicit sqlContext: SQLContext): Dataset[A] = { + Dataset(sqlContext, logicalPlan) + } + + def toDF()(implicit sqlContext: SQLContext): DataFrame = { + Dataset.ofRows(sqlContext, logicalPlan) + } + + def addData(data: A*): Offset = { + addData(data.toTraversable) + } + + def addData(data: TraversableOnce[A]): Offset = { + import sqlContext.implicits._ + val ds = data.toVector.toDS() + logDebug(s"Adding ds: $ds") + this.synchronized { + currentOffset = currentOffset + 1 + batches.append(ds) + currentOffset + } + } + + override def toString: String = s"MemoryStream[${output.mkString(",")}]" + + override def getOffset: Option[Offset] = synchronized { + if (batches.isEmpty) { + None + } else { + Some(currentOffset) + } + } + + /** + * Returns the next batch of data that is available after `start`, if any is available. + */ + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + val startOrdinal = + start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1 + val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1 + val newBlocks = synchronized { batches.slice(startOrdinal, endOrdinal) } + + logDebug( + s"MemoryBatch [$startOrdinal, $endOrdinal]: ${newBlocks.flatMap(_.collect()).mkString(", ")}") + newBlocks + .map(_.toDF()) + .reduceOption(_ union _) + .getOrElse { + sys.error("No data selected!") + } + } +} + +/** + * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit + * tests and does not provide durability. + */ +class MemorySink(val schema: StructType) extends Sink with Logging { + /** An order list of batches that have been written to this [[Sink]]. */ + @GuardedBy("this") + private val batches = new ArrayBuffer[Array[Row]]() + + /** Returns all rows that are stored in this [[Sink]]. */ + def allData: Seq[Row] = synchronized { + batches.flatten + } + + def lastBatch: Seq[Row] = synchronized { batches.last } + + def toDebugString: String = synchronized { + batches.zipWithIndex.map { case (b, i) => + val dataStr = try b.mkString(" ") catch { + case NonFatal(e) => "[Error converting to string]" + } + s"$i: $dataStr" + }.mkString("\n") + } + + override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized { + if (batchId == batches.size) { + logDebug(s"Committing batch $batchId") + batches.append(data.collect()) + } else { + logDebug(s"Skipping already committed batch: $batchId") + } + } +} + +/** + * Used to query the data that has been written into a [[MemorySink]]. + */ +case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode { + def this(sink: MemorySink) = this(sink, sink.schema.toAttributes) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala new file mode 100644 index 0000000000000..3335755fd3b67 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -0,0 +1,586 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.io.{DataInputStream, DataOutputStream, IOException} + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.util.Random +import scala.util.control.NonFatal + +import com.google.common.io.ByteStreams +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} + +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.internal.Logging +import org.apache.spark.io.LZ4CompressionCodec +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils + + +/** + * An implementation of [[StateStoreProvider]] and [[StateStore]] in which all the data is backed + * by files in a HDFS-compatible file system. All updates to the store has to be done in sets + * transactionally, and each set of updates increments the store's version. These versions can + * be used to re-execute the updates (by retries in RDD operations) on the correct version of + * the store, and regenerate the store version. + * + * Usage: + * To update the data in the state store, the following order of operations are needed. + * + * // get the right store + * - val store = StateStore.get( + * StateStoreId(checkpointLocation, operatorId, partitionId), ..., version, ...) + * - store.put(...) + * - store.remove(...) + * - store.commit() // commits all the updates to made; the new version will be returned + * - store.iterator() // key-value data after last commit as an iterator + * - store.updates() // updates made in the last commit as an iterator + * + * Fault-tolerance model: + * - Every set of updates is written to a delta file before committing. + * - The state store is responsible for managing, collapsing and cleaning up of delta files. + * - Multiple attempts to commit the same version of updates may overwrite each other. + * Consistency guarantees depend on whether multiple attempts have the same updates and + * the overwrite semantics of underlying file system. + * - Background maintenance of files ensures that last versions of the store is always recoverable + * to ensure re-executed RDD operations re-apply updates on the correct past version of the + * store. + */ +private[state] class HDFSBackedStateStoreProvider( + val id: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + storeConf: StateStoreConf, + hadoopConf: Configuration + ) extends StateStoreProvider with Logging { + + type MapType = java.util.HashMap[UnsafeRow, UnsafeRow] + + /** Implementation of [[StateStore]] API which is backed by a HDFS-compatible file system */ + class HDFSBackedStateStore(val version: Long, mapToUpdate: MapType) + extends StateStore { + + /** Trait and classes representing the internal state of the store */ + trait STATE + case object UPDATING extends STATE + case object COMMITTED extends STATE + case object ABORTED extends STATE + + private val newVersion = version + 1 + private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") + private val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true)) + + private val allUpdates = new java.util.HashMap[UnsafeRow, StoreUpdate]() + + @volatile private var state: STATE = UPDATING + @volatile private var finalDeltaFile: Path = null + + override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id + + override def get(key: UnsafeRow): Option[UnsafeRow] = { + Option(mapToUpdate.get(key)) + } + + override def put(key: UnsafeRow, value: UnsafeRow): Unit = { + verify(state == UPDATING, "Cannot remove after already committed or aborted") + + val isNewKey = !mapToUpdate.containsKey(key) + mapToUpdate.put(key, value) + + Option(allUpdates.get(key)) match { + case Some(ValueAdded(_, _)) => + // Value did not exist in previous version and was added already, keep it marked as added + allUpdates.put(key, ValueAdded(key, value)) + case Some(ValueUpdated(_, _)) | Some(KeyRemoved(_)) => + // Value existed in previous version and updated/removed, mark it as updated + allUpdates.put(key, ValueUpdated(key, value)) + case None => + // There was no prior update, so mark this as added or updated according to its presence + // in previous version. + val update = if (isNewKey) ValueAdded(key, value) else ValueUpdated(key, value) + allUpdates.put(key, update) + } + writeToDeltaFile(tempDeltaFileStream, ValueUpdated(key, value)) + } + + /** Remove keys that match the following condition */ + override def remove(condition: UnsafeRow => Boolean): Unit = { + verify(state == UPDATING, "Cannot remove after already committed or aborted") + val keyIter = mapToUpdate.keySet().iterator() + while (keyIter.hasNext) { + val key = keyIter.next + if (condition(key)) { + keyIter.remove() + + Option(allUpdates.get(key)) match { + case Some(ValueUpdated(_, _)) | None => + // Value existed in previous version and maybe was updated, mark removed + allUpdates.put(key, KeyRemoved(key)) + case Some(ValueAdded(_, _)) => + // Value did not exist in previous version and was added, should not appear in updates + allUpdates.remove(key) + case Some(KeyRemoved(_)) => + // Remove already in update map, no need to change + } + writeToDeltaFile(tempDeltaFileStream, KeyRemoved(key)) + } + } + } + + /** Commit all the updates that have been made to the store, and return the new version. */ + override def commit(): Long = { + verify(state == UPDATING, "Cannot commit after already committed or aborted") + + try { + finalizeDeltaFile(tempDeltaFileStream) + finalDeltaFile = commitUpdates(newVersion, mapToUpdate, tempDeltaFile) + state = COMMITTED + logInfo(s"Committed version $newVersion for $this") + newVersion + } catch { + case NonFatal(e) => + throw new IllegalStateException( + s"Error committing version $newVersion into ${HDFSBackedStateStoreProvider.this}", e) + } + } + + /** Abort all the updates made on this store. This store will not be usable any more. */ + override def abort(): Unit = { + verify(state == UPDATING || state == ABORTED, "Cannot abort after already committed") + + state = ABORTED + if (tempDeltaFileStream != null) { + tempDeltaFileStream.close() + } + if (tempDeltaFile != null && fs.exists(tempDeltaFile)) { + fs.delete(tempDeltaFile, true) + } + logInfo("Aborted") + } + + /** + * Get an iterator of all the store data. + * This can be called only after committing all the updates made in the current thread. + */ + override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = { + verify(state == COMMITTED, + "Cannot get iterator of store data before committing or after aborting") + HDFSBackedStateStoreProvider.this.iterator(newVersion) + } + + /** + * Get an iterator of all the updates made to the store in the current version. + * This can be called only after committing all the updates made in the current thread. + */ + override def updates(): Iterator[StoreUpdate] = { + verify(state == COMMITTED, + "Cannot get iterator of updates before committing or after aborting") + allUpdates.values().asScala.toIterator + } + + /** + * Whether all updates have been committed + */ + override private[state] def hasCommitted: Boolean = { + state == COMMITTED + } + } + + /** Get the state store for making updates to create a new `version` of the store. */ + override def getStore(version: Long): StateStore = synchronized { + require(version >= 0, "Version cannot be less than 0") + val newMap = new MapType() + if (version > 0) { + newMap.putAll(loadMap(version)) + } + val store = new HDFSBackedStateStore(version, newMap) + logInfo(s"Retrieved version $version of $this for update") + store + } + + /** Do maintenance backing data files, including creating snapshots and cleaning up old files */ + override def doMaintenance(): Unit = { + try { + doSnapshot() + cleanup() + } catch { + case NonFatal(e) => + logWarning(s"Error performing snapshot and cleaning up $this") + } + } + + override def toString(): String = { + s"StateStore[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]" + } + + /* Internal classes and methods */ + + private val loadedMaps = new mutable.HashMap[Long, MapType] + private val baseDir = + new Path(id.checkpointLocation, s"${id.operatorId}/${id.partitionId.toString}") + private val fs = baseDir.getFileSystem(hadoopConf) + private val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) + + initialize() + + private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean) + + /** Commit a set of updates to the store with the given new version */ + private def commitUpdates(newVersion: Long, map: MapType, tempDeltaFile: Path): Path = { + synchronized { + val finalDeltaFile = deltaFile(newVersion) + fs.rename(tempDeltaFile, finalDeltaFile) + loadedMaps.put(newVersion, map) + finalDeltaFile + } + } + + /** + * Get iterator of all the data of the latest version of the store. + * Note that this will look up the files to determined the latest known version. + */ + private[state] def latestIterator(): Iterator[(UnsafeRow, UnsafeRow)] = synchronized { + val versionsInFiles = fetchFiles().map(_.version).toSet + val versionsLoaded = loadedMaps.keySet + val allKnownVersions = versionsInFiles ++ versionsLoaded + if (allKnownVersions.nonEmpty) { + loadMap(allKnownVersions.max).entrySet().iterator().asScala.map { x => + (x.getKey, x.getValue) + } + } else Iterator.empty + } + + /** Get iterator of a specific version of the store */ + private[state] def iterator(version: Long): Iterator[(UnsafeRow, UnsafeRow)] = synchronized { + loadMap(version).entrySet().iterator().asScala.map { x => + (x.getKey, x.getValue) + } + } + + /** Initialize the store provider */ + private def initialize(): Unit = { + if (!fs.exists(baseDir)) { + fs.mkdirs(baseDir) + } else { + if (!fs.isDirectory(baseDir)) { + throw new IllegalStateException( + s"Cannot use ${id.checkpointLocation} for storing state data for $this as " + + s"$baseDir already exists and is not a directory") + } + } + } + + /** Load the required version of the map data from the backing files */ + private def loadMap(version: Long): MapType = { + if (version <= 0) return new MapType + synchronized { loadedMaps.get(version) }.getOrElse { + val mapFromFile = readSnapshotFile(version).getOrElse { + val prevMap = loadMap(version - 1) + val newMap = new MapType(prevMap) + newMap.putAll(prevMap) + updateFromDeltaFile(version, newMap) + newMap + } + loadedMaps.put(version, mapFromFile) + mapFromFile + } + } + + private def writeToDeltaFile(output: DataOutputStream, update: StoreUpdate): Unit = { + + def writeUpdate(key: UnsafeRow, value: UnsafeRow): Unit = { + val keyBytes = key.getBytes() + val valueBytes = value.getBytes() + output.writeInt(keyBytes.size) + output.write(keyBytes) + output.writeInt(valueBytes.size) + output.write(valueBytes) + } + + def writeRemove(key: UnsafeRow): Unit = { + val keyBytes = key.getBytes() + output.writeInt(keyBytes.size) + output.write(keyBytes) + output.writeInt(-1) + } + + update match { + case ValueAdded(key, value) => + writeUpdate(key, value) + case ValueUpdated(key, value) => + writeUpdate(key, value) + case KeyRemoved(key) => + writeRemove(key) + } + } + + private def finalizeDeltaFile(output: DataOutputStream): Unit = { + output.writeInt(-1) // Write this magic number to signify end of file + output.close() + } + + private def updateFromDeltaFile(version: Long, map: MapType): Unit = { + val fileToRead = deltaFile(version) + if (!fs.exists(fileToRead)) { + throw new IllegalStateException( + s"Error reading delta file $fileToRead of $this: $fileToRead does not exist") + } + var input: DataInputStream = null + try { + input = decompressStream(fs.open(fileToRead)) + var eof = false + + while(!eof) { + val keySize = input.readInt() + if (keySize == -1) { + eof = true + } else if (keySize < 0) { + throw new IOException( + s"Error reading delta file $fileToRead of $this: key size cannot be $keySize") + } else { + val keyRowBuffer = new Array[Byte](keySize) + ByteStreams.readFully(input, keyRowBuffer, 0, keySize) + + val keyRow = new UnsafeRow(keySchema.fields.length) + keyRow.pointTo(keyRowBuffer, keySize) + + val valueSize = input.readInt() + if (valueSize < 0) { + map.remove(keyRow) + } else { + val valueRowBuffer = new Array[Byte](valueSize) + ByteStreams.readFully(input, valueRowBuffer, 0, valueSize) + val valueRow = new UnsafeRow(valueSchema.fields.length) + valueRow.pointTo(valueRowBuffer, valueSize) + map.put(keyRow, valueRow) + } + } + } + } finally { + if (input != null) input.close() + } + logInfo(s"Read delta file for version $version of $this from $fileToRead") + } + + private def writeSnapshotFile(version: Long, map: MapType): Unit = { + val fileToWrite = snapshotFile(version) + var output: DataOutputStream = null + Utils.tryWithSafeFinally { + output = compressStream(fs.create(fileToWrite, false)) + val iter = map.entrySet().iterator() + while(iter.hasNext) { + val entry = iter.next() + val keyBytes = entry.getKey.getBytes() + val valueBytes = entry.getValue.getBytes() + output.writeInt(keyBytes.size) + output.write(keyBytes) + output.writeInt(valueBytes.size) + output.write(valueBytes) + } + output.writeInt(-1) + } { + if (output != null) output.close() + } + logInfo(s"Written snapshot file for version $version of $this at $fileToWrite") + } + + private def readSnapshotFile(version: Long): Option[MapType] = { + val fileToRead = snapshotFile(version) + if (!fs.exists(fileToRead)) return None + + val map = new MapType() + var input: DataInputStream = null + + try { + input = decompressStream(fs.open(fileToRead)) + var eof = false + + while (!eof) { + val keySize = input.readInt() + if (keySize == -1) { + eof = true + } else if (keySize < 0) { + throw new IOException( + s"Error reading snapshot file $fileToRead of $this: key size cannot be $keySize") + } else { + val keyRowBuffer = new Array[Byte](keySize) + ByteStreams.readFully(input, keyRowBuffer, 0, keySize) + + val keyRow = new UnsafeRow(keySchema.fields.length) + keyRow.pointTo(keyRowBuffer, keySize) + + val valueSize = input.readInt() + if (valueSize < 0) { + throw new IOException( + s"Error reading snapshot file $fileToRead of $this: value size cannot be $valueSize") + } else { + val valueRowBuffer = new Array[Byte](valueSize) + ByteStreams.readFully(input, valueRowBuffer, 0, valueSize) + val valueRow = new UnsafeRow(valueSchema.fields.length) + valueRow.pointTo(valueRowBuffer, valueSize) + map.put(keyRow, valueRow) + } + } + } + logInfo(s"Read snapshot file for version $version of $this from $fileToRead") + Some(map) + } finally { + if (input != null) input.close() + } + } + + + /** Perform a snapshot of the store to allow delta files to be consolidated */ + private def doSnapshot(): Unit = { + try { + val files = fetchFiles() + if (files.nonEmpty) { + val lastVersion = files.last.version + val deltaFilesForLastVersion = + filesForVersion(files, lastVersion).filter(_.isSnapshot == false) + synchronized { loadedMaps.get(lastVersion) } match { + case Some(map) => + if (deltaFilesForLastVersion.size > storeConf.minDeltasForSnapshot) { + writeSnapshotFile(lastVersion, map) + } + case None => + // The last map is not loaded, probably some other instance is in charge + } + + } + } catch { + case NonFatal(e) => + logWarning(s"Error doing snapshots for $this", e) + } + } + + /** + * Clean up old snapshots and delta files that are not needed any more. It ensures that last + * few versions of the store can be recovered from the files, so re-executed RDD operations + * can re-apply updates on the past versions of the store. + */ + private[state] def cleanup(): Unit = { + try { + val files = fetchFiles() + if (files.nonEmpty) { + val earliestVersionToRetain = files.last.version - storeConf.minVersionsToRetain + if (earliestVersionToRetain > 0) { + val earliestFileToRetain = filesForVersion(files, earliestVersionToRetain).head + synchronized { + val mapsToRemove = loadedMaps.keys.filter(_ < earliestVersionToRetain).toSeq + mapsToRemove.foreach(loadedMaps.remove) + } + files.filter(_.version < earliestFileToRetain.version).foreach { f => + fs.delete(f.path, true) + } + logInfo(s"Deleted files older than ${earliestFileToRetain.version} for $this") + } + } + } catch { + case NonFatal(e) => + logWarning(s"Error cleaning up files for $this", e) + } + } + + /** Files needed to recover the given version of the store */ + private def filesForVersion(allFiles: Seq[StoreFile], version: Long): Seq[StoreFile] = { + require(version >= 0) + require(allFiles.exists(_.version == version)) + + val latestSnapshotFileBeforeVersion = allFiles + .filter(_.isSnapshot == true) + .takeWhile(_.version <= version) + .lastOption + val deltaBatchFiles = latestSnapshotFileBeforeVersion match { + case Some(snapshotFile) => + + val deltaFiles = allFiles.filter { file => + file.version > snapshotFile.version && file.version <= version + } + verify( + deltaFiles.size == version - snapshotFile.version, + s"Unexpected list of delta files for version $version for $this: $deltaFiles" + ) + deltaFiles + + case None => + allFiles.takeWhile(_.version <= version) + } + latestSnapshotFileBeforeVersion.toSeq ++ deltaBatchFiles + } + + /** Fetch all the files that back the store */ + private def fetchFiles(): Seq[StoreFile] = { + val files: Seq[FileStatus] = try { + fs.listStatus(baseDir) + } catch { + case _: java.io.FileNotFoundException => + Seq.empty + } + val versionToFiles = new mutable.HashMap[Long, StoreFile] + files.foreach { status => + val path = status.getPath + val nameParts = path.getName.split("\\.") + if (nameParts.size == 2) { + val version = nameParts(0).toLong + nameParts(1).toLowerCase match { + case "delta" => + // ignore the file otherwise, snapshot file already exists for that batch id + if (!versionToFiles.contains(version)) { + versionToFiles.put(version, StoreFile(version, path, isSnapshot = false)) + } + case "snapshot" => + versionToFiles.put(version, StoreFile(version, path, isSnapshot = true)) + case _ => + logWarning(s"Could not identify file $path for $this") + } + } + } + val storeFiles = versionToFiles.values.toSeq.sortBy(_.version) + logDebug(s"Current set of files for $this: $storeFiles") + storeFiles + } + + private def compressStream(outputStream: DataOutputStream): DataOutputStream = { + val compressed = new LZ4CompressionCodec(sparkConf).compressedOutputStream(outputStream) + new DataOutputStream(compressed) + } + + private def decompressStream(inputStream: DataInputStream): DataInputStream = { + val compressed = new LZ4CompressionCodec(sparkConf).compressedInputStream(inputStream) + new DataInputStream(compressed) + } + + private def deltaFile(version: Long): Path = { + new Path(baseDir, s"$version.delta") + } + + private def snapshotFile(version: Long): Path = { + new Path(baseDir, s"$version.snapshot") + } + + private def verify(condition: => Boolean, msg: String): Unit = { + if (!condition) { + throw new IllegalStateException(msg) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala new file mode 100644 index 0000000000000..952150632519e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.concurrent.{ScheduledFuture, TimeUnit} + +import scala.collection.mutable +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ThreadUtils + + +/** Unique identifier for a [[StateStore]] */ +case class StateStoreId(checkpointLocation: String, operatorId: Long, partitionId: Int) + + +/** + * Base trait for a versioned key-value store used for streaming aggregations + */ +trait StateStore { + + /** Unique identifier of the store */ + def id: StateStoreId + + /** Version of the data in this store before committing updates. */ + def version: Long + + /** Get the current value of a key. */ + def get(key: UnsafeRow): Option[UnsafeRow] + + /** Put a new value for a key. */ + def put(key: UnsafeRow, value: UnsafeRow): Unit + + /** + * Remove keys that match the following condition. + */ + def remove(condition: UnsafeRow => Boolean): Unit + + /** + * Commit all the updates that have been made to the store, and return the new version. + */ + def commit(): Long + + /** Abort all the updates that have been made to the store. */ + def abort(): Unit + + /** + * Iterator of store data after a set of updates have been committed. + * This can be called only after committing all the updates made in the current thread. + */ + def iterator(): Iterator[(UnsafeRow, UnsafeRow)] + + /** + * Iterator of the updates that have been committed. + * This can be called only after committing all the updates made in the current thread. + */ + def updates(): Iterator[StoreUpdate] + + /** + * Whether all updates have been committed + */ + private[state] def hasCommitted: Boolean +} + + +/** Trait representing a provider of a specific version of a [[StateStore]]. */ +trait StateStoreProvider { + + /** Get the store with the existing version. */ + def getStore(version: Long): StateStore + + /** Optional method for providers to allow for background maintenance */ + def doMaintenance(): Unit = { } +} + + +/** Trait representing updates made to a [[StateStore]]. */ +sealed trait StoreUpdate + +case class ValueAdded(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate + +case class ValueUpdated(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate + +case class KeyRemoved(key: UnsafeRow) extends StoreUpdate + + +/** + * Companion object to [[StateStore]] that provides helper methods to create and retrieve stores + * by their unique ids. In addition, when a SparkContext is active (i.e. SparkEnv.get is not null), + * it also runs a periodic background task to do maintenance on the loaded stores. For each + * store, it uses the [[StateStoreCoordinator]] to ensure whether the current loaded instance of + * the store is the active instance. Accordingly, it either keeps it loaded and performs + * maintenance, or unloads the store. + */ +private[state] object StateStore extends Logging { + + val MAINTENANCE_INTERVAL_CONFIG = "spark.streaming.stateStore.maintenanceInterval" + val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60 + + private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]() + private val maintenanceTaskExecutor = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("state-store-maintenance-task") + + @volatile private var maintenanceTask: ScheduledFuture[_] = null + @volatile private var _coordRef: StateStoreCoordinatorRef = null + + /** Get or create a store associated with the id. */ + def get( + storeId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + version: Long, + storeConf: StateStoreConf, + hadoopConf: Configuration): StateStore = { + require(version >= 0) + val storeProvider = loadedProviders.synchronized { + startMaintenanceIfNeeded() + val provider = loadedProviders.getOrElseUpdate( + storeId, + new HDFSBackedStateStoreProvider(storeId, keySchema, valueSchema, storeConf, hadoopConf)) + reportActiveStoreInstance(storeId) + provider + } + storeProvider.getStore(version) + } + + /** Unload a state store provider */ + def unload(storeId: StateStoreId): Unit = loadedProviders.synchronized { + loadedProviders.remove(storeId) + } + + /** Whether a state store provider is loaded or not */ + def isLoaded(storeId: StateStoreId): Boolean = loadedProviders.synchronized { + loadedProviders.contains(storeId) + } + + /** Unload and stop all state store providers */ + def stop(): Unit = loadedProviders.synchronized { + loadedProviders.clear() + _coordRef = null + if (maintenanceTask != null) { + maintenanceTask.cancel(false) + maintenanceTask = null + } + logInfo("StateStore stopped") + } + + /** Start the periodic maintenance task if not already started and if Spark active */ + private def startMaintenanceIfNeeded(): Unit = loadedProviders.synchronized { + val env = SparkEnv.get + if (maintenanceTask == null && env != null) { + val periodMs = env.conf.getTimeAsMs( + MAINTENANCE_INTERVAL_CONFIG, s"${MAINTENANCE_INTERVAL_DEFAULT_SECS}s") + val runnable = new Runnable { + override def run(): Unit = { doMaintenance() } + } + maintenanceTask = maintenanceTaskExecutor.scheduleAtFixedRate( + runnable, periodMs, periodMs, TimeUnit.MILLISECONDS) + logInfo("State Store maintenance task started") + } + } + + /** + * Execute background maintenance task in all the loaded store providers if they are still + * the active instances according to the coordinator. + */ + private def doMaintenance(): Unit = { + logDebug("Doing maintenance") + loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) => + try { + if (verifyIfStoreInstanceActive(id)) { + provider.doMaintenance() + } else { + unload(id) + logInfo(s"Unloaded $provider") + } + } catch { + case NonFatal(e) => + logWarning(s"Error managing $provider") + } + } + } + + private def reportActiveStoreInstance(storeId: StateStoreId): Unit = { + try { + val host = SparkEnv.get.blockManager.blockManagerId.host + val executorId = SparkEnv.get.blockManager.blockManagerId.executorId + coordinatorRef.foreach(_.reportActiveInstance(storeId, host, executorId)) + logDebug(s"Reported that the loaded instance $storeId is active") + } catch { + case NonFatal(e) => + logWarning(s"Error reporting active instance of $storeId") + } + } + + private def verifyIfStoreInstanceActive(storeId: StateStoreId): Boolean = { + try { + val executorId = SparkEnv.get.blockManager.blockManagerId.executorId + val verified = + coordinatorRef.map(_.verifyIfInstanceActive(storeId, executorId)).getOrElse(false) + logDebug(s"Verified whether the loaded instance $storeId is active: $verified" ) + verified + } catch { + case NonFatal(e) => + logWarning(s"Error verifying active instance of $storeId") + false + } + } + + private def coordinatorRef: Option[StateStoreCoordinatorRef] = synchronized { + val env = SparkEnv.get + if (env != null) { + if (_coordRef == null) { + _coordRef = StateStoreCoordinatorRef.forExecutor(env) + } + logDebug(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}") + Some(_coordRef) + } else { + _coordRef = null + None + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala new file mode 100644 index 0000000000000..e55f63a6c8db8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.sql.internal.SQLConf + +/** A class that contains configuration parameters for [[StateStore]]s. */ +private[streaming] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable { + + def this() = this(new SQLConf) + + import SQLConf._ + + val minDeltasForSnapshot = conf.getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) + + val minVersionsToRetain = conf.getConf(STATE_STORE_MIN_VERSIONS_TO_RETAIN) +} + +private[streaming] object StateStoreConf { + val empty = new StateStoreConf() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala new file mode 100644 index 0000000000000..e418217238cca --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import scala.collection.mutable + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.util.RpcUtils + +/** Trait representing all messages to [[StateStoreCoordinator]] */ +private sealed trait StateStoreCoordinatorMessage extends Serializable + +/** Classes representing messages */ +private case class ReportActiveInstance(storeId: StateStoreId, host: String, executorId: String) + extends StateStoreCoordinatorMessage + +private case class VerifyIfInstanceActive(storeId: StateStoreId, executorId: String) + extends StateStoreCoordinatorMessage + +private case class GetLocation(storeId: StateStoreId) + extends StateStoreCoordinatorMessage + +private case class DeactivateInstances(storeRootLocation: String) + extends StateStoreCoordinatorMessage + +private object StopCoordinator + extends StateStoreCoordinatorMessage + +/** Helper object used to create reference to [[StateStoreCoordinator]]. */ +private[sql] object StateStoreCoordinatorRef extends Logging { + + private val endpointName = "StateStoreCoordinator" + + /** + * Create a reference to a [[StateStoreCoordinator]] + */ + def forDriver(env: SparkEnv): StateStoreCoordinatorRef = synchronized { + try { + val coordinator = new StateStoreCoordinator(env.rpcEnv) + val coordinatorRef = env.rpcEnv.setupEndpoint(endpointName, coordinator) + logInfo("Registered StateStoreCoordinator endpoint") + new StateStoreCoordinatorRef(coordinatorRef) + } catch { + case e: IllegalArgumentException => + val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName, env.conf, env.rpcEnv) + logDebug("Retrieved existing StateStoreCoordinator endpoint") + new StateStoreCoordinatorRef(rpcEndpointRef) + } + } + + def forExecutor(env: SparkEnv): StateStoreCoordinatorRef = synchronized { + val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName, env.conf, env.rpcEnv) + logDebug("Retrieved existing StateStoreCoordinator endpoint") + new StateStoreCoordinatorRef(rpcEndpointRef) + } +} + +/** + * Reference to a [[StateStoreCoordinator]] that can be used to coordinate instances of + * [[StateStore]]s across all the executors, and get their locations for job scheduling. + */ +private[sql] class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { + + private[state] def reportActiveInstance( + storeId: StateStoreId, + host: String, + executorId: String): Unit = { + rpcEndpointRef.send(ReportActiveInstance(storeId, host, executorId)) + } + + /** Verify whether the given executor has the active instance of a state store */ + private[state] def verifyIfInstanceActive(storeId: StateStoreId, executorId: String): Boolean = { + rpcEndpointRef.askWithRetry[Boolean](VerifyIfInstanceActive(storeId, executorId)) + } + + /** Get the location of the state store */ + private[state] def getLocation(storeId: StateStoreId): Option[String] = { + rpcEndpointRef.askWithRetry[Option[String]](GetLocation(storeId)) + } + + /** Deactivate instances related to a set of operator */ + private[state] def deactivateInstances(storeRootLocation: String): Unit = { + rpcEndpointRef.askWithRetry[Boolean](DeactivateInstances(storeRootLocation)) + } + + private[state] def stop(): Unit = { + rpcEndpointRef.askWithRetry[Boolean](StopCoordinator) + } +} + + +/** + * Class for coordinating instances of [[StateStore]]s loaded in executors across the cluster, + * and get their locations for job scheduling. + */ +private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint { + private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation] + + override def receive: PartialFunction[Any, Unit] = { + case ReportActiveInstance(id, host, executorId) => + instances.put(id, ExecutorCacheTaskLocation(host, executorId)) + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case VerifyIfInstanceActive(id, execId) => + val response = instances.get(id) match { + case Some(location) => location.executorId == execId + case None => false + } + context.reply(response) + + case GetLocation(id) => + context.reply(instances.get(id).map(_.toString)) + + case DeactivateInstances(loc) => + val storeIdsToRemove = + instances.keys.filter(_.checkpointLocation == loc).toSeq + instances --= storeIdsToRemove + context.reply(true) + + case StopCoordinator => + stop() // Stop before replying to ensure that endpoint name has been deregistered + context.reply(true) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala new file mode 100644 index 0000000000000..d708486d8ea0b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import scala.reflect.ClassTag + +import org.apache.spark.{Partition, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +/** + * An RDD that allows computations to be executed against [[StateStore]]s. It + * uses the [[StateStoreCoordinator]] to get the locations of loaded state stores + * and use that as the preferred locations. + */ +class StateStoreRDD[T: ClassTag, U: ClassTag]( + dataRDD: RDD[T], + storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], + checkpointLocation: String, + operatorId: Long, + storeVersion: Long, + keySchema: StructType, + valueSchema: StructType, + storeConf: StateStoreConf, + @transient private val storeCoordinator: Option[StateStoreCoordinatorRef]) + extends RDD[U](dataRDD) { + + // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it + private val confBroadcast = dataRDD.context.broadcast( + new SerializableConfiguration(dataRDD.context.hadoopConfiguration)) + + override protected def getPartitions: Array[Partition] = dataRDD.partitions + + override def getPreferredLocations(partition: Partition): Seq[String] = { + val storeId = StateStoreId(checkpointLocation, operatorId, partition.index) + storeCoordinator.flatMap(_.getLocation(storeId)).toSeq + } + + override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { + var store: StateStore = null + val storeId = StateStoreId(checkpointLocation, operatorId, partition.index) + store = StateStore.get( + storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value) + val inputIter = dataRDD.iterator(partition, ctxt) + storeUpdateFunction(store, inputIter) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala new file mode 100644 index 0000000000000..9b6d0918e29c1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.reflect.ClassTag + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.types.StructType + +package object state { + + implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) { + + /** Map each partition of a RDD along with data in a [[StateStore]]. */ + def mapPartitionsWithStateStore[U: ClassTag]( + sqlContext: SQLContext, + checkpointLocation: String, + operatorId: Long, + storeVersion: Long, + keySchema: StructType, + valueSchema: StructType)( + storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { + + mapPartitionsWithStateStore( + checkpointLocation, + operatorId, + storeVersion, + keySchema, + valueSchema, + new StateStoreConf(sqlContext.conf), + Some(sqlContext.streams.stateStoreCoordinator))( + storeUpdateFunction) + } + + /** Map each partition of a RDD along with data in a [[StateStore]]. */ + private[streaming] def mapPartitionsWithStateStore[U: ClassTag]( + checkpointLocation: String, + operatorId: Long, + storeVersion: Long, + keySchema: StructType, + valueSchema: StructType, + storeConf: StateStoreConf, + storeCoordinator: Option[StateStoreCoordinatorRef])( + storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { + val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) + new StateStoreRDD( + dataRDD, + cleanedF, + checkpointLocation, + operatorId, + storeVersion, + keySchema, + valueSchema, + storeConf, + storeCoordinator) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala new file mode 100644 index 0000000000000..4b3091ba22c60 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.{expressions, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.DataType + +/** + * A subquery that will return only one row and one column. + * + * This is the physical copy of ScalarSubquery to be used inside SparkPlan. + */ +case class ScalarSubquery( + @transient executedPlan: SparkPlan, + exprId: ExprId) + extends SubqueryExpression { + + override def query: LogicalPlan = throw new UnsupportedOperationException + override def withNewPlan(plan: LogicalPlan): SubqueryExpression = { + throw new UnsupportedOperationException + } + override def plan: SparkPlan = Subquery(simpleString, executedPlan) + + override def dataType: DataType = executedPlan.schema.fields.head.dataType + override def nullable: Boolean = true + override def toString: String = s"subquery#${exprId.id}" + + // the first column in first row from `query`. + private var result: Any = null + + def updateResult(v: Any): Unit = { + result = v + } + + override def eval(input: InternalRow): Any = result + + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + Literal.create(result, dataType).genCode(ctx, ev) + } +} + +/** + * Plans scalar subqueries from that are present in the given [[SparkPlan]]. + */ +case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] { + def apply(plan: SparkPlan): SparkPlan = { + plan.transformAllExpressions { + case subquery: expressions.ScalarSubquery => + val executedPlan = new QueryExecution(sqlContext, subquery.plan).executedPlan + ScalarSubquery(executedPlan, subquery.exprId) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index 49646a99d68c8..e96fb9f7550a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -24,7 +24,7 @@ import scala.xml.Node import org.apache.commons.lang3.StringEscapeUtils -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.ui.{UIUtils, WebUIPage} private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with Logging { @@ -55,6 +55,12 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L } _content } + content ++= + UIUtils.headerSparkPage("SQL", content, parent, Some(5000)) } } @@ -118,14 +124,12 @@ private[ui] abstract class ExecutionTable( {failedJobs} }} - {detailCell(executionUIData.physicalPlanDescription)} } private def descriptionCell(execution: SQLExecutionUIData): Seq[Node] = { val details = if (execution.details.nonEmpty) { - + +details ++ } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 5a072de400b6a..5ae9e916adae1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -19,11 +19,34 @@ package org.apache.spark.sql.execution.ui import scala.collection.mutable -import org.apache.spark.executor.TaskMetrics +import org.apache.spark.{JobExecutionStatus, SparkConf} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ -import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} -import org.apache.spark.{JobExecutionStatus, Logging, SparkConf} +import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} +import org.apache.spark.sql.execution.metric._ +import org.apache.spark.ui.SparkUI + +@DeveloperApi +case class SparkListenerSQLExecutionStart( + executionId: Long, + description: String, + details: String, + physicalPlanDescription: String, + sparkPlanInfo: SparkPlanInfo, + time: Long) + extends SparkListenerEvent + +@DeveloperApi +case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long) + extends SparkListenerEvent + +private[sql] class SQLHistoryListenerFactory extends SparkHistoryListenerFactory { + + override def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] = { + List(new SQLHistoryListener(conf, sparkUI)) + } +} private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Logging { @@ -117,8 +140,8 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi override def onExecutorMetricsUpdate( executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { - for ((taskId, stageId, stageAttemptID, metrics) <- executorMetricsUpdate.taskMetrics) { - updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics, finishTask = false) + for ((taskId, stageId, stageAttemptID, accumUpdates) <- executorMetricsUpdate.accumUpdates) { + updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, accumUpdates, finishTask = false) } } @@ -136,27 +159,26 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { - updateTaskAccumulatorValues( - taskEnd.taskInfo.taskId, - taskEnd.stageId, - taskEnd.stageAttemptId, - taskEnd.taskMetrics, - finishTask = true) + if (taskEnd.taskMetrics != null) { + updateTaskAccumulatorValues( + taskEnd.taskInfo.taskId, + taskEnd.stageId, + taskEnd.stageAttemptId, + taskEnd.taskMetrics.accumulatorUpdates(), + finishTask = true) + } } /** * Update the accumulator values of a task with the latest metrics for this task. This is called * every time we receive an executor heartbeat or when a task finishes. */ - private def updateTaskAccumulatorValues( + protected def updateTaskAccumulatorValues( taskId: Long, stageId: Int, stageAttemptID: Int, - metrics: TaskMetrics, + accumulatorUpdates: Seq[AccumulableInfo], finishTask: Boolean): Unit = { - if (metrics == null) { - return - } _stageIdToStageMetrics.get(stageId) match { case Some(stageMetrics) => @@ -174,9 +196,9 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi case Some(taskMetrics) => if (finishTask) { taskMetrics.finished = true - taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() + taskMetrics.accumulatorUpdates = accumulatorUpdates } else if (!taskMetrics.finished) { - taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() + taskMetrics.accumulatorUpdates = accumulatorUpdates } else { // If a task is finished, we should not override with accumulator updates from // heartbeat reports @@ -185,7 +207,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi // TODO Now just set attemptId to 0. Should fix here when we can get the attempt // id from SparkListenerExecutorMetricsUpdate stageMetrics.taskIdToMetricUpdates(taskId) = new SQLTaskMetrics( - attemptId = 0, finished = finishTask, metrics.accumulatorUpdates()) + attemptId = 0, finished = finishTask, accumulatorUpdates) } } case None => @@ -193,38 +215,40 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } } - def onExecutionStart( - executionId: Long, - description: String, - details: String, - physicalPlanDescription: String, - physicalPlanGraph: SparkPlanGraph, - time: Long): Unit = { - val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => - node.metrics.map(metric => metric.accumulatorId -> metric) - } - - val executionUIData = new SQLExecutionUIData(executionId, description, details, - physicalPlanDescription, physicalPlanGraph, sqlPlanMetrics.toMap, time) - synchronized { - activeExecutions(executionId) = executionUIData - _executionIdToData(executionId) = executionUIData - } - } - - def onExecutionEnd(executionId: Long, time: Long): Unit = synchronized { - _executionIdToData.get(executionId).foreach { executionUIData => - executionUIData.completionTime = Some(time) - if (!executionUIData.hasRunningJobs) { - // onExecutionEnd happens after all "onJobEnd"s - // So we should update the execution lists. - markExecutionFinished(executionId) - } else { - // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. - // Then we don't if the execution is successful, so let the last onJobEnd updates the - // execution lists. + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case SparkListenerSQLExecutionStart(executionId, description, details, + physicalPlanDescription, sparkPlanInfo, time) => + val physicalPlanGraph = SparkPlanGraph(sparkPlanInfo) + val sqlPlanMetrics = physicalPlanGraph.allNodes.flatMap { node => + node.metrics.map(metric => metric.accumulatorId -> metric) + } + val executionUIData = new SQLExecutionUIData( + executionId, + description, + details, + physicalPlanDescription, + physicalPlanGraph, + sqlPlanMetrics.toMap, + time) + synchronized { + activeExecutions(executionId) = executionUIData + _executionIdToData(executionId) = executionUIData + } + case SparkListenerSQLExecutionEnd(executionId, time) => synchronized { + _executionIdToData.get(executionId).foreach { executionUIData => + executionUIData.completionTime = Some(time) + if (!executionUIData.hasRunningJobs) { + // onExecutionEnd happens after all "onJobEnd"s + // So we should update the execution lists. + markExecutionFinished(executionId) + } else { + // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. + // Then we don't if the execution is successful, so let the last onJobEnd updates the + // execution lists. + } } } + case _ => // Ignore } private def markExecutionFinished(executionId: Long): Unit = { @@ -265,8 +289,10 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi for (stageId <- executionUIData.stages; stageMetrics <- _stageIdToStageMetrics.get(stageId).toIterable; taskMetrics <- stageMetrics.taskIdToMetricUpdates.values; - accumulatorUpdate <- taskMetrics.accumulatorUpdates.toSeq) yield { - accumulatorUpdate + accumulatorUpdate <- taskMetrics.accumulatorUpdates) yield { + assert(accumulatorUpdate.update.isDefined, s"accumulator update from " + + s"task did not have a partial value: ${accumulatorUpdate.name}") + (accumulatorUpdate.id, accumulatorUpdate.update.get) } }.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) } mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId => @@ -289,6 +315,48 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } + +/** + * A [[SQLListener]] for rendering the SQL UI in the history server. + */ +private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) + extends SQLListener(conf) { + + private var sqlTabAttached = false + + override def onExecutorMetricsUpdate(u: SparkListenerExecutorMetricsUpdate): Unit = { + // Do nothing; these events are not logged + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { + updateTaskAccumulatorValues( + taskEnd.taskInfo.taskId, + taskEnd.stageId, + taskEnd.stageAttemptId, + taskEnd.taskInfo.accumulables.flatMap { a => + // Filter out accumulators that are not SQL metrics + // For now we assume all SQL metrics are Long's that have been JSON serialized as String's + if (a.metadata == Some(SQLMetrics.ACCUM_IDENTIFIER)) { + val newValue = new LongSQLMetricValue(a.update.map(_.toString.toLong).getOrElse(0L)) + Some(a.copy(update = Some(newValue))) + } else { + None + } + }, + finishTask = true) + } + + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case _: SparkListenerSQLExecutionStart => + if (!sqlTabAttached) { + new SQLTab(this, sparkUI) + sqlTabAttached = true + } + super.onOtherEvent(event) + case _ => super.onOtherEvent(event) + } +} + /** * Represent all necessary data for an execution that will be used in Web UI. */ @@ -350,4 +418,4 @@ private[ui] class SQLStageMetrics( private[ui] class SQLTaskMetrics( val attemptId: Long, // TODO not used yet var finished: Boolean, - var accumulatorUpdates: Map[Long, Any]) + var accumulatorUpdates: Seq[AccumulableInfo]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala index 9c27944d42fc6..e8675ce749a2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.execution.ui -import java.util.concurrent.atomic.AtomicInteger - -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.ui.{SparkUI, SparkUITab} private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) - extends SparkUITab(sparkUI, SQLTab.nextTabName) with Logging { + extends SparkUITab(sparkUI, "SQL") with Logging { val parent = sparkUI @@ -35,13 +33,5 @@ private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) } private[sql] object SQLTab { - private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/execution/ui/static" - - private val nextTabId = new AtomicInteger(0) - - private def nextTabName: String = { - val nextId = nextTabId.getAndIncrement() - if (nextId == 0) "SQL" else s"SQL$nextId" - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index f1fce5478a3fe..c6fcb6956c274 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -21,8 +21,10 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} +import org.apache.commons.lang3.StringEscapeUtils + +import org.apache.spark.sql.execution.{SparkPlanInfo, WholeStageCodegen} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * A graph used for storing information of an executionPlan of DataFrame. @@ -41,6 +43,16 @@ private[ui] case class SparkPlanGraph( dotFile.append("}") dotFile.toString() } + + /** + * All the SparkPlanGraphNodes, including those inside of WholeStageCodegen. + */ + val allNodes: Seq[SparkPlanGraphNode] = { + nodes.flatMap { + case cluster: SparkPlanGraphCluster => cluster.nodes :+ cluster + case node => Seq(node) + } + } } private[sql] object SparkPlanGraph { @@ -48,32 +60,73 @@ private[sql] object SparkPlanGraph { /** * Build a SparkPlanGraph from the root of a SparkPlan tree. */ - def apply(plan: SparkPlan): SparkPlanGraph = { + def apply(planInfo: SparkPlanInfo): SparkPlanGraph = { val nodeIdGenerator = new AtomicLong(0) val nodes = mutable.ArrayBuffer[SparkPlanGraphNode]() val edges = mutable.ArrayBuffer[SparkPlanGraphEdge]() - buildSparkPlanGraphNode(plan, nodeIdGenerator, nodes, edges) + val exchanges = mutable.HashMap[SparkPlanInfo, SparkPlanGraphNode]() + buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, null, null, exchanges) new SparkPlanGraph(nodes, edges) } private def buildSparkPlanGraphNode( - plan: SparkPlan, + planInfo: SparkPlanInfo, nodeIdGenerator: AtomicLong, nodes: mutable.ArrayBuffer[SparkPlanGraphNode], - edges: mutable.ArrayBuffer[SparkPlanGraphEdge]): SparkPlanGraphNode = { - val metrics = plan.metrics.toSeq.map { case (key, metric) => - SQLPlanMetric(metric.name.getOrElse(key), metric.id, - metric.param.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]]) - } - val node = SparkPlanGraphNode( - nodeIdGenerator.getAndIncrement(), plan.nodeName, plan.simpleString, metrics) - nodes += node - val childrenNodes = plan.children.map( - child => buildSparkPlanGraphNode(child, nodeIdGenerator, nodes, edges)) - for (child <- childrenNodes) { - edges += SparkPlanGraphEdge(child.id, node.id) + edges: mutable.ArrayBuffer[SparkPlanGraphEdge], + parent: SparkPlanGraphNode, + subgraph: SparkPlanGraphCluster, + exchanges: mutable.HashMap[SparkPlanInfo, SparkPlanGraphNode]): Unit = { + planInfo.nodeName match { + case "WholeStageCodegen" => + val metrics = planInfo.metrics.map { metric => + SQLPlanMetric(metric.name, metric.accumulatorId, + SQLMetrics.getMetricParam(metric.metricParam)) + } + + val cluster = new SparkPlanGraphCluster( + nodeIdGenerator.getAndIncrement(), + planInfo.nodeName, + planInfo.simpleString, + mutable.ArrayBuffer[SparkPlanGraphNode](), + metrics) + nodes += cluster + + buildSparkPlanGraphNode( + planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster, exchanges) + case "InputAdapter" => + buildSparkPlanGraphNode( + planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges) + case "Subquery" if subgraph != null => + // Subquery should not be included in WholeStageCodegen + buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null, exchanges) + case "ReusedExchange" => + // Point to the re-used exchange + val node = exchanges(planInfo.children.head) + edges += SparkPlanGraphEdge(node.id, parent.id) + case name => + val metrics = planInfo.metrics.map { metric => + SQLPlanMetric(metric.name, metric.accumulatorId, + SQLMetrics.getMetricParam(metric.metricParam)) + } + val node = new SparkPlanGraphNode( + nodeIdGenerator.getAndIncrement(), planInfo.nodeName, + planInfo.simpleString, planInfo.metadata, metrics) + if (subgraph == null) { + nodes += node + } else { + subgraph.nodes += node + } + if (name.contains("Exchange")) { + exchanges += planInfo -> node + } + + if (parent != null) { + edges += SparkPlanGraphEdge(node.id, parent.id) + } + planInfo.children.foreach( + buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph, exchanges)) } - node } } @@ -84,30 +137,69 @@ private[sql] object SparkPlanGraph { * @param name the name of this SparkPlan node * @param metrics metrics that this SparkPlan node will track */ -private[ui] case class SparkPlanGraphNode( - id: Long, name: String, desc: String, metrics: Seq[SQLPlanMetric]) { +private[ui] class SparkPlanGraphNode( + val id: Long, + val name: String, + val desc: String, + val metadata: Map[String, String], + val metrics: Seq[SQLPlanMetric]) { def makeDotNode(metricsValue: Map[Long, String]): String = { - val values = { - for (metric <- metrics; - value <- metricsValue.get(metric.accumulatorId)) yield { - metric.name + ": " + value - } + val builder = new mutable.StringBuilder(name) + + val values = for { + metric <- metrics + value <- metricsValue.get(metric.accumulatorId) + } yield { + metric.name + ": " + value } - val label = if (values.isEmpty) { - name + + if (values.nonEmpty) { + // If there are metrics, display each entry in a separate line. + // Note: whitespace between two "\n"s is to create an empty line between the name of + // SparkPlan and metrics. If removing it, it won't display the empty line in UI. + builder ++= "\n \n" + builder ++= values.mkString("\n") + } + + s""" $id [label="${StringEscapeUtils.escapeJava(builder.toString())}"];""" + } +} + +/** + * Represent a tree of SparkPlan for WholeStageCodegen. + */ +private[ui] class SparkPlanGraphCluster( + id: Long, + name: String, + desc: String, + val nodes: mutable.ArrayBuffer[SparkPlanGraphNode], + metrics: Seq[SQLPlanMetric]) + extends SparkPlanGraphNode(id, name, desc, Map.empty, metrics) { + + override def makeDotNode(metricsValue: Map[Long, String]): String = { + val duration = metrics.filter(_.name.startsWith(WholeStageCodegen.PIPELINE_DURATION_METRIC)) + val labelStr = if (duration.nonEmpty) { + require(duration.length == 1) + val id = duration(0).accumulatorId + if (metricsValue.contains(duration(0).accumulatorId)) { + name + "\n\n" + metricsValue.get(id).get } else { - // If there are metrics, display all metrics in a separate line. We should use an escaped - // "\n" here to follow the dot syntax. - // - // Note: whitespace between two "\n"s is to create an empty line between the name of - // SparkPlan and metrics. If removing it, it won't display the empty line in UI. - name + "\\n \\n" + values.mkString("\\n") + name } - s""" $id [label="$label"];""" + } else { + name + } + s""" + | subgraph cluster${id} { + | label="${StringEscapeUtils.escapeJava(labelStr)}"; + | ${nodes.map(_.makeDotNode(metricsValue)).mkString(" \n")} + | } + """.stripMargin } } + /** * Represent an edge in the SparkPlan tree. `fromId` is the parent node id, and `toId` is the child * node id. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala new file mode 100644 index 0000000000000..baae9dd2d5e3e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.sql.{Dataset, Encoder, TypedColumn} +import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression + +/** + * A base class for user-defined aggregations, which can be used in [[Dataset]] operations to take + * all of the elements of a group and reduce them to a single value. + * + * For example, the following aggregator extracts an `int` from a specific class and adds them up: + * {{{ + * case class Data(i: Int) + * + * val customSummer = new Aggregator[Data, Int, Int] { + * def zero: Int = 0 + * def reduce(b: Int, a: Data): Int = b + a.i + * def merge(b1: Int, b2: Int): Int = b1 + b2 + * def finish(r: Int): Int = r + * }.toColumn() + * + * val ds: Dataset[Data] = ... + * val aggregated = ds.select(customSummer) + * }}} + * + * Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird + * + * @tparam IN The input type for the aggregation. + * @tparam BUF The type of the intermediate value of the reduction. + * @tparam OUT The type of the final output result. + * @since 1.6.0 + */ +abstract class Aggregator[-IN, BUF, OUT] extends Serializable { + + /** + * A zero value for this aggregation. Should satisfy the property that any b + zero = b. + * @since 1.6.0 + */ + def zero: BUF + + /** + * Combine two values to produce a new value. For performance, the function may modify `b` and + * return it instead of constructing new object for b. + * @since 1.6.0 + */ + def reduce(b: BUF, a: IN): BUF + + /** + * Merge two intermediate values. + * @since 1.6.0 + */ + def merge(b1: BUF, b2: BUF): BUF + + /** + * Transform the output of the reduction. + * @since 1.6.0 + */ + def finish(reduction: BUF): OUT + + /** + * Specifies the [[Encoder]] for the intermediate value type. + * @since 2.0.0 + */ + def bufferEncoder: Encoder[BUF] + + /** + * Specifies the [[Encoder]] for the final ouput value type. + * @since 2.0.0 + */ + def outputEncoder: Encoder[OUT] + + /** + * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]]. + * operations. + * @since 1.6.0 + */ + def toColumn: TypedColumn[IN, OUT] = { + implicit val bEncoder = bufferEncoder + implicit val cEncoder = outputEncoder + + val expr = + AggregateExpression( + TypedAggregateExpression(this), + Complete, + isDistinct = false) + + new TypedColumn[IN, OUT](expr, encoderFor[OUT]) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala new file mode 100644 index 0000000000000..bd35d19aa20bb --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -0,0 +1,48 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.expressions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.expressions.ScalaUDF +import org.apache.spark.sql.Column +import org.apache.spark.sql.functions +import org.apache.spark.sql.types.DataType + +/** + * A user-defined function. To create one, use the `udf` functions in [[functions]]. + * As an example: + * {{{ + * // Defined a UDF that returns true or false based on some numeric score. + * val predict = udf((score: Double) => if (score > 0.5) true else false) + * + * // Projects a column that adds a prediction column based on the score column. + * df.select( predict(df("score")) ) + * }}} + * + * @since 1.3.0 + */ +@Experimental +case class UserDefinedFunction protected[sql] ( + f: AnyRef, + dataType: DataType, + inputTypes: Option[Seq[DataType]]) { + + def apply(exprs: Column*): Column = { + Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes.getOrElse(Nil))) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index e9b60841fc28c..350c2836461e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -42,7 +42,7 @@ object Window { * Creates a [[WindowSpec]] with the partitioning defined. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def partitionBy(colName: String, colNames: String*): WindowSpec = { spec.partitionBy(colName, colNames : _*) } @@ -51,7 +51,7 @@ object Window { * Creates a [[WindowSpec]] with the partitioning defined. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def partitionBy(cols: Column*): WindowSpec = { spec.partitionBy(cols : _*) } @@ -60,7 +60,7 @@ object Window { * Creates a [[WindowSpec]] with the ordering defined. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def orderBy(colName: String, colNames: String*): WindowSpec = { spec.orderBy(colName, colNames : _*) } @@ -69,7 +69,7 @@ object Window { * Creates a [[WindowSpec]] with the ordering defined. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def orderBy(cols: Column*): WindowSpec = { spec.orderBy(cols : _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 8b9247adea200..d716da2668675 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -18,11 +18,9 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.types.BooleanType -import org.apache.spark.sql.{Column, catalyst} +import org.apache.spark.sql.{catalyst, Column} import org.apache.spark.sql.catalyst.expressions._ - /** * :: Experimental :: * A window specification that defines the partitioning, ordering, and frame boundaries. @@ -41,7 +39,7 @@ class WindowSpec private[sql]( * Defines the partitioning columns in a [[WindowSpec]]. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def partitionBy(colName: String, colNames: String*): WindowSpec = { partitionBy((colName +: colNames).map(Column(_)): _*) } @@ -50,7 +48,7 @@ class WindowSpec private[sql]( * Defines the partitioning columns in a [[WindowSpec]]. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def partitionBy(cols: Column*): WindowSpec = { new WindowSpec(cols.map(_.expr), orderSpec, frame) } @@ -59,7 +57,7 @@ class WindowSpec private[sql]( * Defines the ordering columns in a [[WindowSpec]]. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def orderBy(colName: String, colNames: String*): WindowSpec = { orderBy((colName +: colNames).map(Column(_)): _*) } @@ -68,7 +66,7 @@ class WindowSpec private[sql]( * Defines the ordering columns in a [[WindowSpec]]. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def orderBy(cols: Column*): WindowSpec = { val sortOrder: Seq[SortOrder] = cols.map { col => col.expr match { @@ -140,41 +138,7 @@ class WindowSpec private[sql]( * Converts this [[WindowSpec]] into a [[Column]] with an aggregate expression. */ private[sql] def withAggregate(aggregate: Column): Column = { - val windowExpr = aggregate.expr match { - case Average(child) => WindowExpression( - UnresolvedWindowFunction("avg", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Sum(child) => WindowExpression( - UnresolvedWindowFunction("sum", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Count(child) => WindowExpression( - UnresolvedWindowFunction("count", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case First(child, ignoreNulls) => WindowExpression( - // TODO this is a hack for Hive UDAF first_value - UnresolvedWindowFunction( - "first_value", - child :: ignoreNulls :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Last(child, ignoreNulls) => WindowExpression( - // TODO this is a hack for Hive UDAF last_value - UnresolvedWindowFunction( - "last_value", - child :: ignoreNulls :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Min(child) => WindowExpression( - UnresolvedWindowFunction("min", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Max(child) => WindowExpression( - UnresolvedWindowFunction("max", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case wf: WindowFunction => WindowExpression( - wf, - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case x => - throw new UnsupportedOperationException(s"$x is not supported in window operation.") - } - new Column(windowExpr) + val spec = WindowSpecDefinition(partitionSpec, orderSpec, frame) + new Column(WindowExpression(aggregate.expr, spec)) } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala new file mode 100644 index 0000000000000..d0eb190afd036 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions.scala + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.aggregate._ + +/** + * :: Experimental :: + * Type-safe functions available for [[Dataset]] operations in Scala. + * + * Java users should use [[org.apache.spark.sql.expressions.java.typed]]. + * + * @since 2.0.0 + */ +@Experimental +// scalastyle:off +object typed { + // scalastyle:on + + // Note: whenever we update this file, we should update the corresponding Java version too. + // The reason we have separate files for Java and Scala is because in the Scala version, we can + // use tighter types (primitive types) for return types, whereas in the Java version we can only + // use boxed primitive types. + // For example, avg in the Scala veresion returns Scala primitive Double, whose bytecode + // signature is just a java.lang.Object; avg in the Java version returns java.lang.Double. + + // TODO: This is pretty hacky. Maybe we should have an object for implicit encoders. + private val implicits = new SQLImplicits { + override protected def _sqlContext: SQLContext = null + } + + import implicits._ + + /** + * Average aggregate function. + * + * @since 2.0.0 + */ + def avg[IN](f: IN => Double): TypedColumn[IN, Double] = new TypedAverage(f).toColumn + + /** + * Count aggregate function. + * + * @since 2.0.0 + */ + def count[IN](f: IN => Any): TypedColumn[IN, Long] = new TypedCount(f).toColumn + + /** + * Sum aggregate function for floating point (double) type. + * + * @since 2.0.0 + */ + def sum[IN](f: IN => Double): TypedColumn[IN, Double] = new TypedSumDouble[IN](f).toColumn + + /** + * Sum aggregate function for integral (long, i.e. 64 bit integer) type. + * + * @since 2.0.0 + */ + def sumLong[IN](f: IN => Long): TypedColumn[IN, Long] = new TypedSumLong[IN](f).toColumn + + // TODO: + // stddevOf: Double + // varianceOf: Double + // approxCountDistinct: Long + + // minOf: T + // maxOf: T + + // firstOf: T + // lastOf: T +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index 258afadc76951..48925910ac8cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.expressions -import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2} -import org.apache.spark.sql.execution.aggregate.ScalaUDAF +import org.apache.spark.annotation.Experimental import org.apache.spark.sql.{Column, Row} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} +import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.types._ -import org.apache.spark.annotation.Experimental /** * :: Experimental :: @@ -106,10 +106,10 @@ abstract class UserDefinedAggregateFunction extends Serializable { /** * Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments. */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def apply(exprs: Column*): Column = { val aggregateExpression = - AggregateExpression2( + AggregateExpression( ScalaUDAF(exprs.map(_.expr), this), Complete, isDistinct = false) @@ -120,10 +120,10 @@ abstract class UserDefinedAggregateFunction extends Serializable { * Creates a [[Column]] for this UDAF using the distinct values of the given * [[Column]]s as input arguments. */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def distinct(exprs: Column*): Column = { val aggregateExpression = - AggregateExpression2( + AggregateExpression( ScalaUDAF(exprs.map(_.expr), this), Complete, isDistinct = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c70c965a9b04c..223122300dbb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -18,17 +18,22 @@ package org.apache.spark.sql import scala.language.implicitConversions -import scala.reflect.runtime.universe.{TypeTag, typeTag} +import scala.reflect.runtime.universe.{typeTag, TypeTag} import scala.util.Try import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint +import org.apache.spark.sql.execution.SparkSqlParser +import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.types._ import org.apache.spark.util.Utils + /** * :: Experimental :: * Functions available for [[DataFrame]]. @@ -51,7 +56,13 @@ import org.apache.spark.util.Utils object functions { // scalastyle:on - private[this] implicit def toColumn(expr: Expression): Column = Column(expr) + private def withExpr(expr: Expression): Column = Column(expr) + + private def withAggregateFunction( + func: AggregateFunction, + isDistinct: Boolean = false): Column = { + Column(func.toAggregateExpression(isDistinct)) + } /** * Returns a [[Column]] based on the given column name. @@ -128,7 +139,9 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def approxCountDistinct(e: Column): Column = ApproxCountDistinct(e.expr) + def approxCountDistinct(e: Column): Column = withAggregateFunction { + HyperLogLogPlusPlus(e.expr) + } /** * Aggregate function: returns the approximate number of distinct items in a group. @@ -141,14 +154,20 @@ object functions { /** * Aggregate function: returns the approximate number of distinct items in a group. * + * @param rsd maximum estimation error allowed (default = 0.05) + * * @group agg_funcs * @since 1.3.0 */ - def approxCountDistinct(e: Column, rsd: Double): Column = ApproxCountDistinct(e.expr, rsd) + def approxCountDistinct(e: Column, rsd: Double): Column = withAggregateFunction { + HyperLogLogPlusPlus(e.expr, rsd, 0, 0) + } /** * Aggregate function: returns the approximate number of distinct items in a group. * + * @param rsd maximum estimation error allowed (default = 0.05) + * * @group agg_funcs * @since 1.3.0 */ @@ -162,7 +181,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def avg(e: Column): Column = Average(e.expr) + def avg(e: Column): Column = withAggregateFunction { Average(e.expr) } /** * Aggregate function: returns the average of the values in a group. @@ -172,14 +191,55 @@ object functions { */ def avg(columnName: String): Column = avg(Column(columnName)) + /** + * Aggregate function: returns a list of objects with duplicates. + * + * For now this is an alias for the collect_list Hive UDAF. + * + * @group agg_funcs + * @since 1.6.0 + */ + def collect_list(e: Column): Column = callUDF("collect_list", e) + + /** + * Aggregate function: returns a list of objects with duplicates. + * + * For now this is an alias for the collect_list Hive UDAF. + * + * @group agg_funcs + * @since 1.6.0 + */ + def collect_list(columnName: String): Column = collect_list(Column(columnName)) + + /** + * Aggregate function: returns a set of objects with duplicate elements eliminated. + * + * For now this is an alias for the collect_set Hive UDAF. + * + * @group agg_funcs + * @since 1.6.0 + */ + def collect_set(e: Column): Column = callUDF("collect_set", e) + + /** + * Aggregate function: returns a set of objects with duplicate elements eliminated. + * + * For now this is an alias for the collect_set Hive UDAF. + * + * @group agg_funcs + * @since 1.6.0 + */ + def collect_set(columnName: String): Column = collect_set(Column(columnName)) + /** * Aggregate function: returns the Pearson Correlation Coefficient for two columns. * * @group agg_funcs * @since 1.6.0 */ - def corr(column1: Column, column2: Column): Column = + def corr(column1: Column, column2: Column): Column = withAggregateFunction { Corr(column1.expr, column2.expr) + } /** * Aggregate function: returns the Pearson Correlation Coefficient for two columns. @@ -187,8 +247,9 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def corr(columnName1: String, columnName2: String): Column = + def corr(columnName1: String, columnName2: String): Column = { corr(Column(columnName1), Column(columnName2)) + } /** * Aggregate function: returns the number of items in a group. @@ -196,10 +257,12 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def count(e: Column): Column = e.expr match { - // Turn count(*) into count(1) - case s: Star => Count(Literal(1)) - case _ => Count(e.expr) + def count(e: Column): Column = withAggregateFunction { + e.expr match { + // Turn count(*) into count(1) + case s: Star => Count(Literal(1)) + case _ => Count(e.expr) + } } /** @@ -208,7 +271,8 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def count(columnName: String): Column = count(Column(columnName)) + def count(columnName: String): TypedColumn[Any, Long] = + count(Column(columnName)).as(ExpressionEncoder[Long]()) /** * Aggregate function: returns the number of distinct items in a group. @@ -217,8 +281,9 @@ object functions { * @since 1.3.0 */ @scala.annotation.varargs - def countDistinct(expr: Column, exprs: Column*): Column = - CountDistinct((expr +: exprs).map(_.expr)) + def countDistinct(expr: Column, exprs: Column*): Column = { + withAggregateFunction(Count.apply((expr +: exprs).map(_.expr)), isDistinct = true) + } /** * Aggregate function: returns the number of distinct items in a group. @@ -230,45 +295,202 @@ object functions { def countDistinct(columnName: String, columnNames: String*): Column = countDistinct(Column(columnName), columnNames.map(Column.apply) : _*) + /** + * Aggregate function: returns the population covariance for two columns. + * + * @group agg_funcs + * @since 2.0.0 + */ + def covar_pop(column1: Column, column2: Column): Column = withAggregateFunction { + CovPopulation(column1.expr, column2.expr) + } + + /** + * Aggregate function: returns the population covariance for two columns. + * + * @group agg_funcs + * @since 2.0.0 + */ + def covar_pop(columnName1: String, columnName2: String): Column = { + covar_pop(Column(columnName1), Column(columnName2)) + } + + /** + * Aggregate function: returns the sample covariance for two columns. + * + * @group agg_funcs + * @since 2.0.0 + */ + def covar_samp(column1: Column, column2: Column): Column = withAggregateFunction { + CovSample(column1.expr, column2.expr) + } + + /** + * Aggregate function: returns the sample covariance for two columns. + * + * @group agg_funcs + * @since 2.0.0 + */ + def covar_samp(columnName1: String, columnName2: String): Column = { + covar_samp(Column(columnName1), Column(columnName2)) + } + /** * Aggregate function: returns the first value in a group. * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 2.0.0 + */ + def first(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction { + new First(e.expr, Literal(ignoreNulls)) + } + + /** + * Aggregate function: returns the first value of a column in a group. + * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 2.0.0 + */ + def first(columnName: String, ignoreNulls: Boolean): Column = { + first(Column(columnName), ignoreNulls) + } + + /** + * Aggregate function: returns the first value in a group. + * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * * @group agg_funcs * @since 1.3.0 */ - def first(e: Column): Column = First(e.expr) + def first(e: Column): Column = first(e, ignoreNulls = false) /** * Aggregate function: returns the first value of a column in a group. * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * * @group agg_funcs * @since 1.3.0 */ def first(columnName: String): Column = first(Column(columnName)) + /** + * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated + * or not, returns 1 for aggregated or 0 for not aggregated in the result set. + * + * @group agg_funcs + * @since 2.0.0 + */ + def grouping(e: Column): Column = Column(Grouping(e.expr)) + + /** + * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated + * or not, returns 1 for aggregated or 0 for not aggregated in the result set. + * + * @group agg_funcs + * @since 2.0.0 + */ + def grouping(columnName: String): Column = grouping(Column(columnName)) + + /** + * Aggregate function: returns the level of grouping, equals to + * + * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) + * + * Note: the list of columns should match with grouping columns exactly, or empty (means all the + * grouping columns). + * + * @group agg_funcs + * @since 2.0.0 + */ + def grouping_id(cols: Column*): Column = Column(GroupingID(cols.map(_.expr))) + + /** + * Aggregate function: returns the level of grouping, equals to + * + * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) + * + * Note: the list of columns should match with grouping columns exactly. + * + * @group agg_funcs + * @since 2.0.0 + */ + def grouping_id(colName: String, colNames: String*): Column = { + grouping_id((Seq(colName) ++ colNames).map(n => Column(n)) : _*) + } + /** * Aggregate function: returns the kurtosis of the values in a group. * * @group agg_funcs * @since 1.6.0 */ - def kurtosis(e: Column): Column = Kurtosis(e.expr) + def kurtosis(e: Column): Column = withAggregateFunction { Kurtosis(e.expr) } + + /** + * Aggregate function: returns the kurtosis of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def kurtosis(columnName: String): Column = kurtosis(Column(columnName)) /** * Aggregate function: returns the last value in a group. * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 2.0.0 + */ + def last(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction { + new Last(e.expr, Literal(ignoreNulls)) + } + + /** + * Aggregate function: returns the last value of the column in a group. + * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 2.0.0 + */ + def last(columnName: String, ignoreNulls: Boolean): Column = { + last(Column(columnName), ignoreNulls) + } + + /** + * Aggregate function: returns the last value in a group. + * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * * @group agg_funcs * @since 1.3.0 */ - def last(e: Column): Column = Last(e.expr) + def last(e: Column): Column = last(e, ignoreNulls = false) /** * Aggregate function: returns the last value of the column in a group. * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * * @group agg_funcs * @since 1.3.0 */ - def last(columnName: String): Column = last(Column(columnName)) + def last(columnName: String): Column = last(Column(columnName), ignoreNulls = false) /** * Aggregate function: returns the maximum value of the expression in a group. @@ -276,7 +498,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def max(e: Column): Column = Max(e.expr) + def max(e: Column): Column = withAggregateFunction { Max(e.expr) } /** * Aggregate function: returns the maximum value of the column in a group. @@ -310,7 +532,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def min(e: Column): Column = Min(e.expr) + def min(e: Column): Column = withAggregateFunction { Min(e.expr) } /** * Aggregate function: returns the minimum value of the column in a group. @@ -326,7 +548,23 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def skewness(e: Column): Column = Skewness(e.expr) + def skewness(e: Column): Column = withAggregateFunction { Skewness(e.expr) } + + /** + * Aggregate function: returns the skewness of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def skewness(columnName: String): Column = skewness(Column(columnName)) + + /** + * Aggregate function: alias for [[stddev_samp]]. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) } /** * Aggregate function: alias for [[stddev_samp]]. @@ -334,16 +572,25 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def stddev(e: Column): Column = StddevSamp(e.expr) + def stddev(columnName: String): Column = stddev(Column(columnName)) /** - * Aggregate function: returns the unbiased sample standard deviation of + * Aggregate function: returns the sample standard deviation of * the expression in a group. * * @group agg_funcs * @since 1.6.0 */ - def stddev_samp(e: Column): Column = StddevSamp(e.expr) + def stddev_samp(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) } + + /** + * Aggregate function: returns the sample standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_samp(columnName: String): Column = stddev_samp(Column(columnName)) /** * Aggregate function: returns the population standard deviation of @@ -352,7 +599,16 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def stddev_pop(e: Column): Column = StddevPop(e.expr) + def stddev_pop(e: Column): Column = withAggregateFunction { StddevPop(e.expr) } + + /** + * Aggregate function: returns the population standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_pop(columnName: String): Column = stddev_pop(Column(columnName)) /** * Aggregate function: returns the sum of all values in the expression. @@ -360,7 +616,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def sum(e: Column): Column = Sum(e.expr) + def sum(e: Column): Column = withAggregateFunction { Sum(e.expr) } /** * Aggregate function: returns the sum of all values in the given column. @@ -376,7 +632,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def sumDistinct(e: Column): Column = SumDistinct(e.expr) + def sumDistinct(e: Column): Column = withAggregateFunction(Sum(e.expr), isDistinct = true) /** * Aggregate function: returns the sum of distinct values in the expression. @@ -392,7 +648,23 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def variance(e: Column): Column = VarianceSamp(e.expr) + def variance(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) } + + /** + * Aggregate function: alias for [[var_samp]]. + * + * @group agg_funcs + * @since 1.6.0 + */ + def variance(columnName: String): Column = variance(Column(columnName)) + + /** + * Aggregate function: returns the unbiased variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def var_samp(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) } /** * Aggregate function: returns the unbiased variance of the values in a group. @@ -400,7 +672,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def var_samp(e: Column): Column = VarianceSamp(e.expr) + def var_samp(columnName: String): Column = var_samp(Column(columnName)) /** * Aggregate function: returns the population variance of the values in a group. @@ -408,7 +680,15 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def var_pop(e: Column): Column = VariancePop(e.expr) + def var_pop(e: Column): Column = withAggregateFunction { VariancePop(e.expr) } + + /** + * Aggregate function: returns the population variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def var_pop(columnName: String): Column = var_pop(Column(columnName)) ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions @@ -423,15 +703,10 @@ object functions { * cumeDist(x) = number of values before (and including) x / N * }}} * - * - * This is equivalent to the CUME_DIST function in SQL. - * * @group window_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def cumeDist(): Column = { - UnresolvedWindowFunction("cume_dist", Nil) - } + def cume_dist(): Column = withExpr { new CumeDist } /** * Window function: returns the rank of rows within a window partition, without any gaps. @@ -441,14 +716,10 @@ object functions { * and had three people tie for second place, you would say that all three were in second * place and that the next person came in third. * - * This is equivalent to the DENSE_RANK function in SQL. - * * @group window_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def denseRank(): Column = { - UnresolvedWindowFunction("dense_rank", Nil) - } + def dense_rank(): Column = withExpr { new DenseRank } /** * Window function: returns the value that is `offset` rows before the current row, and @@ -460,9 +731,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lag(e: Column, offset: Int): Column = { - lag(e, offset, null) - } + def lag(e: Column, offset: Int): Column = lag(e, offset, null) /** * Window function: returns the value that is `offset` rows before the current row, and @@ -474,9 +743,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lag(columnName: String, offset: Int): Column = { - lag(columnName, offset, null) - } + def lag(columnName: String, offset: Int): Column = lag(columnName, offset, null) /** * Window function: returns the value that is `offset` rows before the current row, and @@ -502,8 +769,8 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lag(e: Column, offset: Int, defaultValue: Any): Column = { - UnresolvedWindowFunction("lag", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil) + def lag(e: Column, offset: Int, defaultValue: Any): Column = withExpr { + Lag(e.expr, Literal(offset), Literal(defaultValue)) } /** @@ -516,9 +783,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lead(columnName: String, offset: Int): Column = { - lead(columnName, offset, null) - } + def lead(columnName: String, offset: Int): Column = { lead(columnName, offset, null) } /** * Window function: returns the value that is `offset` rows after the current row, and @@ -530,9 +795,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lead(e: Column, offset: Int): Column = { - lead(e, offset, null) - } + def lead(e: Column, offset: Int): Column = { lead(e, offset, null) } /** * Window function: returns the value that is `offset` rows after the current row, and @@ -558,8 +821,8 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lead(e: Column, offset: Int, defaultValue: Any): Column = { - UnresolvedWindowFunction("lead", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil) + def lead(e: Column, offset: Int, defaultValue: Any): Column = withExpr { + Lead(e.expr, Literal(offset), Literal(defaultValue)) } /** @@ -572,9 +835,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def ntile(n: Int): Column = { - UnresolvedWindowFunction("ntile", lit(n).expr :: Nil) - } + def ntile(n: Int): Column = withExpr { new NTile(Literal(n)) } /** * Window function: returns the relative rank (i.e. percentile) of rows within a window partition. @@ -587,11 +848,9 @@ object functions { * This is equivalent to the PERCENT_RANK function in SQL. * * @group window_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def percentRank(): Column = { - UnresolvedWindowFunction("percent_rank", Nil) - } + def percent_rank(): Column = withExpr { new PercentRank } /** * Window function: returns the rank of rows within a window partition. @@ -606,21 +865,15 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def rank(): Column = { - UnresolvedWindowFunction("rank", Nil) - } + def rank(): Column = withExpr { new Rank } /** * Window function: returns a sequential number starting at 1 within a window partition. * - * This is equivalent to the ROW_NUMBER function in SQL. - * * @group window_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def rowNumber(): Column = { - UnresolvedWindowFunction("row_number", Nil) - } + def row_number(): Column = withExpr { RowNumber() } ////////////////////////////////////////////////////////////////////////////////////////////// // Non-aggregate functions @@ -632,7 +885,7 @@ object functions { * @group normal_funcs * @since 1.3.0 */ - def abs(e: Column): Column = Abs(e.expr) + def abs(e: Column): Column = withExpr { Abs(e.expr) } /** * Creates a new array column. The input columns must all have the same data type. @@ -641,7 +894,7 @@ object functions { * @since 1.4.0 */ @scala.annotation.varargs - def array(cols: Column*): Column = CreateArray(cols.map(_.expr)) + def array(cols: Column*): Column = withExpr { CreateArray(cols.map(_.expr)) } /** * Creates a new array column. The input columns must all have the same data type. @@ -649,10 +902,22 @@ object functions { * @group normal_funcs * @since 1.4.0 */ + @scala.annotation.varargs def array(colName: String, colNames: String*): Column = { array((colName +: colNames).map(col) : _*) } + /** + * Creates a new map column. The input columns must be grouped as key-value pairs, e.g. + * (key1, value1, key2, value2, ...). The key columns must all have the same data type, and can't + * be null. The value columns must all have the same data type. + * + * @group normal_funcs + * @since 2.0 + */ + @scala.annotation.varargs + def map(cols: Column*): Column = withExpr { CreateMap(cols.map(_.expr)) } + /** * Marks a DataFrame as small enough for use in broadcast joins. * @@ -666,7 +931,7 @@ object functions { * @since 1.5.0 */ def broadcast(df: DataFrame): DataFrame = { - DataFrame(df.sqlContext, BroadcastHint(df.logicalPlan)) + Dataset.ofRows(df.sqlContext, BroadcastHint(df.logicalPlan)) } /** @@ -679,22 +944,31 @@ object functions { * @since 1.3.0 */ @scala.annotation.varargs - def coalesce(e: Column*): Column = Coalesce(e.map(_.expr)) + def coalesce(e: Column*): Column = withExpr { Coalesce(e.map(_.expr)) } /** * Creates a string column for the file name of the current Spark task. * * @group normal_funcs + * @since 1.6.0 */ - def inputFileName(): Column = InputFileName() + def input_file_name(): Column = withExpr { InputFileName() } /** * Return true iff the column is NaN. * * @group normal_funcs - * @since 1.5.0 + * @since 1.6.0 */ - def isNaN(e: Column): Column = IsNaN(e.expr) + def isnan(e: Column): Column = withExpr { IsNaN(e.expr) } + + /** + * Return true iff the column is null. + * + * @group normal_funcs + * @since 1.6.0 + */ + def isnull(e: Column): Column = withExpr { IsNull(e.expr) } /** * A column expression that generates monotonically increasing 64-bit integers. @@ -711,7 +985,24 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def monotonicallyIncreasingId(): Column = MonotonicallyIncreasingID() + def monotonicallyIncreasingId(): Column = monotonically_increasing_id() + + /** + * A column expression that generates monotonically increasing 64-bit integers. + * + * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + * The current implementation puts the partition ID in the upper 31 bits, and the record number + * within each partition in the lower 33 bits. The assumption is that the data frame has + * less than 1 billion partitions, and each partition has less than 8 billion records. + * + * As an example, consider a [[DataFrame]] with two partitions, each with 3 records. + * This expression would return the following IDs: + * 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. + * + * @group normal_funcs + * @since 1.6.0 + */ + def monotonically_increasing_id(): Column = withExpr { MonotonicallyIncreasingID() } /** * Returns col1 if it is not NaN, or col2 if col1 is NaN. @@ -721,7 +1012,7 @@ object functions { * @group normal_funcs * @since 1.5.0 */ - def nanvl(col1: Column, col2: Column): Column = NaNvl(col1.expr, col2.expr) + def nanvl(col1: Column, col2: Column): Column = withExpr { NaNvl(col1.expr, col2.expr) } /** * Unary minus, i.e. negate the expression. @@ -757,10 +1048,12 @@ object functions { /** * Generate a random column with i.i.d. samples from U[0.0, 1.0]. * + * Note that this is indeterministic when data partitions are not fixed. + * * @group normal_funcs * @since 1.4.0 */ - def rand(seed: Long): Column = Rand(seed) + def rand(seed: Long): Column = withExpr { Rand(seed) } /** * Generate a random column with i.i.d. samples from U[0.0, 1.0]. @@ -773,10 +1066,12 @@ object functions { /** * Generate a column with i.i.d. samples from the standard normal distribution. * + * Note that this is indeterministic when data partitions are not fixed. + * * @group normal_funcs * @since 1.4.0 */ - def randn(seed: Long): Column = Randn(seed) + def randn(seed: Long): Column = withExpr { Randn(seed) } /** * Generate a column with i.i.d. samples from the standard normal distribution. @@ -792,9 +1087,9 @@ object functions { * Note that this is indeterministic because it depends on data partitioning and task scheduling. * * @group normal_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def sparkPartitionId(): Column = SparkPartitionID() + def spark_partition_id(): Column = withExpr { SparkPartitionID() } /** * Computes the square root of the specified float value. @@ -802,7 +1097,7 @@ object functions { * @group math_funcs * @since 1.3.0 */ - def sqrt(e: Column): Column = Sqrt(e.expr) + def sqrt(e: Column): Column = withExpr { Sqrt(e.expr) } /** * Computes the square root of the specified float value. @@ -823,9 +1118,7 @@ object functions { * @since 1.4.0 */ @scala.annotation.varargs - def struct(cols: Column*): Column = { - CreateStruct(cols.map(_.expr)) - } + def struct(cols: Column*): Column = withExpr { CreateStruct(cols.map(_.expr)) } /** * Creates a new struct column that composes multiple input columns. @@ -833,6 +1126,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ + @scala.annotation.varargs def struct(colName: String, colNames: String*): Column = { struct((colName +: colNames).map(col) : _*) } @@ -858,8 +1152,8 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def when(condition: Column, value: Any): Column = { - CaseWhen(Seq(condition.expr, lit(value).expr)) + def when(condition: Column, value: Any): Column = withExpr { + CaseWhen(Seq((condition.expr, lit(value).expr))) } /** @@ -868,7 +1162,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def bitwiseNOT(e: Column): Column = BitwiseNot(e.expr) + def bitwiseNOT(e: Column): Column = withExpr { BitwiseNot(e.expr) } /** * Parses the expression string into the column that it represents, similar to @@ -880,7 +1174,9 @@ object functions { * * @group normal_funcs */ - def expr(expr: String): Column = Column(SqlParser.parseExpression(expr)) + def expr(expr: String): Column = { + Column(SparkSqlParser.parseExpression(expr)) + } ////////////////////////////////////////////////////////////////////////////////////////////// // Math Functions @@ -893,7 +1189,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def acos(e: Column): Column = Acos(e.expr) + def acos(e: Column): Column = withExpr { Acos(e.expr) } /** * Computes the cosine inverse of the given column; the returned angle is in the range @@ -911,7 +1207,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def asin(e: Column): Column = Asin(e.expr) + def asin(e: Column): Column = withExpr { Asin(e.expr) } /** * Computes the sine inverse of the given column; the returned angle is in the range @@ -928,7 +1224,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan(e: Column): Column = Atan(e.expr) + def atan(e: Column): Column = withExpr { Atan(e.expr) } /** * Computes the tangent inverse of the given column. @@ -945,7 +1241,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan2(l: Column, r: Column): Column = Atan2(l.expr, r.expr) + def atan2(l: Column, r: Column): Column = withExpr { Atan2(l.expr, r.expr) } /** * Returns the angle theta from the conversion of rectangular coordinates (x, y) to @@ -982,7 +1278,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan2(l: Column, r: Double): Column = atan2(l, lit(r).expr) + def atan2(l: Column, r: Double): Column = atan2(l, lit(r)) /** * Returns the angle theta from the conversion of rectangular coordinates (x, y) to @@ -1000,7 +1296,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan2(l: Double, r: Column): Column = atan2(lit(l).expr, r) + def atan2(l: Double, r: Column): Column = atan2(lit(l), r) /** * Returns the angle theta from the conversion of rectangular coordinates (x, y) to @@ -1018,7 +1314,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def bin(e: Column): Column = Bin(e.expr) + def bin(e: Column): Column = withExpr { Bin(e.expr) } /** * An expression that returns the string representation of the binary value of the given long @@ -1035,7 +1331,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def cbrt(e: Column): Column = Cbrt(e.expr) + def cbrt(e: Column): Column = withExpr { Cbrt(e.expr) } /** * Computes the cube-root of the given column. @@ -1051,7 +1347,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def ceil(e: Column): Column = Ceil(e.expr) + def ceil(e: Column): Column = withExpr { Ceil(e.expr) } /** * Computes the ceiling of the given column. @@ -1067,8 +1363,9 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def conv(num: Column, fromBase: Int, toBase: Int): Column = + def conv(num: Column, fromBase: Int, toBase: Int): Column = withExpr { Conv(num.expr, lit(fromBase).expr, lit(toBase).expr) + } /** * Computes the cosine of the given value. @@ -1076,7 +1373,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def cos(e: Column): Column = Cos(e.expr) + def cos(e: Column): Column = withExpr { Cos(e.expr) } /** * Computes the cosine of the given column. @@ -1092,7 +1389,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def cosh(e: Column): Column = Cosh(e.expr) + def cosh(e: Column): Column = withExpr { Cosh(e.expr) } /** * Computes the hyperbolic cosine of the given column. @@ -1108,7 +1405,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def exp(e: Column): Column = Exp(e.expr) + def exp(e: Column): Column = withExpr { Exp(e.expr) } /** * Computes the exponential of the given column. @@ -1124,7 +1421,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def expm1(e: Column): Column = Expm1(e.expr) + def expm1(e: Column): Column = withExpr { Expm1(e.expr) } /** * Computes the exponential of the given column. @@ -1140,7 +1437,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def factorial(e: Column): Column = Factorial(e.expr) + def factorial(e: Column): Column = withExpr { Factorial(e.expr) } /** * Computes the floor of the given value. @@ -1148,7 +1445,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def floor(e: Column): Column = Floor(e.expr) + def floor(e: Column): Column = withExpr { Floor(e.expr) } /** * Computes the floor of the given column. @@ -1166,7 +1463,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def greatest(exprs: Column*): Column = { + def greatest(exprs: Column*): Column = withExpr { require(exprs.length > 1, "greatest requires at least 2 arguments.") Greatest(exprs.map(_.expr)) } @@ -1189,7 +1486,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def hex(column: Column): Column = Hex(column.expr) + def hex(column: Column): Column = withExpr { Hex(column.expr) } /** * Inverse of hex. Interprets each pair of characters as a hexadecimal number @@ -1198,7 +1495,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def unhex(column: Column): Column = Unhex(column.expr) + def unhex(column: Column): Column = withExpr { Unhex(column.expr) } /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. @@ -1206,7 +1503,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def hypot(l: Column, r: Column): Column = Hypot(l.expr, r.expr) + def hypot(l: Column, r: Column): Column = withExpr { Hypot(l.expr, r.expr) } /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. @@ -1239,7 +1536,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def hypot(l: Column, r: Double): Column = hypot(l, lit(r).expr) + def hypot(l: Column, r: Double): Column = hypot(l, lit(r)) /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. @@ -1255,7 +1552,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def hypot(l: Double, r: Column): Column = hypot(lit(l).expr, r) + def hypot(l: Double, r: Column): Column = hypot(lit(l), r) /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. @@ -1273,7 +1570,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def least(exprs: Column*): Column = { + def least(exprs: Column*): Column = withExpr { require(exprs.length > 1, "least requires at least 2 arguments.") Least(exprs.map(_.expr)) } @@ -1296,7 +1593,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log(e: Column): Column = Log(e.expr) + def log(e: Column): Column = withExpr { Log(e.expr) } /** * Computes the natural logarithm of the given column. @@ -1312,7 +1609,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log(base: Double, a: Column): Column = Logarithm(lit(base).expr, a.expr) + def log(base: Double, a: Column): Column = withExpr { Logarithm(lit(base).expr, a.expr) } /** * Returns the first argument-base logarithm of the second argument. @@ -1328,7 +1625,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log10(e: Column): Column = Log10(e.expr) + def log10(e: Column): Column = withExpr { Log10(e.expr) } /** * Computes the logarithm of the given value in base 10. @@ -1344,7 +1641,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log1p(e: Column): Column = Log1p(e.expr) + def log1p(e: Column): Column = withExpr { Log1p(e.expr) } /** * Computes the natural logarithm of the given column plus one. @@ -1360,7 +1657,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def log2(expr: Column): Column = Log2(expr.expr) + def log2(expr: Column): Column = withExpr { Log2(expr.expr) } /** * Computes the logarithm of the given value in base 2. @@ -1376,7 +1673,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def pow(l: Column, r: Column): Column = Pow(l.expr, r.expr) + def pow(l: Column, r: Column): Column = withExpr { Pow(l.expr, r.expr) } /** * Returns the value of the first argument raised to the power of the second argument. @@ -1408,7 +1705,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def pow(l: Column, r: Double): Column = pow(l, lit(r).expr) + def pow(l: Column, r: Double): Column = pow(l, lit(r)) /** * Returns the value of the first argument raised to the power of the second argument. @@ -1424,7 +1721,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def pow(l: Double, r: Column): Column = pow(lit(l).expr, r) + def pow(l: Double, r: Column): Column = pow(lit(l), r) /** * Returns the value of the first argument raised to the power of the second argument. @@ -1440,7 +1737,9 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def pmod(dividend: Column, divisor: Column): Column = Pmod(dividend.expr, divisor.expr) + def pmod(dividend: Column, divisor: Column): Column = withExpr { + Pmod(dividend.expr, divisor.expr) + } /** * Returns the double value that is closest in value to the argument and @@ -1449,7 +1748,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def rint(e: Column): Column = Rint(e.expr) + def rint(e: Column): Column = withExpr { Rint(e.expr) } /** * Returns the double value that is closest in value to the argument and @@ -1466,7 +1765,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def round(e: Column): Column = round(e.expr, 0) + def round(e: Column): Column = round(e, 0) /** * Round the value of `e` to `scale` decimal places if `scale` >= 0 @@ -1475,35 +1774,38 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale)) + def round(e: Column, scale: Int): Column = withExpr { Round(e.expr, Literal(scale)) } /** - * Shift the the given value numBits left. If the given value is a long value, this function + * Shift the given value numBits left. If the given value is a long value, this function * will return a long value else it will return an integer value. * * @group math_funcs * @since 1.5.0 */ - def shiftLeft(e: Column, numBits: Int): Column = ShiftLeft(e.expr, lit(numBits).expr) + def shiftLeft(e: Column, numBits: Int): Column = withExpr { ShiftLeft(e.expr, lit(numBits).expr) } /** - * Shift the the given value numBits right. If the given value is a long value, it will return + * Shift the given value numBits right. If the given value is a long value, it will return * a long value else it will return an integer value. * * @group math_funcs * @since 1.5.0 */ - def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr) + def shiftRight(e: Column, numBits: Int): Column = withExpr { + ShiftRight(e.expr, lit(numBits).expr) + } /** - * Unsigned shift the the given value numBits right. If the given value is a long value, + * Unsigned shift the given value numBits right. If the given value is a long value, * it will return a long value else it will return an integer value. * * @group math_funcs * @since 1.5.0 */ - def shiftRightUnsigned(e: Column, numBits: Int): Column = + def shiftRightUnsigned(e: Column, numBits: Int): Column = withExpr { ShiftRightUnsigned(e.expr, lit(numBits).expr) + } /** * Computes the signum of the given value. @@ -1511,7 +1813,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def signum(e: Column): Column = Signum(e.expr) + def signum(e: Column): Column = withExpr { Signum(e.expr) } /** * Computes the signum of the given column. @@ -1527,7 +1829,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def sin(e: Column): Column = Sin(e.expr) + def sin(e: Column): Column = withExpr { Sin(e.expr) } /** * Computes the sine of the given column. @@ -1543,7 +1845,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def sinh(e: Column): Column = Sinh(e.expr) + def sinh(e: Column): Column = withExpr { Sinh(e.expr) } /** * Computes the hyperbolic sine of the given column. @@ -1559,7 +1861,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def tan(e: Column): Column = Tan(e.expr) + def tan(e: Column): Column = withExpr { Tan(e.expr) } /** * Computes the tangent of the given column. @@ -1575,7 +1877,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def tanh(e: Column): Column = Tanh(e.expr) + def tanh(e: Column): Column = withExpr { Tanh(e.expr) } /** * Computes the hyperbolic tangent of the given column. @@ -1591,7 +1893,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def toDegrees(e: Column): Column = ToDegrees(e.expr) + def toDegrees(e: Column): Column = withExpr { ToDegrees(e.expr) } /** * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. @@ -1607,7 +1909,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def toRadians(e: Column): Column = ToRadians(e.expr) + def toRadians(e: Column): Column = withExpr { ToRadians(e.expr) } /** * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. @@ -1628,7 +1930,7 @@ object functions { * @group misc_funcs * @since 1.5.0 */ - def md5(e: Column): Column = Md5(e.expr) + def md5(e: Column): Column = withExpr { Md5(e.expr) } /** * Calculates the SHA-1 digest of a binary column and returns the value @@ -1637,7 +1939,7 @@ object functions { * @group misc_funcs * @since 1.5.0 */ - def sha1(e: Column): Column = Sha1(e.expr) + def sha1(e: Column): Column = withExpr { Sha1(e.expr) } /** * Calculates the SHA-2 family of hash functions of a binary column and @@ -1652,7 +1954,7 @@ object functions { def sha2(e: Column, numBits: Int): Column = { require(Seq(0, 224, 256, 384, 512).contains(numBits), s"numBits $numBits is not in the permitted values (0, 224, 256, 384, 512)") - Sha2(e.expr, lit(numBits).expr) + withExpr { Sha2(e.expr, lit(numBits).expr) } } /** @@ -1662,7 +1964,18 @@ object functions { * @group misc_funcs * @since 1.5.0 */ - def crc32(e: Column): Column = Crc32(e.expr) + def crc32(e: Column): Column = withExpr { Crc32(e.expr) } + + /** + * Calculates the hash code of given columns, and returns the result as an int column. + * + * @group misc_funcs + * @since 2.0 + */ + @scala.annotation.varargs + def hash(cols: Column*): Column = withExpr { + new Murmur3Hash(cols.map(_.expr)) + } ////////////////////////////////////////////////////////////////////////////////////////////// // String functions @@ -1675,7 +1988,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def ascii(e: Column): Column = Ascii(e.expr) + def ascii(e: Column): Column = withExpr { Ascii(e.expr) } /** * Computes the BASE64 encoding of a binary column and returns it as a string column. @@ -1684,7 +1997,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def base64(e: Column): Column = Base64(e.expr) + def base64(e: Column): Column = withExpr { Base64(e.expr) } /** * Concatenates multiple input string columns together into a single string column. @@ -1693,7 +2006,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def concat(exprs: Column*): Column = Concat(exprs.map(_.expr)) + def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) } /** * Concatenates multiple input string columns together into a single string column, @@ -1703,7 +2016,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def concat_ws(sep: String, exprs: Column*): Column = { + def concat_ws(sep: String, exprs: Column*): Column = withExpr { ConcatWs(Literal.create(sep, StringType) +: exprs.map(_.expr)) } @@ -1715,7 +2028,9 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def decode(value: Column, charset: String): Column = Decode(value.expr, lit(charset).expr) + def decode(value: Column, charset: String): Column = withExpr { + Decode(value.expr, lit(charset).expr) + } /** * Computes the first argument into a binary from a string using the provided character set @@ -1725,7 +2040,9 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def encode(value: Column, charset: String): Column = Encode(value.expr, lit(charset).expr) + def encode(value: Column, charset: String): Column = withExpr { + Encode(value.expr, lit(charset).expr) + } /** * Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places, @@ -1737,7 +2054,9 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr) + def format_number(x: Column, d: Int): Column = withExpr { + FormatNumber(x.expr, lit(d).expr) + } /** * Formats the arguments in printf-style and returns the result as a string column. @@ -1746,7 +2065,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def format_string(format: String, arguments: Column*): Column = { + def format_string(format: String, arguments: Column*): Column = withExpr { FormatString((lit(format) +: arguments).map(_.expr): _*) } @@ -1759,7 +2078,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def initcap(e: Column): Column = InitCap(e.expr) + def initcap(e: Column): Column = withExpr { InitCap(e.expr) } /** * Locate the position of the first occurrence of substr column in the given string. @@ -1771,7 +2090,9 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def instr(str: Column, substring: String): Column = StringInstr(str.expr, lit(substring).expr) + def instr(str: Column, substring: String): Column = withExpr { + StringInstr(str.expr, lit(substring).expr) + } /** * Computes the length of a given string or binary column. @@ -1779,7 +2100,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def length(e: Column): Column = Length(e.expr) + def length(e: Column): Column = withExpr { Length(e.expr) } /** * Converts a string column to lower case. @@ -1787,14 +2108,14 @@ object functions { * @group string_funcs * @since 1.3.0 */ - def lower(e: Column): Column = Lower(e.expr) + def lower(e: Column): Column = withExpr { Lower(e.expr) } /** * Computes the Levenshtein distance of the two given string columns. * @group string_funcs * @since 1.5.0 */ - def levenshtein(l: Column, r: Column): Column = Levenshtein(l.expr, r.expr) + def levenshtein(l: Column, r: Column): Column = withExpr { Levenshtein(l.expr, r.expr) } /** * Locate the position of the first occurrence of substr. @@ -1804,7 +2125,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def locate(substr: String, str: Column): Column = { + def locate(substr: String, str: Column): Column = withExpr { new StringLocate(lit(substr).expr, str.expr) } @@ -1817,7 +2138,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def locate(substr: String, str: Column, pos: Int): Column = { + def locate(substr: String, str: Column, pos: Int): Column = withExpr { StringLocate(lit(substr).expr, str.expr, lit(pos).expr) } @@ -1827,7 +2148,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def lpad(str: Column, len: Int, pad: String): Column = { + def lpad(str: Column, len: Int, pad: String): Column = withExpr { StringLPad(str.expr, lit(len).expr, lit(pad).expr) } @@ -1837,7 +2158,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def ltrim(e: Column): Column = StringTrimLeft(e.expr) + def ltrim(e: Column): Column = withExpr {StringTrimLeft(e.expr) } /** * Extract a specific(idx) group identified by a java regex, from the specified string column. @@ -1845,7 +2166,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def regexp_extract(e: Column, exp: String, groupIdx: Int): Column = { + def regexp_extract(e: Column, exp: String, groupIdx: Int): Column = withExpr { RegExpExtract(e.expr, lit(exp).expr, lit(groupIdx).expr) } @@ -1855,7 +2176,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def regexp_replace(e: Column, pattern: String, replacement: String): Column = { + def regexp_replace(e: Column, pattern: String, replacement: String): Column = withExpr { RegExpReplace(e.expr, lit(pattern).expr, lit(replacement).expr) } @@ -1866,7 +2187,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def unbase64(e: Column): Column = UnBase64(e.expr) + def unbase64(e: Column): Column = withExpr { UnBase64(e.expr) } /** * Right-padded with pad to a length of len. @@ -1874,7 +2195,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def rpad(str: Column, len: Int, pad: String): Column = { + def rpad(str: Column, len: Int, pad: String): Column = withExpr { StringRPad(str.expr, lit(len).expr, lit(pad).expr) } @@ -1884,7 +2205,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def repeat(str: Column, n: Int): Column = { + def repeat(str: Column, n: Int): Column = withExpr { StringRepeat(str.expr, lit(n).expr) } @@ -1894,9 +2215,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def reverse(str: Column): Column = { - StringReverse(str.expr) - } + def reverse(str: Column): Column = withExpr { StringReverse(str.expr) } /** * Trim the spaces from right end for the specified string value. @@ -1904,7 +2223,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def rtrim(e: Column): Column = StringTrimRight(e.expr) + def rtrim(e: Column): Column = withExpr { StringTrimRight(e.expr) } /** * * Return the soundex code for the specified expression. @@ -1912,16 +2231,16 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def soundex(e: Column): Column = SoundEx(e.expr) + def soundex(e: Column): Column = withExpr { SoundEx(e.expr) } /** * Splits str around pattern (pattern is a regular expression). - * NOTE: pattern is a string represent the regular expression. + * NOTE: pattern is a string representation of the regular expression. * * @group string_funcs * @since 1.5.0 */ - def split(str: Column, pattern: String): Column = { + def split(str: Column, pattern: String): Column = withExpr { StringSplit(str.expr, lit(pattern).expr) } @@ -1933,8 +2252,9 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def substring(str: Column, pos: Int, len: Int): Column = + def substring(str: Column, pos: Int, len: Int): Column = withExpr { Substring(str.expr, lit(pos).expr, lit(len).expr) + } /** * Returns the substring from string str before count occurrences of the delimiter delim. @@ -1944,20 +2264,22 @@ object functions { * * @group string_funcs */ - def substring_index(str: Column, delim: String, count: Int): Column = + def substring_index(str: Column, delim: String, count: Int): Column = withExpr { SubstringIndex(str.expr, lit(delim).expr, lit(count).expr) + } /** * Translate any character in the src by a character in replaceString. - * The characters in replaceString is corresponding to the characters in matchingString. - * The translate will happen when any character in the string matching with the character - * in the matchingString. + * The characters in replaceString correspond to the characters in matchingString. + * The translate will happen when any character in the string matches the character + * in the `matchingString`. * * @group string_funcs * @since 1.5.0 */ - def translate(src: Column, matchingString: String, replaceString: String): Column = + def translate(src: Column, matchingString: String, replaceString: String): Column = withExpr { StringTranslate(src.expr, lit(matchingString).expr, lit(replaceString).expr) + } /** * Trim the spaces from both ends for the specified string column. @@ -1965,7 +2287,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def trim(e: Column): Column = StringTrim(e.expr) + def trim(e: Column): Column = withExpr { StringTrim(e.expr) } /** * Converts a string column to upper case. @@ -1973,7 +2295,7 @@ object functions { * @group string_funcs * @since 1.3.0 */ - def upper(e: Column): Column = Upper(e.expr) + def upper(e: Column): Column = withExpr { Upper(e.expr) } ////////////////////////////////////////////////////////////////////////////////////////////// // DateTime functions @@ -1985,8 +2307,9 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def add_months(startDate: Column, numMonths: Int): Column = + def add_months(startDate: Column, numMonths: Int): Column = withExpr { AddMonths(startDate.expr, Literal(numMonths)) + } /** * Returns the current date as a date column. @@ -1994,7 +2317,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def current_date(): Column = CurrentDate() + def current_date(): Column = withExpr { CurrentDate() } /** * Returns the current timestamp as a timestamp column. @@ -2002,7 +2325,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def current_timestamp(): Column = CurrentTimestamp() + def current_timestamp(): Column = withExpr { CurrentTimestamp() } /** * Converts a date/timestamp/string to a value of string in the format specified by the date @@ -2017,71 +2340,72 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def date_format(dateExpr: Column, format: String): Column = + def date_format(dateExpr: Column, format: String): Column = withExpr { DateFormatClass(dateExpr.expr, Literal(format)) + } /** * Returns the date that is `days` days after `start` * @group datetime_funcs * @since 1.5.0 */ - def date_add(start: Column, days: Int): Column = DateAdd(start.expr, Literal(days)) + def date_add(start: Column, days: Int): Column = withExpr { DateAdd(start.expr, Literal(days)) } /** * Returns the date that is `days` days before `start` * @group datetime_funcs * @since 1.5.0 */ - def date_sub(start: Column, days: Int): Column = DateSub(start.expr, Literal(days)) + def date_sub(start: Column, days: Int): Column = withExpr { DateSub(start.expr, Literal(days)) } /** * Returns the number of days from `start` to `end`. * @group datetime_funcs * @since 1.5.0 */ - def datediff(end: Column, start: Column): Column = DateDiff(end.expr, start.expr) + def datediff(end: Column, start: Column): Column = withExpr { DateDiff(end.expr, start.expr) } /** * Extracts the year as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def year(e: Column): Column = Year(e.expr) + def year(e: Column): Column = withExpr { Year(e.expr) } /** * Extracts the quarter as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def quarter(e: Column): Column = Quarter(e.expr) + def quarter(e: Column): Column = withExpr { Quarter(e.expr) } /** * Extracts the month as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def month(e: Column): Column = Month(e.expr) + def month(e: Column): Column = withExpr { Month(e.expr) } /** * Extracts the day of the month as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def dayofmonth(e: Column): Column = DayOfMonth(e.expr) + def dayofmonth(e: Column): Column = withExpr { DayOfMonth(e.expr) } /** * Extracts the day of the year as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def dayofyear(e: Column): Column = DayOfYear(e.expr) + def dayofyear(e: Column): Column = withExpr { DayOfYear(e.expr) } /** * Extracts the hours as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def hour(e: Column): Column = Hour(e.expr) + def hour(e: Column): Column = withExpr { Hour(e.expr) } /** * Given a date column, returns the last day of the month which the given date belongs to. @@ -2091,21 +2415,23 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def last_day(e: Column): Column = LastDay(e.expr) + def last_day(e: Column): Column = withExpr { LastDay(e.expr) } /** * Extracts the minutes as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def minute(e: Column): Column = Minute(e.expr) + def minute(e: Column): Column = withExpr { Minute(e.expr) } /* * Returns number of months between dates `date1` and `date2`. * @group datetime_funcs * @since 1.5.0 */ - def months_between(date1: Column, date2: Column): Column = MonthsBetween(date1.expr, date2.expr) + def months_between(date1: Column, date2: Column): Column = withExpr { + MonthsBetween(date1.expr, date2.expr) + } /** * Given a date column, returns the first date which is later than the value of the date column @@ -2120,21 +2446,23 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def next_day(date: Column, dayOfWeek: String): Column = NextDay(date.expr, lit(dayOfWeek).expr) + def next_day(date: Column, dayOfWeek: String): Column = withExpr { + NextDay(date.expr, lit(dayOfWeek).expr) + } /** * Extracts the seconds as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def second(e: Column): Column = Second(e.expr) + def second(e: Column): Column = withExpr { Second(e.expr) } /** * Extracts the week number as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def weekofyear(e: Column): Column = WeekOfYear(e.expr) + def weekofyear(e: Column): Column = withExpr { WeekOfYear(e.expr) } /** * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string @@ -2143,7 +2471,9 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def from_unixtime(ut: Column): Column = FromUnixTime(ut.expr, Literal("yyyy-MM-dd HH:mm:ss")) + def from_unixtime(ut: Column): Column = withExpr { + FromUnixTime(ut.expr, Literal("yyyy-MM-dd HH:mm:ss")) + } /** * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string @@ -2152,14 +2482,18 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def from_unixtime(ut: Column, f: String): Column = FromUnixTime(ut.expr, Literal(f)) + def from_unixtime(ut: Column, f: String): Column = withExpr { + FromUnixTime(ut.expr, Literal(f)) + } /** * Gets current Unix timestamp in seconds. * @group datetime_funcs * @since 1.5.0 */ - def unix_timestamp(): Column = UnixTimestamp(CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")) + def unix_timestamp(): Column = withExpr { + UnixTimestamp(CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")) + } /** * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), @@ -2167,7 +2501,9 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def unix_timestamp(s: Column): Column = UnixTimestamp(s.expr, Literal("yyyy-MM-dd HH:mm:ss")) + def unix_timestamp(s: Column): Column = withExpr { + UnixTimestamp(s.expr, Literal("yyyy-MM-dd HH:mm:ss")) + } /** * Convert time string with given pattern @@ -2176,7 +2512,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def unix_timestamp(s: Column, p: String): Column = UnixTimestamp(s.expr, Literal(p)) + def unix_timestamp(s: Column, p: String): Column = withExpr {UnixTimestamp(s.expr, Literal(p)) } /** * Converts the column into DateType. @@ -2184,7 +2520,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def to_date(e: Column): Column = ToDate(e.expr) + def to_date(e: Column): Column = withExpr { ToDate(e.expr) } /** * Returns date truncated to the unit specified by the format. @@ -2195,34 +2531,174 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def trunc(date: Column, format: String): Column = TruncDate(date.expr, Literal(format)) + def trunc(date: Column, format: String): Column = withExpr { + TruncDate(date.expr, Literal(format)) + } /** * Assumes given timestamp is UTC and converts to given timezone. * @group datetime_funcs * @since 1.5.0 */ - def from_utc_timestamp(ts: Column, tz: String): Column = - FromUTCTimestamp(ts.expr, Literal(tz).expr) + def from_utc_timestamp(ts: Column, tz: String): Column = withExpr { + FromUTCTimestamp(ts.expr, Literal(tz)) + } /** * Assumes given timestamp is in given timezone and converts to UTC. * @group datetime_funcs * @since 1.5.0 */ - def to_utc_timestamp(ts: Column, tz: String): Column = ToUTCTimestamp(ts.expr, Literal(tz).expr) + def to_utc_timestamp(ts: Column, tz: String): Column = withExpr { + ToUTCTimestamp(ts.expr, Literal(tz)) + } + + /** + * Bucketize rows into one or more time windows given a timestamp specifying column. Window + * starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window + * [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in + * the order of months are not supported. The following example takes the average stock price for + * a one minute window every 10 seconds starting 5 seconds after the hour: + * + * {{{ + * val df = ... // schema => timestamp: TimestampType, stockId: StringType, price: DoubleType + * df.groupBy(window($"time", "1 minute", "10 seconds", "5 seconds"), $"stockId") + * .agg(mean("price")) + * }}} + * + * The windows will look like: + * + * {{{ + * 09:00:05-09:01:05 + * 09:00:15-09:01:15 + * 09:00:25-09:01:25 ... + * }}} + * + * For a continuous query, you may use the function `current_timestamp` to generate windows on + * processing time. + * + * @param timeColumn The column or the expression to use as the timestamp for windowing by time. + * The time column must be of TimestampType. + * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, + * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for + * valid duration identifiers. + * @param slideDuration A string specifying the sliding interval of the window, e.g. `1 minute`. + * A new window will be generated every `slideDuration`. Must be less than + * or equal to the `windowDuration`. Check + * [[org.apache.spark.unsafe.types.CalendarInterval]] for valid duration + * identifiers. + * @param startTime The offset with respect to 1970-01-01 00:00:00 UTC with which to start + * window intervals. For example, in order to have hourly tumbling windows that + * start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide + * `startTime` as `15 minutes`. + * + * @group datetime_funcs + * @since 2.0.0 + */ + @Experimental + def window( + timeColumn: Column, + windowDuration: String, + slideDuration: String, + startTime: String): Column = { + withExpr { + TimeWindow(timeColumn.expr, windowDuration, slideDuration, startTime) + }.as("window") + } + + + /** + * Bucketize rows into one or more time windows given a timestamp specifying column. Window + * starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window + * [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in + * the order of months are not supported. The windows start beginning at 1970-01-01 00:00:00 UTC. + * The following example takes the average stock price for a one minute window every 10 seconds: + * + * {{{ + * val df = ... // schema => timestamp: TimestampType, stockId: StringType, price: DoubleType + * df.groupBy(window($"time", "1 minute", "10 seconds"), $"stockId") + * .agg(mean("price")) + * }}} + * + * The windows will look like: + * + * {{{ + * 09:00:00-09:01:00 + * 09:00:10-09:01:10 + * 09:00:20-09:01:20 ... + * }}} + * + * For a continuous query, you may use the function `current_timestamp` to generate windows on + * processing time. + * + * @param timeColumn The column or the expression to use as the timestamp for windowing by time. + * The time column must be of TimestampType. + * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, + * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for + * valid duration identifiers. + * @param slideDuration A string specifying the sliding interval of the window, e.g. `1 minute`. + * A new window will be generated every `slideDuration`. Must be less than + * or equal to the `windowDuration`. Check + * [[org.apache.spark.unsafe.types.CalendarInterval]] for valid duration. + * + * @group datetime_funcs + * @since 2.0.0 + */ + @Experimental + def window(timeColumn: Column, windowDuration: String, slideDuration: String): Column = { + window(timeColumn, windowDuration, slideDuration, "0 second") + } + + /** + * Generates tumbling time windows given a timestamp specifying column. Window + * starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window + * [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in + * the order of months are not supported. The windows start beginning at 1970-01-01 00:00:00 UTC. + * The following example takes the average stock price for a one minute tumbling window: + * + * {{{ + * val df = ... // schema => timestamp: TimestampType, stockId: StringType, price: DoubleType + * df.groupBy(window($"time", "1 minute"), $"stockId") + * .agg(mean("price")) + * }}} + * + * The windows will look like: + * + * {{{ + * 09:00:00-09:01:00 + * 09:01:00-09:02:00 + * 09:02:00-09:03:00 ... + * }}} + * + * For a continuous query, you may use the function `current_timestamp` to generate windows on + * processing time. + * + * @param timeColumn The column or the expression to use as the timestamp for windowing by time. + * The time column must be of TimestampType. + * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, + * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for + * valid duration identifiers. + * + * @group datetime_funcs + * @since 2.0.0 + */ + @Experimental + def window(timeColumn: Column, windowDuration: String): Column = { + window(timeColumn, windowDuration, windowDuration, "0 second") + } ////////////////////////////////////////////////////////////////////////////////////////////// // Collection functions ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Returns true if the array contain the value + * Returns true if the array contains `value` * @group collection_funcs * @since 1.5.0 */ - def array_contains(column: Column, value: Any): Column = + def array_contains(column: Column, value: Any): Column = withExpr { ArrayContains(column.expr, Literal(value)) + } /** * Creates a new row for each element in the given array or map column. @@ -2230,7 +2706,30 @@ object functions { * @group collection_funcs * @since 1.3.0 */ - def explode(e: Column): Column = Explode(e.expr) + def explode(e: Column): Column = withExpr { Explode(e.expr) } + + /** + * Extracts json object from a json string based on json path specified, and returns json string + * of the extracted json object. It will return null if the input json string is invalid. + * + * @group collection_funcs + * @since 1.6.0 + */ + def get_json_object(e: Column, path: String): Column = withExpr { + GetJsonObject(e.expr, lit(path).expr) + } + + /** + * Creates a new row for a json column according to the given field names. + * + * @group collection_funcs + * @since 1.6.0 + */ + @scala.annotation.varargs + def json_tuple(json: Column, fields: String*): Column = withExpr { + require(fields.nonEmpty, "at least 1 field name should be given.") + JsonTuple(json.expr +: fields.map(Literal.apply)) + } /** * Returns length of array or map. @@ -2238,7 +2737,7 @@ object functions { * @group collection_funcs * @since 1.5.0 */ - def size(e: Column): Column = Size(e.expr) + def size(e: Column): Column = withExpr { Size(e.expr) } /** * Sorts the input array for the given column in ascending order, @@ -2256,12 +2755,13 @@ object functions { * @group collection_funcs * @since 1.5.0 */ - def sort_array(e: Column, asc: Boolean): Column = SortArray(e.expr, lit(asc).expr) + def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// - // scalastyle:off + // scalastyle:off line.size.limit + // scalastyle:off parameter.number /* Use the following code to generate: (0 to 10).map { x => @@ -2277,30 +2777,11 @@ object functions { * @since 1.3.0 */ def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { - val inputTypes = Try($inputTypes).getOrElse(Nil) + val inputTypes = Try($inputTypes).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) }""") } - (0 to 10).map { x => - val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") - val fTypes = Seq.fill(x + 1)("_").mkString(", ") - val argsInUDF = (1 to x).map(i => s"arg$i.expr").mkString(", ") - println(s""" - /** - * Call a Scala function of ${x} arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = { - ScalaUDF(f, returnType, Seq($argsInUDF)) - }""") - } - } */ /** * Defines a user-defined function of 0 arguments as user-defined function (UDF). @@ -2310,7 +2791,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = { - val inputTypes = Try(Nil).getOrElse(Nil) + val inputTypes = Try(Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2322,7 +2803,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2334,7 +2815,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2346,7 +2827,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2358,7 +2839,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2370,7 +2851,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2382,7 +2863,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2394,7 +2875,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2406,7 +2887,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2418,7 +2899,7 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2430,157 +2911,27 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: ScalaReflection.schemaFor(typeTag[A10]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: ScalaReflection.schemaFor(typeTag[A10]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } - ////////////////////////////////////////////////////////////////////////////////////////////////// - - /** - * Call a Scala function of 0 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function0[_], returnType: DataType): Column = { - ScalaUDF(f, returnType, Seq()) - } - - /** - * Call a Scala function of 1 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = { - ScalaUDF(f, returnType, Seq(arg1.expr)) - } + // scalastyle:on parameter.number + // scalastyle:on line.size.limit /** - * Call a Scala function of 2 arguments as user-defined function (UDF). This requires - * you to specify the return data type. + * Defines a user-defined function (UDF) using a Scala closure. For this variant, the caller must + * specify the output data type, and there is no automatic input type coercion. * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr)) - } - - /** - * Call a Scala function of 3 arguments as user-defined function (UDF). This requires - * you to specify the return data type. + * @param f A closure in Scala + * @param dataType The output data type of the UDF * * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() + * @since 2.0.0 */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) + def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = { + UserDefinedFunction(f, dataType, None) } - /** - * Call a Scala function of 4 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) - } - - /** - * Call a Scala function of 5 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) - } - - /** - * Call a Scala function of 6 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) - } - - /** - * Call a Scala function of 7 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) - } - - /** - * Call a Scala function of 8 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) - } - - /** - * Call a Scala function of 9 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) - } - - /** - * Call a Scala function of 10 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) - } - - // scalastyle:on - /** * Call an user-defined function. * Example: @@ -2597,36 +2948,8 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def callUDF(udfName: String, cols: Column*): Column = { + def callUDF(udfName: String, cols: Column*): Column = withExpr { UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false) } - /** - * Call an user-defined function. - * Example: - * {{{ - * import org.apache.spark.sql._ - * - * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") - * val sqlContext = df.sqlContext - * sqlContext.udf.register("simpleUDF", (v: Int) => v * v) - * df.select($"id", callUdf("simpleUDF", $"value")) - * }}} - * - * @group udf_funcs - * @since 1.4.0 - * @deprecated As of 1.5.0, since it was not coherent to have two functions callUdf and callUDF - */ - @deprecated("Use callUDF", "1.5.0") - def callUdf(udfName: String, cols: Column*): Column = { - // Note: we avoid using closures here because on file systems that are case-insensitive, the - // compiled class file for the closure here will conflict with the one in callUDF (upper case). - val exprs = new Array[Expression](cols.size) - var i = 0 - while (i < cols.size) { - exprs(i) = cols(i).expr - i += 1 - } - UnresolvedFunction(udfName, exprs, isDistinct = false) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala new file mode 100644 index 0000000000000..058df1e3c19a7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.sql.RuntimeConfig + +/** + * Implementation for [[RuntimeConfig]]. + */ +class RuntimeConfigImpl extends RuntimeConfig { + + private val conf = new SQLConf + + private val hadoopConf = java.util.Collections.synchronizedMap( + new java.util.HashMap[String, String]()) + + override def set(key: String, value: String): RuntimeConfig = { + conf.setConfString(key, value) + this + } + + override def set(key: String, value: Boolean): RuntimeConfig = set(key, value.toString) + + override def set(key: String, value: Long): RuntimeConfig = set(key, value.toString) + + @throws[NoSuchElementException]("if the key is not set") + override def get(key: String): String = conf.getConfString(key) + + override def getOption(key: String): Option[String] = { + try Option(get(key)) catch { + case _: NoSuchElementException => None + } + } + + override def unset(key: String): Unit = conf.unsetConf(key) + + override def setHadoop(key: String, value: String): RuntimeConfig = { + hadoopConf.put(key, value) + this + } + + @throws[NoSuchElementException]("if the key is not set") + override def getHadoop(key: String): String = hadoopConf.synchronized { + if (hadoopConf.containsKey(key)) { + hadoopConf.get(key) + } else { + throw new NoSuchElementException(key) + } + } + + override def getHadoopOption(key: String): Option[String] = { + try Option(getHadoop(key)) catch { + case _: NoSuchElementException => None + } + } + + override def unsetHadoop(key: String): Unit = hadoopConf.remove(key) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala new file mode 100644 index 0000000000000..20d9a285483f0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -0,0 +1,692 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import java.util.{NoSuchElementException, Properties} + +import scala.collection.JavaConverters._ +import scala.collection.immutable + +import org.apache.parquet.hadoop.ParquetOutputCommitter + +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.network.util.ByteUnit +import org.apache.spark.sql.catalyst.CatalystConf + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines the configuration options for Spark SQL. +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +object SQLConf { + + private val sqlConfEntries = java.util.Collections.synchronizedMap( + new java.util.HashMap[String, ConfigEntry[_]]()) + + private def register(entry: ConfigEntry[_]): Unit = sqlConfEntries.synchronized { + require(!sqlConfEntries.containsKey(entry.key), + s"Duplicate SQLConfigEntry. ${entry.key} has been registered") + sqlConfEntries.put(entry.key, entry) + } + + private[sql] object SQLConfigBuilder { + + def apply(key: String): ConfigBuilder = new ConfigBuilder(key).onCreate(register) + + } + + val ALLOW_MULTIPLE_CONTEXTS = SQLConfigBuilder("spark.sql.allowMultipleContexts") + .doc("When set to true, creating multiple SQLContexts/HiveContexts is allowed. " + + "When set to false, only one SQLContext/HiveContext is allowed to be created " + + "through the constructor (new SQLContexts/HiveContexts created through newSession " + + "method is allowed). Please note that this conf needs to be set in Spark Conf. Once " + + "a SQLContext/HiveContext has been created, changing the value of this conf will not " + + "have effect.") + .booleanConf + .createWithDefault(true) + + val COMPRESS_CACHED = SQLConfigBuilder("spark.sql.inMemoryColumnarStorage.compressed") + .internal() + .doc("When set to true Spark SQL will automatically select a compression codec for each " + + "column based on statistics of the data.") + .booleanConf + .createWithDefault(true) + + val COLUMN_BATCH_SIZE = SQLConfigBuilder("spark.sql.inMemoryColumnarStorage.batchSize") + .internal() + .doc("Controls the size of batches for columnar caching. Larger batch sizes can improve " + + "memory utilization and compression, but risk OOMs when caching data.") + .intConf + .createWithDefault(10000) + + val IN_MEMORY_PARTITION_PRUNING = + SQLConfigBuilder("spark.sql.inMemoryColumnarStorage.partitionPruning") + .internal() + .doc("When true, enable partition pruning for in-memory columnar tables.") + .booleanConf + .createWithDefault(true) + + val PREFER_SORTMERGEJOIN = SQLConfigBuilder("spark.sql.join.preferSortMergeJoin") + .internal() + .doc("When true, prefer sort merge join over shuffle hash join.") + .booleanConf + .createWithDefault(true) + + val AUTO_BROADCASTJOIN_THRESHOLD = SQLConfigBuilder("spark.sql.autoBroadcastJoinThreshold") + .doc("Configures the maximum size in bytes for a table that will be broadcast to all worker " + + "nodes when performing a join. By setting this value to -1 broadcasting can be disabled. " + + "Note that currently statistics are only supported for Hive Metastore tables where the " + + "commandANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run.") + .intConf + .createWithDefault(10 * 1024 * 1024) + + val DEFAULT_SIZE_IN_BYTES = SQLConfigBuilder("spark.sql.defaultSizeInBytes") + .internal() + .doc("The default table size used in query planning. By default, it is set to a larger " + + "value than `spark.sql.autoBroadcastJoinThreshold` to be more conservative. That is to say " + + "by default the optimizer will not choose to broadcast a table unless it knows for sure " + + "its size is small enough.") + .longConf + .createWithDefault(-1) + + val SHUFFLE_PARTITIONS = SQLConfigBuilder("spark.sql.shuffle.partitions") + .doc("The default number of partitions to use when shuffling data for joins or aggregations.") + .intConf + .createWithDefault(200) + + val SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE = + SQLConfigBuilder("spark.sql.adaptive.shuffle.targetPostShuffleInputSize") + .doc("The target post-shuffle input size in bytes of a task.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(64 * 1024 * 1024) + + val ADAPTIVE_EXECUTION_ENABLED = SQLConfigBuilder("spark.sql.adaptive.enabled") + .doc("When true, enable adaptive query execution.") + .booleanConf + .createWithDefault(false) + + val SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS = + SQLConfigBuilder("spark.sql.adaptive.minNumPostShufflePartitions") + .internal() + .doc("The advisory minimal number of post-shuffle partitions provided to " + + "ExchangeCoordinator. This setting is used in our test to make sure we " + + "have enough parallelism to expose issues that will not be exposed with a " + + "single partition. When the value is a non-positive value, this setting will " + + "not be provided to ExchangeCoordinator.") + .intConf + .createWithDefault(-1) + + val SUBEXPRESSION_ELIMINATION_ENABLED = + SQLConfigBuilder("spark.sql.subexpressionElimination.enabled") + .internal() + .doc("When true, common subexpressions will be eliminated.") + .booleanConf + .createWithDefault(true) + + val CASE_SENSITIVE = SQLConfigBuilder("spark.sql.caseSensitive") + .doc("Whether the query analyzer should be case sensitive or not.") + .booleanConf + .createWithDefault(true) + + val PARQUET_SCHEMA_MERGING_ENABLED = SQLConfigBuilder("spark.sql.parquet.mergeSchema") + .doc("When true, the Parquet data source merges schemas collected from all data files, " + + "otherwise the schema is picked from the summary file or a random data file " + + "if no summary file is available.") + .booleanConf + .createWithDefault(false) + + val PARQUET_SCHEMA_RESPECT_SUMMARIES = SQLConfigBuilder("spark.sql.parquet.respectSummaryFiles") + .doc("When true, we make assumption that all part-files of Parquet are consistent with " + + "summary files and we will ignore them when merging schema. Otherwise, if this is " + + "false, which is the default, we will merge all part-files. This should be considered " + + "as expert-only option, and shouldn't be enabled before knowing what it means exactly.") + .booleanConf + .createWithDefault(false) + + val PARQUET_BINARY_AS_STRING = SQLConfigBuilder("spark.sql.parquet.binaryAsString") + .doc("Some other Parquet-producing systems, in particular Impala and older versions of " + + "Spark SQL, do not differentiate between binary data and strings when writing out the " + + "Parquet schema. This flag tells Spark SQL to interpret binary data as a string to provide " + + "compatibility with these systems.") + .booleanConf + .createWithDefault(false) + + val PARQUET_INT96_AS_TIMESTAMP = SQLConfigBuilder("spark.sql.parquet.int96AsTimestamp") + .doc("Some Parquet-producing systems, in particular Impala, store Timestamp into INT96. " + + "Spark would also store Timestamp as INT96 because we need to avoid precision lost of the " + + "nanoseconds field. This flag tells Spark SQL to interpret INT96 data as a timestamp to " + + "provide compatibility with these systems.") + .booleanConf + .createWithDefault(true) + + val PARQUET_CACHE_METADATA = SQLConfigBuilder("spark.sql.parquet.cacheMetadata") + .doc("Turns on caching of Parquet schema metadata. Can speed up querying of static data.") + .booleanConf + .createWithDefault(true) + + val PARQUET_COMPRESSION = SQLConfigBuilder("spark.sql.parquet.compression.codec") + .doc("Sets the compression codec use when writing Parquet files. Acceptable values include: " + + "uncompressed, snappy, gzip, lzo.") + .stringConf + .transform(_.toLowerCase()) + .checkValues(Set("uncompressed", "snappy", "gzip", "lzo")) + .createWithDefault("snappy") + + val PARQUET_FILTER_PUSHDOWN_ENABLED = SQLConfigBuilder("spark.sql.parquet.filterPushdown") + .doc("Enables Parquet filter push-down optimization when set to true.") + .booleanConf + .createWithDefault(true) + + val PARQUET_WRITE_LEGACY_FORMAT = SQLConfigBuilder("spark.sql.parquet.writeLegacyFormat") + .doc("Whether to follow Parquet's format specification when converting Parquet schema to " + + "Spark SQL schema and vice versa.") + .booleanConf + .createWithDefault(false) + + val PARQUET_OUTPUT_COMMITTER_CLASS = SQLConfigBuilder("spark.sql.parquet.output.committer.class") + .doc("The output committer class used by Parquet. The specified class needs to be a " + + "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + + "of org.apache.parquet.hadoop.ParquetOutputCommitter. NOTE: 1. Instead of SQLConf, this " + + "option must be set in Hadoop Configuration. 2. This option overrides " + + "\"spark.sql.sources.outputCommitterClass\".") + .stringConf + .createWithDefault(classOf[ParquetOutputCommitter].getName) + + val PARQUET_VECTORIZED_READER_ENABLED = + SQLConfigBuilder("spark.sql.parquet.enableVectorizedReader") + .doc("Enables vectorized parquet decoding.") + .booleanConf + .createWithDefault(true) + + val ORC_FILTER_PUSHDOWN_ENABLED = SQLConfigBuilder("spark.sql.orc.filterPushdown") + .doc("When true, enable filter pushdown for ORC files.") + .booleanConf + .createWithDefault(false) + + val HIVE_VERIFY_PARTITION_PATH = SQLConfigBuilder("spark.sql.hive.verifyPartitionPath") + .doc("When true, check all the partition paths under the table\'s root directory " + + "when reading data stored in HDFS.") + .booleanConf + .createWithDefault(false) + + val HIVE_METASTORE_PARTITION_PRUNING = + SQLConfigBuilder("spark.sql.hive.metastorePartitionPruning") + .doc("When true, some predicates will be pushed down into the Hive metastore so that " + + "unmatching partitions can be eliminated earlier.") + .booleanConf + .createWithDefault(false) + + val NATIVE_VIEW = SQLConfigBuilder("spark.sql.nativeView") + .internal() + .doc("When true, CREATE VIEW will be handled by Spark SQL instead of Hive native commands. " + + "Note that this function is experimental and should ony be used when you are using " + + "non-hive-compatible tables written by Spark SQL. The SQL string used to create " + + "view should be fully qualified, i.e. use `tbl1`.`col1` instead of `*` whenever " + + "possible, or you may get wrong result.") + .booleanConf + .createWithDefault(true) + + val CANONICAL_NATIVE_VIEW = SQLConfigBuilder("spark.sql.nativeView.canonical") + .internal() + .doc("When this option and spark.sql.nativeView are both true, Spark SQL tries to handle " + + "CREATE VIEW statement using SQL query string generated from view definition logical " + + "plan. If the logical plan doesn't have a SQL representation, we fallback to the " + + "original native view implementation.") + .booleanConf + .createWithDefault(true) + + val COLUMN_NAME_OF_CORRUPT_RECORD = SQLConfigBuilder("spark.sql.columnNameOfCorruptRecord") + .doc("The name of internal column for storing raw/un-parsed JSON records that fail to parse.") + .stringConf + .createWithDefault("_corrupt_record") + + val BROADCAST_TIMEOUT = SQLConfigBuilder("spark.sql.broadcastTimeout") + .doc("Timeout in seconds for the broadcast wait time in broadcast joins.") + .intConf + .createWithDefault(5 * 60) + + // This is only used for the thriftserver + val THRIFTSERVER_POOL = SQLConfigBuilder("spark.sql.thriftserver.scheduler.pool") + .doc("Set a Fair Scheduler pool for a JDBC client session.") + .stringConf + .createOptional + + val THRIFTSERVER_UI_STATEMENT_LIMIT = + SQLConfigBuilder("spark.sql.thriftserver.ui.retainedStatements") + .doc("The number of SQL statements kept in the JDBC/ODBC web UI history.") + .intConf + .createWithDefault(200) + + val THRIFTSERVER_UI_SESSION_LIMIT = SQLConfigBuilder("spark.sql.thriftserver.ui.retainedSessions") + .doc("The number of SQL client sessions kept in the JDBC/ODBC web UI history.") + .intConf + .createWithDefault(200) + + // This is used to set the default data source + val DEFAULT_DATA_SOURCE_NAME = SQLConfigBuilder("spark.sql.sources.default") + .doc("The default data source to use in input/output.") + .stringConf + .createWithDefault("org.apache.spark.sql.parquet") + + // This is used to control the when we will split a schema's JSON string to multiple pieces + // in order to fit the JSON string in metastore's table property (by default, the value has + // a length restriction of 4000 characters). We will split the JSON string of a schema + // to its length exceeds the threshold. + val SCHEMA_STRING_LENGTH_THRESHOLD = + SQLConfigBuilder("spark.sql.sources.schemaStringLengthThreshold") + .doc("The maximum length allowed in a single cell when " + + "storing additional schema information in Hive's metastore.") + .internal() + .intConf + .createWithDefault(4000) + + val PARTITION_DISCOVERY_ENABLED = SQLConfigBuilder("spark.sql.sources.partitionDiscovery.enabled") + .doc("When true, automatically discover data partitions.") + .booleanConf + .createWithDefault(true) + + val PARTITION_COLUMN_TYPE_INFERENCE = + SQLConfigBuilder("spark.sql.sources.partitionColumnTypeInference.enabled") + .doc("When true, automatically infer the data types for partitioned columns.") + .booleanConf + .createWithDefault(true) + + val PARTITION_MAX_FILES = + SQLConfigBuilder("spark.sql.sources.maxConcurrentWrites") + .doc("The maximum number of concurrent files to open before falling back on sorting when " + + "writing out files using dynamic partitioning.") + .intConf + .createWithDefault(1) + + val BUCKETING_ENABLED = SQLConfigBuilder("spark.sql.sources.bucketing.enabled") + .doc("When false, we will treat bucketed table as normal table") + .booleanConf + .createWithDefault(true) + + val ORDER_BY_ORDINAL = SQLConfigBuilder("spark.sql.orderByOrdinal") + .doc("When true, the ordinal numbers are treated as the position in the select list. " + + "When false, the ordinal numbers in order/sort By clause are ignored.") + .booleanConf + .createWithDefault(true) + + val GROUP_BY_ORDINAL = SQLConfigBuilder("spark.sql.groupByOrdinal") + .doc("When true, the ordinal numbers in group by clauses are treated as the position " + + "in the select list. When false, the ordinal numbers are ignored.") + .booleanConf + .createWithDefault(true) + + // The output committer class used by HadoopFsRelation. The specified class needs to be a + // subclass of org.apache.hadoop.mapreduce.OutputCommitter. + // + // NOTE: + // + // 1. Instead of SQLConf, this option *must be set in Hadoop Configuration*. + // 2. This option can be overridden by "spark.sql.parquet.output.committer.class". + val OUTPUT_COMMITTER_CLASS = + SQLConfigBuilder("spark.sql.sources.outputCommitterClass").internal().stringConf.createOptional + + val PARALLEL_PARTITION_DISCOVERY_THRESHOLD = + SQLConfigBuilder("spark.sql.sources.parallelPartitionDiscovery.threshold") + .doc("The degree of parallelism for schema merging and partition discovery of " + + "Parquet data sources.") + .intConf + .createWithDefault(32) + + // Whether to perform eager analysis when constructing a dataframe. + // Set to false when debugging requires the ability to look at invalid query plans. + val DATAFRAME_EAGER_ANALYSIS = SQLConfigBuilder("spark.sql.eagerAnalysis") + .internal() + .doc("When true, eagerly applies query analysis on DataFrame operations.") + .booleanConf + .createWithDefault(true) + + // Whether to automatically resolve ambiguity in join conditions for self-joins. + // See SPARK-6231. + val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = + SQLConfigBuilder("spark.sql.selfJoinAutoResolveAmbiguity") + .internal() + .booleanConf + .createWithDefault(true) + + // Whether to retain group by columns or not in GroupedData.agg. + val DATAFRAME_RETAIN_GROUP_COLUMNS = SQLConfigBuilder("spark.sql.retainGroupColumns") + .internal() + .booleanConf + .createWithDefault(true) + + val DATAFRAME_PIVOT_MAX_VALUES = SQLConfigBuilder("spark.sql.pivotMaxValues") + .doc("When doing a pivot without specifying values for the pivot column this is the maximum " + + "number of (distinct) values that will be collected without error.") + .intConf + .createWithDefault(10000) + + val RUN_SQL_ON_FILES = SQLConfigBuilder("spark.sql.runSQLOnFiles") + .internal() + .doc("When true, we could use `datasource`.`path` as table in SQL query.") + .booleanConf + .createWithDefault(true) + + val WHOLESTAGE_CODEGEN_ENABLED = SQLConfigBuilder("spark.sql.codegen.wholeStage") + .internal() + .doc("When true, the whole stage (of multiple operators) will be compiled into single java" + + " method.") + .booleanConf + .createWithDefault(true) + + val WHOLESTAGE_MAX_NUM_FIELDS = SQLConfigBuilder("spark.sql.codegen.maxFields") + .internal() + .doc("The maximum number of fields (including nested fields) that will be supported before" + + " deactivating whole-stage codegen.") + .intConf + .createWithDefault(200) + + val FILES_MAX_PARTITION_BYTES = SQLConfigBuilder("spark.sql.files.maxPartitionBytes") + .doc("The maximum number of bytes to pack into a single partition when reading files.") + .longConf + .createWithDefault(128 * 1024 * 1024) // parquet.block.size + + val FILES_OPEN_COST_IN_BYTES = SQLConfigBuilder("spark.sql.files.openCostInBytes") + .internal() + .doc("The estimated cost to open a file, measured by the number of bytes could be scanned in" + + " the same time. This is used when putting multiple files into a partition. It's better to" + + " over estimated, then the partitions with small files will be faster than partitions with" + + " bigger files (which is scheduled first).") + .longConf + .createWithDefault(4 * 1024 * 1024) + + val EXCHANGE_REUSE_ENABLED = SQLConfigBuilder("spark.sql.exchange.reuse") + .internal() + .doc("When true, the planner will try to find out duplicated exchanges and re-use them.") + .booleanConf + .createWithDefault(true) + + val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT = + SQLConfigBuilder("spark.sql.streaming.stateStore.minDeltasForSnapshot") + .internal() + .doc("Minimum number of state store delta files that needs to be generated before they " + + "consolidated into snapshots.") + .intConf + .createWithDefault(10) + + val STATE_STORE_MIN_VERSIONS_TO_RETAIN = + SQLConfigBuilder("spark.sql.streaming.stateStore.minBatchesToRetain") + .internal() + .doc("Minimum number of versions of a state store's data to retain after cleaning.") + .intConf + .createWithDefault(2) + + val CHECKPOINT_LOCATION = SQLConfigBuilder("spark.sql.streaming.checkpointLocation") + .doc("The default location for storing checkpoint data for continuously executing queries.") + .stringConf + .createOptional + + // TODO: This is still WIP and shouldn't be turned on without extensive test coverage + val COLUMNAR_AGGREGATE_MAP_ENABLED = SQLConfigBuilder("spark.sql.codegen.aggregate.map.enabled") + .internal() + .doc("When true, aggregate with keys use an in-memory columnar map to speed up execution.") + .booleanConf + .createWithDefault(false) + + object Deprecated { + val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" + val EXTERNAL_SORT = "spark.sql.planner.externalSort" + val USE_SQL_AGGREGATE2 = "spark.sql.useAggregate2" + val TUNGSTEN_ENABLED = "spark.sql.tungsten.enabled" + val CODEGEN_ENABLED = "spark.sql.codegen" + val UNSAFE_ENABLED = "spark.sql.unsafe.enabled" + val SORTMERGE_JOIN = "spark.sql.planner.sortMergeJoin" + val PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED = "spark.sql.parquet.enableUnsafeRowRecordReader" + } +} + +/** + * A class that enables the setting and getting of mutable config parameters/hints. + * + * In the presence of a SQLContext, these can be set and queried by passing SET commands + * into Spark SQL's query functions (i.e. sql()). Otherwise, users of this class can + * modify the hints by programmatically calling the setters and getters of this class. + * + * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). + */ +private[sql] class SQLConf extends Serializable with CatalystConf with Logging { + import SQLConf._ + + /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ + @transient protected[spark] val settings = java.util.Collections.synchronizedMap( + new java.util.HashMap[String, String]()) + + /** ************************ Spark SQL Params/Hints ******************* */ + + def checkpointLocation: String = getConf(CHECKPOINT_LOCATION) + + def filesMaxPartitionBytes: Long = getConf(FILES_MAX_PARTITION_BYTES) + + def filesOpenCostInBytes: Long = getConf(FILES_OPEN_COST_IN_BYTES) + + def useCompression: Boolean = getConf(COMPRESS_CACHED) + + def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) + + def parquetCacheMetadata: Boolean = getConf(PARQUET_CACHE_METADATA) + + def parquetVectorizedReaderEnabled: Boolean = getConf(PARQUET_VECTORIZED_READER_ENABLED) + + def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) + + def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) + + def targetPostShuffleInputSize: Long = + getConf(SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) + + def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED) + + def minNumPostShufflePartitions: Int = + getConf(SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS) + + def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) + + def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) + + def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) + + def metastorePartitionPruning: Boolean = getConf(HIVE_METASTORE_PARTITION_PRUNING) + + def nativeView: Boolean = getConf(NATIVE_VIEW) + + def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED) + + def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS) + + def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED) + + def canonicalView: Boolean = getConf(CANONICAL_NATIVE_VIEW) + + def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) + + def subexpressionEliminationEnabled: Boolean = + getConf(SUBEXPRESSION_ELIMINATION_ENABLED) + + def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) + + def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN) + + def defaultSizeInBytes: Long = + getConf(DEFAULT_SIZE_IN_BYTES, autoBroadcastJoinThreshold + 1L) + + def isParquetBinaryAsString: Boolean = getConf(PARQUET_BINARY_AS_STRING) + + def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP) + + def writeLegacyParquetFormat: Boolean = getConf(PARQUET_WRITE_LEGACY_FORMAT) + + def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) + + def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD) + + def broadcastTimeout: Int = getConf(BROADCAST_TIMEOUT) + + def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME) + + def partitionDiscoveryEnabled(): Boolean = + getConf(SQLConf.PARTITION_DISCOVERY_ENABLED) + + def partitionColumnTypeInferenceEnabled(): Boolean = + getConf(SQLConf.PARTITION_COLUMN_TYPE_INFERENCE) + + def parallelPartitionDiscoveryThreshold: Int = + getConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD) + + def bucketingEnabled: Boolean = getConf(SQLConf.BUCKETING_ENABLED) + + // Do not use a value larger than 4000 as the default value of this property. + // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information. + def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD) + + def dataFrameEagerAnalysis: Boolean = getConf(DATAFRAME_EAGER_ANALYSIS) + + def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = + getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) + + def dataFrameRetainGroupColumns: Boolean = getConf(DATAFRAME_RETAIN_GROUP_COLUMNS) + + def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES) + + def columnarAggregateMapEnabled: Boolean = getConf(COLUMNAR_AGGREGATE_MAP_ENABLED) + + override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL) + + override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) + /** ********************** SQLConf functionality methods ************ */ + + /** Set Spark SQL configuration properties. */ + def setConf(props: Properties): Unit = settings.synchronized { + props.asScala.foreach { case (k, v) => setConfString(k, v) } + } + + /** Set the given Spark SQL configuration property using a `string` value. */ + def setConfString(key: String, value: String): Unit = { + require(key != null, "key cannot be null") + require(value != null, s"value cannot be null for key: $key") + val entry = sqlConfEntries.get(key) + if (entry != null) { + // Only verify configs in the SQLConf object + entry.valueConverter(value) + } + setConfWithCheck(key, value) + } + + /** Set the given Spark SQL configuration property. */ + def setConf[T](entry: ConfigEntry[T], value: T): Unit = { + require(entry != null, "entry cannot be null") + require(value != null, s"value cannot be null for key: ${entry.key}") + require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") + setConfWithCheck(entry.key, entry.stringConverter(value)) + } + + /** Return the value of Spark SQL configuration property for the given key. */ + @throws[NoSuchElementException]("if key is not set") + def getConfString(key: String): String = { + Option(settings.get(key)). + orElse { + // Try to use the default value + Option(sqlConfEntries.get(key)).map(_.defaultValueString) + }. + getOrElse(throw new NoSuchElementException(key)) + } + + /** + * Return the value of Spark SQL configuration property for the given key. If the key is not set + * yet, return `defaultValue`. This is useful when `defaultValue` in ConfigEntry is not the + * desired one. + */ + def getConf[T](entry: ConfigEntry[T], defaultValue: T): T = { + require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") + Option(settings.get(entry.key)).map(entry.valueConverter).getOrElse(defaultValue) + } + + /** + * Return the value of Spark SQL configuration property for the given key. If the key is not set + * yet, return `defaultValue` in [[ConfigEntry]]. + */ + def getConf[T](entry: ConfigEntry[T]): T = { + require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") + Option(settings.get(entry.key)).map(entry.valueConverter).orElse(entry.defaultValue). + getOrElse(throw new NoSuchElementException(entry.key)) + } + + /** + * Return the value of an optional Spark SQL configuration property for the given key. If the key + * is not set yet, throw an exception. + */ + def getConf[T](entry: OptionalConfigEntry[T]): T = { + require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") + Option(settings.get(entry.key)).map(entry.rawValueConverter). + getOrElse(throw new NoSuchElementException(entry.key)) + } + + /** + * Return the `string` value of Spark SQL configuration property for the given key. If the key is + * not set yet, return `defaultValue`. + */ + def getConfString(key: String, defaultValue: String): String = { + val entry = sqlConfEntries.get(key) + if (entry != null && defaultValue != "") { + // Only verify configs in the SQLConf object + entry.valueConverter(defaultValue) + } + Option(settings.get(key)).getOrElse(defaultValue) + } + + /** + * Return all the configuration properties that have been set (i.e. not the default). + * This creates a new copy of the config properties in the form of a Map. + */ + def getAllConfs: immutable.Map[String, String] = + settings.synchronized { settings.asScala.toMap } + + /** + * Return all the configuration definitions that have been defined in [[SQLConf]]. Each + * definition contains key, defaultValue and doc. + */ + def getAllDefinedConfs: Seq[(String, String, String)] = sqlConfEntries.synchronized { + sqlConfEntries.values.asScala.filter(_.isPublic).map { entry => + (entry.key, entry.defaultValueString, entry.doc) + }.toSeq + } + + private def setConfWithCheck(key: String, value: String): Unit = { + if (key.startsWith("spark.") && !key.startsWith("spark.sql.")) { + logWarning(s"Attempt to set non-Spark SQL config in SQLConf: key = $key, value = $value") + } + settings.put(key, value) + } + + def unsetConf(key: String): Unit = { + settings.remove(key) + } + + private[spark] def unsetConf(entry: ConfigEntry[_]): Unit = { + settings.remove(entry.key) + } + + def clear(): Unit = { + settings.clear() + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala new file mode 100644 index 0000000000000..69e3358d4eb9e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.sql.{ContinuousQueryManager, ExperimentalMethods, SQLContext, UDFRegistration} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.SessionCatalog +import org.apache.spark.sql.catalyst.optimizer.Optimizer +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.datasources.{DataSourceAnalysis, PreInsertCastAndRename, ResolveDataSource} +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} +import org.apache.spark.sql.util.ExecutionListenerManager + +/** + * A class that holds all session-specific state in a given [[SQLContext]]. + */ +private[sql] class SessionState(ctx: SQLContext) { + + // Note: These are all lazy vals because they depend on each other (e.g. conf) and we + // want subclasses to override some of the fields. Otherwise, we would get a lot of NPEs. + + /** + * SQL-specific key-value configurations. + */ + lazy val conf = new SQLConf + + lazy val experimentalMethods = new ExperimentalMethods + + /** + * Internal catalog for managing functions registered by the user. + */ + lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy() + + /** + * Internal catalog for managing table and database states. + */ + lazy val catalog = + new SessionCatalog( + ctx.externalCatalog, + ctx.functionResourceLoader, + functionRegistry, + conf) + + /** + * Interface exposed to the user for registering user-defined functions. + */ + lazy val udf: UDFRegistration = new UDFRegistration(functionRegistry) + + /** + * Logical query plan analyzer for resolving unresolved attributes and relations. + */ + lazy val analyzer: Analyzer = { + new Analyzer(catalog, conf) { + override val extendedResolutionRules = + PreInsertCastAndRename :: + DataSourceAnalysis :: + (if (conf.runSQLOnFile) new ResolveDataSource(ctx) :: Nil else Nil) + + override val extendedCheckRules = Seq(datasources.PreWriteCheck(conf, catalog)) + } + } + + /** + * Logical query plan optimizer. + */ + lazy val optimizer: Optimizer = new SparkOptimizer(experimentalMethods) + + /** + * Parser that extracts expressions, plans, table identifiers etc. from SQL texts. + */ + lazy val sqlParser: ParserInterface = SparkSqlParser + + /** + * Planner that converts optimized logical plans to physical plans. + */ + def planner: SparkPlanner = + new SparkPlanner(ctx.sparkContext, conf, experimentalMethods.extraStrategies) + + /** + * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s + * that listen for execution metrics. + */ + lazy val listenerManager: ExecutionListenerManager = new ExecutionListenerManager + + /** + * Interface to start and stop [[org.apache.spark.sql.ContinuousQuery]]s. + */ + lazy val continuousQueryManager: ContinuousQueryManager = new ContinuousQueryManager(ctx) +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/package-info.java b/sql/core/src/main/scala/org/apache/spark/sql/internal/package-info.java new file mode 100644 index 0000000000000..1e801cb6ee2a4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * All classes in this package are considered an internal API to Spark and + * are subject to change between minor releases. + */ +package org.apache.spark.sql.internal; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/package.scala new file mode 100644 index 0000000000000..c2394f42e552d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/package.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +/** + * All classes in this package are considered an internal API to Spark and + * are subject to change between minor releases. + */ +package object internal diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala new file mode 100644 index 0000000000000..467d8d62d1b7f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import org.apache.spark.sql.types.{DataType, MetadataBuilder} + +/** + * AggregatedDialect can unify multiple dialects into one virtual Dialect. + * Dialects are tried in order, and the first dialect that does not return a + * neutral element will will. + * + * @param dialects List of dialects. + */ +private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect { + + require(dialects.nonEmpty) + + override def canHandle(url : String): Boolean = + dialects.map(_.canHandle(url)).reduce(_ && _) + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + dialects.flatMap(_.getCatalystType(sqlType, typeName, size, md)).headOption + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = { + dialects.flatMap(_.getJDBCType(dt)).headOption + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala new file mode 100644 index 0000000000000..f12b6ca9d6ad2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import org.apache.spark.sql.types.{BooleanType, DataType, StringType} + +private object DB2Dialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:db2") + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Option(JdbcType("CLOB", java.sql.Types.CLOB)) + case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR)) + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala new file mode 100644 index 0000000000000..84f68e779c38c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Types + +import org.apache.spark.sql.types._ + + +private object DerbyDialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:derby") + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.REAL) Option(FloatType) else None + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Option(JdbcType("CLOB", java.sql.Types.CLOB)) + case ByteType => Option(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) + case ShortType => Option(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) + case BooleanType => Option(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) + // 31 is the maximum precision and 5 is the default scale for a Derby DECIMAL + case t: DecimalType if t.precision > 31 => + Option(JdbcType("DECIMAL(31,5)", java.sql.Types.DECIMAL)) + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 88ae83957a708..948106fd062a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.jdbc -import java.sql.Types +import java.sql.Connection -import org.apache.spark.sql.types._ import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.types._ /** * :: DeveloperApi :: @@ -53,7 +53,7 @@ case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int) * for the given Catalyst type. */ @DeveloperApi -abstract class JdbcDialect { +abstract class JdbcDialect extends Serializable { /** * Check if this dialect instance can handle a certain jdbc url. * @param url the jdbc url. @@ -99,6 +99,15 @@ abstract class JdbcDialect { s"SELECT * FROM $table WHERE 1=0" } + /** + * Override connection specific properties to run before a select is made. This is in place to + * allow dialects that need special treatment to optimize behavior. + * @param connection The connection object + * @param properties The connection properties. This is passed through from the relation. + */ + def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { + } + } /** @@ -115,11 +124,10 @@ abstract class JdbcDialect { @DeveloperApi object JdbcDialects { - private var dialects = List[JdbcDialect]() - /** * Register a dialect for use on all new matching jdbc [[org.apache.spark.sql.DataFrame]]. - * Readding an existing dialect will cause a move-to-front. + * Reading an existing dialect will cause a move-to-front. + * * @param dialect The new dialect. */ def registerDialect(dialect: JdbcDialect) : Unit = { @@ -128,18 +136,21 @@ object JdbcDialects { /** * Unregister a dialect. Does nothing if the dialect is not registered. + * * @param dialect The jdbc dialect. */ def unregisterDialect(dialect : JdbcDialect) : Unit = { dialects = dialects.filterNot(_ == dialect) } + private[this] var dialects = List[JdbcDialect]() + registerDialect(MySQLDialect) registerDialect(PostgresDialect) registerDialect(DB2Dialect) registerDialect(MsSqlServerDialect) registerDialect(DerbyDialect) - + registerDialect(OracleDialect) /** * Fetch the JdbcDialect class corresponding to a given database url. @@ -155,163 +166,8 @@ object JdbcDialects { } /** - * :: DeveloperApi :: - * AggregatedDialect can unify multiple dialects into one virtual Dialect. - * Dialects are tried in order, and the first dialect that does not return a - * neutral element will will. - * @param dialects List of dialects. - */ -@DeveloperApi -class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect { - - require(dialects.nonEmpty) - - override def canHandle(url : String): Boolean = - dialects.map(_.canHandle(url)).reduce(_ && _) - - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - dialects.flatMap(_.getCatalystType(sqlType, typeName, size, md)).headOption - } - - override def getJDBCType(dt: DataType): Option[JdbcType] = { - dialects.flatMap(_.getJDBCType(dt)).headOption - } -} - -/** - * :: DeveloperApi :: * NOOP dialect object, always returning the neutral element. */ -@DeveloperApi -case object NoopDialect extends JdbcDialect { +private object NoopDialect extends JdbcDialect { override def canHandle(url : String): Boolean = true } - -/** - * :: DeveloperApi :: - * Default postgres dialect, mapping bit/cidr/inet on read and string/binary/boolean on write. - */ -@DeveloperApi -case object PostgresDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { - Option(BinaryType) - } else if (sqlType == Types.OTHER && typeName.equals("cidr")) { - Option(StringType) - } else if (sqlType == Types.OTHER && typeName.equals("inet")) { - Option(StringType) - } else if (sqlType == Types.OTHER && typeName.equals("json")) { - Option(StringType) - } else if (sqlType == Types.OTHER && typeName.equals("jsonb")) { - Option(StringType) - } else None - } - - override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { - case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR)) - case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY)) - case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) - case _ => None - } - - override def getTableExistsQuery(table: String): String = { - s"SELECT 1 FROM $table LIMIT 1" - } - -} - -/** - * :: DeveloperApi :: - * Default mysql dialect to read bit/bitsets correctly. - */ -@DeveloperApi -case object MySQLDialect extends JdbcDialect { - override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { - // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as - // byte arrays instead of longs. - md.putLong("binarylong", 1) - Option(LongType) - } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) { - Option(BooleanType) - } else None - } - - override def quoteIdentifier(colName: String): String = { - s"`$colName`" - } - - override def getTableExistsQuery(table: String): String = { - s"SELECT 1 FROM $table LIMIT 1" - } -} - -/** - * :: DeveloperApi :: - * Default DB2 dialect, mapping string/boolean on write to valid DB2 types. - * By default string, and boolean gets mapped to db2 invalid types TEXT, and BIT(1). - */ -@DeveloperApi -case object DB2Dialect extends JdbcDialect { - - override def canHandle(url: String): Boolean = url.startsWith("jdbc:db2") - - override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { - case StringType => Some(JdbcType("CLOB", java.sql.Types.CLOB)) - case BooleanType => Some(JdbcType("CHAR(1)", java.sql.Types.CHAR)) - case _ => None - } -} - -/** - * :: DeveloperApi :: - * Default Microsoft SQL Server dialect, mapping the datetimeoffset types to a String on read. - */ -@DeveloperApi -case object MsSqlServerDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:sqlserver") - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - if (typeName.contains("datetimeoffset")) { - // String is recommend by Microsoft SQL Server for datetimeoffset types in non-MS clients - Option(StringType) - } else None - } - - override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { - case TimestampType => Some(JdbcType("DATETIME", java.sql.Types.TIMESTAMP)) - case _ => None - } -} - -/** - * :: DeveloperApi :: - * Default Apache Derby dialect, mapping real on read - * and string/byte/short/boolean/decimal on write. - */ -@DeveloperApi -case object DerbyDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:derby") - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - if (sqlType == Types.REAL) Option(FloatType) else None - } - - override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { - case StringType => Some(JdbcType("CLOB", java.sql.Types.CLOB)) - case ByteType => Some(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) - case ShortType => Some(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) - case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) - // 31 is the maximum precision and 5 is the default scale for a Derby DECIMAL - case (t: DecimalType) if (t.precision > 31) => - Some(JdbcType("DECIMAL(31,5)", java.sql.Types.DECIMAL)) - case _ => None - } - -} - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala new file mode 100644 index 0000000000000..3eb722b070d5d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import org.apache.spark.sql.types._ + + +private object MsSqlServerDialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:sqlserver") + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (typeName.contains("datetimeoffset")) { + // String is recommend by Microsoft SQL Server for datetimeoffset types in non-MS clients + Option(StringType) + } else { + None + } + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case TimestampType => Some(JdbcType("DATETIME", java.sql.Types.TIMESTAMP)) + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala new file mode 100644 index 0000000000000..e1717049f383d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Types + +import org.apache.spark.sql.types.{BooleanType, DataType, LongType, MetadataBuilder} + +private case object MySQLDialect extends JdbcDialect { + + override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { + // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as + // byte arrays instead of longs. + md.putLong("binarylong", 1) + Option(LongType) + } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) { + Option(BooleanType) + } else None + } + + override def quoteIdentifier(colName: String): String = { + s"`$colName`" + } + + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala new file mode 100644 index 0000000000000..46b3877a7cab3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Types + +import org.apache.spark.sql.types._ + + +private case object OracleDialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:oracle") + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + // Handle NUMBER fields that have no precision/scale in special way + // because JDBC ResultSetMetaData converts this to 0 precision and -127 scale + // For more details, please see + // https://github.com/apache/spark/pull/8780#issuecomment-145598968 + // and + // https://github.com/apache/spark/pull/8780#issuecomment-144541760 + if (sqlType == Types.NUMERIC && size == 0) { + // This is sub-optimal as we have to pick a precision/scale in advance whereas the data + // in Oracle is allowed to have different precision/scale for each value. + Option(DecimalType(DecimalType.MAX_PRECISION, 10)) + } else { + None + } + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Some(JdbcType("VARCHAR2(255)", java.sql.Types.VARCHAR)) + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala new file mode 100644 index 0000000000000..2d6c3974a833e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.{Connection, Types} + +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.types._ + + +private object PostgresDialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { + Some(BinaryType) + } else if (sqlType == Types.OTHER) { + Some(StringType) + } else if (sqlType == Types.ARRAY) { + val scale = md.build.getLong("scale").toInt + // postgres array type names start with underscore + toCatalystType(typeName.drop(1), size, scale).map(ArrayType(_)) + } else None + } + + private def toCatalystType( + typeName: String, + precision: Int, + scale: Int): Option[DataType] = typeName match { + case "bool" => Some(BooleanType) + case "bit" => Some(BinaryType) + case "int2" => Some(ShortType) + case "int4" => Some(IntegerType) + case "int8" | "oid" => Some(LongType) + case "float4" => Some(FloatType) + case "money" | "float8" => Some(DoubleType) + case "text" | "varchar" | "char" | "cidr" | "inet" | "json" | "jsonb" | "uuid" => + Some(StringType) + case "bytea" => Some(BinaryType) + case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType) + case "date" => Some(DateType) + case "numeric" | "decimal" => Some(DecimalType.bounded(precision, scale)) + case _ => None + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Some(JdbcType("TEXT", Types.CHAR)) + case BinaryType => Some(JdbcType("BYTEA", Types.BINARY)) + case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN)) + case FloatType => Some(JdbcType("FLOAT4", Types.FLOAT)) + case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE)) + case t: DecimalType => Some( + JdbcType(s"NUMERIC(${t.precision},${t.scale})", java.sql.Types.NUMERIC)) + case ArrayType(et, _) if et.isInstanceOf[AtomicType] => + getJDBCType(et).map(_.databaseTypeDefinition) + .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition)) + .map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY)) + case ByteType => throw new IllegalArgumentException(s"Unsupported type in postgresql: $dt"); + case _ => None + } + + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } + + override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { + super.beforeFetch(connection, properties) + + // According to the postgres jdbc documentation we need to be in autocommit=false if we actually + // want to have fetchsize be non 0 (all the rows). This allows us to not have to cache all the + // rows inside the driver when fetching. + // + // See: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor + // + if (properties.getOrElse("fetchsize", "0").toInt > 0) { + connection.setAutoCommit(false) + } + + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index a9c600b139b18..97e35bb10407e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -42,10 +42,5 @@ package object sql { @DeveloperApi type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] - /** - * Type alias for [[DataFrame]]. Kept here for backward source compatibility for Scala. - * @deprecated As of 1.3.0, replaced by `DataFrame`. - */ - @deprecated("use DataFrame", "1.3.0") - type SchemaRDD = DataFrame + type DataFrame = Dataset[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 3780cbbcc9631..9130e77ea5724 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -82,7 +82,24 @@ case class LessThanOrEqual(attribute: String, value: Any) extends Filter * * @since 1.3.0 */ -case class In(attribute: String, values: Array[Any]) extends Filter +case class In(attribute: String, values: Array[Any]) extends Filter { + override def hashCode(): Int = { + var h = attribute.hashCode + values.foreach { v => + h *= 41 + h += v.hashCode() + } + h + } + override def equals(o: Any): Boolean = o match { + case In(a, vs) => + a == attribute && vs.length == values.length && vs.zip(values).forall(x => x._1 == x._2) + case _ => false + } + override def toString: String = { + s"In($attribute, [${values.mkString(",")}]" + } +} /** * A filter that evaluates to `true` iff the attribute evaluates to null. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index e296d631f0f30..4b9bf8daae37c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -22,20 +22,23 @@ import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} -import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.execution.{FileRelation, RDDConversions} -import org.apache.spark.sql.execution.datasources.{PartitioningUtils, PartitionSpec, Partition} +import org.apache.spark.sql.execution.FileRelation +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.types.{StringType, StructType} -import org.apache.spark.sql._ import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.collection.BitSet /** * ::DeveloperApi:: @@ -55,7 +58,7 @@ trait DataSourceRegister { * overridden by children to provide a nice alias for the data source. For example: * * {{{ - * override def format(): String = "parquet" + * override def shortName(): String = "parquet" * }}} * * @since 1.5.0 @@ -123,43 +126,33 @@ trait SchemaRelationProvider { } /** - * ::Experimental:: - * Implemented by objects that produce relations for a specific kind of data source - * with a given schema and partitioned columns. When Spark SQL is given a DDL operation with a - * USING clause specified (to specify the implemented [[HadoopFsRelationProvider]]), a user defined - * schema, and an optional list of partition columns, this interface is used to pass in the - * parameters specified by a user. - * - * Users may specify the fully qualified class name of a given data source. When that class is - * not found Spark SQL will append the class name `DefaultSource` to the path, allowing for - * less verbose invocation. For example, 'org.apache.spark.sql.json' would resolve to the - * data source 'org.apache.spark.sql.json.DefaultSource' - * - * A new instance of this class will be instantiated each time a DDL call is made. - * - * The difference between a [[RelationProvider]] and a [[HadoopFsRelationProvider]] is - * that users need to provide a schema and a (possibly empty) list of partition columns when - * using a [[HadoopFsRelationProvider]]. A relation provider can inherits both [[RelationProvider]], - * and [[HadoopFsRelationProvider]] if it can support schema inference, user-specified - * schemas, and accessing partitioned relations. - * - * @since 1.4.0 + * Implemented by objects that can produce a streaming [[Source]] for a specific format or system. */ -@Experimental -trait HadoopFsRelationProvider { - /** - * Returns a new base relation with the given parameters, a user defined schema, and a list of - * partition columns. Note: the parameters' keywords are case insensitive and this insensitivity - * is enforced by the Map that is passed to the function. - * - * @param dataSchema Schema of data columns (i.e., columns that are not partition columns). - */ - def createRelation( +trait StreamSourceProvider { + + /** Returns the name and schema of the source that can be used to continually read data. */ + def sourceSchema( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) + + def createSource( + sqlContext: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source +} + +/** + * Implemented by objects that can produce a streaming [[Sink]] for a specific format or system. + */ +trait StreamSinkProvider { + def createSink( sqlContext: SQLContext, - paths: Array[String], - dataSchema: Option[StructType], - partitionColumns: Option[StructType], - parameters: Map[String, String]): HadoopFsRelation + parameters: Map[String, String], + partitionColumns: Seq[String]): Sink } /** @@ -168,19 +161,19 @@ trait HadoopFsRelationProvider { @DeveloperApi trait CreatableRelationProvider { /** - * Creates a relation with the given parameters based on the contents of the given - * DataFrame. The mode specifies the expected behavior of createRelation when - * data already exists. - * Right now, there are three modes, Append, Overwrite, and ErrorIfExists. - * Append mode means that when saving a DataFrame to a data source, if data already exists, - * contents of the DataFrame are expected to be appended to existing data. - * Overwrite mode means that when saving a DataFrame to a data source, if data already exists, - * existing data is expected to be overwritten by the contents of the DataFrame. - * ErrorIfExists mode means that when saving a DataFrame to a data source, - * if data already exists, an exception is expected to be thrown. - * - * @since 1.3.0 - */ + * Creates a relation with the given parameters based on the contents of the given + * DataFrame. The mode specifies the expected behavior of createRelation when + * data already exists. + * Right now, there are three modes, Append, Overwrite, and ErrorIfExists. + * Append mode means that when saving a DataFrame to a data source, if data already exists, + * contents of the DataFrame are expected to be appended to existing data. + * Overwrite mode means that when saving a DataFrame to a data source, if data already exists, + * existing data is expected to be overwritten by the contents of the DataFrame. + * ErrorIfExists mode means that when saving a DataFrame to a data source, + * if data already exists, an exception is expected to be thrown. + * + * @since 1.3.0 + */ def createRelation( sqlContext: SQLContext, mode: SaveMode, @@ -235,9 +228,11 @@ abstract class BaseRelation { def needConversion: Boolean = true /** - * Given an array of [[Filter]]s, returns an array of [[Filter]]s that this data source relation - * cannot handle. Spark SQL will apply all returned [[Filter]]s against rows returned by this - * data source relation. + * Returns the list of [[Filter]]s that this datasource may not be able to handle. + * These returned [[Filter]]s will be evaluated by Spark SQL after data is output by a scan. + * By default, this function will return all filters, as it is always safe to + * double evaluate a [[Filter]]. However, specific implementations can override this function to + * avoid double filtering when they are capable of processing a filter internally. * * @since 1.6.0 */ @@ -345,10 +340,13 @@ abstract class OutputWriterFactory extends Serializable { * @param dataSchema Schema of the rows to be written. Partition columns are not included in the * schema if the relation being written is partitioned. * @param context The Hadoop MapReduce task context. - * * @since 1.4.0 */ - def newInstance(path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter + private[sql] def newInstance( + path: String, + bucketId: Option[Int], // TODO: This doesn't belong here... + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter } /** @@ -391,181 +389,286 @@ abstract class OutputWriter { } /** - * ::Experimental:: - * A [[BaseRelation]] that provides much of the common code required for relations that store their - * data to an HDFS compatible filesystem. - * - * For the read path, similar to [[PrunedFilteredScan]], it can eliminate unneeded columns and - * filter using selected predicates before producing an RDD containing all matching tuples as - * [[Row]] objects. In addition, when reading from Hive style partitioned tables stored in file - * systems, it's able to discover partitioning information from the paths of input directories, and - * perform partition pruning before start reading the data. Subclasses of [[HadoopFsRelation()]] - * must override one of the three `buildScan` methods to implement the read path. - * - * For the write path, it provides the ability to write to both non-partitioned and partitioned - * tables. Directory layout of the partitioned tables is compatible with Hive. - * - * @constructor This constructor is for internal uses only. The [[PartitionSpec]] argument is for - * implementing metastore table conversion. - * - * @param maybePartitionSpec An [[HadoopFsRelation]] can be created with an optional - * [[PartitionSpec]], so that partition discovery can be skipped. - * - * @since 1.4.0 + * Acts as a container for all of the metadata required to read from a datasource. All discovery, + * resolution and merging logic for schemas and partitions has been removed. + * + * @param location A [[FileCatalog]] that can enumerate the locations of all the files that comprise + * this relation. + * @param partitionSchema The schema of the columns (if any) that are used to partition the relation + * @param dataSchema The schema of any remaining columns. Note that if any partition columns are + * present in the actual data files as well, they are preserved. + * @param bucketSpec Describes the bucketing (hash-partitioning of the files by some column values). + * @param fileFormat A file format that can be used to read and write the data in files. + * @param options Configuration used when reading / writing data. */ -@Experimental -abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[PartitionSpec]) - extends BaseRelation with FileRelation with Logging { +case class HadoopFsRelation( + sqlContext: SQLContext, + location: FileCatalog, + partitionSchema: StructType, + dataSchema: StructType, + bucketSpec: Option[BucketSpec], + fileFormat: FileFormat, + options: Map[String, String]) extends BaseRelation with FileRelation { + + val schema: StructType = { + val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet + StructType(dataSchema ++ partitionSchema.filterNot { column => + dataSchemaColumnNames.contains(column.name.toLowerCase) + }) + } - override def toString: String = getClass.getSimpleName + paths.mkString("[", ",", "]") + def partitionSchemaOption: Option[StructType] = + if (partitionSchema.isEmpty) None else Some(partitionSchema) + def partitionSpec: PartitionSpec = location.partitionSpec() - def this() = this(None) + def refresh(): Unit = location.refresh() - private val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) + override def toString: String = + s"HadoopFiles" - private val codegenEnabled = sqlContext.conf.codegenEnabled + /** Returns the list of files that will be read when scanning this relation. */ + override def inputFiles: Array[String] = + location.allFiles().map(_.getPath.toUri.toString).toArray - private var _partitionSpec: PartitionSpec = _ + override def sizeInBytes: Long = location.allFiles().map(_.getLen).sum +} - private class FileStatusCache { - var leafFiles = mutable.Map.empty[Path, FileStatus] +/** + * Used to read and write data stored in files to/from the [[InternalRow]] format. + */ +trait FileFormat { + /** + * When possible, this method should return the schema of the given `files`. When the format + * does not support inference, or no valid files are given should return None. In these cases + * Spark will require that user specify the schema manually. + */ + def inferSchema( + sqlContext: SQLContext, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] - var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]] + /** + * Prepares a read job and returns a potentially updated data source option [[Map]]. This method + * can be useful for collecting necessary global information for scanning input data. + */ + def prepareRead( + sqlContext: SQLContext, + options: Map[String, String], + files: Seq[FileStatus]): Map[String, String] = options - private def listLeafFiles(paths: Array[String]): Set[FileStatus] = { - if (paths.length >= sqlContext.conf.parallelPartitionDiscoveryThreshold) { - HadoopFsRelation.listLeafFilesInParallel(paths, hadoopConf, sqlContext.sparkContext) - } else { - val statuses = paths.flatMap { path => - val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(hadoopConf) - val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - - logInfo(s"Listing $qualified on driver") - Try(fs.listStatus(qualified)).getOrElse(Array.empty) - }.filterNot { status => - val name = status.getPath.getName - name.toLowerCase == "_temporary" || name.startsWith(".") - } + /** + * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can + * be put here. For example, user defined output committer can be configured here + * by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass. + */ + def prepareWrite( + sqlContext: SQLContext, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory - val (dirs, files) = statuses.partition(_.isDir) + /** + * Returns whether this format support returning columnar batch or not. + * + * TODO: we should just have different traits for the different formats. + */ + def supportBatch(sqlContext: SQLContext, dataSchema: StructType): Boolean = { + false + } - if (dirs.isEmpty) { - files.toSet - } else { - files.toSet ++ listLeafFiles(dirs.map(_.getPath.toString)) - } - } - } + /** + * Returns a function that can be used to read a single file in as an Iterator of InternalRow. + * + * @param dataSchema The global data schema. It can be either specified by the user, or + * reconciled/merged from all underlying data files. If any partition columns + * are contained in the files, they are preserved in this schema. + * @param partitionSchema The schema of the partition column row that will be present in each + * PartitionedFile. These columns should be appended to the rows that + * are produced by the iterator. + * @param requiredSchema The schema of the data that should be output for each row. This may be a + * subset of the columns that are present in the file if column pruning has + * occurred. + * @param filters A set of filters than can optionally be used to reduce the number of rows output + * @param options A set of string -> string configuration options. + * @return + */ + def buildReader( + sqlContext: SQLContext, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = { + // TODO: Remove this default implementation when the other formats have been ported + // Until then we guard in [[FileSourceStrategy]] to only call this method on supported formats. + throw new UnsupportedOperationException(s"buildReader is not supported for $this") + } +} - def refresh(): Unit = { - val files = listLeafFiles(paths) +/** + * A collection of data files from a partitioned relation, along with the partition values in the + * form of an [[InternalRow]]. + */ +case class Partition(values: InternalRow, files: Seq[FileStatus]) - leafFiles.clear() - leafDirToChildrenFiles.clear() +/** + * An interface for objects capable of enumerating the files that comprise a relation as well + * as the partitioning characteristics of those files. + */ +trait FileCatalog { + def paths: Seq[Path] + + def partitionSpec(): PartitionSpec + + /** + * Returns all valid files grouped into partitions when the data is partitioned. If the data is + * unpartitioned, this will return a single partition with not partition values. + * + * @param filters the filters used to prune which partitions are returned. These filters must + * only refer to partition columns and this method will only return files + * where these predicates are guaranteed to evaluate to `true`. Thus, these + * filters will not need to be evaluated again on the returned data. + */ + def listFiles(filters: Seq[Expression]): Seq[Partition] - leafFiles ++= files.map(f => f.getPath -> f).toMap - leafDirToChildrenFiles ++= files.toArray.groupBy(_.getPath.getParent) + def allFiles(): Seq[FileStatus] + + def getStatus(path: Path): Array[FileStatus] + + def refresh(): Unit +} + +/** + * A file catalog that caches metadata gathered by scanning all the files present in `paths` + * recursively. + * + * @param parameters as set of options to control discovery + * @param paths a list of paths to scan + * @param partitionSchema an optional partition schema that will be use to provide types for the + * discovered partitions + */ +class HDFSFileCatalog( + val sqlContext: SQLContext, + val parameters: Map[String, String], + val paths: Seq[Path], + val partitionSchema: Option[StructType]) + extends FileCatalog with Logging { + + private val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) + + var leafFiles = mutable.LinkedHashMap.empty[Path, FileStatus] + var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]] + var cachedPartitionSpec: PartitionSpec = _ + + def partitionSpec(): PartitionSpec = { + if (cachedPartitionSpec == null) { + cachedPartitionSpec = inferPartitioning(partitionSchema) } - } - private lazy val fileStatusCache = { - val cache = new FileStatusCache - cache.refresh() - cache + cachedPartitionSpec } - protected def cachedLeafStatuses(): Set[FileStatus] = { - fileStatusCache.leafFiles.values.toSet - } + refresh() - final private[sql] def partitionSpec: PartitionSpec = { - if (_partitionSpec == null) { - _partitionSpec = maybePartitionSpec - .flatMap { - case spec if spec.partitions.nonEmpty => - Some(spec.copy(partitionColumns = spec.partitionColumns.asNullable)) - case _ => - None - } - .orElse { - // We only know the partition columns and their data types. We need to discover - // partition values. - userDefinedPartitionColumns.map { partitionSchema => - val spec = discoverPartitions() - val partitionColumnTypes = spec.partitionColumns.map(_.dataType) - val castedPartitions = spec.partitions.map { case p @ Partition(values, path) => - val literals = partitionColumnTypes.zipWithIndex.map { case (dt, i) => - Literal.create(values.get(i, dt), dt) - } - val castedValues = partitionSchema.zip(literals).map { case (field, literal) => - Cast(literal, field.dataType).eval() - } - p.copy(values = InternalRow.fromSeq(castedValues)) - } - PartitionSpec(partitionSchema, castedPartitions) - } - } - .getOrElse { - if (sqlContext.conf.partitionDiscoveryEnabled()) { - discoverPartitions() - } else { - PartitionSpec(StructType(Nil), Array.empty[Partition]) - } - } + override def listFiles(filters: Seq[Expression]): Seq[Partition] = { + if (partitionSpec().partitionColumns.isEmpty) { + Partition(InternalRow.empty, allFiles().filterNot(_.getPath.getName startsWith "_")) :: Nil + } else { + prunePartitions(filters, partitionSpec()).map { + case PartitionDirectory(values, path) => + Partition( + values, + getStatus(path).filterNot(_.getPath.getName startsWith "_")) + } } - _partitionSpec } - /** - * Base paths of this relation. For partitioned relations, it should be either root directories - * of all partition directories. - * - * @since 1.4.0 - */ - def paths: Array[String] + protected def prunePartitions( + predicates: Seq[Expression], + partitionSpec: PartitionSpec): Seq[PartitionDirectory] = { + val PartitionSpec(partitionColumns, partitions) = partitionSpec + val partitionColumnNames = partitionColumns.map(_.name).toSet + val partitionPruningPredicates = predicates.filter { + _.references.map(_.name).toSet.subsetOf(partitionColumnNames) + } - override def inputFiles: Array[String] = cachedLeafStatuses().map(_.getPath.toString).toArray + if (partitionPruningPredicates.nonEmpty) { + val predicate = partitionPruningPredicates.reduce(expressions.And) - override def sizeInBytes: Long = cachedLeafStatuses().map(_.getLen).sum + val boundPredicate = InterpretedPredicate.create(predicate.transform { + case a: AttributeReference => + val index = partitionColumns.indexWhere(a.name == _.name) + BoundReference(index, partitionColumns(index).dataType, nullable = true) + }) - /** - * Partition columns. Can be either defined by [[userDefinedPartitionColumns]] or automatically - * discovered. Note that they should always be nullable. - * - * @since 1.4.0 - */ - final def partitionColumns: StructType = - userDefinedPartitionColumns.getOrElse(partitionSpec.partitionColumns) + val selected = partitions.filter { + case PartitionDirectory(values, _) => boundPredicate(values) + } + logInfo { + val total = partitions.length + val selectedSize = selected.length + val percentPruned = (1 - selectedSize.toDouble / total.toDouble) * 100 + s"Selected $selectedSize partitions out of $total, pruned $percentPruned% partitions." + } - /** - * Optional user defined partition columns. - * - * @since 1.4.0 - */ - def userDefinedPartitionColumns: Option[StructType] = None + selected + } else { + partitions + } + } + + def allFiles(): Seq[FileStatus] = leafFiles.values.toSeq + + def getStatus(path: Path): Array[FileStatus] = leafDirToChildrenFiles(path) + + private def listLeafFiles(paths: Seq[Path]): mutable.LinkedHashSet[FileStatus] = { + if (paths.length >= sqlContext.conf.parallelPartitionDiscoveryThreshold) { + HadoopFsRelation.listLeafFilesInParallel(paths, hadoopConf, sqlContext.sparkContext) + } else { + val statuses = paths.flatMap { path => + val fs = path.getFileSystem(hadoopConf) + logInfo(s"Listing $path on driver") + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(hadoopConf, this.getClass()) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + if (pathFilter != null) { + Try(fs.listStatus(path, pathFilter)).getOrElse(Array.empty) + } else { + Try(fs.listStatus(path)).getOrElse(Array.empty) + } + }.filterNot { status => + val name = status.getPath.getName + HadoopFsRelation.shouldFilterOut(name) + } - private[sql] def refresh(): Unit = { - fileStatusCache.refresh() - if (sqlContext.conf.partitionDiscoveryEnabled()) { - _partitionSpec = discoverPartitions() + val (dirs, files) = statuses.partition(_.isDirectory) + + // It uses [[LinkedHashSet]] since the order of files can affect the results. (SPARK-11500) + if (dirs.isEmpty) { + mutable.LinkedHashSet(files: _*) + } else { + mutable.LinkedHashSet(files: _*) ++ listLeafFiles(dirs.map(_.getPath)) + } } } - private def discoverPartitions(): PartitionSpec = { + def inferPartitioning(schema: Option[StructType]): PartitionSpec = { // We use leaf dirs containing data files to discover the schema. - val leafDirs = fileStatusCache.leafDirToChildrenFiles.keys.toSeq - userDefinedPartitionColumns match { + val leafDirs = leafDirToChildrenFiles.keys.toSeq + schema match { case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => val spec = PartitioningUtils.parsePartitions( - leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME, typeInference = false) + leafDirs, + PartitioningUtils.DEFAULT_PARTITION_NAME, + typeInference = false, + basePaths = basePaths) // Without auto inference, all of value in the `row` should be null or in StringType, // we need to cast into the data type that user specified. def castPartitionValuesToUserSchema(row: InternalRow) = { InternalRow((0 until row.numFields).map { i => Cast( - Literal.create(row.getString(i), StringType), + Literal.create(row.getUTF8String(i), StringType), userProvidedSchema.fields(i).dataType).eval() }: _*) } @@ -573,235 +676,72 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio PartitionSpec(userProvidedSchema, spec.partitions.map { part => part.copy(values = castPartitionValuesToUserSchema(part.values)) }) - case _ => - // user did not provide a partitioning schema - PartitioningUtils.parsePartitions(leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME, - typeInference = sqlContext.conf.partitionColumnTypeInferenceEnabled()) + PartitioningUtils.parsePartitions( + leafDirs, + PartitioningUtils.DEFAULT_PARTITION_NAME, + typeInference = sqlContext.conf.partitionColumnTypeInferenceEnabled(), + basePaths = basePaths) } } /** - * Schema of this relation. It consists of columns appearing in [[dataSchema]] and all partition - * columns not appearing in [[dataSchema]]. - * - * @since 1.4.0 + * Contains a set of paths that are considered as the base dirs of the input datasets. + * The partitioning discovery logic will make sure it will stop when it reaches any + * base path. By default, the paths of the dataset provided by users will be base paths. + * For example, if a user uses `sqlContext.read.parquet("/path/something=true/")`, the base path + * will be `/path/something=true/`, and the returned DataFrame will not contain a column of + * `something`. If users want to override the basePath. They can set `basePath` in the options + * to pass the new base path to the data source. + * For the above example, if the user-provided base path is `/path/`, the returned + * DataFrame will have the column of `something`. */ - override lazy val schema: StructType = { - val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet - StructType(dataSchema ++ partitionColumns.filterNot { column => - dataSchemaColumnNames.contains(column.name.toLowerCase) - }) - } - - final private[sql] def buildInternalScan( - requiredColumns: Array[String], - filters: Array[Filter], - inputPaths: Array[String], - broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { - val inputStatuses = inputPaths.flatMap { input => - val path = new Path(input) - - // First assumes `input` is a directory path, and tries to get all files contained in it. - fileStatusCache.leafDirToChildrenFiles.getOrElse( - path, - // Otherwise, `input` might be a file path - fileStatusCache.leafFiles.get(path).toArray - ).filter { status => - val name = status.getPath.getName - !name.startsWith("_") && !name.startsWith(".") - } + private def basePaths: Set[Path] = { + val userDefinedBasePath = parameters.get("basePath").map(basePath => Set(new Path(basePath))) + userDefinedBasePath.getOrElse { + // If the user does not provide basePath, we will just use paths. + paths.toSet + }.map { hdfsPath => + // Make the path qualified (consistent with listLeafFiles and listLeafFilesInParallel). + val fs = hdfsPath.getFileSystem(hadoopConf) + hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) } - - buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf) - } - - /** - * Specifies schema of actual data files. For partitioned relations, if one or more partitioned - * columns are contained in the data files, they should also appear in `dataSchema`. - * - * @since 1.4.0 - */ - def dataSchema: StructType - - /** - * For a non-partitioned relation, this method builds an `RDD[Row]` containing all rows within - * this relation. For partitioned relations, this method is called for each selected partition, - * and builds an `RDD[Row]` containing all rows within that single partition. - * - * @param inputFiles For a non-partitioned relation, it contains paths of all data files in the - * relation. For a partitioned relation, it contains paths of all data files in a single - * selected partition. - * - * @since 1.4.0 - */ - def buildScan(inputFiles: Array[FileStatus]): RDD[Row] = { - throw new UnsupportedOperationException( - "At least one buildScan() method should be overridden to read the relation.") } - /** - * For a non-partitioned relation, this method builds an `RDD[Row]` containing all rows within - * this relation. For partitioned relations, this method is called for each selected partition, - * and builds an `RDD[Row]` containing all rows within that single partition. - * - * @param requiredColumns Required columns. - * @param inputFiles For a non-partitioned relation, it contains paths of all data files in the - * relation. For a partitioned relation, it contains paths of all data files in a single - * selected partition. - * - * @since 1.4.0 - */ - // TODO Tries to eliminate the extra Catalyst-to-Scala conversion when `needConversion` is true - // - // PR #7626 separated `Row` and `InternalRow` completely. One of the consequences is that we can - // no longer treat an `InternalRow` containing Catalyst values as a `Row`. Thus we have to - // introduce another row value conversion for data sources whose `needConversion` is true. - def buildScan(requiredColumns: Array[String], inputFiles: Array[FileStatus]): RDD[Row] = { - // Yeah, to workaround serialization... - val dataSchema = this.dataSchema - val codegenEnabled = this.codegenEnabled - val needConversion = this.needConversion - - val requiredOutput = requiredColumns.map { col => - val field = dataSchema(col) - BoundReference(dataSchema.fieldIndex(col), field.dataType, field.nullable) - }.toSeq - - val rdd: RDD[Row] = buildScan(inputFiles) - val converted: RDD[InternalRow] = - if (needConversion) { - RDDConversions.rowToRowRdd(rdd, dataSchema.fields.map(_.dataType)) - } else { - rdd.asInstanceOf[RDD[InternalRow]] - } + def refresh(): Unit = { + val files = listLeafFiles(paths) - converted.mapPartitions { rows => - val buildProjection = if (codegenEnabled) { - GenerateMutableProjection.generate(requiredOutput, dataSchema.toAttributes) - } else { - () => new InterpretedMutableProjection(requiredOutput, dataSchema.toAttributes) - } + leafFiles.clear() + leafDirToChildrenFiles.clear() - val projectedRows = { - val mutableProjection = buildProjection() - rows.map(r => mutableProjection(r)) - } + leafFiles ++= files.map(f => f.getPath -> f) + leafDirToChildrenFiles ++= files.toArray.groupBy(_.getPath.getParent) - if (needConversion) { - val requiredSchema = StructType(requiredColumns.map(dataSchema(_))) - val toScala = CatalystTypeConverters.createToScalaConverter(requiredSchema) - projectedRows.map(toScala(_).asInstanceOf[Row]) - } else { - projectedRows - } - }.asInstanceOf[RDD[Row]] + cachedPartitionSpec = null } - /** - * For a non-partitioned relation, this method builds an `RDD[Row]` containing all rows within - * this relation. For partitioned relations, this method is called for each selected partition, - * and builds an `RDD[Row]` containing all rows within that single partition. - * - * @param requiredColumns Required columns. - * @param filters Candidate filters to be pushed down. The actual filter should be the conjunction - * of all `filters`. The pushed down filters are currently purely an optimization as they - * will all be evaluated again. This means it is safe to use them with methods that produce - * false positives such as filtering partitions based on a bloom filter. - * @param inputFiles For a non-partitioned relation, it contains paths of all data files in the - * relation. For a partitioned relation, it contains paths of all data files in a single - * selected partition. - * - * @since 1.4.0 - */ - def buildScan( - requiredColumns: Array[String], - filters: Array[Filter], - inputFiles: Array[FileStatus]): RDD[Row] = { - buildScan(requiredColumns, inputFiles) + override def equals(other: Any): Boolean = other match { + case hdfs: HDFSFileCatalog => paths.toSet == hdfs.paths.toSet + case _ => false } - /** - * For a non-partitioned relation, this method builds an `RDD[Row]` containing all rows within - * this relation. For partitioned relations, this method is called for each selected partition, - * and builds an `RDD[Row]` containing all rows within that single partition. - * - * Note: This interface is subject to change in future. - * - * @param requiredColumns Required columns. - * @param filters Candidate filters to be pushed down. The actual filter should be the conjunction - * of all `filters`. The pushed down filters are currently purely an optimization as they - * will all be evaluated again. This means it is safe to use them with methods that produce - * false positives such as filtering partitions based on a bloom filter. - * @param inputFiles For a non-partitioned relation, it contains paths of all data files in the - * relation. For a partitioned relation, it contains paths of all data files in a single - * selected partition. - * @param broadcastedConf A shared broadcast Hadoop Configuration, which can be used to reduce the - * overhead of broadcasting the Configuration for every Hadoop RDD. - * - * @since 1.4.0 - */ - private[sql] def buildScan( - requiredColumns: Array[String], - filters: Array[Filter], - inputFiles: Array[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { - buildScan(requiredColumns, filters, inputFiles) - } + override def hashCode(): Int = paths.toSet.hashCode() +} - /** - * For a non-partitioned relation, this method builds an `RDD[InternalRow]` containing all rows - * within this relation. For partitioned relations, this method is called for each selected - * partition, and builds an `RDD[InternalRow]` containing all rows within that single partition. - * - * Note: - * - * 1. Rows contained in the returned `RDD[InternalRow]` are assumed to be `UnsafeRow`s. - * 2. This interface is subject to change in future. - * - * @param requiredColumns Required columns. - * @param filters Candidate filters to be pushed down. The actual filter should be the conjunction - * of all `filters`. The pushed down filters are currently purely an optimization as they - * will all be evaluated again. This means it is safe to use them with methods that produce - * false positives such as filtering partitions based on a bloom filter. - * @param inputFiles For a non-partitioned relation, it contains paths of all data files in the - * relation. For a partitioned relation, it contains paths of all data files in a single - * selected partition. - * @param broadcastedConf A shared broadcast Hadoop Configuration, which can be used to reduce the - * overhead of broadcasting the Configuration for every Hadoop RDD. - */ - private[sql] def buildInternalScan( - requiredColumns: Array[String], - filters: Array[Filter], - inputFiles: Array[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { - val requiredSchema = StructType(requiredColumns.map(dataSchema.apply)) - val internalRows = { - val externalRows = buildScan(requiredColumns, filters, inputFiles, broadcastedConf) - execution.RDDConversions.rowToRowRdd(externalRows, requiredSchema.map(_.dataType)) - } +/** + * Helper methods for gathering metadata from HDFS. + */ +private[sql] object HadoopFsRelation extends Logging { - internalRows.mapPartitions { iterator => - val unsafeProjection = UnsafeProjection.create(requiredSchema) - iterator.map(unsafeProjection) - } + /** Checks if we should filter out this path name. */ + def shouldFilterOut(pathName: String): Boolean = { + // TODO: We should try to filter out all files/dirs starting with "." or "_". + // The only reason that we are not doing it now is that Parquet needs to find those + // metadata files from leaf files returned by this methods. We should refactor + // this logic to not mix metadata files with data files. + pathName == "_SUCCESS" || pathName == "_temporary" || pathName.startsWith(".") } - /** - * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can - * be put here. For example, user defined output committer can be configured here - * by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass. - * - * Note that the only side effect expected here is mutating `job` via its setters. Especially, - * Spark SQL caches [[BaseRelation]] instances for performance, mutating relation internal states - * may cause unexpected behaviors. - * - * @since 1.4.0 - */ - def prepareJobForWrite(job: Job): OutputWriterFactory -} - -private[sql] object HadoopFsRelation extends Logging { // We don't filter files/directories whose name start with "_" except "_temporary" here, as // specific data sources may take advantages over them (e.g. Parquet _metadata and // _common_metadata files). "_temporary" directories are explicitly ignored since failed @@ -810,11 +750,21 @@ private[sql] object HadoopFsRelation extends Logging { def listLeafFiles(fs: FileSystem, status: FileStatus): Array[FileStatus] = { logInfo(s"Listing ${status.getPath}") val name = status.getPath.getName.toLowerCase - if (name == "_temporary" || name.startsWith(".")) { + if (shouldFilterOut(name)) { Array.empty } else { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(fs.getConf, this.getClass()) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + val statuses = + if (pathFilter != null) { + val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDirectory) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } else { + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDirectory) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } + statuses.filterNot(status => shouldFilterOut(status.getPath.getName)) } } @@ -832,31 +782,32 @@ private[sql] object HadoopFsRelation extends Logging { accessTime: Long) def listLeafFilesInParallel( - paths: Array[String], + paths: Seq[Path], hadoopConf: Configuration, - sparkContext: SparkContext): Set[FileStatus] = { + sparkContext: SparkContext): mutable.LinkedHashSet[FileStatus] = { logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}") val serializableConfiguration = new SerializableConfiguration(hadoopConf) - val fakeStatuses = sparkContext.parallelize(paths).flatMap { path => - val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(serializableConfiguration.value) - val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - Try(listLeafFiles(fs, fs.getFileStatus(qualified))).getOrElse(Array.empty) + val serializedPaths = paths.map(_.toString) + + val fakeStatuses = sparkContext.parallelize(serializedPaths).map(new Path(_)).flatMap { path => + val fs = path.getFileSystem(serializableConfiguration.value) + Try(listLeafFiles(fs, fs.getFileStatus(path))).getOrElse(Array.empty) }.map { status => FakeFileStatus( status.getPath.toString, status.getLen, - status.isDir, + status.isDirectory, status.getReplication, status.getBlockSize, status.getModificationTime, status.getAccessTime) }.collect() - fakeStatuses.map { f => + val hadoopFakeStatuses = fakeStatuses.map { f => new FileStatus( f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime, new Path(f.path)) - }.toSet + } + mutable.LinkedHashSet(hadoopFakeStatuses: _*) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index 8d4854b698ed7..695a5ad78adc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.test -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ /** @@ -26,7 +26,12 @@ import org.apache.spark.sql.types._ * @param y y coordinate */ @SQLUserDefinedType(udt = classOf[ExamplePointUDT]) -private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable +private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable { + override def equals(other: Any): Boolean = other match { + case that: ExamplePoint => this.x == that.x && this.y == that.y + case _ => false + } +} /** * User-defined type for [[ExamplePoint]]. @@ -37,14 +42,11 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT" - override def serialize(obj: Any): GenericArrayData = { - obj match { - case p: ExamplePoint => - val output = new Array[Any](2) - output(0) = p.x - output(1) = p.y - new GenericArrayData(output) - } + override def serialize(p: ExamplePoint): GenericArrayData = { + val output = new Array[Any](2) + output(0) = p.x + output(1) = p.y + new GenericArrayData(output) } override def deserialize(datum: Any): ExamplePoint = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala new file mode 100644 index 0000000000000..ba1facf11b7d5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.ContinuousQuery +import org.apache.spark.sql.util.ContinuousQueryListener._ + +/** + * :: Experimental :: + * Interface for listening to events related to [[ContinuousQuery ContinuousQueries]]. + * @note The methods are not thread-safe as they may be called from different threads. + */ +@Experimental +abstract class ContinuousQueryListener { + + /** + * Called when a query is started. + * @note This is called synchronously with + * [[org.apache.spark.sql.DataFrameWriter `DataFrameWriter.startStream()`]], + * that is, `onQueryStart` will be called on all listeners before + * `DataFrameWriter.startStream()` returns the corresponding [[ContinuousQuery]]. Please + * don't block this method as it will block your query. + */ + def onQueryStarted(queryStarted: QueryStarted): Unit + + /** + * Called when there is some status update (ingestion rate updated, etc.) + * + * @note This method is asynchronous. The status in [[ContinuousQuery]] will always be + * latest no matter when this method is called. Therefore, the status of [[ContinuousQuery]] + * may be changed before/when you process the event. E.g., you may find [[ContinuousQuery]] + * is terminated when you are processing [[QueryProgress]]. + */ + def onQueryProgress(queryProgress: QueryProgress): Unit + + /** Called when a query is stopped, with or without error */ + def onQueryTerminated(queryTerminated: QueryTerminated): Unit +} + + +/** + * :: Experimental :: + * Companion object of [[ContinuousQueryListener]] that defines the listener events. + */ +@Experimental +object ContinuousQueryListener { + + /** Base type of [[ContinuousQueryListener]] events */ + trait Event + + /** Event representing the start of a query */ + class QueryStarted private[sql](val query: ContinuousQuery) extends Event + + /** Event representing any progress updates in a query */ + class QueryProgress private[sql](val query: ContinuousQuery) extends Event + + /** Event representing that termination of a query */ + class QueryTerminated private[sql](val query: ContinuousQuery) extends Event +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index 909a8abd225b8..3cae5355eecc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -18,37 +18,39 @@ package org.apache.spark.sql.util import java.util.concurrent.locks.ReentrantReadWriteLock + import scala.collection.mutable.ListBuffer +import scala.util.control.NonFatal import org.apache.spark.annotation.{DeveloperApi, Experimental} -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.QueryExecution - /** + * :: Experimental :: * The interface of query execution listener that can be used to analyze execution metrics. * - * Note that implementations should guarantee thread-safety as they will be used in a non - * thread-safe way. + * Note that implementations should guarantee thread-safety as they can be invoked by + * multiple different threads. */ @Experimental trait QueryExecutionListener { /** * A callback function that will be called when a query executed successfully. - * Implementations should guarantee thread-safe. + * Note that this can be invoked by multiple different threads. * - * @param funcName the name of the action that triggered this query. + * @param funcName name of the action that triggered this query. * @param qe the QueryExecution object that carries detail information like logical plan, * physical plan, etc. - * @param duration the execution time for this query in nanoseconds. + * @param durationNs the execution time for this query in nanoseconds. */ @DeveloperApi - def onSuccess(funcName: String, qe: QueryExecution, duration: Long) + def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit /** * A callback function that will be called when a query execution failed. - * Implementations should guarantee thread-safe. + * Note that this can be invoked by multiple different threads. * * @param funcName the name of the action that triggered this query. * @param qe the QueryExecution object that carries detail information like logical plan, @@ -56,34 +58,20 @@ trait QueryExecutionListener { * @param exception the exception that failed this query. */ @DeveloperApi - def onFailure(funcName: String, qe: QueryExecution, exception: Exception) + def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit } -@Experimental -class ExecutionListenerManager extends Logging { - private[this] val listeners = ListBuffer.empty[QueryExecutionListener] - private[this] val lock = new ReentrantReadWriteLock() - - /** Acquires a read lock on the cache for the duration of `f`. */ - private def readLock[A](f: => A): A = { - val rl = lock.readLock() - rl.lock() - try f finally { - rl.unlock() - } - } - /** Acquires a write lock on the cache for the duration of `f`. */ - private def writeLock[A](f: => A): A = { - val wl = lock.writeLock() - wl.lock() - try f finally { - wl.unlock() - } - } +/** + * :: Experimental :: + * + * Manager for [[QueryExecutionListener]]. See [[org.apache.spark.sql.SQLContext.listenerManager]]. + */ +@Experimental +class ExecutionListenerManager private[sql] () extends Logging { /** - * Registers the specified QueryExecutionListener. + * Registers the specified [[QueryExecutionListener]]. */ @DeveloperApi def register(listener: QueryExecutionListener): Unit = writeLock { @@ -91,7 +79,7 @@ class ExecutionListenerManager extends Logging { } /** - * Unregisters the specified QueryExecutionListener. + * Unregisters the specified [[QueryExecutionListener]]. */ @DeveloperApi def unregister(listener: QueryExecutionListener): Unit = writeLock { @@ -99,38 +87,59 @@ class ExecutionListenerManager extends Logging { } /** - * clears out all registered QueryExecutionListeners. + * Removes all the registered [[QueryExecutionListener]]. */ @DeveloperApi def clear(): Unit = writeLock { listeners.clear() } - private[sql] def onSuccess( - funcName: String, - qe: QueryExecution, - duration: Long): Unit = readLock { - withErrorHandling { listener => - listener.onSuccess(funcName, qe, duration) + private[sql] def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + readLock { + withErrorHandling { listener => + listener.onSuccess(funcName, qe, duration) + } } } - private[sql] def onFailure( - funcName: String, - qe: QueryExecution, - exception: Exception): Unit = readLock { - withErrorHandling { listener => - listener.onFailure(funcName, qe, exception) + private[sql] def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + readLock { + withErrorHandling { listener => + listener.onFailure(funcName, qe, exception) + } } } + private[this] val listeners = ListBuffer.empty[QueryExecutionListener] + + /** A lock to prevent updating the list of listeners while we are traversing through them. */ + private[this] val lock = new ReentrantReadWriteLock() + private def withErrorHandling(f: QueryExecutionListener => Unit): Unit = { for (listener <- listeners) { try { f(listener) } catch { - case e: Exception => logWarning("error executing query execution listener", e) + case NonFatal(e) => logWarning("Error executing query execution listener", e) } } } + + /** Acquires a read lock on the cache for the duration of `f`. */ + private def readLock[A](f: => A): A = { + val rl = lock.readLock() + rl.lock() + try f finally { + rl.unlock() + } + } + + /** Acquires a write lock on the cache for the duration of `f`. */ + private def writeLock[A](f: => A): A = { + val wl = lock.writeLock() + wl.lock() + try f finally { + wl.unlock() + } + } } diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroArrayOfArray.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroArrayOfArray.java index ee327827903e5..8de0b06b162c6 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroArrayOfArray.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroArrayOfArray.java @@ -129,6 +129,7 @@ public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfA } @Override + @SuppressWarnings(value="unchecked") public AvroArrayOfArray build() { try { AvroArrayOfArray record = new AvroArrayOfArray(); diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroMapOfArray.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroMapOfArray.java index 727f6a7bf733e..29f3109f83a15 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroMapOfArray.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroMapOfArray.java @@ -129,6 +129,7 @@ public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArr } @Override + @SuppressWarnings(value="unchecked") public AvroMapOfArray build() { try { AvroMapOfArray record = new AvroMapOfArray(); diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroNonNullableArrays.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroNonNullableArrays.java index 934793f42f9c9..c5522ed1e53e5 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroNonNullableArrays.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroNonNullableArrays.java @@ -182,6 +182,7 @@ public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNulla } @Override + @SuppressWarnings(value="unchecked") public AvroNonNullableArrays build() { try { AvroNonNullableArrays record = new AvroNonNullableArrays(); diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java index a7bf4841919c5..f84e3f2d61efb 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java @@ -182,6 +182,7 @@ public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Build } @Override + @SuppressWarnings(value="unchecked") public Nested build() { try { Nested record = new Nested(); diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java index ef12d193f916c..46fc608398ccf 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java @@ -235,6 +235,7 @@ public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroC } @Override + @SuppressWarnings(value="unchecked") public ParquetAvroCompat build() { try { ParquetAvroCompat record = new ParquetAvroCompat(); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index 7b50aad4ad498..189cc3972c9ba 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -32,7 +32,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -107,15 +107,15 @@ public Row call(Person person) throws Exception { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - DataFrame df = sqlContext.applySchema(rowRDD, schema); + Dataset df = sqlContext.createDataFrame(rowRDD, schema); df.registerTempTable("people"); - Row[] actual = sqlContext.sql("SELECT * FROM people").collect(); + List actual = sqlContext.sql("SELECT * FROM people").collectAsList(); - List expected = new ArrayList(2); + List expected = new ArrayList<>(2); expected.add(RowFactory.create("Michael", 29)); expected.add(RowFactory.create("Yin", 28)); - Assert.assertEquals(expected, Arrays.asList(actual)); + Assert.assertEquals(expected, actual); } @Test @@ -143,14 +143,15 @@ public Row call(Person person) { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - DataFrame df = sqlContext.applySchema(rowRDD, schema); + Dataset df = sqlContext.createDataFrame(rowRDD, schema); df.registerTempTable("people"); - List actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function() { - @Override - public String call(Row row) { - return row.getString(0) + "_" + row.get(1); - } - }).collect(); + List actual = sqlContext.sql("SELECT * FROM people").toJavaRDD() + .map(new Function() { + @Override + public String call(Row row) { + return row.getString(0) + "_" + row.get(1); + } + }).collect(); List expected = new ArrayList<>(2); expected.add("Michael_29"); @@ -198,14 +199,14 @@ public void applySchemaToJSON() { null, "this is another simple string.")); - DataFrame df1 = sqlContext.read().json(jsonRDD); + Dataset df1 = sqlContext.read().json(jsonRDD); StructType actualSchema1 = df1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); df1.registerTempTable("jsonTable1"); List actual1 = sqlContext.sql("select * from jsonTable1").collectAsList(); Assert.assertEquals(expectedResult, actual1); - DataFrame df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD); + Dataset df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD); StructType actualSchema2 = df2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); df2.registerTempTable("jsonTable2"); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 49f516e86d754..1eb680dc4c029 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -18,10 +18,14 @@ package test.org.apache.spark.sql; import java.io.Serializable; +import java.net.URISyntaxException; +import java.net.URL; import java.util.Arrays; +import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Map; +import java.util.ArrayList; import scala.collection.JavaConverters; import scala.collection.Seq; @@ -34,10 +38,12 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; -import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.types.*; +import org.apache.spark.util.sketch.CountMinSketch; +import static org.apache.spark.sql.functions.*; import static org.apache.spark.sql.types.DataTypes.*; +import org.apache.spark.util.sketch.BloomFilter; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; @@ -61,8 +67,15 @@ public void tearDown() { @Test public void testExecution() { - DataFrame df = context.table("testData").filter("key = 1"); - Assert.assertEquals(1, df.select("key").collect()[0].get(0)); + Dataset df = context.table("testData").filter("key = 1"); + Assert.assertEquals(1, df.select("key").collectAsList().get(0).get(0)); + } + + @Test + public void testCollectAndTake() { + Dataset df = context.table("testData").filter("key = 1 or key = 2 or key = 3"); + Assert.assertEquals(3, df.select("key").collectAsList().size()); + Assert.assertEquals(2, df.select("key").takeAsList(2).size()); } /** @@ -70,7 +83,7 @@ public void testExecution() { */ @Test public void testVarargMethods() { - DataFrame df = context.table("testData"); + Dataset df = context.table("testData"); df.toDF("key1", "value1"); @@ -99,7 +112,7 @@ public void testVarargMethods() { df.select(coalesce(col("key"))); // Varargs with mathfunctions - DataFrame df2 = context.table("testData2"); + Dataset df2 = context.table("testData2"); df2.select(exp("a"), exp("b")); df2.select(exp(log("a"))); df2.select(pow("a", "a"), pow("b", 2.0)); @@ -113,7 +126,7 @@ public void testVarargMethods() { @Ignore public void testShow() { // This test case is intended ignored, but to make sure it compiles correctly - DataFrame df = context.table("testData"); + Dataset df = context.table("testData"); df.show(); df.show(1000); } @@ -141,7 +154,7 @@ public List getD() { } } - void validateDataFrameWithBeans(Bean bean, DataFrame df) { + void validateDataFrameWithBeans(Bean bean, Dataset df) { StructType schema = df.schema(); Assert.assertEquals(new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()), schema.apply("a")); @@ -181,7 +194,7 @@ void validateDataFrameWithBeans(Bean bean, DataFrame df) { public void testCreateDataFrameFromLocalJavaBeans() { Bean bean = new Bean(); List data = Arrays.asList(bean); - DataFrame df = context.createDataFrame(data, Bean.class); + Dataset df = context.createDataFrame(data, Bean.class); validateDataFrameWithBeans(bean, df); } @@ -189,7 +202,7 @@ public void testCreateDataFrameFromLocalJavaBeans() { public void testCreateDataFrameFromJavaBeans() { Bean bean = new Bean(); JavaRDD rdd = jsc.parallelize(Arrays.asList(bean)); - DataFrame df = context.createDataFrame(rdd, Bean.class); + Dataset df = context.createDataFrame(rdd, Bean.class); validateDataFrameWithBeans(bean, df); } @@ -197,9 +210,22 @@ public void testCreateDataFrameFromJavaBeans() { public void testCreateDataFromFromList() { StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true))); List rows = Arrays.asList(RowFactory.create(0)); - DataFrame df = context.createDataFrame(rows, schema); - Row[] result = df.collect(); - Assert.assertEquals(1, result.length); + Dataset df = context.createDataFrame(rows, schema); + List result = df.collectAsList(); + Assert.assertEquals(1, result.size()); + } + + @Test + public void testCreateStructTypeFromList(){ + List fields1 = new ArrayList<>(); + fields1.add(new StructField("id", DataTypes.StringType, true, Metadata.empty())); + StructType schema1 = StructType$.MODULE$.apply(fields1); + Assert.assertEquals(0, schema1.fieldIndex("id")); + + List fields2 = + Arrays.asList(new StructField("id", DataTypes.StringType, true, Metadata.empty())); + StructType schema2 = StructType$.MODULE$.apply(fields2); + Assert.assertEquals(0, schema2.fieldIndex("id")); } private static final Comparator crosstabRowComparator = new Comparator() { @@ -213,14 +239,14 @@ public int compare(Row row1, Row row2) { @Test public void testCrosstab() { - DataFrame df = context.table("testData2"); - DataFrame crosstab = df.stat().crosstab("a", "b"); + Dataset df = context.table("testData2"); + Dataset crosstab = df.stat().crosstab("a", "b"); String[] columnNames = crosstab.schema().fieldNames(); Assert.assertEquals("a_b", columnNames[0]); - Assert.assertEquals("1", columnNames[1]); - Assert.assertEquals("2", columnNames[2]); - Row[] rows = crosstab.collect(); - Arrays.sort(rows, crosstabRowComparator); + Assert.assertEquals("2", columnNames[1]); + Assert.assertEquals("1", columnNames[2]); + List rows = crosstab.collectAsList(); + Collections.sort(rows, crosstabRowComparator); Integer count = 1; for (Row row : rows) { Assert.assertEquals(row.get(0).toString(), count.toString()); @@ -232,32 +258,143 @@ public void testCrosstab() { @Test public void testFrequentItems() { - DataFrame df = context.table("testData2"); + Dataset df = context.table("testData2"); String[] cols = {"a"}; - DataFrame results = df.stat().freqItems(cols, 0.2); - Assert.assertTrue(results.collect()[0].getSeq(0).contains(1)); + Dataset results = df.stat().freqItems(cols, 0.2); + Assert.assertTrue(results.collectAsList().get(0).getSeq(0).contains(1)); } @Test public void testCorrelation() { - DataFrame df = context.table("testData2"); + Dataset df = context.table("testData2"); Double pearsonCorr = df.stat().corr("a", "b", "pearson"); Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6); } @Test public void testCovariance() { - DataFrame df = context.table("testData2"); + Dataset df = context.table("testData2"); Double result = df.stat().cov("a", "b"); Assert.assertTrue(Math.abs(result) < 1.0e-6); } @Test public void testSampleBy() { - DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); - DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); - Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); - Row[] expected = {RowFactory.create(0, 5), RowFactory.create(1, 8)}; - Assert.assertArrayEquals(expected, actual); + Dataset df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); + Dataset sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); + List actual = sampled.groupBy("key").count().orderBy("key").collectAsList(); + Assert.assertEquals(0, actual.get(0).getLong(0)); + Assert.assertTrue(0 <= actual.get(0).getLong(1) && actual.get(0).getLong(1) <= 8); + Assert.assertEquals(1, actual.get(1).getLong(0)); + Assert.assertTrue(2 <= actual.get(1).getLong(1) && actual.get(1).getLong(1) <= 13); + } + + @Test + public void pivot() { + Dataset df = context.table("courseSales"); + List actual = df.groupBy("year") + .pivot("course", Arrays.asList("dotNET", "Java")) + .agg(sum("earnings")).orderBy("year").collectAsList(); + + Assert.assertEquals(2012, actual.get(0).getInt(0)); + Assert.assertEquals(15000.0, actual.get(0).getDouble(1), 0.01); + Assert.assertEquals(20000.0, actual.get(0).getDouble(2), 0.01); + + Assert.assertEquals(2013, actual.get(1).getInt(0)); + Assert.assertEquals(48000.0, actual.get(1).getDouble(1), 0.01); + Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01); + } + + private String getResource(String resource) { + try { + // The following "getResource" has different behaviors in SBT and Maven. + // When running in Jenkins, the file path may contain "@" when there are multiple + // SparkPullRequestBuilders running in the same worker + // (e.g., /home/jenkins/workspace/SparkPullRequestBuilder@2) + // When running in SBT, "@" in the file path will be returned as "@", however, + // when running in Maven, "@" will be encoded as "%40". + // Therefore, we convert it to URI then call "getPath" to decode it back so that it can both + // work both in SBT and Maven. + URL url = Thread.currentThread().getContextClassLoader().getResource(resource); + return url.toURI().getPath(); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + @Test + public void testGenericLoad() { + Dataset df1 = context.read().format("text").load(getResource("text-suite.txt")); + Assert.assertEquals(4L, df1.count()); + + Dataset df2 = context.read().format("text").load( + getResource("text-suite.txt"), + getResource("text-suite2.txt")); + Assert.assertEquals(5L, df2.count()); + } + + @Test + public void testTextLoad() { + Dataset ds1 = context.read().text(getResource("text-suite.txt")); + Assert.assertEquals(4L, ds1.count()); + + Dataset ds2 = context.read().text( + getResource("text-suite.txt"), + getResource("text-suite2.txt")); + Assert.assertEquals(5L, ds2.count()); + } + + @Test + public void testCountMinSketch() { + Dataset df = context.range(1000); + + CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42); + Assert.assertEquals(sketch1.totalCount(), 1000); + Assert.assertEquals(sketch1.depth(), 10); + Assert.assertEquals(sketch1.width(), 20); + + CountMinSketch sketch2 = df.stat().countMinSketch(col("id"), 10, 20, 42); + Assert.assertEquals(sketch2.totalCount(), 1000); + Assert.assertEquals(sketch2.depth(), 10); + Assert.assertEquals(sketch2.width(), 20); + + CountMinSketch sketch3 = df.stat().countMinSketch("id", 0.001, 0.99, 42); + Assert.assertEquals(sketch3.totalCount(), 1000); + Assert.assertEquals(sketch3.relativeError(), 0.001, 1e-4); + Assert.assertEquals(sketch3.confidence(), 0.99, 5e-3); + + CountMinSketch sketch4 = df.stat().countMinSketch(col("id"), 0.001, 0.99, 42); + Assert.assertEquals(sketch4.totalCount(), 1000); + Assert.assertEquals(sketch4.relativeError(), 0.001, 1e-4); + Assert.assertEquals(sketch4.confidence(), 0.99, 5e-3); + } + + @Test + public void testBloomFilter() { + Dataset df = context.range(1000); + + BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03); + Assert.assertTrue(filter1.expectedFpp() - 0.03 < 1e-3); + for (int i = 0; i < 1000; i++) { + Assert.assertTrue(filter1.mightContain(i)); + } + + BloomFilter filter2 = df.stat().bloomFilter(col("id").multiply(3), 1000, 0.03); + Assert.assertTrue(filter2.expectedFpp() - 0.03 < 1e-3); + for (int i = 0; i < 1000; i++) { + Assert.assertTrue(filter2.mightContain(i * 3)); + } + + BloomFilter filter3 = df.stat().bloomFilter("id", 1000, 64 * 5); + Assert.assertTrue(filter3.bitSize() == 64 * 5); + for (int i = 0; i < 1000; i++) { + Assert.assertTrue(filter3.mightContain(i)); + } + + BloomFilter filter4 = df.stat().bloomFilter(col("id").multiply(3), 1000, 64 * 5); + Assert.assertTrue(filter4.bitSize() == 64 * 5); + for (int i = 0; i < 1000; i++) { + Assert.assertTrue(filter4.mightContain(i * 3)); + } } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java new file mode 100644 index 0000000000000..5abd62cbc245b --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -0,0 +1,819 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import java.io.Serializable; +import java.math.BigDecimal; +import java.sql.Date; +import java.sql.Timestamp; +import java.util.*; + +import com.google.common.base.Objects; +import org.junit.rules.ExpectedException; +import scala.Tuple2; +import scala.Tuple3; +import scala.Tuple4; +import scala.Tuple5; + +import org.junit.*; + +import org.apache.spark.Accumulator; +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.function.*; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.*; +import org.apache.spark.sql.test.TestSQLContext; +import org.apache.spark.sql.catalyst.encoders.OuterScopes; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.StructType; + +import static org.apache.spark.sql.functions.*; +import static org.apache.spark.sql.types.DataTypes.*; + +public class JavaDatasetSuite implements Serializable { + private transient JavaSparkContext jsc; + private transient TestSQLContext context; + + @Before + public void setUp() { + // Trigger static initializer of TestData + SparkContext sc = new SparkContext("local[*]", "testing"); + jsc = new JavaSparkContext(sc); + context = new TestSQLContext(sc); + context.loadTestData(); + } + + @After + public void tearDown() { + context.sparkContext().stop(); + context = null; + jsc = null; + } + + private Tuple2 tuple2(T1 t1, T2 t2) { + return new Tuple2<>(t1, t2); + } + + @Test + public void testCollect() { + List data = Arrays.asList("hello", "world"); + Dataset ds = context.createDataset(data, Encoders.STRING()); + List collected = ds.collectAsList(); + Assert.assertEquals(Arrays.asList("hello", "world"), collected); + } + + @Test + public void testTake() { + List data = Arrays.asList("hello", "world"); + Dataset ds = context.createDataset(data, Encoders.STRING()); + List collected = ds.takeAsList(1); + Assert.assertEquals(Arrays.asList("hello"), collected); + } + + @Test + public void testToLocalIterator() { + List data = Arrays.asList("hello", "world"); + Dataset ds = context.createDataset(data, Encoders.STRING()); + Iterator iter = ds.toLocalIterator(); + Assert.assertEquals("hello", iter.next()); + Assert.assertEquals("world", iter.next()); + Assert.assertFalse(iter.hasNext()); + } + + @Test + public void testCommonOperation() { + List data = Arrays.asList("hello", "world"); + Dataset ds = context.createDataset(data, Encoders.STRING()); + Assert.assertEquals("hello", ds.first()); + + Dataset filtered = ds.filter(new FilterFunction() { + @Override + public boolean call(String v) throws Exception { + return v.startsWith("h"); + } + }); + Assert.assertEquals(Arrays.asList("hello"), filtered.collectAsList()); + + + Dataset mapped = ds.map(new MapFunction() { + @Override + public Integer call(String v) throws Exception { + return v.length(); + } + }, Encoders.INT()); + Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList()); + + Dataset parMapped = ds.mapPartitions(new MapPartitionsFunction() { + @Override + public Iterator call(Iterator it) { + List ls = new LinkedList<>(); + while (it.hasNext()) { + ls.add(it.next().toUpperCase(Locale.ENGLISH)); + } + return ls.iterator(); + } + }, Encoders.STRING()); + Assert.assertEquals(Arrays.asList("HELLO", "WORLD"), parMapped.collectAsList()); + + Dataset flatMapped = ds.flatMap(new FlatMapFunction() { + @Override + public Iterator call(String s) { + List ls = new LinkedList<>(); + for (char c : s.toCharArray()) { + ls.add(String.valueOf(c)); + } + return ls.iterator(); + } + }, Encoders.STRING()); + Assert.assertEquals( + Arrays.asList("h", "e", "l", "l", "o", "w", "o", "r", "l", "d"), + flatMapped.collectAsList()); + } + + @Test + public void testForeach() { + final Accumulator accum = jsc.accumulator(0); + List data = Arrays.asList("a", "b", "c"); + Dataset ds = context.createDataset(data, Encoders.STRING()); + + ds.foreach(new ForeachFunction() { + @Override + public void call(String s) throws Exception { + accum.add(1); + } + }); + Assert.assertEquals(3, accum.value().intValue()); + } + + @Test + public void testReduce() { + List data = Arrays.asList(1, 2, 3); + Dataset ds = context.createDataset(data, Encoders.INT()); + + int reduced = ds.reduce(new ReduceFunction() { + @Override + public Integer call(Integer v1, Integer v2) throws Exception { + return v1 + v2; + } + }); + Assert.assertEquals(6, reduced); + } + + @Test + public void testGroupBy() { + List data = Arrays.asList("a", "foo", "bar"); + Dataset ds = context.createDataset(data, Encoders.STRING()); + KeyValueGroupedDataset grouped = ds.groupByKey( + new MapFunction() { + @Override + public Integer call(String v) throws Exception { + return v.length(); + } + }, + Encoders.INT()); + + Dataset mapped = grouped.mapGroups(new MapGroupsFunction() { + @Override + public String call(Integer key, Iterator values) throws Exception { + StringBuilder sb = new StringBuilder(key.toString()); + while (values.hasNext()) { + sb.append(values.next()); + } + return sb.toString(); + } + }, Encoders.STRING()); + + Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped.collectAsList())); + + Dataset flatMapped = grouped.flatMapGroups( + new FlatMapGroupsFunction() { + @Override + public Iterator call(Integer key, Iterator values) { + StringBuilder sb = new StringBuilder(key.toString()); + while (values.hasNext()) { + sb.append(values.next()); + } + return Collections.singletonList(sb.toString()).iterator(); + } + }, + Encoders.STRING()); + + Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped.collectAsList())); + + Dataset> reduced = grouped.reduceGroups(new ReduceFunction() { + @Override + public String call(String v1, String v2) throws Exception { + return v1 + v2; + } + }); + + Assert.assertEquals( + asSet(tuple2(1, "a"), tuple2(3, "foobar")), + toSet(reduced.collectAsList())); + + List data2 = Arrays.asList(2, 6, 10); + Dataset ds2 = context.createDataset(data2, Encoders.INT()); + KeyValueGroupedDataset grouped2 = ds2.groupByKey( + new MapFunction() { + @Override + public Integer call(Integer v) throws Exception { + return v / 2; + } + }, + Encoders.INT()); + + Dataset cogrouped = grouped.cogroup( + grouped2, + new CoGroupFunction() { + @Override + public Iterator call(Integer key, Iterator left, Iterator right) { + StringBuilder sb = new StringBuilder(key.toString()); + while (left.hasNext()) { + sb.append(left.next()); + } + sb.append("#"); + while (right.hasNext()) { + sb.append(right.next()); + } + return Collections.singletonList(sb.toString()).iterator(); + } + }, + Encoders.STRING()); + + Assert.assertEquals(asSet("1a#2", "3foobar#6", "5#10"), toSet(cogrouped.collectAsList())); + } + + @Test + public void testSelect() { + List data = Arrays.asList(2, 6); + Dataset ds = context.createDataset(data, Encoders.INT()); + + Dataset> selected = ds.select( + expr("value + 1"), + col("value").cast("string")).as(Encoders.tuple(Encoders.INT(), Encoders.STRING())); + + Assert.assertEquals( + Arrays.asList(tuple2(3, "2"), tuple2(7, "6")), + selected.collectAsList()); + } + + @Test + public void testSetOperation() { + List data = Arrays.asList("abc", "abc", "xyz"); + Dataset ds = context.createDataset(data, Encoders.STRING()); + + Assert.assertEquals(asSet("abc", "xyz"), toSet(ds.distinct().collectAsList())); + + List data2 = Arrays.asList("xyz", "foo", "foo"); + Dataset ds2 = context.createDataset(data2, Encoders.STRING()); + + Dataset intersected = ds.intersect(ds2); + Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList()); + + Dataset unioned = ds.union(ds2).union(ds); + Assert.assertEquals( + Arrays.asList("abc", "abc", "xyz", "xyz", "foo", "foo", "abc", "abc", "xyz"), + unioned.collectAsList()); + + Dataset subtracted = ds.except(ds2); + Assert.assertEquals(Arrays.asList("abc", "abc"), subtracted.collectAsList()); + } + + private static Set toSet(List records) { + return new HashSet<>(records); + } + + @SafeVarargs + @SuppressWarnings("varargs") + private static Set asSet(T... records) { + return toSet(Arrays.asList(records)); + } + + @Test + public void testJoin() { + List data = Arrays.asList(1, 2, 3); + Dataset ds = context.createDataset(data, Encoders.INT()).as("a"); + List data2 = Arrays.asList(2, 3, 4); + Dataset ds2 = context.createDataset(data2, Encoders.INT()).as("b"); + + Dataset> joined = + ds.joinWith(ds2, col("a.value").equalTo(col("b.value"))); + Assert.assertEquals( + Arrays.asList(tuple2(2, 2), tuple2(3, 3)), + joined.collectAsList()); + } + + @Test + public void testTupleEncoder() { + Encoder> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING()); + List> data2 = Arrays.asList(tuple2(1, "a"), tuple2(2, "b")); + Dataset> ds2 = context.createDataset(data2, encoder2); + Assert.assertEquals(data2, ds2.collectAsList()); + + Encoder> encoder3 = + Encoders.tuple(Encoders.INT(), Encoders.LONG(), Encoders.STRING()); + List> data3 = + Arrays.asList(new Tuple3<>(1, 2L, "a")); + Dataset> ds3 = context.createDataset(data3, encoder3); + Assert.assertEquals(data3, ds3.collectAsList()); + + Encoder> encoder4 = + Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING()); + List> data4 = + Arrays.asList(new Tuple4<>(1, "b", 2L, "a")); + Dataset> ds4 = context.createDataset(data4, encoder4); + Assert.assertEquals(data4, ds4.collectAsList()); + + Encoder> encoder5 = + Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING(), + Encoders.BOOLEAN()); + List> data5 = + Arrays.asList(new Tuple5<>(1, "b", 2L, "a", true)); + Dataset> ds5 = + context.createDataset(data5, encoder5); + Assert.assertEquals(data5, ds5.collectAsList()); + } + + @Test + public void testNestedTupleEncoder() { + // test ((int, string), string) + Encoder, String>> encoder = + Encoders.tuple(Encoders.tuple(Encoders.INT(), Encoders.STRING()), Encoders.STRING()); + List, String>> data = + Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b")); + Dataset, String>> ds = context.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + + // test (int, (string, string, long)) + Encoder>> encoder2 = + Encoders.tuple(Encoders.INT(), + Encoders.tuple(Encoders.STRING(), Encoders.STRING(), Encoders.LONG())); + List>> data2 = + Arrays.asList(tuple2(1, new Tuple3<>("a", "b", 3L))); + Dataset>> ds2 = + context.createDataset(data2, encoder2); + Assert.assertEquals(data2, ds2.collectAsList()); + + // test (int, ((string, long), string)) + Encoder, String>>> encoder3 = + Encoders.tuple(Encoders.INT(), + Encoders.tuple(Encoders.tuple(Encoders.STRING(), Encoders.LONG()), Encoders.STRING())); + List, String>>> data3 = + Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b"))); + Dataset, String>>> ds3 = + context.createDataset(data3, encoder3); + Assert.assertEquals(data3, ds3.collectAsList()); + } + + @Test + public void testPrimitiveEncoder() { + Encoder> encoder = + Encoders.tuple(Encoders.DOUBLE(), Encoders.DECIMAL(), Encoders.DATE(), Encoders.TIMESTAMP(), + Encoders.FLOAT()); + List> data = + Arrays.asList(new Tuple5<>( + 1.7976931348623157E308, new BigDecimal("0.922337203685477589"), + Date.valueOf("1970-01-01"), new Timestamp(System.currentTimeMillis()), Float.MAX_VALUE)); + Dataset> ds = + context.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + } + + public static class KryoSerializable { + String value; + + KryoSerializable(String value) { + this.value = value; + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || getClass() != other.getClass()) return false; + + return this.value.equals(((KryoSerializable) other).value); + } + + @Override + public int hashCode() { + return this.value.hashCode(); + } + } + + public static class JavaSerializable implements Serializable { + String value; + + JavaSerializable(String value) { + this.value = value; + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || getClass() != other.getClass()) return false; + + return this.value.equals(((JavaSerializable) other).value); + } + + @Override + public int hashCode() { + return this.value.hashCode(); + } + } + + @Test + public void testKryoEncoder() { + Encoder encoder = Encoders.kryo(KryoSerializable.class); + List data = Arrays.asList( + new KryoSerializable("hello"), new KryoSerializable("world")); + Dataset ds = context.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + } + + @Test + public void testJavaEncoder() { + Encoder encoder = Encoders.javaSerialization(JavaSerializable.class); + List data = Arrays.asList( + new JavaSerializable("hello"), new JavaSerializable("world")); + Dataset ds = context.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + } + + @Test + public void testRandomSplit() { + List data = Arrays.asList("hello", "world", "from", "spark"); + Dataset ds = context.createDataset(data, Encoders.STRING()); + double[] arraySplit = {1, 2, 3}; + + List> randomSplit = ds.randomSplitAsList(arraySplit, 1); + Assert.assertEquals("wrong number of splits", randomSplit.size(), 3); + } + + /** + * For testing error messages when creating an encoder on a private class. This is done + * here since we cannot create truly private classes in Scala. + */ + private static class PrivateClassTest { } + + @Test(expected = UnsupportedOperationException.class) + public void testJavaEncoderErrorMessageForPrivateClass() { + Encoders.javaSerialization(PrivateClassTest.class); + } + + @Test(expected = UnsupportedOperationException.class) + public void testKryoEncoderErrorMessageForPrivateClass() { + Encoders.kryo(PrivateClassTest.class); + } + + public static class SimpleJavaBean implements Serializable { + private boolean a; + private int b; + private byte[] c; + private String[] d; + private List e; + private List f; + + public boolean isA() { + return a; + } + + public void setA(boolean a) { + this.a = a; + } + + public int getB() { + return b; + } + + public void setB(int b) { + this.b = b; + } + + public byte[] getC() { + return c; + } + + public void setC(byte[] c) { + this.c = c; + } + + public String[] getD() { + return d; + } + + public void setD(String[] d) { + this.d = d; + } + + public List getE() { + return e; + } + + public void setE(List e) { + this.e = e; + } + + public List getF() { + return f; + } + + public void setF(List f) { + this.f = f; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + SimpleJavaBean that = (SimpleJavaBean) o; + + if (a != that.a) return false; + if (b != that.b) return false; + if (!Arrays.equals(c, that.c)) return false; + if (!Arrays.equals(d, that.d)) return false; + if (!e.equals(that.e)) return false; + return f.equals(that.f); + } + + @Override + public int hashCode() { + int result = (a ? 1 : 0); + result = 31 * result + b; + result = 31 * result + Arrays.hashCode(c); + result = 31 * result + Arrays.hashCode(d); + result = 31 * result + e.hashCode(); + result = 31 * result + f.hashCode(); + return result; + } + } + + public static class SimpleJavaBean2 implements Serializable { + private Timestamp a; + private Date b; + private java.math.BigDecimal c; + + public Timestamp getA() { return a; } + + public void setA(Timestamp a) { this.a = a; } + + public Date getB() { return b; } + + public void setB(Date b) { this.b = b; } + + public java.math.BigDecimal getC() { return c; } + + public void setC(java.math.BigDecimal c) { this.c = c; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + SimpleJavaBean2 that = (SimpleJavaBean2) o; + + if (!a.equals(that.a)) return false; + if (!b.equals(that.b)) return false; + return c.equals(that.c); + } + + @Override + public int hashCode() { + int result = a.hashCode(); + result = 31 * result + b.hashCode(); + result = 31 * result + c.hashCode(); + return result; + } + } + + public static class NestedJavaBean implements Serializable { + private SimpleJavaBean a; + + public SimpleJavaBean getA() { + return a; + } + + public void setA(SimpleJavaBean a) { + this.a = a; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + NestedJavaBean that = (NestedJavaBean) o; + + return a.equals(that.a); + } + + @Override + public int hashCode() { + return a.hashCode(); + } + } + + @Test + public void testJavaBeanEncoder() { + OuterScopes.addOuterScope(this); + SimpleJavaBean obj1 = new SimpleJavaBean(); + obj1.setA(true); + obj1.setB(3); + obj1.setC(new byte[]{1, 2}); + obj1.setD(new String[]{"hello", null}); + obj1.setE(Arrays.asList("a", "b")); + obj1.setF(Arrays.asList(100L, null, 200L)); + SimpleJavaBean obj2 = new SimpleJavaBean(); + obj2.setA(false); + obj2.setB(30); + obj2.setC(new byte[]{3, 4}); + obj2.setD(new String[]{null, "world"}); + obj2.setE(Arrays.asList("x", "y")); + obj2.setF(Arrays.asList(300L, null, 400L)); + + List data = Arrays.asList(obj1, obj2); + Dataset ds = context.createDataset(data, Encoders.bean(SimpleJavaBean.class)); + Assert.assertEquals(data, ds.collectAsList()); + + NestedJavaBean obj3 = new NestedJavaBean(); + obj3.setA(obj1); + + List data2 = Arrays.asList(obj3); + Dataset ds2 = context.createDataset(data2, Encoders.bean(NestedJavaBean.class)); + Assert.assertEquals(data2, ds2.collectAsList()); + + Row row1 = new GenericRow(new Object[]{ + true, + 3, + new byte[]{1, 2}, + new String[]{"hello", null}, + Arrays.asList("a", "b"), + Arrays.asList(100L, null, 200L)}); + Row row2 = new GenericRow(new Object[]{ + false, + 30, + new byte[]{3, 4}, + new String[]{null, "world"}, + Arrays.asList("x", "y"), + Arrays.asList(300L, null, 400L)}); + StructType schema = new StructType() + .add("a", BooleanType, false) + .add("b", IntegerType, false) + .add("c", BinaryType) + .add("d", createArrayType(StringType)) + .add("e", createArrayType(StringType)) + .add("f", createArrayType(LongType)); + Dataset ds3 = context.createDataFrame(Arrays.asList(row1, row2), schema) + .as(Encoders.bean(SimpleJavaBean.class)); + Assert.assertEquals(data, ds3.collectAsList()); + } + + @Test + public void testJavaBeanEncoder2() { + // This is a regression test of SPARK-12404 + OuterScopes.addOuterScope(this); + SimpleJavaBean2 obj = new SimpleJavaBean2(); + obj.setA(new Timestamp(0)); + obj.setB(new Date(0)); + obj.setC(java.math.BigDecimal.valueOf(1)); + Dataset ds = + context.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class)); + ds.collect(); + } + + public static class SmallBean implements Serializable { + private String a; + + private int b; + + public int getB() { + return b; + } + + public void setB(int b) { + this.b = b; + } + + public String getA() { + return a; + } + + public void setA(String a) { + this.a = a; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SmallBean smallBean = (SmallBean) o; + return b == smallBean.b && com.google.common.base.Objects.equal(a, smallBean.a); + } + + @Override + public int hashCode() { + return Objects.hashCode(a, b); + } + } + + public static class NestedSmallBean implements Serializable { + private SmallBean f; + + public SmallBean getF() { + return f; + } + + public void setF(SmallBean f) { + this.f = f; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + NestedSmallBean that = (NestedSmallBean) o; + return Objects.equal(f, that.f); + } + + @Override + public int hashCode() { + return Objects.hashCode(f); + } + } + + @Rule + public transient ExpectedException nullabilityCheck = ExpectedException.none(); + + @Test + public void testRuntimeNullabilityCheck() { + OuterScopes.addOuterScope(this); + + StructType schema = new StructType() + .add("f", new StructType() + .add("a", StringType, true) + .add("b", IntegerType, true), true); + + // Shouldn't throw runtime exception since it passes nullability check. + { + Row row = new GenericRow(new Object[] { + new GenericRow(new Object[] { + "hello", 1 + }) + }); + + Dataset df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset ds = df.as(Encoders.bean(NestedSmallBean.class)); + + SmallBean smallBean = new SmallBean(); + smallBean.setA("hello"); + smallBean.setB(1); + + NestedSmallBean nestedSmallBean = new NestedSmallBean(); + nestedSmallBean.setF(smallBean); + + Assert.assertEquals(ds.collectAsList(), Collections.singletonList(nestedSmallBean)); + } + + // Shouldn't throw runtime exception when parent object (`ClassData`) is null + { + Row row = new GenericRow(new Object[] { null }); + + Dataset df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset ds = df.as(Encoders.bean(NestedSmallBean.class)); + + NestedSmallBean nestedSmallBean = new NestedSmallBean(); + Assert.assertEquals(ds.collectAsList(), Collections.singletonList(nestedSmallBean)); + } + + nullabilityCheck.expect(RuntimeException.class); + nullabilityCheck.expectMessage("Null value appeared in non-nullable field"); + + { + Row row = new GenericRow(new Object[] { + new GenericRow(new Object[] { + "hello", null + }) + }); + + Dataset df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset ds = df.as(Encoders.bean(NestedSmallBean.class)); + + ds.collect(); + } + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java new file mode 100644 index 0000000000000..0e49f871de5c4 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources; + +import java.util.Arrays; + +import scala.Tuple2; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.KeyValueGroupedDataset; +import org.apache.spark.sql.expressions.Aggregator; +import org.apache.spark.sql.expressions.java.typed; + +/** + * Suite for testing the aggregate functionality of Datasets in Java. + */ +public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase { + @Test + public void testTypedAggregationAnonClass() { + KeyValueGroupedDataset> grouped = generateGroupedDataset(); + + Dataset> agged = grouped.agg(new IntSumOf().toColumn()); + Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); + + Dataset> agged2 = grouped.agg(new IntSumOf().toColumn()) + .as(Encoders.tuple(Encoders.STRING(), Encoders.INT())); + Assert.assertEquals( + Arrays.asList( + new Tuple2<>("a", 3), + new Tuple2<>("b", 3)), + agged2.collectAsList()); + } + + static class IntSumOf extends Aggregator, Integer, Integer> { + @Override + public Integer zero() { + return 0; + } + + @Override + public Integer reduce(Integer l, Tuple2 t) { + return l + t._2(); + } + + @Override + public Integer merge(Integer b1, Integer b2) { + return b1 + b2; + } + + @Override + public Integer finish(Integer reduction) { + return reduction; + } + + @Override + public Encoder bufferEncoder() { + return Encoders.INT(); + } + + @Override + public Encoder outputEncoder() { + return Encoders.INT(); + } + } + + @Test + public void testTypedAggregationAverage() { + KeyValueGroupedDataset> grouped = generateGroupedDataset(); + Dataset> agged = grouped.agg(typed.avg( + new MapFunction, Double>() { + public Double call(Tuple2 value) throws Exception { + return (double)(value._2() * 2); + } + })); + Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 6.0)), agged.collectAsList()); + } + + @Test + public void testTypedAggregationCount() { + KeyValueGroupedDataset> grouped = generateGroupedDataset(); + Dataset> agged = grouped.agg(typed.count( + new MapFunction, Object>() { + public Object call(Tuple2 value) throws Exception { + return value; + } + })); + Assert.assertEquals(Arrays.asList(tuple2("a", 2), tuple2("b", 1)), agged.collectAsList()); + } + + @Test + public void testTypedAggregationSumDouble() { + KeyValueGroupedDataset> grouped = generateGroupedDataset(); + Dataset> agged = grouped.agg(typed.sum( + new MapFunction, Double>() { + public Double call(Tuple2 value) throws Exception { + return (double)value._2(); + } + })); + Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 3.0)), agged.collectAsList()); + } + + @Test + public void testTypedAggregationSumLong() { + KeyValueGroupedDataset> grouped = generateGroupedDataset(); + Dataset> agged = grouped.agg(typed.sumLong( + new MapFunction, Long>() { + public Long call(Tuple2 value) throws Exception { + return (long)value._2(); + } + })); + Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java new file mode 100644 index 0000000000000..7863177093c15 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +import scala.Tuple2; + +import org.junit.After; +import org.junit.Before; + +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.KeyValueGroupedDataset; +import org.apache.spark.sql.test.TestSQLContext; + +/** + * Common test base shared across this and Java8DatasetAggregatorSuite. + */ +public class JavaDatasetAggregatorSuiteBase implements Serializable { + protected transient JavaSparkContext jsc; + protected transient TestSQLContext context; + + @Before + public void setUp() { + // Trigger static initializer of TestData + SparkContext sc = new SparkContext("local[*]", "testing"); + jsc = new JavaSparkContext(sc); + context = new TestSQLContext(sc); + context.loadTestData(); + } + + @After + public void tearDown() { + context.sparkContext().stop(); + context = null; + jsc = null; + } + + protected Tuple2 tuple2(T1 t1, T2 t2) { + return new Tuple2<>(t1, t2); + } + + protected KeyValueGroupedDataset> generateGroupedDataset() { + Encoder> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT()); + List> data = + Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); + Dataset> ds = context.createDataset(data, encoder); + + return ds.groupByKey( + new MapFunction, String>() { + @Override + public String call(Tuple2 value) throws Exception { + return value._1(); + } + }, + Encoders.STRING()); + } +} + diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 9e241f20987c0..9e65158eb0a33 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -40,11 +40,10 @@ public class JavaSaveLoadSuite { private transient JavaSparkContext sc; private transient SQLContext sqlContext; - String originalDefaultSource; File path; - DataFrame df; + Dataset df; - private static void checkAnswer(DataFrame actual, List expected) { + private static void checkAnswer(Dataset actual, List expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -57,7 +56,6 @@ public void setUp() throws IOException { sqlContext = new SQLContext(_sc); sc = new JavaSparkContext(_sc); - originalDefaultSource = sqlContext.conf().defaultDataSourceName(); path = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile(); if (path.exists()) { @@ -85,7 +83,7 @@ public void saveAndLoad() { Map options = new HashMap<>(); options.put("path", path.toString()); df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save(); - DataFrame loadedDF = sqlContext.read().format("json").options(options).load(); + Dataset loadedDF = sqlContext.read().format("json").options(options).load(); checkAnswer(loadedDF, df.collectAsList()); } @@ -98,7 +96,7 @@ public void saveAndLoadWithSchema() { List fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); - DataFrame loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); + Dataset loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList()); } diff --git a/sql/core/src/test/resources/bool.csv b/sql/core/src/test/resources/bool.csv new file mode 100644 index 0000000000000..94b2d49506e0d --- /dev/null +++ b/sql/core/src/test/resources/bool.csv @@ -0,0 +1,5 @@ +bool +"True" +"False" + +"true" diff --git a/sql/core/src/test/resources/cars-alternative.csv b/sql/core/src/test/resources/cars-alternative.csv new file mode 100644 index 0000000000000..646f7c456c866 --- /dev/null +++ b/sql/core/src/test/resources/cars-alternative.csv @@ -0,0 +1,5 @@ +year|make|model|comment|blank +'2012'|'Tesla'|'S'| 'No comment'| + +1997|Ford|E350|'Go get one now they are going fast'| +2015|Chevy|Volt diff --git a/sql/core/src/test/resources/cars-malformed.csv b/sql/core/src/test/resources/cars-malformed.csv new file mode 100644 index 0000000000000..cfa378c01f1d9 --- /dev/null +++ b/sql/core/src/test/resources/cars-malformed.csv @@ -0,0 +1,6 @@ +~ All the rows here are malformed having tokens more than the schema (header). +year,make,model,comment,blank +"2012","Tesla","S","No comment",,null,null + +1997,Ford,E350,"Go get one now they are going fast",,null,null +2015,Chevy,,,, diff --git a/sql/core/src/test/resources/cars-null.csv b/sql/core/src/test/resources/cars-null.csv new file mode 100644 index 0000000000000..130c0b40bbe78 --- /dev/null +++ b/sql/core/src/test/resources/cars-null.csv @@ -0,0 +1,6 @@ +year,make,model,comment,blank +"2012","Tesla","S",null, + +1997,Ford,E350,"Go get one now they are going fast", +null,Chevy,Volt + diff --git a/sql/core/src/test/resources/cars-unbalanced-quotes.csv b/sql/core/src/test/resources/cars-unbalanced-quotes.csv new file mode 100644 index 0000000000000..5ea39fcbfadcc --- /dev/null +++ b/sql/core/src/test/resources/cars-unbalanced-quotes.csv @@ -0,0 +1,4 @@ +year,make,model,comment,blank +"2012,Tesla,S,No comment +1997,Ford,E350,Go get one now they are going fast" +"2015,"Chevy",Volt, diff --git a/sql/core/src/test/resources/cars.csv b/sql/core/src/test/resources/cars.csv new file mode 100644 index 0000000000000..40ded573ade5c --- /dev/null +++ b/sql/core/src/test/resources/cars.csv @@ -0,0 +1,7 @@ + +year,make,model,comment,blank +"2012","Tesla","S","No comment", + +1997,Ford,E350,"Go get one now they are going fast", +2015,Chevy,Volt + diff --git a/sql/core/src/test/resources/cars.tsv b/sql/core/src/test/resources/cars.tsv new file mode 100644 index 0000000000000..a7bfa9a91f961 --- /dev/null +++ b/sql/core/src/test/resources/cars.tsv @@ -0,0 +1,4 @@ +year make model price comment blank +2012 Tesla S "80,000.65" +1997 Ford E350 35,000 "Go get one now they are going fast" +2015 Chevy Volt 5,000.10 diff --git a/sql/core/src/test/resources/cars_iso-8859-1.csv b/sql/core/src/test/resources/cars_iso-8859-1.csv new file mode 100644 index 0000000000000..c51b6c59010f0 --- /dev/null +++ b/sql/core/src/test/resources/cars_iso-8859-1.csv @@ -0,0 +1,6 @@ +yearmakemodelcommentblank +"2012""Tesla""S""No comment" + +1997FordE350"Go get one now they are oing fast" +2015ChevyVolt + diff --git a/sql/core/src/test/resources/comments.csv b/sql/core/src/test/resources/comments.csv new file mode 100644 index 0000000000000..6275be7285b36 --- /dev/null +++ b/sql/core/src/test/resources/comments.csv @@ -0,0 +1,6 @@ +~ Version 1.0 +~ Using a non-standard comment char to test CSV parser defaults are overridden +1,2,3,4,5.01,2015-08-20 15:57:00 +6,7,8,9,0,2015-08-21 16:58:01 +~0,9,8,7,6,2015-08-22 17:59:02 +1,2,3,4,5,2015-08-23 18:00:42 diff --git a/sql/core/src/test/resources/dec-in-fixed-len.parquet b/sql/core/src/test/resources/dec-in-fixed-len.parquet new file mode 100644 index 0000000000000..6ad37d5639511 Binary files /dev/null and b/sql/core/src/test/resources/dec-in-fixed-len.parquet differ diff --git a/sql/core/src/test/resources/disable_comments.csv b/sql/core/src/test/resources/disable_comments.csv new file mode 100644 index 0000000000000..304d406e4d980 --- /dev/null +++ b/sql/core/src/test/resources/disable_comments.csv @@ -0,0 +1,2 @@ +#1,2,3 +4,5,6 diff --git a/core/src/test/resources/spark-events/local-1422981759269/APPLICATION_COMPLETE b/sql/core/src/test/resources/empty.csv old mode 100755 new mode 100644 similarity index 100% rename from core/src/test/resources/spark-events/local-1422981759269/APPLICATION_COMPLETE rename to sql/core/src/test/resources/empty.csv diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties index 12fb128149d32..e53cb1f4e681d 100644 --- a/sql/core/src/test/resources/log4j.properties +++ b/sql/core/src/test/resources/log4j.properties @@ -23,6 +23,7 @@ log4j.appender.CA=org.apache.log4j.ConsoleAppender log4j.appender.CA.layout=org.apache.log4j.PatternLayout log4j.appender.CA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c: %m%n log4j.appender.CA.Threshold = WARN +log4j.appender.CA.follow = true #File Appender diff --git a/sql/core/src/test/resources/simple_sparse.csv b/sql/core/src/test/resources/simple_sparse.csv new file mode 100644 index 0000000000000..02d29cabf95f2 --- /dev/null +++ b/sql/core/src/test/resources/simple_sparse.csv @@ -0,0 +1,5 @@ +A,B,C,D +1,,, +,1,, +,,1, +,,,1 diff --git a/sql/core/src/test/resources/text-suite2.txt b/sql/core/src/test/resources/text-suite2.txt new file mode 100644 index 0000000000000..f9d498c80493c --- /dev/null +++ b/sql/core/src/test/resources/text-suite2.txt @@ -0,0 +1 @@ +This is another file for testing multi path loading. diff --git a/sql/core/src/test/resources/unescaped-quotes.csv b/sql/core/src/test/resources/unescaped-quotes.csv new file mode 100644 index 0000000000000..7c68055575de0 --- /dev/null +++ b/sql/core/src/test/resources/unescaped-quotes.csv @@ -0,0 +1,2 @@ +"a"b,ccc,ddd +ab,cc"c,ddd" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index dbcb011f603f7..82b79c791db40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -17,38 +17,38 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.analysis.NoSuchTableException -import org.apache.spark.sql.execution.Exchange -import org.apache.spark.sql.execution.PhysicalRDD - import scala.concurrent.duration._ import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators -import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.execution.PhysicalRDD +import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.storage.{StorageLevel, RDDBlockId} +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.storage.{RDDBlockId, StorageLevel} private case class BigData(s: String) -class CachedTableSuite extends QueryTest with SharedSQLContext { +class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext { import testImplicits._ def rddIdOf(tableName: String): Int = { - val executedPlan = sqlContext.table(tableName).queryExecution.executedPlan - executedPlan.collect { + val plan = sqlContext.table(tableName).queryExecution.sparkPlan + plan.collect { case InMemoryColumnarTableScan(_, _, relation) => relation.cachedColumnBuffers.id case _ => - fail(s"Table $tableName is not cached\n" + executedPlan) + fail(s"Table $tableName is not cached\n" + plan) }.head } def isMaterialized(rddId: Int): Boolean = { - sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty + val maybeBlock = sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)) + maybeBlock.foreach(_ => sparkContext.env.blockManager.releaseLock(RDDBlockId(rddId, 0))) + maybeBlock.nonEmpty } test("withColumn doesn't invalidate cached dataframe") { @@ -280,7 +280,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { sql("CACHE TABLE testData") sqlContext.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => - val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum + val actualSizeInBytes = (1 to 100).map(i => 4 + i.toString.length + 4).sum assert(cached.statistics.sizeInBytes === actualSizeInBytes) } } @@ -289,7 +289,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { testData.select('key).registerTempTable("t1") sqlContext.table("t1") sqlContext.dropTempTable("t1") - intercept[NoSuchTableException](sqlContext.table("t1")) + intercept[AnalysisException](sqlContext.table("t1")) } test("Drops cached temporary table") { @@ -301,7 +301,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { assert(sqlContext.isCached("t2")) sqlContext.dropTempTable("t1") - intercept[NoSuchTableException](sqlContext.table("t1")) + intercept[AnalysisException](sqlContext.table("t1")) assert(!sqlContext.isCached("t2")) } @@ -359,11 +359,11 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { * Verifies that the plan for `df` contains `expected` number of Exchange operators. */ private def verifyNumExchanges(df: DataFrame, expected: Int): Unit = { - assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.size == expected) + assert(df.queryExecution.executedPlan.collect { case e: ShuffleExchange => e }.size == expected) } test("A cached table preserves the partitioning and ordering of its cached SparkPlan") { - val table3x = testData.unionAll(testData).unionAll(testData) + val table3x = testData.union(testData).union(testData) table3x.registerTempTable("testData3x") sql("SELECT key, value FROM testData3x ORDER BY key").registerTempTable("orderedTable") @@ -375,53 +375,135 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { sql("SELECT key, count(*) FROM orderedTable GROUP BY key ORDER BY key"), sql("SELECT key, count(*) FROM testData3x GROUP BY key ORDER BY key").collect()) sqlContext.uncacheTable("orderedTable") + sqlContext.dropTempTable("orderedTable") // Set up two tables distributed in the same way. Try this with the data distributed into // different number of partitions. for (numPartitions <- 1 until 10 by 4) { - testData.repartition(numPartitions, $"key").registerTempTable("t1") - testData2.repartition(numPartitions, $"a").registerTempTable("t2") + withTempTable("t1", "t2") { + testData.repartition(numPartitions, $"key").registerTempTable("t1") + testData2.repartition(numPartitions, $"a").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + + // Joining them should result in no exchanges. + verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 0) + checkAnswer(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), + sql("SELECT * FROM testData t1 JOIN testData2 t2 ON t1.key = t2.a")) + + // Grouping on the partition key should result in no exchanges + verifyNumExchanges(sql("SELECT count(*) FROM t1 GROUP BY key"), 0) + checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"), + sql("SELECT count(*) FROM testData GROUP BY key")) + + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } + } + + // Distribute the tables into non-matching number of partitions. Need to shuffle one side. + withTempTable("t1", "t2") { + testData.repartition(6, $"key").registerTempTable("t1") + testData2.repartition(3, $"a").registerTempTable("t2") sqlContext.cacheTable("t1") sqlContext.cacheTable("t2") - // Joining them should result in no exchanges. - verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 0) - checkAnswer(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), - sql("SELECT * FROM testData t1 JOIN testData2 t2 ON t1.key = t2.a")) + val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") + verifyNumExchanges(query, 1) + assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 6) + checkAnswer( + query, + testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } - // Grouping on the partition key should result in no exchanges - verifyNumExchanges(sql("SELECT count(*) FROM t1 GROUP BY key"), 0) - checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"), - sql("SELECT count(*) FROM testData GROUP BY key")) + // One side of join is not partitioned in the desired way. Need to shuffle one side. + withTempTable("t1", "t2") { + testData.repartition(6, $"value").registerTempTable("t1") + testData2.repartition(6, $"a").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") + verifyNumExchanges(query, 1) + assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 6) + checkAnswer( + query, + testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) sqlContext.uncacheTable("t1") sqlContext.uncacheTable("t2") - sqlContext.dropTempTable("t1") - sqlContext.dropTempTable("t2") } - // Distribute the tables into non-matching number of partitions. Need to shuffle. - testData.repartition(6, $"key").registerTempTable("t1") - testData2.repartition(3, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + withTempTable("t1", "t2") { + testData.repartition(6, $"value").registerTempTable("t1") + testData2.repartition(12, $"a").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") - verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 2) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") - sqlContext.dropTempTable("t1") - sqlContext.dropTempTable("t2") + val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") + verifyNumExchanges(query, 1) + assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 12) + checkAnswer( + query, + testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } - // One side of join is not partitioned in the desired way. Need to shuffle. - testData.repartition(6, $"value").registerTempTable("t1") - testData2.repartition(6, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + // One side of join is not partitioned in the desired way. Since the number of partitions of + // the side that has already partitioned is smaller than the side that is not partitioned, + // we shuffle both side. + withTempTable("t1", "t2") { + testData.repartition(6, $"value").registerTempTable("t1") + testData2.repartition(3, $"a").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") - verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 2) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") - sqlContext.dropTempTable("t1") - sqlContext.dropTempTable("t2") + val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") + verifyNumExchanges(query, 2) + checkAnswer( + query, + testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } + + // repartition's column ordering is different from group by column ordering. + // But they use the same set of columns. + withTempTable("t1") { + testData.repartition(6, $"value", $"key").registerTempTable("t1") + sqlContext.cacheTable("t1") + + val query = sql("SELECT value, key from t1 group by key, value") + verifyNumExchanges(query, 0) + checkAnswer( + query, + testData.distinct().select($"value", $"key")) + sqlContext.uncacheTable("t1") + } + + // repartition's column ordering is different from join condition's column ordering. + // We will still shuffle because hashcodes of a row depend on the column ordering. + // If we do not shuffle, we may actually partition two tables in totally two different way. + // See PartitioningSuite for more details. + withTempTable("t1", "t2") { + val df1 = testData + df1.repartition(6, $"value", $"key").registerTempTable("t1") + val df2 = testData2.select($"a", $"b".cast("string")) + df2.repartition(6, $"a", $"b").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + + val query = + sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a and t1.value = t2.b") + verifyNumExchanges(query, 1) + assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 6) + checkAnswer( + query, + df1.join(df2, $"key" === $"a" && $"value" === $"b").select($"key", $"value", $"a", $"b")) + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index fa559c9c64005..351b03b38bad1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.scalatest.Matchers._ -import org.apache.spark.sql.execution.{Project, TungstenProject} +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.execution.Project import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -105,10 +105,11 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { Row("a") :: Nil) } - test("alias") { + test("alias and name") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") assert(df.select(df("a").as("b")).columns.head === "b") assert(df.select(df("a").alias("b")).columns.head === "b") + assert(df.select(df("a").name("b")).columns.head === "b") } test("as propagates metadata") { @@ -298,7 +299,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil) checkAnswer( - testData.select(isNaN($"a"), isNaN($"b")), + testData.select(isnan($"a"), isnan($"b")), Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil) checkAnswer( @@ -349,7 +350,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { testData2.collect().toSeq.filter(r => r.getInt(0) == r.getInt(1))) } - test("!==") { + test("=!=") { val nullData = sqlContext.createDataFrame(sparkContext.parallelize( Row(1, 1) :: Row(1, 2) :: @@ -368,6 +369,17 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { checkAnswer( nullData.filter($"a" <=> $"b"), Row(1, 1) :: Row(null, null) :: Nil) + + val nullData2 = sqlContext.createDataFrame(sparkContext.parallelize( + Row("abc") :: + Row(null) :: + Row("xyz") :: Nil), + StructType(Seq(StructField("a", StringType, true)))) + + checkAnswer( + nullData2.filter($"a" <=> null), + Row(null) :: Nil) + } test(">") { @@ -563,28 +575,32 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { df.select(monotonicallyIncreasingId()), Row(0L) :: Row(1L) :: Row((1L << 33) + 0L) :: Row((1L << 33) + 1L) :: Nil ) + checkAnswer( + df.select(expr("monotonically_increasing_id()")), + Row(0L) :: Row(1L) :: Row((1L << 33) + 0L) :: Row((1L << 33) + 1L) :: Nil + ) } - test("sparkPartitionId") { + test("spark_partition_id") { // Make sure we have 2 partitions, each with 2 records. val df = sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => Iterator(Tuple1(1), Tuple1(2)) }.toDF("a") checkAnswer( - df.select(sparkPartitionId()), + df.select(spark_partition_id()), Row(0) :: Row(0) :: Row(1) :: Row(1) :: Nil ) } - test("InputFileName") { + test("input_file_name") { withTempPath { dir => val data = sparkContext.parallelize(0 to 10).toDF("id") data.write.parquet(dir.getCanonicalPath) - val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(inputFileName()) + val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(input_file_name()) .head.getString(0) assert(answer.contains(dir.getCanonicalPath)) - checkAnswer(data.select(inputFileName()).limit(1), Row("")) + checkAnswer(data.select(input_file_name()).limit(1), Row("")) } } @@ -614,9 +630,8 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = { - val projects = df.queryExecution.executedPlan.collect { - case project: Project => project - case tungstenProject: TungstenProject => tungstenProject + val projects = df.queryExecution.sparkPlan.collect { + case tungstenProject: Project => tungstenProject } assert(projects.size === expectedNumProjects) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 2e679e7bc4e0a..7d96ef6fe0a10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql +import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.DecimalType +case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double) class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -60,6 +63,120 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("rollup") { + checkAnswer( + courseSales.rollup("course", "year").sum("earnings"), + Row("Java", 2012, 20000.0) :: + Row("Java", 2013, 30000.0) :: + Row("Java", null, 50000.0) :: + Row("dotNET", 2012, 15000.0) :: + Row("dotNET", 2013, 48000.0) :: + Row("dotNET", null, 63000.0) :: + Row(null, null, 113000.0) :: Nil + ) + } + + test("cube") { + checkAnswer( + courseSales.cube("course", "year").sum("earnings"), + Row("Java", 2012, 20000.0) :: + Row("Java", 2013, 30000.0) :: + Row("Java", null, 50000.0) :: + Row("dotNET", 2012, 15000.0) :: + Row("dotNET", 2013, 48000.0) :: + Row("dotNET", null, 63000.0) :: + Row(null, 2012, 35000.0) :: + Row(null, 2013, 78000.0) :: + Row(null, null, 113000.0) :: Nil + ) + + val df0 = sqlContext.sparkContext.parallelize(Seq( + Fact(20151123, 18, 35, "room1", 18.6), + Fact(20151123, 18, 35, "room2", 22.4), + Fact(20151123, 18, 36, "room1", 17.4), + Fact(20151123, 18, 36, "room2", 25.6))).toDF() + + val cube0 = df0.cube("date", "hour", "minute", "room_name").agg(Map("temp" -> "avg")) + assert(cube0.where("date IS NULL").count > 0) + } + + test("grouping and grouping_id") { + checkAnswer( + courseSales.cube("course", "year") + .agg(grouping("course"), grouping("year"), grouping_id("course", "year")), + Row("Java", 2012, 0, 0, 0) :: + Row("Java", 2013, 0, 0, 0) :: + Row("Java", null, 0, 1, 1) :: + Row("dotNET", 2012, 0, 0, 0) :: + Row("dotNET", 2013, 0, 0, 0) :: + Row("dotNET", null, 0, 1, 1) :: + Row(null, 2012, 1, 0, 2) :: + Row(null, 2013, 1, 0, 2) :: + Row(null, null, 1, 1, 3) :: Nil + ) + + intercept[AnalysisException] { + courseSales.groupBy().agg(grouping("course")).explain() + } + intercept[AnalysisException] { + courseSales.groupBy().agg(grouping_id("course")).explain() + } + } + + test("grouping/grouping_id inside window function") { + + val w = Window.orderBy(sum("earnings")) + checkAnswer( + courseSales.cube("course", "year") + .agg(sum("earnings"), + grouping_id("course", "year"), + rank().over(Window.partitionBy(grouping_id("course", "year")).orderBy(sum("earnings")))), + Row("Java", 2012, 20000.0, 0, 2) :: + Row("Java", 2013, 30000.0, 0, 3) :: + Row("Java", null, 50000.0, 1, 1) :: + Row("dotNET", 2012, 15000.0, 0, 1) :: + Row("dotNET", 2013, 48000.0, 0, 4) :: + Row("dotNET", null, 63000.0, 1, 2) :: + Row(null, 2012, 35000.0, 2, 1) :: + Row(null, 2013, 78000.0, 2, 2) :: + Row(null, null, 113000.0, 3, 1) :: Nil + ) + } + + test("rollup overlapping columns") { + checkAnswer( + testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"), + Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1) + :: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1) + :: Row(null, null, 3) :: Nil + ) + + checkAnswer( + testData2.rollup("a", "b").agg(sum("b")), + Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2) + :: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3) + :: Row(null, null, 9) :: Nil + ) + } + + test("cube overlapping columns") { + checkAnswer( + testData2.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), + Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1) + :: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1) + :: Row(null, 1, 3) :: Row(null, 2, 0) + :: Row(null, null, 3) :: Nil + ) + + checkAnswer( + testData2.cube("a", "b").agg(sum("b")), + Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2) + :: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3) + :: Row(null, 1, 3) :: Row(null, 2, 6) + :: Row(null, null, 9) :: Nil + ) + } + test("spark.sql.retainGroupColumns config") { checkAnswer( testData2.groupBy("a").agg(sum($"b")), @@ -81,6 +198,13 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("agg without groups and functions") { + checkAnswer( + testData2.agg(lit(1)), + Row(1) + ) + } + test("average") { checkAnswer( testData2.agg(avg('a), mean('a)), @@ -133,7 +257,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("count") { - assert(testData2.count() === testData2.map(_ => 1).count()) + assert(testData2.count() === testData2.rdd.map(_ => 1).count()) checkAnswer( testData2.agg(count('a), sumDistinct('a)), // non-partial @@ -162,6 +286,31 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("multiple column distinct count") { + val df1 = Seq( + ("a", "b", "c"), + ("a", "b", "c"), + ("a", "b", "d"), + ("x", "y", "z"), + ("x", "q", null.asInstanceOf[String])) + .toDF("key1", "key2", "key3") + + checkAnswer( + df1.agg(countDistinct('key1, 'key2)), + Row(3) + ) + + checkAnswer( + df1.agg(countDistinct('key1, 'key2, 'key3)), + Row(3) + ) + + checkAnswer( + df1.groupBy('key1).agg(countDistinct('key2, 'key3)), + Seq(Row("a", 2), Row("x", 1)) + ) + } + test("zero count") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( @@ -170,10 +319,13 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("stddev") { - val testData2ADev = math.sqrt(4 / 5.0) + val testData2ADev = math.sqrt(4.0 / 5.0) checkAnswer( testData2.agg(stddev('a), stddev_pop('a), stddev_samp('a)), Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev)) + checkAnswer( + testData2.agg(stddev("a"), stddev_pop("a"), stddev_samp("a")), + Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev)) } test("zero stddev") { @@ -219,17 +371,23 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("zero moments") { val input = Seq((1, 2)).toDF("a", "b") checkAnswer( - input.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), - Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN)) + input.agg(stddev('a), stddev_samp('a), stddev_pop('a), variance('a), + var_samp('a), var_pop('a), skewness('a), kurtosis('a)), + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN, 0.0, + Double.NaN, Double.NaN)) checkAnswer( input.agg( + expr("stddev(a)"), + expr("stddev_samp(a)"), + expr("stddev_pop(a)"), expr("variance(a)"), expr("var_samp(a)"), expr("var_pop(a)"), expr("skewness(a)"), expr("kurtosis(a)")), - Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN)) + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN, 0.0, + Double.NaN, Double.NaN)) } test("null moments") { @@ -237,7 +395,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( emptyTableData.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), - Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)) + Row(null, null, null, null, null)) checkAnswer( emptyTableData.agg( @@ -246,6 +404,6 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { expr("var_pop(a)"), expr("skewness(a)"), expr("kurtosis(a)")), - Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)) + Row(null, null, null, null, null)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 09f7b507670c9..72f676e6225ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -41,6 +41,21 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { test("UDF on array") { val f = udf((a: String) => a) val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") - df.select(array($"a").as("s")).select(f(expr("s[0]"))).collect() + df.select(array($"a").as("s")).select(f($"s".getItem(0))).collect() + } + + test("UDF on map") { + val f = udf((a: String) => a) + val df = Seq("a" -> 1).toDF("a", "b") + df.select(map($"a", $"b").as("s")).select(f($"s".getItem("a"))).collect() + } + + test("SPARK-12477 accessing null element in array field") { + val df = sparkContext.parallelize(Seq((Seq("val1", null, "val2"), + Seq(Some(1), None, Some(2))))).toDF("s", "i") + val nullStringRow = df.selectExpr("s[1]").collect()(0) + assert(nullStringRow == org.apache.spark.sql.Row(null)) + val nullIntRow = df.selectExpr("i[1]").collect()(0) + assert(nullIntRow == org.apache.spark.sql.Row(null)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 3a3f19af1473b..746e25a0c3ec5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.nio.charset.StandardCharsets + import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -42,15 +44,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { val expectedType = ArrayType(IntegerType, containsNull = false) assert(row.schema(0).dataType === expectedType) - assert(row.getAs[Seq[Int]](0) === Seq(0, 2)) + assert(row.getSeq[Int](0) === Seq(0, 2)) } - // Turn this on once we add a rule to the analyzer to throw a friendly exception - ignore("array: throw exception if putting columns of different types into an array") { - val df = Seq((0, "str")).toDF("a", "b") - intercept[AnalysisException] { - df.select(array("a", "b")) - } + test("map with column expressions") { + val df = Seq(1 -> "a").toDF("a", "b") + val row = df.select(map($"a" + 1, $"b")).first() + + val expectedType = MapType(IntegerType, StringType, valueContainsNull = true) + assert(row.schema(0).dataType === expectedType) + assert(row.getMap[Int, String](0) === Map(2 -> "a")) } test("struct with column name") { @@ -167,12 +170,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } test("misc sha1 function") { - val df = Seq(("ABC", "ABC".getBytes)).toDF("a", "b") + val df = Seq(("ABC", "ABC".getBytes(StandardCharsets.UTF_8))).toDF("a", "b") checkAnswer( df.select(sha1($"a"), sha1($"b")), Row("3c01bdbb26f358bab27f267924aa2c9a03fcfdb8", "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")) - val dfEmpty = Seq(("", "".getBytes)).toDF("a", "b") + val dfEmpty = Seq(("", "".getBytes(StandardCharsets.UTF_8))).toDF("a", "b") checkAnswer( dfEmpty.selectExpr("sha1(a)", "sha1(b)"), Row("da39a3ee5e6b4b0d3255bfef95601890afd80709", "da39a3ee5e6b4b0d3255bfef95601890afd80709")) @@ -308,10 +311,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null, null)) ) - val df2 = Seq((Array[Array[Int]](Array(2)), "x")).toDF("a", "b") - assert(intercept[AnalysisException] { - df2.selectExpr("sort_array(a)").collect() - }.getMessage().contains("does not support sorting array of type array")) + val df2 = Seq((Array[Array[Int]](Array(2), Array(1), Array(2, 4), null), "x")).toDF("a", "b") + checkAnswer( + df2.selectExpr("sort_array(a, true)", "sort_array(a, false)"), + Seq( + Row( + Seq[Seq[Int]](null, Seq(1), Seq(2), Seq(2, 4)), + Seq[Seq[Int]](Seq(2, 4), Seq(2), Seq(1), null))) + ) val df3 = Seq(("xxx", "x")).toDF("a", "b") assert(intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 56ad71ea4f487..067a62d011ec4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -42,17 +44,45 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil) } + test("join - sorted columns not in join's outputSet") { + val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str_sort").as('df1) + val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as('df2) + val df3 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as('df3) + + checkAnswer( + df.join(df2, $"df1.int" === $"df2.int", "outer").select($"df1.int", $"df2.int2") + .orderBy('str_sort.asc, 'str.asc), + Row(null, 6) :: Row(1, 3) :: Row(3, null) :: Nil) + + checkAnswer( + df2.join(df3, $"df2.int" === $"df3.int", "inner") + .select($"df2.int", $"df3.int").orderBy($"df2.str".desc), + Row(5, 5) :: Row(1, 1) :: Nil) + } + test("join - join using multiple columns and specifying join type") { - val df = Seq(1, 2, 3).map(i => (i, i + 1, i.toString)).toDF("int", "int2", "str") - val df2 = Seq(1, 2, 3).map(i => (i, i + 1, (i + 1).toString)).toDF("int", "int2", "str") + val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str") + val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str") + + checkAnswer( + df.join(df2, Seq("int", "str"), "inner"), + Row(1, "1", 2, 3) :: Nil) checkAnswer( df.join(df2, Seq("int", "str"), "left"), - Row(1, 2, "1", null) :: Row(2, 3, "2", null) :: Row(3, 4, "3", null) :: Nil) + Row(1, "1", 2, 3) :: Row(3, "3", 4, null) :: Nil) checkAnswer( df.join(df2, Seq("int", "str"), "right"), - Row(null, null, null, 2) :: Row(null, null, null, 3) :: Row(null, null, null, 4) :: Nil) + Row(1, "1", 2, 3) :: Row(5, "5", null, 6) :: Nil) + + checkAnswer( + df.join(df2, Seq("int", "str"), "outer"), + Row(1, "1", 2, 3) :: Row(3, "3", 4, null) :: Row(5, "5", null, 6) :: Nil) + + checkAnswer( + df.join(df2, Seq("int", "str"), "left_semi"), + Row(1, "1", 2) :: Nil) } test("join - join using self join") { @@ -111,14 +141,67 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") // equijoin - should be converted into broadcast join - val plan1 = df1.join(broadcast(df2), "key").queryExecution.executedPlan + val plan1 = df1.join(broadcast(df2), "key").queryExecution.sparkPlan assert(plan1.collect { case p: BroadcastHashJoin => p }.size === 1) // no join key -- should not be a broadcast join - val plan2 = df1.join(broadcast(df2)).queryExecution.executedPlan + val plan2 = df1.join(broadcast(df2)).queryExecution.sparkPlan assert(plan2.collect { case p: BroadcastHashJoin => p }.size === 0) // planner should not crash without a join - broadcast(df1).queryExecution.executedPlan + broadcast(df1).queryExecution.sparkPlan + + // SPARK-12275: no physical plan for BroadcastHint in some condition + withTempPath { path => + df1.write.parquet(path.getCanonicalPath) + val pf1 = sqlContext.read.parquet(path.getCanonicalPath) + assert(df1.join(broadcast(pf1)).count() === 4) + } + } + + test("join - outer join conversion") { + val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str").as("a") + val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as("b") + + // outer -> left + val outerJoin2Left = df.join(df2, $"a.int" === $"b.int", "outer").where($"a.int" === 3) + assert(outerJoin2Left.queryExecution.optimizedPlan.collect { + case j @ Join(_, _, LeftOuter, _) => j }.size === 1) + checkAnswer( + outerJoin2Left, + Row(3, 4, "3", null, null, null) :: Nil) + + // outer -> right + val outerJoin2Right = df.join(df2, $"a.int" === $"b.int", "outer").where($"b.int" === 5) + assert(outerJoin2Right.queryExecution.optimizedPlan.collect { + case j @ Join(_, _, RightOuter, _) => j }.size === 1) + checkAnswer( + outerJoin2Right, + Row(null, null, null, 5, 6, "5") :: Nil) + + // outer -> inner + val outerJoin2Inner = df.join(df2, $"a.int" === $"b.int", "outer"). + where($"a.int" === 1 && $"b.int2" === 3) + assert(outerJoin2Inner.queryExecution.optimizedPlan.collect { + case j @ Join(_, _, Inner, _) => j }.size === 1) + checkAnswer( + outerJoin2Inner, + Row(1, 2, "1", 1, 3, "1") :: Nil) + + // right -> inner + val rightJoin2Inner = df.join(df2, $"a.int" === $"b.int", "right").where($"a.int" === 1) + assert(rightJoin2Inner.queryExecution.optimizedPlan.collect { + case j @ Join(_, _, Inner, _) => j }.size === 1) + checkAnswer( + rightJoin2Inner, + Row(1, 2, "1", 1, 3, "1") :: Nil) + + // left -> inner + val leftJoin2Inner = df.join(df2, $"a.int" === $"b.int", "left").where($"b.int2" === 3) + assert(leftJoin2Inner.queryExecution.optimizedPlan.collect { + case j @ Join(_, _, Inner, _) => j }.size === 1) + checkAnswer( + leftJoin2Inner, + Row(1, 2, "1", 1, 3, "1") :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index e34875471f093..18e04c24a4b9b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -141,26 +141,36 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { } test("fill with map") { - val df = Seq[(String, String, java.lang.Long, java.lang.Double, java.lang.Boolean)]( - (null, null, null, null, null)).toDF("a", "b", "c", "d", "e") - checkAnswer( - df.na.fill(Map( - "a" -> "test", - "c" -> 1, - "d" -> 2.2, - "e" -> false - )), - Row("test", null, 1, 2.2, false)) - - // Test Java version - checkAnswer( - df.na.fill(Map( - "a" -> "test", - "c" -> 1, - "d" -> 2.2, - "e" -> false - ).asJava), - Row("test", null, 1, 2.2, false)) + val df = Seq[(String, String, java.lang.Integer, java.lang.Long, + java.lang.Float, java.lang.Double, java.lang.Boolean)]( + (null, null, null, null, null, null, null)) + .toDF("stringFieldA", "stringFieldB", "integerField", "longField", + "floatField", "doubleField", "booleanField") + + val fillMap = Map( + "stringFieldA" -> "test", + "integerField" -> 1, + "longField" -> 2L, + "floatField" -> 3.3f, + "doubleField" -> 4.4d, + "booleanField" -> false) + + val expectedRow = Row("test", null, 1, 2L, 3.3f, 4.4d, false) + + checkAnswer(df.na.fill(fillMap), expectedRow) + checkAnswer(df.na.fill(fillMap.asJava), expectedRow) // Test Java version + + // Ensure replacement values are cast to the column data type. + checkAnswer(df.na.fill(Map( + "integerField" -> 1d, + "longField" -> 2d, + "floatField" -> 3d, + "doubleField" -> 4d)), + Row(null, null, 1, 2L, 3f, 4d, null)) + + // Ensure column types do not change. Columns that have null values replaced + // will no longer be flagged as nullable, so do not compare schemas directly. + assert(df.na.fill(fillMap).schema.fields.map(_.dataType) === df.schema.fields.map(_.dataType)) } test("replace") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala new file mode 100644 index 0000000000000..368aa5cd141f0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +class DataFramePivotSuite extends QueryTest with SharedSQLContext{ + import testImplicits._ + + test("pivot courses with literals") { + checkAnswer( + courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java")) + .agg(sum($"earnings")), + Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil + ) + } + + test("pivot year with literals") { + checkAnswer( + courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("pivot courses with literals and multiple aggregations") { + checkAnswer( + courseSales.groupBy($"year") + .pivot("course", Seq("dotNET", "Java")) + .agg(sum($"earnings"), avg($"earnings")), + Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: + Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil + ) + } + + test("pivot year with string values (cast)") { + checkAnswer( + courseSales.groupBy("course").pivot("year", Seq("2012", "2013")).sum("earnings"), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("pivot year with int values") { + checkAnswer( + courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).sum("earnings"), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("pivot courses with no values") { + // Note Java comes before dotNet in sorted order + checkAnswer( + courseSales.groupBy("year").pivot("course").agg(sum($"earnings")), + Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil + ) + } + + test("pivot year with no values") { + checkAnswer( + courseSales.groupBy("course").pivot("year").agg(sum($"earnings")), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("pivot max values enforced") { + sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1) + intercept[AnalysisException]( + courseSales.groupBy("year").pivot("course") + ) + sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get) + } + + test("pivot with UnresolvedFunction") { + checkAnswer( + courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java")) + .agg("earnings" -> "sum"), + Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 6524abcf5e97f..0ea7727e45029 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -19,8 +19,13 @@ package org.apache.spark.sql import java.util.Random +import org.scalatest.Matchers._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.functions.col import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.DoubleType class DataFrameStatSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -41,7 +46,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { val data = sparkContext.parallelize(1 to n, 2).toDF("id") checkAnswer( data.sample(withReplacement = false, 0.05, seed = 13), - Seq(16, 23, 88, 100).map(Row(_)) + Seq(3, 17, 27, 58, 62).map(Row(_)) ) } @@ -52,7 +57,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { val splits = data.randomSplit(Array[Double](1, 2, 3), seed) assert(splits.length == 3, "wrong number of splits") - assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList == + assert(splits.reduce((a, b) => a.union(b)).sort("id").collect().toList == data.collect().toList, "incomplete or wrong split") val s = splits.map(_.count()) @@ -62,6 +67,28 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } } + test("randomSplit on reordered partitions") { + // This test ensures that randomSplit does not create overlapping splits even when the + // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of + // rows in each partition. + val data = + sparkContext.parallelize(1 to 600, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id") + val splits = data.randomSplit(Array[Double](2, 3), seed = 1) + + assert(splits.length == 2, "wrong number of splits") + + // Verify that the splits span the entire dataset + assert(splits.flatMap(_.collect()).toSet == data.collect().toSet) + + // Verify that the splits don't overlap + assert(splits(0).intersect(splits(1)).collect().isEmpty) + + // Verify that the results are deterministic across multiple runs + val firstRun = splits.toSeq.map(_.collect().toSeq) + val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq) + assert(firstRun == secondRun) + } + test("pearson correlation") { val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c") val corr1 = df.stat.corr("a", "b", "pearson") @@ -98,6 +125,33 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { assert(math.abs(decimalRes) < 1e-12) } + test("approximate quantile") { + val n = 1000 + val df = Seq.tabulate(n)(i => (i, 2.0 * i)).toDF("singles", "doubles") + + val q1 = 0.5 + val q2 = 0.8 + val epsilons = List(0.1, 0.05, 0.001) + + for (epsilon <- epsilons) { + val Array(single1) = df.stat.approxQuantile("singles", Array(q1), epsilon) + val Array(double2) = df.stat.approxQuantile("doubles", Array(q2), epsilon) + // Also make sure there is no regression by computing multiple quantiles at once. + val Array(d1, d2) = df.stat.approxQuantile("doubles", Array(q1, q2), epsilon) + val Array(s1, s2) = df.stat.approxQuantile("singles", Array(q1, q2), epsilon) + + val error_single = 2 * 1000 * epsilon + val error_double = 2 * 2000 * epsilon + + assert(math.abs(single1 - q1 * n) < error_single) + assert(math.abs(double2 - 2 * q2 * n) < error_double) + assert(math.abs(s1 - q1 * n) < error_single) + assert(math.abs(s2 - q2 * n) < error_single) + assert(math.abs(d1 - 2 * q1 * n) < error_double) + assert(math.abs(d2 - 2 * q2 * n) < error_double) + } + } + test("crosstab") { val rng = new Random() val data = Seq.tabulate(25)(i => (rng.nextInt(5), rng.nextInt(10))) @@ -186,6 +240,98 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) checkAnswer( sampled.groupBy("key").count().orderBy("key"), - Seq(Row(0, 5), Row(1, 8))) + Seq(Row(0, 6), Row(1, 11))) + } + + // This test case only verifies that `DataFrame.countMinSketch()` methods do return + // `CountMinSketch`es that meet required specs. Test cases for `CountMinSketch` can be found in + // `CountMinSketchSuite` in project spark-sketch. + test("countMinSketch") { + val df = sqlContext.range(1000) + + val sketch1 = df.stat.countMinSketch("id", depth = 10, width = 20, seed = 42) + assert(sketch1.totalCount() === 1000) + assert(sketch1.depth() === 10) + assert(sketch1.width() === 20) + + val sketch2 = df.stat.countMinSketch($"id", depth = 10, width = 20, seed = 42) + assert(sketch2.totalCount() === 1000) + assert(sketch2.depth() === 10) + assert(sketch2.width() === 20) + + val sketch3 = df.stat.countMinSketch("id", eps = 0.001, confidence = 0.99, seed = 42) + assert(sketch3.totalCount() === 1000) + assert(sketch3.relativeError() === 0.001) + assert(sketch3.confidence() === 0.99 +- 5e-3) + + val sketch4 = df.stat.countMinSketch($"id", eps = 0.001, confidence = 0.99, seed = 42) + assert(sketch4.totalCount() === 1000) + assert(sketch4.relativeError() === 0.001 +- 1e04) + assert(sketch4.confidence() === 0.99 +- 5e-3) + + intercept[IllegalArgumentException] { + df.select('id cast DoubleType as 'id) + .stat + .countMinSketch('id, depth = 10, width = 20, seed = 42) + } + } + + // This test only verifies some basic requirements, more correctness tests can be found in + // `BloomFilterSuite` in project spark-sketch. + test("Bloom filter") { + val df = sqlContext.range(1000) + + val filter1 = df.stat.bloomFilter("id", 1000, 0.03) + assert(filter1.expectedFpp() - 0.03 < 1e-3) + assert(0.until(1000).forall(filter1.mightContain)) + + val filter2 = df.stat.bloomFilter($"id" * 3, 1000, 0.03) + assert(filter2.expectedFpp() - 0.03 < 1e-3) + assert(0.until(1000).forall(i => filter2.mightContain(i * 3))) + + val filter3 = df.stat.bloomFilter("id", 1000, 64 * 5) + assert(filter3.bitSize() == 64 * 5) + assert(0.until(1000).forall(filter3.mightContain)) + + val filter4 = df.stat.bloomFilter($"id" * 3, 1000, 64 * 5) + assert(filter4.bitSize() == 64 * 5) + assert(0.until(1000).forall(i => filter4.mightContain(i * 3))) + } +} + + +class DataFrameStatPerfSuite extends QueryTest with SharedSQLContext with Logging { + + // Turn on this test if you want to test the performance of approximate quantiles. + ignore("computing quantiles should not take much longer than describe()") { + val df = sqlContext.range(5000000L).toDF("col1").cache() + def seconds(f: => Any): Double = { + // Do some warmup + logDebug("warmup...") + for (i <- 1 to 10) { + df.count() + f + } + logDebug("execute...") + // Do it 10 times and report median + val times = (1 to 10).map { i => + val start = System.nanoTime() + f + val end = System.nanoTime() + (end - start) / 1e9 + } + logDebug("execute done") + times.sum / times.length.toDouble + } + + logDebug("*** Normal describe ***") + val t1 = seconds { df.describe() } + logDebug(s"T1 = $t1") + logDebug("*** Just quantiles ***") + val t2 = seconds { + StatFunctions.multipleApproxQuantiles(df, Seq("col1"), Seq(0.1, 0.25, 0.5, 0.75, 0.9), 0.01) + } + logDebug(s"T1 = $t1, T2 = $t2") } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 84a616d0b9081..e953a6e8ef0c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.io.File +import java.nio.charset.StandardCharsets import scala.language.postfixOps import scala.util.Random @@ -25,35 +26,29 @@ import scala.util.Random import org.scalatest.Matchers._ import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation -import org.apache.spark.sql.execution.Exchange +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union} +import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.aggregate.TungstenAggregate +import org.apache.spark.sql.execution.exchange.{BroadcastExchange, ReusedExchange, ShuffleExchange} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestData.TestData2 +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} +import org.apache.spark.sql.test.SQLTestData.TestData2 import org.apache.spark.sql.types._ class DataFrameSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("analysis error should be eagerly reported") { - // Eager analysis. - withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "true") { - intercept[Exception] { testData.select('nonExistentName) } - intercept[Exception] { - testData.groupBy('key).agg(Map("nonExistentName" -> "sum")) - } - intercept[Exception] { - testData.groupBy("nonExistentName").agg(Map("key" -> "sum")) - } - intercept[Exception] { - testData.groupBy($"abcd").agg(Map("key" -> "sum")) - } + intercept[Exception] { testData.select('nonExistentName) } + intercept[Exception] { + testData.groupBy('key).agg(Map("nonExistentName" -> "sum")) } - - // No more eager analysis once the flag is turned off - withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "false") { - testData.select('nonExistentName) + intercept[Exception] { + testData.groupBy("nonExistentName").agg(Map("key" -> "sum")) + } + intercept[Exception] { + testData.groupBy($"abcd").agg(Map("key" -> "sum")) } } @@ -71,7 +66,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row(1, 1) :: Nil) } - test("invalid plan toString, debug mode") { + ignore("invalid plan toString, debug mode") { // Turn on debug mode so we can see invalid query plans. import org.apache.spark.sql.execution.debug._ @@ -98,6 +93,36 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.collect().toSeq) } + test("union all") { + val unionDF = testData.union(testData).union(testData) + .union(testData).union(testData) + + // Before optimizer, Union should be combined. + assert(unionDF.queryExecution.analyzed.collect { + case j: Union if j.children.size == 5 => j }.size === 1) + + checkAnswer( + unionDF.agg(avg('key), max('key), min('key), sum('key)), + Row(50.5, 100, 1, 25250) :: Nil + ) + } + + test("union should union DataFrames with UDTs (SPARK-13410)") { + val rowRDD1 = sparkContext.parallelize(Seq(Row(1, new ExamplePoint(1.0, 2.0)))) + val schema1 = StructType(Array(StructField("label", IntegerType, false), + StructField("point", new ExamplePointUDT(), false))) + val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0, 4.0)))) + val schema2 = StructType(Array(StructField("label", IntegerType, false), + StructField("point", new ExamplePointUDT(), false))) + val df1 = sqlContext.createDataFrame(rowRDD1, schema1) + val df2 = sqlContext.createDataFrame(rowRDD2, schema2) + + checkAnswer( + df1.union(df2).orderBy("label"), + Seq(Row(1, new ExamplePoint(1.0, 2.0)), Row(2, new ExamplePoint(3.0, 4.0))) + ) + } + test("empty data frame") { assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) assert(sqlContext.emptyDataFrame.count() === 0) @@ -141,22 +166,62 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { ) } - test("SPARK-8930: explode should fail with a meaningful message if it takes a star") { + test("Star Expansion - CreateStruct and CreateArray") { + val structDf = testData2.select("a", "b").as("record") + // CreateStruct and CreateArray in aggregateExpressions + assert(structDf.groupBy($"a").agg(min(struct($"record.*"))).first() == Row(3, Row(3, 1))) + assert(structDf.groupBy($"a").agg(min(array($"record.*"))).first() == Row(3, Seq(3, 1))) + + // CreateStruct and CreateArray in project list (unresolved alias) + assert(structDf.select(struct($"record.*")).first() == Row(Row(1, 1))) + assert(structDf.select(array($"record.*")).first().getAs[Seq[Int]](0) === Seq(1, 1)) + + // CreateStruct and CreateArray in project list (alias) + assert(structDf.select(struct($"record.*").as("a")).first() == Row(Row(1, 1))) + assert(structDf.select(array($"record.*").as("a")).first().getAs[Seq[Int]](0) === Seq(1, 1)) + } + + test("Star Expansion - hash") { + val structDf = testData2.select("a", "b").as("record") + checkAnswer( + structDf.groupBy($"a", $"b").agg(min(hash($"a", $"*"))), + structDf.groupBy($"a", $"b").agg(min(hash($"a", $"a", $"b")))) + + checkAnswer( + structDf.groupBy($"a", $"b").agg(hash($"a", $"*")), + structDf.groupBy($"a", $"b").agg(hash($"a", $"a", $"b"))) + + checkAnswer( + structDf.select(hash($"*")), + structDf.select(hash($"record.*"))) + + checkAnswer( + structDf.select(hash($"a", $"*")), + structDf.select(hash($"a", $"record.*"))) + } + + test("Star Expansion - explode should fail with a meaningful message if it takes a star") { val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv") val e = intercept[AnalysisException] { df.explode($"*") { case Row(prefix: String, csv: String) => csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq }.queryExecution.assertAnalyzed() } - assert(e.getMessage.contains( - "Cannot explode *, explode can only be applied on a specific column.")) + assert(e.getMessage.contains("Invalid usage of '*' in explode/json_tuple/UDTF")) - df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) => - csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq - }.queryExecution.assertAnalyzed() + checkAnswer( + df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) => + csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq + }, + Row("1", "1,2", "1:1") :: + Row("1", "1,2", "1:2") :: + Row("2", "4", "2:4") :: + Row("3", "7,8,9", "3:7") :: + Row("3", "7,8,9", "3:8") :: + Row("3", "7,8,9", "3:9") :: Nil) } - test("explode alias and star") { + test("Star Expansion - explode alias and star") { val df = Seq((Array("a"), 1)).toDF("a", "b") checkAnswer( @@ -164,6 +229,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row("a", Seq("a"), 1) :: Nil) } + test("sort after generate with join=true") { + val df = Seq((Array("a"), 1)).toDF("a", "b") + + checkAnswer( + df.select($"*", explode($"a").as("c")).sortWithinPartitions("b", "c"), + Row(Seq("a"), 1, "a") :: Nil) + } + test("selectExpr") { checkAnswer( testData.selectExpr("abs(key)", "value"), @@ -176,6 +249,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.select("key").collect().toSeq) } + test("selectExpr with udtf") { + val df = Seq((Map("1" -> 1), 1)).toDF("a", "b") + checkAnswer( + df.selectExpr("explode(a)"), + Row("1", 1) :: Nil) + } + test("filterExpr") { val res = testData.collect().filter(_.getInt(0) > 90).toSeq checkAnswer(testData.filter("key > 90"), res) @@ -301,6 +381,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer( mapData.toDF().limit(1), mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) + + // SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake + checkAnswer( + sqlContext.range(2).toDF().limit(2147483638), + Row(0) :: Row(1) :: Nil + ) } test("except") { @@ -322,6 +408,48 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row(3, "c") :: Row(4, "d") :: Nil) checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.intersect(nullInts), + Row(1) :: + Row(2) :: + Row(3) :: + Row(null) :: Nil) + + // check if values are de-duplicated + checkAnswer( + allNulls.intersect(allNulls), + Row(null) :: Nil) + + // check if values are de-duplicated + val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value") + checkAnswer( + df.intersect(df), + Row("id1", 1) :: + Row("id", 1) :: + Row("id1", 2) :: Nil) + } + + test("intersect - nullability") { + val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(_.nullable == false)) + + val df1 = nonNullableInts.intersect(nullInts) + checkAnswer(df1, Row(1) :: Row(3) :: Nil) + assert(df1.schema.forall(_.nullable == false)) + + val df2 = nullInts.intersect(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(3) :: Nil) + assert(df2.schema.forall(_.nullable == false)) + + val df3 = nullInts.intersect(nullInts) + checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + assert(df3.schema.forall(_.nullable == true)) + + val df4 = nonNullableInts.intersect(nonNullableInts) + checkAnswer(df4, Row(1) :: Row(3) :: Nil) + assert(df4.schema.forall(_.nullable == false)) } test("udf") { @@ -334,15 +462,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { ) } - test("deprecated callUdf in SQLContext") { - val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") - val sqlctx = df.sqlContext - sqlctx.udf.register("simpleUdf", (v: Int) => v * v) - checkAnswer( - df.select($"id", callUdf("simpleUdf", $"value")), - Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil) - } - test("callUDF in SQLContext") { val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") val sqlctx = df.sqlContext @@ -378,6 +497,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.schema.map(_.name) === Seq("value")) } + test("drop columns using drop") { + val src = Seq((0, 2, 3)).toDF("a", "b", "c") + val df = src.drop("a", "b") + checkAnswer(df, Row(3)) + assert(df.schema.map(_.name) === Seq("c")) + } + test("drop unknown column (no-op)") { val df = testData.drop("random") checkAnswer( @@ -510,8 +636,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val jsonDF = sqlContext.read.json(jsonDir) assert(parquetDF.inputFiles.nonEmpty) - val unioned = jsonDF.unionAll(parquetDF).inputFiles.sorted - val allFiles = (jsonDF.inputFiles ++ parquetDF.inputFiles).toSet.toArray.sorted + val unioned = jsonDF.union(parquetDF).inputFiles.sorted + val allFiles = (jsonDF.inputFiles ++ parquetDF.inputFiles).distinct.sorted assert(unioned === allFiles) } } @@ -526,7 +652,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val longString = Array.fill(21)("1").mkString val df = sparkContext.parallelize(Seq("1", longString)).toDF() val expectedAnswerForFalse = """+---------------------+ - ||_1 | + ||value | |+---------------------+ ||1 | ||111111111111111111111| @@ -534,7 +660,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { |""".stripMargin assert(df.showString(10, false) === expectedAnswerForFalse) val expectedAnswerForTrue = """+--------------------+ - || _1| + || value| |+--------------------+ || 1| ||11111111111111111...| @@ -578,6 +704,21 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.showString(10) === expectedAnswer) } + test("showString: binary") { + val df = Seq( + ("12".getBytes(StandardCharsets.UTF_8), "ABC.".getBytes(StandardCharsets.UTF_8)), + ("34".getBytes(StandardCharsets.UTF_8), "12346".getBytes(StandardCharsets.UTF_8)) + ).toDF() + val expectedAnswer = """+-------+----------------+ + || _1| _2| + |+-------+----------------+ + ||[31 32]| [41 42 43 2E]| + ||[33 34]|[31 32 33 34 36]| + |+-------+----------------+ + |""".stripMargin + assert(df.showString(10) === expectedAnswer) + } + test("showString: minimum column width") { val df = Seq( (1, 1), @@ -621,11 +762,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-6899: type should match when using codegen") { - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - checkAnswer( - decimalData.agg(avg('a)), - Row(new java.math.BigDecimal(2.0))) - } + checkAnswer(decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2.0))) } test("SPARK-7133: Implement struct, array, and map field accessor") { @@ -744,6 +881,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val res11 = sqlContext.range(-1).select("id") assert(res11.count == 0) + + // using the default slice number + val res12 = sqlContext.range(3, 15, 3).select("id") + assert(res12.count == 4) + assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) } test("SPARK-8621: support empty string column name") { @@ -781,7 +923,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { .write.format("parquet").save("temp") } assert(e.getMessage.contains("Duplicate column(s)")) - assert(e.getMessage.contains("parquet")) assert(e.getMessage.contains("column1")) assert(!e.getMessage.contains("column2")) @@ -792,7 +933,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { .write.format("json").save("temp") } assert(f.getMessage.contains("Duplicate column(s)")) - assert(f.getMessage.contains("JSON")) assert(f.getMessage.contains("column1")) assert(f.getMessage.contains("column3")) assert(!f.getMessage.contains("column2")) @@ -835,7 +975,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed.")) // error case: insert into an OneRowRelation - new DataFrame(sqlContext, OneRowRelation).registerTempTable("one_row") + Dataset.ofRows(sqlContext, OneRowRelation).registerTempTable("one_row") val e3 = intercept[AnalysisException] { insertion.write.insertInto("one_row") } @@ -844,31 +984,16 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-8608: call `show` on local DataFrame with random columns should return same value") { - // Make sure we can pass this test for both codegen mode and interpreted mode. - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - val df = testData.select(rand(33)) - assert(df.showString(5) == df.showString(5)) - } - - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { - val df = testData.select(rand(33)) - assert(df.showString(5) == df.showString(5)) - } + val df = testData.select(rand(33)) + assert(df.showString(5) == df.showString(5)) // We will reuse the same Expression object for LocalRelation. - val df = (1 to 10).map(Tuple1.apply).toDF().select(rand(33)) - assert(df.showString(5) == df.showString(5)) + val df1 = (1 to 10).map(Tuple1.apply).toDF().select(rand(33)) + assert(df1.showString(5) == df1.showString(5)) } test("SPARK-8609: local DataFrame with random columns should return same value after sort") { - // Make sure we can pass this test for both codegen mode and interpreted mode. - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - checkAnswer(testData.sort(rand(33)), testData.sort(rand(33))) - } - - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { - checkAnswer(testData.sort(rand(33)), testData.sort(rand(33))) - } + checkAnswer(testData.sort(rand(33)), testData.sort(rand(33))) // We will reuse the same Expression object for LocalRelation. val df = (1 to 10).map(Tuple1.apply).toDF() @@ -886,6 +1011,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(expected === actual) } + test("Sorting columns are not in Filter and Project") { + checkAnswer( + upperCaseData.filter('N > 1).select('N).filter('N < 6).orderBy('L.asc), + Row(2) :: Row(3) :: Row(4) :: Row(5) :: Nil) + } + test("SPARK-9323: DataFrame.orderBy should support nested column name") { val df = sqlContext.read.json(sparkContext.makeRDD( """{"a": {"b": 1}}""" :: Nil)) @@ -916,7 +1047,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val dir2 = new File(dir, "dir2").getCanonicalPath df2.write.format("json").save(dir2) - checkAnswer(sqlContext.read.format("json").load(Array(dir1, dir2)), + checkAnswer(sqlContext.read.format("json").load(dir1, dir2), Row(1, 22) :: Row(2, 23) :: Nil) checkAnswer(sqlContext.read.format("json").load(dir1), @@ -924,12 +1055,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } - test("SPARK-10034: Sort on Aggregate with aggregation expression named 'aggOrdering'") { + test("Alias uses internally generated names 'aggOrder' and 'havingCondition'") { val df = Seq(1 -> 2).toDF("i", "j") - val query = df.groupBy('i) - .agg(max('j).as("aggOrdering")) + val query1 = df.groupBy('i) + .agg(max('j).as("aggOrder")) .orderBy(sum('j)) - checkAnswer(query, Row(1, 2)) + checkAnswer(query1, Row(1, 2)) + + // In the plan, there are two attributes having the same name 'havingCondition' + // One is a user-provided alias name; another is an internally generated one. + val query2 = df.groupBy('i) + .agg(max('j).as("havingCondition")) + .where(sum('j) > 0) + .orderBy('havingCondition.asc) + checkAnswer(query2, Row(1, 2)) } test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") { @@ -965,7 +1104,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } - val union = df1.unionAll(df2) + val union = df1.union(df2) checkAnswer( union.filter('i < rand(7) * 10), expected(union) @@ -995,6 +1134,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-10743: keep the name of expression if possible when do cast") { val df = (1 to 10).map(Tuple1.apply).toDF("i").as("src") assert(df.select($"src.i".cast(StringType)).columns.head === "i") + assert(df.select($"src.i".cast(StringType).cast(IntegerType)).columns.head === "i") } test("SPARK-11301: fix case sensitivity for filter on partitioned columns") { @@ -1013,14 +1153,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { private def verifyNonExchangingAgg(df: DataFrame) = { var atFirstAgg: Boolean = false df.queryExecution.executedPlan.foreach { - case agg: TungstenAggregate => { + case agg: TungstenAggregate => atFirstAgg = !atFirstAgg - } - case _ => { + case _ => if (atFirstAgg) { fail("Should not have operators between the two aggregations") } - } } } @@ -1030,13 +1168,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { private def verifyExchangingAgg(df: DataFrame) = { var atFirstAgg: Boolean = false df.queryExecution.executedPlan.foreach { - case agg: TungstenAggregate => { + case agg: TungstenAggregate => if (atFirstAgg) { fail("Should not have back to back Aggregates") } atFirstAgg = true - } - case e: Exchange => atFirstAgg = false + case e: ShuffleExchange => atFirstAgg = false case _ => } } @@ -1071,17 +1208,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // Walk each partition and verify that it is sorted descending and does not contain all // the values. df4.rdd.foreachPartition { p => - var previousValue: Int = -1 - var allSequential: Boolean = true - p.foreach { r => - val v: Int = r.getInt(1) - if (previousValue != -1) { - if (previousValue < v) throw new SparkException("Partition is not ordered.") - if (v + 1 != previousValue) allSequential = false + // Skip empty partition + if (p.hasNext) { + var previousValue: Int = -1 + var allSequential: Boolean = true + p.foreach { r => + val v: Int = r.getInt(1) + if (previousValue != -1) { + if (previousValue < v) throw new SparkException("Partition is not ordered.") + if (v + 1 != previousValue) allSequential = false + } + previousValue = v } - previousValue = v + if (allSequential) throw new SparkException("Partition should not be globally ordered") } - if (allSequential) throw new SparkException("Partition should not be globally ordered") } // Distribute and order by with multiple order bys @@ -1102,8 +1242,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } // Distribute into one partition and order by. This partition should contain all the values. - val df6 = data.repartition(1, $"a").sortWithinPartitions($"b".asc) - // Walk each partition and verify that it is sorted descending and not globally sorted. + val df6 = data.repartition(1, $"a").sortWithinPartitions("b") + // Walk each partition and verify that it is sorted ascending and not globally sorted. df6.rdd.foreachPartition { p => var previousValue: Int = -1 var allSequential: Boolean = true @@ -1128,4 +1268,165 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } } + + // This test case is to verify a bug when making a new instance of LogicalRDD. + test("SPARK-11633: LogicalRDD throws TreeNode Exception: Failed to Copy Node") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val rdd = sparkContext.makeRDD(Seq(Row(1, 3), Row(2, 1))) + val df = sqlContext.createDataFrame( + rdd, + new StructType().add("f1", IntegerType).add("f2", IntegerType), + needsConversion = false).select($"F1", $"f2".as("f2")) + val df1 = df.as("a") + val df2 = df.as("b") + checkAnswer(df1.join(df2, $"a.f2" === $"b.f2"), Row(1, 3, 1, 3) :: Row(2, 1, 2, 1) :: Nil) + } + } + + test("SPARK-10656: completely support special chars") { + val df = Seq(1 -> "a").toDF("i_$.a", "d^'a.") + checkAnswer(df.select(df("*")), Row(1, "a")) + checkAnswer(df.withColumnRenamed("d^'a.", "a"), Row(1, "a")) + } + + test("SPARK-11725: correctly handle null inputs for ScalaUDF") { + val df = sparkContext.parallelize(Seq( + new java.lang.Integer(22) -> "John", + null.asInstanceOf[java.lang.Integer] -> "Lucy")).toDF("age", "name") + + // passing null into the UDF that could handle it + val boxedUDF = udf[java.lang.Integer, java.lang.Integer] { + (i: java.lang.Integer) => if (i == null) -10 else null + } + checkAnswer(df.select(boxedUDF($"age")), Row(null) :: Row(-10) :: Nil) + + sqlContext.udf.register("boxedUDF", + (i: java.lang.Integer) => (if (i == null) -10 else null): java.lang.Integer) + checkAnswer(sql("select boxedUDF(null), boxedUDF(-1)"), Row(-10, null) :: Nil) + + val primitiveUDF = udf((i: Int) => i * 2) + checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil) + } + + test("SPARK-12398 truncated toString") { + val df1 = Seq((1L, "row1")).toDF("id", "name") + assert(df1.toString() === "[id: bigint, name: string]") + + val df2 = Seq((1L, "c2", false)).toDF("c1", "c2", "c3") + assert(df2.toString === "[c1: bigint, c2: string ... 1 more field]") + + val df3 = Seq((1L, "c2", false, 10)).toDF("c1", "c2", "c3", "c4") + assert(df3.toString === "[c1: bigint, c2: string ... 2 more fields]") + + val df4 = Seq((1L, Tuple2(1L, "val"))).toDF("c1", "c2") + assert(df4.toString === "[c1: bigint, c2: struct<_1: bigint, _2: string>]") + + val df5 = Seq((1L, Tuple2(1L, "val"), 20.0)).toDF("c1", "c2", "c3") + assert(df5.toString === "[c1: bigint, c2: struct<_1: bigint, _2: string> ... 1 more field]") + + val df6 = Seq((1L, Tuple2(1L, "val"), 20.0, 1)).toDF("c1", "c2", "c3", "c4") + assert(df6.toString === "[c1: bigint, c2: struct<_1: bigint, _2: string> ... 2 more fields]") + + val df7 = Seq((1L, Tuple3(1L, "val", 2), 20.0, 1)).toDF("c1", "c2", "c3", "c4") + assert( + df7.toString === + "[c1: bigint, c2: struct<_1: bigint, _2: string ... 1 more field> ... 2 more fields]") + + val df8 = Seq((1L, Tuple7(1L, "val", 2, 3, 4, 5, 6), 20.0, 1)).toDF("c1", "c2", "c3", "c4") + assert( + df8.toString === + "[c1: bigint, c2: struct<_1: bigint, _2: string ... 5 more fields> ... 2 more fields]") + + val df9 = + Seq((1L, Tuple4(1L, Tuple4(1L, 2L, 3L, 4L), 2L, 3L), 20.0, 1)).toDF("c1", "c2", "c3", "c4") + assert( + df9.toString === + "[c1: bigint, c2: struct<_1: bigint," + + " _2: struct<_1: bigint," + + " _2: bigint ... 2 more fields> ... 2 more fields> ... 2 more fields]") + + } + + test("reuse exchange") { + withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "2") { + val df = sqlContext.range(100).toDF() + val join = df.join(df, "id") + val plan = join.queryExecution.executedPlan + checkAnswer(join, df) + assert( + join.queryExecution.executedPlan.collect { case e: ShuffleExchange => true }.size === 1) + assert(join.queryExecution.executedPlan.collect { case e: ReusedExchange => true }.size === 1) + val broadcasted = broadcast(join) + val join2 = join.join(broadcasted, "id").join(broadcasted, "id") + checkAnswer(join2, df) + assert( + join2.queryExecution.executedPlan.collect { case e: ShuffleExchange => true }.size === 1) + assert( + join2.queryExecution.executedPlan.collect { case e: BroadcastExchange => true }.size === 1) + assert( + join2.queryExecution.executedPlan.collect { case e: ReusedExchange => true }.size === 4) + } + } + + test("sameResult() on aggregate") { + val df = sqlContext.range(100) + val agg1 = df.groupBy().count() + val agg2 = df.groupBy().count() + // two aggregates with different ExprId within them should have same result + assert(agg1.queryExecution.executedPlan.sameResult(agg2.queryExecution.executedPlan)) + val agg3 = df.groupBy().sum() + assert(!agg1.queryExecution.executedPlan.sameResult(agg3.queryExecution.executedPlan)) + val df2 = sqlContext.range(101) + val agg4 = df2.groupBy().count() + assert(!agg1.queryExecution.executedPlan.sameResult(agg4.queryExecution.executedPlan)) + } + + test("SPARK-12512: support `.` in column name for withColumn()") { + val df = Seq("a" -> "b").toDF("col.a", "col.b") + checkAnswer(df.select(df("*")), Row("a", "b")) + checkAnswer(df.withColumn("col.a", lit("c")), Row("c", "b")) + checkAnswer(df.withColumn("col.c", lit("c")), Row("a", "b", "c")) + } + + test("SPARK-12841: cast in filter") { + checkAnswer( + Seq(1 -> "a").toDF("i", "j").filter($"i".cast(StringType) === "1"), + Row(1, "a")) + } + + test("SPARK-12982: Add table name validation in temp table registration") { + val df = Seq("foo", "bar").map(Tuple1.apply).toDF("col") + // invalid table name test as below + intercept[AnalysisException](df.registerTempTable("t~")) + // valid table name test as below + df.registerTempTable("table1") + // another invalid table name test as below + intercept[AnalysisException](df.registerTempTable("#$@sum")) + // another invalid table name test as below + intercept[AnalysisException](df.registerTempTable("table!#")) + } + + test("assertAnalyzed shouldn't replace original stack trace") { + val e = intercept[AnalysisException] { + sqlContext.range(1).select('id as 'a, 'id as 'b).groupBy('a).agg('b) + } + + assert(e.getStackTrace.head.getClassName != classOf[QueryExecution].getName) + } + + test("SPARK-13774: Check error message for non existent path without globbed paths") { + val e = intercept[AnalysisException] (sqlContext.read.format("csv"). + load("/xyz/file2", "/xyz/file21", "/abc/files555", "a")).getMessage() + assert(e.startsWith("Path does not exist")) + } + + test("SPARK-13774: Check error message for not existent globbed paths") { + val e = intercept[AnalysisException] (sqlContext.read.format("text"). + load( "/xyz/*")).getMessage() + assert(e.startsWith("Path does not exist")) + + val e1 = intercept[AnalysisException] (sqlContext.read.json("/mnt/*/*-xyz.json").rdd). + getMessage() + assert(e1.startsWith("Path does not exist")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala new file mode 100644 index 0000000000000..06584ec21e2f8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.TimeZone + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StringType + +class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { + + import testImplicits._ + + override def beforeEach(): Unit = { + super.beforeEach() + TimeZone.setDefault(TimeZone.getTimeZone("UTC")) + } + + override def afterEach(): Unit = { + super.beforeEach() + TimeZone.setDefault(null) + } + + test("tumbling window groupBy statement") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + checkAnswer( + df.groupBy(window($"time", "10 seconds")) + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select("counts"), + Seq(Row(1), Row(1), Row(1)) + ) + } + + test("tumbling window groupBy statement with startTime") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(window($"time", "10 seconds", "10 seconds", "5 seconds"), $"id") + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select("counts"), + Seq(Row(1), Row(1), Row(1))) + } + + test("tumbling window with multi-column projection") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + checkAnswer( + df.select(window($"time", "10 seconds"), $"value") + .orderBy($"window.start".asc) + .select($"window.start".cast("string"), $"window.end".cast("string"), $"value"), + Seq( + Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4), + Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1), + Row("2016-03-27 19:39:50", "2016-03-27 19:40:00", 2) + ) + ) + } + + test("sliding window grouping") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(window($"time", "10 seconds", "3 seconds", "0 second")) + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select($"window.start".cast("string"), $"window.end".cast("string"), $"counts"), + // 2016-03-27 19:39:27 UTC -> 4 bins + // 2016-03-27 19:39:34 UTC -> 3 bins + // 2016-03-27 19:39:56 UTC -> 3 bins + Seq( + Row("2016-03-27 19:39:18", "2016-03-27 19:39:28", 1), + Row("2016-03-27 19:39:21", "2016-03-27 19:39:31", 1), + Row("2016-03-27 19:39:24", "2016-03-27 19:39:34", 1), + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", 2), + Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1), + Row("2016-03-27 19:39:33", "2016-03-27 19:39:43", 1), + Row("2016-03-27 19:39:48", "2016-03-27 19:39:58", 1), + Row("2016-03-27 19:39:51", "2016-03-27 19:40:01", 1), + Row("2016-03-27 19:39:54", "2016-03-27 19:40:04", 1)) + ) + } + + test("sliding window projection") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + checkAnswer( + df.select(window($"time", "10 seconds", "3 seconds", "0 second"), $"value") + .orderBy($"window.start".asc, $"value".desc).select("value"), + // 2016-03-27 19:39:27 UTC -> 4 bins + // 2016-03-27 19:39:34 UTC -> 3 bins + // 2016-03-27 19:39:56 UTC -> 3 bins + Seq(Row(4), Row(4), Row(4), Row(4), Row(1), Row(1), Row(1), Row(2), Row(2), Row(2)) + ) + } + + test("windowing combined with explode expression") { + val df = Seq( + ("2016-03-27 19:39:34", 1, Seq("a", "b")), + ("2016-03-27 19:39:56", 2, Seq("a", "c", "d"))).toDF("time", "value", "ids") + + checkAnswer( + df.select(window($"time", "10 seconds"), $"value", explode($"ids")) + .orderBy($"window.start".asc).select("value"), + // first window exploded to two rows for "a", and "b", second window exploded to 3 rows + Seq(Row(1), Row(1), Row(2), Row(2), Row(2)) + ) + } + + test("null timestamps") { + val df = Seq( + ("2016-03-27 09:00:05", 1), + ("2016-03-27 09:00:32", 2), + (null, 3), + (null, 4)).toDF("time", "value") + + checkDataset( + df.select(window($"time", "10 seconds"), $"value") + .orderBy($"window.start".asc) + .select("value") + .as[Int], + 1, 2) // null columns are dropped + } + + test("time window joins") { + val df = Seq( + ("2016-03-27 09:00:05", 1), + ("2016-03-27 09:00:32", 2), + (null, 3), + (null, 4)).toDF("time", "value") + + val df2 = Seq( + ("2016-03-27 09:00:02", 3), + ("2016-03-27 09:00:35", 6)).toDF("time", "othervalue") + + checkAnswer( + df.select(window($"time", "10 seconds"), $"value").join( + df2.select(window($"time", "10 seconds"), $"othervalue"), Seq("window")) + .groupBy("window") + .agg((sum("value") + sum("othervalue")).as("total")) + .orderBy($"window.start".asc).select("total"), + Seq(Row(4), Row(8))) + } + + test("negative timestamps") { + val df4 = Seq( + ("1970-01-01 00:00:02", 1), + ("1970-01-01 00:00:12", 2)).toDF("time", "value") + checkAnswer( + df4.select(window($"time", "10 seconds", "10 seconds", "5 seconds"), $"value") + .orderBy($"window.start".asc) + .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), + Seq( + Row("1969-12-31 23:59:55", "1970-01-01 00:00:05", 1), + Row("1970-01-01 00:00:05", "1970-01-01 00:00:15", 2)) + ) + } + + test("multiple time windows in a single operator throws nice exception") { + val df = Seq( + ("2016-03-27 09:00:02", 3), + ("2016-03-27 09:00:35", 6)).toDF("time", "value") + val e = intercept[AnalysisException] { + df.select(window($"time", "10 second"), window($"time", "15 second")).collect() + } + assert(e.getMessage.contains( + "Multiple time window expressions would result in a cartesian product")) + } + + test("aliased windows") { + val df = Seq( + ("2016-03-27 19:39:34", 1, Seq("a", "b")), + ("2016-03-27 19:39:56", 2, Seq("a", "c", "d"))).toDF("time", "value", "ids") + + checkAnswer( + df.select(window($"time", "10 seconds").as("time_window"), $"value") + .orderBy($"time_window.start".asc) + .select("value"), + Seq(Row(1), Row(2)) + ) + } + + test("millisecond precision sliding windows") { + val df = Seq( + ("2016-03-27 09:00:00.41", 3), + ("2016-03-27 09:00:00.62", 6), + ("2016-03-27 09:00:00.715", 8)).toDF("time", "value") + checkAnswer( + df.groupBy(window($"time", "200 milliseconds", "40 milliseconds", "0 milliseconds")) + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"counts"), + Seq( + Row("2016-03-27 09:00:00.24", "2016-03-27 09:00:00.44", 1), + Row("2016-03-27 09:00:00.28", "2016-03-27 09:00:00.48", 1), + Row("2016-03-27 09:00:00.32", "2016-03-27 09:00:00.52", 1), + Row("2016-03-27 09:00:00.36", "2016-03-27 09:00:00.56", 1), + Row("2016-03-27 09:00:00.4", "2016-03-27 09:00:00.6", 1), + Row("2016-03-27 09:00:00.44", "2016-03-27 09:00:00.64", 1), + Row("2016-03-27 09:00:00.48", "2016-03-27 09:00:00.68", 1), + Row("2016-03-27 09:00:00.52", "2016-03-27 09:00:00.72", 2), + Row("2016-03-27 09:00:00.56", "2016-03-27 09:00:00.76", 2), + Row("2016-03-27 09:00:00.6", "2016-03-27 09:00:00.8", 2), + Row("2016-03-27 09:00:00.64", "2016-03-27 09:00:00.84", 1), + Row("2016-03-27 09:00:00.68", "2016-03-27 09:00:00.88", 1)) + ) + } + + private def withTempTable(f: String => Unit): Unit = { + val tableName = "temp" + Seq( + ("2016-03-27 19:39:34", 1), + ("2016-03-27 19:39:56", 2), + ("2016-03-27 19:39:27", 4)).toDF("time", "value").registerTempTable(tableName) + try { + f(tableName) + } finally { + sqlContext.dropTempTable(tableName) + } + } + + test("time window in SQL with single string expression") { + withTempTable { table => + checkAnswer( + sqlContext.sql(s"""select window(time, "10 seconds"), value from $table""") + .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), + Seq( + Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4), + Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1), + Row("2016-03-27 19:39:50", "2016-03-27 19:40:00", 2) + ) + ) + } + } + + test("time window in SQL with with two expressions") { + withTempTable { table => + checkAnswer( + sqlContext.sql( + s"""select window(time, "10 seconds", 10000000), value from $table""") + .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), + Seq( + Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4), + Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1), + Row("2016-03-27 19:39:50", "2016-03-27 19:40:00", 2) + ) + ) + } + } + + test("time window in SQL with with three expressions") { + withTempTable { table => + checkAnswer( + sqlContext.sql( + s"""select window(time, "10 seconds", 10000000, "5 seconds"), value from $table""") + .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), + Seq( + Row("2016-03-27 19:39:25", "2016-03-27 19:39:35", 1), + Row("2016-03-27 19:39:25", "2016-03-27 19:39:35", 4), + Row("2016-03-27 19:39:55", "2016-03-27 19:40:05", 2) + ) + ) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index 7ae12a7895f7e..68e99d6a6b816 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -31,52 +31,46 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("test simple types") { - withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { - val df = sparkContext.parallelize(Seq((1, 2))).toDF("a", "b") - assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2)) - } + val df = sparkContext.parallelize(Seq((1, 2))).toDF("a", "b") + assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2)) } test("test struct type") { - withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { - val struct = Row(1, 2L, 3.0F, 3.0) - val data = sparkContext.parallelize(Seq(Row(1, struct))) + val struct = Row(1, 2L, 3.0F, 3.0) + val data = sparkContext.parallelize(Seq(Row(1, struct))) - val schema = new StructType() - .add("a", IntegerType) - .add("b", - new StructType() - .add("b1", IntegerType) - .add("b2", LongType) - .add("b3", FloatType) - .add("b4", DoubleType)) + val schema = new StructType() + .add("a", IntegerType) + .add("b", + new StructType() + .add("b1", IntegerType) + .add("b2", LongType) + .add("b3", FloatType) + .add("b4", DoubleType)) - val df = sqlContext.createDataFrame(data, schema) - assert(df.select("b").first() === Row(struct)) - } + val df = sqlContext.createDataFrame(data, schema) + assert(df.select("b").first() === Row(struct)) } test("test nested struct type") { - withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { - val innerStruct = Row(1, "abcd") - val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg") - val data = sparkContext.parallelize(Seq(Row(1, outerStruct))) + val innerStruct = Row(1, "abcd") + val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg") + val data = sparkContext.parallelize(Seq(Row(1, outerStruct))) - val schema = new StructType() - .add("a", IntegerType) - .add("b", - new StructType() - .add("b1", IntegerType) - .add("b2", LongType) - .add("b3", FloatType) - .add("b4", DoubleType) - .add("b5", new StructType() - .add("b5a", IntegerType) - .add("b5b", StringType)) - .add("b6", StringType)) + val schema = new StructType() + .add("a", IntegerType) + .add("b", + new StructType() + .add("b1", IntegerType) + .add("b2", LongType) + .add("b3", FloatType) + .add("b4", DoubleType) + .add("b5", new StructType() + .add("b5a", IntegerType) + .add("b5b", StringType)) + .add("b6", StringType)) - val df = sqlContext.createDataFrame(data, schema) - assert(df.select("b").first() === Row(outerStruct)) - } + val df = sqlContext.createDataFrame(data, schema) + assert(df.select("b").first() === Row(outerStruct)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala new file mode 100644 index 0000000000000..2bcbb1983f7ac --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala @@ -0,0 +1,357 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{DataType, LongType, StructType} + +class DataFrameWindowSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("reuse window partitionBy") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val w = Window.partitionBy("key").orderBy("value") + + checkAnswer( + df.select( + lead("key", 1).over(w), + lead("value", 1).over(w)), + Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) + } + + test("reuse window orderBy") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val w = Window.orderBy("value").partitionBy("key") + + checkAnswer( + df.select( + lead("key", 1).over(w), + lead("value", 1).over(w)), + Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) + } + + test("lead") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + + checkAnswer( + df.select( + lead("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), + Row("1") :: Row(null) :: Row("2") :: Row(null) :: Nil) + } + + test("lag") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + + checkAnswer( + df.select( + lag("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), + Row(null) :: Row("1") :: Row(null) :: Row("2") :: Nil) + } + + test("lead with default value") { + val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), + (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))), + Seq(Row("1"), Row("1"), Row("n/a"), Row("n/a"), Row("2"), Row("n/a"), Row("n/a"))) + } + + test("lag with default value") { + val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), + (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))), + Seq(Row("n/a"), Row("n/a"), Row("1"), Row("1"), Row("n/a"), Row("n/a"), Row("2"))) + } + + test("rank functions in unspecific window") { + val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + $"key", + max("key").over(Window.partitionBy("value").orderBy("key")), + min("key").over(Window.partitionBy("value").orderBy("key")), + mean("key").over(Window.partitionBy("value").orderBy("key")), + count("key").over(Window.partitionBy("value").orderBy("key")), + sum("key").over(Window.partitionBy("value").orderBy("key")), + ntile(2).over(Window.partitionBy("value").orderBy("key")), + row_number().over(Window.partitionBy("value").orderBy("key")), + dense_rank().over(Window.partitionBy("value").orderBy("key")), + rank().over(Window.partitionBy("value").orderBy("key")), + cume_dist().over(Window.partitionBy("value").orderBy("key")), + percent_rank().over(Window.partitionBy("value").orderBy("key"))), + Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d, 0.0d) :: + Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d / 3.0d, 0.0d) :: + Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 1, 2, 2, 2, 1.0d, 0.5d) :: + Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d) :: Nil) + } + + test("aggregation and rows between") { + val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))), + Seq(Row(4.0d / 3.0d), Row(4.0d / 3.0d), Row(3.0d / 2.0d), Row(2.0d), Row(2.0d))) + } + + test("aggregation and range between") { + val df = Seq((1, "1"), (1, "1"), (3, "1"), (2, "2"), (2, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))), + Seq(Row(4.0d / 3.0d), Row(4.0d / 3.0d), Row(7.0d / 4.0d), Row(5.0d / 2.0d), + Row(2.0d), Row(2.0d))) + } + + test("aggregation and rows between with unbounded") { + val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + $"key", + last("key").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, Long.MaxValue)), + last("key").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)), + last("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1))), + Seq(Row(1, 1, 1, 1), Row(2, 3, 2, 3), Row(3, 3, 3, 3), Row(1, 4, 1, 2), Row(2, 4, 2, 4), + Row(4, 4, 4, 4))) + } + + test("aggregation and range between with unbounded") { + val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + $"key", + last("value").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(-2, -1)) + .equalTo("2") + .as("last_v"), + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(Long.MinValue, 1)) + .as("avg_key1"), + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, Long.MaxValue)) + .as("avg_key2"), + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0)) + .as("avg_key3") + ), + Seq(Row(3, null, 3.0d, 4.0d, 3.0d), + Row(5, false, 4.0d, 5.0d, 5.0d), + Row(2, null, 2.0d, 17.0d / 4.0d, 2.0d), + Row(4, true, 11.0d / 3.0d, 5.0d, 4.0d), + Row(5, true, 17.0d / 4.0d, 11.0d / 2.0d, 4.5d), + Row(6, true, 17.0d / 4.0d, 6.0d, 11.0d / 2.0d))) + } + + test("reverse sliding range frame") { + val df = Seq( + (1, "Thin", "Cell Phone", 6000), + (2, "Normal", "Tablet", 1500), + (3, "Mini", "Tablet", 5500), + (4, "Ultra thin", "Cell Phone", 5500), + (5, "Very thin", "Cell Phone", 6000), + (6, "Big", "Tablet", 2500), + (7, "Bendable", "Cell Phone", 3000), + (8, "Foldable", "Cell Phone", 3000), + (9, "Pro", "Tablet", 4500), + (10, "Pro2", "Tablet", 6500)). + toDF("id", "product", "category", "revenue") + val window = Window. + partitionBy($"category"). + orderBy($"revenue".desc). + rangeBetween(-2000L, 1000L) + checkAnswer( + df.select( + $"id", + avg($"revenue").over(window).cast("int")), + Row(1, 5833) :: Row(2, 2000) :: Row(3, 5500) :: + Row(4, 5833) :: Row(5, 5833) :: Row(6, 2833) :: + Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: + Row(10, 6000) :: Nil) + } + + // This is here to illustrate the fact that reverse order also reverses offsets. + test("reverse unbounded range frame") { + val df = Seq(1, 2, 4, 3, 2, 1). + map(Tuple1.apply). + toDF("value") + val window = Window.orderBy($"value".desc) + checkAnswer( + df.select( + $"value", + sum($"value").over(window.rangeBetween(Long.MinValue, 1)), + sum($"value").over(window.rangeBetween(1, Long.MaxValue))), + Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: + Row(3, 11, 6) :: Row(2, 13, 2) :: Row(1, 13, null) :: Nil) + } + + test("statistical functions") { + val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)). + toDF("key", "value") + val window = Window.partitionBy($"key") + checkAnswer( + df.select( + $"key", + var_pop($"value").over(window), + var_samp($"value").over(window), + approxCountDistinct($"value").over(window)), + Seq.fill(4)(Row("a", 1.0d / 4.0d, 1.0d / 3.0d, 2)) + ++ Seq.fill(3)(Row("b", 2.0d / 3.0d, 1.0d, 3))) + } + + test("window function with aggregates") { + val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)). + toDF("key", "value") + val window = Window.orderBy() + checkAnswer( + df.groupBy($"key") + .agg( + sum($"value"), + sum(sum($"value")).over(window) - sum($"value")), + Seq(Row("a", 6, 9), Row("b", 9, 6))) + } + + test("window function with udaf") { + val udaf = new UserDefinedAggregateFunction { + def inputSchema: StructType = new StructType() + .add("a", LongType) + .add("b", LongType) + + def bufferSchema: StructType = new StructType() + .add("product", LongType) + + def dataType: DataType = LongType + + def deterministic: Boolean = true + + def initialize(buffer: MutableAggregationBuffer): Unit = { + buffer(0) = 0L + } + + def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + if (!(input.isNullAt(0) || input.isNullAt(1))) { + buffer(0) = buffer.getLong(0) + input.getLong(0) * input.getLong(1) + } + } + + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) + } + + def evaluate(buffer: Row): Any = + buffer.getLong(0) + } + val df = Seq( + ("a", 1, 1), + ("a", 1, 5), + ("a", 2, 10), + ("a", 2, -1), + ("b", 4, 7), + ("b", 3, 8), + ("b", 2, 4)) + .toDF("key", "a", "b") + val window = Window.partitionBy($"key").orderBy($"a").rangeBetween(Long.MinValue, 0L) + checkAnswer( + df.select( + $"key", + $"a", + $"b", + udaf($"a", $"b").over(window)), + Seq( + Row("a", 1, 1, 6), + Row("a", 1, 5, 6), + Row("a", 2, 10, 24), + Row("a", 2, -1, 24), + Row("b", 4, 7, 60), + Row("b", 3, 8, 32), + Row("b", 2, 4, 8))) + } + + test("null inputs") { + val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)) + .toDF("key", "value") + val window = Window.orderBy() + checkAnswer( + df.select( + $"key", + $"value", + avg(lit(null)).over(window), + sum(lit(null)).over(window)), + Seq( + Row("a", 1, null, null), + Row("a", 1, null, null), + Row("a", 2, null, null), + Row("a", 2, null, null), + Row("b", 4, null, null), + Row("b", 3, null, null), + Row("b", 2, null, null))) + } + + test("last/first with ignoreNulls") { + val nullStr: String = null + val df = Seq( + ("a", 0, nullStr), + ("a", 1, "x"), + ("a", 2, "y"), + ("a", 3, "z"), + ("a", 4, nullStr), + ("b", 1, nullStr), + ("b", 2, nullStr)). + toDF("key", "order", "value") + val window = Window.partitionBy($"key").orderBy($"order") + checkAnswer( + df.select( + $"key", + $"order", + first($"value").over(window), + first($"value", ignoreNulls = false).over(window), + first($"value", ignoreNulls = true).over(window), + last($"value").over(window), + last($"value", ignoreNulls = false).over(window), + last($"value", ignoreNulls = true).over(window)), + Seq( + Row("a", 0, null, null, null, null, null, null), + Row("a", 1, null, null, "x", "x", "x", "x"), + Row("a", 2, null, null, "x", "y", "y", "y"), + Row("a", 3, null, null, "x", "z", "z", "z"), + Row("a", 4, null, null, "x", null, null, "z"), + Row("b", 1, null, null, null, null, null, null), + Row("b", 2, null, null, null, null, null, null))) + } + + test("SPARK-12989 ExtractWindowExpressions treats alias as regular attribute") { + val src = Seq((0, 3, 5)).toDF("a", "b", "c") + .withColumn("Data", struct("a", "b")) + .drop("a") + .drop("b") + val winSpec = Window.partitionBy("Data.a", "Data.b").orderBy($"c".desc) + val df = src.select($"*", max("c").over(winSpec) as "max") + checkAnswer(df, Row(5, Row(0, 3), 5)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala new file mode 100644 index 0000000000000..3a7215ee39728 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.postfixOps + +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.expressions.scala.typed +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + + +object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] { + override def zero: (Long, Long) = (0, 0) + override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = { + (countAndSum._1 + 1, countAndSum._2 + input._2) + } + override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = { + (b1._1 + b2._1, b1._2 + b2._2) + } + override def finish(reduction: (Long, Long)): (Long, Long) = reduction + override def bufferEncoder: Encoder[(Long, Long)] = Encoders.product[(Long, Long)] + override def outputEncoder: Encoder[(Long, Long)] = Encoders.product[(Long, Long)] +} + + +case class AggData(a: Int, b: String) + +object ClassInputAgg extends Aggregator[AggData, Int, Int] { + override def zero: Int = 0 + override def reduce(b: Int, a: AggData): Int = b + a.a + override def finish(reduction: Int): Int = reduction + override def merge(b1: Int, b2: Int): Int = b1 + b2 + override def bufferEncoder: Encoder[Int] = Encoders.scalaInt + override def outputEncoder: Encoder[Int] = Encoders.scalaInt +} + + +object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] { + override def zero: (Int, AggData) = 0 -> AggData(0, "0") + override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a) + override def finish(reduction: (Int, AggData)): Int = reduction._1 + override def merge(b1: (Int, AggData), b2: (Int, AggData)): (Int, AggData) = + (b1._1 + b2._1, b1._2) + override def bufferEncoder: Encoder[(Int, AggData)] = Encoders.product[(Int, AggData)] + override def outputEncoder: Encoder[Int] = Encoders.scalaInt +} + + +object NameAgg extends Aggregator[AggData, String, String] { + def zero: String = "" + def reduce(b: String, a: AggData): String = a.b + b + def merge(b1: String, b2: String): String = b1 + b2 + def finish(r: String): String = r + override def bufferEncoder: Encoder[String] = Encoders.STRING + override def outputEncoder: Encoder[String] = Encoders.STRING +} + + +class ParameterizedTypeSum[IN, OUT : Numeric : Encoder](f: IN => OUT) + extends Aggregator[IN, OUT, OUT] { + + private val numeric = implicitly[Numeric[OUT]] + override def zero: OUT = numeric.zero + override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a)) + override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2) + override def finish(reduction: OUT): OUT = reduction + override def bufferEncoder: Encoder[OUT] = implicitly[Encoder[OUT]] + override def outputEncoder: Encoder[OUT] = implicitly[Encoder[OUT]] +} + + +class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { + + import testImplicits._ + + test("typed aggregation: TypedAggregator") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkDataset( + ds.groupByKey(_._1).agg(typed.sum(_._2)), + ("a", 30.0), ("b", 3.0), ("c", 1.0)) + } + + test("typed aggregation: TypedAggregator, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkDataset( + ds.groupByKey(_._1).agg( + typed.sum(_._2), + expr("sum(_2)").as[Long], + count("*")), + ("a", 30.0, 30L, 2L), ("b", 3.0, 3L, 2L), ("c", 1.0, 1L, 1L)) + } + + test("typed aggregation: complex result type") { + val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() + + checkDataset( + ds.groupByKey(_._1).agg( + expr("avg(_2)").as[Double], + ComplexResultAgg.toColumn), + ("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L))) + } + + test("typed aggregation: in project list") { + val ds = Seq(1, 3, 2, 5).toDS() + + checkDataset( + ds.select(typed.sum((i: Int) => i)), + 11.0) + checkDataset( + ds.select(typed.sum((i: Int) => i), typed.sum((i: Int) => i * 2)), + 11.0 -> 22.0) + } + + test("typed aggregation: class input") { + val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS() + + checkDataset( + ds.select(ClassInputAgg.toColumn), + 3) + } + + test("typed aggregation: class input with reordering") { + val ds = sql("SELECT 'one' AS b, 1 as a").as[AggData] + + checkDataset( + ds.select(ClassInputAgg.toColumn), + 1) + + checkDataset( + ds.select(expr("avg(a)").as[Double], ClassInputAgg.toColumn), + (1.0, 1)) + + checkDataset( + ds.groupByKey(_.b).agg(ClassInputAgg.toColumn), + ("one", 1)) + } + + test("typed aggregation: complex input") { + val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS() + + checkDataset( + ds.select(ComplexBufferAgg.toColumn), + 2 + ) + + checkDataset( + ds.select(expr("avg(a)").as[Double], ComplexBufferAgg.toColumn), + (1.5, 2)) + + checkDataset( + ds.groupByKey(_.b).agg(ComplexBufferAgg.toColumn), + ("one", 1), ("two", 1)) + } + + test("typed aggregate: avg, count, sum") { + val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() + checkDataset( + ds.groupByKey(_._1).agg( + typed.avg(_._2), typed.count(_._2), typed.sum(_._2), typed.sumLong(_._2)), + ("a", 2.0, 2L, 4.0, 4L), ("b", 3.0, 1L, 3.0, 3L)) + } + + test("generic typed sum") { + val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() + checkDataset( + ds.groupByKey(_._1) + .agg(new ParameterizedTypeSum[(String, Int), Double](_._2.toDouble).toColumn), + ("a", 4.0), ("b", 3.0)) + + checkDataset( + ds.groupByKey(_._1) + .agg(new ParameterizedTypeSum((x: (String, Int)) => x._2.toInt).toColumn), + ("a", 4), ("b", 3)) + } + + test("SPARK-12555 - result should not be corrupted after input columns are reordered") { + val ds = sql("SELECT 'Some String' AS b, 1279869254 AS a").as[AggData] + + checkDataset( + ds.groupByKey(_.a).agg(NameAgg.toColumn), + (1279869254, "Some String")) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala new file mode 100644 index 0000000000000..ae9fb80c68f42 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkContext +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.expressions.scala.typed +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.StringType +import org.apache.spark.util.Benchmark + +/** + * Benchmark for Dataset typed operations comparing with DataFrame and RDD versions. + */ +object DatasetBenchmark { + + case class Data(l: Long, s: String) + + def backToBackMap(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = { + import sqlContext.implicits._ + + val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val benchmark = new Benchmark("back-to-back map", numRows) + val func = (d: Data) => Data(d.l + 1, d.s) + + val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + benchmark.addCase("RDD") { iter => + var res = rdd + var i = 0 + while (i < numChains) { + res = rdd.map(func) + i += 1 + } + res.foreach(_ => Unit) + } + + benchmark.addCase("DataFrame") { iter => + var res = df + var i = 0 + while (i < numChains) { + res = res.select($"l" + 1 as "l") + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark.addCase("Dataset") { iter => + var res = df.as[Data] + var i = 0 + while (i < numChains) { + res = res.map(func) + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark + } + + def backToBackFilter(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = { + import sqlContext.implicits._ + + val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val benchmark = new Benchmark("back-to-back filter", numRows) + val func = (d: Data, i: Int) => d.l % (100L + i) == 0L + val funcs = 0.until(numChains).map { i => + (d: Data) => func(d, i) + } + + val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + benchmark.addCase("RDD") { iter => + var res = rdd + var i = 0 + while (i < numChains) { + res = rdd.filter(funcs(i)) + i += 1 + } + res.foreach(_ => Unit) + } + + benchmark.addCase("DataFrame") { iter => + var res = df + var i = 0 + while (i < numChains) { + res = res.filter($"l" % (100L + i) === 0L) + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark.addCase("Dataset") { iter => + var res = df.as[Data] + var i = 0 + while (i < numChains) { + res = res.filter(funcs(i)) + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark + } + + object ComplexAggregator extends Aggregator[Data, Data, Long] { + override def zero: Data = Data(0, "") + + override def reduce(b: Data, a: Data): Data = Data(b.l + a.l, "") + + override def finish(reduction: Data): Long = reduction.l + + override def merge(b1: Data, b2: Data): Data = Data(b1.l + b2.l, "") + + override def bufferEncoder: Encoder[Data] = Encoders.product[Data] + + override def outputEncoder: Encoder[Long] = Encoders.scalaLong + } + + def aggregate(sqlContext: SQLContext, numRows: Long): Benchmark = { + import sqlContext.implicits._ + + val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val benchmark = new Benchmark("aggregate", numRows) + + val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + benchmark.addCase("RDD sum") { iter => + rdd.aggregate(0L)(_ + _.l, _ + _) + } + + benchmark.addCase("DataFrame sum") { iter => + df.select(sum($"l")).queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark.addCase("Dataset sum using Aggregator") { iter => + df.as[Data].select(typed.sumLong((d: Data) => d.l)).queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark.addCase("Dataset complex Aggregator") { iter => + df.as[Data].select(ComplexAggregator.toColumn).queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark + } + + def main(args: Array[String]): Unit = { + val sparkContext = new SparkContext("local[*]", "Dataset benchmark") + val sqlContext = new SQLContext(sparkContext) + + val numRows = 100000000 + val numChains = 10 + + val benchmark = backToBackMap(sqlContext, numRows, numChains) + val benchmark2 = backToBackFilter(sqlContext, numRows, numChains) + val benchmark3 = aggregate(sqlContext, numRows) + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + RDD 1935 / 2105 51.7 19.3 1.0X + DataFrame 756 / 799 132.3 7.6 2.6X + Dataset 7359 / 7506 13.6 73.6 0.3X + */ + benchmark.run() + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + back-to-back filter: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + RDD 1974 / 2036 50.6 19.7 1.0X + DataFrame 103 / 127 967.4 1.0 19.1X + Dataset 4343 / 4477 23.0 43.4 0.5X + */ + benchmark2.run() + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + aggregate: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + RDD sum 2130 / 2166 46.9 21.3 1.0X + DataFrame sum 92 / 128 1085.3 0.9 23.1X + Dataset sum using Aggregator 4111 / 4282 24.3 41.1 0.5X + Dataset complex Aggregator 8782 / 9036 11.4 87.8 0.2X + */ + benchmark3.run() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala new file mode 100644 index 0000000000000..942cc09b6d58e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.postfixOps + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + + +class DatasetCacheSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("persist and unpersist") { + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int]) + val cached = ds.cache() + // count triggers the caching action. It should not throw. + cached.count() + // Make sure, the Dataset is indeed cached. + assertCached(cached) + // Check result. + checkDataset( + cached, + 2, 3, 4) + // Drop the cache. + cached.unpersist() + assert(!sqlContext.isCached(cached), "The Dataset should not be cached.") + } + + test("persist and then rebind right encoder when join 2 datasets") { + val ds1 = Seq("1", "2").toDS().as("a") + val ds2 = Seq(2, 3).toDS().as("b") + + ds1.persist() + assertCached(ds1) + ds2.persist() + assertCached(ds2) + + val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") + checkDataset(joined, ("2", 2)) + assertCached(joined, 2) + + ds1.unpersist() + assert(!sqlContext.isCached(ds1), "The Dataset ds1 should not be cached.") + ds2.unpersist() + assert(!sqlContext.isCached(ds2), "The Dataset ds2 should not be cached.") + } + + test("persist and then groupBy columns asKey, map") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupByKey(_._1) + val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } + agged.persist() + + checkDataset( + agged.filter(_._1 == "b"), + ("b", 3)) + assertCached(agged.filter(_._1 == "b")) + + ds.unpersist() + assert(!sqlContext.isCached(ds), "The Dataset ds should not be cached.") + agged.unpersist() + assert(!sqlContext.isCached(agged), "The Dataset agged should not be cached.") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 32443557fb8e0..ff022b2dc45ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -23,19 +23,23 @@ import org.apache.spark.sql.test.SharedSQLContext case class IntClass(value: Int) +package object packageobject { + case class PackageClass(value: Int) +} + class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("toDS") { val data = Seq(1, 2, 3, 4, 5, 6) - checkAnswer( + checkDataset( data.toDS(), data: _*) } test("as case class / collect") { val ds = Seq(1, 2, 3).toDS().as[IntClass] - checkAnswer( + checkDataset( ds, IntClass(1), IntClass(2), IntClass(3)) @@ -44,14 +48,14 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("map") { val ds = Seq(1, 2, 3).toDS() - checkAnswer( + checkDataset( ds.map(_ + 1), 2, 3, 4) } test("filter") { val ds = Seq(1, 2, 3, 4).toDS() - checkAnswer( + checkDataset( ds.filter(_ % 2 == 0), 2, 4) } @@ -59,7 +63,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("foreach") { val ds = Seq(1, 2, 3).toDS() val acc = sparkContext.accumulator(0) - ds.foreach(acc +=) + ds.foreach(acc += _) assert(acc.value == 6) } @@ -75,29 +79,62 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { assert(ds.reduce(_ + _) == 6) } - test("fold") { - val ds = Seq(1, 2, 3).toDS() - assert(ds.fold(0)(_ + _) == 6) - } - test("groupBy function, keys") { val ds = Seq(1, 2, 3, 4, 5).toDS() - val grouped = ds.groupBy(_ % 2) - checkAnswer( + val grouped = ds.groupByKey(_ % 2) + checkDataset( grouped.keys, 0, 1) } - test("groupBy function, mapGroups") { + test("groupBy function, map") { val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS() - val grouped = ds.groupBy(_ % 2) + val grouped = ds.groupByKey(_ % 2) val agged = grouped.mapGroups { case (g, iter) => val name = if (g == 0) "even" else "odd" - Iterator((name, iter.size)) + (name, iter.size) } - checkAnswer( + checkDataset( agged, ("even", 5), ("odd", 6)) } + + test("groupBy function, flatMap") { + val ds = Seq("a", "b", "c", "xyz", "hello").toDS() + val grouped = ds.groupByKey(_.length) + val agged = grouped.flatMapGroups { case (g, iter) => Iterator(g.toString, iter.mkString) } + + checkDataset( + agged, + "1", "abc", "3", "xyz", "5", "hello") + } + + test("Arrays and Lists") { + checkDataset(Seq(Seq(1)).toDS(), Seq(1)) + checkDataset(Seq(Seq(1.toLong)).toDS(), Seq(1.toLong)) + checkDataset(Seq(Seq(1.toDouble)).toDS(), Seq(1.toDouble)) + checkDataset(Seq(Seq(1.toFloat)).toDS(), Seq(1.toFloat)) + checkDataset(Seq(Seq(1.toByte)).toDS(), Seq(1.toByte)) + checkDataset(Seq(Seq(1.toShort)).toDS(), Seq(1.toShort)) + checkDataset(Seq(Seq(true)).toDS(), Seq(true)) + checkDataset(Seq(Seq("test")).toDS(), Seq("test")) + checkDataset(Seq(Seq(Tuple1(1))).toDS(), Seq(Tuple1(1))) + + checkDataset(Seq(Array(1)).toDS(), Array(1)) + checkDataset(Seq(Array(1.toLong)).toDS(), Array(1.toLong)) + checkDataset(Seq(Array(1.toDouble)).toDS(), Array(1.toDouble)) + checkDataset(Seq(Array(1.toFloat)).toDS(), Array(1.toFloat)) + checkDataset(Seq(Array(1.toByte)).toDS(), Array(1.toByte)) + checkDataset(Seq(Array(1.toShort)).toDS(), Array(1.toShort)) + checkDataset(Seq(Array(true)).toDS(), Array(true)) + checkDataset(Seq(Array("test")).toDS(), Array("test")) + checkDataset(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1))) + } + + test("package objects") { + import packageobject._ + checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1)) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 3e9b621cfd67f..d074535bf6265 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -17,40 +17,89 @@ package org.apache.spark.sql +import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.sql.{Date, Timestamp} + import scala.language.postfixOps +import org.apache.spark.sql.catalyst.encoders.OuterScopes +import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext - -case class ClassData(a: String, b: Int) +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} class DatasetSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("toDS") { - val data = Seq(("a", 1) , ("b", 2), ("c", 3)) - checkAnswer( + val data = Seq(("a", 1), ("b", 2), ("c", 3)) + checkDataset( data.toDS(), data: _*) } test("toDS with RDD") { val ds = sparkContext.makeRDD(Seq("a", "b", "c"), 3).toDS() - checkAnswer( + checkDataset( ds.mapPartitions(_ => Iterator(1)), 1, 1, 1) } + test("range") { + assert(sqlContext.range(10).map(_ + 1).reduce(_ + _) == 55) + assert(sqlContext.range(10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) + assert(sqlContext.range(0, 10).map(_ + 1).reduce(_ + _) == 55) + assert(sqlContext.range(0, 10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) + assert(sqlContext.range(0, 10, 1, 2).map(_ + 1).reduce(_ + _) == 55) + assert(sqlContext.range(0, 10, 1, 2).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) + } + + test("SPARK-12404: Datatype Helper Serializability") { + val ds = sparkContext.parallelize(( + new Timestamp(0), + new Date(0), + java.math.BigDecimal.valueOf(1), + scala.math.BigDecimal(1)) :: Nil).toDS() + + ds.collect() + } + + test("collect, first, and take should use encoders for serialization") { + val item = NonSerializableCaseClass("abcd") + val ds = Seq(item).toDS() + assert(ds.collect().head == item) + assert(ds.collectAsList().get(0) == item) + assert(ds.first() == item) + assert(ds.take(1).head == item) + assert(ds.takeAsList(1).get(0) == item) + assert(ds.toLocalIterator().next() === item) + } + + test("coalesce, repartition") { + val data = (1 to 100).map(i => ClassData(i.toString, i)) + val ds = data.toDS() + + assert(ds.repartition(10).rdd.partitions.length == 10) + checkDataset( + ds.repartition(10), + data: _*) + + assert(ds.coalesce(1).rdd.partitions.length == 1) + checkDataset( + ds.coalesce(1), + data: _*) + } + test("as tuple") { val data = Seq(("a", 1), ("b", 2)).toDF("a", "b") - checkAnswer( + checkDataset( data.as[(String, Int)], ("a", 1), ("b", 2)) } test("as case class / collect") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData] - checkAnswer( + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData] + checkDataset( ds, ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)) assert(ds.collect().head == ClassData("a", 1)) @@ -61,23 +110,59 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.collect() === Array(ClassData("a", 1), ClassData("b", 2), ClassData("c", 3))) } + test("as case class - take") { + val ds = Seq((1, "a"), (2, "b"), (3, "c")).toDF("b", "a").as[ClassData] + assert(ds.take(2) === Array(ClassData("a", 1), ClassData("b", 2))) + } + test("map") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() - checkAnswer( + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + checkDataset( ds.map(v => (v._1, v._2 + 1)), ("a", 2), ("b", 3), ("c", 4)) } + test("map with type change with the exact matched number of attributes") { + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + + checkDataset( + ds.map(identity[(String, Int)]) + .as[OtherTuple] + .map(identity[OtherTuple]), + OtherTuple("a", 1), OtherTuple("b", 2), OtherTuple("c", 3)) + } + + test("map with type change with less attributes") { + val ds = Seq(("a", 1, 3), ("b", 2, 4), ("c", 3, 5)).toDS() + + checkDataset( + ds.as[OtherTuple] + .map(identity[OtherTuple]), + OtherTuple("a", 1), OtherTuple("b", 2), OtherTuple("c", 3)) + } + + test("map and group by with class data") { + // We inject a group by here to make sure this test case is future proof + // when we implement better pipelining and local execution mode. + val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS() + .map(c => ClassData(c.a, c.b + 1)) + .groupByKey(p => p).count() + + checkDataset( + ds, + (ClassData("one", 2), 1L), (ClassData("two", 3), 1L)) + } + test("select") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() - checkAnswer( + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + checkDataset( ds.select(expr("_2 + 1").as[Int]), 2, 3, 4) } test("select 2") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() - checkAnswer( + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + checkDataset( ds.select( expr("_1").as[String], expr("_2").as[Int]) : Dataset[(String, Int)], @@ -85,8 +170,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("select 2, primitive and tuple") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() - checkAnswer( + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + checkDataset( ds.select( expr("_1").as[String], expr("struct(_2, _2)").as[(Int, Int)]), @@ -94,8 +179,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("select 2, primitive and class") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() - checkAnswer( + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + checkDataset( ds.select( expr("_1").as[String], expr("named_struct('a', _1, 'b', _2)").as[ClassData]), @@ -103,7 +188,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("select 2, primitive and class, fields reordered") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkDecoding( ds.select( expr("_1").as[String], @@ -112,59 +197,60 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("filter") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() - checkAnswer( + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + checkDataset( ds.filter(_._1 == "b"), ("b", 2)) } test("foreach") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() val acc = sparkContext.accumulator(0) ds.foreach(v => acc += v._2) assert(acc.value == 6) } test("foreachPartition") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() val acc = sparkContext.accumulator(0) ds.foreachPartition(_.foreach(v => acc += v._2)) assert(acc.value == 6) } test("reduce") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() assert(ds.reduce((a, b) => ("sum", a._2 + b._2)) == ("sum", 6)) } - test("fold") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() - assert(ds.fold(("", 0))((a, b) => ("sum", a._2 + b._2)) == ("sum", 6)) - } - test("joinWith, flat schema") { val ds1 = Seq(1, 2, 3).toDS().as("a") val ds2 = Seq(1, 2).toDS().as("b") - checkAnswer( - ds1.joinWith(ds2, $"a.value" === $"b.value"), + checkDataset( + ds1.joinWith(ds2, $"a.value" === $"b.value", "inner"), (1, 1), (2, 2)) } - test("joinWith, expression condition") { - val ds1 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() - val ds2 = Seq(("a", 1), ("b", 2)).toDS() - - checkAnswer( - ds1.joinWith(ds2, $"_1" === $"a"), - (ClassData("a", 1), ("a", 1)), (ClassData("b", 2), ("b", 2))) + test("joinWith, expression condition, outer join") { + val nullInteger = null.asInstanceOf[Integer] + val nullString = null.asInstanceOf[String] + val ds1 = Seq(ClassNullableData("a", 1), + ClassNullableData("c", 3)).toDS() + val ds2 = Seq(("a", new Integer(1)), + ("b", new Integer(2))).toDS() + + checkDataset( + ds1.joinWith(ds2, $"_1" === $"a", "outer"), + (ClassNullableData("a", 1), ("a", new Integer(1))), + (ClassNullableData("c", 3), (nullString, nullInteger)), + (ClassNullableData(nullString, nullInteger), ("b", new Integer(2)))) } test("joinWith tuple with primitive, expression") { val ds1 = Seq(1, 1, 2).toDS() val ds2 = Seq(("a", 1), ("b", 2)).toDS() - checkAnswer( + checkDataset( ds1.joinWith(ds2, $"value" === $"_2"), (1, ("a", 1)), (1, ("a", 1)), (2, ("b", 2))) } @@ -183,98 +269,420 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds2 = Seq(("a", 1), ("b", 2)).toDS().as("b") val ds3 = Seq(("a", 1), ("b", 2)).toDS().as("c") - checkAnswer( + checkDataset( ds1.joinWith(ds2, $"a._2" === $"b._2").as("ab").joinWith(ds3, $"ab._1._2" === $"c._2"), ((("a", 1), ("a", 1)), ("a", 1)), ((("b", 2), ("b", 2)), ("b", 2))) - } test("groupBy function, keys") { val ds = Seq(("a", 1), ("b", 1)).toDS() - val grouped = ds.groupBy(v => (1, v._2)) - checkAnswer( + val grouped = ds.groupByKey(v => (1, v._2)) + checkDataset( grouped.keys, (1, 1)) } - test("groupBy function, mapGroups") { + test("groupBy function, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy(v => (v._1, "word")) - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g._1, iter.map(_._2).sum)) - } + val grouped = ds.groupByKey(v => (v._1, "word")) + val agged = grouped.mapGroups { case (g, iter) => (g._1, iter.map(_._2).sum) } - checkAnswer( + checkDataset( agged, ("a", 30), ("b", 3), ("c", 1)) } - test("groupBy columns, mapGroups") { + test("groupBy function, flatMap") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1") - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g.getString(0), iter.map(_._2).sum)) + val grouped = ds.groupByKey(v => (v._1, "word")) + val agged = grouped.flatMapGroups { case (g, iter) => + Iterator(g._1, iter.map(_._2).sum.toString) } - checkAnswer( + checkDataset( agged, - ("a", 30), ("b", 3), ("c", 1)) + "a", "30", "b", "3", "c", "1") } - test("groupBy columns asKey, mapGroups") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1").asKey[String] - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g, iter.map(_._2).sum)) - } + test("groupBy function, reduce") { + val ds = Seq("abc", "xyz", "hello").toDS() + val agged = ds.groupByKey(_.length).reduceGroups(_ + _) - checkAnswer( + checkDataset( agged, - ("a", 30), ("b", 3), ("c", 1)) + 3 -> "abcxyz", 5 -> "hello") + } + + test("groupBy single field class, count") { + val ds = Seq("abc", "xyz", "hello").toDS() + val count = ds.groupByKey(s => Tuple1(s.length)).count() + + checkDataset( + count, + (Tuple1(3), 2L), (Tuple1(5), 1L) + ) } - test("groupBy columns asKey tuple, mapGroups") { + test("typed aggregation: expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1", lit(1)).asKey[(String, Int)] - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g, iter.map(_._2).sum)) - } - checkAnswer( - agged, - (("a", 1), 30), (("b", 1), 3), (("c", 1), 1)) + checkDataset( + ds.groupByKey(_._1).agg(sum("_2").as[Long]), + ("a", 30L), ("b", 3L), ("c", 1L)) } - test("groupBy columns asKey class, mapGroups") { + test("typed aggregation: expr, expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).asKey[ClassData] - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g, iter.map(_._2).sum)) - } - checkAnswer( - agged, - (ClassData("a", 1), 30), (ClassData("b", 1), 3), (ClassData("c", 1), 1)) + checkDataset( + ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]), + ("a", 30L, 32L), ("b", 3L, 5L), ("c", 1L, 2L)) + } + + test("typed aggregation: expr, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkDataset( + ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")), + ("a", 30L, 32L, 2L), ("b", 3L, 5L, 2L), ("c", 1L, 2L, 1L)) + } + + test("typed aggregation: expr, expr, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkDataset( + ds.groupByKey(_._1).agg( + sum("_2").as[Long], + sum($"_2" + 1).as[Long], + count("*").as[Long], + avg("_2").as[Double]), + ("a", 30L, 32L, 2L, 15.0), ("b", 3L, 5L, 2L, 1.5), ("c", 1L, 2L, 1L, 1.0)) } test("cogroup") { val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS() val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS() - val cogrouped = ds1.groupBy(_._1).cogroup(ds2.groupBy(_._1)) { case (key, data1, data2) => + val cogrouped = ds1.groupByKey(_._1).cogroup(ds2.groupByKey(_._1)) { case (key, data1, data2) => Iterator(key -> (data1.map(_._2).mkString + "#" + data2.map(_._2).mkString)) } - checkAnswer( + checkDataset( cogrouped, 1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er") } + test("cogroup with complex data") { + val ds1 = Seq(1 -> ClassData("a", 1), 2 -> ClassData("b", 2)).toDS() + val ds2 = Seq(2 -> ClassData("c", 3), 3 -> ClassData("d", 4)).toDS() + val cogrouped = ds1.groupByKey(_._1).cogroup(ds2.groupByKey(_._1)) { case (key, data1, data2) => + Iterator(key -> (data1.map(_._2.a).mkString + data2.map(_._2.a).mkString)) + } + + checkDataset( + cogrouped, + 1 -> "a", 2 -> "bc", 3 -> "d") + } + + test("sample with replacement") { + val n = 100 + val data = sparkContext.parallelize(1 to n, 2).toDS() + checkDataset( + data.sample(withReplacement = true, 0.05, seed = 13), + 5, 10, 52, 73) + } + + test("sample without replacement") { + val n = 100 + val data = sparkContext.parallelize(1 to n, 2).toDS() + checkDataset( + data.sample(withReplacement = false, 0.05, seed = 13), + 3, 17, 27, 58, 62) + } + test("SPARK-11436: we should rebind right encoder when join 2 datasets") { val ds1 = Seq("1", "2").toDS().as("a") val ds2 = Seq(2, 3).toDS().as("b") val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") - checkAnswer(joined, ("2", 2)) + checkDataset(joined, ("2", 2)) } + + test("self join") { + val ds = Seq("1", "2").toDS().as("a") + val joined = ds.joinWith(ds, lit(true)) + checkDataset(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2")) + } + + test("toString") { + val ds = Seq((1, 2)).toDS() + assert(ds.toString == "[_1: int, _2: int]") + } + + test("showString: Kryo encoder") { + implicit val kryoEncoder = Encoders.kryo[KryoData] + val ds = Seq(KryoData(1), KryoData(2)).toDS() + + val expectedAnswer = """+-----------+ + || value| + |+-----------+ + ||KryoData(1)| + ||KryoData(2)| + |+-----------+ + |""".stripMargin + assert(ds.showString(10) === expectedAnswer) + } + + test("Kryo encoder") { + implicit val kryoEncoder = Encoders.kryo[KryoData] + val ds = Seq(KryoData(1), KryoData(2)).toDS() + + assert(ds.groupByKey(p => p).count().collect().toSet == + Set((KryoData(1), 1L), (KryoData(2), 1L))) + } + + test("Kryo encoder self join") { + implicit val kryoEncoder = Encoders.kryo[KryoData] + val ds = Seq(KryoData(1), KryoData(2)).toDS() + assert(ds.joinWith(ds, lit(true)).collect().toSet == + Set( + (KryoData(1), KryoData(1)), + (KryoData(1), KryoData(2)), + (KryoData(2), KryoData(1)), + (KryoData(2), KryoData(2)))) + } + + test("Java encoder") { + implicit val kryoEncoder = Encoders.javaSerialization[JavaData] + val ds = Seq(JavaData(1), JavaData(2)).toDS() + + assert(ds.groupByKey(p => p).count().collect().toSeq == + Seq((JavaData(1), 1L), (JavaData(2), 1L))) + } + + test("Java encoder self join") { + implicit val kryoEncoder = Encoders.javaSerialization[JavaData] + val ds = Seq(JavaData(1), JavaData(2)).toDS() + assert(ds.joinWith(ds, lit(true)).collect().toSet == + Set( + (JavaData(1), JavaData(1)), + (JavaData(1), JavaData(2)), + (JavaData(2), JavaData(1)), + (JavaData(2), JavaData(2)))) + } + + test("SPARK-11894: Incorrect results are returned when using null") { + val nullInt = null.asInstanceOf[java.lang.Integer] + val ds1 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() + val ds2 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() + + checkDataset( + ds1.joinWith(ds2, lit(true)), + ((nullInt, "1"), (nullInt, "1")), + ((new java.lang.Integer(22), "2"), (nullInt, "1")), + ((nullInt, "1"), (new java.lang.Integer(22), "2")), + ((new java.lang.Integer(22), "2"), (new java.lang.Integer(22), "2"))) + } + + test("change encoder with compatible schema") { + val ds = Seq(2 -> 2.toByte, 3 -> 3.toByte).toDF("a", "b").as[ClassData] + assert(ds.collect().toSeq == Seq(ClassData("2", 2), ClassData("3", 3))) + } + + test("verify mismatching field names fail with a good error") { + val ds = Seq(ClassData("a", 1)).toDS() + val e = intercept[AnalysisException] { + ds.as[ClassData2] + } + assert(e.getMessage.contains("cannot resolve '`c`' given input columns: [a, b]"), e.getMessage) + } + + test("runtime nullability check") { + val schema = StructType(Seq( + StructField("f", StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", IntegerType, nullable = false) + )), nullable = true) + )) + + def buildDataset(rows: Row*): Dataset[NestedStruct] = { + val rowRDD = sqlContext.sparkContext.parallelize(rows) + sqlContext.createDataFrame(rowRDD, schema).as[NestedStruct] + } + + checkDataset( + buildDataset(Row(Row("hello", 1))), + NestedStruct(ClassData("hello", 1)) + ) + + // Shouldn't throw runtime exception when parent object (`ClassData`) is null + assert(buildDataset(Row(null)).collect() === Array(NestedStruct(null))) + + val message = intercept[RuntimeException] { + buildDataset(Row(Row("hello", null))).collect() + }.getMessage + + assert(message.contains("Null value appeared in non-nullable field")) + } + + test("SPARK-12478: top level null field") { + val ds0 = Seq(NestedStruct(null)).toDS() + checkDataset(ds0, NestedStruct(null)) + checkAnswer(ds0.toDF(), Row(null)) + + val ds1 = Seq(DeepNestedStruct(NestedStruct(null))).toDS() + checkDataset(ds1, DeepNestedStruct(NestedStruct(null))) + checkAnswer(ds1.toDF(), Row(Row(null))) + } + + test("support inner class in Dataset") { + val outer = new OuterClass + OuterScopes.addOuterScope(outer) + val ds = Seq(outer.InnerClass("1"), outer.InnerClass("2")).toDS() + checkDataset(ds.map(_.a), "1", "2") + } + + test("grouping key and grouped value has field with same name") { + val ds = Seq(ClassData("a", 1), ClassData("a", 2)).toDS() + val agged = ds.groupByKey(d => ClassNullableData(d.a, null)).mapGroups { + case (key, values) => key.a + values.map(_.b).sum + } + + checkDataset(agged, "a3") + } + + test("cogroup's left and right side has field with same name") { + val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() + val right = Seq(ClassNullableData("a", 3), ClassNullableData("b", 4)).toDS() + val cogrouped = left.groupByKey(_.a).cogroup(right.groupByKey(_.a)) { + case (key, lData, rData) => Iterator(key + lData.map(_.b).sum + rData.map(_.b.toInt).sum) + } + + checkDataset(cogrouped, "a13", "b24") + } + + test("give nice error message when the real number of fields doesn't match encoder schema") { + val ds = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() + + val message = intercept[AnalysisException] { + ds.as[(String, Int, Long)] + }.message + assert(message == + "Try to map struct to Tuple3, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct\n" + + " - Target schema: struct<_1:string,_2:int,_3:bigint>") + + val message2 = intercept[AnalysisException] { + ds.as[Tuple1[String]] + }.message + assert(message2 == + "Try to map struct to Tuple1, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct\n" + + " - Target schema: struct<_1:string>") + } + + test("SPARK-13440: Resolving option fields") { + val df = Seq(1, 2, 3).toDS() + val ds = df.as[Option[Int]] + checkDataset( + ds.filter(_ => true), + Some(1), Some(2), Some(3)) + } + + test("SPARK-13540 Dataset of nested class defined in Scala object") { + checkDataset( + Seq(OuterObject.InnerClass("foo")).toDS(), + OuterObject.InnerClass("foo")) + } + + test("SPARK-14000: case class with tuple type field") { + checkDataset( + Seq(TupleClass((1, "a"))).toDS(), + TupleClass(1, "a") + ) + } + + test("isStreaming returns false for static Dataset") { + val data = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + assert(!data.isStreaming, "static Dataset returned true for 'isStreaming'.") + } + + test("isStreaming returns true for streaming Dataset") { + val data = MemoryStream[Int].toDS() + assert(data.isStreaming, "streaming Dataset returned false for 'isStreaming'.") + } + + test("isStreaming returns true after static and streaming Dataset join") { + val static = Seq(("a", 1), ("b", 2), ("c", 3)).toDF("a", "b") + val streaming = MemoryStream[Int].toDS().toDF("b") + val df = streaming.join(static, Seq("b")) + assert(df.isStreaming, "streaming Dataset returned false for 'isStreaming'.") + } + + test("SPARK-14554: Dataset.map may generate wrong java code for wide table") { + val wideDF = sqlContext.range(10).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*) + // Make sure the generated code for this plan can compile and execute. + checkDataset(wideDF.map(_.getLong(0)), 0L until 10 : _*) + } +} + +case class OtherTuple(_1: String, _2: Int) + +case class TupleClass(data: (Int, String)) + +class OuterClass extends Serializable { + case class InnerClass(a: String) +} + +object OuterObject { + case class InnerClass(a: String) +} + +case class ClassData(a: String, b: Int) +case class ClassData2(c: String, d: Int) +case class ClassNullableData(a: String, b: Integer) + +case class NestedStruct(f: ClassData) +case class DeepNestedStruct(f: NestedStruct) + +/** + * A class used to test serialization using encoders. This class throws exceptions when using + * Java serialization -- so the only way it can be "serialized" is through our encoders. + */ +case class NonSerializableCaseClass(value: String) extends Externalizable { + override def readExternal(in: ObjectInput): Unit = { + throw new UnsupportedOperationException + } + + override def writeExternal(out: ObjectOutput): Unit = { + throw new UnsupportedOperationException + } +} + +/** Used to test Kryo encoder. */ +class KryoData(val a: Int) { + override def equals(other: Any): Boolean = { + a == other.asInstanceOf[KryoData].a + } + override def hashCode: Int = a + override def toString: String = s"KryoData($a)" +} + +object KryoData { + def apply(a: Int): KryoData = new KryoData(a) +} + +/** Used to test Java encoder. */ +class JavaData(val a: Int) extends Serializable { + override def equals(other: Any): Boolean = { + a == other.asInstanceOf[JavaData].a + } + override def hashCode: Int = a + override def toString: String = s"JavaData($a)" +} + +object JavaData { + def apply(a: Int): JavaData = new JavaData(a) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 9080c53c491ac..f7aa3b747ae5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.sql.{Timestamp, Date} +import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -38,15 +38,21 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1) } - // This is a bad test. SPARK-9196 will fix it and re-enable it. - ignore("function current_timestamp") { + test("function current_timestamp and now") { val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1)) + // Execution in one query should return the same value - checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), - Row(true)) - assert(math.abs(sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp( - 0).getTime - System.currentTimeMillis()) < 5000) + checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), Row(true)) + + // Current timestamp should return the current timestamp ... + val before = System.currentTimeMillis + val got = sql("SELECT CURRENT_TIMESTAMP()").collect().head.getTimestamp(0).getTime + val after = System.currentTimeMillis + assert(got >= before && got <= after) + + // Now alias + checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = NOW()"""), Row(true)) } val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") @@ -442,6 +448,30 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(date1.getTime / 1000L), Row(date2.getTime / 1000L))) checkAnswer(df.selectExpr(s"unix_timestamp(s, '$fmt')"), Seq( Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + + val now = sql("select unix_timestamp()").collect().head.getLong(0) + checkAnswer(sql(s"select cast ($now as timestamp)"), Row(new java.util.Date(now * 1000))) + } + + test("to_unix_timestamp") { + val date1 = Date.valueOf("2015-07-24") + val date2 = Date.valueOf("2015-07-25") + val ts1 = Timestamp.valueOf("2015-07-24 10:00:00.3") + val ts2 = Timestamp.valueOf("2015-07-25 02:02:02.2") + val s1 = "2015/07/24 10:00:00.5" + val s2 = "2015/07/25 02:02:02.6" + val ss1 = "2015-07-24 10:00:00" + val ss2 = "2015-07-25 02:02:02" + val fmt = "yyyy/MM/dd HH:mm:ss.S" + val df = Seq((date1, ts1, s1, ss1), (date2, ts2, s2, ss2)).toDF("d", "ts", "s", "ss") + checkAnswer(df.selectExpr("to_unix_timestamp(ts)"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.selectExpr("to_unix_timestamp(ss)"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.selectExpr(s"to_unix_timestamp(d, '$fmt')"), Seq( + Row(date1.getTime / 1000L), Row(date2.getTime / 1000L))) + checkAnswer(df.selectExpr(s"to_unix_timestamp(s, '$fmt')"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) } test("datediff") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala index 78a98798eff64..b1987c690811d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala @@ -15,25 +15,26 @@ * limitations under the License. */ -package test.org.apache.spark.sql +package org.apache.spark.sql import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Literal, GenericInternalRow, Attribute} -import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.{Row, Strategy, QueryTest} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.unsafe.types.UTF8String case class FastOperator(output: Seq[Attribute]) extends SparkPlan { override protected def doExecute(): RDD[InternalRow] = { val str = Literal("so fast").value val row = new GenericInternalRow(Array[Any](str)) - sparkContext.parallelize(Seq(row)) + val unsafeProj = UnsafeProjection.create(schema) + val unsafeRow = unsafeProj(row).copy() + sparkContext.parallelize(Seq(unsafeRow)) } + override def producedAttributes: AttributeSet = outputSet override def children: Seq[SparkPlan] = Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index a9ca46cab067d..a87a41c12664f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -36,135 +37,102 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = sqlContext.planner.EquiJoinSelection(join) + val planned = sqlContext.sessionState.planner.EquiJoinSelection(join) assert(planned.size === 1) } - def assertJoin(sqlString: String, c: Class[_]): Any = { + def assertJoin(pair: (String, Class[_])): Any = { + val (sqlString, c) = pair val df = sql(sqlString) val physical = df.queryExecution.sparkPlan val operators = physical.collect { - case j: ShuffledHashJoin => j - case j: ShuffledHashOuterJoin => j - case j: LeftSemiJoinHash => j case j: BroadcastHashJoin => j - case j: BroadcastHashOuterJoin => j - case j: LeftSemiJoinBNL => j + case j: ShuffledHashJoin => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j - case j: BroadcastLeftSemiJoinHash => j case j: SortMergeJoin => j - case j: SortMergeOuterJoin => j } assert(operators.size === 1) - if (operators(0).getClass() != c) { - fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical") + if (operators.head.getClass != c) { + fail(s"$sqlString expected operator: $c, but got ${operators.head}\n physical: \n$physical") } } test("join operator selection") { sqlContext.cacheManager.clearCache() - Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), - ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), - ("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), - ("SELECT * FROM testData LEFT JOIN testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData RIGHT JOIN testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), - ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), - ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), - ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]), - ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]), - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]), - ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[SortMergeOuterJoin]), - ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[SortMergeOuterJoin]), - ("SELECT * FROM testData full outer join testData2 ON key = a", - classOf[SortMergeOuterJoin]), - ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", - classOf[BroadcastNestedLoopJoin]), - ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", - classOf[BroadcastNestedLoopJoin]), - ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", - classOf[BroadcastNestedLoopJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") { Seq( - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", - classOf[ShuffledHashJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", - classOf[ShuffledHashJoin]), - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), + ("SELECT * FROM testData LEFT JOIN testData2", classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2", classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), + ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]), + ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", + classOf[CartesianProduct]), + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[ShuffledHashOuterJoin]), + classOf[SortMergeJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[ShuffledHashOuterJoin]), + classOf[SortMergeJoin]), ("SELECT * FROM testData full outer join testData2 ON key = a", - classOf[ShuffledHashOuterJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + classOf[SortMergeJoin]), + ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData ANTI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData LEFT ANTI JOIN testData2", classOf[BroadcastNestedLoopJoin]) + ).foreach(assertJoin) } } - test("SortMergeJoin shouldn't work on unsortable columns") { - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { - Seq( - ("SELECT * FROM arrayData JOIN complexData ON data = a", classOf[ShuffledHashJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } - } +// ignore("SortMergeJoin shouldn't work on unsortable columns") { +// Seq( +// ("SELECT * FROM arrayData JOIN complexData ON data = a", classOf[ShuffledHashJoin]) +// ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } +// } test("broadcasted hash join operator selection") { sqlContext.cacheManager.clearCache() sql("CACHE TABLE testData") - for (sortMergeJoinEnabled <- Seq(true, false)) { - withClue(s"sortMergeJoinEnabled=$sortMergeJoinEnabled") { - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> s"$sortMergeJoinEnabled") { - Seq( - ("SELECT * FROM testData join testData2 ON key = a", - classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a and key = 2", - classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key = 2", - classOf[BroadcastHashJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } - } - } + Seq( + ("SELECT * FROM testData join testData2 ON key = a", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key = 2", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key = 2", + classOf[BroadcastHashJoin]) + ).foreach(assertJoin) sql("UNCACHE TABLE testData") } test("broadcasted hash outer join operator selection") { sqlContext.cacheManager.clearCache() sql("CACHE TABLE testData") - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { - Seq( - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", - classOf[SortMergeOuterJoin]), - ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[BroadcastHashOuterJoin]), - ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[BroadcastHashOuterJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { - Seq( - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", - classOf[ShuffledHashOuterJoin]), - ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[BroadcastHashOuterJoin]), - ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[BroadcastHashOuterJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } + sql("CACHE TABLE testData2") + Seq( + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[BroadcastHashJoin]) + ).foreach(assertJoin) sql("UNCACHE TABLE testData") } @@ -172,7 +140,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = sqlContext.planner.EquiJoinSelection(join) + val planned = sqlContext.sessionState.planner.EquiJoinSelection(join) assert(planned.size === 1) } @@ -219,7 +187,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("big inner join, 4 matches per row") { - val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData) + val bigData = testData.union(testData).union(testData).union(testData) val bigDataX = bigData.as("x") val bigDataY = bigData.as("y") @@ -279,16 +247,17 @@ class JoinSuite extends QueryTest with SharedSQLContext { checkAnswer( sql( """ - |SELECT l.N, count(*) - |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) - |GROUP BY l.N - """.stripMargin), - Row(1, 1) :: - Row(2, 1) :: - Row(3, 1) :: - Row(4, 1) :: - Row(5, 1) :: - Row(6, 1) :: Nil) + |SELECT l.N, count(*) + |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) + |GROUP BY l.N + """. + stripMargin), + Row(1, 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: Nil) checkAnswer( sql( @@ -343,7 +312,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) |GROUP BY l.a """.stripMargin), - Row(null, 6)) + Row(null, + 6)) checkAnswer( sql( @@ -352,7 +322,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) |GROUP BY r.N """.stripMargin), - Row(1, 1) :: + Row(1 + , 1) :: Row(2, 1) :: Row(3, 1) :: Row(4, 1) :: @@ -361,8 +332,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("full outer join") { - upperCaseData.where('N <= 4).registerTempTable("left") - upperCaseData.where('N >= 3).registerTempTable("right") + upperCaseData.where('N <= 4).registerTempTable("`left`") + upperCaseData.where('N >= 3).registerTempTable("`right`") val left = UnresolvedRelation(TableIdentifier("left"), None) val right = UnresolvedRelation(TableIdentifier("right"), None) @@ -377,7 +348,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(null, null, 6, "F") :: Nil) checkAnswer( - left.join(right, ($"left.N" === $"right.N") && ($"left.N" !== 3), "full"), + left.join(right, ($"left.N" === $"right.N") && ($"left.N" =!= 3), "full"), Row(1, "A", null, null) :: Row(2, "B", null, null) :: Row(3, "C", null, null) :: @@ -387,7 +358,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(null, null, 6, "F") :: Nil) checkAnswer( - left.join(right, ($"left.N" === $"right.N") && ($"right.N" !== 3), "full"), + left.join(right, ($"left.N" === $"right.N") && ($"right.N" =!= 3), "full"), Row(1, "A", null, null) :: Row(2, "B", null, null) :: Row(3, "C", null, null) :: @@ -396,14 +367,16 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(null, null, 5, "E") :: Row(null, null, 6, "F") :: Nil) - // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator. + // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join + // operator. checkAnswer( sql( """ - |SELECT l.a, count(*) - |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) - |GROUP BY l.a - """.stripMargin), + |SELECT l.a, count(*) + |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) + |GROUP BY l.a + """. + stripMargin), Row(null, 10)) checkAnswer( @@ -413,7 +386,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) |GROUP BY r.N """.stripMargin), - Row(1, 1) :: + Row + (1, 1) :: Row(2, 1) :: Row(3, 1) :: Row(4, 1) :: @@ -428,7 +402,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) |GROUP BY l.N """.stripMargin), - Row(1, 1) :: + Row(1 + , 1) :: Row(2, 1) :: Row(3, 1) :: Row(4, 1) :: @@ -439,32 +414,30 @@ class JoinSuite extends QueryTest with SharedSQLContext { checkAnswer( sql( """ - |SELECT r.a, count(*) - |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) - |GROUP BY r.a - """.stripMargin), + |SELECT r.a, count(*) + |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) + |GROUP BY r.a + """. + stripMargin), Row(null, 10)) } - test("broadcasted left semi join operator selection") { + test("broadcasted existence join operator selection") { sqlContext.cacheManager.clearCache() sql("CACHE TABLE testData") withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", - classOf[BroadcastLeftSemiJoinHash]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) - } + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[BroadcastHashJoin]), + ("SELECT * FROM testData ANT JOIN testData2 ON key = a", classOf[BroadcastHashJoin]) + ).foreach(assertJoin) } withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) - } + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData LEFT ANTI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]) + ).foreach(assertJoin) } sql("UNCACHE TABLE testData") @@ -487,9 +460,9 @@ class JoinSuite extends QueryTest with SharedSQLContext { Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", - classOf[LeftSemiJoinHash]), + classOf[ShuffledHashJoin]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", - classOf[LeftSemiJoinBNL]), + classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData JOIN testData2", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData JOIN testData2 WHERE key = 2", @@ -516,7 +489,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData full JOIN testData2 WHERE (key * a != key + a)", classOf[BroadcastNestedLoopJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + ).foreach(assertJoin) checkAnswer( sql( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index e3531d0d6d799..1391c9d57ff7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -39,25 +39,52 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { ("6", "[invalid JSON string]") :: Nil + test("function get_json_object - null") { + val df: DataFrame = tuples.toDF("key", "jstring") + val expected = + Row("1", "value1", "value2", "3", null, "5.23") :: + Row("2", "value12", "2", "value3", "4.01", null) :: + Row("3", "value13", "2", "value33", "value44", "5.01") :: + Row("4", null, null, null, null, null) :: + Row("5", "", null, null, null, null) :: + Row("6", null, null, null, null, null) :: + Nil + + checkAnswer( + df.select($"key", functions.get_json_object($"jstring", "$.f1"), + functions.get_json_object($"jstring", "$.f2"), + functions.get_json_object($"jstring", "$.f3"), + functions.get_json_object($"jstring", "$.f4"), + functions.get_json_object($"jstring", "$.f5")), + expected) + } + test("json_tuple select") { val df: DataFrame = tuples.toDF("key", "jstring") - val expected = Row("1", Row("value1", "value2", "3", null, "5.23")) :: - Row("2", Row("value12", "2", "value3", "4.01", null)) :: - Row("3", Row("value13", "2", "value33", "value44", "5.01")) :: - Row("4", Row(null, null, null, null, null)) :: - Row("5", Row("", null, null, null, null)) :: - Row("6", Row(null, null, null, null, null)) :: + val expected = + Row("1", "value1", "value2", "3", null, "5.23") :: + Row("2", "value12", "2", "value3", "4.01", null) :: + Row("3", "value13", "2", "value33", "value44", "5.01") :: + Row("4", null, null, null, null, null) :: + Row("5", "", null, null, null, null) :: + Row("6", null, null, null, null, null) :: Nil - checkAnswer(df.selectExpr("key", "json_tuple(jstring, 'f1', 'f2', 'f3', 'f4', 'f5')"), expected) + checkAnswer( + df.select($"key", functions.json_tuple($"jstring", "f1", "f2", "f3", "f4", "f5")), + expected) + + checkAnswer( + df.selectExpr("key", "json_tuple(jstring, 'f1', 'f2', 'f3', 'f4', 'f5')"), + expected) } test("json_tuple filter and group") { val df: DataFrame = tuples.toDF("key", "jstring") val expr = df - .selectExpr("json_tuple(jstring, 'f1', 'f2') as jt") - .where($"jt.c0".isNotNull) - .groupBy($"jt.c1") + .select(functions.json_tuple($"jstring", "f1", "f2")) + .where($"c0".isNotNull) + .groupBy($"c1") .count() val expected = Row(null, 1) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 5688f46e5e3d4..bb54c525cb76d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} -import org.apache.spark.sql.catalyst.TableIdentifier class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext { import testImplicits._ @@ -33,7 +33,8 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex } after { - sqlContext.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) + sqlContext.sessionState.catalog.dropTable( + TableIdentifier("ListTablesSuiteTable"), ignoreIfNotExists = true) } test("get all tables") { @@ -45,20 +46,22 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - sqlContext.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) + sqlContext.sessionState.catalog.dropTable( + TableIdentifier("ListTablesSuiteTable"), ignoreIfNotExists = true) assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } - test("getting all Tables with a database name has no impact on returned table names") { + test("getting all tables with a database name has no impact on returned table names") { checkAnswer( - sqlContext.tables("DB").filter("tableName = 'ListTablesSuiteTable'"), + sqlContext.tables("default").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( - sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), + sql("show TABLES in default").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - sqlContext.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) + sqlContext.sessionState.catalog.dropTable( + TableIdentifier("ListTablesSuiteTable"), ignoreIfNotExists = true) assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 58f982c2bc932..f5a67fd782d63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.nio.charset.StandardCharsets + import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions.{log => logarithm} import org.apache.spark.sql.test.SharedSQLContext @@ -212,7 +214,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600)) ) - val pi = 3.1415 + val pi = "3.1415" checkAnswer( sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"), @@ -262,9 +264,9 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { test("unhex") { val data = Seq(("1C", "737472696E67")).toDF("a", "b") checkAnswer(data.select(unhex('a)), Row(Array[Byte](28.toByte))) - checkAnswer(data.select(unhex('b)), Row("string".getBytes)) + checkAnswer(data.select(unhex('b)), Row("string".getBytes(StandardCharsets.UTF_8))) checkAnswer(data.selectExpr("unhex(a)"), Row(Array[Byte](28.toByte))) - checkAnswer(data.selectExpr("unhex(b)"), Row("string".getBytes)) + checkAnswer(data.selectExpr("unhex(b)"), Row("string".getBytes(StandardCharsets.UTF_8))) checkAnswer(data.selectExpr("""unhex("##")"""), Row(null)) checkAnswer(data.selectExpr("""unhex("G123")"""), Row(null)) } @@ -367,6 +369,16 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { checkAnswer( input.toDF("key", "value").selectExpr("abs(key) a").sort("a"), input.map(pair => Row(pair._2))) + + checkAnswer( + sql("select abs(0), abs(-1), abs(123), abs(-9223372036854775807), abs(9223372036854775807)"), + Row(0, 1, 123, 9223372036854775807L, 9223372036854775807L) + ) + + checkAnswer( + sql("select abs(0.0), abs(-3.14159265), abs(3.14159265)"), + Row(BigDecimal("0.0"), BigDecimal("3.14159265"), BigDecimal("3.14159265")) + ) } test("log2") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala index 0e8fcb6a858b1..0b5a92c256e57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql -import org.apache.spark._ import org.scalatest.BeforeAndAfterAll +import org.apache.spark._ +import org.apache.spark.sql.internal.SQLConf + class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { private var originalActiveSQLContext: Option[SQLContext] = _ @@ -27,11 +29,11 @@ class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { private var sparkConf: SparkConf = _ override protected def beforeAll(): Unit = { - originalActiveSQLContext = SQLContext.getActiveContextOption() + originalActiveSQLContext = SQLContext.getActive() originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() SQLContext.clearActive() - originalInstantiatedSQLContext.foreach(ctx => SQLContext.clearInstantiatedContext(ctx)) + SQLContext.clearInstantiatedContext() sparkConf = new SparkConf(false) .setMaster("local[*]") @@ -89,10 +91,9 @@ class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { testNewSession(rootSQLContext) testNewSession(rootSQLContext) testCreatingNewSQLContext(allowMultipleSQLContexts) - - SQLContext.clearInstantiatedContext(rootSQLContext) } finally { sc.stop() + SQLContext.clearInstantiatedContext() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala new file mode 100644 index 0000000000000..0d18a645f6790 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration._ + +import org.apache.spark.SparkFunSuite + +class ProcessingTimeSuite extends SparkFunSuite { + + test("create") { + assert(ProcessingTime(10.seconds).intervalMs === 10 * 1000) + assert(ProcessingTime.create(10, TimeUnit.SECONDS).intervalMs === 10 * 1000) + assert(ProcessingTime("1 minute").intervalMs === 60 * 1000) + assert(ProcessingTime("interval 1 minute").intervalMs === 60 * 1000) + + intercept[IllegalArgumentException] { ProcessingTime(null: String) } + intercept[IllegalArgumentException] { ProcessingTime("") } + intercept[IllegalArgumentException] { ProcessingTime("invalid") } + intercept[IllegalArgumentException] { ProcessingTime("1 month") } + intercept[IllegalArgumentException] { ProcessingTime("1 year") } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 3c174efe73ffe..23a0ce215ff3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -20,11 +20,20 @@ package org.apache.spark.sql import java.util.{Locale, TimeZone} import scala.collection.JavaConverters._ +import scala.util.control.NonFatal +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.columnar.InMemoryRelation -import org.apache.spark.sql.catalyst.encoders.Encoder +import org.apache.spark.sql.execution.LogicalRDD +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.streaming.MemoryPlan +import org.apache.spark.sql.types.ObjectType abstract class QueryTest extends PlanTest { @@ -65,12 +74,12 @@ abstract class QueryTest extends PlanTest { * for cases where reordering is done on fields. For such tests, user `checkDecoding` instead * which performs a subset of the checks done by this function. */ - protected def checkAnswer[T : Encoder]( - ds: => Dataset[T], + protected def checkDataset[T]( + ds: Dataset[T], expectedAnswer: T*): Unit = { checkAnswer( ds.toDF(), - sqlContext.createDataset(expectedAnswer).toDF().collect().toSeq) + sqlContext.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq) checkDecoding(ds, expectedAnswer: _*) } @@ -83,18 +92,27 @@ abstract class QueryTest extends PlanTest { fail( s""" |Exception collecting dataset as objects - |${ds.encoder} - |${ds.encoder.constructExpression.treeString} + |${ds.resolvedTEncoder} + |${ds.resolvedTEncoder.deserializer.treeString} |${ds.queryExecution} """.stripMargin, e) } - if (decoded != expectedAnswer.toSet) { + // Handle the case where the return type is an array + val isArray = decoded.headOption.map(_.getClass.isArray).getOrElse(false) + def normalEquality = decoded == expectedAnswer.toSet + def expectedAsSeq = expectedAnswer.map(_.asInstanceOf[Array[_]].toSeq).toSet + def decodedAsSeq = decoded.map(_.asInstanceOf[Array[_]].toSeq) + + if (!((isArray && expectedAsSeq == decodedAsSeq) || normalEquality)) { + val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted + val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted + + val comparison = sideBySide("expected" +: expected, "spark" +: actual).mkString("\n") fail( s"""Decoded objects do not match expected objects: - |Expected: ${expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted} - |Actual ${decoded.toSet.toSeq.map((a: Any) => a.toString).sorted} - |${ds.encoder.constructExpression.treeString} + |$comparison + |${ds.resolvedTEncoder.deserializer.treeString} """.stripMargin) } } @@ -107,19 +125,23 @@ abstract class QueryTest extends PlanTest { protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { val analyzedDF = try df catch { case ae: AnalysisException => - val currentValue = sqlContext.conf.dataFrameEagerAnalysis - sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false) - val partiallyAnalzyedPlan = df.queryExecution.analyzed - sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, currentValue) - fail( - s""" - |Failed to analyze query: $ae - |$partiallyAnalzyedPlan - | - |${stackTraceToString(ae)} - |""".stripMargin) + if (ae.plan.isDefined) { + fail( + s""" + |Failed to analyze query: $ae + |${ae.plan.get} + | + |${stackTraceToString(ae)} + |""".stripMargin) + } else { + throw ae + } } + checkJsonFormat(analyzedDF) + + assertEmptyMissingInput(df) + QueryTest.checkAnswer(analyzedDF, expectedAnswer) match { case Some(errorMessage) => fail(errorMessage) case None => @@ -161,9 +183,9 @@ abstract class QueryTest extends PlanTest { } /** - * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. + * Asserts that a given [[Dataset]] will be executed using the given number of cached results. */ - def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { + def assertCached(query: Dataset[_], numCachedTables: Int = 1): Unit = { val planWithCaching = query.queryExecution.withCachedData val cachedData = planWithCaching collect { case cached: InMemoryRelation => cached @@ -174,6 +196,109 @@ abstract class QueryTest extends PlanTest { s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + planWithCaching) } + + private def checkJsonFormat(df: DataFrame): Unit = { + val logicalPlan = df.queryExecution.analyzed + // bypass some cases that we can't handle currently. + logicalPlan.transform { + case _: ObjectOperator => return + case _: LogicalRelation => return + case _: MemoryPlan => return + }.transformAllExpressions { + case a: ImperativeAggregate => return + case _: TypedAggregateExpression => return + case Literal(_, _: ObjectType) => return + } + + // bypass hive tests before we fix all corner cases in hive module. + if (this.getClass.getName.startsWith("org.apache.spark.sql.hive")) return + + val jsonString = try { + logicalPlan.toJSON + } catch { + case NonFatal(e) => + fail( + s""" + |Failed to parse logical plan to JSON: + |${logicalPlan.treeString} + """.stripMargin, e) + } + + // scala function is not serializable to JSON, use null to replace them so that we can compare + // the plans later. + val normalized1 = logicalPlan.transformAllExpressions { + case udf: ScalaUDF => udf.copy(function = null) + case gen: UserDefinedGenerator => gen.copy(function = null) + } + + // RDDs/data are not serializable to JSON, so we need to collect LogicalPlans that contains + // these non-serializable stuff, and use these original ones to replace the null-placeholders + // in the logical plans parsed from JSON. + var logicalRDDs = logicalPlan.collect { case l: LogicalRDD => l } + var localRelations = logicalPlan.collect { case l: LocalRelation => l } + var inMemoryRelations = logicalPlan.collect { case i: InMemoryRelation => i } + + val jsonBackPlan = try { + TreeNode.fromJSON[LogicalPlan](jsonString, sqlContext.sparkContext) + } catch { + case NonFatal(e) => + fail( + s""" + |Failed to rebuild the logical plan from JSON: + |${logicalPlan.treeString} + | + |${logicalPlan.prettyJson} + """.stripMargin, e) + } + + val normalized2 = jsonBackPlan transformDown { + case l: LogicalRDD => + val origin = logicalRDDs.head + logicalRDDs = logicalRDDs.drop(1) + LogicalRDD(l.output, origin.rdd)(sqlContext) + case l: LocalRelation => + val origin = localRelations.head + localRelations = localRelations.drop(1) + l.copy(data = origin.data) + case l: InMemoryRelation => + val origin = inMemoryRelations.head + inMemoryRelations = inMemoryRelations.drop(1) + InMemoryRelation( + l.output, + l.useCompression, + l.batchSize, + l.storageLevel, + origin.child, + l.tableName)( + origin.cachedColumnBuffers, + l._statistics, + origin._batchStats) + } + + assert(logicalRDDs.isEmpty) + assert(localRelations.isEmpty) + assert(inMemoryRelations.isEmpty) + + if (normalized1 != normalized2) { + fail( + s""" + |== FAIL: the logical plan parsed from json does not match the original one === + |${sideBySide(logicalPlan.treeString, normalized2.treeString).mkString("\n")} + """.stripMargin) + } + } + + /** + * Asserts that a given [[Dataset]] does not have missing inputs in all the analyzed plans. + */ + def assertEmptyMissingInput(query: Dataset[_]): Unit = { + assert(query.queryExecution.analyzed.missingInput.isEmpty, + s"The analyzed logical plan has missing inputs: ${query.queryExecution.analyzed}") + assert(query.queryExecution.optimizedPlan.missingInput.isEmpty, + s"The optimized logical plan has missing inputs: ${query.queryExecution.optimizedPlan}") + assert(query.queryExecution.executedPlan.missingInput.isEmpty, + s"The physical plan has missing inputs: ${query.queryExecution.executedPlan}") + } } object QueryTest { @@ -188,27 +313,7 @@ object QueryTest { def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = { val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty - // We need to call prepareRow recursively to handle schemas with struct types. - def prepareRow(row: Row): Row = { - Row.fromSeq(row.toSeq.map { - case null => null - case d: java.math.BigDecimal => BigDecimal(d) - // Convert array to Seq for easy equality check. - case b: Array[_] => b.toSeq - case r: Row => prepareRow(r) - case o => o - }) - } - def prepareAnswer(answer: Seq[Row]): Seq[Row] = { - // Converts data to types that we can do equality comparison using Scala collections. - // For BigDecimal type, the Scala type has a better definition of equality test (similar to - // Java's java.math.BigDecimal.compareTo). - // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for - // equality test. - val converted: Seq[Row] = answer.map(prepareRow) - if (!isSorted) converted.sortBy(_.toString()) else converted - } val sparkAnswer = try df.collect().toSeq catch { case e: Exception => val errorMessage = @@ -222,22 +327,56 @@ object QueryTest { return Some(errorMessage) } - if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { - val errorMessage = + sameRows(expectedAnswer, sparkAnswer, isSorted).map { results => s""" |Results do not match for query: |${df.queryExecution} |== Results == - |${sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer).map(_.toString()), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} - """.stripMargin - return Some(errorMessage) + |$results + """.stripMargin } + } + - return None + def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + val converted: Seq[Row] = answer.map(prepareRow) + if (!isSorted) converted.sortBy(_.toString()) else converted + } + + // We need to call prepareRow recursively to handle schemas with struct types. + def prepareRow(row: Row): Row = { + Row.fromSeq(row.toSeq.map { + case null => null + case d: java.math.BigDecimal => BigDecimal(d) + // Convert array to Seq for easy equality check. + case b: Array[_] => b.toSeq + case r: Row => prepareRow(r) + case o => o + }) + } + + def sameRows( + expectedAnswer: Seq[Row], + sparkAnswer: Seq[Row], + isSorted: Boolean = false): Option[String] = { + if (prepareAnswer(expectedAnswer, isSorted) != prepareAnswer(sparkAnswer, isSorted)) { + val errorMessage = + s""" + |== Results == + |${sideBySide( + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer, isSorted).map(_.toString()), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n")} + """.stripMargin + return Some(errorMessage) + } + None } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 3ba14d7602a62..4552eb6ce00a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala deleted file mode 100644 index 3d2bd236ceead..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ /dev/null @@ -1,96 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import org.apache.spark.sql.test.{TestSQLContext, SharedSQLContext} - - -class SQLConfSuite extends QueryTest with SharedSQLContext { - private val testKey = "test.key.0" - private val testVal = "test.val.0" - - test("propagate from spark conf") { - // We create a new context here to avoid order dependence with other tests that might call - // clear(). - val newContext = new SQLContext(sparkContext) - assert(newContext.getConf("spark.sql.testkey", "false") === "true") - } - - test("programmatic ways of basic setting and getting") { - // Set a conf first. - sqlContext.setConf(testKey, testVal) - // Clear the conf. - sqlContext.conf.clear() - // After clear, only overrideConfs used by unit test should be in the SQLConf. - assert(sqlContext.getAllConfs === TestSQLContext.overrideConfs) - - sqlContext.setConf(testKey, testVal) - assert(sqlContext.getConf(testKey) === testVal) - assert(sqlContext.getConf(testKey, testVal + "_") === testVal) - assert(sqlContext.getAllConfs.contains(testKey)) - - // Tests SQLConf as accessed from a SQLContext is mutable after - // the latter is initialized, unlike SparkConf inside a SparkContext. - assert(sqlContext.getConf(testKey) === testVal) - assert(sqlContext.getConf(testKey, testVal + "_") === testVal) - assert(sqlContext.getAllConfs.contains(testKey)) - - sqlContext.conf.clear() - } - - test("parse SQL set commands") { - sqlContext.conf.clear() - sql(s"set $testKey=$testVal") - assert(sqlContext.getConf(testKey, testVal + "_") === testVal) - assert(sqlContext.getConf(testKey, testVal + "_") === testVal) - - sql("set some.property=20") - assert(sqlContext.getConf("some.property", "0") === "20") - sql("set some.property = 40") - assert(sqlContext.getConf("some.property", "0") === "40") - - val key = "spark.sql.key" - val vs = "val0,val_1,val2.3,my_table" - sql(s"set $key=$vs") - assert(sqlContext.getConf(key, "0") === vs) - - sql(s"set $key=") - assert(sqlContext.getConf(key, "0") === "") - - sqlContext.conf.clear() - } - - test("deprecated property") { - sqlContext.conf.clear() - val original = sqlContext.conf.numShufflePartitions - try{ - sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") - assert(sqlContext.conf.numShufflePartitions === 10) - } finally { - sql(s"set ${SQLConf.SHUFFLE_PARTITIONS}=$original") - } - } - - test("invalid conf value") { - sqlContext.conf.clear() - val e = intercept[IllegalArgumentException] { - sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") - } - assert(e.getMessage === s"${SQLConf.CASE_SENSITIVE.key} should be boolean, but was 10") - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index 1994dacfc4dfa..2f62ad4850dee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -18,8 +18,15 @@ package org.apache.spark.sql import org.apache.spark.{SharedSparkContext, SparkFunSuite} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf -class SQLContextSuite extends SparkFunSuite with SharedSparkContext{ +class SQLContextSuite extends SparkFunSuite with SharedSparkContext { + + object DummyRule extends Rule[LogicalPlan] { + def apply(p: LogicalPlan): LogicalPlan = p + } test("getOrCreate instantiates SQLContext") { val sqlContext = SQLContext.getOrCreate(sc) @@ -65,4 +72,17 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext{ session2.sql("select myadd(1, 2)").explain() } } + + test("Catalyst optimization passes are modifiable at runtime") { + val sqlContext = SQLContext.getOrCreate(sc) + sqlContext.experimental.extraOptimizations = Seq(DummyRule) + assert(sqlContext.sessionState.optimizer.batches.flatMap(_.rules).contains(DummyRule)) + } + + test("SQLContext can access `spark.sql.*` configs") { + sc.conf.set("spark.sql.with.or.without.you", "my love") + val sqlContext = new SQLContext(sc) + assert(sqlContext.getConf("spark.sql.with.or.without.you") == "my love") + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 3de277a79a52c..cdd404d699a71 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -21,19 +21,18 @@ import java.math.MathContext import java.sql.Timestamp import org.apache.spark.AccumulatorSuite -import org.apache.spark.sql.catalyst.DefaultParserDialect -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.apache.spark.sql.catalyst.errors.DialectException +import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, CartesianProduct, SortMergeJoin} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestData._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} +import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ -/** A SQL Dialect for testing purpose, and it can not be nested type */ -class MyDialect extends DefaultParserDialect - class SQLQuerySuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -57,8 +56,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("show functions") { - checkAnswer(sql("SHOW functions"), - FunctionRegistry.builtin.listFunction().sorted.map(Row(_))) + def getFunctions(pattern: String): Seq[Row] = { + StringUtils.filterPattern(sqlContext.sessionState.functionRegistry.listFunction(), pattern) + .map(Row(_)) + } + checkAnswer(sql("SHOW functions"), getFunctions("*")) + Seq("^c*", "*e$", "log*", "*date*").foreach { pattern => + // For the pattern part, only '*' and '|' are allowed as wildcards. + // For '*', we need to replace it to '.*'. + checkAnswer(sql(s"SHOW FUNCTIONS '$pattern'"), getFunctions(pattern)) + } } test("describe functions") { @@ -79,7 +86,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { "Extended Usage") checkExistence(sql("describe functioN abcadf"), true, - "Function: abcadf is not found.") + "Function: abcadf not found.") + } + + test("SPARK-14415: All functions should have own descriptions") { + for (f <- sqlContext.sessionState.functionRegistry.listFunction()) { + if (!Seq("cube", "grouping", "grouping_id", "rollup", "window").contains(f)) { + checkExistence(sql(s"describe function `$f`"), false, "To be added.") + } + } } test("SPARK-6743: no columns from cache") { @@ -147,23 +162,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { .count(), Row(24, 1) :: Row(14, 1) :: Nil) } - test("SQL Dialect Switching to a new SQL parser") { - val newContext = new SQLContext(sparkContext) - newContext.setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) - assert(newContext.getSQLDialect().getClass === classOf[MyDialect]) - assert(newContext.sql("SELECT 1").collect() === Array(Row(1))) - } - - test("SQL Dialect Switch to an invalid parser with alias") { - val newContext = new SQLContext(sparkContext) - newContext.sql("SET spark.sql.dialect=MyTestClass") - intercept[DialectException] { - newContext.sql("SELECT 1") - } - // test if the dialect set back to DefaultSQLDialect - assert(newContext.getSQLDialect().getClass === classOf[DefaultParserDialect]) - } - test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") { checkAnswer( sql("SELECT a FROM testData2 SORT BY a"), @@ -237,40 +235,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-8828 sum should return null if all input values are null") { - withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "true") { - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - checkAnswer( - sql("select sum(a), avg(a) from allNulls"), - Seq(Row(null, null)) - ) - } - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { - checkAnswer( - sql("select sum(a), avg(a) from allNulls"), - Seq(Row(null, null)) - ) - } - } - withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - checkAnswer( - sql("select sum(a), avg(a) from allNulls"), - Seq(Row(null, null)) - ) - } - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { - checkAnswer( - sql("select sum(a), avg(a) from allNulls"), - Seq(Row(null, null)) - ) - } - } + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) } private def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { val df = sql(sqlText) // First, check if we have GeneratedAggregate. - val hasGeneratedAgg = df.queryExecution.executedPlan + val hasGeneratedAgg = df.queryExecution.sparkPlan .collect { case _: aggregate.TungstenAggregate => true } .nonEmpty if (!hasGeneratedAgg) { @@ -285,12 +259,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("aggregation with codegen") { - val originalValue = sqlContext.conf.codegenEnabled - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) // Prepare a table that we can group some rows. sqlContext.table("testData") - .unionAll(sqlContext.table("testData")) - .unionAll(sqlContext.table("testData")) + .union(sqlContext.table("testData")) + .union(sqlContext.table("testData")) .registerTempTable("testData3x") try { @@ -340,13 +312,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { testCodeGen( "SELECT min(key) FROM testData3x", Row(1) :: Nil) - // STDDEV - testCodeGen( - "SELECT a, stddev(b), stddev_pop(b) FROM testData2 GROUP BY a", - (1 to 3).map(i => Row(i, math.sqrt(0.5), math.sqrt(0.25)))) - testCodeGen( - "SELECT stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2", - Row(math.sqrt(1.5 / 5), math.sqrt(1.5 / 6), math.sqrt(1.5 / 5)) :: Nil) // Some combinations. testCodeGen( """ @@ -367,11 +332,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(100, 1, 50.5, 300, 100) :: Nil) // Aggregate with Code generation handling all null values testCodeGen( - "SELECT sum('a'), avg('a'), stddev('a'), count(null) FROM testData", - Row(null, null, null, 0) :: Nil) + "SELECT sum('a'), avg('a'), count(null) FROM testData", + Row(null, null, 0) :: Nil) } finally { sqlContext.dropTempTable("testData3x") - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) } } @@ -506,29 +470,100 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Seq(Row(1, 3), Row(2, 3), Row(3, 3))) } - test("literal in agg grouping expressions") { - def literalInAggTest(): Unit = { - checkAnswer( - sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) - checkAnswer( - sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + test("Group By Ordinal - basic") { + checkAnswer( + sql("SELECT a, sum(b) FROM testData2 GROUP BY 1"), + sql("SELECT a, sum(b) FROM testData2 GROUP BY a")) - checkAnswer( - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) - checkAnswer( - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) - checkAnswer( - sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), - sql("SELECT 1, 2, sum(b) FROM testData2")) + // duplicate group-by columns + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY 1, 2"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + } + + test("Group By Ordinal - non aggregate expressions") { + checkAnswer( + sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, 2"), + sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2")) + + checkAnswer( + sql("SELECT a, b + 2 as c, count(2) FROM testData2 GROUP BY a, 2"), + sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2")) + } + + test("Group By Ordinal - non-foldable constant expression") { + checkAnswer( + sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b, 1 + 0"), + sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b")) + + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + } + + test("Group By Ordinal - alias") { + checkAnswer( + sql("SELECT a, (b + 2) as c, count(2) FROM testData2 GROUP BY a, 2"), + sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2")) + + checkAnswer( + sql("SELECT a as b, b as a, sum(b) FROM testData2 GROUP BY 1, 2"), + sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b")) + } + + test("Group By Ordinal - constants") { + checkAnswer( + sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), + sql("SELECT 1, 2, sum(b) FROM testData2")) + } + + test("Group By Ordinal - negative cases") { + intercept[UnresolvedException[Aggregate]] { + sql("SELECT a, b FROM testData2 GROUP BY -1") + } + + intercept[UnresolvedException[Aggregate]] { + sql("SELECT a, b FROM testData2 GROUP BY 3") } - literalInAggTest() - withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { - literalInAggTest() + var e = intercept[UnresolvedException[Aggregate]]( + sql("SELECT SUM(a) FROM testData2 GROUP BY 1")) + assert(e.getMessage contains + "Invalid call to Group by position: the '1'th column in the select contains " + + "an aggregate function") + + e = intercept[UnresolvedException[Aggregate]]( + sql("SELECT SUM(a) + 1 FROM testData2 GROUP BY 1")) + assert(e.getMessage contains + "Invalid call to Group by position: the '1'th column in the select contains " + + "an aggregate function") + + var ae = intercept[AnalysisException]( + sql("SELECT a, rand(0), sum(b) FROM testData2 GROUP BY a, 2")) + assert(ae.getMessage contains + "nondeterministic expression rand(0) should not appear in grouping expression") + + ae = intercept[AnalysisException]( + sql("SELECT * FROM testData2 GROUP BY a, b, 1")) + assert(ae.getMessage contains + "Group by position: star is not allowed to use in the select list " + + "when using ordinals in group by") + } + + test("Group By Ordinal: spark.sql.groupByOrdinal=false") { + withSQLConf(SQLConf.GROUP_BY_ORDINAL.key -> "false") { + // If spark.sql.groupByOrdinal=false, ignore the position number. + intercept[AnalysisException] { + sql("SELECT a, sum(b) FROM testData2 GROUP BY 1") + } + // '*' is not allowed to use in the select list when users specify ordinals in group by + checkAnswer( + sql("SELECT * FROM testData2 GROUP BY a, b, 1"), + sql("SELECT * FROM testData2 GROUP BY a, b")) } } @@ -598,12 +633,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sortTest() } - test("SPARK-6927 external sorting with codegen on") { - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - sortTest() - } - } - test("limit") { checkAnswer( sql("SELECT * FROM testData LIMIT 10"), @@ -633,7 +662,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("Allow only a single WITH clause per query") { - intercept[RuntimeException] { + intercept[AnalysisException] { sql( "with q1 as (select * from testData) with q2 as (select * from q1) select * from q2") } @@ -649,8 +678,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("from follow multiple brackets") { checkAnswer(sql( """ - |select key from ((select * from testData limit 1) - | union all (select * from testData limit 1)) x limit 1 + |select key from ((select * from testData) + | union all (select * from testData)) x limit 1 """.stripMargin), Row(1) ) @@ -663,7 +692,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(sql( """ |select key from - | (select * from testData limit 1 union all select * from testData limit 1) x + | (select * from testData union all select * from testData) x | limit 1 """.stripMargin), Row(1) @@ -696,13 +725,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("approximate count distinct") { checkAnswer( - sql("SELECT APPROXIMATE COUNT(DISTINCT a) FROM testData2"), + sql("SELECT APPROX_COUNT_DISTINCT(a) FROM testData2"), Row(3)) } test("approximate count distinct with user provided standard deviation") { checkAnswer( - sql("SELECT APPROXIMATE(0.04) COUNT(DISTINCT a) FROM testData2"), + sql("SELECT APPROX_COUNT_DISTINCT(a, 0.04) FROM testData2"), Row(3)) } @@ -838,14 +867,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-11111 null-safe join should not use cartesian product") { val df = sql("select count(*) from testData a join testData b on (a.key <=> b.key)") - val cp = df.queryExecution.executedPlan.collect { + val cp = df.queryExecution.sparkPlan.collect { case cp: CartesianProduct => cp } assert(cp.isEmpty, "should not use CartesianProduct for null-safe join") - val smj = df.queryExecution.executedPlan.collect { + val smj = df.queryExecution.sparkPlan.collect { case smj: SortMergeJoin => smj + case j: BroadcastHashJoin => j } - assert(smj.size > 0, "should use SortMergeJoin") + assert(smj.size > 0, "should use SortMergeJoin or BroadcastHashJoin") checkAnswer(df, Row(100) :: Nil) } @@ -853,7 +883,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC") .limit(2) .registerTempTable("subset1") - sql("SELECT DISTINCT n FROM lowerCaseData") + sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n ASC") .limit(2) .registerTempTable("subset2") checkAnswer( @@ -1048,7 +1078,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SET commands with illegal or inappropriate argument") { sqlContext.conf.clear() - // Set negative mapred.reduce.tasks for automatically determing + // Set negative mapred.reduce.tasks for automatically determining // the number of reducers is not supported intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-1")) intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-01")) @@ -1239,11 +1269,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("Floating point number format") { checkAnswer( - sql("SELECT 0.3"), Row(BigDecimal(0.3).underlying()) + sql("SELECT 0.3"), Row(BigDecimal(0.3)) ) checkAnswer( - sql("SELECT -0.8"), Row(BigDecimal(-0.8).underlying()) + sql("SELECT -0.8"), Row(BigDecimal(-0.8)) ) checkAnswer( @@ -1288,7 +1318,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) checkAnswer( - sql("SELECT +6.8"), Row(BigDecimal(6.8)) + sql("SELECT +6.8e0"), Row(6.8d) ) checkAnswer( @@ -1456,12 +1486,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-4699 case sensitivity SQL query") { - sqlContext.setConf(SQLConf.CASE_SENSITIVE, false) - val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil - val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) - rdd.toDF().registerTempTable("testTable1") - checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) - sqlContext.setConf(SQLConf.CASE_SENSITIVE, true) + val orig = sqlContext.getConf(SQLConf.CASE_SENSITIVE) + try { + sqlContext.setConf(SQLConf.CASE_SENSITIVE, false) + val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil + val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) + rdd.toDF().registerTempTable("testTable1") + checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) + } finally { + sqlContext.setConf(SQLConf.CASE_SENSITIVE, orig) + } } test("SPARK-6145: ORDER BY test for nested fields") { @@ -1624,16 +1658,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { e.message.contains("Cannot save interval data type into external storage") }) - def checkIntervalParseError(s: String): Unit = { - val e = intercept[AnalysisException] { - sql(s) - } - e.message.contains("at least one time unit should be given for interval literal") + val e1 = intercept[AnalysisException] { + sql("select interval") } - - checkIntervalParseError("select interval") + assert(e1.message.contains("at least one time unit should be given for interval literal")) // Currently we don't yet support nanosecond - checkIntervalParseError("select interval 23 nanosecond") + val e2 = intercept[AnalysisException] { + sql("select interval 23 nanosecond") + } + assert(e2.message.contains("No interval can be constructed")) } test("SPARK-8945: add and subtract expressions for interval type") { @@ -1655,12 +1688,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("aggregation with codegen updates peak execution memory") { - withSQLConf((SQLConf.CODEGEN_ENABLED.key, "true")) { - AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "aggregation with codegen") { - testCodeGen( - "SELECT key, count(value) FROM testData GROUP BY key", - (1 to 100).map(i => Row(i, 1))) - } + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "aggregation with codegen") { + testCodeGen( + "SELECT key, count(value) FROM testData GROUP BY key", + (1 to 100).map(i => Row(i, 1))) } } @@ -1683,15 +1714,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-10215 Div of Decimal returns null") { - val d = Decimal(1.12321) + val d = Decimal(1.12321).toBigDecimal val df = Seq((d, 1)).toDF("a", "b") checkAnswer( df.selectExpr("b * a / b"), - Seq(Row(d.toBigDecimal))) + Seq(Row(d))) checkAnswer( df.selectExpr("b * a / b / b"), - Seq(Row(d.toBigDecimal))) + Seq(Row(d))) checkAnswer( df.selectExpr("b * a + b"), Seq(Row(BigDecimal(2.12321)))) @@ -1700,7 +1731,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Seq(Row(BigDecimal(0.12321)))) checkAnswer( df.selectExpr("b * a * b"), - Seq(Row(d.toBigDecimal))) + Seq(Row(d))) } test("precision smaller than scale") { @@ -1715,7 +1746,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("external sorting updates peak execution memory") { AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") { - sortTest() + sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect() } } @@ -1738,7 +1769,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { .format("parquet") .save(path) - val message = intercept[AnalysisException] { + // We don't support creating a temporary table while specifying a database + intercept[AnalysisException] { sqlContext.sql( s""" |CREATE TEMPORARY TABLE db.t @@ -1748,9 +1780,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { |) """.stripMargin) }.getMessage - assert(message.contains("Specifying database name or other qualifiers are not allowed")) - // If you use backticks to quote the name of a temporary table having dot in it. + // If you use backticks to quote the name then it's OK. sqlContext.sql( s""" |CREATE TEMPORARY TABLE `db.t` @@ -1782,7 +1813,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("run sql directly on files") { - val df = sqlContext.range(100) + val df = sqlContext.range(100).toDF() withTempPath(f => { df.write.json(f.getCanonicalPath) checkAnswer(sql(s"select id from json.`${f.getCanonicalPath}`"), @@ -1796,27 +1827,24 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val e1 = intercept[AnalysisException] { sql("select * from in_valid_table") } - assert(e1.message.contains("Table not found")) + assert(e1.message.contains("Table or View not found")) val e2 = intercept[AnalysisException] { - sql("select * from no_db.no_table") + sql("select * from no_db.no_table").show() } - assert(e2.message.contains("Table not found")) + assert(e2.message.contains("Table or View not found")) val e3 = intercept[AnalysisException] { sql("select * from json.invalid_file") } - assert(e3.message.contains("No input paths specified")) + assert(e3.message.contains("Path does not exist")) } test("SortMergeJoin returns wrong results when using UnsafeRows") { // This test is for the fix of https://issues.apache.org/jira/browse/SPARK-10737. // This bug will be triggered when Tungsten is enabled and there are multiple // SortMergeJoin operators executed in the same task. - val confs = - SQLConf.SORTMERGE_JOIN.key -> "true" :: - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1" :: - SQLConf.TUNGSTEN_ENABLED.key -> "true" :: Nil + val confs = SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1" :: Nil withSQLConf(confs: _*) { val df1 = (1 to 50).map(i => (s"str_$i", i)).toDF("i", "j") val df2 = @@ -2001,4 +2029,443 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) } } + + test("Star Expansion - group by") { + withSQLConf("spark.sql.retainGroupColumns" -> "false") { + checkAnswer( + testData2.groupBy($"a", $"b").agg($"*"), + sql("SELECT * FROM testData2 group by a, b")) + } + } + + test("Common subexpression elimination") { + // TODO: support subexpression elimination in whole stage codegen + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + // select from a table to prevent constant folding. + val df = sql("SELECT a, b from testData2 limit 1") + checkAnswer(df, Row(1, 1)) + + checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) + checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) + + // This does not work because the expressions get grouped like (a + a) + 1 + checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) + checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) + + // Identity udf that tracks the number of times it is called. + val countAcc = sparkContext.accumulator(0, "CallCount") + sqlContext.udf.register("testUdf", (x: Int) => { + countAcc.++=(1) + x + }) + + // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value + // is correct. + def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { + countAcc.setValue(0) + checkAnswer(df, expectedResult) + assert(countAcc.value == expectedCount) + } + + verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) + + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) + + val testUdf = functions.udf((x: Int) => { + countAcc.++=(1) + x + }) + verifyCallCount( + df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) + + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 1) + + // Try disabling it via configuration. + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + } + } + + test("SPARK-10707: nullability should be correctly propagated through set operations (1)") { + // This test produced an incorrect result of 1 before the SPARK-10707 fix because of the + // NullPropagation rule: COUNT(v) got replaced with COUNT(1) because the output column of + // UNION was incorrectly considered non-nullable: + checkAnswer( + sql("""SELECT count(v) FROM ( + | SELECT v FROM ( + | SELECT 'foo' AS v UNION ALL + | SELECT NULL AS v + | ) my_union WHERE isnull(v) + |) my_subview""".stripMargin), + Seq(Row(0))) + } + + test("SPARK-10707: nullability should be correctly propagated through set operations (2)") { + // This test uses RAND() to stop column pruning for Union and checks the resulting isnull + // value. This would produce an incorrect result before the fix in SPARK-10707 because the "v" + // column of the union was considered non-nullable. + checkAnswer( + sql( + """ + |SELECT a FROM ( + | SELECT ISNULL(v) AS a, RAND() FROM ( + | SELECT 'foo' AS v UNION ALL SELECT null AS v + | ) my_union + |) my_view + """.stripMargin), + Row(false) :: Row(true) :: Nil) + } + + test("rollup") { + checkAnswer( + sql("select course, year, sum(earnings) from courseSales group by rollup(course, year)" + + " order by course, year"), + Row(null, null, 113000.0) :: + Row("Java", null, 50000.0) :: + Row("Java", 2012, 20000.0) :: + Row("Java", 2013, 30000.0) :: + Row("dotNET", null, 63000.0) :: + Row("dotNET", 2012, 15000.0) :: + Row("dotNET", 2013, 48000.0) :: Nil + ) + } + + test("grouping sets when aggregate functions containing groupBy columns") { + checkAnswer( + sql("select course, sum(earnings) as sum from courseSales group by course, earnings " + + "grouping sets((), (course), (course, earnings)) " + + "order by course, sum"), + Row(null, 113000.0) :: + Row("Java", 20000.0) :: + Row("Java", 30000.0) :: + Row("Java", 50000.0) :: + Row("dotNET", 5000.0) :: + Row("dotNET", 10000.0) :: + Row("dotNET", 48000.0) :: + Row("dotNET", 63000.0) :: Nil + ) + + checkAnswer( + sql("select course, sum(earnings) as sum, grouping_id(course, earnings) from courseSales " + + "group by course, earnings grouping sets((), (course), (course, earnings)) " + + "order by course, sum"), + Row(null, 113000.0, 3) :: + Row("Java", 20000.0, 0) :: + Row("Java", 30000.0, 0) :: + Row("Java", 50000.0, 1) :: + Row("dotNET", 5000.0, 0) :: + Row("dotNET", 10000.0, 0) :: + Row("dotNET", 48000.0, 0) :: + Row("dotNET", 63000.0, 1) :: Nil + ) + } + + test("cube") { + checkAnswer( + sql("select course, year, sum(earnings) from courseSales group by cube(course, year)"), + Row("Java", 2012, 20000.0) :: + Row("Java", 2013, 30000.0) :: + Row("Java", null, 50000.0) :: + Row("dotNET", 2012, 15000.0) :: + Row("dotNET", 2013, 48000.0) :: + Row("dotNET", null, 63000.0) :: + Row(null, 2012, 35000.0) :: + Row(null, 2013, 78000.0) :: + Row(null, null, 113000.0) :: Nil + ) + } + + test("grouping sets") { + checkAnswer( + sql("select course, year, sum(earnings) from courseSales group by course, year " + + "grouping sets(course, year)"), + Row("Java", null, 50000.0) :: + Row("dotNET", null, 63000.0) :: + Row(null, 2012, 35000.0) :: + Row(null, 2013, 78000.0) :: Nil + ) + + checkAnswer( + sql("select course, year, sum(earnings) from courseSales group by course, year " + + "grouping sets(course)"), + Row("Java", null, 50000.0) :: + Row("dotNET", null, 63000.0) :: Nil + ) + + checkAnswer( + sql("select course, year, sum(earnings) from courseSales group by course, year " + + "grouping sets(year)"), + Row(null, 2012, 35000.0) :: + Row(null, 2013, 78000.0) :: Nil + ) + } + + test("grouping and grouping_id") { + checkAnswer( + sql("select course, year, grouping(course), grouping(year), grouping_id(course, year)" + + " from courseSales group by cube(course, year)"), + Row("Java", 2012, 0, 0, 0) :: + Row("Java", 2013, 0, 0, 0) :: + Row("Java", null, 0, 1, 1) :: + Row("dotNET", 2012, 0, 0, 0) :: + Row("dotNET", 2013, 0, 0, 0) :: + Row("dotNET", null, 0, 1, 1) :: + Row(null, 2012, 1, 0, 2) :: + Row(null, 2013, 1, 0, 2) :: + Row(null, null, 1, 1, 3) :: Nil + ) + + var error = intercept[AnalysisException] { + sql("select course, year, grouping(course) from courseSales group by course, year") + } + assert(error.getMessage contains "grouping() can only be used with GroupingSets/Cube/Rollup") + error = intercept[AnalysisException] { + sql("select course, year, grouping_id(course, year) from courseSales group by course, year") + } + assert(error.getMessage contains "grouping_id() can only be used with GroupingSets/Cube/Rollup") + error = intercept[AnalysisException] { + sql("select course, year, grouping__id from courseSales group by cube(course, year)") + } + assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead") + } + + test("grouping and grouping_id in having") { + checkAnswer( + sql("select course, year from courseSales group by cube(course, year)" + + " having grouping(year) = 1 and grouping_id(course, year) > 0"), + Row("Java", null) :: + Row("dotNET", null) :: + Row(null, null) :: Nil + ) + + var error = intercept[AnalysisException] { + sql("select course, year from courseSales group by course, year" + + " having grouping(course) > 0") + } + assert(error.getMessage contains + "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + error = intercept[AnalysisException] { + sql("select course, year from courseSales group by course, year" + + " having grouping_id(course, year) > 0") + } + assert(error.getMessage contains + "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + error = intercept[AnalysisException] { + sql("select course, year from courseSales group by cube(course, year)" + + " having grouping__id > 0") + } + assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead") + } + + test("grouping and grouping_id in sort") { + checkAnswer( + sql("select course, year, grouping(course), grouping(year) from courseSales" + + " group by cube(course, year) order by grouping_id(course, year), course, year"), + Row("Java", 2012, 0, 0) :: + Row("Java", 2013, 0, 0) :: + Row("dotNET", 2012, 0, 0) :: + Row("dotNET", 2013, 0, 0) :: + Row("Java", null, 0, 1) :: + Row("dotNET", null, 0, 1) :: + Row(null, 2012, 1, 0) :: + Row(null, 2013, 1, 0) :: + Row(null, null, 1, 1) :: Nil + ) + + checkAnswer( + sql("select course, year, grouping_id(course, year) from courseSales" + + " group by cube(course, year) order by grouping(course), grouping(year), course, year"), + Row("Java", 2012, 0) :: + Row("Java", 2013, 0) :: + Row("dotNET", 2012, 0) :: + Row("dotNET", 2013, 0) :: + Row("Java", null, 1) :: + Row("dotNET", null, 1) :: + Row(null, 2012, 2) :: + Row(null, 2013, 2) :: + Row(null, null, 3) :: Nil + ) + + var error = intercept[AnalysisException] { + sql("select course, year from courseSales group by course, year" + + " order by grouping(course)") + } + assert(error.getMessage contains + "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + error = intercept[AnalysisException] { + sql("select course, year from courseSales group by course, year" + + " order by grouping_id(course, year)") + } + assert(error.getMessage contains + "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + error = intercept[AnalysisException] { + sql("select course, year from courseSales group by cube(course, year)" + + " order by grouping__id") + } + assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead") + } + + test("filter on a grouping column that is not presented in SELECT") { + checkAnswer( + sql("select count(1) from (select 1 as a) t group by a having a > 0"), + Row(1) :: Nil) + } + + test("SPARK-13056: Null in map value causes NPE") { + val df = Seq(1 -> Map("abc" -> "somestring", "cba" -> null)).toDF("key", "value") + withTempTable("maptest") { + df.registerTempTable("maptest") + // local optimization will by pass codegen code, so we should keep the filter `key=1` + checkAnswer(sql("SELECT value['abc'] FROM maptest where key = 1"), Row("somestring")) + checkAnswer(sql("SELECT value['cba'] FROM maptest where key = 1"), Row(null)) + } + } + + test("hash function") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + withTempTable("tbl") { + df.registerTempTable("tbl") + checkAnswer( + df.select(hash($"i", $"j")), + sql("SELECT hash(i, j) from tbl") + ) + } + } + + test("order by ordinal number") { + checkAnswer( + sql("SELECT * FROM testData2 ORDER BY 1 DESC"), + sql("SELECT * FROM testData2 ORDER BY a DESC")) + // If the position is not an integer, ignore it. + checkAnswer( + sql("SELECT * FROM testData2 ORDER BY 1 + 0 DESC, b ASC"), + sql("SELECT * FROM testData2 ORDER BY b ASC")) + checkAnswer( + sql("SELECT * FROM testData2 ORDER BY 1 DESC, b ASC"), + sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC")) + checkAnswer( + sql("SELECT * FROM testData2 SORT BY 1 DESC, 2"), + sql("SELECT * FROM testData2 SORT BY a DESC, b ASC")) + checkAnswer( + sql("SELECT * FROM testData2 ORDER BY 1 ASC, b ASC"), + Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))) + } + + test("order by ordinal number - negative cases") { + intercept[UnresolvedException[SortOrder]] { + sql("SELECT * FROM testData2 ORDER BY 0") + } + intercept[UnresolvedException[SortOrder]] { + sql("SELECT * FROM testData2 ORDER BY -1 DESC, b ASC") + } + intercept[UnresolvedException[SortOrder]] { + sql("SELECT * FROM testData2 ORDER BY 3 DESC, b ASC") + } + } + + test("order by ordinal number with conf spark.sql.orderByOrdinal=false") { + withSQLConf(SQLConf.ORDER_BY_ORDINAL.key -> "false") { + // If spark.sql.orderByOrdinal=false, ignore the position number. + checkAnswer( + sql("SELECT * FROM testData2 ORDER BY 1 DESC, b ASC"), + sql("SELECT * FROM testData2 ORDER BY b ASC")) + } + } + + test("natural join") { + val df1 = Seq(("one", 1), ("two", 2), ("three", 3)).toDF("k", "v1") + val df2 = Seq(("one", 1), ("two", 22), ("one", 5)).toDF("k", "v2") + withTempTable("nt1", "nt2") { + df1.registerTempTable("nt1") + df2.registerTempTable("nt2") + checkAnswer( + sql("SELECT * FROM nt1 natural join nt2 where k = \"one\""), + Row("one", 1, 1) :: Row("one", 1, 5) :: Nil) + + checkAnswer( + sql("SELECT * FROM nt1 natural left join nt2 order by v1, v2"), + Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Row("three", 3, null) :: Nil) + + checkAnswer( + sql("SELECT * FROM nt1 natural right join nt2 order by v1, v2"), + Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Nil) + + checkAnswer( + sql("SELECT count(*) FROM nt1 natural full outer join nt2"), + Row(4) :: Nil) + } + } + + test("join with using clause") { + val df1 = Seq(("r1c1", "r1c2", "t1r1c3"), + ("r2c1", "r2c2", "t1r2c3"), ("r3c1x", "r3c2", "t1r3c3")).toDF("c1", "c2", "c3") + val df2 = Seq(("r1c1", "r1c2", "t2r1c3"), + ("r2c1", "r2c2", "t2r2c3"), ("r3c1y", "r3c2", "t2r3c3")).toDF("c1", "c2", "c3") + val df3 = Seq((null, "r1c2", "t3r1c3"), + ("r2c1", "r2c2", "t3r2c3"), ("r3c1y", "r3c2", "t3r3c3")).toDF("c1", "c2", "c3") + withTempTable("t1", "t2", "t3") { + df1.registerTempTable("t1") + df2.registerTempTable("t2") + df3.registerTempTable("t3") + // inner join with one using column + checkAnswer( + sql("SELECT * FROM t1 join t2 using (c1)"), + Row("r1c1", "r1c2", "t1r1c3", "r1c2", "t2r1c3") :: + Row("r2c1", "r2c2", "t1r2c3", "r2c2", "t2r2c3") :: Nil) + + // inner join with two using columns + checkAnswer( + sql("SELECT * FROM t1 join t2 using (c1, c2)"), + Row("r1c1", "r1c2", "t1r1c3", "t2r1c3") :: + Row("r2c1", "r2c2", "t1r2c3", "t2r2c3") :: Nil) + + // Left outer join with one using column. + checkAnswer( + sql("SELECT * FROM t1 left join t2 using (c1)"), + Row("r1c1", "r1c2", "t1r1c3", "r1c2", "t2r1c3") :: + Row("r2c1", "r2c2", "t1r2c3", "r2c2", "t2r2c3") :: + Row("r3c1x", "r3c2", "t1r3c3", null, null) :: Nil) + + // Right outer join with one using column. + checkAnswer( + sql("SELECT * FROM t1 right join t2 using (c1)"), + Row("r1c1", "r1c2", "t1r1c3", "r1c2", "t2r1c3") :: + Row("r2c1", "r2c2", "t1r2c3", "r2c2", "t2r2c3") :: + Row("r3c1y", null, null, "r3c2", "t2r3c3") :: Nil) + + // Full outer join with one using column. + checkAnswer( + sql("SELECT * FROM t1 full outer join t2 using (c1)"), + Row("r1c1", "r1c2", "t1r1c3", "r1c2", "t2r1c3") :: + Row("r2c1", "r2c2", "t1r2c3", "r2c2", "t2r2c3") :: + Row("r3c1x", "r3c2", "t1r3c3", null, null) :: + Row("r3c1y", null, + null, "r3c2", "t2r3c3") :: Nil) + + // Full outer join with null value in join column. + checkAnswer( + sql("SELECT * FROM t1 full outer join t3 using (c1)"), + Row("r1c1", "r1c2", "t1r1c3", null, null) :: + Row("r2c1", "r2c2", "t1r2c3", "r2c2", "t3r2c3") :: + Row("r3c1x", "r3c2", "t1r3c3", null, null) :: + Row("r3c1y", null, null, "r3c2", "t3r3c3") :: + Row(null, null, null, "r1c2", "t3r1c3") :: Nil) + + // Self join with using columns. + checkAnswer( + sql("SELECT * FROM t1 join t1 using (c1)"), + Row("r1c1", "r1c2", "t1r1c3", "r1c2", "t1r1c3") :: + Row("r2c1", "r2c2", "t1r2c3", "r2c2", "t1r2c3") :: + Row("r3c1x", "r3c2", "t1r3c3", "r3c2", "t1r3c3") :: Nil) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala new file mode 100644 index 0000000000000..6ccc99fe179d7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -0,0 +1,514 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.lang.Thread.UncaughtExceptionHandler + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.language.experimental.macros +import scala.reflect.ClassTag +import scala.util.Random +import scala.util.control.NonFatal + +import org.scalatest.Assertions +import org.scalatest.concurrent.{Eventually, Timeouts} +import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.exceptions.TestFailedDueToTimeoutException +import org.scalatest.time.Span +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.util.Utils + +/** + * A framework for implementing tests for streaming queries and sources. + * + * A test consists of a set of steps (expressed as a `StreamAction`) that are executed in order, + * blocking as necessary to let the stream catch up. For example, the following adds some data to + * a stream, blocking until it can verify that the correct values are eventually produced. + * + * {{{ + * val inputData = MemoryStream[Int] + val mapped = inputData.toDS().map(_ + 1) + + testStream(mapped)( + AddData(inputData, 1, 2, 3), + CheckAnswer(2, 3, 4)) + * }}} + * + * Note that while we do sleep to allow the other thread to progress without spinning, + * `StreamAction` checks should not depend on the amount of time spent sleeping. Instead they + * should check the actual progress of the stream before verifying the required test condition. + * + * Currently it is assumed that all streaming queries will eventually complete in 10 seconds to + * avoid hanging forever in the case of failures. However, individual suites can change this + * by overriding `streamingTimeout`. + */ +trait StreamTest extends QueryTest with Timeouts { + + implicit class RichSource(s: Source) { + def toDF(): DataFrame = Dataset.ofRows(sqlContext, StreamingExecutionRelation(s)) + + def toDS[A: Encoder](): Dataset[A] = Dataset(sqlContext, StreamingExecutionRelation(s)) + } + + /** How long to wait for an active stream to catch up when checking a result. */ + val streamingTimeout = 10.seconds + + /** A trait for actions that can be performed while testing a streaming DataFrame. */ + trait StreamAction + + /** A trait to mark actions that require the stream to be actively running. */ + trait StreamMustBeRunning + + /** + * Adds the given data to the stream. Subsequent check answers will block until this data has + * been processed. + */ + object AddData { + def apply[A](source: MemoryStream[A], data: A*): AddDataMemory[A] = + AddDataMemory(source, data) + } + + /** A trait that can be extended when testing other sources. */ + trait AddData extends StreamAction { + def source: Source + + /** + * Called to trigger adding the data. Should return the offset that will denote when this + * new data has been processed. + */ + def addData(): Offset + } + + case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData { + override def toString: String = s"AddData to $source: ${data.mkString(",")}" + + override def addData(): Offset = { + source.addData(data) + } + } + + /** + * Checks to make sure that the current data stored in the sink matches the `expectedAnswer`. + * This operation automatically blocks until all added data has been processed. + */ + object CheckAnswer { + def apply[A : Encoder](data: A*): CheckAnswerRows = { + val encoder = encoderFor[A] + val toExternalRow = RowEncoder(encoder.schema) + CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), false) + } + + def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false) + } + + /** + * Checks to make sure that the current data stored in the sink matches the `expectedAnswer`. + * This operation automatically blocks until all added data has been processed. + */ + object CheckLastBatch { + def apply[A : Encoder](data: A*): CheckAnswerRows = { + val encoder = encoderFor[A] + val toExternalRow = RowEncoder(encoder.schema) + CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), true) + } + + def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true) + } + + case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean) + extends StreamAction with StreamMustBeRunning { + override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" + private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer" + } + + /** Stops the stream. It must currently be running. */ + case object StopStream extends StreamAction with StreamMustBeRunning + + /** Starts the stream, resuming if data has already been processed. It must not be running. */ + case object StartStream extends StreamAction + + /** Signals that a failure is expected and should not kill the test. */ + case class ExpectFailure[T <: Throwable : ClassTag]() extends StreamAction { + val causeClass: Class[T] = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]] + override def toString(): String = s"ExpectFailure[${causeClass.getCanonicalName}]" + } + + /** Assert that a body is true */ + class Assert(condition: => Boolean, val message: String = "") extends StreamAction { + def run(): Unit = { Assertions.assert(condition) } + override def toString: String = s"Assert(, $message)" + } + + object Assert { + def apply(condition: => Boolean, message: String = ""): Assert = new Assert(condition, message) + def apply(message: String)(body: => Unit): Assert = new Assert( { body; true }, message) + def apply(body: => Unit): Assert = new Assert( { body; true }, "") + } + + /** Assert that a condition on the active query is true */ + class AssertOnQuery(val condition: StreamExecution => Boolean, val message: String) + extends StreamAction { + override def toString: String = s"AssertOnQuery(, $message)" + } + + object AssertOnQuery { + def apply(condition: StreamExecution => Boolean, message: String = ""): AssertOnQuery = { + new AssertOnQuery(condition, message) + } + + def apply(message: String)(condition: StreamExecution => Boolean): AssertOnQuery = { + new AssertOnQuery(condition, message) + } + } + + /** + * Executes the specified actions on the given streaming DataFrame and provides helpful + * error messages in the case of failures or incorrect answers. + * + * Note that if the stream is not explicitly started before an action that requires it to be + * running then it will be automatically started before performing any other actions. + */ + def testStream(_stream: Dataset[_])(actions: StreamAction*): Unit = { + val stream = _stream.toDF() + var pos = 0 + var currentPlan: LogicalPlan = stream.logicalPlan + var currentStream: StreamExecution = null + var lastStream: StreamExecution = null + val awaiting = new mutable.HashMap[Source, Offset]() + val sink = new MemorySink(stream.schema) + + @volatile + var streamDeathCause: Throwable = null + + // If the test doesn't manually start the stream, we do it automatically at the beginning. + val startedManually = + actions.takeWhile(!_.isInstanceOf[StreamMustBeRunning]).contains(StartStream) + val startedTest = if (startedManually) actions else StartStream +: actions + + def testActions = actions.zipWithIndex.map { + case (a, i) => + if ((pos == i && startedManually) || (pos == (i + 1) && !startedManually)) { + "=> " + a.toString + } else { + " " + a.toString + } + }.mkString("\n") + + def currentOffsets = + if (currentStream != null) currentStream.committedOffsets.toString else "not started" + + def threadState = + if (currentStream != null && currentStream.microBatchThread.isAlive) "alive" else "dead" + + def testState = + s""" + |== Progress == + |$testActions + | + |== Stream == + |Stream state: $currentOffsets + |Thread state: $threadState + |${if (streamDeathCause != null) stackTraceToString(streamDeathCause) else ""} + | + |== Sink == + |${sink.toDebugString} + | + |== Plan == + |${if (currentStream != null) currentStream.lastExecution else ""} + """.stripMargin + + def verify(condition: => Boolean, message: String): Unit = { + if (!condition) { + failTest(message) + } + } + + def eventually[T](message: String)(func: => T): T = { + try { + Eventually.eventually(Timeout(streamingTimeout)) { + func + } + } catch { + case NonFatal(e) => + failTest(message, e) + } + } + + def failTest(message: String, cause: Throwable = null) = { + + // Recursively pretty print a exception with truncated stacktrace and internal cause + def exceptionToString(e: Throwable, prefix: String = ""): String = { + val base = s"$prefix${e.getMessage}" + + e.getStackTrace.take(10).mkString(s"\n$prefix", s"\n$prefix\t", "\n") + if (e.getCause != null) { + base + s"\n$prefix\tCaused by: " + exceptionToString(e.getCause, s"$prefix\t") + } else { + base + } + } + val c = Option(cause).map(exceptionToString(_)) + val m = if (message != null && message.size > 0) Some(message) else None + fail( + s""" + |${(m ++ c).mkString(": ")} + |$testState + """.stripMargin) + } + + val testThread = Thread.currentThread() + val metadataRoot = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath + + try { + startedTest.foreach { action => + action match { + case StartStream => + verify(currentStream == null, "stream already running") + lastStream = currentStream + currentStream = + sqlContext + .streams + .startQuery( + StreamExecution.nextName, + metadataRoot, + stream, + sink) + .asInstanceOf[StreamExecution] + currentStream.microBatchThread.setUncaughtExceptionHandler( + new UncaughtExceptionHandler { + override def uncaughtException(t: Thread, e: Throwable): Unit = { + streamDeathCause = e + testThread.interrupt() + } + }) + + case StopStream => + verify(currentStream != null, "can not stop a stream that is not running") + try failAfter(streamingTimeout) { + currentStream.stop() + verify(!currentStream.microBatchThread.isAlive, + s"microbatch thread not stopped") + verify(!currentStream.isActive, + "query.isActive() is false even after stopping") + verify(currentStream.exception.isEmpty, + s"query.exception() is not empty after clean stop: " + + currentStream.exception.map(_.toString()).getOrElse("")) + } catch { + case _: InterruptedException => + case _: org.scalatest.exceptions.TestFailedDueToTimeoutException => + failTest("Timed out while stopping and waiting for microbatchthread to terminate.") + case t: Throwable => + failTest("Error while stopping stream", t) + } finally { + lastStream = currentStream + currentStream = null + } + + case ef: ExpectFailure[_] => + verify(currentStream != null, "can not expect failure when stream is not running") + try failAfter(streamingTimeout) { + val thrownException = intercept[ContinuousQueryException] { + currentStream.awaitTermination() + } + eventually("microbatch thread not stopped after termination with failure") { + assert(!currentStream.microBatchThread.isAlive) + } + verify(thrownException.query.eq(currentStream), + s"incorrect query reference in exception") + verify(currentStream.exception === Some(thrownException), + s"incorrect exception returned by query.exception()") + + val exception = currentStream.exception.get + verify(exception.cause.getClass === ef.causeClass, + "incorrect cause in exception returned by query.exception()\n" + + s"\tExpected: ${ef.causeClass}\n\tReturned: ${exception.cause.getClass}") + } catch { + case _: InterruptedException => + case _: org.scalatest.exceptions.TestFailedDueToTimeoutException => + failTest("Timed out while waiting for failure") + case t: Throwable => + failTest("Error while checking stream failure", t) + } finally { + lastStream = currentStream + currentStream = null + streamDeathCause = null + } + + case a: AssertOnQuery => + verify(currentStream != null || lastStream != null, + "cannot assert when not stream has been started") + val streamToAssert = Option(currentStream).getOrElse(lastStream) + verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}") + + case a: Assert => + val streamToAssert = Option(currentStream).getOrElse(lastStream) + verify({ a.run(); true }, s"Assert failed: ${a.message}") + + case a: AddData => + awaiting.put(a.source, a.addData()) + + case CheckAnswerRows(expectedAnswer, lastOnly) => + verify(currentStream != null, "stream not running") + + // Block until all data added has been processed + awaiting.foreach { case (source, offset) => + failAfter(streamingTimeout) { + currentStream.awaitOffset(source, offset) + } + } + + val sparkAnswer = try if (lastOnly) sink.lastBatch else sink.allData catch { + case e: Exception => + failTest("Exception while getting data from sink", e) + } + + QueryTest.sameRows(expectedAnswer, sparkAnswer).foreach { + error => failTest(error) + } + } + pos += 1 + } + } catch { + case _: InterruptedException if streamDeathCause != null => + failTest("Stream Thread Died") + case _: org.scalatest.exceptions.TestFailedDueToTimeoutException => + failTest("Timed out waiting for stream") + } finally { + if (currentStream != null && currentStream.microBatchThread.isAlive) { + currentStream.stop() + } + } + } + + /** + * Creates a stress test that randomly starts/stops/adds data/checks the result. + * + * @param ds a dataframe that executes + 1 on a stream of integers, returning the result. + * @param addData and add data action that adds the given numbers to the stream, encoding them + * as needed + */ + def runStressTest( + ds: Dataset[Int], + addData: Seq[Int] => StreamAction, + iterations: Int = 100): Unit = { + implicit val intEncoder = ExpressionEncoder[Int]() + var dataPos = 0 + var running = true + val actions = new ArrayBuffer[StreamAction]() + + def addCheck() = { actions += CheckAnswer(1 to dataPos: _*) } + + def addRandomData() = { + val numItems = Random.nextInt(10) + val data = dataPos until (dataPos + numItems) + dataPos += numItems + actions += addData(data) + } + + (1 to iterations).foreach { i => + val rand = Random.nextDouble() + if(!running) { + rand match { + case r if r < 0.7 => // AddData + addRandomData() + + case _ => // StartStream + actions += StartStream + running = true + } + } else { + rand match { + case r if r < 0.1 => + addCheck() + + case r if r < 0.7 => // AddData + addRandomData() + + case _ => // StopStream + addCheck() + actions += StopStream + running = false + } + } + } + if(!running) { actions += StartStream } + addCheck() + testStream(ds)(actions: _*) + } + + + object AwaitTerminationTester { + + trait ExpectedBehavior + + /** Expect awaitTermination to not be blocked */ + case object ExpectNotBlocked extends ExpectedBehavior + + /** Expect awaitTermination to get blocked */ + case object ExpectBlocked extends ExpectedBehavior + + /** Expect awaitTermination to throw an exception */ + case class ExpectException[E <: Exception]()(implicit val t: ClassTag[E]) + extends ExpectedBehavior + + private val DEFAULT_TEST_TIMEOUT = 1 second + + def test( + expectedBehavior: ExpectedBehavior, + awaitTermFunc: () => Unit, + testTimeout: Span = DEFAULT_TEST_TIMEOUT + ): Unit = { + + expectedBehavior match { + case ExpectNotBlocked => + withClue("Got blocked when expected non-blocking.") { + failAfter(testTimeout) { + awaitTermFunc() + } + } + + case ExpectBlocked => + withClue("Was not blocked when expected.") { + intercept[TestFailedDueToTimeoutException] { + failAfter(testTimeout) { + awaitTermFunc() + } + } + } + + case e: ExpectException[_] => + val thrownException = + withClue(s"Did not throw ${e.t.runtimeClass.getSimpleName} when expected.") { + intercept[ContinuousQueryException] { + failAfter(testTimeout) { + awaitTermFunc() + } + } + } + assert(thrownException.cause.getClass === e.t.runtimeClass, + "exception of incorrect type was throw") + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index e2090b0a83ce7..6809f26968836 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -272,12 +272,12 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { } test("initcap function") { - val df = Seq(("ab", "a B")).toDF("l", "r") + val df = Seq(("ab", "a B", "sParK")).toDF("x", "y", "z") checkAnswer( - df.select(initcap($"l"), initcap($"r")), Row("Ab", "A B")) + df.select(initcap($"x"), initcap($"y"), initcap($"z")), Row("Ab", "A B", "Spark")) checkAnswer( - df.selectExpr("InitCap(l)", "InitCap(r)"), Row("Ab", "A B")) + df.selectExpr("InitCap(x)", "InitCap(y)", "InitCap(z)"), Row("Ab", "A B", "Spark")) } test("number format function") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala new file mode 100644 index 0000000000000..21b19fe7df8b2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext + +class SubquerySuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("simple uncorrelated scalar subquery") { + assertResult(Array(Row(1))) { + sql("select (select 1 as b) as b").collect() + } + + assertResult(Array(Row(3))) { + sql("select (select (select 1) + 1) + 1").collect() + } + + // string type + assertResult(Array(Row("s"))) { + sql("select (select 's' as s) as b").collect() + } + } + + test("uncorrelated scalar subquery in CTE") { + assertResult(Array(Row(1))) { + sql("with t2 as (select 1 as b, 2 as c) " + + "select a from (select 1 as a union all select 2 as a) t " + + "where a = (select max(b) from t2) ").collect() + } + } + + test("uncorrelated scalar subquery should return null if there is 0 rows") { + assertResult(Array(Row(null))) { + sql("select (select 's' as s limit 0) as b").collect() + } + } + + test("runtime error when the number of rows is greater than 1") { + val error2 = intercept[RuntimeException] { + sql("select (select a from (select 1 as a union all select 2 as a) t) as b").collect() + } + assert(error2.getMessage.contains( + "more than one row returned by a subquery used as an expression")) + } + + test("uncorrelated scalar subquery on a DataFrame generated query") { + val df = Seq((1, "one"), (2, "two"), (3, "three")).toDF("key", "value") + df.registerTempTable("subqueryData") + + assertResult(Array(Row(4))) { + sql("select (select key from subqueryData where key > 2 order by key limit 1) + 1").collect() + } + + assertResult(Array(Row(-3))) { + sql("select -(select max(key) from subqueryData)").collect() + } + + assertResult(Array(Row(null))) { + sql("select (select value from subqueryData limit 0)").collect() + } + + assertResult(Array(Row("two"))) { + sql("select (select min(value) from subqueryData" + + " where key = (select max(key) from subqueryData) - 1)").collect() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 6e510f0b8aff4..ec950332c5f63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.expressions.ScalaUDF -import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ @@ -85,7 +83,8 @@ class UDFSuite extends QueryTest with SharedSQLContext { val e = intercept[AnalysisException] { df.selectExpr("a_function_that_does_not_exist()") } - assert(e.getMessage.contains("undefined function")) + assert(e.getMessage.contains("Undefined function")) + assert(e.getMessage.contains("a_function_that_does_not_exist")) } test("Simple UDF") { @@ -194,106 +193,59 @@ class UDFSuite extends QueryTest with SharedSQLContext { assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) } - private def checkNumUDFs(df: DataFrame, expectedNumUDFs: Int): Unit = { - val udfs = df.queryExecution.optimizedPlan.collect { - case p: logical.Project => p.projectList.flatMap { - case e => e.collect { - case udf: ScalaUDF => udf - } - } - }.flatten - assert(udfs.length === expectedNumUDFs) - } - - test("foldable udf") { - import org.apache.spark.sql.functions._ - - val myUDF = udf((x: Int) => x + 1) - - { - val df = sql("SELECT 1 as a") - .select(col("a"), myUDF(col("a")).as("b")) - .select(col("a"), col("b"), myUDF(col("b")).as("c")) - checkNumUDFs(df, 0) - checkAnswer(df, Row(1, 2, 3)) - } - } - - test("nondeterministic udf: using UDFRegistration") { - import org.apache.spark.sql.functions._ - - val myUDF = sqlContext.udf.register("plusOne1", (x: Int) => x + 1) - sqlContext.udf.register("plusOne2", myUDF.nondeterministic) - - { - val df = sqlContext.range(1, 2).select(col("id").as("a")) - .select(col("a"), myUDF(col("a")).as("b")) - .select(col("a"), col("b"), myUDF(col("b")).as("c")) - checkNumUDFs(df, 3) - checkAnswer(df, Row(1, 2, 3)) - } - - { - val df = sqlContext.range(1, 2).select(col("id").as("a")) - .select(col("a"), callUDF("plusOne1", col("a")).as("b")) - .select(col("a"), col("b"), callUDF("plusOne1", col("b")).as("c")) - checkNumUDFs(df, 3) - checkAnswer(df, Row(1, 2, 3)) - } - - { - val df = sqlContext.range(1, 2).select(col("id").as("a")) - .select(col("a"), myUDF.nondeterministic(col("a")).as("b")) - .select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c")) - checkNumUDFs(df, 2) - checkAnswer(df, Row(1, 2, 3)) - } - - { - val df = sqlContext.range(1, 2).select(col("id").as("a")) - .select(col("a"), callUDF("plusOne2", col("a")).as("b")) - .select(col("a"), col("b"), callUDF("plusOne2", col("b")).as("c")) - checkNumUDFs(df, 2) - checkAnswer(df, Row(1, 2, 3)) - } - } - - test("nondeterministic udf: using udf function") { - import org.apache.spark.sql.functions._ - - val myUDF = udf((x: Int) => x + 1) - - { - val df = sqlContext.range(1, 2).select(col("id").as("a")) - .select(col("a"), myUDF(col("a")).as("b")) - .select(col("a"), col("b"), myUDF(col("b")).as("c")) - checkNumUDFs(df, 3) - checkAnswer(df, Row(1, 2, 3)) - } - - { - val df = sqlContext.range(1, 2).select(col("id").as("a")) - .select(col("a"), myUDF.nondeterministic(col("a")).as("b")) - .select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c")) - checkNumUDFs(df, 2) - checkAnswer(df, Row(1, 2, 3)) - } - - { - // nondeterministicUDF will not be foldable. - val df = sql("SELECT 1 as a") - .select(col("a"), myUDF.nondeterministic(col("a")).as("b")) - .select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c")) - checkNumUDFs(df, 2) - checkAnswer(df, Row(1, 2, 3)) - } - } - - test("override a registered udf") { - sqlContext.udf.register("intExpected", (x: Int) => x) - assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) - - sqlContext.udf.register("intExpected", (x: Int) => x + 1) - assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 2) + test("udf in different types") { + sqlContext.udf.register("testDataFunc", (n: Int, s: String) => { (n, s) }) + sqlContext.udf.register("decimalDataFunc", + (a: java.math.BigDecimal, b: java.math.BigDecimal) => { (a, b) }) + sqlContext.udf.register("binaryDataFunc", (a: Array[Byte], b: Int) => { (a, b) }) + sqlContext.udf.register("arrayDataFunc", + (data: Seq[Int], nestedData: Seq[Seq[Int]]) => { (data, nestedData) }) + sqlContext.udf.register("mapDataFunc", + (data: scala.collection.Map[Int, String]) => { data }) + sqlContext.udf.register("complexDataFunc", + (m: Map[String, Int], a: Seq[Int], b: Boolean) => { (m, a, b) } ) + + checkAnswer( + sql("SELECT tmp.t.* FROM (SELECT testDataFunc(key, value) AS t from testData) tmp").toDF(), + testData) + checkAnswer( + sql(""" + | SELECT tmp.t.* FROM + | (SELECT decimalDataFunc(a, b) AS t FROM decimalData) tmp + """.stripMargin).toDF(), decimalData) + checkAnswer( + sql(""" + | SELECT tmp.t.* FROM + | (SELECT binaryDataFunc(a, b) AS t FROM binaryData) tmp + """.stripMargin).toDF(), binaryData) + checkAnswer( + sql(""" + | SELECT tmp.t.* FROM + | (SELECT arrayDataFunc(data, nestedData) AS t FROM arrayData) tmp + """.stripMargin).toDF(), arrayData.toDF()) + checkAnswer( + sql(""" + | SELECT mapDataFunc(data) AS t FROM mapData + """.stripMargin).toDF(), mapData.toDF()) + checkAnswer( + sql(""" + | SELECT tmp.t.* FROM + | (SELECT complexDataFunc(m, a, b) AS t FROM complexData) tmp + """.stripMargin).toDF(), complexData.select("m", "a", "b")) + } + + test("SPARK-11716 UDFRegistration does not include the input data type in returned UDF") { + val myUDF = sqlContext.udf.register("testDataFunc", (n: Int, s: String) => { (n, s.toInt) }) + + // Without the fix, this will fail because we fail to cast data type of b to string + // because myUDF does not know its input data type. With the fix, this query should not + // fail. + checkAnswer( + testData2.select(myUDF($"a", $"b").as("t")), + testData2.selectExpr("struct(a, b)")) + + checkAnswer( + sql("SELECT tmp.t.* FROM (SELECT testDataFunc(a, b) AS t from testData2) tmp").toDF(), + testData2) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index 00f1526576cc5..a32763db054f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -34,8 +34,8 @@ class UnsafeRowSuite extends SparkFunSuite { test("UnsafeRow Java serialization") { // serializing an UnsafeRow pointing to a large buffer should only serialize the relevant data val data = new Array[Byte](1024) - val row = new UnsafeRow - row.pointTo(data, 1, 16) + val row = new UnsafeRow(1) + row.pointTo(data, 16) row.setLong(0, 19285) val ser = new JavaSerializer(new SparkConf).newInstance() @@ -47,8 +47,8 @@ class UnsafeRowSuite extends SparkFunSuite { test("UnsafeRow Kryo serialization") { // serializing an UnsafeRow pointing to a large buffer should only serialize the relevant data val data = new Array[Byte](1024) - val row = new UnsafeRow - row.pointTo(data, 1, 16) + val row = new UnsafeRow(1) + row.pointTo(data, 16) row.setLong(0, 19285) val ser = new KryoSerializer(new SparkConf).newInstance() @@ -86,11 +86,10 @@ class UnsafeRowSuite extends SparkFunSuite { offheapRowPage.getBaseOffset, arrayBackedUnsafeRow.getSizeInBytes ) - val offheapUnsafeRow: UnsafeRow = new UnsafeRow() + val offheapUnsafeRow: UnsafeRow = new UnsafeRow(3) offheapUnsafeRow.pointTo( offheapRowPage.getBaseObject, offheapRowPage.getBaseOffset, - 3, // num fields arrayBackedUnsafeRow.getSizeInBytes ) assert(offheapUnsafeRow.getBaseObject === null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index a229e5814df89..8c4afb605b01f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,22 +17,15 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} - import scala.beans.{BeanInfo, BeanProperty} -import com.clearspring.analytics.stream.cardinality.HyperLogLog - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils -import org.apache.spark.util.collection.OpenHashSet - @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { @@ -52,11 +45,8 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) - override def serialize(obj: Any): ArrayData = { - obj match { - case features: MyDenseVector => - new GenericArrayData(features.data.map(_.asInstanceOf[Any])) - } + override def serialize(features: MyDenseVector): ArrayData = { + new GenericArrayData(features.data.map(_.asInstanceOf[Any])) } override def deserialize(datum: Any): MyDenseVector = { @@ -69,6 +59,11 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { override def userClass: Class[MyDenseVector] = classOf[MyDenseVector] private[spark] override def asNullable: MyDenseVectorUDT = this + + override def equals(other: Any): Boolean = other match { + case _: MyDenseVectorUDT => true + case _ => false + } } class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest { @@ -134,25 +129,6 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) } - test("HyperLogLogUDT") { - val hyperLogLogUDT = HyperLogLogUDT - val hyperLogLog = new HyperLogLog(0.4) - (1 to 10).foreach(i => hyperLogLog.offer(Row(i))) - - val actual = hyperLogLogUDT.deserialize(hyperLogLogUDT.serialize(hyperLogLog)) - assert(actual.cardinality() === hyperLogLog.cardinality()) - assert(java.util.Arrays.equals(actual.getBytes, hyperLogLog.getBytes)) - } - - test("OpenHashSetUDT") { - val openHashSetUDT = new OpenHashSetUDT(IntegerType) - val set = new OpenHashSet[Int] - (1 to 10).foreach(i => set.add(i)) - - val actual = openHashSetUDT.deserialize(openHashSetUDT.serialize(set)) - assert(actual.iterator.toSet === set.iterator.toSet) - } - test("UDTs with JSON") { val data = Seq( "{\"id\":1,\"vec\":[1.1,2.2,3.3,4.4]}", @@ -176,7 +152,6 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT test("SPARK-10472 UserDefinedType.typeName") { assert(IntegerType.typeName === "integer") assert(new MyDenseVectorUDT().typeName === "mydensevector") - assert(new OpenHashSetUDT(IntegerType).typeName === "openhashset") } test("Catalyst type converter null handling for UDTs") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala new file mode 100644 index 0000000000000..d23f19c480633 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -0,0 +1,684 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.HashMap + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} +import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.vectorized.AggregateHashMap +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{IntegerType, LongType, StructType} +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.hash.Murmur3_x86_32 +import org.apache.spark.unsafe.map.BytesToBytesMap +import org.apache.spark.util.Benchmark + +/** + * Benchmark to measure whole stage codegen performance. + * To run this: + * build/sbt "sql/test-only *BenchmarkWholeStageCodegen" + */ +class BenchmarkWholeStageCodegen extends SparkFunSuite { + lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") + .set("spark.sql.shuffle.partitions", "1") + .set("spark.sql.autoBroadcastJoinThreshold", "1") + lazy val sc = SparkContext.getOrCreate(conf) + lazy val sqlContext = SQLContext.getOrCreate(sc) + + def runBenchmark(name: String, values: Long)(f: => Unit): Unit = { + val benchmark = new Benchmark(name, values) + + Seq(false, true).foreach { enabled => + benchmark.addCase(s"$name codegen=$enabled") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", enabled.toString) + f + } + } + + benchmark.run() + } + + // These benchmark are skipped in normal build + ignore("range/filter/sum") { + val N = 500L << 20 + runBenchmark("rang/filter/sum", N) { + sqlContext.range(N).filter("(id & 1) = 1").groupBy().sum().collect() + } + /* + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + rang/filter/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + rang/filter/sum codegen=false 14332 / 16646 36.0 27.8 1.0X + rang/filter/sum codegen=true 897 / 1022 584.6 1.7 16.4X + */ + } + + ignore("range/limit/sum") { + val N = 500L << 20 + runBenchmark("range/limit/sum", N) { + sqlContext.range(N).limit(1000000).groupBy().sum().collect() + } + /* + Westmere E56xx/L56xx/X56xx (Nehalem-C) + range/limit/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + range/limit/sum codegen=false 609 / 672 861.6 1.2 1.0X + range/limit/sum codegen=true 561 / 621 935.3 1.1 1.1X + */ + } + + ignore("range/sample/sum") { + val N = 500 << 20 + runBenchmark("range/sample/sum", N) { + sqlContext.range(N).sample(true, 0.01).groupBy().sum().collect() + } + /* + Westmere E56xx/L56xx/X56xx (Nehalem-C) + range/sample/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + range/sample/sum codegen=false 53888 / 56592 9.7 102.8 1.0X + range/sample/sum codegen=true 41614 / 42607 12.6 79.4 1.3X + */ + + runBenchmark("range/sample/sum", N) { + sqlContext.range(N).sample(false, 0.01).groupBy().sum().collect() + } + /* + Westmere E56xx/L56xx/X56xx (Nehalem-C) + range/sample/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + range/sample/sum codegen=false 12982 / 13384 40.4 24.8 1.0X + range/sample/sum codegen=true 7074 / 7383 74.1 13.5 1.8X + */ + } + + ignore("stat functions") { + val N = 100L << 20 + + runBenchmark("stddev", N) { + sqlContext.range(N).groupBy().agg("id" -> "stddev").collect() + } + + runBenchmark("kurtosis", N) { + sqlContext.range(N).groupBy().agg("id" -> "kurtosis").collect() + } + + + /** + Using ImperativeAggregate (as implemented in Spark 1.6): + + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + stddev: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + stddev w/o codegen 2019.04 10.39 1.00 X + stddev w codegen 2097.29 10.00 0.96 X + kurtosis w/o codegen 2108.99 9.94 0.96 X + kurtosis w codegen 2090.69 10.03 0.97 X + + Using DeclarativeAggregate: + + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + stddev: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + stddev codegen=false 5630 / 5776 18.0 55.6 1.0X + stddev codegen=true 1259 / 1314 83.0 12.0 4.5X + + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + kurtosis: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + kurtosis codegen=false 14847 / 15084 7.0 142.9 1.0X + kurtosis codegen=true 1652 / 2124 63.0 15.9 9.0X + */ + } + + ignore("aggregate with keys") { + val N = 20 << 20 + + val benchmark = new Benchmark("Aggregate w keys", N) + def f(): Unit = sqlContext.range(N).selectExpr("(id & 65535) as k").groupBy("k").sum().collect() + + benchmark.addCase(s"codegen = F") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "false") + f() + } + + benchmark.addCase(s"codegen = T hashmap = F") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.setConf("spark.sql.codegen.aggregate.map.enabled", "false") + f() + } + + benchmark.addCase(s"codegen = T hashmap = T") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.setConf("spark.sql.codegen.aggregate.map.enabled", "true") + f() + } + + benchmark.run() + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + codegen = F 2219 / 2392 9.4 105.8 1.0X + codegen = T hashmap = F 1330 / 1466 15.8 63.4 1.7X + codegen = T hashmap = T 384 / 518 54.7 18.3 5.8X + */ + } + + ignore("broadcast hash join") { + val N = 20 << 20 + val M = 1 << 16 + val dim = broadcast(sqlContext.range(M).selectExpr("id as k", "cast(id as string) as v")) + + runBenchmark("Join w long", N) { + sqlContext.range(N).join(dim, (col("id") % M) === col("k")).count() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Join w long codegen=false 3002 / 3262 7.0 143.2 1.0X + Join w long codegen=true 321 / 371 65.3 15.3 9.3X + */ + + runBenchmark("Join w long duplicated", N) { + val dim = broadcast(sqlContext.range(M).selectExpr("cast(id/10 as long) as k")) + sqlContext.range(N).join(dim, (col("id") % M) === col("k")).count() + } + + /** + Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Join w long duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Join w long duplicated codegen=false 3446 / 3478 6.1 164.3 1.0X + Join w long duplicated codegen=true 322 / 351 65.2 15.3 10.7X + */ + + val dim2 = broadcast(sqlContext.range(M) + .selectExpr("cast(id as int) as k1", "cast(id as int) as k2", "cast(id as string) as v")) + + runBenchmark("Join w 2 ints", N) { + sqlContext.range(N).join(dim2, + (col("id") % M).cast(IntegerType) === col("k1") + && (col("id") % M).cast(IntegerType) === col("k2")).count() + } + + /** + Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Join w 2 ints codegen=false 4426 / 4501 4.7 211.1 1.0X + Join w 2 ints codegen=true 791 / 818 26.5 37.7 5.6X + */ + + val dim3 = broadcast(sqlContext.range(M) + .selectExpr("id as k1", "id as k2", "cast(id as string) as v")) + + runBenchmark("Join w 2 longs", N) { + sqlContext.range(N).join(dim3, + (col("id") % M) === col("k1") && (col("id") % M) === col("k2")) + .count() + } + + /** + Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Join w 2 longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Join w 2 longs codegen=false 5905 / 6123 3.6 281.6 1.0X + Join w 2 longs codegen=true 2230 / 2529 9.4 106.3 2.6X + */ + + val dim4 = broadcast(sqlContext.range(M) + .selectExpr("cast(id/10 as long) as k1", "cast(id/10 as long) as k2")) + + runBenchmark("Join w 2 longs duplicated", N) { + sqlContext.range(N).join(dim4, + (col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2")) + .count() + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Join w 2 longs duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Join w 2 longs duplicated codegen=false 6420 / 6587 3.3 306.1 1.0X + Join w 2 longs duplicated codegen=true 2080 / 2139 10.1 99.2 3.1X + */ + + runBenchmark("outer join w long", N) { + sqlContext.range(N).join(dim, (col("id") % M) === col("k"), "left").count() + } + + /** + Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + outer join w long codegen=false 3055 / 3189 6.9 145.7 1.0X + outer join w long codegen=true 261 / 276 80.5 12.4 11.7X + */ + + runBenchmark("semi join w long", N) { + sqlContext.range(N).join(dim, (col("id") % M) === col("k"), "leftsemi").count() + } + + /** + Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + semi join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + semi join w long codegen=false 1912 / 1990 11.0 91.2 1.0X + semi join w long codegen=true 237 / 244 88.3 11.3 8.1X + */ + } + + ignore("sort merge join") { + val N = 2 << 20 + runBenchmark("merge join", N) { + val df1 = sqlContext.range(N).selectExpr(s"id * 2 as k1") + val df2 = sqlContext.range(N).selectExpr(s"id * 3 as k2") + df1.join(df2, col("k1") === col("k2")).count() + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + merge join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + merge join codegen=false 1588 / 1880 1.3 757.1 1.0X + merge join codegen=true 1477 / 1531 1.4 704.2 1.1X + */ + + runBenchmark("sort merge join", N) { + val df1 = sqlContext.range(N) + .selectExpr(s"(id * 15485863) % ${N*10} as k1") + val df2 = sqlContext.range(N) + .selectExpr(s"(id * 15485867) % ${N*10} as k2") + df1.join(df2, col("k1") === col("k2")).count() + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + sort merge join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + sort merge join codegen=false 3626 / 3667 0.6 1728.9 1.0X + sort merge join codegen=true 3405 / 3438 0.6 1623.8 1.1X + */ + } + + ignore("shuffle hash join") { + val N = 4 << 20 + sqlContext.setConf("spark.sql.shuffle.partitions", "2") + sqlContext.setConf("spark.sql.autoBroadcastJoinThreshold", "10000000") + sqlContext.setConf("spark.sql.join.preferSortMergeJoin", "false") + runBenchmark("shuffle hash join", N) { + val df1 = sqlContext.range(N).selectExpr(s"id as k1") + val df2 = sqlContext.range(N / 5).selectExpr(s"id * 3 as k2") + df1.join(df2, col("k1") === col("k2")).count() + } + + /** + Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + shuffle hash join codegen=false 1101 / 1391 3.8 262.6 1.0X + shuffle hash join codegen=true 528 / 578 7.9 125.8 2.1X + */ + } + + ignore("cube") { + val N = 5 << 20 + + runBenchmark("cube", N) { + sqlContext.range(N).selectExpr("id", "id % 1000 as k1", "id & 256 as k2") + .cube("k1", "k2").sum("id").collect() + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + cube: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + cube codegen=false 3188 / 3392 1.6 608.2 1.0X + cube codegen=true 1239 / 1394 4.2 236.3 2.6X + */ + } + + ignore("hash and BytesToBytesMap") { + val N = 20 << 20 + + val benchmark = new Benchmark("BytesToBytesMap", N) + + benchmark.addCase("UnsafeRowhash") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var s = 0 + while (i < N) { + key.setInt(0, i % 1000) + val h = Murmur3_x86_32.hashUnsafeWords( + key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 42) + s += h + i += 1 + } + } + + benchmark.addCase("murmur3 hash") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var p = 524283 + var s = 0 + while (i < N) { + var h = Murmur3_x86_32.hashLong(i, 42) + key.setInt(0, h) + s += h + i += 1 + } + } + + benchmark.addCase("fast hash") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var p = 524283 + var s = 0 + while (i < N) { + var h = i % p + if (h < 0) { + h += p + } + key.setInt(0, h) + s += h + i += 1 + } + } + + benchmark.addCase("arrayEqual") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + var s = 0 + while (i < N) { + key.setInt(0, i % 1000) + if (key.equals(value)) { + s += 1 + } + i += 1 + } + } + + benchmark.addCase("Java HashMap (Long)") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val map = new HashMap[Long, UnsafeRow]() + while (i < 65536) { + value.setInt(0, i) + map.put(i.toLong, value) + i += 1 + } + var s = 0 + i = 0 + while (i < N) { + if (map.get(i % 100000) != null) { + s += 1 + } + i += 1 + } + } + + benchmark.addCase("Java HashMap (two ints) ") { iter => + var i = 0 + val valueBytes = new Array[Byte](16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val map = new HashMap[Long, UnsafeRow]() + while (i < 65536) { + value.setInt(0, i) + val key = (i.toLong << 32) + Integer.rotateRight(i, 15) + map.put(key, value) + i += 1 + } + var s = 0 + i = 0 + while (i < N) { + val key = ((i & 100000).toLong << 32) + Integer.rotateRight(i & 100000, 15) + if (map.get(key) != null) { + s += 1 + } + i += 1 + } + } + + benchmark.addCase("Java HashMap (UnsafeRow)") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val map = new HashMap[UnsafeRow, UnsafeRow]() + while (i < 65536) { + key.setInt(0, i) + value.setInt(0, i) + map.put(key, value.copy()) + i += 1 + } + var s = 0 + i = 0 + while (i < N) { + key.setInt(0, i % 100000) + if (map.get(key) != null) { + s += 1 + } + i += 1 + } + } + + Seq(false, true).foreach { optimized => + benchmark.addCase(s"LongToUnsafeRowMap (opt=$optimized)") { iter => + var i = 0 + val valueBytes = new Array[Byte](16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val map = new LongToUnsafeRowMap(taskMemoryManager, 64) + while (i < 65536) { + value.setInt(0, i) + val key = i % 100000 + map.append(key, value) + i += 1 + } + if (optimized) { + map.optimize() + } + var s = 0 + i = 0 + while (i < N) { + val key = i % 100000 + if (map.getValue(key, value) != null) { + s += 1 + } + i += 1 + } + } + } + + Seq("off", "on").foreach { heap => + benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", s"${heap == "off"}") + .set("spark.memory.offHeap.size", "102400000"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val map = new BytesToBytesMap(taskMemoryManager, 1024, 64L<<20) + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var i = 0 + val numKeys = 65536 + while (i < numKeys) { + key.setInt(0, i % 65536) + val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + Murmur3_x86_32.hashLong(i % 65536, 42)) + if (!loc.isDefined) { + loc.append(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) + } + i += 1 + } + i = 0 + var s = 0 + while (i < N) { + key.setInt(0, i % 100000) + val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + Murmur3_x86_32.hashLong(i % 100000, 42)) + if (loc.isDefined) { + s += 1 + } + i += 1 + } + } + } + + benchmark.addCase("Aggregate HashMap") { iter => + var i = 0 + val numKeys = 65536 + val schema = new StructType() + .add("key", LongType) + .add("value", LongType) + val map = new AggregateHashMap(schema) + while (i < numKeys) { + val row = map.findOrInsert(i.toLong) + row.setLong(1, row.getLong(1) + 1) + i += 1 + } + var s = 0 + i = 0 + while (i < N) { + if (map.find(i % 100000) != -1) { + s += 1 + } + i += 1 + } + } + + /** + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + UnsafeRow hash 267 / 284 78.4 12.8 1.0X + murmur3 hash 102 / 129 205.5 4.9 2.6X + fast hash 79 / 96 263.8 3.8 3.4X + arrayEqual 164 / 172 128.2 7.8 1.6X + Java HashMap (Long) 321 / 399 65.4 15.3 0.8X + Java HashMap (two ints) 328 / 363 63.9 15.7 0.8X + Java HashMap (UnsafeRow) 1140 / 1200 18.4 54.3 0.2X + LongToUnsafeRowMap (opt=false) 378 / 400 55.5 18.0 0.7X + LongToUnsafeRowMap (opt=true) 144 / 152 145.2 6.9 1.9X + BytesToBytesMap (off Heap) 1300 / 1616 16.1 62.0 0.2X + BytesToBytesMap (on Heap) 1165 / 1202 18.0 55.5 0.2X + Aggregate HashMap 121 / 131 173.3 5.8 2.2X + */ + benchmark.run() + } + + ignore("collect") { + val N = 1 << 20 + + val benchmark = new Benchmark("collect", N) + benchmark.addCase("collect 1 million") { iter => + sqlContext.range(N).collect() + } + benchmark.addCase("collect 2 millions") { iter => + sqlContext.range(N * 2).collect() + } + benchmark.addCase("collect 4 millions") { iter => + sqlContext.range(N * 4).collect() + } + benchmark.run() + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + collect 1 million 439 / 654 2.4 418.7 1.0X + collect 2 millions 961 / 1907 1.1 916.4 0.5X + collect 4 millions 3193 / 3895 0.3 3044.7 0.1X + */ + } + + ignore("collect limit") { + val N = 1 << 20 + + val benchmark = new Benchmark("collect limit", N) + benchmark.addCase("collect limit 1 million") { iter => + sqlContext.range(N * 4).limit(N).collect() + } + benchmark.addCase("collect limit 2 millions") { iter => + sqlContext.range(N * 4).limit(N * 2).collect() + } + benchmark.run() + + /** + model name : Westmere E56xx/L56xx/X56xx (Nehalem-C) + collect limit: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + collect limit 1 million 833 / 1284 1.3 794.4 1.0X + collect limit 2 millions 3348 / 4005 0.3 3193.3 0.2X + */ + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 25f2f5caeed15..01d485ce2d713 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.execution import org.scalatest.BeforeAndAfterAll +import org.apache.spark.{MapOutputStatistics, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchange} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql._ -import org.apache.spark.{SparkFunSuite, SparkContext, SparkConf, MapOutputStatistics} class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { @@ -30,11 +32,11 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { private var originalInstantiatedSQLContext: Option[SQLContext] = _ override protected def beforeAll(): Unit = { - originalActiveSQLContext = SQLContext.getActiveContextOption() + originalActiveSQLContext = SQLContext.getActive() originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() SQLContext.clearActive() - originalInstantiatedSQLContext.foreach(ctx => SQLContext.clearInstantiatedContext(ctx)) + SQLContext.clearInstantiatedContext() } override protected def afterAll(): Unit = { @@ -260,6 +262,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { .set("spark.driver.allowMultipleContexts", "true") .set(SQLConf.SHUFFLE_PARTITIONS.key, "5") .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") + .set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1") .set( SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, targetNumPostShufflePartitions.toString) @@ -296,13 +299,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = agg.queryExecution.executedPlan.collect { - case e: Exchange => e + case e: ShuffleExchange => e } assert(exchanges.length === 1) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: Exchange => + case e: ShuffleExchange => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 3) case o => @@ -310,7 +313,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { case None => exchanges.foreach { - case e: Exchange => + case e: ShuffleExchange => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 2) case o => @@ -318,7 +321,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } } - withSQLContext(test, 1536, minNumPostShufflePartitions) + withSQLContext(test, 2000, minNumPostShufflePartitions) } test(s"determining the number of reducers: join operator$testNameNote") { @@ -339,7 +342,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { sqlContext .range(0, 1000) .selectExpr("id % 500 as key", "id as value") - .unionAll(sqlContext.range(0, 1000).selectExpr("id % 500 as key", "id as value")) + .union(sqlContext.range(0, 1000).selectExpr("id % 500 as key", "id as value")) checkAnswer( join, expectedAnswer.collect()) @@ -347,13 +350,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = join.queryExecution.executedPlan.collect { - case e: Exchange => e + case e: ShuffleExchange => e } assert(exchanges.length === 2) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: Exchange => + case e: ShuffleExchange => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 3) case o => @@ -361,7 +364,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { case None => exchanges.foreach { - case e: Exchange => + case e: ShuffleExchange => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 2) case o => @@ -403,13 +406,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = join.queryExecution.executedPlan.collect { - case e: Exchange => e + case e: ShuffleExchange => e } assert(exchanges.length === 4) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: Exchange => + case e: ShuffleExchange => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 3) case o => @@ -421,7 +424,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } } - withSQLContext(test, 6144, minNumPostShufflePartitions) + withSQLContext(test, 6644, minNumPostShufflePartitions) } test(s"determining the number of reducers: complex query 2$testNameNote") { @@ -455,13 +458,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = join.queryExecution.executedPlan.collect { - case e: Exchange => e + case e: ShuffleExchange => e } assert(exchanges.length === 3) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: Exchange => + case e: ShuffleExchange => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 3) case o => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 911d12e93e503..17f2343cf971e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -18,18 +18,87 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.plans.physical.SinglePartition +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition} +import org.apache.spark.sql.execution.exchange.{BroadcastExchange, ReusedExchange, ShuffleExchange} +import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode import org.apache.spark.sql.test.SharedSQLContext class ExchangeSuite extends SparkPlanTest with SharedSQLContext { - import testImplicits.localSeqToDataFrameHolder + import testImplicits._ test("shuffling UnsafeRows in exchange") { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( input.toDF(), - plan => ConvertToSafe(Exchange(SinglePartition, ConvertToUnsafe(plan))), + plan => ShuffleExchange(SinglePartition, plan), input.map(Row.fromTuple) ) } + + test("compatible BroadcastMode") { + val mode1 = IdentityBroadcastMode + val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil) + val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil) + + assert(mode1.compatibleWith(mode1)) + assert(!mode1.compatibleWith(mode2)) + assert(!mode2.compatibleWith(mode1)) + assert(mode2.compatibleWith(mode2)) + assert(!mode2.compatibleWith(mode3)) + assert(mode3.compatibleWith(mode3)) + } + + test("BroadcastExchange same result") { + val df = sqlContext.range(10) + val plan = df.queryExecution.executedPlan + val output = plan.output + assert(plan sameResult plan) + + val exchange1 = BroadcastExchange(IdentityBroadcastMode, plan) + val hashMode = HashedRelationBroadcastMode(output) + val exchange2 = BroadcastExchange(hashMode, plan) + val hashMode2 = + HashedRelationBroadcastMode(Alias(output.head, "id2")() :: Nil) + val exchange3 = BroadcastExchange(hashMode2, plan) + val exchange4 = ReusedExchange(output, exchange3) + + assert(exchange1 sameResult exchange1) + assert(exchange2 sameResult exchange2) + assert(exchange3 sameResult exchange3) + assert(exchange4 sameResult exchange4) + + assert(!exchange1.sameResult(exchange2)) + assert(!exchange2.sameResult(exchange3)) + assert(!exchange3.sameResult(exchange4)) + assert(exchange4 sameResult exchange3) + } + + test("ShuffleExchange same result") { + val df = sqlContext.range(10) + val plan = df.queryExecution.executedPlan + val output = plan.output + assert(plan sameResult plan) + + val part1 = HashPartitioning(output, 1) + val exchange1 = ShuffleExchange(part1, plan) + val exchange2 = ShuffleExchange(part1, plan) + val part2 = HashPartitioning(output, 2) + val exchange3 = ShuffleExchange(part2, plan) + val part3 = HashPartitioning(output ++ output, 2) + val exchange4 = ShuffleExchange(part3, plan) + val exchange5 = ReusedExchange(output, exchange4) + + assert(exchange1 sameResult exchange1) + assert(exchange2 sameResult exchange2) + assert(exchange3 sameResult exchange3) + assert(exchange4 sameResult exchange4) + assert(exchange5 sameResult exchange5) + + assert(exchange1 sameResult exchange2) + assert(!exchange2.sameResult(exchange3)) + assert(!exchange3.sameResult(exchange4)) + assert(!exchange4.sameResult(exchange5)) + assert(exchange5 sameResult exchange4) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala index e7a08481cfa80..6f10e4b80577a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.types.{LongType, StringType, IntegerType, StructType} +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} class GroupedIteratorSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 2076c573b56c1..bdbcf842ca47d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,51 +18,41 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, Row, SQLConf} +import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchange, ReuseExchange, ShuffleExchange} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ - class PlannerSuite extends SharedSQLContext { import testImplicits._ setupTestData() private def testPartialAggregationPlan(query: LogicalPlan): Unit = { - val planner = sqlContext.planner + val planner = sqlContext.sessionState.planner import planner._ - val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption) + val plannedOption = Aggregation(query).headOption val planned = plannedOption.getOrElse( fail(s"Could query play aggregation query $query. Is it an aggregation query?")) val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } - // For the new aggregation code path, there will be three aggregate operator for + // For the new aggregation code path, there will be four aggregate operator for // distinct aggregations. assert( - aggregations.size == 2 || aggregations.size == 3, + aggregations.size == 2 || aggregations.size == 4, s"The plan of query $query does not have partial aggregations.") } - test("unions are collapsed") { - val planner = sqlContext.planner - import planner._ - val query = testData.unionAll(testData).unionAll(testData).logicalPlan - val planned = BasicOperators(query).head - val logicalUnions = query collect { case u: logical.Union => u } - val physicalUnions = planned collect { case u: execution.Union => u } - - assert(logicalUnions.size === 2) - assert(physicalUnions.size === 1) - } - test("count is partially aggregated") { val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed testPartialAggregationPlan(query) @@ -94,13 +84,13 @@ class PlannerSuite extends SharedSQLContext { """ |SELECT l.a, l.b |FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key) - """.stripMargin).queryExecution.executedPlan + """.stripMargin).queryExecution.sparkPlan val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } - val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } + val sortMergeJoins = planned.collect { case join: SortMergeJoin => join } assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") - assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") + assert(sortMergeJoins.isEmpty, "Should not use sort merge join") } } @@ -147,39 +137,66 @@ class PlannerSuite extends SharedSQLContext { val a = testData.as("a") val b = sqlContext.table("tiny").as("b") - val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan + val planned = a.join(b, $"a.key" === $"b.key").queryExecution.sparkPlan val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } - val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } + val sortMergeJoins = planned.collect { case join: SortMergeJoin => join } assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") - assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") + assert(sortMergeJoins.isEmpty, "Should not use shuffled hash join") sqlContext.clearCache() } } } - test("efficient limit -> project -> sort") { - { - val query = - testData.select('key, 'value).sort('key).limit(2).logicalPlan - val planned = sqlContext.planner.TakeOrderedAndProject(query) - assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) - assert(planned.head.output === testData.select('key, 'value).logicalPlan.output) - } + test("SPARK-11390 explain should print PushedFilters of PhysicalRDD") { + withTempPath { file => + val path = file.getCanonicalPath + testData.write.parquet(path) + val df = sqlContext.read.parquet(path) + sqlContext.registerDataFrameAsTable(df, "testPushed") - { - // We need to make sure TakeOrderedAndProject's output is correct when we push a project - // into it. - val query = - testData.select('key, 'value).sort('key).select('value, 'key).limit(2).logicalPlan - val planned = sqlContext.planner.TakeOrderedAndProject(query) - assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) - assert(planned.head.output === testData.select('value, 'key).logicalPlan.output) + withTempTable("testPushed") { + val exp = sql("select * from testPushed where key = 15").queryExecution.sparkPlan + assert(exp.toString.contains("PushedFilters: [IsNotNull(key), EqualTo(key,15)]")) + } } } + test("efficient terminal limit -> sort should use TakeOrderedAndProject") { + val query = testData.select('key, 'value).sort('key).limit(2) + val planned = query.queryExecution.executedPlan + assert(planned.isInstanceOf[execution.TakeOrderedAndProject]) + assert(planned.output === testData.select('key, 'value).logicalPlan.output) + } + + test("terminal limit -> project -> sort should use TakeOrderedAndProject") { + val query = testData.select('key, 'value).sort('key).select('value, 'key).limit(2) + val planned = query.queryExecution.executedPlan + assert(planned.isInstanceOf[execution.TakeOrderedAndProject]) + assert(planned.output === testData.select('value, 'key).logicalPlan.output) + } + + test("terminal limits that are not handled by TakeOrderedAndProject should use CollectLimit") { + val query = testData.select('value).limit(2) + val planned = query.queryExecution.sparkPlan + assert(planned.isInstanceOf[CollectLimit]) + assert(planned.output === testData.select('value).logicalPlan.output) + } + + test("TakeOrderedAndProject can appear in the middle of plans") { + val query = testData.select('key, 'value).sort('key).limit(2).filter('key === 3) + val planned = query.queryExecution.executedPlan + assert(planned.find(_.isInstanceOf[TakeOrderedAndProject]).isDefined) + } + + test("CollectLimit can appear in the middle of a plan when caching is used") { + val query = testData.select('key, 'value).limit(2).cache() + val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation] + assert(planned.child.isInstanceOf[CollectLimit]) + } + test("PartitioningCollection") { withTempTable("normal", "small", "tiny") { testData.registerTempTable("normal") @@ -197,9 +214,9 @@ class PlannerSuite extends SharedSQLContext { | JOIN tiny ON (small.key = tiny.key) """.stripMargin ).queryExecution.executedPlan.collect { - case exchange: Exchange => exchange + case exchange: ShuffleExchange => exchange }.length - assert(numExchanges === 3) + assert(numExchanges === 5) } { @@ -212,15 +229,27 @@ class PlannerSuite extends SharedSQLContext { | JOIN tiny ON (normal.key = tiny.key) """.stripMargin ).queryExecution.executedPlan.collect { - case exchange: Exchange => exchange + case exchange: ShuffleExchange => exchange }.length - assert(numExchanges === 3) + assert(numExchanges === 5) } } } } + test("collapse adjacent repartitions") { + val doubleRepartitioned = testData.repartition(10).repartition(20).coalesce(5) + def countRepartitions(plan: LogicalPlan): Int = plan.collect { case r: Repartition => r }.length + assert(countRepartitions(doubleRepartitioned.queryExecution.logical) === 3) + assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 1) + doubleRepartitioned.queryExecution.optimizedPlan match { + case r: Repartition => + assert(r.numPartitions === 5) + assert(r.shuffle === false) + } + } + // --- Unit tests of EnsureRequirements --------------------------------------------------------- // When it comes to testing whether EnsureRequirements properly ensures distribution requirements, @@ -266,9 +295,9 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: Exchange => true }.isEmpty) { + if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") } } @@ -286,7 +315,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) } @@ -304,9 +333,9 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: Exchange => true }.isEmpty) { + if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") } } @@ -324,9 +353,9 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: Exchange => true }.nonEmpty) { + if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { fail(s"Exchange should not have been added:\n$outputPlan") } } @@ -347,9 +376,9 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(outputOrdering, outputOrdering) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: Exchange => true }.nonEmpty) { + if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { fail(s"No Exchanges should have been added:\n$outputPlan") } } @@ -363,9 +392,9 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingB)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: TungstenSort => true; case s: Sort => true }.isEmpty) { + if (outputPlan.collect { case s: Sort => true }.isEmpty) { fail(s"Sort should have been added:\n$outputPlan") } } @@ -379,9 +408,9 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingA)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: TungstenSort => true; case s: Sort => true }.nonEmpty) { + if (outputPlan.collect { case s: Sort => true }.nonEmpty) { fail(s"No sorts should have been added:\n$outputPlan") } } @@ -396,14 +425,99 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingA, orderingB)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: TungstenSort => true; case s: Sort => true }.isEmpty) { + if (outputPlan.collect { case s: Sort => true }.isEmpty) { fail(s"Sort should have been added:\n$outputPlan") } } + test("EnsureRequirements eliminates Exchange if child has Exchange with same partitioning") { + val distribution = ClusteredDistribution(Literal(1) :: Nil) + val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) + val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) + assert(!childPartitioning.satisfies(distribution)) + val inputPlan = ShuffleExchange(finalPartitioning, + DummySparkPlan( + children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, + requiredChildDistribution = Seq(distribution), + requiredChildOrdering = Seq(Seq.empty)), + None) + + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case e: ShuffleExchange => true }.size == 2) { + fail(s"Topmost Exchange should have been eliminated:\n$outputPlan") + } + } + + test("EnsureRequirements does not eliminate Exchange with different partitioning") { + val distribution = ClusteredDistribution(Literal(1) :: Nil) + // Number of partitions differ + val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 8) + val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) + assert(!childPartitioning.satisfies(distribution)) + val inputPlan = ShuffleExchange(finalPartitioning, + DummySparkPlan( + children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, + requiredChildDistribution = Seq(distribution), + requiredChildOrdering = Seq(Seq.empty)), + None) + + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case e: ShuffleExchange => true }.size == 1) { + fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan") + } + } + // --------------------------------------------------------------------------------------------- + + test("Reuse exchanges") { + val distribution = ClusteredDistribution(Literal(1) :: Nil) + val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) + val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) + assert(!childPartitioning.satisfies(distribution)) + val shuffle = ShuffleExchange(finalPartitioning, + DummySparkPlan( + children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, + requiredChildDistribution = Seq(distribution), + requiredChildOrdering = Seq(Seq.empty)), + None) + + val inputPlan = SortMergeJoin( + Literal(1) :: Nil, + Literal(1) :: Nil, + Inner, + None, + shuffle, + shuffle) + + val outputPlan = ReuseExchange(sqlContext.sessionState.conf).apply(inputPlan) + if (outputPlan.collect { case e: ReusedExchange => true }.size != 1) { + fail(s"Should re-use the shuffle:\n$outputPlan") + } + if (outputPlan.collect { case e: ShuffleExchange => true }.size != 1) { + fail(s"Should have only one shuffle:\n$outputPlan") + } + + // nested exchanges + val inputPlan2 = SortMergeJoin( + Literal(1) :: Nil, + Literal(1) :: Nil, + Inner, + None, + ShuffleExchange(finalPartitioning, inputPlan), + ShuffleExchange(finalPartitioning, inputPlan)) + + val outputPlan2 = ReuseExchange(sqlContext.sessionState.conf).apply(inputPlan2) + if (outputPlan2.collect { case e: ReusedExchange => true }.size != 2) { + fail(s"Should re-use the two shuffles:\n$outputPlan2") + } + if (outputPlan2.collect { case e: ShuffleExchange => true }.size != 2) { + fail(s"Should have only two shuffles:\n$outputPlan") + } + } } // Used for unit-testing EnsureRequirements diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala new file mode 100644 index 0000000000000..2963a856d15cf --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.ExternalSorter + + +/** + * A reference sort implementation used to compare against our normal sort. + */ +case class ReferenceSort( + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan) + extends UnaryNode { + + override def requiredChildDistribution: Seq[Distribution] = + if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { + child.execute().mapPartitions( { iterator => + val ordering = newOrdering(sortOrder, child.output) + val sorter = new ExternalSorter[InternalRow, Null, InternalRow]( + TaskContext.get(), ordering = Some(ordering)) + sorter.insertAll(iterator.map(r => (r.copy(), null))) + val baseIterator = sorter.iterator.map(_._1) + val context = TaskContext.get() + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) + CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop()) + }, preservesPartitioning = true) + } + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala deleted file mode 100644 index b3fceeab64cfe..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull} -import org.apache.spark.sql.catalyst.util.GenericArrayData -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{ArrayType, StringType} -import org.apache.spark.unsafe.types.UTF8String - -class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { - - private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect { - case c: ConvertToUnsafe => c - case c: ConvertToSafe => c - } - - private val outputsSafe = Sort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) - assert(!outputsSafe.outputsUnsafeRows) - private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) - assert(outputsUnsafe.outputsUnsafeRows) - - test("planner should insert unsafe->safe conversions when required") { - val plan = Limit(10, outputsUnsafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe]) - } - - test("filter can process unsafe rows") { - val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).size === 1) - assert(preparedPlan.outputsUnsafeRows) - } - - test("filter can process safe rows") { - val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).isEmpty) - assert(!preparedPlan.outputsUnsafeRows) - } - - test("execute() fails an assertion if inputs rows are of different formats") { - val e = intercept[AssertionError] { - Union(Seq(outputsSafe, outputsUnsafe)).execute() - } - assert(e.getMessage.contains("format")) - } - - test("union requires all of its input rows' formats to agree") { - val plan = Union(Seq(outputsSafe, outputsUnsafe)) - assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(preparedPlan.outputsUnsafeRows) - } - - test("union can process safe rows") { - val plan = Union(Seq(outputsSafe, outputsSafe)) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(!preparedPlan.outputsUnsafeRows) - } - - test("union can process unsafe rows") { - val plan = Union(Seq(outputsUnsafe, outputsUnsafe)) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(preparedPlan.outputsUnsafeRows) - } - - test("round trip with ConvertToUnsafe and ConvertToSafe") { - val input = Seq(("hello", 1), ("world", 2)) - checkAnswer( - sqlContext.createDataFrame(input), - plan => ConvertToSafe(ConvertToUnsafe(plan)), - input.map(Row.fromTuple) - ) - } - - test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") { - SparkPlan.currentContext.set(sqlContext) - val schema = ArrayType(StringType) - val rows = (1 to 100).map { i => - InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString)))) - } - val relation = LocalTableScan(Seq(AttributeReference("t", schema)()), rows) - - val plan = - DummyPlan( - ConvertToSafe( - ConvertToUnsafe(relation))) - assert(plan.execute().collect().map(_.getUTF8String(0).toString) === (1 to 100).map(_.toString)) - } -} - -case class DummyPlan(child: SparkPlan) extends UnaryNode { - - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => - // This `DummyPlan` is in safe mode, so we don't need to do copy even we hold some - // values gotten from the incoming rows. - // we cache all strings here to make sure we have deep copied UTF8String inside incoming - // safe InternalRow. - val strings = new scala.collection.mutable.ArrayBuffer[UTF8String] - iter.foreach { row => - strings += row.getArray(0).getUTF8String(0) - } - strings.map(InternalRow(_)).iterator - } - } - - override def output: Seq[Attribute] = Seq(AttributeReference("a", StringType)()) -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala index 63639681ef80a..c9f517ca34296 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution import java.util.Properties -import scala.collection.parallel.CompositeThrowable - import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql.SQLContext @@ -51,6 +49,20 @@ class SQLExecutionSuite extends SparkFunSuite { } } + test("concurrent query execution with fork-join pool (SPARK-13747)") { + val sc = new SparkContext("local[*]", "test") + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + try { + // Should not throw IllegalArgumentException + (1 to 100).par.foreach { _ => + sc.parallelize(1 to 5).map { i => (i, i) }.toDF("a", "b").count() + } + } finally { + sc.stop() + } + } + /** * Trigger SPARK-10548 by mocking a parent and its child thread executing queries concurrently. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index 847c188a30333..778477660e169 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -17,15 +17,22 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.Row +import scala.util.Random + +import org.apache.spark.AccumulatorSuite +import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +/** + * Test sorting. Many of the test cases generate random data and compares the sorted result with one + * sorted by a reference implementation ([[ReferenceSort]]). + */ class SortSuite extends SparkPlanTest with SharedSQLContext { - import testImplicits.localSeqToDataFrameHolder + import testImplicits.newProductEncoder + import testImplicits.localSeqToDatasetHolder - // This test was originally added as an example of how to use [[SparkPlanTest]]; - // it's not designed to be a comprehensive test of ExternalSort. test("basic sorting using ExternalSort") { val input = Seq( @@ -36,14 +43,66 @@ class SortSuite extends SparkPlanTest with SharedSQLContext { checkAnswer( input.toDF("a", "b", "c"), - Sort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan), + (child: SparkPlan) => Sort('a.asc :: 'b.asc :: Nil, global = true, child = child), input.sortBy(t => (t._1, t._2)).map(Row.fromTuple), sortAnswers = false) checkAnswer( input.toDF("a", "b", "c"), - Sort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan), + (child: SparkPlan) => Sort('b.asc :: 'a.asc :: Nil, global = true, child = child), input.sortBy(t => (t._2, t._1)).map(Row.fromTuple), sortAnswers = false) } + + test("sort followed by limit") { + checkThatPlansAgree( + (1 to 100).map(v => Tuple1(v)).toDF("a"), + (child: SparkPlan) => GlobalLimit(10, Sort('a.asc :: Nil, global = true, child = child)), + (child: SparkPlan) => GlobalLimit(10, ReferenceSort('a.asc :: Nil, global = true, child)), + sortAnswers = false + ) + } + + test("sorting does not crash for large inputs") { + val sortOrder = 'a.asc :: Nil + val stringLength = 1024 * 1024 * 2 + checkThatPlansAgree( + Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), + Sort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), + ReferenceSort(sortOrder, global = true, _: SparkPlan), + sortAnswers = false + ) + } + + test("sorting updates peak execution memory") { + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "unsafe external sort") { + checkThatPlansAgree( + (1 to 100).map(v => Tuple1(v)).toDF("a"), + (child: SparkPlan) => Sort('a.asc :: Nil, global = true, child = child), + (child: SparkPlan) => ReferenceSort('a.asc :: Nil, global = true, child), + sortAnswers = false) + } + } + + // Test sorting on different data types + for ( + dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); + nullable <- Seq(true, false); + sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); + randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) + ) { + test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { + val inputData = Seq.fill(1000)(randomDataGenerator()) + val inputDf = sqlContext.createDataFrame( + sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), + StructType(StructField("a", dataType, nullable = true) :: Nil) + ) + checkThatPlansAgree( + inputDf, + p => Sort(sortOrder, global = true, p: SparkPlan, testSpillFrequency = 23), + ReferenceSort(sortOrder, global = true, _: SparkPlan), + sortAnswers = false + ) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 8549a6a0f6643..38318740a5119 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -18,11 +18,10 @@ package org.apache.spark.sql.execution import scala.language.implicitConversions -import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.test.SQLTestUtils @@ -232,10 +231,8 @@ object SparkPlanTest { } private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { - // A very simple resolver to make writing tests easier. In contrast to the real resolver - // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = sqlContext.prepareForExecution.execute( - outputPlan transform { + val execution = new QueryExecution(sqlContext, null) { + override lazy val sparkPlan: SparkPlan = outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap plan transformExpressions { @@ -244,8 +241,8 @@ object SparkPlanTest { sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) } } - ) - resolvedPlan.executeCollectPublic().toSeq + } + execution.executedPlan.executeCollectPublic().toSeq } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala new file mode 100644 index 0000000000000..a4c6d072f33a8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.util.Random + +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ + + +class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { + + private var rand: Random = _ + private var seed: Long = 0 + + protected override def beforeAll(): Unit = { + super.beforeAll() + seed = System.currentTimeMillis() + rand = new Random(seed) + } + + private def generateRandomInputData(): DataFrame = { + val schema = new StructType() + .add("a", IntegerType, nullable = false) + .add("b", IntegerType, nullable = false) + val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt())) + sqlContext.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema) + } + + /** + * Adds a no-op filter to the child plan in order to prevent executeCollect() from being + * called directly on the child plan. + */ + private def noOpFilter(plan: SparkPlan): SparkPlan = Filter(Literal(true), plan) + + val limit = 250 + val sortOrder = 'a.desc :: 'b.desc :: Nil + + test("TakeOrderedAndProject.doExecute without project") { + withClue(s"seed = $seed") { + checkThatPlansAgree( + generateRandomInputData(), + input => + noOpFilter(TakeOrderedAndProject(limit, sortOrder, None, input)), + input => + GlobalLimit(limit, + LocalLimit(limit, + Sort(sortOrder, true, input))), + sortAnswers = false) + } + } + + test("TakeOrderedAndProject.doExecute with project") { + withClue(s"seed = $seed") { + checkThatPlansAgree( + generateRandomInputData(), + input => + noOpFilter(TakeOrderedAndProject(limit, sortOrder, Some(Seq(input.output.last)), input)), + input => + GlobalLimit(limit, + LocalLimit(limit, + Project(Seq(input.output.last), + Sort(sortOrder, true, input)))), + sortAnswers = false) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala deleted file mode 100644 index 7a0f0dfd2b7f1..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import scala.util.Random - -import org.apache.spark.AccumulatorSuite -import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf} -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types._ - -/** - * A test suite that generates randomized data to test the [[TungstenSort]] operator. - */ -class TungstenSortSuite extends SparkPlanTest with SharedSQLContext { - import testImplicits.localSeqToDataFrameHolder - - override def beforeAll(): Unit = { - super.beforeAll() - sqlContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) - } - - override def afterAll(): Unit = { - try { - sqlContext.conf.unsetConf(SQLConf.CODEGEN_ENABLED) - } finally { - super.afterAll() - } - } - - test("sort followed by limit") { - checkThatPlansAgree( - (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, TungstenSort('a.asc :: Nil, true, child)), - (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), - sortAnswers = false - ) - } - - test("sorting does not crash for large inputs") { - val sortOrder = 'a.asc :: Nil - val stringLength = 1024 * 1024 * 2 - checkThatPlansAgree( - Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), - TungstenSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), - Sort(sortOrder, global = true, _: SparkPlan), - sortAnswers = false - ) - } - - test("sorting updates peak execution memory") { - AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "unsafe external sort") { - checkThatPlansAgree( - (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => TungstenSort('a.asc :: Nil, true, child), - (child: SparkPlan) => Sort('a.asc :: Nil, global = true, child), - sortAnswers = false) - } - } - - // Test sorting on different data types - for ( - dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); - nullable <- Seq(true, false); - sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); - randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) - ) { - test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { - val inputData = Seq.fill(1000)(randomDataGenerator()) - val inputDf = sqlContext.createDataFrame( - sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), - StructType(StructField("a", dataType, nullable = true) :: Nil) - ) - assert(TungstenSort.supportsSchema(inputDf.schema)) - checkThatPlansAgree( - inputDf, - plan => ConvertToSafe( - TungstenSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), - Sort(sortOrder, global = true, _: SparkPlan), - sortAnswers = false - ) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 7ceaee38d131b..c1555114e8b3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql.execution -import scala.util.control.NonFatal +import java.util.Properties + import scala.collection.mutable -import scala.util.{Try, Random} +import scala.util.{Random, Try} +import scala.util.control.NonFatal import org.scalatest.Matchers -import org.apache.spark.{SparkConf, TaskContextImpl, TaskContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl} import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -61,7 +63,7 @@ class UnsafeFixedWidthAggregationMapSuite } test(name) { - val conf = new SparkConf().set("spark.unsafe.offHeap", "false") + val conf = new SparkConf().set("spark.memory.offHeap.enabled", "false") memoryManager = new TestMemoryManager(conf) taskMemoryManager = new TaskMemoryManager(memoryManager, 0) @@ -71,8 +73,8 @@ class UnsafeFixedWidthAggregationMapSuite taskAttemptId = Random.nextInt(10000), attemptNumber = 0, taskMemoryManager = taskMemoryManager, - metricsSystem = null, - internalAccumulators = Seq.empty)) + localProperties = new Properties, + metricsSystem = null)) try { f diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 7b80963ec8708..03d4be8ee528e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.execution +import java.util.Properties + import scala.util.Random import org.apache.spark._ import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -40,8 +42,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { private val rand = new Random(42) for (i <- 0 until 6) { - val keySchema = RandomDataGenerator.randomSchema(rand.nextInt(10) + 1, keyTypes) - val valueSchema = RandomDataGenerator.randomSchema(rand.nextInt(10) + 1, valueTypes) + val keySchema = RandomDataGenerator.randomSchema(rand, rand.nextInt(10) + 1, keyTypes) + val valueSchema = RandomDataGenerator.randomSchema(rand, rand.nextInt(10) + 1, valueTypes) testKVSorter(keySchema, valueSchema, spill = i > 3) } @@ -109,7 +111,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { pageSize: Long, spill: Boolean): Unit = { val memoryManager = - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")) + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")) val taskMemMgr = new TaskMemoryManager(memoryManager, 0) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, @@ -117,11 +119,11 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { taskAttemptId = 98456, attemptNumber = 0, taskMemoryManager = taskMemMgr, - metricsSystem = null, - internalAccumulators = Seq.empty)) + localProperties = new Properties, + metricsSystem = null)) val sorter = new UnsafeKVExternalSorter( - keySchema, valueSchema, SparkEnv.get.blockManager, pageSize) + keySchema, valueSchema, SparkEnv.get.blockManager, SparkEnv.get.serializerManager, pageSize) // Insert the keys and values into the sorter inputData.foreach { case (k, v) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 09e258299de5a..01687877eeed6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -17,20 +17,19 @@ package org.apache.spark.sql.execution -import java.io.{File, ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File} +import java.util.Properties -import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark._ import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rdd.RDD -import org.apache.spark.storage.ShuffleBlockId -import org.apache.spark.util.collection.ExternalSorter -import org.apache.spark.util.Utils import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ -import org.apache.spark._ - +import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.util.Utils /** * used to test close InputStream in UnsafeRowSerializer @@ -115,12 +114,12 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { } val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0) val taskContext = new TaskContextImpl( - 0, 0, 0, 0, taskMemoryManager, null, InternalAccumulator.create(sc)) + 0, 0, 0, 0, taskMemoryManager, new Properties, null, InternalAccumulator.createAll(sc)) val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( taskContext, partitioner = Some(new HashPartitioner(10)), - serializer = Some(new UnsafeRowSerializer(numFields = 1))) + serializer = new UnsafeRowSerializer(numFields = 1)) // Ensure we spilled something and have to merge them later assert(sorter.numSpills === 0) @@ -128,7 +127,6 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { assert(sorter.numSpills > 0) // Merging spilled files should not throw assertion error - taskContext.taskMetrics.shuffleWriteMetrics = Some(new ShuffleWriteMetrics) sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile) } { // Clean up @@ -156,7 +154,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { new ShuffleDependency[Int, InternalRow, InternalRow]( rowsRDD, new PartitionIdPassthrough(2), - Some(new UnsafeRowSerializer(2))) + new UnsafeRowSerializer(2)) val shuffled = new ShuffledRowRDD(dependency) shuffled.count() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala new file mode 100644 index 0000000000000..8efd9de29eb0f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.aggregate.TungstenAggregate +import org.apache.spark.sql.execution.joins.BroadcastHashJoin +import org.apache.spark.sql.expressions.scala.typed +import org.apache.spark.sql.functions.{avg, broadcast, col, max} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} + +class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { + + test("range/filter should be combined") { + val df = sqlContext.range(10).filter("id = 1").selectExpr("id + 1") + val plan = df.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[WholeStageCodegen]).isDefined) + assert(df.collect() === Array(Row(2))) + } + + test("Aggregate should be included in WholeStageCodegen") { + val df = sqlContext.range(10).groupBy().agg(max(col("id")), avg(col("id"))) + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[TungstenAggregate]).isDefined) + assert(df.collect() === Array(Row(9, 4.5))) + } + + test("Aggregate with grouping keys should be included in WholeStageCodegen") { + val df = sqlContext.range(3).groupBy("id").count().orderBy("id") + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[TungstenAggregate]).isDefined) + assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1))) + } + + test("BroadcastHashJoin should be included in WholeStageCodegen") { + val rdd = sqlContext.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2"))) + val schema = new StructType().add("k", IntegerType).add("v", StringType) + val smallDF = sqlContext.createDataFrame(rdd, schema) + val df = sqlContext.range(10).join(broadcast(smallDF), col("k") === col("id")) + assert(df.queryExecution.executedPlan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[BroadcastHashJoin]).isDefined) + assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2"))) + } + + test("Sort should be included in WholeStageCodegen") { + val df = sqlContext.range(3, 0, -1).toDF().sort(col("id")) + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Sort]).isDefined) + assert(df.collect() === Array(Row(1), Row(2), Row(3))) + } + + test("MapElements should be included in WholeStageCodegen") { + import testImplicits._ + + val ds = sqlContext.range(10).map(_.toString) + val plan = ds.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[MapElements]).isDefined) + assert(ds.collect() === 0.until(10).map(_.toString).toArray) + } + + test("typed filter should be included in WholeStageCodegen") { + val ds = sqlContext.range(10).filter(_ % 2 == 0) + val plan = ds.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Filter]).isDefined) + assert(ds.collect() === Array(0, 2, 4, 6, 8)) + } + + test("back-to-back typed filter should be included in WholeStageCodegen") { + val ds = sqlContext.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0) + val plan = ds.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[SerializeFromObject]).isDefined) + assert(ds.collect() === Array(0, 6)) + } + + test("simple typed UDAF should be included in WholeStageCodegen") { + import testImplicits._ + + val ds = Seq(("a", 10), ("b", 1), ("b", 2), ("c", 1)).toDS() + .groupByKey(_._1).agg(typed.sum(_._2)) + + val plan = ds.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[TungstenAggregate]).isDefined) + assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala index 89a664001bdd2..b2d04f7c5a6e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericInternalRow @@ -50,7 +50,7 @@ class ColumnStatsSuite extends SparkFunSuite { } test(s"$columnStatsName: non-empty") { - import org.apache.spark.sql.columnar.ColumnarTestUtils._ + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ val columnStats = columnStatsClass.newInstance() val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) @@ -86,7 +86,7 @@ class ColumnStatsSuite extends SparkFunSuite { } test(s"$columnStatsName: non-empty") { - import org.apache.spark.sql.columnar.ColumnarTestUtils._ + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ val columnStats = new DecimalColumnStats(15, 10) val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala similarity index 90% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala index 63bc39bfa0307..052f4cbaebc8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala @@ -15,17 +15,18 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar -import java.nio.{ByteOrder, ByteBuffer} +import java.nio.{ByteBuffer, ByteOrder} +import java.nio.charset.StandardCharsets +import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} -import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeProjection} +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types._ -import org.apache.spark.{Logging, SparkFunSuite} - class ColumnTypeSuite extends SparkFunSuite with Logging { private val DEFAULT_BUFFER_SIZE = 512 @@ -35,7 +36,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { test("defaultSize") { val checks = Map( - NULL-> 0, BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, LONG -> 8, + NULL -> 0, BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, LONG -> 8, FLOAT -> 4, DOUBLE -> 8, COMPACT_DECIMAL(15, 10) -> 8, LARGE_DECIMAL(20, 10) -> 12, STRING -> 8, BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 16, MAP_TYPE -> 32) @@ -68,7 +69,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { checkActualSize(LONG, Long.MaxValue, 8) checkActualSize(FLOAT, Float.MaxValue, 4) checkActualSize(DOUBLE, Double.MaxValue, 8) - checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length) + checkActualSize(STRING, "hello", 4 + "hello".getBytes(StandardCharsets.UTF_8).length) checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4) checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8) checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala similarity index 94% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala index a5882f7870e37..1529313dfbd51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import scala.collection.immutable.HashSet import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, GenericMutableRow} -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types.{AtomicType, Decimal} import org.apache.spark.unsafe.types.UTF8String @@ -60,6 +60,7 @@ object ColumnarTestUtils { case MAP(_) => ArrayBasedMapData( Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32))))) + case _ => throw new IllegalArgumentException(s"Unknown column type $columnType") }).asInstanceOf[JvmType] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala similarity index 90% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 6265e40a0a07b..50c8745a288f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -15,8 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar +import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import org.apache.spark.sql.{QueryTest, Row} @@ -31,7 +32,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { setupTestData() test("simple columnar query") { - val plan = sqlContext.executePlan(testData.logicalPlan).executedPlan + val plan = sqlContext.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -48,7 +49,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("projection") { - val plan = sqlContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan + val plan = sqlContext.executePlan(testData.select('value, 'key).logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().map { @@ -57,7 +58,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { - val plan = sqlContext.executePlan(testData.logicalPlan).executedPlan + val plan = sqlContext.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -125,7 +126,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("decimal type") { // Casting is required here because ScalaReflection can't capture decimal precision information. val df = (1 to 10) - .map(i => Tuple1(Decimal(i, 15, 10))) + .map(i => Tuple1(Decimal(i, 15, 10).toJavaBigDecimal)) .toDF("dec") .select($"dec" cast DecimalType(15, 10)) @@ -160,7 +161,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sparkContext.parallelize((1 to 10000), 10).map { i => Row( s"str${i}: test cache.", - s"binary${i}: test cache.".getBytes("UTF-8"), + s"binary${i}: test cache.".getBytes(StandardCharsets.UTF_8), null, i % 2 == 0, i.toByte, @@ -219,4 +220,14 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { assert(data.count() === 10) assert(data.filter($"s" === "3").count() === 1) } + + test("SPARK-14138: Generated SpecificColumnarIterator can exceed JVM size limit for cached DF") { + val length1 = 3999 + val columnTypes1 = List.fill(length1)(IntegerType) + val columnarIterator1 = GenerateColumnAccessor.generate(columnTypes1) + + val length2 = 10000 + val columnTypes2 = List.fill(length2)(IntegerType) + val columnarIterator2 = GenerateColumnAccessor.generate(columnTypes2) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala similarity index 93% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala index aa1605fee8c73..dc22d3e8e4d3a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeProjection} import org.apache.spark.sql.types._ class TestNullableColumnAccessor[JvmType]( @@ -38,7 +38,7 @@ object TestNullableColumnAccessor { } class NullableColumnAccessorSuite extends SparkFunSuite { - import org.apache.spark.sql.columnar.ColumnarTestUtils._ + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ Seq( NULL, BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala similarity index 94% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala index 91404577832a0..cdd4551d64b50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala @@ -15,11 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeProjection} import org.apache.spark.sql.types._ class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType]) @@ -36,7 +36,7 @@ object TestNullableColumnBuilder { } class NullableColumnBuilderSuite extends SparkFunSuite { - import org.apache.spark.sql.columnar.ColumnarTestUtils._ + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ Seq( BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala similarity index 88% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index 6b7401464f46f..4f185ed283ce6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -15,14 +15,21 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar + +import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite -import org.apache.spark.sql._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ -class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext { + +class PartitionBatchPruningSuite + extends SparkFunSuite + with BeforeAndAfterEach + with SharedSQLContext { + import testImplicits._ private lazy val originalColumnBatchSize = sqlContext.conf.columnBatchSize @@ -32,30 +39,41 @@ class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext { super.beforeAll() // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) - - val pruningData = sparkContext.makeRDD((1 to 100).map { key => - val string = if (((key - 1) / 10) % 2 == 0) null else key.toString - TestData(key, string) - }, 5).toDF() - pruningData.registerTempTable("pruningData") - // Enable in-memory partition pruning sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Enable in-memory table scan accumulators sqlContext.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") - sqlContext.cacheTable("pruningData") } override protected def afterAll(): Unit = { try { sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) - sqlContext.uncacheTable("pruningData") } finally { super.afterAll() } } + override protected def beforeEach(): Unit = { + super.beforeEach() + // This creates accumulators, which get cleaned up after every single test, + // so we need to do this before every test. + val pruningData = sparkContext.makeRDD((1 to 100).map { key => + val string = if (((key - 1) / 10) % 2 == 0) null else key.toString + TestData(key, string) + }, 5).toDF() + pruningData.registerTempTable("pruningData") + sqlContext.cacheTable("pruningData") + } + + override protected def afterEach(): Unit = { + try { + sqlContext.uncacheTable("pruningData") + } finally { + super.afterEach() + } + } + // Comparisons checkBatchPruning("SELECT key FROM pruningData WHERE key = 1", 1, 1)(Seq(1)) checkBatchPruning("SELECT key FROM pruningData WHERE 1 = key", 1, 1)(Seq(1)) @@ -114,7 +132,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext { df.collect().map(_(0)).toArray } - val (readPartitions, readBatches) = df.queryExecution.executedPlan.collect { + val (readPartitions, readBatches) = df.queryExecution.sparkPlan.collect { case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value) }.head diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala similarity index 94% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala index 9a2948c59ba42..f67e9c7dae278 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar.ColumnarTestUtils._ -import org.apache.spark.sql.columnar.{BOOLEAN, NoopColumnStats} +import org.apache.spark.sql.execution.columnar.{BOOLEAN, NoopColumnStats} +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ class BooleanBitSetSuite extends SparkFunSuite { import BooleanBitSet._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala new file mode 100644 index 0000000000000..1aadd700d7443 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala @@ -0,0 +1,345 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar.compression + +import java.nio.{ByteBuffer, ByteOrder} +import java.nio.charset.StandardCharsets + +import org.apache.commons.lang3.RandomStringUtils +import org.apache.commons.math3.distribution.LogNormalDistribution + +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, GenericMutableRow} +import org.apache.spark.sql.execution.columnar.{BOOLEAN, INT, LONG, NativeColumnType, SHORT, STRING} +import org.apache.spark.sql.types.AtomicType +import org.apache.spark.util.Benchmark +import org.apache.spark.util.Utils._ + +/** + * Benchmark to decoders using various compression schemes. + */ +object CompressionSchemeBenchmark extends AllCompressionSchemes { + + private[this] def allocateLocal(size: Int): ByteBuffer = { + ByteBuffer.allocate(size).order(ByteOrder.nativeOrder) + } + + private[this] def genLowerSkewData() = { + val rng = new LogNormalDistribution(0.0, 0.01) + () => rng.sample + } + + private[this] def genHigherSkewData() = { + val rng = new LogNormalDistribution(0.0, 1.0) + () => rng.sample + } + + private[this] def prepareEncodeInternal[T <: AtomicType]( + count: Int, + tpe: NativeColumnType[T], + supportedScheme: CompressionScheme, + input: ByteBuffer): ((ByteBuffer, ByteBuffer) => ByteBuffer, Double, ByteBuffer) = { + assert(supportedScheme.supports(tpe)) + + def toRow(d: Any) = new GenericInternalRow(Array[Any](d)) + val encoder = supportedScheme.encoder(tpe) + for (i <- 0 until count) { + encoder.gatherCompressibilityStats(toRow(tpe.extract(input)), 0) + } + input.rewind() + + val compressedSize = if (encoder.compressedSize == 0) { + input.remaining() + } else { + encoder.compressedSize + } + + (encoder.compress, encoder.compressionRatio, allocateLocal(4 + compressedSize)) + } + + private[this] def runEncodeBenchmark[T <: AtomicType]( + name: String, + iters: Int, + count: Int, + tpe: NativeColumnType[T], + input: ByteBuffer): Unit = { + val benchmark = new Benchmark(name, iters * count) + + schemes.filter(_.supports(tpe)).map { scheme => + val (compressFunc, compressionRatio, buf) = prepareEncodeInternal(count, tpe, scheme, input) + val label = s"${getFormattedClassName(scheme)}(${compressionRatio.formatted("%.3f")})" + + benchmark.addCase(label)({ i: Int => + for (n <- 0L until iters) { + compressFunc(input, buf) + input.rewind() + buf.rewind() + } + }) + } + + benchmark.run() + } + + private[this] def runDecodeBenchmark[T <: AtomicType]( + name: String, + iters: Int, + count: Int, + tpe: NativeColumnType[T], + input: ByteBuffer): Unit = { + val benchmark = new Benchmark(name, iters * count) + + schemes.filter(_.supports(tpe)).map { scheme => + val (compressFunc, _, buf) = prepareEncodeInternal(count, tpe, scheme, input) + val compressedBuf = compressFunc(input, buf) + val label = s"${getFormattedClassName(scheme)}" + + input.rewind() + + benchmark.addCase(label)({ i: Int => + val rowBuf = new GenericMutableRow(1) + + for (n <- 0L until iters) { + compressedBuf.rewind.position(4) + val decoder = scheme.decoder(compressedBuf, tpe) + while (decoder.hasNext) { + decoder.next(rowBuf, 0) + } + } + }) + } + + benchmark.run() + } + + def bitEncodingBenchmark(iters: Int): Unit = { + val count = 65536 + val testData = allocateLocal(count * BOOLEAN.defaultSize) + + val g = { + val rng = genLowerSkewData() + () => (rng().toInt % 2).toByte + } + for (i <- 0 until count) { + testData.put(i * BOOLEAN.defaultSize, g()) + } + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // BOOLEAN Encode: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough(1.000) 3 / 4 19300.2 0.1 1.0X + // RunLengthEncoding(2.491) 923 / 939 72.7 13.8 0.0X + // BooleanBitSet(0.125) 359 / 363 187.1 5.3 0.0X + runEncodeBenchmark("BOOLEAN Encode", iters, count, BOOLEAN, testData) + + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // BOOLEAN Decode: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough 129 / 136 519.8 1.9 1.0X + // RunLengthEncoding 613 / 623 109.4 9.1 0.2X + // BooleanBitSet 1196 / 1222 56.1 17.8 0.1X + runDecodeBenchmark("BOOLEAN Decode", iters, count, BOOLEAN, testData) + } + + def shortEncodingBenchmark(iters: Int): Unit = { + val count = 65536 + val testData = allocateLocal(count * SHORT.defaultSize) + + val g1 = genLowerSkewData() + for (i <- 0 until count) { + testData.putShort(i * SHORT.defaultSize, g1().toShort) + } + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // SHORT Encode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough(1.000) 6 / 7 10971.4 0.1 1.0X + // RunLengthEncoding(1.510) 1526 / 1542 44.0 22.7 0.0X + runEncodeBenchmark("SHORT Encode (Lower Skew)", iters, count, SHORT, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // SHORT Decode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough 811 / 837 82.8 12.1 1.0X + // RunLengthEncoding 1219 / 1266 55.1 18.2 0.7X + runDecodeBenchmark("SHORT Decode (Lower Skew)", iters, count, SHORT, testData) + + val g2 = genHigherSkewData() + for (i <- 0 until count) { + testData.putShort(i * SHORT.defaultSize, g2().toShort) + } + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // SHORT Encode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough(1.000) 7 / 7 10112.4 0.1 1.0X + // RunLengthEncoding(2.009) 1623 / 1661 41.4 24.2 0.0X + runEncodeBenchmark("SHORT Encode (Higher Skew)", iters, count, SHORT, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // SHORT Decode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough 818 / 827 82.0 12.2 1.0X + // RunLengthEncoding 1202 / 1237 55.8 17.9 0.7X + runDecodeBenchmark("SHORT Decode (Higher Skew)", iters, count, SHORT, testData) + } + + def intEncodingBenchmark(iters: Int): Unit = { + val count = 65536 + val testData = allocateLocal(count * INT.defaultSize) + + val g1 = genLowerSkewData() + for (i <- 0 until count) { + testData.putInt(i * INT.defaultSize, g1().toInt) + } + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // INT Encode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough(1.000) 18 / 19 3716.4 0.3 1.0X + // RunLengthEncoding(1.001) 1992 / 2056 33.7 29.7 0.0X + // DictionaryEncoding(0.500) 723 / 739 92.8 10.8 0.0X + // IntDelta(0.250) 368 / 377 182.2 5.5 0.0X + runEncodeBenchmark("INT Encode (Lower Skew)", iters, count, INT, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // INT Decode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough 821 / 845 81.8 12.2 1.0X + // RunLengthEncoding 1246 / 1256 53.9 18.6 0.7X + // DictionaryEncoding 757 / 766 88.6 11.3 1.1X + // IntDelta 680 / 689 98.7 10.1 1.2X + runDecodeBenchmark("INT Decode (Lower Skew)", iters, count, INT, testData) + + val g2 = genHigherSkewData() + for (i <- 0 until count) { + testData.putInt(i * INT.defaultSize, g2().toInt) + } + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // INT Encode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough(1.000) 17 / 19 3888.4 0.3 1.0X + // RunLengthEncoding(1.339) 2127 / 2148 31.5 31.7 0.0X + // DictionaryEncoding(0.501) 960 / 972 69.9 14.3 0.0X + // IntDelta(0.250) 362 / 366 185.5 5.4 0.0X + runEncodeBenchmark("INT Encode (Higher Skew)", iters, count, INT, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // INT Decode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough 838 / 884 80.1 12.5 1.0X + // RunLengthEncoding 1287 / 1311 52.1 19.2 0.7X + // DictionaryEncoding 844 / 859 79.5 12.6 1.0X + // IntDelta 764 / 784 87.8 11.4 1.1X + runDecodeBenchmark("INT Decode (Higher Skew)", iters, count, INT, testData) + } + + def longEncodingBenchmark(iters: Int): Unit = { + val count = 65536 + val testData = allocateLocal(count * LONG.defaultSize) + + val g1 = genLowerSkewData() + for (i <- 0 until count) { + testData.putLong(i * LONG.defaultSize, g1().toLong) + } + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // LONG Encode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough(1.000) 37 / 38 1804.8 0.6 1.0X + // RunLengthEncoding(0.748) 2065 / 2094 32.5 30.8 0.0X + // DictionaryEncoding(0.250) 950 / 962 70.6 14.2 0.0X + // LongDelta(0.125) 475 / 482 141.2 7.1 0.1X + runEncodeBenchmark("LONG Encode (Lower Skew)", iters, count, LONG, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // LONG Decode (Lower Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough 888 / 894 75.5 13.2 1.0X + // RunLengthEncoding 1301 / 1311 51.6 19.4 0.7X + // DictionaryEncoding 887 / 904 75.7 13.2 1.0X + // LongDelta 693 / 735 96.8 10.3 1.3X + runDecodeBenchmark("LONG Decode (Lower Skew)", iters, count, LONG, testData) + + val g2 = genHigherSkewData() + for (i <- 0 until count) { + testData.putLong(i * LONG.defaultSize, g2().toLong) + } + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // LONG Encode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough(1.000) 34 / 35 1963.9 0.5 1.0X + // RunLengthEncoding(0.999) 2260 / 3021 29.7 33.7 0.0X + // DictionaryEncoding(0.251) 1270 / 1438 52.8 18.9 0.0X + // LongDelta(0.125) 496 / 509 135.3 7.4 0.1X + runEncodeBenchmark("LONG Encode (Higher Skew)", iters, count, LONG, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // LONG Decode (Higher Skew): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough 965 / 1494 69.5 14.4 1.0X + // RunLengthEncoding 1350 / 1378 49.7 20.1 0.7X + // DictionaryEncoding 892 / 924 75.2 13.3 1.1X + // LongDelta 817 / 847 82.2 12.2 1.2X + runDecodeBenchmark("LONG Decode (Higher Skew)", iters, count, LONG, testData) + } + + def stringEncodingBenchmark(iters: Int): Unit = { + val count = 65536 + val strLen = 8 + val tableSize = 16 + val testData = allocateLocal(count * (4 + strLen)) + + val g = { + val dataTable = (0 until tableSize).map(_ => RandomStringUtils.randomAlphabetic(strLen)) + val rng = genHigherSkewData() + () => dataTable(rng().toInt % tableSize) + } + for (i <- 0 until count) { + testData.putInt(strLen) + testData.put(g().getBytes(StandardCharsets.UTF_8)) + } + testData.rewind() + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // STRING Encode: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough(1.000) 56 / 57 1197.9 0.8 1.0X + // RunLengthEncoding(0.893) 4892 / 4937 13.7 72.9 0.0X + // DictionaryEncoding(0.167) 2968 / 2992 22.6 44.2 0.0X + runEncodeBenchmark("STRING Encode", iters, count, STRING, testData) + + // Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + // STRING Decode: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + // ------------------------------------------------------------------------------------------- + // PassThrough 2422 / 2449 27.7 36.1 1.0X + // RunLengthEncoding 2885 / 3018 23.3 43.0 0.8X + // DictionaryEncoding 2716 / 2752 24.7 40.5 0.9X + runDecodeBenchmark("STRING Decode", iters, count, STRING, testData) + } + + def main(args: Array[String]): Unit = { + bitEncodingBenchmark(1024) + shortEncodingBenchmark(1024) + intEncodingBenchmark(1024) + longEncodingBenchmark(1024) + stringEncodingBenchmark(1024) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala index acfab6586c0d1..830ca0294e1b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType class DictionaryEncodingSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala index 2111e9fbe62cb..988a577a7b4d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala @@ -15,12 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.IntegralType class IntegralDeltaSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala similarity index 93% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala index 67ec08f594a43..95642e93ae9f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala @@ -15,12 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType class RunLengthEncodingSuite extends SparkFunSuite { @@ -100,11 +100,11 @@ class RunLengthEncodingSuite extends SparkFunSuite { } test(s"$RunLengthEncoding with $typeName: simple case") { - skeleton(2, Seq(0 -> 2, 1 ->2)) + skeleton(2, Seq(0 -> 2, 1 -> 2)) } test(s"$RunLengthEncoding with $typeName: run length == 1") { - skeleton(2, Seq(0 -> 1, 1 ->1)) + skeleton(2, Seq(0 -> 1, 1 -> 1)) } test(s"$RunLengthEncoding with $typeName: single long run") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala similarity index 93% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala index 5268dfe0aa03e..5e078f251375a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala @@ -15,9 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression -import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.types.AtomicType class TestCompressibleColumnBuilder[T <: AtomicType]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala new file mode 100644 index 0000000000000..d6ccaf93488e7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -0,0 +1,699 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.execution.SparkSqlParser +import org.apache.spark.sql.types._ + +// TODO: merge this with DDLSuite (SPARK-14441) +class DDLCommandSuite extends PlanTest { + private val parser = SparkSqlParser + + private def assertUnsupported(sql: String): Unit = { + val e = intercept[AnalysisException] { + parser.parsePlan(sql) + } + assert(e.getMessage.toLowerCase.contains("operation not allowed")) + } + + test("create database") { + val sql = + """ + |CREATE DATABASE IF NOT EXISTS database_name + |COMMENT 'database_comment' LOCATION '/home/user/db' + |WITH DBPROPERTIES ('a'='a', 'b'='b', 'c'='c') + """.stripMargin + val parsed = parser.parsePlan(sql) + val expected = CreateDatabase( + "database_name", + ifNotExists = true, + Some("/home/user/db"), + Some("database_comment"), + Map("a" -> "a", "b" -> "b", "c" -> "c")) + comparePlans(parsed, expected) + } + + test("drop database") { + val sql1 = "DROP DATABASE IF EXISTS database_name RESTRICT" + val sql2 = "DROP DATABASE IF EXISTS database_name CASCADE" + val sql3 = "DROP SCHEMA IF EXISTS database_name RESTRICT" + val sql4 = "DROP SCHEMA IF EXISTS database_name CASCADE" + // The default is restrict=true + val sql5 = "DROP DATABASE IF EXISTS database_name" + // The default is ifExists=false + val sql6 = "DROP DATABASE database_name" + val sql7 = "DROP DATABASE database_name CASCADE" + + val parsed1 = parser.parsePlan(sql1) + val parsed2 = parser.parsePlan(sql2) + val parsed3 = parser.parsePlan(sql3) + val parsed4 = parser.parsePlan(sql4) + val parsed5 = parser.parsePlan(sql5) + val parsed6 = parser.parsePlan(sql6) + val parsed7 = parser.parsePlan(sql7) + + val expected1 = DropDatabase( + "database_name", + ifExists = true, + cascade = false) + val expected2 = DropDatabase( + "database_name", + ifExists = true, + cascade = true) + val expected3 = DropDatabase( + "database_name", + ifExists = false, + cascade = false) + val expected4 = DropDatabase( + "database_name", + ifExists = false, + cascade = true) + + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + comparePlans(parsed3, expected1) + comparePlans(parsed4, expected2) + comparePlans(parsed5, expected1) + comparePlans(parsed6, expected3) + comparePlans(parsed7, expected4) + } + + test("alter database set dbproperties") { + // ALTER (DATABASE|SCHEMA) database_name SET DBPROPERTIES (property_name=property_value, ...) + val sql1 = "ALTER DATABASE database_name SET DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')" + val sql2 = "ALTER SCHEMA database_name SET DBPROPERTIES ('a'='a')" + + val parsed1 = parser.parsePlan(sql1) + val parsed2 = parser.parsePlan(sql2) + + val expected1 = AlterDatabaseProperties( + "database_name", + Map("a" -> "a", "b" -> "b", "c" -> "c")) + val expected2 = AlterDatabaseProperties( + "database_name", + Map("a" -> "a")) + + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + } + + test("describe database") { + // DESCRIBE DATABASE [EXTENDED] db_name; + val sql1 = "DESCRIBE DATABASE EXTENDED db_name" + val sql2 = "DESCRIBE DATABASE db_name" + + val parsed1 = parser.parsePlan(sql1) + val parsed2 = parser.parsePlan(sql2) + + val expected1 = DescribeDatabase( + "db_name", + extended = true) + val expected2 = DescribeDatabase( + "db_name", + extended = false) + + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + } + + test("create function") { + val sql1 = + """ + |CREATE TEMPORARY FUNCTION helloworld as + |'com.matthewrathbone.example.SimpleUDFExample' USING JAR '/path/to/jar1', + |JAR '/path/to/jar2' + """.stripMargin + val sql2 = + """ + |CREATE FUNCTION hello.world as + |'com.matthewrathbone.example.SimpleUDFExample' USING ARCHIVE '/path/to/archive', + |FILE '/path/to/file' + """.stripMargin + val parsed1 = parser.parsePlan(sql1) + val parsed2 = parser.parsePlan(sql2) + val expected1 = CreateFunction( + None, + "helloworld", + "com.matthewrathbone.example.SimpleUDFExample", + Seq(("jar", "/path/to/jar1"), ("jar", "/path/to/jar2")), + isTemp = true) + val expected2 = CreateFunction( + Some("hello"), + "world", + "com.matthewrathbone.example.SimpleUDFExample", + Seq(("archive", "/path/to/archive"), ("file", "/path/to/file")), + isTemp = false) + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + } + + test("drop function") { + val sql1 = "DROP TEMPORARY FUNCTION helloworld" + val sql2 = "DROP TEMPORARY FUNCTION IF EXISTS helloworld" + val sql3 = "DROP FUNCTION hello.world" + val sql4 = "DROP FUNCTION IF EXISTS hello.world" + + val parsed1 = parser.parsePlan(sql1) + val parsed2 = parser.parsePlan(sql2) + val parsed3 = parser.parsePlan(sql3) + val parsed4 = parser.parsePlan(sql4) + + val expected1 = DropFunction( + None, + "helloworld", + ifExists = false, + isTemp = true) + val expected2 = DropFunction( + None, + "helloworld", + ifExists = true, + isTemp = true) + val expected3 = DropFunction( + Some("hello"), + "world", + ifExists = false, + isTemp = false) + val expected4 = DropFunction( + Some("hello"), + "world", + ifExists = true, + isTemp = false) + + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + comparePlans(parsed3, expected3) + comparePlans(parsed4, expected4) + } + + // ALTER TABLE table_name RENAME TO new_table_name; + // ALTER VIEW view_name RENAME TO new_view_name; + test("alter table/view: rename table/view") { + val sql_table = "ALTER TABLE table_name RENAME TO new_table_name" + val sql_view = sql_table.replace("TABLE", "VIEW") + val parsed_table = parser.parsePlan(sql_table) + val parsed_view = parser.parsePlan(sql_view) + val expected_table = AlterTableRename( + TableIdentifier("table_name", None), + TableIdentifier("new_table_name", None), + isView = false) + val expected_view = AlterTableRename( + TableIdentifier("table_name", None), + TableIdentifier("new_table_name", None), + isView = true) + comparePlans(parsed_table, expected_table) + comparePlans(parsed_view, expected_view) + } + + // ALTER TABLE table_name SET TBLPROPERTIES ('comment' = new_comment); + // ALTER TABLE table_name UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); + // ALTER VIEW view_name SET TBLPROPERTIES ('comment' = new_comment); + // ALTER VIEW view_name UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); + test("alter table/view: alter table/view properties") { + val sql1_table = "ALTER TABLE table_name SET TBLPROPERTIES ('test' = 'test', " + + "'comment' = 'new_comment')" + val sql2_table = "ALTER TABLE table_name UNSET TBLPROPERTIES ('comment', 'test')" + val sql3_table = "ALTER TABLE table_name UNSET TBLPROPERTIES IF EXISTS ('comment', 'test')" + val sql1_view = sql1_table.replace("TABLE", "VIEW") + val sql2_view = sql2_table.replace("TABLE", "VIEW") + val sql3_view = sql3_table.replace("TABLE", "VIEW") + + val parsed1_table = parser.parsePlan(sql1_table) + val parsed2_table = parser.parsePlan(sql2_table) + val parsed3_table = parser.parsePlan(sql3_table) + val parsed1_view = parser.parsePlan(sql1_view) + val parsed2_view = parser.parsePlan(sql2_view) + val parsed3_view = parser.parsePlan(sql3_view) + + val tableIdent = TableIdentifier("table_name", None) + val expected1_table = AlterTableSetProperties( + tableIdent, Map("test" -> "test", "comment" -> "new_comment"), isView = false) + val expected2_table = AlterTableUnsetProperties( + tableIdent, Seq("comment", "test"), ifExists = false, isView = false) + val expected3_table = AlterTableUnsetProperties( + tableIdent, Seq("comment", "test"), ifExists = true, isView = false) + val expected1_view = expected1_table.copy(isView = true) + val expected2_view = expected2_table.copy(isView = true) + val expected3_view = expected3_table.copy(isView = true) + + comparePlans(parsed1_table, expected1_table) + comparePlans(parsed2_table, expected2_table) + comparePlans(parsed3_table, expected3_table) + comparePlans(parsed1_view, expected1_view) + comparePlans(parsed2_view, expected2_view) + comparePlans(parsed3_view, expected3_view) + } + + test("alter table: SerDe properties") { + val sql1 = "ALTER TABLE table_name SET SERDE 'org.apache.class'" + val sql2 = + """ + |ALTER TABLE table_name SET SERDE 'org.apache.class' + |WITH SERDEPROPERTIES ('columns'='foo,bar', 'field.delim' = ',') + """.stripMargin + val sql3 = + """ + |ALTER TABLE table_name SET SERDEPROPERTIES ('columns'='foo,bar', + |'field.delim' = ',') + """.stripMargin + val sql4 = + """ + |ALTER TABLE table_name PARTITION (test, dt='2008-08-08', + |country='us') SET SERDE 'org.apache.class' WITH SERDEPROPERTIES ('columns'='foo,bar', + |'field.delim' = ',') + """.stripMargin + val sql5 = + """ + |ALTER TABLE table_name PARTITION (test, dt='2008-08-08', + |country='us') SET SERDEPROPERTIES ('columns'='foo,bar', 'field.delim' = ',') + """.stripMargin + val parsed1 = parser.parsePlan(sql1) + val parsed2 = parser.parsePlan(sql2) + val parsed3 = parser.parsePlan(sql3) + val parsed4 = parser.parsePlan(sql4) + val parsed5 = parser.parsePlan(sql5) + val tableIdent = TableIdentifier("table_name", None) + val expected1 = AlterTableSerDeProperties( + tableIdent, Some("org.apache.class"), None, None) + val expected2 = AlterTableSerDeProperties( + tableIdent, + Some("org.apache.class"), + Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), + None) + val expected3 = AlterTableSerDeProperties( + tableIdent, None, Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), None) + val expected4 = AlterTableSerDeProperties( + tableIdent, + Some("org.apache.class"), + Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), + Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us"))) + val expected5 = AlterTableSerDeProperties( + tableIdent, + None, + Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), + Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us"))) + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + comparePlans(parsed3, expected3) + comparePlans(parsed4, expected4) + comparePlans(parsed5, expected5) + } + + // ALTER TABLE table_name ADD [IF NOT EXISTS] PARTITION partition_spec + // [LOCATION 'location1'] partition_spec [LOCATION 'location2'] ...; + test("alter table: add partition") { + val sql1 = + """ + |ALTER TABLE table_name ADD IF NOT EXISTS PARTITION + |(dt='2008-08-08', country='us') LOCATION 'location1' PARTITION + |(dt='2009-09-09', country='uk') + """.stripMargin + val sql2 = "ALTER TABLE table_name ADD PARTITION (dt='2008-08-08') LOCATION 'loc'" + + val parsed1 = parser.parsePlan(sql1) + val parsed2 = parser.parsePlan(sql2) + + val expected1 = AlterTableAddPartition( + TableIdentifier("table_name", None), + Seq( + (Map("dt" -> "2008-08-08", "country" -> "us"), Some("location1")), + (Map("dt" -> "2009-09-09", "country" -> "uk"), None)), + ifNotExists = true) + val expected2 = AlterTableAddPartition( + TableIdentifier("table_name", None), + Seq((Map("dt" -> "2008-08-08"), Some("loc"))), + ifNotExists = false) + + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + } + + // ALTER VIEW view_name ADD [IF NOT EXISTS] PARTITION partition_spec PARTITION partition_spec ...; + test("alter view: add partition") { + val sql1 = + """ + |ALTER VIEW view_name ADD IF NOT EXISTS PARTITION + |(dt='2008-08-08', country='us') PARTITION + |(dt='2009-09-09', country='uk') + """.stripMargin + // different constant types in partitioning spec + val sql2 = + """ + |ALTER VIEW view_name ADD PARTITION + |(col1=NULL, cOL2='f', col3=5, COL4=true) + """.stripMargin + + intercept[ParseException] { + parser.parsePlan(sql1) + } + intercept[ParseException] { + parser.parsePlan(sql2) + } + } + + test("alter table: rename partition") { + val sql = + """ + |ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') + |RENAME TO PARTITION (dt='2008-09-09', country='uk') + """.stripMargin + val parsed = parser.parsePlan(sql) + val expected = AlterTableRenamePartition( + TableIdentifier("table_name", None), + Map("dt" -> "2008-08-08", "country" -> "us"), + Map("dt" -> "2008-09-09", "country" -> "uk")) + comparePlans(parsed, expected) + } + + test("alter table: exchange partition (not supported)") { + assertUnsupported( + """ + |ALTER TABLE table_name_1 EXCHANGE PARTITION + |(dt='2008-08-08', country='us') WITH TABLE table_name_2 + """.stripMargin) + } + + // ALTER TABLE table_name DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE] + // ALTER VIEW table_name DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] + test("alter table/view: drop partitions") { + val sql1_table = + """ + |ALTER TABLE table_name DROP IF EXISTS PARTITION + |(dt='2008-08-08', country='us'), PARTITION (dt='2009-09-09', country='uk') + """.stripMargin + val sql2_table = + """ + |ALTER TABLE table_name DROP PARTITION + |(dt='2008-08-08', country='us'), PARTITION (dt='2009-09-09', country='uk') PURGE + """.stripMargin + val sql1_view = sql1_table.replace("TABLE", "VIEW") + // Note: ALTER VIEW DROP PARTITION does not support PURGE + val sql2_view = sql2_table.replace("TABLE", "VIEW").replace("PURGE", "") + + val parsed1_table = parser.parsePlan(sql1_table) + val e = intercept[ParseException] { + parser.parsePlan(sql2_table) + } + assert(e.getMessage.contains("Operation not allowed")) + + intercept[ParseException] { + parser.parsePlan(sql1_view) + } + intercept[ParseException] { + parser.parsePlan(sql2_view) + } + + val tableIdent = TableIdentifier("table_name", None) + val expected1_table = AlterTableDropPartition( + tableIdent, + Seq( + Map("dt" -> "2008-08-08", "country" -> "us"), + Map("dt" -> "2009-09-09", "country" -> "uk")), + ifExists = true) + + comparePlans(parsed1_table, expected1_table) + } + + test("alter table: archive partition (not supported)") { + assertUnsupported("ALTER TABLE table_name ARCHIVE PARTITION (dt='2008-08-08', country='us')") + } + + test("alter table: unarchive partition (not supported)") { + assertUnsupported("ALTER TABLE table_name UNARCHIVE PARTITION (dt='2008-08-08', country='us')") + } + + test("alter table: set file format") { + val sql1 = "ALTER TABLE table_name SET FILEFORMAT INPUTFORMAT 'test' " + + "OUTPUTFORMAT 'test' SERDE 'test'" + val sql2 = "ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') " + + "SET FILEFORMAT PARQUET" + val parsed1 = parser.parsePlan(sql1) + val parsed2 = parser.parsePlan(sql2) + val tableIdent = TableIdentifier("table_name", None) + val expected1 = AlterTableSetFileFormat( + tableIdent, + None, + List("test", "test", "test"), + None)(sql1) + val expected2 = AlterTableSetFileFormat( + tableIdent, + Some(Map("dt" -> "2008-08-08", "country" -> "us")), + Seq(), + Some("PARQUET"))(sql2) + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + } + + test("alter table: set location") { + val sql1 = "ALTER TABLE table_name SET LOCATION 'new location'" + val sql2 = "ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') " + + "SET LOCATION 'new location'" + val parsed1 = parser.parsePlan(sql1) + val parsed2 = parser.parsePlan(sql2) + val tableIdent = TableIdentifier("table_name", None) + val expected1 = AlterTableSetLocation( + tableIdent, + None, + "new location") + val expected2 = AlterTableSetLocation( + tableIdent, + Some(Map("dt" -> "2008-08-08", "country" -> "us")), + "new location") + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + } + + test("alter table: touch (not supported)") { + assertUnsupported("ALTER TABLE table_name TOUCH") + assertUnsupported("ALTER TABLE table_name TOUCH PARTITION (dt='2008-08-08', country='us')") + } + + test("alter table: compact (not supported)") { + assertUnsupported("ALTER TABLE table_name COMPACT 'compaction_type'") + assertUnsupported( + """ + |ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') + |COMPACT 'MAJOR' + """.stripMargin) + } + + test("alter table: concatenate (not supported)") { + assertUnsupported("ALTER TABLE table_name CONCATENATE") + assertUnsupported( + "ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') CONCATENATE") + } + + test("alter table: change column name/type/position/comment") { + val sql1 = "ALTER TABLE table_name CHANGE col_old_name col_new_name INT" + val sql2 = + """ + |ALTER TABLE table_name CHANGE COLUMN col_old_name col_new_name INT + |COMMENT 'col_comment' FIRST CASCADE + """.stripMargin + val sql3 = + """ + |ALTER TABLE table_name CHANGE COLUMN col_old_name col_new_name INT + |COMMENT 'col_comment' AFTER column_name RESTRICT + """.stripMargin + val parsed1 = parser.parsePlan(sql1) + val parsed2 = parser.parsePlan(sql2) + val parsed3 = parser.parsePlan(sql3) + val tableIdent = TableIdentifier("table_name", None) + val expected1 = AlterTableChangeCol( + tableName = tableIdent, + partitionSpec = None, + oldColName = "col_old_name", + newColName = "col_new_name", + dataType = IntegerType, + comment = None, + afterColName = None, + restrict = false, + cascade = false)(sql1) + val expected2 = AlterTableChangeCol( + tableName = tableIdent, + partitionSpec = None, + oldColName = "col_old_name", + newColName = "col_new_name", + dataType = IntegerType, + comment = Some("col_comment"), + afterColName = None, + restrict = false, + cascade = true)(sql2) + val expected3 = AlterTableChangeCol( + tableName = tableIdent, + partitionSpec = None, + oldColName = "col_old_name", + newColName = "col_new_name", + dataType = IntegerType, + comment = Some("col_comment"), + afterColName = Some("column_name"), + restrict = true, + cascade = false)(sql3) + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + comparePlans(parsed3, expected3) + } + + test("alter table: add/replace columns") { + val sql1 = + """ + |ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') + |ADD COLUMNS (new_col1 INT COMMENT 'test_comment', new_col2 LONG + |COMMENT 'test_comment2') CASCADE + """.stripMargin + val sql2 = + """ + |ALTER TABLE table_name REPLACE COLUMNS (new_col1 INT + |COMMENT 'test_comment', new_col2 LONG COMMENT 'test_comment2') RESTRICT + """.stripMargin + val parsed1 = parser.parsePlan(sql1) + val parsed2 = parser.parsePlan(sql2) + val meta1 = new MetadataBuilder().putString("comment", "test_comment").build() + val meta2 = new MetadataBuilder().putString("comment", "test_comment2").build() + val tableIdent = TableIdentifier("table_name", None) + val expected1 = AlterTableAddCol( + tableIdent, + Some(Map("dt" -> "2008-08-08", "country" -> "us")), + StructType(Seq( + StructField("new_col1", IntegerType, nullable = true, meta1), + StructField("new_col2", LongType, nullable = true, meta2))), + restrict = false, + cascade = true)(sql1) + val expected2 = AlterTableReplaceCol( + tableIdent, + None, + StructType(Seq( + StructField("new_col1", IntegerType, nullable = true, meta1), + StructField("new_col2", LongType, nullable = true, meta2))), + restrict = true, + cascade = false)(sql2) + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + } + + test("show databases") { + val sql1 = "SHOW DATABASES" + val sql2 = "SHOW DATABASES LIKE 'defau*'" + val parsed1 = parser.parsePlan(sql1) + val expected1 = ShowDatabasesCommand(None) + val parsed2 = parser.parsePlan(sql2) + val expected2 = ShowDatabasesCommand(Some("defau*")) + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + } + + test("show tblproperties") { + val parsed1 = parser.parsePlan("SHOW TBLPROPERTIES tab1") + val expected1 = ShowTablePropertiesCommand(TableIdentifier("tab1", None), None) + val parsed2 = parser.parsePlan("SHOW TBLPROPERTIES tab1('propKey1')") + val expected2 = ShowTablePropertiesCommand(TableIdentifier("tab1", None), Some("propKey1")) + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + } + + test("unsupported operations") { + intercept[ParseException] { + parser.parsePlan("DROP TABLE tab PURGE") + } + intercept[ParseException] { + parser.parsePlan("DROP TABLE tab FOR REPLICATION('eventid')") + } + intercept[ParseException] { + parser.parsePlan("CREATE VIEW testView AS SELECT id FROM tab") + } + intercept[ParseException] { + parser.parsePlan("ALTER VIEW testView AS SELECT id FROM tab") + } + intercept[ParseException] { + parser.parsePlan( + """ + |CREATE EXTERNAL TABLE parquet_tab2(c1 INT, c2 STRING) + |TBLPROPERTIES('prop1Key '= "prop1Val", ' `prop2Key` '= "prop2Val") + """.stripMargin) + } + intercept[ParseException] { + parser.parsePlan( + """ + |CREATE EXTERNAL TABLE oneToTenDef + |USING org.apache.spark.sql.sources + |OPTIONS (from '1', to '10') + """.stripMargin) + } + intercept[ParseException] { + parser.parsePlan("SELECT TRANSFORM (key, value) USING 'cat' AS (tKey, tValue) FROM testData") + } + } + + test("SPARK-14383: DISTRIBUTE and UNSET as non-keywords") { + val sql = "SELECT distribute, unset FROM x" + val parsed = parser.parsePlan(sql) + assert(parsed.isInstanceOf[Project]) + } + + test("drop table") { + val tableName1 = "db.tab" + val tableName2 = "tab" + + val parsed1 = parser.parsePlan(s"DROP TABLE $tableName1") + val parsed2 = parser.parsePlan(s"DROP TABLE IF EXISTS $tableName1") + val parsed3 = parser.parsePlan(s"DROP TABLE $tableName2") + val parsed4 = parser.parsePlan(s"DROP TABLE IF EXISTS $tableName2") + + val expected1 = + DropTable(TableIdentifier("tab", Option("db")), ifExists = false, isView = false) + val expected2 = + DropTable(TableIdentifier("tab", Option("db")), ifExists = true, isView = false) + val expected3 = + DropTable(TableIdentifier("tab", None), ifExists = false, isView = false) + val expected4 = + DropTable(TableIdentifier("tab", None), ifExists = true, isView = false) + + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + comparePlans(parsed3, expected3) + comparePlans(parsed4, expected4) + } + + test("drop view") { + val viewName1 = "db.view" + val viewName2 = "view" + + val parsed1 = parser.parsePlan(s"DROP VIEW $viewName1") + val parsed2 = parser.parsePlan(s"DROP VIEW IF EXISTS $viewName1") + val parsed3 = parser.parsePlan(s"DROP VIEW $viewName2") + val parsed4 = parser.parsePlan(s"DROP VIEW IF EXISTS $viewName2") + + val expected1 = + DropTable(TableIdentifier("view", Option("db")), ifExists = false, isView = true) + val expected2 = + DropTable(TableIdentifier("view", Option("db")), ifExists = true, isView = true) + val expected3 = + DropTable(TableIdentifier("view", None), ifExists = false, isView = true) + val expected4 = + DropTable(TableIdentifier("view", None), ifExists = true, isView = true) + + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + comparePlans(parsed3, expected3) + comparePlans(parsed4, expected4) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala new file mode 100644 index 0000000000000..9ffffa0bdd6e7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -0,0 +1,719 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import java.io.File + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogStorageFormat} +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec +import org.apache.spark.sql.test.SharedSQLContext + +class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { + private val escapedIdentifier = "`(.+)`".r + + override def afterEach(): Unit = { + try { + // drop all databases, tables and functions after each test + sqlContext.sessionState.catalog.reset() + } finally { + super.afterEach() + } + } + + /** + * Strip backticks, if any, from the string. + */ + private def cleanIdentifier(ident: String): String = { + ident match { + case escapedIdentifier(i) => i + case plainIdent => plainIdent + } + } + + private def assertUnsupported(query: String): Unit = { + val e = intercept[AnalysisException] { + sql(query) + } + assert(e.getMessage.toLowerCase.contains("operation not allowed")) + } + + private def maybeWrapException[T](expectException: Boolean)(body: => T): Unit = { + if (expectException) intercept[AnalysisException] { body } else body + } + + private def createDatabase(catalog: SessionCatalog, name: String): Unit = { + catalog.createDatabase(CatalogDatabase(name, "", "", Map()), ignoreIfExists = false) + } + + private def createTable(catalog: SessionCatalog, name: TableIdentifier): Unit = { + catalog.createTable(CatalogTable( + identifier = name, + tableType = CatalogTableType.EXTERNAL_TABLE, + storage = CatalogStorageFormat(None, None, None, None, Map()), + schema = Seq()), ignoreIfExists = false) + } + + private def createTablePartition( + catalog: SessionCatalog, + spec: TablePartitionSpec, + tableName: TableIdentifier): Unit = { + val part = CatalogTablePartition(spec, CatalogStorageFormat(None, None, None, None, Map())) + catalog.createPartitions(tableName, Seq(part), ignoreIfExists = false) + } + + test("Create/Drop Database") { + val catalog = sqlContext.sessionState.catalog + + val databaseNames = Seq("db1", "`database`") + + databaseNames.foreach { dbName => + try { + val dbNameWithoutBackTicks = cleanIdentifier(dbName) + + sql(s"CREATE DATABASE $dbName") + val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks) + assert(db1 == CatalogDatabase( + dbNameWithoutBackTicks, + "", + System.getProperty("java.io.tmpdir") + File.separator + s"$dbNameWithoutBackTicks.db", + Map.empty)) + sql(s"DROP DATABASE $dbName CASCADE") + assert(!catalog.databaseExists(dbNameWithoutBackTicks)) + } finally { + catalog.reset() + } + } + } + + test("Create Database - database already exists") { + val catalog = sqlContext.sessionState.catalog + val databaseNames = Seq("db1", "`database`") + + databaseNames.foreach { dbName => + try { + val dbNameWithoutBackTicks = cleanIdentifier(dbName) + sql(s"CREATE DATABASE $dbName") + val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks) + assert(db1 == CatalogDatabase( + dbNameWithoutBackTicks, + "", + System.getProperty("java.io.tmpdir") + File.separator + s"$dbNameWithoutBackTicks.db", + Map.empty)) + + val message = intercept[AnalysisException] { + sql(s"CREATE DATABASE $dbName") + }.getMessage + assert(message.contains(s"Database '$dbNameWithoutBackTicks' already exists.")) + } finally { + catalog.reset() + } + } + } + + test("Alter/Describe Database") { + val catalog = sqlContext.sessionState.catalog + val databaseNames = Seq("db1", "`database`") + + databaseNames.foreach { dbName => + try { + val dbNameWithoutBackTicks = cleanIdentifier(dbName) + val location = + System.getProperty("java.io.tmpdir") + File.separator + s"$dbNameWithoutBackTicks.db" + sql(s"CREATE DATABASE $dbName") + + checkAnswer( + sql(s"DESCRIBE DATABASE EXTENDED $dbName"), + Row("Database Name", dbNameWithoutBackTicks) :: + Row("Description", "") :: + Row("Location", location) :: + Row("Properties", "") :: Nil) + + sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')") + + checkAnswer( + sql(s"DESCRIBE DATABASE EXTENDED $dbName"), + Row("Database Name", dbNameWithoutBackTicks) :: + Row("Description", "") :: + Row("Location", location) :: + Row("Properties", "((a,a), (b,b), (c,c))") :: Nil) + + sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('d'='d')") + + checkAnswer( + sql(s"DESCRIBE DATABASE EXTENDED $dbName"), + Row("Database Name", dbNameWithoutBackTicks) :: + Row("Description", "") :: + Row("Location", location) :: + Row("Properties", "((a,a), (b,b), (c,c), (d,d))") :: Nil) + } finally { + catalog.reset() + } + } + } + + test("Drop/Alter/Describe Database - database does not exists") { + val databaseNames = Seq("db1", "`database`") + + databaseNames.foreach { dbName => + val dbNameWithoutBackTicks = cleanIdentifier(dbName) + assert(!sqlContext.sessionState.catalog.databaseExists(dbNameWithoutBackTicks)) + + var message = intercept[AnalysisException] { + sql(s"DROP DATABASE $dbName") + }.getMessage + assert(message.contains(s"Database '$dbNameWithoutBackTicks' does not exist")) + + message = intercept[AnalysisException] { + sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('d'='d')") + }.getMessage + assert(message.contains(s"Database '$dbNameWithoutBackTicks' does not exist")) + + message = intercept[AnalysisException] { + sql(s"DESCRIBE DATABASE EXTENDED $dbName") + }.getMessage + assert(message.contains(s"Database '$dbNameWithoutBackTicks' does not exist")) + + sql(s"DROP DATABASE IF EXISTS $dbName") + } + } + + // TODO: test drop database in restrict mode + + test("alter table: rename") { + val catalog = sqlContext.sessionState.catalog + val tableIdent1 = TableIdentifier("tab1", Some("dbx")) + val tableIdent2 = TableIdentifier("tab2", Some("dbx")) + createDatabase(catalog, "dbx") + createDatabase(catalog, "dby") + createTable(catalog, tableIdent1) + assert(catalog.listTables("dbx") == Seq(tableIdent1)) + sql("ALTER TABLE dbx.tab1 RENAME TO dbx.tab2") + assert(catalog.listTables("dbx") == Seq(tableIdent2)) + catalog.setCurrentDatabase("dbx") + // rename without explicitly specifying database + sql("ALTER TABLE tab2 RENAME TO tab1") + assert(catalog.listTables("dbx") == Seq(tableIdent1)) + // table to rename does not exist + intercept[AnalysisException] { + sql("ALTER TABLE dbx.does_not_exist RENAME TO dbx.tab2") + } + // destination database is different + intercept[AnalysisException] { + sql("ALTER TABLE dbx.tab1 RENAME TO dby.tab2") + } + } + + test("alter table: set location") { + testSetLocation(isDatasourceTable = false) + } + + test("alter table: set location (datasource table)") { + testSetLocation(isDatasourceTable = true) + } + + test("alter table: set properties") { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + assert(catalog.getTableMetadata(tableIdent).properties.isEmpty) + // set table properties + sql("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('andrew' = 'or14', 'kor' = 'bel')") + assert(catalog.getTableMetadata(tableIdent).properties == + Map("andrew" -> "or14", "kor" -> "bel")) + // set table properties without explicitly specifying database + catalog.setCurrentDatabase("dbx") + sql("ALTER TABLE tab1 SET TBLPROPERTIES ('kor' = 'belle', 'kar' = 'bol')") + assert(catalog.getTableMetadata(tableIdent).properties == + Map("andrew" -> "or14", "kor" -> "belle", "kar" -> "bol")) + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist SET TBLPROPERTIES ('winner' = 'loser')") + } + // throw exception for datasource tables + convertToDatasourceTable(catalog, tableIdent) + val e = intercept[AnalysisException] { + sql("ALTER TABLE tab1 SET TBLPROPERTIES ('sora' = 'bol')") + } + assert(e.getMessage.contains("datasource")) + } + + test("alter table: unset properties") { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + // unset table properties + sql("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('j' = 'am', 'p' = 'an', 'c' = 'lan')") + sql("ALTER TABLE dbx.tab1 UNSET TBLPROPERTIES ('j')") + assert(catalog.getTableMetadata(tableIdent).properties == Map("p" -> "an", "c" -> "lan")) + // unset table properties without explicitly specifying database + catalog.setCurrentDatabase("dbx") + sql("ALTER TABLE tab1 UNSET TBLPROPERTIES ('p')") + assert(catalog.getTableMetadata(tableIdent).properties == Map("c" -> "lan")) + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist UNSET TBLPROPERTIES ('c' = 'lan')") + } + // property to unset does not exist + val e = intercept[AnalysisException] { + sql("ALTER TABLE tab1 UNSET TBLPROPERTIES ('c', 'xyz')") + } + assert(e.getMessage.contains("xyz")) + // property to unset does not exist, but "IF EXISTS" is specified + sql("ALTER TABLE tab1 UNSET TBLPROPERTIES IF EXISTS ('c', 'xyz')") + assert(catalog.getTableMetadata(tableIdent).properties.isEmpty) + // throw exception for datasource tables + convertToDatasourceTable(catalog, tableIdent) + val e1 = intercept[AnalysisException] { + sql("ALTER TABLE tab1 UNSET TBLPROPERTIES ('sora')") + } + assert(e1.getMessage.contains("datasource")) + } + + test("alter table: set serde") { + testSetSerde(isDatasourceTable = false) + } + + test("alter table: set serde (datasource table)") { + testSetSerde(isDatasourceTable = true) + } + + test("alter table: bucketing is not supported") { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + assertUnsupported("ALTER TABLE dbx.tab1 CLUSTERED BY (blood, lemon, grape) INTO 11 BUCKETS") + assertUnsupported("ALTER TABLE dbx.tab1 CLUSTERED BY (fuji) SORTED BY (grape) INTO 5 BUCKETS") + assertUnsupported("ALTER TABLE dbx.tab1 NOT CLUSTERED") + assertUnsupported("ALTER TABLE dbx.tab1 NOT SORTED") + } + + test("alter table: skew is not supported") { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + assertUnsupported("ALTER TABLE dbx.tab1 SKEWED BY (dt, country) ON " + + "(('2008-08-08', 'us'), ('2009-09-09', 'uk'), ('2010-10-10', 'cn'))") + assertUnsupported("ALTER TABLE dbx.tab1 SKEWED BY (dt, country) ON " + + "(('2008-08-08', 'us'), ('2009-09-09', 'uk')) STORED AS DIRECTORIES") + assertUnsupported("ALTER TABLE dbx.tab1 NOT SKEWED") + assertUnsupported("ALTER TABLE dbx.tab1 NOT STORED AS DIRECTORIES") + } + + test("alter table: add partition") { + testAddPartitions(isDatasourceTable = false) + } + + test("alter table: add partition (datasource table)") { + testAddPartitions(isDatasourceTable = true) + } + + test("alter table: add partition is not supported for views") { + assertUnsupported("ALTER VIEW dbx.tab1 ADD IF NOT EXISTS PARTITION (b='2')") + } + + test("alter table: drop partition") { + testDropPartitions(isDatasourceTable = false) + } + + test("alter table: drop partition (datasource table)") { + testDropPartitions(isDatasourceTable = true) + } + + test("alter table: drop partition is not supported for views") { + assertUnsupported("ALTER VIEW dbx.tab1 DROP IF EXISTS PARTITION (b='2')") + } + + test("alter table: rename partition") { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + val part1 = Map("a" -> "1") + val part2 = Map("b" -> "2") + val part3 = Map("c" -> "3") + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + createTablePartition(catalog, part1, tableIdent) + createTablePartition(catalog, part2, tableIdent) + createTablePartition(catalog, part3, tableIdent) + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2, part3)) + sql("ALTER TABLE dbx.tab1 PARTITION (a='1') RENAME TO PARTITION (a='100')") + sql("ALTER TABLE dbx.tab1 PARTITION (b='2') RENAME TO PARTITION (b='200')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(Map("a" -> "100"), Map("b" -> "200"), part3)) + // rename without explicitly specifying database + catalog.setCurrentDatabase("dbx") + sql("ALTER TABLE tab1 PARTITION (a='100') RENAME TO PARTITION (a='10')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(Map("a" -> "10"), Map("b" -> "200"), part3)) + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist PARTITION (c='3') RENAME TO PARTITION (c='333')") + } + // partition to rename does not exist + intercept[AnalysisException] { + sql("ALTER TABLE tab1 PARTITION (x='300') RENAME TO PARTITION (x='333')") + } + } + + test("show tables") { + withTempTable("show1a", "show2b") { + sql( + """ + |CREATE TEMPORARY TABLE show1a + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + | + |) + """.stripMargin) + sql( + """ + |CREATE TEMPORARY TABLE show2b + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + checkAnswer( + sql("SHOW TABLES IN default 'show1*'"), + Row("show1a", true) :: Nil) + + checkAnswer( + sql("SHOW TABLES IN default 'show1*|show2*'"), + Row("show1a", true) :: + Row("show2b", true) :: Nil) + + checkAnswer( + sql("SHOW TABLES 'show1*|show2*'"), + Row("show1a", true) :: + Row("show2b", true) :: Nil) + + assert( + sql("SHOW TABLES").count() >= 2) + assert( + sql("SHOW TABLES IN default").count() >= 2) + } + } + + test("show databases") { + sql("CREATE DATABASE showdb1A") + sql("CREATE DATABASE showdb2B") + + assert( + sql("SHOW DATABASES").count() >= 2) + + checkAnswer( + sql("SHOW DATABASES LIKE '*db1A'"), + Row("showdb1A") :: Nil) + + checkAnswer( + sql("SHOW DATABASES LIKE 'showdb1A'"), + Row("showdb1A") :: Nil) + + checkAnswer( + sql("SHOW DATABASES LIKE '*db1A|*db2B'"), + Row("showdb1A") :: + Row("showdb2B") :: Nil) + + checkAnswer( + sql("SHOW DATABASES LIKE 'non-existentdb'"), + Nil) + } + + test("drop table - temporary table") { + val catalog = sqlContext.sessionState.catalog + sql( + """ + |CREATE TEMPORARY TABLE tab1 + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + assert(catalog.listTables("default") == Seq(TableIdentifier("tab1"))) + sql("DROP TABLE tab1") + assert(catalog.listTables("default") == Nil) + } + + test("drop table") { + testDropTable(isDatasourceTable = false) + } + + test("drop table - data source table") { + testDropTable(isDatasourceTable = true) + } + + private def testDropTable(isDatasourceTable: Boolean): Unit = { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + assert(catalog.listTables("dbx") == Seq(tableIdent)) + sql("DROP TABLE dbx.tab1") + assert(catalog.listTables("dbx") == Nil) + sql("DROP TABLE IF EXISTS dbx.tab1") + // no exception will be thrown + sql("DROP TABLE dbx.tab1") + } + + test("drop view in SQLContext") { + // SQLContext does not support create view. Log an error message, if tab1 does not exists + sql("DROP VIEW tab1") + + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + assert(catalog.listTables("dbx") == Seq(tableIdent)) + + val e = intercept[AnalysisException] { + sql("DROP VIEW dbx.tab1") + } + assert( + e.getMessage.contains("Cannot drop a table with DROP VIEW. Please use DROP TABLE instead")) + } + + private def convertToDatasourceTable( + catalog: SessionCatalog, + tableIdent: TableIdentifier): Unit = { + catalog.alterTable(catalog.getTableMetadata(tableIdent).copy( + properties = Map("spark.sql.sources.provider" -> "csv"))) + } + + private def testSetLocation(isDatasourceTable: Boolean): Unit = { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + val partSpec = Map("a" -> "1") + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + createTablePartition(catalog, partSpec, tableIdent) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + assert(catalog.getTableMetadata(tableIdent).storage.locationUri.isEmpty) + assert(catalog.getTableMetadata(tableIdent).storage.serdeProperties.isEmpty) + assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isEmpty) + assert(catalog.getPartition(tableIdent, partSpec).storage.serdeProperties.isEmpty) + // Verify that the location is set to the expected string + def verifyLocation(expected: String, spec: Option[TablePartitionSpec] = None): Unit = { + val storageFormat = spec + .map { s => catalog.getPartition(tableIdent, s).storage } + .getOrElse { catalog.getTableMetadata(tableIdent).storage } + if (isDatasourceTable) { + if (spec.isDefined) { + assert(storageFormat.serdeProperties.isEmpty) + assert(storageFormat.locationUri.isEmpty) + } else { + assert(storageFormat.serdeProperties.get("path") === Some(expected)) + assert(storageFormat.locationUri === Some(expected)) + } + } else { + assert(storageFormat.locationUri === Some(expected)) + } + } + // set table location + sql("ALTER TABLE dbx.tab1 SET LOCATION '/path/to/your/lovely/heart'") + verifyLocation("/path/to/your/lovely/heart") + // set table partition location + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE dbx.tab1 PARTITION (a='1') SET LOCATION '/path/to/part/ways'") + } + verifyLocation("/path/to/part/ways", Some(partSpec)) + // set table location without explicitly specifying database + catalog.setCurrentDatabase("dbx") + sql("ALTER TABLE tab1 SET LOCATION '/swanky/steak/place'") + verifyLocation("/swanky/steak/place") + // set table partition location without explicitly specifying database + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE tab1 PARTITION (a='1') SET LOCATION 'vienna'") + } + verifyLocation("vienna", Some(partSpec)) + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE dbx.does_not_exist SET LOCATION '/mister/spark'") + } + // partition to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE dbx.tab1 PARTITION (b='2') SET LOCATION '/mister/spark'") + } + } + + private def testSetSerde(isDatasourceTable: Boolean): Unit = { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + assert(catalog.getTableMetadata(tableIdent).storage.serde.isEmpty) + assert(catalog.getTableMetadata(tableIdent).storage.serdeProperties.isEmpty) + // set table serde and/or properties (should fail on datasource tables) + if (isDatasourceTable) { + val e1 = intercept[AnalysisException] { + sql("ALTER TABLE dbx.tab1 SET SERDE 'whatever'") + } + val e2 = intercept[AnalysisException] { + sql("ALTER TABLE dbx.tab1 SET SERDE 'org.apache.madoop' " + + "WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee')") + } + assert(e1.getMessage.contains("datasource")) + assert(e2.getMessage.contains("datasource")) + } else { + sql("ALTER TABLE dbx.tab1 SET SERDE 'org.apache.jadoop'") + assert(catalog.getTableMetadata(tableIdent).storage.serde == Some("org.apache.jadoop")) + assert(catalog.getTableMetadata(tableIdent).storage.serdeProperties.isEmpty) + sql("ALTER TABLE dbx.tab1 SET SERDE 'org.apache.madoop' " + + "WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee')") + assert(catalog.getTableMetadata(tableIdent).storage.serde == Some("org.apache.madoop")) + assert(catalog.getTableMetadata(tableIdent).storage.serdeProperties == + Map("k" -> "v", "kay" -> "vee")) + } + // set serde properties only + sql("ALTER TABLE dbx.tab1 SET SERDEPROPERTIES ('k' = 'vvv', 'kay' = 'vee')") + assert(catalog.getTableMetadata(tableIdent).storage.serdeProperties == + Map("k" -> "vvv", "kay" -> "vee")) + // set things without explicitly specifying database + catalog.setCurrentDatabase("dbx") + sql("ALTER TABLE tab1 SET SERDEPROPERTIES ('kay' = 'veee')") + assert(catalog.getTableMetadata(tableIdent).storage.serdeProperties == + Map("k" -> "vvv", "kay" -> "veee")) + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist SET SERDEPROPERTIES ('x' = 'y')") + } + } + + private def testAddPartitions(isDatasourceTable: Boolean): Unit = { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + val part1 = Map("a" -> "1") + val part2 = Map("b" -> "2") + val part3 = Map("c" -> "3") + val part4 = Map("d" -> "4") + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + createTablePartition(catalog, part1, tableIdent) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE dbx.tab1 ADD IF NOT EXISTS " + + "PARTITION (b='2') LOCATION 'paris' PARTITION (c='3')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) + assert(catalog.getPartition(tableIdent, part1).storage.locationUri.isEmpty) + assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Some("paris")) + assert(catalog.getPartition(tableIdent, part3).storage.locationUri.isEmpty) + } + // add partitions without explicitly specifying database + catalog.setCurrentDatabase("dbx") + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE tab1 ADD IF NOT EXISTS PARTITION (d='4')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2, part3, part4)) + } + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist ADD IF NOT EXISTS PARTITION (d='4')") + } + // partition to add already exists + intercept[AnalysisException] { + sql("ALTER TABLE tab1 ADD PARTITION (d='4')") + } + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE tab1 ADD IF NOT EXISTS PARTITION (d='4')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2, part3, part4)) + } + } + + private def testDropPartitions(isDatasourceTable: Boolean): Unit = { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + val part1 = Map("a" -> "1") + val part2 = Map("b" -> "2") + val part3 = Map("c" -> "3") + val part4 = Map("d" -> "4") + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + createTablePartition(catalog, part1, tableIdent) + createTablePartition(catalog, part2, tableIdent) + createTablePartition(catalog, part3, tableIdent) + createTablePartition(catalog, part4, tableIdent) + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2, part3, part4)) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE dbx.tab1 DROP IF EXISTS PARTITION (d='4'), PARTITION (c='3')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2)) + } + // drop partitions without explicitly specifying database + catalog.setCurrentDatabase("dbx") + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (b='2')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec) == Seq(part1)) + } + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist DROP IF EXISTS PARTITION (b='2')") + } + // partition to drop does not exist + intercept[AnalysisException] { + sql("ALTER TABLE tab1 DROP PARTITION (x='300')") + } + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (x='300')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec) == Seq(part1)) + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala new file mode 100644 index 0000000000000..dac56d393647b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -0,0 +1,409 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.File + +import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.mapreduce.Job + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet, PredicateHelper} +import org.apache.spark.sql.catalyst.util +import org.apache.spark.sql.execution.DataSourceScan +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.util.Utils + +class FileSourceStrategySuite extends QueryTest with SharedSQLContext with PredicateHelper { + import testImplicits._ + + test("unpartitioned table, single partition") { + val table = + createTable( + files = Seq( + "file1" -> 1, + "file2" -> 1, + "file3" -> 1, + "file4" -> 1, + "file5" -> 1, + "file6" -> 1, + "file7" -> 1, + "file8" -> 1, + "file9" -> 1, + "file10" -> 1)) + + checkScan(table.select('c1)) { partitions => + // 10 one byte files should fit in a single partition with 10 files. + assert(partitions.size == 1, "when checking partitions") + assert(partitions.head.files.size == 10, "when checking partition 1") + // 1 byte files are too small to split so we should read the whole thing. + assert(partitions.head.files.head.start == 0) + assert(partitions.head.files.head.length == 1) + } + + checkPartitionSchema(StructType(Nil)) + checkDataSchema(StructType(Nil).add("c1", IntegerType)) + } + + test("unpartitioned table, multiple partitions") { + val table = + createTable( + files = Seq( + "file1" -> 5, + "file2" -> 5, + "file3" -> 5)) + + withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "11", + SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "1") { + checkScan(table.select('c1)) { partitions => + // 5 byte files should be laid out [(5, 5), (5)] + assert(partitions.size == 2, "when checking partitions") + assert(partitions(0).files.size == 2, "when checking partition 1") + assert(partitions(1).files.size == 1, "when checking partition 2") + + // 5 byte files are too small to split so we should read the whole thing. + assert(partitions.head.files.head.start == 0) + assert(partitions.head.files.head.length == 5) + } + + checkPartitionSchema(StructType(Nil)) + checkDataSchema(StructType(Nil).add("c1", IntegerType)) + } + } + + test("Unpartitioned table, large file that gets split") { + val table = + createTable( + files = Seq( + "file1" -> 15, + "file2" -> 3)) + + withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "10", + SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "1") { + checkScan(table.select('c1)) { partitions => + // Files should be laid out [(0-10), (10-15, 4)] + assert(partitions.size == 2, "when checking partitions") + assert(partitions(0).files.size == 1, "when checking partition 1") + assert(partitions(1).files.size == 2, "when checking partition 2") + + // Start by reading 10 bytes of the first file + assert(partitions.head.files.head.start == 0) + assert(partitions.head.files.head.length == 10) + + // Second partition reads the remaining 5 + assert(partitions(1).files.head.start == 10) + assert(partitions(1).files.head.length == 5) + } + + checkPartitionSchema(StructType(Nil)) + checkDataSchema(StructType(Nil).add("c1", IntegerType)) + } + } + + test("Unpartitioned table, many files that get split") { + val table = + createTable( + files = Seq( + "file1" -> 2, + "file2" -> 2, + "file3" -> 1, + "file4" -> 1, + "file5" -> 1, + "file6" -> 1)) + + withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "4", + SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "1") { + checkScan(table.select('c1)) { partitions => + // Files should be laid out [(file1), (file2, file3), (file4, file5), (file6)] + assert(partitions.size == 4, "when checking partitions") + assert(partitions(0).files.size == 1, "when checking partition 1") + assert(partitions(1).files.size == 2, "when checking partition 2") + assert(partitions(2).files.size == 2, "when checking partition 3") + assert(partitions(3).files.size == 1, "when checking partition 4") + + // First partition reads (file1) + assert(partitions(0).files(0).start == 0) + assert(partitions(0).files(0).length == 2) + + // Second partition reads (file2, file3) + assert(partitions(1).files(0).start == 0) + assert(partitions(1).files(0).length == 2) + assert(partitions(1).files(1).start == 0) + assert(partitions(1).files(1).length == 1) + + // Third partition reads (file4, file5) + assert(partitions(2).files(0).start == 0) + assert(partitions(2).files(0).length == 1) + assert(partitions(2).files(1).start == 0) + assert(partitions(2).files(1).length == 1) + + // Final partition reads (file6) + assert(partitions(3).files(0).start == 0) + assert(partitions(3).files(0).length == 1) + } + + checkPartitionSchema(StructType(Nil)) + checkDataSchema(StructType(Nil).add("c1", IntegerType)) + } + } + + test("partitioned table") { + val table = + createTable( + files = Seq( + "p1=1/file1" -> 10, + "p1=2/file2" -> 10)) + + // Only one file should be read. + checkScan(table.where("p1 = 1")) { partitions => + assert(partitions.size == 1, "when checking partitions") + assert(partitions.head.files.size == 1, "when files in partition 1") + } + // We don't need to reevaluate filters that are only on partitions. + checkDataFilters(Set.empty) + + // Only one file should be read. + checkScan(table.where("p1 = 1 AND c1 = 1 AND (p1 + c1) = 1")) { partitions => + assert(partitions.size == 1, "when checking partitions") + assert(partitions.head.files.size == 1, "when checking files in partition 1") + assert(partitions.head.files.head.partitionValues.getInt(0) == 1, + "when checking partition values") + } + // Only the filters that do not contain the partition column should be pushed down + checkDataFilters(Set(IsNotNull("c1"), EqualTo("c1", 1))) + } + + test("partitioned table - case insensitive") { + withSQLConf("spark.sql.caseSensitive" -> "false") { + val table = + createTable( + files = Seq( + "p1=1/file1" -> 10, + "p1=2/file2" -> 10)) + + // Only one file should be read. + checkScan(table.where("P1 = 1")) { partitions => + assert(partitions.size == 1, "when checking partitions") + assert(partitions.head.files.size == 1, "when files in partition 1") + } + // We don't need to reevaluate filters that are only on partitions. + checkDataFilters(Set.empty) + + // Only one file should be read. + checkScan(table.where("P1 = 1 AND C1 = 1 AND (P1 + C1) = 1")) { partitions => + assert(partitions.size == 1, "when checking partitions") + assert(partitions.head.files.size == 1, "when checking files in partition 1") + assert(partitions.head.files.head.partitionValues.getInt(0) == 1, + "when checking partition values") + } + // Only the filters that do not contain the partition column should be pushed down + checkDataFilters(Set(IsNotNull("c1"), EqualTo("c1", 1))) + } + } + + test("partitioned table - after scan filters") { + val table = + createTable( + files = Seq( + "p1=1/file1" -> 10, + "p1=2/file2" -> 10)) + + val df = table.where("p1 = 1 AND (p1 + c1) = 2 AND c1 = 1") + // Filter on data only are advisory so we have to reevaluate. + assert(getPhysicalFilters(df) contains resolve(df, "c1 = 1")) + // Need to evalaute filters that are not pushed down. + assert(getPhysicalFilters(df) contains resolve(df, "(p1 + c1) = 2")) + // Don't reevaluate partition only filters. + assert(!(getPhysicalFilters(df) contains resolve(df, "p1 = 1"))) + } + + test("bucketed table") { + val table = + createTable( + files = Seq( + "p1=1/file1_0000" -> 1, + "p1=1/file2_0000" -> 1, + "p1=1/file3_0002" -> 1, + "p1=2/file4_0002" -> 1, + "p1=2/file5_0000" -> 1, + "p1=2/file6_0000" -> 1, + "p1=2/file7_0000" -> 1), + buckets = 3) + + // No partition pruning + checkScan(table) { partitions => + assert(partitions.size == 3) + assert(partitions(0).files.size == 5) + assert(partitions(1).files.size == 0) + assert(partitions(2).files.size == 2) + } + + // With partition pruning + checkScan(table.where("p1=2")) { partitions => + assert(partitions.size == 3) + assert(partitions(0).files.size == 3) + assert(partitions(1).files.size == 0) + assert(partitions(2).files.size == 1) + } + } + + // Helpers for checking the arguments passed to the FileFormat. + + protected val checkPartitionSchema = + checkArgument("partition schema", _.partitionSchema, _: StructType) + protected val checkDataSchema = + checkArgument("data schema", _.dataSchema, _: StructType) + protected val checkDataFilters = + checkArgument("data filters", _.filters.toSet, _: Set[Filter]) + + /** Helper for building checks on the arguments passed to the reader. */ + protected def checkArgument[T](name: String, arg: LastArguments.type => T, expected: T): Unit = { + if (arg(LastArguments) != expected) { + fail( + s""" + |Wrong $name + |expected: $expected + |actual: ${arg(LastArguments)} + """.stripMargin) + } + } + + /** Returns a resolved expression for `str` in the context of `df`. */ + def resolve(df: DataFrame, str: String): Expression = { + df.select(expr(str)).queryExecution.analyzed.expressions.head.children.head + } + + /** Returns a set with all the filters present in the physical plan. */ + def getPhysicalFilters(df: DataFrame): ExpressionSet = { + ExpressionSet( + df.queryExecution.executedPlan.collect { + case execution.Filter(f, _) => splitConjunctivePredicates(f) + }.flatten) + } + + /** Plans the query and calls the provided validation function with the planned partitioning. */ + def checkScan(df: DataFrame)(func: Seq[FilePartition] => Unit): Unit = { + val fileScan = df.queryExecution.executedPlan.collect { + case scan: DataSourceScan if scan.rdd.isInstanceOf[FileScanRDD] => + scan.rdd.asInstanceOf[FileScanRDD] + }.headOption.getOrElse { + fail(s"No FileScan in query\n${df.queryExecution}") + } + + func(fileScan.filePartitions) + } + + /** + * Constructs a new table given a list of file names and sizes expressed in bytes. The table + * is written out in a temporary directory and any nested directories in the files names + * are automatically created. + * + * When `buckets` is > 0 the returned [[DataFrame]] will have metadata specifying that number of + * buckets. However, it is the responsibility of the caller to assign files to each bucket + * by appending the bucket id to the file names. + */ + def createTable( + files: Seq[(String, Int)], + buckets: Int = 0): DataFrame = { + val tempDir = Utils.createTempDir() + files.foreach { + case (name, size) => + val file = new File(tempDir, name) + assert(file.getParentFile.exists() || file.getParentFile.mkdirs()) + util.stringToFile(file, "*" * size) + } + + val df = sqlContext.read + .format(classOf[TestFileFormat].getName) + .load(tempDir.getCanonicalPath) + + if (buckets > 0) { + val bucketed = df.queryExecution.analyzed transform { + case l @ LogicalRelation(r: HadoopFsRelation, _, _) => + l.copy(relation = + r.copy(bucketSpec = Some(BucketSpec(numBuckets = buckets, "c1" :: Nil, Nil)))) + } + Dataset.ofRows(sqlContext, bucketed) + } else { + df + } + } +} + +/** Holds the last arguments passed to [[TestFileFormat]]. */ +object LastArguments { + var partitionSchema: StructType = _ + var dataSchema: StructType = _ + var filters: Seq[Filter] = _ + var options: Map[String, String] = _ +} + +/** A test [[FileFormat]] that records the arguments passed to buildReader, and returns nothing. */ +class TestFileFormat extends FileFormat { + + override def toString: String = "TestFileFormat" + + /** + * When possible, this method should return the schema of the given `files`. When the format + * does not support inference, or no valid files are given should return None. In these cases + * Spark will require that user specify the schema manually. + */ + override def inferSchema( + sqlContext: SQLContext, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = + Some( + StructType(Nil) + .add("c1", IntegerType) + .add("c2", IntegerType)) + + /** + * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can + * be put here. For example, user defined output committer can be configured here + * by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass. + */ + override def prepareWrite( + sqlContext: SQLContext, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + throw new NotImplementedError("JUST FOR TESTING") + } + + override def buildReader( + sqlContext: SQLContext, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = { + + // Record the arguments so they can be checked in the test case. + LastArguments.partitionSchema = partitionSchema + LastArguments.dataSchema = requiredSchema + LastArguments.filters = filters + LastArguments.options = options + + (file: PartitionedFile) => { Iterator.empty } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala new file mode 100644 index 0000000000000..297731c70c151 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.{File, FilenameFilter} + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.test.SharedSQLContext + +class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { + + test("sizeInBytes should be the total size of all files") { + withTempDir{ dir => + dir.delete() + sqlContext.range(1000).write.parquet(dir.toString) + // ignore hidden files + val allFiles = dir.listFiles(new FilenameFilter { + override def accept(dir: File, name: String): Boolean = { + !name.startsWith(".") + } + }) + val totalSize = allFiles.map(_.length()).sum + val df = sqlContext.read.parquet(dir.toString) + assert(df.queryExecution.logical.statistics.sizeInBytes === BigInt(totalSize)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala new file mode 100644 index 0000000000000..23d422635b0a9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +class CSVInferSchemaSuite extends SparkFunSuite { + + test("String fields types are inferred correctly from null types") { + assert(CSVInferSchema.inferField(NullType, "") == NullType) + assert(CSVInferSchema.inferField(NullType, null) == NullType) + assert(CSVInferSchema.inferField(NullType, "100000000000") == LongType) + assert(CSVInferSchema.inferField(NullType, "60") == IntegerType) + assert(CSVInferSchema.inferField(NullType, "3.5") == DoubleType) + assert(CSVInferSchema.inferField(NullType, "test") == StringType) + assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00") == TimestampType) + assert(CSVInferSchema.inferField(NullType, "True") == BooleanType) + assert(CSVInferSchema.inferField(NullType, "FAlSE") == BooleanType) + } + + test("String fields types are inferred correctly from other types") { + assert(CSVInferSchema.inferField(LongType, "1.0") == DoubleType) + assert(CSVInferSchema.inferField(LongType, "test") == StringType) + assert(CSVInferSchema.inferField(IntegerType, "1.0") == DoubleType) + assert(CSVInferSchema.inferField(DoubleType, null) == DoubleType) + assert(CSVInferSchema.inferField(DoubleType, "test") == StringType) + assert(CSVInferSchema.inferField(LongType, "2015-08-20 14:57:00") == TimestampType) + assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 15:57:00") == TimestampType) + assert(CSVInferSchema.inferField(LongType, "True") == BooleanType) + assert(CSVInferSchema.inferField(IntegerType, "FALSE") == BooleanType) + assert(CSVInferSchema.inferField(TimestampType, "FALSE") == BooleanType) + } + + test("Timestamp field types are inferred correctly from other types") { + assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14") == StringType) + assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 14:10") == StringType) + assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00") == StringType) + } + + test("Boolean fields types are inferred correctly from other types") { + assert(CSVInferSchema.inferField(LongType, "Fale") == StringType) + assert(CSVInferSchema.inferField(DoubleType, "TRUEe") == StringType) + } + + test("Type arrays are merged to highest common type") { + assert( + CSVInferSchema.mergeRowTypes(Array(StringType), + Array(DoubleType)).deep == Array(StringType).deep) + assert( + CSVInferSchema.mergeRowTypes(Array(IntegerType), + Array(LongType)).deep == Array(LongType).deep) + assert( + CSVInferSchema.mergeRowTypes(Array(DoubleType), + Array(LongType)).deep == Array(DoubleType).deep) + } + + test("Null fields are handled properly when a nullValue is specified") { + assert(CSVInferSchema.inferField(NullType, "null", "null") == NullType) + assert(CSVInferSchema.inferField(StringType, "null", "null") == StringType) + assert(CSVInferSchema.inferField(LongType, "null", "null") == LongType) + assert(CSVInferSchema.inferField(IntegerType, "\\N", "\\N") == IntegerType) + assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType) + assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType) + assert(CSVInferSchema.inferField(BooleanType, "\\N", "\\N") == BooleanType) + } + + test("Merging Nulltypes should yield Nulltype.") { + val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), Array(NullType)) + assert(mergedNullTypes.deep == Array(NullType).deep) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala new file mode 100644 index 0000000000000..aaeecef5f37fc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import org.apache.spark.SparkFunSuite + +/** + * test cases for StringIteratorReader + */ +class CSVParserSuite extends SparkFunSuite { + + private def readAll(iter: Iterator[String]) = { + val reader = new StringIteratorReader(iter) + var c: Int = -1 + val read = new scala.collection.mutable.StringBuilder() + do { + c = reader.read() + read.append(c.toChar) + } while (c != -1) + + read.dropRight(1).toString + } + + private def readBufAll(iter: Iterator[String], bufSize: Int) = { + val reader = new StringIteratorReader(iter) + val cbuf = new Array[Char](bufSize) + val read = new scala.collection.mutable.StringBuilder() + + var done = false + do { // read all input one cbuf at a time + var numRead = 0 + var n = 0 + do { // try to fill cbuf + var off = 0 + var len = cbuf.length + n = reader.read(cbuf, off, len) + + if (n != -1) { + off += n + len -= n + } + + assert(len >= 0 && len <= cbuf.length) + assert(off >= 0 && off <= cbuf.length) + read.appendAll(cbuf.take(n)) + } while (n > 0) + if(n != -1) { + numRead += n + } else { + done = true + } + } while (!done) + + read.toString + } + + test("Hygiene") { + val reader = new StringIteratorReader(List("").toIterator) + assert(reader.ready === true) + assert(reader.markSupported === false) + intercept[IllegalArgumentException] { reader.skip(1) } + intercept[IllegalArgumentException] { reader.mark(1) } + intercept[IllegalArgumentException] { reader.reset() } + } + + test("Regular case") { + val input = List("This is a string", "This is another string", "Small", "", "\"quoted\"") + val read = readAll(input.toIterator) + assert(read === input.mkString("\n") ++ "\n") + } + + test("Empty iter") { + val input = List[String]() + val read = readAll(input.toIterator) + assert(read === "") + } + + test("Embedded new line") { + val input = List("This is a string", "This is another string", "Small\n", "", "\"quoted\"") + val read = readAll(input.toIterator) + assert(read === input.mkString("\n") ++ "\n") + } + + test("Buffer Regular case") { + val input = List("This is a string", "This is another string", "Small", "", "\"quoted\"") + val output = input.mkString("\n") ++ "\n" + for(i <- 1 to output.length + 5) { + val read = readBufAll(input.toIterator, i) + assert(read === output) + } + } + + test("Buffer Empty iter") { + val input = List[String]() + val output = "" + for(i <- 1 to output.length + 5) { + val read = readBufAll(input.toIterator, 1) + assert(read === "") + } + } + + test("Buffer Embedded new line") { + val input = List("This is a string", "This is another string", "Small\n", "", "\"quoted\"") + val output = input.mkString("\n") ++ "\n" + for(i <- 1 to output.length + 5) { + val read = readBufAll(input.toIterator, 1) + assert(read === output) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala new file mode 100644 index 0000000000000..9baae80f15981 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -0,0 +1,493 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.io.File +import java.nio.charset.UnsupportedCharsetException +import java.sql.Timestamp + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.io.SequenceFile.CompressionType +import org.apache.hadoop.io.compress.GzipCodec + +import org.apache.spark.SparkException +import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.types._ + +class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { + private val carsFile = "cars.csv" + private val carsMalformedFile = "cars-malformed.csv" + private val carsFile8859 = "cars_iso-8859-1.csv" + private val carsTsvFile = "cars.tsv" + private val carsAltFile = "cars-alternative.csv" + private val carsUnbalancedQuotesFile = "cars-unbalanced-quotes.csv" + private val carsNullFile = "cars-null.csv" + private val emptyFile = "empty.csv" + private val commentsFile = "comments.csv" + private val disableCommentsFile = "disable_comments.csv" + private val boolFile = "bool.csv" + private val simpleSparseFile = "simple_sparse.csv" + private val unescapedQuotesFile = "unescaped-quotes.csv" + + private def testFile(fileName: String): String = { + Thread.currentThread().getContextClassLoader.getResource(fileName).toString + } + + /** Verifies data and schema. */ + private def verifyCars( + df: DataFrame, + withHeader: Boolean, + numCars: Int = 3, + numFields: Int = 5, + checkHeader: Boolean = true, + checkValues: Boolean = true, + checkTypes: Boolean = false): Unit = { + + val numColumns = numFields + val numRows = if (withHeader) numCars else numCars + 1 + // schema + assert(df.schema.fieldNames.length === numColumns) + assert(df.count === numRows) + + if (checkHeader) { + if (withHeader) { + assert(df.schema.fieldNames === Array("year", "make", "model", "comment", "blank")) + } else { + assert(df.schema.fieldNames === Array("C0", "C1", "C2", "C3", "C4")) + } + } + + if (checkValues) { + val yearValues = List("2012", "1997", "2015") + val actualYears = if (!withHeader) "year" :: yearValues else yearValues + val years = if (withHeader) df.select("year").collect() else df.select("C0").collect() + + years.zipWithIndex.foreach { case (year, index) => + if (checkTypes) { + assert(year === Row(actualYears(index).toInt)) + } else { + assert(year === Row(actualYears(index))) + } + } + } + } + + test("simple csv test") { + val cars = sqlContext + .read + .format("csv") + .option("header", "false") + .load(testFile(carsFile)) + + verifyCars(cars, withHeader = false, checkTypes = false) + } + + test("simple csv test with calling another function to load") { + val cars = sqlContext + .read + .option("header", "false") + .csv(testFile(carsFile)) + + verifyCars(cars, withHeader = false, checkTypes = false) + } + + test("simple csv test with type inference") { + val cars = sqlContext + .read + .format("csv") + .option("header", "true") + .option("inferSchema", "true") + .load(testFile(carsFile)) + + verifyCars(cars, withHeader = true, checkTypes = true) + } + + test("test inferring booleans") { + val result = sqlContext.read + .format("csv") + .option("header", "true") + .option("inferSchema", "true") + .load(testFile(boolFile)) + + val expectedSchema = StructType(List( + StructField("bool", BooleanType, nullable = true))) + assert(result.schema === expectedSchema) + } + + test("test with alternative delimiter and quote") { + val cars = sqlContext.read + .format("csv") + .options(Map("quote" -> "\'", "delimiter" -> "|", "header" -> "true")) + .load(testFile(carsAltFile)) + + verifyCars(cars, withHeader = true) + } + + test("parse unescaped quotes with maxCharsPerColumn") { + val rows = sqlContext.read + .format("csv") + .option("maxCharsPerColumn", "4") + .load(testFile(unescapedQuotesFile)) + + val expectedRows = Seq(Row("\"a\"b", "ccc", "ddd"), Row("ab", "cc\"c", "ddd\"")) + + checkAnswer(rows, expectedRows) + } + + test("bad encoding name") { + val exception = intercept[UnsupportedCharsetException] { + sqlContext + .read + .format("csv") + .option("charset", "1-9588-osi") + .load(testFile(carsFile8859)) + } + + assert(exception.getMessage.contains("1-9588-osi")) + } + + test("test different encoding") { + // scalastyle:off + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE carsTable USING csv + |OPTIONS (path "${testFile(carsFile8859)}", header "true", + |charset "iso-8859-1", delimiter "þ") + """.stripMargin.replaceAll("\n", " ")) + // scalastyle:on + + verifyCars(sqlContext.table("carsTable"), withHeader = true) + } + + test("test aliases sep and encoding for delimiter and charset") { + // scalastyle:off + val cars = sqlContext + .read + .format("csv") + .option("header", "true") + .option("encoding", "iso-8859-1") + .option("sep", "þ") + .load(testFile(carsFile8859)) + // scalastyle:on + + verifyCars(cars, withHeader = true) + } + + test("DDL test with tab separated file") { + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE carsTable USING csv + |OPTIONS (path "${testFile(carsTsvFile)}", header "true", delimiter "\t") + """.stripMargin.replaceAll("\n", " ")) + + verifyCars(sqlContext.table("carsTable"), numFields = 6, withHeader = true, checkHeader = false) + } + + test("DDL test parsing decimal type") { + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE carsTable + |(yearMade double, makeName string, modelName string, priceTag decimal, + | comments string, grp string) + |USING csv + |OPTIONS (path "${testFile(carsTsvFile)}", header "true", delimiter "\t") + """.stripMargin.replaceAll("\n", " ")) + + assert( + sqlContext.sql("SELECT makeName FROM carsTable where priceTag > 60000").collect().size === 1) + } + + test("test for DROPMALFORMED parsing mode") { + val cars = sqlContext.read + .format("csv") + .options(Map("header" -> "true", "mode" -> "dropmalformed")) + .load(testFile(carsFile)) + + assert(cars.select("year").collect().size === 2) + } + + test("test for FAILFAST parsing mode") { + val exception = intercept[SparkException]{ + sqlContext.read + .format("csv") + .options(Map("header" -> "true", "mode" -> "failfast")) + .load(testFile(carsFile)).collect() + } + + assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt")) + } + + test("test for tokens more than the fields in the schema") { + val cars = sqlContext + .read + .format("csv") + .option("header", "false") + .option("comment", "~") + .load(testFile(carsMalformedFile)) + + verifyCars(cars, withHeader = false, checkTypes = false) + } + + test("test with null quote character") { + val cars = sqlContext.read + .format("csv") + .option("header", "true") + .option("quote", "") + .load(testFile(carsUnbalancedQuotesFile)) + + verifyCars(cars, withHeader = true, checkValues = false) + + } + + test("test with empty file and known schema") { + val result = sqlContext.read + .format("csv") + .schema(StructType(List(StructField("column", StringType, false)))) + .load(testFile(emptyFile)) + + assert(result.collect.size === 0) + assert(result.schema.fieldNames.size === 1) + } + + test("DDL test with empty file") { + sqlContext.sql(s""" + |CREATE TEMPORARY TABLE carsTable + |(yearMade double, makeName string, modelName string, comments string, grp string) + |USING csv + |OPTIONS (path "${testFile(emptyFile)}", header "false") + """.stripMargin.replaceAll("\n", " ")) + + assert(sqlContext.sql("SELECT count(*) FROM carsTable").collect().head(0) === 0) + } + + test("DDL test with schema") { + sqlContext.sql(s""" + |CREATE TEMPORARY TABLE carsTable + |(yearMade double, makeName string, modelName string, comments string, blank string) + |USING csv + |OPTIONS (path "${testFile(carsFile)}", header "true") + """.stripMargin.replaceAll("\n", " ")) + + val cars = sqlContext.table("carsTable") + verifyCars(cars, withHeader = true, checkHeader = false, checkValues = false) + assert( + cars.schema.fieldNames === Array("yearMade", "makeName", "modelName", "comments", "blank")) + } + + test("save csv") { + withTempDir { dir => + val csvDir = new File(dir, "csv").getCanonicalPath + val cars = sqlContext.read + .format("csv") + .option("header", "true") + .load(testFile(carsFile)) + + cars.coalesce(1).write + .option("header", "true") + .csv(csvDir) + + val carsCopy = sqlContext.read + .format("csv") + .option("header", "true") + .load(csvDir) + + verifyCars(carsCopy, withHeader = true) + } + } + + test("save csv with quote") { + withTempDir { dir => + val csvDir = new File(dir, "csv").getCanonicalPath + val cars = sqlContext.read + .format("csv") + .option("header", "true") + .load(testFile(carsFile)) + + cars.coalesce(1).write + .format("csv") + .option("header", "true") + .option("quote", "\"") + .save(csvDir) + + val carsCopy = sqlContext.read + .format("csv") + .option("header", "true") + .option("quote", "\"") + .load(csvDir) + + verifyCars(carsCopy, withHeader = true) + } + } + + test("commented lines in CSV data") { + val results = sqlContext.read + .format("csv") + .options(Map("comment" -> "~", "header" -> "false")) + .load(testFile(commentsFile)) + .collect() + + val expected = + Seq(Seq("1", "2", "3", "4", "5.01", "2015-08-20 15:57:00"), + Seq("6", "7", "8", "9", "0", "2015-08-21 16:58:01"), + Seq("1", "2", "3", "4", "5", "2015-08-23 18:00:42")) + + assert(results.toSeq.map(_.toSeq) === expected) + } + + test("inferring schema with commented lines in CSV data") { + val results = sqlContext.read + .format("csv") + .options(Map("comment" -> "~", "header" -> "false", "inferSchema" -> "true")) + .load(testFile(commentsFile)) + .collect() + + val expected = + Seq(Seq(1, 2, 3, 4, 5.01D, Timestamp.valueOf("2015-08-20 15:57:00")), + Seq(6, 7, 8, 9, 0, Timestamp.valueOf("2015-08-21 16:58:01")), + Seq(1, 2, 3, 4, 5, Timestamp.valueOf("2015-08-23 18:00:42"))) + + assert(results.toSeq.map(_.toSeq) === expected) + } + + test("setting comment to null disables comment support") { + val results = sqlContext.read + .format("csv") + .options(Map("comment" -> "", "header" -> "false")) + .load(testFile(disableCommentsFile)) + .collect() + + val expected = + Seq( + Seq("#1", "2", "3"), + Seq("4", "5", "6")) + + assert(results.toSeq.map(_.toSeq) === expected) + } + + test("nullable fields with user defined null value of \"null\"") { + + // year,make,model,comment,blank + val dataSchema = StructType(List( + StructField("year", IntegerType, nullable = true), + StructField("make", StringType, nullable = false), + StructField("model", StringType, nullable = false), + StructField("comment", StringType, nullable = true), + StructField("blank", StringType, nullable = true))) + val cars = sqlContext.read + .format("csv") + .schema(dataSchema) + .options(Map("header" -> "true", "nullValue" -> "null")) + .load(testFile(carsNullFile)) + + verifyCars(cars, withHeader = true, checkValues = false) + val results = cars.collect() + assert(results(0).toSeq === Array(2012, "Tesla", "S", "null", "null")) + assert(results(2).toSeq === Array(null, "Chevy", "Volt", null, null)) + } + + test("save csv with compression codec option") { + withTempDir { dir => + val csvDir = new File(dir, "csv").getCanonicalPath + val cars = sqlContext.read + .format("csv") + .option("header", "true") + .load(testFile(carsFile)) + + cars.coalesce(1).write + .format("csv") + .option("header", "true") + .option("compression", "gZiP") + .save(csvDir) + + val compressedFiles = new File(csvDir).listFiles() + assert(compressedFiles.exists(_.getName.endsWith(".csv.gz"))) + + val carsCopy = sqlContext.read + .format("csv") + .option("header", "true") + .load(csvDir) + + verifyCars(carsCopy, withHeader = true) + } + } + + test("SPARK-13543 Write the output as uncompressed via option()") { + val clonedConf = new Configuration(hadoopConfiguration) + hadoopConfiguration.set("mapreduce.output.fileoutputformat.compress", "true") + hadoopConfiguration + .set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString) + hadoopConfiguration + .set("mapreduce.output.fileoutputformat.compress.codec", classOf[GzipCodec].getName) + hadoopConfiguration.set("mapreduce.map.output.compress", "true") + hadoopConfiguration.set("mapreduce.map.output.compress.codec", classOf[GzipCodec].getName) + withTempDir { dir => + try { + val csvDir = new File(dir, "csv").getCanonicalPath + val cars = sqlContext.read + .format("csv") + .option("header", "true") + .load(testFile(carsFile)) + + cars.coalesce(1).write + .format("csv") + .option("header", "true") + .option("compression", "none") + .save(csvDir) + + val compressedFiles = new File(csvDir).listFiles() + assert(compressedFiles.exists(!_.getName.endsWith(".csv.gz"))) + + val carsCopy = sqlContext.read + .format("csv") + .option("header", "true") + .load(csvDir) + + verifyCars(carsCopy, withHeader = true) + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + } + } + } + + test("Schema inference correctly identifies the datatype when data is sparse.") { + val df = sqlContext.read + .format("csv") + .option("header", "true") + .option("inferSchema", "true") + .load(testFile(simpleSparseFile)) + + assert( + df.schema.fields.map(field => field.dataType).deep == + Array(IntegerType, IntegerType, IntegerType, IntegerType).deep) + } + + test("old csv data source name works") { + val cars = sqlContext + .read + .format("com.databricks.spark.csv") + .option("header", "false") + .load(testFile(carsFile)) + + verifyCars(cars, withHeader = false, checkTypes = false) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala new file mode 100644 index 0000000000000..5702a1b4ea1f7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.math.BigDecimal +import java.util.Locale + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class CSVTypeCastSuite extends SparkFunSuite { + + test("Can parse decimal type values") { + val stringValues = Seq("10.05", "1,000.01", "158,058,049.001") + val decimalValues = Seq(10.05, 1000.01, 158058049.001) + val decimalType = new DecimalType() + + stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) => + val decimalValue = new BigDecimal(decimalVal.toString) + assert(CSVTypeCast.castTo(strVal, decimalType) === + Decimal(decimalValue, decimalType.precision, decimalType.scale)) + } + } + + test("Can parse escaped characters") { + assert(CSVTypeCast.toChar("""\t""") === '\t') + assert(CSVTypeCast.toChar("""\r""") === '\r') + assert(CSVTypeCast.toChar("""\b""") === '\b') + assert(CSVTypeCast.toChar("""\f""") === '\f') + assert(CSVTypeCast.toChar("""\"""") === '\"') + assert(CSVTypeCast.toChar("""\'""") === '\'') + assert(CSVTypeCast.toChar("""\u0000""") === '\u0000') + } + + test("Does not accept delimiter larger than one character") { + val exception = intercept[IllegalArgumentException]{ + CSVTypeCast.toChar("ab") + } + assert(exception.getMessage.contains("cannot be more than one character")) + } + + test("Throws exception for unsupported escaped characters") { + val exception = intercept[IllegalArgumentException]{ + CSVTypeCast.toChar("""\1""") + } + assert(exception.getMessage.contains("Unsupported special character for delimiter")) + } + + test("Nullable types are handled") { + assert(CSVTypeCast.castTo("", IntegerType, nullable = true) == null) + } + + test("String type should always return the same as the input") { + assert(CSVTypeCast.castTo("", StringType, nullable = true) == UTF8String.fromString("")) + assert(CSVTypeCast.castTo("", StringType, nullable = false) == UTF8String.fromString("")) + } + + test("Throws exception for empty string with non null type") { + val exception = intercept[NumberFormatException]{ + CSVTypeCast.castTo("", IntegerType, nullable = false) + } + assert(exception.getMessage.contains("For input string: \"\"")) + } + + test("Types are cast correctly") { + assert(CSVTypeCast.castTo("10", ByteType) == 10) + assert(CSVTypeCast.castTo("10", ShortType) == 10) + assert(CSVTypeCast.castTo("10", IntegerType) == 10) + assert(CSVTypeCast.castTo("10", LongType) == 10) + assert(CSVTypeCast.castTo("1.00", FloatType) == 1.0) + assert(CSVTypeCast.castTo("1.00", DoubleType) == 1.0) + assert(CSVTypeCast.castTo("true", BooleanType) == true) + val timestamp = "2015-01-01 00:00:00" + assert(CSVTypeCast.castTo(timestamp, TimestampType) == + DateTimeUtils.stringToTime(timestamp).getTime * 1000L) + assert(CSVTypeCast.castTo("2015-01-01", DateType) == + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime("2015-01-01").getTime)) + } + + test("Float and Double Types are cast correctly with Locale") { + val originalLocale = Locale.getDefault + try { + val locale : Locale = new Locale("fr", "FR") + Locale.setDefault(locale) + assert(CSVTypeCast.castTo("1,00", FloatType) == 1.0) + assert(CSVTypeCast.castTo("1,00", DoubleType) == 1.0) + } finally { + Locale.setDefault(originalLocale) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala new file mode 100644 index 0000000000000..1742df31bba9a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.json + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.test.SharedSQLContext + +/** + * Test cases for various [[JSONOptions]]. + */ +class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { + + test("allowComments off") { + val str = """{'name': /* hello */ 'Reynold Xin'}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } + + test("allowComments on") { + val str = """{'name': /* hello */ 'Reynold Xin'}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowComments", "true").json(rdd) + + assert(df.schema.head.name == "name") + assert(df.first().getString(0) == "Reynold Xin") + } + + test("allowSingleQuotes off") { + val str = """{'name': 'Reynold Xin'}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowSingleQuotes", "false").json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } + + test("allowSingleQuotes on") { + val str = """{'name': 'Reynold Xin'}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.json(rdd) + + assert(df.schema.head.name == "name") + assert(df.first().getString(0) == "Reynold Xin") + } + + test("allowUnquotedFieldNames off") { + val str = """{name: 'Reynold Xin'}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } + + test("allowUnquotedFieldNames on") { + val str = """{name: 'Reynold Xin'}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowUnquotedFieldNames", "true").json(rdd) + + assert(df.schema.head.name == "name") + assert(df.first().getString(0) == "Reynold Xin") + } + + test("allowNumericLeadingZeros off") { + val str = """{"age": 0018}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } + + test("allowNumericLeadingZeros on") { + val str = """{"age": 0018}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowNumericLeadingZeros", "true").json(rdd) + + assert(df.schema.head.name == "age") + assert(df.first().getLong(0) == 18) + } + + // The following two tests are not really working - need to look into Jackson's + // JsonParser.Feature.ALLOW_NON_NUMERIC_NUMBERS. + ignore("allowNonNumericNumbers off") { + val str = """{"age": NaN}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } + + ignore("allowNonNumericNumbers on") { + val str = """{"age": NaN}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowNonNumericNumbers", "true").json(rdd) + + assert(df.schema.head.name == "age") + assert(df.first().getDouble(0).isNaN) + } + + test("allowBackslashEscapingAnyCharacter off") { + val str = """{"name": "Cazen Lee", "price": "\$10"}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowBackslashEscapingAnyCharacter", "false").json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } + + test("allowBackslashEscapingAnyCharacter on") { + val str = """{"name": "Cazen Lee", "price": "\$10"}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowBackslashEscapingAnyCharacter", "true").json(rdd) + + assert(df.schema.head.name == "name") + assert(df.schema.last.name == "price") + assert(df.first().getString(0) == "Cazen Lee") + assert(df.first().getString(1) == "$10") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 28b8f02bdf87f..e17340c70b7e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -18,20 +18,32 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} +import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} +import scala.collection.JavaConverters._ + import com.fasterxml.jackson.core.JsonFactory -import org.apache.spark.rdd.RDD -import org.scalactic.Tolerance._ +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, PathFilter} +import org.apache.hadoop.io.SequenceFile.CompressionType +import org.apache.hadoop.io.compress.GzipCodec +import org.apache.spark.SparkException +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +class TestFileFilter extends PathFilter { + override def accept(path: Path): Boolean = path.getParent.getName != "p=2" +} + class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ @@ -54,7 +66,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Utils.tryWithResource(factory.createParser(writer.toString)) { parser => parser.nextToken() - JacksonParser.convertField(factory, parser, dataType) + JacksonParser.convertRootField(factory, parser, dataType) } } @@ -74,9 +86,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val doubleNumber: Double = 1.7976931348623157E308d checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType)) - checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber)), + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber * 1000L)), enforceCorrectType(intNumber, TimestampType)) - checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber.toLong)), + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber.toLong * 1000L)), enforceCorrectType(intNumber.toLong, TimestampType)) val strTime = "2014-09-30 12:34:56" checkTypePromotion(DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(strTime)), @@ -197,7 +209,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { StructType( StructField("f1", IntegerType, true) :: StructField("f2", IntegerType, true) :: Nil), - StructType(StructField("f1", LongType, true) :: Nil) , + StructType(StructField("f1", LongType, true) :: Nil), StructType( StructField("f1", LongType, true) :: StructField("f2", IntegerType, true) :: Nil)) @@ -571,35 +583,6 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { jsonDF.registerTempTable("jsonTable") } - test("jsonFile should be based on JSONRelation") { - val dir = Utils.createTempDir() - dir.delete() - val path = dir.getCanonicalFile.toURI.toString - sparkContext.parallelize(1 to 100) - .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) - val jsonDF = sqlContext.read.option("samplingRatio", "0.49").json(path) - - val analyzed = jsonDF.queryExecution.analyzed - assert( - analyzed.isInstanceOf[LogicalRelation], - "The DataFrame returned by jsonFile should be based on LogicalRelation.") - val relation = analyzed.asInstanceOf[LogicalRelation].relation - assert( - relation.isInstanceOf[JSONRelation], - "The DataFrame returned by jsonFile should be based on JSONRelation.") - assert(relation.asInstanceOf[JSONRelation].paths === Array(path)) - assert(relation.asInstanceOf[JSONRelation].samplingRatio === (0.49 +- 0.001)) - - val schema = StructType(StructField("a", LongType, true) :: Nil) - val logicalRelation = - sqlContext.read.schema(schema).json(path) - .queryExecution.analyzed.asInstanceOf[LogicalRelation] - val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] - assert(relationWithSchema.paths === Array(path)) - assert(relationWithSchema.schema === schema) - assert(relationWithSchema.samplingRatio > 0.99) - } - test("Loading a JSON dataset from a text file") { val dir = Utils.createTempDir() dir.delete() @@ -762,6 +745,100 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } + test("Loading a JSON dataset prefersDecimal returns schema with float types as BigDecimal") { + val jsonDF = sqlContext.read.option("prefersDecimal", "true").json(primitiveFieldAndType) + + val expectedSchema = StructType( + StructField("bigInteger", DecimalType(20, 0), true) :: + StructField("boolean", BooleanType, true) :: + StructField("double", DecimalType(17, -292), true) :: + StructField("integer", LongType, true) :: + StructField("long", LongType, true) :: + StructField("null", StringType, true) :: + StructField("string", StringType, true) :: Nil) + + assert(expectedSchema === jsonDF.schema) + + jsonDF.registerTempTable("jsonTable") + + checkAnswer( + sql("select * from jsonTable"), + Row(BigDecimal("92233720368547758070"), + true, + BigDecimal("1.7976931348623157E308"), + 10, + 21474836470L, + null, + "this is a simple string.") + ) + } + + test("Find compatible types even if inferred DecimalType is not capable of other IntegralType") { + val mixedIntegerAndDoubleRecords = sparkContext.parallelize( + """{"a": 3, "b": 1.1}""" :: + s"""{"a": 3.1, "b": 0.${"0" * 38}1}""" :: Nil) + val jsonDF = sqlContext.read + .option("prefersDecimal", "true") + .json(mixedIntegerAndDoubleRecords) + + // The values in `a` field will be decimals as they fit in decimal. For `b` field, + // they will be doubles as `1.0E-39D` does not fit. + val expectedSchema = StructType( + StructField("a", DecimalType(21, 1), true) :: + StructField("b", DoubleType, true) :: Nil) + + assert(expectedSchema === jsonDF.schema) + checkAnswer( + jsonDF, + Row(BigDecimal("3"), 1.1D) :: + Row(BigDecimal("3.1"), 1.0E-39D) :: Nil + ) + } + + test("Infer big integers correctly even when it does not fit in decimal") { + val jsonDF = sqlContext.read + .json(bigIntegerRecords) + + // The value in `a` field will be a double as it does not fit in decimal. For `b` field, + // it will be a decimal as `92233720368547758070`. + val expectedSchema = StructType( + StructField("a", DoubleType, true) :: + StructField("b", DecimalType(20, 0), true) :: Nil) + + assert(expectedSchema === jsonDF.schema) + checkAnswer(jsonDF, Row(1.0E38D, BigDecimal("92233720368547758070"))) + } + + test("Infer floating-point values correctly even when it does not fit in decimal") { + val jsonDF = sqlContext.read + .option("prefersDecimal", "true") + .json(floatingValueRecords) + + // The value in `a` field will be a double as it does not fit in decimal. For `b` field, + // it will be a decimal as `0.01` by having a precision equal to the scale. + val expectedSchema = StructType( + StructField("a", DoubleType, true) :: + StructField("b", DecimalType(2, 2), true):: Nil) + + assert(expectedSchema === jsonDF.schema) + checkAnswer(jsonDF, Row(1.0E-39D, BigDecimal(0.01))) + + val mergedJsonDF = sqlContext.read + .option("prefersDecimal", "true") + .json(floatingValueRecords ++ bigIntegerRecords) + + val expectedMergedSchema = StructType( + StructField("a", DoubleType, true) :: + StructField("b", DecimalType(22, 2), true):: Nil) + + assert(expectedMergedSchema === mergedJsonDF.schema) + checkAnswer( + mergedJsonDF, + Row(1.0E-39D, BigDecimal(0.01)) :: + Row(1.0E38D, BigDecimal("92233720368547758070")) :: Nil + ) + } + test("Loading a JSON dataset from a text file with SQL") { val dir = Utils.createTempDir() dir.delete() @@ -847,7 +924,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") checkAnswer( - sql("select map from jsonWithSimpleMap"), + sql("select `map` from jsonWithSimpleMap"), Row(Map("a" -> 1)) :: Row(Map("b" -> 2)) :: Row(Map("c" -> 3)) :: @@ -856,7 +933,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) checkAnswer( - sql("select map['c'] from jsonWithSimpleMap"), + sql("select `map`['c'] from jsonWithSimpleMap"), Row(null) :: Row(null) :: Row(3) :: @@ -875,7 +952,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { jsonWithComplexMap.registerTempTable("jsonWithComplexMap") checkAnswer( - sql("select map from jsonWithComplexMap"), + sql("select `map` from jsonWithComplexMap"), Row(Map("a" -> Row(Seq(1, 2, 3, null), null))) :: Row(Map("b" -> Row(null, 2))) :: Row(Map("c" -> Row(Seq(), 4))) :: @@ -885,7 +962,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) checkAnswer( - sql("select map['a'].field1, map['c'].field2 from jsonWithComplexMap"), + sql("select `map`['a'].field1, `map`['c'].field2 from jsonWithComplexMap"), Row(Seq(1, 2, 3, null), null) :: Row(null, null) :: Row(null, 4) :: @@ -953,7 +1030,56 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } - test("Corrupt records") { + test("Corrupt records: FAILFAST mode") { + val schema = StructType( + StructField("a", StringType, true) :: Nil) + // `FAILFAST` mode should throw an exception for corrupt records. + val exceptionOne = intercept[SparkException] { + sqlContext.read + .option("mode", "FAILFAST") + .json(corruptRecords) + .collect() + } + assert(exceptionOne.getMessage.contains("Malformed line in FAILFAST mode: {")) + + val exceptionTwo = intercept[SparkException] { + sqlContext.read + .option("mode", "FAILFAST") + .schema(schema) + .json(corruptRecords) + .collect() + } + assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode: {")) + } + + test("Corrupt records: DROPMALFORMED mode") { + val schemaOne = StructType( + StructField("a", StringType, true) :: + StructField("b", StringType, true) :: + StructField("c", StringType, true) :: Nil) + val schemaTwo = StructType( + StructField("a", StringType, true) :: Nil) + // `DROPMALFORMED` mode should skip corrupt records + val jsonDFOne = sqlContext.read + .option("mode", "DROPMALFORMED") + .json(corruptRecords) + checkAnswer( + jsonDFOne, + Row("str_a_4", "str_b_4", "str_c_4") :: Nil + ) + assert(jsonDFOne.schema === schemaOne) + + val jsonDFTwo = sqlContext.read + .option("mode", "DROPMALFORMED") + .schema(schemaTwo) + .json(corruptRecords) + checkAnswer( + jsonDFTwo, + Row("str_a_4") :: Nil) + assert(jsonDFTwo.schema === schemaTwo) + } + + test("Corrupt records: PERMISSIVE mode") { // Test if we can query corrupt records. withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { withTempTable("jsonTable") { @@ -1007,6 +1133,27 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } } + test("SPARK-13953 Rename the corrupt record field via option") { + val jsonDF = sqlContext.read + .option("columnNameOfCorruptRecord", "_malformed") + .json(corruptRecords) + val schema = StructType( + StructField("_malformed", StringType, true) :: + StructField("a", StringType, true) :: + StructField("b", StringType, true) :: + StructField("c", StringType, true) :: Nil) + + assert(schema === jsonDF.schema) + checkAnswer( + jsonDF.selectExpr("a", "b", "c", "_malformed"), + Row(null, null, null, "{") :: + Row(null, null, null, """{"a":1, b:2}""") :: + Row(null, null, null, """{"a":{, b:3}""") :: + Row("str_a_4", "str_b_4", "str_c_4", null) :: + Row(null, null, null, "]") :: Nil + ) + } + test("SPARK-4068: nulls in arrays") { val jsonDF = sqlContext.read.json(nullsInArrays) jsonDF.registerTempTable("jsonTable") @@ -1086,7 +1233,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") val jsonDF = sqlContext.read.json(primitiveFieldAndType) - val primTable = sqlContext.read.json(jsonDF.toJSON) + val primTable = sqlContext.read.json(jsonDF.toJSON.rdd) primTable.registerTempTable("primitiveTable") checkAnswer( sql("select * from primitiveTable"), @@ -1099,7 +1246,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) val complexJsonDF = sqlContext.read.json(complexFieldAndType1) - val compTable = sqlContext.read.json(complexJsonDF.toJSON) + val compTable = sqlContext.read.json(complexJsonDF.toJSON.rdd) compTable.registerTempTable("complexTable") // Access elements of a primitive array. checkAnswer( @@ -1163,76 +1310,33 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("JSONRelation equality test") { - val relation0 = new JSONRelation( - Some(empty), - 1.0, - false, - Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(sqlContext) - val logicalRelation0 = LogicalRelation(relation0) - val relation1 = new JSONRelation( - Some(singleRow), - 1.0, - false, - Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(sqlContext) - val logicalRelation1 = LogicalRelation(relation1) - val relation2 = new JSONRelation( - Some(singleRow), - 0.5, - false, - Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(sqlContext) - val logicalRelation2 = LogicalRelation(relation2) - val relation3 = new JSONRelation( - Some(singleRow), - 1.0, - false, - Some(StructType(StructField("b", IntegerType, true) :: Nil)), - None, None)(sqlContext) - val logicalRelation3 = LogicalRelation(relation3) - - assert(relation0 !== relation1) - assert(!logicalRelation0.sameResult(logicalRelation1), - s"$logicalRelation0 and $logicalRelation1 should be considered not having the same result.") - - assert(relation1 === relation2) - assert(logicalRelation1.sameResult(logicalRelation2), - s"$logicalRelation1 and $logicalRelation2 should be considered having the same result.") - - assert(relation1 !== relation3) - assert(!logicalRelation1.sameResult(logicalRelation3), - s"$logicalRelation1 and $logicalRelation3 should be considered not having the same result.") - - assert(relation2 !== relation3) - assert(!logicalRelation2.sameResult(logicalRelation3), - s"$logicalRelation2 and $logicalRelation3 should be considered not having the same result.") - withTempPath(dir => { val path = dir.getCanonicalFile.toURI.toString sparkContext.parallelize(1 to 100) .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) - val d1 = ResolvedDataSource( + val d1 = DataSource( sqlContext, userSpecifiedSchema = None, partitionColumns = Array.empty[String], - provider = classOf[DefaultSource].getCanonicalName, - options = Map("path" -> path)) + bucketSpec = None, + className = classOf[DefaultSource].getCanonicalName, + options = Map("path" -> path)).resolveRelation() - val d2 = ResolvedDataSource( + val d2 = DataSource( sqlContext, userSpecifiedSchema = None, partitionColumns = Array.empty[String], - provider = classOf[DefaultSource].getCanonicalName, - options = Map("path" -> path)) + bucketSpec = None, + className = classOf[DefaultSource].getCanonicalName, + options = Map("path" -> path)).resolveRelation() assert(d1 === d2) }) } test("SPARK-6245 JsonRDD.inferSchema on empty RDD") { // This is really a test that it doesn't throw an exception - val emptySchema = InferSchema(empty, 1.0, "") + val emptySchema = InferSchema.infer(empty, "", new JSONOptions(Map())) assert(StructType(Seq()) === emptySchema) } @@ -1256,7 +1360,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-8093 Erase empty structs") { - val emptySchema = InferSchema(emptyRecords, 1.0, "") + val emptySchema = InferSchema.infer(emptyRecords, "", new JSONOptions(Map())) assert(StructType(Seq()) === emptySchema) } @@ -1325,7 +1429,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val constantValues = Seq( - "a string in binary".getBytes("UTF-8"), + "a string in binary".getBytes(StandardCharsets.UTF_8), null, true, 1.toByte, @@ -1393,4 +1497,186 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } + + test("SPARK-11544 test pathfilter") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = sqlContext.range(2) + df.write.json(path + "/p=1") + df.write.json(path + "/p=2") + assert(sqlContext.read.json(path).count() === 4) + + val clonedConf = new Configuration(hadoopConfiguration) + try { + // Setting it twice as the name of the propery has changed between hadoop versions. + hadoopConfiguration.setClass( + "mapred.input.pathFilter.class", + classOf[TestFileFilter], + classOf[PathFilter]) + hadoopConfiguration.setClass( + "mapreduce.input.pathFilter.class", + classOf[TestFileFilter], + classOf[PathFilter]) + assert(sqlContext.read.json(path).count() === 2) + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + } + } + } + + test("SPARK-12057 additional corrupt records do not throw exceptions") { + // Test if we can query corrupt records. + withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { + withTempTable("jsonTable") { + val schema = StructType( + StructField("_unparsed", StringType, true) :: + StructField("dummy", StringType, true) :: Nil) + + { + // We need to make sure we can infer the schema. + val jsonDF = sqlContext.read.json(additionalCorruptRecords) + assert(jsonDF.schema === schema) + } + + { + val jsonDF = sqlContext.read.schema(schema).json(additionalCorruptRecords) + jsonDF.registerTempTable("jsonTable") + + // In HiveContext, backticks should be used to access columns starting with a underscore. + checkAnswer( + sql( + """ + |SELECT dummy, _unparsed + |FROM jsonTable + """.stripMargin), + Row("test", null) :: + Row(null, """[1,2,3]""") :: + Row(null, """":"test", "a":1}""") :: + Row(null, """42""") :: + Row(null, """ ","ian":"test"}""") :: Nil + ) + } + } + } + } + + test("Parse JSON rows having an array type and a struct type in the same field.") { + withTempDir { dir => + val dir = Utils.createTempDir() + dir.delete() + val path = dir.getCanonicalPath + arrayAndStructRecords.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + + val schema = + StructType( + StructField("a", StructType( + StructField("b", StringType) :: Nil + )) :: Nil) + val jsonDF = sqlContext.read.schema(schema).json(path) + assert(jsonDF.count() == 2) + } + } + + test("SPARK-12872 Support to specify the option for compression codec") { + withTempDir { dir => + val dir = Utils.createTempDir() + dir.delete() + val path = dir.getCanonicalPath + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + + val jsonDF = sqlContext.read.json(path) + val jsonDir = new File(dir, "json").getCanonicalPath + jsonDF.coalesce(1).write + .format("json") + .option("compression", "gZiP") + .save(jsonDir) + + val compressedFiles = new File(jsonDir).listFiles() + assert(compressedFiles.exists(_.getName.endsWith(".json.gz"))) + + val jsonCopy = sqlContext.read + .format("json") + .load(jsonDir) + + assert(jsonCopy.count == jsonDF.count) + val jsonCopySome = jsonCopy.selectExpr("string", "long", "boolean") + val jsonDFSome = jsonDF.selectExpr("string", "long", "boolean") + checkAnswer(jsonCopySome, jsonDFSome) + } + } + + test("SPARK-13543 Write the output as uncompressed via option()") { + val clonedConf = new Configuration(hadoopConfiguration) + hadoopConfiguration.set("mapreduce.output.fileoutputformat.compress", "true") + hadoopConfiguration + .set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString) + hadoopConfiguration + .set("mapreduce.output.fileoutputformat.compress.codec", classOf[GzipCodec].getName) + hadoopConfiguration.set("mapreduce.map.output.compress", "true") + hadoopConfiguration.set("mapreduce.map.output.compress.codec", classOf[GzipCodec].getName) + withTempDir { dir => + try { + val dir = Utils.createTempDir() + dir.delete() + + val path = dir.getCanonicalPath + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + + val jsonDF = sqlContext.read.json(path) + val jsonDir = new File(dir, "json").getCanonicalPath + jsonDF.coalesce(1).write + .format("json") + .option("compression", "none") + .save(jsonDir) + + val compressedFiles = new File(jsonDir).listFiles() + assert(compressedFiles.exists(!_.getName.endsWith(".json.gz"))) + + val jsonCopy = sqlContext.read + .format("json") + .load(jsonDir) + + assert(jsonCopy.count == jsonDF.count) + val jsonCopySome = jsonCopy.selectExpr("string", "long", "boolean") + val jsonDFSome = jsonDF.selectExpr("string", "long", "boolean") + checkAnswer(jsonCopySome, jsonDFSome) + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + } + } + } + + test("Casting long as timestamp") { + withTempTable("jsonTable") { + val schema = (new StructType).add("ts", TimestampType) + val jsonDF = sqlContext.read.schema(schema).json(timestampAsLong) + + jsonDF.registerTempTable("jsonTable") + + checkAnswer( + sql("select ts from jsonTable"), + Row(java.sql.Timestamp.valueOf("2016-01-02 03:04:05")) + ) + } + } + + test("wide nested json table") { + val nested = (1 to 100).map { i => + s""" + |"c$i": $i + """.stripMargin + }.mkString(", ") + val json = s""" + |{"a": [{$nested}], "b": [{$nested}]} + """.stripMargin + val rdd = sqlContext.sparkContext.makeRDD(Seq(json)) + val df = sqlContext.read.json(rdd) + assert(df.schema.size === 2) + df.collect() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index 713d1da1cb515..2873c6a881bef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -188,6 +188,14 @@ private[json] trait TestJsonData { """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """]""" :: Nil) + def additionalCorruptRecords: RDD[String] = + sqlContext.sparkContext.parallelize( + """{"dummy":"test"}""" :: + """[1,2,3]""" :: + """":"test", "a":1}""" :: + """42""" :: + """ ","ian":"test"}""" :: Nil) + def emptyRecords: RDD[String] = sqlContext.sparkContext.parallelize( """{""" :: @@ -197,6 +205,22 @@ private[json] trait TestJsonData { """{"b": [{"c": {}}]}""" :: """]""" :: Nil) + def timestampAsLong: RDD[String] = + sqlContext.sparkContext.parallelize( + """{"ts":1451732645}""" :: Nil) + + def arrayAndStructRecords: RDD[String] = + sqlContext.sparkContext.parallelize( + """{"a": {"b": 1}}""" :: + """{"a": []}""" :: Nil) + + def floatingValueRecords: RDD[String] = + sqlContext.sparkContext.parallelize( + s"""{"a": 0.${"0" * 38}1, "b": 0.01}""" :: Nil) + + def bigIntegerRecords: RDD[String] = + sqlContext.sparkContext.parallelize( + s"""{"a": 1${"0" * 38}, "b": 92233720368547758070}""" :: Nil) lazy val singleRow: RDD[String] = sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala index 36b929ee1f409..f98ea8c5aeb80 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.execution.datasources.parquet -import java.io.File import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets import java.util.{List => JList, Map => JMap} import scala.collection.JavaConverters._ @@ -59,7 +59,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared .setLongColumn(i.toLong * 10) .setFloatColumn(i.toFloat + 0.1f) .setDoubleColumn(i.toDouble + 0.2d) - .setBinaryColumn(ByteBuffer.wrap(s"val_$i".getBytes("UTF-8"))) + .setBinaryColumn(ByteBuffer.wrap(s"val_$i".getBytes(StandardCharsets.UTF_8))) .setStringColumn(s"val_$i") .build()) } @@ -74,7 +74,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared i.toLong * 10, i.toFloat + 0.1f, i.toDouble + 0.2d, - s"val_$i".getBytes("UTF-8"), + s"val_$i".getBytes(StandardCharsets.UTF_8), s"val_$i") }) } @@ -103,7 +103,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared .setMaybeLongColumn(i.toLong * 10) .setMaybeFloatColumn(i.toFloat + 0.1f) .setMaybeDoubleColumn(i.toDouble + 0.2d) - .setMaybeBinaryColumn(ByteBuffer.wrap(s"val_$i".getBytes("UTF-8"))) + .setMaybeBinaryColumn(ByteBuffer.wrap(s"val_$i".getBytes(StandardCharsets.UTF_8))) .setMaybeStringColumn(s"val_$i") .build() } @@ -124,7 +124,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared i.toLong * 10, i.toFloat + 0.1f, i.toDouble + 0.2d, - s"val_$i".getBytes("UTF-8"), + s"val_$i".getBytes(StandardCharsets.UTF_8), s"val_$i") } }) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index 0835bd123049b..4217c81ff3e24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -21,9 +21,9 @@ import scala.collection.JavaConverters.{collectionAsScalaIterableConverter, mapA import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, PathFilter} +import org.apache.parquet.hadoop.{ParquetFileReader, ParquetWriter} import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.hadoop.api.WriteSupport.WriteContext -import org.apache.parquet.hadoop.{ParquetFileReader, ParquetWriter} import org.apache.parquet.io.api.RecordConsumer import org.apache.parquet.schema.{MessageType, MessageTypeParser} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala new file mode 100644 index 0000000000000..88fcfce0ec1bc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.test.SharedSQLContext + +// TODO: this needs a lot more testing but it's currently not easy to test with the parquet +// writer abstractions. Revisit. +class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContext { + import testImplicits._ + + val ROW = ((1).toByte, 2, 3L, "abc") + val NULL_ROW = ( + null.asInstanceOf[java.lang.Byte], + null.asInstanceOf[Integer], + null.asInstanceOf[java.lang.Long], + null.asInstanceOf[String]) + + test("All Types Dictionary") { + (1 :: 1000 :: Nil).foreach { n => { + withTempPath { dir => + List.fill(n)(ROW).toDF.repartition(1).write.parquet(dir.getCanonicalPath) + val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head + + val reader = new VectorizedParquetRecordReader + reader.initialize(file.asInstanceOf[String], null) + val batch = reader.resultBatch() + assert(reader.nextBatch()) + assert(batch.numRows() == n) + var i = 0 + while (i < n) { + assert(batch.column(0).getByte(i) == 1) + assert(batch.column(1).getInt(i) == 2) + assert(batch.column(2).getLong(i) == 3) + assert(batch.column(3).getUTF8String(i).toString == "abc") + i += 1 + } + reader.close() + } + }} + } + + test("All Types Null") { + (1 :: 100 :: Nil).foreach { n => { + withTempPath { dir => + val data = List.fill(n)(NULL_ROW).toDF + data.repartition(1).write.parquet(dir.getCanonicalPath) + val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head + + val reader = new VectorizedParquetRecordReader + reader.initialize(file.asInstanceOf[String], null) + val batch = reader.resultBatch() + assert(reader.nextBatch()) + assert(batch.numRows() == n) + var i = 0 + while (i < n) { + assert(batch.column(0).isNullAt(i)) + assert(batch.column(1).isNullAt(i)) + assert(batch.column(2).isNullAt(i)) + assert(batch.column(3).isNullAt(i)) + i += 1 + } + reader.close() + }} + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index c24c9f025dad7..51183e970d965 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -17,15 +17,22 @@ package org.apache.spark.sql.execution.datasources.parquet -import org.apache.parquet.filter2.predicate.Operators._ +import java.nio.charset.StandardCharsets + import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} +import org.apache.parquet.filter2.predicate.FilterApi._ +import org.apache.parquet.filter2.predicate.Operators.{Column => _, _} -import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelation} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -50,27 +57,31 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex val output = predicate.collect { case a: Attribute => a }.distinct withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { - val query = df - .select(output.map(e => Column(e)): _*) - .where(Column(predicate)) - - val analyzedPredicate = query.queryExecution.optimizedPlan.collect { - case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation, _)) => filters - }.flatten - assert(analyzedPredicate.nonEmpty) - - val selectedFilters = analyzedPredicate.flatMap(DataSourceStrategy.translateFilter) - assert(selectedFilters.nonEmpty) - - selectedFilters.foreach { pred => - val maybeFilter = ParquetFilters.createFilter(df.schema, pred) - assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") - maybeFilter.foreach { f => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + val query = df + .select(output.map(e => Column(e)): _*) + .where(Column(predicate)) + + var maybeRelation: Option[HadoopFsRelation] = None + val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, LogicalRelation(relation: HadoopFsRelation, _, _)) => + maybeRelation = Some(relation) + filters + }.flatten.reduceLeftOption(_ && _) + assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") + + val (_, selectedFilters) = + DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) + assert(selectedFilters.nonEmpty, "No filter is pushed down") + + selectedFilters.foreach { pred => + val maybeFilter = ParquetFilters.createFilter(df.schema, pred) + assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) - assert(f.getClass === filterClass) + maybeFilter.exists(_.getClass === filterClass) } + checker(stripSparkFilter(query), expected) } - checker(query, expected) } } @@ -91,7 +102,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex (implicit df: DataFrame): Unit = { def checkBinaryAnswer(df: DataFrame, expected: Seq[Row]) = { assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).sorted) { - df.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted + df.rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted } } @@ -111,7 +122,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkFilterPredicate('_1 === true, classOf[Eq[_]], true) checkFilterPredicate('_1 <=> true, classOf[Eq[_]], true) - checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false) + checkFilterPredicate('_1 =!= true, classOf[NotEq[_]], false) } } @@ -122,7 +133,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + checkFilterPredicate('_1 =!= 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) @@ -148,7 +159,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + checkFilterPredicate('_1 =!= 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) @@ -174,7 +185,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + checkFilterPredicate('_1 =!= 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) @@ -200,7 +211,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + checkFilterPredicate('_1 =!= 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) @@ -229,7 +240,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1") checkFilterPredicate('_1 <=> "1", classOf[Eq[_]], "1") checkFilterPredicate( - '_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) + '_1 =!= "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1") checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4") @@ -251,7 +262,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // See https://issues.apache.org/jira/browse/SPARK-11153 ignore("filter pushdown - binary") { implicit class IntToBinary(int: Int) { - def b: Array[Byte] = int.toString.getBytes("UTF-8") + def b: Array[Byte] = int.toString.getBytes(StandardCharsets.UTF_8) } withParquetDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df => @@ -263,7 +274,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.b)).toSeq) checkBinaryFilterPredicate( - '_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq) + '_1 =!= 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq) checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt[_]], 1.b) checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt[_]], 4.b) @@ -294,7 +305,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // If the "part = 1" filter gets pushed down, this query will throw an exception since // "part" is not a valid column in the actual Parquet file checkAnswer( - sqlContext.read.parquet(path).filter("part = 1"), + sqlContext.read.parquet(dir.getCanonicalPath).filter("part = 1"), (1 to 3).map(i => Row(i, i.toString, 1))) } } @@ -311,12 +322,53 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // If the "part = 1" filter gets pushed down, this query will throw an exception since // "part" is not a valid column in the actual Parquet file checkAnswer( - sqlContext.read.parquet(path).filter("a > 0 and (part = 0 or a > 1)"), + sqlContext.read.parquet(dir.getCanonicalPath).filter("a > 0 and (part = 0 or a > 1)"), (2 to 3).map(i => Row(i, i.toString, 1))) } } } + test("SPARK-12231: test the filter and empty project in partitioned DataSource scan") { + import testImplicits._ + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}" + (1 to 3).map(i => (i, i + 1, i + 2, i + 3)).toDF("a", "b", "c", "d"). + write.partitionBy("a").parquet(path) + + // The filter "a > 1 or b < 2" will not get pushed down, and the projection is empty, + // this query will throw an exception since the project from combinedFilter expect + // two projection while the + val df1 = sqlContext.read.parquet(dir.getCanonicalPath) + + assert(df1.filter("a > 1 or b < 2").count() == 2) + } + } + } + + test("SPARK-12231: test the new projection in partitioned DataSource scan") { + import testImplicits._ + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}" + (1 to 3).map(i => (i, i + 1, i + 2, i + 3)).toDF("a", "b", "c", "d"). + write.partitionBy("a").parquet(path) + + // test the generate new projection case + // when projects != partitionAndNormalColumnProjs + + val df1 = sqlContext.read.parquet(dir.getCanonicalPath) + + checkAnswer( + df1.filter("a > 1 or b > 2").orderBy("a").selectExpr("a", "b", "c", "d"), + (2 to 3).map(i => Row(i, i + 1, i + 2, i + 3))) + } + } + } + + test("SPARK-11103: Filter applied on merged Parquet schema with new column fails") { import testImplicits._ @@ -330,9 +382,167 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // If the "c = 1" filter gets pushed down, this query will throw an exception which // Parquet emits. This is a Parquet issue (PARQUET-389). + val df = sqlContext.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a") checkAnswer( - sqlContext.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a"), - (1 to 1).map(i => Row(i, i.toString, null))) + df, + Row(1, "1", null)) + + // The fields "a" and "c" only exist in one Parquet file. + assert(df.schema("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(df.schema("c").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + + val pathThree = s"${dir.getCanonicalPath}/table3" + df.write.parquet(pathThree) + + // We will remove the temporary metadata when writing Parquet file. + val schema = sqlContext.read.parquet(pathThree).schema + assert(schema.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) + + val pathFour = s"${dir.getCanonicalPath}/table4" + val dfStruct = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + dfStruct.select(struct("a").as("s")).write.parquet(pathFour) + + val pathFive = s"${dir.getCanonicalPath}/table5" + val dfStruct2 = sparkContext.parallelize(Seq((1, 1))).toDF("c", "b") + dfStruct2.select(struct("c").as("s")).write.parquet(pathFive) + + // If the "s.c = 1" filter gets pushed down, this query will throw an exception which + // Parquet emits. + val dfStruct3 = sqlContext.read.parquet(pathFour, pathFive).filter("s.c = 1") + .selectExpr("s") + checkAnswer(dfStruct3, Row(Row(null, 1))) + + // The fields "s.a" and "s.c" only exist in one Parquet file. + val field = dfStruct3.schema("s").dataType.asInstanceOf[StructType] + assert(field("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(field("c").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + + val pathSix = s"${dir.getCanonicalPath}/table6" + dfStruct3.write.parquet(pathSix) + + // We will remove the temporary metadata when writing Parquet file. + val forPathSix = sqlContext.read.parquet(pathSix).schema + assert(forPathSix.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) + + // sanity test: make sure optional metadata field is not wrongly set. + val pathSeven = s"${dir.getCanonicalPath}/table7" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathSeven) + val pathEight = s"${dir.getCanonicalPath}/table8" + (4 to 6).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathEight) + + val df2 = sqlContext.read.parquet(pathSeven, pathEight).filter("a = 1").selectExpr("a", "b") + checkAnswer( + df2, + Row(1, "1")) + + // The fields "a" and "b" exist in both two Parquet files. No metadata is set. + assert(!df2.schema("a").metadata.contains(StructType.metadataKeyForOptionalField)) + assert(!df2.schema("b").metadata.contains(StructType.metadataKeyForOptionalField)) + } + } + } + + // The unsafe row RecordReader does not support row by row filtering so run it with it disabled. + test("SPARK-11661 Still pushdown filters returned by unhandledFilters") { + import testImplicits._ + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/part=1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) + val df = sqlContext.read.parquet(path).filter("a = 2") + + // The result should be single row. + // When a filter is pushed to Parquet, Parquet can apply it to every row. + // So, we can check the number of rows returned from the Parquet + // to make sure our filter pushdown work. + assert(stripSparkFilter(df).count == 1) + } + } + } + } + + test("SPARK-12218: 'Not' is included in Parquet filter pushdown") { + import testImplicits._ + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/table1" + (1 to 5).map(i => (i, (i % 2).toString)).toDF("a", "b").write.parquet(path) + + checkAnswer( + sqlContext.read.parquet(path).where("not (a = 2) or not(b in ('1'))"), + (1 to 5).map(i => Row(i, (i % 2).toString))) + + checkAnswer( + sqlContext.read.parquet(path).where("not (a = 2 and b in ('1'))"), + (1 to 5).map(i => Row(i, (i % 2).toString))) + } + } + } + + test("SPARK-12218 Converting conjunctions into Parquet filter predicates") { + val schema = StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", StringType, nullable = true), + StructField("c", DoubleType, nullable = true) + )) + + assertResult(Some(and( + lt(intColumn("a"), 10: Integer), + gt(doubleColumn("c"), 1.5: java.lang.Double))) + ) { + ParquetFilters.createFilter( + schema, + sources.And( + sources.LessThan("a", 10), + sources.GreaterThan("c", 1.5D))) + } + + assertResult(None) { + ParquetFilters.createFilter( + schema, + sources.And( + sources.LessThan("a", 10), + sources.StringContains("b", "prefix"))) + } + + assertResult(None) { + ParquetFilters.createFilter( + schema, + sources.Not( + sources.And( + sources.GreaterThan("a", 1), + sources.StringContains("b", "prefix")))) + } + } + + test("SPARK-11164: test the parquet filter in") { + import testImplicits._ + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/table1" + (1 to 5).map(i => (i.toFloat, i%3)).toDF("a", "b").write.parquet(path) + + // When a filter is pushed to Parquet, Parquet can apply it to every row. + // So, we can check the number of rows returned from the Parquet + // to make sure our filter pushdown work. + val df = sqlContext.read.parquet(path).where("b in (0,2)") + assert(stripSparkFilter(df).count == 3) + + val df1 = sqlContext.read.parquet(path).where("not (b in (1))") + assert(stripSparkFilter(df1).count == 3) + + val df2 = sqlContext.read.parquet(path).where("not (b in (1,3) or a <= 2)") + assert(stripSparkFilter(df2).count == 2) + + val df3 = sqlContext.read.parquet(path).where("not (b in (1,3) and a <= 2)") + assert(stripSparkFilter(df3).count == 4) + + val df4 = sqlContext.read.parquet(path).where("not (a <= 2)") + assert(stripSparkFilter(df4).count == 3) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index f14b2886a9ecb..581095d3dc1c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -17,28 +17,30 @@ package org.apache.spark.sql.execution.datasources.parquet -import java.util.Collections - import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} -import org.apache.parquet.example.data.simple.SimpleGroup +import org.apache.parquet.column.{Encoding, ParquetProperties} import org.apache.parquet.example.data.{Group, GroupWriter} +import org.apache.parquet.example.data.simple.SimpleGroup import org.apache.parquet.hadoop._ import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.hadoop.api.WriteSupport.WriteContext -import org.apache.parquet.hadoop.metadata.{CompressionCodecName, FileMetaData, ParquetMetadata} +import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.parquet.io.api.RecordConsumer import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.SparkException import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -91,6 +93,35 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } + test("SPARK-11694 Parquet logical types are not being tested properly") { + val parquetSchema = MessageTypeParser.parseMessageType( + """message root { + | required int32 a(INT_8); + | required int32 b(INT_16); + | required int32 c(DATE); + | required int32 d(DECIMAL(1,0)); + | required int64 e(DECIMAL(10,0)); + | required binary f(UTF8); + | required binary g(ENUM); + | required binary h(DECIMAL(32,0)); + | required fixed_len_byte_array(32) i(DECIMAL(32,0)); + |} + """.stripMargin) + + val expectedSparkTypes = Seq(ByteType, ShortType, DateType, DecimalType(1, 0), + DecimalType(10, 0), StringType, StringType, DecimalType(32, 0), DecimalType(32, 0)) + + withTempPath { location => + val path = new Path(location.getCanonicalPath) + val conf = sparkContext.hadoopConfiguration + writeMetadata(parquetSchema, path, conf) + readParquetFile(path.toString)(df => { + val sparkTypes = df.schema.map(_.dataType) + assert(sparkTypes === expectedSparkTypes) + }) + } + } + test("string") { val data = (1 to 4).map(i => Tuple1(i.toString)) // Property spark.sql.parquet.binaryAsString shouldn't affect Parquet files written by Spark SQL @@ -114,7 +145,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withTempPath { dir => val data = makeDecimalRDD(DecimalType(precision, scale)) data.write.parquet(dir.getCanonicalPath) - checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) + readParquetFile(dir.getCanonicalPath) { df => { + checkAnswer(df, data.collect().toSeq) + }} } } } @@ -130,7 +163,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withTempPath { dir => val data = makeDateRDD() data.write.parquet(dir.getCanonicalPath) - checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) + readParquetFile(dir.getCanonicalPath) { df => + checkAnswer(df, data.collect().toSeq) + } } } @@ -206,6 +241,43 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } + test("SPARK-10113 Support for unsigned Parquet logical types") { + val parquetSchema = MessageTypeParser.parseMessageType( + """message root { + | required int32 c(UINT_32); + |} + """.stripMargin) + + withTempPath { location => + val path = new Path(location.getCanonicalPath) + val conf = sparkContext.hadoopConfiguration + writeMetadata(parquetSchema, path, conf) + val errorMessage = intercept[Throwable] { + sqlContext.read.parquet(path.toString).printSchema() + }.toString + assert(errorMessage.contains("Parquet type not supported")) + } + } + + test("SPARK-11692 Support for Parquet logical types, JSON and BSON (embedded types)") { + val parquetSchema = MessageTypeParser.parseMessageType( + """message root { + | required binary a(JSON); + | required binary b(BSON); + |} + """.stripMargin) + + val expectedSparkTypes = Seq(StringType, BinaryType) + + withTempPath { location => + val path = new Path(location.getCanonicalPath) + val conf = sparkContext.hadoopConfiguration + writeMetadata(parquetSchema, path, conf) + val sparkTypes = sqlContext.read.parquet(path.toString).schema.map(_.dataType) + assert(sparkTypes === expectedSparkTypes) + } + } + test("compression codec") { def compressionCodecFor(path: String, codecName: String): String = { val codecs = for { @@ -270,9 +342,10 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withTempDir { dir => val path = new Path(dir.toURI.toString, "part-r-0.parquet") makeRawParquetFile(path) - checkAnswer(sqlContext.read.parquet(path.toString), (0 until 10).map { i => - Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) - }) + readParquetFile(path.toString) { df => + checkAnswer(df, (0 until 10).map { i => + Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) }) + } } } @@ -298,7 +371,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withParquetFile((1 to 10).map(i => (i, i.toString))) { file => val newData = (11 to 20).map(i => (i, i.toString)) newData.toDF().write.format("parquet").mode(SaveMode.Overwrite).save(file) - checkAnswer(sqlContext.read.parquet(file), newData.map(Row.fromTuple)) + readParquetFile(file) { df => + checkAnswer(df, newData.map(Row.fromTuple)) + } } } @@ -307,7 +382,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withParquetFile(data) { file => val newData = (11 to 20).map(i => (i, i.toString)) newData.toDF().write.format("parquet").mode(SaveMode.Ignore).save(file) - checkAnswer(sqlContext.read.parquet(file), data.map(Row.fromTuple)) + readParquetFile(file) { df => + checkAnswer(df, data.map(Row.fromTuple)) + } } } @@ -327,7 +404,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withParquetFile(data) { file => val newData = (11 to 20).map(i => (i, i.toString)) newData.toDF().write.format("parquet").mode(SaveMode.Append).save(file) - checkAnswer(sqlContext.read.parquet(file), (data ++ newData).map(Row.fromTuple)) + readParquetFile(file) { df => + checkAnswer(df, (data ++ newData).map(Row.fromTuple)) + } } } @@ -350,75 +429,22 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { """.stripMargin) withTempPath { location => - val extraMetadata = Collections.singletonMap( - CatalystReadSupport.SPARK_METADATA_KEY, sparkSchema.toString) - val fileMetadata = new FileMetaData(parquetSchema, extraMetadata, "Spark") + val extraMetadata = Map(CatalystReadSupport.SPARK_METADATA_KEY -> sparkSchema.toString) val path = new Path(location.getCanonicalPath) - - ParquetFileWriter.writeMetadataFile( - sparkContext.hadoopConfiguration, - path, - Collections.singletonList( - new Footer(path, new ParquetMetadata(fileMetadata, Collections.emptyList())))) - - assertResult(sqlContext.read.parquet(path.toString).schema) { - StructType( - StructField("a", BooleanType, nullable = false) :: - StructField("b", IntegerType, nullable = false) :: - Nil) - } - } - } - - test("SPARK-6352 DirectParquetOutputCommitter") { - val clonedConf = new Configuration(hadoopConfiguration) - - // Write to a parquet file and let it fail. - // _temporary should be missing if direct output committer works. - try { - hadoopConfiguration.set("spark.sql.parquet.output.committer.class", - classOf[DirectParquetOutputCommitter].getCanonicalName) - sqlContext.udf.register("div0", (x: Int) => x / 0) - withTempPath { dir => - intercept[org.apache.spark.SparkException] { - sqlContext.range(1, 2).selectExpr("div0(id) as a").write.parquet(dir.getCanonicalPath) - } - val path = new Path(dir.getCanonicalPath, "_temporary") - val fs = path.getFileSystem(hadoopConfiguration) - assert(!fs.exists(path)) - } - } finally { - // Hadoop 1 doesn't have `Configuration.unset` - hadoopConfiguration.clear() - clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) - } - } - - test("SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible") { - val clonedConf = new Configuration(hadoopConfiguration) - - // Write to a parquet file and let it fail. - // _temporary should be missing if direct output committer works. - try { - hadoopConfiguration.set("spark.sql.parquet.output.committer.class", - "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") - sqlContext.udf.register("div0", (x: Int) => x / 0) - withTempPath { dir => - intercept[org.apache.spark.SparkException] { - sqlContext.range(1, 2).selectExpr("div0(id) as a").write.parquet(dir.getCanonicalPath) + val conf = sparkContext.hadoopConfiguration + writeMetadata(parquetSchema, path, conf, extraMetadata) + + readParquetFile(path.toString) { df => + assertResult(df.schema) { + StructType( + StructField("a", BooleanType, nullable = true) :: + StructField("b", IntegerType, nullable = true) :: + Nil) } - val path = new Path(dir.getCanonicalPath, "_temporary") - val fs = path.getFileSystem(hadoopConfiguration) - assert(!fs.exists(path)) } - } finally { - // Hadoop 1 doesn't have `Configuration.unset` - hadoopConfiguration.clear() - clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) } } - test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overridden") { withTempPath { dir => val clonedConf = new Configuration(hadoopConfiguration) @@ -489,24 +515,184 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } + test("SPARK-11044 Parquet writer version fixed as version1 ") { + // For dictionary encoding, Parquet changes the encoding types according to its writer + // version. So, this test checks one of the encoding types in order to ensure that + // the file is written with writer version2. + withTempPath { dir => + val clonedConf = new Configuration(hadoopConfiguration) + try { + // Write a Parquet file with writer version2. + hadoopConfiguration.set(ParquetOutputFormat.WRITER_VERSION, + ParquetProperties.WriterVersion.PARQUET_2_0.toString) + + // By default, dictionary encoding is enabled from Parquet 1.2.0 but + // it is enabled just in case. + hadoopConfiguration.setBoolean(ParquetOutputFormat.ENABLE_DICTIONARY, true) + val path = s"${dir.getCanonicalPath}/part-r-0.parquet" + sqlContext.range(1 << 16).selectExpr("(id % 4) AS i") + .coalesce(1).write.mode("overwrite").parquet(path) + + val blockMetadata = readFooter(new Path(path), hadoopConfiguration).getBlocks.asScala.head + val columnChunkMetadata = blockMetadata.getColumns.asScala.head + + // If the file is written with version2, this should include + // Encoding.RLE_DICTIONARY type. For version1, it is Encoding.PLAIN_DICTIONARY + assert(columnChunkMetadata.getEncodings.contains(Encoding.RLE_DICTIONARY)) + } finally { + // Manually clear the hadoop configuration for other tests. + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + } + } + } + + test("null and non-null strings") { + // Create a dataset where the first values are NULL and then some non-null values. The + // number of non-nulls needs to be bigger than the ParquetReader batch size. + val data: Dataset[String] = sqlContext.range(200).map (i => + if (i < 150) null + else "a" + ) + val df = data.toDF("col") + assert(df.agg("col" -> "count").collect().head.getLong(0) == 50) + + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/data" + df.write.parquet(path) + + readParquetFile(path) { df2 => + assert(df2.agg("col" -> "count").collect().head.getLong(0) == 50) + } + } + } + test("read dictionary encoded decimals written as INT32") { - checkAnswer( - // Decimal column in this file is encoded using plain dictionary - readResourceParquetFile("dec-in-i32.parquet"), - sqlContext.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec)) + ("true" :: "false" :: Nil).foreach { vectorized => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { + checkAnswer( + // Decimal column in this file is encoded using plain dictionary + readResourceParquetFile("dec-in-i32.parquet"), + sqlContext.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec)) + } + } } test("read dictionary encoded decimals written as INT64") { - checkAnswer( - // Decimal column in this file is encoded using plain dictionary - readResourceParquetFile("dec-in-i64.parquet"), - sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec)) + ("true" :: "false" :: Nil).foreach { vectorized => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { + checkAnswer( + // Decimal column in this file is encoded using plain dictionary + readResourceParquetFile("dec-in-i64.parquet"), + sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec)) + } + } + } + + test("read dictionary encoded decimals written as FIXED_LEN_BYTE_ARRAY") { + ("true" :: "false" :: Nil).foreach { vectorized => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { + checkAnswer( + // Decimal column in this file is encoded using plain dictionary + readResourceParquetFile("dec-in-fixed-len.parquet"), + sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec)) + } + } } - // TODO Adds test case for reading dictionary encoded decimals written as `FIXED_LEN_BYTE_ARRAY` - // The Parquet writer version Spark 1.6 and prior versions use is `PARQUET_1_0`, which doesn't - // provide dictionary encoding support for `FIXED_LEN_BYTE_ARRAY`. Should add a test here once - // we upgrade to `PARQUET_2_0`. + test("SPARK-12589 copy() on rows returned from reader works for strings") { + withTempPath { dir => + val data = (1, "abc") ::(2, "helloabcde") :: Nil + data.toDF().write.parquet(dir.getCanonicalPath) + var hash1: Int = 0 + var hash2: Int = 0 + (false :: true :: Nil).foreach { v => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> v.toString) { + val df = sqlContext.read.parquet(dir.getCanonicalPath) + val rows = df.queryExecution.toRdd.map(_.copy()).collect() + val unsafeRows = rows.map(_.asInstanceOf[UnsafeRow]) + if (!v) { + hash1 = unsafeRows(0).hashCode() + hash2 = unsafeRows(1).hashCode() + } else { + assert(hash1 == unsafeRows(0).hashCode()) + assert(hash2 == unsafeRows(1).hashCode()) + } + } + } + } + } + + test("VectorizedParquetRecordReader - direct path read") { + val data = (0 to 10).map(i => (i, (i + 'a').toChar.toString)) + withTempPath { dir => + sqlContext.createDataFrame(data).repartition(1).write.parquet(dir.getCanonicalPath) + val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0); + { + val reader = new VectorizedParquetRecordReader + try { + reader.initialize(file, null) + val result = mutable.ArrayBuffer.empty[(Int, String)] + while (reader.nextKeyValue()) { + val row = reader.getCurrentValue.asInstanceOf[InternalRow] + val v = (row.getInt(0), row.getString(1)) + result += v + } + assert(data == result) + } finally { + reader.close() + } + } + + // Project just one column + { + val reader = new VectorizedParquetRecordReader + try { + reader.initialize(file, ("_2" :: Nil).asJava) + val result = mutable.ArrayBuffer.empty[(String)] + while (reader.nextKeyValue()) { + val row = reader.getCurrentValue.asInstanceOf[InternalRow] + result += row.getString(0) + } + assert(data.map(_._2) == result) + } finally { + reader.close() + } + } + + // Project columns in opposite order + { + val reader = new VectorizedParquetRecordReader + try { + reader.initialize(file, ("_2" :: "_1" :: Nil).asJava) + val result = mutable.ArrayBuffer.empty[(String, Int)] + while (reader.nextKeyValue()) { + val row = reader.getCurrentValue.asInstanceOf[InternalRow] + val v = (row.getString(0), row.getInt(1)) + result += v + } + assert(data.map { x => (x._2, x._1) } == result) + } finally { + reader.close() + } + } + + // Empty projection + { + val reader = new VectorizedParquetRecordReader + try { + reader.initialize(file, List[String]().asJava) + var result = 0 + while (reader.nextKeyValue()) { + result += 1 + } + assert(result == data.length) + } finally { + reader.close() + } + } + } + } } class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 61cc0da50865c..f875b54cd6649 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -29,7 +29,9 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.execution.datasources.{LogicalRelation, PartitionSpec, Partition, PartitioningUtils} +import org.apache.spark.sql.execution.datasources.{LogicalRelation, PartitionDirectory => Partition, PartitioningUtils, PartitionSpec} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -66,7 +68,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha "hdfs://host:9000/path/a=10.5/b=hello") var exception = intercept[AssertionError] { - parsePartitions(paths.map(new Path(_)), defaultPartitionName, true) + parsePartitions(paths.map(new Path(_)), defaultPartitionName, true, Set.empty[Path]) } assert(exception.getMessage().contains("Conflicting directory structures detected")) @@ -76,7 +78,37 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha "hdfs://host:9000/path/a=10/b=20", "hdfs://host:9000/path/_temporary/path") - parsePartitions(paths.map(new Path(_)), defaultPartitionName, true) + parsePartitions( + paths.map(new Path(_)), + defaultPartitionName, + true, + Set(new Path("hdfs://host:9000/path/"))) + + // Valid + paths = Seq( + "hdfs://host:9000/path/something=true/table/", + "hdfs://host:9000/path/something=true/table/_temporary", + "hdfs://host:9000/path/something=true/table/a=10/b=20", + "hdfs://host:9000/path/something=true/table/_temporary/path") + + parsePartitions( + paths.map(new Path(_)), + defaultPartitionName, + true, + Set(new Path("hdfs://host:9000/path/something=true/table"))) + + // Valid + paths = Seq( + "hdfs://host:9000/path/table=true/", + "hdfs://host:9000/path/table=true/_temporary", + "hdfs://host:9000/path/table=true/a=10/b=20", + "hdfs://host:9000/path/table=true/_temporary/path") + + parsePartitions( + paths.map(new Path(_)), + defaultPartitionName, + true, + Set(new Path("hdfs://host:9000/path/table=true"))) // Invalid paths = Seq( @@ -85,7 +117,11 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha "hdfs://host:9000/path/path1") exception = intercept[AssertionError] { - parsePartitions(paths.map(new Path(_)), defaultPartitionName, true) + parsePartitions( + paths.map(new Path(_)), + defaultPartitionName, + true, + Set(new Path("hdfs://host:9000/path/"))) } assert(exception.getMessage().contains("Conflicting directory structures detected")) @@ -101,19 +137,24 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha "hdfs://host:9000/tmp/tables/nonPartitionedTable2") exception = intercept[AssertionError] { - parsePartitions(paths.map(new Path(_)), defaultPartitionName, true) + parsePartitions( + paths.map(new Path(_)), + defaultPartitionName, + true, + Set(new Path("hdfs://host:9000/tmp/tables/"))) } assert(exception.getMessage().contains("Conflicting directory structures detected")) } test("parse partition") { def check(path: String, expected: Option[PartitionValues]): Unit = { - assert(expected === parsePartition(new Path(path), defaultPartitionName, true)._1) + val actual = parsePartition(new Path(path), defaultPartitionName, true, Set.empty[Path])._1 + assert(expected === actual) } def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = { val message = intercept[T] { - parsePartition(new Path(path), defaultPartitionName, true) + parsePartition(new Path(path), defaultPartitionName, true, Set.empty[Path]) }.getMessage assert(message.contains(expected)) @@ -152,8 +193,17 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } test("parse partitions") { - def check(paths: Seq[String], spec: PartitionSpec): Unit = { - assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName, true) === spec) + def check( + paths: Seq[String], + spec: PartitionSpec, + rootPaths: Set[Path] = Set.empty[Path]): Unit = { + val actualSpec = + parsePartitions( + paths.map(new Path(_)), + defaultPartitionName, + true, + rootPaths) + assert(actualSpec === spec) } check(Seq( @@ -232,7 +282,9 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha test("parse partitions with type inference disabled") { def check(paths: Seq[String], spec: PartitionSpec): Unit = { - assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName, false) === spec) + val actualSpec = + parsePartitions(paths.map(new Path(_)), defaultPartitionName, false, Set.empty[Path]) + assert(actualSpec === spec) } check(Seq( @@ -513,7 +565,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha (1 to 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(dir.getCanonicalPath) val queryExecution = sqlContext.read.parquet(dir.getCanonicalPath).queryExecution queryExecution.analyzed.collectFirst { - case LogicalRelation(relation: ParquetRelation, _) => + case LogicalRelation(relation: HadoopFsRelation, _, _) => assert(relation.partitionSpec === PartitionSpec.emptySpec) }.getOrElse { fail(s"Expecting a ParquetRelation2, but got:\n$queryExecution") @@ -590,6 +642,93 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } } + test("SPARK-11678: Partition discovery stops at the root path of the dataset") { + withTempPath { dir => + val tablePath = new File(dir, "key=value") + val df = (1 to 3).map(i => (i, i, i, i)).toDF("a", "b", "c", "d") + + df.write + .format("parquet") + .partitionBy("b", "c", "d") + .save(tablePath.getCanonicalPath) + + Files.touch(new File(s"${tablePath.getCanonicalPath}/", "_SUCCESS")) + Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) + + checkAnswer(sqlContext.read.format("parquet").load(tablePath.getCanonicalPath), df) + } + + withTempPath { dir => + val path = new File(dir, "key=value") + val tablePath = new File(path, "table") + + val df = (1 to 3).map(i => (i, i, i, i)).toDF("a", "b", "c", "d") + + df.write + .format("parquet") + .partitionBy("b", "c", "d") + .save(tablePath.getCanonicalPath) + + Files.touch(new File(s"${tablePath.getCanonicalPath}/", "_SUCCESS")) + Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) + + checkAnswer(sqlContext.read.format("parquet").load(tablePath.getCanonicalPath), df) + } + } + + test("use basePath to specify the root dir of a partitioned table.") { + withTempPath { dir => + val tablePath = new File(dir, "table") + val df = (1 to 3).map(i => (i, i, i, i)).toDF("a", "b", "c", "d") + + df.write + .format("parquet") + .partitionBy("b", "c", "d") + .save(tablePath.getCanonicalPath) + + val twoPartitionsDF = + sqlContext + .read + .option("basePath", tablePath.getCanonicalPath) + .parquet( + s"${tablePath.getCanonicalPath}/b=1", + s"${tablePath.getCanonicalPath}/b=2") + + checkAnswer(twoPartitionsDF, df.filter("b != 3")) + + intercept[AssertionError] { + sqlContext + .read + .parquet( + s"${tablePath.getCanonicalPath}/b=1", + s"${tablePath.getCanonicalPath}/b=2") + } + } + } + + test("_SUCCESS should not break partitioning discovery") { + Seq(1, 32).foreach { threshold => + // We have two paths to list files, one at driver side, another one that we use + // a Spark job. We need to test both ways. + withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> threshold.toString) { + withTempPath { dir => + val tablePath = new File(dir, "table") + val df = (1 to 3).map(i => (i, i, i, i)).toDF("a", "b", "c", "d") + + df.write + .format("parquet") + .partitionBy("b", "c", "d") + .save(tablePath.getCanonicalPath) + + Files.touch(new File(s"${tablePath.getCanonicalPath}/b=1", "_SUCCESS")) + Files.touch(new File(s"${tablePath.getCanonicalPath}/b=1/c=1", "_SUCCESS")) + Files.touch(new File(s"${tablePath.getCanonicalPath}/b=1/c=1/d=1", "_SUCCESS")) + checkAnswer(sqlContext.read.format("parquet").load(tablePath.getCanonicalPath), df) + } + } + } + } + test("listConflictingPartitionColumns") { def makeExpectedMessage(colNameLists: Seq[String], paths: Seq[String]): String = { val conflictingColNameLists = colNameLists.zipWithIndex.map { case (list, index) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 70fae32b7e7a1..7d206e7bc443d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -22,9 +22,10 @@ import java.io.File import org.apache.hadoop.fs.Path import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow import org.apache.spark.sql.execution.datasources.parquet.TestingUDT.{NestedStruct, NestedStructUDT} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -45,11 +46,13 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("appending") { val data = (0 until 10).map(i => (i, i.toString)) sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") - withParquetTable(data, "t") { + // Query appends, don't test with both read modes. + withParquetTable(data, "t", false) { sql("INSERT INTO TABLE t SELECT * FROM tmp") checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) } - sqlContext.catalog.unregisterTable(TableIdentifier("tmp")) + sqlContext.sessionState.catalog.dropTable( + TableIdentifier("tmp"), ignoreIfNotExists = true) } test("overwriting") { @@ -59,7 +62,8 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple)) } - sqlContext.catalog.unregisterTable(TableIdentifier("tmp")) + sqlContext.sessionState.catalog.dropTable( + TableIdentifier("tmp"), ignoreIfNotExists = true) } test("self-join") { @@ -69,7 +73,8 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext (maybeInt, i.toString) } - withParquetTable(data, "t") { + // TODO: vectorized doesn't work here because it requires UnsafeRows + withParquetTable(data, "t", false) { val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") val queryOutput = selfJoin.queryExecution.analyzed.output @@ -252,6 +257,19 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } + test("SPARK-11997 parquet with null partition values") { + withTempPath { dir => + val path = dir.getCanonicalPath + sqlContext.range(1, 3) + .selectExpr("if(id % 2 = 0, null, id) AS n", "id") + .write.partitionBy("n").parquet(path) + + checkAnswer( + sqlContext.read.parquet(path).filter("n is null"), + Row(2, null)) + } + } + // This test case is ignored because of parquet-mr bug PARQUET-370 ignore("SPARK-10301 requested schema clipping - schemas with disjoint sets of fields") { withTempPath { dir => @@ -561,6 +579,16 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext assert(CatalystReadSupport.expandUDT(schema) === expected) } + + test("read/write wide table") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = sqlContext.range(1000).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*) + df.write.mode(SaveMode.Overwrite).parquet(path) + checkAnswer(sqlContext.read.parquet(path), df) + } + } } object TestingUDT { @@ -574,14 +602,11 @@ object TestingUDT { .add("b", LongType, nullable = false) .add("c", DoubleType, nullable = false) - override def serialize(obj: Any): Any = { + override def serialize(n: NestedStruct): Any = { val row = new SpecificMutableRow(sqlType.asInstanceOf[StructType].map(_.dataType)) - obj match { - case n: NestedStruct => - row.setInt(0, n.a) - row.setLong(1, n.b) - row.setDouble(2, n.c) - } + row.setInt(0, n.a) + row.setLong(1, n.b) + row.setDouble(2, n.c) } override def userClass: Class[NestedStruct] = classOf[NestedStruct] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala new file mode 100644 index 0000000000000..cef541f0444be --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -0,0 +1,352 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import scala.collection.JavaConverters._ +import scala.util.Try + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.SQLContext +import org.apache.spark.util.{Benchmark, Utils} + +/** + * Benchmark to measure parquet read performance. + * To run this: + * spark-submit --class --jars + */ +object ParquetReadBenchmark { + val conf = new SparkConf() + conf.set("spark.sql.parquet.compression.codec", "snappy") + val sc = new SparkContext("local[1]", "test-sql-context", conf) + val sqlContext = new SQLContext(sc) + + // Set default configs. Individual cases will change them if necessary. + sqlContext.conf.setConfString(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + sqlContext.conf.setConfString(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(sqlContext.dropTempTable) + } + + def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption) + (keys, values).zipped.foreach(sqlContext.conf.setConfString) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => sqlContext.conf.setConfString(key, value) + case (key, None) => sqlContext.conf.unsetConf(key) + } + } + } + + def intScanBenchmark(values: Int): Unit = { + // Benchmarks running through spark sql. + val sqlBenchmark = new Benchmark("SQL Single Int Column Scan", values) + // Benchmarks driving reader component directly. + val parquetReaderBenchmark = new Benchmark("Parquet Reader Single Int Column Scan", values) + + withTempPath { dir => + withTempTable("t1", "tempTable") { + sqlContext.range(values).registerTempTable("t1") + sqlContext.sql("select cast(id as INT) as id from t1") + .write.parquet(dir.getCanonicalPath) + sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + + sqlBenchmark.addCase("SQL Parquet Vectorized") { iter => + sqlContext.sql("select sum(id) from tempTable").collect() + } + + sqlBenchmark.addCase("SQL Parquet MR") { iter => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + sqlContext.sql("select sum(id) from tempTable").collect() + } + } + + val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray + // Driving the parquet reader in batch mode directly. + parquetReaderBenchmark.addCase("ParquetReader Vectorized") { num => + var sum = 0L + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new VectorizedParquetRecordReader + try { + reader.initialize(p, ("id" :: Nil).asJava) + val batch = reader.resultBatch() + val col = batch.column(0) + while (reader.nextBatch()) { + val numRows = batch.numRows() + var i = 0 + while (i < numRows) { + if (!col.isNullAt(i)) sum += col.getInt(i) + i += 1 + } + } + } finally { + reader.close() + } + } + } + + // Decoding in vectorized but having the reader return rows. + parquetReaderBenchmark.addCase("ParquetReader Vectorized -> Row") { num => + var sum = 0L + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new VectorizedParquetRecordReader + try { + reader.initialize(p, ("id" :: Nil).asJava) + val batch = reader.resultBatch() + while (reader.nextBatch()) { + val it = batch.rowIterator() + while (it.hasNext) { + val record = it.next() + if (!record.isNullAt(0)) sum += record.getInt(0) + } + } + } finally { + reader.close() + } + } + } + + /* + Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz + SQL Single Int Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + SQL Parquet Vectorized 215 / 262 73.0 13.7 1.0X + SQL Parquet MR 1946 / 2083 8.1 123.7 0.1X + */ + sqlBenchmark.run() + + /* + Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz + Parquet Reader Single Int Column Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + ParquetReader Vectorized 123 / 152 127.8 7.8 1.0X + ParquetReader Vectorized -> Row 165 / 180 95.2 10.5 0.7X + */ + parquetReaderBenchmark.run() + } + } + } + + def intStringScanBenchmark(values: Int): Unit = { + withTempPath { dir => + withTempTable("t1", "tempTable") { + sqlContext.range(values).registerTempTable("t1") + sqlContext.sql("select cast(id as INT) as c1, cast(id as STRING) as c2 from t1") + .write.parquet(dir.getCanonicalPath) + sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + + val benchmark = new Benchmark("Int and String Scan", values) + + benchmark.addCase("SQL Parquet Vectorized") { iter => + sqlContext.sql("select sum(c1), sum(length(c2)) from tempTable").collect + } + + benchmark.addCase("SQL Parquet MR") { iter => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + sqlContext.sql("select sum(c1), sum(length(c2)) from tempTable").collect + } + } + + val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray + + /* + Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz + Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + SQL Parquet Vectorized 628 / 720 16.7 59.9 1.0X + SQL Parquet MR 1905 / 2239 5.5 181.7 0.3X + */ + benchmark.run() + } + } + } + + def stringDictionaryScanBenchmark(values: Int): Unit = { + withTempPath { dir => + withTempTable("t1", "tempTable") { + sqlContext.range(values).registerTempTable("t1") + sqlContext.sql("select cast((id % 200) + 10000 as STRING) as c1 from t1") + .write.parquet(dir.getCanonicalPath) + sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + + val benchmark = new Benchmark("String Dictionary", values) + + benchmark.addCase("SQL Parquet Vectorized") { iter => + sqlContext.sql("select sum(length(c1)) from tempTable").collect + } + + benchmark.addCase("SQL Parquet MR") { iter => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + sqlContext.sql("select sum(length(c1)) from tempTable").collect + } + } + + /* + Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz + String Dictionary: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + SQL Parquet Vectorized 329 / 337 31.9 31.4 1.0X + SQL Parquet MR 1131 / 1325 9.3 107.8 0.3X + */ + benchmark.run() + } + } + } + + def partitionTableScanBenchmark(values: Int): Unit = { + withTempPath { dir => + withTempTable("t1", "tempTable") { + sqlContext.range(values).registerTempTable("t1") + sqlContext.sql("select id % 2 as p, cast(id as INT) as id from t1") + .write.partitionBy("p").parquet(dir.getCanonicalPath) + sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + + val benchmark = new Benchmark("Partitioned Table", values) + + benchmark.addCase("Read data column") { iter => + sqlContext.sql("select sum(id) from tempTable").collect + } + + benchmark.addCase("Read partition column") { iter => + sqlContext.sql("select sum(p) from tempTable").collect + } + + benchmark.addCase("Read both columns") { iter => + sqlContext.sql("select sum(p), sum(id) from tempTable").collect + } + + /* + Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz + Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Read data column 191 / 250 82.1 12.2 1.0X + Read partition column 82 / 86 192.4 5.2 2.3X + Read both columns 220 / 248 71.5 14.0 0.9X + */ + benchmark.run() + } + } + } + + def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { + withTempPath { dir => + withTempTable("t1", "tempTable") { + sqlContext.range(values).registerTempTable("t1") + sqlContext.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " + + s"IF(rand(2) < $fractionOfNulls, NULL, cast(id as STRING)) as c2 from t1") + .write.parquet(dir.getCanonicalPath) + sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + + val benchmark = new Benchmark("String with Nulls Scan", values) + + benchmark.addCase("SQL Parquet Vectorized") { iter => + sqlContext.sql("select sum(length(c2)) from tempTable where c1 is " + + "not NULL and c2 is not NULL").collect() + } + + val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray + benchmark.addCase("PR Vectorized") { num => + var sum = 0 + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new VectorizedParquetRecordReader + try { + reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) + val batch = reader.resultBatch() + while (reader.nextBatch()) { + val rowIterator = batch.rowIterator() + while (rowIterator.hasNext) { + val row = rowIterator.next() + val value = row.getUTF8String(0) + if (!row.isNullAt(0) && !row.isNullAt(1)) sum += value.numBytes() + } + } + } finally { + reader.close() + } + } + } + + benchmark.addCase("PR Vectorized (Null Filtering)") { num => + var sum = 0L + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new VectorizedParquetRecordReader + try { + reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) + val batch = reader.resultBatch() + batch.filterNullsInColumn(0) + batch.filterNullsInColumn(1) + while (reader.nextBatch()) { + val rowIterator = batch.rowIterator() + while (rowIterator.hasNext) { + sum += rowIterator.next().getUTF8String(0).numBytes() + } + } + } finally { + reader.close() + } + } + } + + /* + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + String with Nulls Scan (0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + SQL Parquet Vectorized 1229 / 1648 8.5 117.2 1.0X + PR Vectorized 833 / 846 12.6 79.4 1.5X + PR Vectorized (Null Filtering) 732 / 782 14.3 69.8 1.7X + + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + String with Nulls Scan (50%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + SQL Parquet Vectorized 995 / 1053 10.5 94.9 1.0X + PR Vectorized 732 / 772 14.3 69.8 1.4X + PR Vectorized (Null Filtering) 725 / 790 14.5 69.1 1.4X + + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + String with Nulls Scan (95%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + SQL Parquet Vectorized 326 / 333 32.2 31.1 1.0X + PR Vectorized 190 / 200 55.1 18.2 1.7X + PR Vectorized (Null Filtering) 168 / 172 62.2 16.1 1.9X + */ + + benchmark.run() + } + } + } + + def main(args: Array[String]): Unit = { + intScanBenchmark(1024 * 1024 * 15) + intStringScanBenchmark(1024 * 1024 * 10) + stringDictionaryScanBenchmark(1024 * 1024 * 10) + partitionTableScanBenchmark(1024 * 1024 * 15) + for (fractionOfNulls <- List(0.0, 0.50, 0.95)) { + stringWithNullsScanBenchmark(1024 * 1024 * 10, fractionOfNulls) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 60fa81b1ab819..90e3d50714ef3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.parquet.schema.MessageTypeParser +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -260,7 +261,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { int96AsTimestamp = true, writeLegacyParquetFormat = true) - testSchemaInference[Tuple1[Pair[Int, String]]]( + testSchemaInference[Tuple1[(Int, String)]]( "struct", """ |message root { @@ -449,6 +450,35 @@ class ParquetSchemaSuite extends ParquetSchemaTest { }.getMessage.contains("detected conflicting schemas")) } + test("schema merging failure error message") { + withTempPath { dir => + val path = dir.getCanonicalPath + sqlContext.range(3).write.parquet(s"$path/p=1") + sqlContext.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2") + + val message = intercept[SparkException] { + sqlContext.read.option("mergeSchema", "true").parquet(path).schema + }.getMessage + + assert(message.contains("Failed merging schema of file")) + } + + // test for second merging (after read Parquet schema in parallel done) + withTempPath { dir => + val path = dir.getCanonicalPath + sqlContext.range(3).write.parquet(s"$path/p=1") + sqlContext.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2") + + sqlContext.sparkContext.conf.set("spark.default.parallelism", "20") + + val message = intercept[SparkException] { + sqlContext.read.option("mergeSchema", "true").parquet(path).schema + }.getMessage + + assert(message.contains("Failed merging schema:")) + } + } + // ======================================================= // Tests for converting Parquet LIST to Catalyst ArrayType // ======================================================= diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 8ffb01fc5b584..e8c524e9e550d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -26,12 +26,14 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.parquet.format.converter.ParquetMetadataConverter -import org.apache.parquet.hadoop.metadata.{BlockMetaData, FileMetaData, ParquetMetadata} import org.apache.parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} +import org.apache.parquet.hadoop.metadata.{BlockMetaData, FileMetaData, ParquetMetadata} +import org.apache.parquet.schema.MessageType +import org.apache.spark.sql.{DataFrame, SaveMode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, SQLConf, SaveMode} /** * A helper trait that provides convenient facilities for Parquet testing. @@ -42,6 +44,20 @@ import org.apache.spark.sql.{DataFrame, SQLConf, SaveMode} */ private[sql] trait ParquetTest extends SQLTestUtils { + /** + * Reads the parquet file at `path` + */ + protected def readParquetFile(path: String, testVectorized: Boolean = true) + (f: DataFrame => Unit) = { + (true :: false :: Nil).foreach { vectorized => + if (!vectorized || testVectorized) { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { + f(sqlContext.read.parquet(path.toString)) + } + } + } + } + /** * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` * returns. @@ -60,9 +76,9 @@ private[sql] trait ParquetTest extends SQLTestUtils { * which is then passed to `f`. The Parquet file will be deleted after `f` returns. */ protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag] - (data: Seq[T]) + (data: Seq[T], testVectorized: Boolean = true) (f: DataFrame => Unit): Unit = { - withParquetFile(data)(path => f(sqlContext.read.parquet(path))) + withParquetFile(data)(path => readParquetFile(path.toString, testVectorized)(f)) } /** @@ -71,9 +87,9 @@ private[sql] trait ParquetTest extends SQLTestUtils { * Parquet file will be dropped/deleted after `f` returns. */ protected def withParquetTable[T <: Product: ClassTag: TypeTag] - (data: Seq[T], tableName: String) + (data: Seq[T], tableName: String, testVectorized: Boolean = true) (f: => Unit): Unit = { - withParquetDataFrame(data) { df => + withParquetDataFrame(data, testVectorized) { df => sqlContext.registerDataFrameAsTable(df, tableName) withTempTable(tableName)(f) } @@ -117,6 +133,21 @@ private[sql] trait ParquetTest extends SQLTestUtils { ParquetFileWriter.writeMetadataFile(configuration, path, Seq(footer).asJava) } + /** + * This is an overloaded version of `writeMetadata` above to allow writing customized + * Parquet schema. + */ + protected def writeMetadata( + parquetSchema: MessageType, path: Path, configuration: Configuration, + extraMetadata: Map[String, String] = Map.empty[String, String]): Unit = { + val extraMetadataAsJava = extraMetadata.asJava + val createdBy = s"Apache Spark ${org.apache.spark.SPARK_VERSION}" + val fileMetadata = new FileMetaData(parquetSchema, extraMetadataAsJava, createdBy) + val parquetMetadata = new ParquetMetadata(fileMetadata, Seq.empty[BlockMetaData].asJava) + val footer = new Footer(path, parquetMetadata) + ParquetFileWriter.writeMetadataFile(configuration, path, Seq(footer).asJava) + } + protected def readAllFootersWithoutSummaryFiles( path: Path, configuration: Configuration): Seq[Footer] = { val fs = path.getFileSystem(configuration) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index 0a2306c06646c..47330f1db369e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -17,12 +17,19 @@ package org.apache.spark.sql.execution.datasources.text +import java.io.File + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.io.SequenceFile.CompressionType +import org.apache.hadoop.io.compress.GzipCodec + +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{StringType, StructType} -import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.util.Utils - class TextSuite extends QueryTest with SharedSQLContext { test("reading text file") { @@ -30,16 +37,16 @@ class TextSuite extends QueryTest with SharedSQLContext { } test("SQLContext.read.text() API") { - verifyFrame(sqlContext.read.text(testFile)) + verifyFrame(sqlContext.read.text(testFile).toDF()) } - test("writing") { - val df = sqlContext.read.text(testFile) + test("SPARK-12562 verify write.text() can handle column name beyond `value`") { + val df = sqlContext.read.text(testFile).withColumnRenamed("value", "adwrasdf") val tempFile = Utils.createTempDir() tempFile.delete() df.write.text(tempFile.getCanonicalPath) - verifyFrame(sqlContext.read.text(tempFile.getCanonicalPath)) + verifyFrame(sqlContext.read.text(tempFile.getCanonicalPath).toDF()) Utils.deleteRecursively(tempFile) } @@ -58,6 +65,53 @@ class TextSuite extends QueryTest with SharedSQLContext { } } + test("SPARK-13503 Support to specify the option for compression codec for TEXT") { + val testDf = sqlContext.read.text(testFile) + val extensionNameMap = Map("bzip2" -> ".bz2", "deflate" -> ".deflate", "gzip" -> ".gz") + extensionNameMap.foreach { + case (codecName, extension) => + val tempDir = Utils.createTempDir() + val tempDirPath = tempDir.getAbsolutePath + testDf.write.option("compression", codecName).mode(SaveMode.Overwrite).text(tempDirPath) + val compressedFiles = new File(tempDirPath).listFiles() + assert(compressedFiles.exists(_.getName.endsWith(s".txt$extension"))) + verifyFrame(sqlContext.read.text(tempDirPath).toDF()) + } + + val errMsg = intercept[IllegalArgumentException] { + val tempDirPath = Utils.createTempDir().getAbsolutePath + testDf.write.option("compression", "illegal").mode(SaveMode.Overwrite).text(tempDirPath) + } + assert(errMsg.getMessage.contains("Codec [illegal] is not available. " + + "Known codecs are")) + } + + test("SPARK-13543 Write the output as uncompressed via option()") { + val clonedConf = new Configuration(hadoopConfiguration) + hadoopConfiguration.set("mapreduce.output.fileoutputformat.compress", "true") + hadoopConfiguration + .set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString) + hadoopConfiguration + .set("mapreduce.output.fileoutputformat.compress.codec", classOf[GzipCodec].getName) + hadoopConfiguration.set("mapreduce.map.output.compress", "true") + hadoopConfiguration.set("mapreduce.map.output.compress.codec", classOf[GzipCodec].getName) + withTempDir { dir => + try { + val testDf = sqlContext.read.text(testFile) + val tempDir = Utils.createTempDir() + val tempDirPath = tempDir.getAbsolutePath + testDf.write.option("compression", "none").mode(SaveMode.Overwrite).text(tempDirPath) + val compressedFiles = new File(tempDirPath).listFiles() + assert(compressedFiles.exists(!_.getName.endsWith(".txt.gz"))) + verifyFrame(sqlContext.read.text(tempDirPath).toDF()) + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + } + } + } + private def testFile: String = { Thread.currentThread().getContextClassLoader.getResource("text-suite.txt").toString } @@ -65,7 +119,7 @@ class TextSuite extends QueryTest with SharedSQLContext { /** Verifies data and schema. */ private def verifyFrame(df: DataFrame): Unit = { // schema - assert(df.schema == new StructType().add("text", StringType)) + assert(df.schema == new StructType().add("value", StringType)) // verify content val data = df.collect() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 22189477d277d..8aa0114d98d74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -19,10 +19,23 @@ package org.apache.spark.sql.execution.debug import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData.TestData class DebuggingSuite extends SparkFunSuite with SharedSQLContext { test("DataFrame.debug()") { testData.debug() } + + test("Dataset.debug()") { + import testImplicits._ + testData.as[TestData].debug() + } + + test("debugCodegen") { + val res = codegenString(sqlContext.range(10).groupBy("id").count().queryExecution.executedPlan) + assert(res.contains("Subtree 1 / 2")) + assert(res.contains("Subtree 2 / 2")) + assert(res.contains("Object[]")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index dcbfdca71acb6..babe7ef70f99d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,11 +22,12 @@ import scala.reflect.ClassTag import org.scalatest.BeforeAndAfterAll import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} +import org.apache.spark.sql.{QueryTest, SQLContext} +import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.functions._ -import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest} /** - * Test various broadcast join operators with unsafe enabled. + * Test various broadcast join operators. * * Tests in this suite we need to run Spark in local-cluster mode. In particular, the use of * unsafe map in [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered @@ -45,8 +46,6 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { .setAppName("testing") val sc = new SparkContext(conf) sqlContext = new SQLContext(sc) - sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true) - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) } override def afterAll(): Unit = { @@ -64,7 +63,8 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { // Comparison at the end is for broadcast left semi join val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") val df3 = df1.join(broadcast(df2), joinExpression, joinType) - val plan = df3.queryExecution.executedPlan + val plan = + EnsureRequirements(sqlContext.sessionState.conf).apply(df3.queryExecution.sparkPlan) assert(plan.collect { case p: T => p }.size === 1) plan.executeCollect() } @@ -75,11 +75,11 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { } test("unsafe broadcast hash outer join updates peak execution memory") { - testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer") + testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash outer join", "left_outer") } test("unsafe broadcast left semi join updates peak execution memory") { - testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi") + testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast left semi join", "leftsemi") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala new file mode 100644 index 0000000000000..8cdfa8afd098a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftAnti, LeftSemi} +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} + +class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { + + private lazy val left = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(1, 2.0), + Row(2, 1.0), + Row(2, 1.0), + Row(3, 3.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) + + private lazy val right = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( + Row(2, 3.0), + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + private lazy val condition = { + And((left.col("a") === right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) + } + + private lazy val conditionNEQ = { + And((left.col("a") < right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) + } + + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable + private def testExistenceJoin( + testName: String, + joinType: JoinType, + leftRows: => DataFrame, + rightRows: => DataFrame, + condition: => Expression, + expectedAnswer: Seq[Row]): Unit = { + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join) + } + + test(s"$testName using ShuffledHashJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + ShuffledHashJoin( + leftKeys, rightKeys, joinType, BuildRight, boundCondition, left, right)), + expectedAnswer, + sortAnswers = true) + } + } + } + + test(s"$testName using BroadcastHashJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + BroadcastHashJoin( + leftKeys, rightKeys, joinType, BuildRight, boundCondition, left, right)), + expectedAnswer, + sortAnswers = true) + } + } + } + + test(s"$testName using BroadcastNestedLoopJoin build left") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + BroadcastNestedLoopJoin(left, right, BuildLeft, joinType, Some(condition))), + expectedAnswer, + sortAnswers = true) + } + } + + test(s"$testName using BroadcastNestedLoopJoin build right") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + BroadcastNestedLoopJoin(left, right, BuildRight, joinType, Some(condition))), + expectedAnswer, + sortAnswers = true) + } + } + } + + testExistenceJoin( + "basic test for left semi join", + LeftSemi, + left, + right, + condition, + Seq(Row(2, 1.0), Row(2, 1.0))) + + testExistenceJoin( + "basic test for left semi non equal join", + LeftSemi, + left, + right, + conditionNEQ, + Seq(Row(1, 2.0), Row(1, 2.0), Row(2, 1.0), Row(2, 1.0))) + + testExistenceJoin( + "basic test for anti join", + LeftAnti, + left, + right, + condition, + Seq(Row(1, 2.0), Row(1, 2.0), Row(3, 3.0), Row(6, null), Row(null, 5.0), Row(null, null))) + + testExistenceJoin( + "basic test for anti non equal join", + LeftAnti, + left, + right, + conditionNEQ, + Seq(Row(3, 3.0), Row(6, null), Row(null, 5.0), Row(null, null))) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index e5fd9e277fc61..371a9ed617d65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -19,76 +19,43 @@ package org.apache.spark.sql.execution.joins import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.collection.CompactBuffer - class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { - // Key is simply the record itself - private val keyProjection = new Projection { - override def apply(row: InternalRow): InternalRow = row - } - - test("GeneralHashedRelation") { - val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) - val numDataRows = SQLMetrics.createLongMetric(sparkContext, "data") - val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) - assert(hashed.isInstanceOf[GeneralHashedRelation]) - - assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) - assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1))) - assert(hashed.get(InternalRow(10)) === null) - - val data2 = CompactBuffer[InternalRow](data(2)) - data2 += data(2) - assert(hashed.get(data(2)) === data2) - assert(numDataRows.value.value === data.length) - } - - test("UniqueKeyHashedRelation") { - val data = Array(InternalRow(0), InternalRow(1), InternalRow(2)) - val numDataRows = SQLMetrics.createLongMetric(sparkContext, "data") - val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) - assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) - - assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) - assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1))) - assert(hashed.get(data(2)) === CompactBuffer[InternalRow](data(2))) - assert(hashed.get(InternalRow(10)) === null) - - val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation] - assert(uniqHashed.getValue(data(0)) === data(0)) - assert(uniqHashed.getValue(data(1)) === data(1)) - assert(uniqHashed.getValue(data(2)) === data(2)) - assert(uniqHashed.getValue(InternalRow(10)) === null) - assert(numDataRows.value.value === data.length) - } + val mm = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) test("UnsafeHashedRelation") { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) - val numDataRows = SQLMetrics.createLongMetric(sparkContext, "data") val toUnsafe = UnsafeProjection.create(schema) - val unsafeData = data.map(toUnsafe(_).copy()).toArray + val unsafeData = data.map(toUnsafe(_).copy()) + val buildKey = Seq(BoundReference(0, IntegerType, false)) - val keyGenerator = UnsafeProjection.create(buildKey) - val hashed = UnsafeHashedRelation(unsafeData.iterator, numDataRows, keyGenerator, 1) + val hashed = UnsafeHashedRelation(unsafeData.iterator, buildKey, 1, mm) assert(hashed.isInstanceOf[UnsafeHashedRelation]) - assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) - assert(hashed.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) + assert(hashed.get(unsafeData(0)).toArray === Array(unsafeData(0))) + assert(hashed.get(unsafeData(1)).toArray === Array(unsafeData(1))) assert(hashed.get(toUnsafe(InternalRow(10))) === null) val data2 = CompactBuffer[InternalRow](unsafeData(2).copy()) data2 += unsafeData(2).copy() - assert(hashed.get(unsafeData(2)) === data2) + assert(hashed.get(unsafeData(2)).toArray === data2.toArray) val os = new ByteArrayOutputStream() val out = new ObjectOutputStream(os) @@ -97,11 +64,10 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) val hashed2 = new UnsafeHashedRelation() hashed2.readExternal(in) - assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) - assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) + assert(hashed2.get(unsafeData(0)).toArray === Array(unsafeData(0))) + assert(hashed2.get(unsafeData(1)).toArray === Array(unsafeData(1))) assert(hashed2.get(toUnsafe(InternalRow(10))) === null) - assert(hashed2.get(unsafeData(2)) === data2) - assert(numDataRows.value.value === data.length) + assert(hashed2.get(unsafeData(2)).toArray === data2) val os2 = new ByteArrayOutputStream() val out2 = new ObjectOutputStream(os2) @@ -113,10 +79,17 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { } test("test serialization empty hash map") { + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val binaryMap = new BytesToBytesMap(taskMemoryManager, 1, 1) val os = new ByteArrayOutputStream() val out = new ObjectOutputStream(os) - val hashed = new UnsafeHashedRelation( - new java.util.HashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) + val hashed = new UnsafeHashedRelation(1, binaryMap) hashed.writeExternal(out) out.flush() val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) @@ -134,4 +107,46 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { out2.flush() assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray)) } + + test("LongToUnsafeRowMap") { + val unsafeProj = UnsafeProjection.create( + Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, true))) + val rows = (0 until 100).map(i => unsafeProj(InternalRow(i, i + 1)).copy()) + val key = Seq(BoundReference(0, IntegerType, false)) + val longRelation = LongHashedRelation(rows.iterator, key, 10, mm) + assert(longRelation.keyIsUnique) + (0 until 100).foreach { i => + val row = longRelation.getValue(i) + assert(row.getInt(0) === i) + assert(row.getInt(1) === i + 1) + } + + val longRelation2 = LongHashedRelation(rows.iterator ++ rows.iterator, key, 100, mm) + assert(!longRelation2.keyIsUnique) + (0 until 100).foreach { i => + val rows = longRelation2.get(i).toArray + assert(rows.length === 2) + assert(rows(0).getInt(0) === i) + assert(rows(0).getInt(1) === i + 1) + assert(rows(1).getInt(0) === i) + assert(rows(1).getInt(1) === i + 1) + } + + val os = new ByteArrayOutputStream() + val out = new ObjectOutputStream(os) + longRelation2.writeExternal(out) + out.flush() + val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) + val relation = new LongHashedRelation() + relation.readExternal(in) + assert(!relation.keyIsUnique) + (0 until 100).foreach { i => + val rows = relation.get(i).toArray + assert(rows.length === 2) + assert(rows(0).getInt(0) === i) + assert(rows(0).getInt(1) === i + 1) + assert(rows(1).getInt(0) === i) + assert(rows(1).getInt(1) === i + 1) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 066c16e535c76..3cb3ef1ffa2f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -17,17 +17,20 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.{DataFrame, execution, Row, SQLConf} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { - import testImplicits.localSeqToDataFrameHolder + import testImplicits.newProductEncoder + import testImplicits.localSeqToDatasetHolder private lazy val myUpperCaseData = sqlContext.createDataFrame( sparkContext.parallelize(Seq( @@ -88,9 +91,15 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { leftPlan: SparkPlan, rightPlan: SparkPlan, side: BuildSide) = { - val broadcastHashJoin = - execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan) - boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) + val broadcastJoin = joins.BroadcastHashJoin( + leftKeys, + rightKeys, + Inner, + side, + boundCondition, + leftPlan, + rightPlan) + EnsureRequirements(sqlContext.sessionState.conf).apply(broadcastJoin) } def makeShuffledHashJoin( @@ -101,10 +110,10 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { rightPlan: SparkPlan, side: BuildSide) = { val shuffledHashJoin = - execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan) + joins.ShuffledHashJoin(leftKeys, rightKeys, Inner, side, None, leftPlan, rightPlan) val filteredJoin = boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) - EnsureRequirements(sqlContext).apply(filteredJoin) + EnsureRequirements(sqlContext.sessionState.conf).apply(filteredJoin) } def makeSortMergeJoin( @@ -114,9 +123,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { leftPlan: SparkPlan, rightPlan: SparkPlan) = { val sortMergeJoin = - execution.joins.SortMergeJoin(leftKeys, rightKeys, leftPlan, rightPlan) - val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin) - EnsureRequirements(sqlContext).apply(filteredJoin) + joins.SortMergeJoin(leftKeys, rightKeys, Inner, boundCondition, leftPlan, rightPlan) + EnsureRequirements(sqlContext.sessionState.conf).apply(sortMergeJoin) } test(s"$testName using BroadcastHashJoin (build=left)") { @@ -177,6 +185,33 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } } } + + test(s"$testName using CartesianProduct") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + Filter(condition(), CartesianProduct(left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using BroadcastNestedLoopJoin build left") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoin(left, right, BuildLeft, Inner, Some(condition())), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using BroadcastNestedLoopJoin build right") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoin(left, right, BuildRight, Inner, Some(condition())), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } } testInnerJoin( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 09e0237a7cc50..4cacb20aa0791 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -17,14 +17,16 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.{DataFrame, Row, SQLConf} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} -import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType} +import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { @@ -74,24 +76,34 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { ExtractEquiJoinKeys.unapply(join) } - test(s"$testName using ShuffledHashOuterJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(sqlContext).apply( - ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + if (joinType != FullOuter) { + test(s"$testName using ShuffledHashJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val buildSide = if (joinType == LeftOuter) BuildRight else BuildLeft + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext.sessionState.conf).apply( + ShuffledHashJoin( + leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } } if (joinType != FullOuter) { - test(s"$testName using BroadcastHashOuterJoin") { + test(s"$testName using BroadcastHashJoin") { + val buildSide = joinType match { + case LeftOuter => BuildRight + case RightOuter => BuildLeft + case _ => fail(s"Unsupported join type $joinType") + } extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), + BroadcastHashJoin( + leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } @@ -99,17 +111,35 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { } } - test(s"$testName using SortMergeOuterJoin") { + test(s"$testName using SortMergeJoin") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(sqlContext).apply( - SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), + EnsureRequirements(sqlContext.sessionState.conf).apply( + SortMergeJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } } } + + test(s"$testName using BroadcastNestedLoopJoin build left") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoin(left, right, BuildLeft, joinType, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using BroadcastNestedLoopJoin build right") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoin(left, right, BuildRight, joinType, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } } // --- Basic outer joins ------------------------------------------------------------------------ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala deleted file mode 100644 index 3afd762942bcf..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.sql.{SQLConf, DataFrame, Row} -import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys -import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression} -import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} - -class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { - - private lazy val left = sqlContext.createDataFrame( - sparkContext.parallelize(Seq( - Row(1, 2.0), - Row(1, 2.0), - Row(2, 1.0), - Row(2, 1.0), - Row(3, 3.0), - Row(null, null), - Row(null, 5.0), - Row(6, null) - )), new StructType().add("a", IntegerType).add("b", DoubleType)) - - private lazy val right = sqlContext.createDataFrame( - sparkContext.parallelize(Seq( - Row(2, 3.0), - Row(2, 3.0), - Row(3, 2.0), - Row(4, 1.0), - Row(null, null), - Row(null, 5.0), - Row(6, null) - )), new StructType().add("c", IntegerType).add("d", DoubleType)) - - private lazy val condition = { - And((left.col("a") === right.col("c")).expr, - LessThan(left.col("b").expr, right.col("d").expr)) - } - - // Note: the input dataframes and expression must be evaluated lazily because - // the SQLContext should be used only within a test to keep SQL tests stable - private def testLeftSemiJoin( - testName: String, - leftRows: => DataFrame, - rightRows: => DataFrame, - condition: => Expression, - expectedAnswer: Seq[Product]): Unit = { - - def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) - ExtractEquiJoinKeys.unapply(join) - } - - test(s"$testName using LeftSemiJoinHash") { - extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(left.sqlContext).apply( - LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - } - - test(s"$testName using BroadcastLeftSemiJoinHash") { - extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - } - - test(s"$testName using LeftSemiJoinBNL") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - LeftSemiJoinBNL(left, right, Some(condition)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - } - - testLeftSemiJoin( - "basic test", - left, - right, - condition, - Seq( - (2, 1.0), - (2, 1.0) - ) - ) -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala deleted file mode 100644 index efc3227dd60d8..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation - -/** - * A dummy [[LocalNode]] that just returns rows from a [[LocalRelation]]. - */ -private[local] case class DummyNode( - output: Seq[Attribute], - relation: LocalRelation, - conf: SQLConf) - extends LocalNode(conf) { - - import DummyNode._ - - private var index: Int = CLOSED - private val input: Seq[InternalRow] = relation.data - - def this(output: Seq[Attribute], data: Seq[Product], conf: SQLConf = new SQLConf) { - this(output, LocalRelation.fromProduct(output, data), conf) - } - - def isOpen: Boolean = index != CLOSED - - override def children: Seq[LocalNode] = Seq.empty - - override def open(): Unit = { - index = -1 - } - - override def next(): Boolean = { - index += 1 - index < input.size - } - - override def fetch(): InternalRow = { - assert(index >= 0 && index < input.size) - input(index) - } - - override def close(): Unit = { - index = CLOSED - } -} - -private object DummyNode { - val CLOSED: Int = Int.MinValue -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala deleted file mode 100644 index bbd94d8da2d11..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.catalyst.dsl.expressions._ - - -class ExpandNodeSuite extends LocalNodeTest { - - private def testExpand(inputData: Array[(Int, Int)] = Array.empty): Unit = { - val inputNode = new DummyNode(kvIntAttributes, inputData) - val projections = Seq(Seq('k + 'v, 'k - 'v), Seq('k * 'v, 'k / 'v)) - val expandNode = new ExpandNode(conf, projections, inputNode.output, inputNode) - val resolvedNode = resolveExpressions(expandNode) - val expectedOutput = { - val firstHalf = inputData.map { case (k, v) => (k + v, k - v) } - val secondHalf = inputData.map { case (k, v) => (k * v, k / v) } - firstHalf ++ secondHalf - } - val actualOutput = resolvedNode.collect().map { case row => - (row.getInt(0), row.getInt(1)) - } - assert(actualOutput.toSet === expectedOutput.toSet) - } - - test("empty") { - testExpand() - } - - test("basic") { - testExpand((1 to 100).map { i => (i, i * 1000) }.toArray) - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala deleted file mode 100644 index 4eadce646d379..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala +++ /dev/null @@ -1,45 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.catalyst.dsl.expressions._ - - -class FilterNodeSuite extends LocalNodeTest { - - private def testFilter(inputData: Array[(Int, Int)] = Array.empty): Unit = { - val cond = 'k % 2 === 0 - val inputNode = new DummyNode(kvIntAttributes, inputData) - val filterNode = new FilterNode(conf, cond, inputNode) - val resolvedNode = resolveExpressions(filterNode) - val expectedOutput = inputData.filter { case (k, _) => k % 2 == 0 } - val actualOutput = resolvedNode.collect().map { case row => - (row.getInt(0), row.getInt(1)) - } - assert(actualOutput === expectedOutput) - } - - test("empty") { - testFilter() - } - - test("basic") { - testFilter((1 to 100).map { i => (i, i) }.toArray) - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala deleted file mode 100644 index 8c2e78b2a9db7..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.mockito.Mockito.{mock, when} - -import org.apache.spark.broadcast.TorrentBroadcast -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, UnsafeProjection, Expression} -import org.apache.spark.sql.execution.joins.{HashedRelation, BuildLeft, BuildRight, BuildSide} - -class HashJoinNodeSuite extends LocalNodeTest { - - // Test all combinations of the two dimensions: with/out unsafe and build sides - private val maybeUnsafeAndCodegen = Seq(false, true) - private val buildSides = Seq(BuildLeft, BuildRight) - maybeUnsafeAndCodegen.foreach { unsafeAndCodegen => - buildSides.foreach { buildSide => - testJoin(unsafeAndCodegen, buildSide) - } - } - - /** - * Builds a [[HashedRelation]] based on a resolved `buildKeys` - * and a resolved `buildNode`. - */ - private def buildHashedRelation( - conf: SQLConf, - buildKeys: Seq[Expression], - buildNode: LocalNode): HashedRelation = { - - val isUnsafeMode = - conf.codegenEnabled && - conf.unsafeEnabled && - UnsafeProjection.canSupport(buildKeys) - - val buildSideKeyGenerator = - if (isUnsafeMode) { - UnsafeProjection.create(buildKeys, buildNode.output) - } else { - new InterpretedMutableProjection(buildKeys, buildNode.output) - } - - buildNode.prepare() - buildNode.open() - val hashedRelation = HashedRelation(buildNode, buildSideKeyGenerator) - buildNode.close() - - hashedRelation - } - - /** - * Test inner hash join with varying degrees of matches. - */ - private def testJoin( - unsafeAndCodegen: Boolean, - buildSide: BuildSide): Unit = { - val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe" - val testNamePrefix = s"$simpleOrUnsafe / $buildSide" - val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray - val conf = new SQLConf - conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen) - conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen) - - // Actual test body - def runTest(leftInput: Array[(Int, String)], rightInput: Array[(Int, String)]): Unit = { - val rightInputMap = rightInput.toMap - val leftNode = new DummyNode(joinNameAttributes, leftInput) - val rightNode = new DummyNode(joinNicknameAttributes, rightInput) - val makeBinaryHashJoinNode = (node1: LocalNode, node2: LocalNode) => { - val binaryHashJoinNode = - BinaryHashJoinNode(conf, Seq('id1), Seq('id2), buildSide, node1, node2) - resolveExpressions(binaryHashJoinNode) - } - val makeBroadcastJoinNode = (node1: LocalNode, node2: LocalNode) => { - val leftKeys = Seq('id1.attr) - val rightKeys = Seq('id2.attr) - // Figure out the build side and stream side. - val (buildNode, buildKeys, streamedNode, streamedKeys) = buildSide match { - case BuildLeft => (node1, leftKeys, node2, rightKeys) - case BuildRight => (node2, rightKeys, node1, leftKeys) - } - // Resolve the expressions of the build side and then create a HashedRelation. - val resolvedBuildNode = resolveExpressions(buildNode) - val resolvedBuildKeys = resolveExpressions(buildKeys, resolvedBuildNode) - val hashedRelation = buildHashedRelation(conf, resolvedBuildKeys, resolvedBuildNode) - val broadcastHashedRelation = mock(classOf[TorrentBroadcast[HashedRelation]]) - when(broadcastHashedRelation.value).thenReturn(hashedRelation) - - val hashJoinNode = - BroadcastHashJoinNode( - conf, - streamedKeys, - streamedNode, - buildSide, - resolvedBuildNode.output, - broadcastHashedRelation) - resolveExpressions(hashJoinNode) - } - - val expectedOutput = leftInput - .filter { case (k, _) => rightInputMap.contains(k) } - .map { case (k, v) => (k, v, k, rightInputMap(k)) } - - Seq(makeBinaryHashJoinNode, makeBroadcastJoinNode).foreach { makeNode => - val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode - val hashJoinNode = makeUnsafeNode(leftNode, rightNode) - - val actualOutput = hashJoinNode.collect().map { row => - // (id, name, id, nickname) - (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) - } - assert(actualOutput === expectedOutput) - } - } - - test(s"$testNamePrefix: empty") { - runTest(Array.empty, Array.empty) - runTest(someData, Array.empty) - runTest(Array.empty, someData) - } - - test(s"$testNamePrefix: no matches") { - val someIrrelevantData = (10000 to 100100).map { i => (i, "piper" + i) }.toArray - runTest(someData, Array.empty) - runTest(Array.empty, someData) - runTest(someData, someIrrelevantData) - runTest(someIrrelevantData, someData) - } - - test(s"$testNamePrefix: partial matches") { - val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray - runTest(someData, someOtherData) - runTest(someOtherData, someData) - } - - test(s"$testNamePrefix: full matches") { - val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) }.toArray - runTest(someData, someSuperRelevantData) - runTest(someSuperRelevantData, someData) - } - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala deleted file mode 100644 index c0ad2021b204a..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - - -class IntersectNodeSuite extends LocalNodeTest { - - test("basic") { - val n = 100 - val leftData = (1 to n).filter { i => i % 2 == 0 }.map { i => (i, i) }.toArray - val rightData = (1 to n).filter { i => i % 3 == 0 }.map { i => (i, i) }.toArray - val leftNode = new DummyNode(kvIntAttributes, leftData) - val rightNode = new DummyNode(kvIntAttributes, rightData) - val intersectNode = new IntersectNode(conf, leftNode, rightNode) - val expectedOutput = leftData.intersect(rightData) - val actualOutput = intersectNode.collect().map { case row => - (row.getInt(0), row.getInt(1)) - } - assert(actualOutput === expectedOutput) - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala deleted file mode 100644 index fb790636a3689..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - - -class LimitNodeSuite extends LocalNodeTest { - - private def testLimit(inputData: Array[(Int, Int)] = Array.empty, limit: Int = 10): Unit = { - val inputNode = new DummyNode(kvIntAttributes, inputData) - val limitNode = new LimitNode(conf, limit, inputNode) - val expectedOutput = inputData.take(limit) - val actualOutput = limitNode.collect().map { case row => - (row.getInt(0), row.getInt(1)) - } - assert(actualOutput === expectedOutput) - } - - test("empty") { - testLimit() - } - - test("basic") { - testLimit((1 to 100).map { i => (i, i) }.toArray, 20) - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala deleted file mode 100644 index 0d1ed99eec6cd..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala +++ /dev/null @@ -1,73 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - - -class LocalNodeSuite extends LocalNodeTest { - private val data = (1 to 100).map { i => (i, i) }.toArray - - test("basic open, next, fetch, close") { - val node = new DummyNode(kvIntAttributes, data) - assert(!node.isOpen) - node.open() - assert(node.isOpen) - data.foreach { case (k, v) => - assert(node.next()) - // fetch should be idempotent - val fetched = node.fetch() - assert(node.fetch() === fetched) - assert(node.fetch() === fetched) - assert(node.fetch().numFields === 2) - assert(node.fetch().getInt(0) === k) - assert(node.fetch().getInt(1) === v) - } - assert(!node.next()) - node.close() - assert(!node.isOpen) - } - - test("asIterator") { - val node = new DummyNode(kvIntAttributes, data) - val iter = node.asIterator - node.open() - data.foreach { case (k, v) => - // hasNext should be idempotent - assert(iter.hasNext) - assert(iter.hasNext) - val item = iter.next() - assert(item.numFields === 2) - assert(item.getInt(0) === k) - assert(item.getInt(1) === v) - } - intercept[NoSuchElementException] { - iter.next() - } - node.close() - } - - test("collect") { - val node = new DummyNode(kvIntAttributes, data) - node.open() - val collected = node.collect() - assert(collected.size === data.size) - assert(collected.forall(_.size === 2)) - assert(collected.map { case row => (row.getInt(0), row.getInt(0)) } === data) - node.close() - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala deleted file mode 100644 index 615c417093612..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{Expression, AttributeReference} -import org.apache.spark.sql.types.{IntegerType, StringType} - - -class LocalNodeTest extends SparkFunSuite { - - protected val conf: SQLConf = new SQLConf - protected val kvIntAttributes = Seq( - AttributeReference("k", IntegerType)(), - AttributeReference("v", IntegerType)()) - protected val joinNameAttributes = Seq( - AttributeReference("id1", IntegerType)(), - AttributeReference("name", StringType)()) - protected val joinNicknameAttributes = Seq( - AttributeReference("id2", IntegerType)(), - AttributeReference("nickname", StringType)()) - - /** - * Wrap a function processing two [[LocalNode]]s such that: - * (1) all input rows are automatically converted to unsafe rows - * (2) all output rows are automatically converted back to safe rows - */ - protected def wrapForUnsafe( - f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { - (left: LocalNode, right: LocalNode) => { - val _left = ConvertToUnsafeNode(conf, left) - val _right = ConvertToUnsafeNode(conf, right) - val r = f(_left, _right) - ConvertToSafeNode(conf, r) - } - } - - /** - * Recursively resolve all expressions in a [[LocalNode]] using the node's attributes. - */ - protected def resolveExpressions(outputNode: LocalNode): LocalNode = { - outputNode transform { - case node: LocalNode => - val inputMap = node.output.map { a => (a.name, a) }.toMap - node transformExpressions { - case UnresolvedAttribute(Seq(u)) => - inputMap.getOrElse(u, - sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) - } - } - } - - /** - * Resolve all expressions in `expressions` based on the `output` of `localNode`. - * It assumes that all expressions in the `localNode` are resolved. - */ - protected def resolveExpressions( - expressions: Seq[Expression], - localNode: LocalNode): Seq[Expression] = { - require(localNode.expressions.forall(_.resolved)) - val inputMap = localNode.output.map { a => (a.name, a) }.toMap - expressions.map { expression => - expression.transformUp { - case UnresolvedAttribute(Seq(u)) => - inputMap.getOrElse(u, - sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) - } - } - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala deleted file mode 100644 index 40299d9d5ee37..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala +++ /dev/null @@ -1,145 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} - - -class NestedLoopJoinNodeSuite extends LocalNodeTest { - - // Test all combinations of the three dimensions: with/out unsafe, build sides, and join types - private val maybeUnsafeAndCodegen = Seq(false, true) - private val buildSides = Seq(BuildLeft, BuildRight) - private val joinTypes = Seq(LeftOuter, RightOuter, FullOuter) - maybeUnsafeAndCodegen.foreach { unsafeAndCodegen => - buildSides.foreach { buildSide => - joinTypes.foreach { joinType => - testJoin(unsafeAndCodegen, buildSide, joinType) - } - } - } - - /** - * Test outer nested loop joins with varying degrees of matches. - */ - private def testJoin( - unsafeAndCodegen: Boolean, - buildSide: BuildSide, - joinType: JoinType): Unit = { - val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe" - val testNamePrefix = s"$simpleOrUnsafe / $buildSide / $joinType" - val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray - val conf = new SQLConf - conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen) - conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen) - - // Actual test body - def runTest( - joinType: JoinType, - leftInput: Array[(Int, String)], - rightInput: Array[(Int, String)]): Unit = { - val leftNode = new DummyNode(joinNameAttributes, leftInput) - val rightNode = new DummyNode(joinNicknameAttributes, rightInput) - val cond = 'id1 === 'id2 - val makeNode = (node1: LocalNode, node2: LocalNode) => { - resolveExpressions( - new NestedLoopJoinNode(conf, node1, node2, buildSide, joinType, Some(cond))) - } - val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode - val hashJoinNode = makeUnsafeNode(leftNode, rightNode) - val expectedOutput = generateExpectedOutput(leftInput, rightInput, joinType) - val actualOutput = hashJoinNode.collect().map { row => - // (id, name, id, nickname) - (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) - } - assert(actualOutput.toSet === expectedOutput.toSet) - } - - test(s"$testNamePrefix: empty") { - runTest(joinType, Array.empty, Array.empty) - } - - test(s"$testNamePrefix: no matches") { - val someIrrelevantData = (10000 to 10100).map { i => (i, "piper" + i) }.toArray - runTest(joinType, someData, Array.empty) - runTest(joinType, Array.empty, someData) - runTest(joinType, someData, someIrrelevantData) - runTest(joinType, someIrrelevantData, someData) - } - - test(s"$testNamePrefix: partial matches") { - val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray - runTest(joinType, someData, someOtherData) - runTest(joinType, someOtherData, someData) - } - - test(s"$testNamePrefix: full matches") { - val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) } - runTest(joinType, someData, someSuperRelevantData) - runTest(joinType, someSuperRelevantData, someData) - } - } - - /** - * Helper method to generate the expected output of a test based on the join type. - */ - private def generateExpectedOutput( - leftInput: Array[(Int, String)], - rightInput: Array[(Int, String)], - joinType: JoinType): Array[(Int, String, Int, String)] = { - joinType match { - case LeftOuter => - val rightInputMap = rightInput.toMap - leftInput.map { case (k, v) => - val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0) - val rightValue = rightInputMap.getOrElse(k, null) - (k, v, rightKey, rightValue) - } - - case RightOuter => - val leftInputMap = leftInput.toMap - rightInput.map { case (k, v) => - val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0) - val leftValue = leftInputMap.getOrElse(k, null) - (leftKey, leftValue, k, v) - } - - case FullOuter => - val leftInputMap = leftInput.toMap - val rightInputMap = rightInput.toMap - val leftOutput = leftInput.map { case (k, v) => - val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0) - val rightValue = rightInputMap.getOrElse(k, null) - (k, v, rightKey, rightValue) - } - val rightOutput = rightInput.map { case (k, v) => - val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0) - val leftValue = leftInputMap.getOrElse(k, null) - (leftKey, leftValue, k, v) - } - (leftOutput ++ rightOutput).distinct - - case other => - throw new IllegalArgumentException(s"Join type $other is not applicable") - } - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala deleted file mode 100644 index 02ecb23d34b2f..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NamedExpression} -import org.apache.spark.sql.types.{IntegerType, StringType} - - -class ProjectNodeSuite extends LocalNodeTest { - private val pieAttributes = Seq( - AttributeReference("id", IntegerType)(), - AttributeReference("age", IntegerType)(), - AttributeReference("name", StringType)()) - - private def testProject(inputData: Array[(Int, Int, String)] = Array.empty): Unit = { - val inputNode = new DummyNode(pieAttributes, inputData) - val columns = Seq[NamedExpression](inputNode.output(0), inputNode.output(2)) - val projectNode = new ProjectNode(conf, columns, inputNode) - val expectedOutput = inputData.map { case (id, age, name) => (id, name) } - val actualOutput = projectNode.collect().map { case row => - (row.getInt(0), row.getString(1)) - } - assert(actualOutput === expectedOutput) - } - - test("empty") { - testProject() - } - - test("basic") { - testProject((1 to 100).map { i => (i, i + 1, "pie" + i) }.toArray) - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala deleted file mode 100644 index a3e83bbd51457..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.local - -import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} - - -class SampleNodeSuite extends LocalNodeTest { - - private def testSample(withReplacement: Boolean): Unit = { - val seed = 0L - val lowerb = 0.0 - val upperb = 0.3 - val maybeOut = if (withReplacement) "" else "out" - test(s"with$maybeOut replacement") { - val inputData = (1 to 1000).map { i => (i, i) }.toArray - val inputNode = new DummyNode(kvIntAttributes, inputData) - val sampleNode = new SampleNode(conf, lowerb, upperb, withReplacement, seed, inputNode) - val sampler = - if (withReplacement) { - new PoissonSampler[(Int, Int)](upperb - lowerb, useGapSamplingIfPossible = false) - } else { - new BernoulliCellSampler[(Int, Int)](lowerb, upperb) - } - sampler.setSeed(seed) - val expectedOutput = sampler.sample(inputData.iterator).toArray - val actualOutput = sampleNode.collect().map { case row => - (row.getInt(0), row.getInt(1)) - } - assert(actualOutput === expectedOutput) - } - } - - testSample(withReplacement = true) - testSample(withReplacement = false) -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala deleted file mode 100644 index 42ebc7bfcaadc..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.local - -import scala.util.Random - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.SortOrder - - -class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { - - private def testTakeOrderedAndProject(desc: Boolean): Unit = { - val limit = 10 - val ascOrDesc = if (desc) "desc" else "asc" - test(ascOrDesc) { - val inputData = Random.shuffle((1 to 100).toList).map { i => (i, i) }.toArray - val inputNode = new DummyNode(kvIntAttributes, inputData) - val firstColumn = inputNode.output(0) - val sortDirection = if (desc) Descending else Ascending - val sortOrder = SortOrder(firstColumn, sortDirection) - val takeOrderAndProjectNode = new TakeOrderedAndProjectNode( - conf, limit, Seq(sortOrder), Some(Seq(firstColumn)), inputNode) - val expectedOutput = inputData - .map { case (k, _) => k } - .sortBy { k => k * (if (desc) -1 else 1) } - .take(limit) - val actualOutput = takeOrderAndProjectNode.collect().map { row => row.getInt(0) } - assert(actualOutput === expectedOutput) - } - } - - testTakeOrderedAndProject(desc = false) - testTakeOrderedAndProject(desc = true) -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala deleted file mode 100644 index 666b0235c061d..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.local - - -class UnionNodeSuite extends LocalNodeTest { - - private def testUnion(inputData: Seq[Array[(Int, Int)]]): Unit = { - val inputNodes = inputData.map { data => - new DummyNode(kvIntAttributes, data) - } - val unionNode = new UnionNode(conf, inputNodes) - val expectedOutput = inputData.flatten - val actualOutput = unionNode.collect().map { case row => - (row.getInt(0), row.getInt(1)) - } - assert(actualOutput === expectedOutput) - } - - test("empty") { - testUnion(Seq(Array.empty)) - testUnion(Seq(Array.empty, Array.empty)) - } - - test("self") { - val data = (1 to 100).map { i => (i, i) }.toArray - testUnion(Seq(data)) - testUnion(Seq(data, data)) - testUnion(Seq(data, data, data)) - } - - test("basic") { - val zero = Array.empty[(Int, Int)] - val one = (1 to 100).map { i => (i, i) }.toArray - val two = (50 to 150).map { i => (i, i) }.toArray - val three = (800 to 900).map { i => (i, i) }.toArray - testUnion(Seq(zero, one, two, three)) - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index cdd885ba14203..695b1824e8cf5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -21,15 +21,17 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm._ -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ +import org.apache.xbean.asm5._ +import org.apache.xbean.asm5.Opcodes._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.util.Utils +import org.apache.spark.util.{JsonProtocol, Utils} class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { @@ -41,22 +43,20 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { l += 1L l.add(1L) } - BoxingFinder.getClassReader(f.getClass).foreach { cl => - val boxingFinder = new BoxingFinder() - cl.accept(boxingFinder, 0) - assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}") - } + val cl = BoxingFinder.getClassReader(f.getClass) + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}") } test("Normal accumulator should do boxing") { // We need this test to make sure BoxingFinder works. val l = sparkContext.accumulator(0L) val f = () => { l += 1L } - BoxingFinder.getClassReader(f.getClass).foreach { cl => - val boxingFinder = new BoxingFinder() - cl.accept(boxingFinder, 0) - assert(boxingFinder.boxingInvokes.nonEmpty, "Found find boxing in this test") - } + val cl = BoxingFinder.getClassReader(f.getClass) + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.nonEmpty, "Found find boxing in this test") } /** @@ -72,7 +72,9 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { expectedNumOfJobs: Int, expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { val previousExecutionIds = sqlContext.listener.executionIdToData.keySet - df.collect() + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + df.collect() + } sparkContext.listenerBus.waitUntilEmpty(10000) val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) assert(executionIds.size === 1) @@ -84,7 +86,8 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { if (jobs.size == expectedNumOfJobs) { // If we can track all jobs, check the metric values val metricValues = sqlContext.listener.getExecutionMetrics(executionId) - val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node => + val actualMetrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan( + df.queryExecution.executedPlan)).allNodes.filter { node => expectedMetrics.contains(node.id) }.map { node => val nodeMetrics = node.metrics.map { metric => @@ -111,354 +114,164 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { } } - test("Project metrics") { - withSQLConf( - SQLConf.UNSAFE_ENABLED.key -> "false", - SQLConf.CODEGEN_ENABLED.key -> "false", - SQLConf.TUNGSTEN_ENABLED.key -> "false") { - // Assume the execution plan is - // PhysicalRDD(nodeId = 1) -> Project(nodeId = 0) - val df = person.select('name) - testSparkPlanMetrics(df, 1, Map( - 0L ->("Project", Map( - "number of rows" -> 2L))) - ) - } - } - - test("TungstenProject metrics") { - withSQLConf( - SQLConf.UNSAFE_ENABLED.key -> "true", - SQLConf.CODEGEN_ENABLED.key -> "true", - SQLConf.TUNGSTEN_ENABLED.key -> "true") { - // Assume the execution plan is - // PhysicalRDD(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = person.select('name) - testSparkPlanMetrics(df, 1, Map( - 0L ->("TungstenProject", Map( - "number of rows" -> 2L))) - ) - } - } - test("Filter metrics") { // Assume the execution plan is // PhysicalRDD(nodeId = 1) -> Filter(nodeId = 0) val df = person.filter('age < 25) testSparkPlanMetrics(df, 1, Map( 0L -> ("Filter", Map( - "number of input rows" -> 2L, "number of output rows" -> 1L))) ) } - test("Aggregate metrics") { - withSQLConf( - SQLConf.UNSAFE_ENABLED.key -> "false", - SQLConf.CODEGEN_ENABLED.key -> "false", - SQLConf.TUNGSTEN_ENABLED.key -> "false") { - // Assume the execution plan is - // ... -> Aggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> Aggregate(nodeId = 0) - val df = testData2.groupBy().count() // 2 partitions - testSparkPlanMetrics(df, 1, Map( - 2L -> ("Aggregate", Map( - "number of input rows" -> 6L, - "number of output rows" -> 2L)), - 0L -> ("Aggregate", Map( - "number of input rows" -> 2L, - "number of output rows" -> 1L))) - ) - - // 2 partitions and each partition contains 2 keys - val df2 = testData2.groupBy('a).count() - testSparkPlanMetrics(df2, 1, Map( - 2L -> ("Aggregate", Map( - "number of input rows" -> 6L, - "number of output rows" -> 4L)), - 0L -> ("Aggregate", Map( - "number of input rows" -> 4L, - "number of output rows" -> 3L))) - ) - } + test("WholeStageCodegen metrics") { + // Assume the execution plan is + // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Filter(nodeId = 1)) + // TODO: update metrics in generated operators + val ds = sqlContext.range(10).filter('id < 5) + testSparkPlanMetrics(ds.toDF(), 1, Map.empty) } - test("SortBasedAggregate metrics") { - // Because SortBasedAggregate may skip different rows if the number of partitions is different, - // this test should use the deterministic number of partitions. - withSQLConf( - SQLConf.UNSAFE_ENABLED.key -> "false", - SQLConf.CODEGEN_ENABLED.key -> "true", - SQLConf.TUNGSTEN_ENABLED.key -> "true") { - // Assume the execution plan is - // ... -> SortBasedAggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> - // SortBasedAggregate(nodeId = 0) - val df = testData2.groupBy().count() // 2 partitions - testSparkPlanMetrics(df, 1, Map( - 2L -> ("SortBasedAggregate", Map( - "number of input rows" -> 6L, - "number of output rows" -> 2L)), - 0L -> ("SortBasedAggregate", Map( - "number of input rows" -> 2L, - "number of output rows" -> 1L))) - ) + test("TungstenAggregate metrics") { + // Assume the execution plan is + // ... -> TungstenAggregate(nodeId = 2) -> Exchange(nodeId = 1) + // -> TungstenAggregate(nodeId = 0) + val df = testData2.groupBy().count() // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("TungstenAggregate", Map( + "number of output rows" -> 2L)), + 0L -> ("TungstenAggregate", Map( + "number of output rows" -> 1L))) + ) - // Assume the execution plan is - // ... -> SortBasedAggregate(nodeId = 3) -> TungstenExchange(nodeId = 2) - // -> ExternalSort(nodeId = 1)-> SortBasedAggregate(nodeId = 0) - // 2 partitions and each partition contains 2 keys - val df2 = testData2.groupBy('a).count() - testSparkPlanMetrics(df2, 1, Map( - 3L -> ("SortBasedAggregate", Map( - "number of input rows" -> 6L, - "number of output rows" -> 4L)), - 0L -> ("SortBasedAggregate", Map( - "number of input rows" -> 4L, - "number of output rows" -> 3L))) - ) - } + // 2 partitions and each partition contains 2 keys + val df2 = testData2.groupBy('a).count() + testSparkPlanMetrics(df2, 1, Map( + 2L -> ("TungstenAggregate", Map( + "number of output rows" -> 4L)), + 0L -> ("TungstenAggregate", Map( + "number of output rows" -> 3L))) + ) } - test("TungstenAggregate metrics") { - withSQLConf( - SQLConf.UNSAFE_ENABLED.key -> "true", - SQLConf.CODEGEN_ENABLED.key -> "true", - SQLConf.TUNGSTEN_ENABLED.key -> "true") { - // Assume the execution plan is - // ... -> TungstenAggregate(nodeId = 2) -> Exchange(nodeId = 1) - // -> TungstenAggregate(nodeId = 0) - val df = testData2.groupBy().count() // 2 partitions - testSparkPlanMetrics(df, 1, Map( - 2L -> ("TungstenAggregate", Map( - "number of input rows" -> 6L, - "number of output rows" -> 2L)), - 0L -> ("TungstenAggregate", Map( - "number of input rows" -> 2L, - "number of output rows" -> 1L))) - ) - - // 2 partitions and each partition contains 2 keys - val df2 = testData2.groupBy('a).count() - testSparkPlanMetrics(df2, 1, Map( - 2L -> ("TungstenAggregate", Map( - "number of input rows" -> 6L, - "number of output rows" -> 4L)), - 0L -> ("TungstenAggregate", Map( - "number of input rows" -> 4L, - "number of output rows" -> 3L))) - ) - } + test("Sort metrics") { + // Assume the execution plan is + // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1)) + val ds = sqlContext.range(10).sort('id) + testSparkPlanMetrics(ds.toDF(), 2, Map.empty) } test("SortMergeJoin metrics") { // Because SortMergeJoin may skip different rows if the number of partitions is different, this // test should use the deterministic number of partitions. - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { - val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) - testDataForJoin.registerTempTable("testDataForJoin") - withTempTable("testDataForJoin") { - // Assume the execution plan is - // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = sqlContext.sql( - "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a") - testSparkPlanMetrics(df, 1, Map( - 1L -> ("SortMergeJoin", Map( - // It's 4 because we only read 3 rows in the first partition and 1 row in the second one - "number of left rows" -> 4L, - "number of right rows" -> 2L, - "number of output rows" -> 4L))) - ) - } - } - } - - test("SortMergeOuterJoin metrics") { - // Because SortMergeOuterJoin may skip different rows if the number of partitions is different, - // this test should use the deterministic number of partitions. - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { - val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) - testDataForJoin.registerTempTable("testDataForJoin") - withTempTable("testDataForJoin") { - // Assume the execution plan is - // ... -> SortMergeOuterJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = sqlContext.sql( - "SELECT * FROM testData2 left JOIN testDataForJoin ON testData2.a = testDataForJoin.a") - testSparkPlanMetrics(df, 1, Map( - 1L -> ("SortMergeOuterJoin", Map( - // It's 4 because we only read 3 rows in the first partition and 1 row in the second one - "number of left rows" -> 6L, - "number of right rows" -> 2L, - "number of output rows" -> 8L))) - ) - - val df2 = sqlContext.sql( - "SELECT * FROM testDataForJoin right JOIN testData2 ON testData2.a = testDataForJoin.a") - testSparkPlanMetrics(df2, 1, Map( - 1L -> ("SortMergeOuterJoin", Map( - // It's 4 because we only read 3 rows in the first partition and 1 row in the second one - "number of left rows" -> 2L, - "number of right rows" -> 6L, - "number of output rows" -> 8L))) - ) - } - } - } - - test("BroadcastHashJoin metrics") { - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { - val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") - val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key", "value") + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { // Assume the execution plan is - // ... -> BroadcastHashJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = df1.join(broadcast(df2), "key") - testSparkPlanMetrics(df, 2, Map( - 1L -> ("BroadcastHashJoin", Map( - "number of left rows" -> 2L, - "number of right rows" -> 4L, - "number of output rows" -> 2L))) + // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df, 1, Map( + 0L -> ("SortMergeJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of output rows" -> 4L))) ) } } - test("ShuffledHashJoin metrics") { - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { - val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) - testDataForJoin.registerTempTable("testDataForJoin") - withTempTable("testDataForJoin") { - // Assume the execution plan is - // ... -> ShuffledHashJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = sqlContext.sql( - "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a") - testSparkPlanMetrics(df, 1, Map( - 1L -> ("ShuffledHashJoin", Map( - "number of left rows" -> 6L, - "number of right rows" -> 2L, - "number of output rows" -> 4L))) - ) - } - } - } - - test("ShuffledHashOuterJoin metrics") { - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { - val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") - val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") + test("SortMergeJoin(outer) metrics") { + // Because SortMergeJoin may skip different rows if the number of partitions is different, + // this test should use the deterministic number of partitions. + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { // Assume the execution plan is - // ... -> ShuffledHashOuterJoin(nodeId = 0) - val df = df1.join(df2, $"key" === $"key2", "left_outer") + // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 left JOIN testDataForJoin ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df, 1, Map( - 0L -> ("ShuffledHashOuterJoin", Map( - "number of left rows" -> 3L, - "number of right rows" -> 4L, - "number of output rows" -> 5L))) - ) - - val df3 = df1.join(df2, $"key" === $"key2", "right_outer") - testSparkPlanMetrics(df3, 1, Map( - 0L -> ("ShuffledHashOuterJoin", Map( - "number of left rows" -> 3L, - "number of right rows" -> 4L, - "number of output rows" -> 6L))) + 0L -> ("SortMergeJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of output rows" -> 8L))) ) - val df4 = df1.join(df2, $"key" === $"key2", "outer") - testSparkPlanMetrics(df4, 1, Map( - 0L -> ("ShuffledHashOuterJoin", Map( - "number of left rows" -> 3L, - "number of right rows" -> 4L, - "number of output rows" -> 7L))) + val df2 = sqlContext.sql( + "SELECT * FROM testDataForJoin right JOIN testData2 ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df2, 1, Map( + 0L -> ("SortMergeJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of output rows" -> 8L))) ) } } - test("BroadcastHashOuterJoin metrics") { - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { - val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") - val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") - // Assume the execution plan is - // ... -> BroadcastHashOuterJoin(nodeId = 0) - val df = df1.join(broadcast(df2), $"key" === $"key2", "left_outer") - testSparkPlanMetrics(df, 2, Map( - 0L -> ("BroadcastHashOuterJoin", Map( - "number of left rows" -> 3L, - "number of right rows" -> 4L, - "number of output rows" -> 5L))) - ) + test("BroadcastHashJoin metrics") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key", "value") + // Assume the execution plan is + // ... -> BroadcastHashJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = df1.join(broadcast(df2), "key") + testSparkPlanMetrics(df, 2, Map( + 1L -> ("BroadcastHashJoin", Map( + "number of output rows" -> 2L))) + ) + } - val df3 = df1.join(broadcast(df2), $"key" === $"key2", "right_outer") - testSparkPlanMetrics(df3, 2, Map( - 0L -> ("BroadcastHashOuterJoin", Map( - "number of left rows" -> 3L, - "number of right rows" -> 4L, - "number of output rows" -> 6L))) - ) - } + test("BroadcastHashJoin(outer) metrics") { + val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") + val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") + // Assume the execution plan is + // ... -> BroadcastHashJoin(nodeId = 0) + val df = df1.join(broadcast(df2), $"key" === $"key2", "left_outer") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("BroadcastHashJoin", Map( + "number of output rows" -> 5L))) + ) + + val df3 = df1.join(broadcast(df2), $"key" === $"key2", "right_outer") + testSparkPlanMetrics(df3, 2, Map( + 0L -> ("BroadcastHashJoin", Map( + "number of output rows" -> 6L))) + ) } test("BroadcastNestedLoopJoin metrics") { - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { - val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) - testDataForJoin.registerTempTable("testDataForJoin") - withTempTable("testDataForJoin") { - // Assume the execution plan is - // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = sqlContext.sql( - "SELECT * FROM testData2 left JOIN testDataForJoin ON " + - "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a") - testSparkPlanMetrics(df, 3, Map( - 1L -> ("BroadcastNestedLoopJoin", Map( - "number of left rows" -> 12L, // left needs to be scanned twice - "number of right rows" -> 2L, - "number of output rows" -> 12L))) - ) - } + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 left JOIN testDataForJoin ON " + + "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a") + testSparkPlanMetrics(df, 3, Map( + 1L -> ("BroadcastNestedLoopJoin", Map( + "number of output rows" -> 12L))) + ) } } test("BroadcastLeftSemiJoinHash metrics") { - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { - val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") - val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") - // Assume the execution plan is - // ... -> BroadcastLeftSemiJoinHash(nodeId = 0) - val df = df1.join(broadcast(df2), $"key" === $"key2", "leftsemi") - testSparkPlanMetrics(df, 2, Map( - 0L -> ("BroadcastLeftSemiJoinHash", Map( - "number of left rows" -> 2L, - "number of right rows" -> 4L, - "number of output rows" -> 2L))) - ) - } + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + // Assume the execution plan is + // ... -> BroadcastLeftSemiJoinHash(nodeId = 0) + val df = df1.join(broadcast(df2), $"key" === $"key2", "leftsemi") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("BroadcastLeftSemiJoinHash", Map( + "number of output rows" -> 2L))) + ) } - test("LeftSemiJoinHash metrics") { - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + test("ShuffledHashJoin metrics") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") // Assume the execution plan is - // ... -> LeftSemiJoinHash(nodeId = 0) + // ... -> ShuffledHashJoin(nodeId = 0) val df = df1.join(df2, $"key" === $"key2", "leftsemi") testSparkPlanMetrics(df, 1, Map( - 0L -> ("LeftSemiJoinHash", Map( - "number of left rows" -> 2L, - "number of right rows" -> 4L, - "number of output rows" -> 2L))) - ) - } - } - - test("LeftSemiJoinBNL metrics") { - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { - val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") - val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") - // Assume the execution plan is - // ... -> LeftSemiJoinBNL(nodeId = 0) - val df = df1.join(df2, $"key" < $"key2", "leftsemi") - testSparkPlanMetrics(df, 2, Map( - 0L -> ("LeftSemiJoinBNL", Map( - "number of left rows" -> 2L, - "number of right rows" -> 4L, + 0L -> ("ShuffledHashJoin", Map( "number of output rows" -> 2L))) ) } @@ -473,9 +286,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val df = sqlContext.sql( "SELECT * FROM testData2 JOIN testDataForJoin") testSparkPlanMetrics(df, 1, Map( - 1L -> ("CartesianProduct", Map( - "number of left rows" -> 12L, // left needs to be scanned twice - "number of right rows" -> 12L, // right is read 6 times + 0L -> ("CartesianProduct", Map( "number of output rows" -> 12L))) ) } @@ -498,10 +309,32 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val metricValues = sqlContext.listener.getExecutionMetrics(executionId) // Because "save" will create a new DataFrame internally, we cannot get the real metric id. // However, we still can check the value. - assert(metricValues.values.toSeq === Seq("2")) + assert(metricValues.values.toSeq.exists(_ === "2")) } } + test("metrics can be loaded by history server") { + val metric = new LongSQLMetric("zanzibar", LongSQLMetricParam) + metric += 10L + val metricInfo = metric.toInfo(Some(metric.localValue), None) + metricInfo.update match { + case Some(v: LongSQLMetricValue) => assert(v.value === 10L) + case Some(v) => fail(s"metric value was not a LongSQLMetricValue: ${v.getClass.getName}") + case _ => fail("metric update is missing") + } + assert(metricInfo.metadata === Some(SQLMetrics.ACCUM_IDENTIFIER)) + // After serializing to JSON, the original value type is lost, but we can still + // identify that it's a SQL metric from the metadata + val metricInfoJson = JsonProtocol.accumulableInfoToJson(metricInfo) + val metricInfoDeser = JsonProtocol.accumulableInfoFromJson(metricInfoJson) + metricInfoDeser.update match { + case Some(v: String) => assert(v.toLong === 10L) + case Some(v) => fail(s"deserialized metric value was not a string: ${v.getClass.getName}") + case _ => fail("deserialized metric update is missing") + } + assert(metricInfoDeser.metadata === Some(SQLMetrics.ACCUM_IDENTIFIER)) + } + } private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String) @@ -516,7 +349,7 @@ private class BoxingFinder( method: MethodIdentifier[_] = null, val boxingInvokes: mutable.Set[String] = mutable.Set.empty, visitedMethods: mutable.Set[MethodIdentifier[_]] = mutable.Set.empty) - extends ClassVisitor(ASM4) { + extends ClassVisitor(ASM5) { private val primitiveBoxingClassName = Set("java/lang/Long", @@ -533,11 +366,12 @@ private class BoxingFinder( MethodVisitor = { if (method != null && (method.name != name || method.desc != desc)) { // If method is specified, skip other methods. - return new MethodVisitor(ASM4) {} + return new MethodVisitor(ASM5) {} } - new MethodVisitor(ASM4) { - override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { + new MethodVisitor(ASM5) { + override def visitMethodInsn( + op: Int, owner: String, name: String, desc: String, itf: Boolean) { if (op == INVOKESPECIAL && name == "" || op == INVOKESTATIC && name == "valueOf") { if (primitiveBoxingClassName.contains(owner)) { // Find boxing methods, e.g, new java.lang.Long(l) or java.lang.Long.valueOf(l) @@ -552,10 +386,9 @@ private class BoxingFinder( if (!visitedMethods.contains(m)) { // Keep track of visited methods to avoid potential infinite cycles visitedMethods += m - BoxingFinder.getClassReader(classOfMethodOwner).foreach { cl => - visitedMethods += m - cl.accept(new BoxingFinder(m, boxingInvokes, visitedMethods), 0) - } + val cl = BoxingFinder.getClassReader(classOfMethodOwner) + visitedMethods += m + cl.accept(new BoxingFinder(m, boxingInvokes, visitedMethods), 0) } } } @@ -565,22 +398,14 @@ private class BoxingFinder( private object BoxingFinder { - def getClassReader(cls: Class[_]): Option[ClassReader] = { + def getClassReader(cls: Class[_]): ClassReader = { val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" val resourceStream = cls.getResourceAsStream(className) val baos = new ByteArrayOutputStream(128) // Copy data over, before delegating to ClassReader - // else we can run out of open file handles. Utils.copyStream(resourceStream, baos, true) - // ASM4 doesn't support Java 8 classes, which requires ASM5. - // So if the class is ASM5 (E.g., java.lang.Long when using JDK8 runtime to run these codes), - // then ClassReader will throw IllegalArgumentException, - // However, since this is only for testing, it's safe to skip these classes. - try { - Some(new ClassReader(new ByteArrayInputStream(baos.toByteArray))) - } catch { - case _: IllegalArgumentException => None - } + new ClassReader(new ByteArrayInputStream(baos.toByteArray)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala new file mode 100644 index 0000000000000..0a989d026ce1c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.stat + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.stat.StatFunctions.QuantileSummaries + + +class ApproxQuantileSuite extends SparkFunSuite { + + private val r = new Random(1) + private val n = 100 + private val increasing = "increasing" -> (0 until n).map(_.toDouble) + private val decreasing = "decreasing" -> (n until 0 by -1).map(_.toDouble) + private val random = "random" -> Seq.fill(n)(math.ceil(r.nextDouble() * 1000)) + + private def buildSummary( + data: Seq[Double], + epsi: Double, + threshold: Int): QuantileSummaries = { + var summary = new QuantileSummaries(threshold, epsi) + data.foreach { x => + summary = summary.insert(x) + } + summary.compress() + } + + private def checkQuantile(quant: Double, data: Seq[Double], summary: QuantileSummaries): Unit = { + val approx = summary.query(quant) + // The rank of the approximation. + val rank = data.count(_ < approx) // has to be <, not <= to be exact + val lower = math.floor((quant - summary.relativeError) * data.size) + val upper = math.ceil((quant + summary.relativeError) * data.size) + val msg = + s"$rank not in [$lower $upper], requested quantile: $quant, approx returned: $approx" + assert(rank >= lower, msg) + assert(rank <= upper, msg) + } + + for { + (seq_name, data) <- Seq(increasing, decreasing, random) + epsi <- Seq(0.1, 0.0001) + compression <- Seq(1000, 10) + } { + + test(s"Extremas with epsi=$epsi and seq=$seq_name, compression=$compression") { + val s = buildSummary(data, epsi, compression) + val min_approx = s.query(0.0) + assert(min_approx == data.min, s"Did not return the min: min=${data.min}, got $min_approx") + val max_approx = s.query(1.0) + assert(max_approx == data.max, s"Did not return the max: max=${data.max}, got $max_approx") + } + + test(s"Some quantile values with epsi=$epsi and seq=$seq_name, compression=$compression") { + val s = buildSummary(data, epsi, compression) + assert(s.count == data.size, s"Found count=${s.count} but data size=${data.size}") + checkQuantile(0.9999, data, s) + checkQuantile(0.9, data, s) + checkQuantile(0.5, data, s) + checkQuantile(0.1, data, s) + checkQuantile(0.001, data, s) + } + } + + // Tests for merging procedure + for { + (seq_name, data) <- Seq(increasing, decreasing, random) + epsi <- Seq(0.1, 0.0001) + compression <- Seq(1000, 10) + } { + + val (data1, data2) = { + val l = data.size + data.take(l / 2) -> data.drop(l / 2) + } + + test(s"Merging ordered lists with epsi=$epsi and seq=$seq_name, compression=$compression") { + val s1 = buildSummary(data1, epsi, compression) + val s2 = buildSummary(data2, epsi, compression) + val s = s1.merge(s2) + val min_approx = s.query(0.0) + assert(min_approx == data.min, s"Did not return the min: min=${data.min}, got $min_approx") + val max_approx = s.query(1.0) + assert(max_approx == data.max, s"Did not return the max: max=${data.max}, got $max_approx") + checkQuantile(0.9999, data, s) + checkQuantile(0.9, data, s) + checkQuantile(0.5, data, s) + checkQuantile(0.1, data, s) + checkQuantile(0.001, data, s) + } + + val (data11, data12) = { + data.sliding(2).map(_.head).toSeq -> data.sliding(2).map(_.last).toSeq + } + + test(s"Merging interleaved lists with epsi=$epsi and seq=$seq_name, compression=$compression") { + val s1 = buildSummary(data11, epsi, compression) + val s2 = buildSummary(data12, epsi, compression) + val s = s1.merge(s2) + val min_approx = s.query(0.0) + assert(min_approx == data.min, s"Did not return the min: min=${data.min}, got $min_approx") + val max_approx = s.query(1.0) + assert(max_approx == data.max, s"Did not return the max: max=${data.max}, got $max_approx") + checkQuantile(0.9999, data, s) + checkQuantile(0.9, data, s) + checkQuantile(0.5, data, s) + checkQuantile(0.1, data, s) + checkQuantile(0.001, data, s) + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala new file mode 100644 index 0000000000000..13281427045c5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io.{File, FileNotFoundException, IOException} +import java.net.URI +import java.util.ConcurrentModificationException + +import scala.language.implicitConversions +import scala.util.Random + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs._ +import org.scalatest.concurrent.AsyncAssertions._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.sql.execution.streaming.FakeFileSystem._ +import org.apache.spark.sql.execution.streaming.HDFSMetadataLog.{FileContextManager, FileManager, FileSystemManager} +import org.apache.spark.sql.test.SharedSQLContext + +class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { + + /** To avoid caching of FS objects */ + override protected val sparkConf = + new SparkConf().set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") + + private implicit def toOption[A](a: A): Option[A] = Option(a) + + test("FileManager: FileContextManager") { + withTempDir { temp => + val path = new Path(temp.getAbsolutePath) + testManager(path, new FileContextManager(path, new Configuration)) + } + } + + test("FileManager: FileSystemManager") { + withTempDir { temp => + val path = new Path(temp.getAbsolutePath) + testManager(path, new FileSystemManager(path, new Configuration)) + } + } + + test("HDFSMetadataLog: basic") { + withTempDir { temp => + val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir + val metadataLog = new HDFSMetadataLog[String](sqlContext, dir.getAbsolutePath) + assert(metadataLog.add(0, "batch0")) + assert(metadataLog.getLatest() === Some(0 -> "batch0")) + assert(metadataLog.get(0) === Some("batch0")) + assert(metadataLog.getLatest() === Some(0 -> "batch0")) + assert(metadataLog.get(None, 0) === Array(0 -> "batch0")) + + assert(metadataLog.add(1, "batch1")) + assert(metadataLog.get(0) === Some("batch0")) + assert(metadataLog.get(1) === Some("batch1")) + assert(metadataLog.getLatest() === Some(1 -> "batch1")) + assert(metadataLog.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) + + // Adding the same batch does nothing + metadataLog.add(1, "batch1-duplicated") + assert(metadataLog.get(0) === Some("batch0")) + assert(metadataLog.get(1) === Some("batch1")) + assert(metadataLog.getLatest() === Some(1 -> "batch1")) + assert(metadataLog.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) + } + } + + testQuietly("HDFSMetadataLog: fallback from FileContext to FileSystem") { + sqlContext.sparkContext.hadoopConfiguration.set( + s"fs.$scheme.impl", + classOf[FakeFileSystem].getName) + withTempDir { temp => + val metadataLog = new HDFSMetadataLog[String](sqlContext, s"$scheme://$temp") + assert(metadataLog.add(0, "batch0")) + assert(metadataLog.getLatest() === Some(0 -> "batch0")) + assert(metadataLog.get(0) === Some("batch0")) + assert(metadataLog.get(None, 0) === Array(0 -> "batch0")) + + + val metadataLog2 = new HDFSMetadataLog[String](sqlContext, s"$scheme://$temp") + assert(metadataLog2.get(0) === Some("batch0")) + assert(metadataLog2.getLatest() === Some(0 -> "batch0")) + assert(metadataLog2.get(None, 0) === Array(0 -> "batch0")) + + } + } + + test("HDFSMetadataLog: restart") { + withTempDir { temp => + val metadataLog = new HDFSMetadataLog[String](sqlContext, temp.getAbsolutePath) + assert(metadataLog.add(0, "batch0")) + assert(metadataLog.add(1, "batch1")) + assert(metadataLog.get(0) === Some("batch0")) + assert(metadataLog.get(1) === Some("batch1")) + assert(metadataLog.getLatest() === Some(1 -> "batch1")) + assert(metadataLog.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) + + val metadataLog2 = new HDFSMetadataLog[String](sqlContext, temp.getAbsolutePath) + assert(metadataLog2.get(0) === Some("batch0")) + assert(metadataLog2.get(1) === Some("batch1")) + assert(metadataLog2.getLatest() === Some(1 -> "batch1")) + assert(metadataLog2.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) + } + } + + test("HDFSMetadataLog: metadata directory collision") { + withTempDir { temp => + val waiter = new Waiter + val maxBatchId = 100 + for (id <- 0 until 10) { + new Thread() { + override def run(): Unit = waiter { + val metadataLog = new HDFSMetadataLog[String](sqlContext, temp.getAbsolutePath) + try { + var nextBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L) + nextBatchId += 1 + while (nextBatchId <= maxBatchId) { + metadataLog.add(nextBatchId, nextBatchId.toString) + nextBatchId += 1 + } + } catch { + case e: ConcurrentModificationException => + // This is expected since there are multiple writers + } finally { + waiter.dismiss() + } + } + }.start() + } + + waiter.await(timeout(10.seconds), dismissals(10)) + val metadataLog = new HDFSMetadataLog[String](sqlContext, temp.getAbsolutePath) + assert(metadataLog.getLatest() === Some(maxBatchId -> maxBatchId.toString)) + assert(metadataLog.get(None, maxBatchId) === (0 to maxBatchId).map(i => (i, i.toString))) + } + } + + + def testManager(basePath: Path, fm: FileManager): Unit = { + // Mkdirs + val dir = new Path(s"$basePath/dir/subdir/subsubdir") + assert(!fm.exists(dir)) + fm.mkdirs(dir) + assert(fm.exists(dir)) + fm.mkdirs(dir) + + // List + val acceptAllFilter = new PathFilter { + override def accept(path: Path): Boolean = true + } + val rejectAllFilter = new PathFilter { + override def accept(path: Path): Boolean = false + } + assert(fm.list(basePath, acceptAllFilter).exists(_.getPath.getName == "dir")) + assert(fm.list(basePath, rejectAllFilter).length === 0) + + // Create + val path = new Path(s"$dir/file") + assert(!fm.exists(path)) + fm.create(path).close() + assert(fm.exists(path)) + intercept[IOException] { + fm.create(path) + } + + // Open and delete + fm.open(path) + fm.delete(path) + assert(!fm.exists(path)) + intercept[IOException] { + fm.open(path) + } + fm.delete(path) // should not throw exception + + // Rename + val path1 = new Path(s"$dir/file1") + val path2 = new Path(s"$dir/file2") + fm.create(path1).close() + assert(fm.exists(path1)) + fm.rename(path1, path2) + intercept[FileNotFoundException] { + fm.rename(path1, path2) + } + val path3 = new Path(s"$dir/file3") + fm.create(path3).close() + assert(fm.exists(path3)) + intercept[FileAlreadyExistsException] { + fm.rename(path2, path3) + } + } +} + +/** FakeFileSystem to test fallback of the HDFSMetadataLog from FileContext to FileSystem API */ +class FakeFileSystem extends RawLocalFileSystem { + override def getUri: URI = { + URI.create(s"$scheme:///") + } +} + +object FakeFileSystem { + val scheme = s"HDFSMetadataLogSuite${math.abs(Random.nextInt)}" +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala new file mode 100644 index 0000000000000..dd5f92248bf5c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.{CountDownLatch, TimeUnit} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.ProcessingTime +import org.apache.spark.util.ManualClock + +class ProcessingTimeExecutorSuite extends SparkFunSuite { + + test("nextBatchTime") { + val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(100)) + assert(processingTimeExecutor.nextBatchTime(1) === 100) + assert(processingTimeExecutor.nextBatchTime(99) === 100) + assert(processingTimeExecutor.nextBatchTime(100) === 100) + assert(processingTimeExecutor.nextBatchTime(101) === 200) + assert(processingTimeExecutor.nextBatchTime(150) === 200) + } + + private def testBatchTermination(intervalMs: Long): Unit = { + var batchCounts = 0 + val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(intervalMs)) + processingTimeExecutor.execute(() => { + batchCounts += 1 + // If the batch termination works well, batchCounts should be 3 after `execute` + batchCounts < 3 + }) + assert(batchCounts === 3) + } + + test("batch termination") { + testBatchTermination(0) + testBatchTermination(10) + } + + test("notifyBatchFallingBehind") { + val clock = new ManualClock() + @volatile var batchFallingBehindCalled = false + val latch = new CountDownLatch(1) + val t = new Thread() { + override def run(): Unit = { + val processingTimeExecutor = new ProcessingTimeExecutor(ProcessingTime(100), clock) { + override def notifyBatchFallingBehind(realElapsedTimeMs: Long): Unit = { + batchFallingBehindCalled = true + } + } + processingTimeExecutor.execute(() => { + latch.countDown() + clock.waitTillTime(200) + false + }) + } + } + t.start() + // Wait until the batch is running so that we don't call `advance` too early + assert(latch.await(10, TimeUnit.SECONDS), "the batch has not yet started in 10 seconds") + clock.advance(200) + t.join() + assert(batchFallingBehindCalled === true) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala new file mode 100644 index 0000000000000..a7e32626264cc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.ExecutorCacheTaskLocation + +class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { + + import StateStoreCoordinatorSuite._ + + test("report, verify, getLocation") { + withCoordinatorRef(sc) { coordinatorRef => + val id = StateStoreId("x", 0, 0) + + assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false) + assert(coordinatorRef.getLocation(id) === None) + + coordinatorRef.reportActiveInstance(id, "hostX", "exec1") + eventually(timeout(5 seconds)) { + assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === true) + assert( + coordinatorRef.getLocation(id) === + Some(ExecutorCacheTaskLocation("hostX", "exec1").toString)) + } + + coordinatorRef.reportActiveInstance(id, "hostX", "exec2") + + eventually(timeout(5 seconds)) { + assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false) + assert(coordinatorRef.verifyIfInstanceActive(id, "exec2") === true) + + assert( + coordinatorRef.getLocation(id) === + Some(ExecutorCacheTaskLocation("hostX", "exec2").toString)) + } + } + } + + test("make inactive") { + withCoordinatorRef(sc) { coordinatorRef => + val id1 = StateStoreId("x", 0, 0) + val id2 = StateStoreId("y", 1, 0) + val id3 = StateStoreId("x", 0, 1) + val host = "hostX" + val exec = "exec1" + + coordinatorRef.reportActiveInstance(id1, host, exec) + coordinatorRef.reportActiveInstance(id2, host, exec) + coordinatorRef.reportActiveInstance(id3, host, exec) + + eventually(timeout(5 seconds)) { + assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === true) + assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true) + assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === true) + } + + coordinatorRef.deactivateInstances("x") + + assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === false) + assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true) + assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === false) + + assert(coordinatorRef.getLocation(id1) === None) + assert( + coordinatorRef.getLocation(id2) === + Some(ExecutorCacheTaskLocation(host, exec).toString)) + assert(coordinatorRef.getLocation(id3) === None) + + coordinatorRef.deactivateInstances("y") + assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === false) + assert(coordinatorRef.getLocation(id2) === None) + } + } + + test("multiple references have same underlying coordinator") { + withCoordinatorRef(sc) { coordRef1 => + val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env) + + val id = StateStoreId("x", 0, 0) + + coordRef1.reportActiveInstance(id, "hostX", "exec1") + + eventually(timeout(5 seconds)) { + assert(coordRef2.verifyIfInstanceActive(id, "exec1") === true) + assert( + coordRef2.getLocation(id) === + Some(ExecutorCacheTaskLocation("hostX", "exec1").toString)) + } + } + } +} + +object StateStoreCoordinatorSuite { + def withCoordinatorRef(sc: SparkContext)(body: StateStoreCoordinatorRef => Unit): Unit = { + var coordinatorRef: StateStoreCoordinatorRef = null + try { + coordinatorRef = StateStoreCoordinatorRef.forDriver(sc.env) + body(coordinatorRef) + } finally { + if (coordinatorRef != null) coordinatorRef.stop() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala new file mode 100644 index 0000000000000..6be94eb24fcf1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.io.File +import java.nio.file.Files + +import scala.util.Random + +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.LocalSparkContext._ +import org.apache.spark.rdd.RDD +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.util.{CompletionIterator, Utils} + +class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { + + private val sparkConf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName) + private var tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString + private val keySchema = StructType(Seq(StructField("key", StringType, true))) + private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) + + import StateStoreSuite._ + + after { + StateStore.stop() + } + + override def afterAll(): Unit = { + super.afterAll() + Utils.deleteRecursively(new File(tempDir)) + } + + test("versioning and immutability") { + withSpark(new SparkContext(sparkConf)) { sc => + val sqlContext = new SQLContext(sc) + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val opId = 0 + val rdd1 = + makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)( + increment) + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + + // Generate next version of stores + val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) + assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) + + // Make sure the previous RDD still has the same data. + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + } + } + + test("recovering from files") { + val opId = 0 + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + + def makeStoreRDD( + sc: SparkContext, + seq: Seq[String], + storeVersion: Int): RDD[(String, Int)] = { + implicit val sqlContext = new SQLContext(sc) + makeRDD(sc, Seq("a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment) + } + + // Generate RDDs and state store data + withSpark(new SparkContext(sparkConf)) { sc => + for (i <- 1 to 20) { + require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) + } + } + + // With a new context, try using the earlier state store data + withSpark(new SparkContext(sparkConf)) { sc => + assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21)) + } + } + + test("usage with iterators - only gets and only puts") { + withSpark(new SparkContext(sparkConf)) { sc => + implicit val sqlContext = new SQLContext(sc) + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val opId = 0 + + // Returns an iterator of the incremented value made into the store + def iteratorOfPuts(store: StateStore, iter: Iterator[String]): Iterator[(String, Int)] = { + val resIterator = iter.map { s => + val key = stringToRow(s) + val oldValue = store.get(key).map(rowToInt).getOrElse(0) + val newValue = oldValue + 1 + store.put(key, intToRow(newValue)) + (s, newValue) + } + CompletionIterator[(String, Int), Iterator[(String, Int)]](resIterator, { + store.commit() + }) + } + + def iteratorOfGets( + store: StateStore, + iter: Iterator[String]): Iterator[(String, Option[Int])] = { + iter.map { s => + val key = stringToRow(s) + val value = store.get(key).map(rowToInt) + (s, value) + } + } + + val rddOfGets1 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets) + assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None)) + + val rddOfPuts = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfPuts) + assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1)) + + val rddOfGets2 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(iteratorOfGets) + assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None)) + } + } + + test("preferred locations using StateStoreCoordinator") { + quietly { + val opId = 0 + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + + withSpark(new SparkContext(sparkConf)) { sc => + implicit val sqlContext = new SQLContext(sc) + val coordinatorRef = sqlContext.streams.stateStoreCoordinator + coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1") + coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2") + + assert( + coordinatorRef.getLocation(StateStoreId(path, opId, 0)) === + Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) + + val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) + require(rdd.partitions.length === 2) + + assert( + rdd.preferredLocations(rdd.partitions(0)) === + Seq(ExecutorCacheTaskLocation("host1", "exec1").toString)) + + assert( + rdd.preferredLocations(rdd.partitions(1)) === + Seq(ExecutorCacheTaskLocation("host2", "exec2").toString)) + + rdd.collect() + } + } + } + + test("distributed test") { + quietly { + withSpark(new SparkContext(sparkConf.setMaster("local-cluster[2, 1, 1024]"))) { sc => + implicit val sqlContext = new SQLContext(sc) + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val opId = 0 + val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + + // Generate next version of stores + val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) + assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) + + // Make sure the previous RDD still has the same data. + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + } + } + } + + private def makeRDD(sc: SparkContext, seq: Seq[String]): RDD[String] = { + sc.makeRDD(seq, 2).groupBy(x => x).flatMap(_._2) + } + + private val increment = (store: StateStore, iter: Iterator[String]) => { + iter.foreach { s => + val key = stringToRow(s) + val oldValue = store.get(key).map(rowToInt).getOrElse(0) + store.put(key, intToRow(oldValue + 1)) + } + store.commit() + store.iterator().map(rowsToStringInt) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala new file mode 100644 index 0000000000000..dd23925716b06 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -0,0 +1,567 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.io.File + +import scala.collection.mutable +import scala.util.Random + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.scalatest.{BeforeAndAfter, PrivateMethodTester} +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark.LocalSparkContext._ +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester { + type MapType = mutable.HashMap[UnsafeRow, UnsafeRow] + + import StateStoreCoordinatorSuite._ + import StateStoreSuite._ + + private val tempDir = Utils.createTempDir().toString + private val keySchema = StructType(Seq(StructField("key", StringType, true))) + private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) + + after { + StateStore.stop() + } + + test("get, put, remove, commit, and all data iterator") { + val provider = newStoreProvider() + + // Verify state before starting a new set of updates + assert(provider.latestIterator().isEmpty) + + val store = provider.getStore(0) + assert(!store.hasCommitted) + intercept[IllegalStateException] { + store.iterator() + } + intercept[IllegalStateException] { + store.updates() + } + + // Verify state after updating + put(store, "a", 1) + intercept[IllegalStateException] { + store.iterator() + } + intercept[IllegalStateException] { + store.updates() + } + assert(provider.latestIterator().isEmpty) + + // Make updates, commit and then verify state + put(store, "b", 2) + put(store, "aa", 3) + remove(store, _.startsWith("a")) + assert(store.commit() === 1) + + assert(store.hasCommitted) + assert(rowsToSet(store.iterator()) === Set("b" -> 2)) + assert(rowsToSet(provider.latestIterator()) === Set("b" -> 2)) + assert(fileExists(provider, version = 1, isSnapshot = false)) + + assert(getDataFromFiles(provider) === Set("b" -> 2)) + + // Trying to get newer versions should fail + intercept[Exception] { + provider.getStore(2) + } + intercept[Exception] { + getDataFromFiles(provider, 2) + } + + // New updates to the reloaded store with new version, and does not change old version + val reloadedProvider = new HDFSBackedStateStoreProvider( + store.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) + val reloadedStore = reloadedProvider.getStore(1) + put(reloadedStore, "c", 4) + assert(reloadedStore.commit() === 2) + assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) + assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4)) + assert(getDataFromFiles(provider, version = 1) === Set("b" -> 2)) + assert(getDataFromFiles(provider, version = 2) === Set("b" -> 2, "c" -> 4)) + } + + test("updates iterator with all combos of updates and removes") { + val provider = newStoreProvider() + var currentVersion: Int = 0 + + def withStore(body: StateStore => Unit): Unit = { + val store = provider.getStore(currentVersion) + body(store) + currentVersion += 1 + } + + // New data should be seen in updates as value added, even if they had multiple updates + withStore { store => + put(store, "a", 1) + put(store, "aa", 1) + put(store, "aa", 2) + store.commit() + assert(updatesToSet(store.updates()) === Set(Added("a", 1), Added("aa", 2))) + assert(rowsToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2)) + } + + // Multiple updates to same key should be collapsed in the updates as a single value update + // Keys that have not been updated should not appear in the updates + withStore { store => + put(store, "a", 4) + put(store, "a", 6) + store.commit() + assert(updatesToSet(store.updates()) === Set(Updated("a", 6))) + assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) + } + + // Keys added, updated and finally removed before commit should not appear in updates + withStore { store => + put(store, "b", 4) // Added, finally removed + put(store, "bb", 5) // Added, updated, finally removed + put(store, "bb", 6) + remove(store, _.startsWith("b")) + store.commit() + assert(updatesToSet(store.updates()) === Set.empty) + assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) + } + + // Removed data should be seen in updates as a key removed + // Removed, but re-added data should be seen in updates as a value update + withStore { store => + remove(store, _.startsWith("a")) + put(store, "a", 10) + store.commit() + assert(updatesToSet(store.updates()) === Set(Updated("a", 10), Removed("aa"))) + assert(rowsToSet(store.iterator()) === Set("a" -> 10)) + } + } + + test("cancel") { + val provider = newStoreProvider() + val store = provider.getStore(0) + put(store, "a", 1) + store.commit() + assert(rowsToSet(store.iterator()) === Set("a" -> 1)) + + // cancelUpdates should not change the data in the files + val store1 = provider.getStore(1) + put(store1, "b", 1) + store1.abort() + assert(getDataFromFiles(provider) === Set("a" -> 1)) + } + + test("getStore with unexpected versions") { + val provider = newStoreProvider() + + intercept[IllegalArgumentException] { + provider.getStore(-1) + } + + // Prepare some data in the stoer + val store = provider.getStore(0) + put(store, "a", 1) + assert(store.commit() === 1) + assert(rowsToSet(store.iterator()) === Set("a" -> 1)) + + intercept[IllegalStateException] { + provider.getStore(2) + } + + // Update store version with some data + val store1 = provider.getStore(1) + put(store1, "b", 1) + assert(store1.commit() === 2) + assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1)) + assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1)) + + // Overwrite the version with other data + val store2 = provider.getStore(1) + put(store2, "c", 1) + assert(store2.commit() === 2) + assert(rowsToSet(store2.iterator()) === Set("a" -> 1, "c" -> 1)) + assert(getDataFromFiles(provider) === Set("a" -> 1, "c" -> 1)) + } + + test("snapshotting") { + val provider = newStoreProvider(minDeltasForSnapshot = 5) + + var currentVersion = 0 + def updateVersionTo(targetVersion: Int): Unit = { + for (i <- currentVersion + 1 to targetVersion) { + val store = provider.getStore(currentVersion) + put(store, "a", i) + store.commit() + currentVersion += 1 + } + require(currentVersion === targetVersion) + } + + updateVersionTo(2) + require(getDataFromFiles(provider) === Set("a" -> 2)) + provider.doMaintenance() // should not generate snapshot files + assert(getDataFromFiles(provider) === Set("a" -> 2)) + + for (i <- 1 to currentVersion) { + assert(fileExists(provider, i, isSnapshot = false)) // all delta files present + assert(!fileExists(provider, i, isSnapshot = true)) // no snapshot files present + } + + // After version 6, snapshotting should generate one snapshot file + updateVersionTo(6) + require(getDataFromFiles(provider) === Set("a" -> 6), "store not updated correctly") + provider.doMaintenance() // should generate snapshot files + + val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true)) + assert(snapshotVersion.nonEmpty, "snapshot file not generated") + deleteFilesEarlierThanVersion(provider, snapshotVersion.get) + assert( + getDataFromFiles(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get), + "snapshotting messed up the data of the snapshotted version") + assert( + getDataFromFiles(provider) === Set("a" -> 6), + "snapshotting messed up the data of the final version") + + // After version 20, snapshotting should generate newer snapshot files + updateVersionTo(20) + require(getDataFromFiles(provider) === Set("a" -> 20), "store not updated correctly") + provider.doMaintenance() // do snapshot + + val latestSnapshotVersion = (0 to 20).filter(version => + fileExists(provider, version, isSnapshot = true)).lastOption + assert(latestSnapshotVersion.nonEmpty, "no snapshot file found") + assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated") + + deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get) + assert(getDataFromFiles(provider) === Set("a" -> 20), "snapshotting messed up the data") + } + + test("cleaning") { + val provider = newStoreProvider(minDeltasForSnapshot = 5) + + for (i <- 1 to 20) { + val store = provider.getStore(i - 1) + put(store, "a", i) + store.commit() + provider.doMaintenance() // do cleanup + } + require( + rowsToSet(provider.latestIterator()) === Set("a" -> 20), + "store not updated correctly") + + assert(!fileExists(provider, version = 1, isSnapshot = false)) // first file should be deleted + + // last couple of versions should be retrievable + assert(getDataFromFiles(provider, 20) === Set("a" -> 20)) + assert(getDataFromFiles(provider, 19) === Set("a" -> 19)) + } + + + test("corrupted file handling") { + val provider = newStoreProvider(minDeltasForSnapshot = 5) + for (i <- 1 to 6) { + val store = provider.getStore(i - 1) + put(store, "a", i) + store.commit() + provider.doMaintenance() // do cleanup + } + val snapshotVersion = (0 to 10).find( version => + fileExists(provider, version, isSnapshot = true)).getOrElse(fail("snapshot file not found")) + + // Corrupt snapshot file and verify that it throws error + assert(getDataFromFiles(provider, snapshotVersion) === Set("a" -> snapshotVersion)) + corruptFile(provider, snapshotVersion, isSnapshot = true) + intercept[Exception] { + getDataFromFiles(provider, snapshotVersion) + } + + // Corrupt delta file and verify that it throws error + assert(getDataFromFiles(provider, snapshotVersion - 1) === Set("a" -> (snapshotVersion - 1))) + corruptFile(provider, snapshotVersion - 1, isSnapshot = false) + intercept[Exception] { + getDataFromFiles(provider, snapshotVersion - 1) + } + + // Delete delta file and verify that it throws error + deleteFilesEarlierThanVersion(provider, snapshotVersion) + intercept[Exception] { + getDataFromFiles(provider, snapshotVersion - 1) + } + } + + test("StateStore.get") { + quietly { + val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val storeId = StateStoreId(dir, 0, 0) + val storeConf = StateStoreConf.empty + val hadoopConf = new Configuration() + + + // Verify that trying to get incorrect versions throw errors + intercept[IllegalArgumentException] { + StateStore.get(storeId, keySchema, valueSchema, -1, storeConf, hadoopConf) + } + assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt to load the store + + intercept[IllegalStateException] { + StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) + } + + // Increase version of the store + val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf) + assert(store0.version === 0) + put(store0, "a", 1) + store0.commit() + + assert(StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf).version == 1) + assert(StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf).version == 0) + + // Verify that you can remove the store and still reload and use it + StateStore.unload(storeId) + assert(!StateStore.isLoaded(storeId)) + + val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) + assert(StateStore.isLoaded(storeId)) + put(store1, "a", 2) + assert(store1.commit() === 2) + assert(rowsToSet(store1.iterator()) === Set("a" -> 2)) + } + } + + ignore("maintenance") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set(StateStore.MAINTENANCE_INTERVAL_CONFIG, "10ms") + .set("spark.rpc.numRetries", "1") + val opId = 0 + val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val storeId = StateStoreId(dir, opId, 0) + val storeConf = StateStoreConf.empty + val hadoopConf = new Configuration() + val provider = new HDFSBackedStateStoreProvider( + storeId, keySchema, valueSchema, storeConf, hadoopConf) + + quietly { + withSpark(new SparkContext(conf)) { sc => + withCoordinatorRef(sc) { coordinatorRef => + for (i <- 1 to 20) { + val store = StateStore.get( + storeId, keySchema, valueSchema, i - 1, storeConf, hadoopConf) + put(store, "a", i) + store.commit() + } + eventually(timeout(10 seconds)) { + assert(coordinatorRef.getLocation(storeId).nonEmpty, "active instance was not reported") + } + + // Background maintenance should clean up and generate snapshots + eventually(timeout(10 seconds)) { + // Earliest delta file should get cleaned up + assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") + + // Some snapshots should have been generated + val snapshotVersions = (0 to 20).filter { version => + fileExists(provider, version, isSnapshot = true) + } + assert(snapshotVersions.nonEmpty, "no snapshot file found") + } + + // If driver decides to deactivate all instances of the store, then this instance + // should be unloaded + coordinatorRef.deactivateInstances(dir) + eventually(timeout(10 seconds)) { + assert(!StateStore.isLoaded(storeId)) + } + + // Reload the store and verify + StateStore.get(storeId, keySchema, valueSchema, 20, storeConf, hadoopConf) + assert(StateStore.isLoaded(storeId)) + + // If some other executor loads the store, then this instance should be unloaded + coordinatorRef.reportActiveInstance(storeId, "other-host", "other-exec") + eventually(timeout(10 seconds)) { + assert(!StateStore.isLoaded(storeId)) + } + + // Reload the store and verify + StateStore.get(storeId, keySchema, valueSchema, 20, storeConf, hadoopConf) + assert(StateStore.isLoaded(storeId)) + } + } + + // Verify if instance is unloaded if SparkContext is stopped + require(SparkEnv.get === null) + eventually(timeout(10 seconds)) { + assert(!StateStore.isLoaded(storeId)) + } + } + } + + def getDataFromFiles( + provider: HDFSBackedStateStoreProvider, + version: Int = -1): Set[(String, Int)] = { + val reloadedProvider = new HDFSBackedStateStoreProvider( + provider.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) + if (version < 0) { + reloadedProvider.latestIterator().map(rowsToStringInt).toSet + } else { + reloadedProvider.iterator(version).map(rowsToStringInt).toSet + } + } + + def assertMap( + testMapOption: Option[MapType], + expectedMap: Map[String, Int]): Unit = { + assert(testMapOption.nonEmpty, "no map present") + val convertedMap = testMapOption.get.map(rowsToStringInt) + assert(convertedMap === expectedMap) + } + + def fileExists( + provider: HDFSBackedStateStoreProvider, + version: Long, + isSnapshot: Boolean): Boolean = { + val method = PrivateMethod[Path]('baseDir) + val basePath = provider invokePrivate method() + val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta" + val filePath = new File(basePath.toString, fileName) + filePath.exists + } + + def deleteFilesEarlierThanVersion(provider: HDFSBackedStateStoreProvider, version: Long): Unit = { + val method = PrivateMethod[Path]('baseDir) + val basePath = provider invokePrivate method() + for (version <- 0 until version.toInt) { + for (isSnapshot <- Seq(false, true)) { + val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta" + val filePath = new File(basePath.toString, fileName) + if (filePath.exists) filePath.delete() + } + } + } + + def corruptFile( + provider: HDFSBackedStateStoreProvider, + version: Long, + isSnapshot: Boolean): Unit = { + val method = PrivateMethod[Path]('baseDir) + val basePath = provider invokePrivate method() + val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta" + val filePath = new File(basePath.toString, fileName) + filePath.delete() + filePath.createNewFile() + } + + def storeLoaded(storeId: StateStoreId): Boolean = { + val method = PrivateMethod[mutable.HashMap[StateStoreId, StateStore]]('loadedStores) + val loadedStores = StateStore invokePrivate method() + loadedStores.contains(storeId) + } + + def unloadStore(storeId: StateStoreId): Boolean = { + val method = PrivateMethod('remove) + StateStore invokePrivate method(storeId) + } + + def newStoreProvider( + opId: Long = Random.nextLong, + partition: Int = 0, + minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get + ): HDFSBackedStateStoreProvider = { + val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString + val sqlConf = new SQLConf() + sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot) + new HDFSBackedStateStoreProvider( + StateStoreId(dir, opId, partition), + keySchema, + valueSchema, + new StateStoreConf(sqlConf), + new Configuration()) + } + + def remove(store: StateStore, condition: String => Boolean): Unit = { + store.remove(row => condition(rowToString(row))) + } + + private def put(store: StateStore, key: String, value: Int): Unit = { + store.put(stringToRow(key), intToRow(value)) + } + + private def get(store: StateStore, key: String): Option[Int] = { + store.get(stringToRow(key)).map(rowToInt) + } +} + +private[state] object StateStoreSuite { + + /** Trait and classes mirroring [[StoreUpdate]] for testing store updates iterator */ + trait TestUpdate + case class Added(key: String, value: Int) extends TestUpdate + case class Updated(key: String, value: Int) extends TestUpdate + case class Removed(key: String) extends TestUpdate + + val strProj = UnsafeProjection.create(Array[DataType](StringType)) + val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) + + def stringToRow(s: String): UnsafeRow = { + strProj.apply(new GenericInternalRow(Array[Any](UTF8String.fromString(s)))).copy() + } + + def intToRow(i: Int): UnsafeRow = { + intProj.apply(new GenericInternalRow(Array[Any](i))).copy() + } + + def rowToString(row: UnsafeRow): String = { + row.getUTF8String(0).toString + } + + def rowToInt(row: UnsafeRow): Int = { + row.getInt(0) + } + + def rowsToIntInt(row: (UnsafeRow, UnsafeRow)): (Int, Int) = { + (rowToInt(row._1), rowToInt(row._2)) + } + + + def rowsToStringInt(row: (UnsafeRow, UnsafeRow)): (String, Int) = { + (rowToString(row._1), rowToInt(row._2)) + } + + def rowsToSet(iterator: Iterator[(UnsafeRow, UnsafeRow)]): Set[(String, Int)] = { + iterator.map(rowsToStringInt).toSet + } + + def updatesToSet(iterator: Iterator[StoreUpdate]): Set[TestUpdate] = { + iterator.map { _ match { + case ValueAdded(key, value) => Added(rowToString(key), rowToInt(value)) + case ValueUpdated(key, value) => Updated(rowToString(key), rowToInt(value)) + case KeyRemoved(key) => Removed(rowToString(key)) + }}.toSet + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index c15aac775096c..09bd7f6e8f0a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -19,13 +19,17 @@ package org.apache.spark.sql.execution.ui import java.util.Properties -import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite} +import org.mockito.Mockito.{mock, when} + +import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.executor.TaskMetrics -import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} +import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.ui.SparkUI class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ @@ -67,22 +71,34 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { ) private def createTaskMetrics(accumulatorUpdates: Map[Long, Long]): TaskMetrics = { - val metrics = new TaskMetrics - metrics.setAccumulatorsUpdater(() => accumulatorUpdates.mapValues(new LongSQLMetricValue(_))) - metrics.updateAccumulators() + val metrics = mock(classOf[TaskMetrics]) + when(metrics.accumulatorUpdates()).thenReturn(accumulatorUpdates.map { case (id, update) => + new AccumulableInfo(id, Some(""), Some(new LongSQLMetricValue(update)), + value = None, internal = true, countFailedValues = true) + }.toSeq) metrics } test("basic") { def checkAnswer(actual: Map[Long, String], expected: Map[Long, Long]): Unit = { - assert(actual === expected.mapValues(_.toString)) + assert(actual.size == expected.size) + expected.foreach { e => + // The values in actual can be SQL metrics meaning that they contain additional formatting + // when converted to string. Verify that they start with the expected value. + // TODO: this is brittle. There is no requirement that the actual string needs to start + // with the accumulator value. + assert(actual.contains(e._1)) + val v = actual.get(e._1).get.trim + assert(v.startsWith(e._2.toString)) + } } val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame val accumulatorIds = - SparkPlanGraph(df.queryExecution.executedPlan).nodes.flatMap(_.metrics.map(_.accumulatorId)) + SparkPlanGraph(SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan)) + .allNodes.flatMap(_.metrics.map(_.accumulatorId)) // Assume all accumulators are long var accumulatorValue = 0L val accumulatorUpdates = accumulatorIds.map { id => @@ -90,13 +106,13 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { (id, accumulatorValue) }.toMap - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) val executionUIData = listener.executionIdToData(0) @@ -113,17 +129,17 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { assert(listener.getExecutionMetrics(0).isEmpty) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( - // (task id, stage id, stage attempt, metrics) - (0L, 0, 0, createTaskMetrics(accumulatorUpdates)), - (1L, 0, 0, createTaskMetrics(accumulatorUpdates)) + // (task id, stage id, stage attempt, accum updates) + (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()), + (1L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()) ))) checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( - // (task id, stage id, stage attempt, metrics) - (0L, 0, 0, createTaskMetrics(accumulatorUpdates)), - (1L, 0, 0, createTaskMetrics(accumulatorUpdates.mapValues(_ * 2))) + // (task id, stage id, stage attempt, accum updates) + (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()), + (1L, 0, 0, createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)).accumulatorUpdates()) ))) checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 3)) @@ -132,9 +148,9 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 1))) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( - // (task id, stage id, stage attempt, metrics) - (0L, 0, 1, createTaskMetrics(accumulatorUpdates)), - (1L, 0, 1, createTaskMetrics(accumulatorUpdates)) + // (task id, stage id, stage attempt, accum updates) + (0L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()), + (1L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()) ))) checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) @@ -172,9 +188,9 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(1, 0))) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( - // (task id, stage id, stage attempt, metrics) - (0L, 1, 0, createTaskMetrics(accumulatorUpdates)), - (1L, 1, 0, createTaskMetrics(accumulatorUpdates)) + // (task id, stage id, stage attempt, accum updates) + (0L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()), + (1L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()) ))) checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 7)) @@ -206,7 +222,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { time = System.currentTimeMillis(), JobSucceeded )) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) assert(executionUIData.runningJobs.isEmpty) assert(executionUIData.succeededJobs === Seq(0)) @@ -219,19 +236,20 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) listener.onJobEnd(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), @@ -248,13 +266,13 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), @@ -271,7 +289,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) listener.onJobEnd(SparkListenerJobEnd( jobId = 1, time = System.currentTimeMillis(), @@ -288,19 +307,20 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Seq.empty, createProperties(executionId))) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) listener.onJobEnd(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), @@ -326,43 +346,81 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber + 1) } + test("SPARK-13055: history listener only tracks SQL metrics") { + val listener = new SQLHistoryListener(sparkContext.conf, mock(classOf[SparkUI])) + // We need to post other events for the listener to track our accumulators. + // These are largely just boilerplate unrelated to what we're trying to test. + val df = createTestDataFrame + val executionStart = SparkListenerSQLExecutionStart( + 0, "", "", "", SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), 0) + val stageInfo = createStageInfo(0, 0) + val jobStart = SparkListenerJobStart(0, 0, Seq(stageInfo), createProperties(0)) + val stageSubmitted = SparkListenerStageSubmitted(stageInfo) + // This task has both accumulators that are SQL metrics and accumulators that are not. + // The listener should only track the ones that are actually SQL metrics. + val sqlMetric = SQLMetrics.createLongMetric(sparkContext, "beach umbrella") + val nonSqlMetric = sparkContext.accumulator[Int](0, "baseball") + val sqlMetricInfo = sqlMetric.toInfo(Some(sqlMetric.localValue), None) + val nonSqlMetricInfo = nonSqlMetric.toInfo(Some(nonSqlMetric.localValue), None) + val taskInfo = createTaskInfo(0, 0) + taskInfo.accumulables ++= Seq(sqlMetricInfo, nonSqlMetricInfo) + val taskEnd = SparkListenerTaskEnd(0, 0, "just-a-task", null, taskInfo, null) + listener.onOtherEvent(executionStart) + listener.onJobStart(jobStart) + listener.onStageSubmitted(stageSubmitted) + // Before SPARK-13055, this throws ClassCastException because the history listener would + // assume that the accumulator value is of type Long, but this may not be true for + // accumulators that are not SQL metrics. + listener.onTaskEnd(taskEnd) + val trackedAccums = listener.stageIdToStageMetrics.values.flatMap { stageMetrics => + stageMetrics.taskIdToMetricUpdates.values.flatMap(_.accumulatorUpdates) + } + // Listener tracks only SQL metrics, not other accumulators + assert(trackedAccums.size === 1) + assert(trackedAccums.head === sqlMetricInfo) + } + } + class SQLListenerMemoryLeakSuite extends SparkFunSuite { test("no memory leak") { - val conf = new SparkConf() - .setMaster("local") - .setAppName("test") - .set("spark.task.maxFailures", "1") // Don't retry the tasks to run this test quickly - .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly - val sc = new SparkContext(conf) - try { - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ - // Run 100 successful executions and 100 failed executions. - // Each execution only has one job and one stage. - for (i <- 0 until 100) { - val df = Seq( - (1, 1), - (2, 2) - ).toDF() - df.collect() - try { - df.foreach(_ => throw new RuntimeException("Oops")) - } catch { - case e: SparkException => // This is expected for a failed job + quietly { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set("spark.task.maxFailures", "1") // Don't retry the tasks to run this test quickly + .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly + val sc = new SparkContext(conf) + try { + SQLContext.clearSqlListener() + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + // Run 100 successful executions and 100 failed executions. + // Each execution only has one job and one stage. + for (i <- 0 until 100) { + val df = Seq( + (1, 1), + (2, 2) + ).toDF() + df.collect() + try { + df.foreach(_ => throw new RuntimeException("Oops")) + } catch { + case e: SparkException => // This is expected for a failed job + } } + sc.listenerBus.waitUntilEmpty(10000) + assert(sqlContext.listener.getCompletedExecutions.size <= 50) + assert(sqlContext.listener.getFailedExecutions.size <= 50) + // 50 for successful executions and 50 for failed executions + assert(sqlContext.listener.executionIdToData.size <= 100) + assert(sqlContext.listener.jobIdToExecutionId.size <= 100) + assert(sqlContext.listener.stageIdToStageMetrics.size <= 100) + } finally { + sc.stop() } - sc.listenerBus.waitUntilEmpty(10000) - assert(sqlContext.listener.getCompletedExecutions.size <= 50) - assert(sqlContext.listener.getFailedExecutions.size <= 50) - // 50 for successful executions and 50 for failed executions - assert(sqlContext.listener.executionIdToData.size <= 100) - assert(sqlContext.listener.jobIdToExecutionId.size <= 100) - assert(sqlContext.listener.stageIdToStageMetrics.size <= 100) - } finally { - sc.stop() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala new file mode 100644 index 0000000000000..67b3d98c1daed --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -0,0 +1,399 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets + +import scala.util.Random + +import org.apache.spark.memory.MemoryMode +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.vectorized.ColumnVector +import org.apache.spark.sql.types.{BinaryType, IntegerType} +import org.apache.spark.unsafe.Platform +import org.apache.spark.util.Benchmark +import org.apache.spark.util.collection.BitSet + +/** + * Benchmark to low level memory access using different ways to manage buffers. + */ +object ColumnarBatchBenchmark { + + // This benchmark reads and writes an array of ints. + // TODO: there is a big (2x) penalty for a random access API for off heap. + // Note: carefully if modifying this code. It's hard to reason about the JIT. + def intAccess(iters: Long): Unit = { + val count = 8 * 1000 + + // Accessing a java array. + val javaArray = { i: Int => + val data = new Array[Int](count) + var sum = 0L + for (n <- 0L until iters) { + var i = 0 + while (i < count) { + data(i) = i + i += 1 + } + i = 0 + while (i < count) { + sum += data(i) + i += 1 + } + } + } + + // Accessing ByteBuffers + val byteBufferUnsafe = { i: Int => + val data = ByteBuffer.allocate(count * 4) + var sum = 0L + for (n <- 0L until iters) { + var i = 0 + while (i < count) { + Platform.putInt(data.array(), Platform.BYTE_ARRAY_OFFSET + i * 4, i) + i += 1 + } + i = 0 + while (i < count) { + sum += Platform.getInt(data.array(), Platform.BYTE_ARRAY_OFFSET + i * 4) + i += 1 + } + } + } + + // Accessing offheap byte buffers + val directByteBuffer = { i: Int => + val data = ByteBuffer.allocateDirect(count * 4).asIntBuffer() + var sum = 0L + for (n <- 0L until iters) { + var i = 0 + while (i < count) { + data.put(i) + i += 1 + } + data.rewind() + i = 0 + while (i < count) { + sum += data.get() + i += 1 + } + data.rewind() + } + } + + // Accessing ByteBuffer using the typed APIs + val byteBufferApi = { i: Int => + val data = ByteBuffer.allocate(count * 4) + var sum = 0L + for (n <- 0L until iters) { + var i = 0 + while (i < count) { + data.putInt(i) + i += 1 + } + data.rewind() + i = 0 + while (i < count) { + sum += data.getInt() + i += 1 + } + data.rewind() + } + } + + // Using unsafe memory + val unsafeBuffer = { i: Int => + val data: Long = Platform.allocateMemory(count * 4) + var sum = 0L + for (n <- 0L until iters) { + var ptr = data + var i = 0 + while (i < count) { + Platform.putInt(null, ptr, i) + ptr += 4 + i += 1 + } + ptr = data + i = 0 + while (i < count) { + sum += Platform.getInt(null, ptr) + ptr += 4 + i += 1 + } + } + } + + // Access through the column API with on heap memory + val columnOnHeap = { i: Int => + val col = ColumnVector.allocate(count, IntegerType, MemoryMode.ON_HEAP) + var sum = 0L + for (n <- 0L until iters) { + var i = 0 + while (i < count) { + col.putInt(i, i) + i += 1 + } + i = 0 + while (i < count) { + sum += col.getInt(i) + i += 1 + } + } + col.close + } + + // Access through the column API with off heap memory + def columnOffHeap = { i: Int => { + val col = ColumnVector.allocate(count, IntegerType, MemoryMode.OFF_HEAP) + var sum = 0L + for (n <- 0L until iters) { + var i = 0 + while (i < count) { + col.putInt(i, i) + i += 1 + } + i = 0 + while (i < count) { + sum += col.getInt(i) + i += 1 + } + } + col.close + }} + + // Access by directly getting the buffer backing the column. + val columnOffheapDirect = { i: Int => + val col = ColumnVector.allocate(count, IntegerType, MemoryMode.OFF_HEAP) + var sum = 0L + for (n <- 0L until iters) { + var addr = col.valuesNativeAddress() + var i = 0 + while (i < count) { + Platform.putInt(null, addr, i) + addr += 4 + i += 1 + } + i = 0 + addr = col.valuesNativeAddress() + while (i < count) { + sum += Platform.getInt(null, addr) + addr += 4 + i += 1 + } + } + col.close + } + + // Access by going through a batch of unsafe rows. + val unsafeRowOnheap = { i: Int => + val buffer = new Array[Byte](count * 16) + var sum = 0L + for (n <- 0L until iters) { + val row = new UnsafeRow(1) + var i = 0 + while (i < count) { + row.pointTo(buffer, Platform.BYTE_ARRAY_OFFSET + i * 16, 16) + row.setInt(0, i) + i += 1 + } + i = 0 + while (i < count) { + row.pointTo(buffer, Platform.BYTE_ARRAY_OFFSET + i * 16, 16) + sum += row.getInt(0) + i += 1 + } + } + } + + // Access by going through a batch of unsafe rows. + val unsafeRowOffheap = { i: Int => + val buffer = Platform.allocateMemory(count * 16) + var sum = 0L + for (n <- 0L until iters) { + val row = new UnsafeRow(1) + var i = 0 + while (i < count) { + row.pointTo(null, buffer + i * 16, 16) + row.setInt(0, i) + i += 1 + } + i = 0 + while (i < count) { + row.pointTo(null, buffer + i * 16, 16) + sum += row.getInt(0) + i += 1 + } + } + Platform.freeMemory(buffer) + } + + // Adding values by appending, instead of putting. + val onHeapAppend = { i: Int => + val col = ColumnVector.allocate(count, IntegerType, MemoryMode.ON_HEAP) + var sum = 0L + for (n <- 0L until iters) { + var i = 0 + while (i < count) { + col.appendInt(i) + i += 1 + } + i = 0 + while (i < count) { + sum += col.getInt(i) + i += 1 + } + col.reset() + } + col.close + } + + /* + Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz + Int Read/Write: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------- + Java Array 248.8 1317.04 1.00 X + ByteBuffer Unsafe 435.6 752.25 0.57 X + ByteBuffer API 1752.0 187.03 0.14 X + DirectByteBuffer 595.4 550.35 0.42 X + Unsafe Buffer 235.2 1393.20 1.06 X + Column(on heap) 189.8 1726.45 1.31 X + Column(off heap) 408.4 802.35 0.61 X + Column(off heap direct) 237.6 1379.12 1.05 X + UnsafeRow (on heap) 414.6 790.35 0.60 X + UnsafeRow (off heap) 487.2 672.58 0.51 X + Column On Heap Append 530.1 618.14 0.59 X + */ + val benchmark = new Benchmark("Int Read/Write", count * iters) + benchmark.addCase("Java Array")(javaArray) + benchmark.addCase("ByteBuffer Unsafe")(byteBufferUnsafe) + benchmark.addCase("ByteBuffer API")(byteBufferApi) + benchmark.addCase("DirectByteBuffer")(directByteBuffer) + benchmark.addCase("Unsafe Buffer")(unsafeBuffer) + benchmark.addCase("Column(on heap)")(columnOnHeap) + benchmark.addCase("Column(off heap)")(columnOffHeap) + benchmark.addCase("Column(off heap direct)")(columnOffheapDirect) + benchmark.addCase("UnsafeRow (on heap)")(unsafeRowOnheap) + benchmark.addCase("UnsafeRow (off heap)")(unsafeRowOffheap) + benchmark.addCase("Column On Heap Append")(onHeapAppend) + benchmark.run() + } + + def booleanAccess(iters: Int): Unit = { + val count = 8 * 1024 + val benchmark = new Benchmark("Boolean Read/Write", iters * count) + benchmark.addCase("Bitset") { i: Int => { + val b = new BitSet(count) + var sum = 0L + for (n <- 0L until iters) { + var i = 0 + while (i < count) { + if (i % 2 == 0) b.set(i) + i += 1 + } + i = 0 + while (i < count) { + if (b.get(i)) sum += 1 + i += 1 + } + } + }} + + benchmark.addCase("Byte Array") { i: Int => { + val b = new Array[Byte](count) + var sum = 0L + for (n <- 0L until iters) { + var i = 0 + while (i < count) { + if (i % 2 == 0) b(i) = 1; + i += 1 + } + i = 0 + while (i < count) { + if (b(i) == 1) sum += 1 + i += 1 + } + } + }} + /* + Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz + Boolean Read/Write: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------- + Bitset 895.88 374.54 1.00 X + Byte Array 578.96 579.56 1.55 X + */ + benchmark.run() + } + + def stringAccess(iters: Long): Unit = { + val chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + val random = new Random(0) + + def randomString(min: Int, max: Int): String = { + val len = random.nextInt(max - min) + min + val sb = new StringBuilder(len) + var i = 0 + while (i < len) { + sb.append(chars.charAt(random.nextInt(chars.length()))); + i += 1 + } + return sb.toString + } + + val minString = 3 + val maxString = 32 + val count = 4 * 1000 + + val data = Seq.fill(count)(randomString(minString, maxString)) + .map(_.getBytes(StandardCharsets.UTF_8)).toArray + + def column(memoryMode: MemoryMode) = { i: Int => + val column = ColumnVector.allocate(count, BinaryType, memoryMode) + var sum = 0L + for (n <- 0L until iters) { + var i = 0 + while (i < count) { + column.putByteArray(i, data(i)) + i += 1 + } + i = 0 + while (i < count) { + sum += column.getUTF8String(i).numBytes() + i += 1 + } + column.reset() + } + } + + /* + String Read/Write: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------------- + On Heap 457.0 35.85 1.00 X + Off Heap 1206.0 13.59 0.38 X + */ + val benchmark = new Benchmark("String Read/Write", count * iters) + benchmark.addCase("On Heap")(column(MemoryMode.ON_HEAP)) + benchmark.addCase("Off Heap")(column(MemoryMode.OFF_HEAP)) + benchmark.run + } + + def main(args: Array[String]): Unit = { + intAccess(1024 * 40) + booleanAccess(1024 * 40) + stringAccess(1024 * 4) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala new file mode 100644 index 0000000000000..31b63f2ce13d5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -0,0 +1,776 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized + +import java.nio.charset.StandardCharsets + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.memory.MemoryMode +import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.types.CalendarInterval + +class ColumnarBatchSuite extends SparkFunSuite { + test("Null Apis") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val reference = mutable.ArrayBuffer.empty[Boolean] + + val column = ColumnVector.allocate(1024, IntegerType, memMode) + var idx = 0 + assert(column.anyNullsSet() == false) + + column.putNotNull(idx) + reference += false + idx += 1 + assert(column.anyNullsSet() == false) + + column.putNull(idx) + reference += true + idx += 1 + assert(column.anyNullsSet() == true) + assert(column.numNulls() == 1) + + column.putNulls(idx, 3) + reference += true + reference += true + reference += true + idx += 3 + assert(column.anyNullsSet() == true) + + column.putNotNulls(idx, 4) + reference += false + reference += false + reference += false + reference += false + idx += 4 + assert(column.anyNullsSet() == true) + assert(column.numNulls() == 4) + + reference.zipWithIndex.foreach { v => + assert(v._1 == column.isNullAt(v._2)) + if (memMode == MemoryMode.OFF_HEAP) { + val addr = column.nullsNativeAddress() + assert(v._1 == (Platform.getByte(null, addr + v._2) == 1), "index=" + v._2) + } + } + column.close + }} + } + + test("Byte Apis") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val reference = mutable.ArrayBuffer.empty[Byte] + + val column = ColumnVector.allocate(1024, ByteType, memMode) + var idx = 0 + + val values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).map(_.toByte).toArray + column.putBytes(idx, 2, values, 0) + reference += 1 + reference += 2 + idx += 2 + + column.putBytes(idx, 3, values, 2) + reference += 3 + reference += 4 + reference += 5 + idx += 3 + + column.putByte(idx, 9) + reference += 9 + idx += 1 + + column.putBytes(idx, 3, 4) + reference += 4 + reference += 4 + reference += 4 + idx += 3 + + reference.zipWithIndex.foreach { v => + assert(v._1 == column.getByte(v._2), "MemoryMode" + memMode) + if (memMode == MemoryMode.OFF_HEAP) { + val addr = column.valuesNativeAddress() + assert(v._1 == Platform.getByte(null, addr + v._2)) + } + } + }} + } + + test("Int Apis") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val seed = System.currentTimeMillis() + val random = new Random(seed) + val reference = mutable.ArrayBuffer.empty[Int] + + val column = ColumnVector.allocate(1024, IntegerType, memMode) + var idx = 0 + + val values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).toArray + column.putInts(idx, 2, values, 0) + reference += 1 + reference += 2 + idx += 2 + + column.putInts(idx, 3, values, 2) + reference += 3 + reference += 4 + reference += 5 + idx += 3 + + val littleEndian = new Array[Byte](8) + littleEndian(0) = 7 + littleEndian(1) = 1 + littleEndian(4) = 6 + littleEndian(6) = 1 + + column.putIntsLittleEndian(idx, 1, littleEndian, 4) + column.putIntsLittleEndian(idx + 1, 1, littleEndian, 0) + reference += 6 + (1 << 16) + reference += 7 + (1 << 8) + idx += 2 + + column.putIntsLittleEndian(idx, 2, littleEndian, 0) + reference += 7 + (1 << 8) + reference += 6 + (1 << 16) + idx += 2 + + while (idx < column.capacity) { + val single = random.nextBoolean() + if (single) { + val v = random.nextInt() + column.putInt(idx, v) + reference += v + idx += 1 + } else { + val n = math.min(random.nextInt(column.capacity / 20), column.capacity - idx) + column.putInts(idx, n, n + 1) + var i = 0 + while (i < n) { + reference += (n + 1) + i += 1 + } + idx += n + } + } + + reference.zipWithIndex.foreach { v => + assert(v._1 == column.getInt(v._2), "Seed = " + seed + " Mem Mode=" + memMode) + if (memMode == MemoryMode.OFF_HEAP) { + val addr = column.valuesNativeAddress() + assert(v._1 == Platform.getInt(null, addr + 4 * v._2)) + } + } + column.close + }} + } + + test("Long Apis") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val seed = System.currentTimeMillis() + val random = new Random(seed) + val reference = mutable.ArrayBuffer.empty[Long] + + val column = ColumnVector.allocate(1024, LongType, memMode) + var idx = 0 + + val values = (1L :: 2L :: 3L :: 4L :: 5L :: Nil).toArray + column.putLongs(idx, 2, values, 0) + reference += 1 + reference += 2 + idx += 2 + + column.putLongs(idx, 3, values, 2) + reference += 3 + reference += 4 + reference += 5 + idx += 3 + + val littleEndian = new Array[Byte](16) + littleEndian(0) = 7 + littleEndian(1) = 1 + littleEndian(8) = 6 + littleEndian(10) = 1 + + column.putLongsLittleEndian(idx, 1, littleEndian, 8) + column.putLongsLittleEndian(idx + 1, 1, littleEndian, 0) + reference += 6 + (1 << 16) + reference += 7 + (1 << 8) + idx += 2 + + column.putLongsLittleEndian(idx, 2, littleEndian, 0) + reference += 7 + (1 << 8) + reference += 6 + (1 << 16) + idx += 2 + + while (idx < column.capacity) { + val single = random.nextBoolean() + if (single) { + val v = random.nextLong() + column.putLong(idx, v) + reference += v + idx += 1 + } else { + + val n = math.min(random.nextInt(column.capacity / 20), column.capacity - idx) + column.putLongs(idx, n, n + 1) + var i = 0 + while (i < n) { + reference += (n + 1) + i += 1 + } + idx += n + } + } + + + reference.zipWithIndex.foreach { v => + assert(v._1 == column.getLong(v._2), "idx=" + v._2 + + " Seed = " + seed + " MemMode=" + memMode) + if (memMode == MemoryMode.OFF_HEAP) { + val addr = column.valuesNativeAddress() + assert(v._1 == Platform.getLong(null, addr + 8 * v._2)) + } + } + }} + } + + test("Double APIs") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val seed = System.currentTimeMillis() + val random = new Random(seed) + val reference = mutable.ArrayBuffer.empty[Double] + + val column = ColumnVector.allocate(1024, DoubleType, memMode) + var idx = 0 + + val values = (1.0 :: 2.0 :: 3.0 :: 4.0 :: 5.0 :: Nil).toArray + column.putDoubles(idx, 2, values, 0) + reference += 1.0 + reference += 2.0 + idx += 2 + + column.putDoubles(idx, 3, values, 2) + reference += 3.0 + reference += 4.0 + reference += 5.0 + idx += 3 + + val buffer = new Array[Byte](16) + Platform.putDouble(buffer, Platform.BYTE_ARRAY_OFFSET, 2.234) + Platform.putDouble(buffer, Platform.BYTE_ARRAY_OFFSET + 8, 1.123) + + column.putDoubles(idx, 1, buffer, 8) + column.putDoubles(idx + 1, 1, buffer, 0) + reference += 1.123 + reference += 2.234 + idx += 2 + + column.putDoubles(idx, 2, buffer, 0) + reference += 2.234 + reference += 1.123 + idx += 2 + + while (idx < column.capacity) { + val single = random.nextBoolean() + if (single) { + val v = random.nextDouble() + column.putDouble(idx, v) + reference += v + idx += 1 + } else { + val n = math.min(random.nextInt(column.capacity / 20), column.capacity - idx) + val v = random.nextDouble() + column.putDoubles(idx, n, v) + var i = 0 + while (i < n) { + reference += v + i += 1 + } + idx += n + } + } + + reference.zipWithIndex.foreach { v => + assert(v._1 == column.getDouble(v._2), "Seed = " + seed + " MemMode=" + memMode) + if (memMode == MemoryMode.OFF_HEAP) { + val addr = column.valuesNativeAddress() + assert(v._1 == Platform.getDouble(null, addr + 8 * v._2)) + } + } + column.close + }} + } + + test("String APIs") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val reference = mutable.ArrayBuffer.empty[String] + + val column = ColumnVector.allocate(6, BinaryType, memMode) + assert(column.arrayData().elementsAppended == 0) + var idx = 0 + + val values = ("Hello" :: "abc" :: Nil).toArray + column.putByteArray(idx, values(0).getBytes(StandardCharsets.UTF_8), + 0, values(0).getBytes(StandardCharsets.UTF_8).length) + reference += values(0) + idx += 1 + assert(column.arrayData().elementsAppended == 5) + + column.putByteArray(idx, values(1).getBytes(StandardCharsets.UTF_8), + 0, values(1).getBytes(StandardCharsets.UTF_8).length) + reference += values(1) + idx += 1 + assert(column.arrayData().elementsAppended == 8) + + // Just put llo + val offset = column.putByteArray(idx, values(0).getBytes(StandardCharsets.UTF_8), + 2, values(0).getBytes(StandardCharsets.UTF_8).length - 2) + reference += "llo" + idx += 1 + assert(column.arrayData().elementsAppended == 11) + + // Put the same "ll" at offset. This should not allocate more memory in the column. + column.putArray(idx, offset, 2) + reference += "ll" + idx += 1 + assert(column.arrayData().elementsAppended == 11) + + // Put a long string + val s = "abcdefghijklmnopqrstuvwxyz" + column.putByteArray(idx, (s + s).getBytes(StandardCharsets.UTF_8)) + reference += (s + s) + idx += 1 + assert(column.arrayData().elementsAppended == 11 + (s + s).length) + + reference.zipWithIndex.foreach { v => + assert(v._1.length == column.getArrayLength(v._2), "MemoryMode=" + memMode) + assert(v._1 == column.getUTF8String(v._2).toString, + "MemoryMode" + memMode) + } + + column.reset() + assert(column.arrayData().elementsAppended == 0) + }} + } + + test("Int Array") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val column = ColumnVector.allocate(10, new ArrayType(IntegerType, true), memMode) + + // Fill the underlying data with all the arrays back to back. + val data = column.arrayData(); + var i = 0 + while (i < 6) { + data.putInt(i, i) + i += 1 + } + + // Populate it with arrays [0], [1, 2], [], [3, 4, 5] + column.putArray(0, 0, 1) + column.putArray(1, 1, 2) + column.putArray(2, 2, 0) + column.putArray(3, 3, 3) + + val a1 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] + val a2 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(1)).asInstanceOf[Array[Int]] + val a3 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(2)).asInstanceOf[Array[Int]] + val a4 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(3)).asInstanceOf[Array[Int]] + assert(a1 === Array(0)) + assert(a2 === Array(1, 2)) + assert(a3 === Array.empty[Int]) + assert(a4 === Array(3, 4, 5)) + + // Verify the ArrayData APIs + assert(column.getArray(0).length == 1) + assert(column.getArray(0).getInt(0) == 0) + + assert(column.getArray(1).length == 2) + assert(column.getArray(1).getInt(0) == 1) + assert(column.getArray(1).getInt(1) == 2) + + assert(column.getArray(2).length == 0) + + assert(column.getArray(3).length == 3) + assert(column.getArray(3).getInt(0) == 3) + assert(column.getArray(3).getInt(1) == 4) + assert(column.getArray(3).getInt(2) == 5) + + // Add a longer array which requires resizing + column.reset + val array = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) + assert(data.capacity == 10) + data.reserve(array.length) + assert(data.capacity == array.length * 2) + data.putInts(0, array.length, array, 0) + column.putArray(0, 0, array.length) + assert(ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] + === array) + }} + } + + test("Struct Column") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val schema = new StructType().add("int", IntegerType).add("double", DoubleType) + val column = ColumnVector.allocate(1024, schema, memMode) + + val c1 = column.getChildColumn(0) + val c2 = column.getChildColumn(1) + assert(c1.dataType() == IntegerType) + assert(c2.dataType() == DoubleType) + + c1.putInt(0, 123) + c2.putDouble(0, 3.45) + c1.putInt(1, 456) + c2.putDouble(1, 5.67) + + val s = column.getStruct(0) + assert(s.columns()(0).getInt(0) == 123) + assert(s.columns()(0).getInt(1) == 456) + assert(s.columns()(1).getDouble(0) == 3.45) + assert(s.columns()(1).getDouble(1) == 5.67) + + assert(s.getInt(0) == 123) + assert(s.getDouble(1) == 3.45) + + val s2 = column.getStruct(1) + assert(s2.getInt(0) == 456) + assert(s2.getDouble(1) == 5.67) + }} + } + + test("ColumnarBatch basic") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val schema = new StructType() + .add("intCol", IntegerType) + .add("doubleCol", DoubleType) + .add("intCol2", IntegerType) + .add("string", BinaryType) + + val batch = ColumnarBatch.allocate(schema, memMode) + assert(batch.numCols() == 4) + assert(batch.numRows() == 0) + assert(batch.numValidRows() == 0) + assert(batch.capacity() > 0) + assert(batch.rowIterator().hasNext == false) + + // Add a row [1, 1.1, NULL] + batch.column(0).putInt(0, 1) + batch.column(1).putDouble(0, 1.1) + batch.column(2).putNull(0) + batch.column(3).putByteArray(0, "Hello".getBytes(StandardCharsets.UTF_8)) + batch.setNumRows(1) + + // Verify the results of the row. + assert(batch.numCols() == 4) + assert(batch.numRows() == 1) + assert(batch.numValidRows() == 1) + assert(batch.rowIterator().hasNext == true) + assert(batch.rowIterator().hasNext == true) + + assert(batch.column(0).getInt(0) == 1) + assert(batch.column(0).isNullAt(0) == false) + assert(batch.column(1).getDouble(0) == 1.1) + assert(batch.column(1).isNullAt(0) == false) + assert(batch.column(2).isNullAt(0) == true) + assert(batch.column(3).getUTF8String(0).toString == "Hello") + + // Verify the iterator works correctly. + val it = batch.rowIterator() + assert(it.hasNext()) + val row = it.next() + assert(row.getInt(0) == 1) + assert(row.isNullAt(0) == false) + assert(row.getDouble(1) == 1.1) + assert(row.isNullAt(1) == false) + assert(row.isNullAt(2) == true) + assert(batch.column(3).getUTF8String(0).toString == "Hello") + assert(it.hasNext == false) + assert(it.hasNext == false) + + // Filter out the row. + row.markFiltered() + assert(batch.numRows() == 1) + assert(batch.numValidRows() == 0) + assert(batch.rowIterator().hasNext == false) + + // Reset and add 3 rows + batch.reset() + assert(batch.numRows() == 0) + assert(batch.numValidRows() == 0) + assert(batch.rowIterator().hasNext == false) + + // Add rows [NULL, 2.2, 2, "abc"], [3, NULL, 3, ""], [4, 4.4, 4, "world] + batch.column(0).putNull(0) + batch.column(1).putDouble(0, 2.2) + batch.column(2).putInt(0, 2) + batch.column(3).putByteArray(0, "abc".getBytes(StandardCharsets.UTF_8)) + + batch.column(0).putInt(1, 3) + batch.column(1).putNull(1) + batch.column(2).putInt(1, 3) + batch.column(3).putByteArray(1, "".getBytes(StandardCharsets.UTF_8)) + + batch.column(0).putInt(2, 4) + batch.column(1).putDouble(2, 4.4) + batch.column(2).putInt(2, 4) + batch.column(3).putByteArray(2, "world".getBytes(StandardCharsets.UTF_8)) + batch.setNumRows(3) + + def rowEquals(x: InternalRow, y: Row): Unit = { + assert(x.isNullAt(0) == y.isNullAt(0)) + if (!x.isNullAt(0)) assert(x.getInt(0) == y.getInt(0)) + + assert(x.isNullAt(1) == y.isNullAt(1)) + if (!x.isNullAt(1)) assert(x.getDouble(1) == y.getDouble(1)) + + assert(x.isNullAt(2) == y.isNullAt(2)) + if (!x.isNullAt(2)) assert(x.getInt(2) == y.getInt(2)) + + assert(x.isNullAt(3) == y.isNullAt(3)) + if (!x.isNullAt(3)) assert(x.getString(3) == y.getString(3)) + } + + // Verify + assert(batch.numRows() == 3) + assert(batch.numValidRows() == 3) + val it2 = batch.rowIterator() + rowEquals(it2.next(), Row(null, 2.2, 2, "abc")) + rowEquals(it2.next(), Row(3, null, 3, "")) + rowEquals(it2.next(), Row(4, 4.4, 4, "world")) + assert(!it.hasNext) + + // Filter out some rows and verify + batch.markFiltered(1) + assert(batch.numValidRows() == 2) + val it3 = batch.rowIterator() + rowEquals(it3.next(), Row(null, 2.2, 2, "abc")) + rowEquals(it3.next(), Row(4, 4.4, 4, "world")) + assert(!it.hasNext) + + batch.markFiltered(2) + assert(batch.numValidRows() == 1) + val it4 = batch.rowIterator() + rowEquals(it4.next(), Row(null, 2.2, 2, "abc")) + + batch.close + }} + } + + private def doubleEquals(d1: Double, d2: Double): Boolean = { + if (d1.isNaN && d2.isNaN) { + true + } else { + d1 == d2 + } + } + + private def compareStruct(fields: Seq[StructField], r1: InternalRow, r2: Row, seed: Long) { + fields.zipWithIndex.foreach { v => { + assert(r1.isNullAt(v._2) == r2.isNullAt(v._2), "Seed = " + seed) + if (!r1.isNullAt(v._2)) { + v._1.dataType match { + case BooleanType => assert(r1.getBoolean(v._2) == r2.getBoolean(v._2), "Seed = " + seed) + case ByteType => assert(r1.getByte(v._2) == r2.getByte(v._2), "Seed = " + seed) + case ShortType => assert(r1.getShort(v._2) == r2.getShort(v._2), "Seed = " + seed) + case IntegerType => assert(r1.getInt(v._2) == r2.getInt(v._2), "Seed = " + seed) + case LongType => assert(r1.getLong(v._2) == r2.getLong(v._2), "Seed = " + seed) + case FloatType => assert(doubleEquals(r1.getFloat(v._2), r2.getFloat(v._2)), + "Seed = " + seed) + case DoubleType => assert(doubleEquals(r1.getDouble(v._2), r2.getDouble(v._2)), + "Seed = " + seed) + case t: DecimalType => + val d1 = r1.getDecimal(v._2, t.precision, t.scale).toBigDecimal + val d2 = r2.getDecimal(v._2) + assert(d1.compare(d2) == 0, "Seed = " + seed) + case StringType => + assert(r1.getString(v._2) == r2.getString(v._2), "Seed = " + seed) + case CalendarIntervalType => + assert(r1.getInterval(v._2) === r2.get(v._2).asInstanceOf[CalendarInterval]) + case ArrayType(childType, n) => + val a1 = r1.getArray(v._2).array + val a2 = r2.getList(v._2).toArray + assert(a1.length == a2.length, "Seed = " + seed) + childType match { + case DoubleType => + var i = 0 + while (i < a1.length) { + assert(doubleEquals(a1(i).asInstanceOf[Double], a2(i).asInstanceOf[Double]), + "Seed = " + seed) + i += 1 + } + case FloatType => + var i = 0 + while (i < a1.length) { + assert(doubleEquals(a1(i).asInstanceOf[Float], a2(i).asInstanceOf[Float]), + "Seed = " + seed) + i += 1 + } + case t: DecimalType => + var i = 0 + while (i < a1.length) { + assert((a1(i) == null) == (a2(i) == null), "Seed = " + seed) + if (a1(i) != null) { + val d1 = a1(i).asInstanceOf[Decimal].toBigDecimal + val d2 = a2(i).asInstanceOf[java.math.BigDecimal] + assert(d1.compare(d2) == 0, "Seed = " + seed) + } + i += 1 + } + case _ => assert(a1 === a2, "Seed = " + seed) + } + case StructType(childFields) => + compareStruct(childFields, r1.getStruct(v._2, fields.length), r2.getStruct(v._2), seed) + case _ => + throw new NotImplementedError("Not implemented " + v._1.dataType) + } + } + }} + } + + test("Convert rows") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val rows = Row(1, 2L, "a", 1.2, 'b'.toByte) :: Row(4, 5L, "cd", 2.3, 'a'.toByte) :: Nil + val schema = new StructType() + .add("i1", IntegerType) + .add("l2", LongType) + .add("string", StringType) + .add("d", DoubleType) + .add("b", ByteType) + + val batch = ColumnVectorUtils.toBatch(schema, memMode, rows.iterator.asJava) + assert(batch.numRows() == 2) + assert(batch.numCols() == 5) + + val it = batch.rowIterator() + val referenceIt = rows.iterator + while (it.hasNext) { + compareStruct(schema, it.next(), referenceIt.next(), 0) + } + batch.close() + } + }} + + /** + * This test generates a random schema data, serializes it to column batches and verifies the + * results. + */ + def testRandomRows(flatSchema: Boolean, numFields: Int) { + // TODO: Figure out why StringType doesn't work on jenkins. + val types = Array( + BooleanType, ByteType, FloatType, DoubleType, + IntegerType, LongType, ShortType, DecimalType.IntDecimal, new DecimalType(30, 10), + CalendarIntervalType) + val seed = System.nanoTime() + val NUM_ROWS = 200 + val NUM_ITERS = 1000 + val random = new Random(seed) + var i = 0 + while (i < NUM_ITERS) { + val schema = if (flatSchema) { + RandomDataGenerator.randomSchema(random, numFields, types) + } else { + RandomDataGenerator.randomNestedSchema(random, numFields, types) + } + val rows = mutable.ArrayBuffer.empty[Row] + var j = 0 + while (j < NUM_ROWS) { + val row = RandomDataGenerator.randomRow(random, schema) + rows += row + j += 1 + } + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val batch = ColumnVectorUtils.toBatch(schema, memMode, rows.iterator.asJava) + assert(batch.numRows() == NUM_ROWS) + + val it = batch.rowIterator() + val referenceIt = rows.iterator + var k = 0 + while (it.hasNext) { + compareStruct(schema, it.next(), referenceIt.next(), seed) + k += 1 + } + batch.close() + }} + i += 1 + } + } + + test("Random flat schema") { + testRandomRows(true, 15) + } + + test("Random nested schema") { + testRandomRows(false, 30) + } + + test("null filtered columns") { + val NUM_ROWS = 10 + val schema = new StructType() + .add("key", IntegerType, nullable = false) + .add("value", StringType, nullable = true) + for (numNulls <- List(0, NUM_ROWS / 2, NUM_ROWS)) { + val rows = mutable.ArrayBuffer.empty[Row] + for (i <- 0 until NUM_ROWS) { + val row = if (i < numNulls) Row.fromSeq(Seq(i, null)) else Row.fromSeq(Seq(i, i.toString)) + rows += row + } + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val batch = ColumnVectorUtils.toBatch(schema, memMode, rows.iterator.asJava) + batch.filterNullsInColumn(1) + batch.setNumRows(NUM_ROWS) + assert(batch.numRows() == NUM_ROWS) + val it = batch.rowIterator() + // Top numNulls rows should be filtered + var k = numNulls + while (it.hasNext) { + assert(it.next().getInt(0) == k) + k += 1 + } + assert(k == NUM_ROWS) + batch.close() + }} + } + } + + test("mutable ColumnarBatch rows") { + val NUM_ITERS = 10 + val types = Array( + BooleanType, FloatType, DoubleType, + IntegerType, LongType, ShortType, DecimalType.IntDecimal, new DecimalType(30, 10)) + for (i <- 0 to NUM_ITERS) { + val random = new Random(System.nanoTime()) + val schema = RandomDataGenerator.randomSchema(random, numFields = 20, types) + val oldRow = RandomDataGenerator.randomRow(random, schema) + val newRow = RandomDataGenerator.randomRow(random, schema) + + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => + val batch = ColumnVectorUtils.toBatch(schema, memMode, (oldRow :: Nil).iterator.asJava) + val columnarBatchRow = batch.getRow(0) + newRow.toSeq.zipWithIndex.foreach(i => columnarBatchRow.update(i._2, i._1)) + compareStruct(schema, columnarBatchRow, newRow, 0) + batch.close() + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/RuntimeConfigSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/RuntimeConfigSuite.scala new file mode 100644 index 0000000000000..f809e01169355 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/RuntimeConfigSuite.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.RuntimeConfig + +class RuntimeConfigSuite extends SparkFunSuite { + + private def newConf(): RuntimeConfig = new RuntimeConfigImpl + + test("set and get") { + val conf = newConf() + conf + .set("k1", "v1") + .set("k2", 2) + .set("k3", value = false) + + assert(conf.get("k1") == "v1") + assert(conf.get("k2") == "2") + assert(conf.get("k3") == "false") + + intercept[NoSuchElementException] { + conf.get("notset") + } + } + + test("getOption") { + val conf = newConf().set("k1", "v1") + assert(conf.getOption("k1") == Some("v1")) + assert(conf.getOption("notset") == None) + } + + test("unset") { + val conf = newConf().set("k1", "v1") + assert(conf.get("k1") == "v1") + conf.unset("k1") + intercept[NoSuchElementException] { + conf.get("k1") + } + } + + test("set and get hadoop configuration") { + val conf = newConf() + conf + .setHadoop("k1", "v1") + .setHadoop("k2", "v2") + + assert(conf.getHadoop("k1") == "v1") + assert(conf.getHadoop("k2") == "v2") + + intercept[NoSuchElementException] { + conf.get("notset") + } + } + + test("getHadoopOption") { + val conf = newConf().setHadoop("k1", "v1") + assert(conf.getHadoopOption("k1") == Some("v1")) + assert(conf.getHadoopOption("notset") == None) + } + + test("unsetHadoop") { + val conf = newConf().setHadoop("k1", "v1") + assert(conf.getHadoop("k1") == "v1") + conf.unsetHadoop("k1") + intercept[NoSuchElementException] { + conf.getHadoop("k1") + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala similarity index 83% rename from sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala index 2e33777f14adc..cc6919913948d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala @@ -15,10 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.internal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SQLConf._ +import org.apache.spark.sql.internal.SQLConf._ class SQLConfEntrySuite extends SparkFunSuite { @@ -26,7 +26,7 @@ class SQLConfEntrySuite extends SparkFunSuite { test("intConf") { val key = "spark.sql.SQLConfEntrySuite.int" - val confEntry = SQLConfEntry.intConf(key) + val confEntry = SQLConfigBuilder(key).intConf.createWithDefault(1) assert(conf.getConf(confEntry, 5) === 5) conf.setConf(confEntry, 10) @@ -45,7 +45,7 @@ class SQLConfEntrySuite extends SparkFunSuite { test("longConf") { val key = "spark.sql.SQLConfEntrySuite.long" - val confEntry = SQLConfEntry.longConf(key) + val confEntry = SQLConfigBuilder(key).longConf.createWithDefault(1L) assert(conf.getConf(confEntry, 5L) === 5L) conf.setConf(confEntry, 10L) @@ -64,7 +64,7 @@ class SQLConfEntrySuite extends SparkFunSuite { test("booleanConf") { val key = "spark.sql.SQLConfEntrySuite.boolean" - val confEntry = SQLConfEntry.booleanConf(key) + val confEntry = SQLConfigBuilder(key).booleanConf.createWithDefault(true) assert(conf.getConf(confEntry, false) === false) conf.setConf(confEntry, true) @@ -83,7 +83,7 @@ class SQLConfEntrySuite extends SparkFunSuite { test("doubleConf") { val key = "spark.sql.SQLConfEntrySuite.double" - val confEntry = SQLConfEntry.doubleConf(key) + val confEntry = SQLConfigBuilder(key).doubleConf.createWithDefault(1d) assert(conf.getConf(confEntry, 5.0) === 5.0) conf.setConf(confEntry, 10.0) @@ -102,7 +102,7 @@ class SQLConfEntrySuite extends SparkFunSuite { test("stringConf") { val key = "spark.sql.SQLConfEntrySuite.string" - val confEntry = SQLConfEntry.stringConf(key) + val confEntry = SQLConfigBuilder(key).stringConf.createWithDefault(null) assert(conf.getConf(confEntry, "abc") === "abc") conf.setConf(confEntry, "abcd") @@ -116,7 +116,10 @@ class SQLConfEntrySuite extends SparkFunSuite { test("enumConf") { val key = "spark.sql.SQLConfEntrySuite.enum" - val confEntry = SQLConfEntry.enumConf(key, v => v, Set("a", "b", "c"), defaultValue = Some("a")) + val confEntry = SQLConfigBuilder(key) + .stringConf + .checkValues(Set("a", "b", "c")) + .createWithDefault("a") assert(conf.getConf(confEntry) === "a") conf.setConf(confEntry, "b") @@ -135,8 +138,10 @@ class SQLConfEntrySuite extends SparkFunSuite { test("stringSeqConf") { val key = "spark.sql.SQLConfEntrySuite.stringSeq" - val confEntry = SQLConfEntry.stringSeqConf("spark.sql.SQLConfEntrySuite.stringSeq", - defaultValue = Some(Nil)) + val confEntry = SQLConfigBuilder(key) + .stringConf + .toSequence + .createWithDefault(Nil) assert(conf.getConf(confEntry, Seq("a", "b", "c")) === Seq("a", "b", "c")) conf.setConf(confEntry, Seq("a", "b", "c", "d")) @@ -147,4 +152,12 @@ class SQLConfEntrySuite extends SparkFunSuite { assert(conf.getConfString(key) === "a,b,c,d,e") assert(conf.getConf(confEntry, Seq("a", "b", "c")) === Seq("a", "b", "c", "d", "e")) } + + test("duplicate entry") { + val key = "spark.sql.SQLConfEntrySuite.duplicate" + SQLConfigBuilder(key).stringConf.createOptional + intercept[IllegalArgumentException] { + SQLConfigBuilder(key).stringConf.createOptional + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala new file mode 100644 index 0000000000000..e687e6a5cefe9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -0,0 +1,128 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.internal + +import org.apache.spark.sql.{QueryTest, SQLContext} +import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} + +class SQLConfSuite extends QueryTest with SharedSQLContext { + private val testKey = "test.key.0" + private val testVal = "test.val.0" + + test("propagate from spark conf") { + // We create a new context here to avoid order dependence with other tests that might call + // clear(). + val newContext = new SQLContext(sparkContext) + assert(newContext.getConf("spark.sql.testkey", "false") === "true") + } + + test("programmatic ways of basic setting and getting") { + // Set a conf first. + sqlContext.setConf(testKey, testVal) + // Clear the conf. + sqlContext.conf.clear() + // After clear, only overrideConfs used by unit test should be in the SQLConf. + assert(sqlContext.getAllConfs === TestSQLContext.overrideConfs) + + sqlContext.setConf(testKey, testVal) + assert(sqlContext.getConf(testKey) === testVal) + assert(sqlContext.getConf(testKey, testVal + "_") === testVal) + assert(sqlContext.getAllConfs.contains(testKey)) + + // Tests SQLConf as accessed from a SQLContext is mutable after + // the latter is initialized, unlike SparkConf inside a SparkContext. + assert(sqlContext.getConf(testKey) === testVal) + assert(sqlContext.getConf(testKey, testVal + "_") === testVal) + assert(sqlContext.getAllConfs.contains(testKey)) + + sqlContext.conf.clear() + } + + test("parse SQL set commands") { + sqlContext.conf.clear() + sql(s"set $testKey=$testVal") + assert(sqlContext.getConf(testKey, testVal + "_") === testVal) + assert(sqlContext.getConf(testKey, testVal + "_") === testVal) + + sql("set some.property=20") + assert(sqlContext.getConf("some.property", "0") === "20") + sql("set some.property = 40") + assert(sqlContext.getConf("some.property", "0") === "40") + + val key = "spark.sql.key" + val vs = "val0,val_1,val2.3,my_table" + sql(s"set $key=$vs") + assert(sqlContext.getConf(key, "0") === vs) + + sql(s"set $key=") + assert(sqlContext.getConf(key, "0") === "") + + sqlContext.conf.clear() + } + + test("deprecated property") { + sqlContext.conf.clear() + val original = sqlContext.conf.numShufflePartitions + try{ + sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") + assert(sqlContext.conf.numShufflePartitions === 10) + } finally { + sql(s"set ${SQLConf.SHUFFLE_PARTITIONS}=$original") + } + } + + test("invalid conf value") { + sqlContext.conf.clear() + val e = intercept[IllegalArgumentException] { + sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") + } + assert(e.getMessage === s"${SQLConf.CASE_SENSITIVE.key} should be boolean, but was 10") + } + + test("Test SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE's method") { + sqlContext.conf.clear() + + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "100") + assert(sqlContext.conf.targetPostShuffleInputSize === 100) + + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1k") + assert(sqlContext.conf.targetPostShuffleInputSize === 1024) + + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1M") + assert(sqlContext.conf.targetPostShuffleInputSize === 1048576) + + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1g") + assert(sqlContext.conf.targetPostShuffleInputSize === 1073741824) + + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-1") + assert(sqlContext.conf.targetPostShuffleInputSize === -1) + + // Test overflow exception + intercept[IllegalArgumentException] { + // This value exceeds Long.MaxValue + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "90000000000g") + } + + intercept[IllegalArgumentException] { + // This value less than Long.MinValue + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-90000000000g") + } + + sqlContext.conf.clear() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index d530b1a469ce2..f66deea06589c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -18,18 +18,25 @@ package org.apache.spark.sql.jdbc import java.math.BigDecimal -import java.sql.DriverManager +import java.sql.{Date, DriverManager, Timestamp} import java.util.{Calendar, GregorianCalendar, Properties} import org.h2.jdbc.JdbcSQLException -import org.scalatest.BeforeAndAfter +import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.execution.DataSourceScan +import org.apache.spark.sql.execution.command.ExplainCommand +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD +import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { +class JDBCSuite extends SparkFunSuite + with BeforeAndAfter with PrivateMethodTester with SharedSQLContext { import testImplicits._ val url = "jdbc:h2:mem:testdb0" @@ -164,6 +171,27 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext |OPTIONS (url '$url', dbtable 'TEST.NULLTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement( + "create table test.emp(name TEXT(32) NOT NULL," + + " theid INTEGER, \"Dept\" INTEGER)").executeUpdate() + conn.prepareStatement( + "insert into test.emp values ('fred', 1, 10)").executeUpdate() + conn.prepareStatement( + "insert into test.emp values ('mary', 2, null)").executeUpdate() + conn.prepareStatement( + "insert into test.emp values ('joe ''foo'' \"bar\"', 3, 30)").executeUpdate() + conn.prepareStatement( + "insert into test.emp values ('kathy', null, null)").executeUpdate() + conn.commit() + + sql( + s""" + |CREATE TEMPORARY TABLE nullparts + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable 'TEST.EMP', user 'testUser', password 'testPass', + |partitionColumn '"Dept"', lowerBound '1', upperBound '4', numPartitions '4') + """.stripMargin.replaceAll("\n", " ")) + // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types. } @@ -176,12 +204,65 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("SELECT * WHERE (simple predicates)") { - assert(sql("SELECT * FROM foobar WHERE THEID < 1").collect().size === 0) - assert(sql("SELECT * FROM foobar WHERE THEID != 2").collect().size === 2) - assert(sql("SELECT * FROM foobar WHERE THEID = 1").collect().size === 1) - assert(sql("SELECT * FROM foobar WHERE NAME = 'fred'").collect().size === 1) - assert(sql("SELECT * FROM foobar WHERE NAME > 'fred'").collect().size === 2) - assert(sql("SELECT * FROM foobar WHERE NAME != 'fred'").collect().size === 2) + def checkPushdown(df: DataFrame): DataFrame = { + val parentPlan = df.queryExecution.executedPlan + // Check if SparkPlan Filter is removed in a physical plan and + // the plan only has PhysicalRDD to scan JDBCRelation. + assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen]) + val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen] + assert(node.child.isInstanceOf[org.apache.spark.sql.execution.DataSourceScan]) + assert(node.child.asInstanceOf[DataSourceScan].nodeName.contains("JDBCRelation")) + df + } + assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size == 0) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID != 2")).collect().size == 2) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID = 1")).collect().size == 1) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size == 1) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME <=> 'fred'")).collect().size == 1) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size == 2) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size == 2) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME IN ('mary', 'fred')")) + .collect().size == 2) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME NOT IN ('fred')")) + .collect().size == 2) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID = 1 OR NAME = 'mary'")) + .collect().size == 2) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID = 1 OR NAME = 'mary' " + + "AND THEID = 2")).collect().size == 2) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME LIKE 'fr%'")).collect().size == 1) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME LIKE '%ed'")).collect().size == 1) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE NAME LIKE '%re%'")).collect().size == 1) + assert(checkPushdown(sql("SELECT * FROM nulltypes WHERE A IS NULL")).collect().size == 1) + assert(checkPushdown(sql("SELECT * FROM nulltypes WHERE A IS NOT NULL")).collect().size == 0) + + // This is a test to reflect discussion in SPARK-12218. + // The older versions of spark have this kind of bugs in parquet data source. + val df1 = sql("SELECT * FROM foobar WHERE NOT (THEID != 2 AND NAME != 'mary')") + val df2 = sql("SELECT * FROM foobar WHERE NOT (THEID != 2) OR NOT (NAME != 'mary')") + assert(df1.collect.toSet === Set(Row("mary", 2))) + assert(df2.collect.toSet === Set(Row("mary", 2))) + + def checkNotPushdown(df: DataFrame): DataFrame = { + val parentPlan = df.queryExecution.executedPlan + // Check if SparkPlan Filter is not removed in a physical plan because JDBCRDD + // cannot compile given predicates. + assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen]) + val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen] + assert(node.child.isInstanceOf[org.apache.spark.sql.execution.Filter]) + df + } + assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 1) < 2")).collect().size == 0) + assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 2) != 4")).collect().size == 2) + } + + test("SELECT COUNT(1) WHERE (predicates)") { + // Check if an answer is correct when Filter is removed from operations such as count() which + // does not require any columns. In some data sources, e.g., Parquet, `requiredColumns` in + // org.apache.spark.sql.sources.interfaces is not given in logical plans, but some filters + // are applied for columns with Filter producing wrong results. On the other hand, JDBCRDD + // correctly handles this case by assigning `requiredColumns` properly. See PR 10427 for more + // discussions. + assert(sql("SELECT COUNT(1) FROM foobar WHERE NAME = 'mary'").collect.toSet === Set(Row(1))) } test("SELECT * WHERE (quoted strings)") { @@ -278,6 +359,23 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext .collect().length === 3) } + test("Partitioning on column that might have null values.") { + assert( + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties) + .collect().length === 4) + assert( + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties) + .collect().length === 4) + // partitioning on a nullable quoted column + assert( + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", """"Dept"""", 0, 4, 3, new Properties) + .collect().length === 4) + } + + test("SELECT * on partitioned table with a nullable partition column") { + assert(sql("SELECT * FROM nullparts").collect().size == 4) + } + test("H2 integral types") { val rows = sql("SELECT * FROM inttypes WHERE A IS NOT NULL").collect() assert(rows.length === 1) @@ -427,6 +525,32 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(DerbyColumns === Seq(""""abc"""", """"key"""")) } + test("compile filters") { + val compileFilter = PrivateMethod[Option[String]]('compileFilter) + def doCompileFilter(f: Filter): String = JDBCRDD invokePrivate compileFilter(f) getOrElse("") + assert(doCompileFilter(EqualTo("col0", 3)) === "col0 = 3") + assert(doCompileFilter(Not(EqualTo("col1", "abc"))) === "(NOT (col1 = 'abc'))") + assert(doCompileFilter(And(EqualTo("col0", 0), EqualTo("col1", "def"))) + === "(col0 = 0) AND (col1 = 'def')") + assert(doCompileFilter(Or(EqualTo("col0", 2), EqualTo("col1", "ghi"))) + === "(col0 = 2) OR (col1 = 'ghi')") + assert(doCompileFilter(LessThan("col0", 5)) === "col0 < 5") + assert(doCompileFilter(LessThan("col3", + Timestamp.valueOf("1995-11-21 00:00:00.0"))) === "col3 < '1995-11-21 00:00:00.0'") + assert(doCompileFilter(LessThan("col4", Date.valueOf("1983-08-04"))) === "col4 < '1983-08-04'") + assert(doCompileFilter(LessThanOrEqual("col0", 5)) === "col0 <= 5") + assert(doCompileFilter(GreaterThan("col0", 3)) === "col0 > 3") + assert(doCompileFilter(GreaterThanOrEqual("col0", 3)) === "col0 >= 3") + assert(doCompileFilter(In("col1", Array("jkl"))) === "col1 IN ('jkl')") + assert(doCompileFilter(Not(In("col1", Array("mno", "pqr")))) + === "(NOT (col1 IN ('mno', 'pqr')))") + assert(doCompileFilter(IsNull("col1")) === "col1 IS NULL") + assert(doCompileFilter(IsNotNull("col1")) === "col1 IS NOT NULL") + assert(doCompileFilter(And(EqualNullSafe("col0", "abc"), EqualTo("col1", "def"))) + === "((NOT (col0 != 'abc' OR col0 IS NULL OR 'abc' IS NULL) " + + "OR (col0 IS NULL AND 'abc' IS NULL))) AND (col1 = 'def')") + } + test("Dialect unregister") { JdbcDialects.registerDialect(testH2Dialect) JdbcDialects.unregisterDialect(testH2Dialect) @@ -460,6 +584,12 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") assert(Postgres.getCatalystType(java.sql.Types.OTHER, "json", 1, null) === Some(StringType)) assert(Postgres.getCatalystType(java.sql.Types.OTHER, "jsonb", 1, null) === Some(StringType)) + assert(Postgres.getJDBCType(FloatType).map(_.databaseTypeDefinition).get == "FLOAT4") + assert(Postgres.getJDBCType(DoubleType).map(_.databaseTypeDefinition).get == "FLOAT8") + val errMsg = intercept[IllegalArgumentException] { + Postgres.getJDBCType(ByteType) + } + assert(errMsg.getMessage contains "Unsupported type in postgresql: ByteType") } test("DerbyDialect jdbc type mapping") { @@ -484,4 +614,41 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(h2.getTableExistsQuery(table) == defaultQuery) assert(derby.getTableExistsQuery(table) == defaultQuery) } + + test("Test DataFrame.where for Date and Timestamp") { + // Regression test for bug SPARK-11788 + val timestamp = java.sql.Timestamp.valueOf("2001-02-20 11:22:33.543543"); + val date = java.sql.Date.valueOf("1995-01-01") + val jdbcDf = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val rows = jdbcDf.where($"B" > date && $"C" > timestamp).collect() + assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) + assert(rows(0).getAs[java.sql.Timestamp](2) + === java.sql.Timestamp.valueOf("2002-02-20 11:22:33.543543")) + } + + test("test credentials in the properties are not in plan output") { + val df = sql("SELECT * FROM parts") + val explain = ExplainCommand(df.queryExecution.logical, extended = true) + sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + r => assert(!List("testPass", "testUser").exists(r.toString.contains)) + } + // test the JdbcRelation toString output + df.queryExecution.analyzed.collect { + case r: LogicalRelation => assert(r.relation.toString == "JDBCRelation(TEST.PEOPLE)") + } + } + + test("test credentials in the connection url are not in the plan output") { + val df = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) + val explain = ExplainCommand(df.queryExecution.logical, extended = true) + sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + r => assert(!List("testPass", "testUser").exists(r.toString.contains)) + } + } + + test("SPARK 12941: The data type mapping for StringType to Oracle") { + val oracleDialect = JdbcDialects.get("jdbc:oracle://127.0.0.1/db") + assert(oracleDialect.getJDBCType(StringType). + map(_.databaseTypeDefinition).get == "VARCHAR2(255)") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 6fc9febe49707..cb88a1c83c999 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -22,7 +22,6 @@ import java.io.{File, IOException} import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.execution.datasources.DDLException import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -105,7 +104,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with sql("SELECT a, b FROM jsonTable"), sql("SELECT a, b FROM jt").collect()) - val message = intercept[DDLException]{ + val message = intercept[AnalysisException]{ sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable @@ -156,7 +155,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with } test("CREATE TEMPORARY TABLE AS SELECT with IF NOT EXISTS is not allowed") { - val message = intercept[DDLException]{ + val message = intercept[AnalysisException]{ sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable @@ -173,7 +172,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with } test("a CTAS statement with column definitions is not allowed") { - intercept[DDLException]{ + intercept[AnalysisException]{ sql( s""" |CREATE TEMPORARY TABLE jsonTable (a int, b string) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index af04079ec895a..92061133cd49b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import org.apache.spark.sql._ +import org.apache.spark.sql.internal.SQLConf private[sql] abstract class DataSourceTest extends QueryTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 7541e723029bf..19e34b45bff67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -17,16 +17,16 @@ package org.apache.spark.sql.sources -import org.apache.spark.sql.execution.datasources.LogicalRelation - import scala.language.existentials import org.apache.spark.rdd.RDD -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.PredicateHelper +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ - +import org.apache.spark.unsafe.types.UTF8String class FilteredScanSource extends RelationProvider { override def createRelation( @@ -130,7 +130,7 @@ object ColumnsRequired { var set: Set[String] = Set.empty } -class FilteredScanSuite extends DataSourceTest with SharedSQLContext { +class FilteredScanSuite extends DataSourceTest with SharedSQLContext with PredicateHelper { protected override lazy val sql = caseInsensitiveContext.sql _ override def beforeAll(): Unit = { @@ -144,9 +144,6 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext { | to '10' |) """.stripMargin) - - // UDF for testing filter push-down - caseInsensitiveContext.udf.register("udf_gt3", (_: Int) > 3) } sqlTest( @@ -258,7 +255,11 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext { testPushDown("SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", 3, Set("a", "b", "c")) testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0, Set("a", "b", "c")) - testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10, Set("a", "b", "c")) + testPushDown( + "SELECT * FROM oneToTenFiltered WHERE b = 1", + 10, + Set("a", "b", "c"), + Set(EqualTo("b", 1))) testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 5 AND a > 1", 3, Set("a", "b", "c")) testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", 4, Set("a", "b", "c")) @@ -276,48 +277,66 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext { testPushDown("SELECT c FROM oneToTenFiltered WHERE c = 'aaaaaAAAAA'", 1, Set("c")) testPushDown("SELECT c FROM oneToTenFiltered WHERE c IN ('aaaaaAAAAA', 'foo')", 1, Set("c")) - // Columns only referenced by UDF filter must be required, as UDF filters can't be pushed down. - testPushDown("SELECT c FROM oneToTenFiltered WHERE udf_gt3(A)", 10, Set("a", "c")) + // Filters referencing multiple columns are not convertible, all referenced columns must be + // required. + testPushDown("SELECT c FROM oneToTenFiltered WHERE A + b > 9", 10, Set("a", "b", "c")) - // A query with an unconvertible filter, an unhandled filter, and a handled filter. + // A query with an inconvertible filter, an unhandled filter, and a handled filter. testPushDown( """SELECT a | FROM oneToTenFiltered - | WHERE udf_gt3(b) + | WHERE a + b > 9 | AND b < 16 | AND c IN ('bbbbbBBBBB', 'cccccCCCCC', 'dddddDDDDD', 'foo') - """.stripMargin.split("\n").map(_.trim).mkString(" "), 3, Set("a", "b")) + """.stripMargin.split("\n").map(_.trim).mkString(" "), + 3, + Set("a", "b"), + Set(LessThan("b", 16))) def testPushDown( - sqlString: String, - expectedCount: Int, - requiredColumnNames: Set[String]): Unit = { + sqlString: String, + expectedCount: Int, + requiredColumnNames: Set[String]): Unit = { + testPushDown(sqlString, expectedCount, requiredColumnNames, Set.empty[Filter]) + } + + def testPushDown( + sqlString: String, + expectedCount: Int, + requiredColumnNames: Set[String], + expectedUnhandledFilters: Set[Filter]): Unit = { + test(s"PushDown Returns $expectedCount: $sqlString") { - val queryExecution = sql(sqlString).queryExecution - val rawPlan = queryExecution.executedPlan.collect { - case p: execution.PhysicalRDD => p - } match { - case Seq(p) => p - case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") - } - val rawCount = rawPlan.execute().count() - assert(ColumnsRequired.set === requiredColumnNames) + // These tests check a particular plan, disable whole stage codegen. + caseInsensitiveContext.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, false) + try { + val queryExecution = sql(sqlString).queryExecution + val rawPlan = queryExecution.executedPlan.collect { + case p: execution.DataSourceScan => p + } match { + case Seq(p) => p + case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") + } + val rawCount = rawPlan.execute().count() + assert(ColumnsRequired.set === requiredColumnNames) - assert { val table = caseInsensitiveContext.table("oneToTenFiltered") val relation = table.queryExecution.logical.collectFirst { - case LogicalRelation(r, _) => r + case LogicalRelation(r, _, _) => r }.get - // `relation` should be able to handle all pushed filters - relation.unhandledFilters(FiltersPushed.list.toArray).isEmpty - } - - if (rawCount != expectedCount) { - fail( - s"Wrong # of results for pushed filter. Got $rawCount, Expected $expectedCount\n" + - s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" + - queryExecution) + assert( + relation.unhandledFilters(FiltersPushed.list.toArray).toSet === expectedUnhandledFilters) + + if (rawCount != expectedCount) { + fail( + s"Wrong # of results for pushed filter. Got $rawCount, Expected $expectedCount\n" + + s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" + + queryExecution) + } + } finally { + caseInsensitiveContext.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.defaultValue.get) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 5b70d258d6ce3..5ac39f54b91ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -174,7 +174,7 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { """.stripMargin) }.getMessage assert( - message.contains("Cannot insert overwrite into table that is also being read from."), + message.contains("Cannot overwrite a path that is also being read from."), "INSERT OVERWRITE to a table while querying it should not be allowed.") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index c9791879ec74c..a9b1970a7c393 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources -import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -44,7 +44,7 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { path.delete() val base = sqlContext.range(100) - val df = base.unionAll(base).select($"id", lit(1).as("data")) + val df = base.union(base).select($"id", lit(1).as("data")) df.write.partitionBy("id").save(path.getCanonicalPath) checkAnswer( @@ -53,4 +53,12 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { Utils.deleteRecursively(path) } + + test("partitioned columns should appear at the end of schema") { + withTempPath { f => + val path = f.getAbsolutePath + Seq(1 -> "a").toDF("i", "j").write.partitionBy("i").parquet(path) + assert(sqlContext.read.parquet(path).schema.map(_.name) == Seq("j", "i")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index a89c5f8007e78..62f991fc5dc61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -21,6 +21,7 @@ import scala.language.existentials import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -117,28 +118,35 @@ class PrunedScanSuite extends DataSourceTest with SharedSQLContext { def testPruning(sqlString: String, expectedColumns: String*): Unit = { test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") { - val queryExecution = sql(sqlString).queryExecution - val rawPlan = queryExecution.executedPlan.collect { - case p: execution.PhysicalRDD => p - } match { - case Seq(p) => p - case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") - } - val rawColumns = rawPlan.output.map(_.name) - val rawOutput = rawPlan.execute().first() - - if (rawColumns != expectedColumns) { - fail( - s"Wrong column names. Got $rawColumns, Expected $expectedColumns\n" + - s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" + - queryExecution) - } - if (rawOutput.numFields != expectedColumns.size) { - fail(s"Wrong output row. Got $rawOutput\n$queryExecution") + // These tests check a particular plan, disable whole stage codegen. + caseInsensitiveContext.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, false) + try { + val queryExecution = sql(sqlString).queryExecution + val rawPlan = queryExecution.executedPlan.collect { + case p: execution.DataSourceScan => p + } match { + case Seq(p) => p + case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") + } + val rawColumns = rawPlan.output.map(_.name) + val rawOutput = rawPlan.execute().first() + + if (rawColumns != expectedColumns) { + fail( + s"Wrong column names. Got $rawColumns, Expected $expectedColumns\n" + + s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" + + queryExecution) + } + + if (rawOutput.numFields != expectedColumns.size) { + fail(s"Wrong output row. Got $rawOutput\n$queryExecution") + } + } finally { + caseInsensitiveContext.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.defaultValue.get) } } } - } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 27d1cd92fca1a..94d032f4ee414 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -18,43 +18,62 @@ package org.apache.spark.sql.sources import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.datasources.ResolvedDataSource +import org.apache.spark.sql.execution.datasources.DataSource class ResolvedDataSourceSuite extends SparkFunSuite { + private def getProvidingClass(name: String): Class[_] = + DataSource(sqlContext = null, className = name).providingClass test("jdbc") { assert( - ResolvedDataSource.lookupDataSource("jdbc") === + getProvidingClass("jdbc") === classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) assert( - ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.jdbc") === + getProvidingClass("org.apache.spark.sql.execution.datasources.jdbc") === classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) assert( - ResolvedDataSource.lookupDataSource("org.apache.spark.sql.jdbc") === + getProvidingClass("org.apache.spark.sql.jdbc") === classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) } test("json") { assert( - ResolvedDataSource.lookupDataSource("json") === + getProvidingClass("json") === classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) assert( - ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.json") === + getProvidingClass("org.apache.spark.sql.execution.datasources.json") === classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) assert( - ResolvedDataSource.lookupDataSource("org.apache.spark.sql.json") === + getProvidingClass("org.apache.spark.sql.json") === classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) } test("parquet") { assert( - ResolvedDataSource.lookupDataSource("parquet") === + getProvidingClass("parquet") === classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) assert( - ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.parquet") === + getProvidingClass("org.apache.spark.sql.execution.datasources.parquet") === classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) assert( - ResolvedDataSource.lookupDataSource("org.apache.spark.sql.parquet") === + getProvidingClass("org.apache.spark.sql.parquet") === classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) } + + test("error message for unknown data sources") { + val error1 = intercept[ClassNotFoundException] { + getProvidingClass("avro") + } + assert(error1.getMessage.contains("spark-packages")) + + val error2 = intercept[ClassNotFoundException] { + getProvidingClass("com.databricks.spark.avro") + } + assert(error2.getMessage.contains("spark-packages")) + + val error3 = intercept[ClassNotFoundException] { + getProvidingClass("asfdwefasdfasdf") + } + assert(error3.getMessage.contains("spark-packages")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 10d261368993d..bb2c54aa64977 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -21,7 +21,8 @@ import java.io.File import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.{AnalysisException, SaveMode, SQLConf, DataFrame} +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -121,7 +122,7 @@ class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndA // verify the append mode df.write.mode(SaveMode.Append).json(path.toString) - val df2 = df.unionAll(df) + val df2 = df.union(df) df2.registerTempTable("jsonTable2") checkLoad(df2, "jsonTable2") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 12af8068c398f..99f1661ad0d15 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -85,6 +85,7 @@ case class AllDataTypesScan( Date.valueOf("1970-01-01"), new Timestamp(20000 + i), s"varchar_$i", + s"char_$i", Seq(i, i + 1), Seq(Map(s"str_$i" -> Row(i.toLong))), Map(i -> i.toString), @@ -115,6 +116,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { Date.valueOf("1970-01-01"), new Timestamp(20000 + i), s"varchar_$i", + s"char_$i", Seq(i, i + 1), Seq(Map(s"str_$i" -> Row(i.toLong))), Map(i -> i.toString), @@ -154,6 +156,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { |dateField dAte, |timestampField tiMestamp, |varcharField varchaR(12), + |charField ChaR(18), |arrayFieldSimple Array, |arrayFieldComplex Array>>, |mapFieldSimple MAP, @@ -207,6 +210,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { StructField("dateField", DateType, true) :: StructField("timestampField", TimestampType, true) :: StructField("varcharField", StringType, true) :: + StructField("charField", StringType, true) :: StructField("arrayFieldSimple", ArrayType(IntegerType), true) :: StructField("arrayFieldComplex", ArrayType( @@ -248,6 +252,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { | dateField, | timestampField, | varcharField, + | charField, | arrayFieldSimple, | arrayFieldComplex, | mapFieldSimple, @@ -334,7 +339,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { test("exceptions") { // Make sure we do throw correct exception when users use a relation provider that - // only implements the RelationProvier or the SchemaRelationProvider. + // only implements the RelationProvider or the SchemaRelationProvider. val schemaNotAllowed = intercept[Exception] { sql( """ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala new file mode 100644 index 0000000000000..3d69c8a18711b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala @@ -0,0 +1,313 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import scala.concurrent.Future +import scala.util.Random +import scala.util.control.NonFatal + +import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.time.Span +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkException +import org.apache.spark.sql.{ContinuousQuery, Dataset, StreamTest} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils + +class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { + + import AwaitTerminationTester._ + import testImplicits._ + + override val streamingTimeout = 20.seconds + + before { + assert(sqlContext.streams.active.isEmpty) + sqlContext.streams.resetTerminated() + } + + after { + assert(sqlContext.streams.active.isEmpty) + sqlContext.streams.resetTerminated() + } + + testQuietly("listing") { + val (m1, ds1) = makeDataset + val (m2, ds2) = makeDataset + val (m3, ds3) = makeDataset + + withQueriesOn(ds1, ds2, ds3) { queries => + require(queries.size === 3) + assert(sqlContext.streams.active.toSet === queries.toSet) + val (q1, q2, q3) = (queries(0), queries(1), queries(2)) + + assert(sqlContext.streams.get(q1.name).eq(q1)) + assert(sqlContext.streams.get(q2.name).eq(q2)) + assert(sqlContext.streams.get(q3.name).eq(q3)) + intercept[IllegalArgumentException] { + sqlContext.streams.get("non-existent-name") + } + + q1.stop() + + assert(sqlContext.streams.active.toSet === Set(q2, q3)) + val ex1 = withClue("no error while getting non-active query") { + intercept[IllegalArgumentException] { + sqlContext.streams.get(q1.name) + } + } + assert(ex1.getMessage.contains(q1.name), "error does not contain name of query to be fetched") + assert(sqlContext.streams.get(q2.name).eq(q2)) + + m2.addData(0) // q2 should terminate with error + + eventually(Timeout(streamingTimeout)) { + require(!q2.isActive) + require(q2.exception.isDefined) + } + withClue("no error while getting non-active query") { + intercept[IllegalArgumentException] { + sqlContext.streams.get(q2.name).eq(q2) + } + } + + assert(sqlContext.streams.active.toSet === Set(q3)) + } + } + + testQuietly("awaitAnyTermination without timeout and resetTerminated") { + val datasets = Seq.fill(5)(makeDataset._2) + withQueriesOn(datasets: _*) { queries => + require(queries.size === datasets.size) + assert(sqlContext.streams.active.toSet === queries.toSet) + + // awaitAnyTermination should be blocking + testAwaitAnyTermination(ExpectBlocked) + + // Stop a query asynchronously and see if it is reported through awaitAnyTermination + val q1 = stopRandomQueryAsync(stopAfter = 100 milliseconds, withError = false) + testAwaitAnyTermination(ExpectNotBlocked) + require(!q1.isActive) // should be inactive by the time the prev awaitAnyTerm returned + + // All subsequent calls to awaitAnyTermination should be non-blocking + testAwaitAnyTermination(ExpectNotBlocked) + + // Resetting termination should make awaitAnyTermination() blocking again + sqlContext.streams.resetTerminated() + testAwaitAnyTermination(ExpectBlocked) + + // Terminate a query asynchronously with exception and see awaitAnyTermination throws + // the exception + val q2 = stopRandomQueryAsync(100 milliseconds, withError = true) + testAwaitAnyTermination(ExpectException[SparkException]) + require(!q2.isActive) // should be inactive by the time the prev awaitAnyTerm returned + + // All subsequent calls to awaitAnyTermination should throw the exception + testAwaitAnyTermination(ExpectException[SparkException]) + + // Resetting termination should make awaitAnyTermination() blocking again + sqlContext.streams.resetTerminated() + testAwaitAnyTermination(ExpectBlocked) + + // Terminate multiple queries, one with failure and see whether awaitAnyTermination throws + // the exception + val q3 = stopRandomQueryAsync(10 milliseconds, withError = false) + testAwaitAnyTermination(ExpectNotBlocked) + require(!q3.isActive) + val q4 = stopRandomQueryAsync(10 milliseconds, withError = true) + eventually(Timeout(streamingTimeout)) { require(!q4.isActive) } + // After q4 terminates with exception, awaitAnyTerm should start throwing exception + testAwaitAnyTermination(ExpectException[SparkException]) + } + } + + testQuietly("awaitAnyTermination with timeout and resetTerminated") { + val datasets = Seq.fill(6)(makeDataset._2) + withQueriesOn(datasets: _*) { queries => + require(queries.size === datasets.size) + assert(sqlContext.streams.active.toSet === queries.toSet) + + // awaitAnyTermination should be blocking or non-blocking depending on timeout values + testAwaitAnyTermination( + ExpectBlocked, + awaitTimeout = 4 seconds, + expectedReturnedValue = false, + testBehaviorFor = 2 seconds) + + testAwaitAnyTermination( + ExpectNotBlocked, + awaitTimeout = 50 milliseconds, + expectedReturnedValue = false, + testBehaviorFor = 1 second) + + // Stop a query asynchronously within timeout and awaitAnyTerm should be unblocked + val q1 = stopRandomQueryAsync(stopAfter = 100 milliseconds, withError = false) + testAwaitAnyTermination( + ExpectNotBlocked, + awaitTimeout = 2 seconds, + expectedReturnedValue = true, + testBehaviorFor = 4 seconds) + require(!q1.isActive) // should be inactive by the time the prev awaitAnyTerm returned + + // All subsequent calls to awaitAnyTermination should be non-blocking even if timeout is high + testAwaitAnyTermination( + ExpectNotBlocked, awaitTimeout = 4 seconds, expectedReturnedValue = true) + + // Resetting termination should make awaitAnyTermination() blocking again + sqlContext.streams.resetTerminated() + testAwaitAnyTermination( + ExpectBlocked, + awaitTimeout = 4 seconds, + expectedReturnedValue = false, + testBehaviorFor = 1 second) + + // Terminate a query asynchronously with exception within timeout, awaitAnyTermination should + // throws the exception + val q2 = stopRandomQueryAsync(100 milliseconds, withError = true) + testAwaitAnyTermination( + ExpectException[SparkException], + awaitTimeout = 4 seconds, + testBehaviorFor = 6 seconds) + require(!q2.isActive) // should be inactive by the time the prev awaitAnyTerm returned + + // All subsequent calls to awaitAnyTermination should throw the exception + testAwaitAnyTermination( + ExpectException[SparkException], + awaitTimeout = 2 seconds, + testBehaviorFor = 4 seconds) + + // Terminate a query asynchronously outside the timeout, awaitAnyTerm should be blocked + sqlContext.streams.resetTerminated() + val q3 = stopRandomQueryAsync(2 seconds, withError = true) + testAwaitAnyTermination( + ExpectNotBlocked, + awaitTimeout = 100 milliseconds, + expectedReturnedValue = false, + testBehaviorFor = 4 seconds) + + // After that query is stopped, awaitAnyTerm should throw exception + eventually(Timeout(streamingTimeout)) { require(!q3.isActive) } // wait for query to stop + testAwaitAnyTermination( + ExpectException[SparkException], + awaitTimeout = 100 milliseconds, + testBehaviorFor = 4 seconds) + + + // Terminate multiple queries, one with failure and see whether awaitAnyTermination throws + // the exception + sqlContext.streams.resetTerminated() + + val q4 = stopRandomQueryAsync(10 milliseconds, withError = false) + testAwaitAnyTermination( + ExpectNotBlocked, awaitTimeout = 2 seconds, expectedReturnedValue = true) + require(!q4.isActive) + val q5 = stopRandomQueryAsync(10 milliseconds, withError = true) + eventually(Timeout(streamingTimeout)) { require(!q5.isActive) } + // After q5 terminates with exception, awaitAnyTerm should start throwing exception + testAwaitAnyTermination(ExpectException[SparkException], awaitTimeout = 2 seconds) + } + } + + + /** Run a body of code by defining a query each on multiple datasets */ + private def withQueriesOn(datasets: Dataset[_]*)(body: Seq[ContinuousQuery] => Unit): Unit = { + failAfter(streamingTimeout) { + val queries = withClue("Error starting queries") { + datasets.map { ds => + @volatile var query: StreamExecution = null + try { + val df = ds.toDF + val metadataRoot = + Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath + query = sqlContext + .streams + .startQuery( + StreamExecution.nextName, + metadataRoot, + df, + new MemorySink(df.schema)) + .asInstanceOf[StreamExecution] + } catch { + case NonFatal(e) => + if (query != null) query.stop() + throw e + } + query + } + } + try { + body(queries) + } finally { + queries.foreach(_.stop()) + } + } + } + + /** Test the behavior of awaitAnyTermination */ + private def testAwaitAnyTermination( + expectedBehavior: ExpectedBehavior, + expectedReturnedValue: Boolean = false, + awaitTimeout: Span = null, + testBehaviorFor: Span = 4 seconds + ): Unit = { + + def awaitTermFunc(): Unit = { + if (awaitTimeout != null && awaitTimeout.toMillis > 0) { + val returnedValue = sqlContext.streams.awaitAnyTermination(awaitTimeout.toMillis) + assert(returnedValue === expectedReturnedValue, "Returned value does not match expected") + } else { + sqlContext.streams.awaitAnyTermination() + } + } + + AwaitTerminationTester.test(expectedBehavior, awaitTermFunc, testBehaviorFor) + } + + /** Stop a random active query either with `stop()` or with an error */ + private def stopRandomQueryAsync(stopAfter: Span, withError: Boolean): ContinuousQuery = { + + import scala.concurrent.ExecutionContext.Implicits.global + + val activeQueries = sqlContext.streams.active + val queryToStop = activeQueries(Random.nextInt(activeQueries.length)) + Future { + Thread.sleep(stopAfter.toMillis) + if (withError) { + logDebug(s"Terminating query ${queryToStop.name} with error") + queryToStop.asInstanceOf[StreamExecution].logicalPlan.collect { + case StreamingExecutionRelation(source, _) => + source.asInstanceOf[MemoryStream[Int]].addData(0) + } + } else { + logDebug(s"Stopping query ${queryToStop.name}") + queryToStop.stop() + } + } + queryToStop + } + + private def makeDataset: (MemoryStream[Int], Dataset[Int]) = { + val inputData = MemoryStream[Int] + val mapped = inputData.toDS.map(6 / _) + (inputData, mapped) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala new file mode 100644 index 0000000000000..3be0ea481dc53 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.SparkException +import org.apache.spark.sql.StreamTest +import org.apache.spark.sql.execution.streaming.{CompositeOffset, LongOffset, MemoryStream, StreamExecution} +import org.apache.spark.sql.test.SharedSQLContext + +class ContinuousQuerySuite extends StreamTest with SharedSQLContext { + + import AwaitTerminationTester._ + import testImplicits._ + + testQuietly("lifecycle states and awaitTermination") { + val inputData = MemoryStream[Int] + val mapped = inputData.toDS().map { 6 / _} + + testStream(mapped)( + AssertOnQuery(_.isActive === true), + AssertOnQuery(_.exception.isEmpty), + AddData(inputData, 1, 2), + CheckAnswer(6, 3), + TestAwaitTermination(ExpectBlocked), + TestAwaitTermination(ExpectBlocked, timeoutMs = 2000), + TestAwaitTermination(ExpectNotBlocked, timeoutMs = 10, expectedReturnValue = false), + StopStream, + AssertOnQuery(_.isActive === false), + AssertOnQuery(_.exception.isEmpty), + TestAwaitTermination(ExpectNotBlocked), + TestAwaitTermination(ExpectNotBlocked, timeoutMs = 2000, expectedReturnValue = true), + TestAwaitTermination(ExpectNotBlocked, timeoutMs = 10, expectedReturnValue = true), + StartStream, + AssertOnQuery(_.isActive === true), + AddData(inputData, 0), + ExpectFailure[SparkException], + AssertOnQuery(_.isActive === false), + TestAwaitTermination(ExpectException[SparkException]), + TestAwaitTermination(ExpectException[SparkException], timeoutMs = 2000), + TestAwaitTermination(ExpectException[SparkException], timeoutMs = 10), + AssertOnQuery( + q => + q.exception.get.startOffset.get === q.committedOffsets.toCompositeOffset(Seq(inputData)), + "incorrect start offset on exception") + ) + } + + testQuietly("source and sink statuses") { + val inputData = MemoryStream[Int] + val mapped = inputData.toDS().map(6 / _) + + testStream(mapped)( + AssertOnQuery(_.sourceStatuses.length === 1), + AssertOnQuery(_.sourceStatuses(0).description.contains("Memory")), + AssertOnQuery(_.sourceStatuses(0).offset === None), + AssertOnQuery(_.sinkStatus.description.contains("Memory")), + AssertOnQuery(_.sinkStatus.offset === new CompositeOffset(None :: Nil)), + AddData(inputData, 1, 2), + CheckAnswer(6, 3), + AssertOnQuery(_.sourceStatuses(0).offset === Some(LongOffset(0))), + AssertOnQuery(_.sinkStatus.offset === CompositeOffset.fill(LongOffset(0))), + AddData(inputData, 1, 2), + CheckAnswer(6, 3, 6, 3), + AssertOnQuery(_.sourceStatuses(0).offset === Some(LongOffset(1))), + AssertOnQuery(_.sinkStatus.offset === CompositeOffset.fill(LongOffset(1))), + AddData(inputData, 0), + ExpectFailure[SparkException], + AssertOnQuery(_.sourceStatuses(0).offset === Some(LongOffset(2))), + AssertOnQuery(_.sinkStatus.offset === CompositeOffset.fill(LongOffset(1))) + ) + } + + /** + * A [[StreamAction]] to test the behavior of `ContinuousQuery.awaitTermination()`. + * + * @param expectedBehavior Expected behavior (not blocked, blocked, or exception thrown) + * @param timeoutMs Timeout in milliseconds + * When timeoutMs <= 0, awaitTermination() is tested (i.e. w/o timeout) + * When timeoutMs > 0, awaitTermination(timeoutMs) is tested + * @param expectedReturnValue Expected return value when awaitTermination(timeoutMs) is used + */ + case class TestAwaitTermination( + expectedBehavior: ExpectedBehavior, + timeoutMs: Int = -1, + expectedReturnValue: Boolean = false + ) extends AssertOnQuery( + TestAwaitTermination.assertOnQueryCondition(expectedBehavior, timeoutMs, expectedReturnValue), + "Error testing awaitTermination behavior" + ) { + override def toString(): String = { + s"TestAwaitTermination($expectedBehavior, timeoutMs = $timeoutMs, " + + s"expectedReturnValue = $expectedReturnValue)" + } + } + + object TestAwaitTermination { + + /** + * Tests the behavior of `ContinuousQuery.awaitTermination`. + * + * @param expectedBehavior Expected behavior (not blocked, blocked, or exception thrown) + * @param timeoutMs Timeout in milliseconds + * When timeoutMs <= 0, awaitTermination() is tested (i.e. w/o timeout) + * When timeoutMs > 0, awaitTermination(timeoutMs) is tested + * @param expectedReturnValue Expected return value when awaitTermination(timeoutMs) is used + */ + def assertOnQueryCondition( + expectedBehavior: ExpectedBehavior, + timeoutMs: Int, + expectedReturnValue: Boolean + )(q: StreamExecution): Boolean = { + + def awaitTermFunc(): Unit = { + if (timeoutMs <= 0) { + q.awaitTermination() + } else { + val returnedValue = q.awaitTermination(timeoutMs) + assert(returnedValue === expectedReturnValue, "Returned value does not match expected") + } + } + AwaitTerminationTester.test(expectedBehavior, awaitTermFunc) + true // If the control reached here, then everything worked as expected + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala new file mode 100644 index 0000000000000..00efe21d39de4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala @@ -0,0 +1,371 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.test + +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration._ + +import org.mockito.Mockito._ +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.util.Utils + +object LastOptions { + + var mockStreamSourceProvider = mock(classOf[StreamSourceProvider]) + var mockStreamSinkProvider = mock(classOf[StreamSinkProvider]) + var parameters: Map[String, String] = null + var schema: Option[StructType] = null + var partitionColumns: Seq[String] = Nil + + def clear(): Unit = { + parameters = null + schema = null + partitionColumns = null + reset(mockStreamSourceProvider) + reset(mockStreamSinkProvider) + } +} + +/** Dummy provider: returns no-op source/sink and records options in [[LastOptions]]. */ +class DefaultSource extends StreamSourceProvider with StreamSinkProvider { + + private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) + + override def sourceSchema( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = { + LastOptions.parameters = parameters + LastOptions.schema = schema + LastOptions.mockStreamSourceProvider.sourceSchema(sqlContext, schema, providerName, parameters) + ("dummySource", fakeSchema) + } + + override def createSource( + sqlContext: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + LastOptions.parameters = parameters + LastOptions.schema = schema + LastOptions.mockStreamSourceProvider.createSource( + sqlContext, metadataPath, schema, providerName, parameters) + new Source { + override def schema: StructType = fakeSchema + + override def getOffset: Option[Offset] = Some(new LongOffset(0)) + + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + import sqlContext.implicits._ + + Seq[Int]().toDS().toDF() + } + } + } + + override def createSink( + sqlContext: SQLContext, + parameters: Map[String, String], + partitionColumns: Seq[String]): Sink = { + LastOptions.parameters = parameters + LastOptions.partitionColumns = partitionColumns + LastOptions.mockStreamSinkProvider.createSink(sqlContext, parameters, partitionColumns) + new Sink { + override def addBatch(batchId: Long, data: DataFrame): Unit = {} + } + } +} + +class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { + import testImplicits._ + + private def newMetadataDir = + Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath + + after { + sqlContext.streams.active.foreach(_.stop()) + } + + test("resolve default source") { + sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + .write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .startStream() + .stop() + } + + test("resolve full class") { + sqlContext.read + .format("org.apache.spark.sql.streaming.test.DefaultSource") + .stream() + .write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .startStream() + .stop() + } + + test("options") { + val map = new java.util.HashMap[String, String] + map.put("opt3", "3") + + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .option("opt1", "1") + .options(Map("opt2" -> "2")) + .options(map) + .stream() + + assert(LastOptions.parameters("opt1") == "1") + assert(LastOptions.parameters("opt2") == "2") + assert(LastOptions.parameters("opt3") == "3") + + LastOptions.clear() + + df.write + .format("org.apache.spark.sql.streaming.test") + .option("opt1", "1") + .options(Map("opt2" -> "2")) + .options(map) + .option("checkpointLocation", newMetadataDir) + .startStream() + .stop() + + assert(LastOptions.parameters("opt1") == "1") + assert(LastOptions.parameters("opt2") == "2") + assert(LastOptions.parameters("opt3") == "3") + } + + test("partitioning") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + + df.write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .startStream() + .stop() + assert(LastOptions.partitionColumns == Nil) + + df.write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .partitionBy("a") + .startStream() + .stop() + assert(LastOptions.partitionColumns == Seq("a")) + + withSQLConf("spark.sql.caseSensitive" -> "false") { + df.write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .partitionBy("A") + .startStream() + .stop() + assert(LastOptions.partitionColumns == Seq("a")) + } + + intercept[AnalysisException] { + df.write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .partitionBy("b") + .startStream() + .stop() + } + } + + test("stream paths") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .stream("/test") + + assert(LastOptions.parameters("path") == "/test") + + LastOptions.clear() + + df.write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .startStream("/test") + .stop() + + assert(LastOptions.parameters("path") == "/test") + } + + test("test different data types for options") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .option("intOpt", 56) + .option("boolOpt", false) + .option("doubleOpt", 6.7) + .stream("/test") + + assert(LastOptions.parameters("intOpt") == "56") + assert(LastOptions.parameters("boolOpt") == "false") + assert(LastOptions.parameters("doubleOpt") == "6.7") + + LastOptions.clear() + df.write + .format("org.apache.spark.sql.streaming.test") + .option("intOpt", 56) + .option("boolOpt", false) + .option("doubleOpt", 6.7) + .option("checkpointLocation", newMetadataDir) + .startStream("/test") + .stop() + + assert(LastOptions.parameters("intOpt") == "56") + assert(LastOptions.parameters("boolOpt") == "false") + assert(LastOptions.parameters("doubleOpt") == "6.7") + } + + test("unique query names") { + + /** Start a query with a specific name */ + def startQueryWithName(name: String = ""): ContinuousQuery = { + sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream("/test") + .write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .queryName(name) + .startStream() + } + + /** Start a query without specifying a name */ + def startQueryWithoutName(): ContinuousQuery = { + sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream("/test") + .write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .startStream() + } + + /** Get the names of active streams */ + def activeStreamNames: Set[String] = { + val streams = sqlContext.streams.active + val names = streams.map(_.name).toSet + assert(streams.length === names.size, s"names of active queries are not unique: $names") + names + } + + val q1 = startQueryWithName("name") + + // Should not be able to start another query with the same name + intercept[IllegalArgumentException] { + startQueryWithName("name") + } + assert(activeStreamNames === Set("name")) + + // Should be able to start queries with other names + val q3 = startQueryWithName("another-name") + assert(activeStreamNames === Set("name", "another-name")) + + // Should be able to start queries with auto-generated names + val q4 = startQueryWithoutName() + assert(activeStreamNames.contains(q4.name)) + + // Should not be able to start a query with same auto-generated name + intercept[IllegalArgumentException] { + startQueryWithName(q4.name) + } + + // Should be able to start query with that name after stopping the previous query + q1.stop() + val q5 = startQueryWithName("name") + assert(activeStreamNames.contains("name")) + sqlContext.streams.active.foreach(_.stop()) + } + + test("trigger") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream("/test") + + var q = df.write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .trigger(ProcessingTime(10.seconds)) + .startStream() + q.stop() + + assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(10000)) + + q = df.write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .trigger(ProcessingTime.create(100, TimeUnit.SECONDS)) + .startStream() + q.stop() + + assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(100000)) + } + + test("source metadataPath") { + LastOptions.clear() + + val checkpointLocation = newMetadataDir + + val df1 = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + + val df2 = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + + val q = df1.union(df2).write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", checkpointLocation) + .trigger(ProcessingTime(10.seconds)) + .startStream() + q.stop() + + verify(LastOptions.mockStreamSourceProvider).createSource( + sqlContext, + checkpointLocation + "/sources/0", + None, + "org.apache.spark.sql.streaming.test", + Map.empty) + + verify(LastOptions.mockStreamSourceProvider).createSource( + sqlContext, + checkpointLocation + "/sources/1", + None, + "org.apache.spark.sql.streaming.test", + Map.empty) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala new file mode 100644 index 0000000000000..8cf5dedabcee1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.StreamTest +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils + +class FileStreamSinkSuite extends StreamTest with SharedSQLContext { + import testImplicits._ + + test("unpartitioned writing") { + val inputData = MemoryStream[Int] + val df = inputData.toDF() + + val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath + val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath + + val query = + df.write + .format("parquet") + .option("checkpointLocation", checkpointDir) + .startStream(outputDir) + + inputData.addData(1, 2, 3) + failAfter(streamingTimeout) { query.processAllAvailable() } + + val outputDf = sqlContext.read.parquet(outputDir).as[Int] + checkDataset( + outputDf, + 1, 2, 3) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala new file mode 100644 index 0000000000000..73d1b1b1d507d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.io.File + +import org.apache.spark.sql.{AnalysisException, StreamTest} +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.util.Utils + +class FileStreamSourceTest extends StreamTest with SharedSQLContext { + + import testImplicits._ + + case class AddTextFileData(source: FileStreamSource, content: String, src: File, tmp: File) + extends AddData { + + override def addData(): Offset = { + source.withBatchingLocked { + val file = Utils.tempFileWith(new File(tmp, "text")) + stringToFile(file, content).renameTo(new File(src, file.getName)) + source.currentOffset + } + 1 + } + } + + case class AddParquetFileData( + source: FileStreamSource, + content: Seq[String], + src: File, + tmp: File) extends AddData { + + override def addData(): Offset = { + source.withBatchingLocked { + val file = Utils.tempFileWith(new File(tmp, "parquet")) + content.toDS().toDF().write.parquet(file.getCanonicalPath) + file.renameTo(new File(src, file.getName)) + source.currentOffset + } + 1 + } + } + + /** Use `format` and `path` to create FileStreamSource via DataFrameReader */ + def createFileStreamSource( + format: String, + path: String, + schema: Option[StructType] = None): FileStreamSource = { + val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath + val reader = + if (schema.isDefined) { + sqlContext.read.format(format).schema(schema.get) + } else { + sqlContext.read.format(format) + } + reader.stream(path) + .queryExecution.analyzed + .collect { case StreamingRelation(dataSource, _, _) => + // There is only one source in our tests so just set sourceId to 0 + dataSource.createSource(s"$checkpointLocation/sources/0").asInstanceOf[FileStreamSource] + }.head + } + + val valueSchema = new StructType().add("value", StringType) +} + +class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { + + import testImplicits._ + + private def createFileStreamSourceAndGetSchema( + format: Option[String], + path: Option[String], + schema: Option[StructType] = None): StructType = { + val reader = sqlContext.read + format.foreach(reader.format) + schema.foreach(reader.schema) + val df = + if (path.isDefined) { + reader.stream(path.get) + } else { + reader.stream() + } + df.queryExecution.analyzed + .collect { case StreamingRelation(dataSource, _, _) => + dataSource.sourceSchema() + }.head._2 + } + + test("FileStreamSource schema: no path") { + val e = intercept[IllegalArgumentException] { + createFileStreamSourceAndGetSchema(format = None, path = None, schema = None) + } + assert("'path' is not specified" === e.getMessage) + } + + test("FileStreamSource schema: path doesn't exist") { + intercept[AnalysisException] { + createFileStreamSourceAndGetSchema(format = None, path = Some("/a/b/c"), schema = None) + } + } + + test("FileStreamSource schema: text, no existing files, no schema") { + withTempDir { src => + val schema = createFileStreamSourceAndGetSchema( + format = Some("text"), path = Some(src.getCanonicalPath), schema = None) + assert(schema === new StructType().add("value", StringType)) + } + } + + test("FileStreamSource schema: text, existing files, no schema") { + withTempDir { src => + stringToFile(new File(src, "1"), "a\nb\nc") + val schema = createFileStreamSourceAndGetSchema( + format = Some("text"), path = Some(src.getCanonicalPath), schema = None) + assert(schema === new StructType().add("value", StringType)) + } + } + + test("FileStreamSource schema: text, existing files, schema") { + withTempDir { src => + stringToFile(new File(src, "1"), "a\nb\nc") + val userSchema = new StructType().add("userColumn", StringType) + val schema = createFileStreamSourceAndGetSchema( + format = Some("text"), path = Some(src.getCanonicalPath), schema = Some(userSchema)) + assert(schema === userSchema) + } + } + + test("FileStreamSource schema: parquet, no existing files, no schema") { + withTempDir { src => + val e = intercept[AnalysisException] { + createFileStreamSourceAndGetSchema( + format = Some("parquet"), path = Some(new File(src, "1").getCanonicalPath), schema = None) + } + assert("Unable to infer schema. It must be specified manually.;" === e.getMessage) + } + } + + test("FileStreamSource schema: parquet, existing files, no schema") { + withTempDir { src => + Seq("a", "b", "c").toDS().as("userColumn").toDF() + .write.parquet(new File(src, "1").getCanonicalPath) + val schema = createFileStreamSourceAndGetSchema( + format = Some("parquet"), path = Some(src.getCanonicalPath), schema = None) + assert(schema === new StructType().add("value", StringType)) + } + } + + test("FileStreamSource schema: parquet, existing files, schema") { + withTempPath { src => + Seq("a", "b", "c").toDS().as("oldUserColumn").toDF() + .write.parquet(new File(src, "1").getCanonicalPath) + val userSchema = new StructType().add("userColumn", StringType) + val schema = createFileStreamSourceAndGetSchema( + format = Some("parquet"), path = Some(src.getCanonicalPath), schema = Some(userSchema)) + assert(schema === userSchema) + } + } + + test("FileStreamSource schema: json, no existing files, no schema") { + withTempDir { src => + val e = intercept[AnalysisException] { + createFileStreamSourceAndGetSchema( + format = Some("json"), path = Some(src.getCanonicalPath), schema = None) + } + assert("Unable to infer schema. It must be specified manually.;" === e.getMessage) + } + } + + test("FileStreamSource schema: json, existing files, no schema") { + withTempDir { src => + stringToFile(new File(src, "1"), "{'c': '1'}\n{'c': '2'}\n{'c': '3'}") + val schema = createFileStreamSourceAndGetSchema( + format = Some("json"), path = Some(src.getCanonicalPath), schema = None) + assert(schema === new StructType().add("c", StringType)) + } + } + + test("FileStreamSource schema: json, existing files, schema") { + withTempDir { src => + stringToFile(new File(src, "1"), "{'c': '1'}\n{'c': '2'}\n{'c', '3'}") + val userSchema = new StructType().add("userColumn", StringType) + val schema = createFileStreamSourceAndGetSchema( + format = Some("json"), path = Some(src.getCanonicalPath), schema = Some(userSchema)) + assert(schema === userSchema) + } + } + + test("read from text files") { + val src = Utils.createTempDir(namePrefix = "streaming.src") + val tmp = Utils.createTempDir(namePrefix = "streaming.tmp") + + val textSource = createFileStreamSource("text", src.getCanonicalPath) + val filtered = textSource.toDF().filter($"value" contains "keep") + + testStream(filtered)( + AddTextFileData(textSource, "drop1\nkeep2\nkeep3", src, tmp), + CheckAnswer("keep2", "keep3"), + StopStream, + AddTextFileData(textSource, "drop4\nkeep5\nkeep6", src, tmp), + StartStream, + CheckAnswer("keep2", "keep3", "keep5", "keep6"), + AddTextFileData(textSource, "drop7\nkeep8\nkeep9", src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9") + ) + + Utils.deleteRecursively(src) + Utils.deleteRecursively(tmp) + } + + test("read from json files") { + val src = Utils.createTempDir(namePrefix = "streaming.src") + val tmp = Utils.createTempDir(namePrefix = "streaming.tmp") + + val textSource = createFileStreamSource("json", src.getCanonicalPath, Some(valueSchema)) + val filtered = textSource.toDF().filter($"value" contains "keep") + + testStream(filtered)( + AddTextFileData( + textSource, + "{'value': 'drop1'}\n{'value': 'keep2'}\n{'value': 'keep3'}", + src, + tmp), + CheckAnswer("keep2", "keep3"), + StopStream, + AddTextFileData( + textSource, + "{'value': 'drop4'}\n{'value': 'keep5'}\n{'value': 'keep6'}", + src, + tmp), + StartStream, + CheckAnswer("keep2", "keep3", "keep5", "keep6"), + AddTextFileData( + textSource, + "{'value': 'drop7'}\n{'value': 'keep8'}\n{'value': 'keep9'}", + src, + tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9") + ) + + Utils.deleteRecursively(src) + Utils.deleteRecursively(tmp) + } + + test("read from json files with inferring schema") { + val src = Utils.createTempDir(namePrefix = "streaming.src") + val tmp = Utils.createTempDir(namePrefix = "streaming.tmp") + + // Add a file so that we can infer its schema + stringToFile(new File(src, "existing"), "{'c': 'drop1'}\n{'c': 'keep2'}\n{'c': 'keep3'}") + + val textSource = createFileStreamSource("json", src.getCanonicalPath) + + // FileStreamSource should infer the column "c" + val filtered = textSource.toDF().filter($"c" contains "keep") + + testStream(filtered)( + AddTextFileData(textSource, "{'c': 'drop4'}\n{'c': 'keep5'}\n{'c': 'keep6'}", src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6") + ) + + Utils.deleteRecursively(src) + Utils.deleteRecursively(tmp) + } + + test("read from parquet files") { + val src = Utils.createTempDir(namePrefix = "streaming.src") + val tmp = Utils.createTempDir(namePrefix = "streaming.tmp") + + val fileSource = createFileStreamSource("parquet", src.getCanonicalPath, Some(valueSchema)) + val filtered = fileSource.toDF().filter($"value" contains "keep") + + testStream(filtered)( + AddParquetFileData(fileSource, Seq("drop1", "keep2", "keep3"), src, tmp), + CheckAnswer("keep2", "keep3"), + StopStream, + AddParquetFileData(fileSource, Seq("drop4", "keep5", "keep6"), src, tmp), + StartStream, + CheckAnswer("keep2", "keep3", "keep5", "keep6"), + AddParquetFileData(fileSource, Seq("drop7", "keep8", "keep9"), src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9") + ) + + Utils.deleteRecursively(src) + Utils.deleteRecursively(tmp) + } + + test("file stream source without schema") { + val src = Utils.createTempDir(namePrefix = "streaming.src") + + // Only "text" doesn't need a schema + createFileStreamSource("text", src.getCanonicalPath) + + // Both "json" and "parquet" require a schema if no existing file to infer + intercept[AnalysisException] { + createFileStreamSource("json", src.getCanonicalPath) + } + intercept[AnalysisException] { + createFileStreamSource("parquet", src.getCanonicalPath) + } + + Utils.deleteRecursively(src) + } + + test("fault tolerance") { + val src = Utils.createTempDir(namePrefix = "streaming.src") + val tmp = Utils.createTempDir(namePrefix = "streaming.tmp") + + val textSource = createFileStreamSource("text", src.getCanonicalPath) + val filtered = textSource.toDF().filter($"value" contains "keep") + + testStream(filtered)( + AddTextFileData(textSource, "drop1\nkeep2\nkeep3", src, tmp), + CheckAnswer("keep2", "keep3"), + StopStream, + AddTextFileData(textSource, "drop4\nkeep5\nkeep6", src, tmp), + StartStream, + CheckAnswer("keep2", "keep3", "keep5", "keep6"), + AddTextFileData(textSource, "drop7\nkeep8\nkeep9", src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9") + ) + + Utils.deleteRecursively(src) + Utils.deleteRecursively(tmp) + } +} + +class FileStreamSourceStressTestSuite extends FileStreamSourceTest with SharedSQLContext { + + import testImplicits._ + + test("file source stress test") { + val src = Utils.createTempDir(namePrefix = "streaming.src") + val tmp = Utils.createTempDir(namePrefix = "streaming.tmp") + + val textSource = createFileStreamSource("text", src.getCanonicalPath) + val ds = textSource.toDS[String]().map(_.toInt + 1) + runStressTest(ds, data => { + AddTextFileData(textSource, data.mkString("\n"), src, tmp) + }) + + Utils.deleteRecursively(src) + Utils.deleteRecursively(tmp) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala new file mode 100644 index 0000000000000..5b49a0a86a04f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.io.File +import java.util.UUID + +import scala.util.Random +import scala.util.control.NonFatal + +import org.apache.spark.sql.{ContinuousQuery, ContinuousQueryException, StreamTest} +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils + +/** + * A stress test for streaming queries that read and write files. This test consists of + * two threads: + * - one that writes out `numRecords` distinct integers to files of random sizes (the total + * number of records is fixed but each files size / creation time is random). + * - another that continually restarts a buggy streaming query (i.e. fails with 5% probability on + * any partition). + * + * At the end, the resulting files are loaded and the answer is checked. + */ +class FileStressSuite extends StreamTest with SharedSQLContext { + import testImplicits._ + + test("fault tolerance stress test") { + val numRecords = 10000 + val inputDir = Utils.createTempDir(namePrefix = "stream.input").getCanonicalPath + val stagingDir = Utils.createTempDir(namePrefix = "stream.staging").getCanonicalPath + val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath + val checkpoint = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath + + @volatile + var continue = true + @volatile + var stream: ContinuousQuery = null + + val writer = new Thread("stream writer") { + override def run(): Unit = { + var i = numRecords + while (i > 0) { + val count = Random.nextInt(100) + var j = 0 + var string = "" + while (j < count && i > 0) { + if (i % 10000 == 0) { logError(s"Wrote record $i") } + string = string + i + "\n" + j += 1 + i -= 1 + } + + val uuid = UUID.randomUUID().toString + val fileName = new File(stagingDir, uuid) + stringToFile(fileName, string) + fileName.renameTo(new File(inputDir, uuid)) + val sleep = Random.nextInt(100) + Thread.sleep(sleep) + } + + logError("== DONE WRITING ==") + var done = false + while (!done) { + try { + stream.processAllAvailable() + done = true + } catch { + case NonFatal(_) => + } + } + + continue = false + stream.stop() + } + } + writer.start() + + val input = sqlContext.read.format("text").stream(inputDir) + def startStream(): ContinuousQuery = input + .repartition(5) + .as[String] + .mapPartitions { iter => + val rand = Random.nextInt(100) + if (rand < 5) { sys.error("failure") } + iter.map(_.toLong) + } + .write + .format("parquet") + .option("checkpointLocation", checkpoint) + .startStream(outputDir) + + var failures = 0 + val streamThread = new Thread("stream runner") { + while (continue) { + if (failures % 10 == 0) { logError(s"Query restart #$failures") } + stream = startStream() + + try { + stream.awaitTermination() + } catch { + case ce: ContinuousQueryException => + failures += 1 + } + } + } + + streamThread.join() + + logError(s"Stream restarted $failures times.") + assert(sqlContext.read.parquet(outputDir).distinct().count() == numRecords) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala new file mode 100644 index 0000000000000..1f2834054519b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.{AnalysisException, Row, StreamTest} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils + +class MemorySinkSuite extends StreamTest with SharedSQLContext { + import testImplicits._ + + test("registering as a table") { + val input = MemoryStream[Int] + val query = input.toDF().write + .format("memory") + .queryName("memStream") + .startStream() + input.addData(1, 2, 3) + query.processAllAvailable() + + checkDataset( + sqlContext.table("memStream").as[Int], + 1, 2, 3) + + input.addData(4, 5, 6) + query.processAllAvailable() + checkDataset( + sqlContext.table("memStream").as[Int], + 1, 2, 3, 4, 5, 6) + + query.stop() + } + + test("error when no name is specified") { + val error = intercept[AnalysisException] { + val input = MemoryStream[Int] + val query = input.toDF().write + .format("memory") + .startStream() + } + + assert(error.message contains "queryName must be specified") + } + + test("error if attempting to resume specific checkpoint") { + val location = Utils.createTempDir(namePrefix = "steaming.checkpoint").getCanonicalPath + + val input = MemoryStream[Int] + val query = input.toDF().write + .format("memory") + .queryName("memStream") + .option("checkpointLocation", location) + .startStream() + input.addData(1, 2, 3) + query.processAllAvailable() + query.stop() + + intercept[AnalysisException] { + input.toDF().write + .format("memory") + .queryName("memStream") + .option("checkpointLocation", location) + .startStream() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala new file mode 100644 index 0000000000000..81760d2aa8205 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.StreamTest +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.test.SharedSQLContext + +class MemorySourceStressSuite extends StreamTest with SharedSQLContext { + import testImplicits._ + + test("memory stress test") { + val input = MemoryStream[Int] + val mapped = input.toDS().map(_ + 1) + + runStressTest(mapped, AddData(input, _: _*)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala new file mode 100644 index 0000000000000..9590af4e7737d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.streaming.{CompositeOffset, LongOffset, Offset} + +trait OffsetSuite extends SparkFunSuite { + /** Creates test to check all the comparisons of offsets given a `one` that is less than `two`. */ + def compare(one: Offset, two: Offset): Unit = { + test(s"comparison $one <=> $two") { + assert(one < two) + assert(one <= two) + assert(one <= one) + assert(two > one) + assert(two >= one) + assert(one >= one) + assert(one == one) + assert(two == two) + assert(one != two) + assert(two != one) + } + } + + /** Creates test to check that non-equality comparisons throw exception. */ + def compareInvalid(one: Offset, two: Offset): Unit = { + test(s"invalid comparison $one <=> $two") { + intercept[IllegalArgumentException] { + assert(one < two) + } + + intercept[IllegalArgumentException] { + assert(one <= two) + } + + intercept[IllegalArgumentException] { + assert(one > two) + } + + intercept[IllegalArgumentException] { + assert(one >= two) + } + + assert(!(one == two)) + assert(!(two == one)) + assert(one != two) + assert(two != one) + } + } +} + +class LongOffsetSuite extends OffsetSuite { + val one = LongOffset(1) + val two = LongOffset(2) + compare(one, two) +} + +class CompositeOffsetSuite extends OffsetSuite { + compare( + one = CompositeOffset(Some(LongOffset(1)) :: Nil), + two = CompositeOffset(Some(LongOffset(2)) :: Nil)) + + compare( + one = CompositeOffset(None :: Nil), + two = CompositeOffset(Some(LongOffset(2)) :: Nil)) + + compareInvalid( // sizes must be same + one = CompositeOffset(Nil), + two = CompositeOffset(Some(LongOffset(2)) :: Nil)) + + compare( + one = CompositeOffset.fill(LongOffset(0), LongOffset(1)), + two = CompositeOffset.fill(LongOffset(1), LongOffset(2))) + + compare( + one = CompositeOffset.fill(LongOffset(1), LongOffset(1)), + two = CompositeOffset.fill(LongOffset(1), LongOffset(2))) + + compareInvalid( + one = CompositeOffset.fill(LongOffset(2), LongOffset(1)), // vector time inconsistent + two = CompositeOffset.fill(LongOffset(1), LongOffset(2))) +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala new file mode 100644 index 0000000000000..2bd27c7efdbdc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.sql.{DataFrame, Row, SQLContext, StreamTest} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources.StreamSourceProvider +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +class StreamSuite extends StreamTest with SharedSQLContext { + + import testImplicits._ + + test("map with recovery") { + val inputData = MemoryStream[Int] + val mapped = inputData.toDS().map(_ + 1) + + testStream(mapped)( + AddData(inputData, 1, 2, 3), + StartStream, + CheckAnswer(2, 3, 4), + StopStream, + AddData(inputData, 4, 5, 6), + StartStream, + CheckAnswer(2, 3, 4, 5, 6, 7)) + } + + test("join") { + // Make a table and ensure it will be broadcast. + val smallTable = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + + // Join the input stream with a table. + val inputData = MemoryStream[Int] + val joined = inputData.toDS().toDF().join(smallTable, $"value" === $"number") + + testStream(joined)( + AddData(inputData, 1, 2, 3), + CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two")), + AddData(inputData, 4), + CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two"), Row(4, 4, "four"))) + } + + test("union two streams") { + val inputData1 = MemoryStream[Int] + val inputData2 = MemoryStream[Int] + + val unioned = inputData1.toDS().union(inputData2.toDS()) + + testStream(unioned)( + AddData(inputData1, 1, 3, 5), + CheckAnswer(1, 3, 5), + AddData(inputData2, 2, 4, 6), + CheckAnswer(1, 2, 3, 4, 5, 6), + StopStream, + AddData(inputData1, 7), + StartStream, + AddData(inputData2, 8), + CheckAnswer(1, 2, 3, 4, 5, 6, 7, 8)) + } + + test("sql queries") { + val inputData = MemoryStream[Int] + inputData.toDF().registerTempTable("stream") + val evens = sql("SELECT * FROM stream WHERE value % 2 = 0") + + testStream(evens)( + AddData(inputData, 1, 2, 3, 4), + CheckAnswer(2, 4)) + } + + test("DataFrame reuse") { + def assertDF(df: DataFrame) { + withTempDir { outputDir => + withTempDir { checkpointDir => + val query = df.write.format("parquet") + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .startStream(outputDir.getAbsolutePath) + try { + query.processAllAvailable() + val outputDf = sqlContext.read.parquet(outputDir.getAbsolutePath).as[Long] + checkDataset[Long](outputDf, (0L to 10L).toArray: _*) + } finally { + query.stop() + } + } + } + } + + val df = sqlContext.read.format(classOf[FakeDefaultSource].getName).stream() + assertDF(df) + assertDF(df) + } +} + +/** + * A fake StreamSourceProvider thats creates a fake Source that cannot be reused. + */ +class FakeDefaultSource extends StreamSourceProvider { + + private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) + + override def sourceSchema( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = ("fakeSource", fakeSchema) + + override def createSource( + sqlContext: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + // Create a fake Source that emits 0 to 10. + new Source { + private var offset = -1L + + override def schema: StructType = StructType(StructField("a", IntegerType) :: Nil) + + override def getOffset: Option[Offset] = { + if (offset >= 10) { + None + } else { + offset += 1 + Some(LongOffset(offset)) + } + } + + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + val startOffset = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) + 1 + sqlContext.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a") + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala new file mode 100644 index 0000000000000..3af7c01e525ad --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.SparkException +import org.apache.spark.sql.StreamTest +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.expressions.scala.typed +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +object FailureSinglton { + var firstTime = true +} + +class StreamingAggregationSuite extends StreamTest with SharedSQLContext { + + import testImplicits._ + + test("simple count") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + testStream(aggregated)( + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)), + StopStream, + StartStream, + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 3), (2, 2), (1, 1)), + // By default we run in new tuple mode. + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4)) + ) + } + + test("multiple keys") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value", $"value" + 1) + .agg(count("*")) + .as[(Int, Int, Long)] + + testStream(aggregated)( + AddData(inputData, 1, 2), + CheckLastBatch((1, 2, 1), (2, 3, 1)), + AddData(inputData, 1, 2), + CheckLastBatch((1, 2, 2), (2, 3, 2)) + ) + } + + test("multiple aggregations") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value") + .agg(count("*") as 'count) + .groupBy($"value" % 2) + .agg(sum($"count")) + .as[(Int, Long)] + + testStream(aggregated)( + AddData(inputData, 1, 2, 3, 4), + CheckLastBatch((0, 2), (1, 2)), + AddData(inputData, 1, 3, 5), + CheckLastBatch((1, 5)) + ) + } + + testQuietly("midbatch failure") { + val inputData = MemoryStream[Int] + FailureSinglton.firstTime = true + val aggregated = + inputData.toDS() + .map { i => + if (i == 4 && FailureSinglton.firstTime) { + FailureSinglton.firstTime = false + sys.error("injected failure") + } + + i + } + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + testStream(aggregated)( + StartStream, + AddData(inputData, 1, 2, 3, 4), + ExpectFailure[SparkException](), + StartStream, + CheckLastBatch((1, 1), (2, 1), (3, 1), (4, 1)) + ) + } + + test("typed aggregators") { + val inputData = MemoryStream[(String, Int)] + val aggregated = inputData.toDS().groupByKey(_._1).agg(typed.sumLong(_._2)) + + testStream(aggregated)( + AddData(inputData, ("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)), + CheckLastBatch(("a", 30), ("b", 3), ("c", 1)) + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/ProcessTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/ProcessTestUtils.scala index 152c9c8459de9..df530d8587ef7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/ProcessTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/ProcessTestUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.test -import java.io.{IOException, InputStream} +import java.io.{InputStream, IOException} import scala.sys.process.BasicIO diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 520dea7f7dd92..7fa6760b71c8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.test +import java.nio.charset.StandardCharsets + import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} @@ -103,11 +105,11 @@ private[sql] trait SQLTestData { self => protected lazy val binaryData: DataFrame = { val df = sqlContext.sparkContext.parallelize( - BinaryData("12".getBytes, 1) :: - BinaryData("22".getBytes, 5) :: - BinaryData("122".getBytes, 3) :: - BinaryData("121".getBytes, 2) :: - BinaryData("123".getBytes, 4) :: Nil).toDF() + BinaryData("12".getBytes(StandardCharsets.UTF_8), 1) :: + BinaryData("22".getBytes(StandardCharsets.UTF_8), 5) :: + BinaryData("122".getBytes(StandardCharsets.UTF_8), 3) :: + BinaryData("121".getBytes(StandardCharsets.UTF_8), 2) :: + BinaryData("123".getBytes(StandardCharsets.UTF_8), 4) :: Nil).toDF() df.registerTempTable("binaryData") df } @@ -242,6 +244,17 @@ private[sql] trait SQLTestData { self => df } + protected lazy val courseSales: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + CourseSales("dotNET", 2012, 10000) :: + CourseSales("Java", 2012, 20000) :: + CourseSales("dotNET", 2012, 5000) :: + CourseSales("dotNET", 2013, 48000) :: + CourseSales("Java", 2013, 30000) :: Nil).toDF() + df.registerTempTable("courseSales") + df + } + /** * Initialize all test data such that all temp tables are properly registered. */ @@ -270,6 +283,7 @@ private[sql] trait SQLTestData { self => person salary complexData + courseSales } } @@ -295,4 +309,5 @@ private[sql] object SQLTestData { case class Person(id: Int, name: String, age: Int) case class Salary(personId: Int, salary: Double) case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) + case class CourseSales(course: String, year: Int, earnings: Double) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 9214569f18e93..7844d1b296597 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -20,16 +20,19 @@ package org.apache.spark.sql.test import java.io.File import java.util.UUID -import scala.util.Try import scala.language.implicitConversions +import scala.util.Try import org.apache.hadoop.conf.Configuration import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.Filter import org.apache.spark.util.Utils /** @@ -64,13 +67,6 @@ private[sql] trait SQLTestUtils */ protected object testImplicits extends SQLImplicits { protected override def _sqlContext: SQLContext = self.sqlContext - - // This must live here to preserve binary compatibility with Spark < 1.5. - implicit class StringToColumn(val sc: StringContext) { - def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args: _*)) - } - } } /** @@ -136,11 +132,38 @@ private[sql] trait SQLTestUtils try f(dir) finally Utils.deleteRecursively(dir) } + /** + * Drops functions after calling `f`. A function is represented by (functionName, isTemporary). + */ + protected def withUserDefinedFunction(functions: (String, Boolean)*)(f: => Unit): Unit = { + try { + f + } catch { + case cause: Throwable => throw cause + } finally { + // If the test failed part way, we don't want to mask the failure by failing to remove + // temp tables that never got created. + try functions.foreach { case (functionName, isTemporary) => + val withTemporary = if (isTemporary) "TEMPORARY" else "" + sqlContext.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName") + assert( + !sqlContext.sessionState.catalog.functionExists(FunctionIdentifier(functionName)), + s"Function $functionName should have been dropped. But, it still exists.") + } + } + } + /** * Drops temporary table `tableName` after calling `f`. */ protected def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f finally tableNames.foreach(sqlContext.dropTempTable) + try f finally { + // If the test failed part way, we don't want to mask the failure by failing to remove + // temp tables that never got created. + try tableNames.foreach(sqlContext.dropTempTable) catch { + case _: NoSuchTableException => + } + } } /** @@ -154,9 +177,22 @@ private[sql] trait SQLTestUtils } } + /** + * Drops view `viewName` after calling `f`. + */ + protected def withView(viewNames: String*)(f: => Unit): Unit = { + try f finally { + viewNames.foreach { name => + sqlContext.sql(s"DROP VIEW IF EXISTS $name") + } + } + } + /** * Creates a temporary database and switches current database to it before executing `f`. This * database is dropped after `f` returns. + * + * Note that this method doesn't switch current database before executing `f`. */ protected def withTempDatabase(f: String => Unit): Unit = { val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" @@ -175,8 +211,24 @@ private[sql] trait SQLTestUtils * `f` returns. */ protected def activateDatabase(db: String)(f: => Unit): Unit = { - sqlContext.sql(s"USE $db") - try f finally sqlContext.sql(s"USE default") + sqlContext.sessionState.catalog.setCurrentDatabase(db) + try f finally sqlContext.sessionState.catalog.setCurrentDatabase("default") + } + + /** + * Strip Spark-side filtering in order to check if a datasource filters rows correctly. + */ + protected def stripSparkFilter(df: DataFrame): DataFrame = { + val schema = df.schema + val withoutFilters = df.queryExecution.sparkPlan transform { + case Filter(_, child) => child + } + + val childRDD = withoutFilters + .execute() + .map(row => Row.fromSeq(row.copy().toSeq(schema))) + + sqlContext.createDataFrame(childRDD, schema) } /** @@ -184,7 +236,21 @@ private[sql] trait SQLTestUtils * way to construct [[DataFrame]] directly out of local data without relying on implicits. */ protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { - DataFrame(sqlContext, plan) + Dataset.ofRows(sqlContext, plan) + } + + /** + * Disable stdout and stderr when running the test. To not output the logs to the console, + * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of + * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if + * we change System.out and System.err. + */ + protected def testQuietly(name: String)(f: => Unit): Unit = { + test(name) { + quietly { + f + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 963d10eed62ed..914c6a550900a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.test +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.SQLContext @@ -25,6 +26,8 @@ import org.apache.spark.sql.SQLContext */ trait SharedSQLContext extends SQLTestUtils { + protected val sparkConf = new SparkConf() + /** * The [[TestSQLContext]] to use for all tests in this suite. * @@ -36,14 +39,15 @@ trait SharedSQLContext extends SQLTestUtils { /** * The [[TestSQLContext]] to use for all tests in this suite. */ - protected def sqlContext: SQLContext = _ctx + protected implicit def sqlContext: SQLContext = _ctx /** * Initialize the [[TestSQLContext]]. */ protected override def beforeAll(): Unit = { + SQLContext.clearSqlListener() if (_ctx == null) { - _ctx = new TestSQLContext + _ctx = new TestSQLContext(sparkConf) } // Ensure we have initialized the context before calling parent code super.beforeAll() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index c89a1516503e0..7ab79b12ce246 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -18,29 +18,33 @@ package org.apache.spark.sql.test import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{SQLConf, SQLContext} - +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.internal.{SessionState, SQLConf} /** * A special [[SQLContext]] prepared for testing. */ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { self => - def this() { + def this(sparkConf: SparkConf) { this(new SparkContext("local[2]", "test-sql-context", - new SparkConf().set("spark.sql.testkey", "true"))) + sparkConf.set("spark.sql.testkey", "true"))) } - protected[sql] override lazy val conf: SQLConf = new SQLConf { - - clear() - - override def clear(): Unit = { - super.clear() + def this() { + this(new SparkConf) + } - // Make sure we start with the default test configs even after clear - TestSQLContext.overrideConfs.map { - case (key, value) => setConfString(key, value) + @transient + protected[sql] override lazy val sessionState: SessionState = new SessionState(self) { + override lazy val conf: SQLConf = { + new SQLConf { + clear() + override def clear(): Unit = { + super.clear() + // Make sure we start with the default test configs even after clear + TestSQLContext.overrideConfs.foreach { case (key, value) => setConfString(key, value) } + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala new file mode 100644 index 0000000000000..3498fe83d02eb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import java.util.concurrent.ConcurrentLinkedQueue + +import scala.util.control.NonFatal + +import org.scalatest.BeforeAndAfter +import org.scalatest.PrivateMethodTester._ +import org.scalatest.concurrent.AsyncAssertions.Waiter +import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.util.ContinuousQueryListener.{QueryProgress, QueryStarted, QueryTerminated} + +class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { + + import testImplicits._ + + after { + sqlContext.streams.active.foreach(_.stop()) + assert(sqlContext.streams.active.isEmpty) + assert(addedListeners.isEmpty) + // Make sure we don't leak any events to the next test + sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000) + } + + test("single listener") { + val listener = new QueryStatusCollector + val input = MemoryStream[Int] + withListenerAdded(listener) { + testStream(input.toDS)( + StartStream, + Assert("Incorrect query status in onQueryStarted") { + val status = listener.startStatus + assert(status != null) + assert(status.active == true) + assert(status.sourceStatuses.size === 1) + assert(status.sourceStatuses(0).description.contains("Memory")) + + // The source and sink offsets must be None as this must be called before the + // batches have started + assert(status.sourceStatuses(0).offset === None) + assert(status.sinkStatus.offset === CompositeOffset(None :: Nil)) + + // No progress events or termination events + assert(listener.progressStatuses.isEmpty) + assert(listener.terminationStatus === null) + }, + AddDataMemory(input, Seq(1, 2, 3)), + CheckAnswer(1, 2, 3), + Assert("Incorrect query status in onQueryProgress") { + eventually(Timeout(streamingTimeout)) { + + // There should be only on progress event as batch has been processed + assert(listener.progressStatuses.size === 1) + val status = listener.progressStatuses.peek() + assert(status != null) + assert(status.active == true) + assert(status.sourceStatuses(0).offset === Some(LongOffset(0))) + assert(status.sinkStatus.offset === CompositeOffset.fill(LongOffset(0))) + + // No termination events + assert(listener.terminationStatus === null) + } + }, + StopStream, + Assert("Incorrect query status in onQueryTerminated") { + eventually(Timeout(streamingTimeout)) { + val status = listener.terminationStatus + assert(status != null) + + assert(status.active === false) // must be inactive by the time onQueryTerm is called + assert(status.sourceStatuses(0).offset === Some(LongOffset(0))) + assert(status.sinkStatus.offset === CompositeOffset.fill(LongOffset(0))) + } + listener.checkAsyncErrors() + } + ) + } + } + + test("adding and removing listener") { + def isListenerActive(listener: QueryStatusCollector): Boolean = { + listener.reset() + testStream(MemoryStream[Int].toDS)( + StartStream, + StopStream + ) + listener.startStatus != null + } + + try { + val listener1 = new QueryStatusCollector + val listener2 = new QueryStatusCollector + + sqlContext.streams.addListener(listener1) + assert(isListenerActive(listener1) === true) + assert(isListenerActive(listener2) === false) + sqlContext.streams.addListener(listener2) + assert(isListenerActive(listener1) === true) + assert(isListenerActive(listener2) === true) + sqlContext.streams.removeListener(listener1) + assert(isListenerActive(listener1) === false) + assert(isListenerActive(listener2) === true) + } finally { + addedListeners.foreach(sqlContext.streams.removeListener) + } + } + + test("event ordering") { + val listener = new QueryStatusCollector + withListenerAdded(listener) { + for (i <- 1 to 100) { + listener.reset() + require(listener.startStatus === null) + testStream(MemoryStream[Int].toDS)( + StartStream, + Assert(listener.startStatus !== null, "onQueryStarted not called before query returned"), + StopStream, + Assert { listener.checkAsyncErrors() } + ) + } + } + } + + + private def withListenerAdded(listener: ContinuousQueryListener)(body: => Unit): Unit = { + try { + failAfter(1 minute) { + sqlContext.streams.addListener(listener) + body + } + } finally { + sqlContext.streams.removeListener(listener) + } + } + + private def addedListeners(): Array[ContinuousQueryListener] = { + val listenerBusMethod = + PrivateMethod[ContinuousQueryListenerBus]('listenerBus) + val listenerBus = sqlContext.streams invokePrivate listenerBusMethod() + listenerBus.listeners.toArray.map(_.asInstanceOf[ContinuousQueryListener]) + } + + class QueryStatusCollector extends ContinuousQueryListener { + + private val asyncTestWaiter = new Waiter // to catch errors in the async listener events + + @volatile var startStatus: QueryStatus = null + @volatile var terminationStatus: QueryStatus = null + val progressStatuses = new ConcurrentLinkedQueue[QueryStatus] + + def reset(): Unit = { + startStatus = null + terminationStatus = null + progressStatuses.clear() + + // To reset the waiter + try asyncTestWaiter.await(timeout(1 milliseconds)) catch { + case NonFatal(e) => + } + } + + def checkAsyncErrors(): Unit = { + asyncTestWaiter.await(timeout(streamingTimeout)) + } + + + override def onQueryStarted(queryStarted: QueryStarted): Unit = { + asyncTestWaiter { + startStatus = QueryStatus(queryStarted.query) + } + } + + override def onQueryProgress(queryProgress: QueryProgress): Unit = { + asyncTestWaiter { + assert(startStatus != null, "onQueryProgress called before onQueryStarted") + progressStatuses.add(QueryStatus(queryProgress.query)) + } + } + + override def onQueryTerminated(queryTerminated: QueryTerminated): Unit = { + asyncTestWaiter { + assert(startStatus != null, "onQueryTerminated called before onQueryStarted") + terminationStatus = QueryStatus(queryTerminated.query) + } + asyncTestWaiter.dismiss() + } + } + + case class QueryStatus( + active: Boolean, + exception: Option[Exception], + sourceStatuses: Array[SourceStatus], + sinkStatus: SinkStatus) + + object QueryStatus { + def apply(query: ContinuousQuery): QueryStatus = { + QueryStatus(query.isActive, query.exception, query.sourceStatuses, query.sinkStatus) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index b46b0d2f6040a..e7d2b5ad96821 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark._ import org.apache.spark.sql.{functions, QueryTest} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project} -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegen} import org.apache.spark.sql.test.SharedSQLContext class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { @@ -92,7 +92,11 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { - metrics += qe.executedPlan.longMetric("numInputRows").value.value + val metric = qe.executedPlan match { + case w: WholeStageCodegen => w.child.longMetric("numOutputRows") + case other => other.longMetric("numOutputRows") + } + metrics += metric.value.value } } sqlContext.listenerManager.register(listener) @@ -103,9 +107,9 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() assert(metrics.length == 3) - assert(metrics(0) == 1) - assert(metrics(1) == 1) - assert(metrics(2) == 2) + assert(metrics(0) === 1) + assert(metrics(1) === 1) + assert(metrics(2) === 2) sqlContext.listenerManager.unregister(listener) } @@ -140,7 +144,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { .filter(_._2.name == InternalAccumulator.PEAK_EXECUTION_MEMORY) assert(peakMemoryAccumulator.size == 1) - peakMemoryAccumulator.head._2.value.toLong + peakMemoryAccumulator.head._2.value.get.asInstanceOf[Long] } assert(sparkListener.getCompletedStageInfos.length == 2) diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index b5b2143292a69..c8d17bd468582 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT + spark-parent_2.11 + 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-hive-thriftserver_2.10 + spark-hive-thriftserver_2.11 jar Spark Project Hive Thrift Server http://spark.apache.org/ diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala b/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala index 2228f651e2387..60bb4dc5e77b6 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala @@ -16,7 +16,7 @@ */ package org.apache.hive.service.server -import org.apache.hive.service.server.HiveServer2.{StartOptionExecutor, ServerOptionsProcessor} +import org.apache.hive.service.server.HiveServer2.{ServerOptionsProcessor, StartOptionExecutor} /** * Class to upgrade a package-private class to public, and diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index a4fd0c3ce9702..ee0d23a6e57c4 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -27,17 +27,17 @@ import org.apache.commons.logging.LogFactory import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.service.cli.thrift.{ThriftBinaryCLIService, ThriftHttpCLIService} -import org.apache.hive.service.server.{HiveServerServerOptionsProcessor, HiveServer2} +import org.apache.hive.service.server.{HiveServer2, HiveServerServerOptionsProcessor} +import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd, SparkListenerJobStart} -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.{ShutdownHookManager, Utils} -import org.apache.spark.{Logging, SparkContext} - /** * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a @@ -67,6 +67,7 @@ object HiveThriftServer2 extends Logging { } def main(args: Array[String]) { + Utils.initDaemon(log) val optionsProcessor = new HiveServerServerOptionsProcessor("HiveThriftServer2") if (!optionsProcessor.process(args)) { System.exit(-1) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 719b03e1c7c71..673a293ce2601 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.hive.thriftserver import java.security.PrivilegedExceptionAction import java.sql.{Date, Timestamp} +import java.util.{Arrays, Map => JMap, UUID} import java.util.concurrent.RejectedExecutionException -import java.util.{Arrays, UUID, Map => JMap} import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => SMap} @@ -32,12 +32,13 @@ import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession -import org.apache.spark.Logging -import org.apache.spark.sql.execution.SetCommand +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, Row => SparkRow} +import org.apache.spark.sql.execution.command.SetCommand import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, SQLConf, Row => SparkRow} - +import org.apache.spark.util.{Utils => SparkUtils} private[hive] class SparkExecuteStatementOperation( parentSession: HiveSession, @@ -53,6 +54,18 @@ private[hive] class SparkExecuteStatementOperation( private var dataTypes: Array[DataType] = _ private var statementId: String = _ + private lazy val resultSchema: TableSchema = { + if (result == null || result.queryExecution.analyzed.output.size == 0) { + new TableSchema(Arrays.asList(new FieldSchema("Result", "string", ""))) + } else { + logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") + val schema = result.queryExecution.analyzed.output.map { attr => + new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") + } + new TableSchema(schema.asJava) + } + } + def close(): Unit = { // RDDs will be cleaned automatically upon garbage collection. hiveContext.sparkContext.clearJobGroup() @@ -120,24 +133,14 @@ private[hive] class SparkExecuteStatementOperation( } } - def getResultSetSchema: TableSchema = { - if (result == null || result.queryExecution.analyzed.output.size == 0) { - new TableSchema(Arrays.asList(new FieldSchema("Result", "string", ""))) - } else { - logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") - val schema = result.queryExecution.analyzed.output.map { attr => - new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") - } - new TableSchema(schema.asJava) - } - } + def getResultSetSchema: TableSchema = resultSchema - override def run(): Unit = { + override def runInternal(): Unit = { setState(OperationState.PENDING) setHasResultSet(true) // avoid no resultset for async run if (!runInBackground) { - runInternal() + execute() } else { val sparkServiceUGI = Utils.getUGI() @@ -149,7 +152,7 @@ private[hive] class SparkExecuteStatementOperation( val doAsAction = new PrivilegedExceptionAction[Unit]() { override def run(): Unit = { try { - runInternal() + execute() } catch { case e: HiveSQLException => setOperationException(e) @@ -186,7 +189,7 @@ private[hive] class SparkExecuteStatementOperation( } } - override def runInternal(): Unit = { + private def execute(): Unit = { statementId = UUID.randomUUID().toString logInfo(s"Running query '$statement' with $statementId") setState(OperationState.RUNNING) @@ -219,7 +222,7 @@ private[hive] class SparkExecuteStatementOperation( val useIncrementalCollect = hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean if (useIncrementalCollect) { - result.rdd.toLocalIterator + result.toLocalIterator.asScala } else { result.collect().iterator } @@ -230,7 +233,7 @@ private[hive] class SparkExecuteStatementOperation( if (getStatus().getState() == OperationState.CANCELED) { return } else { - setState(OperationState.ERROR); + setState(OperationState.ERROR) throw e } // Actually do need to catch Throwable as some failures don't inherit from Exception and @@ -240,7 +243,7 @@ private[hive] class SparkExecuteStatementOperation( logError(s"Error executing query, currentState $currentState, ", e) setState(OperationState.ERROR) HiveThriftServer2.listener.onStatementError( - statementId, e.getMessage, e.getStackTraceString) + statementId, e.getMessage, SparkUtils.exceptionString(e)) throw new HiveSQLException(e.toString) } setState(OperationState.FINISHED) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 6419002a2aa89..57693284b01df 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -20,13 +20,10 @@ package org.apache.spark.sql.hive.thriftserver import java.io._ import java.util.{ArrayList => JArrayList, Locale} -import org.apache.spark.sql.AnalysisException - import scala.collection.JavaConverters._ import jline.console.ConsoleReader import jline.console.history.FileHistory - import org.apache.commons.lang3.StringUtils import org.apache.commons.logging.LogFactory import org.apache.hadoop.conf.Configuration @@ -35,13 +32,15 @@ import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, CommandProcessor, + CommandProcessorFactory, SetProcessor} import org.apache.hadoop.hive.ql.session.SessionState import org.apache.thrift.transport.TSocket -import org.apache.spark.Logging +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.util.{ShutdownHookManager, Utils} +import org.apache.spark.util.ShutdownHookManager /** * This code doesn't support remote connections in Hive 1.2+, as the underlying CliDriver @@ -83,7 +82,7 @@ private[hive] object SparkSQLCLIDriver extends Logging { val cliConf = new HiveConf(classOf[SessionState]) // Override the location of the metastore since this is only used for local execution. - HiveContext.newTemporaryConfiguration().foreach { + HiveContext.newTemporaryConfiguration(useInMemoryDerby = false).foreach { case (key, value) => cliConf.set(key, value) } val sessionState = new CliSessionState(cliConf) @@ -151,7 +150,8 @@ private[hive] object SparkSQLCLIDriver extends Logging { } if (sessionState.database != null) { - SparkSQLEnv.hiveContext.runSqlHive(s"USE ${sessionState.database}") + SparkSQLEnv.hiveContext.sessionState.catalog.setCurrentDatabase( + s"${sessionState.database}") } // Execute -i init files (always in silent mode) @@ -194,6 +194,20 @@ private[hive] object SparkSQLCLIDriver extends Logging { logWarning(e.getMessage) } + // add shutdown hook to flush the history to history file + ShutdownHookManager.addShutdownHook { () => + reader.getHistory match { + case h: FileHistory => + try { + h.flush() + } catch { + case e: IOException => + logWarning("WARNING: Failed to write command history file: " + e.getMessage) + } + case _ => + } + } + // TODO: missing /* val clientTransportTSocketField = classOf[CliSessionState].getDeclaredField("transport") @@ -276,8 +290,11 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { val tokens: Array[String] = cmd_trimmed.split("\\s+") val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() if (cmd_lower.equals("quit") || - cmd_lower.equals("exit") || - tokens(0).toLowerCase(Locale.ENGLISH).equals("source") || + cmd_lower.equals("exit")) { + sessionState.close() + System.exit(0) + } + if (tokens(0).toLowerCase(Locale.ENGLISH).equals("source") || cmd_trimmed.startsWith("!") || tokens(0).toLowerCase.equals("list") || isRemoteMode) { @@ -357,7 +374,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { if (counter != 0) { responseMsg += s", Fetched $counter row(s)" } - console.printInfo(responseMsg , null) + console.printInfo(responseMsg, null) // Destroy the driver to release all the locks. driver.destroy() } else { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index 5ad8c54f296d5..6fe57554cf580 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -27,11 +27,11 @@ import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.shims.Utils import org.apache.hadoop.security.UserGroupInformation +import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.hive.service.Service.STATE import org.apache.hive.service.auth.HiveAuthFactory import org.apache.hive.service.cli._ import org.apache.hive.service.server.HiveServer2 -import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index f1ec7238520ac..b8bc8ea44dc84 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.hive.thriftserver -import java.util.{Arrays, ArrayList => JArrayList, List => JList} -import org.apache.log4j.LogManager -import org.apache.spark.sql.AnalysisException +import java.util.{ArrayList => JArrayList, Arrays, List => JList} import scala.collection.JavaConverters._ @@ -28,7 +26,8 @@ import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse -import org.apache.spark.Logging +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} private[hive] class SparkSQLDriver( diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index bacf6cc458fd5..2594c5bfdb3af 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -21,9 +21,10 @@ import java.io.PrintStream import scala.collection.JavaConverters._ +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.internal.Logging import org.apache.spark.scheduler.StatsReportListener import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.util.Utils /** A singleton object for the master program. The slaves should not access this. */ diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index 33aaead3fbf96..de4e9c62b57a4 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -41,6 +41,11 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext: override def init(hiveConf: HiveConf) { setSuperField(this, "hiveConf", hiveConf) + // Create operation log root directory, if operation logging is enabled + if (hiveConf.getBoolVar(ConfVars.HIVE_SERVER2_LOGGING_OPERATION_ENABLED)) { + invoke(classOf[SessionManager], this, "initOperationLogRootDir") + } + val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) getAncestorField[Log](this, 3, "LOG").info( @@ -66,7 +71,11 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext: val session = super.getSession(sessionHandle) HiveThriftServer2.listener.onSessionCreated( session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername) - val ctx = hiveContext.newSession() + val ctx = if (hiveContext.hiveThriftServerSingleSession) { + hiveContext + } else { + hiveContext.newSession() + } ctx.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) sparkSqlOperationManager.sessionToContexts += sessionHandle -> ctx sessionHandle diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 476651a559d2c..0c468a408ba98 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -18,14 +18,16 @@ package org.apache.spark.sql.hive.thriftserver.server import java.util.{Map => JMap} + import scala.collection.mutable.Map import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager} import org.apache.hive.service.cli.session.HiveSession -import org.apache.spark.Logging + +import org.apache.spark.internal.Logging import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.hive.thriftserver.{SparkExecuteStatementOperation, ReflectionUtils} +import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation} /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index e990bd06011ff..c82fa4eaaa4e5 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -23,10 +23,11 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node import org.apache.commons.lang3.StringEscapeUtils -import org.apache.spark.Logging -import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{SessionInfo, ExecutionState, ExecutionInfo} -import org.apache.spark.ui.UIUtils._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{ExecutionInfo, ExecutionState, SessionInfo} import org.apache.spark.ui._ +import org.apache.spark.ui.UIUtils._ /** Page for Spark Web UI that shows statistics of a thrift server */ diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index af16cb31df187..008108a5ce06d 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -23,10 +23,11 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node import org.apache.commons.lang3.StringEscapeUtils -import org.apache.spark.Logging + +import org.apache.spark.internal.Logging import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{ExecutionInfo, ExecutionState} -import org.apache.spark.ui.UIUtils._ import org.apache.spark.ui._ +import org.apache.spark.ui.UIUtils._ /** Page for Spark Web UI that shows statistics of a streaming job */ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala index 4eabeaa6735e6..923ba8a30c5c5 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.hive.thriftserver.ui -import org.apache.spark.sql.hive.thriftserver.{HiveThriftServer2, SparkSQLEnv} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2 import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab._ import org.apache.spark.ui.{SparkUI, SparkUITab} -import org.apache.spark.{SparkContext, Logging, SparkException} /** * Spark Web UI tab that shows statistics of a streaming job. diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 3fa5c8528b602..eb49eabcb1ba9 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -18,39 +18,46 @@ package org.apache.spark.sql.hive.thriftserver import java.io._ +import java.nio.charset.StandardCharsets import java.sql.Timestamp import java.util.Date import scala.collection.mutable.ArrayBuffer -import scala.concurrent.duration._ import scala.concurrent.{Await, Promise} -import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer +import scala.concurrent.duration._ import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.scalatest.BeforeAndAfter +import org.scalatest.BeforeAndAfterAll +import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.Logging +import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkFunSuite} /** * A test suite for the `spark-sql` CLI tool. Note that all test cases share the same temporary * Hive metastore and warehouse. */ -class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { +class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { val warehousePath = Utils.createTempDir() val metastorePath = Utils.createTempDir() val scratchDirPath = Utils.createTempDir() - before { + override def beforeAll(): Unit = { + super.beforeAll() warehousePath.delete() metastorePath.delete() scratchDirPath.delete() } - after { - warehousePath.delete() - metastorePath.delete() - scratchDirPath.delete() + override def afterAll(): Unit = { + try { + warehousePath.delete() + metastorePath.delete() + scratchDirPath.delete() + } finally { + super.afterAll() + } } /** @@ -62,7 +69,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { * with one of these strings is found, fail the test immediately. * The default value is `Seq("Error:")` * - * @param queriesAndExpectedAnswers one or more tupes of query + answer + * @param queriesAndExpectedAnswers one or more tuples of query + answer */ def runCliWithin( timeout: FiniteDuration, @@ -79,6 +86,8 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { val jdbcUrl = s"jdbc:derby:;databaseName=$metastorePath;create=true" s"""$cliScript | --master local + | --driver-java-options -Dderby.system.durability=test + | --conf spark.ui.enabled=false | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$jdbcUrl | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath | --hiveconf ${ConfVars.SCRATCHDIR}=$scratchDirPath @@ -114,7 +123,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { val process = new ProcessBuilder(command: _*).start() - val stdinWriter = new OutputStreamWriter(process.getOutputStream) + val stdinWriter = new OutputStreamWriter(process.getOutputStream, StandardCharsets.UTF_8) stdinWriter.write(queriesString) stdinWriter.flush() stdinWriter.close() @@ -153,7 +162,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { runCliWithin(3.minute)( "CREATE TABLE hive_test(key INT, val STRING);" - -> "OK", + -> "", "SHOW TABLES;" -> "hive_test", s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE hive_test;" @@ -163,31 +172,28 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { "SELECT COUNT(*) FROM hive_test;" -> "5", "DROP TABLE hive_test;" - -> "OK" + -> "" ) } test("Single command with -e") { - runCliWithin(2.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK") + runCliWithin(2.minute, Seq("-e", "SHOW DATABASES;"))("" -> "") } test("Single command with --database") { runCliWithin(2.minute)( "CREATE DATABASE hive_test_db;" - -> "OK", + -> "", "USE hive_test_db;" - -> "OK", + -> "", "CREATE TABLE hive_test(key INT, val STRING);" - -> "OK", + -> "", "SHOW TABLES;" -> "hive_test" ) runCliWithin(2.minute, Seq("--database", "hive_test_db", "-e", "SHOW TABLES;"))( - "" - -> "OK", - "" - -> "hive_test" + "" -> "hive_test" ) } @@ -204,9 +210,9 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { """CREATE TABLE t1(key string, val string) |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'; """.stripMargin - -> "OK", + -> "", "CREATE TABLE sourceTable (key INT, val STRING);" - -> "OK", + -> "", s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE sourceTable;" -> "OK", "INSERT INTO TABLE t1 SELECT key, val FROM sourceTable;" @@ -214,9 +220,9 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { "SELECT count(key) FROM t1;" -> "5", "DROP TABLE t1;" - -> "OK", + -> "", "DROP TABLE sourceTable;" - -> "OK" + -> "" ) } @@ -224,7 +230,12 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { runCliWithin(timeout = 2.minute, errorResponses = Seq("AnalysisException"))( "select * from nonexistent_table;" - -> "Error in query: Table not found: nonexistent_table;" + -> "Error in query: Table or View not found: nonexistent_table;" ) } + + test("SPARK-11624 Spark SQL CLI should set sessionState only once") { + runCliWithin(2.minute, Seq("-e", "!echo \"This is a test for Spark-11624\";"))( + "" -> "This is a test for Spark-11624") + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index ff8ca0150649d..a1268b8e94f56 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -19,16 +19,16 @@ package org.apache.spark.sql.hive.thriftserver import java.io.File import java.net.URL +import java.nio.charset.StandardCharsets import java.sql.{Date, DriverManager, SQLException, Statement} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.{Await, ExecutionContext, Future, Promise} import scala.concurrent.duration._ -import scala.concurrent.{Await, Promise, future} +import scala.io.Source import scala.util.{Random, Try} -import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.jdbc.HiveDriver @@ -40,10 +40,11 @@ import org.apache.thrift.protocol.TBinaryProtocol import org.apache.thrift.transport.TSocket import org.scalatest.BeforeAndAfterAll +import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.Logging import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer -import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.util.{ThreadUtils, Utils} object TestData { def getTestDataFilePath(name: String): URL = { @@ -202,7 +203,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } test("test multiple session") { - import org.apache.spark.sql.SQLConf + import org.apache.spark.sql.internal.SQLConf var defaultV1: String = null var defaultV2: String = null var data: ArrayBuffer[Int] = null @@ -347,7 +348,9 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { ) } - test("test jdbc cancel") { + // This test often hangs and then times out, leaving the hanging processes. + // Let's ignore it and improve the test. + ignore("test jdbc cancel") { withJdbcStatement { statement => val queries = Seq( "DROP TABLE IF EXISTS test_map", @@ -355,31 +358,54 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map") queries.foreach(statement.execute) - - val largeJoin = "SELECT COUNT(*) FROM test_map " + - List.fill(10)("join test_map").mkString(" ") - val f = future { Thread.sleep(100); statement.cancel(); } - val e = intercept[SQLException] { - statement.executeQuery(largeJoin) + implicit val ec = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonSingleThreadExecutor("test-jdbc-cancel")) + try { + // Start a very-long-running query that will take hours to finish, then cancel it in order + // to demonstrate that cancellation works. + val f = Future { + statement.executeQuery( + "SELECT COUNT(*) FROM test_map " + + List.fill(10)("join test_map").mkString(" ")) + } + // Note that this is slightly race-prone: if the cancel is issued before the statement + // begins executing then we'll fail with a timeout. As a result, this fixed delay is set + // slightly more conservatively than may be strictly necessary. + Thread.sleep(1000) + statement.cancel() + val e = intercept[SQLException] { + Await.result(f, 3.minute) + } + assert(e.getMessage.contains("cancelled")) + + // Cancellation is a no-op if spark.sql.hive.thriftServer.async=false + statement.executeQuery("SET spark.sql.hive.thriftServer.async=false") + try { + val sf = Future { + statement.executeQuery( + "SELECT COUNT(*) FROM test_map " + + List.fill(4)("join test_map").mkString(" ") + ) + } + // Similarly, this is also slightly race-prone on fast machines where the query above + // might race and complete before we issue the cancel. + Thread.sleep(1000) + statement.cancel() + val rs1 = Await.result(sf, 3.minute) + rs1.next() + assert(rs1.getInt(1) === math.pow(5, 5)) + rs1.close() + + val rs2 = statement.executeQuery("SELECT COUNT(*) FROM test_map") + rs2.next() + assert(rs2.getInt(1) === 5) + rs2.close() + } finally { + statement.executeQuery("SET spark.sql.hive.thriftServer.async=true") + } + } finally { + ec.shutdownNow() } - assert(e.getMessage contains "cancelled") - Await.result(f, 3.minute) - - // cancel is a noop - statement.executeQuery("SET spark.sql.hive.thriftServer.async=false") - val sf = future { Thread.sleep(100); statement.cancel(); } - val smallJoin = "SELECT COUNT(*) FROM test_map " + - List.fill(4)("join test_map").mkString(" ") - val rs1 = statement.executeQuery(smallJoin) - Await.result(sf, 3.minute) - rs1.next() - assert(rs1.getInt(1) === math.pow(5, 5)) - rs1.close() - - val rs2 = statement.executeQuery("SELECT COUNT(*) FROM test_map") - rs2.next() - assert(rs2.getInt(1) === 5) - rs2.close() } } @@ -462,6 +488,112 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { assert(conf.get("spark.sql.hive.version") === Some("1.2.1")) } } + + test("SPARK-11595 ADD JAR with input path having URL scheme") { + withJdbcStatement { statement => + try { + val jarPath = "../hive/src/test/resources/TestUDTF.jar" + val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath" + + Seq( + s"ADD JAR $jarURL", + s"""CREATE TEMPORARY FUNCTION udtf_count2 + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin + ).foreach(statement.execute) + + val rs1 = statement.executeQuery("DESCRIBE FUNCTION udtf_count2") + + assert(rs1.next()) + assert(rs1.getString(1) === "Function: udtf_count2") + + assert(rs1.next()) + assertResult("Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2") { + rs1.getString(1) + } + + assert(rs1.next()) + assert(rs1.getString(1) === "Usage: To be added.") + + val dataPath = "../hive/src/test/resources/data/files/kv1.txt" + + Seq( + s"CREATE TABLE test_udtf(key INT, value STRING)", + s"LOAD DATA LOCAL INPATH '$dataPath' OVERWRITE INTO TABLE test_udtf" + ).foreach(statement.execute) + + val rs2 = statement.executeQuery( + "SELECT key, cc FROM test_udtf LATERAL VIEW udtf_count2(value) dd AS cc") + + assert(rs2.next()) + assert(rs2.getInt(1) === 97) + assert(rs2.getInt(2) === 500) + + assert(rs2.next()) + assert(rs2.getInt(1) === 97) + assert(rs2.getInt(2) === 500) + } finally { + statement.executeQuery("DROP TEMPORARY FUNCTION udtf_count2") + } + } + } + + test("SPARK-11043 check operation log root directory") { + val expectedLine = + "Operation log root directory is created: " + operationLogPath.getAbsoluteFile + assert(Source.fromFile(logPath).getLines().exists(_.contains(expectedLine))) + } +} + +class SingleSessionSuite extends HiveThriftJdbcTest { + override def mode: ServerMode.Value = ServerMode.binary + + override protected def extraConf: Seq[String] = + "--conf spark.sql.hive.thriftServer.singleSession=true" :: Nil + + test("test single session") { + withMultipleConnectionJdbcStatement( + { statement => + val jarPath = "../hive/src/test/resources/TestUDTF.jar" + val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath" + + // Configurations and temporary functions added in this session should be visible to all + // the other sessions. + Seq( + "SET foo=bar", + s"ADD JAR $jarURL", + s"""CREATE TEMPORARY FUNCTION udtf_count2 + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin + ).foreach(statement.execute) + }, + + { statement => + try { + val rs1 = statement.executeQuery("SET foo") + + assert(rs1.next()) + assert(rs1.getString(1) === "foo") + assert(rs1.getString(2) === "bar") + + val rs2 = statement.executeQuery("DESCRIBE FUNCTION udtf_count2") + + assert(rs2.next()) + assert(rs2.getString(1) === "Function: udtf_count2") + + assert(rs2.next()) + assertResult("Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2") { + rs2.getString(1) + } + + assert(rs2.next()) + assert(rs2.getString(1) === "Usage: To be added.") + } finally { + statement.executeQuery("DROP TEMPORARY FUNCTION udtf_count2") + } + } + ) + } } class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { @@ -550,10 +682,13 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl protected def metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" private val pidDir: File = Utils.createTempDir("thriftserver-pid") - private var logPath: File = _ + protected var logPath: File = _ + protected var operationLogPath: File = _ private var logTailingProcess: Process = _ private var diagnosisBuffer: ArrayBuffer[String] = ArrayBuffer.empty[String] + protected def extraConf: Seq[String] = Nil + protected def serverStartCommand(port: Int) = { val portConf = if (mode == ServerMode.binary) { ConfVars.HIVE_SERVER2_THRIFT_PORT @@ -574,7 +709,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n """.stripMargin, new File(s"$tempLog4jConf/log4j.properties"), - UTF_8) + StandardCharsets.UTF_8) tempLog4jConf } @@ -585,21 +720,23 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode + | --hiveconf ${ConfVars.HIVE_SERVER2_LOGGING_OPERATION_LOG_LOCATION}=$operationLogPath | --hiveconf $portConf=$port | --driver-class-path $driverClassPath | --driver-java-options -Dlog4j.debug | --conf spark.ui.enabled=false + | ${extraConf.mkString("\n")} """.stripMargin.split("\\s+").toSeq } /** - * String to scan for when looking for the the thrift binary endpoint running. + * String to scan for when looking for the thrift binary endpoint running. * This can change across Hive versions. */ val THRIFT_BINARY_SERVICE_LIVE = "Starting ThriftBinaryCLIService on port" /** - * String to scan for when looking for the the thrift HTTP endpoint running. + * String to scan for when looking for the thrift HTTP endpoint running. * This can change across Hive versions. */ val THRIFT_HTTP_SERVICE_LIVE = "Started ThriftHttpCLIService in http" @@ -611,6 +748,8 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl warehousePath.delete() metastorePath = Utils.createTempDir() metastorePath.delete() + operationLogPath = Utils.createTempDir() + operationLogPath.delete() logPath = null logTailingProcess = null @@ -632,11 +771,15 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl extraEnvironment = Map( // Disables SPARK_TESTING to exclude log4j.properties in test directories. "SPARK_TESTING" -> "0", + // But set SPARK_SQL_TESTING to make spark-class happy. + "SPARK_SQL_TESTING" -> "1", // Points SPARK_PID_DIR to SPARK_HOME, otherwise only 1 Thrift server instance can be // started at a time, which is not Jenkins friendly. "SPARK_PID_DIR" -> pidDir.getCanonicalPath), redirectStderr = true) + logInfo(s"COMMAND: $command") + logInfo(s"OUTPUT: $lines") lines.split("\n").collectFirst { case line if line.contains(LOG_FILE_MARK) => new File(line.drop(LOG_FILE_MARK.length)) }.getOrElse { @@ -687,6 +830,9 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl metastorePath.delete() metastorePath = null + operationLogPath.delete() + operationLogPath = null + Option(logPath).foreach(_.delete()) logPath = null @@ -708,6 +854,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl } override protected def beforeAll(): Unit = { + super.beforeAll() // Chooses a random port between 10000 and 19999 listeningPort = 10000 + Random.nextInt(10000) diagnosisBuffer.clear() @@ -729,7 +876,11 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl } override protected def afterAll(): Unit = { - stopThriftServer() - logInfo("HiveThriftServer2 stopped") + try { + stopThriftServer() + logInfo("HiveThriftServer2 stopped") + } finally { + super.afterAll() + } } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 2d0d7b8af3581..989e68aebed9b 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -20,17 +20,16 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.util.{Locale, TimeZone} -import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.tags.ExtendedHiveTest +import org.apache.spark.sql.internal.SQLConf /** * Runs the test cases that are included in the hive distribution. */ -@ExtendedHiveTest class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // TODO: bundle in jar files... get from classpath private lazy val hiveQueryDir = TestHive.getHiveFile( @@ -40,10 +39,14 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalLocale = Locale.getDefault private val originalColumnBatchSize = TestHive.conf.columnBatchSize private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning + private val originalConvertMetastoreOrc = TestHive.convertMetastoreOrc - def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) + def testCases: Seq[(String, File)] = { + hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) + } override def beforeAll() { + super.beforeAll() TestHive.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) @@ -53,22 +56,33 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, 5) // Enable in-memory partition pruning for testing purposes TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) + // Use Hive hash expression instead of the native one + TestHive.sessionState.functionRegistry.unregisterFunction("hash") + // Ensures that the plans generation use metastore relation and not OrcRelation + // Was done because SqlBuilder does not work with plans having logical relation + TestHive.setConf(HiveContext.CONVERT_METASTORE_ORC, false) RuleExecutor.resetTime() } override def afterAll() { - TestHive.cacheTables = false - TimeZone.setDefault(originalTimeZone) - Locale.setDefault(originalLocale) - TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) - TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) - - // For debugging dump some statistics about how much time was spent in various optimizer rules. - logWarning(RuleExecutor.dumpTimeSpent()) + try { + TestHive.cacheTables = false + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) + TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + TestHive.setConf(HiveContext.CONVERT_METASTORE_ORC, originalConvertMetastoreOrc) + TestHive.sessionState.functionRegistry.restore() + + // For debugging dump some statistics about how much time was spent in various optimizer rules. + logWarning(RuleExecutor.dumpTimeSpent()) + } finally { + super.afterAll() + } } /** A list of tests deemed out of scope currently and thus completely disregarded. */ - override def blackList = Seq( + override def blackList: Seq[String] = Seq( // These tests use hooks that are not on the classpath and thus break all subsequent execution. "hook_order", "hook_context_cs", @@ -103,7 +117,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "alter_merge", "alter_concatenate_indexed_table", "protectmode2", - //"describe_table", + // "describe_table", "describe_comment_nonascii", "create_merge_compressed", @@ -283,12 +297,14 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "compute_stats_empty_table", "compute_stats_long", "create_view_translate", - "show_create_table_serde", "show_tblproperties", // Odd changes to output "merge4", + // Unsupported underscore syntax. + "inputddl5", + // Thift is broken... "inputddl8", @@ -308,14 +324,141 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // The difference between the double numbers generated by Hive and Spark // can be ignored (e.g., 0.6633880657639323 and 0.6633880657639322) - "udaf_corr" + "udaf_corr", + + // Feature removed in HIVE-11145 + "alter_partition_protect_mode", + "drop_partitions_ignore_protection", + "protectmode", + + // Hive returns null rather than NaN when n = 1 + "udaf_covar_samp", + + // The implementation of GROUPING__ID in Hive is wrong (not match with doc). + "groupby_grouping_id1", + "groupby_grouping_id2", + "groupby_grouping_sets1", + + // Spark parser treats numerical literals differently: it creates decimals instead of doubles. + "udf_abs", + "udf_format_number", + "udf_round", + "udf_round_3", + "view_cast", + + // These tests check the VIEW table definition, but Spark handles CREATE VIEW itself and + // generates different View Expanded Text. + "alter_view_as_select", + + // We don't support show create table commands in general + "show_create_table_alter", + "show_create_table_db_table", + "show_create_table_delimited", + "show_create_table_does_not_exist", + "show_create_table_index", + "show_create_table_partitioned", + "show_create_table_serde", + "show_create_table_view", + + // These tests try to change how a table is bucketed, which we don't support + "alter4", + "sort_merge_join_desc_5", + "sort_merge_join_desc_6", + "sort_merge_join_desc_7", + + // These tests try to create a table with bucketed columns, which we don't support + "auto_join32", + "auto_join_filters", + "auto_smb_mapjoin_14", + "ct_case_insensitive", + "explain_rearrange", + "groupby_sort_10", + "groupby_sort_2", + "groupby_sort_3", + "groupby_sort_4", + "groupby_sort_5", + "groupby_sort_7", + "groupby_sort_8", + "groupby_sort_9", + "groupby_sort_test_1", + "inputddl4", + "join_filters", + "join_nulls", + "join_nullsafe", + "load_dyn_part2", + "orc_empty_files", + "reduce_deduplicate", + "smb_mapjoin9", + "smb_mapjoin_1", + "smb_mapjoin_10", + "smb_mapjoin_13", + "smb_mapjoin_14", + "smb_mapjoin_15", + "smb_mapjoin_16", + "smb_mapjoin_17", + "smb_mapjoin_2", + "smb_mapjoin_21", + "smb_mapjoin_25", + "smb_mapjoin_3", + "smb_mapjoin_4", + "smb_mapjoin_5", + "smb_mapjoin_6", + "smb_mapjoin_7", + "smb_mapjoin_8", + "sort_merge_join_desc_1", + "sort_merge_join_desc_2", + "sort_merge_join_desc_3", + "sort_merge_join_desc_4", + + // These tests try to create a table with skewed columns, which we don't support + "create_skewed_table1", + "skewjoinopt13", + "skewjoinopt18", + "skewjoinopt9", + + // This test tries to create a table like with TBLPROPERTIES clause, which we don't support. + "create_like_tbl_props", + + // Index commands are not supported + "drop_index", + "drop_index_removes_partition_dirs", + "alter_index", + "auto_sortmerge_join_1", + "auto_sortmerge_join_10", + "auto_sortmerge_join_11", + "auto_sortmerge_join_12", + "auto_sortmerge_join_13", + "auto_sortmerge_join_14", + "auto_sortmerge_join_15", + "auto_sortmerge_join_16", + "auto_sortmerge_join_2", + "auto_sortmerge_join_3", + "auto_sortmerge_join_4", + "auto_sortmerge_join_5", + "auto_sortmerge_join_6", + "auto_sortmerge_join_7", + "auto_sortmerge_join_8", + "auto_sortmerge_join_9", + + // Macro commands are not supported + "macro", + + // Create partitioned view is not supported + "create_like_view", + "describe_formatted_view_partitioned", + + // This uses CONCATENATE, which we don't support + "alter_merge_2", + + // TOUCH is not supported + "touch" ) /** * The set of tests that are believed to be working in catalyst. Tests not on whiteList or * blacklist are implicitly marked as ignored. */ - override def whiteList = Seq( + override def whiteList: Seq[String] = Seq( "add_part_exist", "add_part_multiple", "add_partition_no_whitelist", @@ -323,18 +466,13 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "alias_casted_column", "alter2", "alter3", - "alter4", "alter5", - "alter_index", - "alter_merge_2", "alter_partition_format_loc", - "alter_partition_protect_mode", "alter_partition_with_whitelist", "alter_rename_partition", "alter_table_serde", "alter_varchar1", "alter_varchar2", - "alter_view_as_select", "ambiguous_col", "annotate_stats_join", "annotate_stats_limit", @@ -366,33 +504,14 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "auto_join3", "auto_join30", "auto_join31", - "auto_join32", "auto_join4", "auto_join5", "auto_join6", "auto_join7", "auto_join8", "auto_join9", - "auto_join_filters", "auto_join_nulls", "auto_join_reordering_values", - "auto_smb_mapjoin_14", - "auto_sortmerge_join_1", - "auto_sortmerge_join_10", - "auto_sortmerge_join_11", - "auto_sortmerge_join_12", - "auto_sortmerge_join_13", - "auto_sortmerge_join_14", - "auto_sortmerge_join_15", - "auto_sortmerge_join_16", - "auto_sortmerge_join_2", - "auto_sortmerge_join_3", - "auto_sortmerge_join_4", - "auto_sortmerge_join_5", - "auto_sortmerge_join_6", - "auto_sortmerge_join_7", - "auto_sortmerge_join_8", - "auto_sortmerge_join_9", "binary_constant", "binarysortable_1", "cast1", @@ -421,16 +540,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "count", "cp_mj_rc", "create_insert_outputformat", - "create_like_tbl_props", - "create_like_view", "create_nested_type", - "create_skewed_table1", "create_struct_table", "create_view_translate", "cross_join", "cross_product_check_1", "cross_product_check_2", - "ct_case_insensitive", "database_drop", "database_location", "database_properties", @@ -447,20 +562,16 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "default_partition_name", "delimiter", "desc_non_existent_tbl", - "describe_formatted_view_partitioned", "diff_part_input_formats", "disable_file_format_check", "disallow_incompatible_type_change_off", "distinct_stats", "drop_database_removes_partition_dirs", "drop_function", - "drop_index", - "drop_index_removes_partition_dirs", "drop_multi_partitions", "drop_partitions_filter", "drop_partitions_filter2", "drop_partitions_filter3", - "drop_partitions_ignore_protection", "drop_table", "drop_table2", "drop_table_removes_partition_dirs", @@ -470,7 +581,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "escape_distributeby1", "escape_orderby1", "escape_sortby1", - "explain_rearrange", "fileformat_mix", "fileformat_sequencefile", "fileformat_text", @@ -480,9 +590,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "groupby11", "groupby12", "groupby1_limit", - "groupby_grouping_id1", - "groupby_grouping_id2", - "groupby_grouping_sets1", "groupby_grouping_sets2", "groupby_grouping_sets3", "groupby_grouping_sets4", @@ -528,16 +635,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "groupby_neg_float", "groupby_ppd", "groupby_ppr", - "groupby_sort_10", - "groupby_sort_2", - "groupby_sort_3", - "groupby_sort_4", - "groupby_sort_5", "groupby_sort_6", - "groupby_sort_7", - "groupby_sort_8", - "groupby_sort_9", - "groupby_sort_test_1", "having", "implicit_cast1", "index_serde", @@ -592,8 +690,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "inputddl1", "inputddl2", "inputddl3", - "inputddl4", - "inputddl5", "inputddl6", "inputddl7", "inputddl8", @@ -649,11 +745,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "join_array", "join_casesensitive", "join_empty", - "join_filters", "join_hive_626", "join_map_ppr", - "join_nulls", - "join_nullsafe", "join_rc", "join_reorder2", "join_reorder3", @@ -677,7 +770,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "load_dyn_part13", "load_dyn_part14", "load_dyn_part14_win", - "load_dyn_part2", "load_dyn_part3", "load_dyn_part4", "load_dyn_part5", @@ -688,7 +780,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "load_file_with_space_in_the_name", "loadpart1", "louter_join_ppr", - "macro", "mapjoin_distinct", "mapjoin_filter_on_outerjoin", "mapjoin_mapjoin", @@ -731,7 +822,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "nullscript", "optional_outer", "orc_dictionary_threshold", - "orc_empty_files", "order", "order2", "outer_join_ppr", @@ -778,7 +868,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "ppr_pushdown2", "ppr_pushdown3", "progress_1", - "protectmode", "push_or", "query_with_semi", "quote1", @@ -788,7 +877,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "rcfile_null_value", "rcfile_toleratecorruptions", "rcfile_union", - "reduce_deduplicate", "reduce_deduplicate_exclude_gby", "reduce_deduplicate_exclude_join", "reduce_deduplicate_extended", @@ -805,46 +893,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "serde_reported_schema", "set_variable_sub", "show_columns", - "show_create_table_alter", - "show_create_table_db_table", - "show_create_table_delimited", - "show_create_table_does_not_exist", - "show_create_table_index", - "show_create_table_partitioned", - "show_create_table_serde", - "show_create_table_view", "show_describe_func_quotes", "show_functions", "show_partitions", "show_tblproperties", - "skewjoinopt13", - "skewjoinopt18", - "skewjoinopt9", - "smb_mapjoin9", - "smb_mapjoin_1", - "smb_mapjoin_10", - "smb_mapjoin_13", - "smb_mapjoin_14", - "smb_mapjoin_15", - "smb_mapjoin_16", - "smb_mapjoin_17", - "smb_mapjoin_2", - "smb_mapjoin_21", - "smb_mapjoin_25", - "smb_mapjoin_3", - "smb_mapjoin_4", - "smb_mapjoin_5", - "smb_mapjoin_6", - "smb_mapjoin_7", - "smb_mapjoin_8", "sort", - "sort_merge_join_desc_1", - "sort_merge_join_desc_2", - "sort_merge_join_desc_3", - "sort_merge_join_desc_4", - "sort_merge_join_desc_5", - "sort_merge_join_desc_6", - "sort_merge_join_desc_7", "stats0", "stats_aggregator_error_1", "stats_empty_partition", @@ -855,7 +908,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "timestamp_comparison", "timestamp_lazy", "timestamp_null", - "touch", "transform_ppr1", "transform_ppr2", "truncate_table", @@ -863,7 +915,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "type_widening", "udaf_collect_set", "udaf_covar_pop", - "udaf_covar_samp", "udaf_histogram_numeric", "udf2", "udf5", @@ -873,7 +924,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_10_trims", "udf_E", "udf_PI", - "udf_abs", "udf_acos", "udf_add", "udf_array", @@ -917,7 +967,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_find_in_set", "udf_float", "udf_floor", - "udf_format_number", "udf_from_unixtime", "udf_greaterthan", "udf_greaterthanorequal", @@ -965,8 +1014,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_regexp_replace", "udf_repeat", "udf_rlike", - "udf_round", - "udf_round_3", "udf_rpad", "udf_rtrim", "udf_sign", diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index 92bb9e6d73af1..d0b4cbe401eb3 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -104,6 +104,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) TestHive.reset() + super.afterAll() } ///////////////////////////////////////////////////////////////////////////// @@ -454,6 +455,9 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte |window w1 as (distribute by p_mfgr sort by p_name rows between 2 preceding and 2 following) """.stripMargin, reset = false) + /* Disabled because: + - Spark uses a different default stddev. + - Tiny numerical differences in stddev results. createQueryTest("windowing.q -- 15. testExpressions", s""" |select p_mfgr,p_name, p_size, @@ -472,7 +476,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name | rows between 2 preceding and 2 following) """.stripMargin, reset = false) - + */ createQueryTest("windowing.q -- 16. testMultipleWindows", s""" |select p_mfgr,p_name, p_size, @@ -530,6 +534,9 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte // when running this test suite under Java 7 and 8. // We change the original sql query a little bit for making the test suite passed // under different JDK + /* Disabled because: + - Spark uses a different default stddev. + - Tiny numerical differences in stddev results. createQueryTest("windowing.q -- 20. testSTATs", """ |select p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp @@ -547,12 +554,12 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte |) t lateral view explode(uniq_size) d as uniq_data |order by p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp """.stripMargin, reset = false) - + */ createQueryTest("windowing.q -- 21. testDISTs", """ |select p_mfgr,p_name, p_size, |histogram_numeric(p_retailprice, 5) over w1 as hist, - |percentile(p_partkey, 0.5) over w1 as per, + |percentile(p_partkey, cast(0.5 as double)) over w1 as per, |row_number() over(distribute by p_mfgr sort by p_name) as rn |from part |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index d96f3e2b9f62b..61504becf1f38 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -21,13 +21,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT + spark-parent_2.11 + 2.0.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-hive_2.10 + spark-hive_2.11 jar Spark Project Hive http://spark.apache.org/ @@ -72,6 +72,12 @@ protobuf-java ${protobuf.version} +--> + + ${hive.group} + hive-cli + + - - org.apache.maven.plugins - maven-dependency-plugin - - - copy-dependencies - package - - copy-dependencies - - - - ${basedir}/../../lib_managed/jars - false - false - true - org.datanucleus - - - - diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala deleted file mode 100644 index 7f8449cdc282d..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import scala.language.implicitConversions - -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.AbstractSparkSQLParser -import org.apache.spark.sql.hive.execution.{AddJar, AddFile, HiveNativeCommand} - -/** - * A parser that recognizes all HiveQL constructs together with Spark SQL specific extensions. - */ -private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser { - // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` - // properties via reflection the class in runtime for constructing the SqlLexical object - protected val ADD = Keyword("ADD") - protected val DFS = Keyword("DFS") - protected val FILE = Keyword("FILE") - protected val JAR = Keyword("JAR") - - protected lazy val start: Parser[LogicalPlan] = dfs | addJar | addFile | hiveQl - - protected lazy val hiveQl: Parser[LogicalPlan] = - restInput ^^ { - case statement => HiveQl.createPlan(statement.trim) - } - - protected lazy val dfs: Parser[LogicalPlan] = - DFS ~> wholeInput ^^ { - case command => HiveNativeCommand(command.trim) - } - - private lazy val addFile: Parser[LogicalPlan] = - ADD ~ FILE ~> restInput ^^ { - case input => AddFile(input.trim) - } - - private lazy val addJar: Parser[LogicalPlan] = - ADD ~ JAR ~> restInput ^^ { - case input => AddJar(input.trim) - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2d72b959af134..ff93bfc4a3d16 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive import java.io.File import java.net.{URL, URLClassLoader} +import java.nio.charset.StandardCharsets import java.sql.Timestamp import java.util.concurrent.TimeUnit import java.util.regex.Pattern @@ -27,6 +28,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.HashMap import scala.language.implicitConversions +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.common.`type`.HiveDecimal @@ -35,37 +37,28 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} +import org.apache.hadoop.util.VersionInfo +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.api.java.JavaSparkContext -import org.apache.spark.sql.SQLConf.SQLConfEntry -import org.apache.spark.sql.SQLConf.SQLConfEntry._ +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.ConfigEntry import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, SqlParser} -import org.apache.spark.sql.execution.datasources.{ResolveDataSource, DataSourceStrategy, PreInsertCastAndRename, PreWriteCheck} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.command.{ExecutedCommand, SetCommand} import org.apache.spark.sql.execution.ui.SQLListener -import org.apache.spark.sql.execution.{CacheManager, ExecutedCommand, ExtractPythonUDFs, SetCommand} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkContext} - - -/** - * This is the HiveQL Dialect, this dialect is strongly bind with HiveContext - */ -private[hive] class HiveQLDialect(sqlContext: HiveContext) extends ParserDialect { - override def parse(sqlText: String): LogicalPlan = { - sqlContext.executionHive.withHiveState { - HiveQl.parseSql(sqlText) - } - } -} /** * Returns the current database of metadataHive. @@ -76,7 +69,7 @@ private[hive] case class CurrentDatabase(ctx: HiveContext) override def foldable: Boolean = true override def nullable: Boolean = false override def eval(input: InternalRow): Any = { - UTF8String.fromString(ctx.metadataHive.currentDatabase) + UTF8String.fromString(ctx.sessionState.catalog.getCurrentDatabase) } } @@ -90,15 +83,31 @@ class HiveContext private[hive]( sc: SparkContext, cacheManager: CacheManager, listener: SQLListener, - @transient private val execHive: ClientWrapper, - @transient private val metaHive: ClientInterface, - isRootContext: Boolean) - extends SQLContext(sc, cacheManager, listener, isRootContext) with Logging { + @transient private[hive] val executionHive: HiveClientImpl, + @transient private[hive] val metadataHive: HiveClient, + isRootContext: Boolean, + @transient private[sql] val hiveCatalog: HiveExternalCatalog) + extends SQLContext(sc, cacheManager, listener, isRootContext, hiveCatalog) with Logging { self => + private def this(sc: SparkContext, execHive: HiveClientImpl, metaHive: HiveClient) { + this( + sc, + new CacheManager, + SQLContext.createListenerAndUI(sc), + execHive, + metaHive, + true, + new HiveExternalCatalog(metaHive)) + } + def this(sc: SparkContext) = { - this(sc, new CacheManager, SQLContext.createListenerAndUI(sc), null, null, true) + this( + sc, + HiveContext.newClientForExecution(sc.conf, sc.hadoopConfiguration), + HiveContext.newClientForMetadata(sc.conf, sc.hadoopConfiguration)) } + def this(sc: JavaSparkContext) = this(sc.sc) import org.apache.spark.sql.hive.HiveContext._ @@ -115,11 +124,20 @@ class HiveContext private[hive]( sc = sc, cacheManager = cacheManager, listener = listener, - execHive = executionHive.newSession(), - metaHive = metadataHive.newSession(), - isRootContext = false) + executionHive = executionHive.newSession(), + metadataHive = metadataHive.newSession(), + isRootContext = false, + hiveCatalog = hiveCatalog) } + @transient + protected[sql] override lazy val sessionState = new HiveSessionState(self) + + // The Hive UDF current_database() is foldable, will be evaluated by optimizer, + // but the optimizer can't access the SessionState of metadataHive. + sessionState.functionRegistry.registerFunction( + "current_database", (e: Seq[Expression]) => new CurrentDatabase(self)) + /** * When true, enables an experimental feature where metastore tables that use the parquet SerDe * are automatically converted to use the Spark SQL parquet table scan, instead of the Hive @@ -136,6 +154,13 @@ class HiveContext private[hive]( protected[sql] def convertMetastoreParquetWithSchemaMerging: Boolean = getConf(CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING) + /** + * When true, enables an experimental feature where metastore tables that use the Orc SerDe + * are automatically converted to use the Spark SQL ORC table scan, instead of the Hive + * SerDe. + */ + protected[sql] def convertMetastoreOrc: Boolean = getConf(CONVERT_METASTORE_ORC) + /** * When true, a table created by a Hive CTAS statement (no USING clause) will be * converted to a data source table, using the data source set by spark.sql.sources.default. @@ -150,69 +175,16 @@ class HiveContext private[hive]( */ protected[sql] def convertCTAS: Boolean = getConf(CONVERT_CTAS) - /** - * The version of the hive client that will be used to communicate with the metastore. Note that - * this does not necessarily need to be the same version of Hive that is used internally by - * Spark SQL for execution. - */ - protected[hive] def hiveMetastoreVersion: String = getConf(HIVE_METASTORE_VERSION) - - /** - * The location of the jars that should be used to instantiate the HiveMetastoreClient. This - * property can be one of three options: - * - a classpath in the standard format for both hive and hadoop. - * - builtin - attempt to discover the jars that were used to load Spark SQL and use those. This - * option is only valid when using the execution version of Hive. - * - maven - download the correct version of hive on demand from maven. - */ - protected[hive] def hiveMetastoreJars: String = getConf(HIVE_METASTORE_JARS) - - /** - * A comma separated list of class prefixes that should be loaded using the classloader that - * is shared between Spark SQL and a specific version of Hive. An example of classes that should - * be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need - * to be shared are those that interact with classes that are already shared. For example, - * custom appenders that are used by log4j. - */ - protected[hive] def hiveMetastoreSharedPrefixes: Seq[String] = - getConf(HIVE_METASTORE_SHARED_PREFIXES).filterNot(_ == "") - - /** - * A comma separated list of class prefixes that should explicitly be reloaded for each version - * of Hive that Spark SQL is communicating with. For example, Hive UDFs that are declared in a - * prefix that typically would be shared (i.e. org.apache.spark.*) - */ - protected[hive] def hiveMetastoreBarrierPrefixes: Seq[String] = - getConf(HIVE_METASTORE_BARRIER_PREFIXES).filterNot(_ == "") - /* * hive thrift server use background spark sql thread pool to execute sql queries */ protected[hive] def hiveThriftServerAsync: Boolean = getConf(HIVE_THRIFT_SERVER_ASYNC) - @transient - protected[sql] lazy val substitutor = new VariableSubstitution() + protected[hive] def hiveThriftServerSingleSession: Boolean = + sc.conf.get("spark.sql.hive.thriftServer.singleSession", "false").toBoolean - /** - * The copy of the hive client that is used for execution. Currently this must always be - * Hive 13 as this is the version of Hive that is packaged with Spark SQL. This copy of the - * client is used for execution related tasks like registering temporary functions or ensuring - * that the ThreadLocal SessionState is correctly populated. This copy of Hive is *not* used - * for storing persistent metadata, and only point to a dummy metastore in a temporary directory. - */ @transient - protected[hive] lazy val executionHive: ClientWrapper = if (execHive != null) { - execHive - } else { - logInfo(s"Initializing execution hive, version $hiveExecutionVersion") - val loader = new IsolatedClientLoader( - version = IsolatedClientLoader.hiveVersion(hiveExecutionVersion), - execJars = Seq(), - config = newTemporaryConfiguration(), - isolationOn = false, - baseClassLoader = Utils.getContextOrSparkClassLoader) - loader.createClient().asInstanceOf[ClientWrapper] - } + protected[sql] lazy val substitutor = new VariableSubstitution() /** * Overrides default Hive configurations to avoid breaking changes to Spark SQL users. @@ -224,106 +196,10 @@ class HiveContext private[hive]( defaultOverrides() - /** - * The copy of the Hive client that is used to retrieve metadata from the Hive MetaStore. - * The version of the Hive client that is used here must match the metastore that is configured - * in the hive-site.xml file. - */ - @transient - protected[hive] lazy val metadataHive: ClientInterface = if (metaHive != null) { - metaHive - } else { - val metaVersion = IsolatedClientLoader.hiveVersion(hiveMetastoreVersion) - - // We instantiate a HiveConf here to read in the hive-site.xml file and then pass the options - // into the isolated client loader - val metadataConf = new HiveConf() - - val defaultWarehouseLocation = metadataConf.get("hive.metastore.warehouse.dir") - logInfo("default warehouse location is " + defaultWarehouseLocation) - - // `configure` goes second to override other settings. - val allConfig = metadataConf.asScala.map(e => e.getKey -> e.getValue).toMap ++ configure - - val isolatedLoader = if (hiveMetastoreJars == "builtin") { - if (hiveExecutionVersion != hiveMetastoreVersion) { - throw new IllegalArgumentException( - "Builtin jars can only be used when hive execution version == hive metastore version. " + - s"Execution: ${hiveExecutionVersion} != Metastore: ${hiveMetastoreVersion}. " + - "Specify a vaild path to the correct hive jars using $HIVE_METASTORE_JARS " + - s"or change ${HIVE_METASTORE_VERSION.key} to $hiveExecutionVersion.") - } - - // We recursively find all jars in the class loader chain, - // starting from the given classLoader. - def allJars(classLoader: ClassLoader): Array[URL] = classLoader match { - case null => Array.empty[URL] - case urlClassLoader: URLClassLoader => - urlClassLoader.getURLs ++ allJars(urlClassLoader.getParent) - case other => allJars(other.getParent) - } - - val classLoader = Utils.getContextOrSparkClassLoader - val jars = allJars(classLoader) - if (jars.length == 0) { - throw new IllegalArgumentException( - "Unable to locate hive jars to connect to metastore. " + - "Please set spark.sql.hive.metastore.jars.") - } - - logInfo( - s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using Spark classes.") - new IsolatedClientLoader( - version = metaVersion, - execJars = jars.toSeq, - config = allConfig, - isolationOn = true, - barrierPrefixes = hiveMetastoreBarrierPrefixes, - sharedPrefixes = hiveMetastoreSharedPrefixes) - } else if (hiveMetastoreJars == "maven") { - // TODO: Support for loading the jars from an already downloaded location. - logInfo( - s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using maven.") - IsolatedClientLoader.forVersion( - version = hiveMetastoreVersion, - config = allConfig, - barrierPrefixes = hiveMetastoreBarrierPrefixes, - sharedPrefixes = hiveMetastoreSharedPrefixes) - } else { - // Convert to files and expand any directories. - val jars = - hiveMetastoreJars - .split(File.pathSeparator) - .flatMap { - case path if new File(path).getName() == "*" => - val files = new File(path).getParentFile().listFiles() - if (files == null) { - logWarning(s"Hive jar path '$path' does not exist.") - Nil - } else { - files.filter(_.getName().toLowerCase().endsWith(".jar")) - } - case path => - new File(path) :: Nil - } - .map(_.toURI.toURL) - - logInfo( - s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion " + - s"using ${jars.mkString(":")}") - new IsolatedClientLoader( - version = metaVersion, - execJars = jars.toSeq, - config = allConfig, - isolationOn = true, - barrierPrefixes = hiveMetastoreBarrierPrefixes, - sharedPrefixes = hiveMetastoreSharedPrefixes) - } - isolatedLoader.createClient() - } - protected[sql] override def parseSql(sql: String): LogicalPlan = { - super.parseSql(substitutor.substitute(hiveconf, sql)) + executionHive.withHiveState { + super.parseSql(substitutor.substitute(hiveconf, sql)) + } } override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = @@ -338,13 +214,13 @@ class HiveContext private[hive]( * @since 1.3.0 */ def refreshTable(tableName: String): Unit = { - val tableIdent = SqlParser.parseTableIdentifier(tableName) - catalog.refreshTable(tableIdent) + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) + sessionState.catalog.refreshTable(tableIdent) } protected[hive] def invalidateTable(tableName: String): Unit = { - val tableIdent = SqlParser.parseTableIdentifier(tableName) - catalog.invalidateTable(tableIdent) + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) + sessionState.catalog.invalidateTable(tableIdent) } /** @@ -357,8 +233,8 @@ class HiveContext private[hive]( * @since 1.2.0 */ def analyze(tableName: String) { - val tableIdent = SqlParser.parseTableIdentifier(tableName) - val relation = EliminateSubQueries(catalog.lookupRelation(tableIdent)) + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) + val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdent)) relation match { case relation: MetastoreRelation => @@ -375,7 +251,7 @@ class HiveContext private[hive]( def calculateTableSize(fs: FileSystem, path: Path): Long = { val fileStatus = fs.getFileStatus(path) - val size = if (fileStatus.isDir) { + val size = if (fileStatus.isDirectory) { fs.listStatus(path) .map { status => if (!status.getPath().getName().startsWith(stagingDir)) { @@ -419,7 +295,7 @@ class HiveContext private[hive]( // recorded in the Hive metastore. // This logic is based on org.apache.hadoop.hive.ql.exec.StatsTask.aggregateStats(). if (newTotalSize > 0 && newTotalSize != oldTotalSize) { - catalog.client.alterTable( + sessionState.catalog.alterTable( relation.table.copy( properties = relation.table.properties + (StatsSetupConst.TOTAL_SIZE -> newTotalSize.toString))) @@ -442,102 +318,14 @@ class HiveContext private[hive]( hiveconf.set(key, value) } - override private[sql] def setConf[T](entry: SQLConfEntry[T], value: T): Unit = { + override private[sql] def setConf[T](entry: ConfigEntry[T], value: T): Unit = { setConf(entry.key, entry.stringConverter(value)) } - /* A catalyst metadata catalog that points to the Hive Metastore. */ - @transient - override protected[sql] lazy val catalog = - new HiveMetastoreCatalog(metadataHive, this) with OverrideCatalog - - // Note that HiveUDFs will be overridden by functions registered in this context. - @transient - override protected[sql] lazy val functionRegistry: FunctionRegistry = - new HiveFunctionRegistry(FunctionRegistry.builtin.copy()) - - // The Hive UDF current_database() is foldable, will be evaluated by optimizer, but the optimizer - // can't access the SessionState of metadataHive. - functionRegistry.registerFunction( - "current_database", - (expressions: Seq[Expression]) => new CurrentDatabase(this)) - - /* An analyzer that uses the Hive metastore. */ - @transient - override protected[sql] lazy val analyzer: Analyzer = - new Analyzer(catalog, functionRegistry, conf) { - override val extendedResolutionRules = - catalog.ParquetConversions :: - catalog.CreateTables :: - catalog.PreInsertionCasts :: - ExtractPythonUDFs :: - ResolveHiveWindowFunction :: - PreInsertCastAndRename :: - (if (conf.runSQLOnFile) new ResolveDataSource(self) :: Nil else Nil) - - override val extendedCheckRules = Seq( - PreWriteCheck(catalog) - ) - } - - /** Overridden by child classes that need to set configuration before the client init. */ - protected def configure(): Map[String, String] = { - // Hive 0.14.0 introduces timeout operations in HiveConf, and changes default values of a bunch - // of time `ConfVar`s by adding time suffixes (`s`, `ms`, and `d` etc.). This breaks backwards- - // compatibility when users are trying to connecting to a Hive metastore of lower version, - // because these options are expected to be integral values in lower versions of Hive. - // - // Here we enumerate all time `ConfVar`s and convert their values to numeric strings according - // to their output time units. - Seq( - ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY -> TimeUnit.SECONDS, - ConfVars.METASTORE_CLIENT_SOCKET_TIMEOUT -> TimeUnit.SECONDS, - ConfVars.METASTORE_CLIENT_SOCKET_LIFETIME -> TimeUnit.SECONDS, - ConfVars.HMSHANDLERINTERVAL -> TimeUnit.MILLISECONDS, - ConfVars.METASTORE_EVENT_DB_LISTENER_TTL -> TimeUnit.SECONDS, - ConfVars.METASTORE_EVENT_CLEAN_FREQ -> TimeUnit.SECONDS, - ConfVars.METASTORE_EVENT_EXPIRY_DURATION -> TimeUnit.SECONDS, - ConfVars.METASTORE_AGGREGATE_STATS_CACHE_TTL -> TimeUnit.SECONDS, - ConfVars.METASTORE_AGGREGATE_STATS_CACHE_MAX_WRITER_WAIT -> TimeUnit.MILLISECONDS, - ConfVars.METASTORE_AGGREGATE_STATS_CACHE_MAX_READER_WAIT -> TimeUnit.MILLISECONDS, - ConfVars.HIVES_AUTO_PROGRESS_TIMEOUT -> TimeUnit.SECONDS, - ConfVars.HIVE_LOG_INCREMENTAL_PLAN_PROGRESS_INTERVAL -> TimeUnit.MILLISECONDS, - ConfVars.HIVE_STATS_JDBC_TIMEOUT -> TimeUnit.SECONDS, - ConfVars.HIVE_STATS_RETRIES_WAIT -> TimeUnit.MILLISECONDS, - ConfVars.HIVE_LOCK_SLEEP_BETWEEN_RETRIES -> TimeUnit.SECONDS, - ConfVars.HIVE_ZOOKEEPER_SESSION_TIMEOUT -> TimeUnit.MILLISECONDS, - ConfVars.HIVE_ZOOKEEPER_CONNECTION_BASESLEEPTIME -> TimeUnit.MILLISECONDS, - ConfVars.HIVE_TXN_TIMEOUT -> TimeUnit.SECONDS, - ConfVars.HIVE_COMPACTOR_WORKER_TIMEOUT -> TimeUnit.SECONDS, - ConfVars.HIVE_COMPACTOR_CHECK_INTERVAL -> TimeUnit.SECONDS, - ConfVars.HIVE_COMPACTOR_CLEANER_RUN_INTERVAL -> TimeUnit.MILLISECONDS, - ConfVars.HIVE_SERVER2_THRIFT_HTTP_MAX_IDLE_TIME -> TimeUnit.MILLISECONDS, - ConfVars.HIVE_SERVER2_THRIFT_HTTP_WORKER_KEEPALIVE_TIME -> TimeUnit.SECONDS, - ConfVars.HIVE_SERVER2_THRIFT_HTTP_COOKIE_MAX_AGE -> TimeUnit.SECONDS, - ConfVars.HIVE_SERVER2_THRIFT_LOGIN_BEBACKOFF_SLOT_LENGTH -> TimeUnit.MILLISECONDS, - ConfVars.HIVE_SERVER2_THRIFT_LOGIN_TIMEOUT -> TimeUnit.SECONDS, - ConfVars.HIVE_SERVER2_THRIFT_WORKER_KEEPALIVE_TIME -> TimeUnit.SECONDS, - ConfVars.HIVE_SERVER2_ASYNC_EXEC_SHUTDOWN_TIMEOUT -> TimeUnit.SECONDS, - ConfVars.HIVE_SERVER2_ASYNC_EXEC_KEEPALIVE_TIME -> TimeUnit.SECONDS, - ConfVars.HIVE_SERVER2_LONG_POLLING_TIMEOUT -> TimeUnit.MILLISECONDS, - ConfVars.HIVE_SERVER2_SESSION_CHECK_INTERVAL -> TimeUnit.MILLISECONDS, - ConfVars.HIVE_SERVER2_IDLE_SESSION_TIMEOUT -> TimeUnit.MILLISECONDS, - ConfVars.HIVE_SERVER2_IDLE_OPERATION_TIMEOUT -> TimeUnit.MILLISECONDS, - ConfVars.SERVER_READ_SOCKET_TIMEOUT -> TimeUnit.SECONDS, - ConfVars.HIVE_LOCALIZE_RESOURCE_WAIT_INTERVAL -> TimeUnit.MILLISECONDS, - ConfVars.SPARK_CLIENT_FUTURE_TIMEOUT -> TimeUnit.SECONDS, - ConfVars.SPARK_JOB_MONITOR_TIMEOUT -> TimeUnit.SECONDS, - ConfVars.SPARK_RPC_CLIENT_CONNECT_TIMEOUT -> TimeUnit.MILLISECONDS, - ConfVars.SPARK_RPC_CLIENT_HANDSHAKE_TIMEOUT -> TimeUnit.MILLISECONDS - ).map { case (confVar, unit) => - confVar.varname -> hiveconf.getTimeVar(confVar, unit).toString - }.toMap - } - /** * SQLConf and HiveConf contracts: * - * 1. create a new SessionState for each HiveContext + * 1. create a new o.a.h.hive.ql.session.SessionState for each HiveContext * 2. when the Hive session is first initialized, params in HiveConf will get picked up by the * SQLConf. Additionally, any properties set by set() or a SET command inside sql() will be * set in the SQLConf *as well as* in the HiveConf. @@ -549,44 +337,6 @@ class HiveContext private[hive]( c } - protected[sql] override lazy val conf: SQLConf = new SQLConf { - override def dialect: String = getConf(SQLConf.DIALECT, "hiveql") - override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) - } - - protected[sql] override def getSQLDialect(): ParserDialect = { - if (conf.dialect == "hiveql") { - new HiveQLDialect(this) - } else { - super.getSQLDialect() - } - } - - @transient - private val hivePlanner = new SparkPlanner with HiveStrategies { - val hiveContext = self - - override def strategies: Seq[Strategy] = experimental.extraStrategies ++ Seq( - DataSourceStrategy, - HiveCommandStrategy(self), - HiveDDLStrategy, - DDLStrategy, - TakeOrderedAndProject, - InMemoryScans, - HiveTableScans, - DataSinks, - Scripts, - HashAggregation, - Aggregation, - LeftSemiJoin, - EquiJoinSelection, - BasicOperators, - BroadcastNestedLoop, - CartesianProduct, - DefaultJoin - ) - } - private def functionOrMacroDDLPattern(command: String) = Pattern.compile( ".*(create|drop)\\s+(temporary\\s+)?(function|macro).+", Pattern.DOTALL).matcher(command) @@ -602,8 +352,14 @@ class HiveContext private[hive]( } } - @transient - override protected[sql] val planner = hivePlanner + /** + * Executes a SQL query without parsing it, but instead passing it directly to Hive. + * This is currently only used for DDLs and will be removed as soon as Spark can parse + * all supported Hive DDLs itself. + */ + protected[sql] override def runNativeSql(sqlText: String): Seq[Row] = { + runSqlHive(sqlText).map { s => Row(s) } + } /** Extends QueryExecution with hive specific features. */ protected[sql] class QueryExecution(logicalPlan: LogicalPlan) @@ -653,71 +409,329 @@ class HiveContext private[hive]( } -private[hive] object HiveContext { +private[hive] object HiveContext extends Logging { /** The version of hive used internally by Spark SQL. */ val hiveExecutionVersion: String = "1.2.1" - val HIVE_METASTORE_VERSION = stringConf("spark.sql.hive.metastore.version", - defaultValue = Some(hiveExecutionVersion), - doc = "Version of the Hive metastore. Available options are " + + val HIVE_METASTORE_VERSION = SQLConfigBuilder("spark.sql.hive.metastore.version") + .doc("Version of the Hive metastore. Available options are " + s"0.12.0 through $hiveExecutionVersion.") + .stringConf + .createWithDefault(hiveExecutionVersion) - val HIVE_EXECUTION_VERSION = stringConf( - key = "spark.sql.hive.version", - defaultValue = Some(hiveExecutionVersion), - doc = "Version of Hive used internally by Spark SQL.") + val HIVE_EXECUTION_VERSION = SQLConfigBuilder("spark.sql.hive.version") + .doc("Version of Hive used internally by Spark SQL.") + .stringConf + .createWithDefault(hiveExecutionVersion) - val HIVE_METASTORE_JARS = stringConf("spark.sql.hive.metastore.jars", - defaultValue = Some("builtin"), - doc = s""" + val HIVE_METASTORE_JARS = SQLConfigBuilder("spark.sql.hive.metastore.jars") + .doc(s""" | Location of the jars that should be used to instantiate the HiveMetastoreClient. | This property can be one of three options: " | 1. "builtin" - | Use Hive ${hiveExecutionVersion}, which is bundled with the Spark assembly jar when + | Use Hive ${hiveExecutionVersion}, which is bundled with the Spark assembly when | -Phive is enabled. When this option is chosen, | spark.sql.hive.metastore.version must be either | ${hiveExecutionVersion} or not defined. | 2. "maven" | Use Hive jars of specified version downloaded from Maven repositories. | 3. A classpath in the standard format for both Hive and Hadoop. - """.stripMargin) - val CONVERT_METASTORE_PARQUET = booleanConf("spark.sql.hive.convertMetastoreParquet", - defaultValue = Some(true), - doc = "When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of " + - "the built in support.") - - val CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING = booleanConf( - "spark.sql.hive.convertMetastoreParquet.mergeSchema", - defaultValue = Some(false), - doc = "TODO") + """.stripMargin) + .stringConf + .createWithDefault("builtin") - val CONVERT_CTAS = booleanConf("spark.sql.hive.convertCTAS", - defaultValue = Some(false), - doc = "TODO") + val CONVERT_METASTORE_PARQUET = SQLConfigBuilder("spark.sql.hive.convertMetastoreParquet") + .doc("When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of " + + "the built in support.") + .booleanConf + .createWithDefault(true) + + val CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING = + SQLConfigBuilder("spark.sql.hive.convertMetastoreParquet.mergeSchema") + .doc("When true, also tries to merge possibly different but compatible Parquet schemas in " + + "different Parquet data files. This configuration is only effective " + + "when \"spark.sql.hive.convertMetastoreParquet\" is true.") + .booleanConf + .createWithDefault(false) + + val CONVERT_CTAS = SQLConfigBuilder("spark.sql.hive.convertCTAS") + .doc("When true, a table created by a Hive CTAS statement (no USING clause) will be " + + "converted to a data source table, using the data source set by spark.sql.sources.default.") + .booleanConf + .createWithDefault(false) + + val CONVERT_METASTORE_ORC = SQLConfigBuilder("spark.sql.hive.convertMetastoreOrc") + .doc("When set to false, Spark SQL will use the Hive SerDe for ORC tables instead of " + + "the built in support.") + .booleanConf + .createWithDefault(true) - val HIVE_METASTORE_SHARED_PREFIXES = stringSeqConf("spark.sql.hive.metastore.sharedPrefixes", - defaultValue = Some(jdbcPrefixes), - doc = "A comma separated list of class prefixes that should be loaded using the classloader " + + val HIVE_METASTORE_SHARED_PREFIXES = SQLConfigBuilder("spark.sql.hive.metastore.sharedPrefixes") + .doc("A comma separated list of class prefixes that should be loaded using the classloader " + "that is shared between Spark SQL and a specific version of Hive. An example of classes " + "that should be shared is JDBC drivers that are needed to talk to the metastore. Other " + "classes that need to be shared are those that interact with classes that are already " + "shared. For example, custom appenders that are used by log4j.") + .stringConf + .toSequence + .createWithDefault(jdbcPrefixes) private def jdbcPrefixes = Seq( "com.mysql.jdbc", "org.postgresql", "com.microsoft.sqlserver", "oracle.jdbc") - val HIVE_METASTORE_BARRIER_PREFIXES = stringSeqConf("spark.sql.hive.metastore.barrierPrefixes", - defaultValue = Some(Seq()), - doc = "A comma separated list of class prefixes that should explicitly be reloaded for each " + + val HIVE_METASTORE_BARRIER_PREFIXES = SQLConfigBuilder("spark.sql.hive.metastore.barrierPrefixes") + .doc("A comma separated list of class prefixes that should explicitly be reloaded for each " + "version of Hive that Spark SQL is communicating with. For example, Hive UDFs that are " + "declared in a prefix that typically would be shared (i.e. org.apache.spark.*).") + .stringConf + .toSequence + .createWithDefault(Nil) + + val HIVE_THRIFT_SERVER_ASYNC = SQLConfigBuilder("spark.sql.hive.thriftServer.async") + .doc("When set to true, Hive Thrift server executes SQL queries in an asynchronous way.") + .booleanConf + .createWithDefault(true) - val HIVE_THRIFT_SERVER_ASYNC = booleanConf("spark.sql.hive.thriftServer.async", - defaultValue = Some(true), - doc = "TODO") + /** + * The version of the hive client that will be used to communicate with the metastore. Note that + * this does not necessarily need to be the same version of Hive that is used internally by + * Spark SQL for execution. + */ + private def hiveMetastoreVersion(conf: SQLConf): String = { + conf.getConf(HIVE_METASTORE_VERSION) + } + + /** + * The location of the jars that should be used to instantiate the HiveMetastoreClient. This + * property can be one of three options: + * - a classpath in the standard format for both hive and hadoop. + * - builtin - attempt to discover the jars that were used to load Spark SQL and use those. This + * option is only valid when using the execution version of Hive. + * - maven - download the correct version of hive on demand from maven. + */ + private def hiveMetastoreJars(conf: SQLConf): String = { + conf.getConf(HIVE_METASTORE_JARS) + } + + /** + * A comma separated list of class prefixes that should be loaded using the classloader that + * is shared between Spark SQL and a specific version of Hive. An example of classes that should + * be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need + * to be shared are those that interact with classes that are already shared. For example, + * custom appenders that are used by log4j. + */ + private def hiveMetastoreSharedPrefixes(conf: SQLConf): Seq[String] = { + conf.getConf(HIVE_METASTORE_SHARED_PREFIXES).filterNot(_ == "") + } + + /** + * A comma separated list of class prefixes that should explicitly be reloaded for each version + * of Hive that Spark SQL is communicating with. For example, Hive UDFs that are declared in a + * prefix that typically would be shared (i.e. org.apache.spark.*) + */ + private def hiveMetastoreBarrierPrefixes(conf: SQLConf): Seq[String] = { + conf.getConf(HIVE_METASTORE_BARRIER_PREFIXES).filterNot(_ == "") + } + + /** + * Configurations needed to create a [[HiveClient]]. + */ + private[hive] def hiveClientConfigurations(hiveconf: HiveConf): Map[String, String] = { + // Hive 0.14.0 introduces timeout operations in HiveConf, and changes default values of a bunch + // of time `ConfVar`s by adding time suffixes (`s`, `ms`, and `d` etc.). This breaks backwards- + // compatibility when users are trying to connecting to a Hive metastore of lower version, + // because these options are expected to be integral values in lower versions of Hive. + // + // Here we enumerate all time `ConfVar`s and convert their values to numeric strings according + // to their output time units. + Seq( + ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY -> TimeUnit.SECONDS, + ConfVars.METASTORE_CLIENT_SOCKET_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.METASTORE_CLIENT_SOCKET_LIFETIME -> TimeUnit.SECONDS, + ConfVars.HMSHANDLERINTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.METASTORE_EVENT_DB_LISTENER_TTL -> TimeUnit.SECONDS, + ConfVars.METASTORE_EVENT_CLEAN_FREQ -> TimeUnit.SECONDS, + ConfVars.METASTORE_EVENT_EXPIRY_DURATION -> TimeUnit.SECONDS, + ConfVars.METASTORE_AGGREGATE_STATS_CACHE_TTL -> TimeUnit.SECONDS, + ConfVars.METASTORE_AGGREGATE_STATS_CACHE_MAX_WRITER_WAIT -> TimeUnit.MILLISECONDS, + ConfVars.METASTORE_AGGREGATE_STATS_CACHE_MAX_READER_WAIT -> TimeUnit.MILLISECONDS, + ConfVars.HIVES_AUTO_PROGRESS_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_LOG_INCREMENTAL_PLAN_PROGRESS_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_STATS_JDBC_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_STATS_RETRIES_WAIT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_LOCK_SLEEP_BETWEEN_RETRIES -> TimeUnit.SECONDS, + ConfVars.HIVE_ZOOKEEPER_SESSION_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_ZOOKEEPER_CONNECTION_BASESLEEPTIME -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_TXN_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_COMPACTOR_WORKER_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_COMPACTOR_CHECK_INTERVAL -> TimeUnit.SECONDS, + ConfVars.HIVE_COMPACTOR_CLEANER_RUN_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_THRIFT_HTTP_MAX_IDLE_TIME -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_THRIFT_HTTP_WORKER_KEEPALIVE_TIME -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_THRIFT_HTTP_COOKIE_MAX_AGE -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_THRIFT_LOGIN_BEBACKOFF_SLOT_LENGTH -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_THRIFT_LOGIN_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_THRIFT_WORKER_KEEPALIVE_TIME -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_ASYNC_EXEC_SHUTDOWN_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_ASYNC_EXEC_KEEPALIVE_TIME -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_LONG_POLLING_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_SESSION_CHECK_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_IDLE_SESSION_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_IDLE_OPERATION_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.SERVER_READ_SOCKET_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_LOCALIZE_RESOURCE_WAIT_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.SPARK_CLIENT_FUTURE_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.SPARK_JOB_MONITOR_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.SPARK_RPC_CLIENT_CONNECT_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.SPARK_RPC_CLIENT_HANDSHAKE_TIMEOUT -> TimeUnit.MILLISECONDS + ).map { case (confVar, unit) => + confVar.varname -> hiveconf.getTimeVar(confVar, unit).toString + }.toMap + } + + /** + * Create a [[HiveClient]] used for execution. + * + * Currently this must always be Hive 13 as this is the version of Hive that is packaged + * with Spark SQL. This copy of the client is used for execution related tasks like + * registering temporary functions or ensuring that the ThreadLocal SessionState is + * correctly populated. This copy of Hive is *not* used for storing persistent metadata, + * and only point to a dummy metastore in a temporary directory. + */ + protected[hive] def newClientForExecution( + conf: SparkConf, + hadoopConf: Configuration): HiveClientImpl = { + logInfo(s"Initializing execution hive, version $hiveExecutionVersion") + val loader = new IsolatedClientLoader( + version = IsolatedClientLoader.hiveVersion(hiveExecutionVersion), + sparkConf = conf, + execJars = Seq(), + hadoopConf = hadoopConf, + config = newTemporaryConfiguration(useInMemoryDerby = true), + isolationOn = false, + baseClassLoader = Utils.getContextOrSparkClassLoader) + loader.createClient().asInstanceOf[HiveClientImpl] + } + + /** + * Create a [[HiveClient]] used to retrieve metadata from the Hive MetaStore. + * + * The version of the Hive client that is used here must match the metastore that is configured + * in the hive-site.xml file. + */ + private def newClientForMetadata(conf: SparkConf, hadoopConf: Configuration): HiveClient = { + val hiveConf = new HiveConf(hadoopConf, classOf[HiveConf]) + val configurations = hiveClientConfigurations(hiveConf) + newClientForMetadata(conf, hiveConf, hadoopConf, configurations) + } + + protected[hive] def newClientForMetadata( + conf: SparkConf, + hiveConf: HiveConf, + hadoopConf: Configuration, + configurations: Map[String, String]): HiveClient = { + val sqlConf = new SQLConf + sqlConf.setConf(SQLContext.getSQLProperties(conf)) + val hiveMetastoreVersion = HiveContext.hiveMetastoreVersion(sqlConf) + val hiveMetastoreJars = HiveContext.hiveMetastoreJars(sqlConf) + val hiveMetastoreSharedPrefixes = HiveContext.hiveMetastoreSharedPrefixes(sqlConf) + val hiveMetastoreBarrierPrefixes = HiveContext.hiveMetastoreBarrierPrefixes(sqlConf) + val metaVersion = IsolatedClientLoader.hiveVersion(hiveMetastoreVersion) + + val defaultWarehouseLocation = hiveConf.get("hive.metastore.warehouse.dir") + logInfo("default warehouse location is " + defaultWarehouseLocation) + + // `configure` goes second to override other settings. + val allConfig = hiveConf.asScala.map(e => e.getKey -> e.getValue).toMap ++ configurations + + val isolatedLoader = if (hiveMetastoreJars == "builtin") { + if (hiveExecutionVersion != hiveMetastoreVersion) { + throw new IllegalArgumentException( + "Builtin jars can only be used when hive execution version == hive metastore version. " + + s"Execution: $hiveExecutionVersion != Metastore: $hiveMetastoreVersion. " + + "Specify a vaild path to the correct hive jars using $HIVE_METASTORE_JARS " + + s"or change ${HIVE_METASTORE_VERSION.key} to $hiveExecutionVersion.") + } + + // We recursively find all jars in the class loader chain, + // starting from the given classLoader. + def allJars(classLoader: ClassLoader): Array[URL] = classLoader match { + case null => Array.empty[URL] + case urlClassLoader: URLClassLoader => + urlClassLoader.getURLs ++ allJars(urlClassLoader.getParent) + case other => allJars(other.getParent) + } + + val classLoader = Utils.getContextOrSparkClassLoader + val jars = allJars(classLoader) + if (jars.length == 0) { + throw new IllegalArgumentException( + "Unable to locate hive jars to connect to metastore. " + + "Please set spark.sql.hive.metastore.jars.") + } + + logInfo( + s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using Spark classes.") + new IsolatedClientLoader( + version = metaVersion, + sparkConf = conf, + hadoopConf = hadoopConf, + execJars = jars.toSeq, + config = allConfig, + isolationOn = true, + barrierPrefixes = hiveMetastoreBarrierPrefixes, + sharedPrefixes = hiveMetastoreSharedPrefixes) + } else if (hiveMetastoreJars == "maven") { + // TODO: Support for loading the jars from an already downloaded location. + logInfo( + s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using maven.") + IsolatedClientLoader.forVersion( + hiveMetastoreVersion = hiveMetastoreVersion, + hadoopVersion = VersionInfo.getVersion, + sparkConf = conf, + hadoopConf = hadoopConf, + config = allConfig, + barrierPrefixes = hiveMetastoreBarrierPrefixes, + sharedPrefixes = hiveMetastoreSharedPrefixes) + } else { + // Convert to files and expand any directories. + val jars = + hiveMetastoreJars + .split(File.pathSeparator) + .flatMap { + case path if new File(path).getName == "*" => + val files = new File(path).getParentFile.listFiles() + if (files == null) { + logWarning(s"Hive jar path '$path' does not exist.") + Nil + } else { + files.filter(_.getName.toLowerCase.endsWith(".jar")) + } + case path => + new File(path) :: Nil + } + .map(_.toURI.toURL) + + logInfo( + s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion " + + s"using ${jars.mkString(":")}") + new IsolatedClientLoader( + version = metaVersion, + sparkConf = conf, + hadoopConf = hadoopConf, + execJars = jars.toSeq, + config = allConfig, + isolationOn = true, + barrierPrefixes = hiveMetastoreBarrierPrefixes, + sharedPrefixes = hiveMetastoreSharedPrefixes) + } + isolatedLoader.createClient() + } /** Constructs a configuration for hive, where the metastore is located in a temp directory. */ - def newTemporaryConfiguration(): Map[String, String] = { + def newTemporaryConfiguration(useInMemoryDerby: Boolean): Map[String, String] = { + val withInMemoryMode = if (useInMemoryDerby) "memory:" else "" + val tempDir = Utils.createTempDir() val localMetastore = new File(tempDir, "metastore") val propMap: HashMap[String, String] = HashMap() @@ -731,9 +745,24 @@ private[hive] object HiveContext { } propMap.put(HiveConf.ConfVars.METASTOREWAREHOUSE.varname, localMetastore.toURI.toString) propMap.put(HiveConf.ConfVars.METASTORECONNECTURLKEY.varname, - s"jdbc:derby:;databaseName=${localMetastore.getAbsolutePath};create=true") + s"jdbc:derby:${withInMemoryMode};databaseName=${localMetastore.getAbsolutePath};create=true") propMap.put("datanucleus.rdbms.datastoreAdapterClassName", "org.datanucleus.store.rdbms.adapter.DerbyAdapter") + + // SPARK-11783: When "hive.metastore.uris" is set, the metastore connection mode will be + // remote (https://cwiki.apache.org/confluence/display/Hive/AdminManual+MetastoreAdmin + // mentions that "If hive.metastore.uris is empty local mode is assumed, remote otherwise"). + // Remote means that the metastore server is running in its own process. + // When the mode is remote, configurations like "javax.jdo.option.ConnectionURL" will not be + // used (because they are used by remote metastore server that talks to the database). + // Because execution Hive should always connects to a embedded derby metastore. + // We have to remove the value of hive.metastore.uris. So, the execution Hive client connects + // to the actual embedded derby metastore instead of the remote metastore. + // You can search HiveConf.ConfVars.METASTOREURIS in the code of HiveConf (in Hive's repo). + // Then, you will find that the local metastore mode is only set to true when + // hive.metastore.uris is not set. + propMap.put(ConfVars.METASTOREURIS.varname, "") + propMap.toMap } @@ -756,7 +785,7 @@ private[hive] object HiveContext { case (null, _) => "NULL" case (d: Int, DateType) => new DateWritable(d).toString case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString - case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8") + case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8) case (decimal: java.math.BigDecimal, DecimalType()) => // Hive strips trailing zeros so use its toString HiveDecimal.create(decimal).toString diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala new file mode 100644 index 0000000000000..f627384253aa9 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import scala.util.control.NonFatal + +import org.apache.hadoop.hive.ql.metadata.HiveException +import org.apache.thrift.TException + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.hive.client.HiveClient + + +/** + * A persistent implementation of the system catalog using Hive. + * All public methods must be synchronized for thread-safety. + */ +private[spark] class HiveExternalCatalog(client: HiveClient) extends ExternalCatalog with Logging { + import ExternalCatalog._ + + // Exceptions thrown by the hive client that we would like to wrap + private val clientExceptions = Set( + classOf[HiveException].getCanonicalName, + classOf[TException].getCanonicalName) + + /** + * Whether this is an exception thrown by the hive client that should be wrapped. + * + * Due to classloader isolation issues, pattern matching won't work here so we need + * to compare the canonical names of the exceptions, which we assume to be stable. + */ + private def isClientException(e: Throwable): Boolean = { + var temp: Class[_] = e.getClass + var found = false + while (temp != null && !found) { + found = clientExceptions.contains(temp.getCanonicalName) + temp = temp.getSuperclass + } + found + } + + /** + * Run some code involving `client` in a [[synchronized]] block and wrap certain + * exceptions thrown in the process in [[AnalysisException]]. + */ + private def withClient[T](body: => T): T = synchronized { + try { + body + } catch { + case NonFatal(e) if isClientException(e) => + throw new AnalysisException(e.getClass.getCanonicalName + ": " + e.getMessage) + } + } + + private def requireDbMatches(db: String, table: CatalogTable): Unit = { + if (table.identifier.database != Some(db)) { + throw new AnalysisException( + s"Provided database $db does not match the one specified in the " + + s"table definition (${table.identifier.database.getOrElse("n/a")})") + } + } + + private def requireTableExists(db: String, table: String): Unit = { + withClient { getTable(db, table) } + } + + // -------------------------------------------------------------------------- + // Databases + // -------------------------------------------------------------------------- + + override def createDatabase( + dbDefinition: CatalogDatabase, + ignoreIfExists: Boolean): Unit = withClient { + client.createDatabase(dbDefinition, ignoreIfExists) + } + + override def dropDatabase( + db: String, + ignoreIfNotExists: Boolean, + cascade: Boolean): Unit = withClient { + client.dropDatabase(db, ignoreIfNotExists, cascade) + } + + /** + * Alter a database whose name matches the one specified in `dbDefinition`, + * assuming the database exists. + * + * Note: As of now, this only supports altering database properties! + */ + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = withClient { + val existingDb = getDatabase(dbDefinition.name) + if (existingDb.properties == dbDefinition.properties) { + logWarning(s"Request to alter database ${dbDefinition.name} is a no-op because " + + s"the provided database properties are the same as the old ones. Hive does not " + + s"currently support altering other database fields.") + } + client.alterDatabase(dbDefinition) + } + + override def getDatabase(db: String): CatalogDatabase = withClient { + client.getDatabase(db) + } + + override def databaseExists(db: String): Boolean = withClient { + client.getDatabaseOption(db).isDefined + } + + override def listDatabases(): Seq[String] = withClient { + client.listDatabases("*") + } + + override def listDatabases(pattern: String): Seq[String] = withClient { + client.listDatabases(pattern) + } + + override def setCurrentDatabase(db: String): Unit = withClient { + client.setCurrentDatabase(db) + } + + // -------------------------------------------------------------------------- + // Tables + // -------------------------------------------------------------------------- + + override def createTable( + db: String, + tableDefinition: CatalogTable, + ignoreIfExists: Boolean): Unit = withClient { + requireDbExists(db) + requireDbMatches(db, tableDefinition) + client.createTable(tableDefinition, ignoreIfExists) + } + + override def dropTable( + db: String, + table: String, + ignoreIfNotExists: Boolean): Unit = withClient { + requireDbExists(db) + client.dropTable(db, table, ignoreIfNotExists) + } + + override def renameTable(db: String, oldName: String, newName: String): Unit = withClient { + val newTable = client.getTable(db, oldName) + .copy(identifier = TableIdentifier(newName, Some(db))) + client.alterTable(oldName, newTable) + } + + /** + * Alter a table whose name that matches the one specified in `tableDefinition`, + * assuming the table exists. + * + * Note: As of now, this only supports altering table properties, serde properties, + * and num buckets! + */ + override def alterTable(db: String, tableDefinition: CatalogTable): Unit = withClient { + requireDbMatches(db, tableDefinition) + requireTableExists(db, tableDefinition.identifier.table) + client.alterTable(tableDefinition) + } + + override def getTable(db: String, table: String): CatalogTable = withClient { + client.getTable(db, table) + } + + override def getTableOption(db: String, table: String): Option[CatalogTable] = withClient { + client.getTableOption(db, table) + } + + override def tableExists(db: String, table: String): Boolean = withClient { + client.getTableOption(db, table).isDefined + } + + override def listTables(db: String): Seq[String] = withClient { + requireDbExists(db) + client.listTables(db) + } + + override def listTables(db: String, pattern: String): Seq[String] = withClient { + requireDbExists(db) + client.listTables(db, pattern) + } + + // -------------------------------------------------------------------------- + // Partitions + // -------------------------------------------------------------------------- + + override def createPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition], + ignoreIfExists: Boolean): Unit = withClient { + requireTableExists(db, table) + client.createPartitions(db, table, parts, ignoreIfExists) + } + + override def dropPartitions( + db: String, + table: String, + parts: Seq[TablePartitionSpec], + ignoreIfNotExists: Boolean): Unit = withClient { + requireTableExists(db, table) + client.dropPartitions(db, table, parts, ignoreIfNotExists) + } + + override def renamePartitions( + db: String, + table: String, + specs: Seq[TablePartitionSpec], + newSpecs: Seq[TablePartitionSpec]): Unit = withClient { + client.renamePartitions(db, table, specs, newSpecs) + } + + override def alterPartitions( + db: String, + table: String, + newParts: Seq[CatalogTablePartition]): Unit = withClient { + client.alterPartitions(db, table, newParts) + } + + override def getPartition( + db: String, + table: String, + spec: TablePartitionSpec): CatalogTablePartition = withClient { + client.getPartition(db, table, spec) + } + + override def listPartitions( + db: String, + table: String): Seq[CatalogTablePartition] = withClient { + client.getAllPartitions(db, table) + } + + // -------------------------------------------------------------------------- + // Functions + // -------------------------------------------------------------------------- + + override def createFunction( + db: String, + funcDefinition: CatalogFunction): Unit = withClient { + // Hive's metastore is case insensitive. However, Hive's createFunction does + // not normalize the function name (unlike the getFunction part). So, + // we are normalizing the function name. + val functionName = funcDefinition.identifier.funcName.toLowerCase + val functionIdentifier = funcDefinition.identifier.copy(funcName = functionName) + client.createFunction(db, funcDefinition.copy(identifier = functionIdentifier)) + } + + override def dropFunction(db: String, name: String): Unit = withClient { + client.dropFunction(db, name) + } + + override def renameFunction(db: String, oldName: String, newName: String): Unit = withClient { + client.renameFunction(db, oldName, newName) + } + + override def getFunction(db: String, funcName: String): CatalogFunction = withClient { + client.getFunction(db, funcName) + } + + override def functionExists(db: String, funcName: String): Boolean = withClient { + client.functionExists(db, funcName) + } + + override def listFunctions(db: String, pattern: String): Seq[String] = withClient { + client.listFunctions(db, pattern) + } + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 36f0708f9da3d..585befe37825c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -19,18 +19,19 @@ package org.apache.spark.sql.hive import scala.collection.JavaConverters._ -import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} -import org.apache.hadoop.hive.serde2.objectinspector.primitive._ +import org.apache.hadoop.{io => hadoopIo} +import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar} +import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _} +import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfoFactory} -import org.apache.hadoop.hive.serde2.{io => hiveIo} -import org.apache.hadoop.{io => hadoopIo} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.types import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, types} import org.apache.spark.unsafe.types.UTF8String /** @@ -61,6 +62,7 @@ import org.apache.spark.unsafe.types.UTF8String * Primitive Type * Java Boxed Primitives: * org.apache.hadoop.hive.common.type.HiveVarchar + * org.apache.hadoop.hive.common.type.HiveChar * java.lang.String * java.lang.Integer * java.lang.Boolean @@ -75,6 +77,7 @@ import org.apache.spark.unsafe.types.UTF8String * java.sql.Timestamp * Writables: * org.apache.hadoop.hive.serde2.io.HiveVarcharWritable + * org.apache.hadoop.hive.serde2.io.HiveCharWritable * org.apache.hadoop.io.Text * org.apache.hadoop.io.IntWritable * org.apache.hadoop.hive.serde2.io.DoubleWritable @@ -93,7 +96,8 @@ import org.apache.spark.unsafe.types.UTF8String * Struct: Object[] / java.util.List / java POJO * Union: class StandardUnion { byte tag; Object object } * - * NOTICE: HiveVarchar is not supported by catalyst, it will be simply considered as String type. + * NOTICE: HiveVarchar/HiveChar is not supported by catalyst, it will be simply considered as + * String type. * * * 2. Hive ObjectInspector is a group of flexible APIs to inspect value in different data @@ -137,6 +141,7 @@ import org.apache.spark.unsafe.types.UTF8String * Primitive Object Inspectors: * WritableConstantStringObjectInspector * WritableConstantHiveVarcharObjectInspector + * WritableConstantHiveCharObjectInspector * WritableConstantHiveDecimalObjectInspector * WritableConstantTimestampObjectInspector * WritableConstantIntObjectInspector @@ -259,6 +264,8 @@ private[hive] trait HiveInspectors { UTF8String.fromString(poi.getWritableConstantValue.toString) case poi: WritableConstantHiveVarcharObjectInspector => UTF8String.fromString(poi.getWritableConstantValue.getHiveVarchar.getValue) + case poi: WritableConstantHiveCharObjectInspector => + UTF8String.fromString(poi.getWritableConstantValue.getHiveChar.getValue) case poi: WritableConstantHiveDecimalObjectInspector => HiveShim.toCatalystDecimal( PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector, @@ -303,13 +310,20 @@ private[hive] trait HiveInspectors { case _ if data == null => null case poi: VoidObjectInspector => null // always be null for void object inspector case pi: PrimitiveObjectInspector => pi match { - // We think HiveVarchar is also a String + // We think HiveVarchar/HiveChar is also a String case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() => UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue) case hvoi: HiveVarcharObjectInspector => UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue) + case hvoi: HiveCharObjectInspector if hvoi.preferWritable() => + UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveChar.getValue) + case hvoi: HiveCharObjectInspector => + UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue) case x: StringObjectInspector if x.preferWritable() => - UTF8String.fromString(x.getPrimitiveWritableObject(data).toString) + // Text is in UTF-8 already. No need to convert again via fromString. Copy bytes + val wObj = x.getPrimitiveWritableObject(data) + val result = wObj.copyBytes() + UTF8String.fromBytes(result, 0, result.length) case x: StringObjectInspector => UTF8String.fromString(x.getPrimitiveJavaObject(data)) case x: IntObjectInspector if x.preferWritable() => x.get(data) @@ -372,7 +386,16 @@ private[hive] trait HiveInspectors { (o: Any) => if (o != null) { val s = o.asInstanceOf[UTF8String].toString - new HiveVarchar(s, s.size) + new HiveVarchar(s, s.length) + } else { + null + } + + case _: JavaHiveCharObjectInspector => + (o: Any) => + if (o != null) { + val s = o.asInstanceOf[UTF8String].toString + new HiveChar(s, s.length) } else { null } @@ -427,9 +450,7 @@ private[hive] trait HiveInspectors { if (o != null) { val array = o.asInstanceOf[ArrayData] val values = new java.util.ArrayList[Any](array.numElements()) - array.foreach(elementType, (_, e) => { - values.add(wrapper(e)) - }) + array.foreach(elementType, (_, e) => values.add(wrapper(e))) values } else { null @@ -445,9 +466,8 @@ private[hive] trait HiveInspectors { if (o != null) { val map = o.asInstanceOf[MapData] val jmap = new java.util.HashMap[Any, Any](map.numElements()) - map.foreach(mt.keyType, mt.valueType, (k, v) => { - jmap.put(keyWrapper(k), valueWrapper(v)) - }) + map.foreach(mt.keyType, mt.valueType, (k, v) => + jmap.put(keyWrapper(k), valueWrapper(v))) jmap } else { null @@ -564,9 +584,9 @@ private[hive] trait HiveInspectors { case x: ListObjectInspector => val list = new java.util.ArrayList[Object] val tpe = dataType.asInstanceOf[ArrayType].elementType - a.asInstanceOf[ArrayData].foreach(tpe, (_, e) => { + a.asInstanceOf[ArrayData].foreach(tpe, (_, e) => list.add(wrap(e, x.getListElementObjectInspector, tpe)) - }) + ) list case x: MapObjectInspector => val keyType = dataType.asInstanceOf[MapType].keyType @@ -576,10 +596,10 @@ private[hive] trait HiveInspectors { // Some UDFs seem to assume we pass in a HashMap. val hashMap = new java.util.HashMap[Any, Any](map.numElements()) - map.foreach(keyType, valueType, (k, v) => { + map.foreach(keyType, valueType, (k, v) => hashMap.put(wrap(k, x.getMapKeyObjectInspector, keyType), wrap(v, x.getMapValueObjectInspector, valueType)) - }) + ) hashMap } @@ -681,9 +701,8 @@ private[hive] trait HiveInspectors { ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null) } else { val list = new java.util.ArrayList[Object]() - value.asInstanceOf[ArrayData].foreach(dt, (_, e) => { - list.add(wrap(e, listObjectInspector, dt)) - }) + value.asInstanceOf[ArrayData].foreach(dt, (_, e) => + list.add(wrap(e, listObjectInspector, dt))) ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list) } case Literal(value, MapType(keyType, valueType, _)) => @@ -695,9 +714,8 @@ private[hive] trait HiveInspectors { val map = value.asInstanceOf[MapData] val jmap = new java.util.HashMap[Any, Any](map.numElements()) - map.foreach(keyType, valueType, (k, v) => { - jmap.put(wrap(k, keyOI, keyType), wrap(v, valueOI, valueType)) - }) + map.foreach(keyType, valueType, (k, v) => + jmap.put(wrap(k, keyOI, keyType), wrap(v, valueOI, valueType))) ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, jmap) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f4d45714fae4e..ccc8345d7375d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -22,30 +22,32 @@ import scala.collection.mutable import com.google.common.base.Objects import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.Warehouse +import org.apache.hadoop.hive.metastore.{TableType => HiveTableType} import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.hadoop.hive.ql.metadata._ +import org.apache.hadoop.hive.ql.metadata.{Table => HiveTable, _} import org.apache.hadoop.hive.ql.plan.TableDesc -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{AnalysisException, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.parser.DataTypeParser import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.util.DataTypeParser -import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation -import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} -import org.apache.spark.sql.execution.{FileRelation, datasources} +import org.apache.spark.sql.execution.FileRelation +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.parquet.{DefaultSource => ParquetDefaultSource, ParquetRelation} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.HiveNativeCommand -import org.apache.spark.sql.sources._ +import org.apache.spark.sql.hive.orc.{DefaultSource => OrcDefaultSource} +import org.apache.spark.sql.sources.{FileFormat, HadoopFsRelation, HDFSFileCatalog} import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} private[hive] case class HiveSerDe( inputFormat: Option[String] = None, @@ -84,7 +86,18 @@ private[hive] object HiveSerDe { HiveSerDe( inputFormat = Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), outputFormat = Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"), - serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"))) + serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")), + + "textfile" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), + + "avro" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat"), + serde = Option("org.apache.hadoop.hive.serde2.avro.AvroSerDe"))) val key = source.toLowerCase match { case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet" @@ -96,27 +109,35 @@ private[hive] object HiveSerDe { } } -private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: HiveContext) - extends Catalog with Logging { - val conf = hive.conf +/** + * Legacy catalog for interacting with the Hive metastore. + * + * This is still used for things like creating data source tables, but in the future will be + * cleaned up to integrate more nicely with [[HiveExternalCatalog]]. + */ +private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveContext) + extends Logging { - /** Usages should lock on `this`. */ - protected[hive] lazy val hiveWarehouse = new Warehouse(hive.hiveconf) + val conf = hive.conf /** A fully qualified identifier for a table (i.e., database.tableName) */ case class QualifiedTableName(database: String, name: String) - private def getQualifiedTableName(tableIdent: TableIdentifier) = { + private def getCurrentDatabase: String = { + hive.sessionState.catalog.getCurrentDatabase + } + + def getQualifiedTableName(tableIdent: TableIdentifier): QualifiedTableName = { QualifiedTableName( - tableIdent.database.getOrElse(client.currentDatabase).toLowerCase, + tableIdent.database.getOrElse(getCurrentDatabase).toLowerCase, tableIdent.table.toLowerCase) } - private def getQualifiedTableName(hiveTable: HiveTable) = { + private def getQualifiedTableName(t: CatalogTable): QualifiedTableName = { QualifiedTableName( - hiveTable.specifiedDatabase.getOrElse(client.currentDatabase).toLowerCase, - hiveTable.name.toLowerCase) + t.identifier.database.getOrElse(getCurrentDatabase).toLowerCase, + t.identifier.table.toLowerCase) } /** A cache of Spark SQL data source tables that have been accessed. */ @@ -143,19 +164,15 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } } - def partColsFromParts: Option[Seq[String]] = { - table.properties.get("spark.sql.sources.schema.numPartCols").map { numPartCols => - (0 until numPartCols.toInt).map { index => - val partCol = table.properties.get(s"spark.sql.sources.schema.partCol.$index").orNull - if (partCol == null) { + def getColumnNames(colType: String): Seq[String] = { + table.properties.get(s"spark.sql.sources.schema.num${colType.capitalize}Cols").map { + numCols => (0 until numCols.toInt).map { index => + table.properties.getOrElse(s"spark.sql.sources.schema.${colType}Col.$index", throw new AnalysisException( - "Could not read partitioned columns from the metastore because it is corrupted " + - s"(missing part $index of the it, $numPartCols parts are expected).") - } - - partCol + s"Could not read $colType columns from the metastore because it is corrupted " + + s"(missing part $index of it, $numCols parts are expected).")) } - } + }.getOrElse(Nil) } // Originally, we used spark.sql.sources.schema to store the schema of a data source table. @@ -170,28 +187,32 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // We only need names at here since userSpecifiedSchema we loaded from the metastore // contains partition columns. We can always get datatypes of partitioning columns // from userSpecifiedSchema. - val partitionColumns = partColsFromParts.getOrElse(Nil) + val partitionColumns = getColumnNames("part") - // It does not appear that the ql client for the metastore has a way to enumerate all the - // SerDe properties directly... - val options = table.serdeProperties + val bucketSpec = table.properties.get("spark.sql.sources.schema.numBuckets").map { n => + BucketSpec(n.toInt, getColumnNames("bucket"), getColumnNames("sort")) + } - val resolvedRelation = - ResolvedDataSource( + val options = table.storage.serdeProperties + val dataSource = + DataSource( hive, - userSpecifiedSchema, - partitionColumns.toArray, - table.properties("spark.sql.sources.provider"), - options) - - LogicalRelation(resolvedRelation.relation) + userSpecifiedSchema = userSpecifiedSchema, + partitionColumns = partitionColumns, + bucketSpec = bucketSpec, + className = table.properties("spark.sql.sources.provider"), + options = options) + + LogicalRelation( + dataSource.resolveRelation(), + metastoreTableIdentifier = Some(TableIdentifier(in.name, Some(in.database)))) } } CacheBuilder.newBuilder().maximumSize(1000).build(cacheLoader) } - override def refreshTable(tableIdent: TableIdentifier): Unit = { + def refreshTable(tableIdent: TableIdentifier): Unit = { // refreshTable does not eagerly reload the cache. It just invalidate the cache. // Next time when we use the table, it will be populated in the cache. // Since we also cache ParquetRelations converted from Hive Parquet tables and @@ -211,6 +232,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], partitionColumns: Array[String], + bucketSpec: Option[BucketSpec], provider: String, options: Map[String, String], isExternal: Boolean): Unit = { @@ -240,6 +262,25 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } } + if (userSpecifiedSchema.isDefined && bucketSpec.isDefined) { + val BucketSpec(numBuckets, bucketColumnNames, sortColumnNames) = bucketSpec.get + + tableProperties.put("spark.sql.sources.schema.numBuckets", numBuckets.toString) + tableProperties.put("spark.sql.sources.schema.numBucketCols", + bucketColumnNames.length.toString) + bucketColumnNames.zipWithIndex.foreach { case (bucketCol, index) => + tableProperties.put(s"spark.sql.sources.schema.bucketCol.$index", bucketCol) + } + + if (sortColumnNames.nonEmpty) { + tableProperties.put("spark.sql.sources.schema.numSortCols", + sortColumnNames.length.toString) + sortColumnNames.zipWithIndex.foreach { case (sortCol, index) => + tableProperties.put(s"spark.sql.sources.schema.sortCol.$index", sortCol) + } + } + } + if (userSpecifiedSchema.isEmpty && partitionColumns.length > 0) { // The table does not have a specified schema, which means that the schema will be inferred // when we load the table. So, we are not expecting partition columns and we will discover @@ -252,78 +293,91 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val tableType = if (isExternal) { tableProperties.put("EXTERNAL", "TRUE") - ExternalTable + CatalogTableType.EXTERNAL_TABLE } else { tableProperties.put("EXTERNAL", "FALSE") - ManagedTable + CatalogTableType.MANAGED_TABLE } val maybeSerDe = HiveSerDe.sourceToSerDe(provider, hive.hiveconf) - val dataSource = ResolvedDataSource( - hive, userSpecifiedSchema, partitionColumns, provider, options) - - def newSparkSQLSpecificMetastoreTable(): HiveTable = { - HiveTable( - specifiedDatabase = Option(dbName), - name = tblName, - schema = Nil, - partitionColumns = Nil, + val dataSource = + DataSource( + hive, + userSpecifiedSchema = userSpecifiedSchema, + partitionColumns = partitionColumns, + bucketSpec = bucketSpec, + className = provider, + options = options) + + def newSparkSQLSpecificMetastoreTable(): CatalogTable = { + CatalogTable( + identifier = TableIdentifier(tblName, Option(dbName)), tableType = tableType, - properties = tableProperties.toMap, - serdeProperties = options) + schema = Nil, + storage = CatalogStorageFormat( + locationUri = None, + inputFormat = None, + outputFormat = None, + serde = None, + serdeProperties = options + ), + properties = tableProperties.toMap) } - def newHiveCompatibleMetastoreTable(relation: HadoopFsRelation, serde: HiveSerDe): HiveTable = { - def schemaToHiveColumn(schema: StructType): Seq[HiveColumn] = { - schema.map { field => - HiveColumn( - name = field.name, - hiveType = HiveMetastoreTypes.toMetastoreType(field.dataType), - comment = "") - } - } - + def newHiveCompatibleMetastoreTable( + relation: HadoopFsRelation, + serde: HiveSerDe): CatalogTable = { assert(partitionColumns.isEmpty) - assert(relation.partitionColumns.isEmpty) + assert(relation.partitionSchema.isEmpty) - HiveTable( - specifiedDatabase = Option(dbName), - name = tblName, - schema = schemaToHiveColumn(relation.schema), - partitionColumns = Nil, + CatalogTable( + identifier = TableIdentifier(tblName, Option(dbName)), tableType = tableType, + storage = CatalogStorageFormat( + locationUri = Some(relation.location.paths.map(_.toUri.toString).head), + inputFormat = serde.inputFormat, + outputFormat = serde.outputFormat, + serde = serde.serde, + serdeProperties = options + ), + schema = relation.schema.map { f => + CatalogColumn(f.name, HiveMetastoreTypes.toMetastoreType(f.dataType)) + }, properties = tableProperties.toMap, - serdeProperties = options, - location = Some(relation.paths.head), - viewText = None, // TODO We need to place the SQL string here. - inputFormat = serde.inputFormat, - outputFormat = serde.outputFormat, - serde = serde.serde) + viewText = None) // TODO: We need to place the SQL string here } // TODO: Support persisting partitioned data source relations in Hive compatible format val qualifiedTableName = tableIdent.quotedString - val (hiveCompatibleTable, logMessage) = (maybeSerDe, dataSource.relation) match { + val skipHiveMetadata = options.getOrElse("skipHiveMetadata", "false").toBoolean + val (hiveCompatibleTable, logMessage) = (maybeSerDe, dataSource.resolveRelation()) match { + case _ if skipHiveMetadata => + val message = + s"Persisting partitioned data source relation $qualifiedTableName into " + + "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive." + (None, message) + case (Some(serde), relation: HadoopFsRelation) - if relation.paths.length == 1 && relation.partitionColumns.isEmpty => + if relation.location.paths.length == 1 && relation.partitionSchema.isEmpty => val hiveTable = newHiveCompatibleMetastoreTable(relation, serde) val message = s"Persisting data source relation $qualifiedTableName with a single input path " + - s"into Hive metastore in Hive compatible format. Input path: ${relation.paths.head}." + s"into Hive metastore in Hive compatible format. Input path: " + + s"${relation.location.paths.head}." (Some(hiveTable), message) - case (Some(serde), relation: HadoopFsRelation) if relation.partitionColumns.nonEmpty => + case (Some(serde), relation: HadoopFsRelation) if relation.partitionSchema.nonEmpty => val message = s"Persisting partitioned data source relation $qualifiedTableName into " + "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. " + - "Input path(s): " + relation.paths.mkString("\n", "\n", "") + "Input path(s): " + relation.location.paths.mkString("\n", "\n", "") (None, message) case (Some(serde), relation: HadoopFsRelation) => val message = s"Persisting data source relation $qualifiedTableName with multiple input paths into " + "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. " + - s"Input paths: " + relation.paths.mkString("\n", "\n", "") + s"Input paths: " + relation.location.paths.mkString("\n", "\n", "") (None, message) case (Some(serde), _) => @@ -348,7 +402,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // specific way. try { logInfo(message) - client.createTable(table) + client.createTable(table, ignoreIfExists = false) } catch { case throwable: Throwable => val warningMessage = @@ -356,28 +410,23 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive s"it into Hive metastore in Spark SQL specific format." logWarning(warningMessage, throwable) val sparkSqlSpecificTable = newSparkSQLSpecificMetastoreTable() - client.createTable(sparkSqlSpecificTable) + client.createTable(sparkSqlSpecificTable, ignoreIfExists = false) } case (None, message) => logWarning(message) val hiveTable = newSparkSQLSpecificMetastoreTable() - client.createTable(hiveTable) + client.createTable(hiveTable, ignoreIfExists = false) } } def hiveDefaultTableFilePath(tableIdent: TableIdentifier): String = { // Code based on: hiveWarehouse.getTablePath(currentDatabase, tableName) val QualifiedTableName(dbName, tblName) = getQualifiedTableName(tableIdent) - new Path(new Path(client.getDatabase(dbName).location), tblName).toString - } - - override def tableExists(tableIdent: TableIdentifier): Boolean = { - val QualifiedTableName(dbName, tblName) = getQualifiedTableName(tableIdent) - client.getTableOption(dbName, tblName).isDefined + new Path(new Path(client.getDatabase(dbName).locationUri), tblName).toString } - override def lookupRelation( + def lookupRelation( tableIdent: TableIdentifier, alias: Option[String]): LogicalPlan = { val qualifiedTableName = getQualifiedTableName(tableIdent) @@ -385,123 +434,191 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive if (table.properties.get("spark.sql.sources.provider").isDefined) { val dataSourceTable = cachedDataSourceTables(qualifiedTableName) - val tableWithQualifiers = Subquery(qualifiedTableName.name, dataSourceTable) + val qualifiedTable = SubqueryAlias(qualifiedTableName.name, dataSourceTable) // Then, if alias is specified, wrap the table with a Subquery using the alias. // Otherwise, wrap the table with a Subquery using the table name. - alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) - } else if (table.tableType == VirtualView) { + alias.map(a => SubqueryAlias(a, qualifiedTable)).getOrElse(qualifiedTable) + } else if (table.tableType == CatalogTableType.VIRTUAL_VIEW) { val viewText = table.viewText.getOrElse(sys.error("Invalid view without text.")) alias match { // because hive use things like `_c0` to build the expanded text // currently we cannot support view from "create view v1(c1) as ..." - case None => Subquery(table.name, HiveQl.createPlan(viewText)) - case Some(aliasText) => Subquery(aliasText, HiveQl.createPlan(viewText)) + case None => SubqueryAlias(table.identifier.table, hive.parseSql(viewText)) + case Some(aliasText) => SubqueryAlias(aliasText, hive.parseSql(viewText)) } } else { - MetastoreRelation(qualifiedTableName.database, qualifiedTableName.name, alias)(table)(hive) + MetastoreRelation( + qualifiedTableName.database, qualifiedTableName.name, alias)(table, client, hive) } } - private def convertToParquetRelation(metastoreRelation: MetastoreRelation): LogicalRelation = { - val metastoreSchema = StructType.fromAttributes(metastoreRelation.output) - val mergeSchema = hive.convertMetastoreParquetWithSchemaMerging - - // NOTE: Instead of passing Metastore schema directly to `ParquetRelation`, we have to - // serialize the Metastore schema to JSON and pass it as a data source option because of the - // evil case insensitivity issue, which is reconciled within `ParquetRelation`. - val parquetOptions = Map( - ParquetRelation.METASTORE_SCHEMA -> metastoreSchema.json, - ParquetRelation.MERGE_SCHEMA -> mergeSchema.toString) - val tableIdentifier = - QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) - - def getCached( - tableIdentifier: QualifiedTableName, - pathsInMetastore: Seq[String], - schemaInMetastore: StructType, - partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = { - cachedDataSourceTables.getIfPresent(tableIdentifier) match { - case null => None // Cache miss - case logical @ LogicalRelation(parquetRelation: ParquetRelation, _) => - // If we have the same paths, same schema, and same partition spec, - // we will use the cached Parquet Relation. - val useCached = - parquetRelation.paths.toSet == pathsInMetastore.toSet && - logical.schema.sameType(metastoreSchema) && - parquetRelation.partitionSpec == partitionSpecInMetastore.getOrElse { - PartitionSpec(StructType(Nil), Array.empty[datasources.Partition]) + private def getCached( + tableIdentifier: QualifiedTableName, + metastoreRelation: MetastoreRelation, + schemaInMetastore: StructType, + expectedFileFormat: Class[_ <: FileFormat], + expectedBucketSpec: Option[BucketSpec], + partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = { + + cachedDataSourceTables.getIfPresent(tableIdentifier) match { + case null => None // Cache miss + case logical @ LogicalRelation(relation: HadoopFsRelation, _, _) => + val pathsInMetastore = metastoreRelation.table.storage.locationUri.toSeq + val cachedRelationFileFormatClass = relation.fileFormat.getClass + + expectedFileFormat match { + case `cachedRelationFileFormatClass` => + // If we have the same paths, same schema, and same partition spec, + // we will use the cached relation. + val useCached = + relation.location.paths.map(_.toString).toSet == pathsInMetastore.toSet && + logical.schema.sameType(schemaInMetastore) && + relation.bucketSpec == expectedBucketSpec && + relation.partitionSpec == partitionSpecInMetastore.getOrElse { + PartitionSpec(StructType(Nil), Array.empty[PartitionDirectory]) + } + + if (useCached) { + Some(logical) + } else { + // If the cached relation is not updated, we invalidate it right away. + cachedDataSourceTables.invalidate(tableIdentifier) + None } - - if (useCached) { - Some(logical) - } else { - // If the cached relation is not updated, we invalidate it right away. + case _ => + logWarning( + s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} " + + s"should be stored as $expectedFileFormat. However, we are getting " + + s"a ${relation.fileFormat} from the metastore cache. This cached " + + s"entry will be invalidated.") cachedDataSourceTables.invalidate(tableIdentifier) None - } - case other => - logWarning( - s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} should be stored " + - s"as Parquet. However, we are getting a $other from the metastore cache. " + - s"This cached entry will be invalidated.") - cachedDataSourceTables.invalidate(tableIdentifier) - None - } + } + case other => + logWarning( + s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} should be stored " + + s"as $expectedFileFormat. However, we are getting a $other from the metastore cache. " + + s"This cached entry will be invalidated.") + cachedDataSourceTables.invalidate(tableIdentifier) + None } + } + + private def convertToLogicalRelation(metastoreRelation: MetastoreRelation, + options: Map[String, String], + defaultSource: FileFormat, + fileFormatClass: Class[_ <: FileFormat], + fileType: String): LogicalRelation = { + val metastoreSchema = StructType.fromAttributes(metastoreRelation.output) + val tableIdentifier = + QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) + val bucketSpec = None // We don't support hive bucketed tables, only ones we write out. val result = if (metastoreRelation.hiveQlTable.isPartitioned) { val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys) val partitionColumnDataTypes = partitionSchema.map(_.dataType) - // We're converting the entire table into ParquetRelation, so predicates to Hive metastore + // We're converting the entire table into HadoopFsRelation, so predicates to Hive metastore // are empty. val partitions = metastoreRelation.getHiveQlPartitions().map { p => val location = p.getLocation val values = InternalRow.fromSeq(p.getValues.asScala.zip(partitionColumnDataTypes).map { case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null) }) - ParquetPartition(values, location) + PartitionDirectory(values, location) } val partitionSpec = PartitionSpec(partitionSchema, partitions) - val paths = partitions.map(_.path) - val cached = getCached(tableIdentifier, paths, metastoreSchema, Some(partitionSpec)) - val parquetRelation = cached.getOrElse { - val created = LogicalRelation( - new ParquetRelation( - paths.toArray, None, Some(partitionSpec), parquetOptions)(hive)) + val cached = getCached( + tableIdentifier, + metastoreRelation, + metastoreSchema, + fileFormatClass, + bucketSpec, + Some(partitionSpec)) + + val hadoopFsRelation = cached.getOrElse { + val paths = new Path(metastoreRelation.table.storage.locationUri.get) :: Nil + val fileCatalog = new MetaStoreFileCatalog(hive, paths, partitionSpec) + + val inferredSchema = if (fileType.equals("parquet")) { + val inferredSchema = defaultSource.inferSchema(hive, options, fileCatalog.allFiles()) + inferredSchema.map { inferred => + ParquetRelation.mergeMetastoreParquetSchema(metastoreSchema, inferred) + }.getOrElse(metastoreSchema) + } else { + defaultSource.inferSchema(hive, options, fileCatalog.allFiles()).get + } + + val relation = HadoopFsRelation( + sqlContext = hive, + location = fileCatalog, + partitionSchema = partitionSchema, + dataSchema = inferredSchema, + bucketSpec = bucketSpec, + fileFormat = defaultSource, + options = options) + + val created = LogicalRelation(relation) cachedDataSourceTables.put(tableIdentifier, created) created } - parquetRelation + hadoopFsRelation } else { val paths = Seq(metastoreRelation.hiveQlTable.getDataLocation.toString) - val cached = getCached(tableIdentifier, paths, metastoreSchema, None) - val parquetRelation = cached.getOrElse { - val created = LogicalRelation( - new ParquetRelation(paths.toArray, None, None, parquetOptions)(hive)) + val cached = getCached(tableIdentifier, + metastoreRelation, + metastoreSchema, + fileFormatClass, + bucketSpec, + None) + val logicalRelation = cached.getOrElse { + val created = + LogicalRelation( + DataSource( + sqlContext = hive, + paths = paths, + userSpecifiedSchema = Some(metastoreRelation.schema), + bucketSpec = bucketSpec, + options = options, + className = fileType).resolveRelation()) + cachedDataSourceTables.put(tableIdentifier, created) created } - parquetRelation + logicalRelation } - result.copy(expectedOutputAttributes = Some(metastoreRelation.output)) } - override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { - val db = databaseName.getOrElse(client.currentDatabase) - - client.listTables(db).map(tableName => (tableName, false)) - } - /** * When scanning or writing to non-partitioned Metastore Parquet tables, convert them to Parquet * data source relations for better performance. */ object ParquetConversions extends Rule[LogicalPlan] { + private def shouldConvertMetastoreParquet(relation: MetastoreRelation): Boolean = { + relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") && + hive.convertMetastoreParquet + } + + private def convertToParquetRelation(relation: MetastoreRelation): LogicalRelation = { + val defaultSource = new ParquetDefaultSource() + val fileFormatClass = classOf[ParquetDefaultSource] + + val mergeSchema = hive.convertMetastoreParquetWithSchemaMerging + val options = Map( + ParquetRelation.MERGE_SCHEMA -> mergeSchema.toString, + ParquetRelation.METASTORE_TABLE_NAME -> TableIdentifier( + relation.tableName, + Some(relation.databaseName) + ).unquotedString + ) + + convertToLogicalRelation(relation, options, defaultSource, fileFormatClass, "parquet") + } + override def apply(plan: LogicalPlan): LogicalPlan = { if (!plan.resolved || plan.analyzed) { return plan @@ -511,24 +628,63 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // Write path case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) // Inserting into partitioned table is not supported in Parquet data source (yet). - if !r.hiveQlTable.isPartitioned && hive.convertMetastoreParquet && - r.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => - val parquetRelation = convertToParquetRelation(r) - InsertIntoTable(parquetRelation, partition, child, overwrite, ifNotExists) + if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreParquet(r) => + InsertIntoTable(convertToParquetRelation(r), partition, child, overwrite, ifNotExists) // Write path case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) // Inserting into partitioned table is not supported in Parquet data source (yet). - if !r.hiveQlTable.isPartitioned && hive.convertMetastoreParquet && - r.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => - val parquetRelation = convertToParquetRelation(r) - InsertIntoTable(parquetRelation, partition, child, overwrite, ifNotExists) + if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreParquet(r) => + InsertIntoTable(convertToParquetRelation(r), partition, child, overwrite, ifNotExists) // Read path - case relation: MetastoreRelation if hive.convertMetastoreParquet && - relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => + case relation: MetastoreRelation if shouldConvertMetastoreParquet(relation) => val parquetRelation = convertToParquetRelation(relation) - Subquery(relation.alias.getOrElse(relation.tableName), parquetRelation) + SubqueryAlias(relation.alias.getOrElse(relation.tableName), parquetRelation) + } + } + } + + /** + * When scanning Metastore ORC tables, convert them to ORC data source relations + * for better performance. + */ + object OrcConversions extends Rule[LogicalPlan] { + private def shouldConvertMetastoreOrc(relation: MetastoreRelation): Boolean = { + relation.tableDesc.getSerdeClassName.toLowerCase.contains("orc") && + hive.convertMetastoreOrc + } + + private def convertToOrcRelation(relation: MetastoreRelation): LogicalRelation = { + val defaultSource = new OrcDefaultSource() + val fileFormatClass = classOf[OrcDefaultSource] + val options = Map[String, String]() + + convertToLogicalRelation(relation, options, defaultSource, fileFormatClass, "orc") + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!plan.resolved || plan.analyzed) { + return plan + } + + plan transformUp { + // Write path + case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) + // Inserting into partitioned table is not supported in Orc data source (yet). + if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreOrc(r) => + InsertIntoTable(convertToOrcRelation(r), partition, child, overwrite, ifNotExists) + + // Write path + case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) + // Inserting into partitioned table is not supported in Orc data source (yet). + if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreOrc(r) => + InsertIntoTable(convertToOrcRelation(r), partition, child, overwrite, ifNotExists) + + // Read path + case relation: MetastoreRelation if shouldConvertMetastoreOrc(relation) => + val orcRelation = convertToOrcRelation(relation) + SubqueryAlias(relation.alias.getOrElse(relation.tableName), orcRelation) } } } @@ -543,43 +699,38 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive case p: LogicalPlan if !p.childrenResolved => p case p: LogicalPlan if p.resolved => p - case CreateViewAsSelect(table, child, allowExisting, replace, sql) => - if (conf.nativeView) { - if (allowExisting && replace) { - throw new AnalysisException( - "It is not allowed to define a view with both IF NOT EXISTS and OR REPLACE.") - } + case CreateViewAsSelect(table, child, allowExisting, replace, sql) if conf.nativeView => + if (allowExisting && replace) { + throw new AnalysisException( + "It is not allowed to define a view with both IF NOT EXISTS and OR REPLACE.") + } - val QualifiedTableName(dbName, tblName) = getQualifiedTableName(table) + val QualifiedTableName(dbName, tblName) = getQualifiedTableName(table) - execution.CreateViewAsSelect( - table.copy( - specifiedDatabase = Some(dbName), - name = tblName), - child.output, - allowExisting, - replace) - } else { - HiveNativeCommand(sql) - } + execution.CreateViewAsSelect( + table.copy(identifier = TableIdentifier(tblName, Some(dbName))), + child, + allowExisting, + replace) + + case CreateViewAsSelect(table, child, allowExisting, replace, sql) => + HiveNativeCommand(sql) case p @ CreateTableAsSelect(table, child, allowExisting) => val schema = if (table.schema.nonEmpty) { table.schema } else { - child.output.map { - attr => new HiveColumn( - attr.name, - HiveMetastoreTypes.toMetastoreType(attr.dataType), null) + child.output.map { a => + CatalogColumn(a.name, HiveMetastoreTypes.toMetastoreType(a.dataType), a.nullable) } } val desc = table.copy(schema = schema) - if (hive.convertCTAS && table.serde.isEmpty) { + if (hive.convertCTAS && table.storage.serde.isEmpty) { // Do the conversion when spark.sql.hive.convertCTAS is true and the query // does not specify any storage format (file format and storage handler). - if (table.specifiedDatabase.isDefined) { + if (table.identifier.database.isDefined) { throw new AnalysisException( "Cannot specify database name in a CTAS statement " + "when spark.sql.hive.convertCTAS is set to true.") @@ -587,18 +738,19 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val mode = if (allowExisting) SaveMode.Ignore else SaveMode.ErrorIfExists CreateTableUsingAsSelect( - TableIdentifier(desc.name), + TableIdentifier(desc.identifier.table), conf.defaultDataSourceName, temporary = false, Array.empty[String], + bucketSpec = None, mode, options = Map.empty[String, String], child ) } else { - val desc = if (table.serde.isEmpty) { + val desc = if (table.storage.serde.isEmpty) { // add default serde - table.copy( + table.withNewStorage( serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) } else { table @@ -607,9 +759,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val QualifiedTableName(dbName, tblName) = getQualifiedTableName(table) execution.CreateTableAsSelect( - desc.copy( - specifiedDatabase = Some(dbName), - name = tblName), + desc.copy(identifier = TableIdentifier(tblName, Some(dbName))), child, allowExisting) } @@ -658,23 +808,25 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } } - /** - * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. - * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]]. - */ - override def registerTable(tableIdent: TableIdentifier, plan: LogicalPlan): Unit = { - throw new UnsupportedOperationException - } +} - /** - * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. - * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]]. - */ - override def unregisterTable(tableIdent: TableIdentifier): Unit = { - throw new UnsupportedOperationException +/** + * An override of the standard HDFS listing based catalog, that overrides the partition spec with + * the information from the metastore. + */ +class MetaStoreFileCatalog( + hive: HiveContext, + paths: Seq[Path], + partitionSpecFromHive: PartitionSpec) + extends HDFSFileCatalog(hive, Map.empty, paths, Some(partitionSpecFromHive.partitionColumns)) { + + + override def getStatus(path: Path): Array[FileStatus] = { + val fs = path.getFileSystem(hive.sparkContext.hadoopConfiguration) + fs.listStatus(path) } - override def unregisterAllTables(): Unit = {} + override def partitionSpec(): PartitionSpec = partitionSpecFromHive } /** @@ -704,10 +856,13 @@ private[hive] case class InsertIntoHiveTable( } } -private[hive] case class MetastoreRelation - (databaseName: String, tableName: String, alias: Option[String]) - (val table: HiveTable) - (@transient private val sqlContext: SQLContext) +private[hive] case class MetastoreRelation( + databaseName: String, + tableName: String, + alias: Option[String]) + (val table: CatalogTable, + @transient private val client: HiveClient, + @transient private val sqlContext: SQLContext) extends LeafNode with MultiInstanceRelation with FileRelation { override def equals(other: Any): Boolean = other match { @@ -723,38 +878,54 @@ private[hive] case class MetastoreRelation Objects.hashCode(databaseName, tableName, alias, output) } - @transient val hiveQlTable: Table = { + override protected def otherCopyArgs: Seq[AnyRef] = table :: sqlContext :: Nil + + private def toHiveColumn(c: CatalogColumn): FieldSchema = { + new FieldSchema(c.name, c.dataType, c.comment.orNull) + } + + // TODO: merge this with HiveClientImpl#toHiveTable + @transient val hiveQlTable: HiveTable = { // We start by constructing an API table as Hive performs several important transformations // internally when converting an API table to a QL table. val tTable = new org.apache.hadoop.hive.metastore.api.Table() - tTable.setTableName(table.name) + tTable.setTableName(table.identifier.table) tTable.setDbName(table.database) val tableParameters = new java.util.HashMap[String, String]() tTable.setParameters(tableParameters) table.properties.foreach { case (k, v) => tableParameters.put(k, v) } - tTable.setTableType(table.tableType.name) + tTable.setTableType(table.tableType match { + case CatalogTableType.EXTERNAL_TABLE => HiveTableType.EXTERNAL_TABLE.toString + case CatalogTableType.MANAGED_TABLE => HiveTableType.MANAGED_TABLE.toString + case CatalogTableType.INDEX_TABLE => HiveTableType.INDEX_TABLE.toString + case CatalogTableType.VIRTUAL_VIEW => HiveTableType.VIRTUAL_VIEW.toString + }) val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() tTable.setSd(sd) - sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) - tTable.setPartitionKeys( - table.partitionColumns.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) - table.location.foreach(sd.setLocation) - table.inputFormat.foreach(sd.setInputFormat) - table.outputFormat.foreach(sd.setOutputFormat) + // Note: In Hive the schema and partition columns must be disjoint sets + val (partCols, schema) = table.schema.map(toHiveColumn).partition { c => + table.partitionColumnNames.contains(c.getName) + } + sd.setCols(schema.asJava) + tTable.setPartitionKeys(partCols.asJava) + + table.storage.locationUri.foreach(sd.setLocation) + table.storage.inputFormat.foreach(sd.setInputFormat) + table.storage.outputFormat.foreach(sd.setOutputFormat) val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo - table.serde.foreach(serdeInfo.setSerializationLib) + table.storage.serde.foreach(serdeInfo.setSerializationLib) sd.setSerdeInfo(serdeInfo) val serdeParameters = new java.util.HashMap[String, String]() - table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + table.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } serdeInfo.setParameters(serdeParameters) - new Table(tTable) + new HiveTable(tTable) } @transient override lazy val statistics: Statistics = Statistics( @@ -779,11 +950,11 @@ private[hive] case class MetastoreRelation // When metastore partition pruning is turned off, we cache the list of all partitions to // mimic the behavior of Spark < 1.5 - lazy val allPartitions = table.getAllPartitions + private lazy val allPartitions: Seq[CatalogTablePartition] = client.getAllPartitions(table) def getHiveQlPartitions(predicates: Seq[Expression] = Nil): Seq[Partition] = { val rawPartitions = if (sqlContext.conf.metastorePartitionPruning) { - table.getPartitions(predicates) + client.getPartitionsByFilter(table, predicates) } else { allPartitions } @@ -792,24 +963,24 @@ private[hive] case class MetastoreRelation val tPartition = new org.apache.hadoop.hive.metastore.api.Partition tPartition.setDbName(databaseName) tPartition.setTableName(tableName) - tPartition.setValues(p.values.asJava) + tPartition.setValues(p.spec.values.toList.asJava) val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() tPartition.setSd(sd) - sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) - - sd.setLocation(p.storage.location) - sd.setInputFormat(p.storage.inputFormat) - sd.setOutputFormat(p.storage.outputFormat) + sd.setCols(table.schema.map(toHiveColumn).asJava) + p.storage.locationUri.foreach(sd.setLocation) + p.storage.inputFormat.foreach(sd.setInputFormat) + p.storage.outputFormat.foreach(sd.setOutputFormat) val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo sd.setSerdeInfo(serdeInfo) - serdeInfo.setSerializationLib(p.storage.serde) + // maps and lists should be set only after all elements are ready (see HIVE-7975) + p.storage.serde.foreach(serdeInfo.setSerializationLib) val serdeParameters = new java.util.HashMap[String, String]() - serdeInfo.setParameters(serdeParameters) - table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + table.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + serdeInfo.setParameters(serdeParameters) new Partition(hiveQlTable, tPartition) } @@ -834,20 +1005,23 @@ private[hive] case class MetastoreRelation hiveQlTable.getMetadata ) - implicit class SchemaAttribute(f: HiveColumn) { + implicit class SchemaAttribute(f: CatalogColumn) { def toAttribute: AttributeReference = AttributeReference( f.name, - HiveMetastoreTypes.toDataType(f.hiveType), + HiveMetastoreTypes.toDataType(f.dataType), // Since data can be dumped in randomly with no validation, everything is nullable. nullable = true - )(qualifiers = Seq(alias.getOrElse(tableName))) + )(qualifier = Some(alias.getOrElse(tableName))) } /** PartitionKey attributes */ val partitionKeys = table.partitionColumns.map(_.toAttribute) /** Non-partitionKey attributes */ - val attributes = table.schema.map(_.toAttribute) + // TODO: just make this hold the schema itself, not just non-partition columns + val attributes = table.schema + .filter { c => !table.partitionColumnNames.contains(c.name) } + .map(_.toAttribute) val output = attributes ++ partitionKeys @@ -858,19 +1032,22 @@ private[hive] case class MetastoreRelation val columnOrdinals = AttributeMap(attributes.zipWithIndex) override def inputFiles: Array[String] = { - val partLocations = table.getPartitions(Nil).map(_.storage.location).toArray + val partLocations = client + .getPartitionsByFilter(table, Nil) + .flatMap(_.storage.locationUri) + .toArray if (partLocations.nonEmpty) { partLocations } else { Array( - table.location.getOrElse( + table.storage.locationUri.getOrElse( sys.error(s"Could not get the location of ${table.qualifiedName}."))) } } override def newInstance(): MetastoreRelation = { - MetastoreRelation(databaseName, tableName, alias)(table)(sqlContext) + MetastoreRelation(databaseName, tableName, alias)(table, client, sqlContext) } } @@ -905,3 +1082,28 @@ private[hive] object HiveMetastoreTypes { case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType) } } + +private[hive] case class CreateTableAsSelect( + tableDesc: CatalogTable, + child: LogicalPlan, + allowExisting: Boolean) extends UnaryNode with Command { + + override def output: Seq[Attribute] = Seq.empty[Attribute] + override lazy val resolved: Boolean = + tableDesc.identifier.database.isDefined && + tableDesc.schema.nonEmpty && + tableDesc.storage.serde.isDefined && + tableDesc.storage.inputFormat.isDefined && + tableDesc.storage.outputFormat.isDefined && + childrenResolved +} + +private[hive] case class CreateViewAsSelect( + tableDesc: CatalogTable, + child: LogicalPlan, + allowExisting: Boolean, + replace: Boolean, + sql: String) extends UnaryNode with Command { + override def output: Seq[Attribute] = Seq.empty[Attribute] + override lazy val resolved: Boolean = false +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala deleted file mode 100644 index ab88c1e68fd72..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ /dev/null @@ -1,1867 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.sql.Date -import java.util.Locale - -import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer - -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} -import org.apache.hadoop.hive.ql.lib.Node -import org.apache.hadoop.hive.ql.parse._ -import org.apache.hadoop.hive.ql.plan.PlanUtils -import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.ql.{Context, ErrorMsg} -import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe - -import org.apache.spark.Logging -import org.apache.spark.sql.{AnalysisException, catalyst} -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{logical, _} -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.CurrentOrigin -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.execution.ExplainCommand -import org.apache.spark.sql.execution.datasources.DescribeCommand -import org.apache.spark.sql.hive.HiveShim._ -import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.hive.execution.{AnalyzeTable, DropTable, HiveNativeCommand, HiveScriptIOSchema} -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval -import org.apache.spark.util.random.RandomSampler - -/** - * Used when we need to start parsing the AST before deciding that we are going to pass the command - * back for Hive to execute natively. Will be replaced with a native command that contains the - * cmd string. - */ -private[hive] case object NativePlaceholder extends LogicalPlan { - override def children: Seq[LogicalPlan] = Seq.empty - override def output: Seq[Attribute] = Seq.empty -} - -private[hive] case class CreateTableAsSelect( - tableDesc: HiveTable, - child: LogicalPlan, - allowExisting: Boolean) extends UnaryNode with Command { - - override def output: Seq[Attribute] = Seq.empty[Attribute] - override lazy val resolved: Boolean = - tableDesc.specifiedDatabase.isDefined && - tableDesc.schema.size > 0 && - tableDesc.serde.isDefined && - tableDesc.inputFormat.isDefined && - tableDesc.outputFormat.isDefined && - childrenResolved -} - -private[hive] case class CreateViewAsSelect( - tableDesc: HiveTable, - child: LogicalPlan, - allowExisting: Boolean, - replace: Boolean, - sql: String) extends UnaryNode with Command { - override def output: Seq[Attribute] = Seq.empty[Attribute] - override lazy val resolved: Boolean = false -} - -/** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ -private[hive] object HiveQl extends Logging { - protected val nativeCommands = Seq( - "TOK_ALTERDATABASE_OWNER", - "TOK_ALTERDATABASE_PROPERTIES", - "TOK_ALTERINDEX_PROPERTIES", - "TOK_ALTERINDEX_REBUILD", - "TOK_ALTERTABLE", - "TOK_ALTERTABLE_ADDCOLS", - "TOK_ALTERTABLE_ADDPARTS", - "TOK_ALTERTABLE_ALTERPARTS", - "TOK_ALTERTABLE_ARCHIVE", - "TOK_ALTERTABLE_CLUSTER_SORT", - "TOK_ALTERTABLE_DROPPARTS", - "TOK_ALTERTABLE_PARTITION", - "TOK_ALTERTABLE_PROPERTIES", - "TOK_ALTERTABLE_RENAME", - "TOK_ALTERTABLE_RENAMECOL", - "TOK_ALTERTABLE_REPLACECOLS", - "TOK_ALTERTABLE_SKEWED", - "TOK_ALTERTABLE_TOUCH", - "TOK_ALTERTABLE_UNARCHIVE", - "TOK_ALTERVIEW_ADDPARTS", - "TOK_ALTERVIEW_AS", - "TOK_ALTERVIEW_DROPPARTS", - "TOK_ALTERVIEW_PROPERTIES", - "TOK_ALTERVIEW_RENAME", - - "TOK_CREATEDATABASE", - "TOK_CREATEFUNCTION", - "TOK_CREATEINDEX", - "TOK_CREATEMACRO", - "TOK_CREATEROLE", - - "TOK_DESCDATABASE", - "TOK_DESCFUNCTION", - - "TOK_DROPDATABASE", - "TOK_DROPFUNCTION", - "TOK_DROPINDEX", - "TOK_DROPMACRO", - "TOK_DROPROLE", - "TOK_DROPTABLE_PROPERTIES", - "TOK_DROPVIEW", - "TOK_DROPVIEW_PROPERTIES", - - "TOK_EXPORT", - - "TOK_GRANT", - "TOK_GRANT_ROLE", - - "TOK_IMPORT", - - "TOK_LOAD", - - "TOK_LOCKTABLE", - - "TOK_MSCK", - - "TOK_REVOKE", - - "TOK_SHOW_COMPACTIONS", - "TOK_SHOW_CREATETABLE", - "TOK_SHOW_GRANT", - "TOK_SHOW_ROLE_GRANT", - "TOK_SHOW_ROLE_PRINCIPALS", - "TOK_SHOW_ROLES", - "TOK_SHOW_SET_ROLE", - "TOK_SHOW_TABLESTATUS", - "TOK_SHOW_TBLPROPERTIES", - "TOK_SHOW_TRANSACTIONS", - "TOK_SHOWCOLUMNS", - "TOK_SHOWDATABASES", - "TOK_SHOWFUNCTIONS", - "TOK_SHOWINDEXES", - "TOK_SHOWLOCKS", - "TOK_SHOWPARTITIONS", - - "TOK_SWITCHDATABASE", - - "TOK_UNLOCKTABLE" - ) - - // Commands that we do not need to explain. - protected val noExplainCommands = Seq( - "TOK_DESCTABLE", - "TOK_SHOWTABLES", - "TOK_TRUNCATETABLE" // truncate table" is a NativeCommand, does not need to explain. - ) ++ nativeCommands - - protected val hqlParser = new ExtendedHiveQlParser - - /** - * A set of implicit transformations that allow Hive ASTNodes to be rewritten by transformations - * similar to [[catalyst.trees.TreeNode]]. - * - * Note that this should be considered very experimental and is not indented as a replacement - * for TreeNode. Primarily it should be noted ASTNodes are not immutable and do not appear to - * have clean copy semantics. Therefore, users of this class should take care when - * copying/modifying trees that might be used elsewhere. - */ - implicit class TransformableNode(n: ASTNode) { - /** - * Returns a copy of this node where `rule` has been recursively applied to it and all of its - * children. When `rule` does not apply to a given node it is left unchanged. - * @param rule the function use to transform this nodes children - */ - def transform(rule: PartialFunction[ASTNode, ASTNode]): ASTNode = { - try { - val afterRule = rule.applyOrElse(n, identity[ASTNode]) - afterRule.withChildren( - nilIfEmpty(afterRule.getChildren) - .asInstanceOf[Seq[ASTNode]] - .map(ast => Option(ast).map(_.transform(rule)).orNull)) - } catch { - case e: Exception => - logError(dumpTree(n).toString) - throw e - } - } - - /** - * Returns a scala.Seq equivalent to [s] or Nil if [s] is null. - */ - private def nilIfEmpty[A](s: java.util.List[A]): Seq[A] = - Option(s).map(_.asScala).getOrElse(Nil) - - /** - * Returns this ASTNode with the text changed to `newText`. - */ - def withText(newText: String): ASTNode = { - n.token.asInstanceOf[org.antlr.runtime.CommonToken].setText(newText) - n - } - - /** - * Returns this ASTNode with the children changed to `newChildren`. - */ - def withChildren(newChildren: Seq[ASTNode]): ASTNode = { - (1 to n.getChildCount).foreach(_ => n.deleteChild(0)) - n.addChildren(newChildren.asJava) - n - } - - /** - * Throws an error if this is not equal to other. - * - * Right now this function only checks the name, type, text and children of the node - * for equality. - */ - def checkEquals(other: ASTNode): Unit = { - def check(field: String, f: ASTNode => Any): Unit = if (f(n) != f(other)) { - sys.error(s"$field does not match for trees. " + - s"'${f(n)}' != '${f(other)}' left: ${dumpTree(n)}, right: ${dumpTree(other)}") - } - check("name", _.getName) - check("type", _.getType) - check("text", _.getText) - check("numChildren", n => nilIfEmpty(n.getChildren).size) - - val leftChildren = nilIfEmpty(n.getChildren).asInstanceOf[Seq[ASTNode]] - val rightChildren = nilIfEmpty(other.getChildren).asInstanceOf[Seq[ASTNode]] - leftChildren zip rightChildren foreach { - case (l, r) => l checkEquals r - } - } - } - - /** - * Returns the AST for the given SQL string. - */ - def getAst(sql: String): ASTNode = { - /* - * Context has to be passed in hive0.13.1. - * Otherwise, there will be Null pointer exception, - * when retrieving properties form HiveConf. - */ - val hContext = createContext() - val node = getAst(sql, hContext) - hContext.clear() - node - } - - private def createContext(): Context = new Context(hiveConf) - - private def getAst(sql: String, context: Context) = - ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql, context)) - - /** - * Returns the HiveConf - */ - private[this] def hiveConf: HiveConf = { - var ss = SessionState.get() - // SessionState is lazy initialization, it can be null here - if (ss == null) { - val original = Thread.currentThread().getContextClassLoader - val conf = new HiveConf(classOf[SessionState]) - conf.setClassLoader(original) - ss = new SessionState(conf) - SessionState.start(ss) - } - ss.getConf - } - - /** Returns a LogicalPlan for a given HiveQL string. */ - def parseSql(sql: String): LogicalPlan = hqlParser.parse(sql) - - val errorRegEx = "line (\\d+):(\\d+) (.*)".r - - /** Creates LogicalPlan for a given HiveQL string. */ - def createPlan(sql: String): LogicalPlan = { - try { - val context = createContext() - val tree = getAst(sql, context) - val plan = if (nativeCommands contains tree.getText) { - HiveNativeCommand(sql) - } else { - nodeToPlan(tree, context) match { - case NativePlaceholder => HiveNativeCommand(sql) - case other => other - } - } - context.clear() - plan - } catch { - case pe: org.apache.hadoop.hive.ql.parse.ParseException => - pe.getMessage match { - case errorRegEx(line, start, message) => - throw new AnalysisException(message, Some(line.toInt), Some(start.toInt)) - case otherMessage => - throw new AnalysisException(otherMessage) - } - case e: MatchError => throw e - case e: Exception => - throw new AnalysisException(e.getMessage) - case e: NotImplementedError => - throw new AnalysisException( - s""" - |Unsupported language features in query: $sql - |${dumpTree(getAst(sql))} - |$e - |${e.getStackTrace.head} - """.stripMargin) - } - } - - def parseDdl(ddl: String): Seq[Attribute] = { - val tree = - try { - ParseUtils.findRootNonNullToken( - (new ParseDriver).parse(ddl, null /* no context required for parsing alone */)) - } catch { - case pe: org.apache.hadoop.hive.ql.parse.ParseException => - throw new RuntimeException(s"Failed to parse ddl: '$ddl'", pe) - } - assert(tree.asInstanceOf[ASTNode].getText == "TOK_CREATETABLE", "Only CREATE TABLE supported.") - val tableOps = tree.getChildren - val colList = - tableOps.asScala - .find(_.asInstanceOf[ASTNode].getText == "TOK_TABCOLLIST") - .getOrElse(sys.error("No columnList!")).getChildren - - colList.asScala.map(nodeToAttribute) - } - - /** Extractor for matching Hive's AST Tokens. */ - object Token { - /** @return matches of the form (tokenName, children). */ - def unapply(t: Any): Option[(String, Seq[ASTNode])] = t match { - case t: ASTNode => - CurrentOrigin.setPosition(t.getLine, t.getCharPositionInLine) - Some((t.getText, - Option(t.getChildren).map(_.asScala.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]])) - case _ => None - } - } - - protected def getClauses( - clauseNames: Seq[String], - nodeList: Seq[ASTNode]): Seq[Option[ASTNode]] = { - var remainingNodes = nodeList - val clauses = clauseNames.map { clauseName => - val (matches, nonMatches) = remainingNodes.partition(_.getText.toUpperCase == clauseName) - remainingNodes = nonMatches ++ (if (matches.nonEmpty) matches.tail else Nil) - matches.headOption - } - - if (remainingNodes.nonEmpty) { - sys.error( - s"""Unhandled clauses: ${remainingNodes.map(dumpTree(_)).mkString("\n")}. - |You are likely trying to use an unsupported Hive feature."""".stripMargin) - } - clauses - } - - def getClause(clauseName: String, nodeList: Seq[Node]): Node = - getClauseOption(clauseName, nodeList).getOrElse(sys.error( - s"Expected clause $clauseName missing from ${nodeList.map(dumpTree(_)).mkString("\n")}")) - - def getClauseOption(clauseName: String, nodeList: Seq[Node]): Option[Node] = { - nodeList.filter { case ast: ASTNode => ast.getText == clauseName } match { - case Seq(oneMatch) => Some(oneMatch) - case Seq() => None - case _ => sys.error(s"Found multiple instances of clause $clauseName") - } - } - - protected def nodeToAttribute(node: Node): Attribute = node match { - case Token("TOK_TABCOL", Token(colName, Nil) :: dataType :: Nil) => - AttributeReference(colName, nodeToDataType(dataType), true)() - - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") - } - - protected def nodeToDataType(node: Node): DataType = node match { - case Token("TOK_DECIMAL", precision :: scale :: Nil) => - DecimalType(precision.getText.toInt, scale.getText.toInt) - case Token("TOK_DECIMAL", precision :: Nil) => - DecimalType(precision.getText.toInt, 0) - case Token("TOK_DECIMAL", Nil) => DecimalType.USER_DEFAULT - case Token("TOK_BIGINT", Nil) => LongType - case Token("TOK_INT", Nil) => IntegerType - case Token("TOK_TINYINT", Nil) => ByteType - case Token("TOK_SMALLINT", Nil) => ShortType - case Token("TOK_BOOLEAN", Nil) => BooleanType - case Token("TOK_STRING", Nil) => StringType - case Token("TOK_VARCHAR", Token(_, Nil) :: Nil) => StringType - case Token("TOK_FLOAT", Nil) => FloatType - case Token("TOK_DOUBLE", Nil) => DoubleType - case Token("TOK_DATE", Nil) => DateType - case Token("TOK_TIMESTAMP", Nil) => TimestampType - case Token("TOK_BINARY", Nil) => BinaryType - case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType)) - case Token("TOK_STRUCT", - Token("TOK_TABCOLLIST", fields) :: Nil) => - StructType(fields.map(nodeToStructField)) - case Token("TOK_MAP", - keyType :: - valueType :: Nil) => - MapType(nodeToDataType(keyType), nodeToDataType(valueType)) - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for DataType:\n ${dumpTree(a).toString} ") - } - - protected def nodeToStructField(node: Node): StructField = node match { - case Token("TOK_TABCOL", - Token(fieldName, Nil) :: - dataType :: Nil) => - StructField(fieldName, nodeToDataType(dataType), nullable = true) - case Token("TOK_TABCOL", - Token(fieldName, Nil) :: - dataType :: - _ /* comment */:: Nil) => - StructField(fieldName, nodeToDataType(dataType), nullable = true) - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for StructField:\n ${dumpTree(a).toString} ") - } - - protected def extractTableIdent(tableNameParts: Node): TableIdentifier = { - tableNameParts.getChildren.asScala.map { - case Token(part, Nil) => cleanIdentifier(part) - } match { - case Seq(tableOnly) => TableIdentifier(tableOnly) - case Seq(databaseName, table) => TableIdentifier(table, Some(databaseName)) - case other => sys.error("Hive only supports tables names like 'tableName' " + - s"or 'databaseName.tableName', found '$other'") - } - } - - /** - * SELECT MAX(value) FROM src GROUP BY k1, k2, k3 GROUPING SETS((k1, k2), (k2)) - * is equivalent to - * SELECT MAX(value) FROM src GROUP BY k1, k2 UNION SELECT MAX(value) FROM src GROUP BY k2 - * Check the following link for details. - * -https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C+Grouping+and+Rollup - * - * The bitmask denotes the grouping expressions validity for a grouping set, - * the bitmask also be called as grouping id (`GROUPING__ID`, the virtual column in Hive) - * e.g. In superset (k1, k2, k3), (bit 0: k1, bit 1: k2, and bit 2: k3), the grouping id of - * GROUPING SETS (k1, k2) and (k2) should be 3 and 2 respectively. - */ - protected def extractGroupingSet(children: Seq[ASTNode]): (Seq[Expression], Seq[Int]) = { - val (keyASTs, setASTs) = children.partition( n => n match { - case Token("TOK_GROUPING_SETS_EXPRESSION", children) => false // grouping sets - case _ => true // grouping keys - }) - - val keys = keyASTs.map(nodeToExpr).toSeq - val keyMap = keyASTs.map(_.toStringTree).zipWithIndex.toMap - - val bitmasks: Seq[Int] = setASTs.map(set => set match { - case Token("TOK_GROUPING_SETS_EXPRESSION", null) => 0 - case Token("TOK_GROUPING_SETS_EXPRESSION", children) => - children.foldLeft(0)((bitmap, col) => { - val colString = col.asInstanceOf[ASTNode].toStringTree() - require(keyMap.contains(colString), s"$colString doens't show up in the GROUP BY list") - bitmap | 1 << keyMap(colString) - }) - case _ => sys.error("Expect GROUPING SETS clause") - }) - - (keys, bitmasks) - } - - protected def getProperties(node: Node): Seq[(String, String)] = node match { - case Token("TOK_TABLEPROPLIST", list) => - list.map { - case Token("TOK_TABLEPROPERTY", Token(key, Nil) :: Token(value, Nil) :: Nil) => - (unquoteString(key) -> unquoteString(value)) - } - } - - private def createView( - view: ASTNode, - context: Context, - viewNameParts: ASTNode, - query: ASTNode, - schema: Seq[HiveColumn], - properties: Map[String, String], - allowExist: Boolean, - replace: Boolean): CreateViewAsSelect = { - val TableIdentifier(viewName, dbName) = extractTableIdent(viewNameParts) - - val originalText = context.getTokenRewriteStream - .toString(query.getTokenStartIndex, query.getTokenStopIndex) - - val tableDesc = HiveTable( - specifiedDatabase = dbName, - name = viewName, - schema = schema, - partitionColumns = Seq.empty[HiveColumn], - properties = properties, - serdeProperties = Map[String, String](), - tableType = VirtualView, - location = None, - inputFormat = None, - outputFormat = None, - serde = None, - viewText = Some(originalText)) - - // We need to keep the original SQL string so that if `spark.sql.nativeView` is - // false, we can fall back to use hive native command later. - // We can remove this when parser is configurable(can access SQLConf) in the future. - val sql = context.getTokenRewriteStream - .toString(view.getTokenStartIndex, view.getTokenStopIndex) - CreateViewAsSelect(tableDesc, nodeToPlan(query, context), allowExist, replace, sql) - } - - protected def nodeToPlan(node: ASTNode, context: Context): LogicalPlan = node match { - // Special drop table that also uncaches. - case Token("TOK_DROPTABLE", - Token("TOK_TABNAME", tableNameParts) :: - ifExists) => - val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") - DropTable(tableName, ifExists.nonEmpty) - // Support "ANALYZE TABLE tableNmae COMPUTE STATISTICS noscan" - case Token("TOK_ANALYZE", - Token("TOK_TAB", Token("TOK_TABNAME", tableNameParts) :: partitionSpec) :: - isNoscan) => - // Reference: - // https://cwiki.apache.org/confluence/display/Hive/StatsDev#StatsDev-ExistingTables - if (partitionSpec.nonEmpty) { - // Analyze partitions will be treated as a Hive native command. - NativePlaceholder - } else if (isNoscan.isEmpty) { - // If users do not specify "noscan", it will be treated as a Hive native command. - NativePlaceholder - } else { - val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") - AnalyzeTable(tableName) - } - // Just fake explain for any of the native commands. - case Token("TOK_EXPLAIN", explainArgs) - if noExplainCommands.contains(explainArgs.head.getText) => - ExplainCommand(OneRowRelation) - case Token("TOK_EXPLAIN", explainArgs) - if "TOK_CREATETABLE" == explainArgs.head.getText => - val Some(crtTbl) :: _ :: extended :: Nil = - getClauses(Seq("TOK_CREATETABLE", "FORMATTED", "EXTENDED"), explainArgs) - ExplainCommand( - nodeToPlan(crtTbl, context), - extended = extended.isDefined) - case Token("TOK_EXPLAIN", explainArgs) => - // Ignore FORMATTED if present. - val Some(query) :: _ :: extended :: Nil = - getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs) - ExplainCommand( - nodeToPlan(query, context), - extended = extended.isDefined) - - case Token("TOK_DESCTABLE", describeArgs) => - // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL - val Some(tableType) :: formatted :: extended :: pretty :: Nil = - getClauses(Seq("TOK_TABTYPE", "FORMATTED", "EXTENDED", "PRETTY"), describeArgs) - if (formatted.isDefined || pretty.isDefined) { - // FORMATTED and PRETTY are not supported and this statement will be treated as - // a Hive native command. - NativePlaceholder - } else { - tableType match { - case Token("TOK_TABTYPE", nameParts) if nameParts.size == 1 => { - nameParts.head match { - case Token(".", dbName :: tableName :: Nil) => - // It is describing a table with the format like "describe db.table". - // TODO: Actually, a user may mean tableName.columnName. Need to resolve this issue. - val tableIdent = extractTableIdent(nameParts.head) - DescribeCommand( - UnresolvedRelation(tableIdent, None), isExtended = extended.isDefined) - case Token(".", dbName :: tableName :: colName :: Nil) => - // It is describing a column with the format like "describe db.table column". - NativePlaceholder - case tableName => - // It is describing a table with the format like "describe table". - DescribeCommand( - UnresolvedRelation(TableIdentifier(tableName.getText), None), - isExtended = extended.isDefined) - } - } - // All other cases. - case _ => NativePlaceholder - } - } - - case view @ Token("TOK_ALTERVIEW", children) => - val Some(viewNameParts) :: maybeQuery :: ignores = - getClauses(Seq( - "TOK_TABNAME", - "TOK_QUERY", - "TOK_ALTERVIEW_ADDPARTS", - "TOK_ALTERVIEW_DROPPARTS", - "TOK_ALTERVIEW_PROPERTIES", - "TOK_ALTERVIEW_RENAME"), children) - - // if ALTER VIEW doesn't have query part, let hive to handle it. - maybeQuery.map { query => - createView(view, context, viewNameParts, query, Nil, Map(), false, true) - }.getOrElse(NativePlaceholder) - - case view @ Token("TOK_CREATEVIEW", children) - if children.collect { case t @ Token("TOK_QUERY", _) => t }.nonEmpty => - val Seq( - Some(viewNameParts), - Some(query), - maybeComment, - replace, - allowExisting, - maybeProperties, - maybeColumns, - maybePartCols - ) = getClauses(Seq( - "TOK_TABNAME", - "TOK_QUERY", - "TOK_TABLECOMMENT", - "TOK_ORREPLACE", - "TOK_IFNOTEXISTS", - "TOK_TABLEPROPERTIES", - "TOK_TABCOLNAME", - "TOK_VIEWPARTCOLS"), children) - - // If the view is partitioned, we let hive handle it. - if (maybePartCols.isDefined) { - NativePlaceholder - } else { - val schema = maybeColumns.map { cols => - BaseSemanticAnalyzer.getColumns(cols, true).asScala.map { field => - // We can't specify column types when create view, so fill it with null first, and - // update it after the schema has been resolved later. - HiveColumn(field.getName, null, field.getComment) - } - }.getOrElse(Seq.empty[HiveColumn]) - - val properties = scala.collection.mutable.Map.empty[String, String] - - maybeProperties.foreach { - case Token("TOK_TABLEPROPERTIES", list :: Nil) => - properties ++= getProperties(list) - } - - maybeComment.foreach { - case Token("TOK_TABLECOMMENT", child :: Nil) => - val comment = BaseSemanticAnalyzer.unescapeSQLString(child.getText) - if (comment ne null) { - properties += ("comment" -> comment) - } - } - - createView(view, context, viewNameParts, query, schema, properties.toMap, - allowExisting.isDefined, replace.isDefined) - } - - case Token("TOK_CREATETABLE", children) - if children.collect { case t @ Token("TOK_QUERY", _) => t }.nonEmpty => - // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL - val ( - Some(tableNameParts) :: - _ /* likeTable */ :: - externalTable :: - Some(query) :: - allowExisting +: - ignores) = - getClauses( - Seq( - "TOK_TABNAME", - "TOK_LIKETABLE", - "EXTERNAL", - "TOK_QUERY", - "TOK_IFNOTEXISTS", - "TOK_TABLECOMMENT", - "TOK_TABCOLLIST", - "TOK_TABLEPARTCOLS", // Partitioned by - "TOK_TABLEBUCKETS", // Clustered by - "TOK_TABLESKEWED", // Skewed by - "TOK_TABLEROWFORMAT", - "TOK_TABLESERIALIZER", - "TOK_FILEFORMAT_GENERIC", - "TOK_TABLEFILEFORMAT", // User-provided InputFormat and OutputFormat - "TOK_STORAGEHANDLER", // Storage handler - "TOK_TABLELOCATION", - "TOK_TABLEPROPERTIES"), - children) - val TableIdentifier(tblName, dbName) = extractTableIdent(tableNameParts) - - // TODO add bucket support - var tableDesc: HiveTable = HiveTable( - specifiedDatabase = dbName, - name = tblName, - schema = Seq.empty[HiveColumn], - partitionColumns = Seq.empty[HiveColumn], - properties = Map[String, String](), - serdeProperties = Map[String, String](), - tableType = if (externalTable.isDefined) ExternalTable else ManagedTable, - location = None, - inputFormat = None, - outputFormat = None, - serde = None, - viewText = None) - - // default storage type abbreviation (e.g. RCFile, ORC, PARQUET etc.) - val defaultStorageType = hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTFILEFORMAT) - // handle the default format for the storage type abbreviation - val hiveSerDe = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf).getOrElse { - HiveSerDe( - inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - } - - hiveSerDe.inputFormat.foreach(f => tableDesc = tableDesc.copy(inputFormat = Some(f))) - hiveSerDe.outputFormat.foreach(f => tableDesc = tableDesc.copy(outputFormat = Some(f))) - hiveSerDe.serde.foreach(f => tableDesc = tableDesc.copy(serde = Some(f))) - - children.collect { - case list @ Token("TOK_TABCOLLIST", _) => - val cols = BaseSemanticAnalyzer.getColumns(list, true) - if (cols != null) { - tableDesc = tableDesc.copy( - schema = cols.asScala.map { field => - HiveColumn(field.getName, field.getType, field.getComment) - }) - } - case Token("TOK_TABLECOMMENT", child :: Nil) => - val comment = BaseSemanticAnalyzer.unescapeSQLString(child.getText) - // TODO support the sql text - tableDesc = tableDesc.copy(viewText = Option(comment)) - case Token("TOK_TABLEPARTCOLS", list @ Token("TOK_TABCOLLIST", _) :: Nil) => - val cols = BaseSemanticAnalyzer.getColumns(list(0), false) - if (cols != null) { - tableDesc = tableDesc.copy( - partitionColumns = cols.asScala.map { field => - HiveColumn(field.getName, field.getType, field.getComment) - }) - } - case Token("TOK_TABLEROWFORMAT", Token("TOK_SERDEPROPS", child :: Nil) :: Nil) => - val serdeParams = new java.util.HashMap[String, String]() - child match { - case Token("TOK_TABLEROWFORMATFIELD", rowChild1 :: rowChild2) => - val fieldDelim = BaseSemanticAnalyzer.unescapeSQLString (rowChild1.getText()) - serdeParams.put(serdeConstants.FIELD_DELIM, fieldDelim) - serdeParams.put(serdeConstants.SERIALIZATION_FORMAT, fieldDelim) - if (rowChild2.length > 1) { - val fieldEscape = BaseSemanticAnalyzer.unescapeSQLString (rowChild2(0).getText) - serdeParams.put(serdeConstants.ESCAPE_CHAR, fieldEscape) - } - case Token("TOK_TABLEROWFORMATCOLLITEMS", rowChild :: Nil) => - val collItemDelim = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) - serdeParams.put(serdeConstants.COLLECTION_DELIM, collItemDelim) - case Token("TOK_TABLEROWFORMATMAPKEYS", rowChild :: Nil) => - val mapKeyDelim = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) - serdeParams.put(serdeConstants.MAPKEY_DELIM, mapKeyDelim) - case Token("TOK_TABLEROWFORMATLINES", rowChild :: Nil) => - val lineDelim = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) - if (!(lineDelim == "\n") && !(lineDelim == "10")) { - throw new AnalysisException( - SemanticAnalyzer.generateErrorMessage( - rowChild, - ErrorMsg.LINES_TERMINATED_BY_NON_NEWLINE.getMsg)) - } - serdeParams.put(serdeConstants.LINE_DELIM, lineDelim) - case Token("TOK_TABLEROWFORMATNULL", rowChild :: Nil) => - val nullFormat = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) - // TODO support the nullFormat - case _ => assert(false) - } - tableDesc = tableDesc.copy( - serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) - case Token("TOK_TABLELOCATION", child :: Nil) => - var location = BaseSemanticAnalyzer.unescapeSQLString(child.getText) - location = EximUtil.relativeToAbsolutePath(hiveConf, location) - tableDesc = tableDesc.copy(location = Option(location)) - case Token("TOK_TABLESERIALIZER", child :: Nil) => - tableDesc = tableDesc.copy( - serde = Option(BaseSemanticAnalyzer.unescapeSQLString(child.getChild(0).getText))) - if (child.getChildCount == 2) { - val serdeParams = new java.util.HashMap[String, String]() - BaseSemanticAnalyzer.readProps( - (child.getChild(1).getChild(0)).asInstanceOf[ASTNode], serdeParams) - tableDesc = tableDesc.copy( - serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) - } - case Token("TOK_FILEFORMAT_GENERIC", child :: Nil) => - child.getText().toLowerCase(Locale.ENGLISH) match { - case "orc" => - tableDesc = tableDesc.copy( - inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) - } - - case "parquet" => - tableDesc = tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) - } - - case "rcfile" => - tableDesc = tableDesc.copy( - inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) - } - - case "textfile" => - tableDesc = tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - - case "sequencefile" => - tableDesc = tableDesc.copy( - inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), - outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")) - - case "avro" => - tableDesc = tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.serde2.avro.AvroSerDe")) - } - - case _ => - throw new SemanticException( - s"Unrecognized file format in STORED AS clause: ${child.getText}") - } - - case Token("TOK_TABLESERIALIZER", - Token("TOK_SERDENAME", Token(serdeName, Nil) :: otherProps) :: Nil) => - tableDesc = tableDesc.copy(serde = Option(unquoteString(serdeName))) - - otherProps match { - case Token("TOK_TABLEPROPERTIES", list :: Nil) :: Nil => - tableDesc = tableDesc.copy( - serdeProperties = tableDesc.serdeProperties ++ getProperties(list)) - case Nil => - } - - case Token("TOK_TABLEPROPERTIES", list :: Nil) => - tableDesc = tableDesc.copy(properties = tableDesc.properties ++ getProperties(list)) - case list @ Token("TOK_TABLEFILEFORMAT", children) => - tableDesc = tableDesc.copy( - inputFormat = - Option(BaseSemanticAnalyzer.unescapeSQLString(list.getChild(0).getText)), - outputFormat = - Option(BaseSemanticAnalyzer.unescapeSQLString(list.getChild(1).getText))) - case Token("TOK_STORAGEHANDLER", _) => - throw new AnalysisException(ErrorMsg.CREATE_NON_NATIVE_AS.getMsg()) - case _ => // Unsupport features - } - - CreateTableAsSelect(tableDesc, nodeToPlan(query, context), allowExisting != None) - - // If its not a "CTAS" like above then take it as a native command - case Token("TOK_CREATETABLE", _) => NativePlaceholder - - // Support "TRUNCATE TABLE table_name [PARTITION partition_spec]" - case Token("TOK_TRUNCATETABLE", - Token("TOK_TABLE_PARTITION", table) :: Nil) => NativePlaceholder - - case Token("TOK_QUERY", queryArgs) - if Seq("TOK_FROM", "TOK_INSERT").contains(queryArgs.head.getText) => - - val (fromClause: Option[ASTNode], insertClauses, cteRelations) = - queryArgs match { - case Token("TOK_FROM", args: Seq[ASTNode]) :: insertClauses => - // check if has CTE - insertClauses.last match { - case Token("TOK_CTE", cteClauses) => - val cteRelations = cteClauses.map(node => { - val relation = nodeToRelation(node, context).asInstanceOf[Subquery] - (relation.alias, relation) - }).toMap - (Some(args.head), insertClauses.init, Some(cteRelations)) - - case _ => (Some(args.head), insertClauses, None) - } - - case Token("TOK_INSERT", _) :: Nil => (None, queryArgs, None) - } - - // Return one query for each insert clause. - val queries = insertClauses.map { case Token("TOK_INSERT", singleInsert) => - val ( - intoClause :: - destClause :: - selectClause :: - selectDistinctClause :: - whereClause :: - groupByClause :: - rollupGroupByClause :: - cubeGroupByClause :: - groupingSetsClause :: - orderByClause :: - havingClause :: - sortByClause :: - clusterByClause :: - distributeByClause :: - limitClause :: - lateralViewClause :: - windowClause :: Nil) = { - getClauses( - Seq( - "TOK_INSERT_INTO", - "TOK_DESTINATION", - "TOK_SELECT", - "TOK_SELECTDI", - "TOK_WHERE", - "TOK_GROUPBY", - "TOK_ROLLUP_GROUPBY", - "TOK_CUBE_GROUPBY", - "TOK_GROUPING_SETS", - "TOK_ORDERBY", - "TOK_HAVING", - "TOK_SORTBY", - "TOK_CLUSTERBY", - "TOK_DISTRIBUTEBY", - "TOK_LIMIT", - "TOK_LATERAL_VIEW", - "WINDOW"), - singleInsert) - } - - val relations = fromClause match { - case Some(f) => nodeToRelation(f, context) - case None => OneRowRelation - } - - val withWhere = whereClause.map { whereNode => - val Seq(whereExpr) = whereNode.getChildren.asScala - Filter(nodeToExpr(whereExpr), relations) - }.getOrElse(relations) - - val select = - (selectClause orElse selectDistinctClause).getOrElse(sys.error("No select clause.")) - - // Script transformations are expressed as a select clause with a single expression of type - // TOK_TRANSFORM - val transformation = select.getChildren.iterator().next() match { - case Token("TOK_SELEXPR", - Token("TOK_TRANSFORM", - Token("TOK_EXPLIST", inputExprs) :: - Token("TOK_SERDE", inputSerdeClause) :: - Token("TOK_RECORDWRITER", writerClause) :: - // TODO: Need to support other types of (in/out)put - Token(script, Nil) :: - Token("TOK_SERDE", outputSerdeClause) :: - Token("TOK_RECORDREADER", readerClause) :: - outputClause) :: Nil) => - - val (output, schemaLess) = outputClause match { - case Token("TOK_ALIASLIST", aliases) :: Nil => - (aliases.map { case Token(name, Nil) => AttributeReference(name, StringType)() }, - false) - case Token("TOK_TABCOLLIST", attributes) :: Nil => - (attributes.map { case Token("TOK_TABCOL", Token(name, Nil) :: dataType :: Nil) => - AttributeReference(name, nodeToDataType(dataType))() }, false) - case Nil => - (List(AttributeReference("key", StringType)(), - AttributeReference("value", StringType)()), true) - } - - type SerDeInfo = ( - Seq[(String, String)], // Input row format information - Option[String], // Optional input SerDe class - Seq[(String, String)], // Input SerDe properties - Boolean // Whether to use default record reader/writer - ) - - def matchSerDe(clause: Seq[ASTNode]): SerDeInfo = clause match { - case Token("TOK_SERDEPROPS", propsClause) :: Nil => - val rowFormat = propsClause.map { - case Token(name, Token(value, Nil) :: Nil) => (name, value) - } - (rowFormat, None, Nil, false) - - case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => - (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil, false) - - case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: - Token("TOK_TABLEPROPERTIES", - Token("TOK_TABLEPROPLIST", propsClause) :: Nil) :: Nil) :: Nil => - val serdeProps = propsClause.map { - case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) => - (BaseSemanticAnalyzer.unescapeSQLString(name), - BaseSemanticAnalyzer.unescapeSQLString(value)) - } - - // SPARK-10310: Special cases LazySimpleSerDe - // TODO Fully supports user-defined record reader/writer classes - val unescapedSerDeClass = BaseSemanticAnalyzer.unescapeSQLString(serdeClass) - val useDefaultRecordReaderWriter = - unescapedSerDeClass == classOf[LazySimpleSerDe].getCanonicalName - (Nil, Some(unescapedSerDeClass), serdeProps, useDefaultRecordReaderWriter) - - case Nil => - // Uses default TextRecordReader/TextRecordWriter, sets field delimiter here - val serdeProps = Seq(serdeConstants.FIELD_DELIM -> "\t") - (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), serdeProps, true) - } - - val (inRowFormat, inSerdeClass, inSerdeProps, useDefaultRecordReader) = - matchSerDe(inputSerdeClause) - - val (outRowFormat, outSerdeClass, outSerdeProps, useDefaultRecordWriter) = - matchSerDe(outputSerdeClause) - - val unescapedScript = BaseSemanticAnalyzer.unescapeSQLString(script) - - // TODO Adds support for user-defined record reader/writer classes - val recordReaderClass = if (useDefaultRecordReader) { - Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDREADER)) - } else { - None - } - - val recordWriterClass = if (useDefaultRecordWriter) { - Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDWRITER)) - } else { - None - } - - val schema = HiveScriptIOSchema( - inRowFormat, outRowFormat, - inSerdeClass, outSerdeClass, - inSerdeProps, outSerdeProps, - recordReaderClass, recordWriterClass, - schemaLess) - - Some( - logical.ScriptTransformation( - inputExprs.map(nodeToExpr), - unescapedScript, - output, - withWhere, schema)) - case _ => None - } - - val withLateralView = lateralViewClause.map { lv => - val Token("TOK_SELECT", - Token("TOK_SELEXPR", clauses) :: Nil) = lv.getChildren.iterator().next() - - val alias = getClause("TOK_TABALIAS", clauses).getChildren.iterator().next() - .asInstanceOf[ASTNode].getText - - val (generator, attributes) = nodesToGenerator(clauses) - Generate( - generator, - join = true, - outer = false, - Some(alias.toLowerCase), - attributes.map(UnresolvedAttribute(_)), - withWhere) - }.getOrElse(withWhere) - - // The projection of the query can either be a normal projection, an aggregation - // (if there is a group by) or a script transformation. - val withProject: LogicalPlan = transformation.getOrElse { - val selectExpressions = - select.getChildren.asScala.flatMap(selExprNodeToExpr).map(UnresolvedAlias) - Seq( - groupByClause.map(e => e match { - case Token("TOK_GROUPBY", children) => - // Not a transformation so must be either project or aggregation. - Aggregate(children.map(nodeToExpr), selectExpressions, withLateralView) - case _ => sys.error("Expect GROUP BY") - }), - groupingSetsClause.map(e => e match { - case Token("TOK_GROUPING_SETS", children) => - val(groupByExprs, masks) = extractGroupingSet(children) - GroupingSets(masks, groupByExprs, withLateralView, selectExpressions) - case _ => sys.error("Expect GROUPING SETS") - }), - rollupGroupByClause.map(e => e match { - case Token("TOK_ROLLUP_GROUPBY", children) => - Rollup(children.map(nodeToExpr), withLateralView, selectExpressions) - case _ => sys.error("Expect WITH ROLLUP") - }), - cubeGroupByClause.map(e => e match { - case Token("TOK_CUBE_GROUPBY", children) => - Cube(children.map(nodeToExpr), withLateralView, selectExpressions) - case _ => sys.error("Expect WITH CUBE") - }), - Some(Project(selectExpressions, withLateralView))).flatten.head - } - - // Handle HAVING clause. - val withHaving = havingClause.map { h => - val havingExpr = h.getChildren.asScala match { case Seq(hexpr) => nodeToExpr(hexpr) } - // Note that we added a cast to boolean. If the expression itself is already boolean, - // the optimizer will get rid of the unnecessary cast. - Filter(Cast(havingExpr, BooleanType), withProject) - }.getOrElse(withProject) - - // Handle SELECT DISTINCT - val withDistinct = - if (selectDistinctClause.isDefined) Distinct(withHaving) else withHaving - - // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. - val withSort = - (orderByClause, sortByClause, distributeByClause, clusterByClause) match { - case (Some(totalOrdering), None, None, None) => - Sort(totalOrdering.getChildren.asScala.map(nodeToSortOrder), true, withDistinct) - case (None, Some(perPartitionOrdering), None, None) => - Sort( - perPartitionOrdering.getChildren.asScala.map(nodeToSortOrder), - false, withDistinct) - case (None, None, Some(partitionExprs), None) => - RepartitionByExpression( - partitionExprs.getChildren.asScala.map(nodeToExpr), withDistinct) - case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => - Sort( - perPartitionOrdering.getChildren.asScala.map(nodeToSortOrder), false, - RepartitionByExpression( - partitionExprs.getChildren.asScala.map(nodeToExpr), - withDistinct)) - case (None, None, None, Some(clusterExprs)) => - Sort( - clusterExprs.getChildren.asScala.map(nodeToExpr).map(SortOrder(_, Ascending)), - false, - RepartitionByExpression( - clusterExprs.getChildren.asScala.map(nodeToExpr), - withDistinct)) - case (None, None, None, None) => withDistinct - case _ => sys.error("Unsupported set of ordering / distribution clauses.") - } - - val withLimit = - limitClause.map(l => nodeToExpr(l.getChildren.iterator().next())) - .map(Limit(_, withSort)) - .getOrElse(withSort) - - // Collect all window specifications defined in the WINDOW clause. - val windowDefinitions = windowClause.map(_.getChildren.asScala.collect { - case Token("TOK_WINDOWDEF", - Token(windowName, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) => - windowName -> nodesToWindowSpecification(spec) - }.toMap) - // Handle cases like - // window w1 as (partition by p_mfgr order by p_name - // range between 2 preceding and 2 following), - // w2 as w1 - val resolvedCrossReference = windowDefinitions.map { - windowDefMap => windowDefMap.map { - case (windowName, WindowSpecReference(other)) => - (windowName, windowDefMap(other).asInstanceOf[WindowSpecDefinition]) - case o => o.asInstanceOf[(String, WindowSpecDefinition)] - } - } - - val withWindowDefinitions = - resolvedCrossReference.map(WithWindowDefinition(_, withLimit)).getOrElse(withLimit) - - // TOK_INSERT_INTO means to add files to the table. - // TOK_DESTINATION means to overwrite the table. - val resultDestination = - (intoClause orElse destClause).getOrElse(sys.error("No destination found.")) - val overwrite = intoClause.isEmpty - nodeToDest( - resultDestination, - withWindowDefinitions, - overwrite) - } - - // If there are multiple INSERTS just UNION them together into on query. - val query = queries.reduceLeft(Union) - - // return With plan if there is CTE - cteRelations.map(With(query, _)).getOrElse(query) - - // HIVE-9039 renamed TOK_UNION => TOK_UNIONALL while adding TOK_UNIONDISTINCT - case Token("TOK_UNIONALL", left :: right :: Nil) => - Union(nodeToPlan(left, context), nodeToPlan(right, context)) - - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for $node:\n ${dumpTree(a).toString} ") - } - - val allJoinTokens = "(TOK_.*JOIN)".r - val laterViewToken = "TOK_LATERAL_VIEW(.*)".r - def nodeToRelation(node: Node, context: Context): LogicalPlan = node match { - case Token("TOK_SUBQUERY", - query :: Token(alias, Nil) :: Nil) => - Subquery(cleanIdentifier(alias), nodeToPlan(query, context)) - - case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) => - val Token("TOK_SELECT", - Token("TOK_SELEXPR", clauses) :: Nil) = selectClause - - val alias = getClause("TOK_TABALIAS", clauses).getChildren.iterator().next() - .asInstanceOf[ASTNode].getText - - val (generator, attributes) = nodesToGenerator(clauses) - Generate( - generator, - join = true, - outer = isOuter.nonEmpty, - Some(alias.toLowerCase), - attributes.map(UnresolvedAttribute(_)), - nodeToRelation(relationClause, context)) - - /* All relations, possibly with aliases or sampling clauses. */ - case Token("TOK_TABREF", clauses) => - // If the last clause is not a token then it's the alias of the table. - val (nonAliasClauses, aliasClause) = - if (clauses.last.getText.startsWith("TOK")) { - (clauses, None) - } else { - (clauses.dropRight(1), Some(clauses.last)) - } - - val (Some(tableNameParts) :: - splitSampleClause :: - bucketSampleClause :: Nil) = { - getClauses(Seq("TOK_TABNAME", "TOK_TABLESPLITSAMPLE", "TOK_TABLEBUCKETSAMPLE"), - nonAliasClauses) - } - - val tableIdent = extractTableIdent(tableNameParts) - val alias = aliasClause.map { case Token(a, Nil) => cleanIdentifier(a) } - val relation = UnresolvedRelation(tableIdent, alias) - - // Apply sampling if requested. - (bucketSampleClause orElse splitSampleClause).map { - case Token("TOK_TABLESPLITSAMPLE", - Token("TOK_ROWCOUNT", Nil) :: - Token(count, Nil) :: Nil) => - Limit(Literal(count.toInt), relation) - case Token("TOK_TABLESPLITSAMPLE", - Token("TOK_PERCENT", Nil) :: - Token(fraction, Nil) :: Nil) => - // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling - // function takes X PERCENT as the input and the range of X is [0, 100], we need to - // adjust the fraction. - require( - fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon) - && fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon), - s"Sampling fraction ($fraction) must be on interval [0, 100]") - Sample(0.0, fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt, - relation) - case Token("TOK_TABLEBUCKETSAMPLE", - Token(numerator, Nil) :: - Token(denominator, Nil) :: Nil) => - val fraction = numerator.toDouble / denominator.toDouble - Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation) - case a: ASTNode => - throw new NotImplementedError( - s"""No parse rules for sampling clause: ${a.getType}, text: ${a.getText} : - |${dumpTree(a).toString}" + - """.stripMargin) - }.getOrElse(relation) - - case Token("TOK_UNIQUEJOIN", joinArgs) => - val tableOrdinals = - joinArgs.zipWithIndex.filter { - case (arg, i) => arg.getText == "TOK_TABREF" - }.map(_._2) - - val isPreserved = tableOrdinals.map(i => (i - 1 < 0) || joinArgs(i - 1).getText == "PRESERVE") - val tables = tableOrdinals.map(i => nodeToRelation(joinArgs(i), context)) - val joinExpressions = - tableOrdinals.map(i => joinArgs(i + 1).getChildren.asScala.map(nodeToExpr)) - - val joinConditions = joinExpressions.sliding(2).map { - case Seq(c1, c2) => - val predicates = (c1, c2).zipped.map { case (e1, e2) => EqualTo(e1, e2): Expression } - predicates.reduceLeft(And) - }.toBuffer - - val joinType = isPreserved.sliding(2).map { - case Seq(true, true) => FullOuter - case Seq(true, false) => LeftOuter - case Seq(false, true) => RightOuter - case Seq(false, false) => Inner - }.toBuffer - - val joinedTables = tables.reduceLeft(Join(_, _, Inner, None)) - - // Must be transform down. - val joinedResult = joinedTables transform { - case j: Join => - j.copy( - condition = Some(joinConditions.remove(joinConditions.length - 1)), - joinType = joinType.remove(joinType.length - 1)) - } - - val groups = joinExpressions.head.indices.map(i => Coalesce(joinExpressions.map(_(i)))) - - // Unique join is not really the same as an outer join so we must group together results where - // the joinExpressions are the same, taking the First of each value is only okay because the - // user of a unique join is implicitly promising that there is only one result. - // TODO: This doesn't actually work since [[Star]] is not a valid aggregate expression. - // instead we should figure out how important supporting this feature is and whether it is - // worth the number of hacks that will be required to implement it. Namely, we need to add - // some sort of mapped star expansion that would expand all child output row to be similarly - // named output expressions where some aggregate expression has been applied (i.e. First). - // Aggregate(groups, Star(None, First(_)) :: Nil, joinedResult) - throw new UnsupportedOperationException - - case Token(allJoinTokens(joinToken), - relation1 :: - relation2 :: other) => - if (!(other.size <= 1)) { - sys.error(s"Unsupported join operation: $other") - } - - val joinType = joinToken match { - case "TOK_JOIN" => Inner - case "TOK_CROSSJOIN" => Inner - case "TOK_RIGHTOUTERJOIN" => RightOuter - case "TOK_LEFTOUTERJOIN" => LeftOuter - case "TOK_FULLOUTERJOIN" => FullOuter - case "TOK_LEFTSEMIJOIN" => LeftSemi - } - Join(nodeToRelation(relation1, context), - nodeToRelation(relation2, context), - joinType, - other.headOption.map(nodeToExpr)) - - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") - } - - def nodeToSortOrder(node: Node): SortOrder = node match { - case Token("TOK_TABSORTCOLNAMEASC", sortExpr :: Nil) => - SortOrder(nodeToExpr(sortExpr), Ascending) - case Token("TOK_TABSORTCOLNAMEDESC", sortExpr :: Nil) => - SortOrder(nodeToExpr(sortExpr), Descending) - - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") - } - - val destinationToken = "TOK_DESTINATION|TOK_INSERT_INTO".r - protected def nodeToDest( - node: Node, - query: LogicalPlan, - overwrite: Boolean): LogicalPlan = node match { - case Token(destinationToken(), - Token("TOK_DIR", - Token("TOK_TMP_FILE", Nil) :: Nil) :: Nil) => - query - - case Token(destinationToken(), - Token("TOK_TAB", - tableArgs) :: Nil) => - val Some(tableNameParts) :: partitionClause :: Nil = - getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) - - val tableIdent = extractTableIdent(tableNameParts) - - val partitionKeys = partitionClause.map(_.getChildren.asScala.map { - // Parse partitions. We also make keys case insensitive. - case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value)) - case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> None - }.toMap).getOrElse(Map.empty) - - InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, false) - - case Token(destinationToken(), - Token("TOK_TAB", - tableArgs) :: - Token("TOK_IFNOTEXISTS", - ifNotExists) :: Nil) => - val Some(tableNameParts) :: partitionClause :: Nil = - getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) - - val tableIdent = extractTableIdent(tableNameParts) - - val partitionKeys = partitionClause.map(_.getChildren.asScala.map { - // Parse partitions. We also make keys case insensitive. - case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value)) - case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> None - }.toMap).getOrElse(Map.empty) - - InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, true) - - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for ${a.getName}:" + - s"\n ${dumpTree(a).toString} ") - } - - protected def selExprNodeToExpr(node: Node): Option[Expression] = node match { - case Token("TOK_SELEXPR", e :: Nil) => - Some(nodeToExpr(e)) - - case Token("TOK_SELEXPR", e :: Token(alias, Nil) :: Nil) => - Some(Alias(nodeToExpr(e), cleanIdentifier(alias))()) - - case Token("TOK_SELEXPR", e :: aliasChildren) => - var aliasNames = ArrayBuffer[String]() - aliasChildren.foreach { _ match { - case Token(name, Nil) => aliasNames += cleanIdentifier(name) - case _ => - } - } - Some(MultiAlias(nodeToExpr(e), aliasNames)) - - /* Hints are ignored */ - case Token("TOK_HINTLIST", _) => None - - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for ${a.getName }:" + - s"\n ${dumpTree(a).toString } ") - } - - protected val escapedIdentifier = "`([^`]+)`".r - protected val doubleQuotedString = "\"([^\"]+)\"".r - protected val singleQuotedString = "'([^']+)'".r - - protected def unquoteString(str: String) = str match { - case singleQuotedString(s) => s - case doubleQuotedString(s) => s - case other => other - } - - /** Strips backticks from ident if present */ - protected def cleanIdentifier(ident: String): String = ident match { - case escapedIdentifier(i) => i - case plainIdent => plainIdent - } - - val numericAstTypes = Seq( - HiveParser.Number, - HiveParser.TinyintLiteral, - HiveParser.SmallintLiteral, - HiveParser.BigintLiteral, - HiveParser.DecimalLiteral) - - /* Case insensitive matches */ - val COUNT = "(?i)COUNT".r - val SUM = "(?i)SUM".r - val AND = "(?i)AND".r - val OR = "(?i)OR".r - val NOT = "(?i)NOT".r - val TRUE = "(?i)TRUE".r - val FALSE = "(?i)FALSE".r - val LIKE = "(?i)LIKE".r - val RLIKE = "(?i)RLIKE".r - val REGEXP = "(?i)REGEXP".r - val IN = "(?i)IN".r - val DIV = "(?i)DIV".r - val BETWEEN = "(?i)BETWEEN".r - val WHEN = "(?i)WHEN".r - val CASE = "(?i)CASE".r - - protected def nodeToExpr(node: Node): Expression = node match { - /* Attribute References */ - case Token("TOK_TABLE_OR_COL", - Token(name, Nil) :: Nil) => - UnresolvedAttribute.quoted(cleanIdentifier(name)) - case Token(".", qualifier :: Token(attr, Nil) :: Nil) => - nodeToExpr(qualifier) match { - case UnresolvedAttribute(nameParts) => - UnresolvedAttribute(nameParts :+ cleanIdentifier(attr)) - case other => UnresolvedExtractValue(other, Literal(attr)) - } - - /* Stars (*) */ - case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None) - // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only - // has a single child which is tableName. - case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) => - UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName(name))) - - /* Aggregate Functions */ - case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => Count(Literal(1)) - case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => CountDistinct(args.map(nodeToExpr)) - case Token("TOK_FUNCTIONDI", Token(SUM(), Nil) :: arg :: Nil) => SumDistinct(nodeToExpr(arg)) - - /* Casts */ - case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), StringType) - case Token("TOK_FUNCTION", Token("TOK_VARCHAR", _) :: arg :: Nil) => - Cast(nodeToExpr(arg), StringType) - case Token("TOK_FUNCTION", Token("TOK_CHAR", _) :: arg :: Nil) => - Cast(nodeToExpr(arg), StringType) - case Token("TOK_FUNCTION", Token("TOK_INT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), IntegerType) - case Token("TOK_FUNCTION", Token("TOK_BIGINT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), LongType) - case Token("TOK_FUNCTION", Token("TOK_FLOAT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), FloatType) - case Token("TOK_FUNCTION", Token("TOK_DOUBLE", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DoubleType) - case Token("TOK_FUNCTION", Token("TOK_SMALLINT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), ShortType) - case Token("TOK_FUNCTION", Token("TOK_TINYINT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), ByteType) - case Token("TOK_FUNCTION", Token("TOK_BINARY", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), BinaryType) - case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), BooleanType) - case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: scale :: nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, scale.getText.toInt)) - case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, 0)) - case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType.USER_DEFAULT) - case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), TimestampType) - case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DateType) - - /* Arithmetic */ - case Token("+", child :: Nil) => nodeToExpr(child) - case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child)) - case Token("~", child :: Nil) => BitwiseNot(nodeToExpr(child)) - case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right)) - case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right)) - case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right)) - case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right)) - case Token(DIV(), left :: right:: Nil) => - Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType) - case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right)) - case Token("&", left :: right:: Nil) => BitwiseAnd(nodeToExpr(left), nodeToExpr(right)) - case Token("|", left :: right:: Nil) => BitwiseOr(nodeToExpr(left), nodeToExpr(right)) - case Token("^", left :: right:: Nil) => BitwiseXor(nodeToExpr(left), nodeToExpr(right)) - - /* Comparisons */ - case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) - case Token("==", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) - case Token("<=>", left :: right:: Nil) => EqualNullSafe(nodeToExpr(left), nodeToExpr(right)) - case Token("!=", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) - case Token("<>", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) - case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right)) - case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right)) - case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right)) - case Token("<=", left :: right:: Nil) => LessThanOrEqual(nodeToExpr(left), nodeToExpr(right)) - case Token(LIKE(), left :: right:: Nil) => Like(nodeToExpr(left), nodeToExpr(right)) - case Token(RLIKE(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) - case Token(REGEXP(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) - case Token("TOK_FUNCTION", Token("TOK_ISNOTNULL", Nil) :: child :: Nil) => - IsNotNull(nodeToExpr(child)) - case Token("TOK_FUNCTION", Token("TOK_ISNULL", Nil) :: child :: Nil) => - IsNull(nodeToExpr(child)) - case Token("TOK_FUNCTION", Token(IN(), Nil) :: value :: list) => - In(nodeToExpr(value), list.map(nodeToExpr)) - case Token("TOK_FUNCTION", - Token(BETWEEN(), Nil) :: - kw :: - target :: - minValue :: - maxValue :: Nil) => - - val targetExpression = nodeToExpr(target) - val betweenExpr = - And( - GreaterThanOrEqual(targetExpression, nodeToExpr(minValue)), - LessThanOrEqual(targetExpression, nodeToExpr(maxValue))) - kw match { - case Token("KW_FALSE", Nil) => betweenExpr - case Token("KW_TRUE", Nil) => Not(betweenExpr) - } - - /* Boolean Logic */ - case Token(AND(), left :: right:: Nil) => And(nodeToExpr(left), nodeToExpr(right)) - case Token(OR(), left :: right:: Nil) => Or(nodeToExpr(left), nodeToExpr(right)) - case Token(NOT(), child :: Nil) => Not(nodeToExpr(child)) - case Token("!", child :: Nil) => Not(nodeToExpr(child)) - - /* Case statements */ - case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) => - CaseWhen(branches.map(nodeToExpr)) - case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) => - val keyExpr = nodeToExpr(branches.head) - CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr)) - - /* Complex datatype manipulation */ - case Token("[", child :: ordinal :: Nil) => - UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal)) - - /* Window Functions */ - case Token("TOK_FUNCTION", Token(name, Nil) +: args :+ Token("TOK_WINDOWSPEC", spec)) => - val function = UnresolvedWindowFunction(name, args.map(nodeToExpr)) - nodesToWindowSpecification(spec) match { - case reference: WindowSpecReference => - UnresolvedWindowExpression(function, reference) - case definition: WindowSpecDefinition => - WindowExpression(function, definition) - } - case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) => - // Safe to use Literal(1)? - val function = UnresolvedWindowFunction(name, Literal(1) :: Nil) - nodesToWindowSpecification(spec) match { - case reference: WindowSpecReference => - UnresolvedWindowExpression(function, reference) - case definition: WindowSpecDefinition => - WindowExpression(function, definition) - } - - /* UDFs - Must be last otherwise will preempt built in functions */ - case Token("TOK_FUNCTION", Token(name, Nil) :: args) => - UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = false) - // Aggregate function with DISTINCT keyword. - case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) => - UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = true) - case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => - UnresolvedFunction(name, UnresolvedStar(None) :: Nil, isDistinct = false) - - /* Literals */ - case Token("TOK_NULL", Nil) => Literal.create(null, NullType) - case Token(TRUE(), Nil) => Literal.create(true, BooleanType) - case Token(FALSE(), Nil) => Literal.create(false, BooleanType) - case Token("TOK_STRINGLITERALSEQUENCE", strings) => - Literal(strings.map(s => BaseSemanticAnalyzer.unescapeSQLString(s.getText)).mkString) - - // This code is adapted from - // /ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java#L223 - case ast: ASTNode if numericAstTypes contains ast.getType => - var v: Literal = null - try { - if (ast.getText.endsWith("L")) { - // Literal bigint. - v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toLong, LongType) - } else if (ast.getText.endsWith("S")) { - // Literal smallint. - v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toShort, ShortType) - } else if (ast.getText.endsWith("Y")) { - // Literal tinyint. - v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toByte, ByteType) - } else if (ast.getText.endsWith("BD") || ast.getText.endsWith("D")) { - // Literal decimal - val strVal = ast.getText.stripSuffix("D").stripSuffix("B") - v = Literal(Decimal(strVal)) - } else { - v = Literal.create(ast.getText.toDouble, DoubleType) - v = Literal.create(ast.getText.toLong, LongType) - v = Literal.create(ast.getText.toInt, IntegerType) - } - } catch { - case nfe: NumberFormatException => // Do nothing - } - - if (v == null) { - sys.error(s"Failed to parse number '${ast.getText}'.") - } else { - v - } - - case ast: ASTNode if ast.getType == HiveParser.StringLiteral => - Literal(BaseSemanticAnalyzer.unescapeSQLString(ast.getText)) - - case ast: ASTNode if ast.getType == HiveParser.TOK_DATELITERAL => - Literal(Date.valueOf(ast.getText.substring(1, ast.getText.length - 1))) - - case ast: ASTNode if ast.getType == HiveParser.TOK_CHARSETLITERAL => - Literal(BaseSemanticAnalyzer.charSetString(ast.getChild(0).getText, ast.getChild(1).getText)) - - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => - Literal(CalendarInterval.fromYearMonthString(ast.getText)) - - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_DAY_TIME_LITERAL => - Literal(CalendarInterval.fromDayTimeString(ast.getText)) - - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_YEAR_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("year", ast.getText)) - - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_MONTH_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("month", ast.getText)) - - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_DAY_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("day", ast.getText)) - - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_HOUR_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("hour", ast.getText)) - - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_MINUTE_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("minute", ast.getText)) - - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_SECOND_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("second", ast.getText)) - - case a: ASTNode => - throw new NotImplementedError( - s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText} : - |${dumpTree(a).toString}" + - """.stripMargin) - } - - /* Case insensitive matches for Window Specification */ - val PRECEDING = "(?i)preceding".r - val FOLLOWING = "(?i)following".r - val CURRENT = "(?i)current".r - def nodesToWindowSpecification(nodes: Seq[ASTNode]): WindowSpec = nodes match { - case Token(windowName, Nil) :: Nil => - // Refer to a window spec defined in the window clause. - WindowSpecReference(windowName) - case Nil => - // OVER() - WindowSpecDefinition( - partitionSpec = Nil, - orderSpec = Nil, - frameSpecification = UnspecifiedFrame) - case spec => - val (partitionClause :: rowFrame :: rangeFrame :: Nil) = - getClauses( - Seq( - "TOK_PARTITIONINGSPEC", - "TOK_WINDOWRANGE", - "TOK_WINDOWVALUES"), - spec) - - // Handle Partition By and Order By. - val (partitionSpec, orderSpec) = partitionClause.map { partitionAndOrdering => - val (partitionByClause :: orderByClause :: sortByClause :: clusterByClause :: Nil) = - getClauses( - Seq("TOK_DISTRIBUTEBY", "TOK_ORDERBY", "TOK_SORTBY", "TOK_CLUSTERBY"), - partitionAndOrdering.getChildren.asScala.asInstanceOf[Seq[ASTNode]]) - - (partitionByClause, orderByClause.orElse(sortByClause), clusterByClause) match { - case (Some(partitionByExpr), Some(orderByExpr), None) => - (partitionByExpr.getChildren.asScala.map(nodeToExpr), - orderByExpr.getChildren.asScala.map(nodeToSortOrder)) - case (Some(partitionByExpr), None, None) => - (partitionByExpr.getChildren.asScala.map(nodeToExpr), Nil) - case (None, Some(orderByExpr), None) => - (Nil, orderByExpr.getChildren.asScala.map(nodeToSortOrder)) - case (None, None, Some(clusterByExpr)) => - val expressions = clusterByExpr.getChildren.asScala.map(nodeToExpr) - (expressions, expressions.map(SortOrder(_, Ascending))) - case _ => - throw new NotImplementedError( - s"""No parse rules for Node ${partitionAndOrdering.getName} - """.stripMargin) - } - }.getOrElse { - (Nil, Nil) - } - - // Handle Window Frame - val windowFrame = - if (rowFrame.isEmpty && rangeFrame.isEmpty) { - UnspecifiedFrame - } else { - val frameType = rowFrame.map(_ => RowFrame).getOrElse(RangeFrame) - def nodeToBoundary(node: Node): FrameBoundary = node match { - case Token(PRECEDING(), Token(count, Nil) :: Nil) => - if (count.toLowerCase() == "unbounded") { - UnboundedPreceding - } else { - ValuePreceding(count.toInt) - } - case Token(FOLLOWING(), Token(count, Nil) :: Nil) => - if (count.toLowerCase() == "unbounded") { - UnboundedFollowing - } else { - ValueFollowing(count.toInt) - } - case Token(CURRENT(), Nil) => CurrentRow - case _ => - throw new NotImplementedError( - s"""No parse rules for the Window Frame Boundary based on Node ${node.getName} - """.stripMargin) - } - - rowFrame.orElse(rangeFrame).map { frame => - frame.getChildren.asScala.toList match { - case precedingNode :: followingNode :: Nil => - SpecifiedWindowFrame( - frameType, - nodeToBoundary(precedingNode), - nodeToBoundary(followingNode)) - case precedingNode :: Nil => - SpecifiedWindowFrame(frameType, nodeToBoundary(precedingNode), CurrentRow) - case _ => - throw new NotImplementedError( - s"""No parse rules for the Window Frame based on Node ${frame.getName} - """.stripMargin) - } - }.getOrElse(sys.error(s"If you see this, please file a bug report with your query.")) - } - - WindowSpecDefinition(partitionSpec, orderSpec, windowFrame) - } - - val explode = "(?i)explode".r - def nodesToGenerator(nodes: Seq[Node]): (Generator, Seq[String]) = { - val function = nodes.head - - val attributes = nodes.flatMap { - case Token(a, Nil) => a.toLowerCase :: Nil - case _ => Nil - } - - function match { - case Token("TOK_FUNCTION", Token(explode(), Nil) :: child :: Nil) => - (Explode(nodeToExpr(child)), attributes) - - case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => - val functionInfo: FunctionInfo = - Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse( - sys.error(s"Couldn't find function $functionName")) - val functionClassName = functionInfo.getFunctionClass.getName - - (HiveGenericUDTF( - new HiveFunctionWrapper(functionClassName), - children.map(nodeToExpr)), attributes) - - case a: ASTNode => - throw new NotImplementedError( - s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText}, tree: - |${dumpTree(a).toString} - """.stripMargin) - } - } - - def dumpTree(node: Node, builder: StringBuilder = new StringBuilder, indent: Int = 0) - : StringBuilder = { - node match { - case a: ASTNode => builder.append( - (" " * indent) + a.getText + " " + - a.getLine + ", " + - a.getTokenStartIndex + "," + - a.getTokenStopIndex + ", " + - a.getCharPositionInLine + "\n") - case other => sys.error(s"Non ASTNode encountered: $other") - } - - Option(node.getChildren).map(_.asScala).getOrElse(Nil).foreach(dumpTree(_, builder, indent + 1)) - builder - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala new file mode 100644 index 0000000000000..0cccc22e5a624 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import scala.util.{Failure, Success, Try} +import scala.util.control.NonFatal + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} +import org.apache.hadoop.hive.ql.exec.{FunctionRegistry => HiveFunctionRegistry} +import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, SessionCatalog} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.datasources.BucketSpec +import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper +import org.apache.spark.sql.hive.client.HiveClient +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils + + +private[sql] class HiveSessionCatalog( + externalCatalog: HiveExternalCatalog, + client: HiveClient, + context: HiveContext, + functionResourceLoader: FunctionResourceLoader, + functionRegistry: FunctionRegistry, + conf: SQLConf) + extends SessionCatalog(externalCatalog, functionResourceLoader, functionRegistry, conf) { + + override def setCurrentDatabase(db: String): Unit = { + super.setCurrentDatabase(db) + client.setCurrentDatabase(db) + } + + override def lookupRelation(name: TableIdentifier, alias: Option[String]): LogicalPlan = { + val table = formatTableName(name.table) + if (name.database.isDefined || !tempTables.contains(table)) { + val newName = name.copy(table = table) + metastoreCatalog.lookupRelation(newName, alias) + } else { + val relation = tempTables(table) + val tableWithQualifiers = SubqueryAlias(table, relation) + // If an alias was specified by the lookup, wrap the plan in a subquery so that + // attributes are properly qualified with this alias. + alias.map(a => SubqueryAlias(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) + } + } + + // ---------------------------------------------------------------- + // | Methods and fields for interacting with HiveMetastoreCatalog | + // ---------------------------------------------------------------- + + override def getDefaultDBPath(db: String): String = { + val defaultPath = context.hiveconf.getVar(HiveConf.ConfVars.METASTOREWAREHOUSE) + new Path(new Path(defaultPath), db + ".db").toString + } + + // Catalog for handling data source tables. TODO: This really doesn't belong here since it is + // essentially a cache for metastore tables. However, it relies on a lot of session-specific + // things so it would be a lot of work to split its functionality between HiveSessionCatalog + // and HiveCatalog. We should still do it at some point... + private val metastoreCatalog = new HiveMetastoreCatalog(client, context) + + val ParquetConversions: Rule[LogicalPlan] = metastoreCatalog.ParquetConversions + val OrcConversions: Rule[LogicalPlan] = metastoreCatalog.OrcConversions + val CreateTables: Rule[LogicalPlan] = metastoreCatalog.CreateTables + val PreInsertionCasts: Rule[LogicalPlan] = metastoreCatalog.PreInsertionCasts + + override def refreshTable(name: TableIdentifier): Unit = { + metastoreCatalog.refreshTable(name) + } + + override def invalidateTable(name: TableIdentifier): Unit = { + metastoreCatalog.invalidateTable(name) + } + + def invalidateCache(): Unit = { + metastoreCatalog.cachedDataSourceTables.invalidateAll() + } + + def createDataSourceTable( + name: TableIdentifier, + userSpecifiedSchema: Option[StructType], + partitionColumns: Array[String], + bucketSpec: Option[BucketSpec], + provider: String, + options: Map[String, String], + isExternal: Boolean): Unit = { + metastoreCatalog.createDataSourceTable( + name, userSpecifiedSchema, partitionColumns, bucketSpec, provider, options, isExternal) + } + + def hiveDefaultTableFilePath(name: TableIdentifier): String = { + metastoreCatalog.hiveDefaultTableFilePath(name) + } + + // For testing only + private[hive] def getCachedDataSourceTable(table: TableIdentifier): LogicalPlan = { + val key = metastoreCatalog.getQualifiedTableName(table) + metastoreCatalog.cachedDataSourceTables.getIfPresent(key) + } + + override def makeFunctionBuilder(funcName: String, className: String): FunctionBuilder = { + makeFunctionBuilder(funcName, Utils.classForName(className)) + } + + /** + * Construct a [[FunctionBuilder]] based on the provided class that represents a function. + */ + private def makeFunctionBuilder(name: String, clazz: Class[_]): FunctionBuilder = { + // When we instantiate hive UDF wrapper class, we may throw exception if the input + // expressions don't satisfy the hive UDF, such as type mismatch, input number + // mismatch, etc. Here we catch the exception and throw AnalysisException instead. + (children: Seq[Expression]) => { + try { + if (classOf[UDF].isAssignableFrom(clazz)) { + val udf = HiveSimpleUDF(name, new HiveFunctionWrapper(clazz.getName), children) + udf.dataType // Force it to check input data types. + udf + } else if (classOf[GenericUDF].isAssignableFrom(clazz)) { + val udf = HiveGenericUDF(name, new HiveFunctionWrapper(clazz.getName), children) + udf.dataType // Force it to check input data types. + udf + } else if (classOf[AbstractGenericUDAFResolver].isAssignableFrom(clazz)) { + val udaf = HiveUDAFFunction(name, new HiveFunctionWrapper(clazz.getName), children) + udaf.dataType // Force it to check input data types. + udaf + } else if (classOf[UDAF].isAssignableFrom(clazz)) { + val udaf = HiveUDAFFunction( + name, + new HiveFunctionWrapper(clazz.getName), + children, + isUDAFBridgeRequired = true) + udaf.dataType // Force it to check input data types. + udaf + } else if (classOf[GenericUDTF].isAssignableFrom(clazz)) { + val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), children) + udtf.elementTypes // Force it to check input data types. + udtf + } else { + throw new AnalysisException(s"No handler for Hive UDF '${clazz.getCanonicalName}'") + } + } catch { + case ae: AnalysisException => + throw ae + case NonFatal(e) => + val analysisException = + new AnalysisException(s"No handler for Hive UDF '${clazz.getCanonicalName}': $e") + analysisException.setStackTrace(e.getStackTrace) + throw analysisException + } + } + } + + // We have a list of Hive built-in functions that we do not support. So, we will check + // Hive's function registry and lazily load needed functions into our own function registry. + // Those Hive built-in functions are + // assert_true, collect_list, collect_set, compute_stats, context_ngrams, create_union, + // current_user ,elt, ewah_bitmap, ewah_bitmap_and, ewah_bitmap_empty, ewah_bitmap_or, field, + // histogram_numeric, in_file, index, inline, java_method, map_keys, map_values, + // matchpath, ngrams, noop, noopstreaming, noopwithmap, noopwithmapstreaming, + // parse_url, parse_url_tuple, percentile, percentile_approx, posexplode, reflect, reflect2, + // regexp, sentences, stack, std, str_to_map, windowingtablefunction, xpath, xpath_boolean, + // xpath_double, xpath_float, xpath_int, xpath_long, xpath_number, + // xpath_short, and xpath_string. + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + // TODO: Once lookupFunction accepts a FunctionIdentifier, we should refactor this method to + // if (super.functionExists(name)) { + // super.lookupFunction(name, children) + // } else { + // // This function is a Hive builtin function. + // ... + // } + Try(super.lookupFunction(name, children)) match { + case Success(expr) => expr + case Failure(error) => + if (functionRegistry.functionExists(name)) { + // If the function actually exists in functionRegistry, it means that there is an + // error when we create the Expression using the given children. + // We need to throw the original exception. + throw error + } else { + // This function is not in functionRegistry, let's try to load it as a Hive's + // built-in function. + // Hive is case insensitive. + val functionName = name.toLowerCase + // TODO: This may not really work for current_user because current_user is not evaluated + // with session info. + // We do not need to use executionHive at here because we only load + // Hive's builtin functions, which do not need current db. + val functionInfo = { + try { + Option(HiveFunctionRegistry.getFunctionInfo(functionName)).getOrElse( + failFunctionLookup(name)) + } catch { + // If HiveFunctionRegistry.getFunctionInfo throws an exception, + // we are failing to load a Hive builtin function, which means that + // the given function is not a Hive builtin function. + case NonFatal(e) => failFunctionLookup(name) + } + } + val className = functionInfo.getFunctionClass.getName + val builder = makeFunctionBuilder(functionName, className) + // Put this Hive built-in function to our function registry. + val info = new ExpressionInfo(className, functionName) + createTempFunction(functionName, info, builder, ignoreIfExists = false) + // Now, we need to create the Expression. + functionRegistry.lookupFunction(functionName, children) + } + } + } + + // Pre-load a few commonly used Hive built-in functions. + HiveSessionCatalog.preloadedHiveBuiltinFunctions.foreach { + case (functionName, clazz) => + val builder = makeFunctionBuilder(functionName, clazz) + val info = new ExpressionInfo(clazz.getCanonicalName, functionName) + createTempFunction(functionName, info, builder, ignoreIfExists = false) + } +} + +private[sql] object HiveSessionCatalog { + // This is the list of Hive's built-in functions that are commonly used and we want to + // pre-load when we create the FunctionRegistry. + val preloadedHiveBuiltinFunctions = + ("collect_set", classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectSet]) :: + ("collect_list", classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList]) :: Nil +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala new file mode 100644 index 0000000000000..b992fda18cef7 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.execution.{python, SparkPlanner} +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.hive.execution.HiveSqlParser +import org.apache.spark.sql.internal.{SessionState, SQLConf} + + +/** + * A class that holds all session-specific state in a given [[HiveContext]]. + */ +private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) { + + override lazy val conf: SQLConf = new SQLConf { + override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) + } + + /** + * Internal catalog for managing table and database states. + */ + override lazy val catalog = { + new HiveSessionCatalog( + ctx.hiveCatalog, + ctx.metadataHive, + ctx, + ctx.functionResourceLoader, + functionRegistry, + conf) + } + + /** + * An analyzer that uses the Hive metastore. + */ + override lazy val analyzer: Analyzer = { + new Analyzer(catalog, conf) { + override val extendedResolutionRules = + catalog.ParquetConversions :: + catalog.OrcConversions :: + catalog.CreateTables :: + catalog.PreInsertionCasts :: + PreInsertCastAndRename :: + DataSourceAnalysis :: + (if (conf.runSQLOnFile) new ResolveDataSource(ctx) :: Nil else Nil) + + override val extendedCheckRules = Seq(PreWriteCheck(conf, catalog)) + } + } + + /** + * Parser for HiveQl query texts. + */ + override lazy val sqlParser: ParserInterface = HiveSqlParser + + /** + * Planner that takes into account Hive-specific strategies. + */ + override def planner: SparkPlanner = { + new SparkPlanner(ctx.sparkContext, conf, experimentalMethods.extraStrategies) + with HiveStrategies { + override val hiveContext = ctx + + override def strategies: Seq[Strategy] = { + experimentalMethods.extraStrategies ++ Seq( + FileSourceStrategy, + DataSourceStrategy, + HiveCommandStrategy(ctx), + HiveDDLStrategy, + DDLStrategy, + SpecialLimits, + InMemoryScans, + HiveTableScans, + DataSinks, + Scripts, + Aggregation, + ExistenceJoin, + EquiJoinSelection, + BasicOperators, + BroadcastNestedLoop, + CartesianProduct, + DefaultJoin + ) + } + } + } + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index f0697613cff3b..0d2a765a388aa 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -20,25 +20,25 @@ package org.apache.spark.sql.hive import java.io.{InputStream, OutputStream} import java.rmi.server.UID -import org.apache.avro.Schema - import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag -import com.esotericsoftware.kryo.Kryo -import com.esotericsoftware.kryo.io.{Input, Output} - +import com.google.common.base.Objects +import org.apache.avro.Schema import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro import org.apache.hadoop.hive.serde2.ColumnProjectionUtils import org.apache.hadoop.hive.serde2.avro.{AvroGenericRecordWritable, AvroSerdeUtils} import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector import org.apache.hadoop.io.Writable +import org.apache.hive.com.esotericsoftware.kryo.Kryo +import org.apache.hive.com.esotericsoftware.kryo.io.{Input, Output} -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.sql.types.Decimal import org.apache.spark.util.Utils @@ -47,6 +47,7 @@ private[hive] object HiveShim { // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) val UNLIMITED_DECIMAL_PRECISION = 38 val UNLIMITED_DECIMAL_SCALE = 18 + val HIVE_GENERIC_UDF_MACRO_CLS = "org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro" /* * This function in hive-0.13 become private, but we have to do this to walkaround hive bug @@ -125,6 +126,26 @@ private[hive] object HiveShim { // for Serialization def this() = this(null) + override def hashCode(): Int = { + if (functionClassName == HIVE_GENERIC_UDF_MACRO_CLS) { + Objects.hashCode(functionClassName, instance.asInstanceOf[GenericUDFMacro].getBody()) + } else { + functionClassName.hashCode() + } + } + + override def equals(other: Any): Boolean = other match { + case a: HiveFunctionWrapper if functionClassName == a.functionClassName => + // In case of udf macro, check to make sure they point to the same underlying UDF + if (functionClassName == HIVE_GENERIC_UDF_MACRO_CLS) { + a.instance.asInstanceOf[GenericUDFMacro].getBody() == + instance.asInstanceOf[GenericUDFMacro].getBody() + } else { + true + } + case _ => false + } + @transient def deserializeObjectByKryo[T: ClassTag]( kryo: Kryo, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index d38ad9127327d..010361a32eb34 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -22,14 +22,14 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand} -import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand, _} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.command.{DescribeCommand => _, _} +import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTableUsingAsSelect, CreateTempTableUsingAsSelect, DescribeCommand} import org.apache.spark.sql.hive.execution._ - private[hive] trait HiveStrategies { // Possibly being too clever with types here... or not clever enough. - self: SQLContext#SparkPlanner => + self: SparkPlanner => val hiveContext: HiveContext @@ -89,10 +89,14 @@ private[hive] trait HiveStrategies { tableIdent, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath) ExecutedCommand(cmd) :: Nil - case CreateTableUsingAsSelect( - tableIdent, provider, false, partitionCols, mode, opts, query) => - val cmd = - CreateMetastoreDataSourceAsSelect(tableIdent, provider, partitionCols, mode, opts, query) + case c: CreateTableUsingAsSelect if c.temporary => + val cmd = CreateTempTableUsingAsSelect( + c.tableIdent, c.provider, c.partitionColumns, c.mode, c.options, c.child) + ExecutedCommand(cmd) :: Nil + + case c: CreateTableUsingAsSelect => + val cmd = CreateMetastoreDataSourceAsSelect(c.tableIdent, c.provider, c.partitionColumns, + c.bucketSpec, c.mode, c.options, c.child) ExecutedCommand(cmd) :: Nil case _ => Nil @@ -102,18 +106,8 @@ private[hive] trait HiveStrategies { case class HiveCommandStrategy(context: HiveContext) extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case describe: DescribeCommand => - val resolvedTable = context.executePlan(describe.table).analyzed - resolvedTable match { - case t: MetastoreRelation => - ExecutedCommand( - DescribeHiveTableCommand(t, describe.output, describe.isExtended)) :: Nil - - case o: LogicalPlan => - val resultPlan = context.executePlan(o).executedPlan - ExecutedCommand(RunnableDescribeCommand( - resultPlan, describe.output, describe.isExtended)) :: Nil - } - + ExecutedCommand( + DescribeHiveTableCommand(describe.table, describe.output, describe.isExtended)) :: Nil case _ => Nil } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala new file mode 100644 index 0000000000000..e54358e657690 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -0,0 +1,533 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.util.concurrent.atomic.AtomicLong + +import scala.util.control.NonFatal + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer.{CollapseProject, CombineUnions} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.hive.execution.HiveScriptIOSchema +import org.apache.spark.sql.types.{ByteType, DataType, IntegerType, NullType} + +/** + * A builder class used to convert a resolved logical plan into a SQL query string. Note that not + * all resolved logical plan are convertible. They either don't have corresponding SQL + * representations (e.g. logical plans that operate on local Scala collections), or are simply not + * supported by this builder (yet). + */ +class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging { + require(logicalPlan.resolved, "SQLBuilder only supports resolved logical query plans") + + def this(df: DataFrame) = this(df.queryExecution.analyzed, df.sqlContext) + + private val nextSubqueryId = new AtomicLong(0) + private def newSubqueryName(): String = s"gen_subquery_${nextSubqueryId.getAndIncrement()}" + + def toSQL: String = { + val canonicalizedPlan = Canonicalizer.execute(logicalPlan) + val outputNames = logicalPlan.output.map(_.name) + val qualifiers = logicalPlan.output.flatMap(_.qualifier).distinct + + // Keep the qualifier information by using it as sub-query name, if there is only one qualifier + // present. + val finalName = if (qualifiers.length == 1) { + qualifiers.head + } else { + newSubqueryName() + } + + // Canonicalizer will remove all naming information, we should add it back by adding an extra + // Project and alias the outputs. + val aliasedOutput = canonicalizedPlan.output.zip(outputNames).map { + case (attr, name) => Alias(attr.withQualifier(None), name)() + } + val finalPlan = Project(aliasedOutput, SubqueryAlias(finalName, canonicalizedPlan)) + + try { + val replaced = finalPlan.transformAllExpressions { + case e: SubqueryExpression => + SubqueryHolder(new SQLBuilder(e.query, sqlContext).toSQL) + case e: NonSQLExpression => + throw new UnsupportedOperationException( + s"Expression $e doesn't have a SQL representation" + ) + case e => e + } + + val generatedSQL = toSQL(replaced) + logDebug( + s"""Built SQL query string successfully from given logical plan: + | + |# Original logical plan: + |${logicalPlan.treeString} + |# Canonicalized logical plan: + |${replaced.treeString} + |# Generated SQL: + |$generatedSQL + """.stripMargin) + generatedSQL + } catch { case NonFatal(e) => + logDebug( + s"""Failed to build SQL query string from given logical plan: + | + |# Original logical plan: + |${logicalPlan.treeString} + |# Canonicalized logical plan: + |${canonicalizedPlan.treeString} + """.stripMargin) + throw e + } + } + + private def toSQL(node: LogicalPlan): String = node match { + case Distinct(p: Project) => + projectToSQL(p, isDistinct = true) + + case p: Project => + projectToSQL(p, isDistinct = false) + + case a @ Aggregate(_, _, e @ Expand(_, _, p: Project)) if isGroupingSet(a, e, p) => + groupingSetToSQL(a, e, p) + + case p: Aggregate => + aggregateToSQL(p) + + case w: Window => + windowToSQL(w) + + case g: Generate => + generateToSQL(g) + + case Limit(limitExpr, child) => + s"${toSQL(child)} LIMIT ${limitExpr.sql}" + + case Filter(condition, child) => + val whereOrHaving = child match { + case _: Aggregate => "HAVING" + case _ => "WHERE" + } + build(toSQL(child), whereOrHaving, condition.sql) + + case p @ Distinct(u: Union) if u.children.length > 1 => + val childrenSql = u.children.map(c => s"(${toSQL(c)})") + childrenSql.mkString(" UNION DISTINCT ") + + case p: Union if p.children.length > 1 => + val childrenSql = p.children.map(c => s"(${toSQL(c)})") + childrenSql.mkString(" UNION ALL ") + + case p: Intersect => + build("(" + toSQL(p.left), ") INTERSECT (", toSQL(p.right) + ")") + + case p: Except => + build("(" + toSQL(p.left), ") EXCEPT (", toSQL(p.right) + ")") + + case p: SubqueryAlias => build("(" + toSQL(p.child) + ")", "AS", p.alias) + + case p: Join => + build( + toSQL(p.left), + p.joinType.sql, + "JOIN", + toSQL(p.right), + p.condition.map(" ON " + _.sql).getOrElse("")) + + case SQLTable(database, table, _, sample) => + val qualifiedName = s"${quoteIdentifier(database)}.${quoteIdentifier(table)}" + sample.map { case (lowerBound, upperBound) => + val fraction = math.min(100, math.max(0, (upperBound - lowerBound) * 100)) + qualifiedName + " TABLESAMPLE(" + fraction + " PERCENT)" + }.getOrElse(qualifiedName) + + case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _)) + if orders.map(_.child) == partitionExprs => + build(toSQL(child), "CLUSTER BY", partitionExprs.map(_.sql).mkString(", ")) + + case p: Sort => + build( + toSQL(p.child), + if (p.global) "ORDER BY" else "SORT BY", + p.order.map(_.sql).mkString(", ") + ) + + case p: RepartitionByExpression => + build( + toSQL(p.child), + "DISTRIBUTE BY", + p.partitionExpressions.map(_.sql).mkString(", ") + ) + + case p: ScriptTransformation => + scriptTransformationToSQL(p) + + case OneRowRelation => + "" + + case _ => + throw new UnsupportedOperationException(s"unsupported plan $node") + } + + /** + * Turns a bunch of string segments into a single string and separate each segment by a space. + * The segments are trimmed so only a single space appears in the separation. + * For example, `build("a", " b ", " c")` becomes "a b c". + */ + private def build(segments: String*): String = + segments.map(_.trim).filter(_.nonEmpty).mkString(" ") + + private def projectToSQL(plan: Project, isDistinct: Boolean): String = { + build( + "SELECT", + if (isDistinct) "DISTINCT" else "", + plan.projectList.map(_.sql).mkString(", "), + if (plan.child == OneRowRelation) "" else "FROM", + toSQL(plan.child) + ) + } + + private def scriptTransformationToSQL(plan: ScriptTransformation): String = { + val ioSchema = plan.ioschema.asInstanceOf[HiveScriptIOSchema] + val inputRowFormatSQL = ioSchema.inputRowFormatSQL.getOrElse( + throw new UnsupportedOperationException( + s"unsupported row format ${ioSchema.inputRowFormat}")) + val outputRowFormatSQL = ioSchema.outputRowFormatSQL.getOrElse( + throw new UnsupportedOperationException( + s"unsupported row format ${ioSchema.outputRowFormat}")) + + val outputSchema = plan.output.map { attr => + s"${attr.sql} ${attr.dataType.simpleString}" + }.mkString(", ") + + build( + "SELECT TRANSFORM", + "(" + plan.input.map(_.sql).mkString(", ") + ")", + inputRowFormatSQL, + s"USING \'${plan.script}\'", + "AS (" + outputSchema + ")", + outputRowFormatSQL, + if (plan.child == OneRowRelation) "" else "FROM", + toSQL(plan.child) + ) + } + + private def aggregateToSQL(plan: Aggregate): String = { + val groupingSQL = plan.groupingExpressions.map(_.sql).mkString(", ") + build( + "SELECT", + plan.aggregateExpressions.map(_.sql).mkString(", "), + if (plan.child == OneRowRelation) "" else "FROM", + toSQL(plan.child), + if (groupingSQL.isEmpty) "" else "GROUP BY", + groupingSQL + ) + } + + private def generateToSQL(g: Generate): String = { + val columnAliases = g.generatorOutput.map(_.sql).mkString(", ") + + val childSQL = if (g.child == OneRowRelation) { + // This only happens when we put UDTF in project list and there is no FROM clause. Because we + // always generate LATERAL VIEW for `Generate`, here we use a trick to put a dummy sub-query + // after FROM clause, so that we can generate a valid LATERAL VIEW SQL string. + // For example, if the original SQL is: "SELECT EXPLODE(ARRAY(1, 2))", we will convert in to + // LATERAL VIEW format, and generate: + // SELECT col FROM (SELECT 1) sub_q0 LATERAL VIEW EXPLODE(ARRAY(1, 2)) sub_q1 AS col + s"(SELECT 1) ${newSubqueryName()}" + } else { + toSQL(g.child) + } + + // The final SQL string for Generate contains 7 parts: + // 1. the SQL of child, can be a table or sub-query + // 2. the LATERAL VIEW keyword + // 3. an optional OUTER keyword + // 4. the SQL of generator, e.g. EXPLODE(array_col) + // 5. the table alias for output columns of generator. + // 6. the AS keyword + // 7. the column alias, can be more than one, e.g. AS key, value + // An concrete example: "tbl LATERAL VIEW EXPLODE(map_col) sub_q AS key, value", and the builder + // will put it in FROM clause later. + build( + childSQL, + "LATERAL VIEW", + if (g.outer) "OUTER" else "", + g.generator.sql, + newSubqueryName(), + "AS", + columnAliases + ) + } + + private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = + output1.size == output2.size && + output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) + + private def isGroupingSet(a: Aggregate, e: Expand, p: Project): Boolean = { + assert(a.child == e && e.child == p) + a.groupingExpressions.forall(_.isInstanceOf[Attribute]) && + sameOutput(e.output, p.child.output ++ a.groupingExpressions.map(_.asInstanceOf[Attribute])) + } + + private def groupingSetToSQL( + agg: Aggregate, + expand: Expand, + project: Project): String = { + assert(agg.groupingExpressions.length > 1) + + // The last column of Expand is always grouping ID + val gid = expand.output.last + + val numOriginalOutput = project.child.output.length + // Assumption: Aggregate's groupingExpressions is composed of + // 1) the attributes of aliased group by expressions + // 2) gid, which is always the last one + val groupByAttributes = agg.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute]) + // Assumption: Project's projectList is composed of + // 1) the original output (Project's child.output), + // 2) the aliased group by expressions. + val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child) + val groupingSQL = groupByExprs.map(_.sql).mkString(", ") + + // a map from group by attributes to the original group by expressions. + val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs)) + + val groupingSet: Seq[Seq[Expression]] = expand.projections.map { project => + // Assumption: expand.projections is composed of + // 1) the original output (Project's child.output), + // 2) group by attributes(or null literal) + // 3) gid, which is always the last one in each project in Expand + project.drop(numOriginalOutput).dropRight(1).collect { + case attr: Attribute if groupByAttrMap.contains(attr) => groupByAttrMap(attr) + } + } + val groupingSetSQL = "GROUPING SETS(" + + groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")" + + val aggExprs = agg.aggregateExpressions.map { case aggExpr => + val originalAggExpr = aggExpr.transformDown { + // grouping_id() is converted to VirtualColumn.groupingIdName by Analyzer. Revert it back. + case ar: AttributeReference if ar == gid => GroupingID(Nil) + case ar: AttributeReference if groupByAttrMap.contains(ar) => groupByAttrMap(ar) + case a @ Cast(BitwiseAnd( + ShiftRight(ar: AttributeReference, Literal(value: Any, IntegerType)), + Literal(1, IntegerType)), ByteType) if ar == gid => + // for converting an expression to its original SQL format grouping(col) + val idx = groupByExprs.length - 1 - value.asInstanceOf[Int] + groupByExprs.lift(idx).map(Grouping).getOrElse(a) + } + + originalAggExpr match { + // Ancestor operators may reference the output of this grouping set, and we use exprId to + // generate a unique name for each attribute, so we should make sure the transformed + // aggregate expression won't change the output, i.e. exprId and alias name should remain + // the same. + case ne: NamedExpression if ne.exprId == aggExpr.exprId => ne + case e => Alias(e, normalizedName(aggExpr))(exprId = aggExpr.exprId) + } + } + + build( + "SELECT", + aggExprs.map(_.sql).mkString(", "), + if (agg.child == OneRowRelation) "" else "FROM", + toSQL(project.child), + "GROUP BY", + groupingSQL, + groupingSetSQL + ) + } + + private def windowToSQL(w: Window): String = { + build( + "SELECT", + (w.child.output ++ w.windowExpressions).map(_.sql).mkString(", "), + if (w.child == OneRowRelation) "" else "FROM", + toSQL(w.child) + ) + } + + private def normalizedName(n: NamedExpression): String = "gen_attr_" + n.exprId.id + + object Canonicalizer extends RuleExecutor[LogicalPlan] { + override protected def batches: Seq[Batch] = Seq( + Batch("Prepare", FixedPoint(100), + // The `WidenSetOperationTypes` analysis rule may introduce extra `Project`s over + // `Aggregate`s to perform type casting. This rule merges these `Project`s into + // `Aggregate`s. + CollapseProject, + // Parser is unable to parse the following query: + // SELECT `u_1`.`id` + // FROM (((SELECT `t0`.`id` FROM `default`.`t0`) + // UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) + // UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) AS u_1 + // This rule combine adjacent Unions together so we can generate flat UNION ALL SQL string. + CombineUnions), + Batch("Recover Scoping Info", Once, + // A logical plan is allowed to have same-name outputs with different qualifiers(e.g. the + // `Join` operator). However, this kind of plan can't be put under a sub query as we will + // erase and assign a new qualifier to all outputs and make it impossible to distinguish + // same-name outputs. This rule renames all attributes, to guarantee different + // attributes(with different exprId) always have different names. It also removes all + // qualifiers, as attributes have unique names now and we don't need qualifiers to resolve + // ambiguity. + NormalizedAttribute, + // Our analyzer will add one or more sub-queries above table relation, this rule removes + // these sub-queries so that next rule can combine adjacent table relation and sample to + // SQLTable. + RemoveSubqueriesAboveSQLTable, + // Finds the table relations and wrap them with `SQLTable`s. If there are any `Sample` + // operators on top of a table relation, merge the sample information into `SQLTable` of + // that table relation, as we can only convert table sample to standard SQL string. + ResolveSQLTable, + // Insert sub queries on top of operators that need to appear after FROM clause. + AddSubquery + ) + ) + + object NormalizedAttribute extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { + case a: AttributeReference => + AttributeReference(normalizedName(a), a.dataType)(exprId = a.exprId, qualifier = None) + case a: Alias => + Alias(a.child, normalizedName(a))(exprId = a.exprId, qualifier = None) + } + } + + object RemoveSubqueriesAboveSQLTable extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case SubqueryAlias(_, t @ ExtractSQLTable(_)) => t + } + } + + object ResolveSQLTable extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformDown { + case Sample(lowerBound, upperBound, _, _, ExtractSQLTable(table)) => + aliasColumns(table.withSample(lowerBound, upperBound)) + case ExtractSQLTable(table) => + aliasColumns(table) + } + + /** + * Aliases the table columns to the generated attribute names, as we use exprId to generate + * unique name for each attribute when normalize attributes, and we can't reference table + * columns with their real names. + */ + private def aliasColumns(table: SQLTable): LogicalPlan = { + val aliasedOutput = table.output.map { attr => + Alias(attr, normalizedName(attr))(exprId = attr.exprId) + } + addSubquery(Project(aliasedOutput, table)) + } + } + + object AddSubquery extends Rule[LogicalPlan] { + override def apply(tree: LogicalPlan): LogicalPlan = tree transformUp { + // This branch handles aggregate functions within HAVING clauses. For example: + // + // SELECT key FROM src GROUP BY key HAVING max(value) > "val_255" + // + // This kind of query results in query plans of the following form because of analysis rule + // `ResolveAggregateFunctions`: + // + // Project ... + // +- Filter ... + // +- Aggregate ... + // +- MetastoreRelation default, src, None + case p @ Project(_, f @ Filter(_, _: Aggregate)) => p.copy(child = addSubquery(f)) + + case w @ Window(_, _, _, f @ Filter(_, _: Aggregate)) => w.copy(child = addSubquery(f)) + + case p: Project => p.copy(child = addSubqueryIfNeeded(p.child)) + + // We will generate "SELECT ... FROM ..." for Window operator, so its child operator should + // be able to put in the FROM clause, or we wrap it with a subquery. + case w: Window => w.copy(child = addSubqueryIfNeeded(w.child)) + + case j: Join => j.copy( + left = addSubqueryIfNeeded(j.left), + right = addSubqueryIfNeeded(j.right)) + + // A special case for Generate. When we put UDTF in project list, followed by WHERE, e.g. + // SELECT EXPLODE(arr) FROM tbl WHERE id > 1, the Filter operator will be under Generate + // operator and we need to add a sub-query between them, as it's not allowed to have a WHERE + // before LATERAL VIEW, e.g. "... FROM tbl WHERE id > 2 EXPLODE(arr) ..." is illegal. + case g @ Generate(_, _, _, _, _, f: Filter) => + // Add an extra `Project` to make sure we can generate legal SQL string for sub-query, + // for example, Subquery -> Filter -> Table will generate "(tbl WHERE ...) AS name", which + // misses the SELECT part. + val proj = Project(f.output, f) + g.copy(child = addSubquery(proj)) + } + } + + private def addSubquery(plan: LogicalPlan): SubqueryAlias = { + SubqueryAlias(newSubqueryName(), plan) + } + + private def addSubqueryIfNeeded(plan: LogicalPlan): LogicalPlan = plan match { + case _: SubqueryAlias => plan + case _: Filter => plan + case _: Join => plan + case _: LocalLimit => plan + case _: GlobalLimit => plan + case _: SQLTable => plan + case _: Generate => plan + case OneRowRelation => plan + case _ => addSubquery(plan) + } + } + + case class SQLTable( + database: String, + table: String, + output: Seq[Attribute], + sample: Option[(Double, Double)] = None) extends LeafNode { + def withSample(lowerBound: Double, upperBound: Double): SQLTable = + this.copy(sample = Some(lowerBound -> upperBound)) + } + + object ExtractSQLTable { + def unapply(plan: LogicalPlan): Option[SQLTable] = plan match { + case l @ LogicalRelation(_, _, Some(TableIdentifier(table, Some(database)))) => + Some(SQLTable(database, table, l.output.map(_.withQualifier(None)))) + + case m: MetastoreRelation => + Some(SQLTable(m.databaseName, m.tableName, m.output.map(_.withQualifier(None)))) + + case _ => None + } + } + + /** + * A place holder for generated SQL for subquery expression. + */ + case class SubqueryHolder(query: String) extends LeafExpression with Unevaluable { + override def dataType: DataType = NullType + override def nullable: Boolean = true + override def sql: String = s"($query)" + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 69f481c49a655..54afe9c2a3550 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -23,16 +23,19 @@ import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants._ import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable, Hive, HiveUtils, HiveStorageHandler} +import org.apache.hadoop.hive.ql.metadata.{HiveUtils, Partition => HivePartition, + Table => HiveTable} import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.serde2.Deserializer +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, + StructObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.primitive._ -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, StructObjectInspector} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} -import org.apache.spark.Logging import org.apache.spark.broadcast.Broadcast +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -72,8 +75,7 @@ class HadoopTableReader( math.max(sc.hiveconf.getInt("mapred.map.tasks", 1), sc.sparkContext.defaultMinPartitions) } - // TODO: set aws s3 credentials. - + SparkHadoopUtil.get.appendS3AndSparkHadoopConfigurations(sc.sparkContext.conf, hiveExtraConf) private val _broadcastedHiveConf = sc.sparkContext.broadcast(new SerializableConfiguration(hiveExtraConf)) @@ -382,6 +384,9 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { case oi: HiveVarcharObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) + case oi: HiveCharObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => + row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) case oi: HiveDecimalObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, HiveShim.toCatalystDecimal(oi, value)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala deleted file mode 100644 index 9d9a55edd7314..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ /dev/null @@ -1,192 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.client - -import java.io.PrintStream -import java.util.{Map => JMap} -import javax.annotation.Nullable - -import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException} -import org.apache.spark.sql.catalyst.expressions.Expression - -private[hive] case class HiveDatabase(name: String, location: String) - -private[hive] abstract class TableType { val name: String } -private[hive] case object ExternalTable extends TableType { override val name = "EXTERNAL_TABLE" } -private[hive] case object IndexTable extends TableType { override val name = "INDEX_TABLE" } -private[hive] case object ManagedTable extends TableType { override val name = "MANAGED_TABLE" } -private[hive] case object VirtualView extends TableType { override val name = "VIRTUAL_VIEW" } - -// TODO: Use this for Tables and Partitions -private[hive] case class HiveStorageDescriptor( - location: String, - inputFormat: String, - outputFormat: String, - serde: String, - serdeProperties: Map[String, String]) - -private[hive] case class HivePartition( - values: Seq[String], - storage: HiveStorageDescriptor) - -private[hive] case class HiveColumn(name: String, @Nullable hiveType: String, comment: String) -private[hive] case class HiveTable( - specifiedDatabase: Option[String], - name: String, - schema: Seq[HiveColumn], - partitionColumns: Seq[HiveColumn], - properties: Map[String, String], - serdeProperties: Map[String, String], - tableType: TableType, - location: Option[String] = None, - inputFormat: Option[String] = None, - outputFormat: Option[String] = None, - serde: Option[String] = None, - viewText: Option[String] = None) { - - @transient - private[client] var client: ClientInterface = _ - - private[client] def withClient(ci: ClientInterface): this.type = { - client = ci - this - } - - def database: String = specifiedDatabase.getOrElse(sys.error("database not resolved")) - - def isPartitioned: Boolean = partitionColumns.nonEmpty - - def getAllPartitions: Seq[HivePartition] = client.getAllPartitions(this) - - def getPartitions(predicates: Seq[Expression]): Seq[HivePartition] = - client.getPartitionsByFilter(this, predicates) - - // Hive does not support backticks when passing names to the client. - def qualifiedName: String = s"$database.$name" -} - -/** - * An externally visible interface to the Hive client. This interface is shared across both the - * internal and external classloaders for a given version of Hive and thus must expose only - * shared classes. - */ -private[hive] trait ClientInterface { - - /** Returns the Hive Version of this client. */ - def version: HiveVersion - - /** Returns the configuration for the given key in the current session. */ - def getConf(key: String, defaultValue: String): String - - /** - * Runs a HiveQL command using Hive, returning the results as a list of strings. Each row will - * result in one string. - */ - def runSqlHive(sql: String): Seq[String] - - def setOut(stream: PrintStream): Unit - def setInfo(stream: PrintStream): Unit - def setError(stream: PrintStream): Unit - - /** Returns the names of all tables in the given database. */ - def listTables(dbName: String): Seq[String] - - /** Returns the name of the active database. */ - def currentDatabase: String - - /** Returns the metadata for specified database, throwing an exception if it doesn't exist */ - def getDatabase(name: String): HiveDatabase = { - getDatabaseOption(name).getOrElse(throw new NoSuchDatabaseException) - } - - /** Returns the metadata for a given database, or None if it doesn't exist. */ - def getDatabaseOption(name: String): Option[HiveDatabase] - - /** Returns the specified table, or throws [[NoSuchTableException]]. */ - def getTable(dbName: String, tableName: String): HiveTable = { - getTableOption(dbName, tableName).getOrElse(throw new NoSuchTableException) - } - - /** Returns the metadata for the specified table or None if it doens't exist. */ - def getTableOption(dbName: String, tableName: String): Option[HiveTable] - - /** Creates a view with the given metadata. */ - def createView(view: HiveTable): Unit - - /** Updates the given view with new metadata. */ - def alertView(view: HiveTable): Unit - - /** Creates a table with the given metadata. */ - def createTable(table: HiveTable): Unit - - /** Updates the given table with new metadata. */ - def alterTable(table: HiveTable): Unit - - /** Creates a new database with the given name. */ - def createDatabase(database: HiveDatabase): Unit - - /** Returns the specified paritition or None if it does not exist. */ - def getPartitionOption( - hTable: HiveTable, - partitionSpec: JMap[String, String]): Option[HivePartition] - - /** Returns all partitions for the given table. */ - def getAllPartitions(hTable: HiveTable): Seq[HivePartition] - - /** Returns partitions filtered by predicates for the given table. */ - def getPartitionsByFilter(hTable: HiveTable, predicates: Seq[Expression]): Seq[HivePartition] - - /** Loads a static partition into an existing table. */ - def loadPartition( - loadPath: String, - tableName: String, - partSpec: java.util.LinkedHashMap[String, String], // Hive relies on LinkedHashMap ordering - replace: Boolean, - holdDDLTime: Boolean, - inheritTableSpecs: Boolean, - isSkewedStoreAsSubdir: Boolean): Unit - - /** Loads data into an existing table. */ - def loadTable( - loadPath: String, // TODO URI - tableName: String, - replace: Boolean, - holdDDLTime: Boolean): Unit - - /** Loads new dynamic partitions into an existing table. */ - def loadDynamicPartitions( - loadPath: String, - tableName: String, - partSpec: java.util.LinkedHashMap[String, String], // Hive relies on LinkedHashMap ordering - replace: Boolean, - numDP: Int, - holdDDLTime: Boolean, - listBucketingEnabled: Boolean): Unit - - /** Add a jar into class loader */ - def addJar(path: String): Unit - - /** Return a ClientInterface as new session, that will share the class loader and Hive client */ - def newSession(): ClientInterface - - /** Run a function within Hive state (SessionState, HiveConf, Hive client and class loader) */ - def withHiveState[A](f: => A): A - - /** Used for testing only. Removes all metadata from this instance of Hive. */ - def reset(): Unit -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala deleted file mode 100644 index 3dce86c480747..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ /dev/null @@ -1,575 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.client - -import java.io.{File, PrintStream} -import java.util.{Map => JMap} -import javax.annotation.concurrent.GuardedBy - -import scala.collection.JavaConverters._ -import scala.language.reflectiveCalls - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.api.{Database, FieldSchema} -import org.apache.hadoop.hive.metastore.{TableType => HTableType} -import org.apache.hadoop.hive.ql.metadata.Hive -import org.apache.hadoop.hive.ql.processors._ -import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.ql.{Driver, metadata} -import org.apache.hadoop.hive.shims.{HadoopShims, ShimLoader} -import org.apache.hadoop.util.VersionInfo - -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.execution.QueryExecutionException -import org.apache.spark.util.{CircularBuffer, Utils} - -/** - * A class that wraps the HiveClient and converts its responses to externally visible classes. - * Note that this class is typically loaded with an internal classloader for each instantiation, - * allowing it to interact directly with a specific isolated version of Hive. Loading this class - * with the isolated classloader however will result in it only being visible as a ClientInterface, - * not a ClientWrapper. - * - * This class needs to interact with multiple versions of Hive, but will always be compiled with - * the 'native', execution version of Hive. Therefore, any places where hive breaks compatibility - * must use reflection after matching on `version`. - * - * @param version the version of hive used when pick function calls that are not compatible. - * @param config a collection of configuration options that will be added to the hive conf before - * opening the hive client. - * @param initClassLoader the classloader used when creating the `state` field of - * this ClientWrapper. - */ -private[hive] class ClientWrapper( - override val version: HiveVersion, - config: Map[String, String], - initClassLoader: ClassLoader, - val clientLoader: IsolatedClientLoader) - extends ClientInterface - with Logging { - - overrideHadoopShims() - - // !! HACK ALERT !! - // - // Internally, Hive `ShimLoader` tries to load different versions of Hadoop shims by checking - // major version number gathered from Hadoop jar files: - // - // - For major version number 1, load `Hadoop20SShims`, where "20S" stands for Hadoop 0.20 with - // security. - // - For major version number 2, load `Hadoop23Shims`, where "23" stands for Hadoop 0.23. - // - // However, APIs in Hadoop 2.0.x and 2.1.x versions were in flux due to historical reasons. It - // turns out that Hadoop 2.0.x versions should also be used together with `Hadoop20SShims`, but - // `Hadoop23Shims` is chosen because the major version number here is 2. - // - // To fix this issue, we try to inspect Hadoop version via `org.apache.hadoop.utils.VersionInfo` - // and load `Hadoop20SShims` for Hadoop 1.x and 2.0.x versions. If Hadoop version information is - // not available, we decide whether to override the shims or not by checking for existence of a - // probe method which doesn't exist in Hadoop 1.x or 2.0.x versions. - private def overrideHadoopShims(): Unit = { - val hadoopVersion = VersionInfo.getVersion - val VersionPattern = """(\d+)\.(\d+).*""".r - - hadoopVersion match { - case null => - logError("Failed to inspect Hadoop version") - - // Using "Path.getPathWithoutSchemeAndAuthority" as the probe method. - val probeMethod = "getPathWithoutSchemeAndAuthority" - if (!classOf[Path].getDeclaredMethods.exists(_.getName == probeMethod)) { - logInfo( - s"Method ${classOf[Path].getCanonicalName}.$probeMethod not found, " + - s"we are probably using Hadoop 1.x or 2.0.x") - loadHadoop20SShims() - } - - case VersionPattern(majorVersion, minorVersion) => - logInfo(s"Inspected Hadoop version: $hadoopVersion") - - // Loads Hadoop20SShims for 1.x and 2.0.x versions - val (major, minor) = (majorVersion.toInt, minorVersion.toInt) - if (major < 2 || (major == 2 && minor == 0)) { - loadHadoop20SShims() - } - } - - // Logs the actual loaded Hadoop shims class - val loadedShimsClassName = ShimLoader.getHadoopShims.getClass.getCanonicalName - logInfo(s"Loaded $loadedShimsClassName for Hadoop version $hadoopVersion") - } - - private def loadHadoop20SShims(): Unit = { - val hadoop20SShimsClassName = "org.apache.hadoop.hive.shims.Hadoop20SShims" - logInfo(s"Loading Hadoop shims $hadoop20SShimsClassName") - - try { - val shimsField = classOf[ShimLoader].getDeclaredField("hadoopShims") - // scalastyle:off classforname - val shimsClass = Class.forName(hadoop20SShimsClassName) - // scalastyle:on classforname - val shims = classOf[HadoopShims].cast(shimsClass.newInstance()) - shimsField.setAccessible(true) - shimsField.set(null, shims) - } catch { case cause: Throwable => - throw new RuntimeException(s"Failed to load $hadoop20SShimsClassName", cause) - } - } - - // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. - private val outputBuffer = new CircularBuffer() - - private val shim = version match { - case hive.v12 => new Shim_v0_12() - case hive.v13 => new Shim_v0_13() - case hive.v14 => new Shim_v0_14() - case hive.v1_0 => new Shim_v1_0() - case hive.v1_1 => new Shim_v1_1() - case hive.v1_2 => new Shim_v1_2() - } - - // Create an internal session state for this ClientWrapper. - val state = { - val original = Thread.currentThread().getContextClassLoader - // Switch to the initClassLoader. - Thread.currentThread().setContextClassLoader(initClassLoader) - val ret = try { - val initialConf = new HiveConf(classOf[SessionState]) - // HiveConf is a Hadoop Configuration, which has a field of classLoader and - // the initial value will be the current thread's context class loader - // (i.e. initClassLoader at here). - // We call initialConf.setClassLoader(initClassLoader) at here to make - // this action explicit. - initialConf.setClassLoader(initClassLoader) - config.foreach { case (k, v) => - if (k.toLowerCase.contains("password")) { - logDebug(s"Hive Config: $k=xxx") - } else { - logDebug(s"Hive Config: $k=$v") - } - initialConf.set(k, v) - } - val state = new SessionState(initialConf) - if (clientLoader.cachedHive != null) { - Hive.set(clientLoader.cachedHive.asInstanceOf[Hive]) - } - SessionState.start(state) - state.out = new PrintStream(outputBuffer, true, "UTF-8") - state.err = new PrintStream(outputBuffer, true, "UTF-8") - state - } finally { - Thread.currentThread().setContextClassLoader(original) - } - ret - } - - /** Returns the configuration for the current session. */ - def conf: HiveConf = SessionState.get().getConf - - override def getConf(key: String, defaultValue: String): String = { - conf.get(key, defaultValue) - } - - // We use hive's conf for compatibility. - private val retryLimit = conf.getIntVar(HiveConf.ConfVars.METASTORETHRIFTFAILURERETRIES) - private val retryDelayMillis = shim.getMetastoreClientConnectRetryDelayMillis(conf) - - /** - * Runs `f` with multiple retries in case the hive metastore is temporarily unreachable. - */ - private def retryLocked[A](f: => A): A = clientLoader.synchronized { - // Hive sometimes retries internally, so set a deadline to avoid compounding delays. - val deadline = System.nanoTime + (retryLimit * retryDelayMillis * 1e6).toLong - var numTries = 0 - var caughtException: Exception = null - do { - numTries += 1 - try { - return f - } catch { - case e: Exception if causedByThrift(e) => - caughtException = e - logWarning( - "HiveClientWrapper got thrift exception, destroying client and retrying " + - s"(${retryLimit - numTries} tries remaining)", e) - clientLoader.cachedHive = null - Thread.sleep(retryDelayMillis) - } - } while (numTries <= retryLimit && System.nanoTime < deadline) - if (System.nanoTime > deadline) { - logWarning("Deadline exceeded") - } - throw caughtException - } - - private def causedByThrift(e: Throwable): Boolean = { - var target = e - while (target != null) { - val msg = target.getMessage() - if (msg != null && msg.matches("(?s).*(TApplication|TProtocol|TTransport)Exception.*")) { - return true - } - target = target.getCause() - } - false - } - - def client: Hive = { - if (clientLoader.cachedHive != null) { - clientLoader.cachedHive.asInstanceOf[Hive] - } else { - val c = Hive.get(conf) - clientLoader.cachedHive = c - c - } - } - - /** - * Runs `f` with ThreadLocal session state and classloaders configured for this version of hive. - */ - def withHiveState[A](f: => A): A = retryLocked { - val original = Thread.currentThread().getContextClassLoader - // Set the thread local metastore client to the client associated with this ClientWrapper. - Hive.set(client) - // The classloader in clientLoader could be changed after addJar, always use the latest - // classloader - state.getConf.setClassLoader(clientLoader.classLoader) - // setCurrentSessionState will use the classLoader associated - // with the HiveConf in `state` to override the context class loader of the current - // thread. - shim.setCurrentSessionState(state) - val ret = try f finally { - Thread.currentThread().setContextClassLoader(original) - } - ret - } - - def setOut(stream: PrintStream): Unit = withHiveState { - state.out = stream - } - - def setInfo(stream: PrintStream): Unit = withHiveState { - state.info = stream - } - - def setError(stream: PrintStream): Unit = withHiveState { - state.err = stream - } - - override def currentDatabase: String = withHiveState { - state.getCurrentDatabase - } - - override def createDatabase(database: HiveDatabase): Unit = withHiveState { - client.createDatabase( - new Database( - database.name, - "", - new File(database.location).toURI.toString, - new java.util.HashMap), - true) - } - - override def getDatabaseOption(name: String): Option[HiveDatabase] = withHiveState { - Option(client.getDatabase(name)).map { d => - HiveDatabase( - name = d.getName, - location = d.getLocationUri) - } - } - - override def getTableOption( - dbName: String, - tableName: String): Option[HiveTable] = withHiveState { - - logDebug(s"Looking up $dbName.$tableName") - - val hiveTable = Option(client.getTable(dbName, tableName, false)) - val converted = hiveTable.map { h => - - HiveTable( - name = h.getTableName, - specifiedDatabase = Option(h.getDbName), - schema = h.getCols.asScala.map(f => HiveColumn(f.getName, f.getType, f.getComment)), - partitionColumns = h.getPartCols.asScala.map(f => - HiveColumn(f.getName, f.getType, f.getComment)), - properties = h.getParameters.asScala.toMap, - serdeProperties = h.getTTable.getSd.getSerdeInfo.getParameters.asScala.toMap, - tableType = h.getTableType match { - case HTableType.MANAGED_TABLE => ManagedTable - case HTableType.EXTERNAL_TABLE => ExternalTable - case HTableType.VIRTUAL_VIEW => VirtualView - case HTableType.INDEX_TABLE => IndexTable - }, - location = shim.getDataLocation(h), - inputFormat = Option(h.getInputFormatClass).map(_.getName), - outputFormat = Option(h.getOutputFormatClass).map(_.getName), - serde = Option(h.getSerializationLib), - viewText = Option(h.getViewExpandedText)).withClient(this) - } - converted - } - - private def toInputFormat(name: String) = - Utils.classForName(name).asInstanceOf[Class[_ <: org.apache.hadoop.mapred.InputFormat[_, _]]] - - private def toOutputFormat(name: String) = - Utils.classForName(name) - .asInstanceOf[Class[_ <: org.apache.hadoop.hive.ql.io.HiveOutputFormat[_, _]]] - - private def toQlTable(table: HiveTable): metadata.Table = { - val qlTable = new metadata.Table(table.database, table.name) - - qlTable.setFields(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) - qlTable.setPartCols( - table.partitionColumns.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) - table.properties.foreach { case (k, v) => qlTable.setProperty(k, v) } - table.serdeProperties.foreach { case (k, v) => qlTable.setSerdeParam(k, v) } - - // set owner - qlTable.setOwner(conf.getUser) - // set create time - qlTable.setCreateTime((System.currentTimeMillis() / 1000).asInstanceOf[Int]) - - table.location.foreach { loc => shim.setDataLocation(qlTable, loc) } - table.inputFormat.map(toInputFormat).foreach(qlTable.setInputFormatClass) - table.outputFormat.map(toOutputFormat).foreach(qlTable.setOutputFormatClass) - table.serde.foreach(qlTable.setSerializationLib) - - qlTable - } - - private def toViewTable(view: HiveTable): metadata.Table = { - // TODO: this is duplicated with `toQlTable` except the table type stuff. - val tbl = new metadata.Table(view.database, view.name) - tbl.setTableType(HTableType.VIRTUAL_VIEW) - tbl.setSerializationLib(null) - tbl.clearSerDeInfo() - - // TODO: we will save the same SQL string to original and expanded text, which is different - // from Hive. - tbl.setViewOriginalText(view.viewText.get) - tbl.setViewExpandedText(view.viewText.get) - - tbl.setFields(view.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) - view.properties.foreach { case (k, v) => tbl.setProperty(k, v) } - - // set owner - tbl.setOwner(conf.getUser) - // set create time - tbl.setCreateTime((System.currentTimeMillis() / 1000).asInstanceOf[Int]) - - tbl - } - - override def createView(view: HiveTable): Unit = withHiveState { - client.createTable(toViewTable(view)) - } - - override def alertView(view: HiveTable): Unit = withHiveState { - client.alterTable(view.qualifiedName, toViewTable(view)) - } - - override def createTable(table: HiveTable): Unit = withHiveState { - val qlTable = toQlTable(table) - client.createTable(qlTable) - } - - override def alterTable(table: HiveTable): Unit = withHiveState { - val qlTable = toQlTable(table) - client.alterTable(table.qualifiedName, qlTable) - } - - private def toHivePartition(partition: metadata.Partition): HivePartition = { - val apiPartition = partition.getTPartition - HivePartition( - values = Option(apiPartition.getValues).map(_.asScala).getOrElse(Seq.empty), - storage = HiveStorageDescriptor( - location = apiPartition.getSd.getLocation, - inputFormat = apiPartition.getSd.getInputFormat, - outputFormat = apiPartition.getSd.getOutputFormat, - serde = apiPartition.getSd.getSerdeInfo.getSerializationLib, - serdeProperties = apiPartition.getSd.getSerdeInfo.getParameters.asScala.toMap)) - } - - override def getPartitionOption( - table: HiveTable, - partitionSpec: JMap[String, String]): Option[HivePartition] = withHiveState { - - val qlTable = toQlTable(table) - val qlPartition = client.getPartition(qlTable, partitionSpec, false) - Option(qlPartition).map(toHivePartition) - } - - override def getAllPartitions(hTable: HiveTable): Seq[HivePartition] = withHiveState { - val qlTable = toQlTable(hTable) - shim.getAllPartitions(client, qlTable).map(toHivePartition) - } - - override def getPartitionsByFilter( - hTable: HiveTable, - predicates: Seq[Expression]): Seq[HivePartition] = withHiveState { - val qlTable = toQlTable(hTable) - shim.getPartitionsByFilter(client, qlTable, predicates).map(toHivePartition) - } - - override def listTables(dbName: String): Seq[String] = withHiveState { - client.getAllTables(dbName).asScala - } - - /** - * Runs the specified SQL query using Hive. - */ - override def runSqlHive(sql: String): Seq[String] = { - val maxResults = 100000 - val results = runHive(sql, maxResults) - // It is very confusing when you only get back some of the results... - if (results.size == maxResults) sys.error("RESULTS POSSIBLY TRUNCATED") - results - } - - /** - * Execute the command using Hive and return the results as a sequence. Each element - * in the sequence is one row. - */ - protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = withHiveState { - logDebug(s"Running hiveql '$cmd'") - if (cmd.toLowerCase.startsWith("set")) { logDebug(s"Changing config: $cmd") } - try { - val cmd_trimmed: String = cmd.trim() - val tokens: Array[String] = cmd_trimmed.split("\\s+") - // The remainder of the command. - val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() - val proc = shim.getCommandProcessor(tokens(0), conf) - proc match { - case driver: Driver => - val response: CommandProcessorResponse = driver.run(cmd) - // Throw an exception if there is an error in query processing. - if (response.getResponseCode != 0) { - driver.close() - throw new QueryExecutionException(response.getErrorMessage) - } - driver.setMaxRows(maxRows) - - val results = shim.getDriverResults(driver) - driver.close() - results - - case _ => - if (state.out != null) { - // scalastyle:off println - state.out.println(tokens(0) + " " + cmd_1) - // scalastyle:on println - } - Seq(proc.run(cmd_1).getResponseCode.toString) - } - } catch { - case e: Exception => - logError( - s""" - |====================== - |HIVE FAILURE OUTPUT - |====================== - |${outputBuffer.toString} - |====================== - |END HIVE FAILURE OUTPUT - |====================== - """.stripMargin) - throw e - } - } - - def loadPartition( - loadPath: String, - tableName: String, - partSpec: java.util.LinkedHashMap[String, String], - replace: Boolean, - holdDDLTime: Boolean, - inheritTableSpecs: Boolean, - isSkewedStoreAsSubdir: Boolean): Unit = withHiveState { - shim.loadPartition( - client, - new Path(loadPath), // TODO: Use URI - tableName, - partSpec, - replace, - holdDDLTime, - inheritTableSpecs, - isSkewedStoreAsSubdir) - } - - def loadTable( - loadPath: String, // TODO URI - tableName: String, - replace: Boolean, - holdDDLTime: Boolean): Unit = withHiveState { - shim.loadTable( - client, - new Path(loadPath), - tableName, - replace, - holdDDLTime) - } - - def loadDynamicPartitions( - loadPath: String, - tableName: String, - partSpec: java.util.LinkedHashMap[String, String], - replace: Boolean, - numDP: Int, - holdDDLTime: Boolean, - listBucketingEnabled: Boolean): Unit = withHiveState { - shim.loadDynamicPartitions( - client, - new Path(loadPath), - tableName, - partSpec, - replace, - numDP, - holdDDLTime, - listBucketingEnabled) - } - - def addJar(path: String): Unit = { - clientLoader.addJar(path) - runSqlHive(s"ADD JAR $path") - } - - def newSession(): ClientWrapper = { - clientLoader.createClient().asInstanceOf[ClientWrapper] - } - - def reset(): Unit = withHiveState { - client.getAllTables("default").asScala.foreach { t => - logDebug(s"Deleting table $t") - val table = client.getTable("default", t) - client.getIndexes("default", t, 255).asScala.foreach { index => - shim.dropIndex(client, "default", t, index.getIndexName) - } - if (!table.isIndexTable) { - client.dropTable("default", t) - } - } - client.getAllDatabases.asScala.filterNot(_ == "default").foreach { db => - logDebug(s"Dropping Database: $db") - client.dropDatabase(db, true, false, true) - } - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala new file mode 100644 index 0000000000000..6f7e7bf45106f --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import java.io.PrintStream + +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.expressions.Expression + + +/** + * An externally visible interface to the Hive client. This interface is shared across both the + * internal and external classloaders for a given version of Hive and thus must expose only + * shared classes. + */ +private[hive] trait HiveClient { + + /** Returns the Hive Version of this client. */ + def version: HiveVersion + + /** Returns the configuration for the given key in the current session. */ + def getConf(key: String, defaultValue: String): String + + /** + * Runs a HiveQL command using Hive, returning the results as a list of strings. Each row will + * result in one string. + */ + def runSqlHive(sql: String): Seq[String] + + def setOut(stream: PrintStream): Unit + def setInfo(stream: PrintStream): Unit + def setError(stream: PrintStream): Unit + + /** Returns the names of all tables in the given database. */ + def listTables(dbName: String): Seq[String] + + /** Returns the names of tables in the given database that matches the given pattern. */ + def listTables(dbName: String, pattern: String): Seq[String] + + /** Sets the name of current database. */ + def setCurrentDatabase(databaseName: String): Unit + + /** Returns the metadata for specified database, throwing an exception if it doesn't exist */ + final def getDatabase(name: String): CatalogDatabase = { + getDatabaseOption(name).getOrElse(throw new NoSuchDatabaseException(name)) + } + + /** Returns the metadata for a given database, or None if it doesn't exist. */ + def getDatabaseOption(name: String): Option[CatalogDatabase] + + /** List the names of all the databases that match the specified pattern. */ + def listDatabases(pattern: String): Seq[String] + + /** Returns the specified table, or throws [[NoSuchTableException]]. */ + final def getTable(dbName: String, tableName: String): CatalogTable = { + getTableOption(dbName, tableName).getOrElse(throw new NoSuchTableException(dbName, tableName)) + } + + /** Returns the metadata for the specified table or None if it doesn't exist. */ + def getTableOption(dbName: String, tableName: String): Option[CatalogTable] + + /** Creates a view with the given metadata. */ + def createView(view: CatalogTable): Unit + + /** Updates the given view with new metadata. */ + def alertView(view: CatalogTable): Unit + + /** Creates a table with the given metadata. */ + def createTable(table: CatalogTable, ignoreIfExists: Boolean): Unit + + /** Drop the specified table. */ + def dropTable(dbName: String, tableName: String, ignoreIfNotExists: Boolean): Unit + + /** Alter a table whose name matches the one specified in `table`, assuming it exists. */ + final def alterTable(table: CatalogTable): Unit = alterTable(table.identifier.table, table) + + /** Updates the given table with new metadata, optionally renaming the table. */ + def alterTable(tableName: String, table: CatalogTable): Unit + + /** Creates a new database with the given name. */ + def createDatabase(database: CatalogDatabase, ignoreIfExists: Boolean): Unit + + /** + * Drop the specified database, if it exists. + * + * @param name database to drop + * @param ignoreIfNotExists if true, do not throw error if the database does not exist + * @param cascade whether to remove all associated objects such as tables and functions + */ + def dropDatabase(name: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit + + /** + * Alter a database whose name matches the one specified in `database`, assuming it exists. + */ + def alterDatabase(database: CatalogDatabase): Unit + + /** + * Create one or many partitions in the given table. + */ + def createPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition], + ignoreIfExists: Boolean): Unit + + /** + * Drop one or many partitions in the given table, assuming they exist. + */ + def dropPartitions( + db: String, + table: String, + specs: Seq[ExternalCatalog.TablePartitionSpec], + ignoreIfNotExists: Boolean): Unit + + /** + * Rename one or many existing table partitions, assuming they exist. + */ + def renamePartitions( + db: String, + table: String, + specs: Seq[ExternalCatalog.TablePartitionSpec], + newSpecs: Seq[ExternalCatalog.TablePartitionSpec]): Unit + + /** + * Alter one or more table partitions whose specs match the ones specified in `newParts`, + * assuming the partitions exist. + */ + def alterPartitions( + db: String, + table: String, + newParts: Seq[CatalogTablePartition]): Unit + + /** Returns the specified partition, or throws [[NoSuchPartitionException]]. */ + final def getPartition( + dbName: String, + tableName: String, + spec: ExternalCatalog.TablePartitionSpec): CatalogTablePartition = { + getPartitionOption(dbName, tableName, spec).getOrElse { + throw new NoSuchPartitionException(dbName, tableName, spec) + } + } + + /** Returns the specified partition or None if it does not exist. */ + final def getPartitionOption( + db: String, + table: String, + spec: ExternalCatalog.TablePartitionSpec): Option[CatalogTablePartition] = { + getPartitionOption(getTable(db, table), spec) + } + + /** Returns the specified partition or None if it does not exist. */ + def getPartitionOption( + table: CatalogTable, + spec: ExternalCatalog.TablePartitionSpec): Option[CatalogTablePartition] + + /** Returns all partitions for the given table. */ + final def getAllPartitions(db: String, table: String): Seq[CatalogTablePartition] = { + getAllPartitions(getTable(db, table)) + } + + /** Returns all partitions for the given table. */ + def getAllPartitions(table: CatalogTable): Seq[CatalogTablePartition] + + /** Returns partitions filtered by predicates for the given table. */ + def getPartitionsByFilter( + table: CatalogTable, + predicates: Seq[Expression]): Seq[CatalogTablePartition] + + /** Loads a static partition into an existing table. */ + def loadPartition( + loadPath: String, + tableName: String, + partSpec: java.util.LinkedHashMap[String, String], // Hive relies on LinkedHashMap ordering + replace: Boolean, + holdDDLTime: Boolean, + inheritTableSpecs: Boolean, + isSkewedStoreAsSubdir: Boolean): Unit + + /** Loads data into an existing table. */ + def loadTable( + loadPath: String, // TODO URI + tableName: String, + replace: Boolean, + holdDDLTime: Boolean): Unit + + /** Loads new dynamic partitions into an existing table. */ + def loadDynamicPartitions( + loadPath: String, + tableName: String, + partSpec: java.util.LinkedHashMap[String, String], // Hive relies on LinkedHashMap ordering + replace: Boolean, + numDP: Int, + holdDDLTime: Boolean, + listBucketingEnabled: Boolean): Unit + + /** Create a function in an existing database. */ + def createFunction(db: String, func: CatalogFunction): Unit + + /** Drop an existing function an the database. */ + def dropFunction(db: String, name: String): Unit + + /** Rename an existing function in the database. */ + def renameFunction(db: String, oldName: String, newName: String): Unit + + /** Alter a function whose name matches the one specified in `func`, assuming it exists. */ + def alterFunction(db: String, func: CatalogFunction): Unit + + /** Return an existing function in the database, assuming it exists. */ + final def getFunction(db: String, name: String): CatalogFunction = { + getFunctionOption(db, name).getOrElse(throw new NoSuchFunctionException(db, name)) + } + + /** Return an existing function in the database, or None if it doesn't exist. */ + def getFunctionOption(db: String, name: String): Option[CatalogFunction] + + /** Return whether a function exists in the specified database. */ + final def functionExists(db: String, name: String): Boolean = { + getFunctionOption(db, name).isDefined + } + + /** Return the names of all functions that match the given pattern in the database. */ + def listFunctions(db: String, pattern: String): Seq[String] + + /** Add a jar into class loader */ + def addJar(path: String): Unit + + /** Return a [[HiveClient]] as new session, that will share the class loader and Hive client */ + def newSession(): HiveClient + + /** Run a function within Hive state (SessionState, HiveConf, Hive client and class loader) */ + def withHiveState[A](f: => A): A + + /** Used for testing only. Removes all metadata from this instance of Hive. */ + def reset(): Unit + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala new file mode 100644 index 0000000000000..2a1fff92b570a --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -0,0 +1,757 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import java.io.{File, PrintStream} + +import scala.collection.JavaConverters._ +import scala.language.reflectiveCalls + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.cli.CliSessionState +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.metastore.{PartitionDropOptions, TableType => HiveTableType} +import org.apache.hadoop.hive.metastore.api.{Database => HiveDatabase, FieldSchema, Function => HiveFunction, FunctionType, PrincipalType, ResourceType, ResourceUri} +import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} +import org.apache.hadoop.hive.ql.metadata.{Hive, HiveException} +import org.apache.hadoop.hive.ql.plan.AddPartitionDesc +import org.apache.hadoop.hive.ql.processors._ +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.security.UserGroupInformation + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchPartitionException} +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.QueryExecutionException +import org.apache.spark.util.{CircularBuffer, Utils} + +/** + * A class that wraps the HiveClient and converts its responses to externally visible classes. + * Note that this class is typically loaded with an internal classloader for each instantiation, + * allowing it to interact directly with a specific isolated version of Hive. Loading this class + * with the isolated classloader however will result in it only being visible as a [[HiveClient]], + * not a [[HiveClientImpl]]. + * + * This class needs to interact with multiple versions of Hive, but will always be compiled with + * the 'native', execution version of Hive. Therefore, any places where hive breaks compatibility + * must use reflection after matching on `version`. + * + * @param version the version of hive used when pick function calls that are not compatible. + * @param config a collection of configuration options that will be added to the hive conf before + * opening the hive client. + * @param initClassLoader the classloader used when creating the `state` field of + * this [[HiveClientImpl]]. + */ +private[hive] class HiveClientImpl( + override val version: HiveVersion, + sparkConf: SparkConf, + hadoopConf: Configuration, + config: Map[String, String], + initClassLoader: ClassLoader, + val clientLoader: IsolatedClientLoader) + extends HiveClient + with Logging { + + // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. + private val outputBuffer = new CircularBuffer() + + private val shim = version match { + case hive.v12 => new Shim_v0_12() + case hive.v13 => new Shim_v0_13() + case hive.v14 => new Shim_v0_14() + case hive.v1_0 => new Shim_v1_0() + case hive.v1_1 => new Shim_v1_1() + case hive.v1_2 => new Shim_v1_2() + } + + // Create an internal session state for this HiveClientImpl. + val state = { + val original = Thread.currentThread().getContextClassLoader + // Switch to the initClassLoader. + Thread.currentThread().setContextClassLoader(initClassLoader) + + // Set up kerberos credentials for UserGroupInformation.loginUser within + // current class loader + // Instead of using the spark conf of the current spark context, a new + // instance of SparkConf is needed for the original value of spark.yarn.keytab + // and spark.yarn.principal set in SparkSubmit, as yarn.Client resets the + // keytab configuration for the link name in distributed cache + if (sparkConf.contains("spark.yarn.principal") && sparkConf.contains("spark.yarn.keytab")) { + val principalName = sparkConf.get("spark.yarn.principal") + val keytabFileName = sparkConf.get("spark.yarn.keytab") + if (!new File(keytabFileName).exists()) { + throw new SparkException(s"Keytab file: ${keytabFileName}" + + " specified in spark.yarn.keytab does not exist") + } else { + logInfo("Attempting to login to Kerberos" + + s" using principal: ${principalName} and keytab: ${keytabFileName}") + UserGroupInformation.loginUserFromKeytab(principalName, keytabFileName) + } + } + + val ret = try { + // originState will be created if not exists, will never be null + val originalState = SessionState.get() + if (originalState.isInstanceOf[CliSessionState]) { + // In `SparkSQLCLIDriver`, we have already started a `CliSessionState`, + // which contains information like configurations from command line. Later + // we call `SparkSQLEnv.init()` there, which would run into this part again. + // so we should keep `conf` and reuse the existing instance of `CliSessionState`. + originalState + } else { + val initialConf = new HiveConf(hadoopConf, classOf[SessionState]) + // HiveConf is a Hadoop Configuration, which has a field of classLoader and + // the initial value will be the current thread's context class loader + // (i.e. initClassLoader at here). + // We call initialConf.setClassLoader(initClassLoader) at here to make + // this action explicit. + initialConf.setClassLoader(initClassLoader) + config.foreach { case (k, v) => + if (k.toLowerCase.contains("password")) { + logDebug(s"Hive Config: $k=xxx") + } else { + logDebug(s"Hive Config: $k=$v") + } + initialConf.set(k, v) + } + val state = new SessionState(initialConf) + if (clientLoader.cachedHive != null) { + Hive.set(clientLoader.cachedHive.asInstanceOf[Hive]) + } + SessionState.start(state) + state.out = new PrintStream(outputBuffer, true, "UTF-8") + state.err = new PrintStream(outputBuffer, true, "UTF-8") + state + } + } finally { + Thread.currentThread().setContextClassLoader(original) + } + ret + } + + /** Returns the configuration for the current session. */ + def conf: HiveConf = SessionState.get().getConf + + override def getConf(key: String, defaultValue: String): String = { + conf.get(key, defaultValue) + } + + // We use hive's conf for compatibility. + private val retryLimit = conf.getIntVar(HiveConf.ConfVars.METASTORETHRIFTFAILURERETRIES) + private val retryDelayMillis = shim.getMetastoreClientConnectRetryDelayMillis(conf) + + /** + * Runs `f` with multiple retries in case the hive metastore is temporarily unreachable. + */ + private def retryLocked[A](f: => A): A = clientLoader.synchronized { + // Hive sometimes retries internally, so set a deadline to avoid compounding delays. + val deadline = System.nanoTime + (retryLimit * retryDelayMillis * 1e6).toLong + var numTries = 0 + var caughtException: Exception = null + do { + numTries += 1 + try { + return f + } catch { + case e: Exception if causedByThrift(e) => + caughtException = e + logWarning( + "HiveClient got thrift exception, destroying client and retrying " + + s"(${retryLimit - numTries} tries remaining)", e) + clientLoader.cachedHive = null + Thread.sleep(retryDelayMillis) + } + } while (numTries <= retryLimit && System.nanoTime < deadline) + if (System.nanoTime > deadline) { + logWarning("Deadline exceeded") + } + throw caughtException + } + + private def causedByThrift(e: Throwable): Boolean = { + var target = e + while (target != null) { + val msg = target.getMessage() + if (msg != null && msg.matches("(?s).*(TApplication|TProtocol|TTransport)Exception.*")) { + return true + } + target = target.getCause() + } + false + } + + def client: Hive = { + if (clientLoader.cachedHive != null) { + clientLoader.cachedHive.asInstanceOf[Hive] + } else { + val c = Hive.get(conf) + clientLoader.cachedHive = c + c + } + } + + /** + * Runs `f` with ThreadLocal session state and classloaders configured for this version of hive. + */ + def withHiveState[A](f: => A): A = retryLocked { + val original = Thread.currentThread().getContextClassLoader + // Set the thread local metastore client to the client associated with this HiveClientImpl. + Hive.set(client) + // The classloader in clientLoader could be changed after addJar, always use the latest + // classloader + state.getConf.setClassLoader(clientLoader.classLoader) + // setCurrentSessionState will use the classLoader associated + // with the HiveConf in `state` to override the context class loader of the current + // thread. + shim.setCurrentSessionState(state) + val ret = try f finally { + Thread.currentThread().setContextClassLoader(original) + } + ret + } + + def setOut(stream: PrintStream): Unit = withHiveState { + state.out = stream + } + + def setInfo(stream: PrintStream): Unit = withHiveState { + state.info = stream + } + + def setError(stream: PrintStream): Unit = withHiveState { + state.err = stream + } + + override def setCurrentDatabase(databaseName: String): Unit = withHiveState { + if (getDatabaseOption(databaseName).isDefined) { + state.setCurrentDatabase(databaseName) + } else { + throw new NoSuchDatabaseException(databaseName) + } + } + + override def createDatabase( + database: CatalogDatabase, + ignoreIfExists: Boolean): Unit = withHiveState { + client.createDatabase( + new HiveDatabase( + database.name, + database.description, + database.locationUri, + database.properties.asJava), + ignoreIfExists) + } + + override def dropDatabase( + name: String, + ignoreIfNotExists: Boolean, + cascade: Boolean): Unit = withHiveState { + client.dropDatabase(name, true, ignoreIfNotExists, cascade) + } + + override def alterDatabase(database: CatalogDatabase): Unit = withHiveState { + client.alterDatabase( + database.name, + new HiveDatabase( + database.name, + database.description, + database.locationUri, + database.properties.asJava)) + } + + override def getDatabaseOption(name: String): Option[CatalogDatabase] = withHiveState { + Option(client.getDatabase(name)).map { d => + CatalogDatabase( + name = d.getName, + description = d.getDescription, + locationUri = d.getLocationUri, + properties = d.getParameters.asScala.toMap) + } + } + + override def listDatabases(pattern: String): Seq[String] = withHiveState { + client.getDatabasesByPattern(pattern).asScala.toSeq + } + + override def getTableOption( + dbName: String, + tableName: String): Option[CatalogTable] = withHiveState { + logDebug(s"Looking up $dbName.$tableName") + Option(client.getTable(dbName, tableName, false)).map { h => + // Note: Hive separates partition columns and the schema, but for us the + // partition columns are part of the schema + val partCols = h.getPartCols.asScala.map(fromHiveColumn) + val schema = h.getCols.asScala.map(fromHiveColumn) ++ partCols + CatalogTable( + identifier = TableIdentifier(h.getTableName, Option(h.getDbName)), + tableType = h.getTableType match { + case HiveTableType.EXTERNAL_TABLE => CatalogTableType.EXTERNAL_TABLE + case HiveTableType.MANAGED_TABLE => CatalogTableType.MANAGED_TABLE + case HiveTableType.INDEX_TABLE => CatalogTableType.INDEX_TABLE + case HiveTableType.VIRTUAL_VIEW => CatalogTableType.VIRTUAL_VIEW + }, + schema = schema, + partitionColumnNames = partCols.map(_.name), + sortColumnNames = Seq(), // TODO: populate this + bucketColumnNames = h.getBucketCols.asScala, + numBuckets = h.getNumBuckets, + createTime = h.getTTable.getCreateTime.toLong * 1000, + lastAccessTime = h.getLastAccessTime.toLong * 1000, + storage = CatalogStorageFormat( + locationUri = shim.getDataLocation(h), + inputFormat = Option(h.getInputFormatClass).map(_.getName), + outputFormat = Option(h.getOutputFormatClass).map(_.getName), + serde = Option(h.getSerializationLib), + serdeProperties = h.getTTable.getSd.getSerdeInfo.getParameters.asScala.toMap + ), + properties = h.getParameters.asScala.toMap, + viewOriginalText = Option(h.getViewOriginalText), + viewText = Option(h.getViewExpandedText)) + } + } + + override def createView(view: CatalogTable): Unit = withHiveState { + client.createTable(toHiveViewTable(view)) + } + + override def alertView(view: CatalogTable): Unit = withHiveState { + client.alterTable(view.qualifiedName, toHiveViewTable(view)) + } + + override def createTable(table: CatalogTable, ignoreIfExists: Boolean): Unit = withHiveState { + client.createTable(toHiveTable(table), ignoreIfExists) + } + + override def dropTable( + dbName: String, + tableName: String, + ignoreIfNotExists: Boolean): Unit = withHiveState { + client.dropTable(dbName, tableName, true, ignoreIfNotExists) + } + + override def alterTable(tableName: String, table: CatalogTable): Unit = withHiveState { + val hiveTable = toHiveTable(table) + // Do not use `table.qualifiedName` here because this may be a rename + val qualifiedTableName = s"${table.database}.$tableName" + client.alterTable(qualifiedTableName, hiveTable) + } + + override def createPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition], + ignoreIfExists: Boolean): Unit = withHiveState { + val addPartitionDesc = new AddPartitionDesc(db, table, ignoreIfExists) + parts.foreach { s => + addPartitionDesc.addPartition(s.spec.asJava, s.storage.locationUri.orNull) + } + client.createPartitions(addPartitionDesc) + } + + override def dropPartitions( + db: String, + table: String, + specs: Seq[ExternalCatalog.TablePartitionSpec], + ignoreIfNotExists: Boolean): Unit = withHiveState { + // TODO: figure out how to drop multiple partitions in one call + val hiveTable = client.getTable(db, table, true /* throw exception */) + specs.foreach { s => + // The provided spec here can be a partial spec, i.e. it will match all partitions + // whose specs are supersets of this partial spec. E.g. If a table has partitions + // (b='1', c='1') and (b='1', c='2'), a partial spec of (b='1') will match both. + val matchingParts = client.getPartitions(hiveTable, s.asJava).asScala + if (matchingParts.isEmpty && !ignoreIfNotExists) { + throw new AnalysisException( + s"partition to drop '$s' does not exist in table '$table' database '$db'") + } + matchingParts.foreach { hivePartition => + val dropOptions = new PartitionDropOptions + dropOptions.ifExists = ignoreIfNotExists + client.dropPartition(db, table, hivePartition.getValues, dropOptions) + } + } + } + + override def renamePartitions( + db: String, + table: String, + specs: Seq[ExternalCatalog.TablePartitionSpec], + newSpecs: Seq[ExternalCatalog.TablePartitionSpec]): Unit = withHiveState { + require(specs.size == newSpecs.size, "number of old and new partition specs differ") + val catalogTable = getTable(db, table) + val hiveTable = toHiveTable(catalogTable) + specs.zip(newSpecs).foreach { case (oldSpec, newSpec) => + val hivePart = getPartitionOption(catalogTable, oldSpec) + .map { p => toHivePartition(p.copy(spec = newSpec), hiveTable) } + .getOrElse { throw new NoSuchPartitionException(db, table, oldSpec) } + client.renamePartition(hiveTable, oldSpec.asJava, hivePart) + } + } + + override def alterPartitions( + db: String, + table: String, + newParts: Seq[CatalogTablePartition]): Unit = withHiveState { + val hiveTable = toHiveTable(getTable(db, table)) + client.alterPartitions(table, newParts.map { p => toHivePartition(p, hiveTable) }.asJava) + } + + override def getPartitionOption( + table: CatalogTable, + spec: ExternalCatalog.TablePartitionSpec): Option[CatalogTablePartition] = withHiveState { + val hiveTable = toHiveTable(table) + val hivePartition = client.getPartition(hiveTable, spec.asJava, false) + Option(hivePartition).map(fromHivePartition) + } + + override def getAllPartitions(table: CatalogTable): Seq[CatalogTablePartition] = withHiveState { + val hiveTable = toHiveTable(table) + shim.getAllPartitions(client, hiveTable).map(fromHivePartition) + } + + override def getPartitionsByFilter( + table: CatalogTable, + predicates: Seq[Expression]): Seq[CatalogTablePartition] = withHiveState { + val hiveTable = toHiveTable(table) + shim.getPartitionsByFilter(client, hiveTable, predicates).map(fromHivePartition) + } + + override def listTables(dbName: String): Seq[String] = withHiveState { + client.getAllTables(dbName).asScala + } + + override def listTables(dbName: String, pattern: String): Seq[String] = withHiveState { + client.getTablesByPattern(dbName, pattern).asScala + } + + /** + * Runs the specified SQL query using Hive. + */ + override def runSqlHive(sql: String): Seq[String] = { + val maxResults = 100000 + val results = runHive(sql, maxResults) + // It is very confusing when you only get back some of the results... + if (results.size == maxResults) sys.error("RESULTS POSSIBLY TRUNCATED") + results + } + + /** + * Execute the command using Hive and return the results as a sequence. Each element + * in the sequence is one row. + */ + protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = withHiveState { + logDebug(s"Running hiveql '$cmd'") + if (cmd.toLowerCase.startsWith("set")) { logDebug(s"Changing config: $cmd") } + try { + val cmd_trimmed: String = cmd.trim() + val tokens: Array[String] = cmd_trimmed.split("\\s+") + // The remainder of the command. + val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() + val proc = shim.getCommandProcessor(tokens(0), conf) + proc match { + case driver: Driver => + val response: CommandProcessorResponse = driver.run(cmd) + // Throw an exception if there is an error in query processing. + if (response.getResponseCode != 0) { + driver.close() + throw new QueryExecutionException(response.getErrorMessage) + } + driver.setMaxRows(maxRows) + + val results = shim.getDriverResults(driver) + driver.close() + results + + case _ => + if (state.out != null) { + // scalastyle:off println + state.out.println(tokens(0) + " " + cmd_1) + // scalastyle:on println + } + Seq(proc.run(cmd_1).getResponseCode.toString) + } + } catch { + case e: Exception => + logError( + s""" + |====================== + |HIVE FAILURE OUTPUT + |====================== + |${outputBuffer.toString} + |====================== + |END HIVE FAILURE OUTPUT + |====================== + """.stripMargin) + throw e + } + } + + def loadPartition( + loadPath: String, + tableName: String, + partSpec: java.util.LinkedHashMap[String, String], + replace: Boolean, + holdDDLTime: Boolean, + inheritTableSpecs: Boolean, + isSkewedStoreAsSubdir: Boolean): Unit = withHiveState { + shim.loadPartition( + client, + new Path(loadPath), // TODO: Use URI + tableName, + partSpec, + replace, + holdDDLTime, + inheritTableSpecs, + isSkewedStoreAsSubdir) + } + + def loadTable( + loadPath: String, // TODO URI + tableName: String, + replace: Boolean, + holdDDLTime: Boolean): Unit = withHiveState { + shim.loadTable( + client, + new Path(loadPath), + tableName, + replace, + holdDDLTime) + } + + def loadDynamicPartitions( + loadPath: String, + tableName: String, + partSpec: java.util.LinkedHashMap[String, String], + replace: Boolean, + numDP: Int, + holdDDLTime: Boolean, + listBucketingEnabled: Boolean): Unit = withHiveState { + shim.loadDynamicPartitions( + client, + new Path(loadPath), + tableName, + partSpec, + replace, + numDP, + holdDDLTime, + listBucketingEnabled) + } + + override def createFunction(db: String, func: CatalogFunction): Unit = withHiveState { + client.createFunction(toHiveFunction(func, db)) + } + + override def dropFunction(db: String, name: String): Unit = withHiveState { + client.dropFunction(db, name) + } + + override def renameFunction(db: String, oldName: String, newName: String): Unit = withHiveState { + val catalogFunc = getFunction(db, oldName) + .copy(identifier = FunctionIdentifier(newName, Some(db))) + val hiveFunc = toHiveFunction(catalogFunc, db) + client.alterFunction(db, oldName, hiveFunc) + } + + override def alterFunction(db: String, func: CatalogFunction): Unit = withHiveState { + client.alterFunction(db, func.identifier.funcName, toHiveFunction(func, db)) + } + + override def getFunctionOption( + db: String, + name: String): Option[CatalogFunction] = withHiveState { + try { + Option(client.getFunction(db, name)).map(fromHiveFunction) + } catch { + case he: HiveException => None + } + } + + override def listFunctions(db: String, pattern: String): Seq[String] = withHiveState { + client.getFunctions(db, pattern).asScala + } + + def addJar(path: String): Unit = { + val uri = new Path(path).toUri + val jarURL = if (uri.getScheme == null) { + // `path` is a local file path without a URL scheme + new File(path).toURI.toURL + } else { + // `path` is a URL with a scheme + uri.toURL + } + clientLoader.addJar(jarURL) + runSqlHive(s"ADD JAR $path") + } + + def newSession(): HiveClientImpl = { + clientLoader.createClient().asInstanceOf[HiveClientImpl] + } + + def reset(): Unit = withHiveState { + client.getAllTables("default").asScala.foreach { t => + logDebug(s"Deleting table $t") + val table = client.getTable("default", t) + client.getIndexes("default", t, 255).asScala.foreach { index => + shim.dropIndex(client, "default", t, index.getIndexName) + } + if (!table.isIndexTable) { + client.dropTable("default", t) + } + } + client.getAllDatabases.asScala.filterNot(_ == "default").foreach { db => + logDebug(s"Dropping Database: $db") + client.dropDatabase(db, true, false, true) + } + } + + + /* -------------------------------------------------------- * + | Helper methods for converting to and from Hive classes | + * -------------------------------------------------------- */ + + private def toInputFormat(name: String) = + Utils.classForName(name).asInstanceOf[Class[_ <: org.apache.hadoop.mapred.InputFormat[_, _]]] + + private def toOutputFormat(name: String) = + Utils.classForName(name) + .asInstanceOf[Class[_ <: org.apache.hadoop.hive.ql.io.HiveOutputFormat[_, _]]] + + private def toHiveFunction(f: CatalogFunction, db: String): HiveFunction = { + val resourceUris = f.resources.map { case (resourceType, resourcePath) => + new ResourceUri(ResourceType.valueOf(resourceType.toUpperCase), resourcePath) + } + new HiveFunction( + f.identifier.funcName, + db, + f.className, + null, + PrincipalType.USER, + (System.currentTimeMillis / 1000).toInt, + FunctionType.JAVA, + resourceUris.asJava) + } + + private def fromHiveFunction(hf: HiveFunction): CatalogFunction = { + val name = FunctionIdentifier(hf.getFunctionName, Option(hf.getDbName)) + val resources = hf.getResourceUris.asScala.map { uri => + val resourceType = uri.getResourceType() match { + case ResourceType.ARCHIVE => "archive" + case ResourceType.FILE => "file" + case ResourceType.JAR => "jar" + case r => throw new AnalysisException(s"Unknown resource type: $r") + } + (resourceType, uri.getUri()) + } + new CatalogFunction(name, hf.getClassName, resources) + } + + private def toHiveColumn(c: CatalogColumn): FieldSchema = { + new FieldSchema(c.name, c.dataType, c.comment.orNull) + } + + private def fromHiveColumn(hc: FieldSchema): CatalogColumn = { + new CatalogColumn( + name = hc.getName, + dataType = hc.getType, + nullable = true, + comment = Option(hc.getComment)) + } + + private def toHiveTable(table: CatalogTable): HiveTable = { + val hiveTable = new HiveTable(table.database, table.identifier.table) + // For EXTERNAL_TABLE, we also need to set EXTERNAL field in the table properties. + // Otherwise, Hive metastore will change the table to a MANAGED_TABLE. + // (metastore/src/java/org/apache/hadoop/hive/metastore/ObjectStore.java#L1095-L1105) + hiveTable.setTableType(table.tableType match { + case CatalogTableType.EXTERNAL_TABLE => + hiveTable.setProperty("EXTERNAL", "TRUE") + HiveTableType.EXTERNAL_TABLE + case CatalogTableType.MANAGED_TABLE => + HiveTableType.MANAGED_TABLE + case CatalogTableType.INDEX_TABLE => HiveTableType.INDEX_TABLE + case CatalogTableType.VIRTUAL_VIEW => HiveTableType.VIRTUAL_VIEW + }) + // Note: In Hive the schema and partition columns must be disjoint sets + val (partCols, schema) = table.schema.map(toHiveColumn).partition { c => + table.partitionColumnNames.contains(c.getName) + } + if (table.schema.isEmpty) { + // This is a hack to preserve existing behavior. Before Spark 2.0, we do not + // set a default serde here (this was done in Hive), and so if the user provides + // an empty schema Hive would automatically populate the schema with a single + // field "col". However, after SPARK-14388, we set the default serde to + // LazySimpleSerde so this implicit behavior no longer happens. Therefore, + // we need to do it in Spark ourselves. + hiveTable.setFields( + Seq(new FieldSchema("col", "array", "from deserializer")).asJava) + } else { + hiveTable.setFields(schema.asJava) + } + hiveTable.setPartCols(partCols.asJava) + // TODO: set sort columns here too + hiveTable.setBucketCols(table.bucketColumnNames.asJava) + hiveTable.setOwner(conf.getUser) + hiveTable.setNumBuckets(table.numBuckets) + hiveTable.setCreateTime((table.createTime / 1000).toInt) + hiveTable.setLastAccessTime((table.lastAccessTime / 1000).toInt) + table.storage.locationUri.foreach { loc => shim.setDataLocation(hiveTable, loc) } + table.storage.inputFormat.map(toInputFormat).foreach(hiveTable.setInputFormatClass) + table.storage.outputFormat.map(toOutputFormat).foreach(hiveTable.setOutputFormatClass) + hiveTable.setSerializationLib( + table.storage.serde.getOrElse("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + table.storage.serdeProperties.foreach { case (k, v) => hiveTable.setSerdeParam(k, v) } + table.properties.foreach { case (k, v) => hiveTable.setProperty(k, v) } + table.comment.foreach { c => hiveTable.setProperty("comment", c) } + table.viewOriginalText.foreach { t => hiveTable.setViewOriginalText(t) } + table.viewText.foreach { t => hiveTable.setViewExpandedText(t) } + hiveTable + } + + private def toHiveViewTable(view: CatalogTable): HiveTable = { + val tbl = toHiveTable(view) + tbl.setTableType(HiveTableType.VIRTUAL_VIEW) + tbl.setSerializationLib(null) + tbl.clearSerDeInfo() + tbl + } + + private def toHivePartition( + p: CatalogTablePartition, + ht: HiveTable): HivePartition = { + new HivePartition(ht, p.spec.asJava, p.storage.locationUri.map { l => new Path(l) }.orNull) + } + + private def fromHivePartition(hp: HivePartition): CatalogTablePartition = { + val apiPartition = hp.getTPartition + CatalogTablePartition( + spec = Option(hp.getSpec).map(_.asScala.toMap).getOrElse(Map.empty), + storage = CatalogStorageFormat( + locationUri = Option(apiPartition.getSd.getLocation), + inputFormat = Option(apiPartition.getSd.getInputFormat), + outputFormat = Option(apiPartition.getSd.getOutputFormat), + serde = Option(apiPartition.getSd.getSerdeInfo.getSerializationLib), + serdeProperties = apiPartition.getSd.getSerdeInfo.getParameters.asScala.toMap)) + } + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 48bbb21e6c1de..4ecf866f96395 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -33,13 +33,13 @@ import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorF import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{StringType, IntegralType} +import org.apache.spark.sql.types.{IntegralType, StringType} /** - * A shim that defines the interface between ClientWrapper and the underlying Hive library used to - * talk to the metastore. Each Hive version has its own implementation of this class, defining + * A shim that defines the interface between [[HiveClientImpl]] and the underlying Hive library used + * to talk to the metastore. Each Hive version has its own implementation of this class, defining * version-specific version of needed functions. * * The guideline for writing shims is: @@ -52,7 +52,6 @@ private[client] sealed abstract class Shim { /** * Set the current SessionState to the given SessionState. Also, set the context classloader of * the current thread to the one set in the HiveConf of this given `state`. - * @param state */ def setCurrentSessionState(state: SessionState): Unit @@ -321,7 +320,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { def convertFilters(table: Table, filters: Seq[Expression]): String = { // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. val varcharKeys = table.getPartitionKeys.asScala - .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME)) + .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || + col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) .map(col => col.getName).toSet filters.collect { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index f99c3ed2ae987..f45264af34d93 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -22,37 +22,69 @@ import java.lang.reflect.InvocationTargetException import java.net.{URL, URLClassLoader} import java.util -import scala.collection.mutable import scala.language.reflectiveCalls import scala.util.Try import org.apache.commons.io.{FileUtils, IOUtils} +import org.apache.hadoop.conf.Configuration -import org.apache.spark.Logging +import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkSubmitUtils -import org.apache.spark.util.{MutableURLClassLoader, Utils} - +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.util.{MutableURLClassLoader, Utils} /** Factory for `IsolatedClientLoader` with specific versions of hive. */ -private[hive] object IsolatedClientLoader { +private[hive] object IsolatedClientLoader extends Logging { /** * Creates isolated Hive client loaders by downloading the requested version from maven. */ def forVersion( - version: String, + hiveMetastoreVersion: String, + hadoopVersion: String, + sparkConf: SparkConf, + hadoopConf: Configuration, config: Map[String, String] = Map.empty, ivyPath: Option[String] = None, sharedPrefixes: Seq[String] = Seq.empty, barrierPrefixes: Seq[String] = Seq.empty): IsolatedClientLoader = synchronized { - val resolvedVersion = hiveVersion(version) - val files = resolvedVersions.getOrElseUpdate(resolvedVersion, - downloadVersion(resolvedVersion, ivyPath)) + val resolvedVersion = hiveVersion(hiveMetastoreVersion) + // We will first try to share Hadoop classes. If we cannot resolve the Hadoop artifact + // with the given version, we will use Hadoop 2.4.0 and then will not share Hadoop classes. + var sharesHadoopClasses = true + val files = if (resolvedVersions.contains((resolvedVersion, hadoopVersion))) { + resolvedVersions((resolvedVersion, hadoopVersion)) + } else { + val (downloadedFiles, actualHadoopVersion) = + try { + (downloadVersion(resolvedVersion, hadoopVersion, ivyPath), hadoopVersion) + } catch { + case e: RuntimeException if e.getMessage.contains("hadoop") => + // If the error message contains hadoop, it is probably because the hadoop + // version cannot be resolved (e.g. it is a vendor specific version like + // 2.0.0-cdh4.1.1). If it is the case, we will try just + // "org.apache.hadoop:hadoop-client:2.4.0". "org.apache.hadoop:hadoop-client:2.4.0" + // is used just because we used to hard code it as the hadoop artifact to download. + logWarning(s"Failed to resolve Hadoop artifacts for the version ${hadoopVersion}. " + + s"We will change the hadoop version from ${hadoopVersion} to 2.4.0 and try again. " + + "Hadoop classes will not be shared between Spark and Hive metastore client. " + + "It is recommended to set jars used by Hive metastore client through " + + "spark.sql.hive.metastore.jars in the production environment.") + sharesHadoopClasses = false + (downloadVersion(resolvedVersion, "2.4.0", ivyPath), "2.4.0") + } + resolvedVersions.put((resolvedVersion, actualHadoopVersion), downloadedFiles) + resolvedVersions((resolvedVersion, actualHadoopVersion)) + } + new IsolatedClientLoader( - version = hiveVersion(version), + hiveVersion(hiveMetastoreVersion), + sparkConf, execJars = files, + hadoopConf = hadoopConf, config = config, + sharesHadoopClasses = sharesHadoopClasses, sharedPrefixes = sharedPrefixes, barrierPrefixes = barrierPrefixes) } @@ -66,12 +98,15 @@ private[hive] object IsolatedClientLoader { case "1.2" | "1.2.0" | "1.2.1" => hive.v1_2 } - private def downloadVersion(version: HiveVersion, ivyPath: Option[String]): Seq[URL] = { + private def downloadVersion( + version: HiveVersion, + hadoopVersion: String, + ivyPath: Option[String]): Seq[URL] = { val hiveArtifacts = version.extraDeps ++ - Seq("hive-metastore", "hive-exec", "hive-common", "hive-serde") + Seq("hive-metastore", "hive-exec", "hive-common", "hive-serde", "hive-cli") .map(a => s"org.apache.hive:$a:${version.fullVersion}") ++ Seq("com.google.guava:guava:14.0.1", - "org.apache.hadoop:hadoop-client:2.4.0") + s"org.apache.hadoop:hadoop-client:$hadoopVersion") val classpath = quietly { SparkSubmitUtils.resolveMavenCoordinates( @@ -88,19 +123,22 @@ private[hive] object IsolatedClientLoader { tempDir.listFiles().map(_.toURI.toURL) } - private def resolvedVersions = new scala.collection.mutable.HashMap[HiveVersion, Seq[URL]] + // A map from a given pair of HiveVersion and Hadoop version to jar files. + // It is only used by forVersion. + private val resolvedVersions = + new scala.collection.mutable.HashMap[(HiveVersion, String), Seq[URL]] } /** - * Creates a Hive `ClientInterface` using a classloader that works according to the following rules: + * Creates a [[HiveClient]] using a classloader that works according to the following rules: * - Shared classes: Java, Scala, logging, and Spark classes are delegated to `baseClassLoader` - * allowing the results of calls to the `ClientInterface` to be visible externally. + * allowing the results of calls to the [[HiveClient]] to be visible externally. * - Hive classes: new instances are loaded from `execJars`. These classes are not * accessible externally due to their custom loading. - * - ClientWrapper: a new copy is created for each instance of `IsolatedClassLoader`. + * - [[HiveClientImpl]]: a new copy is created for each instance of `IsolatedClassLoader`. * This new instance is able to see a specific version of hive without using reflection. Since * this is a unique instance, it is not visible externally other than as a generic - * `ClientInterface`, unless `isolationOn` is set to `false`. + * [[HiveClient]], unless `isolationOn` is set to `false`. * * @param version The version of hive on the classpath. used to pick specific function signatures * that are not compatible across versions. @@ -108,14 +146,18 @@ private[hive] object IsolatedClientLoader { * @param config A set of options that will be added to the HiveConf of the constructed client. * @param isolationOn When true, custom versions of barrier classes will be constructed. Must be * true unless loading the version of hive that is on Sparks classloader. + * @param sharesHadoopClasses When true, we will share Hadoop classes between Spark and * @param rootClassLoader The system root classloader. Must not know about Hive classes. * @param baseClassLoader The spark classloader that is used to load shared classes. */ private[hive] class IsolatedClientLoader( val version: HiveVersion, + val sparkConf: SparkConf, + val hadoopConf: Configuration, val execJars: Seq[URL] = Seq.empty, val config: Map[String, String] = Map.empty, val isolationOn: Boolean = true, + val sharesHadoopClasses: Boolean = true, val rootClassLoader: ClassLoader = ClassLoader.getSystemClassLoader.getParent.getParent, val baseClassLoader: ClassLoader = Thread.currentThread().getContextClassLoader, val sharedPrefixes: Seq[String] = Seq.empty, @@ -128,20 +170,24 @@ private[hive] class IsolatedClientLoader( /** All jars used by the hive specific classloader. */ protected def allJars = execJars.toArray - protected def isSharedClass(name: String): Boolean = + protected def isSharedClass(name: String): Boolean = { + val isHadoopClass = + name.startsWith("org.apache.hadoop.") && !name.startsWith("org.apache.hadoop.hive.") + name.contains("slf4j") || name.contains("log4j") || name.startsWith("org.apache.spark.") || - (name.startsWith("org.apache.hadoop.") && !name.startsWith("org.apache.hadoop.hive.")) || + (sharesHadoopClasses && isHadoopClass) || name.startsWith("scala.") || (name.startsWith("com.google") && !name.startsWith("com.google.cloud")) || name.startsWith("java.lang.") || name.startsWith("java.net") || sharedPrefixes.exists(name.startsWith) + } /** True if `name` refers to a spark class that must see specific version of Hive. */ protected def isBarrierClass(name: String): Boolean = - name.startsWith(classOf[ClientWrapper].getName) || + name.startsWith(classOf[HiveClientImpl].getName) || name.startsWith(classOf[Shim].getName) || barrierPrefixes.exists(name.startsWith) @@ -190,15 +236,14 @@ private[hive] class IsolatedClientLoader( new NonClosableMutableURLClassLoader(isolatedClassLoader) } - private[hive] def addJar(path: String): Unit = synchronized { - val jarURL = new java.io.File(path).toURI.toURL - classLoader.addURL(jarURL) + private[hive] def addJar(path: URL): Unit = synchronized { + classLoader.addURL(path) } /** The isolated client interface to Hive. */ - private[hive] def createClient(): ClientInterface = { + private[hive] def createClient(): HiveClient = { if (!isolationOn) { - return new ClientWrapper(version, config, baseClassLoader, this) + return new HiveClientImpl(version, sparkConf, hadoopConf, config, baseClassLoader, this) } // Pre-reflective instantiation setup. logDebug("Initializing the logger to avoid disaster...") @@ -207,10 +252,10 @@ private[hive] class IsolatedClientLoader( try { classLoader - .loadClass(classOf[ClientWrapper].getName) + .loadClass(classOf[HiveClientImpl].getName) .getConstructors.head - .newInstance(version, config, classLoader, this) - .asInstanceOf[ClientInterface] + .newInstance(version, sparkConf, hadoopConf, config, classLoader, this) + .asInstanceOf[HiveClient] } catch { case e: InvocationTargetException => if (e.getCause().isInstanceOf[NoClassDefFoundError]) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index e72a60b42e653..29f7dc2997d26 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable} import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} -import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} +import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes, MetastoreRelation} -import org.apache.spark.sql.{AnalysisException, Row, SQLContext} /** * Create table and insert the query result into it. @@ -33,12 +33,12 @@ import org.apache.spark.sql.{AnalysisException, Row, SQLContext} */ private[hive] case class CreateTableAsSelect( - tableDesc: HiveTable, + tableDesc: CatalogTable, query: LogicalPlan, allowExisting: Boolean) extends RunnableCommand { - val tableIdentifier = TableIdentifier(tableDesc.name, Some(tableDesc.database)) + private val tableIdentifier = tableDesc.identifier override def children: Seq[LogicalPlan] = Seq(query) @@ -51,35 +51,35 @@ case class CreateTableAsSelect( import org.apache.hadoop.mapred.TextInputFormat val withFormat = - tableDesc.copy( + tableDesc.withNewStorage( inputFormat = - tableDesc.inputFormat.orElse(Some(classOf[TextInputFormat].getName)), + tableDesc.storage.inputFormat.orElse(Some(classOf[TextInputFormat].getName)), outputFormat = - tableDesc.outputFormat + tableDesc.storage.outputFormat .orElse(Some(classOf[HiveIgnoreKeyTextOutputFormat[Text, Text]].getName)), - serde = tableDesc.serde.orElse(Some(classOf[LazySimpleSerDe].getName()))) + serde = tableDesc.storage.serde.orElse(Some(classOf[LazySimpleSerDe].getName))) val withSchema = if (withFormat.schema.isEmpty) { // Hive doesn't support specifying the column list for target table in CTAS // However we don't think SparkSQL should follow that. - tableDesc.copy(schema = - query.output.map(c => - HiveColumn(c.name, HiveMetastoreTypes.toMetastoreType(c.dataType), null))) + tableDesc.copy(schema = query.output.map { c => + CatalogColumn(c.name, HiveMetastoreTypes.toMetastoreType(c.dataType)) + }) } else { withFormat } - hiveContext.catalog.client.createTable(withSchema) + hiveContext.sessionState.catalog.createTable(withSchema, ignoreIfExists = false) // Get the Metastore Relation - hiveContext.catalog.lookupRelation(tableIdentifier, None) match { + hiveContext.sessionState.catalog.lookupRelation(tableIdentifier) match { case r: MetastoreRelation => r } } // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data // processing. - if (hiveContext.catalog.tableExists(tableIdentifier)) { + if (hiveContext.sessionState.catalog.tableExists(tableIdentifier)) { if (allowExisting) { // table already exists, will do nothing, to keep consistent with Hive } else { @@ -93,6 +93,8 @@ case class CreateTableAsSelect( } override def argString: String = { - s"[Database:${tableDesc.database}}, TableName: ${tableDesc.name}, InsertIntoHiveTable]" + s"[Database:${tableDesc.database}}, " + + s"TableName: ${tableDesc.identifier.table}, " + + s"InsertIntoHiveTable]" } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala index 2c81115ee4fed..33cd8b44805b8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.hive.{HiveMetastoreTypes, HiveContext} +import scala.util.control.NonFatal + import org.apache.spark.sql.{AnalysisException, Row, SQLContext} -import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable} +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.hive.{ HiveContext, HiveMetastoreTypes, SQLBuilder} /** * Create Hive view on non-hive-compatible tables by specifying schema ourselves instead of @@ -31,68 +34,104 @@ import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} // TODO: Note that this class can NOT canonicalize the view SQL string entirely, which is different // from Hive and may not work for some cases like create view on self join. private[hive] case class CreateViewAsSelect( - tableDesc: HiveTable, - childSchema: Seq[Attribute], + tableDesc: CatalogTable, + child: LogicalPlan, allowExisting: Boolean, orReplace: Boolean) extends RunnableCommand { + private val childSchema = child.output + assert(tableDesc.schema == Nil || tableDesc.schema.length == childSchema.length) assert(tableDesc.viewText.isDefined) - val tableIdentifier = TableIdentifier(tableDesc.name, Some(tableDesc.database)) + private val tableIdentifier = tableDesc.identifier override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] - if (hiveContext.catalog.tableExists(tableIdentifier)) { - if (allowExisting) { - // view already exists, will do nothing, to keep consistent with Hive - } else if (orReplace) { - hiveContext.catalog.client.alertView(prepareTable()) - } else { + hiveContext.sessionState.catalog.tableExists(tableIdentifier) match { + case true if allowExisting => + // Handles `CREATE VIEW IF NOT EXISTS v0 AS SELECT ...`. Does nothing when the target view + // already exists. + + case true if orReplace => + // Handles `CREATE OR REPLACE VIEW v0 AS SELECT ...` + hiveContext.metadataHive.alertView(prepareTable(sqlContext)) + + case true => + // Handles `CREATE VIEW v0 AS SELECT ...`. Throws exception when the target view already + // exists. throw new AnalysisException(s"View $tableIdentifier already exists. " + "If you want to update the view definition, please use ALTER VIEW AS or " + "CREATE OR REPLACE VIEW AS") - } - } else { - hiveContext.catalog.client.createView(prepareTable()) + + case false => + hiveContext.metadataHive.createView(prepareTable(sqlContext)) } Seq.empty[Row] } - private def prepareTable(): HiveTable = { - // setup column types according to the schema of child. - val schema = if (tableDesc.schema == Nil) { - childSchema.map { attr => - HiveColumn(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), null) + private def prepareTable(sqlContext: SQLContext): CatalogTable = { + val expandedText = if (sqlContext.conf.canonicalView) { + try rebuildViewQueryString(sqlContext) catch { + case NonFatal(e) => wrapViewTextWithSelect } } else { - childSchema.zip(tableDesc.schema).map { case (attr, col) => - HiveColumn(col.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), col.comment) + wrapViewTextWithSelect + } + + val viewSchema = { + if (tableDesc.schema.isEmpty) { + childSchema.map { a => + CatalogColumn(a.name, HiveMetastoreTypes.toMetastoreType(a.dataType)) + } + } else { + childSchema.zip(tableDesc.schema).map { case (a, col) => + CatalogColumn( + col.name, + HiveMetastoreTypes.toMetastoreType(a.dataType), + nullable = true, + col.comment) + } } } - val columnNames = childSchema.map(f => verbose(f.name)) + tableDesc.copy(schema = viewSchema, viewText = Some(expandedText)) + } + private def wrapViewTextWithSelect: String = { // When user specified column names for view, we should create a project to do the renaming. // When no column name specified, we still need to create a project to declare the columns // we need, to make us more robust to top level `*`s. - val projectList = if (tableDesc.schema == Nil) { - columnNames.mkString(", ") - } else { - columnNames.zip(tableDesc.schema.map(f => verbose(f.name))).map { - case (name, alias) => s"$name AS $alias" - }.mkString(", ") + val viewOutput = { + val columnNames = childSchema.map(f => quote(f.name)) + if (tableDesc.schema.isEmpty) { + columnNames.mkString(", ") + } else { + columnNames.zip(tableDesc.schema.map(f => quote(f.name))).map { + case (name, alias) => s"$name AS $alias" + }.mkString(", ") + } } - val viewName = verbose(tableDesc.name) - - val expandedText = s"SELECT $projectList FROM (${tableDesc.viewText.get}) $viewName" + val viewText = tableDesc.viewText.get + val viewName = quote(tableDesc.identifier.table) + s"SELECT $viewOutput FROM ($viewText) $viewName" + } - tableDesc.copy(schema = schema, viewText = Some(expandedText)) + private def rebuildViewQueryString(sqlContext: SQLContext): String = { + val logicalPlan = if (tableDesc.schema.isEmpty) { + child + } else { + val projectList = childSchema.zip(tableDesc.schema).map { + case (attr, col) => Alias(attr, col.name)() + } + sqlContext.executePlan(Project(projectList, child)).analyzed + } + new SQLBuilder(logicalPlan, sqlContext).toSQL } // escape backtick with double-backtick in column name and wrap it with backtick. - private def verbose(name: String) = s"`${name.replaceAll("`", "``")}`" + private def quote(name: String) = s"`${name.replaceAll("`", "``")}`" } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala index 441b6b6033e1f..8481324086c34 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala @@ -21,43 +21,56 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.command.{DescribeCommand, RunnableCommand} import org.apache.spark.sql.hive.MetastoreRelation -import org.apache.spark.sql.{Row, SQLContext} /** * Implementation for "describe [extended] table". */ private[hive] case class DescribeHiveTableCommand( - table: MetastoreRelation, + tableId: TableIdentifier, override val output: Seq[Attribute], isExtended: Boolean) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { - // Trying to mimic the format of Hive's output. But not exactly the same. - var results: Seq[(String, String, String)] = Nil - - val columns: Seq[FieldSchema] = table.hiveQlTable.getCols.asScala - val partitionColumns: Seq[FieldSchema] = table.hiveQlTable.getPartCols.asScala - results ++= columns.map(field => (field.getName, field.getType, field.getComment)) - if (partitionColumns.nonEmpty) { - val partColumnInfo = - partitionColumns.map(field => (field.getName, field.getType, field.getComment)) - results ++= - partColumnInfo ++ - Seq(("# Partition Information", "", "")) ++ - Seq((s"# ${output(0).name}", output(1).name, output(2).name)) ++ - partColumnInfo - } + // There are two modes here: + // For metastore tables, create an output similar to Hive's. + // For other tables, delegate to DescribeCommand. - if (isExtended) { - results ++= Seq(("Detailed Table Information", table.hiveQlTable.getTTable.toString, "")) - } + // In the future, we will consolidate the two and simply report what the catalog reports. + sqlContext.sessionState.catalog.lookupRelation(tableId) match { + case table: MetastoreRelation => + // Trying to mimic the format of Hive's output. But not exactly the same. + var results: Seq[(String, String, String)] = Nil + + val columns: Seq[FieldSchema] = table.hiveQlTable.getCols.asScala + val partitionColumns: Seq[FieldSchema] = table.hiveQlTable.getPartCols.asScala + results ++= columns.map(field => (field.getName, field.getType, field.getComment)) + if (partitionColumns.nonEmpty) { + val partColumnInfo = + partitionColumns.map(field => (field.getName, field.getType, field.getComment)) + results ++= + partColumnInfo ++ + Seq(("# Partition Information", "", "")) ++ + Seq((s"# ${output(0).name}", output(1).name, output(2).name)) ++ + partColumnInfo + } + + if (isExtended) { + results ++= Seq(("Detailed Table Information", table.hiveQlTable.getTTable.toString, "")) + } + + results.map { case (name, dataType, comment) => + Row(name, dataType, comment) + } - results.map { case (name, dataType, comment) => - Row(name, dataType, comment) + case o: LogicalPlan => + DescribeCommand(tableId, output, isExtended).run(sqlContext) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala index 41b645b2c9c93..9bb971992d0d1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.types.StringType -import org.apache.spark.sql.{Row, SQLContext} private[hive] case class HiveNativeCommand(sql: String) extends RunnableCommand { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala new file mode 100644 index 0000000000000..a97b65e27bc59 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala @@ -0,0 +1,503 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hive.execution + +import scala.collection.JavaConverters._ + +import org.antlr.v4.runtime.{ParserRuleContext, Token} +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.parse.EximUtil +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.serde.serdeConstants +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe + +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.parser._ +import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkSqlAstBuilder +import org.apache.spark.sql.execution.command.{CreateTable, CreateTableLike} +import org.apache.spark.sql.hive.{CreateTableAsSelect => CTAS, CreateViewAsSelect => CreateView, HiveSerDe} +import org.apache.spark.sql.hive.{HiveGenericUDTF, HiveMetastoreTypes, HiveSerDe} +import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper + +/** + * Concrete parser for HiveQl statements. + */ +object HiveSqlParser extends AbstractSqlParser { + val astBuilder = new HiveSqlAstBuilder + + override protected def nativeCommand(sqlText: String): LogicalPlan = { + HiveNativeCommand(sqlText) + } +} + +/** + * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier. + */ +class HiveSqlAstBuilder extends SparkSqlAstBuilder { + import ParserUtils._ + + /** + * Get the current Hive Configuration. + */ + private[this] def hiveConf: HiveConf = { + var ss = SessionState.get() + // SessionState is lazy initialization, it can be null here + if (ss == null) { + val original = Thread.currentThread().getContextClassLoader + val conf = new HiveConf(classOf[SessionState]) + conf.setClassLoader(original) + ss = new SessionState(conf) + SessionState.start(ss) + } + ss.getConf + } + + /** + * Pass a command to Hive using a [[HiveNativeCommand]]. + */ + override def visitExecuteNativeCommand( + ctx: ExecuteNativeCommandContext): LogicalPlan = withOrigin(ctx) { + HiveNativeCommand(command(ctx)) + } + + /** + * Fail an unsupported Hive native command. + */ + override def visitFailNativeCommand( + ctx: FailNativeCommandContext): LogicalPlan = withOrigin(ctx) { + val keywords = if (ctx.kws != null) { + Seq(ctx.kws.kw1, ctx.kws.kw2, ctx.kws.kw3).filter(_ != null).map(_.getText).mkString(" ") + } else { + // SET ROLE is the exception to the rule, because we handle this before other SET commands. + "SET ROLE" + } + throw new ParseException(s"Unsupported operation: $keywords", ctx) + } + + /** + * Create an [[AddJar]] or [[AddFile]] command depending on the requested resource. + */ + override def visitAddResource(ctx: AddResourceContext): LogicalPlan = withOrigin(ctx) { + ctx.identifier.getText.toLowerCase match { + case "file" => AddFile(remainder(ctx.identifier).trim) + case "jar" => AddJar(remainder(ctx.identifier).trim) + case other => throw new ParseException(s"Unsupported resource type '$other'.", ctx) + } + } + + /** + * Create an [[AnalyzeTable]] command. This currently only implements the NOSCAN option (other + * options are passed on to Hive) e.g.: + * {{{ + * ANALYZE TABLE table COMPUTE STATISTICS NOSCAN; + * }}} + */ + override def visitAnalyze(ctx: AnalyzeContext): LogicalPlan = withOrigin(ctx) { + if (ctx.partitionSpec == null && + ctx.identifier != null && + ctx.identifier.getText.toLowerCase == "noscan") { + AnalyzeTable(visitTableIdentifier(ctx.tableIdentifier).toString) + } else { + HiveNativeCommand(command(ctx)) + } + } + + /** + * Create a [[CatalogStorageFormat]] for creating tables. + */ + override def visitCreateFileFormat( + ctx: CreateFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { + (ctx.fileFormat, ctx.storageHandler) match { + // Expected format: INPUTFORMAT input_format OUTPUTFORMAT output_format + case (c: TableFileFormatContext, null) => + visitTableFileFormat(c) + // Expected format: SEQUENCEFILE | TEXTFILE | RCFILE | ORC | PARQUET | AVRO + case (c: GenericFileFormatContext, null) => + visitGenericFileFormat(c) + case (null, storageHandler) => + throw new ParseException("Operation not allowed: ... STORED BY storage_handler ...", ctx) + case _ => + throw new ParseException("expected either STORED AS or STORED BY, not both", ctx) + } + } + + /** + * Create a table, returning either a [[CreateTable]] or a [[CreateTableAsSelect]]. + * + * This is not used to create datasource tables, which is handled through + * "CREATE TABLE ... USING ...". + * + * Note: several features are currently not supported - temporary tables, bucketing, + * skewed columns and storage handlers (STORED BY). + * + * Expected format: + * {{{ + * CREATE [TEMPORARY] [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name + * [(col1 data_type [COMMENT col_comment], ...)] + * [COMMENT table_comment] + * [PARTITIONED BY (col3 data_type [COMMENT col_comment], ...)] + * [CLUSTERED BY (col1, ...) [SORTED BY (col1 [ASC|DESC], ...)] INTO num_buckets BUCKETS] + * [SKEWED BY (col1, col2, ...) ON ((col_value, col_value, ...), ...) [STORED AS DIRECTORIES]] + * [ROW FORMAT row_format] + * [STORED AS file_format | STORED BY storage_handler_class [WITH SERDEPROPERTIES (...)]] + * [LOCATION path] + * [TBLPROPERTIES (property_name=property_value, ...)] + * [AS select_statement]; + * }}} + */ + override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { + val (name, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) + // TODO: implement temporary tables + if (temp) { + throw new ParseException( + "CREATE TEMPORARY TABLE is not supported yet. " + + "Please use registerTempTable as an alternative.", ctx) + } + if (ctx.skewSpec != null) { + throw new ParseException("Operation not allowed: CREATE TABLE ... SKEWED BY ...", ctx) + } + if (ctx.bucketSpec != null) { + throw new ParseException("Operation not allowed: CREATE TABLE ... CLUSTERED BY ...", ctx) + } + val tableType = if (external) { + CatalogTableType.EXTERNAL_TABLE + } else { + CatalogTableType.MANAGED_TABLE + } + val comment = Option(ctx.STRING).map(string) + val partitionCols = Option(ctx.partitionColumns).toSeq.flatMap(visitCatalogColumns) + val cols = Option(ctx.columns).toSeq.flatMap(visitCatalogColumns) + val properties = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty) + val selectQuery = Option(ctx.query).map(plan) + + // Note: Hive requires partition columns to be distinct from the schema, so we need + // to include the partition columns here explicitly + val schema = cols ++ partitionCols + + // Storage format + val defaultStorage: CatalogStorageFormat = { + val defaultStorageType = hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTFILEFORMAT) + val defaultHiveSerde = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf) + CatalogStorageFormat( + locationUri = None, + inputFormat = defaultHiveSerde.flatMap(_.inputFormat) + .orElse(Some("org.apache.hadoop.mapred.TextInputFormat")), + outputFormat = defaultHiveSerde.flatMap(_.outputFormat) + .orElse(Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), + // Note: Keep this unspecified because we use the presence of the serde to decide + // whether to convert a table created by CTAS to a datasource table. + serde = None, + serdeProperties = Map()) + } + val fileStorage = Option(ctx.createFileFormat).map(visitCreateFileFormat) + .getOrElse(EmptyStorageFormat) + val rowStorage = Option(ctx.rowFormat).map(visitRowFormat).getOrElse(EmptyStorageFormat) + val location = Option(ctx.locationSpec).map(visitLocationSpec) + val storage = CatalogStorageFormat( + locationUri = location, + inputFormat = fileStorage.inputFormat.orElse(defaultStorage.inputFormat), + outputFormat = fileStorage.outputFormat.orElse(defaultStorage.outputFormat), + serde = rowStorage.serde.orElse(fileStorage.serde).orElse(defaultStorage.serde), + serdeProperties = rowStorage.serdeProperties ++ fileStorage.serdeProperties) + + // TODO support the sql text - have a proper location for this! + val tableDesc = CatalogTable( + identifier = name, + tableType = tableType, + storage = storage, + schema = schema, + partitionColumnNames = partitionCols.map(_.name), + properties = properties, + comment = comment) + + selectQuery match { + case Some(q) => CTAS(tableDesc, q, ifNotExists) + case None => CreateTable(tableDesc, ifNotExists) + } + } + + /** + * Create a [[CreateTableLike]] command. + */ + override def visitCreateTableLike(ctx: CreateTableLikeContext): LogicalPlan = withOrigin(ctx) { + val targetTable = visitTableIdentifier(ctx.target) + val sourceTable = visitTableIdentifier(ctx.source) + CreateTableLike(targetTable, sourceTable, ctx.EXISTS != null) + } + + /** + * Create or replace a view. This creates a [[CreateViewAsSelect]] command. + * + * For example: + * {{{ + * CREATE VIEW [IF NOT EXISTS] [db_name.]view_name + * [(column_name [COMMENT column_comment], ...) ] + * [COMMENT view_comment] + * [TBLPROPERTIES (property_name = property_value, ...)] + * AS SELECT ...; + * }}} + */ + override def visitCreateView(ctx: CreateViewContext): LogicalPlan = withOrigin(ctx) { + if (ctx.identifierList != null) { + throw new ParseException(s"Operation not allowed: partitioned views", ctx) + } else { + val identifiers = Option(ctx.identifierCommentList).toSeq.flatMap(_.identifierComment.asScala) + val schema = identifiers.map { ic => + CatalogColumn(ic.identifier.getText, null, nullable = true, Option(ic.STRING).map(string)) + } + createView( + ctx, + ctx.tableIdentifier, + comment = Option(ctx.STRING).map(string), + schema, + ctx.query, + Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty), + ctx.EXISTS != null, + ctx.REPLACE != null + ) + } + } + + /** + * Alter the query of a view. This creates a [[CreateViewAsSelect]] command. + */ + override def visitAlterViewQuery(ctx: AlterViewQueryContext): LogicalPlan = withOrigin(ctx) { + createView( + ctx, + ctx.tableIdentifier, + comment = None, + Seq.empty, + ctx.query, + Map.empty, + allowExist = false, + replace = true) + } + + /** + * Create a [[CreateViewAsSelect]] command. + */ + private def createView( + ctx: ParserRuleContext, + name: TableIdentifierContext, + comment: Option[String], + schema: Seq[CatalogColumn], + query: QueryContext, + properties: Map[String, String], + allowExist: Boolean, + replace: Boolean): LogicalPlan = { + val sql = Option(source(query)) + val tableDesc = CatalogTable( + identifier = visitTableIdentifier(name), + tableType = CatalogTableType.VIRTUAL_VIEW, + schema = schema, + storage = EmptyStorageFormat, + properties = properties, + viewOriginalText = sql, + viewText = sql, + comment = comment) + CreateView(tableDesc, plan(query), allowExist, replace, command(ctx)) + } + + /** + * Create a [[HiveScriptIOSchema]]. + */ + override protected def withScriptIOSchema( + ctx: QuerySpecificationContext, + inRowFormat: RowFormatContext, + recordWriter: Token, + outRowFormat: RowFormatContext, + recordReader: Token, + schemaLess: Boolean): HiveScriptIOSchema = { + if (recordWriter != null || recordReader != null) { + throw new ParseException( + "Unsupported operation: Used defined record reader/writer classes.", ctx) + } + + // Decode and input/output format. + type Format = (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String]) + def format(fmt: RowFormatContext, confVar: ConfVars): Format = fmt match { + case c: RowFormatDelimitedContext => + // TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema + // expects a seq of pairs in which the old parsers' token names are used as keys. + // Transforming the result of visitRowFormatDelimited would be quite a bit messier than + // retrieving the key value pairs ourselves. + def entry(key: String, value: Token): Seq[(String, String)] = { + Option(value).map(t => key -> t.getText).toSeq + } + val entries = entry("TOK_TABLEROWFORMATFIELD", c.fieldsTerminatedBy) ++ + entry("TOK_TABLEROWFORMATCOLLITEMS", c.collectionItemsTerminatedBy) ++ + entry("TOK_TABLEROWFORMATMAPKEYS", c.keysTerminatedBy) ++ + entry("TOK_TABLEROWFORMATLINES", c.linesSeparatedBy) ++ + entry("TOK_TABLEROWFORMATNULL", c.nullDefinedAs) + + (entries, None, Seq.empty, None) + + case c: RowFormatSerdeContext => + // Use a serde format. + val CatalogStorageFormat(None, None, None, Some(name), props) = visitRowFormatSerde(c) + + // SPARK-10310: Special cases LazySimpleSerDe + val recordHandler = if (name == classOf[LazySimpleSerDe].getCanonicalName) { + Option(hiveConf.getVar(confVar)) + } else { + None + } + (Seq.empty, Option(name), props.toSeq, recordHandler) + + case null => + // Use default (serde) format. + val name = hiveConf.getVar(ConfVars.HIVESCRIPTSERDE) + val props = Seq(serdeConstants.FIELD_DELIM -> "\t") + val recordHandler = Option(hiveConf.getVar(confVar)) + (Nil, Option(name), props, recordHandler) + } + + val (inFormat, inSerdeClass, inSerdeProps, reader) = + format(inRowFormat, ConfVars.HIVESCRIPTRECORDREADER) + + val (outFormat, outSerdeClass, outSerdeProps, writer) = + format(inRowFormat, ConfVars.HIVESCRIPTRECORDWRITER) + + HiveScriptIOSchema( + inFormat, outFormat, + inSerdeClass, outSerdeClass, + inSerdeProps, outSerdeProps, + reader, writer, + schemaLess) + } + + /** + * Create location string. + */ + override def visitLocationSpec(ctx: LocationSpecContext): String = { + EximUtil.relativeToAbsolutePath(hiveConf, super.visitLocationSpec(ctx)) + } + + /** Empty storage format for default values and copies. */ + private val EmptyStorageFormat = CatalogStorageFormat(None, None, None, None, Map.empty) + + /** + * Create a [[CatalogStorageFormat]]. + */ + override def visitTableFileFormat( + ctx: TableFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { + EmptyStorageFormat.copy( + inputFormat = Option(string(ctx.inFmt)), + outputFormat = Option(string(ctx.outFmt)), + serde = Option(ctx.serdeCls).map(string) + ) + } + + /** + * Resolve a [[HiveSerDe]] based on the name given and return it as a [[CatalogStorageFormat]]. + */ + override def visitGenericFileFormat( + ctx: GenericFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { + val source = ctx.identifier.getText + HiveSerDe.sourceToSerDe(source, hiveConf) match { + case Some(s) => + EmptyStorageFormat.copy( + inputFormat = s.inputFormat, + outputFormat = s.outputFormat, + serde = s.serde) + case None => + throw new ParseException(s"Unrecognized file format in STORED AS clause: $source", ctx) + } + } + + /** + * Create a [[RowFormat]] used for creating tables. + * + * Example format: + * {{{ + * SERDE serde_name [WITH SERDEPROPERTIES (k1=v1, k2=v2, ...)] + * }}} + * + * OR + * + * {{{ + * DELIMITED [FIELDS TERMINATED BY char [ESCAPED BY char]] + * [COLLECTION ITEMS TERMINATED BY char] + * [MAP KEYS TERMINATED BY char] + * [LINES TERMINATED BY char] + * [NULL DEFINED AS char] + * }}} + */ + private def visitRowFormat(ctx: RowFormatContext): CatalogStorageFormat = withOrigin(ctx) { + ctx match { + case serde: RowFormatSerdeContext => visitRowFormatSerde(serde) + case delimited: RowFormatDelimitedContext => visitRowFormatDelimited(delimited) + } + } + + /** + * Create SERDE row format name and properties pair. + */ + override def visitRowFormatSerde( + ctx: RowFormatSerdeContext): CatalogStorageFormat = withOrigin(ctx) { + import ctx._ + EmptyStorageFormat.copy( + serde = Option(string(name)), + serdeProperties = Option(tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty)) + } + + /** + * Create a delimited row format properties object. + */ + override def visitRowFormatDelimited( + ctx: RowFormatDelimitedContext): CatalogStorageFormat = withOrigin(ctx) { + // Collect the entries if any. + def entry(key: String, value: Token): Seq[(String, String)] = { + Option(value).toSeq.map(x => key -> string(x)) + } + // TODO we need proper support for the NULL format. + val entries = entry(serdeConstants.FIELD_DELIM, ctx.fieldsTerminatedBy) ++ + entry(serdeConstants.SERIALIZATION_FORMAT, ctx.fieldsTerminatedBy) ++ + entry(serdeConstants.ESCAPE_CHAR, ctx.escapedBy) ++ + entry(serdeConstants.COLLECTION_DELIM, ctx.collectionItemsTerminatedBy) ++ + entry(serdeConstants.MAPKEY_DELIM, ctx.keysTerminatedBy) ++ + Option(ctx.linesSeparatedBy).toSeq.map { token => + val value = string(token) + assert( + value == "\n", + s"LINES TERMINATED BY only supports newline '\\n' right now: $value", + ctx) + serdeConstants.LINE_DELIM -> value + } + EmptyStorageFormat.copy(serdeProperties = entries.toMap) + } + + /** + * Create a sequence of [[CatalogColumn]]s from a column list + */ + private def visitCatalogColumns(ctx: ColTypeListContext): Seq[CatalogColumn] = withOrigin(ctx) { + ctx.colType.asScala.map { col => + CatalogColumn( + col.identifier.getText.toLowerCase, + // Note: for types like "STRUCT" we can't + // just convert the whole type string to lower case, otherwise the struct field names + // will no longer be case sensitive. Instead, we rely on our parser to get the proper + // case before passing it to Hive. + CatalystSqlParser.parseDataType(col.dataType.getText).simpleString, + nullable = true, + Option(col.STRING).map(string)) + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 806d2b9b0b7d4..235b80b7c697c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -30,8 +30,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.hive._ import org.apache.spark.sql.types.{BooleanType, DataType} +import org.apache.spark.util.Utils /** * The Hive table scan operator. Column and partition pruning are both handled. @@ -51,6 +53,12 @@ case class HiveTableScan( require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned, "Partition pruning predicates only supported for partitioned tables.") + private[sql] override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + + override def producedAttributes: AttributeSet = outputSet ++ + AttributeSet(partitionPruningPred.flatMap(_.references)) + // Retrieve the original attributes based on expression ID so that capitalization matches. val attributes = requestedAttributes.map(relation.attributeMap) @@ -129,11 +137,27 @@ case class HiveTableScan( } } - protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) { - hadoopReader.makeRDDForTable(relation.hiveQlTable) - } else { - hadoopReader.makeRDDForPartitionedTable( - prunePartitions(relation.getHiveQlPartitions(partitionPruningPred))) + protected override def doExecute(): RDD[InternalRow] = { + // Using dummyCallSite, as getCallSite can turn out to be expensive with + // with multiple partitions. + val rdd = if (!relation.hiveQlTable.isPartitioned) { + Utils.withDummyCallSite(sqlContext.sparkContext) { + hadoopReader.makeRDDForTable(relation.hiveQlTable) + } + } else { + Utils.withDummyCallSite(sqlContext.sparkContext) { + hadoopReader.makeRDDForPartitionedTable( + prunePartitions(relation.getHiveQlPartitions(partitionPruningPred))) + } + } + val numOutputRows = longMetric("numOutputRows") + rdd.mapPartitionsInternal { iter => + val proj = UnsafeProjection.create(schema) + iter.map { r => + numOutputRows += 1 + proj(r) + } + } } override def output: Seq[Attribute] = attributes diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index f936cf565b2bc..430fa4616fc2b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -23,22 +23,16 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.ql.{Context, ErrorMsg} -import org.apache.hadoop.hive.serde2.Serializer -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption -import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} +import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} -import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} import org.apache.spark.sql.hive._ -import org.apache.spark.sql.types.DataType -import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} +import org.apache.spark.SparkException import org.apache.spark.util.SerializableJobConf private[hive] @@ -47,18 +41,11 @@ case class InsertIntoHiveTable( partition: Map[String, Option[String]], child: SparkPlan, overwrite: Boolean, - ifNotExists: Boolean) extends UnaryNode with HiveInspectors { + ifNotExists: Boolean) extends UnaryNode { @transient val sc: HiveContext = sqlContext.asInstanceOf[HiveContext] - @transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass @transient private lazy val hiveContext = new Context(sc.hiveconf) - @transient private lazy val catalog = sc.catalog - - private def newSerializer(tableDesc: TableDesc): Serializer = { - val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] - serializer.initialize(null, tableDesc.getProperties) - serializer - } + @transient private lazy val client = sc.metadataHive def output: Seq[Attribute] = Seq.empty @@ -79,42 +66,10 @@ case class InsertIntoHiveTable( conf.value, SparkHiveWriterContainer.createPathFromString(fileSinkConf.getDirName, conf.value)) log.debug("Saving as hadoop file of type " + valueClass.getSimpleName) - writerContainer.driverSideSetup() - sc.sparkContext.runJob(rdd, writeToFile _) + sc.sparkContext.runJob(rdd, writerContainer.writeToFile _) writerContainer.commitJob() - // Note that this function is executed on executor side - def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = { - val serializer = newSerializer(fileSinkConf.getTableInfo) - val standardOI = ObjectInspectorUtils - .getStandardObjectInspector( - fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, - ObjectInspectorCopyOption.JAVA) - .asInstanceOf[StructObjectInspector] - - val fieldOIs = standardOI.getAllStructFieldRefs.asScala - .map(_.getFieldObjectInspector).toArray - val dataTypes: Array[DataType] = child.output.map(_.dataType).toArray - val wrappers = fieldOIs.zip(dataTypes).map { case (f, dt) => wrapperFor(f, dt)} - val outputData = new Array[Any](fieldOIs.length) - - writerContainer.executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) - - iterator.foreach { row => - var i = 0 - while (i < fieldOIs.length) { - outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i))) - i += 1 - } - - writerContainer - .getLocalFileWriter(row, table.schema) - .write(serializer.serialize(outputData, standardOI)) - } - - writerContainer.close() - } } /** @@ -193,11 +148,21 @@ case class InsertIntoHiveTable( val writerContainer = if (numDynamicPartitions > 0) { val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions) - new SparkHiveDynamicPartitionWriterContainer(jobConf, fileSinkConf, dynamicPartColNames) + new SparkHiveDynamicPartitionWriterContainer( + jobConf, + fileSinkConf, + dynamicPartColNames, + child.output, + table) } else { - new SparkHiveWriterContainer(jobConf, fileSinkConf) + new SparkHiveWriterContainer( + jobConf, + fileSinkConf, + child.output, + table) } + @transient val outputClass = writerContainer.newSerializer(table.tableDesc).getSerializedClass saveAsHiveFile(child.execute(), outputClass, fileSinkConf, jobConfSer, writerContainer) val outputPath = FileOutputFormat.getOutputPath(jobConf) @@ -212,7 +177,7 @@ case class InsertIntoHiveTable( // loadPartition call orders directories created on the iteration order of the this map val orderedPartitionSpec = new util.LinkedHashMap[String, String]() table.hiveQlTable.getPartCols.asScala.foreach { entry => - orderedPartitionSpec.put(entry.getName, partitionSpec.get(entry.getName).getOrElse("")) + orderedPartitionSpec.put(entry.getName, partitionSpec.getOrElse(entry.getName, "")) } // inheritTableSpecs is set to true. It should be set to false for a IMPORT query @@ -221,8 +186,8 @@ case class InsertIntoHiveTable( // TODO: Correctly set isSkewedStoreAsSubdir. val isSkewedStoreAsSubdir = false if (numDynamicPartitions > 0) { - catalog.synchronized { - catalog.client.loadDynamicPartitions( + client.synchronized { + client.loadDynamicPartitions( outputPath.toString, qualifiedTableName, orderedPartitionSpec, @@ -237,12 +202,12 @@ case class InsertIntoHiveTable( // https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DML#LanguageManualDML-InsertingdataintoHiveTablesfromqueries // scalastyle:on val oldPart = - catalog.client.getPartitionOption( - catalog.client.getTable(table.databaseName, table.tableName), - partitionSpec.asJava) + client.getPartitionOption( + client.getTable(table.databaseName, table.tableName), + partitionSpec) if (oldPart.isEmpty || !ifNotExists) { - catalog.client.loadPartition( + client.loadPartition( outputPath.toString, qualifiedTableName, orderedPartitionSpec, @@ -253,7 +218,7 @@ case class InsertIntoHiveTable( } } } else { - catalog.client.loadTable( + client.loadTable( outputPath.toString, // TODO: URI qualifiedTableName, overwrite, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index b30117f0de997..3566526561b2f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.execution import java.io._ +import java.nio.charset.StandardCharsets import java.util.Properties import javax.annotation.Nullable @@ -31,16 +32,17 @@ import org.apache.hadoop.hive.serde2.AbstractSerDe import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.io.Writable +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.execution._ -import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} +import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.types.DataType import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils} -import org.apache.spark.{Logging, TaskContext} /** * Transforms the input by forking and running the specified script. @@ -58,7 +60,9 @@ case class ScriptTransformation( ioschema: HiveScriptIOSchema)(@transient private val sc: HiveContext) extends UnaryNode { - override def otherCopyArgs: Seq[HiveContext] = sc :: Nil + override protected def otherCopyArgs: Seq[HiveContext] = sc :: Nil + + override def producedAttributes: AttributeSet = outputSet -- inputSet private val serializedHiveConf = new SerializableConfiguration(sc.hiveconf) @@ -111,7 +115,7 @@ case class ScriptTransformation( ioschema.initOutputSerDe(output).getOrElse((null, null)) } - val reader = new BufferedReader(new InputStreamReader(inputStream)) + val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { var curLine: String = null val scriptOutputStream = new DataInputStream(inputStream) @@ -211,7 +215,8 @@ case class ScriptTransformation( child.execute().mapPartitions { iter => if (iter.hasNext) { - processIterator(iter) + val proj = UnsafeProjection.create(schema) + processIterator(iter).map(proj) } else { // If the input iterator has no rows then do not launch the external script. Iterator.empty @@ -268,7 +273,7 @@ private class ScriptTransformationWriterThread( sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")) sb.toString() } - outputStream.write(data.getBytes("utf-8")) + outputStream.write(data.getBytes(StandardCharsets.UTF_8)) } else { val writable = inputSerde.serialize( row.asInstanceOf[GenericInternalRow].values, inputSoi) @@ -396,4 +401,52 @@ case class HiveScriptIOSchema ( instance } } + + def inputRowFormatSQL: Option[String] = + getRowFormatSQL(inputRowFormat, inputSerdeClass, inputSerdeProps) + + def outputRowFormatSQL: Option[String] = + getRowFormatSQL(outputRowFormat, outputSerdeClass, outputSerdeProps) + + /** + * Get the row format specification + * Note: + * 1. Changes are needed when readerClause and writerClause are supported. + * 2. Changes are needed when "ESCAPED BY" is supported. + */ + private def getRowFormatSQL( + rowFormat: Seq[(String, String)], + serdeClass: Option[String], + serdeProps: Seq[(String, String)]): Option[String] = { + if (schemaLess) return Some("") + + val rowFormatDelimited = + rowFormat.map { + case ("TOK_TABLEROWFORMATFIELD", value) => + "FIELDS TERMINATED BY " + value + case ("TOK_TABLEROWFORMATCOLLITEMS", value) => + "COLLECTION ITEMS TERMINATED BY " + value + case ("TOK_TABLEROWFORMATMAPKEYS", value) => + "MAP KEYS TERMINATED BY " + value + case ("TOK_TABLEROWFORMATLINES", value) => + "LINES TERMINATED BY " + value + case ("TOK_TABLEROWFORMATNULL", value) => + "NULL DEFINED AS " + value + case o => return None + } + + val serdeClassSQL = serdeClass.map("'" + _ + "'").getOrElse("") + val serdePropsSQL = + if (serdeClass.nonEmpty) { + val props = serdeProps.map{p => s"'${p._1}' = '${p._2}'"}.mkString(", ") + if (props.nonEmpty) " WITH SERDEPROPERTIES(" + props + ")" else "" + } else { + "" + } + if (rowFormat.nonEmpty) { + Some("ROW FORMAT DELIMITED " + rowFormatDelimited.mkString(" ")) + } else { + Some("ROW FORMAT SERDE " + serdeClassSQL + serdePropsSQL) + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 94210a5394f9b..06badff474f49 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -21,12 +21,11 @@ import org.apache.hadoop.hive.metastore.MetaStoreUtils import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} +import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.execution.datasources.{BucketSpec, DataSource, LogicalRelation} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -47,35 +46,6 @@ case class AnalyzeTable(tableName: String) extends RunnableCommand { } } -/** - * Drops a table from the metastore and removes it if it is cached. - */ -private[hive] -case class DropTable( - tableName: String, - ifExists: Boolean) extends RunnableCommand { - - override def run(sqlContext: SQLContext): Seq[Row] = { - val hiveContext = sqlContext.asInstanceOf[HiveContext] - val ifExistsClause = if (ifExists) "IF EXISTS " else "" - try { - hiveContext.cacheManager.tryUncacheQuery(hiveContext.table(tableName)) - } catch { - // This table's metadata is not in Hive metastore (e.g. the table does not exist). - case _: org.apache.hadoop.hive.ql.metadata.InvalidTableException => - case _: org.apache.spark.sql.catalyst.analysis.NoSuchTableException => - // Other Throwables can be caused by users providing wrong parameters in OPTIONS - // (e.g. invalid paths). We catch it and log a warning message. - // Users should be able to drop such kinds of tables regardless if there is an error. - case e: Throwable => log.warn(s"${e.getMessage}", e) - } - hiveContext.invalidateTable(tableName) - hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") - hiveContext.catalog.unregisterTable(TableIdentifier(tableName)) - Seq.empty[Row] - } -} - private[hive] case class AddJar(path: String) extends RunnableCommand { @@ -130,7 +100,7 @@ case class CreateMetastoreDataSource( val tableName = tableIdent.unquotedString val hiveContext = sqlContext.asInstanceOf[HiveContext] - if (hiveContext.catalog.tableExists(tableIdent)) { + if (hiveContext.sessionState.catalog.tableExists(tableIdent)) { if (allowExisting) { return Seq.empty[Row] } else { @@ -142,15 +112,25 @@ case class CreateMetastoreDataSource( val optionsWithPath = if (!options.contains("path") && managedIfNoPath) { isExternal = false - options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableIdent)) + options + ("path" -> + hiveContext.sessionState.catalog.hiveDefaultTableFilePath(tableIdent)) } else { options } - hiveContext.catalog.createDataSourceTable( + // Create the relation to validate the arguments before writing the metadata to the metastore. + DataSource( + sqlContext = sqlContext, + userSpecifiedSchema = userSpecifiedSchema, + className = provider, + bucketSpec = None, + options = optionsWithPath).resolveRelation() + + hiveContext.sessionState.catalog.createDataSourceTable( tableIdent, userSpecifiedSchema, Array.empty[String], + bucketSpec = None, provider, optionsWithPath, isExternal) @@ -164,6 +144,7 @@ case class CreateMetastoreDataSourceAsSelect( tableIdent: TableIdentifier, provider: String, partitionColumns: Array[String], + bucketSpec: Option[BucketSpec], mode: SaveMode, options: Map[String, String], query: LogicalPlan) extends RunnableCommand { @@ -190,13 +171,14 @@ case class CreateMetastoreDataSourceAsSelect( val optionsWithPath = if (!options.contains("path")) { isExternal = false - options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableIdent)) + options + ("path" -> + hiveContext.sessionState.catalog.hiveDefaultTableFilePath(tableIdent)) } else { options } var existingSchema = None: Option[StructType] - if (sqlContext.catalog.tableExists(tableIdent)) { + if (sqlContext.sessionState.catalog.tableExists(tableIdent)) { // Check if we need to throw an exception or just return. mode match { case SaveMode.ErrorIfExists => @@ -210,28 +192,19 @@ case class CreateMetastoreDataSourceAsSelect( return Seq.empty[Row] case SaveMode.Append => // Check if the specified data source match the data source of the existing table. - val resolved = ResolvedDataSource( - sqlContext, Some(query.schema.asNullable), partitionColumns, provider, optionsWithPath) - val createdRelation = LogicalRelation(resolved.relation) - EliminateSubQueries(sqlContext.catalog.lookupRelation(tableIdent)) match { - case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _) => - if (l.relation != createdRelation.relation) { - val errorDescription = - s"Cannot append to table $tableName because the resolved relation does not " + - s"match the existing relation of $tableName. " + - s"You can use insertInto($tableName, false) to append this DataFrame to the " + - s"table $tableName and using its data source and options." - val errorMessage = - s""" - |$errorDescription - |== Relations == - |${sideBySide( - s"== Expected Relation ==" :: l.toString :: Nil, - s"== Actual Relation ==" :: createdRelation.toString :: Nil - ).mkString("\n")} - """.stripMargin - throw new AnalysisException(errorMessage) - } + val dataSource = DataSource( + sqlContext = sqlContext, + userSpecifiedSchema = Some(query.schema.asNullable), + partitionColumns = partitionColumns, + bucketSpec = bucketSpec, + className = provider, + options = optionsWithPath) + // TODO: Check that options from the resolved relation match the relation that we are + // inserting into (i.e. using the same compression). + + EliminateSubqueryAliases( + sqlContext.sessionState.catalog.lookupRelation(tableIdent)) match { + case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _) => existingSchema = Some(l.schema) case o => throw new AnalysisException(s"Saving data in ${o.toString} is not supported.") @@ -246,32 +219,39 @@ case class CreateMetastoreDataSourceAsSelect( createMetastoreTable = true } - val data = DataFrame(hiveContext, query) + val data = Dataset.ofRows(hiveContext, query) val df = existingSchema match { // If we are inserting into an existing table, just use the existing schema. - case Some(schema) => sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, schema) + case Some(s) => sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, s) case None => data } // Create the relation based on the data of df. - val resolved = - ResolvedDataSource(sqlContext, provider, partitionColumns, mode, optionsWithPath, df) + val dataSource = DataSource( + sqlContext, + className = provider, + partitionColumns = partitionColumns, + bucketSpec = bucketSpec, + options = optionsWithPath) + + val result = dataSource.write(mode, df) if (createMetastoreTable) { // We will use the schema of resolved.relation as the schema of the table (instead of // the schema of df). It is important since the nullability may be changed by the relation // provider (for example, see org.apache.spark.sql.parquet.DefaultSource). - hiveContext.catalog.createDataSourceTable( + hiveContext.sessionState.catalog.createDataSourceTable( tableIdent, - Some(resolved.relation.schema), + Some(result.schema), partitionColumns, + bucketSpec, provider, optionsWithPath, isExternal) } // Refresh the cache of the table in the catalog. - hiveContext.catalog.refreshTable(tableIdent) + hiveContext.sessionState.catalog.refreshTable(tableIdent) Seq.empty[Row] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index a9db70119d011..784b018353472 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -17,119 +17,29 @@ package org.apache.spark.sql.hive -import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConverters._ -import scala.util.Try +import scala.collection.mutable.ArrayBuffer -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector} -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory import org.apache.hadoop.hive.ql.exec._ import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ -import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper +import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, + ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions -import org.apache.spark.Logging -import org.apache.spark.sql.AnalysisException +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.types._ -private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) - extends analysis.FunctionRegistry with HiveInspectors { - - def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name) - - override def lookupFunction(name: String, children: Seq[Expression]): Expression = { - Try(underlying.lookupFunction(name, children)).getOrElse { - // We only look it up to see if it exists, but do not include it in the HiveUDF since it is - // not always serializable. - val functionInfo: FunctionInfo = - Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse( - throw new AnalysisException(s"undefined function $name")) - - val functionClassName = functionInfo.getFunctionClass.getName - - // When we instantiate hive UDF wrapper class, we may throw exception if the input expressions - // don't satisfy the hive UDF, such as type mismatch, input number mismatch, etc. Here we - // catch the exception and throw AnalysisException instead. - try { - if (classOf[GenericUDFMacro].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUDF( - new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children) - } else if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveSimpleUDF(new HiveFunctionWrapper(functionClassName), children) - } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children) - } else if ( - classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUDAFFunction(new HiveFunctionWrapper(functionClassName), children) - } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUDAFFunction( - new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true) - } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - val udtf = HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children) - udtf.elementTypes // Force it to check input data types. - udtf - } else { - throw new AnalysisException(s"No handler for udf ${functionInfo.getFunctionClass}") - } - } catch { - case analysisException: AnalysisException => - // If the exception is an AnalysisException, just throw it. - throw analysisException - case throwable: Throwable => - // If there is any other error, we throw an AnalysisException. - val errorMessage = s"No handler for Hive udf ${functionInfo.getFunctionClass} " + - s"because: ${throwable.getMessage}." - throw new AnalysisException(errorMessage) - } - } - } - - override def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder) - : Unit = underlying.registerFunction(name, info, builder) - - /* List all of the registered function names. */ - override def listFunction(): Seq[String] = { - (FunctionRegistry.getFunctionNames.asScala ++ underlying.listFunction()).toList.sorted - } - - /* Get the class of the registered function by specified name. */ - override def lookupFunction(name: String): Option[ExpressionInfo] = { - underlying.lookupFunction(name).orElse( - Try { - val info = FunctionRegistry.getFunctionInfo(name) - val annotation = info.getFunctionClass.getAnnotation(classOf[Description]) - if (annotation != null) { - Some(new ExpressionInfo( - info.getFunctionClass.getCanonicalName, - annotation.name(), - annotation.value(), - annotation.extended())) - } else { - Some(new ExpressionInfo( - info.getFunctionClass.getCanonicalName, - name, - null, - null)) - } - }.getOrElse(None)) - } -} - -private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) +private[hive] case class HiveSimpleUDF( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with CodegenFallback with Logging { override def deterministic: Boolean = isUDFDeterministic @@ -158,7 +68,7 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre @transient private lazy val conversionHelper = new ConversionHelper(method, arguments) - override val dataType = javaClassToDataType(method.getReturnType) + override lazy val dataType = javaClassToDataType(method.getReturnType) @transient lazy val returnInspector = ObjectInspectorFactory.getReflectionObjectInspector( @@ -183,6 +93,10 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } + + override def prettyName: String = name + + override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" } // Adapter from Catalyst ExpressionResult to Hive DeferredObject @@ -197,7 +111,8 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp override def get(): AnyRef = wrap(func(), oi, dataType) } -private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) +private[hive] case class HiveGenericUDF( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with CodegenFallback with Logging { override def nullable: Boolean = true @@ -225,11 +140,11 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr } @transient - private lazy val deferedObjects = argumentInspectors.zip(children).map { case (inspect, child) => + private lazy val deferredObjects = argumentInspectors.zip(children).map { case (inspect, child) => new DeferredObjectAdapter(inspect, child.dataType) }.toArray[DeferredObject] - override val dataType: DataType = inspectorToDataType(returnInspector) + override lazy val dataType: DataType = inspectorToDataType(returnInspector) override def eval(input: InternalRow): Any = { returnInspector // Make sure initialized. @@ -237,242 +152,20 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr var i = 0 while (i < children.length) { val idx = i - deferedObjects(i).asInstanceOf[DeferredObjectAdapter].set( + deferredObjects(i).asInstanceOf[DeferredObjectAdapter].set( () => { children(idx).eval(input) }) i += 1 } - unwrap(function.evaluate(deferedObjects), returnInspector) - } - - override def toString: String = { - s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" - } -} - -/** - * Resolves [[UnresolvedWindowFunction]] to [[HiveWindowFunction]]. - */ -private[spark] object ResolveHiveWindowFunction extends Rule[LogicalPlan] { - private def shouldResolveFunction( - unresolvedWindowFunction: UnresolvedWindowFunction, - windowSpec: WindowSpecDefinition): Boolean = { - unresolvedWindowFunction.childrenResolved && windowSpec.childrenResolved - } - - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case p: LogicalPlan if !p.childrenResolved => p - - // We are resolving WindowExpressions at here. When we get here, we have already - // replaced those WindowSpecReferences. - case p: LogicalPlan => - p transformExpressions { - // We will not start to resolve the function unless all arguments are resolved - // and all expressions in window spec are fixed. - case WindowExpression( - u @ UnresolvedWindowFunction(name, children), - windowSpec: WindowSpecDefinition) if shouldResolveFunction(u, windowSpec) => - // First, let's find the window function info. - val windowFunctionInfo: WindowFunctionInfo = - Option(FunctionRegistry.getWindowFunctionInfo(name.toLowerCase)).getOrElse( - throw new AnalysisException(s"Couldn't find window function $name")) - - // Get the class of this function. - // In Hive 0.12, there is no windowFunctionInfo.getFunctionClass. So, we use - // windowFunctionInfo.getfInfo().getFunctionClass for both Hive 0.13 and Hive 0.13.1. - val functionClass = windowFunctionInfo.getFunctionClass() - val newChildren = - // Rank(), DENSE_RANK(), CUME_DIST(), and PERCENT_RANK() do not take explicit - // input parameters and requires implicit parameters, which - // are expressions in Order By clause. - if (classOf[GenericUDAFRank].isAssignableFrom(functionClass)) { - if (children.nonEmpty) { - throw new AnalysisException(s"$name does not take input parameters.") - } - windowSpec.orderSpec.map(_.child) - } else { - children - } - - // If the class is UDAF, we need to use UDAFBridge. - val isUDAFBridgeRequired = - if (classOf[UDAF].isAssignableFrom(functionClass)) { - true - } else { - false - } - - // Create the HiveWindowFunction. For the meaning of isPivotResult, see the doc of - // HiveWindowFunction. - val windowFunction = - HiveWindowFunction( - new HiveFunctionWrapper(functionClass.getName), - windowFunctionInfo.isPivotResult, - isUDAFBridgeRequired, - newChildren) - - // Second, check if the specified window function can accept window definition. - windowSpec.frameSpecification match { - case frame: SpecifiedWindowFrame if !windowFunctionInfo.isSupportsWindow => - // This Hive window function does not support user-speficied window frame. - throw new AnalysisException( - s"Window function $name does not take a frame specification.") - case frame: SpecifiedWindowFrame if windowFunctionInfo.isSupportsWindow && - windowFunctionInfo.isPivotResult => - // These two should not be true at the same time when a window frame is defined. - // If so, throw an exception. - throw new AnalysisException(s"Could not handle Hive window function $name because " + - s"it supports both a user specified window frame and pivot result.") - case _ => // OK - } - // Resolve those UnspecifiedWindowFrame because the physical Window operator still needs - // a window frame specification to work. - val newWindowSpec = windowSpec.frameSpecification match { - case UnspecifiedFrame => - val newWindowFrame = - SpecifiedWindowFrame.defaultWindowFrame( - windowSpec.orderSpec.nonEmpty, - windowFunctionInfo.isSupportsWindow) - WindowSpecDefinition(windowSpec.partitionSpec, windowSpec.orderSpec, newWindowFrame) - case _ => windowSpec - } - - // Finally, we create a WindowExpression with the resolved window function and - // specified window spec. - WindowExpression(windowFunction, newWindowSpec) - } - } -} - -/** - * A [[WindowFunction]] implementation wrapping Hive's window function. - * @param funcWrapper The wrapper for the Hive Window Function. - * @param pivotResult If it is true, the Hive function will return a list of values representing - * the values of the added columns. Otherwise, a single value is returned for - * current row. - * @param isUDAFBridgeRequired If it is true, the function returned by functionWrapper's - * createFunction is UDAF, we need to use GenericUDAFBridge to wrap - * it as a GenericUDAFResolver2. - * @param children Input parameters. - */ -private[hive] case class HiveWindowFunction( - funcWrapper: HiveFunctionWrapper, - pivotResult: Boolean, - isUDAFBridgeRequired: Boolean, - children: Seq[Expression]) extends WindowFunction - with HiveInspectors with Unevaluable { - - // Hive window functions are based on GenericUDAFResolver2. - type UDFType = GenericUDAFResolver2 - - @transient - protected lazy val resolver: GenericUDAFResolver2 = - if (isUDAFBridgeRequired) { - new GenericUDAFBridge(funcWrapper.createFunction[UDAF]()) - } else { - funcWrapper.createFunction[GenericUDAFResolver2]() - } - - @transient - protected lazy val inputInspectors = children.map(toInspector).toArray - - // The GenericUDAFEvaluator used to evaluate the window function. - @transient - protected lazy val evaluator: GenericUDAFEvaluator = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false) - resolver.getEvaluator(parameterInfo) - } - - // The object inspector of values returned from the Hive window function. - @transient - protected lazy val returnInspector = { - evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) - } - - override val dataType: DataType = - if (!pivotResult) { - inspectorToDataType(returnInspector) - } else { - // If pivotResult is true, we should take the element type out as the data type of this - // function. - inspectorToDataType(returnInspector) match { - case ArrayType(dt, _) => dt - case _ => - sys.error( - s"error resolve the data type of window function ${funcWrapper.functionClassName}") - } - } - - override def nullable: Boolean = true - - @transient - lazy val inputProjection = new InterpretedProjection(children) - - @transient - private var hiveEvaluatorBuffer: AggregationBuffer = _ - // Output buffer. - private var outputBuffer: Any = _ - - @transient - private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray - - override def init(): Unit = { - evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) - } - - // Reset the hiveEvaluatorBuffer and outputPosition - override def reset(): Unit = { - // We create a new aggregation buffer to workaround the bug in GenericUDAFRowNumber. - // Basically, GenericUDAFRowNumberEvaluator.reset calls RowNumberBuffer.init. - // However, RowNumberBuffer.init does not really reset this buffer. - hiveEvaluatorBuffer = evaluator.getNewAggregationBuffer - evaluator.reset(hiveEvaluatorBuffer) - } - - override def prepareInputParameters(input: InternalRow): AnyRef = { - wrap( - inputProjection(input), - inputInspectors, - new Array[AnyRef](children.length), - inputDataTypes) - } - - // Add input parameters for a single row. - override def update(input: AnyRef): Unit = { - evaluator.iterate(hiveEvaluatorBuffer, input.asInstanceOf[Array[AnyRef]]) - } - - override def batchUpdate(inputs: Array[AnyRef]): Unit = { - var i = 0 - while (i < inputs.length) { - evaluator.iterate(hiveEvaluatorBuffer, inputs(i).asInstanceOf[Array[AnyRef]]) - i += 1 - } - } - - override def evaluate(): Unit = { - outputBuffer = unwrap(evaluator.evaluate(hiveEvaluatorBuffer), returnInspector) + unwrap(function.evaluate(deferredObjects), returnInspector) } - override def get(index: Int): Any = { - if (!pivotResult) { - // if pivotResult is false, we will get a single value for all rows in the frame. - outputBuffer - } else { - // if pivotResult is true, we will get a ArrayData having the same size with the size - // of the window frame. At here, we will return the result at the position of - // index in the output buffer. - outputBuffer.asInstanceOf[ArrayData].get(index, dataType) - } - } + override def prettyName: String = name override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - - override def newInstance(): WindowFunction = - new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children) } /** @@ -487,6 +180,7 @@ private[hive] case class HiveWindowFunction( * user defined aggregations, which have clean semantics even in a partitioned execution. */ private[hive] case class HiveGenericUDTF( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Generator with HiveInspectors with CodegenFallback { @@ -552,6 +246,8 @@ private[hive] case class HiveGenericUDTF( override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } + + override def prettyName: String = name } /** @@ -559,6 +255,7 @@ private[hive] case class HiveGenericUDTF( * performance a lot. */ private[hive] case class HiveUDAFFunction( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression], isUDAFBridgeRequired: Boolean = false, @@ -642,6 +339,12 @@ private[hive] case class HiveUDAFFunction( override def supportsPartial: Boolean = false - override val dataType: DataType = inspectorToDataType(returnInspector) -} + override lazy val dataType: DataType = inspectorToDataType(returnInspector) + + override def prettyName: String = name + override def sql(isDistinct: Boolean): String = { + val distinct = if (isDistinct) "DISTINCT " else " " + s"$name($distinct${children.map(_.sql).mkString(", ")})" + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 93c016b6c6c7a..794fe264ead5d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -20,21 +20,27 @@ package org.apache.spark.sql.hive import java.text.NumberFormat import java.util.Date -import scala.collection.mutable +import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.common.FileUtils import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities} import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.hadoop.hive.serde2.Serializer +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorUtils, StructObjectInspector} +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred._ -import org.apache.hadoop.hive.common.FileUtils +import org.apache.hadoop.mapreduce.TaskType +import org.apache.spark._ +import org.apache.spark.internal.Logging import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.UnsafeKVExternalSorter import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableJobConf @@ -44,10 +50,12 @@ import org.apache.spark.util.SerializableJobConf * It is based on [[SparkHadoopWriter]]. */ private[hive] class SparkHiveWriterContainer( - jobConf: JobConf, - fileSinkConf: FileSinkDesc) + @transient private val jobConf: JobConf, + fileSinkConf: FileSinkDesc, + inputSchema: Seq[Attribute], + table: MetastoreRelation) extends Logging - with SparkHadoopMapRedUtil + with HiveInspectors with Serializable { private val now = new Date() @@ -68,8 +76,8 @@ private[hive] class SparkHiveWriterContainer( @transient private var writer: FileSinkOperator.RecordWriter = null @transient protected lazy val committer = conf.value.getOutputCommitter - @transient protected lazy val jobContext = newJobContext(conf.value, jID.value) - @transient private lazy val taskContext = newTaskAttemptContext(conf.value, taID.value) + @transient protected lazy val jobContext = new JobContextImpl(conf.value, jID.value) + @transient private lazy val taskContext = new TaskAttemptContextImpl(conf.value, taID.value) @transient private lazy val outputFormat = conf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef, Writable]] @@ -94,14 +102,12 @@ private[hive] class SparkHiveWriterContainer( "part-" + numberFormat.format(splitID) + extension } - def getLocalFileWriter(row: InternalRow, schema: StructType): FileSinkOperator.RecordWriter = { - writer - } - def close() { // Seems the boolean value passed into close does not matter. - writer.close(false) - commit() + if (writer != null) { + writer.close(false) + commit() + } } def commitJob() { @@ -124,6 +130,13 @@ private[hive] class SparkHiveWriterContainer( SparkHadoopMapRedUtil.commitTask(committer, taskContext, jobID, splitID) } + def abortTask(): Unit = { + if (committer != null) { + committer.abortTask(taskContext) + } + logError(s"Task attempt $taskContext aborted.") + } + private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { jobID = jobId splitID = splitId @@ -131,7 +144,7 @@ private[hive] class SparkHiveWriterContainer( jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobId)) taID = new SerializableWritable[TaskAttemptID]( - new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID)) + new TaskAttemptID(new TaskID(jID.value, TaskType.MAP, splitID), attemptID)) } private def setConfParams() { @@ -141,6 +154,44 @@ private[hive] class SparkHiveWriterContainer( conf.value.setBoolean("mapred.task.is.map", true) conf.value.setInt("mapred.task.partition", splitID) } + + def newSerializer(tableDesc: TableDesc): Serializer = { + val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] + serializer.initialize(null, tableDesc.getProperties) + serializer + } + + protected def prepareForWrite() = { + val serializer = newSerializer(fileSinkConf.getTableInfo) + val standardOI = ObjectInspectorUtils + .getStandardObjectInspector( + fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, + ObjectInspectorCopyOption.JAVA) + .asInstanceOf[StructObjectInspector] + + val fieldOIs = standardOI.getAllStructFieldRefs.asScala.map(_.getFieldObjectInspector).toArray + val dataTypes = inputSchema.map(_.dataType) + val wrappers = fieldOIs.zip(dataTypes).map { case (f, dt) => wrapperFor(f, dt) } + val outputData = new Array[Any](fieldOIs.length) + (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) + } + + // this function is executed on executor side + def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = { + val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = prepareForWrite() + executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) + + iterator.foreach { row => + var i = 0 + while (i < fieldOIs.length) { + outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i))) + i += 1 + } + writer.write(serializer.serialize(outputData, standardOI)) + } + + close() + } } private[hive] object SparkHiveWriterContainer { @@ -164,25 +215,22 @@ private[spark] object SparkHiveDynamicPartitionWriterContainer { private[spark] class SparkHiveDynamicPartitionWriterContainer( jobConf: JobConf, fileSinkConf: FileSinkDesc, - dynamicPartColNames: Array[String]) - extends SparkHiveWriterContainer(jobConf, fileSinkConf) { + dynamicPartColNames: Array[String], + inputSchema: Seq[Attribute], + table: MetastoreRelation) + extends SparkHiveWriterContainer(jobConf, fileSinkConf, inputSchema, table) { import SparkHiveDynamicPartitionWriterContainer._ private val defaultPartName = jobConf.get( ConfVars.DEFAULTPARTITIONNAME.varname, ConfVars.DEFAULTPARTITIONNAME.defaultStrVal) - @transient private var writers: mutable.HashMap[String, FileSinkOperator.RecordWriter] = _ - override protected def initWriters(): Unit = { - // NOTE: This method is executed at the executor side. - // Actual writers are created for each dynamic partition on the fly. - writers = mutable.HashMap.empty[String, FileSinkOperator.RecordWriter] + // do nothing } override def close(): Unit = { - writers.values.foreach(_.close(false)) - commit() + // do nothing } override def commitJob(): Unit = { @@ -199,33 +247,90 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( conf.value.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker) } - override def getLocalFileWriter(row: InternalRow, schema: StructType) - : FileSinkOperator.RecordWriter = { - def convertToHiveRawString(col: String, value: Any): String = { - val raw = String.valueOf(value) - schema(col).dataType match { - case DateType => DateTimeUtils.dateToString(raw.toInt) - case _: DecimalType => BigDecimal(raw).toString() - case _ => raw - } + // this function is executed on executor side + override def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = { + val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = prepareForWrite() + executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) + + val partitionOutput = inputSchema.takeRight(dynamicPartColNames.length) + val dataOutput = inputSchema.take(fieldOIs.length) + // Returns the partition key given an input row + val getPartitionKey = UnsafeProjection.create(partitionOutput, inputSchema) + // Returns the data columns to be written given an input row + val getOutputRow = UnsafeProjection.create(dataOutput, inputSchema) + + val fun: AnyRef = (pathString: String) => FileUtils.escapePathName(pathString, defaultPartName) + // Expressions that given a partition key build a string like: col1=val/col2=val/... + val partitionStringExpression = partitionOutput.zipWithIndex.flatMap { case (c, i) => + val escaped = + ScalaUDF(fun, StringType, Seq(Cast(c, StringType)), Seq(StringType)) + val str = If(IsNull(c), Literal(defaultPartName), escaped) + val partitionName = Literal(dynamicPartColNames(i) + "=") :: str :: Nil + if (i == 0) partitionName else Literal(Path.SEPARATOR_CHAR.toString) :: partitionName } - val nonDynamicPartLen = row.numFields - dynamicPartColNames.length - val dynamicPartPath = dynamicPartColNames.zipWithIndex.map { case (colName, i) => - val rawVal = row.get(nonDynamicPartLen + i, schema(colName).dataType) - val string = if (rawVal == null) null else convertToHiveRawString(colName, rawVal) - val colString = - if (string == null || string.isEmpty) { - defaultPartName - } else { - FileUtils.escapePathName(string, defaultPartName) - } - s"/$colName=$colString" - }.mkString + // Returns the partition path given a partition key. + val getPartitionString = + UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionOutput) + + // If anything below fails, we should abort the task. + try { + val sorter: UnsafeKVExternalSorter = new UnsafeKVExternalSorter( + StructType.fromAttributes(partitionOutput), + StructType.fromAttributes(dataOutput), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + TaskContext.get().taskMemoryManager().pageSizeBytes) + + while (iterator.hasNext) { + val inputRow = iterator.next() + val currentKey = getPartitionKey(inputRow) + sorter.insertKV(currentKey, getOutputRow(inputRow)) + } - def newWriter(): FileSinkOperator.RecordWriter = { + logInfo(s"Sorting complete. Writing out partition files one at a time.") + val sortedIterator = sorter.sortedIterator() + var currentKey: InternalRow = null + var currentWriter: FileSinkOperator.RecordWriter = null + try { + while (sortedIterator.next()) { + if (currentKey != sortedIterator.getKey) { + if (currentWriter != null) { + currentWriter.close(false) + } + currentKey = sortedIterator.getKey.copy() + logDebug(s"Writing partition: $currentKey") + currentWriter = newOutputWriter(currentKey) + } + + var i = 0 + while (i < fieldOIs.length) { + outputData(i) = if (sortedIterator.getValue.isNullAt(i)) { + null + } else { + wrappers(i)(sortedIterator.getValue.get(i, dataTypes(i))) + } + i += 1 + } + currentWriter.write(serializer.serialize(outputData, standardOI)) + } + } finally { + if (currentWriter != null) { + currentWriter.close(false) + } + } + commit() + } catch { + case cause: Throwable => + logError("Aborting task.", cause) + abortTask() + throw new SparkException("Task failed while writing rows.", cause) + } + /** Open and returns a new OutputWriter given a partition key. */ + def newOutputWriter(key: InternalRow): FileSinkOperator.RecordWriter = { + val partitionPath = getPartitionString(key).getString(0) val newFileSinkDesc = new FileSinkDesc( - fileSinkConf.getDirName + dynamicPartPath, + fileSinkConf.getDirName + partitionPath, fileSinkConf.getTableInfo, fileSinkConf.getCompressed) newFileSinkDesc.setCompressCodec(fileSinkConf.getCompressCodec) @@ -235,7 +340,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( // to avoid write to the same file when `spark.speculation=true` val path = FileOutputFormat.getTaskOutputPath( conf.value, - dynamicPartPath.stripPrefix("/") + "/" + getOutputName) + partitionPath.stripPrefix("/") + "/" + getOutputName) HiveFileFormatUtils.getHiveRecordWriter( conf.value, @@ -245,7 +350,5 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( path, Reporter.NULL) } - - writers.getOrElseUpdate(dynamicPartPath, newWriter()) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index 0f9a1a6ef3b27..8248a112a0af4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -22,9 +22,8 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.io.orc.{OrcFile, Reader} import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector -import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.AnalysisException +import org.apache.spark.internal.Logging import org.apache.spark.sql.hive.HiveMetastoreTypes import org.apache.spark.sql.types.StructType @@ -45,7 +44,6 @@ private[orc] object OrcFileOperator extends Logging { * directly from HDFS via Spark SQL, because we have to discover the schema from raw ORC * files. So this method always tries to find a ORC file whose schema is non-empty, and * create the result reader from that file. If no such file is found, it returns `None`. - * * @todo Needs to consider all files when schema evolution is taken into account. */ def getFileReader(basePath: String, config: Option[Configuration] = None): Option[Reader] = { @@ -73,16 +71,15 @@ private[orc] object OrcFileOperator extends Logging { } } - def readSchema(path: String, conf: Option[Configuration]): StructType = { - val reader = getFileReader(path, conf).getOrElse { - throw new AnalysisException( - s"Failed to discover schema from ORC files stored in $path. " + - "Probably there are either no ORC files or only empty ORC files.") + def readSchema(paths: Seq[String], conf: Option[Configuration]): Option[StructType] = { + // Take the first file where we can open a valid reader if we can find one. Otherwise just + // return None to indicate we can't infer the schema. + paths.flatMap(getFileReader(_, conf)).headOption.map { reader => + val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] + val schema = readerInspector.getTypeName + logDebug(s"Reading schema from file $paths, got Hive schema string: $schema") + HiveMetastoreTypes.toDataType(schema).asInstanceOf[StructType] } - val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] - val schema = readerInspector.getTypeName - logDebug(s"Reading schema from file $path, got Hive schema string: $schema") - HiveMetastoreTypes.toDataType(schema).asInstanceOf[StructType] } def getObjectInspector( @@ -91,20 +88,14 @@ private[orc] object OrcFileOperator extends Logging { } def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { + // TODO: Check if the paths coming in are already qualified and simplify. val origPath = new Path(pathStr) val fs = origPath.getFileSystem(conf) - val path = origPath.makeQualified(fs.getUri, fs.getWorkingDirectory) val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath) - .filterNot(_.isDir) + .filterNot(_.isDirectory) .map(_.getPath) .filterNot(_.getName.startsWith("_")) .filterNot(_.getName.startsWith(".")) - - if (paths == null || paths.isEmpty) { - throw new IllegalArgumentException( - s"orcFileOperator: path $path does not have valid orc files matching the pattern") - } - paths } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index 27193f54d3a91..c025c12a90a2d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -18,23 +18,55 @@ package org.apache.spark.sql.hive.orc import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar} -import org.apache.hadoop.hive.ql.io.sarg.{SearchArgumentFactory, SearchArgument} +import org.apache.hadoop.hive.ql.io.sarg.{SearchArgument, SearchArgumentFactory} import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder import org.apache.hadoop.hive.serde2.io.DateWritable -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.sql.sources._ /** - * It may be optimized by push down partial filters. But we are conservative here. - * Because if some filters fail to be parsed, the tree may be corrupted, - * and cannot be used anymore. + * Helper object for building ORC `SearchArgument`s, which are used for ORC predicate push-down. + * + * Due to limitation of ORC `SearchArgument` builder, we had to end up with a pretty weird double- + * checking pattern when converting `And`/`Or`/`Not` filters. + * + * An ORC `SearchArgument` must be built in one pass using a single builder. For example, you can't + * build `a = 1` and `b = 2` first, and then combine them into `a = 1 AND b = 2`. This is quite + * different from the cases in Spark SQL or Parquet, where complex filters can be easily built using + * existing simpler ones. + * + * The annoying part is that, `SearchArgument` builder methods like `startAnd()`, `startOr()`, and + * `startNot()` mutate internal state of the builder instance. This forces us to translate all + * convertible filters with a single builder instance. However, before actually converting a filter, + * we've no idea whether it can be recognized by ORC or not. Thus, when an inconvertible filter is + * found, we may already end up with a builder whose internal state is inconsistent. + * + * For example, to convert an `And` filter with builder `b`, we call `b.startAnd()` first, and then + * try to convert its children. Say we convert `left` child successfully, but find that `right` + * child is inconvertible. Alas, `b.startAnd()` call can't be rolled back, and `b` is inconsistent + * now. + * + * The workaround employed here is that, for `And`/`Or`/`Not`, we first try to convert their + * children with brand new builders, and only do the actual conversion with the right builder + * instance when the children are proven to be convertible. + * + * P.S.: Hive seems to use `SearchArgument` together with `ExprNodeGenericFuncDesc` only. Usage of + * builder methods mentioned above can only be found in test code, where all tested filters are + * known to be convertible. */ private[orc] object OrcFilters extends Logging { def createFilter(filters: Array[Filter]): Option[SearchArgument] = { + // First, tries to convert each filter individually to see whether it's convertible, and then + // collect all convertible ones to build the final `SearchArgument`. + val convertibleFilters = for { + filter <- filters + _ <- buildSearchArgument(filter, SearchArgumentFactory.newBuilder()) + } yield filter + for { - // Combines all filters with `And`s to produce a single conjunction predicate - conjunction <- filters.reduceOption(And) + // Combines all convertible filters using `And` to produce a single conjunction + conjunction <- convertibleFilters.reduceOption(And) // Then tries to build a single ORC `SearchArgument` for the conjunction predicate builder <- buildSearchArgument(conjunction, SearchArgumentFactory.newBuilder()) } yield builder.build() @@ -50,46 +82,22 @@ private[orc] object OrcFilters extends Logging { case _ => false } - // lian: I probably missed something here, and had to end up with a pretty weird double-checking - // pattern when converting `And`/`Or`/`Not` filters. - // - // The annoying part is that, `SearchArgument` builder methods like `startAnd()` `startOr()`, - // and `startNot()` mutate internal state of the builder instance. This forces us to translate - // all convertible filters with a single builder instance. However, before actually converting a - // filter, we've no idea whether it can be recognized by ORC or not. Thus, when an inconvertible - // filter is found, we may already end up with a builder whose internal state is inconsistent. - // - // For example, to convert an `And` filter with builder `b`, we call `b.startAnd()` first, and - // then try to convert its children. Say we convert `left` child successfully, but find that - // `right` child is inconvertible. Alas, `b.startAnd()` call can't be rolled back, and `b` is - // inconsistent now. - // - // The workaround employed here is that, for `And`/`Or`/`Not`, we first try to convert their - // children with brand new builders, and only do the actual conversion with the right builder - // instance when the children are proven to be convertible. - // - // P.S.: Hive seems to use `SearchArgument` together with `ExprNodeGenericFuncDesc` only. - // Usage of builder methods mentioned above can only be found in test code, where all tested - // filters are known to be convertible. - expression match { case And(left, right) => - val tryLeft = buildSearchArgument(left, newBuilder) - val tryRight = buildSearchArgument(right, newBuilder) - - val conjunction = for { - _ <- tryLeft - _ <- tryRight + // At here, it is not safe to just convert one side if we do not understand the + // other side. Here is an example used to explain the reason. + // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to + // convert b in ('1'). If we only convert a = 2, we will end up with a filter + // NOT(a = 2), which will generate wrong results. + // Pushing one side of AND down is only safe to do at the top level. + // You can see ParquetRelation's initializeLocalJobFunc method as an example. + for { + _ <- buildSearchArgument(left, newBuilder) + _ <- buildSearchArgument(right, newBuilder) lhs <- buildSearchArgument(left, builder.startAnd()) rhs <- buildSearchArgument(right, lhs) } yield rhs.end() - // For filter `left AND right`, we can still push down `left` even if `right` is not - // convertible, and vice versa. - conjunction - .orElse(tryLeft.flatMap(_ => buildSearchArgument(left, builder))) - .orElse(tryRight.flatMap(_ => buildSearchArgument(right, builder))) - case Or(left, right) => for { _ <- buildSearchArgument(left, newBuilder) @@ -104,6 +112,10 @@ private[orc] object OrcFilters extends Logging { negate <- buildSearchArgument(child, builder.startNot()) } yield negate.end() + // NOTE: For all case branches dealing with leaf predicates below, the additional `startAnd()` + // call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be + // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). + case EqualTo(attribute, value) if isSearchableLiteral(value) => Some(builder.startAnd().equals(attribute, value).end()) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 45de567039760..21591ec093d3d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -17,57 +17,168 @@ package org.apache.spark.sql.hive.orc +import java.net.URI import java.util.Properties -import com.google.common.base.Objects import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.io.orc.{OrcInputFormat, OrcOutputFormat, OrcSerde, OrcSplit, OrcStruct} -import org.apache.hadoop.hive.serde2.objectinspector.SettableStructObjectInspector +import org.apache.hadoop.hive.ql.io.orc._ +import org.apache.hadoop.hive.ql.io.orc.OrcFile.OrcTableProperties +import org.apache.hadoop.hive.serde2.objectinspector.{SettableStructObjectInspector, StructObjectInspector} import org.apache.hadoop.hive.serde2.typeinfo.{StructTypeInfo, TypeInfoUtils} import org.apache.hadoop.io.{NullWritable, Writable} import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat -import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl -import org.apache.spark.Logging -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.mapred.SparkHadoopMapRedUtil +import org.apache.spark.internal.Logging import org.apache.spark.rdd.{HadoopRDD, RDD} +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.datasources.PartitionSpec -import org.apache.spark.sql.hive.{HiveContext, HiveInspectors, HiveMetastoreTypes, HiveShim} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.hive.{HiveInspectors, HiveMetastoreTypes, HiveShim} import org.apache.spark.sql.sources.{Filter, _} import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.SerializableConfiguration -private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { +private[sql] class DefaultSource + extends FileFormat with DataSourceRegister with Serializable { override def shortName(): String = "orc" - override def createRelation( + override def toString: String = "ORC" + + override def inferSchema( + sqlContext: SQLContext, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + OrcFileOperator.readSchema( + files.map(_.getPath.toUri.toString), + Some(sqlContext.sparkContext.hadoopConfiguration) + ) + } + + override def prepareWrite( sqlContext: SQLContext, - paths: Array[String], - dataSchema: Option[StructType], - partitionColumns: Option[StructType], - parameters: Map[String, String]): HadoopFsRelation = { - assert( - sqlContext.isInstanceOf[HiveContext], - "The ORC data source can only be used with HiveContext.") - - new OrcRelation(paths, dataSchema, None, partitionColumns, parameters)(sqlContext) + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + val compressionCodec: Option[String] = options + .get("compression") + .map { codecName => + // Validate if given compression codec is supported or not. + val shortOrcCompressionCodecNames = OrcRelation.shortOrcCompressionCodecNames + if (!shortOrcCompressionCodecNames.contains(codecName.toLowerCase)) { + val availableCodecs = shortOrcCompressionCodecNames.keys.map(_.toLowerCase) + throw new IllegalArgumentException(s"Codec [$codecName] " + + s"is not available. Available codecs are ${availableCodecs.mkString(", ")}.") + } + codecName.toLowerCase + } + + compressionCodec.foreach { codecName => + job.getConfiguration.set( + OrcTableProperties.COMPRESSION.getPropName, + OrcRelation + .shortOrcCompressionCodecNames + .getOrElse(codecName, CompressionKind.NONE).name()) + } + + job.getConfiguration match { + case conf: JobConf => + conf.setOutputFormat(classOf[OrcOutputFormat]) + case conf => + conf.setClass( + "mapred.output.format.class", + classOf[OrcOutputFormat], + classOf[MapRedOutputFormat[_, _]]) + } + + new OutputWriterFactory { + override def newInstance( + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new OrcOutputWriter(path, bucketId, dataSchema, context) + } + } + } + + override def buildReader( + sqlContext: SQLContext, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String]): (PartitionedFile) => Iterator[InternalRow] = { + val orcConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) + + if (sqlContext.conf.orcFilterPushDown) { + // Sets pushed predicates + OrcFilters.createFilter(filters.toArray).foreach { f => + orcConf.set(OrcTableScan.SARG_PUSHDOWN, f.toKryo) + orcConf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) + } + } + + val broadcastedConf = sqlContext.sparkContext.broadcast(new SerializableConfiguration(orcConf)) + + (file: PartitionedFile) => { + val conf = broadcastedConf.value.value + + // SPARK-8501: Empty ORC files always have an empty schema stored in their footer. In this + // case, `OrcFileOperator.readSchema` returns `None`, and we can't read the underlying file + // using the given physical schema. Instead, we simply return an empty iterator. + val maybePhysicalSchema = OrcFileOperator.readSchema(Seq(file.filePath), Some(conf)) + if (maybePhysicalSchema.isEmpty) { + Iterator.empty + } else { + val physicalSchema = maybePhysicalSchema.get + OrcRelation.setRequiredColumns(conf, physicalSchema, requiredSchema) + + val orcRecordReader = { + val job = Job.getInstance(conf) + FileInputFormat.setInputPaths(job, file.filePath) + + val inputFormat = new OrcNewInputFormat + val fileSplit = new FileSplit( + new Path(new URI(file.filePath)), file.start, file.length, Array.empty + ) + + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + inputFormat.createRecordReader(fileSplit, hadoopAttemptContext) + } + + // Unwraps `OrcStruct`s to `UnsafeRow`s + val unsafeRowIterator = OrcRelation.unwrapOrcStructs( + file.filePath, conf, requiredSchema, new RecordReaderIterator[OrcStruct](orcRecordReader) + ) + + // Appends partition values + val fullOutput = requiredSchema.toAttributes ++ partitionSchema.toAttributes + val joinedRow = new JoinedRow() + val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput) + + unsafeRowIterator.map { dataRow => + appendPartitionColumns(joinedRow(dataRow, file.partitionValues)) + } + } + } } } private[orc] class OrcOutputWriter( path: String, + bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriter with SparkHadoopMapRedUtil with HiveInspectors { + extends OutputWriter with HiveInspectors { private val serializer = { val table = new Properties() @@ -77,7 +188,7 @@ private[orc] class OrcOutputWriter( }.mkString(":")) val serde = new OrcSerde - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val configuration = context.getConfiguration serde.initialize(configuration, table) serde } @@ -99,11 +210,19 @@ private[orc] class OrcOutputWriter( private lazy val recordWriter: RecordWriter[NullWritable, Writable] = { recordWriterInstantiated = true - val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val conf = context.getConfiguration val uniqueWriteJobId = conf.get("spark.sql.sources.writeJobUUID") - val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val taskAttemptId = context.getTaskAttemptID val partition = taskAttemptId.getTaskID.getId - val filename = f"part-r-$partition%05d-$uniqueWriteJobId.orc" + val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") + val compressionExtension = { + val name = conf.get(OrcTableProperties.COMPRESSION.getPropName) + OrcRelation.extensionsForCompressionCodecNames.getOrElse(name, "") + } + // It has the `.orc` extension at the end because (de)compression tools + // such as gunzip would not be able to decompress this as the compression + // is not applied on this whole file but on each "stream" in ORC format. + val filename = f"part-r-$partition%05d-$uniqueWriteJobId$bucketString$compressionExtension.orc" new OrcOutputFormat().getRecordWriter( new Path(path, filename).getFileSystem(conf), @@ -113,7 +232,8 @@ private[orc] class OrcOutputWriter( ).asInstanceOf[RecordWriter[NullWritable, Writable]] } - override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") + override def write(row: Row): Unit = + throw new UnsupportedOperationException("call writeInternal") private def wrapOrcStruct( struct: OrcStruct, @@ -122,6 +242,7 @@ private[orc] class OrcOutputWriter( val fieldRefs = oi.getAllStructFieldRefs var i = 0 while (i < fieldRefs.size) { + oi.setStructFieldData( struct, fieldRefs.get(i), @@ -150,147 +271,17 @@ private[orc] class OrcOutputWriter( } } -private[sql] class OrcRelation( - override val paths: Array[String], - maybeDataSchema: Option[StructType], - maybePartitionSpec: Option[PartitionSpec], - override val userDefinedPartitionColumns: Option[StructType], - parameters: Map[String, String])( - @transient val sqlContext: SQLContext) - extends HadoopFsRelation(maybePartitionSpec) - with Logging { - - private[sql] def this( - paths: Array[String], - maybeDataSchema: Option[StructType], - maybePartitionSpec: Option[PartitionSpec], - parameters: Map[String, String])( - sqlContext: SQLContext) = { - this( - paths, - maybeDataSchema, - maybePartitionSpec, - maybePartitionSpec.map(_.partitionColumns), - parameters)(sqlContext) - } - - override val dataSchema: StructType = maybeDataSchema.getOrElse { - OrcFileOperator.readSchema( - paths.head, Some(sqlContext.sparkContext.hadoopConfiguration)) - } - - override def needConversion: Boolean = false - - override def equals(other: Any): Boolean = other match { - case that: OrcRelation => - paths.toSet == that.paths.toSet && - dataSchema == that.dataSchema && - schema == that.schema && - partitionColumns == that.partitionColumns - case _ => false - } - - override def hashCode(): Int = { - Objects.hashCode( - paths.toSet, - dataSchema, - schema, - partitionColumns) - } - - override private[sql] def buildInternalScan( - requiredColumns: Array[String], - filters: Array[Filter], - inputPaths: Array[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { - val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes - OrcTableScan(output, this, filters, inputPaths).execute() - } - - override def prepareJobForWrite(job: Job): OutputWriterFactory = { - SparkHadoopUtil.get.getConfigurationFromJobContext(job) match { - case conf: JobConf => - conf.setOutputFormat(classOf[OrcOutputFormat]) - case conf => - conf.setClass( - "mapred.output.format.class", - classOf[OrcOutputFormat], - classOf[MapRedOutputFormat[_, _]]) - } - - new OutputWriterFactory { - override def newInstance( - path: String, - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter = { - new OrcOutputWriter(path, dataSchema, context) - } - } - } -} - private[orc] case class OrcTableScan( + @transient sqlContext: SQLContext, attributes: Seq[Attribute], - @transient relation: OrcRelation, filters: Array[Filter], - @transient inputPaths: Array[FileStatus]) + @transient inputPaths: Seq[FileStatus]) extends Logging with HiveInspectors { - @transient private val sqlContext = relation.sqlContext - - private def addColumnIds( - output: Seq[Attribute], - relation: OrcRelation, - conf: Configuration): Unit = { - val ids = output.map(a => relation.dataSchema.fieldIndex(a.name): Integer) - val (sortedIds, sortedNames) = ids.zip(attributes.map(_.name)).sorted.unzip - HiveShim.appendReadColumns(conf, sortedIds, sortedNames) - } - - // Transform all given raw `Writable`s into `InternalRow`s. - private def fillObject( - path: String, - conf: Configuration, - iterator: Iterator[Writable], - nonPartitionKeyAttrs: Seq[Attribute]): Iterator[InternalRow] = { - val deserializer = new OrcSerde - val maybeStructOI = OrcFileOperator.getObjectInspector(path, Some(conf)) - val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) - val unsafeProjection = UnsafeProjection.create(StructType.fromAttributes(nonPartitionKeyAttrs)) - - // SPARK-8501: ORC writes an empty schema ("struct<>") to an ORC file if the file contains zero - // rows, and thus couldn't give a proper ObjectInspector. In this case we just return an empty - // partition since we know that this file is empty. - maybeStructOI.map { soi => - val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.zipWithIndex.map { - case (attr, ordinal) => - soi.getStructFieldRef(attr.name) -> ordinal - }.unzip - val unwrappers = fieldRefs.map(unwrapperFor) - // Map each tuple to a row object - iterator.map { value => - val raw = deserializer.deserialize(value) - var i = 0 - while (i < fieldRefs.length) { - val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) - if (fieldValue == null) { - mutableRow.setNullAt(fieldOrdinals(i)) - } else { - unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) - } - i += 1 - } - unsafeProjection(mutableRow) - } - }.getOrElse { - Iterator.empty - } - } - def execute(): RDD[InternalRow] = { - val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val job = Job.getInstance(sqlContext.sparkContext.hadoopConfiguration) + val conf = job.getConfiguration // Tries to push down filters if ORC filter push-down is enabled if (sqlContext.conf.orcFilterPushDown) { @@ -300,8 +291,15 @@ private[orc] case class OrcTableScan( } } + // Figure out the actual schema from the ORC source (without partition columns) so that we + // can pick the correct ordinals. Note that this assumes that all files have the same schema. + val orcFormat = new DefaultSource + val dataSchema = + orcFormat + .inferSchema(sqlContext, Map.empty, inputPaths) + .getOrElse(sys.error("Failed to read schema from target ORC files.")) // Sets requested columns - addColumnIds(attributes, relation, conf) + OrcRelation.setRequiredColumns(conf, dataSchema, StructType.fromAttributes(attributes)) if (inputPaths.isEmpty) { // the input path probably be pruned, return an empty RDD. @@ -324,7 +322,12 @@ private[orc] case class OrcTableScan( rdd.mapPartitionsWithInputSplit { case (split: OrcSplit, iterator) => val writableIterator = iterator.map(_._2) - fillObject(split.getPath.toString, wrappedConf.value, writableIterator, attributes) + OrcRelation.unwrapOrcStructs( + split.getPath.toString, + wrappedConf.value, + StructType.fromAttributes(attributes), + writableIterator + ) } } } @@ -333,3 +336,64 @@ private[orc] object OrcTableScan { // This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public. private[orc] val SARG_PUSHDOWN = "sarg.pushdown" } + +private[orc] object OrcRelation extends HiveInspectors { + // The ORC compression short names + val shortOrcCompressionCodecNames = Map( + "none" -> CompressionKind.NONE, + "uncompressed" -> CompressionKind.NONE, + "snappy" -> CompressionKind.SNAPPY, + "zlib" -> CompressionKind.ZLIB, + "lzo" -> CompressionKind.LZO) + + // The extensions for ORC compression codecs + val extensionsForCompressionCodecNames = Map( + CompressionKind.NONE.name -> "", + CompressionKind.SNAPPY.name -> ".snappy", + CompressionKind.ZLIB.name -> ".zlib", + CompressionKind.LZO.name -> ".lzo" + ) + + def unwrapOrcStructs( + filePath: String, + conf: Configuration, + dataSchema: StructType, + iterator: Iterator[Writable]): Iterator[InternalRow] = { + val deserializer = new OrcSerde + val maybeStructOI = OrcFileOperator.getObjectInspector(filePath, Some(conf)) + val mutableRow = new SpecificMutableRow(dataSchema.map(_.dataType)) + val unsafeProjection = UnsafeProjection.create(dataSchema) + + def unwrap(oi: StructObjectInspector): Iterator[InternalRow] = { + val (fieldRefs, fieldOrdinals) = dataSchema.zipWithIndex.map { + case (field, ordinal) => oi.getStructFieldRef(field.name) -> ordinal + }.unzip + + val unwrappers = fieldRefs.map(unwrapperFor) + + iterator.map { value => + val raw = deserializer.deserialize(value) + var i = 0 + while (i < fieldRefs.length) { + val fieldValue = oi.getStructFieldData(raw, fieldRefs(i)) + if (fieldValue == null) { + mutableRow.setNullAt(fieldOrdinals(i)) + } else { + unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) + } + i += 1 + } + unsafeProjection(mutableRow) + } + } + + maybeStructOI.map(unwrap).getOrElse(Iterator.empty) + } + + def setRequiredColumns( + conf: Configuration, physicalSchema: StructType, requestedSchema: StructType): Unit = { + val ids = requestedSchema.map(a => physicalSchema.fieldIndex(a.name): Integer) + val (sortedIDs, sortedNames) = ids.zip(requestedSchema.fieldNames).sorted.unzip + HiveShim.appendReadColumns(conf, sortedIDs, sortedNames) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 6883d305cbead..7f6ca21782da4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -24,25 +24,33 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.implicitConversions +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.FunctionRegistry import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe -import org.apache.spark.sql.{SQLContext, SQLConf} +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.ExpressionInfo import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.CacheTableCommand +import org.apache.spark.sql.execution.CacheManager +import org.apache.spark.sql.execution.command.CacheTableCommand +import org.apache.spark.sql.execution.ui.SQLListener import org.apache.spark.sql.hive._ +import org.apache.spark.sql.hive.client.{HiveClient, HiveClientImpl} import org.apache.spark.sql.hive.execution.HiveNativeCommand +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.{ShutdownHookManager, Utils} -import org.apache.spark.{SparkConf, SparkContext} // SPARK-3729: Test key required to check for initialization errors with config. object TestHive extends TestHiveContext( new SparkContext( - System.getProperty("spark.sql.test.master", "local[32]"), + System.getProperty("spark.sql.test.master", "local[1]"), "TestSQLContext", new SparkConf() .set("spark.sql.test", "") @@ -51,10 +59,6 @@ object TestHive // SPARK-8910 .set("spark.ui.enabled", "false"))) -trait TestHiveSingleton { - protected val sqlContext: SQLContext = TestHive - protected val hiveContext: TestHiveContext = TestHive -} /** * A locally running test instance of Spark's Hive execution engine. @@ -67,10 +71,87 @@ trait TestHiveSingleton { * hive metastore seems to lead to weird non-deterministic failures. Therefore, the execution of * test cases that rely on TestHive must be serialized. */ -class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { - self => +class TestHiveContext private[hive]( + sc: SparkContext, + cacheManager: CacheManager, + listener: SQLListener, + executionHive: HiveClientImpl, + metadataHive: HiveClient, + isRootContext: Boolean, + hiveCatalog: HiveExternalCatalog, + val warehousePath: File, + val scratchDirPath: File, + metastoreTemporaryConf: Map[String, String]) + extends HiveContext( + sc, + cacheManager, + listener, + executionHive, + metadataHive, + isRootContext, + hiveCatalog) { self => + + // Unfortunately, due to the complex interactions between the construction parameters + // and the limitations in scala constructors, we need many of these constructors to + // provide a shorthand to create a new TestHiveContext with only a SparkContext. + // This is not a great design pattern but it's necessary here. + + private def this( + sc: SparkContext, + executionHive: HiveClientImpl, + metadataHive: HiveClient, + warehousePath: File, + scratchDirPath: File, + metastoreTemporaryConf: Map[String, String]) { + this( + sc, + new CacheManager, + SQLContext.createListenerAndUI(sc), + executionHive, + metadataHive, + true, + new HiveExternalCatalog(metadataHive), + warehousePath, + scratchDirPath, + metastoreTemporaryConf) + } + + private def this( + sc: SparkContext, + warehousePath: File, + scratchDirPath: File, + metastoreTemporaryConf: Map[String, String]) { + this( + sc, + HiveContext.newClientForExecution(sc.conf, sc.hadoopConfiguration), + TestHiveContext.newClientForMetadata( + sc.conf, sc.hadoopConfiguration, warehousePath, scratchDirPath, metastoreTemporaryConf), + warehousePath, + scratchDirPath, + metastoreTemporaryConf) + } - import HiveContext._ + def this(sc: SparkContext) { + this( + sc, + Utils.createTempDir(namePrefix = "warehouse"), + TestHiveContext.makeScratchDir(), + HiveContext.newTemporaryConfiguration(useInMemoryDerby = false)) + } + + override def newSession(): HiveContext = { + new TestHiveContext( + sc = sc, + cacheManager = cacheManager, + listener = listener, + executionHive = executionHive.newSession(), + metadataHive = metadataHive.newSession(), + isRootContext = false, + hiveCatalog = hiveCatalog, + warehousePath = warehousePath, + scratchDirPath = scratchDirPath, + metastoreTemporaryConf = metastoreTemporaryConf) + } // By clearing the port we force Spark to pick a new one. This allows us to rerun tests // without restarting the JVM. @@ -79,24 +160,12 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { hiveconf.set("hive.plan.serialization.format", "javaXML") - lazy val warehousePath = Utils.createTempDir(namePrefix = "warehouse-") - - lazy val scratchDirPath = { - val dir = Utils.createTempDir(namePrefix = "scratch-") - dir.delete() - dir - } - - private lazy val temporaryConfig = newTemporaryConfiguration() - - /** Sets up the system initially or after a RESET command */ - protected override def configure(): Map[String, String] = { - super.configure() ++ temporaryConfig ++ Map( - ConfVars.METASTOREWAREHOUSE.varname -> warehousePath.toURI.toString, - ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true", - ConfVars.SCRATCHDIR.varname -> scratchDirPath.toURI.toString, - ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY.varname -> "1" - ) + // A snapshot of the entries in the starting SQLConf + // We save this because tests can mutate this singleton object if they want + val initialSQLConf: SQLConf = { + val snapshot = new SQLConf + conf.getAllConfs.foreach { case (k, v) => snapshot.setConfString(k, v) } + snapshot } val testTempDir = Utils.createTempDir() @@ -116,19 +185,29 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { override def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution(plan) - protected[sql] override lazy val conf: SQLConf = new SQLConf { - // The super.getConf(SQLConf.DIALECT) is "sql" by default, we need to set it as "hiveql" - override def dialect: String = super.getConf(SQLConf.DIALECT, "hiveql") - override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) - - clear() - - override def clear(): Unit = { - super.clear() + @transient + protected[sql] override lazy val sessionState = new HiveSessionState(this) { + override lazy val conf: SQLConf = { + new SQLConf { + clear() + override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) + override def clear(): Unit = { + super.clear() + TestHiveContext.overrideConfs.map { + case (key, value) => setConfString(key, value) + } + } + } + } - TestHiveContext.overrideConfs.map { - case (key, value) => setConfString(key, value) + override lazy val functionRegistry = { + // We use TestHiveFunctionRegistry at here to track functions that have been explicitly + // unregistered (through TestHiveFunctionRegistry.unregisterFunction method). + val fr = new TestHiveFunctionRegistry + org.apache.spark.sql.catalyst.analysis.FunctionRegistry.expressions.foreach { + case (name, (info, builder)) => fr.registerFunction(name, info, builder) } + fr } } @@ -196,7 +275,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(loadTestTable) // Proceed with analysis. - analyzer.execute(logical) + sessionState.analyzer.execute(logical) } } @@ -306,8 +385,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { """.stripMargin.cmd, s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/episodes.avro")}' INTO TABLE episodes".cmd ), - // THIS TABLE IS NOT THE SAME AS THE HIVE TEST TABLE episodes_partitioned AS DYNAMIC PARITIONING - // IS NOT YET SUPPORTED + // THIS TABLE IS NOT THE SAME AS THE HIVE TEST TABLE episodes_partitioned AS DYNAMIC + // PARTITIONING IS NOT YET SUPPORTED TestTable("episodes_part", s"""CREATE TABLE episodes_part (title STRING, air_date STRING, doctor INT) |PARTITIONED BY (doctor_pt INT) @@ -410,14 +489,17 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { try { // HACK: Hive is too noisy by default. org.apache.log4j.LogManager.getCurrentLoggers.asScala.foreach { log => - log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) + val logger = log.asInstanceOf[org.apache.log4j.Logger] + if (!logger.getName.contains("org.apache.spark")) { + logger.setLevel(org.apache.log4j.Level.WARN) + } } cacheManager.clearCache() loadedTables.clear() - catalog.cachedDataSourceTables.invalidateAll() - catalog.client.reset() - catalog.unregisterAllTables() + sessionState.catalog.clearTempTables() + sessionState.catalog.invalidateCache() + metadataHive.reset() FunctionRegistry.getFunctionNames.asScala.filterNot(originalUDFs.contains(_)). foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } @@ -436,25 +518,35 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { // Lots of tests fail if we do not change the partition whitelist from the default. runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") - configure().foreach { - case (k, v) => - metadataHive.runSqlHive(s"SET $k=$v") - } + // In case a test changed any of these values, restore all the original ones here. + TestHiveContext.hiveClientConfigurations( + hiveconf, warehousePath, scratchDirPath, metastoreTemporaryConf) + .foreach { case (k, v) => metadataHive.runSqlHive(s"SET $k=$v") } defaultOverrides() - runSqlHive("USE default") - - // Just loading src makes a lot of tests pass. This is because some tests do something like - // drop an index on src at the beginning. Since we just pass DDL to hive this bypasses our - // Analyzer and thus the test table auto-loading mechanism. - // Remove after we handle more DDL operations natively. - loadTestTable("src") - loadTestTable("srcpart") + sessionState.catalog.setCurrentDatabase("default") } catch { case e: Exception => logError("FATAL ERROR: Failed to reset TestDB state.", e) } } + +} + +private[hive] class TestHiveFunctionRegistry extends SimpleFunctionRegistry { + + private val removedFunctions = + collection.mutable.ArrayBuffer.empty[(String, (ExpressionInfo, FunctionBuilder))] + + def unregisterFunction(name: String): Unit = { + functionBuilders.remove(name).foreach(f => removedFunctions += name -> f) + } + + def restore(): Unit = { + removedFunctions.foreach { + case (name, (info, builder)) => registerFunction(name, info, builder) + } + } } private[hive] object TestHiveContext { @@ -467,4 +559,43 @@ private[hive] object TestHiveContext { // Fewer shuffle partitions to speed up testing. SQLConf.SHUFFLE_PARTITIONS.key -> "5" ) + + /** + * Create a [[HiveClient]] used to retrieve metadata from the Hive MetaStore. + */ + private def newClientForMetadata( + conf: SparkConf, + hadoopConf: Configuration, + warehousePath: File, + scratchDirPath: File, + metastoreTemporaryConf: Map[String, String]): HiveClient = { + val hiveConf = new HiveConf(hadoopConf, classOf[HiveConf]) + HiveContext.newClientForMetadata( + conf, + hiveConf, + hadoopConf, + hiveClientConfigurations(hiveConf, warehousePath, scratchDirPath, metastoreTemporaryConf)) + } + + /** + * Configurations needed to create a [[HiveClient]]. + */ + private def hiveClientConfigurations( + hiveconf: HiveConf, + warehousePath: File, + scratchDirPath: File, + metastoreTemporaryConf: Map[String, String]): Map[String, String] = { + HiveContext.hiveClientConfigurations(hiveconf) ++ metastoreTemporaryConf ++ Map( + ConfVars.METASTOREWAREHOUSE.varname -> warehousePath.toURI.toString, + ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true", + ConfVars.SCRATCHDIR.varname -> scratchDirPath.toURI.toString, + ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY.varname -> "1") + } + + private def makeScratchDir(): File = { + val scratchDir = Utils.createTempDir(namePrefix = "scratch") + scratchDir.delete() + scratchDir + } + } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java index b4bf9eef8fca5..397421ae92a47 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -38,9 +38,9 @@ public class JavaDataFrameSuite { private transient JavaSparkContext sc; private transient HiveContext hc; - DataFrame df; + Dataset df; - private static void checkAnswer(DataFrame actual, List expected) { + private static void checkAnswer(Dataset actual, List expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -82,12 +82,12 @@ public void saveTableAndQueryIt() { @Test public void testUDAF() { - DataFrame df = hc.range(0, 100).unionAll(hc.range(0, 100)).select(col("id").as("value")); + Dataset df = hc.range(0, 100).union(hc.range(0, 100)).select(col("id").as("value")); UserDefinedAggregateFunction udaf = new MyDoubleSum(); UserDefinedAggregateFunction registeredUDAF = hc.udf().register("mydoublesum", udaf); // Create Columns for the UDAF. For now, callUDF does not take an argument to specific if // we want to use distinct aggregation. - DataFrame aggregatedDF = + Dataset aggregatedDF = df.groupBy() .agg( udaf.distinct(col("value")), diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 8c4af1b8eaf44..2fc38e2b2d2e7 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -33,7 +33,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.QueryTest$; import org.apache.spark.sql.Row; import org.apache.spark.sql.hive.test.TestHive$; @@ -48,13 +48,12 @@ public class JavaMetastoreDataSourcesSuite { private transient JavaSparkContext sc; private transient HiveContext sqlContext; - String originalDefaultSource; File path; Path hiveManagedPath; FileSystem fs; - DataFrame df; + Dataset df; - private static void checkAnswer(DataFrame actual, List expected) { + private static void checkAnswer(Dataset actual, List expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -66,14 +65,14 @@ public void setUp() throws IOException { sqlContext = TestHive$.MODULE$; sc = new JavaSparkContext(sqlContext.sparkContext()); - originalDefaultSource = sqlContext.conf().defaultDataSourceName(); path = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile(); if (path.exists()) { path.delete(); } - hiveManagedPath = new Path(sqlContext.catalog().hiveDefaultTableFilePath( - new TableIdentifier("javaSavedTable"))); + hiveManagedPath = new Path( + sqlContext.sessionState().catalog().hiveDefaultTableFilePath( + new TableIdentifier("javaSavedTable"))); fs = hiveManagedPath.getFileSystem(sc.hadoopConfiguration()); if (fs.exists(hiveManagedPath)){ fs.delete(hiveManagedPath, true); @@ -111,7 +110,7 @@ public void saveExternalTableAndQueryIt() { sqlContext.sql("SELECT * FROM javaSavedTable"), df.collectAsList()); - DataFrame loadedDF = + Dataset loadedDF = sqlContext.createExternalTable("externalTable", "org.apache.spark.sql.json", options); checkAnswer(loadedDF, df.collectAsList()); @@ -137,7 +136,7 @@ public void saveExternalTableWithSchemaAndQueryIt() { List fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); - DataFrame loadedDF = + Dataset loadedDF = sqlContext.createExternalTable("externalTable", "org.apache.spark.sql.json", schema, options); checkAnswer( diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java index 5a167edd89592..ae0c097c362ab 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java @@ -42,14 +42,14 @@ public class MyDoubleAvg extends UserDefinedAggregateFunction { private DataType _returnDataType; public MyDoubleAvg() { - List inputFields = new ArrayList(); + List inputFields = new ArrayList<>(); inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); _inputDataType = DataTypes.createStructType(inputFields); // The buffer has two values, bufferSum for storing the current sum and // bufferCount for storing the number of non-null input values that have been contribuetd // to the current sum. - List bufferFields = new ArrayList(); + List bufferFields = new ArrayList<>(); bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true)); bufferFields.add(DataTypes.createStructField("bufferCount", DataTypes.LongType, true)); _bufferSchema = DataTypes.createStructType(bufferFields); diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java index c3b7768e71bf8..d17fb3e5194f3 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java @@ -41,11 +41,11 @@ public class MyDoubleSum extends UserDefinedAggregateFunction { private DataType _returnDataType; public MyDoubleSum() { - List inputFields = new ArrayList(); + List inputFields = new ArrayList<>(); inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); _inputDataType = DataTypes.createStructType(inputFields); - List bufferFields = new ArrayList(); + List bufferFields = new ArrayList<>(); bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true)); _bufferSchema = DataTypes.createStructType(bufferFields); diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java index e010112bb9327..a8cbd4fab15bb 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java @@ -39,7 +39,7 @@ * does not contain union fields that are not supported by Spark SQL. */ -@SuppressWarnings({"ALL", "unchecked"}) +@SuppressWarnings("all") public class Complex implements org.apache.thrift.TBase, java.io.Serializable, Cloneable { private static final org.apache.thrift.protocol.TStruct STRUCT_DESC = new org.apache.thrift.protocol.TStruct("Complex"); @@ -50,7 +50,7 @@ public class Complex implements org.apache.thrift.TBase, SchemeFactory> schemes = new HashMap, SchemeFactory>(); + private static final Map, SchemeFactory> schemes = new HashMap<>(); static { schemes.put(StandardScheme.class, new ComplexStandardSchemeFactory()); schemes.put(TupleScheme.class, new ComplexTupleSchemeFactory()); @@ -72,7 +72,7 @@ public enum _Fields implements org.apache.thrift.TFieldIdEnum { LINT_STRING((short)5, "lintString"), M_STRING_STRING((short)6, "mStringString"); - private static final Map byName = new HashMap(); + private static final Map byName = new HashMap<>(); static { for (_Fields field : EnumSet.allOf(_Fields.class)) { @@ -141,7 +141,7 @@ public String getFieldName() { private byte __isset_bitfield = 0; public static final Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> metaDataMap; static { - Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> tmpMap = new EnumMap<_Fields, org.apache.thrift.meta_data.FieldMetaData>(_Fields.class); + Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> tmpMap = new EnumMap<>(_Fields.class); tmpMap.put(_Fields.AINT, new org.apache.thrift.meta_data.FieldMetaData("aint", org.apache.thrift.TFieldRequirementType.DEFAULT, new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32))); tmpMap.put(_Fields.A_STRING, new org.apache.thrift.meta_data.FieldMetaData("aString", org.apache.thrift.TFieldRequirementType.DEFAULT, @@ -194,28 +194,28 @@ public Complex(Complex other) { this.aString = other.aString; } if (other.isSetLint()) { - List __this__lint = new ArrayList(); + List __this__lint = new ArrayList<>(); for (Integer other_element : other.lint) { __this__lint.add(other_element); } this.lint = __this__lint; } if (other.isSetLString()) { - List __this__lString = new ArrayList(); + List __this__lString = new ArrayList<>(); for (String other_element : other.lString) { __this__lString.add(other_element); } this.lString = __this__lString; } if (other.isSetLintString()) { - List __this__lintString = new ArrayList(); + List __this__lintString = new ArrayList<>(); for (IntString other_element : other.lintString) { __this__lintString.add(new IntString(other_element)); } this.lintString = __this__lintString; } if (other.isSetMStringString()) { - Map __this__mStringString = new HashMap(); + Map __this__mStringString = new HashMap<>(); for (Map.Entry other_element : other.mStringString.entrySet()) { String other_element_key = other_element.getKey(); @@ -339,7 +339,7 @@ public java.util.Iterator getLStringIterator() { public void addToLString(String elem) { if (this.lString == null) { - this.lString = new ArrayList(); + this.lString = new ArrayList<>(); } this.lString.add(elem); } @@ -411,7 +411,7 @@ public int getMStringStringSize() { public void putToMStringString(String key, String val) { if (this.mStringString == null) { - this.mStringString = new HashMap(); + this.mStringString = new HashMap<>(); } this.mStringString.put(key, val); } @@ -489,6 +489,7 @@ public void setFieldValue(_Fields field, Object value) { } break; + default: } } @@ -512,6 +513,7 @@ public Object getFieldValue(_Fields field) { case M_STRING_STRING: return getMStringString(); + default: } throw new IllegalStateException(); } @@ -535,75 +537,91 @@ public boolean isSet(_Fields field) { return isSetLintString(); case M_STRING_STRING: return isSetMStringString(); + default: } throw new IllegalStateException(); } @Override public boolean equals(Object that) { - if (that == null) + if (that == null) { return false; - if (that instanceof Complex) + } + if (that instanceof Complex) { return this.equals((Complex)that); + } return false; } public boolean equals(Complex that) { - if (that == null) + if (that == null) { return false; + } boolean this_present_aint = true; boolean that_present_aint = true; if (this_present_aint || that_present_aint) { - if (!(this_present_aint && that_present_aint)) + if (!(this_present_aint && that_present_aint)) { return false; - if (this.aint != that.aint) + } + if (this.aint != that.aint) { return false; + } } boolean this_present_aString = true && this.isSetAString(); boolean that_present_aString = true && that.isSetAString(); if (this_present_aString || that_present_aString) { - if (!(this_present_aString && that_present_aString)) + if (!(this_present_aString && that_present_aString)) { return false; - if (!this.aString.equals(that.aString)) + } + if (!this.aString.equals(that.aString)) { return false; + } } boolean this_present_lint = true && this.isSetLint(); boolean that_present_lint = true && that.isSetLint(); if (this_present_lint || that_present_lint) { - if (!(this_present_lint && that_present_lint)) + if (!(this_present_lint && that_present_lint)) { return false; - if (!this.lint.equals(that.lint)) + } + if (!this.lint.equals(that.lint)) { return false; + } } boolean this_present_lString = true && this.isSetLString(); boolean that_present_lString = true && that.isSetLString(); if (this_present_lString || that_present_lString) { - if (!(this_present_lString && that_present_lString)) + if (!(this_present_lString && that_present_lString)) { return false; - if (!this.lString.equals(that.lString)) + } + if (!this.lString.equals(that.lString)) { return false; + } } boolean this_present_lintString = true && this.isSetLintString(); boolean that_present_lintString = true && that.isSetLintString(); if (this_present_lintString || that_present_lintString) { - if (!(this_present_lintString && that_present_lintString)) + if (!(this_present_lintString && that_present_lintString)) { return false; - if (!this.lintString.equals(that.lintString)) + } + if (!this.lintString.equals(that.lintString)) { return false; + } } boolean this_present_mStringString = true && this.isSetMStringString(); boolean that_present_mStringString = true && that.isSetMStringString(); if (this_present_mStringString || that_present_mStringString) { - if (!(this_present_mStringString && that_present_mStringString)) + if (!(this_present_mStringString && that_present_mStringString)) { return false; - if (!this.mStringString.equals(that.mStringString)) + } + if (!this.mStringString.equals(that.mStringString)) { return false; + } } return true; @@ -615,33 +633,39 @@ public int hashCode() { boolean present_aint = true; builder.append(present_aint); - if (present_aint) + if (present_aint) { builder.append(aint); + } boolean present_aString = true && (isSetAString()); builder.append(present_aString); - if (present_aString) + if (present_aString) { builder.append(aString); + } boolean present_lint = true && (isSetLint()); builder.append(present_lint); - if (present_lint) + if (present_lint) { builder.append(lint); + } boolean present_lString = true && (isSetLString()); builder.append(present_lString); - if (present_lString) + if (present_lString) { builder.append(lString); + } boolean present_lintString = true && (isSetLintString()); builder.append(present_lintString); - if (present_lintString) + if (present_lintString) { builder.append(lintString); + } boolean present_mStringString = true && (isSetMStringString()); builder.append(present_mStringString); - if (present_mStringString) + if (present_mStringString) { builder.append(mStringString); + } return builder.toHashCode(); } @@ -737,7 +761,9 @@ public String toString() { sb.append("aint:"); sb.append(this.aint); first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("aString:"); if (this.aString == null) { sb.append("null"); @@ -745,7 +771,9 @@ public String toString() { sb.append(this.aString); } first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("lint:"); if (this.lint == null) { sb.append("null"); @@ -753,7 +781,9 @@ public String toString() { sb.append(this.lint); } first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("lString:"); if (this.lString == null) { sb.append("null"); @@ -761,7 +791,9 @@ public String toString() { sb.append(this.lString); } first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("lintString:"); if (this.lintString == null) { sb.append("null"); @@ -769,7 +801,9 @@ public String toString() { sb.append(this.lintString); } first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("mStringString:"); if (this.mStringString == null) { sb.append("null"); @@ -842,7 +876,7 @@ public void read(org.apache.thrift.protocol.TProtocol iprot, Complex struct) thr if (schemeField.type == org.apache.thrift.protocol.TType.LIST) { { org.apache.thrift.protocol.TList _list0 = iprot.readListBegin(); - struct.lint = new ArrayList(_list0.size); + struct.lint = new ArrayList<>(_list0.size); for (int _i1 = 0; _i1 < _list0.size; ++_i1) { int _elem2; // required @@ -860,7 +894,7 @@ public void read(org.apache.thrift.protocol.TProtocol iprot, Complex struct) thr if (schemeField.type == org.apache.thrift.protocol.TType.LIST) { { org.apache.thrift.protocol.TList _list3 = iprot.readListBegin(); - struct.lString = new ArrayList(_list3.size); + struct.lString = new ArrayList<>(_list3.size); for (int _i4 = 0; _i4 < _list3.size; ++_i4) { String _elem5; // required @@ -878,7 +912,7 @@ public void read(org.apache.thrift.protocol.TProtocol iprot, Complex struct) thr if (schemeField.type == org.apache.thrift.protocol.TType.LIST) { { org.apache.thrift.protocol.TList _list6 = iprot.readListBegin(); - struct.lintString = new ArrayList(_list6.size); + struct.lintString = new ArrayList<>(_list6.size); for (int _i7 = 0; _i7 < _list6.size; ++_i7) { IntString _elem8; // required @@ -1080,7 +1114,7 @@ public void read(org.apache.thrift.protocol.TProtocol prot, Complex struct) thro if (incoming.get(2)) { { org.apache.thrift.protocol.TList _list21 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.I32, iprot.readI32()); - struct.lint = new ArrayList(_list21.size); + struct.lint = new ArrayList<>(_list21.size); for (int _i22 = 0; _i22 < _list21.size; ++_i22) { int _elem23; // required @@ -1093,7 +1127,7 @@ public void read(org.apache.thrift.protocol.TProtocol prot, Complex struct) thro if (incoming.get(3)) { { org.apache.thrift.protocol.TList _list24 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRING, iprot.readI32()); - struct.lString = new ArrayList(_list24.size); + struct.lString = new ArrayList<>(_list24.size); for (int _i25 = 0; _i25 < _list24.size; ++_i25) { String _elem26; // required @@ -1106,7 +1140,7 @@ public void read(org.apache.thrift.protocol.TProtocol prot, Complex struct) thro if (incoming.get(4)) { { org.apache.thrift.protocol.TList _list27 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRUCT, iprot.readI32()); - struct.lintString = new ArrayList(_list27.size); + struct.lintString = new ArrayList<>(_list27.size); for (int _i28 = 0; _i28 < _list27.size; ++_i28) { IntString _elem29; // required @@ -1120,7 +1154,7 @@ public void read(org.apache.thrift.protocol.TProtocol prot, Complex struct) thro if (incoming.get(5)) { { org.apache.thrift.protocol.TMap _map30 = new org.apache.thrift.protocol.TMap(org.apache.thrift.protocol.TType.STRING, org.apache.thrift.protocol.TType.STRING, iprot.readI32()); - struct.mStringString = new HashMap(2*_map30.size); + struct.mStringString = new HashMap<>(2*_map30.size); for (int _i31 = 0; _i31 < _map30.size; ++_i31) { String _key32; // required diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala new file mode 100644 index 0000000000000..154ada3daae51 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.test + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SQLContext + + +trait TestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll { + protected val sqlContext: SQLContext = TestHive + protected val hiveContext: TestHiveContext = TestHive + + protected override def afterAll(): Unit = { + try { + hiveContext.reset() + } finally { + super.afterAll() + } + } + +} diff --git a/sql/hive/src/test/resources/data/conf/hive-log4j.properties b/sql/hive/src/test/resources/data/conf/hive-log4j.properties index 885c86f2b94f4..6a042472adb90 100644 --- a/sql/hive/src/test/resources/data/conf/hive-log4j.properties +++ b/sql/hive/src/test/resources/data/conf/hive-log4j.properties @@ -47,7 +47,7 @@ log4j.appender.DRFA.layout.ConversionPattern=%d{ISO8601} %-5p %c{2} (%F:%M(%L)) # # console -# Add "console" to rootlogger above if you want to use this +# Add "console" to rootlogger above if you want to use this # log4j.appender.console=org.apache.log4j.ConsoleAppender diff --git a/sql/hive/src/test/resources/golden/'1' + 1.0-0-5db3b55120a19863d96460d399c2d0e b/sql/hive/src/test/resources/golden/'1' + 1.0-0-404b0ea20c125c9648b7919a8f41add3 similarity index 100% rename from sql/hive/src/test/resources/golden/'1' + 1.0-0-5db3b55120a19863d96460d399c2d0e rename to sql/hive/src/test/resources/golden/'1' + 1.0-0-404b0ea20c125c9648b7919a8f41add3 diff --git a/sql/hive/src/test/resources/golden/1 + 1.0-0-4f5da98a11db8e7192423c27db767ca6 b/sql/hive/src/test/resources/golden/1 + 1.0-0-77ca48f121bd2ef41efb9ee3bc28418 similarity index 100% rename from sql/hive/src/test/resources/golden/1 + 1.0-0-4f5da98a11db8e7192423c27db767ca6 rename to sql/hive/src/test/resources/golden/1 + 1.0-0-77ca48f121bd2ef41efb9ee3bc28418 diff --git a/sql/hive/src/test/resources/golden/1.0 + '1'-0-a6ec78b3b93d52034aab829d43210e73 b/sql/hive/src/test/resources/golden/1.0 + '1'-0-6beb1ef5178117a9fd641008ed5ebb80 similarity index 100% rename from sql/hive/src/test/resources/golden/1.0 + '1'-0-a6ec78b3b93d52034aab829d43210e73 rename to sql/hive/src/test/resources/golden/1.0 + '1'-0-6beb1ef5178117a9fd641008ed5ebb80 diff --git a/sql/hive/src/test/resources/golden/1.0 + 1-0-30a4b1c8227906931cd0532367bebc43 b/sql/hive/src/test/resources/golden/1.0 + 1-0-bec2842d2b009973b4d4b8f10b5554f8 similarity index 100% rename from sql/hive/src/test/resources/golden/1.0 + 1-0-30a4b1c8227906931cd0532367bebc43 rename to sql/hive/src/test/resources/golden/1.0 + 1-0-bec2842d2b009973b4d4b8f10b5554f8 diff --git a/sql/hive/src/test/resources/golden/1.0 + 1.0-0-87321b2e30ee2986b00b631d0e4f4d8d b/sql/hive/src/test/resources/golden/1.0 + 1.0-0-eafdfdbb14980ee517c388dc117d91a8 similarity index 100% rename from sql/hive/src/test/resources/golden/1.0 + 1.0-0-87321b2e30ee2986b00b631d0e4f4d8d rename to sql/hive/src/test/resources/golden/1.0 + 1.0-0-eafdfdbb14980ee517c388dc117d91a8 diff --git a/sql/hive/src/test/resources/golden/1.0 + 1L-0-44bb88a1c9280952e8119a3ab1bb4205 b/sql/hive/src/test/resources/golden/1.0 + 1L-0-ef273f05968cd0e91af8c76949c73798 similarity index 100% rename from sql/hive/src/test/resources/golden/1.0 + 1L-0-44bb88a1c9280952e8119a3ab1bb4205 rename to sql/hive/src/test/resources/golden/1.0 + 1L-0-ef273f05968cd0e91af8c76949c73798 diff --git a/sql/hive/src/test/resources/golden/1.0 + 1S-0-31fbe14d01fb532176c1689680398368 b/sql/hive/src/test/resources/golden/1.0 + 1S-0-9f93538c38920d52b322bfc40cc2f31a similarity index 100% rename from sql/hive/src/test/resources/golden/1.0 + 1S-0-31fbe14d01fb532176c1689680398368 rename to sql/hive/src/test/resources/golden/1.0 + 1S-0-9f93538c38920d52b322bfc40cc2f31a diff --git a/sql/hive/src/test/resources/golden/1.0 + 1Y-0-12bcf6e49e83abd2aa36ea612b418d43 b/sql/hive/src/test/resources/golden/1.0 + 1Y-0-9e354e022b1b423f366bf79ed7522f2a similarity index 100% rename from sql/hive/src/test/resources/golden/1.0 + 1Y-0-12bcf6e49e83abd2aa36ea612b418d43 rename to sql/hive/src/test/resources/golden/1.0 + 1Y-0-9e354e022b1b423f366bf79ed7522f2a diff --git a/sql/hive/src/test/resources/golden/1L + 1.0-0-95a30c4b746f520f1251981a66cef5c8 b/sql/hive/src/test/resources/golden/1L + 1.0-0-9b0510d0bb3e9ee6a7698369b008a280 similarity index 100% rename from sql/hive/src/test/resources/golden/1L + 1.0-0-95a30c4b746f520f1251981a66cef5c8 rename to sql/hive/src/test/resources/golden/1L + 1.0-0-9b0510d0bb3e9ee6a7698369b008a280 diff --git a/sql/hive/src/test/resources/golden/1S + 1.0-0-8dfa46ec33c1be5ffba2e40cbfe5349e b/sql/hive/src/test/resources/golden/1S + 1.0-0-c3d54e5b6034b7796ed16896a434d1ba similarity index 100% rename from sql/hive/src/test/resources/golden/1S + 1.0-0-8dfa46ec33c1be5ffba2e40cbfe5349e rename to sql/hive/src/test/resources/golden/1S + 1.0-0-c3d54e5b6034b7796ed16896a434d1ba diff --git a/sql/hive/src/test/resources/golden/1Y + 1.0-0-3ad5e3db0d0300312d33231e7c2a6c8d b/sql/hive/src/test/resources/golden/1Y + 1.0-0-7b54e1d367c2ed1f5c181298ee5470d0 similarity index 100% rename from sql/hive/src/test/resources/golden/1Y + 1.0-0-3ad5e3db0d0300312d33231e7c2a6c8d rename to sql/hive/src/test/resources/golden/1Y + 1.0-0-7b54e1d367c2ed1f5c181298ee5470d0 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 deleted file mode 100644 index dac1b84b916d7..0000000000000 --- a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 +++ /dev/null @@ -1,6 +0,0 @@ -500 NULL 0 -91 0 1 -84 1 1 -105 2 1 -113 3 1 -107 4 1 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #2-0-7a511f02a16f0af4f810b1666cfcd896 b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #2-0-7a511f02a16f0af4f810b1666cfcd896 deleted file mode 100644 index c7cb747c0a659..0000000000000 --- a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #2-0-7a511f02a16f0af4f810b1666cfcd896 +++ /dev/null @@ -1,10 +0,0 @@ -1 NULL -3 2 -1 NULL -1 2 -1 NULL 3 2 -1 NULL 4 2 -1 NULL 5 2 -1 NULL 6 2 -1 NULL 12 2 -1 NULL 14 2 -1 NULL 15 2 -1 NULL 22 2 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for GroupingSet-0-8c14c24670a4b06c440346277ce9cf1c b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for GroupingSet-0-8c14c24670a4b06c440346277ce9cf1c deleted file mode 100644 index c7cb747c0a659..0000000000000 --- a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for GroupingSet-0-8c14c24670a4b06c440346277ce9cf1c +++ /dev/null @@ -1,10 +0,0 @@ -1 NULL -3 2 -1 NULL -1 2 -1 NULL 3 2 -1 NULL 4 2 -1 NULL 5 2 -1 NULL 6 2 -1 NULL 12 2 -1 NULL 14 2 -1 NULL 15 2 -1 NULL 22 2 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a deleted file mode 100644 index dac1b84b916d7..0000000000000 --- a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a +++ /dev/null @@ -1,6 +0,0 @@ -500 NULL 0 -91 0 1 -84 1 1 -105 2 1 -113 3 1 -107 4 1 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 deleted file mode 100644 index 1eea4a9b23687..0000000000000 --- a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 +++ /dev/null @@ -1,10 +0,0 @@ -1 0 5 3 -1 0 15 3 -1 0 25 3 -1 0 60 3 -1 0 75 3 -1 0 80 3 -1 0 100 3 -1 0 140 3 -1 0 145 3 -1 0 150 3 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce deleted file mode 100644 index 1eea4a9b23687..0000000000000 --- a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce +++ /dev/null @@ -1,10 +0,0 @@ -1 0 5 3 -1 0 15 3 -1 0 25 3 -1 0 60 3 -1 0 75 3 -1 0 80 3 -1 0 100 3 -1 0 140 3 -1 0 145 3 -1 0 150 3 diff --git a/sql/hive/src/test/resources/golden/case when then 1.0 else null end -0-aeb1f906bfe92f2d406f84109301afe0 b/sql/hive/src/test/resources/golden/case when then 1.0 else null end -0-cf71a4c4cce08635cc80a64a1ae6bc83 similarity index 100% rename from sql/hive/src/test/resources/golden/case when then 1.0 else null end -0-aeb1f906bfe92f2d406f84109301afe0 rename to sql/hive/src/test/resources/golden/case when then 1.0 else null end -0-cf71a4c4cce08635cc80a64a1ae6bc83 diff --git a/sql/hive/src/test/resources/golden/case when then null else 1.0 end -0-7f5ce763801781cf568c6a31dd80b623 b/sql/hive/src/test/resources/golden/case when then null else 1.0 end -0-dfc876530eeaa7c42978d1bc0b1fd58 similarity index 100% rename from sql/hive/src/test/resources/golden/case when then null else 1.0 end -0-7f5ce763801781cf568c6a31dd80b623 rename to sql/hive/src/test/resources/golden/case when then null else 1.0 end -0-dfc876530eeaa7c42978d1bc0b1fd58 diff --git a/sql/hive/src/test/resources/golden/constant null testing-0-237a6af90a857da1efcbe98f6bbbf9d6 b/sql/hive/src/test/resources/golden/constant null testing-0-237a6af90a857da1efcbe98f6bbbf9d6 new file mode 100644 index 0000000000000..a01c2622c68e2 --- /dev/null +++ b/sql/hive/src/test/resources/golden/constant null testing-0-237a6af90a857da1efcbe98f6bbbf9d6 @@ -0,0 +1 @@ +1 NULL 1 NULL 1.0 NULL true NULL 1 NULL 1.0 NULL 1 NULL 1 NULL 1 NULL 1970-01-01 NULL NULL 1 NULL diff --git a/sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 b/sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 deleted file mode 100644 index 7c41615f8c184..0000000000000 --- a/sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 +++ /dev/null @@ -1 +0,0 @@ -1 NULL 1 NULL 1.0 NULL true NULL 1 NULL 1.0 NULL 1 NULL 1 NULL 1 NULL 1970-01-01 NULL 1969-12-31 16:00:00.001 NULL 1 NULL diff --git a/sql/hive/src/test/resources/golden/show_functions-1-4a6f611305f58bdbafb2fd89ec62d797 b/sql/hive/src/test/resources/golden/show_functions-1-4a6f611305f58bdbafb2fd89ec62d797 index 175795534fff5..f400819b67c26 100644 --- a/sql/hive/src/test/resources/golden/show_functions-1-4a6f611305f58bdbafb2fd89ec62d797 +++ b/sql/hive/src/test/resources/golden/show_functions-1-4a6f611305f58bdbafb2fd89ec62d797 @@ -1,4 +1,5 @@ case +cbrt ceil ceiling coalesce @@ -17,3 +18,6 @@ covar_samp create_union cume_dist current_database +current_date +current_timestamp +current_user diff --git a/sql/hive/src/test/resources/golden/show_functions-2-97cbada21ad9efda7ce9de5891deca7c b/sql/hive/src/test/resources/golden/show_functions-2-97cbada21ad9efda7ce9de5891deca7c index 3c25d656bda1c..19458fc86e439 100644 --- a/sql/hive/src/test/resources/golden/show_functions-2-97cbada21ad9efda7ce9de5891deca7c +++ b/sql/hive/src/test/resources/golden/show_functions-2-97cbada21ad9efda7ce9de5891deca7c @@ -2,6 +2,7 @@ assert_true case coalesce current_database +current_date decode e encode diff --git a/sql/hive/src/test/resources/golden/show_functions-4-4deaa213aff83575bbaf859f79bfdd48 b/sql/hive/src/test/resources/golden/show_functions-4-4deaa213aff83575bbaf859f79bfdd48 index cd2e58d04a4ef..1d05f843a7e0f 100644 --- a/sql/hive/src/test/resources/golden/show_functions-4-4deaa213aff83575bbaf859f79bfdd48 +++ b/sql/hive/src/test/resources/golden/show_functions-4-4deaa213aff83575bbaf859f79bfdd48 @@ -1,4 +1,6 @@ +current_date date_add +date_format date_sub datediff to_date diff --git a/sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 b/sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 deleted file mode 100644 index d00491fd7e5bb..0000000000000 --- a/sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 +++ /dev/null @@ -1 +0,0 @@ -1 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 b/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 deleted file mode 100644 index 84a31a5a6970b..0000000000000 --- a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 +++ /dev/null @@ -1 +0,0 @@ --0.001 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f b/sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f deleted file mode 100644 index 3fbedf693b51d..0000000000000 --- a/sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f +++ /dev/null @@ -1 +0,0 @@ --2 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 21. testDISTs-0-672d4cb385b7ced2e446f132474293ad b/sql/hive/src/test/resources/golden/windowing.q -- 21. testDISTs-0-d9065e533430691d70b3370174fbbd50 similarity index 100% rename from sql/hive/src/test/resources/golden/windowing.q -- 21. testDISTs-0-672d4cb385b7ced2e446f132474293ad rename to sql/hive/src/test/resources/golden/windowing.q -- 21. testDISTs-0-d9065e533430691d70b3370174fbbd50 diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.10.jar b/sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.10.jar new file mode 100644 index 0000000000000..26d410f33029b Binary files /dev/null and b/sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.10.jar differ diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.11.jar b/sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.11.jar new file mode 100644 index 0000000000000..f34784752f69f Binary files /dev/null and b/sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.11.jar differ diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar b/sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar deleted file mode 100644 index 5944aa6076a5f..0000000000000 Binary files a/sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar and /dev/null differ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 5c2fc7d82ffbd..11384a0275ae3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -19,10 +19,9 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.spark.sql.columnar.InMemoryColumnarTableScan -import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation -import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} +import org.apache.spark.sql.execution.columnar.InMemoryColumnarTableScan +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.storage.RDDBlockId import org.apache.spark.util.Utils @@ -30,17 +29,19 @@ class CachedTableSuite extends QueryTest with TestHiveSingleton { import hiveContext._ def rddIdOf(tableName: String): Int = { - val executedPlan = table(tableName).queryExecution.executedPlan - executedPlan.collect { + val plan = table(tableName).queryExecution.sparkPlan + plan.collect { case InMemoryColumnarTableScan(_, _, relation) => relation.cachedColumnBuffers.id case _ => - fail(s"Table $tableName is not cached\n" + executedPlan) + fail(s"Table $tableName is not cached\n" + plan) }.head } def isMaterialized(rddId: Int): Boolean = { - sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty + val maybeBlock = sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)) + maybeBlock.foreach(_ => sparkContext.env.blockManager.releaseLock(RDDBlockId(rddId, 0))) + maybeBlock.nonEmpty } test("cache table") { @@ -185,7 +186,7 @@ class CachedTableSuite extends QueryTest with TestHiveSingleton { assertCached(table("refreshTable")) checkAnswer( table("refreshTable"), - table("src").unionAll(table("src")).collect()) + table("src").union(table("src")).collect()) // Drop the table and create it again. sql("DROP TABLE refreshTable") @@ -197,7 +198,7 @@ class CachedTableSuite extends QueryTest with TestHiveSingleton { sql("REFRESH TABLE refreshTable") checkAnswer( table("refreshTable"), - table("src").unionAll(table("src")).collect()) + table("src").union(table("src")).collect()) // It is not cached. assert(!isCached("refreshTable"), "refreshTable should not be cached.") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala index 34b2edb44b033..f262ef62be036 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala @@ -24,9 +24,7 @@ import org.apache.spark.SparkFunSuite /** * Verify that some classes load and that others are not found on the classpath. * - * - * This is used to detect classpath and shading conflict, especially between - * Spark's required Kryo version and that which can be found in some Hive versions. + * This is used to detect classpath and shading conflicts. */ class ClasspathDependenciesSuite extends SparkFunSuite { private val classloader = this.getClass.getClassLoader @@ -40,10 +38,6 @@ class ClasspathDependenciesSuite extends SparkFunSuite { classloader.loadClass(classname) } - private def assertLoads(classes: String*): Unit = { - classes.foreach(assertLoads) - } - private def findResource(classname: String): URL = { val resource = resourceName(classname) classloader.getResource(resource) @@ -63,17 +57,12 @@ class ClasspathDependenciesSuite extends SparkFunSuite { } } - private def assertClassNotFound(classes: String*): Unit = { - classes.foreach(assertClassNotFound) + test("shaded Protobuf") { + assertLoads("org.apache.hive.com.google.protobuf.ServiceException") } - private val KRYO = "com.esotericsoftware.kryo.Kryo" - - private val SPARK_HIVE = "org.apache.hive." - private val SPARK_SHADED = "org.spark-project.hive.shaded." - - test("shaded Protobuf") { - assertLoads(SPARK_SHADED + "com.google.protobuf.ServiceException") + test("shaded Kryo") { + assertLoads("org.apache.hive.com.esotericsoftware.kryo.Kryo") } test("hive-common") { @@ -86,25 +75,13 @@ class ClasspathDependenciesSuite extends SparkFunSuite { private val STD_INSTANTIATOR = "org.objenesis.strategy.StdInstantiatorStrategy" - test("unshaded kryo") { - assertLoads(KRYO, STD_INSTANTIATOR) - } - test("Forbidden Dependencies") { - assertClassNotFound( - SPARK_HIVE + KRYO, - SPARK_SHADED + KRYO, - "org.apache.hive." + KRYO, - "com.esotericsoftware.shaded." + STD_INSTANTIATOR, - SPARK_HIVE + "com.esotericsoftware.shaded." + STD_INSTANTIATOR, - "org.apache.hive.com.esotericsoftware.shaded." + STD_INSTANTIATOR - ) + assertClassNotFound("com.esotericsoftware.shaded." + STD_INSTANTIATOR) + assertClassNotFound("org.apache.hive.com.esotericsoftware.shaded." + STD_INSTANTIATOR) } test("parquet-hadoop-bundle") { - assertLoads( - "parquet.hadoop.ParquetOutputFormat", - "parquet.hadoop.ParquetInputFormat" - ) + assertLoads("parquet.hadoop.ParquetOutputFormat") + assertLoads("parquet.hadoop.ParquetInputFormat") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index cf737836939f9..d9664680f4a11 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -19,20 +19,34 @@ package org.apache.spark.sql.hive import scala.util.Try -import org.scalatest.BeforeAndAfter +import org.scalatest.BeforeAndAfterEach +import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.hive.execution.HiveSqlParser import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.{AnalysisException, QueryTest} - -class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter { +class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterEach { import hiveContext.implicits._ - before { + override protected def beforeEach(): Unit = { + super.beforeEach() + if (sqlContext.tableNames().contains("src")) { + sqlContext.dropTempTable("src") + } + Seq((1, "")).toDF("key", "value").registerTempTable("src") Seq((1, 1, 1)).toDF("a", "a", "b").registerTempTable("dupAttributes") } + override protected def afterEach(): Unit = { + try { + sqlContext.dropTempTable("src") + sqlContext.dropTempTable("dupAttributes") + } finally { + super.afterEach() + } + } + positionTest("ambiguous attribute reference 1", "SELECT a from dupAttributes", "a") @@ -117,8 +131,8 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd * @param token a unique token in the string that should be indicated by the exception */ def positionTest(name: String, query: String, token: String): Unit = { - def parseTree = - Try(quietly(HiveQl.dumpTree(HiveQl.getAst(query)))).getOrElse("") + def ast = HiveSqlParser.parsePlan(query) + def parseTree = Try(quietly(ast.treeString)).getOrElse("") test(name) { val error = intercept[AnalysisException] { @@ -140,10 +154,7 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd val expectedStart = line.indexOf(token) val actualStart = error.startPosition.getOrElse { - fail( - s"start not returned for error on token $token\n" + - HiveQl.dumpTree(HiveQl.getAst(query)) - ) + fail(s"start not returned for error on token $token\n${ast.treeString}") } assert(expectedStart === actualStart, s"""Incorrect start position. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala new file mode 100644 index 0000000000000..38c84abd7c595 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.sql.Timestamp + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{If, Literal} + +class ExpressionSQLBuilderSuite extends SQLBuilderTest { + test("literal") { + checkSQL(Literal("foo"), "\"foo\"") + checkSQL(Literal("\"foo\""), "\"\\\"foo\\\"\"") + checkSQL(Literal(1: Byte), "1Y") + checkSQL(Literal(2: Short), "2S") + checkSQL(Literal(4: Int), "4") + checkSQL(Literal(8: Long), "8L") + checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)") + checkSQL(Literal(2.5D), "2.5D") + checkSQL( + Literal(Timestamp.valueOf("2016-01-01 00:00:00")), "TIMESTAMP('2016-01-01 00:00:00.0')") + // TODO tests for decimals + } + + test("attributes") { + checkSQL('a.int, "`a`") + checkSQL(Symbol("foo bar").int, "`foo bar`") + // Keyword + checkSQL('int.int, "`int`") + } + + test("binary comparisons") { + checkSQL('a.int === 'b.int, "(`a` = `b`)") + checkSQL('a.int <=> 'b.int, "(`a` <=> `b`)") + checkSQL('a.int =!= 'b.int, "(NOT (`a` = `b`))") + + checkSQL('a.int < 'b.int, "(`a` < `b`)") + checkSQL('a.int <= 'b.int, "(`a` <= `b`)") + checkSQL('a.int > 'b.int, "(`a` > `b`)") + checkSQL('a.int >= 'b.int, "(`a` >= `b`)") + + checkSQL('a.int in ('b.int, 'c.int), "(`a` IN (`b`, `c`))") + checkSQL('a.int in (1, 2), "(`a` IN (1, 2))") + + checkSQL('a.int.isNull, "(`a` IS NULL)") + checkSQL('a.int.isNotNull, "(`a` IS NOT NULL)") + } + + test("logical operators") { + checkSQL('a.boolean && 'b.boolean, "(`a` AND `b`)") + checkSQL('a.boolean || 'b.boolean, "(`a` OR `b`)") + checkSQL(!'a.boolean, "(NOT `a`)") + checkSQL(If('a.boolean, 'b.int, 'c.int), "(IF(`a`, `b`, `c`))") + } + + test("arithmetic expressions") { + checkSQL('a.int + 'b.int, "(`a` + `b`)") + checkSQL('a.int - 'b.int, "(`a` - `b`)") + checkSQL('a.int * 'b.int, "(`a` * `b`)") + checkSQL('a.int / 'b.int, "(`a` / `b`)") + checkSQL('a.int % 'b.int, "(`a` % `b`)") + + checkSQL(-'a.int, "(-`a`)") + checkSQL(-('a.int + 'b.int), "(-(`a` + `b`))") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala new file mode 100644 index 0000000000000..bf85d71c66759 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import scala.util.control.NonFatal + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SQLTestUtils + +class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { + import testImplicits._ + + protected override def beforeAll(): Unit = { + super.beforeAll() + sql("DROP TABLE IF EXISTS t0") + sql("DROP TABLE IF EXISTS t1") + sql("DROP TABLE IF EXISTS t2") + + val bytes = Array[Byte](1, 2, 3, 4) + Seq((bytes, "AQIDBA==")).toDF("a", "b").write.saveAsTable("t0") + + sqlContext + .range(10) + .select('id as 'key, concat(lit("val_"), 'id) as 'value) + .write + .saveAsTable("t1") + + sqlContext.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2") + } + + override protected def afterAll(): Unit = { + try { + sql("DROP TABLE IF EXISTS t0") + sql("DROP TABLE IF EXISTS t1") + sql("DROP TABLE IF EXISTS t2") + } finally { + super.afterAll() + } + } + + private def checkSqlGeneration(hiveQl: String): Unit = { + val df = sql(hiveQl) + + val convertedSQL = try new SQLBuilder(df).toSQL catch { + case NonFatal(e) => + fail( + s"""Cannot convert the following HiveQL query plan back to SQL query string: + | + |# Original HiveQL query string: + |$hiveQl + | + |# Resolved query plan: + |${df.queryExecution.analyzed.treeString} + """.stripMargin) + } + + try { + checkAnswer(sql(convertedSQL), df) + } catch { case cause: Throwable => + fail( + s"""Failed to execute converted SQL string or got wrong answer: + | + |# Converted SQL query string: + |$convertedSQL + | + |# Original HiveQL query string: + |$hiveQl + | + |# Resolved query plan: + |${df.queryExecution.analyzed.treeString} + """.stripMargin, + cause) + } + } + + test("misc non-aggregate functions") { + checkSqlGeneration("SELECT abs(15), abs(-15)") + checkSqlGeneration("SELECT array(1,2,3)") + checkSqlGeneration("SELECT coalesce(null, 1, 2)") + // wait for resolution of JIRA SPARK-12719 SQL Generation for Generators + // checkSqlGeneration("SELECT explode(array(1,2,3))") + checkSqlGeneration("SELECT greatest(1,null,3)") + checkSqlGeneration("SELECT if(1==2, 'yes', 'no')") + checkSqlGeneration("SELECT isnan(15), isnan('invalid')") + checkSqlGeneration("SELECT isnull(null), isnull('a')") + checkSqlGeneration("SELECT isnotnull(null), isnotnull('a')") + checkSqlGeneration("SELECT least(1,null,3)") + checkSqlGeneration("SELECT map(1, 'a', 2, 'b')") + checkSqlGeneration("SELECT named_struct('c1',1,'c2',2,'c3',3)") + checkSqlGeneration("SELECT nanvl(a, 5), nanvl(b, 10), nanvl(d, c) from t2") + checkSqlGeneration("SELECT nvl(null, 1, 2)") + checkSqlGeneration("SELECT rand(1)") + checkSqlGeneration("SELECT randn(3)") + checkSqlGeneration("SELECT struct(1,2,3)") + } + + test("math functions") { + checkSqlGeneration("SELECT acos(-1)") + checkSqlGeneration("SELECT asin(-1)") + checkSqlGeneration("SELECT atan(1)") + checkSqlGeneration("SELECT atan2(1, 1)") + checkSqlGeneration("SELECT bin(10)") + checkSqlGeneration("SELECT cbrt(1000.0)") + checkSqlGeneration("SELECT ceil(2.333)") + checkSqlGeneration("SELECT ceiling(2.333)") + checkSqlGeneration("SELECT cos(1.0)") + checkSqlGeneration("SELECT cosh(1.0)") + checkSqlGeneration("SELECT conv(15, 10, 16)") + checkSqlGeneration("SELECT degrees(pi())") + checkSqlGeneration("SELECT e()") + checkSqlGeneration("SELECT exp(1.0)") + checkSqlGeneration("SELECT expm1(1.0)") + checkSqlGeneration("SELECT floor(-2.333)") + checkSqlGeneration("SELECT factorial(5)") + checkSqlGeneration("SELECT hex(10)") + checkSqlGeneration("SELECT hypot(3, 4)") + checkSqlGeneration("SELECT log(10.0)") + checkSqlGeneration("SELECT log10(1000.0)") + checkSqlGeneration("SELECT log1p(0.0)") + checkSqlGeneration("SELECT log2(8.0)") + checkSqlGeneration("SELECT ln(10.0)") + checkSqlGeneration("SELECT negative(-1)") + checkSqlGeneration("SELECT pi()") + checkSqlGeneration("SELECT pmod(3, 2)") + checkSqlGeneration("SELECT positive(3)") + checkSqlGeneration("SELECT pow(2, 3)") + checkSqlGeneration("SELECT power(2, 3)") + checkSqlGeneration("SELECT radians(180.0)") + checkSqlGeneration("SELECT rint(1.63)") + checkSqlGeneration("SELECT round(31.415, -1)") + checkSqlGeneration("SELECT shiftleft(2, 3)") + checkSqlGeneration("SELECT shiftright(16, 3)") + checkSqlGeneration("SELECT shiftrightunsigned(16, 3)") + checkSqlGeneration("SELECT sign(-2.63)") + checkSqlGeneration("SELECT signum(-2.63)") + checkSqlGeneration("SELECT sin(1.0)") + checkSqlGeneration("SELECT sinh(1.0)") + checkSqlGeneration("SELECT sqrt(100.0)") + checkSqlGeneration("SELECT tan(1.0)") + checkSqlGeneration("SELECT tanh(1.0)") + } + + test("aggregate functions") { + checkSqlGeneration("SELECT approx_count_distinct(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT avg(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT corr(value, key) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT count(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT covar_pop(value, key) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT covar_samp(value, key) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT first(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT first_value(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT kurtosis(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT last(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT last_value(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT max(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT mean(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT min(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT skewness(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT stddev(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT stddev_pop(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT stddev_samp(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT sum(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT variance(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT var_pop(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT var_samp(value) FROM t1 GROUP BY key") + } + + test("string functions") { + checkSqlGeneration("SELECT ascii('SparkSql')") + checkSqlGeneration("SELECT base64(a) FROM t0") + checkSqlGeneration("SELECT concat('This ', 'is ', 'a ', 'test')") + checkSqlGeneration("SELECT concat_ws(' ', 'This', 'is', 'a', 'test')") + checkSqlGeneration("SELECT decode(a, 'UTF-8') FROM t0") + checkSqlGeneration("SELECT encode('SparkSql', 'UTF-8')") + checkSqlGeneration("SELECT find_in_set('ab', 'abc,b,ab,c,def')") + checkSqlGeneration("SELECT format_number(1234567.890, 2)") + checkSqlGeneration("SELECT format_string('aa%d%s',123, 'cc')") + checkSqlGeneration("SELECT get_json_object('{\"a\":\"bc\"}','$.a')") + checkSqlGeneration("SELECT initcap('This is a test')") + checkSqlGeneration("SELECT instr('This is a test', 'is')") + checkSqlGeneration("SELECT lcase('SparkSql')") + checkSqlGeneration("SELECT length('This is a test')") + checkSqlGeneration("SELECT levenshtein('This is a test', 'Another test')") + checkSqlGeneration("SELECT lower('SparkSql')") + checkSqlGeneration("SELECT locate('is', 'This is a test', 3)") + checkSqlGeneration("SELECT lpad('SparkSql', 16, 'Learning')") + checkSqlGeneration("SELECT ltrim(' SparkSql ')") + // wait for resolution of JIRA SPARK-12719 SQL Generation for Generators + // checkSqlGeneration("SELECT json_tuple('{\"f1\": \"value1\", \"f2\": \"value2\"}','f1')") + checkSqlGeneration("SELECT printf('aa%d%s', 123, 'cc')") + checkSqlGeneration("SELECT regexp_extract('100-200', '(\\d+)-(\\d+)', 1)") + checkSqlGeneration("SELECT regexp_replace('100-200', '(\\d+)', 'num')") + checkSqlGeneration("SELECT repeat('SparkSql', 3)") + checkSqlGeneration("SELECT reverse('SparkSql')") + checkSqlGeneration("SELECT rpad('SparkSql', 16, ' is Cool')") + checkSqlGeneration("SELECT rtrim(' SparkSql ')") + checkSqlGeneration("SELECT soundex('SparkSql')") + checkSqlGeneration("SELECT space(2)") + checkSqlGeneration("SELECT split('aa2bb3cc', '[1-9]+')") + checkSqlGeneration("SELECT space(2)") + checkSqlGeneration("SELECT substr('This is a test', 1)") + checkSqlGeneration("SELECT substring('This is a test', 1)") + checkSqlGeneration("SELECT substring_index('www.apache.org','.',1)") + checkSqlGeneration("SELECT translate('translate', 'rnlt', '123')") + checkSqlGeneration("SELECT trim(' SparkSql ')") + checkSqlGeneration("SELECT ucase('SparkSql')") + checkSqlGeneration("SELECT unbase64('SparkSql')") + checkSqlGeneration("SELECT unhex(41)") + checkSqlGeneration("SELECT upper('SparkSql')") + } + + test("datetime functions") { + checkSqlGeneration("SELECT add_months('2001-03-31', 1)") + checkSqlGeneration("SELECT count(current_date())") + checkSqlGeneration("SELECT count(current_timestamp())") + checkSqlGeneration("SELECT datediff('2001-01-02', '2001-01-01')") + checkSqlGeneration("SELECT date_add('2001-01-02', 1)") + checkSqlGeneration("SELECT date_format('2001-05-02', 'yyyy-dd')") + checkSqlGeneration("SELECT date_sub('2001-01-02', 1)") + checkSqlGeneration("SELECT day('2001-05-02')") + checkSqlGeneration("SELECT dayofyear('2001-05-02')") + checkSqlGeneration("SELECT dayofmonth('2001-05-02')") + checkSqlGeneration("SELECT from_unixtime(1000, 'yyyy-MM-dd HH:mm:ss')") + checkSqlGeneration("SELECT from_utc_timestamp('2015-07-24 00:00:00', 'PST')") + checkSqlGeneration("SELECT hour('11:35:55')") + checkSqlGeneration("SELECT last_day('2001-01-01')") + checkSqlGeneration("SELECT minute('11:35:55')") + checkSqlGeneration("SELECT month('2001-05-02')") + checkSqlGeneration("SELECT months_between('2001-10-30 10:30:00', '1996-10-30')") + checkSqlGeneration("SELECT next_day('2001-05-02', 'TU')") + checkSqlGeneration("SELECT count(now())") + checkSqlGeneration("SELECT quarter('2001-05-02')") + checkSqlGeneration("SELECT second('11:35:55')") + checkSqlGeneration("SELECT to_date('2001-10-30 10:30:00')") + checkSqlGeneration("SELECT to_unix_timestamp('2015-07-24 00:00:00', 'yyyy-MM-dd HH:mm:ss')") + checkSqlGeneration("SELECT to_utc_timestamp('2015-07-24 00:00:00', 'PST')") + checkSqlGeneration("SELECT trunc('2001-10-30 10:30:00', 'YEAR')") + checkSqlGeneration("SELECT unix_timestamp('2001-10-30 10:30:00')") + checkSqlGeneration("SELECT weekofyear('2001-05-02')") + checkSqlGeneration("SELECT year('2001-05-02')") + + checkSqlGeneration("SELECT interval 3 years - 3 month 7 week 123 microseconds as i") + } + + test("collection functions") { + checkSqlGeneration("SELECT array_contains(array(2, 9, 8), 9)") + checkSqlGeneration("SELECT size(array('b', 'd', 'c', 'a'))") + checkSqlGeneration("SELECT sort_array(array('b', 'd', 'c', 'a'))") + } + + test("misc functions") { + checkSqlGeneration("SELECT crc32('Spark')") + checkSqlGeneration("SELECT md5('Spark')") + checkSqlGeneration("SELECT hash('Spark')") + checkSqlGeneration("SELECT sha('Spark')") + checkSqlGeneration("SELECT sha1('Spark')") + checkSqlGeneration("SELECT sha2('Spark', 0)") + checkSqlGeneration("SELECT spark_partition_id()") + checkSqlGeneration("SELECT input_file_name()") + checkSqlGeneration("SELECT monotonically_increasing_id()") + } + + test("subquery") { + checkSqlGeneration("SELECT 1 + (SELECT 2)") + checkSqlGeneration("SELECT 1 + (SELECT 2 + (SELECT 3 as a))") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextSuite.scala new file mode 100644 index 0000000000000..b644a50613337 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextSuite.scala @@ -0,0 +1,37 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.hive + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.hive.test.TestHive + + +class HiveContextSuite extends SparkFunSuite { + + test("HiveContext can access `spark.sql.*` configs") { + // Avoid creating another SparkContext in the same JVM + val sc = TestHive.sparkContext + require(sc.conf.get("spark.sql.hive.metastore.barrierPrefixes") == + "org.apache.spark.sql.hive.execution.PairSerDe") + assert(TestHive.initialSQLConf.getConfString("spark.sql.hive.metastore.barrierPrefixes") == + "org.apache.spark.sql.hive.execution.PairSerDe") + assert(TestHive.metadataHive.getConf("spark.sql.hive.metastore.barrierPrefixes", "") == + "org.apache.spark.sql.hive.execution.PairSerDe") + } + +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala new file mode 100644 index 0000000000000..110c6d19d89ba --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -0,0 +1,582 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.hadoop.hive.serde.serdeConstants + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans +import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan +import org.apache.spark.sql.catalyst.expressions.JsonTuple +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Generate, ScriptTransformation} +import org.apache.spark.sql.execution.command.{CreateTable, CreateTableLike} +import org.apache.spark.sql.hive.execution.{HiveNativeCommand, HiveSqlParser} + +class HiveDDLCommandSuite extends PlanTest { + val parser = HiveSqlParser + + private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { + parser.parsePlan(sql).collect { + case CreateTable(desc, allowExisting) => (desc, allowExisting) + case CreateTableAsSelect(desc, _, allowExisting) => (desc, allowExisting) + case CreateViewAsSelect(desc, _, allowExisting, _, _) => (desc, allowExisting) + }.head + } + + private def assertUnsupported(sql: String): Unit = { + val e = intercept[ParseException] { + parser.parsePlan(sql) + } + assert(e.getMessage.toLowerCase.contains("unsupported")) + } + + test("Test CTAS #1") { + val s1 = + """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |(viewTime INT, + |userid BIGINT, + |page_url STRING, + |referrer_url STRING, + |ip STRING COMMENT 'IP Address of the User', + |country STRING COMMENT 'country of origination') + |COMMENT 'This is the staging page view table' + |PARTITIONED BY (dt STRING COMMENT 'date type', hour STRING COMMENT 'hour of the day') + |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\054' STORED AS RCFILE + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src""".stripMargin + + val (desc, exists) = extractTableDesc(s1) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE) + assert(desc.storage.locationUri == Some("/user/external/page_view")) + assert(desc.schema == + CatalogColumn("viewtime", "int") :: + CatalogColumn("userid", "bigint") :: + CatalogColumn("page_url", "string") :: + CatalogColumn("referrer_url", "string") :: + CatalogColumn("ip", "string", comment = Some("IP Address of the User")) :: + CatalogColumn("country", "string", comment = Some("country of origination")) :: + CatalogColumn("dt", "string", comment = Some("date type")) :: + CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) + assert(desc.comment == Some("This is the staging page view table")) + // TODO will be SQLText + assert(desc.viewText.isEmpty) + assert(desc.viewOriginalText.isEmpty) + assert(desc.partitionColumns == + CatalogColumn("dt", "string", comment = Some("date type")) :: + CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) + assert(desc.storage.serdeProperties == + Map((serdeConstants.SERIALIZATION_FORMAT, "\u002C"), (serdeConstants.FIELD_DELIM, "\u002C"))) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(desc.storage.serde == + Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) + assert(desc.properties == Map(("p1", "v1"), ("p2", "v2"))) + } + + test("Test CTAS #2") { + val s2 = + """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |(viewTime INT, + |userid BIGINT, + |page_url STRING, + |referrer_url STRING, + |ip STRING COMMENT 'IP Address of the User', + |country STRING COMMENT 'country of origination') + |COMMENT 'This is the staging page view table' + |PARTITIONED BY (dt STRING COMMENT 'date type', hour STRING COMMENT 'hour of the day') + |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' + | STORED AS + | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat' + | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src""".stripMargin + + val (desc, exists) = extractTableDesc(s2) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE) + assert(desc.storage.locationUri == Some("/user/external/page_view")) + assert(desc.schema == + CatalogColumn("viewtime", "int") :: + CatalogColumn("userid", "bigint") :: + CatalogColumn("page_url", "string") :: + CatalogColumn("referrer_url", "string") :: + CatalogColumn("ip", "string", comment = Some("IP Address of the User")) :: + CatalogColumn("country", "string", comment = Some("country of origination")) :: + CatalogColumn("dt", "string", comment = Some("date type")) :: + CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) + // TODO will be SQLText + assert(desc.comment == Some("This is the staging page view table")) + assert(desc.viewText.isEmpty) + assert(desc.viewOriginalText.isEmpty) + assert(desc.partitionColumns == + CatalogColumn("dt", "string", comment = Some("date type")) :: + CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) + assert(desc.storage.serdeProperties == Map()) + assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) + assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) + assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) + assert(desc.properties == Map(("p1", "v1"), ("p2", "v2"))) + } + + test("Test CTAS #3") { + val s3 = """CREATE TABLE page_view AS SELECT * FROM src""" + val (desc, exists) = extractTableDesc(s3) + assert(exists == false) + assert(desc.identifier.database == None) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.MANAGED_TABLE) + assert(desc.storage.locationUri == None) + assert(desc.schema == Seq.empty[CatalogColumn]) + assert(desc.viewText == None) // TODO will be SQLText + assert(desc.viewOriginalText.isEmpty) + assert(desc.storage.serdeProperties == Map()) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) + assert(desc.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) + assert(desc.storage.serde.isEmpty) + assert(desc.properties == Map()) + } + + test("Test CTAS #4") { + val s4 = + """CREATE TABLE page_view + |STORED BY 'storage.handler.class.name' AS SELECT * FROM src""".stripMargin + intercept[AnalysisException] { + extractTableDesc(s4) + } + } + + test("Test CTAS #5") { + val s5 = """CREATE TABLE ctas2 + | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" + | WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2") + | STORED AS RCFile + | TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22") + | AS + | SELECT key, value + | FROM src + | ORDER BY key, value""".stripMargin + val (desc, exists) = extractTableDesc(s5) + assert(exists == false) + assert(desc.identifier.database == None) + assert(desc.identifier.table == "ctas2") + assert(desc.tableType == CatalogTableType.MANAGED_TABLE) + assert(desc.storage.locationUri == None) + assert(desc.schema == Seq.empty[CatalogColumn]) + assert(desc.viewText == None) // TODO will be SQLText + assert(desc.viewOriginalText.isEmpty) + assert(desc.storage.serdeProperties == Map(("serde_p1" -> "p1"), ("serde_p2" -> "p2"))) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe")) + assert(desc.properties == Map(("tbl_p1" -> "p11"), ("tbl_p2" -> "p22"))) + } + + test("unsupported operations") { + intercept[ParseException] { + parser.parsePlan( + """ + |CREATE TEMPORARY TABLE ctas2 + |ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" + |WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2") + |STORED AS RCFile + |TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22") + |AS SELECT key, value FROM src ORDER BY key, value + """.stripMargin) + } + intercept[ParseException] { + parser.parsePlan( + """ + |CREATE TABLE user_info_bucketed(user_id BIGINT, firstname STRING, lastname STRING) + |CLUSTERED BY(user_id) INTO 256 BUCKETS + |AS SELECT key, value FROM src ORDER BY key, value + """.stripMargin) + } + intercept[ParseException] { + parser.parsePlan( + """ + |CREATE TABLE user_info_bucketed(user_id BIGINT, firstname STRING, lastname STRING) + |SKEWED BY (key) ON (1,5,6) + |AS SELECT key, value FROM src ORDER BY key, value + """.stripMargin) + } + intercept[ParseException] { + parser.parsePlan( + """ + |SELECT TRANSFORM (key, value) USING 'cat' AS (tKey, tValue) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.contrib.serde2.TypedBytesSerDe' + |RECORDREADER 'org.apache.hadoop.hive.contrib.util.typedbytes.TypedBytesRecordReader' + |FROM testData + """.stripMargin) + } + } + + test("Invalid interval term should throw AnalysisException") { + def assertError(sql: String, errorMessage: String): Unit = { + val e = intercept[AnalysisException] { + parser.parsePlan(sql) + } + assert(e.getMessage.contains(errorMessage)) + } + assertError("select interval '42-32' year to month", + "month 32 outside range [0, 11]") + assertError("select interval '5 49:12:15' day to second", + "hour 49 outside range [0, 23]") + assertError("select interval '.1111111111' second", + "nanosecond 1111111111 outside range") + } + + test("use native json_tuple instead of hive's UDTF in LATERAL VIEW") { + val plan = parser.parsePlan( + """ + |SELECT * + |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test + |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt AS a, b + """.stripMargin) + + assert(plan.children.head.asInstanceOf[Generate].generator.isInstanceOf[JsonTuple]) + } + + test("transform query spec") { + val plan1 = parser.parsePlan("select transform(a, b) using 'func' from e where f < 10") + .asInstanceOf[ScriptTransformation].copy(ioschema = null) + val plan2 = parser.parsePlan("map a, b using 'func' as c, d from e") + .asInstanceOf[ScriptTransformation].copy(ioschema = null) + val plan3 = parser.parsePlan("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e") + .asInstanceOf[ScriptTransformation].copy(ioschema = null) + + val p = ScriptTransformation( + Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), + "func", Seq.empty, plans.table("e"), null) + + comparePlans(plan1, + p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) + comparePlans(plan2, + p.copy(output = Seq('c.string, 'd.string))) + comparePlans(plan3, + p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) + } + + test("use backticks in output of Script Transform") { + parser.parsePlan( + """SELECT `t`.`thing1` + |FROM (SELECT TRANSFORM (`parquet_t1`.`key`, `parquet_t1`.`value`) + |USING 'cat' AS (`thing1` int, `thing2` string) FROM `default`.`parquet_t1`) AS t + """.stripMargin) + } + + test("use backticks in output of Generator") { + parser.parsePlan( + """ + |SELECT `gentab2`.`gencol2` + |FROM `default`.`src` + |LATERAL VIEW explode(array(array(1, 2, 3))) `gentab1` AS `gencol1` + |LATERAL VIEW explode(`gentab1`.`gencol1`) `gentab2` AS `gencol2` + """.stripMargin) + } + + test("use escaped backticks in output of Generator") { + parser.parsePlan( + """ + |SELECT `gen``tab2`.`gen``col2` + |FROM `default`.`src` + |LATERAL VIEW explode(array(array(1, 2, 3))) `gen``tab1` AS `gen``col1` + |LATERAL VIEW explode(`gen``tab1`.`gen``col1`) `gen``tab2` AS `gen``col2` + """.stripMargin) + } + + test("create table - basic") { + val query = "CREATE TABLE my_table (id int, name string)" + val (desc, allowExisting) = extractTableDesc(query) + assert(!allowExisting) + assert(desc.identifier.database.isEmpty) + assert(desc.identifier.table == "my_table") + assert(desc.tableType == CatalogTableType.MANAGED_TABLE) + assert(desc.schema == Seq(CatalogColumn("id", "int"), CatalogColumn("name", "string"))) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.sortColumnNames.isEmpty) + assert(desc.bucketColumnNames.isEmpty) + assert(desc.numBuckets == -1) + assert(desc.viewText.isEmpty) + assert(desc.viewOriginalText.isEmpty) + assert(desc.storage.locationUri.isEmpty) + assert(desc.storage.inputFormat == + Some("org.apache.hadoop.mapred.TextInputFormat")) + assert(desc.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) + assert(desc.storage.serde.isEmpty) + assert(desc.storage.serdeProperties.isEmpty) + assert(desc.properties.isEmpty) + assert(desc.comment.isEmpty) + } + + test("create table - with database name") { + val query = "CREATE TABLE dbx.my_table (id int, name string)" + val (desc, _) = extractTableDesc(query) + assert(desc.identifier.database == Some("dbx")) + assert(desc.identifier.table == "my_table") + } + + test("create table - temporary") { + val query = "CREATE TEMPORARY TABLE tab1 (id int, name string)" + val e = intercept[ParseException] { parser.parsePlan(query) } + assert(e.message.contains("registerTempTable")) + } + + test("create table - external") { + val query = "CREATE EXTERNAL TABLE tab1 (id int, name string)" + val (desc, _) = extractTableDesc(query) + assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE) + } + + test("create table - if not exists") { + val query = "CREATE TABLE IF NOT EXISTS tab1 (id int, name string)" + val (_, allowExisting) = extractTableDesc(query) + assert(allowExisting) + } + + test("create table - comment") { + val query = "CREATE TABLE my_table (id int, name string) COMMENT 'its hot as hell below'" + val (desc, _) = extractTableDesc(query) + assert(desc.comment == Some("its hot as hell below")) + } + + test("create table - partitioned columns") { + val query = "CREATE TABLE my_table (id int, name string) PARTITIONED BY (month int)" + val (desc, _) = extractTableDesc(query) + assert(desc.schema == Seq( + CatalogColumn("id", "int"), + CatalogColumn("name", "string"), + CatalogColumn("month", "int"))) + assert(desc.partitionColumnNames == Seq("month")) + } + + test("create table - clustered by") { + val baseQuery = "CREATE TABLE my_table (id int, name string) CLUSTERED BY(id)" + val query1 = s"$baseQuery INTO 10 BUCKETS" + val query2 = s"$baseQuery SORTED BY(id) INTO 10 BUCKETS" + val e1 = intercept[ParseException] { parser.parsePlan(query1) } + val e2 = intercept[ParseException] { parser.parsePlan(query2) } + assert(e1.getMessage.contains("Operation not allowed")) + assert(e2.getMessage.contains("Operation not allowed")) + } + + test("create table - skewed by") { + val baseQuery = "CREATE TABLE my_table (id int, name string) SKEWED BY" + val query1 = s"$baseQuery(id) ON (1, 10, 100)" + val query2 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z'))" + val query3 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z')) STORED AS DIRECTORIES" + val e1 = intercept[ParseException] { parser.parsePlan(query1) } + val e2 = intercept[ParseException] { parser.parsePlan(query2) } + val e3 = intercept[ParseException] { parser.parsePlan(query3) } + assert(e1.getMessage.contains("Operation not allowed")) + assert(e2.getMessage.contains("Operation not allowed")) + assert(e3.getMessage.contains("Operation not allowed")) + } + + test("create table - row format") { + val baseQuery = "CREATE TABLE my_table (id int, name string) ROW FORMAT" + val query1 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff'" + val query2 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1')" + val query3 = + s""" + |$baseQuery DELIMITED FIELDS TERMINATED BY 'x' ESCAPED BY 'y' + |COLLECTION ITEMS TERMINATED BY 'a' + |MAP KEYS TERMINATED BY 'b' + |LINES TERMINATED BY '\n' + |NULL DEFINED AS 'c' + """.stripMargin + val (desc1, _) = extractTableDesc(query1) + val (desc2, _) = extractTableDesc(query2) + val (desc3, _) = extractTableDesc(query3) + assert(desc1.storage.serde == Some("org.apache.poof.serde.Baff")) + assert(desc1.storage.serdeProperties.isEmpty) + assert(desc2.storage.serde == Some("org.apache.poof.serde.Baff")) + assert(desc2.storage.serdeProperties == Map("k1" -> "v1")) + assert(desc3.storage.serdeProperties == Map( + "field.delim" -> "x", + "escape.delim" -> "y", + "serialization.format" -> "x", + "line.delim" -> "\n", + "colelction.delim" -> "a", // yes, it's a typo from Hive :) + "mapkey.delim" -> "b")) + } + + test("create table - file format") { + val baseQuery = "CREATE TABLE my_table (id int, name string) STORED AS" + val query1 = s"$baseQuery INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput'" + val query2 = s"$baseQuery ORC" + val (desc1, _) = extractTableDesc(query1) + val (desc2, _) = extractTableDesc(query2) + assert(desc1.storage.inputFormat == Some("winput")) + assert(desc1.storage.outputFormat == Some("wowput")) + assert(desc1.storage.serde.isEmpty) + assert(desc2.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat")) + assert(desc2.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) + assert(desc2.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + } + + test("create table - storage handler") { + val baseQuery = "CREATE TABLE my_table (id int, name string) STORED BY" + val query1 = s"$baseQuery 'org.papachi.StorageHandler'" + val query2 = s"$baseQuery 'org.mamachi.StorageHandler' WITH SERDEPROPERTIES ('k1'='v1')" + val e1 = intercept[ParseException] { parser.parsePlan(query1) } + val e2 = intercept[ParseException] { parser.parsePlan(query2) } + assert(e1.getMessage.contains("Operation not allowed")) + assert(e2.getMessage.contains("Operation not allowed")) + } + + test("create table - location") { + val query = "CREATE TABLE my_table (id int, name string) LOCATION '/path/to/mars'" + val (desc, _) = extractTableDesc(query) + assert(desc.storage.locationUri == Some("/path/to/mars")) + } + + test("create table - properties") { + val query = "CREATE TABLE my_table (id int, name string) TBLPROPERTIES ('k1'='v1', 'k2'='v2')" + val (desc, _) = extractTableDesc(query) + assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2")) + } + + test("create table - everything!") { + val query = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS dbx.my_table (id int, name string) + |COMMENT 'no comment' + |PARTITIONED BY (month int) + |ROW FORMAT SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1') + |STORED AS INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput' + |LOCATION '/path/to/mercury' + |TBLPROPERTIES ('k1'='v1', 'k2'='v2') + """.stripMargin + val (desc, allowExisting) = extractTableDesc(query) + assert(allowExisting) + assert(desc.identifier.database == Some("dbx")) + assert(desc.identifier.table == "my_table") + assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE) + assert(desc.schema == Seq( + CatalogColumn("id", "int"), + CatalogColumn("name", "string"), + CatalogColumn("month", "int"))) + assert(desc.partitionColumnNames == Seq("month")) + assert(desc.sortColumnNames.isEmpty) + assert(desc.bucketColumnNames.isEmpty) + assert(desc.numBuckets == -1) + assert(desc.viewText.isEmpty) + assert(desc.viewOriginalText.isEmpty) + assert(desc.storage.locationUri == Some("/path/to/mercury")) + assert(desc.storage.inputFormat == Some("winput")) + assert(desc.storage.outputFormat == Some("wowput")) + assert(desc.storage.serde == Some("org.apache.poof.serde.Baff")) + assert(desc.storage.serdeProperties == Map("k1" -> "v1")) + assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2")) + assert(desc.comment == Some("no comment")) + } + + test("create view -- basic") { + val v1 = "CREATE VIEW view1 AS SELECT * FROM tab1" + val (desc, exists) = extractTableDesc(v1) + assert(!exists) + assert(desc.identifier.database.isEmpty) + assert(desc.identifier.table == "view1") + assert(desc.tableType == CatalogTableType.VIRTUAL_VIEW) + assert(desc.storage.locationUri.isEmpty) + assert(desc.schema == Seq.empty[CatalogColumn]) + assert(desc.viewText == Option("SELECT * FROM tab1")) + assert(desc.viewOriginalText == Option("SELECT * FROM tab1")) + assert(desc.storage.serdeProperties == Map()) + assert(desc.storage.inputFormat.isEmpty) + assert(desc.storage.outputFormat.isEmpty) + assert(desc.storage.serde.isEmpty) + assert(desc.properties == Map()) + } + + test("create view - full") { + val v1 = + """ + |CREATE OR REPLACE VIEW IF NOT EXISTS view1 + |(col1, col3) + |COMMENT 'BLABLA' + |TBLPROPERTIES('prop1Key'="prop1Val") + |AS SELECT * FROM tab1 + """.stripMargin + val (desc, exists) = extractTableDesc(v1) + assert(exists) + assert(desc.identifier.database.isEmpty) + assert(desc.identifier.table == "view1") + assert(desc.tableType == CatalogTableType.VIRTUAL_VIEW) + assert(desc.storage.locationUri.isEmpty) + assert(desc.schema == + CatalogColumn("col1", null, nullable = true, None) :: + CatalogColumn("col3", null, nullable = true, None) :: Nil) + assert(desc.viewText == Option("SELECT * FROM tab1")) + assert(desc.viewOriginalText == Option("SELECT * FROM tab1")) + assert(desc.storage.serdeProperties == Map()) + assert(desc.storage.inputFormat.isEmpty) + assert(desc.storage.outputFormat.isEmpty) + assert(desc.storage.serde.isEmpty) + assert(desc.properties == Map("prop1Key" -> "prop1Val")) + assert(desc.comment == Option("BLABLA")) + } + + test("create view -- partitioned view") { + val v1 = "CREATE VIEW view1 partitioned on (ds, hr) as select * from srcpart" + intercept[ParseException] { + parser.parsePlan(v1).isInstanceOf[HiveNativeCommand] + } + } + + test("MSCK repair table (not supported)") { + assertUnsupported("MSCK REPAIR TABLE tab1") + } + + test("create table like") { + val v1 = "CREATE TABLE table1 LIKE table2" + val (target, source, exists) = parser.parsePlan(v1).collect { + case CreateTableLike(t, s, allowExisting) => (t, s, allowExisting) + }.head + assert(exists == false) + assert(target.database.isEmpty) + assert(target.table == "table1") + assert(source.database.isEmpty) + assert(source.table == "table2") + + val v2 = "CREATE TABLE IF NOT EXISTS table1 LIKE table2" + val (target2, source2, exists2) = parser.parsePlan(v2).collect { + case CreateTableLike(t, s, allowExisting) => (t, s, allowExisting) + }.head + assert(exists2) + assert(target2.database.isEmpty) + assert(target2.table == "table1") + assert(source2.database.isEmpty) + assert(source2.table == "table2") + } + +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index 2e5cae415e54b..57f96e725a044 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.{DataFrame, QueryTest} +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.scalatest.BeforeAndAfterAll // TODO ideally we should put the test suite into the package `sql`, as // `hive` package is optional in compiling, however, `SQLContext.sql` doesn't @@ -32,12 +33,17 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with private var testData: DataFrame = _ override def beforeAll() { - testData = Seq((1, 2), (2, 4)).toDF("a", "b") + super.beforeAll() + testData = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") hiveContext.registerDataFrameAsTable(testData, "mytable") } override def afterAll(): Unit = { - hiveContext.dropTempTable("mytable") + try { + hiveContext.dropTempTable("mytable") + } finally { + super.afterAll() + } } test("rollup") { @@ -52,6 +58,17 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with ) } + test("collect functions") { + checkAnswer( + testData.select(collect_list($"a"), collect_list($"b")), + Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4))) + ) + checkAnswer( + testData.select(collect_set($"a"), collect_set($"b")), + Seq(Row(Seq(1, 2, 3), Seq(2, 4))) + ) + } + test("cube") { checkAnswer( testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala index f621367eb553b..63cf5030ab8b6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.hive.test.TestHiveSingleton class HiveDataFrameJoinSuite extends QueryTest with TestHiveSingleton { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala new file mode 100644 index 0000000000000..7fdc5d71937ff --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.QueryTest + +class HiveDataFrameSuite extends QueryTest with TestHiveSingleton { + test("table name with schema") { + // regression test for SPARK-11778 + hiveContext.sql("create schema usrdb") + hiveContext.sql("create table usrdb.test(c int)") + hiveContext.read.table("usrdb.test") + hiveContext.sql("drop table usrdb.test") + hiveContext.sql("drop schema usrdb") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala deleted file mode 100644 index 2c98f1c3cc49c..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ /dev/null @@ -1,259 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.expressions.Window -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.TestHiveSingleton - -class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { - import hiveContext.implicits._ - import hiveContext.sql - - test("reuse window partitionBy") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - val w = Window.partitionBy("key").orderBy("value") - - checkAnswer( - df.select( - lead("key", 1).over(w), - lead("value", 1).over(w)), - Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) - } - - test("reuse window orderBy") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - val w = Window.orderBy("value").partitionBy("key") - - checkAnswer( - df.select( - lead("key", 1).over(w), - lead("value", 1).over(w)), - Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) - } - - test("lead") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.registerTempTable("window_table") - - checkAnswer( - df.select( - lead("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), - sql( - """SELECT - | lead(value) OVER (PARTITION BY key ORDER BY value) - | FROM window_table""".stripMargin).collect()) - } - - test("lag") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.registerTempTable("window_table") - - checkAnswer( - df.select( - lag("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), - sql( - """SELECT - | lag(value) OVER (PARTITION BY key ORDER BY value) - | FROM window_table""".stripMargin).collect()) - } - - test("lead with default value") { - val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), - (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.registerTempTable("window_table") - checkAnswer( - df.select( - lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))), - sql( - """SELECT - | lead(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) - | FROM window_table""".stripMargin).collect()) - } - - test("lag with default value") { - val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), - (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.registerTempTable("window_table") - checkAnswer( - df.select( - lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))), - sql( - """SELECT - | lag(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) - | FROM window_table""".stripMargin).collect()) - } - - test("rank functions in unspecific window") { - val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value") - df.registerTempTable("window_table") - checkAnswer( - df.select( - $"key", - max("key").over(Window.partitionBy("value").orderBy("key")), - min("key").over(Window.partitionBy("value").orderBy("key")), - mean("key").over(Window.partitionBy("value").orderBy("key")), - count("key").over(Window.partitionBy("value").orderBy("key")), - sum("key").over(Window.partitionBy("value").orderBy("key")), - ntile(2).over(Window.partitionBy("value").orderBy("key")), - rowNumber().over(Window.partitionBy("value").orderBy("key")), - denseRank().over(Window.partitionBy("value").orderBy("key")), - rank().over(Window.partitionBy("value").orderBy("key")), - cumeDist().over(Window.partitionBy("value").orderBy("key")), - percentRank().over(Window.partitionBy("value").orderBy("key"))), - sql( - s"""SELECT - |key, - |max(key) over (partition by value order by key), - |min(key) over (partition by value order by key), - |avg(key) over (partition by value order by key), - |count(key) over (partition by value order by key), - |sum(key) over (partition by value order by key), - |ntile(2) over (partition by value order by key), - |row_number() over (partition by value order by key), - |dense_rank() over (partition by value order by key), - |rank() over (partition by value order by key), - |cume_dist() over (partition by value order by key), - |percent_rank() over (partition by value order by key) - |FROM window_table""".stripMargin).collect()) - } - - test("aggregation and rows between") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.registerTempTable("window_table") - checkAnswer( - df.select( - avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))), - sql( - """SELECT - | avg(key) OVER - | (PARTITION BY value ORDER BY key ROWS BETWEEN 1 preceding and 2 following) - | FROM window_table""".stripMargin).collect()) - } - - test("aggregation and range betweens") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.registerTempTable("window_table") - checkAnswer( - df.select( - avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))), - sql( - """SELECT - | avg(key) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and 1 following) - | FROM window_table""".stripMargin).collect()) - } - - test("aggregation and rows betweens with unbounded") { - val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") - df.registerTempTable("window_table") - checkAnswer( - df.select( - $"key", - last("value").over( - Window.partitionBy($"value").orderBy($"key").rowsBetween(0, Long.MaxValue)), - last("value").over( - Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)), - last("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 3))), - sql( - """SELECT - | key, - | last_value(value) OVER - | (PARTITION BY value ORDER BY key ROWS between current row and unbounded following), - | last_value(value) OVER - | (PARTITION BY value ORDER BY key ROWS between unbounded preceding and current row), - | last_value(value) OVER - | (PARTITION BY value ORDER BY key ROWS between 1 preceding and 3 following) - | FROM window_table""".stripMargin).collect()) - } - - test("aggregation and range betweens with unbounded") { - val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") - df.registerTempTable("window_table") - checkAnswer( - df.select( - $"key", - last("value").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(-2, -1)) - .equalTo("2") - .as("last_v"), - avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(Long.MinValue, 1)) - .as("avg_key1"), - avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, Long.MaxValue)) - .as("avg_key2"), - avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0)) - .as("avg_key3") - ), - sql( - """SELECT - | key, - | last_value(value) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN 2 preceding and 1 preceding) == "2", - | avg(key) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN unbounded preceding and 1 following), - | avg(key) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN current row and unbounded following), - | avg(key) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and current row) - | FROM window_table""".stripMargin).collect()) - } - - test("reverse sliding range frame") { - val df = Seq( - (1, "Thin", "Cell Phone", 6000), - (2, "Normal", "Tablet", 1500), - (3, "Mini", "Tablet", 5500), - (4, "Ultra thin", "Cell Phone", 5500), - (5, "Very thin", "Cell Phone", 6000), - (6, "Big", "Tablet", 2500), - (7, "Bendable", "Cell Phone", 3000), - (8, "Foldable", "Cell Phone", 3000), - (9, "Pro", "Tablet", 4500), - (10, "Pro2", "Tablet", 6500)). - toDF("id", "product", "category", "revenue") - val window = Window. - partitionBy($"category"). - orderBy($"revenue".desc). - rangeBetween(-2000L, 1000L) - checkAnswer( - df.select( - $"id", - avg($"revenue").over(window).cast("int")), - Row(1, 5833) :: Row(2, 2000) :: Row(3, 5500) :: - Row(4, 5833) :: Row(5, 5833) :: Row(6, 2833) :: - Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: - Row(10, 6000) :: Nil) - } - - // This is here to illustrate the fact that reverse order also reverses offsets. - test("reverse unbounded range frame") { - val df = Seq(1, 2, 4, 3, 2, 1). - map(Tuple1.apply). - toDF("value") - val window = Window.orderBy($"value".desc) - checkAnswer( - df.select( - $"value", - sum($"value").over(window.rangeBetween(Long.MinValue, 1)), - sum($"value").over(window.rangeBetween(1, Long.MaxValue))), - Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: - Row(3, 11, 6) :: Row(2, 13, 2) :: Row(1, 13, null) :: Nil) - - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala new file mode 100644 index 0000000000000..3334c16f0be87 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.util.VersionInfo + +import org.apache.spark.SparkConf +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.hive.client.{HiveClient, IsolatedClientLoader} +import org.apache.spark.util.Utils + +/** + * Test suite for the [[HiveExternalCatalog]]. + */ +class HiveExternalCatalogSuite extends CatalogTestCases { + + private val client: HiveClient = { + IsolatedClientLoader.forVersion( + hiveMetastoreVersion = HiveContext.hiveExecutionVersion, + hadoopVersion = VersionInfo.getVersion, + sparkConf = new SparkConf(), + hadoopConf = new Configuration()).createClient() + } + + protected override val utils: CatalogTestUtils = new CatalogTestUtils { + override val tableInputFormat: String = "org.apache.hadoop.mapred.SequenceFileInputFormat" + override val tableOutputFormat: String = "org.apache.hadoop.mapred.SequenceFileOutputFormat" + override def newEmptyCatalog(): ExternalCatalog = new HiveExternalCatalog(client) + } + + protected override def resetState(): Unit = client.reset() + +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 8bb9058cd74ef..3b867bbfa1817 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -30,7 +30,7 @@ import org.apache.hadoop.io.LongWritable import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.catalyst.util.{MapData, GenericArrayData, ArrayBasedMapData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.sql.Row diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index d63f3d3996523..8648834f0d881 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -19,14 +19,15 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.hive.client.{ExternalTable, ManagedTable} +import org.apache.spark.sql.{QueryTest, Row, SaveMode} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogTableType import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils} import org.apache.spark.sql.types.{DecimalType, StringType, StructType} -import org.apache.spark.sql.{SQLConf, QueryTest, Row, SaveMode} -class HiveMetastoreCatalogSuite extends SparkFunSuite with TestHiveSingleton { +class HiveMetastoreCatalogSuite extends TestHiveSingleton { import hiveContext.implicits._ test("struct field should accept underscore in sub-column name") { @@ -82,17 +83,17 @@ class DataSourceWithHiveMetastoreCatalogSuite .saveAsTable("t") } - val hiveTable = catalog.client.getTable("default", "t") - assert(hiveTable.inputFormat === Some(inputFormat)) - assert(hiveTable.outputFormat === Some(outputFormat)) - assert(hiveTable.serde === Some(serde)) + val hiveTable = sessionState.catalog.getTableMetadata(TableIdentifier("t", Some("default"))) + assert(hiveTable.storage.inputFormat === Some(inputFormat)) + assert(hiveTable.storage.outputFormat === Some(outputFormat)) + assert(hiveTable.storage.serde === Some(serde)) - assert(!hiveTable.isPartitioned) - assert(hiveTable.tableType === ManagedTable) + assert(hiveTable.partitionColumnNames.isEmpty) + assert(hiveTable.tableType === CatalogTableType.MANAGED_TABLE) val columns = hiveTable.schema assert(columns.map(_.name) === Seq("d1", "d2")) - assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string")) + assert(columns.map(_.dataType) === Seq("decimal(10,3)", "string")) checkAnswer(table("t"), testDF) assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) @@ -113,17 +114,19 @@ class DataSourceWithHiveMetastoreCatalogSuite .saveAsTable("t") } - val hiveTable = catalog.client.getTable("default", "t") - assert(hiveTable.inputFormat === Some(inputFormat)) - assert(hiveTable.outputFormat === Some(outputFormat)) - assert(hiveTable.serde === Some(serde)) + val hiveTable = + sessionState.catalog.getTableMetadata(TableIdentifier("t", Some("default"))) + assert(hiveTable.storage.inputFormat === Some(inputFormat)) + assert(hiveTable.storage.outputFormat === Some(outputFormat)) + assert(hiveTable.storage.serde === Some(serde)) - assert(hiveTable.tableType === ExternalTable) - assert(hiveTable.location.get === path.toURI.toString.stripSuffix(File.separator)) + assert(hiveTable.tableType === CatalogTableType.EXTERNAL_TABLE) + assert(hiveTable.storage.locationUri === + Some(path.toURI.toString.stripSuffix(File.separator))) val columns = hiveTable.schema assert(columns.map(_.name) === Seq("d1", "d2")) - assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string")) + assert(columns.map(_.dataType) === Seq("decimal(10,3)", "string")) checkAnswer(table("t"), testDF) assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) @@ -142,18 +145,18 @@ class DataSourceWithHiveMetastoreCatalogSuite |AS SELECT 1 AS d1, "val_1" AS d2 """.stripMargin) - val hiveTable = catalog.client.getTable("default", "t") - assert(hiveTable.inputFormat === Some(inputFormat)) - assert(hiveTable.outputFormat === Some(outputFormat)) - assert(hiveTable.serde === Some(serde)) + val hiveTable = + sessionState.catalog.getTableMetadata(TableIdentifier("t", Some("default"))) + assert(hiveTable.storage.inputFormat === Some(inputFormat)) + assert(hiveTable.storage.outputFormat === Some(outputFormat)) + assert(hiveTable.storage.serde === Some(serde)) - assert(hiveTable.isPartitioned === false) - assert(hiveTable.tableType === ExternalTable) - assert(hiveTable.partitionColumns.length === 0) + assert(hiveTable.partitionColumnNames.isEmpty) + assert(hiveTable.tableType === CatalogTableType.EXTERNAL_TABLE) val columns = hiveTable.schema assert(columns.map(_.name) === Seq("d1", "d2")) - assert(columns.map(_.hiveType) === Seq("int", "string")) + assert(columns.map(_.dataType) === Seq("int", "string")) checkAnswer(table("t"), Row(1, "val_1")) assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index 5596ec6882ea2..b5af758a65b1c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.hive +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.{QueryTest, Row} case class Cases(lower: String, UPPER: String) @@ -61,7 +61,8 @@ class HiveParquetSuite extends QueryTest with ParquetTest with TestHiveSingleton } test("INSERT OVERWRITE TABLE Parquet table") { - withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { + // Don't run with vectorized: currently relies on UnsafeRow. + withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t", false) { withTempPath { file => sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath) hiveContext.read.parquet(file.getCanonicalPath).registerTempTable("p") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala deleted file mode 100644 index 528a7398b10df..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ /dev/null @@ -1,186 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import org.apache.hadoop.hive.serde.serdeConstants -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, HiveTable, ManagedTable} - - -class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { - private def extractTableDesc(sql: String): (HiveTable, Boolean) = { - HiveQl.createPlan(sql).collect { - case CreateTableAsSelect(desc, child, allowExisting) => (desc, allowExisting) - }.head - } - - test("Test CTAS #1") { - val s1 = - """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view - |(viewTime INT, - |userid BIGINT, - |page_url STRING, - |referrer_url STRING, - |ip STRING COMMENT 'IP Address of the User', - |country STRING COMMENT 'country of origination') - |COMMENT 'This is the staging page view table' - |PARTITIONED BY (dt STRING COMMENT 'date type', hour STRING COMMENT 'hour of the day') - |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\054' STORED AS RCFILE - |LOCATION '/user/external/page_view' - |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin - - val (desc, exists) = extractTableDesc(s1) - assert(exists == true) - assert(desc.specifiedDatabase == Some("mydb")) - assert(desc.name == "page_view") - assert(desc.tableType == ExternalTable) - assert(desc.location == Some("/user/external/page_view")) - assert(desc.schema == - HiveColumn("viewtime", "int", null) :: - HiveColumn("userid", "bigint", null) :: - HiveColumn("page_url", "string", null) :: - HiveColumn("referrer_url", "string", null) :: - HiveColumn("ip", "string", "IP Address of the User") :: - HiveColumn("country", "string", "country of origination") :: Nil) - // TODO will be SQLText - assert(desc.viewText == Option("This is the staging page view table")) - assert(desc.partitionColumns == - HiveColumn("dt", "string", "date type") :: - HiveColumn("hour", "string", "hour of the day") :: Nil) - assert(desc.serdeProperties == - Map((serdeConstants.SERIALIZATION_FORMAT, "\054"), (serdeConstants.FIELD_DELIM, "\054"))) - assert(desc.inputFormat == Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) - assert(desc.outputFormat == Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - assert(desc.serde == Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) - assert(desc.properties == Map(("p1", "v1"), ("p2", "v2"))) - } - - test("Test CTAS #2") { - val s2 = - """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view - |(viewTime INT, - |userid BIGINT, - |page_url STRING, - |referrer_url STRING, - |ip STRING COMMENT 'IP Address of the User', - |country STRING COMMENT 'country of origination') - |COMMENT 'This is the staging page view table' - |PARTITIONED BY (dt STRING COMMENT 'date type', hour STRING COMMENT 'hour of the day') - |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' - | STORED AS - | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat' - | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' - |LOCATION '/user/external/page_view' - |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin - - val (desc, exists) = extractTableDesc(s2) - assert(exists == true) - assert(desc.specifiedDatabase == Some("mydb")) - assert(desc.name == "page_view") - assert(desc.tableType == ExternalTable) - assert(desc.location == Some("/user/external/page_view")) - assert(desc.schema == - HiveColumn("viewtime", "int", null) :: - HiveColumn("userid", "bigint", null) :: - HiveColumn("page_url", "string", null) :: - HiveColumn("referrer_url", "string", null) :: - HiveColumn("ip", "string", "IP Address of the User") :: - HiveColumn("country", "string", "country of origination") :: Nil) - // TODO will be SQLText - assert(desc.viewText == Option("This is the staging page view table")) - assert(desc.partitionColumns == - HiveColumn("dt", "string", "date type") :: - HiveColumn("hour", "string", "hour of the day") :: Nil) - assert(desc.serdeProperties == Map()) - assert(desc.inputFormat == Option("parquet.hive.DeprecatedParquetInputFormat")) - assert(desc.outputFormat == Option("parquet.hive.DeprecatedParquetOutputFormat")) - assert(desc.serde == Option("parquet.hive.serde.ParquetHiveSerDe")) - assert(desc.properties == Map(("p1", "v1"), ("p2", "v2"))) - } - - test("Test CTAS #3") { - val s3 = """CREATE TABLE page_view AS SELECT * FROM src""" - val (desc, exists) = extractTableDesc(s3) - assert(exists == false) - assert(desc.specifiedDatabase == None) - assert(desc.name == "page_view") - assert(desc.tableType == ManagedTable) - assert(desc.location == None) - assert(desc.schema == Seq.empty[HiveColumn]) - assert(desc.viewText == None) // TODO will be SQLText - assert(desc.serdeProperties == Map()) - assert(desc.inputFormat == Option("org.apache.hadoop.mapred.TextInputFormat")) - assert(desc.outputFormat == Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - assert(desc.serde.isEmpty) - assert(desc.properties == Map()) - } - - test("Test CTAS #4") { - val s4 = - """CREATE TABLE page_view - |STORED BY 'storage.handler.class.name' AS SELECT * FROM src""".stripMargin - intercept[AnalysisException] { - extractTableDesc(s4) - } - } - - test("Test CTAS #5") { - val s5 = """CREATE TABLE ctas2 - | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" - | WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2") - | STORED AS RCFile - | TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22") - | AS - | SELECT key, value - | FROM src - | ORDER BY key, value""".stripMargin - val (desc, exists) = extractTableDesc(s5) - assert(exists == false) - assert(desc.specifiedDatabase == None) - assert(desc.name == "ctas2") - assert(desc.tableType == ManagedTable) - assert(desc.location == None) - assert(desc.schema == Seq.empty[HiveColumn]) - assert(desc.viewText == None) // TODO will be SQLText - assert(desc.serdeProperties == Map(("serde_p1" -> "p1"), ("serde_p2" -> "p2"))) - assert(desc.inputFormat == Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) - assert(desc.outputFormat == Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - assert(desc.serde == Option("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe")) - assert(desc.properties == Map(("tbl_p1" -> "p11"), ("tbl_p2" -> "p22"))) - } - - test("Invalid interval term should throw AnalysisException") { - def assertError(sql: String, errorMessage: String): Unit = { - val e = intercept[AnalysisException] { - HiveQl.parseSql(sql) - } - assert(e.getMessage.contains(errorMessage)) - } - assertError("select interval '42-32' year to month", - "month 32 outside range [0, 11]") - assertError("select interval '5 49:12:15' day to second", - "hour 49 outside range [0, 23]") - assertError("select interval '.1111111111' second", - "nanosecond 1111111111 outside range") - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 10e4ae2c50308..c5417b06a455b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -22,14 +22,18 @@ import java.sql.Timestamp import java.util.Date import scala.collection.mutable.ArrayBuffer +import scala.tools.nsc.Properties -import org.scalatest.Matchers +import org.scalatest.{BeforeAndAfterEach, Matchers} import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ import org.apache.spark._ -import org.apache.spark.sql.{SQLContext, QueryTest} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{QueryTest, Row, SQLContext} +import org.apache.spark.sql.catalyst.catalog.CatalogFunction +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer @@ -42,17 +46,68 @@ import org.apache.spark.util.{ResetSystemProperties, Utils} class HiveSparkSubmitSuite extends SparkFunSuite with Matchers - // This test suite sometimes gets extremely slow out of unknown reason on Jenkins. Here we - // add a timestamp to provide more diagnosis information. + with BeforeAndAfterEach with ResetSystemProperties with Timeouts { // TODO: rewrite these or mark them as slow tests to be run sparingly - def beforeAll() { + override def beforeEach() { + super.beforeEach() System.setProperty("spark.testing", "true") } + test("temporary Hive UDF: define a UDF and use it") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) + val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) + val jarsString = Seq(jar1, jar2).map(j => j.toString).mkString(",") + val args = Seq( + "--class", TemporaryHiveUDFTest.getClass.getName.stripSuffix("$"), + "--name", "TemporaryHiveUDFTest", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", + "--jars", jarsString, + unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") + runSparkSubmit(args) + } + + test("permanent Hive UDF: define a UDF and use it") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) + val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) + val jarsString = Seq(jar1, jar2).map(j => j.toString).mkString(",") + val args = Seq( + "--class", PermanentHiveUDFTest1.getClass.getName.stripSuffix("$"), + "--name", "PermanentHiveUDFTest1", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", + "--jars", jarsString, + unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") + runSparkSubmit(args) + } + + test("permanent Hive UDF: use a already defined permanent function") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) + val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) + val jarsString = Seq(jar1, jar2).map(j => j.toString).mkString(",") + val args = Seq( + "--class", PermanentHiveUDFTest2.getClass.getName.stripSuffix("$"), + "--name", "PermanentHiveUDFTest2", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", + "--jars", jarsString, + unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") + runSparkSubmit(args) + } + test("SPARK-8368: includes jars passed in through --jars") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) @@ -66,6 +121,7 @@ class HiveSparkSubmitSuite "--master", "local-cluster[2,1,1024]", "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", "--jars", jarsString, unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") runSparkSubmit(args) @@ -79,6 +135,9 @@ class HiveSparkSubmitSuite "--master", "local-cluster[2,1,1024]", "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", + "--conf", "spark.sql.hive.metastore.version=0.12", + "--conf", "spark.sql.hive.metastore.jars=maven", + "--driver-java-options", "-Dderby.system.durability=test", unusedJar.toString) runSparkSubmit(args) } @@ -89,10 +148,15 @@ class HiveSparkSubmitSuite // Before the fix in SPARK-8470, this results in a MissingRequirementError because // the HiveContext code mistakenly overrides the class loader that contains user classes. // For more detail, see sql/hive/src/test/resources/regression-test-SPARK-8489/*scala. - val testJar = "sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar" + val version = Properties.versionNumberString match { + case v if v.startsWith("2.10") || v.startsWith("2.11") => v.substring(0, 4) + case x => throw new Exception(s"Unsupported Scala Version: $x") + } + val testJar = s"sql/hive/src/test/resources/regression-test-SPARK-8489/test-$version.jar" val args = Seq( "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", "--class", "Main", testJar) runSparkSubmit(args) @@ -104,6 +168,9 @@ class HiveSparkSubmitSuite "--class", SPARK_9757.getClass.getName.stripSuffix("$"), "--name", "SparkSQLConfTest", "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", unusedJar.toString) runSparkSubmit(args) } @@ -114,6 +181,22 @@ class HiveSparkSubmitSuite "--class", SPARK_11009.getClass.getName.stripSuffix("$"), "--name", "SparkSQLConfTest", "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", + unusedJar.toString) + runSparkSubmit(args) + } + + test("SPARK-14244 fix window partition size attribute binding failure") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SPARK_14244.getClass.getName.stripSuffix("$"), + "--name", "SparkSQLConfTest", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", unusedJar.toString) runSparkSubmit(args) } @@ -178,6 +261,118 @@ class HiveSparkSubmitSuite } } +// This application is used to test defining a new Hive UDF (with an associated jar) +// and use this UDF. We need to run this test in separate JVM to make sure we +// can load the jar defined with the function. +object TemporaryHiveUDFTest extends Logging { + def main(args: Array[String]) { + Utils.configTestLog4j("INFO") + val conf = new SparkConf() + conf.set("spark.ui.enabled", "false") + val sc = new SparkContext(conf) + val hiveContext = new TestHiveContext(sc) + + // Load a Hive UDF from the jar. + logInfo("Registering a temporary Hive UDF provided in a jar.") + val jar = hiveContext.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath + hiveContext.sql( + s""" + |CREATE TEMPORARY FUNCTION example_max + |AS 'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax' + |USING JAR '$jar' + """.stripMargin) + val source = + hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val") + source.registerTempTable("sourceTable") + // Actually use the loaded UDF. + logInfo("Using the UDF.") + val result = hiveContext.sql( + "SELECT example_max(key) as key, val FROM sourceTable GROUP BY val") + logInfo("Running a simple query on the table.") + val count = result.orderBy("key", "val").count() + if (count != 10) { + throw new Exception(s"Result table should have 10 rows instead of $count rows") + } + hiveContext.sql("DROP temporary FUNCTION example_max") + logInfo("Test finishes.") + sc.stop() + } +} + +// This application is used to test defining a new Hive UDF (with an associated jar) +// and use this UDF. We need to run this test in separate JVM to make sure we +// can load the jar defined with the function. +object PermanentHiveUDFTest1 extends Logging { + def main(args: Array[String]) { + Utils.configTestLog4j("INFO") + val conf = new SparkConf() + conf.set("spark.ui.enabled", "false") + val sc = new SparkContext(conf) + val hiveContext = new TestHiveContext(sc) + + // Load a Hive UDF from the jar. + logInfo("Registering a permanent Hive UDF provided in a jar.") + val jar = hiveContext.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath + hiveContext.sql( + s""" + |CREATE FUNCTION example_max + |AS 'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax' + |USING JAR '$jar' + """.stripMargin) + val source = + hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val") + source.registerTempTable("sourceTable") + // Actually use the loaded UDF. + logInfo("Using the UDF.") + val result = hiveContext.sql( + "SELECT example_max(key) as key, val FROM sourceTable GROUP BY val") + logInfo("Running a simple query on the table.") + val count = result.orderBy("key", "val").count() + if (count != 10) { + throw new Exception(s"Result table should have 10 rows instead of $count rows") + } + hiveContext.sql("DROP FUNCTION example_max") + logInfo("Test finishes.") + sc.stop() + } +} + +// This application is used to test that a pre-defined permanent function with a jar +// resources can be used. We need to run this test in separate JVM to make sure we +// can load the jar defined with the function. +object PermanentHiveUDFTest2 extends Logging { + def main(args: Array[String]) { + Utils.configTestLog4j("INFO") + val conf = new SparkConf() + conf.set("spark.ui.enabled", "false") + val sc = new SparkContext(conf) + val hiveContext = new TestHiveContext(sc) + // Load a Hive UDF from the jar. + logInfo("Write the metadata of a permanent Hive UDF into metastore.") + val jar = hiveContext.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath + val function = CatalogFunction( + FunctionIdentifier("example_max"), + "org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax", + ("JAR" -> jar) :: Nil) + hiveContext.sessionState.catalog.createFunction(function, ignoreIfExists = false) + val source = + hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val") + source.registerTempTable("sourceTable") + // Actually use the loaded UDF. + logInfo("Using the UDF.") + val result = hiveContext.sql( + "SELECT example_max(key) as key, val FROM sourceTable GROUP BY val") + logInfo("Running a simple query on the table.") + val count = result.orderBy("key", "val").count() + if (count != 10) { + throw new Exception(s"Result table should have 10 rows instead of $count rows") + } + hiveContext.sql("DROP FUNCTION example_max") + logInfo("Test finishes.") + sc.stop() + } +} + // This object is used for testing SPARK-8368: https://issues.apache.org/jira/browse/SPARK-8368. // We test if we can load user jars in both driver and executors when HiveContext is used. object SparkSubmitClassLoaderTest extends Logging { @@ -199,14 +394,14 @@ object SparkSubmitClassLoaderTest extends Logging { } // Second, we load classes at the executor side. logInfo("Testing load classes at the executor side.") - val result = df.mapPartitions { x => + val result = df.rdd.mapPartitions { x => var exception: String = null try { Utils.classForName(args(0)) Utils.classForName(args(1)) } catch { case t: Throwable => - exception = t + "\n" + t.getStackTraceString + exception = t + "\n" + Utils.exceptionString(t) exception = exception.replaceAll("\n", "\n\t") } Option(exception).toSeq.iterator @@ -352,7 +547,7 @@ object SPARK_11009 extends QueryTest { val df = sqlContext.range(1 << 20) val df2 = df.select((df("id") % 1000).alias("A"), (df("id") / 1000).alias("B")) val ws = Window.partitionBy(df2("A")).orderBy(df2("B")) - val df3 = df2.select(df2("A"), df2("B"), rowNumber().over(ws).alias("rn")).filter("rn < 0") + val df3 = df2.select(df2("A"), df2("B"), row_number().over(ws).alias("rn")).filter("rn < 0") if (df3.rdd.count() != 0) { throw new Exception("df3 should have 0 output row.") } @@ -361,3 +556,32 @@ object SPARK_11009 extends QueryTest { } } } + +object SPARK_14244 extends QueryTest { + import org.apache.spark.sql.expressions.Window + import org.apache.spark.sql.functions._ + + protected var sqlContext: SQLContext = _ + + def main(args: Array[String]): Unit = { + Utils.configTestLog4j("INFO") + + val sparkContext = new SparkContext( + new SparkConf() + .set("spark.ui.enabled", "false") + .set("spark.sql.shuffle.partitions", "100")) + + val hiveContext = new TestHiveContext(sparkContext) + sqlContext = hiveContext + + import hiveContext.implicits._ + + try { + val window = Window.orderBy('id) + val df = sqlContext.range(2).select(cume_dist().over(window).as('cdist)).orderBy('cdist) + checkAnswer(df, Seq(Row(0.5D), Row(1.0D))) + } finally { + sparkContext.stop() + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 81ee9ba71beb6..4db95636e7610 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -22,8 +22,8 @@ import java.io.File import org.apache.hadoop.hive.conf.HiveConf import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.{QueryTest, _} +import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -81,7 +81,7 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef test("Double create fails when allowExisting = false") { sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") - intercept[QueryExecutionException] { + intercept[AnalysisException] { sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") } } @@ -154,8 +154,8 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef } val expected = List( "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=2"::Nil, - "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=3"::Nil , - "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=1"::Nil , + "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=3"::Nil, + "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=1"::Nil, "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=4"::Nil ) assert(listFolders(tmpDir, List()).sortBy(_.toString()) === expected.sortBy(_.toString)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index 183aca29cf98d..e8188e5f02f28 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.hive import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.QueryTest import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.hive.test.TestHiveSingleton class ListTablesSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { import hiveContext._ @@ -31,18 +31,25 @@ class ListTablesSuite extends QueryTest with TestHiveSingleton with BeforeAndAft val df = sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") override def beforeAll(): Unit = { + super.beforeAll() // The catalog in HiveContext is a case insensitive one. - catalog.registerTable(TableIdentifier("ListTablesSuiteTable"), df.logicalPlan) + sessionState.catalog.createTempTable( + "ListTablesSuiteTable", df.logicalPlan, overrideIfExists = true) sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)") sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB") sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)") } override def afterAll(): Unit = { - catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) - sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable") - sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable") - sql("DROP DATABASE IF EXISTS ListTablesSuiteDB") + try { + sessionState.catalog.dropTable( + TableIdentifier("ListTablesSuiteTable"), ignoreIfNotExists = true) + sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable") + sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable") + sql("DROP DATABASE IF EXISTS ListTablesSuiteDB") + } finally { + super.afterAll() + } } test("get all tables of current database") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala new file mode 100644 index 0000000000000..c9bcf819effaf --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -0,0 +1,744 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import scala.util.control.NonFatal + +import org.apache.spark.sql.Column +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SQLTestUtils + +class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { + import testImplicits._ + + protected override def beforeAll(): Unit = { + super.beforeAll() + sql("DROP TABLE IF EXISTS parquet_t0") + sql("DROP TABLE IF EXISTS parquet_t1") + sql("DROP TABLE IF EXISTS parquet_t2") + sql("DROP TABLE IF EXISTS t0") + + sqlContext.range(10).write.saveAsTable("parquet_t0") + sql("CREATE TABLE t0 AS SELECT * FROM parquet_t0") + + sqlContext + .range(10) + .select('id as 'key, concat(lit("val_"), 'id) as 'value) + .write + .saveAsTable("parquet_t1") + + sqlContext + .range(10) + .select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd) + .write + .saveAsTable("parquet_t2") + + def createArray(id: Column): Column = { + when(id % 3 === 0, lit(null)).otherwise(array('id, 'id + 1)) + } + + sqlContext + .range(10) + .select( + createArray('id).as("arr"), + array(array('id), createArray('id)).as("arr2"), + lit("""{"f1": "1", "f2": "2", "f3": 3}""").as("json"), + 'id + ) + .write + .saveAsTable("parquet_t3") + } + + override protected def afterAll(): Unit = { + try { + sql("DROP TABLE IF EXISTS parquet_t0") + sql("DROP TABLE IF EXISTS parquet_t1") + sql("DROP TABLE IF EXISTS parquet_t2") + sql("DROP TABLE IF EXISTS parquet_t3") + sql("DROP TABLE IF EXISTS t0") + } finally { + super.afterAll() + } + } + + private def checkHiveQl(hiveQl: String): Unit = { + val df = sql(hiveQl) + + val convertedSQL = try new SQLBuilder(df).toSQL catch { + case NonFatal(e) => + fail( + s"""Cannot convert the following HiveQL query plan back to SQL query string: + | + |# Original HiveQL query string: + |$hiveQl + | + |# Resolved query plan: + |${df.queryExecution.analyzed.treeString} + """.stripMargin, e) + } + + try { + checkAnswer(sql(convertedSQL), df) + } catch { case cause: Throwable => + fail( + s"""Failed to execute converted SQL string or got wrong answer: + | + |# Converted SQL query string: + |$convertedSQL + | + |# Original HiveQL query string: + |$hiveQl + | + |# Resolved query plan: + |${df.queryExecution.analyzed.treeString} + """.stripMargin, cause) + } + } + + test("in") { + checkHiveQl("SELECT id FROM parquet_t0 WHERE id IN (1, 2, 3)") + } + + test("not in") { + checkHiveQl("SELECT id FROM t0 WHERE id NOT IN (1, 2, 3)") + } + + test("not like") { + checkHiveQl("SELECT id FROM t0 WHERE id + 5 NOT LIKE '1%'") + } + + test("aggregate function in having clause") { + checkHiveQl("SELECT COUNT(value) FROM parquet_t1 GROUP BY key HAVING MAX(key) > 0") + } + + test("aggregate function in order by clause") { + checkHiveQl("SELECT COUNT(value) FROM parquet_t1 GROUP BY key ORDER BY MAX(key)") + } + + // When there are multiple aggregate functions in ORDER BY clause, all of them are extracted into + // Aggregate operator and aliased to the same name "aggOrder". This is OK for normal query + // execution since these aliases have different expression ID. But this introduces name collision + // when converting resolved plans back to SQL query strings as expression IDs are stripped. + test("aggregate function in order by clause with multiple order keys") { + checkHiveQl("SELECT COUNT(value) FROM parquet_t1 GROUP BY key ORDER BY key, MAX(key)") + } + + test("type widening in union") { + checkHiveQl("SELECT id FROM parquet_t0 UNION ALL SELECT CAST(id AS INT) AS id FROM parquet_t0") + } + + test("union distinct") { + checkHiveQl("SELECT * FROM t0 UNION SELECT * FROM t0") + } + + test("three-child union") { + checkHiveQl( + """ + |SELECT id FROM parquet_t0 + |UNION ALL SELECT id FROM parquet_t0 + |UNION ALL SELECT id FROM parquet_t0 + """.stripMargin) + } + + test("intersect") { + checkHiveQl("SELECT * FROM t0 INTERSECT SELECT * FROM t0") + } + + test("except") { + checkHiveQl("SELECT * FROM t0 EXCEPT SELECT * FROM t0") + } + + test("self join") { + checkHiveQl("SELECT x.key FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key") + } + + test("self join with group by") { + checkHiveQl( + "SELECT x.key, COUNT(*) FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key group by x.key") + } + + test("case") { + checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 WHEN id % 2 = 0 THEN 1 END FROM parquet_t0") + } + + test("case with else") { + checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 ELSE 1 END FROM parquet_t0") + } + + test("case with key") { + checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' END FROM parquet_t0") + } + + test("case with key and else") { + checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' ELSE 'baz' END FROM parquet_t0") + } + + test("select distinct without aggregate functions") { + checkHiveQl("SELECT DISTINCT id FROM parquet_t0") + } + + test("rollup/cube #1") { + // Original logical plan: + // Aggregate [(key#17L % cast(5 as bigint))#47L,grouping__id#46], + // [(count(1),mode=Complete,isDistinct=false) AS cnt#43L, + // (key#17L % cast(5 as bigint))#47L AS _c1#45L, + // grouping__id#46 AS _c2#44] + // +- Expand [List(key#17L, value#18, (key#17L % cast(5 as bigint))#47L, 0), + // List(key#17L, value#18, null, 1)], + // [key#17L,value#18,(key#17L % cast(5 as bigint))#47L,grouping__id#46] + // +- Project [key#17L, + // value#18, + // (key#17L % cast(5 as bigint)) AS (key#17L % cast(5 as bigint))#47L] + // +- Subquery t1 + // +- Relation[key#17L,value#18] ParquetRelation + // Converted SQL: + // SELECT count( 1) AS `cnt`, + // (`t1`.`key` % CAST(5 AS BIGINT)), + // grouping_id() AS `_c2` + // FROM `default`.`t1` + // GROUP BY (`t1`.`key` % CAST(5 AS BIGINT)) + // GROUPING SETS (((`t1`.`key` % CAST(5 AS BIGINT))), ()) + checkHiveQl( + "SELECT count(*) as cnt, key%5, grouping_id() FROM parquet_t1 GROUP BY key % 5 WITH ROLLUP") + checkHiveQl( + "SELECT count(*) as cnt, key%5, grouping_id() FROM parquet_t1 GROUP BY key % 5 WITH CUBE") + } + + test("rollup/cube #2") { + checkHiveQl("SELECT key, value, count(value) FROM parquet_t1 GROUP BY key, value WITH ROLLUP") + checkHiveQl("SELECT key, value, count(value) FROM parquet_t1 GROUP BY key, value WITH CUBE") + } + + test("rollup/cube #3") { + checkHiveQl( + "SELECT key, count(value), grouping_id() FROM parquet_t1 GROUP BY key, value WITH ROLLUP") + checkHiveQl( + "SELECT key, count(value), grouping_id() FROM parquet_t1 GROUP BY key, value WITH CUBE") + } + + test("rollup/cube #4") { + checkHiveQl( + s""" + |SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM parquet_t1 + |GROUP BY key % 5, key - 5 WITH ROLLUP + """.stripMargin) + checkHiveQl( + s""" + |SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM parquet_t1 + |GROUP BY key % 5, key - 5 WITH CUBE + """.stripMargin) + } + + test("rollup/cube #5") { + checkHiveQl( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id(key % 5, key - 5) AS k3 + |FROM (SELECT key, key%2, key - 5 FROM parquet_t1) t GROUP BY key%5, key-5 + |WITH ROLLUP + """.stripMargin) + checkHiveQl( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id(key % 5, key - 5) AS k3 + |FROM (SELECT key, key % 2, key - 5 FROM parquet_t1) t GROUP BY key % 5, key - 5 + |WITH CUBE + """.stripMargin) + } + + test("rollup/cube #6") { + checkHiveQl("SELECT a, b, sum(c) FROM parquet_t2 GROUP BY ROLLUP(a, b) ORDER BY a, b") + checkHiveQl("SELECT a, b, sum(c) FROM parquet_t2 GROUP BY CUBE(a, b) ORDER BY a, b") + checkHiveQl("SELECT a, b, sum(a) FROM parquet_t2 GROUP BY ROLLUP(a, b) ORDER BY a, b") + checkHiveQl("SELECT a, b, sum(a) FROM parquet_t2 GROUP BY CUBE(a, b) ORDER BY a, b") + checkHiveQl("SELECT a + b, b, sum(a - b) FROM parquet_t2 GROUP BY a + b, b WITH ROLLUP") + checkHiveQl("SELECT a + b, b, sum(a - b) FROM parquet_t2 GROUP BY a + b, b WITH CUBE") + } + + test("rollup/cube #7") { + checkHiveQl("SELECT a, b, grouping_id(a, b) FROM parquet_t2 GROUP BY cube(a, b)") + checkHiveQl("SELECT a, b, grouping(b) FROM parquet_t2 GROUP BY cube(a, b)") + checkHiveQl("SELECT a, b, grouping(a) FROM parquet_t2 GROUP BY cube(a, b)") + } + + test("rollup/cube #8") { + // grouping_id() is part of another expression + checkHiveQl( + s""" + |SELECT hkey AS k1, value - 5 AS k2, hash(grouping_id()) AS hgid + |FROM (SELECT hash(key) as hkey, key as value FROM parquet_t1) t GROUP BY hkey, value-5 + |WITH ROLLUP + """.stripMargin) + checkHiveQl( + s""" + |SELECT hkey AS k1, value - 5 AS k2, hash(grouping_id()) AS hgid + |FROM (SELECT hash(key) as hkey, key as value FROM parquet_t1) t GROUP BY hkey, value-5 + |WITH CUBE + """.stripMargin) + } + + test("rollup/cube #9") { + // self join is used as the child node of ROLLUP/CUBE with replaced quantifiers + checkHiveQl( + s""" + |SELECT t.key - 5, cnt, SUM(cnt) + |FROM (SELECT x.key, COUNT(*) as cnt + |FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key GROUP BY x.key) t + |GROUP BY cnt, t.key - 5 + |WITH ROLLUP + """.stripMargin) + checkHiveQl( + s""" + |SELECT t.key - 5, cnt, SUM(cnt) + |FROM (SELECT x.key, COUNT(*) as cnt + |FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key GROUP BY x.key) t + |GROUP BY cnt, t.key - 5 + |WITH CUBE + """.stripMargin) + } + + test("grouping sets #1") { + checkHiveQl( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id() AS k3 + |FROM (SELECT key, key % 2, key - 5 FROM parquet_t1) t GROUP BY key % 5, key - 5 + |GROUPING SETS (key % 5, key - 5) + """.stripMargin) + } + + test("grouping sets #2") { + checkHiveQl( + "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a, b) ORDER BY a, b") + checkHiveQl( + "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a) ORDER BY a, b") + checkHiveQl( + "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (b) ORDER BY a, b") + checkHiveQl( + "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (()) ORDER BY a, b") + checkHiveQl( + s""" + |SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b + |GROUPING SETS ((), (a), (a, b)) ORDER BY a, b + """.stripMargin) + } + + test("cluster by") { + checkHiveQl("SELECT id FROM parquet_t0 CLUSTER BY id") + } + + test("distribute by") { + checkHiveQl("SELECT id FROM parquet_t0 DISTRIBUTE BY id") + } + + test("distribute by with sort by") { + checkHiveQl("SELECT id FROM parquet_t0 DISTRIBUTE BY id SORT BY id") + } + + test("SPARK-13720: sort by after having") { + checkHiveQl("SELECT COUNT(value) FROM parquet_t1 GROUP BY key HAVING MAX(key) > 0 SORT BY key") + } + + test("distinct aggregation") { + checkHiveQl("SELECT COUNT(DISTINCT id) FROM parquet_t0") + } + + test("TABLESAMPLE") { + // Project [id#2L] + // +- Sample 0.0, 1.0, false, ... + // +- Subquery s + // +- Subquery parquet_t0 + // +- Relation[id#2L] ParquetRelation + checkHiveQl("SELECT s.id FROM parquet_t0 TABLESAMPLE(100 PERCENT) s") + + // Project [id#2L] + // +- Sample 0.0, 1.0, false, ... + // +- Subquery parquet_t0 + // +- Relation[id#2L] ParquetRelation + checkHiveQl("SELECT * FROM parquet_t0 TABLESAMPLE(100 PERCENT)") + + // Project [id#21L] + // +- Sample 0.0, 1.0, false, ... + // +- MetastoreRelation default, t0, Some(s) + checkHiveQl("SELECT s.id FROM t0 TABLESAMPLE(100 PERCENT) s") + + // Project [id#24L] + // +- Sample 0.0, 1.0, false, ... + // +- MetastoreRelation default, t0, None + checkHiveQl("SELECT * FROM t0 TABLESAMPLE(100 PERCENT)") + + // When a sampling fraction is not 100%, the returned results are random. + // Thus, added an always-false filter here to check if the generated plan can be successfully + // executed. + checkHiveQl("SELECT s.id FROM parquet_t0 TABLESAMPLE(0.1 PERCENT) s WHERE 1=0") + checkHiveQl("SELECT * FROM parquet_t0 TABLESAMPLE(0.1 PERCENT) WHERE 1=0") + } + + test("multi-distinct columns") { + checkHiveQl("SELECT a, COUNT(DISTINCT b), COUNT(DISTINCT c), SUM(d) FROM parquet_t2 GROUP BY a") + } + + test("persisted data source relations") { + Seq("orc", "json", "parquet").foreach { format => + val tableName = s"${format}_parquet_t0" + withTable(tableName) { + sqlContext.range(10).write.format(format).saveAsTable(tableName) + checkHiveQl(s"SELECT id FROM $tableName") + } + } + } + + test("script transformation - schemaless") { + checkHiveQl("SELECT TRANSFORM (a, b, c, d) USING 'cat' FROM parquet_t2") + checkHiveQl("SELECT TRANSFORM (*) USING 'cat' FROM parquet_t2") + } + + test("script transformation - alias list") { + checkHiveQl("SELECT TRANSFORM (a, b, c, d) USING 'cat' AS (d1, d2, d3, d4) FROM parquet_t2") + } + + test("script transformation - alias list with type") { + checkHiveQl( + """FROM + |(FROM parquet_t1 SELECT TRANSFORM(key, value) USING 'cat' AS (thing1 int, thing2 string)) t + |SELECT thing1 + 1 + """.stripMargin) + } + + test("script transformation - row format delimited clause with only one format property") { + checkHiveQl( + """SELECT TRANSFORM (key) ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + |USING 'cat' AS (tKey) ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + |FROM parquet_t1 + """.stripMargin) + } + + test("script transformation - row format delimited clause with multiple format properties") { + checkHiveQl( + """SELECT TRANSFORM (key) + |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' LINES TERMINATED BY '\t' + |USING 'cat' AS (tKey) + |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' LINES TERMINATED BY '\t' + |FROM parquet_t1 + """.stripMargin) + } + + test("script transformation - row format serde clauses with SERDEPROPERTIES") { + checkHiveQl( + """SELECT TRANSFORM (key, value) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |WITH SERDEPROPERTIES('field.delim' = '|') + |USING 'cat' AS (tKey, tValue) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |WITH SERDEPROPERTIES('field.delim' = '|') + |FROM parquet_t1 + """.stripMargin) + } + + test("script transformation - row format serde clauses without SERDEPROPERTIES") { + checkHiveQl( + """SELECT TRANSFORM (key, value) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |USING 'cat' AS (tKey, tValue) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |FROM parquet_t1 + """.stripMargin) + } + + test("plans with non-SQL expressions") { + sqlContext.udf.register("foo", (_: Int) * 2) + intercept[UnsupportedOperationException](new SQLBuilder(sql("SELECT foo(id) FROM t0")).toSQL) + } + + test("named expression in column names shouldn't be quoted") { + def checkColumnNames(query: String, expectedColNames: String*): Unit = { + checkHiveQl(query) + assert(sql(query).columns === expectedColNames) + } + + // Attributes + checkColumnNames( + """SELECT * FROM ( + | SELECT 1 AS a, 2 AS b, 3 AS `we``ird` + |) s + """.stripMargin, + "a", "b", "we`ird" + ) + + checkColumnNames( + """SELECT x.a, y.a, x.b, y.b + |FROM (SELECT 1 AS a, 2 AS b) x + |INNER JOIN (SELECT 1 AS a, 2 AS b) y + |ON x.a = y.a + """.stripMargin, + "a", "a", "b", "b" + ) + + // String literal + checkColumnNames( + "SELECT 'foo', '\"bar\\''", + "foo", "\"bar\'" + ) + + // Numeric literals (should have CAST or suffixes in column names) + checkColumnNames( + "SELECT 1Y, 2S, 3, 4L, 5.1, 6.1D", + "1", "2", "3", "4", "5.1", "6.1" + ) + + // Aliases + checkColumnNames( + "SELECT 1 AS a", + "a" + ) + + // Complex type extractors + checkColumnNames( + """SELECT + | a.f1, b[0].f1, b.f1, c["foo"], d[0] + |FROM ( + | SELECT + | NAMED_STRUCT("f1", 1, "f2", "foo") AS a, + | ARRAY(NAMED_STRUCT("f1", 1, "f2", "foo")) AS b, + | MAP("foo", 1) AS c, + | ARRAY(1) AS d + |) s + """.stripMargin, + "f1", "b[0].f1", "f1", "c[foo]", "d[0]" + ) + } + + test("window basic") { + checkHiveQl("SELECT MAX(value) OVER (PARTITION BY key % 3) FROM parquet_t1") + checkHiveQl( + """ + |SELECT key, value, ROUND(AVG(key) OVER (), 2) + |FROM parquet_t1 ORDER BY key + """.stripMargin) + checkHiveQl( + """ + |SELECT value, MAX(key + 1) OVER (PARTITION BY key % 5 ORDER BY key % 7) AS max + |FROM parquet_t1 + """.stripMargin) + } + + test("multiple window functions in one expression") { + checkHiveQl( + """ + |SELECT + | MAX(key) OVER (ORDER BY key DESC, value) / MIN(key) OVER (PARTITION BY key % 3) + |FROM parquet_t1 + """.stripMargin) + } + + test("regular expressions and window functions in one expression") { + checkHiveQl("SELECT MAX(key) OVER (PARTITION BY key % 3) + key FROM parquet_t1") + } + + test("aggregate functions and window functions in one expression") { + checkHiveQl("SELECT MAX(c) + COUNT(a) OVER () FROM parquet_t2 GROUP BY a, b") + } + + test("window with different window specification") { + checkHiveQl( + """ + |SELECT key, value, + |DENSE_RANK() OVER (ORDER BY key, value) AS dr, + |MAX(value) OVER (PARTITION BY key ORDER BY key ASC) AS max + |FROM parquet_t1 + """.stripMargin) + } + + test("window with the same window specification with aggregate + having") { + checkHiveQl( + """ + |SELECT key, value, + |MAX(value) OVER (PARTITION BY key % 5 ORDER BY key DESC) AS max + |FROM parquet_t1 GROUP BY key, value HAVING key > 5 + """.stripMargin) + } + + test("window with the same window specification with aggregate functions") { + checkHiveQl( + """ + |SELECT key, value, + |MAX(value) OVER (PARTITION BY key % 5 ORDER BY key) AS max + |FROM parquet_t1 GROUP BY key, value + """.stripMargin) + } + + test("window with the same window specification with aggregate") { + checkHiveQl( + """ + |SELECT key, value, + |DENSE_RANK() OVER (DISTRIBUTE BY key SORT BY key, value) AS dr, + |COUNT(key) + |FROM parquet_t1 GROUP BY key, value + """.stripMargin) + } + + test("window with the same window specification without aggregate and filter") { + checkHiveQl( + """ + |SELECT key, value, + |DENSE_RANK() OVER (DISTRIBUTE BY key SORT BY key, value) AS dr, + |COUNT(key) OVER(DISTRIBUTE BY key SORT BY key, value) AS ca + |FROM parquet_t1 + """.stripMargin) + } + + test("window clause") { + checkHiveQl( + """ + |SELECT key, MAX(value) OVER w1 AS MAX, MIN(value) OVER w2 AS min + |FROM parquet_t1 + |WINDOW w1 AS (PARTITION BY key % 5 ORDER BY key), w2 AS (PARTITION BY key % 6) + """.stripMargin) + } + + test("special window functions") { + checkHiveQl( + """ + |SELECT + | RANK() OVER w, + | PERCENT_RANK() OVER w, + | DENSE_RANK() OVER w, + | ROW_NUMBER() OVER w, + | NTILE(10) OVER w, + | CUME_DIST() OVER w, + | LAG(key, 2) OVER w, + | LEAD(key, 2) OVER w + |FROM parquet_t1 + |WINDOW w AS (PARTITION BY key % 5 ORDER BY key) + """.stripMargin) + } + + test("window with join") { + checkHiveQl( + """ + |SELECT x.key, MAX(y.key) OVER (PARTITION BY x.key % 5 ORDER BY x.key) + |FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key + """.stripMargin) + } + + test("join 2 tables and aggregate function in having clause") { + checkHiveQl( + """ + |SELECT COUNT(a.value), b.KEY, a.KEY + |FROM parquet_t1 a, parquet_t1 b + |GROUP BY a.KEY, b.KEY + |HAVING MAX(a.KEY) > 0 + """.stripMargin) + } + + test("generator in project list without FROM clause") { + checkHiveQl("SELECT EXPLODE(ARRAY(1,2,3))") + checkHiveQl("SELECT EXPLODE(ARRAY(1,2,3)) AS val") + } + + test("generator in project list with non-referenced table") { + checkHiveQl("SELECT EXPLODE(ARRAY(1,2,3)) FROM t0") + checkHiveQl("SELECT EXPLODE(ARRAY(1,2,3)) AS val FROM t0") + } + + test("generator in project list with referenced table") { + checkHiveQl("SELECT EXPLODE(arr) FROM parquet_t3") + checkHiveQl("SELECT EXPLODE(arr) AS val FROM parquet_t3") + } + + test("generator in project list with non-UDTF expressions") { + checkHiveQl("SELECT EXPLODE(arr), id FROM parquet_t3") + checkHiveQl("SELECT EXPLODE(arr) AS val, id as a FROM parquet_t3") + } + + test("generator in lateral view") { + checkHiveQl("SELECT val, id FROM parquet_t3 LATERAL VIEW EXPLODE(arr) exp AS val") + checkHiveQl("SELECT val, id FROM parquet_t3 LATERAL VIEW OUTER EXPLODE(arr) exp AS val") + } + + test("generator in lateral view with ambiguous names") { + checkHiveQl( + """ + |SELECT exp.id, parquet_t3.id + |FROM parquet_t3 + |LATERAL VIEW EXPLODE(arr) exp AS id + """.stripMargin) + checkHiveQl( + """ + |SELECT exp.id, parquet_t3.id + |FROM parquet_t3 + |LATERAL VIEW OUTER EXPLODE(arr) exp AS id + """.stripMargin) + } + + test("use JSON_TUPLE as generator") { + checkHiveQl( + """ + |SELECT c0, c1, c2 + |FROM parquet_t3 + |LATERAL VIEW JSON_TUPLE(json, 'f1', 'f2', 'f3') jt + """.stripMargin) + checkHiveQl( + """ + |SELECT a, b, c + |FROM parquet_t3 + |LATERAL VIEW JSON_TUPLE(json, 'f1', 'f2', 'f3') jt AS a, b, c + """.stripMargin) + } + + test("nested generator in lateral view") { + checkHiveQl( + """ + |SELECT val, id + |FROM parquet_t3 + |LATERAL VIEW EXPLODE(arr2) exp1 AS nested_array + |LATERAL VIEW EXPLODE(nested_array) exp1 AS val + """.stripMargin) + + checkHiveQl( + """ + |SELECT val, id + |FROM parquet_t3 + |LATERAL VIEW EXPLODE(arr2) exp1 AS nested_array + |LATERAL VIEW OUTER EXPLODE(nested_array) exp1 AS val + """.stripMargin) + } + + test("generate with other operators") { + checkHiveQl( + """ + |SELECT EXPLODE(arr) AS val, id + |FROM parquet_t3 + |WHERE id > 2 + |ORDER BY val, id + |LIMIT 5 + """.stripMargin) + + checkHiveQl( + """ + |SELECT val, id + |FROM parquet_t3 + |LATERAL VIEW EXPLODE(arr2) exp1 AS nested_array + |LATERAL VIEW EXPLODE(nested_array) exp1 AS val + |WHERE val > 2 + |ORDER BY val, id + |LIMIT 5 + """.stripMargin) + } + + test("filter after subquery") { + checkHiveQl("SELECT a FROM (SELECT key + 1 AS a FROM parquet_t1) t WHERE a > 5") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index f74eb1500b989..3c299daa778cc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -17,20 +17,21 @@ package org.apache.spark.sql.hive -import java.io.{IOException, File} +import java.io.File import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.util.Utils /** @@ -43,6 +44,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv var jsonFilePath: String = _ override def beforeAll(): Unit = { + super.beforeAll() jsonFilePath = Utils.getSparkClassLoader.getResource("sample.json").getFile } @@ -163,7 +165,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv test("check change without refresh") { withTempPath { tempDir => withTable("jsonTable") { - (("a", "b") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) + (("a", "b") :: Nil).toDF().toJSON.rdd.saveAsTextFile(tempDir.getCanonicalPath) sql( s"""CREATE TABLE jsonTable @@ -178,7 +180,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv Row("a", "b")) Utils.deleteRecursively(tempDir) - (("a1", "b1", "c1") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) + (("a1", "b1", "c1") :: Nil).toDF().toJSON.rdd.saveAsTextFile(tempDir.getCanonicalPath) // Schema is cached so the new column does not show. The updated values in existing columns // will show. @@ -198,7 +200,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv test("drop, change, recreate") { withTempPath { tempDir => - (("a", "b") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) + (("a", "b") :: Nil).toDF().toJSON.rdd.saveAsTextFile(tempDir.getCanonicalPath) withTable("jsonTable") { sql( @@ -214,7 +216,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv Row("a", "b")) Utils.deleteRecursively(tempDir) - (("a", "b", "c") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) + (("a", "b", "c") :: Nil).toDF().toJSON.rdd.saveAsTextFile(tempDir.getCanonicalPath) sql("DROP TABLE jsonTable") @@ -368,7 +370,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv |) """.stripMargin) - val expectedPath = catalog.hiveDefaultTableFilePath(TableIdentifier("ctasJsonTable")) + val expectedPath = + sessionState.catalog.hiveDefaultTableFilePath(TableIdentifier("ctasJsonTable")) val filesystemPath = new Path(expectedPath) val fs = filesystemPath.getFileSystem(sparkContext.hadoopConfiguration) if (fs.exists(filesystemPath)) fs.delete(filesystemPath, true) @@ -402,20 +405,6 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } - test("SPARK-5286 Fail to drop an invalid table when using the data source API") { - withTable("jsonTable") { - sql( - s"""CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path 'it is not a path at all!' - |) - """.stripMargin) - - sql("DROP TABLE jsonTable").collect().foreach(i => logInfo(i.toString)) - } - } - test("SPARK-5839 HiveMetastoreCatalog does not recognize table aliases of data source tables.") { withTable("savedJsonTable") { // Save the df as a managed table (by not specifying the path). @@ -472,8 +461,9 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv // Drop table will also delete the data. sql("DROP TABLE savedJsonTable") - intercept[IOException] { - read.json(catalog.hiveDefaultTableFilePath(TableIdentifier("savedJsonTable"))) + intercept[AnalysisException] { + read.json( + sessionState.catalog.hiveDefaultTableFilePath(TableIdentifier("savedJsonTable"))) } } @@ -540,21 +530,26 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv sql("SELECT b FROM savedJsonTable")) sql("DROP TABLE createdJsonTable") - - assert( - intercept[RuntimeException] { - createExternalTable( - "createdJsonTable", - "org.apache.spark.sql.json", - schema, - Map.empty[String, String]) - }.getMessage.contains("key not found: path"), - "We should complain that path is not specified.") } } } } + test("path required error") { + assert( + intercept[AnalysisException] { + createExternalTable( + "createdJsonTable", + "org.apache.spark.sql.json", + Map.empty[String, String]) + + table("createdJsonTable") + }.getMessage.contains("Unable to infer schema"), + "We should complain that path is not specified.") + + sql("DROP TABLE createdJsonTable") + } + test("scan a parquet table created through a CTAS statement") { withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "true") { withTempTable("jt") { @@ -571,9 +566,9 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv Row(3) :: Row(4) :: Nil) table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(p: ParquetRelation, _) => // OK + case LogicalRelation(p: HadoopFsRelation, _, _) => // OK case _ => - fail(s"test_parquet_ctas should have be converted to ${classOf[ParquetRelation]}") + fail(s"test_parquet_ctas should have be converted to ${classOf[HadoopFsRelation]}") } } } @@ -699,22 +694,25 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv test("SPARK-6024 wide schema support") { withSQLConf(SQLConf.SCHEMA_STRING_LENGTH_THRESHOLD.key -> "4000") { withTable("wide_schema") { - // We will need 80 splits for this schema if the threshold is 4000. - val schema = StructType((1 to 5000).map(i => StructField(s"c_$i", StringType, true))) - - // Manually create a metastore data source table. - catalog.createDataSourceTable( - tableIdent = TableIdentifier("wide_schema"), - userSpecifiedSchema = Some(schema), - partitionColumns = Array.empty[String], - provider = "json", - options = Map("path" -> "just a dummy path"), - isExternal = false) - - invalidateTable("wide_schema") - - val actualSchema = table("wide_schema").schema - assert(schema === actualSchema) + withTempDir { tempDir => + // We will need 80 splits for this schema if the threshold is 4000. + val schema = StructType((1 to 5000).map(i => StructField(s"c_$i", StringType, true))) + + // Manually create a metastore data source table. + sessionState.catalog.createDataSourceTable( + name = TableIdentifier("wide_schema"), + userSpecifiedSchema = Some(schema), + partitionColumns = Array.empty[String], + bucketSpec = None, + provider = "json", + options = Map("path" -> tempDir.getCanonicalPath), + isExternal = false) + + invalidateTable("wide_schema") + + val actualSchema = table("wide_schema").schema + assert(schema === actualSchema) + } } } } @@ -723,20 +721,24 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv val tableName = "spark6655" withTable(tableName) { val schema = StructType(StructField("int", IntegerType, true) :: Nil) - val hiveTable = HiveTable( - specifiedDatabase = Some("default"), - name = tableName, + val hiveTable = CatalogTable( + identifier = TableIdentifier(tableName, Some("default")), + tableType = CatalogTableType.MANAGED_TABLE, schema = Seq.empty, - partitionColumns = Seq.empty, + storage = CatalogStorageFormat( + locationUri = None, + inputFormat = None, + outputFormat = None, + serde = None, + serdeProperties = Map( + "path" -> sessionState.catalog.hiveDefaultTableFilePath(TableIdentifier(tableName))) + ), properties = Map( "spark.sql.sources.provider" -> "json", "spark.sql.sources.schema" -> schema.json, - "EXTERNAL" -> "FALSE"), - tableType = ManagedTable, - serdeProperties = Map( - "path" -> catalog.hiveDefaultTableFilePath(TableIdentifier(tableName)))) + "EXTERNAL" -> "FALSE")) - catalog.client.createTable(hiveTable) + hiveCatalog.createTable("default", hiveTable, ignoreIfExists = false) invalidateTable(tableName) val actualSchema = table(tableName).schema @@ -744,14 +746,14 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } - test("Saving partition columns information") { + test("Saving partitionBy columns information") { val df = (1 to 10).map(i => (i, i + 1, s"str$i", s"str${i + 1}")).toDF("a", "b", "c", "d") val tableName = s"partitionInfo_${System.currentTimeMillis()}" withTable(tableName) { df.write.format("parquet").partitionBy("d", "b").saveAsTable(tableName) invalidateTable(tableName) - val metastoreTable = catalog.client.getTable("default", tableName) + val metastoreTable = hiveCatalog.getTable("default", tableName) val expectedPartitionColumns = StructType(df.schema("d") :: df.schema("b") :: Nil) val numPartCols = metastoreTable.properties("spark.sql.sources.schema.numPartCols").toInt @@ -775,6 +777,59 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } + test("Saving information for sortBy and bucketBy columns") { + val df = (1 to 10).map(i => (i, i + 1, s"str$i", s"str${i + 1}")).toDF("a", "b", "c", "d") + val tableName = s"bucketingInfo_${System.currentTimeMillis()}" + + withTable(tableName) { + df.write + .format("parquet") + .bucketBy(8, "d", "b") + .sortBy("c") + .saveAsTable(tableName) + invalidateTable(tableName) + val metastoreTable = hiveCatalog.getTable("default", tableName) + val expectedBucketByColumns = StructType(df.schema("d") :: df.schema("b") :: Nil) + val expectedSortByColumns = StructType(df.schema("c") :: Nil) + + val numBuckets = metastoreTable.properties("spark.sql.sources.schema.numBuckets").toInt + assert(numBuckets == 8) + + val numBucketCols = metastoreTable.properties("spark.sql.sources.schema.numBucketCols").toInt + assert(numBucketCols == 2) + + val numSortCols = metastoreTable.properties("spark.sql.sources.schema.numSortCols").toInt + assert(numSortCols == 1) + + val actualBucketByColumns = + StructType( + (0 until numBucketCols).map { index => + df.schema(metastoreTable.properties(s"spark.sql.sources.schema.bucketCol.$index")) + }) + // Make sure bucketBy columns are correctly stored in metastore. + assert( + expectedBucketByColumns.sameType(actualBucketByColumns), + s"Partitions columns stored in metastore $actualBucketByColumns is not the " + + s"partition columns defined by the saveAsTable operation $expectedBucketByColumns.") + + val actualSortByColumns = + StructType( + (0 until numSortCols).map { index => + df.schema(metastoreTable.properties(s"spark.sql.sources.schema.sortCol.$index")) + }) + // Make sure sortBy columns are correctly stored in metastore. + assert( + expectedSortByColumns.sameType(actualSortByColumns), + s"Partitions columns stored in metastore $actualSortByColumns is not the " + + s"partition columns defined by the saveAsTable operation $expectedSortByColumns.") + + // Check the content of the saved table. + checkAnswer( + table(tableName).select("c", "b", "d", "a"), + df.select("c", "b", "d", "a")) + } + } + test("insert into a table") { def createDF(from: Int, to: Int): DataFrame = { (from to to).map(i => i -> s"str$i").toDF("c1", "c2") @@ -846,4 +901,40 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv sqlContext.sql("""use default""") sqlContext.sql("""drop database if exists testdb8156 CASCADE""") } + + + test("skip hive metadata on table creation") { + withTempDir { tempPath => + val schema = StructType((1 to 5).map(i => StructField(s"c_$i", StringType))) + + sessionState.catalog.createDataSourceTable( + name = TableIdentifier("not_skip_hive_metadata"), + userSpecifiedSchema = Some(schema), + partitionColumns = Array.empty[String], + bucketSpec = None, + provider = "parquet", + options = Map("path" -> tempPath.getCanonicalPath, "skipHiveMetadata" -> "false"), + isExternal = false) + + // As a proxy for verifying that the table was stored in Hive compatible format, + // we verify that each column of the table is of native type StringType. + assert(hiveCatalog.getTable("default", "not_skip_hive_metadata").schema + .forall(column => HiveMetastoreTypes.toDataType(column.dataType) == StringType)) + + sessionState.catalog.createDataSourceTable( + name = TableIdentifier("skip_hive_metadata"), + userSpecifiedSchema = Some(schema), + partitionColumns = Array.empty[String], + bucketSpec = None, + provider = "parquet", + options = Map("path" -> tempPath.getCanonicalPath, "skipHiveMetadata" -> "true"), + isExternal = false) + + // As a proxy for verifying that the table was stored in SparkSQL format, + // we verify that the table has a column type as array of StringType. + assert(hiveCatalog.getTable("default", "skip_hive_metadata").schema.forall { c => + HiveMetastoreTypes.toDataType(c.dataType) == ArrayType(StringType) + }) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index f16c257ab5ab4..3c003506efcb1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -17,18 +17,18 @@ package org.apache.spark.sql.hive +import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { - private lazy val df = sqlContext.range(10).coalesce(1) + private lazy val df = sqlContext.range(10).coalesce(1).toDF() private def checkTablePath(dbName: String, tableName: String): Unit = { - val metastoreTable = hiveContext.catalog.client.getTable(dbName, tableName) - val expectedPath = hiveContext.catalog.client.getDatabase(dbName).location + "/" + tableName + val metastoreTable = hiveContext.hiveCatalog.getTable(dbName, tableName) + val expectedPath = hiveContext.hiveCatalog.getDatabase(dbName).locationUri + "/" + tableName - assert(metastoreTable.serdeProperties("path") === expectedPath) + assert(metastoreTable.storage.serdeProperties("path") === expectedPath) } test(s"saveAsTable() to non-default database - with USE - Overwrite") { @@ -112,11 +112,11 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle df.write.mode(SaveMode.Overwrite).saveAsTable("t") df.write.mode(SaveMode.Append).saveAsTable("t") assert(sqlContext.tableNames().contains("t")) - checkAnswer(sqlContext.table("t"), df.unionAll(df)) + checkAnswer(sqlContext.table("t"), df.union(df)) } assert(sqlContext.tableNames(db).contains("t")) - checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + checkAnswer(sqlContext.table(s"$db.t"), df.union(df)) checkTablePath(db, "t") } @@ -127,7 +127,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") df.write.mode(SaveMode.Append).saveAsTable(s"$db.t") assert(sqlContext.tableNames(db).contains("t")) - checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + checkAnswer(sqlContext.table(s"$db.t"), df.union(df)) checkTablePath(db, "t") } @@ -140,7 +140,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle assert(sqlContext.tableNames().contains("t")) df.write.insertInto(s"$db.t") - checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + checkAnswer(sqlContext.table(s"$db.t"), df.union(df)) } } } @@ -155,7 +155,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle assert(sqlContext.tableNames(db).contains("t")) df.write.insertInto(s"$db.t") - checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + checkAnswer(sqlContext.table(s"$db.t"), df.union(df)) } } @@ -219,7 +219,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle hiveContext.refreshTable("t") checkAnswer( sqlContext.table("t"), - df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2)))) + df.withColumn("p", lit(1)).union(df.withColumn("p", lit(2)))) } } } @@ -251,7 +251,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle hiveContext.refreshTable(s"$db.t") checkAnswer( sqlContext.table(s"$db.t"), - df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2)))) + df.withColumn("p", lit(1)).union(df.withColumn("p", lit(2)))) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index 49aab85cf1aaf..a9823ae26278d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -21,9 +21,10 @@ import java.sql.Timestamp import org.apache.hadoop.hive.conf.HiveConf +import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.parquet.ParquetCompatibilityTest -import org.apache.spark.sql.{Row, SQLConf} import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with TestHiveSingleton { /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index f542a5a02508c..78569c58085cd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.hive import com.google.common.io.Files -import org.apache.spark.util.Utils -import org.apache.spark.sql.{QueryTest, _} +import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.Utils class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import hiveContext.implicits._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala new file mode 100644 index 0000000000000..9a63ecb4ca8d0 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import scala.util.control.NonFatal + +import org.apache.spark.sql.{DataFrame, Dataset, QueryTest} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.hive.test.TestHiveSingleton + +abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { + protected def checkSQL(e: Expression, expectedSQL: String): Unit = { + val actualSQL = e.sql + try { + assert(actualSQL === expectedSQL) + } catch { + case cause: Throwable => + fail( + s"""Wrong SQL generated for the following expression: + | + |${e.prettyName} + | + |$cause + """.stripMargin) + } + } + + protected def checkSQL(plan: LogicalPlan, expectedSQL: String): Unit = { + val generatedSQL = try new SQLBuilder(plan, hiveContext).toSQL catch { case NonFatal(e) => + fail( + s"""Cannot convert the following logical query plan to SQL: + | + |${plan.treeString} + """.stripMargin) + } + + try { + assert(generatedSQL === expectedSQL) + } catch { + case cause: Throwable => + fail( + s"""Wrong SQL generated for the following logical query plan: + | + |${plan.treeString} + | + |$cause + """.stripMargin) + } + + checkAnswer(sqlContext.sql(generatedSQL), Dataset.ofRows(sqlContext, plan)) + } + + protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = { + checkSQL(df.queryExecution.analyzed, expectedSQL) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 9bb32f11b76bd..05318f51af01e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -19,18 +19,19 @@ package org.apache.spark.sql.hive import scala.reflect.ClassTag -import org.apache.spark.sql.{Row, SQLConf, QueryTest} +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf class StatisticsSuite extends QueryTest with TestHiveSingleton { import hiveContext.sql test("parse analyze commands") { def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { - val parsed = HiveQl.parseSql(analyzeCommand) + val parsed = HiveSqlParser.parsePlan(analyzeCommand) val operators = parsed.collect { case a: AnalyzeTable => a case o => o @@ -69,7 +70,8 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton { test("analyze MetastoreRelations") { def queryTotalSize(tableName: String): BigInt = - hiveContext.catalog.lookupRelation(TableIdentifier(tableName)).statistics.sizeInBytes + hiveContext.sessionState.catalog.lookupRelation( + TableIdentifier(tableName)).statistics.sizeInBytes // Non-partitioned table sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() @@ -116,7 +118,8 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton { intercept[UnsupportedOperationException] { hiveContext.analyze("tempTable") } - hiveContext.catalog.unregisterTable(TableIdentifier("tempTable")) + hiveContext.sessionState.catalog.dropTable( + TableIdentifier("tempTable"), ignoreIfNotExists = true) } test("estimates the size of a test MetastoreRelation") { @@ -166,7 +169,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton { val shj = df.queryExecution.sparkPlan.collect { case j: SortMergeJoin => j } assert(shj.size === 1, - "ShuffledHashJoin should be planned when BroadcastHashJoin is turned off") + "SortMergeJoin should be planned when BroadcastHashJoin is turned off") sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=$tmp""") } @@ -207,7 +210,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton { // Using `sparkPlan` because for relevant patterns in HashJoin to be // matched, other strategies need to be applied. var bhj = df.queryExecution.sparkPlan.collect { - case j: BroadcastLeftSemiJoinHash => j + case j: BroadcastHashJoin => j } assert(bhj.size === 1, s"actual query plans do not contain broadcast join: ${df.queryExecution}") @@ -220,12 +223,12 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton { sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") df = sql(leftSemiJoinQuery) bhj = df.queryExecution.sparkPlan.collect { - case j: BroadcastLeftSemiJoinHash => j + case j: BroadcastHashJoin => j } assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") val shj = df.queryExecution.sparkPlan.collect { - case j: LeftSemiJoinHash => j + case j: ShuffledHashJoin => j } assert(shj.size === 1, "LeftSemiJoinHash should be planned when BroadcastHashJoin is turned off") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 3ab4576811194..d1aa5aa931947 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -17,12 +17,51 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.QueryTest +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest with TestHiveSingleton { +/** + * A test suite for UDF related functionalities. Because Hive metastore is + * case insensitive, database names and function names have both upper case + * letters and lower case letters. + */ +class UDFSuite + extends QueryTest + with SQLTestUtils + with TestHiveSingleton + with BeforeAndAfterEach { + + import hiveContext.implicits._ + + private[this] val functionName = "myUPper" + private[this] val functionNameUpper = "MYUPPER" + private[this] val functionNameLower = "myupper" + + private[this] val functionClass = + classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDFUpper].getCanonicalName + + private var testDF: DataFrame = null + private[this] val testTableName = "testDF_UDFSuite" + private var expectedDF: DataFrame = null + + override def beforeAll(): Unit = { + sql("USE default") + + testDF = (1 to 10).map(i => s"sTr$i").toDF("value") + testDF.registerTempTable(testTableName) + expectedDF = (1 to 10).map(i => s"STR$i").toDF("value") + super.beforeAll() + } + + override def afterEach(): Unit = { + sql("USE default") + super.afterEach() + } test("UDF case insensitive") { hiveContext.udf.register("random0", () => { Math.random() }) @@ -32,4 +71,128 @@ class UDFSuite extends QueryTest with TestHiveSingleton { assert(hiveContext.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) assert(hiveContext.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) } + + test("temporary function: create and drop") { + withUserDefinedFunction(functionName -> true) { + intercept[AnalysisException] { + sql(s"CREATE TEMPORARY FUNCTION default.$functionName AS '$functionClass'") + } + sql(s"CREATE TEMPORARY FUNCTION $functionName AS '$functionClass'") + checkAnswer( + sql(s"SELECT $functionNameLower(value) from $testTableName"), + expectedDF + ) + intercept[AnalysisException] { + sql(s"DROP TEMPORARY FUNCTION default.$functionName") + } + } + } + + test("permanent function: create and drop without specifying db name") { + withUserDefinedFunction(functionName -> false) { + sql(s"CREATE FUNCTION $functionName AS '$functionClass'") + checkAnswer( + sql("SHOW functions like '.*upper'"), + Row(s"default.$functionNameLower") + ) + checkAnswer( + sql(s"SELECT $functionName(value) from $testTableName"), + expectedDF + ) + assert( + sql("SHOW functions").collect() + .map(_.getString(0)) + .contains(s"default.$functionNameLower")) + } + } + + test("permanent function: create and drop with a db name") { + // For this block, drop function command uses functionName as the function name. + withUserDefinedFunction(functionNameUpper -> false) { + sql(s"CREATE FUNCTION default.$functionName AS '$functionClass'") + // TODO: Re-enable it after can distinguish qualified and unqualified function name + // in SessionCatalog.lookupFunction. + // checkAnswer( + // sql(s"SELECT default.myuPPer(value) from $testTableName"), + // expectedDF + // ) + checkAnswer( + sql(s"SELECT $functionName(value) from $testTableName"), + expectedDF + ) + checkAnswer( + sql(s"SELECT default.$functionName(value) from $testTableName"), + expectedDF + ) + } + + // For this block, drop function command uses default.functionName as the function name. + withUserDefinedFunction(s"DEfault.$functionNameLower" -> false) { + sql(s"CREATE FUNCTION dEFault.$functionName AS '$functionClass'") + checkAnswer( + sql(s"SELECT $functionNameUpper(value) from $testTableName"), + expectedDF + ) + } + } + + test("permanent function: create and drop a function in another db") { + // For this block, drop function command uses functionName as the function name. + withTempDatabase { dbName => + withUserDefinedFunction(functionName -> false) { + sql(s"CREATE FUNCTION $dbName.$functionName AS '$functionClass'") + // TODO: Re-enable it after can distinguish qualified and unqualified function name + // checkAnswer( + // sql(s"SELECT $dbName.myuPPer(value) from $testTableName"), + // expectedDF + // ) + + checkAnswer( + sql(s"SHOW FUNCTIONS like $dbName.$functionNameUpper"), + Row(s"$dbName.$functionNameLower") + ) + + sql(s"USE $dbName") + + checkAnswer( + sql(s"SELECT $functionName(value) from $testTableName"), + expectedDF + ) + + sql(s"USE default") + + checkAnswer( + sql(s"SELECT $dbName.$functionName(value) from $testTableName"), + expectedDF + ) + + sql(s"USE $dbName") + } + + sql(s"USE default") + + // For this block, drop function command uses default.functionName as the function name. + withUserDefinedFunction(s"$dbName.$functionNameUpper" -> false) { + sql(s"CREATE FUNCTION $dbName.$functionName AS '$functionClass'") + // TODO: Re-enable it after can distinguish qualified and unqualified function name + // checkAnswer( + // sql(s"SELECT $dbName.myupper(value) from $testTableName"), + // expectedDF + // ) + + sql(s"USE $dbName") + + assert( + sql("SHOW functions").collect() + .map(_.getString(0)) + .contains(s"$dbName.$functionNameLower")) + checkAnswer( + sql(s"SELECT $functionNameLower(value) from $testTableName"), + expectedDF + ) + + sql(s"USE default") + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala index 5e7b93d457106..cd96c85f3e209 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -22,7 +22,8 @@ import java.util.Collections import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -65,7 +66,7 @@ class FiltersSuite extends SparkFunSuite with Logging { "") private def filterTest(name: String, filters: Seq[Expression], result: String) = { - test(name){ + test(name) { val converted = shim.convertFilters(testTable, filters) if (converted != result) { fail( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index c6d034a23a1c6..8b0719209dedf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -19,16 +19,22 @@ package org.apache.spark.sql.hive.client import java.io.File -import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.util.VersionInfo + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal, NamedExpression} import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.types.IntegerType import org.apache.spark.tags.ExtendedHiveTest import org.apache.spark.util.Utils /** - * A simple set of tests that call the methods of a hive ClientInterface, loading different version + * A simple set of tests that call the methods of a [[HiveClient]], loading different version * of hive from maven central. These tests are simple in that they are mostly just testing to make * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality * is not fully tested. @@ -36,10 +42,14 @@ import org.apache.spark.util.Utils @ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { - // Do not use a temp path here to speed up subsequent executions of the unit test during - // development. - private val ivyPath = Some( - new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath()) + private val sparkConf = new SparkConf() + + // In order to speed up test execution during development or in Jenkins, you can specify the path + // of an existing Ivy cache: + private val ivyPath: Option[String] = { + sys.env.get("SPARK_VERSIONS_SUITE_IVY_PATH").orElse( + Some(new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath)) + } private def buildConf() = { lazy val warehousePath = Utils.createTempDir() @@ -51,11 +61,28 @@ class VersionsSuite extends SparkFunSuite with Logging { } test("success sanity check") { - val badClient = IsolatedClientLoader.forVersion(HiveContext.hiveExecutionVersion, - buildConf(), - ivyPath).createClient() - val db = new HiveDatabase("default", "") - badClient.createDatabase(db) + val badClient = IsolatedClientLoader.forVersion( + hiveMetastoreVersion = HiveContext.hiveExecutionVersion, + hadoopVersion = VersionInfo.getVersion, + sparkConf = sparkConf, + hadoopConf = new Configuration(), + config = buildConf(), + ivyPath = ivyPath).createClient() + val db = new CatalogDatabase("default", "desc", "loc", Map()) + badClient.createDatabase(db, ignoreIfExists = true) + } + + test("hadoop configuration preserved") { + val hadoopConf = new Configuration(); + hadoopConf.set("test", "success") + val client = IsolatedClientLoader.forVersion( + hiveMetastoreVersion = HiveContext.hiveExecutionVersion, + hadoopVersion = VersionInfo.getVersion, + sparkConf = sparkConf, + hadoopConf = hadoopConf, + config = buildConf(), + ivyPath = ivyPath).createClient() + assert("success" === client.getConf("test", null)) } private def getNestedMessages(e: Throwable): String = { @@ -83,7 +110,13 @@ class VersionsSuite extends SparkFunSuite with Logging { ignore("failure sanity check") { val e = intercept[Throwable] { val badClient = quietly { - IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).createClient() + IsolatedClientLoader.forVersion( + hiveMetastoreVersion = "13", + hadoopVersion = VersionInfo.getVersion, + sparkConf = sparkConf, + hadoopConf = new Configuration(), + config = buildConf(), + ivyPath = ivyPath).createClient() } } assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") @@ -91,39 +124,43 @@ class VersionsSuite extends SparkFunSuite with Logging { private val versions = Seq("12", "13", "14", "1.0.0", "1.1.0", "1.2.0") - private var client: ClientInterface = null + private var client: HiveClient = null versions.foreach { version => test(s"$version: create client") { client = null System.gc() // Hack to avoid SEGV on some JVM versions. - client = IsolatedClientLoader.forVersion(version, buildConf(), ivyPath).createClient() + client = + IsolatedClientLoader.forVersion( + hiveMetastoreVersion = version, + hadoopVersion = VersionInfo.getVersion, + sparkConf = sparkConf, + hadoopConf = new Configuration(), + config = buildConf(), + ivyPath = ivyPath).createClient() } test(s"$version: createDatabase") { - val db = HiveDatabase("default", "") - client.createDatabase(db) + val db = CatalogDatabase("default", "desc", "loc", Map()) + client.createDatabase(db, ignoreIfExists = true) } test(s"$version: createTable") { val table = - HiveTable( - specifiedDatabase = Option("default"), - name = "src", - schema = Seq(HiveColumn("key", "int", "")), - partitionColumns = Seq.empty, - properties = Map.empty, - serdeProperties = Map.empty, - tableType = ManagedTable, - location = None, - inputFormat = - Some(classOf[org.apache.hadoop.mapred.TextInputFormat].getName), - outputFormat = - Some(classOf[org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat[_, _]].getName), - serde = - Some(classOf[org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe].getName())) - - client.createTable(table) + CatalogTable( + identifier = TableIdentifier("src", Some("default")), + tableType = CatalogTableType.MANAGED_TABLE, + schema = Seq(CatalogColumn("key", "int")), + storage = CatalogStorageFormat( + locationUri = None, + inputFormat = Some(classOf[org.apache.hadoop.mapred.TextInputFormat].getName), + outputFormat = Some( + classOf[org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat[_, _]].getName), + serde = Some(classOf[org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe].getName()), + serdeProperties = Map.empty + )) + + client.createTable(table, ignoreIfExists = false) } test(s"$version: getTable") { @@ -134,10 +171,6 @@ class VersionsSuite extends SparkFunSuite with Logging { assert(client.listTables("default") === Seq("src")) } - test(s"$version: currentDatabase") { - assert(client.currentDatabase === "default") - } - test(s"$version: getDatabase") { client.getDatabase("default") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index ea80060e370e0..84bb7edf03821 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -18,14 +18,15 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConverters._ +import scala.util.Random -import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ @@ -66,14 +67,68 @@ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFun } } +class ScalaAggregateFunctionWithoutInputSchema extends UserDefinedAggregateFunction { + + def inputSchema: StructType = StructType(Nil) + + def bufferSchema: StructType = StructType(StructField("value", LongType) :: Nil) + + def dataType: DataType = LongType + + def deterministic: Boolean = true + + def initialize(buffer: MutableAggregationBuffer): Unit = { + buffer.update(0, 0L) + } + + def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + buffer.update(0, input.getAs[Seq[Row]](0).map(_.getAs[Int]("v")).sum + buffer.getLong(0)) + } + + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0)) + } + + def evaluate(buffer: Row): Any = { + buffer.getLong(0) + } +} + +class LongProductSum extends UserDefinedAggregateFunction { + def inputSchema: StructType = new StructType() + .add("a", LongType) + .add("b", LongType) + + def bufferSchema: StructType = new StructType() + .add("product", LongType) + + def dataType: DataType = LongType + + def deterministic: Boolean = true + + def initialize(buffer: MutableAggregationBuffer): Unit = { + buffer(0) = 0L + } + + def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + if (!(input.isNullAt(0) || input.isNullAt(1))) { + buffer(0) = buffer.getLong(0) + input.getLong(0) * input.getLong(1) + } + } + + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) + } + + def evaluate(buffer: Row): Any = + buffer.getLong(0) +} + abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import testImplicits._ - var originalUseAggregate2: Boolean = _ - override def beforeAll(): Unit = { - originalUseAggregate2 = sqlContext.conf.useSqlAggregate2 - sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, "true") + super.beforeAll() val data1 = Seq[(Integer, Integer)]( (1, 10), (null, -60), @@ -106,6 +161,22 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te (3, null, null)).toDF("key", "value1", "value2") data2.write.saveAsTable("agg2") + val data3 = Seq[(Seq[Integer], Integer, Integer)]( + (Seq[Integer](1, 1), 10, -10), + (Seq[Integer](null), -60, 60), + (Seq[Integer](1, 1), 30, -30), + (Seq[Integer](1), 30, 30), + (Seq[Integer](2), 1, 1), + (null, -10, 10), + (Seq[Integer](2, 3), -1, null), + (Seq[Integer](2, 3), 1, 1), + (Seq[Integer](2, 3, 4), null, 1), + (Seq[Integer](null), 100, -10), + (Seq[Integer](3), null, 3), + (null, null, null), + (Seq[Integer](3), null, null)).toDF("key", "value1", "value2") + data3.write.saveAsTable("agg3") + val emptyDF = sqlContext.createDataFrame( sparkContext.emptyRDD[Row], StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil)) @@ -114,13 +185,26 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te // Register UDAFs sqlContext.udf.register("mydoublesum", new MyDoubleSum) sqlContext.udf.register("mydoubleavg", new MyDoubleAvg) + sqlContext.udf.register("longProductSum", new LongProductSum) } override def afterAll(): Unit = { - sqlContext.sql("DROP TABLE IF EXISTS agg1") - sqlContext.sql("DROP TABLE IF EXISTS agg2") - sqlContext.dropTempTable("emptyTable") - sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, originalUseAggregate2.toString) + try { + sqlContext.sql("DROP TABLE IF EXISTS agg1") + sqlContext.sql("DROP TABLE IF EXISTS agg2") + sqlContext.sql("DROP TABLE IF EXISTS agg3") + sqlContext.dropTempTable("emptyTable") + } finally { + super.afterAll() + } + } + + test("group by function") { + Seq((1, 2)).toDF("a", "b").registerTempTable("data") + + checkAnswer( + sql("SELECT floor(a) AS a, collect_set(b) FROM data GROUP BY floor(a) ORDER BY a"), + Row(1, Array(2)) :: Nil) } test("empty table") { @@ -240,6 +324,41 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(100, null) :: Row(null, 3) :: Row(null, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT DISTINCT key + |FROM agg3 + """.stripMargin), + Row(Seq[Integer](1, 1)) :: + Row(Seq[Integer](null)) :: + Row(Seq[Integer](1)) :: + Row(Seq[Integer](2)) :: + Row(null) :: + Row(Seq[Integer](2, 3)) :: + Row(Seq[Integer](2, 3, 4)) :: + Row(Seq[Integer](3)) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT value1, key + |FROM agg3 + |GROUP BY value1, key + """.stripMargin), + Row(10, Seq[Integer](1, 1)) :: + Row(-60, Seq[Integer](null)) :: + Row(30, Seq[Integer](1, 1)) :: + Row(30, Seq[Integer](1)) :: + Row(1, Seq[Integer](2)) :: + Row(-10, null) :: + Row(-1, Seq[Integer](2, 3)) :: + Row(1, Seq[Integer](2, 3)) :: + Row(null, Seq[Integer](2, 3, 4)) :: + Row(100, Seq[Integer](null)) :: + Row(null, Seq[Integer](3)) :: + Row(null, null) :: Nil) } test("case in-sensitive resolution") { @@ -516,6 +635,50 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(3, 4, 4, 3, null) :: Nil) } + test("single distinct multiple columns set") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | count(distinct value1, value2) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(null, 3) :: + Row(1, 3) :: + Row(2, 1) :: + Row(3, 0) :: Nil) + } + + test("multiple distinct multiple columns sets") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | count(distinct value1), + | sum(distinct value1), + | count(distinct value2), + | sum(distinct value2), + | count(distinct value1, value2), + | longProductSum(distinct value1, value2), + | count(value1), + | sum(value1), + | count(value2), + | sum(value2), + | longProductSum(value1, value2), + | count(*), + | count(1) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(null, 3, 30, 3, 60, 3, -4700, 3, 30, 3, 60, -4700, 4, 4) :: + Row(1, 2, 40, 3, -10, 3, -100, 3, 70, 3, -10, -100, 3, 3) :: + Row(2, 2, 0, 1, 1, 1, 1, 3, 1, 3, 3, 2, 4, 4) :: + Row(3, 0, null, 1, 3, 0, 0, 0, null, 1, 3, 0, 2, 2) :: Nil) + } + test("test count") { checkAnswer( sqlContext.sql( @@ -641,7 +804,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te """ |SELECT corr(b, c) FROM covar_tab WHERE a = 3 """.stripMargin), - Row(null) :: Nil) + Row(Double.NaN) :: Nil) checkAnswer( sqlContext.sql( @@ -650,55 +813,42 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te """.stripMargin), Row(1, null) :: Row(2, null) :: - Row(3, null) :: - Row(4, null) :: - Row(5, null) :: - Row(6, null) :: Nil) + Row(3, Double.NaN) :: + Row(4, Double.NaN) :: + Row(5, Double.NaN) :: + Row(6, Double.NaN) :: Nil) val corr7 = sqlContext.sql("SELECT corr(b, c) FROM covar_tab").collect()(0).getDouble(0) assert(math.abs(corr7 - 0.6633880657639323) < 1e-12) - - withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { - val errorMessage = intercept[SparkException] { - val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c") - val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) - }.getMessage - assert(errorMessage.contains("java.lang.UnsupportedOperationException: " + - "Corr only supports the new AggregateExpression2")) - } } - test("test Last implemented based on AggregateExpression1") { - // TODO: Remove this test once we remove AggregateExpression1. - import org.apache.spark.sql.functions._ - val df = Seq((1, 1), (2, 2), (3, 3)).toDF("i", "j").repartition(1) - withSQLConf( - SQLConf.SHUFFLE_PARTITIONS.key -> "1", - SQLConf.USE_SQL_AGGREGATE2.key -> "false") { + test("covariance: covar_pop and covar_samp") { + // non-trivial example. To reproduce in python, use: + // >>> import numpy as np + // >>> a = np.array(range(20)) + // >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)]) + // >>> np.cov(a, b, bias = 0)[0][1] + // 595.0 + // >>> np.cov(a, b, bias = 1)[0][1] + // 565.25 + val df = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b") + val cov_samp = df.groupBy().agg(covar_samp("a", "b")).collect()(0).getDouble(0) + assert(math.abs(cov_samp - 595.0) < 1e-12) - checkAnswer( - df.groupBy("i").agg(last("j")), - df - ) - } - } + val cov_pop = df.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0) + assert(math.abs(cov_pop - 565.25) < 1e-12) - test("error handling") { - withSQLConf("spark.sql.useAggregate2" -> "false") { - val errorMessage = intercept[AnalysisException] { - sqlContext.sql( - """ - |SELECT - | key, - | sum(value + 1.5 * key), - | mydoublesum(value), - | mydoubleavg(value) - |FROM agg1 - |GROUP BY key - """.stripMargin).collect() - }.getMessage - assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) - } + val df2 = Seq.tabulate(20)(x => (1 * x, x * x * x - 2)).toDF("a", "b") + val cov_samp2 = df2.groupBy().agg(covar_samp("a", "b")).collect()(0).getDouble(0) + assert(math.abs(cov_samp2 - 11564.0) < 1e-12) + + val cov_pop2 = df2.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0) + assert(math.abs(cov_pop2 - 10985.799999999999) < 1e-12) + + // one row test + val df3 = Seq.tabulate(1)(x => (1 * x, x * x * x - 2)).toDF("a", "b") + checkAnswer(df3.groupBy().agg(covar_samp("a", "b")), Row(Double.NaN)) + checkAnswer(df3.groupBy().agg(covar_pop("a", "b")), Row(0.0)) } test("no aggregation function (SPARK-11486)") { @@ -741,7 +891,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te RandomDataGenerator.forType( dataType = schemaForGenerator, nullable = true, - seed = Some(System.nanoTime())) + new Random(System.nanoTime())) val dataGenerator = maybeDataGenerator .getOrElse(fail(s"Failed to create data generator for schema $schemaForGenerator")) @@ -760,92 +910,90 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te val df = sqlContext.createDataFrame(rdd, schema) val allColumns = df.schema.fields.map(f => col(f.name)) - val expectedAnaswer = + val expectedAnswer = data .find(r => r.getInt(0) == 50) .getOrElse(fail("A row with id 50 should be the expected answer.")) checkAnswer( df.groupBy().agg(udaf(allColumns: _*)), // udaf returns a Row as the output value. - Row(expectedAnaswer) + Row(expectedAnswer) ) } } -} - -class SortBasedAggregationQuerySuite extends AggregationQuerySuite { - var originalUnsafeEnabled: Boolean = _ + test("udaf without specifying inputSchema") { + withTempTable("noInputSchemaUDAF") { + sqlContext.udf.register("noInputSchema", new ScalaAggregateFunctionWithoutInputSchema) + + val data = + Row(1, Seq(Row(1), Row(2), Row(3))) :: + Row(1, Seq(Row(4), Row(5), Row(6))) :: + Row(2, Seq(Row(-10))) :: Nil + val schema = + StructType( + StructField("key", IntegerType) :: + StructField("myArray", + ArrayType(StructType(StructField("v", IntegerType) :: Nil))) :: Nil) + sqlContext.createDataFrame( + sparkContext.parallelize(data, 2), + schema) + .registerTempTable("noInputSchemaUDAF") - override def beforeAll(): Unit = { - originalUnsafeEnabled = sqlContext.conf.unsafeEnabled - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "false") - super.beforeAll() - } + checkAnswer( + sqlContext.sql( + """ + |SELECT key, noInputSchema(myArray) + |FROM noInputSchemaUDAF + |GROUP BY key + """.stripMargin), + Row(1, 21) :: Row(2, -10) :: Nil) - override def afterAll(): Unit = { - super.afterAll() - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) + checkAnswer( + sqlContext.sql( + """ + |SELECT noInputSchema(myArray) + |FROM noInputSchemaUDAF + """.stripMargin), + Row(11) :: Nil) + } } } -class TungstenAggregationQuerySuite extends AggregationQuerySuite { - var originalUnsafeEnabled: Boolean = _ +class TungstenAggregationQuerySuite extends AggregationQuerySuite - override def beforeAll(): Unit = { - originalUnsafeEnabled = sqlContext.conf.unsafeEnabled - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true") - super.beforeAll() - } - - override def afterAll(): Unit = { - super.afterAll() - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) - } -} class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQuerySuite { - var originalUnsafeEnabled: Boolean = _ - - override def beforeAll(): Unit = { - originalUnsafeEnabled = sqlContext.conf.unsafeEnabled - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true") - super.beforeAll() - } - - override def afterAll(): Unit = { - super.afterAll() - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) - sqlContext.conf.unsetConf("spark.sql.TungstenAggregate.testFallbackStartsAt") - } - override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { - (0 to 2).foreach { fallbackStartsAt => - sqlContext.setConf( - "spark.sql.TungstenAggregate.testFallbackStartsAt", - fallbackStartsAt.toString) - - // Create a new df to make sure its physical operator picks up - // spark.sql.TungstenAggregate.testFallbackStartsAt. - // todo: remove it? - val newActual = DataFrame(sqlContext, actual.logicalPlan) - - QueryTest.checkAnswer(newActual, expectedAnswer) match { - case Some(errorMessage) => - val newErrorMessage = - s""" - |The following aggregation query failed when using TungstenAggregate with - |controlled fallback (it falls back to sort-based aggregation once it has processed - |$fallbackStartsAt input rows). The query is - |${actual.queryExecution} - | - |$errorMessage - """.stripMargin - - fail(newErrorMessage) - case None => + Seq(false, true).foreach { enableColumnarHashMap => + withSQLConf("spark.sql.codegen.aggregate.map.enabled" -> enableColumnarHashMap.toString) { + (1 to 3).foreach { fallbackStartsAt => + withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" -> + s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") { + // Create a new df to make sure its physical operator picks up + // spark.sql.TungstenAggregate.testFallbackStartsAt. + // todo: remove it? + val newActual = Dataset.ofRows(sqlContext, actual.logicalPlan) + + QueryTest.checkAnswer(newActual, expectedAnswer) match { + case Some(errorMessage) => + val newErrorMessage = + s""" + |The following aggregation query failed when using TungstenAggregate with + |controlled fallback (it falls back to bytes to bytes map once it has processed + |${fallbackStartsAt -1} input rows and to sort-based aggregation once it has + |processed $fallbackStartsAt input rows). The query is ${actual.queryExecution} + | + |$errorMessage + """.stripMargin + + fail(newErrorMessage) + case None => // Success + } + } + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala index e38d1eb5779fe..f5cd73d45ed75 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.hive.execution +import org.scalatest.BeforeAndAfterAll + import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql.hive.test.TestHiveContext -import org.scalatest.BeforeAndAfterAll class ConcurrentHiveSuite extends SparkFunSuite with BeforeAndAfterAll { ignore("multiple instances not supported") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala new file mode 100644 index 0000000000000..061d1512a5250 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils + +class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + protected override def beforeAll(): Unit = { + super.beforeAll() + sql( + """ + |CREATE TABLE parquet_tab1 (c1 INT, c2 STRING) + |USING org.apache.spark.sql.parquet.DefaultSource + """.stripMargin) + + sql( + """ + |CREATE EXTERNAL TABLE parquet_tab2 (c1 INT, c2 STRING) + |STORED AS PARQUET + |TBLPROPERTIES('prop1Key'="prop1Val", '`prop2Key`'="prop2Val") + """.stripMargin) + } + + override protected def afterAll(): Unit = { + try { + sql("DROP TABLE IF EXISTS parquet_tab1") + sql("DROP TABLE IF EXISTS parquet_tab2") + } finally { + super.afterAll() + } + } + + test("show tables") { + withTable("show1a", "show2b") { + sql("CREATE TABLE show1a(c1 int)") + sql("CREATE TABLE show2b(c2 int)") + checkAnswer( + sql("SHOW TABLES IN default 'show1*'"), + Row("show1a", false) :: Nil) + checkAnswer( + sql("SHOW TABLES IN default 'show1*|show2*'"), + Row("show1a", false) :: + Row("show2b", false) :: Nil) + checkAnswer( + sql("SHOW TABLES 'show1*|show2*'"), + Row("show1a", false) :: + Row("show2b", false) :: Nil) + assert( + sql("SHOW TABLES").count() >= 2) + assert( + sql("SHOW TABLES IN default").count() >= 2) + } + } + + test("show tblproperties of data source tables - basic") { + checkAnswer( + sql("SHOW TBLPROPERTIES parquet_tab1") + .filter(s"key = 'spark.sql.sources.provider'"), + Row("spark.sql.sources.provider", "org.apache.spark.sql.parquet.DefaultSource") :: Nil + ) + + checkAnswer( + sql("SHOW TBLPROPERTIES parquet_tab1(spark.sql.sources.provider)"), + Row("org.apache.spark.sql.parquet.DefaultSource") :: Nil + ) + + checkAnswer( + sql("SHOW TBLPROPERTIES parquet_tab1") + .filter(s"key = 'spark.sql.sources.schema.numParts'"), + Row("spark.sql.sources.schema.numParts", "1") :: Nil + ) + + checkAnswer( + sql("SHOW TBLPROPERTIES parquet_tab1('spark.sql.sources.schema.numParts')"), + Row("1")) + } + + test("show tblproperties for datasource table - errors") { + val message1 = intercept[AnalysisException] { + sql("SHOW TBLPROPERTIES badtable") + }.getMessage + assert(message1.contains("Table or View badtable not found in database default")) + + // When key is not found, a row containing the error is returned. + checkAnswer( + sql("SHOW TBLPROPERTIES parquet_tab1('invalid.prop.key')"), + Row("Table default.parquet_tab1 does not have property: invalid.prop.key") :: Nil + ) + } + + test("show tblproperties for hive table") { + checkAnswer(sql("SHOW TBLPROPERTIES parquet_tab2('prop1Key')"), Row("prop1Val")) + checkAnswer(sql("SHOW TBLPROPERTIES parquet_tab2('`prop2Key`')"), Row("prop2Val")) + } + + test("show tblproperties for spark temporary table - empty row") { + withTempTable("parquet_temp") { + sql( + """ + |CREATE TEMPORARY TABLE parquet_temp (c1 INT, c2 STRING) + |USING org.apache.spark.sql.parquet.DefaultSource + """.stripMargin) + + // An empty sequence of row is returned for session temporary table. + checkAnswer(sql("SHOW TBLPROPERTIES parquet_temp"), Nil) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index aa95ba94fa873..e67fcbedc3364 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.execution import java.io._ +import java.nio.charset.StandardCharsets import scala.util.control.NonFatal @@ -27,8 +28,9 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} +import org.apache.spark.sql.execution.command.{ExplainCommand, SetCommand} import org.apache.spark.sql.execution.datasources.DescribeCommand +import org.apache.spark.sql.hive.{InsertIntoHiveTable => LogicalInsertIntoHiveTable, SQLBuilder} import org.apache.spark.sql.hive.test.TestHive /** @@ -126,10 +128,18 @@ abstract class HiveComparisonTest protected val cacheDigest = java.security.MessageDigest.getInstance("MD5") protected def getMd5(str: String): String = { val digest = java.security.MessageDigest.getInstance("MD5") - digest.update(str.replaceAll(System.lineSeparator(), "\n").getBytes("utf-8")) + digest.update(str.replaceAll(System.lineSeparator(), "\n").getBytes(StandardCharsets.UTF_8)) new java.math.BigInteger(1, digest.digest).toString(16) } + override protected def afterAll(): Unit = { + try { + TestHive.reset() + } finally { + super.afterAll() + } + } + protected def prepareAnswer( hiveQuery: TestHive.type#QueryExecution, answer: Seq[String]): Seq[String] = { @@ -209,7 +219,11 @@ abstract class HiveComparisonTest } val installHooksCommand = "(?i)SET.*hooks".r - def createQueryTest(testCaseName: String, sql: String, reset: Boolean = true) { + def createQueryTest( + testCaseName: String, + sql: String, + reset: Boolean = true, + tryWithoutResettingFirst: Boolean = false) { // testCaseName must not contain ':', which is not allowed to appear in a filename of Windows assert(!testCaseName.contains(":")) @@ -240,9 +254,6 @@ abstract class HiveComparisonTest test(testCaseName) { logDebug(s"=== HIVE TEST: $testCaseName ===") - // Clear old output for this testcase. - outputDirectories.map(new File(_, testCaseName)).filter(_.exists()).foreach(_.delete()) - val sqlWithoutComment = sql.split("\n").filterNot(l => l.matches("--.*(?<=[^\\\\]);")).mkString("\n") val allQueries = @@ -269,11 +280,32 @@ abstract class HiveComparisonTest }.mkString("\n== Console version of this test ==\n", "\n", "\n") } - try { + def doTest(reset: Boolean, isSpeculative: Boolean = false): Unit = { + // Clear old output for this testcase. + outputDirectories.map(new File(_, testCaseName)).filter(_.exists()).foreach(_.delete()) + if (reset) { TestHive.reset() } + // Many tests drop indexes on src and srcpart at the beginning, so we need to load those + // tables here. Since DROP INDEX DDL is just passed to Hive, it bypasses the analyzer and + // thus the tables referenced in those DDL commands cannot be extracted for use by our + // test table auto-loading mechanism. In addition, the tests which use the SHOW TABLES + // command expect these tables to exist. + val hasShowTableCommand = queryList.exists(_.toLowerCase.contains("show tables")) + for (table <- Seq("src", "srcpart")) { + val hasMatchingQuery = queryList.exists { query => + val normalizedQuery = query.toLowerCase.stripSuffix(";") + normalizedQuery.endsWith(table) || + normalizedQuery.contains(s"from $table") || + normalizedQuery.contains(s"from default.$table") + } + if (hasShowTableCommand || hasMatchingQuery) { + TestHive.loadTestTable(table) + } + } + val hiveCacheFiles = queryList.zipWithIndex.map { case (queryString, i) => val cachedAnswerName = s"$testCaseName-$i-${getMd5(queryString)}" @@ -350,14 +382,63 @@ abstract class HiveComparisonTest // Run w/ catalyst val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => - val query = new TestHive.QueryExecution(queryString) - try { (query, prepareAnswer(query, query.stringResult())) } catch { + var query: TestHive.QueryExecution = null + try { + query = { + val originalQuery = new TestHive.QueryExecution(queryString) + val containsCommands = originalQuery.analyzed.collectFirst { + case _: Command => () + case _: LogicalInsertIntoHiveTable => () + }.nonEmpty + + if (containsCommands) { + originalQuery + } else { + val convertedSQL = try { + new SQLBuilder(originalQuery.analyzed, TestHive).toSQL + } catch { + case NonFatal(e) => fail( + s"""Cannot convert the following HiveQL query plan back to SQL query string: + | + |# Original HiveQL query string: + |$queryString + | + |# Resolved query plan: + |${originalQuery.analyzed.treeString} + """.stripMargin, e) + } + + try { + val queryExecution = new TestHive.QueryExecution(convertedSQL) + // Trigger the analysis of this converted SQL query. + queryExecution.analyzed + queryExecution + } catch { + case NonFatal(e) => fail( + s"""Failed to analyze the converted SQL string: + | + |# Original HiveQL query string: + |$queryString + | + |# Resolved query plan: + |${originalQuery.analyzed.treeString} + | + |# Converted SQL query string: + |$convertedSQL + """.stripMargin, e) + } + } + } + + (query, prepareAnswer(query, query.stringResult())) + } catch { case e: Throwable => val errorMessage = s""" |Failed to execute query using catalyst: |Error: ${e.getMessage} |${stackTraceToString(e)} + |$queryString |$query |== HIVE - ${hive.size} row(s) == |${hive.mkString("\n")} @@ -399,6 +480,10 @@ abstract class HiveComparisonTest val executions = queryList.map(new TestHive.QueryExecution(_)) executions.foreach(_.toRdd) val tablesGenerated = queryList.zip(executions).flatMap { + // We should take executedPlan instead of sparkPlan, because in following codes we + // will run the collected plans. As we will do extra processing for sparkPlan such + // as adding exchange, collapsing codegen stages, etc., collecting sparkPlan here + // will cause some errors when running these plans later. case (q, e) => e.executedPlan.collect { case i: InsertIntoHiveTable if tablesRead contains i.table.tableName => (q, e, i) @@ -430,12 +515,45 @@ abstract class HiveComparisonTest """.stripMargin stringToFile(new File(wrongDirectory, testCaseName), errorMessage + consoleTestCase) - fail(errorMessage) + if (isSpeculative && !reset) { + fail("Failed on first run; retrying") + } else { + fail(errorMessage) + } } } // Touch passed file. new FileOutputStream(new File(passedDirectory, testCaseName)).close() + } + + val canSpeculativelyTryWithoutReset: Boolean = { + val excludedSubstrings = Seq( + "into table", + "create table", + "drop index" + ) + !queryList.map(_.toLowerCase).exists { query => + excludedSubstrings.exists(s => query.contains(s)) + } + } + + try { + try { + if (tryWithoutResettingFirst && canSpeculativelyTryWithoutReset) { + doTest(reset = false, isSpeculative = true) + } else { + doTest(reset) + } + } catch { + case tf: org.scalatest.exceptions.TestFailedException => + if (tryWithoutResettingFirst && canSpeculativelyTryWithoutReset) { + logWarning("Test failed without reset(); retrying with reset()") + doTest(reset = true) + } else { + throw tf + } + } } catch { case tf: org.scalatest.exceptions.TestFailedException => throw tf case originalException: Exception => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala new file mode 100644 index 0000000000000..206d911e0d8a2 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import java.io.File + +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} +import org.apache.spark.sql.catalyst.catalog.CatalogTableType +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils + +class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import hiveContext.implicits._ + + // check if the directory for recording the data of the table exists. + private def tableDirectoryExists(tableIdentifier: TableIdentifier): Boolean = { + val expectedTablePath = + hiveContext.sessionState.catalog.hiveDefaultTableFilePath(tableIdentifier) + val filesystemPath = new Path(expectedTablePath) + val fs = filesystemPath.getFileSystem(sparkContext.hadoopConfiguration) + fs.exists(filesystemPath) + } + + test("drop tables") { + withTable("tab1") { + val tabName = "tab1" + + assert(!tableDirectoryExists(TableIdentifier(tabName))) + sql(s"CREATE TABLE $tabName(c1 int)") + + assert(tableDirectoryExists(TableIdentifier(tabName))) + sql(s"DROP TABLE $tabName") + + assert(!tableDirectoryExists(TableIdentifier(tabName))) + sql(s"DROP TABLE IF EXISTS $tabName") + sql(s"DROP VIEW IF EXISTS $tabName") + } + } + + test("drop managed tables") { + withTempDir { tmpDir => + val tabName = "tab1" + withTable(tabName) { + assert(tmpDir.listFiles.isEmpty) + sql( + s""" + |create table $tabName + |stored as parquet + |location '$tmpDir' + |as select 1, '3' + """.stripMargin) + + val hiveTable = + hiveContext.sessionState.catalog + .getTableMetadata(TableIdentifier(tabName, Some("default"))) + // It is a managed table, although it uses external in SQL + assert(hiveTable.tableType == CatalogTableType.MANAGED_TABLE) + + assert(tmpDir.listFiles.nonEmpty) + sql(s"DROP TABLE $tabName") + // The data are deleted since the table type is not EXTERNAL + assert(tmpDir.listFiles == null) + } + } + } + + test("drop external data source table") { + withTempDir { tmpDir => + val tabName = "tab1" + withTable(tabName) { + assert(tmpDir.listFiles.isEmpty) + + withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { + Seq(1 -> "a").toDF("i", "j") + .write + .mode(SaveMode.Overwrite) + .format("parquet") + .option("path", tmpDir.toString) + .saveAsTable(tabName) + } + + val hiveTable = + hiveContext.sessionState.catalog + .getTableMetadata(TableIdentifier(tabName, Some("default"))) + // This data source table is external table + assert(hiveTable.tableType == CatalogTableType.EXTERNAL_TABLE) + + assert(tmpDir.listFiles.nonEmpty) + sql(s"DROP TABLE $tabName") + // The data are not deleted since the table type is EXTERNAL + assert(tmpDir.listFiles.nonEmpty) + } + } + } + + test("create table and view with comment") { + val catalog = hiveContext.sessionState.catalog + val tabName = "tab1" + withTable(tabName) { + sql(s"CREATE TABLE $tabName(c1 int) COMMENT 'BLABLA'") + val viewName = "view1" + withView(viewName) { + sql(s"CREATE VIEW $viewName COMMENT 'no comment' AS SELECT * FROM $tabName") + val tableMetadata = catalog.getTableMetadata(TableIdentifier(tabName, Some("default"))) + val viewMetadata = catalog.getTableMetadata(TableIdentifier(viewName, Some("default"))) + assert(tableMetadata.properties.get("comment") == Option("BLABLA")) + assert(viewMetadata.properties.get("comment") == Option("no comment")) + } + } + } + + test("add/drop partitions - external table") { + val catalog = hiveContext.sessionState.catalog + withTempDir { tmpDir => + val basePath = tmpDir.getCanonicalPath + val partitionPath_1stCol_part1 = new File(basePath + "/ds=2008-04-08") + val partitionPath_1stCol_part2 = new File(basePath + "/ds=2008-04-09") + val partitionPath_part1 = new File(basePath + "/ds=2008-04-08/hr=11") + val partitionPath_part2 = new File(basePath + "/ds=2008-04-09/hr=11") + val partitionPath_part3 = new File(basePath + "/ds=2008-04-08/hr=12") + val partitionPath_part4 = new File(basePath + "/ds=2008-04-09/hr=12") + val dirSet = + tmpDir :: partitionPath_1stCol_part1 :: partitionPath_1stCol_part2 :: + partitionPath_part1 :: partitionPath_part2 :: partitionPath_part3 :: + partitionPath_part4 :: Nil + + val externalTab = "extTable_with_partitions" + withTable(externalTab) { + assert(tmpDir.listFiles.isEmpty) + sql( + s""" + |CREATE EXTERNAL TABLE $externalTab (key INT, value STRING) + |PARTITIONED BY (ds STRING, hr STRING) + |LOCATION '$basePath' + """.stripMargin) + + // Before data insertion, all the directory are empty + assert(dirSet.forall(dir => dir.listFiles == null || dir.listFiles.isEmpty)) + + for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { + sql( + s""" + |INSERT OVERWRITE TABLE $externalTab + |partition (ds='$ds',hr='$hr') + |SELECT 1, 'a' + """.stripMargin) + } + + val hiveTable = catalog.getTableMetadata(TableIdentifier(externalTab, Some("default"))) + assert(hiveTable.tableType == CatalogTableType.EXTERNAL_TABLE) + // After data insertion, all the directory are not empty + assert(dirSet.forall(dir => dir.listFiles.nonEmpty)) + + sql( + s""" + |ALTER TABLE $externalTab DROP PARTITION (ds='2008-04-08'), + |PARTITION (ds='2008-04-09', hr='12') + """.stripMargin) + assert(catalog.listPartitions(TableIdentifier(externalTab)).map(_.spec).toSet == + Set(Map("ds" -> "2008-04-09", "hr" -> "11"))) + // drop partition will not delete the data of external table + assert(dirSet.forall(dir => dir.listFiles.nonEmpty)) + + sql(s"ALTER TABLE $externalTab ADD PARTITION (ds='2008-04-08', hr='12')") + assert(catalog.listPartitions(TableIdentifier(externalTab)).map(_.spec).toSet == + Set(Map("ds" -> "2008-04-08", "hr" -> "12"), Map("ds" -> "2008-04-09", "hr" -> "11"))) + // add partition will not delete the data + assert(dirSet.forall(dir => dir.listFiles.nonEmpty)) + + sql(s"DROP TABLE $externalTab") + // drop table will not delete the data of external table + assert(dirSet.forall(dir => dir.listFiles.nonEmpty)) + } + } + } + + test("drop views") { + withTable("tab1") { + val tabName = "tab1" + sqlContext.range(10).write.saveAsTable("tab1") + withView("view1") { + val viewName = "view1" + + assert(tableDirectoryExists(TableIdentifier(tabName))) + assert(!tableDirectoryExists(TableIdentifier(viewName))) + sql(s"CREATE VIEW $viewName AS SELECT * FROM tab1") + + assert(tableDirectoryExists(TableIdentifier(tabName))) + assert(!tableDirectoryExists(TableIdentifier(viewName))) + sql(s"DROP VIEW $viewName") + + assert(tableDirectoryExists(TableIdentifier(tabName))) + sql(s"DROP VIEW IF EXISTS $viewName") + } + } + } + + test("alter views - rename") { + val tabName = "tab1" + withTable(tabName) { + sqlContext.range(10).write.saveAsTable(tabName) + val oldViewName = "view1" + val newViewName = "view2" + withView(oldViewName, newViewName) { + val catalog = hiveContext.sessionState.catalog + sql(s"CREATE VIEW $oldViewName AS SELECT * FROM $tabName") + + assert(catalog.tableExists(TableIdentifier(oldViewName))) + assert(!catalog.tableExists(TableIdentifier(newViewName))) + sql(s"ALTER VIEW $oldViewName RENAME TO $newViewName") + assert(!catalog.tableExists(TableIdentifier(oldViewName))) + assert(catalog.tableExists(TableIdentifier(newViewName))) + } + } + } + + test("alter views - set/unset tblproperties") { + val tabName = "tab1" + withTable(tabName) { + sqlContext.range(10).write.saveAsTable(tabName) + val viewName = "view1" + withView(viewName) { + val catalog = hiveContext.sessionState.catalog + sql(s"CREATE VIEW $viewName AS SELECT * FROM $tabName") + + assert(catalog.getTableMetadata(TableIdentifier(viewName)) + .properties.filter(_._1 != "transient_lastDdlTime") == Map()) + sql(s"ALTER VIEW $viewName SET TBLPROPERTIES ('p' = 'an')") + assert(catalog.getTableMetadata(TableIdentifier(viewName)) + .properties.filter(_._1 != "transient_lastDdlTime") == Map("p" -> "an")) + + // no exception or message will be issued if we set it again + sql(s"ALTER VIEW $viewName SET TBLPROPERTIES ('p' = 'an')") + assert(catalog.getTableMetadata(TableIdentifier(viewName)) + .properties.filter(_._1 != "transient_lastDdlTime") == Map("p" -> "an")) + + // the value will be updated if we set the same key to a different value + sql(s"ALTER VIEW $viewName SET TBLPROPERTIES ('p' = 'b')") + assert(catalog.getTableMetadata(TableIdentifier(viewName)) + .properties.filter(_._1 != "transient_lastDdlTime") == Map("p" -> "b")) + + sql(s"ALTER VIEW $viewName UNSET TBLPROPERTIES ('p')") + assert(catalog.getTableMetadata(TableIdentifier(viewName)) + .properties.filter(_._1 != "transient_lastDdlTime") == Map()) + + val message = intercept[AnalysisException] { + sql(s"ALTER VIEW $viewName UNSET TBLPROPERTIES ('p')") + }.getMessage + assert(message.contains( + "attempted to unset non-existent property 'p' in table '`view1`'")) + } + } + } + + test("alter views and alter table - misuse") { + val tabName = "tab1" + withTable(tabName) { + sqlContext.range(10).write.saveAsTable(tabName) + val oldViewName = "view1" + val newViewName = "view2" + withView(oldViewName, newViewName) { + val catalog = hiveContext.sessionState.catalog + sql(s"CREATE VIEW $oldViewName AS SELECT * FROM $tabName") + + assert(catalog.tableExists(TableIdentifier(tabName))) + assert(catalog.tableExists(TableIdentifier(oldViewName))) + + var message = intercept[AnalysisException] { + sql(s"ALTER VIEW $tabName RENAME TO $newViewName") + }.getMessage + assert(message.contains( + "Cannot alter a table with ALTER VIEW. Please use ALTER TABLE instead")) + + message = intercept[AnalysisException] { + sql(s"ALTER VIEW $tabName SET TBLPROPERTIES ('p' = 'an')") + }.getMessage + assert(message.contains( + "Cannot alter a table with ALTER VIEW. Please use ALTER TABLE instead")) + + message = intercept[AnalysisException] { + sql(s"ALTER VIEW $tabName UNSET TBLPROPERTIES ('p')") + }.getMessage + assert(message.contains( + "Cannot alter a table with ALTER VIEW. Please use ALTER TABLE instead")) + + message = intercept[AnalysisException] { + sql(s"ALTER TABLE $oldViewName RENAME TO $newViewName") + }.getMessage + assert(message.contains( + "Cannot alter a view with ALTER TABLE. Please use ALTER VIEW instead")) + + message = intercept[AnalysisException] { + sql(s"ALTER TABLE $oldViewName SET TBLPROPERTIES ('p' = 'an')") + }.getMessage + assert(message.contains( + "Cannot alter a view with ALTER TABLE. Please use ALTER VIEW instead")) + + message = intercept[AnalysisException] { + sql(s"ALTER TABLE $oldViewName UNSET TBLPROPERTIES ('p')") + }.getMessage + assert(message.contains( + "Cannot alter a view with ALTER TABLE. Please use ALTER VIEW instead")) + + assert(catalog.tableExists(TableIdentifier(tabName))) + assert(catalog.tableExists(TableIdentifier(oldViewName))) + } + } + } + + test("drop table using drop view") { + withTable("tab1") { + sql("CREATE TABLE tab1(c1 int)") + val message = intercept[AnalysisException] { + sql("DROP VIEW tab1") + }.getMessage + assert(message.contains("Cannot drop a table with DROP VIEW. Please use DROP TABLE instead")) + } + } + + test("drop view using drop table") { + withTable("tab1") { + sqlContext.range(10).write.saveAsTable("tab1") + withView("view1") { + sql("CREATE VIEW view1 AS SELECT * FROM tab1") + val message = intercept[AnalysisException] { + sql("DROP TABLE view1") + }.getMessage + assert(message.contains("Cannot drop a view with DROP TABLE. Please use DROP VIEW instead")) + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 94162da4eae1a..c45d49d6c0d10 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils /** * A set of tests that validates support for Hive Explain command. @@ -37,8 +37,7 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto "== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", "== Optimized Logical Plan ==", - "== Physical Plan ==", - "Code Generation") + "== Physical Plan ==") } test("explain create table command") { @@ -102,4 +101,33 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto "Physical Plan should not contain Subquery since it's eliminated by optimizer") } } + + test("EXPLAIN CODEGEN command") { + checkExistence(sql("EXPLAIN CODEGEN SELECT 1"), true, + "WholeStageCodegen", + "Generated code:", + "/* 001 */ public Object generate(Object[] references) {", + "/* 002 */ return new GeneratedIterator(references);", + "/* 003 */ }" + ) + + checkExistence(sql("EXPLAIN CODEGEN SELECT 1"), false, + "== Physical Plan ==" + ) + + checkExistence(sql("EXPLAIN EXTENDED CODEGEN SELECT 1"), true, + "WholeStageCodegen", + "Generated code:", + "/* 001 */ public Object generate(Object[] references) {", + "/* 002 */ return new GeneratedIterator(references);", + "/* 003 */ }" + ) + + checkExistence(sql("EXPLAIN EXTENDED CODEGEN SELECT 1"), false, + "== Parsed Logical Plan ==", + "== Analyzed Logical Plan ==", + "== Optimized Logical Plan ==", + "== Physical Plan ==" + ) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala index 0d4c7f86b315a..b252c6ee2faae 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.hive.test.TestHiveSingleton /** * A set of tests that validates commands can also be queried by like a table diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index cd055f9eca37e..d8d3448adde0b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -29,8 +29,8 @@ class HivePlanTest extends QueryTest with TestHiveSingleton { test("udf constant folding") { Seq.empty[Tuple1[Int]].toDF("a").registerTempTable("t") - val optimized = sql("SELECT cos(null) FROM t").queryExecution.optimizedPlan - val correctAnswer = sql("SELECT cast(null as double) FROM t").queryExecution.optimizedPlan + val optimized = sql("SELECT cos(null) AS c FROM t").queryExecution.optimizedPlan + val correctAnswer = sql("SELECT cast(null as double) AS c FROM t").queryExecution.optimizedPlan comparePlans(optimized, correctAnswer) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala index f7b37dae0a5f3..f96c989c4614f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala @@ -59,7 +59,7 @@ abstract class HiveQueryFileTest extends HiveComparisonTest { runAll) { // Build a test case and submit it to scala test framework... val queriesString = fileToString(testCaseFile) - createQueryTest(testCaseName, queriesString) + createQueryTest(testCaseName, queriesString, reset = true, tryWithoutResettingFirst = true) } else { // Only output warnings for the built in whitelist as this clutters the output when the user // trying to execute a single test from the commandline. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index fc72e3c7dc6aa..af73baa1f3914 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -18,23 +18,21 @@ package org.apache.spark.sql.hive.execution import java.io.File +import java.sql.Timestamp import java.util.{Locale, TimeZone} -import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin - import scala.util.Try -import org.scalatest.BeforeAndAfter - import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkFiles, SparkException} +import org.apache.spark.{SparkException, SparkFiles} import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.test.TestHiveContext -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} import org.apache.spark.sql.hive.test.TestHive._ case class TestData(a: Int, b: String) @@ -50,6 +48,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { import org.apache.spark.sql.hive.test.TestHive.implicits._ override def beforeAll() { + super.beforeAll() TestHive.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) @@ -58,10 +57,19 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } override def afterAll() { - TestHive.cacheTables = false - TimeZone.setDefault(originalTimeZone) - Locale.setDefault(originalLocale) - sql("DROP TEMPORARY FUNCTION udtf_count2") + try { + TestHive.cacheTables = false + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + sql("DROP TEMPORARY FUNCTION IF EXISTS udtf_count2") + } finally { + super.afterAll() + } + } + + private def assertUnsupportedFeature(body: => Unit): Unit = { + val e = intercept[AnalysisException] { body } + assert(e.getMessage.toLowerCase.contains("unsupported operation")) } test("SPARK-4908: concurrent hive native commands") { @@ -123,60 +131,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assertBroadcastNestedLoopJoin(spark_10484_4) } - createQueryTest("SPARK-8976 Wrong Result for Rollup #1", - """ - SELECT count(*) AS cnt, key % 5,GROUPING__ID FROM src group by key%5 WITH ROLLUP - """.stripMargin) - - createQueryTest("SPARK-8976 Wrong Result for Rollup #2", - """ - SELECT - count(*) AS cnt, - key % 5 as k1, - key-5 as k2, - GROUPING__ID as k3 - FROM src group by key%5, key-5 - WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin) - - createQueryTest("SPARK-8976 Wrong Result for Rollup #3", - """ - SELECT - count(*) AS cnt, - key % 5 as k1, - key-5 as k2, - GROUPING__ID as k3 - FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5 - WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin) - - createQueryTest("SPARK-8976 Wrong Result for CUBE #1", - """ - SELECT count(*) AS cnt, key % 5,GROUPING__ID FROM src group by key%5 WITH CUBE - """.stripMargin) - - createQueryTest("SPARK-8976 Wrong Result for CUBE #2", - """ - SELECT - count(*) AS cnt, - key % 5 as k1, - key-5 as k2, - GROUPING__ID as k3 - FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5 - WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin) - - createQueryTest("SPARK-8976 Wrong Result for GroupingSet", - """ - SELECT - count(*) AS cnt, - key % 5 as k1, - key-5 as k2, - GROUPING__ID as k3 - FROM (SELECT key, key%2, key - 5 FROM src) t group by key%5, key-5 - GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin) - createQueryTest("insert table with generator with column name", """ | CREATE TABLE gen_tmp (key Int); @@ -251,12 +205,17 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |IF(TRUE, CAST(NULL AS BINARY), CAST("1" AS BINARY)) AS COL18, |IF(FALSE, CAST(NULL AS DATE), CAST("1970-01-01" AS DATE)) AS COL19, |IF(TRUE, CAST(NULL AS DATE), CAST("1970-01-01" AS DATE)) AS COL20, - |IF(FALSE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL21, - |IF(TRUE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL22, - |IF(FALSE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL23, - |IF(TRUE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL24 + |IF(TRUE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL21, + |IF(FALSE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL22, + |IF(TRUE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL23 |FROM src LIMIT 1""".stripMargin) + test("constant null testing timestamp") { + val r1 = sql("SELECT IF(FALSE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL20") + .collect().head + assert(new Timestamp(1000) == r1.getTimestamp(0)) + } + createQueryTest("constant array", """ |SELECT sort_array( @@ -319,12 +278,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { "SELECT 11 % 10, IF((101.1 % 100.0) BETWEEN 1.01 AND 1.11, \"true\", \"false\"), " + "(101 / 2) % 10 FROM src LIMIT 1") - test("Query expressed in SQL") { - setConf("spark.sql.dialect", "sql") - assert(sql("SELECT 1").collect() === Array(Row(1))) - setConf("spark.sql.dialect", "hiveql") - } - test("Query expressed in HiveQL") { sql("FROM src SELECT key").collect() } @@ -384,9 +337,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("partitioned table scan", "SELECT ds, hr, key, value FROM srcpart") - createQueryTest("hash", - "SELECT hash('test') FROM src LIMIT 1") - createQueryTest("create table as", """ |CREATE TABLE createdtable AS SELECT * FROM src; @@ -606,26 +556,32 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // Jdk version leads to different query output for double, so not use createQueryTest here test("timestamp cast #1") { val res = sql("SELECT CAST(CAST(1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1").collect().head - assert(0.001 == res.getDouble(0)) + assert(1 == res.getDouble(0)) } createQueryTest("timestamp cast #2", "SELECT CAST(CAST(1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") - createQueryTest("timestamp cast #3", - "SELECT CAST(CAST(1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1") + test("timestamp cast #3") { + val res = sql("SELECT CAST(CAST(1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1").collect().head + assert(1200 == res.getInt(0)) + } createQueryTest("timestamp cast #4", "SELECT CAST(CAST(1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") - createQueryTest("timestamp cast #5", - "SELECT CAST(CAST(-1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") + test("timestamp cast #5") { + val res = sql("SELECT CAST(CAST(-1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1").collect().head + assert(-1 == res.get(0)) + } createQueryTest("timestamp cast #6", "SELECT CAST(CAST(-1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") - createQueryTest("timestamp cast #7", - "SELECT CAST(CAST(-1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1") + test("timestamp cast #7") { + val res = sql("SELECT CAST(CAST(-1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1").collect().head + assert(-1200 == res.getInt(0)) + } createQueryTest("timestamp cast #8", "SELECT CAST(CAST(-1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") @@ -648,7 +604,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |select * where key = 4 """.stripMargin) - // test get_json_object again Hive, because the HiveCompatabilitySuite cannot handle result + // test get_json_object again Hive, because the HiveCompatibilitySuite cannot handle result // with newline in it. createQueryTest("get_json_object #1", "SELECT get_json_object(src_json.json, '$') FROM src_json") @@ -710,11 +666,13 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("implement identity function using case statement") { val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src") + .rdd .map { case Row(i: Int) => i } .collect() .toSet val expected = sql("SELECT key FROM src") + .rdd .map { case Row(i: Int) => i } .collect() .toSet @@ -748,7 +706,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { def isExplanation(result: DataFrame): Boolean = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } - explanation.contains("== Physical Plan ==") + explanation.head.startsWith("== Physical Plan ==") } test("SPARK-1704: Explain commands as a DataFrame") { @@ -762,14 +720,14 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("SPARK-2180: HAVING support in GROUP BY clauses (positive)") { val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3)) - .zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)} + .zipWithIndex.map {case ((value, attr), key) => HavingRow(key, value, attr)} TestHive.sparkContext.parallelize(fixture).toDF().registerTempTable("having_test") val results = sql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") .collect() - .map(x => Pair(x.getString(0), x.getInt(1))) + .map(x => (x.getString(0), x.getInt(1))) - assert(results === Array(Pair("foo", 4))) + assert(results === Array(("foo", 4))) TestHive.reset() } @@ -781,6 +739,24 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assert(sql("select key from src having key > 490").collect().size < 100) } + test("union/except/intersect") { + assertResult(Array(Row(1), Row(1))) { + sql("select 1 as a union all select 1 as a").collect() + } + assertResult(Array(Row(1))) { + sql("select 1 as a union distinct select 1 as a").collect() + } + assertResult(Array(Row(1))) { + sql("select 1 as a union select 1 as a").collect() + } + assertResult(Array()) { + sql("select 1 as a except select 1 as a").collect() + } + assertResult(Array(Row(1))) { + sql("select 1 as a intersect select 1 as a").collect() + } + } + test("SPARK-5383 alias for udfs with multi output columns") { assert( sql("select stack(2, key, value, key, value) as (a, b) from src limit 5") @@ -927,7 +903,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("SPARK-2263: Insert Map values") { sql("CREATE TABLE m(value MAP)") sql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") - sql("SELECT * FROM m").collect().zip(sql("SELECT * FROM src LIMIT 10").collect()).map { + sql("SELECT * FROM m").collect().zip(sql("SELECT * FROM src LIMIT 10").collect()).foreach { case (Row(map: Map[_, _]), Row(key: Int, value: String)) => assert(map.size === 1) assert(map.head === (key, value)) @@ -961,10 +937,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("CREATE TEMPORARY FUNCTION") { val funcJar = TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath - sql(s"ADD JAR $funcJar") + val jarURL = s"file://$funcJar" + sql(s"ADD JAR $jarURL") sql( """CREATE TEMPORARY FUNCTION udtf_count2 AS - | 'org.apache.spark.sql.hive.execution.GenericUDTFCount2'""".stripMargin) + |'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin) assert(sql("DESCRIBE FUNCTION udtf_count2").count > 1) sql("DROP TEMPORARY FUNCTION udtf_count2") } @@ -980,9 +958,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assert(checkAddFileRDD.first()) } - case class LogEntry(filename: String, message: String) - case class LogFile(name: String) - createQueryTest("dynamic_partition", """ |DROP TABLE IF EXISTS dynamic_part_table; @@ -1235,12 +1210,103 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } + test("use database") { + val currentDatabase = sql("select current_database()").first().getString(0) + + sql("CREATE DATABASE hive_test_db") + sql("USE hive_test_db") + assert("hive_test_db" == sql("select current_database()").first().getString(0)) + + intercept[AnalysisException] { + sql("USE not_existing_db") + } + + sql(s"USE $currentDatabase") + assert(currentDatabase == sql("select current_database()").first().getString(0)) + } + + test("lookup hive UDF in another thread") { + val e = intercept[AnalysisException] { + range(1).selectExpr("not_a_udf()") + } + assert(e.getMessage.contains("Undefined function")) + assert(e.getMessage.contains("not_a_udf")) + var success = false + val t = new Thread("test") { + override def run(): Unit = { + val e = intercept[AnalysisException] { + range(1).selectExpr("not_a_udf()") + } + assert(e.getMessage.contains("Undefined function")) + assert(e.getMessage.contains("not_a_udf")) + success = true + } + } + t.start() + t.join() + assert(success) + } + createQueryTest("select from thrift based table", "SELECT * from src_thrift") // Put tests that depend on specific Hive settings before these last two test, // since they modify /clear stuff. + + test("role management commands are not supported") { + assertUnsupportedFeature { sql("CREATE ROLE my_role") } + assertUnsupportedFeature { sql("DROP ROLE my_role") } + assertUnsupportedFeature { sql("SHOW CURRENT ROLES") } + assertUnsupportedFeature { sql("SHOW ROLES") } + assertUnsupportedFeature { sql("SHOW GRANT") } + assertUnsupportedFeature { sql("SHOW ROLE GRANT USER my_principal") } + assertUnsupportedFeature { sql("SHOW PRINCIPALS my_role") } + assertUnsupportedFeature { sql("SET ROLE my_role") } + assertUnsupportedFeature { sql("GRANT my_role TO USER my_user") } + assertUnsupportedFeature { sql("GRANT ALL ON my_table TO USER my_user") } + assertUnsupportedFeature { sql("REVOKE my_role FROM USER my_user") } + assertUnsupportedFeature { sql("REVOKE ALL ON my_table FROM USER my_user") } + } + + test("import/export commands are not supported") { + assertUnsupportedFeature { sql("IMPORT TABLE my_table FROM 'my_path'") } + assertUnsupportedFeature { sql("EXPORT TABLE my_table TO 'my_path'") } + } + + test("some show commands are not supported") { + assertUnsupportedFeature { sql("SHOW CREATE TABLE my_table") } + assertUnsupportedFeature { sql("SHOW COMPACTIONS") } + assertUnsupportedFeature { sql("SHOW TRANSACTIONS") } + assertUnsupportedFeature { sql("SHOW INDEXES ON my_table") } + assertUnsupportedFeature { sql("SHOW LOCKS my_table") } + } + + test("lock/unlock table and database commands are not supported") { + assertUnsupportedFeature { sql("LOCK TABLE my_table SHARED") } + assertUnsupportedFeature { sql("UNLOCK TABLE my_table") } + assertUnsupportedFeature { sql("LOCK DATABASE my_db SHARED") } + assertUnsupportedFeature { sql("UNLOCK DATABASE my_db") } + } + + test("create/drop/alter index commands are not supported") { + assertUnsupportedFeature { + sql("CREATE INDEX my_index ON TABLE my_table(a) as 'COMPACT' WITH DEFERRED REBUILD")} + assertUnsupportedFeature { sql("DROP INDEX my_index ON my_table") } + assertUnsupportedFeature { sql("ALTER INDEX my_index ON my_table REBUILD")} + assertUnsupportedFeature { + sql("ALTER INDEX my_index ON my_table set IDXPROPERTIES (\"prop1\"=\"val1_new\")")} + } + + test("create/drop macro commands are not supported") { + assertUnsupportedFeature { + sql("CREATE TEMPORARY MACRO SIGMOID (x DOUBLE) 1.0 / (1.0 + EXP(-x))") + } + assertUnsupportedFeature { sql("DROP TEMPORARY MACRO SIGMOID") } + } } // for SPARK-2180 test case class HavingRow(key: Int, value: String, attr: Int) + +case class LogEntry(filename: String, message: String) +case class LogFile(name: String) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index b08db6de2d2f6..dd13b8392880a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.hive.test.TestHive.{read, sparkContext, jsonRDD, sql} +import org.apache.spark.sql.hive.test.TestHive.{read, sparkContext, sql} import org.apache.spark.sql.hive.test.TestHive.implicits._ case class Nested(a: Int, B: Int) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 2209fc2f30a3c..b0c0dcbe5c25c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ - import org.apache.spark.util.Utils class HiveTableScanSuite extends HiveComparisonTest { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index 197e9bfb02c4e..6b424d73430e2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -25,25 +25,33 @@ import org.apache.spark.sql.hive.test.TestHive * A set of tests that validate type promotion and coercion rules. */ class HiveTypeCoercionSuite extends HiveComparisonTest { - val baseTypes = Seq("1", "1.0", "1L", "1S", "1Y", "'1'") + val baseTypes = Seq( + ("1", "1"), + ("1.0", "CAST(1.0 AS DOUBLE)"), + ("1L", "1L"), + ("1S", "1S"), + ("1Y", "1Y"), + ("'1'", "'1'")) - baseTypes.foreach { i => - baseTypes.foreach { j => - createQueryTest(s"$i + $j", s"SELECT $i + $j FROM src LIMIT 1") + baseTypes.foreach { case (ni, si) => + baseTypes.foreach { case (nj, sj) => + createQueryTest(s"$ni + $nj", s"SELECT $si + $sj FROM src LIMIT 1") } } val nullVal = "null" - baseTypes.init.foreach { i => + baseTypes.init.foreach { case (i, s) => createQueryTest(s"case when then $i else $nullVal end ", - s"SELECT case when true then $i else $nullVal end FROM src limit 1") + s"SELECT case when true then $s else $nullVal end FROM src limit 1") createQueryTest(s"case when then $nullVal else $i end ", - s"SELECT case when true then $nullVal else $i end FROM src limit 1") + s"SELECT case when true then $nullVal else $s end FROM src limit 1") } test("[SPARK-2210] boolean cast on boolean value should be removed") { val q = "select cast(cast(key=0 as boolean) as boolean) from src" - val project = TestHive.sql(q).queryExecution.executedPlan.collect { case e: Project => e }.head + val project = TestHive.sql(q).queryExecution.sparkPlan.collect { + case e: Project => e + }.head // No cast expression introduced project.transformAllExpressions { case c: Cast => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 5f9a447759b48..d07ac56586744 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -17,20 +17,21 @@ package org.apache.spark.sql.hive.execution -import java.io.{DataInput, DataOutput} +import java.io.{DataInput, DataOutput, File, PrintWriter} import java.util.{ArrayList, Arrays, Properties} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.ql.udf.UDAFPercentile -import org.apache.hadoop.hive.ql.udf.generic.{GenericUDFOPAnd, GenericUDTFExplode, GenericUDAFAverage, GenericUDF} +import org.apache.hadoop.hive.ql.udf.generic._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.io.Writable -import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} -import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) @@ -44,9 +45,9 @@ case class ListStringCaseClass(l: Seq[String]) /** * A test suite for Hive custom UDFs. */ -class HiveUDFSuite extends QueryTest with TestHiveSingleton { +class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { - import hiveContext.{udf, sql} + import hiveContext.udf import hiveContext.implicits._ test("spark sql udf test that returns a struct") { @@ -92,44 +93,36 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton { } test("Max/Min on named_struct") { - def testOrderInStruct(): Unit = { - checkAnswer(sql( - """ - |SELECT max(named_struct( - | "key", key, - | "value", value)).value FROM src - """.stripMargin), Seq(Row("val_498"))) - checkAnswer(sql( - """ - |SELECT min(named_struct( - | "key", key, - | "value", value)).value FROM src - """.stripMargin), Seq(Row("val_0"))) - - // nested struct cases - checkAnswer(sql( - """ - |SELECT max(named_struct( - | "key", named_struct( - "key", key, - "value", value), - | "value", value)).value FROM src - """.stripMargin), Seq(Row("val_498"))) - checkAnswer(sql( - """ - |SELECT min(named_struct( - | "key", named_struct( - "key", key, - "value", value), - | "value", value)).value FROM src - """.stripMargin), Seq(Row("val_0"))) - } - val codegenDefault = hiveContext.getConf(SQLConf.CODEGEN_ENABLED) - hiveContext.setConf(SQLConf.CODEGEN_ENABLED, true) - testOrderInStruct() - hiveContext.setConf(SQLConf.CODEGEN_ENABLED, false) - testOrderInStruct() - hiveContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) + checkAnswer(sql( + """ + |SELECT max(named_struct( + | "key", key, + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_498"))) + checkAnswer(sql( + """ + |SELECT min(named_struct( + | "key", key, + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_0"))) + + // nested struct cases + checkAnswer(sql( + """ + |SELECT max(named_struct( + | "key", named_struct( + "key", key, + "value", value), + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_498"))) + checkAnswer(sql( + """ + |SELECT min(named_struct( + | "key", named_struct( + "key", key, + "value", value), + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_0"))) } test("SPARK-6409 UDAF Average test") { @@ -150,10 +143,10 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton { } test("Generic UDAF aggregates") { - checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999)) FROM src LIMIT 1"), + checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999D)) FROM src LIMIT 1"), sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq) - checkAnswer(sql("SELECT percentile_approx(100.0, array(0.9, 0.9)) FROM src LIMIT 1"), + checkAnswer(sql("SELECT percentile_approx(100.0D, array(0.9D, 0.9D)) FROM src LIMIT 1"), sql("SELECT array(100, 100) FROM src LIMIT 1").collect().toSeq) } @@ -310,7 +303,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton { val message = intercept[AnalysisException] { sql("SELECT testUDFTwoListList() FROM testUDF") }.getMessage - assert(message.contains("No handler for Hive udf")) + assert(message.contains("No handler for Hive UDF")) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") } @@ -320,7 +313,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton { val message = intercept[AnalysisException] { sql("SELECT testUDFAnd() FROM testUDF") }.getMessage - assert(message.contains("No handler for Hive udf")) + assert(message.contains("No handler for Hive UDF")) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFAnd") } @@ -330,7 +323,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton { val message = intercept[AnalysisException] { sql("SELECT testUDAFPercentile(a) FROM testUDF GROUP BY b") }.getMessage - assert(message.contains("No handler for Hive udf")) + assert(message.contains("No handler for Hive UDF")) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDAFPercentile") } @@ -340,7 +333,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton { val message = intercept[AnalysisException] { sql("SELECT testUDAFAverage() FROM testUDF GROUP BY b") }.getMessage - assert(message.contains("No handler for Hive udf")) + assert(message.contains("No handler for Hive UDF")) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDAFAverage") } @@ -350,12 +343,111 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton { val message = intercept[AnalysisException] { sql("SELECT testUDTFExplode() FROM testUDF") }.getMessage - assert(message.contains("No handler for Hive udf")) + assert(message.contains("No handler for Hive UDF")) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDTFExplode") } sqlContext.dropTempTable("testUDF") } + + test("Hive UDF in group by") { + withTempTable("tab1") { + Seq(Tuple1(1451400761)).toDF("test_date").registerTempTable("tab1") + sql(s"CREATE TEMPORARY FUNCTION testUDFToDate AS '${classOf[GenericUDFToDate].getName}'") + val count = sql("select testUDFToDate(cast(test_date as timestamp))" + + " from tab1 group by testUDFToDate(cast(test_date as timestamp))").count() + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToDate") + assert(count == 1) + } + } + + test("SPARK-11522 select input_file_name from non-parquet table") { + + withTempDir { tempDir => + + // EXTERNAL OpenCSVSerde table pointing to LOCATION + + val file1 = new File(tempDir + "/data1") + val writer1 = new PrintWriter(file1) + writer1.write("1,2") + writer1.close() + + val file2 = new File(tempDir + "/data2") + val writer2 = new PrintWriter(file2) + writer2.write("1,2") + writer2.close() + + sql( + s"""CREATE EXTERNAL TABLE csv_table(page_id INT, impressions INT) + ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' + WITH SERDEPROPERTIES ( + \"separatorChar\" = \",\", + \"quoteChar\" = \"\\\"\", + \"escapeChar\" = \"\\\\\") + LOCATION '$tempDir' + """) + + val answer1 = + sql("SELECT input_file_name() FROM csv_table").head().getString(0) + assert(answer1.contains("data1") || answer1.contains("data2")) + + val count1 = sql("SELECT input_file_name() FROM csv_table").distinct().count() + assert(count1 == 2) + sql("DROP TABLE csv_table") + + // EXTERNAL pointing to LOCATION + + sql( + s"""CREATE EXTERNAL TABLE external_t5 (c1 int, c2 int) + ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' + LOCATION '$tempDir' + """) + + val answer2 = + sql("SELECT input_file_name() as file FROM external_t5").head().getString(0) + assert(answer1.contains("data1") || answer1.contains("data2")) + + val count2 = sql("SELECT input_file_name() as file FROM external_t5").distinct().count + assert(count2 == 2) + sql("DROP TABLE external_t5") + } + + withTempDir { tempDir => + + // External parquet pointing to LOCATION + + val parquetLocation = tempDir + "/external_parquet" + sql("SELECT 1, 2").write.parquet(parquetLocation) + + sql( + s"""CREATE EXTERNAL TABLE external_parquet(c1 int, c2 int) + STORED AS PARQUET + LOCATION '$parquetLocation' + """) + + val answer3 = + sql("SELECT input_file_name() as file FROM external_parquet").head().getString(0) + assert(answer3.contains("external_parquet")) + + val count3 = sql("SELECT input_file_name() as file FROM external_parquet").distinct().count + assert(count3 == 1) + sql("DROP TABLE external_parquet") + } + + // Non-External parquet pointing to /tmp/... + + sql("CREATE TABLE parquet_tmp(c1 int, c2 int) " + + " STORED AS parquet " + + " AS SELECT 1, 2") + + val answer4 = + sql("SELECT input_file_name() as file FROM parquet_tmp").head().getString(0) + assert(answer4.contains("parquet_tmp")) + + val count4 = sql("SELECT input_file_name() as file FROM parquet_tmp").distinct().count + assert(count4 == 1) + sql("DROP TABLE parquet_tmp") + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 210d566745415..97cb9d972081c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -144,12 +144,12 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { expectedScannedColumns: Seq[String], expectedPartValues: Seq[Seq[String]]): Unit = { test(s"$testCaseName - pruning test") { - val plan = new TestHive.QueryExecution(sql).executedPlan + val plan = new TestHive.QueryExecution(sql).sparkPlan val actualOutputColumns = plan.output.map(_.name) val (actualScannedColumns, actualPartValues) = plan.collect { case p @ HiveTableScan(columns, relation, _) => val columnNames = columns.map(_.name) - val partValues = if (relation.table.isPartitioned) { + val partValues = if (relation.table.partitionColumnNames.nonEmpty) { p.prunePartitions(relation.getHiveQlPartitions()).map(_.getValues) } else { Seq.empty diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index af48d478953b4..5ce16be4dc059 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -19,16 +19,15 @@ package org.apache.spark.sql.hive.execution import java.sql.{Date, Timestamp} -import scala.collection.JavaConverters._ - import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{TableIdentifier, DefaultParserDialect} -import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, EliminateSubQueries} -import org.apache.spark.sql.catalyst.errors.DialectException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry} import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation} import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} -import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -55,8 +54,6 @@ case class WindowData( month: Int, area: String, product: Int) -/** A SQL Dialect for testing purpose, and it can not be nested type */ -class MyDialect extends DefaultParserDialect /** * A collection of hive query tests where we generate the answers ourselves instead of depending on @@ -68,22 +65,43 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import hiveContext.implicits._ test("UDTF") { - sql(s"ADD JAR ${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}") - // The function source code can be found at: - // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF - sql( - """ - |CREATE TEMPORARY FUNCTION udtf_count2 - |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' - """.stripMargin) + withUserDefinedFunction("udtf_count2" -> true) { + sql(s"ADD JAR ${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}") + // The function source code can be found at: + // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF + sql( + """ + |CREATE TEMPORARY FUNCTION udtf_count2 + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin) - checkAnswer( - sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"), - Row(97, 500) :: Row(97, 500) :: Nil) + checkAnswer( + sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"), + Row(97, 500) :: Row(97, 500) :: Nil) - checkAnswer( - sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), - Row(3) :: Row(3) :: Nil) + checkAnswer( + sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), + Row(3) :: Row(3) :: Nil) + } + } + + test("permanent UDTF") { + withUserDefinedFunction("udtf_count_temp" -> false) { + sql( + s""" + |CREATE FUNCTION udtf_count_temp + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + |USING JAR '${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}' + """.stripMargin) + + checkAnswer( + sql("SELECT key, cc FROM src LATERAL VIEW udtf_count_temp(value) dd AS cc"), + Row(97, 500) :: Row(97, 500) :: Nil) + + checkAnswer( + sql("SELECT udtf_count_temp(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), + Row(3) :: Row(3) :: Nil) + } } test("SPARK-6835: udtf in lateral view") { @@ -93,6 +111,16 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(query, Row(1, 1) :: Row(1, 2) :: Row(1, 3) :: Nil) } + test("SPARK-13651: generator outputs shouldn't be resolved from its child's output") { + withTempTable("src") { + Seq(("id1", "value1")).toDF("key", "value").registerTempTable("src") + val query = + sql("SELECT genoutput.* FROM src " + + "LATERAL VIEW explode(map('key1', 100, 'key2', 200)) genoutput AS key, value") + checkAnswer(query, Row("key1", 100) :: Row("key2", 200) :: Nil) + } + } + test("SPARK-6851: Self-joined converted parquet tables") { val orders = Seq( Order(1, "Atlas", "MTB", 234, "2015-01-07", "John D", "Pacifica", "CA", 20151), @@ -160,9 +188,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("show functions") { - val allBuiltinFunctions = - (FunctionRegistry.builtin.listFunction().toSet[String] ++ - org.apache.hadoop.hive.ql.exec.FunctionRegistry.getFunctionNames.asScala).toList.sorted + val allBuiltinFunctions = FunctionRegistry.builtin.listFunction().toSet[String].toList.sorted // The TestContext is shared by all the test cases, some functions may be registered before // this, so we check that all the builtin functions are returned. val allFunctions = sql("SHOW functions").collect().map(r => r(0)) @@ -176,9 +202,13 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(sql("SHOW functions `abc`.`abs`"), Row("abs")) checkAnswer(sql("SHOW functions `~`"), Row("~")) checkAnswer(sql("SHOW functions `a function doens't exist`"), Nil) - checkAnswer(sql("SHOW functions `weekofyea.*`"), Row("weekofyear")) + checkAnswer(sql("SHOW functions `weekofyea*`"), Row("weekofyear")) // this probably will failed if we add more function with `sha` prefixing. - checkAnswer(sql("SHOW functions `sha.*`"), Row("sha") :: Row("sha1") :: Row("sha2") :: Nil) + checkAnswer(sql("SHOW functions `sha*`"), Row("sha") :: Row("sha1") :: Row("sha2") :: Nil) + // Test '|' for alternation. + checkAnswer( + sql("SHOW functions 'sha*|weekofyea*'"), + Row("sha") :: Row("sha1") :: Row("sha2") :: Row("weekofyear") :: Nil) } test("describe functions") { @@ -200,12 +230,30 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { "Extended Usage") checkExistence(sql("describe functioN abcadf"), true, - "Function: abcadf is not found.") + "Function: abcadf not found.") checkExistence(sql("describe functioN `~`"), true, "Function: ~", - "Class: org.apache.hadoop.hive.ql.udf.UDFOPBitNot", - "Usage: ~ n - Bitwise not") + "Class: org.apache.spark.sql.catalyst.expressions.BitwiseNot", + "Usage: ~ b - Bitwise NOT.") + + // Hard coded describe functions + checkExistence(sql("describe function `<>`"), true, + "Function: <>", + "Usage: a <> b - Returns TRUE if a is not equal to b") + + checkExistence(sql("describe function `!=`"), true, + "Function: !=", + "Usage: a != b - Returns TRUE if a is not equal to b") + + checkExistence(sql("describe function `between`"), true, + "Function: between", + "Usage: a [NOT] BETWEEN b AND c - evaluate if a is [not] in between b and c") + + checkExistence(sql("describe function `case`"), true, + "Function: case", + "Usage: CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END - " + + "When a = b, returns c; when a = d, return e; else return f") } test("SPARK-5371: union with null and sum") { @@ -255,6 +303,23 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer( sql("SELECT ints FROM nestedArray LATERAL VIEW explode(a.b) a AS ints"), Row(1) :: Row(2) :: Row(3) :: Nil) + + checkAnswer( + sql("SELECT `ints` FROM nestedArray LATERAL VIEW explode(a.b) `a` AS `ints`"), + Row(1) :: Row(2) :: Row(3) :: Nil) + + checkAnswer( + sql("SELECT `a`.`ints` FROM nestedArray LATERAL VIEW explode(a.b) `a` AS `ints`"), + Row(1) :: Row(2) :: Row(3) :: Nil) + + checkAnswer( + sql( + """ + |SELECT `weird``tab`.`weird``col` + |FROM nestedArray + |LATERAL VIEW explode(a.b) `weird``tab` AS `weird``col` + """.stripMargin), + Row(1) :: Row(2) :: Row(3) :: Nil) } test("SPARK-4512 Fix attribute reference resolution error when using SORT BY") { @@ -266,19 +331,20 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("CTAS without serde") { def checkRelation(tableName: String, isDataSourceParquet: Boolean): Unit = { - val relation = EliminateSubQueries(catalog.lookupRelation(TableIdentifier(tableName))) + val relation = EliminateSubqueryAliases( + sessionState.catalog.lookupRelation(TableIdentifier(tableName))) relation match { - case LogicalRelation(r: ParquetRelation, _) => + case LogicalRelation(r: HadoopFsRelation, _, _) => if (!isDataSourceParquet) { fail( s"${classOf[MetastoreRelation].getCanonicalName} is expected, but found " + - s"${ParquetRelation.getClass.getCanonicalName}.") + s"${HadoopFsRelation.getClass.getCanonicalName}.") } case r: MetastoreRelation => if (isDataSourceParquet) { fail( - s"${ParquetRelation.getClass.getCanonicalName} is expected, but found " + + s"${HadoopFsRelation.getClass.getCanonicalName} is expected, but found " + s"${classOf[MetastoreRelation].getCanonicalName}.") } } @@ -294,7 +360,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { var message = intercept[AnalysisException] { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") }.getMessage - assert(message.contains("ctas1 already exists")) + assert(message.contains("already exists")) checkRelation("ctas1", true) sql("DROP TABLE ctas1") @@ -335,42 +401,6 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } - test("SQL dialect at the start of HiveContext") { - val hiveContext = new HiveContext(sqlContext.sparkContext) - val dialectConf = "spark.sql.dialect" - checkAnswer(hiveContext.sql(s"set $dialectConf"), Row(dialectConf, "hiveql")) - assert(hiveContext.getSQLDialect().getClass === classOf[HiveQLDialect]) - } - - test("SQL Dialect Switching") { - assert(getSQLDialect().getClass === classOf[HiveQLDialect]) - setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) - assert(getSQLDialect().getClass === classOf[MyDialect]) - assert(sql("SELECT 1").collect() === Array(Row(1))) - - // set the dialect back to the DefaultSQLDialect - sql("SET spark.sql.dialect=sql") - assert(getSQLDialect().getClass === classOf[DefaultParserDialect]) - sql("SET spark.sql.dialect=hiveql") - assert(getSQLDialect().getClass === classOf[HiveQLDialect]) - - // set invalid dialect - sql("SET spark.sql.dialect.abc=MyTestClass") - sql("SET spark.sql.dialect=abc") - intercept[Exception] { - sql("SELECT 1") - } - // test if the dialect set back to HiveQLDialect - getSQLDialect().getClass === classOf[HiveQLDialect] - - sql("SET spark.sql.dialect=MyTestClass") - intercept[DialectException] { - sql("SELECT 1") - } - // test if the dialect set back to HiveQLDialect - assert(getSQLDialect().getClass === classOf[HiveQLDialect]) - } - test("CTAS with serde") { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect() sql( @@ -730,7 +760,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { (1 to 100).par.map { i => val tableName = s"SPARK_6618_table_$i" sql(s"CREATE TABLE $tableName (col1 string)") - catalog.lookupRelation(TableIdentifier(tableName)) + sessionState.catalog.lookupRelation(TableIdentifier(tableName)) table(tableName) tables() sql(s"DROP TABLE $tableName") @@ -738,7 +768,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("SPARK-5203 union with different decimal precision") { - Seq.empty[(Decimal, Decimal)] + Seq.empty[(java.math.BigDecimal, java.math.BigDecimal)] .toDF("d1", "d2") .select($"d1".cast(DecimalType(10, 5)).as("d")) .registerTempTable("dn") @@ -747,20 +777,24 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { .queryExecution.analyzed } + test("Star Expansion - script transform") { + val data = (1 to 100000).map { i => (i, i, i) } + data.toDF("d1", "d2", "d3").registerTempTable("script_trans") + assert(100000 === sql("SELECT TRANSFORM (*) USING 'cat' FROM script_trans").count()) + } + test("test script transform for stdout") { val data = (1 to 100000).map { i => (i, i, i) } data.toDF("d1", "d2", "d3").registerTempTable("script_trans") assert(100000 === - sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans") - .queryExecution.toRdd.count()) + sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans").count()) } test("test script transform for stderr") { val data = (1 to 100000).map { i => (i, i, i) } data.toDF("d1", "d2", "d3").registerTempTable("script_trans") assert(0 === - sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat 1>&2' AS (a,b,c) FROM script_trans") - .queryExecution.toRdd.count()) + sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat 1>&2' AS (a,b,c) FROM script_trans").count()) } test("test script transform data type") { @@ -768,12 +802,12 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { data.toDF("key", "value").registerTempTable("test") checkAnswer( sql("""FROM - |(FROM test SELECT TRANSFORM(key, value) USING 'cat' AS (thing1 int, thing2 string)) t + |(FROM test SELECT TRANSFORM(key, value) USING 'cat' AS (`thing1` int, thing2 string)) t |SELECT thing1 + 1 """.stripMargin), (2 to 6).map(i => Row(i))) } - test("window function: udaf with aggregate expressin") { + test("window function: udaf with aggregate expression") { val data = Seq( WindowData(1, "a", 5), WindowData(2, "a", 6), @@ -915,6 +949,27 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("window function: distinct should not be silently ignored") { + val data = Seq( + WindowData(1, "a", 5), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 10) + ) + sparkContext.parallelize(data).toDF().registerTempTable("windowData") + + val e = intercept[AnalysisException] { + sql( + """ + |select month, area, product, sum(distinct product + 1) over (partition by 1 order by 2) + |from windowData + """.stripMargin) + } + assert(e.getMessage.contains("Distinct window functions are not supported")) + } + test("window function: expressions in arguments of a window functions") { val data = Seq( WindowData(1, "a", 5), @@ -943,6 +998,130 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("Sorting columns are not in Generate") { + withTempTable("data") { + sqlContext.range(1, 5) + .select(array($"id", $"id" + 1).as("a"), $"id".as("b"), (lit(10) - $"id").as("c")) + .registerTempTable("data") + + // case 1: missing sort columns are resolvable if join is true + checkAnswer( + sql("SELECT explode(a) AS val, b FROM data WHERE b < 2 order by val, c"), + Row(1, 1) :: Row(2, 1) :: Nil) + + // case 2: missing sort columns are resolvable if join is false + checkAnswer( + sql("SELECT explode(a) AS val FROM data order by val, c"), + Seq(1, 2, 2, 3, 3, 4, 4, 5).map(i => Row(i))) + + // case 3: missing sort columns are resolvable if join is true and outer is true + checkAnswer( + sql( + """ + |SELECT C.val, b FROM data LATERAL VIEW OUTER explode(a) C as val + |where b < 2 order by c, val, b + """.stripMargin), + Row(1, 1) :: Row(2, 1) :: Nil) + } + } + + test("window function: Sorting columns are not in Project") { + val data = Seq( + WindowData(1, "d", 10), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 11) + ) + sparkContext.parallelize(data).toDF().registerTempTable("windowData") + + checkAnswer( + sql("select month, product, sum(product + 1) over() from windowData order by area"), + Seq( + (2, 6, 57), + (3, 7, 57), + (4, 8, 57), + (5, 9, 57), + (6, 11, 57), + (1, 10, 57) + ).map(i => Row(i._1, i._2, i._3))) + + checkAnswer( + sql( + """ + |select area, rank() over (partition by area order by tmp.month) + tmp.tmp1 as c1 + |from (select month, area, product as p, 1 as tmp1 from windowData) tmp order by p + """.stripMargin), + Seq( + ("a", 2), + ("b", 2), + ("b", 3), + ("c", 2), + ("d", 2), + ("c", 3) + ).map(i => Row(i._1, i._2))) + + checkAnswer( + sql( + """ + |select area, rank() over (partition by area order by month) as c1 + |from windowData group by product, area, month order by product, area + """.stripMargin), + Seq( + ("a", 1), + ("b", 1), + ("b", 2), + ("c", 1), + ("d", 1), + ("c", 2) + ).map(i => Row(i._1, i._2))) + + checkAnswer( + sql( + """ + |select area, sum(product) / sum(sum(product)) over (partition by area) as c1 + |from windowData group by area, month order by month, c1 + """.stripMargin), + Seq( + ("d", 1.0), + ("a", 1.0), + ("b", 0.4666666666666667), + ("b", 0.5333333333333333), + ("c", 0.45), + ("c", 0.55) + ).map(i => Row(i._1, i._2))) + } + + // todo: fix this test case by reimplementing the function ResolveAggregateFunctions + ignore("window function: Pushing aggregate Expressions in Sort to Aggregate") { + val data = Seq( + WindowData(1, "d", 10), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 11) + ) + sparkContext.parallelize(data).toDF().registerTempTable("windowData") + + checkAnswer( + sql( + """ + |select area, sum(product) over () as c from windowData + |where product > 3 group by area, product + |having avg(month) > 0 order by avg(month), product + """.stripMargin), + Seq( + ("a", 51), + ("b", 51), + ("b", 51), + ("c", 51), + ("c", 51), + ("d", 51) + ).map(i => Row(i._1, i._2))) + } + test("window function: multiple window expressions in a single expression") { val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") nums.registerTempTable("nums") @@ -1028,9 +1207,9 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { | java_method("java.lang.String", "isEmpty"), | java_method("java.lang.Math", "max", 2, 3), | java_method("java.lang.Math", "min", 2, 3), - | java_method("java.lang.Math", "round", 2.5), - | java_method("java.lang.Math", "exp", 1.0), - | java_method("java.lang.Math", "floor", 1.9) + | java_method("java.lang.Math", "round", 2.5D), + | java_method("java.lang.Math", "exp", 1.0D), + | java_method("java.lang.Math", "floor", 1.9D) |FROM src tablesample (1 rows) """.stripMargin), Row( @@ -1185,6 +1364,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { .format("parquet") .save(path) + // We don't support creating a temporary table while specifying a database val message = intercept[AnalysisException] { sqlContext.sql( s""" @@ -1195,9 +1375,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { |) """.stripMargin) }.getMessage - assert(message.contains("Specifying database name or other qualifiers are not allowed")) - // If you use backticks to quote the name of a temporary table having dot in it. + // If you use backticks to quote the name then it's OK. sqlContext.sql( s""" |CREATE TEMPORARY TABLE `db.t` @@ -1289,7 +1468,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("run sql directly on files") { - val df = sqlContext.range(100) + val df = sqlContext.range(100).toDF() withTempPath(f => { df.write.parquet(f.getCanonicalPath) checkAnswer(sql(s"select id from parquet.`${f.getCanonicalPath}`"), @@ -1309,7 +1488,6 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { sql( """CREATE VIEW IF NOT EXISTS |default.testView (c1 COMMENT 'blabla', c2 COMMENT 'blabla') - |COMMENT 'blabla' |TBLPROPERTIES ('a' = 'b') |AS SELECT * FROM jt""".stripMargin) checkAnswer(sql("SELECT c1, c2 FROM testView ORDER BY c1"), (1 to 9).map(i => Row(i, i))) @@ -1335,67 +1513,119 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } - test("correctly handle CREATE OR REPLACE VIEW") { - withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { - withTable("jt", "jt2") { - sqlContext.range(1, 10).write.format("json").saveAsTable("jt") - sql("CREATE OR REPLACE VIEW testView AS SELECT id FROM jt") - checkAnswer(sql("SELECT * FROM testView ORDER BY id"), (1 to 9).map(i => Row(i))) + Seq(true, false).foreach { enabled => + val prefix = (if (enabled) "With" else "Without") + " canonical native view: " + test(s"$prefix correctly handle CREATE OR REPLACE VIEW") { + withSQLConf( + SQLConf.NATIVE_VIEW.key -> "true", SQLConf.CANONICAL_NATIVE_VIEW.key -> enabled.toString) { + withTable("jt", "jt2") { + sqlContext.range(1, 10).write.format("json").saveAsTable("jt") + sql("CREATE OR REPLACE VIEW testView AS SELECT id FROM jt") + checkAnswer(sql("SELECT * FROM testView ORDER BY id"), (1 to 9).map(i => Row(i))) - val df = (1 until 10).map(i => i -> i).toDF("i", "j") - df.write.format("json").saveAsTable("jt2") - sql("CREATE OR REPLACE VIEW testView AS SELECT * FROM jt2") - // make sure the view has been changed. - checkAnswer(sql("SELECT * FROM testView ORDER BY i"), (1 to 9).map(i => Row(i, i))) + val df = (1 until 10).map(i => i -> i).toDF("i", "j") + df.write.format("json").saveAsTable("jt2") + sql("CREATE OR REPLACE VIEW testView AS SELECT * FROM jt2") + // make sure the view has been changed. + checkAnswer(sql("SELECT * FROM testView ORDER BY i"), (1 to 9).map(i => Row(i, i))) - sql("DROP VIEW testView") + sql("DROP VIEW testView") - val e = intercept[AnalysisException] { - sql("CREATE OR REPLACE VIEW IF NOT EXISTS testView AS SELECT id FROM jt") + val e = intercept[AnalysisException] { + sql("CREATE OR REPLACE VIEW IF NOT EXISTS testView AS SELECT id FROM jt") + } + assert(e.message.contains("not allowed to define a view")) } - assert(e.message.contains("not allowed to define a view")) } } - } - test("correctly handle ALTER VIEW") { - withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { - withTable("jt", "jt2") { - sqlContext.range(1, 10).write.format("json").saveAsTable("jt") - sql("CREATE VIEW testView AS SELECT id FROM jt") + test(s"$prefix correctly handle ALTER VIEW") { + withSQLConf( + SQLConf.NATIVE_VIEW.key -> "true", SQLConf.CANONICAL_NATIVE_VIEW.key -> enabled.toString) { + withTable("jt", "jt2") { + withView("testView") { + sqlContext.range(1, 10).write.format("json").saveAsTable("jt") + sql("CREATE VIEW testView AS SELECT id FROM jt") + + val df = (1 until 10).map(i => i -> i).toDF("i", "j") + df.write.format("json").saveAsTable("jt2") + sql("ALTER VIEW testView AS SELECT * FROM jt2") + // make sure the view has been changed. + checkAnswer(sql("SELECT * FROM testView ORDER BY i"), (1 to 9).map(i => Row(i, i))) + } + } + } + } - val df = (1 until 10).map(i => i -> i).toDF("i", "j") - df.write.format("json").saveAsTable("jt2") - sql("ALTER VIEW testView AS SELECT * FROM jt2") - // make sure the view has been changed. - checkAnswer(sql("SELECT * FROM testView ORDER BY i"), (1 to 9).map(i => Row(i, i))) + test(s"$prefix create hive view for json table") { + // json table is not hive-compatible, make sure the new flag fix it. + withSQLConf( + SQLConf.NATIVE_VIEW.key -> "true", SQLConf.CANONICAL_NATIVE_VIEW.key -> enabled.toString) { + withTable("jt") { + withView("testView") { + sqlContext.range(1, 10).write.format("json").saveAsTable("jt") + sql("CREATE VIEW testView AS SELECT id FROM jt") + checkAnswer(sql("SELECT * FROM testView ORDER BY id"), (1 to 9).map(i => Row(i))) + } + } + } + } - sql("DROP VIEW testView") + test(s"$prefix create hive view for partitioned parquet table") { + // partitioned parquet table is not hive-compatible, make sure the new flag fix it. + withSQLConf( + SQLConf.NATIVE_VIEW.key -> "true", SQLConf.CANONICAL_NATIVE_VIEW.key -> enabled.toString) { + withTable("parTable") { + withView("testView") { + val df = Seq(1 -> "a").toDF("i", "j") + df.write.format("parquet").partitionBy("i").saveAsTable("parTable") + sql("CREATE VIEW testView AS SELECT i, j FROM parTable") + checkAnswer(sql("SELECT * FROM testView"), Row(1, "a")) + } + } } } } - test("create hive view for json table") { - // json table is not hive-compatible, make sure the new flag fix it. - withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { - withTable("jt") { - sqlContext.range(1, 10).write.format("json").saveAsTable("jt") - sql("CREATE VIEW testView AS SELECT id FROM jt") - checkAnswer(sql("SELECT * FROM testView ORDER BY id"), (1 to 9).map(i => Row(i))) - sql("DROP VIEW testView") + test("CTE within view") { + withSQLConf( + SQLConf.NATIVE_VIEW.key -> "true", SQLConf.CANONICAL_NATIVE_VIEW.key -> "true") { + withView("cte_view") { + sql("CREATE VIEW cte_view AS WITH w AS (SELECT 1 AS n) SELECT n FROM w") + checkAnswer(sql("SELECT * FROM cte_view"), Row(1)) } } } - test("create hive view for partitioned parquet table") { - // partitioned parquet table is not hive-compatible, make sure the new flag fix it. - withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { - withTable("parTable") { - val df = Seq(1 -> "a").toDF("i", "j") - df.write.format("parquet").partitionBy("i").saveAsTable("parTable") - sql("CREATE VIEW testView AS SELECT i, j FROM parTable") - checkAnswer(sql("SELECT * FROM testView"), Row(1, "a")) - sql("DROP VIEW testView") + test("Using view after switching current database") { + withSQLConf( + SQLConf.NATIVE_VIEW.key -> "true", SQLConf.CANONICAL_NATIVE_VIEW.key -> "true") { + withView("v") { + sql("CREATE VIEW v AS SELECT * FROM src") + withTempDatabase { db => + activateDatabase(db) { + // Should look up table `src` in database `default`. + checkAnswer(sql("SELECT * FROM default.v"), sql("SELECT * FROM default.src")) + + // The new `src` table shouldn't be scanned. + sql("CREATE TABLE src(key INT, value STRING)") + checkAnswer(sql("SELECT * FROM default.v"), sql("SELECT * FROM default.src")) + } + } + } + } + } + + test("Using view after adding more columns") { + withSQLConf( + SQLConf.NATIVE_VIEW.key -> "true", SQLConf.CANONICAL_NATIVE_VIEW.key -> "true") { + withTable("add_col") { + sqlContext.range(10).write.saveAsTable("add_col") + withView("v") { + sql("CREATE VIEW v AS SELECT * FROM add_col") + sqlContext.range(10).select('id, 'id as 'a).write.mode("overwrite").saveAsTable("add_col") + checkAnswer(sql("SELECT * FROM v"), sqlContext.range(10).toDF()) + } } } } @@ -1418,14 +1648,251 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("SPARK-8976 Wrong Result for Rollup #1") { + checkAnswer(sql( + "SELECT count(*) AS cnt, key % 5, grouping_id() FROM src GROUP BY key%5 WITH ROLLUP"), + Seq( + (113, 3, 0), + (91, 0, 0), + (500, null, 1), + (84, 1, 0), + (105, 2, 0), + (107, 4, 0) + ).map(i => Row(i._1, i._2, i._3))) + } + + test("SPARK-8976 Wrong Result for Rollup #2") { + checkAnswer(sql( + """ + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 + |FROM src GROUP BY key%5, key-5 + |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, 0, 5, 0), + (1, 0, 15, 0), + (1, 0, 25, 0), + (1, 0, 60, 0), + (1, 0, 75, 0), + (1, 0, 80, 0), + (1, 0, 100, 0), + (1, 0, 140, 0), + (1, 0, 145, 0), + (1, 0, 150, 0) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } + + test("SPARK-8976 Wrong Result for Rollup #3") { + checkAnswer(sql( + """ + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, 0, 5, 0), + (1, 0, 15, 0), + (1, 0, 25, 0), + (1, 0, 60, 0), + (1, 0, 75, 0), + (1, 0, 80, 0), + (1, 0, 100, 0), + (1, 0, 140, 0), + (1, 0, 145, 0), + (1, 0, 150, 0) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } + + test("SPARK-8976 Wrong Result for CUBE #1") { + checkAnswer(sql( + "SELECT count(*) AS cnt, key % 5, grouping_id() FROM src GROUP BY key%5 WITH CUBE"), + Seq( + (113, 3, 0), + (91, 0, 0), + (500, null, 1), + (84, 1, 0), + (105, 2, 0), + (107, 4, 0) + ).map(i => Row(i._1, i._2, i._3))) + } + + test("SPARK-8976 Wrong Result for CUBE #2") { + checkAnswer(sql( + """ + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, null, -3, 2), + (1, null, -1, 2), + (1, null, 3, 2), + (1, null, 4, 2), + (1, null, 5, 2), + (1, null, 6, 2), + (1, null, 12, 2), + (1, null, 14, 2), + (1, null, 15, 2), + (1, null, 22, 2) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } + + test("SPARK-8976 Wrong Result for GroupingSet") { + checkAnswer(sql( + """ + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, null, -3, 2), + (1, null, -1, 2), + (1, null, 3, 2), + (1, null, 4, 2), + (1, null, 5, 2), + (1, null, 6, 2), + (1, null, 12, 2), + (1, null, 14, 2), + (1, null, 15, 2), + (1, null, 22, 2) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } + test("SPARK-10562: partition by column with mixed case name") { withTable("tbl10562") { val df = Seq(2012 -> "a").toDF("Year", "val") df.write.partitionBy("Year").saveAsTable("tbl10562") + checkAnswer(sql("SELECT year FROM tbl10562"), Row(2012)) checkAnswer(sql("SELECT Year FROM tbl10562"), Row(2012)) checkAnswer(sql("SELECT yEAr FROM tbl10562"), Row(2012)) checkAnswer(sql("SELECT val FROM tbl10562 WHERE Year > 2015"), Nil) checkAnswer(sql("SELECT val FROM tbl10562 WHERE Year == 2012"), Row("a")) } } + + test("SPARK-11453: append data to partitioned table") { + withTable("tbl11453") { + Seq("1" -> "10", "2" -> "20").toDF("i", "j") + .write.partitionBy("i").saveAsTable("tbl11453") + + Seq("3" -> "30").toDF("i", "j") + .write.mode(SaveMode.Append).partitionBy("i").saveAsTable("tbl11453") + checkAnswer( + sqlContext.read.table("tbl11453").select("i", "j").orderBy("i"), + Row("1", "10") :: Row("2", "20") :: Row("3", "30") :: Nil) + + // make sure case sensitivity is correct. + Seq("4" -> "40").toDF("i", "j") + .write.mode(SaveMode.Append).partitionBy("I").saveAsTable("tbl11453") + checkAnswer( + sqlContext.read.table("tbl11453").select("i", "j").orderBy("i"), + Row("1", "10") :: Row("2", "20") :: Row("3", "30") :: Row("4", "40") :: Nil) + } + } + + test("SPARK-11590: use native json_tuple in lateral view") { + checkAnswer(sql( + """ + |SELECT a, b + |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test + |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt AS a, b + """.stripMargin), Row("value1", "12")) + + // we should use `c0`, `c1`... as the name of fields if no alias is provided, to follow hive. + checkAnswer(sql( + """ + |SELECT c0, c1 + |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test + |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt + """.stripMargin), Row("value1", "12")) + + // we can also use `json_tuple` in project list. + checkAnswer(sql( + """ + |SELECT json_tuple(json, 'f1', 'f2') + |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test + """.stripMargin), Row("value1", "12")) + + // we can also mix `json_tuple` with other project expressions. + checkAnswer(sql( + """ + |SELECT json_tuple(json, 'f1', 'f2'), 3.14, str + |FROM (SELECT '{"f1": "value1", "f2": 12}' json, 'hello' as str) test + """.stripMargin), Row("value1", "12", BigDecimal("3.14"), "hello")) + } + + test("multi-insert with lateral view") { + withTempTable("t1") { + sqlContext.range(10) + .select(array($"id", $"id" + 1).as("arr"), $"id") + .registerTempTable("source") + withTable("dest1", "dest2") { + sql("CREATE TABLE dest1 (i INT)") + sql("CREATE TABLE dest2 (i INT)") + sql( + """ + |FROM source + |INSERT OVERWRITE TABLE dest1 + |SELECT id + |WHERE id > 3 + |INSERT OVERWRITE TABLE dest2 + |select col LATERAL VIEW EXPLODE(arr) exp AS col + |WHERE col > 3 + """.stripMargin) + + checkAnswer( + sqlContext.table("dest1"), + sql("SELECT id FROM source WHERE id > 3")) + checkAnswer( + sqlContext.table("dest2"), + sql("SELECT col FROM source LATERAL VIEW EXPLODE(arr) exp AS col WHERE col > 3")) + } + } + } + + test( + "SPARK-14488 \"CREATE TEMPORARY TABLE ... USING ... AS SELECT ...\" " + + "shouldn't create persisted table" + ) { + withTempPath { dir => + withTempTable("t1", "t2") { + val path = dir.getCanonicalPath + val ds = sqlContext.range(10) + ds.registerTempTable("t1") + + sql( + s"""CREATE TEMPORARY TABLE t2 + |USING PARQUET + |OPTIONS (PATH '$path') + |AS SELECT * FROM t1 + """.stripMargin) + + checkAnswer( + sqlContext.tables().select('isTemporary).filter('tableName === "t2"), + Row(true) + ) + + checkAnswer(table("t2"), table("t1")) + } + } + } + + test( + "SPARK-14493 \"CREATE TEMPORARY TABLE ... USING ... AS SELECT ...\" " + + "shouldn always be used together with PATH data source option" + ) { + withTempTable("t") { + sqlContext.range(10).registerTempTable("t") + + val message = intercept[IllegalArgumentException] { + sql( + s"""CREATE TEMPORARY TABLE t1 + |USING PARQUET + |AS SELECT * FROM t + """.stripMargin) + }.getMessage + + assert(message == "'path' is not specified") + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 7cfdb886b585d..8f163f27c94cf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.execution.{UnaryNode, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryNode} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.types.StringType diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala new file mode 100644 index 0000000000000..c6b7eb63662c5 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql._ +import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} +import org.apache.spark.sql.test.SQLTestUtils + +/** + * This suite contains a couple of Hive window tests which fail in the typical setup due to tiny + * numerical differences or due semantic differences between Hive and Spark. + */ +class WindowQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + + override def beforeAll(): Unit = { + super.beforeAll() + sql("DROP TABLE IF EXISTS part") + sql( + """ + |CREATE TABLE part( + | p_partkey INT, + | p_name STRING, + | p_mfgr STRING, + | p_brand STRING, + | p_type STRING, + | p_size INT, + | p_container STRING, + | p_retailprice DOUBLE, + | p_comment STRING) + """.stripMargin) + val testData1 = TestHive.getHiveFile("data/files/part_tiny.txt").getCanonicalPath + sql( + s""" + |LOAD DATA LOCAL INPATH '$testData1' overwrite into table part + """.stripMargin) + } + + override def afterAll(): Unit = { + try { + sql("DROP TABLE IF EXISTS part") + } finally { + super.afterAll() + } + } + + test("windowing.q -- 15. testExpressions") { + // Moved because: + // - Spark uses a different default stddev (sample instead of pop) + // - Tiny numerical differences in stddev results. + // - Different StdDev behavior when n=1 (NaN instead of 0) + checkAnswer(sql(s""" + |select p_mfgr,p_name, p_size, + |rank() over(distribute by p_mfgr sort by p_name) as r, + |dense_rank() over(distribute by p_mfgr sort by p_name) as dr, + |cume_dist() over(distribute by p_mfgr sort by p_name) as cud, + |percent_rank() over(distribute by p_mfgr sort by p_name) as pr, + |ntile(3) over(distribute by p_mfgr sort by p_name) as nt, + |count(p_size) over(distribute by p_mfgr sort by p_name) as ca, + |avg(p_size) over(distribute by p_mfgr sort by p_name) as avg, + |stddev(p_size) over(distribute by p_mfgr sort by p_name) as st, + |first_value(p_size % 5) over(distribute by p_mfgr sort by p_name) as fv, + |last_value(p_size) over(distribute by p_mfgr sort by p_name) as lv, + |first_value(p_size) over w1 as fvW1 + |from part + |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name + | rows between 2 preceding and 2 following) + """.stripMargin), + // scalastyle:off + Seq( + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 1, 1, 0.3333333333333333, 0.0, 1, 2, 2.0, 0.0, 2, 2, 2), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 1, 1, 0.3333333333333333, 0.0, 1, 2, 2.0, 0.0, 2, 2, 2), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 3, 2, 0.5, 0.4, 2, 3, 12.666666666666666, 18.475208614068027, 2, 34, 2), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 4, 3, 0.6666666666666666, 0.6, 2, 4, 11.0, 15.448840301675292, 2, 6, 2), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 5, 4, 0.8333333333333334, 0.8, 3, 5, 14.4, 15.388307249337076, 2, 28, 34), + Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 6, 5, 1.0, 1.0, 3, 6, 19.0, 17.787636155487327, 2, 42, 6), + Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 1, 1, 0.2, 0.0, 1, 1, 14.0, Double.NaN, 4, 14, 14), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 2, 2, 0.4, 0.25, 1, 2, 27.0, 18.384776310850235, 4, 40, 14), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 3, 3, 0.6, 0.5, 2, 3, 18.666666666666668, 19.42506971244462, 4, 2, 14), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 4, 4, 0.8, 0.75, 2, 4, 20.25, 16.17353805861084, 4, 25, 40), + Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 5, 5, 1.0, 1.0, 3, 5, 19.8, 14.042791745233567, 4, 18, 2), + Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 1, 1, 0.2, 0.0, 1, 1, 17.0,Double.NaN, 2, 17, 17), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 2, 2, 0.4, 0.25, 1, 2, 15.5, 2.1213203435596424, 2, 14, 17), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 3, 3, 0.6, 0.5, 2, 3, 16.666666666666668, 2.516611478423583, 2, 19, 17), + Row("Manufacturer#3", "almond antique misty red olive", 1, 4, 4, 0.8, 0.75, 2, 4, 12.75, 8.098353742170895, 2, 1, 14), + Row("Manufacturer#3", "almond antique olive coral navajo", 45, 5, 5, 1.0, 1.0, 3, 5, 19.2, 16.037456157383566, 2, 45, 19), + Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 1, 1, 0.2, 0.0, 1, 1, 10.0, Double.NaN, 0, 10, 10), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 2, 2, 0.4, 0.25, 1, 2, 24.5, 20.506096654409877, 0, 39, 10), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 3, 3, 0.6, 0.5, 2, 3, 25.333333333333332, 14.571661996262929, 0, 27, 10), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 4, 4, 0.8, 0.75, 2, 4, 20.75, 15.01943185787443, 0, 7, 39), + Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 5, 5, 1.0, 1.0, 3, 5, 19.0, 13.583077707206124, 0, 12, 27), + Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 1, 1, 0.2, 0.0, 1, 1, 31.0, Double.NaN, 1, 31, 31), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 2, 2, 0.4, 0.25, 1, 2, 18.5, 17.67766952966369, 1, 6, 31), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 3, 3, 0.6, 0.5, 2, 3, 13.0, 15.716233645501712, 1, 2, 31), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 4, 4, 0.8, 0.75, 2, 4, 21.25, 20.902551678363736, 1, 46, 6), + Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 5, 5, 1.0, 1.0, 3, 5, 21.6, 18.1190507477627, 1, 23, 2))) + // scalastyle:on + } + + test("windowing.q -- 20. testSTATs") { + // Moved because: + // - Spark uses a different default stddev/variance (sample instead of pop) + // - Tiny numerical differences in aggregation results. + checkAnswer(sql(""" + |select p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp + |from ( + |select p_mfgr,p_name, p_size, + |stddev_pop(p_retailprice) over w1 as sdev, + |stddev_pop(p_retailprice) over w1 as sdev_pop, + |collect_set(p_size) over w1 as uniq_size, + |var_pop(p_retailprice) over w1 as var, + |corr(p_size, p_retailprice) over w1 as cor, + |covar_pop(p_size, p_retailprice) over w1 as covarp + |from part + |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name + | rows between 2 preceding and 2 following) + |) t lateral view explode(uniq_size) d as uniq_data + |order by p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp + """.stripMargin), + // scalastyle:off + Seq( + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 258.10677784349247, 258.10677784349247, 2, 66619.10876874997, 0.811328754177887, 2801.7074999999995), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 258.10677784349247, 258.10677784349247, 6, 66619.10876874997, 0.811328754177887, 2801.7074999999995), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 258.10677784349247, 258.10677784349247, 34, 66619.10876874997, 0.811328754177887, 2801.7074999999995), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 273.70217881648085, 273.70217881648085, 2, 74912.88268888886, 1.0, 4128.782222222221), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 273.70217881648085, 273.70217881648085, 34, 74912.88268888886, 1.0, 4128.782222222221), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 230.9015158547037, 230.9015158547037, 2, 53315.510023999974, 0.6956393773976641, 2210.7864), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 230.9015158547037, 230.9015158547037, 6, 53315.510023999974, 0.6956393773976641, 2210.7864), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 230.9015158547037, 230.9015158547037, 28, 53315.510023999974, 0.6956393773976641, 2210.7864), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 230.9015158547037, 230.9015158547037, 34, 53315.510023999974, 0.6956393773976641, 2210.7864), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 2, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 6, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 28, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 34, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 42, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 121.60645179738611, 121.60645179738611, 6, 14788.129118749992, 0.2036684720435979, 331.1337500000004), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 121.60645179738611, 121.60645179738611, 28, 14788.129118749992, 0.2036684720435979, 331.1337500000004), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 121.60645179738611, 121.60645179738611, 34, 14788.129118749992, 0.2036684720435979, 331.1337500000004), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 121.60645179738611, 121.60645179738611, 42, 14788.129118749992, 0.2036684720435979, 331.1337500000004), + Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 96.57515864168516, 96.57515864168516, 6, 9326.761266666656, -1.4442181184933883E-4, -0.20666666666708502), + Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 96.57515864168516, 96.57515864168516, 28, 9326.761266666656, -1.4442181184933883E-4, -0.20666666666708502), + Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 96.57515864168516, 96.57515864168516, 42, 9326.761266666656, -1.4442181184933883E-4, -0.20666666666708502), + Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 142.23631697518977, 142.23631697518977, 2, 20231.16986666666, -0.4936952655452319, -1113.7466666666658), + Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 142.23631697518977, 142.23631697518977, 14, 20231.16986666666, -0.4936952655452319, -1113.7466666666658), + Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 142.23631697518977, 142.23631697518977, 40, 20231.16986666666, -0.4936952655452319, -1113.7466666666658), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 137.7630649884068, 137.7630649884068, 2, 18978.662074999997, -0.5205630897335946, -1004.4812499999995), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 137.7630649884068, 137.7630649884068, 14, 18978.662074999997, -0.5205630897335946, -1004.4812499999995), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 137.7630649884068, 137.7630649884068, 25, 18978.662074999997, -0.5205630897335946, -1004.4812499999995), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 137.7630649884068, 137.7630649884068, 40, 18978.662074999997, -0.5205630897335946, -1004.4812499999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 2, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 14, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 18, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 25, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 40, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 135.55100986344593, 135.55100986344593, 2, 18374.076275000018, -0.6091405874714462, -1128.1787499999987), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 135.55100986344593, 135.55100986344593, 18, 18374.076275000018, -0.6091405874714462, -1128.1787499999987), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 135.55100986344593, 135.55100986344593, 25, 18374.076275000018, -0.6091405874714462, -1128.1787499999987), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 135.55100986344593, 135.55100986344593, 40, 18374.076275000018, -0.6091405874714462, -1128.1787499999987), + Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 156.44019460768035, 156.44019460768035, 2, 24473.534488888898, -0.9571686373491605, -1441.4466666666676), + Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 156.44019460768035, 156.44019460768035, 18, 24473.534488888898, -0.9571686373491605, -1441.4466666666676), + Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 156.44019460768035, 156.44019460768035, 25, 24473.534488888898, -0.9571686373491605, -1441.4466666666676), + Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 196.77422668858057, 196.77422668858057, 14, 38720.0962888889, 0.5557168646224995, 224.6944444444446), + Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 196.77422668858057, 196.77422668858057, 17, 38720.0962888889, 0.5557168646224995, 224.6944444444446), + Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 196.77422668858057, 196.77422668858057, 19, 38720.0962888889, 0.5557168646224995, 224.6944444444446), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 275.1414418985261, 275.1414418985261, 1, 75702.81305000003, -0.6720833036576083, -1296.9000000000003), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 275.1414418985261, 275.1414418985261, 14, 75702.81305000003, -0.6720833036576083, -1296.9000000000003), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 275.1414418985261, 275.1414418985261, 17, 75702.81305000003, -0.6720833036576083, -1296.9000000000003), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 275.1414418985261, 275.1414418985261, 19, 75702.81305000003, -0.6720833036576083, -1296.9000000000003), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 1, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 14, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 17, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 19, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 45, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique misty red olive", 1, 275.913996235693, 275.913996235693, 1, 76128.53331875002, -0.5774768996448021, -2547.7868749999993), + Row("Manufacturer#3", "almond antique misty red olive", 1, 275.913996235693, 275.913996235693, 14, 76128.53331875002, -0.5774768996448021, -2547.7868749999993), + Row("Manufacturer#3", "almond antique misty red olive", 1, 275.913996235693, 275.913996235693, 19, 76128.53331875002, -0.5774768996448021, -2547.7868749999993), + Row("Manufacturer#3", "almond antique misty red olive", 1, 275.913996235693, 275.913996235693, 45, 76128.53331875002, -0.5774768996448021, -2547.7868749999993), + Row("Manufacturer#3", "almond antique olive coral navajo", 45, 260.58159187137954, 260.58159187137954, 1, 67902.7660222222, -0.8710736366736884, -4099.731111111111), + Row("Manufacturer#3", "almond antique olive coral navajo", 45, 260.58159187137954, 260.58159187137954, 19, 67902.7660222222, -0.8710736366736884, -4099.731111111111), + Row("Manufacturer#3", "almond antique olive coral navajo", 45, 260.58159187137954, 260.58159187137954, 45, 67902.7660222222, -0.8710736366736884, -4099.731111111111), + Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 170.1301188959661, 170.1301188959661, 10, 28944.25735555556, -0.6656975320098423, -1347.4777777777779), + Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 170.1301188959661, 170.1301188959661, 27, 28944.25735555556, -0.6656975320098423, -1347.4777777777779), + Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 170.1301188959661, 170.1301188959661, 39, 28944.25735555556, -0.6656975320098423, -1347.4777777777779), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 242.26834609323197, 242.26834609323197, 7, 58693.95151875002, -0.8051852719193339, -2537.328125), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 242.26834609323197, 242.26834609323197, 10, 58693.95151875002, -0.8051852719193339, -2537.328125), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 242.26834609323197, 242.26834609323197, 27, 58693.95151875002, -0.8051852719193339, -2537.328125), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 242.26834609323197, 242.26834609323197, 39, 58693.95151875002, -0.8051852719193339, -2537.328125), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 7, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 10, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 12, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 27, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 39, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 247.33427141977316, 247.33427141977316, 7, 61174.241818750015, -0.5508665654707869, -1719.0368749999975), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 247.33427141977316, 247.33427141977316, 12, 61174.241818750015, -0.5508665654707869, -1719.0368749999975), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 247.33427141977316, 247.33427141977316, 27, 61174.241818750015, -0.5508665654707869, -1719.0368749999975), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 247.33427141977316, 247.33427141977316, 39, 61174.241818750015, -0.5508665654707869, -1719.0368749999975), + Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 283.33443305668936, 283.33443305668936, 7, 80278.4009555556, -0.7755740084632333, -1867.4888888888881), + Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 283.33443305668936, 283.33443305668936, 12, 80278.4009555556, -0.7755740084632333, -1867.4888888888881), + Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 283.33443305668936, 283.33443305668936, 27, 80278.4009555556, -0.7755740084632333, -1867.4888888888881), + Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 83.69879024746344, 83.69879024746344, 2, 7005.487488888881, 0.3900430308728505, 418.9233333333353), + Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 83.69879024746344, 83.69879024746344, 6, 7005.487488888881, 0.3900430308728505, 418.9233333333353), + Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 83.69879024746344, 83.69879024746344, 31, 7005.487488888881, 0.3900430308728505, 418.9233333333353), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 316.68049612345885, 316.68049612345885, 2, 100286.53662500005, -0.7136129117761831, -4090.853749999999), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 316.68049612345885, 316.68049612345885, 6, 100286.53662500005, -0.7136129117761831, -4090.853749999999), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 316.68049612345885, 316.68049612345885, 31, 100286.53662500005, -0.7136129117761831, -4090.853749999999), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 316.68049612345885, 316.68049612345885, 46, 100286.53662500005, -0.7136129117761831, -4090.853749999999), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 2, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 6, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 23, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 31, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 46, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 285.43749038756283, 285.43749038756283, 2, 81474.56091875004, -0.9841287871533909, -4871.028125000002), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 285.43749038756283, 285.43749038756283, 6, 81474.56091875004, -0.9841287871533909, -4871.028125000002), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 285.43749038756283, 285.43749038756283, 23, 81474.56091875004, -0.9841287871533909, -4871.028125000002), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 285.43749038756283, 285.43749038756283, 46, 81474.56091875004, -0.9841287871533909, -4871.028125000002), + Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 315.9225931564038, 315.9225931564038, 2, 99807.08486666666, -0.9978877469246935, -5664.856666666666), + Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 315.9225931564038, 315.9225931564038, 23, 99807.08486666666, -0.9978877469246935, -5664.856666666666), + Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 315.9225931564038, 315.9225931564038, 46, 99807.08486666666, -0.9978877469246935, -5664.856666666666))) + // scalastyle:on + } + + test("null arguments") { + checkAnswer(sql(""" + |select p_mfgr, p_name, p_size, + |sum(null) over(distribute by p_mfgr sort by p_name) as sum, + |avg(null) over(distribute by p_mfgr sort by p_name) as avg + |from part + """.stripMargin), + sql(""" + |select p_mfgr, p_name, p_size, + |null as sum, + |null as avg + |from part + """.stripMargin)) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala new file mode 100644 index 0000000000000..7b0c7a9f00514 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.nio.charset.StandardCharsets + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument} + +import org.apache.spark.sql.{Column, DataFrame, QueryTest} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelation} +import org.apache.spark.sql.sources.HadoopFsRelation + +/** + * A test suite that tests ORC filter API based filter pushdown optimization. + */ +class OrcFilterSuite extends QueryTest with OrcTest { + private def checkFilterPredicate( + df: DataFrame, + predicate: Predicate, + checker: (SearchArgument) => Unit): Unit = { + val output = predicate.collect { case a: Attribute => a }.distinct + val query = df + .select(output.map(e => Column(e)): _*) + .where(Column(predicate)) + + var maybeRelation: Option[HadoopFsRelation] = None + val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _)) => + maybeRelation = Some(orcRelation) + filters + }.flatten.reduceLeftOption(_ && _) + assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") + + val (_, selectedFilters) = + DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) + assert(selectedFilters.nonEmpty, "No filter is pushed down") + + val maybeFilter = OrcFilters.createFilter(selectedFilters.toArray) + assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $selectedFilters") + checker(maybeFilter.get) + } + + private def checkFilterPredicate + (predicate: Predicate, filterOperator: PredicateLeaf.Operator) + (implicit df: DataFrame): Unit = { + def checkComparisonOperator(filter: SearchArgument) = { + val operator = filter.getLeaves.asScala + assert(operator.map(_.getOperator).contains(filterOperator)) + } + checkFilterPredicate(df, predicate, checkComparisonOperator) + } + + private def checkFilterPredicate + (predicate: Predicate, stringExpr: String) + (implicit df: DataFrame): Unit = { + def checkLogicalOperator(filter: SearchArgument) = { + assert(filter.toString == stringExpr) + } + checkFilterPredicate(df, predicate, checkLogicalOperator) + } + + test("filter pushdown - boolean") { + withOrcDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + } + } + + test("filter pushdown - integer") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - long") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - float") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - double") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - string") { + withOrcDataFrame((1 to 4).map(i => Tuple1(i.toString))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === "1", PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> "1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < "2", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > "3", PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= "1", PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= "4", PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal("1") === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal("1") <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal("2") > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal("3") < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal("1") >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal("4") <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - binary") { + implicit class IntToBinary(int: Int) { + def b: Array[Byte] = int.toString.getBytes(StandardCharsets.UTF_8) + } + + withOrcDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + } + } + + test("filter pushdown - combinations with logical operators") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => + // Because `ExpressionTree` is not accessible at Hive 1.2.x, this should be checked + // in string form in order to check filter creation including logical operators + // such as `and`, `or` or `not`. So, this function uses `SearchArgument.toString()` + // to produce string expression and then compare it to given string expression below. + // This might have to be changed after Hive version is upgraded. + checkFilterPredicate( + '_1.isNotNull, + """leaf-0 = (IS_NULL _1) + |expr = (not leaf-0)""".stripMargin.trim + ) + checkFilterPredicate( + '_1 =!= 1, + """leaf-0 = (IS_NULL _1) + |leaf-1 = (EQUALS _1 1) + |expr = (and (not leaf-0) (not leaf-1))""".stripMargin.trim + ) + checkFilterPredicate( + !('_1 < 4), + """leaf-0 = (IS_NULL _1) + |leaf-1 = (LESS_THAN _1 4) + |expr = (and (not leaf-0) (not leaf-1))""".stripMargin.trim + ) + checkFilterPredicate( + '_1 < 2 || '_1 > 3, + """leaf-0 = (LESS_THAN _1 2) + |leaf-1 = (LESS_THAN_EQUALS _1 3) + |expr = (or leaf-0 (not leaf-1))""".stripMargin.trim + ) + checkFilterPredicate( + '_1 < 2 && '_1 > 3, + """leaf-0 = (IS_NULL _1) + |leaf-1 = (LESS_THAN _1 2) + |leaf-2 = (LESS_THAN_EQUALS _1 3) + |expr = (and (not leaf-0) leaf-1 (not leaf-2))""".stripMargin.trim + ) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index 92043d66c914f..2345c1cf9cc09 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -17,9 +17,14 @@ package org.apache.spark.sql.hive.orc -import org.apache.hadoop.fs.Path +import java.io.File + +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.hive.ql.io.orc.{CompressionKind, OrcFile} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.Row +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.HadoopFsRelationTest import org.apache.spark.sql.types._ @@ -60,4 +65,47 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { "dataSchema" -> dataSchemaWithPartition.json)).format(dataSourceName).load()) } } + + test("SPARK-12218: 'Not' is included in ORC filter pushdown") { + import testImplicits._ + + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/table1" + (1 to 5).map(i => (i, (i % 2).toString)).toDF("a", "b").write.orc(path) + + checkAnswer( + sqlContext.read.orc(path).where("not (a = 2) or not(b in ('1'))"), + (1 to 5).map(i => Row(i, (i % 2).toString))) + + checkAnswer( + sqlContext.read.orc(path).where("not (a = 2 and b in ('1'))"), + (1 to 5).map(i => Row(i, (i % 2).toString))) + } + } + } + + test("SPARK-13543: Support for specifying compression codec for ORC via option()") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/table1" + val df = (1 to 5).map(i => (i, (i % 2).toString)).toDF("a", "b") + df.write + .option("compression", "ZlIb") + .orc(path) + + // Check if this is compressed as ZLIB. + val conf = sparkContext.hadoopConfiguration + val fs = FileSystem.getLocal(conf) + val maybeOrcFile = new File(path).listFiles().find(_.getName.endsWith(".zlib.orc")) + assert(maybeOrcFile.isDefined) + val orcFilePath = new Path(maybeOrcFile.get.toPath.toString) + val orcReader = OrcFile.createReader(orcFilePath, OrcFile.readerOptions(conf)) + assert(orcReader.getCompression == CompressionKind.ZLIB) + + val copyDf = sqlContext + .read + .orc(path) + checkAnswer(df, copyDf) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index 52e09f9496f05..6161412a49775 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -22,8 +22,8 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.scalatest.BeforeAndAfterAll import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHiveSingleton diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 7efeab528c1dd..5ef8194f28881 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.orc import java.io.File +import java.nio.charset.StandardCharsets import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.io.orc.CompressionKind @@ -25,8 +26,11 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.internal.SQLConf case class AllDataTypesWithNonPrimitiveType( stringField: String, @@ -70,14 +74,14 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } test("Read/write binary data") { - withOrcFile(BinaryData("test".getBytes("utf8")) :: Nil) { file => + withOrcFile(BinaryData("test".getBytes(StandardCharsets.UTF_8)) :: Nil) { file => val bytes = read.orc(file).head().getAs[Array[Byte]](0) - assert(new String(bytes, "utf8") === "test") + assert(new String(bytes, StandardCharsets.UTF_8) === "test") } } test("Read/write all types with non-primitive type") { - val data = (0 to 255).map { i => + val data: Seq[AllDataTypesWithNonPrimitiveType] = (0 to 255).map { i => AllDataTypesWithNonPrimitiveType( s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0, 0 until i, @@ -118,6 +122,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { // expr = (not leaf-0) assertResult(10) { sql("SELECT name, contacts FROM t where age > 5") + .rdd .flatMap(_.getAs[Seq[_]]("contacts")) .count() } @@ -130,7 +135,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { val df = sql("SELECT name, contacts FROM t WHERE age > 5 AND age < 8") assert(df.count() === 2) assertResult(4) { - df.flatMap(_.getAs[Seq[_]]("contacts")).count() + df.rdd.flatMap(_.getAs[Seq[_]]("contacts")).count() } } @@ -142,7 +147,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { val df = sql("SELECT name, contacts FROM t WHERE age < 2 OR age > 8") assert(df.count() === 3) assertResult(6) { - df.flatMap(_.getAs[Seq[_]]("contacts")).count() + df.rdd.flatMap(_.getAs[Seq[_]]("contacts")).count() } } } @@ -219,7 +224,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { sql("INSERT INTO TABLE t SELECT * FROM tmp") checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) } - catalog.unregisterTable(TableIdentifier("tmp")) + sessionState.catalog.dropTable(TableIdentifier("tmp"), ignoreIfNotExists = true) } test("overwriting") { @@ -229,7 +234,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") checkAnswer(table("t"), data.map(Row.fromTuple)) } - catalog.unregisterTable(TableIdentifier("tmp")) + sessionState.catalog.dropTable(TableIdentifier("tmp"), ignoreIfNotExists = true) } test("self-join") { @@ -328,7 +333,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { sqlContext.read.orc(path) }.getMessage - assert(errorMessage.contains("Failed to discover schema from ORC files")) + assert(errorMessage.contains("Unable to infer schema for ORC")) val singleRowDF = Seq((0, "foo")).toDF("key", "value").coalesce(1) singleRowDF.registerTempTable("single") @@ -350,28 +355,87 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withTempPath { dir => withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { import testImplicits._ - val path = dir.getCanonicalPath - sqlContext.range(10).coalesce(1).write.orc(path) + + // For field "a", the first column has odds integers. This is to check the filtered count + // when `isNull` is performed. For Field "b", `isNotNull` of ORC file filters rows + // only when all the values are null (maybe this works differently when the data + // or query is complicated). So, simply here a column only having `null` is added. + val data = (0 until 10).map { i => + val maybeInt = if (i % 2 == 0) None else Some(i) + val nullValue: Option[String] = None + (maybeInt, nullValue) + } + // It needs to repartition data so that we can have several ORC files + // in order to skip stripes in ORC. + createDataFrame(data).toDF("a", "b").repartition(10).write.orc(path) val df = sqlContext.read.orc(path) - def checkPredicate(pred: Column, answer: Seq[Long]): Unit = { - checkAnswer(df.where(pred), answer.map(Row(_))) + def checkPredicate(pred: Column, answer: Seq[Row]): Unit = { + val sourceDf = stripSparkFilter(df.where(pred)) + val data = sourceDf.collect().toSet + val expectedData = answer.toSet + + // When a filter is pushed to ORC, ORC can apply it to rows. So, we can check + // the number of rows returned from the ORC to make sure our filter pushdown work. + // A tricky part is, ORC does not process filter rows fully but return some possible + // results. So, this checks if the number of result is less than the original count + // of data, and then checks if it contains the expected data. + assert( + sourceDf.count < 10 && expectedData.subsetOf(data), + s"No data was filtered for predicate: $pred") } - checkPredicate('id === 5, Seq(5L)) - checkPredicate('id <=> 5, Seq(5L)) - checkPredicate('id < 5, 0L to 4L) - checkPredicate('id <= 5, 0L to 5L) - checkPredicate('id > 5, 6L to 9L) - checkPredicate('id >= 5, 5L to 9L) - checkPredicate('id.isNull, Seq.empty[Long]) - checkPredicate('id.isNotNull, 0L to 9L) - checkPredicate('id.isin(1L, 3L, 5L), Seq(1L, 3L, 5L)) - checkPredicate('id > 0 && 'id < 3, 1L to 2L) - checkPredicate('id < 1 || 'id > 8, Seq(0L, 9L)) - checkPredicate(!('id > 3), 0L to 3L) - checkPredicate(!('id > 0 && 'id < 3), Seq(0L) ++ (3L to 9L)) + checkPredicate('a === 5, List(5).map(Row(_, null))) + checkPredicate('a <=> 5, List(5).map(Row(_, null))) + checkPredicate('a < 5, List(1, 3).map(Row(_, null))) + checkPredicate('a <= 5, List(1, 3, 5).map(Row(_, null))) + checkPredicate('a > 5, List(7, 9).map(Row(_, null))) + checkPredicate('a >= 5, List(5, 7, 9).map(Row(_, null))) + checkPredicate('a.isNull, List(null).map(Row(_, null))) + checkPredicate('b.isNotNull, List()) + checkPredicate('a.isin(3, 5, 7), List(3, 5, 7).map(Row(_, null))) + checkPredicate('a > 0 && 'a < 3, List(1).map(Row(_, null))) + checkPredicate('a < 1 || 'a > 8, List(9).map(Row(_, null))) + checkPredicate(!('a > 3), List(1, 3).map(Row(_, null))) + checkPredicate(!('a > 0 && 'a < 3), List(3, 5, 7, 9).map(Row(_, null))) + } + } + } + + test("SPARK-14070 Use ORC data source for SQL queries on ORC tables") { + withTempPath { dir => + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true", + HiveContext.CONVERT_METASTORE_ORC.key -> "true") { + val path = dir.getCanonicalPath + + withTable("dummy_orc") { + withTempTable("single") { + sqlContext.sql( + s"""CREATE TABLE dummy_orc(key INT, value STRING) + |STORED AS ORC + |LOCATION '$path' + """.stripMargin) + + val singleRowDF = Seq((0, "foo")).toDF("key", "value").coalesce(1) + singleRowDF.registerTempTable("single") + + sqlContext.sql( + s"""INSERT INTO TABLE dummy_orc + |SELECT key, value FROM single + """.stripMargin) + + val df = sqlContext.sql("SELECT * FROM dummy_orc WHERE key=0") + checkAnswer(df, singleRowDF) + + val queryExecution = df.queryExecution + queryExecution.analyzed.collectFirst { + case _: LogicalRelation => () + }.getOrElse { + fail(s"Expecting the query plan to have LogicalRelation, but got:\n$queryExecution") + } + } + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 7a34cf731b4c5..bdd3428a89742 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -23,6 +23,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.sources._ case class OrcData(intField: Int, stringField: String) @@ -67,8 +68,12 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA } override def afterAll(): Unit = { - orcTableDir.delete() - orcTableAsDir.delete() + try { + orcTableDir.delete() + orcTableAsDir.delete() + } finally { + super.afterAll() + } } test("create temporary orc table") { @@ -130,17 +135,17 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA val df = sql( """ |SELECT - | CAST(null as TINYINT), - | CAST(null as SMALLINT), - | CAST(null as INT), - | CAST(null as BIGINT), - | CAST(null as FLOAT), - | CAST(null as DOUBLE), - | CAST(null as DECIMAL(7,2)), - | CAST(null as TIMESTAMP), - | CAST(null as DATE), - | CAST(null as STRING), - | CAST(null as VARCHAR(10)) + | CAST(null as TINYINT) as c0, + | CAST(null as SMALLINT) as c1, + | CAST(null as INT) as c2, + | CAST(null as BIGINT) as c3, + | CAST(null as FLOAT) as c4, + | CAST(null as DOUBLE) as c5, + | CAST(null as DECIMAL(7,2)) as c6, + | CAST(null as TIMESTAMP) as c7, + | CAST(null as DATE) as c8, + | CAST(null as STRING) as c9, + | CAST(null as VARCHAR(10)) as c10 |FROM orc_temp_table limit 1 """.stripMargin) @@ -174,4 +179,33 @@ class OrcSourceSuite extends OrcSuite { |) """.stripMargin) } + + test("SPARK-12218 Converting conjunctions into ORC SearchArguments") { + // The `LessThan` should be converted while the `StringContains` shouldn't + assertResult( + """leaf-0 = (LESS_THAN a 10) + |expr = leaf-0 + """.stripMargin.trim + ) { + OrcFilters.createFilter(Array( + LessThan("a", 10), + StringContains("b", "prefix") + )).get.toString + } + + // The `LessThan` should be converted while the whole inner `And` shouldn't + assertResult( + """leaf-0 = (LESS_THAN a 10) + |expr = leaf-0 + """.stripMargin.trim + ) { + OrcFilters.createFilter(Array( + LessThan("a", 10), + Not(And( + GreaterThan("a", 1), + StringContains("b", "prefix") + )) + )).get.toString + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 88a0ed511749f..637c10611afc6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -23,8 +23,8 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql._ -import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton { import testImplicits._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 905eb7a3925b2..eac65d5720575 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -20,11 +20,14 @@ package org.apache.spark.sql.hive import java.io.File import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.execution.DataSourceScan +import org.apache.spark.sql.execution.command.ExecutedCommand import org.apache.spark.sql.execution.datasources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} -import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -55,6 +58,7 @@ case class ParquetDataWithKeyAndComplexTypes( */ class ParquetMetastoreSuite extends ParquetPartitioningTest { import hiveContext._ + import hiveContext.implicits._ override def beforeAll(): Unit = { super.beforeAll() @@ -168,10 +172,8 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { sql(s"ALTER TABLE partitioned_parquet_with_complextypes ADD PARTITION (p=$p)") } - val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) - read.json(rdd1).registerTempTable("jt") - val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":[$i, null]}""")) - read.json(rdd2).registerTempTable("jt_array") + (1 to 10).map(i => (i, s"str$i")).toDF("a", "b").registerTempTable("jt") + (1 to 10).map(i => Tuple1(Seq(new Integer(i), null))).toDF("a").registerTempTable("jt_array") setConf(HiveContext.CONVERT_METASTORE_PARQUET, true) } @@ -190,12 +192,12 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { test(s"conversion is working") { assert( - sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { + sql("SELECT * FROM normal_parquet").queryExecution.sparkPlan.collect { case _: HiveTableScan => true }.isEmpty) assert( - sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { - case _: PhysicalRDD => true + sql("SELECT * FROM normal_parquet").queryExecution.sparkPlan.collect { + case _: DataSourceScan => true }.nonEmpty) } @@ -282,10 +284,10 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { ) table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(_: ParquetRelation, _) => // OK + case LogicalRelation(_: HadoopFsRelation, _, _) => // OK case _ => fail( "test_parquet_ctas should be converted to " + - s"${classOf[ParquetRelation].getCanonicalName }") + s"${classOf[HadoopFsRelation ].getCanonicalName }") } } } @@ -305,11 +307,11 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") - df.queryExecution.executedPlan match { - case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation, _, _)) => // OK + df.queryExecution.sparkPlan match { + case ExecutedCommand(_: InsertIntoHadoopFsRelation) => // OK case o => fail("test_insert_parquet should be converted to a " + - s"${classOf[ParquetRelation].getCanonicalName} and " + - s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan. " + + s"${classOf[HadoopFsRelation ].getCanonicalName} and " + + s"${classOf[InsertIntoDataSource].getCanonicalName} is expected as the SparkPlan. " + s"However, found a ${o.toString} ") } @@ -335,11 +337,11 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") - df.queryExecution.executedPlan match { - case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation, _, _)) => // OK + df.queryExecution.sparkPlan match { + case ExecutedCommand(_: InsertIntoHadoopFsRelation) => // OK case o => fail("test_insert_parquet should be converted to a " + - s"${classOf[ParquetRelation].getCanonicalName} and " + - s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + + s"${classOf[HadoopFsRelation ].getCanonicalName} and " + + s"${classOf[InsertIntoDataSource].getCanonicalName} is expected as the SparkPlan." + s"However, found a ${o.toString} ") } @@ -369,18 +371,18 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { assertResult(2) { analyzed.collect { - case r @ LogicalRelation(_: ParquetRelation, _) => r + case r @ LogicalRelation(_: HadoopFsRelation, _, _) => r }.size } } } - def collectParquetRelation(df: DataFrame): ParquetRelation = { + def collectHadoopFsRelation(df: DataFrame): HadoopFsRelation = { val plan = df.queryExecution.analyzed plan.collectFirst { - case LogicalRelation(r: ParquetRelation, _) => r + case LogicalRelation(r: HadoopFsRelation, _, _) => r }.getOrElse { - fail(s"Expecting a ParquetRelation2, but got:\n$plan") + fail(s"Expecting a HadoopFsRelation 2, but got:\n$plan") } } @@ -395,9 +397,9 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) // First lookup fills the cache - val r1 = collectParquetRelation(table("nonPartitioned")) + val r1 = collectHadoopFsRelation (table("nonPartitioned")) // Second lookup should reuse the cache - val r2 = collectParquetRelation(table("nonPartitioned")) + val r2 = collectHadoopFsRelation (table("nonPartitioned")) // They should be the same instance assert(r1 eq r2) } @@ -415,20 +417,20 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) // First lookup fills the cache - val r1 = collectParquetRelation(table("partitioned")) + val r1 = collectHadoopFsRelation (table("partitioned")) // Second lookup should reuse the cache - val r2 = collectParquetRelation(table("partitioned")) + val r2 = collectHadoopFsRelation (table("partitioned")) // They should be the same instance assert(r1 eq r2) } } test("Caching converted data source Parquet Relations") { - def checkCached(tableIdentifier: catalog.QualifiedTableName): Unit = { + def checkCached(tableIdentifier: TableIdentifier): Unit = { // Converted test_parquet should be cached. - catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) match { + sessionState.catalog.getCachedDataSourceTable(tableIdentifier) match { case null => fail("Converted test_parquet should be cached in the cache.") - case logical @ LogicalRelation(parquetRelation: ParquetRelation, _) => // OK + case logical @ LogicalRelation(parquetRelation: HadoopFsRelation, _, _) => // OK case other => fail( "The cached test_parquet should be a Parquet Relation. " + @@ -451,17 +453,17 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' """.stripMargin) - var tableIdentifier = catalog.QualifiedTableName("default", "test_insert_parquet") + var tableIdentifier = TableIdentifier("test_insert_parquet", Some("default")) // First, make sure the converted test_parquet is not cached. - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) // Table lookup will make the table cached. table("test_insert_parquet") checkCached(tableIdentifier) // For insert into non-partitioned table, we will do the conversion, // so the converted test_insert_parquet should be cached. invalidateTable("test_insert_parquet") - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) sql( """ |INSERT INTO TABLE test_insert_parquet @@ -474,7 +476,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { sql("select a, b from jt").collect()) // Invalidate the cache. invalidateTable("test_insert_parquet") - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) // Create a partitioned table. sql( @@ -491,8 +493,8 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' """.stripMargin) - tableIdentifier = catalog.QualifiedTableName("default", "test_parquet_partitioned_cache_test") - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + tableIdentifier = TableIdentifier("test_parquet_partitioned_cache_test", Some("default")) + assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test @@ -501,14 +503,14 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) // Right now, insert into a partitioned Parquet is not supported in data source Parquet. // So, we expect it is not cached. - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test |PARTITION (`date`='2015-04-02') |select a, b from jt """.stripMargin) - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) // Make sure we can cache the partitioned table. table("test_parquet_partitioned_cache_test") @@ -524,7 +526,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin).collect()) invalidateTable("test_parquet_partitioned_cache_test") - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) + assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) dropTables("test_insert_parquet", "test_parquet_partitioned_cache_test") } @@ -590,7 +592,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { sql("drop table if exists spark_6016_fix") // Create a DataFrame with two partitions. So, the created table will have two parquet files. - val df1 = read.json(sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i}"""), 2)) + val df1 = (1 to 10).map(Tuple1(_)).toDF("a").coalesce(2) df1.write.mode(SaveMode.Overwrite).format("parquet").saveAsTable("spark_6016_fix") checkAnswer( sql("select * from spark_6016_fix"), @@ -598,7 +600,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { ) // Create a DataFrame with four partitions. So, the created table will have four parquet files. - val df2 = read.json(sparkContext.parallelize((1 to 10).map(i => s"""{"b":$i}"""), 4)) + val df2 = (1 to 10).map(Tuple1(_)).toDF("b").coalesce(4) df2.write.mode(SaveMode.Overwrite).format("parquet").saveAsTable("spark_6016_fix") // For the bug of SPARK-6016, we are caching two outdated footers for df1. Then, // since the new table has four parquet files, we are trying to read new footers from two files @@ -626,7 +628,10 @@ class ParquetSourceSuite extends ParquetPartitioningTest { sql( s"""CREATE TABLE array_of_struct |STORED AS PARQUET LOCATION '$path' - |AS SELECT '1st', '2nd', ARRAY(NAMED_STRUCT('a', 'val_a', 'b', 'val_b')) + |AS SELECT + | '1st' AS a, + | '2nd' AS b, + | ARRAY(NAMED_STRUCT('a', 'val_a', 'b', 'val_b')) AS c """.stripMargin) checkAnswer( @@ -695,6 +700,7 @@ abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with var partitionedTableDirWithKeyAndComplexTypes: File = null override def beforeAll(): Unit = { + super.beforeAll() partitionedTableDir = Utils.createTempDir() normalTableDir = Utils.createTempDir() @@ -752,6 +758,7 @@ abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with /** * Drop named tables if they exist + * * @param tableNames tables to drop */ def dropTables(tableNames: String*): Unit = { @@ -849,7 +856,7 @@ abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with test(s"hive udfs $table") { checkAnswer( sql(s"SELECT concat(stringField, stringField) FROM $table"), - sql(s"SELECT stringField FROM $table").map { + sql(s"SELECT stringField FROM $table").rdd.map { case Row(s: String) => Row(s + s) }.collect().toSeq) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala new file mode 100644 index 0000000000000..a0be55cfba94c --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.io.File + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.execution.DataSourceScan +import org.apache.spark.sql.execution.datasources.{BucketSpec, DataSourceStrategy} +import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.joins.SortMergeJoin +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.Utils +import org.apache.spark.util.collection.BitSet + +class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import testImplicits._ + + private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + private val nullDF = (for { + i <- 0 to 50 + s <- Seq(null, "a", "b", "c", "d", "e", "f", null, "g") + } yield (i % 5, s, i % 13)).toDF("i", "j", "k") + + test("read bucketed data") { + withTable("bucketed_table") { + df.write + .format("parquet") + .partitionBy("i") + .bucketBy(8, "j", "k") + .saveAsTable("bucketed_table") + + for (i <- 0 until 5) { + val table = hiveContext.table("bucketed_table").filter($"i" === i) + val query = table.queryExecution + val output = query.analyzed.output + val rdd = query.toRdd + + assert(rdd.partitions.length == 8) + + val attrs = table.select("j", "k").queryExecution.analyzed.output + val checkBucketId = rdd.mapPartitionsWithIndex((index, rows) => { + val getBucketId = UnsafeProjection.create( + HashPartitioning(attrs, 8).partitionIdExpression :: Nil, + output) + rows.map(row => getBucketId(row).getInt(0) -> index) + }) + checkBucketId.collect().foreach(r => assert(r._1 == r._2)) + } + } + } + + // To verify if the bucket pruning works, this function checks two conditions: + // 1) Check if the pruned buckets (before filtering) are empty. + // 2) Verify the final result is the same as the expected one + private def checkPrunedAnswers( + bucketSpec: BucketSpec, + bucketValues: Seq[Integer], + filterCondition: Column, + originalDataFrame: DataFrame): Unit = { + // This test verifies parts of the plan. Disable whole stage codegen. + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + val bucketedDataFrame = hiveContext.table("bucketed_table").select("i", "j", "k") + val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec + // Limit: bucket pruning only works when the bucket column has one and only one column + assert(bucketColumnNames.length == 1) + val bucketColumnIndex = bucketedDataFrame.schema.fieldIndex(bucketColumnNames.head) + val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex) + val matchedBuckets = new BitSet(numBuckets) + bucketValues.foreach { value => + matchedBuckets.set(DataSourceStrategy.getBucketId(bucketColumn, numBuckets, value)) + } + + // Filter could hide the bug in bucket pruning. Thus, skipping all the filters + val plan = bucketedDataFrame.filter(filterCondition).queryExecution.executedPlan + val rdd = plan.find(_.isInstanceOf[DataSourceScan]) + assert(rdd.isDefined, plan) + + val checkedResult = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) => + if (matchedBuckets.get(index % numBuckets) && iter.nonEmpty) Iterator(index) else Iterator() + } + // TODO: These tests are not testing the right columns. +// // checking if all the pruned buckets are empty +// val invalidBuckets = checkedResult.collect().toList +// if (invalidBuckets.nonEmpty) { +// fail(s"Buckets $invalidBuckets should have been pruned from:\n$plan") +// } + + checkAnswer( + bucketedDataFrame.filter(filterCondition).orderBy("i", "j", "k"), + originalDataFrame.filter(filterCondition).orderBy("i", "j", "k")) + } + } + + test("read partitioning bucketed tables with bucket pruning filters") { + withTable("bucketed_table") { + val numBuckets = 8 + val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) + // json does not support predicate push-down, and thus json is used here + df.write + .format("json") + .partitionBy("i") + .bucketBy(numBuckets, "j") + .saveAsTable("bucketed_table") + + for (j <- 0 until 13) { + // Case 1: EqualTo + checkPrunedAnswers( + bucketSpec, + bucketValues = j :: Nil, + filterCondition = $"j" === j, + df) + + // Case 2: EqualNullSafe + checkPrunedAnswers( + bucketSpec, + bucketValues = j :: Nil, + filterCondition = $"j" <=> j, + df) + + // Case 3: In + checkPrunedAnswers( + bucketSpec, + bucketValues = Seq(j, j + 1, j + 2, j + 3), + filterCondition = $"j".isin(j, j + 1, j + 2, j + 3), + df) + } + } + } + + test("read non-partitioning bucketed tables with bucket pruning filters") { + withTable("bucketed_table") { + val numBuckets = 8 + val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) + // json does not support predicate push-down, and thus json is used here + df.write + .format("json") + .bucketBy(numBuckets, "j") + .saveAsTable("bucketed_table") + + for (j <- 0 until 13) { + checkPrunedAnswers( + bucketSpec, + bucketValues = j :: Nil, + filterCondition = $"j" === j, + df) + } + } + } + + test("read partitioning bucketed tables having null in bucketing key") { + withTable("bucketed_table") { + val numBuckets = 8 + val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) + // json does not support predicate push-down, and thus json is used here + nullDF.write + .format("json") + .partitionBy("i") + .bucketBy(numBuckets, "j") + .saveAsTable("bucketed_table") + + // Case 1: isNull + checkPrunedAnswers( + bucketSpec, + bucketValues = null :: Nil, + filterCondition = $"j".isNull, + nullDF) + + // Case 2: <=> null + checkPrunedAnswers( + bucketSpec, + bucketValues = null :: Nil, + filterCondition = $"j" <=> null, + nullDF) + } + } + + test("read partitioning bucketed tables having composite filters") { + withTable("bucketed_table") { + val numBuckets = 8 + val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) + // json does not support predicate push-down, and thus json is used here + df.write + .format("json") + .partitionBy("i") + .bucketBy(numBuckets, "j") + .saveAsTable("bucketed_table") + + for (j <- 0 until 13) { + checkPrunedAnswers( + bucketSpec, + bucketValues = j :: Nil, + filterCondition = $"j" === j && $"k" > $"j", + df) + + checkPrunedAnswers( + bucketSpec, + bucketValues = j :: Nil, + filterCondition = $"j" === j && $"i" > j % 5, + df) + } + } + } + + private val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1") + private val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2") + + /** + * A helper method to test the bucket read functionality using join. It will save `df1` and `df2` + * to hive tables, bucketed or not, according to the given bucket specifics. Next we will join + * these 2 tables, and firstly make sure the answer is corrected, and then check if the shuffle + * exists as user expected according to the `shuffleLeft` and `shuffleRight`. + */ + private def testBucketing( + bucketSpecLeft: Option[BucketSpec], + bucketSpecRight: Option[BucketSpec], + joinColumns: Seq[String], + shuffleLeft: Boolean, + shuffleRight: Boolean): Unit = { + withTable("bucketed_table1", "bucketed_table2") { + def withBucket(writer: DataFrameWriter, bucketSpec: Option[BucketSpec]): DataFrameWriter = { + bucketSpec.map { spec => + writer.bucketBy( + spec.numBuckets, + spec.bucketColumnNames.head, + spec.bucketColumnNames.tail: _*) + }.getOrElse(writer) + } + + withBucket(df1.write.format("parquet"), bucketSpecLeft).saveAsTable("bucketed_table1") + withBucket(df2.write.format("parquet"), bucketSpecRight).saveAsTable("bucketed_table2") + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0", + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + val t1 = hiveContext.table("bucketed_table1") + val t2 = hiveContext.table("bucketed_table2") + val joined = t1.join(t2, joinCondition(t1, t2, joinColumns)) + + // First check the result is corrected. + checkAnswer( + joined.sort("bucketed_table1.k", "bucketed_table2.k"), + df1.join(df2, joinCondition(df1, df2, joinColumns)).sort("df1.k", "df2.k")) + + assert(joined.queryExecution.executedPlan.isInstanceOf[SortMergeJoin]) + val joinOperator = joined.queryExecution.executedPlan.asInstanceOf[SortMergeJoin] + + assert( + joinOperator.left.find(_.isInstanceOf[ShuffleExchange]).isDefined == shuffleLeft, + s"expected shuffle in plan to be $shuffleLeft but found\n${joinOperator.left}") + assert( + joinOperator.right.find(_.isInstanceOf[ShuffleExchange]).isDefined == shuffleRight, + s"expected shuffle in plan to be $shuffleRight but found\n${joinOperator.right}") + } + } + } + + private def joinCondition(left: DataFrame, right: DataFrame, joinCols: Seq[String]): Column = { + joinCols.map(col => left(col) === right(col)).reduce(_ && _) + } + + test("avoid shuffle when join 2 bucketed tables") { + val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) + testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false) + } + + // Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704 + ignore("avoid shuffle when join keys are a super-set of bucket keys") { + val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil)) + testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false) + } + + test("only shuffle one side when join bucketed table and non-bucketed table") { + val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) + testBucketing(bucketSpec, None, Seq("i", "j"), shuffleLeft = false, shuffleRight = true) + } + + test("only shuffle one side when 2 bucketed tables have different bucket number") { + val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Nil)) + val bucketSpec2 = Some(BucketSpec(5, Seq("i", "j"), Nil)) + testBucketing(bucketSpec1, bucketSpec2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true) + } + + test("only shuffle one side when 2 bucketed tables have different bucket keys") { + val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Nil)) + val bucketSpec2 = Some(BucketSpec(8, Seq("j"), Nil)) + testBucketing(bucketSpec1, bucketSpec2, Seq("i"), shuffleLeft = false, shuffleRight = true) + } + + test("shuffle when join keys are not equal to bucket keys") { + val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil)) + testBucketing(bucketSpec, bucketSpec, Seq("j"), shuffleLeft = true, shuffleRight = true) + } + + test("shuffle when join 2 bucketed tables with bucketing disabled") { + val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) + withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") { + testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = true, shuffleRight = true) + } + } + + test("avoid shuffle when grouping keys are equal to bucket keys") { + withTable("bucketed_table") { + df1.write.format("parquet").bucketBy(8, "i", "j").saveAsTable("bucketed_table") + val tbl = hiveContext.table("bucketed_table") + val agged = tbl.groupBy("i", "j").agg(max("k")) + + checkAnswer( + agged.sort("i", "j"), + df1.groupBy("i", "j").agg(max("k")).sort("i", "j")) + + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isEmpty) + } + } + + test("avoid shuffle when grouping keys are a super-set of bucket keys") { + withTable("bucketed_table") { + df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") + val tbl = hiveContext.table("bucketed_table") + val agged = tbl.groupBy("i", "j").agg(max("k")) + + checkAnswer( + agged.sort("i", "j"), + df1.groupBy("i", "j").agg(max("k")).sort("i", "j")) + + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isEmpty) + } + } + + test("error if there exists any malformed bucket files") { + withTable("bucketed_table") { + df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") + val tableDir = new File(hiveContext.warehousePath, "bucketed_table") + Utils.deleteRecursively(tableDir) + df1.write.parquet(tableDir.getAbsolutePath) + + val agged = hiveContext.table("bucketed_table").groupBy("i").count() + val error = intercept[RuntimeException] { + agged.count() + } + + assert(error.toString contains "Invalid bucket file") + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala new file mode 100644 index 0000000000000..a3e7737a7c059 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.io.File +import java.net.URI + +import org.apache.spark.SparkException +import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.execution.datasources.BucketingUtils +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils + +class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import testImplicits._ + + test("bucketed by non-existing column") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + intercept[AnalysisException](df.write.bucketBy(2, "k").saveAsTable("tt")) + } + + test("numBuckets not greater than 0 or less than 100000") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + intercept[IllegalArgumentException](df.write.bucketBy(0, "i").saveAsTable("tt")) + intercept[IllegalArgumentException](df.write.bucketBy(100000, "i").saveAsTable("tt")) + } + + test("specify sorting columns without bucketing columns") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + intercept[IllegalArgumentException](df.write.sortBy("j").saveAsTable("tt")) + } + + test("sorting by non-orderable column") { + val df = Seq("a" -> Map(1 -> 1), "b" -> Map(2 -> 2)).toDF("i", "j") + intercept[AnalysisException](df.write.bucketBy(2, "i").sortBy("j").saveAsTable("tt")) + } + + test("write bucketed data to unsupported data source") { + val df = Seq(Tuple1("a"), Tuple1("b")).toDF("i") + intercept[SparkException](df.write.bucketBy(3, "i").format("text").saveAsTable("tt")) + } + + test("write bucketed data to non-hive-table or existing hive table") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + intercept[IllegalArgumentException](df.write.bucketBy(2, "i").parquet("/tmp/path")) + intercept[IllegalArgumentException](df.write.bucketBy(2, "i").json("/tmp/path")) + intercept[IllegalArgumentException](df.write.bucketBy(2, "i").insertInto("tt")) + } + + private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + + def tableDir: File = { + val identifier = hiveContext.sessionState.sqlParser.parseTableIdentifier("bucketed_table") + new File(URI.create(hiveContext.sessionState.catalog.hiveDefaultTableFilePath(identifier))) + } + + /** + * A helper method to check the bucket write functionality in low level, i.e. check the written + * bucket files to see if the data are correct. User should pass in a data dir that these bucket + * files are written to, and the format of data(parquet, json, etc.), and the bucketing + * information. + */ + private def testBucketing( + dataDir: File, + source: String, + numBuckets: Int, + bucketCols: Seq[String], + sortCols: Seq[String] = Nil): Unit = { + val allBucketFiles = dataDir.listFiles().filterNot(f => + f.getName.startsWith(".") || f.getName.startsWith("_") + ) + + for (bucketFile <- allBucketFiles) { + val bucketId = BucketingUtils.getBucketId(bucketFile.getName).getOrElse { + fail(s"Unable to find the related bucket files.") + } + + // Remove the duplicate columns in bucketCols and sortCols; + // Otherwise, we got analysis errors due to duplicate names + val selectedColumns = (bucketCols ++ sortCols).distinct + // We may lose the type information after write(e.g. json format doesn't keep schema + // information), here we get the types from the original dataframe. + val types = df.select(selectedColumns.map(col): _*).schema.map(_.dataType) + val columns = selectedColumns.zip(types).map { + case (colName, dt) => col(colName).cast(dt) + } + + // Read the bucket file into a dataframe, so that it's easier to test. + val readBack = sqlContext.read.format(source) + .load(bucketFile.getAbsolutePath) + .select(columns: _*) + + // If we specified sort columns while writing bucket table, make sure the data in this + // bucket file is already sorted. + if (sortCols.nonEmpty) { + checkAnswer(readBack.sort(sortCols.map(col): _*), readBack.collect()) + } + + // Go through all rows in this bucket file, calculate bucket id according to bucket column + // values, and make sure it equals to the expected bucket id that inferred from file name. + val qe = readBack.select(bucketCols.map(col): _*).queryExecution + val rows = qe.toRdd.map(_.copy()).collect() + val getBucketId = UnsafeProjection.create( + HashPartitioning(qe.analyzed.output, numBuckets).partitionIdExpression :: Nil, + qe.analyzed.output) + + for (row <- rows) { + val actualBucketId = getBucketId(row).getInt(0) + assert(actualBucketId == bucketId) + } + } + } + + test("write bucketed data") { + for (source <- Seq("parquet", "json", "orc")) { + withTable("bucketed_table") { + df.write + .format(source) + .partitionBy("i") + .bucketBy(8, "j", "k") + .saveAsTable("bucketed_table") + + for (i <- 0 until 5) { + testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j", "k")) + } + } + } + } + + test("write bucketed data with sortBy") { + for (source <- Seq("parquet", "json", "orc")) { + withTable("bucketed_table") { + df.write + .format(source) + .partitionBy("i") + .bucketBy(8, "j") + .sortBy("k") + .saveAsTable("bucketed_table") + + for (i <- 0 until 5) { + testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j"), Seq("k")) + } + } + } + } + + test("write bucketed data with the overlapping bucketBy and partitionBy columns") { + intercept[AnalysisException](df.write + .partitionBy("i", "j") + .bucketBy(8, "j", "k") + .sortBy("k") + .saveAsTable("bucketed_table")) + } + + test("write bucketed data with the identical bucketBy and partitionBy columns") { + intercept[AnalysisException](df.write + .partitionBy("i") + .bucketBy(8, "i") + .saveAsTable("bucketed_table")) + } + + test("write bucketed data without partitionBy") { + for (source <- Seq("parquet", "json", "orc")) { + withTable("bucketed_table") { + df.write + .format(source) + .bucketBy(8, "i", "j") + .saveAsTable("bucketed_table") + + testBucketing(tableDir, source, 8, Seq("i", "j")) + } + } + } + + test("write bucketed data without partitionBy with sortBy") { + for (source <- Seq("parquet", "json", "orc")) { + withTable("bucketed_table") { + df.write + .format(source) + .bucketBy(8, "i", "j") + .sortBy("k") + .saveAsTable("bucketed_table") + + testBucketing(tableDir, source, 8, Seq("i", "j"), Seq("k")) + } + } + } + + test("write bucketed data with bucketing disabled") { + // The configuration BUCKETING_ENABLED does not affect the writing path + withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") { + for (source <- Seq("parquet", "json", "orc")) { + withTable("bucketed_table") { + df.write + .format(source) + .partitionBy("i") + .bucketBy(8, "j", "k") + .saveAsTable("bucketed_table") + + for (i <- 0 until 5) { + testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j", "k")) + } + } + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala deleted file mode 100644 index dc0531a6d4bc5..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources - -import org.apache.hadoop.fs.Path -import org.apache.spark.SparkException -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.test.SQLTestUtils - - -class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton { - - // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. - val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName - - test("SPARK-7684: commitTask() failure should fallback to abortTask()") { - withTempPath { file => - // Here we coalesce partition number to 1 to ensure that only a single task is issued. This - // prevents race condition happened when FileOutputCommitter tries to remove the `_temporary` - // directory while committing/aborting the job. See SPARK-8513 for more details. - val df = sqlContext.range(0, 10).coalesce(1) - intercept[SparkException] { - df.write.format(dataSourceName).save(file.getCanonicalPath) - } - - val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf) - assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary"))) - } - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index e2d754e806403..a15bd227a9201 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -23,7 +23,8 @@ import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.{execution, AnalysisException, SaveMode} +import org.apache.spark.sql._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -125,7 +126,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { test("SPARK-8604: Parquet data source should write summary file while doing appending") { withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext.range(0, 5) + val df = sqlContext.range(0, 5).toDF() df.write.mode(SaveMode.Overwrite).parquet(path) val summaryPath = new Path(path, "_metadata") @@ -136,7 +137,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { fs.delete(commonSummaryPath, true) df.write.mode(SaveMode.Append).parquet(path) - checkAnswer(sqlContext.read.parquet(path), df.unionAll(df)) + checkAnswer(sqlContext.read.parquet(path), df.union(df)) assert(fs.exists(summaryPath)) assert(fs.exists(commonSummaryPath)) @@ -149,10 +150,82 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { sqlContext.range(2).select('id as 'a, 'id as 'b).write.partitionBy("b").parquet(path) val df = sqlContext.read.parquet(path).filter('a === 0).select('b) - val physicalPlan = df.queryExecution.executedPlan + val physicalPlan = df.queryExecution.sparkPlan assert(physicalPlan.collect { case p: execution.Project => p }.length === 1) assert(physicalPlan.collect { case p: execution.Filter => p }.length === 1) } } + + test("SPARK-11500: Not deterministic order of columns when using merging schemas.") { + import testImplicits._ + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true") { + withTempPath { dir => + val pathOne = s"${dir.getCanonicalPath}/part=1" + Seq(1, 1).zipWithIndex.toDF("a", "b").write.parquet(pathOne) + val pathTwo = s"${dir.getCanonicalPath}/part=2" + Seq(1, 1).zipWithIndex.toDF("c", "b").write.parquet(pathTwo) + val pathThree = s"${dir.getCanonicalPath}/part=3" + Seq(1, 1).zipWithIndex.toDF("d", "b").write.parquet(pathThree) + + // The schema consists of the leading columns of the first part-file + // in the lexicographic order. + assert(sqlContext.read.parquet(dir.getCanonicalPath).schema.map(_.name) + === Seq("a", "b", "c", "d", "part")) + } + } + } + + test(s"SPARK-13537: Fix readBytes in VectorizedPlainValuesReader") { + withTempPath { file => + val path = file.getCanonicalPath + + val schema = new StructType() + .add("index", IntegerType, nullable = false) + .add("col", ByteType, nullable = true) + + val data = Seq(Row(1, -33.toByte), Row(2, 0.toByte), Row(3, -55.toByte), Row(4, 56.toByte), + Row(5, 127.toByte), Row(6, -44.toByte), Row(7, 23.toByte), Row(8, -95.toByte), + Row(9, 127.toByte), Row(10, 13.toByte)) + + val rdd = sqlContext.sparkContext.parallelize(data) + val df = sqlContext.createDataFrame(rdd, schema).orderBy("index").coalesce(1) + + df.write + .mode("overwrite") + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .save(path) + + val loadedDF = sqlContext + .read + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .schema(df.schema) + .load(path) + .orderBy("index") + + checkAnswer(loadedDF, df) + } + } + + test("SPARK-13543: Support for specifying compression codec for Parquet via option()") { + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "UNCOMPRESSED") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/table1" + val df = (1 to 5).map(i => (i, (i % 2).toString)).toDF("a", "b") + df.write + .option("compression", "GzIP") + .parquet(path) + + val compressedFiles = new File(path).listFiles() + assert(compressedFiles.exists(_.getName.endsWith(".gz.parquet"))) + + val copyDf = sqlContext + .read + .parquet(path) + checkAnswer(df, copyDf) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala deleted file mode 100644 index d945408341fc9..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources - -import org.apache.hadoop.fs.Path -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.execution.PhysicalRDD -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types._ - -class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { - import testImplicits._ - - override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName - - // We have a very limited number of supported types at here since it is just for a - // test relation and we do very basic testing at here. - override protected def supportsDataType(dataType: DataType): Boolean = dataType match { - case _: BinaryType => false - // We are using random data generator and the generated strings are not really valid string. - case _: StringType => false - case _: BooleanType => false // see https://issues.apache.org/jira/browse/SPARK-10442 - case _: CalendarIntervalType => false - case _: DateType => false - case _: TimestampType => false - case _: ArrayType => false - case _: MapType => false - case _: StructType => false - case _: UserDefinedType[_] => false - case _ => true - } - - test("save()/load() - partitioned table - simple queries - partition columns in data") { - withTempDir { file => - val basePath = new Path(file.getCanonicalPath) - val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) - val qualifiedBasePath = fs.makeQualified(basePath) - - for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { - val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") - sparkContext - .parallelize(for (i <- 1 to 3) yield s"$i,val_$i,$p1") - .saveAsTextFile(partitionDir.toString) - } - - val dataSchemaWithPartition = - StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) - - checkQueries( - hiveContext.read.format(dataSourceName) - .option("dataSchema", dataSchemaWithPartition.json) - .load(file.getCanonicalPath)) - } - } - - private val writer = testDF.write.option("dataSchema", dataSchema.json).format(dataSourceName) - private val reader = sqlContext.read.option("dataSchema", dataSchema.json).format(dataSourceName) - - test("unhandledFilters") { - withTempPath { dir => - - val path = dir.getCanonicalPath - writer.save(s"$path/p=0") - writer.save(s"$path/p=1") - - val isOdd = udf((_: Int) % 2 == 1) - val df = reader.load(path) - .filter( - // This filter is inconvertible - isOdd('a) && - // This filter is convertible but unhandled - 'a > 1 && - // This filter is convertible and handled - 'b > "val_1" && - // This filter references a partiiton column, won't be pushed down - 'p === 1 - ).select('a, 'p) - val rawScan = df.queryExecution.executedPlan collect { - case p: PhysicalRDD => p - } match { - case Seq(p) => p - } - - val outputSchema = new StructType().add("a", IntegerType).add("p", IntegerType) - - assertResult(Set((2, 1), (3, 1))) { - rawScan.execute().collect() - .map { CatalystTypeConverters.convertToScala(_, outputSchema) } - .map { case Row(a, p) => (a, p) }.toSet - } - - checkAnswer(df, Row(3, 1)) - } - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala deleted file mode 100644 index da09e1b00ae48..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ /dev/null @@ -1,229 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources - -import java.text.NumberFormat - -import com.google.common.base.Objects -import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.{NullWritable, Text} -import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat} -import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} - -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, expressions} -import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.sql.{Row, SQLContext, sources} - -/** - * A simple example [[HadoopFsRelationProvider]]. - */ -class SimpleTextSource extends HadoopFsRelationProvider { - override def createRelation( - sqlContext: SQLContext, - paths: Array[String], - schema: Option[StructType], - partitionColumns: Option[StructType], - parameters: Map[String, String]): HadoopFsRelation = { - new SimpleTextRelation(paths, schema, partitionColumns, parameters)(sqlContext) - } -} - -class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullWritable, Text] { - val numberFormat = NumberFormat.getInstance() - - numberFormat.setMinimumIntegerDigits(5) - numberFormat.setGroupingUsed(false) - - override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) - val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") - val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) - val split = taskAttemptId.getTaskID.getId - val name = FileOutputFormat.getOutputName(context) - new Path(outputFile, s"$name-${numberFormat.format(split)}-$uniqueWriteJobId") - } -} - -class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends OutputWriter { - private val recordWriter: RecordWriter[NullWritable, Text] = - new AppendingTextOutputFormat(new Path(path)).getRecordWriter(context) - - override def write(row: Row): Unit = { - val serialized = row.toSeq.map { v => - if (v == null) "" else v.toString - }.mkString(",") - recordWriter.write(null, new Text(serialized)) - } - - override def close(): Unit = { - recordWriter.close(context) - } -} - -/** - * A simple example [[HadoopFsRelation]], used for testing purposes. Data are stored as comma - * separated string lines. When scanning data, schema must be explicitly provided via data source - * option `"dataSchema"`. - */ -class SimpleTextRelation( - override val paths: Array[String], - val maybeDataSchema: Option[StructType], - override val userDefinedPartitionColumns: Option[StructType], - parameters: Map[String, String])( - @transient val sqlContext: SQLContext) - extends HadoopFsRelation { - - import sqlContext.sparkContext - - override val dataSchema: StructType = - maybeDataSchema.getOrElse(DataType.fromJson(parameters("dataSchema")).asInstanceOf[StructType]) - - override def equals(other: Any): Boolean = other match { - case that: SimpleTextRelation => - this.paths.sameElements(that.paths) && - this.maybeDataSchema == that.maybeDataSchema && - this.dataSchema == that.dataSchema && - this.partitionColumns == that.partitionColumns - - case _ => false - } - - override def hashCode(): Int = - Objects.hashCode(paths, maybeDataSchema, dataSchema, partitionColumns) - - override def buildScan(inputStatuses: Array[FileStatus]): RDD[Row] = { - val fields = dataSchema.map(_.dataType) - - sparkContext.textFile(inputStatuses.map(_.getPath).mkString(",")).map { record => - Row(record.split(",", -1).zip(fields).map { case (v, dataType) => - val value = if (v == "") null else v - // `Cast`ed values are always of Catalyst types (i.e. UTF8String instead of String, etc.) - val catalystValue = Cast(Literal(value), dataType).eval() - // Here we're converting Catalyst values to Scala values to test `needsConversion` - CatalystTypeConverters.convertToScala(catalystValue, dataType) - }: _*) - } - } - - override def buildScan( - requiredColumns: Array[String], - filters: Array[Filter], - inputFiles: Array[FileStatus]): RDD[Row] = { - - val fields = this.dataSchema.map(_.dataType) - val inputAttributes = this.dataSchema.toAttributes - val outputAttributes = requiredColumns.flatMap(name => inputAttributes.find(_.name == name)) - val dataSchema = this.dataSchema - - val inputPaths = inputFiles.map(_.getPath).mkString(",") - sparkContext.textFile(inputPaths).mapPartitions { iterator => - // Constructs a filter predicate to simulate filter push-down - val predicate = { - val filterCondition: Expression = filters.collect { - // According to `unhandledFilters`, `SimpleTextRelation` only handles `GreaterThan` filter - case sources.GreaterThan(column, value) => - val dataType = dataSchema(column).dataType - val literal = Literal.create(value, dataType) - val attribute = inputAttributes.find(_.name == column).get - expressions.GreaterThan(attribute, literal) - }.reduceOption(expressions.And).getOrElse(Literal(true)) - InterpretedPredicate.create(filterCondition, inputAttributes) - } - - // Uses a simple projection to simulate column pruning - val projection = new InterpretedMutableProjection(outputAttributes, inputAttributes) - val toScala = { - val requiredSchema = StructType.fromAttributes(outputAttributes) - CatalystTypeConverters.createToScalaConverter(requiredSchema) - } - - iterator.map { record => - new GenericInternalRow(record.split(",", -1).zip(fields).map { - case (v, dataType) => - val value = if (v == "") null else v - // `Cast`ed values are always of internal types (e.g. UTF8String instead of String) - Cast(Literal(value), dataType).eval() - }) - }.filter { row => - predicate(row) - }.map { row => - toScala(projection(row)).asInstanceOf[Row] - } - } - } - - override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory { - job.setOutputFormatClass(classOf[TextOutputFormat[_, _]]) - - override def newInstance( - path: String, - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter = { - new SimpleTextOutputWriter(path, context) - } - } - - // `SimpleTextRelation` only handles `GreaterThan` filter. This is used to test filter push-down - // and `BaseRelation.unhandledFilters()`. - override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { - filters.filter { - case _: GreaterThan => false - case _ => true - } - } -} - -/** - * A simple example [[HadoopFsRelationProvider]]. - */ -class CommitFailureTestSource extends HadoopFsRelationProvider { - override def createRelation( - sqlContext: SQLContext, - paths: Array[String], - schema: Option[StructType], - partitionColumns: Option[StructType], - parameters: Map[String, String]): HadoopFsRelation = { - new CommitFailureTestRelation(paths, schema, partitionColumns, parameters)(sqlContext) - } -} - -class CommitFailureTestRelation( - override val paths: Array[String], - maybeDataSchema: Option[StructType], - override val userDefinedPartitionColumns: Option[StructType], - parameters: Map[String, String])( - @transient sqlContext: SQLContext) - extends SimpleTextRelation( - paths, maybeDataSchema, userDefinedPartitionColumns, parameters)(sqlContext) { - override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory { - override def newInstance( - path: String, - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter = { - new SimpleTextOutputWriter(path, context) { - override def close(): Unit = { - super.close() - sys.error("Intentional task commitment failure for testing purpose.") - } - } - } - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 100b97137cff0..10eeb30242e2c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import scala.collection.JavaConverters._ +import scala.util.Random import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -27,9 +28,9 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ -import org.apache.spark.sql.execution.ConvertToUnsafe import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ @@ -59,7 +60,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes p2 <- Seq("foo", "bar") } yield (i, s"val_$i", 2, p2)).toDF("a", "b", "p1", "p2") - lazy val partitionedTestDF = partitionedTestDF1.unionAll(partitionedTestDF2) + lazy val partitionedTestDF = partitionedTestDF1.union(partitionedTestDF2) def checkQueries(df: DataFrame): Unit = { // Selects everything @@ -115,44 +116,56 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes new MyDenseVectorUDT() ).filter(supportsDataType) - for (dataType <- supportedDataTypes) { - test(s"test all data types - $dataType") { - withTempPath { file => - val path = file.getCanonicalPath - - val dataGenerator = RandomDataGenerator.forType( - dataType = dataType, - nullable = true, - seed = Some(System.nanoTime()) - ).getOrElse { - fail(s"Failed to create data generator for schema $dataType") + try { + for (dataType <- supportedDataTypes) { + for (parquetDictionaryEncodingEnabled <- Seq(true, false)) { + test(s"test all data types - $dataType with parquet.enable.dictionary = " + + s"$parquetDictionaryEncodingEnabled") { + + hadoopConfiguration.setBoolean("parquet.enable.dictionary", + parquetDictionaryEncodingEnabled) + + withTempPath { file => + val path = file.getCanonicalPath + + val dataGenerator = RandomDataGenerator.forType( + dataType = dataType, + nullable = true, + new Random(System.nanoTime()) + ).getOrElse { + fail(s"Failed to create data generator for schema $dataType") + } + + // Create a DF for the schema with random data. The index field is used to sort the + // DataFrame. This is a workaround for SPARK-10591. + val schema = new StructType() + .add("index", IntegerType, nullable = false) + .add("col", dataType, nullable = true) + val rdd = + sqlContext.sparkContext.parallelize((1 to 10).map(i => Row(i, dataGenerator()))) + val df = sqlContext.createDataFrame(rdd, schema).orderBy("index").coalesce(1) + + df.write + .mode("overwrite") + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .save(path) + + val loadedDF = sqlContext + .read + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .schema(df.schema) + .load(path) + .orderBy("index") + + checkAnswer(loadedDF, df) + } } - - // Create a DF for the schema with random data. The index field is used to sort the - // DataFrame. This is a workaround for SPARK-10591. - val schema = new StructType() - .add("index", IntegerType, nullable = false) - .add("col", dataType, nullable = true) - val rdd = sqlContext.sparkContext.parallelize((1 to 10).map(i => Row(i, dataGenerator()))) - val df = sqlContext.createDataFrame(rdd, schema).orderBy("index").coalesce(1) - - df.write - .mode("overwrite") - .format(dataSourceName) - .option("dataSchema", df.schema.json) - .save(path) - - val loadedDF = sqlContext - .read - .format(dataSourceName) - .option("dataSchema", df.schema.json) - .schema(df.schema) - .load(path) - .orderBy("index") - - checkAnswer(loadedDF, df) } } + } finally { + hadoopConfiguration.unset("parquet.enable.dictionary") } test("save()/load() - non-partitioned table - Overwrite") { @@ -178,7 +191,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes sqlContext.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath).orderBy("a"), - testDF.unionAll(testDF).orderBy("a").collect()) + testDF.union(testDF).orderBy("a").collect()) } } @@ -255,7 +268,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes sqlContext.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath), - partitionedTestDF.unionAll(partitionedTestDF).collect()) + partitionedTestDF.union(partitionedTestDF).collect()) } } @@ -319,7 +332,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes testDF.write.format(dataSourceName).mode(SaveMode.Append).saveAsTable("t") withTable("t") { - checkAnswer(sqlContext.table("t"), testDF.unionAll(testDF).orderBy("a").collect()) + checkAnswer(sqlContext.table("t"), testDF.union(testDF).orderBy("a").collect()) } } @@ -402,7 +415,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .saveAsTable("t") withTable("t") { - checkAnswer(sqlContext.table("t"), partitionedTestDF.unionAll(partitionedTestDF).collect()) + checkAnswer(sqlContext.table("t"), partitionedTestDF.union(partitionedTestDF).collect()) } } @@ -486,6 +499,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes val df = sqlContext.read .format(dataSourceName) .option("dataSchema", dataSchema.json) + .option("basePath", file.getCanonicalPath) .load(s"${file.getCanonicalPath}/p1=*/p2=???") val expectedPaths = Set( @@ -500,8 +514,8 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } val actualPaths = df.queryExecution.analyzed.collectFirst { - case LogicalRelation(relation: HadoopFsRelation, _) => - relation.paths.toSet + case LogicalRelation(relation: HadoopFsRelation, _, _) => + relation.location.paths.map(_.toString).toSet }.getOrElse { fail("Expect an FSBasedRelation, but none could be found") } @@ -558,7 +572,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .saveAsTable("t") withTable("t") { - checkAnswer(sqlContext.table("t"), df.select('b, 'c, 'a).collect()) + checkAnswer(sqlContext.table("t").select('b, 'c, 'a), df.select('b, 'c, 'a).collect()) } } @@ -611,7 +625,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .format(dataSourceName) .option("dataSchema", df.schema.json) .load(dir.getCanonicalPath), - df.unionAll(df)) + df.union(df)) // This will fail because AlwaysFailOutputCommitter is used when we do append. intercept[Exception] { @@ -654,70 +668,6 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes df.write.format(dataSourceName).partitionBy("c", "d", "e").saveAsTable("t") } } - - test("SPARK-9899 Disable customized output committer when speculation is on") { - val clonedConf = new Configuration(hadoopConfiguration) - val speculationEnabled = - sqlContext.sparkContext.conf.getBoolean("spark.speculation", defaultValue = false) - - try { - withTempPath { dir => - // Enables task speculation - sqlContext.sparkContext.conf.set("spark.speculation", "true") - - // Uses a customized output committer which always fails - hadoopConfiguration.set( - SQLConf.OUTPUT_COMMITTER_CLASS.key, - classOf[AlwaysFailOutputCommitter].getName) - - // Code below shouldn't throw since customized output committer should be disabled. - val df = sqlContext.range(10).coalesce(1) - df.write.format(dataSourceName).save(dir.getCanonicalPath) - checkAnswer( - sqlContext - .read - .format(dataSourceName) - .option("dataSchema", df.schema.json) - .load(dir.getCanonicalPath), - df) - } - } finally { - // Hadoop 1 doesn't have `Configuration.unset` - hadoopConfiguration.clear() - clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) - sqlContext.sparkContext.conf.set("spark.speculation", speculationEnabled.toString) - } - } - - test("HadoopFsRelation produces UnsafeRow") { - withTempTable("test_unsafe") { - withTempPath { dir => - val path = dir.getCanonicalPath - sqlContext.range(3).write.format(dataSourceName).save(path) - sqlContext.read - .format(dataSourceName) - .option("dataSchema", new StructType().add("id", LongType, nullable = false).json) - .load(path) - .registerTempTable("test_unsafe") - - val df = sqlContext.sql( - """SELECT COUNT(*) - |FROM test_unsafe a JOIN test_unsafe b - |WHERE a.id = b.id - """.stripMargin) - - val plan = df.queryExecution.executedPlan - - assert( - plan.collect { case plan: ConvertToUnsafe => plan }.isEmpty, - s"""Query plan shouldn't have ${classOf[ConvertToUnsafe].getSimpleName} node(s): - |$plan - """.stripMargin) - - checkAnswer(df, Row(3)) - } - } - } } // This class is used to test SPARK-8578. We should not use any custom output committer when diff --git a/streaming/pom.xml b/streaming/pom.xml index 145c8a7321c05..7d409c5d3b076 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT + spark-parent_2.11 + 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-streaming_2.10 + spark-streaming_2.11 streaming @@ -93,6 +93,11 @@ selenium-java test + + org.mockito + mockito-core + test + target/scala-${scala.binary.version}/classes diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java index 3738fc1a235c2..2803cad8095dd 100644 --- a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java +++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java @@ -37,26 +37,26 @@ public abstract class WriteAheadLog { * ensure that the written data is durable and readable (using the record handle) by the * time this function returns. */ - abstract public WriteAheadLogRecordHandle write(ByteBuffer record, long time); + public abstract WriteAheadLogRecordHandle write(ByteBuffer record, long time); /** * Read a written record based on the given record handle. */ - abstract public ByteBuffer read(WriteAheadLogRecordHandle handle); + public abstract ByteBuffer read(WriteAheadLogRecordHandle handle); /** * Read and return an iterator of all the records that have been written but not yet cleaned up. */ - abstract public Iterator readAll(); + public abstract Iterator readAll(); /** * Clean all the records that are older than the threshold time. It can wait for * the completion of the deletion. */ - abstract public void clean(long threshTime, boolean waitForCompletion); + public abstract void clean(long threshTime, boolean waitForCompletion); /** * Close this log and release any resources. */ - abstract public void close(); + public abstract void close(); } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index b7de6dde61c63..03956009541a0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -21,15 +21,15 @@ import java.io._ import java.util.concurrent.Executors import java.util.concurrent.RejectedExecutionException -import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.{SparkException, SparkConf, Logging} +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec -import org.apache.spark.util.{MetadataCleaner, Utils} import org.apache.spark.streaming.scheduler.JobGenerator - +import org.apache.spark.util.Utils private[streaming] class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) @@ -41,7 +41,6 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) val checkpointDir = ssc.checkpointDir val checkpointDuration = ssc.checkpointDuration val pendingTimes = ssc.scheduler.getPendingTimes().toArray - val delaySeconds = MetadataCleaner.getDelaySeconds(ssc.conf) val sparkConfPairs = ssc.conf.getAll def createSparkConf(): SparkConf = { @@ -55,7 +54,8 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.driver.port", "spark.master", "spark.yarn.keytab", - "spark.yarn.principal") + "spark.yarn.principal", + "spark.ui.filters") val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) .remove("spark.driver.host") @@ -66,6 +66,16 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) newSparkConf.set(prop, value) } } + + // Add Yarn proxy filter specific configurations to the recovered SparkConf + val filter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" + val filterPrefix = s"spark.$filter.param." + newReloadConf.getAll.foreach { case (k, v) => + if (k.startsWith(filterPrefix) && k.length > filterPrefix.length) { + newSparkConf.set(k, v) + } + } + newSparkConf } @@ -174,18 +184,32 @@ class CheckpointWriter( val executor = Executors.newFixedThreadPool(1) val compressionCodec = CompressionCodec.createCodec(conf) private var stopped = false - private var fs_ : FileSystem = _ + private var _fs: FileSystem = _ + + @volatile private var latestCheckpointTime: Time = null class CheckpointWriteHandler( checkpointTime: Time, bytes: Array[Byte], clearCheckpointDataLater: Boolean) extends Runnable { def run() { + if (latestCheckpointTime == null || latestCheckpointTime < checkpointTime) { + latestCheckpointTime = checkpointTime + } var attempts = 0 val startTime = System.currentTimeMillis() val tempFile = new Path(checkpointDir, "temp") - val checkpointFile = Checkpoint.checkpointFile(checkpointDir, checkpointTime) - val backupFile = Checkpoint.checkpointBackupFile(checkpointDir, checkpointTime) + // We will do checkpoint when generating a batch and completing a batch. When the processing + // time of a batch is greater than the batch interval, checkpointing for completing an old + // batch may run after checkpointing of a new batch. If this happens, checkpoint of an old + // batch actually has the latest information, so we want to recovery from it. Therefore, we + // also use the latest checkpoint time as the file name, so that we can recovery from the + // latest checkpoint file. + // + // Note: there is only one thread writing the checkpoint files, so we don't need to worry + // about thread-safety. + val checkpointFile = Checkpoint.checkpointFile(checkpointDir, latestCheckpointTime) + val backupFile = Checkpoint.checkpointBackupFile(checkpointDir, latestCheckpointTime) while (attempts < MAX_ATTEMPTS && !stopped) { attempts += 1 @@ -207,7 +231,7 @@ class CheckpointWriter( // If the checkpoint file exists, back it up // If the backup exists as well, just delete it, otherwise rename will fail if (fs.exists(checkpointFile)) { - if (fs.exists(backupFile)){ + if (fs.exists(backupFile)) { fs.delete(backupFile, true) // just in case it exists } if (!fs.rename(checkpointFile, backupFile)) { @@ -223,10 +247,10 @@ class CheckpointWriter( // Delete old checkpoint files val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)) if (allCheckpointFiles.size > 10) { - allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach(file => { + allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach { file => logInfo("Deleting " + file) fs.delete(file, true) - }) + } } // All done, print success @@ -252,7 +276,7 @@ class CheckpointWriter( val bytes = Checkpoint.serialize(checkpoint, conf) executor.execute(new CheckpointWriteHandler( checkpoint.checkpointTime, bytes, clearCheckpointDataLater)) - logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue") + logInfo("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue") } catch { case rej: RejectedExecutionException => logError("Could not submit checkpoint task to the thread pool executor", rej) @@ -275,12 +299,12 @@ class CheckpointWriter( } private def fs = synchronized { - if (fs_ == null) fs_ = new Path(checkpointDir).getFileSystem(hadoopConf) - fs_ + if (_fs == null) _fs = new Path(checkpointDir).getFileSystem(hadoopConf) + _fs } private def reset() = synchronized { - fs_ = null + _fs = null } } @@ -310,8 +334,7 @@ object CheckpointReader extends Logging { ignoreReadError: Boolean = false): Option[Checkpoint] = { val checkpointPath = new Path(checkpointDir) - // TODO(rxin): Why is this a def?! - def fs: FileSystem = checkpointPath.getFileSystem(hadoopConf) + val fs = checkpointPath.getFileSystem(hadoopConf) // Try to find the checkpoint files val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)).reverse @@ -322,7 +345,7 @@ object CheckpointReader extends Logging { // Try to read the checkpoint files in the order logInfo("Checkpoint files found: " + checkpointFiles.mkString(",")) var readError: Exception = null - checkpointFiles.foreach(file => { + checkpointFiles.foreach { file => logInfo("Attempting to load checkpoint from file " + file) try { val fis = fs.open(file) @@ -335,7 +358,7 @@ object CheckpointReader extends Logging { readError = e logWarning("Error reading checkpoint from file " + file, e) } - }) + } // If none of checkpoint files could be read, then throw exception if (!ignoreReadError) { @@ -347,8 +370,8 @@ object CheckpointReader extends Logging { } private[streaming] -class ObjectInputStreamWithLoader(inputStream_ : InputStream, loader: ClassLoader) - extends ObjectInputStream(inputStream_) { +class ObjectInputStreamWithLoader(_inputStream: InputStream, loader: ClassLoader) + extends ObjectInputStream(_inputStream) { override def resolveClass(desc: ObjectStreamClass): Class[_] = { try { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index 1b0b7890b3b00..54d736ee5101b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -17,11 +17,13 @@ package org.apache.spark.streaming +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} + import scala.collection.mutable.ArrayBuffer -import java.io.{ObjectInputStream, IOException, ObjectOutputStream} -import org.apache.spark.Logging + +import org.apache.spark.internal.Logging +import org.apache.spark.streaming.dstream.{DStream, InputDStream, ReceiverInputDStream} import org.apache.spark.streaming.scheduler.Job -import org.apache.spark.streaming.dstream.{DStream, ReceiverInputDStream, InputDStream} import org.apache.spark.util.Utils final private[streaming] class DStreamGraph extends Serializable with Logging { @@ -158,7 +160,7 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { require(batchDuration != null, "Batch duration has not been set") // assert(batchDuration >= Milliseconds(100), "Batch duration of " + batchDuration + // " is very low") - require(getOutputStreams().size > 0, "No output operations registered, so nothing to execute") + require(getOutputStreams().nonEmpty, "No output operations registered, so nothing to execute") } } @@ -167,7 +169,8 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { * safe remember duration which can be used to perform cleanup operations. */ def getMaxInputStreamRememberDuration(): Duration = { - inputStreams.map { _.rememberDuration }.maxBy { _.milliseconds } + // If an InputDStream is not used, its `rememberDuration` will be null and we can ignore them + inputStreams.map(_.rememberDuration).filter(_ != null).maxBy(_.milliseconds) } @throws(classOf[IOException]) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala new file mode 100644 index 0000000000000..42424d67d8838 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import scala.language.implicitConversions + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Abstract class for getting and updating the state in mapping function used in the `mapWithState` + * operation of a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) + * or a [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * + * Scala example of using `State`: + * {{{ + * // A mapping function that maintains an integer state and returns a String + * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = { + * // Check if state exists + * if (state.exists) { + * val existingState = state.get // Get the existing state + * val shouldRemove = ... // Decide whether to remove the state + * if (shouldRemove) { + * state.remove() // Remove the state + * } else { + * val newState = ... + * state.update(newState) // Set the new state + * } + * } else { + * val initialState = ... + * state.update(initialState) // Set the initial state + * } + * ... // return something + * } + * + * }}} + * + * Java example of using `State`: + * {{{ + * // A mapping function that maintains an integer state and returns a String + * Function3, State, String> mappingFunction = + * new Function3, State, String>() { + * + * @Override + * public String call(String key, Optional value, State state) { + * if (state.exists()) { + * int existingState = state.get(); // Get the existing state + * boolean shouldRemove = ...; // Decide whether to remove the state + * if (shouldRemove) { + * state.remove(); // Remove the state + * } else { + * int newState = ...; + * state.update(newState); // Set the new state + * } + * } else { + * int initialState = ...; // Set the initial state + * state.update(initialState); + * } + * // return something + * } + * }; + * }}} + * + * @tparam S Class of the state + */ +@Experimental +sealed abstract class State[S] { + + /** Whether the state already exists */ + def exists(): Boolean + + /** + * Get the state if it exists, otherwise it will throw `java.util.NoSuchElementException`. + * Check with `exists()` whether the state exists or not before calling `get()`. + * + * @throws java.util.NoSuchElementException If the state does not exist. + */ + def get(): S + + /** + * Update the state with a new value. + * + * State cannot be updated if it has been already removed (that is, `remove()` has already been + * called) or it is going to be removed due to timeout (that is, `isTimingOut()` is `true`). + * + * @throws java.lang.IllegalArgumentException If the state has already been removed, or is + * going to be removed + */ + def update(newState: S): Unit + + /** + * Remove the state if it exists. + * + * State cannot be updated if it has been already removed (that is, `remove()` has already been + * called) or it is going to be removed due to timeout (that is, `isTimingOut()` is `true`). + */ + def remove(): Unit + + /** + * Whether the state is timing out and going to be removed by the system after the current batch. + * This timeout can occur if timeout duration has been specified in the + * [[org.apache.spark.streaming.StateSpec StatSpec]] and the key has not received any new data + * for that timeout duration. + */ + def isTimingOut(): Boolean + + /** + * Get the state as an [[scala.Option]]. It will be `Some(state)` if it exists, otherwise `None`. + */ + @inline final def getOption(): Option[S] = if (exists) Some(get()) else None + + @inline final override def toString(): String = { + getOption.map { _.toString }.getOrElse("") + } +} + +/** Internal implementation of the [[State]] interface */ +private[streaming] class StateImpl[S] extends State[S] { + + private var state: S = null.asInstanceOf[S] + private var defined: Boolean = false + private var timingOut: Boolean = false + private var updated: Boolean = false + private var removed: Boolean = false + + // ========= Public API ========= + override def exists(): Boolean = { + defined + } + + override def get(): S = { + if (defined) { + state + } else { + throw new NoSuchElementException("State is not set") + } + } + + override def update(newState: S): Unit = { + require(!removed, "Cannot update the state after it has been removed") + require(!timingOut, "Cannot update the state that is timing out") + state = newState + defined = true + updated = true + } + + override def isTimingOut(): Boolean = { + timingOut + } + + override def remove(): Unit = { + require(!timingOut, "Cannot remove the state that is timing out") + require(!removed, "Cannot remove the state that has already been removed") + defined = false + updated = false + removed = true + } + + // ========= Internal API ========= + + /** Whether the state has been marked for removing */ + def isRemoved(): Boolean = { + removed + } + + /** Whether the state has been been updated */ + def isUpdated(): Boolean = { + updated + } + + /** + * Update the internal data and flags in `this` to the given state option. + * This method allows `this` object to be reused across many state records. + */ + def wrap(optionalState: Option[S]): Unit = { + optionalState match { + case Some(newState) => + this.state = newState + defined = true + + case None => + this.state = null.asInstanceOf[S] + defined = false + } + timingOut = false + removed = false + updated = false + } + + /** + * Update the internal data and flags in `this` to the given state that is going to be timed out. + * This method allows `this` object to be reused across many state records. + */ + def wrapTimingOutState(newState: S): Unit = { + this.state = newState + defined = true + timingOut = true + removed = false + updated = false + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala new file mode 100644 index 0000000000000..7c1ea2f89ddb8 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import org.apache.spark.{HashPartitioner, Partitioner} +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.{JavaPairRDD, JavaUtils, Optional} +import org.apache.spark.api.java.function.{Function3 => JFunction3, Function4 => JFunction4} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.ClosureCleaner + +/** + * :: Experimental :: + * Abstract class representing all the specifications of the DStream transformation + * `mapWithState` operation of a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * Use [[org.apache.spark.streaming.StateSpec.function() StateSpec.function]] factory methods + * to create instances of this class. + * + * Example in Scala: + * {{{ + * // A mapping function that maintains an integer state and return a String + * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = { + * // Use state.exists(), state.get(), state.update() and state.remove() + * // to manage state, and return the necessary string + * } + * + * val spec = StateSpec.function(mappingFunction).numPartitions(10) + * + * val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec) + * }}} + * + * Example in Java: + * {{{ + * // A mapping function that maintains an integer state and return a string + * Function3, State, String> mappingFunction = + * new Function3, State, String>() { + * @Override + * public Optional call(Optional value, State state) { + * // Use state.exists(), state.get(), state.update() and state.remove() + * // to manage state, and return the necessary string + * } + * }; + * + * JavaMapWithStateDStream mapWithStateDStream = + * keyValueDStream.mapWithState(StateSpec.function(mappingFunc)); + * }}} + * + * @tparam KeyType Class of the state key + * @tparam ValueType Class of the state value + * @tparam StateType Class of the state data + * @tparam MappedType Class of the mapped elements + */ +@Experimental +sealed abstract class StateSpec[KeyType, ValueType, StateType, MappedType] extends Serializable { + + /** Set the RDD containing the initial states that will be used by `mapWithState` */ + def initialState(rdd: RDD[(KeyType, StateType)]): this.type + + /** Set the RDD containing the initial states that will be used by `mapWithState` */ + def initialState(javaPairRDD: JavaPairRDD[KeyType, StateType]): this.type + + /** + * Set the number of partitions by which the state RDDs generated by `mapWithState` + * will be partitioned. Hash partitioning will be used. + */ + def numPartitions(numPartitions: Int): this.type + + /** + * Set the partitioner by which the state RDDs generated by `mapWithState` will be partitioned. + */ + def partitioner(partitioner: Partitioner): this.type + + /** + * Set the duration after which the state of an idle key will be removed. A key and its state is + * considered idle if it has not received any data for at least the given duration. The + * mapping function will be called one final time on the idle states that are going to be + * removed; [[org.apache.spark.streaming.State State.isTimingOut()]] set + * to `true` in that call. + */ + def timeout(idleDuration: Duration): this.type +} + + +/** + * :: Experimental :: + * Builder object for creating instances of [[org.apache.spark.streaming.StateSpec StateSpec]] + * that is used for specifying the parameters of the DStream transformation `mapWithState` + * that is used for specifying the parameters of the DStream transformation + * `mapWithState` operation of a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * + * Example in Scala: + * {{{ + * // A mapping function that maintains an integer state and return a String + * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = { + * // Use state.exists(), state.get(), state.update() and state.remove() + * // to manage state, and return the necessary string + * } + * + * val spec = StateSpec.function(mappingFunction).numPartitions(10) + * + * val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec) + * }}} + * + * Example in Java: + * {{{ + * // A mapping function that maintains an integer state and return a string + * Function3, State, String> mappingFunction = + * new Function3, State, String>() { + * @Override + * public Optional call(Optional value, State state) { + * // Use state.exists(), state.get(), state.update() and state.remove() + * // to manage state, and return the necessary string + * } + * }; + * + * JavaMapWithStateDStream mapWithStateDStream = + * keyValueDStream.mapWithState(StateSpec.function(mappingFunc)); + *}}} + */ +@Experimental +object StateSpec { + /** + * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications + * of the `mapWithState` operation on a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. + * + * @param mappingFunction The function applied on every data item to manage the associated state + * and generate the mapped data + * @tparam KeyType Class of the keys + * @tparam ValueType Class of the values + * @tparam StateType Class of the states data + * @tparam MappedType Class of the mapped data + */ + def function[KeyType, ValueType, StateType, MappedType]( + mappingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[MappedType] + ): StateSpec[KeyType, ValueType, StateType, MappedType] = { + ClosureCleaner.clean(mappingFunction, checkSerializable = true) + new StateSpecImpl(mappingFunction) + } + + /** + * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications + * of the `mapWithState` operation on a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. + * + * @param mappingFunction The function applied on every data item to manage the associated state + * and generate the mapped data + * @tparam ValueType Class of the values + * @tparam StateType Class of the states data + * @tparam MappedType Class of the mapped data + */ + def function[KeyType, ValueType, StateType, MappedType]( + mappingFunction: (KeyType, Option[ValueType], State[StateType]) => MappedType + ): StateSpec[KeyType, ValueType, StateType, MappedType] = { + ClosureCleaner.clean(mappingFunction, checkSerializable = true) + val wrappedFunction = + (time: Time, key: KeyType, value: Option[ValueType], state: State[StateType]) => { + Some(mappingFunction(key, value, state)) + } + new StateSpecImpl(wrappedFunction) + } + + /** + * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all + * the specifications of the `mapWithState` operation on a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]]. + * + * @param mappingFunction The function applied on every data item to manage the associated + * state and generate the mapped data + * @tparam KeyType Class of the keys + * @tparam ValueType Class of the values + * @tparam StateType Class of the states data + * @tparam MappedType Class of the mapped data + */ + def function[KeyType, ValueType, StateType, MappedType](mappingFunction: + JFunction4[Time, KeyType, Optional[ValueType], State[StateType], Optional[MappedType]]): + StateSpec[KeyType, ValueType, StateType, MappedType] = { + val wrappedFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => { + val t = mappingFunction.call(time, k, JavaUtils.optionToOptional(v), s) + if (t.isPresent) { + Some(t.get) + } else { + None + } + } + StateSpec.function(wrappedFunc) + } + + /** + * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications + * of the `mapWithState` operation on a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]]. + * + * @param mappingFunction The function applied on every data item to manage the associated + * state and generate the mapped data + * @tparam ValueType Class of the values + * @tparam StateType Class of the states data + * @tparam MappedType Class of the mapped data + */ + def function[KeyType, ValueType, StateType, MappedType]( + mappingFunction: JFunction3[KeyType, Optional[ValueType], State[StateType], MappedType]): + StateSpec[KeyType, ValueType, StateType, MappedType] = { + val wrappedFunc = (k: KeyType, v: Option[ValueType], s: State[StateType]) => { + mappingFunction.call(k, JavaUtils.optionToOptional(v), s) + } + StateSpec.function(wrappedFunc) + } +} + + +/** Internal implementation of [[org.apache.spark.streaming.StateSpec]] interface. */ +private[streaming] +case class StateSpecImpl[K, V, S, T]( + function: (Time, K, Option[V], State[S]) => Option[T]) extends StateSpec[K, V, S, T] { + + require(function != null) + + @volatile private var partitioner: Partitioner = null + @volatile private var initialStateRDD: RDD[(K, S)] = null + @volatile private var timeoutInterval: Duration = null + + override def initialState(rdd: RDD[(K, S)]): this.type = { + this.initialStateRDD = rdd + this + } + + override def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type = { + this.initialStateRDD = javaPairRDD.rdd + this + } + + override def numPartitions(numPartitions: Int): this.type = { + this.partitioner(new HashPartitioner(numPartitions)) + this + } + + override def partitioner(partitioner: Partitioner): this.type = { + this.partitioner = partitioner + this + } + + override def timeout(interval: Duration): this.type = { + this.timeoutInterval = interval + this + } + + // ================= Private Methods ================= + + private[streaming] def getFunction(): (Time, K, Option[V], State[S]) => Option[T] = function + + private[streaming] def getInitialStateRDD(): Option[RDD[(K, S)]] = Option(initialStateRDD) + + private[streaming] def getPartitioner(): Option[Partitioner] = Option(partitioner) + + private[streaming] def getTimeoutInterval(): Option[Duration] = Option(timeoutInterval) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 97113835f3bd0..928739a416f0f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming import java.io.{InputStream, NotSerializableException} +import java.util.Properties import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} import scala.collection.Map @@ -25,24 +26,26 @@ import scala.collection.mutable.Queue import scala.reflect.ClassTag import scala.util.control.NonFatal -import akka.actor.{Props, SupervisorStrategy} +import org.apache.commons.lang.SerializationUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{BytesWritable, LongWritable, Text} -import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.spark._ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.input.FixedLengthBinaryInputFormat +import org.apache.spark.internal.Logging import org.apache.spark.rdd.{RDD, RDDOperationScope} +import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.SerializationDebugger import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContextState._ import org.apache.spark.streaming.dstream._ -import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} -import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} +import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.streaming.scheduler.{ExecutorAllocationManager, JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} import org.apache.spark.util.{CallSite, ShutdownHookManager, ThreadUtils, Utils} @@ -58,9 +61,9 @@ import org.apache.spark.util.{CallSite, ShutdownHookManager, ThreadUtils, Utils} * of the context by `stop()` or by an exception. */ class StreamingContext private[streaming] ( - sc_ : SparkContext, - cp_ : Checkpoint, - batchDur_ : Duration + _sc: SparkContext, + _cp: Checkpoint, + _batchDur: Duration ) extends Logging { /** @@ -105,7 +108,7 @@ class StreamingContext private[streaming] ( * HDFS compatible filesystems */ def this(path: String, hadoopConf: Configuration) = - this(null, CheckpointReader.read(path, new SparkConf(), hadoopConf).get, null) + this(null, CheckpointReader.read(path, new SparkConf(), hadoopConf).orNull, null) /** * Recreate a StreamingContext from a checkpoint file. @@ -121,23 +124,20 @@ class StreamingContext private[streaming] ( def this(path: String, sparkContext: SparkContext) = { this( sparkContext, - CheckpointReader.read(path, sparkContext.conf, sparkContext.hadoopConfiguration).get, + CheckpointReader.read(path, sparkContext.conf, sparkContext.hadoopConfiguration).orNull, null) } + require(_sc != null || _cp != null, + "Spark Streaming cannot be initialized with both SparkContext and checkpoint as null") - if (sc_ == null && cp_ == null) { - throw new Exception("Spark Streaming cannot be initialized with " + - "both SparkContext and checkpoint as null") - } - - private[streaming] val isCheckpointPresent = (cp_ != null) + private[streaming] val isCheckpointPresent: Boolean = _cp != null private[streaming] val sc: SparkContext = { - if (sc_ != null) { - sc_ + if (_sc != null) { + _sc } else if (isCheckpointPresent) { - SparkContext.getOrCreate(cp_.createSparkConf()) + SparkContext.getOrCreate(_cp.createSparkConf()) } else { throw new SparkException("Cannot create StreamingContext without a SparkContext") } @@ -154,13 +154,13 @@ class StreamingContext private[streaming] ( private[streaming] val graph: DStreamGraph = { if (isCheckpointPresent) { - cp_.graph.setContext(this) - cp_.graph.restoreCheckpointData() - cp_.graph + _cp.graph.setContext(this) + _cp.graph.restoreCheckpointData() + _cp.graph } else { - require(batchDur_ != null, "Batch duration for StreamingContext cannot be null") + require(_batchDur != null, "Batch duration for StreamingContext cannot be null") val newGraph = new DStreamGraph() - newGraph.setBatchDuration(batchDur_) + newGraph.setBatchDuration(_batchDur) newGraph } } @@ -169,15 +169,15 @@ class StreamingContext private[streaming] ( private[streaming] var checkpointDir: String = { if (isCheckpointPresent) { - sc.setCheckpointDir(cp_.checkpointDir) - cp_.checkpointDir + sc.setCheckpointDir(_cp.checkpointDir) + _cp.checkpointDir } else { null } } private[streaming] val checkpointDuration: Duration = { - if (isCheckpointPresent) cp_.checkpointDuration else graph.batchDuration + if (isCheckpointPresent) _cp.checkpointDuration else graph.batchDuration } private[streaming] val scheduler = new JobScheduler(this) @@ -200,6 +200,10 @@ class StreamingContext private[streaming] ( private val startSite = new AtomicReference[CallSite](null) + // Copy of thread-local properties from SparkContext. These properties will be set in all tasks + // submitted by this StreamingContext after start. + private[streaming] val savedProperties = new AtomicReference[Properties](new Properties) + private[streaming] def getStartSite(): CallSite = startSite.get() private var shutdownHookRef: AnyRef = _ @@ -212,8 +216,8 @@ class StreamingContext private[streaming] ( def sparkContext: SparkContext = sc /** - * Set each DStreams in this context to remember RDDs it generated in the last given duration. - * DStreams remember RDDs only for a limited duration of time and releases them for garbage + * Set each DStream in this context to remember RDDs it generated in the last given duration. + * DStreams remember RDDs only for a limited duration of time and release them for garbage * collection. This method allows the developer to specify how long to remember the RDDs ( * if the developer wishes to query old data outside the DStream computation). * @param duration Minimum duration that each DStream should remember its RDDs @@ -226,7 +230,7 @@ class StreamingContext private[streaming] ( * Set the context to periodically checkpoint the DStream operations for driver * fault-tolerance. * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored. - * Note that this must be a fault-tolerant file system like HDFS for + * Note that this must be a fault-tolerant file system like HDFS. */ def checkpoint(directory: String) { if (directory != null) { @@ -246,7 +250,7 @@ class StreamingContext private[streaming] ( } private[streaming] def initialCheckpoint: Checkpoint = { - if (isCheckpointPresent) cp_ else null + if (isCheckpointPresent) _cp else null } private[streaming] def getNewInputStreamId() = nextInputStreamId.getAndIncrement() @@ -271,21 +275,7 @@ class StreamingContext private[streaming] ( /** * Create an input stream with any arbitrary user implemented receiver. - * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html - * @param receiver Custom implementation of Receiver - * - * @deprecated As of 1.0.0", replaced by `receiverStream`. - */ - @deprecated("Use receiverStream", "1.0.0") - def networkStream[T: ClassTag](receiver: Receiver[T]): ReceiverInputDStream[T] = { - withNamedScope("network stream") { - receiverStream(receiver) - } - } - - /** - * Create an input stream with any arbitrary user implemented receiver. - * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html + * Find more details at http://spark.apache.org/docs/latest/streaming-custom-receivers.html * @param receiver Custom implementation of Receiver */ def receiverStream[T: ClassTag](receiver: Receiver[T]): ReceiverInputDStream[T] = { @@ -295,34 +285,14 @@ class StreamingContext private[streaming] ( } /** - * Create an input stream with any arbitrary user implemented actor receiver. - * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html - * @param props Props object defining creation of the actor - * @param name Name of the actor - * @param storageLevel RDD storage level (default: StorageLevel.MEMORY_AND_DISK_SER_2) - * - * @note An important point to note: - * Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e parametrized type of data received and actorStream - * should be same. - */ - def actorStream[T: ClassTag]( - props: Props, - name: String, - storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2, - supervisorStrategy: SupervisorStrategy = ActorSupervisorStrategy.defaultStrategy - ): ReceiverInputDStream[T] = withNamedScope("actor stream") { - receiverStream(new ActorReceiver[T](props, name, storageLevel, supervisorStrategy)) - } - - /** - * Create a input stream from TCP source hostname:port. Data is received using + * Creates an input stream from TCP source hostname:port. Data is received using * a TCP socket and the receive bytes is interpreted as UTF8 encoded `\n` delimited * lines. * @param hostname Hostname to connect to for receiving data * @param port Port to connect to for receiving data * @param storageLevel Storage level to use for storing the received objects * (default: StorageLevel.MEMORY_AND_DISK_SER_2) + * @see [[socketStream]] */ def socketTextStream( hostname: String, @@ -333,8 +303,8 @@ class StreamingContext private[streaming] ( } /** - * Create a input stream from TCP source hostname:port. Data is received using - * a TCP socket and the receive bytes it interepreted as object using the given + * Creates an input stream from TCP source hostname:port. Data is received using + * a TCP socket and the receive bytes it interpreted as object using the given * converter. * @param hostname Hostname to connect to for receiving data * @param port Port to connect to for receiving data @@ -460,7 +430,7 @@ class StreamingContext private[streaming] ( def binaryRecordsStream( directory: String, recordLength: Int): DStream[Array[Byte]] = withNamedScope("binary records stream") { - val conf = sc_.hadoopConfiguration + val conf = _sc.hadoopConfiguration conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength) val br = fileStream[LongWritable, BytesWritable, FixedLengthBinaryInputFormat]( directory, FileInputDStream.defaultFilter: Path => Boolean, newFilesOnly = true, conf) @@ -480,7 +450,7 @@ class StreamingContext private[streaming] ( * NOTE: Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of * those RDDs, so `queueStream` doesn't support checkpointing. * - * @param queue Queue of RDDs + * @param queue Queue of RDDs. Modifications to this data structure must be synchronized. * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @tparam T Type of objects in the RDD */ @@ -498,7 +468,7 @@ class StreamingContext private[streaming] ( * NOTE: Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of * those RDDs, so `queueStream` doesn't support checkpointing. * - * @param queue Queue of RDDs + * @param queue Queue of RDDs. Modifications to this data structure must be synchronized. * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @param defaultRDD Default RDD is returned by the DStream when the queue is empty. * Set as null if no RDD should be returned when empty @@ -530,9 +500,10 @@ class StreamingContext private[streaming] ( new TransformedDStream[T](dstreams, sparkContext.clean(transformFunc)) } - /** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for - * receiving system events related to streaming. - */ + /** + * Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for + * receiving system events related to streaming. + */ def addStreamingListener(streamingListener: StreamingListener) { scheduler.listenerBus.addListener(streamingListener) } @@ -549,7 +520,7 @@ class StreamingContext private[streaming] ( // Verify whether the DStream checkpoint is serializable if (isCheckpointingEnabled) { - val checkpoint = new Checkpoint(this, Time.apply(0)) + val checkpoint = new Checkpoint(this, Time(0)) try { Checkpoint.serialize(checkpoint, conf) } catch { @@ -562,11 +533,12 @@ class StreamingContext private[streaming] ( } } - if (Utils.isDynamicAllocationEnabled(sc.conf)) { + if (Utils.isDynamicAllocationEnabled(sc.conf) || + ExecutorAllocationManager.isDynamicAllocationEnabled(conf)) { logWarning("Dynamic Allocation is enabled for this application. " + "Enabling Dynamic allocation for Spark Streaming applications can cause data loss if " + "Write Ahead Log is not enabled for non-replayable sources like Flume. " + - "See the programming guide for details on how to enable the Write Ahead Log") + "See the programming guide for details on how to enable the Write Ahead Log.") } } @@ -574,11 +546,12 @@ class StreamingContext private[streaming] ( * :: DeveloperApi :: * * Return the current state of the context. The context can be in three possible states - - * - StreamingContextState.INTIALIZED - The context has been created, but not been started yet. - * Input DStreams, transformations and output operations can be created on the context. - * - StreamingContextState.ACTIVE - The context has been started, and been not stopped. - * Input DStreams, transformations and output operations cannot be created on the context. - * - StreamingContextState.STOPPED - The context has been stopped and cannot be used any more. + * + * - StreamingContextState.INITIALIZED - The context has been created, but not started yet. + * Input DStreams, transformations and output operations can be created on the context. + * - StreamingContextState.ACTIVE - The context has been started, and not stopped. + * Input DStreams, transformations and output operations cannot be created on the context. + * - StreamingContextState.STOPPED - The context has been stopped and cannot be used any more. */ @DeveloperApi def getState(): StreamingContextState = synchronized { @@ -606,6 +579,8 @@ class StreamingContext private[streaming] ( sparkContext.setCallSite(startSite.get) sparkContext.clearJobGroup() sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false") + savedProperties.set(SerializationUtils.clone( + sparkContext.localProperties.get()).asInstanceOf[Properties]) scheduler.start() } state = StreamingContextState.ACTIVE @@ -641,18 +616,6 @@ class StreamingContext private[streaming] ( waiter.waitForStopOrError() } - /** - * Wait for the execution to stop. Any exceptions that occurs during the execution - * will be thrown in this thread. - * @param timeout time to wait in milliseconds - * - * @deprecated As of 1.3.0, replaced by `awaitTerminationOrTimeout(Long)`. - */ - @deprecated("Use awaitTerminationOrTimeout(Long) instead", "1.3.0") - def awaitTermination(timeout: Long) { - waiter.waitForStopOrError(timeout) - } - /** * Wait for the execution to stop. Any exceptions that occurs during the execution * will be thrown in this thread. @@ -693,29 +656,46 @@ class StreamingContext private[streaming] ( */ def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = { var shutdownHookRefToRemove: AnyRef = null + if (LiveListenerBus.withinListenerThread.value) { + throw new SparkException( + s"Cannot stop StreamingContext within listener thread of ${LiveListenerBus.name}") + } synchronized { - try { - state match { - case INITIALIZED => - logWarning("StreamingContext has not been started yet") - case STOPPED => - logWarning("StreamingContext has already been stopped") - case ACTIVE => + // The state should always be Stopped after calling `stop()`, even if we haven't started yet + state match { + case INITIALIZED => + logWarning("StreamingContext has not been started yet") + state = STOPPED + case STOPPED => + logWarning("StreamingContext has already been stopped") + state = STOPPED + case ACTIVE => + // It's important that we don't set state = STOPPED until the very end of this case, + // since we need to ensure that we're still able to call `stop()` to recover from + // a partially-stopped StreamingContext which resulted from this `stop()` call being + // interrupted. See SPARK-12001 for more details. Because the body of this case can be + // executed twice in the case of a partial stop, all methods called here need to be + // idempotent. + Utils.tryLogNonFatalError { scheduler.stop(stopGracefully) - // Removing the streamingSource to de-register the metrics on stop() + } + // Removing the streamingSource to de-register the metrics on stop() + Utils.tryLogNonFatalError { env.metricsSystem.removeSource(streamingSource) + } + Utils.tryLogNonFatalError { uiTab.foreach(_.detach()) - StreamingContext.setActiveContext(null) + } + StreamingContext.setActiveContext(null) + Utils.tryLogNonFatalError { waiter.notifyStop() - if (shutdownHookRef != null) { - shutdownHookRefToRemove = shutdownHookRef - shutdownHookRef = null - } - logInfo("StreamingContext stopped successfully") - } - } finally { - // The state should always be Stopped after calling `stop()`, even if we haven't started yet - state = STOPPED + } + if (shutdownHookRef != null) { + shutdownHookRefToRemove = shutdownHookRef + shutdownHookRef = null + } + logInfo("StreamingContext stopped successfully") + state = STOPPED } } if (shutdownHookRefToRemove != null) { @@ -780,18 +760,6 @@ object StreamingContext extends Logging { } } - /** - * @deprecated As of 1.3.0, replaced by implicit functions in the DStream companion object. - * This is kept here only for backward compatibility. - */ - @deprecated("Replaced by implicit functions in the DStream companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - def toPairDStreamFunctions[K, V](stream: DStream[(K, V)]) - (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) - : PairDStreamFunctions[K, V] = { - DStream.toPairDStreamFunctions(stream)(kt, vt, ord) - } - /** * :: Experimental :: * @@ -882,12 +850,25 @@ object StreamingContext extends Logging { } private[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = { - if (prefix == null) { - time.milliseconds.toString - } else if (suffix == null || suffix.length ==0) { - prefix + "-" + time.milliseconds - } else { - prefix + "-" + time.milliseconds + "." + suffix + var result = time.milliseconds.toString + if (prefix != null && prefix.length > 0) { + result = s"$prefix-$result" + } + if (suffix != null && suffix.length > 0) { + result = s"$result.$suffix" } + result + } +} + +private class StreamingContextPythonHelper { + + /** + * This is a private method only for Python to implement `getOrCreate`. + */ + def tryRecoverFromCheckpoint(checkpointPath: String): Option[StreamingContext] = { + val checkpointOption = CheckpointReader.read( + checkpointPath, new SparkConf(), SparkHadoopUtil.get.conf, ignoreReadError = false) + checkpointOption.map(new StreamingContext(null, _, null)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala index 01cdcb0574040..a59f4efccb575 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala @@ -17,14 +17,14 @@ package org.apache.spark.streaming.api.java -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.api.java.function.{Function => JFunction} -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.storage.StorageLevel -import org.apache.spark.rdd.RDD - import scala.language.implicitConversions import scala.reflect.ClassTag + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.java.function.{Function => JFunction} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.streaming.dstream.DStream /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index edfa474677f15..43632f37ccb16 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.api.java -import java.lang.{Long => JLong} +import java.{lang => jl} import java.util.{List => JList} import scala.collection.JavaConverters._ @@ -27,7 +27,7 @@ import scala.reflect.ClassTag import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaRDDLike} import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag -import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, Function3 => JFunction3, _} +import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, Function3 => JFunction3, VoidFunction => JVoidFunction, VoidFunction2 => JVoidFunction2, _} import org.apache.spark.rdd.RDD import org.apache.spark.streaming._ import org.apache.spark.streaming.api.java.JavaDStream._ @@ -50,8 +50,10 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T def wrapRDD(in: RDD[T]): R - implicit def scalaIntToJavaLong(in: DStream[Long]): JavaDStream[JLong] = { - in.map(new JLong(_)) + // This is just unfortunate we made a mistake in naming -- should be scalaLongToJavaLong. + // Don't fix this for now as it would break binary compatibility. + implicit def scalaIntToJavaLong(in: DStream[Long]): JavaDStream[jl.Long] = { + in.map(jl.Long.valueOf) } /** @@ -74,14 +76,14 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * Return a new DStream in which each RDD has a single element generated by counting each RDD * of this DStream. */ - def count(): JavaDStream[JLong] = dstream.count() + def count(): JavaDStream[jl.Long] = dstream.count() /** * Return a new DStream in which each RDD contains the counts of each distinct value in * each RDD of this DStream. Hash partitioning is used to generate the RDDs with * Spark's default number of partitions. */ - def countByValue(): JavaPairDStream[T, JLong] = { + def countByValue(): JavaPairDStream[T, jl.Long] = { JavaPairDStream.scalaToJavaLong(dstream.countByValue()) } @@ -91,7 +93,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * partitions. * @param numPartitions number of partitions of each RDD in the new DStream. */ - def countByValue(numPartitions: Int): JavaPairDStream[T, JLong] = { + def countByValue(numPartitions: Int): JavaPairDStream[T, jl.Long] = { JavaPairDStream.scalaToJavaLong(dstream.countByValue(numPartitions)) } @@ -101,7 +103,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * of elements in a window over this DStream. windowDuration and slideDuration are as defined in * the window() operation. This is equivalent to window(windowDuration, slideDuration).count() */ - def countByWindow(windowDuration: Duration, slideDuration: Duration) : JavaDStream[JLong] = { + def countByWindow(windowDuration: Duration, slideDuration: Duration): JavaDStream[jl.Long] = { dstream.countByWindow(windowDuration, slideDuration) } @@ -116,7 +118,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * DStream's batching interval */ def countByValueAndWindow(windowDuration: Duration, slideDuration: Duration) - : JavaPairDStream[T, JLong] = { + : JavaPairDStream[T, jl.Long] = { JavaPairDStream.scalaToJavaLong( dstream.countByValueAndWindow(windowDuration, slideDuration)) } @@ -133,7 +135,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * @param numPartitions number of partitions of each RDD in the new DStream. */ def countByValueAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) - : JavaPairDStream[T, JLong] = { + : JavaPairDStream[T, jl.Long] = { JavaPairDStream.scalaToJavaLong( dstream.countByValueAndWindow(windowDuration, slideDuration, numPartitions)) } @@ -166,8 +168,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * and then flattening the results */ def flatMap[U](f: FlatMapFunction[T, U]): JavaDStream[U] = { - import scala.collection.JavaConverters._ - def fn: (T) => Iterable[U] = (x: T) => f.call(x).asScala + def fn: (T) => Iterator[U] = (x: T) => f.call(x).asScala new JavaDStream(dstream.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -176,8 +177,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * and then flattening the results */ def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairDStream[K2, V2] = { - import scala.collection.JavaConverters._ - def fn: (T) => Iterable[(K2, V2)] = (x: T) => f.call(x).asScala + def fn: (T) => Iterator[(K2, V2)] = (x: T) => f.call(x).asScala def cm: ClassTag[(K2, V2)] = fakeClassTag new JavaPairDStream(dstream.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -189,7 +189,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaDStream[U] = { def fn: (Iterator[T]) => Iterator[U] = { - (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + (x: Iterator[T]) => f.call(x.asJava).asScala } new JavaDStream(dstream.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -202,7 +202,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]) : JavaPairDStream[K2, V2] = { def fn: (Iterator[T]) => Iterator[(K2, V2)] = { - (x: Iterator[T]) => f.call(x.asJava).iterator().asScala + (x: Iterator[T]) => f.call(x.asJava).asScala } new JavaPairDStream(dstream.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -216,27 +216,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T /** * Return a new DStream in which each RDD has a single element generated by reducing all * elements in a sliding window over this DStream. - * @param reduceFunc associative reduce function - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - * @deprecated As this API is not Java compatible. - */ - @deprecated("Use Java-compatible version of reduceByWindow", "1.3.0") - def reduceByWindow( - reduceFunc: (T, T) => T, - windowDuration: Duration, - slideDuration: Duration - ): DStream[T] = { - dstream.reduceByWindow(reduceFunc, windowDuration, slideDuration) - } - - /** - * Return a new DStream in which each RDD has a single element generated by reducing all - * elements in a sliding window over this DStream. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which @@ -259,7 +239,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) * This is more efficient than reduceByWindow without "inverse reduce" function. * However, it is applicable to only "invertible reduce functions". - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param invReduceFunc inverse reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -283,33 +263,11 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T dstream.slice(fromTime, toTime).map(wrapRDD).asJava } - /** - * Apply a function to each RDD in this DStream. This is an output operator, so - * 'this' DStream will be registered as an output stream and therefore materialized. - * - * @deprecated As of release 0.9.0, replaced by foreachRDD - */ - @deprecated("Use foreachRDD", "0.9.0") - def foreach(foreachFunc: JFunction[R, Void]) { - foreachRDD(foreachFunc) - } - - /** - * Apply a function to each RDD in this DStream. This is an output operator, so - * 'this' DStream will be registered as an output stream and therefore materialized. - * - * @deprecated As of release 0.9.0, replaced by foreachRDD - */ - @deprecated("Use foreachRDD", "0.9.0") - def foreach(foreachFunc: JFunction2[R, Time, Void]) { - foreachRDD(foreachFunc) - } - /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. */ - def foreachRDD(foreachFunc: JFunction[R, Void]) { + def foreachRDD(foreachFunc: JVoidFunction[R]) { dstream.foreachRDD(rdd => foreachFunc.call(wrapRDD(rdd))) } @@ -317,7 +275,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. */ - def foreachRDD(foreachFunc: JFunction2[R, Time, Void]) { + def foreachRDD(foreachFunc: JVoidFunction2[R, Time]) { dstream.foreachRDD((rdd, time) => foreachFunc.call(wrapRDD(rdd), time)) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala new file mode 100644 index 0000000000000..16c0d6fff8229 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.streaming.dstream.MapWithStateDStream + +/** + * :: Experimental :: + * DStream representing the stream of data generated by `mapWithState` operation on a + * [[JavaPairDStream]]. Additionally, it also gives access to the + * stream of state snapshots, that is, the state data of all keys after a batch has updated them. + * + * @tparam KeyType Class of the keys + * @tparam ValueType Class of the values + * @tparam StateType Class of the state data + * @tparam MappedType Class of the mapped data + */ +@Experimental +class JavaMapWithStateDStream[KeyType, ValueType, StateType, MappedType] private[streaming]( + dstream: MapWithStateDStream[KeyType, ValueType, StateType, MappedType]) + extends JavaDStream[MappedType](dstream)(JavaSparkContext.fakeClassTag) { + + def stateSnapshots(): JavaPairDStream[KeyType, StateType] = + new JavaPairDStream(dstream.stateSnapshots())( + JavaSparkContext.fakeClassTag, + JavaSparkContext.fakeClassTag) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index e2aec6c2f63e7..2a80cf4466588 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -17,19 +17,21 @@ package org.apache.spark.streaming.api.java -import java.lang.{Long => JLong, Iterable => JIterable} +import java.{lang => jl} +import java.lang.{Iterable => JIterable} import java.util.{List => JList} import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag -import com.google.common.base.Optional import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} + import org.apache.spark.Partitioner -import org.apache.spark.api.java.{JavaPairRDD, JavaUtils} +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.{JavaPairRDD, JavaSparkContext, JavaUtils, Optional} import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} @@ -136,8 +138,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( /** * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are - * merged using the associative reduce function. Hash partitioning is used to generate the RDDs - * with Spark's default number of partitions. + * merged using the associative and commutative reduce function. Hash partitioning is used to + * generate the RDDs with Spark's default number of partitions. */ def reduceByKey(func: JFunction2[V, V, V]): JavaPairDStream[K, V] = dstream.reduceByKey(func) @@ -153,7 +155,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( /** * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are * merged using the supplied reduce function. org.apache.spark.Partitioner is used to control - * thepartitioning of each RDD. + * the partitioning of each RDD. */ def reduceByKey(func: JFunction2[V, V, V], partitioner: Partitioner): JavaPairDStream[K, V] = { dstream.reduceByKey(func, partitioner) @@ -255,7 +257,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Similar to `DStream.reduceByKey()`, but applies it over a sliding window. The new DStream * generates RDDs with the same interval as this DStream. Hash partitioning is used to generate * the RDDs with Spark's default number of partitions. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval */ @@ -268,7 +270,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to * generate the RDDs with Spark's default number of partitions. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which @@ -287,7 +289,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to * generate the RDDs with `numPartitions` partitions. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which @@ -307,7 +309,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( /** * Return a new DStream by applying `reduceByKey` over a sliding window. Similar to * `DStream.reduceByKey()`, but applies it over a sliding window. - * @param reduceFunc associative reduce function + * @param reduceFunc associative rand commutative educe function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which @@ -333,7 +335,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. * However, it is applicable to only "invertible reduce functions". * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param invReduceFunc inverse function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -358,7 +360,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. * However, it is applicable to only "invertible reduce functions". * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param invReduceFunc inverse function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -395,7 +397,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. * However, it is applicable to only "invertible reduce functions". - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param invReduceFunc inverse function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -426,6 +428,42 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( ) } + /** + * :: Experimental :: + * Return a [[JavaMapWithStateDStream]] by applying a function to every key-value element of + * `this` stream, while maintaining some state data for each unique key. The mapping function + * and other specification (e.g. partitioners, timeouts, initial state data, etc.) of this + * transformation can be specified using [[StateSpec]] class. The state data is accessible in + * as a parameter of type [[State]] in the mapping function. + * + * Example of using `mapWithState`: + * {{{ + * // A mapping function that maintains an integer state and return a string + * Function3, State, String> mappingFunction = + * new Function3, State, String>() { + * @Override + * public Optional call(Optional value, State state) { + * // Use state.exists(), state.get(), state.update() and state.remove() + * // to manage state, and return the necessary string + * } + * }; + * + * JavaMapWithStateDStream mapWithStateDStream = + * keyValueDStream.mapWithState(StateSpec.function(mappingFunc)); + *}}} + * + * @param spec Specification of this transformation + * @tparam StateType Class type of the state data + * @tparam MappedType Class type of the mapped data + */ + @Experimental + def mapWithState[StateType, MappedType](spec: StateSpec[K, V, StateType, MappedType]): + JavaMapWithStateDStream[K, V, StateType, MappedType] = { + new JavaMapWithStateDStream(dstream.mapWithState(spec)( + JavaSparkContext.fakeClassTag, + JavaSparkContext.fakeClassTag)) + } + private def convertUpdateStateFunction[S](in: JFunction2[JList[V], Optional[S], Optional[S]]): (Seq[V], Option[S]) => Option[S] = { val scalaFunc: (Seq[V], Option[S]) => Option[S] = (values, state) => { @@ -809,7 +847,7 @@ object JavaPairDStream { } def scalaToJavaLong[K: ClassTag](dstream: JavaPairDStream[K, Long]) - : JavaPairDStream[K, JLong] = { - DStream.toPairDStreamFunctions(dstream.dstream).mapValues(new JLong(_)) + : JavaPairDStream[K, jl.Long] = { + DStream.toPairDStreamFunctions(dstream.dstream).mapValues(jl.Long.valueOf) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairInputDStream.scala index e6ff8a0cb545f..da0db02236a1f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairInputDStream.scala @@ -17,11 +17,11 @@ package org.apache.spark.streaming.api.java -import org.apache.spark.streaming.dstream.InputDStream - import scala.language.implicitConversions import scala.reflect.ClassTag +import org.apache.spark.streaming.dstream.InputDStream + /** * A Java-friendly interface to [[org.apache.spark.streaming.dstream.InputDStream]] of * key-value pairs. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 8f21c79a760c1..922e4a5e4d9cc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -17,19 +17,18 @@ package org.apache.spark.streaming.api.java -import java.lang.{Boolean => JBoolean} import java.io.{Closeable, InputStream} +import java.lang.{Boolean => JBoolean} import java.util.{List => JList, Map => JMap} import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import akka.actor.{Props, SupervisorStrategy} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} import org.apache.spark.api.java.function.{Function0 => JFunction0} @@ -37,10 +36,9 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ -import org.apache.spark.streaming.scheduler.StreamingListener import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver -import org.apache.hadoop.conf.Configuration +import org.apache.spark.streaming.scheduler.StreamingListener /** * A Java-friendly version of [[org.apache.spark.streaming.StreamingContext]] which is the main @@ -155,12 +153,6 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { /** The underlying SparkContext */ val sparkContext = new JavaSparkContext(ssc.sc) - /** - * @deprecated As of 0.9.0, replaced by `sparkContext` - */ - @deprecated("use sparkContext", "0.9.0") - val sc: JavaSparkContext = sparkContext - /** * Create an input stream from network source hostname:port. Data is received using * a TCP socket and the receive bytes is interpreted as UTF8 encoded \n delimited @@ -356,69 +348,6 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { ssc.fileStream[K, V, F](directory, fn, newFilesOnly, conf) } - /** - * Create an input stream with any arbitrary user implemented actor receiver. - * @param props Props object defining creation of the actor - * @param name Name of the actor - * @param storageLevel Storage level to use for storing the received objects - * - * @note An important point to note: - * Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e parametrized type of data received and actorStream - * should be same. - */ - def actorStream[T]( - props: Props, - name: String, - storageLevel: StorageLevel, - supervisorStrategy: SupervisorStrategy - ): JavaReceiverInputDStream[T] = { - implicit val cm: ClassTag[T] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - ssc.actorStream[T](props, name, storageLevel, supervisorStrategy) - } - - /** - * Create an input stream with any arbitrary user implemented actor receiver. - * @param props Props object defining creation of the actor - * @param name Name of the actor - * @param storageLevel Storage level to use for storing the received objects - * - * @note An important point to note: - * Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e parametrized type of data received and actorStream - * should be same. - */ - def actorStream[T]( - props: Props, - name: String, - storageLevel: StorageLevel - ): JavaReceiverInputDStream[T] = { - implicit val cm: ClassTag[T] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - ssc.actorStream[T](props, name, storageLevel) - } - - /** - * Create an input stream with any arbitrary user implemented actor receiver. - * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. - * @param props Props object defining creation of the actor - * @param name Name of the actor - * - * @note An important point to note: - * Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e parametrized type of data received and actorStream - * should be same. - */ - def actorStream[T]( - props: Props, - name: String - ): JavaReceiverInputDStream[T] = { - implicit val cm: ClassTag[T] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - ssc.actorStream[T](props, name) - } - /** * Create an input stream from an queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. @@ -588,9 +517,10 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { ssc.remember(duration) } - /** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for - * receiving system events related to streaming. - */ + /** + * Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for + * receiving system events related to streaming. + */ def addStreamingListener(streamingListener: StreamingListener) { ssc.addStreamingListener(streamingListener) } @@ -601,7 +531,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Return the current state of the context. The context can be in three possible states - *
      *
    • - * StreamingContextState.INTIALIZED - The context has been created, but not been started yet. + * StreamingContextState.INITIALIZED - The context has been created, but not been started yet. * Input DStreams, transformations and output operations can be created on the context. *
    • *
    • @@ -632,17 +562,6 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { ssc.awaitTermination() } - /** - * Wait for the execution to stop. Any exceptions that occurs during the execution - * will be thrown in this thread. - * @param timeout time to wait in milliseconds - * @deprecated As of 1.3.0, replaced by `awaitTerminationOrTimeout(Long)`. - */ - @deprecated("Use awaitTerminationOrTimeout(Long) instead", "1.3.0") - def awaitTermination(timeout: Long): Unit = { - ssc.awaitTermination(timeout) - } - /** * Wait for the execution to stop. Any exceptions that occurs during the execution * will be thrown in this thread. @@ -687,78 +606,6 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { */ object JavaStreamingContext { - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program - * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext - * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. - */ - @deprecated("use getOrCreate without JavaStreamingContextFactor", "1.4.0") - def getOrCreate( - checkpointPath: String, - factory: JavaStreamingContextFactory - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, () => { - factory.create.ssc - }) - new JavaStreamingContext(ssc) - } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext - * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible - * file system - * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. - */ - @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") - def getOrCreate( - checkpointPath: String, - hadoopConf: Configuration, - factory: JavaStreamingContextFactory - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, () => { - factory.create.ssc - }, hadoopConf) - new JavaStreamingContext(ssc) - } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext - * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible - * file system - * @param createOnError Whether to create a new JavaStreamingContext if there is an - * error in reading checkpoint data. - * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. - */ - @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") - def getOrCreate( - checkpointPath: String, - hadoopConf: Configuration, - factory: JavaStreamingContextFactory, - createOnError: Boolean - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, () => { - factory.create.ssc - }, hadoopConf, createOnError) - new JavaStreamingContext(ssc) - } - /** * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be @@ -831,10 +678,3 @@ object JavaStreamingContext { */ def jarOfClass(cls: Class[_]): Array[String] = SparkContext.jarOfClass(cls).toArray } - -/** - * Factory interface for creating a new JavaStreamingContext - */ -trait JavaStreamingContextFactory { - def create(): JavaStreamingContext -} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala new file mode 100644 index 0000000000000..db0bae9958d61 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import org.apache.spark.streaming.Time + +private[streaming] trait PythonStreamingListener{ + + /** Called when a receiver has been started */ + def onReceiverStarted(receiverStarted: JavaStreamingListenerReceiverStarted) { } + + /** Called when a receiver has reported an error */ + def onReceiverError(receiverError: JavaStreamingListenerReceiverError) { } + + /** Called when a receiver has been stopped */ + def onReceiverStopped(receiverStopped: JavaStreamingListenerReceiverStopped) { } + + /** Called when a batch of jobs has been submitted for processing. */ + def onBatchSubmitted(batchSubmitted: JavaStreamingListenerBatchSubmitted) { } + + /** Called when processing of a batch of jobs has started. */ + def onBatchStarted(batchStarted: JavaStreamingListenerBatchStarted) { } + + /** Called when processing of a batch of jobs has completed. */ + def onBatchCompleted(batchCompleted: JavaStreamingListenerBatchCompleted) { } + + /** Called when processing of a job of a batch has started. */ + def onOutputOperationStarted( + outputOperationStarted: JavaStreamingListenerOutputOperationStarted) { } + + /** Called when processing of a job of a batch has completed. */ + def onOutputOperationCompleted( + outputOperationCompleted: JavaStreamingListenerOutputOperationCompleted) { } +} + +private[streaming] class PythonStreamingListenerWrapper(listener: PythonStreamingListener) + extends JavaStreamingListener { + + /** Called when a receiver has been started */ + override def onReceiverStarted(receiverStarted: JavaStreamingListenerReceiverStarted): Unit = { + listener.onReceiverStarted(receiverStarted) + } + + /** Called when a receiver has reported an error */ + override def onReceiverError(receiverError: JavaStreamingListenerReceiverError): Unit = { + listener.onReceiverError(receiverError) + } + + /** Called when a receiver has been stopped */ + override def onReceiverStopped(receiverStopped: JavaStreamingListenerReceiverStopped): Unit = { + listener.onReceiverStopped(receiverStopped) + } + + /** Called when a batch of jobs has been submitted for processing. */ + override def onBatchSubmitted(batchSubmitted: JavaStreamingListenerBatchSubmitted): Unit = { + listener.onBatchSubmitted(batchSubmitted) + } + + /** Called when processing of a batch of jobs has started. */ + override def onBatchStarted(batchStarted: JavaStreamingListenerBatchStarted): Unit = { + listener.onBatchStarted(batchStarted) + } + + /** Called when processing of a batch of jobs has completed. */ + override def onBatchCompleted(batchCompleted: JavaStreamingListenerBatchCompleted): Unit = { + listener.onBatchCompleted(batchCompleted) + } + + /** Called when processing of a job of a batch has started. */ + override def onOutputOperationStarted( + outputOperationStarted: JavaStreamingListenerOutputOperationStarted): Unit = { + listener.onOutputOperationStarted(outputOperationStarted) + } + + /** Called when processing of a job of a batch has completed. */ + override def onOutputOperationCompleted( + outputOperationCompleted: JavaStreamingListenerOutputOperationCompleted): Unit = { + listener.onOutputOperationCompleted(outputOperationCompleted) + } +} + +/** + * A listener interface for receiving information about an ongoing streaming computation. + */ +private[streaming] class JavaStreamingListener { + + /** Called when a receiver has been started */ + def onReceiverStarted(receiverStarted: JavaStreamingListenerReceiverStarted): Unit = { } + + /** Called when a receiver has reported an error */ + def onReceiverError(receiverError: JavaStreamingListenerReceiverError): Unit = { } + + /** Called when a receiver has been stopped */ + def onReceiverStopped(receiverStopped: JavaStreamingListenerReceiverStopped): Unit = { } + + /** Called when a batch of jobs has been submitted for processing. */ + def onBatchSubmitted(batchSubmitted: JavaStreamingListenerBatchSubmitted): Unit = { } + + /** Called when processing of a batch of jobs has started. */ + def onBatchStarted(batchStarted: JavaStreamingListenerBatchStarted): Unit = { } + + /** Called when processing of a batch of jobs has completed. */ + def onBatchCompleted(batchCompleted: JavaStreamingListenerBatchCompleted): Unit = { } + + /** Called when processing of a job of a batch has started. */ + def onOutputOperationStarted( + outputOperationStarted: JavaStreamingListenerOutputOperationStarted): Unit = { } + + /** Called when processing of a job of a batch has completed. */ + def onOutputOperationCompleted( + outputOperationCompleted: JavaStreamingListenerOutputOperationCompleted): Unit = { } +} + +/** + * Base trait for events related to JavaStreamingListener + */ +private[streaming] sealed trait JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerBatchSubmitted(val batchInfo: JavaBatchInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerBatchCompleted(val batchInfo: JavaBatchInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerBatchStarted(val batchInfo: JavaBatchInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerOutputOperationStarted( + val outputOperationInfo: JavaOutputOperationInfo) extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerOutputOperationCompleted( + val outputOperationInfo: JavaOutputOperationInfo) extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerReceiverStarted(val receiverInfo: JavaReceiverInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerReceiverError(val receiverInfo: JavaReceiverInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerReceiverStopped(val receiverInfo: JavaReceiverInfo) + extends JavaStreamingListenerEvent + +/** + * Class having information on batches. + * + * @param batchTime Time of the batch + * @param streamIdToInputInfo A map of input stream id to its input info + * @param submissionTime Clock time of when jobs of this batch was submitted to the streaming + * scheduler queue + * @param processingStartTime Clock time of when the first job of this batch started processing. + * `-1` means the batch has not yet started + * @param processingEndTime Clock time of when the last job of this batch finished processing. `-1` + * means the batch has not yet completed. + * @param schedulingDelay Time taken for the first job of this batch to start processing from the + * time this batch was submitted to the streaming scheduler. Essentially, it + * is `processingStartTime` - `submissionTime`. `-1` means the batch has not + * yet started + * @param processingDelay Time taken for the all jobs of this batch to finish processing from the + * time they started processing. Essentially, it is + * `processingEndTime` - `processingStartTime`. `-1` means the batch has not + * yet completed. + * @param totalDelay Time taken for all the jobs of this batch to finish processing from the time + * they were submitted. Essentially, it is `processingDelay` + `schedulingDelay`. + * `-1` means the batch has not yet completed. + * @param numRecords The number of recorders received by the receivers in this batch + * @param outputOperationInfos The output operations in this batch + */ +private[streaming] case class JavaBatchInfo( + batchTime: Time, + streamIdToInputInfo: java.util.Map[Int, JavaStreamInputInfo], + submissionTime: Long, + processingStartTime: Long, + processingEndTime: Long, + schedulingDelay: Long, + processingDelay: Long, + totalDelay: Long, + numRecords: Long, + outputOperationInfos: java.util.Map[Int, JavaOutputOperationInfo]) + +/** + * Track the information of input stream at specified batch time. + * + * @param inputStreamId the input stream id + * @param numRecords the number of records in a batch + * @param metadata metadata for this batch. It should contain at least one standard field named + * "Description" which maps to the content that will be shown in the UI. + * @param metadataDescription description of this input stream + */ +private[streaming] case class JavaStreamInputInfo( + inputStreamId: Int, + numRecords: Long, + metadata: java.util.Map[String, Any], + metadataDescription: String) + +/** + * Class having information about a receiver + */ +private[streaming] case class JavaReceiverInfo( + streamId: Int, + name: String, + active: Boolean, + location: String, + executorId: String, + lastErrorMessage: String, + lastError: String, + lastErrorTime: Long) + +/** + * Class having information on output operations. + * + * @param batchTime Time of the batch + * @param id Id of this output operation. Different output operations have different ids in a batch. + * @param name The name of this output operation. + * @param description The description of this output operation. + * @param startTime Clock time of when the output operation started processing. `-1` means the + * output operation has not yet started + * @param endTime Clock time of when the output operation started processing. `-1` means the output + * operation has not yet completed + * @param failureReason Failure reason if this output operation fails. If the output operation is + * successful, this field is `null`. + */ +private[streaming] case class JavaOutputOperationInfo( + batchTime: Time, + id: Int, + name: String, + description: String, + startTime: Long, + endTime: Long, + failureReason: String) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala new file mode 100644 index 0000000000000..b109b9f1cbeae --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import scala.collection.JavaConverters._ + +import org.apache.spark.streaming.scheduler._ + +/** + * A wrapper to convert a [[JavaStreamingListener]] to a [[StreamingListener]]. + */ +private[streaming] class JavaStreamingListenerWrapper(javaStreamingListener: JavaStreamingListener) + extends StreamingListener { + + private def toJavaReceiverInfo(receiverInfo: ReceiverInfo): JavaReceiverInfo = { + JavaReceiverInfo( + receiverInfo.streamId, + receiverInfo.name, + receiverInfo.active, + receiverInfo.location, + receiverInfo.executorId, + receiverInfo.lastErrorMessage, + receiverInfo.lastError, + receiverInfo.lastErrorTime + ) + } + + private def toJavaStreamInputInfo(streamInputInfo: StreamInputInfo): JavaStreamInputInfo = { + JavaStreamInputInfo( + streamInputInfo.inputStreamId, + streamInputInfo.numRecords: Long, + streamInputInfo.metadata.asJava, + streamInputInfo.metadataDescription.orNull + ) + } + + private def toJavaOutputOperationInfo( + outputOperationInfo: OutputOperationInfo): JavaOutputOperationInfo = { + JavaOutputOperationInfo( + outputOperationInfo.batchTime, + outputOperationInfo.id, + outputOperationInfo.name, + outputOperationInfo.description: String, + outputOperationInfo.startTime.getOrElse(-1), + outputOperationInfo.endTime.getOrElse(-1), + outputOperationInfo.failureReason.orNull + ) + } + + private def toJavaBatchInfo(batchInfo: BatchInfo): JavaBatchInfo = { + JavaBatchInfo( + batchInfo.batchTime, + batchInfo.streamIdToInputInfo.mapValues(toJavaStreamInputInfo(_)).asJava, + batchInfo.submissionTime, + batchInfo.processingStartTime.getOrElse(-1), + batchInfo.processingEndTime.getOrElse(-1), + batchInfo.schedulingDelay.getOrElse(-1), + batchInfo.processingDelay.getOrElse(-1), + batchInfo.totalDelay.getOrElse(-1), + batchInfo.numRecords, + batchInfo.outputOperationInfos.mapValues(toJavaOutputOperationInfo(_)).asJava + ) + } + + override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { + javaStreamingListener.onReceiverStarted( + new JavaStreamingListenerReceiverStarted(toJavaReceiverInfo(receiverStarted.receiverInfo))) + } + + override def onReceiverError(receiverError: StreamingListenerReceiverError): Unit = { + javaStreamingListener.onReceiverError( + new JavaStreamingListenerReceiverError(toJavaReceiverInfo(receiverError.receiverInfo))) + } + + override def onReceiverStopped(receiverStopped: StreamingListenerReceiverStopped): Unit = { + javaStreamingListener.onReceiverStopped( + new JavaStreamingListenerReceiverStopped(toJavaReceiverInfo(receiverStopped.receiverInfo))) + } + + override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted): Unit = { + javaStreamingListener.onBatchSubmitted( + new JavaStreamingListenerBatchSubmitted(toJavaBatchInfo(batchSubmitted.batchInfo))) + } + + override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = { + javaStreamingListener.onBatchStarted( + new JavaStreamingListenerBatchStarted(toJavaBatchInfo(batchStarted.batchInfo))) + } + + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = { + javaStreamingListener.onBatchCompleted( + new JavaStreamingListenerBatchCompleted(toJavaBatchInfo(batchCompleted.batchInfo))) + } + + override def onOutputOperationStarted( + outputOperationStarted: StreamingListenerOutputOperationStarted): Unit = { + javaStreamingListener.onOutputOperationStarted(new JavaStreamingListenerOutputOperationStarted( + toJavaOutputOperationInfo(outputOperationStarted.outputOperationInfo))) + } + + override def onOutputOperationCompleted( + outputOperationCompleted: StreamingListenerOutputOperationCompleted): Unit = { + javaStreamingListener.onOutputOperationCompleted( + new JavaStreamingListenerOutputOperationCompleted( + toJavaOutputOperationInfo(outputOperationCompleted.outputOperationInfo))) + } + +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index dfc569451df86..aeff4d7a98e7a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -24,22 +24,27 @@ import java.util.{ArrayList => JArrayList, List => JList} import scala.collection.JavaConverters._ import scala.language.existentials -import py4j.GatewayServer - +import org.apache.spark.SparkException import org.apache.spark.api.java._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Interval, Duration, Time} -import org.apache.spark.streaming.dstream._ +import org.apache.spark.streaming.{Duration, Interval, Time} import org.apache.spark.streaming.api.java._ +import org.apache.spark.streaming.dstream._ import org.apache.spark.util.Utils - /** * Interface for Python callback function which is used to transform RDDs */ private[python] trait PythonTransformFunction { def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] + + /** + * Get the failure, if any, in the last call to `call`. + * + * @return the failure message if there was a failure, or `null` if there was no failure. + */ + def getLastFailure: String } /** @@ -48,6 +53,13 @@ private[python] trait PythonTransformFunction { private[python] trait PythonTransformFunctionSerializer { def dumps(id: String): Array[Byte] def loads(bytes: Array[Byte]): PythonTransformFunction + + /** + * Get the failure, if any, in the last call to `dumps` or `loads`. + * + * @return the failure message if there was a failure, or `null` if there was no failure. + */ + def getLastFailure: String } /** @@ -59,18 +71,27 @@ private[python] class TransformFunction(@transient var pfunc: PythonTransformFun extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] { def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { - Option(pfunc.call(time.milliseconds, List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava)) - .map(_.rdd) + val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava + Option(callPythonTransformFunction(time.milliseconds, rdds)).map(_.rdd) } def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull, rdd2.map(JavaRDD.fromRDD(_)).orNull).asJava - Option(pfunc.call(time.milliseconds, rdds)).map(_.rdd) + Option(callPythonTransformFunction(time.milliseconds, rdds)).map(_.rdd) } // for function.Function2 def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = { - pfunc.call(time.milliseconds, rdds) + callPythonTransformFunction(time.milliseconds, rdds) + } + + private def callPythonTransformFunction(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] = { + val resultRDD = pfunc.call(time, rdds) + val failure = pfunc.getLastFailure + if (failure != null) { + throw new SparkException("An exception was raised by Python:\n" + failure) + } + resultRDD } private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { @@ -103,23 +124,33 @@ private[python] object PythonTransformFunctionSerializer { /* * Register a serializer from Python, should be called during initialization */ - def register(ser: PythonTransformFunctionSerializer): Unit = { + def register(ser: PythonTransformFunctionSerializer): Unit = synchronized { serializer = ser } - def serialize(func: PythonTransformFunction): Array[Byte] = { + def serialize(func: PythonTransformFunction): Array[Byte] = synchronized { require(serializer != null, "Serializer has not been registered!") // get the id of PythonTransformFunction in py4j val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy]) val f = h.getClass().getDeclaredField("id") f.setAccessible(true) val id = f.get(h).asInstanceOf[String] - serializer.dumps(id) + val results = serializer.dumps(id) + val failure = serializer.getLastFailure + if (failure != null) { + throw new SparkException("An exception was raised by Python:\n" + failure) + } + results } - def deserialize(bytes: Array[Byte]): PythonTransformFunction = { + def deserialize(bytes: Array[Byte]): PythonTransformFunction = synchronized { require(serializer != null, "Serializer has not been registered!") - serializer.loads(bytes) + val pfunc = serializer.loads(bytes) + val failure = serializer.getLastFailure + if (failure != null) { + throw new SparkException("An exception was raised by Python:\n" + failure) + } + pfunc } } @@ -136,16 +167,6 @@ private[python] object PythonDStream { PythonTransformFunctionSerializer.register(ser) } - /** - * Update the port of callback client to `port` - */ - def updatePythonGatewayPort(gws: GatewayServer, port: Int): Unit = { - val cl = gws.getCallbackClient - val f = cl.getClass.getDeclaredField("port") - f.setAccessible(true) - f.setInt(cl, port) - } - /** * helper function for DStream.foreachRDD(), * cannot be `foreachRDD`, it will confusing py4j @@ -230,9 +251,19 @@ private[python] class PythonTransformed2DStream( */ private[python] class PythonStateDStream( parent: DStream[Array[Byte]], - reduceFunc: PythonTransformFunction) + reduceFunc: PythonTransformFunction, + initialRDD: Option[RDD[Array[Byte]]]) extends PythonDStream(parent, reduceFunc) { + def this( + parent: DStream[Array[Byte]], + reduceFunc: PythonTransformFunction) = this(parent, reduceFunc, None) + + def this( + parent: DStream[Array[Byte]], + reduceFunc: PythonTransformFunction, + initialRDD: JavaRDD[Array[Byte]]) = this(parent, reduceFunc, Some(initialRDD.rdd)) + super.persist(StorageLevel.MEMORY_ONLY) override val mustCheckpoint = true @@ -240,7 +271,7 @@ private[python] class PythonStateDStream( val lastState = getOrCompute(validTime - slideDuration) val rdd = parent.getOrCompute(validTime) if (rdd.isDefined) { - func(lastState, rdd, validTime) + func(lastState.orElse(initialRDD), rdd, validTime) } else { lastState } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala index 4eb92dd8b1053..995470ec8deae 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala @@ -20,13 +20,13 @@ package org.apache.spark.streaming.dstream import scala.reflect.ClassTag import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.{Time, StreamingContext} +import org.apache.spark.streaming.{StreamingContext, Time} /** - * An input stream that always returns the same RDD on each timestep. Useful for testing. + * An input stream that always returns the same RDD on each time step. Useful for testing. */ -class ConstantInputDStream[T: ClassTag](ssc_ : StreamingContext, rdd: RDD[T]) - extends InputDStream[T](ssc_) { +class ConstantInputDStream[T: ClassTag](_ssc: StreamingContext, rdd: RDD[T]) + extends InputDStream[T](_ssc) { require(rdd != null, "parameter rdd null is illegal, which will lead to NPE in the following transformation") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 1da0b0a54df07..58842f9c2f446 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -25,14 +25,15 @@ import scala.language.implicitConversions import scala.reflect.ClassTag import scala.util.matching.Regex -import org.apache.spark.{Logging, SparkContext, SparkException} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.internal.Logging import org.apache.spark.rdd.{BlockRDD, PairRDDFunctions, RDD, RDDOperationScope} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext.rddToFileName import org.apache.spark.streaming.scheduler.Job import org.apache.spark.streaming.ui.UIUtils -import org.apache.spark.util.{CallSite, MetadataCleaner, Utils} +import org.apache.spark.util.{CallSite, Utils} /** * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous @@ -82,7 +83,7 @@ abstract class DStream[T: ClassTag] ( // RDDs generated, marked as private[streaming] so that testsuites can access it @transient - private[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] () + private[streaming] var generatedRDDs = new HashMap[Time, RDD[T]]() // Time zero for the DStream private[streaming] var zeroTime: Time = null @@ -97,11 +98,13 @@ abstract class DStream[T: ClassTag] ( private[streaming] val mustCheckpoint = false private[streaming] var checkpointDuration: Duration = null private[streaming] val checkpointData = new DStreamCheckpointData(this) + @transient + private var restoredFromCheckpointData = false // Reference to whole DStream graph private[streaming] var graph: DStreamGraph = null - private[streaming] def isInitialized = (zeroTime != null) + private[streaming] def isInitialized = zeroTime != null // Duration for which the DStream requires its parent DStream to remember each RDD created private[streaming] def parentRememberDuration = rememberDuration @@ -187,15 +190,15 @@ abstract class DStream[T: ClassTag] ( */ private[streaming] def initialize(time: Time) { if (zeroTime != null && zeroTime != time) { - throw new SparkException("ZeroTime is already initialized to " + zeroTime - + ", cannot initialize it again to " + time) + throw new SparkException(s"ZeroTime is already initialized to $zeroTime" + + s", cannot initialize it again to $time") } zeroTime = time // Set the checkpoint interval to be slideDuration or 10 seconds, which ever is larger if (mustCheckpoint && checkpointDuration == null) { checkpointDuration = slideDuration * math.ceil(Seconds(10) / slideDuration).toInt - logInfo("Checkpoint interval automatically set to " + checkpointDuration) + logInfo(s"Checkpoint interval automatically set to $checkpointDuration") } // Set the minimum value of the rememberDuration if not already set @@ -232,7 +235,7 @@ abstract class DStream[T: ClassTag] ( require( !mustCheckpoint || checkpointDuration != null, - "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set." + + s"The checkpoint interval for ${this.getClass.getSimpleName} has not been set." + " Please use DStream.checkpoint() to set the interval." ) @@ -243,65 +246,53 @@ abstract class DStream[T: ClassTag] ( require( checkpointDuration == null || checkpointDuration >= slideDuration, - "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + - checkpointDuration + " which is lower than its slide time (" + slideDuration + "). " + - "Please set it to at least " + slideDuration + "." + s"The checkpoint interval for ${this.getClass.getSimpleName} has been set to " + + s"$checkpointDuration which is lower than its slide time ($slideDuration). " + + s"Please set it to at least $slideDuration." ) require( checkpointDuration == null || checkpointDuration.isMultipleOf(slideDuration), - "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + - checkpointDuration + " which not a multiple of its slide time (" + slideDuration + "). " + - "Please set it to a multiple of " + slideDuration + "." + s"The checkpoint interval for ${this.getClass.getSimpleName} has been set to " + + s" $checkpointDuration which not a multiple of its slide time ($slideDuration). " + + s"Please set it to a multiple of $slideDuration." ) require( checkpointDuration == null || storageLevel != StorageLevel.NONE, - "" + this.getClass.getSimpleName + " has been marked for checkpointing but the storage " + + s"${this.getClass.getSimpleName} has been marked for checkpointing but the storage " + "level has not been set to enable persisting. Please use DStream.persist() to set the " + "storage level to use memory for better checkpointing performance." ) require( checkpointDuration == null || rememberDuration > checkpointDuration, - "The remember duration for " + this.getClass.getSimpleName + " has been set to " + - rememberDuration + " which is not more than the checkpoint interval (" + - checkpointDuration + "). Please set it to higher than " + checkpointDuration + "." - ) - - val metadataCleanerDelay = MetadataCleaner.getDelaySeconds(ssc.conf) - logInfo("metadataCleanupDelay = " + metadataCleanerDelay) - require( - metadataCleanerDelay < 0 || rememberDuration.milliseconds < metadataCleanerDelay * 1000, - "It seems you are doing some DStream window operation or setting a checkpoint interval " + - "which requires " + this.getClass.getSimpleName + " to remember generated RDDs for more " + - "than " + rememberDuration.milliseconds / 1000 + " seconds. But Spark's metadata cleanup" + - "delay is set to " + metadataCleanerDelay + " seconds, which is not sufficient. Please " + - "set the Java cleaner delay to more than " + - math.ceil(rememberDuration.milliseconds / 1000.0).toInt + " seconds." + s"The remember duration for ${this.getClass.getSimpleName} has been set to " + + s" $rememberDuration which is not more than the checkpoint interval" + + s" ($checkpointDuration). Please set it to a value higher than $checkpointDuration." ) dependencies.foreach(_.validateAtStart()) - logInfo("Slide time = " + slideDuration) - logInfo("Storage level = " + storageLevel) - logInfo("Checkpoint interval = " + checkpointDuration) - logInfo("Remember duration = " + rememberDuration) - logInfo("Initialized and validated " + this) + logInfo(s"Slide time = $slideDuration") + logInfo(s"Storage level = ${storageLevel.description}") + logInfo(s"Checkpoint interval = $checkpointDuration") + logInfo(s"Remember interval = $rememberDuration") + logInfo(s"Initialized and validated $this") } private[streaming] def setContext(s: StreamingContext) { if (ssc != null && ssc != s) { - throw new SparkException("Context is already set in " + this + ", cannot set it again") + throw new SparkException(s"Context must not be set again for $this") } ssc = s - logInfo("Set context for " + this) + logInfo(s"Set context for $this") dependencies.foreach(_.setContext(ssc)) } private[streaming] def setGraph(g: DStreamGraph) { if (graph != null && graph != g) { - throw new SparkException("Graph is already set in " + this + ", cannot set it again") + throw new SparkException(s"Graph must not be set again for $this") } graph = g dependencies.foreach(_.setGraph(graph)) @@ -310,7 +301,7 @@ abstract class DStream[T: ClassTag] ( private[streaming] def remember(duration: Duration) { if (duration != null && (rememberDuration == null || duration > rememberDuration)) { rememberDuration = duration - logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this) + logInfo(s"Duration for remembering RDDs set to $rememberDuration for $this") } dependencies.foreach(_.remember(parentRememberDuration)) } @@ -320,11 +311,11 @@ abstract class DStream[T: ClassTag] ( if (!isInitialized) { throw new SparkException (this + " has not been initialized") } else if (time <= zeroTime || ! (time - zeroTime).isMultipleOf(slideDuration)) { - logInfo("Time " + time + " is invalid as zeroTime is " + zeroTime + - " and slideDuration is " + slideDuration + " and difference is " + (time - zeroTime)) + logInfo(s"Time $time is invalid as zeroTime is $zeroTime" + + s" , slideDuration is $slideDuration and difference is ${time - zeroTime}") false } else { - logDebug("Time " + time + " is valid") + logDebug(s"Time $time is valid") true } } @@ -341,7 +332,7 @@ abstract class DStream[T: ClassTag] ( // of RDD generation, else generate nothing. if (isTimeValid(time)) { - val rddOption = createRDDWithLocalProperties(time) { + val rddOption = createRDDWithLocalProperties(time, displayInnerRDDOps = false) { // Disable checks for existing output directories in jobs launched by the streaming // scheduler, since we may need to write output to an existing directory during checkpoint // recovery; see SPARK-4835 for more details. We need to have this call here because @@ -373,27 +364,52 @@ abstract class DStream[T: ClassTag] ( /** * Wrap a body of code such that the call site and operation scope * information are passed to the RDDs created in this body properly. - */ - protected def createRDDWithLocalProperties[U](time: Time)(body: => U): U = { + * @param body RDD creation code to execute with certain local properties. + * @param time Current batch time that should be embedded in the scope names + * @param displayInnerRDDOps Whether the detailed callsites and scopes of the inner RDDs generated + * by `body` will be displayed in the UI; only the scope and callsite + * of the DStream operation that generated `this` will be displayed. + */ + protected[streaming] def createRDDWithLocalProperties[U]( + time: Time, + displayInnerRDDOps: Boolean)(body: => U): U = { val scopeKey = SparkContext.RDD_SCOPE_KEY val scopeNoOverrideKey = SparkContext.RDD_SCOPE_NO_OVERRIDE_KEY // Pass this DStream's operation scope and creation site information to RDDs through // thread-local properties in our SparkContext. Since this method may be called from another // DStream, we need to temporarily store any old scope and creation site information to // restore them later after setting our own. - val prevCallSite = ssc.sparkContext.getCallSite() + val prevCallSite = CallSite( + ssc.sparkContext.getLocalProperty(CallSite.SHORT_FORM), + ssc.sparkContext.getLocalProperty(CallSite.LONG_FORM) + ) val prevScope = ssc.sparkContext.getLocalProperty(scopeKey) val prevScopeNoOverride = ssc.sparkContext.getLocalProperty(scopeNoOverrideKey) try { - ssc.sparkContext.setCallSite(creationSite) + if (displayInnerRDDOps) { + // Unset the short form call site, so that generated RDDs get their own + ssc.sparkContext.setLocalProperty(CallSite.SHORT_FORM, null) + ssc.sparkContext.setLocalProperty(CallSite.LONG_FORM, null) + } else { + // Set the callsite, so that the generated RDDs get the DStream's call site and + // the internal RDD call sites do not get displayed + ssc.sparkContext.setCallSite(creationSite) + } + // Use the DStream's base scope for this RDD so we can (1) preserve the higher level // DStream operation name, and (2) share this scope with other DStreams created in the // same operation. Disallow nesting so that low-level Spark primitives do not show up. // TODO: merge callsites with scopes so we can just reuse the code there makeScope(time).foreach { s => ssc.sparkContext.setLocalProperty(scopeKey, s.toJson) - ssc.sparkContext.setLocalProperty(scopeNoOverrideKey, "true") + if (displayInnerRDDOps) { + // Allow inner RDDs to add inner scopes + ssc.sparkContext.setLocalProperty(scopeNoOverrideKey, null) + } else { + // Do not allow inner RDDs to override the scope set by DStream + ssc.sparkContext.setLocalProperty(scopeNoOverrideKey, "true") + } } body @@ -413,13 +429,12 @@ abstract class DStream[T: ClassTag] ( */ private[streaming] def generateJob(time: Time): Option[Job] = { getOrCompute(time) match { - case Some(rdd) => { + case Some(rdd) => val jobFunc = () => { val emptyFunc = { (iterator: Iterator[T]) => {} } context.sparkContext.runJob(rdd, emptyFunc) } Some(new Job(time, jobFunc)) - } case None => None } } @@ -437,20 +452,20 @@ abstract class DStream[T: ClassTag] ( oldRDDs.map(x => s"${x._1} -> ${x._2.id}").mkString(", ") + "]") generatedRDDs --= oldRDDs.keys if (unpersistData) { - logDebug("Unpersisting old RDDs: " + oldRDDs.values.map(_.id).mkString(", ")) + logDebug(s"Unpersisting old RDDs: ${oldRDDs.values.map(_.id).mkString(", ")}") oldRDDs.values.foreach { rdd => rdd.unpersist(false) // Explicitly remove blocks of BlockRDD rdd match { case b: BlockRDD[_] => - logInfo("Removing blocks of RDD " + b + " of time " + time) + logInfo(s"Removing blocks of RDD $b of time $time") b.removeBlocks() case _ => } } } - logDebug("Cleared " + oldRDDs.size + " RDDs that were older than " + - (time - rememberDuration) + ": " + oldRDDs.keys.mkString(", ")) + logDebug(s"Cleared ${oldRDDs.size} RDDs that were older than " + + s"${time - rememberDuration}: ${oldRDDs.keys.mkString(", ")}") dependencies.foreach(_.clearMetadata(time)) } @@ -462,10 +477,10 @@ abstract class DStream[T: ClassTag] ( * this method to save custom checkpoint data. */ private[streaming] def updateCheckpointData(currentTime: Time) { - logDebug("Updating checkpoint data for time " + currentTime) + logDebug(s"Updating checkpoint data for time $currentTime") checkpointData.update(currentTime) dependencies.foreach(_.updateCheckpointData(currentTime)) - logDebug("Updated checkpoint data for time " + currentTime + ": " + checkpointData) + logDebug(s"Updated checkpoint data for time $currentTime: $checkpointData") } private[streaming] def clearCheckpointData(time: Time) { @@ -482,22 +497,25 @@ abstract class DStream[T: ClassTag] ( * override the updateCheckpointData() method would also need to override this method. */ private[streaming] def restoreCheckpointData() { - // Create RDDs from the checkpoint data - logInfo("Restoring checkpoint data") - checkpointData.restore() - dependencies.foreach(_.restoreCheckpointData()) - logInfo("Restored checkpoint data") + if (!restoredFromCheckpointData) { + // Create RDDs from the checkpoint data + logInfo("Restoring checkpoint data") + checkpointData.restore() + dependencies.foreach(_.restoreCheckpointData()) + restoredFromCheckpointData = true + logInfo("Restored checkpoint data") + } } @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { - logDebug(this.getClass().getSimpleName + ".writeObject used") + logDebug(s"${this.getClass().getSimpleName}.writeObject used") if (graph != null) { graph.synchronized { if (graph.checkpointInProgress) { oos.defaultWriteObject() } else { - val msg = "Object of " + this.getClass.getName + " is being serialized " + + val msg = s"Object of ${this.getClass.getName} is being serialized " + " possibly as a part of closure of an RDD operation. This is because " + " the DStream object is being referred to from within the closure. " + " Please rewrite the RDD operation inside this DStream to avoid this. " + @@ -514,9 +532,9 @@ abstract class DStream[T: ClassTag] ( @throws(classOf[IOException]) private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { - logDebug(this.getClass().getSimpleName + ".readObject used") + logDebug(s"${this.getClass().getSimpleName}.readObject used") ois.defaultReadObject() - generatedRDDs = new HashMap[Time, RDD[T]] () + generatedRDDs = new HashMap[Time, RDD[T]]() } // ======================================================================= @@ -532,7 +550,7 @@ abstract class DStream[T: ClassTag] ( * Return a new DStream by applying a function to all elements of this DStream, * and then flattening the results */ - def flatMap[U: ClassTag](flatMapFunc: T => Traversable[U]): DStream[U] = ssc.withScope { + def flatMap[U: ClassTag](flatMapFunc: T => TraversableOnce[U]): DStream[U] = ssc.withScope { new FlatMappedDStream(this, context.sparkContext.clean(flatMapFunc)) } @@ -600,35 +618,13 @@ abstract class DStream[T: ClassTag] ( this.map(x => (x, 1L)).reduceByKey((x: Long, y: Long) => x + y, numPartitions) } - /** - * Apply a function to each RDD in this DStream. This is an output operator, so - * 'this' DStream will be registered as an output stream and therefore materialized. - * - * @deprecated As of 0.9.0, replaced by `foreachRDD`. - */ - @deprecated("use foreachRDD", "0.9.0") - def foreach(foreachFunc: RDD[T] => Unit): Unit = ssc.withScope { - this.foreachRDD(foreachFunc) - } - - /** - * Apply a function to each RDD in this DStream. This is an output operator, so - * 'this' DStream will be registered as an output stream and therefore materialized. - * - * @deprecated As of 0.9.0, replaced by `foreachRDD`. - */ - @deprecated("use foreachRDD", "0.9.0") - def foreach(foreachFunc: (RDD[T], Time) => Unit): Unit = ssc.withScope { - this.foreachRDD(foreachFunc) - } - /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. */ def foreachRDD(foreachFunc: RDD[T] => Unit): Unit = ssc.withScope { val cleanedF = context.sparkContext.clean(foreachFunc, false) - this.foreachRDD((r: RDD[T], t: Time) => cleanedF(r)) + foreachRDD((r: RDD[T], t: Time) => cleanedF(r), displayInnerRDDOps = true) } /** @@ -639,7 +635,23 @@ abstract class DStream[T: ClassTag] ( // because the DStream is reachable from the outer object here, and because // DStreams can't be serialized with closures, we can't proactively check // it for serializability and so we pass the optional false to SparkContext.clean - new ForEachDStream(this, context.sparkContext.clean(foreachFunc, false)).register() + foreachRDD(foreachFunc, displayInnerRDDOps = true) + } + + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * 'this' DStream will be registered as an output stream and therefore materialized. + * @param foreachFunc foreachRDD function + * @param displayInnerRDDOps Whether the detailed callsites and scopes of the RDDs generated + * in the `foreachFunc` to be displayed in the UI. If `false`, then + * only the scopes and callsites of `foreachRDD` will override those + * of the RDDs on the display. + */ + private def foreachRDD( + foreachFunc: (RDD[T], Time) => Unit, + displayInnerRDDOps: Boolean): Unit = { + new ForEachDStream(this, + context.sparkContext.clean(foreachFunc, false), displayInnerRDDOps).register() } /** @@ -722,7 +734,7 @@ abstract class DStream[T: ClassTag] ( val firstNum = rdd.take(num + 1) // scalastyle:off println println("-------------------------------------------") - println("Time: " + time) + println(s"Time: $time") println("-------------------------------------------") firstNum.take(num).foreach(println) if (firstNum.length > num) println("...") @@ -730,7 +742,7 @@ abstract class DStream[T: ClassTag] ( // scalastyle:on println } } - new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register() + foreachRDD(context.sparkContext.clean(foreachFunc), displayInnerRDDOps = false) } /** @@ -757,7 +769,7 @@ abstract class DStream[T: ClassTag] ( /** * Return a new DStream in which each RDD has a single element generated by reducing all * elements in a sliding window over this DStream. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which @@ -780,7 +792,7 @@ abstract class DStream[T: ClassTag] ( * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) * This is more efficient than reduceByWindow without "inverse reduce" function. * However, it is applicable to only "invertible reduce functions". - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param invReduceFunc inverse reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -869,21 +881,19 @@ abstract class DStream[T: ClassTag] ( val alignedToTime = if ((toTime - zeroTime).isMultipleOf(slideDuration)) { toTime } else { - logWarning("toTime (" + toTime + ") is not a multiple of slideDuration (" - + slideDuration + ")") - toTime.floor(slideDuration, zeroTime) + logWarning(s"toTime ($toTime) is not a multiple of slideDuration ($slideDuration)") + toTime.floor(slideDuration, zeroTime) } val alignedFromTime = if ((fromTime - zeroTime).isMultipleOf(slideDuration)) { fromTime } else { - logWarning("fromTime (" + fromTime + ") is not a multiple of slideDuration (" - + slideDuration + ")") + logWarning(s"fromTime ($fromTime) is not a multiple of slideDuration ($slideDuration)") fromTime.floor(slideDuration, zeroTime) } - logInfo("Slicing from " + fromTime + " to " + toTime + - " (aligned to " + alignedFromTime + " and " + alignedToTime + ")") + logInfo(s"Slicing from $fromTime to $toTime" + + s" (aligned to $alignedFromTime and $alignedToTime)") alignedFromTime.to(alignedToTime, slideDuration).flatMap(time => { if (time >= zeroTime) getOrCompute(time) else None @@ -900,7 +910,7 @@ abstract class DStream[T: ClassTag] ( val file = rddToFileName(prefix, suffix, time) rdd.saveAsObjectFile(file) } - this.foreachRDD(saveFunc) + this.foreachRDD(saveFunc, displayInnerRDDOps = false) } /** @@ -913,7 +923,7 @@ abstract class DStream[T: ClassTag] ( val file = rddToFileName(prefix, suffix, time) rdd.saveAsTextFile(file) } - this.foreachRDD(saveFunc) + this.foreachRDD(saveFunc, displayInnerRDDOps = false) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala index 39fd21342813e..e73837eb9602f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala @@ -17,17 +17,19 @@ package org.apache.spark.streaming.dstream +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} + import scala.collection.mutable.HashMap import scala.reflect.ClassTag -import java.io.{ObjectOutputStream, ObjectInputStream, IOException} -import org.apache.hadoop.fs.Path -import org.apache.hadoop.fs.FileSystem -import org.apache.spark.Logging + +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.internal.Logging import org.apache.spark.streaming.Time import org.apache.spark.util.Utils private[streaming] -class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) +class DStreamCheckpointData[T: ClassTag](dstream: DStream[T]) extends Serializable with Logging { protected val data = new HashMap[Time, AnyRef]() @@ -37,13 +39,13 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) // in that batch's checkpoint data @transient private var timeToOldestCheckpointFileTime = new HashMap[Time, Time] - @transient private var fileSystem : FileSystem = null + @transient private var fileSystem: FileSystem = null protected[streaming] def currentCheckpointFiles = data.asInstanceOf[HashMap[Time, String]] /** * Updates the checkpoint data of the DStream. This gets called every time * the graph checkpoint is initiated. Default implementation records the - * checkpoint files to which the generate RDDs of the DStream has been saved. + * checkpoint files at which the generated RDDs of the DStream have been saved. */ def update(time: Time) { @@ -101,16 +103,15 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) /** * Restore the checkpoint data. This gets called once when the DStream graph - * (along with its DStreams) are being restored from a graph checkpoint file. + * (along with its output DStreams) is being restored from a graph checkpoint file. * Default implementation restores the RDDs from their checkpoint files. */ def restore() { // Create RDDs from the checkpoint data currentCheckpointFiles.foreach { - case(time, file) => { + case(time, file) => logInfo("Restoring checkpointed RDD for time " + time + " from file '" + file + "'") dstream.generatedRDDs += ((time, dstream.context.sparkContext.checkpointFile[T](file))) - } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 40208a64861fb..36f50e04db422 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -42,6 +42,7 @@ import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Uti * class remembers the information about the files selected in past batches for * a certain duration (say, "remember window") as shown in the figure below. * + * {{{ * |<----- remember window ----->| * ignore threshold --->| |<--- current batch time * |____.____.____.____.____.____| @@ -49,6 +50,7 @@ import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Uti * ---------------------|----|----|----|----|----|----|-----------------------> Time * |____|____|____|____|____|____| * remembered batches + * }}} * * The trailing end of the window is the "ignore threshold" and all files whose mod times * are less than this threshold are assumed to have already been selected and are therefore @@ -59,24 +61,25 @@ import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Uti * `isNewFile` for more details. * * This makes some assumptions from the underlying file system that the system is monitoring. - * - The clock of the file system is assumed to synchronized with the clock of the machine running - * the streaming app. - * - If a file is to be visible in the directory listings, it must be visible within a certain - * duration of the mod time of the file. This duration is the "remember window", which is set to - * 1 minute (see `FileInputDStream.minRememberDuration`). Otherwise, the file will never be - * selected as the mod time will be less than the ignore threshold when it becomes visible. - * - Once a file is visible, the mod time cannot change. If it does due to appends, then the - * processing semantics are undefined. + * + * - The clock of the file system is assumed to synchronized with the clock of the machine running + * the streaming app. + * - If a file is to be visible in the directory listings, it must be visible within a certain + * duration of the mod time of the file. This duration is the "remember window", which is set to + * 1 minute (see `FileInputDStream.minRememberDuration`). Otherwise, the file will never be + * selected as the mod time will be less than the ignore threshold when it becomes visible. + * - Once a file is visible, the mod time cannot change. If it does due to appends, then the + * processing semantics are undefined. */ private[streaming] class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( - ssc_ : StreamingContext, + _ssc: StreamingContext, directory: String, filter: Path => Boolean = FileInputDStream.defaultFilter, newFilesOnly: Boolean = true, conf: Option[Configuration] = None) (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]) - extends InputDStream[(K, V)](ssc_) { + extends InputDStream[(K, V)](_ssc) { private val serializableConfOpt = conf.map(new SerializableConfiguration(_)) @@ -114,7 +117,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( // Map of batch-time to selected file info for the remembered batches // This is a concurrent map because it's also accessed in unit tests @transient private[streaming] var batchTimeToSelectedFiles = - new mutable.HashMap[Time, Array[String]] with mutable.SynchronizedMap[Time, Array[String]] + new mutable.HashMap[Time, Array[String]] // Set of files that were selected in the remembered batches @transient private var recentlySelectedFiles = new mutable.HashSet[String]() @@ -125,8 +128,8 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( // Timestamp of the last round of finding files @transient private var lastNewFileFindingTime = 0L - @transient private var path_ : Path = null - @transient private var fs_ : FileSystem = null + @transient private var _path: Path = null + @transient private var _fs: FileSystem = null override def start() { } @@ -145,7 +148,9 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( // Find new files val newFiles = findNewFiles(validTime.milliseconds) logInfo("New files at time " + validTime + ":\n" + newFiles.mkString("\n")) - batchTimeToSelectedFiles += ((validTime, newFiles)) + batchTimeToSelectedFiles.synchronized { + batchTimeToSelectedFiles += ((validTime, newFiles)) + } recentlySelectedFiles ++= newFiles val rdds = Some(filesToRDD(newFiles)) // Copy newFiles to immutable.List to prevent from being modified by the user @@ -159,14 +164,15 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( /** Clear the old time-to-files mappings along with old RDDs */ protected[streaming] override def clearMetadata(time: Time) { - super.clearMetadata(time) - val oldFiles = batchTimeToSelectedFiles.filter(_._1 < (time - rememberDuration)) - batchTimeToSelectedFiles --= oldFiles.keys - recentlySelectedFiles --= oldFiles.values.flatten - logInfo("Cleared " + oldFiles.size + " old files that were older than " + - (time - rememberDuration) + ": " + oldFiles.keys.mkString(", ")) - logDebug("Cleared files are:\n" + - oldFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n")) + batchTimeToSelectedFiles.synchronized { + val oldFiles = batchTimeToSelectedFiles.filter(_._1 < (time - rememberDuration)) + batchTimeToSelectedFiles --= oldFiles.keys + recentlySelectedFiles --= oldFiles.values.flatten + logInfo("Cleared " + oldFiles.size + " old files that were older than " + + (time - rememberDuration) + ": " + oldFiles.keys.mkString(", ")) + logDebug("Cleared files are:\n" + + oldFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n")) + } // Delete file mod times that weren't accessed in the last round of getting new files fileToModTime.clearOldValues(lastNewFileFindingTime - 1) } @@ -270,7 +276,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( config) case None => context.sparkContext.newAPIHadoopFile[K, V, F](file) } - if (rdd.partitions.size == 0) { + if (rdd.partitions.isEmpty) { logError("File " + file + " has no data in it. Spark Streaming can only ingest " + "files that have been \"moved\" to the directory assigned to the file stream. " + "Refer to the streaming programming guide for more details.") @@ -286,17 +292,17 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( } private def directoryPath: Path = { - if (path_ == null) path_ = new Path(directory) - path_ + if (_path == null) _path = new Path(directory) + _path } private def fs: FileSystem = { - if (fs_ == null) fs_ = directoryPath.getFileSystem(ssc.sparkContext.hadoopConfiguration) - fs_ + if (_fs == null) _fs = directoryPath.getFileSystem(ssc.sparkContext.hadoopConfiguration) + _fs } private def reset() { - fs_ = null + _fs = null } @throws(classOf[IOException]) @@ -304,8 +310,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( logDebug(this.getClass().getSimpleName + ".readObject used") ois.defaultReadObject() generatedRDDs = new mutable.HashMap[Time, RDD[(K, V)]]() - batchTimeToSelectedFiles = - new mutable.HashMap[Time, Array[String]] with mutable.SynchronizedMap[Time, Array[String]] + batchTimeToSelectedFiles = new mutable.HashMap[Time, Array[String]] recentlySelectedFiles = new mutable.HashSet[String]() fileToModTime = new TimeStampedHashMap[String, Long](true) } @@ -321,21 +326,20 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( override def update(time: Time) { hadoopFiles.clear() - hadoopFiles ++= batchTimeToSelectedFiles + batchTimeToSelectedFiles.synchronized { hadoopFiles ++= batchTimeToSelectedFiles } } override def cleanup(time: Time) { } override def restore() { hadoopFiles.toSeq.sortBy(_._1)(Time.ordering).foreach { - case (t, f) => { + case (t, f) => // Restore the metadata in both files and generatedRDDs logInfo("Restoring files for time " + t + " - " + f.mkString("[", ", ", "]") ) - batchTimeToSelectedFiles += ((t, f)) + batchTimeToSelectedFiles.synchronized { batchTimeToSelectedFiles += ((t, f)) } recentlySelectedFiles ++= f generatedRDDs += ((t, filesToRDD(f))) - } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala index fcd5216f101af..43079880b2352 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.rdd.RDD import scala.reflect.ClassTag +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} + private[streaming] class FilteredDStream[T: ClassTag]( parent: DStream[T], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala index 9d09a3baf37ca..f5b1e5f3a1454 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala @@ -17,11 +17,11 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ import scala.reflect.ClassTag +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} + private[streaming] class FlatMapValuedDStream[K: ClassTag, V: ClassTag, U: ClassTag]( parent: DStream[(K, V)], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala index 475ea2d2d4f38..d60a6179782e0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala @@ -17,14 +17,15 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.rdd.RDD import scala.reflect.ClassTag +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} + private[streaming] class FlatMappedDStream[T: ClassTag, U: ClassTag]( parent: DStream[T], - flatMapFunc: T => Traversable[U] + flatMapFunc: T => TraversableOnce[U] ) extends DStream[U](parent.ssc) { override def dependencies: List[DStream[_]] = List(parent) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala index c109ceccc6989..a0fadee8a9844 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala @@ -17,15 +17,25 @@ package org.apache.spark.streaming.dstream +import scala.reflect.ClassTag + import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.streaming.scheduler.Job -import scala.reflect.ClassTag +/** + * An internal DStream used to represent output operations like DStream.foreachRDD. + * @param parent Parent DStream + * @param foreachFunc Function to apply on each RDD generated by the parent DStream + * @param displayInnerRDDOps Whether the detailed callsites and scopes of the RDDs generated + * by `foreachFunc` will be displayed in the UI; only the scope and + * callsite of `DStream.foreachRDD` will be displayed. + */ private[streaming] class ForEachDStream[T: ClassTag] ( parent: DStream[T], - foreachFunc: (RDD[T], Time) => Unit + foreachFunc: (RDD[T], Time) => Unit, + displayInnerRDDOps: Boolean ) extends DStream[Unit](parent.ssc) { override def dependencies: List[DStream[_]] = List(parent) @@ -37,8 +47,7 @@ class ForEachDStream[T: ClassTag] ( override def generateJob(time: Time): Option[Job] = { parent.getOrCompute(time) match { case Some(rdd) => - val jobFunc = () => createRDDWithLocalProperties(time) { - ssc.sparkContext.setCallSite(creationSite) + val jobFunc = () => createRDDWithLocalProperties(time, displayInnerRDDOps) { foreachFunc(rdd, time) } Some(new Job(time, jobFunc)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala index dbb295fe54f71..9f1252f091a63 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.rdd.RDD import scala.reflect.ClassTag +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} + private[streaming] class GlommedDStream[T: ClassTag](parent: DStream[T]) extends DStream[Array[T]](parent.ssc) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index 95994c983c0cc..a3c125c306954 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -23,12 +23,12 @@ import org.apache.spark.SparkContext import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.streaming.{Duration, StreamingContext, Time} import org.apache.spark.streaming.scheduler.RateController -import org.apache.spark.streaming.scheduler.rate.RateEstimator import org.apache.spark.util.Utils /** * This is the abstract base class for all input streams. This class provides methods - * start() and stop() which is called by Spark Streaming system to start and stop receiving data. + * start() and stop() which are called by Spark Streaming system to start and stop + * receiving data, respectively. * Input streams that can generate RDDs from new data by running a service/thread only on * the driver node (that is, without running a receiver on worker nodes), can be * implemented by directly inheriting this InputDStream. For example, @@ -37,10 +37,10 @@ import org.apache.spark.util.Utils * that requires running a receiver on the worker nodes, use * [[org.apache.spark.streaming.dstream.ReceiverInputDStream]] as the parent class. * - * @param ssc_ Streaming context that will execute this input stream + * @param _ssc Streaming context that will execute this input stream */ -abstract class InputDStream[T: ClassTag] (ssc_ : StreamingContext) - extends DStream[T](ssc_) { +abstract class InputDStream[T: ClassTag](_ssc: StreamingContext) + extends DStream[T](_ssc) { private[streaming] var lastValidTime: Time = null @@ -90,8 +90,8 @@ abstract class InputDStream[T: ClassTag] (ssc_ : StreamingContext) } else { // Time is valid, but check it it is more than lastValidTime if (lastValidTime != null && time < lastValidTime) { - logWarning("isTimeValid called with " + time + " where as last valid time is " + - lastValidTime) + logWarning(s"isTimeValid called with $time whereas the last valid time " + + s"is $lastValidTime") } lastValidTime = time true @@ -107,8 +107,8 @@ abstract class InputDStream[T: ClassTag] (ssc_ : StreamingContext) } /** Method called to start receiving data. Subclasses must implement this method. */ - def start() + def start(): Unit /** Method called to stop receiving data. Subclasses must implement this method. */ - def stop() + def stop(): Unit } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala index 5994bc1e23f2b..bcdf1752e61e7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.rdd.RDD import scala.reflect.ClassTag +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} + private[streaming] class MapPartitionedDStream[T: ClassTag, U: ClassTag]( parent: DStream[T], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala index 954d2eb4a7b00..c209f86c864ac 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala @@ -17,11 +17,11 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ import scala.reflect.ClassTag +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} + private[streaming] class MapValuedDStream[K: ClassTag, V: ClassTag, U: ClassTag]( parent: DStream[(K, V)], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala new file mode 100644 index 0000000000000..ed08191f41cc8 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import scala.reflect.ClassTag + +import org.apache.spark._ +import org.apache.spark.annotation.Experimental +import org.apache.spark.rdd.{EmptyRDD, RDD} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming._ +import org.apache.spark.streaming.dstream.InternalMapWithStateDStream._ +import org.apache.spark.streaming.rdd.{MapWithStateRDD, MapWithStateRDDRecord} + +/** + * :: Experimental :: + * DStream representing the stream of data generated by `mapWithState` operation on a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. + * Additionally, it also gives access to the stream of state snapshots, that is, the state data of + * all keys after a batch has updated them. + * + * @tparam KeyType Class of the key + * @tparam ValueType Class of the value + * @tparam StateType Class of the state data + * @tparam MappedType Class of the mapped data + */ +@Experimental +sealed abstract class MapWithStateDStream[KeyType, ValueType, StateType, MappedType: ClassTag]( + ssc: StreamingContext) extends DStream[MappedType](ssc) { + + /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */ + def stateSnapshots(): DStream[(KeyType, StateType)] +} + +/** Internal implementation of the [[MapWithStateDStream]] */ +private[streaming] class MapWithStateDStreamImpl[ + KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, MappedType: ClassTag]( + dataStream: DStream[(KeyType, ValueType)], + spec: StateSpecImpl[KeyType, ValueType, StateType, MappedType]) + extends MapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream.context) { + + private val internalStream = + new InternalMapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream, spec) + + override def slideDuration: Duration = internalStream.slideDuration + + override def dependencies: List[DStream[_]] = List(internalStream) + + override def compute(validTime: Time): Option[RDD[MappedType]] = { + internalStream.getOrCompute(validTime).map { _.flatMap[MappedType] { _.mappedData } } + } + + /** + * Forward the checkpoint interval to the internal DStream that computes the state maps. This + * to make sure that this DStream does not get checkpointed, only the internal stream. + */ + override def checkpoint(checkpointInterval: Duration): DStream[MappedType] = { + internalStream.checkpoint(checkpointInterval) + this + } + + /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */ + def stateSnapshots(): DStream[(KeyType, StateType)] = { + internalStream.flatMap { + _.stateMap.getAll().map { case (k, s, _) => (k, s) }.toTraversable } + } + + def keyClass: Class[_] = implicitly[ClassTag[KeyType]].runtimeClass + + def valueClass: Class[_] = implicitly[ClassTag[ValueType]].runtimeClass + + def stateClass: Class[_] = implicitly[ClassTag[StateType]].runtimeClass + + def mappedClass: Class[_] = implicitly[ClassTag[MappedType]].runtimeClass +} + +/** + * A DStream that allows per-key state to be maintained, and arbitrary records to be generated + * based on updates to the state. This is the main DStream that implements the `mapWithState` + * operation on DStreams. + * + * @param parent Parent (key, value) stream that is the source + * @param spec Specifications of the mapWithState operation + * @tparam K Key type + * @tparam V Value type + * @tparam S Type of the state maintained + * @tparam E Type of the mapped data + */ +private[streaming] +class InternalMapWithStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( + parent: DStream[(K, V)], spec: StateSpecImpl[K, V, S, E]) + extends DStream[MapWithStateRDDRecord[K, S, E]](parent.context) { + + persist(StorageLevel.MEMORY_ONLY) + + private val partitioner = spec.getPartitioner().getOrElse( + new HashPartitioner(ssc.sc.defaultParallelism)) + + private val mappingFunction = spec.getFunction() + + override def slideDuration: Duration = parent.slideDuration + + override def dependencies: List[DStream[_]] = List(parent) + + /** Enable automatic checkpointing */ + override val mustCheckpoint = true + + /** Override the default checkpoint duration */ + override def initialize(time: Time): Unit = { + if (checkpointDuration == null) { + checkpointDuration = slideDuration * DEFAULT_CHECKPOINT_DURATION_MULTIPLIER + } + super.initialize(time) + } + + /** Method that generates a RDD for the given time */ + override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = { + // Get the previous state or create a new empty state RDD + val prevStateRDD = getOrCompute(validTime - slideDuration) match { + case Some(rdd) => + if (rdd.partitioner != Some(partitioner)) { + // If the RDD is not partitioned the right way, let us repartition it using the + // partition index as the key. This is to ensure that state RDD is always partitioned + // before creating another state RDD using it + MapWithStateRDD.createFromRDD[K, V, S, E]( + rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime) + } else { + rdd + } + case None => + MapWithStateRDD.createFromPairRDD[K, V, S, E]( + spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), + partitioner, + validTime + ) + } + + + // Compute the new state RDD with previous state RDD and partitioned data RDD + // Even if there is no data RDD, use an empty one to create a new state RDD + val dataRDD = parent.getOrCompute(validTime).getOrElse { + context.sparkContext.emptyRDD[(K, V)] + } + val partitionedDataRDD = dataRDD.partitionBy(partitioner) + val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => + (validTime - interval).milliseconds + } + Some(new MapWithStateRDD( + prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime)) + } +} + +private[streaming] object InternalMapWithStateDStream { + private val DEFAULT_CHECKPOINT_DURATION_MULTIPLIER = 10 +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala index fa14b2e897c3e..e11d82697af89 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.rdd.RDD import scala.reflect.ClassTag +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} + private[streaming] class MappedDStream[T: ClassTag, U: ClassTag] ( parent: DStream[T], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 71bec96d46c8d..d6ff96e1fc696 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -25,8 +25,9 @@ import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} import org.apache.spark.{HashPartitioner, Partitioner} +import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.{Duration, Time} +import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext.rddToFileName import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf} @@ -35,8 +36,7 @@ import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf} */ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K]) - extends Serializable -{ + extends Serializable { private[streaming] def ssc = self.ssc private[streaming] def sparkContext = self.context.sparkContext @@ -75,8 +75,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) /** * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are - * merged using the associative reduce function. Hash partitioning is used to generate the RDDs - * with Spark's default number of partitions. + * merged using the associative and commutative reduce function. Hash partitioning is used to + * generate the RDDs with Spark's default number of partitions. */ def reduceByKey(reduceFunc: (V, V) => V): DStream[(K, V)] = ssc.withScope { reduceByKey(reduceFunc, defaultPartitioner()) @@ -204,7 +204,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) * Similar to `DStream.reduceByKey()`, but applies it over a sliding window. The new DStream * generates RDDs with the same interval as this DStream. Hash partitioning is used to generate * the RDDs with Spark's default number of partitions. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval */ @@ -219,7 +219,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to * generate the RDDs with Spark's default number of partitions. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which @@ -238,7 +238,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to * generate the RDDs with `numPartitions` partitions. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which @@ -259,7 +259,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) /** * Return a new DStream by applying `reduceByKey` over a sliding window. Similar to * `DStream.reduceByKey()`, but applies it over a sliding window. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which @@ -289,7 +289,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) * This is more efficient than reduceByKeyAndWindow without "inverse reduce" function. * However, it is applicable to only "invertible reduce functions". * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param invReduceFunc inverse reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -320,7 +320,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) * This is more efficient than reduceByKeyAndWindow without "inverse reduce" function. * However, it is applicable to only "invertible reduce functions". - * @param reduceFunc associative reduce function + * @param reduceFunc associative and commutative reduce function * @param invReduceFunc inverse reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -350,6 +350,41 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) ) } + /** + * :: Experimental :: + * Return a [[MapWithStateDStream]] by applying a function to every key-value element of + * `this` stream, while maintaining some state data for each unique key. The mapping function + * and other specification (e.g. partitioners, timeouts, initial state data, etc.) of this + * transformation can be specified using [[StateSpec]] class. The state data is accessible in + * as a parameter of type [[State]] in the mapping function. + * + * Example of using `mapWithState`: + * {{{ + * // A mapping function that maintains an integer state and return a String + * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = { + * // Use state.exists(), state.get(), state.update() and state.remove() + * // to manage state, and return the necessary string + * } + * + * val spec = StateSpec.function(mappingFunction).numPartitions(10) + * + * val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec) + * }}} + * + * @param spec Specification of this transformation + * @tparam StateType Class type of the state data + * @tparam MappedType Class type of the mapped data + */ + @Experimental + def mapWithState[StateType: ClassTag, MappedType: ClassTag]( + spec: StateSpec[K, V, StateType, MappedType] + ): MapWithStateDStream[K, V, StateType, MappedType] = { + new MapWithStateDStreamImpl[K, V, StateType, MappedType]( + self, + spec.asInstanceOf[StateSpecImpl[K, V, StateType, MappedType]] + ) + } + /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. @@ -411,7 +446,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) * remember the partitioner despite the key being changed. * @param partitioner Partitioner for controlling the partitioning of each RDD in the new * DStream - * @param rememberPartitioner Whether to remember the paritioner object in the generated RDDs. + * @param rememberPartitioner Whether to remember the partitioner object in the generated RDDs. * @tparam S State type */ def updateStateByKey[S: ClassTag]( @@ -455,7 +490,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) * remember the partitioner despite the key being changed. * @param partitioner Partitioner for controlling the partitioning of each RDD in the new * DStream - * @param rememberPartitioner Whether to remember the paritioner object in the generated RDDs. + * @param rememberPartitioner Whether to remember the partitioner object in the generated RDDs. * @param initialRDD initial state value of each key. * @tparam S State type */ @@ -692,7 +727,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) val serializableConf = new SerializableJobConf(conf) val saveFunc = (rdd: RDD[(K, V)], time: Time) => { val file = rddToFileName(prefix, suffix, time) - rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, serializableConf.value) + rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, + new JobConf(serializableConf.value)) } self.foreachRDD(saveFunc) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala index 002aac9f43617..e003ddb96c860 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala @@ -17,14 +17,15 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.StreamingContext import scala.reflect.ClassTag + +import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.receiver.Receiver private[streaming] class PluggableInputDStream[T: ClassTag]( - ssc_ : StreamingContext, - receiver: Receiver[T]) extends ReceiverInputDStream[T](ssc_) { + _ssc: StreamingContext, + receiver: Receiver[T]) extends ReceiverInputDStream[T](_ssc) { def getReceiver(): Receiver[T] = { receiver diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala index cd073646370d0..f9c78699164ab 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.{ArrayBuffer, Queue} import scala.reflect.ClassTag import org.apache.spark.rdd.{RDD, UnionRDD} -import org.apache.spark.streaming.{Time, StreamingContext} +import org.apache.spark.streaming.{StreamingContext, Time} private[streaming] class QueueInputDStream[T: ClassTag]( @@ -48,12 +48,15 @@ class QueueInputDStream[T: ClassTag]( override def compute(validTime: Time): Option[RDD[T]] = { val buffer = new ArrayBuffer[RDD[T]]() - if (oneAtATime && queue.size > 0) { - buffer += queue.dequeue() - } else { - buffer ++= queue.dequeueAll(_ => true) + queue.synchronized { + if (oneAtATime && queue.nonEmpty) { + buffer += queue.dequeue() + } else { + buffer ++= queue + queue.clear() + } } - if (buffer.size > 0) { + if (buffer.nonEmpty) { if (oneAtATime) { Some(buffer.head) } else { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala index 5a9eda7c12776..b2ec33e82ddaa 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala @@ -17,19 +17,18 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.{Logging, SparkEnv} -import org.apache.spark.storage.{StorageLevel, StreamBlockId} -import org.apache.spark.streaming.StreamingContext - -import scala.reflect.ClassTag - +import java.io.EOFException import java.net.InetSocketAddress import java.nio.ByteBuffer import java.nio.channels.{ReadableByteChannel, SocketChannel} -import java.io.EOFException import java.util.concurrent.ArrayBlockingQueue -import org.apache.spark.streaming.receiver.Receiver +import scala.reflect.ClassTag + +import org.apache.spark.internal.Logging +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.receiver.Receiver /** * An input stream that reads blocks of serialized objects from a given network address. @@ -39,11 +38,11 @@ import org.apache.spark.streaming.receiver.Receiver */ private[streaming] class RawInputDStream[T: ClassTag]( - ssc_ : StreamingContext, + _ssc: StreamingContext, host: String, port: Int, storageLevel: StorageLevel - ) extends ReceiverInputDStream[T](ssc_ ) with Logging { + ) extends ReceiverInputDStream[T](_ssc) with Logging { def getReceiver(): Receiver[T] = { new RawNetworkReceiver(host, port, storageLevel).asInstanceOf[Receiver[T]] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 87c20afd5c13c..fd3e72e41be26 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -21,25 +21,25 @@ import scala.reflect.ClassTag import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.BlockId +import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.streaming.scheduler.{RateController, ReceivedBlockInfo, StreamInputInfo} import org.apache.spark.streaming.scheduler.rate.RateEstimator -import org.apache.spark.streaming.scheduler.{ReceivedBlockInfo, RateController, StreamInputInfo} import org.apache.spark.streaming.util.WriteAheadLogUtils -import org.apache.spark.streaming.{StreamingContext, Time} /** * Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]] * that has to start a receiver on worker nodes to receive external data. * Specific implementations of ReceiverInputDStream must - * define `the getReceiver()` function that gets the receiver object of type + * define [[getReceiver]] function that gets the receiver object of type * [[org.apache.spark.streaming.receiver.Receiver]] that will be sent * to the workers to receive data. - * @param ssc_ Streaming context that will execute this input stream + * @param _ssc Streaming context that will execute this input stream * @tparam T Class type of the object of this stream */ -abstract class ReceiverInputDStream[T: ClassTag](ssc_ : StreamingContext) - extends InputDStream[T](ssc_) { +abstract class ReceiverInputDStream[T: ClassTag](_ssc: StreamingContext) + extends InputDStream[T](_ssc) { /** * Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker. @@ -108,7 +108,7 @@ abstract class ReceiverInputDStream[T: ClassTag](ssc_ : StreamingContext) } else { // Else, create a BlockRDD. However, if there are some blocks with WAL info but not // others then that is unexpected and log a warning accordingly. - if (blockInfos.find(_.walRecordHandleOption.nonEmpty).nonEmpty) { + if (blockInfos.exists(_.walRecordHandleOption.nonEmpty)) { if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { logError("Some blocks do not have Write Ahead Log information; " + "this is unexpected and data may not be recoverable after driver failures") @@ -119,9 +119,9 @@ abstract class ReceiverInputDStream[T: ClassTag](ssc_ : StreamingContext) val validBlockIds = blockIds.filter { id => ssc.sparkContext.env.blockManager.master.contains(id) } - if (validBlockIds.size != blockIds.size) { + if (validBlockIds.length != blockIds.length) { logWarning("Some blocks could not be recovered as they were not found in memory. " + - "To prevent such data loss, enabled Write Ahead Log (see programming guide " + + "To prevent such data loss, enable Write Ahead Log (see programming guide " + "for more details.") } new BlockRDD[T](ssc.sc, validBlockIds) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala index 6a583bf2a3626..a9e93838b8bf7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -17,18 +17,14 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.{CoGroupedRDD, MapPartitionsRDD} +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + import org.apache.spark.Partitioner -import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.{CoGroupedRDD, RDD} import org.apache.spark.storage.StorageLevel - -import scala.collection.mutable.ArrayBuffer import org.apache.spark.streaming.{Duration, Interval, Time} -import scala.collection.mutable.ArrayBuffer -import scala.reflect.ClassTag - private[streaming] class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( parent: DStream[(K, V)], @@ -91,7 +87,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( logDebug("Window time = " + windowDuration) logDebug("Slide time = " + slideDuration) - logDebug("ZeroTime = " + zeroTime) + logDebug("Zero time = " + zeroTime) logDebug("Current window = " + currentWindow) logDebug("Previous window = " + previousWindow) @@ -132,7 +128,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( val numNewValues = newRDDs.size val mergeValues = (arrayOfValues: Array[Iterable[V]]) => { - if (arrayOfValues.size != 1 + numOldValues + numNewValues) { + if (arrayOfValues.length != 1 + numOldValues + numNewValues) { throw new Exception("Unexpected number of sequences of reduced values") } // Getting reduced values "old time steps" that will be removed from current window diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala index e0ffd5d86b435..6971a66b380db 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala @@ -17,11 +17,11 @@ package org.apache.spark.streaming.dstream +import scala.reflect.ClassTag + import org.apache.spark.Partitioner import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ import org.apache.spark.streaming.{Duration, Time} -import scala.reflect.ClassTag private[streaming] class ShuffledDStream[K: ClassTag, V: ClassTag, C: ClassTag]( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala index de84e0c9a498d..7853af562368e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala @@ -17,27 +17,27 @@ package org.apache.spark.streaming.dstream -import scala.util.control.NonFatal - -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.NextIterator +import java.io._ +import java.net.{ConnectException, Socket} +import java.nio.charset.StandardCharsets import scala.reflect.ClassTag +import scala.util.control.NonFatal -import java.io._ -import java.net.{UnknownHostException, Socket} -import org.apache.spark.Logging +import org.apache.spark.internal.Logging +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.util.NextIterator private[streaming] class SocketInputDStream[T: ClassTag]( - ssc_ : StreamingContext, + _ssc: StreamingContext, host: String, port: Int, bytesToObjects: InputStream => Iterator[T], storageLevel: StorageLevel - ) extends ReceiverInputDStream[T](ssc_) { + ) extends ReceiverInputDStream[T](_ssc) { def getReceiver(): Receiver[T] = { new SocketReceiver(host, port, bytesToObjects, storageLevel) @@ -52,7 +52,20 @@ class SocketReceiver[T: ClassTag]( storageLevel: StorageLevel ) extends Receiver[T](storageLevel) with Logging { + private var socket: Socket = _ + def onStart() { + + logInfo(s"Connecting to $host:$port") + try { + socket = new Socket(host, port) + } catch { + case e: ConnectException => + restart(s"Error connecting to $host:$port", e) + return + } + logInfo(s"Connected to $host:$port") + // Start the thread that receives data over a connection new Thread("Socket Receiver") { setDaemon(true) @@ -61,20 +74,22 @@ class SocketReceiver[T: ClassTag]( } def onStop() { - // There is nothing much to do as the thread calling receive() - // is designed to stop by itself isStopped() returns false + // in case restart thread close it twice + synchronized { + if (socket != null) { + socket.close() + socket = null + logInfo(s"Closed socket to $host:$port") + } + } } /** Create a socket connection and receive data until receiver is stopped */ def receive() { - var socket: Socket = null try { - logInfo("Connecting to " + host + ":" + port) - socket = new Socket(host, port) - logInfo("Connected to " + host + ":" + port) val iterator = bytesToObjects(socket.getInputStream()) while(!isStopped && iterator.hasNext) { - store(iterator.next) + store(iterator.next()) } if (!isStopped()) { restart("Socket data stream had no more data") @@ -82,16 +97,11 @@ class SocketReceiver[T: ClassTag]( logInfo("Stopped receiving") } } catch { - case e: java.net.ConnectException => - restart("Error connecting to " + host + ":" + port, e) case NonFatal(e) => logWarning("Error receiving data", e) restart("Error receiving data", e) } finally { - if (socket != null) { - socket.close() - logInfo("Closed socket to " + host + ":" + port) - } + onStop() } } } @@ -104,7 +114,8 @@ object SocketReceiver { * to '\n' delimited strings and returns an iterator to access the strings. */ def bytesToLines(inputStream: InputStream): Iterator[String] = { - val dataInputStream = new BufferedReader(new InputStreamReader(inputStream, "UTF-8")) + val dataInputStream = new BufferedReader( + new InputStreamReader(inputStream, StandardCharsets.UTF_8)) new NextIterator[String] { protected override def getNext() = { val nextValue = dataInputStream.readLine() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index 621d6dff788f4..8efb09a8ce981 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -17,21 +17,20 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.rdd.RDD +import scala.reflect.ClassTag + import org.apache.spark.Partitioner -import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Duration, Time} -import scala.reflect.ClassTag - private[streaming] class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( parent: DStream[(K, V)], updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], partitioner: Partitioner, preservePartitioning: Boolean, - initialRDD : Option[RDD[(K, S)]] + initialRDD: Option[RDD[(K, S)]] ) extends DStream[(K, S)](parent.ssc) { super.persist(StorageLevel.MEMORY_ONLY_SER) @@ -43,17 +42,17 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( override val mustCheckpoint = true private [this] def computeUsingPreviousRDD ( - parentRDD : RDD[(K, V)], prevStateRDD : RDD[(K, S)]) = { + parentRDD: RDD[(K, V)], prevStateRDD: RDD[(K, S)]) = { // Define the function for the mapPartition operation on cogrouped RDD; // first map the cogrouped tuple to tuples of required type, // and then apply the update function val updateFuncLocal = updateFunc val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => { - val i = iterator.map(t => { + val i = iterator.map { t => val itr = t._2._2.iterator val headOption = if (itr.hasNext) Some(itr.next()) else None (t._1, t._2._1.toSeq, headOption) - }) + } updateFuncLocal(i) } val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) @@ -66,14 +65,12 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( // Try to get the previous state RDD getOrCompute(validTime - slideDuration) match { - case Some(prevStateRDD) => { // If previous state RDD exists - + case Some(prevStateRDD) => // If previous state RDD exists // Try to get the parent RDD parent.getOrCompute(validTime) match { - case Some(parentRDD) => { // If parent RDD exists, then compute as usual - computeUsingPreviousRDD (parentRDD, prevStateRDD) - } - case None => { // If parent RDD does not exist + case Some(parentRDD) => // If parent RDD exists, then compute as usual + computeUsingPreviousRDD(parentRDD, prevStateRDD) + case None => // If parent RDD does not exist // Re-apply the update function to the old state RDD val updateFuncLocal = updateFunc @@ -83,41 +80,33 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( } val stateRDD = prevStateRDD.mapPartitions(finalFunc, preservePartitioning) Some(stateRDD) - } } - } - - case None => { // If previous session RDD does not exist (first input data) + case None => // If previous session RDD does not exist (first input data) // Try to get the parent RDD parent.getOrCompute(validTime) match { - case Some(parentRDD) => { // If parent RDD exists, then compute as usual + case Some(parentRDD) => // If parent RDD exists, then compute as usual initialRDD match { - case None => { + case None => // Define the function for the mapPartition operation on grouped RDD; // first map the grouped tuple to tuples of required type, // and then apply the update function val updateFuncLocal = updateFunc - val finalFunc = (iterator : Iterator[(K, Iterable[V])]) => { - updateFuncLocal (iterator.map (tuple => (tuple._1, tuple._2.toSeq, None))) + val finalFunc = (iterator: Iterator[(K, Iterable[V])]) => { + updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2.toSeq, None))) } - val groupedRDD = parentRDD.groupByKey (partitioner) - val sessionRDD = groupedRDD.mapPartitions (finalFunc, preservePartitioning) + val groupedRDD = parentRDD.groupByKey(partitioner) + val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning) // logDebug("Generating state RDD for time " + validTime + " (first)") - Some (sessionRDD) - } - case Some (initialStateRDD) => { + Some(sessionRDD) + case Some(initialStateRDD) => computeUsingPreviousRDD(parentRDD, initialStateRDD) - } } - } - case None => { // If parent RDD does not exist, then nothing to do! + case None => // If parent RDD does not exist, then nothing to do! // logDebug("Not generating state RDD (no previous state, no parent)") None - } } - } } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala index 5eabdf63dc8d7..47eb9b806fa7d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala @@ -51,4 +51,17 @@ class TransformedDStream[U: ClassTag] ( } Some(transformedRDD) } + + /** + * Wrap a body of code such that the call site and operation scope + * information are passed to the RDDs created in this body properly. + * This has been overridden to make sure that `displayInnerRDDOps` is always `true`, that is, + * the inner scopes and callsites of RDDs generated in `DStream.transform` are always + * displayed in the UI. + */ + override protected[streaming] def createRDDWithLocalProperties[U]( + time: Time, + displayInnerRDDOps: Boolean)(body: => U): U = { + super.createRDDWithLocalProperties(time, displayInnerRDDOps = true)(body) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala index d73ffdfd84d2d..d46c0a01e05d9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala @@ -21,17 +21,16 @@ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import org.apache.spark.SparkException -import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.UnionRDD +import org.apache.spark.streaming.{Duration, Time} private[streaming] class UnionDStream[T: ClassTag](parents: Array[DStream[T]]) extends DStream[T](parents.head.ssc) { require(parents.length > 0, "List of DStreams to union is empty") - require(parents.map(_.ssc).distinct.size == 1, "Some of the DStreams have different contexts") - require(parents.map(_.slideDuration).distinct.size == 1, + require(parents.map(_.ssc).distinct.length == 1, "Some of the DStreams have different contexts") + require(parents.map(_.slideDuration).distinct.length == 1, "Some of the DStreams have different slide durations") override def dependencies: List[DStream[_]] = parents.toList @@ -45,8 +44,8 @@ class UnionDStream[T: ClassTag](parents: Array[DStream[T]]) case None => throw new SparkException("Could not generate RDD from a parent for unifying at" + s" time $validTime") } - if (rdds.size > 0) { - Some(new UnionRDD(ssc.sc, rdds)) + if (rdds.nonEmpty) { + Some(ssc.sc.union(rdds)) } else { None } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala index 4efba039f8959..fe0f875525660 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala @@ -17,13 +17,13 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.rdd.{PartitionerAwareUnionRDD, RDD, UnionRDD} +import scala.reflect.ClassTag + +import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.Duration -import scala.reflect.ClassTag - private[streaming] class WindowedDStream[T: ClassTag]( parent: DStream[T], @@ -63,13 +63,6 @@ class WindowedDStream[T: ClassTag]( override def compute(validTime: Time): Option[RDD[T]] = { val currentWindow = new Interval(validTime - windowDuration + parent.slideDuration, validTime) val rddsInWindow = parent.slice(currentWindow) - val windowRDD = if (rddsInWindow.flatMap(_.partitioner).distinct.length == 1) { - logDebug("Using partition aware union for windowing at " + validTime) - new PartitionerAwareUnionRDD(ssc.sc, rddsInWindow) - } else { - logDebug("Using normal union for windowing at " + validTime) - new UnionRDD(ssc.sc, rddsInWindow) - } - Some(windowRDD) + Some(ssc.sc.union(rddsInWindow)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala new file mode 100644 index 0000000000000..8119d808ffab3 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.rdd + +import java.io.{IOException, ObjectOutputStream} + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{State, StateImpl, Time} +import org.apache.spark.streaming.util.{EmptyStateMap, StateMap} +import org.apache.spark.util.Utils + +/** + * Record storing the keyed-state [[MapWithStateRDD]]. Each record contains a [[StateMap]] and a + * sequence of records returned by the mapping function of `mapWithState`. + */ +private[streaming] case class MapWithStateRDDRecord[K, S, E]( + var stateMap: StateMap[K, S], var mappedData: Seq[E]) + +private[streaming] object MapWithStateRDDRecord { + def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( + prevRecord: Option[MapWithStateRDDRecord[K, S, E]], + dataIterator: Iterator[(K, V)], + mappingFunction: (Time, K, Option[V], State[S]) => Option[E], + batchTime: Time, + timeoutThresholdTime: Option[Long], + removeTimedoutData: Boolean + ): MapWithStateRDDRecord[K, S, E] = { + // Create a new state map by cloning the previous one (if it exists) or by creating an empty one + val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() } + + val mappedData = new ArrayBuffer[E] + val wrappedState = new StateImpl[S]() + + // Call the mapping function on each record in the data iterator, and accordingly + // update the states touched, and collect the data returned by the mapping function + dataIterator.foreach { case (key, value) => + wrappedState.wrap(newStateMap.get(key)) + val returned = mappingFunction(batchTime, key, Some(value), wrappedState) + if (wrappedState.isRemoved) { + newStateMap.remove(key) + } else if (wrappedState.isUpdated + || (wrappedState.exists && timeoutThresholdTime.isDefined)) { + newStateMap.put(key, wrappedState.get(), batchTime.milliseconds) + } + mappedData ++= returned + } + + // Get the timed out state records, call the mapping function on each and collect the + // data returned + if (removeTimedoutData && timeoutThresholdTime.isDefined) { + newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) => + wrappedState.wrapTimingOutState(state) + val returned = mappingFunction(batchTime, key, None, wrappedState) + mappedData ++= returned + newStateMap.remove(key) + } + } + + MapWithStateRDDRecord(newStateMap, mappedData) + } +} + +/** + * Partition of the [[MapWithStateRDD]], which depends on corresponding partitions of prev state + * RDD, and a partitioned keyed-data RDD + */ +private[streaming] class MapWithStateRDDPartition( + idx: Int, + @transient private var prevStateRDD: RDD[_], + @transient private var partitionedDataRDD: RDD[_]) extends Partition { + + private[rdd] var previousSessionRDDPartition: Partition = null + private[rdd] var partitionedDataRDDPartition: Partition = null + + override def index: Int = idx + override def hashCode(): Int = idx + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { + // Update the reference to parent split at the time of task serialization + previousSessionRDDPartition = prevStateRDD.partitions(index) + partitionedDataRDDPartition = partitionedDataRDD.partitions(index) + oos.defaultWriteObject() + } +} + + +/** + * RDD storing the keyed states of `mapWithState` operation and corresponding mapped data. + * Each partition of this RDD has a single record of type [[MapWithStateRDDRecord]]. This contains a + * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the mapping + * function of `mapWithState`. + * @param prevStateRDD The previous MapWithStateRDD on whose StateMap data `this` RDD + * will be created + * @param partitionedDataRDD The partitioned data RDD which is used update the previous StateMaps + * in the `prevStateRDD` to create `this` RDD + * @param mappingFunction The function that will be used to update state and return new data + * @param batchTime The time of the batch to which this RDD belongs to. Use to update + * @param timeoutThresholdTime The time to indicate which keys are timeout + */ +private[streaming] class MapWithStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( + private var prevStateRDD: RDD[MapWithStateRDDRecord[K, S, E]], + private var partitionedDataRDD: RDD[(K, V)], + mappingFunction: (Time, K, Option[V], State[S]) => Option[E], + batchTime: Time, + timeoutThresholdTime: Option[Long] + ) extends RDD[MapWithStateRDDRecord[K, S, E]]( + partitionedDataRDD.sparkContext, + List( + new OneToOneDependency[MapWithStateRDDRecord[K, S, E]](prevStateRDD), + new OneToOneDependency(partitionedDataRDD)) + ) { + + @volatile private var doFullScan = false + + require(prevStateRDD.partitioner.nonEmpty) + require(partitionedDataRDD.partitioner == prevStateRDD.partitioner) + + override val partitioner = prevStateRDD.partitioner + + override def checkpoint(): Unit = { + super.checkpoint() + doFullScan = true + } + + override def compute( + partition: Partition, context: TaskContext): Iterator[MapWithStateRDDRecord[K, S, E]] = { + + val stateRDDPartition = partition.asInstanceOf[MapWithStateRDDPartition] + val prevStateRDDIterator = prevStateRDD.iterator( + stateRDDPartition.previousSessionRDDPartition, context) + val dataIterator = partitionedDataRDD.iterator( + stateRDDPartition.partitionedDataRDDPartition, context) + + val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None + val newRecord = MapWithStateRDDRecord.updateRecordWithData( + prevRecord, + dataIterator, + mappingFunction, + batchTime, + timeoutThresholdTime, + removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled + ) + Iterator(newRecord) + } + + override protected def getPartitions: Array[Partition] = { + Array.tabulate(prevStateRDD.partitions.length) { i => + new MapWithStateRDDPartition(i, prevStateRDD, partitionedDataRDD)} + } + + override def clearDependencies(): Unit = { + super.clearDependencies() + prevStateRDD = null + partitionedDataRDD = null + } + + def setFullScan(): Unit = { + doFullScan = true + } +} + +private[streaming] object MapWithStateRDD { + + def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( + pairRDD: RDD[(K, S)], + partitioner: Partitioner, + updateTime: Time): MapWithStateRDD[K, V, S, E] = { + + val stateRDD = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator => + val stateMap = StateMap.create[K, S](SparkEnv.get.conf) + iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) } + Iterator(MapWithStateRDDRecord(stateMap, Seq.empty[E])) + }, preservesPartitioning = true) + + val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner) + + val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None + + new MapWithStateRDD[K, V, S, E]( + stateRDD, emptyDataRDD, noOpFunc, updateTime, None) + } + + def createFromRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( + rdd: RDD[(K, S, Long)], + partitioner: Partitioner, + updateTime: Time): MapWithStateRDD[K, V, S, E] = { + + val pairRDD = rdd.map { x => (x._1, (x._2, x._3)) } + val stateRDD = pairRDD.partitionBy(partitioner).mapPartitions({ iterator => + val stateMap = StateMap.create[K, S](SparkEnv.get.conf) + iterator.foreach { case (key, (state, updateTime)) => + stateMap.put(key, state, updateTime) + } + Iterator(MapWithStateRDDRecord(stateMap, Seq.empty[E])) + }, preservesPartitioning = true) + + val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner) + + val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None + + new MapWithStateRDD[K, V, S, E]( + stateRDD, emptyDataRDD, noOpFunc, updateTime, None) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index f811784b25c82..53fccd8d5e6ed 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -28,11 +28,13 @@ import org.apache.spark.rdd.BlockRDD import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming.util._ import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.io.ChunkedByteBuffer /** * Partition class for [[org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD]]. * It contains information about the id of the blocks having this partition's data and * the corresponding record handle in the write ahead log that backs the partition. + * * @param index index of the partition * @param blockId id of the block having the partition data * @param isBlockIdValid Whether the block Ids are valid (i.e., the blocks are present in the Spark @@ -59,7 +61,6 @@ class WriteAheadLogBackedBlockRDDPartition( * correctness, and it can be used in situations where it is known that the block * does not exist in the Spark executors (e.g. after a failed driver is restarted). * - * * @param sc SparkContext * @param _blockIds Ids of the blocks that contains this RDD's data * @param walRecordHandles Record handles in write ahead logs that contain this RDD's data @@ -114,6 +115,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( assertValid() val hadoopConf = broadcastedHadoopConf.value val blockManager = SparkEnv.get.blockManager + val serializerManager = SparkEnv.get.serializerManager val partition = split.asInstanceOf[WriteAheadLogBackedBlockRDDPartition] val blockId = partition.blockId @@ -156,11 +158,13 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( logInfo(s"Read partition data of $this from write ahead log, record handle " + partition.walRecordHandle) if (storeInBlockManager) { - blockManager.putBytes(blockId, dataRead, storageLevel) + blockManager.putBytes(blockId, new ChunkedByteBuffer(dataRead.duplicate()), storageLevel) logDebug(s"Stored partition data of $this into block manager with level $storageLevel") dataRead.rewind() } - blockManager.dataDeserialize(blockId, dataRead).asInstanceOf[Iterator[T]] + serializerManager + .dataDeserializeStream(blockId, new ChunkedByteBuffer(dataRead).toInputStream()) + .asInstanceOf[Iterator[T]] } if (partition.isBlockIdValid) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala deleted file mode 100644 index 7ec74016a1c2c..0000000000000 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala +++ /dev/null @@ -1,201 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.receiver - -import java.nio.ByteBuffer -import java.util.concurrent.atomic.AtomicInteger - -import scala.concurrent.duration._ -import scala.language.postfixOps -import scala.reflect.ClassTag - -import akka.actor._ -import akka.actor.SupervisorStrategy.{Escalate, Restart} - -import org.apache.spark.{Logging, SparkEnv} -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.storage.StorageLevel - -/** - * :: DeveloperApi :: - * A helper with set of defaults for supervisor strategy - */ -@DeveloperApi -object ActorSupervisorStrategy { - - val defaultStrategy = OneForOneStrategy(maxNrOfRetries = 10, withinTimeRange = - 15 millis) { - case _: RuntimeException => Restart - case _: Exception => Escalate - } -} - -/** - * :: DeveloperApi :: - * A receiver trait to be mixed in with your Actor to gain access to - * the API for pushing received data into Spark Streaming for being processed. - * - * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html - * - * @example {{{ - * class MyActor extends Actor with ActorHelper{ - * def receive { - * case anything: String => store(anything) - * } - * } - * - * // Can be used with an actorStream as follows - * ssc.actorStream[String](Props(new MyActor),"MyActorReceiver") - * - * }}} - * - * @note Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e parametrized type of push block and InputDStream - * should be same. - */ -@DeveloperApi -trait ActorHelper extends Logging{ - - self: Actor => // to ensure that this can be added to Actor classes only - - /** Store an iterator of received data as a data block into Spark's memory. */ - def store[T](iter: Iterator[T]) { - logDebug("Storing iterator") - context.parent ! IteratorData(iter) - } - - /** - * Store the bytes of received data as a data block into Spark's memory. Note - * that the data in the ByteBuffer must be serialized using the same serializer - * that Spark is configured to use. - */ - def store(bytes: ByteBuffer) { - logDebug("Storing Bytes") - context.parent ! ByteBufferData(bytes) - } - - /** - * Store a single item of received data to Spark's memory. - * These single items will be aggregated together into data blocks before - * being pushed into Spark's memory. - */ - def store[T](item: T) { - logDebug("Storing item") - context.parent ! SingleItemData(item) - } -} - -/** - * :: DeveloperApi :: - * Statistics for querying the supervisor about state of workers. Used in - * conjunction with `StreamingContext.actorStream` and - * [[org.apache.spark.streaming.receiver.ActorHelper]]. - */ -@DeveloperApi -case class Statistics(numberOfMsgs: Int, - numberOfWorkers: Int, - numberOfHiccups: Int, - otherInfo: String) - -/** Case class to receive data sent by child actors */ -private[streaming] sealed trait ActorReceiverData -private[streaming] case class SingleItemData[T](item: T) extends ActorReceiverData -private[streaming] case class IteratorData[T](iterator: Iterator[T]) extends ActorReceiverData -private[streaming] case class ByteBufferData(bytes: ByteBuffer) extends ActorReceiverData - -/** - * Provides Actors as receivers for receiving stream. - * - * As Actors can also be used to receive data from almost any stream source. - * A nice set of abstraction(s) for actors as receivers is already provided for - * a few general cases. It is thus exposed as an API where user may come with - * their own Actor to run as receiver for Spark Streaming input source. - * - * This starts a supervisor actor which starts workers and also provides - * [http://doc.akka.io/docs/akka/snapshot/scala/fault-tolerance.html fault-tolerance]. - * - * Here's a way to start more supervisor/workers as its children. - * - * @example {{{ - * context.parent ! Props(new Supervisor) - * }}} OR {{{ - * context.parent ! Props(new Worker, "Worker") - * }}} - */ -private[streaming] class ActorReceiver[T: ClassTag]( - props: Props, - name: String, - storageLevel: StorageLevel, - receiverSupervisorStrategy: SupervisorStrategy - ) extends Receiver[T](storageLevel) with Logging { - - protected lazy val actorSupervisor = SparkEnv.get.actorSystem.actorOf(Props(new Supervisor), - "Supervisor" + streamId) - - class Supervisor extends Actor { - - override val supervisorStrategy = receiverSupervisorStrategy - private val worker = context.actorOf(props, name) - logInfo("Started receiver worker at:" + worker.path) - - private val n: AtomicInteger = new AtomicInteger(0) - private val hiccups: AtomicInteger = new AtomicInteger(0) - - override def receive: PartialFunction[Any, Unit] = { - - case IteratorData(iterator) => - logDebug("received iterator") - store(iterator.asInstanceOf[Iterator[T]]) - - case SingleItemData(msg) => - logDebug("received single") - store(msg.asInstanceOf[T]) - n.incrementAndGet - - case ByteBufferData(bytes) => - logDebug("received bytes") - store(bytes) - - case props: Props => - val worker = context.actorOf(props) - logInfo("Started receiver worker at:" + worker.path) - sender ! worker - - case (props: Props, name: String) => - val worker = context.actorOf(props, name) - logInfo("Started receiver worker at:" + worker.path) - sender ! worker - - case _: PossiblyHarmful => hiccups.incrementAndGet() - - case _: Statistics => - val workers = context.children - sender ! Statistics(n.get, workers.size, hiccups.get, workers.mkString("\n")) - - } - } - - def onStart(): Unit = { - actorSupervisor - logInfo("Supervision tree for receivers initialized at:" + actorSupervisor.path) - } - - def onStop(): Unit = { - actorSupervisor ! PoisonPill - } -} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 421d60ae359f8..4592e015ed9a0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -21,7 +21,8 @@ import java.util.concurrent.{ArrayBlockingQueue, TimeUnit} import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{SparkException, Logging, SparkConf} +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.internal.Logging import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.util.RecurringTimer import org.apache.spark.util.{Clock, SystemClock} @@ -36,7 +37,7 @@ private[streaming] trait BlockGeneratorListener { * that will be useful when a block is generated. Any long blocking operation in this callback * will hurt the throughput. */ - def onAddData(data: Any, metadata: Any) + def onAddData(data: Any, metadata: Any): Unit /** * Called when a new block of data is generated by the block generator. The block generation @@ -46,7 +47,7 @@ private[streaming] trait BlockGeneratorListener { * be useful when the block has been successfully stored. Any long blocking operation in this * callback will hurt the throughput. */ - def onGenerateBlock(blockId: StreamBlockId) + def onGenerateBlock(blockId: StreamBlockId): Unit /** * Called when a new block is ready to be pushed. Callers are supposed to store the block into @@ -54,13 +55,13 @@ private[streaming] trait BlockGeneratorListener { * thread, that is not synchronized with any other callbacks. Hence it is okay to do long * blocking operation in this callback. */ - def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]) + def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]): Unit /** * Called when an error has occurred in the BlockGenerator. Can be called form many places * so better to not do any long block operation in this callback. */ - def onError(message: String, throwable: Throwable) + def onError(message: String, throwable: Throwable): Unit } /** @@ -84,13 +85,14 @@ private[streaming] class BlockGenerator( /** * The BlockGenerator can be in 5 possible states, in the order as follows. - * - Initialized: Nothing has been started - * - Active: start() has been called, and it is generating blocks on added data. - * - StoppedAddingData: stop() has been called, the adding of data has been stopped, - * but blocks are still being generated and pushed. - * - StoppedGeneratingBlocks: Generating of blocks has been stopped, but - * they are still being pushed. - * - StoppedAll: Everything has stopped, and the BlockGenerator object can be GCed. + * + * - Initialized: Nothing has been started + * - Active: start() has been called, and it is generating blocks on added data. + * - StoppedAddingData: stop() has been called, the adding of data has been stopped, + * but blocks are still being generated and pushed. + * - StoppedGeneratingBlocks: Generating of blocks has been stopped, but + * they are still being pushed. + * - StoppedAll: Everything has stopped, and the BlockGenerator object can be GCed. */ private object GeneratorState extends Enumeration { type GeneratorState = Value @@ -125,9 +127,10 @@ private[streaming] class BlockGenerator( /** * Stop everything in the right order such that all the data added is pushed out correctly. - * - First, stop adding data to the current buffer. - * - Second, stop generating blocks. - * - Finally, wait for queue of to-be-pushed blocks to be drained. + * + * - First, stop adding data to the current buffer. + * - Second, stop generating blocks. + * - Finally, wait for queue of to-be-pushed blocks to be drained. */ def stop(): Unit = { // Set the state to stop adding data diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala index bca1fbc8fda2f..fbac4880bdf65 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala @@ -19,24 +19,26 @@ package org.apache.spark.streaming.receiver import com.google.common.util.concurrent.{RateLimiter => GuavaRateLimiter} -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging -/** Provides waitToPush() method to limit the rate at which receivers consume data. - * - * waitToPush method will block the thread if too many messages have been pushed too quickly, - * and only return when a new message has been pushed. It assumes that only one message is - * pushed at a time. - * - * The spark configuration spark.streaming.receiver.maxRate gives the maximum number of messages - * per second that each receiver will accept. - * - * @param conf spark configuration - */ +/** + * Provides waitToPush() method to limit the rate at which receivers consume data. + * + * waitToPush method will block the thread if too many messages have been pushed too quickly, + * and only return when a new message has been pushed. It assumes that only one message is + * pushed at a time. + * + * The spark configuration spark.streaming.receiver.maxRate gives the maximum number of messages + * per second that each receiver will accept. + * + * @param conf spark configuration + */ private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging { // treated as an upper limit private val maxRateLimit = conf.getLong("spark.streaming.receiver.maxRate", Long.MaxValue) - private lazy val rateLimiter = GuavaRateLimiter.create(maxRateLimit.toDouble) + private lazy val rateLimiter = GuavaRateLimiter.create(getInitialRateLimit().toDouble) def waitToPush() { rateLimiter.acquire() @@ -51,7 +53,7 @@ private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging { * Set the rate limit to `newRate`. The new rate will not exceed the maximum rate configured by * {{{spark.streaming.receiver.maxRate}}}, even if `newRate` is higher than that. * - * @param newRate A new rate in events per second. It has no effect if it's 0 or negative. + * @param newRate A new rate in records per second. It has no effect if it's 0 or negative. */ private[receiver] def updateRate(newRate: Long): Unit = if (newRate > 0) { @@ -61,4 +63,11 @@ private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging { rateLimiter.setRate(newRate) } } + + /** + * Get the initial rateLimit to initial rateLimiter + */ + private def getInitialRateLimit(): Long = { + math.min(conf.getLong("spark.streaming.backpressure.initialRate", maxRateLimit), maxRateLimit) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 5f6c5b024085c..7aea1c9b64f5c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -17,18 +17,21 @@ package org.apache.spark.streaming.receiver -import scala.concurrent.duration._ import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration._ import scala.language.{existentials, postfixOps} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.internal.Logging +import org.apache.spark.serializer.SerializerManager import org.apache.spark.storage._ import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._ import org.apache.spark.streaming.util.{WriteAheadLogRecordHandle, WriteAheadLogUtils} import org.apache.spark.util.{Clock, SystemClock, ThreadUtils} -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.util.io.ChunkedByteBuffer /** Trait that represents the metadata related to storage of blocks */ private[streaming] trait ReceivedBlockStoreResult { @@ -45,7 +48,7 @@ private[streaming] trait ReceivedBlockHandler { def storeBlock(blockId: StreamBlockId, receivedBlock: ReceivedBlock): ReceivedBlockStoreResult /** Cleanup old blocks older than the given threshold time */ - def cleanupOldBlocks(threshTime: Long) + def cleanupOldBlocks(threshTime: Long): Unit } @@ -69,9 +72,9 @@ private[streaming] class BlockManagerBasedBlockHandler( def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = { - var numRecords = None: Option[Long] + var numRecords: Option[Long] = None - val putResult: Seq[(BlockId, BlockStatus)] = block match { + val putSucceeded: Boolean = block match { case ArrayBufferBlock(arrayBuffer) => numRecords = Some(arrayBuffer.size.toLong) blockManager.putIterator(blockId, arrayBuffer.iterator, storageLevel, @@ -83,12 +86,13 @@ private[streaming] class BlockManagerBasedBlockHandler( numRecords = countIterator.count putResult case ByteBufferBlock(byteBuffer) => - blockManager.putBytes(blockId, byteBuffer, storageLevel, tellMaster = true) + blockManager.putBytes( + blockId, new ChunkedByteBuffer(byteBuffer.duplicate()), storageLevel, tellMaster = true) case o => throw new SparkException( s"Could not store $blockId to block manager, unexpected block type ${o.getClass.getName}") } - if (!putResult.map { _._1 }.contains(blockId)) { + if (!putSucceeded) { throw new SparkException( s"Could not store $blockId to block manager with storage level $storageLevel") } @@ -120,6 +124,7 @@ private[streaming] case class WriteAheadLogBasedStoreResult( */ private[streaming] class WriteAheadLogBasedBlockHandler( blockManager: BlockManager, + serializerManager: SerializerManager, streamId: Int, storageLevel: StorageLevel, conf: SparkConf, @@ -170,23 +175,26 @@ private[streaming] class WriteAheadLogBasedBlockHandler( val serializedBlock = block match { case ArrayBufferBlock(arrayBuffer) => numRecords = Some(arrayBuffer.size.toLong) - blockManager.dataSerialize(blockId, arrayBuffer.iterator) + serializerManager.dataSerialize(blockId, arrayBuffer.iterator) case IteratorBlock(iterator) => val countIterator = new CountingIterator(iterator) - val serializedBlock = blockManager.dataSerialize(blockId, countIterator) + val serializedBlock = serializerManager.dataSerialize(blockId, countIterator) numRecords = countIterator.count serializedBlock case ByteBufferBlock(byteBuffer) => - byteBuffer + new ChunkedByteBuffer(byteBuffer.duplicate()) case _ => throw new Exception(s"Could not push $blockId to block manager, unexpected block type") } // Store the block in block manager val storeInBlockManagerFuture = Future { - val putResult = - blockManager.putBytes(blockId, serializedBlock, effectiveStorageLevel, tellMaster = true) - if (!putResult.map { _._1 }.contains(blockId)) { + val putSucceeded = blockManager.putBytes( + blockId, + serializedBlock, + effectiveStorageLevel, + tellMaster = true) + if (!putSucceeded) { throw new SparkException( s"Could not store $blockId to block manager with storage level $storageLevel") } @@ -194,7 +202,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler( // Store the block in write ahead log val storeInWriteAheadLogFuture = Future { - writeAheadLog.write(serializedBlock, clock.getTimeMillis()) + writeAheadLog.write(serializedBlock.toByteBuffer, clock.getTimeMillis()) } // Combine the futures, wait for both to complete, and return the write ahead log record handle diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala index 2252e28f22af8..5157ca62dc449 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala @@ -22,8 +22,8 @@ import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConverters._ -import org.apache.spark.storage.StorageLevel import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.storage.StorageLevel /** * :: DeveloperApi :: @@ -99,16 +99,16 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * (iii) `restart(...)` can be called to restart the receiver. This will call `onStop()` * immediately, and then `onStart()` after a delay. */ - def onStart() + def onStart(): Unit /** * This method is called by the system when the receiver is stopped. All resources - * (threads, buffers, etc.) setup in `onStart()` must be cleaned up in this method. + * (threads, buffers, etc.) set up in `onStart()` must be cleaned up in this method. */ - def onStop() + def onStop(): Unit /** Override this to specify a preferred location (hostname). */ - def preferredLocation : Option[String] = None + def preferredLocation: Option[String] = None /** * Store a single item of received data to Spark's memory. @@ -257,11 +257,11 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable private var id: Int = -1 /** Handler object that runs the receiver. This is instantiated lazily in the worker. */ - @transient private var _supervisor : ReceiverSupervisor = null + @transient private var _supervisor: ReceiverSupervisor = null /** Set the ID of the DStream that this receiver is associated with. */ - private[streaming] def setReceiverId(id_ : Int) { - id = id_ + private[streaming] def setReceiverId(_id: Int) { + id = _id } /** Attach Network Receiver executor to this receiver. */ @@ -273,7 +273,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable /** Get the attached supervisor. */ private[streaming] def supervisor: ReceiverSupervisor = { assert(_supervisor != null, - "A ReceiverSupervisor have not been attached to the receiver yet. Maybe you are starting " + + "A ReceiverSupervisor has not been attached to the receiver yet. Maybe you are starting " + "some computation in the receiver before the Receiver.onStart() has been called.") _supervisor } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index 158d1ba2f183a..42fc84c19b971 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -24,9 +24,10 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent._ import scala.util.control.NonFatal -import org.apache.spark.{SparkEnv, Logging, SparkConf} +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging import org.apache.spark.storage.StreamBlockId -import org.apache.spark.util.{Utils, ThreadUtils} +import org.apache.spark.util.{ThreadUtils, Utils} /** * Abstract class that is responsible for supervising a Receiver in the worker. @@ -69,28 +70,28 @@ private[streaming] abstract class ReceiverSupervisor( @volatile private[streaming] var receiverState = Initialized /** Push a single data item to backend data store. */ - def pushSingle(data: Any) + def pushSingle(data: Any): Unit /** Store the bytes of received data as a data block into Spark's memory. */ def pushBytes( bytes: ByteBuffer, optionalMetadata: Option[Any], optionalBlockId: Option[StreamBlockId] - ) + ): Unit /** Store a iterator of received data as a data block into Spark's memory. */ def pushIterator( iterator: Iterator[_], optionalMetadata: Option[Any], optionalBlockId: Option[StreamBlockId] - ) + ): Unit /** Store an ArrayBuffer of received data as a data block into Spark's memory. */ def pushArrayBuffer( arrayBuffer: ArrayBuffer[_], optionalMetadata: Option[Any], optionalBlockId: Option[StreamBlockId] - ) + ): Unit /** * Create a custom [[BlockGenerator]] that the receiver implementation can directly control @@ -102,7 +103,7 @@ private[streaming] abstract class ReceiverSupervisor( def createBlockGenerator(blockGeneratorListener: BlockGeneratorListener): BlockGenerator /** Report errors. */ - def reportError(message: String, throwable: Throwable) + def reportError(message: String, throwable: Throwable): Unit /** * Called when supervisor is started. @@ -143,10 +144,10 @@ private[streaming] abstract class ReceiverSupervisor( def startReceiver(): Unit = synchronized { try { if (onReceiverStart()) { - logInfo("Starting receiver") + logInfo(s"Starting receiver $streamId") receiverState = Started receiver.onStart() - logInfo("Called receiver onStart") + logInfo(s"Called receiver $streamId onStart") } else { // The driver refused us stop("Registered unsuccessfully because Driver refused to start receiver " + streamId, None) @@ -174,7 +175,7 @@ private[streaming] abstract class ReceiverSupervisor( } } catch { case NonFatal(t) => - logError("Error stopping receiver " + streamId + t.getStackTraceString) + logError(s"Error stopping receiver $streamId ${Utils.exceptionString(t)}") } } @@ -218,11 +219,9 @@ private[streaming] abstract class ReceiverSupervisor( stopLatch.await() if (stoppingError != null) { logError("Stopped receiver with error: " + stoppingError) + throw stoppingError } else { logInfo("Stopped receiver without error") } - if (stoppingError != null) { - throw stoppingError - } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 167f56aa42281..4fb0f8caacbb6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -19,20 +19,22 @@ package org.apache.spark.streaming.receiver import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.ConcurrentLinkedQueue -import scala.collection.mutable +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import com.google.common.base.Throwables import org.apache.hadoop.conf.Configuration +import org.apache.spark.{SparkEnv, SparkException} +import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.Time import org.apache.spark.streaming.scheduler._ import org.apache.spark.streaming.util.WriteAheadLogUtils import org.apache.spark.util.RpcUtils -import org.apache.spark.{Logging, SparkEnv, SparkException} /** * Concrete implementation of [[org.apache.spark.streaming.receiver.ReceiverSupervisor]] @@ -58,7 +60,7 @@ private[streaming] class ReceiverSupervisorImpl( "Please use streamingContext.checkpoint() to set the checkpoint directory. " + "See documentation for more details.") } - new WriteAheadLogBasedBlockHandler(env.blockManager, receiver.streamId, + new WriteAheadLogBasedBlockHandler(env.blockManager, env.serializerManager, receiver.streamId, receiver.storageLevel, env.conf, hadoopConf, checkpointDirOption.get) } else { new BlockManagerBasedBlockHandler(env.blockManager, receiver.storageLevel) @@ -83,7 +85,7 @@ private[streaming] class ReceiverSupervisorImpl( cleanupOldBlocks(threshTime) case UpdateRateLimit(eps) => logInfo(s"Received a new rate limit: $eps.") - registeredBlockGenerators.foreach { bg => + registeredBlockGenerators.asScala.foreach { bg => bg.updateRate(eps) } } @@ -92,8 +94,7 @@ private[streaming] class ReceiverSupervisorImpl( /** Unique block ids if one wants to add blocks directly */ private val newBlockId = new AtomicLong(System.currentTimeMillis()) - private val registeredBlockGenerators = new mutable.ArrayBuffer[BlockGenerator] - with mutable.SynchronizedBuffer[BlockGenerator] + private val registeredBlockGenerators = new ConcurrentLinkedQueue[BlockGenerator]() /** Divides received data records into data blocks for pushing in BlockManager. */ private val defaultBlockGeneratorListener = new BlockGeneratorListener { @@ -170,11 +171,11 @@ private[streaming] class ReceiverSupervisorImpl( } override protected def onStart() { - registeredBlockGenerators.foreach { _.start() } + registeredBlockGenerators.asScala.foreach { _.start() } } override protected def onStop(message: String, error: Option[Throwable]) { - registeredBlockGenerators.foreach { _.stop() } + registeredBlockGenerators.asScala.foreach { _.stop() } env.rpcEnv.stop(endpoint) } @@ -194,10 +195,11 @@ private[streaming] class ReceiverSupervisorImpl( override def createBlockGenerator( blockGeneratorListener: BlockGeneratorListener): BlockGenerator = { // Cleanup BlockGenerators that have already been stopped - registeredBlockGenerators --= registeredBlockGenerators.filter{ _.isStopped() } + val stoppedGenerators = registeredBlockGenerators.asScala.filter{ _.isStopped() } + stoppedGenerators.foreach(registeredBlockGenerators.remove(_)) val newBlockGenerator = new BlockGenerator(blockGeneratorListener, streamId, env.conf) - registeredBlockGenerators += newBlockGenerator + registeredBlockGenerators.add(newBlockGenerator) newBlockGenerator } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala index 436eb0a566141..5b2b959f8138d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala @@ -41,9 +41,6 @@ case class BatchInfo( outputOperationInfos: Map[Int, OutputOperationInfo] ) { - @deprecated("Use streamIdToInputInfo instead", "1.5.0") - def streamIdToNumRecords: Map[Int, Long] = streamIdToInputInfo.mapValues(_.numRecords) - /** * Time taken for the first job of this batch to start processing from the time this batch * was submitted to the streaming scheduler. Essentially, it is diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala new file mode 100644 index 0000000000000..f7b6584893c6e --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.streaming.scheduler + +import scala.util.Random + +import org.apache.spark.{ExecutorAllocationClient, SparkConf} +import org.apache.spark.internal.Logging +import org.apache.spark.streaming.util.RecurringTimer +import org.apache.spark.util.{Clock, Utils} + +/** + * Class that manages executor allocated to a StreamingContext, and dynamically request or kill + * executors based on the statistics of the streaming computation. This is different from the core + * dynamic allocation policy; the core policy relies on executors being idle for a while, but the + * micro-batch model of streaming prevents any particular executors from being idle for a long + * time. Instead, the measure of "idle-ness" needs to be based on the time taken to process + * each batch. + * + * At a high level, the policy implemented by this class is as follows: + * - Use StreamingListener interface get batch processing times of completed batches + * - Periodically take the average batch completion times and compare with the batch interval + * - If (avg. proc. time / batch interval) >= scaling up ratio, then request more executors. + * The number of executors requested is based on the ratio = (avg. proc. time / batch interval). + * - If (avg. proc. time / batch interval) <= scaling down ratio, then try to kill a executor that + * is not running a receiver. + * + * This features should ideally be used in conjunction with backpressure, as backpressure ensures + * system stability, while executors are being readjusted. + */ +private[streaming] class ExecutorAllocationManager( + client: ExecutorAllocationClient, + receiverTracker: ReceiverTracker, + conf: SparkConf, + batchDurationMs: Long, + clock: Clock) extends StreamingListener with Logging { + + import ExecutorAllocationManager._ + + private val scalingIntervalSecs = conf.getTimeAsSeconds( + SCALING_INTERVAL_KEY, + s"${SCALING_INTERVAL_DEFAULT_SECS}s") + private val scalingUpRatio = conf.getDouble(SCALING_UP_RATIO_KEY, SCALING_UP_RATIO_DEFAULT) + private val scalingDownRatio = conf.getDouble(SCALING_DOWN_RATIO_KEY, SCALING_DOWN_RATIO_DEFAULT) + private val minNumExecutors = conf.getInt( + MIN_EXECUTORS_KEY, + math.max(1, receiverTracker.numReceivers)) + private val maxNumExecutors = conf.getInt(MAX_EXECUTORS_KEY, Integer.MAX_VALUE) + private val timer = new RecurringTimer(clock, scalingIntervalSecs * 1000, + _ => manageAllocation(), "streaming-executor-allocation-manager") + + @volatile private var batchProcTimeSum = 0L + @volatile private var batchProcTimeCount = 0 + + validateSettings() + + def start(): Unit = { + timer.start() + logInfo(s"ExecutorAllocationManager started with " + + s"ratios = [$scalingUpRatio, $scalingDownRatio] and interval = $scalingIntervalSecs sec") + } + + def stop(): Unit = { + timer.stop(interruptTimer = true) + logInfo("ExecutorAllocationManager stopped") + } + + /** + * Manage executor allocation by requesting or killing executors based on the collected + * batch statistics. + */ + private def manageAllocation(): Unit = synchronized { + logInfo(s"Managing executor allocation with ratios = [$scalingUpRatio, $scalingDownRatio]") + if (batchProcTimeCount > 0) { + val averageBatchProcTime = batchProcTimeSum / batchProcTimeCount + val ratio = averageBatchProcTime.toDouble / batchDurationMs + logInfo(s"Average: $averageBatchProcTime, ratio = $ratio" ) + if (ratio >= scalingUpRatio) { + logDebug("Requesting executors") + val numNewExecutors = math.max(math.round(ratio).toInt, 1) + requestExecutors(numNewExecutors) + } else if (ratio <= scalingDownRatio) { + logDebug("Killing executors") + killExecutor() + } + } + batchProcTimeSum = 0 + batchProcTimeCount = 0 + } + + /** Request the specified number of executors over the currently active one */ + private def requestExecutors(numNewExecutors: Int): Unit = { + require(numNewExecutors >= 1) + val allExecIds = client.getExecutorIds() + logDebug(s"Executors (${allExecIds.size}) = ${allExecIds}") + val targetTotalExecutors = + math.max(math.min(maxNumExecutors, allExecIds.size + numNewExecutors), minNumExecutors) + client.requestTotalExecutors(targetTotalExecutors, 0, Map.empty) + logInfo(s"Requested total $targetTotalExecutors executors") + } + + /** Kill an executor that is not running any receiver, if possible */ + private def killExecutor(): Unit = { + val allExecIds = client.getExecutorIds() + logDebug(s"Executors (${allExecIds.size}) = ${allExecIds}") + + if (allExecIds.nonEmpty && allExecIds.size > minNumExecutors) { + val execIdsWithReceivers = receiverTracker.allocatedExecutors.values.flatten.toSeq + logInfo(s"Executors with receivers (${execIdsWithReceivers.size}): ${execIdsWithReceivers}") + + val removableExecIds = allExecIds.diff(execIdsWithReceivers) + logDebug(s"Removable executors (${removableExecIds.size}): ${removableExecIds}") + if (removableExecIds.nonEmpty) { + val execIdToRemove = removableExecIds(Random.nextInt(removableExecIds.size)) + client.killExecutor(execIdToRemove) + logInfo(s"Requested to kill executor $execIdToRemove") + } else { + logInfo(s"No non-receiver executors to kill") + } + } else { + logInfo("No available executor to kill") + } + } + + private def addBatchProcTime(timeMs: Long): Unit = synchronized { + batchProcTimeSum += timeMs + batchProcTimeCount += 1 + logDebug( + s"Added batch processing time $timeMs, sum = $batchProcTimeSum, count = $batchProcTimeCount") + } + + private def validateSettings(): Unit = { + require( + scalingIntervalSecs > 0, + s"Config $SCALING_INTERVAL_KEY must be more than 0") + + require( + scalingUpRatio > 0, + s"Config $SCALING_UP_RATIO_KEY must be more than 0") + + require( + scalingDownRatio > 0, + s"Config $SCALING_DOWN_RATIO_KEY must be more than 0") + + require( + minNumExecutors > 0, + s"Config $MIN_EXECUTORS_KEY must be more than 0") + + require( + maxNumExecutors > 0, + s"$MAX_EXECUTORS_KEY must be more than 0") + + require( + scalingUpRatio > scalingDownRatio, + s"Config $SCALING_UP_RATIO_KEY must be more than config $SCALING_DOWN_RATIO_KEY") + + if (conf.contains(MIN_EXECUTORS_KEY) && conf.contains(MAX_EXECUTORS_KEY)) { + require( + maxNumExecutors >= minNumExecutors, + s"Config $MAX_EXECUTORS_KEY must be more than config $MIN_EXECUTORS_KEY") + } + } + + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = { + logDebug("onBatchCompleted called: " + batchCompleted) + if (!batchCompleted.batchInfo.outputOperationInfos.values.exists(_.failureReason.nonEmpty)) { + batchCompleted.batchInfo.processingDelay.foreach(addBatchProcTime) + } + } +} + +private[streaming] object ExecutorAllocationManager extends Logging { + val ENABLED_KEY = "spark.streaming.dynamicAllocation.enabled" + + val SCALING_INTERVAL_KEY = "spark.streaming.dynamicAllocation.scalingInterval" + val SCALING_INTERVAL_DEFAULT_SECS = 60 + + val SCALING_UP_RATIO_KEY = "spark.streaming.dynamicAllocation.scalingUpRatio" + val SCALING_UP_RATIO_DEFAULT = 0.9 + + val SCALING_DOWN_RATIO_KEY = "spark.streaming.dynamicAllocation.scalingDownRatio" + val SCALING_DOWN_RATIO_DEFAULT = 0.3 + + val MIN_EXECUTORS_KEY = "spark.streaming.dynamicAllocation.minExecutors" + + val MAX_EXECUTORS_KEY = "spark.streaming.dynamicAllocation.maxExecutors" + + def isDynamicAllocationEnabled(conf: SparkConf): Boolean = { + val numExecutor = conf.getInt("spark.executor.instances", 0) + val streamingDynamicAllocationEnabled = conf.getBoolean(ENABLED_KEY, false) + if (numExecutor != 0 && streamingDynamicAllocationEnabled) { + throw new IllegalArgumentException( + "Dynamic Allocation for streaming cannot be enabled while spark.executor.instances is set.") + } + if (Utils.isDynamicAllocationEnabled(conf) && streamingDynamicAllocationEnabled) { + throw new IllegalArgumentException( + """ + |Dynamic Allocation cannot be enabled for both streaming and core at the same time. + |Please disable core Dynamic Allocation by setting spark.dynamicAllocation.enabled to + |false to use Dynamic Allocation in streaming. + """.stripMargin) + } + val testing = conf.getBoolean("spark.streaming.dynamicAllocation.testing", false) + numExecutor == 0 && streamingDynamicAllocationEnabled && (!Utils.isLocalMaster(conf) || testing) + } + + def createIfEnabled( + client: ExecutorAllocationClient, + receiverTracker: ReceiverTracker, + conf: SparkConf, + batchDurationMs: Long, + clock: Clock): Option[ExecutorAllocationManager] = { + if (isDynamicAllocationEnabled(conf)) { + Some(new ExecutorAllocationManager(client, receiverTracker, conf, batchDurationMs, clock)) + } else None + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala index deb15d075975c..4f124a1356b5a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala @@ -19,9 +19,9 @@ package org.apache.spark.streaming.scheduler import scala.collection.mutable -import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.streaming.{Time, StreamingContext} +import org.apache.spark.internal.Logging +import org.apache.spark.streaming.{StreamingContext, Time} /** * :: DeveloperApi :: diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala index ab1b3565fcc19..7050d7ef45240 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.scheduler import scala.util.{Failure, Try} import org.apache.spark.streaming.Time -import org.apache.spark.util.{Utils, CallSite} +import org.apache.spark.util.{CallSite, Utils} /** * Class representing a Spark computation. It may contain multiple Spark jobs. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 2de035d166e7b..307ff1f7ec235 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -19,10 +19,12 @@ package org.apache.spark.streaming.scheduler import scala.util.{Failure, Success, Try} -import org.apache.spark.{SparkEnv, Logging} +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time} import org.apache.spark.streaming.util.RecurringTimer -import org.apache.spark.util.{Utils, Clock, EventLoop, ManualClock} +import org.apache.spark.util.{Clock, EventLoop, ManualClock, Utils} /** Event classes for JobGenerator */ private[scheduler] sealed trait JobGeneratorEvent @@ -217,11 +219,12 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // Batches that were unprocessed before failure val pendingTimes = ssc.initialCheckpoint.pendingTimes.sorted(Time.ordering) - logInfo("Batches pending processing (" + pendingTimes.size + " batches): " + + logInfo("Batches pending processing (" + pendingTimes.length + " batches): " + pendingTimes.mkString(", ")) // Reschedule jobs for these times - val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering) - logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " + + val timesToReschedule = (pendingTimes ++ downTimes).filter { _ < restartTime } + .distinct.sorted(Time.ordering) + logInfo("Batches to reschedule (" + timesToReschedule.length + " batches): " + timesToReschedule.mkString(", ")) timesToReschedule.foreach { time => // Allocate the related blocks when recovering from failure, because some blocks that were @@ -238,10 +241,9 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Generate jobs and perform checkpoint for the given `time`. */ private def generateJobs(time: Time) { - // Set the SparkEnv in this thread, so that job generation code can access the environment - // Example: BlockRDDs are created in this thread, and it needs to access BlockManager - // Update: This is probably redundant after threadlocal stuff in SparkEnv has been removed. - SparkEnv.set(ssc.env) + // Checkpoint all RDDs marked for checkpointing to ensure their lineages are + // truncated periodically. Otherwise, we may run into stack overflows (SPARK-6847). + ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, "true") Try { jobScheduler.receiverTracker.allocateBlocksToBatch(time) // allocate received blocks to batch graph.generateJobs(time) // generate jobs using allocated block diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 2480b4ec093e2..ac18f73ea86aa 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -17,16 +17,19 @@ package org.apache.spark.streaming.scheduler +import java.util.Properties import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import scala.collection.JavaConverters._ import scala.util.Failure -import org.apache.spark.Logging -import org.apache.spark.rdd.PairRDDFunctions +import org.apache.commons.lang.SerializationUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.{PairRDDFunctions, RDD} import org.apache.spark.streaming._ import org.apache.spark.streaming.ui.UIUtils -import org.apache.spark.util.{EventLoop, ThreadUtils, Utils} +import org.apache.spark.util.{EventLoop, ThreadUtils} private[scheduler] sealed trait JobSchedulerEvent @@ -49,7 +52,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { ThreadUtils.newDaemonFixedThreadPool(numConcurrentJobs, "streaming-job-executor") private val jobGenerator = new JobGenerator(this) val clock = jobGenerator.clock - val listenerBus = new StreamingListenerBus() + val listenerBus = new StreamingListenerBus(ssc.sparkContext.listenerBus) // These two are created only when scheduler starts. // eventLoop not being null means the scheduler has been started and not stopped @@ -57,6 +60,8 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { // A tracker to track all the input stream information as well as processed record number var inputInfoTracker: InputInfoTracker = null + private var executorAllocationManager: Option[ExecutorAllocationManager] = None + private var eventLoop: EventLoop[JobSchedulerEvent] = null def start(): Unit = synchronized { @@ -76,11 +81,19 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { rateController <- inputDStream.rateController } ssc.addStreamingListener(rateController) - listenerBus.start(ssc.sparkContext) + listenerBus.start() receiverTracker = new ReceiverTracker(ssc) inputInfoTracker = new InputInfoTracker(ssc) + executorAllocationManager = ExecutorAllocationManager.createIfEnabled( + ssc.sparkContext, + receiverTracker, + ssc.conf, + ssc.graph.batchDuration.milliseconds, + clock) + executorAllocationManager.foreach(ssc.addStreamingListener) receiverTracker.start() jobGenerator.start() + executorAllocationManager.foreach(_.start()) logInfo("Started JobScheduler") } @@ -88,8 +101,14 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { if (eventLoop == null) return // scheduler has already been stopped logDebug("Stopping JobScheduler") - // First, stop receiving - receiverTracker.stop(processAllReceivedData) + if (receiverTracker != null) { + // First, stop receiving + receiverTracker.stop(processAllReceivedData) + } + + if (executorAllocationManager != null) { + executorAllocationManager.foreach(_.stop()) + } // Second, stop generating jobs. If it has to process all received data, // then this will wait for all the processing through JobScheduler to be over. @@ -198,7 +217,10 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { import JobScheduler._ def run() { + val oldProps = ssc.sparkContext.getLocalProperties try { + ssc.sparkContext.setLocalProperties( + SerializationUtils.clone(ssc.savedProperties.get()).asInstanceOf[Properties]) val formattedTime = UIUtils.formatBatchTime( job.time.milliseconds, ssc.graph.batchDuration.milliseconds, showYYYYMMSS = false) val batchUrl = s"/streaming/batch/?id=${job.time.milliseconds}" @@ -208,6 +230,9 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { s"""Streaming job from $batchLinkText""") ssc.sc.setLocalProperty(BATCH_TIME_PROPERTY_KEY, job.time.milliseconds.toString) ssc.sc.setLocalProperty(OUTPUT_OP_ID_PROPERTY_KEY, job.outputOpId.toString) + // Checkpoint all RDDs marked for checkpointing to ensure their lineages are + // truncated periodically. Otherwise, we may run into stack overflows (SPARK-6847). + ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, "true") // We need to assign `eventLoop` to a temp variable. Otherwise, because // `JobScheduler.stop(false)` may set `eventLoop` to null when this method is running, then @@ -229,8 +254,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { // JobScheduler has been stopped. } } finally { - ssc.sc.setLocalProperty(JobScheduler.BATCH_TIME_PROPERTY_KEY, null) - ssc.sc.setLocalProperty(JobScheduler.OUTPUT_OP_ID_PROPERTY_KEY, null) + ssc.sparkContext.setLocalProperties(oldProps) } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index f76300351e3c0..0baedaf275d67 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -18,14 +18,13 @@ package org.apache.spark.streaming.scheduler import scala.collection.mutable.HashSet -import scala.util.Failure import org.apache.spark.streaming.Time -import org.apache.spark.util.Utils -/** Class representing a set of Jobs - * belong to the same batch. - */ +/** + * Class representing a set of Jobs + * belong to the same batch. + */ private[streaming] case class JobSet( time: Time, @@ -59,17 +58,15 @@ case class JobSet( // Time taken to process all the jobs from the time they were submitted // (i.e. including the time they wait in the streaming scheduler queue) - def totalDelay: Long = { - processingEndTime - time.milliseconds - } + def totalDelay: Long = processingEndTime - time.milliseconds def toBatchInfo: BatchInfo = { BatchInfo( time, streamIdToInputInfo, submissionTime, - if (processingStartTime >= 0) Some(processingStartTime) else None, - if (processingEndTime >= 0) Some(processingEndTime) else None, + if (hasStarted) Some(processingStartTime) else None, + if (hasCompleted) Some(processingEndTime) else None, jobs.map { job => (job.outputOpId, job.toOutputOperationInfo) }.toMap ) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index f2711d1355e60..5d9a8ac0d9297 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -22,14 +22,17 @@ import java.nio.ByteBuffer import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.implicitConversions +import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils import org.apache.spark.streaming.Time import org.apache.spark.streaming.util.{WriteAheadLog, WriteAheadLogUtils} import org.apache.spark.util.{Clock, Utils} -import org.apache.spark.{Logging, SparkConf} /** Trait representing any event in the ReceivedBlockTracker that updates its state. */ private[streaming] sealed trait ReceivedBlockTrackerLogEvent @@ -41,7 +44,6 @@ private[streaming] case class BatchAllocationEvent(time: Time, allocatedBlocks: private[streaming] case class BatchCleanupEvent(times: Seq[Time]) extends ReceivedBlockTrackerLogEvent - /** Class representing the blocks of all the streams allocated to a batch */ private[streaming] case class AllocatedBlocks(streamIdToAllocatedBlocks: Map[Int, Seq[ReceivedBlockInfo]]) { @@ -82,15 +84,22 @@ private[streaming] class ReceivedBlockTracker( } /** Add received block. This event will get written to the write ahead log (if enabled). */ - def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = synchronized { + def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = { try { - writeToLog(BlockAdditionEvent(receivedBlockInfo)) - getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo - logDebug(s"Stream ${receivedBlockInfo.streamId} received " + - s"block ${receivedBlockInfo.blockStoreResult.blockId}") - true + val writeResult = writeToLog(BlockAdditionEvent(receivedBlockInfo)) + if (writeResult) { + synchronized { + getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo + } + logDebug(s"Stream ${receivedBlockInfo.streamId} received " + + s"block ${receivedBlockInfo.blockStoreResult.blockId}") + } else { + logDebug(s"Failed to acknowledge stream ${receivedBlockInfo.streamId} receiving " + + s"block ${receivedBlockInfo.blockStoreResult.blockId} in the Write Ahead Log.") + } + writeResult } catch { - case e: Exception => + case NonFatal(e) => logError(s"Error adding block $receivedBlockInfo", e) false } @@ -106,10 +115,12 @@ private[streaming] class ReceivedBlockTracker( (streamId, getReceivedBlockQueue(streamId).dequeueAll(x => true)) }.toMap val allocatedBlocks = AllocatedBlocks(streamIdToBlocks) - writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks)) - timeToAllocatedBlocks(batchTime) = allocatedBlocks - lastAllocatedBatchTime = batchTime - allocatedBlocks + if (writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks))) { + timeToAllocatedBlocks.put(batchTime, allocatedBlocks) + lastAllocatedBatchTime = batchTime + } else { + logInfo(s"Possibly processed batch $batchTime needs to be processed again in WAL recovery") + } } else { // This situation occurs when: // 1. WAL is ended with BatchAllocationEvent, but without BatchCleanupEvent, @@ -118,7 +129,7 @@ private[streaming] class ReceivedBlockTracker( // 2. Slow checkpointing makes recovered batch time older than WAL recovered // lastAllocatedBatchTime. // This situation will only occurs in recovery time. - logInfo(s"Possibly processed batch $batchTime need to be processed again in WAL recovery") + logInfo(s"Possibly processed batch $batchTime needs to be processed again in WAL recovery") } } @@ -156,10 +167,13 @@ private[streaming] class ReceivedBlockTracker( def cleanupOldBatches(cleanupThreshTime: Time, waitForCompletion: Boolean): Unit = synchronized { require(cleanupThreshTime.milliseconds < clock.getTimeMillis()) val timesToCleanup = timeToAllocatedBlocks.keys.filter { _ < cleanupThreshTime }.toSeq - logInfo("Deleting batches " + timesToCleanup) - writeToLog(BatchCleanupEvent(timesToCleanup)) - timeToAllocatedBlocks --= timesToCleanup - writeAheadLogOption.foreach(_.clean(cleanupThreshTime.milliseconds, waitForCompletion)) + logInfo(s"Deleting batches: ${timesToCleanup.mkString(" ")}") + if (writeToLog(BatchCleanupEvent(timesToCleanup))) { + timeToAllocatedBlocks --= timesToCleanup + writeAheadLogOption.foreach(_.clean(cleanupThreshTime.milliseconds, waitForCompletion)) + } else { + logWarning("Failed to acknowledge batch clean up in the Write Ahead Log.") + } } /** Stop the block tracker. */ @@ -185,8 +199,8 @@ private[streaming] class ReceivedBlockTracker( logTrace(s"Recovery: Inserting allocated batch for time $batchTime to " + s"${allocatedBlocks.streamIdToAllocatedBlocks}") streamIdToUnallocatedBlockQueues.values.foreach { _.clear() } - lastAllocatedBatchTime = batchTime timeToAllocatedBlocks.put(batchTime, allocatedBlocks) + lastAllocatedBatchTime = batchTime } // Cleanup the batch allocations @@ -198,9 +212,9 @@ private[streaming] class ReceivedBlockTracker( writeAheadLogOption.foreach { writeAheadLog => logInfo(s"Recovering from write ahead logs in ${checkpointDirOption.get}") writeAheadLog.readAll().asScala.foreach { byteBuffer => - logTrace("Recovering record " + byteBuffer) + logInfo("Recovering record " + byteBuffer) Utils.deserialize[ReceivedBlockTrackerLogEvent]( - byteBuffer.array, Thread.currentThread().getContextClassLoader) match { + JavaUtils.bufferToArray(byteBuffer), Thread.currentThread().getContextClassLoader) match { case BlockAdditionEvent(receivedBlockInfo) => insertAddedBlock(receivedBlockInfo) case BatchAllocationEvent(time, allocatedBlocks) => @@ -213,12 +227,20 @@ private[streaming] class ReceivedBlockTracker( } /** Write an update to the tracker to the write ahead log */ - private def writeToLog(record: ReceivedBlockTrackerLogEvent) { + private def writeToLog(record: ReceivedBlockTrackerLogEvent): Boolean = { if (isWriteAheadLogEnabled) { - logDebug(s"Writing to log $record") - writeAheadLogOption.foreach { logManager => - logManager.write(ByteBuffer.wrap(Utils.serialize(record)), clock.getTimeMillis()) + logTrace(s"Writing record: $record") + try { + writeAheadLogOption.get.write(ByteBuffer.wrap(Utils.serialize(record)), + clock.getTimeMillis()) + true + } catch { + case NonFatal(e) => + logWarning(s"Exception thrown while writing record: $record to the WriteAheadLog.", e) + false } + } else { + true } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala index 59df892397fe0..d16e158da35cc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala @@ -18,7 +18,6 @@ package org.apache.spark.streaming.scheduler import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rpc.RpcEndpointRef /** * :: DeveloperApi :: @@ -30,6 +29,7 @@ case class ReceiverInfo( name: String, active: Boolean, location: String, + executorId: String, lastErrorMessage: String = "", lastError: String = "", lastErrorTime: Long = -1L diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala index 234bc8660da8a..391a461f08125 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala @@ -27,28 +27,29 @@ import org.apache.spark.streaming.receiver.Receiver * A class that tries to schedule receivers with evenly distributed. There are two phases for * scheduling receivers. * - * - The first phase is global scheduling when ReceiverTracker is starting and we need to schedule - * all receivers at the same time. ReceiverTracker will call `scheduleReceivers` at this phase. - * It will try to schedule receivers such that they are evenly distributed. ReceiverTracker should - * update its `receiverTrackingInfoMap` according to the results of `scheduleReceivers`. - * `ReceiverTrackingInfo.scheduledLocations` for each receiver should be set to an location list - * that contains the scheduled locations. Then when a receiver is starting, it will send a - * register request and `ReceiverTracker.registerReceiver` will be called. In - * `ReceiverTracker.registerReceiver`, if a receiver's scheduled locations is set, it should check - * if the location of this receiver is one of the scheduled locations, if not, the register will - * be rejected. - * - The second phase is local scheduling when a receiver is restarting. There are two cases of - * receiver restarting: - * - If a receiver is restarting because it's rejected due to the real location and the scheduled - * locations mismatching, in other words, it fails to start in one of the locations that - * `scheduleReceivers` suggested, `ReceiverTracker` should firstly choose the executors that are - * still alive in the list of scheduled locations, then use them to launch the receiver job. - * - If a receiver is restarting without a scheduled locations list, or the executors in the list - * are dead, `ReceiverTracker` should call `rescheduleReceiver`. If so, `ReceiverTracker` should - * not set `ReceiverTrackingInfo.scheduledLocations` for this receiver, instead, it should clear - * it. Then when this receiver is registering, we can know this is a local scheduling, and - * `ReceiverTrackingInfo` should call `rescheduleReceiver` again to check if the launching - * location is matching. + * - The first phase is global scheduling when ReceiverTracker is starting and we need to schedule + * all receivers at the same time. ReceiverTracker will call `scheduleReceivers` at this phase. + * It will try to schedule receivers such that they are evenly distributed. ReceiverTracker + * should update its `receiverTrackingInfoMap` according to the results of `scheduleReceivers`. + * `ReceiverTrackingInfo.scheduledLocations` for each receiver should be set to an location list + * that contains the scheduled locations. Then when a receiver is starting, it will send a + * register request and `ReceiverTracker.registerReceiver` will be called. In + * `ReceiverTracker.registerReceiver`, if a receiver's scheduled locations is set, it should + * check if the location of this receiver is one of the scheduled locations, if not, the register + * will be rejected. + * - The second phase is local scheduling when a receiver is restarting. There are two cases of + * receiver restarting: + * - If a receiver is restarting because it's rejected due to the real location and the scheduled + * locations mismatching, in other words, it fails to start in one of the locations that + * `scheduleReceivers` suggested, `ReceiverTracker` should firstly choose the executors that + * are still alive in the list of scheduled locations, then use them to launch the receiver + * job. + * - If a receiver is restarting without a scheduled locations list, or the executors in the list + * are dead, `ReceiverTracker` should call `rescheduleReceiver`. If so, `ReceiverTracker` + * should not set `ReceiverTrackingInfo.scheduledLocations` for this receiver, instead, it + * should clear it. Then when this receiver is registering, we can know this is a local + * scheduling, and `ReceiverTrackingInfo` should call `rescheduleReceiver` again to check if + * the launching location is matching. * * In conclusion, we should make a global schedule, try to achieve that exactly as long as possible, * otherwise do local scheduling. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index b183d856f50c3..9aa2f0bbb9952 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -20,14 +20,15 @@ package org.apache.spark.streaming.scheduler import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.collection.mutable.HashMap -import scala.concurrent.ExecutionContext +import scala.concurrent.{ExecutionContext, Future} import scala.language.existentials import scala.util.{Failure, Success} import org.apache.spark._ +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.rpc._ -import org.apache.spark.scheduler.{TaskLocation, ExecutorCacheTaskLocation} +import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, TaskLocation} import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.util.WriteAheadLogUtils @@ -91,6 +92,8 @@ private[streaming] case object AllReceiverIds extends ReceiverTrackerLocalMessag private[streaming] case class UpdateReceiverRateLimit(streamUID: Int, newRate: Long) extends ReceiverTrackerLocalMessage +private[streaming] case object GetAllReceiverInfo extends ReceiverTrackerLocalMessage + /** * This class manages the execution of the receivers of ReceiverInputDStreams. Instance of * this class must be created after all input streams have been added and StreamingContext.start() @@ -131,7 +134,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Track the active receiver job number. When a receiver job exits ultimately, countDown will // be called. - private val receiverJobExitLatch = new CountDownLatch(receiverInputStreams.size) + private val receiverJobExitLatch = new CountDownLatch(receiverInputStreams.length) /** * Track all receivers' information. The key is the receiver id, the value is the receiver info. @@ -233,6 +236,26 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } } + /** + * Get the executors allocated to each receiver. + * @return a map containing receiver ids to optional executor ids. + */ + def allocatedExecutors(): Map[Int, Option[String]] = synchronized { + if (isTrackerStarted) { + endpoint.askWithRetry[Map[Int, ReceiverTrackingInfo]](GetAllReceiverInfo).mapValues { + _.runningExecutor.map { + _.executorId + } + } + } else { + Map.empty + } + } + + def numReceivers(): Int = { + receiverInputStreams.size + } + /** Register a receiver */ private def registerReceiver( streamId: Int, @@ -411,11 +434,11 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false * worker nodes as a parallel collection, and runs them. */ private def launchReceivers(): Unit = { - val receivers = receiverInputStreams.map(nis => { + val receivers = receiverInputStreams.map { nis => val rcvr = nis.getReceiver() rcvr.setReceiverId(nis.id) rcvr - }) + } runDummySparkJob() @@ -435,9 +458,10 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** RpcEndpoint to receive messages from the receivers. */ private class ReceiverTrackerEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint { - // TODO Remove this thread pool after https://github.com/apache/spark/issues/7385 is merged - private val submitJobThreadPool = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("submit-job-thead-pool")) + private val walBatchingThreadPool = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("wal-batching-thread-pool")) + + @volatile private var active: Boolean = true override def receive: PartialFunction[Any, Unit] = { // Local messages @@ -488,13 +512,28 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false registerReceiver(streamId, typ, host, executorId, receiverEndpoint, context.senderAddress) context.reply(successful) case AddBlock(receivedBlockInfo) => - context.reply(addBlock(receivedBlockInfo)) + if (WriteAheadLogUtils.isBatchingEnabled(ssc.conf, isDriver = true)) { + walBatchingThreadPool.execute(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + if (active) { + context.reply(addBlock(receivedBlockInfo)) + } else { + throw new IllegalStateException("ReceiverTracker RpcEndpoint shut down.") + } + } + }) + } else { + context.reply(addBlock(receivedBlockInfo)) + } case DeregisterReceiver(streamId, message, error) => deregisterReceiver(streamId, message, error) context.reply(true) + // Local messages case AllReceiverIds => context.reply(receiverTrackingInfos.filter(_._2.state != ReceiverState.INACTIVE).keys.toSeq) + case GetAllReceiverInfo => + context.reply(receiverTrackingInfos.toMap) case StopAllReceivers => assert(isTrackerStopping || isTrackerStopped) stopReceivers() @@ -593,12 +632,13 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false logInfo(s"Restarting Receiver $receiverId") self.send(RestartReceiver(receiver)) } - }(submitJobThreadPool) + }(ThreadUtils.sameThread) logInfo(s"Receiver ${receiver.streamId} started") } override def onStop(): Unit = { - submitJobThreadPool.shutdownNow() + active = false + walBatchingThreadPool.shutdown() } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala index ab0a84f05214d..4dc5bb9c3bfbe 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala @@ -49,6 +49,7 @@ private[streaming] case class ReceiverTrackingInfo( name.getOrElse(""), state == ReceiverState.ACTIVE, location = runningExecutor.map(_.host).getOrElse(""), + executorId = runningExecutor.map(_.executorId).getOrElse(""), lastErrorMessage = errorInfo.map(_.lastErrorMessage).getOrElse(""), lastError = errorInfo.map(_.lastError).getOrElse(""), lastErrorTime = errorInfo.map(_.lastErrorTime).getOrElse(-1L) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala index d19bdbb443c5e..58fc78d552106 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala @@ -19,8 +19,8 @@ package org.apache.spark.streaming.scheduler import scala.collection.mutable.Queue -import org.apache.spark.util.Distribution import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.util.Distribution /** * :: DeveloperApi :: diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala index ca111bb636ed5..39f6e711a67ad 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala @@ -17,19 +17,37 @@ package org.apache.spark.streaming.scheduler -import java.util.concurrent.atomic.AtomicBoolean +import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerEvent} +import org.apache.spark.util.ListenerBus -import org.apache.spark.Logging -import org.apache.spark.util.AsynchronousListenerBus +/** + * A Streaming listener bus to forward events to StreamingListeners. This one will wrap received + * Streaming events as WrappedStreamingListenerEvent and send them to Spark listener bus. It also + * registers itself with Spark listener bus, so that it can receive WrappedStreamingListenerEvents, + * unwrap them as StreamingListenerEvent and dispatch them to StreamingListeners. + */ +private[streaming] class StreamingListenerBus(sparkListenerBus: LiveListenerBus) + extends SparkListener with ListenerBus[StreamingListener, StreamingListenerEvent] { -/** Asynchronously passes StreamingListenerEvents to registered StreamingListeners. */ -private[spark] class StreamingListenerBus - extends AsynchronousListenerBus[StreamingListener, StreamingListenerEvent]("StreamingListenerBus") - with Logging { + /** + * Post a StreamingListenerEvent to the Spark listener bus asynchronously. This event will be + * dispatched to all StreamingListeners in the thread of the Spark listener bus. + */ + def post(event: StreamingListenerEvent) { + sparkListenerBus.post(new WrappedStreamingListenerEvent(event)) + } - private val logDroppedEvent = new AtomicBoolean(false) + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case WrappedStreamingListenerEvent(e) => + postToAll(e) + case _ => + } + } - override def onPostEvent(listener: StreamingListener, event: StreamingListenerEvent): Unit = { + protected override def doPostEvent( + listener: StreamingListener, + event: StreamingListenerEvent): Unit = { event match { case receiverStarted: StreamingListenerReceiverStarted => listener.onReceiverStarted(receiverStarted) @@ -51,12 +69,31 @@ private[spark] class StreamingListenerBus } } - override def onDropEvent(event: StreamingListenerEvent): Unit = { - if (logDroppedEvent.compareAndSet(false, true)) { - // Only log the following message once to avoid duplicated annoying logs. - logError("Dropping StreamingListenerEvent because no remaining room in event queue. " + - "This likely means one of the StreamingListeners is too slow and cannot keep up with the " + - "rate at which events are being started by the scheduler.") - } + /** + * Register this one with the Spark listener bus so that it can receive Streaming events and + * forward them to StreamingListeners. + */ + def start(): Unit = { + sparkListenerBus.addListener(this) // for getting callbacks on spark events + } + + /** + * Unregister this one with the Spark listener bus and all StreamingListeners won't receive any + * events after that. + */ + def stop(): Unit = { + sparkListenerBus.removeListener(this) + } + + /** + * Wrapper for StreamingListenerEvent as SparkListenerEvent so that it can be posted to Spark + * listener bus. + */ + private case class WrappedStreamingListenerEvent(streamingListenerEvent: StreamingListenerEvent) + extends SparkListenerEvent { + + // Do not log streaming events in event log as history server does not support streaming + // events (SPARK-12140). TODO Once SPARK-12140 is resolved we should set it to true. + protected[spark] override def logEvent: Boolean = false } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala index 84a3ca9d74e58..a73e6cc2cd9c1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.scheduler.rate -import org.apache.spark.Logging +import org.apache.spark.internal.Logging /** * Implements a proportional-integral-derivative (PID) controller which acts on diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala index d7210f64fcc36..7b2ef6881d6f7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala @@ -21,18 +21,20 @@ import org.apache.spark.SparkConf import org.apache.spark.streaming.Duration /** - * A component that estimates the rate at wich an InputDStream should ingest - * elements, based on updates at every batch completion. + * A component that estimates the rate at which an `InputDStream` should ingest + * records, based on updates at every batch completion. + * + * @see [[org.apache.spark.streaming.scheduler.RateController]] */ private[streaming] trait RateEstimator extends Serializable { /** - * Computes the number of elements the stream attached to this `RateEstimator` + * Computes the number of records the stream attached to this `RateEstimator` * should ingest per second, given an update on the size and completion * times of the latest batch. * - * @param time The timetamp of the current batch interval that just finished - * @param elements The number of elements that were processed in this batch + * @param time The timestamp of the current batch interval that just finished + * @param elements The number of records that were processed in this batch * @param processingDelay The time in ms that took for the job to complete * @param schedulingDelay The time in ms that the job spent in the scheduling queue */ @@ -46,13 +48,13 @@ private[streaming] trait RateEstimator extends Serializable { object RateEstimator { /** - * Return a new RateEstimator based on the value of `spark.streaming.RateEstimator`. + * Return a new `RateEstimator` based on the value of + * `spark.streaming.backpressure.rateEstimator`. * - * The only known estimator right now is `pid`. + * The only known and acceptable estimator right now is `pid`. * * @return An instance of RateEstimator - * @throws IllegalArgumentException if there is a configured RateEstimator that doesn't match any - * known estimators. + * @throws IllegalArgumentException if the configured RateEstimator is not `pid`. */ def create(conf: SparkConf, batchInterval: Duration): RateEstimator = conf.get("spark.streaming.backpressure.rateEstimator", "pid") match { @@ -64,6 +66,6 @@ object RateEstimator { new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived, minRate) case estimator => - throw new IllegalArgumentException(s"Unkown rate estimator: $estimator") + throw new IllegalArgumentException(s"Unknown rate estimator: $estimator") } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala index 125cafd41b8af..c024b4ef7e46f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala @@ -33,10 +33,26 @@ private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long) {SparkUIUtils.tooltip("Time taken to process all jobs of a batch", "top")} } + /** + * Return the first failure reason if finding in the batches. + */ + protected def getFirstFailureReason(batches: Seq[BatchUIData]): Option[String] = { + batches.flatMap(_.outputOperations.flatMap(_._2.failureReason)).headOption + } + + protected def getFirstFailureTableCell(batch: BatchUIData): Seq[Node] = { + val firstFailureReason = batch.outputOperations.flatMap(_._2.failureReason).headOption + firstFailureReason.map { failureReason => + val failureReasonForUI = UIUtils.createOutputOperationFailureForUI(failureReason) + UIUtils.failureReasonCell( + failureReasonForUI, rowspan = 1, includeFirstLineInExpandDetails = false) + }.getOrElse(-) + } + protected def baseRow(batch: BatchUIData): Seq[Node] = { val batchTime = batch.batchTime.milliseconds val formattedBatchTime = UIUtils.formatBatchTime(batchTime, batchInterval) - val eventCount = batch.numRecords + val numRecords = batch.numRecords val schedulingDelay = batch.schedulingDelay val formattedSchedulingDelay = schedulingDelay.map(SparkUIUtils.formatDuration).getOrElse("-") val processingTime = batch.processingDelay @@ -49,7 +65,7 @@ private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long) {formattedBatchTime} - {eventCount.toString} events + {numRecords.toString} records {formattedSchedulingDelay} @@ -97,9 +113,17 @@ private[ui] class ActiveBatchTable( waitingBatches: Seq[BatchUIData], batchInterval: Long) extends BatchTableBase("active-batches-table", batchInterval) { + private val firstFailureReason = getFirstFailureReason(runningBatches) + override protected def columns: Seq[Node] = super.columns ++ { Output Ops: Succeeded/Total - Status + Status ++ { + if (firstFailureReason.nonEmpty) { + Error + } else { + Nil + } + } } override protected def renderRows: Seq[Node] = { @@ -110,20 +134,41 @@ private[ui] class ActiveBatchTable( } private def runningBatchRow(batch: BatchUIData): Seq[Node] = { - baseRow(batch) ++ createOutputOperationProgressBar(batch) ++ processing + baseRow(batch) ++ createOutputOperationProgressBar(batch) ++ processing ++ { + if (firstFailureReason.nonEmpty) { + getFirstFailureTableCell(batch) + } else { + Nil + } + } } private def waitingBatchRow(batch: BatchUIData): Seq[Node] = { - baseRow(batch) ++ createOutputOperationProgressBar(batch) ++ queued + baseRow(batch) ++ createOutputOperationProgressBar(batch) ++ queued++ { + if (firstFailureReason.nonEmpty) { + // Waiting batches have not run yet, so must have no failure reasons. + - + } else { + Nil + } + } } } private[ui] class CompletedBatchTable(batches: Seq[BatchUIData], batchInterval: Long) extends BatchTableBase("completed-batches-table", batchInterval) { + private val firstFailureReason = getFirstFailureReason(batches) + override protected def columns: Seq[Node] = super.columns ++ { Total Delay {SparkUIUtils.tooltip("Total time taken to handle a batch", "top")} - Output Ops: Succeeded/Total + Output Ops: Succeeded/Total ++ { + if (firstFailureReason.nonEmpty) { + Error + } else { + Nil + } + } } override protected def renderRows: Seq[Node] = { @@ -138,6 +183,12 @@ private[ui] class CompletedBatchTable(batches: Seq[BatchUIData], batchInterval: {formattedTotalDelay} - } ++ createOutputOperationProgressBar(batch) + } ++ createOutputOperationProgressBar(batch)++ { + if (firstFailureReason.nonEmpty) { + getFirstFailureTableCell(batch) + } else { + Nil + } + } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 2ed925572826e..1ef26d2f865da 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -24,9 +24,9 @@ import scala.xml._ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.streaming.Time -import org.apache.spark.streaming.ui.StreamingJobProgressListener.{OutputOpId, SparkJobId} -import org.apache.spark.ui.jobs.UIData.JobUIData +import org.apache.spark.streaming.ui.StreamingJobProgressListener.SparkJobId import org.apache.spark.ui.{UIUtils => SparkUIUtils, WebUIPage} +import org.apache.spark.ui.jobs.UIData.JobUIData private[ui] case class SparkJobIdWithUIData(sparkJobId: SparkJobId, jobUIData: Option[JobUIData]) @@ -37,10 +37,10 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { private def columns: Seq[Node] = { Output Op Id Description - Duration + Output Op Duration Status Job Id - Duration + Job Duration Stages: Succeeded/Total Tasks (for all stages): Succeeded/Total Error @@ -149,7 +149,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { total = sparkJob.numTasks - sparkJob.numSkippedTasks) } - {failureReasonCell(lastFailureReason, rowspan = 1)} + {UIUtils.failureReasonCell(lastFailureReason)} } @@ -245,48 +245,6 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } - private def failureReasonCell( - failureReason: String, - rowspan: Int, - includeFirstLineInExpandDetails: Boolean = true): Seq[Node] = { - val isMultiline = failureReason.indexOf('\n') >= 0 - // Display the first line by default - val failureReasonSummary = StringEscapeUtils.escapeHtml4( - if (isMultiline) { - failureReason.substring(0, failureReason.indexOf('\n')) - } else { - failureReason - }) - val failureDetails = - if (isMultiline && !includeFirstLineInExpandDetails) { - // Skip the first line - failureReason.substring(failureReason.indexOf('\n') + 1) - } else { - failureReason - } - val details = if (isMultiline) { - // scalastyle:off - - +details - ++ - - // scalastyle:on - } else { - "" - } - - if (rowspan == 1) { - {failureReasonSummary}{details} - } else { - - {failureReasonSummary}{details} - - } - } - private def getJobData(sparkJobId: SparkJobId): Option[JobUIData] = { sparkListener.activeJobs.get(sparkJobId).orElse { sparkListener.completedJobs.find(_.jobId == sparkJobId).orElse { @@ -301,7 +259,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } else { var nextLineIndex = failure.indexOf("\n") if (nextLineIndex < 0) { - nextLineIndex = failure.size + nextLineIndex = failure.length } val firstLine = failure.substring(0, nextLineIndex) s"Failed due to error: $firstLine\n$failure" @@ -315,7 +273,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { val outputOpIdToSparkJobIds = batchUIData.outputOpIdSparkJobIdPairs.groupBy(_.outputOpId). map { case (outputOpId, outputOpIdAndSparkJobIds) => // sort SparkJobIds for each OutputOpId - (outputOpId, outputOpIdAndSparkJobIds.map(_.sparkJobId).sorted) + (outputOpId, outputOpIdAndSparkJobIds.map(_.sparkJobId).toSeq.sorted) } val outputOps: Seq[(OutputOperationUIData, Seq[SparkJobId])] = @@ -434,8 +392,9 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { private def outputOpStatusCell(outputOp: OutputOperationUIData, rowspan: Int): Seq[Node] = { outputOp.failureReason match { case Some(failureReason) => - val failureReasonForUI = generateOutputOperationStatusForUI(failureReason) - failureReasonCell(failureReasonForUI, rowspan, includeFirstLineInExpandDetails = false) + val failureReasonForUI = UIUtils.createOutputOperationFailureForUI(failureReason) + UIUtils.failureReasonCell( + failureReasonForUI, rowspan, includeFirstLineInExpandDetails = false) case None => if (outputOp.endTime.isEmpty) { - diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala index 3ef3689de1c45..1af60857bc770 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala @@ -33,7 +33,7 @@ private[ui] case class BatchUIData( val processingStartTime: Option[Long], val processingEndTime: Option[Long], val outputOperations: mutable.HashMap[OutputOpId, OutputOperationUIData] = mutable.HashMap(), - var outputOpIdSparkJobIdPairs: Seq[OutputOpIdAndSparkJobId] = Seq.empty) { + var outputOpIdSparkJobIdPairs: Iterable[OutputOpIdAndSparkJobId] = Seq.empty) { /** * Time taken for the first job of this batch to start processing from the time this batch diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index f6cc6edf2569a..c086df47d9835 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -17,22 +17,18 @@ package org.apache.spark.streaming.ui -import java.util.LinkedHashMap -import java.util.{Map => JMap} -import java.util.Properties +import java.util.{LinkedHashMap, Map => JMap, Properties} +import java.util.concurrent.ConcurrentLinkedQueue -import scala.collection.mutable.{ArrayBuffer, Queue, HashMap, SynchronizedBuffer} +import scala.collection.JavaConverters._ +import scala.collection.mutable.{HashMap, Queue} import org.apache.spark.scheduler._ -import org.apache.spark.streaming.{Time, StreamingContext} +import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.scheduler._ -import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted -import org.apache.spark.streaming.scheduler.StreamingListenerBatchStarted -import org.apache.spark.streaming.scheduler.StreamingListenerBatchSubmitted - private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) - extends StreamingListener with SparkListener { + extends SparkListener with StreamingListener { private val waitingBatchUIData = new HashMap[Time, BatchUIData] private val runningBatchUIData = new HashMap[Time, BatchUIData] @@ -47,9 +43,9 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) // we may not be able to get the corresponding BatchUIData when receiving onJobStart. So here we // cannot use a map of (Time, BatchUIData). private[ui] val batchTimeToOutputOpIdSparkJobIdPair = - new LinkedHashMap[Time, SynchronizedBuffer[OutputOpIdAndSparkJobId]] { + new LinkedHashMap[Time, ConcurrentLinkedQueue[OutputOpIdAndSparkJobId]] { override def removeEldestEntry( - p1: JMap.Entry[Time, SynchronizedBuffer[OutputOpIdAndSparkJobId]]): Boolean = { + p1: JMap.Entry[Time, ConcurrentLinkedQueue[OutputOpIdAndSparkJobId]]): Boolean = { // If a lot of "onBatchCompleted"s happen before "onJobStart" (image if // SparkContext.listenerBus is very slow), "batchTimeToOutputOpIdToSparkJobIds" // may add some information for a removed batch when processing "onJobStart". It will be a @@ -97,7 +93,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = synchronized { val batchUIData = BatchUIData(batchStarted.batchInfo) - runningBatchUIData(batchStarted.batchInfo.batchTime) = BatchUIData(batchStarted.batchInfo) + runningBatchUIData(batchStarted.batchInfo.batchTime) = batchUIData waitingBatchUIData.remove(batchStarted.batchInfo.batchTime) totalReceivedRecords += batchUIData.numRecords @@ -137,12 +133,10 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) getBatchTimeAndOutputOpId(jobStart.properties).foreach { case (batchTime, outputOpId) => var outputOpIdToSparkJobIds = batchTimeToOutputOpIdSparkJobIdPair.get(batchTime) if (outputOpIdToSparkJobIds == null) { - outputOpIdToSparkJobIds = - new ArrayBuffer[OutputOpIdAndSparkJobId]() - with SynchronizedBuffer[OutputOpIdAndSparkJobId] + outputOpIdToSparkJobIds = new ConcurrentLinkedQueue[OutputOpIdAndSparkJobId]() batchTimeToOutputOpIdSparkJobIdPair.put(batchTime, outputOpIdToSparkJobIds) } - outputOpIdToSparkJobIds += OutputOpIdAndSparkJobId(outputOpId, jobStart.jobId) + outputOpIdToSparkJobIds.add(OutputOpIdAndSparkJobId(outputOpId, jobStart.jobId)) } } @@ -167,7 +161,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } def numInactiveReceivers: Int = { - ssc.graph.getReceiverInputStreams().size - numActiveReceivers + ssc.graph.getReceiverInputStreams().length - numActiveReceivers } def numTotalCompletedBatches: Long = synchronized { @@ -208,21 +202,21 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) def streamIds: Seq[Int] = ssc.graph.getInputStreams().map(_.id) /** - * Return all of the event rates for each InputDStream in each batch. The key of the return value - * is the stream id, and the value is a sequence of batch time with its event rate. + * Return all of the record rates for each InputDStream in each batch. The key of the return value + * is the stream id, and the value is a sequence of batch time with its record rate. */ - def receivedEventRateWithBatchTime: Map[Int, Seq[(Long, Double)]] = synchronized { + def receivedRecordRateWithBatchTime: Map[Int, Seq[(Long, Double)]] = synchronized { val _retainedBatches = retainedBatches val latestBatches = _retainedBatches.map { batchUIData => (batchUIData.batchTime.milliseconds, batchUIData.streamIdToInputInfo.mapValues(_.numRecords)) } streamIds.map { streamId => - val eventRates = latestBatches.map { + val recordRates = latestBatches.map { case (batchTime, streamIdToNumRecords) => val numRecords = streamIdToNumRecords.getOrElse(streamId, 0L) (batchTime, numRecords * 1000.0 / batchDuration) } - (streamId, eventRates) + (streamId, recordRates) }.toMap } @@ -262,8 +256,11 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } } batchUIData.foreach { _batchUIData => - val outputOpIdToSparkJobIds = - Option(batchTimeToOutputOpIdSparkJobIdPair.get(batchTime)).getOrElse(Seq.empty) + // We use an Iterable rather than explicitly converting to a seq so that updates + // will propagate + val outputOpIdToSparkJobIds: Iterable[OutputOpIdAndSparkJobId] = + Option(batchTimeToOutputOpIdSparkJobIdPair.get(batchTime).asScala) + .getOrElse(Seq.empty) _batchUIData.outputOpIdSparkJobIdPairs = outputOpIdToSparkJobIds } batchUIData diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 96d943e75d272..b97e24f28bfc6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -17,15 +17,13 @@ package org.apache.spark.streaming.ui -import java.text.SimpleDateFormat -import java.util.Date import java.util.concurrent.TimeUnit import javax.servlet.http.HttpServletRequest import scala.collection.mutable.ArrayBuffer import scala.xml.{Node, Unparsed} -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.ui._ import org.apache.spark.ui.{UIUtils => SparkUIUtils} @@ -127,9 +125,9 @@ private[ui] class MillisecondsStatUIData(data: Seq[(Long, Long)]) { * A helper class for "input rate" to generate data that will be used in the timeline and histogram * graphs. * - * @param data (batchTime, event-rate). + * @param data (batch time, record rate). */ -private[ui] class EventRateUIData(val data: Seq[(Long, Double)]) { +private[ui] class RecordRateUIData(val data: Seq[(Long, Double)]) { val avg: Option[Double] = if (data.isEmpty) None else Some(data.map(_._2).sum / data.size) @@ -217,7 +215,7 @@ private[ui] class StreamingPage(parent: StreamingTab) val minBatchTime = if (batchTimes.isEmpty) startTime else batchTimes.min val maxBatchTime = if (batchTimes.isEmpty) startTime else batchTimes.max - val eventRateForAllStreams = new EventRateUIData(batches.map { batchInfo => + val recordRateForAllStreams = new RecordRateUIData(batches.map { batchInfo => (batchInfo.batchTime.milliseconds, batchInfo.numRecords * 1000.0 / listener.batchDuration) }) @@ -243,24 +241,24 @@ private[ui] class StreamingPage(parent: StreamingTab) // Use the max input rate for all InputDStreams' graphs to make the Y axis ranges same. // If it's not an integral number, just use its ceil integral number. - val maxEventRate = eventRateForAllStreams.max.map(_.ceil.toLong).getOrElse(0L) - val minEventRate = 0L + val maxRecordRate = recordRateForAllStreams.max.map(_.ceil.toLong).getOrElse(0L) + val minRecordRate = 0L val batchInterval = UIUtils.convertToTimeUnit(listener.batchDuration, normalizedUnit) val jsCollector = new JsCollector - val graphUIDataForEventRateOfAllStreams = + val graphUIDataForRecordRateOfAllStreams = new GraphUIData( - "all-stream-events-timeline", - "all-stream-events-histogram", - eventRateForAllStreams.data, + "all-stream-records-timeline", + "all-stream-records-histogram", + recordRateForAllStreams.data, minBatchTime, maxBatchTime, - minEventRate, - maxEventRate, - "events/sec") - graphUIDataForEventRateOfAllStreams.generateDataJs(jsCollector) + minRecordRate, + maxRecordRate, + "records/sec") + graphUIDataForRecordRateOfAllStreams.generateDataJs(jsCollector) val graphUIDataForSchedulingDelay = new GraphUIData( @@ -336,16 +334,16 @@ private[ui] class StreamingPage(parent: StreamingTab)
      Receivers: {listener.numActiveReceivers} / {numReceivers} active
      } } -
      Avg: {eventRateForAllStreams.formattedAvg} events/sec
      +
      Avg: {recordRateForAllStreams.formattedAvg} records/sec
      - {graphUIDataForEventRateOfAllStreams.generateTimelineHtml(jsCollector)} - {graphUIDataForEventRateOfAllStreams.generateHistogramHtml(jsCollector)} + {graphUIDataForRecordRateOfAllStreams.generateTimelineHtml(jsCollector)} + {graphUIDataForRecordRateOfAllStreams.generateHistogramHtml(jsCollector)} {if (hasStream) { - {generateInputDStreamsTable(jsCollector, minBatchTime, maxBatchTime, minEventRate, maxEventRate)} + {generateInputDStreamsTable(jsCollector, minBatchTime, maxBatchTime, minRecordRate, maxRecordRate)} }} @@ -392,8 +390,16 @@ private[ui] class StreamingPage(parent: StreamingTab) maxX: Long, minY: Double, maxY: Double): Seq[Node] = { - val content = listener.receivedEventRateWithBatchTime.map { case (streamId, eventRates) => - generateInputDStreamRow(jsCollector, streamId, eventRates, minX, maxX, minY, maxY) + val maxYCalculated = listener.receivedRecordRateWithBatchTime.values + .flatMap { case streamAndRates => streamAndRates.map { case (_, recordRate) => recordRate } } + .reduceOption[Double](math.max) + .map(_.ceil.toLong) + .getOrElse(0L) + + val content = listener.receivedRecordRateWithBatchTime.toList.sortBy(_._1).map { + case (streamId, recordRates) => + generateInputDStreamRow( + jsCollector, streamId, recordRates, minX, maxX, minY, maxYCalculated) }.foldLeft[Seq[Node]](Nil)(_ ++ _) // scalastyle:off @@ -402,7 +408,7 @@ private[ui] class StreamingPage(parent: StreamingTab)
      Status
      -
      Location
      +
      Executor ID / Host
      Last Error Time
      Last Error Message @@ -417,7 +423,7 @@ private[ui] class StreamingPage(parent: StreamingTab) private def generateInputDStreamRow( jsCollector: JsCollector, streamId: Int, - eventRates: Seq[(Long, Double)], + recordRates: Seq[(Long, Double)], minX: Long, maxX: Long, minY: Double, @@ -430,33 +436,37 @@ private[ui] class StreamingPage(parent: StreamingTab) val receiverActive = receiverInfo.map { info => if (info.active) "ACTIVE" else "INACTIVE" }.getOrElse(emptyCell) - val receiverLocation = receiverInfo.map(_.location).getOrElse(emptyCell) + val receiverLocation = receiverInfo.map { info => + val executorId = if (info.executorId.isEmpty) emptyCell else info.executorId + val location = if (info.location.isEmpty) emptyCell else info.location + s"$executorId / $location" + }.getOrElse(emptyCell) val receiverLastError = receiverInfo.map { info => val msg = s"${info.lastErrorMessage} - ${info.lastError}" - if (msg.size > 100) msg.take(97) + "..." else msg + if (msg.length > 100) msg.take(97) + "..." else msg }.getOrElse(emptyCell) val receiverLastErrorTime = receiverInfo.map { r => if (r.lastErrorTime < 0) "-" else SparkUIUtils.formatDate(r.lastErrorTime) }.getOrElse(emptyCell) - val receivedRecords = new EventRateUIData(eventRates) + val receivedRecords = new RecordRateUIData(recordRates) - val graphUIDataForEventRate = + val graphUIDataForRecordRate = new GraphUIData( - s"stream-$streamId-events-timeline", - s"stream-$streamId-events-histogram", + s"stream-$streamId-records-timeline", + s"stream-$streamId-records-histogram", receivedRecords.data, minX, maxX, minY, maxY, - "events/sec") - graphUIDataForEventRate.generateDataJs(jsCollector) + "records/sec") + graphUIDataForRecordRate.generateDataJs(jsCollector)
      -
      {receiverName}
      -
      Avg: {receivedRecords.formattedAvg} events/sec
      +
      {receiverName}
      +
      Avg: {receivedRecords.formattedAvg} records/sec
      {receiverActive} @@ -466,9 +476,9 @@ private[ui] class StreamingPage(parent: StreamingTab) - {graphUIDataForEventRate.generateTimelineHtml(jsCollector)} + {graphUIDataForRecordRate.generateTimelineHtml(jsCollector)} - {graphUIDataForEventRate.generateHistogramHtml(jsCollector)} + {graphUIDataForRecordRate.generateHistogramHtml(jsCollector)} } @@ -549,7 +559,7 @@ private[ui] class JsCollector { def toHtml: Seq[Node] = { val js = s""" - |$$(document).ready(function(){ + |$$(document).ready(function() { | ${preparedStatements.mkString("\n")} | ${statements.mkString("\n")} |});""".stripMargin diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index bc53f2a31f6d1..c5f8aada3fc4a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -17,18 +17,19 @@ package org.apache.spark.streaming.ui -import org.apache.spark.{Logging, SparkException} +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging import org.apache.spark.streaming.StreamingContext import org.apache.spark.ui.{SparkUI, SparkUITab} -import StreamingTab._ - /** * Spark Web UI tab that shows statistics of a streaming job. * This assumes the given SparkContext has enabled its SparkUI. */ private[spark] class StreamingTab(val ssc: StreamingContext) - extends SparkUITab(getSparkUI(ssc), "streaming") with Logging { + extends SparkUITab(StreamingTab.getSparkUI(ssc), "streaming") with Logging { + + import StreamingTab._ private val STATIC_RESOURCE_DIR = "org/apache/spark/streaming/ui/static" diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala index 86cfb1fa47370..9b1c939e9329f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala @@ -21,6 +21,10 @@ import java.text.SimpleDateFormat import java.util.TimeZone import java.util.concurrent.TimeUnit +import scala.xml.Node + +import org.apache.commons.lang3.StringEscapeUtils + private[streaming] object UIUtils { /** @@ -124,4 +128,60 @@ private[streaming] object UIUtils { } } } + + def createOutputOperationFailureForUI(failure: String): String = { + if (failure.startsWith("org.apache.spark.Spark")) { + // SparkException or SparkDriverExecutionException + "Failed due to Spark job error\n" + failure + } else { + var nextLineIndex = failure.indexOf("\n") + if (nextLineIndex < 0) { + nextLineIndex = failure.length + } + val firstLine = failure.substring(0, nextLineIndex) + s"Failed due to error: $firstLine\n$failure" + } + } + + def failureReasonCell( + failureReason: String, + rowspan: Int = 1, + includeFirstLineInExpandDetails: Boolean = true): Seq[Node] = { + val isMultiline = failureReason.indexOf('\n') >= 0 + // Display the first line by default + val failureReasonSummary = StringEscapeUtils.escapeHtml4( + if (isMultiline) { + failureReason.substring(0, failureReason.indexOf('\n')) + } else { + failureReason + }) + val failureDetails = + if (isMultiline && !includeFirstLineInExpandDetails) { + // Skip the first line + failureReason.substring(failureReason.indexOf('\n') + 1) + } else { + failureReason + } + val details = if (isMultiline) { + // scalastyle:off + + +details + ++ + + // scalastyle:on + } else { + "" + } + + if (rowspan == 1) { + {failureReasonSummary}{details} + } else { + + {failureReasonSummary}{details} + + } + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala new file mode 100644 index 0000000000000..165e81ea41a98 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import java.nio.ByteBuffer +import java.util.{Iterator => JIterator} +import java.util.concurrent.LinkedBlockingQueue + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.{Await, Promise} +import scala.concurrent.duration._ +import scala.util.control.NonFatal + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.util.Utils + +/** + * A wrapper for a WriteAheadLog that batches records before writing data. Handles aggregation + * during writes, and de-aggregation in the `readAll` method. The end consumer has to handle + * de-aggregation after the `read` method. In addition, the `WriteAheadLogRecordHandle` returned + * after the write will contain the batch of records rather than individual records. + * + * When writing a batch of records, the `time` passed to the `wrappedLog` will be the timestamp + * of the latest record in the batch. This is very important in achieving correctness. Consider the + * following example: + * We receive records with timestamps 1, 3, 5, 7. We use "log-1" as the filename. Once we receive + * a clean up request for timestamp 3, we would clean up the file "log-1", and lose data regarding + * 5 and 7. + * + * This means the caller can assume the same write semantics as any other WriteAheadLog + * implementation despite the batching in the background - when the write() returns, the data is + * written to the WAL and is durable. To take advantage of the batching, the caller can write from + * multiple threads, each of which will stay blocked until the corresponding data has been written. + * + * All other methods of the WriteAheadLog interface will be passed on to the wrapped WriteAheadLog. + */ +private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: SparkConf) + extends WriteAheadLog with Logging { + + import BatchedWriteAheadLog._ + + private val walWriteQueue = new LinkedBlockingQueue[Record]() + + // Whether the writer thread is active + @volatile private var active: Boolean = true + private val buffer = new ArrayBuffer[Record]() + + private val batchedWriterThread = startBatchedWriterThread() + + /** + * Write a byte buffer to the log file. This method adds the byteBuffer to a queue and blocks + * until the record is properly written by the parent. + */ + override def write(byteBuffer: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { + val promise = Promise[WriteAheadLogRecordHandle]() + val putSuccessfully = synchronized { + if (active) { + walWriteQueue.offer(Record(byteBuffer, time, promise)) + true + } else { + false + } + } + if (putSuccessfully) { + Await.result(promise.future, WriteAheadLogUtils.getBatchingTimeout(conf).milliseconds) + } else { + throw new IllegalStateException("close() was called on BatchedWriteAheadLog before " + + s"write request with time $time could be fulfilled.") + } + } + + /** + * This method is not supported as the resulting ByteBuffer would actually require de-aggregation. + * This method is primarily used in testing, and to ensure that it is not used in production, + * we throw an UnsupportedOperationException. + */ + override def read(segment: WriteAheadLogRecordHandle): ByteBuffer = { + throw new UnsupportedOperationException("read() is not supported for BatchedWriteAheadLog " + + "as the data may require de-aggregation.") + } + + /** + * Read all the existing logs from the log directory. The output of the wrapped WriteAheadLog + * will be de-aggregated. + */ + override def readAll(): JIterator[ByteBuffer] = { + wrappedLog.readAll().asScala.flatMap(deaggregate).asJava + } + + /** + * Delete the log files that are older than the threshold time. + * + * This method is handled by the parent WriteAheadLog. + */ + override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { + wrappedLog.clean(threshTime, waitForCompletion) + } + + + /** + * Stop the batched writer thread, fulfill promises with failures and close the wrapped WAL. + */ + override def close(): Unit = { + logInfo(s"BatchedWriteAheadLog shutting down at time: ${System.currentTimeMillis()}.") + synchronized { + active = false + } + batchedWriterThread.interrupt() + batchedWriterThread.join() + while (!walWriteQueue.isEmpty) { + val Record(_, time, promise) = walWriteQueue.poll() + promise.failure(new IllegalStateException("close() was called on BatchedWriteAheadLog " + + s"before write request with time $time could be fulfilled.")) + } + wrappedLog.close() + } + + /** Start the actual log writer on a separate thread. */ + private def startBatchedWriterThread(): Thread = { + val thread = new Thread(new Runnable { + override def run(): Unit = { + while (active) { + try { + flushRecords() + } catch { + case NonFatal(e) => + logWarning("Encountered exception in Batched Writer Thread.", e) + } + } + logInfo("BatchedWriteAheadLog Writer thread exiting.") + } + }, "BatchedWriteAheadLog Writer") + thread.setDaemon(true) + thread.start() + thread + } + + /** Write all the records in the buffer to the write ahead log. */ + private def flushRecords(): Unit = { + try { + buffer.append(walWriteQueue.take()) + val numBatched = walWriteQueue.drainTo(buffer.asJava) + 1 + logDebug(s"Received $numBatched records from queue") + } catch { + case _: InterruptedException => + logWarning("BatchedWriteAheadLog Writer queue interrupted.") + } + try { + var segment: WriteAheadLogRecordHandle = null + if (buffer.length > 0) { + logDebug(s"Batched ${buffer.length} records for Write Ahead Log write") + // threads may not be able to add items in order by time + val sortedByTime = buffer.sortBy(_.time) + // We take the latest record for the timestamp. Please refer to the class Javadoc for + // detailed explanation + val time = sortedByTime.last.time + segment = wrappedLog.write(aggregate(sortedByTime), time) + } + buffer.foreach(_.promise.success(segment)) + } catch { + case e: InterruptedException => + logWarning("BatchedWriteAheadLog Writer queue interrupted.", e) + buffer.foreach(_.promise.failure(e)) + case NonFatal(e) => + logWarning(s"BatchedWriteAheadLog Writer failed to write $buffer", e) + buffer.foreach(_.promise.failure(e)) + } finally { + buffer.clear() + } + } + + /** Method for querying the queue length. Should only be used in tests. */ + private def getQueueLength(): Int = walWriteQueue.size() +} + +/** Static methods for aggregating and de-aggregating records. */ +private[util] object BatchedWriteAheadLog { + + /** + * Wrapper class for representing the records that we will write to the WriteAheadLog. Coupled + * with the timestamp for the write request of the record, and the promise that will block the + * write request, while a separate thread is actually performing the write. + */ + case class Record(data: ByteBuffer, time: Long, promise: Promise[WriteAheadLogRecordHandle]) + + /** Aggregate multiple serialized ReceivedBlockTrackerLogEvents in a single ByteBuffer. */ + def aggregate(records: Seq[Record]): ByteBuffer = { + ByteBuffer.wrap(Utils.serialize[Array[Array[Byte]]]( + records.map(record => JavaUtils.bufferToArray(record.data)).toArray)) + } + + /** + * De-aggregate serialized ReceivedBlockTrackerLogEvents in a single ByteBuffer. + * A stream may not have used batching initially, but started using it after a restart. This + * method therefore needs to be backwards compatible. + */ + def deaggregate(buffer: ByteBuffer): Array[ByteBuffer] = { + val prevPosition = buffer.position() + try { + Utils.deserialize[Array[Array[Byte]]](JavaUtils.bufferToArray(buffer)).map(ByteBuffer.wrap) + } catch { + case _: ClassCastException => // users may restart a stream with batching enabled + // Restore `position` so that the user can read `buffer` later + buffer.position(prevPosition) + Array(buffer) + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index bc3f2486c21fd..9b689f01b8d39 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -18,23 +18,27 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer import java.util.{Iterator => JIterator} +import java.util.concurrent.RejectedExecutionException import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import scala.collection.parallel.ExecutionContextTaskSupport import scala.concurrent.{Await, ExecutionContext, Future} import scala.language.postfixOps import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging import org.apache.spark.util.{CompletionIterator, ThreadUtils} -import org.apache.spark.{Logging, SparkConf} /** * This class manages write ahead log files. - * - Writes records (bytebuffers) to periodically rotating log files. - * - Recovers the log files and the reads the recovered records upon failures. - * - Cleans up old log files. + * + * - Writes records (bytebuffers) to periodically rotating log files. + * - Recovers the log files and the reads the recovered records upon failures. + * - Cleans up old log files. * * Uses [[org.apache.spark.streaming.util.FileBasedWriteAheadLogWriter]] to write * and [[org.apache.spark.streaming.util.FileBasedWriteAheadLogReader]] to read. @@ -54,12 +58,18 @@ private[streaming] class FileBasedWriteAheadLog( import FileBasedWriteAheadLog._ private val pastLogs = new ArrayBuffer[LogInfo] - private val callerNameTag = getCallerName.map(c => s" for $c").getOrElse("") + private val callerName = getCallerName + + private val threadpoolName = { + "WriteAheadLogManager" + callerName.map(c => s" for $c").getOrElse("") + } + private val forkJoinPool = ThreadUtils.newForkJoinPool(threadpoolName, 20) + private val executionContext = ExecutionContext.fromExecutorService(forkJoinPool) - private val threadpoolName = s"WriteAheadLogManager $callerNameTag" - implicit private val executionContext = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonSingleThreadExecutor(threadpoolName)) - override protected val logName = s"WriteAheadLogManager $callerNameTag" + override protected def logName = { + getClass.getName.stripSuffix("$") + + callerName.map("_" + _).getOrElse("").replaceAll("[ ]", "_") + } private var currentLogPath: Option[String] = None private var currentLogWriter: FileBasedWriteAheadLogWriter = null @@ -124,13 +134,19 @@ private[streaming] class FileBasedWriteAheadLog( */ def readAll(): JIterator[ByteBuffer] = synchronized { val logFilesToRead = pastLogs.map{ _.path} ++ currentLogPath - logInfo("Reading from the logs: " + logFilesToRead.mkString("\n")) - - logFilesToRead.iterator.map { file => + logInfo("Reading from the logs:\n" + logFilesToRead.mkString("\n")) + def readFile(file: String): Iterator[ByteBuffer] = { logDebug(s"Creating log reader with $file") val reader = new FileBasedWriteAheadLogReader(file, hadoopConf) CompletionIterator[ByteBuffer, Iterator[ByteBuffer]](reader, reader.close _) - }.flatten.asJava + } + if (!closeFileAfterWrite) { + logFilesToRead.iterator.map(readFile).flatten.asJava + } else { + // For performance gains, it makes sense to parallelize the recovery if + // closeFileAfterWrite = true + seqToParIterator(executionContext, logFilesToRead, readFile).asJava + } } /** @@ -146,30 +162,39 @@ private[streaming] class FileBasedWriteAheadLog( * asynchronously. */ def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { - val oldLogFiles = synchronized { pastLogs.filter { _.endTime < threshTime } } + val oldLogFiles = synchronized { + val expiredLogs = pastLogs.filter { _.endTime < threshTime } + pastLogs --= expiredLogs + expiredLogs + } logInfo(s"Attempting to clear ${oldLogFiles.size} old log files in $logDirectory " + s"older than $threshTime: ${oldLogFiles.map { _.path }.mkString("\n")}") - def deleteFiles() { - oldLogFiles.foreach { logInfo => - try { - val path = new Path(logInfo.path) - val fs = HdfsUtils.getFileSystemForPath(path, hadoopConf) - fs.delete(path, true) - synchronized { pastLogs -= logInfo } - logDebug(s"Cleared log file $logInfo") - } catch { - case ex: Exception => - logWarning(s"Error clearing write ahead log file $logInfo", ex) - } + def deleteFile(walInfo: LogInfo): Unit = { + try { + val path = new Path(walInfo.path) + val fs = HdfsUtils.getFileSystemForPath(path, hadoopConf) + fs.delete(path, true) + logDebug(s"Cleared log file $walInfo") + } catch { + case ex: Exception => + logWarning(s"Error clearing write ahead log file $walInfo", ex) } logInfo(s"Cleared log files in $logDirectory older than $threshTime") } - if (!executionContext.isShutdown) { - val f = Future { deleteFiles() } - if (waitForCompletion) { - import scala.concurrent.duration._ - Await.ready(f, 1 second) + oldLogFiles.foreach { logInfo => + if (!executionContext.isShutdown) { + try { + val f = Future { deleteFile(logInfo) }(executionContext) + if (waitForCompletion) { + import scala.concurrent.duration._ + Await.ready(f, 1 second) + } + } catch { + case e: RejectedExecutionException => + logWarning("Execution context shutdown before deleting old WriteAheadLogs. " + + "This would not affect recovery correctness.", e) + } } } } @@ -206,7 +231,8 @@ private[streaming] class FileBasedWriteAheadLog( val logDirectoryPath = new Path(logDirectory) val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) - if (fileSystem.exists(logDirectoryPath) && fileSystem.getFileStatus(logDirectoryPath).isDir) { + if (fileSystem.exists(logDirectoryPath) && + fileSystem.getFileStatus(logDirectoryPath).isDirectory) { val logFileInfo = logFilesTologInfo(fileSystem.listStatus(logDirectoryPath).map { _.getPath }) pastLogs.clear() pastLogs ++= logFileInfo @@ -234,8 +260,12 @@ private[streaming] object FileBasedWriteAheadLog { } def getCallerName(): Option[String] = { - val stackTraceClasses = Thread.currentThread.getStackTrace().map(_.getClassName) - stackTraceClasses.find(!_.contains("WriteAheadLog")).flatMap(_.split(".").lastOption) + val blacklist = Seq("WriteAheadLog", "Logging", "java.lang", "scala.") + Thread.currentThread.getStackTrace() + .map(_.getClassName) + .find { c => !blacklist.exists(c.contains) } + .flatMap(_.split("\\.").lastOption) + .flatMap(_.split("\\$\\$").headOption) } /** Convert a sequence of files to a sequence of sorted LogInfo objects */ @@ -251,4 +281,24 @@ private[streaming] object FileBasedWriteAheadLog { } }.sortBy { _.startTime } } + + /** + * This creates an iterator from a parallel collection, by keeping at most `n` objects in memory + * at any given time, where `n` is at most the max of the size of the thread pool or 8. This is + * crucial for use cases where we create `FileBasedWriteAheadLogReader`s during parallel recovery. + * We don't want to open up `k` streams altogether where `k` is the size of the Seq that we want + * to parallelize. + */ + def seqToParIterator[I, O]( + executionContext: ExecutionContext, + source: Seq[I], + handler: I => Iterator[O]): Iterator[O] = { + val taskSupport = new ExecutionContextTaskSupport(executionContext) + val groupSize = taskSupport.parallelismLevel.max(8) + source.grouped(groupSize).flatMap { group => + val parallelCollection = group.par + parallelCollection.tasksupport = taskSupport + parallelCollection.map(handler) + }.flatten + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogRandomReader.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogRandomReader.scala index f7168229ec15a..56d4977da0b51 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogRandomReader.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogRandomReader.scala @@ -30,7 +30,7 @@ private[streaming] class FileBasedWriteAheadLogRandomReader(path: String, conf: extends Closeable { private val instream = HdfsUtils.getInputStream(path, conf) - private var closed = false + private var closed = (instream == null) // the file may be deleted as we're opening the stream def read(segment: FileBasedWriteAheadLogSegment): ByteBuffer = synchronized { assertOpen() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala index c3bb59f3fef94..14d9bc94a123c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala @@ -16,11 +16,12 @@ */ package org.apache.spark.streaming.util -import java.io.{Closeable, EOFException} +import java.io.{Closeable, EOFException, IOException} import java.nio.ByteBuffer import org.apache.hadoop.conf.Configuration -import org.apache.spark.Logging + +import org.apache.spark.internal.Logging /** * A reader for reading write ahead log files written using @@ -32,7 +33,7 @@ private[streaming] class FileBasedWriteAheadLogReader(path: String, conf: Config extends Iterator[ByteBuffer] with Closeable with Logging { private val instream = HdfsUtils.getInputStream(path, conf) - private var closed = false + private var closed = (instream == null) // the file may be deleted as we're opening the stream private var nextItem: Option[ByteBuffer] = None override def hasNext: Boolean = synchronized { @@ -55,6 +56,19 @@ private[streaming] class FileBasedWriteAheadLogReader(path: String, conf: Config logDebug("Error reading next item, EOF reached", e) close() false + case e: IOException => + logWarning("Error while trying to read data. If the file was deleted, " + + "this should be okay.", e) + close() + if (HdfsUtils.checkFileExists(path, conf)) { + // If file exists, this could be a legitimate error + throw e + } else { + // File was deleted. This can occur when the daemon cleanup thread takes time to + // delete the file during recovery. + false + } + case e: Exception => logWarning("Error while trying to read data from HDFS.", e) close() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala index e146bec32a456..1f5c1d4369b53 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala @@ -19,10 +19,9 @@ package org.apache.spark.streaming.util import java.io._ import java.nio.ByteBuffer -import scala.util.Try - import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FSDataOutputStream + +import org.apache.spark.util.Utils /** * A writer for writing byte-buffers to a write ahead log file. @@ -32,11 +31,6 @@ private[streaming] class FileBasedWriteAheadLogWriter(path: String, hadoopConf: private lazy val stream = HdfsUtils.getOutputStream(path, hadoopConf) - private lazy val hadoopFlushMethod = { - // Use reflection to get the right flush operation - val cls = classOf[FSDataOutputStream] - Try(cls.getMethod("hflush")).orElse(Try(cls.getMethod("sync"))).toOption - } private var nextOffset = stream.getPos() private var closed = false @@ -48,17 +42,7 @@ private[streaming] class FileBasedWriteAheadLogWriter(path: String, hadoopConf: val lengthToWrite = data.remaining() val segment = new FileBasedWriteAheadLogSegment(path, nextOffset, lengthToWrite) stream.writeInt(lengthToWrite) - if (data.hasArray) { - stream.write(data.array()) - } else { - // If the buffer is not backed by an array, we transfer using temp array - // Note that despite the extra array copy, this should be faster than byte-by-byte copy - while (data.hasRemaining) { - val array = new Array[Byte](data.remaining) - data.get(array) - stream.write(array) - } - } + Utils.writeByteBuffer(data, stream: OutputStream) flush() nextOffset = stream.getPos() segment @@ -70,7 +54,7 @@ private[streaming] class FileBasedWriteAheadLogWriter(path: String, hadoopConf: } private def flush() { - hadoopFlushMethod.foreach { _.invoke(stream) } + stream.hflush() // Useful for local file system where hflush/sync does not work (HADOOP-7844) stream.getWrappedStream.flush() } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala index f60688f173c44..13a765d035ee8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.streaming.util +import java.io.IOException + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ @@ -42,8 +44,19 @@ private[streaming] object HdfsUtils { def getInputStream(path: String, conf: Configuration): FSDataInputStream = { val dfsPath = new Path(path) val dfs = getFileSystemForPath(dfsPath, conf) - val instream = dfs.open(dfsPath) - instream + if (dfs.isFile(dfsPath)) { + try { + dfs.open(dfsPath) + } catch { + case e: IOException => + // If we are really unlucky, the file may be deleted as we're opening the stream. + // This can happen as clean up is performed by daemon threads that may be left over from + // previous runs. + if (!dfs.isFile(dfsPath)) null else throw e + } + } else { + null + } } def checkState(state: Boolean, errorMsg: => String) { @@ -71,4 +84,11 @@ private[streaming] object HdfsUtils { case _ => fs } } + + /** Check if the file exists at the given path. */ + def checkFileExists(path: String, conf: Configuration): Boolean = { + val hdpPath = new Path(path) + val fs = getFileSystemForPath(hdpPath, conf) + fs.isFile(hdpPath) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RateLimitedOutputStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RateLimitedOutputStream.scala index a96e2924a0b44..29cc1fa00ac0f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RateLimitedOutputStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RateLimitedOutputStream.scala @@ -17,13 +17,12 @@ package org.apache.spark.streaming.util -import scala.annotation.tailrec - import java.io.OutputStream import java.util.concurrent.TimeUnit._ -import org.apache.spark.Logging +import scala.annotation.tailrec +import org.apache.spark.internal.Logging private[streaming] class RateLimitedOutputStream(out: OutputStream, desiredBytesPerSec: Int) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala index 6addb96752038..9667af97f03bc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala @@ -23,7 +23,8 @@ import java.nio.ByteBuffer import scala.io.Source -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.IntParam diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala index 0148cb51c6f09..62e681e3e9646 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.util -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.util.{Clock, SystemClock} private[streaming] @@ -72,10 +72,10 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: /** * Stop the timer, and return the last time the callback was made. - * - interruptTimer = true will interrupt the callback - * if it is in progress (not guaranteed to give correct time in this case). - * - interruptTimer = false guarantees that there will be at least one callback after `stop` has - * been called. + * + * @param interruptTimer True will interrupt the callback if it is in progress (not guaranteed to + * give correct time in this case). False guarantees that there will be at + * least one callback after `stop` has been called. */ def stop(interruptTimer: Boolean): Long = synchronized { if (!stopped) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala new file mode 100644 index 0000000000000..3a21cfae5ac2f --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -0,0 +1,375 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import java.io._ + +import scala.reflect.ClassTag + +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + +import org.apache.spark.SparkConf +import org.apache.spark.serializer.{KryoInputObjectInputBridge, KryoOutputObjectOutputBridge} +import org.apache.spark.streaming.util.OpenHashMapBasedStateMap._ +import org.apache.spark.util.collection.OpenHashMap + +/** Internal interface for defining the map that keeps track of sessions. */ +private[streaming] abstract class StateMap[K, S] extends Serializable { + + /** Get the state for a key if it exists */ + def get(key: K): Option[S] + + /** Get all the keys and states whose updated time is older than the given threshold time */ + def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] + + /** Get all the keys and states in this map. */ + def getAll(): Iterator[(K, S, Long)] + + /** Add or update state */ + def put(key: K, state: S, updatedTime: Long): Unit + + /** Remove a key */ + def remove(key: K): Unit + + /** + * Shallow copy `this` map to create a new state map. + * Updates to the new map should not mutate `this` map. + */ + def copy(): StateMap[K, S] + + def toDebugString(): String = toString() +} + +/** Companion object for [[StateMap]], with utility methods */ +private[streaming] object StateMap { + def empty[K, S]: StateMap[K, S] = new EmptyStateMap[K, S] + + def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = { + val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold", + DELTA_CHAIN_LENGTH_THRESHOLD) + new OpenHashMapBasedStateMap[K, S](deltaChainThreshold) + } +} + +/** Implementation of StateMap interface representing an empty map */ +private[streaming] class EmptyStateMap[K, S] extends StateMap[K, S] { + override def put(key: K, session: S, updateTime: Long): Unit = { + throw new NotImplementedError("put() should not be called on an EmptyStateMap") + } + override def get(key: K): Option[S] = None + override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = Iterator.empty + override def getAll(): Iterator[(K, S, Long)] = Iterator.empty + override def copy(): StateMap[K, S] = this + override def remove(key: K): Unit = { } + override def toDebugString(): String = "" +} + +/** Implementation of StateMap based on Spark's [[org.apache.spark.util.collection.OpenHashMap]] */ +private[streaming] class OpenHashMapBasedStateMap[K, S]( + @transient @volatile var parentStateMap: StateMap[K, S], + private var initialCapacity: Int = DEFAULT_INITIAL_CAPACITY, + private var deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD + )(implicit private var keyClassTag: ClassTag[K], private var stateClassTag: ClassTag[S]) + extends StateMap[K, S] with KryoSerializable { self => + + def this(initialCapacity: Int, deltaChainThreshold: Int) + (implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = this( + new EmptyStateMap[K, S], + initialCapacity = initialCapacity, + deltaChainThreshold = deltaChainThreshold) + + def this(deltaChainThreshold: Int) + (implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = this( + initialCapacity = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold = deltaChainThreshold) + + def this()(implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = { + this(DELTA_CHAIN_LENGTH_THRESHOLD) + } + + require(initialCapacity >= 1, "Invalid initial capacity") + require(deltaChainThreshold >= 1, "Invalid delta chain threshold") + + @transient @volatile private var deltaMap = new OpenHashMap[K, StateInfo[S]](initialCapacity) + + /** Get the session data if it exists */ + override def get(key: K): Option[S] = { + val stateInfo = deltaMap(key) + if (stateInfo != null) { + if (!stateInfo.deleted) { + Some(stateInfo.data) + } else { + None + } + } else { + parentStateMap.get(key) + } + } + + /** Get all the keys and states whose updated time is older than the give threshold time */ + override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = { + val oldStates = parentStateMap.getByTime(threshUpdatedTime).filter { case (key, value, _) => + !deltaMap.contains(key) + } + + val updatedStates = deltaMap.iterator.filter { case (_, stateInfo) => + !stateInfo.deleted && stateInfo.updateTime < threshUpdatedTime + }.map { case (key, stateInfo) => + (key, stateInfo.data, stateInfo.updateTime) + } + oldStates ++ updatedStates + } + + /** Get all the keys and states in this map. */ + override def getAll(): Iterator[(K, S, Long)] = { + + val oldStates = parentStateMap.getAll().filter { case (key, _, _) => + !deltaMap.contains(key) + } + + val updatedStates = deltaMap.iterator.filter { ! _._2.deleted }.map { case (key, stateInfo) => + (key, stateInfo.data, stateInfo.updateTime) + } + oldStates ++ updatedStates + } + + /** Add or update state */ + override def put(key: K, state: S, updateTime: Long): Unit = { + val stateInfo = deltaMap(key) + if (stateInfo != null) { + stateInfo.update(state, updateTime) + } else { + deltaMap.update(key, new StateInfo(state, updateTime)) + } + } + + /** Remove a state */ + override def remove(key: K): Unit = { + val stateInfo = deltaMap(key) + if (stateInfo != null) { + stateInfo.markDeleted() + } else { + val newInfo = new StateInfo[S](deleted = true) + deltaMap.update(key, newInfo) + } + } + + /** + * Shallow copy the map to create a new session store. Updates to the new map + * should not mutate `this` map. + */ + override def copy(): StateMap[K, S] = { + new OpenHashMapBasedStateMap[K, S](this, deltaChainThreshold = deltaChainThreshold) + } + + /** Whether the delta chain length is long enough that it should be compacted */ + def shouldCompact: Boolean = { + deltaChainLength >= deltaChainThreshold + } + + /** Length of the delta chains of this map */ + def deltaChainLength: Int = parentStateMap match { + case map: OpenHashMapBasedStateMap[_, _] => map.deltaChainLength + 1 + case _ => 0 + } + + /** + * Approximate number of keys in the map. This is an overestimation that is mainly used to + * reserve capacity in a new map at delta compaction time. + */ + def approxSize: Int = deltaMap.size + { + parentStateMap match { + case s: OpenHashMapBasedStateMap[_, _] => s.approxSize + case _ => 0 + } + } + + /** Get all the data of this map as string formatted as a tree based on the delta depth */ + override def toDebugString(): String = { + val tabs = if (deltaChainLength > 0) { + (" " * (deltaChainLength - 1)) + "+--- " + } else "" + parentStateMap.toDebugString() + "\n" + deltaMap.iterator.mkString(tabs, "\n" + tabs, "") + } + + override def toString(): String = { + s"[${System.identityHashCode(this)}, ${System.identityHashCode(parentStateMap)}]" + } + + /** + * Serialize the map data. Besides serialization, this method actually compact the deltas + * (if needed) in a single pass over all the data in the map. + */ + private def writeObjectInternal(outputStream: ObjectOutput): Unit = { + // Write the data in the delta of this state map + outputStream.writeInt(deltaMap.size) + val deltaMapIterator = deltaMap.iterator + var deltaMapCount = 0 + while (deltaMapIterator.hasNext) { + deltaMapCount += 1 + val (key, stateInfo) = deltaMapIterator.next() + outputStream.writeObject(key) + outputStream.writeObject(stateInfo) + } + assert(deltaMapCount == deltaMap.size) + + // Write the data in the parent state map while copying the data into a new parent map for + // compaction (if needed) + val doCompaction = shouldCompact + val newParentSessionStore = if (doCompaction) { + val initCapacity = if (approxSize > 0) approxSize else 64 + new OpenHashMapBasedStateMap[K, S](initialCapacity = initCapacity, deltaChainThreshold) + } else { null } + + val iterOfActiveSessions = parentStateMap.getAll() + + var parentSessionCount = 0 + + // First write the approximate size of the data to be written, so that readObject can + // allocate appropriately sized OpenHashMap. + outputStream.writeInt(approxSize) + + while(iterOfActiveSessions.hasNext) { + parentSessionCount += 1 + + val (key, state, updateTime) = iterOfActiveSessions.next() + outputStream.writeObject(key) + outputStream.writeObject(state) + outputStream.writeLong(updateTime) + + if (doCompaction) { + newParentSessionStore.deltaMap.update( + key, StateInfo(state, updateTime, deleted = false)) + } + } + + // Write the final limit marking object with the correct count of records written. + val limiterObj = new LimitMarker(parentSessionCount) + outputStream.writeObject(limiterObj) + if (doCompaction) { + parentStateMap = newParentSessionStore + } + } + + /** Deserialize the map data. */ + private def readObjectInternal(inputStream: ObjectInput): Unit = { + // Read the data of the delta + val deltaMapSize = inputStream.readInt() + deltaMap = if (deltaMapSize != 0) { + new OpenHashMap[K, StateInfo[S]](deltaMapSize) + } else { + new OpenHashMap[K, StateInfo[S]](initialCapacity) + } + var deltaMapCount = 0 + while (deltaMapCount < deltaMapSize) { + val key = inputStream.readObject().asInstanceOf[K] + val sessionInfo = inputStream.readObject().asInstanceOf[StateInfo[S]] + deltaMap.update(key, sessionInfo) + deltaMapCount += 1 + } + + + // Read the data of the parent map. Keep reading records, until the limiter is reached + // First read the approximate number of records to expect and allocate properly size + // OpenHashMap + val parentStateMapSizeHint = inputStream.readInt() + val newStateMapInitialCapacity = math.max(parentStateMapSizeHint, DEFAULT_INITIAL_CAPACITY) + val newParentSessionStore = new OpenHashMapBasedStateMap[K, S]( + initialCapacity = newStateMapInitialCapacity, deltaChainThreshold) + + // Read the records until the limit marking object has been reached + var parentSessionLoopDone = false + while(!parentSessionLoopDone) { + val obj = inputStream.readObject() + if (obj.isInstanceOf[LimitMarker]) { + parentSessionLoopDone = true + val expectedCount = obj.asInstanceOf[LimitMarker].num + assert(expectedCount == newParentSessionStore.deltaMap.size) + } else { + val key = obj.asInstanceOf[K] + val state = inputStream.readObject().asInstanceOf[S] + val updateTime = inputStream.readLong() + newParentSessionStore.deltaMap.update( + key, StateInfo(state, updateTime, deleted = false)) + } + } + parentStateMap = newParentSessionStore + } + + private def writeObject(outputStream: ObjectOutputStream): Unit = { + // Write all the non-transient fields, especially class tags, etc. + outputStream.defaultWriteObject() + writeObjectInternal(outputStream) + } + + private def readObject(inputStream: ObjectInputStream): Unit = { + // Read the non-transient fields, especially class tags, etc. + inputStream.defaultReadObject() + readObjectInternal(inputStream) + } + + override def write(kryo: Kryo, output: Output): Unit = { + output.writeInt(initialCapacity) + output.writeInt(deltaChainThreshold) + kryo.writeClassAndObject(output, keyClassTag) + kryo.writeClassAndObject(output, stateClassTag) + writeObjectInternal(new KryoOutputObjectOutputBridge(kryo, output)) + } + + override def read(kryo: Kryo, input: Input): Unit = { + initialCapacity = input.readInt() + deltaChainThreshold = input.readInt() + keyClassTag = kryo.readClassAndObject(input).asInstanceOf[ClassTag[K]] + stateClassTag = kryo.readClassAndObject(input).asInstanceOf[ClassTag[S]] + readObjectInternal(new KryoInputObjectInputBridge(kryo, input)) + } +} + +/** + * Companion object of [[OpenHashMapBasedStateMap]] having associated helper + * classes and methods + */ +private[streaming] object OpenHashMapBasedStateMap { + + /** Internal class to represent the state information */ + case class StateInfo[S]( + var data: S = null.asInstanceOf[S], + var updateTime: Long = -1, + var deleted: Boolean = false) { + + def markDeleted(): Unit = { + deleted = true + } + + def update(newData: S, newUpdateTime: Long): Unit = { + data = newData + updateTime = newUpdateTime + deleted = false + } + } + + /** + * Internal class to represent a marker the demarkate the end of all state data in the + * serialized bytes. + */ + class LimitMarker(val num: Int) extends Serializable + + val DELTA_CHAIN_LENGTH_THRESHOLD = 20 + + val DEFAULT_INITIAL_CAPACITY = 64 +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala index 0ea970e61b694..7542e2f5ecf24 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala @@ -21,8 +21,9 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.internal.Logging import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkConf, SparkException} /** A helper class with utility functions related to the WriteAheadLog interface */ private[streaming] object WriteAheadLogUtils extends Logging { @@ -38,6 +39,8 @@ private[streaming] object WriteAheadLogUtils extends Logging { val DRIVER_WAL_ROLLING_INTERVAL_CONF_KEY = "spark.streaming.driver.writeAheadLog.rollingIntervalSecs" val DRIVER_WAL_MAX_FAILURES_CONF_KEY = "spark.streaming.driver.writeAheadLog.maxFailures" + val DRIVER_WAL_BATCHING_CONF_KEY = "spark.streaming.driver.writeAheadLog.allowBatching" + val DRIVER_WAL_BATCHING_TIMEOUT_CONF_KEY = "spark.streaming.driver.writeAheadLog.batchingTimeout" val DRIVER_WAL_CLOSE_AFTER_WRITE_CONF_KEY = "spark.streaming.driver.writeAheadLog.closeFileAfterWrite" @@ -64,6 +67,18 @@ private[streaming] object WriteAheadLogUtils extends Logging { } } + def isBatchingEnabled(conf: SparkConf, isDriver: Boolean): Boolean = { + isDriver && conf.getBoolean(DRIVER_WAL_BATCHING_CONF_KEY, defaultValue = true) + } + + /** + * How long we will wait for the wrappedLog in the BatchedWriteAheadLog to write the records + * before we fail the write attempt to unblock receivers. + */ + def getBatchingTimeout(conf: SparkConf): Long = { + conf.getLong(DRIVER_WAL_BATCHING_TIMEOUT_CONF_KEY, defaultValue = 5000) + } + def shouldCloseFileAfterWrite(conf: SparkConf, isDriver: Boolean): Boolean = { if (isDriver) { conf.getBoolean(DRIVER_WAL_CLOSE_AFTER_WRITE_CONF_KEY, defaultValue = false) @@ -115,7 +130,7 @@ private[streaming] object WriteAheadLogUtils extends Logging { } else { sparkConf.getOption(RECEIVER_WAL_CLASS_CONF_KEY) } - classNameOption.map { className => + val wal = classNameOption.map { className => try { instantiateClass( Utils.classForName(className).asInstanceOf[Class[_ <: WriteAheadLog]], sparkConf) @@ -128,6 +143,11 @@ private[streaming] object WriteAheadLogUtils extends Logging { getRollingIntervalSecs(sparkConf, isDriver), getMaxFailures(sparkConf, isDriver), shouldCloseFileAfterWrite(sparkConf, isDriver)) } + if (isBatchingEnabled(sparkConf, isDriver)) { + new BatchedWriteAheadLog(wal, sparkConf) + } else { + wal + } } /** Instantiate the class, either using single arg constructor or zero arg constructor */ diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index c5217149224e4..01f0c4de9e3c9 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -18,7 +18,7 @@ package org.apache.spark.streaming; import java.io.*; -import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; import java.util.*; import java.util.concurrent.atomic.AtomicBoolean; @@ -33,19 +33,20 @@ import org.junit.Assert; import org.junit.Test; -import com.google.common.base.Optional; import com.google.common.io.Files; import com.google.common.collect.Sets; +import org.apache.spark.Accumulator; import org.apache.spark.HashPartitioner; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.Optional; import org.apache.spark.api.java.function.*; import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.api.java.*; import org.apache.spark.util.Utils; -import org.apache.spark.SparkConf; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -270,12 +271,12 @@ public void testMapPartitions() { JavaDStream mapped = stream.mapPartitions( new FlatMapFunction, String>() { @Override - public Iterable call(Iterator in) { + public Iterator call(Iterator in) { StringBuilder out = new StringBuilder(); while (in.hasNext()) { out.append(in.next().toUpperCase(Locale.ENGLISH)); } - return Arrays.asList(out.toString()); + return Arrays.asList(out.toString()).iterator(); } }); JavaTestUtils.attachTestOutputStream(mapped); @@ -348,7 +349,9 @@ private void testReduceByWindow(boolean withInverse) { JavaDStream reducedWindowed; if (withInverse) { reducedWindowed = stream.reduceByWindow(new IntegerSum(), - new IntegerDifference(), new Duration(2000), new Duration(1000)); + new IntegerDifference(), + new Duration(2000), + new Duration(1000)); } else { reducedWindowed = stream.reduceByWindow(new IntegerSum(), new Duration(2000), new Duration(1000)); @@ -496,7 +499,8 @@ public JavaRDD call(JavaRDD in) { pairStream.transformToPair( new Function2, Time, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaPairRDD in, Time time) { + @Override public JavaPairRDD call(JavaPairRDD in, + Time time) { return null; } } @@ -605,7 +609,8 @@ public JavaRDD call(JavaRDD rdd1, JavaRDD rdd2, Time ti pairStream1, new Function3, JavaPairRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) { + public JavaRDD call(JavaRDD rdd1, JavaPairRDD rdd2, + Time time) { return null; } } @@ -615,7 +620,8 @@ public JavaRDD call(JavaRDD rdd1, JavaPairRDD stream2, new Function3, JavaRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) { + public JavaPairRDD call(JavaRDD rdd1, JavaRDD rdd2, + Time time) { return null; } } @@ -623,9 +629,12 @@ public JavaPairRDD call(JavaRDD rdd1, JavaRDD r stream1.transformWithToPair( pairStream1, - new Function3, JavaPairRDD, Time, JavaPairRDD>() { + new Function3, JavaPairRDD, Time, + JavaPairRDD>() { @Override - public JavaPairRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) { + public JavaPairRDD call(JavaRDD rdd1, + JavaPairRDD rdd2, + Time time) { return null; } } @@ -635,7 +644,8 @@ public JavaPairRDD call(JavaRDD rdd1, JavaPairRDD, JavaRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) { + public JavaRDD call(JavaPairRDD rdd1, JavaRDD rdd2, + Time time) { return null; } } @@ -643,9 +653,12 @@ public JavaRDD call(JavaPairRDD rdd1, JavaRDD r pairStream1.transformWith( pairStream1, - new Function3, JavaPairRDD, Time, JavaRDD>() { + new Function3, JavaPairRDD, Time, + JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) { + public JavaRDD call(JavaPairRDD rdd1, + JavaPairRDD rdd2, + Time time) { return null; } } @@ -653,9 +666,12 @@ public JavaRDD call(JavaPairRDD rdd1, JavaPairRDD, JavaRDD, Time, JavaPairRDD>() { + new Function3, JavaRDD, Time, + JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) { + public JavaPairRDD call(JavaPairRDD rdd1, + JavaRDD rdd2, + Time time) { return null; } } @@ -663,9 +679,12 @@ public JavaPairRDD call(JavaPairRDD rdd1, JavaR pairStream1.transformWithToPair( pairStream2, - new Function3, JavaPairRDD, Time, JavaPairRDD>() { + new Function3, JavaPairRDD, Time, + JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) { + public JavaPairRDD call(JavaPairRDD rdd1, + JavaPairRDD rdd2, + Time time) { return null; } } @@ -721,13 +740,16 @@ public JavaRDD call(List> listOfRDDs, Time time) { listOfDStreams2, new Function2>, Time, JavaPairRDD>>() { @Override - public JavaPairRDD> call(List> listOfRDDs, Time time) { + public JavaPairRDD> call(List> listOfRDDs, + Time time) { Assert.assertEquals(3, listOfRDDs.size()); JavaRDD rdd1 = (JavaRDD)listOfRDDs.get(0); JavaRDD rdd2 = (JavaRDD)listOfRDDs.get(1); - JavaRDD> rdd3 = (JavaRDD>)listOfRDDs.get(2); + JavaRDD> rdd3 = + (JavaRDD>)listOfRDDs.get(2); JavaPairRDD prdd3 = JavaPairRDD.fromJavaRDD(rdd3); - PairFunction mapToTuple = new PairFunction() { + PairFunction mapToTuple = + new PairFunction() { @Override public Tuple2 call(Integer i) { return new Tuple2<>(i, i); @@ -738,7 +760,8 @@ public Tuple2 call(Integer i) { } ); JavaTestUtils.attachTestOutputStream(transformed2); - List>>> result = JavaTestUtils.runStreams(ssc, 2, 2); + List>>> result = + JavaTestUtils.runStreams(ssc, 2, 2); Assert.assertEquals(expected, result); } @@ -758,8 +781,8 @@ public void testFlatMap() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream flatMapped = stream.flatMap(new FlatMapFunction() { @Override - public Iterable call(String x) { - return Arrays.asList(x.split("(?!^)")); + public Iterator call(String x) { + return Arrays.asList(x.split("(?!^)")).iterator(); } }); JavaTestUtils.attachTestOutputStream(flatMapped); @@ -768,6 +791,44 @@ public Iterable call(String x) { assertOrderInvariantEquals(expected, result); } + @SuppressWarnings("unchecked") + @Test + public void testForeachRDD() { + final Accumulator accumRdd = ssc.sparkContext().accumulator(0); + final Accumulator accumEle = ssc.sparkContext().accumulator(0); + List> inputData = Arrays.asList( + Arrays.asList(1,1,1), + Arrays.asList(1,1,1)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaTestUtils.attachTestOutputStream(stream.count()); // dummy output + + stream.foreachRDD(new VoidFunction>() { + @Override + public void call(JavaRDD rdd) { + accumRdd.add(1); + rdd.foreach(new VoidFunction() { + @Override + public void call(Integer i) { + accumEle.add(1); + } + }); + } + }); + + // This is a test to make sure foreachRDD(VoidFunction2) can be called from Java + stream.foreachRDD(new VoidFunction2, Time>() { + @Override + public void call(JavaRDD rdd, Time time) { + } + }); + + JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(2, accumRdd.value().intValue()); + Assert.assertEquals(6, accumEle.value().intValue()); + } + @SuppressWarnings("unchecked") @Test public void testPairFlatMap() { @@ -807,12 +868,12 @@ public void testPairFlatMap() { JavaPairDStream flatMapped = stream.flatMapToPair( new PairFlatMapFunction() { @Override - public Iterable> call(String in) { + public Iterator> call(String in) { List> out = new ArrayList<>(); for (String letter: in.split("(?!^)")) { out.add(new Tuple2<>(in.length(), letter)); } - return out; + return out.iterator(); } }); JavaTestUtils.attachTestOutputStream(flatMapped); @@ -942,7 +1003,8 @@ public void testPairMap() { // Maps pair -> pair of different type new Tuple2<>(3, "new york"), new Tuple2<>(1, "new york"))); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reversed = pairStream.mapToPair( new PairFunction, Integer, String>() { @@ -975,18 +1037,19 @@ public void testPairMapPartitions() { // Maps pair -> pair of different type new Tuple2<>(3, "new york"), new Tuple2<>(1, "new york"))); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reversed = pairStream.mapPartitionsToPair( new PairFlatMapFunction>, Integer, String>() { @Override - public Iterable> call(Iterator> in) { + public Iterator> call(Iterator> in) { List> out = new LinkedList<>(); while (in.hasNext()) { Tuple2 next = in.next(); out.add(next.swap()); } - return out; + return out.iterator(); } }); @@ -1005,7 +1068,8 @@ public void testPairMap2() { // Maps pair -> single Arrays.asList(1, 3, 4, 1), Arrays.asList(5, 5, 3, 1)); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaDStream reversed = pairStream.map( new Function, Integer>() { @@ -1050,12 +1114,12 @@ public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair JavaPairDStream flatMapped = pairStream.flatMapToPair( new PairFlatMapFunction, Integer, String>() { @Override - public Iterable> call(Tuple2 in) { + public Iterator> call(Tuple2 in) { List> out = new LinkedList<>(); for (Character s : in._1().toCharArray()) { out.add(new Tuple2<>(in._2(), s.toString())); } - return out; + return out.iterator(); } }); JavaTestUtils.attachTestOutputStream(flatMapped); @@ -1077,7 +1141,8 @@ public void testPairGroupByKey() { new Tuple2<>("california", Arrays.asList("sharks", "ducks")), new Tuple2<>("new york", Arrays.asList("rangers", "islanders")))); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream> grouped = pairStream.groupByKey(); @@ -1202,7 +1267,8 @@ public void testGroupByKeyAndWindow() { ) ); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream> groupWindowed = @@ -1216,7 +1282,8 @@ public void testGroupByKeyAndWindow() { } } - private static Set>> convert(List>> listOfTuples) { + private static Set>> + convert(List>> listOfTuples) { List>> newListOfTuples = new ArrayList<>(); for (Tuple2> tuple: listOfTuples) { newListOfTuples.add(convert(tuple)); @@ -1241,7 +1308,8 @@ public void testReduceByKeyAndWindow() { Arrays.asList(new Tuple2<>("california", 10), new Tuple2<>("new york", 4))); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reduceWindowed = @@ -1265,7 +1333,8 @@ public void testUpdateStateByKey() { Arrays.asList(new Tuple2<>("california", 14), new Tuple2<>("new york", 9))); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream updated = pairStream.updateStateByKey( @@ -1293,12 +1362,12 @@ public Optional call(List values, Optional state) { public void testUpdateStateByKeyWithInitial() { List>> inputData = stringIntKVStream; - List> initial = Arrays.asList ( + List> initial = Arrays.asList( new Tuple2<>("california", 1), new Tuple2<>("new york", 2)); JavaRDD> tmpRDD = ssc.sparkContext().parallelize(initial); - JavaPairRDD initialRDD = JavaPairRDD.fromJavaRDD (tmpRDD); + JavaPairRDD initialRDD = JavaPairRDD.fromJavaRDD(tmpRDD); List>> expected = Arrays.asList( Arrays.asList(new Tuple2<>("california", 5), @@ -1308,7 +1377,8 @@ public void testUpdateStateByKeyWithInitial() { Arrays.asList(new Tuple2<>("california", 15), new Tuple2<>("new york", 11))); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream updated = pairStream.updateStateByKey( @@ -1344,7 +1414,8 @@ public void testReduceByKeyAndWindowWithInverse() { Arrays.asList(new Tuple2<>("california", 10), new Tuple2<>("new york", 4))); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reduceWindowed = @@ -1591,19 +1662,27 @@ public void testCoGroup() { ssc, stringStringKVStream2, 1); JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); - JavaPairDStream, Iterable>> grouped = pairStream1.cogroup(pairStream2); + JavaPairDStream, Iterable>> grouped = + pairStream1.cogroup(pairStream2); JavaTestUtils.attachTestOutputStream(grouped); - List, Iterable>>>> result = JavaTestUtils.runStreams(ssc, 2, 2); + List, Iterable>>>> result = + JavaTestUtils.runStreams(ssc, 2, 2); Assert.assertEquals(expected.size(), result.size()); - Iterator, Iterable>>>> resultItr = result.iterator(); - Iterator, List>>>> expectedItr = expected.iterator(); + Iterator, Iterable>>>> resultItr = + result.iterator(); + Iterator, List>>>> expectedItr = + expected.iterator(); while (resultItr.hasNext() && expectedItr.hasNext()) { - Iterator, Iterable>>> resultElements = resultItr.next().iterator(); - Iterator, List>>> expectedElements = expectedItr.next().iterator(); + Iterator, Iterable>>> resultElements = + resultItr.next().iterator(); + Iterator, List>>> expectedElements = + expectedItr.next().iterator(); while (resultElements.hasNext() && expectedElements.hasNext()) { - Tuple2, Iterable>> resultElement = resultElements.next(); - Tuple2, List>> expectedElement = expectedElements.next(); + Tuple2, Iterable>> resultElement = + resultElements.next(); + Tuple2, List>> expectedElement = + expectedElements.next(); Assert.assertEquals(expectedElement._1(), resultElement._1()); equalIterable(expectedElement._2()._1(), resultElement._2()._1()); equalIterable(expectedElement._2()._2(), resultElement._2()._2()); @@ -1680,7 +1759,8 @@ public void testLeftOuterJoin() { ssc, stringStringKVStream2, 1); JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); - JavaPairDStream>> joined = pairStream1.leftOuterJoin(pairStream2); + JavaPairDStream>> joined = + pairStream1.leftOuterJoin(pairStream2); JavaDStream counted = joined.count(); JavaTestUtils.attachTestOutputStream(counted); List> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -1827,7 +1907,8 @@ public void testSocketString() { @Override public Iterable call(InputStream in) throws IOException { List out = new ArrayList<>(); - try (BufferedReader reader = new BufferedReader(new InputStreamReader(in))) { + try (BufferedReader reader = new BufferedReader( + new InputStreamReader(in, StandardCharsets.UTF_8))) { for (String line; (line = reader.readLine()) != null;) { out.add(line); } @@ -1891,7 +1972,7 @@ public void testRawSocketStream() { private static List> fileTestPrepare(File testDir) throws IOException { File existingFile = new File(testDir, "0"); - Files.write("0\n", existingFile, Charset.forName("UTF-8")); + Files.write("0\n", existingFile, StandardCharsets.UTF_8); Assert.assertTrue(existingFile.setLastModified(1000)); Assert.assertEquals(1000, existingFile.lastModified()); return Arrays.asList(Arrays.asList("0")); diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java new file mode 100644 index 0000000000000..9b7701003d8d0 --- /dev/null +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Set; + +import scala.Tuple2; + +import com.google.common.collect.Sets; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.VoidFunction; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.util.ManualClock; +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.HashPartitioner; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.Optional; +import org.apache.spark.api.java.function.Function3; +import org.apache.spark.api.java.function.Function4; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaMapWithStateDStream; + +public class JavaMapWithStateSuite extends LocalJavaStreamingContext implements Serializable { + + /** + * This test is only for testing the APIs. It's not necessary to run it. + */ + public void testAPI() { + JavaPairRDD initialRDD = null; + JavaPairDStream wordsDstream = null; + + Function4, State, Optional> mappingFunc = + new Function4, State, Optional>() { + @Override + public Optional call( + Time time, String word, Optional one, State state) { + // Use all State's methods here + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return Optional.of(2.0); + } + }; + + JavaMapWithStateDStream stateDstream = + wordsDstream.mapWithState( + StateSpec.function(mappingFunc) + .initialState(initialRDD) + .numPartitions(10) + .partitioner(new HashPartitioner(10)) + .timeout(Durations.seconds(10))); + + stateDstream.stateSnapshots(); + + Function3, State, Double> mappingFunc2 = + new Function3, State, Double>() { + @Override + public Double call(String key, Optional one, State state) { + // Use all State's methods here + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return 2.0; + } + }; + + JavaMapWithStateDStream stateDstream2 = + wordsDstream.mapWithState( + StateSpec.function(mappingFunc2) + .initialState(initialRDD) + .numPartitions(10) + .partitioner(new HashPartitioner(10)) + .timeout(Durations.seconds(10))); + + stateDstream2.stateSnapshots(); + } + + @Test + public void testBasicFunction() { + List> inputData = Arrays.asList( + Collections.emptyList(), + Arrays.asList("a"), + Arrays.asList("a", "b"), + Arrays.asList("a", "b", "c"), + Arrays.asList("a", "b"), + Arrays.asList("a"), + Collections.emptyList() + ); + + List> outputData = Arrays.asList( + Collections.emptySet(), + Sets.newHashSet(1), + Sets.newHashSet(2, 1), + Sets.newHashSet(3, 2, 1), + Sets.newHashSet(4, 3), + Sets.newHashSet(5), + Collections.emptySet() + ); + + @SuppressWarnings("unchecked") + List>> stateData = Arrays.asList( + Collections.>emptySet(), + Sets.newHashSet(new Tuple2<>("a", 1)), + Sets.newHashSet(new Tuple2<>("a", 2), new Tuple2<>("b", 1)), + Sets.newHashSet(new Tuple2<>("a", 3), new Tuple2<>("b", 2), new Tuple2<>("c", 1)), + Sets.newHashSet(new Tuple2<>("a", 4), new Tuple2<>("b", 3), new Tuple2<>("c", 1)), + Sets.newHashSet(new Tuple2<>("a", 5), new Tuple2<>("b", 3), new Tuple2<>("c", 1)), + Sets.newHashSet(new Tuple2<>("a", 5), new Tuple2<>("b", 3), new Tuple2<>("c", 1)) + ); + + Function3, State, Integer> mappingFunc = + new Function3, State, Integer>() { + @Override + public Integer call(String key, Optional value, State state) { + int sum = value.orElse(0) + (state.exists() ? state.get() : 0); + state.update(sum); + return sum; + } + }; + testOperation( + inputData, + StateSpec.function(mappingFunc), + outputData, + stateData); + } + + private void testOperation( + List> input, + StateSpec mapWithStateSpec, + List> expectedOutputs, + List>> expectedStateSnapshots) { + int numBatches = expectedOutputs.size(); + JavaDStream inputStream = JavaTestUtils.attachTestInputStream(ssc, input, 2); + JavaMapWithStateDStream mapWithStateDStream = + JavaPairDStream.fromJavaDStream(inputStream.map(new Function>() { + @Override + public Tuple2 call(K x) { + return new Tuple2<>(x, 1); + } + })).mapWithState(mapWithStateSpec); + + final List> collectedOutputs = + Collections.synchronizedList(new ArrayList>()); + mapWithStateDStream.foreachRDD(new VoidFunction>() { + @Override + public void call(JavaRDD rdd) { + collectedOutputs.add(Sets.newHashSet(rdd.collect())); + } + }); + final List>> collectedStateSnapshots = + Collections.synchronizedList(new ArrayList>>()); + mapWithStateDStream.stateSnapshots().foreachRDD(new VoidFunction>() { + @Override + public void call(JavaPairRDD rdd) { + collectedStateSnapshots.add(Sets.newHashSet(rdd.collect())); + } + }); + BatchCounter batchCounter = new BatchCounter(ssc.ssc()); + ssc.start(); + ((ManualClock) ssc.ssc().scheduler().clock()) + .advance(ssc.ssc().progressListener().batchDuration() * numBatches + 1); + batchCounter.waitUntilBatchesCompleted(numBatches, 10000); + + Assert.assertEquals(expectedOutputs, collectedOutputs); + Assert.assertEquals(expectedStateSnapshots, collectedStateSnapshots); + } +} diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java index ec2bffd6a5b97..091ccbfd85cad 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java @@ -18,12 +18,14 @@ package org.apache.spark.streaming; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.VoidFunction; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; -import static org.junit.Assert.*; +import com.google.common.io.Closeables; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -36,6 +38,7 @@ import java.io.Serializable; import java.net.ConnectException; import java.net.Socket; +import java.nio.charset.StandardCharsets; import java.util.concurrent.atomic.AtomicLong; public class JavaReceiverAPISuite implements Serializable { @@ -67,12 +70,11 @@ public String call(String v1) { return v1 + "."; } }); - mapped.foreachRDD(new Function, Void>() { + mapped.foreachRDD(new VoidFunction>() { @Override - public Void call(JavaRDD rdd) { + public void call(JavaRDD rdd) { long count = rdd.count(); dataCounter.addAndGet(count); - return null; } }); @@ -89,7 +91,7 @@ public Void call(JavaRDD rdd) { Thread.sleep(100); } ssc.stop(); - assertTrue(dataCounter.get() > 0); + Assert.assertTrue(dataCounter.get() > 0); } finally { server.stop(); } @@ -97,8 +99,8 @@ public Void call(JavaRDD rdd) { private static class JavaSocketReceiver extends Receiver { - String host = null; - int port = -1; + private String host = null; + private int port = -1; JavaSocketReceiver(String host_ , int port_) { super(StorageLevel.MEMORY_AND_DISK()); @@ -121,14 +123,20 @@ public void onStop() { private void receive() { try { - Socket socket = new Socket(host, port); - BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream())); - String userInput; - while ((userInput = in.readLine()) != null) { - store(userInput); + Socket socket = null; + BufferedReader in = null; + try { + socket = new Socket(host, port); + in = new BufferedReader( + new InputStreamReader(socket.getInputStream(), StandardCharsets.UTF_8)); + String userInput; + while ((userInput = in.readLine()) != null) { + store(userInput); + } + } finally { + Closeables.close(in, /* swallowIOException = */ true); + Closeables.close(socket, /* swallowIOException = */ true); } - in.close(); - socket.close(); } catch(ConnectException ce) { ce.printStackTrace(); restart("Could not connect", ce); diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java new file mode 100644 index 0000000000000..ff0be820e0a9a --- /dev/null +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.streaming; + +import org.apache.spark.streaming.api.java.*; + +public class JavaStreamingListenerAPISuite extends JavaStreamingListener { + + @Override + public void onReceiverStarted(JavaStreamingListenerReceiverStarted receiverStarted) { + JavaReceiverInfo receiverInfo = receiverStarted.receiverInfo(); + receiverInfo.streamId(); + receiverInfo.name(); + receiverInfo.active(); + receiverInfo.location(); + receiverInfo.executorId(); + receiverInfo.lastErrorMessage(); + receiverInfo.lastError(); + receiverInfo.lastErrorTime(); + } + + @Override + public void onReceiverError(JavaStreamingListenerReceiverError receiverError) { + JavaReceiverInfo receiverInfo = receiverError.receiverInfo(); + receiverInfo.streamId(); + receiverInfo.name(); + receiverInfo.active(); + receiverInfo.location(); + receiverInfo.executorId(); + receiverInfo.lastErrorMessage(); + receiverInfo.lastError(); + receiverInfo.lastErrorTime(); + } + + @Override + public void onReceiverStopped(JavaStreamingListenerReceiverStopped receiverStopped) { + JavaReceiverInfo receiverInfo = receiverStopped.receiverInfo(); + receiverInfo.streamId(); + receiverInfo.name(); + receiverInfo.active(); + receiverInfo.location(); + receiverInfo.executorId(); + receiverInfo.lastErrorMessage(); + receiverInfo.lastError(); + receiverInfo.lastErrorTime(); + } + + @Override + public void onBatchSubmitted(JavaStreamingListenerBatchSubmitted batchSubmitted) { + super.onBatchSubmitted(batchSubmitted); + } + + @Override + public void onBatchStarted(JavaStreamingListenerBatchStarted batchStarted) { + super.onBatchStarted(batchStarted); + } + + @Override + public void onBatchCompleted(JavaStreamingListenerBatchCompleted batchCompleted) { + super.onBatchCompleted(batchCompleted); + } + + @Override + public void onOutputOperationStarted( + JavaStreamingListenerOutputOperationStarted outputOperationStarted) { + super.onOutputOperationStarted(outputOperationStarted); + } + + @Override + public void onOutputOperationCompleted( + JavaStreamingListenerOutputOperationCompleted outputOperationCompleted) { + super.onOutputOperationCompleted(outputOperationCompleted); + } +} diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala index 57b50bdfd6520..ae44fd07ac558 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala @@ -69,7 +69,7 @@ trait JavaTestBase extends TestSuiteBase { implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] ssc.getState() val res = runStreams[V](ssc.ssc, numBatches, numExpectedOutput) - res.map(_.asJava).asJava + res.map(_.asJava).toSeq.asJava } /** @@ -85,7 +85,7 @@ trait JavaTestBase extends TestSuiteBase { implicit val cm: ClassTag[V] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] val res = runStreamsWithPartitions[V](ssc.ssc, numBatches, numExpectedOutput) - res.map(entry => entry.map(_.asJava).asJava).asJava + res.map(entry => entry.map(_.asJava).asJava).toSeq.asJava } } diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java index 175b8a496b4e5..f02fa87f6194b 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.streaming; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.nio.ByteBuffer; import java.util.Arrays; @@ -27,6 +26,7 @@ import com.google.common.base.Function; import com.google.common.collect.Iterators; import org.apache.spark.SparkConf; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.streaming.util.WriteAheadLog; import org.apache.spark.streaming.util.WriteAheadLogRecordHandle; import org.apache.spark.streaming.util.WriteAheadLogUtils; @@ -108,23 +108,23 @@ public void close() { public void testCustomWAL() { SparkConf conf = new SparkConf(); conf.set("spark.streaming.driver.writeAheadLog.class", JavaWriteAheadLogSuite.class.getName()); + conf.set("spark.streaming.driver.writeAheadLog.allowBatching", "false"); WriteAheadLog wal = WriteAheadLogUtils.createLogForDriver(conf, null, null); String data1 = "data1"; - WriteAheadLogRecordHandle handle = - wal.write(ByteBuffer.wrap(data1.getBytes(StandardCharsets.UTF_8)), 1234); + WriteAheadLogRecordHandle handle = wal.write(JavaUtils.stringToBytes(data1), 1234); Assert.assertTrue(handle instanceof JavaWriteAheadLogSuiteHandle); - Assert.assertEquals(new String(wal.read(handle).array(), StandardCharsets.UTF_8), data1); + Assert.assertEquals(JavaUtils.bytesToString(wal.read(handle)), data1); - wal.write(ByteBuffer.wrap("data2".getBytes(StandardCharsets.UTF_8)), 1235); - wal.write(ByteBuffer.wrap("data3".getBytes(StandardCharsets.UTF_8)), 1236); - wal.write(ByteBuffer.wrap("data4".getBytes(StandardCharsets.UTF_8)), 1237); + wal.write(JavaUtils.stringToBytes("data2"), 1235); + wal.write(JavaUtils.stringToBytes("data3"), 1236); + wal.write(JavaUtils.stringToBytes("data4"), 1237); wal.clean(1236, false); Iterator dataIterator = wal.readAll(); List readData = new ArrayList<>(); while (dataIterator.hasNext()) { - readData.add(new String(dataIterator.next().array(), StandardCharsets.UTF_8)); + readData.add(JavaUtils.bytesToString(dataIterator.next())); } Assert.assertEquals(readData, Arrays.asList("data3", "data4")); } diff --git a/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala b/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala new file mode 100644 index 0000000000000..0295e059f7bc2 --- /dev/null +++ b/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.streaming.Time +import org.apache.spark.streaming.scheduler._ + +class JavaStreamingListenerWrapperSuite extends SparkFunSuite { + + test("basic") { + val listener = new TestJavaStreamingListener() + val listenerWrapper = new JavaStreamingListenerWrapper(listener) + + val receiverStarted = StreamingListenerReceiverStarted(ReceiverInfo( + streamId = 2, + name = "test", + active = true, + location = "localhost", + executorId = "1" + )) + listenerWrapper.onReceiverStarted(receiverStarted) + assertReceiverInfo(listener.receiverStarted.receiverInfo, receiverStarted.receiverInfo) + + val receiverStopped = StreamingListenerReceiverStopped(ReceiverInfo( + streamId = 2, + name = "test", + active = false, + location = "localhost", + executorId = "1" + )) + listenerWrapper.onReceiverStopped(receiverStopped) + assertReceiverInfo(listener.receiverStopped.receiverInfo, receiverStopped.receiverInfo) + + val receiverError = StreamingListenerReceiverError(ReceiverInfo( + streamId = 2, + name = "test", + active = false, + location = "localhost", + executorId = "1", + lastErrorMessage = "failed", + lastError = "failed", + lastErrorTime = System.currentTimeMillis() + )) + listenerWrapper.onReceiverError(receiverError) + assertReceiverInfo(listener.receiverError.receiverInfo, receiverError.receiverInfo) + + val batchSubmitted = StreamingListenerBatchSubmitted(BatchInfo( + batchTime = Time(1000L), + streamIdToInputInfo = Map( + 0 -> StreamInputInfo( + inputStreamId = 0, + numRecords = 1000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver1")), + 1 -> StreamInputInfo( + inputStreamId = 1, + numRecords = 2000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver2"))), + submissionTime = 1001L, + None, + None, + outputOperationInfos = Map( + 0 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = None, + endTime = None, + failureReason = None), + 1 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 1, + name = "op2", + description = "operation2", + startTime = None, + endTime = None, + failureReason = None)) + )) + listenerWrapper.onBatchSubmitted(batchSubmitted) + assertBatchInfo(listener.batchSubmitted.batchInfo, batchSubmitted.batchInfo) + + val batchStarted = StreamingListenerBatchStarted(BatchInfo( + batchTime = Time(1000L), + streamIdToInputInfo = Map( + 0 -> StreamInputInfo( + inputStreamId = 0, + numRecords = 1000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver1")), + 1 -> StreamInputInfo( + inputStreamId = 1, + numRecords = 2000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver2"))), + submissionTime = 1001L, + Some(1002L), + None, + outputOperationInfos = Map( + 0 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = Some(1003L), + endTime = None, + failureReason = None), + 1 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 1, + name = "op2", + description = "operation2", + startTime = Some(1005L), + endTime = None, + failureReason = None)) + )) + listenerWrapper.onBatchStarted(batchStarted) + assertBatchInfo(listener.batchStarted.batchInfo, batchStarted.batchInfo) + + val batchCompleted = StreamingListenerBatchCompleted(BatchInfo( + batchTime = Time(1000L), + streamIdToInputInfo = Map( + 0 -> StreamInputInfo( + inputStreamId = 0, + numRecords = 1000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver1")), + 1 -> StreamInputInfo( + inputStreamId = 1, + numRecords = 2000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver2"))), + submissionTime = 1001L, + Some(1002L), + Some(1010L), + outputOperationInfos = Map( + 0 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = Some(1003L), + endTime = Some(1004L), + failureReason = None), + 1 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 1, + name = "op2", + description = "operation2", + startTime = Some(1005L), + endTime = Some(1010L), + failureReason = None)) + )) + listenerWrapper.onBatchCompleted(batchCompleted) + assertBatchInfo(listener.batchCompleted.batchInfo, batchCompleted.batchInfo) + + val outputOperationStarted = StreamingListenerOutputOperationStarted(OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = Some(1003L), + endTime = None, + failureReason = None + )) + listenerWrapper.onOutputOperationStarted(outputOperationStarted) + assertOutputOperationInfo(listener.outputOperationStarted.outputOperationInfo, + outputOperationStarted.outputOperationInfo) + + val outputOperationCompleted = StreamingListenerOutputOperationCompleted(OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = Some(1003L), + endTime = Some(1004L), + failureReason = None + )) + listenerWrapper.onOutputOperationCompleted(outputOperationCompleted) + assertOutputOperationInfo(listener.outputOperationCompleted.outputOperationInfo, + outputOperationCompleted.outputOperationInfo) + } + + private def assertReceiverInfo( + javaReceiverInfo: JavaReceiverInfo, receiverInfo: ReceiverInfo): Unit = { + assert(javaReceiverInfo.streamId === receiverInfo.streamId) + assert(javaReceiverInfo.name === receiverInfo.name) + assert(javaReceiverInfo.active === receiverInfo.active) + assert(javaReceiverInfo.location === receiverInfo.location) + assert(javaReceiverInfo.executorId === receiverInfo.executorId) + assert(javaReceiverInfo.lastErrorMessage === receiverInfo.lastErrorMessage) + assert(javaReceiverInfo.lastError === receiverInfo.lastError) + assert(javaReceiverInfo.lastErrorTime === receiverInfo.lastErrorTime) + } + + private def assertBatchInfo(javaBatchInfo: JavaBatchInfo, batchInfo: BatchInfo): Unit = { + assert(javaBatchInfo.batchTime === batchInfo.batchTime) + assert(javaBatchInfo.streamIdToInputInfo.size === batchInfo.streamIdToInputInfo.size) + batchInfo.streamIdToInputInfo.foreach { case (streamId, streamInputInfo) => + assertStreamingInfo(javaBatchInfo.streamIdToInputInfo.get(streamId), streamInputInfo) + } + assert(javaBatchInfo.submissionTime === batchInfo.submissionTime) + assert(javaBatchInfo.processingStartTime === batchInfo.processingStartTime.getOrElse(-1)) + assert(javaBatchInfo.processingEndTime === batchInfo.processingEndTime.getOrElse(-1)) + assert(javaBatchInfo.schedulingDelay === batchInfo.schedulingDelay.getOrElse(-1)) + assert(javaBatchInfo.processingDelay === batchInfo.processingDelay.getOrElse(-1)) + assert(javaBatchInfo.totalDelay === batchInfo.totalDelay.getOrElse(-1)) + assert(javaBatchInfo.numRecords === batchInfo.numRecords) + assert(javaBatchInfo.outputOperationInfos.size === batchInfo.outputOperationInfos.size) + batchInfo.outputOperationInfos.foreach { case (outputOperationId, outputOperationInfo) => + assertOutputOperationInfo( + javaBatchInfo.outputOperationInfos.get(outputOperationId), outputOperationInfo) + } + } + + private def assertStreamingInfo( + javaStreamInputInfo: JavaStreamInputInfo, streamInputInfo: StreamInputInfo): Unit = { + assert(javaStreamInputInfo.inputStreamId === streamInputInfo.inputStreamId) + assert(javaStreamInputInfo.numRecords === streamInputInfo.numRecords) + assert(javaStreamInputInfo.metadata === streamInputInfo.metadata.asJava) + assert(javaStreamInputInfo.metadataDescription === streamInputInfo.metadataDescription.orNull) + } + + private def assertOutputOperationInfo( + javaOutputOperationInfo: JavaOutputOperationInfo, + outputOperationInfo: OutputOperationInfo): Unit = { + assert(javaOutputOperationInfo.batchTime === outputOperationInfo.batchTime) + assert(javaOutputOperationInfo.id === outputOperationInfo.id) + assert(javaOutputOperationInfo.name === outputOperationInfo.name) + assert(javaOutputOperationInfo.description === outputOperationInfo.description) + assert(javaOutputOperationInfo.startTime === outputOperationInfo.startTime.getOrElse(-1)) + assert(javaOutputOperationInfo.endTime === outputOperationInfo.endTime.getOrElse(-1)) + assert(javaOutputOperationInfo.failureReason === outputOperationInfo.failureReason.orNull) + } +} + +class TestJavaStreamingListener extends JavaStreamingListener { + + var receiverStarted: JavaStreamingListenerReceiverStarted = null + var receiverError: JavaStreamingListenerReceiverError = null + var receiverStopped: JavaStreamingListenerReceiverStopped = null + var batchSubmitted: JavaStreamingListenerBatchSubmitted = null + var batchStarted: JavaStreamingListenerBatchStarted = null + var batchCompleted: JavaStreamingListenerBatchCompleted = null + var outputOperationStarted: JavaStreamingListenerOutputOperationStarted = null + var outputOperationCompleted: JavaStreamingListenerOutputOperationCompleted = null + + override def onReceiverStarted(receiverStarted: JavaStreamingListenerReceiverStarted): Unit = { + this.receiverStarted = receiverStarted + } + + override def onReceiverError(receiverError: JavaStreamingListenerReceiverError): Unit = { + this.receiverError = receiverError + } + + override def onReceiverStopped(receiverStopped: JavaStreamingListenerReceiverStopped): Unit = { + this.receiverStopped = receiverStopped + } + + override def onBatchSubmitted(batchSubmitted: JavaStreamingListenerBatchSubmitted): Unit = { + this.batchSubmitted = batchSubmitted + } + + override def onBatchStarted(batchStarted: JavaStreamingListenerBatchStarted): Unit = { + this.batchStarted = batchStarted + } + + override def onBatchCompleted(batchCompleted: JavaStreamingListenerBatchCompleted): Unit = { + this.batchCompleted = batchCompleted + } + + override def onOutputOperationStarted( + outputOperationStarted: JavaStreamingListenerOutputOperationStarted): Unit = { + this.outputOperationStarted = outputOperationStarted + } + + override def onOutputOperationCompleted( + outputOperationCompleted: JavaStreamingListenerOutputOperationCompleted): Unit = { + this.outputOperationCompleted = outputOperationCompleted + } +} diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties index 75e3b53a093f6..fd51f8faf56b9 100644 --- a/streaming/src/test/resources/log4j.properties +++ b/streaming/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 9d296c6d3ef8b..cfcbdc7c382f9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -17,8 +17,10 @@ package org.apache.spark.streaming +import java.util.concurrent.ConcurrentLinkedQueue + +import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.language.existentials import scala.reflect.ClassTag @@ -74,7 +76,7 @@ class BasicOperationsSuite extends TestSuiteBase { assert(numInputPartitions === 2, "Number of input partitions has been changed from 2") val input = Seq(1 to 4, 5 to 8, 9 to 12) val output = Seq(Seq(3, 7), Seq(11, 15), Seq(19, 23)) - val operation = (r: DStream[Int]) => r.mapPartitions(x => Iterator(x.reduce(_ + _))) + val operation = (r: DStream[Int]) => r.mapPartitions(x => Iterator(x.sum)) testOperation(input, operation, output, true) } @@ -84,9 +86,10 @@ class BasicOperationsSuite extends TestSuiteBase { withStreamingContext(setupStreams(input, operation, 2)) { ssc => val output = runStreamsWithPartitions(ssc, 3, 3) assert(output.size === 3) - val first = output(0) - val second = output(1) - val third = output(2) + val outputArray = output.toArray + val first = outputArray(0) + val second = outputArray(1) + val third = outputArray(2) assert(first.size === 5) assert(second.size === 5) @@ -104,9 +107,10 @@ class BasicOperationsSuite extends TestSuiteBase { withStreamingContext(setupStreams(input, operation, 5)) { ssc => val output = runStreamsWithPartitions(ssc, 3, 3) assert(output.size === 3) - val first = output(0) - val second = output(1) - val third = output(2) + val outputArray = output.toArray + val first = outputArray(0) + val second = outputArray(1) + val third = outputArray(2) assert(first.size === 2) assert(second.size === 2) @@ -186,7 +190,7 @@ class BasicOperationsSuite extends TestSuiteBase { val output = Seq(1 to 8, 101 to 108, 201 to 208) testOperation( input, - (s: DStream[Int]) => s.union(s.map(_ + 4)) , + (s: DStream[Int]) => s.union(s.map(_ + 4)), output ) } @@ -534,10 +538,9 @@ class BasicOperationsSuite extends TestSuiteBase { val stateObj = state.getOrElse(new StateObject) values.sum match { case 0 => stateObj.expireCounter += 1 // no new values - case n => { // has new values, increment and reset expireCounter + case n => // has new values, increment and reset expireCounter stateObj.counter += n stateObj.expireCounter = 0 - } } stateObj.expireCounter match { case 2 => None // seen twice with no new values, give it the boot @@ -645,8 +648,8 @@ class BasicOperationsSuite extends TestSuiteBase { val networkStream = ssc.socketTextStream("localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) val mappedStream = networkStream.map(_ + ".").persist() - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(mappedStream, outputBuffer) + val outputQueue = new ConcurrentLinkedQueue[Seq[String]] + val outputStream = new TestOutputStream(mappedStream, outputQueue) outputStream.register() ssc.start() @@ -685,7 +688,7 @@ class BasicOperationsSuite extends TestSuiteBase { testServer.stop() // verify data has been received - assert(outputBuffer.size > 0) + assert(!outputQueue.isEmpty) assert(blockRdds.size > 0) assert(persistentRddIds.size > 0) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 84f5294aa39cc..bdbac64b9bc79 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -17,32 +17,198 @@ package org.apache.spark.streaming -import java.io.{ObjectOutputStream, ByteArrayOutputStream, ByteArrayInputStream, File} -import org.apache.spark.TestUtils +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, ObjectOutputStream} +import java.nio.charset.StandardCharsets +import java.util.concurrent.ConcurrentLinkedQueue -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import com.google.common.base.Charsets import com.google.common.io.Files import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{IntWritable, Text} import org.apache.hadoop.mapred.TextOutputFormat import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} +import org.mockito.Mockito.mock import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} -import org.apache.spark.streaming.scheduler.{ConstantEstimator, RateTestInputDStream, RateTestReceiver} -import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite, TestUtils} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.dstream._ +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.util.{Clock, ManualClock, MutableURLClassLoader, ResetSystemProperties, + Utils} + +/** + * A input stream that records the times of restore() invoked + */ +private[streaming] +class CheckpointInputDStream(_ssc: StreamingContext) extends InputDStream[Int](_ssc) { + protected[streaming] override val checkpointData = new FileInputDStreamCheckpointData + override def start(): Unit = { } + override def stop(): Unit = { } + override def compute(time: Time): Option[RDD[Int]] = Some(ssc.sc.makeRDD(Seq(1))) + private[streaming] + class FileInputDStreamCheckpointData extends DStreamCheckpointData(this) { + @transient + var restoredTimes = 0 + override def restore() { + restoredTimes += 1 + super.restore() + } + } +} + +/** + * A trait of that can be mixed in to get methods for testing DStream operations under + * DStream checkpointing. Note that the implementations of this trait has to implement + * the `setupCheckpointOperation` + */ +trait DStreamCheckpointTester { self: SparkFunSuite => + + /** + * Tests a streaming operation under checkpointing, by restarting the operation + * from checkpoint file and verifying whether the final output is correct. + * The output is assumed to have come from a reliable queue which an replay + * data as required. + * + * NOTE: This takes into consideration that the last batch processed before + * master failure will be re-processed after restart/recovery. + */ + protected def testCheckpointedOperation[U: ClassTag, V: ClassTag]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + numBatchesBeforeRestart: Int, + batchDuration: Duration = Milliseconds(500), + stopSparkContextAfterTest: Boolean = true + ) { + require(numBatchesBeforeRestart < expectedOutput.size, + "Number of batches before context restart less than number of expected output " + + "(i.e. number of total batches to run)") + require(StreamingContext.getActive().isEmpty, + "Cannot run test with already active streaming context") + + // Current code assumes that number of batches to be run = number of inputs + val totalNumBatches = input.size + val batchDurationMillis = batchDuration.milliseconds + + // Setup the stream computation + val checkpointDir = Utils.createTempDir(this.getClass.getSimpleName()).toString + logDebug(s"Using checkpoint directory $checkpointDir") + val ssc = createContextForCheckpointOperation(batchDuration) + require(ssc.conf.get("spark.streaming.clock") === classOf[ManualClock].getName, + "Cannot run test without manual clock in the conf") + + val inputStream = new TestInputStream(ssc, input, numPartitions = 2) + val operatedStream = operation(inputStream) + operatedStream.print() + val outputStream = new TestOutputStreamWithPartitions(operatedStream, + new ConcurrentLinkedQueue[Seq[Seq[V]]]) + outputStream.register() + ssc.checkpoint(checkpointDir) + + // Do the computation for initial number of batches, create checkpoint file and quit + val beforeRestartOutput = generateOutput[V](ssc, + Time(batchDurationMillis * numBatchesBeforeRestart), checkpointDir, stopSparkContextAfterTest) + assertOutput(beforeRestartOutput, expectedOutput, beforeRestart = true) + // Restart and complete the computation from checkpoint file + logInfo( + "\n-------------------------------------------\n" + + " Restarting stream computation " + + "\n-------------------------------------------\n" + ) + + val restartedSsc = new StreamingContext(checkpointDir) + val afterRestartOutput = generateOutput[V](restartedSsc, + Time(batchDurationMillis * totalNumBatches), checkpointDir, stopSparkContextAfterTest) + assertOutput(afterRestartOutput, expectedOutput, beforeRestart = false) + } + + protected def createContextForCheckpointOperation(batchDuration: Duration): StreamingContext = { + val conf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName) + conf.set("spark.streaming.clock", classOf[ManualClock].getName()) + new StreamingContext(SparkContext.getOrCreate(conf), batchDuration) + } + + /** + * Get the first TestOutputStreamWithPartitions, does not check the provided generic type. + */ + protected def getTestOutputStream[V: ClassTag](streams: Array[DStream[_]]): + TestOutputStreamWithPartitions[V] = { + streams.collect { + case ds: TestOutputStreamWithPartitions[V @unchecked] => ds + }.head + } + + + protected def generateOutput[V: ClassTag]( + ssc: StreamingContext, + targetBatchTime: Time, + checkpointDir: String, + stopSparkContext: Boolean + ): Seq[Seq[V]] = { + try { + val batchDuration = ssc.graph.batchDuration + val batchCounter = new BatchCounter(ssc) + ssc.start() + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val currentTime = clock.getTimeMillis() + + logInfo("Manual clock before advancing = " + clock.getTimeMillis()) + clock.setTime(targetBatchTime.milliseconds) + logInfo("Manual clock after advancing = " + clock.getTimeMillis()) + + val outputStream = getTestOutputStream[V](ssc.graph.getOutputStreams()) + + eventually(timeout(10 seconds)) { + ssc.awaitTerminationOrTimeout(10) + assert(batchCounter.getLastCompletedBatchTime === targetBatchTime) + } + + eventually(timeout(10 seconds)) { + val checkpointFilesOfLatestTime = Checkpoint.getCheckpointFiles(checkpointDir).filter { + _.toString.contains(clock.getTimeMillis.toString) + } + // Checkpoint files are written twice for every batch interval. So assert that both + // are written to make sure that both of them have been written. + assert(checkpointFilesOfLatestTime.size === 2) + } + outputStream.output.asScala.map(_.flatten).toSeq + + } finally { + ssc.stop(stopSparkContext = stopSparkContext) + } + } + + private def assertOutput[V: ClassTag]( + output: Seq[Seq[V]], + expectedOutput: Seq[Seq[V]], + beforeRestart: Boolean): Unit = { + val expectedPartialOutput = if (beforeRestart) { + expectedOutput.take(output.size) + } else { + expectedOutput.takeRight(output.size) + } + val setComparison = output.zip(expectedPartialOutput).forall { + case (o, e) => o.toSet === e.toSet + } + assert(setComparison, s"set comparison failed\n" + + s"Expected output items:\n${expectedPartialOutput.mkString("\n")}\n" + + s"Generated output items: ${output.mkString("\n")}" + ) + } +} /** * This test suites tests the checkpointing functionality of DStreams - * the checkpointing of a DStream's RDDs as well as the checkpointing of * the whole DStream graph. */ -class CheckpointSuite extends TestSuiteBase { +class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester + with ResetSystemProperties { var ssc: StreamingContext = null @@ -54,9 +220,17 @@ class CheckpointSuite extends TestSuiteBase { } override def afterFunction() { - super.afterFunction() - if (ssc != null) ssc.stop() - Utils.deleteRecursively(new File(checkpointDir)) + try { + if (ssc != null) { ssc.stop() } + Utils.deleteRecursively(new File(checkpointDir)) + } finally { + super.afterFunction() + } + } + + test("non-existent checkpoint dir") { + // SPARK-13211 + intercept[IllegalArgumentException](new StreamingContext("nosuchdirectory")) } test("basic rdd checkpoints + dstream graph checkpoint recovery") { @@ -93,10 +267,9 @@ class CheckpointSuite extends TestSuiteBase { assert(!stateStream.checkpointData.currentCheckpointFiles.isEmpty, "No checkpointed RDDs in state stream before first failure") stateStream.checkpointData.currentCheckpointFiles.foreach { - case (time, file) => { + case (time, file) => assert(fs.exists(new Path(file)), "Checkpoint file '" + file +"' for time " + time + " for state stream before first failure does not exist") - } } // Run till a further time such that previous checkpoint files in the stream would be deleted @@ -123,10 +296,9 @@ class CheckpointSuite extends TestSuiteBase { assert(!stateStream.checkpointData.currentCheckpointFiles.isEmpty, "No checkpointed RDDs in state stream before second failure") stateStream.checkpointData.currentCheckpointFiles.foreach { - case (time, file) => { + case (time, file) => assert(fs.exists(new Path(file)), "Checkpoint file '" + file +"' for time " + time + " for state stream before seconds failure does not exist") - } } ssc.stop() @@ -250,7 +422,9 @@ class CheckpointSuite extends TestSuiteBase { Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), - Seq(("", 2)), Seq() ), + Seq(("", 2)), + Seq() + ), 3 ) } @@ -430,7 +604,7 @@ class CheckpointSuite extends TestSuiteBase { // Set up the streaming context and input streams val batchDuration = Seconds(2) // Due to 1-second resolution of setLastModified() on some OS's. val testDir = Utils.createTempDir() - val outputBuffer = new ArrayBuffer[Seq[Int]] with SynchronizedBuffer[Seq[Int]] + val outputBuffer = new ConcurrentLinkedQueue[Seq[Int]] /** * Writes a file named `i` (which contains the number `i`) to the test directory and sets its @@ -438,7 +612,7 @@ class CheckpointSuite extends TestSuiteBase { */ def writeFile(i: Int, clock: Clock): Unit = { val file = new File(testDir, i.toString) - Files.write(i + "\n", file, Charsets.UTF_8) + Files.write(i + "\n", file, StandardCharsets.UTF_8) assert(file.setLastModified(clock.getTimeMillis())) // Check that the file's modification date is actually the value we wrote, since rounding or // truncation will break the test: @@ -451,7 +625,8 @@ class CheckpointSuite extends TestSuiteBase { def recordedFiles(ssc: StreamingContext): Seq[Int] = { val fileInputDStream = ssc.graph.getInputStreams().head.asInstanceOf[FileInputDStream[_, _, _]] - val filenames = fileInputDStream.batchTimeToSelectedFiles.values.flatten + val filenames = fileInputDStream.batchTimeToSelectedFiles.synchronized + { fileInputDStream.batchTimeToSelectedFiles.values.flatten } filenames.map(_.split(File.separator).last.toInt).toSeq.sorted } @@ -510,7 +685,7 @@ class CheckpointSuite extends TestSuiteBase { ssc.stop() // Check that we shut down while the third batch was being processed assert(batchCounter.getNumCompletedBatches === 2) - assert(outputStream.output.flatten === Seq(1, 3)) + assert(outputStream.output.asScala.toSeq.flatten === Seq(1, 3)) } // The original StreamingContext has now been stopped. @@ -560,7 +735,7 @@ class CheckpointSuite extends TestSuiteBase { assert(batchCounter.getNumCompletedBatches === index + numBatchesAfterRestart + 1) } } - logInfo("Output after restart = " + outputStream.output.mkString("[", ", ", "]")) + logInfo("Output after restart = " + outputStream.output.asScala.mkString("[", ", ", "]")) assert(outputStream.output.size > 0, "No files processed after restart") ssc.stop() @@ -569,17 +744,44 @@ class CheckpointSuite extends TestSuiteBase { assert(recordedFiles(ssc) === (1 to 9)) // Append the new output to the old buffer - outputBuffer ++= outputStream.output + outputBuffer.addAll(outputStream.output) // Verify whether all the elements received are as expected val expectedOutput = Seq(1, 3, 6, 10, 15, 21, 28, 36, 45) - assert(outputBuffer.flatten.toSet === expectedOutput.toSet) + assert(outputBuffer.asScala.flatten.toSet === expectedOutput.toSet) } } finally { Utils.deleteRecursively(testDir) } } + test("DStreamCheckpointData.restore invoking times") { + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + ssc.checkpoint(checkpointDir) + val inputDStream = new CheckpointInputDStream(ssc) + val checkpointData = inputDStream.checkpointData + val mappedDStream = inputDStream.map(_ + 100) + val outputStream = new TestOutputStreamWithPartitions(mappedDStream) + outputStream.register() + // do two more times output + mappedDStream.foreachRDD(rdd => rdd.count()) + mappedDStream.foreachRDD(rdd => rdd.count()) + assert(checkpointData.restoredTimes === 0) + val batchDurationMillis = ssc.progressListener.batchDuration + generateOutput(ssc, Time(batchDurationMillis * 3), checkpointDir, stopSparkContext = true) + assert(checkpointData.restoredTimes === 0) + } + logInfo("*********** RESTARTING ************") + withStreamingContext(new StreamingContext(checkpointDir)) { ssc => + val checkpointData = + ssc.graph.getInputStreams().head.asInstanceOf[CheckpointInputDStream].checkpointData + assert(checkpointData.restoredTimes === 1) + ssc.start() + ssc.stop() + assert(checkpointData.restoredTimes === 1) + } + } + // This tests whether spark can deserialize array object // refer to SPARK-5569 test("recovery from checkpoint contains array object") { @@ -611,58 +813,103 @@ class CheckpointSuite extends TestSuiteBase { assert(ois.readObject().asInstanceOf[Class[_]].getName == "[LtestClz;") } - /** - * Tests a streaming operation under checkpointing, by restarting the operation - * from checkpoint file and verifying whether the final output is correct. - * The output is assumed to have come from a reliable queue which an replay - * data as required. - * - * NOTE: This takes into consideration that the last batch processed before - * master failure will be re-processed after restart/recovery. - */ - def testCheckpointedOperation[U: ClassTag, V: ClassTag]( - input: Seq[Seq[U]], - operation: DStream[U] => DStream[V], - expectedOutput: Seq[Seq[V]], - initialNumBatches: Int - ) { - - // Current code assumes that: - // number of inputs = number of outputs = number of batches to be run - val totalNumBatches = input.size - val nextNumBatches = totalNumBatches - initialNumBatches - val initialNumExpectedOutputs = initialNumBatches - val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + 1 - // because the last batch will be processed again + test("SPARK-11267: the race condition of two checkpoints in a batch") { + val jobGenerator = mock(classOf[JobGenerator]) + val checkpointDir = Utils.createTempDir().toString + val checkpointWriter = + new CheckpointWriter(jobGenerator, conf, checkpointDir, new Configuration()) + val bytes1 = Array.fill[Byte](10)(1) + new checkpointWriter.CheckpointWriteHandler( + Time(2000), bytes1, clearCheckpointDataLater = false).run() + val bytes2 = Array.fill[Byte](10)(2) + new checkpointWriter.CheckpointWriteHandler( + Time(1000), bytes2, clearCheckpointDataLater = true).run() + val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir).reverse.map { path => + new File(path.toUri) + } + assert(checkpointFiles.size === 2) + // Although bytes2 was written with an old time, it contains the latest status, so we should + // try to read from it at first. + assert(Files.toByteArray(checkpointFiles(0)) === bytes2) + assert(Files.toByteArray(checkpointFiles(1)) === bytes1) + checkpointWriter.stop() + } - // Do the computation for initial number of batches, create checkpoint file and quit - ssc = setupStreams[U, V](input, operation) - ssc.start() - val output = advanceTimeWithRealDelay[V](ssc, initialNumBatches) - ssc.stop() - verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) - Thread.sleep(1000) + test("SPARK-6847: stack overflow when updateStateByKey is followed by a checkpointed dstream") { + // In this test, there are two updateStateByKey operators. The RDD DAG is as follows: + // + // batch 1 batch 2 batch 3 ... + // + // 1) input rdd input rdd input rdd + // | | | + // v v v + // 2) cogroup rdd ---> cogroup rdd ---> cogroup rdd ... + // | / | / | + // v / v / v + // 3) map rdd --- map rdd --- map rdd ... + // | | | + // v v v + // 4) cogroup rdd ---> cogroup rdd ---> cogroup rdd ... + // | / | / | + // v / v / v + // 5) map rdd --- map rdd --- map rdd ... + // + // Every batch depends on its previous batch, so "updateStateByKey" needs to do checkpoint to + // break the RDD chain. However, before SPARK-6847, when the state RDD (layer 5) of the second + // "updateStateByKey" does checkpoint, it won't checkpoint the state RDD (layer 3) of the first + // "updateStateByKey" (Note: "updateStateByKey" has already marked that its state RDD (layer 3) + // should be checkpointed). Hence, the connections between layer 2 and layer 3 won't be broken + // and the RDD chain will grow infinitely and cause StackOverflow. + // + // Therefore SPARK-6847 introduces "spark.checkpoint.checkpointAllMarked" to force checkpointing + // all marked RDDs in the DAG to resolve this issue. (For the previous example, it will break + // connections between layer 2 and layer 3) + ssc = new StreamingContext(master, framework, batchDuration) + val batchCounter = new BatchCounter(ssc) + ssc.checkpoint(checkpointDir) + val inputDStream = new CheckpointInputDStream(ssc) + val updateFunc = (values: Seq[Int], state: Option[Int]) => { + Some(values.sum + state.getOrElse(0)) + } + @volatile var shouldCheckpointAllMarkedRDDs = false + @volatile var rddsCheckpointed = false + inputDStream.map(i => (i, i)) + .updateStateByKey(updateFunc).checkpoint(batchDuration) + .updateStateByKey(updateFunc).checkpoint(batchDuration) + .foreachRDD { rdd => + /** + * Find all RDDs that are marked for checkpointing in the specified RDD and its ancestors. + */ + def findAllMarkedRDDs(rdd: RDD[_]): List[RDD[_]] = { + val markedRDDs = rdd.dependencies.flatMap(dep => findAllMarkedRDDs(dep.rdd)).toList + if (rdd.checkpointData.isDefined) { + rdd :: markedRDDs + } else { + markedRDDs + } + } - // Restart and complete the computation from checkpoint file - logInfo( - "\n-------------------------------------------\n" + - " Restarting stream computation " + - "\n-------------------------------------------\n" - ) - ssc = new StreamingContext(checkpointDir) + shouldCheckpointAllMarkedRDDs = + Option(rdd.sparkContext.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)). + map(_.toBoolean).getOrElse(false) + + val stateRDDs = findAllMarkedRDDs(rdd) + rdd.count() + // Check the two state RDDs are both checkpointed + rddsCheckpointed = stateRDDs.size == 2 && stateRDDs.forall(_.isCheckpointed) + } ssc.start() - val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches) - // the first element will be re-processed data of the last batch before restart - verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) - ssc.stop() - ssc = null + batchCounter.waitUntilBatchesCompleted(1, 10000) + assert(shouldCheckpointAllMarkedRDDs === true) + assert(rddsCheckpointed === true) } /** * Advances the manual clock on the streaming scheduler by given number of batches. * It also waits for the expected amount of time for each batch. */ - def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = + def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): + Iterable[Seq[V]] = { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] logInfo("Manual clock before advancing = " + clock.getTimeMillis()) @@ -673,10 +920,8 @@ class CheckpointSuite extends TestSuiteBase { logInfo("Manual clock after advancing = " + clock.getTimeMillis()) Thread.sleep(batchDuration.milliseconds) - val outputStream = ssc.graph.getOutputStreams().filter { dstream => - dstream.isInstanceOf[TestOutputStreamWithPartitions[V]] - }.head.asInstanceOf[TestOutputStreamWithPartitions[V]] - outputStream.output.map(_.flatten) + val outputStream = getTestOutputStream[V](ssc.graph.getOutputStreams()) + outputStream.output.asScala.map(_.flatten) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala index 9b5e4dc819a2b..1fc34f569f9f4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala @@ -33,13 +33,18 @@ class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll { private var ssc: StreamingContext = null override def beforeAll(): Unit = { + super.beforeAll() val sc = new SparkContext("local", "test") ssc = new StreamingContext(sc, Seconds(1)) } override def afterAll(): Unit = { - ssc.stop(stopSparkContext = true) - ssc = null + try { + ssc.stop(stopSparkContext = true) + ssc = null + } finally { + super.afterAll() + } } test("user provided closures are actually cleaned") { @@ -51,7 +56,6 @@ class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll { testFilter(dstream) testMapPartitions(dstream) testReduce(dstream) - testForeach(dstream) testForeachRDD(dstream) testTransform(dstream) testTransformWith(dstream) @@ -101,12 +105,6 @@ class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll { private def testReduce(ds: DStream[Int]): Unit = expectCorrectException { ds.reduce { case (_, _) => return; 1 } } - private def testForeach(ds: DStream[Int]): Unit = { - val foreachF1 = (rdd: RDD[Int], t: Time) => return - val foreachF2 = (rdd: RDD[Int]) => return - expectCorrectException { ds.foreach(foreachF1) } - expectCorrectException { ds.foreach(foreachF2) } - } private def testForeachRDD(ds: DStream[Int]): Unit = { val foreachRDDF1 = (rdd: RDD[Int], t: Time) => return val foreachRDDF2 = (rdd: RDD[Int]) => return diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala index 8844c9d74b933..94f1bcebc3a39 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.streaming +import scala.collection.mutable.ArrayBuffer + import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.rdd.RDDOperationScope +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.ui.UIUtils +import org.apache.spark.util.ManualClock /** * Tests whether scope information is passed from DStream operations to RDDs correctly. @@ -32,11 +35,18 @@ class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAnd private val batchDuration: Duration = Seconds(1) override def beforeAll(): Unit = { - ssc = new StreamingContext(new SparkContext("local", "test"), batchDuration) + super.beforeAll() + val conf = new SparkConf().setMaster("local").setAppName("test") + conf.set("spark.streaming.clock", classOf[ManualClock].getName()) + ssc = new StreamingContext(new SparkContext(conf), batchDuration) } override def afterAll(): Unit = { - ssc.stop(stopSparkContext = true) + try { + ssc.stop(stopSparkContext = true) + } finally { + super.afterAll() + } } before { assertPropertiesNotSet() } @@ -103,6 +113,8 @@ class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAnd test("scoping nested operations") { val inputStream = new DummyInputDStream(ssc) + // countByKeyAndWindow internally uses reduceByKeyAndWindow, but only countByKeyAndWindow + // should appear in scope val countStream = inputStream.countByWindow(Seconds(10), Seconds(1)) countStream.initialize(Time(0)) @@ -137,6 +149,57 @@ class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAnd testStream(countStream) } + test("transform should allow RDD operations to be captured in scopes") { + val inputStream = new DummyInputDStream(ssc) + val transformedStream = inputStream.transform { _.map { _ -> 1}.reduceByKey(_ + _) } + transformedStream.initialize(Time(0)) + + val transformScopeBase = transformedStream.baseScope.map(RDDOperationScope.fromJson) + val transformScope1 = transformedStream.getOrCompute(Time(1000)).get.scope + val transformScope2 = transformedStream.getOrCompute(Time(2000)).get.scope + val transformScope3 = transformedStream.getOrCompute(Time(3000)).get.scope + + // Assert that all children RDDs inherit the DStream operation name correctly + assertDefined(transformScopeBase, transformScope1, transformScope2, transformScope3) + assert(transformScopeBase.get.name === "transform") + assertNestedScopeCorrect(transformScope1.get, 1000) + assertNestedScopeCorrect(transformScope2.get, 2000) + assertNestedScopeCorrect(transformScope3.get, 3000) + + def assertNestedScopeCorrect(rddScope: RDDOperationScope, batchTime: Long): Unit = { + assert(rddScope.name === "reduceByKey") + assert(rddScope.parent.isDefined) + assertScopeCorrect(transformScopeBase.get, rddScope.parent.get, batchTime) + } + } + + test("foreachRDD should allow RDD operations to be captured in scope") { + val inputStream = new DummyInputDStream(ssc) + val generatedRDDs = new ArrayBuffer[RDD[(Int, Int)]] + inputStream.foreachRDD { rdd => + generatedRDDs += rdd.map { _ -> 1}.reduceByKey(_ + _) + } + val batchCounter = new BatchCounter(ssc) + ssc.start() + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + clock.advance(3000) + batchCounter.waitUntilBatchesCompleted(3, 10000) + assert(generatedRDDs.size === 3) + + val foreachBaseScope = + ssc.graph.getOutputStreams().head.baseScope.map(RDDOperationScope.fromJson) + assertDefined(foreachBaseScope) + assert(foreachBaseScope.get.name === "foreachRDD") + + val rddScopes = generatedRDDs.map { _.scope } + assertDefined(rddScopes: _*) + rddScopes.zipWithIndex.foreach { case (rddScope, idx) => + assert(rddScope.get.name === "reduceByKey") + assert(rddScope.get.parent.isDefined) + assertScopeCorrect(foreachBaseScope.get, rddScope.get.parent.get, (idx + 1) * 1000) + } + } + /** Assert that the RDD operation scope properties are not set in our SparkContext. */ private def assertPropertiesNotSet(): Unit = { assert(ssc != null) @@ -149,19 +212,12 @@ class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAnd baseScope: RDDOperationScope, rddScope: RDDOperationScope, batchTime: Long): Unit = { - assertScopeCorrect(baseScope.id, baseScope.name, rddScope, batchTime) - } - - /** Assert that the given RDD scope inherits the base name and ID correctly. */ - private def assertScopeCorrect( - baseScopeId: String, - baseScopeName: String, - rddScope: RDDOperationScope, - batchTime: Long): Unit = { + val (baseScopeId, baseScopeName) = (baseScope.id, baseScope.name) val formattedBatchTime = UIUtils.formatBatchTime( batchTime, ssc.graph.batchDuration.milliseconds, showYYYYMMSS = false) assert(rddScope.id === s"${baseScopeId}_$batchTime") assert(rddScope.name.replaceAll("\\n", " ") === s"$baseScopeName @ $formattedBatchTime") + assert(rddScope.parent.isEmpty) // There should not be any higher scope } /** Assert that all the specified options are defined. */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala index e82c2fa4e72ad..19ceb748e07f7 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -21,7 +21,8 @@ import java.io.File import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkFunSuite, Logging} +import org.apache.spark._ +import org.apache.spark.internal.Logging import org.apache.spark.util.Utils /** @@ -43,6 +44,9 @@ class FailureSuite extends SparkFunSuite with BeforeAndAfter with Logging { Utils.deleteRecursively(directory) } StreamingContext.getActive().foreach { _.stop() } + + // Stop SparkContext if active + SparkContext.getOrCreate(new SparkConf().setMaster("local").setAppName("bla")).stop() } test("multiple failures with map") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 047e38ef90998..a2653000af557 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -17,30 +17,30 @@ package org.apache.spark.streaming -import java.io.{File, BufferedWriter, OutputStreamWriter} -import java.net.{Socket, SocketException, ServerSocket} -import java.nio.charset.Charset -import java.util.concurrent.{CountDownLatch, Executors, TimeUnit, ArrayBlockingQueue} +import java.io.{BufferedWriter, File, OutputStreamWriter} +import java.net.{ServerSocket, Socket, SocketException} +import java.nio.charset.StandardCharsets +import java.util.concurrent._ import java.util.concurrent.atomic.AtomicInteger -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer, SynchronizedQueue} +import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.language.postfixOps import com.google.common.io.Files -import org.apache.hadoop.io.{Text, LongWritable} -import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ -import org.apache.spark.Logging +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.scheduler.{StreamingListenerBatchCompleted, StreamingListener} -import org.apache.spark.util.{ManualClock, Utils} import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream} import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.util.{ManualClock, Utils} class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { @@ -58,15 +58,15 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val batchCounter = new BatchCounter(ssc) val networkStream = ssc.socketTextStream( "localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(networkStream, outputBuffer) + val outputQueue = new ConcurrentLinkedQueue[Seq[String]] + val outputStream = new TestOutputStream(networkStream, outputQueue) outputStream.register() ssc.start() // Feed data to the server to send to the network receiver val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val expectedOutput = input.map(_.toString) - for (i <- 0 until input.size) { + for (i <- input.indices) { testServer.send(input(i).toString + "\n") Thread.sleep(500) clock.advance(batchDuration.milliseconds) @@ -77,7 +77,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } // Ensure progress listener has been notified of all events - ssc.scheduler.listenerBus.waitUntilEmpty(500) + ssc.sparkContext.listenerBus.waitUntilEmpty(500) // Verify all "InputInfo"s have been reported assert(ssc.progressListener.numTotalReceivedRecords === input.size) @@ -90,9 +90,9 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Verify whether data received was as expected logInfo("--------------------------------") - logInfo("output.size = " + outputBuffer.size) + logInfo("output.size = " + outputQueue.size) logInfo("output") - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + outputQueue.asScala.foreach(x => logInfo("[" + x.mkString(",") + "]")) logInfo("expected output.size = " + expectedOutput.size) logInfo("expected output") expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) @@ -100,9 +100,9 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Verify whether all the elements received are as expected // (whether the elements were received one in each interval is not verified) - val output: ArrayBuffer[String] = outputBuffer.flatMap(x => x) - assert(output.size === expectedOutput.size) - for (i <- 0 until output.size) { + val output: Array[String] = outputQueue.asScala.flatMap(x => x).toArray + assert(output.length === expectedOutput.size) + for (i <- output.indices) { assert(output(i) === expectedOutput(i)) } } @@ -119,8 +119,8 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val batchCounter = new BatchCounter(ssc) val networkStream = ssc.socketTextStream( "localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(networkStream, outputBuffer) + val outputQueue = new ConcurrentLinkedQueue[Seq[String]] + val outputStream = new TestOutputStream(networkStream, outputQueue) outputStream.register() ssc.start() @@ -146,7 +146,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val testDir = Utils.createTempDir() // Create a file that exists before the StreamingContext is created: val existingFile = new File(testDir, "0") - Files.write("0\n", existingFile, Charset.forName("UTF-8")) + Files.write("0\n", existingFile, StandardCharsets.UTF_8) assert(existingFile.setLastModified(10000) && existingFile.lastModified === 10000) // Set up the streaming context and input streams @@ -156,9 +156,8 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { clock.setTime(existingFile.lastModified + batchDuration.milliseconds) val batchCounter = new BatchCounter(ssc) val fileStream = ssc.binaryRecordsStream(testDir.toString, 1) - val outputBuffer = new ArrayBuffer[Seq[Array[Byte]]] - with SynchronizedBuffer[Seq[Array[Byte]]] - val outputStream = new TestOutputStream(fileStream, outputBuffer) + val outputQueue = new ConcurrentLinkedQueue[Seq[Array[Byte]]] + val outputStream = new TestOutputStream(fileStream, outputQueue) outputStream.register() ssc.start() @@ -183,8 +182,8 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } val expectedOutput = input.map(i => i.toByte) - val obtainedOutput = outputBuffer.flatten.toList.map(i => i(0).toByte) - assert(obtainedOutput === expectedOutput) + val obtainedOutput = outputQueue.asScala.flatten.toList.map(i => i(0).toByte) + assert(obtainedOutput.toSeq === expectedOutput) } } finally { if (testDir != null) Utils.deleteRecursively(testDir) @@ -206,69 +205,73 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val numTotalRecords = numThreads * numRecordsPerThread val testReceiver = new MultiThreadTestReceiver(numThreads, numRecordsPerThread) MultiThreadTestReceiver.haveAllThreadsFinished = false + val outputQueue = new ConcurrentLinkedQueue[Seq[Long]] + def output: Iterable[Long] = outputQueue.asScala.flatMap(x => x) // set up the network stream using the test receiver - val ssc = new StreamingContext(conf, batchDuration) - val networkStream = ssc.receiverStream[Int](testReceiver) - val countStream = networkStream.count - val outputBuffer = new ArrayBuffer[Seq[Long]] with SynchronizedBuffer[Seq[Long]] - val outputStream = new TestOutputStream(countStream, outputBuffer) - def output: ArrayBuffer[Long] = outputBuffer.flatMap(x => x) - outputStream.register() - ssc.start() - - // Let the data from the receiver be received - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val startTime = System.currentTimeMillis() - while((!MultiThreadTestReceiver.haveAllThreadsFinished || output.sum < numTotalRecords) && - System.currentTimeMillis() - startTime < 5000) { - Thread.sleep(100) - clock.advance(batchDuration.milliseconds) + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + val networkStream = ssc.receiverStream[Int](testReceiver) + val countStream = networkStream.count + + val outputStream = new TestOutputStream(countStream, outputQueue) + outputStream.register() + ssc.start() + + // Let the data from the receiver be received + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val startTime = System.currentTimeMillis() + while ((!MultiThreadTestReceiver.haveAllThreadsFinished || output.sum < numTotalRecords) && + System.currentTimeMillis() - startTime < 5000) { + Thread.sleep(100) + clock.advance(batchDuration.milliseconds) + } + Thread.sleep(1000) } - Thread.sleep(1000) - logInfo("Stopping context") - ssc.stop() // Verify whether data received was as expected logInfo("--------------------------------") - logInfo("output.size = " + outputBuffer.size) + logInfo("output.size = " + outputQueue.size) logInfo("output") - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + outputQueue.asScala.foreach(x => logInfo("[" + x.mkString(",") + "]")) logInfo("--------------------------------") assert(output.sum === numTotalRecords) } test("queue input stream - oneAtATime = true") { - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val queue = new SynchronizedQueue[RDD[String]]() - val queueStream = ssc.queueStream(queue, oneAtATime = true) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(queueStream, outputBuffer) - def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) - outputStream.register() - ssc.start() - - // Setup data queued into the stream - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq("1", "2", "3", "4", "5") val expectedOutput = input.map(Seq(_)) + val outputQueue = new ConcurrentLinkedQueue[Seq[String]] + def output: Iterable[Seq[String]] = outputQueue.asScala.filter(_.nonEmpty) - val inputIterator = input.toIterator - for (i <- 0 until input.size) { - // Enqueue more than 1 item per tick but they should dequeue one at a time - inputIterator.take(2).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) - clock.advance(batchDuration.milliseconds) + // Set up the streaming context and input streams + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + val queue = new mutable.Queue[RDD[String]]() + val queueStream = ssc.queueStream(queue, oneAtATime = true) + val outputStream = new TestOutputStream(queueStream, outputQueue) + outputStream.register() + ssc.start() + + // Setup data queued into the stream + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + + val inputIterator = input.toIterator + for (i <- input.indices) { + // Enqueue more than 1 item per tick but they should dequeue one at a time + inputIterator.take(2).foreach { i => + queue.synchronized { + queue += ssc.sparkContext.makeRDD(Seq(i)) + } + } + clock.advance(batchDuration.milliseconds) + } + Thread.sleep(1000) } - Thread.sleep(1000) - logInfo("Stopping context") - ssc.stop() // Verify whether data received was as expected logInfo("--------------------------------") - logInfo("output.size = " + outputBuffer.size) + logInfo("output.size = " + outputQueue.size) logInfo("output") - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + outputQueue.asScala.foreach(x => logInfo("[" + x.mkString(",") + "]")) logInfo("expected output.size = " + expectedOutput.size) logInfo("expected output") expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) @@ -276,45 +279,51 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Verify whether all the elements received are as expected assert(output.size === expectedOutput.size) - for (i <- 0 until output.size) { - assert(output(i) === expectedOutput(i)) - } + output.zipWithIndex.foreach{case (e, i) => assert(e == expectedOutput(i))} } test("queue input stream - oneAtATime = false") { - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val queue = new SynchronizedQueue[RDD[String]]() - val queueStream = ssc.queueStream(queue, oneAtATime = false) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(queueStream, outputBuffer) - def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) - outputStream.register() - ssc.start() - - // Setup data queued into the stream - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val outputQueue = new ConcurrentLinkedQueue[Seq[String]] + def output: Iterable[Seq[String]] = outputQueue.asScala.filter(_.nonEmpty) val input = Seq("1", "2", "3", "4", "5") val expectedOutput = Seq(Seq("1", "2", "3"), Seq("4", "5")) - // Enqueue the first 3 items (one by one), they should be merged in the next batch - val inputIterator = input.toIterator - inputIterator.take(3).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) - clock.advance(batchDuration.milliseconds) - Thread.sleep(1000) + // Set up the streaming context and input streams + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + val queue = new mutable.Queue[RDD[String]]() + val queueStream = ssc.queueStream(queue, oneAtATime = false) + val outputStream = new TestOutputStream(queueStream, outputQueue) + outputStream.register() + ssc.start() + + // Setup data queued into the stream + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + + // Enqueue the first 3 items (one by one), they should be merged in the next batch + val inputIterator = input.toIterator + inputIterator.take(3).foreach { i => + queue.synchronized { + queue += ssc.sparkContext.makeRDD(Seq(i)) + } + } + clock.advance(batchDuration.milliseconds) + Thread.sleep(1000) - // Enqueue the remaining items (again one by one), merged in the final batch - inputIterator.foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) - clock.advance(batchDuration.milliseconds) - Thread.sleep(1000) - logInfo("Stopping context") - ssc.stop() + // Enqueue the remaining items (again one by one), merged in the final batch + inputIterator.foreach { i => + queue.synchronized { + queue += ssc.sparkContext.makeRDD(Seq(i)) + } + } + clock.advance(batchDuration.milliseconds) + Thread.sleep(1000) + } // Verify whether data received was as expected logInfo("--------------------------------") - logInfo("output.size = " + outputBuffer.size) + logInfo("output.size = " + outputQueue.size) logInfo("output") - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + outputQueue.asScala.foreach(x => logInfo("[" + x.mkString(",") + "]")) logInfo("expected output.size = " + expectedOutput.size) logInfo("expected output") expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) @@ -322,9 +331,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Verify whether all the elements received are as expected assert(output.size === expectedOutput.size) - for (i <- 0 until output.size) { - assert(output(i) === expectedOutput(i)) - } + output.zipWithIndex.foreach{case (e, i) => assert(e == expectedOutput(i))} } test("test track the number of input stream") { @@ -362,7 +369,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val testDir = Utils.createTempDir() // Create a file that exists before the StreamingContext is created: val existingFile = new File(testDir, "0") - Files.write("0\n", existingFile, Charset.forName("UTF-8")) + Files.write("0\n", existingFile, StandardCharsets.UTF_8) assert(existingFile.setLastModified(10000) && existingFile.lastModified === 10000) // Set up the streaming context and input streams @@ -373,8 +380,8 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val batchCounter = new BatchCounter(ssc) val fileStream = ssc.fileStream[LongWritable, Text, TextInputFormat]( testDir.toString, (x: Path) => true, newFilesOnly = newFilesOnly).map(_._2.toString) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(fileStream, outputBuffer) + val outputQueue = new ConcurrentLinkedQueue[Seq[String]] + val outputStream = new TestOutputStream(fileStream, outputQueue) outputStream.register() ssc.start() @@ -386,7 +393,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val input = Seq(1, 2, 3, 4, 5) input.foreach { i => val file = new File(testDir, i.toString) - Files.write(i + "\n", file, Charset.forName("UTF-8")) + Files.write(i + "\n", file, StandardCharsets.UTF_8) assert(file.setLastModified(clock.getTimeMillis())) assert(file.lastModified === clock.getTimeMillis()) logInfo("Created file " + file) @@ -404,7 +411,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } else { (Seq(0) ++ input).map(_.toString).toSet } - assert(outputBuffer.flatten.toSet === expectedOutput) + assert(outputQueue.asScala.flatten.toSet === expectedOutput) } } finally { if (testDir != null) Utils.deleteRecursively(testDir) @@ -441,7 +448,7 @@ class TestServer(portToBind: Int = 0) extends Logging { try { clientSocket.setTcpNoDelay(true) val outputStream = new BufferedWriter( - new OutputStreamWriter(clientSocket.getOutputStream)) + new OutputStreamWriter(clientSocket.getOutputStream, StandardCharsets.UTF_8)) while (clientSocket.isConnected) { val msg = queue.poll(100, TimeUnit.MILLISECONDS) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala new file mode 100644 index 0000000000000..3b662ec1833aa --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala @@ -0,0 +1,586 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import java.io.File +import java.util.concurrent.ConcurrentLinkedQueue + +import scala.collection.JavaConverters._ +import scala.reflect.ClassTag + +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.scalatest.PrivateMethodTester._ + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.streaming.dstream.{DStream, InternalMapWithStateDStream, MapWithStateDStream, MapWithStateDStreamImpl} +import org.apache.spark.util.{ManualClock, Utils} + +class MapWithStateSuite extends SparkFunSuite + with DStreamCheckpointTester with BeforeAndAfterAll with BeforeAndAfter { + + private var sc: SparkContext = null + protected var checkpointDir: File = null + protected val batchDuration = Seconds(1) + + before { + StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } + checkpointDir = Utils.createTempDir("checkpoint") + } + + after { + StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } + if (checkpointDir != null) { + Utils.deleteRecursively(checkpointDir) + } + } + + override def beforeAll(): Unit = { + super.beforeAll() + val conf = new SparkConf().setMaster("local").setAppName("MapWithStateSuite") + conf.set("spark.streaming.clock", classOf[ManualClock].getName()) + sc = new SparkContext(conf) + } + + override def afterAll(): Unit = { + try { + if (sc != null) { + sc.stop() + } + } finally { + super.afterAll() + } + } + + test("state - get, exists, update, remove, ") { + var state: StateImpl[Int] = null + + def testState( + expectedData: Option[Int], + shouldBeUpdated: Boolean = false, + shouldBeRemoved: Boolean = false, + shouldBeTimingOut: Boolean = false + ): Unit = { + if (expectedData.isDefined) { + assert(state.exists) + assert(state.get() === expectedData.get) + assert(state.getOption() === expectedData) + assert(state.getOption.getOrElse(-1) === expectedData.get) + } else { + assert(!state.exists) + intercept[NoSuchElementException] { + state.get() + } + assert(state.getOption() === None) + assert(state.getOption.getOrElse(-1) === -1) + } + + assert(state.isTimingOut() === shouldBeTimingOut) + if (shouldBeTimingOut) { + intercept[IllegalArgumentException] { + state.remove() + } + intercept[IllegalArgumentException] { + state.update(-1) + } + } + + assert(state.isUpdated() === shouldBeUpdated) + + assert(state.isRemoved() === shouldBeRemoved) + if (shouldBeRemoved) { + intercept[IllegalArgumentException] { + state.remove() + } + intercept[IllegalArgumentException] { + state.update(-1) + } + } + } + + state = new StateImpl[Int]() + testState(None) + + state.wrap(None) + testState(None) + + state.wrap(Some(1)) + testState(Some(1)) + + state.update(2) + testState(Some(2), shouldBeUpdated = true) + + state = new StateImpl[Int]() + state.update(2) + testState(Some(2), shouldBeUpdated = true) + + state.remove() + testState(None, shouldBeRemoved = true) + + state.wrapTimingOutState(3) + testState(Some(3), shouldBeTimingOut = true) + } + + test("mapWithState - basic operations with simple API") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(), + Seq(1), + Seq(2, 1), + Seq(3, 2, 1), + Seq(4, 3), + Seq(5), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + // state maintains running count, and updated count is returned + val mappingFunc = (key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOption.getOrElse(0) + state.update(sum) + sum + } + + testOperation[String, Int, Int]( + inputData, StateSpec.function(mappingFunc), outputData, stateData) + } + + test("mapWithState - basic operations with advanced API") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(), + Seq("aa"), + Seq("aa", "bb"), + Seq("aa", "bb", "cc"), + Seq("aa", "bb"), + Seq("aa"), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + // state maintains running count, key string doubled and returned + val mappingFunc = (batchTime: Time, key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOption.getOrElse(0) + state.update(sum) + Some(key * 2) + } + + testOperation(inputData, StateSpec.function(mappingFunc), outputData, stateData) + } + + test("mapWithState - type inferencing and class tags") { + + // Simple track state function with value as Int, state as Double and mapped type as Double + val simpleFunc = (key: String, value: Option[Int], state: State[Double]) => { + 0L + } + + // Advanced track state function with key as String, value as Int, state as Double and + // mapped type as Double + val advancedFunc = (time: Time, key: String, value: Option[Int], state: State[Double]) => { + Some(0L) + } + + def testTypes(dstream: MapWithStateDStream[_, _, _, _]): Unit = { + val dstreamImpl = dstream.asInstanceOf[MapWithStateDStreamImpl[_, _, _, _]] + assert(dstreamImpl.keyClass === classOf[String]) + assert(dstreamImpl.valueClass === classOf[Int]) + assert(dstreamImpl.stateClass === classOf[Double]) + assert(dstreamImpl.mappedClass === classOf[Long]) + } + val ssc = new StreamingContext(sc, batchDuration) + val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2) + + // Defining StateSpec inline with mapWithState and simple function implicitly gets the types + val simpleFunctionStateStream1 = inputStream.mapWithState( + StateSpec.function(simpleFunc).numPartitions(1)) + testTypes(simpleFunctionStateStream1) + + // Separately defining StateSpec with simple function requires explicitly specifying types + val simpleFuncSpec = StateSpec.function[String, Int, Double, Long](simpleFunc) + val simpleFunctionStateStream2 = inputStream.mapWithState(simpleFuncSpec) + testTypes(simpleFunctionStateStream2) + + // Separately defining StateSpec with advanced function implicitly gets the types + val advFuncSpec1 = StateSpec.function(advancedFunc) + val advFunctionStateStream1 = inputStream.mapWithState(advFuncSpec1) + testTypes(advFunctionStateStream1) + + // Defining StateSpec inline with mapWithState and advanced func implicitly gets the types + val advFunctionStateStream2 = inputStream.mapWithState( + StateSpec.function(simpleFunc).numPartitions(1)) + testTypes(advFunctionStateStream2) + + // Defining StateSpec inline with mapWithState and advanced func implicitly gets the types + val advFuncSpec2 = StateSpec.function[String, Int, Double, Long](advancedFunc) + val advFunctionStateStream3 = inputStream.mapWithState[Double, Long](advFuncSpec2) + testTypes(advFunctionStateStream3) + } + + test("mapWithState - states as mapped data") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3)), + Seq(("a", 5)), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOption.getOrElse(0) + val output = (key, sum) + state.update(sum) + Some(output) + } + + testOperation(inputData, StateSpec.function(mappingFunc), outputData, stateData) + } + + test("mapWithState - initial states, with nothing returned as from mapping function") { + + val initialState = Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0)) + + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = Seq.fill(inputData.size)(Seq.empty[Int]) + + val stateData = + Seq( + Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0)), + Seq(("a", 6), ("b", 10), ("c", -20), ("d", 0)), + Seq(("a", 7), ("b", 11), ("c", -20), ("d", 0)), + Seq(("a", 8), ("b", 12), ("c", -19), ("d", 0)), + Seq(("a", 9), ("b", 13), ("c", -19), ("d", 0)), + Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0)), + Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0)) + ) + + val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOption.getOrElse(0) + val output = (key, sum) + state.update(sum) + None.asInstanceOf[Option[Int]] + } + + val mapWithStateSpec = StateSpec.function(mappingFunc).initialState(sc.makeRDD(initialState)) + testOperation(inputData, mapWithStateSpec, outputData, stateData) + } + + test("mapWithState - state removing") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), // a will be removed + Seq("a", "b", "c"), // b will be removed + Seq("a", "b", "c"), // a and c will be removed + Seq("a", "b"), // b will be removed + Seq("a"), // a will be removed + Seq() + ) + + // States that were removed + val outputData = + Seq( + Seq(), + Seq(), + Seq("a"), + Seq("b"), + Seq("a", "c"), + Seq("b"), + Seq("a"), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("b", 1)), + Seq(("a", 1), ("c", 1)), + Seq(("b", 1)), + Seq(("a", 1)), + Seq(), + Seq() + ) + + val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + if (state.exists) { + state.remove() + Some(key) + } else { + state.update(value.get) + None + } + } + + testOperation( + inputData, StateSpec.function(mappingFunc).numPartitions(1), outputData, stateData) + } + + test("mapWithState - state timing out") { + val inputData = + Seq( + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq(), // c will time out + Seq(), // b will time out + Seq("a") // a will not time out + ) ++ Seq.fill(20)(Seq("a")) // a will continue to stay active + + val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + if (value.isDefined) { + state.update(1) + } + if (state.isTimingOut) { + Some(key) + } else { + None + } + } + + val (collectedOutputs, collectedStateSnapshots) = getOperationOutput( + inputData, StateSpec.function(mappingFunc).timeout(Seconds(3)), 20) + + // b and c should be returned once each, when they were marked as expired + assert(collectedOutputs.flatten.sorted === Seq("b", "c")) + + // States for a, b, c should be defined at one point of time + assert(collectedStateSnapshots.exists { + _.toSet == Set(("a", 1), ("b", 1), ("c", 1)) + }) + + // Finally state should be defined only for a + assert(collectedStateSnapshots.last.toSet === Set(("a", 1))) + } + + test("mapWithState - checkpoint durations") { + val privateMethod = PrivateMethod[InternalMapWithStateDStream[_, _, _, _]]('internalStream) + + def testCheckpointDuration( + batchDuration: Duration, + expectedCheckpointDuration: Duration, + explicitCheckpointDuration: Option[Duration] = None + ): Unit = { + val ssc = new StreamingContext(sc, batchDuration) + + try { + val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1) + val dummyFunc = (key: Int, value: Option[Int], state: State[Int]) => 0 + val mapWithStateStream = inputStream.mapWithState(StateSpec.function(dummyFunc)) + val internalmapWithStateStream = mapWithStateStream invokePrivate privateMethod() + + explicitCheckpointDuration.foreach { d => + mapWithStateStream.checkpoint(d) + } + mapWithStateStream.register() + ssc.checkpoint(checkpointDir.toString) + ssc.start() // should initialize all the checkpoint durations + assert(mapWithStateStream.checkpointDuration === null) + assert(internalmapWithStateStream.checkpointDuration === expectedCheckpointDuration) + } finally { + ssc.stop(stopSparkContext = false) + } + } + + testCheckpointDuration(Milliseconds(100), Seconds(1)) + testCheckpointDuration(Seconds(1), Seconds(10)) + testCheckpointDuration(Seconds(10), Seconds(100)) + + testCheckpointDuration(Milliseconds(100), Seconds(2), Some(Seconds(2))) + testCheckpointDuration(Seconds(1), Seconds(2), Some(Seconds(2))) + testCheckpointDuration(Seconds(10), Seconds(20), Some(Seconds(20))) + } + + + test("mapWithState - driver failure recovery") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + def operation(dstream: DStream[String]): DStream[(String, Int)] = { + + val checkpointDuration = batchDuration * (stateData.size / 2) + + val runningCount = (key: String, value: Option[Int], state: State[Int]) => { + state.update(state.getOption().getOrElse(0) + value.getOrElse(0)) + state.get() + } + + val mapWithStateStream = dstream.map { _ -> 1 }.mapWithState( + StateSpec.function(runningCount)) + // Set interval make sure there is one RDD checkpointing + mapWithStateStream.checkpoint(checkpointDuration) + mapWithStateStream.stateSnapshots() + } + + testCheckpointedOperation(inputData, operation, stateData, inputData.size / 2, + batchDuration = batchDuration, stopSparkContextAfterTest = false) + } + + private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag]( + input: Seq[Seq[K]], + mapWithStateSpec: StateSpec[K, Int, S, T], + expectedOutputs: Seq[Seq[T]], + expectedStateSnapshots: Seq[Seq[(K, S)]] + ): Unit = { + require(expectedOutputs.size == expectedStateSnapshots.size) + + val (collectedOutputs, collectedStateSnapshots) = + getOperationOutput(input, mapWithStateSpec, expectedOutputs.size) + assert(expectedOutputs, collectedOutputs, "outputs") + assert(expectedStateSnapshots, collectedStateSnapshots, "state snapshots") + } + + private def getOperationOutput[K: ClassTag, S: ClassTag, T: ClassTag]( + input: Seq[Seq[K]], + mapWithStateSpec: StateSpec[K, Int, S, T], + numBatches: Int + ): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = { + + // Setup the stream computation + val ssc = new StreamingContext(sc, Seconds(1)) + val inputStream = new TestInputStream(ssc, input, numPartitions = 2) + val trackeStateStream = inputStream.map(x => (x, 1)).mapWithState(mapWithStateSpec) + val collectedOutputs = new ConcurrentLinkedQueue[Seq[T]] + val outputStream = new TestOutputStream(trackeStateStream, collectedOutputs) + val collectedStateSnapshots = new ConcurrentLinkedQueue[Seq[(K, S)]] + val stateSnapshotStream = new TestOutputStream( + trackeStateStream.stateSnapshots(), collectedStateSnapshots) + outputStream.register() + stateSnapshotStream.register() + + val batchCounter = new BatchCounter(ssc) + ssc.checkpoint(checkpointDir.toString) + ssc.start() + + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + clock.advance(batchDuration.milliseconds * numBatches) + + batchCounter.waitUntilBatchesCompleted(numBatches, 10000) + ssc.stop(stopSparkContext = false) + (collectedOutputs.asScala.toSeq, collectedStateSnapshots.asScala.toSeq) + } + + private def assert[U](expected: Seq[Seq[U]], collected: Seq[Seq[U]], typ: String) { + val debugString = "\nExpected:\n" + expected.mkString("\n") + + "\nCollected:\n" + collected.mkString("\n") + assert(expected.size === collected.size, + s"number of collected $typ (${collected.size}) different from expected (${expected.size})" + + debugString) + expected.zip(collected).foreach { case (c, e) => + assert(c.toSet === e.toSet, + s"collected $typ is different from expected $debugString" + ) + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index 0e64b57e0ffd8..60c8e702352cf 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -17,23 +17,22 @@ package org.apache.spark.streaming -import org.apache.spark.Logging -import org.apache.spark.streaming.dstream.DStream -import org.apache.spark.util.Utils +import java.io.{File, IOException} +import java.nio.charset.StandardCharsets +import java.util.UUID -import scala.util.Random +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag - -import java.io.{File, IOException} -import java.nio.charset.Charset -import java.util.UUID +import scala.util.Random import com.google.common.io.Files - -import org.apache.hadoop.fs.Path import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.spark.internal.Logging +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.Utils private[streaming] object MasterFailureTest extends Logging { @@ -202,12 +201,12 @@ object MasterFailureTest extends Logging { * the last expected output is generated. Finally, return */ private def runStreams[T: ClassTag]( - ssc_ : StreamingContext, + _ssc: StreamingContext, lastExpectedOutput: T, maxTimeToRun: Long ): Seq[T] = { - var ssc = ssc_ + var ssc = _ssc var totalTimeRan = 0L var isLastOutputGenerated = false var isTimedOut = false @@ -217,8 +216,8 @@ object MasterFailureTest extends Logging { while(!isLastOutputGenerated && !isTimedOut) { // Get the output buffer - val outputBuffer = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[T]].output - def output = outputBuffer.flatMap(x => x) + val outputQueue = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[T]].output + def output = outputQueue.asScala.flatten // Start the thread to kill the streaming after some time killed = false @@ -243,6 +242,8 @@ object MasterFailureTest extends Logging { } } catch { case e: Exception => logError("Error running streaming context", e) + } finally { + ssc.stop() } if (killingThread.isAlive) { killingThread.interrupt() @@ -251,7 +252,6 @@ object MasterFailureTest extends Logging { // to null after the next test creates the new SparkContext and fail the test. killingThread.join() } - ssc.stop() logInfo("Has been killed = " + killed) logInfo("Is last output generated = " + isLastOutputGenerated) @@ -259,9 +259,9 @@ object MasterFailureTest extends Logging { // Verify whether the output of each batch has only one element or no element // and then merge the new output with all the earlier output - mergedOutput ++= output + mergedOutput ++= output.toSeq totalTimeRan += timeRan - logInfo("New output = " + output) + logInfo("New output = " + output.toSeq) logInfo("Merged output = " + mergedOutput) logInfo("Time ran = " + timeRan) logInfo("Total time ran = " + totalTimeRan) @@ -371,7 +371,7 @@ class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long) val localFile = new File(localTestDir, (i + 1).toString) val hadoopFile = new Path(testDir, (i + 1).toString) val tempHadoopFile = new Path(testDir, ".tmp_" + (i + 1).toString) - Files.write(input(i) + "\n", localFile, Charset.forName("UTF-8")) + Files.write(input(i) + "\n", localFile, StandardCharsets.UTF_8) var tries = 0 var done = false while (!done && tries < maxTries) { @@ -382,11 +382,10 @@ class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long) fs.rename(tempHadoopFile, hadoopFile) done = true } catch { - case ioe: IOException => { + case ioe: IOException => fs = testDir.getFileSystem(new Configuration()) logWarning("Attempt " + tries + " at generating file " + hadoopFile + " failed.", ioe) - } } } if (!done) { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index c17fb7238151b..39d0de5179ea9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -29,18 +29,18 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ +import org.apache.spark.internal.Logging import org.apache.spark.memory.StaticMemoryManager import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus -import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.serializer.{KryoSerializer, SerializerManager} import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage._ import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.util._ import org.apache.spark.util.{ManualClock, Utils} -import WriteAheadLogBasedBlockHandler._ -import WriteAheadLogSuite._ +import org.apache.spark.util.io.ChunkedByteBuffer class ReceivedBlockHandlerSuite extends SparkFunSuite @@ -48,6 +48,9 @@ class ReceivedBlockHandlerSuite with Matchers with Logging { + import WriteAheadLogBasedBlockHandler._ + import WriteAheadLogSuite._ + val conf = new SparkConf() .set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") .set("spark.app.id", "streaming-test") @@ -57,6 +60,7 @@ class ReceivedBlockHandlerSuite val mapOutputTracker = new MapOutputTrackerMaster(conf) val shuffleManager = new HashShuffleManager(conf) val serializer = new KryoSerializer(conf) + var serializerManager = new SerializerManager(serializer, conf) val manualClock = new ManualClock val blockManagerSize = 10000000 val blockManagerBuffer = new ArrayBuffer[BlockManager]() @@ -105,7 +109,10 @@ class ReceivedBlockHandlerSuite testBlockStoring(handler) { case (data, blockIds, storeResults) => // Verify the data in block manager is correct val storedData = blockIds.flatMap { blockId => - blockManager.getLocal(blockId).map(_.data.map(_.toString).toList).getOrElse(List.empty) + blockManager + .getLocalValues(blockId) + .map(_.data.map(_.toString).toList) + .getOrElse(List.empty) }.toList storedData shouldEqual data @@ -129,7 +136,10 @@ class ReceivedBlockHandlerSuite testBlockStoring(handler) { case (data, blockIds, storeResults) => // Verify the data in block manager is correct val storedData = blockIds.flatMap { blockId => - blockManager.getLocal(blockId).map(_.data.map(_.toString).toList).getOrElse(List.empty) + blockManager + .getLocalValues(blockId) + .map(_.data.map(_.toString).toList) + .getOrElse(List.empty) }.toList storedData shouldEqual data @@ -147,7 +157,8 @@ class ReceivedBlockHandlerSuite val reader = new FileBasedWriteAheadLogRandomReader(fileSegment.path, hadoopConf) val bytes = reader.read(fileSegment) reader.close() - blockManager.dataDeserialize(generateBlockId(), bytes).toList + serializerManager.dataDeserializeStream( + generateBlockId(), new ChunkedByteBuffer(bytes).toInputStream()).toList } loggedData shouldEqual data } @@ -195,13 +206,13 @@ class ReceivedBlockHandlerSuite blockManager = createBlockManager(12000, sparkConf) // there is not enough space to store this block in MEMORY, - // But BlockManager will be able to sereliaze this block to WAL + // But BlockManager will be able to serialize this block to WAL // and hence count returns correct value. testRecordcount(false, StorageLevel.MEMORY_ONLY, IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator), blockManager, Some(70)) // there is not enough space to store this block in MEMORY, - // But BlockManager will be able to sereliaze this block to DISK + // But BlockManager will be able to serialize this block to DISK // and hence count returns correct value. testRecordcount(true, StorageLevel.MEMORY_AND_DISK, IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator), blockManager, Some(70)) @@ -255,8 +266,8 @@ class ReceivedBlockHandlerSuite conf: SparkConf, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) - val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) - val blockManager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, conf, + val transfer = new NettyBlockTransferService(conf, securityMgr, "localhost", numCores = 1) + val blockManager = new BlockManager(name, rpcEnv, blockManagerMaster, serializerManager, conf, memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) memManager.setMemoryStore(blockManager.memoryStore) blockManager.initialize("app-id") @@ -265,7 +276,7 @@ class ReceivedBlockHandlerSuite } /** - * Test storing of data using different types of Handler, StorageLevle and ReceivedBlocks + * Test storing of data using different types of Handler, StorageLevel and ReceivedBlocks * and verify the correct record count */ private def testRecordcount(isBlockManagedBasedBlockHandler: Boolean, @@ -325,13 +336,14 @@ class ReceivedBlockHandlerSuite } } - def dataToByteBuffer(b: Seq[String]) = blockManager.dataSerialize(generateBlockId, b.iterator) + def dataToByteBuffer(b: Seq[String]) = + serializerManager.dataSerialize(generateBlockId, b.iterator) val blocks = data.grouped(10).toSeq storeAndVerify(blocks.map { b => IteratorBlock(b.toIterator) }) storeAndVerify(blocks.map { b => ArrayBufferBlock(new ArrayBuffer ++= b) }) - storeAndVerify(blocks.map { b => ByteBufferBlock(dataToByteBuffer(b)) }) + storeAndVerify(blocks.map { b => ByteBufferBlock(dataToByteBuffer(b).toByteBuffer) }) } /** Test error handling when blocks that cannot be stored */ @@ -357,8 +369,8 @@ class ReceivedBlockHandlerSuite /** Instantiate a WriteAheadLogBasedBlockHandler and run a code with it */ private def withWriteAheadLogBasedBlockHandler(body: WriteAheadLogBasedBlockHandler => Unit) { require(WriteAheadLogUtils.getRollingIntervalSecs(conf, isDriver = false) === 1) - val receivedBlockHandler = new WriteAheadLogBasedBlockHandler(blockManager, 1, - storageLevel, conf, hadoopConf, tempDirectory.toString, manualClock) + val receivedBlockHandler = new WriteAheadLogBasedBlockHandler(blockManager, serializerManager, + 1, storageLevel, conf, hadoopConf, tempDirectory.toString, manualClock) try { body(receivedBlockHandler) } finally { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index f793a12843b2f..851013bb1e846 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming import java.io.File +import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ @@ -28,11 +29,12 @@ import org.apache.hadoop.conf.Configuration import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{Logging, SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.internal.Logging import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult import org.apache.spark.streaming.scheduler._ -import org.apache.spark.streaming.util.{WriteAheadLogUtils, FileBasedWriteAheadLogReader} +import org.apache.spark.streaming.util._ import org.apache.spark.streaming.util.WriteAheadLogSuite._ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} @@ -40,7 +42,6 @@ class ReceivedBlockTrackerSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { val hadoopConf = new Configuration() - val akkaTimeout = 10 seconds val streamId = 1 var allReceivedBlockTrackers = new ArrayBuffer[ReceivedBlockTracker]() @@ -207,6 +208,75 @@ class ReceivedBlockTrackerSuite tracker1.isWriteAheadLogEnabled should be (false) } + test("parallel file deletion in FileBasedWriteAheadLog is robust to deletion error") { + conf.set("spark.streaming.driver.writeAheadLog.rollingIntervalSecs", "1") + require(WriteAheadLogUtils.getRollingIntervalSecs(conf, isDriver = true) === 1) + + val addBlocks = generateBlockInfos() + val batch1 = addBlocks.slice(0, 1) + val batch2 = addBlocks.slice(1, 3) + val batch3 = addBlocks.slice(3, addBlocks.length) + + assert(getWriteAheadLogFiles().length === 0) + + // list of timestamps for files + val t = Seq.tabulate(5)(i => i * 1000) + + writeEventsManually(getLogFileName(t(0)), Seq(createBatchCleanup(t(0)))) + assert(getWriteAheadLogFiles().length === 1) + + // The goal is to create several log files which should have been cleaned up. + // If we face any issue during recovery, because these old files exist, then we need to make + // deletion more robust rather than a parallelized operation where we fire and forget + val batch1Allocation = createBatchAllocation(t(1), batch1) + writeEventsManually(getLogFileName(t(1)), batch1.map(BlockAdditionEvent) :+ batch1Allocation) + + writeEventsManually(getLogFileName(t(2)), Seq(createBatchCleanup(t(1)))) + + val batch2Allocation = createBatchAllocation(t(3), batch2) + writeEventsManually(getLogFileName(t(3)), batch2.map(BlockAdditionEvent) :+ batch2Allocation) + + writeEventsManually(getLogFileName(t(4)), batch3.map(BlockAdditionEvent)) + + // We should have 5 different log files as we called `writeEventsManually` with 5 different + // timestamps + assert(getWriteAheadLogFiles().length === 5) + + // Create the tracker to recover from the log files. We're going to ask the tracker to clean + // things up, and then we're going to rewrite that data, and recover using a different tracker. + // They should have identical data no matter what + val tracker = createTracker(recoverFromWriteAheadLog = true, clock = new ManualClock(t(4))) + + def compareTrackers(base: ReceivedBlockTracker, subject: ReceivedBlockTracker): Unit = { + subject.getBlocksOfBatchAndStream(t(3), streamId) should be( + base.getBlocksOfBatchAndStream(t(3), streamId)) + subject.getBlocksOfBatchAndStream(t(1), streamId) should be( + base.getBlocksOfBatchAndStream(t(1), streamId)) + subject.getBlocksOfBatchAndStream(t(0), streamId) should be(Nil) + } + + // ask the tracker to clean up some old files + tracker.cleanupOldBatches(t(3), waitForCompletion = true) + assert(getWriteAheadLogFiles().length === 3) + + val tracker2 = createTracker(recoverFromWriteAheadLog = true, clock = new ManualClock(t(4))) + compareTrackers(tracker, tracker2) + + // rewrite first file + writeEventsManually(getLogFileName(t(0)), Seq(createBatchCleanup(t(0)))) + assert(getWriteAheadLogFiles().length === 4) + // make sure trackers are consistent + val tracker3 = createTracker(recoverFromWriteAheadLog = true, clock = new ManualClock(t(4))) + compareTrackers(tracker, tracker3) + + // rewrite second file + writeEventsManually(getLogFileName(t(1)), batch1.map(BlockAdditionEvent) :+ batch1Allocation) + assert(getWriteAheadLogFiles().length === 5) + // make sure trackers are consistent + val tracker4 = createTracker(recoverFromWriteAheadLog = true, clock = new ManualClock(t(4))) + compareTrackers(tracker, tracker4) + } + /** * Create tracker object with the optional provided clock. Use fake clock if you * want to control time by manually incrementing it to test log clean. @@ -228,11 +298,30 @@ class ReceivedBlockTrackerSuite BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt)), Some(0L)))) } + /** + * Write received block tracker events to a file manually. + */ + def writeEventsManually(filePath: String, events: Seq[ReceivedBlockTrackerLogEvent]): Unit = { + val writer = HdfsUtils.getOutputStream(filePath, hadoopConf) + events.foreach { event => + val bytes = Utils.serialize(event) + writer.writeInt(bytes.size) + writer.write(bytes) + } + writer.close() + } + /** Get all the data written in the given write ahead log file. */ def getWrittenLogData(logFile: String): Seq[ReceivedBlockTrackerLogEvent] = { getWrittenLogData(Seq(logFile)) } + /** Get the log file name for the given log start time. */ + def getLogFileName(time: Long, rollingIntervalSecs: Int = 1): String = { + checkpointDirectory.toString + File.separator + "receivedBlockMetadata" + + File.separator + s"log-$time-${time + rollingIntervalSecs * 1000}" + } + /** * Get all the data written in the given write ahead log files. By default, it will read all * files in the test log directory. @@ -241,8 +330,13 @@ class ReceivedBlockTrackerSuite : Seq[ReceivedBlockTrackerLogEvent] = { logFiles.flatMap { file => new FileBasedWriteAheadLogReader(file, hadoopConf).toSeq - }.map { byteBuffer => - Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) + }.flatMap { byteBuffer => + val validBuffer = if (WriteAheadLogUtils.isBatchingEnabled(conf, isDriver = true)) { + Utils.deserialize[Array[Array[Byte]]](byteBuffer.array()).map(ByteBuffer.wrap) + } else { + Array(byteBuffer) + } + validBuffer.map(b => Utils.deserialize[ReceivedBlockTrackerLogEvent](b.array())) }.toList } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala index 6d388d9624d92..6763ac64da287 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.scalatest.BeforeAndAfterAll +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.rdd.BlockRDD import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming.dstream.ReceiverInputDStream @@ -28,12 +29,15 @@ import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.{BlockManagerBasedStoreResult, Receiver, WriteAheadLogBasedStoreResult} import org.apache.spark.streaming.scheduler.ReceivedBlockInfo import org.apache.spark.streaming.util.{WriteAheadLogRecordHandle, WriteAheadLogUtils} -import org.apache.spark.{SparkConf, SparkEnv} class ReceiverInputDStreamSuite extends TestSuiteBase with BeforeAndAfterAll { override def afterAll(): Unit = { - StreamingContext.getActive().map { _.stop() } + try { + StreamingContext.getActive().map { _.stop() } + } finally { + super.afterAll() + } } testWithoutWAL("createBlockRDD creates empty BlockRDD when no block info") { receiverStream => @@ -93,7 +97,7 @@ class ReceiverInputDStreamSuite extends TestSuiteBase with BeforeAndAfterAll { assert(blockRDD.walRecordHandles.toSeq === blockInfos.map { _.walRecordHandleOption.get }) } - testWithWAL("createBlockRDD creates BlockRDD when some block info dont have WAL info") { + testWithWAL("createBlockRDD creates BlockRDD when some block info don't have WAL info") { receiverStream => val blockInfos1 = Seq.fill(2) { createBlockInfo(withWALInfo = true) } val blockInfos2 = Seq.fill(3) { createBlockInfo(withWALInfo = false) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index 01279b34f73dc..917232c9cdd63 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -24,8 +24,8 @@ import java.util.concurrent.Semaphore import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala new file mode 100644 index 0000000000000..484f3733e8423 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import scala.collection.{immutable, mutable, Map} +import scala.reflect.ClassTag +import scala.util.Random + +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer._ +import org.apache.spark.streaming.rdd.MapWithStateRDDRecord +import org.apache.spark.streaming.util.{EmptyStateMap, OpenHashMapBasedStateMap, StateMap} + +class StateMapSuite extends SparkFunSuite { + + private val conf = new SparkConf() + + test("EmptyStateMap") { + val map = new EmptyStateMap[Int, Int] + intercept[scala.NotImplementedError] { + map.put(1, 1, 1) + } + assert(map.get(1) === None) + assert(map.getByTime(10000).isEmpty) + assert(map.getAll().isEmpty) + map.remove(1) // no exception + assert(map.copy().eq(map)) + } + + test("OpenHashMapBasedStateMap - put, get, getByTime, getAll, remove") { + val map = new OpenHashMapBasedStateMap[Int, Int]() + + map.put(1, 100, 10) + assert(map.get(1) === Some(100)) + assert(map.get(2) === None) + assert(map.getByTime(11).toSet === Set((1, 100, 10))) + assert(map.getByTime(10).toSet === Set.empty) + assert(map.getByTime(9).toSet === Set.empty) + assert(map.getAll().toSet === Set((1, 100, 10))) + + map.put(2, 200, 20) + assert(map.getByTime(21).toSet === Set((1, 100, 10), (2, 200, 20))) + assert(map.getByTime(11).toSet === Set((1, 100, 10))) + assert(map.getByTime(10).toSet === Set.empty) + assert(map.getByTime(9).toSet === Set.empty) + assert(map.getAll().toSet === Set((1, 100, 10), (2, 200, 20))) + + map.remove(1) + assert(map.get(1) === None) + assert(map.getAll().toSet === Set((2, 200, 20))) + } + + test("OpenHashMapBasedStateMap - put, get, getByTime, getAll, remove with copy") { + val parentMap = new OpenHashMapBasedStateMap[Int, Int]() + parentMap.put(1, 100, 1) + parentMap.put(2, 200, 2) + parentMap.remove(1) + + // Create child map and make changes + val map = parentMap.copy() + assert(map.get(1) === None) + assert(map.get(2) === Some(200)) + assert(map.getByTime(10).toSet === Set((2, 200, 2))) + assert(map.getByTime(2).toSet === Set.empty) + assert(map.getAll().toSet === Set((2, 200, 2))) + + // Add new items + map.put(3, 300, 3) + assert(map.get(3) === Some(300)) + map.put(4, 400, 4) + assert(map.get(4) === Some(400)) + assert(map.getByTime(10).toSet === Set((2, 200, 2), (3, 300, 3), (4, 400, 4))) + assert(map.getByTime(4).toSet === Set((2, 200, 2), (3, 300, 3))) + assert(map.getAll().toSet === Set((2, 200, 2), (3, 300, 3), (4, 400, 4))) + assert(parentMap.getAll().toSet === Set((2, 200, 2))) + + // Remove items + map.remove(4) + assert(map.get(4) === None) // item added in this map, then removed in this map + map.remove(2) + assert(map.get(2) === None) // item removed in parent map, then added in this map + assert(map.getAll().toSet === Set((3, 300, 3))) + assert(parentMap.getAll().toSet === Set((2, 200, 2))) + + // Update items + map.put(1, 1000, 100) + assert(map.get(1) === Some(1000)) // item removed in parent map, then added in this map + map.put(2, 2000, 200) + assert(map.get(2) === Some(2000)) // item added in parent map, then removed + added in this map + map.put(3, 3000, 300) + assert(map.get(3) === Some(3000)) // item added + updated in this map + map.put(4, 4000, 400) + assert(map.get(4) === Some(4000)) // item removed + updated in this map + + assert(map.getAll().toSet === + Set((1, 1000, 100), (2, 2000, 200), (3, 3000, 300), (4, 4000, 400))) + assert(parentMap.getAll().toSet === Set((2, 200, 2))) + + map.remove(2) // remove item present in parent map, so that its not visible in child map + + // Create child map and see availability of items + val childMap = map.copy() + assert(childMap.getAll().toSet === map.getAll().toSet) + assert(childMap.get(1) === Some(1000)) // item removed in grandparent, but added in parent map + assert(childMap.get(2) === None) // item added in grandparent, but removed in parent map + assert(childMap.get(3) === Some(3000)) // item added and updated in parent map + + childMap.put(2, 20000, 200) + assert(childMap.get(2) === Some(20000)) // item map + } + + test("OpenHashMapBasedStateMap - serializing and deserializing") { + val map1 = new OpenHashMapBasedStateMap[Int, Int]() + testSerialization(map1, "error deserializing and serialized empty map") + + map1.put(1, 100, 1) + map1.put(2, 200, 2) + testSerialization(map1, "error deserializing and serialized map with data + no delta") + + val map2 = map1.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]] + // Do not test compaction + assert(map2.shouldCompact === false) + testSerialization(map2, "error deserializing and serialized map with 1 delta + no new data") + + map2.put(3, 300, 3) + map2.put(4, 400, 4) + testSerialization(map2, "error deserializing and serialized map with 1 delta + new data") + + val map3 = map2.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]] + assert(map3.shouldCompact === false) + testSerialization(map3, "error deserializing and serialized map with 2 delta + no new data") + map3.put(3, 600, 3) + map3.remove(2) + testSerialization(map3, "error deserializing and serialized map with 2 delta + new data") + } + + test("OpenHashMapBasedStateMap - serializing and deserializing with compaction") { + val targetDeltaLength = 10 + val deltaChainThreshold = 5 + + var map = new OpenHashMapBasedStateMap[Int, Int]( + deltaChainThreshold = deltaChainThreshold) + + // Make large delta chain with length more than deltaChainThreshold + for(i <- 1 to targetDeltaLength) { + map.put(Random.nextInt(), Random.nextInt(), 1) + map = map.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]] + } + assert(map.deltaChainLength > deltaChainThreshold) + assert(map.shouldCompact === true) + + val deser_map = testSerialization(map, "Deserialized + compacted map not same as original map") + assert(deser_map.deltaChainLength < deltaChainThreshold) + assert(deser_map.shouldCompact === false) + } + + test("OpenHashMapBasedStateMap - all possible sequences of operations with copies ") { + /* + * This tests the map using all permutations of sequences operations, across multiple map + * copies as well as between copies. It is to ensure complete coverage, though it is + * kind of hard to debug this. It is set up as follows. + * + * - For any key, there can be 2 types of update ops on a state map - put or remove + * + * - These operations are done on a test map in "sets". After each set, the map is "copied" + * to create a new map, and the next set of operations are done on the new one. This tests + * whether the map data persist correctly across copies. + * + * - Within each set, there are a number of operations to test whether the map correctly + * updates and removes data without affecting the parent state map. + * + * - Overall this creates (numSets * numOpsPerSet) operations, each of which that can 2 types + * of operations. This leads to a total of [2 ^ (numSets * numOpsPerSet)] different sequence + * of operations, which we will test with different keys. + * + * Example: With numSets = 2, and numOpsPerSet = 2 give numTotalOps = 4. This means that + * 2 ^ 4 = 16 possible permutations needs to be tested using 16 keys. + * _______________________________________________ + * | | Set1 | Set2 | + * | |-----------------|-----------------| + * | | Op1 Op2 |c| Op3 Op4 | + * |---------|----------------|o|----------------| + * | key 0 | put put |p| put put | + * | key 1 | put put |y| put rem | + * | key 2 | put put | | rem put | + * | key 3 | put put |t| rem rem | + * | key 4 | put rem |h| put put | + * | key 5 | put rem |e| put rem | + * | key 6 | put rem | | rem put | + * | key 7 | put rem |s| rem rem | + * | key 8 | rem put |t| put put | + * | key 9 | rem put |a| put rem | + * | key 10 | rem put |t| rem put | + * | key 11 | rem put |e| rem rem | + * | key 12 | rem rem | | put put | + * | key 13 | rem rem |m| put rem | + * | key 14 | rem rem |a| rem put | + * | key 15 | rem rem |p| rem rem | + * |_________|________________|_|________________| + */ + + val numTypeMapOps = 2 // 0 = put a new value, 1 = remove value + val numSets = 3 + val numOpsPerSet = 3 // to test seq of ops like update -> remove -> update in same set + val numTotalOps = numOpsPerSet * numSets + val numKeys = math.pow(numTypeMapOps, numTotalOps).toInt // to get all combinations of ops + + val refMap = new mutable.HashMap[Int, (Int, Long)]() + var prevSetRefMap: immutable.Map[Int, (Int, Long)] = null + + var stateMap: StateMap[Int, Int] = new OpenHashMapBasedStateMap[Int, Int]() + var prevSetStateMap: StateMap[Int, Int] = null + + var time = 1L + + for (setId <- 0 until numSets) { + for (opInSetId <- 0 until numOpsPerSet) { + val opId = setId * numOpsPerSet + opInSetId + for (keyId <- 0 until numKeys) { + time += 1 + // Find the operation type that needs to be done + // This is similar to finding the nth bit value of a binary number + // E.g. nth bit from the right of any binary number B is [ B / (2 ^ (n - 1)) ] % 2 + val opCode = + (keyId / math.pow(numTypeMapOps, numTotalOps - opId - 1).toInt) % numTypeMapOps + opCode match { + case 0 => + val value = Random.nextInt() + stateMap.put(keyId, value, time) + refMap.put(keyId, (value, time)) + case 1 => + stateMap.remove(keyId) + refMap.remove(keyId) + } + } + + // Test whether the current state map after all key updates is correct + assertMap(stateMap, refMap, time, "State map does not match reference map") + + // Test whether the previous map before copy has not changed + if (prevSetStateMap != null && prevSetRefMap != null) { + assertMap(prevSetStateMap, prevSetRefMap, time, + "Parent state map somehow got modified, does not match corresponding reference map") + } + } + + // Copy the map and remember the previous maps for future tests + prevSetStateMap = stateMap + prevSetRefMap = refMap.toMap + stateMap = stateMap.copy() + + // Assert that the copied map has the same data + assertMap(stateMap, prevSetRefMap, time, + "State map does not match reference map after copying") + } + assertMap(stateMap, refMap.toMap, time, "Final state map does not match reference map") + } + + private def testSerialization[T: ClassTag]( + map: OpenHashMapBasedStateMap[T, T], msg: String): OpenHashMapBasedStateMap[T, T] = { + testSerialization(new JavaSerializer(conf), map, msg) + testSerialization(new KryoSerializer(conf), map, msg) + } + + private def testSerialization[T: ClassTag]( + serializer: Serializer, + map: OpenHashMapBasedStateMap[T, T], + msg: String): OpenHashMapBasedStateMap[T, T] = { + val deserMap = serializeAndDeserialize(serializer, map) + assertMap(deserMap, map, 1, msg) + deserMap + } + + // Assert whether all the data and operations on a state map matches that of a reference state map + private def assertMap[T]( + mapToTest: StateMap[T, T], + refMapToTestWith: StateMap[T, T], + time: Long, + msg: String): Unit = { + withClue(msg) { + // Assert all the data is same as the reference map + assert(mapToTest.getAll().toSet === refMapToTestWith.getAll().toSet) + + // Assert that get on every key returns the right value + for (keyId <- refMapToTestWith.getAll().map { _._1 }) { + assert(mapToTest.get(keyId) === refMapToTestWith.get(keyId)) + } + + // Assert that every time threshold returns the correct data + for (t <- 0L to (time + 1)) { + assert(mapToTest.getByTime(t).toSet === refMapToTestWith.getByTime(t).toSet) + } + } + } + + // Assert whether all the data and operations on a state map matches that of a reference map + private def assertMap( + mapToTest: StateMap[Int, Int], + refMapToTestWith: Map[Int, (Int, Long)], + time: Long, + msg: String): Unit = { + withClue(msg) { + // Assert all the data is same as the reference map + assert(mapToTest.getAll().toSet === + refMapToTestWith.iterator.map { x => (x._1, x._2._1, x._2._2) }.toSet) + + // Assert that get on every key returns the right value + for (keyId <- refMapToTestWith.keys) { + assert(mapToTest.get(keyId) === refMapToTestWith.get(keyId).map { _._1 }) + } + + // Assert that every time threshold returns the correct data + for (t <- 0L to (time + 1)) { + val expectedRecords = + refMapToTestWith.iterator.filter { _._2._2 < t }.map { x => (x._1, x._2._1, x._2._2) } + assert(mapToTest.getByTime(t).toSet === expectedRecords.toSet) + } + } + } + + test("OpenHashMapBasedStateMap - serializing and deserializing with KryoSerializable states") { + val map = new OpenHashMapBasedStateMap[KryoState, KryoState]() + map.put(new KryoState("a"), new KryoState("b"), 1) + testSerialization( + new KryoSerializer(conf), map, "error deserializing and serialized KryoSerializable states") + } + + test("EmptyStateMap - serializing and deserializing") { + val map = StateMap.empty[KryoState, KryoState] + // Since EmptyStateMap doesn't contains any date, KryoState won't break JavaSerializer. + assert(serializeAndDeserialize(new JavaSerializer(conf), map). + isInstanceOf[EmptyStateMap[KryoState, KryoState]]) + assert(serializeAndDeserialize(new KryoSerializer(conf), map). + isInstanceOf[EmptyStateMap[KryoState, KryoState]]) + } + + test("MapWithStateRDDRecord - serializing and deserializing with KryoSerializable states") { + val map = new OpenHashMapBasedStateMap[KryoState, KryoState]() + map.put(new KryoState("a"), new KryoState("b"), 1) + + val record = + MapWithStateRDDRecord[KryoState, KryoState, KryoState](map, Seq(new KryoState("c"))) + val deserRecord = serializeAndDeserialize(new KryoSerializer(conf), record) + assert(!(record eq deserRecord)) + assert(record.stateMap.getAll().toSeq === deserRecord.stateMap.getAll().toSeq) + assert(record.mappedData === deserRecord.mappedData) + } + + private def serializeAndDeserialize[T: ClassTag](serializer: Serializer, t: T): T = { + val serializerInstance = serializer.newInstance() + serializerInstance.deserialize[T]( + serializerInstance.serialize(t), Thread.currentThread().getContextClassLoader) + } +} + +/** A class that only supports Kryo serialization. */ +private[streaming] final class KryoState(var state: String) extends KryoSerializable { + + override def write(kryo: Kryo, output: Output): Unit = { + kryo.writeClassAndObject(output, state) + } + + override def read(kryo: Kryo, input: Input): Unit = { + state = kryo.readClassAndObject(input).asInstanceOf[String] + } + + override def equals(other: Any): Boolean = other match { + case that: KryoState => state == that.state + case _ => false + } + + override def hashCode(): Int = { + if (state == null) 0 else state.hashCode() + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index c7a877142b374..806e181f61980 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -31,6 +31,7 @@ import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.internal.Logging import org.apache.spark.metrics.MetricsSystem import org.apache.spark.metrics.source.Source import org.apache.spark.storage.StorageLevel @@ -81,9 +82,9 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo test("from conf with settings") { val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) - myConf.set("spark.cleaner.ttl", "10s") + myConf.set("spark.dummyTimeConfig", "10s") ssc = new StreamingContext(myConf, batchDuration) - assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) + assert(ssc.conf.getTimeAsSeconds("spark.dummyTimeConfig", "-1") === 10) } test("from existing SparkContext") { @@ -93,26 +94,27 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo test("from existing SparkContext with settings") { val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) - myConf.set("spark.cleaner.ttl", "10s") + myConf.set("spark.dummyTimeConfig", "10s") ssc = new StreamingContext(myConf, batchDuration) - assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) + assert(ssc.conf.getTimeAsSeconds("spark.dummyTimeConfig", "-1") === 10) } test("from checkpoint") { val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) - myConf.set("spark.cleaner.ttl", "10s") + myConf.set("spark.dummyTimeConfig", "10s") val ssc1 = new StreamingContext(myConf, batchDuration) addInputStream(ssc1).register() ssc1.start() val cp = new Checkpoint(ssc1, Time(1000)) assert( Utils.timeStringAsSeconds(cp.sparkConfPairs - .toMap.getOrElse("spark.cleaner.ttl", "-1")) === 10) + .toMap.getOrElse("spark.dummyTimeConfig", "-1")) === 10) ssc1.stop() val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp)) - assert(newCp.createSparkConf().getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) + assert( + newCp.createSparkConf().getTimeAsSeconds("spark.dummyTimeConfig", "-1") === 10) ssc = new StreamingContext(null, newCp, null) - assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) + assert(ssc.conf.getTimeAsSeconds("spark.dummyTimeConfig", "-1") === 10) } test("checkPoint from conf") { @@ -146,7 +148,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo } } - test("start with non-seriazable DStream checkpoints") { + test("start with non-serializable DStream checkpoints") { val checkpointDir = Utils.createTempDir() ssc = new StreamingContext(conf, batchDuration) ssc.checkpoint(checkpointDir.getAbsolutePath) @@ -180,7 +182,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo assert(ssc.scheduler.isStarted === false) } - test("start should set job group and description of streaming jobs correctly") { + test("start should set local properties of streaming jobs correctly") { ssc = new StreamingContext(conf, batchDuration) ssc.sc.setJobGroup("non-streaming", "non-streaming", true) val sc = ssc.sc @@ -188,16 +190,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo @volatile var jobGroupFound: String = "" @volatile var jobDescFound: String = "" @volatile var jobInterruptFound: String = "" + @volatile var customPropFound: String = "" @volatile var allFound: Boolean = false addInputStream(ssc).foreachRDD { rdd => jobGroupFound = sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) jobDescFound = sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) jobInterruptFound = sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) + customPropFound = sc.getLocalProperty("customPropKey") allFound = true } + ssc.sc.setLocalProperty("customPropKey", "value1") ssc.start() + // Local props set after start should be ignored + ssc.sc.setLocalProperty("customPropKey", "value2") + eventually(timeout(10 seconds), interval(10 milliseconds)) { assert(allFound === true) } @@ -206,11 +214,13 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo assert(jobGroupFound === null) assert(jobDescFound.contains("Streaming job from")) assert(jobInterruptFound === "false") + assert(customPropFound === "value1") // Verify current thread's thread-local properties have not changed assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "non-streaming") assert(sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) === "non-streaming") assert(sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) === "true") + assert(sc.getLocalProperty("customPropKey") === "value2") } test("start multiple times") { @@ -288,7 +298,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo test("stop gracefully") { val conf = new SparkConf().setMaster(master).setAppName(appName) - conf.set("spark.cleaner.ttl", "3600s") + conf.set("spark.dummyTimeConfig", "3600s") sc = new SparkContext(conf) for (i <- 1 to 4) { logInfo("==================================\n\n\n") @@ -780,6 +790,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo "Please don't use queueStream when checkpointing is enabled.")) } + test("Creating an InputDStream but not using it should not crash") { + ssc = new StreamingContext(master, appName, batchDuration) + val input1 = addInputStream(ssc) + val input2 = addInputStream(ssc) + val output = new TestOutputStream(input2) + output.register() + val batchCount = new BatchCounter(ssc) + ssc.start() + // Just wait for completing 2 batches to make sure it triggers + // `DStream.getMaxInputStreamRememberDuration` + batchCount.waitUntilBatchesCompleted(2, 10000) + // Throw the exception if crash + ssc.awaitTerminationOrTimeout(1) + ssc.stop() + } + def addInputStream(s: StreamingContext): DStream[Int] = { val input = (1 to 100).map(i => 1 to i) val inputStream = new TestInputStream(s, input, 1) @@ -879,7 +905,7 @@ object SlowTestReceiver { package object testPackage extends Assertions { def test() { val conf = new SparkConf().setMaster("local").setAppName("CreationSite test") - val ssc = new StreamingContext(conf , Milliseconds(100)) + val ssc = new StreamingContext(conf, Milliseconds(100)) try { val inputStream = ssc.receiverStream(new TestReceiver) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 5dc0472c7770c..0f957a1b55706 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -17,20 +17,25 @@ package org.apache.spark.streaming -import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedBuffer, SynchronizedMap} -import scala.concurrent.Future +import java.util.concurrent.ConcurrentLinkedQueue + +import scala.collection.JavaConverters._ +import scala.collection.mutable.HashMap import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Future + +import org.mockito.Mockito.{mock, reset, verifyNoMoreInteractions} +import org.scalatest.Matchers +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.streaming.scheduler._ -import org.scalatest.Matchers -import org.scalatest.concurrent.Eventually._ -import org.scalatest.time.SpanSugar._ -import org.apache.spark.Logging - class StreamingListenerSuite extends TestSuiteBase with Matchers { val input = (1 to 4).map(Seq(_)).toSeq @@ -60,43 +65,43 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { val batchInfosSubmitted = collector.batchInfosSubmitted batchInfosSubmitted should have size 4 - batchInfosSubmitted.foreach(info => { + batchInfosSubmitted.asScala.foreach(info => { info.schedulingDelay should be (None) info.processingDelay should be (None) info.totalDelay should be (None) }) - batchInfosSubmitted.foreach { info => + batchInfosSubmitted.asScala.foreach { info => info.numRecords should be (1L) info.streamIdToInputInfo should be (Map(0 -> StreamInputInfo(0, 1L))) } - isInIncreasingOrder(batchInfosSubmitted.map(_.submissionTime)) should be (true) + isInIncreasingOrder(batchInfosSubmitted.asScala.map(_.submissionTime)) should be (true) // SPARK-6766: processingStartTime of batch info should not be None when starting val batchInfosStarted = collector.batchInfosStarted batchInfosStarted should have size 4 - batchInfosStarted.foreach(info => { + batchInfosStarted.asScala.foreach(info => { info.schedulingDelay should not be None info.schedulingDelay.get should be >= 0L info.processingDelay should be (None) info.totalDelay should be (None) }) - batchInfosStarted.foreach { info => + batchInfosStarted.asScala.foreach { info => info.numRecords should be (1L) info.streamIdToInputInfo should be (Map(0 -> StreamInputInfo(0, 1L))) } - isInIncreasingOrder(batchInfosStarted.map(_.submissionTime)) should be (true) - isInIncreasingOrder(batchInfosStarted.map(_.processingStartTime.get)) should be (true) + isInIncreasingOrder(batchInfosStarted.asScala.map(_.submissionTime)) should be (true) + isInIncreasingOrder(batchInfosStarted.asScala.map(_.processingStartTime.get)) should be (true) // test onBatchCompleted val batchInfosCompleted = collector.batchInfosCompleted batchInfosCompleted should have size 4 - batchInfosCompleted.foreach(info => { + batchInfosCompleted.asScala.foreach(info => { info.schedulingDelay should not be None info.processingDelay should not be None info.totalDelay should not be None @@ -105,14 +110,14 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { info.totalDelay.get should be >= 0L }) - batchInfosCompleted.foreach { info => + batchInfosCompleted.asScala.foreach { info => info.numRecords should be (1L) info.streamIdToInputInfo should be (Map(0 -> StreamInputInfo(0, 1L))) } - isInIncreasingOrder(batchInfosCompleted.map(_.submissionTime)) should be (true) - isInIncreasingOrder(batchInfosCompleted.map(_.processingStartTime.get)) should be (true) - isInIncreasingOrder(batchInfosCompleted.map(_.processingEndTime.get)) should be (true) + isInIncreasingOrder(batchInfosCompleted.asScala.map(_.submissionTime)) should be (true) + isInIncreasingOrder(batchInfosCompleted.asScala.map(_.processingStartTime.get)) should be (true) + isInIncreasingOrder(batchInfosCompleted.asScala.map(_.processingEndTime.get)) should be (true) } test("receiver info reporting") { @@ -127,13 +132,13 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { try { eventually(timeout(30 seconds), interval(20 millis)) { collector.startedReceiverStreamIds.size should equal (1) - collector.startedReceiverStreamIds(0) should equal (0) - collector.stoppedReceiverStreamIds should have size 1 - collector.stoppedReceiverStreamIds(0) should equal (0) + collector.startedReceiverStreamIds.peek() should equal (0) + collector.stoppedReceiverStreamIds.size should equal (1) + collector.stoppedReceiverStreamIds.peek() should equal (0) collector.receiverErrors should have size 1 - collector.receiverErrors(0)._1 should equal (0) - collector.receiverErrors(0)._2 should include ("report error") - collector.receiverErrors(0)._3 should include ("report exception") + collector.receiverErrors.peek()._1 should equal (0) + collector.receiverErrors.peek()._2 should include ("report error") + collector.receiverErrors.peek()._3 should include ("report exception") } } finally { ssc.stop() @@ -153,14 +158,22 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { ssc.start() try { eventually(timeout(30 seconds), interval(20 millis)) { - collector.startedOutputOperationIds.take(3) should be (Seq(0, 1, 2)) - collector.completedOutputOperationIds.take(3) should be (Seq(0, 1, 2)) + collector.startedOutputOperationIds.asScala.take(3) should be (Seq(0, 1, 2)) + collector.completedOutputOperationIds.asScala.take(3) should be (Seq(0, 1, 2)) } } finally { ssc.stop() } } + test("don't call ssc.stop in listener") { + ssc = new StreamingContext("local[2]", "ssc", Milliseconds(1000)) + val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) + inputStream.foreachRDD(_.count) + + startStreamingContextAndCallStop(ssc) + } + test("onBatchCompleted with successful batch") { ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) @@ -207,6 +220,42 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { assert(failureReasons(1).contains("This is another failed job")) } + test("StreamingListener receives no events after stopping StreamingListenerBus") { + val streamingListener = mock(classOf[StreamingListener]) + + ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) + ssc.addStreamingListener(streamingListener) + val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) + inputStream.foreachRDD(_.count) + ssc.start() + ssc.stop() + + // Because "streamingListener" has already received some events, let's clear that. + reset(streamingListener) + + // Post a Streaming event after stopping StreamingContext + val receiverInfoStopped = ReceiverInfo(0, "test", false, "localhost", "0") + ssc.scheduler.listenerBus.post(StreamingListenerReceiverStopped(receiverInfoStopped)) + ssc.sparkContext.listenerBus.waitUntilEmpty(1000) + // The StreamingListener should not receive any event + verifyNoMoreInteractions(streamingListener) + } + + private def startStreamingContextAndCallStop(_ssc: StreamingContext): Unit = { + val contextStoppingCollector = new StreamingContextStoppingCollector(_ssc) + _ssc.addStreamingListener(contextStoppingCollector) + val batchCounter = new BatchCounter(_ssc) + _ssc.start() + // Make sure running at least one batch + if (!batchCounter.waitUntilBatchesCompleted(expectedNumCompletedBatches = 1, timeout = 10000)) { + fail("The first batch cannot complete in 10 seconds") + } + // When reaching here, we can make sure `StreamingContextStoppingCollector` won't call + // `ssc.stop()`, so it's safe to call `_ssc.stop()` now. + _ssc.stop() + assert(contextStoppingCollector.sparkExSeen) + } + private def startStreamingContextAndCollectFailureReasons( _ssc: StreamingContext, isFailed: Boolean = false): Map[Int, String] = { val failureReasonsCollector = new FailureReasonsCollector() @@ -221,73 +270,70 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { } } _ssc.stop() - failureReasonsCollector.failureReasons.toMap + failureReasonsCollector.failureReasons.synchronized + { + failureReasonsCollector.failureReasons.toMap + } } /** Check if a sequence of numbers is in increasing order */ - def isInIncreasingOrder(seq: Seq[Long]): Boolean = { - for (i <- 1 until seq.size) { - if (seq(i - 1) > seq(i)) { - return false - } - } - true + def isInIncreasingOrder(data: Iterable[Long]): Boolean = { + !data.sliding(2).exists { itr => itr.size == 2 && itr.head > itr.tail.head } } } /** Listener that collects information on processed batches */ class BatchInfoCollector extends StreamingListener { - val batchInfosCompleted = new ArrayBuffer[BatchInfo] with SynchronizedBuffer[BatchInfo] - val batchInfosStarted = new ArrayBuffer[BatchInfo] with SynchronizedBuffer[BatchInfo] - val batchInfosSubmitted = new ArrayBuffer[BatchInfo] with SynchronizedBuffer[BatchInfo] + val batchInfosCompleted = new ConcurrentLinkedQueue[BatchInfo] + val batchInfosStarted = new ConcurrentLinkedQueue[BatchInfo] + val batchInfosSubmitted = new ConcurrentLinkedQueue[BatchInfo] override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted) { - batchInfosSubmitted += batchSubmitted.batchInfo + batchInfosSubmitted.add(batchSubmitted.batchInfo) } override def onBatchStarted(batchStarted: StreamingListenerBatchStarted) { - batchInfosStarted += batchStarted.batchInfo + batchInfosStarted.add(batchStarted.batchInfo) } override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { - batchInfosCompleted += batchCompleted.batchInfo + batchInfosCompleted.add(batchCompleted.batchInfo) } } /** Listener that collects information on processed batches */ class ReceiverInfoCollector extends StreamingListener { - val startedReceiverStreamIds = new ArrayBuffer[Int] with SynchronizedBuffer[Int] - val stoppedReceiverStreamIds = new ArrayBuffer[Int] with SynchronizedBuffer[Int] - val receiverErrors = - new ArrayBuffer[(Int, String, String)] with SynchronizedBuffer[(Int, String, String)] + val startedReceiverStreamIds = new ConcurrentLinkedQueue[Int] + val stoppedReceiverStreamIds = new ConcurrentLinkedQueue[Int] + val receiverErrors = new ConcurrentLinkedQueue[(Int, String, String)] override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { - startedReceiverStreamIds += receiverStarted.receiverInfo.streamId + startedReceiverStreamIds.add(receiverStarted.receiverInfo.streamId) } override def onReceiverStopped(receiverStopped: StreamingListenerReceiverStopped) { - stoppedReceiverStreamIds += receiverStopped.receiverInfo.streamId + stoppedReceiverStreamIds.add(receiverStopped.receiverInfo.streamId) } override def onReceiverError(receiverError: StreamingListenerReceiverError) { - receiverErrors += ((receiverError.receiverInfo.streamId, - receiverError.receiverInfo.lastErrorMessage, receiverError.receiverInfo.lastError)) + receiverErrors.add(((receiverError.receiverInfo.streamId, + receiverError.receiverInfo.lastErrorMessage, receiverError.receiverInfo.lastError))) } } /** Listener that collects information on processed output operations */ class OutputOperationInfoCollector extends StreamingListener { - val startedOutputOperationIds = new ArrayBuffer[Int] with SynchronizedBuffer[Int] - val completedOutputOperationIds = new ArrayBuffer[Int] with SynchronizedBuffer[Int] + val startedOutputOperationIds = new ConcurrentLinkedQueue[Int]() + val completedOutputOperationIds = new ConcurrentLinkedQueue[Int]() override def onOutputOperationStarted( outputOperationStarted: StreamingListenerOutputOperationStarted): Unit = { - startedOutputOperationIds += outputOperationStarted.outputOperationInfo.id + startedOutputOperationIds.add(outputOperationStarted.outputOperationInfo.id) } override def onOutputOperationCompleted( outputOperationCompleted: StreamingListenerOutputOperationCompleted): Unit = { - completedOutputOperationIds += outputOperationCompleted.outputOperationInfo.id + completedOutputOperationIds.add(outputOperationCompleted.outputOperationInfo.id) } } @@ -311,12 +357,38 @@ class StreamingListenerSuiteReceiver extends Receiver[Any](StorageLevel.MEMORY_O */ class FailureReasonsCollector extends StreamingListener { - val failureReasons = new HashMap[Int, String] with SynchronizedMap[Int, String] + val failureReasons = new HashMap[Int, String] override def onOutputOperationCompleted( outputOperationCompleted: StreamingListenerOutputOperationCompleted): Unit = { outputOperationCompleted.outputOperationInfo.failureReason.foreach { f => - failureReasons(outputOperationCompleted.outputOperationInfo.id) = f + failureReasons.synchronized + { + failureReasons(outputOperationCompleted.outputOperationInfo.id) = f + } + } + } +} +/** + * A StreamingListener that calls StreamingContext.stop(). + */ +class StreamingContextStoppingCollector(val ssc: StreamingContext) extends StreamingListener { + @volatile var sparkExSeen = false + + private var isFirstBatch = true + + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { + if (isFirstBatch) { + // We should only call `ssc.stop()` in the first batch. Otherwise, it's possible that the main + // thread is calling `ssc.stop()`, while StreamingContextStoppingCollector is also calling + // `ssc.stop()` in the listener thread, which becomes a dead-lock. + isFirstBatch = false + try { + ssc.stop() + } catch { + case se: SparkException => + sparkExSeen = true + } } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 0d58a7b54412f..fa975a146216d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -17,21 +17,22 @@ package org.apache.spark.streaming -import java.io.{ObjectInputStream, IOException} +import java.io.{IOException, ObjectInputStream} +import java.util.concurrent.ConcurrentLinkedQueue -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.SynchronizedBuffer +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag import org.scalatest.BeforeAndAfter -import org.scalatest.time.{Span, Seconds => ScalaTestSeconds} import org.scalatest.concurrent.Eventually.timeout import org.scalatest.concurrent.PatienceConfiguration +import org.scalatest.time.{Seconds => ScalaTestSeconds, Span} -import org.apache.spark.{Logging, SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} +import org.apache.spark.streaming.dstream.{DStream, ForEachDStream, InputDStream} import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{ManualClock, Utils} @@ -56,10 +57,10 @@ private[streaming] class DummyInputDStream(ssc: StreamingContext) extends InputD /** * This is a input stream just for the testsuites. This is equivalent to a checkpointable, * replayable, reliable message queue like Kafka. It requires a sequence as input, and - * returns the i_th element at the i_th batch unde manual clock. + * returns the i_th element at the i_th batch under manual clock. */ -class TestInputStream[T: ClassTag](ssc_ : StreamingContext, input: Seq[Seq[T]], numPartitions: Int) - extends InputDStream[T](ssc_) { +class TestInputStream[T: ClassTag](_ssc: StreamingContext, input: Seq[Seq[T]], numPartitions: Int) + extends InputDStream[T](_ssc) { def start() {} @@ -87,18 +88,18 @@ class TestInputStream[T: ClassTag](ssc_ : StreamingContext, input: Seq[Seq[T]], /** * This is a output stream just for the testsuites. All the output is collected into a - * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. + * ConcurrentLinkedQueue. This queue is wiped clean on being restored from checkpoint. * - * The buffer contains a sequence of RDD's, each containing a sequence of items + * The buffer contains a sequence of RDD's, each containing a sequence of items. */ class TestOutputStream[T: ClassTag]( parent: DStream[T], - val output: SynchronizedBuffer[Seq[T]] = - new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]] + val output: ConcurrentLinkedQueue[Seq[T]] = + new ConcurrentLinkedQueue[Seq[T]]() ) extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.collect() - output += collected - }) { + output.add(collected) + }, false) { // This is to clear the output buffer every it is read from a checkpoint @throws(classOf[IOException]) @@ -110,19 +111,19 @@ class TestOutputStream[T: ClassTag]( /** * This is a output stream just for the testsuites. All the output is collected into a - * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. + * ConcurrentLinkedQueue. This queue is wiped clean on being restored from checkpoint. * - * The buffer contains a sequence of RDD's, each containing a sequence of partitions, each + * The queue contains a sequence of RDD's, each containing a sequence of partitions, each * containing a sequence of items. */ class TestOutputStreamWithPartitions[T: ClassTag]( parent: DStream[T], - val output: SynchronizedBuffer[Seq[Seq[T]]] = - new ArrayBuffer[Seq[Seq[T]]] with SynchronizedBuffer[Seq[Seq[T]]]) + val output: ConcurrentLinkedQueue[Seq[Seq[T]]] = + new ConcurrentLinkedQueue[Seq[Seq[T]]]()) extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.glom().collect().map(_.toSeq) - output += collected - }) { + output.add(collected) + }, false) { // This is to clear the output buffer every it is read from a checkpoint @throws(classOf[IOException]) @@ -142,6 +143,7 @@ class BatchCounter(ssc: StreamingContext) { // All access to this state should be guarded by `BatchCounter.this.synchronized` private var numCompletedBatches = 0 private var numStartedBatches = 0 + private var lastCompletedBatchTime: Time = null private val listener = new StreamingListener { override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = @@ -152,6 +154,7 @@ class BatchCounter(ssc: StreamingContext) { override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = BatchCounter.this.synchronized { numCompletedBatches += 1 + lastCompletedBatchTime = batchCompleted.batchInfo.batchTime BatchCounter.this.notifyAll() } } @@ -165,6 +168,10 @@ class BatchCounter(ssc: StreamingContext) { numStartedBatches } + def getLastCompletedBatchTime: Time = this.synchronized { + lastCompletedBatchTime + } + /** * Wait until `expectedNumCompletedBatches` batches are completed, or timeout. Return true if * `expectedNumCompletedBatches` batches are completed. Otherwise, return false to indicate it's @@ -316,7 +323,7 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { val inputStream = new TestInputStream(ssc, input, numPartitions) val operatedStream = operation(inputStream) val outputStream = new TestOutputStreamWithPartitions(operatedStream, - new ArrayBuffer[Seq[Seq[V]]] with SynchronizedBuffer[Seq[Seq[V]]]) + new ConcurrentLinkedQueue[Seq[Seq[V]]]) outputStream.register() ssc } @@ -341,7 +348,7 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { val inputStream2 = new TestInputStream(ssc, input2, numInputPartitions) val operatedStream = operation(inputStream1, inputStream2) val outputStream = new TestOutputStreamWithPartitions(operatedStream, - new ArrayBuffer[Seq[Seq[W]]] with SynchronizedBuffer[Seq[Seq[W]]]) + new ConcurrentLinkedQueue[Seq[Seq[W]]]) outputStream.register() ssc } @@ -412,7 +419,7 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { } val timeTaken = System.currentTimeMillis() - startTime logInfo("Output generated in " + timeTaken + " milliseconds") - output.foreach(x => logInfo("[" + x.mkString(",") + "]")) + output.asScala.foreach(x => logInfo("[" + x.mkString(",") + "]")) assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") assert(output.size === numExpectedOutput, "Unexpected number of outputs generated") @@ -420,7 +427,7 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { } finally { ssc.stop(stopSparkContext = true) } - output + output.asScala.toSeq } /** @@ -495,7 +502,7 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size withStreamingContext(setupStreams[U, V](input, operation)) { ssc => val output = runStreams[V](ssc, numBatches_, expectedOutput.size) - verifyOutput[V](output, expectedOutput, useSet) + verifyOutput[V](output.toSeq, expectedOutput, useSet) } } @@ -534,7 +541,7 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size withStreamingContext(setupStreams[U, V, W](input1, input2, operation)) { ssc => val output = runStreams[W](ssc, numBatches_, expectedOutput.size) - verifyOutput[W](output, expectedOutput, useSet) + verifyOutput[W](output.toSeq, expectedOutput, useSet) } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index a5744a9009c1c..454c3dffa3db1 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -38,14 +38,19 @@ class UISeleniumSuite implicit var webDriver: WebDriver = _ override def beforeAll(): Unit = { + super.beforeAll() webDriver = new HtmlUnitDriver { getWebClient.setCssErrorHandler(new SparkUICssErrorHandler) } } override def afterAll(): Unit = { - if (webDriver != null) { - webDriver.quit() + try { + if (webDriver != null) { + webDriver.quit() + } + } finally { + super.afterAll() } } @@ -138,8 +143,9 @@ class UISeleniumSuite summaryText should contain ("Total delay:") findAll(cssSelector("""#batch-job-table th""")).map(_.text).toSeq should be { - List("Output Op Id", "Description", "Duration", "Status", "Job Id", "Duration", - "Stages: Succeeded/Total", "Tasks (for all stages): Succeeded/Total", "Error") + List("Output Op Id", "Description", "Output Op Duration", "Status", "Job Id", + "Job Duration", "Stages: Succeeded/Total", "Tasks (for all stages): Succeeded/Total", + "Error") } // Check we have 2 output op ids @@ -155,17 +161,17 @@ class UISeleniumSuite jobLinks.size should be (4) // Check stage progress - findAll(cssSelector(""".stage-progress-cell""")).map(_.text).toSeq should be - (List("1/1", "1/1", "1/1", "0/1 (1 failed)")) + findAll(cssSelector(""".stage-progress-cell""")).map(_.text).toList should be ( + List("1/1", "1/1", "1/1", "0/1 (1 failed)")) // Check job progress - findAll(cssSelector(""".progress-cell""")).map(_.text).toSeq should be - (List("1/1", "1/1", "1/1", "0/1 (1 failed)")) + findAll(cssSelector(""".progress-cell""")).map(_.text).toList should be ( + List("4/4", "4/4", "4/4", "0/4 (1 failed)")) // Check stacktrace - val errorCells = findAll(cssSelector(""".stacktrace-details""")).map(_.text).toSeq + val errorCells = findAll(cssSelector(""".stacktrace-details""")).map(_.underlying).toSeq errorCells should have size 1 - errorCells(0) should include("java.lang.RuntimeException: Oops") + // Can't get the inner (invisible) text without running JS // Check the job link in the batch page is right go to (jobLinks(0)) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala index c39ad05f41520..c7d085ec0799b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.streaming -import org.apache.spark.streaming.dstream.DStream import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.dstream.DStream class WindowOperationsSuite extends TestSuiteBase { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala new file mode 100644 index 0000000000000..e8c814ba7184b --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala @@ -0,0 +1,399 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.rdd + +import java.io.File + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{State, Time} +import org.apache.spark.streaming.util.OpenHashMapBasedStateMap +import org.apache.spark.util.Utils + +class MapWithStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with BeforeAndAfterAll { + + private var sc: SparkContext = null + private var checkpointDir: File = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sc = new SparkContext( + new SparkConf().setMaster("local").setAppName("MapWithStateRDDSuite")) + checkpointDir = Utils.createTempDir() + sc.setCheckpointDir(checkpointDir.toString) + } + + override def afterAll(): Unit = { + try { + if (sc != null) { + sc.stop() + } + Utils.deleteRecursively(checkpointDir) + } finally { + super.afterAll() + } + } + + override def sparkContext: SparkContext = sc + + test("creation from pair RDD") { + val data = Seq((1, "1"), (2, "2"), (3, "3")) + val partitioner = new HashPartitioner(10) + val rdd = MapWithStateRDD.createFromPairRDD[Int, Int, String, Int]( + sc.parallelize(data), partitioner, Time(123)) + assertRDD[Int, Int, String, Int](rdd, data.map { x => (x._1, x._2, 123)}.toSet, Set.empty) + assert(rdd.partitions.size === partitioner.numPartitions) + + assert(rdd.partitioner === Some(partitioner)) + } + + test("updating state and generating mapped data in MapWithStateRDDRecord") { + + val initialTime = 1000L + val updatedTime = 2000L + val thresholdTime = 1500L + @volatile var functionCalled = false + + /** + * Assert that applying given data on a prior record generates correct updated record, with + * correct state map and mapped data + */ + def assertRecordUpdate( + initStates: Iterable[Int], + data: Iterable[String], + expectedStates: Iterable[(Int, Long)], + timeoutThreshold: Option[Long] = None, + removeTimedoutData: Boolean = false, + expectedOutput: Iterable[Int] = None, + expectedTimingOutStates: Iterable[Int] = None, + expectedRemovedStates: Iterable[Int] = None + ): Unit = { + val initialStateMap = new OpenHashMapBasedStateMap[String, Int]() + initStates.foreach { s => initialStateMap.put("key", s, initialTime) } + functionCalled = false + val record = MapWithStateRDDRecord[String, Int, Int](initialStateMap, Seq.empty) + val dataIterator = data.map { v => ("key", v) }.iterator + val removedStates = new ArrayBuffer[Int] + val timingOutStates = new ArrayBuffer[Int] + /** + * Mapping function that updates/removes state based on instructions in the data, and + * return state (when instructed or when state is timing out). + */ + def testFunc(t: Time, key: String, data: Option[String], state: State[Int]): Option[Int] = { + functionCalled = true + + assert(t.milliseconds === updatedTime, "mapping func called with wrong time") + + data match { + case Some("noop") => + None + case Some("get-state") => + Some(state.getOption().getOrElse(-1)) + case Some("update-state") => + if (state.exists) state.update(state.get + 1) else state.update(0) + None + case Some("remove-state") => + removedStates += state.get() + state.remove() + None + case None => + assert(state.isTimingOut() === true, "State is not timing out when data = None") + timingOutStates += state.get() + None + case _ => + fail("Unexpected test data") + } + } + + val updatedRecord = MapWithStateRDDRecord.updateRecordWithData[String, String, Int, Int]( + Some(record), dataIterator, testFunc, + Time(updatedTime), timeoutThreshold, removeTimedoutData) + + val updatedStateData = updatedRecord.stateMap.getAll().map { x => (x._2, x._3) } + assert(updatedStateData.toSet === expectedStates.toSet, + "states do not match after updating the MapWithStateRDDRecord") + + assert(updatedRecord.mappedData.toSet === expectedOutput.toSet, + "mapped data do not match after updating the MapWithStateRDDRecord") + + assert(timingOutStates.toSet === expectedTimingOutStates.toSet, "timing out states do not " + + "match those that were expected to do so while updating the MapWithStateRDDRecord") + + assert(removedStates.toSet === expectedRemovedStates.toSet, "removed states do not " + + "match those that were expected to do so while updating the MapWithStateRDDRecord") + + } + + // No data, no state should be changed, function should not be called, + assertRecordUpdate(initStates = Nil, data = None, expectedStates = Nil) + assert(functionCalled === false) + assertRecordUpdate(initStates = Seq(0), data = None, expectedStates = Seq((0, initialTime))) + assert(functionCalled === false) + + // Data present, function should be called irrespective of whether state exists + assertRecordUpdate(initStates = Seq(0), data = Seq("noop"), + expectedStates = Seq((0, initialTime))) + assert(functionCalled === true) + assertRecordUpdate(initStates = None, data = Some("noop"), expectedStates = None) + assert(functionCalled === true) + + // Function called with right state data + assertRecordUpdate(initStates = None, data = Seq("get-state"), + expectedStates = None, expectedOutput = Seq(-1)) + assertRecordUpdate(initStates = Seq(123), data = Seq("get-state"), + expectedStates = Seq((123, initialTime)), expectedOutput = Seq(123)) + + // Update state and timestamp, when timeout not present + assertRecordUpdate(initStates = Nil, data = Seq("update-state"), + expectedStates = Seq((0, updatedTime))) + assertRecordUpdate(initStates = Seq(0), data = Seq("update-state"), + expectedStates = Seq((1, updatedTime))) + + // Remove state + assertRecordUpdate(initStates = Seq(345), data = Seq("remove-state"), + expectedStates = Nil, expectedRemovedStates = Seq(345)) + + // State strictly older than timeout threshold should be timed out + assertRecordUpdate(initStates = Seq(123), data = Nil, + timeoutThreshold = Some(initialTime), removeTimedoutData = true, + expectedStates = Seq((123, initialTime)), expectedTimingOutStates = Nil) + + assertRecordUpdate(initStates = Seq(123), data = Nil, + timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true, + expectedStates = Nil, expectedTimingOutStates = Seq(123)) + + // State should not be timed out after it has received data + assertRecordUpdate(initStates = Seq(123), data = Seq("noop"), + timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true, + expectedStates = Seq((123, updatedTime)), expectedTimingOutStates = Nil) + assertRecordUpdate(initStates = Seq(123), data = Seq("remove-state"), + timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true, + expectedStates = Nil, expectedTimingOutStates = Nil, expectedRemovedStates = Seq(123)) + + // If a state is not set but timeoutThreshold is defined, we should ignore this state. + // Previously it threw NoSuchElementException (SPARK-13195). + assertRecordUpdate(initStates = Seq(), data = Seq("noop"), + timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true, + expectedStates = Nil, expectedTimingOutStates = Nil) + } + + test("states generated by MapWithStateRDD") { + val initStates = Seq(("k1", 0), ("k2", 0)) + val initTime = 123 + val initStateWthTime = initStates.map { x => (x._1, x._2, initTime) }.toSet + val partitioner = new HashPartitioner(2) + val initStateRDD = MapWithStateRDD.createFromPairRDD[String, Int, Int, Int]( + sc.parallelize(initStates), partitioner, Time(initTime)).persist() + assertRDD(initStateRDD, initStateWthTime, Set.empty) + + val updateTime = 345 + + /** + * Test that the test state RDD, when operated with new data, + * creates a new state RDD with expected states + */ + def testStateUpdates( + testStateRDD: MapWithStateRDD[String, Int, Int, Int], + testData: Seq[(String, Int)], + expectedStates: Set[(String, Int, Int)]): MapWithStateRDD[String, Int, Int, Int] = { + + // Persist the test MapWithStateRDD so that its not recomputed while doing the next operation. + // This is to make sure that we only touch which state keys are being touched in the next op. + testStateRDD.persist().count() + + // To track which keys are being touched + MapWithStateRDDSuite.touchedStateKeys.clear() + + val mappingFunction = (time: Time, key: String, data: Option[Int], state: State[Int]) => { + + // Track the key that has been touched + MapWithStateRDDSuite.touchedStateKeys += key + + // If the data is 0, do not do anything with the state + // else if the data is 1, increment the state if it exists, or set new state to 0 + // else if the data is 2, remove the state if it exists + data match { + case Some(1) => + if (state.exists()) { state.update(state.get + 1) } + else state.update(0) + case Some(2) => + state.remove() + case _ => + } + None.asInstanceOf[Option[Int]] // Do not return anything, not being tested + } + val newDataRDD = sc.makeRDD(testData).partitionBy(testStateRDD.partitioner.get) + + // Assert that the new state RDD has expected state data + val newStateRDD = assertOperation( + testStateRDD, newDataRDD, mappingFunction, updateTime, expectedStates, Set.empty) + + // Assert that the function was called only for the keys present in the data + assert(MapWithStateRDDSuite.touchedStateKeys.size === testData.size, + "More number of keys are being touched than that is expected") + assert(MapWithStateRDDSuite.touchedStateKeys.toSet === testData.toMap.keys, + "Keys not in the data are being touched unexpectedly") + + // Assert that the test RDD's data has not changed + assertRDD(initStateRDD, initStateWthTime, Set.empty) + newStateRDD + } + + // Test no-op, no state should change + testStateUpdates(initStateRDD, Seq(), initStateWthTime) // should not scan any state + testStateUpdates( + initStateRDD, Seq(("k1", 0)), initStateWthTime) // should not update existing state + testStateUpdates( + initStateRDD, Seq(("k3", 0)), initStateWthTime) // should not create new state + + // Test creation of new state + val rdd1 = testStateUpdates(initStateRDD, Seq(("k3", 1)), // should create k3's state as 0 + Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime))) + + val rdd2 = testStateUpdates(rdd1, Seq(("k4", 1)), // should create k4's state as 0 + Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime), ("k4", 0, updateTime))) + + // Test updating of state + val rdd3 = testStateUpdates( + initStateRDD, Seq(("k1", 1)), // should increment k1's state 0 -> 1 + Set(("k1", 1, updateTime), ("k2", 0, initTime))) + + val rdd4 = testStateUpdates(rdd3, + Seq(("x", 0), ("k2", 1), ("k2", 1), ("k3", 1)), // should update k2, 0 -> 2 and create k3, 0 + Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 0, updateTime))) + + val rdd5 = testStateUpdates( + rdd4, Seq(("k3", 1)), // should update k3's state 0 -> 2 + Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 1, updateTime))) + + // Test removing of state + val rdd6 = testStateUpdates( // should remove k1's state + initStateRDD, Seq(("k1", 2)), Set(("k2", 0, initTime))) + + val rdd7 = testStateUpdates( // should remove k2's state + rdd6, Seq(("k2", 2), ("k0", 2), ("k3", 1)), Set(("k3", 0, updateTime))) + + val rdd8 = testStateUpdates( // should remove k3's state + rdd7, Seq(("k3", 2)), Set()) + } + + test("checkpointing") { + /** + * This tests whether the MapWithStateRDD correctly truncates any references to its parent RDDs + * - the data RDD and the parent MapWithStateRDD. + */ + def rddCollectFunc(rdd: RDD[MapWithStateRDDRecord[Int, Int, Int]]) + : Set[(List[(Int, Int, Long)], List[Int])] = { + rdd.map { record => (record.stateMap.getAll().toList, record.mappedData.toList) } + .collect.toSet + } + + /** Generate MapWithStateRDD with data RDD having a long lineage */ + def makeStateRDDWithLongLineageDataRDD(longLineageRDD: RDD[Int]) + : MapWithStateRDD[Int, Int, Int, Int] = { + MapWithStateRDD.createFromPairRDD(longLineageRDD.map { _ -> 1}, partitioner, Time(0)) + } + + testRDD( + makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _) + testRDDPartitions( + makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _) + + /** Generate MapWithStateRDD with parent state RDD having a long lineage */ + def makeStateRDDWithLongLineageParenttateRDD( + longLineageRDD: RDD[Int]): MapWithStateRDD[Int, Int, Int, Int] = { + + // Create a MapWithStateRDD that has a long lineage using the data RDD with a long lineage + val stateRDDWithLongLineage = makeStateRDDWithLongLineageDataRDD(longLineageRDD) + + // Create a new MapWithStateRDD, with the lineage lineage MapWithStateRDD as the parent + new MapWithStateRDD[Int, Int, Int, Int]( + stateRDDWithLongLineage, + stateRDDWithLongLineage.sparkContext.emptyRDD[(Int, Int)].partitionBy(partitioner), + (time: Time, key: Int, value: Option[Int], state: State[Int]) => None, + Time(10), + None + ) + } + + testRDD( + makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _) + testRDDPartitions( + makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _) + } + + test("checkpointing empty state RDD") { + val emptyStateRDD = MapWithStateRDD.createFromPairRDD[Int, Int, Int, Int]( + sc.emptyRDD[(Int, Int)], new HashPartitioner(10), Time(0)) + emptyStateRDD.checkpoint() + assert(emptyStateRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty) + val cpRDD = sc.checkpointFile[MapWithStateRDDRecord[Int, Int, Int]]( + emptyStateRDD.getCheckpointFile.get) + assert(cpRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty) + } + + /** Assert whether the `mapWithState` operation generates expected results */ + private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + testStateRDD: MapWithStateRDD[K, V, S, T], + newDataRDD: RDD[(K, V)], + mappingFunction: (Time, K, Option[V], State[S]) => Option[T], + currentTime: Long, + expectedStates: Set[(K, S, Int)], + expectedMappedData: Set[T], + doFullScan: Boolean = false + ): MapWithStateRDD[K, V, S, T] = { + + val partitionedNewDataRDD = if (newDataRDD.partitioner != testStateRDD.partitioner) { + newDataRDD.partitionBy(testStateRDD.partitioner.get) + } else { + newDataRDD + } + + val newStateRDD = new MapWithStateRDD[K, V, S, T]( + testStateRDD, newDataRDD, mappingFunction, Time(currentTime), None) + if (doFullScan) newStateRDD.setFullScan() + + // Persist to make sure that it gets computed only once and we can track precisely how many + // state keys the computing touched + newStateRDD.persist().count() + assertRDD(newStateRDD, expectedStates, expectedMappedData) + newStateRDD + } + + /** Assert whether the [[MapWithStateRDD]] has the expected state and mapped data */ + private def assertRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + stateRDD: MapWithStateRDD[K, V, S, T], + expectedStates: Set[(K, S, Int)], + expectedMappedData: Set[T]): Unit = { + val states = stateRDD.flatMap { _.stateMap.getAll() }.collect().toSet + val mappedData = stateRDD.flatMap { _.mappedData }.collect().toSet + assert(states === expectedStates, + "states after mapWithState operation were not as expected") + assert(mappedData === expectedMappedData, + "mapped data after mapWithState operation were not as expected") + } +} + +object MapWithStateRDDSuite { + private val touchedStateKeys = new ArrayBuffer[String]() +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index cb017b798b2a4..ce5a6e00fb2fe 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -23,10 +23,11 @@ import scala.util.Random import org.apache.hadoop.conf.Configuration import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} +import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark.serializer.SerializerManager import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} import org.apache.spark.streaming.util.{FileBasedWriteAheadLogSegment, FileBasedWriteAheadLogWriter} import org.apache.spark.util.Utils -import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} class WriteAheadLogBackedBlockRDDSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfterEach { @@ -39,25 +40,39 @@ class WriteAheadLogBackedBlockRDDSuite var sparkContext: SparkContext = null var blockManager: BlockManager = null + var serializerManager: SerializerManager = null var dir: File = null override def beforeEach(): Unit = { + super.beforeEach() dir = Utils.createTempDir() } override def afterEach(): Unit = { - Utils.deleteRecursively(dir) + try { + Utils.deleteRecursively(dir) + } finally { + super.afterEach() + } } override def beforeAll(): Unit = { + super.beforeAll() sparkContext = new SparkContext(conf) blockManager = sparkContext.env.blockManager + serializerManager = sparkContext.env.serializerManager } override def afterAll(): Unit = { // Copied from LocalSparkContext, simpler than to introduced test dependencies to core tests. - sparkContext.stop() - System.clearProperty("spark.driver.port") + try { + sparkContext.stop() + System.clearProperty("spark.driver.port") + blockManager = null + serializerManager = null + } finally { + super.afterAll() + } } test("Read data available in both block manager and write ahead log") { @@ -97,8 +112,6 @@ class WriteAheadLogBackedBlockRDDSuite * It can also test if the partitions that were read from the log were again stored in * block manager. * - * - * * @param numPartitions Number of partitions in RDD * @param numPartitionsInBM Number of partitions to write to the BlockManager. * Partitions 0 to (numPartitionsInBM-1) will be written to BlockManager @@ -213,7 +226,7 @@ class WriteAheadLogBackedBlockRDDSuite require(blockData.size === blockIds.size) val writer = new FileBasedWriteAheadLogWriter(new File(dir, "logFile").toString, hadoopConf) val segments = blockData.zip(blockIds).map { case (data, id) => - writer.write(blockManager.dataSerialize(id, data.iterator)) + writer.write(serializerManager.dataSerialize(id, data.iterator).toByteBuffer) } writer.close() segments diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala index 2f11b255f1104..a1d0561bf308a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala @@ -17,17 +17,21 @@ package org.apache.spark.streaming.receiver +import java.util.concurrent.ConcurrentLinkedQueue + +import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.language.reflectiveCalls import org.scalatest.BeforeAndAfter import org.scalatest.Matchers._ -import org.scalatest.concurrent.Timeouts._ import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.Timeouts._ import org.scalatest.time.SpanSugar._ +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.storage.StreamBlockId import org.apache.spark.util.ManualClock -import org.apache.spark.{SparkException, SparkConf, SparkFunSuite} class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { @@ -83,7 +87,7 @@ class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { assert(listener.onPushBlockCalled === true) } } - listener.pushedData should contain theSameElementsInOrderAs (data1) + listener.pushedData.asScala.toSeq should contain theSameElementsInOrderAs (data1) assert(listener.onAddDataCalled === false) // should be called only with addDataWithCallback() // Verify addDataWithCallback() add data+metadata and and callbacks are called correctly @@ -91,21 +95,24 @@ class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { val metadata2 = data2.map { _.toString } data2.zip(metadata2).foreach { case (d, m) => blockGenerator.addDataWithCallback(d, m) } assert(listener.onAddDataCalled === true) - listener.addedData should contain theSameElementsInOrderAs (data2) - listener.addedMetadata should contain theSameElementsInOrderAs (metadata2) + listener.addedData.asScala.toSeq should contain theSameElementsInOrderAs (data2) + listener.addedMetadata.asScala.toSeq should contain theSameElementsInOrderAs (metadata2) clock.advance(blockIntervalMs) // advance clock to generate blocks eventually(timeout(1 second)) { - listener.pushedData should contain theSameElementsInOrderAs (data1 ++ data2) + val combined = data1 ++ data2 + listener.pushedData.asScala.toSeq should contain theSameElementsInOrderAs combined } // Verify addMultipleDataWithCallback() add data+metadata and and callbacks are called correctly val data3 = 21 to 30 val metadata3 = "metadata" blockGenerator.addMultipleDataWithCallback(data3.iterator, metadata3) - listener.addedMetadata should contain theSameElementsInOrderAs (metadata2 :+ metadata3) + val combinedMetadata = metadata2 :+ metadata3 + listener.addedMetadata.asScala.toSeq should contain theSameElementsInOrderAs (combinedMetadata) clock.advance(blockIntervalMs) // advance clock to generate blocks eventually(timeout(1 second)) { - listener.pushedData should contain theSameElementsInOrderAs (data1 ++ data2 ++ data3) + val combinedData = data1 ++ data2 ++ data3 + listener.pushedData.asScala.toSeq should contain theSameElementsInOrderAs (combinedData) } // Stop the block generator by starting the stop on a different thread and @@ -190,7 +197,7 @@ class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { assert(thread.isAlive === false) } assert(blockGenerator.isStopped() === true) // generator has finally been completely stopped - assert(listener.pushedData === data, "All data not pushed by stop()") + assert(listener.pushedData.asScala.toSeq === data, "All data not pushed by stop()") } test("block push errors are reported") { @@ -230,15 +237,15 @@ class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { /** A listener for BlockGenerator that records the data in the callbacks */ private class TestBlockGeneratorListener extends BlockGeneratorListener { - val pushedData = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any] - val addedData = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any] - val addedMetadata = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any] + val pushedData = new ConcurrentLinkedQueue[Any] + val addedData = new ConcurrentLinkedQueue[Any] + val addedMetadata = new ConcurrentLinkedQueue[Any] @volatile var onGenerateBlockCalled = false @volatile var onAddDataCalled = false @volatile var onPushBlockCalled = false override def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { - pushedData ++= arrayBuffer + pushedData.addAll(arrayBuffer.asJava) onPushBlockCalled = true } override def onError(message: String, throwable: Throwable): Unit = {} @@ -246,8 +253,8 @@ class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { onGenerateBlockCalled = true } override def onAddData(data: Any, metadata: Any): Unit = { - addedData += data - addedMetadata += metadata + addedData.add(data) + addedMetadata.add(metadata) onAddDataCalled = true } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/RateLimiterSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/RateLimiterSuite.scala index c6330eb3673fb..ee3817c4b605d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/receiver/RateLimiterSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/RateLimiterSuite.scala @@ -25,21 +25,21 @@ class RateLimiterSuite extends SparkFunSuite { test("rate limiter initializes even without a maxRate set") { val conf = new SparkConf() - val rateLimiter = new RateLimiter(conf){} + val rateLimiter = new RateLimiter(conf) {} rateLimiter.updateRate(105) assert(rateLimiter.getCurrentLimit == 105) } test("rate limiter updates when below maxRate") { val conf = new SparkConf().set("spark.streaming.receiver.maxRate", "110") - val rateLimiter = new RateLimiter(conf){} + val rateLimiter = new RateLimiter(conf) {} rateLimiter.updateRate(105) assert(rateLimiter.getCurrentLimit == 105) } test("rate limiter stays below maxRate despite large updates") { val conf = new SparkConf().set("spark.streaming.receiver.maxRate", "100") - val rateLimiter = new RateLimiter(conf){} + val rateLimiter = new RateLimiter(conf) {} rateLimiter.updateRate(105) assert(rateLimiter.getCurrentLimit === 100) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala new file mode 100644 index 0000000000000..7630f4a75e336 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala @@ -0,0 +1,395 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.mockito.Matchers.{eq => meq} +import org.mockito.Mockito._ +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, PrivateMethodTester} +import org.scalatest.concurrent.Eventually.{eventually, timeout} +import org.scalatest.mock.MockitoSugar +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.{ExecutorAllocationClient, SparkConf, SparkFunSuite} +import org.apache.spark.streaming.{DummyInputDStream, Seconds, StreamingContext} +import org.apache.spark.util.{ManualClock, Utils} + + +class ExecutorAllocationManagerSuite extends SparkFunSuite + with BeforeAndAfter with BeforeAndAfterAll with MockitoSugar with PrivateMethodTester { + + import ExecutorAllocationManager._ + + private val batchDurationMillis = 1000L + private var allocationClient: ExecutorAllocationClient = null + private var clock: ManualClock = null + + before { + allocationClient = mock[ExecutorAllocationClient] + clock = new ManualClock() + } + + test("basic functionality") { + // Test that adding batch processing time info to allocation manager + // causes executors to be requested and killed accordingly + + // There is 1 receiver, and exec 1 has been allocated to it + withAllocationManager(numReceivers = 1) { case (receiverTracker, allocationManager) => + when(receiverTracker.allocatedExecutors).thenReturn(Map(1 -> Some("1"))) + + /** Add data point for batch processing time and verify executor allocation */ + def addBatchProcTimeAndVerifyAllocation(batchProcTimeMs: Double)(body: => Unit): Unit = { + // 2 active executors + reset(allocationClient) + when(allocationClient.getExecutorIds()).thenReturn(Seq("1", "2")) + addBatchProcTime(allocationManager, batchProcTimeMs.toLong) + clock.advance(SCALING_INTERVAL_DEFAULT_SECS * 1000 + 1) + eventually(timeout(10 seconds)) { + body + } + } + + /** Verify that the expected number of total executor were requested */ + def verifyTotalRequestedExecs(expectedRequestedTotalExecs: Option[Int]): Unit = { + if (expectedRequestedTotalExecs.nonEmpty) { + require(expectedRequestedTotalExecs.get > 0) + verify(allocationClient, times(1)).requestTotalExecutors( + meq(expectedRequestedTotalExecs.get), meq(0), meq(Map.empty)) + } else { + verify(allocationClient, never).requestTotalExecutors(0, 0, Map.empty) + } + } + + /** Verify that a particular executor was killed */ + def verifyKilledExec(expectedKilledExec: Option[String]): Unit = { + if (expectedKilledExec.nonEmpty) { + verify(allocationClient, times(1)).killExecutor(meq(expectedKilledExec.get)) + } else { + verify(allocationClient, never).killExecutor(null) + } + } + + // Batch proc time = batch interval, should increase allocation by 1 + addBatchProcTimeAndVerifyAllocation(batchDurationMillis) { + verifyTotalRequestedExecs(Some(3)) // one already allocated, increase allocation by 1 + verifyKilledExec(None) + } + + // Batch proc time = batch interval * 2, should increase allocation by 2 + addBatchProcTimeAndVerifyAllocation(batchDurationMillis * 2) { + verifyTotalRequestedExecs(Some(4)) + verifyKilledExec(None) + } + + // Batch proc time slightly more than the scale up ratio, should increase allocation by 1 + addBatchProcTimeAndVerifyAllocation(batchDurationMillis * SCALING_UP_RATIO_DEFAULT + 1) { + verifyTotalRequestedExecs(Some(3)) + verifyKilledExec(None) + } + + // Batch proc time slightly less than the scale up ratio, should not change allocation + addBatchProcTimeAndVerifyAllocation(batchDurationMillis * SCALING_UP_RATIO_DEFAULT - 1) { + verifyTotalRequestedExecs(None) + verifyKilledExec(None) + } + + // Batch proc time slightly more than the scale down ratio, should not change allocation + addBatchProcTimeAndVerifyAllocation(batchDurationMillis * SCALING_DOWN_RATIO_DEFAULT + 1) { + verifyTotalRequestedExecs(None) + verifyKilledExec(None) + } + + // Batch proc time slightly more than the scale down ratio, should not change allocation + addBatchProcTimeAndVerifyAllocation(batchDurationMillis * SCALING_DOWN_RATIO_DEFAULT - 1) { + verifyTotalRequestedExecs(None) + verifyKilledExec(Some("2")) + } + } + } + + test("requestExecutors policy") { + + /** Verify that the expected number of total executor were requested */ + def verifyRequestedExecs( + numExecs: Int, + numNewExecs: Int, + expectedRequestedTotalExecs: Int)( + implicit allocationManager: ExecutorAllocationManager): Unit = { + reset(allocationClient) + when(allocationClient.getExecutorIds()).thenReturn((1 to numExecs).map(_.toString)) + requestExecutors(allocationManager, numNewExecs) + verify(allocationClient, times(1)).requestTotalExecutors( + meq(expectedRequestedTotalExecs), meq(0), meq(Map.empty)) + } + + withAllocationManager(numReceivers = 1) { case (_, allocationManager) => + implicit val am = allocationManager + intercept[IllegalArgumentException] { + verifyRequestedExecs(numExecs = 0, numNewExecs = 0, 0) + } + verifyRequestedExecs(numExecs = 0, numNewExecs = 1, expectedRequestedTotalExecs = 1) + verifyRequestedExecs(numExecs = 1, numNewExecs = 1, expectedRequestedTotalExecs = 2) + verifyRequestedExecs(numExecs = 2, numNewExecs = 2, expectedRequestedTotalExecs = 4) + } + + withAllocationManager(numReceivers = 2) { case(_, allocationManager) => + implicit val am = allocationManager + + verifyRequestedExecs(numExecs = 0, numNewExecs = 1, expectedRequestedTotalExecs = 2) + verifyRequestedExecs(numExecs = 1, numNewExecs = 1, expectedRequestedTotalExecs = 2) + verifyRequestedExecs(numExecs = 2, numNewExecs = 2, expectedRequestedTotalExecs = 4) + } + + withAllocationManager( + // Test min 2 executors + new SparkConf().set("spark.streaming.dynamicAllocation.minExecutors", "2")) { + case (_, allocationManager) => + implicit val am = allocationManager + + verifyRequestedExecs(numExecs = 0, numNewExecs = 1, expectedRequestedTotalExecs = 2) + verifyRequestedExecs(numExecs = 0, numNewExecs = 3, expectedRequestedTotalExecs = 3) + verifyRequestedExecs(numExecs = 1, numNewExecs = 1, expectedRequestedTotalExecs = 2) + verifyRequestedExecs(numExecs = 1, numNewExecs = 2, expectedRequestedTotalExecs = 3) + verifyRequestedExecs(numExecs = 2, numNewExecs = 1, expectedRequestedTotalExecs = 3) + verifyRequestedExecs(numExecs = 2, numNewExecs = 2, expectedRequestedTotalExecs = 4) + } + + withAllocationManager( + // Test with max 2 executors + new SparkConf().set("spark.streaming.dynamicAllocation.maxExecutors", "2")) { + case (_, allocationManager) => + implicit val am = allocationManager + + verifyRequestedExecs(numExecs = 0, numNewExecs = 1, expectedRequestedTotalExecs = 1) + verifyRequestedExecs(numExecs = 0, numNewExecs = 3, expectedRequestedTotalExecs = 2) + verifyRequestedExecs(numExecs = 1, numNewExecs = 2, expectedRequestedTotalExecs = 2) + verifyRequestedExecs(numExecs = 2, numNewExecs = 1, expectedRequestedTotalExecs = 2) + verifyRequestedExecs(numExecs = 2, numNewExecs = 2, expectedRequestedTotalExecs = 2) + } + } + + test("killExecutor policy") { + + /** + * Verify that a particular executor was killed, given active executors and executors + * allocated to receivers. + */ + def verifyKilledExec( + execIds: Seq[String], + receiverExecIds: Map[Int, Option[String]], + expectedKilledExec: Option[String])( + implicit x: (ReceiverTracker, ExecutorAllocationManager)): Unit = { + val (receiverTracker, allocationManager) = x + + reset(allocationClient) + when(allocationClient.getExecutorIds()).thenReturn(execIds) + when(receiverTracker.allocatedExecutors).thenReturn(receiverExecIds) + killExecutor(allocationManager) + if (expectedKilledExec.nonEmpty) { + verify(allocationClient, times(1)).killExecutor(meq(expectedKilledExec.get)) + } else { + verify(allocationClient, never).killExecutor(null) + } + } + + withAllocationManager() { case (receiverTracker, allocationManager) => + implicit val rcvrTrackerAndExecAllocMgr = (receiverTracker, allocationManager) + + verifyKilledExec(Nil, Map.empty, None) + verifyKilledExec(Seq("1", "2"), Map.empty, None) + verifyKilledExec(Seq("1"), Map(1 -> Some("1")), None) + verifyKilledExec(Seq("1", "2"), Map(1 -> Some("1")), Some("2")) + verifyKilledExec(Seq("1", "2"), Map(1 -> Some("1"), 2 -> Some("2")), None) + } + + withAllocationManager( + new SparkConf().set("spark.streaming.dynamicAllocation.minExecutors", "2")) { + case (receiverTracker, allocationManager) => + implicit val rcvrTrackerAndExecAllocMgr = (receiverTracker, allocationManager) + + verifyKilledExec(Seq("1", "2"), Map.empty, None) + verifyKilledExec(Seq("1", "2", "3"), Map(1 -> Some("1"), 2 -> Some("2")), Some("3")) + } + } + + test("parameter validation") { + + def validateParams( + numReceivers: Int = 1, + scalingIntervalSecs: Option[Int] = None, + scalingUpRatio: Option[Double] = None, + scalingDownRatio: Option[Double] = None, + minExecs: Option[Int] = None, + maxExecs: Option[Int] = None): Unit = { + require(numReceivers > 0) + val receiverTracker = mock[ReceiverTracker] + when(receiverTracker.numReceivers()).thenReturn(numReceivers) + val conf = new SparkConf() + if (scalingIntervalSecs.nonEmpty) { + conf.set( + "spark.streaming.dynamicAllocation.scalingInterval", + s"${scalingIntervalSecs.get}s") + } + if (scalingUpRatio.nonEmpty) { + conf.set("spark.streaming.dynamicAllocation.scalingUpRatio", scalingUpRatio.get.toString) + } + if (scalingDownRatio.nonEmpty) { + conf.set( + "spark.streaming.dynamicAllocation.scalingDownRatio", + scalingDownRatio.get.toString) + } + if (minExecs.nonEmpty) { + conf.set("spark.streaming.dynamicAllocation.minExecutors", minExecs.get.toString) + } + if (maxExecs.nonEmpty) { + conf.set("spark.streaming.dynamicAllocation.maxExecutors", maxExecs.get.toString) + } + new ExecutorAllocationManager( + allocationClient, receiverTracker, conf, batchDurationMillis, clock) + } + + validateParams(numReceivers = 1) + validateParams(numReceivers = 2, minExecs = Some(1)) + validateParams(numReceivers = 2, minExecs = Some(3)) + validateParams(numReceivers = 2, maxExecs = Some(3)) + validateParams(numReceivers = 2, maxExecs = Some(1)) + validateParams(minExecs = Some(3), maxExecs = Some(3)) + validateParams(scalingIntervalSecs = Some(1)) + validateParams(scalingUpRatio = Some(1.1)) + validateParams(scalingDownRatio = Some(0.1)) + validateParams(scalingUpRatio = Some(1.1), scalingDownRatio = Some(0.1)) + + intercept[IllegalArgumentException] { + validateParams(minExecs = Some(0)) + } + intercept[IllegalArgumentException] { + validateParams(minExecs = Some(-1)) + } + intercept[IllegalArgumentException] { + validateParams(maxExecs = Some(0)) + } + intercept[IllegalArgumentException] { + validateParams(maxExecs = Some(-1)) + } + intercept[IllegalArgumentException] { + validateParams(minExecs = Some(4), maxExecs = Some(3)) + } + intercept[IllegalArgumentException] { + validateParams(scalingIntervalSecs = Some(-1)) + } + intercept[IllegalArgumentException] { + validateParams(scalingIntervalSecs = Some(0)) + } + intercept[IllegalArgumentException] { + validateParams(scalingUpRatio = Some(-0.1)) + } + intercept[IllegalArgumentException] { + validateParams(scalingUpRatio = Some(0)) + } + intercept[IllegalArgumentException] { + validateParams(scalingDownRatio = Some(-0.1)) + } + intercept[IllegalArgumentException] { + validateParams(scalingDownRatio = Some(0)) + } + intercept[IllegalArgumentException] { + validateParams(scalingUpRatio = Some(0.5), scalingDownRatio = Some(0.5)) + } + intercept[IllegalArgumentException] { + validateParams(scalingUpRatio = Some(0.3), scalingDownRatio = Some(0.5)) + } + } + + test("enabling and disabling") { + withStreamingContext(new SparkConf()) { ssc => + ssc.start() + assert(getExecutorAllocationManager(ssc).isEmpty) + } + + withStreamingContext( + new SparkConf().set("spark.streaming.dynamicAllocation.enabled", "true")) { ssc => + ssc.start() + assert(getExecutorAllocationManager(ssc).nonEmpty) + } + + val confWithBothDynamicAllocationEnabled = new SparkConf() + .set("spark.streaming.dynamicAllocation.enabled", "true") + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.dynamicAllocation.testing", "true") + require(Utils.isDynamicAllocationEnabled(confWithBothDynamicAllocationEnabled) === true) + withStreamingContext(confWithBothDynamicAllocationEnabled) { ssc => + intercept[IllegalArgumentException] { + ssc.start() + } + } + } + + private def withAllocationManager( + conf: SparkConf = new SparkConf, + numReceivers: Int = 1 + )(body: (ReceiverTracker, ExecutorAllocationManager) => Unit): Unit = { + + val receiverTracker = mock[ReceiverTracker] + when(receiverTracker.numReceivers()).thenReturn(numReceivers) + + val manager = new ExecutorAllocationManager( + allocationClient, receiverTracker, conf, batchDurationMillis, clock) + try { + manager.start() + body(receiverTracker, manager) + } finally { + manager.stop() + } + } + + private val _addBatchProcTime = PrivateMethod[Unit]('addBatchProcTime) + private val _requestExecutors = PrivateMethod[Unit]('requestExecutors) + private val _killExecutor = PrivateMethod[Unit]('killExecutor) + private val _executorAllocationManager = + PrivateMethod[Option[ExecutorAllocationManager]]('executorAllocationManager) + + private def addBatchProcTime(manager: ExecutorAllocationManager, timeMs: Long): Unit = { + manager invokePrivate _addBatchProcTime(timeMs) + } + + private def requestExecutors(manager: ExecutorAllocationManager, newExecs: Int): Unit = { + manager invokePrivate _requestExecutors(newExecs) + } + + private def killExecutor(manager: ExecutorAllocationManager): Unit = { + manager invokePrivate _killExecutor() + } + + private def getExecutorAllocationManager( + ssc: StreamingContext): Option[ExecutorAllocationManager] = { + ssc.scheduler invokePrivate _executorAllocationManager() + } + + private def withStreamingContext(conf: SparkConf)(body: StreamingContext => Unit): Unit = { + conf.setMaster("local").setAppName(this.getClass.getSimpleName).set( + "spark.streaming.dynamicAllocation.testing", "true") // to test dynamic allocation + + var ssc: StreamingContext = null + try { + ssc = new StreamingContext(conf, Seconds(1)) + new DummyInputDStream(ssc).foreachRDD(_ => { }) + body(ssc) + } finally { + if (ssc != null) ssc.stop() + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala index f5248acf712b9..a7e365649d3e8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.scheduler import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.streaming.{Time, Duration, StreamingContext} +import org.apache.spark.streaming.{Duration, StreamingContext, Time} class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala index 9b6cd4bc4e315..a2dbae149f311 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala @@ -56,8 +56,7 @@ class JobGeneratorSuite extends TestSuiteBase { // 4. allow subsequent batches to be generated (to allow premature deletion of 3rd batch metadata) // 5. verify whether 3rd batch's block metadata still exists // - // TODO: SPARK-7420 enable this test - ignore("SPARK-6222: Do not clear received block data too soon") { + test("SPARK-6222: Do not clear received block data too soon") { import JobGeneratorSuite._ val checkpointDir = Utils.createTempDir() val testConf = conf diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala index 1eb52b7029a21..37ca0ce2f6a30 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.streaming.scheduler -import scala.collection.mutable - import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index 3bd8d086abf7f..df122ac090c3e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart, TaskLo import org.apache.spark.scheduler.TaskLocality.TaskLocality import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming._ -import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.dstream.{ConstantInputDStream, ReceiverInputDStream} import org.apache.spark.streaming.receiver._ /** Testsuite for receiver scheduling */ @@ -34,8 +34,6 @@ class ReceiverTrackerSuite extends TestSuiteBase { test("send rate update to receivers") { withStreamingContext(new StreamingContext(conf, Milliseconds(100))) { ssc => - ssc.scheduler.listenerBus.start(ssc.sc) - val newRateLimit = 100L val inputDStream = new RateTestInputDStream(ssc) val tracker = new ReceiverTracker(ssc) @@ -104,11 +102,32 @@ class ReceiverTrackerSuite extends TestSuiteBase { } } } + + test("get allocated executors") { + // Test get allocated executors when 1 receiver is registered + withStreamingContext(new StreamingContext(conf, Milliseconds(100))) { ssc => + val input = ssc.receiverStream(new TestReceiver) + val output = new TestOutputStream(input) + output.register() + ssc.start() + assert(ssc.scheduler.receiverTracker.allocatedExecutors().size === 1) + } + + // Test get allocated executors when there's no receiver registered + withStreamingContext(new StreamingContext(conf, Milliseconds(100))) { ssc => + val rdd = ssc.sc.parallelize(1 to 10) + val input = new ConstantInputDStream(ssc, rdd) + val output = new TestOutputStream(input) + output.register() + ssc.start() + assert(ssc.scheduler.receiverTracker.allocatedExecutors() === Map.empty) + } + } } /** An input DStream with for testing rate controlling */ -private[streaming] class RateTestInputDStream(@transient ssc_ : StreamingContext) - extends ReceiverInputDStream[Int](ssc_) { +private[streaming] class RateTestInputDStream(_ssc: StreamingContext) + extends ReceiverInputDStream[Int](_ssc) { override def getReceiver(): Receiver[Int] = new RateTestReceiver(id) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index af4718b4eb705..26b757cc2d535 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -130,20 +130,20 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.numTotalReceivedRecords should be (600) // onReceiverStarted - val receiverInfoStarted = ReceiverInfo(0, "test", true, "localhost") + val receiverInfoStarted = ReceiverInfo(0, "test", true, "localhost", "0") listener.onReceiverStarted(StreamingListenerReceiverStarted(receiverInfoStarted)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (None) // onReceiverError - val receiverInfoError = ReceiverInfo(1, "test", true, "localhost") + val receiverInfoError = ReceiverInfo(1, "test", true, "localhost", "1") listener.onReceiverError(StreamingListenerReceiverError(receiverInfoError)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (Some(receiverInfoError)) listener.receiverInfo(2) should be (None) // onReceiverStopped - val receiverInfoStopped = ReceiverInfo(2, "test", true, "localhost") + val receiverInfoStopped = ReceiverInfo(2, "test", true, "localhost", "2") listener.onReceiverStopped(StreamingListenerReceiverStopped(receiverInfoStopped)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (Some(receiverInfoError)) @@ -200,7 +200,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { batchUIData.get.totalDelay should be (batchInfoSubmitted.totalDelay) batchUIData.get.streamIdToInputInfo should be (Map.empty) batchUIData.get.numRecords should be (0) - batchUIData.get.outputOpIdSparkJobIdPairs should be (Seq(OutputOpIdAndSparkJobId(0, 0))) + batchUIData.get.outputOpIdSparkJobIdPairs.toSeq should be (Seq(OutputOpIdAndSparkJobId(0, 0))) // A lot of "onBatchCompleted"s happen before "onJobStart" for(i <- limit + 1 to limit * 2) { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala index 78fc344b00177..6d9c80d99206b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming.util import java.io.ByteArrayOutputStream +import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit._ import org.apache.spark.SparkFunSuite @@ -34,7 +35,7 @@ class RateLimitedOutputStreamSuite extends SparkFunSuite { val underlying = new ByteArrayOutputStream val data = "X" * 41000 val stream = new RateLimitedOutputStream(underlying, desiredBytesPerSec = 10000) - val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) } + val elapsedNs = benchmark { stream.write(data.getBytes(StandardCharsets.UTF_8)) } val seconds = SECONDS.convert(elapsedNs, NANOSECONDS) assert(seconds >= 4, s"Seconds value ($seconds) is less than 4.") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/RecurringTimerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/RecurringTimerSuite.scala index 0544972d95c03..25b70a3d089ee 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/RecurringTimerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/RecurringTimerSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.streaming.util -import scala.collection.mutable +import java.util.concurrent.ConcurrentLinkedQueue + +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import org.scalatest.PrivateMethodTester @@ -30,34 +32,34 @@ class RecurringTimerSuite extends SparkFunSuite with PrivateMethodTester { test("basic") { val clock = new ManualClock() - val results = new mutable.ArrayBuffer[Long]() with mutable.SynchronizedBuffer[Long] + val results = new ConcurrentLinkedQueue[Long]() val timer = new RecurringTimer(clock, 100, time => { - results += time + results.add(time) }, "RecurringTimerSuite-basic") timer.start(0) eventually(timeout(10.seconds), interval(10.millis)) { - assert(results === Seq(0L)) + assert(results.asScala.toSeq === Seq(0L)) } clock.advance(100) eventually(timeout(10.seconds), interval(10.millis)) { - assert(results === Seq(0L, 100L)) + assert(results.asScala.toSeq === Seq(0L, 100L)) } clock.advance(200) eventually(timeout(10.seconds), interval(10.millis)) { - assert(results === Seq(0L, 100L, 200L, 300L)) + assert(results.asScala.toSeq === Seq(0L, 100L, 200L, 300L)) } assert(timer.stop(interruptTimer = true) === 300L) } test("SPARK-10224: call 'callback' after stopping") { val clock = new ManualClock() - val results = new mutable.ArrayBuffer[Long]() with mutable.SynchronizedBuffer[Long] + val results = new ConcurrentLinkedQueue[Long] val timer = new RecurringTimer(clock, 100, time => { - results += time + results.add(time) }, "RecurringTimerSuite-SPARK-10224") timer.start(0) eventually(timeout(10.seconds), interval(10.millis)) { - assert(results === Seq(0L)) + assert(results.asScala.toSeq === Seq(0L)) } @volatile var lastTime = -1L // Now RecurringTimer is waiting for the next interval @@ -77,7 +79,7 @@ class RecurringTimerSuite extends SparkFunSuite with PrivateMethodTester { // Then it will find `stopped` is true and exit the loop, but it should call `callback` again // before exiting its internal thread. thread.join() - assert(results === Seq(0L, 100L, 200L)) + assert(results.asScala.toSeq === Seq(0L, 100L, 200L)) assert(lastTime === 200L) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 93ae41a3d2ecd..8c980dee2cc06 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -18,31 +18,45 @@ package org.apache.spark.streaming.util import java.io._ import java.nio.ByteBuffer -import java.util +import java.util.{Iterator => JIterator} +import java.util.concurrent.{CountDownLatch, RejectedExecutionException, ThreadPoolExecutor, TimeUnit} +import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import scala.concurrent._ import scala.concurrent.duration._ import scala.language.{implicitConversions, postfixOps} -import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.mockito.ArgumentCaptor +import org.mockito.Matchers.{eq => meq, _} +import org.mockito.Mockito._ +import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach, PrivateMethodTester} +import org.scalatest.concurrent.Eventually import org.scalatest.concurrent.Eventually._ -import org.scalatest.BeforeAndAfter +import org.scalatest.mock.MockitoSugar -import org.apache.spark.util.{ManualClock, Utils} -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.util.{CompletionIterator, ManualClock, ThreadUtils, Utils} -class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { +/** Common tests for WriteAheadLogs that we would like to test with different configurations. */ +abstract class CommonWriteAheadLogTests( + allowBatching: Boolean, + closeFileAfterWrite: Boolean, + testTag: String = "") + extends SparkFunSuite with BeforeAndAfter { import WriteAheadLogSuite._ - val hadoopConf = new Configuration() - var tempDir: File = null - var testDir: String = null - var testFile: String = null - var writeAheadLog: FileBasedWriteAheadLog = null + protected val hadoopConf = new Configuration() + protected var tempDir: File = null + protected var testDir: String = null + protected var testFile: String = null + protected var writeAheadLog: WriteAheadLog = null + protected def testPrefix = if (testTag != "") testTag + " - " else testTag before { tempDir = Utils.createTempDir() @@ -58,47 +72,211 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { Utils.deleteRecursively(tempDir) } - test("WriteAheadLogUtils - log selection and creation") { - val logDir = Utils.createTempDir().getAbsolutePath() + test(testPrefix + "read all logs") { + // Write data manually for testing reading through WriteAheadLog + val writtenData = (1 to 10).flatMap { i => + val data = generateRandomData() + val file = testDir + s"/log-$i-$i" + writeDataManually(data, file, allowBatching) + data + } + + val logDirectoryPath = new Path(testDir) + val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) + assert(fileSystem.exists(logDirectoryPath) === true) + + // Read data using manager and verify + val readData = readDataUsingWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(readData === writtenData) + } + + test(testPrefix + "write logs") { + // Write data with rotation using WriteAheadLog class + val dataToWrite = generateRandomData() + writeDataUsingWriteAheadLog(testDir, dataToWrite, closeFileAfterWrite = closeFileAfterWrite, + allowBatching = allowBatching) + + // Read data manually to verify the written data + val logFiles = getLogFilesInDirectory(testDir) + assert(logFiles.size > 1) + val writtenData = readAndDeserializeDataManually(logFiles, allowBatching) + assert(writtenData === dataToWrite) + } + + test(testPrefix + "read all logs after write") { + // Write data with manager, recover with new manager and verify + val dataToWrite = generateRandomData() + writeDataUsingWriteAheadLog(testDir, dataToWrite, closeFileAfterWrite, allowBatching) + val logFiles = getLogFilesInDirectory(testDir) + assert(logFiles.size > 1) + val readData = readDataUsingWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(dataToWrite === readData) + } + + test(testPrefix + "clean old logs") { + logCleanUpTest(waitForCompletion = false) + } + + test(testPrefix + "clean old logs synchronously") { + logCleanUpTest(waitForCompletion = true) + } + + private def logCleanUpTest(waitForCompletion: Boolean): Unit = { + // Write data with manager, recover with new manager and verify + val manualClock = new ManualClock + val dataToWrite = generateRandomData() + writeAheadLog = writeDataUsingWriteAheadLog(testDir, dataToWrite, closeFileAfterWrite, + allowBatching, manualClock, closeLog = false) + val logFiles = getLogFilesInDirectory(testDir) + assert(logFiles.size > 1) + + writeAheadLog.clean(manualClock.getTimeMillis() / 2, waitForCompletion) + + if (waitForCompletion) { + assert(getLogFilesInDirectory(testDir).size < logFiles.size) + } else { + eventually(Eventually.timeout(1 second), interval(10 milliseconds)) { + assert(getLogFilesInDirectory(testDir).size < logFiles.size) + } + } + } + + test(testPrefix + "handling file errors while reading rotating logs") { + // Generate a set of log files + val manualClock = new ManualClock + val dataToWrite1 = generateRandomData() + writeDataUsingWriteAheadLog(testDir, dataToWrite1, closeFileAfterWrite, allowBatching, + manualClock) + val logFiles1 = getLogFilesInDirectory(testDir) + assert(logFiles1.size > 1) - def assertDriverLogClass[T <: WriteAheadLog: ClassTag](conf: SparkConf): WriteAheadLog = { - val log = WriteAheadLogUtils.createLogForDriver(conf, logDir, hadoopConf) - assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) - log + + // Recover old files and generate a second set of log files + val dataToWrite2 = generateRandomData() + manualClock.advance(100000) + writeDataUsingWriteAheadLog(testDir, dataToWrite2, closeFileAfterWrite, allowBatching, + manualClock) + val logFiles2 = getLogFilesInDirectory(testDir) + assert(logFiles2.size > logFiles1.size) + + // Read the files and verify that all the written data can be read + val readData1 = readDataUsingWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(readData1 === (dataToWrite1 ++ dataToWrite2)) + + // Corrupt the first set of files so that they are basically unreadable + logFiles1.foreach { f => + val raf = new FileOutputStream(f, true).getChannel() + raf.truncate(1) + raf.close() } - def assertReceiverLogClass[T: ClassTag](conf: SparkConf): WriteAheadLog = { - val log = WriteAheadLogUtils.createLogForReceiver(conf, logDir, hadoopConf) - assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) - log + // Verify that the corrupted files do not prevent reading of the second set of data + val readData = readDataUsingWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(readData === dataToWrite2) + } + + test(testPrefix + "do not create directories or files unless write") { + val nonexistentTempPath = File.createTempFile("test", "") + nonexistentTempPath.delete() + assert(!nonexistentTempPath.exists()) + + val writtenSegment = writeDataManually(generateRandomData(), testFile, allowBatching) + val wal = createWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(!nonexistentTempPath.exists(), "Directory created just by creating log object") + if (allowBatching) { + intercept[UnsupportedOperationException](wal.read(writtenSegment.head)) + } else { + wal.read(writtenSegment.head) } + assert(!nonexistentTempPath.exists(), "Directory created just by attempting to read segment") + } - val emptyConf = new SparkConf() // no log configuration - assertDriverLogClass[FileBasedWriteAheadLog](emptyConf) - assertReceiverLogClass[FileBasedWriteAheadLog](emptyConf) + test(testPrefix + "parallel recovery not enabled if closeFileAfterWrite = false") { + // write some data + val writtenData = (1 to 10).flatMap { i => + val data = generateRandomData() + val file = testDir + s"/log-$i-$i" + writeDataManually(data, file, allowBatching) + data + } - // Verify setting driver WAL class - val conf1 = new SparkConf().set("spark.streaming.driver.writeAheadLog.class", - classOf[MockWriteAheadLog0].getName()) - assertDriverLogClass[MockWriteAheadLog0](conf1) - assertReceiverLogClass[FileBasedWriteAheadLog](conf1) + val wal = createWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + // create iterator but don't materialize it + val readData = wal.readAll().asScala.map(byteBufferToString) + wal.close() + if (closeFileAfterWrite) { + // the threadpool is shutdown by the wal.close call above, therefore we shouldn't be able + // to materialize the iterator with parallel recovery + intercept[RejectedExecutionException](readData.toArray) + } else { + assert(readData.toSeq === writtenData) + } + } +} - // Verify setting receiver WAL class - val receiverWALConf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", - classOf[MockWriteAheadLog0].getName()) - assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf) - assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf) +class FileBasedWriteAheadLogSuite + extends CommonWriteAheadLogTests(false, false, "FileBasedWriteAheadLog") { - // Verify setting receiver WAL class with 1-arg constructor - val receiverWALConf2 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", - classOf[MockWriteAheadLog1].getName()) - assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf2) + import WriteAheadLogSuite._ - // Verify failure setting receiver WAL class with 2-arg constructor - intercept[SparkException] { - val receiverWALConf3 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", - classOf[MockWriteAheadLog2].getName()) - assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf3) + test("FileBasedWriteAheadLog - seqToParIterator") { + /* + If the setting `closeFileAfterWrite` is enabled, we start generating a very large number of + files. This causes recovery to take a very long time. In order to make it quicker, we + parallelized the reading of these files. This test makes sure that we limit the number of + open files to the size of the number of threads in our thread pool rather than the size of + the list of files. + */ + val numThreads = 8 + val fpool = ThreadUtils.newForkJoinPool("wal-test-thread-pool", numThreads) + val executionContext = ExecutionContext.fromExecutorService(fpool) + + class GetMaxCounter { + private val value = new AtomicInteger() + @volatile private var max: Int = 0 + def increment(): Unit = synchronized { + val atInstant = value.incrementAndGet() + if (atInstant > max) max = atInstant + } + def decrement(): Unit = synchronized { value.decrementAndGet() } + def get(): Int = synchronized { value.get() } + def getMax(): Int = synchronized { max } + } + try { + // If Jenkins is slow, we may not have a chance to run many threads simultaneously. Having + // a latch will make sure that all the threads can be launched altogether. + val latch = new CountDownLatch(1) + val testSeq = 1 to 1000 + val counter = new GetMaxCounter() + def handle(value: Int): Iterator[Int] = { + new CompletionIterator[Int, Iterator[Int]](Iterator(value)) { + counter.increment() + // block so that other threads also launch + latch.await(10, TimeUnit.SECONDS) + override def completion() { counter.decrement() } + } + } + @volatile var collected: Seq[Int] = Nil + val t = new Thread() { + override def run() { + // run the calculation on a separate thread so that we can release the latch + val iterator = FileBasedWriteAheadLog.seqToParIterator[Int, Int](executionContext, + testSeq, handle) + collected = iterator.toSeq + } + } + t.start() + eventually(Eventually.timeout(10.seconds)) { + // make sure we are doing a parallel computation! + assert(counter.getMax() > 1) + } + latch.countDown() + t.join(10000) + assert(collected === testSeq) + // make sure we didn't open too many Iterators + assert(counter.getMax() <= numThreads) + } finally { + fpool.shutdownNow() } } @@ -122,7 +300,7 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { test("FileBasedWriteAheadLogReader - sequentially reading data") { val writtenData = generateRandomData() - writeDataManually(writtenData, testFile) + writeDataManually(writtenData, testFile, allowBatching = false) val reader = new FileBasedWriteAheadLogReader(testFile, hadoopConf) val readData = reader.toSeq.map(byteBufferToString) assert(readData === writtenData) @@ -163,10 +341,30 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { assert(readDataUsingReader(testFile) === (dataToWrite.dropRight(1))) } + test("FileBasedWriteAheadLogReader - handles errors when file doesn't exist") { + // Write data manually for testing the sequential reader + val dataToWrite = generateRandomData() + writeDataUsingWriter(testFile, dataToWrite) + val tFile = new File(testFile) + assert(tFile.exists()) + // Verify the data can be read and is same as the one correctly written + assert(readDataUsingReader(testFile) === dataToWrite) + + tFile.delete() + assert(!tFile.exists()) + + val reader = new FileBasedWriteAheadLogReader(testFile, hadoopConf) + assert(!reader.hasNext) + reader.close() + + // Verify that no exception is thrown if file doesn't exist + assert(readDataUsingReader(testFile) === Nil) + } + test("FileBasedWriteAheadLogRandomReader - reading data using random reader") { // Write data manually for testing the random reader val writtenData = generateRandomData() - val segments = writeDataManually(writtenData, testFile) + val segments = writeDataManually(writtenData, testFile, allowBatching = false) // Get a random order of these segments and read them back val writtenDataAndSegments = writtenData.zip(segments).toSeq.permutations.take(10).flatten @@ -190,163 +388,224 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { } reader.close() } +} - test("FileBasedWriteAheadLog - write rotating logs") { - // Write data with rotation using WriteAheadLog class - val dataToWrite = generateRandomData() - writeDataUsingWriteAheadLog(testDir, dataToWrite) - - // Read data manually to verify the written data - val logFiles = getLogFilesInDirectory(testDir) - assert(logFiles.size > 1) - val writtenData = logFiles.flatMap { file => readDataManually(file)} - assert(writtenData === dataToWrite) - } +abstract class CloseFileAfterWriteTests(allowBatching: Boolean, testTag: String) + extends CommonWriteAheadLogTests(allowBatching, closeFileAfterWrite = true, testTag) { - test("FileBasedWriteAheadLog - close after write flag") { + import WriteAheadLogSuite._ + test(testPrefix + "close after write flag") { // Write data with rotation using WriteAheadLog class val numFiles = 3 val dataToWrite = Seq.tabulate(numFiles)(_.toString) // total advance time is less than 1000, therefore log shouldn't be rolled, but manually closed writeDataUsingWriteAheadLog(testDir, dataToWrite, closeLog = false, clockAdvanceTime = 100, - closeFileAfterWrite = true) + closeFileAfterWrite = true, allowBatching = allowBatching) // Read data manually to verify the written data val logFiles = getLogFilesInDirectory(testDir) assert(logFiles.size === numFiles) - val writtenData = logFiles.flatMap { file => readDataManually(file)} + val writtenData: Seq[String] = readAndDeserializeDataManually(logFiles, allowBatching) assert(writtenData === dataToWrite) } +} - test("FileBasedWriteAheadLog - read rotating logs") { - // Write data manually for testing reading through WriteAheadLog - val writtenData = (1 to 10).map { i => - val data = generateRandomData() - val file = testDir + s"/log-$i-$i" - writeDataManually(data, file) - data - }.flatten +class FileBasedWriteAheadLogWithFileCloseAfterWriteSuite + extends CloseFileAfterWriteTests(allowBatching = false, "FileBasedWriteAheadLog") - val logDirectoryPath = new Path(testDir) - val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) - assert(fileSystem.exists(logDirectoryPath) === true) +class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( + allowBatching = true, + closeFileAfterWrite = false, + "BatchedWriteAheadLog") + with MockitoSugar + with BeforeAndAfterEach + with Eventually + with PrivateMethodTester { - // Read data using manager and verify - val readData = readDataUsingWriteAheadLog(testDir) - assert(readData === writtenData) - } + import BatchedWriteAheadLog._ + import WriteAheadLogSuite._ - test("FileBasedWriteAheadLog - recover past logs when creating new manager") { - // Write data with manager, recover with new manager and verify - val dataToWrite = generateRandomData() - writeDataUsingWriteAheadLog(testDir, dataToWrite) - val logFiles = getLogFilesInDirectory(testDir) - assert(logFiles.size > 1) - val readData = readDataUsingWriteAheadLog(testDir) - assert(dataToWrite === readData) - } + private var wal: WriteAheadLog = _ + private var walHandle: WriteAheadLogRecordHandle = _ + private var walBatchingThreadPool: ThreadPoolExecutor = _ + private var walBatchingExecutionContext: ExecutionContextExecutorService = _ + private val sparkConf = new SparkConf() - test("FileBasedWriteAheadLog - clean old logs") { - logCleanUpTest(waitForCompletion = false) - } + private val queueLength = PrivateMethod[Int]('getQueueLength) - test("FileBasedWriteAheadLog - clean old logs synchronously") { - logCleanUpTest(waitForCompletion = true) + override def beforeEach(): Unit = { + super.beforeEach() + wal = mock[WriteAheadLog] + walHandle = mock[WriteAheadLogRecordHandle] + walBatchingThreadPool = ThreadUtils.newDaemonFixedThreadPool(8, "wal-test-thread-pool") + walBatchingExecutionContext = ExecutionContext.fromExecutorService(walBatchingThreadPool) } - private def logCleanUpTest(waitForCompletion: Boolean): Unit = { - // Write data with manager, recover with new manager and verify - val manualClock = new ManualClock - val dataToWrite = generateRandomData() - writeAheadLog = writeDataUsingWriteAheadLog(testDir, dataToWrite, manualClock, closeLog = false) - val logFiles = getLogFilesInDirectory(testDir) - assert(logFiles.size > 1) - - writeAheadLog.clean(manualClock.getTimeMillis() / 2, waitForCompletion) - - if (waitForCompletion) { - assert(getLogFilesInDirectory(testDir).size < logFiles.size) - } else { - eventually(timeout(1 second), interval(10 milliseconds)) { - assert(getLogFilesInDirectory(testDir).size < logFiles.size) + override def afterEach(): Unit = { + try { + if (walBatchingExecutionContext != null) { + walBatchingExecutionContext.shutdownNow() } + } finally { + super.afterEach() } } - test("FileBasedWriteAheadLog - handling file errors while reading rotating logs") { - // Generate a set of log files - val manualClock = new ManualClock - val dataToWrite1 = generateRandomData() - writeDataUsingWriteAheadLog(testDir, dataToWrite1, manualClock) - val logFiles1 = getLogFilesInDirectory(testDir) - assert(logFiles1.size > 1) + test("BatchedWriteAheadLog - serializing and deserializing batched records") { + val events = Seq( + BlockAdditionEvent(ReceivedBlockInfo(0, None, None, null)), + BatchAllocationEvent(null, null), + BatchCleanupEvent(Nil) + ) + val buffers = events.map(e => Record(ByteBuffer.wrap(Utils.serialize(e)), 0L, null)) + val batched = BatchedWriteAheadLog.aggregate(buffers) + val deaggregate = BatchedWriteAheadLog.deaggregate(batched).map(buffer => + Utils.deserialize[ReceivedBlockTrackerLogEvent](buffer.array())) - // Recover old files and generate a second set of log files - val dataToWrite2 = generateRandomData() - manualClock.advance(100000) - writeDataUsingWriteAheadLog(testDir, dataToWrite2, manualClock) - val logFiles2 = getLogFilesInDirectory(testDir) - assert(logFiles2.size > logFiles1.size) + assert(deaggregate.toSeq === events) + } - // Read the files and verify that all the written data can be read - val readData1 = readDataUsingWriteAheadLog(testDir) - assert(readData1 === (dataToWrite1 ++ dataToWrite2)) + test("BatchedWriteAheadLog - failures in wrappedLog get bubbled up") { + when(wal.write(any[ByteBuffer], anyLong)).thenThrow(new RuntimeException("Hello!")) + // the BatchedWriteAheadLog should bubble up any exceptions that may have happened during writes + val batchedWal = new BatchedWriteAheadLog(wal, sparkConf) - // Corrupt the first set of files so that they are basically unreadable - logFiles1.foreach { f => - val raf = new FileOutputStream(f, true).getChannel() - raf.truncate(1) - raf.close() + intercept[RuntimeException] { + val buffer = mock[ByteBuffer] + batchedWal.write(buffer, 2L) } - - // Verify that the corrupted files do not prevent reading of the second set of data - val readData = readDataUsingWriteAheadLog(testDir) - assert(readData === dataToWrite2) } - test("FileBasedWriteAheadLog - do not create directories or files unless write") { - val nonexistentTempPath = File.createTempFile("test", "") - nonexistentTempPath.delete() - assert(!nonexistentTempPath.exists()) + // we make the write requests in separate threads so that we don't block the test thread + private def writeAsync(wal: WriteAheadLog, event: String, time: Long): Promise[Unit] = { + val p = Promise[Unit]() + p.completeWith(Future { + val v = wal.write(event, time) + assert(v === walHandle) + }(walBatchingExecutionContext)) + p + } - val writtenSegment = writeDataManually(generateRandomData(), testFile) - val wal = new FileBasedWriteAheadLog(new SparkConf(), tempDir.getAbsolutePath, - new Configuration(), 1, 1, closeFileAfterWrite = false) - assert(!nonexistentTempPath.exists(), "Directory created just by creating log object") - wal.read(writtenSegment.head) - assert(!nonexistentTempPath.exists(), "Directory created just by attempting to read segment") + test("BatchedWriteAheadLog - name log with the highest timestamp of aggregated entries") { + val blockingWal = new BlockingWriteAheadLog(wal, walHandle) + val batchedWal = new BatchedWriteAheadLog(blockingWal, sparkConf) + + val event1 = "hello" + val event2 = "world" + val event3 = "this" + val event4 = "is" + val event5 = "doge" + + // The queue.take() immediately takes the 3, and there is nothing left in the queue at that + // moment. Then the promise blocks the writing of 3. The rest get queued. + writeAsync(batchedWal, event1, 3L) + eventually(timeout(1 second)) { + assert(blockingWal.isBlocked) + assert(batchedWal.invokePrivate(queueLength()) === 0) + } + // rest of the records will be batched while it takes time for 3 to get written + writeAsync(batchedWal, event2, 5L) + writeAsync(batchedWal, event3, 8L) + // we would like event 5 to be written before event 4 in order to test that they get + // sorted before being aggregated + writeAsync(batchedWal, event5, 12L) + eventually(timeout(1 second)) { + assert(blockingWal.isBlocked) + assert(batchedWal.invokePrivate(queueLength()) === 3) + } + writeAsync(batchedWal, event4, 10L) + eventually(timeout(1 second)) { + assert(walBatchingThreadPool.getActiveCount === 5) + assert(batchedWal.invokePrivate(queueLength()) === 4) + } + blockingWal.allowWrite() + + val buffer = wrapArrayArrayByte(Array(event1)) + val queuedEvents = Set(event2, event3, event4, event5) + + eventually(timeout(1 second)) { + assert(batchedWal.invokePrivate(queueLength()) === 0) + verify(wal, times(1)).write(meq(buffer), meq(3L)) + // the file name should be the timestamp of the last record, as events should be naturally + // in order of timestamp, and we need the last element. + val bufferCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) + verify(wal, times(1)).write(bufferCaptor.capture(), meq(12L)) + val records = BatchedWriteAheadLog.deaggregate(bufferCaptor.getValue).map(byteBufferToString) + assert(records.toSet === queuedEvents) + } } -} -object WriteAheadLogSuite { + test("BatchedWriteAheadLog - shutdown properly") { + val batchedWal = new BatchedWriteAheadLog(wal, sparkConf) + batchedWal.close() + verify(wal, times(1)).close() - class MockWriteAheadLog0() extends WriteAheadLog { - override def write(record: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { null } - override def read(handle: WriteAheadLogRecordHandle): ByteBuffer = { null } - override def readAll(): util.Iterator[ByteBuffer] = { null } - override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { } - override def close(): Unit = { } + intercept[IllegalStateException](batchedWal.write(mock[ByteBuffer], 12L)) } - class MockWriteAheadLog1(val conf: SparkConf) extends MockWriteAheadLog0() + test("BatchedWriteAheadLog - fail everything in queue during shutdown") { + val blockingWal = new BlockingWriteAheadLog(wal, walHandle) + val batchedWal = new BatchedWriteAheadLog(blockingWal, sparkConf) + + val event1 = "hello" + val event2 = "world" + val event3 = "this" + + // The queue.take() immediately takes the 3, and there is nothing left in the queue at that + // moment. Then the promise blocks the writing of 3. The rest get queued. + val promise1 = writeAsync(batchedWal, event1, 3L) + eventually(timeout(1 second)) { + assert(blockingWal.isBlocked) + assert(batchedWal.invokePrivate(queueLength()) === 0) + } + // rest of the records will be batched while it takes time for 3 to get written + val promise2 = writeAsync(batchedWal, event2, 5L) + val promise3 = writeAsync(batchedWal, event3, 8L) + + eventually(timeout(1 second)) { + assert(walBatchingThreadPool.getActiveCount === 3) + assert(blockingWal.isBlocked) + assert(batchedWal.invokePrivate(queueLength()) === 2) // event1 is being written + } + + val writePromises = Seq(promise1, promise2, promise3) - class MockWriteAheadLog2(val conf: SparkConf, x: Int) extends MockWriteAheadLog0() + batchedWal.close() + eventually(timeout(1 second)) { + assert(writePromises.forall(_.isCompleted)) + assert(writePromises.forall(_.future.value.get.isFailure)) // all should have failed + } + } +} + +class BatchedWriteAheadLogWithCloseFileAfterWriteSuite + extends CloseFileAfterWriteTests(allowBatching = true, "BatchedWriteAheadLog") +object WriteAheadLogSuite { private val hadoopConf = new Configuration() /** Write data to a file directly and return an array of the file segments written. */ - def writeDataManually(data: Seq[String], file: String): Seq[FileBasedWriteAheadLogSegment] = { + def writeDataManually( + data: Seq[String], + file: String, + allowBatching: Boolean): Seq[FileBasedWriteAheadLogSegment] = { val segments = new ArrayBuffer[FileBasedWriteAheadLogSegment]() val writer = HdfsUtils.getOutputStream(file, hadoopConf) - data.foreach { item => + def writeToStream(bytes: Array[Byte]): Unit = { val offset = writer.getPos - val bytes = Utils.serialize(item) writer.writeInt(bytes.size) writer.write(bytes) segments += FileBasedWriteAheadLogSegment(file, offset, bytes.size) } + if (allowBatching) { + writeToStream(wrapArrayArrayByte(data.toArray[String]).array()) + } else { + data.foreach { item => + writeToStream(Utils.serialize(item)) + } + } writer.close() segments } @@ -356,8 +615,7 @@ object WriteAheadLogSuite { */ def writeDataUsingWriter( filePath: String, - data: Seq[String] - ): Seq[FileBasedWriteAheadLogSegment] = { + data: Seq[String]): Seq[FileBasedWriteAheadLogSegment] = { val writer = new FileBasedWriteAheadLogWriter(filePath, hadoopConf) val segments = data.map { item => writer.write(item) @@ -370,13 +628,13 @@ object WriteAheadLogSuite { def writeDataUsingWriteAheadLog( logDirectory: String, data: Seq[String], + closeFileAfterWrite: Boolean, + allowBatching: Boolean, manualClock: ManualClock = new ManualClock, closeLog: Boolean = true, - clockAdvanceTime: Int = 500, - closeFileAfterWrite: Boolean = false): FileBasedWriteAheadLog = { + clockAdvanceTime: Int = 500): WriteAheadLog = { if (manualClock.getTimeMillis() < 100000) manualClock.setTime(10000) - val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1, - closeFileAfterWrite) + val wal = createWriteAheadLog(logDirectory, closeFileAfterWrite, allowBatching) // Ensure that 500 does not get sorted after 2000, so put a high base value. data.foreach { item => @@ -406,16 +664,16 @@ object WriteAheadLogSuite { } /** Read all the data from a log file directly and return the list of byte buffers. */ - def readDataManually(file: String): Seq[String] = { + def readDataManually[T](file: String): Seq[T] = { val reader = HdfsUtils.getInputStream(file, hadoopConf) - val buffer = new ArrayBuffer[String] + val buffer = new ArrayBuffer[T] try { while (true) { // Read till EOF is thrown val length = reader.readInt() val bytes = new Array[Byte](length) reader.read(bytes) - buffer += Utils.deserialize[String](bytes) + buffer += Utils.deserialize[T](bytes) } } catch { case ex: EOFException => @@ -434,20 +692,23 @@ object WriteAheadLogSuite { } /** Read all the data in the log file in a directory using the WriteAheadLog class. */ - def readDataUsingWriteAheadLog(logDirectory: String): Seq[String] = { - val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1, - closeFileAfterWrite = false) - val data = wal.readAll().asScala.map(byteBufferToString).toSeq + def readDataUsingWriteAheadLog( + logDirectory: String, + closeFileAfterWrite: Boolean, + allowBatching: Boolean): Seq[String] = { + val wal = createWriteAheadLog(logDirectory, closeFileAfterWrite, allowBatching) + val data = wal.readAll().asScala.map(byteBufferToString).toArray wal.close() data } - /** Get the log files in a direction */ + /** Get the log files in a directory. */ def getLogFilesInDirectory(directory: String): Seq[String] = { val logDirectoryPath = new Path(directory) val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) - if (fileSystem.exists(logDirectoryPath) && fileSystem.getFileStatus(logDirectoryPath).isDir) { + if (fileSystem.exists(logDirectoryPath) && + fileSystem.getFileStatus(logDirectoryPath).isDirectory) { fileSystem.listStatus(logDirectoryPath).map { _.getPath() }.sortBy { _.getName().split("-")(1).toLong }.map { @@ -458,10 +719,31 @@ object WriteAheadLogSuite { } } + def createWriteAheadLog( + logDirectory: String, + closeFileAfterWrite: Boolean, + allowBatching: Boolean): WriteAheadLog = { + val sparkConf = new SparkConf + val wal = new FileBasedWriteAheadLog(sparkConf, logDirectory, hadoopConf, 1, 1, + closeFileAfterWrite) + if (allowBatching) new BatchedWriteAheadLog(wal, sparkConf) else wal + } + def generateRandomData(): Seq[String] = { (1 to 100).map { _.toString } } + def readAndDeserializeDataManually(logFiles: Seq[String], allowBatching: Boolean): Seq[String] = { + if (allowBatching) { + logFiles.flatMap { file => + val data = readDataManually[Array[Array[Byte]]](file) + data.flatMap(byteArray => byteArray.map(Utils.deserialize[String])) + } + } else { + logFiles.flatMap { file => readDataManually[String](file)} + } + } + implicit def stringToByteBuffer(str: String): ByteBuffer = { ByteBuffer.wrap(Utils.serialize(str)) } @@ -469,4 +751,41 @@ object WriteAheadLogSuite { implicit def byteBufferToString(byteBuffer: ByteBuffer): String = { Utils.deserialize[String](byteBuffer.array) } + + def wrapArrayArrayByte[T](records: Array[T]): ByteBuffer = { + ByteBuffer.wrap(Utils.serialize[Array[Array[Byte]]](records.map(Utils.serialize[T]))) + } + + /** + * A wrapper WriteAheadLog that blocks the write function to allow batching with the + * BatchedWriteAheadLog. + */ + class BlockingWriteAheadLog( + wal: WriteAheadLog, + handle: WriteAheadLogRecordHandle) extends WriteAheadLog { + @volatile private var isWriteCalled: Boolean = false + @volatile private var blockWrite: Boolean = true + + override def write(record: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { + isWriteCalled = true + eventually(Eventually.timeout(2 second)) { + assert(!blockWrite) + } + wal.write(record, time) + isWriteCalled = false + handle + } + override def read(segment: WriteAheadLogRecordHandle): ByteBuffer = wal.read(segment) + override def readAll(): JIterator[ByteBuffer] = wal.readAll() + override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { + wal.clean(threshTime, waitForCompletion) + } + override def close(): Unit = wal.close() + + def allowWrite(): Unit = { + blockWrite = false + } + + def isBlocked: Boolean = isWriteCalled + } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala new file mode 100644 index 0000000000000..2a41177a5e638 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import java.nio.ByteBuffer +import java.util + +import scala.reflect.ClassTag + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.util.Utils + +class WriteAheadLogUtilsSuite extends SparkFunSuite { + import WriteAheadLogUtilsSuite._ + + private val logDir = Utils.createTempDir().getAbsolutePath() + private val hadoopConf = new Configuration() + + def assertDriverLogClass[T <: WriteAheadLog: ClassTag]( + conf: SparkConf, + isBatched: Boolean = false): WriteAheadLog = { + val log = WriteAheadLogUtils.createLogForDriver(conf, logDir, hadoopConf) + if (isBatched) { + assert(log.isInstanceOf[BatchedWriteAheadLog]) + val parentLog = log.asInstanceOf[BatchedWriteAheadLog].wrappedLog + assert(parentLog.getClass === implicitly[ClassTag[T]].runtimeClass) + } else { + assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) + } + log + } + + def assertReceiverLogClass[T <: WriteAheadLog: ClassTag](conf: SparkConf): WriteAheadLog = { + val log = WriteAheadLogUtils.createLogForReceiver(conf, logDir, hadoopConf) + assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) + log + } + + test("log selection and creation") { + + val emptyConf = new SparkConf() // no log configuration + assertDriverLogClass[FileBasedWriteAheadLog](emptyConf, isBatched = true) + assertReceiverLogClass[FileBasedWriteAheadLog](emptyConf) + + // Verify setting driver WAL class + val driverWALConf = new SparkConf().set("spark.streaming.driver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[MockWriteAheadLog0](driverWALConf, isBatched = true) + assertReceiverLogClass[FileBasedWriteAheadLog](driverWALConf) + + // Verify setting receiver WAL class + val receiverWALConf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf, isBatched = true) + assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf) + + // Verify setting receiver WAL class with 1-arg constructor + val receiverWALConf2 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog1].getName()) + assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf2) + + // Verify failure setting receiver WAL class with 2-arg constructor + intercept[SparkException] { + val receiverWALConf3 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog2].getName()) + assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf3) + } + } + + test("wrap WriteAheadLog in BatchedWriteAheadLog when batching is enabled") { + def getBatchedSparkConf: SparkConf = + new SparkConf().set("spark.streaming.driver.writeAheadLog.allowBatching", "true") + + val justBatchingConf = getBatchedSparkConf + assertDriverLogClass[FileBasedWriteAheadLog](justBatchingConf, isBatched = true) + assertReceiverLogClass[FileBasedWriteAheadLog](justBatchingConf) + + // Verify setting driver WAL class + val driverWALConf = getBatchedSparkConf.set("spark.streaming.driver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[MockWriteAheadLog0](driverWALConf, isBatched = true) + assertReceiverLogClass[FileBasedWriteAheadLog](driverWALConf) + + // Verify receivers are not wrapped + val receiverWALConf = getBatchedSparkConf.set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf, isBatched = true) + assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf) + } + + test("batching is enabled by default in WriteAheadLog") { + val conf = new SparkConf() + assert(WriteAheadLogUtils.isBatchingEnabled(conf, isDriver = true)) + // batching is not valid for receiver WALs + assert(!WriteAheadLogUtils.isBatchingEnabled(conf, isDriver = false)) + } + + test("closeFileAfterWrite is disabled by default in WriteAheadLog") { + val conf = new SparkConf() + assert(!WriteAheadLogUtils.shouldCloseFileAfterWrite(conf, isDriver = true)) + assert(!WriteAheadLogUtils.shouldCloseFileAfterWrite(conf, isDriver = false)) + } +} + +object WriteAheadLogUtilsSuite { + + class MockWriteAheadLog0() extends WriteAheadLog { + override def write(record: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { null } + override def read(handle: WriteAheadLogRecordHandle): ByteBuffer = { null } + override def readAll(): util.Iterator[ByteBuffer] = { null } + override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { } + override def close(): Unit = { } + } + + class MockWriteAheadLog1(val conf: SparkConf) extends MockWriteAheadLog0() + + class MockWriteAheadLog2(val conf: SparkConf, x: Int) extends MockWriteAheadLog0() +} diff --git a/tags/pom.xml b/tags/pom.xml deleted file mode 100644 index ca93722e73345..0000000000000 --- a/tags/pom.xml +++ /dev/null @@ -1,50 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT - ../pom.xml - - - org.apache.spark - spark-test-tags_2.10 - jar - Spark Project Test Tags - http://spark.apache.org/ - - test-tags - - - - - org.scalatest - scalatest_${scala.binary.version} - compile - - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - diff --git a/tools/pom.xml b/tools/pom.xml index 1e64f280e5bed..9bb20e1381067 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -19,13 +19,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT + spark-parent_2.11 + 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-tools_2.10 + spark-tools_2.11 tools @@ -34,16 +34,6 @@ http://spark.apache.org/ - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - org.scala-lang scala-reflect @@ -52,6 +42,11 @@ org.scala-lang scala-compiler + + org.clapper + classutil_${scala.binary.version} + 1.0.6 + diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index a0524cabff2d4..c9058ff409893 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -18,15 +18,13 @@ // scalastyle:off classforname package org.apache.spark.tools -import java.io.File -import java.util.jar.JarFile - import scala.collection.mutable -import scala.collection.JavaConverters._ -import scala.reflect.runtime.universe.runtimeMirror import scala.reflect.runtime.{universe => unv} +import scala.reflect.runtime.universe.runtimeMirror import scala.util.Try +import org.clapper.classutil.ClassFinder + /** * A tool for generating classes to be excluded during binary checking with MIMA. It is expected * that this tool is run with ./spark-class. @@ -42,14 +40,6 @@ object GenerateMIMAIgnore { private val classLoader = Thread.currentThread().getContextClassLoader private val mirror = runtimeMirror(classLoader) - - private def isDeveloperApi(sym: unv.Symbol) = - sym.annotations.exists(_.tpe =:= unv.typeOf[org.apache.spark.annotation.DeveloperApi]) - - private def isExperimental(sym: unv.Symbol) = - sym.annotations.exists(_.tpe =:= unv.typeOf[org.apache.spark.annotation.Experimental]) - - private def isPackagePrivate(sym: unv.Symbol) = !sym.privateWithin.fullName.startsWith("") @@ -58,7 +48,7 @@ object GenerateMIMAIgnore { /** * For every class checks via scala reflection if the class itself or contained members - * have DeveloperApi or Experimental annotations or they are package private. + * are package private. * Returns the tuple of such classes and members. */ private def privateWithin(packageName: String): (Set[String], Set[String]) = { @@ -72,9 +62,9 @@ object GenerateMIMAIgnore { val classSymbol = mirror.classSymbol(Class.forName(className, false, classLoader)) val moduleSymbol = mirror.staticModule(className) val directlyPrivateSpark = - isPackagePrivate(classSymbol) || isPackagePrivateModule(moduleSymbol) - val developerApi = isDeveloperApi(classSymbol) || isDeveloperApi(moduleSymbol) - val experimental = isExperimental(classSymbol) || isExperimental(moduleSymbol) + isPackagePrivate(classSymbol) || + isPackagePrivateModule(moduleSymbol) || + classSymbol.isPrivate /* Inner classes defined within a private[spark] class or object are effectively invisible, so we account for them as package private. */ lazy val indirectlyPrivateSpark = { @@ -86,7 +76,7 @@ object GenerateMIMAIgnore { false } } - if (directlyPrivateSpark || indirectlyPrivateSpark || developerApi || experimental) { + if (directlyPrivateSpark || indirectlyPrivateSpark) { ignoredClasses += className } // check if this class has package-private/annotated members. @@ -101,10 +91,11 @@ object GenerateMIMAIgnore { (ignoredClasses.flatMap(c => Seq(c, c.replace("$", "#"))).toSet, ignoredMembers.toSet) } - /** Scala reflection does not let us see inner function even if they are upgraded - * to public for some reason. So had to resort to java reflection to get all inner - * functions with $$ in there name. - */ + /** + * Scala reflection does not let us see inner function even if they are upgraded + * to public for some reason. So had to resort to java reflection to get all inner + * functions with $$ in there name. + */ def getInnerFunctions(classSymbol: unv.ClassSymbol): Seq[String] = { try { Class.forName(classSymbol.fullName, false, classLoader).getMethods.map(_.getName) @@ -121,9 +112,7 @@ object GenerateMIMAIgnore { private def getAnnotatedOrPackagePrivateMembers(classSymbol: unv.ClassSymbol) = { classSymbol.typeSignature.members.filterNot(x => x.fullName.startsWith("java") || x.fullName.startsWith("scala") - ).filter(x => - isPackagePrivate(x) || isDeveloperApi(x) || isExperimental(x) - ).map(_.fullName) ++ getInnerFunctions(classSymbol) + ).filter(x => isPackagePrivate(x)).map(_.fullName) ++ getInnerFunctions(classSymbol) } def main(args: Array[String]) { @@ -143,7 +132,6 @@ object GenerateMIMAIgnore { // scalastyle:on println } - private def shouldExclude(name: String) = { // Heuristic to remove JVM classes that do not correspond to user-facing classes in Scala name.contains("anon") || @@ -158,35 +146,13 @@ object GenerateMIMAIgnore { * and subpackages both from directories and jars present on the classpath. */ private def getClasses(packageName: String): Set[String] = { - val path = packageName.replace('.', '/') - val resources = classLoader.getResources(path) - - val jars = resources.asScala.filter(_.getProtocol == "jar") - .map(_.getFile.split(":")(1).split("!")(0)).toSeq - - jars.flatMap(getClassesFromJar(_, path)) - .map(_.getName) - .filterNot(shouldExclude).toSet - } - - /** - * Get all classes in a package from a jar file. - */ - private def getClassesFromJar(jarPath: String, packageName: String) = { - import scala.collection.mutable - val jar = new JarFile(new File(jarPath)) - val enums = jar.entries().asScala.map(_.getName).filter(_.startsWith(packageName)) - val classes = mutable.HashSet[Class[_]]() - for (entry <- enums if entry.endsWith(".class")) { - try { - classes += Class.forName(entry.replace('/', '.').stripSuffix(".class"), false, classLoader) - } catch { - // scalastyle:off println - case _: Throwable => println("Unable to load:" + entry) - // scalastyle:on println - } - } - classes + val finder = ClassFinder() + finder + .getClasses + .map(_.name) + .filter(_.startsWith(packageName)) + .filterNot(shouldExclude) + .toSet } } // scalastyle:on classforname diff --git a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala deleted file mode 100644 index 856ea177a9a10..0000000000000 --- a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala +++ /dev/null @@ -1,367 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.tools - -import java.lang.reflect.{Type, Method} - -import scala.collection.mutable.ArrayBuffer -import scala.language.existentials - -import org.apache.spark._ -import org.apache.spark.api.java._ -import org.apache.spark.rdd.{RDD, DoubleRDDFunctions, PairRDDFunctions, OrderedRDDFunctions} -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.{DStream, PairDStreamFunctions} - - -private[spark] abstract class SparkType(val name: String) - -private[spark] case class BaseType(override val name: String) extends SparkType(name) { - override def toString: String = { - name - } -} - -private[spark] -case class ParameterizedType(override val name: String, - parameters: Seq[SparkType], - typebounds: String = "") extends SparkType(name) { - override def toString: String = { - if (typebounds != "") { - typebounds + " " + name + "<" + parameters.mkString(", ") + ">" - } else { - name + "<" + parameters.mkString(", ") + ">" - } - } -} - -private[spark] -case class SparkMethod(name: String, returnType: SparkType, parameters: Seq[SparkType]) { - override def toString: String = { - returnType + " " + name + "(" + parameters.mkString(", ") + ")" - } -} - -/** - * A tool for identifying methods that need to be ported from Scala to the Java API. - * - * It uses reflection to find methods in the Scala API and rewrites those methods' signatures - * into appropriate Java equivalents. If those equivalent methods have not been implemented in - * the Java API, they are printed. - */ -object JavaAPICompletenessChecker { - - private def parseType(typeStr: String): SparkType = { - if (!typeStr.contains("<")) { - // Base types might begin with "class" or "interface", so we have to strip that off: - BaseType(typeStr.trim.split(" ").last) - } else if (typeStr.endsWith("[]")) { - ParameterizedType("Array", Seq(parseType(typeStr.stripSuffix("[]")))) - } else { - val parts = typeStr.split("<", 2) - val name = parts(0).trim - assert (parts(1).last == '>') - val parameters = parts(1).dropRight(1) - ParameterizedType(name, parseTypeList(parameters)) - } - } - - private def parseTypeList(typeStr: String): Seq[SparkType] = { - val types: ArrayBuffer[SparkType] = new ArrayBuffer[SparkType] - var stack = 0 - var token: StringBuffer = new StringBuffer() - for (c <- typeStr.trim) { - if (c == ',' && stack == 0) { - types += parseType(token.toString) - token = new StringBuffer() - } else if (c == ' ' && stack != 0) { - // continue - } else { - if (c == '<') { - stack += 1 - } else if (c == '>') { - stack -= 1 - } - token.append(c) - } - } - assert (stack == 0) - if (token.toString != "") { - types += parseType(token.toString) - } - types.toSeq - } - - private def parseReturnType(typeStr: String): SparkType = { - if (typeStr(0) == '<') { - val parts = typeStr.drop(0).split(">", 2) - val parsed = parseType(parts(1)).asInstanceOf[ParameterizedType] - ParameterizedType(parsed.name, parsed.parameters, parts(0)) - } else { - parseType(typeStr) - } - } - - private def toSparkMethod(method: Method): SparkMethod = { - val returnType = parseReturnType(method.getGenericReturnType.toString) - val name = method.getName - val parameters = method.getGenericParameterTypes.map(t => parseType(t.toString)) - SparkMethod(name, returnType, parameters) - } - - private def toJavaType(scalaType: SparkType, isReturnType: Boolean): SparkType = { - val renameSubstitutions = Map( - "scala.collection.Map" -> "java.util.Map", - // TODO: the JavaStreamingContext API accepts Array arguments - // instead of Lists, so this isn't a trivial translation / sub: - "scala.collection.Seq" -> "java.util.List", - "scala.Function2" -> "org.apache.spark.api.java.function.Function2", - "scala.collection.Iterator" -> "java.util.Iterator", - "scala.collection.mutable.Queue" -> "java.util.Queue", - "double" -> "java.lang.Double" - ) - // Keep applying the substitutions until we've reached a fixedpoint. - def applySubs(scalaType: SparkType): SparkType = { - scalaType match { - case ParameterizedType(name, parameters, typebounds) => - name match { - case "org.apache.spark.rdd.RDD" => - if (parameters(0).name == classOf[Tuple2[_, _]].getName) { - val tupleParams = - parameters(0).asInstanceOf[ParameterizedType].parameters.map(applySubs) - ParameterizedType(classOf[JavaPairRDD[_, _]].getName, tupleParams) - } else { - ParameterizedType(classOf[JavaRDD[_]].getName, parameters.map(applySubs)) - } - case "org.apache.spark.streaming.dstream.DStream" => - if (parameters(0).name == classOf[Tuple2[_, _]].getName) { - val tupleParams = - parameters(0).asInstanceOf[ParameterizedType].parameters.map(applySubs) - ParameterizedType("org.apache.spark.streaming.api.java.JavaPairDStream", - tupleParams) - } else { - ParameterizedType("org.apache.spark.streaming.api.java.JavaDStream", - parameters.map(applySubs)) - } - case "scala.Option" => { - if (isReturnType) { - ParameterizedType("com.google.common.base.Optional", parameters.map(applySubs)) - } else { - applySubs(parameters(0)) - } - } - case "scala.Function1" => - val firstParamName = parameters.last.name - if (firstParamName.startsWith("scala.collection.Traversable") || - firstParamName.startsWith("scala.collection.Iterator")) { - ParameterizedType("org.apache.spark.api.java.function.FlatMapFunction", - Seq(parameters(0), - parameters.last.asInstanceOf[ParameterizedType].parameters(0)).map(applySubs)) - } else if (firstParamName == "scala.runtime.BoxedUnit") { - ParameterizedType("org.apache.spark.api.java.function.VoidFunction", - parameters.dropRight(1).map(applySubs)) - } else { - ParameterizedType("org.apache.spark.api.java.function.Function", - parameters.map(applySubs)) - } - case _ => - ParameterizedType(renameSubstitutions.getOrElse(name, name), - parameters.map(applySubs)) - } - case BaseType(name) => - if (renameSubstitutions.contains(name)) { - BaseType(renameSubstitutions(name)) - } else { - scalaType - } - } - } - var oldType = scalaType - var newType = applySubs(scalaType) - while (oldType != newType) { - oldType = newType - newType = applySubs(scalaType) - } - newType - } - - private def toJavaMethod(method: SparkMethod): SparkMethod = { - val params = method.parameters - .filterNot(_.name == "scala.reflect.ClassTag") - .map(toJavaType(_, isReturnType = false)) - SparkMethod(method.name, toJavaType(method.returnType, isReturnType = true), params) - } - - private def isExcludedByName(method: Method): Boolean = { - val name = method.getDeclaringClass.getName + "." + method.getName - // Scala methods that are declared as private[mypackage] become public in the resulting - // Java bytecode. As a result, we need to manually exclude those methods here. - // This list also includes a few methods that are only used by the web UI or other - // internal Spark components. - val excludedNames = Seq( - "org.apache.spark.rdd.RDD.origin", - "org.apache.spark.rdd.RDD.elementClassTag", - "org.apache.spark.rdd.RDD.checkpointData", - "org.apache.spark.rdd.RDD.partitioner", - "org.apache.spark.rdd.RDD.partitions", - "org.apache.spark.rdd.RDD.firstParent", - "org.apache.spark.rdd.RDD.doCheckpoint", - "org.apache.spark.rdd.RDD.markCheckpointed", - "org.apache.spark.rdd.RDD.clearDependencies", - "org.apache.spark.rdd.RDD.getDependencies", - "org.apache.spark.rdd.RDD.getPartitions", - "org.apache.spark.rdd.RDD.dependencies", - "org.apache.spark.rdd.RDD.getPreferredLocations", - "org.apache.spark.rdd.RDD.collectPartitions", - "org.apache.spark.rdd.RDD.computeOrReadCheckpoint", - "org.apache.spark.rdd.PairRDDFunctions.getKeyClass", - "org.apache.spark.rdd.PairRDDFunctions.getValueClass", - "org.apache.spark.SparkContext.stringToText", - "org.apache.spark.SparkContext.makeRDD", - "org.apache.spark.SparkContext.runJob", - "org.apache.spark.SparkContext.runApproximateJob", - "org.apache.spark.SparkContext.clean", - "org.apache.spark.SparkContext.metadataCleaner", - "org.apache.spark.SparkContext.ui", - "org.apache.spark.SparkContext.newShuffleId", - "org.apache.spark.SparkContext.newRddId", - "org.apache.spark.SparkContext.cleanup", - "org.apache.spark.SparkContext.receiverJobThread", - "org.apache.spark.SparkContext.getRDDStorageInfo", - "org.apache.spark.SparkContext.addedFiles", - "org.apache.spark.SparkContext.addedJars", - "org.apache.spark.SparkContext.persistentRdds", - "org.apache.spark.SparkContext.executorEnvs", - "org.apache.spark.SparkContext.checkpointDir", - "org.apache.spark.SparkContext.getSparkHome", - "org.apache.spark.SparkContext.executorMemoryRequested", - "org.apache.spark.SparkContext.getExecutorStorageStatus", - "org.apache.spark.streaming.dstream.DStream.generatedRDDs", - "org.apache.spark.streaming.dstream.DStream.zeroTime", - "org.apache.spark.streaming.dstream.DStream.rememberDuration", - "org.apache.spark.streaming.dstream.DStream.storageLevel", - "org.apache.spark.streaming.dstream.DStream.mustCheckpoint", - "org.apache.spark.streaming.dstream.DStream.checkpointDuration", - "org.apache.spark.streaming.dstream.DStream.checkpointData", - "org.apache.spark.streaming.dstream.DStream.graph", - "org.apache.spark.streaming.dstream.DStream.isInitialized", - "org.apache.spark.streaming.dstream.DStream.parentRememberDuration", - "org.apache.spark.streaming.dstream.DStream.initialize", - "org.apache.spark.streaming.dstream.DStream.validate", - "org.apache.spark.streaming.dstream.DStream.setContext", - "org.apache.spark.streaming.dstream.DStream.setGraph", - "org.apache.spark.streaming.dstream.DStream.remember", - "org.apache.spark.streaming.dstream.DStream.getOrCompute", - "org.apache.spark.streaming.dstream.DStream.generateJob", - "org.apache.spark.streaming.dstream.DStream.clearOldMetadata", - "org.apache.spark.streaming.dstream.DStream.addMetadata", - "org.apache.spark.streaming.dstream.DStream.updateCheckpointData", - "org.apache.spark.streaming.dstream.DStream.restoreCheckpointData", - "org.apache.spark.streaming.dstream.DStream.isTimeValid", - "org.apache.spark.streaming.StreamingContext.nextNetworkInputStreamId", - "org.apache.spark.streaming.StreamingContext.checkpointDir", - "org.apache.spark.streaming.StreamingContext.checkpointDuration", - "org.apache.spark.streaming.StreamingContext.receiverJobThread", - "org.apache.spark.streaming.StreamingContext.scheduler", - "org.apache.spark.streaming.StreamingContext.initialCheckpoint", - "org.apache.spark.streaming.StreamingContext.getNewNetworkStreamId", - "org.apache.spark.streaming.StreamingContext.validate", - "org.apache.spark.streaming.StreamingContext.createNewSparkContext", - "org.apache.spark.streaming.StreamingContext.rddToFileName", - "org.apache.spark.streaming.StreamingContext.getSparkCheckpointDir", - "org.apache.spark.streaming.StreamingContext.env", - "org.apache.spark.streaming.StreamingContext.graph", - "org.apache.spark.streaming.StreamingContext.isCheckpointPresent" - ) - val excludedPatterns = Seq( - """^org\.apache\.spark\.SparkContext\..*To.*Functions""", - """^org\.apache\.spark\.SparkContext\..*WritableConverter""", - """^org\.apache\.spark\.SparkContext\..*To.*Writable""" - ).map(_.r) - lazy val excludedByPattern = - !excludedPatterns.map(_.findFirstIn(name)).filter(_.isDefined).isEmpty - name.contains("$") || excludedNames.contains(name) || excludedByPattern - } - - private def isExcludedByInterface(method: Method): Boolean = { - val excludedInterfaces = - Set("org.apache.spark.Logging", "org.apache.hadoop.mapreduce.HadoopMapReduceUtil") - def toComparisionKey(method: Method): (Class[_], String, Type) = - (method.getReturnType, method.getName, method.getGenericReturnType) - val interfaces = method.getDeclaringClass.getInterfaces.filter { i => - excludedInterfaces.contains(i.getName) - } - val excludedMethods = interfaces.flatMap(_.getMethods.map(toComparisionKey)) - excludedMethods.contains(toComparisionKey(method)) - } - - private def printMissingMethods(scalaClass: Class[_], javaClass: Class[_]) { - val methods = scalaClass.getMethods - .filterNot(_.isAccessible) - .filterNot(isExcludedByName) - .filterNot(isExcludedByInterface) - val javaEquivalents = methods.map(m => toJavaMethod(toSparkMethod(m))).toSet - - val javaMethods = javaClass.getMethods.map(toSparkMethod).toSet - - val missingMethods = javaEquivalents -- javaMethods - - for (method <- missingMethods) { - // scalastyle:off println - println(method) - // scalastyle:on println - } - } - - def main(args: Array[String]) { - // scalastyle:off println - println("Missing RDD methods") - printMissingMethods(classOf[RDD[_]], classOf[JavaRDD[_]]) - println() - - println("Missing PairRDD methods") - printMissingMethods(classOf[PairRDDFunctions[_, _]], classOf[JavaPairRDD[_, _]]) - println() - - println("Missing DoubleRDD methods") - printMissingMethods(classOf[DoubleRDDFunctions], classOf[JavaDoubleRDD]) - println() - - println("Missing OrderedRDD methods") - printMissingMethods(classOf[OrderedRDDFunctions[_, _, _]], classOf[JavaPairRDD[_, _]]) - println() - - println("Missing SparkContext methods") - printMissingMethods(classOf[SparkContext], classOf[JavaSparkContext]) - println() - - println("Missing StreamingContext methods") - printMissingMethods(classOf[StreamingContext], classOf[JavaStreamingContext]) - println() - - println("Missing DStream methods") - printMissingMethods(classOf[DStream[_]], classOf[JavaDStream[_]]) - println() - - println("Missing PairDStream methods") - printMissingMethods(classOf[PairDStreamFunctions[_, _]], classOf[JavaPairDStream[_, _]]) - println() - // scalastyle:on println - } -} diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala deleted file mode 100644 index 0dc2861253f17..0000000000000 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.tools - -import java.util.concurrent.{CountDownLatch, Executors} -import java.util.concurrent.atomic.AtomicLong - -import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.serializer.KryoSerializer -import org.apache.spark.shuffle.hash.HashShuffleManager -import org.apache.spark.util.Utils - -/** - * Internal utility for micro-benchmarking shuffle write performance. - * - * Writes simulated shuffle output from several threads and records the observed throughput. - */ -object StoragePerfTester { - def main(args: Array[String]): Unit = { - /** Total amount of data to generate. Distributed evenly amongst maps and reduce splits. */ - val dataSizeMb = Utils.memoryStringToMb(sys.env.getOrElse("OUTPUT_DATA", "1g")) - - /** Number of map tasks. All tasks execute concurrently. */ - val numMaps = sys.env.get("NUM_MAPS").map(_.toInt).getOrElse(8) - - /** Number of reduce splits for each map task. */ - val numOutputSplits = sys.env.get("NUM_REDUCERS").map(_.toInt).getOrElse(500) - - val recordLength = 1000 // ~1KB records - val totalRecords = dataSizeMb * 1000 - val recordsPerMap = totalRecords / numMaps - - val writeKey = "1" * (recordLength / 2) - val writeValue = "1" * (recordLength / 2) - val executor = Executors.newFixedThreadPool(numMaps) - - val conf = new SparkConf() - .set("spark.shuffle.compress", "false") - .set("spark.shuffle.sync", "true") - .set("spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager") - - // This is only used to instantiate a BlockManager. All thread scheduling is done manually. - val sc = new SparkContext("local[4]", "Write Tester", conf) - val hashShuffleManager = sc.env.shuffleManager.asInstanceOf[HashShuffleManager] - - def writeOutputBytes(mapId: Int, total: AtomicLong): Unit = { - val shuffle = hashShuffleManager.shuffleBlockResolver.forMapTask(1, mapId, numOutputSplits, - new KryoSerializer(sc.conf), new ShuffleWriteMetrics()) - val writers = shuffle.writers - for (i <- 1 to recordsPerMap) { - writers(i % numOutputSplits).write(writeKey, writeValue) - } - writers.map { w => - w.commitAndClose() - total.addAndGet(w.fileSegment().length) - } - - shuffle.releaseWriters(true) - } - - val start = System.currentTimeMillis() - val latch = new CountDownLatch(numMaps) - val totalBytes = new AtomicLong() - for (task <- 1 to numMaps) { - executor.submit(new Runnable() { - override def run(): Unit = { - try { - writeOutputBytes(task, totalBytes) - latch.countDown() - } catch { - case e: Exception => - // scalastyle:off println - println("Exception in child thread: " + e + " " + e.getMessage) - // scalastyle:on println - System.exit(1) - } - } - }) - } - latch.await() - val end = System.currentTimeMillis() - val time = (end - start) / 1000.0 - val bytesPerSecond = totalBytes.get() / time - val bytesPerFile = (totalBytes.get() / (numOutputSplits * numMaps.toDouble)).toLong - - // scalastyle:off println - System.err.println("files_total\t\t%s".format(numMaps * numOutputSplits)) - System.err.println("bytes_per_file\t\t%s".format(Utils.bytesToString(bytesPerFile))) - System.err.println("agg_throughput\t\t%s/s".format(Utils.bytesToString(bytesPerSecond.toLong))) - // scalastyle:on println - - executor.shutdown() - sc.stop() - } -} diff --git a/unsafe/pom.xml b/unsafe/pom.xml deleted file mode 100644 index caf1f77890b58..0000000000000 --- a/unsafe/pom.xml +++ /dev/null @@ -1,106 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT - ../pom.xml - - - org.apache.spark - spark-unsafe_2.10 - jar - Spark Project Unsafe - http://spark.apache.org/ - - unsafe - - - - - - - com.google.code.findbugs - jsr305 - - - com.google.guava - guava - - - - - org.slf4j - slf4j-api - provided - - - - - org.apache.spark - spark-test-tags_${scala.binary.version} - - - org.mockito - mockito-core - test - - - org.scalacheck - scalacheck_${scala.binary.version} - test - - - org.apache.commons - commons-lang3 - test - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - - - net.alchim31.maven - scala-maven-plugin - - - - -XDignore.symbol.file - - - - - org.apache.maven.plugins - maven-compiler-plugin - - - - -XDignore.symbol.file - - - - - - - diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java deleted file mode 100644 index 1c16da982923b..0000000000000 --- a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe; - -import java.lang.reflect.Field; - -import sun.misc.Unsafe; - -public final class Platform { - - private static final Unsafe _UNSAFE; - - public static final int BYTE_ARRAY_OFFSET; - - public static final int INT_ARRAY_OFFSET; - - public static final int LONG_ARRAY_OFFSET; - - public static final int DOUBLE_ARRAY_OFFSET; - - public static int getInt(Object object, long offset) { - return _UNSAFE.getInt(object, offset); - } - - public static void putInt(Object object, long offset, int value) { - _UNSAFE.putInt(object, offset, value); - } - - public static boolean getBoolean(Object object, long offset) { - return _UNSAFE.getBoolean(object, offset); - } - - public static void putBoolean(Object object, long offset, boolean value) { - _UNSAFE.putBoolean(object, offset, value); - } - - public static byte getByte(Object object, long offset) { - return _UNSAFE.getByte(object, offset); - } - - public static void putByte(Object object, long offset, byte value) { - _UNSAFE.putByte(object, offset, value); - } - - public static short getShort(Object object, long offset) { - return _UNSAFE.getShort(object, offset); - } - - public static void putShort(Object object, long offset, short value) { - _UNSAFE.putShort(object, offset, value); - } - - public static long getLong(Object object, long offset) { - return _UNSAFE.getLong(object, offset); - } - - public static void putLong(Object object, long offset, long value) { - _UNSAFE.putLong(object, offset, value); - } - - public static float getFloat(Object object, long offset) { - return _UNSAFE.getFloat(object, offset); - } - - public static void putFloat(Object object, long offset, float value) { - _UNSAFE.putFloat(object, offset, value); - } - - public static double getDouble(Object object, long offset) { - return _UNSAFE.getDouble(object, offset); - } - - public static void putDouble(Object object, long offset, double value) { - _UNSAFE.putDouble(object, offset, value); - } - - public static Object getObjectVolatile(Object object, long offset) { - return _UNSAFE.getObjectVolatile(object, offset); - } - - public static void putObjectVolatile(Object object, long offset, Object value) { - _UNSAFE.putObjectVolatile(object, offset, value); - } - - public static long allocateMemory(long size) { - return _UNSAFE.allocateMemory(size); - } - - public static void freeMemory(long address) { - _UNSAFE.freeMemory(address); - } - - public static void copyMemory( - Object src, long srcOffset, Object dst, long dstOffset, long length) { - while (length > 0) { - long size = Math.min(length, UNSAFE_COPY_THRESHOLD); - _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); - length -= size; - srcOffset += size; - dstOffset += size; - } - } - - /** - * Raises an exception bypassing compiler checks for checked exceptions. - */ - public static void throwException(Throwable t) { - _UNSAFE.throwException(t); - } - - /** - * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to - * allow safepoint polling during a large copy. - */ - private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L; - - static { - sun.misc.Unsafe unsafe; - try { - Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe"); - unsafeField.setAccessible(true); - unsafe = (sun.misc.Unsafe) unsafeField.get(null); - } catch (Throwable cause) { - unsafe = null; - } - _UNSAFE = unsafe; - - if (_UNSAFE != null) { - BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class); - INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class); - LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class); - DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class); - } else { - BYTE_ARRAY_OFFSET = 0; - INT_ARRAY_OFFSET = 0; - LONG_ARRAY_OFFSET = 0; - DOUBLE_ARRAY_OFFSET = 0; - } - } -} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java deleted file mode 100644 index 4276f25c2165b..0000000000000 --- a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.hash; - -import org.apache.spark.unsafe.Platform; - -/** - * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. - */ -public final class Murmur3_x86_32 { - private static final int C1 = 0xcc9e2d51; - private static final int C2 = 0x1b873593; - - private final int seed; - - public Murmur3_x86_32(int seed) { - this.seed = seed; - } - - @Override - public String toString() { - return "Murmur3_32(seed=" + seed + ")"; - } - - public int hashInt(int input) { - int k1 = mixK1(input); - int h1 = mixH1(seed, k1); - - return fmix(h1, 4); - } - - public int hashUnsafeWords(Object base, long offset, int lengthInBytes) { - return hashUnsafeWords(base, offset, lengthInBytes, seed); - } - - public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { - // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. - assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; - int h1 = seed; - for (int i = 0; i < lengthInBytes; i += 4) { - int halfWord = Platform.getInt(base, offset + i); - int k1 = mixK1(halfWord); - h1 = mixH1(h1, k1); - } - return fmix(h1, lengthInBytes); - } - - public int hashLong(long input) { - int low = (int) input; - int high = (int) (input >>> 32); - - int k1 = mixK1(low); - int h1 = mixH1(seed, k1); - - k1 = mixK1(high); - h1 = mixH1(h1, k1); - - return fmix(h1, 8); - } - - private static int mixK1(int k1) { - k1 *= C1; - k1 = Integer.rotateLeft(k1, 15); - k1 *= C2; - return k1; - } - - private static int mixH1(int h1, int k1) { - h1 ^= k1; - h1 = Integer.rotateLeft(h1, 13); - h1 = h1 * 5 + 0xe6546b64; - return h1; - } - - // Finalization mix - force all bits of a hash block to avalanche - private static int fmix(int h1, int length) { - h1 ^= length; - h1 ^= h1 >>> 16; - h1 *= 0x85ebca6b; - h1 ^= h1 >>> 13; - h1 *= 0xc2b2ae35; - h1 ^= h1 >>> 16; - return h1; - } -} diff --git a/yarn/pom.xml b/yarn/pom.xml index 989b820bec9ef..328bb6678db99 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -19,13 +19,13 @@ 4.0.0 org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT + spark-parent_2.11 + 2.0.0-SNAPSHOT ../pom.xml org.apache.spark - spark-yarn_2.10 + spark-yarn_2.11 jar Spark Project YARN diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala index 56e4741b93873..a6a4fec3ba9e9 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala @@ -24,9 +24,12 @@ import scala.language.postfixOps import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.security.UserGroupInformation -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.yarn.config._ +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.util.ThreadUtils /* @@ -60,11 +63,9 @@ private[yarn] class AMDelegationTokenRenewer( private val hadoopUtil = YarnSparkHadoopUtil.get - private val credentialsFile = sparkConf.get("spark.yarn.credentials.file") - private val daysToKeepFiles = - sparkConf.getInt("spark.yarn.credentials.file.retention.days", 5) - private val numFilesToKeep = - sparkConf.getInt("spark.yarn.credentials.file.retention.count", 5) + private val credentialsFile = sparkConf.get(CREDENTIALS_FILE_PATH) + private val daysToKeepFiles = sparkConf.get(CREDENTIALS_FILE_MAX_RETENTION) + private val numFilesToKeep = sparkConf.get(CREDENTIAL_FILE_MAX_COUNT) private val freshHadoopConf = hadoopUtil.getConfBypassingFSCache(hadoopConf, new Path(credentialsFile).toUri.getScheme) @@ -76,8 +77,8 @@ private[yarn] class AMDelegationTokenRenewer( * */ private[spark] def scheduleLoginFromKeytab(): Unit = { - val principal = sparkConf.get("spark.yarn.principal") - val keytab = sparkConf.get("spark.yarn.keytab") + val principal = sparkConf.get(PRINCIPAL).get + val keytab = sparkConf.get(KEYTAB).get /** * Schedule re-login and creation of new tokens. If tokens have already expired, this method @@ -115,7 +116,7 @@ private[yarn] class AMDelegationTokenRenewer( } } // Schedule update of credentials. This handles the case of updating the tokens right now - // as well, since the renenwal interval will be 0, and the thread will get scheduled + // as well, since the renewal interval will be 0, and the thread will get scheduled // immediately. scheduleRenewal(driverTokenRenewerRunnable) } @@ -151,7 +152,7 @@ private[yarn] class AMDelegationTokenRenewer( // passed in already has tokens for that FS even if the tokens are expired (it really only // checks if there are tokens for the service, and not if they are valid). So the only real // way to get new tokens is to make sure a different Credentials object is used each time to - // get new tokens and then the new tokens are copied over the the current user's Credentials. + // get new tokens and then the new tokens are copied over the current user's Credentials. // So: // - we login as a different user and get the UGI // - use that UGI to get the tokens (see doAs block below) @@ -172,6 +173,8 @@ private[yarn] class AMDelegationTokenRenewer( override def run(): Void = { val nns = YarnSparkHadoopUtil.get.getNameNodesToAccess(sparkConf) + dst hadoopUtil.obtainTokensForNamenodes(nns, freshHadoopConf, tempCreds) + hadoopUtil.obtainTokenForHiveMetastore(sparkConf, freshHadoopConf, tempCreds) + hadoopUtil.obtainTokenForHBase(sparkConf, freshHadoopConf, tempCreds) null } }) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 50ae7ffeec4c5..d447a59937be7 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -17,23 +17,25 @@ package org.apache.spark.deploy.yarn -import scala.util.control.NonFatal - import java.io.{File, IOException} import java.lang.reflect.InvocationTargetException import java.net.{Socket, URL} import java.util.concurrent.atomic.AtomicReference +import scala.util.control.NonFatal + import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.spark.rpc._ -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv, - SparkException, SparkUserAppException} +import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer +import org.apache.spark.deploy.yarn.config._ +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.rpc._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util._ @@ -66,16 +68,18 @@ private[spark] class ApplicationMaster( // allocation is enabled), with a minimum of 3. private val maxNumExecutorFailures = { - val defaultKey = + val effectiveNumExecutors = if (Utils.isDynamicAllocationEnabled(sparkConf)) { - "spark.dynamicAllocation.maxExecutors" + sparkConf.get(DYN_ALLOCATION_MAX_EXECUTORS) } else { - "spark.executor.instances" + sparkConf.get(EXECUTOR_INSTANCES).getOrElse(0) } - val effectiveNumExecutors = sparkConf.getInt(defaultKey, 0) - val defaultMaxNumExecutorFailures = math.max(3, 2 * effectiveNumExecutors) + // By default, effectiveNumExecutors is Int.MaxValue if dynamic allocation is enabled. We need + // avoid the integer overflow here. + val defaultMaxNumExecutorFailures = math.max(3, + if (effectiveNumExecutors > Int.MaxValue / 2) Int.MaxValue else (2 * effectiveNumExecutors)) - sparkConf.getInt("spark.yarn.max.executor.failures", defaultMaxNumExecutorFailures) + sparkConf.get(MAX_EXECUTOR_FAILURES).getOrElse(defaultMaxNumExecutorFailures) } @volatile private var exitCode = 0 @@ -96,14 +100,13 @@ private[spark] class ApplicationMaster( private val heartbeatInterval = { // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses. val expiryInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) - math.max(0, math.min(expiryInterval / 2, - sparkConf.getTimeAsMs("spark.yarn.scheduler.heartbeat.interval-ms", "3s"))) + math.max(0, math.min(expiryInterval / 2, sparkConf.get(RM_HEARTBEAT_INTERVAL))) } // Initial wait interval before allocator poll, to allow for quicker ramp up when executors are // being requested. private val initialAllocationInterval = math.min(heartbeatInterval, - sparkConf.getTimeAsMs("spark.yarn.scheduler.initial-allocation.interval", "200ms")) + sparkConf.get(INITIAL_HEARTBEAT_INTERVAL)) // Next wait interval before allocator poll. private var nextAllocationInterval = initialAllocationInterval @@ -117,6 +120,10 @@ private[spark] class ApplicationMaster( private var delegationTokenRenewerOption: Option[AMDelegationTokenRenewer] = None + def getAttemptId(): ApplicationAttemptId = { + client.getAttemptId() + } + final def run(): Int = { try { val appAttemptId = client.getAttemptId() @@ -126,15 +133,13 @@ private[spark] class ApplicationMaster( // other spark processes running on the same box System.setProperty("spark.ui.port", "0") - // Set the master property to match the requested mode. - System.setProperty("spark.master", "yarn-cluster") + // Set the master and deploy mode property to match the requested mode. + System.setProperty("spark.master", "yarn") + System.setProperty("spark.submit.deployMode", "cluster") - // Propagate the application ID so that YarnClusterSchedulerBackend can pick it up. + // Set this internal configuration if it is running on cluster mode, this + // configuration will be checked in SparkContext to avoid misuse of yarn cluster mode. System.setProperty("spark.yarn.app.id", appAttemptId.getApplicationId().toString()) - - // Propagate the attempt if, so that in case of event logging, - // different attempt's logs gets created in different directory - System.setProperty("spark.yarn.app.attemptId", appAttemptId.getAttemptId().toString()) } logInfo("ApplicationAttemptId: " + appAttemptId) @@ -148,13 +153,13 @@ private[spark] class ApplicationMaster( val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts if (!finished) { - // This happens when the user application calls System.exit(). We have the choice - // of either failing or succeeding at this point. We report success to avoid - // retrying applications that have succeeded (System.exit(0)), which means that - // applications that explicitly exit with a non-zero status will also show up as - // succeeded in the RM UI. + // The default state of ApplicationMaster is failed if it is invoked by shut down hook. + // This behavior is different compared to 1.x version. + // If user application is exited ahead of time by calling System.exit(N), here mark + // this application as failed with EXIT_EARLY. For a good shutdown, user shouldn't call + // System.exit(0) to terminate the application. finish(finalStatus, - ApplicationMaster.EXIT_SUCCESS, + ApplicationMaster.EXIT_EARLY, "Shutdown hook called before final status was reported.") } @@ -174,7 +179,7 @@ private[spark] class ApplicationMaster( // If the credentials file config is present, we must periodically renew tokens. So create // a new AMDelegationTokenRenewer - if (sparkConf.contains("spark.yarn.credentials.file")) { + if (sparkConf.contains(CREDENTIALS_FILE_PATH.key)) { delegationTokenRenewerOption = Some(new AMDelegationTokenRenewer(sparkConf, yarnConf)) // If a principal and keytab have been set, use that to create new credentials for executors // periodically @@ -205,7 +210,7 @@ private[spark] class ApplicationMaster( */ final def getDefaultFinalStatus(): FinalApplicationStatus = { if (isClusterMode) { - FinalApplicationStatus.SUCCEEDED + FinalApplicationStatus.FAILED } else { FinalApplicationStatus.UNDEFINED } @@ -271,16 +276,16 @@ private[spark] class ApplicationMaster( val appId = client.getAttemptId().getApplicationId().toString() val attemptId = client.getAttemptId().getAttemptId().toString() val historyAddress = - sparkConf.getOption("spark.yarn.historyServer.address") + sparkConf.get(HISTORY_SERVER_ADDRESS) .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) } .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" } .getOrElse("") val _sparkConf = if (sc != null) sc.getConf else sparkConf - val driverUrl = _rpcEnv.uriOf( - SparkEnv.driverActorSystemName, - RpcAddress(_sparkConf.get("spark.driver.host"), _sparkConf.get("spark.driver.port").toInt), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + val driverUrl = RpcEndpointAddress( + _sparkConf.get("spark.driver.host"), + _sparkConf.get("spark.driver.port").toInt, + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString allocator = client.register(driverUrl, driverRef, yarnConf, @@ -306,7 +311,6 @@ private[spark] class ApplicationMaster( port: String, isClusterMode: Boolean): RpcEndpointRef = { val driverEndpoint = rpcEnv.setupEndpointRef( - SparkEnv.driverActorSystemName, RpcAddress(host, port.toInt), YarnSchedulerBackend.ENDPOINT_NAME) amEndpoint = @@ -352,7 +356,7 @@ private[spark] class ApplicationMaster( private def launchReporterThread(): Thread = { // The number of failures in a row until Reporter thread give up - val reporterMaxFailures = sparkConf.getInt("spark.yarn.scheduler.reporterThread.maxFailures", 5) + val reporterMaxFailures = sparkConf.get(MAX_REPORTER_THREAD_FAILURES) val t = new Thread { override def run() { @@ -370,16 +374,22 @@ private[spark] class ApplicationMaster( failureCount = 0 } catch { case i: InterruptedException => - case e: Throwable => { + case e: Throwable => failureCount += 1 - if (!NonFatal(e) || failureCount >= reporterMaxFailures) { + // this exception was introduced in hadoop 2.4 and this code would not compile + // with earlier versions if we refer it directly. + if ("org.apache.hadoop.yarn.exceptions.ApplicationAttemptNotFoundException" == + e.getClass().getName()) { + logError("Exception from Reporter thread.", e) + finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_REPORTER_FAILURE, + e.getMessage) + } else if (!NonFatal(e) || failureCount >= reporterMaxFailures) { finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_REPORTER_FAILURE, "Exception was thrown " + s"$failureCount time(s) from Reporter thread.") } else { logWarning(s"Reporter thread fails $failureCount time(s) in a row.", e) } - } } try { val numPendingAllocate = allocator.getPendingAllocate.size @@ -419,7 +429,7 @@ private[spark] class ApplicationMaster( private def cleanupStagingDir(fs: FileSystem) { var stagingDirPath: Path = null try { - val preserveFiles = sparkConf.getBoolean("spark.yarn.preserve.staging.files", false) + val preserveFiles = sparkConf.get(PRESERVE_STAGING_FILES) if (!preserveFiles) { stagingDirPath = new Path(System.getenv("SPARK_YARN_STAGING_DIR")) if (stagingDirPath == null) { @@ -438,7 +448,7 @@ private[spark] class ApplicationMaster( private def waitForSparkContextInitialized(): SparkContext = { logInfo("Waiting for spark context initialization") sparkContextRef.synchronized { - val totalWaitTime = sparkConf.getTimeAsMs("spark.yarn.am.waitTime", "100s") + val totalWaitTime = sparkConf.get(AM_MAX_WAIT_TIME) val deadline = System.currentTimeMillis() + totalWaitTime while (sparkContextRef.get() == null && System.currentTimeMillis < deadline && !finished) { @@ -463,7 +473,7 @@ private[spark] class ApplicationMaster( // Spark driver should already be up since it launched us, but we don't want to // wait forever, so wait 100 seconds max to match the cluster mode setting. - val totalWaitTimeMs = sparkConf.getTimeAsMs("spark.yarn.am.waitTime", "100s") + val totalWaitTimeMs = sparkConf.get(AM_MAX_WAIT_TIME) val deadline = System.currentTimeMillis + totalWaitTimeMs while (!driverUp && !finished && System.currentTimeMillis < deadline) { @@ -596,11 +606,12 @@ private[spark] class ApplicationMaster( localityAwareTasks, hostToLocalTaskCount)) { resetAllocatorInterval() } + context.reply(true) case None => logWarning("Container allocator is not ready to request executors yet.") + context.reply(false) } - context.reply(true) case KillExecutors(executorIds) => logInfo(s"Driver requested to kill executor(s) ${executorIds.mkString(", ")}.") @@ -642,6 +653,7 @@ object ApplicationMaster extends Logging { private val EXIT_SC_NOT_INITED = 13 private val EXIT_SECURITY = 14 private val EXIT_EXCEPTION_USER_CLASS = 15 + private val EXIT_EARLY = 16 private var master: ApplicationMaster = _ @@ -649,7 +661,7 @@ object ApplicationMaster extends Logging { SignalLogger.register(log) val amArgs = new ApplicationMasterArguments(args) SparkHadoopUtil.get.runAsSparkUser { () => - master = new ApplicationMaster(amArgs, new YarnRMClient(amArgs)) + master = new ApplicationMaster(amArgs, new YarnRMClient) System.exit(master.run()) } } @@ -662,6 +674,10 @@ object ApplicationMaster extends Logging { master.sparkContextStopped(sc) } + private[spark] def getAttemptId(): ApplicationAttemptId = { + master.getAttemptId + } + } /** diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index 17d9943c795e3..5cdec87667a5d 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -17,9 +17,9 @@ package org.apache.spark.deploy.yarn -import org.apache.spark.util.{MemoryParam, IntParam} -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ -import collection.mutable.ArrayBuffer +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.util.{IntParam, MemoryParam} class ApplicationMasterArguments(val args: Array[String]) { var userJar: String = null @@ -27,8 +27,6 @@ class ApplicationMasterArguments(val args: Array[String]) { var primaryPyFile: String = null var primaryRFile: String = null var userArgs: Seq[String] = Nil - var executorMemory = 1024 - var executorCores = 1 var propertiesFile: String = null parseArgs(args.toList) @@ -58,18 +56,10 @@ class ApplicationMasterArguments(val args: Array[String]) { primaryRFile = value args = tail - case ("--args" | "--arg") :: value :: tail => + case ("--arg") :: value :: tail => userArgsBuffer += value args = tail - case ("--worker-memory" | "--executor-memory") :: MemoryParam(value) :: tail => - executorMemory = value - args = tail - - case ("--worker-cores" | "--executor-cores") :: IntParam(value) :: tail => - executorCores = value - args = tail - case ("--properties-file") :: value :: tail => propertiesFile = value args = tail @@ -86,7 +76,7 @@ class ApplicationMasterArguments(val args: Array[String]) { System.exit(-1) } - userArgs = userArgsBuffer.readOnly + userArgs = userArgsBuffer.toList } def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { @@ -101,12 +91,8 @@ class ApplicationMasterArguments(val args: Array[String]) { | --class CLASS_NAME Name of your application's main class | --primary-py-file A main Python file | --primary-r-file A main R file - | --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to - | place on the PYTHONPATH for Python apps. - | --args ARGS Arguments to be passed to your application's main class. + | --arg ARG Argument to be passed to your application's main class. | Multiple invocations are possible, each will be passed in order. - | --executor-cores NUM Number of cores for the executors (Default: 1) - | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) | --properties-file FILE Path to a custom Spark properties file. """.stripMargin) // scalastyle:on println diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index a3f33d80184a3..04e91f8553d51 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -19,31 +19,26 @@ package org.apache.spark.deploy.yarn import java.io.{ByteArrayInputStream, DataInputStream, File, FileOutputStream, IOException, OutputStreamWriter} -import java.net.{InetAddress, UnknownHostException, URI, URISyntaxException} +import java.net.{InetAddress, UnknownHostException, URI} import java.nio.ByteBuffer -import java.security.PrivilegedExceptionAction +import java.nio.charset.StandardCharsets import java.util.{Properties, UUID} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map} -import scala.reflect.runtime.universe -import scala.util.{Try, Success, Failure} +import scala.util.{Failure, Success, Try} import scala.util.control.NonFatal -import com.google.common.base.Charsets.UTF_8 import com.google.common.base.Objects import com.google.common.io.Files - -import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission -import org.apache.hadoop.io.Text +import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier +import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.mapreduce.MRJobConfig import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.apache.hadoop.security.token.{TokenIdentifier, Token} import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment @@ -54,9 +49,12 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException import org.apache.hadoop.yarn.util.Records -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException} -import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle, YarnCommandBuilderUtils} +import org.apache.spark.{SecurityManager, SparkConf, SparkContext, SparkException} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.yarn.config._ +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle, YarnCommandBuilderUtils} import org.apache.spark.util.Utils private[spark] class Client( @@ -66,21 +64,44 @@ private[spark] class Client( extends Logging { import Client._ + import YarnSparkHadoopUtil._ def this(clientArgs: ClientArguments, spConf: SparkConf) = this(clientArgs, SparkHadoopUtil.get.newConfiguration(spConf), spConf) private val yarnClient = YarnClient.createYarnClient private val yarnConf = new YarnConfiguration(hadoopConf) - private var credentials: Credentials = null - private val amMemoryOverhead = args.amMemoryOverhead // MB - private val executorMemoryOverhead = args.executorMemoryOverhead // MB + + private val isClusterMode = sparkConf.get("spark.submit.deployMode", "client") == "cluster" + + // AM related configurations + private val amMemory = if (isClusterMode) { + sparkConf.get(DRIVER_MEMORY).toInt + } else { + sparkConf.get(AM_MEMORY).toInt + } + private val amMemoryOverhead = { + val amMemoryOverheadEntry = if (isClusterMode) DRIVER_MEMORY_OVERHEAD else AM_MEMORY_OVERHEAD + sparkConf.get(amMemoryOverheadEntry).getOrElse( + math.max((MEMORY_OVERHEAD_FACTOR * amMemory).toLong, MEMORY_OVERHEAD_MIN)).toInt + } + private val amCores = if (isClusterMode) { + sparkConf.get(DRIVER_CORES) + } else { + sparkConf.get(AM_CORES) + } + + // Executor related configurations + private val executorMemory = sparkConf.get(EXECUTOR_MEMORY) + private val executorMemoryOverhead = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse( + math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toLong, MEMORY_OVERHEAD_MIN)).toInt + private val distCacheMgr = new ClientDistributedCacheManager() - private val isClusterMode = args.isClusterMode private var loginFromKeytab = false private var principal: String = null private var keytab: String = null + private var credentials: Credentials = null private val launcherBackend = new LauncherBackend() { override def onStopRequest(): Unit = { @@ -92,8 +113,7 @@ private[spark] class Client( } } } - private val fireAndForget = isClusterMode && - !sparkConf.getBoolean("spark.yarn.submit.waitAppCompletion", true) + private val fireAndForget = isClusterMode && !sparkConf.get(WAIT_FOR_APP_COMPLETION) private var appId: ApplicationId = null @@ -161,9 +181,9 @@ private[spark] class Client( private def cleanupStagingDir(appId: ApplicationId): Unit = { val appStagingDir = getAppStagingDir(appId) try { - val preserveFiles = sparkConf.getBoolean("spark.yarn.preserve.staging.files", false) - val stagingDirPath = new Path(appStagingDir) + val preserveFiles = sparkConf.get(PRESERVE_STAGING_FILES) val fs = FileSystem.get(hadoopConf) + val stagingDirPath = getAppStagingDirPath(sparkConf, fs, appStagingDir) if (!preserveFiles && fs.exists(stagingDirPath)) { logInfo("Deleting staging directory " + stagingDirPath) fs.delete(stagingDirPath, true) @@ -182,50 +202,71 @@ private[spark] class Client( newApp: YarnClientApplication, containerContext: ContainerLaunchContext): ApplicationSubmissionContext = { val appContext = newApp.getApplicationSubmissionContext - appContext.setApplicationName(args.appName) - appContext.setQueue(args.amQueue) + appContext.setApplicationName(sparkConf.get("spark.app.name", "Spark")) + appContext.setQueue(sparkConf.get(QUEUE_NAME)) appContext.setAMContainerSpec(containerContext) appContext.setApplicationType("SPARK") - sparkConf.getOption(CONF_SPARK_YARN_APPLICATION_TAGS) - .map(StringUtils.getTrimmedStringCollection(_)) - .filter(!_.isEmpty()) - .foreach { tagCollection => - try { - // The setApplicationTags method was only introduced in Hadoop 2.4+, so we need to use - // reflection to set it, printing a warning if a tag was specified but the YARN version - // doesn't support it. - val method = appContext.getClass().getMethod( - "setApplicationTags", classOf[java.util.Set[String]]) - method.invoke(appContext, new java.util.HashSet[String](tagCollection)) - } catch { - case e: NoSuchMethodException => - logWarning(s"Ignoring $CONF_SPARK_YARN_APPLICATION_TAGS because this version of " + - "YARN does not support it") - } + + sparkConf.get(APPLICATION_TAGS).foreach { tags => + try { + // The setApplicationTags method was only introduced in Hadoop 2.4+, so we need to use + // reflection to set it, printing a warning if a tag was specified but the YARN version + // doesn't support it. + val method = appContext.getClass().getMethod( + "setApplicationTags", classOf[java.util.Set[String]]) + method.invoke(appContext, new java.util.HashSet[String](tags.asJava)) + } catch { + case e: NoSuchMethodException => + logWarning(s"Ignoring ${APPLICATION_TAGS.key} because this version of " + + "YARN does not support it") } - sparkConf.getOption("spark.yarn.maxAppAttempts").map(_.toInt) match { + } + sparkConf.get(MAX_APP_ATTEMPTS) match { case Some(v) => appContext.setMaxAppAttempts(v) - case None => logDebug("spark.yarn.maxAppAttempts is not set. " + + case None => logDebug(s"${MAX_APP_ATTEMPTS.key} is not set. " + "Cluster's default value will be used.") } - if (sparkConf.contains("spark.yarn.am.attemptFailuresValidityInterval")) { + sparkConf.get(ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS).foreach { interval => try { - val interval = sparkConf.getTimeAsMs("spark.yarn.am.attemptFailuresValidityInterval") val method = appContext.getClass().getMethod( "setAttemptFailuresValidityInterval", classOf[Long]) method.invoke(appContext, interval: java.lang.Long) } catch { case e: NoSuchMethodException => - logWarning("Ignoring spark.yarn.am.attemptFailuresValidityInterval because the version " + - "of YARN does not support it") + logWarning(s"Ignoring ${ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS.key} because " + + "the version of YARN does not support it") } } val capability = Records.newRecord(classOf[Resource]) - capability.setMemory(args.amMemory + amMemoryOverhead) - capability.setVirtualCores(args.amCores) - appContext.setResource(capability) + capability.setMemory(amMemory + amMemoryOverhead) + capability.setVirtualCores(amCores) + + sparkConf.get(AM_NODE_LABEL_EXPRESSION) match { + case Some(expr) => + try { + val amRequest = Records.newRecord(classOf[ResourceRequest]) + amRequest.setResourceName(ResourceRequest.ANY) + amRequest.setPriority(Priority.newInstance(0)) + amRequest.setCapability(capability) + amRequest.setNumContainers(1) + val method = amRequest.getClass.getMethod("setNodeLabelExpression", classOf[String]) + method.invoke(amRequest, expr) + + val setResourceRequestMethod = + appContext.getClass.getMethod("setAMContainerResourceRequest", classOf[ResourceRequest]) + setResourceRequestMethod.invoke(appContext, amRequest) + } catch { + case e: NoSuchMethodException => + logWarning(s"Ignoring ${AM_NODE_LABEL_EXPRESSION.key} because the version " + + "of YARN does not support it") + appContext.setResource(capability) + } + case None => + appContext.setResource(capability) + } + appContext } @@ -254,15 +295,16 @@ private[spark] class Client( val maxMem = newAppResponse.getMaximumResourceCapability().getMemory() logInfo("Verifying our application has not requested more than the maximum " + s"memory capability of the cluster ($maxMem MB per container)") - val executorMem = args.executorMemory + executorMemoryOverhead + val executorMem = executorMemory + executorMemoryOverhead if (executorMem > maxMem) { - throw new IllegalArgumentException(s"Required executor memory (${args.executorMemory}" + + throw new IllegalArgumentException(s"Required executor memory ($executorMemory" + s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " + - "Please increase the value of 'yarn.scheduler.maximum-allocation-mb'.") + "Please check the values of 'yarn.scheduler.maximum-allocation-mb' and/or " + + "'yarn.nodemanager.resource.memory-mb'.") } - val amMem = args.amMemory + amMemoryOverhead + val amMem = amMemory + amMemoryOverhead if (amMem > maxMem) { - throw new IllegalArgumentException(s"Required AM memory (${args.amMemory}" + + throw new IllegalArgumentException(s"Required AM memory ($amMemory" + s"+$amMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " + "Please increase the value of 'yarn.scheduler.maximum-allocation-mb'.") } @@ -315,31 +357,23 @@ private[spark] class Client( // Upload Spark and the application JAR to the remote file system if necessary, // and add them as local resources to the application master. val fs = FileSystem.get(hadoopConf) - val dst = new Path(fs.getHomeDirectory(), appStagingDir) + val dst = getAppStagingDirPath(sparkConf, fs, appStagingDir) val nns = YarnSparkHadoopUtil.get.getNameNodesToAccess(sparkConf) + dst YarnSparkHadoopUtil.get.obtainTokensForNamenodes(nns, hadoopConf, credentials) // Used to keep track of URIs added to the distributed cache. If the same URI is added // multiple times, YARN will fail to launch containers for the app with an internal // error. val distributedUris = new HashSet[String] - obtainTokenForHiveMetastore(sparkConf, hadoopConf, credentials) - obtainTokenForHBase(sparkConf, hadoopConf, credentials) + YarnSparkHadoopUtil.get.obtainTokenForHiveMetastore(sparkConf, hadoopConf, credentials) + YarnSparkHadoopUtil.get.obtainTokenForHBase(sparkConf, hadoopConf, credentials) - val replication = sparkConf.getInt("spark.yarn.submit.file.replication", - fs.getDefaultReplication(dst)).toShort + val replication = sparkConf.get(STAGING_FILE_REPLICATION).map(_.toShort) + .getOrElse(fs.getDefaultReplication(dst)) val localResources = HashMap[String, LocalResource]() FileSystem.mkdirs(fs, dst, new FsPermission(STAGING_DIR_PERMISSION)) val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() - val oldLog4jConf = Option(System.getenv("SPARK_LOG4J_CONF")) - if (oldLog4jConf.isDefined) { - logWarning( - "SPARK_LOG4J_CONF detected in the system environment. This variable has been " + - "deprecated. Please refer to the \"Launching Spark on YARN\" documentation " + - "for alternatives.") - } - def addDistributedUri(uri: URI): Boolean = { val uriStr = uri.toString() if (distributedUris.contains(uriStr)) { @@ -399,32 +433,74 @@ private[spark] class Client( logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" + " via the YARN Secure Distributed Cache.") val (_, localizedPath) = distribute(keytab, - destName = Some(sparkConf.get("spark.yarn.keytab")), + destName = sparkConf.get(KEYTAB), appMasterOnly = true) require(localizedPath != null, "Keytab file already distributed.") } /** - * Copy the given main resource to the distributed cache if the scheme is not "local". + * Add Spark to the cache. There are two settings that control what files to add to the cache: + * - if a Spark archive is defined, use the archive. The archive is expected to contain + * jar files at its root directory. + * - if a list of jars is provided, filter the non-local ones, resolve globs, and + * add the found files to the cache. + * + * Note that the archive cannot be a "local" URI. If none of the above settings are found, + * then upload all files found in $SPARK_HOME/jars. + */ + val sparkArchive = sparkConf.get(SPARK_ARCHIVE) + if (sparkArchive.isDefined) { + val archive = sparkArchive.get + require(!isLocalUri(archive), s"${SPARK_ARCHIVE.key} cannot be a local URI.") + distribute(Utils.resolveURI(archive).toString, + resType = LocalResourceType.ARCHIVE, + destName = Some(LOCALIZED_LIB_DIR)) + } else { + sparkConf.get(SPARK_JARS) match { + case Some(jars) => + // Break the list of jars to upload, and resolve globs. + val localJars = new ArrayBuffer[String]() + jars.foreach { jar => + if (!isLocalUri(jar)) { + val path = getQualifiedLocalPath(Utils.resolveURI(jar), hadoopConf) + val pathFs = FileSystem.get(path.toUri(), hadoopConf) + pathFs.globStatus(path).filter(_.isFile()).foreach { entry => + distribute(entry.getPath().toUri().toString(), + targetDir = Some(LOCALIZED_LIB_DIR)) + } + } else { + localJars += jar + } + } + + // Propagate the local URIs to the containers using the configuration. + sparkConf.set(SPARK_JARS, localJars) + + case None => + // No configuration, so fall back to uploading local jar files. + logWarning(s"Neither ${SPARK_JARS.key} nor ${SPARK_ARCHIVE.key} is set, falling back " + + "to uploading libraries under SPARK_HOME.") + val jarsDir = new File(YarnCommandBuilderUtils.findJarsDir( + sparkConf.getenv("SPARK_HOME"))) + jarsDir.listFiles().foreach { f => + if (f.isFile() && f.getName().toLowerCase().endsWith(".jar")) { + distribute(f.getAbsolutePath(), targetDir = Some(LOCALIZED_LIB_DIR)) + } + } + } + } + + /** + * Copy user jar to the distributed cache if their scheme is not "local". * Otherwise, set the corresponding key in our SparkConf to handle it downstream. - * Each resource is represented by a 3-tuple of: - * (1) destination resource name, - * (2) local path to the resource, - * (3) Spark property key to set if the scheme is not local */ - List( - (SPARK_JAR, sparkJar(sparkConf), CONF_SPARK_JAR), - (APP_JAR, args.userJar, CONF_SPARK_USER_JAR), - ("log4j.properties", oldLog4jConf.orNull, null) - ).foreach { case (destName, path, confKey) => - if (path != null && !path.trim().isEmpty()) { - val (isLocal, localizedPath) = distribute(path, destName = Some(destName)) - if (isLocal && confKey != null) { - require(localizedPath != null, s"Path $path already distributed.") - // If the resource is intended for local use only, handle this downstream - // by setting the appropriate property - sparkConf.set(confKey, localizedPath) - } + Option(args.userJar).filter(_.trim.nonEmpty).foreach { jar => + val (isLocal, localizedPath) = distribute(jar, destName = Some(APP_JAR_NAME)) + if (isLocal) { + require(localizedPath != null, s"Path $jar already distributed") + // If the resource is intended for local use only, handle this downstream + // by setting the appropriate property + sparkConf.set(APP_JAR, localizedPath) } } @@ -437,22 +513,20 @@ private[spark] class Client( */ val cachedSecondaryJarLinks = ListBuffer.empty[String] List( - (args.addJars, LocalResourceType.FILE, true), - (args.files, LocalResourceType.FILE, false), - (args.archives, LocalResourceType.ARCHIVE, false) + (sparkConf.get(JARS_TO_DISTRIBUTE), LocalResourceType.FILE, true), + (sparkConf.get(FILES_TO_DISTRIBUTE), LocalResourceType.FILE, false), + (sparkConf.get(ARCHIVES_TO_DISTRIBUTE), LocalResourceType.ARCHIVE, false) ).foreach { case (flist, resType, addToClasspath) => - if (flist != null && !flist.isEmpty()) { - flist.split(',').foreach { file => - val (_, localizedPath) = distribute(file, resType = resType) - require(localizedPath != null) - if (addToClasspath) { - cachedSecondaryJarLinks += localizedPath - } + flist.foreach { file => + val (_, localizedPath) = distribute(file, resType = resType) + require(localizedPath != null) + if (addToClasspath) { + cachedSecondaryJarLinks += localizedPath } } } if (cachedSecondaryJarLinks.nonEmpty) { - sparkConf.set(CONF_SPARK_YARN_SECONDARY_JARS, cachedSecondaryJarLinks.mkString(",")) + sparkConf.set(SECONDARY_JARS, cachedSecondaryJarLinks) } if (isClusterMode && args.primaryPyFile != null) { @@ -463,16 +537,15 @@ private[spark] class Client( // The python files list needs to be treated especially. All files that are not an // archive need to be placed in a subdirectory that will be added to PYTHONPATH. - args.pyFiles.foreach { f => + sparkConf.get(PY_FILES).foreach { f => val targetDir = if (f.endsWith(".py")) Some(LOCALIZED_PYTHON_DIR) else None distribute(f, targetDir = targetDir) } - // Distribute an archive with Hadoop and Spark configuration for the AM. + // Distribute an archive with Hadoop and Spark configuration for the AM and executors. val (_, confLocalizedPath) = distribute(createConfArchive().toURI().getPath(), resType = LocalResourceType.ARCHIVE, - destName = Some(LOCALIZED_CONF_DIR), - appMasterOnly = true) + destName = Some(LOCALIZED_CONF_DIR)) require(confLocalizedPath != null) localResources @@ -481,10 +554,10 @@ private[spark] class Client( /** * Create an archive with the config files for distribution. * - * These are only used by the AM, since executors will use the configuration object broadcast by - * the driver. The files are zipped and added to the job as an archive, so that YARN will explode - * it when distributing to the AM. This directory is then added to the classpath of the AM - * process, just to make sure that everybody is using the same default config. + * These will be used by AM and executors. The files are zipped and added to the job as an + * archive, so that YARN will explode it when distributing to AM and executors. This directory + * is then added to the classpath of AM and executor process, just to make sure that everybody + * is using the same default config. * * This follows the order of precedence set by the startup scripts, in which HADOOP_CONF_DIR * shows up in the classpath before YARN_CONF_DIR. @@ -503,20 +576,28 @@ private[spark] class Client( // required when user changes log4j.properties directly to set the log configurations. If // configuration file is provided through --files then executors will be taking configurations // from --files instead of $SPARK_CONF_DIR/log4j.properties. - val log4jFileName = "log4j.properties" - Option(Utils.getContextOrSparkClassLoader.getResource(log4jFileName)).foreach { url => - if (url.getProtocol == "file") { - hadoopConfFiles(log4jFileName) = new File(url.getPath) - } + + // Also uploading metrics.properties to distributed cache if exists in classpath. + // If user specify this file using --files then executors will use the one + // from --files instead. + for { prop <- Seq("log4j.properties", "metrics.properties") + url <- Option(Utils.getContextOrSparkClassLoader.getResource(prop)) + if url.getProtocol == "file" } { + hadoopConfFiles(prop) = new File(url.getPath) } Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey => sys.env.get(envKey).foreach { path => val dir = new File(path) if (dir.isDirectory()) { - dir.listFiles().foreach { file => - if (file.isFile && !hadoopConfFiles.contains(file.getName())) { - hadoopConfFiles(file.getName()) = file + val files = dir.listFiles() + if (files == null) { + logWarning("Failed to list files under directory " + dir) + } else { + files.foreach { file => + if (file.isFile && !hadoopConfFiles.contains(file.getName())) { + hadoopConfFiles(file.getName()) = file + } } } } @@ -541,7 +622,7 @@ private[spark] class Client( val props = new Properties() sparkConf.getAll.foreach { case (k, v) => props.setProperty(k, v) } confStream.putNextEntry(new ZipEntry(SPARK_CONF_FILE)) - val writer = new OutputStreamWriter(confStream, UTF_8) + val writer = new OutputStreamWriter(confStream, StandardCharsets.UTF_8) props.store(writer, "Spark configuration.") writer.flush() confStream.closeEntry() @@ -561,7 +642,7 @@ private[spark] class Client( val creds = new Credentials() val nns = YarnSparkHadoopUtil.get.getNameNodesToAccess(sparkConf) + stagingDirPath YarnSparkHadoopUtil.get.obtainTokensForNamenodes( - nns, hadoopConf, creds, Some(sparkConf.get("spark.yarn.principal"))) + nns, hadoopConf, creds, sparkConf.get(PRINCIPAL)) val t = creds.getAllTokens.asScala .filter(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND) .head @@ -581,20 +662,18 @@ private[spark] class Client( pySparkArchives: Seq[String]): HashMap[String, String] = { logInfo("Setting up the launch environment for our AM container") val env = new HashMap[String, String]() - val extraCp = sparkConf.getOption("spark.driver.extraClassPath") - populateClasspath(args, yarnConf, sparkConf, env, true, extraCp) + populateClasspath(args, yarnConf, sparkConf, env, sparkConf.get(DRIVER_CLASS_PATH)) env("SPARK_YARN_MODE") = "true" env("SPARK_YARN_STAGING_DIR") = stagingDir env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() if (loginFromKeytab) { val remoteFs = FileSystem.get(hadoopConf) - val stagingDirPath = new Path(remoteFs.getHomeDirectory, stagingDir) + val stagingDirPath = getAppStagingDirPath(sparkConf, remoteFs, stagingDir) val credentialsFile = "credentials-" + UUID.randomUUID().toString - sparkConf.set( - "spark.yarn.credentials.file", new Path(stagingDirPath, credentialsFile).toString) + sparkConf.set(CREDENTIALS_FILE_PATH, new Path(stagingDirPath, credentialsFile).toString) logInfo(s"Credentials file set to: $credentialsFile") val renewalInterval = getTokenRenewalInterval(stagingDirPath) - sparkConf.set("spark.yarn.token.renewal.interval", renewalInterval.toString) + sparkConf.set(TOKEN_RENEWAL_INTERVAL, renewalInterval) } // Pick up any environment variables for the AM provided through spark.yarn.appMasterEnv.* @@ -617,7 +696,7 @@ private[spark] class Client( // // NOTE: the code currently does not handle .py files defined with a "local:" scheme. val pythonPath = new ListBuffer[String]() - val (pyFiles, pyArchives) = args.pyFiles.partition(_.endsWith(".py")) + val (pyFiles, pyArchives) = sparkConf.get(PY_FILES).partition(_.endsWith(".py")) if (pyFiles.nonEmpty) { pythonPath += buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), LOCALIZED_PYTHON_DIR) @@ -669,6 +748,9 @@ private[spark] class Client( } env("SPARK_JAVA_OPTS") = value } + // propagate PYSPARK_DRIVER_PYTHON and PYSPARK_PYTHON to driver in cluster mode + sys.env.get("PYSPARK_DRIVER_PYTHON").foreach(env("PYSPARK_DRIVER_PYTHON") = _) + sys.env.get("PYSPARK_PYTHON").foreach(env("PYSPARK_PYTHON") = _) } sys.env.get(ENV_DIST_CLASSPATH).foreach { dcp => @@ -688,7 +770,7 @@ private[spark] class Client( val appId = newAppResponse.getApplicationId val appStagingDir = getAppStagingDir(appId) val pySparkArchives = - if (sparkConf.getBoolean("spark.yarn.isPython", false)) { + if (sparkConf.get(IS_PYTHON_APP)) { findPySparkArchives() } else { Nil @@ -711,7 +793,7 @@ private[spark] class Client( var prefixEnv: Option[String] = None // Add Xmx for AM memory - javaOpts += "-Xmx" + args.amMemory + "m" + javaOpts += "-Xmx" + amMemory + "m" val tmpDir = new Path( YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), @@ -741,36 +823,33 @@ private[spark] class Client( // Include driver-specific java options if we are launching a driver if (isClusterMode) { - val driverOpts = sparkConf.getOption("spark.driver.extraJavaOptions") - .orElse(sys.env.get("SPARK_JAVA_OPTS")) + val driverOpts = sparkConf.get(DRIVER_JAVA_OPTIONS).orElse(sys.env.get("SPARK_JAVA_OPTS")) driverOpts.foreach { opts => javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } - val libraryPaths = Seq(sys.props.get("spark.driver.extraLibraryPath"), + val libraryPaths = Seq(sparkConf.get(DRIVER_LIBRARY_PATH), sys.props.get("spark.driver.libraryPath")).flatten if (libraryPaths.nonEmpty) { prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(libraryPaths))) } - if (sparkConf.getOption("spark.yarn.am.extraJavaOptions").isDefined) { - logWarning("spark.yarn.am.extraJavaOptions will not take effect in cluster mode") + if (sparkConf.get(AM_JAVA_OPTIONS).isDefined) { + logWarning(s"${AM_JAVA_OPTIONS.key} will not take effect in cluster mode") } } else { // Validate and include yarn am specific java options in yarn-client mode. - val amOptsKey = "spark.yarn.am.extraJavaOptions" - val amOpts = sparkConf.getOption(amOptsKey) - amOpts.foreach { opts => + sparkConf.get(AM_JAVA_OPTIONS).foreach { opts => if (opts.contains("-Dspark")) { - val msg = s"$amOptsKey is not allowed to set Spark options (was '$opts'). " + val msg = s"${AM_JAVA_OPTIONS.key} is not allowed to set Spark options (was '$opts')." throw new SparkException(msg) } - if (opts.contains("-Xmx") || opts.contains("-Xms")) { - val msg = s"$amOptsKey is not allowed to alter memory settings (was '$opts')." + if (opts.contains("-Xmx")) { + val msg = s"${AM_JAVA_OPTIONS.key} is not allowed to specify max heap memory settings " + + s"(was '$opts'). Use spark.yarn.am.memory instead." throw new SparkException(msg) } javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } - - sparkConf.getOption("spark.yarn.am.extraLibraryPath").foreach { paths => + sparkConf.get(AM_LIBRARY_PATH).foreach { paths => prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(paths)))) } } @@ -818,8 +897,6 @@ private[spark] class Client( val amArgs = Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ primaryRFile ++ userArgs ++ Seq( - "--executor-memory", args.executorMemory.toString + "m", - "--executor-cores", args.executorCores.toString, "--properties-file", buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), LOCALIZED_CONF_DIR, SPARK_CONF_FILE)) @@ -858,17 +935,10 @@ private[spark] class Client( } def setupCredentials(): Unit = { - loginFromKeytab = args.principal != null || sparkConf.contains("spark.yarn.principal") + loginFromKeytab = sparkConf.contains(PRINCIPAL.key) if (loginFromKeytab) { - principal = - if (args.principal != null) args.principal else sparkConf.get("spark.yarn.principal") - keytab = { - if (args.keytab != null) { - args.keytab - } else { - sparkConf.getOption("spark.yarn.keytab").orNull - } - } + principal = sparkConf.get(PRINCIPAL).get + keytab = sparkConf.get(KEYTAB).orNull require(keytab != null, "Keytab must be specified when principal is specified.") logInfo("Attempting to login to the Kerberos" + @@ -877,8 +947,8 @@ private[spark] class Client( // Generate a file name that can be used for the keytab file, that does not conflict // with any user file. val keytabFileName = f.getName + "-" + UUID.randomUUID().toString - sparkConf.set("spark.yarn.keytab", keytabFileName) - sparkConf.set("spark.yarn.principal", principal) + sparkConf.set(KEYTAB.key, keytabFileName) + sparkConf.set(PRINCIPAL.key, principal) } credentials = UserGroupInformation.getCurrentUser.getCredentials } @@ -898,7 +968,7 @@ private[spark] class Client( appId: ApplicationId, returnOnRunning: Boolean = false, logApplicationReport: Boolean = true): (YarnApplicationState, FinalApplicationStatus) = { - val interval = sparkConf.getLong("spark.yarn.report.interval", 1000) + val interval = sparkConf.get(REPORT_INTERVAL) var lastState: YarnApplicationState = null while (true) { Thread.sleep(interval) @@ -1021,16 +1091,16 @@ private[spark] class Client( val pyArchivesFile = new File(pyLibPath, "pyspark.zip") require(pyArchivesFile.exists(), "pyspark.zip not found; cannot run pyspark application in YARN mode.") - val py4jFile = new File(pyLibPath, "py4j-0.9-src.zip") + val py4jFile = new File(pyLibPath, "py4j-0.9.2-src.zip") require(py4jFile.exists(), - "py4j-0.9-src.zip not found; cannot run pyspark application in YARN mode.") + "py4j-0.9.2-src.zip not found; cannot run pyspark application in YARN mode.") Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath()) } } } -object Client extends Logging { +private object Client extends Logging { def main(argStrings: Array[String]) { if (!sys.props.contains("SPARK_SUBMIT")) { @@ -1043,17 +1113,12 @@ object Client extends Logging { System.setProperty("SPARK_YARN_MODE", "true") val sparkConf = new SparkConf - val args = new ClientArguments(argStrings, sparkConf) - // to maintain backwards-compatibility - if (!Utils.isDynamicAllocationEnabled(sparkConf)) { - sparkConf.setIfMissing("spark.executor.instances", args.numExecutors.toString) - } + val args = new ClientArguments(argStrings) new Client(args, sparkConf).run() } - // Alias for the Spark assembly jar and the user jar - val SPARK_JAR: String = "__spark__.jar" - val APP_JAR: String = "__app__.jar" + // Alias for the user jar + val APP_JAR_NAME: String = "__app__.jar" // URI scheme that identifies local resources val LOCAL_SCHEME = "local" @@ -1061,20 +1126,6 @@ object Client extends Logging { // Staging directory for any temporary jars or files val SPARK_STAGING: String = ".sparkStaging" - // Location of any user-defined Spark jars - val CONF_SPARK_JAR = "spark.yarn.jar" - val ENV_SPARK_JAR = "SPARK_JAR" - - // Internal config to propagate the location of the user's jar to the driver/executors - val CONF_SPARK_USER_JAR = "spark.yarn.user.jar" - - // Internal config to propagate the locations of any extra jars to add to the classpath - // of the executors - val CONF_SPARK_YARN_SECONDARY_JARS = "spark.yarn.secondary.jars" - - // Comma-separated list of strings to pass through as YARN application tags appearing - // in YARN ApplicationReports, which can be used for filtering when querying YARN. - val CONF_SPARK_YARN_APPLICATION_TAGS = "spark.yarn.tags" // Staging directory is private! -> rwx-------- val STAGING_DIR_PERMISSION: FsPermission = @@ -1096,28 +1147,8 @@ object Client extends Logging { // Subdirectory where the user's python files (not archives) will be placed. val LOCALIZED_PYTHON_DIR = "__pyfiles__" - /** - * Find the user-defined Spark jar if configured, or return the jar containing this - * class if not. - * - * This method first looks in the SparkConf object for the CONF_SPARK_JAR key, and in the - * user environment if that is not found (for backwards compatibility). - */ - private def sparkJar(conf: SparkConf): String = { - if (conf.contains(CONF_SPARK_JAR)) { - conf.get(CONF_SPARK_JAR) - } else if (System.getenv(ENV_SPARK_JAR) != null) { - logWarning( - s"$ENV_SPARK_JAR detected in the system environment. This variable has been deprecated " + - s"in favor of the $CONF_SPARK_JAR configuration variable.") - System.getenv(ENV_SPARK_JAR) - } else { - SparkContext.jarOfClass(this.getClass).getOrElse(throw new SparkException("Could not " - + "find jar containing Spark classes. The jar can be defined using the " - + "spark.yarn.jar configuration option. If testing Spark, either set that option or " - + "make sure SPARK_PREPEND_CLASSES is not set.")) - } - } + // Subdirectory where Spark libraries will be placed. + val LOCALIZED_LIB_DIR = "__spark_libs__" /** * Return the path to the given application's staging directory. @@ -1202,20 +1233,18 @@ object Client extends Logging { conf: Configuration, sparkConf: SparkConf, env: HashMap[String, String], - isAM: Boolean, extraClassPath: Option[String] = None): Unit = { extraClassPath.foreach { cp => addClasspathEntry(getClusterPath(sparkConf, cp), env) } + addClasspathEntry(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env) - if (isAM) { - addClasspathEntry( - YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR + - LOCALIZED_CONF_DIR, env) - } + addClasspathEntry( + YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR + + LOCALIZED_CONF_DIR, env) - if (sparkConf.getBoolean("spark.yarn.user.classpath.first", false)) { + if (sparkConf.get(USER_CLASS_PATH_FIRST)) { // in order to properly add the app jar when user classpath is first // we have to do the mainJar separate in order to send the right thing // into addFileToClasspath @@ -1223,21 +1252,32 @@ object Client extends Logging { if (args != null) { getMainJarUri(Option(args.userJar)) } else { - getMainJarUri(sparkConf.getOption(CONF_SPARK_USER_JAR)) + getMainJarUri(sparkConf.get(APP_JAR)) } - mainJar.foreach(addFileToClasspath(sparkConf, conf, _, APP_JAR, env)) + mainJar.foreach(addFileToClasspath(sparkConf, conf, _, APP_JAR_NAME, env)) val secondaryJars = if (args != null) { - getSecondaryJarUris(Option(args.addJars)) + getSecondaryJarUris(Option(sparkConf.get(JARS_TO_DISTRIBUTE))) } else { - getSecondaryJarUris(sparkConf.getOption(CONF_SPARK_YARN_SECONDARY_JARS)) + getSecondaryJarUris(sparkConf.get(SECONDARY_JARS)) } secondaryJars.foreach { x => addFileToClasspath(sparkConf, conf, x, null, env) } } - addFileToClasspath(sparkConf, conf, new URI(sparkJar(sparkConf)), SPARK_JAR, env) + + // Add the Spark jars to the classpath, depending on how they were distributed. + addClasspathEntry(buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), + LOCALIZED_LIB_DIR, "*"), env) + if (!sparkConf.get(SPARK_ARCHIVE).isDefined) { + sparkConf.get(SPARK_JARS).foreach { jars => + jars.filter(isLocalUri).foreach { jar => + addClasspathEntry(getClusterPath(sparkConf, jar), env) + } + } + } + populateHadoopClasspath(conf, env) sys.env.get(ENV_DIST_CLASSPATH).foreach { cp => addClasspathEntry(getClusterPath(sparkConf, cp), env) @@ -1250,8 +1290,8 @@ object Client extends Logging { * @param conf Spark configuration. */ def getUserClasspath(conf: SparkConf): Array[URI] = { - val mainUri = getMainJarUri(conf.getOption(CONF_SPARK_USER_JAR)) - val secondaryUris = getSecondaryJarUris(conf.getOption(CONF_SPARK_YARN_SECONDARY_JARS)) + val mainUri = getMainJarUri(conf.get(APP_JAR)) + val secondaryUris = getSecondaryJarUris(conf.get(SECONDARY_JARS)) (mainUri ++ secondaryUris).toArray } @@ -1259,11 +1299,11 @@ object Client extends Logging { mainJar.flatMap { path => val uri = Utils.resolveURI(path) if (uri.getScheme == LOCAL_SCHEME) Some(uri) else None - }.orElse(Some(new URI(APP_JAR))) + }.orElse(Some(new URI(APP_JAR_NAME))) } - private def getSecondaryJarUris(secondaryJars: Option[String]): Seq[URI] = { - secondaryJars.map(_.split(",")).toSeq.flatten.map(new URI(_)) + private def getSecondaryJarUris(secondaryJars: Option[Seq[String]]): Seq[URI] = { + secondaryJars.getOrElse(Nil).map(new URI(_)) } /** @@ -1311,17 +1351,17 @@ object Client extends Logging { * * This method uses two configuration values: * - * - spark.yarn.config.gatewayPath: a string that identifies a portion of the input path that may - * only be valid in the gateway node. - * - spark.yarn.config.replacementPath: a string with which to replace the gateway path. This may - * contain, for example, env variable references, which will be expanded by the NMs when - * starting containers. + * - spark.yarn.config.gatewayPath: a string that identifies a portion of the input path that may + * only be valid in the gateway node. + * - spark.yarn.config.replacementPath: a string with which to replace the gateway path. This may + * contain, for example, env variable references, which will be expanded by the NMs when + * starting containers. * * If either config is not available, the input path is returned. */ def getClusterPath(conf: SparkConf, path: String): String = { - val localPath = conf.get("spark.yarn.config.gatewayPath", null) - val clusterPath = conf.get("spark.yarn.config.replacementPath", null) + val localPath = conf.get(GATEWAY_ROOT_PATH) + val clusterPath = conf.get(REPLACEMENT_ROOT_PATH) if (localPath != null && clusterPath != null) { path.replace(localPath, clusterPath) } else { @@ -1329,59 +1369,6 @@ object Client extends Logging { } } - /** - * Obtains token for the Hive metastore and adds them to the credentials. - */ - private def obtainTokenForHiveMetastore( - sparkConf: SparkConf, - conf: Configuration, - credentials: Credentials) { - if (shouldGetTokens(sparkConf, "hive") && UserGroupInformation.isSecurityEnabled) { - YarnSparkHadoopUtil.get.obtainTokenForHiveMetastore(conf).foreach { - credentials.addToken(new Text("hive.server2.delegation.token"), _) - } - } - } - - /** - * Obtain security token for HBase. - */ - def obtainTokenForHBase( - sparkConf: SparkConf, - conf: Configuration, - credentials: Credentials): Unit = { - if (shouldGetTokens(sparkConf, "hbase") && UserGroupInformation.isSecurityEnabled) { - val mirror = universe.runtimeMirror(getClass.getClassLoader) - - try { - val confCreate = mirror.classLoader. - loadClass("org.apache.hadoop.hbase.HBaseConfiguration"). - getMethod("create", classOf[Configuration]) - val obtainToken = mirror.classLoader. - loadClass("org.apache.hadoop.hbase.security.token.TokenUtil"). - getMethod("obtainToken", classOf[Configuration]) - - logDebug("Attempting to fetch HBase security token.") - - val hbaseConf = confCreate.invoke(null, conf).asInstanceOf[Configuration] - if ("kerberos" == hbaseConf.get("hbase.security.authentication")) { - val token = obtainToken.invoke(null, hbaseConf).asInstanceOf[Token[TokenIdentifier]] - credentials.addToken(token.getService, token) - logInfo("Added HBase security token to credentials.") - } - } catch { - case e: java.lang.NoSuchMethodException => - logInfo("HBase Method not found: " + e) - case e: java.lang.ClassNotFoundException => - logDebug("HBase Class not found: " + e) - case e: java.lang.NoClassDefFoundError => - logDebug("HBase Class not found: " + e) - case e: Exception => - logError("Exception when obtaining HBase security token: " + e) - } - } - } - /** * Return whether the two file systems are the same. */ @@ -1433,9 +1420,9 @@ object Client extends Logging { */ def isUserClassPathFirst(conf: SparkConf, isDriver: Boolean): Boolean = { if (isDriver) { - conf.getBoolean("spark.driver.userClassPathFirst", false) + conf.get(DRIVER_USER_CLASS_PATH_FIRST) } else { - conf.getBoolean("spark.executor.userClassPathFirst", false) + conf.get(EXECUTOR_USER_CLASS_PATH_FIRST) } } @@ -1446,13 +1433,21 @@ object Client extends Logging { components.mkString(Path.SEPARATOR) } + /** Returns whether the URI is a "local:" URI. */ + def isLocalUri(uri: String): Boolean = { + uri.startsWith(s"$LOCAL_SCHEME:") + } + /** - * Return whether delegation tokens should be retrieved for the given service when security is - * enabled. By default, tokens are retrieved, but that behavior can be changed by setting - * a service-specific configuration. + * Returns the app staging dir based on the STAGING_DIR configuration if configured + * otherwise based on the users home directory. */ - def shouldGetTokens(conf: SparkConf, service: String): Boolean = { - conf.getBoolean(s"spark.yarn.security.tokens.${service}.enabled", true) + private def getAppStagingDirPath( + conf: SparkConf, + fs: FileSystem, + appStagingDir: String): Path = { + val baseDir = conf.get(STAGING_DIR).map { new Path(_) }.getOrElse(fs.getHomeDirectory()) + new Path(baseDir, appStagingDir) } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 1165061db21e3..61c027ec4483a 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -19,123 +19,20 @@ package org.apache.spark.deploy.yarn import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ -import org.apache.spark.util.{IntParam, MemoryParam, Utils} - // TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware ! -private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) { - var addJars: String = null - var files: String = null - var archives: String = null +private[spark] class ClientArguments(args: Array[String]) { + var userJar: String = null var userClass: String = null - var pyFiles: Seq[String] = Nil var primaryPyFile: String = null var primaryRFile: String = null var userArgs: ArrayBuffer[String] = new ArrayBuffer[String]() - var executorMemory = 1024 // MB - var executorCores = 1 - var numExecutors = DEFAULT_NUMBER_EXECUTORS - var amQueue = sparkConf.get("spark.yarn.queue", "default") - var amMemory: Int = 512 // MB - var amCores: Int = 1 - var appName: String = "Spark" - var priority = 0 - var principal: String = null - var keytab: String = null - def isClusterMode: Boolean = userClass != null - - private var driverMemory: Int = Utils.DEFAULT_DRIVER_MEM_MB // MB - private var driverCores: Int = 1 - private val driverMemOverheadKey = "spark.yarn.driver.memoryOverhead" - private val amMemKey = "spark.yarn.am.memory" - private val amMemOverheadKey = "spark.yarn.am.memoryOverhead" - private val driverCoresKey = "spark.driver.cores" - private val amCoresKey = "spark.yarn.am.cores" - private val isDynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(sparkConf) parseArgs(args.toList) - loadEnvironmentArgs() - validateArgs() - - // Additional memory to allocate to containers - val amMemoryOverheadConf = if (isClusterMode) driverMemOverheadKey else amMemOverheadKey - val amMemoryOverhead = sparkConf.getInt(amMemoryOverheadConf, - math.max((MEMORY_OVERHEAD_FACTOR * amMemory).toInt, MEMORY_OVERHEAD_MIN)) - - val executorMemoryOverhead = sparkConf.getInt("spark.yarn.executor.memoryOverhead", - math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN)) - - /** Load any default arguments provided through environment variables and Spark properties. */ - private def loadEnvironmentArgs(): Unit = { - // For backward compatibility, SPARK_YARN_DIST_{ARCHIVES/FILES} should be resolved to hdfs://, - // while spark.yarn.dist.{archives/files} should be resolved to file:// (SPARK-2051). - files = Option(files) - .orElse(sparkConf.getOption("spark.yarn.dist.files").map(p => Utils.resolveURIs(p))) - .orElse(sys.env.get("SPARK_YARN_DIST_FILES")) - .orNull - archives = Option(archives) - .orElse(sparkConf.getOption("spark.yarn.dist.archives").map(p => Utils.resolveURIs(p))) - .orElse(sys.env.get("SPARK_YARN_DIST_ARCHIVES")) - .orNull - // If dynamic allocation is enabled, start at the configured initial number of executors. - // Default to minExecutors if no initialExecutors is set. - numExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sparkConf) - principal = Option(principal) - .orElse(sparkConf.getOption("spark.yarn.principal")) - .orNull - keytab = Option(keytab) - .orElse(sparkConf.getOption("spark.yarn.keytab")) - .orNull - } - - /** - * Fail fast if any arguments provided are invalid. - * This is intended to be called only after the provided arguments have been parsed. - */ - private def validateArgs(): Unit = { - if (numExecutors < 0 || (!isDynamicAllocationEnabled && numExecutors == 0)) { - throw new IllegalArgumentException( - s""" - |Number of executors was $numExecutors, but must be at least 1 - |(or 0 if dynamic executor allocation is enabled). - |${getUsageMessage()} - """.stripMargin) - } - if (executorCores < sparkConf.getInt("spark.task.cpus", 1)) { - throw new SparkException("Executor cores must not be less than " + - "spark.task.cpus.") - } - // scalastyle:off println - if (isClusterMode) { - for (key <- Seq(amMemKey, amMemOverheadKey, amCoresKey)) { - if (sparkConf.contains(key)) { - println(s"$key is set but does not apply in cluster mode.") - } - } - amMemory = driverMemory - amCores = driverCores - } else { - for (key <- Seq(driverMemOverheadKey, driverCoresKey)) { - if (sparkConf.contains(key)) { - println(s"$key is set but does not apply in client mode.") - } - } - sparkConf.getOption(amMemKey) - .map(Utils.memoryStringToMb) - .foreach { mem => amMemory = mem } - sparkConf.getOption(amCoresKey) - .map(_.toInt) - .foreach { cores => amCores = cores } - } - // scalastyle:on println - } private def parseArgs(inputArgs: List[String]): Unit = { var args = inputArgs - // scalastyle:off println while (!args.isEmpty) { args match { case ("--jar") :: value :: tail => @@ -154,88 +51,16 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) primaryRFile = value args = tail - case ("--args" | "--arg") :: value :: tail => - if (args(0) == "--args") { - println("--args is deprecated. Use --arg instead.") - } + case ("--arg") :: value :: tail => userArgs += value args = tail - case ("--master-class" | "--am-class") :: value :: tail => - println(s"${args(0)} is deprecated and is not used anymore.") - args = tail - - case ("--master-memory" | "--driver-memory") :: MemoryParam(value) :: tail => - if (args(0) == "--master-memory") { - println("--master-memory is deprecated. Use --driver-memory instead.") - } - driverMemory = value - args = tail - - case ("--driver-cores") :: IntParam(value) :: tail => - driverCores = value - args = tail - - case ("--num-workers" | "--num-executors") :: IntParam(value) :: tail => - if (args(0) == "--num-workers") { - println("--num-workers is deprecated. Use --num-executors instead.") - } - numExecutors = value - args = tail - - case ("--worker-memory" | "--executor-memory") :: MemoryParam(value) :: tail => - if (args(0) == "--worker-memory") { - println("--worker-memory is deprecated. Use --executor-memory instead.") - } - executorMemory = value - args = tail - - case ("--worker-cores" | "--executor-cores") :: IntParam(value) :: tail => - if (args(0) == "--worker-cores") { - println("--worker-cores is deprecated. Use --executor-cores instead.") - } - executorCores = value - args = tail - - case ("--queue") :: value :: tail => - amQueue = value - args = tail - - case ("--name") :: value :: tail => - appName = value - args = tail - - case ("--addJars") :: value :: tail => - addJars = value - args = tail - - case ("--py-files") :: value :: tail => - pyFiles = value.split(",") - args = tail - - case ("--files") :: value :: tail => - files = value - args = tail - - case ("--archives") :: value :: tail => - archives = value - args = tail - - case ("--principal") :: value :: tail => - principal = value - args = tail - - case ("--keytab") :: value :: tail => - keytab = value - args = tail - case Nil => case _ => throw new IllegalArgumentException(getUsageMessage(args)) } } - // scalastyle:on println if (primaryPyFile != null && primaryRFile != null) { throw new IllegalArgumentException("Cannot have primary-py-file and primary-r-file" + @@ -245,7 +70,6 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) private def getUsageMessage(unknownParam: List[String] = null): String = { val message = if (unknownParam != null) s"Unknown/unsupported param $unknownParam\n" else "" - val mem_mb = Utils.DEFAULT_DRIVER_MEM_MB message + s""" |Usage: org.apache.spark.deploy.yarn.Client [options] @@ -257,20 +81,6 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) | --primary-r-file A main R file | --arg ARG Argument to be passed to your application's main class. | Multiple invocations are possible, each will be passed in order. - | --num-executors NUM Number of executors to start (Default: 2) - | --executor-cores NUM Number of cores per executor (Default: 1). - | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: $mem_mb Mb) - | --driver-cores NUM Number of cores used by the driver (Default: 1). - | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) - | --name NAME The name of your application (Default: Spark) - | --queue QUEUE The hadoop queue to use for allocation requests (Default: - | 'default') - | --addJars jars Comma separated list of local jars that want SparkContext.addJar - | to work with. - | --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to - | place on the PYTHONPATH for Python apps. - | --files files Comma separated list of files to be distributed with the job. - | --archives archives Comma separated list of archives to be distributed with the job. """.stripMargin } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala index 3d3a966960e9f..869edf6c5b6af 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala @@ -25,9 +25,9 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.yarn.api.records._ -import org.apache.hadoop.yarn.util.{Records, ConverterUtils} +import org.apache.hadoop.yarn.util.{ConverterUtils, Records} -import org.apache.spark.Logging +import org.apache.spark.internal.Logging /** Client side methods to setup the Hadoop distributed cache */ private[spark] class ClientDistributedCacheManager() extends Logging { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala index 94feb6393fd69..3aa64071d478f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala @@ -18,23 +18,25 @@ package org.apache.spark.deploy.yarn import java.util.concurrent.{Executors, TimeUnit} +import scala.util.control.NonFatal + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.yarn.config._ +import org.apache.spark.internal.Logging import org.apache.spark.util.{ThreadUtils, Utils} -import scala.util.control.NonFatal - private[spark] class ExecutorDelegationTokenUpdater( sparkConf: SparkConf, hadoopConf: Configuration) extends Logging { @volatile private var lastCredentialsFileSuffix = 0 - private val credentialsFile = sparkConf.get("spark.yarn.credentials.file") + private val credentialsFile = sparkConf.get(CREDENTIALS_FILE_PATH) private val freshHadoopConf = SparkHadoopUtil.get.getConfBypassingFSCache( hadoopConf, new Path(credentialsFile).toUri.getScheme) @@ -76,7 +78,10 @@ private[spark] class ExecutorDelegationTokenUpdater( SparkHadoopUtil.get.getTimeFromNowToRenewal( sparkConf, 0.8, UserGroupInformation.getCurrentUser.getCredentials) if (timeFromNowToRenewal <= 0) { - executorUpdaterRunnable.run() + // We just checked for new credentials but none were there, wait a minute and retry. + // This handles the shutdown case where the staging directory may have been removed(see + // SPARK-12316 for more details). + delegationTokenRenewer.schedule(executorUpdaterRunnable, 1, TimeUnit.MINUTES) } else { logInfo(s"Scheduling token refresh from HDFS in $timeFromNowToRenewal millis.") delegationTokenRenewer.schedule( diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 2232ffba473b5..ef7908a3ef2ac 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -25,24 +25,27 @@ import java.util.Collections import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, ListBuffer} -import org.apache.hadoop.fs.Path -import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.NMClient import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{ConverterUtils, Records} -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} +import org.apache.spark.{SecurityManager, SparkConf, SparkException} +import org.apache.spark.deploy.yarn.config._ +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.launcher.YarnCommandBuilderUtils import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.Utils -class ExecutorRunnable( +private[yarn] class ExecutorRunnable( container: Container, conf: Configuration, sparkConf: SparkConf, @@ -104,7 +107,7 @@ class ExecutorRunnable( // If external shuffle service is enabled, register with the Yarn shuffle service already // started on the NodeManager and, if authentication is enabled, provide it with our secret // key for fetching shuffle files later - if (sparkConf.getBoolean("spark.shuffle.service.enabled", false)) { + if (sparkConf.get(SHUFFLE_SERVICE_ENABLED)) { val secretString = securityMgr.getSecretKey() val secretBytes = if (secretString != null) { @@ -144,17 +147,16 @@ class ExecutorRunnable( // Set the JVM memory val executorMemoryString = executorMemory + "m" - javaOpts += "-Xms" + executorMemoryString javaOpts += "-Xmx" + executorMemoryString // Set extra Java options for the executor, if defined - sys.props.get("spark.executor.extraJavaOptions").foreach { opts => + sparkConf.get(EXECUTOR_JAVA_OPTIONS).foreach { opts => javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } sys.env.get("SPARK_JAVA_OPTS").foreach { opts => javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } - sys.props.get("spark.executor.extraLibraryPath").foreach { p => + sparkConf.get(EXECUTOR_LIBRARY_PATH).foreach { p => prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p)))) } @@ -166,7 +168,7 @@ class ExecutorRunnable( // Certain configs need to be passed here because they are needed before the Executor // registers with the Scheduler and transfers the spark configs. Since the Executor backend - // uses Akka to connect to the scheduler, the akka settings are needed as well as the + // uses RPC to connect to the scheduler, the RPC settings are needed as well as the // authentication settings. sparkConf.getAll .filter { case (k, v) => SparkConf.isExecutorStartupConf(k) } @@ -184,9 +186,9 @@ class ExecutorRunnable( else { // If no java_opts specified, default to using -XX:+CMSIncrementalMode // It might be possible that other modes/config is being done in - // spark.executor.extraJavaOptions, so we dont want to mess with it. - // In our expts, using (default) throughput collector has severe perf ramnifications in - // multi-tennent machines + // spark.executor.extraJavaOptions, so we don't want to mess with it. + // In our expts, using (default) throughput collector has severe perf ramifications in + // multi-tenant machines // The options are based on // http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use // %20the%20Concurrent%20Low%20Pause%20Collector|outline @@ -286,8 +288,7 @@ class ExecutorRunnable( private def prepareEnvironment(container: Container): HashMap[String, String] = { val env = new HashMap[String, String]() - val extraCp = sparkConf.getOption("spark.executor.extraClassPath") - Client.populateClasspath(null, yarnConf, sparkConf, env, false, extraCp) + Client.populateClasspath(null, yarnConf, sparkConf, env, sparkConf.get(EXECUTOR_CLASS_PATH)) sparkConf.getExecutorEnv.foreach { case (key, value) => // This assumes each executor environment variable set here is a path diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala index 2ec189de7c914..8772e26f4314d 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.util.RackResolver import org.apache.spark.SparkConf +import org.apache.spark.internal.config._ private[yarn] case class ContainerLocalityPreferences(nodes: Array[String], racks: Array[String]) @@ -84,9 +85,6 @@ private[yarn] class LocalityPreferredContainerPlacementStrategy( val yarnConf: Configuration, val resource: Resource) { - // Number of CPUs per task - private val CPUS_PER_TASK = sparkConf.getInt("spark.task.cpus", 1) - /** * Calculate each container's node locality and rack locality * @param numContainer number of containers to calculate @@ -159,7 +157,7 @@ private[yarn] class LocalityPreferredContainerPlacementStrategy( */ private def numExecutorsPending(numTasksPending: Int): Int = { val coresPerExecutor = resource.getVirtualCores - (numTasksPending * CPUS_PER_TASK + coresPerExecutor - 1) / coresPerExecutor + (numTasksPending * sparkConf.get(CPUS_PER_TASK) + coresPerExecutor - 1) / coresPerExecutor } /** diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 4d9e777cb4134..23742eab6268c 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -25,22 +25,23 @@ import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.collection.JavaConverters._ -import com.google.common.util.concurrent.ThreadFactoryBuilder - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.util.RackResolver - import org.apache.log4j.{Level, Logger} -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ +import org.apache.spark.deploy.yarn.config._ +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor -import org.apache.spark.util.Utils +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RetrieveLastAllocatedExecutorId +import org.apache.spark.util.ThreadUtils /** * YarnAllocator is charged with requesting containers from the YARN ResourceManager and deciding @@ -62,7 +63,6 @@ private[yarn] class YarnAllocator( sparkConf: SparkConf, amClient: AMRMClient[ContainerRequest], appAttemptId: ApplicationAttemptId, - args: ApplicationMasterArguments, securityMgr: SecurityManager) extends Logging { @@ -84,8 +84,23 @@ private[yarn] class YarnAllocator( new ConcurrentHashMap[ContainerId, java.lang.Boolean]) @volatile private var numExecutorsRunning = 0 - // Used to generate a unique ID per executor - private var executorIdCounter = 0 + + /** + * Used to generate a unique ID per executor + * + * Init `executorIdCounter`. when AM restart, `executorIdCounter` will reset to 0. Then + * the id of new executor will start from 1, this will conflict with the executor has + * already created before. So, we should initialize the `executorIdCounter` by getting + * the max executorId from driver. + * + * And this situation of executorId conflict is just in yarn client mode, so this is an issue + * in yarn client mode. For more details, can check in jira. + * + * @see SPARK-12864 + */ + private var executorIdCounter: Int = + driverRef.askWithRetry[Int](RetrieveLastAllocatedExecutorId) + @volatile private var numExecutorsFailed = 0 @volatile private var targetNumExecutors = @@ -96,6 +111,10 @@ private[yarn] class YarnAllocator( // was lost. private val pendingLossReasonRequests = new HashMap[String, mutable.Buffer[RpcCallContext]] + // Maintain loss reasons for already released executors, it will be added when executor loss + // reason is got from AM-RM call, and be removed after querying this loss reason. + private val releasedExecutorLossReasons = new HashMap[String, ExecutorLossReason] + // Keep track of which container is running which executor to remove the executors later // Visible for testing. private[yarn] val executorIdToContainer = new HashMap[String, Container] @@ -104,27 +123,22 @@ private[yarn] class YarnAllocator( private val containerIdToExecutorId = new HashMap[ContainerId, String] // Executor memory in MB. - protected val executorMemory = args.executorMemory + protected val executorMemory = sparkConf.get(EXECUTOR_MEMORY).toInt // Additional memory overhead. - protected val memoryOverhead: Int = sparkConf.getInt("spark.yarn.executor.memoryOverhead", - math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN)) + protected val memoryOverhead: Int = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse( + math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN)).toInt // Number of cores per executor. - protected val executorCores = args.executorCores + protected val executorCores = sparkConf.get(EXECUTOR_CORES) // Resource capability requested for each executors private[yarn] val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores) - private val launcherPool = new ThreadPoolExecutor( - // max pool size of Integer.MAX_VALUE is ignored because we use an unbounded queue - sparkConf.getInt("spark.yarn.containerLauncherMaxThreads", 25), Integer.MAX_VALUE, - 1, TimeUnit.MINUTES, - new LinkedBlockingQueue[Runnable](), - new ThreadFactoryBuilder().setNameFormat("ContainerLauncher #%d").setDaemon(true).build()) - launcherPool.allowCoreThreadTimeOut(true) + private val launcherPool = ThreadUtils.newDaemonCachedThreadPool( + "ContainerLauncher", sparkConf.get(CONTAINER_LAUNCH_MAX_THREADS)) // For testing private val launchContainers = sparkConf.getBoolean("spark.yarn.launchContainers", true) - private val labelExpression = sparkConf.getOption("spark.yarn.executor.nodeLabelExpression") + private val labelExpression = sparkConf.get(EXECUTOR_NODE_LABEL_EXPRESSION) // ContainerRequest constructor that can take a node label expression. We grab it through // reflection because it's only available in later versions of YARN. @@ -134,11 +148,10 @@ private[yarn] class YarnAllocator( classOf[Array[String]], classOf[Array[String]], classOf[Priority], classOf[Boolean], classOf[String])) } catch { - case e: NoSuchMethodException => { + case e: NoSuchMethodException => logWarning(s"Node label expression $expr will be ignored because YARN version on" + " classpath does not support it.") None - } } } @@ -202,8 +215,7 @@ private[yarn] class YarnAllocator( */ def killExecutor(executorId: String): Unit = synchronized { if (executorIdToContainer.contains(executorId)) { - val container = executorIdToContainer.remove(executorId).get - containerIdToExecutorId.remove(container.getId) + val container = executorIdToContainer.get(executorId).get internalReleaseContainer(container) numExecutorsRunning -= 1 } else { @@ -269,25 +281,52 @@ private[yarn] class YarnAllocator( // For locality unmatched and locality free container requests, cancel these container // requests, since required locality preference has been changed, recalculating using // container placement strategy. - val (localityMatched, localityUnMatched, localityFree) = splitPendingAllocationsByLocality( + val (localRequests, staleRequests, anyHostRequests) = splitPendingAllocationsByLocality( hostToLocalTaskCounts, pendingAllocate) - // Remove the outdated container request and recalculate the requested container number - localityUnMatched.foreach(amClient.removeContainerRequest) - localityFree.foreach(amClient.removeContainerRequest) - val updatedNumContainer = missing + localityUnMatched.size + localityFree.size + // cancel "stale" requests for locations that are no longer needed + staleRequests.foreach { stale => + amClient.removeContainerRequest(stale) + } + val cancelledContainers = staleRequests.size + logInfo(s"Canceled $cancelledContainers container requests (locality no longer needed)") + + // consider the number of new containers and cancelled stale containers available + val availableContainers = missing + cancelledContainers + + // to maximize locality, include requests with no locality preference that can be cancelled + val potentialContainers = availableContainers + anyHostRequests.size val containerLocalityPreferences = containerPlacementStrategy.localityOfRequestedContainers( - updatedNumContainer, numLocalityAwareTasks, hostToLocalTaskCounts, - allocatedHostToContainersMap, localityMatched) + potentialContainers, numLocalityAwareTasks, hostToLocalTaskCounts, + allocatedHostToContainersMap, localRequests) + + val newLocalityRequests = new mutable.ArrayBuffer[ContainerRequest] + containerLocalityPreferences.foreach { + case ContainerLocalityPreferences(nodes, racks) if nodes != null => + newLocalityRequests.append(createContainerRequest(resource, nodes, racks)) + case _ => + } + + if (availableContainers >= newLocalityRequests.size) { + // more containers are available than needed for locality, fill in requests for any host + for (i <- 0 until (availableContainers - newLocalityRequests.size)) { + newLocalityRequests.append(createContainerRequest(resource, null, null)) + } + } else { + val numToCancel = newLocalityRequests.size - availableContainers + // cancel some requests without locality preferences to schedule more local containers + anyHostRequests.slice(0, numToCancel).foreach { nonLocal => + amClient.removeContainerRequest(nonLocal) + } + logInfo(s"Canceled $numToCancel container requests for any host to resubmit with locality") + } - for (locality <- containerLocalityPreferences) { - val request = createContainerRequest(resource, locality.nodes, locality.racks) + newLocalityRequests.foreach { request => amClient.addContainerRequest(request) - val nodes = request.getNodes - val hostStr = if (nodes == null || nodes.isEmpty) "Any" else nodes.asScala.last - logInfo(s"Container request (host: $hostStr, capability: $resource)") + logInfo(s"Submitted container request (host: ${hostStr(request)}, capability: $resource)") } + } else if (missing < 0) { val numToCancel = math.min(numPendingAllocate, -missing) logInfo(s"Canceling requests for $numToCancel executor containers") @@ -302,6 +341,13 @@ private[yarn] class YarnAllocator( } } + private def hostStr(request: ContainerRequest): String = { + Option(request.getNodes) match { + case Some(nodes) => nodes.asScala.mkString(",") + case None => "Any" + } + } + /** * Creates a container request, handling the reflection required to use YARN features that were * added in recent versions. @@ -478,7 +524,7 @@ private[yarn] class YarnAllocator( (true, memLimitExceededLogMessage( completedContainer.getDiagnostics, PMEM_EXCEEDED_PATTERN)) - case unknown => + case _ => numExecutorsFailed += 1 (true, "Container marked as failed: " + containerId + onHostStr + ". Exit status: " + completedContainer.getExitStatus + @@ -490,7 +536,7 @@ private[yarn] class YarnAllocator( } else { logInfo(containerExitReason) } - ExecutorExited(0, exitCausedByApp, containerExitReason) + ExecutorExited(exitStatus, exitCausedByApp, containerExitReason) } else { // If we have already released this container, then it must mean // that the driver has explicitly requested it to be killed @@ -514,9 +560,18 @@ private[yarn] class YarnAllocator( containerIdToExecutorId.remove(containerId).foreach { eid => executorIdToContainer.remove(eid) - pendingLossReasonRequests.remove(eid).foreach { pendingRequests => - // Notify application of executor loss reasons so it can decide whether it should abort - pendingRequests.foreach(_.reply(exitReason)) + pendingLossReasonRequests.remove(eid) match { + case Some(pendingRequests) => + // Notify application of executor loss reasons so it can decide whether it should abort + pendingRequests.foreach(_.reply(exitReason)) + + case None => + // We cannot find executor for pending reasons. This is because completed container + // is processed before querying pending result. We should store it for later query. + // This is usually happened when explicitly killing a container, the result will be + // returned in one AM-RM communication. So query RPC will be later than this completed + // container process. + releasedExecutorLossReasons.put(eid, exitReason) } if (!alreadyReleased) { // The executor could have gone away (like no route to host, node failure, etc) @@ -538,8 +593,14 @@ private[yarn] class YarnAllocator( if (executorIdToContainer.contains(eid)) { pendingLossReasonRequests .getOrElseUpdate(eid, new ArrayBuffer[RpcCallContext]) += context + } else if (releasedExecutorLossReasons.contains(eid)) { + // Executor is already released explicitly before getting the loss reason, so directly send + // the pre-stored lost reason + context.reply(releasedExecutorLossReasons.remove(eid).get) } else { logWarning(s"Tried to get the loss reason for non-existent executor $eid") + context.sendFailure( + new SparkException(s"Fail to find loss reason for non-existent executor $eid")) } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index d2a211f6711ff..e7f75446641cb 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy.yarn import java.util.{List => JList} import scala.collection.JavaConverters._ -import scala.collection.{Map, Set} +import scala.collection.Map import scala.util.Try import org.apache.hadoop.conf.Configuration @@ -30,15 +30,16 @@ import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.webapp.util.WebAppUtils -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.deploy.yarn.config._ +import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.scheduler.SplitInfo import org.apache.spark.util.Utils /** * Handles registering and unregistering the application with the YARN ResourceManager. */ -private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logging { +private[spark] class YarnRMClient extends Logging { private var amClient: AMRMClient[ContainerRequest] = _ private var uiHistoryAddress: String = _ @@ -71,8 +72,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) registered = true } - new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), args, - securityMgr) + new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr) } /** @@ -118,7 +118,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg /** Returns the maximum number of attempts to register the AM. */ def getMaxRegAttempts(sparkConf: SparkConf, yarnConf: YarnConfiguration): Int = { - val sparkMaxAttempts = sparkConf.getOption("spark.yarn.maxAppAttempts").map(_.toInt) + val sparkMaxAttempts = sparkConf.get(MAX_APP_ATTEMPTS).map(_.toInt) val yarnMaxAttempts = yarnConf.getInt( YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS) val retval: Int = sparkMaxAttempts match { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 561ad79ee0228..4b36da309dbd1 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -18,7 +18,9 @@ package org.apache.spark.deploy.yarn import java.io.File +import java.lang.reflect.UndeclaredThrowableException import java.nio.charset.StandardCharsets.UTF_8 +import java.security.PrivilegedExceptionAction import java.util.regex.Matcher import java.util.regex.Pattern @@ -30,19 +32,21 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.io.Text -import org.apache.hadoop.mapred.{Master, JobConf} +import org.apache.hadoop.mapred.{JobConf, Master} import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation -import org.apache.hadoop.security.token.Token -import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.security.token.{Token, TokenIdentifier} import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.records.{ApplicationAccessType, ContainerId, Priority} +import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.util.ConverterUtils +import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.yarn.config._ +import org.apache.spark.internal.config._ import org.apache.spark.launcher.YarnCommandBuilderUtils -import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.util.Utils /** @@ -61,7 +65,7 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { override def isYarnMode(): Boolean = { true } // Return an appropriate (subclass) of Configuration. Creating a config initializes some Hadoop - // subsystems. Always create a new config, dont reuse yarnConf. + // subsystems. Always create a new config, don't reuse yarnConf. override def newConfiguration(conf: SparkConf): Configuration = new YarnConfiguration(super.newConfiguration(conf)) @@ -95,10 +99,7 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { * Get the list of namenodes the user may access. */ def getNameNodesToAccess(sparkConf: SparkConf): Set[Path] = { - sparkConf.get("spark.yarn.access.namenodes", "") - .split(",") - .map(_.trim()) - .filter(!_.isEmpty) + sparkConf.get(NAMENODES_TO_ACCESS) .map(new Path(_)) .toSet } @@ -133,6 +134,44 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { } } + /** + * Obtains token for the Hive metastore and adds them to the credentials. + */ + def obtainTokenForHiveMetastore( + sparkConf: SparkConf, + conf: Configuration, + credentials: Credentials) { + if (shouldGetTokens(sparkConf, "hive") && UserGroupInformation.isSecurityEnabled) { + YarnSparkHadoopUtil.get.obtainTokenForHiveMetastore(conf).foreach { + credentials.addToken(new Text("hive.server2.delegation.token"), _) + } + } + } + + /** + * Obtain a security token for HBase. + */ + def obtainTokenForHBase( + sparkConf: SparkConf, + conf: Configuration, + credentials: Credentials): Unit = { + if (shouldGetTokens(sparkConf, "hbase") && UserGroupInformation.isSecurityEnabled) { + YarnSparkHadoopUtil.get.obtainTokenForHBase(conf).foreach { token => + credentials.addToken(token.getService, token) + logInfo("Added HBase security token to credentials.") + } + } + } + + /** + * Return whether delegation tokens should be retrieved for the given service when security is + * enabled. By default, tokens are retrieved, but that behavior can be changed by setting + * a service-specific configuration. + */ + private def shouldGetTokens(conf: SparkConf, service: String): Boolean = { + conf.getBoolean(s"spark.yarn.security.tokens.${service}.enabled", true) + } + private[spark] override def startExecutorDelegationTokenRenewer(sparkConf: SparkConf): Unit = { tokenRenewer = Some(new ExecutorDelegationTokenUpdater(sparkConf, conf)) tokenRenewer.get.updateCredentialsIfRequired() @@ -156,7 +195,7 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { */ def obtainTokenForHiveMetastore(conf: Configuration): Option[Token[DelegationTokenIdentifier]] = { try { - obtainTokenForHiveMetastoreInner(conf, UserGroupInformation.getCurrentUser().getUserName) + obtainTokenForHiveMetastoreInner(conf) } catch { case e: ClassNotFoundException => logInfo(s"Hive class not found $e") @@ -171,14 +210,14 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { * @param username the username of the principal requesting the delegating token. * @return a delegation token */ - private[yarn] def obtainTokenForHiveMetastoreInner(conf: Configuration, - username: String): Option[Token[DelegationTokenIdentifier]] = { + private[yarn] def obtainTokenForHiveMetastoreInner(conf: Configuration): + Option[Token[DelegationTokenIdentifier]] = { val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) // the hive configuration class is a subclass of Hadoop Configuration, so can be cast down // to a Configuration and used without reflection val hiveConfClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf") - // using the (Configuration, Class) constructor allows the current configuratin to be included + // using the (Configuration, Class) constructor allows the current configuration to be included // in the hive config. val ctor = hiveConfClass.getDeclaredConstructor(classOf[Configuration], classOf[Object].getClass) @@ -187,11 +226,12 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { // Check for local metastore if (metastoreUri.nonEmpty) { - require(username.nonEmpty, "Username undefined") val principalKey = "hive.metastore.kerberos.principal" val principal = hiveConf.getTrimmed(principalKey, "") require(principal.nonEmpty, "Hive principal $principalKey undefined") - logDebug(s"Getting Hive delegation token for $username against $principal at $metastoreUri") + val currentUser = UserGroupInformation.getCurrentUser() + logDebug(s"Getting Hive delegation token for ${currentUser.getUserName()} against " + + s"$principal at $metastoreUri") val hiveClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.ql.metadata.Hive") val closeCurrent = hiveClass.getMethod("closeCurrent") try { @@ -200,12 +240,14 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { classOf[String], classOf[String]) val getHive = hiveClass.getMethod("get", hiveConfClass) - // invoke - val hive = getHive.invoke(null, hiveConf) - val tokenStr = getDelegationToken.invoke(hive, username, principal).asInstanceOf[String] - val hive2Token = new Token[DelegationTokenIdentifier]() - hive2Token.decodeFromUrlString(tokenStr) - Some(hive2Token) + doAsRealUser { + val hive = getHive.invoke(null, hiveConf) + val tokenStr = getDelegationToken.invoke(hive, currentUser.getUserName(), principal) + .asInstanceOf[String] + val hive2Token = new Token[DelegationTokenIdentifier]() + hive2Token.decodeFromUrlString(tokenStr) + Some(hive2Token) + } } finally { Utils.tryLogNonFatalError { closeCurrent.invoke(null) @@ -216,6 +258,74 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { None } } + + /** + * Obtain a security token for HBase. + * + * Requirements + * + * 1. `"hbase.security.authentication" == "kerberos"` + * 2. The HBase classes `HBaseConfiguration` and `TokenUtil` could be loaded + * and invoked. + * + * @param conf Hadoop configuration; an HBase configuration is created + * from this. + * @return a token if the requirements were met, `None` if not. + */ + def obtainTokenForHBase(conf: Configuration): Option[Token[TokenIdentifier]] = { + try { + obtainTokenForHBaseInner(conf) + } catch { + case e: ClassNotFoundException => + logInfo(s"HBase class not found $e") + logDebug("HBase class not found", e) + None + } + } + + /** + * Obtain a security token for HBase if `"hbase.security.authentication" == "kerberos"` + * + * @param conf Hadoop configuration; an HBase configuration is created + * from this. + * @return a token if one was needed + */ + def obtainTokenForHBaseInner(conf: Configuration): Option[Token[TokenIdentifier]] = { + val mirror = universe.runtimeMirror(getClass.getClassLoader) + val confCreate = mirror.classLoader. + loadClass("org.apache.hadoop.hbase.HBaseConfiguration"). + getMethod("create", classOf[Configuration]) + val obtainToken = mirror.classLoader. + loadClass("org.apache.hadoop.hbase.security.token.TokenUtil"). + getMethod("obtainToken", classOf[Configuration]) + val hbaseConf = confCreate.invoke(null, conf).asInstanceOf[Configuration] + if ("kerberos" == hbaseConf.get("hbase.security.authentication")) { + logDebug("Attempting to fetch HBase security token.") + Some(obtainToken.invoke(null, hbaseConf).asInstanceOf[Token[TokenIdentifier]]) + } else { + None + } + } + + /** + * Run some code as the real logged in user (which may differ from the current user, for + * example, when using proxying). + */ + private def doAsRealUser[T](fn: => T): T = { + val currentUser = UserGroupInformation.getCurrentUser() + val realUser = Option(currentUser.getRealUser()).getOrElse(currentUser) + + // For some reason the Scala-generated anonymous class ends up causing an + // UndeclaredThrowableException, even if you annotate the method with @throws. + try { + realUser.doAs(new PrivilegedExceptionAction[T]() { + override def run(): T = fn + }) + } catch { + case e: UndeclaredThrowableException => throw Option(e.getCause()).getOrElse(e) + } + } + } object YarnSparkHadoopUtil { @@ -224,7 +334,7 @@ object YarnSparkHadoopUtil { // the common cases. Memory overhead tends to grow with container size. val MEMORY_OVERHEAD_FACTOR = 0.10 - val MEMORY_OVERHEAD_MIN = 384 + val MEMORY_OVERHEAD_MIN = 384L val ANY_HOST = "*" @@ -308,7 +418,7 @@ object YarnSparkHadoopUtil { * * @return The correct OOM Error handler JVM option, platform dependent. */ - def getOutOfMemoryErrorArgument : String = { + def getOutOfMemoryErrorArgument: String = { if (Utils.isWindows) { escapeForShell("-XX:OnOutOfMemoryError=taskkill /F /PID %%%%p") } else { @@ -392,23 +502,25 @@ object YarnSparkHadoopUtil { /** * Getting the initial target number of executors depends on whether dynamic allocation is * enabled. + * If not using dynamic allocation it gets the number of executors requested by the user. */ - def getInitialTargetExecutorNumber(conf: SparkConf): Int = { + def getInitialTargetExecutorNumber( + conf: SparkConf, + numExecutors: Int = DEFAULT_NUMBER_EXECUTORS): Int = { if (Utils.isDynamicAllocationEnabled(conf)) { - val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", 0) - val initialNumExecutors = - conf.getInt("spark.dynamicAllocation.initialExecutors", minNumExecutors) - val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", Int.MaxValue) + val minNumExecutors = conf.get(DYN_ALLOCATION_MIN_EXECUTORS) + val initialNumExecutors = conf.get(DYN_ALLOCATION_INITIAL_EXECUTORS) + val maxNumExecutors = conf.get(DYN_ALLOCATION_MAX_EXECUTORS) require(initialNumExecutors >= minNumExecutors && initialNumExecutors <= maxNumExecutors, - s"initial executor number $initialNumExecutors must between min executor number" + + s"initial executor number $initialNumExecutors must between min executor number " + s"$minNumExecutors and max executor number $maxNumExecutors") initialNumExecutors } else { val targetNumExecutors = - sys.env.get("SPARK_EXECUTOR_INSTANCES").map(_.toInt).getOrElse(DEFAULT_NUMBER_EXECUTORS) + sys.env.get("SPARK_EXECUTOR_INSTANCES").map(_.toInt).getOrElse(numExecutors) // System property can override environment variable. - conf.getInt("spark.executor.instances", targetNumExecutors) + conf.get(EXECUTOR_INSTANCES).getOrElse(targetNumExecutors) } } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala new file mode 100644 index 0000000000000..edfbfc5d58d86 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.util.concurrent.TimeUnit + +import org.apache.spark.internal.config.ConfigBuilder +import org.apache.spark.network.util.ByteUnit + +package object config { + + /* Common app configuration. */ + + private[spark] val APPLICATION_TAGS = ConfigBuilder("spark.yarn.tags") + .doc("Comma-separated list of strings to pass through as YARN application tags appearing " + + "in YARN Application Reports, which can be used for filtering when querying YARN.") + .stringConf + .toSequence + .createOptional + + private[spark] val ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS = + ConfigBuilder("spark.yarn.am.attemptFailuresValidityInterval") + .doc("Interval after which AM failures will be considered independent and " + + "not accumulate towards the attempt count.") + .timeConf(TimeUnit.MILLISECONDS) + .createOptional + + private[spark] val MAX_APP_ATTEMPTS = ConfigBuilder("spark.yarn.maxAppAttempts") + .doc("Maximum number of AM attempts before failing the app.") + .intConf + .createOptional + + private[spark] val USER_CLASS_PATH_FIRST = ConfigBuilder("spark.yarn.user.classpath.first") + .doc("Whether to place user jars in front of Spark's classpath.") + .booleanConf + .createWithDefault(false) + + private[spark] val GATEWAY_ROOT_PATH = ConfigBuilder("spark.yarn.config.gatewayPath") + .doc("Root of configuration paths that is present on gateway nodes, and will be replaced " + + "with the corresponding path in cluster machines.") + .stringConf + .createWithDefault(null) + + private[spark] val REPLACEMENT_ROOT_PATH = ConfigBuilder("spark.yarn.config.replacementPath") + .doc(s"Path to use as a replacement for ${GATEWAY_ROOT_PATH.key} when launching processes " + + "in the YARN cluster.") + .stringConf + .createWithDefault(null) + + private[spark] val QUEUE_NAME = ConfigBuilder("spark.yarn.queue") + .stringConf + .createWithDefault("default") + + private[spark] val HISTORY_SERVER_ADDRESS = ConfigBuilder("spark.yarn.historyServer.address") + .stringConf + .createOptional + + /* File distribution. */ + + private[spark] val SPARK_ARCHIVE = ConfigBuilder("spark.yarn.archive") + .doc("Location of archive containing jars files with Spark classes.") + .stringConf + .createOptional + + private[spark] val SPARK_JARS = ConfigBuilder("spark.yarn.jars") + .doc("Location of jars containing Spark classes.") + .stringConf + .toSequence + .createOptional + + private[spark] val ARCHIVES_TO_DISTRIBUTE = ConfigBuilder("spark.yarn.dist.archives") + .stringConf + .toSequence + .createWithDefault(Nil) + + private[spark] val FILES_TO_DISTRIBUTE = ConfigBuilder("spark.yarn.dist.files") + .stringConf + .toSequence + .createWithDefault(Nil) + + private[spark] val JARS_TO_DISTRIBUTE = ConfigBuilder("spark.yarn.dist.jars") + .stringConf + .toSequence + .createWithDefault(Nil) + + private[spark] val PRESERVE_STAGING_FILES = ConfigBuilder("spark.yarn.preserve.staging.files") + .doc("Whether to preserve temporary files created by the job in HDFS.") + .booleanConf + .createWithDefault(false) + + private[spark] val STAGING_FILE_REPLICATION = ConfigBuilder("spark.yarn.submit.file.replication") + .doc("Replication factor for files uploaded by Spark to HDFS.") + .intConf + .createOptional + + private[spark] val STAGING_DIR = ConfigBuilder("spark.yarn.stagingDir") + .doc("Staging directory used while submitting applications.") + .stringConf + .createOptional + + /* Cluster-mode launcher configuration. */ + + private[spark] val WAIT_FOR_APP_COMPLETION = ConfigBuilder("spark.yarn.submit.waitAppCompletion") + .doc("In cluster mode, whether to wait for the application to finish before exiting the " + + "launcher process.") + .booleanConf + .createWithDefault(true) + + private[spark] val REPORT_INTERVAL = ConfigBuilder("spark.yarn.report.interval") + .doc("Interval between reports of the current app status in cluster mode.") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("1s") + + /* Shared Client-mode AM / Driver configuration. */ + + private[spark] val AM_MAX_WAIT_TIME = ConfigBuilder("spark.yarn.am.waitTime") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("100s") + + private[spark] val AM_NODE_LABEL_EXPRESSION = ConfigBuilder("spark.yarn.am.nodeLabelExpression") + .doc("Node label expression for the AM.") + .stringConf + .createOptional + + private[spark] val CONTAINER_LAUNCH_MAX_THREADS = + ConfigBuilder("spark.yarn.containerLauncherMaxThreads") + .intConf + .createWithDefault(25) + + private[spark] val MAX_EXECUTOR_FAILURES = ConfigBuilder("spark.yarn.max.executor.failures") + .intConf + .createOptional + + private[spark] val MAX_REPORTER_THREAD_FAILURES = + ConfigBuilder("spark.yarn.scheduler.reporterThread.maxFailures") + .intConf + .createWithDefault(5) + + private[spark] val RM_HEARTBEAT_INTERVAL = + ConfigBuilder("spark.yarn.scheduler.heartbeat.interval-ms") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("3s") + + private[spark] val INITIAL_HEARTBEAT_INTERVAL = + ConfigBuilder("spark.yarn.scheduler.initial-allocation.interval") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("200ms") + + private[spark] val SCHEDULER_SERVICES = ConfigBuilder("spark.yarn.services") + .doc("A comma-separated list of class names of services to add to the scheduler.") + .stringConf + .toSequence + .createWithDefault(Nil) + + /* Client-mode AM configuration. */ + + private[spark] val AM_CORES = ConfigBuilder("spark.yarn.am.cores") + .intConf + .createWithDefault(1) + + private[spark] val AM_JAVA_OPTIONS = ConfigBuilder("spark.yarn.am.extraJavaOptions") + .doc("Extra Java options for the client-mode AM.") + .stringConf + .createOptional + + private[spark] val AM_LIBRARY_PATH = ConfigBuilder("spark.yarn.am.extraLibraryPath") + .doc("Extra native library path for the client-mode AM.") + .stringConf + .createOptional + + private[spark] val AM_MEMORY_OVERHEAD = ConfigBuilder("spark.yarn.am.memoryOverhead") + .bytesConf(ByteUnit.MiB) + .createOptional + + private[spark] val AM_MEMORY = ConfigBuilder("spark.yarn.am.memory") + .bytesConf(ByteUnit.MiB) + .createWithDefaultString("512m") + + /* Driver configuration. */ + + private[spark] val DRIVER_CORES = ConfigBuilder("spark.driver.cores") + .intConf + .createWithDefault(1) + + private[spark] val DRIVER_MEMORY_OVERHEAD = ConfigBuilder("spark.yarn.driver.memoryOverhead") + .bytesConf(ByteUnit.MiB) + .createOptional + + /* Executor configuration. */ + + private[spark] val EXECUTOR_CORES = ConfigBuilder("spark.executor.cores") + .intConf + .createWithDefault(1) + + private[spark] val EXECUTOR_MEMORY_OVERHEAD = ConfigBuilder("spark.yarn.executor.memoryOverhead") + .bytesConf(ByteUnit.MiB) + .createOptional + + private[spark] val EXECUTOR_NODE_LABEL_EXPRESSION = + ConfigBuilder("spark.yarn.executor.nodeLabelExpression") + .doc("Node label expression for executors.") + .stringConf + .createOptional + + /* Security configuration. */ + + private[spark] val CREDENTIAL_FILE_MAX_COUNT = + ConfigBuilder("spark.yarn.credentials.file.retention.count") + .intConf + .createWithDefault(5) + + private[spark] val CREDENTIALS_FILE_MAX_RETENTION = + ConfigBuilder("spark.yarn.credentials.file.retention.days") + .intConf + .createWithDefault(5) + + private[spark] val NAMENODES_TO_ACCESS = ConfigBuilder("spark.yarn.access.namenodes") + .doc("Extra NameNode URLs for which to request delegation tokens. The NameNode that hosts " + + "fs.defaultFS does not need to be listed here.") + .stringConf + .toSequence + .createWithDefault(Nil) + + private[spark] val TOKEN_RENEWAL_INTERVAL = ConfigBuilder("spark.yarn.token.renewal.interval") + .internal() + .timeConf(TimeUnit.MILLISECONDS) + .createOptional + + /* Private configs. */ + + private[spark] val CREDENTIALS_FILE_PATH = ConfigBuilder("spark.yarn.credentials.file") + .internal() + .stringConf + .createWithDefault(null) + + // Internal config to propagate the location of the user's jar to the driver/executors + private[spark] val APP_JAR = ConfigBuilder("spark.yarn.user.jar") + .internal() + .stringConf + .createOptional + + // Internal config to propagate the locations of any extra jars to add to the classpath + // of the executors + private[spark] val SECONDARY_JARS = ConfigBuilder("spark.yarn.secondary.jars") + .internal() + .stringConf + .toSequence + .createOptional +} diff --git a/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala b/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala index 7d246bf407121..6c3556a2ee43e 100644 --- a/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala +++ b/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.launcher import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer +import scala.util.Properties /** * Exposes methods from the launcher library that are used by the YARN backend. @@ -29,6 +30,14 @@ private[spark] object YarnCommandBuilderUtils { CommandBuilderUtils.quoteForBatchScript(arg) } + def findJarsDir(sparkHome: String): String = { + val scalaVer = Properties.versionNumberString + .split("\\.") + .take(2) + .mkString(".") + CommandBuilderUtils.findJarsDir(sparkHome, scalaVer, true) + } + /** * Adds the perm gen configuration to the list of java options if needed and not yet added. * diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala new file mode 100644 index 0000000000000..4ed285230ff81 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} + +import org.apache.spark.SparkContext +import org.apache.spark.deploy.yarn.config._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +/** + * An extension service that can be loaded into a Spark YARN scheduler. + * A Service that can be started and stopped. + * + * 1. For implementations to be loadable by `SchedulerExtensionServices`, + * they must provide an empty constructor. + * 2. The `stop()` operation MUST be idempotent, and succeed even if `start()` was + * never invoked. + */ +trait SchedulerExtensionService { + + /** + * Start the extension service. This should be a no-op if + * called more than once. + * @param binding binding to the spark application and YARN + */ + def start(binding: SchedulerExtensionServiceBinding): Unit + + /** + * Stop the service + * The `stop()` operation MUST be idempotent, and succeed even if `start()` was + * never invoked. + */ + def stop(): Unit +} + +/** + * Binding information for a [[SchedulerExtensionService]]. + * + * The attempt ID will be set if the service is started within a YARN application master; + * there is then a different attempt ID for every time that AM is restarted. + * When the service binding is instantiated in client mode, there's no attempt ID, as it lacks + * this information. + * @param sparkContext current spark context + * @param applicationId YARN application ID + * @param attemptId YARN attemptID. This will always be unset in client mode, and always set in + * cluster mode. + */ +case class SchedulerExtensionServiceBinding( + sparkContext: SparkContext, + applicationId: ApplicationId, + attemptId: Option[ApplicationAttemptId] = None) + +/** + * Container for [[SchedulerExtensionService]] instances. + * + * Loads Extension Services from the configuration property + * `"spark.yarn.services"`, instantiates and starts them. + * When stopped, it stops all child entries. + * + * The order in which child extension services are started and stopped + * is undefined. + */ +private[spark] class SchedulerExtensionServices extends SchedulerExtensionService + with Logging { + private var serviceOption: Option[String] = None + private var services: List[SchedulerExtensionService] = Nil + private val started = new AtomicBoolean(false) + private var binding: SchedulerExtensionServiceBinding = _ + + /** + * Binding operation will load the named services and call bind on them too; the + * entire set of services are then ready for `init()` and `start()` calls. + * + * @param binding binding to the spark application and YARN + */ + def start(binding: SchedulerExtensionServiceBinding): Unit = { + if (started.getAndSet(true)) { + logWarning("Ignoring re-entrant start operation") + return + } + require(binding.sparkContext != null, "Null context parameter") + require(binding.applicationId != null, "Null appId parameter") + this.binding = binding + val sparkContext = binding.sparkContext + val appId = binding.applicationId + val attemptId = binding.attemptId + logInfo(s"Starting Yarn extension services with app $appId and attemptId $attemptId") + + services = sparkContext.conf.get(SCHEDULER_SERVICES).map { sClass => + val instance = Utils.classForName(sClass) + .newInstance() + .asInstanceOf[SchedulerExtensionService] + // bind this service + instance.start(binding) + logInfo(s"Service $sClass started") + instance + }.toList + } + + /** + * Get the list of services. + * + * @return a list of services; Nil until the service is started + */ + def getServices: List[SchedulerExtensionService] = services + + /** + * Stop the services; idempotent. + * + */ + override def stop(): Unit = { + if (started.getAndSet(false)) { + logInfo(s"Stopping $this") + services.foreach { s => + Utils.tryLogNonFatalError(s.stop()) + } + } + } + + override def toString(): String = s"""SchedulerExtensionServices + |(serviceOption=$serviceOption, + | services=$services, + | started=$started)""".stripMargin +} diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 20771f655473c..56dc0004d04cc 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -19,10 +19,11 @@ package org.apache.spark.scheduler.cluster import scala.collection.mutable.ArrayBuffer -import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} +import org.apache.hadoop.yarn.api.records.YarnApplicationState -import org.apache.spark.{SparkException, Logging, SparkContext} +import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil} +import org.apache.spark.internal.Logging import org.apache.spark.launcher.SparkAppHandle import org.apache.spark.scheduler.TaskSchedulerImpl @@ -33,7 +34,6 @@ private[spark] class YarnClientSchedulerBackend( with Logging { private var client: Client = null - private var appId: ApplicationId = null private var monitorThread: MonitorThread = null /** @@ -48,19 +48,17 @@ private[spark] class YarnClientSchedulerBackend( val argsArrayBuf = new ArrayBuffer[String]() argsArrayBuf += ("--arg", hostport) - argsArrayBuf ++= getExtraClientArguments logDebug("ClientArguments called with: " + argsArrayBuf.mkString(" ")) - val args = new ClientArguments(argsArrayBuf.toArray, conf) - totalExpectedExecutors = args.numExecutors + val args = new ClientArguments(argsArrayBuf.toArray) + totalExpectedExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(conf) client = new Client(args, conf) - appId = client.submitApplication() + bindToYarn(client.submitApplication(), None) // SPARK-8687: Ensure all necessary properties have already been set before // we initialize our driver scheduler backend, which serves these properties // to the executors super.start() - waitForApplication() // SPARK-8851: In yarn-client mode, the AM still does the credentials refresh. The driver @@ -73,51 +71,14 @@ private[spark] class YarnClientSchedulerBackend( monitorThread.start() } - /** - * Return any extra command line arguments to be passed to Client provided in the form of - * environment variables or Spark properties. - */ - private def getExtraClientArguments: Seq[String] = { - val extraArgs = new ArrayBuffer[String] - // List of (target Client argument, environment variable, Spark property) - val optionTuples = - List( - ("--executor-memory", "SPARK_WORKER_MEMORY", "spark.executor.memory"), - ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"), - ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"), - ("--executor-cores", "SPARK_EXECUTOR_CORES", "spark.executor.cores"), - ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue"), - ("--py-files", null, "spark.submit.pyFiles") - ) - // Warn against the following deprecated environment variables: env var -> suggestion - val deprecatedEnvVars = Map( - "SPARK_WORKER_MEMORY" -> "SPARK_EXECUTOR_MEMORY or --executor-memory through spark-submit", - "SPARK_WORKER_CORES" -> "SPARK_EXECUTOR_CORES or --executor-cores through spark-submit") - optionTuples.foreach { case (optionName, envVar, sparkProp) => - if (sc.getConf.contains(sparkProp)) { - extraArgs += (optionName, sc.getConf.get(sparkProp)) - } else if (envVar != null && System.getenv(envVar) != null) { - extraArgs += (optionName, System.getenv(envVar)) - if (deprecatedEnvVars.contains(envVar)) { - logWarning(s"NOTE: $envVar is deprecated. Use ${deprecatedEnvVars(envVar)} instead.") - } - } - } - // The app name is a special case because "spark.app.name" is required of all applications. - // As a result, the corresponding "SPARK_YARN_APP_NAME" is already handled preemptively in - // SparkSubmitArguments if "spark.app.name" is not explicitly set by the user. (SPARK-5222) - sc.getConf.getOption("spark.app.name").foreach(v => extraArgs += ("--name", v)) - extraArgs - } - /** * Report the state of the application until it is running. * If the application has finished, failed or been killed in the process, throw an exception. * This assumes both `client` and `appId` have already been set. */ private def waitForApplication(): Unit = { - assert(client != null && appId != null, "Application has not been submitted yet!") - val (state, _) = client.monitorApplication(appId, returnOnRunning = true) // blocking + assert(client != null && appId.isDefined, "Application has not been submitted yet!") + val (state, _) = client.monitorApplication(appId.get, returnOnRunning = true) // blocking if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) { @@ -125,7 +86,7 @@ private[spark] class YarnClientSchedulerBackend( "It might have been killed or unable to launch application master.") } if (state == YarnApplicationState.RUNNING) { - logInfo(s"Application $appId has started running.") + logInfo(s"Application ${appId.get} has started running.") } } @@ -141,7 +102,7 @@ private[spark] class YarnClientSchedulerBackend( override def run() { try { - val (state, _) = client.monitorApplication(appId, logApplicationReport = false) + val (state, _) = client.monitorApplication(appId.get, logApplicationReport = false) logError(s"Yarn application has already exited with state $state!") allowInterrupt = false sc.stop() @@ -163,7 +124,7 @@ private[spark] class YarnClientSchedulerBackend( * This assumes both `client` and `appId` have already been set. */ private def asyncMonitorApplication(): MonitorThread = { - assert(client != null && appId != null, "Application has not been submitted yet!") + assert(client != null && appId.isDefined, "Application has not been submitted yet!") val t = new MonitorThread t.setName("Yarn application state monitor") t.setDaemon(true) @@ -193,10 +154,4 @@ private[spark] class YarnClientSchedulerBackend( logInfo("Stopped") } - override def applicationId(): String = { - Option(appId).map(_.toString).getOrElse { - logWarning("Application ID is not initialized yet.") - super.applicationId - } - } } diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index 50b699f11b21c..ced597bed36d9 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -21,7 +21,7 @@ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.spark.SparkContext -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil +import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnSparkHadoopUtil} import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.util.Utils @@ -31,26 +31,12 @@ private[spark] class YarnClusterSchedulerBackend( extends YarnSchedulerBackend(scheduler, sc) { override def start() { + val attemptId = ApplicationMaster.getAttemptId + bindToYarn(attemptId.getApplicationId(), Some(attemptId)) super.start() totalExpectedExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sc.conf) } - override def applicationId(): String = - // In YARN Cluster mode, the application ID is expected to be set, so log an error if it's - // not found. - sc.getConf.getOption("spark.yarn.app.id").getOrElse { - logError("Application ID is not set.") - super.applicationId - } - - override def applicationAttemptId(): Option[String] = - // In YARN Cluster mode, the attempt ID is expected to be set, so log an error if it's - // not found. - sc.getConf.getOption("spark.yarn.app.attemptId").orElse { - logError("Application attempt ID is not set.") - super.applicationAttemptId - } - override def getDriverLogUrls: Option[Map[String, String]] = { var driverLogs: Option[Map[String, String]] = None try { diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala index 4ebf3af12b381..029382133ddf2 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala @@ -18,7 +18,6 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.yarn.util.RackResolver - import org.apache.log4j.{Level, Logger} import org.apache.spark._ diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala new file mode 100644 index 0000000000000..6b3c831e60472 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import scala.concurrent.{ExecutionContext, Future} +import scala.util.control.NonFatal + +import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} + +import org.apache.spark.SparkContext +import org.apache.spark.internal.Logging +import org.apache.spark.rpc._ +import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ +import org.apache.spark.ui.JettyUtils +import org.apache.spark.util.{RpcUtils, ThreadUtils} + +/** + * Abstract Yarn scheduler backend that contains common logic + * between the client and cluster Yarn scheduler backends. + */ +private[spark] abstract class YarnSchedulerBackend( + scheduler: TaskSchedulerImpl, + sc: SparkContext) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) { + + override val minRegisteredRatio = + if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { + 0.8 + } else { + super.minRegisteredRatio + } + + protected var totalExpectedExecutors = 0 + + private val yarnSchedulerEndpoint = new YarnSchedulerEndpoint(rpcEnv) + + private val yarnSchedulerEndpointRef = rpcEnv.setupEndpoint( + YarnSchedulerBackend.ENDPOINT_NAME, yarnSchedulerEndpoint) + + private implicit val askTimeout = RpcUtils.askRpcTimeout(sc.conf) + + /** Application ID. */ + protected var appId: Option[ApplicationId] = None + + /** Attempt ID. This is unset for client-mode schedulers */ + private var attemptId: Option[ApplicationAttemptId] = None + + /** Scheduler extension services. */ + private val services: SchedulerExtensionServices = new SchedulerExtensionServices() + + // Flag to specify whether this schedulerBackend should be reset. + private var shouldResetOnAmRegister = false + + /** + * Bind to YARN. This *must* be done before calling [[start()]]. + * + * @param appId YARN application ID + * @param attemptId Optional YARN attempt ID + */ + protected def bindToYarn(appId: ApplicationId, attemptId: Option[ApplicationAttemptId]): Unit = { + this.appId = Some(appId) + this.attemptId = attemptId + } + + override def start() { + require(appId.isDefined, "application ID unset") + val binding = SchedulerExtensionServiceBinding(sc, appId.get, attemptId) + services.start(binding) + super.start() + } + + override def stop(): Unit = { + try { + // SPARK-12009: To prevent Yarn allocator from requesting backup for the executors which + // was Stopped by SchedulerBackend. + requestTotalExecutors(0, 0, Map.empty) + super.stop() + } finally { + services.stop() + } + } + + /** + * Get the attempt ID for this run, if the cluster manager supports multiple + * attempts. Applications run in client mode will not have attempt IDs. + * This attempt ID only includes attempt counter, like "1", "2". + * + * @return The application attempt id, if available. + */ + override def applicationAttemptId(): Option[String] = { + attemptId.map(_.getAttemptId.toString) + } + + /** + * Get an application ID associated with the job. + * This returns the string value of [[appId]] if set, otherwise + * the locally-generated ID from the superclass. + * @return The application ID + */ + override def applicationId(): String = { + appId.map(_.toString).getOrElse { + logWarning("Application ID is not initialized yet.") + super.applicationId + } + } + + /** + * Request executors from the ApplicationMaster by specifying the total number desired. + * This includes executors already pending or running. + */ + override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { + yarnSchedulerEndpointRef.askWithRetry[Boolean]( + RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount)) + } + + /** + * Request that the ApplicationMaster kill the specified executors. + */ + override def doKillExecutors(executorIds: Seq[String]): Boolean = { + yarnSchedulerEndpointRef.askWithRetry[Boolean](KillExecutors(executorIds)) + } + + override def sufficientResourcesRegistered(): Boolean = { + totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio + } + + /** + * Add filters to the SparkUI. + */ + private def addWebUIFilter( + filterName: String, + filterParams: Map[String, String], + proxyBase: String): Unit = { + if (proxyBase != null && proxyBase.nonEmpty) { + System.setProperty("spark.ui.proxyBase", proxyBase) + } + + val hasFilter = + filterName != null && filterName.nonEmpty && + filterParams != null && filterParams.nonEmpty + if (hasFilter) { + logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase") + conf.set("spark.ui.filters", filterName) + filterParams.foreach { case (k, v) => conf.set(s"spark.$filterName.param.$k", v) } + scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) } + } + } + + override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { + new YarnDriverEndpoint(rpcEnv, properties) + } + + /** + * Reset the state of SchedulerBackend to the initial state. This is happened when AM is failed + * and re-registered itself to driver after a failure. The stale state in driver should be + * cleaned. + */ + override protected def reset(): Unit = { + super.reset() + sc.executorAllocationManager.foreach(_.reset()) + } + + /** + * Override the DriverEndpoint to add extra logic for the case when an executor is disconnected. + * This endpoint communicates with the executors and queries the AM for an executor's exit + * status when the executor is disconnected. + */ + private class YarnDriverEndpoint(rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) + extends DriverEndpoint(rpcEnv, sparkProperties) { + + /** + * When onDisconnected is received at the driver endpoint, the superclass DriverEndpoint + * handles it by assuming the Executor was lost for a bad reason and removes the executor + * immediately. + * + * In YARN's case however it is crucial to talk to the application master and ask why the + * executor had exited. If the executor exited for some reason unrelated to the running tasks + * (e.g., preemption), according to the application master, then we pass that information down + * to the TaskSetManager to inform the TaskSetManager that tasks on that lost executor should + * not count towards a job failure. + */ + override def onDisconnected(rpcAddress: RpcAddress): Unit = { + addressToExecutorId.get(rpcAddress).foreach { executorId => + if (disableExecutor(executorId)) { + yarnSchedulerEndpoint.handleExecutorDisconnectedFromDriver(executorId, rpcAddress) + } + } + } + } + + /** + * An [[RpcEndpoint]] that communicates with the ApplicationMaster. + */ + private class YarnSchedulerEndpoint(override val rpcEnv: RpcEnv) + extends ThreadSafeRpcEndpoint with Logging { + private var amEndpoint: Option[RpcEndpointRef] = None + + private val askAmThreadPool = + ThreadUtils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool") + implicit val askAmExecutor = ExecutionContext.fromExecutor(askAmThreadPool) + + private[YarnSchedulerBackend] def handleExecutorDisconnectedFromDriver( + executorId: String, + executorRpcAddress: RpcAddress): Unit = { + amEndpoint match { + case Some(am) => + val lossReasonRequest = GetExecutorLossReason(executorId) + val future = am.ask[ExecutorLossReason](lossReasonRequest, askTimeout) + future onSuccess { + case reason: ExecutorLossReason => + driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason)) + } + future onFailure { + case NonFatal(e) => + logWarning(s"Attempted to get executor loss reason" + + s" for executor id ${executorId} at RPC address ${executorRpcAddress}," + + s" but got no response. Marking as slave lost.", e) + driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, SlaveLost())) + case t => throw t + } + case None => + logWarning("Attempted to check for an executor loss reason" + + " before the AM has registered!") + driverEndpoint.askWithRetry[Boolean]( + RemoveExecutor(executorId, SlaveLost("AM is not yet registered."))) + } + } + + override def receive: PartialFunction[Any, Unit] = { + case RegisterClusterManager(am) => + logInfo(s"ApplicationMaster registered as $am") + amEndpoint = Option(am) + if (!shouldResetOnAmRegister) { + shouldResetOnAmRegister = true + } else { + // AM is already registered before, this potentially means that AM failed and + // a new one registered after the failure. This will only happen in yarn-client mode. + reset() + } + + case AddWebUIFilter(filterName, filterParams, proxyBase) => + addWebUIFilter(filterName, filterParams, proxyBase) + + case RemoveExecutor(executorId, reason) => + logWarning(reason.toString) + removeExecutor(executorId, reason) + } + + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case r: RequestExecutors => + amEndpoint match { + case Some(am) => + Future { + context.reply(am.askWithRetry[Boolean](r)) + } onFailure { + case NonFatal(e) => + logError(s"Sending $r to AM was unsuccessful", e) + context.sendFailure(e) + } + case None => + logWarning("Attempted to request executors before the AM has registered!") + context.reply(false) + } + + case k: KillExecutors => + amEndpoint match { + case Some(am) => + Future { + context.reply(am.askWithRetry[Boolean](k)) + } onFailure { + case NonFatal(e) => + logError(s"Sending $k to AM was unsuccessful", e) + context.sendFailure(e) + } + case None => + logWarning("Attempted to kill executors before the AM has registered!") + context.reply(false) + } + + case RetrieveLastAllocatedExecutorId => + context.reply(currentExecutorIdCounter) + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (amEndpoint.exists(_.address == remoteAddress)) { + logWarning(s"ApplicationMaster has disassociated: $remoteAddress") + amEndpoint = None + } + } + + override def onStop(): Unit = { + askAmThreadPool.shutdownNow() + } + } +} + +private[spark] object YarnSchedulerBackend { + val ENDPOINT_NAME = "YarnScheduler" +} diff --git a/yarn/src/test/resources/log4j.properties b/yarn/src/test/resources/log4j.properties index 6b9a799954bf1..d13454d5ae5d5 100644 --- a/yarn/src/test/resources/log4j.properties +++ b/yarn/src/test/resources/log4j.properties @@ -28,4 +28,4 @@ log4j.logger.com.sun.jersey=WARN log4j.logger.org.apache.hadoop=WARN log4j.logger.org.eclipse.jetty=WARN log4j.logger.org.mortbay=WARN -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index 12494b01054ba..9c3b18e4ec5f3 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.yarn import java.io.{File, FileOutputStream, OutputStreamWriter} +import java.nio.charset.StandardCharsets import java.util.Properties import java.util.concurrent.TimeUnit @@ -25,14 +26,16 @@ import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps -import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files +import org.apache.commons.lang3.SerializationUtils import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.server.MiniYARNCluster import org.scalatest.{BeforeAndAfterAll, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ +import org.apache.spark.deploy.yarn.config._ +import org.apache.spark.internal.Logging import org.apache.spark.launcher._ import org.apache.spark.util.Utils @@ -50,7 +53,7 @@ abstract class BaseYarnClusterSuite |log4j.logger.org.apache.hadoop=WARN |log4j.logger.org.eclipse.jetty=WARN |log4j.logger.org.mortbay=WARN - |log4j.logger.org.spark-project.jetty=WARN + |log4j.logger.org.spark_project.jetty=WARN """.stripMargin private var yarnCluster: MiniYARNCluster = _ @@ -59,10 +62,13 @@ abstract class BaseYarnClusterSuite protected var hadoopConfDir: File = _ private var logConfDir: File = _ + var oldSystemProperties: Properties = null + def newYarnConfig(): YarnConfiguration override def beforeAll() { super.beforeAll() + oldSystemProperties = SerializationUtils.clone(System.getProperties) tempDir = Utils.createTempDir() logConfDir = new File(tempDir, "log4j") @@ -70,7 +76,7 @@ abstract class BaseYarnClusterSuite System.setProperty("SPARK_YARN_MODE", "true") val logConfFile = new File(logConfDir, "log4j.properties") - Files.write(LOG4J_CONF, logConfFile, UTF_8) + Files.write(LOG4J_CONF, logConfFile, StandardCharsets.UTF_8) // Disable the disk utilization check to avoid the test hanging when people's disks are // getting full. @@ -115,9 +121,12 @@ abstract class BaseYarnClusterSuite } override def afterAll() { - yarnCluster.stop() - System.clearProperty("SPARK_YARN_MODE") - super.afterAll() + try { + yarnCluster.stop() + } finally { + System.setProperties(oldSystemProperties) + super.afterAll() + } } protected def runSpark( @@ -129,7 +138,7 @@ abstract class BaseYarnClusterSuite extraJars: Seq[String] = Nil, extraConf: Map[String, String] = Map(), extraEnv: Map[String, String] = Map()): SparkAppHandle.State = { - val master = if (clientMode) "yarn-client" else "yarn-cluster" + val deployMode = if (clientMode) "client" else "cluster" val propsFile = createConfFile(extraClassPath = extraClassPath, extraConf = extraConf) val env = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()) ++ extraEnv @@ -141,7 +150,8 @@ abstract class BaseYarnClusterSuite launcher.setAppResource(fakeSparkJar.getAbsolutePath()) } launcher.setSparkHome(sys.props("spark.test.home")) - .setMaster(master) + .setMaster("yarn") + .setDeployMode(deployMode) .setConf("spark.executor.instances", "1") .setPropertiesFile(propsFile) .addAppArgs(appArgs.toArray: _*) @@ -182,7 +192,7 @@ abstract class BaseYarnClusterSuite result: File, expected: String): Unit = { finalState should be (SparkAppHandle.State.FINISHED) - val resultString = Files.toString(result, UTF_8) + val resultString = Files.toString(result, StandardCharsets.UTF_8) resultString should be (expected) } @@ -194,7 +204,7 @@ abstract class BaseYarnClusterSuite extraClassPath: Seq[String] = Nil, extraConf: Map[String, String] = Map()): String = { val props = new Properties() - props.put("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath()) + props.put(SPARK_JARS.key, "local:" + fakeSparkJar.getAbsolutePath()) val testClasspath = new TestClasspathBuilder() .buildClassPath( @@ -222,7 +232,7 @@ abstract class BaseYarnClusterSuite extraConf.foreach { case (k, v) => props.setProperty(k, v) } val propsFile = File.createTempFile("spark", ".properties", tempDir) - val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8) + val writer = new OutputStreamWriter(new FileOutputStream(propsFile), StandardCharsets.UTF_8) props.store(writer, "Spark properties.") writer.close() propsFile.getAbsolutePath() diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala index 804dfecde7867..ac8f663df2fff 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala @@ -19,25 +19,22 @@ package org.apache.spark.deploy.yarn import java.net.URI -import org.scalatest.mock.MockitoSugar -import org.mockito.Mockito.when +import scala.collection.mutable.HashMap +import scala.collection.mutable.Map import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path -import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.yarn.api.records.LocalResource -import org.apache.hadoop.yarn.api.records.LocalResourceVisibility import org.apache.hadoop.yarn.api.records.LocalResourceType -import org.apache.hadoop.yarn.util.{Records, ConverterUtils} - -import scala.collection.mutable.HashMap -import scala.collection.mutable.Map +import org.apache.hadoop.yarn.api.records.LocalResourceVisibility +import org.apache.hadoop.yarn.util.ConverterUtils +import org.mockito.Mockito.when +import org.scalatest.mock.MockitoSugar import org.apache.spark.SparkFunSuite - class ClientDistributedCacheManagerSuite extends SparkFunSuite with MockitoSugar { class MockClientDistributedCacheManager extends ClientDistributedCacheManager { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index e7f2501e7899f..74e268dc48473 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -17,14 +17,16 @@ package org.apache.spark.deploy.yarn -import java.io.File +import java.io.{File, FileOutputStream} import java.net.URI +import java.util.Properties import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap => MutableHashMap} import scala.reflect.ClassTag import scala.util.Try +import org.apache.commons.lang3.SerializationUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.MRJobConfig @@ -34,53 +36,66 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.YarnClientApplication import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.util.Records -import org.mockito.Matchers._ +import org.mockito.Matchers.{eq => meq, _} import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfterAll, Matchers} -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.util.Utils +import org.apache.spark.{SparkConf, SparkFunSuite, TestUtils} +import org.apache.spark.deploy.yarn.config._ +import org.apache.spark.util.{ResetSystemProperties, SparkConfWithEnv, Utils} -class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { +class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll + with ResetSystemProperties { + + import Client._ + + var oldSystemProperties: Properties = null override def beforeAll(): Unit = { + super.beforeAll() + oldSystemProperties = SerializationUtils.clone(System.getProperties) System.setProperty("SPARK_YARN_MODE", "true") } override def afterAll(): Unit = { - System.clearProperty("SPARK_YARN_MODE") + try { + System.setProperties(oldSystemProperties) + oldSystemProperties = null + } finally { + super.afterAll() + } } test("default Yarn application classpath") { - Client.getDefaultYarnApplicationClasspath should be(Some(Fixtures.knownDefYarnAppCP)) + getDefaultYarnApplicationClasspath should be(Some(Fixtures.knownDefYarnAppCP)) } test("default MR application classpath") { - Client.getDefaultMRApplicationClasspath should be(Some(Fixtures.knownDefMRAppCP)) + getDefaultMRApplicationClasspath should be(Some(Fixtures.knownDefMRAppCP)) } test("resultant classpath for an application that defines a classpath for YARN") { withAppConf(Fixtures.mapYARNAppConf) { conf => val env = newEnv - Client.populateHadoopClasspath(conf, env) + populateHadoopClasspath(conf, env) classpath(env) should be( - flatten(Fixtures.knownYARNAppCP, Client.getDefaultMRApplicationClasspath)) + flatten(Fixtures.knownYARNAppCP, getDefaultMRApplicationClasspath)) } } test("resultant classpath for an application that defines a classpath for MR") { withAppConf(Fixtures.mapMRAppConf) { conf => val env = newEnv - Client.populateHadoopClasspath(conf, env) + populateHadoopClasspath(conf, env) classpath(env) should be( - flatten(Client.getDefaultYarnApplicationClasspath, Fixtures.knownMRAppCP)) + flatten(getDefaultYarnApplicationClasspath, Fixtures.knownMRAppCP)) } } test("resultant classpath for an application that defines both classpaths, YARN and MR") { withAppConf(Fixtures.mapAppConf) { conf => val env = newEnv - Client.populateHadoopClasspath(conf, env) + populateHadoopClasspath(conf, env) classpath(env) should be(flatten(Fixtures.knownYARNAppCP, Fixtures.knownMRAppCP)) } } @@ -89,58 +104,58 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { private val USER = "local:/userJar" private val ADDED = "local:/addJar1,local:/addJar2,/addJar3" + private val PWD = + if (classOf[Environment].getMethods().exists(_.getName == "$$")) { + "{{PWD}}" + } else if (Utils.isWindows) { + "%PWD%" + } else { + Environment.PWD.$() + } + test("Local jar URIs") { val conf = new Configuration() - val sparkConf = new SparkConf().set(Client.CONF_SPARK_JAR, SPARK) - .set("spark.yarn.user.classpath.first", "true") + val sparkConf = new SparkConf() + .set(SPARK_JARS, Seq(SPARK)) + .set(USER_CLASS_PATH_FIRST, true) + .set("spark.yarn.dist.jars", ADDED) val env = new MutableHashMap[String, String]() - val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf) + val args = new ClientArguments(Array("--jar", USER)) - Client.populateClasspath(args, conf, sparkConf, env, true) + populateClasspath(args, conf, sparkConf, env) val cp = env("CLASSPATH").split(":|;|") s"$SPARK,$USER,$ADDED".split(",").foreach({ entry => val uri = new URI(entry) - if (Client.LOCAL_SCHEME.equals(uri.getScheme())) { + if (LOCAL_SCHEME.equals(uri.getScheme())) { cp should contain (uri.getPath()) } else { cp should not contain (uri.getPath()) } }) - val pwdVar = - if (classOf[Environment].getMethods().exists(_.getName == "$$")) { - "{{PWD}}" - } else if (Utils.isWindows) { - "%PWD%" - } else { - Environment.PWD.$() - } - cp should contain(pwdVar) - cp should contain (s"$pwdVar${Path.SEPARATOR}${Client.LOCALIZED_CONF_DIR}") - cp should not contain (Client.SPARK_JAR) - cp should not contain (Client.APP_JAR) + cp should contain(PWD) + cp should contain (s"$PWD${Path.SEPARATOR}${LOCALIZED_CONF_DIR}") + cp should not contain (APP_JAR) } test("Jar path propagation through SparkConf") { val conf = new Configuration() - val sparkConf = new SparkConf().set(Client.CONF_SPARK_JAR, SPARK) - val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf) - - val client = spy(new Client(args, conf, sparkConf)) - doReturn(new Path("/")).when(client).copyFileToRemote(any(classOf[Path]), - any(classOf[Path]), anyShort()) + val sparkConf = new SparkConf() + .set(SPARK_JARS, Seq(SPARK)) + .set("spark.yarn.dist.jars", ADDED) + val client = createClient(sparkConf, args = Array("--jar", USER)) val tempDir = Utils.createTempDir() try { client.prepareLocalResources(tempDir.getAbsolutePath(), Nil) - sparkConf.getOption(Client.CONF_SPARK_USER_JAR) should be (Some(USER)) + sparkConf.get(APP_JAR) should be (Some(USER)) // The non-local path should be propagated by name only, since it will end up in the app's // staging dir. val expected = ADDED.split(",") .map(p => { val uri = new URI(p) - if (Client.LOCAL_SCHEME == uri.getScheme()) { + if (LOCAL_SCHEME == uri.getScheme()) { p } else { Option(uri.getFragment()).getOrElse(new File(p).getName()) @@ -148,7 +163,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { }) .mkString(",") - sparkConf.getOption(Client.CONF_SPARK_YARN_SECONDARY_JARS) should be (Some(expected)) + sparkConf.get(SECONDARY_JARS) should be (Some(expected.split(",").toSeq)) } finally { Utils.deleteRecursively(tempDir) } @@ -157,17 +172,16 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { test("Cluster path translation") { val conf = new Configuration() val sparkConf = new SparkConf() - .set(Client.CONF_SPARK_JAR, "local:/localPath/spark.jar") - .set("spark.yarn.config.gatewayPath", "/localPath") - .set("spark.yarn.config.replacementPath", "/remotePath") + .set(SPARK_JARS, Seq("local:/localPath/spark.jar")) + .set(GATEWAY_ROOT_PATH, "/localPath") + .set(REPLACEMENT_ROOT_PATH, "/remotePath") - Client.getClusterPath(sparkConf, "/localPath") should be ("/remotePath") - Client.getClusterPath(sparkConf, "/localPath/1:/localPath/2") should be ( + getClusterPath(sparkConf, "/localPath") should be ("/remotePath") + getClusterPath(sparkConf, "/localPath/1:/localPath/2") should be ( "/remotePath/1:/remotePath/2") val env = new MutableHashMap[String, String]() - Client.populateClasspath(null, conf, sparkConf, env, false, - extraClassPath = Some("/localPath/my1.jar")) + populateClasspath(null, conf, sparkConf, env, extraClassPath = Some("/localPath/my1.jar")) val cp = classpath(env) cp should contain ("/remotePath/spark.jar") cp should contain ("/remotePath/my1.jar") @@ -179,11 +193,11 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { // Spaces between non-comma strings should be preserved as single tags. Empty strings may or // may not be removed depending on the version of Hadoop being used. val sparkConf = new SparkConf() - .set(Client.CONF_SPARK_YARN_APPLICATION_TAGS, ",tag1, dup,tag2 , ,multi word , dup") - .set("spark.yarn.maxAppAttempts", "42") - val args = new ClientArguments(Array( - "--name", "foo-test-app", - "--queue", "staging-queue"), sparkConf) + .set(APPLICATION_TAGS.key, ",tag1, dup,tag2 , ,multi word , dup") + .set(MAX_APP_ATTEMPTS, 42) + .set("spark.app.name", "foo-test-app") + .set(QUEUE_NAME, "staging-queue") + val args = new ClientArguments(Array()) val appContext = Records.newRecord(classOf[ApplicationSubmissionContext]) val getNewApplicationResponse = Records.newRecord(classOf[GetNewApplicationResponse]) @@ -201,11 +215,76 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { appContext.getClass.getMethods.filter(_.getName.equals("getApplicationTags")).foreach{ method => val tags = method.invoke(appContext).asInstanceOf[java.util.Set[String]] tags should contain allOf ("tag1", "dup", "tag2", "multi word") - tags.asScala.filter(_.nonEmpty).size should be (4) + tags.asScala.count(_.nonEmpty) should be (4) } appContext.getMaxAppAttempts should be (42) } + test("spark.yarn.jars with multiple paths and globs") { + val libs = Utils.createTempDir() + val single = Utils.createTempDir() + val jar1 = TestUtils.createJarWithFiles(Map(), libs) + val jar2 = TestUtils.createJarWithFiles(Map(), libs) + val jar3 = TestUtils.createJarWithFiles(Map(), single) + val jar4 = TestUtils.createJarWithFiles(Map(), single) + + val jarsConf = Seq( + s"${libs.getAbsolutePath()}/*", + jar3.getPath(), + s"local:${jar4.getPath()}", + s"local:${single.getAbsolutePath()}/*") + + val sparkConf = new SparkConf().set(SPARK_JARS, jarsConf) + val client = createClient(sparkConf) + + val tempDir = Utils.createTempDir() + client.prepareLocalResources(tempDir.getAbsolutePath(), Nil) + + assert(sparkConf.get(SPARK_JARS) === + Some(Seq(s"local:${jar4.getPath()}", s"local:${single.getAbsolutePath()}/*"))) + + verify(client).copyFileToRemote(any(classOf[Path]), meq(new Path(jar1.toURI())), anyShort()) + verify(client).copyFileToRemote(any(classOf[Path]), meq(new Path(jar2.toURI())), anyShort()) + verify(client).copyFileToRemote(any(classOf[Path]), meq(new Path(jar3.toURI())), anyShort()) + + val cp = classpath(client) + cp should contain (buildPath(PWD, LOCALIZED_LIB_DIR, "*")) + cp should not contain (jar3.getPath()) + cp should contain (jar4.getPath()) + cp should contain (buildPath(single.getAbsolutePath(), "*")) + } + + test("distribute jars archive") { + val temp = Utils.createTempDir() + val archive = TestUtils.createJarWithFiles(Map(), temp) + + val sparkConf = new SparkConf().set(SPARK_ARCHIVE, archive.getPath()) + val client = createClient(sparkConf) + client.prepareLocalResources(temp.getAbsolutePath(), Nil) + + verify(client).copyFileToRemote(any(classOf[Path]), meq(new Path(archive.toURI())), anyShort()) + classpath(client) should contain (buildPath(PWD, LOCALIZED_LIB_DIR, "*")) + + sparkConf.set(SPARK_ARCHIVE, LOCAL_SCHEME + ":" + archive.getPath()) + intercept[IllegalArgumentException] { + client.prepareLocalResources(temp.getAbsolutePath(), Nil) + } + } + + test("distribute local spark jars") { + val temp = Utils.createTempDir() + val jarsDir = new File(temp, "jars") + assert(jarsDir.mkdir()) + val jar = TestUtils.createJarWithFiles(Map(), jarsDir) + new FileOutputStream(new File(temp, "RELEASE")).close() + + val sparkConf = new SparkConfWithEnv(Map("SPARK_HOME" -> temp.getAbsolutePath())) + val client = createClient(sparkConf) + client.prepareLocalResources(temp.getAbsolutePath(), Nil) + verify(client).copyFileToRemote(any(classOf[Path]), meq(new Path(jar.toURI())), anyShort()) + classpath(client) should contain (buildPath(PWD, LOCALIZED_LIB_DIR, "*")) + } + object Fixtures { val knownDefYarnAppCP: Seq[String] = @@ -266,4 +345,21 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { }.toOption.getOrElse(defaults) } + private def createClient( + sparkConf: SparkConf, + conf: Configuration = new Configuration(), + args: Array[String] = Array()): Client = { + val clientArgs = new ClientArguments(args) + val client = spy(new Client(clientArgs, conf, sparkConf)) + doReturn(new Path("/")).when(client).copyFileToRemote(any(classOf[Path]), + any(classOf[Path]), anyShort()) + client + } + + private def classpath(client: Client): Array[String] = { + val env = new MutableHashMap[String, String]() + populateClasspath(null, client.hadoopConf, client.sparkConf, env) + classpath(env) + } + } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index bd80036c5cfa7..a641a6e73e853 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -25,15 +25,13 @@ import org.apache.hadoop.net.DNSToSwitchMapping import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest -import org.scalatest.{BeforeAndAfterEach, Matchers} - -import org.scalatest.{BeforeAndAfterEach, Matchers} import org.mockito.Mockito._ +import org.scalatest.{BeforeAndAfterEach, Matchers} -import org.apache.spark.{SecurityManager, SparkFunSuite} -import org.apache.spark.SparkConf -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.deploy.yarn.config._ import org.apache.spark.deploy.yarn.YarnAllocator._ +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.SplitInfo @@ -58,7 +56,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter val sparkConf = new SparkConf() sparkConf.set("spark.driver.host", "localhost") sparkConf.set("spark.driver.port", "4040") - sparkConf.set("spark.yarn.jar", "notarealjar.jar") + sparkConf.set(SPARK_JARS, Seq("notarealjar.jar")) sparkConf.set("spark.yarn.launchContainers", "false") val appAttemptId = ApplicationAttemptId.newInstance(ApplicationId.newInstance(0, 0), 0) @@ -72,13 +70,18 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter var containerNum = 0 override def beforeEach() { + super.beforeEach() rmClient = AMRMClient.createAMRMClient() rmClient.init(conf) rmClient.start() } override def afterEach() { - rmClient.stop() + try { + rmClient.stop() + } finally { + super.afterEach() + } } class MockSplitInfo(host: String) extends SplitInfo(null, host, null, 1, null) { @@ -87,12 +90,13 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter def createAllocator(maxExecutors: Int = 5): YarnAllocator = { val args = Array( - "--executor-cores", "5", - "--executor-memory", "2048", "--jar", "somejar.jar", "--class", "SomeClass") val sparkConfClone = sparkConf.clone() - sparkConfClone.set("spark.executor.instances", maxExecutors.toString) + sparkConfClone + .set("spark.executor.instances", maxExecutors.toString) + .set("spark.executor.cores", "5") + .set("spark.executor.memory", "2048") new YarnAllocator( "not used", mock(classOf[RpcEndpointRef]), @@ -100,7 +104,6 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter sparkConfClone, rmClient, appAttemptId, - new ApplicationMasterArguments(args), new SecurityManager(sparkConf)) } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 6db012a77a936..b2b4d84f53d85 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -19,19 +19,20 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.URL -import java.util.{HashMap => JHashMap, Properties} +import java.nio.charset.StandardCharsets +import java.util.{HashMap => JHashMap} import scala.collection.mutable import scala.concurrent.duration._ import scala.language.postfixOps -import com.google.common.base.Charsets.UTF_8 import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ import org.apache.spark._ +import org.apache.spark.internal.Logging import org.apache.spark.launcher._ import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, SparkListenerExecutorAdded} @@ -84,6 +85,35 @@ class YarnClusterSuite extends BaseYarnClusterSuite { testBasicYarnApp(false) } + test("run Spark in yarn-client mode with different configurations") { + testBasicYarnApp(true, + Map( + "spark.driver.memory" -> "512m", + "spark.executor.cores" -> "1", + "spark.executor.memory" -> "512m", + "spark.executor.instances" -> "2" + )) + } + + test("run Spark in yarn-cluster mode with different configurations") { + testBasicYarnApp(true, + Map( + "spark.driver.memory" -> "512m", + "spark.driver.cores" -> "1", + "spark.executor.cores" -> "1", + "spark.executor.memory" -> "512m", + "spark.executor.instances" -> "2" + )) + } + + test("run Spark in yarn-client mode with additional jar") { + testWithAddJar(true) + } + + test("run Spark in yarn-cluster mode with additional jar") { + testWithAddJar(false) + } + test("run Spark in yarn-cluster mode unsuccessfully") { // Don't provide arguments so the driver will fail. val finalState = runSpark(false, mainClassName(YarnClusterDriver.getClass)) @@ -115,7 +145,8 @@ class YarnClusterSuite extends BaseYarnClusterSuite { .setSparkHome(sys.props("spark.test.home")) .setConf("spark.ui.enabled", "false") .setPropertiesFile(propsFile) - .setMaster("yarn-client") + .setMaster("yarn") + .setDeployMode("client") .setAppResource("spark-internal") .setMainClass(mainClassName(YarnLauncherTestApp.getClass)) .startApplication() @@ -137,23 +168,36 @@ class YarnClusterSuite extends BaseYarnClusterSuite { } } - private def testBasicYarnApp(clientMode: Boolean): Unit = { + private def testBasicYarnApp(clientMode: Boolean, conf: Map[String, String] = Map()): Unit = { val result = File.createTempFile("result", null, tempDir) val finalState = runSpark(clientMode, mainClassName(YarnClusterDriver.getClass), - appArgs = Seq(result.getAbsolutePath())) + appArgs = Seq(result.getAbsolutePath()), + extraConf = conf) checkResult(finalState, result) } + private def testWithAddJar(clientMode: Boolean): Unit = { + val originalJar = TestUtils.createJarWithFiles(Map("test.resource" -> "ORIGINAL"), tempDir) + val driverResult = File.createTempFile("driver", null, tempDir) + val executorResult = File.createTempFile("executor", null, tempDir) + val finalState = runSpark(clientMode, mainClassName(YarnClasspathTest.getClass), + appArgs = Seq(driverResult.getAbsolutePath(), executorResult.getAbsolutePath()), + extraClassPath = Seq(originalJar.getPath()), + extraJars = Seq("local:" + originalJar.getPath())) + checkResult(finalState, driverResult, "ORIGINAL") + checkResult(finalState, executorResult, "ORIGINAL") + } + private def testPySpark(clientMode: Boolean): Unit = { val primaryPyFile = new File(tempDir, "test.py") - Files.write(TEST_PYFILE, primaryPyFile, UTF_8) + Files.write(TEST_PYFILE, primaryPyFile, StandardCharsets.UTF_8) // When running tests, let's not assume the user has built the assembly module, which also // creates the pyspark archive. Instead, let's use PYSPARK_ARCHIVES_PATH to point at the // needed locations. - val sparkHome = sys.props("spark.test.home"); + val sparkHome = sys.props("spark.test.home") val pythonPath = Seq( - s"$sparkHome/python/lib/py4j-0.9-src.zip", + s"$sparkHome/python/lib/py4j-0.9.2-src.zip", s"$sparkHome/python") val extraEnv = Map( "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator), @@ -170,7 +214,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { subdir } val pyModule = new File(moduleDir, "mod1.py") - Files.write(TEST_PYMODULE, pyModule, UTF_8) + Files.write(TEST_PYMODULE, pyModule, StandardCharsets.UTF_8) val mod2Archive = TestUtils.createJarWithFiles(Map("mod2.py" -> TEST_PYMODULE), moduleDir) val pyFiles = Seq(pyModule.getAbsolutePath(), mod2Archive.getPath()).mkString(",") @@ -244,7 +288,7 @@ private object YarnClusterDriver extends Logging with Matchers { data should be (Set(1, 2, 3, 4)) result = "success" } finally { - Files.write(result, status, UTF_8) + Files.write(result, status, StandardCharsets.UTF_8) sc.stop() } @@ -318,14 +362,14 @@ private object YarnClasspathTest extends Logging { val ccl = Thread.currentThread().getContextClassLoader() val resource = ccl.getResourceAsStream("test.resource") val bytes = ByteStreams.toByteArray(resource) - result = new String(bytes, 0, bytes.length, UTF_8) + result = new String(bytes, 0, bytes.length, StandardCharsets.UTF_8) } catch { case t: Throwable => error(s"loading test.resource to $resultPath", t) // set the exit code if not yet set exitCode = 2 } finally { - Files.write(result, new File(resultPath), UTF_8) + Files.write(result, new File(resultPath), StandardCharsets.UTF_8) } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala index c17e8695c24fb..950ebd9a2d4d9 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -18,14 +18,15 @@ package org.apache.spark.deploy.yarn import java.io.File +import java.nio.charset.StandardCharsets -import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.commons.io.FileUtils import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.Matchers import org.apache.spark._ +import org.apache.spark.internal.Logging import org.apache.spark.network.shuffle.ShuffleTestAccessor import org.apache.spark.network.yarn.{YarnShuffleService, YarnTestAccessor} import org.apache.spark.tags.ExtendedYarnTest @@ -78,7 +79,7 @@ private object YarnExternalShuffleDriver extends Logging with Matchers { s""" |Invalid command line: ${args.mkString(" ")} | - |Usage: ExternalShuffleDriver [result file] [registed exec file] + |Usage: ExternalShuffleDriver [result file] [registered exec file] """.stripMargin) // scalastyle:on println System.exit(1) @@ -104,7 +105,7 @@ private object YarnExternalShuffleDriver extends Logging with Matchers { } finally { sc.stop() FileUtils.deleteDirectory(execStateCopy) - Files.write(result, status, UTF_8) + Files.write(result, status, StandardCharsets.UTF_8) } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index a70e66d39a64e..fe09808ae508d 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy.yarn import java.io.{File, IOException} import java.lang.reflect.InvocationTargetException +import java.nio.charset.StandardCharsets import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.conf.Configuration @@ -27,18 +28,17 @@ import org.apache.hadoop.hive.ql.metadata.HiveException import org.apache.hadoop.io.Text import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment -import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.hadoop.yarn.api.records.ApplicationAccessType import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.Matchers -import org.apache.hadoop.yarn.api.records.ApplicationAccessType - -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf, SparkException, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.util.Utils +import org.apache.spark.internal.Logging +import org.apache.spark.util.{ResetSystemProperties, Utils} - -class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging { +class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging + with ResetSystemProperties { val hasBash = try { @@ -61,7 +61,7 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging val args = Array("arg1", "${arg.2}", "\"arg3\"", "'arg4'", "$arg5", "\\arg6") try { val argLine = args.map(a => YarnSparkHadoopUtil.escapeForShell(a)).mkString(" ") - Files.write(("bash -c \"echo " + argLine + "\"").getBytes(), scriptFile) + Files.write(("bash -c \"echo " + argLine + "\"").getBytes(StandardCharsets.UTF_8), scriptFile) scriptFile.setExecutable(true) val proc = Runtime.getRuntime().exec(Array(scriptFile.getAbsolutePath())) @@ -101,22 +101,18 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging val modifyAcls = acls.get(ApplicationAccessType.MODIFY_APP) viewAcls match { - case Some(vacls) => { + case Some(vacls) => val aclSet = vacls.split(',').map(_.trim).toSet assert(aclSet.contains(System.getProperty("user.name", "invalid"))) - } - case None => { + case None => fail() - } } modifyAcls match { - case Some(macls) => { + case Some(macls) => val aclSet = macls.split(',').map(_.trim).toSet assert(aclSet.contains(System.getProperty("user.name", "invalid"))) - } - case None => { + case None => fail() - } } } @@ -135,26 +131,22 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging val modifyAcls = acls.get(ApplicationAccessType.MODIFY_APP) viewAcls match { - case Some(vacls) => { + case Some(vacls) => val aclSet = vacls.split(',').map(_.trim).toSet assert(aclSet.contains("user1")) assert(aclSet.contains("user2")) assert(aclSet.contains(System.getProperty("user.name", "invalid"))) - } - case None => { + case None => fail() - } } modifyAcls match { - case Some(macls) => { + case Some(macls) => val aclSet = macls.split(',').map(_.trim).toSet assert(aclSet.contains("user3")) assert(aclSet.contains("user4")) assert(aclSet.contains(System.getProperty("user.name", "invalid"))) - } - case None => { + case None => fail() - } } } @@ -257,9 +249,8 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging hadoopConf.set("hive.metastore.uris", "http://localhost:0") val util = new YarnSparkHadoopUtil assertNestedHiveException(intercept[InvocationTargetException] { - util.obtainTokenForHiveMetastoreInner(hadoopConf, "alice") + util.obtainTokenForHiveMetastoreInner(hadoopConf) }) - // expect exception trapping code to unwind this hive-side exception assertNestedHiveException(intercept[InvocationTargetException] { util.obtainTokenForHiveMetastore(hadoopConf) }) @@ -276,6 +267,16 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging inner } + test("Obtain tokens For HBase") { + val hadoopConf = new Configuration() + hadoopConf.set("hbase.security.authentication", "kerberos") + val util = new YarnSparkHadoopUtil + intercept[ClassNotFoundException] { + util.obtainTokenForHBaseInner(hadoopConf) + } + util.obtainTokenForHBase(hadoopConf) should be (None) + } + // This test needs to live here because it depends on isYarnMode returning true, which can only // happen in the YARN module. test("security manager token generation") { diff --git a/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala b/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala index aa46ec5100f0e..1fed2562fcadb 100644 --- a/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala +++ b/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala @@ -16,10 +16,9 @@ */ package org.apache.spark.network.shuffle -import java.io.{IOException, File} +import java.io.File import java.util.concurrent.ConcurrentMap -import com.google.common.annotations.VisibleForTesting import org.apache.hadoop.yarn.api.records.ApplicationId import org.fusesource.leveldbjni.JniDBFactory import org.iq80.leveldb.{DB, Options} diff --git a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala index 6aa8c814cd4f0..5a426b86d10e0 100644 --- a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala @@ -34,6 +34,7 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd private[yarn] var yarnConfig: YarnConfiguration = new YarnConfiguration override def beforeEach(): Unit = { + super.beforeEach() yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle") yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"), classOf[YarnShuffleService].getCanonicalName) @@ -54,17 +55,21 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd var s3: YarnShuffleService = null override def afterEach(): Unit = { - if (s1 != null) { - s1.stop() - s1 = null - } - if (s2 != null) { - s2.stop() - s2 = null - } - if (s3 != null) { - s3.stop() - s3 = null + try { + if (s1 != null) { + s1.stop() + s1 = null + } + if (s2 != null) { + s2.stop() + s2 = null + } + if (s3 != null) { + s3.stop() + s3 = null + } + } finally { + super.afterEach() } } diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala new file mode 100644 index 0000000000000..6ea7984c64514 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.deploy.yarn.config._ +import org.apache.spark.internal.Logging + +/** + * Test the integration with [[SchedulerExtensionServices]] + */ +class ExtensionServiceIntegrationSuite extends SparkFunSuite + with LocalSparkContext with BeforeAndAfter + with Logging { + + val applicationId = new StubApplicationId(0, 1111L) + val attemptId = new StubApplicationAttemptId(applicationId, 1) + + /* + * Setup phase creates the spark context + */ + before { + val sparkConf = new SparkConf() + sparkConf.set(SCHEDULER_SERVICES, Seq(classOf[SimpleExtensionService].getName())) + sparkConf.setMaster("local").setAppName("ExtensionServiceIntegrationSuite") + sc = new SparkContext(sparkConf) + } + + test("Instantiate") { + val services = new SchedulerExtensionServices() + assertResult(Nil, "non-nil service list") { + services.getServices + } + services.start(SchedulerExtensionServiceBinding(sc, applicationId)) + services.stop() + } + + test("Contains SimpleExtensionService Service") { + val services = new SchedulerExtensionServices() + try { + services.start(SchedulerExtensionServiceBinding(sc, applicationId)) + val serviceList = services.getServices + assert(serviceList.nonEmpty, "empty service list") + val (service :: Nil) = serviceList + val simpleService = service.asInstanceOf[SimpleExtensionService] + assert(simpleService.started.get, "service not started") + services.stop() + assert(!simpleService.started.get, "service not stopped") + } finally { + services.stop() + } + } +} + + diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala new file mode 100644 index 0000000000000..9b8c98cda8da8 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import java.util.concurrent.atomic.AtomicBoolean + +private[spark] class SimpleExtensionService extends SchedulerExtensionService { + + /** started flag; set in the `start()` call, stopped in `stop()`. */ + val started = new AtomicBoolean(false) + + override def start(binding: SchedulerExtensionServiceBinding): Unit = { + started.set(true) + } + + override def stop(): Unit = { + started.set(false) + } +} diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala new file mode 100644 index 0000000000000..4b57b9509a655 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} + +/** + * A stub application ID; can be set in constructor and/or updated later. + * @param applicationId application ID + * @param attempt an attempt counter + */ +class StubApplicationAttemptId(var applicationId: ApplicationId, var attempt: Int) + extends ApplicationAttemptId { + + override def setApplicationId(appID: ApplicationId): Unit = { + applicationId = appID + } + + override def getAttemptId: Int = { + attempt + } + + override def setAttemptId(attemptId: Int): Unit = { + attempt = attemptId + } + + override def getApplicationId: ApplicationId = { + applicationId + } + + override def build(): Unit = { + } +} diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala new file mode 100644 index 0000000000000..bffa0e09befd2 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.hadoop.yarn.api.records.ApplicationId + +/** + * Simple Testing Application Id; ID and cluster timestamp are set in constructor + * and cannot be updated. + * @param id app id + * @param clusterTimestamp timestamp + */ +private[spark] class StubApplicationId(id: Int, clusterTimestamp: Long) extends ApplicationId { + override def getId: Int = { + id + } + + override def getClusterTimestamp: Long = { + clusterTimestamp + } + + override def setId(id: Int): Unit = {} + + override def setClusterTimestamp(clusterTimestamp: Long): Unit = {} + + override def build(): Unit = {} +}